Source code for magtrack._cupy

"""Compatibility helpers for optional CuPy support.

This module exposes ``cp`` and ``cupyx`` objects that either proxy to the real
CuPy modules (when available) or provide lightweight NumPy/SciPy fallbacks.
"""

from __future__ import annotations

from functools import lru_cache
from types import SimpleNamespace
from typing import Any

import numpy as _np

[docs] _cupy_available = False
try: # pragma: no cover - exercised implicitly when CuPy is installed import cupy as _cupy # type: ignore import cupyx as _cupyx # type: ignore except ImportError: # pragma: no cover - exercised when CuPy is absent
[docs] _cupy = None
_cupyx = None else: _cupy_available = True if _cupy_available:
[docs] cp = _cupy # type: ignore[assignment]
cupyx = _cupyx # type: ignore[assignment] else: try: import scipy as _scipy # type: ignore except ImportError as exc: # pragma: no cover - SciPy is an optional dep raise ImportError( "SciPy is required for the CPU fallback when CuPy is unavailable." ) from exc class _NumPyCupyCompat: """Subset of the CuPy API implemented with NumPy.""" def __init__(self) -> None: self._np = _np self.random = _np.random self.fft = _np.fft self.linalg = _np.linalg self.testing = _np.testing self.cuda = SimpleNamespace(is_available=lambda: False) def __getattr__(self, name: str) -> Any: return getattr(self._np, name) def asarray(self, obj, dtype=None): return self._np.asarray(obj, dtype=dtype) def array(self, obj, dtype=None): return self._np.array(obj, dtype=dtype) def asnumpy(self, obj): return obj def get_array_module(self, _): return self._np class _CupyxScipyCompat: """Subset of :mod:`cupyx.scipy` backed by SciPy.""" def __init__(self) -> None: self._scipy = _scipy self.signal = _scipy.signal self.ndimage = _scipy.ndimage self.special = _scipy.special def get_array_module(self, _): return self def __getattr__(self, name: str) -> Any: return getattr(self._scipy, name) cp = _NumPyCupyCompat() cupyx = SimpleNamespace(scipy=_CupyxScipyCompat())
[docs] def is_cupy_available() -> bool: """Return ``True`` when the real CuPy package is importable.""" return _cupy_available
@lru_cache(maxsize=1)
[docs] def check_cupy() -> bool: """Perform a more thorough check to ensure CuPy and CUDA are usable.""" if not is_cupy_available(): return False try: if not _cupy.cuda.is_available(): # type: ignore[union-attr] return False _cupy.random.randint(0, 1, size=(1,)) # type: ignore[union-attr] except Exception: # pragma: no cover - defensive fallback return False else: return True