Source code for caustics.sims.lens_source

# mypy: disable-error-code="union-attr"
from scipy.fft import next_fast_len
from typing import Optional, Annotated, Literal, Union

from caskade import Module, forward, Param

from .simulator import NameType
from ..utils import (
    meshgrid,
    gaussian_quadrature_grid,
    gaussian_quadrature_integrator,
)
from ..lenses.base import Lens
from ..light.base import Source
from ..backend_obj import backend, ArrayLike

__all__ = ("LensSource",)


[docs] class LensSource(Module): """Lens image of a source. Straightforward simulator to sample a lensed image of a source object. Constructs a sampling grid internally based on the pixelscale and gridding parameters. It can automatically upscale and fine sample an image. This is the most straightforward simulator to view the image if you already have a lens and source chosen. Example usage: .. code:: python import matplotlib.pyplot as plt import caustics cosmo = caustics.FlatLambdaCDM() lens = caustics.lenses.SIS(cosmology=cosmo, x0=0.0, y0=0.0, th_ein=1.0) source = caustics.sources.Sersic(x0=0.0, y0=0.0, q=0.5, phi=0.4, n=2.0, Re=1.0, Ie=1.0) sim = caustics.sims.LensSource( lens, source, pixelscale=0.05, pixels_x=100, upsample_factor=2, z_s=1.0 ) img = sim() plt.imshow(img, origin="lower") plt.show() Attributes ---------- lens: Lens caustics lens mass model object source: Source caustics light object which defines the background source pixelscale: float pixelscale of the sampling grid. pixels_x: int number of pixels on the x-axis for the sampling grid lens_light: Source, optional caustics light object which defines the lensing object's light psf: ArrayLike, optional An image to convolve with the scene. Note that if ``upsample_factor > 1`` the psf must also be at the higher resolution. pixels_y: Optional[int] number of pixels on the y-axis for the sampling grid. If left as ``None`` then this will simply be equal to ``gridx`` upsample_factor (default 1) Amount of upsampling to model the image. For example ``upsample_factor = 2`` indicates that the image will be sampled at double the resolution then summed back to the original resolution (given by pixelscale and gridx/y). quad_level: int (default None) sub pixel integration resolution. This will use Gaussian quadrature to sample the image at a higher resolution, then integrate the image back to the original resolution. This is useful for high accuracy integration of the image, but may increase memory usage and runtime. e name: string (default "sim") a name for this simulator in the parameter DAG. Notes: ----- - The simulator will automatically pad the image to half the PSF size to ensure valid convolution. This is done by default, but can be turned off by setting ``psf_pad = False``. This is only relevant if you are using a PSF. - The upsample factor will increase the resolution of the image by the given factor. For example, ``upsample_factor = 2`` will sample the image at double the resolution, then sum back to the original resolution. This is used when a PSF is provided at high resolution than the original image. Not that the when a PSF is used, the upsample_factor must equal the PSF upsampling level. - For arbitrary pixel integration accuracy using the quad_level parameter. This will use Gaussian quadrature to sample the image at a higher resolution, then integrate the image back to the original resolution. This is useful for high accuracy integration of the image, but is not recommended for large images as it will be slow. The quad_level and upsample_factor can be used together to achieve high accuracy integration of the image convolved with a PSF. - A `Pixelated` light source is defined by bilinear interpolation of the provided image. This means that sub-pixel integration is not required for accurate integration of the pixels. However, if you are using a PSF then you should still use upsample_factor (if your PSF is supersampled) to ensure that everything is sampled at the PSF resolution. """ # noqa: E501 def __init__( self, lens: Annotated[Lens, "caustics lens mass model object"], source: Annotated[ Source, "caustics light object which defines the background source" ], pixelscale: Annotated[float, "pixelscale of the sampling grid"], pixels_x: Annotated[ int, "number of pixels on the x-axis for the sampling grid" ], lens_light: Annotated[ Optional[Source], "caustics light object which defines the lensing object's light", ] = None, pixels_y: Annotated[ Optional[int], "number of pixels on the y-axis for the sampling grid" ] = None, upsample_factor: Annotated[int, "Amount of upsampling to model the image"] = 1, quad_level: Annotated[Optional[int], "sub pixel integration resolution"] = None, psf_mode: Annotated[ Literal["fft", "conv2d"], "Mode for convolving psf" ] = "fft", psf_shape: Annotated[Optional[tuple[int, ...]], "The shape of the psf"] = None, psf: Annotated[ Optional[Union[ArrayLike, list]], "An image to convolve with the scene", True, ] = [[1.0]], x0: Annotated[ Optional[Union[ArrayLike, float]], "center of the fov for the lens source image", True, ] = 0.0, y0: Annotated[ Optional[Union[ArrayLike, float]], "center of the fov for the lens source image", True, ] = 0.0, name: NameType = "sim", ): super().__init__(name) # Configure PSF self._psf_mode = psf_mode if psf is not None: psf = backend.as_array(psf) self._psf_shape = psf.shape if psf is not None else psf_shape # Build parameters self.psf = Param("psf", psf, self.psf_shape, units="unitless") self.x0 = Param("x0", x0, shape=(), units="arcsec") self.y0 = Param("y0", y0, shape=(), units="arcsec") self._pixelscale = pixelscale # Lensing models self.lens = lens self.source = source self.lens_light = lens_light # Image grid self._pixels_x = pixels_x self._pixels_y = pixels_x if pixels_y is None else pixels_y self._upsample_factor = upsample_factor self._quad_level = quad_level # Build the imaging grid self._build_grid()
[docs] def to(self, device=None, dtype=None): super().to(device, dtype) self._grid = tuple(backend.to(x, device=device, dtype=dtype) for x in self._grid) # type: ignore[has-type] self._weights = backend.to(self._weights, device=device, dtype=dtype) # type: ignore[has-type] return self
@property def upsample_factor(self): return self._upsample_factor @upsample_factor.setter def upsample_factor(self, value): value = int(value) assert value > 0, f"upsample_factor should be > 0, not {value}" self._upsample_factor = value self._build_grid() @property def pixels_x(self): return self._pixels_x @pixels_x.setter def pixels_x(self, value): self._pixels_x = value self._build_grid() @property def pixels_y(self): return self._pixels_y @pixels_y.setter def pixels_y(self, value): self._pixels_y = value self._build_grid() @property def quad_level(self): return self._quad_level @quad_level.setter def quad_level(self, value): value = None if value is None else int(value) assert ( value is None or value > 0 ), f"quad_level should be None or > 0, not {value}" self._quad_level = value self._build_grid() @property def pixelscale(self): return self._pixelscale @pixelscale.setter def pixelscale(self, value): self._pixelscale = value self._build_grid() @property def psf_shape(self): return self._psf_shape @psf_shape.setter def psf_shape(self, value): self._psf_shape = value self._build_grid() @property def psf_mode(self): return self._psf_mode @psf_mode.setter def psf_mode(self, value): assert value in ( "fft", "conv2d", ), f"psf_mode should be one of 'fft' or 'conv2d', not {value}" self._psf_mode = value self._build_grid() def _build_grid(self): self._psf_pad = ( self.psf_shape[1] // 2, # upsample pixels self.psf_shape[0] // 2, # upsample pixels ) self._n_pix = ( self.pixels_x * self.upsample_factor + self._psf_pad[0] * 2, # upsample pixels self.pixels_y * self.upsample_factor + self._psf_pad[1] * 2, # upsample pixels ) self._grid = meshgrid( self.pixelscale / self.upsample_factor, # upsample pixelscale self._n_pix[0], # upsample pixels self._n_pix[1], # upsample pixels ) self._weights = backend.ones( (1, 1), dtype=self._grid[0].dtype, device=self._grid[0].device ) if self.quad_level is not None and self.quad_level > 1: finegrid_x, finegrid_y, weights = gaussian_quadrature_grid( self.pixelscale / self.upsample_factor, *self._grid, self.quad_level ) self._grid = (finegrid_x, finegrid_y) self._weights = weights else: self._grid = ( backend.unsqueeze(self._grid[0], -1), backend.unsqueeze(self._grid[1], -1), ) # FFT convolution fastest when the image is padded to the next power of 2 self._s = (next_fast_len(self._n_pix[0]), next_fast_len(self._n_pix[1])) def _fft2_padded(self, x): """ Compute the 2D Fast Fourier Transform (FFT) of a tensor with zero-padding. Args: x (ArrayLike): The input tensor to be transformed. Returns: ArrayLike: The 2D FFT of the input tensor with zero-padding. """ return backend.fft.rfft2(x, self._s) def _unpad_fft(self, x): """ Remove padding from the result of a 2D FFT. Parameters --------- x: ArrayLike The input tensor with padding. Returns ------- ArrayLike The input tensor without padding. """ return backend.roll( x, (-self._psf_pad[0], -self._psf_pad[1]), dims=(-2, -1), )[..., : self._s[0], : self._s[1]] @forward def __call__( self, psf: Annotated[ArrayLike, "Param"], x0: Annotated[ArrayLike, "Param"], y0: Annotated[ArrayLike, "Param"], source_light: bool = True, lens_light: bool = True, lens_source: bool = True, psf_convolve: bool = True, chunk_size: Optional[int] = None, ): """ forward function Parameters ---------- source_light: boolean when true the source light will be sampled lens_light: boolean when true the lens light will be sampled lens_source: boolean when true, the source light model will be lensed by the lens mass distribution psf_convolve: boolean when true the image will be convolved with the psf chunk_size: int when not None, the image will be sampled in chunks of this size. This may help reduce memory usage. """ # Automatically turn off light for missing objects if self.source is None: source_light = False if self.lens_light is None: lens_light = False if psf.shape == (1, 1): psf_convolve = False grid = (self._grid[0] + x0, self._grid[1] + y0) # Sample the source light if source_light: if lens_source: # Source is lensed by the lens mass distribution bx, by = backend.vmap(self.lens.raytrace, chunk_size=chunk_size)( backend.flatten(grid[0]), backend.flatten(grid[1]) ) mu_fine = backend.vmap(self.source.brightness, chunk_size=chunk_size)( bx, by ).reshape(grid[0].shape) mu = gaussian_quadrature_integrator(mu_fine, self._weights) else: # Source is imaged without lensing mu_fine = backend.vmap(self.source.brightness, chunk_size=chunk_size)( backend.flatten(grid[0]), backend.flatten(grid[1]) ).reshape(grid[0].shape) mu = gaussian_quadrature_integrator(mu_fine, self._weights) else: # Source is not added to the scene mu = backend.zeros_like( grid[0][..., 0], dtype=grid[0].dtype ) # chop off quad dim # Sample the lens light if lens_light and self.lens_light is not None: mu_fine = backend.vmap(self.lens_light.brightness, chunk_size=chunk_size)( backend.flatten(grid[0]), backend.flatten(grid[1]) ).reshape(grid[0].shape) mu += gaussian_quadrature_integrator(mu_fine, self._weights) # Convolve the PSF if psf_convolve: if self.psf_mode == "fft": mu_fft = self._fft2_padded(mu) psf_fft = self._fft2_padded(psf / psf.sum()) mu = self._unpad_fft(backend.fft.irfft2(mu_fft * psf_fft, self._s).real) elif self.psf_mode == "conv2d": mu = ( backend.conv2d( mu[None, None], backend.to( (backend.flip(psf, (0, 1)) / backend.sum(psf)), dtype=mu.dtype, )[None, None], padding="same", ) .squeeze(0) .squeeze(0) ) # Return to the desired image mu_clipped = mu[ self._psf_pad[1] : self.pixels_y * self.upsample_factor + self._psf_pad[1], self._psf_pad[0] : self.pixels_x * self.upsample_factor + self._psf_pad[0], ] mu_native_resolution = ( backend.avg_pool2d(mu_clipped[None, None], self.upsample_factor) .squeeze(0) .squeeze(0) ) return mu_native_resolution