pytorch parallelproj projection layer

In this example, we show how to define a custom pytorch layer that can be used to define a feed forward neural network that includes a parallelproj forward and back backward projections (or any LinearOperator) that can be used with pytorch’s autograd engine.

https://mybinder.org/badge_logo.svg
16 from __future__ import annotations
17
18 import array_api_compat.torch as torch
19 import matplotlib.pyplot as plt
20 import parallelproj
21 from array_api_compat import device
22
23
24 # device variable (cpu or cuda) that determines whether calculations
25 # are performed on the cpu or cuda gpu
26 if parallelproj.cuda_present:
27     dev = "cuda"
28 else:
29     dev = "cpu"

Setup the forward projection layer

We subclass torch.autograd.Function to define a custom pytorch layer that is compatible with pytorch’s autograd engine. see also: https://pytorch.org/tutorials/beginner/examples_autograd/two_layer_net_custom_function.html

 40 class LinearSingleChannelOperator(torch.autograd.Function):
 41     """
 42     Function representing a linear operator acting on a mini batch of single channel images
 43     """
 44
 45     @staticmethod
 46     def forward(
 47         ctx, x: torch.Tensor, operator: parallelproj.LinearOperator
 48     ) -> torch.Tensor:
 49         """forward pass of the linear operator
 50
 51         Parameters
 52         ----------
 53         ctx : context object
 54             that can be used to store information for the backward pass
 55         x : torch.Tensor
 56             mini batch of 3D images with dimension (batch_size, 1, num_voxels_x, num_voxels_y, num_voxels_z)
 57         operator : parallelproj.LinearOperator
 58             linear operator that can act on a single 3D image
 59
 60         Returns
 61         -------
 62         torch.Tensor
 63             mini batch of 3D images with dimension (batch_size, opertor.out_shape)
 64         """
 65
 66         # https://pytorch.org/docs/stable/notes/extending.html#how-to-use
 67         ctx.set_materialize_grads(False)
 68         ctx.operator = operator
 69
 70         batch_size = x.shape[0]
 71         y = torch.zeros(
 72             (batch_size,) + operator.out_shape, dtype=x.dtype, device=device(x)
 73         )
 74
 75         # loop over all samples in the batch and apply linear operator
 76         # to the first channel
 77         for i in range(batch_size):
 78             y[i, ...] = operator(x[i, 0, ...].detach())
 79
 80         return y
 81
 82     @staticmethod
 83     def backward(ctx, grad_output: torch.Tensor) -> tuple[torch.Tensor, None]:
 84         """backward pass of the forward pass
 85
 86         Parameters
 87         ----------
 88         ctx : context object
 89             that can be used to obtain information from the forward pass
 90         grad_output : torch.Tensor
 91             mini batch of dimension (batch_size, operator.out_shape)
 92
 93         Returns
 94         -------
 95         torch.Tensor, None
 96             mini batch of 3D images with dimension (batch_size, 1, opertor.in_shape)
 97         """
 98
 99         # For details on how to implement the backward pass, see
100         # https://pytorch.org/docs/stable/notes/extending.html#how-to-use
101
102         # since forward takes two input arguments (x, operator)
103         # we have to return two arguments (the latter is None)
104         if grad_output is None:
105             return None, None
106         else:
107             operator = ctx.operator
108
109             batch_size = grad_output.shape[0]
110             x = torch.zeros(
111                 (batch_size, 1) + operator.in_shape,
112                 dtype=grad_output.dtype,
113                 device=device(grad_output),
114             )
115
116             # loop over all samples in the batch and apply linear operator
117             # to the first channel
118             for i in range(batch_size):
119                 x[i, 0, ...] = operator.adjoint(grad_output[i, ...].detach())
120
121             return x, None

Setup the back projection layer

We subclass torch.autograd.Function to define a custom pytorch layer that is compatible with pytorch’s autograd engine. see also: https://pytorch.org/tutorials/beginner/examples_autograd/two_layer_net_custom_function.html

