class darts.utils.callbacks.TFMProgressBar(enable_sanity_check_bar=True, enable_train_bar=True, enable_validation_bar=True, enable_prediction_bar=True, enable_train_bar_only=False, **kwargs)[source]

Bases: TQDMProgressBar

Darts’ Progress Bar for TorchForecastingModels.

Allows to customize for which model stages (sanity checks, training, validation, prediction) to display a progress bar.

This class is a PyTorch Lightning Callback and can be passed to the TorchForecastingModel constructor through the pl_trainer_kwargs parameter.

Examples

>>> from darts.models import NBEATSModel
>>> from darts.utils.callbacks import TFMProgressBar
>>> # only display the training bar and not the validation, prediction, and sanity check bars
>>> prog_bar = TFMProgressBar(enable_train_bar_only=True)
>>> model = NBEATSModel(1, 1, pl_trainer_kwargs={"callbacks": [prog_bar]})
Parameters
  • enable_sanity_check_bar (bool) – Whether to enable to progress bar for sanity checks.

  • enable_train_bar (bool) – Whether to enable to progress bar for training.

  • enable_validation_bar (bool) – Whether to enable to progress bar for validation.

  • enable_prediction_bar (bool) – Whether to enable to progress bar for prediction.

  • enable_train_bar_only (bool) – Whether to disable all progress bars except the bar for training.

  • **kwargs – Arguments passed to the PyTorch Lightning’s TQDMProgressBar.

Attributes

state_key

Identifier for the state of the callback.

total_predict_batches_current_dataloader

The total number of prediction batches, which may change from epoch to epoch for current dataloader.

total_test_batches_current_dataloader

The total number of testing batches, which may change from epoch to epoch for current dataloader.

total_train_batches

The total number of training batches, which may change from epoch to epoch.

total_val_batches

The total number of validation batches, which may change from epoch to epoch for all val dataloaders.

total_val_batches_current_dataloader

The total number of validation batches, which may change from epoch to epoch for current dataloader.

is_disabled

is_enabled

predict_description

predict_progress_bar

process_position

refresh_rate

sanity_check_description

test_description

test_progress_bar

train_description

train_progress_bar

trainer

val_progress_bar

validation_description

Methods

disable()

You should provide a way to disable the progress bar.

enable()

You should provide a way to enable the progress bar.

get_metrics(trainer, pl_module)

Combines progress bar metrics collected from the trainer with standard metrics from get_standard_metrics.

init_predict_tqdm()

Override this to customize the tqdm bar for predicting.

init_sanity_tqdm()

Override this to customize the tqdm bar for the validation sanity run.

init_test_tqdm()

Override this to customize the tqdm bar for testing.

init_train_tqdm()

Override this to customize the tqdm bar for training.

init_validation_tqdm()

Override this to customize the tqdm bar for validation.

load_state_dict(state_dict)

Called when loading a checkpoint, implement to reload callback state given callback's state_dict.

on_after_backward(trainer, pl_module)

Called after loss.backward() and before optimizers are stepped.

on_before_backward(trainer, pl_module, loss)

Called before loss.backward().

on_before_optimizer_step(trainer, pl_module, ...)

Called before optimizer.step().

on_before_zero_grad(trainer, pl_module, ...)

Called before optimizer.zero_grad().

on_exception(trainer, pl_module, exception)

Called when any trainer execution is interrupted by an exception.

on_fit_end(trainer, pl_module)

Called when fit ends.

on_fit_start(trainer, pl_module)

Called when fit begins.

on_load_checkpoint(trainer, pl_module, ...)

Called when loading a model checkpoint, use to reload state.

on_predict_batch_end(trainer, pl_module, ...)

Called when the predict batch ends.

on_predict_batch_start(trainer, pl_module, ...)

Called when the predict batch begins.

on_predict_end(trainer, pl_module)

Called when predict ends.

on_predict_epoch_end(trainer, pl_module)

Called when the predict epoch ends.

on_predict_epoch_start(trainer, pl_module)

Called when the predict epoch begins.

on_predict_start(trainer, pl_module)

Called when the predict begins.

on_sanity_check_end(*_)

Called when the validation sanity check ends.

on_sanity_check_start(*_)

Called when the validation sanity check starts.

on_save_checkpoint(trainer, pl_module, ...)

Called when saving a checkpoint to give you a chance to store anything else you might want to save.

on_test_batch_end(trainer, pl_module, ...[, ...])

Called when the test batch ends.

on_test_batch_start(trainer, pl_module, ...)

Called when the test batch begins.

on_test_end(trainer, pl_module)

Called when the test ends.

on_test_epoch_end(trainer, pl_module)

Called when the test epoch ends.

on_test_epoch_start(trainer, pl_module)

Called when the test epoch begins.

on_test_start(trainer, pl_module)

Called when the test begins.

on_train_batch_end(trainer, pl_module, ...)

Called when the train batch ends.

on_train_batch_start(trainer, pl_module, ...)

Called when the train batch begins.

on_train_end(*_)

Called when the train ends.

