Note
Go to the end to download the full example code.
MLEM with projection data of an open PET geometry
This example demonstrates the use of the MLEM algorithm to minimize the negative Poisson log-likelihood function using “sinogram” data from an open PET geometry.
subject to
using the linear forward model
Tip
parallelproj is python array API compatible meaning it supports different
array backends (e.g. numpy, cupy, torch, …) and devices (CPU or GPU).
Choose your preferred array API xp and device dev below.
31 from __future__ import annotations
32 from parallelproj import Array
33
34 import array_api_compat.numpy as xp
35
36 # import array_api_compat.cupy as xp
37 # import array_api_compat.torch as xp
38
39 import parallelproj
40 from array_api_compat import to_device
41 import array_api_compat.numpy as np
42 import matplotlib.pyplot as plt
43 import matplotlib.animation as animation
44
45 # choose a device (CPU or CUDA GPU)
46 if "numpy" in xp.__name__:
47 # using numpy, device must be cpu
48 dev = "cpu"
49 elif "cupy" in xp.__name__:
50 # using cupy, only cuda devices are possible
51 dev = xp.cuda.Device(0)
52 elif "torch" in xp.__name__:
53 # using torch valid choices are 'cpu' or 'cuda'
54 if parallelproj.cuda_present:
55 dev = "cuda"
56 else:
57 dev = "cpu"
Setup of the forward model \(\bar{y}(x) = A x + s\)
We setup a linear forward operator \(A\) consisting of an image-based resolution model, a non-TOF PET projector and an attenuation model
Here we create an open geometry with 6 sides and 5 rings corresponding to a full geometry using 12 sides where 6 sides were removed.
70 num_rings = 1
71 scanner = parallelproj.RegularPolygonPETScannerGeometry(
72 xp,
73 dev,
74 radius=65.0,
75 num_sides=6,
76 num_lor_endpoints_per_side=15,
77 lor_spacing=2.3,
78 ring_positions=xp.asarray([0.0], device=dev),
79 symmetry_axis=2,
80 phis=(2 * xp.pi / 12) * xp.asarray([-1, 0, 1, 5, 6, 7]),
81 )
setup the LOR descriptor that defines the sinogram
86 img_shape = (40, 40, 1)
87 voxel_size = (2.0, 2.0, 2.0)
88
89 lor_desc = parallelproj.RegularPolygonPETLORDescriptor(
90 scanner,
91 radial_trim=1,
92 sinogram_order=parallelproj.SinogramSpatialAxisOrder.RVP,
93 )
94
95 proj = parallelproj.RegularPolygonPETProjector(
96 lor_desc, img_shape=img_shape, voxel_size=voxel_size
97 )
98
99 # setup a simple test image containing a few "hot rods"
100 x_true = xp.ones(proj.in_shape, device=dev, dtype=xp.float32)
101 c0 = proj.in_shape[0] // 2
102 c1 = proj.in_shape[1] // 2
103
104 x_true[4, c1, :] = 5.0
105 x_true[8, c1, :] = 5.0
106 x_true[12, c1, :] = 5.0
107 x_true[16, c1, :] = 5.0
108
109 x_true[c0, 4, :] = 5.0
110 x_true[c0, 8, :] = 5.0
111 x_true[c0, 12, :] = 5.0
112 x_true[c0, 16, :] = 5.0
113
114 x_true[:2, :, :] = 0
115 x_true[-2:, :, :] = 0
116 x_true[:, :2, :] = 0
117 x_true[:, -2:, :] = 0
Attenuation image and sinogram setup
123 # setup an attenuation image
124 x_att = 0.01 * xp.astype(x_true > 0, xp.float32)
125 # calculate the attenuation sinogram
126 att_sino = xp.exp(-proj(x_att))
Complete PET forward model setup
We combine an image-based resolution model, a non-TOF or TOF PET projector and an attenuation model into a single linear operator.
137 ## enable TOF - comment if you want to run non-TOF
138 # proj.tof_parameters = parallelproj.TOFParameters(
139 # num_tofbins=13 * 5,
140 # tofbin_width=12.0 / 5,
141 # sigma_tof=12.0 / 5,
142 # )
143
144 # setup the attenuation multiplication operator which is different
145 # for TOF and non-TOF since the attenuation sinogram is always non-TOF
146 if proj.tof:
147 att_op = parallelproj.TOFNonTOFElementwiseMultiplicationOperator(
148 proj.out_shape, att_sino
149 )
150 else:
151 att_op = parallelproj.ElementwiseMultiplicationOperator(att_sino)
152
153 res_model = parallelproj.GaussianFilterOperator(
154 proj.in_shape, sigma=4.5 / (2.35 * proj.voxel_size)
155 )
156
157 # compose all 3 operators into a single linear operator
158 pet_lin_op = parallelproj.CompositeLinearOperator((att_op, proj, res_model))
Visualization of the geometry
165 fig = plt.figure(figsize=(16, 8), tight_layout=True)
166 ax1 = fig.add_subplot(121, projection="3d")
167 ax2 = fig.add_subplot(122, projection="3d")
168 proj.show_geometry(ax1)
169 proj.show_geometry(ax2)
170 proj.lor_descriptor.show_views(
171 ax1,
172 views=xp.asarray([0], device=dev),
173 planes=xp.asarray([num_rings // 2], device=dev),
174 lw=0.5,
175 color="k",
176 )
177 ax1.set_title(f"view 0, plane {num_rings // 2}")
178 proj.lor_descriptor.show_views(
179 ax2,
180 views=xp.asarray([proj.lor_descriptor.num_views // 2], device=dev),
181 planes=xp.asarray([num_rings // 2], device=dev),
182 lw=0.5,
183 color="k",
184 )
185 ax2.set_title(f"view {proj.lor_descriptor.num_views // 2}, plane {num_rings // 2}")
186 fig.tight_layout()
187 fig.show()

Simulation of projection data
We setup an arbitrary ground truth \(x_{true}\) and simulate noise-free and noisy data \(y\) by adding Poisson noise.
196 # simulated noise-free data
197 noise_free_data = pet_lin_op(x_true)
198
199 # generate a contant contamination sinogram
200 contamination = xp.full(
201 noise_free_data.shape,
202 0.5 * float(xp.mean(noise_free_data)),
203 device=dev,
204 dtype=xp.float32,
205 )
206
207 noise_free_data += contamination
208
209 # add Poisson noise
210 # np.random.seed(1)
211 # y = xp.asarray(
212 # np.random.poisson(parallelproj.to_numpy_array(noise_free_data)),
213 # device=dev,
214 # dtype=xp.float64,
215 # )
216
217 y = noise_free_data
EM update to minimize \(f(x)\)
The EM update that can be used in MLEM or OSEM is given by cite:p:Dempster1977 [SV82] [LC84] [HL94]
to calculate the minimizer of \(f(x)\) iteratively.
To monitor the convergence we calculate the relative cost
and the distance to the optimal point
We setup a function that calculates a single MLEM/OSEM update given the current solution, a linear forward operator, data, contamination and the adjoint of ones.
246 def em_update(
247 x_cur: Array,
248 data: Array,
249 op: parallelproj.LinearOperator,
250 s: Array,
251 adjoint_ones: Array,
252 ) -> Array:
253 """EM update
254
255 Parameters
256 ----------
257 x_cur : Array
258 current solution
259 data : Array
260 data
261 op : parallelproj.LinearOperator
262 linear forward operator
263 s : Array
264 contamination
265 adjoint_ones : Array
266 adjoint of ones
267
268 Returns
269 -------
270 Array
271 _description_
272 """
273 ybar = op(x_cur) + s
274 return x_cur * op.adjoint(data / ybar) / adjoint_ones
Run the MLEM iterations
281 # number of MLEM iterations
282 num_iter = 100
283
284 # initialize x
285 x = xp.ones(pet_lin_op.in_shape, dtype=xp.float32, device=dev)
286 # calculate A^H 1
287 adjoint_ones = pet_lin_op.adjoint(
288 xp.ones(pet_lin_op.out_shape, dtype=xp.float32, device=dev)
289 )
290
291 for i in range(num_iter):
292 print(f"MLEM iteration {(i + 1):03} / {num_iter:03}", end="\r")
293 x = em_update(x, y, pet_lin_op, contamination, adjoint_ones)
MLEM iteration 001 / 100
MLEM iteration 002 / 100
MLEM iteration 003 / 100
MLEM iteration 004 / 100
MLEM iteration 005 / 100
MLEM iteration 006 / 100
MLEM iteration 007 / 100
MLEM iteration 008 / 100
MLEM iteration 009 / 100
MLEM iteration 010 / 100
MLEM iteration 011 / 100
MLEM iteration 012 / 100
MLEM iteration 013 / 100
MLEM iteration 014 / 100
MLEM iteration 015 / 100
MLEM iteration 016 / 100
MLEM iteration 017 / 100
MLEM iteration 018 / 100
MLEM iteration 019 / 100
MLEM iteration 020 / 100
MLEM iteration 021 / 100
MLEM iteration 022 / 100
MLEM iteration 023 / 100
MLEM iteration 024 / 100
MLEM iteration 025 / 100
MLEM iteration 026 / 100
MLEM iteration 027 / 100
MLEM iteration 028 / 100
MLEM iteration 029 / 100
MLEM iteration 030 / 100
MLEM iteration 031 / 100
MLEM iteration 032 / 100
MLEM iteration 033 / 100
MLEM iteration 034 / 100
MLEM iteration 035 / 100
MLEM iteration 036 / 100
MLEM iteration 037 / 100
MLEM iteration 038 / 100
MLEM iteration 039 / 100
MLEM iteration 040 / 100
MLEM iteration 041 / 100
MLEM iteration 042 / 100
MLEM iteration 043 / 100
MLEM iteration 044 / 100
MLEM iteration 045 / 100
MLEM iteration 046 / 100
MLEM iteration 047 / 100
MLEM iteration 048 / 100
MLEM iteration 049 / 100
MLEM iteration 050 / 100
MLEM iteration 051 / 100
MLEM iteration 052 / 100
MLEM iteration 053 / 100
MLEM iteration 054 / 100
MLEM iteration 055 / 100
MLEM iteration 056 / 100
MLEM iteration 057 / 100
MLEM iteration 058 / 100
MLEM iteration 059 / 100
MLEM iteration 060 / 100
MLEM iteration 061 / 100
MLEM iteration 062 / 100
MLEM iteration 063 / 100
MLEM iteration 064 / 100
MLEM iteration 065 / 100
MLEM iteration 066 / 100
MLEM iteration 067 / 100
MLEM iteration 068 / 100
MLEM iteration 069 / 100
MLEM iteration 070 / 100
MLEM iteration 071 / 100
MLEM iteration 072 / 100
MLEM iteration 073 / 100
MLEM iteration 074 / 100
MLEM iteration 075 / 100
MLEM iteration 076 / 100
MLEM iteration 077 / 100
MLEM iteration 078 / 100
MLEM iteration 079 / 100
MLEM iteration 080 / 100
MLEM iteration 081 / 100
MLEM iteration 082 / 100
MLEM iteration 083 / 100
MLEM iteration 084 / 100
MLEM iteration 085 / 100
MLEM iteration 086 / 100
MLEM iteration 087 / 100
MLEM iteration 088 / 100
MLEM iteration 089 / 100
MLEM iteration 090 / 100
MLEM iteration 091 / 100
MLEM iteration 092 / 100
MLEM iteration 093 / 100
MLEM iteration 094 / 100
MLEM iteration 095 / 100
MLEM iteration 096 / 100
MLEM iteration 097 / 100
MLEM iteration 098 / 100
MLEM iteration 099 / 100
MLEM iteration 100 / 100
Calculation of the negative Poisson log-likelihood function of the reconstruction
300 # calculate the negative Poisson log-likelihood function of the reconstruction
301 exp = pet_lin_op(x) + contamination
302 # calculate the relative cost and distance to the optimal point
303 cost = float(xp.sum(exp - y * xp.log(exp)))
304 print(f"\nMLEM cost {cost:.6E} after {num_iter:03} iterations")
MLEM cost -2.586407E+05 after 100 iterations
Visualize the results
311 def _update_img(i):
312 img0.set_data(x_true_np[:, :, i])
313 img1.set_data(x_np[:, :, i])
314 ax[0].set_title(f"true image - plane {i:02}")
315 ax[1].set_title(f"MLEM iteration {num_iter} - plane {i:02}")
316 return (img0, img1)
317
318
319 x_true_np = parallelproj.to_numpy_array(x_true)
320 x_np = parallelproj.to_numpy_array(x)
321
322 fig, ax = plt.subplots(1, 2, figsize=(10, 5))
323 vmax = x_np.max()
324 img0 = ax[0].imshow(x_true_np[:, :, 0], cmap="Greys", vmin=0, vmax=vmax)
325 img1 = ax[1].imshow(x_np[:, :, 0], cmap="Greys", vmin=0, vmax=vmax)
326 ax[0].set_title(f"true image - plane {0:02}")
327 ax[1].set_title(f"MLEM iteration {num_iter} - plane {0:02}")
328 fig.tight_layout()
329 ani = animation.FuncAnimation(fig, _update_img, x_np.shape[2], interval=200, blit=False)
330 fig.show()
Total running time of the script: (0 minutes 0.646 seconds)
Related examples