# mypy: disable-error-code="union-attr"
from scipy.fft import next_fast_len
from typing import Optional, Annotated, Literal, Union
from caskade import Module, forward, Param
from .simulator import NameType
from ..utils import (
meshgrid,
gaussian_quadrature_grid,
gaussian_quadrature_integrator,
)
from ..lenses.base import Lens
from ..light.base import Source
from ..backend_obj import backend, ArrayLike
__all__ = ("LensSource",)
[docs]
class LensSource(Module):
"""Lens image of a source.
Straightforward simulator to sample a lensed image of a source object.
Constructs a sampling grid internally based on the pixelscale and gridding
parameters. It can automatically upscale and fine sample an image. This is
the most straightforward simulator to view the image if you already have a
lens and source chosen.
Example usage:
.. code:: python
import matplotlib.pyplot as plt
import caustics
cosmo = caustics.FlatLambdaCDM()
lens = caustics.lenses.SIS(cosmology=cosmo, x0=0.0, y0=0.0, th_ein=1.0)
source = caustics.sources.Sersic(x0=0.0, y0=0.0, q=0.5, phi=0.4, n=2.0, Re=1.0, Ie=1.0)
sim = caustics.sims.LensSource(
lens, source, pixelscale=0.05, pixels_x=100, upsample_factor=2, z_s=1.0
)
img = sim()
plt.imshow(img, origin="lower")
plt.show()
Attributes
----------
lens: Lens
caustics lens mass model object
source: Source
caustics light object which defines the background source
pixelscale: float
pixelscale of the sampling grid.
pixels_x: int
number of pixels on the x-axis for the sampling grid
lens_light: Source, optional
caustics light object which defines the lensing object's light
psf: ArrayLike, optional
An image to convolve with the scene. Note that if ``upsample_factor >
1`` the psf must also be at the higher resolution.
pixels_y: Optional[int]
number of pixels on the y-axis for the sampling grid. If left as
``None`` then this will simply be equal to ``gridx``
upsample_factor (default 1)
Amount of upsampling to model the image. For example ``upsample_factor =
2`` indicates that the image will be sampled at double the resolution
then summed back to the original resolution (given by pixelscale and
gridx/y).
quad_level: int (default None)
sub pixel integration resolution. This will use Gaussian quadrature to
sample the image at a higher resolution, then integrate the image back
to the original resolution. This is useful for high accuracy integration
of the image, but may increase memory usage and runtime.
e
name: string (default "sim")
a name for this simulator in the parameter DAG.
Notes:
-----
- The simulator will automatically pad the image to half the PSF size to
ensure valid convolution. This is done by default, but can be turned off
by setting ``psf_pad = False``. This is only relevant if you are using a
PSF.
- The upsample factor will increase the resolution of the image by the given
factor. For example, ``upsample_factor = 2`` will sample the image at
double the resolution, then sum back to the original resolution. This is
used when a PSF is provided at high resolution than the original image.
Not that the when a PSF is used, the upsample_factor must equal the PSF
upsampling level.
- For arbitrary pixel integration accuracy using the quad_level parameter.
This will use Gaussian quadrature to sample the image at a higher
resolution, then integrate the image back to the original resolution. This
is useful for high accuracy integration of the image, but is not
recommended for large images as it will be slow. The quad_level and
upsample_factor can be used together to achieve high accuracy integration
of the image convolved with a PSF.
- A `Pixelated` light source is defined by bilinear interpolation of the
provided image. This means that sub-pixel integration is not required for
accurate integration of the pixels. However, if you are using a PSF then
you should still use upsample_factor (if your PSF is supersampled) to
ensure that everything is sampled at the PSF resolution.
""" # noqa: E501
def __init__(
self,
lens: Annotated[Lens, "caustics lens mass model object"],
source: Annotated[
Source, "caustics light object which defines the background source"
],
pixelscale: Annotated[float, "pixelscale of the sampling grid"],
pixels_x: Annotated[
int, "number of pixels on the x-axis for the sampling grid"
],
lens_light: Annotated[
Optional[Source],
"caustics light object which defines the lensing object's light",
] = None,
pixels_y: Annotated[
Optional[int], "number of pixels on the y-axis for the sampling grid"
] = None,
upsample_factor: Annotated[int, "Amount of upsampling to model the image"] = 1,
quad_level: Annotated[Optional[int], "sub pixel integration resolution"] = None,
psf_mode: Annotated[
Literal["fft", "conv2d"], "Mode for convolving psf"
] = "fft",
psf_shape: Annotated[Optional[tuple[int, ...]], "The shape of the psf"] = None,
psf: Annotated[
Optional[Union[ArrayLike, list]],
"An image to convolve with the scene",
True,
] = [[1.0]],
x0: Annotated[
Optional[Union[ArrayLike, float]],
"center of the fov for the lens source image",
True,
] = 0.0,
y0: Annotated[
Optional[Union[ArrayLike, float]],
"center of the fov for the lens source image",
True,
] = 0.0,
name: NameType = "sim",
):
super().__init__(name)
# Configure PSF
self._psf_mode = psf_mode
if psf is not None:
psf = backend.as_array(psf)
self._psf_shape = psf.shape if psf is not None else psf_shape
# Build parameters
self.psf = Param("psf", psf, self.psf_shape, units="unitless")
self.x0 = Param("x0", x0, shape=(), units="arcsec")
self.y0 = Param("y0", y0, shape=(), units="arcsec")
self._pixelscale = pixelscale
# Lensing models
self.lens = lens
self.source = source
self.lens_light = lens_light
# Image grid
self._pixels_x = pixels_x
self._pixels_y = pixels_x if pixels_y is None else pixels_y
self._upsample_factor = upsample_factor
self._quad_level = quad_level
# Build the imaging grid
self._build_grid()
[docs]
def to(self, device=None, dtype=None):
super().to(device, dtype)
self._grid = tuple(backend.to(x, device=device, dtype=dtype) for x in self._grid) # type: ignore[has-type]
self._weights = backend.to(self._weights, device=device, dtype=dtype) # type: ignore[has-type]
return self
@property
def upsample_factor(self):
return self._upsample_factor
@upsample_factor.setter
def upsample_factor(self, value):
value = int(value)
assert value > 0, f"upsample_factor should be > 0, not {value}"
self._upsample_factor = value
self._build_grid()
@property
def pixels_x(self):
return self._pixels_x
@pixels_x.setter
def pixels_x(self, value):
self._pixels_x = value
self._build_grid()
@property
def pixels_y(self):
return self._pixels_y
@pixels_y.setter
def pixels_y(self, value):
self._pixels_y = value
self._build_grid()
@property
def quad_level(self):
return self._quad_level
@quad_level.setter
def quad_level(self, value):
value = None if value is None else int(value)
assert (
value is None or value > 0
), f"quad_level should be None or > 0, not {value}"
self._quad_level = value
self._build_grid()
@property
def pixelscale(self):
return self._pixelscale
@pixelscale.setter
def pixelscale(self, value):
self._pixelscale = value
self._build_grid()
@property
def psf_shape(self):
return self._psf_shape
@psf_shape.setter
def psf_shape(self, value):
self._psf_shape = value
self._build_grid()
@property
def psf_mode(self):
return self._psf_mode
@psf_mode.setter
def psf_mode(self, value):
assert value in (
"fft",
"conv2d",
), f"psf_mode should be one of 'fft' or 'conv2d', not {value}"
self._psf_mode = value
self._build_grid()
def _build_grid(self):
self._psf_pad = (
self.psf_shape[1] // 2, # upsample pixels
self.psf_shape[0] // 2, # upsample pixels
)
self._n_pix = (
self.pixels_x * self.upsample_factor
+ self._psf_pad[0] * 2, # upsample pixels
self.pixels_y * self.upsample_factor
+ self._psf_pad[1] * 2, # upsample pixels
)
self._grid = meshgrid(
self.pixelscale / self.upsample_factor, # upsample pixelscale
self._n_pix[0], # upsample pixels
self._n_pix[1], # upsample pixels
)
self._weights = backend.ones(
(1, 1), dtype=self._grid[0].dtype, device=self._grid[0].device
)
if self.quad_level is not None and self.quad_level > 1:
finegrid_x, finegrid_y, weights = gaussian_quadrature_grid(
self.pixelscale / self.upsample_factor, *self._grid, self.quad_level
)
self._grid = (finegrid_x, finegrid_y)
self._weights = weights
else:
self._grid = (
backend.unsqueeze(self._grid[0], -1),
backend.unsqueeze(self._grid[1], -1),
)
# FFT convolution fastest when the image is padded to the next power of 2
self._s = (next_fast_len(self._n_pix[0]), next_fast_len(self._n_pix[1]))
def _fft2_padded(self, x):
"""
Compute the 2D Fast Fourier Transform (FFT) of a tensor with zero-padding.
Args:
x (ArrayLike): The input tensor to be transformed.
Returns:
ArrayLike: The 2D FFT of the input tensor with zero-padding.
"""
return backend.fft.rfft2(x, self._s)
def _unpad_fft(self, x):
"""
Remove padding from the result of a 2D FFT.
Parameters
---------
x: ArrayLike
The input tensor with padding.
Returns
-------
ArrayLike
The input tensor without padding.
"""
return backend.roll(
x,
(-self._psf_pad[0], -self._psf_pad[1]),
dims=(-2, -1),
)[..., : self._s[0], : self._s[1]]
@forward
def __call__(
self,
psf: Annotated[ArrayLike, "Param"],
x0: Annotated[ArrayLike, "Param"],
y0: Annotated[ArrayLike, "Param"],
source_light: bool = True,
lens_light: bool = True,
lens_source: bool = True,
psf_convolve: bool = True,
chunk_size: Optional[int] = None,
):
"""
forward function
Parameters
----------
source_light: boolean
when true the source light will be sampled
lens_light: boolean
when true the lens light will be sampled
lens_source: boolean
when true, the source light model will be lensed by the lens mass distribution
psf_convolve: boolean
when true the image will be convolved with the psf
chunk_size: int
when not None, the image will be sampled in chunks of this size. This may help reduce memory usage.
"""
# Automatically turn off light for missing objects
if self.source is None:
source_light = False
if self.lens_light is None:
lens_light = False
if psf.shape == (1, 1):
psf_convolve = False
grid = (self._grid[0] + x0, self._grid[1] + y0)
# Sample the source light
if source_light:
if lens_source:
# Source is lensed by the lens mass distribution
bx, by = backend.vmap(self.lens.raytrace, chunk_size=chunk_size)(
backend.flatten(grid[0]), backend.flatten(grid[1])
)
mu_fine = backend.vmap(self.source.brightness, chunk_size=chunk_size)(
bx, by
).reshape(grid[0].shape)
mu = gaussian_quadrature_integrator(mu_fine, self._weights)
else:
# Source is imaged without lensing
mu_fine = backend.vmap(self.source.brightness, chunk_size=chunk_size)(
backend.flatten(grid[0]), backend.flatten(grid[1])
).reshape(grid[0].shape)
mu = gaussian_quadrature_integrator(mu_fine, self._weights)
else:
# Source is not added to the scene
mu = backend.zeros_like(
grid[0][..., 0], dtype=grid[0].dtype
) # chop off quad dim
# Sample the lens light
if lens_light and self.lens_light is not None:
mu_fine = backend.vmap(self.lens_light.brightness, chunk_size=chunk_size)(
backend.flatten(grid[0]), backend.flatten(grid[1])
).reshape(grid[0].shape)
mu += gaussian_quadrature_integrator(mu_fine, self._weights)
# Convolve the PSF
if psf_convolve:
if self.psf_mode == "fft":
mu_fft = self._fft2_padded(mu)
psf_fft = self._fft2_padded(psf / psf.sum())
mu = self._unpad_fft(backend.fft.irfft2(mu_fft * psf_fft, self._s).real)
elif self.psf_mode == "conv2d":
mu = (
backend.conv2d(
mu[None, None],
backend.to(
(backend.flip(psf, (0, 1)) / backend.sum(psf)),
dtype=mu.dtype,
)[None, None],
padding="same",
)
.squeeze(0)
.squeeze(0)
)
# Return to the desired image
mu_clipped = mu[
self._psf_pad[1] : self.pixels_y * self.upsample_factor + self._psf_pad[1],
self._psf_pad[0] : self.pixels_x * self.upsample_factor + self._psf_pad[0],
]
mu_native_resolution = (
backend.avg_pool2d(mu_clipped[None, None], self.upsample_factor)
.squeeze(0)
.squeeze(0)
)
return mu_native_resolution