Note
Go to the end to download the full example code.
TOF OSEM with projection data
This example demonstrates the use of the MLEM algorithm to minimize the negative Poisson log-likelihood function.
subject to
using the linear forward model
Tip
parallelproj is python array API compatible meaning it supports different
array backends (e.g. numpy, cupy, torch, …) and devices (CPU or GPU).
Choose your preferred array API xp and device dev below.
30 from __future__ import annotations
31 from parallelproj import Array
32
33 import array_api_compat.numpy as xp
34
35 # import array_api_compat.cupy as xp
36 # import array_api_compat.torch as xp
37
38 import parallelproj
39 from array_api_compat import to_device
40 import array_api_compat.numpy as np
41 import matplotlib.pyplot as plt
42 import matplotlib.animation as animation
43 from copy import copy
44
45 # choose a device (CPU or CUDA GPU)
46 if "numpy" in xp.__name__:
47 # using numpy, device must be cpu
48 dev = "cpu"
49 elif "cupy" in xp.__name__:
50 # using cupy, only cuda devices are possible
51 dev = xp.cuda.Device(0)
52 elif "torch" in xp.__name__:
53 # using torch valid choices are 'cpu' or 'cuda'
54 if parallelproj.cuda_present:
55 dev = "cuda"
56 else:
57 dev = "cpu"
Setup of the forward model \(\bar{y}(x) = A x + s\)
We setup a linear forward operator \(A\) consisting of an image-based resolution model, a non-TOF PET projector and an attenuation model
Note
The OSEM implementation below works with all linear operators that
subclass LinearOperator (e.g. the high-level projectors).
setup the LOR descriptor that defines the sinogram
86 img_shape = (40, 40, 8)
87 voxel_size = (2.0, 2.0, 2.0)
88
89 lor_desc = parallelproj.RegularPolygonPETLORDescriptor(
90 scanner,
91 radial_trim=10,
92 max_ring_difference=2,
93 sinogram_order=parallelproj.SinogramSpatialAxisOrder.RVP,
94 )
95
96 proj = parallelproj.RegularPolygonPETProjector(
97 lor_desc, img_shape=img_shape, voxel_size=voxel_size
98 )
99
100 # setup a simple test image containing a few "hot rods"
101 x_true = xp.ones(proj.in_shape, device=dev, dtype=xp.float32)
102 c0 = proj.in_shape[0] // 2
103 c1 = proj.in_shape[1] // 2
104 x_true[(c0 - 2) : (c0 + 2), (c1 - 2) : (c1 + 2), :] = 5.0
105 x_true[4, c1, 2:] = 5.0
106 x_true[c0, 4, :-2] = 5.0
107
108 x_true[:2, :, :] = 0
109 x_true[-2:, :, :] = 0
110 x_true[:, :2, :] = 0
111 x_true[:, -2:, :] = 0
Attenuation image and sinogram setup
117 # setup an attenuation image
118 x_att = 0.01 * xp.astype(x_true > 0, xp.float32)
119 # calculate the attenuation sinogram
120 att_sino = xp.exp(-proj(x_att))
Complete PET forward model setup
We combine an image-based resolution model, a non-TOF or TOF PET projector and an attenuation model into a single linear operator.
130 # enable TOF - comment if you want to run non-TOF
131 proj.tof_parameters = parallelproj.TOFParameters(
132 num_tofbins=13, tofbin_width=12.0, sigma_tof=12.0
133 )
134
135 # setup the attenuation multiplication operator which is different
136 # for TOF and non-TOF since the attenuation sinogram is always non-TOF
137 if proj.tof:
138 att_op = parallelproj.TOFNonTOFElementwiseMultiplicationOperator(
139 proj.out_shape, att_sino
140 )
141 else:
142 att_op = parallelproj.ElementwiseMultiplicationOperator(att_sino)
143
144 res_model = parallelproj.GaussianFilterOperator(
145 proj.in_shape, sigma=4.5 / (2.35 * proj.voxel_size)
146 )
147
148 # compose all 3 operators into a single linear operator
149 pet_lin_op = parallelproj.CompositeLinearOperator((att_op, proj, res_model))
Simulation of projection data
We setup an arbitrary ground truth \(x_{true}\) and simulate noise-free and noisy data \(y\) by adding Poisson noise.
158 # simulated noise-free data
159 noise_free_data = pet_lin_op(x_true)
160
161 # generate a contant contamination sinogram
162 contamination = xp.full(
163 noise_free_data.shape,
164 0.5 * float(xp.mean(noise_free_data)),
165 device=dev,
166 dtype=xp.float32,
167 )
168
169 noise_free_data += contamination
170
171 # add Poisson noise
172 np.random.seed(1)
173 y = xp.asarray(
174 np.random.poisson(parallelproj.to_numpy_array(noise_free_data)),
175 device=dev,
176 dtype=xp.float64,
177 )
Splitting of the forward model into subsets \(A^k\)
Calculate the view numbers and slices for each subset. We will use the subset views to setup a sequence of projectors projecting only a subset of views. The slices can be used to extract the corresponding subsets from full data or corrections sinograms.
188 num_subsets = 10
189
190 subset_views, subset_slices = proj.lor_descriptor.get_distributed_views_and_slices(
191 num_subsets, len(proj.out_shape)
192 )
193
194 _, subset_slices_non_tof = proj.lor_descriptor.get_distributed_views_and_slices(
195 num_subsets, 3
196 )
197
198 # clear the cached LOR endpoints since we will create many copies of the projector
199 proj.clear_cached_lor_endpoints()
200 pet_subset_linop_seq = []
201
202 # we setup a sequence of subset forward operators each constisting of
203 # (1) image-based resolution model
204 # (2) subset projector
205 # (3) multiplication with the corresponding subset of the attenuation sinogram
206 for i in range(num_subsets):
207 print(f"subset {i:02} containing views {subset_views[i]}")
208
209 # make a copy of the full projector and reset the views to project
210 subset_proj = copy(proj)
211 subset_proj.views = subset_views[i]
212
213 if subset_proj.tof:
214 subset_att_op = parallelproj.TOFNonTOFElementwiseMultiplicationOperator(
215 subset_proj.out_shape, att_sino[subset_slices_non_tof[i]]
216 )
217 else:
218 subset_att_op = parallelproj.ElementwiseMultiplicationOperator(
219 att_sino[subset_slices_non_tof[i]]
220 )
221
222 # add the resolution model and multiplication with a subset of the attenuation sinogram
223 pet_subset_linop_seq.append(
224 parallelproj.CompositeLinearOperator(
225 [
226 subset_att_op,
227 subset_proj,
228 res_model,
229 ]
230 )
231 )
232
233 pet_subset_linop_seq = parallelproj.LinearOperatorSequence(pet_subset_linop_seq)
subset 00 containing views [ 0 10 20 30 40 50 60 70 80]
subset 01 containing views [ 5 15 25 35 45 55 65 75 85]
subset 02 containing views [ 1 11 21 31 41 51 61 71 81]
subset 03 containing views [ 6 16 26 36 46 56 66 76 86]
subset 04 containing views [ 2 12 22 32 42 52 62 72 82]
subset 05 containing views [ 7 17 27 37 47 57 67 77 87]
subset 06 containing views [ 3 13 23 33 43 53 63 73 83]
subset 07 containing views [ 8 18 28 38 48 58 68 78 88]
subset 08 containing views [ 4 14 24 34 44 54 64 74 84]
subset 09 containing views [ 9 19 29 39 49 59 69 79 89]
EM update to minimize \(f(x)\)
The EM update that can be used in MLEM or OSEM is given by cite:p:Dempster1977 [SV82] [LC84] [HL94]
to calculate the minimizer of \(f(x)\) iteratively.
To monitor the convergence we calculate the relative cost
and the distance to the optimal point
We setup a function that calculates a single MLEM/OSEM update given the current solution, a linear forward operator, data, contamination and the adjoint of ones.
262 def em_update(
263 x_cur: Array,
264 data: Array,
265 op: parallelproj.LinearOperator,
266 s: Array,
267 adjoint_ones: Array,
268 ) -> Array:
269 """EM update
270
271 Parameters
272 ----------
273 x_cur : Array
274 current solution
275 data : Array
276 data
277 op : parallelproj.LinearOperator
278 linear forward operator
279 s : Array
280 contamination
281 adjoint_ones : Array
282 adjoint of ones
283
284 Returns
285 -------
286 Array
287 """
288 ybar = op(x_cur) + s
289 return x_cur * op.adjoint(data / ybar) / adjoint_ones
Run the OSEM iterations
Note that the OSEM iterations are almost the same as the MLEM iterations. The only difference is that in every subset update, we pass an operator that projects a subset, a subset of the data and a subset of the contamination.
The “sensitivity” images are also calculated separately for each subset.
305 # number of OSEM iterations
306 num_iter = 20 // len(pet_subset_linop_seq)
307
308 # initialize x
309 x = xp.ones(pet_lin_op.in_shape, dtype=xp.float64, device=dev)
310
311 # calculate A_k^H 1 for all subsets k
312 subset_adjoint_ones = [
313 x.adjoint(xp.ones(x.out_shape, dtype=xp.float64, device=dev))
314 for x in pet_subset_linop_seq
315 ]
316
317 # OSEM iterations
318 for i in range(num_iter):
319 for k, sl in enumerate(subset_slices):
320 print(f"OSEM iteration {(k+1):03} / {(i + 1):03} / {num_iter:03}", end="\r")
321 x = em_update(
322 x, y[sl], pet_subset_linop_seq[k], contamination[sl], subset_adjoint_ones[k]
323 )
OSEM iteration 001 / 001 / 002
OSEM iteration 002 / 001 / 002
OSEM iteration 003 / 001 / 002
OSEM iteration 004 / 001 / 002
OSEM iteration 005 / 001 / 002
OSEM iteration 006 / 001 / 002
OSEM iteration 007 / 001 / 002
OSEM iteration 008 / 001 / 002
OSEM iteration 009 / 001 / 002
OSEM iteration 010 / 001 / 002
OSEM iteration 001 / 002 / 002
OSEM iteration 002 / 002 / 002
OSEM iteration 003 / 002 / 002
OSEM iteration 004 / 002 / 002
OSEM iteration 005 / 002 / 002
OSEM iteration 006 / 002 / 002
OSEM iteration 007 / 002 / 002
OSEM iteration 008 / 002 / 002
OSEM iteration 009 / 002 / 002
OSEM iteration 010 / 002 / 002
Calculation of the negative Poisson log-likelihood function of the reconstruction
329 # calculate the negative Poisson log-likelihood function of the reconstruction
330 exp = pet_lin_op(x) + contamination
331 # calculate the relative cost and distance to the optimal point
332 cost = float(xp.sum(exp - y * xp.log(exp)))
333 print(
334 f"\nOSEM cost {cost:.6E} after {num_iter:03} iterations with {num_subsets} subsets"
335 )
OSEM cost -2.667261E+05 after 002 iterations with 10 subsets
Visualize the results
343 def _update_img(i):
344 img0.set_data(x_true_np[:, :, i])
345 img1.set_data(x_np[:, :, i])
346 ax[0].set_title(f"true image - plane {i:02}")
347 ax[1].set_title(f"OSEM iteration {num_iter} - {num_subsets} subsets - plane {i:02}")
348 return (img0, img1)
349
350
351 x_true_np = parallelproj.to_numpy_array(x_true)
352 x_np = parallelproj.to_numpy_array(x)
353
354 fig, ax = plt.subplots(1, 2, figsize=(10, 5))
355 vmax = x_np.max()
356 img0 = ax[0].imshow(x_true_np[:, :, 0], cmap="Greys", vmin=0, vmax=vmax)
357 img1 = ax[1].imshow(x_np[:, :, 0], cmap="Greys", vmin=0, vmax=vmax)
358 ax[0].set_title(f"true image - plane {0:02}")
359 ax[1].set_title(f"OSEM iteration {num_iter} - {num_subsets} subsets - plane {0:02}")
360 fig.tight_layout()
361 ani = animation.FuncAnimation(fig, _update_img, x_np.shape[2], interval=200, blit=False)
362 fig.show()
Total running time of the script: (0 minutes 6.973 seconds)
Related examples