on_train_epoch_end(trainer, pl_module)

Called when the train epoch ends.

on_train_epoch_start(trainer, *_)

Called when the train epoch begins.

on_train_start(*_)

Called when the train begins.

on_validation_batch_end(trainer, pl_module, ...)

Called when the validation batch ends.

on_validation_batch_start(trainer, ...[, ...])

Called when the validation batch begins.

on_validation_end(trainer, pl_module)

Called when the validation loop ends.

on_validation_epoch_end(trainer, pl_module)

Called when the val epoch ends.

on_validation_epoch_start(trainer, pl_module)

Called when the val epoch begins.

on_validation_start(trainer, pl_module)

Called when the validation loop begins.

print(*args[, sep])

You should provide a way to print without breaking the progress bar.

setup(trainer, pl_module, stage)

Called when fit, validate, test, predict, or tune begins.

state_dict()

Called when saving a checkpoint, implement to generate callback's state_dict.

teardown(trainer, pl_module, stage)

Called when fit, validate, test, predict, or tune ends.

has_dataloader_changed

reset_dataloader_idx_tracker

BAR_FORMAT = '{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_noinv_fmt}{postfix}]'
disable()

You should provide a way to disable the progress bar.

Return type

None

enable()

You should provide a way to enable the progress bar.

The Trainer will call this in e.g. pre-training routines like the learning rate finder. to temporarily enable and disable the training progress bar.

Return type

None

get_metrics(trainer, pl_module)

Combines progress bar metrics collected from the trainer with standard metrics from get_standard_metrics. Implement this to override the items displayed in the progress bar.

Here is an example of how to override the defaults:

def get_metrics(self, trainer, model):
    # don't show the version number
    items = super().get_metrics(trainer, model)
    items.pop("v_num", None)
    return items
Return type

dict[str, Union[int, str, float, dict[str, float]]]

Returns

Dictionary with the items to be displayed in the progress bar.

has_dataloader_changed(dataloader_idx)
Return type

bool

init_predict_tqdm()[source]

Override this to customize the tqdm bar for predicting.

Return type

Tqdm

init_sanity_tqdm()[source]

Override this to customize the tqdm bar for the validation sanity run.

Return type

Tqdm

init_test_tqdm()

Override this to customize the tqdm bar for testing.

Return type

Tqdm

init_train_tqdm()[source]

Override this to customize the tqdm bar for training.

Return type

Tqdm

init_validation_tqdm()[source]

Override this to customize the tqdm bar for validation.

Return type

Tqdm

property is_disabled: bool
Return type

bool

property is_enabled: bool
Return type

bool

load_state_dict(state_dict)

Called when loading a checkpoint, implement to reload callback state given callback’s state_dict.

Parameters

state_dict (dict[str, Any]) – the callback state returned by state_dict.

Return type

None

on_after_backward(trainer, pl_module)

Called after loss.backward() and before optimizers are stepped.

Return type

None

on_before_backward(trainer, pl_module, loss)

Called before loss.backward().

Return type

None

on_before_optimizer_step(trainer, pl_module, optimizer)

Called before optimizer.step().

Return type

None

on_before_zero_grad(trainer, pl_module, optimizer)

Called before optimizer.zero_grad().

Return type

None

on_exception(trainer, pl_module, exception)

Called when any trainer execution is interrupted by an exception.

Return type

None

on_fit_end(trainer, pl_module)

Called when fit ends.

Return type

None

on_fit_start(trainer, pl_module)

Called when fit begins.

Return type

None

on_load_checkpoint(trainer, pl_module, checkpoint)

Called when loading a model checkpoint, use to reload state.

Parameters
  • trainer (Trainer) – the current Trainer instance.

  • pl_module (LightningModule) – the current LightningModule instance.

  • checkpoint (dict[str, Any]) – the full checkpoint dictionary that got loaded by the Trainer.

Return type

None

on_predict_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0)

Called when the predict batch ends.

Return type

None

on_predict_batch_start(trainer, pl_module, batch, batch_idx, dataloader_idx=0)

Called when the predict batch begins.

Return type

None

on_predict_end(trainer, pl_module)

Called when predict ends.

Return type

None

on_predict_epoch_end(trainer, pl_module)

Called when the predict epoch ends.

Return type

None

on_predict_epoch_start(trainer, pl_module)

Called when the predict epoch begins.

Return type

None

on_predict_start(trainer, pl_module)

Called when the predict begins.

Return type

None

on_sanity_check_end(*_)

Called when the validation sanity check ends.

Return type

None

on_sanity_check_start(*_)

Called when the validation sanity check starts.

Return type

None

on_save_checkpoint(trainer, pl_module, checkpoint)

Called when saving a checkpoint to give you a chance to store anything else you might want to save.

Parameters
  • trainer (Trainer) – the current Trainer instance.

  • pl_module (LightningModule) – the current LightningModule instance.

  • checkpoint (dict[str, Any]) – the checkpoint dictionary that will be saved.

Return type

None

on_test_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0)

Called when the test batch ends.

Return type

None

