PDHG to optimize the Poisson logL and directional TV (structural prior)

This example demonstrates the use of the primal dual hybrid gradient (PDHG) algorithm, to minimize the negative Poisson log-likelihood function combined with a directional total variation regularizer (a structural prior):

\[f(x) = \sum_{i=1}^m \bar{d}_i (x) - d_i \log(\bar{d}_i (x)) + \beta \|P_{\xi} \nabla x \|_{1,2}\]

subject to

\[x \geq 0\]

using the linear forward model

\[\bar{d}(x) = A x + s\]

see [EB16] and [EMS19] for details.

Tip

parallelproj is python array API compatible meaning it supports different array backends (e.g. numpy, cupy, torch, …) and devices (CPU or GPU). Choose your preferred array API xp and device dev below.

Warning

Running this example using GPU arrays (e.g. using cupy as array backend) is highly recommended due to “longer” execution times with CPU arrays

https://mybinder.org/badge_logo.svg
38 from __future__ import annotations
39
40 import array_api_compat.numpy as xp
41
42 # import array_api_compat.cupy as xp
43 # import array_api_compat.torch as xp
44
45 import parallelproj
46 from array_api_compat import to_device
47 import array_api_compat.numpy as np
48 import matplotlib.pyplot as plt
49
50 # choose a device (CPU or CUDA GPU)
51 if "numpy" in xp.__name__:
52     # using numpy, device must be cpu
53     dev = "cpu"
54 elif "cupy" in xp.__name__:
55     # using cupy, only cuda devices are possible
56     dev = xp.cuda.Device(0)
57 elif "torch" in xp.__name__:
58     # using torch valid choices are 'cpu' or 'cuda'
59     if parallelproj.cuda_present:
60         dev = "cuda"
61     else:
62         dev = "cpu"

Input Parameters

67 # image scale (can be used to simulated more or less counts)
68 img_scale = 0.1
69 # number of MLEM iterations to init. PDHG and LM-SPDHG
70 num_iter_mlem = 10
71 # number of PDHG iterations
72 num_iter_pdhg = 1000
73 # prior weight
74 beta = 6.0
75 # step size ratio for PDHG
76 gamma = 1.0 / img_scale
77 # rho value for PDHG
78 rho = 0.9999
79 # contaminaton in every sinogram bin relative to mean of trues sinogram
80 contam = 1.0
81
82
83 track_cost = True

Simulation of PET data in sinogram space

In this example, we use simulated listmode data for which we first need to setup a sinogram forward model to create a noise-free and noisy emission sinogram that can be converted to listmode data.

Setup of the sinogram forward model

We setup a linear forward operator \(A\) consisting of an image-based resolution model, a non-TOF PET projector and an attenuation model

101 num_rings = 2
102 scanner = parallelproj.RegularPolygonPETScannerGeometry(
103     xp,
104     dev,
105     radius=350.0,
106     num_sides=28,
107     num_lor_endpoints_per_side=16,
108     lor_spacing=4.0,
109     ring_positions=xp.linspace(-2.5, 2.5, num_rings),
110     symmetry_axis=2,
111 )
112
113 # setup the LOR descriptor that defines the sinogram
114
115 img_shape = (40, 40, 4)
116 voxel_size = (4.0, 4.0, 2.5)
117
118 lor_desc = parallelproj.RegularPolygonPETLORDescriptor(
119     scanner,
120     radial_trim=170,
121     sinogram_order=parallelproj.SinogramSpatialAxisOrder.RVP,
122 )
123
124 proj = parallelproj.RegularPolygonPETProjector(
125     lor_desc, img_shape=img_shape, voxel_size=voxel_size
126 )
127
128 # setup a simple test image containing a few "hot rods"
129 x_true = xp.ones(proj.in_shape, device=dev, dtype=xp.float32)
130 c0 = proj.in_shape[0] // 2
131 c1 = proj.in_shape[1] // 2
132 x_true[(c0 - 2) : (c0 + 2), (c1 - 2) : (c1 + 2), :] = 5.0
133 x_true[4, c1, 2:] = 5.0
134 x_true[c0, 4, :-2] = 5.0
135
136 tmp_n = proj.in_shape[0] // 4
137 x_true[:tmp_n, :, :] = 0
138 x_true[-tmp_n:, :, :] = 0
139 x_true[:, :2, :] = 0
140 x_true[:, -2:, :] = 0
141
142 # setup a structural prior image
143 x_struct = -1.0 * xp.sqrt(x_true)
144 x_struct[(c0) : (c0 + 2), (c1) : (c1 + 2), :] = -1.0
145
146 # scale image to get more counts
147 x_true *= img_scale

Attenuation image and sinogram setup

154 # setup an attenuation image
155 x_att = 0.01 * xp.astype(x_true > 0, xp.float32)
156 # calculate the attenuation sinogram
157 att_sino = xp.exp(-proj(x_att))

Complete sinogram PET forward model setup

We combine an image-based resolution model, a non-TOF or TOF PET projector and an attenuation model into a single linear operator.

