Source code for easy_tpp.config_factory.hpo_config

from easy_tpp.config_factory.config import Config
from easy_tpp.config_factory.runner_config import RunnerConfig
from easy_tpp.utils import parse_uri_to_protocol_and_path, py_assert


[docs]class HPOConfig(Config):
[docs] def __init__(self, framework_id, storage_uri, is_continuous, num_trials, num_jobs): """Initialize the HPO Config Args: framework_id (str): hpo framework id. storage_uri (str): result storage dir. is_continuous (bool): whether to continuously do the optimization. num_trials (int): num of trails used in optimization. num_jobs (int): num of the jobs. """ self.framework_id = framework_id or 'optuna' self.is_continuous = is_continuous if is_continuous is not None else True self.num_trials = num_trials or 50 self.storage_uri = storage_uri self.num_jobs = num_jobs if num_jobs is not None else 1
@property def storage_protocol(self): """Get the storage protocol Returns: str: the dir of the storage protocol. """ storage_protocol, _ = parse_uri_to_protocol_and_path(self.storage_uri) return storage_protocol @property def storage_path(self): """Get the storage protocol Returns: str: the dir of the hpo data storage. """ _, storage_path = parse_uri_to_protocol_and_path(self.storage_uri) return storage_path
[docs] def get_yaml_config(self): """Return the config in dict (yaml compatible) format. Returns: dict: config of the HPO specs in dict format. """ return { 'framework_id': self.framework_id, 'storage_uri': self.storage_uri, 'is_continuous': self.is_continuous, 'num_trials': self.num_trials, 'num_jobs': self.num_jobs }
[docs] @staticmethod def parse_from_yaml_config(yaml_config, **kwargs): """Parse from the yaml to generate the config object. Args: yaml_config (dict): configs from yaml file. Returns: EasyTPP.HPOConfig: Config class for HPO specs. """ if yaml_config is None: return None else: return HPOConfig( framework_id=yaml_config.get('framework_id'), storage_uri=yaml_config.get('storage_uri'), is_continuous=yaml_config.get('is_continuous'), num_trials=yaml_config.get('num_trials'), num_jobs=yaml_config.get('num_jobs'), )
[docs] def copy(self): """Copy the config. Returns: EasyTPP.HPOConfig: a copy of current config. """ return HPOConfig( framework_id=self.framework_id, storage_uri=self.storage_uri, is_continuous=self.is_continuous, num_trials=self.num_trials, num_jobs=self.num_jobs )
[docs]@Config.register('hpo_runner_config') class HPORunnerConfig(Config):
[docs] def __init__(self, hpo_config, runner_config): """Initialize the config class Args: hpo_config (EasyTPP.HPOConfig): hpo config class. runner_config (EasyTPP.RunnerConfig): runner config class. """ self.hpo_config = hpo_config self.runner_config = runner_config
[docs] @staticmethod def parse_from_yaml_config(yaml_config, **kwargs): """Parse from the yaml to generate the config object. Args: yaml_config (dict): configs from yaml file. Returns: EasyTPP.HPORunnerConfig: Config class for HPO specs. """ runner_config = RunnerConfig.parse_from_yaml_config(yaml_config, **kwargs) hpo_config = HPOConfig.parse_from_yaml_config(yaml_config.get('hpo'), **kwargs) py_assert(hpo_config is not None, ValueError, 'No hpo configs is provided for HyperTuner') return HPORunnerConfig( hpo_config=hpo_config, runner_config=runner_config )
[docs] def copy(self): """Copy the config. Returns: EasyTPP.HPORunnerConfig: a copy of current config. """ return HPORunnerConfig( hpo_config=self.hpo_config, runner_config=self.runner_config )