Source code for easy_tpp.model.torch_model.torch_ode_tpp

import torch
from torch import nn

from easy_tpp.model.torch_model.torch_baselayer import DNN
from easy_tpp.model.torch_model.torch_basemodel import TorchBaseModel
from easy_tpp.utils import rk4_step_method


def flatten_parameters(model):
    p_shapes = []
    flat_parameters = []
    for p in model.parameters():
        p_shapes.append(p.size())
        flat_parameters.append(p.flatten())
    return torch.cat(flat_parameters)


[docs]class NeuralODEAdjoint(torch.autograd.Function):
[docs] def __init__(self, device): super(NeuralODEAdjoint, self).__init__() self.device = device
[docs] @staticmethod def forward(ctx, z_init, delta_t, ode_fn, solver, num_sample_times, *model_parameters): """ Args: ctx: input: (tensor): [batch_size] model: solver: delta_t (tensor): [batch_size, num_sample_times] Returns: """ ctx.ode_fn = ode_fn ctx.solver = solver ctx.delta_t = delta_t ctx.model_parameters = model_parameters ctx.num_sample_times = num_sample_times total_state = [] dt_ratio = 1.0 / num_sample_times delta_t = delta_t * dt_ratio with torch.no_grad(): state = z_init for i in range(num_sample_times): # [batch_size, hidden_size] state = solver(diff_func=ode_fn, dt=delta_t, z0=state) total_state.append(state) # [batch_size, num_samples, hidden_size] ctx.save_for_backward(state) return state
[docs] @staticmethod def backward(ctx, grad_z): output_state = ctx.saved_tensors[0] # return a tuple ode_fn = ctx.ode_fn solver = ctx.solver delta_t = ctx.delta_t model_parameters = ctx.model_parameters num_sample_times = ctx.num_sample_times # Dynamics of augmented system to be calculated backwards in time def aug_dynamics(aug_states): tmp_z = aug_states[0] tmp_neg_a = -aug_states[1] with torch.set_grad_enabled(True): tmp_z = tmp_z.detach().requires_grad_(True) func_eval = ode_fn(tmp_z) tmp_ds = torch.autograd.grad( (func_eval,), (tmp_z, *model_parameters), grad_outputs=tmp_neg_a, allow_unused=True, retain_graph=True) neg_adfdz = tmp_ds[0] neg_adfdtheta = [torch.flatten(var) for var in tmp_ds[1:]] return [func_eval, neg_adfdz, *neg_adfdtheta] dt_ratio = 1.0 / num_sample_times delta_t = delta_t * dt_ratio with torch.no_grad(): # Construct back-state for ode solver # reshape variable \theta for batch solving init_var_grad = [torch.zeros_like(torch.flatten(var)) for var in model_parameters] # [z(t_1), a(t_1), \theta] z1 = output_state a1 = grad_z states = [z1, a1, *init_var_grad] for i in range(num_sample_times): states = solver(aug_dynamics, -delta_t, states) grad_z0 = states[1] grad_theta = [torch.reshape(torch.mean(var_grad, dim=0), var.shape) for var, var_grad in zip(model_parameters, states[2:])] return (grad_z0, None, None, None, None, *grad_theta)
[docs]class NeuralODE(nn.Module):
[docs] def __init__(self, model, solver, num_sample_times, device): super().__init__() self.model = model self.solver = solver self.params = [w for w in model.parameters()] self.num_sample_times = num_sample_times self.device = device
[docs] def forward(self, input_state, delta_time): """ Args: input_state: [batch_size, hidden_size] return_state: Returns: """ output_state = NeuralODEAdjoint.apply(input_state, delta_time, self.model, self.solver, self.num_sample_times, *self.params) # [batch_size, num_sample_times, hidden_size] return output_state
[docs]class ODETPP(TorchBaseModel): """Torch implementation of a TPP with Neural ODE state evolution, which is a simplified version of TPP in https://arxiv.org/abs/2011.04583, ICLR 2021 code reference: https://msurtsukov.github.io/Neural-ODE/; https://github.com/liruilong940607/NeuralODE/blob/master/NeuralODE.py """
[docs] def __init__(self, model_config): """Initialize the model Args: model_config (EasyTPP.ModelConfig): config of model specs. """ super(ODETPP, self).__init__(model_config) self.layer_intensity = nn.Sequential( nn.Linear(self.hidden_size, self.num_event_types), nn.Softplus()) self.event_model = DNN(inputs_dim=self.hidden_size, hidden_size=[self.hidden_size]) self.ode_num_sample_per_step = model_config.model_specs['ode_num_sample_per_step'] self.time_factor = model_config.model_specs['time_factor'] self.solver = rk4_step_method self.layer_neural_ode = NeuralODE(model=self.event_model, solver=self.solver, num_sample_times=self.ode_num_sample_per_step, device=self.device)
[docs] def forward(self, time_delta_seqs, type_seqs, **kwargs): """Call the model. Args: time_delta_seqs (tensor): [batch_size, seq_len], inter-event time seqs. type_seqs (tensor): [batch_size, seq_len], event type seqs. Returns: tensor: hidden states at event times. """ # [batch_size, seq_len=max_len-1, hidden_size] type_seq_emb = self.layer_type_emb(type_seqs) time_delta_seqs_ = time_delta_seqs[..., None] total_state_at_event_minus = [] total_state_at_event_plus = [] last_state = torch.zeros_like(type_seq_emb[:, 0, :], device=self.device) for type_emb, dt in zip(torch.unbind(type_seq_emb, dim=-2), torch.unbind(time_delta_seqs_, dim=-2)): dt = dt / self.time_factor last_state = self.layer_neural_ode(last_state + type_emb, dt) total_state_at_event_minus.append(last_state) total_state_at_event_plus.append(last_state + type_emb) # [batch_size, seq_len, hidden_size] state_ti = torch.stack(total_state_at_event_minus, dim=1) # [batch_size, seq_len, hidden_size] state_to_evolve = torch.stack(total_state_at_event_plus, dim=1) return state_ti, state_to_evolve
[docs] def loglike_loss(self, batch): """Compute the loglike loss. Args: batch (list): batch input. Returns: list: loglike loss, num events. """ time_seqs, time_delta_seqs, type_seqs, batch_non_pad_mask, _, type_mask = batch state_ti, state_ti_plus = self.forward(time_delta_seqs[:, 1:], type_seqs[:, :-1]) # Num of samples in each batch and num of event time point in the sequence batch_size, seq_len, _ = state_ti.size() # Lambda(t) right before each event time point # lambda_at_event - [batch_size, num_times=max_len-1, num_event_types] # Here we drop the last event because it has no delta_time label (can not decay) lambda_at_event = self.layer_intensity(state_ti) # interval_t_sample - [batch_size, num_times=max_len-1, num_mc_sample] # for every batch and every event point => do a sampling (num_mc_sampling) # the first dtime is zero, so we use time_delta_seq[:, 1:] interval_t_sample = self.make_dtime_loss_samples(time_delta_seqs[:, 1:]) # [batch_size, num_times = max_len - 1, num_mc_sample, hidden_size] sample_state_ti = self.compute_states_at_sample_times(state_ti_plus, interval_t_sample) # [batch_size, num_times = max_len - 1, num_mc_sample, event_num] lambda_t_sample = self.layer_intensity(sample_state_ti) event_ll, non_event_ll, num_events = self.compute_loglikelihood(lambda_at_event=lambda_at_event, lambdas_loss_samples=lambda_t_sample, time_delta_seq=time_delta_seqs[:, 1:], seq_mask=batch_non_pad_mask[:, 1:], lambda_type_mask=type_mask[:, 1:]) loss = - (event_ll - non_event_ll).sum() return loss, num_events
[docs] def compute_states_at_sample_times(self, state_ti_plus, sample_dtimes): """Compute the states at sampling times. Args: state_ti_plus (tensor): [batch_size, seq_len, hidden_size], states right after the events. sample_dtimes (tensor): [batch_size, seq_len, num_samples], delta times in sampling. Returns: tensor: hiddens states at sampling times. """ # Use broadcasting to compute the decays at all time steps # at all sample points # h_ts shape (batch_size, seq_len, num_samples, hidden_dim) state = self.solver(diff_func=self.event_model, dt=sample_dtimes[..., None], # [batch_size, seq_len, num_samples, 1] z0=state_ti_plus[..., None, :]) # [batch_size, seq_len, 1, hidden_size] return state
[docs] def compute_intensities_at_sample_times(self, time_seqs, time_delta_seqs, type_seqs, sample_dtimes, **kwargs): """Compute the intensity at sampled times, not only event times. Args: time_seqs (tensor): [batch_size, seq_len], times seqs. time_delta_seqs (tensor): [batch_size, seq_len], time delta seqs. type_seqs (tensor): [batch_size, seq_len], event type seqs. sample_dtimes (tensor): [batch_size, seq_len, num_sample], sampled inter-event timestamps. Returns: tensor: [batch_size, num_times, num_mc_sample, num_event_types], intensity at each timestamp for each event type. """ compute_last_step_only = kwargs.get('compute_last_step_only', False) # forward to the last but one event state_ti, state_ti_plus = self.forward(time_delta_seqs, type_seqs, **kwargs) # Num of samples in each batch and num of event time point in the sequence batch_size, seq_len, _ = state_ti.size() # [batch_size, num_sample_times, num_mc_sample, hidden_size] sample_state_ti = self.compute_states_at_sample_times(state_ti_plus, sample_dtimes) if compute_last_step_only: # [batch_size, 1, num_mc_sample, num_event_types] sampled_intensities = self.layer_intensity(sample_state_ti[:, -1:, :, :]) else: # [batch_size, num_sample_times, num_mc_sample, num_event_types] sampled_intensities = self.layer_intensity(sample_state_ti) return sampled_intensities