Source code for malib.rollout.envs.pettingzoo.env

# MIT License

# Copyright (c) 2021 MARL @ SJTU

# Author: Ming Zhou

# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:

# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.

# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

from typing import List, Dict, Any, Tuple, Union, Sequence

import importlib
import gym
import pettingzoo

from malib.utils.typing import AgentID
from malib.rollout.envs.env import Environment


[docs]class PettingZooEnv(Environment): def __init__(self, **configs): super(PettingZooEnv, self).__init__(**configs) env_id = configs["env_id"] scenario_configs = configs["scenario_configs"].copy() parallel_simulate = scenario_configs.pop("parallel_simulate", False) domain_id, scenario_id = env_id.split(".") domain = importlib.import_module(f"pettingzoo.{domain_id}") scenario = getattr(domain, scenario_id) if parallel_simulate: self.env: pettingzoo.ParallelEnv = scenario.parallel_env(**scenario_configs) else: self.env: pettingzoo.AECEnv = scenario.env(**scenario_configs) self._parallel_simulate = parallel_simulate self._action_spaces = { agent: self.env.action_space(agent) for agent in self.env.possible_agents } self._observation_spaces = { agent: self.env.observation_space(agent) for agent in self.env.possible_agents } @property def possible_agents(self) -> List[AgentID]: return self.env.possible_agents.copy() @property def parallel_simulate(self) -> bool: return self._parallel_simulate @property def action_spaces(self) -> Dict[AgentID, gym.Space]: return self._action_spaces @property def observation_spaces(self) -> Dict[AgentID, gym.Space]: return self._observation_spaces
[docs] def time_step( self, actions: Dict[AgentID, Any] ) -> Tuple[ Dict[AgentID, Any], Dict[AgentID, float], Dict[AgentID, bool], Dict[AgentID, Any], ]: if not self.parallel_simulate: self.env.step(actions[self.env.agent_selection]) rewards = self.env.rewards.copy() observations = { agent: self.env.observe(agent) for agent in self.possible_agents } dones = self.env.terminations.copy() infos = self.env.infos.copy() else: observations, rewards, dones, truncations, infos = map( dict, self.env.step(actions) ) return None, observations, rewards, dones, infos
[docs] def reset(self, max_step: int = None) -> Union[None, Sequence[Dict[AgentID, Any]]]: super(PettingZooEnv, self).reset(max_step) observations = self.env.reset() if observations is None: # means AECEnv observations = { agent: self.env.observe(agent) for agent in self.possible_agents } return None, observations
[docs] def seed(self, seed: int = None): return self.env.seed(seed)
[docs] def render(self, *args, **kwargs): return self.env.render()
[docs] def close(self): return self.env.close()
# if __name__ == "__main__": # from malib.rollout.envs.pettingzoo.scenario_configs_ref import SCENARIO_CONFIGS # for env_id, scenario_configs in SCENARIO_CONFIGS.items(): # env = PettingZooEnv(env_id=env_id, scenario_configs=scenario_configs) # action_spaces = env.action_spaces # _, observations = env.reset() # done = False # cnt = 0 # while not done: # actions = {k: action_spaces[k].sample() for k in observations.keys()} # _, observations, rewards, dones, infos = env.step(actions) # done = dones["__all__"] # cnt += 1 # print("[{}] test passed after {} steps".format(env_id, cnt))