import torch
from torch import nn
from easy_tpp.model.torch_model.torch_baselayer import DNN
from easy_tpp.model.torch_model.torch_basemodel import TorchBaseModel
from easy_tpp.utils import rk4_step_method
def flatten_parameters(model):
p_shapes = []
flat_parameters = []
for p in model.parameters():
p_shapes.append(p.size())
flat_parameters.append(p.flatten())
return torch.cat(flat_parameters)
[docs]class NeuralODEAdjoint(torch.autograd.Function):
[docs] def __init__(self, device):
super(NeuralODEAdjoint, self).__init__()
self.device = device
[docs] @staticmethod
def forward(ctx, z_init, delta_t, ode_fn, solver, num_sample_times, *model_parameters):
"""
Args:
ctx:
input: (tensor): [batch_size]
model:
solver:
delta_t (tensor): [batch_size, num_sample_times]
Returns:
"""
ctx.ode_fn = ode_fn
ctx.solver = solver
ctx.delta_t = delta_t
ctx.model_parameters = model_parameters
ctx.num_sample_times = num_sample_times
total_state = []
dt_ratio = 1.0 / num_sample_times
delta_t = delta_t * dt_ratio
with torch.no_grad():
state = z_init
for i in range(num_sample_times):
# [batch_size, hidden_size]
state = solver(diff_func=ode_fn, dt=delta_t, z0=state)
total_state.append(state)
# [batch_size, num_samples, hidden_size]
ctx.save_for_backward(state)
return state
[docs] @staticmethod
def backward(ctx, grad_z):
output_state = ctx.saved_tensors[0] # return a tuple
ode_fn = ctx.ode_fn
solver = ctx.solver
delta_t = ctx.delta_t
model_parameters = ctx.model_parameters
num_sample_times = ctx.num_sample_times
# Dynamics of augmented system to be calculated backwards in time
def aug_dynamics(aug_states):
tmp_z = aug_states[0]
tmp_neg_a = -aug_states[1]
with torch.set_grad_enabled(True):
tmp_z = tmp_z.detach().requires_grad_(True)
func_eval = ode_fn(tmp_z)
tmp_ds = torch.autograd.grad(
(func_eval,), (tmp_z, *model_parameters),
grad_outputs=tmp_neg_a,
allow_unused=True,
retain_graph=True)
neg_adfdz = tmp_ds[0]
neg_adfdtheta = [torch.flatten(var) for var in tmp_ds[1:]]
return [func_eval, neg_adfdz, *neg_adfdtheta]
dt_ratio = 1.0 / num_sample_times
delta_t = delta_t * dt_ratio
with torch.no_grad():
# Construct back-state for ode solver
# reshape variable \theta for batch solving
init_var_grad = [torch.zeros_like(torch.flatten(var)) for var in model_parameters]
# [z(t_1), a(t_1), \theta]
z1 = output_state
a1 = grad_z
states = [z1, a1, *init_var_grad]
for i in range(num_sample_times):
states = solver(aug_dynamics, -delta_t, states)
grad_z0 = states[1]
grad_theta = [torch.reshape(torch.mean(var_grad, dim=0), var.shape) for var, var_grad in
zip(model_parameters, states[2:])]
return (grad_z0, None, None, None, None, *grad_theta)
[docs]class NeuralODE(nn.Module):
[docs] def __init__(self, model, solver, num_sample_times, device):
super().__init__()
self.model = model
self.solver = solver
self.params = [w for w in model.parameters()]
self.num_sample_times = num_sample_times
self.device = device
[docs] def forward(self, input_state, delta_time):
"""
Args:
input_state: [batch_size, hidden_size]
return_state:
Returns:
"""
output_state = NeuralODEAdjoint.apply(input_state,
delta_time,
self.model,
self.solver,
self.num_sample_times,
*self.params)
# [batch_size, num_sample_times, hidden_size]
return output_state
[docs]class ODETPP(TorchBaseModel):
"""Torch implementation of a TPP with Neural ODE state evolution, which is a simplified version of TPP in
https://arxiv.org/abs/2011.04583, ICLR 2021
code reference: https://msurtsukov.github.io/Neural-ODE/;
https://github.com/liruilong940607/NeuralODE/blob/master/NeuralODE.py
"""
[docs] def __init__(self, model_config):
"""Initialize the model
Args:
model_config (EasyTPP.ModelConfig): config of model specs.
"""
super(ODETPP, self).__init__(model_config)
self.layer_intensity = nn.Sequential(
nn.Linear(self.hidden_size, self.num_event_types),
nn.Softplus())
self.event_model = DNN(inputs_dim=self.hidden_size,
hidden_size=[self.hidden_size])
self.ode_num_sample_per_step = model_config.model_specs['ode_num_sample_per_step']
self.time_factor = model_config.model_specs['time_factor']
self.solver = rk4_step_method
self.layer_neural_ode = NeuralODE(model=self.event_model,
solver=self.solver,
num_sample_times=self.ode_num_sample_per_step,
device=self.device)
[docs] def forward(self, time_delta_seqs, type_seqs, **kwargs):
"""Call the model.
Args:
time_delta_seqs (tensor): [batch_size, seq_len], inter-event time seqs.
type_seqs (tensor): [batch_size, seq_len], event type seqs.
Returns:
tensor: hidden states at event times.
"""
# [batch_size, seq_len=max_len-1, hidden_size]
type_seq_emb = self.layer_type_emb(type_seqs)
time_delta_seqs_ = time_delta_seqs[..., None]
total_state_at_event_minus = []
total_state_at_event_plus = []
last_state = torch.zeros_like(type_seq_emb[:, 0, :], device=self.device)
for type_emb, dt in zip(torch.unbind(type_seq_emb, dim=-2),
torch.unbind(time_delta_seqs_, dim=-2)):
dt = dt / self.time_factor
last_state = self.layer_neural_ode(last_state + type_emb, dt)
total_state_at_event_minus.append(last_state)
total_state_at_event_plus.append(last_state + type_emb)
# [batch_size, seq_len, hidden_size]
state_ti = torch.stack(total_state_at_event_minus, dim=1)
# [batch_size, seq_len, hidden_size]
state_to_evolve = torch.stack(total_state_at_event_plus, dim=1)
return state_ti, state_to_evolve
[docs] def loglike_loss(self, batch):
"""Compute the loglike loss.
Args:
batch (list): batch input.
Returns:
list: loglike loss, num events.
"""
time_seqs, time_delta_seqs, type_seqs, batch_non_pad_mask, _, type_mask = batch
state_ti, state_ti_plus = self.forward(time_delta_seqs[:, 1:], type_seqs[:, :-1])
# Num of samples in each batch and num of event time point in the sequence
batch_size, seq_len, _ = state_ti.size()
# Lambda(t) right before each event time point
# lambda_at_event - [batch_size, num_times=max_len-1, num_event_types]
# Here we drop the last event because it has no delta_time label (can not decay)
lambda_at_event = self.layer_intensity(state_ti)
# interval_t_sample - [batch_size, num_times=max_len-1, num_mc_sample]
# for every batch and every event point => do a sampling (num_mc_sampling)
# the first dtime is zero, so we use time_delta_seq[:, 1:]
interval_t_sample = self.make_dtime_loss_samples(time_delta_seqs[:, 1:])
# [batch_size, num_times = max_len - 1, num_mc_sample, hidden_size]
sample_state_ti = self.compute_states_at_sample_times(state_ti_plus, interval_t_sample)
# [batch_size, num_times = max_len - 1, num_mc_sample, event_num]
lambda_t_sample = self.layer_intensity(sample_state_ti)
event_ll, non_event_ll, num_events = self.compute_loglikelihood(lambda_at_event=lambda_at_event,
lambdas_loss_samples=lambda_t_sample,
time_delta_seq=time_delta_seqs[:, 1:],
seq_mask=batch_non_pad_mask[:, 1:],
lambda_type_mask=type_mask[:, 1:])
loss = - (event_ll - non_event_ll).sum()
return loss, num_events
[docs] def compute_states_at_sample_times(self, state_ti_plus, sample_dtimes):
"""Compute the states at sampling times.
Args:
state_ti_plus (tensor): [batch_size, seq_len, hidden_size], states right after the events.
sample_dtimes (tensor): [batch_size, seq_len, num_samples], delta times in sampling.
Returns:
tensor: hiddens states at sampling times.
"""
# Use broadcasting to compute the decays at all time steps
# at all sample points
# h_ts shape (batch_size, seq_len, num_samples, hidden_dim)
state = self.solver(diff_func=self.event_model,
dt=sample_dtimes[..., None], # [batch_size, seq_len, num_samples, 1]
z0=state_ti_plus[..., None, :]) # [batch_size, seq_len, 1, hidden_size]
return state
[docs] def compute_intensities_at_sample_times(self, time_seqs, time_delta_seqs, type_seqs, sample_dtimes, **kwargs):
"""Compute the intensity at sampled times, not only event times.
Args:
time_seqs (tensor): [batch_size, seq_len], times seqs.
time_delta_seqs (tensor): [batch_size, seq_len], time delta seqs.
type_seqs (tensor): [batch_size, seq_len], event type seqs.
sample_dtimes (tensor): [batch_size, seq_len, num_sample], sampled inter-event timestamps.
Returns:
tensor: [batch_size, num_times, num_mc_sample, num_event_types],
intensity at each timestamp for each event type.
"""
compute_last_step_only = kwargs.get('compute_last_step_only', False)
# forward to the last but one event
state_ti, state_ti_plus = self.forward(time_delta_seqs, type_seqs, **kwargs)
# Num of samples in each batch and num of event time point in the sequence
batch_size, seq_len, _ = state_ti.size()
# [batch_size, num_sample_times, num_mc_sample, hidden_size]
sample_state_ti = self.compute_states_at_sample_times(state_ti_plus, sample_dtimes)
if compute_last_step_only:
# [batch_size, 1, num_mc_sample, num_event_types]
sampled_intensities = self.layer_intensity(sample_state_ti[:, -1:, :, :])
else:
# [batch_size, num_sample_times, num_mc_sample, num_event_types]
sampled_intensities = self.layer_intensity(sample_state_ti)
return sampled_intensities