Source code for caustics.backend_obj

import os
import importlib
from collections import namedtuple
from typing import TypeAlias

import numpy as np
import caskade as ck

ArrayLike: TypeAlias = ck.ArrayLike


# Util to make Jax and Torch TopK to behave similarly
TopKResult = namedtuple("TopKResult", ["values", "indices"])


[docs] class Backend: def __init__(self, backend=None): self.backend = backend @property def backend(self): return self._backend @backend.setter def backend(self, backend): if backend is None: backend = os.getenv("CASKADE_BACKEND", ck.backend.backend) ck.backend.backend = backend self._load_backend(backend) self._backend = backend def _load_backend(self, backend): if backend == "torch": self.module = importlib.import_module("torch") self.setup_torch() elif backend == "jax": self.module = importlib.import_module("jax.numpy") self.setup_jax() else: raise ValueError(f"Unsupported backend: {backend}")
[docs] def setup_torch(self): self.make_array = self._make_array_torch self._array_type = self._array_type_torch self.concatenate = self._concatenate_torch self.copy = self._copy_torch self.tolist = self._tolist_torch self.view = self._view_torch self.as_array = self._as_array_torch self.to = self._to_torch self.to_numpy = self._to_numpy_torch self.gammaln = self._gammaln_torch self.logit = self._logit_torch self.sigmoid = self._sigmoid_torch self.repeat = self._repeat_torch self.stack = self._stack_torch self.transpose = self._transpose_torch self.upsample2d = self._upsample2d_torch self.pad = self._pad_torch self.LinAlgErr = self.module._C._LinAlgError self.roll = self._roll_torch self.clamp = self._clamp_torch self.flatten = self._flatten_torch self.conv2d = self._conv2d_torch self.mean = self._mean_torch self.std = self._std_torch self.sum = self._sum_torch self.max = self._max_torch self.min = self._min_torch self.topk = self._topk_torch self.bessel_j1 = self._bessel_j1_torch self.lgamma = self._lgamma_torch self.hessian = self._hessian_torch self.jacobian = self._jacobian_torch self.jacfwd = self._jacfwd_torch self.grad = self._grad_torch self.vmap = self._vmap_torch self.long = self._long_torch self.fill_at_indices = self._fill_at_indices_torch self.add_at_indices = self._add_at_indices_torch self.and_at_indices = self._and_at_indices_torch self.unsqueeze = self._unsqueeze_torch self.cat = self._cat_torch self.gradient = self._gradient_torch self.detach = self._detach_torch self.avg_pool2d = self._avg_pool2d_torch self.rand = self._rand_torch self.randn = self._randn_torch self.randint = self._randint_torch self.split_key = self._split_key_torch self.meshgrid = self._meshgrid_torch self.device = self._device_torch self.numel = self._numel_torch self.Size = self._size_torch self.chunk = self._chunk_torch self.jit = self._jit_torch self.norm = self._norm_torch self.arange = self._arange_torch self.linspace = self._linspace_torch
[docs] def setup_jax(self): self.jax = importlib.import_module("jax") self.jax.config.update("jax_enable_x64", True) self.make_array = self._make_array_jax self._array_type = self._array_type_jax self.concatenate = self._concatenate_jax self.copy = self._copy_jax self.tolist = self._tolist_jax self.view = self._view_jax self.as_array = self._as_array_jax self.to = self._to_jax self.to_numpy = self._to_numpy_jax self.gammaln = self._gammaln_jax self.logit = self._logit_jax self.sigmoid = self._sigmoid_jax self.repeat = self._repeat_jax self.stack = self._stack_jax self.transpose = self._transpose_jax self.upsample2d = self._upsample2d_jax self.pad = self._pad_jax self.LinAlgErr = Exception self.roll = self._roll_jax self.clamp = self._clamp_jax self.flatten = self._flatten_jax self.conv2d = self._conv2d_jax self.mean = self._mean_jax self.std = self._std_jax self.sum = self._sum_jax self.max = self._max_jax self.min = self._min_jax self.topk = self._topk_jax self.bessel_j1 = self._bessel_j1_jax self.lgamma = self._lgamma_jax self.hessian = self._hessian_jax self.jacobian = self._jacobian_jax self.jacfwd = self._jacfwd_jax self.grad = self._grad_jax self.vmap = self._vmap_jax self.long = self._long_jax self.fill_at_indices = self._fill_at_indices_jax self.add_at_indices = self._add_at_indices_jax self.and_at_indices = self._and_at_indices_jax self.unsqueeze = self._unsqueeze_jax self.cat = self._cat_jax self.gradient = self._gradient_jax self.detach = self._detach_jax self.avg_pool2d = self._avg_pool2d_jax self.rand = self._rand_jax self.randn = self._randn_jax self.randint = self._randint_jax self.split_key = self._split_key_jax self.meshgrid = self._meshgrid_jax self.device = self._device_jax self.numel = self._numel_jax self.Size = self._size_jax self.chunk = self._chunk_jax self.jit = self._jit_jax self.norm = self._norm_jax self.arange = self._arange_jax self.linspace = self._linspace_jax self.key = self.jax.random.key( np.random.randint(0, 2**31 - 1) ) # random initial state
@property def array_type(self): return self._array_type() def _make_array_torch(self, array, dtype=None, device=None): return self.module.tensor(array, dtype=dtype, device=device) def _make_array_jax(self, array, dtype=None, **kwargs): return self.module.array(array, dtype=dtype) def _array_type_torch(self): return self.module.Tensor def _array_type_jax(self): return self.module.ndarray def _concatenate_torch(self, arrays, dim=0): return self.module.cat(arrays, dim=dim) def _concatenate_jax(self, arrays, dim=0): return self.module.concatenate(arrays, axis=dim) def _copy_torch(self, array): return array.detach().clone() def _copy_jax(self, array): return self.module.copy(array) def _tolist_torch(self, array): return array.detach().cpu().tolist() def _tolist_jax(self, array): return array.block_until_ready().tolist() def _view_torch(self, array, shape): return array.reshape(shape) def _view_jax(self, array, shape): return array.reshape(shape) def _as_array_torch(self, array, dtype=None, device=None): return self.module.as_tensor(array, dtype=dtype, device=device) def _as_array_jax(self, array, dtype=None, **kwargs): return self.module.asarray(array, dtype=dtype) def _to_torch(self, array, dtype=None, device=None): return array.to(dtype=dtype, device=device) def _to_jax(self, array, dtype=None, device=None): if dtype is not None: array = array.astype(dtype) return self.jax.device_put(array, device=device) def _to_numpy_torch(self, array): return array.detach().cpu().numpy() def _to_numpy_jax(self, array): return np.array(array.block_until_ready()) def _repeat_torch(self, a, repeats, axis=None): return self.module.repeat_interleave(a, repeats, dim=axis) def _repeat_jax(self, a, repeats, axis=None): return self.module.repeat(a, repeats, axis=axis) def _stack_torch(self, arrays, dim=0): return self.module.stack(arrays, dim=dim) def _stack_jax(self, arrays, dim=0): return self.module.stack(arrays, axis=dim) def _transpose_torch(self, array, *args): return self.module.transpose(array, *args) def _transpose_jax(self, array, *args): return self.module.swapaxes(array, *args) def _gammaln_torch(self, array): return self.module.special.gammaln(array) def _gammaln_jax(self, array): return self.jax.scipy.special.gammaln(array) def _sigmoid_torch(self, array): return self.module.sigmoid(array) def _sigmoid_jax(self, array): return self.jax.nn.sigmoid(array) def _logit_torch(self, array): return self.module.logit(array) def _logit_jax(self, array): return self.jax.scipy.special.logit(array) def _upsample2d_torch(self, array, scale_factor, method): U = self.module.nn.Upsample(scale_factor=scale_factor, mode=method) array = U(array) / scale_factor**2 return array def _upsample2d_jax(self, array, scale_factor, method): if method == "nearest": method = "bilinear" # no nearest neighbor interpolation in jax new_shape = list(array.shape) new_shape[-2] = array.shape[-2] * scale_factor new_shape[-1] = array.shape[-1] * scale_factor return self.jax.image.resize(array, new_shape, method=method) def _pad_torch(self, array, padding, mode="constant"): return self.module.nn.functional.pad(array, padding, mode=mode) def _pad_jax(self, array, padding, mode="constant"): if mode == "replicate": mode = "edge" elif mode == "circular": mode = "wrap" ndim = array.ndim pad_width = [(0, 0)] * ndim for i in range(0, len(padding), 2): pad_left = padding[i] pad_right = padding[i + 1] dim_idx = -(i // 2 + 1) if abs(dim_idx) <= ndim: pad_width[dim_idx] = (pad_left, pad_right) return self.module.pad(array, pad_width, mode=mode) def _roll_torch(self, array, shifts, dims): return self.module.roll(array, shifts, dims=dims) def _roll_jax(self, array, shifts, dims): return self.module.roll(array, shifts, axis=dims) def _clamp_torch(self, array, min=None, max=None): return self.module.clamp(array, min, max) def _clamp_jax(self, array, min=None, max=None): return self.module.clip(array, min, max) def _long_torch(self, array): return array.long() def _long_jax(self, array): return self.module.astype(array, self.module.int64) def _conv2d_torch(self, array, kernel, padding, stride=1): return self.module.nn.functional.conv2d( array, kernel, padding=padding, stride=stride, ) def _conv2d_jax(self, array, kernel, padding, stride=1): # kernel = self.module.flip(kernel, (-1, -2)) return self.jax.lax.conv_general_dilated( array, kernel, window_strides=(stride, stride), padding=padding ) def _mean_torch(self, array, dim=None): return self.module.mean(array, dim=dim) def _mean_jax(self, array, dim=None): return self.module.mean(array, axis=dim) def _std_torch(self, array, dim=None): return self.module.std(array, dim=dim, correction=0) def _std_jax(self, array, dim=None): return self.module.std(array, axis=dim) def _sum_torch(self, array, dim=None): return self.module.sum(array, dim=dim) def _sum_jax(self, array, dim=None): return self.jax.numpy.sum(array, axis=dim) def _cumprod_torch(self, array, dim=None): return self.module.cumprod(array, dim=dim) def _cumprod_jax(self, array, dim=None): return self.module.cumprod(array, axis=dim) def _max_torch(self, array, dim=None): return ( self.module.max(array) if dim is None else self.module.max(array, dim=dim).values ) def _max_jax(self, array, dim=None): return self.module.max(array, axis=dim) def _min_torch(self, array, dim=None): return ( self.module.min(array) if dim is None else self.module.min(array, dim=dim).values ) def _min_jax(self, array, dim=None): return self.module.min(array, axis=dim) def _topk_torch(self, array, k): return self.module.topk(array, k=k) def _topk_jax(self, array, k): res = self.jax.lax.top_k(array, k=k) return TopKResult(values=res[0], indices=res[1]) def _bessel_j1_torch(self, array): return self.module.special.bessel_j1(array) def _bessel_j1_jax(self, array): return self.jax.scipy.special.bessel_jn(array, v=1)[-1] def _lgamma_torch(self, array): return self.module.lgamma(array) def _lgamma_jax(self, array): return self.jax.lax.lgamma(array) def _grad_torch(self, func, argnums=0): return self.module.func.grad(func, argnums=argnums) def _grad_jax(self, func, argnums=0): return self.jax.grad(func, argnums=argnums) def _jacobian_torch( self, func, x, strategy="forward-mode", vectorize=True, create_graph=False ): return self.module.autograd.functional.jacobian( func, x, strategy=strategy, vectorize=vectorize, create_graph=create_graph ) def _jacobian_jax( self, func, x, strategy="forward-mode", vectorize=True, create_graph=False ): if "forward" in strategy: # n = x.size # eye = self.module.eye(n) # Jt = self.jax.vmap(lambda s: self.jax.jvp(func, (x,), (s,))[1])(eye) # return self.module.moveaxis(Jt, 0, -1) return self.jax.jacfwd(func)(x) return self.jax.jacrev(func)(x) def _jacfwd_torch(self, func, argnums=0, randomness="error"): return self.module.func.jacfwd(func, argnums=argnums, randomness=randomness) def _jacfwd_jax(self, func, argnums=0, randomness="error"): return self.jax.jacfwd(func, argnums=argnums) def _hessian_torch(self, func, argnums=0): return self.module.func.hessian(func, argnums=argnums) def _hessian_jax(self, func, argnums=0): return self.jax.hessian(func, argnums=argnums) def _vmap_torch( self, func, in_dims=0, out_dims=0, randomness="error", chunk_size=None ): return self.module.vmap( func, in_dims=in_dims, out_dims=out_dims, randomness=randomness, chunk_size=chunk_size, ) def _vmap_jax( self, func, in_dims=0, out_dims=0, randomness="error", chunk_size=None ): return self.jax.vmap(func, in_axes=in_dims, out_axes=out_dims) def _fill_at_indices_torch(self, array, indices, values): array[indices] = values return array def _fill_at_indices_jax(self, array, indices, values): array = array.at[indices].set(values) return array def _add_at_indices_torch(self, array, indices, values): array[indices] += values return array def _add_at_indices_jax(self, array, indices, values): array = array.at[indices].add(values) return array def _and_at_indices_torch(self, array, indices, values): array[indices] &= values return array def _and_at_indices_jax(self, array, indices, values): array = array.at[indices].set(array[indices] & values) return array def _flatten_torch(self, array, start_dim=0, end_dim=-1): return array.flatten(start_dim, end_dim) def _flatten_jax(self, array, start_dim=0, end_dim=-1): shape = tuple(array.shape) end_dim = (end_dim % len(shape)) + 1 new_shape = shape[:start_dim] + (-1,) + shape[end_dim:] return self.module.reshape(array, new_shape) def _unsqueeze_torch(self, array, dim): return self.module.unsqueeze(array, dim) def _unsqueeze_jax(self, array, dim): return self.module.expand_dims(array, axis=dim) def _cat_torch(self, array, dim=0): return self.module.cat(array, dim=dim) def _cat_jax(self, array, dim=0): return self.module.concatenate(array, axis=dim) def _gradient_torch(self, array, spacing=1, dim=None): return self.module.gradient(array, spacing=spacing, dim=dim, edge_order=1) def _gradient_jax(self, array, spacing=1, dim=None): # spacing is a positional argument in jax and edge_order is not implemented return self.module.gradient(array, spacing, axis=dim, edge_order=None) def _detach_torch(self, array): return array.detach() def _detach_jax(self, array): return self.jax.lax.stop_gradient(array) def _avg_pool2d_torch(self, array, kernel, stride=None, padding=0): return self.module.nn.functional.avg_pool2d( array, kernel, padding=padding, stride=stride, ) def _avg_pool2d_jax(self, array, kernel, stride=None, padding=0): if stride is None: stride = kernel array = self.jax.lax.reduce_window( array, init_value=0, computation=self.jax.lax.add, window_dimensions=(1, 1, kernel, kernel), window_strides=(1, 1, stride, stride), padding="VALID", ) return array / kernel**2 def _rand_torch(self, *size, key=None): return self.module.rand(*size) def _rand_jax(self, *size, key=None): if key is None: self.key, key = self.jax.random.split(self.key) # update key if len(size) == 1 and isinstance(size[0], (tuple, list)): shape = size[0] else: shape = size return self.jax.random.uniform(key, shape=shape) def _randn_torch(self, *size, dtype=None, device=None, key=None): return self.module.randn(*size, dtype=dtype, device=device) def _randn_jax(self, *size, dtype=None, device=None, key=None): if key is None: self.key, key = self.jax.random.split(self.key) if len(size) == 1 and isinstance(size[0], (tuple, list)): shape = size[0] else: shape = size return self.jax.random.normal(key, shape=shape, dtype=dtype) def _randint_torch(self, high, size, low=0, dtype=None, device=None, key=None): return self.module.randint( low=low, high=high, size=size, dtype=dtype, device=device ) def _randint_jax(self, high, size, low=0, dtype=None, device=None, key=None): if key is None: self.key, key = self.jax.random.split(self.key) # update key return self.jax.random.randint( key, minval=low, maxval=high, shape=size, dtype=dtype ) def _split_key_torch(self, key): raise NotImplementedError( "`split_key` should not be used with the `torch` backend." ) def _split_key_jax(self, key): return self.jax.random.split(key) def _meshgrid_torch(self, *arrays, indexing="ij"): return self.module.meshgrid(*arrays, indexing=indexing) def _meshgrid_jax(self, *arrays, indexing="ij"): if len(arrays) == 1 and isinstance(arrays[0], (list, tuple)): arrays = arrays[0] return self.module.meshgrid(*arrays, indexing=indexing) def _device_torch(self, array): return array.device def _device_jax(self, array): # JAX Tracers (used during jit/vmap) do NOT have a device. # JAX Arrays have .device(), but we usually don't need to pass it explicitly. # Returning None tells JAX to use the default/current device context. return None def _numel_torch(self, array): return array.numel() def _numel_jax(self, array): return array.size def _size_torch(self, shape): return self.module.Size(shape) def _size_jax(self, shape): return tuple(shape) def _chunk_torch(self, array, chunks, dim=0): return self.module.chunk(array, chunks, dim=dim) def _chunk_jax(self, array, chunks, dim=0): return self.module.array_split(array, chunks, axis=dim) def _jit_torch(self, func, **kwargs): return func def _jit_jax(self, func, **kwargs): return self.jax.jit(func, **kwargs) def _norm_torch(self, array, dim=None, **kwargs): return self.linalg.norm(array, dim=dim, **kwargs) def _norm_jax(self, array, dim=None, **kwargs): return self.linalg.norm(array, axis=dim, **kwargs) def _arange_torch(self, *args, dtype=None, device=None): return self.module.arange(*args, dtype=dtype, device=device) def _arange_jax(self, *args, dtype=None, device=None): return self.module.arange(*args, dtype=dtype) def _linspace_torch(self, start, end, steps, dtype=None, device=None): return self.module.linspace(start, end, steps, dtype=dtype, device=device) def _linspace_jax(self, start, end, steps, dtype=None, device=None): return self.module.linspace(start, end, steps, dtype=dtype)
[docs] def searchsorted(self, array, value): return self.module.searchsorted(array, value)
[docs] def any(self, array): return self.module.any(array)
[docs] def all(self, array): return self.module.all(array)
[docs] def log(self, array): return self.module.log(array)
[docs] def safe_log(self, array): return self.module.safe_log(array)
[docs] def log10(self, array): return self.module.log10(array)
[docs] def exp(self, array): return self.module.exp(array)
[docs] def sin(self, array): return self.module.sin(array)
[docs] def cos(self, array): return self.module.cos(array)
[docs] def cosh(self, array): return self.module.cosh(array)
[docs] def sqrt(self, array): return self.module.sqrt(array)
[docs] def abs(self, array): return self.module.abs(array)
[docs] def conj(self, array): return self.module.conj(array)
[docs] def nan_to_num(self, array, posinf=None, neginf=None): return self.module.nan_to_num(array, posinf=posinf, neginf=neginf)
[docs] def floor(self, array): return self.module.floor(array)
[docs] def atleast_1d(self, array): return self.module.atleast_1d(array)
[docs] def tanh(self, array): return self.module.tanh(array)
[docs] def arctan(self, array): return self.module.arctan(array)
[docs] def atan(self, array): return self.module.atan(array)
[docs] def atanh(self, array): return self.module.atanh(array)
[docs] def arctan2(self, y, x): return self.module.arctan2(y, x)
[docs] def arcsin(self, array): return self.module.arcsin(array)
[docs] def arcsinh(self, array): return self.module.arcsinh(array)
[docs] def arccos(self, array): return self.module.arccos(array)
[docs] def arccosh(self, array): return self.module.arccosh(array)
[docs] def round(self, array): return self.module.round(array)
[docs] def zeros(self, shape, dtype=None, device=None): return self.module.zeros(shape, dtype=dtype, device=device)
[docs] def zeros_like(self, array, dtype=None): return self.module.zeros_like(array, dtype=dtype)
[docs] def ones(self, shape, dtype=None, device=None): return self.module.ones(shape, dtype=dtype, device=device)
[docs] def ones_like(self, array, dtype=None): return self.module.ones_like(array, dtype=dtype)
[docs] def empty(self, shape, dtype=None, device=None): return self.module.empty(shape, dtype=dtype, device=device)
[docs] def eye(self, n, dtype=None, device=None): return self.module.eye(n, dtype=dtype, device=device)
[docs] def diag(self, array): return self.module.diag(array)
[docs] def outer(self, a, b): return self.module.outer(a, b)
[docs] def minimum(self, a, b): return self.module.minimum(a, b)
[docs] def maximum(self, a, b): return self.module.maximum(a, b)
[docs] def isnan(self, array): return self.module.isnan(array)
[docs] def isfinite(self, array): return self.module.isfinite(array)
[docs] def where(self, condition, x, y): return self.module.where(condition, x, y)
[docs] def allclose(self, a, b, rtol=1e-5, atol=1e-8): return self.module.allclose(a, b, rtol=rtol, atol=atol)
[docs] def tile(self, array, dims): return self.module.tile(array, dims)
[docs] def flip(self, array, dims): return self.module.flip(array, dims)
[docs] def is_finite(self, array): return self.module.isfinite(array)
[docs] def prod(self, array, dim=None): return self.module.prod(array) if dim is None else self.module.prod(array, dim)
[docs] def cumprod(self, array, dim=None): return self.module.cumprod(array, dim)
[docs] def einsum(self, equation, *operands): return self.module.einsum(equation, *operands)
[docs] def argmax(self, array, dim=None): return self.module.argmax(array, dim)
[docs] def dot(self, *arrays): return self.module.dot(*arrays)
@property def linalg(self): return self.module.linalg @property def fft(self): return self.module.fft @property def inf(self): return self.module.inf @property def bool(self): return self.module.bool @property def int32(self): return self.module.int32 @property def float32(self): return self.module.float32 @property def float64(self): return self.module.float64 @property def pi(self): return self.module.pi @property def nan(self): return self.module.nan
backend = Backend()