easy_tpp.model.torch_model.torch_ode_tpp
Functions
|
Classes
|
|
|
|
|
Torch implementation of a TPP with Neural ODE state evolution, which is a simplified version of TPP in https://arxiv.org/abs/2011.04583, ICLR 2021 |
- class easy_tpp.model.torch_model.torch_ode_tpp.NeuralODEAdjoint(device)[source]
-
- static forward(ctx, z_init, delta_t, ode_fn, solver, num_sample_times, *model_parameters)[source]
- Parameters:
ctx –
input – (tensor): [batch_size]
model –
solver –
delta_t (tensor) – [batch_size, num_sample_times]
Returns:
- static backward(ctx, grad_z)[source]
Defines a formula for differentiating the operation with backward mode automatic differentiation (alias to the vjp function).
This function is to be overridden by all subclasses.
It must accept a context
ctx
as the first argument, followed by as many outputs as theforward()
returned (None will be passed in for non tensor outputs of the forward function), and it should return as many tensors, as there were inputs toforward()
. Each argument is the gradient w.r.t the given output, and each returned value should be the gradient w.r.t. the corresponding input. If an input is not a Tensor or is a Tensor not requiring grads, you can just pass None as a gradient for that input.The context can be used to retrieve tensors saved during the forward pass. It also has an attribute
ctx.needs_input_grad
as a tuple of booleans representing whether each input needs gradient. E.g.,backward()
will havectx.needs_input_grad[0] = True
if the first input toforward()
needs gradient computed w.r.t. the output.
- class easy_tpp.model.torch_model.torch_ode_tpp.NeuralODE(model, solver, num_sample_times, device)[source]
- class easy_tpp.model.torch_model.torch_ode_tpp.ODETPP(model_config)[source]
Torch implementation of a TPP with Neural ODE state evolution, which is a simplified version of TPP in https://arxiv.org/abs/2011.04583, ICLR 2021
code reference: https://msurtsukov.github.io/Neural-ODE/; https://github.com/liruilong940607/NeuralODE/blob/master/NeuralODE.py
- __init__(model_config)[source]
Initialize the model
- Parameters:
model_config (EasyTPP.ModelConfig) – config of model specs.
- forward(time_delta_seqs, type_seqs, **kwargs)[source]
Call the model.
- Parameters:
time_delta_seqs (tensor) – [batch_size, seq_len], inter-event time seqs.
type_seqs (tensor) – [batch_size, seq_len], event type seqs.
- Returns:
hidden states at event times.
- Return type:
tensor
- loglike_loss(batch)[source]
Compute the loglike loss.
- Parameters:
batch (list) – batch input.
- Returns:
loglike loss, num events.
- Return type:
list
- compute_states_at_sample_times(state_ti_plus, sample_dtimes)[source]
Compute the states at sampling times.
- Parameters:
state_ti_plus (tensor) – [batch_size, seq_len, hidden_size], states right after the events.
sample_dtimes (tensor) – [batch_size, seq_len, num_samples], delta times in sampling.
- Returns:
hiddens states at sampling times.
- Return type:
tensor
- compute_intensities_at_sample_times(time_seqs, time_delta_seqs, type_seqs, sample_dtimes, **kwargs)[source]
Compute the intensity at sampled times, not only event times.
- Parameters:
time_seqs (tensor) – [batch_size, seq_len], times seqs.
time_delta_seqs (tensor) – [batch_size, seq_len], time delta seqs.
type_seqs (tensor) – [batch_size, seq_len], event type seqs.
sample_dtimes (tensor) – [batch_size, seq_len, num_sample], sampled inter-event timestamps.
- Returns:
- [batch_size, num_times, num_mc_sample, num_event_types],
intensity at each timestamp for each event type.
- Return type:
tensor