import os
import random
import numpy as np
import torch
from easy_tpp.utils.import_utils import is_torch_mps_available
[docs]def set_seed(seed=1029):
    """Setup random seed.
    Args:
        seed (int, optional): random seed. Defaults to 1029.
    """
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True 
[docs]def set_device(gpu=-1):
    """Setup the device.
    Args:
        gpu (int, optional): num of GPU to use. Defaults to -1 (not use GPU, i.e., use CPU).
    """
    if gpu >= 0:
        if torch.cuda.is_available():
            device = torch.device("cuda:" + str(gpu))
        elif is_torch_mps_available():
            device = torch.device("mps")
    else:
        device = torch.device("cpu")
    return device 
[docs]def set_optimizer(optimizer, params, lr):
    """Setup the optimizer.
    Args:
        optimizer (str): name of the optimizer.
        params (dict): dict of params for the optimizer.
        lr (float): learning rate.
    Raises:
        NotImplementedError: if the optimizer's name is wrong or the optimizer is not supported,
        we raise error.
    Returns:
        torch.optim: torch optimizer.
    """
    if isinstance(optimizer, str):
        if optimizer.lower() == "adam":
            optimizer = "Adam"
    try:
        optimizer = getattr(torch.optim, optimizer)(params, lr=lr)
    except Exception:
        raise NotImplementedError("optimizer={} is not supported.".format(optimizer))
    return optimizer 
[docs]def count_model_params(model):
    """Count the number of params of the model.
    Args:
        model (torch.nn.Moduel): a torch model.
    Returns:
        int: total num of the parameters.
    """
    return sum(p.numel() for p in model.parameters())