Source code for fibermc.graphics_utils

import functools
from typing import List, Tuple

from jax import jit
import jax.numpy as np
import numpy as static_np


def make_batch_idxs(num_elements: int, batch_size: int) -> list:
    """Constructs a list of batches of size `batch_size`. Each
    batch is itself an 1D integer-valued np.ndarray of length
    `batch_size`.

    Parameters
    ----------
    num_elements: int
        total number of elements to construct the batches from.
    batch_size: int
        integer-valued size of each batch to construct.

    Examples
    --------
    >>> make_batch_idxs(9, 3)
    >>> [[0, 1, 2], [3, 4, 5], [6, 7, 8]]

    >>> make_batch_idxs(9, 2)
    >> [[0, 1], [2, 3], [4, 5], [6, 7], [8]]

    Returns
    -------
    batches: List[np.ndarray]
        list of integer-valued `np.np.ndarray`s comprising each 'batch'.
    """
    assert (
        batch_size <= num_elements
    ), f"batch size {batch_size} cannot be larger than the total number of inputs {num_elements}"
    assert batch_size > 0, f"batch size {batch_size} must be a positive integer "

    # --- construct the list of 'full' (size `batch_size`) batches
    num_full_batches: int = num_elements // batch_size
    full_batches: List[np.ndarray] = [
        np.arange(i * batch_size, (i + 1) * batch_size) for i in range(num_full_batches)
    ]

    # --- (optionally) add a last non-'full' batch (see example 2 in the docstring)
    batches: List[np.ndarray] = (
        full_batches
        if ((num_elements % batch_size) == 0)
        else (full_batches + [np.arange(num_full_batches * batch_size, num_elements)])
    )
    return batches


@functools.partial(jit, static_argnums=(1,))
def translation_from_id(tile_id: np.ndarray, tile_dimension: int) -> np.ndarray:
    """Computes the translation vector associated with the provided tile
    identifier and the dimension of each tile.

    Parameters
    ----------
    tile_id: np.ndarray
        np.ndarray of length 2 which contains a row and column index, respectively,
        to identify the location of the tile of interest.
    tile_dimension: int
        dimension of each tile (assumed to be equivalent along both rows and columns).

    Returns
    -------
    translation: np.ndarray
        2D array representing the translation vector applied to an object to place
        it in the bounds of the tile of interest.
    """
    tile_row, tile_column = tile_id
    translation: np.ndarray = np.array(
        [tile_column * tile_dimension, tile_row * tile_dimension]
    )
    return translation


def translate_fibers(fibers: np.ndarray, translation: np.ndarray) -> np.ndarray:
    """TODO don't need this method?"""
    fibers: np.ndarray = fibers + translation
    return fibers


def compute_tile_ids(target_shape: np.ndarray, tile_dimension: int) -> np.ndarray:
    """compute_tile_ids.

    Parameters
    ----------
    target_shape: np.ndarray
        shape of the target array (i.e., the size of the array to be rendered downstream
        from this utility).
    tile_dimension: int
        dimension of each tile (assumed to be equivalent along both rows and columns).

    Returns
    -------
    tile_ids: np.ndarray
        np.np.ndarray of shape (target_rows * target_columns, 2) containing each 2D
        tile identifier along the second axis (i.e., axis=1).
    """
    # --- ensure the target shape is divided by the tile dimension
    target_rows, target_columns = target_shape[:2]
    # assert divides(target_rows, tile_dimension), f"target number of rows {target_rows} does not divide tile dimension {tile_dimension}"
    # assert divides(target_columns, tile_dimension), f"target number of columns {target_columns} does not divide tile dimension {tile_dimension}"

    num_tiles_x: int = target_columns // tile_dimension
    num_tiles_y: int = target_rows // tile_dimension

    tile_ids: np.ndarray = np.dstack(
        np.meshgrid(np.arange(num_tiles_x), np.arange(num_tiles_y))
    ).reshape(-1, 2)
    return tile_ids


