PDHG and LM-SPHG to optimize the Poisson logL and total variation

This example demonstrates the use of the primal dual hybrid gradient (PDHG) algorithm, the listmode stochastic PDHG (LM-SPDHG) to minimize the negative Poisson log-likelihood function combined with a total variation regularizer:

\[f(x) = \sum_{i=1}^m \bar{d}_i (x) - d_i \log(\bar{d}_i (x)) + \beta \|\nabla x \|_{1,2}\]

subject to

\[x \geq 0\]

using the linear forward model

\[\bar{d}(x) = A x + s\]

Warning

Running this example using GPU arrays (e.g. using cupy as array backend) is highly recommended due to “longer” execution times with CPU arrays

from __future__ import annotations
import matplotlib.pyplot as plt
from vis import show_vol_cuts
from img import elliptic_cylinder_phantom
import numpy as np
import math
from array_api_compat import size

import parallelproj.operators
import parallelproj.tof
import parallelproj.pet_scanners
import parallelproj.pet_lors
import parallelproj.projectors
from parallelproj import to_numpy_array, count_event_multiplicity
from array_utils import suggest_array_backend_and_device

# To use a specific backend and/or device, replace the None arguments, e.g.:
#   xp, dev = suggest_array_backend_and_device(backend="numpy", dev="cpu") or by setting xp and dev manually
xp, dev = suggest_array_backend_and_device(None, None)

Input Parameters

# image scale (can be used to simulated more or less counts)
img_scale = 0.1
# number of MLEM iterations to init. PDHG and LM-SPDHG
num_iter_mlem = 10
# number of PDHG iterations
num_iter_pdhg = 2000
# number of subsets for SPDHG and LM-SPDHG
num_subsets = 10
# number of iterations for stochastic PDHGs
num_iter_spdhg = 200
# prior weight
beta = 10.0
# step size ratio for LM-SPDHG
gamma = 1.0 / img_scale
# rho value for LM-SPHDHG
rho = 0.9999
# contaminaton in every sinogram bin relative to mean of trues sinogram
contam = 1.0


# subset probabilities for SPDHG
p_g = 0.5  # gradient update
p_a = (1 - p_g) / num_subsets  # data subset update

track_cost = True

Simulation of PET data in sinogram space

In this example, we use simulated listmode data for which we first need to setup a sinogram forward model to create a noise-free and noisy emission sinogram that can be converted to listmode data.

Setup of the sinogram forward model

We setup a linear forward operator \(A\) consisting of an image-based resolution model, a non-TOF PET projector and an attenuation model

num_rings = 5
scanner = parallelproj.pet_scanners.RegularPolygonPETScannerGeometry(
    xp,
    dev,
    radius=350.0,
    num_sides=28,
    num_lor_endpoints_per_side=16,
    lor_spacing=4.0,
    ring_positions=xp.linspace(-10, 10, num_rings, device=dev),
    symmetry_axis=2,
)

# setup the LOR descriptor that defines the sinogram

img_shape = (40, 40, 5)
voxel_size = (4.0, 4.0, 4.0)

lor_desc = parallelproj.pet_lors.RegularPolygonPETLORDescriptor(
    scanner,
    parallelproj.pet_lors.Michelogram(
        scanner.num_rings, max_ring_difference=num_rings - 1, span=1
    ),
    radial_trim=170,
    sinogram_order=parallelproj.pet_lors.SinogramSpatialAxisOrder.RVP,
)

proj = parallelproj.projectors.RegularPolygonPETProjector(
    lor_desc, img_shape=img_shape, voxel_size=voxel_size
)

x_true = elliptic_cylinder_phantom(
    xp, dev, image_shape=img_shape, voxel_size=voxel_size
)

# scale image to get more counts
x_true *= img_scale

Attenuation image and sinogram setup

# setup an attenuation image
x_att = 0.01 * xp.astype(x_true > 0, xp.float32)
# calculate the attenuation sinogram
att_sino = xp.exp(-proj(x_att))

Complete sinogram PET forward model setup

We combine an image-based resolution model, a non-TOF or TOF PET projector and an attenuation model into a single linear operator.

# enable TOF - comment if you want to run non-TOF
proj.tof_parameters = parallelproj.tof.TOFParameters(
    num_tofbins=17, tofbin_width=12.0, sigma_tof=12.0
)

