import torch
from torch import nn
from easy_tpp.model.torch_model.torch_basemodel import TorchBaseModel
[docs]class RMTPP(TorchBaseModel):
"""Torch implementation of Recurrent Marked Temporal Point Processes, KDD 2016.
https://www.kdd.org/kdd2016/papers/files/rpp1081-duA.pdf
"""
[docs] def __init__(self, model_config):
"""Initialize the model
Args:
model_config (EasyTPP.ModelConfig): config of model specs.
"""
super(RMTPP, self).__init__(model_config)
self.layer_temporal_emb = nn.Linear(1, self.hidden_size)
self.layer_rnn = nn.RNN(input_size=self.hidden_size, hidden_size=self.hidden_size,
num_layers=1, batch_first=True)
self.layer_hidden = nn.Linear(self.hidden_size, self.num_event_types)
self.factor_intensity_base = torch.empty([1, 1, self.num_event_types], device=self.device)
self.factor_intensity_current_influence = torch.empty([1, 1, self.num_event_types], device=self.device)
nn.init.xavier_normal_(self.factor_intensity_base)
nn.init.xavier_normal_(self.factor_intensity_current_influence)
[docs] def state_decay(self, states_to_decay, duration_t):
"""Equation (11), which computes intensity
"""
# [batch_size, seq_len, num_event_types]
states_to_decay_ = self.layer_hidden(states_to_decay)
# [batch_size, seq_len, num_event_types]
# put a max number to avoid explode during HPO
intensity = torch.exp(
states_to_decay_ + self.factor_intensity_current_influence * duration_t +
self.factor_intensity_base).clamp(max=1e5)
return intensity
[docs] def forward(self, time_seqs, time_delta_seqs, type_seqs, **kwargs):
"""Call the model.
Args:
batch (list): batch input.
Returns:
list: hidden states, [batch_size, seq_len, hidden_dim], states right before the event happens;
stacked decay states, [batch_size, max_seq_length, 4, hidden_dim], states right after
the event happens.
"""
max_steps = kwargs.get('max_steps', None)
# last event has no time label
max_seq_length = max_steps if max_steps is not None else type_seqs.size(1) - 1
# [batch_size, seq_len, hidden_size]
type_emb = self.layer_type_emb(type_seqs)
# [batch_size, seq_len, hidden_size]
temporal_emb = self.layer_temporal_emb(time_seqs[..., None])
# [batch_size, seq_len, hidden_size]
# states right after the event
decay_states, _ = self.layer_rnn(type_emb + temporal_emb)
# if only one event, then we dont decay
if max_seq_length == 1:
h_t = decay_states
else:
# States decay - Equation (7) in the paper
# states before the happening of the next event
h_t = self.state_decay(decay_states, time_delta_seqs[..., None])
return h_t, decay_states
[docs] def loglike_loss(self, batch):
"""Compute the loglike loss.
Args:
batch (list): batch input.
Returns:
tuple: loglikelihood loss and num of events.
"""
time_seqs, time_delta_seqs, type_seqs, batch_non_pad_mask, _, type_mask = batch
lambda_at_event, decay_states = self.forward(time_seqs[:, :-1], time_delta_seqs[:, 1:], type_seqs[:, :-1])
# Num of samples in each batch and num of event time point in the sequence
batch_size, seq_len, _ = lambda_at_event.size()
# Compute the big lambda integral in equation (8)
# 1 - take num_mc_sample rand points in each event interval
# 2 - compute its lambda value for every sample point
# 3 - take average of these sample points
# 4 - times the interval length
# interval_t_sample - [batch_size, num_times=max_len-1, num_mc_sample]
# for every batch and every event point => do a sampling (num_mc_sampling)
# the first dtime is zero, so we use time_delta_seqs[:, 1:]
interval_t_sample = self.make_dtime_loss_samples(time_delta_seqs[:, 1:])
# [batch_size, num_times = max_len - 1, num_sample, event_num]
lambda_t_sample = self.compute_states_at_sample_times(decay_states, interval_t_sample)
event_ll, non_event_ll, num_events = self.compute_loglikelihood(lambda_at_event=lambda_at_event,
lambdas_loss_samples=lambda_t_sample,
time_delta_seq=time_delta_seqs[:, 1:],
seq_mask=batch_non_pad_mask[:, 1:],
lambda_type_mask=type_mask[:, 1:])
# (num_samples, num_times)
loss = - (event_ll - non_event_ll).sum()
return loss, num_events
[docs] def compute_states_at_sample_times(self, decay_states, sample_dtimes):
"""Compute the hidden states at sampled times.
Args:
decay_states (tensor): [batch_size, seq_len, hidden_size].
sample_dtimes (tensor): [batch_size, seq_len, num_samples].
Returns:
tensor: hidden state at each sampled time.
"""
# update the states given last event
# Use broadcasting to compute the decays at all time steps
# decay_states[..., None, :]: [batch_size, seq_len, 1, hidden_size]
# sample_dtimes[..., None]: [batch_size, seq_len, num_mc_sample, 1]
# h_ts shape (batch_size, num_times, num_mc_sample, hidden_dim)
h_ts = self.state_decay(decay_states[..., None, :],
sample_dtimes[..., None])
return h_ts
[docs] def compute_intensities_at_sample_times(self, time_seqs, time_delta_seqs, type_seqs, sample_times, **kwargs):
"""Compute the intensity at sampled times, not only event times.
Args:
time_seq (tensor): [batch_size, seq_len], times seqs.
time_delta_seq (tensor): [batch_size, seq_len], time delta seqs.
event_seq (tensor): [batch_size, seq_len], event type seqs.
sample_dtimes (tensor): [batch_size, seq_len, num_sample], sampled inter-event timestamps.
Returns:
tensor: [batch_size, num_times, num_mc_sample, num_event_types],
intensity at each timestamp for each event type.
"""
compute_last_step_only = kwargs.get('compute_last_step_only', False)
# forward to the last but one event
_, decay_states = self.forward(time_seqs, time_delta_seqs, type_seqs, **kwargs)
# Num of samples in each batch and num of event time point in the sequence
batch_size, seq_len, _ = decay_states.size()
if compute_last_step_only:
interval_t_sample = sample_times[:, -1:, :, None]
# [batch_size, 1, num_mc_sample, num_event_types]
sampled_intensities = self.state_decay(decay_states[:, -1:, None, :],
interval_t_sample)
else:
# interval_t_sample - [batch_size, num_times, num_mc_sample, 1]
interval_t_sample = sample_times[..., None]
# Use broadcasting to compute the decays at all time steps
# at all sample points
# sampled_intensities shape (batch_size, num_times, num_mc_sample, hidden_dim)
# decay_states[:, :, None, :] (batch_size, num_times, 1, hidden_dim)
sampled_intensities = self.state_decay(decay_states[..., None, :],
interval_t_sample)
return sampled_intensities