167 # enable TOF - uncomment if you want to run TOF recons
168 # proj.tof_parameters = parallelproj.TOFParameters(
169 #    num_tofbins=17, tofbin_width=12.0, sigma_tof=12.0
170 # )
171
172 # setup the attenuation multiplication operator which is different
173 # for TOF and non-TOF since the attenuation sinogram is always non-TOF
174 if proj.tof:
175     att_op = parallelproj.TOFNonTOFElementwiseMultiplicationOperator(
176         proj.out_shape, att_sino
177     )
178 else:
179     att_op = parallelproj.ElementwiseMultiplicationOperator(att_sino)
180
181 res_model = parallelproj.GaussianFilterOperator(
182     proj.in_shape, sigma=4.5 / (2.35 * proj.voxel_size)
183 )
184
185 # compose all 3 operators into a single linear operator
186 pet_lin_op = parallelproj.CompositeLinearOperator((att_op, proj, res_model))

Simulation of sinogram projection data

We setup an arbitrary ground truth \(x_{true}\) and simulate noise-free and noisy data \(y\) by adding Poisson noise.

195 # simulated noise-free data
196 noise_free_data = pet_lin_op(x_true)
197
198 # generate a contant contamination sinogram
199 contamination = xp.full(
200     noise_free_data.shape,
201     contam * float(xp.mean(noise_free_data)),
202     device=dev,
203     dtype=xp.float32,
204 )
205
206 noise_free_data += contamination
207
208 # add Poisson noise
209 np.random.seed(1)
210 d = xp.asarray(
211     np.random.poisson(np.asarray(to_device(noise_free_data, "cpu"))),
212     device=dev,
213     dtype=xp.int16,
214 )

Run quick MLEM as initialization

220 x_mlem = xp.ones(pet_lin_op.in_shape, dtype=xp.float32, device=dev)
221 # calculate A^H 1
222 adjoint_ones = pet_lin_op.adjoint(
223     xp.ones(pet_lin_op.out_shape, dtype=xp.float32, device=dev)
224 )
225
226 for i in range(num_iter_mlem):
227     print(f"MLEM iteration {(i + 1):03} / {num_iter_mlem:03}", end="\r")
228     dbar = pet_lin_op(x_mlem) + contamination
229     x_mlem *= pet_lin_op.adjoint(d / dbar) / adjoint_ones

Setup the cost function

236 def cost_function(img):
237     exp = pet_lin_op(img) + contamination
238     res = float(xp.sum(exp - d * xp.log(exp)))
239     res += beta * float(xp.sum(xp.linalg.vector_norm(op_G(img), axis=0)))
240     return res

PDHG

PDHG algorithm to minimize negative Poisson log-likelihood + regularization

Input Poisson data \(d\)
Initialize \(x,y,w,S_A,S_G,T\)
Preprocessing \(\overline{z} = z = A^T y + \nabla^T w\)
Repeat, until stopping criterion fulfilled
Update \(x \gets \text{proj}_{\geq 0} \left( x - T \overline{z} \right)\)
Update \(y^+ \gets \text{prox}_{D^*}^{S_A} ( y + S_A ( A x + s))\)
Update \(w^+ \gets \beta \, \text{prox}_{R^*}^{S_G/\beta} ((w + S_G \nabla x)/\beta)\)
Update \(\Delta z \gets A^T (y^+ - y) + \nabla^T (w^+ - w)\)
Update \(z \gets z + \Delta z\)
Update \(\bar{z} \gets z + \Delta z\)
Update \(y \gets y^+\)
Update \(w \gets w^+\)
Return \(x\)

See [EMS19] [SH22] for more details.

Proximal operator of the convex dual of the negative Poisson log-likelihood

\((\text{prox}_{D^*}^{S}(y))_i = \text{prox}_{D^*}^{S}(y_i) = \frac{1}{2} \left(y_i + 1 - \sqrt{ (y_i-1)^2 + 4 S d_i} \right)\)

Step sizes

\(S_A = \gamma \, \text{diag}(\frac{\rho}{A 1})\)

\(S_G = \gamma \, \text{diag}(\frac{\rho}{|\nabla|})\)

\(T_A = \gamma^{-1} \text{diag}(\frac{\rho}{A^T 1})\)

\(T_G = \gamma^{-1} \text{diag}(\frac{\rho}{|\nabla|})\)

\(T = \min T_A, T_G\) pointwise

