Source code for easy_tpp.config_factory.runner_config

import copy
import os

from easy_tpp.config_factory.config import Config
from easy_tpp.config_factory.data_config import DataConfig
from easy_tpp.config_factory.model_config import TrainerConfig, ModelConfig, BaseConfig
from easy_tpp.utils import create_folder, logger, get_unique_id, get_stage, RunnerPhase, \
    MetricsHelper, DefaultRunnerConfig, py_assert, is_torch_available, is_tf_available, is_tf_gpu_available, \
    is_torch_gpu_available
from easy_tpp.utils.const import Backend


[docs]@Config.register('runner_config') class RunnerConfig(Config):
[docs] def __init__(self, base_config, model_config, data_config, trainer_config): """Initialize the Config class. Args: base_config (EasyTPP.BaseConfig): BaseConfig object. model_config (EasyTPP.ModelConfig): ModelConfig object. data_config (EasyTPP.DataConfig): DataConfig object. trainer_config (EasyTPP.TrainerConfig): TrainerConfig object """ self.data_config = data_config self.model_config = model_config self.base_config = base_config self.trainer_config = trainer_config self.ensure_valid_config() self.update_config() # save the complete config save_dir = self.base_config.specs['output_config_dir'] self.save_to_yaml_file(save_dir) logger.info(f'Save the config to {save_dir}')
[docs] def get_yaml_config(self): """Return the config in dict (yaml compatible) format. Returns: dict: config of the runner config in dict format. """ return {'data_config': self.data_config.get_yaml_config(), 'base_config': self.base_config.get_yaml_config(), 'model_config': self.model_config.get_yaml_config(), 'trainer_config': self.trainer_config.get_yaml_config()}
[docs] @staticmethod def parse_from_yaml_config(yaml_config, **kwargs): """Parse from the yaml to generate the config object. Args: yaml_config (dict): configs from yaml file. Returns: RunnerConfig: Config class for trainer specs. """ direct_parse = kwargs.get('direct_parse', False) if not direct_parse: exp_id = kwargs.get('experiment_id') yaml_exp_config = yaml_config[exp_id] dataset_id = yaml_exp_config.get('base_config').get('dataset_id') if dataset_id is None: dataset_id = DefaultRunnerConfig.DEFAULT_DATASET_ID try: yaml_data_config = yaml_config['data'][dataset_id] except KeyError: raise RuntimeError('dataset_id={} is not found in config.'.format(dataset_id)) data_config = DataConfig.parse_from_yaml_config(yaml_data_config) # add exp id to base config yaml_exp_config.get('base_config').update(exp_id=exp_id) else: yaml_exp_config = yaml_config data_config = DataConfig.parse_from_yaml_config(yaml_exp_config.get('data_config')) base_config = BaseConfig.parse_from_yaml_config(yaml_exp_config.get('base_config')) model_config = ModelConfig.parse_from_yaml_config(yaml_exp_config.get('model_config')) trainer_config = TrainerConfig.parse_from_yaml_config(yaml_exp_config.get('trainer_config')) return RunnerConfig( data_config=data_config, base_config=base_config, model_config=model_config, trainer_config=trainer_config )
[docs] def ensure_valid_config(self): """Do some sanity check about the config, to avoid conflicts in settings. """ # during testing we dont do shuffle by default self.trainer_config.shuffle = False # during testing we dont apply tfb by default self.trainer_config.use_tfb = False return
[docs] def update_config(self): """Updated config dict. """ model_folder_name = get_unique_id() log_folder = create_folder(self.base_config.base_dir, model_folder_name) model_folder = create_folder(log_folder, 'models') self.base_config.specs['log_folder'] = log_folder self.base_config.specs['saved_model_dir'] = os.path.join(model_folder, 'saved_model') self.base_config.specs['saved_log_dir'] = os.path.join(log_folder, 'log') self.base_config.specs['output_config_dir'] = os.path.join(log_folder, f'{self.base_config.exp_id}_output.yaml') if self.trainer_config.use_tfb: self.base_config.specs['tfb_train_dir'] = create_folder(log_folder, 'tfb_train') self.base_config.specs['tfb_valid_dir'] = create_folder(log_folder, 'tfb_valid') current_stage = get_stage(self.base_config.stage) is_training = current_stage == RunnerPhase.TRAIN self.model_config.is_training = is_training self.model_config.gpu = self.trainer_config.gpu # update the dataset config => model config self.model_config.num_event_types_pad = self.data_config.data_specs.num_event_types_pad self.model_config.num_event_types = self.data_config.data_specs.num_event_types self.model_config.pad_token_id = self.data_config.data_specs.pad_token_id self.model_config.max_len = self.data_config.data_specs.max_len # update base config => model config model_id = self.base_config.model_id self.model_config.model_id = model_id if self.base_config.model_id == 'ODETPP' and self.base_config.backend == Backend.TF: py_assert(self.data_config.data_specs.padding_strategy == 'max_length', ValueError, 'For ODETPP in TensorFlow, we must pad all sequence to ' 'the same length (max len of the sequences)!') run = current_stage use_torch = self.base_config.backend == Backend.Torch device = 'GPU' if self.trainer_config.gpu >= 0 else 'CPU' py_assert(is_torch_available() if use_torch else is_tf_available(), ValueError, f'Backend {self.base_config.backend} is not supported in the current environment yet !') if use_torch and device == 'GPU': py_assert(is_torch_gpu_available(), ValueError, f'Torch cuda is not supported in the current environment yet!') if not use_torch and device == 'GPU': py_assert(is_tf_gpu_available(), ValueError, f'Tensorflow GPU is not supported in the current environment yet!') critical_msg = '{run} model {model_name} using {device} ' \ 'with {tf_torch} backend'.format(run=run, model_name=model_id, device=device, tf_torch=self.base_config.backend) logger.critical(critical_msg) return
[docs] def get_metric_functions(self): return MetricsHelper.get_metrics_callback_from_names(self.trainer_config.metrics)
[docs] def get_metric_direction(self, metric_name='rmse'): return MetricsHelper.get_metric_direction(metric_name)
[docs] def copy(self): """Copy the config. Returns: RunnerConfig: a copy of current config. """ return RunnerConfig( base_config=copy.deepcopy(self.base_config), model_config=copy.deepcopy(self.model_config), data_config=copy.deepcopy(self.data_config), trainer_config=copy.deepcopy(self.trainer_config) )