malib.rl.dqn package

Submodules

malib.rl.dqn.config module

malib.rl.dqn.policy module

class malib.rl.dqn.policy.DQNPolicy(observation_space: Space, action_space: Space, model_config: Dict[str, Any], custom_config: Dict[str, Any], **kwargs)[source]

Bases: Policy

compute_action(observation: Tensor, act_mask: Optional[Tensor], evaluate: bool, hidden_state: Optional[Any] = None, **kwargs)[source]

Compute action in rollout stage. Do not support vector mode yet.

Parameters:
  • observation (DataArray) – The observation batched data with shape=(n_batch, obs_shape).

  • act_mask (DataArray) – The action mask batched with shape=(n_batch, mask_shape).

  • evaluate (bool) – Turn off exploration or not.

  • state (Any, Optional) – The hidden state. Default by None.

property eps: float
load(path: str)[source]
parameters()[source]

Return trainable parameters.

reset(**kwargs)[source]

Reset parameters or behavior policies.

save(path, global_step=0, hard: bool = False)[source]
value_function(observation: Tensor, evaluate: bool, **kwargs) ndarray[source]

malib.rl.dqn.trainer module

class malib.rl.dqn.trainer.DQNTrainer(training_config: Dict[str, Any], policy_instance: Optional[Policy] = None)[source]

Bases: Trainer

Initialize a trainer for a type of policies.

Parameters:
  • learning_mode (str) – Learning mode inidication, could be off_policy or on_policy.

  • training_config (Dict[str, Any], optional) – The training configuration. Defaults to None.

  • policy_instance (Policy, optional) – A policy instance, if None, we must reset it. Defaults to None.

post_process(batch: Batch, agent_filter: Sequence[str]) Batch[source]

Batch post processing here.

Parameters:

batch (Batch) – Sampled batch.

Raises:

NotImplementedError – Not implemented.

Returns:

A batch instance.

Return type:

Batch

setup()[source]

Set up optimizers here.

train(batch: Batch)[source]

Run training, and return info dict.

Parameters:

batch (Union[Dict[AgentID, Batch], Batch]) – A dict of batch or batch

Returns:

A training batch of data.

Return type:

Batch