282 # setup the "normal" gradient operator
283 G = parallelproj.FiniteForwardDifference(pet_lin_op.in_shape)
284 # calculate the joint vector field based on the structural prior image
285 joint_vector_field = G(x_struct)
286 # setup the projected gradient operator
287 P = parallelproj.GradientFieldProjectionOperator(joint_vector_field, eta=1e-4)
288 op_G = parallelproj.CompositeLinearOperator((P, G))
289
290 # initialize primal and dual variables
291 x_pdhg = 1.0 * x_mlem
292 y = 1 - d / (pet_lin_op(x_pdhg) + contamination)
293
294 # initialize dual variable for the gradient
295 w = xp.zeros(op_G.out_shape, dtype=xp.float32, device=dev)
296
297 z = pet_lin_op.adjoint(y) + op_G.adjoint(w)
298 zbar = 1.0 * z
302 # calculate PHDG step sizes
303 tmp = pet_lin_op(xp.ones(pet_lin_op.in_shape, dtype=xp.float32, device=dev))
304 tmp = xp.where(tmp == 0, xp.min(tmp[tmp > 0]), tmp)
305 S_A = gamma * rho / tmp
306
307 T_A = (
308     (1 / gamma)
309     * rho
310     / pet_lin_op.adjoint(xp.ones(pet_lin_op.out_shape, dtype=xp.float64, device=dev))
311 )
312
313 op_G_norm = op_G.norm(xp, dev, num_iter=100)
314 S_G = gamma * rho / op_G_norm
315 T_G = (1 / gamma) * rho / op_G_norm
316
317 T = xp.where(T_A < T_G, T_A, xp.full(pet_lin_op.in_shape, T_G))

Run PDHG

324 print("")
325 cost_pdhg = np.zeros(num_iter_pdhg, dtype=xp.float32)
326
327 for i in range(num_iter_pdhg):
328     x_pdhg -= T * zbar
329     x_pdhg = xp.where(x_pdhg < 0, xp.zeros_like(x_pdhg), x_pdhg)
330
331     if track_cost:
332         cost_pdhg[i] = cost_function(x_pdhg)
333
334     y_plus = y + S_A * (pet_lin_op(x_pdhg) + contamination)
335     # prox of convex conjugate of negative Poisson logL
336     y_plus = 0.5 * (y_plus + 1 - xp.sqrt((y_plus - 1) ** 2 + 4 * S_A * d))
337
338     w_plus = (w + S_G * op_G(x_pdhg)) / beta
339     # prox of convex conjugate of TV
340     denom = xp.linalg.vector_norm(w_plus, axis=0)
341     w_plus /= xp.where(denom < 1, xp.ones_like(denom), denom)
342     w_plus *= beta
343
344     delta_z = pet_lin_op.adjoint(y_plus - y) + op_G.adjoint(w_plus - w)
345     y = 1.0 * y_plus
346     w = 1.0 * w_plus
347
348     z = z + delta_z
349     zbar = z + delta_z
350
351     print(f"PDHG iter {(i+1):04} / {num_iter_pdhg}, cost {cost_pdhg[i]:.7e}", end="\r")

Vizualizations

357 x_true_np = parallelproj.to_numpy_array(x_true)
358 x_struct_np = parallelproj.to_numpy_array(x_struct)
359 x_pdhg_np = parallelproj.to_numpy_array(x_pdhg)
360
361 pl2 = x_true_np.shape[2] // 2
362 pl1 = x_true_np.shape[1] // 2
363 pl0 = x_true_np.shape[0] // 2
364
365 fig, ax = plt.subplots(2, 3, figsize=(9, 5), tight_layout=True)
366 vmax = 1.2 * x_true_np.max()
367 ax[0, 0].imshow(x_true_np[:, :, pl2], cmap="Greys", vmin=0, vmax=vmax)
368 ax[0, 1].imshow(x_pdhg_np[:, :, pl2], cmap="Greys", vmin=0, vmax=vmax)
369 ax[0, 2].imshow(x_struct_np[:, :, pl2], cmap="Greys")
370
371 ax[1, 0].imshow(x_true_np[pl0, :, :].T, cmap="Greys", vmin=0, vmax=vmax)
372 ax[1, 1].imshow(x_pdhg_np[pl0, :, :].T, cmap="Greys", vmin=0, vmax=vmax)
373 ax[1, 2].imshow(x_struct_np[pl0, :, :].T, cmap="Greys")
374
375 ax[0, 0].set_title("true img", fontsize="medium")
376 ax[0, 1].set_title(f"DTV PDHG {num_iter_pdhg} it.", fontsize="medium")
377 ax[0, 2].set_title("structural img", fontsize="medium")
378 fig.show()
382 if track_cost:
383     fig2, ax2 = plt.subplots(1, 1, figsize=(6, 4), tight_layout=True)
384     ax2.plot(cost_pdhg, ".-", label="PDHG")
385     ax2.grid(ls=":")
386     ax2.legend()
387     ax2.set_xlabel("iteration")
388     ax2.set_title("cost", fontsize="medium")
389     fig2.show()

Related examples

PDHG and LM-SPHG to optimize the Poisson logL and total variation

PDHG and LM-SPHG to optimize the Poisson logL and total variation

DePierro’s algorithm to optimize the Poisson logL with quadratic intensity prior

DePierro's algorithm to optimize the Poisson logL with quadratic intensity prior

TOF listmode MLEM with projection data

TOF listmode MLEM with projection data

TOF-MLEM with projection data

TOF-MLEM with projection data

Gallery generated by Sphinx-Gallery