malib.algorithm.common.policy module
Implementation of basic PyTorch-based policy class
- class malib.algorithm.common.policy.Policy(registered_name: str, observation_space: gym.spaces.space.Space, action_space: gym.spaces.space.Space, model_config: Optional[Dict[str, Any]] = None, custom_config: Optional[Dict[str, Any]] = None, **kwargs)[source]
Bases:
objectCreate a policy instance.
- Parameters
registered_name (str) – Registered policy name.
observation_space (gym.spaces.Space) – Raw observation space of related environment agent(s), determines the model input space.
action_space (gym.spaces.Space) – Raw action space of related environment agent(s).
model_config (Dict[str,Any]) – Model configuration to construct models. Default to None.
custom_config (Dict[str,Any]) – Custom configuration, includes some hyper-parameters. Default to None.
- property actor: Any
Return policy, cannot be None
- abstract compute_action(observation: numpy.ndarray, **kwargs) Tuple[numpy.ndarray, numpy.ndarray, List[numpy.ndarray]][source]
Compute single action when rollout at each step, return 3 elements: action, action_dist, a list of rnn_state
- abstract compute_actions(observation: numpy.ndarray, **kwargs) numpy.ndarray[source]
Compute batched actions for the current policy with given inputs.
Legal keys in kwargs:
behavior_mode: behavior mode used to distinguish different behavior of compute actions.
action_mask: action mask.
- property critic: Any
Return critic, can be None
- property description
Return a dict of basic attributes to identify policy.
The essential elements of returned description:
registered_name: self.registered_name
observation_space: self.observation_space
action_space: self.action_space
model_config: self.model_config
custom_config: self.custom_config
- Returns
A dictionary.
- property exploration_callback: Callable
- get_initial_state(batch_size: Optional[int] = None) List[numpy.ndarray][source]
Return a list of rnn states if models are rnns
- load_state(state_dict: Dict[str, Any]) None[source]
Load state dict outside.
Note that the keys in state_dict should be existed in state handler.
- Parameters
state_dict – Dict[str, Any], A dict of state dict
- Raise
KeyError
- register_state(obj: Any, name: str) None[source]
Register state of obj. Called in init function to register model states.
- Example:
>>> class CustomPolicy(Policy): ... def __init__( ... self, ... registered_name, ... observation_space, ... action_space, ... model_config, ... custom_config ... ): ... # ... ... actor = MLP(...) ... self.register_state(actor, "actor")
- Parameters
obj (Any) – Any object, for non torch.nn.Module, it will be wrapped as a Simpleobject.
name (str) – Humanreadable name, to identify states.
- Raise
malib.utils.errors.RepeatedAssign
- Returns
None
- set_actor(actor) None[source]
Set actor. Note repeated assign will raise a warning
:raise RuntimeWarning, repeated assign.