Source code for fibermc.implicit_differentiation

"""This module contains procedures for implicit differentiation setups in
conjunction with fiber sampling applications.
"""

import functools
from typing import TypeVar

import jax
import jax.numpy as np
from jaxopt.implicit_diff import custom_root

from fibermc.jax_utils import divide00

pytree: type = TypeVar("Pytree")


# @functools.partial(jax.jit, static_argnums=(0, 2))
def bisect(f: callable, fiber: np.ndarray, num_iterations: int = 10) -> float:
    interpolant: callable = lambda x: fiber[0] + x * (fiber[1] - fiber[0])
    h: callable = lambda x: f(interpolant(x))

    # standardize so the 'left' endpoint has negative value
    endpoints: np.ndarray = jax.lax.cond(
        h(0.0) > 0.0,
        lambda _: np.array([1.0, 0]),
        lambda _: np.array([0.0, 1.0]),
        operand=None,
    )

    def _bisect(endpoints: np.ndarray) -> float:
        left, right = endpoints
        midpoint: float = (left + right) / 2.0
        return jax.lax.cond(
            h(midpoint) < 0.0,
            lambda _: np.array([midpoint, right]),
            lambda _: np.array([left, midpoint]),
            operand=midpoint,
        )

    for _ in range(num_iterations):
        endpoints: np.ndarray = _bisect(endpoints)

    return endpoints[0]


def get_interpolant(alpha: np.ndarray, fiber: np.ndarray) -> np.ndarray:
    return fiber[0] + alpha * (fiber[1] - fiber[0])


def bind_optimality_condition(f: callable) -> np.ndarray:
    def optimality_condition(
        x: np.ndarray, params: dict, fiber: np.ndarray
    ) -> np.ndarray:
        """Computes the the scalar-value of the `constraint function` which takes on the
        value zero when the constrain is satisfied.

        Parameters
        ----------
        x: np.ndarray
            np.ndarray (of length 1), representing the interpolant point in the domain
            of `scalar_field` to be evaluated. This is a point defined such that:
            fiber[0] + x * (fiber[1] - fiber[0]) lies on the zero set of the scalar
            field.
        params: tuple
            auxiliary parameters for `scalar_field`.
        fiber: np.ndarray
            fiber included as the second constraint; the constraint is zero when
            `x` lies on this fiber, otherwise the value of the constraint is the
            distance between `x` and the fiber. TODO this isn't actually used

        Returns
        -------
        constraint_value: np.ndarray
            the value of the scalar field evaluated at get_interpolant(x, fiber)
        """
        z: np.ndarray = get_interpolant(x, fiber)
        constraint_value: np.ndarray = f(params, z)
        return constraint_value

    return optimality_condition


