Source code for darts.utils.torch

"""
Utils for Pytorch and its usage
-------------------------------
"""

from functools import wraps
from inspect import signature
from typing import Any, Callable, TypeVar

import torch.nn as nn
import torch.nn.functional as F
from numpy.random import randint
from sklearn.utils import check_random_state
from torch import Tensor
from torch.random import fork_rng, manual_seed

from darts.logging import get_logger, raise_if_not

T = TypeVar("T")
logger = get_logger(__name__)

MAX_TORCH_SEED_VALUE = (1 << 31) - 1  # to accommodate 32-bit architectures
MAX_NUMPY_SEED_VALUE = (1 << 31) - 1


[docs]class MonteCarloDropout(nn.Dropout): """ Defines Monte Carlo dropout Module as defined in the paper https://arxiv.org/pdf/1506.02142.pdf. In summary, This technique uses the regular dropout which can be interpreted as a Bayesian approximation of a well-known probabilistic model: the Gaussian process. We can treat the many different networks (with different neurons dropped out) as Monte Carlo samples from the space of all available models. This provides mathematical grounds to reason about the model’s uncertainty and, as it turns out, often improves its performance. """ # mc dropout is deactivated at init; see `MonteCarloDropout.mc_dropout_enabled` for more info _mc_dropout_enabled = False
[docs] def forward(self, input: Tensor) -> Tensor: # NOTE: we could use the following line in case a different rate # is used for inference: # return F.dropout(input, self.applied_rate, True, self.inplace) return F.dropout(input, self.p, self.mc_dropout_enabled, self.inplace)
@property def mc_dropout_enabled(self) -> bool: # mc dropout is only activated on `PLForecastingModule.on_predict_start()` # otherwise, it is activated based on the `model.training` flag. return self._mc_dropout_enabled or self.training
def _is_method(func: Callable[..., Any]) -> bool: """Check if the specified function is a method. Parameters ---------- func the function to inspect. Returns ------- bool true if `func` is a method, false otherwise. """ spec = signature(func) if len(spec.parameters) > 0: if list(spec.parameters.keys())[0] == "self": return True return False
[docs]def random_method(decorated: Callable[..., T]) -> Callable[..., T]: """Decorator usable on any method within a class that will provide an isolated torch random context. The decorator will store a `_random_instance` property on the object in order to persist successive calls to the RNG Parameters ---------- decorated A method to be run in an isolated torch random context. """ # check that @random_method has been applied to a method. raise_if_not( _is_method(decorated), "@random_method can only be used on methods.", logger ) @wraps(decorated) def decorator(self, *args, **kwargs) -> T: if "random_state" in kwargs.keys(): self._random_instance = check_random_state(kwargs["random_state"]) elif not hasattr(self, "_random_instance"): self._random_instance = check_random_state( randint(0, high=MAX_NUMPY_SEED_VALUE) ) with fork_rng(): manual_seed(self._random_instance.randint(0, high=MAX_TORCH_SEED_VALUE)) return decorated(self, *args, **kwargs) return decorator