Source code for magtrack._cupy
"""Compatibility helpers for optional CuPy support.
This module exposes ``cp`` and ``cupyx`` objects that either proxy to the real
CuPy modules (when available) or provide lightweight NumPy/SciPy fallbacks.
"""
from __future__ import annotations
from functools import lru_cache
from types import SimpleNamespace
from typing import Any
import numpy as _np
[docs]
_cupy_available = False
try: # pragma: no cover - exercised implicitly when CuPy is installed
import cupy as _cupy # type: ignore
import cupyx as _cupyx # type: ignore
except ImportError: # pragma: no cover - exercised when CuPy is absent
_cupyx = None
else:
_cupy_available = True
if _cupy_available:
[docs]
cp = _cupy # type: ignore[assignment]
cupyx = _cupyx # type: ignore[assignment]
else:
try:
import scipy as _scipy # type: ignore
except ImportError as exc: # pragma: no cover - SciPy is an optional dep
raise ImportError(
"SciPy is required for the CPU fallback when CuPy is unavailable."
) from exc
class _NumPyCupyCompat:
"""Subset of the CuPy API implemented with NumPy."""
def __init__(self) -> None:
self._np = _np
self.random = _np.random
self.fft = _np.fft
self.linalg = _np.linalg
self.testing = _np.testing
self.cuda = SimpleNamespace(is_available=lambda: False)
def __getattr__(self, name: str) -> Any:
return getattr(self._np, name)
def asarray(self, obj, dtype=None):
return self._np.asarray(obj, dtype=dtype)
def array(self, obj, dtype=None):
return self._np.array(obj, dtype=dtype)
def asnumpy(self, obj):
return obj
def get_array_module(self, _):
return self._np
class _CupyxScipyCompat:
"""Subset of :mod:`cupyx.scipy` backed by SciPy."""
def __init__(self) -> None:
self._scipy = _scipy
self.signal = _scipy.signal
self.ndimage = _scipy.ndimage
self.special = _scipy.special
def get_array_module(self, _):
return self
def __getattr__(self, name: str) -> Any:
return getattr(self._scipy, name)
cp = _NumPyCupyCompat()
cupyx = SimpleNamespace(scipy=_CupyxScipyCompat())
[docs]
def is_cupy_available() -> bool:
"""Return ``True`` when the real CuPy package is importable."""
return _cupy_available
@lru_cache(maxsize=1)
[docs]
def check_cupy() -> bool:
"""Perform a more thorough check to ensure CuPy and CUDA are usable."""
if not is_cupy_available():
return False
try:
if not _cupy.cuda.is_available(): # type: ignore[union-attr]
return False
_cupy.random.randint(0, 1, size=(1,)) # type: ignore[union-attr]
except Exception: # pragma: no cover - defensive fallback
return False
else:
return True