Source code for easy_tpp.model.torch_model.torch_rmtpp

import torch
from torch import nn

from easy_tpp.model.torch_model.torch_basemodel import TorchBaseModel


[docs]class RMTPP(TorchBaseModel): """Torch implementation of Recurrent Marked Temporal Point Processes, KDD 2016. https://www.kdd.org/kdd2016/papers/files/rpp1081-duA.pdf """
[docs] def __init__(self, model_config): """Initialize the model Args: model_config (EasyTPP.ModelConfig): config of model specs. """ super(RMTPP, self).__init__(model_config) self.layer_temporal_emb = nn.Linear(1, self.hidden_size) self.layer_rnn = nn.RNN(input_size=self.hidden_size, hidden_size=self.hidden_size, num_layers=1, batch_first=True) self.layer_hidden = nn.Linear(self.hidden_size, self.num_event_types) self.factor_intensity_base = torch.empty([1, 1, self.num_event_types], device=self.device) self.factor_intensity_current_influence = torch.empty([1, 1, self.num_event_types], device=self.device) nn.init.xavier_normal_(self.factor_intensity_base) nn.init.xavier_normal_(self.factor_intensity_current_influence)
[docs] def state_decay(self, states_to_decay, duration_t): """Equation (11), which computes intensity """ # [batch_size, seq_len, num_event_types] states_to_decay_ = self.layer_hidden(states_to_decay) # [batch_size, seq_len, num_event_types] # put a max number to avoid explode during HPO intensity = torch.exp( states_to_decay_ + self.factor_intensity_current_influence * duration_t + self.factor_intensity_base).clamp(max=1e5) return intensity
[docs] def forward(self, time_seqs, time_delta_seqs, type_seqs, **kwargs): """Call the model. Args: batch (list): batch input. Returns: list: hidden states, [batch_size, seq_len, hidden_dim], states right before the event happens; stacked decay states, [batch_size, max_seq_length, 4, hidden_dim], states right after the event happens. """ max_steps = kwargs.get('max_steps', None) # last event has no time label max_seq_length = max_steps if max_steps is not None else type_seqs.size(1) - 1 # [batch_size, seq_len, hidden_size] type_emb = self.layer_type_emb(type_seqs) # [batch_size, seq_len, hidden_size] temporal_emb = self.layer_temporal_emb(time_seqs[..., None]) # [batch_size, seq_len, hidden_size] # states right after the event decay_states, _ = self.layer_rnn(type_emb + temporal_emb) # if only one event, then we dont decay if max_seq_length == 1: h_t = decay_states else: # States decay - Equation (7) in the paper # states before the happening of the next event h_t = self.state_decay(decay_states, time_delta_seqs[..., None]) return h_t, decay_states
[docs] def loglike_loss(self, batch): """Compute the loglike loss. Args: batch (list): batch input. Returns: tuple: loglikelihood loss and num of events. """ time_seqs, time_delta_seqs, type_seqs, batch_non_pad_mask, _, type_mask = batch lambda_at_event, decay_states = self.forward(time_seqs[:, :-1], time_delta_seqs[:, 1:], type_seqs[:, :-1]) # Num of samples in each batch and num of event time point in the sequence batch_size, seq_len, _ = lambda_at_event.size() # Compute the big lambda integral in equation (8) # 1 - take num_mc_sample rand points in each event interval # 2 - compute its lambda value for every sample point # 3 - take average of these sample points # 4 - times the interval length # interval_t_sample - [batch_size, num_times=max_len-1, num_mc_sample] # for every batch and every event point => do a sampling (num_mc_sampling) # the first dtime is zero, so we use time_delta_seqs[:, 1:] interval_t_sample = self.make_dtime_loss_samples(time_delta_seqs[:, 1:]) # [batch_size, num_times = max_len - 1, num_sample, event_num] lambda_t_sample = self.compute_states_at_sample_times(decay_states, interval_t_sample) event_ll, non_event_ll, num_events = self.compute_loglikelihood(lambda_at_event=lambda_at_event, lambdas_loss_samples=lambda_t_sample, time_delta_seq=time_delta_seqs[:, 1:], seq_mask=batch_non_pad_mask[:, 1:], lambda_type_mask=type_mask[:, 1:]) # (num_samples, num_times) loss = - (event_ll - non_event_ll).sum() return loss, num_events
[docs] def compute_states_at_sample_times(self, decay_states, sample_dtimes): """Compute the hidden states at sampled times. Args: decay_states (tensor): [batch_size, seq_len, hidden_size]. sample_dtimes (tensor): [batch_size, seq_len, num_samples]. Returns: tensor: hidden state at each sampled time. """ # update the states given last event # Use broadcasting to compute the decays at all time steps # decay_states[..., None, :]: [batch_size, seq_len, 1, hidden_size] # sample_dtimes[..., None]: [batch_size, seq_len, num_mc_sample, 1] # h_ts shape (batch_size, num_times, num_mc_sample, hidden_dim) h_ts = self.state_decay(decay_states[..., None, :], sample_dtimes[..., None]) return h_ts
[docs] def compute_intensities_at_sample_times(self, time_seqs, time_delta_seqs, type_seqs, sample_times, **kwargs): """Compute the intensity at sampled times, not only event times. Args: time_seq (tensor): [batch_size, seq_len], times seqs. time_delta_seq (tensor): [batch_size, seq_len], time delta seqs. event_seq (tensor): [batch_size, seq_len], event type seqs. sample_dtimes (tensor): [batch_size, seq_len, num_sample], sampled inter-event timestamps. Returns: tensor: [batch_size, num_times, num_mc_sample, num_event_types], intensity at each timestamp for each event type. """ compute_last_step_only = kwargs.get('compute_last_step_only', False) # forward to the last but one event _, decay_states = self.forward(time_seqs, time_delta_seqs, type_seqs, **kwargs) # Num of samples in each batch and num of event time point in the sequence batch_size, seq_len, _ = decay_states.size() if compute_last_step_only: interval_t_sample = sample_times[:, -1:, :, None] # [batch_size, 1, num_mc_sample, num_event_types] sampled_intensities = self.state_decay(decay_states[:, -1:, None, :], interval_t_sample) else: # interval_t_sample - [batch_size, num_times, num_mc_sample, 1] interval_t_sample = sample_times[..., None] # Use broadcasting to compute the decays at all time steps # at all sample points # sampled_intensities shape (batch_size, num_times, num_mc_sample, hidden_dim) # decay_states[:, :, None, :] (batch_size, num_times, 1, hidden_dim) sampled_intensities = self.state_decay(decay_states[..., None, :], interval_t_sample) return sampled_intensities