Source code for arraybridge.stack_utils

"""
Stack utilities module for OpenHCS.

This module provides functions for stacking 2D slices into a 3D array
and unstacking a 3D array into 2D slices, with explicit memory type handling.

This module enforces Clause 278 — Mandatory 3D Output Enforcement:
All functions must return a 3D array of shape [Z, Y, X], even when operating
on a single 2D slice. No logic may check, coerce, or infer rank at unstack time.
"""

import logging
from typing import Any

from arraybridge.converters import detect_memory_type
from arraybridge.framework_config import _FRAMEWORK_CONFIG
from arraybridge.types import GPU_MEMORY_TYPES, MemoryType
from arraybridge.utils import optional_import

logger = logging.getLogger(__name__)

# 🔍 MEMORY CONVERSION LOGGING: Test log to verify logger is working
logger.debug("🔄 STACK_UTILS: Module loaded - memory conversion logging enabled")


def _is_2d(data: Any) -> bool:
    """
    Check if data is a 2D array.

    Args:
        data: Data to check

    Returns:
        True if data is 2D, False otherwise
    """
    # Check if data has a shape attribute
    if not hasattr(data, "shape"):
        return False

    # Check if shape has length 2
    return len(data.shape) == 2


def _is_3d(data: Any) -> bool:
    """
    Check if data is a 3D array.

    Args:
        data: Data to check

    Returns:
        True if data is 3D, False otherwise
    """
    # Check if data has a shape attribute
    if not hasattr(data, "shape"):
        return False

    # Check if shape has length 3
    return len(data.shape) == 3


def _enforce_gpu_device_requirements(memory_type: str, gpu_id: int) -> None:
    """
    Enforce GPU device requirements.

    Args:
        memory_type: The memory type
        gpu_id: The GPU device ID

    Raises:
        ValueError: If gpu_id is negative
    """
    # For GPU memory types, validate gpu_id
    if memory_type in {mem_type.value for mem_type in GPU_MEMORY_TYPES}:
        if gpu_id < 0:
            raise ValueError(f"Invalid GPU device ID: {gpu_id}. Must be a non-negative integer.")


# NOTE: Allocation operations now defined in framework_config.py
# This eliminates the scattered _ALLOCATION_OPS dict


def _allocate_stack_array(
    memory_type: str, stack_shape: tuple, first_slice: Any, gpu_id: int
) -> Any:
    """
    Allocate a 3D array for stacking slices using framework config.

    Args:
        memory_type: The target memory type
        stack_shape: The shape of the stack (Z, Y, X)
        first_slice: The first slice (used for dtype inference)
        gpu_id: The GPU device ID

    Returns:
        Pre-allocated array or None for pyclesperanto
    """
    # Convert string to enum
    mem_type = MemoryType(memory_type)
    config = _FRAMEWORK_CONFIG[mem_type]
    allocate_expr = config["allocate_stack"]

    # Check if allocation is None (pyclesperanto uses custom stacking)
    if allocate_expr is None:
        return None

    # Import the module
    mod = optional_import(mem_type.value)
    if mod is None:
        raise ValueError(f"{mem_type.value} is required for memory type {memory_type}")

    # Handle dtype conversion if needed
    needs_conversion = config["needs_dtype_conversion"]
    if callable(needs_conversion):
        # It's a callable that determines if conversion is needed
        needs_conversion = needs_conversion(first_slice, detect_memory_type)

    # Initialize variables for eval expressions
    sample_converted = None
    if needs_conversion:
        from arraybridge.converters import convert_memory

        first_slice_source_type = detect_memory_type(first_slice)
        sample_converted = convert_memory(
            data=first_slice,
            source_type=first_slice_source_type,
            target_type=memory_type,
            gpu_id=gpu_id,
        )

    # Set up local variables for eval
    np = optional_import("numpy")  # noqa: F841
    cupy = mod if mem_type == MemoryType.CUPY else None  # noqa: F841
    torch = mod if mem_type == MemoryType.TORCH else None  # noqa: F841
    tf = mod if mem_type == MemoryType.TENSORFLOW else None  # noqa: F841
    jnp = optional_import("jax.numpy") if mem_type == MemoryType.JAX else None  # noqa: F841
    # dtype is used in allocate_expr eval below (for numpy framework)
    dtype = (  # noqa: F841
        sample_converted.dtype
        if sample_converted is not None
        else (first_slice.dtype if hasattr(first_slice, "dtype") else None)
    )

    # Execute allocation with context if needed
    allocate_context = config.get("allocate_context")
    if allocate_context:
        context = eval(allocate_context)
        with context:
            return eval(allocate_expr)
    else:
        return eval(allocate_expr)


