Source code for arraybridge.decorators

"""
Memory type declaration decorators for OpenHCS.

This module provides decorators for explicitly declaring the memory interface
of pure functions, enforcing Clause 106-A (Declared Memory Types) and supporting
memory-type-aware dispatching and orchestration.

These decorators annotate functions with input_memory_type and output_memory_type
attributes and provide automatic thread-local CUDA stream management for GPU
frameworks to enable true parallelization across multiple threads.

REFACTORED: Uses enum-driven metaprogramming to eliminate 79% of code duplication.
"""

import functools
import inspect
import logging
import threading
from enum import Enum
from typing import Any, Callable, Optional, TypeVar

import numpy as np

from arraybridge.dtype_scaling import SCALING_FUNCTIONS
from arraybridge.framework_ops import _FRAMEWORK_OPS
from arraybridge.oom_recovery import _execute_with_oom_recovery
from arraybridge.slice_processing import process_slices
from arraybridge.types import MemoryType
from arraybridge.utils import optional_import

logger = logging.getLogger(__name__)

F = TypeVar("F", bound=Callable[..., Any])


[docs] class DtypeConversion(Enum): """Data type conversion modes for all memory type functions.""" PRESERVE_INPUT = "preserve" # Keep input dtype (default) NATIVE_OUTPUT = "native" # Use framework's native output UINT8 = "uint8" # Force uint8 (0-255 range) UINT16 = "uint16" # Force uint16 (microscopy standard) INT16 = "int16" # Force int16 (signed microscopy data) INT32 = "int32" # Force int32 (large integer values) FLOAT32 = "float32" # Force float32 (GPU performance) FLOAT64 = "float64" # Force float64 (maximum precision) @property def numpy_dtype(self): """Get the corresponding numpy dtype.""" dtype_map = { self.UINT8: np.uint8, self.UINT16: np.uint16, self.INT16: np.int16, self.INT32: np.int32, self.FLOAT32: np.float32, self.FLOAT64: np.float64, } return dtype_map.get(self, None)
# Thread-local cache for lazy-loaded GPU frameworks _gpu_frameworks_cache = {} def _create_lazy_getter(framework_name: str): """Factory function that creates a lazy import getter for a framework.""" def getter(): if framework_name not in _gpu_frameworks_cache: _gpu_frameworks_cache[framework_name] = optional_import(framework_name) if _gpu_frameworks_cache[framework_name] is not None: logger.debug( f"🔧 Lazy imported {framework_name} in thread " f"{threading.current_thread().name}" ) return _gpu_frameworks_cache[framework_name] return getter # Auto-generate lazy getters for all GPU frameworks for mem_type in MemoryType: ops = _FRAMEWORK_OPS[mem_type] if ops["lazy_getter"] is not None: getter_func = _create_lazy_getter(ops["import_name"]) globals()[f"_get_{ops['import_name']}"] = getter_func # Thread-local storage for GPU streams and contexts _thread_gpu_contexts = threading.local() class ThreadGPUContext: """Thread-local GPU context manager for CUDA streams.""" def __init__(self): self.cupy_stream = None self.torch_stream = None self.tensorflow_device = None self.jax_device = None def get_cupy_stream(self): """Get or create thread-local CuPy stream.""" if self.cupy_stream is None: cupy = globals().get("_get_cupy", lambda: None)() # noqa: F821 if cupy is not None and hasattr(cupy, "cuda"): self.cupy_stream = cupy.cuda.Stream() logger.debug(f"🔧 Created CuPy stream for thread {threading.current_thread().name}") return self.cupy_stream def get_torch_stream(self): """Get or create thread-local PyTorch stream.""" if self.torch_stream is None: torch = globals().get("_get_torch", lambda: None)() # noqa: F821 if torch is not None and hasattr(torch, "cuda") and torch.cuda.is_available(): self.torch_stream = torch.cuda.Stream() logger.debug( f"🔧 Created PyTorch stream for thread " f"{threading.current_thread().name}" ) return self.torch_stream def _get_thread_gpu_context(): """Get or create thread-local GPU context.""" if not hasattr(_thread_gpu_contexts, "context"): _thread_gpu_contexts.context = ThreadGPUContext() return _thread_gpu_contexts.context
[docs] def memory_types( input_type: str, output_type: str, contract: Optional[Callable[[Any], bool]] = None ) -> Callable[[F], F]: """ Base decorator for declaring memory types of a function. This is the foundation decorator that all memory-type-specific decorators build upon. """ def decorator(func: F) -> F: @functools.wraps(func) def wrapper(*args, **kwargs): result = func(*args, **kwargs) # Apply contract validation if provided if contract is not None and not contract(result): raise ValueError(f"Function {func.__name__} violated its output contract") return result # Attach memory type metadata wrapper.input_memory_type = input_type wrapper.output_memory_type = output_type return wrapper return decorator
def _create_dtype_wrapper(func, mem_type: MemoryType, func_name: str): """ Auto-generate dtype preservation wrapper for any memory type. This single function replaces 6 nearly-identical dtype wrapper functions. """ _FRAMEWORK_OPS[mem_type] scale_func = SCALING_FUNCTIONS[mem_type.value] @functools.wraps(func) def dtype_wrapper(image, *args, slice_by_slice: bool = False, **kwargs): # Get dtype_config from kwargs (injected by OpenHCS) dtype_config = kwargs.pop("dtype_config") # Extract the conversion mode from the config dataclass dtype_conversion = dtype_config.default_dtype_conversion # Store original dtype original_dtype = getattr(image, "dtype", None) # Handle slice_by_slice processing for 3D arrays if slice_by_slice and hasattr(image, "ndim") and image.ndim == 3: result = process_slices(image, func, args, kwargs) else: # Call the original function normally result = func(image, *args, **kwargs) def _apply_dtype_conversion(array): if dtype_conversion is None or not hasattr(array, "dtype"): return array if dtype_conversion == DtypeConversion.PRESERVE_INPUT: # Preserve input dtype if original_dtype is not None and array.dtype != original_dtype: return scale_func(array, original_dtype) return array if dtype_conversion == DtypeConversion.NATIVE_OUTPUT: # Return framework's native output dtype return array # Force specific dtype target_dtype = dtype_conversion.numpy_dtype if target_dtype is not None: return scale_func(array, target_dtype) return array try: # Apply dtype conversion to the main output if isinstance(result, tuple): if not result: return result converted_main = _apply_dtype_conversion(result[0]) return (converted_main, *result[1:]) return _apply_dtype_conversion(result) except Exception as e: logger.error( f"Error in {mem_type.value} dtype/slice preserving wrapper " f"for {func_name}: {e}" ) # Return unmodified result on conversion errors return result # Update function signature to include new parameters try: original_sig = inspect.signature(func) new_params = list(original_sig.parameters.values()) # Check if parameters already exist param_names = [p.name for p in new_params] # Add slice_by_slice parameter if "slice_by_slice" not in param_names: slice_param = inspect.Parameter( "slice_by_slice", inspect.Parameter.KEYWORD_ONLY, default=False, annotation=bool ) new_params.append(slice_param) # Create new signature new_sig = original_sig.replace(parameters=new_params) dtype_wrapper.__signature__ = new_sig # Update docstring if dtype_wrapper.__doc__: dtype_wrapper.__doc__ += ( f"\n\n Additional Parameters " f"(added by {mem_type.value} decorator):\n" ) dtype_wrapper.__doc__ += ( " slice_by_slice (bool, optional): " "Process 3D arrays slice-by-slice.\n" ) dtype_wrapper.__doc__ += ( " Defaults to False. " "Prevents cross-slice contamination.\n" ) except Exception as e: logger.warning(f"Could not update signature for {func_name}: {e}") return dtype_wrapper def _create_gpu_wrapper(func, mem_type: MemoryType, oom_recovery: bool): """ Auto-generate GPU stream/device wrapper for any GPU memory type. This function creates the GPU-specific wrapper with stream management and OOM recovery. """ ops = _FRAMEWORK_OPS[mem_type] framework_name = ops["import_name"] lazy_getter = globals().get(ops["lazy_getter"]) @functools.wraps(func) def gpu_wrapper(*args, **kwargs): framework = lazy_getter() # Check if GPU is available for this framework if framework is not None: gpu_check_expr = ops["gpu_check"].format(mod=framework_name) try: gpu_available = eval(gpu_check_expr, {framework_name: framework}) except Exception: gpu_available = False if gpu_available: # Get thread-local context ctx = _get_thread_gpu_context() # Get stream if framework supports it stream = None if mem_type == MemoryType.CUPY: stream = ctx.get_cupy_stream() elif mem_type == MemoryType.TORCH: stream = ctx.get_torch_stream() # Define execution function that captures args/kwargs def execute_with_stream(): if stream is not None: with stream: return func(*args, **kwargs) else: return func(*args, **kwargs) # Execute with OOM recovery if enabled if oom_recovery and ops["has_oom_recovery"]: return _execute_with_oom_recovery(execute_with_stream, mem_type.value) else: return execute_with_stream() # CPU fallback or framework not available return func(*args, **kwargs) # Preserve memory type attributes gpu_wrapper.input_memory_type = func.input_memory_type gpu_wrapper.output_memory_type = func.output_memory_type return gpu_wrapper def _create_memory_decorator(mem_type: MemoryType): """ Factory function that creates a decorator for a specific memory type. This single factory replaces 6 nearly-identical decorator functions. """ ops = _FRAMEWORK_OPS[mem_type] def decorator( func=None, *, input_type=mem_type.value, output_type=mem_type.value, oom_recovery=True, contract=None, ): """ Decorator for {mem_type} memory type functions. Args: func: Function to decorate (when used as @decorator) input_type: Expected input memory type (default: {mem_type}) output_type: Expected output memory type (default: {mem_type}) oom_recovery: Enable automatic OOM recovery (default: True) contract: Optional validation function for outputs Returns: Decorated function with memory type metadata and dtype preservation """ def inner_decorator(func): # Apply base memory_types decorator memory_decorator = memory_types( input_type=input_type, output_type=output_type, contract=contract ) func = memory_decorator(func) # Apply dtype preservation wrapper func = _create_dtype_wrapper(func, mem_type, func.__name__) # Apply GPU wrapper if this is a GPU memory type if ops["gpu_check"] is not None: func = _create_gpu_wrapper(func, mem_type, oom_recovery) return func # Handle both @decorator and @decorator() forms if func is None: return inner_decorator return inner_decorator(func) # Set proper function name and docstring decorator.__name__ = mem_type.value decorator.__doc__ = decorator.__doc__.format(mem_type=ops["display_name"]) return decorator # Auto-generate all 6 memory type decorators for mem_type in MemoryType: decorator_func = _create_memory_decorator(mem_type) globals()[mem_type.value] = decorator_func # Export all decorators __all__ = [ "memory_types", "DtypeConversion", "numpy", # noqa: F822 "cupy", # noqa: F822 "torch", # noqa: F822 "tensorflow", # noqa: F822 "jax", # noqa: F822 "pyclesperanto", # noqa: F822 ]