Note
Go to the end to download the full example code.
pytorch parallelproj projection layer
In this example, we show how to define a custom pytorch layer that can be used
to define a feed forward neural network that includes a parallelproj forward and back
backward projections (or any LinearOperator) that can be used with pytorch’s
autograd engine.
16 from __future__ import annotations
17
18 import array_api_compat.torch as torch
19 import matplotlib.pyplot as plt
20 import parallelproj
21 from array_api_compat import device
22
23
24 # device variable (cpu or cuda) that determines whether calculations
25 # are performed on the cpu or cuda gpu
26 if parallelproj.cuda_present:
27 dev = "cuda"
28 else:
29 dev = "cpu"
Setup the forward projection layer
We subclass torch.autograd.Function to define a custom pytorch layer
that is compatible with pytorch’s autograd engine.
see also: https://pytorch.org/tutorials/beginner/examples_autograd/two_layer_net_custom_function.html
40 class LinearSingleChannelOperator(torch.autograd.Function):
41 """
42 Function representing a linear operator acting on a mini batch of single channel images
43 """
44
45 @staticmethod
46 def forward(
47 ctx, x: torch.Tensor, operator: parallelproj.LinearOperator
48 ) -> torch.Tensor:
49 """forward pass of the linear operator
50
51 Parameters
52 ----------
53 ctx : context object
54 that can be used to store information for the backward pass
55 x : torch.Tensor
56 mini batch of 3D images with dimension (batch_size, 1, num_voxels_x, num_voxels_y, num_voxels_z)
57 operator : parallelproj.LinearOperator
58 linear operator that can act on a single 3D image
59
60 Returns
61 -------
62 torch.Tensor
63 mini batch of 3D images with dimension (batch_size, opertor.out_shape)
64 """
65
66 # https://pytorch.org/docs/stable/notes/extending.html#how-to-use
67 ctx.set_materialize_grads(False)
68 ctx.operator = operator
69
70 batch_size = x.shape[0]
71 y = torch.zeros(
72 (batch_size,) + operator.out_shape, dtype=x.dtype, device=device(x)
73 )
74
75 # loop over all samples in the batch and apply linear operator
76 # to the first channel
77 for i in range(batch_size):
78 y[i, ...] = operator(x[i, 0, ...].detach())
79
80 return y
81
82 @staticmethod
83 def backward(ctx, grad_output: torch.Tensor) -> tuple[torch.Tensor, None]:
84 """backward pass of the forward pass
85
86 Parameters
87 ----------
88 ctx : context object
89 that can be used to obtain information from the forward pass
90 grad_output : torch.Tensor
91 mini batch of dimension (batch_size, operator.out_shape)
92
93 Returns
94 -------
95 torch.Tensor, None
96 mini batch of 3D images with dimension (batch_size, 1, opertor.in_shape)
97 """
98
99 # For details on how to implement the backward pass, see
100 # https://pytorch.org/docs/stable/notes/extending.html#how-to-use
101
102 # since forward takes two input arguments (x, operator)
103 # we have to return two arguments (the latter is None)
104 if grad_output is None:
105 return None, None
106 else:
107 operator = ctx.operator
108
109 batch_size = grad_output.shape[0]
110 x = torch.zeros(
111 (batch_size, 1) + operator.in_shape,
112 dtype=grad_output.dtype,
113 device=device(grad_output),
114 )
115
116 # loop over all samples in the batch and apply linear operator
117 # to the first channel
118 for i in range(batch_size):
119 x[i, 0, ...] = operator.adjoint(grad_output[i, ...].detach())
120
121 return x, None
Setup the back projection layer
We subclass torch.autograd.Function to define a custom pytorch layer
that is compatible with pytorch’s autograd engine.
see also: https://pytorch.org/tutorials/beginner/examples_autograd/two_layer_net_custom_function.html
133 class AdjointLinearSingleChannelOperator(torch.autograd.Function):
134 """
135 Function representing the adjoint of a linear operator acting on a mini batch of single channel images
136 """
137
138 @staticmethod
139 def forward(
140 ctx, x: torch.Tensor, operator: parallelproj.LinearOperator
141 ) -> torch.Tensor:
142 """forward pass of the adjoint of the linear operator
143
144 Parameters
145 ----------
146 ctx : context object
147 that can be used to store information for the backward pass
148 x : torch.Tensor
149 mini batch of 3D images with dimension (batch_size, 1, operator.out_shape)
150 operator : parallelproj.LinearOperator
151 linear operator that can act on a single 3D image
152
153 Returns
154 -------
155 torch.Tensor
156 mini batch of 3D images with dimension (batch_size, 1, opertor.in_shape)
157 """
158
159 ctx.set_materialize_grads(False)
160 ctx.operator = operator
161
162 batch_size = x.shape[0]
163 y = torch.zeros(
164 (batch_size, 1) + operator.in_shape, dtype=x.dtype, device=device(x)
165 )
166
167 # loop over all samples in the batch and apply linear operator
168 # to the first channel
169 for i in range(batch_size):
170 y[i, 0, ...] = operator.adjoint(x[i, ...].detach())
171
172 return y
173
174 @staticmethod
175 def backward(ctx, grad_output):
176 """backward pass of the forward pass
177
178 Parameters
179 ----------
180 ctx : context object
181 that can be used to obtain information from the forward pass
182 grad_output : torch.Tensor
183 mini batch of dimension (batch_size, 1, operator.in_shape)
184
185 Returns
186 -------
187 torch.Tensor, None
188 mini batch of 3D images with dimension (batch_size, 1, opertor.out_shape)
189 """
190
191 # For details on how to implement the backward pass, see
192 # https://pytorch.org/docs/stable/notes/extending.html#how-to-use
193
194 # since forward takes two input arguments (x, operator)
195 # we have to return two arguments (the latter is None)
196 if grad_output is None:
197 return None, None
198 else:
199 operator = ctx.operator
200
201 batch_size = grad_output.shape[0]
202 x = torch.zeros(
203 (batch_size,) + operator.out_shape,
204 dtype=grad_output.dtype,
205 device=device(grad_output),
206 )
207
208 # loop over all samples in the batch and apply linear operator
209 # to the first channel
210 for i in range(batch_size):
211 x[i, ...] = operator(grad_output[i, 0, ...].detach())
212
213 return x, None
Setup a minimal non-TOF PET projector
We setup a minimal non-TOF PET projector of small scanner with three rings.
223 num_rings = 3
224 scanner = parallelproj.RegularPolygonPETScannerGeometry(
225 torch,
226 dev,
227 radius=35.0,
228 num_sides=12,
229 num_lor_endpoints_per_side=6,
230 lor_spacing=3.0,
231 ring_positions=torch.linspace(-4, 4, num_rings),
232 symmetry_axis=1,
233 )
234
235 # setup the LOR descriptor that defines the sinogram
236 lor_desc = parallelproj.RegularPolygonPETLORDescriptor(
237 scanner,
238 radial_trim=10,
239 max_ring_difference=1,
240 sinogram_order=parallelproj.SinogramSpatialAxisOrder.RVP,
241 )
242
243 proj = parallelproj.RegularPolygonPETProjector(
244 lor_desc, img_shape=(20, 5, 20), voxel_size=(2.0, 2.0, 2.0)
245 )
Define a mini batch of input and output tensors
251 batch_size = 2
252
253 x = torch.rand(
254 (batch_size, 1) + proj.in_shape,
255 device=dev,
256 dtype=torch.float32,
257 requires_grad=True,
258 )
259
260 y = torch.rand(
261 (batch_size,) + proj.out_shape,
262 device=dev,
263 dtype=torch.float32,
264 requires_grad=True,
265 )
Define the forward and backward projection layers
271 fwd_op_layer = LinearSingleChannelOperator.apply
272 adjoint_op_layer = AdjointLinearSingleChannelOperator.apply
273
274 f1 = fwd_op_layer(x, proj)
275 print("forward projection (Ax) .:", f1.shape, type(f1), device(f1))
276
277 b1 = adjoint_op_layer(y, proj)
278 print("back projection (A^T y) .:", b1.shape, type(b1), device(b1))
279
280 fb1 = adjoint_op_layer(fwd_op_layer(x, proj), proj)
281 print("back + forward projection (A^TAx) .:", fb1.shape, type(fb1), device(fb1))
forward projection (Ax) .: torch.Size([2, 53, 36, 7]) <class 'torch.Tensor'> cpu
back projection (A^T y) .: torch.Size([2, 1, 20, 5, 20]) <class 'torch.Tensor'> cpu
back + forward projection (A^TAx) .: torch.Size([2, 1, 20, 5, 20]) <class 'torch.Tensor'> cpu
Define a dummy loss function and trigger the backpropagation
288 # define a dummy loss function
289 dummy_loss = (fb1**2).sum()
290 # trigger the backpropagation
291 dummy_loss.backward()
292
293 print(f" backpropagted gradient {x.grad}")
backpropagted gradient tensor([[[[[20468784.0000, 22242056.0000, 23424154.0000, ...,
23413710.0000, 22221178.0000, 20439522.0000],
[ 7623251.0000, 9889296.0000, 12323244.0000, ...,
12241651.0000, 9880237.0000, 7581908.0000],
[32468618.0000, 34818896.0000, 36076256.0000, ...,
36298112.0000, 35010348.0000, 32578980.0000],
[ 7625693.5000, 9906050.0000, 12334679.0000, ...,
12247036.0000, 9866977.0000, 7560809.0000],
[20257866.0000, 22031570.0000, 23141980.0000, ...,
22865192.0000, 21744388.0000, 19985262.0000]],
[[22231584.0000, 23776200.0000, 24752482.0000, ...,
24681202.0000, 23744882.0000, 22235108.0000],
[ 9927428.0000, 12830780.0000, 15452092.0000, ...,
15339719.0000, 12754393.0000, 9844993.0000],
[34839568.0000, 36719244.0000, 37658132.0000, ...,
37768556.0000, 36907952.0000, 35008496.0000],
[ 9940504.0000, 12841288.0000, 15460707.0000, ...,
15419873.0000, 12775093.0000, 9838741.0000],
[22014958.0000, 23488950.0000, 24407800.0000, ...,
24103142.0000, 23208954.0000, 21775234.0000]],
[[23409278.0000, 24717122.0000, 25467456.0000, ...,
25424314.0000, 24672354.0000, 23351662.0000],
[12321049.0000, 15478504.0000, 18471516.0000, ...,
18275326.0000, 15294733.0000, 12247346.0000],
[36223832.0000, 37673708.0000, 38266076.0000, ...,
38341928.0000, 37820712.0000, 36284380.0000],
[12340143.0000, 15501313.0000, 18489270.0000, ...,
18432228.0000, 15411611.0000, 12289568.0000],
[23175564.0000, 24401020.0000, 25092264.0000, ...,
24823130.0000, 24152770.0000, 22917694.0000]],
...,
[[23264558.0000, 24502806.0000, 25176430.0000, ...,
24933514.0000, 24213324.0000, 22986894.0000],
[12370688.0000, 15464183.0000, 18511452.0000, ...,
18307956.0000, 15359501.0000, 12260117.0000],
[36872384.0000, 38527688.0000, 39155272.0000, ...,
38755828.0000, 38152724.0000, 36634176.0000],
[12491102.0000, 15670714.0000, 18763482.0000, ...,
18559618.0000, 15602688.0000, 12433445.0000],
[23400622.0000, 24718656.0000, 25441352.0000, ...,
24938294.0000, 24277282.0000, 23049118.0000]],
[[22126174.0000, 23597442.0000, 24478932.0000, ...,
24234654.0000, 23330004.0000, 21868746.0000],
[ 9939735.0000, 12869183.0000, 15502543.0000, ...,
15360977.0000, 12812474.0000, 9925474.0000],
[35561124.0000, 37556140.0000, 38520688.0000, ...,
38232848.0000, 37336524.0000, 35372740.0000],
[10029829.0000, 13018194.0000, 15717819.0000, ...,
15561702.0000, 12974250.0000, 10045313.0000],
[22294504.0000, 23833140.0000, 24734256.0000, ...,
24231738.0000, 23354164.0000, 21893926.0000]],
[[20361952.0000, 22104834.0000, 23232602.0000, ...,
22987006.0000, 21895898.0000, 20165432.0000],
[ 7665804.5000, 9964919.0000, 12338956.0000, ...,
12330028.0000, 9902195.0000, 7617023.5000],
[33074716.0000, 35521532.0000, 36911780.0000, ...,
36694152.0000, 35431760.0000, 32959992.0000],
[ 7714764.0000, 10064633.0000, 12475198.0000, ...,
12463844.0000, 10015962.0000, 7694401.5000],
[20518130.0000, 22321674.0000, 23456176.0000, ...,
22972184.0000, 21892578.0000, 20150178.0000]]]],
[[[[20787864.0000, 22657280.0000, 23764112.0000, ...,
23959648.0000, 22792720.0000, 20945532.0000],
[ 7733161.5000, 10078768.0000, 12543189.0000, ...,
12606887.0000, 10185013.0000, 7796861.0000],
[33223798.0000, 35685564.0000, 36941016.0000, ...,
37435784.0000, 36211080.0000, 33695564.0000],
[ 7755915.0000, 10074892.0000, 12546208.0000, ...,
12766197.0000, 10295520.0000, 7881976.5000],
[20432564.0000, 22243408.0000, 23403270.0000, ...,
23752880.0000, 22558192.0000, 20732876.0000]],
[[22657698.0000, 24181532.0000, 25107258.0000, ...,
25270568.0000, 24344674.0000, 22783732.0000],
[10119006.0000, 13083118.0000, 15707118.0000, ...,
15850157.0000, 13181271.0000, 10152499.0000],
[35714356.0000, 37627344.0000, 38582952.0000, ...,
39064020.0000, 38258856.0000, 36253536.0000],
[10108744.0000, 13041912.0000, 15681072.0000, ...,
16097965.0000, 13351507.0000, 10273107.0000],
[22205600.0000, 23694030.0000, 24646084.0000, ...,
25058680.0000, 24121472.0000, 22588932.0000]],
[[23840832.0000, 25134102.0000, 25804708.0000, ...,
25923704.0000, 25200310.0000, 23829900.0000],
[12558097.0000, 15763003.0000, 18705270.0000, ...,
18829082.0000, 15809293.0000, 12612810.0000],
[37104252.0000, 38627436.0000, 39234340.0000, ...,
39633700.0000, 39148108.0000, 37465892.0000],
[12523459.0000, 15711688.0000, 18756298.0000, ...,
19182872.0000, 16075050.0000, 12805176.0000],
[23355792.0000, 24593318.0000, 25285372.0000, ...,
25782488.0000, 25072072.0000, 23736660.0000]],
...,
[[23749624.0000, 25096342.0000, 25839486.0000, ...,
26199408.0000, 25487242.0000, 24125230.0000],
[12544147.0000, 15735324.0000, 18804286.0000, ...,
18908690.0000, 15925813.0000, 12710122.0000],
[37244920.0000, 38898168.0000, 39477816.0000, ...,
39537836.0000, 38920464.0000, 37320452.0000],
[12714554.0000, 15909244.0000, 18987640.0000, ...,
18830898.0000, 15759336.0000, 12549982.0000],
[23671914.0000, 24985466.0000, 25744194.0000, ...,
25856632.0000, 25108698.0000, 23791328.0000]],
[[22603378.0000, 24145572.0000, 25132738.0000, ...,
25571714.0000, 24661454.0000, 23033040.0000],
[10079074.0000, 13089359.0000, 15804201.0000, ...,
15925462.0000, 13313573.0000, 10295177.0000],
[36039012.0000, 38013104.0000, 38938008.0000, ...,
38971340.0000, 38050628.0000, 36043792.0000],
[10212288.0000, 13244966.0000, 15936608.0000, ...,
15744201.0000, 13107703.0000, 10170716.0000],
[22519602.0000, 24035972.0000, 24982896.0000, ...,
25183824.0000, 24222038.0000, 22642524.0000]],
[[20760676.0000, 22612502.0000, 23821478.0000, ...,
24220228.0000, 23099918.0000, 21184484.0000],
[ 7743985.0000, 10117988.0000, 12568213.0000, ...,
12777044.0000, 10274112.0000, 7870091.5000],
[33519828.0000, 36040188.0000, 37398364.0000, ...,
37349576.0000, 36073080.0000, 33552928.0000],
[ 7846918.0000, 10249990.0000, 12704178.0000, ...,
12606667.0000, 10148989.0000, 7825096.5000],
[20721584.0000, 22547502.0000, 23719466.0000, ...,
23861136.0000, 22712056.0000, 20865760.0000]]]]])
Check whether the gradients are calculated correctly
We use pytorch’s gradcheck function to check whether the implementation of the backward pass, needed to calculate the gradients, is correct.
This test can be slow which is why we only execute it on the gpu. Note that parallelproj’s projectors use single precision precision which is why we have to use a larger atol and rtol.
307 if dev == "cpu":
308 print("skipping (slow) gradient checks on cpu")
309 else:
310 print("Running forward projection layer gradient test")
311 grad_test_fwd = torch.autograd.gradcheck(
312 fwd_op_layer, (x, proj), eps=1e-1, atol=1e-3, rtol=1e-3
313 )
314
315 print("Running adjoint projection layer gradient test")
316 grad_test_fwd = torch.autograd.gradcheck(
317 adjoint_op_layer, (y, proj), eps=1e-1, atol=1e-3, rtol=1e-3
318 )
skipping (slow) gradient checks on cpu
Visualize the scanner geometry and image FOV
324 fig = plt.figure(figsize=(10, 10))
325 ax = fig.add_subplot(111, projection="3d")
326 proj.show_geometry(ax)
327 fig.tight_layout()
328 fig.show()

Total running time of the script: (0 minutes 1.217 seconds)
Related examples
Non-TOF and TOF projections using a modularized (block) PET scanner geometry