"""Implementations of utilities for sampling fibers and various FMC estimators."""
import functools
import jax
import jax.numpy as np
import jax.random as npr
import fibermc.implicit_differentiation as implicit_differentiation
import fibermc.utils as utils
import fibermc.geometry_utils as geometry_utils
FP64: type = np.float64
pytree: type = dict
[docs]
def sample(
key: np.ndarray,
bounds: np.ndarray,
num_fibers: int,
fiber_length: float,
dtype: type = np.float32,
) -> np.ndarray:
"""Samples fibers according to a jointly uniform distribution over the starts of
the fibers and their angle; the endpoints are then determined by the fibers length.
Parameters
----------
key: np.ndarray
Psuedo-random number generation key/seed array (via jax.random.PRNGKey).
bounds: np.ndarray [4,]
Rectilinear sampling domain specified by a 4-array with elements corresponding to
(min_x, min_y, max_y, max_y).
num_fibers: int [>0]
Strictly positive number of fibers to sample.
length: float [>0]
Strictly positive fiber length.
dtype: type
Numeric type to use for the fibers (default: FP64).
Returns
-------
fibers: np.ndarray
An ndarray of shape (num_fibers, 2, 2) containing the fibers
along axis 0. For each fiber, the start point is the first row
of the (2, 2) array and the end point is the second row.
"""
location_key, angular_key = npr.split(key, 2)
x_key, y_key = npr.split(location_key, 2)
starts: np.ndarray = np.array(
(
npr.uniform(
x_key,
shape=(num_fibers,),
dtype=dtype,
minval=bounds[0],
maxval=bounds[2],
),
npr.uniform(
y_key,
shape=(num_fibers,),
dtype=dtype,
minval=bounds[1],
maxval=bounds[3],
),
)
).T
angles: np.ndarray = npr.uniform(
angular_key,
shape=(num_fibers,),
dtype=dtype,
minval=-np.pi,
maxval=np.pi,
)
ends: np.ndarray = starts + (
fiber_length * np.array([np.cos(angles), np.sin(angles)]).T
)
fibers: np.ndarray = np.stack((starts, ends), axis=-2)
return fibers
def estimate_field_length(
field: callable,
fibers: np.ndarray,
params: tuple,
negative: bool = True,
) -> np.ndarray:
"""Estimates the total fiber length for which a given scalar `field` takes on positive/negative
value.
Parameters
----------
field: callable[[...], float]
Scalar real-valued callable which takes auxiliary data `params` and fiber endpoints as
input argument(s); if fibers are dimension 2 for example, field takes np.ndarrays of size 2
and `params` to produce a real-valued output.
fibers: np.ndarray
np.ndarray of shape (num_fibers, fiber_dim, fiber_dim) of fibers.
params: pytree
Auxiliary data provided to the field (e.g., parameters).
negative: bool
Estimate the total fiber length for which `field` takes on negative values, if True;
if instead False, estimate the total fiber length for which `field` takes on positive
values.
Returns
-------
total_length: np.ndarray
Nonnegative Monte Carlo estimate of the total fiber length for which `field` takes on negative/positive
values (negative by default).
Note: this estimator assumes that `field` changes sign on a lengthscale larger than the length of each
fiber.
"""
vector_field: callable = functools.partial(
jax.vmap(field, in_axes=(None, 0)), params
)
solver_base: callable = implicit_differentiation.bind_solver(field)
solver: callable = jax.jit(
jax.vmap(
lambda fiber, params: implicit_differentiation.get_interpolant(
solver_base(np.empty(0), params, fiber), fiber
),
in_axes=(0, None),
)
)
start_points, end_points = fibers[:, 0], fibers[:, 1]
start_values, end_values = (
vector_field(start_points).ravel(),
vector_field(end_points).ravel(),
)
start_signs, end_signs = (
utils.zero_one_sign(start_values),
utils.zero_one_sign(end_values),
)
if negative:
negative: float = 0.0
positive: float = 1.0
else:
negative: float = 1.0
positive: float = 0.0
count_entire_fiber_cond: np.ndarray = np.logical_and(
start_signs == negative, end_signs == negative
)
count_none_fiber_cond: np.ndarray = np.logical_and(
start_signs == positive, end_signs == positive
)
count_from_start_cond: np.ndarray = np.logical_and(
start_signs == negative, end_signs == positive
)
count_from_end_cond: np.ndarray = np.logical_and(
start_signs == positive, end_signs == negative
)
total_length: np.ndarray = np.zeros(1)
# count the entire fiber length
count_all_fibers: np.ndarray = np.where(
count_entire_fiber_cond.reshape(-1, 1, 1), fibers, np.zeros_like(fibers)
)
total_length += jax.vmap(utils.custom_norm)(
count_all_fibers[:, 1] - count_all_fibers[:, 0]
).sum()
# count from the start of the fiber to the intersection point
count_from_start: np.ndarray = np.where(
count_from_start_cond.reshape(-1, 1, 1), fibers, np.zeros_like(fibers)
)
total_length += jax.vmap(utils.custom_norm)(
solver(count_from_start, params) - count_from_start[:, 0]
).sum()
# count from the end of the fiber to the intersection point
count_from_end: np.ndarray = np.where(
count_from_end_cond.reshape(-1, 1, 1), fibers, np.zeros_like(fibers)
)
total_length += jax.vmap(utils.custom_norm)(
count_from_end[:, 1] - solver(count_from_end, params)
).sum()
return total_length
[docs]
def estimate_field_area(
field: callable,
fibers: np.ndarray,
params: pytree,
negative: bool = True,
) -> np.ndarray:
"""Estimates the total area for which a scalar `field` takes on positive/negative
value (negative, by default).
Parameters
----------
field: callable[[...], float]
Scalar real-valued callable which takes auxiliary data `params` and fiber endpoints as
input argument(s); if fibers are dimension 2 for example, field takes np.ndarrays of size 2
and `params` to produce a real-valued output.
fibers: np.ndarray
np.ndarray of shape (num_fibers, fiber_dim, fiber_dim) of fibers.
params: tuple
Auxiliary data provided to the field (e.g., parameters).
negative: bool
Estimate the total area for which `field` takes on negative values, if True;
if instead False, estimate the area for which `field` takes on positive
values.
Returns
-------
total_area: np.ndarray
Nonnegative Monte Carlo estimate of the total fiber area for which `field` takes on negative/positive
values (negative by default).
Note: this estimator assumes that `field` changes sign on a lengthscale larger than the length of each
fiber.
"""
cumulative_fiber_length: float = jax.vmap(utils.custom_norm)(
fibers[:, 1] - fibers[:, 0]
).sum()
total_length: np.ndarray = estimate_field_length(
field, fibers, params, negative=negative
)
total_field_area: np.ndarray = total_length / cumulative_fiber_length
return total_field_area
def estimate_hull_intersection_length(
fibers: np.ndarray, hull: np.ndarray
) -> np.ndarray:
"""Estimates the total fiber length which lies within a provided
(convex) hull.
Parameters
----------
fibers: np.ndarray
np.ndarray of shape (num_fibers, fiber_dim, fiber_dim) containing the
fibers.
hull: np.ndarray
np.ndarray containing line segments along axis 0 which are ordered
(counter-clockwise by default) to represent a convex hull.
Example
-------
>>> fibers: np.ndarray = np.array([[[-0.25, 0.25], [0.25, 0.25]]])
>>> hull: np.ndarray = np.array([
[0.5, 0.0],
[0.5, 0.5],
[0.0, 0.5],
[0.0, 0.0],
[0.5, 0.0],
])
>>> estimate_hull_intersection_length(fibers, hull)
>>> 0.25
Returns
-------
estimated_intersection_length: np.ndarray
Nonnegative estimated length of interesection between `fibers` and `hull`.
"""
# clip the given fibers so that the resulting 'clipped' fibers all lie within the hull
intra_hull_fibers: np.ndarray = geometry_utils.clip_inside_convex_hull(fibers, hull)
estimated_intersection_length: np.ndarray = jax.vmap(utils.custom_norm)(
intra_hull_fibers[:, 1] - intra_hull_fibers[:, 0]
).sum()
return estimated_intersection_length
[docs]
def estimate_hull_area(fibers: np.ndarray, hull: np.ndarray) -> np.ndarray:
"""Uses fiber Monte Carlo to estimate the area of a convex shape; assuming fibers are sampled from an
extended domain. See `estimators.estimate_hull_intersection_length`.
"""
# the cumulative fiber length inside the hull
intersection_length: np.ndarray = estimate_hull_intersection_length(fibers, hull)
# the total fiber length (in the hull or not)
cumulative_fiber_length: np.ndarray = jax.vmap(utils.custom_norm)(
fibers[:, 1] - fibers[:, 0]
).sum()
# Monte Carlo estimate of the area
estimated_area: np.ndarray = intersection_length / cumulative_fiber_length
return estimated_area
[docs]
@functools.partial(jax.jit, static_argnums=(0, 3))
def clip_to_field(
field: callable,
fibers: np.ndarray,
params: pytree,
negative: bool = True,
) -> np.ndarray:
# vectorize and partially evaluate `field`
vector_field: callable = functools.partial(
jax.vmap(field, in_axes=(None, 0)), params
)
solver: callable = jax.vmap(
lambda fiber: implicit_differentiation.bisection_solver(params, fiber, field)
)
fiber_dim: int = fibers.shape[-1]
# the sign of field(x) where x is each fiber endpoint
start_points, end_points = fibers[:, 0], fibers[:, 1]
start_values, end_values = (
vector_field(start_points).ravel(),
vector_field(end_points).ravel(),
)
start_signs, end_signs = (
utils.zero_one_sign(start_values),
utils.zero_one_sign(end_values),
)
# (default) 0: field(x) < 0 --- 1: field(x) >= 0
if negative:
negative: float = 0.0
positive: float = 1.0
else:
negative: float = 1.0
positive: float = 0.0
count_entire_fiber_cond: np.ndarray = np.logical_and(
start_signs == negative, end_signs == negative
)
count_none_fiber_cond: np.ndarray = np.logical_and(
start_signs == positive, end_signs == positive
)
count_from_start_cond: np.ndarray = np.logical_and(
start_signs == negative, end_signs == positive
)
count_from_end_cond: np.ndarray = np.logical_and(
start_signs == positive, end_signs == negative
)
# case: count the entire fiber length
inside_fibers: np.ndarray = np.where(
count_entire_fiber_cond.reshape(-1, 1, 1), fibers, np.zeros_like(fibers)
)
# case: clip from the start of the fiber to the intersection point
count_from_start: np.ndarray = np.where(
count_from_start_cond.reshape(-1, 1, 1), fibers, np.zeros_like(fibers)
)
solver_cond: callable = lambda predicates, fibers: jax.vmap(
lambda predicate, fiber: jax.lax.cond(
predicate,
lambda fiber: implicit_differentiation.get_interpolant(
solver(fiber[None, :, :]), fiber
)[None, :],
lambda fiber: np.zeros((1, fiber_dim)),
operand=fiber,
)
)(predicates, fibers)
start_clipped_fibers: np.ndarray = (
np.dstack(
(
count_from_start[:, 0],
np.squeeze(solver_cond(count_from_start_cond, count_from_start)),
)
)
.swapaxes(2, 1)
.reshape(-1, 2, fiber_dim)
)
# case: count from the end of the fiber to the intersection point
count_from_end: np.ndarray = np.where(
count_from_end_cond.reshape(-1, 1, 1), fibers, np.zeros_like(fibers)
)
end_clipped_fibers: np.ndarray = (
np.dstack(
(
np.squeeze(solver_cond(count_from_end_cond, count_from_end)),
count_from_end[:, 1],
)
)
.swapaxes(2, 1)
.reshape(-1, 2, fiber_dim)
)
# aggregate all the valid fibers
clipped_fibers: np.ndarray = np.vstack(
(inside_fibers, start_clipped_fibers, end_clipped_fibers)
)
return clipped_fibers