from easy_tpp.preprocess.dataset import TPPDataset
from easy_tpp.preprocess.dataset import get_data_loader
from easy_tpp.preprocess.event_tokenizer import EventTokenizer
from easy_tpp.utils import load_pickle, py_assert
[docs]class TPPDataLoader:
[docs] def __init__(self, data_config, backend, **kwargs):
"""Initialize the dataloader
Args:
data_config (EasyTPP.DataConfig): data config.
backend (str): backend engine, e.g., tensorflow or torch.
"""
self.data_config = data_config
self.num_event_types = data_config.data_specs.num_event_types
self.backend = backend
self.kwargs = kwargs
[docs] def get_loader(self, split='train', **kwargs):
"""Get the corresponding data loader.
Args:
split (str, optional): denote the train, valid and test set. Defaults to 'train'.
num_event_types (int, optional): num of event types in the data. Defaults to None.
Raises:
NotImplementedError: the input of 'num_event_types' is inconsistent with the data.
Returns:
EasyTPP.DataLoader: the data loader for tpp data.
"""
data_dir = self.data_config.get_data_dir(split)
data_source_type = data_dir.split('.')[-1]
if data_source_type == 'pkl':
data = self.build_input_from_pkl(data_dir, split)
dataset = TPPDataset(data)
tokenizer = EventTokenizer(self.data_config.data_specs)
loader = get_data_loader(dataset,
self.backend,
tokenizer,
batch_size=self.kwargs['batch_size'],
shuffle=self.kwargs['shuffle'],
**kwargs)
else:
raise NotImplementedError
return loader
[docs] def train_loader(self, **kwargs):
"""Return the train loader
Returns:
EasyTPP.DataLoader: data loader for train set.
"""
return self.get_loader('train', **kwargs)
[docs] def valid_loader(self, **kwargs):
"""Return the valid loader
Returns:
EasyTPP.DataLoader: data loader for valid set.
"""
return self.get_loader('dev', **kwargs)
[docs] def test_loader(self, **kwargs):
"""Return the test loader
Returns:
EasyTPP.DataLoader: data loader for test set.
"""
return self.get_loader('test', **kwargs)