Source code for easy_tpp.model.torch_model.torch_intensity_free

import torch
import torch.distributions as D
from torch import nn
from torch.distributions import Categorical, TransformedDistribution
from torch.distributions import MixtureSameFamily as TorchMixtureSameFamily
from torch.distributions import Normal as TorchNormal

from easy_tpp.model.torch_model.torch_basemodel import TorchBaseModel


[docs]def clamp_preserve_gradients(x, min_val, max_val): """Clamp the tensor while preserving gradients in the clamped region. Args: x (tensor): tensor to be clamped. min_val (float): minimum value. max_val (float): maximum value. """ return x + (x.clamp(min_val, max_val) - x).detach()
[docs]class Normal(TorchNormal): """Normal distribution, redefined `log_cdf` and `log_survival_function` due to no numerically stable implementation of them is available for normal distribution. """ def log_cdf(self, x): cdf = clamp_preserve_gradients(self.cdf(x), 1e-7, 1 - 1e-7) return cdf.log() def log_survival_function(self, x): cdf = clamp_preserve_gradients(self.cdf(x), 1e-7, 1 - 1e-7) return torch.log(1.0 - cdf)
[docs]class MixtureSameFamily(TorchMixtureSameFamily): """Mixture (same-family) distribution, redefined `log_cdf` and `log_survival_function`. """ def log_cdf(self, x): x = self._pad(x) log_cdf_x = self.component_distribution.log_cdf(x) mix_logits = self.mixture_distribution.logits return torch.logsumexp(log_cdf_x + mix_logits, dim=-1) def log_survival_function(self, x): x = self._pad(x) log_sf_x = self.component_distribution.log_survival_function(x) mix_logits = self.mixture_distribution.logits return torch.logsumexp(log_sf_x + mix_logits, dim=-1)
[docs]class LogNormalMixtureDistribution(TransformedDistribution): """ Mixture of log-normal distributions. Args: locs (tensor): [batch_size, seq_len, num_mix_components]. log_scales (tensor): [batch_size, seq_len, num_mix_components]. log_weights (tensor): [batch_size, seq_len, num_mix_components]. mean_log_inter_time (float): Average log-inter-event-time. std_log_inter_time (float): Std of log-inter-event-times. """
[docs] def __init__(self, locs, log_scales, log_weights, mean_log_inter_time, std_log_inter_time, validate_args=None): mixture_dist = D.Categorical(logits=log_weights) component_dist = Normal(loc=locs, scale=log_scales.exp()) GMM = MixtureSameFamily(mixture_dist, component_dist) if mean_log_inter_time == 0.0 and std_log_inter_time == 1.0: transforms = [] else: transforms = [D.AffineTransform(loc=mean_log_inter_time, scale=std_log_inter_time)] self.mean_log_inter_time = mean_log_inter_time self.std_log_inter_time = std_log_inter_time transforms.append(D.ExpTransform()) self.transforms = transforms sign = 1 for transform in self.transforms: sign = sign * transform.sign self.sign = int(sign) super().__init__(GMM, transforms, validate_args=validate_args)
def log_cdf(self, x): for transform in self.transforms[::-1]: x = transform.inv(x) if self._validate_args: self.base_dist._validate_sample(x) if self.sign == 1: return self.base_dist.log_cdf(x) else: return self.base_dist.log_survival_function(x) def log_survival_function(self, x): for transform in self.transforms[::-1]: x = transform.inv(x) if self._validate_args: self.base_dist._validate_sample(x) if self.sign == 1: return self.base_dist.log_survival_function(x) else: return self.base_dist.log_cdf(x)
[docs]class IntensityFree(TorchBaseModel): """Torch implementation of Intensity-Free Learning of Temporal Point Processes, ICLR 2020. https://openreview.net/pdf?id=HygOjhEYDH reference: https://github.com/shchur/ifl-tpp """
[docs] def __init__(self, model_config): """Initialize the model Args: model_config (EasyTPP.ModelConfig): config of model specs. """ super(IntensityFree, self).__init__(model_config) self.num_mix_components = model_config.model_specs['num_mix_components'] self.num_features = 1 + self.hidden_size self.layer_rnn = nn.GRU(input_size=self.num_features, hidden_size=self.hidden_size, num_layers=1, batch_first=True) self.mark_linear = nn.Linear(self.hidden_size, self.num_event_types_pad) self.linear = nn.Linear(self.hidden_size, 3 * self.num_mix_components)
[docs] def forward(self, time_delta_seqs, type_seqs): """Call the model. Args: time_delta_seqs (tensor): [batch_size, seq_len], inter-event time seqs. type_seqs (tensor): [batch_size, seq_len], event type seqs. Returns: list: hidden states, [batch_size, seq_len, hidden_dim], states right before the event happens. """ # [batch_size, seq_len, hidden_size] # We dont normalize inter-event time here temporal_seqs = torch.log(time_delta_seqs + self.eps).unsqueeze(-1) # [batch_size, seq_len, hidden_size] type_emb = self.layer_type_emb(type_seqs) # [batch_size, seq_len, hidden_size + 1] rnn_input = torch.cat([temporal_seqs, type_emb], dim=-1) # [batch_size, seq_len, hidden_size] context = self.layer_rnn(rnn_input)[0] return context
[docs] def loglike_loss(self, batch): """Compute the loglike loss. Args: batch (list): batch input. Returns: tuple: loglikelihood loss and num of events. """ time_seqs, time_delta_seqs, type_seqs, batch_non_pad_mask, _, type_mask = batch mean_log_inter_time = \ torch.masked_select(time_delta_seqs[:, 1:], batch_non_pad_mask[:, 1:]).clamp(1e-5).log().mean() std_log_inter_time = \ torch.masked_select(time_delta_seqs[:, 1:], batch_non_pad_mask[:, 1:]).clamp(1e-5).log().std() # [batch_size, seq_len, hidden_size] context = self.forward(time_delta_seqs[:, 1:], type_seqs[:, :-1]) # [batch_size, seq_len, 3 * num_mix_components] raw_params = self.linear(context) locs = raw_params[..., :self.num_mix_components] log_scales = raw_params[..., self.num_mix_components: (2 * self.num_mix_components)] log_weights = raw_params[..., (2 * self.num_mix_components):] log_scales = clamp_preserve_gradients(log_scales, -5.0, 3.0) log_weights = torch.log_softmax(log_weights, dim=-1) inter_time_dist = LogNormalMixtureDistribution( locs=locs, log_scales=log_scales, log_weights=log_weights, mean_log_inter_time=mean_log_inter_time, std_log_inter_time=std_log_inter_time ) inter_times = time_delta_seqs[:, 1:].clamp(min=1e-5) # [batch_size, seq_len] log_p = inter_time_dist.log_prob(inter_times) # i comment these lines # (batch_size, 1) # last_event_idx = batch_non_pad_mask.sum(-1, keepdim=True).long() - 1 # log_surv_all = inter_time_dist.log_survival_function(inter_times) # (batch_size,) # log_surv_last = torch.gather(log_surv_all, dim=-1, index=last_event_idx).squeeze(-1) # [batch_size, seq_len, num_marks] mark_logits = torch.log_softmax(self.mark_linear(context), dim=-1) mark_dist = Categorical(logits=mark_logits) log_p += mark_dist.log_prob(type_seqs[:, :-1]) # [batch_size, seq_len] log_p *= batch_non_pad_mask[:, 1:] # [batch_size,] loss = -(log_p.sum(-1)).mean() num_events = torch.masked_select(batch_non_pad_mask[:, 1:], batch_non_pad_mask[:, 1:]).size()[0] return loss, num_events