from typing import Annotated, Literal, Optional
from caskade import Module, forward
from .simulator import NameType
from ..lenses.base import Lens
from ..light.base import Source
from ..backend_obj import backend, ArrayLike
__all__ = ("Microlens",)
[docs]
class Microlens(Module):
"""Computes the total flux from a microlens system within an fov.
Straightforward simulator to compute the total flux a lensed image of a
source object within a given field of view. Constructs a sampling points
internally based on the user settings.
Example usage:: python
import matplotlib.pyplot as plt
import torch
import caustics
cosmo = caustics.FlatLambdaCDM()
lens = caustics.lenses.SIS(cosmology = cosmo, x0 = 0., y0 = 0., th_ein = 1.)
source = caustics.sources.Sersic(x0 = 0., y0 = 0., q = 0.5, phi = 0.4, n = 2., Re = 1., Ie = 1.)
sim = caustics.sims.Microlens(lens, source, z_s = 1.)
fov = torch.tensor([-1., 1., -1., 1.])
print("Flux and uncertainty: ", sim(fov=fov))
Attributes
----------
lens: Lens
caustics lens mass model object
source: Source
caustics light object which defines the background source
name: string (default "sim")
a name for this simulator in the parameter DAG.
""" # noqa: E501
def __init__(
self,
lens: Annotated[Lens, "caustics lens mass model object"],
source: Annotated[
Source, "caustics light object which defines the background source"
],
name: NameType = "sim",
):
super().__init__(name)
self.lens = lens
self.source = source
@forward
def __call__(
self,
fov: ArrayLike,
method: Literal["mcmc", "grid"] = "mcmc",
N_mcmc: int = 10000,
N_grid: int = 100,
key: Optional[ArrayLike] = None,
):
"""Forward pass of the simulator.
Parameters
----------
fov: ArrayLike
Field of view box of the simulation in arcseconds indexed as (x_min, x_max, y_min, y_max)
method: str (default "mcmc")
Method for sampling the image. Choose from "mcmc" or "grid"
N_mcmc: int
Number of sample points for the source sampling if method is "mcmc"
N_grid: int
Number of sample points for the sampling grid on each axis if method is "grid"
key: Optional[ArrayLike]
A jax.random.key to be used when the Jax backend is used
Returns
-------
ArrayLike
Total flux from the microlens system within the field of view
ArrayLike
Error estimate on the total flux
"""
if method == "mcmc":
# Sample the source using MCMC
if key is not None:
key_x, key_y = backend.split_key(key)
else:
key_x, key_y = None, None
sample_x = backend.rand(N_mcmc, key=key_x) * (fov[1] - fov[0]) + fov[0]
sample_y = backend.rand(N_mcmc, key=key_y) * (fov[3] - fov[2]) + fov[2]
bx, by = self.lens.raytrace(sample_x, sample_y)
mu = self.source.brightness(bx, by)
A = (fov[1] - fov[0]) * (fov[3] - fov[2])
return backend.mean(mu) * A, backend.std(mu) * A / N_mcmc**0.5
elif method == "grid":
# Sample the source using a grid
x = backend.linspace(fov[0], fov[1], N_grid)
y = backend.linspace(fov[2], fov[3], N_grid)
sample_x, sample_y = backend.meshgrid(x, y, indexing="ij")
bx, by = self.lens.raytrace(sample_x, sample_y)
mu = self.source.brightness(bx, by)
A = (fov[1] - fov[0]) * (fov[3] - fov[2])
return backend.mean(mu) * A, backend.std(mu) * A / N_grid
else:
raise ValueError(f"Invalid method: {method}, choose from 'mcmc' or 'grid'")