easy_tpp.model.torch_model.torch_sahp
Classes
|
Torch implementation of Self-Attentive Hawkes Process, ICML 2020. |
- class easy_tpp.model.torch_model.torch_sahp.SAHP(model_config)[source]
Torch implementation of Self-Attentive Hawkes Process, ICML 2020. Part of the code is collected from https://github.com/yangalan123/anhp-andtt/blob/master/sahp
I slightly modify the original code because it is not stable.
- __init__(model_config)[source]
Initialize the model
- Parameters:
model_config (EasyTPP.ModelConfig) – config of model specs.
- state_decay(encode_state, mu, eta, gamma, duration_t)[source]
Equation (15), which computes the pre-intensity states
- Parameters:
encode_state (tensor) – [batch_size, seq_len, hidden_size].
mu (tensor) – [batch_size, seq_len, hidden_size].
eta (tensor) – [batch_size, seq_len, hidden_size].
gamma (tensor) – [batch_size, seq_len, hidden_size].
duration_t (tensor) – [batch_size, seq_len, num_sample].
- Returns:
hidden states at event times.
- Return type:
tensor
- forward(time_seqs, time_delta_seqs, event_seqs, attention_mask)[source]
Call the model
- Parameters:
time_seqs (tensor) – [batch_size, seq_len], timestamp seqs.
time_delta_seqs (tensor) – [batch_size, seq_len], inter-event time seqs.
event_seqs (tensor) – [batch_size, seq_len], event type seqs.
attention_mask (tensor) – [batch_size, seq_len, hidden_size], attention masks.
- Returns:
hidden states at event times.
- Return type:
tensor
- loglike_loss(batch)[source]
Compute the loglike loss.
- Parameters:
batch (tuple, list) – batch input.
- Returns:
loglike loss, num events.
- Return type:
list
- compute_states_at_sample_times(encode_state, sample_dtimes)[source]
Compute the hidden states at sampled times.
- Parameters:
encode_state (tensor) – three tensors with each shape [batch_size, seq_len, hidden_size].
sample_dtimes (tensor) – [batch_size, seq_len, num_samples].
- Returns:
[batch_size, seq_len, num_samples, hidden_size], hidden state at each sampled time.
- Return type:
tensor
- compute_intensities_at_sample_times(time_seqs, time_delta_seqs, type_seqs, sample_dtimes, **kwargs)[source]
Compute hidden states at sampled times.
- Parameters:
time_seqs (tensor) – [batch_size, seq_len], times seqs.
time_delta_seqs (tensor) – [batch_size, seq_len], time delta seqs.
type_seqs (tensor) – [batch_size, seq_len], event type seqs.
sample_dtimes (tensor) – [batch_size, seq_len, num_samples], sampled inter-event timestamps.
- Returns:
[batch_size, seq_len, num_samples, num_event_types], intensity at all sampled times.
- Return type:
tensor