Source code for malib.rl.coma.critic

# MIT License

# Copyright (c) 2021 MARL @ SJTU

# Author: Ming Zhou

# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:

# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.

# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

from typing import Union, Dict, Any, Tuple

import numpy as np
import torch
import gym

from torch import nn
from gym import spaces

from malib.utils.episode import Episode
from malib.utils.tianshou_batch import Batch
from malib.models.torch import make_net


[docs]class COMADiscreteCritic(nn.Module): def __init__( self, centralized_obs_space: gym.Space, action_space: gym.Space, net_type: str = None, device: str = "cpu", **kwargs ) -> None: super(COMADiscreteCritic, self).__init__() self.net_type = net_type self.net = make_net( observation_space=centralized_obs_space, action_space=action_space, device=device, net_type=net_type, **kwargs ) def _build_inputs(self, agent_batch: Dict[str, Batch]) -> torch.Tensor: # concat states agents = list(agent_batch.keys()) n_agents = len(agents) # concat by agent-axes: (batch, time_step(optional), inner_dim) -> (batch, time_step(optional), num_agent, inner_dim) states = torch.stack( [agent_batch[k][Episode.CUR_STATE] for k in agents], dim=-2 ) observations = torch.stack( [agent_batch[k][Episode.CUR_OBS] for k in agents], dim=-2 ) use_timestep = len(states.shape) > 3 actions = torch.stack([agent_batch[k][Episode.ACTION] for k in agents], dim=-2) agent_mask = 1 - torch.eye(n_agents, device=states.device) # shape trans: (n_agents, n_agents) -> (n_agents^2, 1) -> (n_agents^2, n_action) # -> (n_agents, n_action * n_agents) agent_mask = ( agent_mask.view(-1, 1).repeat(1, actions.shape[-1]).view(n_agents, -1) ) if use_timestep: batch_size, time_step, _, _ = states.size() actions = actions.view(batch_size, time_step, 1, -1) actions = actions.repeat(1, 1, n_agents, 1) agent_mask = agent_mask.unsqueeze(0).unsqueeze(0) else: ( batch_size, _, _, ) = states.size() actions = actions.view(batch_size, 1, -1) actions = actions.repeat(1, n_agents, 1) agent_mask = agent_mask.unsqueeze(0) actions = actions * agent_mask # shape as: (batch, time_step(optional), agent, inner_dim) inputs = torch.cat([states, observations, actions], dim=-1) return inputs
[docs] def forward( self, inputs: Union[Dict[str, Batch], torch.Tensor] ) -> Union[Tuple[torch.Tensor, Any], torch.Tensor]: if isinstance(inputs, Dict): inputs = self._build_inputs(inputs) assert isinstance(inputs, torch.Tensor), type(inputs) ori_shape = inputs.shape if self.net_type == "rnn": logits, hidden_state = self.net(inputs.view(-1, ori_shape[-1])) logits = logits.reshape(ori_shape[:-1] + (-1,)) hidden_state = { k: v.reshape(ori_shape[:-1] + v.shape[1:]) for k, v in hidden_state.items() } return (logits, hidden_state) else: return self.net(inputs)