Source code for easy_tpp.utils.torch_utils

import os
import random

import numpy as np
import torch

from easy_tpp.utils.import_utils import is_torch_mps_available


[docs]def set_seed(seed=1029): """Setup random seed. Args: seed (int, optional): random seed. Defaults to 1029. """ random.seed(seed) os.environ["PYTHONHASHSEED"] = str(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.backends.cudnn.deterministic = True
[docs]def set_device(gpu=-1): """Setup the device. Args: gpu (int, optional): num of GPU to use. Defaults to -1 (not use GPU, i.e., use CPU). """ if gpu >= 0: if torch.cuda.is_available(): device = torch.device("cuda:" + str(gpu)) elif is_torch_mps_available(): device = torch.device("mps") else: device = torch.device("cpu") return device
[docs]def set_optimizer(optimizer, params, lr): """Setup the optimizer. Args: optimizer (str): name of the optimizer. params (dict): dict of params for the optimizer. lr (float): learning rate. Raises: NotImplementedError: if the optimizer's name is wrong or the optimizer is not supported, we raise error. Returns: torch.optim: torch optimizer. """ if isinstance(optimizer, str): if optimizer.lower() == "adam": optimizer = "Adam" try: optimizer = getattr(torch.optim, optimizer)(params, lr=lr) except Exception: raise NotImplementedError("optimizer={} is not supported.".format(optimizer)) return optimizer
[docs]def count_model_params(model): """Count the number of params of the model. Args: model (torch.nn.Moduel): a torch model. Returns: int: total num of the parameters. """ return sum(p.numel() for p in model.parameters())