import copy
import os
import pickle
import numpy as np
import yaml
from easy_tpp.utils.const import RunnerPhase
[docs]def py_assert(condition, exception_type, msg):
"""An assert function that ensures the condition holds, otherwise throws a message.
Args:
condition (bool): a formula to ensure validity.
exception_type (_StandardError): Error type, such as ValueError.
msg (str): a message to throw out.
Raises:
exception_type: throw an error when the condition does not hold.
"""
if not condition:
raise exception_type(msg)
[docs]def make_config_string(config, max_num_key=4):
"""Generate a name for config files.
Args:
config (dict): configuration dict.
max_num_key (int, optional): max number of keys to concat in the output. Defaults to 4.
Returns:
dict: a concatenated string from config dict.
"""
str_config = ''
num_key = 0
for k, v in config.items():
if num_key < max_num_key: # for the moment we only record model name
if k == 'name':
str_config += str(v) + '_'
num_key += 1
return str_config[:-1]
[docs]def save_yaml_config(save_dir, config):
"""A function that saves a dict of config to yaml format file.
Args:
save_dir (str): the path to save config file.
config (dict): the target config object.
"""
prt_dir = os.path.dirname(save_dir)
from collections import OrderedDict
# add yaml representer for different type
yaml.add_representer(
OrderedDict,
lambda dumper, data: dumper.represent_mapping('tag:yaml.org,2002:map', data.items())
)
if prt_dir != '' and not os.path.exists(prt_dir):
os.makedirs(prt_dir)
with open(save_dir, 'w') as f:
yaml.dump(config, stream=f, default_flow_style=False, sort_keys=False)
return
[docs]def load_yaml_config(config_dir):
""" Load yaml config file from disk.
Args:
config_dir: str or Path
The path of the config file.
Returns:
Config: dict.
"""
with open(config_dir) as config_file:
# load configs
config = yaml.load(config_file, Loader=yaml.FullLoader)
return config
[docs]def get_stage(stage):
stage = stage.lower()
if stage in ['train', 'training']:
return RunnerPhase.TRAIN
elif stage in ['valid', 'dev', 'eval']:
return RunnerPhase.VALIDATE
else:
return RunnerPhase.PREDICT
[docs]def create_folder(*args):
"""Create path if the folder doesn't exist.
Returns:
str: the created folder's path.
"""
path = os.path.join(*args)
if not os.path.exists(path):
os.makedirs(path)
return path
[docs]def load_pickle(file_dir):
"""Load from pickle file.
Args:
file_dir (BinaryIO): dir of the pickle file.
Returns:
any type: the loaded data.
"""
with open(file_dir, 'rb') as file:
try:
data = pickle.load(file, encoding='latin-1')
except Exception:
data = pickle.load(file)
return data
[docs]def save_pickle(file_dir, object_to_save):
"""Save the object to a pickle file.
Args:
file_dir (str): dir of the pickle file.
object_to_save (any): the target data to be saved.
"""
with open(file_dir, "wb") as f_out:
pickle.dump(object_to_save, f_out)
return
[docs]def has_key(target_dict, target_keys):
"""Check if the keys exist in the target dict.
Args:
target_dict (dict): a dict.
target_keys (str, list): list of keys.
Returns:
bool: True if all the key exist in the dict; False otherwise.
"""
if not isinstance(target_keys, list):
target_keys = [target_keys]
for k in target_keys:
if k not in target_dict:
return False
return True
[docs]def array_pad_cols(arr, max_num_cols, pad_index):
"""Pad the array by columns.
Args:
arr (np.array): target array to be padded.
max_num_cols (int): target num cols for padded array.
pad_index (int): pad index to fill out the padded elements
Returns:
np.array: the padded array.
"""
res = np.ones((arr.shape[0], max_num_cols)) * pad_index
res[:, :arr.shape[1]] = arr
return res
[docs]def concat_element(arrs, pad_index):
""" Concat element from each batch output """
n_lens = len(arrs)
n_elements = len(arrs[0])
# found out the max seq len (num cols) in output arrays
max_len = max([x[0].shape[1] for x in arrs])
concated_outputs = []
for j in range(n_elements):
a_output = []
for i in range(n_lens):
arrs_ = array_pad_cols(arrs[i][j], max_num_cols=max_len, pad_index=pad_index)
a_output.append(arrs_)
concated_outputs.append(np.concatenate(a_output, axis=0))
# n_elements * [ [n_lens, dim_of_element] ]
return concated_outputs
[docs]def to_dict(obj, classkey=None):
if isinstance(obj, dict):
data = {}
for (k, v) in obj.items():
data[k] = to_dict(v, classkey)
return data
elif hasattr(obj, "_ast"):
return to_dict(obj._ast())
elif hasattr(obj, "__iter__"):
return [to_dict(v, classkey) for v in obj]
elif hasattr(obj, "__dict__"):
data = dict([(key, to_dict(value, classkey))
for key, value in obj.__dict__.iteritems()
if not callable(value) and not key.startswith('_') and key not in ['name']])
if classkey is not None and hasattr(obj, "__class__"):
data[classkey] = obj.__class__.__name__
return data
else:
return obj
[docs]def dict_deep_update(target, source, is_add_new_key=True):
""" Update 'target' dict by 'source' dict deeply, and return a new dict copied from target and source deeply.
Args:
target: dict
source: dict
is_add_new_key: bool, default True.
Identify if add a key that in source but not in target into target.
Returns:
New target: dict. It contains the both target and source values, but keeps the values from source when the key
is duplicated.
"""
# deep copy for avoiding to modify the original dict
result = copy.deepcopy(target) if target is not None else {}
if source is None:
return result
for key, value in source.items():
if key not in result:
if is_add_new_key:
result[key] = value
continue
# both target and source have the same key
base_type_list = [int, float, str, tuple, bool]
if type(result[key]) in base_type_list or type(source[key]) in base_type_list:
result[key] = value
else:
result[key] = dict_deep_update(result[key], source[key], is_add_new_key=is_add_new_key)
return result