[docs] def bind_solver(f: callable) -> callable: @custom_root(bind_optimality_condition(f)) def bisection_solver( x_init: np.ndarray, params: dict, fiber: np.ndarray ) -> np.ndarray: """Computes the (intersection) point that is (1) on the line segment represented by `fiber` and (2) at which the scalar field takes on the value zero. Parameters ---------- x_init: np.ndarray unused but necessary for signature of jaxopt.custom_root. params: tuple auxiliary parameters for the bound callable `f`. fiber: np.ndarray line segment for which the scalar_field, when evaluated at each endpoint, returns a value with opposing sign. Returns ------- fixed_point: np.ndarray intersection point which (1) lies on the line segment specified by `fiber` and (2) satisfies f(params, fixed_point) == 0. """ x: np.ndarray = bisect(functools.partial(f, params), fiber) return x return bisection_solver
def bisection_constraint( f: callable, x: np.ndarray, params: tuple, fiber: np.ndarray ) -> np.ndarray: """Computes the np.ndarray of length 2 representing the vector-value of the `constraint function` which takes on the value of the zero-vector when the constrain is satisfied. Parameters ---------- scalar_field: callable scalar field function whose level-set at zero is included as one of the constraints; the field is evaluated at `x` and `params`. x: np.ndarray np.ndarray (of length 2, nominally), representing the point in the domain of `scalar_field` to be evaluated. params: tuple auxiliary parameters for `scalar_field`. fiber: np.ndarray fiber included as the second constraint; the constraint is zero when `x` lies on this fiber, otherwise the value of the constraint is the distance between `x` and the fiber. Returns ------- constraint_value: np.ndarray first element is the value of the first constraint (i.e., the value of the scalar field evaluated at x); second element is the distance between `x` and the fiber, it is zero when `x` is coincident with the fiber. Note: where the constraint is satisfied, this means `x` is both coincident with `fiber` and lies on the (zero) level-set of the function `scalar_field`. """ z: np.ndarray = get_interpolant(x, fiber) field_constraint: float = np.squeeze(f(params, z)) return field_constraint @functools.partial(jax.custom_vjp, nondiff_argnums=(2,)) def bisection_solver(params: pytree, fiber: np.ndarray, f: callable) -> np.ndarray: """Computes the (intersection) point that is (1) on the line segment represented by `fiber` and (2) at which the scalar field takes on the value zero. Parameters ---------- f: callable function which should return a real-valued scalar when applied like: f(params, x). fiber: np.ndarray line segment for which the scalar_field `f`, when evaluated at each endpoint, returns a value with opposing sign. params: tuple auxiliary parameters for `f`. Returns ------- fixed_point: np.ndarray intersection point which (1) lies on the line segment specified by `fiber` and (2) satisfies f(params, fixed_point) == 0. See also: src.implicit_differentiation.bisection_solver{forward, backward} """ fixed_point: np.ndarray = bisect(functools.partial(f, params), fiber) return fixed_point def bisection_solver_forward( params: tuple, fiber: np.ndarray, f: callable ) -> np.ndarray: """Uses `src.implicit_differentiation.bisection_solver` to determine the intersection between the fiber and the zero level-set of the scalar field. Parameters ---------- scalar_field: callable function which should return a real-valued scalar when applied like: scalar_field(params, x). fiber: np.ndarray line segment for which the scalar_field, when evaluated at each endpoint, returns a value with opposing sign. params: tuple auxiliary parameters for `scalar_field`. Returns ------- payload: tuple first element is the np.ndarray containing the intersection point; second element is the collection of residuals used in the backward pass to compute the vjp. """ # determine the intersection point fixed_point: np.ndarray = bisection_solver(params, fiber, f) # collect the residuals to be used in the backward pass residuals: tuple = (params, fixed_point, fiber) payload: tuple = (fixed_point, residuals) return payload def bisection_solver_backward( f: callable, residuals: tuple, incoming_gradient: np.ndarray ) -> np.ndarray: """Computes the vector-Jacobian project associated with the bisection solver procedure, using implicit differentiation. Parameters ---------- scalar_field: callable function which should return a real-valued scalar when applied like: scalar_field(params, x). fiber: np.ndarray line segment for which the scalar_field, when evaluated at each endpoint, returns a value with opposing sign. residuals: tuple first element is the parameters associated with the scalar_field (assumed fixed) when the bisection solver was invoked; second element is the intersection point. incoming_gradient: np.ndarray gradient signal arising from some downstream (from the perspective of the forward pass, that is) computation for which autodiff has already produced derivative values; these are used to correctly proceed with the chain rule back upstream of wherever `bisection_solver` was called. Returns ------- final_vjp: np.ndarray array containing the local derivatives associated with the bisection solver and multiplied with the incoming (dowstream) derivatives. """ # unpack residuals params, fixed_point, fiber = residuals # f's univariate analogues f_params: callable = lambda _params: bisection_constraint( f, fixed_point, _params, fiber ) f_spatial: callable = lambda _x: bisection_constraint(f, _x, params, fiber) # partial vjps (w.r.t. params and the spatial variable) _, vjp_params = jax.vjp(f_params, params) _, vjp_spatial = jax.vjp(f_spatial, fixed_point) # solve for the intermediate vjp jacobian_f_fn: callable = jax.jacobian(f_spatial) jacobian_f: np.ndarray = jacobian_f_fn(fixed_point) A: np.ndarray = jacobian_f.T b: np.ndarray = incoming_gradient intermediate_vjp: np.ndarray = -1.0 * divide00(b, A) final_vjp: np.ndarray = vjp_params(intermediate_vjp) return (final_vjp[0], None) # --- solver custom vjp binding bisection_solver.defvjp(bisection_solver_forward, bisection_solver_backward)