[docs] def stack_slices(slices: list[Any], memory_type: str, gpu_id: int) -> Any: """ Stack 2D slices into a 3D array with the specified memory type. STRICT VALIDATION: Assumes all slices are 2D arrays. No automatic handling of improper inputs. Args: slices: List of 2D slices (numpy arrays, cupy arrays, torch tensors, etc.) memory_type: The memory type to use for the stacked array (REQUIRED) gpu_id: The target GPU device ID (REQUIRED) Returns: A 3D array with the specified memory type of shape [Z, Y, X] Raises: ValueError: If memory_type is not supported or slices is empty ValueError: If gpu_id is negative for GPU memory types ValueError: If slices are not 2D arrays MemoryConversionError: If conversion fails """ if not slices: raise ValueError("Cannot stack empty list of slices") # Verify all slices are 2D for i, slice_data in enumerate(slices): if not _is_2d(slice_data): raise ValueError(f"Slice at index {i} is not a 2D array. All slices must be 2D.") # Check GPU requirements _enforce_gpu_device_requirements(memory_type, gpu_id) # Pre-allocate the final 3D array to avoid intermediate list and final stack operation first_slice = slices[0] stack_shape = (len(slices), first_slice.shape[0], first_slice.shape[1]) # Create pre-allocated result array in target memory type using enum dispatch result = _allocate_stack_array(memory_type, stack_shape, first_slice, gpu_id) # Convert each slice and assign to result array conversion_count = 0 # Check for custom stack handler (pyclesperanto) mem_type = MemoryType(memory_type) config = _FRAMEWORK_CONFIG[mem_type] stack_handler = config.get("stack_handler") if stack_handler: # Use custom stack handler mod = optional_import(mem_type.value) result = stack_handler(slices, memory_type, gpu_id, mod) else: # Standard stacking logic for i, slice_data in enumerate(slices): source_type = detect_memory_type(slice_data) # Track conversions for batch logging if source_type != memory_type: conversion_count += 1 # Direct conversion if source_type == memory_type: converted_data = slice_data else: from arraybridge.converters import convert_memory converted_data = convert_memory( data=slice_data, source_type=source_type, target_type=memory_type, gpu_id=gpu_id ) # Assign converted slice using framework-specific handler if available assign_handler = config.get("assign_slice") if assign_handler: # Custom assignment (JAX immutability) result = assign_handler(result, i, converted_data) else: # Standard assignment result[i] = converted_data # 🔍 MEMORY CONVERSION LOGGING: Only log when conversions happen or issues occur if conversion_count > 0: logger.debug( f"🔄 STACK_SLICES: Converted {conversion_count}/{len(slices)} " f"slices to {memory_type}" ) # Silent success for no-conversion cases to reduce log pollution return result
[docs] def unstack_slices( array: Any, memory_type: str, gpu_id: int, validate_slices: bool = True ) -> list[Any]: """ Split a 3D array into 2D slices along axis 0 and convert to the specified memory type. STRICT VALIDATION: Input must be a 3D array. No automatic handling of improper inputs. Args: array: 3D array to split - MUST BE 3D memory_type: The memory type to use for the output slices (REQUIRED) gpu_id: The target GPU device ID (REQUIRED) validate_slices: If True, validates that each extracted slice is 2D Returns: List of 2D slices in the specified memory type Raises: ValueError: If array is not 3D ValueError: If validate_slices is True and any extracted slice is not 2D ValueError: If gpu_id is negative for GPU memory types ValueError: If memory_type is not supported MemoryConversionError: If conversion fails """ # Detect input type and check if conversion is needed input_type = detect_memory_type(array) getattr(array, "shape", "unknown") # Verify the array is 3D - fail loudly if not if not _is_3d(array): raise ValueError(f"Array must be 3D, got shape {getattr(array, 'shape', 'unknown')}") # Check GPU requirements _enforce_gpu_device_requirements(memory_type, gpu_id) # Convert to target memory type source_type = input_type # Reuse already detected type # Direct conversion if source_type == memory_type: # No conversion needed - silent success to reduce log pollution pass else: # Convert and log the conversion from arraybridge.converters import convert_memory logger.debug(f"🔄 UNSTACK_SLICES: Converting array - {source_type}{memory_type}") array = convert_memory( data=array, source_type=source_type, target_type=memory_type, gpu_id=gpu_id ) # Extract slices along axis 0 (already in the target memory type) slices = [array[i] for i in range(array.shape[0])] # Validate that all extracted slices are 2D if requested if validate_slices: for i, slice_data in enumerate(slices): if not _is_2d(slice_data): raise ValueError( f"Extracted slice at index {i} is not 2D. " f"This indicates a malformed 3D array." ) # 🔍 MEMORY CONVERSION LOGGING: Only log conversions or issues if source_type != memory_type: logger.debug(f"🔄 UNSTACK_SLICES: Converted and extracted {len(slices)} slices") elif len(slices) == 0: logger.warning("🔄 UNSTACK_SLICES: No slices extracted (empty array)") # Silent success for no-conversion cases to reduce log pollution return slices