Source code for malib.utils.episode

# MIT License

# Copyright (c) 2021 MARL @ SJTU

# Author: Ming Zhou

# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:

# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.

# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

from typing import Dict, Any, List, Sequence
from collections import defaultdict

import traceback
import numpy as np

from malib.utils.typing import AgentID, EnvID


[docs]class Episode: """Multi-agent episode tracking""" CUR_OBS = "obs" NEXT_OBS = "obs_next" ACTION = "act" ACTION_MASK = "act_mask" NEXT_ACTION_MASK = "act_mask_next" PRE_REWARD = "pre_rew" PRE_DONE = "pre_done" REWARD = "rew" DONE = "done" ACTION_LOGITS = "act_logits" ACTION_DIST = "act_dist" INFO = "infos" # optional STATE_VALUE = "state_value_estimation" STATE_ACTION_VALUE = "state_action_value_estimation" CUR_STATE = "state" # current global state NEXT_STATE = "state_next" # next global state LAST_REWARD = "last_reward" # post process ACC_REWARD = "accumulate_reward" ADVANTAGE = "advantage" STATE_VALUE_TARGET = "state_value_target" # model states RNN_STATE = "rnn_state" def __init__(self, agents: List[AgentID], processors=None): # self.processors = processors self.agents = agents self.agent_entry = {agent: defaultdict(lambda: []) for agent in self.agents} def __getitem__(self, __k: AgentID) -> Dict[str, List]: """Return an agent dict. Args: __k (AgentID): Registered agent id. Returns: Dict[str, List]: A dict of transitions. """ return self.agent_entry[__k] def __setitem__(self, __k: AgentID, v: Dict[str, List]) -> None: """Set an agent episode. Args: __k (AgentID): Agent ids v (Dict[str, List]): Transition dict. """ self.agent_entry[__k] = v
[docs] def record( self, data: Dict[str, Dict[str, Any]], agent_first: bool, ignore_keys={} ): """Save a transiton. The given transition is a sub sequence of (obs, action_mask, reward, done, info). Users specify ignore keys to filter keys. Args: data (Dict[str, Dict[AgentID, Any]]): A transition. ignore_keys (dict, optional): . Defaults to {}. """ if agent_first: for agent, kvs in data.items(): for k, v in kvs.items(): self.agent_entry[agent][k].append(v) else: for k, agent_trans in data.items(): for agent, _v in agent_trans.items(): self.agent_entry[agent][k].append(_v)
[docs] def to_numpy(self) -> Dict[AgentID, Dict[str, np.ndarray]]: """Convert episode to numpy array-like data.""" res = {} for agent, agent_trajectory in self.agent_entry.items(): if len(agent_trajectory[Episode.CUR_OBS]) < 2: continue tmp = {} try: for k, v in agent_trajectory.items(): if k in [Episode.CUR_OBS, Episode.CUR_STATE, Episode.ACTION_MASK]: # move to next obs tmp[f"{k}_next"] = np.stack(v[1:]) tmp[k] = np.stack(v[:-1]) elif k in [Episode.PRE_DONE, Episode.PRE_REWARD]: # ignore 'pre_' tmp[k[4:]] = np.stack(v[1:]) if k == Episode.PRE_DONE: assert v[-1], v else: tmp[k] = np.stack(v) except Exception as e: print(traceback.format_exc()) raise e res[agent] = tmp # agent trajectory length check for agent, trajectory in res.items(): expected_length = len(trajectory[Episode.CUR_OBS]) for k, v in trajectory.items(): assert len(v) == expected_length, (len(v), k, expected_length) return dict(res)
[docs]class NewEpisodeDict(defaultdict): """Episode dict, for trajectory tracking for a bunch of environments.""" def __missing__(self, env_id): if self.default_factory is None: raise KeyError(env_id) else: ret = self[env_id] = self.default_factory() return ret
[docs] def record( self, data: Dict[EnvID, Dict[str, Dict[str, Any]]], agent_first: bool, ignore_keys={}, ): for env_id, _data in data.items(): self[env_id].record(_data, agent_first, ignore_keys)
[docs] def to_numpy(self) -> Dict[EnvID, Dict[AgentID, Dict[str, np.ndarray]]]: """Lossy data transformer, which converts a dict of episode to a dict of numpy array like. (some episode may be empty)""" res = {} for k, v in self.items(): tmp: Dict[AgentID, Dict[str, np.ndarray]] = v.to_numpy() if len(tmp) == 0: continue res[k] = tmp return res
import copy
[docs]class NewEpisodeList: def __init__(self, num: int, agents: List[AgentID]) -> None: self.num = num self.agents = agents self.episodes = [Episode(agents) for _ in range(num)] self.episode_buffer = []
[docs] def record( self, data: List[Dict[str, Dict[str, Any]]], agent_first: bool, is_episode_done: List[bool], ignore_keys={}, ): for i, (episode, _data) in enumerate(zip(self.episodes, data)): episode.record(_data, agent_first, ignore_keys) if is_episode_done[i] and not agent_first: self.episode_buffer.append(episode) new_episode = Episode(self.agents) tmp = {k: copy.deepcopy(v) for k, v in _data.items()} new_episode.record(tmp, agent_first, ignore_keys) self.episodes[i] = new_episode
[docs] def to_numpy(self) -> Dict[AgentID, Dict[str, np.ndarray]]: """Lossy data transformer, which converts a dict of episode to a dict of numpy array like. (some episode may be empty)""" res = [] for v in self.episode_buffer: tmp: Dict[AgentID, Dict[str, np.ndarray]] = v.to_numpy() if len(tmp) == 0: continue res.append(tmp) return res