Callbacks for TorchForecastingModel#
- class darts.utils.callbacks.PyTorchLightningPruningCallback(trial, monitor)[source]#
Bases:
CallbackPyTorch Lightning callback to prune unpromising Optuna trials.
Reports the monitored metric to the Optuna trial after each validation epoch and raises
optuna.TrialPrunedwhentrial.should_prune()returnsTrue.For distributed (DDP) training,
Studymust use RDB storage, andcheck_pruned()must be called manually afterTrainer.fit()completes.- Parameters:
trial – A
Trialcorresponding to the current evaluation of the objective function.monitor (
str) – An evaluation metric for pruning, e.g.,val_lossorval_acc. The metrics are obtained from the returned dictionaries from e.g.lightning.pytorch.LightningModule.training_steporlightning.pytorch.LightningModule.validation_epoch_endand the names thus depend on how this dictionary is formatted.
Examples
>>> import optuna >>> from darts.utils.callbacks import PyTorchLightningPruningCallback >>> def objective(trial): ... pruner = PyTorchLightningPruningCallback(trial, monitor="val_loss") ... model = TCNModel(..., pl_trainer_kwargs={"callbacks": [pruner]}) ... model.fit(...)
Attributes
Identifier for the state of the callback.
Methods
Raise
optuna.TrialPrunedmanually if pruned.load_state_dict(state_dict)Called when loading a checkpoint, implement to reload callback state given callback's
state_dict.on_after_backward(trainer, pl_module)Called after
loss.backward()and before optimizers are stepped.on_before_backward(trainer, pl_module, loss)Called before
loss.backward().on_before_optimizer_step(trainer, pl_module, ...)Called before
optimizer.step().on_before_zero_grad(trainer, pl_module, ...)Called before
optimizer.zero_grad().on_exception(trainer, pl_module, exception)Called when any trainer execution is interrupted by an exception.
on_fit_end(trainer, pl_module)Called when fit ends.
on_fit_start(trainer, pl_module)Called when fit begins.
on_load_checkpoint(trainer, pl_module, ...)Called when loading a model checkpoint, use to reload state.
on_predict_batch_end(trainer, pl_module, ...)Called when the predict batch ends.
on_predict_batch_start(trainer, pl_module, ...)Called when the predict batch begins.
on_predict_end(trainer, pl_module)Called when predict ends.
on_predict_epoch_end(trainer, pl_module)Called when the predict epoch ends.
on_predict_epoch_start(trainer, pl_module)Called when the predict epoch begins.
on_predict_start(trainer, pl_module)Called when the predict begins.
on_sanity_check_end(trainer, pl_module)Called when the validation sanity check ends.
on_sanity_check_start(trainer, pl_module)Called when the validation sanity check starts.
on_save_checkpoint(trainer, pl_module, ...)Called when saving a checkpoint to give you a chance to store anything else you might want to save.
on_test_batch_end(trainer, pl_module, ...[, ...])Called when the test batch ends.
on_test_batch_start(trainer, pl_module, ...)Called when the test batch begins.
on_test_end(trainer, pl_module)Called when the test ends.
on_test_epoch_end(trainer, pl_module)Called when the test epoch ends.
on_test_epoch_start(trainer, pl_module)Called when the test epoch begins.
on_test_start(trainer, pl_module)Called when the test begins.
on_train_batch_end(trainer, pl_module, ...)Called when the train batch ends.
on_train_batch_start(trainer, pl_module, ...)Called when the train batch begins.
on_train_end(trainer, pl_module)Called when the train ends.
on_train_epoch_end(trainer, pl_module)Called when the train epoch ends.
on_train_epoch_start(trainer, pl_module)Called when the train epoch begins.
on_train_start(trainer, pl_module)Called when the train begins.
on_validation_batch_end(trainer, pl_module, ...)Called when the validation batch ends.
on_validation_batch_start(trainer, ...[, ...])Called when the validation batch begins.
on_validation_end(trainer, pl_module)Called when the validation loop ends.
on_validation_epoch_end(trainer, pl_module)Called when the val epoch ends.
on_validation_epoch_start(trainer, pl_module)Called when the val epoch begins.
on_validation_start(trainer, pl_module)Called when the validation loop begins.
setup(trainer, pl_module, stage)Called when fit, validate, test, predict, or tune begins.
Called when saving a checkpoint, implement to generate callback's
state_dict.teardown(trainer, pl_module, stage)Called when fit, validate, test, predict, or tune ends.
- check_pruned()[source]#
Raise
optuna.TrialPrunedmanually if pruned.Currently,
intermediate_valuesare not properly propagated between processes due to storage cache. Therefore, necessary information is kept intrial.system_attrswhen the trial runs in a distributed situation. Please call this method right after callinglightning.pytorch.Trainer.fit(). If a callback doesn’t have any backend storage for DDP, this method does nothing.- Return type:
None
- load_state_dict(state_dict)#
Called when loading a checkpoint, implement to reload callback state given callback’s
state_dict.- Parameters:
state_dict (
dict[str,Any]) – the callback state returned bystate_dict.- Return type:
None
- on_after_backward(trainer, pl_module)#
Called after
loss.backward()and before optimizers are stepped.- Return type:
None
- on_before_backward(trainer, pl_module, loss)#
Called before
loss.backward().- Return type:
None
- on_before_optimizer_step(trainer, pl_module, optimizer)#
Called before
optimizer.step().- Return type:
None
- on_before_zero_grad(trainer, pl_module, optimizer)#
Called before
optimizer.zero_grad().- Return type:
None
- on_exception(trainer, pl_module, exception)#
Called when any trainer execution is interrupted by an exception.
- Return type:
None
- on_fit_end(trainer, pl_module)#
Called when fit ends.
- Return type:
None
- on_load_checkpoint(trainer, pl_module, checkpoint)#
Called when loading a model checkpoint, use to reload state.
- Parameters:
trainer (
Trainer) – the currentTrainerinstance.pl_module (
LightningModule) – the currentLightningModuleinstance.checkpoint (
dict[str,Any]) – the full checkpoint dictionary that got loaded by the Trainer.
- Return type:
None
- on_predict_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0)#
Called when the predict batch ends.
- Return type:
None
- on_predict_batch_start(trainer, pl_module, batch, batch_idx, dataloader_idx=0)#
Called when the predict batch begins.
- Return type:
None
- on_predict_end(trainer, pl_module)#
Called when predict ends.
- Return type:
None
- on_predict_epoch_end(trainer, pl_module)#
Called when the predict epoch ends.
- Return type:
None
- on_predict_epoch_start(trainer, pl_module)#
Called when the predict epoch begins.
- Return type:
None
- on_predict_start(trainer, pl_module)#
Called when the predict begins.
- Return type:
None
- on_sanity_check_end(trainer, pl_module)#
Called when the validation sanity check ends.
- Return type:
None
- on_sanity_check_start(trainer, pl_module)#
Called when the validation sanity check starts.
- Return type:
None
- on_save_checkpoint(trainer, pl_module, checkpoint)#
Called when saving a checkpoint to give you a chance to store anything else you might want to save.
- Parameters:
trainer (
Trainer) – the currentTrainerinstance.pl_module (
LightningModule) – the currentLightningModuleinstance.checkpoint (
dict[str,Any]) – the checkpoint dictionary that will be saved.
- Return type:
None
- on_test_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0)#
Called when the test batch ends.
- Return type:
None
- on_test_batch_start(trainer, pl_module, batch, batch_idx, dataloader_idx=0)#
Called when the test batch begins.
- Return type:
None
- on_test_end(trainer, pl_module)#
Called when the test ends.
- Return type:
None
- on_test_epoch_end(trainer, pl_module)#
Called when the test epoch ends.
- Return type:
None
- on_test_epoch_start(trainer, pl_module)#
Called when the test epoch begins.
- Return type:
None
- on_test_start(trainer, pl_module)#
Called when the test begins.
- Return type:
None
- on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx)#
Called when the train batch ends. :rtype:
NoneNote
The value
outputs["loss"]here will be the normalized value w.r.taccumulate_grad_batchesof the loss returned fromtraining_step.
- on_train_batch_start(trainer, pl_module, batch, batch_idx)#
Called when the train batch begins.
- Return type:
None
- on_train_end(trainer, pl_module)#
Called when the train ends.
- Return type:
None
- on_train_epoch_end(trainer, pl_module)#
Called when the train epoch ends.
To access all batch outputs at the end of the epoch, you can cache step outputs as an attribute of the
pytorch_lightning.core.LightningModuleand access them in this hook:class MyLightningModule(L.LightningModule): def __init__(self): super().__init__() self.training_step_outputs = [] def training_step(self): loss = ... self.training_step_outputs.append(loss) return loss class MyCallback(L.Callback): def on_train_epoch_end(self, trainer, pl_module): # do something with all training_step outputs, for example: epoch_mean = torch.stack(pl_module.training_step_outputs).mean() pl_module.log("training_epoch_mean", epoch_mean) # free up the memory pl_module.training_step_outputs.clear()
- Return type:
None
- on_train_epoch_start(trainer, pl_module)#
Called when the train epoch begins.
- Return type:
None
- on_train_start(trainer, pl_module)#
Called when the train begins.
- Return type:
None
- on_validation_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0)#
Called when the validation batch ends.
- Return type:
None
- on_validation_batch_start(trainer, pl_module, batch, batch_idx, dataloader_idx=0)#
Called when the validation batch begins.
- Return type:
None
- on_validation_end(trainer, pl_module)[source]#
Called when the validation loop ends.
- Return type:
None
- on_validation_epoch_end(trainer, pl_module)#
Called when the val epoch ends.
- Return type:
None
- on_validation_epoch_start(trainer, pl_module)#
Called when the val epoch begins.
- Return type:
None
- on_validation_start(trainer, pl_module)#
Called when the validation loop begins.
- Return type:
None
- setup(trainer, pl_module, stage)#
Called when fit, validate, test, predict, or tune begins.
- Return type:
None
- state_dict()#
Called when saving a checkpoint, implement to generate callback’s
state_dict.- Return type:
dict[str,Any]- Returns:
A dictionary containing callback state.
- property state_key: str#
Identifier for the state of the callback.
Used to store and retrieve a callback’s state from the checkpoint dictionary by
checkpoint["callbacks"][state_key]. Implementations of a callback need to provide a unique state key if 1) the callback has state and 2) it is desired to maintain the state of multiple instances of that callback.
- teardown(trainer, pl_module, stage)#
Called when fit, validate, test, predict, or tune ends.
- Return type:
None
- class darts.utils.callbacks.TFMProgressBar(enable_sanity_check_bar=True, enable_train_bar=True, enable_validation_bar=True, enable_prediction_bar=True, enable_train_bar_only=False, **kwargs)[source]#
Bases:
TQDMProgressBarDarts’ Progress Bar for TorchForecastingModels.
Allows to customize for which model stages (sanity checks, training, validation, prediction) to display a progress bar.
This class is a PyTorch Lightning Callback and can be passed to the TorchForecastingModel constructor through the pl_trainer_kwargs parameter.
Examples
>>> from darts.models import NBEATSModel >>> from darts.utils.callbacks import TFMProgressBar >>> # only display the training bar and not the validation, prediction, and sanity check bars >>> prog_bar = TFMProgressBar(enable_train_bar_only=True) >>> model = NBEATSModel(1, 1, pl_trainer_kwargs={"callbacks": [prog_bar]})
- Parameters:
enable_sanity_check_bar (
bool) – Whether to enable to progress bar for sanity checks.enable_train_bar (
bool) – Whether to enable to progress bar for training.enable_validation_bar (
bool) – Whether to enable to progress bar for validation.enable_prediction_bar (
bool) – Whether to enable to progress bar for prediction.enable_train_bar_only (
bool) – Whether to disable all progress bars except the bar for training.**kwargs – Arguments passed to the PyTorch Lightning’s TQDMProgressBar.
Attributes
state_keyIdentifier for the state of the callback.
total_predict_batches_current_dataloaderThe total number of prediction batches, which may change from epoch to epoch for current dataloader.
total_test_batches_current_dataloaderThe total number of testing batches, which may change from epoch to epoch for current dataloader.
total_train_batchesThe total number of training batches, which may change from epoch to epoch.
total_val_batchesThe total number of validation batches, which may change from epoch to epoch for all val dataloaders.
total_val_batches_current_dataloaderThe total number of validation batches, which may change from epoch to epoch for current dataloader.
is_disabled
is_enabled
predict_description
predict_progress_bar
process_position
refresh_rate
sanity_check_description
test_description
test_progress_bar
train_description
train_progress_bar
trainer
val_progress_bar
validation_description
Methods
disable()You should provide a way to disable the progress bar.
enable()You should provide a way to enable the progress bar.
get_metrics(trainer, pl_module)Combines progress bar metrics collected from the trainer with standard metrics from get_standard_metrics.
Override this to customize the tqdm bar for predicting.
Override this to customize the tqdm bar for the validation sanity run.
init_test_tqdm()Override this to customize the tqdm bar for testing.
Override this to customize the tqdm bar for training.
Override this to customize the tqdm bar for validation.
load_state_dict(state_dict)Called when loading a checkpoint, implement to reload callback state given callback's
state_dict.on_after_backward(trainer, pl_module)Called after
loss.backward()and before optimizers are stepped.on_before_backward(trainer, pl_module, loss)Called before
loss.backward().on_before_optimizer_step(trainer, pl_module, ...)Called before
optimizer.step().on_before_zero_grad(trainer, pl_module, ...)Called before
optimizer.zero_grad().on_exception(trainer, pl_module, exception)Called when any trainer execution is interrupted by an exception.
on_fit_end(trainer, pl_module)Called when fit ends.
on_fit_start(trainer, pl_module)Called when fit begins.
on_load_checkpoint(trainer, pl_module, ...)Called when loading a model checkpoint, use to reload state.
on_predict_batch_end(trainer, pl_module, ...)Called when the predict batch ends.
on_predict_batch_start(trainer, pl_module, ...)Called when the predict batch begins.
on_predict_end(trainer, pl_module)Called when predict ends.
on_predict_epoch_end(trainer, pl_module)Called when the predict epoch ends.
on_predict_epoch_start(trainer, pl_module)Called when the predict epoch begins.
on_predict_start(trainer, pl_module)Called when the predict begins.
on_sanity_check_end(*_)Called when the validation sanity check ends.
on_sanity_check_start(*_)Called when the validation sanity check starts.
on_save_checkpoint(trainer, pl_module, ...)Called when saving a checkpoint to give you a chance to store anything else you might want to save.
on_test_batch_end(trainer, pl_module, ...[, ...])Called when the test batch ends.
on_test_batch_start(trainer, pl_module, ...)Called when the test batch begins.
on_test_end(trainer, pl_module)Called when the test ends.
on_test_epoch_end(trainer, pl_module)Called when the test epoch ends.
on_test_epoch_start(trainer, pl_module)Called when the test epoch begins.
on_test_start(trainer, pl_module)Called when the test begins.
on_train_batch_end(trainer, pl_module, ...)Called when the train batch ends.
on_train_batch_start(trainer, pl_module, ...)Called when the train batch begins.
on_train_end(*_)Called when the train ends.
on_train_epoch_end(trainer, pl_module)Called when the train epoch ends.
on_train_epoch_start(trainer, *_)Called when the train epoch begins.
on_train_start(*_)Called when the train begins.
on_validation_batch_end(trainer, pl_module, ...)Called when the validation batch ends.
on_validation_batch_start(trainer, ...[, ...])Called when the validation batch begins.
on_validation_end(trainer, pl_module)Called when the validation loop ends.
on_validation_epoch_end(trainer, pl_module)Called when the val epoch ends.
on_validation_epoch_start(trainer, pl_module)Called when the val epoch begins.
on_validation_start(trainer, pl_module)Called when the validation loop begins.
print(*args[, sep])You should provide a way to print without breaking the progress bar.
setup(trainer, pl_module, stage)Called when fit, validate, test, predict, or tune begins.
state_dict()Called when saving a checkpoint, implement to generate callback's
state_dict.teardown(trainer, pl_module, stage)Called when fit, validate, test, predict, or tune ends.
has_dataloader_changed
reset_dataloader_idx_tracker
- init_predict_tqdm()[source]#
Override this to customize the tqdm bar for predicting.
- Return type:
Tqdm