EasyTPP Utilities Modules

utils.py_assert(condition, exception_type, msg)[source]

An assert function that ensures the condition holds, otherwise throws a message.

Parameters:
  • condition (bool) – a formula to ensure validity.

  • exception_type (_StandardError) – Error type, such as ValueError.

  • msg (str) – a message to throw out.

Raises:

exception_type – throw an error when the condition does not hold.

utils.make_config_string(config, max_num_key=4)[source]

Generate a name for config files.

Parameters:
  • config (dict) – configuration dict.

  • max_num_key (int, optional) – max number of keys to concat in the output. Defaults to 4.

Returns:

a concatenated string from config dict.

Return type:

dict

utils.create_folder(*args)[source]

Create path if the folder doesn’t exist.

Returns:

the created folder’s path.

Return type:

str

utils.save_yaml_config(save_dir, config)[source]

A function that saves a dict of config to yaml format file.

Parameters:
  • save_dir (str) – the path to save config file.

  • config (dict) – the target config object.

utils.load_yaml_config(config_dir)[source]

Load yaml config file from disk.

Parameters:

config_dir – str or Path The path of the config file.

Returns:

dict.

Return type:

Config

class utils.RunnerPhase(value)[source]

Bases: ExplicitEnum

Model runner phase enum.

TRAIN = 'train'
VALIDATE = 'validate'
PREDICT = 'predict'
class utils.LogConst[source]

Bases: object

Format for log handler.

DEFAULT_FORMAT = '[%(asctime)s] [%(levelname)s] %(message)s'
DEFAULT_FORMAT_LONG = '%(asctime)s - %(filename)s[pid:%(process)d;line:%(lineno)d:%(funcName)s] - %(levelname)s: %(message)s'
utils.load_pickle(file_dir)[source]

Load from pickle file.

Parameters:

file_dir (BinaryIO) – dir of the pickle file.

Returns:

the loaded data.

Return type:

any type

utils.has_key(target_dict, target_keys)[source]

Check if the keys exist in the target dict.

Parameters:
  • target_dict (dict) – a dict.

  • target_keys (str, list) – list of keys.

Returns:

True if all the key exist in the dict; False otherwise.

Return type:

bool

utils.array_pad_cols(arr, max_num_cols, pad_index)[source]

Pad the array by columns.

Parameters:
  • arr (np.array) – target array to be padded.

  • max_num_cols (int) – target num cols for padded array.

  • pad_index (int) – pad index to fill out the padded elements

Returns:

the padded array.

Return type:

np.array

class utils.MetricsHelper[source]

Bases: object

MAXIMIZE = 'maximize'
MINIMIZE = 'minimize'
static get_metric_function(name)[source]
static get_metric_direction(name)[source]
static get_all_registered_metric()[source]
static register(name, direction, overwrite=True)[source]
static metrics_dict_to_str(metrics_dict)[source]

Convert metrics to a string to show in console

static get_metrics_callback_from_names(metric_names)[source]

Metrics function callbacks

class utils.MetricsTracker[source]

Bases: object

Track and record the metrics.

__init__()[source]
update_best(key, value, epoch)[source]

Update the recorder for the best metrics.

Parameters:
  • key (str) – metrics key.

  • value (float) – metrics value.

  • epoch (int) – num of epoch.

Raises:

NotImplementedError – for keys other than ‘loglike’.

Returns:

whether the recorder has been updated.

Return type:

bool

utils.set_device(gpu=-1)[source]

Setup the device.

Parameters:

gpu (int, optional) – num of GPU to use. Defaults to -1 (not use GPU, i.e., use CPU).

utils.set_optimizer(optimizer, params, lr)[source]

Setup the optimizer.

Parameters:
  • optimizer (str) – name of the optimizer.

  • params (dict) – dict of params for the optimizer.

  • lr (float) – learning rate.

Raises:
  • NotImplementedError – if the optimizer’s name is wrong or the optimizer is not supported,

  • we raise error.

Returns:

torch optimizer.

Return type:

torch.optim

utils.set_seed(seed=1029)[source]

Setup random seed.

Parameters:

seed (int, optional) – random seed. Defaults to 1029.

utils.save_pickle(file_dir, object_to_save)[source]

Save the object to a pickle file.

Parameters:
  • file_dir (str) – dir of the pickle file.

  • object_to_save (any) – the target data to be saved.

utils.count_model_params(model)[source]

Count the number of params of the model.

Parameters:

model (torch.nn.Moduel) – a torch model.

Returns:

total num of the parameters.

Return type:

int

class utils.Registrable[source]

Bases: object

Any class that inherits from Registrable gains access to a named registry for its subclasses. To register them, just decorate them with the classmethod @BaseClass.register(name).

After which you can call BaseClass.list_available() to get the keys for the registered subclasses, and BaseClass.by_name(name) to get the corresponding subclass.