# For TOF, att_sino has no TOF-bins dimension while the projector output does.
# broadcast_to adds a trailing singleton via expand_dims and broadcasts it over
# the TOF-bins axis without copying data (zero-stride view).
att_values = (
    xp.broadcast_to(xp.expand_dims(att_sino, axis=-1), proj.out_shape)
    if proj.tof
    else att_sino
)
att_op = parallelproj.operators.ElementwiseMultiplicationOperator(att_values)

res_model = parallelproj.operators.GaussianFilterOperator(
    proj.in_shape, sigma=[4.5 / (2.35 * float(vs)) for vs in proj.voxel_size]
)

# compose all 3 operators into a single linear operator
pet_lin_op = parallelproj.operators.CompositeLinearOperator((att_op, proj, res_model))

Simulation of sinogram projection data

We setup an arbitrary ground truth \(x_{true}\) and simulate noise-free and noisy data \(y\) by adding Poisson noise.

# simulated noise-free data
noise_free_data = pet_lin_op(x_true)

# generate a contant contamination sinogram
contamination = xp.full(
    noise_free_data.shape,
    contam * float(xp.mean(noise_free_data)),
    device=dev,
    dtype=xp.float32,
)

noise_free_data += contamination

# add Poisson noise
np.random.seed(1)
d = xp.asarray(
    np.random.poisson(to_numpy_array(noise_free_data)),
    device=dev,
    dtype=xp.int16,
)

Run quick MLEM as initialization

x_mlem = xp.ones(pet_lin_op.in_shape, dtype=xp.float32, device=dev)
# calculate A^H 1
adjoint_ones = pet_lin_op.adjoint(
    xp.ones(pet_lin_op.out_shape, dtype=xp.float32, device=dev)
)

for i in range(num_iter_mlem):
    print(f"MLEM iteration {(i + 1):03} / {num_iter_mlem:03}", end="\r")
    dbar = pet_lin_op(x_mlem) + contamination
    x_mlem *= pet_lin_op.adjoint(d / dbar) / adjoint_ones

Setup the cost function

def cost_function(img):
    exp = pet_lin_op(img) + contamination
    res = float(xp.sum(exp - d * xp.log(exp)))
    res += beta * float(xp.sum(xp.linalg.vector_norm(op_G(img), axis=0)))
    return res

PDHG

PDHG algorithm to minimize negative Poisson log-likelihood + regularization

Input Poisson data \(d\)
Initialize \(x,y,w,S_A,S_G,T\)
Preprocessing \(\overline{z} = z = A^T y + \nabla^T w\)
Repeat, until stopping criterion fulfilled
Update \(x \gets \text{proj}_{\geq 0} \left( x - T \overline{z} \right)\)
Update \(y^+ \gets \text{prox}_{D^*}^{S_A} ( y + S_A ( A x + s))\)
Update \(w^+ \gets \beta \, \text{prox}_{R^*}^{S_G/\beta} ((w + S_G \nabla x)/\beta)\)
Update \(\Delta z \gets A^T (y^+ - y) + \nabla^T (w^+ - w)\)
Update \(z \gets z + \Delta z\)
Update \(\bar{z} \gets z + \Delta z\)
Update \(y \gets y^+\)
Update \(w \gets w^+\)
Return \(x\)

See [EMS19] [SH22] for more details.

Proximal operator of the convex dual of the negative Poisson log-likelihood

\((\text{prox}_{D^*}^{S}(y))_i = \text{prox}_{D^*}^{S}(y_i) = \frac{1}{2} \left(y_i + 1 - \sqrt{ (y_i-1)^2 + 4 S d_i} \right)\)

Step sizes

\(S_A = \gamma \, \text{diag}(\frac{\rho}{A 1})\)

\(S_G = \gamma \, \text{diag}(\frac{\rho}{|\nabla|})\)

\(T_A = \gamma^{-1} \text{diag}(\frac{\rho}{A^T 1})\)

\(T_G = \gamma^{-1} \text{diag}(\frac{\rho}{|\nabla|})\)

\(T = \min T_A, T_G\) pointwise

op_G = parallelproj.operators.FiniteForwardDifference(pet_lin_op.in_shape)

# initialize primal and dual variables
x_pdhg = 1.0 * x_mlem
y = 1 - d / (pet_lin_op(x_pdhg) + contamination)

# initialize dual variable for the gradient
w = xp.zeros(op_G.out_shape, dtype=xp.float32, device=dev)

