Source code for arraybridge.converters

"""Memory conversion public API for OpenHCS."""

from typing import Any

import numpy as np

from arraybridge.converters_registry import get_converter
from arraybridge.framework_config import _FRAMEWORK_CONFIG
from arraybridge.types import MemoryType, VALID_MEMORY_TYPES


[docs] def convert_memory(data: Any, source_type: str, target_type: str, gpu_id: int) -> Any: """ Convert data between memory types using the unified converter infrastructure. Args: data: The data to convert source_type: The source memory type (e.g., "numpy", "torch") target_type: The target memory type (e.g., "cupy", "jax") gpu_id: The target GPU device ID Returns: The converted data in the target memory type Raises: ValueError: If source_type or target_type is invalid MemoryConversionError: If conversion fails """ if isinstance(target_type, MemoryType): target_type = target_type.value if target_type not in VALID_MEMORY_TYPES: raise ValueError( f"Invalid target_type '{target_type}'. Available types: {sorted(VALID_MEMORY_TYPES)}" ) converter = get_converter(source_type) # Will raise ValueError if invalid method = getattr(converter, f"to_{target_type}") return method(data, gpu_id)
[docs] def detect_memory_type(data: Any) -> str: """ Detect the memory type of data using framework config. Args: data: The data to detect Returns: The detected memory type string (e.g., "numpy", "torch") Raises: ValueError: If memory type cannot be detected """ # NumPy special case (most common, check first) if isinstance(data, np.ndarray): return MemoryType.NUMPY.value # Check all frameworks using their module names from config module_name = type(data).__module__ top_level = module_name.split(".")[0] for mem_type, config in _FRAMEWORK_CONFIG.items(): import_name = config["import_name"] aliases = {import_name} if import_name == "jax": aliases.add("jaxlib") if top_level in aliases: return mem_type.value raise ValueError(f"Unknown memory type for {type(data)} (module: {module_name})")