import numpy as np
from easy_tpp.utils import is_torch_available, is_tf_available
def is_tensor(x):
    """
    Tests if `x` is a `torch.Tensor`, `tf.Tensor`, `jaxlib.xla_extension.DeviceArray` or `np.ndarray`.
    """
    if is_torch_available():
        import torch
        if isinstance(x, torch.Tensor):
            return True
    if is_tf_available():
        import tensorflow as tf
        if isinstance(x, tf.Tensor):
            return True
    return isinstance(x, np.ndarray)
def _is_numpy(x):
    return isinstance(x, np.ndarray)
[docs]def is_numpy_array(x):
    """
    Tests if `x` is a numpy array or not.
    """
    return _is_numpy(x) 
def _is_torch(x):
    import torch
    return isinstance(x, torch.Tensor)
def is_torch_tensor(x):
    """
    Tests if `x` is a torch tensor or not. Safe to call even if torch is not installed.
    """
    return False if not is_torch_available() else _is_torch(x)
def _is_torch_device(x):
    import torch
    return isinstance(x, torch.device)
[docs]def is_torch_device(x):
    """
    Tests if `x` is a torch device or not. Safe to call even if torch is not installed.
    """
    return False if not is_torch_available() else _is_torch_device(x) 
def _is_torch_dtype(x):
    import torch
    if isinstance(x, str):
        if hasattr(torch, x):
            x = getattr(torch, x)
        else:
            return False
    return isinstance(x, torch.dtype)
def is_torch_dtype(x):
    """
    Tests if `x` is a torch dtype or not. Safe to call even if torch is not installed.
    """
    return False if not is_torch_available() else _is_torch_dtype(x)
def _is_tensorflow(x):
    import tensorflow as tf
    return isinstance(x, tf.Tensor)
def is_tf_tensor(x):
    """
    Tests if `x` is a tensorflow tensor or not. Safe to call even if tensorflow is not installed.
    """
    return False if not is_tf_available() else _is_tensorflow(x)
def _is_tf_symbolic_tensor(x):
    import tensorflow as tf
    # the `is_symbolic_tensor` predicate is only available starting with TF 2.14
    if hasattr(tf, "is_symbolic_tensor"):
        return tf.is_symbolic_tensor(x)
    return type(x) == tf.Tensor
def is_tf_symbolic_tensor(x):
    """
    Tests if `x` is a tensorflow symbolic tensor or not (ie. not eager). Safe to call even if tensorflow is not
    installed.
    """
    return False if not is_tf_available() else _is_tf_symbolic_tensor(x)