z = pet_lin_op.adjoint(y) + op_G.adjoint(w)
zbar = 1.0 * z
# calculate PHDG step sizes
tmp = pet_lin_op(xp.ones(pet_lin_op.in_shape, dtype=xp.float32, device=dev))
tmp = xp.where(tmp == 0, xp.min(tmp[tmp > 0]), tmp)
S_A = gamma * rho / tmp

T_A = (
    (1 / gamma)
    * rho
    / pet_lin_op.adjoint(xp.ones(pet_lin_op.out_shape, dtype=xp.float32, device=dev))
)

op_G_norm = op_G.norm(xp, dev, num_iter=100)
S_G = gamma * rho / op_G_norm
T_G = (1 / gamma) * rho / op_G_norm

T = xp.where(
    T_A < T_G, T_A, xp.full(pet_lin_op.in_shape, T_G, device=dev, dtype=xp.float32)
)

Run PDHG

print("")
cost_pdhg = np.zeros(num_iter_pdhg, dtype=np.float32)

for i in range(num_iter_pdhg):
    x_pdhg -= T * zbar
    x_pdhg = xp.where(x_pdhg < 0, xp.zeros_like(x_pdhg), x_pdhg)

    if track_cost:
        cost_pdhg[i] = cost_function(x_pdhg)

    if i == (num_iter_spdhg - 1):
        x_pdhg_early = 1.0 * x_pdhg

    y_plus = y + S_A * (pet_lin_op(x_pdhg) + contamination)
    # prox of convex conjugate of negative Poisson logL
    y_plus = 0.5 * (y_plus + 1 - xp.sqrt((y_plus - 1) ** 2 + 4 * S_A * d))

    w_plus = (w + S_G * op_G(x_pdhg)) / beta
    # prox of convex conjugate of TV
    denom = xp.linalg.vector_norm(w_plus, axis=0)
    w_plus /= xp.where(denom < 1, xp.ones_like(denom), denom)
    w_plus *= beta

    delta_z = pet_lin_op.adjoint(y_plus - y) + op_G.adjoint(w_plus - w)
    y = 1.0 * y_plus
    w = 1.0 * w_plus

    z = z + delta_z
    zbar = z + delta_z

    print(f"PDHG iter {(i+1):04} / {num_iter_pdhg}, cost {cost_pdhg[i]:.7e}", end="\r")

Conversion of the emission sinogram to listmode

Using RegularPolygonPETProjector.convert_sinogram_to_listmode() we can convert an integer non-TOF or TOF sinogram to an event list for listmode processing.

Warning

Note: The created event list is “ordered” and should be shuffled depending on the strategy to define subsets in LM-OSEM.

print(f"\nGenerating LM events ({float(xp.sum(d)):.2e})")
event_start_coords, event_end_coords, event_tofbins = proj.convert_sinogram_to_listmode(
    d
)

Shuffle the simulated “ordered” LM events

random_inds = np.random.permutation(event_start_coords.shape[0])
event_start_coords = event_start_coords[random_inds, :]
event_end_coords = event_end_coords[random_inds, :]
event_tofbins = event_tofbins[random_inds]

Setup of the LM subset projectors and LM subset forward models

# slices that define which elements of the event list belong to each subset
# here every "num_subset-th element" is used
subset_slices_lm = [slice(i, None, num_subsets) for i in range(num_subsets)]

lm_pet_subset_linop_seq = []

for i, sl in enumerate(subset_slices_lm):
    subset_lm_proj = parallelproj.projectors.ListmodePETProjector(
        1 * event_start_coords[sl, :],
        1 * event_end_coords[sl, :],
        proj.in_shape,
        proj.voxel_size,
        proj.img_origin,
    )

    # recalculate the attenuation factor for all LM events
    # this needs to be a non-TOF projection
    subset_att_list = xp.exp(-subset_lm_proj(x_att))

    # enable TOF in the LM projector
    subset_lm_proj.tof_parameters = proj.tof_parameters
    if proj.tof:
        subset_lm_proj.event_tofbins = xp.asarray(event_tofbins[sl], copy=True)
        subset_lm_proj.tof = proj.tof

    subset_lm_att_op = parallelproj.operators.ElementwiseMultiplicationOperator(
        subset_att_list
    )

    lm_pet_subset_linop_seq.append(
        parallelproj.operators.CompositeLinearOperator(
            (subset_lm_att_op, subset_lm_proj, res_model)
        )
    )

