# reference: https://github.com/thu-ml/tianshou/blob/master/tianshou/utils/net/common.py
import operator
from typing import Any, Dict, List, Optional, Sequence, Tuple, Type, Union
import numpy as np
import torch
import gym
import functools
import copy
from torch import nn
from malib.utils.preprocessor import get_preprocessor
ModuleType = Type[nn.Module]
Device = Type[torch.device]
[docs]def miniblock(
input_size: int,
output_size: int = 0,
norm_layer: Optional[ModuleType] = None,
activation: Optional[ModuleType] = None,
linear_layer: Type[nn.Linear] = nn.Linear,
) -> List[nn.Module]:
"""Construct a miniblock with given input/output-size, norm layer and activation.
Args:
input_size (int): The input size.
output_size (int, optional): The output size. Defaults to 0.
norm_layer (Optional[ModuleType], optional): A nn.Module as normal layer. Defaults to None.
activation (Optional[ModuleType], optional): A nn.Module as active layer. Defaults to None.
linear_layer (Type[nn.Linear], optional): A nn.Module as linear layer. Defaults to nn.Linear.
Returns:
List[nn.Module]: A list of layers.
"""
layers: List[nn.Module] = [linear_layer(input_size, output_size)]
if norm_layer is not None:
layers += [norm_layer(output_size)] # type: ignore
if activation is not None:
layers += [activation()]
return layers
[docs]class MLP(nn.Module):
def __init__(
self,
input_dim: int,
output_dim: int = 0,
hidden_sizes: Sequence[int] = (),
norm_layer: Optional[Union[ModuleType, Sequence[ModuleType]]] = None,
activation: Optional[Union[ModuleType, Sequence[ModuleType]]] = nn.ReLU,
device: Optional[Union[str, int, torch.device]] = None,
linear_layer: Type[nn.Linear] = nn.Linear,
) -> None:
"""Create a MLP.
Args:
input_dim (int): dimension of the input vector.
output_dim (int, optional): dimension of the output vector. If set to 0, \
there is no final linear layer. Defaults to 0.
hidden_sizes (Sequence[int], optional): shape of MLP passed in as a list, not \
including input_dim and output_dim. Defaults to ().
norm_layer (Optional[Union[ModuleType, Sequence[ModuleType]]], optional): use which normalization before \
activation, e.g., ``nn.LayerNorm`` and ``nn.BatchNorm1d``. Default to no normalization. \
You can also pass a list of normalization modules with the same length of hidden_sizes, \
to use different normalization module in different layers. Default to no normalization. Defaults to None.
activation (Optional[Union[ModuleType, Sequence[ModuleType]]], optional): which activation \
to use after each layer, can be both the same activation for all layers if passed in nn.Module, \
or different activation for different Modules if passed in a list. Defaults to nn.ReLU.
device (Optional[Union[str, int, torch.device]], optional): which device to create this model \
on. Defaults to None.
linear_layer (Type[nn.Linear], optional): use this module as linear layer. Defaults to nn.Linear.
"""
super().__init__()
self.device = device
if norm_layer:
if isinstance(norm_layer, list):
assert len(norm_layer) == len(hidden_sizes)
norm_layer_list = norm_layer
else:
norm_layer_list = [norm_layer for _ in range(len(hidden_sizes))]
else:
norm_layer_list = [None] * len(hidden_sizes)
if activation:
if isinstance(activation, list):
assert len(activation) == len(hidden_sizes)
activation_list = activation
else:
activation_list = [activation for _ in range(len(hidden_sizes))]
else:
activation_list = [None] * len(hidden_sizes)
hidden_sizes = [input_dim] + list(hidden_sizes)
model = []
for in_dim, out_dim, norm, activ in zip(
hidden_sizes[:-1], hidden_sizes[1:], norm_layer_list, activation_list
):
model += miniblock(in_dim, out_dim, norm, activ, linear_layer)
if output_dim > 0:
model += [linear_layer(hidden_sizes[-1], output_dim)]
self.output_dim = output_dim or hidden_sizes[-1]
self.model = nn.Sequential(*model)
[docs] def forward(self, obs: Union[np.ndarray, torch.Tensor]) -> torch.Tensor:
if not isinstance(obs, torch.Tensor):
obs = obs.copy() # avoid not-writable warning here
if self.device is not None:
obs = torch.as_tensor(
obs,
device=self.device, # type: ignore
dtype=torch.float32,
)
else:
obs = torch.as_tensor(obs, dtype=torch.float32, device=self.device)
return self.model(obs.flatten(-1)) # type: ignore
[docs]class Net(nn.Module):
def __init__(
self,
state_shape: Union[int, Sequence[int]],
action_shape: Union[int, Sequence[int]] = 0,
hidden_sizes: Sequence[int] = (),
norm_layer: Optional[ModuleType] = None,
activation: Optional[ModuleType] = nn.ReLU,
device: Union[str, int, torch.device] = "cpu",
softmax: bool = False,
concat: bool = False,
num_atoms: int = 1,
dueling_param: Optional[Tuple[Dict[str, Any], Dict[str, Any]]] = None,
) -> None:
super().__init__()
self.device = device
self.softmax = softmax
self.num_atoms = num_atoms
input_dim = int(np.prod(state_shape))
action_dim = int(np.prod(action_shape)) * num_atoms
if concat:
input_dim += action_dim
self.use_dueling = dueling_param is not None
output_dim = action_dim if not self.use_dueling and not concat else 0
self.model = MLP(
input_dim, output_dim, hidden_sizes, norm_layer, activation, device
)
self.output_dim = self.model.output_dim
if self.use_dueling: # dueling DQN
q_kwargs, v_kwargs = dueling_param # type: ignore
q_output_dim, v_output_dim = 0, 0
if not concat:
q_output_dim, v_output_dim = action_dim, num_atoms
q_kwargs: Dict[str, Any] = {
**q_kwargs,
"input_dim": self.output_dim,
"output_dim": q_output_dim,
"device": self.device,
}
v_kwargs: Dict[str, Any] = {
**v_kwargs,
"input_dim": self.output_dim,
"output_dim": v_output_dim,
"device": self.device,
}
self.Q, self.V = MLP(**q_kwargs), MLP(**v_kwargs)
self.output_dim = self.Q.output_dim
[docs] def forward(
self,
obs: Union[np.ndarray, torch.Tensor],
state: Any = None,
info: Dict[str, Any] = {},
) -> Tuple[torch.Tensor, Any]:
"""Mapping: obs -> flatten (inside MLP)-> logits."""
logits = self.model(obs)
bsz = logits.shape[0]
if self.use_dueling: # Dueling DQN
q, v = self.Q(logits), self.V(logits)
if self.num_atoms > 1:
q = q.view(bsz, -1, self.num_atoms)
v = v.view(bsz, -1, self.num_atoms)
logits = q - q.mean(dim=1, keepdim=True) + v
elif self.num_atoms > 1:
logits = logits.view(bsz, -1, self.num_atoms)
if self.softmax:
logits = torch.softmax(logits, dim=-1)
return logits, state
[docs]class Recurrent(nn.Module):
def __init__(
self,
layer_num: int,
state_shape: Union[int, Sequence[int]],
action_shape: Union[int, Sequence[int]],
device: Union[str, int, torch.device] = "cpu",
hidden_layer_size: int = 128,
) -> None:
super().__init__()
self.device = device
self.nn = nn.LSTM(
input_size=hidden_layer_size,
hidden_size=hidden_layer_size,
num_layers=layer_num,
batch_first=True,
)
self.fc1 = nn.Linear(int(np.prod(state_shape)), hidden_layer_size)
self.fc2 = nn.Linear(hidden_layer_size, int(np.prod(action_shape)))
[docs] def forward(
self,
obs: Union[np.ndarray, torch.Tensor],
state: Optional[Dict[str, torch.Tensor]] = None,
info: Dict[str, Any] = {},
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
"""Mapping: obs -> flatten -> logits.
In the evaluation mode, `obs` should be with shape ``[bsz, dim]``; in the
training mode, `obs` should be with shape ``[bsz, len, dim]``. See the code
and comment for more detail.
"""
obs = torch.as_tensor(
obs,
device=self.device, # type: ignore
dtype=torch.float32,
)
# obs [bsz, len, dim] (training) or [bsz, dim] (evaluation)
# In short, the tensor's shape in training phase is longer than which
# in evaluation phase.
if len(obs.shape) == 2:
obs = obs.unsqueeze(-2)
obs = self.fc1(obs)
self.nn.flatten_parameters()
if state is None:
obs, (hidden, cell) = self.nn(obs)
else:
# we store the stack data in [bsz, len, ...] format
# but pytorch rnn needs [len, bsz, ...]
obs, (hidden, cell) = self.nn(
obs,
(
state["hidden"].transpose(0, 1).contiguous(),
state["cell"].transpose(0, 1).contiguous(),
),
)
obs = self.fc2(obs[:, -1])
# please ensure the first dim is batch size: [bsz, len, ...]
return obs, {
"hidden": hidden.transpose(0, 1).detach(),
"cell": cell.transpose(0, 1).detach(),
}
[docs]class ActorCritic(nn.Module):
"""An actor-critic network for parsing parameters.
Using ``actor_critic.parameters()`` instead of set.union or list+list to avoid
issue #449.
:param nn.Module actor: the actor network.
:param nn.Module critic: the critic network.
"""
def __init__(self, actor: nn.Module, critic: nn.Module) -> None:
super().__init__()
self.actor = actor
self.critic = critic
[docs]class DataParallelNet(nn.Module):
"""DataParallel wrapper for training agent with multi-GPU.
This class does only the conversion of input data type, from numpy array to torch's
Tensor. If the input is a nested dictionary, the user should create a similar class
to do the same thing.
:param nn.Module net: the network to be distributed in different GPUs.
"""
def __init__(self, net: nn.Module) -> None:
super().__init__()
self.net = nn.DataParallel(net)
[docs] def forward(
self, obs: Union[np.ndarray, torch.Tensor], *args: Any, **kwargs: Any
) -> Tuple[Any, Any]:
if not isinstance(obs, torch.Tensor):
obs = torch.as_tensor(obs, dtype=torch.float32)
return self.net(obs=obs.cuda(), *args, **kwargs)
def _parse_model_config_from_dict(**kwargs) -> Dict[str, Any]:
"""Parse a given raw dict configuration to capable parameter dict.
Raises:
NotImplementedError: _description_
Returns:
Dict[str, Any]: _description_
"""
res = {}
for k, v in kwargs.items():
if k in ["input_dim", "output_dim", "num_atoms", "hidden_layer_size"]:
res[k] = int(v)
elif k == "hidden_sizes":
assert isinstance(v, Sequence) and isinstance(v[0], int)
res[k] = copy.deepcopy(v)
elif k in ["norm_layer", "activation", "linear_layer"]:
# convert str to module obj
if isinstance(v, str):
res[k] = getattr(torch.nn, v)
elif isinstance(v, Sequence):
res[k] = [getattr(torch.nn, _v) for _v in v]
else:
raise TypeError(f"unexpected type for model configuration: {type(v)}")
elif k in ["softmax", "concat"]:
res[k] = bool(v)
elif k in ["dueling_param", "actor", "critic", "net"]:
res[k] = v if k != "dueling_param" else copy.deepcopy(v)
elif k in ["state_shape", "action_shape", "layer_num"]:
if isinstance(v, Sequence):
res[k] = [int(_v) for _v in v]
else:
res[k] = int(v)
return res
def _make_net_from_observation(
observation_space: gym.Space, device: Device, parsed_model_config: Dict[str, Any]
) -> nn.Module:
"""Make a network from given observation and parsed_model_config. Note here for \
any legal `observation_space`, they will be all set to flatten shape.
Args:
observation_space (gym.Space): Observation space.
parsed_model_config (Dict[str, Any]): Parsed model configuration dict.
Returns:
nn.Module: An network instance.
"""
# always flatten
preprocessor = get_preprocessor(observation_space)(observation_space)
parsed_model_config["device"] = device
return Net(state_shape=preprocessor.shape, **parsed_model_config)
[docs]def make_net(
observation_space: gym.Space,
action_space: gym.Space,
device: Device,
net_type: str = None,
**kwargs,
) -> nn.Module:
"""Create a network instance with specific network configuration.
Args:
observation_space (gym.Space): The observation space used to determine \
which network type will be used, if net_type is not be specified
action_space (gym.Space): The action space will be used to determine the network \
output dim, if `output_dim` or `action_shape` is not given in kwargs
device (Device): Indicates device allocated.
net_type (str, optional): Indicates the network type, could be one from \
{mlp, net, rnn, actor_critic, data_parallel}
Raises:
ValueError: Unexpected network type.
Returns:
nn.Module: A network instance.
"""
# parse custom_config
cls = None
parsed_model_config = _parse_model_config_from_dict(device=device, **kwargs)
if net_type is None:
return _make_net_from_observation(
observation_space, device, parsed_model_config
)
else:
if net_type == "mlp":
cls = MLP
# compute input dim here
parsed_model_config["input_dim"] = (
functools.reduce(
operator.mul,
get_preprocessor(observation_space)(observation_space).shape,
)
if not isinstance(observation_space, gym.spaces.Discrete)
else 1
)
if "output_dim" not in parsed_model_config:
parsed_model_config["output_dim"] = (
action_space.n
if isinstance(action_space, gym.spaces.Discrete)
else functools.reduce(operator.mul, action_space.shape)
)
elif net_type == "general_net":
cls = Net
parsed_model_config["state_shape"] = get_preprocessor(observation_space)(
observation_space
).shape
if "action_shape" not in parsed_model_config:
parsed_model_config["action_shape"] = (
(action_space.n,)
if isinstance(action_space, gym.spaces.Discrete)
else action_space.shape
)
elif net_type == "rnn":
cls = Recurrent
parsed_model_config["state_shape"] = (
get_preprocessor(observation_space)(observation_space).shape
if not isinstance(observation_space, gym.spaces.Discrete)
else (1,)
)
if "action_shape" not in parsed_model_config:
parsed_model_config["action_shape"] = (
(action_space.n,)
if isinstance(action_space, gym.spaces.Discrete)
else action_space.shape
)
elif net_type == "actor_critic":
cls = ActorCritic
elif net_type == "data_parallel":
cls = DataParallelNet
else:
raise ValueError("Unexpected net type: {}".format(net_type))
if net_type not in ["actor_critic", "data_parallel"]:
parsed_model_config["device"] = device
net = cls(**parsed_model_config).to(device)
return net