easy_tpp.model.torch_model.torch_rmtpp
Classes
|
Torch implementation of Recurrent Marked Temporal Point Processes, KDD 2016. |
- class easy_tpp.model.torch_model.torch_rmtpp.RMTPP(model_config)[source]
Torch implementation of Recurrent Marked Temporal Point Processes, KDD 2016. https://www.kdd.org/kdd2016/papers/files/rpp1081-duA.pdf
- __init__(model_config)[source]
Initialize the model
- Parameters:
model_config (EasyTPP.ModelConfig) – config of model specs.
- forward(time_seqs, time_delta_seqs, type_seqs, **kwargs)[source]
Call the model.
- Parameters:
batch (list) – batch input.
- Returns:
- hidden states, [batch_size, seq_len, hidden_dim], states right before the event happens;
stacked decay states, [batch_size, max_seq_length, 4, hidden_dim], states right after the event happens.
- Return type:
list
- loglike_loss(batch)[source]
Compute the loglike loss.
- Parameters:
batch (list) – batch input.
- Returns:
loglikelihood loss and num of events.
- Return type:
tuple
- compute_states_at_sample_times(decay_states, sample_dtimes)[source]
Compute the hidden states at sampled times.
- Parameters:
decay_states (tensor) – [batch_size, seq_len, hidden_size].
sample_dtimes (tensor) – [batch_size, seq_len, num_samples].
- Returns:
hidden state at each sampled time.
- Return type:
tensor
- compute_intensities_at_sample_times(time_seqs, time_delta_seqs, type_seqs, sample_times, **kwargs)[source]
Compute the intensity at sampled times, not only event times.
- Parameters:
time_seq (tensor) – [batch_size, seq_len], times seqs.
time_delta_seq (tensor) – [batch_size, seq_len], time delta seqs.
event_seq (tensor) – [batch_size, seq_len], event type seqs.
sample_dtimes (tensor) – [batch_size, seq_len, num_sample], sampled inter-event timestamps.
- Returns:
- [batch_size, num_times, num_mc_sample, num_event_types],
intensity at each timestamp for each event type.
- Return type:
tensor