malib.rl.pg package

Submodules

malib.rl.pg.config module

malib.rl.pg.policy module

class malib.rl.pg.policy.PGPolicy(observation_space: Space, action_space: Space, model_config: Dict[str, Any], custom_config: Dict[str, Any], **kwargs)[source]

Bases: Policy

Build a REINFORCE policy whose input and output dims are determined by observation_space and action_space, respectively.

Parameters:
  • observation_space (spaces.Space) – The observation space.

  • action_space (spaces.Space) – The action space.

  • model_config (Dict[str, Any]) – The model configuration dict.

  • custom_config (Dict[str, Any]) – The custom configuration dict.

  • is_fixed (bool, optional) – Indicates fixed policy or trainable policy. Defaults to False.

Raises:
  • NotImplementedError – Does not support other action space type settings except Box and Discrete.

  • TypeError – Unexpected action space.

compute_action(observation: Tensor, act_mask: Optional[Tensor], evaluate: bool, hidden_state: Optional[Any] = None, **kwargs) Tuple[Any, Any, Any, Any][source]
value_function(observation: Tensor, evaluate: bool, **kwargs)[source]

Compute values of critic.

malib.rl.pg.trainer module

class malib.rl.pg.trainer.PGTrainer(training_config: Dict[str, Any], policy_instance: Optional[Policy] = None)[source]

Bases: Trainer

Initialize a trainer for a type of policies.

Parameters:
  • learning_mode (str) – Learning mode inidication, could be off_policy or on_policy.

  • training_config (Dict[str, Any], optional) – The training configuration. Defaults to None.

  • policy_instance (Policy, optional) – A policy instance, if None, we must reset it. Defaults to None.

post_process(batch: Batch, agent_filter: Sequence[str]) Batch[source]

Batch post processing here.

Parameters:

batch (Batch) – Sampled batch.

Raises:

NotImplementedError – Not implemented.

Returns:

A batch instance.

Return type:

Batch

setup()[source]

Set up optimizers here.

train(batch: Batch) Dict[str, Any][source]

Run training, and return info dict.

Parameters:

batch (Union[Dict[AgentID, Batch], Batch]) – A dict of batch or batch

Returns:

A training batch of data.

Return type:

Batch