- class darts.utils.callbacks.TFMProgressBar(enable_sanity_check_bar=True, enable_train_bar=True, enable_validation_bar=True, enable_prediction_bar=True, enable_train_bar_only=False, **kwargs)[source]¶
Bases:
TQDMProgressBar
Darts’ Progress Bar for TorchForecastingModels.
Allows to customize for which model stages (sanity checks, training, validation, prediction) to display a progress bar.
This class is a PyTorch Lightning Callback and can be passed to the TorchForecastingModel constructor through the pl_trainer_kwargs parameter.
Examples
>>> from darts.models import NBEATSModel >>> from darts.utils.callbacks import TFMProgressBar >>> # only display the training bar and not the validation, prediction, and sanity check bars >>> prog_bar = TFMProgressBar(enable_train_bar_only=True) >>> model = NBEATSModel(1, 1, pl_trainer_kwargs={"callbacks": [prog_bar]})
- Parameters
enable_sanity_check_bar (
bool
) – Whether to enable to progress bar for sanity checks.enable_train_bar (
bool
) – Whether to enable to progress bar for training.enable_validation_bar (
bool
) – Whether to enable to progress bar for validation.enable_prediction_bar (
bool
) – Whether to enable to progress bar for prediction.enable_train_bar_only (
bool
) – Whether to disable all progress bars except the bar for training.**kwargs – Arguments passed to the PyTorch Lightning’s TQDMProgressBar.
Attributes
Identifier for the state of the callback.
The total number of prediction batches, which may change from epoch to epoch for current dataloader.
The total number of testing batches, which may change from epoch to epoch for current dataloader.
The total number of training batches, which may change from epoch to epoch.
The total number of validation batches, which may change from epoch to epoch for all val dataloaders.
The total number of validation batches, which may change from epoch to epoch for current dataloader.
is_disabled
is_enabled
predict_description
predict_progress_bar
process_position
refresh_rate
sanity_check_description
test_description
test_progress_bar
train_description
train_progress_bar
trainer
val_progress_bar
validation_description
Methods
disable
()You should provide a way to disable the progress bar.
enable
()You should provide a way to enable the progress bar.
get_metrics
(trainer, pl_module)Combines progress bar metrics collected from the trainer with standard metrics from get_standard_metrics.
Override this to customize the tqdm bar for predicting.
Override this to customize the tqdm bar for the validation sanity run.
Override this to customize the tqdm bar for testing.
Override this to customize the tqdm bar for training.
Override this to customize the tqdm bar for validation.
load_state_dict
(state_dict)Called when loading a checkpoint, implement to reload callback state given callback's
state_dict
.on_after_backward
(trainer, pl_module)Called after
loss.backward()
and before optimizers are stepped.on_before_backward
(trainer, pl_module, loss)Called before
loss.backward()
.on_before_optimizer_step
(trainer, pl_module, ...)Called before
optimizer.step()
.on_before_zero_grad
(trainer, pl_module, ...)Called before
optimizer.zero_grad()
.on_exception
(trainer, pl_module, exception)Called when any trainer execution is interrupted by an exception.
on_fit_end
(trainer, pl_module)Called when fit ends.
on_fit_start
(trainer, pl_module)Called when fit begins.
on_load_checkpoint
(trainer, pl_module, ...)Called when loading a model checkpoint, use to reload state.
on_predict_batch_end
(trainer, pl_module, ...)Called when the predict batch ends.
on_predict_batch_start
(trainer, pl_module, ...)Called when the predict batch begins.
on_predict_end
(trainer, pl_module)Called when predict ends.
on_predict_epoch_end
(trainer, pl_module)Called when the predict epoch ends.
on_predict_epoch_start
(trainer, pl_module)Called when the predict epoch begins.
on_predict_start
(trainer, pl_module)Called when the predict begins.
Called when the validation sanity check ends.
Called when the validation sanity check starts.
on_save_checkpoint
(trainer, pl_module, ...)Called when saving a checkpoint to give you a chance to store anything else you might want to save.
on_test_batch_end
(trainer, pl_module, ...[, ...])Called when the test batch ends.
on_test_batch_start
(trainer, pl_module, ...)Called when the test batch begins.
on_test_end
(trainer, pl_module)Called when the test ends.
on_test_epoch_end
(trainer, pl_module)Called when the test epoch ends.
on_test_epoch_start
(trainer, pl_module)Called when the test epoch begins.
on_test_start
(trainer, pl_module)Called when the test begins.
on_train_batch_end
(trainer, pl_module, ...)Called when the train batch ends.
on_train_batch_start
(trainer, pl_module, ...)Called when the train batch begins.
on_train_end
(*_)Called when the train ends.
on_train_epoch_end
(trainer, pl_module)Called when the train epoch ends.
on_train_epoch_start
(trainer, *_)Called when the train epoch begins.
on_train_start
(*_)Called when the train begins.
on_validation_batch_end
(trainer, pl_module, ...)Called when the validation batch ends.
on_validation_batch_start
(trainer, ...[, ...])Called when the validation batch begins.
on_validation_end
(trainer, pl_module)Called when the validation loop ends.
on_validation_epoch_end
(trainer, pl_module)Called when the val epoch ends.
on_validation_epoch_start
(trainer, pl_module)Called when the val epoch begins.
on_validation_start
(trainer, pl_module)Called when the validation loop begins.
print
(*args[, sep])You should provide a way to print without breaking the progress bar.
setup
(trainer, pl_module, stage)Called when fit, validate, test, predict, or tune begins.
Called when saving a checkpoint, implement to generate callback's
state_dict
.teardown
(trainer, pl_module, stage)Called when fit, validate, test, predict, or tune ends.
has_dataloader_changed
reset_dataloader_idx_tracker
- BAR_FORMAT = '{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_noinv_fmt}{postfix}]'¶
- disable()¶
You should provide a way to disable the progress bar.
- Return type
None
- enable()¶
You should provide a way to enable the progress bar.
The
Trainer
will call this in e.g. pre-training routines like the learning rate finder. to temporarily enable and disable the training progress bar.- Return type
None
- get_metrics(trainer, pl_module)¶
Combines progress bar metrics collected from the trainer with standard metrics from get_standard_metrics. Implement this to override the items displayed in the progress bar.
Here is an example of how to override the defaults:
def get_metrics(self, trainer, model): # don't show the version number items = super().get_metrics(trainer, model) items.pop("v_num", None) return items
- Return type
dict
[str
,Union
[int
,str
,float
,dict
[str
,float
]]]- Returns
Dictionary with the items to be displayed in the progress bar.
- has_dataloader_changed(dataloader_idx)¶
- Return type
bool
- init_predict_tqdm()[source]¶
Override this to customize the tqdm bar for predicting.
- Return type
Tqdm
- init_sanity_tqdm()[source]¶
Override this to customize the tqdm bar for the validation sanity run.
- Return type
Tqdm
- init_test_tqdm()¶
Override this to customize the tqdm bar for testing.
- Return type
Tqdm
- init_validation_tqdm()[source]¶
Override this to customize the tqdm bar for validation.
- Return type
Tqdm
- property is_disabled: bool¶
- Return type
bool
- property is_enabled: bool¶
- Return type
bool
- load_state_dict(state_dict)¶
Called when loading a checkpoint, implement to reload callback state given callback’s
state_dict
.- Parameters
state_dict (
dict
[str
,Any
]) – the callback state returned bystate_dict
.- Return type
None
- on_after_backward(trainer, pl_module)¶
Called after
loss.backward()
and before optimizers are stepped.- Return type
None
- on_before_backward(trainer, pl_module, loss)¶
Called before
loss.backward()
.- Return type
None
- on_before_optimizer_step(trainer, pl_module, optimizer)¶
Called before
optimizer.step()
.- Return type
None
- on_before_zero_grad(trainer, pl_module, optimizer)¶
Called before
optimizer.zero_grad()
.- Return type
None
- on_exception(trainer, pl_module, exception)¶
Called when any trainer execution is interrupted by an exception.
- Return type
None
- on_fit_end(trainer, pl_module)¶
Called when fit ends.
- Return type
None
- on_fit_start(trainer, pl_module)¶
Called when fit begins.
- Return type
None
- on_load_checkpoint(trainer, pl_module, checkpoint)¶
Called when loading a model checkpoint, use to reload state.
- Parameters
trainer (
Trainer
) – the currentTrainer
instance.pl_module (
LightningModule
) – the currentLightningModule
instance.checkpoint (
dict
[str
,Any
]) – the full checkpoint dictionary that got loaded by the Trainer.
- Return type
None
- on_predict_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0)¶
Called when the predict batch ends.
- Return type
None
- on_predict_batch_start(trainer, pl_module, batch, batch_idx, dataloader_idx=0)¶
Called when the predict batch begins.
- Return type
None
- on_predict_end(trainer, pl_module)¶
Called when predict ends.
- Return type
None
- on_predict_epoch_end(trainer, pl_module)¶
Called when the predict epoch ends.
- Return type
None
- on_predict_epoch_start(trainer, pl_module)¶
Called when the predict epoch begins.
- Return type
None
- on_predict_start(trainer, pl_module)¶
Called when the predict begins.
- Return type
None
- on_sanity_check_end(*_)¶
Called when the validation sanity check ends.
- Return type
None
- on_sanity_check_start(*_)¶
Called when the validation sanity check starts.
- Return type
None
- on_save_checkpoint(trainer, pl_module, checkpoint)¶
Called when saving a checkpoint to give you a chance to store anything else you might want to save.
- Parameters
trainer (
Trainer
) – the currentTrainer
instance.pl_module (
LightningModule
) – the currentLightningModule
instance.checkpoint (
dict
[str
,Any
]) – the checkpoint dictionary that will be saved.
- Return type
None
- on_test_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0)¶
Called when the test batch ends.
- Return type
None
- on_test_batch_start(trainer, pl_module, batch, batch_idx, dataloader_idx=0)¶
Called when the test batch begins.
- Return type
None
- on_test_end(trainer, pl_module)¶
Called when the test ends.
- Return type
None
- on_test_epoch_end(trainer, pl_module)¶
Called when the test epoch ends.
- Return type
None
- on_test_epoch_start(trainer, pl_module)¶
Called when the test epoch begins.
- Return type
None
- on_test_start(trainer, pl_module)¶
Called when the test begins.
- Return type
None
- on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx)¶
Called when the train batch ends.
Note
The value
outputs["loss"]
here will be the normalized value w.r.taccumulate_grad_batches
of the loss returned fromtraining_step
.- Return type
None
- on_train_batch_start(trainer, pl_module, batch, batch_idx)¶
Called when the train batch begins.
- Return type
None
- on_train_end(*_)¶
Called when the train ends.
- Return type
None
- on_train_epoch_end(trainer, pl_module)¶
Called when the train epoch ends.
To access all batch outputs at the end of the epoch, you can cache step outputs as an attribute of the
pytorch_lightning.core.LightningModule
and access them in this hook:class MyLightningModule(L.LightningModule): def __init__(self): super().__init__() self.training_step_outputs = [] def training_step(self): loss = ... self.training_step_outputs.append(loss) return loss class MyCallback(L.Callback): def on_train_epoch_end(self, trainer, pl_module): # do something with all training_step outputs, for example: epoch_mean = torch.stack(pl_module.training_step_outputs).mean() pl_module.log("training_epoch_mean", epoch_mean) # free up the memory pl_module.training_step_outputs.clear()
- Return type
None
- on_train_epoch_start(trainer, *_)¶
Called when the train epoch begins.
- Return type
None
- on_train_start(*_)¶
Called when the train begins.
- Return type
None
- on_validation_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0)¶
Called when the validation batch ends.
- Return type
None
- on_validation_batch_start(trainer, pl_module, batch, batch_idx, dataloader_idx=0)¶
Called when the validation batch begins.
- Return type
None
- on_validation_end(trainer, pl_module)¶
Called when the validation loop ends.
- Return type
None
- on_validation_epoch_end(trainer, pl_module)¶
Called when the val epoch ends.
- Return type
None
- on_validation_epoch_start(trainer, pl_module)¶
Called when the val epoch begins.
- Return type
None
- on_validation_start(trainer, pl_module)¶
Called when the validation loop begins.
- Return type
None
- property predict_description: str¶
- Return type
str
- property predict_progress_bar: tqdm_asyncio¶
- Return type
tqdm_asyncio
- print(*args, sep=' ', **kwargs)¶
You should provide a way to print without breaking the progress bar.
- Return type
None
- property process_position: int¶
- Return type
int
- property refresh_rate: int¶
- Return type
int
- reset_dataloader_idx_tracker()¶
- Return type
None
- property sanity_check_description: str¶
- Return type
str
- setup(trainer, pl_module, stage)¶
Called when fit, validate, test, predict, or tune begins.
- Return type
None
- state_dict()¶
Called when saving a checkpoint, implement to generate callback’s
state_dict
.- Return type
dict
[str
,Any
]- Returns
A dictionary containing callback state.
- property state_key: str¶
Identifier for the state of the callback.
Used to store and retrieve a callback’s state from the checkpoint dictionary by
checkpoint["callbacks"][state_key]
. Implementations of a callback need to provide a unique state key if 1) the callback has state and 2) it is desired to maintain the state of multiple instances of that callback.- Return type
str
- teardown(trainer, pl_module, stage)¶
Called when fit, validate, test, predict, or tune ends.
- Return type
None
- property test_description: str¶
- Return type
str
- property test_progress_bar: tqdm_asyncio¶
- Return type
tqdm_asyncio
- property total_predict_batches_current_dataloader: Union[int, float]¶
The total number of prediction batches, which may change from epoch to epoch for current dataloader.
Use this to set the total number of iterations in the progress bar. Can return
inf
if the predict dataloader is of infinite size.- Return type
Union
[int
,float
]
- property total_test_batches_current_dataloader: Union[int, float]¶
The total number of testing batches, which may change from epoch to epoch for current dataloader.
Use this to set the total number of iterations in the progress bar. Can return
inf
if the test dataloader is of infinite size.- Return type
Union
[int
,float
]
- property total_train_batches: Union[int, float]¶
The total number of training batches, which may change from epoch to epoch.
Use this to set the total number of iterations in the progress bar. Can return
inf
if the training dataloader is of infinite size.- Return type
Union
[int
,float
]
- property total_val_batches: Union[int, float]¶
The total number of validation batches, which may change from epoch to epoch for all val dataloaders.
Use this to set the total number of iterations in the progress bar. Can return
inf
if the predict dataloader is of infinite size.- Return type
Union
[int
,float
]
- property total_val_batches_current_dataloader: Union[int, float]¶
The total number of validation batches, which may change from epoch to epoch for current dataloader.
Use this to set the total number of iterations in the progress bar. Can return
inf
if the validation dataloader is of infinite size.- Return type
Union
[int
,float
]
- property train_description: str¶
- Return type
str
- property train_progress_bar: tqdm_asyncio¶
- Return type
tqdm_asyncio
- property trainer: Trainer¶
- Return type
Trainer
- property val_progress_bar: tqdm_asyncio¶
- Return type
tqdm_asyncio
- property validation_description: str¶
- Return type
str