Source code for caustics.lenses.pixelated_convergence

# mypy: disable-error-code="index,dict-item"
from typing import Optional, Annotated, Union, Literal

import numpy as np
from caskade import forward, Param

from ..backend_obj import backend, ArrayLike
from ..utils import interp2d
from .base import ThinLens, CosmologyType, NameType, ZType
from . import func

__all__ = ("PixelatedConvergence",)


[docs] class PixelatedConvergence(ThinLens): _null_params = { "x0": 0.0, "y0": 0.0, "convergence_map": np.logspace(0, 1, 100, dtype=np.float32).reshape(10, 10), } def __init__( self, pixelscale: Annotated[float, "pixelscale"], cosmology: CosmologyType, z_l: ZType = None, z_s: ZType = None, x0: Annotated[ Optional[Union[ArrayLike, float]], "The x-coordinate of the center of the grid", True, ] = backend.make_array(0.0), y0: Annotated[ Optional[Union[ArrayLike, float]], "The y-coordinate of the center of the grid", True, ] = backend.make_array(0.0), convergence_map: Annotated[ Optional[ArrayLike], "A 2D tensor representing the convergence map", True, ] = None, scale: Annotated[ Optional[ArrayLike], "A scale factor to multiply by the convergence map", True, ] = 1.0, shape: Annotated[ tuple[Optional[int], ...], "The shape of the convergence map" ] = ( None, None, ), convolution_mode: Annotated[ Literal["fft", "conv2d"], "The convolution mode for calculating deflection angles and lensing potential", ] = "fft", use_next_fast_len: Annotated[ bool, "If True, adds additional padding to speed up the FFT by calling `scipy.fft.next_fast_len`", ] = True, padding: Annotated[ Literal["zero", "circular", "reflect", "tile"], "Specifies the type of padding", ] = "zero", window_kernel: Annotated[float, "Amount of kernel to be windowed"] = 1.0 / 8.0, name: NameType = None, ): """Strong lensing with user provided kappa map PixelatedConvergence is a class for strong gravitational lensing with a user-provided kappa map. It inherits from the ThinLens class. This class enables the computation of deflection angles and lensing potential by applying the user-provided kappa map to a grid using either Fast Fourier Transform (FFT) or a 2D convolution. Attributes ---------- name: string The name of the PixelatedConvergence object. fov: float The field of view in arcseconds. *Unit: arcsec* cosmology: Cosmology An instance of the cosmological parameters. z_l: Optional[ArrayLike] The redshift of the lens. *Unit: unitless* z_s: Optional[ArrayLike] The redshift of the source. *Unit: unitless* x0: Optional[ArrayLike] The x-coordinate of the center of the grid. *Unit: arcsec* y0: Optional[ArrayLike] The y-coordinate of the center of the grid. *Unit: arcsec* convergence_map: Optional[ArrayLike] A 2D tensor representing the convergence map. *Unit: unitless* shape: Optional[tuple[int, ...]] The shape of the convergence map. convolution_mode: str, optional The convolution mode for calculating deflection angles and lensing potential. It can be either "fft" (Fast Fourier Transform) or "conv2d" (2D convolution). Default is "fft". use_next_fast_len: bool, optional If True, adds additional padding to speed up the FFT by calling `scipy.fft.next_fast_len`. The speed boost can be substantial when `n_pix` is a multiple of a small prime number. Default is True. padding: { "zero", "circular", "reflect", "tile" } Specifies the type of padding to use: "zero" will do zero padding, "circular" will do cyclic boundaries. "reflect" will do reflection padding. "tile" will tile the image at 2x2 which basically identical to circular padding, but is easier. Use zero padding to represent an overdensity, the other padding schemes represent a mass distribution embedded in a field of similar mass distributions. Generally you should use either "zero" or "tile". window_kernel: float, optional Amount of kernel to be windowed, specify the fraction of the kernel size from which a linear window scaling will ensure the edges go to zero for the purpose of FFT stability. Set to 0 for no windowing. Default is 1/8. """ super().__init__(cosmology, z_l, name=name, z_s=z_s) if convergence_map is not None and convergence_map.ndim != 2: raise ValueError( f"convergence_map must be 2D. Received a {convergence_map.ndim}D tensor)" ) elif shape is not None and len(shape) != 2: raise ValueError(f"shape must specify a 2D tensor. Received shape={shape}") self.x0 = Param("x0", x0, shape=(), units="arcsec") self.y0 = Param("y0", y0, shape=(), units="arcsec") self.convergence_map = Param( "convergence_map", convergence_map, shape, units="unitless" ) self.scale = Param("scale", scale, shape=(), units="flux", valid=(0, None)) self.pixelscale = pixelscale assert ( self.convergence_map.shape[0] == self.convergence_map.shape[1] ), f"Convergence map must be square, not {self.convergence_map.shape}" self.n_pix = self.convergence_map.shape[0] self.use_next_fast_len = use_next_fast_len self.padding = padding # Construct kernels self.ax_kernel, self.ay_kernel, self.potential_kernel = ( func.build_kernels_pixelated_convergence(pixelscale, self.n_pix) ) # Window the kernels if needed if padding != "zero" and convolution_mode == "fft" and window_kernel > 0: window = func.build_window_pixelated_convergence( window_kernel, self.ax_kernel.shape ) self.ax_kernel = self.ax_kernel * window self.ay_kernel = self.ay_kernel * window self.potential_kernel_tilde = None self.ax_kernel_tilde = None self.ay_kernel_tilde = None self._s = None # Triggers creation of FFTs of kernels self.convolution_mode = convolution_mode
[docs] def to(self, device=None, dtype=None): """ Move the ConvergenceGrid object and all its tensors to the specified device and dtype. Parameters ---------- device: optional The target device to move the tensors to. dtype: optional The target data type to cast the tensors to. """ super().to(device, dtype) self.potential_kernel = backend.to( self.potential_kernel, device=device, dtype=dtype ) self.ax_kernel = backend.to(self.ax_kernel, device=device, dtype=dtype) self.ay_kernel = backend.to(self.ay_kernel, device=device, dtype=dtype) if self.potential_kernel_tilde is not None: self.potential_kernel_tilde = backend.to( self.potential_kernel_tilde, device=device ) if self.ax_kernel_tilde is not None: self.ax_kernel_tilde = backend.to(self.ax_kernel_tilde, device=device) if self.ay_kernel_tilde is not None: self.ay_kernel_tilde = backend.to(self.ay_kernel_tilde, device=device)
@property def convolution_mode(self): """ Get the convolution mode of the ConvergenceGrid object. Returns ------- string The convolution mode, either "fft" or "conv2d". """ return self._convolution_mode @convolution_mode.setter def convolution_mode(self, convolution_mode: str): """ Set the convolution mode of the ConvergenceGrid object. Parameters ---------- mode: string The convolution mode to be set, either "fft" or "conv2d". """ if convolution_mode == "fft": # Create FFTs of kernels self.potential_kernel_tilde = backend.fft.rfft2( self.potential_kernel, func._fft_size(self.n_pix) ) self.ax_kernel_tilde = backend.fft.rfft2( self.ax_kernel, func._fft_size(self.n_pix) ) self.ay_kernel_tilde = backend.fft.rfft2( self.ay_kernel, func._fft_size(self.n_pix) ) elif convolution_mode == "conv2d": # Drop FFTs of kernels self.potential_kernel_tilde = self.potential_kernel self.ax_kernel_tilde = self.ax_kernel self.ay_kernel_tilde = self.ay_kernel else: raise ValueError("invalid convolution convolution_mode") self._convolution_mode = convolution_mode
[docs] @forward def reduced_deflection_angle( self, x: ArrayLike, y: ArrayLike, x0: Annotated[ArrayLike, "Param"], y0: Annotated[ArrayLike, "Param"], convergence_map: Annotated[ArrayLike, "Param"], scale: Annotated[ArrayLike, "Param"], ) -> tuple[ArrayLike, ArrayLike]: """ Compute the deflection angles at the specified positions using the given convergence map. Parameters ---------- x: ArrayLike The x-coordinates of the positions to compute the deflection angles for. *Unit: arcsec* y: ArrayLike The y-coordinates of the positions to compute the deflection angles for. *Unit: arcsec* Returns ------- x_component: ArrayLike Deflection Angle in the x-direction. *Unit: arcsec* y_component: ArrayLike Deflection Angle in the y-direction. *Unit: arcsec* """ return func.reduced_deflection_angle_pixelated_convergence( x0, y0, convergence_map * scale, x, y, self.ax_kernel_tilde, self.ay_kernel_tilde, self.pixelscale, self.n_pix * self.pixelscale, self.n_pix, self.padding, self.convolution_mode, )
[docs] @forward def potential( self, x: ArrayLike, y: ArrayLike, x0: Annotated[ArrayLike, "Param"], y0: Annotated[ArrayLike, "Param"], convergence_map: Annotated[ArrayLike, "Param"], scale: Annotated[ArrayLike, "Param"], ) -> ArrayLike: """ Compute the lensing potential at the specified positions using the given convergence map. Parameters ---------- x: ArrayLike The x-coordinates of the positions to compute the lensing potential for. *Unit: arcsec* y: ArrayLike The y-coordinates of the positions to compute the lensing potential for. *Unit: arcsec* Returns ------- ArrayLike The lensing potential at the specified positions. *Unit: arcsec^2* """ return func.potential_pixelated_convergence( x0, y0, convergence_map * scale, x, y, self.potential_kernel_tilde, self.pixelscale, self.n_pix * self.pixelscale, self.n_pix, self.padding, self.convolution_mode, )
[docs] @forward def convergence( self, x: ArrayLike, y: ArrayLike, x0: Annotated[ArrayLike, "Param"], y0: Annotated[ArrayLike, "Param"], convergence_map: Annotated[ArrayLike, "Param"], scale: Annotated[ArrayLike, "Param"], ) -> ArrayLike: """ Compute the convergence at the specified positions. Parameters ---------- x: ArrayLike The x-coordinates of the positions to compute the convergence for. *Unit: arcsec* y: ArrayLike The y-coordinates of the positions to compute the convergence for. *Unit: arcsec* Returns ------- ArrayLike The convergence at the specified positions. *Unit: unitless* """ fov_x = convergence_map.shape[1] * self.pixelscale fov_y = convergence_map.shape[0] * self.pixelscale return interp2d( convergence_map * scale, backend.view(x - x0, -1) / fov_x * 2, backend.view(y - y0, -1) / fov_y * 2, ).reshape(x.shape)