Source code for darts.ad.aggregators.ensemble_sklearn_aggregator

"""
Ensemble scikit-learn aggregator
--------------------------------
"""

from typing import Sequence

import numpy as np
from sklearn.ensemble import BaseEnsemble

from darts import TimeSeries
from darts.ad.aggregators.aggregators import FittableAggregator
from darts.logging import raise_if_not


[docs]class EnsembleSklearnAggregator(FittableAggregator): def __init__(self, model: BaseEnsemble) -> None: """Ensemble scikit-learn aggregator Aggregator wrapped around the sklearn ensemble model `sklearn ensemble model <https://scikit-learn.org/stable/modules/ensemble.html>`_. Parameters ---------- model The sklearn ensemble model. """ raise_if_not( isinstance(model, BaseEnsemble), f"Scorer is expecting a model of type BaseEnsemble (from sklearn ensemble), \ found type {type(model)}.", ) self.model = model super().__init__() def __str__(self) -> str: return "EnsembleSklearnAggregator: {}".format( self.model.__str__().split("(")[0] ) def _fit_core(self, anomalies: Sequence[np.ndarray], series: Sequence[np.ndarray]): X = np.concatenate(series, axis=0) y = np.concatenate( [s.flatten() for s in anomalies], axis=0, ) self.model.fit(y=y, X=X) def _predict_core(self, series: Sequence[TimeSeries]) -> Sequence[TimeSeries]: # assume that parallelization occurs at sklearn model level return [ TimeSeries.from_times_and_values( s.time_index, self.model.predict(s.values(copy=False)), ) for s in series ]