Source code for easy_tpp.model.torch_model.torch_sahp

import torch
import torch.nn as nn

from easy_tpp.model.torch_model.torch_baselayer import EncoderLayer, MultiHeadAttention, \
    TimeShiftedPositionalEncoding
from easy_tpp.model.torch_model.torch_basemodel import TorchBaseModel


[docs]class SAHP(TorchBaseModel): """Torch implementation of Self-Attentive Hawkes Process, ICML 2020. Part of the code is collected from https://github.com/yangalan123/anhp-andtt/blob/master/sahp I slightly modify the original code because it is not stable. """
[docs] def __init__(self, model_config): """Initialize the model Args: model_config (EasyTPP.ModelConfig): config of model specs. """ super(SAHP, self).__init__(model_config) self.d_model = model_config.hidden_size self.d_time = model_config.time_emb_size self.use_norm = model_config.use_ln # position vector, used for temporal encoding self.layer_position_emb = TimeShiftedPositionalEncoding(d_model=self.d_model, device=self.device) self.n_layers = model_config.num_layers self.n_head = model_config.num_heads self.dropout = model_config.dropout_rate # convert hidden vectors into a scalar self.layer_intensity_hidden = nn.Linear(self.d_model, self.num_event_types) self.softplus = nn.Softplus() self.stack_layers = nn.ModuleList( [EncoderLayer( self.d_model, MultiHeadAttention(self.n_head, self.d_model, self.d_model, self.dropout, output_linear=False), use_residual=False, dropout=self.dropout ) for _ in range(self.n_layers)]) if self.use_norm: self.norm = nn.LayerNorm(self.d_model) # Equation (12): mu self.mu = torch.empty([self.d_model, self.num_event_types], device=self.device) # Equation (13): eta self.eta = torch.empty([self.d_model, self.num_event_types], device=self.device) # Equation (14): gamma self.gamma = torch.empty([self.d_model, self.num_event_types], device=self.device) nn.init.xavier_normal_(self.mu) nn.init.xavier_normal_(self.eta) nn.init.xavier_normal_(self.gamma)
[docs] def state_decay(self, encode_state, mu, eta, gamma, duration_t): """Equation (15), which computes the pre-intensity states Args: encode_state (tensor): [batch_size, seq_len, hidden_size]. mu (tensor): [batch_size, seq_len, hidden_size]. eta (tensor): [batch_size, seq_len, hidden_size]. gamma (tensor): [batch_size, seq_len, hidden_size]. duration_t (tensor): [batch_size, seq_len, num_sample]. Returns: tensor: hidden states at event times. """ # [batch_size, hidden_dim] states = torch.matmul(encode_state, mu) + ( torch.matmul(encode_state, eta) - torch.matmul(encode_state, mu)) * torch.exp( -torch.matmul(encode_state, gamma) * duration_t) return states
[docs] def forward(self, time_seqs, time_delta_seqs, event_seqs, attention_mask): """Call the model Args: time_seqs (tensor): [batch_size, seq_len], timestamp seqs. time_delta_seqs (tensor): [batch_size, seq_len], inter-event time seqs. event_seqs (tensor): [batch_size, seq_len], event type seqs. attention_mask (tensor): [batch_size, seq_len, hidden_size], attention masks. Returns: tensor: hidden states at event times. """ type_embedding = self.layer_type_emb(event_seqs) position_embedding = self.layer_position_emb(time_seqs, time_delta_seqs) enc_output = type_embedding + position_embedding for enc_layer in self.stack_layers: enc_output = enc_layer( enc_output, mask=attention_mask) if self.use_norm: enc_output = self.norm(enc_output) # [batch_size, seq_len, hidden_dim] return enc_output
[docs] def loglike_loss(self, batch): """Compute the loglike loss. Args: batch (tuple, list): batch input. Returns: list: loglike loss, num events. """ time_seqs, time_delta_seqs, type_seqs, batch_non_pad_mask, attention_mask, type_mask = batch enc_out = self.forward(time_seqs[:, :-1], time_delta_seqs[:, 1:], type_seqs[:, :-1], attention_mask[:, 1:, :-1]) cell_t = self.state_decay(encode_state=enc_out, mu=self.mu[None, ...], eta=self.eta[None, ...], gamma=self.gamma[None, ...], duration_t=time_delta_seqs[:, 1:, None]) # [batch_size, seq_len, num_event_types] lambda_at_event = self.softplus(cell_t) # 2. compute non-event-loglik (using MC sampling to compute integral) # 2.1 sample times # [batch_size, seq_len, num_sample] sample_dtimes = self.make_dtime_loss_samples(time_delta_seqs[:, 1:]) # 2.2 compute intensities at sampled times # [batch_size, num_times = max_len - 1, num_sample, event_num] state_t_sample = self.compute_states_at_sample_times(encode_state=enc_out, sample_dtimes=sample_dtimes) lambda_t_sample = self.softplus(state_t_sample) event_ll, non_event_ll, num_events = self.compute_loglikelihood(lambda_at_event=lambda_at_event, lambdas_loss_samples=lambda_t_sample, time_delta_seq=time_delta_seqs[:, 1:], seq_mask=batch_non_pad_mask[:, 1:], lambda_type_mask=type_mask[:, 1:]) # return enc_inten to compute accuracy loss = - (event_ll - non_event_ll).sum() return loss, num_events
[docs] def compute_states_at_sample_times(self, encode_state, sample_dtimes): """Compute the hidden states at sampled times. Args: encode_state (tensor): three tensors with each shape [batch_size, seq_len, hidden_size]. sample_dtimes (tensor): [batch_size, seq_len, num_samples]. Returns: tensor: [batch_size, seq_len, num_samples, hidden_size], hidden state at each sampled time. """ cell_states = self.state_decay(encode_state[:, :, None, :], self.mu[None, None, ...], self.eta[None, None, ...], self.gamma[None, None, ...], sample_dtimes[:, :, :, None]) return cell_states
[docs] def compute_intensities_at_sample_times(self, time_seqs, time_delta_seqs, type_seqs, sample_dtimes, **kwargs): """Compute hidden states at sampled times. Args: time_seqs (tensor): [batch_size, seq_len], times seqs. time_delta_seqs (tensor): [batch_size, seq_len], time delta seqs. type_seqs (tensor): [batch_size, seq_len], event type seqs. sample_dtimes (tensor): [batch_size, seq_len, num_samples], sampled inter-event timestamps. Returns: tensor: [batch_size, seq_len, num_samples, num_event_types], intensity at all sampled times. """ attention_mask = kwargs.get('attention_mask', None) compute_last_step_only = kwargs.get('compute_last_step_only', False) if attention_mask is None: batch_size, seq_len = time_seqs.size() attention_mask = torch.triu(torch.ones(seq_len, seq_len, device=self.device), diagonal=1).unsqueeze(0) attention_mask = attention_mask.expand(batch_size, -1, -1).to(torch.bool) # [batch_size, seq_len, num_samples] enc_out = self.forward(time_seqs, time_delta_seqs, type_seqs, attention_mask) # [batch_size, seq_len, num_samples, hidden_size] encoder_output = self.compute_states_at_sample_times(enc_out, sample_dtimes) if compute_last_step_only: lambdas = self.softplus(encoder_output[:, -1:, :, :]) else: # [batch_size, seq_len, num_samples, num_event_types] lambdas = self.softplus(encoder_output) return lambdas