Source code for easy_tpp.utils.import_utils

import importlib.util
import sys
from collections import OrderedDict
from typing import Union, Tuple

from easy_tpp.utils.log_utils import default_logger as logger

if sys.version_info < (3, 8):
    import importlib_metadata
else:
    import importlib.metadata as importlib_metadata


def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[Tuple[bool, str], bool]:
    # Check we're not importing a "pkg_name" directory somewhere but the actual library by trying to grab the version
    package_exists = importlib.util.find_spec(pkg_name) is not None
    package_version = "N/A"
    if package_exists:
        try:
            package_version = importlib_metadata.version(pkg_name)
        except importlib_metadata.PackageNotFoundError:
            pass
        logger.debug(f"Detected {pkg_name} version {package_version}")
    if return_version:
        return package_exists, package_version
    else:
        return package_exists


_tf_available = _is_package_available("tensorflow")
if _tf_available:
    candidates = (
        "tensorflow",
        "tensorflow-cpu",
        "tensorflow-gpu",
        "tf-nightly",
        "tf-nightly-cpu",
        "tf-nightly-gpu",
        "intel-tensorflow",
        "intel-tensorflow-avx512",
        "tensorflow-rocm",
        "tensorflow-macos",
        "tensorflow-aarch64",
    )
    _tf_version = None
    # For the metadata, we have to look for both tensorflow and tensorflow-cpu
    for pkg in candidates:
        try:
            _tf_version = importlib_metadata.version(pkg)
            break
        except importlib_metadata.PackageNotFoundError:
            pass
    _tf_available = _tf_version is not None

_tensorflow_probability_available = _is_package_available("tensorflow_probability")
_torchdistx_available = _is_package_available("torchdistx")
_torchvision_available = _is_package_available("torchvision")

_torch_available, _torch_version = _is_package_available("torch", return_version=True)


[docs]def is_torch_available(): return _torch_available
def get_torch_version(): return _torch_version
[docs]def is_torchvision_available(): return _torchvision_available
[docs]def is_torch_cuda_available(): if is_torch_available(): import torch return torch.cuda.is_available() else: return False
[docs]def is_tf_available(): return _tf_available
[docs]def is_tf_gpu_available(): if is_tf_available(): import tensorflow as tf if tf.__version__ >= '2.0': return bool(tf.config.list_physical_devices("GPU")) else: from tensorflow.python.client import device_lib local_device_protos = device_lib.list_local_devices() for device in local_device_protos: if device.device_type == 'GPU': return True else: return False
def is_torch_mps_available(): if is_torch_available(): try: import torch torch.device('mps') return True except RuntimeError: return False else: return False
[docs]def is_torch_gpu_available(): is_cuda_available = is_torch_cuda_available() is_mps_available = is_torch_mps_available() return is_cuda_available | is_mps_available
[docs]def is_tensorflow_probability_available(): return _tensorflow_probability_available
def torch_only_method(fn): def wrapper(*args, **kwargs): if not _torch_available: raise ImportError( "You need to install pytorch to use this method or class, " "or activate it with environment variables USE_TORCH=1 and USE_TF=0." ) else: return fn(*args, **kwargs) return wrapper # docstyle-ignore PYTORCH_IMPORT_ERROR = """ {0} requires the PyTorch library but it was not found in your environment. Checkout the instructions on the installation page: https://pytorch.org/get-started/locally/ and follow the ones that match your environment. Please note that you may need to restart your runtime after installation. """ # docstyle-ignore TORCHVISION_IMPORT_ERROR = """ {0} requires the Torchvision library but it was not found in your environment. Checkout the instructions on the installation page: https://pytorch.org/get-started/locally/ and follow the ones that match your environment. Please note that you may need to restart your runtime after installation. """ # docstyle-ignore PYTORCH_IMPORT_ERROR_WITH_TF = """ {0} requires the PyTorch library but it was not found in your environment. However, we were able to find a TensorFlow installation. TensorFlow classes begin with "TF", but are otherwise identically named to our PyTorch classes. This means that the TF equivalent of the class you tried to import would be "TF{0}". If you want to use TensorFlow, please use TF classes instead! If you really do want to use PyTorch please go to https://pytorch.org/get-started/locally/ and follow the instructions that match your environment. """ # docstyle-ignore TF_IMPORT_ERROR_WITH_PYTORCH = """ {0} requires the TensorFlow library but it was not found in your environment. However, we were able to find a PyTorch installation. PyTorch classes do not begin with "TF", but are otherwise identically named to our TF classes. If you want to use PyTorch, please use those classes instead! If you really do want to use TensorFlow, please follow the instructions on the installation page https://www.tensorflow.org/install that match your environment. """ # docstyle-ignore TENSORFLOW_IMPORT_ERROR = """ {0} requires the TensorFlow library but it was not found in your environment. Checkout the instructions on the installation page: https://www.tensorflow.org/install and follow the ones that match your environment. Please note that you may need to restart your runtime after installation. """ # docstyle-ignore TENSORFLOW_PROBABILITY_IMPORT_ERROR = """ {0} requires the tensorflow_probability library but it was not found in your environment. You can install it with pip as explained here: https://github.com/tensorflow/probability. Please note that you may need to restart your runtime after installation. """ BACKENDS_MAPPING = OrderedDict( [ ("tensorflow_probability", (is_tensorflow_probability_available, TENSORFLOW_PROBABILITY_IMPORT_ERROR)), ("tf", (is_tf_available, TENSORFLOW_IMPORT_ERROR)), ("torch", (is_torch_available, PYTORCH_IMPORT_ERROR)), ("torchvision", (is_torchvision_available, TORCHVISION_IMPORT_ERROR)) ] )
[docs]def requires_backends(obj, backends): if not isinstance(backends, (list, tuple)): backends = [backends] name = obj.__name__ if hasattr(obj, "__name__") else obj.__class__.__name__ # Raise an error for users who might not realize that classes without "TF" are torch-only if "torch" in backends and "tf" not in backends and not is_torch_available() and is_tf_available(): raise ImportError(PYTORCH_IMPORT_ERROR_WITH_TF.format(name)) # Raise the inverse error for PyTorch users trying to load TF classes if "tf" in backends and "torch" not in backends and is_torch_available() and not is_tf_available(): raise ImportError(TF_IMPORT_ERROR_WITH_PYTORCH.format(name)) checks = (BACKENDS_MAPPING[backend] for backend in backends) failed = [msg.format(name) for available, msg in checks if not available()] if failed: raise ImportError("".join(failed))