Source code for parallelproj.pet_scanners

"""description of PET scanner geometries (detector coordinates)"""

from __future__ import annotations

import abc
from parallelproj import Array
import matplotlib.pyplot as plt
import numpy as np
import array_api_compat

from types import ModuleType
from array_api_compat import size

from .backend import to_numpy_array


[docs] class PETScannerModule(abc.ABC): """abstract base class for PET scanner module""" def __init__( self, xp: ModuleType, dev: str, num_lor_endpoints: int, affine_transformation_matrix: Array | None = None, ) -> None: """ Parameters ---------- xp: ModuleType array module to use for storing the LOR endpoints dev: str device to use for storing the LOR endpoints num_lor_endpoints : int number of LOR endpoints in the module affine_transformation_matrix : Array | None, optional 4x4 affine transformation matrix applied to the LOR endpoint coordinates, default None if None, the 4x4 identity matrix is used """ self._xp = xp self._dev = dev self._num_lor_endpoints = num_lor_endpoints self._lor_endpoint_numbers = self.xp.arange(num_lor_endpoints, device=self.dev) if affine_transformation_matrix is None: aff_mat = self.xp.eye(4, device=self.dev) aff_mat[-1, -1] = 0 self._affine_transformation_matrix = aff_mat self._has_affine_transformation = False else: self._affine_transformation_matrix = affine_transformation_matrix self._has_affine_transformation = True @property def xp(self) -> ModuleType: """array module to use for storing the LOR endpoints""" return self._xp @property def dev(self) -> str: """device to use for storing the LOR endpoints""" return self._dev @property def num_lor_endpoints(self) -> int: """total number of LOR endpoints in the module Returns ------- int """ return self._num_lor_endpoints @property def lor_endpoint_numbers(self) -> Array: """array enumerating all the LOR endpoints in the module Returns ------- Array """ return self._lor_endpoint_numbers @property def affine_transformation_matrix(self) -> Array: """4x4 affine transformation matrix Returns ------- Array """ return self._affine_transformation_matrix
[docs] @abc.abstractmethod def get_raw_lor_endpoints(self, inds: Array | None = None) -> Array: """mapping from LOR endpoint indices within module to an array of "raw" world coordinates Parameters ---------- inds : Array | None, optional an non-negative integer array of indices, default None if None means all possible indices [0, ... , num_lor_endpoints - 1] Returns ------- Array a 3 x len(inds) float array with the world coordinates of the LOR endpoints """ if inds is None: inds = self.lor_endpoint_numbers raise NotImplementedError
[docs] def get_lor_endpoints(self, inds: Array | None = None) -> Array: """mapping from LOR endpoint indices within module to an array of "transformed" world coordinates Parameters ---------- inds : Array | None, optional an non-negative integer array of indices, default None if None means all possible indices [0, ... , num_lor_endpoints - 1] Returns ------- Array a 3 x len(inds) float array with the world coordinates of the LOR endpoints including an affine transformation """ lor_endpoints = self.get_raw_lor_endpoints(inds) if self._has_affine_transformation: tmp = self.xp.ones((lor_endpoints.shape[0], 4), device=self.dev) tmp[:, :-1] = lor_endpoints lor_endpoints = (tmp @ self.affine_transformation_matrix.T)[:, :3] return lor_endpoints
[docs] def show_lor_endpoints( self, ax: plt.Axes, annotation_fontsize: float = 0, annotation_prefix: str = "", annotation_offset: int = 0, transformed: bool = True, **kwargs, ) -> None: """show the LOR coordinates in a 3D scatter plot Parameters ---------- ax : plt.Axes 3D matplotlib axes annotation_fontsize : float, optional fontsize of LOR endpoint number annotation, by default 0 annotation_prefix : str, optional prefix for annotation, by default '' annotation_offset : int, optional number to add to crystal number, by default 0 transformed : bool, optional use transformed instead of raw coordinates, by default True """ if transformed: all_lor_endpoints = self.get_lor_endpoints() else: all_lor_endpoints = self.get_raw_lor_endpoints() # convert to numpy array all_lor_endpoints = to_numpy_array(all_lor_endpoints) ax.scatter( all_lor_endpoints[:, 0], all_lor_endpoints[:, 1], all_lor_endpoints[:, 2], **kwargs, ) ax.set_box_aspect( [ub - lb for lb, ub in (getattr(ax, f"get_{a}lim")() for a in "xyz")] ) ax.set_xlabel("x0") ax.set_ylabel("x1") ax.set_zlabel("x2") if annotation_fontsize > 0: for i in self.lor_endpoint_numbers: ax.text( all_lor_endpoints[int(i), 0], all_lor_endpoints[int(i), 1], all_lor_endpoints[int(i), 2], f"{annotation_prefix}{i+annotation_offset}", fontsize=annotation_fontsize, )
[docs] class BlockPETScannerModule(PETScannerModule): """Block (rectangular cuboid) PET scanner module""" def __init__( self, xp: ModuleType, dev: str, shape: tuple[int, int, int], spacing: tuple[float, float, float], affine_transformation_matrix: Array | None = None, ) -> None: """ Parameters ---------- xp : ModuleType array module to use for storing the LOR endpoints dev : str device to use for storing the LOR endpoints shape : tuple[int, int, int] shape of the regular grid of LOR endpoints forming the block module spacing : tuple[float, float, float] spacing between the LOR endpoints in each direction affine_transformation_matrix : Array | None, optional 4x4 affine transformation matrix applied to the LOR endpoint coordinates, default None if None, the 4x4 identity matrix is used """ self._shape = shape self._spacing = spacing # calculate the LOR endpoints x0 = spacing[0] * (np.arange(shape[0], dtype=np.float32) - (shape[0] - 1) / 2) x1 = spacing[1] * (np.arange(shape[1], dtype=np.float32) - (shape[1] - 1) / 2) x2 = spacing[2] * (np.arange(shape[2], dtype=np.float32) - (shape[2] - 1) / 2) # in the current version (1.12.0) of array_api_compat.torch the indexing kwargs is ignored # which is why we stick to numpy X0, X1, X2 = np.meshgrid(x0, x1, x2, indexing="ij") self._lor_endpoints = np.stack( (X0.ravel(), X1.ravel(), X2.ravel()), axis=-1, ) self._lor_endpoints = xp.asarray(self._lor_endpoints, device=dev) if affine_transformation_matrix is not None: tmp = xp.ones((self._lor_endpoints.shape[0], 4), device=dev) tmp[:, :-1] = self._lor_endpoints self._lor_endpoints = (tmp @ affine_transformation_matrix.T)[:, :3] super().__init__(xp, dev, shape[0] * shape[1] * shape[2], None) @property def shape(self) -> tuple[int, int, int]: """shape of the block module Returns ------- tuple[int, int, int] """ return self._shape @property def spacing(self) -> tuple[float, float, float]: """spacing of the block module Returns ------- tuple[float, float, float] """ return self._spacing @property def lor_endpoints(self) -> Array: """LOR endpoints of the block module Returns ------- Array """ return self._lor_endpoints
[docs] def get_raw_lor_endpoints(self, inds: Array | None = None) -> Array: if inds is None: inds = self.lor_endpoint_numbers return self.xp.take(self.lor_endpoints, inds, axis=0)
[docs] class RegularPolygonPETScannerModule(PETScannerModule): """Regular polygon PET scanner module (detectors on a regular polygon)""" def __init__( self, xp: ModuleType, dev: str, radius: float, num_sides: int, num_lor_endpoints_per_side: int, lor_spacing: float, ax0: int = 2, ax1: int = 1, affine_transformation_matrix: Array | None = None, phis: None | Array = None, ) -> None: """ Parameters ---------- xp: ModuleType array module to use for storing the LOR endpoints device: str device to use for storing the LOR endpoints radius : float inner radius of the regular polygon num_sides: int number of sides of the regular polygon num_lor_endpoints_per_sides: int number of LOR endpoints per side lor_spacing : float spacing between the LOR endpoints in the polygon direction ax0 : int, optional axis number for the first direction, by default 2 ax1 : int, optional axis number for the second direction, by default 1 affine_transformation_matrix : Array | None, optional 4x4 affine transformation matrix applied to the LOR endpoint coordinates, default None if None, the 4x4 identity matrix is used phis : None | Array, optional angle of each side, by default None means that the sides are equally spaced around a circle """ self._radius = radius self._num_sides = num_sides self._num_lor_endpoints_per_side = num_lor_endpoints_per_side self._ax0 = ax0 self._ax1 = ax1 self._lor_spacing = lor_spacing super().__init__( xp, dev, num_sides * num_lor_endpoints_per_side, affine_transformation_matrix, ) # angle of each "side" if phis is None: self._phis = ( 2 * self.xp.pi * self.xp.arange(self._num_sides, dtype=xp.float32, device=dev) / self.num_sides ) else: self._phis = phis @property def radius(self) -> float: """inner radius of the regular polygon Returns ------- float """ return self._radius @property def num_sides(self) -> int: """number of sides of the regular polygon Returns ------- int """ return self._num_sides @property def num_lor_endpoints_per_side(self) -> int: """number of LOR endpoints per side Returns ------- int """ return self._num_lor_endpoints_per_side @property def ax0(self) -> int: """axis number for the first module direction Returns ------- int """ return self._ax0 @property def ax1(self) -> int: """axis number for the second module direction Returns ------- int """ return self._ax1 @property def lor_spacing(self) -> float: """spacing between the LOR endpoints in a module along the polygon Returns ------- float """ return self._lor_spacing @property def phis(self) -> Array: """azimuthal angle of each side Returns ------- Array """ return self._phis # abstract method from base class to be implemented
[docs] def get_raw_lor_endpoints(self, inds: Array | None = None) -> Array: if inds is None: inds = self.lor_endpoint_numbers side = inds // self.num_lor_endpoints_per_side tmp = inds - side * self.num_lor_endpoints_per_side tmp = self.xp.astype(tmp, self.xp.float32) - ( self.num_lor_endpoints_per_side / 2 - 0.5 ) phi = self.xp.take(self._phis, side) lor_endpoints = self.xp.zeros((self.num_lor_endpoints, 3), device=self.dev) lor_endpoints[:, self.ax0] = ( self.xp.cos(phi) * self.radius - self.xp.sin(phi) * self.lor_spacing * tmp ) lor_endpoints[:, self.ax1] = ( self.xp.sin(phi) * self.radius + self.xp.cos(phi) * self.lor_spacing * tmp ) return lor_endpoints
[docs] class ModularizedPETScannerGeometry: """description of a PET scanner geometry consisting of LOR endpoint modules""" def __init__(self, modules: tuple[PETScannerModule]): """ Parameters ---------- modules : tuple[PETScannerModule] a tuple of scanner modules """ # member variable that determines whether we want to use # a numpy or cupy array to store the array of all lor endpoints self._modules = modules self._num_modules = len(self._modules) self._num_lor_endpoints_per_module = self.xp.asarray( [x.num_lor_endpoints for x in self._modules], device=self.dev ) self._num_lor_endpoints = int(self.xp.sum(self._num_lor_endpoints_per_module)) self.setup_all_lor_endpoints()
[docs] def setup_all_lor_endpoints(self) -> None: """calculate the position of all lor endpoints by iterating over the modules and calculating the transformed coordinates of all module endpoints """ self._all_lor_endpoints_index_offset = self.xp.asarray( [ int(sum(self._num_lor_endpoints_per_module[:i])) for i in range(size(self._num_lor_endpoints_per_module)) ], device=self.dev, ) self._all_lor_endpoints = self.xp.zeros( (self._num_lor_endpoints, 3), device=self.dev, dtype=self.xp.float32 ) for i, module in enumerate(self._modules): self._all_lor_endpoints[ int(self._all_lor_endpoints_index_offset[i]) : int( self._all_lor_endpoints_index_offset[i] + module.num_lor_endpoints ), :, ] = self.xp.astype(module.get_lor_endpoints(), self.xp.float32) self._all_lor_endpoints_module_number = [ int(self._num_lor_endpoints_per_module[i]) * [i] for i in range(self._num_modules) ] self._all_lor_endpoints_module_number = self.xp.asarray( [i for r in self._all_lor_endpoints_module_number for i in r], device=self.dev, )
@property def modules(self) -> tuple[PETScannerModule]: """tuple of modules defining the scanner""" return self._modules @property def num_modules(self) -> int: """the number of modules defining the scanner""" return self._num_modules @property def num_lor_endpoints_per_module(self) -> Array: """numpy array showing how many LOR endpoints are in every module""" return self._num_lor_endpoints_per_module @property def num_lor_endpoints(self) -> int: """the total number of LOR endpoints in the scanner""" return self._num_lor_endpoints @property def all_lor_endpoints_index_offset(self) -> Array: """the offset in the linear (flattend) index for all LOR endpoints""" return self._all_lor_endpoints_index_offset @property def all_lor_endpoints_module_number(self) -> Array: """the module number of all LOR endpoints""" return self._all_lor_endpoints_module_number @property def all_lor_endpoints(self) -> Array: """the world coordinates of all LOR endpoints""" return self._all_lor_endpoints @property def xp(self) -> ModuleType: """array module to use for storing the LOR endpoints""" return self._modules[0].xp @property def dev(self) -> str: """device to use for storing the LOR endpoints""" return self._modules[0].dev
[docs] def linear_lor_endpoint_index( self, module: Array, index_in_module: Array, ) -> Array: """transform the module + index_in_modules indices into a flattened / linear LOR endpoint index Parameters ---------- module : Array containing module numbers index_in_module : Array containing index in modules Returns ------- Array the flattened LOR endpoint index """ return ( self.xp.take(self.all_lor_endpoints_index_offset, module, axis=0) + index_in_module )
[docs] def get_lor_endpoints(self, module: Array, index_in_module: Array) -> Array: """get the coordinates for LOR endpoints defined by module and index in module Parameters ---------- module : Array the module number of the LOR endpoints index_in_module : Array the index in module number of the LOR endpoints Returns ------- Array the 3 world coordinates of the LOR endpoints """ return self.xp.take( self.all_lor_endpoints, self.linear_lor_endpoint_index(module, index_in_module), axis=0, )
[docs] def show_lor_endpoints( self, ax: plt.Axes, show_linear_index: bool = True, **kwargs ) -> None: """show all LOR endpoints in a 3D plot Parameters ---------- ax : plt.Axes a 3D matplotlib axes show_linear_index : bool, optional annotate the LOR endpoints with the linear LOR endpoint index **kwargs : keyword arguments passed to show_lor_endpoints() of the scanner module """ for i, module in enumerate(self.modules): if show_linear_index: offset = to_numpy_array(self.all_lor_endpoints_index_offset[i]) prefix = "" else: offset = 0 prefix = f"{i}," module.show_lor_endpoints( ax, annotation_offset=offset, annotation_prefix=prefix, **kwargs )
[docs] class RegularPolygonPETScannerGeometry(ModularizedPETScannerGeometry): """description of a PET scanner geometry consisting stacked regular polygons Examples -------- .. minigallery:: parallelproj.RegularPolygonPETScannerGeometry """ def __init__( self, xp: ModuleType, dev: str, radius: float, num_sides: int, num_lor_endpoints_per_side: int, lor_spacing: float, ring_positions: Array, symmetry_axis: int, phis: None | Array = None, ) -> None: """ Parameters ---------- xp: ModuleType array module to use for storing the LOR endpoints dev: str device to use for storing the LOR endpoints radius : float radius of the scanner num_sides : int number of sides (faces) of each regular polygon num_lor_endpoints_per_side : int number of LOR endpoints in each side (face) of each polygon lor_spacing : float spacing between the LOR endpoints in each side ring_positions : Array 1D array with the coordinate of the rings along the ring axis symmetry_axis : int the ring axis (0,1,2) phis : None | Array, optional angle of each side, by default None means that the sides are equally spaced around a circle """ self._radius = radius self._num_sides = num_sides self._num_lor_endpoints_per_side = num_lor_endpoints_per_side self._lor_spacing = lor_spacing self._symmetry_axis = symmetry_axis self._ring_positions = ring_positions if symmetry_axis == 0: self._ax0 = 2 self._ax1 = 1 elif symmetry_axis == 1: self._ax0 = 0 self._ax1 = 2 elif symmetry_axis == 2: self._ax0 = 1 self._ax1 = 0 modules = [] for ring in range(self.num_rings): aff_mat = xp.eye(4, device=dev) aff_mat[symmetry_axis, -1] = ring_positions[ring] modules.append( RegularPolygonPETScannerModule( xp, dev, radius, num_sides, num_lor_endpoints_per_side=num_lor_endpoints_per_side, lor_spacing=lor_spacing, affine_transformation_matrix=aff_mat, ax0=self._ax0, ax1=self._ax1, phis=phis, ) ) super().__init__(tuple(modules)) self._all_lor_endpoints_index_in_ring = ( self.xp.arange(self.num_lor_endpoints, device=dev) - self.all_lor_endpoints_ring_number * self.num_lor_endpoints_per_module[0] ) @property def radius(self) -> float: """radius of the scanner""" return self._radius @property def num_sides(self) -> int: """number of sides (faces) of each polygon""" return self._num_sides @property def num_lor_endpoints_per_side(self) -> int: """number of LOR endpoints per side (face) in each polygon""" return self._num_lor_endpoints_per_side @property def num_rings(self) -> int: """number of rings (regular polygons)""" return self._ring_positions.shape[0] @property def lor_spacing(self) -> float: """the spacing between the LOR endpoints in every side (face) of each polygon""" return self._lor_spacing @property def symmetry_axis(self) -> int: """The symmetry axis. Also called axial (or ring) direction.""" return self._symmetry_axis @property def all_lor_endpoints_ring_number(self) -> Array: """the ring (regular polygon) number of all LOR endpoints""" return self._all_lor_endpoints_module_number @property def all_lor_endpoints_index_in_ring(self) -> Array: """the index within the ring (regular polygon) of all LOR endpoints""" return self._all_lor_endpoints_index_in_ring @property def num_lor_endpoints_per_ring(self) -> int: """the number of LOR endpoints per ring (regular polygon)""" return int(self._num_lor_endpoints_per_module[0]) @property def ring_positions(self) -> Array: """the ring (regular polygon) positions""" return self._ring_positions
[docs] class DemoPETScannerGeometry(RegularPolygonPETScannerGeometry): """Demo PET scanner geometry consisting of a 34-ogon with 16 LOR endpoints per side and 36 rings""" def __init__( self, xp: ModuleType, dev: str, radius: float = 0.5 * (744.1 + 2 * 8.51), num_sides: int = 34, num_lor_endpoints_per_side: int = 16, lor_spacing: float = 4.03125, num_rings: int = 36, symmetry_axis: int = 2, ) -> None: """ Parameters ---------- xp : ModuleType array module dev : str the device to use radius : float, optional radius of the regular polygon, by default 0.5*(744.1 + 2 * 8.51) num_sides : int, optional number of sides of the polygon, by default 34 num_lor_endpoints_per_side : int, optional number of LOR endpoints per side, by default 16 lor_spacing : float, optional spacing between the LOR endpoints, by default 4.03125 num_rings : int, optional number of rings, by default 36 symmetry_axis : int, optional symmetry (axial) axis of the scanner, by default 2 """ ring_positions = ( 5.32 * xp.arange(num_rings, device=dev, dtype=xp.float32) + (xp.astype(xp.arange(num_rings, device=dev) // 9, xp.float32)) * 2.8 ) ring_positions -= 0.5 * xp.max(ring_positions) super().__init__( xp, dev, radius=radius, num_sides=num_sides, num_lor_endpoints_per_side=num_lor_endpoints_per_side, lor_spacing=lor_spacing, ring_positions=ring_positions, symmetry_axis=symmetry_axis, )