133 class AdjointLinearSingleChannelOperator(torch.autograd.Function):
134     """
135     Function representing the adjoint of a linear operator acting on a mini batch of single channel images
136     """
137
138     @staticmethod
139     def forward(
140         ctx, x: torch.Tensor, operator: parallelproj.LinearOperator
141     ) -> torch.Tensor:
142         """forward pass of the adjoint of the linear operator
143
144         Parameters
145         ----------
146         ctx : context object
147             that can be used to store information for the backward pass
148         x : torch.Tensor
149             mini batch of 3D images with dimension (batch_size, 1, operator.out_shape)
150         operator : parallelproj.LinearOperator
151             linear operator that can act on a single 3D image
152
153         Returns
154         -------
155         torch.Tensor
156             mini batch of 3D images with dimension (batch_size, 1, opertor.in_shape)
157         """
158
159         ctx.set_materialize_grads(False)
160         ctx.operator = operator
161
162         batch_size = x.shape[0]
163         y = torch.zeros(
164             (batch_size, 1) + operator.in_shape, dtype=x.dtype, device=device(x)
165         )
166
167         # loop over all samples in the batch and apply linear operator
168         # to the first channel
169         for i in range(batch_size):
170             y[i, 0, ...] = operator.adjoint(x[i, ...].detach())
171
172         return y
173
174     @staticmethod
175     def backward(ctx, grad_output):
176         """backward pass of the forward pass
177
178         Parameters
179         ----------
180         ctx : context object
181             that can be used to obtain information from the forward pass
182         grad_output : torch.Tensor
183             mini batch of dimension (batch_size, 1, operator.in_shape)
184
185         Returns
186         -------
187         torch.Tensor, None
188             mini batch of 3D images with dimension (batch_size, 1, opertor.out_shape)
189         """
190
191         # For details on how to implement the backward pass, see
192         # https://pytorch.org/docs/stable/notes/extending.html#how-to-use
193
194         # since forward takes two input arguments (x, operator)
195         # we have to return two arguments (the latter is None)
196         if grad_output is None:
197             return None, None
198         else:
199             operator = ctx.operator
200
201             batch_size = grad_output.shape[0]
202             x = torch.zeros(
203                 (batch_size,) + operator.out_shape,
204                 dtype=grad_output.dtype,
205                 device=device(grad_output),
206             )
207
208             # loop over all samples in the batch and apply linear operator
209             # to the first channel
210             for i in range(batch_size):
211                 x[i, ...] = operator(grad_output[i, 0, ...].detach())
212
213             return x, None

Setup a minimal non-TOF PET projector

We setup a minimal non-TOF PET projector of small scanner with three rings.

223 num_rings = 3
224 scanner = parallelproj.RegularPolygonPETScannerGeometry(
225     torch,
226     dev,
227     radius=35.0,
228     num_sides=12,
229     num_lor_endpoints_per_side=6,
230     lor_spacing=3.0,
231     ring_positions=torch.linspace(-4, 4, num_rings),
232     symmetry_axis=1,
233 )
234
235 # setup the LOR descriptor that defines the sinogram
236 lor_desc = parallelproj.RegularPolygonPETLORDescriptor(
237     scanner,
238     radial_trim=10,
239     max_ring_difference=1,
240     sinogram_order=parallelproj.SinogramSpatialAxisOrder.RVP,
241 )
242
243 proj = parallelproj.RegularPolygonPETProjector(
244     lor_desc, img_shape=(20, 5, 20), voxel_size=(2.0, 2.0, 2.0)
245 )

Define a mini batch of input and output tensors

251 batch_size = 2
252
253 x = torch.rand(
254     (batch_size, 1) + proj.in_shape,
255     device=dev,
256     dtype=torch.float32,
257     requires_grad=True,
258 )
259
260 y = torch.rand(
261     (batch_size,) + proj.out_shape,
262     device=dev,
263     dtype=torch.float32,
264     requires_grad=True,
265 )

Define the forward and backward projection layers

271 fwd_op_layer = LinearSingleChannelOperator.apply
272 adjoint_op_layer = AdjointLinearSingleChannelOperator.apply
273
274 f1 = fwd_op_layer(x, proj)
275 print("forward projection (Ax) .:", f1.shape, type(f1), device(f1))
276
277 b1 = adjoint_op_layer(y, proj)
278 print("back projection (A^T y) .:", b1.shape, type(b1), device(b1))
279
280 fb1 = adjoint_op_layer(fwd_op_layer(x, proj), proj)
281 print("back + forward projection (A^TAx) .:", fb1.shape, type(fb1), device(fb1))
forward projection (Ax) .: torch.Size([2, 53, 36, 7]) <class 'torch.Tensor'> cpu
back projection (A^T y) .: torch.Size([2, 1, 20, 5, 20]) <class 'torch.Tensor'> cpu
back + forward projection (A^TAx) .: torch.Size([2, 1, 20, 5, 20]) <class 'torch.Tensor'> cpu

Define a dummy loss function and trigger the backpropagation

