Hide code cell content

%load_ext autoreload
%autoreload 2
import numpy as np
import matplotlib.pyplot as plt
import torch
from torch import vmap
from torch.func import jacfwd

Let’s take a deeper dive into Caustics!#

In this introduction, we will showcase some of the features and design principles of caustics. We will see

  1. How to get started from one of our pre-built Simulator

  2. Visualization of the Simulator graph (DAG of caustics modules)

  3. Distinction between Static and Dynamic parameters

  4. How to create a a batch of simulations

  5. Semantic structure of the Simulator input

  6. Taking gradient w.r.t. to parameters with Pytorch autodiff functionalities

  7. Swapping in flexible modules like the Pixelated representation for more advanced usage

  8. How to create your own Simulator

Getting started with the LensSource Simulator#

For this first introduction, we use the simplest modules in caustics for the lens and source, namely the SIE and the Sersic modules. We also assume a FlatLambdaCDM cosmology.

from caustics import LensSource, SIE, Sersic, FlatLambdaCDM

# Define parameters of the camera pixel grid
pixelscale = 0.04  # arcsec/pixel
pixels = 100

# Instantiate modules for the simulator
cosmo = FlatLambdaCDM(name="cosmology")
lens = SIE(
    cosmology=cosmo, name="lens", z_s=1, z_l=0.5, x0=0, y0=0, q=0.9, phi=0.4, Rein=1
)
source = Sersic(name="source", x0=0, y0=0, q=0.5, phi=0.9, n=1, Re=0.1, Ie=10)
simulator = LensSource(lens, source, pixelscale=pixelscale, pixels_x=pixels)

Generating a simulation of a strong gravitational lens#

Hide code cell content

source_params = source.get_values()
lens_params = lens.get_values()
params = simulator.get_values()
# Generate a lensed image
y = simulator(params=params)

Hide code cell source

fig, axs = plt.subplots(1, 3, figsize=(15, 4))

# A meshgrid to show the source
x = torch.linspace(-0.5, 0.5, 100)
X, Y = torch.meshgrid(x, x, indexing="xy")

ax = axs[0]
ax.set_title(r"Sérsic source")
source_im = source.brightness(X, Y, params=source_params)
ax.imshow(source_im, origin="lower", extent=(-0.5, 0.5, -0.5, 0.5), cmap="gray")
ax.set_ylabel(r"$\beta_y$ ['']")
ax.set_xlabel(r"$\beta_x$ ['']")

ax = axs[1]
ax.set_title(r"SIE mass distribution")
lens_im = lens.convergence(X * 2, Y * 2, params=lens_params)
ax.imshow(lens_im, origin="lower", extent=(-1, 1, -1, 1), cmap="hot")
ax.set_ylabel(r"$\theta_y$ ['']")
ax.set_xlabel(r"$\theta_x$ ['']")

ax = axs[2]
ax.set_title(r"Lensed image")
ax.imshow(y, origin="lower", extent=(-1, 1, -1, 1), cmap="gray")
ax.set_ylabel(r"$\theta_y$ ['']")
ax.set_xlabel(r"$\theta_x$ ['']")
Text(0.5, 0, "$\\theta_x$ ['']")
../_images/32a09a25c70d848d62369fd818876676591b4f1225356f56e3e56e041f54bef4.png

Visualization of the Simulator DAG#

simulator.graphviz()
../_images/47be82097bf589dae091a61059abe2782a913a97ed3ac440a877b2977adb1414.svg

Static vs Dynamic parameters#

In the DAG shown above,

  • Dynamic parameters are shown in white boxes

  • Static parameters are shown in grey boxes

The distinction between the two types can be summarized as follows

  • Dynamic parameters are fed as input to the simulator and can be batched over (data parallelism)

  • Static parameters have fixed values. Their values is stored in the internal DAG, and will be broadcasted over when batching computation

