malib.rl.common package

Submodules

malib.rl.common.misc module

malib.rl.common.misc.gumbel_softmax(logits: Tensor, temperature=1.0, mask: Optional[Tensor] = None, explore=False) Tensor[source]

Convert a softmax to one hot but gradients computation will be kept.

Parameters:
  • logits (torch.Tensor) – Raw logits tensor.

  • temperature (float, optional) – Temperature to control the distribution density. Defaults to 1.0.

  • mask (torch.Tensor, optional) – Action masking. Defaults to None.

  • explore (bool, optional) – Enable noise adding or not. Defaults to True.

Returns:

Genearted gumbel softmax, shaped as (batch_size, n_classes)

Return type:

torch.Tensor

malib.rl.common.misc.masked_logits(logits: Tensor, mask: Tensor)[source]
malib.rl.common.misc.onehot_from_logits(logits: Tensor, eps=0.0)[source]

Given batch of logits, return one-hot sample using epsilon greedy strategy (based on given epsilon)

malib.rl.common.misc.sample_gumbel(shape: ~torch.Size, eps: float = 1e-20, tens_type: ~typing.Type = <class 'torch.FloatTensor'>) Tensor[source]

Sample noise from an uniform distribution withe a given shape. Note the returned tensor is deactivated for gradients computation.

Parameters:
  • shape (torch.Size) – Target shape.

  • eps (float, optional) – Tolerance to avoid NaN. Defaults to 1e-20.

  • tens_type (Type, optional) – Indicates the data type of the sampled noise. Defaults to torch.FloatTensor.

Returns:

A tensor as sampled noise.

Return type:

torch.Tensor

malib.rl.common.misc.soft_update(target: Module, source: Module, tau: float)[source]

Perform soft update.

Parameters:
  • target (torch.nn.Module) – Net to copy parameters to

  • source (torch.nn.Module) – Net whose parameters to copy

  • tau (float) – Range form 0 to 1, weight factor for update

malib.rl.common.misc.softmax(logits: Tensor, temperature: float, mask: Optional[Tensor] = None, explore: bool = True) Tensor[source]

Apply softmax to the given logits. With distribution density control and optional exploration noise.

Parameters:
  • logits (torch.Tensor) – Logits tensor.

  • temperature (float) – Temperature controls the distribution density.

  • mask (torch.Tensor, optional) – Applying action mask if not None. Defaults to None.

  • explore (bool, optional) – Add noise to the generated distribution or not. Defaults to True.

Raises:

TypeError – Logits should be a torch.Tensor.

Returns:

softmax tensor, shaped as (batch_size, n_classes).

Return type:

torch.Tensor

malib.rl.common.policy module

class malib.rl.common.policy.Policy(observation_space, action_space, model_config, custom_config, **kwargs)[source]

Bases: object

property actor
abstract compute_action(observation: Tensor, act_mask: Optional[Tensor], evaluate: bool, hidden_state: Optional[Any] = None, **kwargs) Tuple[Any, Any, Any, Any][source]
coordinate(state: Dict[str, Tensor], message: Any) Any[source]

Coordinate with other agents here

classmethod copy(instance, replacement: Dict)[source]
property critic
property custom_config: Dict[str, Any]
deregister_state(name: str)[source]
property device: str
get_initial_state(batch_size: Optional[int] = None)[source]
load(path: str)[source]
load_state_dict(state_dict: Dict[str, Any])[source]

Load state dict outside.

Parameters:

state_dict (Dict[str, Any]) – A dict of states.

property model_config
parameters() Dict[str, Dict][source]

Return trainable parameters.

property preprocessor
register_state(obj: Any, name: str) None[source]

Register state of obj. Called in init function to register model states.

Example

>>> class CustomPolicy(Policy):
...     def __init__(
...         self,
...         registered_name,
...         observation_space,
...         action_space,
...         model_config,
...         custom_config
...     ):
...     # ...
...     actor = MLP(...)
...     self.register_state(actor, "actor")
Parameters:
  • obj (Any) – Any object, for non torch.nn.Module, it will be wrapped as a Simpleobject.

  • name (str) – Humanreadable name, to identify states.

Raises:

errors.RepeatedAssignError – [description]

property registered_networks: Dict[str, Module]
reset(**kwargs)[source]

Reset parameters or behavior policies.

save(path, global_step=0, hard: bool = False)[source]
state_dict(device=None)[source]

Return state dict in real time

property target_actor
property target_critic
to(device: Optional[str] = None, use_copy: bool = False) Policy[source]

Convert policy to a given device. If use_copy, then return a copy. If device is None, do not change device.

Parameters:
  • device (str) – Device identifier.

  • use_copy (bool, optional) – User copy or not. Defaults to False.

Raises:

NotImplementedError – Not implemented error.

Returns:

A policy instance

Return type:

Policy

update_parameters(parameter_dict: Dict[str, Any])[source]

Update local parameters with given parameter dict.

Parameters:

parameter_dict (Dict[str, Parameter]) – A dict of paramters

class malib.rl.common.policy.SimpleObject(obj, name)[source]

Bases: object

load_state_dict(v)[source]
state_dict()[source]

malib.rl.common.trainer module

class malib.rl.common.trainer.Trainer(training_config: Dict[str, Any], policy_instance: Optional[Policy] = None)[source]

Bases: object

Initialize a trainer for a type of policies.

Parameters:
  • learning_mode (str) – Learning mode inidication, could be off_policy or on_policy.

  • training_config (Dict[str, Any], optional) – The training configuration. Defaults to None.

  • policy_instance (Policy, optional) – A policy instance, if None, we must reset it. Defaults to None.

property counter
parameters()[source]
property policy
post_process(batch: Batch, agent_filter: Sequence[str]) Batch[source]

Batch post processing here.

Parameters:

batch (Batch) – Sampled batch.

Raises:

NotImplementedError – Not implemented.

Returns:

A batch instance.

Return type:

Batch

reset(policy_instance=None, configs=None, learning_mode: Optional[str] = None)[source]

Reset current trainer, with given policy instance, training configuration or learning mode.

Note

Becareful to reset the learning mode, since it will change the sample behavior. Specifically, the on_policy mode will sample datas sequntially, which will return a torch.DataLoader to the method self.train. For the off_policy case, the sampler will sample data randomly, which will return a dict to

Parameters:
  • policy_instance (Policy, optional) – A policy instance. Defaults to None.

  • configs (Dict[str, Any], optional) – A training configuration used to update existing one. Defaults to None.

  • learning_mode (str, optional) – Learning mode, could be off_policy or on_policy. Defaults to None.

abstract setup()[source]

Set up optimizers here.

step_counter()[source]
abstract train(batch: Batch) Dict[str, float][source]

Run training, and return info dict.

Parameters:

batch (Union[Dict[AgentID, Batch], Batch]) – A dict of batch or batch

Returns:

A training batch of data.

Return type:

Batch

property training_config: Dict[str, Any]