import torch
import torch.nn as nn
from easy_tpp.model.torch_model.torch_baselayer import EncoderLayer, MultiHeadAttention, \
    TimeShiftedPositionalEncoding
from easy_tpp.model.torch_model.torch_basemodel import TorchBaseModel
[docs]class SAHP(TorchBaseModel):
    """Torch implementation of Self-Attentive Hawkes Process, ICML 2020.
    Part of the code is collected from https://github.com/yangalan123/anhp-andtt/blob/master/sahp
    I slightly modify the original code because it is not stable.
    """
[docs]    def __init__(self, model_config):
        """Initialize the model
        Args:
            model_config (EasyTPP.ModelConfig): config of model specs.
        """
        super(SAHP, self).__init__(model_config)
        self.d_model = model_config.hidden_size
        self.d_time = model_config.time_emb_size
        self.use_norm = model_config.use_ln
        # position vector, used for temporal encoding
        self.layer_position_emb = TimeShiftedPositionalEncoding(d_model=self.d_model,
                                                                device=self.device)
        self.n_layers = model_config.num_layers
        self.n_head = model_config.num_heads
        self.dropout = model_config.dropout_rate
        # convert hidden vectors into a scalar
        self.layer_intensity_hidden = nn.Linear(self.d_model, self.num_event_types)
        self.softplus = nn.Softplus()
        self.stack_layers = nn.ModuleList(
            [EncoderLayer(
                self.d_model,
                MultiHeadAttention(self.n_head, self.d_model, self.d_model, self.dropout,
                                   output_linear=False),
                use_residual=False,
                dropout=self.dropout
            ) for _ in range(self.n_layers)])
        if self.use_norm:
            self.norm = nn.LayerNorm(self.d_model)
        # Equation (12): mu
        self.mu = torch.empty([self.d_model, self.num_event_types], device=self.device)
        # Equation (13): eta
        self.eta = torch.empty([self.d_model, self.num_event_types], device=self.device)
        # Equation (14): gamma
        self.gamma = torch.empty([self.d_model, self.num_event_types], device=self.device)
        nn.init.xavier_normal_(self.mu)
        nn.init.xavier_normal_(self.eta)
        nn.init.xavier_normal_(self.gamma) 
[docs]    def state_decay(self, encode_state, mu, eta, gamma, duration_t):
        """Equation (15), which computes the pre-intensity states
        Args:
            encode_state (tensor): [batch_size, seq_len, hidden_size].
            mu (tensor): [batch_size, seq_len, hidden_size].
            eta (tensor): [batch_size, seq_len, hidden_size].
            gamma (tensor): [batch_size, seq_len, hidden_size].
            duration_t (tensor): [batch_size, seq_len, num_sample].
        Returns:
            tensor: hidden states at event times.
        """
        # [batch_size, hidden_dim]
        states = torch.matmul(encode_state, mu) + (
                torch.matmul(encode_state, eta) - torch.matmul(encode_state, mu)) * torch.exp(
            -torch.matmul(encode_state, gamma) * duration_t)
        return states 
[docs]    def forward(self, time_seqs, time_delta_seqs, event_seqs, attention_mask):
        """Call the model
        Args:
            time_seqs (tensor): [batch_size, seq_len], timestamp seqs.
            time_delta_seqs (tensor): [batch_size, seq_len], inter-event time seqs.
            event_seqs (tensor): [batch_size, seq_len], event type seqs.
            attention_mask (tensor): [batch_size, seq_len, hidden_size], attention masks.
        Returns:
            tensor: hidden states at event times.
        """
        type_embedding = self.layer_type_emb(event_seqs)
        position_embedding = self.layer_position_emb(time_seqs, time_delta_seqs)
        enc_output = type_embedding + position_embedding
        for enc_layer in self.stack_layers:
            enc_output = enc_layer(
                enc_output,
                mask=attention_mask)
            if self.use_norm:
                enc_output = self.norm(enc_output)
        # [batch_size, seq_len, hidden_dim]
        return enc_output 
