import abc
import numpy as np
from malib.utils.typing import Tuple, Any, Union, Callable
[docs]class BaseAggregator(abc.ABC):
def __init__(self, name: str):
self._name = name
[docs] @abc.abstractmethod
def apply(self, x):
raise NotImplementedError
[docs]class Mean(BaseAggregator):
def __init__(
self, weights: Tuple[float] = None, scale: int = None, *args, **kwargs
):
super().__init__(name="mean")
self._scale = scale
self._weights = np.array(weights) if weights else None
[docs] def apply(self, x, *args: Any, **kwargs: Any) -> Any:
if self._weights:
x = np.multiply(np.array(x), self._weights)
return (
np.divide(np.sum(x), self._scale, *args, **kwargs)
if self._scale
else np.mean(x, *args, **kwargs)
)
[docs]class Max(BaseAggregator):
def __init__(self, weights: Tuple[float] = None, *args, **kwargs):
super().__init__(name="max")
self._weights = np.array(weights) if weights else None
[docs] def apply(self, x, *args: Any, **kwargs: Any) -> Any:
if self._weights:
x = np.multiply(np.array(x), self._weights)
return np.max(x, *args, **kwargs)
[docs]class Min(BaseAggregator):
def __init__(self, weights: Tuple[float] = None, *args, **kwargs):
super().__init__(name="min")
self._weights = np.array(weights) if weights else None
[docs] def apply(self, x, *args: Any, **kwargs: Any) -> Any:
if self._weights:
x = np.multiply(np.array(x), self._weights)
return np.min(x, *args, **kwargs)
[docs]class Aggregator:
MEAN = "mean"
MAX = "max"
MIN = "min"
m = dict(zip([MEAN, MAX, MIN], [Mean, Max, Min]))
[docs] @staticmethod
def register(name: str, cls_build_func: Callable):
Aggregator.m.update(name, cls_build_func)
[docs] @staticmethod
def get(name: str) -> Union[Callable, None]:
return Aggregator.m.get(name, None)