import os
import importlib
from collections import namedtuple
from typing import TypeAlias
import numpy as np
import caskade as ck
ArrayLike: TypeAlias = ck.ArrayLike
# Util to make Jax and Torch TopK to behave similarly
TopKResult = namedtuple("TopKResult", ["values", "indices"])
[docs]
class Backend:
def __init__(self, backend=None):
self.backend = backend
@property
def backend(self):
return self._backend
@backend.setter
def backend(self, backend):
if backend is None:
backend = os.getenv("CASKADE_BACKEND", ck.backend.backend)
ck.backend.backend = backend
self._load_backend(backend)
self._backend = backend
def _load_backend(self, backend):
if backend == "torch":
self.module = importlib.import_module("torch")
self.setup_torch()
elif backend == "jax":
self.module = importlib.import_module("jax.numpy")
self.setup_jax()
else:
raise ValueError(f"Unsupported backend: {backend}")
[docs]
def setup_torch(self):
self.make_array = self._make_array_torch
self._array_type = self._array_type_torch
self.concatenate = self._concatenate_torch
self.copy = self._copy_torch
self.tolist = self._tolist_torch
self.view = self._view_torch
self.as_array = self._as_array_torch
self.to = self._to_torch
self.to_numpy = self._to_numpy_torch
self.gammaln = self._gammaln_torch
self.logit = self._logit_torch
self.sigmoid = self._sigmoid_torch
self.repeat = self._repeat_torch
self.stack = self._stack_torch
self.transpose = self._transpose_torch
self.upsample2d = self._upsample2d_torch
self.pad = self._pad_torch
self.LinAlgErr = self.module._C._LinAlgError
self.roll = self._roll_torch
self.clamp = self._clamp_torch
self.flatten = self._flatten_torch
self.conv2d = self._conv2d_torch
self.mean = self._mean_torch
self.std = self._std_torch
self.sum = self._sum_torch
self.max = self._max_torch
self.min = self._min_torch
self.topk = self._topk_torch
self.bessel_j1 = self._bessel_j1_torch
self.lgamma = self._lgamma_torch
self.hessian = self._hessian_torch
self.jacobian = self._jacobian_torch
self.jacfwd = self._jacfwd_torch
self.grad = self._grad_torch
self.vmap = self._vmap_torch
self.long = self._long_torch
self.fill_at_indices = self._fill_at_indices_torch
self.add_at_indices = self._add_at_indices_torch
self.and_at_indices = self._and_at_indices_torch
self.unsqueeze = self._unsqueeze_torch
self.cat = self._cat_torch
self.gradient = self._gradient_torch
self.detach = self._detach_torch
self.avg_pool2d = self._avg_pool2d_torch
self.rand = self._rand_torch
self.randn = self._randn_torch
self.randint = self._randint_torch
self.split_key = self._split_key_torch
self.meshgrid = self._meshgrid_torch
self.device = self._device_torch
self.numel = self._numel_torch
self.Size = self._size_torch
self.chunk = self._chunk_torch
self.jit = self._jit_torch
self.norm = self._norm_torch
self.arange = self._arange_torch
self.linspace = self._linspace_torch
[docs]
def setup_jax(self):
self.jax = importlib.import_module("jax")
self.jax.config.update("jax_enable_x64", True)
self.make_array = self._make_array_jax
self._array_type = self._array_type_jax
self.concatenate = self._concatenate_jax
self.copy = self._copy_jax
self.tolist = self._tolist_jax
self.view = self._view_jax
self.as_array = self._as_array_jax
self.to = self._to_jax
self.to_numpy = self._to_numpy_jax
self.gammaln = self._gammaln_jax
self.logit = self._logit_jax
self.sigmoid = self._sigmoid_jax
self.repeat = self._repeat_jax
self.stack = self._stack_jax
self.transpose = self._transpose_jax
self.upsample2d = self._upsample2d_jax
self.pad = self._pad_jax
self.LinAlgErr = Exception
self.roll = self._roll_jax
self.clamp = self._clamp_jax
self.flatten = self._flatten_jax
self.conv2d = self._conv2d_jax
self.mean = self._mean_jax
self.std = self._std_jax
self.sum = self._sum_jax
self.max = self._max_jax
self.min = self._min_jax
self.topk = self._topk_jax
self.bessel_j1 = self._bessel_j1_jax
self.lgamma = self._lgamma_jax
self.hessian = self._hessian_jax
self.jacobian = self._jacobian_jax
self.jacfwd = self._jacfwd_jax
self.grad = self._grad_jax
self.vmap = self._vmap_jax
self.long = self._long_jax
self.fill_at_indices = self._fill_at_indices_jax
self.add_at_indices = self._add_at_indices_jax
self.and_at_indices = self._and_at_indices_jax
self.unsqueeze = self._unsqueeze_jax
self.cat = self._cat_jax
self.gradient = self._gradient_jax
self.detach = self._detach_jax
self.avg_pool2d = self._avg_pool2d_jax
self.rand = self._rand_jax
self.randn = self._randn_jax
self.randint = self._randint_jax
self.split_key = self._split_key_jax
self.meshgrid = self._meshgrid_jax
self.device = self._device_jax
self.numel = self._numel_jax
self.Size = self._size_jax
self.chunk = self._chunk_jax
self.jit = self._jit_jax
self.norm = self._norm_jax
self.arange = self._arange_jax
self.linspace = self._linspace_jax
self.key = self.jax.random.key(
np.random.randint(0, 2**31 - 1)
) # random initial state
@property
def array_type(self):
return self._array_type()
def _make_array_torch(self, array, dtype=None, device=None):
return self.module.tensor(array, dtype=dtype, device=device)
def _make_array_jax(self, array, dtype=None, **kwargs):
return self.module.array(array, dtype=dtype)
def _array_type_torch(self):
return self.module.Tensor
def _array_type_jax(self):
return self.module.ndarray
def _concatenate_torch(self, arrays, dim=0):
return self.module.cat(arrays, dim=dim)
def _concatenate_jax(self, arrays, dim=0):
return self.module.concatenate(arrays, axis=dim)
def _copy_torch(self, array):
return array.detach().clone()
def _copy_jax(self, array):
return self.module.copy(array)
def _tolist_torch(self, array):
return array.detach().cpu().tolist()
def _tolist_jax(self, array):
return array.block_until_ready().tolist()
def _view_torch(self, array, shape):
return array.reshape(shape)
def _view_jax(self, array, shape):
return array.reshape(shape)
def _as_array_torch(self, array, dtype=None, device=None):
return self.module.as_tensor(array, dtype=dtype, device=device)
def _as_array_jax(self, array, dtype=None, **kwargs):
return self.module.asarray(array, dtype=dtype)
def _to_torch(self, array, dtype=None, device=None):
return array.to(dtype=dtype, device=device)
def _to_jax(self, array, dtype=None, device=None):
if dtype is not None:
array = array.astype(dtype)
return self.jax.device_put(array, device=device)
def _to_numpy_torch(self, array):
return array.detach().cpu().numpy()
def _to_numpy_jax(self, array):
return np.array(array.block_until_ready())
def _repeat_torch(self, a, repeats, axis=None):
return self.module.repeat_interleave(a, repeats, dim=axis)
def _repeat_jax(self, a, repeats, axis=None):
return self.module.repeat(a, repeats, axis=axis)
def _stack_torch(self, arrays, dim=0):
return self.module.stack(arrays, dim=dim)
def _stack_jax(self, arrays, dim=0):
return self.module.stack(arrays, axis=dim)
def _transpose_torch(self, array, *args):
return self.module.transpose(array, *args)
def _transpose_jax(self, array, *args):
return self.module.swapaxes(array, *args)
def _gammaln_torch(self, array):
return self.module.special.gammaln(array)
def _gammaln_jax(self, array):
return self.jax.scipy.special.gammaln(array)
def _sigmoid_torch(self, array):
return self.module.sigmoid(array)
def _sigmoid_jax(self, array):
return self.jax.nn.sigmoid(array)
def _logit_torch(self, array):
return self.module.logit(array)
def _logit_jax(self, array):
return self.jax.scipy.special.logit(array)
def _upsample2d_torch(self, array, scale_factor, method):
U = self.module.nn.Upsample(scale_factor=scale_factor, mode=method)
array = U(array) / scale_factor**2
return array
def _upsample2d_jax(self, array, scale_factor, method):
if method == "nearest":
method = "bilinear" # no nearest neighbor interpolation in jax
new_shape = list(array.shape)
new_shape[-2] = array.shape[-2] * scale_factor
new_shape[-1] = array.shape[-1] * scale_factor
return self.jax.image.resize(array, new_shape, method=method)
def _pad_torch(self, array, padding, mode="constant"):
return self.module.nn.functional.pad(array, padding, mode=mode)
def _pad_jax(self, array, padding, mode="constant"):
if mode == "replicate":
mode = "edge"
elif mode == "circular":
mode = "wrap"
ndim = array.ndim
pad_width = [(0, 0)] * ndim
for i in range(0, len(padding), 2):
pad_left = padding[i]
pad_right = padding[i + 1]
dim_idx = -(i // 2 + 1)
if abs(dim_idx) <= ndim:
pad_width[dim_idx] = (pad_left, pad_right)
return self.module.pad(array, pad_width, mode=mode)
def _roll_torch(self, array, shifts, dims):
return self.module.roll(array, shifts, dims=dims)
def _roll_jax(self, array, shifts, dims):
return self.module.roll(array, shifts, axis=dims)
def _clamp_torch(self, array, min=None, max=None):
return self.module.clamp(array, min, max)
def _clamp_jax(self, array, min=None, max=None):
return self.module.clip(array, min, max)
def _long_torch(self, array):
return array.long()
def _long_jax(self, array):
return self.module.astype(array, self.module.int64)
def _conv2d_torch(self, array, kernel, padding, stride=1):
return self.module.nn.functional.conv2d(
array,
kernel,
padding=padding,
stride=stride,
)
def _conv2d_jax(self, array, kernel, padding, stride=1):
# kernel = self.module.flip(kernel, (-1, -2))
return self.jax.lax.conv_general_dilated(
array, kernel, window_strides=(stride, stride), padding=padding
)
def _mean_torch(self, array, dim=None):
return self.module.mean(array, dim=dim)
def _mean_jax(self, array, dim=None):
return self.module.mean(array, axis=dim)
def _std_torch(self, array, dim=None):
return self.module.std(array, dim=dim, correction=0)
def _std_jax(self, array, dim=None):
return self.module.std(array, axis=dim)
def _sum_torch(self, array, dim=None):
return self.module.sum(array, dim=dim)
def _sum_jax(self, array, dim=None):
return self.jax.numpy.sum(array, axis=dim)
def _cumprod_torch(self, array, dim=None):
return self.module.cumprod(array, dim=dim)
def _cumprod_jax(self, array, dim=None):
return self.module.cumprod(array, axis=dim)
def _max_torch(self, array, dim=None):
return (
self.module.max(array)
if dim is None
else self.module.max(array, dim=dim).values
)
def _max_jax(self, array, dim=None):
return self.module.max(array, axis=dim)
def _min_torch(self, array, dim=None):
return (
self.module.min(array)
if dim is None
else self.module.min(array, dim=dim).values
)
def _min_jax(self, array, dim=None):
return self.module.min(array, axis=dim)
def _topk_torch(self, array, k):
return self.module.topk(array, k=k)
def _topk_jax(self, array, k):
res = self.jax.lax.top_k(array, k=k)
return TopKResult(values=res[0], indices=res[1])
def _bessel_j1_torch(self, array):
return self.module.special.bessel_j1(array)
def _bessel_j1_jax(self, array):
return self.jax.scipy.special.bessel_jn(array, v=1)[-1]
def _lgamma_torch(self, array):
return self.module.lgamma(array)
def _lgamma_jax(self, array):
return self.jax.lax.lgamma(array)
def _grad_torch(self, func, argnums=0):
return self.module.func.grad(func, argnums=argnums)
def _grad_jax(self, func, argnums=0):
return self.jax.grad(func, argnums=argnums)
def _jacobian_torch(
self, func, x, strategy="forward-mode", vectorize=True, create_graph=False
):
return self.module.autograd.functional.jacobian(
func, x, strategy=strategy, vectorize=vectorize, create_graph=create_graph
)
def _jacobian_jax(
self, func, x, strategy="forward-mode", vectorize=True, create_graph=False
):
if "forward" in strategy:
# n = x.size
# eye = self.module.eye(n)
# Jt = self.jax.vmap(lambda s: self.jax.jvp(func, (x,), (s,))[1])(eye)
# return self.module.moveaxis(Jt, 0, -1)
return self.jax.jacfwd(func)(x)
return self.jax.jacrev(func)(x)
def _jacfwd_torch(self, func, argnums=0, randomness="error"):
return self.module.func.jacfwd(func, argnums=argnums, randomness=randomness)
def _jacfwd_jax(self, func, argnums=0, randomness="error"):
return self.jax.jacfwd(func, argnums=argnums)
def _hessian_torch(self, func, argnums=0):
return self.module.func.hessian(func, argnums=argnums)
def _hessian_jax(self, func, argnums=0):
return self.jax.hessian(func, argnums=argnums)
def _vmap_torch(
self, func, in_dims=0, out_dims=0, randomness="error", chunk_size=None
):
return self.module.vmap(
func,
in_dims=in_dims,
out_dims=out_dims,
randomness=randomness,
chunk_size=chunk_size,
)
def _vmap_jax(
self, func, in_dims=0, out_dims=0, randomness="error", chunk_size=None
):
return self.jax.vmap(func, in_axes=in_dims, out_axes=out_dims)
def _fill_at_indices_torch(self, array, indices, values):
array[indices] = values
return array
def _fill_at_indices_jax(self, array, indices, values):
array = array.at[indices].set(values)
return array
def _add_at_indices_torch(self, array, indices, values):
array[indices] += values
return array
def _add_at_indices_jax(self, array, indices, values):
array = array.at[indices].add(values)
return array
def _and_at_indices_torch(self, array, indices, values):
array[indices] &= values
return array
def _and_at_indices_jax(self, array, indices, values):
array = array.at[indices].set(array[indices] & values)
return array
def _flatten_torch(self, array, start_dim=0, end_dim=-1):
return array.flatten(start_dim, end_dim)
def _flatten_jax(self, array, start_dim=0, end_dim=-1):
shape = tuple(array.shape)
end_dim = (end_dim % len(shape)) + 1
new_shape = shape[:start_dim] + (-1,) + shape[end_dim:]
return self.module.reshape(array, new_shape)
def _unsqueeze_torch(self, array, dim):
return self.module.unsqueeze(array, dim)
def _unsqueeze_jax(self, array, dim):
return self.module.expand_dims(array, axis=dim)
def _cat_torch(self, array, dim=0):
return self.module.cat(array, dim=dim)
def _cat_jax(self, array, dim=0):
return self.module.concatenate(array, axis=dim)
def _gradient_torch(self, array, spacing=1, dim=None):
return self.module.gradient(array, spacing=spacing, dim=dim, edge_order=1)
def _gradient_jax(self, array, spacing=1, dim=None):
# spacing is a positional argument in jax and edge_order is not implemented
return self.module.gradient(array, spacing, axis=dim, edge_order=None)
def _detach_torch(self, array):
return array.detach()
def _detach_jax(self, array):
return self.jax.lax.stop_gradient(array)
def _avg_pool2d_torch(self, array, kernel, stride=None, padding=0):
return self.module.nn.functional.avg_pool2d(
array,
kernel,
padding=padding,
stride=stride,
)
def _avg_pool2d_jax(self, array, kernel, stride=None, padding=0):
if stride is None:
stride = kernel
array = self.jax.lax.reduce_window(
array,
init_value=0,
computation=self.jax.lax.add,
window_dimensions=(1, 1, kernel, kernel),
window_strides=(1, 1, stride, stride),
padding="VALID",
)
return array / kernel**2
def _rand_torch(self, *size, key=None):
return self.module.rand(*size)
def _rand_jax(self, *size, key=None):
if key is None:
self.key, key = self.jax.random.split(self.key) # update key
if len(size) == 1 and isinstance(size[0], (tuple, list)):
shape = size[0]
else:
shape = size
return self.jax.random.uniform(key, shape=shape)
def _randn_torch(self, *size, dtype=None, device=None, key=None):
return self.module.randn(*size, dtype=dtype, device=device)
def _randn_jax(self, *size, dtype=None, device=None, key=None):
if key is None:
self.key, key = self.jax.random.split(self.key)
if len(size) == 1 and isinstance(size[0], (tuple, list)):
shape = size[0]
else:
shape = size
return self.jax.random.normal(key, shape=shape, dtype=dtype)
def _randint_torch(self, high, size, low=0, dtype=None, device=None, key=None):
return self.module.randint(
low=low, high=high, size=size, dtype=dtype, device=device
)
def _randint_jax(self, high, size, low=0, dtype=None, device=None, key=None):
if key is None:
self.key, key = self.jax.random.split(self.key) # update key
return self.jax.random.randint(
key, minval=low, maxval=high, shape=size, dtype=dtype
)
def _split_key_torch(self, key):
raise NotImplementedError(
"`split_key` should not be used with the `torch` backend."
)
def _split_key_jax(self, key):
return self.jax.random.split(key)
def _meshgrid_torch(self, *arrays, indexing="ij"):
return self.module.meshgrid(*arrays, indexing=indexing)
def _meshgrid_jax(self, *arrays, indexing="ij"):
if len(arrays) == 1 and isinstance(arrays[0], (list, tuple)):
arrays = arrays[0]
return self.module.meshgrid(*arrays, indexing=indexing)
def _device_torch(self, array):
return array.device
def _device_jax(self, array):
# JAX Tracers (used during jit/vmap) do NOT have a device.
# JAX Arrays have .device(), but we usually don't need to pass it explicitly.
# Returning None tells JAX to use the default/current device context.
return None
def _numel_torch(self, array):
return array.numel()
def _numel_jax(self, array):
return array.size
def _size_torch(self, shape):
return self.module.Size(shape)
def _size_jax(self, shape):
return tuple(shape)
def _chunk_torch(self, array, chunks, dim=0):
return self.module.chunk(array, chunks, dim=dim)
def _chunk_jax(self, array, chunks, dim=0):
return self.module.array_split(array, chunks, axis=dim)
def _jit_torch(self, func, **kwargs):
return func
def _jit_jax(self, func, **kwargs):
return self.jax.jit(func, **kwargs)
def _norm_torch(self, array, dim=None, **kwargs):
return self.linalg.norm(array, dim=dim, **kwargs)
def _norm_jax(self, array, dim=None, **kwargs):
return self.linalg.norm(array, axis=dim, **kwargs)
def _arange_torch(self, *args, dtype=None, device=None):
return self.module.arange(*args, dtype=dtype, device=device)
def _arange_jax(self, *args, dtype=None, device=None):
return self.module.arange(*args, dtype=dtype)
def _linspace_torch(self, start, end, steps, dtype=None, device=None):
return self.module.linspace(start, end, steps, dtype=dtype, device=device)
def _linspace_jax(self, start, end, steps, dtype=None, device=None):
return self.module.linspace(start, end, steps, dtype=dtype)
[docs]
def searchsorted(self, array, value):
return self.module.searchsorted(array, value)
[docs]
def any(self, array):
return self.module.any(array)
[docs]
def all(self, array):
return self.module.all(array)
[docs]
def log(self, array):
return self.module.log(array)
[docs]
def safe_log(self, array):
return self.module.safe_log(array)
[docs]
def log10(self, array):
return self.module.log10(array)
[docs]
def exp(self, array):
return self.module.exp(array)
[docs]
def sin(self, array):
return self.module.sin(array)
[docs]
def cos(self, array):
return self.module.cos(array)
[docs]
def cosh(self, array):
return self.module.cosh(array)
[docs]
def sqrt(self, array):
return self.module.sqrt(array)
[docs]
def abs(self, array):
return self.module.abs(array)
[docs]
def conj(self, array):
return self.module.conj(array)
[docs]
def nan_to_num(self, array, posinf=None, neginf=None):
return self.module.nan_to_num(array, posinf=posinf, neginf=neginf)
[docs]
def floor(self, array):
return self.module.floor(array)
[docs]
def atleast_1d(self, array):
return self.module.atleast_1d(array)
[docs]
def tanh(self, array):
return self.module.tanh(array)
[docs]
def arctan(self, array):
return self.module.arctan(array)
[docs]
def atan(self, array):
return self.module.atan(array)
[docs]
def atanh(self, array):
return self.module.atanh(array)
[docs]
def arctan2(self, y, x):
return self.module.arctan2(y, x)
[docs]
def arcsin(self, array):
return self.module.arcsin(array)
[docs]
def arcsinh(self, array):
return self.module.arcsinh(array)
[docs]
def arccos(self, array):
return self.module.arccos(array)
[docs]
def arccosh(self, array):
return self.module.arccosh(array)
[docs]
def round(self, array):
return self.module.round(array)
[docs]
def zeros(self, shape, dtype=None, device=None):
return self.module.zeros(shape, dtype=dtype, device=device)
[docs]
def zeros_like(self, array, dtype=None):
return self.module.zeros_like(array, dtype=dtype)
[docs]
def ones(self, shape, dtype=None, device=None):
return self.module.ones(shape, dtype=dtype, device=device)
[docs]
def ones_like(self, array, dtype=None):
return self.module.ones_like(array, dtype=dtype)
[docs]
def empty(self, shape, dtype=None, device=None):
return self.module.empty(shape, dtype=dtype, device=device)
[docs]
def eye(self, n, dtype=None, device=None):
return self.module.eye(n, dtype=dtype, device=device)
[docs]
def diag(self, array):
return self.module.diag(array)
[docs]
def outer(self, a, b):
return self.module.outer(a, b)
[docs]
def minimum(self, a, b):
return self.module.minimum(a, b)
[docs]
def maximum(self, a, b):
return self.module.maximum(a, b)
[docs]
def isnan(self, array):
return self.module.isnan(array)
[docs]
def isfinite(self, array):
return self.module.isfinite(array)
[docs]
def where(self, condition, x, y):
return self.module.where(condition, x, y)
[docs]
def allclose(self, a, b, rtol=1e-5, atol=1e-8):
return self.module.allclose(a, b, rtol=rtol, atol=atol)
[docs]
def tile(self, array, dims):
return self.module.tile(array, dims)
[docs]
def flip(self, array, dims):
return self.module.flip(array, dims)
[docs]
def is_finite(self, array):
return self.module.isfinite(array)
[docs]
def prod(self, array, dim=None):
return self.module.prod(array) if dim is None else self.module.prod(array, dim)
[docs]
def cumprod(self, array, dim=None):
return self.module.cumprod(array, dim)
[docs]
def einsum(self, equation, *operands):
return self.module.einsum(equation, *operands)
[docs]
def argmax(self, array, dim=None):
return self.module.argmax(array, dim)
[docs]
def dot(self, *arrays):
return self.module.dot(*arrays)
@property
def linalg(self):
return self.module.linalg
@property
def fft(self):
return self.module.fft
@property
def inf(self):
return self.module.inf
@property
def bool(self):
return self.module.bool
@property
def int32(self):
return self.module.int32
@property
def float32(self):
return self.module.float32
@property
def float64(self):
return self.module.float64
@property
def pi(self):
return self.module.pi
@property
def nan(self):
return self.module.nan
backend = Backend()