Source code for easy_tpp.model.torch_model.torch_thp

import torch
import torch.nn as nn

from easy_tpp.model.torch_model.torch_baselayer import EncoderLayer, MultiHeadAttention, TimePositionalEncoding
from easy_tpp.model.torch_model.torch_basemodel import TorchBaseModel


[docs]class THP(TorchBaseModel): """Torch implementation of Transformer Hawkes Process, ICML 2020, https://arxiv.org/abs/2002.09291. Note: Part of the code is collected from https://github.com/yangalan123/anhp-andtt/tree/master/thp. """
[docs] def __init__(self, model_config): """Initialize the model Args: model_config (EasyTPP.ModelConfig): config of model specs. """ super(THP, self).__init__(model_config) self.d_model = model_config.hidden_size self.d_time = model_config.time_emb_size self.use_norm = model_config.use_ln self.n_layers = model_config.num_layers self.n_head = model_config.num_heads self.dropout = model_config.dropout_rate self.layer_temporal_encoding = TimePositionalEncoding(self.d_model, device=self.device) self.factor_intensity_base = torch.empty([1, self.num_event_types], device=self.device) self.factor_intensity_decay = torch.empty([1, self.num_event_types], device=self.device) nn.init.xavier_normal_(self.factor_intensity_base) nn.init.xavier_normal_(self.factor_intensity_decay) # convert hidden vectors into event-type-sized vector self.layer_intensity_hidden = nn.Linear(self.d_model, self.num_event_types) self.softplus = nn.Softplus() self.stack_layers = nn.ModuleList( [EncoderLayer( self.d_model, MultiHeadAttention(self.n_head, self.d_model, self.d_model, self.dropout, output_linear=False), use_residual=False, dropout=self.dropout ) for _ in range(self.n_layers)])
[docs] def forward(self, time_seqs, type_seqs, attention_mask): """Call the model Args: time_seqs (tensor): [batch_size, seq_len], timestamp seqs. type_seqs (tensor): [batch_size, seq_len], event type seqs. attention_mask (tensor): [batch_size, seq_len, hidden_size], attention masks. Returns: tensor: hidden states at event times. """ # [batch_size, seq_len, hidden_size] tem_enc = self.layer_temporal_encoding(time_seqs) enc_output = self.layer_type_emb(type_seqs) # [batch_size, seq_len, hidden_size] for enc_layer in self.stack_layers: enc_output += tem_enc enc_output = enc_layer( enc_output, mask=attention_mask) return enc_output
[docs] def loglike_loss(self, batch): """Compute the loglike loss. Args: batch (tuple, list): batch input. Returns: tuple: loglike loss, num events. """ time_seqs, time_delta_seqs, type_seqs, batch_non_pad_mask, attention_mask, type_mask = batch # 1. compute event-loglik # [batch_size, seq_len, hidden_size] enc_out = self.forward(time_seqs[:, :-1], type_seqs[:, :-1], attention_mask[:, 1:, :-1]) # [batch_size, seq_len, num_event_types] # update time decay based on Equation (6) # [1, 1, num_event_types] factor_intensity_decay = self.factor_intensity_decay[None, ...] factor_intensity_base = self.factor_intensity_base[None, ...] # update time decay based on Equation (6) # [batch_size, seq_len, num_event_types] intensity_states = factor_intensity_decay * time_delta_seqs[:, 1:, None] + self.layer_intensity_hidden( enc_out) + factor_intensity_base lambda_at_event = self.softplus(intensity_states) # 2. compute non-event-loglik (using MC sampling to compute integral) # 2.1 sample dtimes # [batch_size, seq_len, num_sample] sample_dtimes = self.make_dtime_loss_samples(time_delta_seqs[:, 1:]) # 2.2 compute intensities at sampled times # [batch_size, num_times = max_len - 1, num_sample, event_num] state_t_sample = self.compute_states_at_sample_times(event_states=enc_out, sample_dtimes=sample_dtimes) lambda_t_sample = self.softplus(state_t_sample) event_ll, non_event_ll, num_events = self.compute_loglikelihood(lambda_at_event=lambda_at_event, lambdas_loss_samples=lambda_t_sample, time_delta_seq=time_delta_seqs[:, 1:], seq_mask=batch_non_pad_mask[:, 1:], lambda_type_mask=type_mask[:, 1:]) # return enc_inten to compute accuracy loss = - (event_ll - non_event_ll).sum() return loss, num_events
[docs] def compute_states_at_sample_times(self, event_states, sample_dtimes): """Compute the hidden states at sampled times. Args: event_states (tensor): [batch_size, seq_len, hidden_size]. sample_dtimes (tensor): [batch_size, seq_len, num_samples]. Returns: tensor: hidden state at each sampled time. """ # [batch_size, seq_len, 1, hidden_size] event_states = event_states[:, :, None, :] # [batch_size, seq_len, num_samples, 1] sample_dtimes = sample_dtimes[..., None] # [1, 1, 1, num_event_types] factor_intensity_decay = self.factor_intensity_decay[None, None, ...] factor_intensity_base = self.factor_intensity_base[None, None, ...] # update time decay based on Equation (6) # [batch_size, seq_len, num_samples, num_event_types] intensity_states = factor_intensity_decay * sample_dtimes + self.layer_intensity_hidden( event_states) + factor_intensity_base return intensity_states
[docs] def compute_intensities_at_sample_times(self, time_seqs, time_delta_seqs, type_seqs, sample_dtimes, **kwargs): """Compute hidden states at sampled times. Args: time_seqs (tensor): [batch_size, seq_len], times seqs. time_delta_seqs (tensor): [batch_size, seq_len], time delta seqs. type_seqs (tensor): [batch_size, seq_len], event type seqs. sample_dtimes (tensor): [batch_size, seq_len, num_samples], sampled inter-event timestamps. Returns: tensor: [batch_size, seq_len, num_samples, num_event_types], intensity at all sampled times. """ attention_mask = kwargs.get('attention_mask', None) compute_last_step_only = kwargs.get('compute_last_step_only', False) if attention_mask is None: batch_size, seq_len = time_seqs.size() attention_mask = torch.triu(torch.ones(seq_len, seq_len, device=self.device), diagonal=1).unsqueeze(0) attention_mask = attention_mask.expand(batch_size, -1, -1).to(torch.bool) # [batch_size, seq_len, num_samples] enc_out = self.forward(time_seqs, type_seqs, attention_mask) # [batch_size, seq_len, num_samples, hidden_size] encoder_output = self.compute_states_at_sample_times(enc_out, sample_dtimes) if compute_last_step_only: lambdas = self.softplus(encoder_output[:, -1:, :, :]) else: # [batch_size, seq_len, num_samples, num_event_types] lambdas = self.softplus(encoder_output) return lambdas