"""Memory conversion public API for OpenHCS."""
from typing import Any
import numpy as np
from arraybridge.converters_registry import get_converter
from arraybridge.framework_config import _FRAMEWORK_CONFIG
from arraybridge.types import MemoryType, VALID_MEMORY_TYPES
[docs]
def convert_memory(data: Any, source_type: str, target_type: str, gpu_id: int) -> Any:
"""
Convert data between memory types using the unified converter infrastructure.
Args:
data: The data to convert
source_type: The source memory type (e.g., "numpy", "torch")
target_type: The target memory type (e.g., "cupy", "jax")
gpu_id: The target GPU device ID
Returns:
The converted data in the target memory type
Raises:
ValueError: If source_type or target_type is invalid
MemoryConversionError: If conversion fails
"""
if isinstance(target_type, MemoryType):
target_type = target_type.value
if target_type not in VALID_MEMORY_TYPES:
raise ValueError(
f"Invalid target_type '{target_type}'. Available types: {sorted(VALID_MEMORY_TYPES)}"
)
converter = get_converter(source_type) # Will raise ValueError if invalid
method = getattr(converter, f"to_{target_type}")
return method(data, gpu_id)
[docs]
def detect_memory_type(data: Any) -> str:
"""
Detect the memory type of data using framework config.
Args:
data: The data to detect
Returns:
The detected memory type string (e.g., "numpy", "torch")
Raises:
ValueError: If memory type cannot be detected
"""
# NumPy special case (most common, check first)
if isinstance(data, np.ndarray):
return MemoryType.NUMPY.value
# Check all frameworks using their module names from config
module_name = type(data).__module__
top_level = module_name.split(".")[0]
for mem_type, config in _FRAMEWORK_CONFIG.items():
import_name = config["import_name"]
aliases = {import_name}
if import_name == "jax":
aliases.add("jaxlib")
if top_level in aliases:
return mem_type.value
raise ValueError(f"Unknown memory type for {type(data)} (module: {module_name})")