# Making a parameter dynamic
simulator.lens.z_s.to_dynamic()
simulator.graphviz()  # z_s turns white, which makes it disappear when we don't show the dynamic parameters (first option False)
../_images/68c2a2fb70b459888024c59b9c12ecfcfd431aae2ecf6c2afc90b42aac23a9ab.svg
# Making a parameter static
simulator.lens.z_s.to_static()
simulator.graphviz()  # z_s turns grey
../_images/47be82097bf589dae091a61059abe2782a913a97ed3ac440a877b2977adb1414.svg

Simulating a batch of observations#

We use vmap over the simulator to create a batch of parameters. In this example, we create a batch of examples that only differ by their Einstein radius. To do this, we turn all the other parameter into static parameters. This is done in the hidden cell below

Hide code cell content

# All parameters static except the Einstein radius
simulator.lens.Rein.to_dynamic()
# Create a grid of Einstein radius
b = torch.linspace(0.5, 1.5, 5).view(-1, 1)  # Shape is [B, 1]
ys = vmap(simulator)(b)

Hide code cell source

fig, axs = plt.subplots(1, 5, figsize=(20, 4))

for i, ax in enumerate(axs.flatten()):
    ax.axis("off")
    ax.imshow(ys[i], cmap="gray")
    ax.set_title(f"$Rein = {b[i].item():.2f}$")
plt.subplots_adjust(wspace=0, hspace=0)
../_images/356c72a3c5f7ef06b19f88ddf390ede615d99f153b4b934b077921cafa16fbf0.png

Semantic structure of the input#

The simulator’s input takes different format to allow different usecase scenarios

  1. Flattened tensor for deep neural network like in Hezaveh et al. (2017)

  2. Semantic List to separate the input int terms of high level modules like Lens and Source

  3. Low-level Dictionary to decompose the parameters at the level of the leafs of the DAG

Below, we illustrate how to use all of these structures. For completeness, we also use vmap.

# Make some parameters dynamic for this example
simulator.source.Ie.to_dynamic()
simulator.lens.Rein.to_dynamic()

Flattened Tensor#

To make sure the order of the parameter is correct, print the simulator. Order of dynamic parameters is read top to bottom.

print(simulator)
sim|LensSource
    psf|static: 1
    x0|static: 0
    y0|static: 0
    lens|SIE
        cosmology|FlatLambdaCDM
            h0|static: 0.677
            critical_density_0|static: 1.27e+11
            Om0|static: 0.31
        z_s|static: 1
        z_l|static: 0.5
        x0|static: 0
        y0|static: 0
        q|static: 0.9
        phi|static: 0.4
        Rein|dynamic: 1
    source|Sersic
        x0|static: 0
        y0|static: 0
        q|static: 0.5
        phi|static: 0.9
        n|static: 1
        Re|static: 0.1
        Ie|dynamic: 10
B = 5  # Batch dimension
Rein = torch.rand(B, 1)
Ie = torch.rand(B, 1)
x = torch.concat([Rein, Ie], dim=1)  # Concat along the feature dimension

# Now we can use vmap to simulate multiple images at once
ys = vmap(simulator)(x)

Semantic lists#

A semantic list is simply a list over module parameters like the one we used earlier: [lens_params, source_params]. Note that we could also include cosmological parameters in that list

Hide code cell content

# Make some parameters dynamic for this example
simulator.source.Ie.to_dynamic()
simulator.lens.Rein.to_dynamic()
simulator.lens.x0.to_dynamic()
simulator.lens.cosmology.h0.to_dynamic()

print(simulator)
sim|LensSource
    psf|static: 1
    x0|static: 0
    y0|static: 0
    lens|SIE
        cosmology|FlatLambdaCDM
            h0|dynamic: 0.677
            critical_density_0|static: 1.27e+11
            Om0|static: 0.31
        z_s|static: 1
        z_l|static: 0.5
        x0|dynamic: 0
        y0|static: 0
        q|static: 0.9
        phi|static: 0.4
        Rein|dynamic: 1
    source|Sersic
        x0|static: 0
        y0|static: 0
        q|static: 0.5
        phi|static: 0.9
        n|static: 1
        Re|static: 0.1
        Ie|dynamic: 10