def get_bounds_from_hull(pixel_hull: np.ndarray) -> np.ndarray:
    return np.array(
        [
            pixel_hull[:, 0].min(),
            pixel_hull[:, 1].min(),
            pixel_hull[:, 0].max(),
            pixel_hull[:, 1].max(),
        ]
    )


def compute_background_area(image: np.ndarray, mask: np.ndarray) -> float:
    """Compute the total area of the background region associated with an image,
    when the image is projected onto the unit square [0, 1] x [0, 1].

    Parameters
    ----------
    image: np.ndarray
        array containing an image; currently assumed to be square (i.e. image.shape[0] == image.shape[1]).
    mask: np.ndarray
        integer or boolean-valued array with a 1 (True) at indices where the image
        is background and a 0 (False) at indices where the image is foreground.

    Example
    -------
    >>> compute_background_area(np.zeros((2, 2)), np.eye(2))
    >>> 0.5

    Returns
    -------
    background_area: float
        proportion of the image that is background (computed as the product of the area of a
        pixel and the number of background pixels).
    """
    num_rows, num_columns, num_channels = image.shape
    assert (
        num_rows == num_columns
    ), f"image with {num_rows} rows and {num_columns} columns is not square (non-square images are not supported)."

    pixel_area: float = (1 / num_rows) ** 2
    num_background_pixels: int = np.sum(mask)
    background_area: float = pixel_area * num_background_pixels
    return background_area


def compute_background_pixel_hulls(image: np.ndarray, mask: np.ndarray) -> list:
    """Instantiates convex hulls (represented as arrays of line segments) for each pixel
    contained in the background of the image, which is identified with a value of `1.`` in
    the `mask` array at the corresponding location.

    Parameters
    ----------
    image: np.ndarray
        array containing an image; currently assumed to be square (i.e. image.shape[0] == image.shape[1]).
    mask: np.ndarray
        integer or boolean-valued array with a 1 (True) at indices where the image
        is background and a 0 (False) at indices where the image is foreground.

    Returns
    -------
    hulls: List[np.ndarray]
        list of convex hulls for each background pixel (that is, the length of this list
        is equal to np.sum(mask)); each hull is represented as a (4, 2) array of line segments
        oriented counter-clockwise from the bottom-right hand corner.
    """
    num_rows, num_columns, _ = image.shape
    assert (
        num_rows == num_columns
    ), f"image with {num_rows} rows and {num_columns} columns is not square (non-square images are not supported)."

    pixel_side_length: float = 1 / num_rows
    background_indices: np.ndarray = np.array(np.nonzero(mask)).T
    hulls: List[np.ndarray] = []

    for x, y in background_indices:
        hull: np.ndarray = np.array(
            [
                [
                    (x + 1) * pixel_side_length,
                    y * pixel_side_length,
                ],  # bottom right-hand corner
                [
                    x * pixel_side_length,
                    y * pixel_side_length,
                ],  # bottom left-hand corner
                [
                    x * pixel_side_length,
                    (y + 1) * pixel_side_length,
                ],  # upper left-hand corner
                [
                    (x + 1) * pixel_side_length,
                    (y + 1) * pixel_side_length,
                ],  # upper right-hand corner
            ]
        )
        hull_counter_clockwise: np.ndarray = hull[::-1]
        hulls.append(hull_counter_clockwise)

    return hulls


