Source code for easy_tpp.utils.misc

import copy
import os
import pickle

import numpy as np
import yaml

from easy_tpp.utils.const import RunnerPhase


[docs]def py_assert(condition, exception_type, msg): """An assert function that ensures the condition holds, otherwise throws a message. Args: condition (bool): a formula to ensure validity. exception_type (_StandardError): Error type, such as ValueError. msg (str): a message to throw out. Raises: exception_type: throw an error when the condition does not hold. """ if not condition: raise exception_type(msg)
[docs]def make_config_string(config, max_num_key=4): """Generate a name for config files. Args: config (dict): configuration dict. max_num_key (int, optional): max number of keys to concat in the output. Defaults to 4. Returns: dict: a concatenated string from config dict. """ str_config = '' num_key = 0 for k, v in config.items(): if num_key < max_num_key: # for the moment we only record model name if k == 'name': str_config += str(v) + '_' num_key += 1 return str_config[:-1]
[docs]def save_yaml_config(save_dir, config): """A function that saves a dict of config to yaml format file. Args: save_dir (str): the path to save config file. config (dict): the target config object. """ prt_dir = os.path.dirname(save_dir) from collections import OrderedDict # add yaml representer for different type yaml.add_representer( OrderedDict, lambda dumper, data: dumper.represent_mapping('tag:yaml.org,2002:map', data.items()) ) if prt_dir != '' and not os.path.exists(prt_dir): os.makedirs(prt_dir) with open(save_dir, 'w') as f: yaml.dump(config, stream=f, default_flow_style=False, sort_keys=False) return
[docs]def load_yaml_config(config_dir): """ Load yaml config file from disk. Args: config_dir: str or Path The path of the config file. Returns: Config: dict. """ with open(config_dir) as config_file: # load configs config = yaml.load(config_file, Loader=yaml.FullLoader) return config
[docs]def get_stage(stage): stage = stage.lower() if stage in ['train', 'training']: return RunnerPhase.TRAIN elif stage in ['valid', 'dev', 'eval']: return RunnerPhase.VALIDATE else: return RunnerPhase.PREDICT
[docs]def create_folder(*args): """Create path if the folder doesn't exist. Returns: str: the created folder's path. """ path = os.path.join(*args) if not os.path.exists(path): os.makedirs(path) return path
[docs]def load_pickle(file_dir): """Load from pickle file. Args: file_dir (BinaryIO): dir of the pickle file. Returns: any type: the loaded data. """ with open(file_dir, 'rb') as file: try: data = pickle.load(file, encoding='latin-1') except Exception: data = pickle.load(file) return data
[docs]def save_pickle(file_dir, object_to_save): """Save the object to a pickle file. Args: file_dir (str): dir of the pickle file. object_to_save (any): the target data to be saved. """ with open(file_dir, "wb") as f_out: pickle.dump(object_to_save, f_out) return
[docs]def has_key(target_dict, target_keys): """Check if the keys exist in the target dict. Args: target_dict (dict): a dict. target_keys (str, list): list of keys. Returns: bool: True if all the key exist in the dict; False otherwise. """ if not isinstance(target_keys, list): target_keys = [target_keys] for k in target_keys: if k not in target_dict: return False return True
[docs]def array_pad_cols(arr, max_num_cols, pad_index): """Pad the array by columns. Args: arr (np.array): target array to be padded. max_num_cols (int): target num cols for padded array. pad_index (int): pad index to fill out the padded elements Returns: np.array: the padded array. """ res = np.ones((arr.shape[0], max_num_cols)) * pad_index res[:, :arr.shape[1]] = arr return res
[docs]def concat_element(arrs, pad_index): """ Concat element from each batch output """ n_lens = len(arrs) n_elements = len(arrs[0]) # found out the max seq len (num cols) in output arrays max_len = max([x[0].shape[1] for x in arrs]) concated_outputs = [] for j in range(n_elements): a_output = [] for i in range(n_lens): arrs_ = array_pad_cols(arrs[i][j], max_num_cols=max_len, pad_index=pad_index) a_output.append(arrs_) concated_outputs.append(np.concatenate(a_output, axis=0)) # n_elements * [ [n_lens, dim_of_element] ] return concated_outputs
[docs]def to_dict(obj, classkey=None): if isinstance(obj, dict): data = {} for (k, v) in obj.items(): data[k] = to_dict(v, classkey) return data elif hasattr(obj, "_ast"): return to_dict(obj._ast()) elif hasattr(obj, "__iter__"): return [to_dict(v, classkey) for v in obj] elif hasattr(obj, "__dict__"): data = dict([(key, to_dict(value, classkey)) for key, value in obj.__dict__.iteritems() if not callable(value) and not key.startswith('_') and key not in ['name']]) if classkey is not None and hasattr(obj, "__class__"): data[classkey] = obj.__class__.__name__ return data else: return obj
[docs]def dict_deep_update(target, source, is_add_new_key=True): """ Update 'target' dict by 'source' dict deeply, and return a new dict copied from target and source deeply. Args: target: dict source: dict is_add_new_key: bool, default True. Identify if add a key that in source but not in target into target. Returns: New target: dict. It contains the both target and source values, but keeps the values from source when the key is duplicated. """ # deep copy for avoiding to modify the original dict result = copy.deepcopy(target) if target is not None else {} if source is None: return result for key, value in source.items(): if key not in result: if is_add_new_key: result[key] = value continue # both target and source have the same key base_type_list = [int, float, str, tuple, bool] if type(result[key]) in base_type_list or type(source[key]) in base_type_list: result[key] = value else: result[key] = dict_deep_update(result[key], source[key], is_add_new_key=is_add_new_key) return result