malib.algorithm.common.reward module

Implementation of basic PyTorch-based policy class

class malib.algorithm.common.reward.Reward(registered_name: str, reward_type: str, observation_space: gym.spaces.space.Space, action_space: gym.spaces.space.Space, model_config: Optional[Dict[str, Any]] = None, custom_config: Optional[Dict[str, Any]] = None)[source]

Bases: object

Create a reward model instance.

Parameters
  • registered_name (str) – Registered policy name.

  • reward_type (str) – Reward function deteiled type in practice.

  • observation_space (gym.spaces.Space) – Raw observation space of related environment agent(s), determines the model input space.

  • action_space (gym.spaces.Space) – Raw action space of related environment agent(s).

  • model_config (Dict[str,Any]) – Model configuration to construct models. Default to None.

  • custom_config (Dict[str,Any]) – Custom configuration, includes some hyper-parameters. Default to None.

clip_rewards(rewards)[source]
abstract compute_reward(observation: numpy.ndarray, action: numpy.ndarray, **kwargs) Tuple[Any][source]

Compute single reward when rollout at each step, return 1 elements: reward

abstract compute_rewards(observation: numpy.ndarray, action: numpy.ndarray, **kwargs) numpy.ndarray[source]

Compute batched rewards for the current policy with given inputs.

Legal keys in kwargs:

  • behavior_mode: behavior mode used to distinguish different behavior of compute actions.

  • action_mask: action mask.

deregister_state(name: str)[source]
property description

Return a dict of basic attributes to identify reward.

The essential elements of returned description:

  • registered_name: self.registered_name

  • observation_space: self.observation_space

  • action_space: self.action_space

  • model_config: self.model_config

  • custom_config: self.custom_config

Returns

A dictionary.

eval()[source]
property exploration_callback: Callable
load_state(state_dict: Dict[str, Any]) None[source]

Load state dict outside.

Note that the keys in state_dict should be existed in state handler.

Parameters

state_dict – Dict[str, Any], A dict of state dict

Raise

KeyError

register_state(obj: Any, name: str) None[source]

Register state of obj. Called in init function to register model states.

Example:
>>> class CustomReward(Reward):
...     def __init__(
...         self,
...         registered_name,
...         observation_space,
...         action_space,
...         model_config,
...         custom_config
...     ):
...     # ...
...     actor = MLP(...)
...     self.register_state(actor, "actor")
Parameters
  • obj (Any) – Any object, for non torch.nn.Module, it will be wrapped as a Simpleobject.

  • name (str) – Humanreadable name, to identify states.

Raise

malib.utils.errors.RepeatedAssign

Returns

None

set_weights(parameters: Dict[str, Any])[source]

Set parameter weights.

Parameters

parameters – Dict[str, Any], A dict of parameters.

Returns

state_dict()[source]

Return state dict in real time

train()[source]