Note that the registry stores the subclasses themselves; not class instances. In most cases you would then call from_params(params) on the returned subclass.

classmethod register(name, constructor=None, overwrite=False)[source]

Register a class under a particular name. :param name: The name to register the class under. :type name: str :param constructor: optional (default=None)

The name of the method to use on the class to construct the object. If this is given, we will use this method (which must be a classmethod) instead of the default constructor.

Parameters:

overwrite (bool) – optional (default=False) If True, overwrites any existing models registered under name. Else, throws an error if a model is already registered under name.

# Examples To use this class, you would typically have a base class that inherits from Registrable: ```python class Transform(Registrable):

``` Then, if you want to register a subclass, you decorate it like this: ```python @Transform.register(“shift-transform”) class ShiftTransform(Transform):

def __init__(self, param1: int, param2: str):

` Registering a class like this will let you instantiate a class from a config file, where you give ``"type": "shift-transform", and keys corresponding to the parameters of the __init__ method (note that for this to work, those parameters must have type annotations). If you want to have the instantiation from a config file call a method other than the constructor, either because you have several different construction paths that could be taken for the same object (as we do in Transform) or because you have logic you want to happen before you get to the constructor, you can register a specific @classmethod as the constructor to use.

classmethod by_name(name)[source]

Returns a callable function that constructs an argument of the registered class. Because you can register particular functions as constructors for specific names, this isn’t necessarily the __init__ method of some class.

classmethod resolve_class_name(name)[source]

Returns the subclass that corresponds to the given name, along with the name of the method that was registered as a constructor for that name, if any. This method also allows name to be a fully-specified module name, instead of a name that was already added to the Registry. In that case, you cannot use a separate function as a constructor (as you need to call cls.register() in order to tell us what separate function to use).

classmethod list_available()[source]

List default first if it exists

classmethod registry_dict()[source]
class utils.Timer(unit='m')[source]

Bases: object

Count the elapsing time between start and end.

__init__(unit='m')[source]
start()[source]
end()[source]
utils.concat_element(arrs, pad_index)[source]

Concat element from each batch output

utils.get_stage(stage)[source]
utils.to_dict(obj, classkey=None)[source]
utils.parse_uri_to_protocol_and_path(uri)[source]

Parse a uri into two parts, protocol and path. Set ‘file’ as default protocol when lack protocol.

Parameters:

uri – str The uri to identify a resource, whose format is like ‘protocol://uri’.

Returns:

str. The method to access the resource. URI: str. The location of the resource.

Return type:

Protocol

utils.is_master_process()[source]

Check if the process is the master process in all machines.

Returns:

bool

utils.is_local_master_process()[source]

Check if the process is the master process in the local machine.

Returns:

bool

utils.dict_deep_update(target, source, is_add_new_key=True)[source]

Update ‘target’ dict by ‘source’ dict deeply, and return a new dict copied from target and source deeply.

Parameters:
  • target – dict

  • source – dict

  • is_add_new_key – bool, default True. Identify if add a key that in source but not in target into target.

Returns:

dict. It contains the both target and source values, but keeps the values from source when the key is duplicated.

Return type:

New target

class utils.DefaultRunnerConfig[source]

Bases: object

DEFAULT_DATASET_ID = 'conttime'
utils.rk4_step_method(diff_func, dt, z0)[source]

Fourth order Runge-Kutta method for solving ODEs.

Parameters:
  • diff_func – function(dt, state) Differential equation.

  • dt – Tensor with shape […, 1] Equal to t1 - t0.

  • z0 – Tensor with shape […, dim] State at t0.

Returns:

Tensor with shape […, dim], which is updated state.

utils.is_tf_available()[source]
utils.is_tensorflow_probability_available()[source]
utils.is_torchvision_available()[source]
utils.is_torch_cuda_available()[source]
utils.is_tf_gpu_available()[source]
utils.is_torch_gpu_available()[source]
utils.is_torch_available()[source]
utils.requires_backends(obj, backends)[source]
class utils.PaddingStrategy(value)[source]

Bases: ExplicitEnum

Possible values for the padding argument in [EventTokenizer.__call__]. Useful for tab-completion in an IDE.

LONGEST = 'longest'
MAX_LENGTH = 'max_length'
DO_NOT_PAD = 'do_not_pad'
class utils.ExplicitEnum(value)[source]

Bases: str, Enum

Enum with more explicit error message for missing values.

class utils.TruncationStrategy(value)[source]

Bases: ExplicitEnum

Possible values for the truncation argument in [EventTokenizer.__call__]. Useful for tab-completion in an IDE.

LONGEST_FIRST = 'longest_first'
DO_NOT_TRUNCATE = 'do_not_truncate'
utils.is_torch_device(x)[source]

Tests if x is a torch device or not. Safe to call even if torch is not installed.

utils.is_numpy_array(x)[source]

Tests if x is a numpy array or not.