Source code for parallelproj.operators

"""basic linear operators"""

from __future__ import annotations

from types import ModuleType
import abc
import numpy as np
import array_api_compat
from array_api_compat import device, get_namespace
from parallelproj import Array
from collections.abc import Sequence

import parallelproj


[docs] class LinearOperator(abc.ABC): """abstract base class for linear operators""" def __init__(self) -> None: self._scale = 1 @property @abc.abstractmethod def in_shape(self) -> tuple[int, ...]: """shape of the input array""" raise NotImplementedError @property @abc.abstractmethod def out_shape(self) -> tuple[int, ...]: """shape of the output array""" raise NotImplementedError @property def scale(self) -> int | float | complex: """scalar factor applied to the linear operator""" return self._scale @scale.setter def scale(self, value: int | float | complex): if not np.isscalar(value): raise ValueError self._scale = value @abc.abstractmethod def _apply(self, x: Array) -> Array: """forward step :math:`y = Ax`""" raise NotImplementedError @abc.abstractmethod def _adjoint(self, y: Array) -> Array: """adjoint step :math:`x = A^H y`""" raise NotImplementedError
[docs] def apply(self, x: Array) -> Array: """(scaled) forward step :math:`y = \\alpha A x` Parameters ---------- x : Array Returns ------- Array """ if self._scale == 1: return self._apply(x) else: return self._scale * self._apply(x)
[docs] def __call__(self, x: Array) -> Array: """alias to apply(x)""" return self.apply(x)
[docs] def adjoint(self, y: Array) -> Array: """(scaled) adjoint step :math:`x = \\overline{\\alpha} A^H y` Parameters ---------- y : Array Returns ------- Array """ if self._scale == 1: return self._adjoint(y) else: return self._scale.conjugate() * self._adjoint(y)
[docs] def adjointness_test( self, xp: ModuleType, dev: str, verbose: bool = False, iscomplex: bool = False, **kwargs, ) -> bool: """test whether the adjoint is correctly implemented Parameters ---------- xp : ModuleType array module to use dev : str device (cpu or cuda) verbose : bool, optional verbose output iscomplex : bool, optional use complex arrays **kwargs : dict passed to np.isclose Returns ------- bool whether the adjoint is correctly implemented """ if iscomplex: dtype = xp.complex128 else: dtype = xp.float64 x = xp.asarray(np.random.rand(*self.in_shape), device=dev, dtype=dtype) y = xp.asarray(np.random.rand(*self.out_shape), device=dev, dtype=dtype) if iscomplex: x = x + 1j * xp.asarray( np.random.rand(*self.in_shape), device=dev, dtype=dtype ) if iscomplex: y = y + 1j * xp.asarray( np.random.rand(*self.out_shape), device=dev, dtype=dtype ) x_fwd = self.apply(x) y_adj = self.adjoint(y) if iscomplex: ip1 = complex(xp.sum(xp.conj(x_fwd) * y)) ip2 = complex(xp.sum(xp.conj(x) * y_adj)) else: ip1 = float(xp.sum(x_fwd * y)) ip2 = float(xp.sum(x * y_adj)) if verbose: print(ip1, ip2) return np.isclose(ip1, ip2, **kwargs)
[docs] def norm( self, xp: ModuleType, dev: str, num_iter: int = 30, iscomplex: bool = False, verbose: bool = False, ) -> float: """estimate norm of the linear operator using power iterations Parameters ---------- xp : ModuleType array module to use dev : str device (cpu or cuda) num_iter : int, optional number of power iterations iscomplex : bool, optional use complex arrays verbose : bool, optional verbose output Returns ------- float the norm of the linear operator """ if iscomplex: dtype = xp.complex128 else: dtype = xp.float64 x = xp.asarray(np.random.rand(*self.in_shape), device=dev, dtype=dtype) if iscomplex: x = x + 1j * xp.asarray( np.random.rand(*self.in_shape), device=dev, dtype=dtype ) for i in range(num_iter): x = self.adjoint(self.apply(x)) norm_squared = xp.sqrt(xp.sum(xp.abs(x) ** 2)) x /= float(norm_squared) if verbose: print(f"{(i+1):03} {xp.sqrt(norm_squared):.2E}") return float(xp.sqrt(norm_squared))
[docs] class MatrixOperator(LinearOperator): """Linear Operator defined by dense matrix multiplication Examples -------- .. minigallery:: parallelproj.MatrixOperator """ def __init__(self, A: Array) -> None: """init method Parameters ---------- A : Array 2D real or complex array representing the matrix """ super().__init__() self._A = A @property def in_shape(self) -> tuple[int]: return (self._A.shape[1],) @property def out_shape(self) -> tuple[int]: return (self._A.shape[0],) @property def A(self) -> Array: """matrix of the operator""" return self._A @property def xp(self) -> ModuleType: """array module of the operator""" return array_api_compat.get_namespace(self._A)
[docs] def iscomplex(self) -> bool: """bool whether the operator is complex""" return self.xp.isdtype(self._A.dtype, self.xp.complex64) or self.xp.isdtype( self._A.dtype, self.xp.complex128 )
def _apply(self, x: Array) -> Array: return self.xp.matmul(self._A, x) def _adjoint(self, y: Array) -> Array: if self.iscomplex(): return self.xp.matmul(self.xp.conj(self._A).T, y) else: return self.xp.matmul(self._A.T, y)
[docs] class CompositeLinearOperator(LinearOperator): """Composite Linear Operator defined by a sequence of Linear Operators Given a Sequence of operators .. math:: A^0, A^1, \\ldots, A^{n-1} the composite linear operator is defined as .. math:: A(x) = A^0( A^1( ... ( A^{n-1}(x) ) ) ) Examples -------- .. minigallery:: parallelproj.CompositeLinearOperator """ def __init__(self, operators: Sequence[LinearOperator]): """init method Parameters ---------- operators : Sequence[LinearOperator, ...] Sequence of linear operators """ super().__init__() self._operators = operators @property def in_shape(self) -> tuple[int, ...]: return self._operators[-1].in_shape @property def out_shape(self) -> tuple[int, ...]: return self._operators[0].out_shape @property def operators(self) -> Sequence[LinearOperator]: """tuple of linear operators""" return self._operators def _apply(self, x: Array) -> Array: y = x for op in self[::-1]: y = op(y) return y def _adjoint(self, y: Array) -> Array: x = y for op in self: x = op.adjoint(x) return x
[docs] def __getitem__(self, i: int) -> LinearOperator: """get the i-th operator :math:`A_i`""" return self._operators[i]
[docs] class ElementwiseMultiplicationOperator(LinearOperator): """Element-wise multiplication operator (multiplication with a diagonal matrix) Examples -------- .. minigallery:: parallelproj.ElementwiseMultiplicationOperator """ def __init__(self, values: Array): """init method Parameters ---------- values : Array values of the diagonal matrix """ super().__init__() self._values = values @property def in_shape(self) -> tuple[int, ...]: return self._values.shape @property def out_shape(self) -> tuple[int, ...]: return self._values.shape @property def xp(self) -> ModuleType: """array module of the operator""" return array_api_compat.get_namespace(self._values) @property def values(self) -> Array: """values that get multiplied""" return self._values def _apply(self, x: Array) -> Array: return self._values * x def _adjoint(self, y: Array) -> Array: if self.iscomplex(): return self.xp.conj(self._values) * y else: return self._values * y
[docs] def iscomplex(self) -> bool: """bool whether the operator is complex""" return self.xp.isdtype( self._values.dtype, self.xp.complex64 ) or self.xp.isdtype(self._values.dtype, self.xp.complex128)
[docs] class TOFNonTOFElementwiseMultiplicationOperator(LinearOperator): """Element-wise multiplication operator between a non-TOF and TOF sinogram Examples -------- .. minigallery:: parallelproj.TOFNonTOFElementwiseMultiplicationOperator """ def __init__(self, in_shape: tuple[int, ...], values: Array): """init method Parameters ---------- in_shape: tuple[int, ...] shape of the TOF sinogram values : Array a non-TOF sinogram """ super().__init__() self._in_shape = in_shape self._values = values @property def in_shape(self) -> tuple[int, ...]: return self._in_shape @property def out_shape(self) -> tuple[int, ...]: return self._in_shape @property def xp(self) -> ModuleType: """array module of the operator""" return array_api_compat.get_namespace(self._values) @property def values(self) -> Array: """values that get multiplied""" return self._values def _apply(self, x: Array) -> Array: y = 1 * x # suboptimal to copy array, but bug in torch asarray with copy = True for i in range(x.shape[-1]): y[..., i] *= self._values return y def _adjoint(self, y: Array) -> Array: x = 1 * y # suboptimal to copy array if self.iscomplex(): tmp = self.xp.conj(self._values) else: tmp = self._values for i in range(x.shape[-1]): x[..., i] *= tmp return x
[docs] def iscomplex(self) -> bool: """bool whether the operator is complex""" return self.xp.isdtype( self._values.dtype, self.xp.complex64 ) or self.xp.isdtype(self._values.dtype, self.xp.complex128)
[docs] class GaussianFilterOperator(LinearOperator): """Gaussian filter operator Examples -------- .. minigallery:: parallelproj.GaussianFilterOperator """ def __init__(self, in_shape: tuple[int, ...], sigma: float | Array, **kwargs): """init method Parameters ---------- in_shape : tuple[int, ...] shape of the input array sigma: float | array standard deviation of the gaussian filter **kwargs : sometype passed to the ndimage gaussian_filter function """ super().__init__() self._in_shape = in_shape self._sigma = sigma self._kwargs = kwargs @property def in_shape(self) -> tuple[int, ...]: return self._in_shape @property def out_shape(self) -> tuple[int, ...]: return self._in_shape def _apply(self, x: Array) -> Array: xp = array_api_compat.get_namespace(x) if parallelproj.is_cuda_array(x): import array_api_compat.cupy as cp import cupyx.scipy.ndimage as ndimagex if array_api_compat.is_array_api_obj(self._sigma): sigma = cp.asarray(self._sigma) else: sigma = self._sigma return array_api_compat.to_device( xp.from_dlpack( ndimagex.gaussian_filter(cp.asarray(x), sigma=sigma, **self._kwargs) ), device(x), ) else: import scipy.ndimage as ndimage if array_api_compat.is_array_api_obj(self._sigma): sigma = np.asarray(self._sigma) else: sigma = self._sigma return array_api_compat.to_device( xp.from_dlpack( ndimage.gaussian_filter(np.asarray(x), sigma=sigma, **self._kwargs) ), device(x), ) def _adjoint(self, y: Array) -> Array: return self._apply(y)
[docs] class VstackOperator(LinearOperator): """Stacking operator for stacking multiple linear operators vertically""" def __init__(self, operators: tuple[LinearOperator, ...]) -> None: """init method Parameters ---------- operators : tuple[LinearOperator, ...] tuple of linear operators """ super().__init__() self._operators = operators self._in_shape = self._operators[0].in_shape self._out_shapes = tuple([x.out_shape for x in operators]) self._raveled_out_shapes = tuple([np.prod(x) for x in self._out_shapes]) self._out_shape = (sum(self._raveled_out_shapes),) # setup the slices for slicing the raveled output array self._slices = [] start = 0 for length in self._raveled_out_shapes: end = start + length self._slices.append(slice(start, end, None)) start = end self._slices = tuple(self._slices) @property def in_shape(self) -> tuple[int, ...]: return self._in_shape @property def out_shape(self) -> tuple[int, ...]: return self._out_shape def _apply(self, x: Array) -> Array: xp = array_api_compat.get_namespace(x) y = xp.zeros(self._out_shape, dtype=x.dtype, device=device(x)) for i, op in enumerate(self._operators): y[self._slices[i]] = xp.reshape(op(x), (-1,)) return y def _adjoint(self, y: Array) -> Array: xp = array_api_compat.get_namespace(y) x = xp.zeros(self._in_shape, dtype=y.dtype, device=device(y)) for i, op in enumerate(self._operators): x += op.adjoint(xp.reshape(y[self._slices[i]], self._out_shapes[i])) return x
[docs] class LinearOperatorSequence(Sequence[LinearOperator]): """Sequence of linear operators .. math:: A^0, A^1 \\ldots, A^{n-1} that can be evaluated independently. Examples -------- .. minigallery:: parallelproj.LinearOperatorSequence """ def __init__(self, operators: Sequence[LinearOperator]) -> None: """init method Parameters ---------- operators : Sequence[LinearOperator, ...] Sequence of linear operators """ self._operators = operators self._in_shape = self._operators[0].in_shape self._out_shapes = [x.out_shape for x in operators] self._len = len(operators) @property def in_shape(self) -> tuple[int, ...]: """shape of the input array""" return self._in_shape @property def out_shapes(self) -> list[tuple[int, ...]]: """shapes of the output array of all subset operators""" return self._out_shapes @property def operators(self) -> Sequence[LinearOperator]: """all subset operators""" return self._operators def __len__(self) -> int: """length of operator sequence""" return self._len
[docs] def __getitem__(self, i: int) -> LinearOperator: """get the i-th linear operator :math:`A^i`""" return self._operators[i]
[docs] def apply(self, x: Array) -> list[Array]: """:math:`(A^0(x), A^1(x), \\ldots, A^{n-1}(x))`""" y = [op(x) for op in self] return y
[docs] def __call__(self, x: Array) -> list[Array]: return self.apply(x)
[docs] def adjoint(self, y: list[Array]) -> Array: """:math:`\\sum_i (A^i)^H y^i` for all :math:`i`""" return sum([op.adjoint(y[i]) for (i, op) in enumerate(self)])
[docs] def norms(self, xp: ModuleType, dev: str) -> list[float]: """:math:`\\text{norm}(A^i)` for all :math:`i`""" return [op.norm(xp, dev) for op in self]
[docs] class FiniteForwardDifference(LinearOperator): """finite difference gradient operator""" def __init__(self, in_shape: tuple[int, ...]) -> None: if len(in_shape) > 4: raise ValueError("only up to 4 dimensions supported") self._ndim = len(in_shape) self._in_shape = in_shape self._out_shape = (self.ndim,) + in_shape super().__init__() @property def in_shape(self) -> tuple[int, ...]: """shape of the input array""" return self._in_shape @property def out_shape(self) -> tuple[int, ...]: """shape of the output array""" return self._out_shape @property def ndim(self) -> int: """number of dimensions of the input array""" return self._ndim def _apply(self, x: Array) -> Array: xp = array_api_compat.get_namespace(x) g = xp.zeros(self.out_shape, dtype=x.dtype, device=device(x)) if self.ndim == 1: g[0, :-1] = x[1:] - x[:-1] elif self.ndim == 2: g[0, :-1, :] = x[1:, :] - x[:-1, :] g[1, :, :-1] = x[:, 1:] - x[:, :-1] elif self.ndim == 3: g[0, :-1, :, :] = x[1:, :, :] - x[:-1, :, :] g[1, :, :-1, :] = x[:, 1:, :] - x[:, :-1, :] g[2, :, :, :-1] = x[:, :, 1:] - x[:, :, :-1] elif self.ndim == 4: g[0, :-1, :, :, :] = x[1:, :, :, :] - x[:-1, :, :, :] g[1, :, :-1, :, :] = x[:, 1:, :, :] - x[:, :-1, :, :] g[2, :, :, :-1, :] = x[:, :, 1:, :] - x[:, :, :-1, :] g[3, :, :, :, :-1] = x[:, :, :, 1:] - x[:, :, :, :-1] return g def _adjoint(self, y: Array) -> Array: xp = array_api_compat.get_namespace(y) if self.ndim == 1: tmp0 = xp.asarray(y[0, ...], copy=True) tmp0[-1] = 0 div0 = xp.zeros(self.in_shape, dtype=y.dtype, device=device(y)) div0[1:] = -tmp0[1:] + tmp0[:-1] div0[0] = -tmp0[0] res = div0 elif self.ndim == 2: tmp0 = xp.asarray(y[0, ...], copy=True) tmp1 = xp.asarray(y[1, ...], copy=True) tmp0[-1, :] = 0 tmp1[:, -1] = 0 div0 = xp.zeros(self.in_shape, dtype=y.dtype, device=device(y)) div1 = xp.zeros(self.in_shape, dtype=y.dtype, device=device(y)) div0[1:, :] = -tmp0[1:, :] + tmp0[:-1, :] div1[:, 1:] = -tmp1[:, 1:] + tmp1[:, :-1] div0[0, :] = -tmp0[0, :] div1[:, 0] = -tmp1[:, 0] res = div0 + div1 elif self.ndim == 3: tmp0 = xp.asarray(y[0, ...], copy=True) tmp1 = xp.asarray(y[1, ...], copy=True) tmp2 = xp.asarray(y[2, ...], copy=True) tmp0[-1, :, :] = 0 tmp1[:, -1, :] = 0 tmp2[:, :, -1] = 0 div0 = xp.zeros(self.in_shape, dtype=y.dtype, device=device(y)) div1 = xp.zeros(self.in_shape, dtype=y.dtype, device=device(y)) div2 = xp.zeros(self.in_shape, dtype=y.dtype, device=device(y)) div0[1:, :, :] = -tmp0[1:, :, :] + tmp0[:-1, :, :] div1[:, 1:, :] = -tmp1[:, 1:, :] + tmp1[:, :-1, :] div2[:, :, 1:] = -tmp2[:, :, 1:] + tmp2[:, :, :-1] div0[0, :, :] = -tmp0[0, :, :] div1[:, 0, :] = -tmp1[:, 0, :] div2[:, :, 0] = -tmp2[:, :, 0] res = div0 + div1 + div2 elif self.ndim == 4: tmp0 = xp.asarray(y[0, ...], copy=True) tmp1 = xp.asarray(y[1, ...], copy=True) tmp2 = xp.asarray(y[2, ...], copy=True) tmp3 = xp.asarray(y[3, ...], copy=True) tmp0[-1, :, :, :] = 0 tmp1[:, -1, :, :] = 0 tmp2[:, :, -1, :] = 0 tmp3[:, :, :, -1] = 0 div0 = xp.zeros(self.in_shape, dtype=y.dtype, device=device(y)) div1 = xp.zeros(self.in_shape, dtype=y.dtype, device=device(y)) div2 = xp.zeros(self.in_shape, dtype=y.dtype, device=device(y)) div3 = xp.zeros(self.in_shape, dtype=y.dtype, device=device(y)) div0[1:, :, :, :] = -tmp0[1:, :, :, :] + tmp0[:-1, :, :, :] div1[:, 1:, :, :] = -tmp1[:, 1:, :, :] + tmp1[:, :-1, :, :] div2[:, :, 1:, :] = -tmp2[:, :, 1:, :] + tmp2[:, :, :-1, :] div3[:, :, :, 1:] = -tmp3[:, :, :, 1:] + tmp3[:, :, :, :-1] div0[0, :, :, :] = -tmp0[0, :, :, :] div1[:, 0, :, :] = -tmp1[:, 0, :, :] div2[:, :, 0, :] = -tmp2[:, :, 0, :] div3[:, :, :, 0] = -tmp3[:, :, :, 0] res = div0 + div1 + div2 + div3 return res
[docs] class GradientFieldProjectionOperator(LinearOperator): """Gradient Field Projection Operator See Ehrhardt and Betcke: "Multicontrast MRI Reconstruction with Structure-Guided Total Variation" https://doi.org/10.1137/15M1047325 .. math:: P_{\\xi_n}x = x - \\langle \\xi_n, x \\rangle \\xi_n .. math:: \\xi_n = g_n / \\| g_n \\|_{\\eta} for the joint gradient field :math:`g_n` """ def __init__(self, gradient_field: Array, eta: float = 0.0): """ Parameters ---------- gradient_field : Array a real gradient field. In 3D, the shape would be [3,n0,n1,n2]. In 2D, the shape would be [2,n0,n1]. This can be e.g. the output of the FiniteForwardDifference operator applied to a structural prior image. eta : float, optional smoothing parameter used in the pointwise gradient norm default 0.0 """ self._xp = get_namespace(gradient_field) self._dev = device(gradient_field) if ( gradient_field.dtype == self._xp.complex64 or gradient_field.dtype == self._xp.complex128 ): raise ValueError("complex gradient fields not supported") self._eta = eta self._in_shape = gradient_field.shape self._out_shape = gradient_field.shape gradient_field_float = self._xp.astype(gradient_field, self._xp.float64) norm = self._xp.sqrt( self._xp.sum(gradient_field_float**2 + self._eta**2, axis=0) ) inds = norm > 0 self._normalized_gradient_field = self._xp.zeros( gradient_field.shape, dtype=self._xp.float64, device=self._dev, ) for i in range(self.out_shape[0]): self._normalized_gradient_field[i, ...][inds] = ( gradient_field_float[i, ...][inds] / norm[inds] ) super().__init__() @property def in_shape(self) -> tuple[int]: return self._in_shape @property def out_shape(self) -> tuple[int]: return self._out_shape @property def xp(self): """array module of the operator""" return self._xp @property def dev(self): """device of the operator""" return self._dev @property def eta(self) -> float: """smoothing parameter""" return self._eta @property def normalized_gradient_field(self): """normalized gradient field""" return self._normalized_gradient_field def _apply(self, x: Array) -> Array: return ( x - self._xp.sum(x * self._normalized_gradient_field, axis=0) * self._normalized_gradient_field ) def _adjoint(self, y: Array) -> Array: return self._apply(y)