Source code for arraybridge.types

"""
Memory type definitions for arraybridge.

This module defines the MemoryType enum and related constants for managing
different array/tensor frameworks.
"""

from enum import Enum
from typing import Any, Callable, TypeVar

T = TypeVar("T")
ConversionFunc = Callable[[Any], Any]


[docs] class MemoryType(Enum): """Enum representing different array/tensor framework types.""" NUMPY = "numpy" CUPY = "cupy" TORCH = "torch" TENSORFLOW = "tensorflow" JAX = "jax" PYCLESPERANTO = "pyclesperanto"
# Memory type sets CPU_MEMORY_TYPES: set[MemoryType] = {MemoryType.NUMPY} GPU_MEMORY_TYPES: set[MemoryType] = { MemoryType.CUPY, MemoryType.TORCH, MemoryType.TENSORFLOW, MemoryType.JAX, MemoryType.PYCLESPERANTO, } SUPPORTED_MEMORY_TYPES: set[MemoryType] = CPU_MEMORY_TYPES | GPU_MEMORY_TYPES # String value sets for validation VALID_MEMORY_TYPES = {mt.value for mt in MemoryType} VALID_GPU_MEMORY_TYPES = {mt.value for mt in GPU_MEMORY_TYPES} # Memory type constants for direct access MEMORY_TYPE_NUMPY = MemoryType.NUMPY.value MEMORY_TYPE_CUPY = MemoryType.CUPY.value MEMORY_TYPE_TORCH = MemoryType.TORCH.value MEMORY_TYPE_TENSORFLOW = MemoryType.TENSORFLOW.value MEMORY_TYPE_JAX = MemoryType.JAX.value MEMORY_TYPE_PYCLESPERANTO = MemoryType.PYCLESPERANTO.value