Source code for easy_tpp.model.torch_model.torch_nhp

import torch
from torch import nn

from easy_tpp.model.torch_model.torch_basemodel import TorchBaseModel


[docs]class ContTimeLSTMCell(nn.Module): """LSTM Cell in Neural Hawkes Process, NeurIPS'17. """
[docs] def __init__(self, hidden_dim, beta=1.0): """Initialize the continuous LSTM cell. Args: hidden_dim (int): dim of hidden state. beta (float, optional): beta in nn.Softplus. Defaults to 1.0. """ super(ContTimeLSTMCell, self).__init__() self.hidden_dim = hidden_dim self.init_dense_layer(hidden_dim, bias=True, beta=beta)
[docs] def init_dense_layer(self, hidden_dim, bias, beta): """Initialize linear layers given Equations (5a-6c) in the paper. Args: hidden_dim (int): dim of hidden state. bias (bool): whether to use bias term in nn.Linear. beta (float): beta in nn.Softplus. """ self.layer_input = nn.Linear(hidden_dim * 2, hidden_dim, bias=bias) self.layer_forget = nn.Linear(hidden_dim * 2, hidden_dim, bias=bias) self.layer_output = nn.Linear(hidden_dim * 2, hidden_dim, bias=bias) self.layer_input_bar = nn.Linear(hidden_dim * 2, hidden_dim, bias=bias) self.layer_forget_bar = nn.Linear(hidden_dim * 2, hidden_dim, bias=bias) self.layer_pre_c = nn.Linear(hidden_dim * 2, hidden_dim, bias=bias) self.layer_decay = nn.Sequential( nn.Linear(hidden_dim * 2, hidden_dim, bias=bias), nn.Softplus(beta=beta))
[docs] def forward(self, x_i, hidden_i_minus, cell_i_minus, cell_bar_i_minus_1): """Update the continuous-time LSTM cell. Args: x_i (tensor): event embedding vector at t_i. hidden_i_minus (tensor): hidden state at t_i- cell_i_minus (tensor): cell state at t_i- cell_bar_i_minus_1 (tensor): cell bar state at t_{i-1} Returns: list: cell state, cell bar state, decay and output at t_i """ x_i_ = torch.cat((x_i, hidden_i_minus), dim=1) # update input gate - Equation (5a) gate_input = torch.nn.Sigmoid()(self.layer_input(x_i_)) # update forget gate - Equation (5b) gate_forget = torch.nn.Sigmoid()(self.layer_forget(x_i_)) # update output gate - Equation (5d) gate_output = torch.nn.Sigmoid()(self.layer_output(x_i_)) # update input bar - similar to Equation (5a) gate_input_bar = torch.nn.Sigmoid()(self.layer_input_bar(x_i_)) # update forget bar - similar to Equation (5b) gate_forget_bar = torch.nn.Sigmoid()(self.layer_forget_bar(x_i_)) # update gate z - Equation (5c) gate_pre_c = torch.tanh(self.layer_pre_c(x_i_)) # update gate decay - Equation (6c) gate_decay = self.layer_decay(x_i_) # update cell state to t_i+ - Equation (6a) cell_i = gate_forget * cell_i_minus + gate_input * gate_pre_c # update cell state bar - Equation (6b) cell_bar_i = gate_forget_bar * cell_bar_i_minus_1 + gate_input_bar * gate_pre_c return cell_i, cell_bar_i, gate_decay, gate_output
[docs] def decay(self, cell_i, cell_bar_i, gate_decay, gate_output, dtime): """Cell and hidden state decay according to Equation (7). Args: cell_i (tensor): cell state at t_i. cell_bar_i (tensor): cell bar state at t_i. gate_decay (tensor): gate decay state at t_i. gate_output (tensor): gate output state at t_i. dtime (tensor): delta time to decay. Returns: list: list of cell and hidden state tensors after the decay. """ c_t = cell_bar_i + (cell_i - cell_bar_i) * torch.exp(-gate_decay * dtime) h_t = gate_output * torch.tanh(c_t) return c_t, h_t
[docs]class NHP(TorchBaseModel): """Torch implementation of The Neural Hawkes Process: A Neurally Self-Modulating Multivariate Point Process, NeurIPS 2017, https://arxiv.org/abs/1612.09328. """
[docs] def __init__(self, model_config): """Initialize the NHP model. Args: model_config (EasyTPP.ModelConfig): config of model specs. """ super(NHP, self).__init__(model_config) self.beta = model_config.model_specs.get('beta', 1.0) self.bias = model_config.model_specs.get('bias', False) self.rnn_cell = ContTimeLSTMCell(self.hidden_size) self.layer_intensity = nn.Sequential( nn.Linear(self.hidden_size, self.num_event_types, self.bias), nn.Softplus(self.beta))
[docs] def init_state(self, batch_size): """Initialize hidden and cell states. Args: batch_size (int): size of batch data. Returns: list: list of hidden states, cell states and cell bar states. """ h_t, c_t, c_bar = torch.zeros(batch_size, 3 * self.hidden_size, device=self.device).chunk(3, dim=1) return h_t, c_t, c_bar
[docs] def forward(self, batch, **kwargs): """Call the model. Args: batch (tuple, list): batch input. Returns: list: hidden states, [batch_size, seq_len, hidden_dim], states right before the event happens; stacked decay states, [batch_size, max_seq_length, 4, hidden_dim], states right after the event happens. """ time_seq, time_delta_seq, event_seq, batch_non_pad_mask, _, type_mask = batch all_hiddens = [] all_outputs = [] all_cells = [] all_cell_bars = [] all_decays = [] max_steps = kwargs.get('max_steps', None) max_decay_time = kwargs.get('max_decay_time', 5.0) # last event has no time label max_seq_length = max_steps if max_steps is not None else event_seq.size(1) - 1 batch_size = len(event_seq) h_t, c_t, c_bar_i = self.init_state(batch_size) # if only one event, then we dont decay if max_seq_length == 1: types_sub_batch = event_seq[:, 0] x_t = self.layer_type_emb(types_sub_batch) cell_i, c_bar_i, decay_i, output_i = \ self.rnn_cell(x_t, h_t, c_t, c_bar_i) # Append all output all_outputs.append(output_i) all_decays.append(decay_i) all_cells.append(cell_i) all_cell_bars.append(c_bar_i) all_hiddens.append(h_t) else: # Loop over all events for i in range(max_seq_length): if i == event_seq.size(1) - 1: dt = torch.ones_like(time_delta_seq[:, i]) * max_decay_time else: dt = time_delta_seq[:, i + 1] # need to carefully check here types_sub_batch = event_seq[:, i] x_t = self.layer_type_emb(types_sub_batch) # cell_i (batch_size, process_dim) cell_i, c_bar_i, decay_i, output_i = \ self.rnn_cell(x_t, h_t, c_t, c_bar_i) # States decay - Equation (7) in the paper c_t, h_t = self.rnn_cell.decay(cell_i, c_bar_i, decay_i, output_i, dt[:, None]) # Append all output all_outputs.append(output_i) all_decays.append(decay_i) all_cells.append(cell_i) all_cell_bars.append(c_bar_i) all_hiddens.append(h_t) # (batch_size, max_seq_length, hidden_dim) cell_stack = torch.stack(all_cells, dim=1) cell_bar_stack = torch.stack(all_cell_bars, dim=1) decay_stack = torch.stack(all_decays, dim=1) output_stack = torch.stack(all_outputs, dim=1) # [batch_size, max_seq_length, hidden_dim] hiddens_stack = torch.stack(all_hiddens, dim=1) # [batch_size, max_seq_length, 4, hidden_dim] decay_states_stack = torch.stack((cell_stack, cell_bar_stack, decay_stack, output_stack), dim=2) return hiddens_stack, decay_states_stack
[docs] def loglike_loss(self, batch): """Compute the loglike loss. Args: batch (list): batch input. Returns: list: loglike loss, num events. """ time_seqs, time_delta_seqs, type_seqs, batch_non_pad_mask, _, type_mask = batch hiddens_ti, decay_states = self.forward(batch) # Num of samples in each batch and num of event time point in the sequence batch_size, seq_len, _ = hiddens_ti.size() # Lambda(t) right before each event time point # lambda_at_event - [batch_size, num_times=max_len-1, num_event_types] # Here we drop the last event because it has no delta_time label (can not decay) lambda_at_event = self.layer_intensity(hiddens_ti) # Compute the big lambda integral in Equation (8) # 1 - take num_mc_sample rand points in each event interval # 2 - compute its lambda value for every sample point # 3 - take average of these sample points # 4 - times the interval length # interval_t_sample - [batch_size, num_times=max_len-1, num_mc_sample] # for every batch and every event point => do a sampling (num_mc_sampling) # the first dtime is zero, so we use time_delta_seq[:, 1:] interval_t_sample = self.make_dtime_loss_samples(time_delta_seqs[:, 1:]) # [batch_size, num_times = max_len - 1, num_mc_sample, hidden_size] state_t_sample = self.compute_states_at_sample_times(decay_states, interval_t_sample) # [batch_size, num_times = max_len - 1, num_mc_sample, event_num] lambda_t_sample = self.layer_intensity(state_t_sample) event_ll, non_event_ll, num_events = self.compute_loglikelihood(lambda_at_event=lambda_at_event, lambdas_loss_samples=lambda_t_sample, time_delta_seq=time_delta_seqs[:, 1:], seq_mask=batch_non_pad_mask[:, 1:], lambda_type_mask=type_mask[:, 1:]) # (num_samples, num_times) loss = - (event_ll - non_event_ll).sum() return loss, num_events
[docs] def compute_states_at_sample_times(self, decay_states, sample_dtimes): """Compute the states at sampling times. Args: decay_states (tensor): states right after the events. sample_dtimes (tensor): delta times in sampling. Returns: tensor: hiddens states at sampling times. """ # update the states given last event # cells (batch_size, num_times, hidden_dim) cells, cell_bars, decays, outputs = decay_states.unbind(dim=-2) # Use broadcasting to compute the decays at all time steps # at all sample points # h_ts shape (batch_size, num_times, num_mc_sample, hidden_dim) # cells[:, :, None, :] (batch_size, num_times, 1, hidden_dim) _, h_ts = self.rnn_cell.decay(cells[:, :, None, :], cell_bars[:, :, None, :], decays[:, :, None, :], outputs[:, :, None, :], sample_dtimes[..., None]) return h_ts
[docs] def compute_intensities_at_sample_times(self, time_seqs, time_delta_seqs, type_seqs, sample_dtimes, **kwargs): """Compute the intensity at sampled times, not only event times. Args: time_seqs (tensor): [batch_size, seq_len], times seqs. time_delta_seqs (tensor): [batch_size, seq_len], time delta seqs. type_seqs (tensor): [batch_size, seq_len], event type seqs. sample_dtimes (tensor): [batch_size, seq_len, num_sample], sampled inter-event timestamps. Returns: tensor: [batch_size, num_times, num_mc_sample, num_event_types], intensity at each timestamp for each event type. """ compute_last_step_only = kwargs.get('compute_last_step_only', False) input_ = time_seqs, time_delta_seqs, type_seqs, None, None, None # forward to the last but one event hiddens_ti, decay_states = self.forward(input_, **kwargs) # Num of samples in each batch and num of event time point in the sequence batch_size, seq_len, _ = hiddens_ti.size() # update the states given last event # cells (batch_size, num_times, hidden_dim) cells, cell_bars, decays, outputs = decay_states.unbind(dim=-2) if compute_last_step_only: interval_t_sample = sample_dtimes[:, -1:, :, None] _, h_ts = self.rnn_cell.decay(cells[:, -1:, None, :], cell_bars[:, -1:, None, :], decays[:, -1:, None, :], outputs[:, -1:, None, :], interval_t_sample) # [batch_size, 1, num_mc_sample, num_event_types] sampled_intensities = self.layer_intensity(h_ts) else: # interval_t_sample - [batch_size, num_times, num_mc_sample, 1] interval_t_sample = sample_dtimes[..., None] # Use broadcasting to compute the decays at all time steps # at all sample points # h_ts shape (batch_size, num_times, num_mc_sample, hidden_dim) # cells[:, :, None, :] (batch_size, num_times, 1, hidden_dim) _, h_ts = self.rnn_cell.decay(cells[:, :, None, :], cell_bars[:, :, None, :], decays[:, :, None, :], outputs[:, :, None, :], interval_t_sample) # [batch_size, num_times, num_mc_sample, num_event_types] sampled_intensities = self.layer_intensity(h_ts) return sampled_intensities