Source code for malib.evaluator.psro

"""
Implementation of global evaluator for Policy Space Response Oracle (PSRO) algorithms. This evaluator will evaluate the
exploitablility between weighted payoff and an oracle payoff.
"""


from malib.evaluator.base_evaluator import BaseEvaluator
from malib.utils.typing import RolloutFeedback, EvaluateResult, TrainingFeedback, Union


[docs]class PSROEvaluator(BaseEvaluator): """Evaluator for Policy Space Response Oracle algorithms"""
[docs] class StopMetrics: """Supported stopping metrics""" MAX_ITERATION = "max_iteration" """Max iteration""" PAYOFF_DIFF_THRESHOLD = "payoff_diff_threshold" """Threshold of difference between the estimated payoff of best response and NE's""" NASH_COV = "nash cov"
def __init__(self, **config): """Create a PSRO evaluator instance. :param Dict[str,Any] config: A dictionary of stopping metrics. """ stop_metrics = config.get("stop_metrics", {}) super(PSROEvaluator, self).__init__(stop_metrics, name="PSRO") self._iteration = 0
[docs] def evaluate( self, content: Union[RolloutFeedback, TrainingFeedback], weighted_payoffs=None, oracle_payoffs=None, trainable_mapping=None, ): """Evaluate global convergence by comparing the margin between Nash and best response. Or, an estimation of exploitability """ res = EvaluateResult.default_result() # res[EvaluateResult.AVE_REWARD] = { # aid: weighted_payoffs[aid] for aid in trainable_mapping # } # nash_cov = 0.0 # for aid, weighted_payoff in weighted_payoffs.items(): # nash_cov += abs(weighted_payoff - oracle_payoffs[aid]) # nash_cov /= 2.0 # res["exploitability"] = nash_cov # default by no limitation on iteration self._iteration += 1 res[ EvaluateResult.REACHED_MAX_ITERATION ] = self._iteration == self._metrics.get( PSROEvaluator.StopMetrics.MAX_ITERATION, 100 ) if res[EvaluateResult.REACHED_MAX_ITERATION]: res[EvaluateResult.CONVERGED] = True return res