# mypy: disable-error-code="misc", disable-error-code="attr-defined"
from math import pi, ceil
from typing import Callable, Optional, Tuple, Dict, Union, Any, Literal
from importlib import import_module
from functools import partial, lru_cache
from scipy.special import roots_legendre
from .constants import rad_to_deg, deg_to_rad
from .backend_obj import backend, ArrayLike
def _import_func_or_class(module_path: str) -> Any:
"""
Import a function or class from a module path
Parameters
----------
module_path : str
The module path to import from
Returns
-------
Callable
The imported function or class
"""
module_name, name = module_path.rsplit(".", 1)
mod = import_module(module_name)
return getattr(mod, name) # type: ignore
def _eval_expression(input_string: str) -> Union[int, float]:
"""
Evaluates a string expression to create an integer or float
Parameters
----------
input_string : str
The string expression to evaluate
Returns
-------
Union[int, float]
The result of the evaluation
Raises
------
NameError
If a disallowed constant is used
"""
# Allowed modules to use string evaluation
allowed_names = {"pi": pi}
# Compile the input string
code = compile(input_string, "<string>", "eval")
# Check for disallowed names
for name in code.co_names:
if name not in allowed_names:
# Throw an error if a disallowed name is used
raise NameError(f"Use of {name} not allowed")
# Evaluate the input string without using builtins
# for security
return eval(code, {"__builtins__": {}}, allowed_names)
[docs]
def flip_axis_ratio(q, phi):
"""
Makes the value of 'q' positive, then swaps x and y axes if 'q' is larger than 1.
Parameters
----------
q: ArrayLike
ArrayLike containing values to be processed.
phi: ArrayLike
ArrayLike containing the phi values for the orientation of the axes.
Returns
-------
Tuple[ArrayLike, ArrayLike]
Tuple containing the processed 'q' and 'phi' ArrayLikes.
"""
q = backend.abs(q)
return backend.where(q > 1, 1 / q, q), backend.where(q > 1, phi + pi / 2, phi)
[docs]
def translate_rotate(x, y, x0, y0, phi: Optional[ArrayLike] = None):
"""
Translates and rotates the points (x, y) by subtracting (x0, y0)
and applying rotation angle phi.
Parameters
----------
x: ArrayLike
ArrayLike containing the x-coordinates.
y: ArrayLike
ArrayLike containing the y-coordinates.
x0: ArrayLike
ArrayLike containing the x-coordinate translation values.
y0: ArrayLike
ArrayLike containing the y-coordinate translation values.
phi: Optional[ArrayLike], optional)
ArrayLike containing the rotation angles. If None, no rotation is applied. Defaults to None.
Returns
-------
Tuple: [ArrayLike, ArrayLike]
Tuple containing the translated and rotated x and y coordinates.
"""
xt = x - x0
yt = y - y0
if phi is not None:
# Apply R(-phi)
c_phi = backend.cos(phi)
s_phi = backend.sin(phi)
# Simultaneous assignment
return xt * c_phi + yt * s_phi, yt * c_phi - xt * s_phi # fmt: skip
return xt, yt
[docs]
def derotate(vx, vy, phi: Optional[ArrayLike] = None):
"""
Applies inverse rotation to the velocity components (vx, vy) using the rotation angle phi.
Parameters
----------
vx: ArrayLike
ArrayLike containing the x-component of velocity.
vy: ArrayLike
ArrayLike containing the y-component of velocity.
phi: Optional[ArrayLike], optional)
ArrayLike containing the rotation angles. If None, no rotation is applied. Defaults to None.
Returns
-------
Tuple: [ArrayLike, ArrayLike]
Tuple containing the derotated x and y components of velocity.
"""
if phi is None:
return vx, vy
c_phi = backend.cos(phi)
s_phi = backend.sin(phi)
return vx * c_phi - vy * s_phi, vx * s_phi + vy * c_phi # fmt: skip
[docs]
def to_elliptical(x, y, q: ArrayLike):
"""
Converts Cartesian coordinates to elliptical coordinates.
Parameters
----------
x: ArrayLike
ArrayLike containing the x-coordinates.
y: ArrayLike
ArrayLike containing the y-coordinates.
q: ArrayLike
ArrayLike containing the elliptical parameters.
Returns
-------
Tuple: ArrayLike, ArrayLike
Tuple containing the x and y coordinates in elliptical form.
"""
return x, y / q
[docs]
def meshgrid(
pixelscale, nx, ny=None, device=None, dtype=backend.float32
) -> Tuple[ArrayLike, ArrayLike]:
"""
Generates a 2D meshgrid based on the provided pixelscale and dimensions.
Parameters
----------
pixelscale: float
The scale of the meshgrid in each dimension.
nx: int
The number of grid points along the x-axis.
ny: int
The number of grid points along the y-axis.
device: optional
The device on which to create the tensor. Defaults to None.
dtype: optional
The desired data type of the tensor. Defaults to torch.float32.
Returns
-------
Tuple: [ArrayLike, ArrayLike]
The generated meshgrid as a tuple of ArrayLikes.
"""
if ny is None:
ny = nx
xs = backend.linspace(-1, 1, nx, device=device, dtype=dtype) * pixelscale * (nx - 1) / 2 # fmt: skip
ys = backend.linspace(-1, 1, ny, device=device, dtype=dtype) * pixelscale * (ny - 1) / 2 # fmt: skip
return backend.meshgrid([xs, ys], indexing="xy")
[docs]
def plane_to_world_gnomonic(px, py, crval):
"""
Perform a gnomonic projection from a tangent plane to the celestial sphere
world coordinates.
Parameters
----------
px: ArrayLike
The x-coordinate of the point on the tangent plane in degrees.
py: ArrayLike
The y-coordinate of the point on the tangent plane in degrees.
crval: ArrayLike
The celestial sphere world coordinates in degrees where the tangent
plane meets the celestial sphere, should be a shape (2,) tensor. It is
assumed that the tangent plane is centered at (0,0) for these
coordinates. Thus ``crval`` matches the standard FITS convention.
Returns
-------
Tuple: [ArrayLike, ArrayLike]
Tuple containing the right ascension and declination in degrees.
"""
plane = backend.stack((px, py), -1) * deg_to_rad
rho = backend.sqrt(backend.sum(plane**2, dim=-1))
c = backend.arctan(rho)
# Convert to sky coordinates
ra = crval[0] + rad_to_deg * backend.arctan2(
plane[..., 0] * backend.sin(c),
rho * backend.cos(crval[1] * deg_to_rad) * backend.cos(c)
- plane[..., 1] * backend.sin(crval[1] * deg_to_rad) * backend.sin(c),
)
dec = backend.where(
rho == 0,
crval[1],
rad_to_deg
* backend.arcsin(
backend.cos(c) * backend.sin(crval[1] * deg_to_rad)
+ plane[..., 1] * backend.sin(c) * backend.cos(crval[1] * deg_to_rad) / rho
),
)
return ra, dec
[docs]
def pixel_to_plane(i, j, crpix, CD, sip_powers=[], sip_coefs=[], crplane=None):
"""
Convert pixel coordinates to a tangent plane using the WCS information. This
matches the FITS convention for SIP transformations.
For more information see:
* FITS World Coordinate System (WCS):
https://fits.gsfc.nasa.gov/fits_wcs.html
* Representations of world coordinates in FITS, 2002, by Geisen and
Calabretta
* The SIP Convention for Representing Distortion in FITS Image Headers,
2008, by Shupe and Hook
Parameters
----------
i: ArrayLike
The first coordinate of the pixel in pixel units. The origin may be
either 0 indexed (python convention) or 1 indexed (FITS convention),
simply ensure that ``crpix`` has the same convention.
j: ArrayLike
The second coordinate of the pixel in pixel units. The origin may be
either 0 indexed (python convention) or 1 indexed (FITS convention),
simply ensure that ``crpix`` has the same convention.
crpix: ArrayLike
The reference pixel in pixel units, should be a shape (2,) tensor. This
is the point that will be placed at ``crval`` in the world coordinates.
The origin may be either 0 indexed (python convention) or 1 indexed
(FITS convention), simply ensure that ``i`` and ``j`` have the same
convention.
CD: ArrayLike
The CD matrix in degrees per pixel. This 2x2 matrix is used to convert
from pixel to degree units and also handles rotation/skew.
sip_powers: ArrayLike
The powers of the pixel coordinates for the SIP distortion, should be a
shape (N orders, 2) tensor. ``N orders`` is the number of non-zero
polynomial coefficients. The second axis has the powers in order ``i,
j``.
sip_coefs: ArrayLike
The coefficients of the pixel coordinates for the SIP distortion, should
be a shape (N orders, 2) tensor. ``N orders`` is the number of non-zero
polynomial coefficients. The second axis has the coefficients in order
``delta_x, delta_y``.
crplane: Optional[ArrayLike], optional
The reference plane coordinates in degrees, should be a shape (2,)
tensor. This is the point that will be placed at ``crpix`` in the pixel
coordinates. If None, it is assumed to be (0, 0). Defaults to None.
Note
----
The representation of the SIP powers and coefficients assumes that the SIP
polynomial will use the same orders for both the x and y coordinates. If
this is not the case you may use zeros for the coefficients to ensure all
polynomial combinations are evaluated. However, it is very common to have
the same orders for both.
Note
----
While it is not perfect, an approximate inverse for the SIP distortion can
be determined by taking the negative of the coefficients (and using the
``plane_to_pixel`` function).
Returns
-------
Tuple: [ArrayLike, ArrayLike]
Tuple containing the x and y tangent plane coordinates in degrees.
"""
if crplane is None:
crplane = backend.zeros_like(crpix)
pixel = backend.stack((i, j), -1) - crpix
delta_p = backend.zeros_like(pixel)
for p in range(len(sip_powers)):
delta_p += sip_coefs[p] * backend.unsqueeze(
backend.prod(pixel ** sip_powers[p], dim=-1), -1
)
plane = backend.einsum("ij,...j->...i", CD, pixel + delta_p) + crplane
return plane[..., 0], plane[..., 1]
[docs]
def pixel_to_world(
i,
j,
crpix,
crval,
CD,
sip_powers=[],
sip_coefs=[],
crplane=None,
):
"""
Convert pixel coordinates to world coordinates using the WCS information.
This matches the FITS convention for SIP transformations.
For more information see:
* FITS World Coordinate System (WCS):
https://fits.gsfc.nasa.gov/fits_wcs.html
* Representations of world coordinates in FITS, 2002, by Geisen and
Calabretta
* The SIP Convention for Representing Distortion in FITS Image Headers,
2008, by Shupe and Hook
Parameters
----------
i: ArrayLike
The first coordinate of the pixel in pixel units. The origin may be either 0
indexed (python convention) or 1 indexed (FITS convention), simply
ensure that ``crpix`` has the same convention.
j: ArrayLike
The second coordinate of the pixel in pixel units. The origin may be either 0
indexed (python convention) or 1 indexed (FITS convention), simply
ensure that ``crpix`` has the same convention.
crpix: ArrayLike
The reference pixel in pixel units, should be a shape (2,) tensor. This
is the point that will be placed at ``crval`` in the world coordinates.
The origin may be either 0 indexed (python convention) or 1 indexed
(FITS convention), simply ensure that ``i`` and ``j`` have the same
convention.
crval: ArrayLike
The reference world coordinates in degrees, should be a shape (2,)
tensor. This is the point that will be placed at ``crpix`` in the pixel
coordinates.
CD: ArrayLike
The CD matrix in degrees per pixel. This 2x2 matrix is used to convert
from pixel to world units and also handles rotation/skew.
sip_powers: ArrayLike
The powers of the pixel coordinates for the SIP distortion, should be a
shape (N orders, 2) tensor. ``N orders`` is the number of non-zero
polynomial coefficients. The second axis has the powers in order ``i,
j``.
sip_coefs: ArrayLike
The coefficients of the pixel coordinates for the SIP distortion, should
be a shape (N orders, 2) tensor. ``N orders`` is the number of non-zero
polynomial coefficients. The second axis has the coefficients in order
``delta_x, delta_y``.
crplane: Optional[ArrayLike], optional
The reference plane coordinates in degrees, should be a shape (2,)
tensor. This is the point that will be placed at ``crpix`` in the pixel
coordinates. If None, it is assumed to be (0, 0). Defaults to None.
Note
----
The representation of the SIP powers and coefficients assumes that the SIP
polynomial will use the same orders for both the x and y coordinates. If
this is not the case you may use zeros for the coefficients to ensure all
polynomial combinations are evaluated. However, it is very common to have
the same orders for both.
Note
----
While it is not perfect, an approximate inverse for the SIP distortion can
be determined by taking the negative of the coefficients (and using the
``world_to_pixel`` function).
Returns
-------
Tuple: [ArrayLike, ArrayLike]
Tuple containing the right ascension and declination in degrees.
"""
px, py = pixel_to_plane(i, j, crpix, CD, sip_powers, sip_coefs, crplane)
ra, dec = plane_to_world_gnomonic(px, py, crval)
return ra, dec
[docs]
def world_to_plane_gnomonic(ra, dec, crval):
"""
Perform a gnomonic projection from the celestial sphere
world coordinates to a tangent plane.
Parameters
----------
ra: ArrayLike
The right ascension in degrees.
dec: ArrayLike
The declination in degrees.
crval: ArrayLike
The celestial sphere world coordinates in degrees where the tangent
plane meets the celestial sphere, should be a shape (2,) tensor. It is
assumed that the tangent plane is centered at (0,0) for these
coordinates. Thus ``crval`` matches the standard FITS convention.
Returns
-------
Tuple: [ArrayLike, ArrayLike]
Tuple containing the x and y tangent plane coordinates in degrees.
"""
ra = ra * deg_to_rad
dec = dec * deg_to_rad
cosc = backend.sin(crval[1] * deg_to_rad) * backend.sin(dec) + backend.cos(
crval[1] * deg_to_rad
) * backend.cos(dec) * backend.cos(ra - crval[0] * deg_to_rad)
x = backend.cos(dec) * backend.sin(ra - crval[0] * deg_to_rad) / cosc
y = (
backend.cos(crval[1] * deg_to_rad) * backend.sin(dec)
- backend.sin(crval[1] * deg_to_rad)
* backend.cos(dec)
* backend.cos(ra - crval[0] * deg_to_rad)
) / cosc
return x * rad_to_deg, y * rad_to_deg
[docs]
def plane_to_pixel(px, py, crpix, CD, sip_powers=[], sip_coefs=[], crplane=None):
"""
Convert tangent plane coordinates to pixel coordinates using the WCS
information. This matches the FITS convention for SIP transformations.
For more information see:
* FITS World Coordinate System (WCS):
https://fits.gsfc.nasa.gov/fits_wcs.html
* Representations of world coordinates in FITS, 2002, by Geisen and
Calabretta
* The SIP Convention for Representing Distortion in FITS Image Headers,
2008, by Shupe and Hook
Parameters
----------
px: ArrayLike
The x-coordinate of the point on the tangent plane in degrees.
py: ArrayLike
The y-coordinate of the point on the tangent plane in degrees.
crpix: ArrayLike
The reference pixel in pixel units, should be a shape (2,) tensor. This
is the point that will be placed at ``crval`` in the world coordinates.
The origin may be either 0 indexed (python convention) or 1 indexed
(FITS convention), ``i`` and ``j`` will have the same convention.
CD: ArrayLike
The CD matrix in degrees per pixel. This 2x2 matrix is used to convert
from pixel to world units and also handles rotation/skew.
sip_powers: ArrayLike
The powers of the pixel coordinates for the SIP distortion, should be a
shape (N orders, 2) tensor. ``N orders`` is the number of non-zero
polynomial coefficients. The second axis has the powers in order ``px,
py``.
sip_coefs: ArrayLike
The coefficients of the pixel coordinates for the SIP distortion, should
be a shape (N orders, 2) tensor. ``N orders`` is the number of non-zero
polynomial coefficients. The second axis has the coefficients in order
``delta_x, delta_y``.
crplane: Optional[ArrayLike], optional
The reference plane coordinates in degrees, should be a shape (2,)
tensor. This is the point that will be placed at ``crpix`` in the pixel
coordinates. If None, it is assumed to be (0, 0). Defaults to None.
Note
----
The representation of the SIP powers and coefficients assumes that the SIP
polynomial will use the same orders for both the x and y coordinates. If
this is not the case you may use zeros for the coefficients to ensure all
polynomial combinations are evaluated. However, it is very common to have
the same orders for both.
Note
----
While it is not perfect, an approximate inverse for the SIP distortion can
be determined by taking the negative of the coefficients (and using the
``pixel_to_plane`` function).
Returns
-------
Tuple: [ArrayLike, ArrayLike]
Tuple containing the ``i`` and ``j`` pixel coordinates (in pixel units).
"""
if crplane is None:
crplane = backend.zeros_like(crpix)
plane = backend.stack((px, py), -1) - crplane
iCD = backend.linalg.inv(CD)
pixel = backend.einsum("ij,...j->...i", iCD, plane)
delta_w = backend.zeros_like(plane)
for i in range(len(sip_powers)):
delta_w += sip_coefs[i] * backend.unsqueeze(
backend.prod(pixel ** sip_powers[i], dim=-1), -1
)
pixel += delta_w + crpix
return pixel[..., 0], pixel[..., 1]
[docs]
def world_to_pixel(
ra,
dec,
crpix,
crval,
CD,
sip_powers=[],
sip_coefs=[],
crplane=None,
):
"""
Convert world coordinates to pixel coordinates using the WCS information.
This matches the FITS convention for SIP transformations.
For more information see:
* FITS World Coordinate System (WCS):
https://fits.gsfc.nasa.gov/fits_wcs.html
* Representations of world coordinates in FITS, 2002, by Geisen and
Calabretta
* The SIP Convention for Representing Distortion in FITS Image Headers,
2008, by Shupe and Hook
Parameters
----------
ra: ArrayLike
The right ascension in degrees.
dec: ArrayLike
The declination in degrees.
crpix: ArrayLike
The reference pixel in pixel units, should be a shape (2,) tensor. This
is the point that will be placed at ``crval`` in the world coordinates.
The origin may be either 0 indexed (python convention) or 1 indexed
(FITS convention), ``i`` and ``j`` will have the same convention.
crval: ArrayLike
The reference world coordinates in degrees, should be a shape (2,)
tensor. This is the point that will be placed at ``crpix`` in the pixel
coordinates (unless ``crplane`` is non-zero).
CD: ArrayLike
The CD matrix in degrees per pixel. This 2x2 matrix is used to convert
from pixel to world units and also handles rotation/skew.
powers: ArrayLike
The powers of the pixel coordinates for the SIP distortion, should be a
shape (N orders, 2) tensor. ``N orders`` is the number of non-zero
polynomial coefficients. The second axis has the powers in order ``i,
j``.
coefs: ArrayLike
The coefficients of the pixel coordinates for the SIP distortion, should
be a shape (N orders, 2) tensor. ``N orders`` is the number of non-zero
polynomial coefficients. The second axis has the coefficients in order
``delta_x, delta_y``.
Note
----
The representation of the SIP powers and coefficients assumes that the SIP
polynomial will use the same orders for both the x and y coordinates. If
this is not the case you may use zeros for the coefficients to ensure all
polynomial combinations are evaluated. However, it is very common to have
the same orders for both.
Note
----
While it is not perfect, an approximate inverse for the SIP distortion can
be determined by taking the negative of the coefficients (and using the
``pixel_to_world`` function).
Returns
-------
Tuple: [ArrayLike, ArrayLike]
Tuple containing the x and y pixel coordinates (in pixels).
"""
px, py = world_to_plane_gnomonic(ra, dec, crval)
i, j = plane_to_pixel(px, py, crpix, CD, sip_powers, sip_coefs, crplane)
return i, j
@lru_cache(maxsize=32)
def _quad_table(n, p, dtype, device):
"""
Generate a meshgrid for quadrature points using Legendre-Gauss quadrature.
Parameters
----------
n : int
The number of quadrature points in each dimension.
p : torch.ArrayLike
The pixelscale.
dtype : torch.dtype
The desired data type of the tensor.
device : torch.device
The device on which to create the tensor.
Returns
-------
Tuple[torch.ArrayLike, torch.ArrayLike, torch.ArrayLike]
The generated meshgrid as a tuple of ArrayLikes.
"""
abscissa, weights = roots_legendre(n)
w = backend.as_array(weights, dtype=dtype, device=device)
a = p * backend.as_array(abscissa, dtype=dtype, device=device) / 2.0
X, Y = backend.meshgrid(a, a, indexing="xy")
W = backend.outer(w, w) / 4.0
X, Y = X.reshape(-1), Y.reshape(-1) # flatten
return X, Y, W.reshape(-1)
[docs]
def gaussian_quadrature_grid(
pixelscale,
X,
Y,
quad_level=3,
):
"""
Generates a 2D meshgrid for Gaussian quadrature based on the provided pixelscale and dimensions.
Parameters
----------
pixelscale : float
The scale of the meshgrid in each dimension.
X : ArrayLike
The x-coordinates of the pixel centers.
Y : ArrayLike
The y-coordinates of the pixel centers.
quad_level : int, optional
The number of quadrature points in each dimension. Default is 3.
Returns
-------
Tuple[ArrayLike, ArrayLike]
The generated meshgrid as a tuple of ArrayLikes.
Example
-------
Usage would look something like:: python
X, Y = meshgrid(pixelscale, nx, ny)
Xs, Ys, weight = gaussian_quadrature_grid(pixelscale, X, Y, quad_level)
F = your_brightness_function(Xs, Ys, other, parameters)
res = gaussian_quadrature_integrator(F, weight)
"""
# collect gaussian quadrature weights
abscissaX, abscissaY, weight = _quad_table(
quad_level, pixelscale, dtype=X.dtype, device=backend.device(X)
)
# Gaussian quadrature evaluation points
Xs = backend.repeat(X[..., None], quad_level**2, -1) + abscissaX
Ys = backend.repeat(Y[..., None], quad_level**2, -1) + abscissaY
return Xs, Ys, weight
[docs]
def gaussian_quadrature_integrator(
F: ArrayLike,
weight: ArrayLike,
):
"""
Performs a pixel-wise integration using Gaussian quadrature.
It takes the brightness function evaluated at the quadrature points `F`
and the quadrature weights `weight` as input.
The result is the integrated brightness function at each pixel.
Parameters
----------
F : ArrayLike
The brightness function evaluated at the quadrature points.
weight : ArrayLike
The quadrature weights as provided by the get_pixel_quad_integrator_grid function.
Returns
-------
ArrayLike
The integrated brightness function at each pixel.
Example
-------
Usage would look something like:: python
X, Y = meshgrid(pixelscale, nx, ny)
Xs, Ys, weight = gaussian_quadrature_grid(pixelscale, X, Y, quad_level)
F = your_brightness_function(Xs, Ys, other, parameters)
res = gaussian_quadrature_integrator(F, weight)
"""
return backend.sum(F * weight, dim=-1)
[docs]
def quad(
F: Callable,
pixelscale: float,
X: ArrayLike,
Y: ArrayLike,
args: Tuple = (),
quad_level: int = 3,
):
"""
Performs a pixel-wise integration on a function using Gaussian quadrature.
Parameters
----------
F : Callable
The brightness function to be evaluated at the quadrature points.
The function should take as input: F(X, Y, *args).
pixelscale : float
The scale of each pixel.
X : ArrayLike
The x-coordinates of the pixels.
Y : ArrayLike
The y-coordinates of the pixels.
args : Optional[Tuple], optional
Additional arguments to be passed to the brightness function, by default None.
quad_level : int, optional
The level of quadrature to use, by default 3.
Returns
-------
ArrayLike
The integrated brightness function at each pixel.
"""
X, Y, weight = gaussian_quadrature_grid(pixelscale, X, Y, quad_level)
F = F(X, Y, *args)
return gaussian_quadrature_integrator(F, weight)
[docs]
def safe_divide(num, denom):
"""
Safely divides two tensors, returning zero where the denominator is zero.
Parameters
----------
num: ArrayLike
The numerator tensor.
denom: ArrayLike
The denominator tensor.
Returns
-------
ArrayLike
The result of the division, with zero where the denominator was zero.
"""
return backend.where(denom != 0, num / denom, backend.zeros_like(num))
[docs]
def safe_log(x):
"""
Safely applies the logarithm to a tensor, returning zero where the tensor is zero.
Parameters
----------
x: ArrayLike
The input tensor.
Returns
-------
ArrayLike
The result of applying the logarithm, with zero where the input was zero.
"""
return backend.where(x != 0, backend.log(x), backend.zeros_like(x))
def _h_poly(t):
"""Helper function to compute the 'h' polynomial matrix used in the
cubic spline.
Parameters
----------
t: ArrayLike
A 1D tensor representing the normalized x values.
Returns
-------
ArrayLike
A 2D tensor of size (4, len(t)) representing the 'h' polynomial matrix.
"""
tt = t[None, :] ** (backend.arange(4, device=backend.device(t))[:, None])
A = backend.as_array(
[[1, 0, -3, 2], [0, 1, -2, 1], [0, 0, 3, -2], [0, 0, -1, 1]],
dtype=t.dtype,
device=backend.device(t),
)
return A @ tt
[docs]
def interp1d(
x: ArrayLike,
y: ArrayLike,
xs: ArrayLike,
extend: Literal["extrapolate", "const", "linear"] = "extrapolate",
) -> ArrayLike:
"""Compute the 1D cubic spline interpolation for the given data points
using PyTorch.
Parameters
----------
x: ArrayLike
A 1D tensor representing the x-coordinates of the known data points.
y: ArrayLike
A 1D tensor representing the y-coordinates of the known data points.
xs: ArrayLike
A 1D tensor representing the x-coordinates of the positions where
the cubic spline function should be evaluated.
extend: (str, optional)
The method for handling extrapolation, either "const", "extrapolate", or "linear".
Default is "extrapolate".
"const": Use the value of the last known data point for extrapolation.
"linear": Use linear extrapolation based on the last two known data points.
"extrapolate": Use cubic extrapolation of data.
Returns
-------
ArrayLike
A 1D tensor representing the interpolated values at the specified positions (xs).
"""
m = (y[1:] - y[:-1]) / (x[1:] - x[:-1])
m = backend.cat([m[0:1], (m[1:] + m[:-1]) / 2, m[-1:]])
idxs = backend.searchsorted(x[:-1], xs) - 1
dx = x[idxs + 1] - x[idxs]
hh = _h_poly((xs - x[idxs]) / dx)
ret = hh[0] * y[idxs] + hh[1] * m[idxs] * dx + hh[2] * y[idxs + 1] + hh[3] * m[idxs + 1] * dx # fmt: skip
if extend == "const":
ret[xs > x[-1]] = y[-1]
elif extend == "linear":
indices = xs > x[-1]
ret[indices] = y[-1] + (xs[indices] - x[-1]) * (y[-1] - y[-2]) / (x[-1] - x[-2])
return ret
[docs]
def interp2d(
im: ArrayLike,
x: ArrayLike,
y: ArrayLike,
method: Literal["linear", "nearest"] = "linear",
padding_mode: str = "zeros",
) -> ArrayLike:
"""
Interpolates a 2D image at specified coordinates. Similar to
`torch.nn.functional.grid_sample` with `align_corners=False`.
Parameters
----------
im: ArrayLike
A 2D tensor representing the image.
x: ArrayLike
A 0D or 1D tensor of x coordinates at which to interpolate.
y: ArrayLike
A 0D or 1D tensor of y coordinates at which to interpolate.
method: (str, optional)
Interpolation method. Either 'nearest' or 'linear'. Defaults to
'linear'.
padding_mode: (str, optional)
Defines the padding mode when out-of-bound indices are encountered.
Either 'zeros', 'clamp', or 'extrapolate'. Defaults to 'zeros' which
fills padded coordinates with zeros. The 'clamp' mode clamps the
coordinates to the image boundaries (essentially taking the border
values out to infinity). The 'extrapolate' mode extrapolates the outer
linear interpolation beyond the last pixel boundary.
Raises
------
ValueError
If `im` is not a 2D tensor.
ValueError
If `x` is not a 0D or 1D tensor.
ValueError
If `y` is not a 0D or 1D tensor.
ValueError
If `padding_mode` is not 'extrapolate' or 'zeros'.
ValueError
If `method` is not 'nearest' or 'linear'.
Returns
-------
ArrayLike
ArrayLike with the same shape as `x` and `y` containing the interpolated
values.
"""
if im.ndim != 2:
raise ValueError(f"im must be 2D (received {im.ndim}D tensor)")
if x.ndim > 1:
raise ValueError(f"x must be 0 or 1D (received {x.ndim}D tensor)")
if y.ndim > 1:
raise ValueError(f"y must be 0 or 1D (received {y.ndim}D tensor)")
if padding_mode not in ["extrapolate", "clamp", "zeros"]:
raise ValueError(f"{padding_mode} is not a valid padding mode")
if padding_mode == "clamp":
x = backend.clamp(x, -1, 1)
y = backend.clamp(y, -1, 1)
else:
idxs_out_of_bounds = (y < -1) | (y > 1) | (x < -1) | (x > 1)
# Convert coordinates to pixel indices
h, w = im.shape
x = 0.5 * ((x + 1) * w - 1)
y = 0.5 * ((y + 1) * h - 1)
if method == "nearest":
result = im[
backend.clamp(backend.long(backend.round(y)), 0, h - 1),
backend.clamp(backend.long(backend.round(x)), 0, w - 1),
]
elif method == "linear":
x0 = backend.clamp(backend.long(backend.floor(x)), 0, w - 2)
y0 = backend.clamp(backend.long(backend.floor(y)), 0, h - 2)
x1 = x0 + 1
y1 = y0 + 1
fa = im[y0, x0]
fb = im[y1, x0]
fc = im[y0, x1]
fd = im[y1, x1]
dx1 = x1 - x
dx0 = x - x0
dy1 = y1 - y
dy0 = y - y0
result = fa * dx1 * dy1 + fb * dx1 * dy0 + fc * dx0 * dy1 + fd * dx0 * dy0 # fmt: skip
else:
raise ValueError(f"{method} is not a valid interpolation method")
if padding_mode == "zeros": # else padding_mode == "extrapolate"
result = backend.where(idxs_out_of_bounds, backend.zeros_like(result), result)
return result
[docs]
def interp3d(
cu: ArrayLike,
x: ArrayLike,
y: ArrayLike,
t: ArrayLike,
method: Literal["linear", "nearest"] = "linear",
padding_mode: Literal["zeros", "extrapolate"] = "zeros",
) -> ArrayLike:
"""
Interpolates a 3D image at specified coordinates.
Similar to `torch.nn.functional.grid_sample` with `align_corners=False`.
Parameters
----------
cu: ArrayLike
A 3D tensor representing the cube.
x: ArrayLike
A 0D or 1D tensor of x coordinates at which to interpolate.
y: ArrayLike
A 0D or 1D tensor of y coordinates at which to interpolate.
t: ArrayLike
A 0D or 1D tensor of t coordinates at which to interpolate.
method: (str, optional)
Interpolation method. Either 'nearest' or 'linear'. Defaults to 'linear'.
padding_mode: (str, optional)
Defines the padding mode when out-of-bound indices are encountered.
Either 'zeros' or 'extrapolate'. Defaults to 'zeros'.
Raises
------
ValueError
If `cu` is not a 3D tensor.
ValueError
If `x` is not a 0D or 1D tensor.
ValueError
If `y` is not a 0D or 1D tensor.
ValueError
If `t` is not a 0D or 1D tensor.
ValueError
If `padding_mode` is not 'extrapolate' or 'zeros'.
ValueError
If `method` is not 'nearest' or 'linear'.
Returns
-------
ArrayLike
ArrayLike with the same shape as `x` and `y` containing the interpolated values.
"""
if cu.ndim != 3:
raise ValueError(f"im must be 3D (received {cu.ndim}D tensor)")
if t.ndim > 1:
raise ValueError(f"t must be 0 or 1D (received {t.ndim}D tensor)")
if padding_mode not in ["extrapolate", "zeros"]:
raise ValueError(f"{padding_mode} is not a valid padding mode")
idxs_out_of_bounds = (y < -1) | (y > 1) | (x < -1) | (x > 1) | (t < -1) | (t > 1)
# Convert coordinates to pixel indices
d, h, w = cu.shape
x = 0.5 * ((x + 1) * w - 1)
y = 0.5 * ((y + 1) * h - 1)
t = 0.5 * ((t + 1) * d - 1)
if method == "nearest":
result = cu[
backend.clamp(backend.long(backend.round(t)), 0, d - 1),
backend.clamp(backend.long(backend.round(y)), 0, h - 1),
backend.clamp(backend.long(backend.round(x)), 0, w - 1),
]
elif method == "linear":
x0 = backend.clamp(backend.long(backend.floor(x)), 0, w - 2)
y0 = backend.clamp(backend.long(backend.floor(y)), 0, h - 2)
t0 = backend.clamp(backend.long(backend.floor(t)), 0, d - 2)
x1 = x0 + 1
y1 = y0 + 1
t1 = t0 + 1
fa = cu[t0, y0, x0]
fb = cu[t0, y1, x0]
fc = cu[t0, y0, x1]
fd = cu[t0, y1, x1]
fe = cu[t1, y0, x0]
ff = cu[t1, y1, x0]
fg = cu[t1, y0, x1]
fh = cu[t1, y1, x1]
xd = x - x0
yd = y - y0
td = t - t0
c00 = fa * (1 - xd) + fc * xd
c01 = fe * (1 - xd) + fg * xd
c10 = fb * (1 - xd) + fd * xd
c11 = ff * (1 - xd) + fh * xd
c0 = c00 * (1 - yd) + c10 * yd
c1 = c01 * (1 - yd) + c11 * yd
result = c0 * (1 - td) + c1 * td
else:
raise ValueError(f"{method} is not a valid interpolation method")
if padding_mode == "zeros": # else padding_mode == "extrapolate"
result = backend.where(idxs_out_of_bounds, backend.zeros_like(result), result)
return result
# Bicubic interpolation coefficients
# These are the coefficients for the bicubic interpolation kernel.
# To quote numerical recipes:
# The formulas that obtain the c’s from the function and derivative values
# are just a complicated linear transformation, with coefficients which,
# having been determined once in the mists of numerical history, can be
# tabulated and forgotten
BC = (
(1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0),
(0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0),
(-3, 0, 0, 3, 0, 0, 0, 0, -2, 0, 0, -1, 0, 0, 0, 0),
(2, 0, 0, -2, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0),
(0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0),
(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0),
(0, 0, 0, 0, -3, 0, 0, 3, 0, 0, 0, 0, -2, 0, 0, -1),
(0, 0, 0, 0, 2, 0, 0, -2, 0, 0, 0, 0, 1, 0, 0, 1),
(-3, 3, 0, 0, -2, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0),
(0, 0, 0, 0, 0, 0, 0, 0, -3, 3, 0, 0, -2, -1, 0, 0),
(9, -9, 9, -9, 6, 3, -3, -6, 6, -6, -3, 3, 4, 2, 1, 2),
(-6, 6, -6, 6, -4, -2, 2, 4, -3, 3, 3, -3, -2, -1, -1, -2),
(2, -2, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0),
(0, 0, 0, 0, 0, 0, 0, 0, 2, -2, 0, 0, 1, 1, 0, 0),
(-6, 6, -6, 6, -3, -3, 3, 3, -4, 4, 2, -2, -2, -2, -1, -1),
(4, -4, 4, -4, 2, 2, -2, -2, 2, -2, -2, 2, 1, 1, 1, 1),
)
[docs]
def bicubic_kernels(Z, d1, d2):
"""
This is just a quick script to compute the necessary derivatives using
finite differences. This is not the most accurate way to compute the
derivatives, but it is good enough for most purposes.
"""
dZ1 = backend.zeros_like(Z)
dZ2 = backend.zeros_like(Z)
dZ12 = backend.zeros_like(Z)
# First derivatives on first axis
# df/dx = (f(x+h, y) - f(x-h, y)) / 2h
dZ1 = backend.fill_at_indices(dZ1, slice(1, -1), (Z[:-2] - Z[2:]) / (2 * d1))
dZ1 = backend.fill_at_indices(dZ1, 0, (Z[0] - Z[1]) / d1)
dZ1 = backend.fill_at_indices(dZ1, -1, (Z[-2] - Z[-1]) / d1)
# First derivatives on second axis
# df/dy = (f(x,y+h) - f(x,y-h)) / h
dZ2 = backend.fill_at_indices(
dZ2, (slice(None), slice(1, -1)), (Z[:, :-2] - Z[:, 2:]) / (2 * d2)
)
dZ2 = backend.fill_at_indices(dZ2, (slice(None), 0), (Z[:, 0] - Z[:, 1]) / d2)
dZ2 = backend.fill_at_indices(dZ2, (slice(None), -1), (Z[:, -2] - Z[:, -1]) / d2)
# Second derivatives across both axes
# d2f/dxdy = (f(x-h, y-k) - f(x-h, y+k) - f(x+h, y-k) + f(x+h, y+k)) / (4hk)
dZ12 = backend.fill_at_indices(
dZ12,
(slice(1, -1), slice(1, -1)),
(Z[:-2, :-2] - Z[:-2, 2:] - Z[2:, :-2] + Z[2:, 2:]) / (4 * d1 * d2),
)
return dZ1, dZ2, dZ12
[docs]
def interp_bicubic(
x,
y,
Z,
dZ1=None,
dZ2=None,
dZ12=None,
get_Y: bool = True,
get_dY: bool = False,
get_ddY: bool = False,
):
"""
Compute bicubic interpolation of a 2D grid at arbitrary locations. This will
smoothly interpolate a grid of points, including smooth first derivatives
and smooth cross derivative (d^2Y/dxdy). For the derivatives, continuity is
enforced, though the transition may be sharp as higher order derivatives are
not considered.
The interpolation requires knowing the values of the first derivative in
each axis and the cross derivative. If these are not provided, they will be
estimated using central differences. For this function, the derivatives
should be provided in pixel units. The interpolation will be more accurate
if an analytic value is available for the derivatives.
See Numerical Recipes in C, Chapter 3 (specifically: "Higher Order for
Smoothness: Bicubic Interpolation") for more details.
Parameters
----------
x : torch.ArrayLike
x-coordinates of the points to interpolate. Must be a 0D or 1D tensor.
It should be in (-1,1) fov units, meaning that -1 is the left edge of
the left pixel, and 1 is the right edge of the right pixel.
y : torch.ArrayLike
y-coordinates of the points to interpolate. Must be a 0D or 1D tensor.
It should be in (-1,1) fov units, meaning that -1 is the bottom edge of
the bottom pixel, and 1 is the top edge of the top pixel.
Z : torch.ArrayLike
2D grid of values to interpolate. The first axis corresponds to the
y-axis and the second axis to the x-axis. The values in Z correspond to
pixel center values, so Z[0,0] is the value at the center of the bottom
left corner pixel of the grid. The grid should be at least 2x2 so the
bicubic interpolation can go between the values.
dZ1 : torch.ArrayLike or None
First derivative of Z along the x-axis. If None, it will be estimated
using central differences. Note that the derivative should be computed
in pixel units, meaning that the distance from one pixel to the next is
considered "1" in these units.
dZ2 : torch.ArrayLike or None
First derivative of Z along the y-axis. If None, it will be estimated
using central differences. Note that the derivative should be computed
in pixel units, meaning that the distance from one pixel to the next is
considered "1" in these units.
dZ12 : torch.ArrayLike or None
Second derivative of Z along both axes. If None, it will be estimated
using central differences. Note that the derivative should be computed
in pixel units, meaning that the distance from one pixel to the next is
considered "1" in these units.
get_Y : bool
Whether to return the interpolated values. This will add the estimated Y
values to the return tuple
get_dY : bool
Whether to return the interpolated first derivatives. This will add dY1
and dY2 to the return tuple
get_ddY : bool
Whether to return the interpolated second derivatives. This will add
dY12, dY11, and dY22 to the return tuple
Returns
-------
Y : torch.ArrayLike or None
Interpolated values at the given locations. Only returned if get_Y is
True
dY1 : torch.ArrayLike or None
Interpolated first derivative along the x-axis. Only returned if get_dY
is True
dY2 : torch.ArrayLike or None
Interpolated first derivative along the y-axis. Only returned if get_dY
is True
dY12 : torch.ArrayLike or None
Interpolated second derivative along both axes. Only returned if get_ddY
is True
dY11 : torch.ArrayLike or None
Interpolated second derivative along the x-axis. Only returned if
get_ddY is True
dY22 : torch.ArrayLike or None
Interpolated second derivative along the y-axis. Only returned if
get_ddY is True
"""
if Z.ndim != 2:
raise ValueError(f"Z must be 2D (received {Z.ndim}D tensor)")
if x.ndim > 1:
raise ValueError(f"x must be 0 or 1D (received {x.ndim}D tensor)")
if y.ndim > 1:
raise ValueError(f"y must be 0 or 1D (received {y.ndim}D tensor)")
# Convert coordinates to pixel indices
h, w = Z.shape
x = 0.5 * ((x + 1) * w - 1)
x = backend.clamp(x, -0.5, w - 0.5)
y = 0.5 * ((y + 1) * h - 1)
y = backend.clamp(y, -0.5, h - 0.5)
# Compute bicubic kernels if not provided
if dZ1 is None or dZ2 is None or dZ12 is None:
_dZ1, _dZ2, _dZ12 = bicubic_kernels(Z, 1.0, 1.0)
if dZ1 is None:
dZ1 = _dZ1
if dZ2 is None:
dZ2 = _dZ2
if dZ12 is None:
dZ12 = _dZ12
# Extract pixel values
x0 = backend.long(backend.floor(x))
y0 = backend.long(backend.floor(y))
x1 = x0 + 1
y1 = y0 + 1
x0 = backend.clamp(x0, 0, w - 2)
x1 = backend.clamp(x1, 1, w - 1)
y0 = backend.clamp(y0, 0, h - 2)
y1 = backend.clamp(y1, 1, h - 1)
# Build interpolation vector
v = []
v.append(Z[y0, x0])
v.append(Z[y0, x1])
v.append(Z[y1, x1])
v.append(Z[y1, x0])
v.append(dZ1[y0, x0])
v.append(dZ1[y0, x1])
v.append(dZ1[y1, x1])
v.append(dZ1[y1, x0])
v.append(dZ2[y0, x0])
v.append(dZ2[y0, x1])
v.append(dZ2[y1, x1])
v.append(dZ2[y1, x0])
v.append(dZ12[y0, x0])
v.append(dZ12[y0, x1])
v.append(dZ12[y1, x1])
v.append(dZ12[y1, x0])
v = backend.stack(v, dim=-1)
# Compute interpolation coefficients
c = (
backend.as_array(BC, dtype=v.dtype, device=backend.device(v))
[docs]
@ backend.unsqueeze(v, -1)
).reshape(-1, 4, 4)
# Compute interpolated values
return_interp = []
t = backend.where(
(x < 0), (x % 1) - 1, backend.where(x >= w - 1, x % 1 + 1, x % 1)
) # TODO: change to x - x0
u = backend.where((y < 0), (y % 1) - 1, backend.where(y >= h - 1, y % 1 + 1, y % 1))
if get_Y:
Y = backend.zeros_like(x)
for i in range(4):
for j in range(4):
Y = Y + c[:, i, j] * t**i * u**j
return_interp.append(Y)
if get_dY:
dY1 = backend.zeros_like(x)
dY2 = backend.zeros_like(x)
for i in range(4):
for j in range(4):
if i > 0:
dY1 = dY1 + i * c[:, i, j] * t ** (i - 1) * u**j
if j > 0:
dY2 = dY2 + j * c[:, i, j] * t**i * u ** (j - 1)
return_interp.append(dY1)
return_interp.append(dY2)
if get_ddY:
dY12 = backend.zeros_like(x)
dY11 = backend.zeros_like(x)
dY22 = backend.zeros_like(x)
for i in range(4):
for j in range(4):
if i > 0 and j > 0:
dY12 = dY12 + i * j * c[:, i, j] * t ** (i - 1) * u ** (j - 1)
if i > 1:
dY11 = dY11 + i * (i - 1) * c[:, i, j] * t ** (i - 2) * u**j
if j > 1:
dY22 = dY22 + j * (j - 1) * c[:, i, j] * t**i * u ** (j - 2)
return_interp.append(dY12)
return_interp.append(dY11)
return_interp.append(dY22)
return tuple(return_interp)
def vmap_n(
func: Callable,
depth: int = 1,
in_dims: Union[int, Tuple] = 0,
out_dims: Union[int, Tuple[int, ...]] = 0,
randomness: str = "error",
) -> Callable:
"""
Transforms a function `depth` times using `torch.vmap` with the same arguments passed each time.
Returns `func` transformed `depth` times by `vmap`, with the same arguments
passed to `vmap` each time.
Parameters
----------
func: Callable
The function to transform.
depth: (int, optional)
The number of times to apply `torch.vmap`. Defaults to 1.
in_dims: (Union[int, Tuple], optional)
The dimensions to vectorize over in the input. Defaults to 0.
out_dims: (Union[int, Tuple[int, ...]], optional):
The dimensions to vectorize over in the output. Defaults to 0.
randomness: (str, optional)
How to handle randomness. Defaults to 'error'.
Raises
------
ValueError
If `depth` is less than 1.
Returns
-------
Callable
The transformed function.
TODO: test.
"""
if depth < 1:
raise ValueError("vmap_n depth must be >= 1")
vmapd_func = func
for _ in range(depth):
vmapd_func = backend.vmap(vmapd_func, in_dims, out_dims, randomness)
return vmapd_func
def _chunk_input(x, k, in_dims, chunk_size):
if isinstance(in_dims, tuple):
if chunk_size is None:
n_chunks = 1
else:
i = 0
while in_dims[i] is None:
i += 1
B = x[i].shape[in_dims[i]]
n_chunks = ceil(B / chunk_size)
# Break data into chunks
chunks = [[] for _ in range(n_chunks)]
for subx, in_dim in zip(x, in_dims):
if in_dim is None:
subchunking = [subx] * n_chunks
else:
subchunking = backend.chunk(subx, n_chunks, dim=in_dim)
for j, subchunk in enumerate(subchunking):
chunks[j].append(subchunk)
else: # isinstance(in_dims, dict)
if chunk_size is None:
n_chunks = 1
else:
for key, value in in_dims.items():
if value is not None:
B = k[key].shape[value]
n_chunks = ceil(B / chunk_size)
break
# Break data into chunks
chunks = [{} for _ in range(n_chunks)]
for key, value in in_dims.items():
if value is None:
subchunking = [k[key]] * n_chunks
else:
subchunking = backend.chunk(k[key], n_chunks, dim=value)
for j, subchunk in enumerate(subchunking):
chunks[j][key] = subchunk
return chunks
[docs]
def vmap_reduce(
func: Callable,
reduce_func: Callable = lambda x: backend.sum(x, dim=0),
chunk_size: Optional[int] = None,
in_dims: Union[Tuple[int, ...], Dict[str, int]] = (0,),
out_dims: Union[int, Tuple[int, ...]] = 0,
**kwargs,
) -> Callable:
"""
Applies `torch.vmap` to `func` and then reduces the output using
`reduce_func` along the appropriate dimensions. This saves on memory
management if the dimension being reduced can cause the intermediate tensor
(before reduction) to be large.
Note
----
The chunking and reduction is only "one level deep". If the output of `func`
is still large even after chunking, this function will not completely solve
the problem. Essentially if the batch dimension divided by chunk_size is
still larger than chunk_size, then you will still have a large intermediate
tensor.
Parameters
----------
func: Callable
The function to transform.
reduce_func: Callable
The function to reduce the output of `func`.
in_dims: Tuple[int,...]
The dimensions to vectorize over in the input.
out_dims: Tuple[int,...]
The dimension to stack the output over.
chunk_size: (Optional[int])
The size of the chunks to process. If None, the entire input is
processed at once.
kwargs: Dict
Additional keyword arguments to pass to `torch.vmap`.
Returns
-------
ArrayLike
The reduced output.
"""
if isinstance(in_dims, tuple):
vfunc = backend.vmap(func, in_dims, **kwargs)
else: # isinstance(in_dims, dict)
vfunc = backend.vmap(func, (in_dims,), **kwargs)
def wrapped(*x, **k):
# Determine chunks
chunks = _chunk_input(x, k, in_dims, chunk_size)
# Process and reduce the chunks
if isinstance(in_dims, tuple):
out = tuple(reduce_func(vfunc(*chunk)) for chunk in chunks)
else: # isinstance(in_dims, dict)
out = tuple(reduce_func(vfunc(chunk)) for chunk in chunks)
# Stack the output
if isinstance(out_dims, int):
out = backend.stack(out, dim=out_dims)
else:
out = tuple(
backend.stack([o[i] for o in out], dim=d)
for i, d in enumerate(out_dims)
)
# Reduce the output
return reduce_func(out)
return wrapped
[docs]
def cluster_means(xs: ArrayLike, k: int, key=None):
"""
Computes cluster means using the k-means++ initialization algorithm.
Parameters
----------
xs: ArrayLike
A tensor of data points.
k: int
The number of clusters.
key: Optional[Union[None, jax.random.key]]
A jax.random.key if using the Jax backend.
Returns
-------
ArrayLike
A tensor of cluster means.
"""
b = len(xs)
mean_idxs = [
int(backend.randint(high=b, size=(), device=backend.device(xs), key=key).item())
]
means = [xs[mean_idxs[0]]]
for _ in range(1, k):
unselected_xs = backend.stack(
[x for i, x in enumerate(xs) if i not in mean_idxs]
)
# Distances to all means
d2s = backend.sum(
(unselected_xs[:, None, :] - backend.stack(means)[None, :, :]) ** 2, -1
)
# Distances to closest mean
d2s_closest = backend.as_array(
[d2s[i, m] for i, m in enumerate(d2s.argmin(-1))]
)
# Add point furthest from closest mean as next mean
new_idx = int(backend.argmax(d2s_closest).item())
means.append(unselected_xs[new_idx])
mean_idxs.append(new_idx)
return backend.stack(means)
def _lm_step(f, X, Y, Cinv, L, Lup, Ldn, epsilon, L_min, L_max):
# Forward
fY = f(X)
dY = Y - fY
# Jacobian
J = backend.jacfwd(f)(X)
J = backend.to(J, dtype=X.dtype)
if Cinv.ndim == 1:
chi2 = backend.sum(dY**2 * Cinv, -1)
else:
chi2 = backend.sum(dY @ Cinv @ dY, -1)
# Gradient
if Cinv.ndim == 1:
grad = J.T @ (dY * Cinv)
else:
grad = J.T @ Cinv @ dY
# Hessian
if Cinv.ndim == 1:
hess = J.T @ (J * Cinv.reshape(-1, 1))
else:
hess = J.T @ Cinv @ J
hess_perturb = L * backend.eye(hess.shape[0], device=backend.device(hess))
hess = hess + hess_perturb
# Step
h = backend.linalg.solve(hess, grad)
# New chi^2
fYnew = f(X + h)
dYnew = Y - fYnew
if Cinv.ndim == 1:
chi2_new = backend.sum(dYnew**2 * Cinv, -1)
else:
chi2_new = backend.sum(dYnew @ Cinv @ dYnew, -1)
# Test
expected_improvement = backend.dot(h, hess @ h) + 2 * backend.dot(h, grad)
rho = (chi2 - chi2_new) / backend.abs(expected_improvement) # fmt: skip
# Update
X = backend.where(rho >= epsilon, X + h, X)
chi2 = backend.where(rho > epsilon, chi2_new, chi2)
L = backend.clamp(backend.where(rho >= epsilon, L / Ldn, L * Lup), L_min, L_max)
return X, L, chi2
[docs]
def batch_lm(
X, # B, Din
Y, # B, Dout
f, # Din -> Dout
C=None, # B, Dout, Dout !or! B, Dout
epsilon=1e-1,
L=1e0,
L_dn=11.0,
L_up=9.0,
max_iter=50,
L_min=1e-9,
L_max=1e9,
stopping=1e-4,
f_args=(),
f_kwargs={},
):
B, Din = X.shape
B, Dout = Y.shape
if len(X) != len(Y):
raise ValueError("x and y must having matching batch dimension")
if C is None:
C = backend.ones_like(Y)
if C.ndim == 2:
Cinv = 1 / C
else:
Cinv = backend.linalg.inv(C)
Cinv = backend.to(Cinv, dtype=X.dtype)
v_lm_step = backend.vmap(
partial(
_lm_step,
lambda x: f(x, *f_args, **f_kwargs),
Lup=L_up,
Ldn=L_dn,
epsilon=epsilon,
L_min=L_min,
L_max=L_max,
)
)
L = L * backend.ones(B, device=backend.device(X), dtype=X.dtype)
for _ in range(max_iter):
Xnew, L, C = v_lm_step(X, Y, Cinv, L)
if (
backend.all(backend.abs(Xnew - X) < stopping)
and backend.sum(L < 1e-2).item() > B / 3
):
break
if backend.all(L >= L_max):
break
X = Xnew
return X, L, C
[docs]
def gaussian(pixelscale, nx, ny, sigma, upsample=1, dtype=backend.float32, device=None):
X, Y = backend.meshgrid(
backend.linspace(
-(nx * upsample - 1) * pixelscale / 2,
(nx * upsample - 1) * pixelscale / 2,
nx * upsample,
dtype=dtype,
device=device,
),
backend.linspace(
-(ny * upsample - 1) * pixelscale / 2,
(ny * upsample - 1) * pixelscale / 2,
ny * upsample,
dtype=dtype,
device=device,
),
indexing="xy",
)
Z = backend.exp(-0.5 * (X**2 + Y**2) / sigma**2)
Z = backend.sum(Z.reshape(ny, upsample, nx, upsample), dim=(1, 3))
return Z / backend.sum(Z)