288 # define a dummy loss function
289 dummy_loss = (fb1**2).sum()
290 # trigger the backpropagation
291 dummy_loss.backward()
292
293 print(f" backpropagted gradient {x.grad}")
backpropagted gradient tensor([[[[[20468784.0000, 22242056.0000, 23424154.0000,  ...,
           23413710.0000, 22221178.0000, 20439522.0000],
          [ 7623251.0000,  9889296.0000, 12323244.0000,  ...,
           12241651.0000,  9880237.0000,  7581908.0000],
          [32468618.0000, 34818896.0000, 36076256.0000,  ...,
           36298112.0000, 35010348.0000, 32578980.0000],
          [ 7625693.5000,  9906050.0000, 12334679.0000,  ...,
           12247036.0000,  9866977.0000,  7560809.0000],
          [20257866.0000, 22031570.0000, 23141980.0000,  ...,
           22865192.0000, 21744388.0000, 19985262.0000]],

         [[22231584.0000, 23776200.0000, 24752482.0000,  ...,
           24681202.0000, 23744882.0000, 22235108.0000],
          [ 9927428.0000, 12830780.0000, 15452092.0000,  ...,
           15339719.0000, 12754393.0000,  9844993.0000],
          [34839568.0000, 36719244.0000, 37658132.0000,  ...,
           37768556.0000, 36907952.0000, 35008496.0000],
          [ 9940504.0000, 12841288.0000, 15460707.0000,  ...,
           15419873.0000, 12775093.0000,  9838741.0000],
          [22014958.0000, 23488950.0000, 24407800.0000,  ...,
           24103142.0000, 23208954.0000, 21775234.0000]],

         [[23409278.0000, 24717122.0000, 25467456.0000,  ...,
           25424314.0000, 24672354.0000, 23351662.0000],
          [12321049.0000, 15478504.0000, 18471516.0000,  ...,
           18275326.0000, 15294733.0000, 12247346.0000],
          [36223832.0000, 37673708.0000, 38266076.0000,  ...,
           38341928.0000, 37820712.0000, 36284380.0000],
          [12340143.0000, 15501313.0000, 18489270.0000,  ...,
           18432228.0000, 15411611.0000, 12289568.0000],
          [23175564.0000, 24401020.0000, 25092264.0000,  ...,
           24823130.0000, 24152770.0000, 22917694.0000]],

         ...,

         [[23264558.0000, 24502806.0000, 25176430.0000,  ...,
           24933514.0000, 24213324.0000, 22986894.0000],
          [12370688.0000, 15464183.0000, 18511452.0000,  ...,
           18307956.0000, 15359501.0000, 12260117.0000],
          [36872384.0000, 38527688.0000, 39155272.0000,  ...,
           38755828.0000, 38152724.0000, 36634176.0000],
          [12491102.0000, 15670714.0000, 18763482.0000,  ...,
           18559618.0000, 15602688.0000, 12433445.0000],
          [23400622.0000, 24718656.0000, 25441352.0000,  ...,
           24938294.0000, 24277282.0000, 23049118.0000]],

         [[22126174.0000, 23597442.0000, 24478932.0000,  ...,
           24234654.0000, 23330004.0000, 21868746.0000],
          [ 9939735.0000, 12869183.0000, 15502543.0000,  ...,
           15360977.0000, 12812474.0000,  9925474.0000],
          [35561124.0000, 37556140.0000, 38520688.0000,  ...,
           38232848.0000, 37336524.0000, 35372740.0000],
          [10029829.0000, 13018194.0000, 15717819.0000,  ...,
           15561702.0000, 12974250.0000, 10045313.0000],
          [22294504.0000, 23833140.0000, 24734256.0000,  ...,
           24231738.0000, 23354164.0000, 21893926.0000]],

         [[20361952.0000, 22104834.0000, 23232602.0000,  ...,
           22987006.0000, 21895898.0000, 20165432.0000],
          [ 7665804.5000,  9964919.0000, 12338956.0000,  ...,
           12330028.0000,  9902195.0000,  7617023.5000],
          [33074716.0000, 35521532.0000, 36911780.0000,  ...,
           36694152.0000, 35431760.0000, 32959992.0000],
          [ 7714764.0000, 10064633.0000, 12475198.0000,  ...,
           12463844.0000, 10015962.0000,  7694401.5000],
          [20518130.0000, 22321674.0000, 23456176.0000,  ...,
           22972184.0000, 21892578.0000, 20150178.0000]]]],



       [[[[20787864.0000, 22657280.0000, 23764112.0000,  ...,
           23959648.0000, 22792720.0000, 20945532.0000],
          [ 7733161.5000, 10078768.0000, 12543189.0000,  ...,
           12606887.0000, 10185013.0000,  7796861.0000],
          [33223798.0000, 35685564.0000, 36941016.0000,  ...,
           37435784.0000, 36211080.0000, 33695564.0000],
          [ 7755915.0000, 10074892.0000, 12546208.0000,  ...,
           12766197.0000, 10295520.0000,  7881976.5000],
          [20432564.0000, 22243408.0000, 23403270.0000,  ...,
           23752880.0000, 22558192.0000, 20732876.0000]],

         [[22657698.0000, 24181532.0000, 25107258.0000,  ...,
           25270568.0000, 24344674.0000, 22783732.0000],
          [10119006.0000, 13083118.0000, 15707118.0000,  ...,
           15850157.0000, 13181271.0000, 10152499.0000],
          [35714356.0000, 37627344.0000, 38582952.0000,  ...,
           39064020.0000, 38258856.0000, 36253536.0000],
          [10108744.0000, 13041912.0000, 15681072.0000,  ...,
           16097965.0000, 13351507.0000, 10273107.0000],
          [22205600.0000, 23694030.0000, 24646084.0000,  ...,
           25058680.0000, 24121472.0000, 22588932.0000]],

         [[23840832.0000, 25134102.0000, 25804708.0000,  ...,
           25923704.0000, 25200310.0000, 23829900.0000],
          [12558097.0000, 15763003.0000, 18705270.0000,  ...,
           18829082.0000, 15809293.0000, 12612810.0000],
          [37104252.0000, 38627436.0000, 39234340.0000,  ...,
           39633700.0000, 39148108.0000, 37465892.0000],
          [12523459.0000, 15711688.0000, 18756298.0000,  ...,
           19182872.0000, 16075050.0000, 12805176.0000],
          [23355792.0000, 24593318.0000, 25285372.0000,  ...,
           25782488.0000, 25072072.0000, 23736660.0000]],

         ...,

         [[23749624.0000, 25096342.0000, 25839486.0000,  ...,
           26199408.0000, 25487242.0000, 24125230.0000],
          [12544147.0000, 15735324.0000, 18804286.0000,  ...,
           18908690.0000, 15925813.0000, 12710122.0000],
          [37244920.0000, 38898168.0000, 39477816.0000,  ...,
           39537836.0000, 38920464.0000, 37320452.0000],
          [12714554.0000, 15909244.0000, 18987640.0000,  ...,
           18830898.0000, 15759336.0000, 12549982.0000],
          [23671914.0000, 24985466.0000, 25744194.0000,  ...,
           25856632.0000, 25108698.0000, 23791328.0000]],

         [[22603378.0000, 24145572.0000, 25132738.0000,  ...,
           25571714.0000, 24661454.0000, 23033040.0000],
          [10079074.0000, 13089359.0000, 15804201.0000,  ...,
           15925462.0000, 13313573.0000, 10295177.0000],
          [36039012.0000, 38013104.0000, 38938008.0000,  ...,
           38971340.0000, 38050628.0000, 36043792.0000],
          [10212288.0000, 13244966.0000, 15936608.0000,  ...,
           15744201.0000, 13107703.0000, 10170716.0000],
          [22519602.0000, 24035972.0000, 24982896.0000,  ...,
           25183824.0000, 24222038.0000, 22642524.0000]],

         [[20760676.0000, 22612502.0000, 23821478.0000,  ...,
           24220228.0000, 23099918.0000, 21184484.0000],
          [ 7743985.0000, 10117988.0000, 12568213.0000,  ...,
           12777044.0000, 10274112.0000,  7870091.5000],
          [33519828.0000, 36040188.0000, 37398364.0000,  ...,
           37349576.0000, 36073080.0000, 33552928.0000],
          [ 7846918.0000, 10249990.0000, 12704178.0000,  ...,
           12606667.0000, 10148989.0000,  7825096.5000],
          [20721584.0000, 22547502.0000, 23719466.0000,  ...,
           23861136.0000, 22712056.0000, 20865760.0000]]]]])

