from ...utils import batch_lm
from ...backend_obj import backend
from ...constants import arcsec_to_rad, c_Mpc_s, days_to_seconds
from warnings import warn
[docs]
def triangle_contains(p, v):
"""
determine if point v is inside triangle p. Where p is a (3,2) tensor, and v
is a (2,) tensor.
"""
p01 = p[1] - p[0]
p02 = p[2] - p[0]
dp0p02 = p[0][0] * p02[1] - p[0][1] * p02[0]
dp0p01 = p[0][0] * p01[1] - p[0][1] * p01[0]
dp01p02 = p01[0] * p02[1] - p01[1] * p02[0]
dvp02 = v[0] * p02[1] - v[1] * p02[0]
dvp01 = v[0] * p01[1] - v[1] * p01[0]
a = (dvp02 - dp0p02) / dp01p02
b = -(dvp01 - dp0p01) / dp01p02
return (a >= 0) & (b >= 0) & (a + b <= 1)
[docs]
def triangle_area(p):
"""
Determine the area of triangle p where p is a (3,2) tensor.
"""
return 0.5 * backend.abs(
p[0][0] * (p[1][1] - p[2][1])
+ p[1][0] * (p[2][1] - p[0][1])
+ p[2][0] * (p[0][1] - p[1][1])
)
[docs]
def triangle_neighbors(p):
"""
Build a set of neighbors for triangle p where p is a (3,2) tensor. The
neighbors all have the same shape as p, but are various translations and
reflections of p that share a common edge or vertex.
"""
p01 = p[1] - p[0]
p02 = p[2] - p[0]
p12 = p[2] - p[1]
pref = -(p - p[0]) + p[0]
return backend.stack(
(
p,
p + p01,
p - p01,
p + p02,
p - p02,
p + p12,
p - p12,
pref,
pref + p01,
pref + 2 * p01,
pref + p02,
pref + 2 * p02,
pref + p01 + p02,
),
dim=0,
)
[docs]
def triangle_upsample(p):
"""
Upsample triangle p where p is a (3,2) tensor. The upsampled triangles are
all triangles internal to p built by taking the midpoints of the edges of p.
"""
p01 = (p[1] + p[0]) / 2
p02 = (p[2] + p[0]) / 2
p12 = (p[2] + p[1]) / 2
return backend.stack(
(
backend.stack((p[0], p01, p02), dim=0),
backend.stack((p01, p[1], p12), dim=0),
backend.stack((p02, p12, p[2]), dim=0),
backend.stack((p01, p12, p02), dim=0),
),
dim=0,
)
[docs]
def triangle_equals(p1, p2):
"""
Determine if two triangles are equal. Where p1 and p2 are (3,2) tensors.
"""
return backend.all(backend.abs(p1 - p2) < 1e-6)
[docs]
def remove_triangle_duplicates(p):
unique_triangles = backend.zeros((0, 3, 2), device=p.device, dtype=p.dtype)
B = p.shape[0]
batch_triangle_equals = backend.vmap(triangle_equals, in_dims=(None, 0))
for i in range(B):
# Compare current triangle with all triangles in the unique list
if i == 0 or not backend.any(batch_triangle_equals(p[i], unique_triangles)):
unique_triangles = backend.concatenate(
(unique_triangles, backend.unsqueeze(p[i], 0)), dim=0
)
return unique_triangles
[docs]
def forward_raytrace_rootfind(ix, iy, bx, by, raytrace):
"""
Perform a forward ray-tracing operation which maps from the source plane to
the image plane.
Parameters
----------
ix: ArrayLike
ArrayLike of x coordinate in the image plane. This initializes the
ray-tracing optimization. Should have shape (B, 2).
*Unit: arcsec*
iy: ArrayLike
ArrayLike of y coordinate in the image plane. This initializes the
ray-tracing optimization. Should have shape (B, 2).
bx: ArrayLike
ArrayLike of x coordinate in the source plane. Should be a scalar.
*Unit: arcsec*
by: ArrayLike
ArrayLike of y coordinate in the source plane. Should be a scalar.
*Unit: arcsec*
raytrace: function
function that takes in the x and y coordinates in the image plane and
returns the x and y coordinates in the source plane.
Returns
-------
x_component: ArrayLike
x-coordinate ArrayLike of the ray-traced light rays
*Unit: arcsec*
y_component: ArrayLike
y-coordinate ArrayLike of the ray-traced light rays
*Unit: arcsec*
"""
ixy = backend.stack((ix, iy), dim=1) # has shape (B, Din:2)
bxy = backend.stack((bx, by)) * backend.ones(
(ix.shape[0], 1), device=bx.device
) # has shape (B, Dout:2)
# Optimize guesses in image plane
x, l, c = batch_lm( # noqa: E741 Unused `l` variable
ixy,
bxy,
lambda *a, **k: backend.stack(
raytrace(a[0][..., 0], a[0][..., 1], *a[1:], **k), dim=-1
),
)
return x
[docs]
def remove_duplicate_points(x, epsilon):
"""
Remove duplicate points from the coordinates list.
"""
unique_points = backend.zeros((0, 2), device=x.device, dtype=x.dtype)
for i in range(x.shape[0]):
# Compare current point with all points in the unique list
if i == 0 or not backend.any(
backend.norm(x[i] - unique_points, dim=1) < epsilon
):
unique_points = backend.concatenate((unique_points, x[i][None]), dim=0)
return unique_points
[docs]
def forward_raytrace(s, raytrace, x0, y0, fov, n, epsilon, max_depth=25):
# Construct a tiling of the image plane (squares at this point)
X, Y = backend.meshgrid(
backend.linspace(x0 - fov / 2, x0 + fov / 2, n),
backend.linspace(y0 - fov / 2, y0 + fov / 2, n),
indexing="ij",
)
E = backend.stack((X, Y), dim=-1)
E = backend.to(E, device=x0.device)
# build the upper and lower triangles within the squares of the grid
E = backend.concatenate(
(
backend.stack((E[:-1, :-1], E[:-1, 1:], E[1:, 1:]), dim=-2),
backend.stack((E[:-1, :-1], E[1:, :-1], E[1:, 1:]), dim=-2),
),
dim=0,
).reshape(-1, 3, 2)
i = 0
while True:
# Expand the search to neighboring triangles
if i > 0: # no need for neighbors in the first iteration
E = backend.vmap(triangle_neighbors)(E)
E = E.reshape(-1, 3, 2)
E = remove_triangle_duplicates(E)
# Upsample the triangles
E = backend.vmap(triangle_upsample)(E)
E = E.reshape(-1, 3, 2)
S = raytrace(E[..., 0], E[..., 1])
S = backend.stack(S, dim=-1)
# Identify triangles that contain the source plane point
locate = backend.vmap(triangle_contains, in_dims=(0, None))(S, s)
E = E[locate]
i += 1
# Triangles now smaller than resolution, try to find exact points
if triangle_area(E[0]) < epsilon**2:
# Rootfind the source plane point in the triangle
Emid = backend.sum(E, dim=1) / 3
Emid = forward_raytrace_rootfind(
Emid[..., 0], Emid[..., 1], s[0], s[1], raytrace
)
Smid = raytrace(Emid[..., 0], Emid[..., 1])
Smid = backend.stack(Smid, dim=-1)
if backend.all(
backend.vmap(triangle_contains)(E, Emid)
) and backend.allclose(Smid, s, atol=epsilon):
break
Emid = Emid[
backend.norm(Smid - s, dim=1) < epsilon
] # ensure only good points are returned if max_depth reached
if i > max_depth:
warn(
"Forward raytrace unable to converge reliably, reached maximum depth of 25 iterations. There may be singularities in the lens, or other numerical challenges."
)
break
# Remove duplicates
unique = remove_duplicate_points(
backend.stack((Emid[..., 0], Emid[..., 1]), dim=1), epsilon
)
return unique[..., 0], unique[..., 1]
[docs]
def physical_from_reduced_deflection_angle(ax, ay, d_s, d_ls):
"""
Computes the physical deflection angle of the given the reduced deflection angles [arcsec].
Parameters
----------
ax: ArrayLike
ArrayLike of x axis reduced deflection angles in the lens plane.
*Unit: arcsec*
y: ArrayLike
ArrayLike of y axis reduced deflection angles in the lens plane.
*Unit: arcsec*
d_s: float
distance to the source.
*Unit: Mpc*
d_ls: float
distance from lens to source.
*Unit: Mpc*
Returns
--------
x_component: ArrayLike
Physical deflection Angle in the x-direction.
*Unit: arcsec*
y_component: ArrayLike
Physical deflection Angle in the y-direction.
*Unit: arcsec*
"""
return (d_s / d_ls) * ax, (d_s / d_ls) * ay
[docs]
def reduced_from_physical_deflection_angle(ax, ay, d_s, d_ls):
"""
Computes the reduced deflection angle of the lens at given coordinates [arcsec].
Parameters
----------
ax: ArrayLike
ArrayLike of x axis physical deflection angles in the lens plane.
*Unit: arcsec*
y: ArrayLike
ArrayLike of y axis physical deflection angles in the lens plane.
*Unit: arcsec*
d_s: float
distance to the source.
*Unit: Mpc*
d_ls: float
distance from lens to source.
*Unit: Mpc*
Returns
--------
x_component: ArrayLike
Reduced deflection Angle in the x-direction.
*Unit: arcsec*
y_component: ArrayLike
Reduced deflection Angle in the y-direction.
*Unit: arcsec*
"""
return (d_ls / d_s) * ax, (d_ls / d_s) * ay
[docs]
def time_delay_arcsec2_to_days(d_l, d_s, d_ls, z_l):
"""
Computes a scaling factor to use in time delay calculations which converts
the time delay (i.e. potential and deflection angle squared terms) from
arcsec^2 to units of days.
"""
return (1 + z_l) / c_Mpc_s * d_s * d_l / d_ls * arcsec_to_rad**2 / days_to_seconds