import math
import torch
from torch import nn
from easy_tpp.model.torch_model.torch_baselayer import EncoderLayer, MultiHeadAttention
from easy_tpp.model.torch_model.torch_basemodel import TorchBaseModel
[docs]class AttNHP(TorchBaseModel):
"""Torch implementation of Attentive Neural Hawkes Process, ICLR 2022.
https://arxiv.org/abs/2201.00044.
Source code: https://github.com/yangalan123/anhp-andtt/blob/master/anhp/model/xfmr_nhp_fast.py
"""
[docs] def __init__(self, model_config):
"""Initialize the model
Args:
model_config (EasyTPP.ModelConfig): config of model specs.
"""
super(AttNHP, self).__init__(model_config)
self.d_model = model_config.hidden_size
self.use_norm = model_config.use_ln
self.d_time = model_config.time_emb_size
self.div_term = torch.exp(torch.arange(0, self.d_time, 2) * -(math.log(10000.0) / self.d_time)).reshape(1, 1,
-1)
self.n_layers = model_config.num_layers
self.n_head = model_config.num_heads
self.dropout = model_config.dropout_rate
self.heads = []
for i in range(self.n_head):
self.heads.append(
nn.ModuleList(
[EncoderLayer(
self.d_model + self.d_time,
MultiHeadAttention(1, self.d_model + self.d_time, self.d_model, self.dropout,
output_linear=False),
use_residual=False,
dropout=self.dropout
)
for _ in range(self.n_layers)
]
)
)
self.heads = nn.ModuleList(self.heads)
if self.use_norm:
self.norm = nn.LayerNorm(self.d_model)
self.inten_linear = nn.Linear(self.d_model * self.n_head, self.num_event_types)
self.softplus = nn.Softplus()
self.layer_event_emb = nn.Linear(self.d_model + self.d_time, self.d_model)
self.layer_intensity = nn.Sequential(self.inten_linear, self.softplus)
self.eps = torch.finfo(torch.float32).eps
[docs] def compute_temporal_embedding(self, time):
"""Compute the temporal embedding.
Args:
time (tensor): [batch_size, seq_len].
Returns:
tensor: [batch_size, seq_len, emb_size].
"""
batch_size = time.size(0)
seq_len = time.size(1)
pe = torch.zeros(batch_size, seq_len, self.d_time).to(time)
_time = time.unsqueeze(-1)
div_term = self.div_term.to(time)
pe[..., 0::2] = torch.sin(_time * div_term)
pe[..., 1::2] = torch.cos(_time * div_term)
return pe
[docs] def forward_pass(self, init_cur_layer, time_emb, sample_time_emb, event_emb, combined_mask):
"""update the structure sequentially.
Args:
init_cur_layer (tensor): [batch_size, seq_len, hidden_size]
time_emb (tensor): [batch_size, seq_len, hidden_size]
sample_time_emb (tensor): [batch_size, seq_len, hidden_size]
event_emb (tensor): [batch_size, seq_len, hidden_size]
combined_mask (tensor): [batch_size, seq_len, hidden_size]
Returns:
tensor: [batch_size, seq_len, hidden_size*2]
"""
cur_layers = []
seq_len = event_emb.size(1)
for head_i in range(self.n_head):
# [batch_size, seq_len, hidden_size]
cur_layer_ = init_cur_layer
for layer_i in range(self.n_layers):
# each layer concats the temporal emb
# [batch_size, seq_len, hidden_size*2]
layer_ = torch.cat([cur_layer_, sample_time_emb], dim=-1)
# make combined input from event emb + layer emb
# [batch_size, seq_len*2, hidden_size*2]
_combined_input = torch.cat([event_emb, layer_], dim=1)
enc_layer = self.heads[head_i][layer_i]
# compute the output
enc_output = enc_layer(_combined_input, combined_mask)
# the layer output
# [batch_size, seq_len, hidden_size]
_cur_layer_ = enc_output[:, seq_len:, :]
# add residual connection
cur_layer_ = torch.tanh(_cur_layer_) + cur_layer_
# event emb
event_emb = torch.cat([enc_output[:, :seq_len, :], time_emb], dim=-1)
if self.use_norm:
cur_layer_ = self.norm(cur_layer_)
cur_layers.append(cur_layer_)
cur_layer_ = torch.cat(cur_layers, dim=-1)
return cur_layer_
[docs] def seq_encoding(self, time_seqs, event_seqs):
"""Encode the sequence.
Args:
time_seqs (tensor): time seqs input, [batch_size, seq_len].
event_seqs (_type_): event type seqs input, [batch_size, seq_len].
Returns:
tuple: event embedding, time embedding and type embedding.
"""
# [batch_size, seq_len, hidden_size]
time_emb = self.compute_temporal_embedding(time_seqs)
# [batch_size, seq_len, hidden_size]
type_emb = torch.tanh(self.layer_type_emb(event_seqs.long()))
# [batch_size, seq_len, hidden_size*2]
event_emb = torch.cat([type_emb, time_emb], dim=-1)
return event_emb, time_emb, type_emb
[docs] def make_layer_mask(self, attention_mask):
"""Create a tensor to do masking on layers.
Args:
attention_mask (tensor): mask for attention operation, [batch_size, seq_len, seq_len]
Returns:
tensor: aim to keep the current layer, the same size of attention mask
a diagonal matrix, [batch_size, seq_len, seq_len]
"""
# [batch_size, seq_len, seq_len]
layer_mask = (torch.eye(attention_mask.size(1)) < 1).unsqueeze(0).expand_as(attention_mask)
return layer_mask
[docs] def make_combined_att_mask(self, attention_mask, layer_mask):
"""Combined attention mask and layer mask.
Args:
attention_mask (tensor): mask for attention operation, [batch_size, seq_len, seq_len]
layer_mask (tensor): mask for other layers, [batch_size, seq_len, seq_len]
Returns:
tensor: [batch_size, seq_len * 2, seq_len * 2]
"""
# [batch_size, seq_len, seq_len * 2]
combined_mask = torch.cat([attention_mask, layer_mask], dim=-1)
# [batch_size, seq_len, seq_len * 2]
contextual_mask = torch.cat([attention_mask, torch.ones_like(layer_mask)], dim=-1)
# [batch_size, seq_len * 2, seq_len * 2]
combined_mask = torch.cat([contextual_mask, combined_mask], dim=1)
return combined_mask
[docs] def forward(self, time_seqs, event_seqs, attention_mask, sample_times=None):
"""Call the model.
Args:
time_seqs (tensor): [batch_size, seq_len], sequences of timestamps.
event_seqs (tensor): [batch_size, seq_len], sequences of event types.
attention_mask (tensor): [batch_size, seq_len, seq_len], masks for event sequences.
sample_times (tensor, optional): [batch_size, seq_len, num_samples]. Defaults to None.
Returns:
tensor: states at sampling times, [batch_size, seq_len, num_samples].
"""
event_emb, time_emb, type_emb = self.seq_encoding(time_seqs, event_seqs)
init_cur_layer = torch.zeros_like(type_emb)
layer_mask = self.make_layer_mask(attention_mask)
if sample_times is None:
sample_time_emb = time_emb
else:
sample_time_emb = self.compute_temporal_embedding(sample_times)
combined_mask = self.make_combined_att_mask(attention_mask, layer_mask)
cur_layer_ = self.forward_pass(init_cur_layer, time_emb, sample_time_emb, event_emb, combined_mask)
return cur_layer_
[docs] def loglike_loss(self, batch):
"""Compute the loglike loss.
Args:
batch (list): batch input.
Returns:
list: loglike loss, num events.
"""
time_seqs, time_delta_seqs, type_seqs, batch_non_pad_mask, attention_mask, type_mask = batch
# 1. compute event-loglik
# the prediction of last event has no label, so we proceed to the last but one
# att mask => diag is False, not mask.
enc_out = self.forward(time_seqs[:, :-1], type_seqs[:, :-1], attention_mask[:, 1:, :-1], time_seqs[:, 1:])
# [batch_size, seq_len, num_event_types]
lambda_at_event = self.layer_intensity(enc_out)
# 2. compute non-event-loglik (using MC sampling to compute integral)
# 2.1 sample times
# [batch_size, seq_len, num_sample]
temp_time = self.make_dtime_loss_samples(time_delta_seqs[:, 1:])
# [batch_size, seq_len, num_sample]
sample_times = temp_time + time_seqs[:, :-1].unsqueeze(-1)
# 2.2 compute intensities at sampled times
# [batch_size, seq_len = max_len - 1, num_sample, event_num]
lambda_t_sample = self.compute_intensities_at_sample_times(time_seqs[:, :-1],
time_delta_seqs[:, :-1], # not used
type_seqs[:, :-1],
sample_times,
attention_mask=attention_mask[:, 1:, :-1])
event_ll, non_event_ll, num_events = self.compute_loglikelihood(lambda_at_event=lambda_at_event,
lambdas_loss_samples=lambda_t_sample,
time_delta_seq=time_delta_seqs[:, 1:],
seq_mask=batch_non_pad_mask[:, 1:],
lambda_type_mask=type_mask[:, 1:])
# return enc_inten to compute accuracy
loss = - (event_ll - non_event_ll).sum()
return loss, num_events
[docs] def compute_states_at_sample_times(self,
time_seqs,
type_seqs,
attention_mask,
sample_times):
"""Compute the states at sampling times.
Args:
time_seqs (tensor): [batch_size, seq_len], sequences of timestamps.
time_delta_seqs (tensor): [batch_size, seq_len], sequences of delta times.
type_seqs (tensor): [batch_size, seq_len], sequences of event types.
attention_mask (tensor): [batch_size, seq_len, seq_len], masks for event sequences.
sample_dtimes (tensor): delta times in sampling.
Returns:
tensor: hiddens states at sampling times.
"""
batch_size = type_seqs.size(0)
seq_len = type_seqs.size(1)
num_samples = sample_times.size(-1)
# [num_samples, batch_size, seq_len]
sample_times = sample_times.permute((2, 0, 1))
# [num_samples * batch_size, seq_len]
_sample_time = sample_times.reshape(num_samples * batch_size, -1)
# [num_samples * batch_size, seq_len]
_types = type_seqs.expand(num_samples, -1, -1).reshape(num_samples * batch_size, -1)
# [num_samples * batch_size, seq_len]
_times = time_seqs.expand(num_samples, -1, -1).reshape(num_samples * batch_size, -1)
# [num_samples * batch_size, seq_len]
_attn_mask = attention_mask.unsqueeze(0).expand(num_samples, -1, -1, -1).reshape(num_samples * batch_size,
seq_len,
seq_len)
# [num_samples * batch_size, seq_len, hidden_size]
encoder_output = self.forward(_times,
_types,
_attn_mask,
_sample_time)
# [num_samples, batch_size, seq_len, hidden_size]
encoder_output = encoder_output.reshape(num_samples, batch_size, seq_len, -1)
# [batch_size, seq_len, num_samples, hidden_size]
encoder_output = encoder_output.permute((1, 2, 0, 3))
return encoder_output
[docs] def compute_intensities_at_sample_times(self, time_seqs, time_delta_seqs, type_seqs, sample_times, **kwargs):
"""Compute the intensity at sampled times.
Args:
time_seqs (tensor): [batch_size, seq_len], sequences of timestamps.
time_delta_seqs (tensor): [batch_size, seq_len], sequences of delta times.
type_seqs (tensor): [batch_size, seq_len], sequences of event types.
sampled_dtimes (tensor): [batch_size, seq_len, num_sample], sampled time delta sequence.
Returns:
tensor: intensities as sampled_dtimes, [batch_size, seq_len, num_samples, event_num].
"""
attention_mask = kwargs.get('attention_mask', None)
compute_last_step_only = kwargs.get('compute_last_step_only', False)
if attention_mask is None:
batch_size, seq_len = time_seqs.size()
attention_mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).unsqueeze(0)
attention_mask = attention_mask.expand(batch_size, -1, -1).to(torch.bool)
if sample_times.size()[1] < time_seqs.size()[1]:
# we pass sample_dtimes for last time step here
# we do a temp solution
# [batch_size, seq_len, num_samples]
sample_times = time_seqs[:, :, None] + torch.tile(sample_times, [1, time_seqs.size()[1], 1])
# [batch_size, seq_len, num_samples, hidden_size]
encoder_output = self.compute_states_at_sample_times(time_seqs, type_seqs, attention_mask, sample_times)
if compute_last_step_only:
lambdas = self.layer_intensity(encoder_output[:, -1:, :, :])
else:
# [batch_size, seq_len, num_samples, num_event_types]
lambdas = self.layer_intensity(encoder_output)
return lambdas