Source code for malib.rollout.envs.mdp.env

from typing import Dict, Any, Tuple, List, Union, Sequence

import gym

from malib.utils.typing import AgentID
from malib.utils.logging import Logger
from malib.rollout.envs.env import Environment


[docs]class MDPEnvironment(Environment): def __init__(self, **configs): super().__init__(**configs) try: from blackhc import mdp from blackhc.mdp import example as mdp_examples except ImportError as e: Logger.error( "please run `pip install -e .[dev]` before using MDPEnvironment" ) raise e scenarios = { "one_round_dmdp": mdp_examples._one_round_dmdp, "two_round_dmdp": mdp_examples._two_round_dmdp, "one_round_nmdp": mdp_examples._one_round_nmdp, "two_round_nmdp": mdp_examples._two_round_nmdp, "multi_round_nmdp": mdp_examples._multi_round_nmdp, } env_id = configs["env_id"] if env_id not in scenarios: raise ValueError( f"env_id={env_id} not a legal environment id, you should init mdp environments from one of them: {scenarios.keys()}" ) self.env = scenarios[env_id]().to_env() self._possible_agents = ["agent"] @property def possible_agents(self) -> List[AgentID]: return self._possible_agents @property def observation_spaces(self) -> Dict[AgentID, gym.Space]: return dict.fromkeys(self.possible_agents, self.env.observation_space) @property def action_spaces(self) -> Dict[AgentID, gym.Space]: return dict.fromkeys(self.possible_agents, self.env.action_space)
[docs] def time_step( self, actions: Dict[AgentID, Any] ) -> Tuple[ Dict[AgentID, Any], Dict[AgentID, float], Dict[AgentID, bool], Dict[AgentID, Any], ]: obs, rew, done, info = self.env._step(actions["agent"]) obs = dict.fromkeys(self.possible_agents, obs) rew = dict.fromkeys(self.possible_agents, rew) done = dict.fromkeys(self.possible_agents, done) return None, obs, rew, done, info
[docs] def reset(self, max_step: int = None) -> Union[None, Sequence[Dict[AgentID, Any]]]: super(MDPEnvironment, self).reset(max_step=max_step) observation = self.env._reset() return None, dict.fromkeys(self.possible_agents, observation)
[docs] def close(self): return self.env.close()
[docs] def render(self, *args, **kwargs): return self.env._render()
[docs] def seed(self, seed: int = None): return self.env.seed(seed)