Source code for caustics.sims.microlens

from typing import Annotated, Literal, Optional

from caskade import Module, forward

from .simulator import NameType
from ..lenses.base import Lens
from ..light.base import Source
from ..backend_obj import backend, ArrayLike

__all__ = ("Microlens",)


[docs] class Microlens(Module): """Computes the total flux from a microlens system within an fov. Straightforward simulator to compute the total flux a lensed image of a source object within a given field of view. Constructs a sampling points internally based on the user settings. Example usage:: python import matplotlib.pyplot as plt import torch import caustics cosmo = caustics.FlatLambdaCDM() lens = caustics.lenses.SIS(cosmology = cosmo, x0 = 0., y0 = 0., th_ein = 1.) source = caustics.sources.Sersic(x0 = 0., y0 = 0., q = 0.5, phi = 0.4, n = 2., Re = 1., Ie = 1.) sim = caustics.sims.Microlens(lens, source, z_s = 1.) fov = torch.tensor([-1., 1., -1., 1.]) print("Flux and uncertainty: ", sim(fov=fov)) Attributes ---------- lens: Lens caustics lens mass model object source: Source caustics light object which defines the background source name: string (default "sim") a name for this simulator in the parameter DAG. """ # noqa: E501 def __init__( self, lens: Annotated[Lens, "caustics lens mass model object"], source: Annotated[ Source, "caustics light object which defines the background source" ], name: NameType = "sim", ): super().__init__(name) self.lens = lens self.source = source @forward def __call__( self, fov: ArrayLike, method: Literal["mcmc", "grid"] = "mcmc", N_mcmc: int = 10000, N_grid: int = 100, key: Optional[ArrayLike] = None, ): """Forward pass of the simulator. Parameters ---------- fov: ArrayLike Field of view box of the simulation in arcseconds indexed as (x_min, x_max, y_min, y_max) method: str (default "mcmc") Method for sampling the image. Choose from "mcmc" or "grid" N_mcmc: int Number of sample points for the source sampling if method is "mcmc" N_grid: int Number of sample points for the sampling grid on each axis if method is "grid" key: Optional[ArrayLike] A jax.random.key to be used when the Jax backend is used Returns ------- ArrayLike Total flux from the microlens system within the field of view ArrayLike Error estimate on the total flux """ if method == "mcmc": # Sample the source using MCMC if key is not None: key_x, key_y = backend.split_key(key) else: key_x, key_y = None, None sample_x = backend.rand(N_mcmc, key=key_x) * (fov[1] - fov[0]) + fov[0] sample_y = backend.rand(N_mcmc, key=key_y) * (fov[3] - fov[2]) + fov[2] bx, by = self.lens.raytrace(sample_x, sample_y) mu = self.source.brightness(bx, by) A = (fov[1] - fov[0]) * (fov[3] - fov[2]) return backend.mean(mu) * A, backend.std(mu) * A / N_mcmc**0.5 elif method == "grid": # Sample the source using a grid x = backend.linspace(fov[0], fov[1], N_grid) y = backend.linspace(fov[2], fov[3], N_grid) sample_x, sample_y = backend.meshgrid(x, y, indexing="ij") bx, by = self.lens.raytrace(sample_x, sample_y) mu = self.source.brightness(bx, by) A = (fov[1] - fov[0]) * (fov[3] - fov[2]) return backend.mean(mu) * A, backend.std(mu) * A / N_grid else: raise ValueError(f"Invalid method: {method}, choose from 'mcmc' or 'grid'")