from typing import Optional
from warnings import warn
from caskade import forward
from ..backend_obj import backend, ArrayLike
from .base import ThinLens, CosmologyType, NameType, ZType
from ..utils import vmap_reduce
__all__ = ("BatchedPlane",)
[docs]
class BatchedPlane(ThinLens):
"""
A class for combining multiple thin lenses into a single lensing plane. It
is assumed that the lens parameters will have a batch dimension, internally
this class will vmap over the batch dimension and return the combined
lensing quantity. This class can only handle a single lens type, if you want
to combine different lens types, use the `SinglePlane` class.
Attributes
----------
name: str
The name of the single plane lens.
cosmology: Cosmology
An instance of the Cosmology class.
lens: ThinLens
A ThinLens object that will be vmapped over into a single lensing plane.
"""
def __init__(
self,
cosmology: CosmologyType,
lens: ThinLens,
name: NameType = None,
z_l: ZType = None,
z_s: ZType = None,
chunk_size: Optional[int] = None,
):
"""
Initialize the SinglePlane lens model.
"""
super().__init__(cosmology, z_l=z_l, name=name, z_s=z_s)
self.hierarchical_link("lens", lens)
if lens.z_l.static:
warn(
f"Lens model {lens.name} has a static lens redshift. This is now overwritten by the BatchedPlane ({self.name}) lens redshift. To prevent this warning, set the lens redshift of the lens model to be dynamic before adding to the system."
)
self.lens.z_l = self.z_l
if lens.z_s.static:
warn(
f"Lens model {lens.name} has a static source redshift. This is now overwritten by the BatchedPlane ({self.name}) source redshift. To prevent this warning, set the source redshift of the lens model to be dynamic before adding to the system."
)
self.lens.z_s = self.z_s
self.chunk_size = chunk_size
[docs]
@forward
def reduced_deflection_angle(
self,
x: ArrayLike,
y: ArrayLike,
lens_params,
lens_dims,
) -> tuple[ArrayLike, ArrayLike]:
"""
Calculate the total deflection angle by summing
the deflection angles of all individual lenses.
Parameters
----------
x: ArrayLike
The x-coordinate of the lens.
*Unit: arcsec*
y: ArrayLike
The y-coordinate of the lens.
*Unit: arcsec*
Returns
-------
x_component: ArrayLike
The x-component of the deflection angle.
*Unit: arcsec*
y_component: ArrayLike
The y-component of the deflection angle.
*Unit: arcsec*
"""
vr_deflection_angle = vmap_reduce(
lambda *args: self.lens.reduced_deflection_angle(
args[0], args[1], args[2:]
),
reduce_func=lambda x: (backend.sum(x[0], dim=0), backend.sum(x[1], dim=0)),
chunk_size=self.chunk_size,
in_dims=(None, None) + tuple(lens_dims),
out_dims=(0, 0),
)
return vr_deflection_angle(x, y, *lens_params)
[docs]
@forward
def convergence(
self,
x: ArrayLike,
y: ArrayLike,
lens_params,
lens_dims,
) -> ArrayLike:
"""
Calculate the total projected mass density by
summing the mass densities of all individual lenses.
Parameters
----------
x: ArrayLike
The x-coordinate of the lens.
*Unit: arcsec*
y: ArrayLike
The y-coordinate of the lens.
*Unit: arcsec*
Returns
-------
ArrayLike
The total projected mass density.
*Unit: unitless*
"""
vr_convergence = vmap_reduce(
lambda *args: self.lens.convergence(args[0], args[1], args[2:]),
chunk_size=self.chunk_size,
in_dims=(None, None) + tuple(lens_dims),
)
return vr_convergence(x, y, *lens_params)
[docs]
@forward
def potential(
self,
x: ArrayLike,
y: ArrayLike,
lens_params,
lens_dims,
) -> ArrayLike:
"""
Compute the total lensing potential by summing
the lensing potentials of all individual lenses.
Parameters
-----------
x: ArrayLike
The x-coordinate of the lens.
*Unit: arcsec*
y: ArrayLike
The y-coordinate of the lens.
*Unit: arcsec*
Returns
-------
ArrayLike
The total lensing potential.
*Unit: arcsec^2*
"""
vr_potential = vmap_reduce(
lambda *args: self.lens.potential(args[0], args[1], args[2:]),
chunk_size=self.chunk_size,
in_dims=(None, None) + tuple(lens_dims),
)
return vr_potential(x, y, *lens_params)