Source code for malib.utils.exploitability

# MIT License

# Copyright (c) 2021 MARL @ SJTU

# Author: Ziyu Wan

# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:

# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.

# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

from collections import defaultdict
from typing import Dict, Sequence, List, Tuple, Union, Any

import logging
import warnings
import itertools

from open_spiel.python.policy import (
    Policy as OSPolicy,
    TabularPolicy,
    joint_action_probabilities_aux,
)
from open_spiel.python.algorithms import policy_aggregator
from open_spiel.python.algorithms.exploitability import nash_conv

try:
    import pyspiel
except ImportError as e:
    logging.warning(
        "Cannot import open spiel, if you wanna run meta game experiment, please install it before that."
    )

import numpy as np
import torch

from malib.utils.typing import PolicyID, AgentID
from malib.rl.common.policy import Policy
from malib.common.strategy_spec import StrategySpec


# inference: https://github.com/JBLanier/pipeline-psro


[docs]class NFSPPolicies(OSPolicy): """Joint policy to be evaluated.""" def __init__(self, game, nfsp_policies: List[TabularPolicy]): policies = {} player_ids = list(range(game.num_players())) for policy in nfsp_policies: policies.update(dict.fromkeys(policy.player_ids, policy)) super(NFSPPolicies, self).__init__(game, player_ids) self._policies = policies
[docs] def action_probabilities(self, state: Any, player_id: str = None): if player_id is None: cur_player = state.current_player() if cur_player == -2: # it is a simultaneous node actions_per_player, probs_per_player = joint_action_probabilities_aux( state, self ) dim = self.game.num_distinct_actions() res = [ (actions[0] * dim + actions[1], np.prod(probs)) for actions, probs in zip( itertools.product(*actions_per_player), itertools.product(*probs_per_player), ) ] prob_dict = dict(res) else: prob_dict = self._policies[cur_player].action_probabilities( state, cur_player ) return prob_dict
[docs]def compute_act_probs( game: pyspiel.Game, policy: "Policy", state: pyspiel.State, player_id: int, use_observation, epsilon: float = 1e-5, ): info_state = ( state.observation_tensor(player_id) if use_observation else state.information_state_tensor(player_id) ) legal_actions = state.legal_actions(player_id) action_mask = state.legal_actions_mask(player_id) obs_np = np.asarray(info_state + action_mask, dtype=np.float32) act_mask_np = np.asarray(action_mask, dtype=np.float32) action, action_probs, logits, hidden_state = policy.compute_action( observation=torch.as_tensor( [obs_np], dtype=torch.float32, device=policy.device ), act_mask=torch.as_tensor( [action_mask], dtype=torch.float32, device=policy.device ), evaluate=True, hidden_state=None, ) legal_act_probs = np.clip(action_probs[0][np.nonzero(act_mask_np)], epsilon, None) # we should normalize probs to avoid zero probs legal_act_probs = legal_act_probs / legal_act_probs.sum() return { act_name: act_prob for act_name, act_prob in zip(legal_actions, legal_act_probs) # if act_prob > 0 }
[docs]class OSPolicyWrapper(OSPolicy): def __init__( self, game, policy: "Policy", player_ids: List[int], use_observation, tolerance: float = 1e-5, ): super().__init__(game, player_ids) self.rl_policy = policy self.use_observation = use_observation self.tolerance = tolerance
[docs] def action_probabilities(self, state, player_id=None): return compute_act_probs( self.game, self.rl_policy, state, player_id, self.use_observation, self.tolerance, )
[docs]def convert_to_os_policies( game, policies: List["Policy"], use_observation: bool, player_ids: List[int] ) -> List[OSPolicy]: res = [] for rl_policy in policies: res.append( OSPolicyWrapper( game, rl_policy, player_ids, use_observation, tolerance=1e-5 ) ) return res
[docs]def measure_exploitability( game: Union[str, pyspiel.Game], populations: Dict[AgentID, Dict[PolicyID, Policy]], policy_mixture_dict: Dict[AgentID, Dict[PolicyID, float]], use_observation: bool = False, use_cpp_br: bool = False, ): """Return a measure of closeness to Nash for a policy in the game. Args: game (Union[str, pyspiel.Game]): An open_spiel game, e.g. kuhn_poker. populations (Dict[AgentID, Dict[PolicyID, Policy]]): A dict of strategy specs, mapping from agent to StrategySpec. policy_mixture_dict (Dict[AgentID, Dict[PolicyID, float]]): A dict if policy distribution, maps from agent to a dict of floats. use_cpp_br (bool, optional): Compute best response with C++. Defaults to False. Returns: NashConv: An object with the following attributes: - player_improvements: A `[num_players]` numpy array of the improvement for players (i.e. value_player_p_versus_BR - value_player_p). - nash_conv: The sum over all players of the improvements in value that each player could obtain by unilaterally changing their strategy, i.e. sum(player_improvements). """ if isinstance(game, str): game = pyspiel.load_game(game) weights = {} # convert dict to list for agent_id, population in populations.items(): pids = list(population.keys()) populations[agent_id] = list(population.values()) policy_dist = policy_mixture_dict[agent_id] weights[agent_id] = [policy_dist[k] for k in pids] # convert each policy to open_spiel policy policies = [ convert_to_os_policies( game, populations[f"player_{i}"], use_observation, player_ids=[i] ) for i in range(game.num_players()) ] weights = [weights[f"player_{i}"] for i in range(game.num_players())] aggregator = policy_aggregator.PolicyAggregator(game) with warnings.catch_warnings(record=True) as w: aggr_policies = aggregator.aggregate( range(game.num_players()), policies=policies, weights=weights ) if len(w) == 1 and w[0]._category_name == "VisibleDeprecationWarning": print( "\033[93m [WARNING] A `VisibleDeprecateionWarning` is caused by the use of `np.copy` of an object. You can avoid it by modifying line240 of `open_spiel.python.algorithms.policy_aggregator.py` as `new_reaches = [e.copy() for e in my_reaches]`\033[0m" ) return nash_conv( game=game, policy=aggr_policies, return_only_nash_conv=False, use_cpp_br=use_cpp_br, )