Source code for malib.rl.common.policy

# MIT License

# Copyright (c) 2021 MARL @ SJTU

# Author: Ming Zhou

# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:

# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.

# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

from abc import ABCMeta, abstractmethod
from typing import Dict, Any, Tuple, Union
from enum import IntEnum

import torch
import torch.nn as nn

from gym import spaces

from malib.utils.preprocessor import get_preprocessor
from malib.common.distributions import make_proba_distribution, Distribution


[docs]class SimpleObject: def __init__(self, obj, name): # assert hasattr(obj, name), f"Object: {obj} has no such attribute named `{name}`" self.obj = obj self.name = name def __str__(self): return f"<SimpleObject, name={self.name}, obj={self.obj}>" def __repr__(self): return f"<SimpleObject, name={self.name}, obj={self.obj}>"
[docs] def load_state_dict(self, v): setattr(self.obj, self.name, v)
[docs] def state_dict(self): value = getattr(self.obj, self.name) return value
Action = Any ActionDist = Any Logits = Any
[docs]class Policy(metaclass=ABCMeta): def __init__( self, observation_space, action_space, model_config, custom_config, **kwargs ): _locals = locals() _locals.pop("self") self._init_args = _locals self._observation_space = observation_space self._action_space = action_space self._model_config = model_config or {} self._custom_config = custom_config or {} self._state_handler_dict = {} self._preprocessor = get_preprocessor( observation_space, mode=self._custom_config.get("preprocess_mode", "flatten"), )(observation_space) self._device = torch.device( "cuda" if self._custom_config.get("use_cuda") else "cpu" ) self._registered_networks: Dict[str, nn.Module] = {} if isinstance(action_space, spaces.Discrete): self.action_type = "discrete" elif isinstance(action_space, spaces.Box): self.action_type = "continuous" else: raise NotImplementedError( "Does not support other action space type settings except Box and Discrete. {}".format( type(action_space) ) ) self.use_cuda = self._custom_config.get("use_cuda", False) self.dist_fn: Distribution = make_proba_distribution( action_space=action_space, use_sde=custom_config.get("use_sde", False), dist_kwargs=custom_config.get("dist_kwargs", None), ) @property def model_config(self): return self._model_config @property def device(self) -> str: return self._device @property def custom_config(self) -> Dict[str, Any]: return self._custom_config @property def target_actor(self): return self._target_actor @target_actor.setter def target_actor(self, value: Any): self._target_actor = value @property def actor(self): return self._actor @actor.setter def actor(self, value: Any): self._actor = value @property def critic(self): return self._critic @critic.setter def critic(self, value: Any): self._critic = value @property def target_critic(self): return self._target_critic @target_critic.setter def target_critic(self, value: Any): self._target_critic = value
[docs] def load_state_dict(self, state_dict: Dict[str, Any]): """Load state dict outside. Args: state_dict (Dict[str, Any]): A dict of states. """ for k, v in state_dict.items(): self._state_handler_dict[k].load_state_dict(v)
[docs] def state_dict(self, device=None): """Return state dict in real time""" if device is None: res = {k: v.state_dict() for k, v in self._state_handler_dict.items()} else: res = {} for k, v in self._state_handler_dict.items(): if isinstance(v, torch.nn.Module): tmp = {} for _k, _v in v.state_dict().items(): tmp[_k] = _v.cpu() else: tmp = v.state_dict() res[k] = tmp return res
[docs] def register_state(self, obj: Any, name: str) -> None: """Register state of obj. Called in init function to register model states. Example: >>> class CustomPolicy(Policy): ... def __init__( ... self, ... registered_name, ... observation_space, ... action_space, ... model_config, ... custom_config ... ): ... # ... ... actor = MLP(...) ... self.register_state(actor, "actor") Args: obj (Any): Any object, for non `torch.nn.Module`, it will be wrapped as a `Simpleobject`. name (str): Humanreadable name, to identify states. Raises: errors.RepeatedAssignError: [description] """ # if not isinstance(obj, nn.Module): if obj.__class__.__module__ == "builtins": n = SimpleObject(self, name) n.load_state_dict(obj) obj = n self._state_handler_dict[name] = obj if isinstance(obj, nn.Module): self._registered_networks[name] = obj
[docs] def deregister_state(self, name: str): if self._state_handler_dict.get(name) is None: print(f"No such state tagged with: {name}") else: self._state_handler_dict.pop(name) print(f"Deregister state tagged with: {name}")
[docs] def get_initial_state(self, batch_size: int = None): return None
@property def preprocessor(self): return self._preprocessor
[docs] @abstractmethod def compute_action( self, observation: torch.Tensor, act_mask: Union[torch.Tensor, None], evaluate: bool, hidden_state: Any = None, **kwargs, ) -> Tuple[Action, ActionDist, Logits, Any]: pass
[docs] def save(self, path, global_step=0, hard: bool = False): state_dict = {"global_step": global_step, **self.state_dict()} torch.save(state_dict, path)
[docs] def load(self, path: str): state_dict = torch.load(path) print( f"[Model Loading] Load policy model with global step={state_dict.pop('global_step')}" ) self.load_state_dict(state_dict)
[docs] def reset(self, **kwargs): """Reset parameters or behavior policies.""" pass
[docs] @classmethod def copy(cls, instance, replacement: Dict): return cls(replacement=replacement, **instance._init_args)
@property def registered_networks(self) -> Dict[str, nn.Module]: return self._registered_networks
[docs] def to(self, device: str = None, use_copy: bool = False) -> "Policy": """Convert policy to a given device. If `use_copy`, then return a copy. If device is None, do not change device. Args: device (str): Device identifier. use_copy (bool, optional): User copy or not. Defaults to False. Raises: NotImplementedError: Not implemented error. Returns: Policy: A policy instance """ if device is None: device = "cpu" if not self.use_cuda else "cuda" cond1 = "cpu" in device and self.use_cuda cond2 = "cuda" in device and not self.use_cuda if "cpu" in device: use_cuda = False else: use_cuda = self._custom_config.get("use_cuda", False) replacement = {} if cond1 or cond2: # retrieve networks here for k, v in self.registered_networks.items(): _v = v.to(device) if not use_copy: setattr(self, k, _v) else: replacement[k] = _v else: # fixed bug: replacement cannot be None. for k, v in self.registered_networks.items(): replacement[k] = v if use_copy: ret = self.copy(self, replacement=replacement) else: self.use_cuda = use_cuda ret = self return ret
[docs] def parameters(self) -> Dict[str, Dict]: """Return trainable parameters.""" res = {} for name, net in self.registered_networks.items(): res[name] = net.parameters() return res
[docs] def update_parameters(self, parameter_dict: Dict[str, Any]): """Update local parameters with given parameter dict. Args: parameter_dict (Dict[str, Parameter]): A dict of paramters """ for k, parameters in parameter_dict.items(): target = self.registered_networks[k] for target_param, param in zip(target.parameters(), parameters): target_param.data.copy_(param.data)
[docs] def coordinate(self, state: Dict[str, torch.Tensor], message: Any) -> Any: """Coordinate with other agents here""" raise NotImplementedError