easy_tpp.model.torch_model.torch_rmtpp

Classes

RMTPP(model_config)

Torch implementation of Recurrent Marked Temporal Point Processes, KDD 2016.

class easy_tpp.model.torch_model.torch_rmtpp.RMTPP(model_config)[source]

Torch implementation of Recurrent Marked Temporal Point Processes, KDD 2016. https://www.kdd.org/kdd2016/papers/files/rpp1081-duA.pdf

__init__(model_config)[source]

Initialize the model

Parameters:

model_config (EasyTPP.ModelConfig) – config of model specs.

state_decay(states_to_decay, duration_t)[source]

Equation (11), which computes intensity

forward(time_seqs, time_delta_seqs, type_seqs, **kwargs)[source]

Call the model.

Parameters:

batch (list) – batch input.

Returns:

hidden states, [batch_size, seq_len, hidden_dim], states right before the event happens;

stacked decay states, [batch_size, max_seq_length, 4, hidden_dim], states right after the event happens.

Return type:

list

loglike_loss(batch)[source]

Compute the loglike loss.

Parameters:

batch (list) – batch input.

Returns:

loglikelihood loss and num of events.

Return type:

tuple

compute_states_at_sample_times(decay_states, sample_dtimes)[source]

Compute the hidden states at sampled times.

Parameters:
  • decay_states (tensor) – [batch_size, seq_len, hidden_size].

  • sample_dtimes (tensor) – [batch_size, seq_len, num_samples].

Returns:

hidden state at each sampled time.

Return type:

tensor

compute_intensities_at_sample_times(time_seqs, time_delta_seqs, type_seqs, sample_times, **kwargs)[source]

Compute the intensity at sampled times, not only event times.

Parameters:
  • time_seq (tensor) – [batch_size, seq_len], times seqs.

  • time_delta_seq (tensor) – [batch_size, seq_len], time delta seqs.

  • event_seq (tensor) – [batch_size, seq_len], event type seqs.

  • sample_dtimes (tensor) – [batch_size, seq_len, num_sample], sampled inter-event timestamps.

Returns:

[batch_size, num_times, num_mc_sample, num_event_types],

intensity at each timestamp for each event type.

Return type:

tensor