Check whether the gradients are calculated correctly

We use pytorch’s gradcheck function to check whether the implementation of the backward pass, needed to calculate the gradients, is correct.

This test can be slow which is why we only execute it on the gpu. Note that parallelproj’s projectors use single precision precision which is why we have to use a larger atol and rtol.

307 if dev == "cpu":
308     print("skipping (slow) gradient checks on cpu")
309 else:
310     print("Running forward projection layer gradient test")
311     grad_test_fwd = torch.autograd.gradcheck(
312         fwd_op_layer, (x, proj), eps=1e-1, atol=1e-3, rtol=1e-3
313     )
314
315     print("Running adjoint projection layer gradient test")
316     grad_test_fwd = torch.autograd.gradcheck(
317         adjoint_op_layer, (y, proj), eps=1e-1, atol=1e-3, rtol=1e-3
318     )
skipping (slow) gradient checks on cpu

Visualize the scanner geometry and image FOV

324 fig = plt.figure(figsize=(10, 10))
325 ax = fig.add_subplot(111, projection="3d")
326 proj.show_geometry(ax)
327 fig.tight_layout()
328 fig.show()
01 run projection layer

Total running time of the script: (0 minutes 1.217 seconds)

Related examples

TOF OSEM with projection data

TOF OSEM with projection data

TOF-MLEM with projection data

TOF-MLEM with projection data

PET TOF sinogram projector

PET TOF sinogram projector

Non-TOF and TOF projections using a modularized (block) PET scanner geometry

Non-TOF and TOF projections using a modularized (block) PET scanner geometry

Gallery generated by Sphinx-Gallery