lm_pet_subset_linop_seq = parallelproj.operators.LinearOperatorSequence(
    lm_pet_subset_linop_seq
)

# create the contamination list
contamination_list = xp.full(
    event_start_coords.shape[0],
    float(xp.reshape(contamination, (size(contamination),))[0]),
    device=dev,
    dtype=xp.float32,
)

Calculate event multiplicity \(\mu\) for each event in the list

events = xp.concat(
    [event_start_coords, event_end_coords, xp.expand_dims(event_tofbins, -1)], axis=1
)
mu = count_event_multiplicity(events)

Listmode SPDHG

Listmode SPDHG algorithm to minimize negative Poisson log-likelihood

Input event list \(N\), contamination list \(s_N\)
Calculate event counts \(\mu_e\) for each \(e\) in \(N\)
Initialize \(x,(S_i)_i,S_G,T,(p_i)_i\)
Initialize list \(y_{N} = 1 - (\mu_N /(A^{LM}_{N} x + s_{N}))\)
Preprocessing \(\overline{z} = z = {A^T} 1 - {A^{LM}_N}^T (y_N-1)/\mu_N\)
Split lists \(N\), \(s_N\) and \(y_N\) into \(n\) sublists \(N_i\), \(y_{N_i}\) and \(s_{N_i}\)
Repeat, until stopping criterion fulfilled
Update \(x \gets \text{proj}_{\geq 0} \left( x - T \overline{z} \right)\)
Select \(i \in \{ 1,\ldots,n+1\}\) randomly according to \((p_i)_i\)
if \(i \leq n\):
Update \(y_{N_i}^+ \gets \text{prox}_{D^*}^{S_i} \left( y_{N_i} + S_i \left(A^{LM}_{N_i} x + s^{LM}_{N_i} \right) \right)\)
Update \(\Delta z \gets {A^{LM}_{N_i}}^T \left(\frac{y_{N_i}^+ - y_{N_i}}{\mu_{N_i}}\right)\)
Update \(y_{N_i} \gets y_{N_i}^+\)
else:
Update \(w^+ \gets \beta \, \text{prox}_{R^*}^{S_G/\beta} ((w + S_G \nabla x)/\beta)\)
Update \(\Delta z \gets \nabla^T (w^+ - w)\)
Update \(w \gets w+\)
Update \(z \gets z + \Delta z\)
Update \(\bar{z} \gets z + (\Delta z/p_i)\)
Return \(x\)

Step sizes

\(S_i = \gamma \, \text{diag}(\frac{\rho}{A^{LM}_{N_i} 1})\)

\(T_i = \gamma^{-1} \text{diag}(\frac{\rho p_i}{{A^{LM}_{N_i}}^T 1/\mu_{N_i}})\)

\(T = \min_{i=1,\ldots,n+1} T_i\) pointwise

Initialize variables

# Intialize image x with solution from quick LM OSEM
x_lmspdhg = 1.0 * x_mlem

# setup dual variable for data subsets
ys = []
for k, sl in enumerate(subset_slices_lm):
    ys.append(
        1 - (mu[sl] / (lm_pet_subset_linop_seq[k](x_lmspdhg) + contamination_list[sl]))
    )

# initialize dual variable for the gradient
w_lm = xp.zeros(op_G.out_shape, dtype=xp.float32, device=dev)

z = 1.0 * adjoint_ones
for k, sl in enumerate(subset_slices_lm):
    z += lm_pet_subset_linop_seq[k].adjoint((ys[k] - 1) / mu[sl])
    # tmp = lm_pet_subset_linop_seq[k].adjoint(1 / mu[sl])
z += op_G.adjoint(w_lm)
zbar = 1.0 * z

Calculate the step sizes

S_A_lm = []
ones_img = xp.ones(img_shape, dtype=xp.float32, device=dev)

for lm_op in lm_pet_subset_linop_seq:
    tmp = lm_op(ones_img)
    tmp = xp.where(tmp == 0, xp.min(tmp[tmp > 0]), tmp)
    S_A_lm.append(gamma * rho / tmp)


