Note
Go to the end to download the full example code.
PDHG to optimize the Poisson logL and directional TV (structural prior)
This example demonstrates the use of the primal dual hybrid gradient (PDHG) algorithm, to minimize the negative Poisson log-likelihood function combined with a directional total variation regularizer (a structural prior):
subject to
using the linear forward model
see [EB16] and [EMS19] for details.
Tip
parallelproj is python array API compatible meaning it supports different
array backends (e.g. numpy, cupy, torch, …) and devices (CPU or GPU).
Choose your preferred array API xp and device dev below.
Warning
Running this example using GPU arrays (e.g. using cupy as array backend) is highly recommended due to “longer” execution times with CPU arrays
38 from __future__ import annotations
39
40 import array_api_compat.numpy as xp
41
42 # import array_api_compat.cupy as xp
43 # import array_api_compat.torch as xp
44
45 import parallelproj
46 from array_api_compat import to_device
47 import array_api_compat.numpy as np
48 import matplotlib.pyplot as plt
49
50 # choose a device (CPU or CUDA GPU)
51 if "numpy" in xp.__name__:
52 # using numpy, device must be cpu
53 dev = "cpu"
54 elif "cupy" in xp.__name__:
55 # using cupy, only cuda devices are possible
56 dev = xp.cuda.Device(0)
57 elif "torch" in xp.__name__:
58 # using torch valid choices are 'cpu' or 'cuda'
59 if parallelproj.cuda_present:
60 dev = "cuda"
61 else:
62 dev = "cpu"
Input Parameters
67 # image scale (can be used to simulated more or less counts)
68 img_scale = 0.1
69 # number of MLEM iterations to init. PDHG and LM-SPDHG
70 num_iter_mlem = 10
71 # number of PDHG iterations
72 num_iter_pdhg = 1000
73 # prior weight
74 beta = 6.0
75 # step size ratio for PDHG
76 gamma = 1.0 / img_scale
77 # rho value for PDHG
78 rho = 0.9999
79 # contaminaton in every sinogram bin relative to mean of trues sinogram
80 contam = 1.0
81
82
83 track_cost = True
Simulation of PET data in sinogram space
In this example, we use simulated listmode data for which we first need to setup a sinogram forward model to create a noise-free and noisy emission sinogram that can be converted to listmode data.
Setup of the sinogram forward model
We setup a linear forward operator \(A\) consisting of an image-based resolution model, a non-TOF PET projector and an attenuation model
101 num_rings = 2
102 scanner = parallelproj.RegularPolygonPETScannerGeometry(
103 xp,
104 dev,
105 radius=350.0,
106 num_sides=28,
107 num_lor_endpoints_per_side=16,
108 lor_spacing=4.0,
109 ring_positions=xp.linspace(-2.5, 2.5, num_rings),
110 symmetry_axis=2,
111 )
112
113 # setup the LOR descriptor that defines the sinogram
114
115 img_shape = (40, 40, 4)
116 voxel_size = (4.0, 4.0, 2.5)
117
118 lor_desc = parallelproj.RegularPolygonPETLORDescriptor(
119 scanner,
120 radial_trim=170,
121 sinogram_order=parallelproj.SinogramSpatialAxisOrder.RVP,
122 )
123
124 proj = parallelproj.RegularPolygonPETProjector(
125 lor_desc, img_shape=img_shape, voxel_size=voxel_size
126 )
127
128 # setup a simple test image containing a few "hot rods"
129 x_true = xp.ones(proj.in_shape, device=dev, dtype=xp.float32)
130 c0 = proj.in_shape[0] // 2
131 c1 = proj.in_shape[1] // 2
132 x_true[(c0 - 2) : (c0 + 2), (c1 - 2) : (c1 + 2), :] = 5.0
133 x_true[4, c1, 2:] = 5.0
134 x_true[c0, 4, :-2] = 5.0
135
136 tmp_n = proj.in_shape[0] // 4
137 x_true[:tmp_n, :, :] = 0
138 x_true[-tmp_n:, :, :] = 0
139 x_true[:, :2, :] = 0
140 x_true[:, -2:, :] = 0
141
142 # setup a structural prior image
143 x_struct = -1.0 * xp.sqrt(x_true)
144 x_struct[(c0) : (c0 + 2), (c1) : (c1 + 2), :] = -1.0
145
146 # scale image to get more counts
147 x_true *= img_scale
Attenuation image and sinogram setup
154 # setup an attenuation image
155 x_att = 0.01 * xp.astype(x_true > 0, xp.float32)
156 # calculate the attenuation sinogram
157 att_sino = xp.exp(-proj(x_att))
Complete sinogram PET forward model setup
We combine an image-based resolution model, a non-TOF or TOF PET projector and an attenuation model into a single linear operator.
167 # enable TOF - uncomment if you want to run TOF recons
168 # proj.tof_parameters = parallelproj.TOFParameters(
169 # num_tofbins=17, tofbin_width=12.0, sigma_tof=12.0
170 # )
171
172 # setup the attenuation multiplication operator which is different
173 # for TOF and non-TOF since the attenuation sinogram is always non-TOF
174 if proj.tof:
175 att_op = parallelproj.TOFNonTOFElementwiseMultiplicationOperator(
176 proj.out_shape, att_sino
177 )
178 else:
179 att_op = parallelproj.ElementwiseMultiplicationOperator(att_sino)
180
181 res_model = parallelproj.GaussianFilterOperator(
182 proj.in_shape, sigma=4.5 / (2.35 * proj.voxel_size)
183 )
184
185 # compose all 3 operators into a single linear operator
186 pet_lin_op = parallelproj.CompositeLinearOperator((att_op, proj, res_model))
Simulation of sinogram projection data
We setup an arbitrary ground truth \(x_{true}\) and simulate noise-free and noisy data \(y\) by adding Poisson noise.
195 # simulated noise-free data
196 noise_free_data = pet_lin_op(x_true)
197
198 # generate a contant contamination sinogram
199 contamination = xp.full(
200 noise_free_data.shape,
201 contam * float(xp.mean(noise_free_data)),
202 device=dev,
203 dtype=xp.float32,
204 )
205
206 noise_free_data += contamination
207
208 # add Poisson noise
209 np.random.seed(1)
210 d = xp.asarray(
211 np.random.poisson(np.asarray(to_device(noise_free_data, "cpu"))),
212 device=dev,
213 dtype=xp.int16,
214 )
Run quick MLEM as initialization
220 x_mlem = xp.ones(pet_lin_op.in_shape, dtype=xp.float32, device=dev)
221 # calculate A^H 1
222 adjoint_ones = pet_lin_op.adjoint(
223 xp.ones(pet_lin_op.out_shape, dtype=xp.float32, device=dev)
224 )
225
226 for i in range(num_iter_mlem):
227 print(f"MLEM iteration {(i + 1):03} / {num_iter_mlem:03}", end="\r")
228 dbar = pet_lin_op(x_mlem) + contamination
229 x_mlem *= pet_lin_op.adjoint(d / dbar) / adjoint_ones
Setup the cost function
236 def cost_function(img):
237 exp = pet_lin_op(img) + contamination
238 res = float(xp.sum(exp - d * xp.log(exp)))
239 res += beta * float(xp.sum(xp.linalg.vector_norm(op_G(img), axis=0)))
240 return res
PDHG
PDHG algorithm to minimize negative Poisson log-likelihood + regularization
See [EMS19] [SH22] for more details.
Proximal operator of the convex dual of the negative Poisson log-likelihood
\((\text{prox}_{D^*}^{S}(y))_i = \text{prox}_{D^*}^{S}(y_i) = \frac{1}{2} \left(y_i + 1 - \sqrt{ (y_i-1)^2 + 4 S d_i} \right)\)
Step sizes
\(S_A = \gamma \, \text{diag}(\frac{\rho}{A 1})\)
\(S_G = \gamma \, \text{diag}(\frac{\rho}{|\nabla|})\)
\(T_A = \gamma^{-1} \text{diag}(\frac{\rho}{A^T 1})\)
\(T_G = \gamma^{-1} \text{diag}(\frac{\rho}{|\nabla|})\)
\(T = \min T_A, T_G\) pointwise
282 # setup the "normal" gradient operator
283 G = parallelproj.FiniteForwardDifference(pet_lin_op.in_shape)
284 # calculate the joint vector field based on the structural prior image
285 joint_vector_field = G(x_struct)
286 # setup the projected gradient operator
287 P = parallelproj.GradientFieldProjectionOperator(joint_vector_field, eta=1e-4)
288 op_G = parallelproj.CompositeLinearOperator((P, G))
289
290 # initialize primal and dual variables
291 x_pdhg = 1.0 * x_mlem
292 y = 1 - d / (pet_lin_op(x_pdhg) + contamination)
293
294 # initialize dual variable for the gradient
295 w = xp.zeros(op_G.out_shape, dtype=xp.float32, device=dev)
296
297 z = pet_lin_op.adjoint(y) + op_G.adjoint(w)
298 zbar = 1.0 * z
302 # calculate PHDG step sizes
303 tmp = pet_lin_op(xp.ones(pet_lin_op.in_shape, dtype=xp.float32, device=dev))
304 tmp = xp.where(tmp == 0, xp.min(tmp[tmp > 0]), tmp)
305 S_A = gamma * rho / tmp
306
307 T_A = (
308 (1 / gamma)
309 * rho
310 / pet_lin_op.adjoint(xp.ones(pet_lin_op.out_shape, dtype=xp.float64, device=dev))
311 )
312
313 op_G_norm = op_G.norm(xp, dev, num_iter=100)
314 S_G = gamma * rho / op_G_norm
315 T_G = (1 / gamma) * rho / op_G_norm
316
317 T = xp.where(T_A < T_G, T_A, xp.full(pet_lin_op.in_shape, T_G))
Run PDHG
324 print("")
325 cost_pdhg = np.zeros(num_iter_pdhg, dtype=xp.float32)
326
327 for i in range(num_iter_pdhg):
328 x_pdhg -= T * zbar
329 x_pdhg = xp.where(x_pdhg < 0, xp.zeros_like(x_pdhg), x_pdhg)
330
331 if track_cost:
332 cost_pdhg[i] = cost_function(x_pdhg)
333
334 y_plus = y + S_A * (pet_lin_op(x_pdhg) + contamination)
335 # prox of convex conjugate of negative Poisson logL
336 y_plus = 0.5 * (y_plus + 1 - xp.sqrt((y_plus - 1) ** 2 + 4 * S_A * d))
337
338 w_plus = (w + S_G * op_G(x_pdhg)) / beta
339 # prox of convex conjugate of TV
340 denom = xp.linalg.vector_norm(w_plus, axis=0)
341 w_plus /= xp.where(denom < 1, xp.ones_like(denom), denom)
342 w_plus *= beta
343
344 delta_z = pet_lin_op.adjoint(y_plus - y) + op_G.adjoint(w_plus - w)
345 y = 1.0 * y_plus
346 w = 1.0 * w_plus
347
348 z = z + delta_z
349 zbar = z + delta_z
350
351 print(f"PDHG iter {(i+1):04} / {num_iter_pdhg}, cost {cost_pdhg[i]:.7e}", end="\r")
Vizualizations
357 x_true_np = parallelproj.to_numpy_array(x_true)
358 x_struct_np = parallelproj.to_numpy_array(x_struct)
359 x_pdhg_np = parallelproj.to_numpy_array(x_pdhg)
360
361 pl2 = x_true_np.shape[2] // 2
362 pl1 = x_true_np.shape[1] // 2
363 pl0 = x_true_np.shape[0] // 2
364
365 fig, ax = plt.subplots(2, 3, figsize=(9, 5), tight_layout=True)
366 vmax = 1.2 * x_true_np.max()
367 ax[0, 0].imshow(x_true_np[:, :, pl2], cmap="Greys", vmin=0, vmax=vmax)
368 ax[0, 1].imshow(x_pdhg_np[:, :, pl2], cmap="Greys", vmin=0, vmax=vmax)
369 ax[0, 2].imshow(x_struct_np[:, :, pl2], cmap="Greys")
370
371 ax[1, 0].imshow(x_true_np[pl0, :, :].T, cmap="Greys", vmin=0, vmax=vmax)
372 ax[1, 1].imshow(x_pdhg_np[pl0, :, :].T, cmap="Greys", vmin=0, vmax=vmax)
373 ax[1, 2].imshow(x_struct_np[pl0, :, :].T, cmap="Greys")
374
375 ax[0, 0].set_title("true img", fontsize="medium")
376 ax[0, 1].set_title(f"DTV PDHG {num_iter_pdhg} it.", fontsize="medium")
377 ax[0, 2].set_title("structural img", fontsize="medium")
378 fig.show()
382 if track_cost:
383 fig2, ax2 = plt.subplots(1, 1, figsize=(6, 4), tight_layout=True)
384 ax2.plot(cost_pdhg, ".-", label="PDHG")
385 ax2.grid(ls=":")
386 ax2.legend()
387 ax2.set_xlabel("iteration")
388 ax2.set_title("cost", fontsize="medium")
389 fig2.show()
Related examples
PDHG and LM-SPHG to optimize the Poisson logL and total variation
DePierro’s algorithm to optimize the Poisson logL with quadratic intensity prior