Source code for darts.utils.torch

"""
Utils for Pytorch and its usage
-------------------------------
"""

from collections.abc import Callable
from functools import wraps
from typing import TypeVar

import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from sklearn.utils import check_random_state
from torch import Tensor
from torch.random import fork_rng, manual_seed

from darts.logging import get_logger, raise_log
from darts.utils.utils import MAX_NUMPY_SEED_VALUE, MAX_TORCH_SEED_VALUE, _is_method

T = TypeVar("T")
logger = get_logger(__name__)


[docs] class MonteCarloDropout(nn.Dropout): """ Defines Monte Carlo dropout Module as defined in the paper https://arxiv.org/pdf/1506.02142.pdf. In summary, This technique uses the regular dropout which can be interpreted as a Bayesian approximation of a well-known probabilistic model: the Gaussian process. We can treat the many different networks (with different neurons dropped out) as Monte Carlo samples from the space of all available models. This provides mathematical grounds to reason about the model’s uncertainty and, as it turns out, often improves its performance. """ # mc dropout is deactivated at init; see `MonteCarloDropout.mc_dropout_enabled` for more info _mc_dropout_enabled = False
[docs] def forward(self, input: Tensor) -> Tensor: # NOTE: we could use the following line in case a different rate # is used for inference: # return F.dropout(input, self.applied_rate, True, self.inplace) return F.dropout(input, self.p, self.mc_dropout_enabled, self.inplace)
@property def mc_dropout_enabled(self) -> bool: # mc dropout is only activated on `PLForecastingModule.on_predict_start()` # otherwise, it is activated based on the `model.training` flag. return self._mc_dropout_enabled or self.training
[docs] def random_method(decorated: Callable[..., T]) -> Callable[..., T]: """Decorator usable on any method within a class that will provide an isolated torch random context. The decorator will store a `_random_instance` property on the object in order to persist successive calls to the RNG Parameters ---------- decorated A method to be run in an isolated torch random context. """ # check that @random_method has been applied to a method. if not _is_method(decorated): raise_log(ValueError("@random_method can only be used on methods."), logger) @wraps(decorated) def decorator(self, *args, **kwargs) -> T: store_instance = False random_instance = None if "random_state" in kwargs.keys() and kwargs["random_state"] is not None: # get random state from model constructor or `predict()` random_instance = check_random_state(kwargs["random_state"]) if not hasattr(self, "_random_instance"): # store random instance when called from model constructor store_instance = True elif not hasattr(self, "_random_instance"): # get random state for first time from other method store_instance = True random_instance = check_random_state( np.random.randint(0, high=MAX_NUMPY_SEED_VALUE) ) # if no random instance is provided, use the one stored in the class if random_instance is None: random_instance = self._random_instance if store_instance: self._random_instance = random_instance # handle the randomness with fork_rng(): manual_seed(random_instance.randint(0, high=MAX_TORCH_SEED_VALUE)) return decorated(self, *args, **kwargs) return decorator