# mypy: disable-error-code="index,dict-item"
from typing import Optional, Annotated, Union
import numpy as np
from caskade import forward, Param
from .base import ThinLens, CosmologyType, NameType, ZType
from ..backend_obj import backend, ArrayLike
from ..utils import interp2d
__all__ = ("PixelatedDeflection",)
[docs]
class PixelatedDeflection(ThinLens):
_null_params = {
"x0": 0.0,
"y0": 0.0,
"deflection_map": np.linspace(-0.1, 0.1, 100, dtype=np.float32).reshape(10, 10),
}
def __init__(
self,
pixelscale: Annotated[float, "pixelscale", True],
cosmology: CosmologyType,
z_l: ZType = None,
z_s: ZType = None,
x0: Annotated[
Optional[Union[ArrayLike, float]],
"The x-coordinate of the center of the grid",
True,
] = backend.make_array(0.0),
y0: Annotated[
Optional[Union[ArrayLike, float]],
"The y-coordinate of the center of the grid",
True,
] = backend.make_array(0.0),
deflection_map: Annotated[
Optional[ArrayLike],
"A 3D tensor (2, nx, ny) representing the reduced deflection angle map",
True,
] = None,
shape: Annotated[
tuple[Optional[int], ...], "The shape of the deflection map"
] = (
2,
None,
None,
),
name: NameType = None,
):
"""Strong lensing with user provided deflection map
This class enables the computation of deflection angles by interpolating
the user-provided deflection map.
Attributes
----------
name: string
The name of the PixelatedDeflection object.
cosmology: Cosmology
An instance of the cosmological parameters.
z_l: Optional[ArrayLike]
The redshift of the lens.
*Unit: unitless*
z_s: Optional[ArrayLike]
The redshift of the source.
*Unit: unitless*
x0: Optional[ArrayLike]
The x-coordinate of the center of the grid.
*Unit: arcsec*
y0: Optional[ArrayLike]
The y-coordinate of the center of the grid.
*Unit: arcsec*
deflection_map: Optional[ArrayLike]
A 2D tensor representing the deflection map.
*Unit: unitless*
shape: Optional[tuple[int, ...]]
The shape of the deflection map.
"""
super().__init__(cosmology, z_l, name=name, z_s=z_s)
if deflection_map is not None and deflection_map.ndim != 3:
raise ValueError(
f"deflection_map must be 3D (2, nx, ny). Received a {deflection_map.ndim}D tensor)"
)
elif shape is not None and len(shape) != 3:
raise ValueError(
f"shape must specify a 3D tensor (2, nx, ny). Received shape={shape}"
)
self.x0 = Param("x0", x0, shape=(), units="arcsec")
self.y0 = Param("y0", y0, shape=(), units="arcsec")
self.deflection_map = Param(
"deflection_map", deflection_map, shape, units="unitless"
)
self.pixelscale = Param(
"pixelscale", pixelscale, shape=(), units="arcsec/pixel"
)
[docs]
@forward
def reduced_deflection_angle(
self,
x: ArrayLike,
y: ArrayLike,
x0: Annotated[ArrayLike, "Param"],
y0: Annotated[ArrayLike, "Param"],
deflection_map: Annotated[ArrayLike, "Param"],
pixelscale: Annotated[ArrayLike, "Param"],
) -> tuple[ArrayLike, ArrayLike]:
"""
Compute the deflection at the specified positions.
Parameters
----------
x: ArrayLike
The x-coordinates of the positions to compute the convergence for.
*Unit: arcsec*
y: ArrayLike
The y-coordinates of the positions to compute the convergence for.
*Unit: arcsec*
Returns
-------
ArrayLike
The deflection at the specified positions.
*Unit: unitless*
"""
fov_x = deflection_map.shape[2] * pixelscale
fov_y = deflection_map.shape[1] * pixelscale
shape = x.shape
x = backend.view(x - x0, -1) / fov_x * 2
y = backend.view(y - y0, -1) / fov_y * 2
return (
interp2d(deflection_map[0], x, y).reshape(shape),
interp2d(deflection_map[1], x, y).reshape(shape),
)
[docs]
@forward
def potential(self, x, y, **kwargs):
raise NotImplementedError(
"Potential is not implemented for PixelatedDeflection."
)
[docs]
@forward
def convergence(self, x, y, **kwargs):
raise NotImplementedError(
"Convergence is not implemented for PixelatedDeflection."
)