B = 5
cosmo_param = torch.rand(B, 1)  # h0
lens_x0_param = torch.randn(B, 1)  # x0
lens_Rein_param = torch.randn(B, 1)  # Rein
source_param = torch.rand(B, 1)  # Ie

x = torch.cat([cosmo_param, lens_x0_param, lens_Rein_param, source_param], dim=-1)
ys = vmap(simulator)(x)

Low-level Dictionary#

Make the dictionary have the same structure as the graph.

Hide code cell content

B = 5
x0 = torch.randn(B, 1)
Rein = torch.randn(B, 1)
Ie = torch.rand(B, 1)
h0 = torch.rand(B, 1)
x = {
    "lens": {
        "x0": x0,
        "Rein": Rein,
        "cosmology": {
            "h0": h0,
        },
    },
    "source": {
        "Ie": Ie,
    },
}
ys = vmap(simulator)(x)

Computing gradients with automatic differentiation#

Computing gradients is particularly useful for optimization. Since taking gradients w.r.t. list or dictionary inputs is not possible with torch.func.grad, we will need a small wrapper around the simulator. For optimisation, the wrapper will often be a log likelihood function. For now we use a generic lambda wrapper.

In the case of the semantic list input, the wrapper has the general form

lambda *x: simulator(x)

The low-level dictionary input is a bit more involved but can be worked out on a case by case basis.

Note: apply vmap around the gradient function (e.g. jacfwd or grad) to handle batched computation

Hide code cell content

# Choose some sensible values to compute the gradient
cosmo_param = torch.tensor([0.7])  # h0
lens_x0_param = torch.tensor([0.0])  # x0
lens_Rein_param = torch.tensor([1.0])  # Rein
source_param = torch.tensor([10.0])  # Ie

jacfwd will return a list of 3 tensors of shape [B, pixels, pixels, D], where D is the number of parameters in that module

jac = jacfwd(lambda *x: simulator(x), argnums=(0, 1, 2, 3))(
    cosmo_param, lens_x0_param, lens_Rein_param, source_param
)

Hide code cell source

fig, axs = plt.subplots(1, 4, figsize=(20, 4))

titles = [
    r"$\nabla_{x_0} f(\mathbf{x})$",
    r"$\nabla_{Rein} f(\mathbf{x})$",
    r"$\nabla_{h_0} f(\mathbf{x})$",
    r"$\nabla_{I_e} f(\mathbf{x})$",
]
jacs = torch.concat(jac, dim=-1)
for i, ax in enumerate(axs.flatten()):
    ax.axis("off")
    ax.imshow(jacs[..., i], cmap="seismic", vmin=-10, vmax=10)
    ax.set_title(titles[i], fontsize=18)
../_images/959197b4fc5f7e151afffaa988afe6f7a4bc3c36d919432e8cd0c842465c023b.png

Pixelated representations#

The examples above made use of very simplistic modules. Here, we will showcase how easily we can swap-in flexible representations to represent more realistic systems.

  • Pixelated is the module used to represent the background source with a grid of pixels

  • PixelatedConvergence is the module used to represent the convergence of the lens with a grid of pixels

For this example, we will use source samples from the PROBES dataset (Stone et al., 2019) and convergence maps sampled from Illustris TNG (Nelson et al., 2019, see Adam et al., 2023 for preprocessing, or use this link to download the maps).

from caustics import Pixelated, PixelatedConvergence

# Some static parameters for the simulator
pixelscale = 0.07
source_pixelscale = 0.25 * pixelscale
z_l = 0.5
z_s = 1.0
x0 = 0
y0 = 0

# Construct the Simulator with Pixelated and PixalatedConvergence modules
cosmo = FlatLambdaCDM(name="cosmo")
source = Pixelated(
    name="source", shape=(256, 256), pixelscale=source_pixelscale, x0=x0, y0=y0
)
lens = PixelatedConvergence(
    cosmology=cosmo,
    name="lens",
    pixelscale=pixelscale,
    shape=(128, 128),
    z_l=z_l,
    z_s=z_s,
)
simulator = LensSource(lens, source, pixelscale=pixelscale, pixels_x=pixels)