[docs]    def loglike_loss(self, batch):
        """Compute the loglike loss.
        Args:
            batch (tuple, list): batch input.
        Returns:
            list: loglike loss, num events.
        """
        time_seqs, time_delta_seqs, type_seqs, batch_non_pad_mask, attention_mask, type_mask = batch
        enc_out = self.forward(time_seqs[:, :-1], time_delta_seqs[:, 1:], type_seqs[:, :-1], attention_mask[:, 1:, :-1])
        cell_t = self.state_decay(encode_state=enc_out,
                                  mu=self.mu[None, ...],
                                  eta=self.eta[None, ...],
                                  gamma=self.gamma[None, ...],
                                  duration_t=time_delta_seqs[:, 1:, None])
        # [batch_size, seq_len, num_event_types]
        lambda_at_event = self.softplus(cell_t)
        # 2. compute non-event-loglik (using MC sampling to compute integral)
        # 2.1 sample times
        # [batch_size, seq_len, num_sample]
        sample_dtimes = self.make_dtime_loss_samples(time_delta_seqs[:, 1:])
        # 2.2 compute intensities at sampled times
        # [batch_size, num_times = max_len - 1, num_sample, event_num]
        state_t_sample = self.compute_states_at_sample_times(encode_state=enc_out,
                                                             sample_dtimes=sample_dtimes)
        lambda_t_sample = self.softplus(state_t_sample)
        event_ll, non_event_ll, num_events = self.compute_loglikelihood(lambda_at_event=lambda_at_event,
                                                                        lambdas_loss_samples=lambda_t_sample,
                                                                        time_delta_seq=time_delta_seqs[:, 1:],
                                                                        seq_mask=batch_non_pad_mask[:, 1:],
                                                                        lambda_type_mask=type_mask[:, 1:])
        # return enc_inten to compute accuracy
        loss = - (event_ll - non_event_ll).sum()
        return loss, num_events 
[docs]    def compute_states_at_sample_times(self,
                                       encode_state,
                                       sample_dtimes):
        """Compute the hidden states at sampled times.
        Args:
            encode_state (tensor): three tensors with each shape [batch_size, seq_len, hidden_size].
            sample_dtimes (tensor): [batch_size, seq_len, num_samples].
        Returns:
            tensor: [batch_size, seq_len, num_samples, hidden_size], hidden state at each sampled time.
        """
        cell_states = self.state_decay(encode_state[:, :, None, :],
                                       self.mu[None, None, ...],
                                       self.eta[None, None, ...],
                                       self.gamma[None, None, ...],
                                       sample_dtimes[:, :, :, None])
        return cell_states 
[docs]    def compute_intensities_at_sample_times(self,
                                            time_seqs,
                                            time_delta_seqs,
                                            type_seqs,
                                            sample_dtimes,
                                            **kwargs):
        """Compute hidden states at sampled times.
        Args:
            time_seqs (tensor): [batch_size, seq_len], times seqs.
            time_delta_seqs (tensor): [batch_size, seq_len], time delta seqs.
            type_seqs (tensor): [batch_size, seq_len], event type seqs.
            sample_dtimes (tensor): [batch_size, seq_len, num_samples], sampled inter-event timestamps.
        Returns:
            tensor: [batch_size, seq_len, num_samples, num_event_types], intensity at all sampled times.
        """
        attention_mask = kwargs.get('attention_mask', None)
        compute_last_step_only = kwargs.get('compute_last_step_only', False)
        if attention_mask is None:
            batch_size, seq_len = time_seqs.size()
            attention_mask = torch.triu(torch.ones(seq_len, seq_len, device=self.device), diagonal=1).unsqueeze(0)
            attention_mask = attention_mask.expand(batch_size, -1, -1).to(torch.bool)
        # [batch_size, seq_len, num_samples]
        enc_out = self.forward(time_seqs, time_delta_seqs, type_seqs, attention_mask)
        # [batch_size, seq_len, num_samples, hidden_size]
        encoder_output = self.compute_states_at_sample_times(enc_out, sample_dtimes)
        if compute_last_step_only:
            lambdas = self.softplus(encoder_output[:, -1:, :, :])
        else:
            # [batch_size, seq_len, num_samples, num_event_types]
            lambdas = self.softplus(encoder_output)
        return lambdas