EasyTPP
GETTING STARTED
Introduction
Installation
Quick Start
USER GUIDE
Dataset
Model Training
Model Prediction
DEVELOPER GUIDE
Model Customization
ADVANCED TOPICS
Thinning Algorithm
Tensorboard
Performance Benchmarks
Implementation Details
API REFERENCE
Config
Preprocess
Model
Runner
Hyper-parameter Optimization
Tf and Torch Wrapper
Utilities
EasyTPP
Index
Index
_
|
A
|
B
|
C
|
D
|
E
|
F
|
G
|
H
|
I
|
L
|
M
|
N
|
O
|
P
|
R
|
S
|
T
|
U
|
V
_
__init__() (config_factory.BaseConfig method)
(config_factory.DataConfig method)
(config_factory.DataSpecConfig method)
(config_factory.HPOConfig method)
(config_factory.HPORunnerConfig method)
(config_factory.ModelConfig method)
(config_factory.RunnerConfig method)
(easy_tpp.model.torch_model.torch_attnhp.AttNHP method)
(easy_tpp.model.torch_model.torch_baselayer.DNN method)
(easy_tpp.model.torch_model.torch_baselayer.EncoderLayer method)
(easy_tpp.model.torch_model.torch_baselayer.MultiHeadAttention method)
(easy_tpp.model.torch_model.torch_baselayer.SublayerConnection method)
(easy_tpp.model.torch_model.torch_baselayer.TimePositionalEncoding method)
(easy_tpp.model.torch_model.torch_baselayer.TimeShiftedPositionalEncoding method)
(easy_tpp.model.torch_model.torch_basemodel.TorchBaseModel method)
(easy_tpp.model.torch_model.torch_fullynn.CumulHazardFunctionNetwork method)
(easy_tpp.model.torch_model.torch_fullynn.FullyNN method)
(easy_tpp.model.torch_model.torch_intensity_free.IntensityFree method)
(easy_tpp.model.torch_model.torch_intensity_free.LogNormalMixtureDistribution method)
(easy_tpp.model.torch_model.torch_nhp.ContTimeLSTMCell method)
(easy_tpp.model.torch_model.torch_nhp.NHP method)
(easy_tpp.model.torch_model.torch_ode_tpp.NeuralODE method)
(easy_tpp.model.torch_model.torch_ode_tpp.NeuralODEAdjoint method)
(easy_tpp.model.torch_model.torch_ode_tpp.ODETPP method)
(easy_tpp.model.torch_model.torch_rmtpp.RMTPP method)
(easy_tpp.model.torch_model.torch_sahp.SAHP method)
(easy_tpp.model.torch_model.torch_thinning.EventSampler method)
(easy_tpp.model.torch_model.torch_thp.THP method)
(preprocess.EventTokenizer method)
(preprocess.TPPDataLoader method)
(preprocess.TPPDataset method)
(runner.Runner method)
(runner.TPPRunner method)
(utils.MetricsTracker method)
(utils.Timer method)
A
activation_layer() (in module easy_tpp.model.torch_model.torch_baselayer)
array_pad_cols() (in module utils)
AttNHP (class in easy_tpp.model.torch_model.torch_attnhp)
B
backward() (easy_tpp.model.torch_model.torch_ode_tpp.NeuralODEAdjoint static method)
BaseConfig (class in config_factory)
build_from_config() (runner.Runner static method)
build_from_yaml_file() (config_factory.Config static method)
build_input_from_pkl() (preprocess.TPPDataLoader method)
by_name() (utils.Registrable class method)
C
clamp_preserve_gradients() (in module easy_tpp.model.torch_model.torch_intensity_free)
compute_intensities_at_sample_times() (easy_tpp.model.torch_model.torch_attnhp.AttNHP method)
(easy_tpp.model.torch_model.torch_fullynn.FullyNN method)
(easy_tpp.model.torch_model.torch_nhp.NHP method)
(easy_tpp.model.torch_model.torch_ode_tpp.ODETPP method)
(easy_tpp.model.torch_model.torch_rmtpp.RMTPP method)
(easy_tpp.model.torch_model.torch_sahp.SAHP method)
(easy_tpp.model.torch_model.torch_thp.THP method)
compute_intensity_upper_bound() (easy_tpp.model.torch_model.torch_thinning.EventSampler method)
compute_loglikelihood() (easy_tpp.model.torch_model.torch_basemodel.TorchBaseModel method)
compute_states_at_sample_times() (easy_tpp.model.torch_model.torch_attnhp.AttNHP method)
(easy_tpp.model.torch_model.torch_nhp.NHP method)
(easy_tpp.model.torch_model.torch_ode_tpp.ODETPP method)
(easy_tpp.model.torch_model.torch_rmtpp.RMTPP method)
(easy_tpp.model.torch_model.torch_sahp.SAHP method)
(easy_tpp.model.torch_model.torch_thp.THP method)
compute_temporal_embedding() (easy_tpp.model.torch_model.torch_attnhp.AttNHP method)
concat_element() (in module utils)
Config (class in config_factory)
config_factory
module
ContTimeLSTMCell (class in easy_tpp.model.torch_model.torch_nhp)
copy() (config_factory.BaseConfig method)
(config_factory.Config method)
(config_factory.DataConfig method)
(config_factory.DataSpecConfig method)
(config_factory.HPOConfig method)
(config_factory.HPORunnerConfig method)
(config_factory.ModelConfig method)
(config_factory.RunnerConfig method)
count_model_params() (in module utils)
create_folder() (in module utils)
CumulHazardFunctionNetwork (class in easy_tpp.model.torch_model.torch_fullynn)
D
DataConfig (class in config_factory)
DataSpecConfig (class in config_factory)
decay() (easy_tpp.model.torch_model.torch_nhp.ContTimeLSTMCell method)
DEFAULT_DATASET_ID (utils.DefaultRunnerConfig attribute)
DEFAULT_FORMAT (utils.LogConst attribute)
DEFAULT_FORMAT_LONG (utils.LogConst attribute)
DefaultRunnerConfig (class in utils)
dict_deep_update() (in module utils)
DNN (class in easy_tpp.model.torch_model.torch_baselayer)
DO_NOT_PAD (utils.PaddingStrategy attribute)
DO_NOT_TRUNCATE (utils.TruncationStrategy attribute)
draw_next_time_one_step() (easy_tpp.model.torch_model.torch_thinning.EventSampler method)
E
easy_tpp.model.tf_model
module
easy_tpp.model.torch_model
module
easy_tpp.model.torch_model.torch_attnhp
module
easy_tpp.model.torch_model.torch_baselayer
module
easy_tpp.model.torch_model.torch_basemodel
module
easy_tpp.model.torch_model.torch_fullynn
module
easy_tpp.model.torch_model.torch_intensity_free
module
easy_tpp.model.torch_model.torch_nhp
module
easy_tpp.model.torch_model.torch_ode_tpp
module
easy_tpp.model.torch_model.torch_rmtpp
module
easy_tpp.model.torch_model.torch_sahp
module
easy_tpp.model.torch_model.torch_thinning
module
easy_tpp.model.torch_model.torch_thp
module
EncoderLayer (class in easy_tpp.model.torch_model.torch_baselayer)
end() (utils.Timer method)
ensure_valid_config() (config_factory.RunnerConfig method)
evaluate() (runner.Runner method)
EventSampler (class in easy_tpp.model.torch_model.torch_thinning)
EventTokenizer (class in preprocess)
ExplicitEnum (class in utils)
F
forward() (easy_tpp.model.torch_model.torch_attnhp.AttNHP method)
(easy_tpp.model.torch_model.torch_baselayer.DNN method)
(easy_tpp.model.torch_model.torch_baselayer.EncoderLayer method)
(easy_tpp.model.torch_model.torch_baselayer.GELU method)
(easy_tpp.model.torch_model.torch_baselayer.Identity method)
(easy_tpp.model.torch_model.torch_baselayer.MultiHeadAttention method)
(easy_tpp.model.torch_model.torch_baselayer.SublayerConnection method)
(easy_tpp.model.torch_model.torch_baselayer.TimePositionalEncoding method)
(easy_tpp.model.torch_model.torch_baselayer.TimeShiftedPositionalEncoding method)
(easy_tpp.model.torch_model.torch_fullynn.CumulHazardFunctionNetwork method)
(easy_tpp.model.torch_model.torch_fullynn.FullyNN method)
(easy_tpp.model.torch_model.torch_intensity_free.IntensityFree method)
(easy_tpp.model.torch_model.torch_nhp.ContTimeLSTMCell method)
(easy_tpp.model.torch_model.torch_nhp.NHP method)
(easy_tpp.model.torch_model.torch_ode_tpp.NeuralODE method)
(easy_tpp.model.torch_model.torch_ode_tpp.NeuralODEAdjoint static method)
(easy_tpp.model.torch_model.torch_ode_tpp.ODETPP method)
(easy_tpp.model.torch_model.torch_rmtpp.RMTPP method)
(easy_tpp.model.torch_model.torch_sahp.SAHP method)
(easy_tpp.model.torch_model.torch_thp.THP method)
forward_pass() (easy_tpp.model.torch_model.torch_attnhp.AttNHP method)
FullyNN (class in easy_tpp.model.torch_model.torch_fullynn)
G
GELU (class in easy_tpp.model.torch_model.torch_baselayer)
gen() (runner.Runner method)
generate_model_from_config() (easy_tpp.model.torch_model.torch_basemodel.TorchBaseModel static method)
get() (config_factory.Config method)
get_all_registered_metric() (utils.MetricsHelper static method)
get_config() (runner.Runner method)
get_data_dir() (config_factory.DataConfig method)
get_data_loader() (in module preprocess)
get_loader() (preprocess.TPPDataLoader method)
get_logits_at_last_step() (easy_tpp.model.torch_model.torch_basemodel.TorchBaseModel static method)
get_metric_direction() (config_factory.RunnerConfig method)
(utils.MetricsHelper static method)
get_metric_function() (utils.MetricsHelper static method)
get_metric_functions() (config_factory.RunnerConfig method)
get_metrics_callback_from_names() (utils.MetricsHelper static method)
get_model_dir() (runner.Runner method)
get_stage() (in module utils)
get_yaml_config() (config_factory.BaseConfig method)
(config_factory.Config method)
(config_factory.DataConfig method)
(config_factory.DataSpecConfig method)
(config_factory.HPOConfig method)
(config_factory.ModelConfig method)
(config_factory.RunnerConfig method)
H
has_key() (in module utils)
HPOConfig (class in config_factory)
HPORunnerConfig (class in config_factory)
I
Identity (class in easy_tpp.model.torch_model.torch_baselayer)
init_dense_layer() (easy_tpp.model.torch_model.torch_nhp.ContTimeLSTMCell method)
init_state() (easy_tpp.model.torch_model.torch_nhp.NHP method)
IntensityFree (class in easy_tpp.model.torch_model.torch_intensity_free)
is_local_master_process() (in module utils)
is_master_process() (in module utils)
is_numpy_array() (in module utils)
is_tensorflow_probability_available() (in module utils)
is_tf_available() (in module utils)
is_tf_gpu_available() (in module utils)
is_torch_available() (in module utils)
is_torch_cuda_available() (in module utils)
is_torch_device() (in module utils)
is_torch_gpu_available() (in module utils)
is_torchvision_available() (in module utils)
L
list_available() (utils.Registrable class method)
load_pickle() (in module utils)
load_yaml_config() (in module utils)
LogConst (class in utils)
loglike_loss() (easy_tpp.model.torch_model.torch_attnhp.AttNHP method)
(easy_tpp.model.torch_model.torch_fullynn.FullyNN method)
(easy_tpp.model.torch_model.torch_intensity_free.IntensityFree method)
(easy_tpp.model.torch_model.torch_nhp.NHP method)
(easy_tpp.model.torch_model.torch_ode_tpp.ODETPP method)
(easy_tpp.model.torch_model.torch_rmtpp.RMTPP method)
(easy_tpp.model.torch_model.torch_sahp.SAHP method)
(easy_tpp.model.torch_model.torch_thp.THP method)
LogNormalMixtureDistribution (class in easy_tpp.model.torch_model.torch_intensity_free)
LONGEST (utils.PaddingStrategy attribute)
LONGEST_FIRST (utils.TruncationStrategy attribute)
M
make_attn_mask_for_pad_sequence() (preprocess.EventTokenizer method)
make_combined_att_mask() (easy_tpp.model.torch_model.torch_attnhp.AttNHP method)
make_config_string() (in module utils)
make_dtime_loss_samples() (easy_tpp.model.torch_model.torch_basemodel.TorchBaseModel method)
make_layer_mask() (easy_tpp.model.torch_model.torch_attnhp.AttNHP method)
make_pad_sequence() (preprocess.EventTokenizer static method)
make_type_mask_for_pad_sequence() (preprocess.EventTokenizer method)
MAX_LENGTH (utils.PaddingStrategy attribute)
MAXIMIZE (utils.MetricsHelper attribute)
metrics_dict_to_str() (utils.MetricsHelper static method)
MetricsHelper (class in utils)
MetricsTracker (class in utils)
MINIMIZE (utils.MetricsHelper attribute)
MixtureSameFamily (class in easy_tpp.model.torch_model.torch_intensity_free)
model_input_names (preprocess.EventTokenizer attribute)
ModelConfig (class in config_factory)
module
config_factory
easy_tpp.model.tf_model
easy_tpp.model.torch_model
easy_tpp.model.torch_model.torch_attnhp
easy_tpp.model.torch_model.torch_baselayer
easy_tpp.model.torch_model.torch_basemodel
easy_tpp.model.torch_model.torch_fullynn
easy_tpp.model.torch_model.torch_intensity_free
easy_tpp.model.torch_model.torch_nhp
easy_tpp.model.torch_model.torch_ode_tpp
easy_tpp.model.torch_model.torch_rmtpp
easy_tpp.model.torch_model.torch_sahp
easy_tpp.model.torch_model.torch_thinning
easy_tpp.model.torch_model.torch_thp
preprocess
runner
utils
MultiHeadAttention (class in easy_tpp.model.torch_model.torch_baselayer)
N
NeuralODE (class in easy_tpp.model.torch_model.torch_ode_tpp)
NeuralODEAdjoint (class in easy_tpp.model.torch_model.torch_ode_tpp)
NHP (class in easy_tpp.model.torch_model.torch_nhp)
Normal (class in easy_tpp.model.torch_model.torch_intensity_free)
O
ODETPP (class in easy_tpp.model.torch_model.torch_ode_tpp)
P
pad() (preprocess.EventTokenizer method)
padding_side (preprocess.EventTokenizer attribute)
PaddingStrategy (class in utils)
parse_from_yaml_config() (config_factory.BaseConfig static method)
(config_factory.Config static method)
(config_factory.DataConfig static method)
(config_factory.DataSpecConfig static method)
(config_factory.HPOConfig static method)
(config_factory.HPORunnerConfig static method)
(config_factory.ModelConfig static method)
(config_factory.RunnerConfig static method)
parse_uri_to_protocol_and_path() (in module utils)
pop() (config_factory.Config method)
PREDICT (utils.RunnerPhase attribute)
predict_multi_step_since_last_event() (easy_tpp.model.torch_model.torch_basemodel.TorchBaseModel method)
predict_one_step_at_every_event() (easy_tpp.model.torch_model.torch_basemodel.TorchBaseModel method)
preprocess
module
py_assert() (in module utils)
R
register() (utils.MetricsHelper static method)
(utils.Registrable class method)
Registrable (class in utils)
registry_dict() (utils.Registrable class method)
requires_backends() (in module utils)
resolve_class_name() (utils.Registrable class method)
rk4_step_method() (in module utils)
RMTPP (class in easy_tpp.model.torch_model.torch_rmtpp)
run() (runner.Runner method)
run_one_epoch() (runner.TPPRunner method)
runner
module
Runner (class in runner)
RunnerConfig (class in config_factory)
RunnerPhase (class in utils)
S
SAHP (class in easy_tpp.model.torch_model.torch_sahp)
sample_accept() (easy_tpp.model.torch_model.torch_thinning.EventSampler method)
sample_exp_distribution() (easy_tpp.model.torch_model.torch_thinning.EventSampler method)
sample_uniform_distribution() (easy_tpp.model.torch_model.torch_thinning.EventSampler method)
save() (runner.Runner method)
save_log() (runner.Runner method)
save_pickle() (in module utils)
save_to_yaml_file() (config_factory.Config method)
save_yaml_config() (in module utils)
seq_encoding() (easy_tpp.model.torch_model.torch_attnhp.AttNHP method)
set() (config_factory.Config method)
set_backend() (config_factory.BaseConfig static method)
set_device() (in module utils)
set_model_dir() (runner.Runner method)
set_optimizer() (in module utils)
set_seed() (in module utils)
start() (utils.Timer method)
state_decay() (easy_tpp.model.torch_model.torch_rmtpp.RMTPP method)
(easy_tpp.model.torch_model.torch_sahp.SAHP method)
storage_path (config_factory.HPOConfig property)
storage_protocol (config_factory.HPOConfig property)
SublayerConnection (class in easy_tpp.model.torch_model.torch_baselayer)
T
test_loader() (preprocess.TPPDataLoader method)
THP (class in easy_tpp.model.torch_model.torch_thp)
TimePositionalEncoding (class in easy_tpp.model.torch_model.torch_baselayer)
Timer (class in utils)
TimeShiftedPositionalEncoding (class in easy_tpp.model.torch_model.torch_baselayer)
to_dict() (in module utils)
to_tf_dataset() (preprocess.TPPDataset method)
TorchBaseModel (class in easy_tpp.model.torch_model.torch_basemodel)
TPPDataLoader (class in preprocess)
TPPDataset (class in preprocess)
TPPRunner (class in runner)
TRAIN (utils.RunnerPhase attribute)
train() (runner.Runner method)
train_loader() (preprocess.TPPDataLoader method)
truncation_side (preprocess.EventTokenizer attribute)
TruncationStrategy (class in utils)
U
update() (config_factory.Config method)
update_best() (utils.MetricsTracker method)
update_config() (config_factory.RunnerConfig method)
utils
module
V
valid_loader() (preprocess.TPPDataLoader method)
VALIDATE (utils.RunnerPhase attribute)