T_A_lm = xp.zeros(
    (num_subsets + 1,) + pet_lin_op.in_shape, dtype=xp.float32, device=dev
)
for k, sl in enumerate(subset_slices_lm):
    tmp = lm_pet_subset_linop_seq[k].adjoint(1 / mu[sl])
    T_A_lm[k] = (rho * p_a / gamma) / tmp
T_A_lm[-1] = T_G
T_lm = xp.min(T_A_lm, axis=0)

Run LM-SPDHG

print("")
cost_lmspdhg = np.zeros(num_iter_spdhg, dtype=np.float32)
psnr_lmspdhg = np.zeros(num_iter_spdhg, dtype=np.float32)

psnr_scale = float(xp.max(x_true))

for i in range(num_iter_spdhg):
    subset_sequence = np.random.permutation(2 * num_subsets)

    psnr_lmspdhg[i] = 10 * math.log10(
        (psnr_scale**2) / float(xp.mean((x_lmspdhg - x_pdhg) ** 2))
    )

    if track_cost:
        cost_lmspdhg[i] = cost_function(x_lmspdhg)
    print(
        f"LM-SPDHG iter {(i+1):04} / {num_iter_spdhg}, cost {cost_lmspdhg[i]:.7e}",
        end="\r",
    )

    for k in subset_sequence:
        x_lmspdhg -= T_lm * zbar
        x_lmspdhg = xp.where(x_lmspdhg < 0, xp.zeros_like(x_lmspdhg), x_lmspdhg)

        if k < num_subsets:
            sl = subset_slices_lm[k]
            y_plus = ys[k] + S_A_lm[k] * (
                lm_pet_subset_linop_seq[k](x_lmspdhg) + contamination_list[sl]
            )
            y_plus = 0.5 * (
                y_plus + 1 - xp.sqrt((y_plus - 1) ** 2 + 4 * S_A_lm[k] * mu[sl])
            )
            dz = lm_pet_subset_linop_seq[k].adjoint((y_plus - ys[k]) / mu[sl])
            ys[k] = y_plus
            p = p_a
        else:
            w_plus = (w_lm + S_G * op_G(x_lmspdhg)) / beta
            # prox of convex conjugate of TV
            denom = xp.linalg.vector_norm(w_plus, axis=0)
            w_plus /= xp.where(denom < 1, xp.ones_like(denom), denom)
            w_plus *= beta
            dz = op_G.adjoint(w_plus - w_lm)
            w_lm = 1.0 * w_plus
            p = p_g

        z = z + dz
        zbar = z + (dz / p)

Visualizations

x_true_np = to_numpy_array(x_true)
x_mlem_np = to_numpy_array(x_mlem)
x_pdhg_np = to_numpy_array(x_pdhg)
x_pdhg_early_np = to_numpy_array(x_pdhg_early)
x_lmspdhg_np = to_numpy_array(x_lmspdhg)

vmax = 1.2 * x_true_np.max()

for vol_np, title in [
    (x_true_np, "true image"),
    (x_mlem_np, "init image (MLEM)"),
    (x_pdhg_np, f"PDHG {num_iter_pdhg} it. (reference)"),
    (x_lmspdhg_np, f"LM-SPDHG {num_iter_spdhg} it. / {num_subsets} subsets"),
    (x_pdhg_early_np, f"PDHG {num_iter_spdhg} it."),
]:
    fig_i, _, widgets_i = show_vol_cuts(
        vol_np, voxel_size=voxel_size, vmin=0, vmax=vmax, fig_title=title
    )
    fig_i.show()
if track_cost:
    fig2, ax2 = plt.subplots(1, 3, figsize=(12, 4), tight_layout=True)
    for i in range(2):
        ax2[i].plot(cost_pdhg, ".-", label="PDHG")
        ax2[i].plot(cost_lmspdhg, ".-", label="LM-SPDHG")
        ax2[i].grid(ls=":")
        ax2[i].legend()
        ax2[i].set_ylim(None, cost_pdhg[10:].max())
    ax2[1].set_xlim(0, num_iter_spdhg)
    ax2[2].plot(psnr_lmspdhg, ".-")
    ax2[2].grid(ls=":")
    for axx in ax2.ravel():
        axx.set_xlabel("iteration")
    ax2[0].set_title("cost", fontsize="medium")
    ax2[1].set_title("cost (zoom)", fontsize="medium")
    ax2[2].set_title("PSNR LM-SPDHG vs ref", fontsize="medium")
    fig2.show()

Gallery generated by Sphinx-Gallery