EasyTPP Utilities Modules
- utils.py_assert(condition, exception_type, msg)[source]
An assert function that ensures the condition holds, otherwise throws a message.
- Parameters:
condition (bool) – a formula to ensure validity.
exception_type (_StandardError) – Error type, such as ValueError.
msg (str) – a message to throw out.
- Raises:
exception_type – throw an error when the condition does not hold.
- utils.make_config_string(config, max_num_key=4)[source]
Generate a name for config files.
- Parameters:
config (dict) – configuration dict.
max_num_key (int, optional) – max number of keys to concat in the output. Defaults to 4.
- Returns:
a concatenated string from config dict.
- Return type:
dict
- utils.create_folder(*args)[source]
Create path if the folder doesn’t exist.
- Returns:
the created folder’s path.
- Return type:
str
- utils.save_yaml_config(save_dir, config)[source]
A function that saves a dict of config to yaml format file.
- Parameters:
save_dir (str) – the path to save config file.
config (dict) – the target config object.
- utils.load_yaml_config(config_dir)[source]
Load yaml config file from disk.
- Parameters:
config_dir – str or Path The path of the config file.
- Returns:
dict.
- Return type:
- class utils.RunnerPhase(value)[source]
Bases:
ExplicitEnum
Model runner phase enum.
- TRAIN = 'train'
- VALIDATE = 'validate'
- PREDICT = 'predict'
- class utils.LogConst[source]
Bases:
object
Format for log handler.
- DEFAULT_FORMAT = '[%(asctime)s] [%(levelname)s] %(message)s'
- DEFAULT_FORMAT_LONG = '%(asctime)s - %(filename)s[pid:%(process)d;line:%(lineno)d:%(funcName)s] - %(levelname)s: %(message)s'
- utils.load_pickle(file_dir)[source]
Load from pickle file.
- Parameters:
file_dir (BinaryIO) – dir of the pickle file.
- Returns:
the loaded data.
- Return type:
any type
- utils.has_key(target_dict, target_keys)[source]
Check if the keys exist in the target dict.
- Parameters:
target_dict (dict) – a dict.
target_keys (str, list) – list of keys.
- Returns:
True if all the key exist in the dict; False otherwise.
- Return type:
bool
- utils.array_pad_cols(arr, max_num_cols, pad_index)[source]
Pad the array by columns.
- Parameters:
arr (np.array) – target array to be padded.
max_num_cols (int) – target num cols for padded array.
pad_index (int) – pad index to fill out the padded elements
- Returns:
the padded array.
- Return type:
np.array
- class utils.MetricsTracker[source]
Bases:
object
Track and record the metrics.
- update_best(key, value, epoch)[source]
Update the recorder for the best metrics.
- Parameters:
key (str) – metrics key.
value (float) – metrics value.
epoch (int) – num of epoch.
- Raises:
NotImplementedError – for keys other than ‘loglike’.
- Returns:
whether the recorder has been updated.
- Return type:
bool
- utils.set_device(gpu=-1)[source]
Setup the device.
- Parameters:
gpu (int, optional) – num of GPU to use. Defaults to -1 (not use GPU, i.e., use CPU).
- utils.set_optimizer(optimizer, params, lr)[source]
Setup the optimizer.
- Parameters:
optimizer (str) – name of the optimizer.
params (dict) – dict of params for the optimizer.
lr (float) – learning rate.
- Raises:
NotImplementedError – if the optimizer’s name is wrong or the optimizer is not supported,
we raise error. –
- Returns:
torch optimizer.
- Return type:
torch.optim
- utils.set_seed(seed=1029)[source]
Setup random seed.
- Parameters:
seed (int, optional) – random seed. Defaults to 1029.
- utils.save_pickle(file_dir, object_to_save)[source]
Save the object to a pickle file.
- Parameters:
file_dir (str) – dir of the pickle file.
object_to_save (any) – the target data to be saved.
- utils.count_model_params(model)[source]
Count the number of params of the model.
- Parameters:
model (torch.nn.Moduel) – a torch model.
- Returns:
total num of the parameters.
- Return type:
int
- class utils.Registrable[source]
Bases:
object
Any class that inherits from
Registrable
gains access to a named registry for its subclasses. To register them, just decorate them with the classmethod@BaseClass.register(name)
.After which you can call
BaseClass.list_available()
to get the keys for the registered subclasses, andBaseClass.by_name(name)
to get the corresponding subclass.Note that the registry stores the subclasses themselves; not class instances. In most cases you would then call
from_params(params)
on the returned subclass.- classmethod register(name, constructor=None, overwrite=False)[source]
Register a class under a particular name. :param name: The name to register the class under. :type name: str :param constructor: optional (default=None)
The name of the method to use on the class to construct the object. If this is given, we will use this method (which must be a
classmethod
) instead of the default constructor.- Parameters:
overwrite (bool) – optional (default=False) If True, overwrites any existing models registered under
name
. Else, throws an error if a model is already registered undername
.
# Examples To use this class, you would typically have a base class that inherits from
Registrable
: ```python class Transform(Registrable):…
``` Then, if you want to register a subclass, you decorate it like this: ```python @Transform.register(“shift-transform”) class ShiftTransform(Transform):
- def __init__(self, param1: int, param2: str):
…
` Registering a class like this will let you instantiate a class from a config file, where you give ``"type": "shift-transform"
, and keys corresponding to the parameters of the__init__
method (note that for this to work, those parameters must have type annotations). If you want to have the instantiation from a config file call a method other than the constructor, either because you have several different construction paths that could be taken for the same object (as we do inTransform
) or because you have logic you want to happen before you get to the constructor, you can register a specific@classmethod
as the constructor to use.
- classmethod by_name(name)[source]
Returns a callable function that constructs an argument of the registered class. Because you can register particular functions as constructors for specific names, this isn’t necessarily the
__init__
method of some class.
- classmethod resolve_class_name(name)[source]
Returns the subclass that corresponds to the given
name
, along with the name of the method that was registered as a constructor for thatname
, if any. This method also allowsname
to be a fully-specified module name, instead of a name that was already added to theRegistry
. In that case, you cannot use a separate function as a constructor (as you need to callcls.register()
in order to tell us what separate function to use).
- utils.parse_uri_to_protocol_and_path(uri)[source]
Parse a uri into two parts, protocol and path. Set ‘file’ as default protocol when lack protocol.
- Parameters:
uri – str The uri to identify a resource, whose format is like ‘protocol://uri’.
- Returns:
str. The method to access the resource. URI: str. The location of the resource.
- Return type:
Protocol
- utils.is_master_process()[source]
Check if the process is the master process in all machines.
- Returns:
bool
- utils.is_local_master_process()[source]
Check if the process is the master process in the local machine.
- Returns:
bool
- utils.dict_deep_update(target, source, is_add_new_key=True)[source]
Update ‘target’ dict by ‘source’ dict deeply, and return a new dict copied from target and source deeply.
- Parameters:
target – dict
source – dict
is_add_new_key – bool, default True. Identify if add a key that in source but not in target into target.
- Returns:
dict. It contains the both target and source values, but keeps the values from source when the key is duplicated.
- Return type:
New target
- utils.rk4_step_method(diff_func, dt, z0)[source]
Fourth order Runge-Kutta method for solving ODEs.
- Parameters:
diff_func – function(dt, state) Differential equation.
dt – Tensor with shape […, 1] Equal to t1 - t0.
z0 – Tensor with shape […, dim] State at t0.
- Returns:
Tensor with shape […, dim], which is updated state.
- class utils.PaddingStrategy(value)[source]
Bases:
ExplicitEnum
Possible values for the padding argument in [EventTokenizer.__call__]. Useful for tab-completion in an IDE.
- LONGEST = 'longest'
- MAX_LENGTH = 'max_length'
- DO_NOT_PAD = 'do_not_pad'
- class utils.ExplicitEnum(value)[source]
Bases:
str
,Enum
Enum with more explicit error message for missing values.
- class utils.TruncationStrategy(value)[source]
Bases:
ExplicitEnum
Possible values for the truncation argument in [EventTokenizer.__call__]. Useful for tab-completion in an IDE.
- LONGEST_FIRST = 'longest_first'
- DO_NOT_TRUNCATE = 'do_not_truncate'