Source code for darts.ad.aggregators.ensemble_sklearn_aggregator

"""
Ensemble scikit-learn aggregator
--------------------------------

Aggregator wrapped around the Ensemble model of sklearn.
`sklearn https://scikit-learn.org/stable/modules/ensemble.html`_.
"""

from typing import Sequence

import numpy as np
from sklearn.ensemble import BaseEnsemble

from darts import TimeSeries
from darts.ad.aggregators.aggregators import FittableAggregator
from darts.logging import raise_if_not


[docs]class EnsembleSklearnAggregator(FittableAggregator): def __init__(self, model) -> None: raise_if_not( isinstance(model, BaseEnsemble), f"Scorer is expecting a model of type BaseEnsemble (from sklearn ensemble), \ found type {type(model)}.", ) self.model = model super().__init__() def __str__(self): return "EnsembleSklearnAggregator: {}".format( self.model.__str__().split("(")[0] ) def _fit_core( self, actual_anomalies: Sequence[TimeSeries], series: Sequence[TimeSeries], ): X = np.concatenate( [s.all_values(copy=False).reshape(len(s), -1) for s in series], axis=0, ) y = np.concatenate( [s.all_values(copy=False).reshape(len(s)) for s in actual_anomalies], axis=0, ) self.model.fit(y=y, X=X) return self def _predict_core(self, series: Sequence[TimeSeries]) -> Sequence[TimeSeries]: return [ TimeSeries.from_times_and_values( s.time_index, self.model.predict((s).all_values(copy=False).reshape(len(s), -1)), ) for s in series ]