# MIT License
# Copyright (c) 2021 MARL @ SJTU
# Author: Ziyu Wan
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
from collections import defaultdict
from typing import Dict, Sequence, List, Tuple, Union, Any
import logging
import warnings
import itertools
from open_spiel.python.policy import (
Policy as OSPolicy,
TabularPolicy,
joint_action_probabilities_aux,
)
from open_spiel.python.algorithms import policy_aggregator
from open_spiel.python.algorithms.exploitability import nash_conv
try:
import pyspiel
except ImportError as e:
logging.warning(
"Cannot import open spiel, if you wanna run meta game experiment, please install it before that."
)
import numpy as np
import torch
from malib.utils.typing import PolicyID, AgentID
from malib.rl.common.policy import Policy
from malib.common.strategy_spec import StrategySpec
# inference: https://github.com/JBLanier/pipeline-psro
[docs]class NFSPPolicies(OSPolicy):
"""Joint policy to be evaluated."""
def __init__(self, game, nfsp_policies: List[TabularPolicy]):
policies = {}
player_ids = list(range(game.num_players()))
for policy in nfsp_policies:
policies.update(dict.fromkeys(policy.player_ids, policy))
super(NFSPPolicies, self).__init__(game, player_ids)
self._policies = policies
[docs] def action_probabilities(self, state: Any, player_id: str = None):
if player_id is None:
cur_player = state.current_player()
if cur_player == -2:
# it is a simultaneous node
actions_per_player, probs_per_player = joint_action_probabilities_aux(
state, self
)
dim = self.game.num_distinct_actions()
res = [
(actions[0] * dim + actions[1], np.prod(probs))
for actions, probs in zip(
itertools.product(*actions_per_player),
itertools.product(*probs_per_player),
)
]
prob_dict = dict(res)
else:
prob_dict = self._policies[cur_player].action_probabilities(
state, cur_player
)
return prob_dict
[docs]def compute_act_probs(
game: pyspiel.Game,
policy: "Policy",
state: pyspiel.State,
player_id: int,
use_observation,
epsilon: float = 1e-5,
):
info_state = (
state.observation_tensor(player_id)
if use_observation
else state.information_state_tensor(player_id)
)
legal_actions = state.legal_actions(player_id)
action_mask = state.legal_actions_mask(player_id)
obs_np = np.asarray(info_state + action_mask, dtype=np.float32)
act_mask_np = np.asarray(action_mask, dtype=np.float32)
action, action_probs, logits, hidden_state = policy.compute_action(
observation=torch.as_tensor(
[obs_np], dtype=torch.float32, device=policy.device
),
act_mask=torch.as_tensor(
[action_mask], dtype=torch.float32, device=policy.device
),
evaluate=True,
hidden_state=None,
)
legal_act_probs = np.clip(action_probs[0][np.nonzero(act_mask_np)], epsilon, None)
# we should normalize probs to avoid zero probs
legal_act_probs = legal_act_probs / legal_act_probs.sum()
return {
act_name: act_prob
for act_name, act_prob in zip(legal_actions, legal_act_probs) # if act_prob > 0
}
[docs]class OSPolicyWrapper(OSPolicy):
def __init__(
self,
game,
policy: "Policy",
player_ids: List[int],
use_observation,
tolerance: float = 1e-5,
):
super().__init__(game, player_ids)
self.rl_policy = policy
self.use_observation = use_observation
self.tolerance = tolerance
[docs] def action_probabilities(self, state, player_id=None):
return compute_act_probs(
self.game,
self.rl_policy,
state,
player_id,
self.use_observation,
self.tolerance,
)
[docs]def convert_to_os_policies(
game, policies: List["Policy"], use_observation: bool, player_ids: List[int]
) -> List[OSPolicy]:
res = []
for rl_policy in policies:
res.append(
OSPolicyWrapper(
game, rl_policy, player_ids, use_observation, tolerance=1e-5
)
)
return res
[docs]def measure_exploitability(
game: Union[str, pyspiel.Game],
populations: Dict[AgentID, Dict[PolicyID, Policy]],
policy_mixture_dict: Dict[AgentID, Dict[PolicyID, float]],
use_observation: bool = False,
use_cpp_br: bool = False,
):
"""Return a measure of closeness to Nash for a policy in the game.
Args:
game (Union[str, pyspiel.Game]): An open_spiel game, e.g. kuhn_poker.
populations (Dict[AgentID, Dict[PolicyID, Policy]]): A dict of strategy specs, mapping from agent to StrategySpec.
policy_mixture_dict (Dict[AgentID, Dict[PolicyID, float]]): A dict if policy distribution, maps from agent to a dict of floats.
use_cpp_br (bool, optional): Compute best response with C++. Defaults to False.
Returns:
NashConv: An object with the following attributes:
- player_improvements: A `[num_players]` numpy array of the improvement for players (i.e. value_player_p_versus_BR - value_player_p).
- nash_conv: The sum over all players of the improvements in value that each player could obtain by unilaterally changing their strategy, i.e. sum(player_improvements).
"""
if isinstance(game, str):
game = pyspiel.load_game(game)
weights = {}
# convert dict to list
for agent_id, population in populations.items():
pids = list(population.keys())
populations[agent_id] = list(population.values())
policy_dist = policy_mixture_dict[agent_id]
weights[agent_id] = [policy_dist[k] for k in pids]
# convert each policy to open_spiel policy
policies = [
convert_to_os_policies(
game, populations[f"player_{i}"], use_observation, player_ids=[i]
)
for i in range(game.num_players())
]
weights = [weights[f"player_{i}"] for i in range(game.num_players())]
aggregator = policy_aggregator.PolicyAggregator(game)
with warnings.catch_warnings(record=True) as w:
aggr_policies = aggregator.aggregate(
range(game.num_players()), policies=policies, weights=weights
)
if len(w) == 1 and w[0]._category_name == "VisibleDeprecationWarning":
print(
"\033[93m [WARNING] A `VisibleDeprecateionWarning` is caused by the use of `np.copy` of an object. You can avoid it by modifying line240 of `open_spiel.python.algorithms.policy_aggregator.py` as `new_reaches = [e.copy() for e in my_reaches]`\033[0m"
)
return nash_conv(
game=game,
policy=aggr_policies,
return_only_nash_conv=False,
use_cpp_br=use_cpp_br,
)