Callbacks for TorchForecastingModel#
- class darts.utils.callbacks.TFMProgressBar(enable_sanity_check_bar=True, enable_train_bar=True, enable_validation_bar=True, enable_prediction_bar=True, enable_train_bar_only=False, **kwargs)[source]#
Bases:
TQDMProgressBarDarts’ Progress Bar for TorchForecastingModels.
Allows to customize for which model stages (sanity checks, training, validation, prediction) to display a progress bar.
This class is a PyTorch Lightning Callback and can be passed to the TorchForecastingModel constructor through the pl_trainer_kwargs parameter.
Examples
>>> from darts.models import NBEATSModel >>> from darts.utils.callbacks import TFMProgressBar >>> # only display the training bar and not the validation, prediction, and sanity check bars >>> prog_bar = TFMProgressBar(enable_train_bar_only=True) >>> model = NBEATSModel(1, 1, pl_trainer_kwargs={"callbacks": [prog_bar]})
- Parameters:
enable_sanity_check_bar (
bool) – Whether to enable to progress bar for sanity checks.enable_train_bar (
bool) – Whether to enable to progress bar for training.enable_validation_bar (
bool) – Whether to enable to progress bar for validation.enable_prediction_bar (
bool) – Whether to enable to progress bar for prediction.enable_train_bar_only (
bool) – Whether to disable all progress bars except the bar for training.**kwargs – Arguments passed to the PyTorch Lightning’s TQDMProgressBar.
Attributes
state_keyIdentifier for the state of the callback.
total_predict_batches_current_dataloaderThe total number of prediction batches, which may change from epoch to epoch for current dataloader.
total_test_batches_current_dataloaderThe total number of testing batches, which may change from epoch to epoch for current dataloader.
total_train_batchesThe total number of training batches, which may change from epoch to epoch.
total_val_batchesThe total number of validation batches, which may change from epoch to epoch for all val dataloaders.
total_val_batches_current_dataloaderThe total number of validation batches, which may change from epoch to epoch for current dataloader.
is_disabled
is_enabled
predict_description
predict_progress_bar
process_position
refresh_rate
sanity_check_description
test_description
test_progress_bar
train_description
train_progress_bar
trainer
val_progress_bar
validation_description
Methods
disable()You should provide a way to disable the progress bar.
enable()You should provide a way to enable the progress bar.
get_metrics(trainer, pl_module)Combines progress bar metrics collected from the trainer with standard metrics from get_standard_metrics.
Override this to customize the tqdm bar for predicting.
Override this to customize the tqdm bar for the validation sanity run.
init_test_tqdm()Override this to customize the tqdm bar for testing.
Override this to customize the tqdm bar for training.
Override this to customize the tqdm bar for validation.
load_state_dict(state_dict)Called when loading a checkpoint, implement to reload callback state given callback's
state_dict.on_after_backward(trainer, pl_module)Called after
loss.backward()and before optimizers are stepped.on_before_backward(trainer, pl_module, loss)Called before
loss.backward().on_before_optimizer_step(trainer, pl_module, ...)Called before
optimizer.step().on_before_zero_grad(trainer, pl_module, ...)Called before
optimizer.zero_grad().on_exception(trainer, pl_module, exception)Called when any trainer execution is interrupted by an exception.
on_fit_end(trainer, pl_module)Called when fit ends.
on_fit_start(trainer, pl_module)Called when fit begins.
on_load_checkpoint(trainer, pl_module, ...)Called when loading a model checkpoint, use to reload state.
on_predict_batch_end(trainer, pl_module, ...)Called when the predict batch ends.
on_predict_batch_start(trainer, pl_module, ...)Called when the predict batch begins.
on_predict_end(trainer, pl_module)Called when predict ends.
on_predict_epoch_end(trainer, pl_module)Called when the predict epoch ends.
on_predict_epoch_start(trainer, pl_module)Called when the predict epoch begins.
on_predict_start(trainer, pl_module)Called when the predict begins.
on_sanity_check_end(*_)Called when the validation sanity check ends.
on_sanity_check_start(*_)Called when the validation sanity check starts.
on_save_checkpoint(trainer, pl_module, ...)Called when saving a checkpoint to give you a chance to store anything else you might want to save.
on_test_batch_end(trainer, pl_module, ...[, ...])Called when the test batch ends.
on_test_batch_start(trainer, pl_module, ...)Called when the test batch begins.
on_test_end(trainer, pl_module)Called when the test ends.
on_test_epoch_end(trainer, pl_module)Called when the test epoch ends.
on_test_epoch_start(trainer, pl_module)Called when the test epoch begins.
on_test_start(trainer, pl_module)Called when the test begins.
on_train_batch_end(trainer, pl_module, ...)Called when the train batch ends.
on_train_batch_start(trainer, pl_module, ...)Called when the train batch begins.
on_train_end(*_)Called when the train ends.
on_train_epoch_end(trainer, pl_module)Called when the train epoch ends.
on_train_epoch_start(trainer, *_)Called when the train epoch begins.
on_train_start(*_)Called when the train begins.
on_validation_batch_end(trainer, pl_module, ...)Called when the validation batch ends.
on_validation_batch_start(trainer, ...[, ...])Called when the validation batch begins.
on_validation_end(trainer, pl_module)Called when the validation loop ends.
on_validation_epoch_end(trainer, pl_module)Called when the val epoch ends.
on_validation_epoch_start(trainer, pl_module)Called when the val epoch begins.
on_validation_start(trainer, pl_module)Called when the validation loop begins.
print(*args[, sep])You should provide a way to print without breaking the progress bar.
setup(trainer, pl_module, stage)Called when fit, validate, test, predict, or tune begins.
state_dict()Called when saving a checkpoint, implement to generate callback's
state_dict.teardown(trainer, pl_module, stage)Called when fit, validate, test, predict, or tune ends.
has_dataloader_changed
reset_dataloader_idx_tracker
- init_predict_tqdm()[source]#
Override this to customize the tqdm bar for predicting.
- Return type:
Tqdm