Source code for easy_tpp.model.torch_model.torch_fullynn

import torch
from torch import nn
from torch.autograd import grad

from easy_tpp.model.torch_model.torch_basemodel import TorchBaseModel


[docs]class CumulHazardFunctionNetwork(nn.Module): """Cumulative Hazard Function Network ref: https://github.com/wassname/torch-neuralpointprocess """
[docs] def __init__(self, model_config): super(CumulHazardFunctionNetwork, self).__init__() self.hidden_size = model_config.hidden_size self.num_mlp_layers = model_config.model_specs['num_mlp_layers'] self.num_event_types = model_config.num_event_types # transform inter-event time embedding self.layer_dense_1 = nn.Linear(in_features=1, out_features=self.hidden_size) # concat rnn states and inter-event time embedding self.layer_dense_2 = nn.Linear(in_features=self.hidden_size * 2, out_features=self.hidden_size) # mlp layers self.module_list = nn.ModuleList( [nn.Linear(in_features=self.hidden_size, out_features=self.hidden_size) for _ in range(self.num_mlp_layers - 1)]) self.layer_dense_3 = nn.Sequential(nn.Linear(in_features=self.hidden_size, out_features=self.num_event_types), nn.Softplus()) self.params_eps = torch.finfo(torch.float32).eps # ensure positiveness of parameters self.init_weights_positive()
def init_weights_positive(self): for p in self.parameters(): p.data = torch.abs(p.data) p.data = torch.clamp(p.data, min=self.params_eps)
[docs] def forward(self, hidden_states, time_delta_seqs): for p in self.parameters(): p.data = torch.clamp(p.data, min=self.params_eps) time_delta_seqs.requires_grad_(True) # [batch_size, seq_len, hidden_size] t = self.layer_dense_1(time_delta_seqs.unsqueeze(dim=-1)) # [batch_size, seq_len, hidden_size] out = torch.tanh(self.layer_dense_2(torch.cat([hidden_states, t], dim=-1))) for layer in self.module_list: out = torch.tanh(layer(out)) # [batch_size, seq_len, num_event_types] integral_lambda = self.layer_dense_3(out) # [batch_size, seq_len] derivative_integral_lambda = grad( integral_lambda.sum(dim=-1).mean(), time_delta_seqs, create_graph=True, retain_graph=True)[0] return integral_lambda, derivative_integral_lambda
[docs]class FullyNN(TorchBaseModel): """Torch implementation of Fully Neural Network based Model for General Temporal Point Processes, NeurIPS 2019. https://arxiv.org/abs/1905.09690 ref: https://github.com/KanghoonYoon/torch-neuralpointprocess/blob/master/module.py; https://github.com/wassname/torch-neuralpointprocess """
[docs] def __init__(self, model_config): """Initialize the model Args: model_config (EasyTPP.ModelConfig): config of model specs. """ super(FullyNN, self).__init__(model_config) self.rnn_type = model_config.rnn_type self.rnn_list = [nn.LSTM, nn.RNN, nn.GRU] for sub_rnn_class in self.rnn_list: if sub_rnn_class.__name__ == self.rnn_type: self.layer_rnn = sub_rnn_class(input_size=1 + self.hidden_size, hidden_size=self.hidden_size, num_layers=1, batch_first=True, dropout=0.1) self.layer_intensity = CumulHazardFunctionNetwork(model_config)
[docs] def forward(self, time_seqs, time_delta_seqs, type_seqs): """Call the model Args: time_seqs (tensor): [batch_size, seq_len], timestamp seqs. time_delta_seqs (tensor): [batch_size, seq_len], inter-event time seqs. type_seqs (tensor): [batch_size, seq_len], event type seqs. Returns: tensor: hidden states at event times. """ # [batch_size, seq_len, hidden_size] type_embedding = self.layer_type_emb(type_seqs) # [batch_size, seq_len, hidden_size + 1] rnn_input = torch.cat((type_embedding, time_delta_seqs.unsqueeze(-1)), dim=-1) # [batch_size, seq_len, hidden_size] # states right after the event hidden_states, _ = self.layer_rnn(rnn_input) integral_lambda, derivative_integral_lambda = self.layer_intensity(hidden_states, time_delta_seqs) # [batch_size, num_event_types, seq_len] return integral_lambda, derivative_integral_lambda
[docs] def loglike_loss(self, batch): """Compute the loglike loss. Args: batch (tuple, list): batch input. Returns: list: loglike loss, num events. """ # [batch_size, seq_len] time_seqs, time_delta_seqs, type_seqs, batch_non_pad_mask, _, type_mask = batch # [batch_size, seq_len, num_event_types] integral_lambda, derivative_integral_lambda = self.forward(time_seqs[:, :-1], time_delta_seqs[:, 1:], type_seqs[:, :-1]) # [batch_size, seq_len] event_ll = (derivative_integral_lambda + self.eps).log() * batch_non_pad_mask[:, 1:] # [batch_size, seq_len] # multiplied by sequence mask non_event_ll = integral_lambda.sum(-1) * batch_non_pad_mask[:, 1:] num_events = torch.masked_select(event_ll, event_ll.ne(0.0)).size()[0] # return enc_inten to compute accuracy loss = - (event_ll - non_event_ll).sum() return loss, num_events
[docs] def compute_intensities_at_sample_times(self, time_seqs, time_delta_seqs, type_seqs, sample_dtimes, **kwargs): """Compute hidden states at sampled times. Args: time_seqs (tensor): [batch_size, seq_len], times seqs. time_delta_seqs (tensor): [batch_size, seq_len], time delta seqs. type_seqs (tensor): [batch_size, seq_len], event type seqs. sample_dtimes (tensor): [batch_size, seq_len, num_samples], sampled inter-event timestamps. Returns: tensor: [batch_size, seq_len, num_samples, num_event_types], intensity at all sampled times. """ compute_last_step_only = kwargs.get('compute_last_step_only', False) num_samples = sample_dtimes.size()[-1] # [batch_size, seq_len, hidden_size, num_samples] type_emb = torch.tile(self.layer_type_emb(type_seqs)[..., None], (1, 1, 1, num_samples)) # [batch_size, seq_len, hidden_size + 1, num_samples] rnn_input = torch.cat([type_emb, sample_dtimes.unsqueeze(-2)], dim=-2) # [batch_size, num_samples, seq_len, hidden_size + 1] rnn_input = torch.permute(rnn_input, (0, 3, 1, 2)) # [batch_size * num_samples, seq_len, hidden_size + 1] rnn_input = torch.reshape(rnn_input, (-1, rnn_input.size(-2), rnn_input.size(-1))) # [batch_size * num_samples, seq_len, hidden_size] # states right after the event hidden_states, _ = self.layer_rnn(rnn_input) # [batch_size, num_samples, seq_len, hidden_size] hidden_states = torch.reshape(hidden_states, (-1, num_samples, hidden_states.size(-2), hidden_states.size(-1))) # [batch_size, seq_len, num_sample, hidden_size] hidden_states = torch.transpose(hidden_states, -2, -3) # [batch_size, seq_len, num_samples] _, derivative_integral_lambda = self.layer_intensity(hidden_states, sample_dtimes) # FIX: need to fix this later # current version does not truly support multi-event types # [batch_size, seq_len, num_samples, num_event_types] derivative_integral_lambda = torch.tile(derivative_integral_lambda[..., None], (1, 1, 1, self.num_event_types)) if compute_last_step_only: lambdas = derivative_integral_lambda[:, -1:, :, :] else: # [batch_size, seq_len, num_samples, num_event_types] lambdas = derivative_integral_lambda return lambdas