simulator.graphviz()
../_images/bc8b620549bda94abd956a0b2ae1e9e5876c924ce60d9c3611e3fde86d40c186.svg

In the hidden cell below, we load the maps from a dataset. If you downloaded the datasets mentioned above, you can use the code below to load maps from them.

Hide code cell content

# import h5py

# B = 10
# path_to_kappa_maps = "/path/to/hkappa128hst_TNG100_rau_trainset.h5"  # modify this to your system path
# index = [250] + sorted(list(np.random.randint(251, 1000, size=B-1)))
# kappa_map = torch.tensor(h5py.File(path_to_kappa_maps, "r")["kappa"][index])

# path_to_source_maps = "/path/to/probes.h5"  # modify this to your system path
# index = [101] + sorted(list(np.random.randint(251, 1000, size=B-1)))
# filter_ = 0  # grz filters: 0 is g, etc.
# source_map = torch.tensor(
#     h5py.File(path_to_source_maps, "r")["galaxies"][index, ..., filter_]
# )

# Load saved assets for demonstration
kappa_maps = torch.tensor(
    np.load("assets/kappa_maps.npz", allow_pickle=True)["kappa_maps"]
)
source_maps = torch.tensor(
    np.load("assets/source_maps.npz", allow_pickle=True)["source_maps"]
)

# Cherry picked example
source_map = source_maps[0]
kappa_map = kappa_maps[0]

Make a simulation by feeding the maps as input to the simulator (using semantic list inputs)

y = simulator([kappa_map, source_map])

Hide code cell source

fig, axs = plt.subplots(1, 3, figsize=(15, 4))

beta_extent = [
    -source_pixelscale * source_map.shape[0] / 2,
    source_pixelscale * source_map.shape[0] / 2,
] * 2

ax = axs[0]
ax.set_title(r"Source map")
ax.imshow(source_map, origin="lower", cmap="gray", extent=beta_extent)
ax.set_ylabel(r"$\beta_y$ ['']")
ax.set_xlabel(r"$\beta_x$ ['']")

theta_extent = [-pixelscale * pixels / 2, pixelscale * pixels / 2] * 2

ax = axs[1]
ax.set_title(r"Convergence map")
ax.imshow(
    kappa_map,
    origin="lower",
    cmap="hot",
    extent=theta_extent,
    norm=plt.cm.colors.LogNorm(vmin=1e-1, vmax=10),
)
ax.set_ylabel(r"$\theta_y$ ['']")
ax.set_xlabel(r"$\theta_x$ ['']")
ax.set_title(r"Convergence map")

ax = axs[2]
ax.set_title(r"Lensed image")
ax.imshow(y, origin="lower", extent=theta_extent, cmap="gray")
ax.set_ylabel(r"$\theta_y$ ['']")
ax.set_xlabel(r"$\theta_x$ ['']")
ax.set_title(r"Lensed image")
Text(0.5, 1.0, 'Lensed image')
../_images/a12f1d574d3477524b6bb46923852b38bd9736fd87c0b0abfbd25ab2e7b704c9.png

Of course, batching works the same way as before and is super fast. Below, we show the time it takes to make 4 batched simulations on a laptop.

%%timeit

ys = vmap(simulator)([kappa_maps, source_maps])
13.2 ms ± 293 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)

Hide code cell source

fig, axs = plt.subplots(3, 3, figsize=(9, 9))

ys = vmap(simulator)([kappa_maps, source_maps])
for i in range(3):
    ax = axs[i, 0]
    ax.axis("off")
    ax.imshow(
        source_maps[len(ys) - 1 - i],
        origin="lower",
        cmap="gray",
        norm=plt.cm.colors.LogNorm(vmin=1e-2, vmax=1, clip=True),
    )

    ax = axs[i, 1]
    ax.axis("off")
    ax.imshow(
        kappa_maps[len(ys) - 1 - i],
        origin="lower",
        cmap="hot",
        norm=plt.cm.colors.LogNorm(vmin=1e-1, vmax=10),
    )

    ax = axs[i, 2]
    ax.axis("off")
    ax.imshow(
        ys[len(ys) - 1 - i],
        origin="lower",
        cmap="gray",
        norm=plt.cm.colors.LogNorm(vmin=1e-2, vmax=1, clip=True),
    )
