EasyTPP Utilities Modules
- utils.py_assert(condition, exception_type, msg)[source]
An assert function that ensures the condition holds, otherwise throws a message.
- Parameters:
condition (bool) – a formula to ensure validity.
exception_type (_StandardError) – Error type, such as ValueError.
msg (str) – a message to throw out.
- Raises:
exception_type – throw an error when the condition does not hold.
- utils.make_config_string(config, max_num_key=4)[source]
Generate a name for config files.
- Parameters:
config (dict) – configuration dict.
max_num_key (int, optional) – max number of keys to concat in the output. Defaults to 4.
- Returns:
a concatenated string from config dict.
- Return type:
dict
- utils.create_folder(*args)[source]
Create path if the folder doesn’t exist.
- Returns:
the created folder’s path.
- Return type:
str
- utils.save_yaml_config(save_dir, config)[source]
A function that saves a dict of config to yaml format file.
- Parameters:
save_dir (str) – the path to save config file.
config (dict) – the target config object.
- utils.load_yaml_config(config_dir)[source]
Load yaml config file from disk.
- Parameters:
config_dir – str or Path The path of the config file.
- Returns:
dict.
- Return type:
- class utils.RunnerPhase(value)[source]
Bases:
ExplicitEnumModel runner phase enum.
- TRAIN = 'train'
- VALIDATE = 'validate'
- PREDICT = 'predict'
- class utils.LogConst[source]
Bases:
objectFormat for log handler.
- DEFAULT_FORMAT = '[%(asctime)s] [%(levelname)s] %(message)s'
- DEFAULT_FORMAT_LONG = '%(asctime)s - %(filename)s[pid:%(process)d;line:%(lineno)d:%(funcName)s] - %(levelname)s: %(message)s'
- utils.load_pickle(file_dir)[source]
Load from pickle file.
- Parameters:
file_dir (BinaryIO) – dir of the pickle file.
- Returns:
the loaded data.
- Return type:
any type
- utils.has_key(target_dict, target_keys)[source]
Check if the keys exist in the target dict.
- Parameters:
target_dict (dict) – a dict.
target_keys (str, list) – list of keys.
- Returns:
True if all the key exist in the dict; False otherwise.
- Return type:
bool
- utils.array_pad_cols(arr, max_num_cols, pad_index)[source]
Pad the array by columns.
- Parameters:
arr (np.array) – target array to be padded.
max_num_cols (int) – target num cols for padded array.
pad_index (int) – pad index to fill out the padded elements
- Returns:
the padded array.
- Return type:
np.array
- class utils.MetricsTracker[source]
Bases:
objectTrack and record the metrics.
- update_best(key, value, epoch)[source]
Update the recorder for the best metrics.
- Parameters:
key (str) – metrics key.
value (float) – metrics value.
epoch (int) – num of epoch.
- Raises:
NotImplementedError – for keys other than ‘loglike’.
- Returns:
whether the recorder has been updated.
- Return type:
bool
- utils.set_device(gpu=-1)[source]
Setup the device.
- Parameters:
gpu (int, optional) – num of GPU to use. Defaults to -1 (not use GPU, i.e., use CPU).
- utils.set_optimizer(optimizer, params, lr)[source]
Setup the optimizer.
- Parameters:
optimizer (str) – name of the optimizer.
params (dict) – dict of params for the optimizer.
lr (float) – learning rate.
- Raises:
NotImplementedError – if the optimizer’s name is wrong or the optimizer is not supported,
we raise error. –
- Returns:
torch optimizer.
- Return type:
torch.optim
- utils.set_seed(seed=1029)[source]
Setup random seed.
- Parameters:
seed (int, optional) – random seed. Defaults to 1029.
- utils.save_pickle(file_dir, object_to_save)[source]
Save the object to a pickle file.
- Parameters:
file_dir (str) – dir of the pickle file.
object_to_save (any) – the target data to be saved.
- utils.count_model_params(model)[source]
Count the number of params of the model.
- Parameters:
model (torch.nn.Moduel) – a torch model.
- Returns:
total num of the parameters.
- Return type:
int
- class utils.Registrable[source]
Bases:
objectAny class that inherits from
Registrablegains access to a named registry for its subclasses. To register them, just decorate them with the classmethod@BaseClass.register(name).After which you can call
BaseClass.list_available()to get the keys for the registered subclasses, andBaseClass.by_name(name)to get the corresponding subclass.Note that the registry stores the subclasses themselves; not class instances. In most cases you would then call
from_params(params)on the returned subclass.- classmethod register(name, constructor=None, overwrite=False)[source]
Register a class under a particular name. :param name: The name to register the class under. :type name: str :param constructor: optional (default=None)
The name of the method to use on the class to construct the object. If this is given, we will use this method (which must be a
classmethod) instead of the default constructor.- Parameters:
overwrite (bool) – optional (default=False) If True, overwrites any existing models registered under
name. Else, throws an error if a model is already registered undername.
# Examples To use this class, you would typically have a base class that inherits from
Registrable: ```python class Transform(Registrable):…
``` Then, if you want to register a subclass, you decorate it like this: ```python @Transform.register(“shift-transform”) class ShiftTransform(Transform):
- def __init__(self, param1: int, param2: str):
…
` Registering a class like this will let you instantiate a class from a config file, where you give ``"type": "shift-transform", and keys corresponding to the parameters of the__init__method (note that for this to work, those parameters must have type annotations). If you want to have the instantiation from a config file call a method other than the constructor, either because you have several different construction paths that could be taken for the same object (as we do inTransform) or because you have logic you want to happen before you get to the constructor, you can register a specific@classmethodas the constructor to use.
- classmethod by_name(name)[source]
Returns a callable function that constructs an argument of the registered class. Because you can register particular functions as constructors for specific names, this isn’t necessarily the
__init__method of some class.
- classmethod resolve_class_name(name)[source]
Returns the subclass that corresponds to the given
name, along with the name of the method that was registered as a constructor for thatname, if any. This method also allowsnameto be a fully-specified module name, instead of a name that was already added to theRegistry. In that case, you cannot use a separate function as a constructor (as you need to callcls.register()in order to tell us what separate function to use).
- utils.parse_uri_to_protocol_and_path(uri)[source]
Parse a uri into two parts, protocol and path. Set ‘file’ as default protocol when lack protocol.
- Parameters:
uri – str The uri to identify a resource, whose format is like ‘protocol://uri’.
- Returns:
str. The method to access the resource. URI: str. The location of the resource.
- Return type:
Protocol
- utils.is_master_process()[source]
Check if the process is the master process in all machines.
- Returns:
bool
- utils.is_local_master_process()[source]
Check if the process is the master process in the local machine.
- Returns:
bool
- utils.dict_deep_update(target, source, is_add_new_key=True)[source]
Update ‘target’ dict by ‘source’ dict deeply, and return a new dict copied from target and source deeply.
- Parameters:
target – dict
source – dict
is_add_new_key – bool, default True. Identify if add a key that in source but not in target into target.
- Returns:
dict. It contains the both target and source values, but keeps the values from source when the key is duplicated.
- Return type:
New target
- utils.rk4_step_method(diff_func, dt, z0)[source]
Fourth order Runge-Kutta method for solving ODEs.
- Parameters:
diff_func – function(dt, state) Differential equation.
dt – Tensor with shape […, 1] Equal to t1 - t0.
z0 – Tensor with shape […, dim] State at t0.
- Returns:
Tensor with shape […, dim], which is updated state.
- class utils.PaddingStrategy(value)[source]
Bases:
ExplicitEnumPossible values for the padding argument in [EventTokenizer.__call__]. Useful for tab-completion in an IDE.
- LONGEST = 'longest'
- MAX_LENGTH = 'max_length'
- DO_NOT_PAD = 'do_not_pad'
- class utils.ExplicitEnum(value)[source]
Bases:
str,EnumEnum with more explicit error message for missing values.
- class utils.TruncationStrategy(value)[source]
Bases:
ExplicitEnumPossible values for the truncation argument in [EventTokenizer.__call__]. Useful for tab-completion in an IDE.
- LONGEST_FIRST = 'longest_first'
- DO_NOT_TRUNCATE = 'do_not_truncate'