Source code for easy_tpp.runner.base_runner

import logging
from abc import abstractmethod

from easy_tpp.preprocess import TPPDataLoader
from easy_tpp.utils import Registrable, Timer, logger, get_unique_id, LogConst, get_stage, RunnerPhase


[docs]class Runner(Registrable): """Registrable Base Runner class. """
[docs] def __init__( self, runner_config, unique_model_dir=False, **kwargs): """Initialize the base runner. Args: runner_config (RunnerConfig): config for the runner. unique_model_dir (bool, optional): whether to give unique dir to save the model. Defaults to False. """ self.runner_config = runner_config # re-assign the model_dir if unique_model_dir: runner_config.model_dir = runner_config.base_config.specs['saved_model_dir'] + '_' + get_unique_id() self.save_log() skip_data_loader = kwargs.get('skip_data_loader', False) if not skip_data_loader: # build data reader data_config = self.runner_config.data_config backend = self.runner_config.base_config.backend kwargs = self.runner_config.trainer_config.get_yaml_config() self._data_loader = TPPDataLoader( data_config=data_config, backend=backend, **kwargs ) self.timer = Timer()
[docs] @staticmethod def build_from_config(runner_config, unique_model_dir=False, **kwargs): """Build up the runner from runner config. Args: runner_config (RunnerConfig): config for the runner. unique_model_dir (bool, optional): whether to give unique dir to save the model. Defaults to False. Returns: Runner: the corresponding runner class. """ runner_cls = Runner.by_name(runner_config.base_config.runner_id) return runner_cls(runner_config, unique_model_dir=unique_model_dir, **kwargs)
[docs] def get_config(self): return self.runner_config
[docs] def set_model_dir(self, model_dir): self.runner_config.base_config.specs['saved_model_dir'] = model_dir
[docs] def get_model_dir(self): return self.runner_config.base_config.specs['saved_model_dir']
[docs] def train( self, train_loader=None, valid_loader=None, test_loader=None, **kwargs ): """Train the model. Args: train_loader (EasyTPP.DataLoader, optional): data loader for train set. Defaults to None. valid_loader (EasyTPP.DataLoader, optional): data loader for valid set. Defaults to None. test_loader (EasyTPP.DataLoader, optional): data loader for test set. Defaults to None. Returns: model: _description_ """ # no train and valid loader from outside if train_loader is None and valid_loader is None: train_loader = self._data_loader.train_loader() valid_loader = self._data_loader.valid_loader() # no test loader from outside and there indeed exits test data in config if test_loader is None and self.runner_config.data_config.test_dir is not None: test_loader = self._data_loader.test_loader() logger.info(f'Data \'{self.runner_config.base_config.dataset_id}\' loaded...') timer = self.timer timer.start() model_id = self.runner_config.base_config.model_id logger.info(f'Start {model_id} training...') model = self._train_model( train_loader, valid_loader, test_loader=test_loader, **kwargs ) logger.info(f'End {model_id} train! Cost time: {timer.end()}') return model
[docs] def evaluate(self, valid_loader=None, **kwargs): if valid_loader is None: valid_loader = self._data_loader.valid_loader() logger.info(f'Data \'{self.runner_config.base_config.dataset_id}\' loaded...') timer = self.timer timer.start() model_id = self.runner_config.base_config.model_id logger.info(f'Start {model_id} evaluation...') metric = self._evaluate_model( valid_loader, **kwargs ) logger.info(f'End {model_id} evaluation! Cost time: {timer.end()}') return metric['rmse'] # return a list of scalr for HPO to use
[docs] def gen(self, gen_loader=None, **kwargs): if gen_loader is None: gen_loader = self._data_loader.test_loader() logger.info(f'Data \'{self.runner_config.dataset_id}\' loaded...') timer = self.timer timer.start() model_name = self.runner_config.model_id logger.info(f'Start {model_name} evaluation...') model = self._gen_model( gen_loader, **kwargs ) logger.info(f'End {model_name} generation! Cost time: {timer.end()}') return model
@abstractmethod def _train_model(self, train_loader, valid_loader, **kwargs): pass @abstractmethod def _evaluate_model(self, data_loader, **kwargs): pass @abstractmethod def _gen_model(self, data_loader, **kwargs): pass @abstractmethod def _save_model(self, model_dir, **kwargs): pass @abstractmethod def _load_model(self, model_dir, **kwargs): pass
[docs] def save_log(self): """Save log to local files """ log_dir = self.runner_config.base_config.specs['saved_log_dir'] fh = logging.FileHandler(log_dir) fh.setFormatter(logging.Formatter(LogConst.DEFAULT_FORMAT_LONG)) logger.addHandler(fh) logger.info(f'Save the log to {log_dir}') return
[docs] def save( self, model_dir=None, **kwargs ): return self._save_model(model_dir, **kwargs)
[docs] def run(self, **kwargs): """Start the runner. Args: **kwargs (dict): optional params. Returns: EasyTPP.BaseModel, dict: the results of the process. """ current_stage = get_stage(self.runner_config.base_config.stage) if current_stage == RunnerPhase.TRAIN: return self.train(**kwargs) elif current_stage == RunnerPhase.VALIDATE: return self.evaluate(**kwargs) else: return self.gen(**kwargs)