malib.algorithm.common.reward module
Implementation of basic PyTorch-based policy class
- class malib.algorithm.common.reward.Reward(registered_name: str, reward_type: str, observation_space: gym.spaces.space.Space, action_space: gym.spaces.space.Space, model_config: Optional[Dict[str, Any]] = None, custom_config: Optional[Dict[str, Any]] = None)[source]
Bases:
objectCreate a reward model instance.
- Parameters
registered_name (str) – Registered policy name.
reward_type (str) – Reward function deteiled type in practice.
observation_space (gym.spaces.Space) – Raw observation space of related environment agent(s), determines the model input space.
action_space (gym.spaces.Space) – Raw action space of related environment agent(s).
model_config (Dict[str,Any]) – Model configuration to construct models. Default to None.
custom_config (Dict[str,Any]) – Custom configuration, includes some hyper-parameters. Default to None.
- abstract compute_reward(observation: numpy.ndarray, action: numpy.ndarray, **kwargs) Tuple[Any][source]
Compute single reward when rollout at each step, return 1 elements: reward
- abstract compute_rewards(observation: numpy.ndarray, action: numpy.ndarray, **kwargs) numpy.ndarray[source]
Compute batched rewards for the current policy with given inputs.
Legal keys in kwargs:
behavior_mode: behavior mode used to distinguish different behavior of compute actions.
action_mask: action mask.
- property description
Return a dict of basic attributes to identify reward.
The essential elements of returned description:
registered_name: self.registered_name
observation_space: self.observation_space
action_space: self.action_space
model_config: self.model_config
custom_config: self.custom_config
- Returns
A dictionary.
- property exploration_callback: Callable
- load_state(state_dict: Dict[str, Any]) None[source]
Load state dict outside.
Note that the keys in state_dict should be existed in state handler.
- Parameters
state_dict – Dict[str, Any], A dict of state dict
- Raise
KeyError
- register_state(obj: Any, name: str) None[source]
Register state of obj. Called in init function to register model states.
- Example:
>>> class CustomReward(Reward): ... def __init__( ... self, ... registered_name, ... observation_space, ... action_space, ... model_config, ... custom_config ... ): ... # ... ... actor = MLP(...) ... self.register_state(actor, "actor")
- Parameters
obj (Any) – Any object, for non torch.nn.Module, it will be wrapped as a Simpleobject.
name (str) – Humanreadable name, to identify states.
- Raise
malib.utils.errors.RepeatedAssign
- Returns
None