[docs] def create_pixel_hull( pixel_coordinate: np.ndarray, ccw: bool = True, dtype=np.float32 ) -> np.ndarray: """Create a pixel hull (i.e., a (4, 2) array with the vertices associated with a square convex hull) given a coordinate representing its location. Parameters ---------- pixel_coordinate: np.ndarray a length 2 array encoding the location (in terms of row/column) of the pixel for which to create a hull. ccw: bool whether to orient the hull counter-clockwise (default: True). dtype: type numeric type for the values comprising the resultant pixel hull (default: np.float32). Returns ------- pixel_hull: np.ndarray array containing 4 vertices of length 2 comprising the convex hull of the pixel. """ pixel_hull_clockwise: np.ndarray = np.array( [ [pixel_coordinate[0] + 1, pixel_coordinate[1]], [pixel_coordinate[0], pixel_coordinate[1]], [pixel_coordinate[0], pixel_coordinate[1] + 1], [pixel_coordinate[0] + 1, pixel_coordinate[1] + 1], ], dtype=dtype, ) if ccw: pixel_hull: np.ndarray = pixel_hull_clockwise[::-1] else: pixel_hull: np.ndarray = pixel_hull_clockwise return pixel_hull
def random_image(size: Tuple[int]) -> np.ndarray: """Generate a psuedo-random image (the model is uniform over the color-space) of size `size`. Parameters ---------- size: Tuple[int] tuple of integers representing the desired generated image shape. Returns ------- image: np.ndarray floating-point valued np.ndarray with values ranging between 0.0 and 1.0 sampled according to a uniform distribution. """ image: np.ndarray = static_np.random.uniform(size=(size)) return image @functools.partial(jit, static_argnums=(0, 1)) def get_pixel_coordinates(width: int, height: int, bounds: np.ndarray) -> np.ndarray: """Returns an array containing the elements of the Cartesian product of the sets of integers comprising the range of rows (0, 1, ..., size[0]-1) and columns (0, 1, ..., size[1]-1) given. Parameters ---------- bounds: ndaray np.ndarray of integers representing the associated image shape. Example ------- >>> get_pixel_coordinates(np.array([0, 0, 2, 2])) >>> array([0, 0], [0, 1], [1, 0], [1, 1]) Returns ------- pixel_coordinates: np.ndarray array of shape (np.prod(size), 2) containing all the length 2 pixel coordinates (row, column) associated with an image of size `size`. """ min_x, min_y, _, _ = bounds[:4] x_coords: np.ndarray = np.arange(width) + min_x y_coords: np.ndarray = np.arange(height) + min_y coordinates: np.ndarray = np.dstack(np.meshgrid(x_coords, y_coords)).reshape(-1, 2) return coordinates def compute_vertex_mask(vertices: np.ndarray) -> np.ndarray: """Constructs a binary mask array to prevent updates on 'boundary' vertices present in the input `vertices` array. Parameters ---------- vertices: np.ndarray (N, 2) array of vertex positions. Returns ------- vertex_mask: np.ndarray binary integer-valued {0, 1} array of length (N, 1) encoding whether, for each of the N vertices in the input array, whether that vertex should be held fixed (because it is at the boundary). """ # --- extract coordinate components vertex_x_coordinates: np.ndarray = vertices[:, 0] vertex_y_coordinates: np.ndarray = vertices[:, 1] # --- determine boundary coordinates min_x_coordinate: np.ndarray = vertices[:, 0].min() max_x_coordinate: np.ndarray = vertices[:, 0].max() min_y_coordinate: np.ndarray = vertices[:, 1].min() max_y_coordinate: np.ndarray = vertices[:, 1].max() # --- construct the mask x_mask: np.ndarray = static_np.logical_or( vertex_x_coordinates == min_x_coordinate, vertex_x_coordinates == max_x_coordinate, ) y_mask: np.ndarray = static_np.logical_or( vertex_y_coordinates == min_y_coordinate, vertex_y_coordinates == max_y_coordinate, ) full_mask: np.ndarray = ~static_np.logical_or(x_mask, y_mask) vertex_mask: np.ndarray = full_mask.astype(int).reshape(-1, 1) return vertex_mask def apply_vertex_mask(gradient: tuple, mask: np.ndarray) -> tuple: """TODO need this function?""" return (gradient[0], gradient[1] * mask) def apply_vertex_clip(bounds: tuple, params: tuple) -> tuple: """TODO don't take params, just take vertices?""" vertices: np.ndarray = params[1] min_x, min_y, max_x, max_y = bounds vertex_x: np.ndarray = np.clip(vertices[:, 0], a_min=min_x, a_max=max_x) vertex_y: np.ndarray = np.clip(vertices[:, 1], a_min=min_y, a_max=max_y) clipped_vertices: np.ndarray = np.stack((vertex_x, vertex_y)).T # TODO this is hacky? if len(params) == 2: return (params[0], clipped_vertices) elif len(params) == 3: return (params[0], clipped_vertices, params[2]) def _apply_vertex_clip(bounds: tuple, vertices: np.ndarray) -> tuple: """TODO don't take params, just take vertices?""" min_x, min_y, max_x, max_y = bounds vertex_x: np.ndarray = np.clip(vertices[:, 0], a_min=min_x, a_max=max_x) vertex_y: np.ndarray = np.clip(vertices[:, 1], a_min=min_y, a_max=max_y) clipped_vertices: np.ndarray = np.stack((vertex_x, vertex_y)).T return clipped_vertices def bounds_from_tile_id(tile_id: np.ndarray, tile_dimension: int) -> np.ndarray: """Computes an np.ndarray parameterizing a rectangular boundary, given a tile identifier and a dimension (measured in pixels) of each tile. """ tile_x, tile_y = tile_id return np.array( [ tile_x * tile_dimension, tile_y * tile_dimension, (tile_x + 1) * tile_dimension, (tile_y + 1) * tile_dimension, ] ) def get_device_batches(config) -> np.ndarray: # --- ensure the requested devices are available num_devices_available: int = get_num_devices() requested_devices: int = config.compute.multi_gpu.num_devices assert ( requested_devices <= num_devices_available ), f"requested {requested_devices} devices but only {num_devices_available} available." num_devices: int = requested_devices # --- if there are fewer tiles than devices, only use one device per tile num_tiles: int = config.tile_ids.shape[0] num_devices: int = min(num_tiles, num_devices) args.log.info(f"using {num_devices} of {num_devices_available} available devices") # --- aggregate batch indices of tiles for launching on devices TODO generalize assert ( num_tiles % num_devices == 0 ), "number of devices does not equally divide number of tiles" batch_idxs: np.ndarray = np.arange(num_tiles).reshape( num_devices, int((num_tiles / num_devices)) ) device_batches: np.ndarray = config.tile_ids[batch_idxs, :] return device_batches def get_tile( target_image: np.ndarray, tile_dimension: int, tile_id: np.ndarray ) -> np.ndarray: x_start, y_start = tile_id target_shape: np.ndarray = np.array(target_image.shape[:2]) num_pixels: int = np.prod(target_shape) pixels_per_tile: int = tile_dimension**2 tile: np.ndarray = np.zeros((tile_dimension, tile_dimension, 3)).astype(np.float32) target: np.ndarray = static_np.array(target_image) tile_x_start: int = x_start * tile_dimension tile_y_start: int = y_start * tile_dimension tile_x_end: int = tile_x_start + tile_dimension tile_y_end: int = tile_y_start + tile_dimension tile = np.array(target[tile_x_start:tile_x_end, tile_y_start:tile_y_end, :]) return tile def compute_tile_power( target_image: np.ndarray, tile_dimension: int, tile_id: np.ndarray ) -> float: tile: np.ndarray = get_tile(target_image, tile_dimension, tile_id) power_spectrum: np.ndarray = (np.abs(np.fft.fftn(tile)) ** 2).ravel() n: int = power_spectrum.size // 2 power: float = power_spectrum[:n].dot(np.arange(n)) return power def compute_tilewise_sparsity( target_image: np.ndarray, tile_dimension: int ) -> np.ndarray: tile_ids: np.ndarray = compute_tile_ids(target_image.shape, tile_dimension) tile_powers: np.ndarray = np.array( [ functools.partial(compute_tile_power, target_image, tile_dimension)(tile_id) for tile_id in tile_ids ] ) tile_powers: np.ndarray = tile_powers / tile_powers.sum() return 1.0 - tile_powers def rescale_sparsity( tilewise_sparsity: np.ndarray, min_sparsity: float, max_sparsity: float ) -> np.ndarray: min_power: float = tilewise_sparsity.min() max_power: float = tilewise_sparsity.max() rescaled: np.ndarray = (max_sparsity - min_sparsity) * ( (tilewise_sparsity - min_power) / (max_power - min_power) ) + min_sparsity return rescaled