Source code for easy_tpp.config_factory.data_config

from easy_tpp.config_factory.config import Config


[docs]class DataSpecConfig(Config):
[docs] def __init__(self, **kwargs): """Initialize the Config class. """ self.num_event_types = kwargs.get('num_event_types') self.pad_token_id = kwargs.get('pad_token_id') self.padding_side = kwargs.get('padding_side') self.truncation_side = kwargs.get('truncation_side') self.padding_strategy = kwargs.get('padding_strategy') self.max_len = kwargs.get('max_len') self.truncation_strategy = kwargs.get('truncation_strategy') self.num_event_types_pad = self.num_event_types + 1 self.model_input_names = kwargs.get('model_input_names') if self.padding_side is not None and self.padding_side not in ["right", "left"]: raise ValueError( f"Padding side should be selected between 'right' and 'left', current value: {self.padding_side}" ) if self.truncation_side is not None and self.truncation_side not in ["right", "left"]: raise ValueError( f"Truncation side should be selected between 'right' and 'left', current value: {self.truncation_side}" )
[docs] def get_yaml_config(self): """Return the config in dict (yaml compatible) format. Returns: dict: config of the data specs in dict format. """ return { 'num_event_types': self.num_event_types, 'pad_token_id': self.pad_token_id, 'padding_side': self.padding_side, 'truncation_side': self.truncation_side, 'padding_strategy': self.padding_strategy, 'truncation_strategy': self.truncation_strategy, 'max_len': self.max_len }
[docs] @staticmethod def parse_from_yaml_config(yaml_config): """Parse from the yaml to generate the config object. Args: yaml_config (dict): configs from yaml file. Returns: DataSpecConfig: Config class for data specs. """ return DataSpecConfig(**yaml_config)
[docs] def copy(self): """Copy the config. Returns: DataSpecConfig: a copy of current config. """ return DataSpecConfig(num_event_types_pad=self.num_event_types_pad, num_event_types=self.num_event_types, event_pad_index=self.pad_token_id, padding_side=self.padding_side, truncation_side=self.truncation_side, padding_strategy=self.padding_strategy, truncation_strategy=self.truncation_strategy, max_len=self.max_len)
[docs]class DataConfig(Config):
[docs] def __init__(self, train_dir, valid_dir, test_dir, specs=None): """Initialize the DataConfig object. Args: train_dir (str): dir of tran set. valid_dir (str): dir of valid set. test_dir (str): dir of test set. specs (dict, optional): specs of dataset. Defaults to None. """ self.train_dir = train_dir self.valid_dir = valid_dir self.test_dir = test_dir self.data_specs = specs or DataSpecConfig() self.data_format = train_dir.split('.')[-1]
[docs] def get_yaml_config(self): """Return the config in dict (yaml compatible) format. Returns: dict: config of the data in dict format. """ return { 'train_dir': self.train_dir, 'valid_dir': self.valid_dir, 'test_dir': self.test_dir, 'data_format': self.data_format, 'data_specs': self.data_specs.get_yaml_config(), }
[docs] @staticmethod def parse_from_yaml_config(yaml_config): """Parse from the yaml to generate the config object. Args: yaml_config (dict): configs from yaml file. Returns: EasyTPP.DataConfig: Config class for data. """ return DataConfig( train_dir=yaml_config.get('train_dir'), valid_dir=yaml_config.get('valid_dir'), test_dir=yaml_config.get('test_dir'), specs=DataSpecConfig.parse_from_yaml_config(yaml_config.get('data_specs')) )
[docs] def copy(self): """Copy the config. Returns: EasyTPP.DataConfig: a copy of current config. """ return DataConfig(train_dir=self.train_dir, valid_dir=self.valid_dir, test_dir=self.test_dir, specs=self.data_specs)
[docs] def get_data_dir(self, split): """Get the dir of the source raw data. Args: split (str): dataset split notation, 'train', 'dev' or 'valid', 'test'. Returns: str: dir of the source raw data file. """ split = split.lower() if split == 'train': return self.train_dir elif split in ['dev', 'valid']: return self.valid_dir else: return self.test_dir