Source code for malib.utils.stoppers

from malib.utils.typing import Status, MetricType, MetricEntry, Dict, List


[docs]class Stopper: def __init__(self, config: Dict, tasks: List = None): """Create a stopper instance with metric fields. This fields should cover all feasible attributes from rollout/training results. :param Dict config: Configuration to control the stopping. :param List tasks: A list of sub task identifications. Default to None """ self._config = config self.tasks_status = dict.fromkeys(tasks, Status.NORMAL) if tasks else {} def __call__(self, results: Dict[str, MetricEntry], global_step: int) -> bool: """Parse results and determine whether we should terminate tasks.""" raise NotImplementedError @property def info(self): """Return statistics for analysis""" raise NotImplementedError
[docs] def set_terminate(self, task_id: str) -> None: """Terminate sub task tagged with task_id, and set status to terminate.""" assert task_id in self.tasks_status, (task_id, self.tasks_status) self.tasks_status[task_id] = Status.TERMINATE
[docs] def all(self): """Judge whether all tasks have been terminated :return: a bool value indicates terminated or not """ terminate = len(self.tasks_status) > 0 for status in self.tasks_status.values(): if status == Status.NORMAL: terminate = False break return terminate
[docs]class SimpleRolloutStopper(Stopper): """SimpleRolloutStopper will check the equivalence between evaluate results and""" def __init__(self, config, tasks: List = None): super(SimpleRolloutStopper, self).__init__(config, tasks) self._config["max_step"] = self._config.get("max_step", 100) self._info = {MetricType.REACH_MAX_STEP: False} @property def max_iteration(self): return self._config["max_step"] def __call__(self, results: Dict[str, MetricEntry], global_step): """Default rollout stopper will return true when global_step reaches to an oracle""" if global_step == self._config["max_step"]: self._info[MetricType.REACH_MAX_STEP] = True return True return False @property def info(self): raise self._info
[docs]class NonStopper(Stopper): """NonStopper always return false""" def __init__(self, config, tasks=None): super(NonStopper, self).__init__(config, tasks) def __call__(self, *args, **kwargs): return False @property def info(self): return {}
[docs]class SimpleTrainingStopper(Stopper): """SimpleRolloutStopper will check the equivalence between evaluate results and""" def __init__(self, config: Dict, tasks: List = None): super(SimpleTrainingStopper, self).__init__(config, tasks) self._config["max_step"] = self._config.get("max_step", 100) self._info = { MetricType.REACH_MAX_STEP: False, } def __call__(self, results: Dict[str, MetricEntry], global_step): """Ignore training loss, use global step.""" if global_step == self._config["max_step"]: self._info[MetricType.REACH_MAX_STEP] = True return True return False @property def info(self): return self._info
[docs]def get_stopper(name: str): """Return a stopper class with given type name. :param str name: Stopper name, choices {simple_rollout, simple_training}. :return: A stopper type. """ return { "none": NonStopper, "simple_rollout": SimpleRolloutStopper, "simple_training": SimpleTrainingStopper, }[name]