on_test_batch_start(trainer, pl_module, batch, batch_idx, dataloader_idx=0)

Called when the test batch begins.

Return type

None

on_test_end(trainer, pl_module)

Called when the test ends.

Return type

None

on_test_epoch_end(trainer, pl_module)

Called when the test epoch ends.

Return type

None

on_test_epoch_start(trainer, pl_module)

Called when the test epoch begins.

Return type

None

on_test_start(trainer, pl_module)

Called when the test begins.

Return type

None

on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx)

Called when the train batch ends.

Note

The value outputs["loss"] here will be the normalized value w.r.t accumulate_grad_batches of the loss returned from training_step.

Return type

None

on_train_batch_start(trainer, pl_module, batch, batch_idx)

Called when the train batch begins.

Return type

None

on_train_end(*_)

Called when the train ends.

Return type

None

on_train_epoch_end(trainer, pl_module)

Called when the train epoch ends.

To access all batch outputs at the end of the epoch, you can cache step outputs as an attribute of the pytorch_lightning.core.LightningModule and access them in this hook:

class MyLightningModule(L.LightningModule):
    def __init__(self):
        super().__init__()
        self.training_step_outputs = []

    def training_step(self):
        loss = ...
        self.training_step_outputs.append(loss)
        return loss

class MyCallback(L.Callback):
    def on_train_epoch_end(self, trainer, pl_module):
        # do something with all training_step outputs, for example:
        epoch_mean = torch.stack(pl_module.training_step_outputs).mean()
        pl_module.log("training_epoch_mean", epoch_mean)
        # free up the memory
        pl_module.training_step_outputs.clear()
Return type

None

on_train_epoch_start(trainer, *_)

Called when the train epoch begins.

Return type

None

on_train_start(*_)

Called when the train begins.

Return type

None

on_validation_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0)

Called when the validation batch ends.

Return type

None

on_validation_batch_start(trainer, pl_module, batch, batch_idx, dataloader_idx=0)

Called when the validation batch begins.

Return type

None

on_validation_end(trainer, pl_module)

Called when the validation loop ends.

Return type

None

on_validation_epoch_end(trainer, pl_module)

Called when the val epoch ends.

Return type

None

on_validation_epoch_start(trainer, pl_module)

Called when the val epoch begins.

Return type

None

on_validation_start(trainer, pl_module)

Called when the validation loop begins.

Return type

None

property predict_description: str
Return type

str

property predict_progress_bar: tqdm_asyncio
Return type

tqdm_asyncio

print(*args, sep=' ', **kwargs)

You should provide a way to print without breaking the progress bar.

Return type

None

property process_position: int
Return type

int

property refresh_rate: int
Return type

int

reset_dataloader_idx_tracker()
Return type

None

property sanity_check_description: str
Return type

str

setup(trainer, pl_module, stage)

Called when fit, validate, test, predict, or tune begins.

Return type

None

state_dict()

Called when saving a checkpoint, implement to generate callback’s state_dict.

Return type

dict[str, Any]

Returns

A dictionary containing callback state.

property state_key: str

Identifier for the state of the callback.

Used to store and retrieve a callback’s state from the checkpoint dictionary by checkpoint["callbacks"][state_key]. Implementations of a callback need to provide a unique state key if 1) the callback has state and 2) it is desired to maintain the state of multiple instances of that callback.

Return type

str

teardown(trainer, pl_module, stage)

Called when fit, validate, test, predict, or tune ends.

Return type

None

property test_description: str
Return type

str

property test_progress_bar: tqdm_asyncio
Return type

tqdm_asyncio

property total_predict_batches_current_dataloader: Union[int, float]

The total number of prediction batches, which may change from epoch to epoch for current dataloader.

Use this to set the total number of iterations in the progress bar. Can return inf if the predict dataloader is of infinite size.

Return type

Union[int, float]

property total_test_batches_current_dataloader: Union[int, float]

The total number of testing batches, which may change from epoch to epoch for current dataloader.

Use this to set the total number of iterations in the progress bar. Can return inf if the test dataloader is of infinite size.

Return type

Union[int, float]

property total_train_batches: Union[int, float]

The total number of training batches, which may change from epoch to epoch.

Use this to set the total number of iterations in the progress bar. Can return inf if the training dataloader is of infinite size.

Return type

Union[int, float]

property total_val_batches: Union[int, float]

The total number of validation batches, which may change from epoch to epoch for all val dataloaders.

Use this to set the total number of iterations in the progress bar. Can return inf if the predict dataloader is of infinite size.

Return type

Union[int, float]

property total_val_batches_current_dataloader: Union[int, float]

The total number of validation batches, which may change from epoch to epoch for current dataloader.

Use this to set the total number of iterations in the progress bar. Can return inf if the validation dataloader is of infinite size.

Return type

Union[int, float]

property train_description: str
Return type

str

property train_progress_bar: tqdm_asyncio
Return type

tqdm_asyncio

property trainer: Trainer
Return type

Trainer

property val_progress_bar: tqdm_asyncio
Return type

tqdm_asyncio

property validation_description: str
Return type

str