"""
Implementation of basic PyTorch-based policy class
"""
import gym
from abc import ABCMeta, abstractmethod
import torch
import torch.nn as nn
from malib.utils import errors
from malib.utils.typing import (
DataTransferType,
ModelConfig,
Dict,
Any,
Tuple,
Callable,
List,
)
from malib.utils.preprocessor import get_preprocessor, Mode
from malib.utils.notations import deprecated
[docs]class SimpleObject:
def __init__(self, obj, name):
assert hasattr(obj, name), f"Object: {obj} has no such attribute named `{name}`"
self.obj = obj
self.name = name
[docs] def load_state_dict(self, v):
setattr(self.obj, self.name, v)
[docs] def state_dict(self):
value = getattr(self.obj, self.name)
return value
DEFAULT_MODEL_CONFIG = {
"actor": {
"network": "mlp",
"layers": [
{"units": 64, "activation": "ReLU"},
{"units": 64, "activation": "ReLU"},
],
"output": {"activation": False},
},
"critic": {
"network": "mlp",
"layers": [
{"units": 64, "activation": "ReLU"},
{"units": 64, "activation": "ReLU"},
],
"output": {"activation": False},
},
}
[docs]class Policy(metaclass=ABCMeta):
def __init__(
self,
registered_name: str,
observation_space: gym.spaces.Space,
action_space: gym.spaces.Space,
model_config: ModelConfig = None,
custom_config: Dict[str, Any] = None,
**kwargs,
):
"""Create a policy instance.
:param str registered_name: Registered policy name.
:param gym.spaces.Space observation_space: Raw observation space of related environment agent(s), determines
the model input space.
:param gym.spaces.Space action_space: Raw action space of related environment agent(s).
:param Dict[str,Any] model_config: Model configuration to construct models. Default to None.
:param Dict[str,Any] custom_config: Custom configuration, includes some hyper-parameters. Default to None.
"""
self.registered_name = registered_name
self.observation_space = observation_space
self.action_space = action_space
self.device = torch.device("cpu")
self.custom_config = {
"gamma": 0.99,
"use_cuda": False,
"use_dueling": False,
"preprocess_mode": Mode.FLATTEN,
}
self.model_config = DEFAULT_MODEL_CONFIG
if custom_config is None:
custom_config = {}
self.custom_config.update(custom_config)
# FIXME(ming): use deep update rule
if model_config is None:
model_config = {}
self.model_config.update(model_config)
self.preprocessor = get_preprocessor(
observation_space, self.custom_config["preprocess_mode"]
)(observation_space)
self._state_handler_dict = {}
self._actor = None
self._critic = None
self._exploration_callback = None
self._kwargs = kwargs
@property
def exploration_callback(self) -> Callable:
return self._exploration_callback
[docs] def register_state(self, obj: Any, name: str) -> None:
"""Register state of obj. Called in init function to register model states.
Example:
>>> class CustomPolicy(Policy):
... def __init__(
... self,
... registered_name,
... observation_space,
... action_space,
... model_config,
... custom_config
... ):
... # ...
... actor = MLP(...)
... self.register_state(actor, "actor")
:param Any obj: Any object, for non `torch.nn.Module`, it will be wrapped as a `Simpleobject`.
:param str name: Humanreadable name, to identify states.
:raise: malib.utils.errors.RepeatedAssign
:return: None
"""
if not isinstance(obj, nn.Module):
obj = SimpleObject(self, name)
if self._state_handler_dict.get(name, None) is not None:
raise errors.RepeatedAssignError(
f"state handler named with {name} is not None."
)
self._state_handler_dict[name] = obj
[docs] def deregister_state(self, name: str):
if self._state_handler_dict.get(name) is None:
print(f"No such state tagged with: {name}")
else:
self._state_handler_dict.pop(name)
print(f"Deregister state tagged with: {name}")
@property
def description(self):
"""Return a dict of basic attributes to identify policy.
The essential elements of returned description:
- registered_name: `self.registered_name`
- observation_space: `self.observation_space`
- action_space: `self.action_space`
- model_config: `self.model_config`
- custom_config: `self.custom_config`
:return: A dictionary.
"""
return {
"registered_name": self.registered_name,
"observation_space": self.observation_space,
"action_space": self.action_space,
"model_config": self.model_config,
"custom_config": self.custom_config,
}
[docs] @abstractmethod
def compute_actions(
self, observation: DataTransferType, **kwargs
) -> DataTransferType:
"""Compute batched actions for the current policy with given inputs.
Legal keys in kwargs:
- behavior_mode: behavior mode used to distinguish different behavior of compute actions.
- action_mask: action mask.
"""
[docs] @abstractmethod
def compute_action(
self, observation: DataTransferType, **kwargs
) -> Tuple[DataTransferType, DataTransferType, List[DataTransferType]]:
"""Compute single action when rollout at each step, return 3 elements:
action, action_dist, a list of rnn_state
"""
[docs] def get_initial_state(self, batch_size: int = None) -> List[DataTransferType]:
"""Return a list of rnn states if models are rnns"""
return []
[docs] def state_dict(self):
"""Return state dict in real time"""
res = {k: v.state_dict() for k, v in self._state_handler_dict.items()}
return res
[docs] def load_state(self, state_dict: Dict[str, Any]) -> None:
"""Load state dict outside.
Note that the keys in `state_dict` should be existed in state handler.
:param state_dict: Dict[str, Any], A dict of state dict
:raise: KeyError
"""
for k, v in state_dict.items():
self._state_handler_dict[k].load_state_dict(v)
# XXX(ziyu): Add tests for it.
[docs] def set_weights(self, parameters: Dict[str, Any]):
"""Set parameter weights.
:param parameters: Dict[str, Any], A dict of parameters.
:return:
"""
for k, v in parameters.items():
# FIXME(ming): strict mode for parameter reload
self._state_handler_dict[k].load_state_dict(v)
[docs] def set_actor(self, actor) -> None:
"""Set actor. Note repeated assign will raise a warning
:raise RuntimeWarning, repeated assign.
"""
# if self._actor is not None:
# raise RuntimeWarning("repeated actor assign")
self._actor = actor
[docs] def set_critic(self, critic):
"""Set critic"""
# if self._critic is not None:
# raise RuntimeWarning("repeated critic assign")
self._critic = critic
@property
def actor(self) -> Any:
"""Return policy, cannot be None"""
return self._actor
@property
def critic(self) -> Any:
"""Return critic, can be None"""
return self._critic
[docs] @deprecated
def train(self):
pass
[docs] @deprecated
def eval(self):
pass
# @abstractmethod
[docs] def reset(self):
"""Reset policy intermediates"""
pass
[docs] def to_device(self, device):
self.device = device
return self
[docs] def value_function(self, *args, **kwargs):
"""Compute values of critic."""
raise NotImplementedError