from functools import partial
import chex
import jax
import jax.numpy as np
from fibermc.jax_utils import _vectorize, _jit_vectorize
def _get_hull_orientation(hull: np.ndarray) -> np.ndarray:
"""Returns -1 if the provided hull is oriented clockwise and
+1 otherwise.
"""
chex.assert_shape(hull, (..., None, 2))
return np.sign(np.cross(hull[1] - hull[0], hull[2] - hull[0]))
def _reduce_clip_params(clips: np.ndarray, axis: int = -2) -> np.ndarray:
if axis == -1 or axis == clips.ndim - 1:
raise ValueError(f"Do not reduce_clip_params over coordinate axis.")
clips: np.ndarray = np.moveaxis(clips, axis, 0)
lo, hi = np.moveaxis(clips, -1, 0)
clip_params: np.ndarray = np.stack([lo.max(0), hi.min(0)], axis=-1)
return clip_params
def _hull_to_walls(hull: np.ndarray) -> np.ndarray:
"""Computes the 'walls' (line-segments) associated with the given `hull`
containing the vertices of the hull.
Parameters
----------
hull: np.ndarray
np.ndarray of vertices comprising the hull in question.
Returns
-------
walls: np.ndarray
np.ndarray of 'walls' (line segments) corresponding to the hull.
"""
chex.assert_shape(hull, (..., None, 2))
walls: np.ndarray = np.stack([hull, np.roll(hull, -1, axis=-2)], axis=-2)
return walls
def _in01(x: np.ndarray, open: bool = True) -> np.ndarray:
"""Returns a boolean-valued np.ndarray representing whether elements
of `x` lie in the interval (0, 1] or (0, 1); the latter if `open`=True.
"""
return np.logical_and(0 < x, x < 1) if open else np.logical_and(0 <= x, x <= 1)
@partial(np.vectorize, signature="(n,2),()->(n,2)")
def _orient_hull(hull: np.ndarray, target_orientation: int):
"""Given a hull and a target orientation, ensures the hull
is organized with that orientation; if it doesn't it reverses
the hull's orientation; the correctly oriented hull is returned
"""
chex.assert_shape(hull, (None, 2))
# determine whether the hull ought to be reversed
should_reverse = target_orientation != _get_hull_orientation(hull)
reverse: callable = lambda hull: hull[::-1]
dont_reverse: callable = lambda hull: hull
oriented_hull: np.ndarray = jax.lax.cond(
should_reverse, reverse, dont_reverse, operand=hull
)
return oriented_hull
@_jit_vectorize(signature="(2,2),(2)->(2,2)")
def apply_fiber_clip(fiber: np.ndarray, clip_params: np.ndarray) -> np.ndarray:
"""Clips the provided fibers according to the clip parameters."""
clipped_fibers: np.ndarray = interpolate(clip_params, fiber)
return clipped_fibers
def orient_hull_ccw(hull: np.ndarray) -> np.ndarray:
"""Returns the same hull as provided but guaranteed to be
oriented counter-clockwise.
"""
oriented_hull: np.ndarray = _orient_hull(hull, +1)
return oriented_hull
def orient_hull_cw(hull: np.ndarray) -> np.ndarray:
"""Returns the same hull as provided but guaranteed to be
oriented clockwise.
"""
oriented_hull: np.ndarray = _orient_hull(hull, -1)
return oriented_hull
@_vectorize(signature="(2,2),(2,2)->(2)")
def intersect_segments(
first_segment: np.ndarray, second_segment: np.ndarray
) -> np.ndarray:
"""Computes the intersection between two line segments.
Note: returns np.array([np.inf, np.inf]) if there is no intersection.
"""
# setup the linear system
A: np.ndarray = np.stack(
[first_segment[1] - first_segment[0], second_segment[0] - second_segment[1]],
axis=-1,
)
b: np.ndarray = second_segment[0] - first_segment[0]
# determine if the coefficient matrix is singular
abs_determinant: np.ndarray = np.abs(np.linalg.det(A))
# specify both sides of the branch
linear_solve: callable = partial(np.linalg.solve, A)
is_no_solution: callable = lambda _: np.array([np.inf, np.inf])
# compute the intersection
is_solvable: bool = abs_determinant > 0
intersection: np.ndarray = jax.lax.cond(
is_solvable, linear_solve, is_no_solution, operand=b
)
return intersection
@_jit_vectorize(signature="(2,2),(2,2)->(2),(2),()")
def clip_wrt_wall(fiber: np.ndarray, wall: np.ndarray) -> np.ndarray:
"""Clips a fiber with respect to an oriented line segment `wall`.
Parameters
----------
fiber: np.ndarray
np.ndarray of shape (2, fiber_dim); the fiber.
wall:
np.ndarray of shape (2, fiber_dim) containing the endpoints of the wall.
Returns
-------
clip_parameters: np.ndarray
interpolation parameters for clipped fiber of shape 2.
endpoint_sides: int {1, -1}
indicates which side of wall each fiber endpoint resides.
has_intersection: bool
indicates if the fiber and the wall had non-empty intersection.
"""
# closure taking a query point and returning which side of the wall it falls on
which_side: callable = lambda query_point: np.sign(
np.cross(wall[1] - wall[0], query_point - wall[0])
)
# which side of the wall each endpoint of the fiber falls on
endpoint_sides: np.ndarray = np.stack([which_side(f) for f in fiber], axis=0)
start_sides, end_sides = endpoint_sides
# find the intersection (if it exists)
intersection: np.ndarray = intersect_segments(fiber, wall)
has_intersection: np.ndarray = _in01(intersection, open=True).all(-1)
alpha, _ = np.moveaxis(intersection, -1, 0)
intersection_select: np.ndarray = jax.lax.select(
start_sides > end_sides,
np.stack([0.0, alpha]),
np.stack([alpha, 1.0]),
)
# determine the clip parameters
clip_parameters: np.ndarray = jax.lax.select(
has_intersection, intersection_select, np.array([0.0, 1.0])
)
return clip_parameters, endpoint_sides, has_intersection
@_jit_vectorize(signature="(2,2),(n,2)->(2)")
def _clip_inside_convex_hull(fiber: np.ndarray, hull: np.ndarray) -> np.ndarray:
"""Clip fiber(s) to inside of a convex hull.
Parameters
----------
fiber:
np.ndarray of shape (2, fiber_dim) containing the fiber.
hull:
np.ndarray of shape (num_vertices, fiber_dim) containing the vertices
of the hull.
Returns
-------
clip_parameters: np.ndarray
array of clipping parameters.
"""
# ensure we have a valid hull with greater than 2 vertices
num_vertices, _ = hull.shape
assert (
num_vertices > 2
), f"Hull must contain more than 2 vertices (got hull with {num_vertices} vertices)."
# orient the hull and obtain its walls
hull: np.ndarray = orient_hull_ccw(hull)
walls = _hull_to_walls(hull)
# clip the fibers with respect to the walls of the hull
clips, endpoint_sides, has_intersections = clip_wrt_wall(fiber, walls)
clip: np.ndarray = _reduce_clip_params(clips)
# determine if there are any intersections
any_intersections: np.ndarray = has_intersections.any(-1)
is_fully_inside: np.ndarray = (endpoint_sides > 0).all((-1, -2))
is_fully_outside: np.ndarray = np.logical_and(~is_fully_inside, ~any_intersections)
# compute the clipping parameters
clip_parameters: np.ndarray = jax.lax.select(
is_fully_outside, np.zeros_like(clip), clip
)
return clip_parameters
[docs]
def clip_inside_convex_hull(fibers: np.ndarray, hull: np.ndarray) -> np.ndarray:
"""Wraps src.utils.geometry_utils._clip_inside_convex_hull and
applies the clipping parameters to the fibers.
"""
clip_parameters = _clip_inside_convex_hull(fibers, hull)
clipped_fibers: np.ndarray = apply_fiber_clip(fibers, clip_parameters)
return clipped_fibers
@_jit_vectorize(signature="(),(2,2)->(2)")
def interpolate(alpha: float, segment: np.ndarray) -> np.ndarray:
"""Interpolates `alpha` amount along `segment`."""
x0, x1 = np.moveaxis(segment, -2, 0)
interpolated: np.ndarray = (1 - alpha) * x0 + alpha * x1
return interpolated
@_jit_vectorize(signature="(2,2),(k,2)->(k,2,2)")
def apply_fiber_multiclip(fiber: np.ndarray, clip_params: np.ndarray) -> np.ndarray:
return interpolate(clip_params, fiber)