import numpy as np
from easy_tpp.utils import is_torch_available, is_tf_available
def is_tensor(x):
"""
Tests if `x` is a `torch.Tensor`, `tf.Tensor`, `jaxlib.xla_extension.DeviceArray` or `np.ndarray`.
"""
if is_torch_available():
import torch
if isinstance(x, torch.Tensor):
return True
if is_tf_available():
import tensorflow as tf
if isinstance(x, tf.Tensor):
return True
return isinstance(x, np.ndarray)
def _is_numpy(x):
return isinstance(x, np.ndarray)
[docs]def is_numpy_array(x):
"""
Tests if `x` is a numpy array or not.
"""
return _is_numpy(x)
def _is_torch(x):
import torch
return isinstance(x, torch.Tensor)
def is_torch_tensor(x):
"""
Tests if `x` is a torch tensor or not. Safe to call even if torch is not installed.
"""
return False if not is_torch_available() else _is_torch(x)
def _is_torch_device(x):
import torch
return isinstance(x, torch.device)
[docs]def is_torch_device(x):
"""
Tests if `x` is a torch device or not. Safe to call even if torch is not installed.
"""
return False if not is_torch_available() else _is_torch_device(x)
def _is_torch_dtype(x):
import torch
if isinstance(x, str):
if hasattr(torch, x):
x = getattr(torch, x)
else:
return False
return isinstance(x, torch.dtype)
def is_torch_dtype(x):
"""
Tests if `x` is a torch dtype or not. Safe to call even if torch is not installed.
"""
return False if not is_torch_available() else _is_torch_dtype(x)
def _is_tensorflow(x):
import tensorflow as tf
return isinstance(x, tf.Tensor)
def is_tf_tensor(x):
"""
Tests if `x` is a tensorflow tensor or not. Safe to call even if tensorflow is not installed.
"""
return False if not is_tf_available() else _is_tensorflow(x)
def _is_tf_symbolic_tensor(x):
import tensorflow as tf
# the `is_symbolic_tensor` predicate is only available starting with TF 2.14
if hasattr(tf, "is_symbolic_tensor"):
return tf.is_symbolic_tensor(x)
return type(x) == tf.Tensor
def is_tf_symbolic_tensor(x):
"""
Tests if `x` is a tensorflow symbolic tensor or not (ie. not eager). Safe to call even if tensorflow is not
installed.
"""
return False if not is_tf_available() else _is_tf_symbolic_tensor(x)