axs[0, 0].set_title(r"Source map")
axs[0, 1].set_title(r"Convergence map")
axs[0, 2].set_title(r"Lensed image")
plt.subplots_adjust(wspace=0, hspace=0)
../_images/d5b05d04cb1108a3b712a53de23693337f4d6ec3ed05b95d0c97582efe369931.png

Creating your own Simulator#

Here, we only introduce the general design principles to create a simulator. More comprehensive explanations can be found in the caskade docs.

A Simulator is very much like a neural network in Pytorch#

A simulator inherits from the super class caskade.Module, similar to how a neural network inherits from the nn.Module class in Pytorch

from caustics import Module, Param, forward

class MySim(Module):
    def __init__(self):
        super().__init__()
        self.p = Param("p")

    @forward
    def myfunction(self, x, p):
        ...
  • The init method constructs the computation graph, initialize the caustics modules, and can prepare or store variables for the forward method.

  • The forward method is where the actual simulation happens.

  • x generally denotes a set of parameters which affect the computations in the simulator graph.

How to use a Simulator in your workflow#

Like a neural network, MySim (and in general any caustics modules), must be instantiated outside the main workload. This is because caustics builds a graph internally every time a module is created. Ideally, this happens only once to avoid overhead. In general, you can follow the following code pattern


# Instantiation
simulator = MySim()

# Heavy workload
for n in range(N):
    y = vmap(simulator)(x)

This allows you to perform inefficient computations that only need to happen once in the __init__ method while keeping your forward method lightweight.

How to feed parameters to the different modules#

This is probably the easiest part of building a Simulator, you only provide the values when calling the top level simulator.

Here is a minimal example that shows how to feed the parameters the forward method

@forward
def raytrace(self, x, y):
   alpha_x, alpha_y = self.lens.reduced_deflection_angle(x, y)
   beta_x = x - alpha_x # lens equation
   beta_y = y - alpha_y
   return self.source.brightness(beta_x, beta_y)
sim.raytrace(xgrid, ygrid, params)

You might worry that params can have a relatively complex structure (flattened tensor, semantic list, low-level dictionary). caustics handles this complexity for you. You only need to make sure that params contains all the dynamic parameters required by your custom simulator. This design works for every caustics module and each of their methods, meaning that params is always the last argument in a caustics method call signature.

The only details that you need to handle explicitly in your own simulator are stuff like the camera pixel position (xgrid and ygrid), and source redshifts (z_s). Those are often constructed in the __init__ method because they can be assumed fixed. Thus, the example above assumed that they can be retrieved from the self registry. A Simulator is often an abstraction of an instrument with many fixed variables to describe it, or aimed at a specific observation.

Of course, you could have more complex workflows for which this assumption is not true. For example, you might want to infer the PSF parameters of your instrument and need to feed this to the simulator as a dynamic parameter. The next section has what you need to customize completely your simulator

Creating your own variables as leafs in the DAG#

You can register new variables in the DAG for custom calculations as follows

from caustics import Module, forward, Param


class MySim(Module):
    def __init__(self):
        super().__init__()  # Don't forget to use super!!
        # shape has to be a tuple, e.g. shape=(1,). This can be any shape you need.
        self.my_dynamic_arg = Param(
            "my_dynamic_arg", value=None, shape=(1,)
        )  # register a dynamic parameter in the DAG
        self.my_static_arg = Param(
            "my_static_arg", value=1.0, shape=()
        )  # register a static parameter in the DAG

    @forward
    def forward(self, x, my_dynamic_arg, my_static_arg):

        # My very complex workflow
        ...
        return my_dynamic_arg * x + my_static_arg


sim = MySim()
sim.graphviz()
../_images/2dd77f65ae66d5a751c0a146014cf75a1f90c5223ae53896f5875ef94636d47c.svg