"""description of PET LORs (and sinograms bins) consisting of two detector endpoints"""
from __future__ import annotations
import abc
import enum
from types import ModuleType
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d.art3d import Line3DCollection
from matplotlib.axes import Axes
from matplotlib.colors import BoundaryNorm, ListedColormap
from matplotlib.lines import Line2D
from matplotlib.patches import Circle
from ._backend import Array, to_numpy_array
from .operators import LinearOperator
from .pet_scanners import (
ModularizedPETScannerGeometry,
RegularPolygonPETScannerGeometry,
)
[docs]
class SinogramSpatialAxisOrder(enum.Enum):
"""order of spatial axis in a sinogram R (radial), V (view), P (plane)"""
RVP = enum.auto()
"""[radial,view,plane]"""
RPV = enum.auto()
"""[radial,plane,view]"""
VRP = enum.auto()
"""[view,radial,plane]"""
VPR = enum.auto()
"""[view,plane,radial]"""
PRV = enum.auto()
"""[plane,radial,view]"""
PVR = enum.auto()
"""[plane,view,radial]"""
[docs]
class Michelogram:
"""Axial plane layout for a cylindrical PET scanner under odd span.
Encapsulates the segment / axial-position combinatorics that map every
valid ring pair :math:`(s, e)` onto a sinogram plane index under
span conventions.
For span :math:`S` (odd) and a maximum ring difference :math:`D`, each
ring pair with :math:`|e - s| \\le D` is assigned a segment via
:meth:`ring_diff_to_segment`. Ring pairs sharing the same
:math:`(\\text{segment},\\; s + e)` collapse into a single plane.
Planes are ordered by :math:`(|\\text{seg}|,\\; -\\text{seg},\\; s + e)`
:math:`[0, +1, -1, +2, -2, \\ldots]` with axial bins increasing in
:math:`s + e` (equivalently in z for equispaced rings).
The class knows nothing about ring z-positions, scanner radius, or
sinogram axis ordering — it operates on pure integer indices. Consumers
(e.g. :class:`RegularPolygonPETLORDescriptor`,
:class:`SinogramAxialCompressionOperator`) combine it with the geometry-
and array-API-specific information they need.
For span ``= 1`` the layout reduces to the unspanned Michelogram (each
ring pair is its own plane with :attr:`max_multiplicity` ``== 1``); the
ordering ``rd = 0, +1, -1, +2, -2, ...`` with each ring difference
sorted by ring sum.
Parameters
----------
num_rings : int
Number of detector rings (:math:`R \\ge 1`).
max_ring_difference : int
Maximum ring difference :math:`|e - s|` considered (:math:`\\ge 0`).
Values larger than ``num_rings - 1`` have no extra effect.
span : int, optional
Axial compression factor — must be odd and :math:`\\ge 1`.
Default ``1`` (no compression).
Examples
--------
>>> m = Michelogram(num_rings=3, max_ring_difference=2, span=3)
>>> int(m.num_planes)
7
>>> int(m.max_multiplicity)
2
>>> int(m.ring_diff_to_segment(0)), int(m.ring_diff_to_segment(2)), \
int(m.ring_diff_to_segment(-2))
(0, 1, -1)
"""
# ------------------------------------------------------------------
# Construction / core formula
# ------------------------------------------------------------------
def __init__(
self,
num_rings: int,
max_ring_difference: int,
span: int = 1,
) -> None:
if not isinstance(num_rings, int) or num_rings < 1:
raise ValueError("num_rings must be a positive integer")
if not isinstance(max_ring_difference, int) or max_ring_difference < 0:
raise ValueError("max_ring_difference must be a non-negative integer")
if not isinstance(span, int) or span < 1 or span % 2 == 0:
raise ValueError("span must be an odd positive integer")
self._num_rings = int(num_rings)
self._max_ring_difference = int(max_ring_difference)
self._span = int(span)
self._half_span = (self._span - 1) // 2
self._build()
[docs]
def ring_diff_to_segment(self, rd: int) -> int:
"""Signed segment number for a given ring difference :math:`e - s`.
Returns
-------
int
``0`` if :math:`|rd| \\le \\text{half\\_span}`, otherwise
:math:`\\pm k` with :math:`k = \\lceil (|rd| - \\text{half\\_span}) / S \\rceil`
and sign equal to that of :math:`rd`.
"""
S = self._span
half_span = self._half_span
abs_rd = abs(rd)
if abs_rd <= half_span:
return 0
k = (abs_rd - half_span + S - 1) // S
return k if rd > 0 else -k
def _build(self) -> None:
"""Compute and cache the full plane layout."""
R = self._num_rings
D = self._max_ring_difference
# Group every valid ring pair (s, e) by (segment, s + e). Iteration
# order here does not matter — we sort the result.
plane_groups: dict[tuple[int, int], list[tuple[int, int]]] = {}
for s in range(R):
for e in range(R):
rd = e - s
if abs(rd) > D:
continue
seg = self.ring_diff_to_segment(rd)
key = (seg, s + e)
plane_groups.setdefault(key, []).append((s, e))
# standard segment sequence + within-segment axial-midpoint order.
sorted_keys = sorted(
plane_groups.keys(), key=lambda k: (abs(k[0]), -k[0], k[1])
)
num_planes = len(sorted_keys)
plane_segment = np.empty(num_planes, dtype=np.int32)
plane_axial_midpoint_int = np.empty(num_planes, dtype=np.int32)
plane_multiplicity = np.empty(num_planes, dtype=np.int32)
for pi, key in enumerate(sorted_keys):
plane_segment[pi] = key[0]
plane_axial_midpoint_int[pi] = key[1]
plane_multiplicity[pi] = len(plane_groups[key])
max_mult = int(plane_multiplicity.max()) if num_planes > 0 else 0
plane_start_rings = np.zeros((num_planes, max_mult), dtype=np.int32)
plane_end_rings = np.zeros((num_planes, max_mult), dtype=np.int32)
plane_mask = np.zeros((num_planes, max_mult), dtype=np.float32)
# Inverse lookup table; -1 indicates an invalid pair (|rd| > D).
plane_for_ring_pair_table = np.full((R, R), -1, dtype=np.int32)
for pi, key in enumerate(sorted_keys):
pairs = plane_groups[key]
for k, (s, e) in enumerate(pairs):
plane_start_rings[pi, k] = s
plane_end_rings[pi, k] = e
plane_mask[pi, k] = 1.0
plane_for_ring_pair_table[s, e] = pi
self._num_planes = num_planes
self._max_multiplicity = max_mult
self._plane_segment = plane_segment
self._plane_axial_midpoint_int = plane_axial_midpoint_int
self._plane_multiplicity = plane_multiplicity
self._plane_start_rings = plane_start_rings
self._plane_end_rings = plane_end_rings
self._plane_mask = plane_mask
self._plane_for_ring_pair_table = plane_for_ring_pair_table
# ------------------------------------------------------------------
# Read-only properties
# ------------------------------------------------------------------
@property
def num_rings(self) -> int:
"""Number of rings."""
return self._num_rings
@property
def max_ring_difference(self) -> int:
"""Maximum ring difference :math:`|e - s|`."""
return self._max_ring_difference
@property
def span(self) -> int:
"""Axial compression factor (odd)."""
return self._span
@property
def num_planes(self) -> int:
"""Total number of sinogram planes."""
return self._num_planes
@property
def max_multiplicity(self) -> int:
"""Largest plane multiplicity (most ring pairs in any one plane)."""
return self._max_multiplicity
@property
def plane_segment(self) -> np.ndarray:
"""Signed segment number for each plane, shape ``(num_planes,)``,
dtype ``int32``."""
return self._plane_segment
@property
def plane_axial_midpoint_int(self) -> np.ndarray:
"""Integer axial midpoint :math:`s + e` (= twice the actual midpoint)
for each plane, shape ``(num_planes,)``, dtype ``int32``."""
return self._plane_axial_midpoint_int
@property
def plane_multiplicity(self) -> np.ndarray:
"""Number of ring pairs contributing to each plane,
shape ``(num_planes,)``, dtype ``int32``."""
return self._plane_multiplicity
@property
def plane_start_rings(self) -> np.ndarray:
"""Contributing start ring indices per plane, right-padded with ``0``.
Shape ``(num_planes, max_multiplicity)``, dtype ``int32``. Use
:attr:`plane_mask` to identify the valid entries.
"""
return self._plane_start_rings
@property
def plane_end_rings(self) -> np.ndarray:
"""Contributing end ring indices per plane, right-padded with ``0``.
Shape ``(num_planes, max_multiplicity)``, dtype ``int32``. Use
:attr:`plane_mask` to identify the valid entries.
"""
return self._plane_end_rings
@property
def plane_mask(self) -> np.ndarray:
"""Validity mask for :attr:`plane_start_rings` / :attr:`plane_end_rings`.
Shape ``(num_planes, max_multiplicity)``, dtype ``float32``. Entries
are ``1.0`` for valid contributing ring pairs and ``0.0`` for
right-padding.
"""
return self._plane_mask
@property
def plane_for_ring_pair_table(self) -> np.ndarray:
"""``(num_rings, num_rings)`` lookup table whose entry ``[s, e]``
is the plane index for ring pair ``(s, e)``, or ``-1`` if
:math:`|e - s| > \\text{max\\_ring\\_difference}`."""
return self._plane_for_ring_pair_table
# ------------------------------------------------------------------
# Lookups
# ------------------------------------------------------------------
[docs]
def plane_for_ring_pair(self, s: int, e: int) -> int:
"""Plane index for the ring pair ``(s, e)``.
Raises
------
IndexError
If either ``s`` or ``e`` is outside ``[0, num_rings)``.
ValueError
If :math:`|e - s| > \\text{max\\_ring\\_difference}`.
"""
if not (0 <= s < self._num_rings) or not (0 <= e < self._num_rings):
raise IndexError(
f"ring indices out of range: ({s}, {e}); "
f"num_rings={self._num_rings}"
)
pi = int(self._plane_for_ring_pair_table[s, e])
if pi < 0:
raise ValueError(
f"ring pair ({s}, {e}) has |rd|={abs(e - s)} > "
f"max_ring_difference={self._max_ring_difference}"
)
return pi
# ------------------------------------------------------------------
# Geometry helpers
# ------------------------------------------------------------------
[docs]
def average_z_per_plane(self, ring_positions) -> tuple[np.ndarray, np.ndarray]:
"""Mean ring z-coordinate per plane, separately for start and end rings.
Equivalent to averaging ``ring_positions`` over the contributing ring
pairs of each plane. For span ``=1`` planes this is trivially the
single contributing ring's z; for span ``> 1`` planes it produces the
averaged-LOR z-position used by the spanned setup of
:class:`RegularPolygonPETLORDescriptor` and by
:meth:`show_segment_lors`.
Parameters
----------
ring_positions : array-like, shape ``(num_rings,)``
z-coordinate of each ring (any backend; converted via ``np.asarray``).
Returns
-------
start_z : np.ndarray, shape ``(num_planes,)``, dtype ``float32``
end_z : np.ndarray, shape ``(num_planes,)``, dtype ``float32``
"""
ring_pos = np.asarray(ring_positions, dtype=np.float64)
if ring_pos.ndim != 1 or ring_pos.shape[0] != self._num_rings:
raise ValueError(
"ring_positions must be a 1-D array of length "
f"num_rings={self._num_rings}"
)
mult = self._plane_multiplicity.astype(np.float64)
start_z = (ring_pos[self._plane_start_rings] * self._plane_mask).sum(
axis=1
) / mult
end_z = (ring_pos[self._plane_end_rings] * self._plane_mask).sum(axis=1) / mult
return start_z.astype(np.float32), end_z.astype(np.float32)
# ------------------------------------------------------------------
# Axial compression
# ------------------------------------------------------------------
[docs]
def compression_index_maps_to(
self, target: "Michelogram"
) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
"""Build gather/scatter index maps to a pre-built target Michelogram.
Returns the integer index structures needed to map planes of this
Michelogram onto planes of ``target``.
Both Michelograms must describe the same scanner geometry
(``target.num_rings == self.num_rings``), and ``target.span`` must
be an integer multiple of ``self.span``. Because both spans are
odd by construction, the ratio ``target.span / self.span`` is then
automatically odd, which guarantees that every ring pair of any
input plane shares the same target plane — so the operation is a
single-valued gather.
The target's ``max_ring_difference`` must be at least
``self.max_ring_difference`` so every input ring pair has a target
plane. If it is strictly greater, the resulting maps still work but
some output planes will have zero multiplicity (output bins that no
input ring pair contributes to).
Parameters
----------
target : Michelogram
Pre-built target Michelogram. Validation rules above.
Returns
-------
target_for_p1 : np.ndarray, shape ``(self.num_planes,)``, dtype ``int64``
For each plane of this Michelogram, the corresponding plane index
in ``target``.
idx2d : np.ndarray, shape ``(target.num_planes, target_max_mult)``, ``int64``
For each target plane, the indices in this Michelogram that
contribute, right-padded with ``0``. Use ``mask`` to filter.
mask : np.ndarray, same shape as ``idx2d``, dtype ``float32``
``1.0`` for valid entries, ``0.0`` for right-padding.
target_multiplicity : np.ndarray, shape ``(target.num_planes,)``, ``int32``
Number of self-planes folded into each target plane.
Raises
------
TypeError
If ``target`` is not a :class:`Michelogram` instance.
ValueError
If ``target.num_rings`` differs from ``self.num_rings``;
if ``target.span < self.span``;
if ``self.span`` does not divide ``target.span``;
or if ``target.max_ring_difference < self.max_ring_difference``.
"""
if not isinstance(target, Michelogram):
raise TypeError("target must be a Michelogram instance")
if target.num_rings != self._num_rings:
raise ValueError(
f"target.num_rings ({target.num_rings}) must match "
f"self.num_rings ({self._num_rings})"
)
if target.span < self._span:
raise ValueError(
f"target.span ({target.span}) must be >= self.span ({self._span})"
)
if target.span % self._span != 0:
raise ValueError(
f"target.span ({target.span}) must be an integer multiple "
f"of self.span ({self._span})"
)
if target.max_ring_difference < self._max_ring_difference:
raise ValueError(
f"target.max_ring_difference ({target.max_ring_difference}) "
f"must be >= self.max_ring_difference "
f"({self._max_ring_difference}) so every input ring pair "
"has a target plane"
)
num_planes_in = self._num_planes
target_for_p1 = np.empty(num_planes_in, dtype=np.int64)
# Any ring pair in an input plane maps to the same target plane
# under the divisibility condition checked above, so we take the
# first one as a representative.
for pi in range(num_planes_in):
s = int(self._plane_start_rings[pi, 0])
e = int(self._plane_end_rings[pi, 0])
target_for_p1[pi] = target.plane_for_ring_pair(s, e)
# Invert: per target plane, list contributing input plane indices.
num_planes_out = target.num_planes
groups: list[list[int]] = [[] for _ in range(num_planes_out)]
for pi in range(num_planes_in):
groups[int(target_for_p1[pi])].append(pi)
target_multiplicity = np.fromiter(
(len(g) for g in groups), dtype=np.int32, count=num_planes_out
)
target_max_mult = int(target_multiplicity.max()) if num_planes_out > 0 else 0
idx2d = np.zeros((num_planes_out, target_max_mult), dtype=np.int64)
mask = np.zeros((num_planes_out, target_max_mult), dtype=np.float32)
for n, g in enumerate(groups):
if g:
idx2d[n, : len(g)] = g
mask[n, : len(g)] = 1.0
return target_for_p1, idx2d, mask, target_multiplicity
[docs]
def compression_index_maps(
self, target_span: int
) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
"""Build gather/scatter index maps to a higher-span Michelogram.
Convenience wrapper around :meth:`compression_index_maps_to` that
builds the target Michelogram internally as
``Michelogram(self.num_rings, self.max_ring_difference, span=target_span)``.
Parameters
----------
target_span : int
Odd integer ``>= self.span``; additionally
``(target_span // self.span)`` must be odd.
Returns
-------
See :meth:`compression_index_maps_to`.
Raises
------
ValueError
If ``target_span`` is not a positive odd integer. Further
validation errors are raised by
:meth:`compression_index_maps_to`.
"""
if not isinstance(target_span, int) or target_span < 1 or target_span % 2 == 0:
raise ValueError("target_span must be an odd positive integer")
target = Michelogram(
self._num_rings, self._max_ring_difference, span=target_span
)
return self.compression_index_maps_to(target)
# ------------------------------------------------------------------
# Visualization
# ------------------------------------------------------------------
[docs]
def show(
self,
ax: Axes,
show_merge_lines: bool = True,
plane_index_fontsize: float = 6,
**kwargs,
) -> None:
"""Draw the Michelogram scatter plot onto ``ax``.
Each point represents a valid ring pair ``(s, e)``, colored by
``abs(segment)``. For ``span > 1``, ring pairs that share the same
``(segment, s + e)`` and therefore collapse into the same sinogram
plane are connected by a thin grey line when ``show_merge_lines``
is ``True``.
Parameters
----------
ax : plt.Axes
2-D matplotlib axes (not 3-D).
show_merge_lines : bool, optional
Draw lines connecting ring pairs that merge into the same plane.
Defaults to ``True``. Only has a visible effect for ``span > 1``.
plane_index_fontsize : float, optional
Font size of the per-plane index annotations placed at each
ring-pair (or merged-group) centroid. Defaults to ``6``. Useful
knob when the Michelogram is large (lower to avoid overlap) or
small (raise for readability).
**kwargs
Forwarded to ``ax.scatter`` (e.g. ``s=4``, ``cmap="RdBu_r"``).
"""
self._draw_axes(
ax,
show_merge_lines=show_merge_lines,
plane_index_fontsize=plane_index_fontsize,
**kwargs,
)
[docs]
def show_segment_lors(
self,
ring_positions,
axs=None,
uncompressed_lor_kwargs: dict | None = None,
compressed_lor_kwargs: dict | None = None,
inset_plane_index_fontsize: float = 4,
):
"""Side-view LOR diagram per segment with a Michelogram inset.
Mirrors the descriptor's
:meth:`RegularPolygonPETLORDescriptor.show_segment_lors`, but takes
``ring_positions`` explicitly so the Michelogram can be visualised
standalone (e.g. with ``np.arange(num_rings)`` for a purely
schematic plot, or with the user's actual ring z-positions).
Subplots are arranged in a 2-row grid (when negative segments exist):
* **columns** indexed by ``abs(segment)``: 0, 1, 2, ...
* **row 0** non-negative segments (0, +1, +2, ...)
* **row 1, col 0** Michelogram inset
* **row 1, col >= 1** negative segments (-1, -2, ...)
Each LOR subplot shows the uncompressed (per-ring-pair) LORs as
solid black lines and the compressed (axially-averaged) LORs as
dashed coloured lines.
Parameters
----------
ring_positions : array-like, shape ``(num_rings,)``
z-coordinate of each ring.
axs : 2-D array-like of Axes, optional
Pre-existing axes of shape ``(n_rows, n_cols)``. If ``None``,
a new figure is created.
uncompressed_lor_kwargs : dict, optional
Style overrides for the uncompressed LOR lines.
compressed_lor_kwargs : dict, optional
Style overrides for the compressed LOR lines.
Returns
-------
matplotlib.figure.Figure
"""
R = self._num_rings
D = self._max_ring_difference
ring_pos = np.asarray(ring_positions, dtype=np.float64)
if ring_pos.ndim != 1 or ring_pos.shape[0] != R:
raise ValueError(
f"ring_positions must be a 1-D array of length num_rings={R}"
)
start_z, end_z = self.average_z_per_plane(ring_pos)
start_z_np = np.asarray(start_z, dtype=np.float64)
end_z_np = np.asarray(end_z, dtype=np.float64)
seg_arr_np = np.asarray(self._plane_segment, dtype=np.int32)
all_segs = sorted(set(int(v) for v in seg_arr_np))
abs_segs = sorted(set(abs(s) for s in all_segs))
n_cols = len(abs_segs)
neg_segs = [s for s in all_segs if s < 0]
n_rows = 2 if neg_segs else 1
unc_kw: dict = {"color": "black", "lw": 1.0, "alpha": 0.5}
if uncompressed_lor_kwargs:
unc_kw.update(uncompressed_lor_kwargs)
com_kw: dict = {"lw": 1.5, "alpha": 0.9, "linestyle": "--"}
if compressed_lor_kwargs:
com_kw.update(compressed_lor_kwargs)
created_fig = axs is None
if created_fig:
fig, raw = plt.subplots(
n_rows,
n_cols,
figsize=(3 * n_cols, 4 * n_rows),
squeeze=False,
)
_axs = raw
else:
_axs = np.asarray(axs)
fig = _axs.flat[0].get_figure()
# coordinate normalisation
z_min, z_max = ring_pos.min(), ring_pos.max()
z_span = max(z_max - z_min, 1.0)
margin = 0.12 * z_span
x_L, x_R = -z_span / 2.0, z_span / 2.0
ring_r = 0.35 * z_span / max(R, 1)
# uncompressed ring-pair lookup keyed by signed segment
uncompressed: dict[int, list[tuple[int, int]]] = {s: [] for s in all_segs}
for s_ring in range(R):
for e_ring in range(R):
rd = e_ring - s_ring
if abs(rd) > D:
continue
seg = self.ring_diff_to_segment(rd)
if seg in uncompressed:
uncompressed[seg].append((s_ring, e_ring))
n_colors = len(abs_segs)
base_cmap = plt.get_cmap("tab10" if n_colors <= 10 else "tab20")
for col_idx, abs_seg in enumerate(abs_segs):
color = base_cmap(col_idx)
for row_idx in range(n_rows):
ax = _axs[row_idx, col_idx]
# [1, 0]: Michelogram inset instead of the non-existent
# segment -0.
if row_idx == 1 and abs_seg == 0:
self._draw_axes(
ax, plane_index_fontsize=inset_plane_index_fontsize
)
continue
seg_val = abs_seg if row_idx == 0 else -abs_seg
if seg_val not in all_segs:
ax.axis("off")
continue
# (a) uncompressed ring-pair LORs
for s_r, e_r in uncompressed[seg_val]:
ax.plot(
[x_L, x_R],
[ring_pos[s_r], ring_pos[e_r]],
**unc_kw,
)
# (b) compressed (averaged) LORs
mask = seg_arr_np == seg_val
kw = dict(com_kw)
if "color" not in kw:
kw["color"] = color
for sz, ez in zip(start_z_np[mask], end_z_np[mask]):
ax.plot([x_L, x_R], [float(sz), float(ez)], **kw)
# detector rings: one Circle per ring at each detector side
for xpos in [x_L, x_R]:
for z in ring_pos:
ax.add_patch(
Circle(
(xpos, float(z)),
ring_r,
edgecolor="black",
facecolor="lightgray",
lw=0.6,
zorder=5,
)
)
if seg_val > 0:
seg_label = f"+{abs_seg}"
elif seg_val < 0:
seg_label = f"-{abs_seg}"
else:
seg_label = "0"
n_compressed = int((seg_arr_np == seg_val).sum())
n_uncompressed = len(uncompressed[seg_val])
ax.set_title(
f"seg {seg_label} {n_compressed} / {n_uncompressed}",
fontsize="small",
)
ax.set_xlim(x_L - 2 * ring_r, x_R + 2 * ring_r)
ax.set_ylim(z_min - margin * 2.0, z_max + margin * 2.0)
ax.set_aspect("equal")
ax.axis("off")
if row_idx == 0 and col_idx == 0:
ax.legend(
handles=[
Line2D(
[0],
[0],
color="black",
lw=1.0,
alpha=0.5,
label="uncompressed",
),
Line2D(
[0],
[0],
color="black",
lw=1.5,
linestyle="--",
label="compressed",
),
],
loc="upper right",
fontsize="x-small",
)
if created_fig:
fig.tight_layout()
return fig
def _draw_axes(
self,
ax: Axes,
show_merge_lines: bool = True,
plane_index_fontsize: float = 6,
**kwargs,
) -> None:
"""Internal helper: draw the Michelogram onto an existing axes.
Shared between :meth:`show` (main use) and :meth:`show_segment_lors`
(inset).
"""
R = self._num_rings
D = self._max_ring_difference
S = self._span
# Unroll the padded layout into flat arrays of valid ring pairs.
total = int(self._plane_mask.sum())
start_arr = np.empty(total, dtype=np.float32)
end_arr = np.empty(total, dtype=np.float32)
seg_arr = np.empty(total, dtype=np.int32)
idx = 0
for pi in range(self._num_planes):
seg = int(self._plane_segment[pi])
m = int(self._plane_multiplicity[pi])
for k in range(m):
start_arr[idx] = int(self._plane_start_rings[pi, k])
end_arr[idx] = int(self._plane_end_rings[pi, k])
seg_arr[idx] = seg
idx += 1
abs_seg_arr = np.abs(seg_arr)
n_colors = int(abs_seg_arr.max()) + 1
base_cmap = plt.get_cmap("tab10" if n_colors <= 10 else "tab20")
cmap = ListedColormap([base_cmap(i) for i in range(n_colors)])
norm = BoundaryNorm(np.arange(-0.5, n_colors, 1.0), cmap.N)
kwargs.setdefault("s", 20)
ax.scatter(
start_arr,
end_arr,
c=abs_seg_arr.astype(np.float32),
cmap=cmap,
norm=norm,
**kwargs,
)
if show_merge_lines and S > 1:
for pi in range(self._num_planes):
m = int(self._plane_multiplicity[pi])
if m > 1:
xs = self._plane_start_rings[pi, :m].astype(np.float32)
ys = self._plane_end_rings[pi, :m].astype(np.float32)
order = np.argsort(xs)
ax.plot(
xs[order],
ys[order],
color="gray",
lw=0.5,
alpha=0.5,
)
# Annotate each plane at the group centroid with its plane index.
for pi in range(self._num_planes):
m = int(self._plane_multiplicity[pi])
xs = [int(self._plane_start_rings[pi, k]) for k in range(m)]
ys = [int(self._plane_end_rings[pi, k]) for k in range(m)]
cx = float(np.mean(xs))
cy = float(np.mean(ys))
ax.text(
cx,
cy,
str(pi),
ha="center",
va="center",
fontsize=plane_index_fontsize,
color="black",
fontweight="bold",
zorder=10,
)
ax.set_xlabel("start ring")
ax.set_ylabel("end ring")
ax.set_title(f"Michelogram\n(span={S}, max Dring={D})", fontsize="small")
ax.set_aspect("equal")
ax.set_xlim(-0.5, R - 0.5)
ax.set_ylim(-0.5, R - 0.5)
def __repr__(self) -> str:
return (
f"{self.__class__.__name__}("
f"num_rings={self._num_rings}, "
f"max_ring_difference={self._max_ring_difference}, "
f"span={self._span})"
)
[docs]
class PETLORDescriptor(abc.ABC):
"""abstract base class to describe which modules / indices in modules of a
modularized PET scanner are in coincidence; defining geometrical LORs"""
def __init__(self, scanner: ModularizedPETScannerGeometry) -> None:
"""
Parameters
----------
scanner : ModularizedPETScannerGeometry
a modularized PET scanner
"""
self._scanner = scanner
[docs]
@abc.abstractmethod
def get_lor_coordinates(self) -> tuple[Array, Array]:
"""return the start and end coordinates of all (or a subset of) LORs"""
raise NotImplementedError
@property
def scanner(self) -> ModularizedPETScannerGeometry:
"""the scanner for which coincidences are described"""
return self._scanner
@property
def xp(self) -> ModuleType:
"""array module to use for storing the LOR endpoints"""
return self.scanner.xp
@property
def dev(self) -> str:
"""device to use for storing the LOR endpoints"""
return self.scanner.dev
[docs]
class EqualBlockPETLORDescriptor(PETLORDescriptor):
"""LOR descriptor for scanner consisting of block modules where each
block module has the same number of LOR endpoints"""
def __init__(
self, scanner: ModularizedPETScannerGeometry, all_block_pairs: Array
) -> None:
"""
Parameters
----------
scanner : ModularizedPETScannerGeometry
A modularized PET scanner consisting of block modules
with the same number of LOR endpoints.
all_block_pairs : Array
An array containing pairs of integer numbers encoding
which block pairs are in coincidence and form valid LORs.
Returns
-------
None
"""
# check if all modules (blocks) have the same number of LOR enpoints
lor_endpoints_per_block = [x.num_lor_endpoints for x in scanner.modules]
if not all(x == lor_endpoints_per_block[0] for x in lor_endpoints_per_block):
raise ValueError(
"All modules (blocks) must have the same number of LOR endpoints"
)
super().__init__(scanner)
self._scanner = scanner
self._all_block_pairs = self.xp.asarray(all_block_pairs, device=self.dev)
self._num_lorendpoints_per_block = self.scanner.modules[0].num_lor_endpoints
self._num_lors_per_block_pair = self._num_lorendpoints_per_block**2
@property
def all_block_pairs(self) -> Array:
"""all block pairs in coincidence"""
return self._all_block_pairs
@property
def num_block_pairs(self) -> int:
"""number of block pairs in coincidence"""
return self._all_block_pairs.shape[0]
@property
def num_lorendpoints_per_block(self) -> int:
"""number of LOR endpoints per block"""
return self._num_lorendpoints_per_block
@property
def num_lors_per_block_pair(self) -> int:
"""number of LORs per block pair"""
return self._num_lors_per_block_pair
@property
def xp(self) -> ModuleType:
"""array module to use for storing the LOR endpoints"""
return self.scanner.xp
@property
def dev(self) -> str:
"""device to use for storing the LOR endpoints"""
return self.scanner.dev
[docs]
def get_lor_coordinates(
self, block_pair_nums: None | Array = None
) -> tuple[Array, Array]:
"""
Get the coordinates of LORs for the given block pair numbers.
Parameters
----------
block_pair_nums : None or Array, optional
The block pair numbers for which to retrieve the LOR coordinates.
If None, all block pair numbers will be used.
Returns
-------
tuple[Array, Array]
A tuple containing two arrays:
- the start coordinates of the LORs, with shape (N, 3), where N is the total number of LORs.
- the end coordinates of the LORs, with shape (N, 3)
"""
if block_pair_nums is None:
block_pair_nums = self.xp.arange(
self._all_block_pairs.shape[0], device=self.dev
)
assert block_pair_nums is not None
num_selected = block_pair_nums.shape[0]
# get start and end block indices for all selected block pairs
bp = self.xp.take(
self._all_block_pairs, block_pair_nums, axis=0
) # (num_selected, 2)
start_blocks = bp[:, 0] # (num_selected,)
end_blocks = bp[:, 1] # (num_selected,)
# build the within-block endpoint index pairs via meshgrid (computed once)
tmp = self.xp.arange(self._num_lorendpoints_per_block, device=self.dev)
a, b = self.xp.meshgrid(tmp, tmp, indexing="ij")
a = self.xp.reshape(a, (-1,)) # (num_lors_per_block_pair,)
b = self.xp.reshape(b, (-1,)) # (num_lors_per_block_pair,)
# flat index over all (block_pair, lor) combinations
lor_idx = self.xp.arange(
num_selected * self._num_lors_per_block_pair, device=self.dev
)
within_pair_idx = lor_idx % self._num_lors_per_block_pair
pair_idx = lor_idx // self._num_lors_per_block_pair
# tile endpoint indices and repeat block indices across all LORs
a_all = self.xp.take(a, within_pair_idx, axis=0)
b_all = self.xp.take(b, within_pair_idx, axis=0)
start_blocks_all = self.xp.take(start_blocks, pair_idx, axis=0)
end_blocks_all = self.xp.take(end_blocks, pair_idx, axis=0)
xstart = self.scanner.get_lor_endpoints(start_blocks_all, a_all)
xend = self.scanner.get_lor_endpoints(end_blocks_all, b_all)
return xstart, xend
[docs]
def show_block_pair_lors(
self, ax: Axes, block_pair_nums: Array, lw: float = 0.2, **kwargs
) -> None:
"""show all LORs connecting all endpoints between blocks forming a block pairs
Parameters
----------
ax : plt.Axes
a 3D matplotlib axes
block_pair_nums : int
the block pair numbers to show
lw : float, optional
the line width, by default 0.2
"""
xs, xe = self.get_lor_coordinates(block_pair_nums=block_pair_nums)
p1s = to_numpy_array(xs)
p2s = to_numpy_array(xe)
ls = np.hstack([p1s, p2s]).copy()
ls = ls.reshape((-1, 2, 3))
lc = Line3DCollection(ls, linewidths=lw, **kwargs)
ax.add_collection(lc)
[docs]
class RegularPolygonPETLORDescriptor(PETLORDescriptor):
"""LOR descriptor for a regular polygon PET scanner where
we have coincidences within and between "rings (polygons of modules)"
The geometrical LORs can be sorted into a sinogram having a
"plane", "view" and "radial" axis."""
def __init__(
self,
scanner: RegularPolygonPETScannerGeometry,
michelogram: Michelogram | None = None,
radial_trim: int = 3,
sinogram_order: SinogramSpatialAxisOrder = SinogramSpatialAxisOrder.RVP,
) -> None:
"""
Parameters
----------
scanner : RegularPolygonPETScannerGeometry
a regular polygon PET scanner.
michelogram : Michelogram, optional
the axial plane layout — the single source of truth for the
spanning combinatorics (segments, axial midpoints, ring-pair
grouping, ordering). If ``None`` (default), a span-1 layout with
no constraint on the ring difference is used, i.e.
``Michelogram(scanner.num_rings, scanner.num_rings - 1, span=1)``.
The Michelogram must have ``num_rings == scanner.num_rings``.
radial_trim : int, optional
number of geometrial LORs to disregard in the radial direction.
Defaults to 3.
sinogram_order : SinogramSpatialAxisOrder, optional
the order of the sinogram axes. Defaults to
``SinogramSpatialAxisOrder.RVP``.
"""
super().__init__(scanner)
if michelogram is None:
michelogram = Michelogram(
num_rings=scanner.num_rings,
max_ring_difference=scanner.num_rings - 1,
span=1,
)
elif michelogram.num_rings != scanner.num_rings:
raise ValueError(
f"michelogram.num_rings ({michelogram.num_rings}) must equal "
f"scanner.num_rings ({scanner.num_rings})"
)
self._scanner = scanner
self._radial_trim = radial_trim
self._michelogram = michelogram
self._max_ring_difference = self._michelogram.max_ring_difference
self._span = self._michelogram.span
self._num_rad = (scanner.num_lor_endpoints_per_ring + 1) - 2 * self._radial_trim
self._num_views = scanner.num_lor_endpoints_per_ring // 2
self._sinogram_order = sinogram_order
# declare all attributes set by the setup methods so they are
# visible in __init__
self._num_planes: int = 0
# None only when span > 1; the properties guard with AttributeError
# before returning.
self._start_plane_index: Array | None = None
self._end_plane_index: Array | None = None
# always set to a real Array by the setup methods
self._start_plane_z: Array = None # type: ignore[assignment]
self._end_plane_z: Array = None # type: ignore[assignment]
self._plane_multiplicity: Array = None # type: ignore[assignment]
self._plane_segment: Array = None # type: ignore[assignment]
self._start_in_ring_index: Array = None # type: ignore[assignment]
self._end_in_ring_index: Array = None # type: ignore[assignment]
self._setup_plane_data()
self._setup_view_indices()
@property
def scanner(self) -> RegularPolygonPETScannerGeometry:
"""the scanner for which coincidences are described"""
return self._scanner
@property
def radial_trim(self) -> int:
"""number of geometrial LORs to disregard in the radial direction"""
return self._radial_trim
@property
def max_ring_difference(self) -> int:
"""the maximum ring difference"""
return self._max_ring_difference
@property
def num_planes(self) -> int:
"""number of planes in the sinogram"""
return self._num_planes
@property
def num_rad(self) -> int:
"""number of radial elements in the sinogram"""
return self._num_rad
@property
def num_views(self) -> int:
"""number of views in the sinogram"""
return self._num_views
@property
def span(self) -> int:
"""axial compression factor (1 = no compression)"""
return self._span
@property
def michelogram(self) -> Michelogram:
"""The :class:`Michelogram` describing the axial plane layout.
This is the single source of truth for the spanning combinatorics
(segments, axial midpoints, ring-pair grouping, ordering). Useful
for visualization, axial compression operators, or any user code
that needs access to the integer ring-pair structure.
"""
return self._michelogram
@property
def start_plane_index(self) -> Array:
"""start ring index for all planes (only defined for span=1)"""
if self._span > 1:
raise AttributeError(
"start_plane_index is not defined for span > 1. Use start_plane_z instead."
)
assert self._start_plane_index is not None
return self._start_plane_index
@property
def end_plane_index(self) -> Array:
"""end ring index for all planes (only defined for span=1)"""
if self._span > 1:
raise AttributeError(
"end_plane_index is not defined for span > 1. Use end_plane_z instead."
)
assert self._end_plane_index is not None
return self._end_plane_index
@property
def start_plane_z(self) -> Array:
"""start z-coordinate for all planes (averaged over constituent ring pairs for span > 1)"""
return self._start_plane_z
@property
def end_plane_z(self) -> Array:
"""end z-coordinate for all planes (averaged over constituent ring pairs for span > 1)"""
return self._end_plane_z
@property
def plane_multiplicity(self) -> Array:
"""number of ring pairs contributing to each plane (always 1 for span=1)"""
return self._plane_multiplicity
@property
def plane_segment(self) -> Array:
"""segment number for each plane (equals abs(rd) for span=1)"""
return self._plane_segment
@property
def start_in_ring_index(self) -> Array:
"""start index within ring for all views - shape (num_view, num_rad)"""
return self._start_in_ring_index
@property
def end_in_ring_index(self) -> Array:
"""end index within ring for all views - shape (num_view, num_rad)"""
return self._end_in_ring_index
@property
def sinogram_order(self) -> SinogramSpatialAxisOrder:
"""the order of the sinogram axes"""
return self._sinogram_order
@property
def plane_axis_num(self) -> int:
"""the axis number of the plane axis"""
return self.sinogram_order.name.find("P")
@property
def radial_axis_num(self) -> int:
"""the axis number of the radial axis"""
return self.sinogram_order.name.find("R")
@property
def view_axis_num(self) -> int:
"""the axis number of the view axis"""
return self.sinogram_order.name.find("V")
@property
def spatial_sinogram_shape(self) -> tuple[int, ...]:
"""the shape of the sinogram in spatial order"""
shape = [0, 0, 0]
shape[self.plane_axis_num] = self.num_planes
shape[self.view_axis_num] = self.num_views
shape[self.radial_axis_num] = self.num_rad
return tuple(shape)
def __str__(self) -> str:
"""string representation"""
return (
self.__class__.__name__
+ " with spatial sinogram shape ("
+ ", ".join(
[
f"{self.spatial_sinogram_shape[i]}{self.sinogram_order.name[i]}"
for i in range(3)
]
)
+ ")"
)
def _setup_plane_data(self) -> None:
"""Project the Michelogram's per-plane data onto the scanner's xp/device.
Reads ``plane_segment`` and ``plane_multiplicity`` directly from the
Michelogram and casts them to ``xp`` arrays on the scanner's device.
Computes ``start_plane_z`` / ``end_plane_z`` by averaging the scanner's
ring positions over each plane's contributing ring pairs via
:meth:`Michelogram.average_z_per_plane`.
For ``span == 1`` the per-plane data is single-valued, so
``start_plane_index`` and ``end_plane_index`` are exposed as 1-D
arrays of ring indices (the first column of the Michelogram's padded
layout). For ``span > 1`` they remain ``None`` and the
:attr:`start_plane_index` / :attr:`end_plane_index` properties raise
``AttributeError``; the padded per-plane ring indices are available
via ``self.michelogram.plane_start_rings`` / ``plane_end_rings``.
"""
m = self._michelogram
xp = self.xp
dev = self.dev
self._num_planes = m.num_planes
self._plane_segment = xp.asarray(m.plane_segment, device=dev)
self._plane_multiplicity = xp.asarray(m.plane_multiplicity, device=dev)
ring_positions_np = np.asarray(
to_numpy_array(self._scanner.ring_positions), dtype=np.float64
)
start_z, end_z = m.average_z_per_plane(ring_positions_np)
self._start_plane_z = xp.asarray(start_z, device=dev)
self._end_plane_z = xp.asarray(end_z, device=dev)
if self._span == 1:
self._start_plane_index = xp.asarray(
m.plane_start_rings[:, 0], device=dev
)
self._end_plane_index = xp.asarray(
m.plane_end_rings[:, 0], device=dev
)
else:
self._start_plane_index = None
self._end_plane_index = None
def _setup_view_indices(self) -> None:
"""setup the start / end view indices"""
n = self._scanner.num_lor_endpoints_per_ring
m = 2 * (n // 2)
self._start_in_ring_index = self.xp.zeros(
(self._num_views, self._num_rad), dtype=self.xp.int32, device=self.dev
)
self._end_in_ring_index = self.xp.zeros(
(self._num_views, self._num_rad), dtype=self.xp.int32, device=self.dev
)
for view in np.arange(self._num_views):
self._start_in_ring_index[view, :] = self.xp.astype(
(
self.xp.concat((self.xp.arange(m) // 2, self.xp.asarray([n // 2])))
- int(view)
)[self._radial_trim : -self._radial_trim],
self.xp.int32,
)
self._end_in_ring_index[view, :] = self.xp.astype(
(
self.xp.concat(
(self.xp.asarray([-1]), -((self.xp.arange(m) + 4) // 2))
)
- int(view)
)[self._radial_trim : -self._radial_trim],
self.xp.int32,
)
# shift the negative indices
self._start_in_ring_index = self.xp.where(
self._start_in_ring_index >= 0,
self._start_in_ring_index,
self._start_in_ring_index + n,
)
self._end_in_ring_index = self.xp.where(
self._end_in_ring_index >= 0,
self._end_in_ring_index,
self._end_in_ring_index + n,
)
[docs]
def get_lor_coordinates(
self,
views: None | Array = None,
) -> tuple[Array, Array]:
"""return the start and end coordinates of all LORs / or a subset of views
Parameters
----------
views : None | Array, optional
the views to consider, by default None means all views
Returns
-------
xstart, xend : Array
2 dimensional floating point arrays containing the start and end coordinates of all LORs
"""
if views is None:
views = self.xp.arange(self.num_views, device=self.dev)
# --- (1) setup the LOR start / end points for all views of plane 0
start_in_ring_index = self.xp.take(self.start_in_ring_index, views, axis=0)
end_in_ring_index = self.xp.take(self.end_in_ring_index, views, axis=0)
if self.view_axis_num > self.radial_axis_num:
start_in_ring_index = start_in_ring_index.T
end_in_ring_index = end_in_ring_index.T
shape_2d = start_in_ring_index.shape
start_inds_2d = self.xp.reshape(start_in_ring_index, (-1,))
end_inds_2d = self.xp.reshape(end_in_ring_index, (-1,))
xstart_2d = self.xp.reshape(
self.scanner.get_lor_endpoints(
self.xp.zeros_like(start_inds_2d), start_inds_2d
),
shape_2d + (3,),
)
xend_2d = self.xp.reshape(
self.scanner.get_lor_endpoints(
self.xp.zeros_like(end_inds_2d), end_inds_2d
),
shape_2d + (3,),
)
xstart_3d = []
xend_3d = []
# --- (2) stack copies of the plane 0 LOR start / end points for all planes with updated "z" coordinates
for i in range(self.num_planes):
xstart = self.xp.asarray(xstart_2d, copy=True)
xend = self.xp.asarray(xend_2d, copy=True)
xstart[..., self._scanner.symmetry_axis] = float(self._start_plane_z[i])
xend[..., self._scanner.symmetry_axis] = float(self._end_plane_z[i])
xstart_3d.append(xstart)
xend_3d.append(xend)
xstart_3d = self.xp.stack(xstart_3d, axis=self.plane_axis_num)
xend_3d = self.xp.stack(xend_3d, axis=self.plane_axis_num)
return xstart_3d, xend_3d
[docs]
def show_views(
self, ax: Axes, views: Array, planes: Array, lw: float = 0.2, **kwargs
) -> None:
"""show all LORs of a single view in a given plane
Parameters
----------
ax : plt.Axes
a 3D matplotlib axes
view : int
the view number
plane : int
the plane number
lw : float, optional
the line width, by default 0.2
"""
xs, xe = self.get_lor_coordinates(views=views)
xs = self.xp.reshape(
self.xp.take(xs, planes, axis=self.plane_axis_num), (-1, 3)
)
xe = self.xp.reshape(
self.xp.take(xe, planes, axis=self.plane_axis_num), (-1, 3)
)
p1s = to_numpy_array(xs)
p2s = to_numpy_array(xe)
ls = np.hstack([p1s, p2s]).copy()
ls = ls.reshape((-1, 2, 3))
lc = Line3DCollection(ls, linewidths=lw, **kwargs)
ax.add_collection(lc)
[docs]
def show_michelogram(
self,
ax: Axes,
show_merge_lines: bool = True,
**kwargs,
) -> None:
"""Visualize the Michelogram.
Thin wrapper around :meth:`Michelogram.show`; see that method for full
documentation of arguments.
"""
self._michelogram.show(ax, show_merge_lines=show_merge_lines, **kwargs)
[docs]
def show_segment_lors(
self,
axs=None,
uncompressed_lor_kwargs: dict | None = None,
compressed_lor_kwargs: dict | None = None,
):
"""Side-view LOR diagram per segment with the Michelogram inset.
Thin wrapper around :meth:`Michelogram.show_segment_lors`; this
method supplies the scanner's ring positions automatically. See
:meth:`Michelogram.show_segment_lors` for full documentation.
"""
ring_positions_np = np.asarray(
to_numpy_array(self._scanner.ring_positions), dtype=np.float64
)
return self._michelogram.show_segment_lors(
ring_positions_np,
axs=axs,
uncompressed_lor_kwargs=uncompressed_lor_kwargs,
compressed_lor_kwargs=compressed_lor_kwargs,
)
[docs]
def get_distributed_views_and_slices(
self, num_subsets: int, num_dim: int
) -> tuple[list[Array], list[tuple[slice, ...]]]:
"""distribute sinogram views numbers into subsets
Parameters
----------
num_subsets : int
number of subsets
num_dim : int
number of dimensions of the sinogram
to setup the subset slices
(e.g. 3 for non-TOF, 4 for TOF)
Returns
-------
tuple[list[Array], list[tuple[slice, ...]]]
subset views numbers and subset slices
"""
subset_nums = []
for i in range(num_subsets // 2):
subset_nums += [x for x in range(i, num_subsets, num_subsets // 2)]
subset_slices = []
subset_views = []
all_views = self.xp.arange(self.num_views, device=self.dev)
for i in subset_nums:
sl = num_dim * [slice(None)]
sl[self.view_axis_num] = slice(i, None, num_subsets)
sl = tuple(sl)
subset_slices.append(sl)
subset_views.append(all_views[sl[self.view_axis_num]])
return subset_views, subset_slices
[docs]
class SinogramAxialCompressionOperator(LinearOperator):
"""Linear operator that axially compresses a span-1 PET sinogram into a higher odd span.
For an input :class:`RegularPolygonPETLORDescriptor` with ``span=1`` and a
target odd span :math:`S`, every span-1 ring pair
:math:`(s, e)` is assigned to an output bin
:math:`(\\text{segment}, \\text{axial midpoint})` where
* ``segment`` is determined by the ring difference :math:`rd = e - s` under
target span :math:`S`
:meth:`RegularPolygonPETLORDescriptor._ring_diff_to_segment`),
* ``axial midpoint`` is :math:`s + e` (an integer equal to twice the actual
midpoint).
All span-1 ring pairs sharing the same
:math:`(\\text{segment}, s + e)` collapse into a single output plane.
Two reduction modes are supported:
* ``mode="sum"`` (default). The output plane is the **sum** of the
contributing input planes:
.. math::
y_n \\;=\\; \\sum_{p_1 \\in \\mathcal{G}(n)} x_{p_1}
\\qquad
\\left(G^T y\\right)_{p_1} \\;=\\; y_{\\,\\tau(p_1)}\\,.
This is the natural reduction for **counts-like** sinograms — emission
data, measured counts, randoms, etc. — which add when ring pairs are
grouped together.
* ``mode="average"``. The output plane is the **mean** of the
contributing input planes:
.. math::
y_n \\;=\\; \\frac{1}{m_n} \\sum_{p_1 \\in \\mathcal{G}(n)} x_{p_1}
\\qquad
\\left(G_{\\rm avg}^T y\\right)_{p_1}
\\;=\\; \\frac{y_{\\,\\tau(p_1)}}{m_{\\,\\tau(p_1)}}\\,.
This is the natural reduction for **multiplicative-factor** sinograms —
attenuation factors, sensitivity / normalisation factors, geometric
efficiency — which should *average* rather than *sum* when ring pairs
are grouped together.
In both expressions, :math:`\\mathcal{G}(n)` is the set of input plane
indices mapped to output plane :math:`n`, :math:`m_n = |\\mathcal{G}(n)|`
is the plane multiplicity, and :math:`\\tau(p_1)` is the output plane
index for input plane :math:`p_1`.
Output plane ordering matches that of
:class:`RegularPolygonPETLORDescriptor` constructed with the same scanner,
``radial_trim``, ``max_ring_difference``, and ``sinogram_order`` but with
``span=target_span``. That companion descriptor is exposed as
:attr:`out_lor_descriptor` for visualization (e.g. ``show_michelogram``,
``show_segment_lors``) or for composing the operator with a span-:math:`S`
projector.
The closed-form operator 2-norms are
.. math::
\\|G_{\\rm sum}\\|_2 = \\sqrt{\\max_n m_n}\\,, \\qquad
\\|G_{\\rm avg}\\|_2 = 1 / \\sqrt{\\min_n m_n}\\,,
derived from :math:`G_{\\rm sum} G_{\\rm sum}^T = \\operatorname{diag}(m_n)`
and :math:`G_{\\rm avg} G_{\\rm avg}^T = \\operatorname{diag}(1/m_n)`.
:meth:`norm` returns these directly without power iteration.
Parameters
----------
lor_descriptor : RegularPolygonPETLORDescriptor
A ``span=1`` LOR descriptor whose sinogram is to be compressed.
target_span : int
Odd integer ``>= 1`` giving the target axial compression. ``1`` is
accepted and yields an identity-like operator (each input plane maps to
a single output plane in the same span-1 order).
mode : {"sum", "average"}, optional
Reduction mode. ``"sum"`` (default) is appropriate for counts-like
sinograms; ``"average"`` is appropriate for multiplicative-factor
sinograms such as attenuation or sensitivity factors.
num_tof_bins : int or None, optional
If ``None`` (default), the operator acts on the 3D spatial sinogram
with shape :attr:`RegularPolygonPETLORDescriptor.spatial_sinogram_shape`.
If a positive integer, the operator acts on a 4D TOF sinogram whose
trailing axis (size ``num_tof_bins``) is the TOF axis and is passed
through unchanged.
Examples
--------
>>> import array_api_compat.numpy as xp
>>> import parallelproj.pet_scanners as pps
>>> import parallelproj.pet_lors as ppl
>>> scanner = pps.RegularPolygonPETScannerGeometry(
... xp, "cpu", radius=65.0, num_sides=12, num_lor_endpoints_per_side=4,
... lor_spacing=4.0, ring_positions=xp.asarray([0.0, 1.0, 2.0]),
... symmetry_axis=2,
... )
>>> lor_s1 = ppl.RegularPolygonPETLORDescriptor(
... scanner, ppl.Michelogram(scanner.num_rings, 2, span=1), radial_trim=1,
... )
>>> comp = ppl.SinogramAxialCompressionOperator(lor_s1, target_span=3)
>>> comp.in_shape, comp.out_shape # doctest: +SKIP
((..., ..., 9), (..., ..., 7))
>>> comp.adjointness_test(xp, "cpu")
True
"""
def __init__(
self,
lor_descriptor: RegularPolygonPETLORDescriptor,
target_span: int,
mode: str = "sum",
num_tof_bins: int | None = None,
) -> None:
if not isinstance(lor_descriptor, RegularPolygonPETLORDescriptor):
raise TypeError("lor_descriptor must be a RegularPolygonPETLORDescriptor")
if lor_descriptor.span != 1:
raise ValueError("input lor_descriptor must have span=1")
if not isinstance(target_span, int) or target_span < 1 or target_span % 2 == 0:
raise ValueError("target_span must be an odd positive integer")
if mode not in ("sum", "average"):
raise ValueError(
f"mode must be 'sum' or 'average', got {mode!r}"
)
if num_tof_bins is not None and (
not isinstance(num_tof_bins, int) or num_tof_bins < 1
):
raise ValueError("num_tof_bins must be a positive integer or None")
super().__init__()
self._lor_descriptor = lor_descriptor
self._target_span = int(target_span)
self._mode = mode
self._num_tof_bins = num_tof_bins
self._xp = lor_descriptor.xp
self._dev = lor_descriptor.dev
self._plane_axis = lor_descriptor.plane_axis_num
# Build the target Michelogram exactly once and reuse it for both
# the companion descriptor and the compression index maps below.
target_michelogram = Michelogram(
num_rings=lor_descriptor.scanner.num_rings,
max_ring_difference=lor_descriptor.max_ring_difference,
span=self._target_span,
)
self._out_lor_descriptor = RegularPolygonPETLORDescriptor(
scanner=lor_descriptor.scanner,
michelogram=target_michelogram,
radial_trim=lor_descriptor.radial_trim,
sinogram_order=lor_descriptor.sinogram_order,
)
self._build_index_maps(target_michelogram)
# in/out shapes honour sinogram_order's plane_axis_num and optional TOF.
spatial_in = tuple(lor_descriptor.spatial_sinogram_shape)
spatial_out = tuple(self._out_lor_descriptor.spatial_sinogram_shape)
if num_tof_bins is None:
self._in_shape = spatial_in
self._out_shape = spatial_out
else:
self._in_shape = spatial_in + (int(num_tof_bins),)
self._out_shape = spatial_out + (int(num_tof_bins),)
def _build_index_maps(self, target_michelogram: Michelogram) -> None:
"""Build the gather/scatter index structures from the Michelogram.
All the combinatorial work — segment assignment, ring-pair grouping,
STIR-standard plane ordering, padded index construction — lives on
:class:`Michelogram`. This method just converts those numpy arrays
to the descriptor's ``xp`` and ``dev`` and stores them.
``target_michelogram`` is the pre-built target Michelogram already
used to construct the companion span-N descriptor, reused here so
the layout is built only once per operator.
Stores on ``self``:
* ``_target_for_p1`` : shape ``(num_planes_1,)``, the output plane
index for every input plane. Used by :meth:`_adjoint`.
* ``_idx2d_flat`` : shape ``(num_planes_n * max_mult,)``, flattened
gather index. ``xp.take(x, _idx2d_flat, axis=plane_axis)`` followed
by a reshape gives ``(num_planes_n, max_mult)`` along the plane axis.
* ``_mask2d`` : shape ``(num_planes_n, max_mult)``, ``1.0`` for
valid entries and ``0.0`` for right-padding. Used to zero-out
padding contributions in :meth:`_apply`.
* ``_multiplicity`` : shape ``(num_planes_n,)``, multiplicity of each
output plane.
* ``_max_mult`` : the largest plane multiplicity.
Because the companion span-N descriptor is built from the same
Michelogram instance, its plane ordering and per-plane multiplicity
agree with this operator's by construction; no cross-check is needed.
"""
target_for_p1, idx2d, mask2d, multiplicity = (
self._lor_descriptor.michelogram.compression_index_maps_to(
target_michelogram
)
)
num_planes_n = int(multiplicity.shape[0])
max_mult = int(idx2d.shape[1]) if num_planes_n > 0 else 0
xp = self._xp
dev = self._dev
self._num_planes_1 = int(target_for_p1.shape[0])
self._num_planes_n = num_planes_n
self._max_mult = max_mult
self._min_mult = int(multiplicity.min()) if num_planes_n > 0 else 0
self._target_for_p1 = xp.asarray(target_for_p1, device=dev)
# store flat indices because the array API standard only requires
# 1-D indices for xp.take (multi-dim take is not portable).
self._idx2d_flat = xp.asarray(idx2d.reshape(-1), device=dev)
self._mask2d = xp.asarray(mask2d, device=dev)
self._multiplicity = xp.asarray(multiplicity, device=dev)
# Pre-computed reciprocals used by mode="average". Stored as
# multiplications instead of divisions inside the hot path.
inv_multiplicity = (1.0 / multiplicity.astype(np.float32)).astype(np.float32)
self._inv_multiplicity = xp.asarray(inv_multiplicity, device=dev)
# inv-multiplicity broadcast onto the input-plane axis (one entry per
# input plane, equal to 1/m_{tau(p1)}). Used by the average-mode
# adjoint.
inv_multiplicity_at_target = inv_multiplicity[target_for_p1]
self._inv_multiplicity_at_target = xp.asarray(
inv_multiplicity_at_target, device=dev
)
# ------------------------------------------------------------------
# LinearOperator interface
# ------------------------------------------------------------------
@property
def in_shape(self) -> tuple[int, ...]:
return self._in_shape
@property
def out_shape(self) -> tuple[int, ...]:
return self._out_shape
def _apply(self, x: Array) -> Array:
"""Compress along the plane axis.
For ``mode="sum"``: ``y_n = sum_{p1 in group(n)} x_{p1}``.
For ``mode="average"``: divide the sum by the per-plane multiplicity.
"""
xp = self._xp
ax = self._plane_axis
# gather all contributing input planes per output plane. After the
# 1-D take we still have x.ndim axes, with the plane axis enlarged to
# num_planes_n * max_mult; the reshape then splits that into
# (num_planes_n, max_mult).
gathered = xp.take(x, self._idx2d_flat, axis=ax)
new_shape = list(x.shape)
new_shape[ax] = self._num_planes_n
new_shape.insert(ax + 1, self._max_mult)
gathered = xp.reshape(gathered, tuple(new_shape))
# Broadcast the (num_planes_n, max_mult) mask across all other axes.
mask_shape = [1] * gathered.ndim
mask_shape[ax] = self._num_planes_n
mask_shape[ax + 1] = self._max_mult
gathered = gathered * xp.reshape(self._mask2d, tuple(mask_shape))
# Sum over the multiplicity axis (at ax + 1).
result = xp.sum(gathered, axis=ax + 1)
if self._mode == "average":
# Divide every output plane by its multiplicity m_n.
inv_shape = [1] * result.ndim
inv_shape[ax] = self._num_planes_n
result = result * xp.reshape(self._inv_multiplicity, tuple(inv_shape))
return result
def _adjoint(self, y: Array) -> Array:
"""Expand the compressed sinogram back along the plane axis.
For ``mode="sum"``: ``x_{p1} = y_{tau(p1)}`` — each input plane gets
the value of its target output plane.
For ``mode="average"``: the broadcast value is additionally divided
by the multiplicity of the target output plane.
"""
xp = self._xp
ax = self._plane_axis
result = xp.take(y, self._target_for_p1, axis=ax)
if self._mode == "average":
# Divide every input plane by m_{tau(p1)}.
inv_shape = [1] * result.ndim
inv_shape[ax] = self._num_planes_1
result = result * xp.reshape(
self._inv_multiplicity_at_target, tuple(inv_shape)
)
return result
[docs]
def norm(
self,
xp: ModuleType,
dev: str,
num_iter: int = 30,
iscomplex: bool = False,
verbose: bool = False,
) -> float:
"""Operator 2-norm in closed form.
Because each input plane belongs to exactly one output plane,
* ``mode="sum"``: :math:`G G^T = \\operatorname{diag}(m_n)`
and therefore :math:`\\|G\\|_2 = \\sqrt{\\max_n m_n}`.
* ``mode="average"``: :math:`G_{\\rm avg} G_{\\rm avg}^T
= \\operatorname{diag}(1/m_n)`
and therefore :math:`\\|G_{\\rm avg}\\|_2 = 1 / \\sqrt{\\min_n m_n}`.
Both norms are independent of TOF, ``xp``, and ``dev``; the inherited
signature is retained for compatibility with
:class:`LinearOperator.norm` but its arguments (``xp``, ``dev``,
``num_iter``, ``iscomplex``, ``verbose``) are ignored.
"""
if self._mode == "sum":
return float(np.sqrt(self._max_mult))
# mode == "average"
return float(1.0 / np.sqrt(self._min_mult))
# ------------------------------------------------------------------
# Public read-only properties
# ------------------------------------------------------------------
@property
def lor_descriptor(self) -> RegularPolygonPETLORDescriptor:
"""The input (span-1) LOR descriptor."""
return self._lor_descriptor
@property
def out_lor_descriptor(self) -> RegularPolygonPETLORDescriptor:
"""Auto-built companion descriptor whose plane ordering matches this
operator's output."""
return self._out_lor_descriptor
@property
def target_span(self) -> int:
"""Target span (odd, >= 1)."""
return self._target_span
@property
def mode(self) -> str:
"""Reduction mode, either ``"sum"`` or ``"average"``."""
return self._mode
@property
def num_tof_bins(self) -> int | None:
"""Number of TOF bins, or ``None`` for a non-TOF operator."""
return self._num_tof_bins
@property
def plane_multiplicity(self) -> Array:
"""Number of span-1 planes that collapse into each output plane.
Shape ``(num_planes_n,)``. Equals the diagonal of
:math:`G G^T`.
"""
return self._multiplicity
@property
def target_plane_for_input_plane(self) -> Array:
"""Output plane index for each span-1 input plane.
Shape ``(num_planes_1,)``. Useful for the closed-form check
:math:`(G^T G\\,\\mathbf{1})_{p_1} = m_{\\,\\tau(p_1)}`.
"""
return self._target_for_p1
@property
def max_plane_multiplicity(self) -> int:
"""Largest plane multiplicity (:math:`\\|G\\|_2^2`)."""
return self._max_mult
@property
def num_planes_in(self) -> int:
"""Number of span-1 input planes."""
return self._num_planes_1
@property
def num_planes_out(self) -> int:
"""Number of span-:math:`S` output planes."""
return self._num_planes_n
def __str__(self) -> str:
tof_str = (
f", {self._num_tof_bins} TOF bins" if self._num_tof_bins is not None else ""
)
return (
f"{self.__class__.__name__}("
f"target_span={self._target_span}, mode={self._mode!r}, "
f"num_planes: {self._num_planes_1} -> {self._num_planes_n}, "
f"max_multiplicity={self._max_mult}{tof_str})"
)