malib.rl.common package
Submodules
malib.rl.common.misc module
- malib.rl.common.misc.gumbel_softmax(logits: Tensor, temperature=1.0, mask: Optional[Tensor] = None, explore=False) Tensor[source]
Convert a softmax to one hot but gradients computation will be kept.
- Parameters:
logits (torch.Tensor) – Raw logits tensor.
temperature (float, optional) – Temperature to control the distribution density. Defaults to 1.0.
mask (torch.Tensor, optional) – Action masking. Defaults to None.
explore (bool, optional) – Enable noise adding or not. Defaults to True.
- Returns:
Genearted gumbel softmax, shaped as (batch_size, n_classes)
- Return type:
torch.Tensor
- malib.rl.common.misc.onehot_from_logits(logits: Tensor, eps=0.0)[source]
Given batch of logits, return one-hot sample using epsilon greedy strategy (based on given epsilon)
- malib.rl.common.misc.sample_gumbel(shape: ~torch.Size, eps: float = 1e-20, tens_type: ~typing.Type = <class 'torch.FloatTensor'>) Tensor[source]
Sample noise from an uniform distribution withe a given shape. Note the returned tensor is deactivated for gradients computation.
- Parameters:
shape (torch.Size) – Target shape.
eps (float, optional) – Tolerance to avoid NaN. Defaults to 1e-20.
tens_type (Type, optional) – Indicates the data type of the sampled noise. Defaults to torch.FloatTensor.
- Returns:
A tensor as sampled noise.
- Return type:
torch.Tensor
- malib.rl.common.misc.soft_update(target: Module, source: Module, tau: float)[source]
Perform soft update.
- Parameters:
target (torch.nn.Module) – Net to copy parameters to
source (torch.nn.Module) – Net whose parameters to copy
tau (float) – Range form 0 to 1, weight factor for update
- malib.rl.common.misc.softmax(logits: Tensor, temperature: float, mask: Optional[Tensor] = None, explore: bool = True) Tensor[source]
Apply softmax to the given logits. With distribution density control and optional exploration noise.
- Parameters:
logits (torch.Tensor) – Logits tensor.
temperature (float) – Temperature controls the distribution density.
mask (torch.Tensor, optional) – Applying action mask if not None. Defaults to None.
explore (bool, optional) – Add noise to the generated distribution or not. Defaults to True.
- Raises:
TypeError – Logits should be a torch.Tensor.
- Returns:
softmax tensor, shaped as (batch_size, n_classes).
- Return type:
torch.Tensor
malib.rl.common.policy module
- class malib.rl.common.policy.Policy(observation_space, action_space, model_config, custom_config, **kwargs)[source]
Bases:
object- property actor
- abstract compute_action(observation: Tensor, act_mask: Optional[Tensor], evaluate: bool, hidden_state: Optional[Any] = None, **kwargs) Tuple[Any, Any, Any, Any][source]
- property critic
- property custom_config: Dict[str, Any]
- property device: str
- load_state_dict(state_dict: Dict[str, Any])[source]
Load state dict outside.
- Parameters:
state_dict (Dict[str, Any]) – A dict of states.
- property model_config
- property preprocessor
- register_state(obj: Any, name: str) None[source]
Register state of obj. Called in init function to register model states.
Example
>>> class CustomPolicy(Policy): ... def __init__( ... self, ... registered_name, ... observation_space, ... action_space, ... model_config, ... custom_config ... ): ... # ... ... actor = MLP(...) ... self.register_state(actor, "actor")
- Parameters:
obj (Any) – Any object, for non torch.nn.Module, it will be wrapped as a Simpleobject.
name (str) – Humanreadable name, to identify states.
- Raises:
errors.RepeatedAssignError – [description]
- property registered_networks: Dict[str, Module]
- property target_actor
- property target_critic
- to(device: Optional[str] = None, use_copy: bool = False) Policy[source]
Convert policy to a given device. If use_copy, then return a copy. If device is None, do not change device.
- Parameters:
device (str) – Device identifier.
use_copy (bool, optional) – User copy or not. Defaults to False.
- Raises:
NotImplementedError – Not implemented error.
- Returns:
A policy instance
- Return type:
malib.rl.common.trainer module
- class malib.rl.common.trainer.Trainer(training_config: Dict[str, Any], policy_instance: Optional[Policy] = None)[source]
Bases:
objectInitialize a trainer for a type of policies.
- Parameters:
learning_mode (str) – Learning mode inidication, could be off_policy or on_policy.
training_config (Dict[str, Any], optional) – The training configuration. Defaults to None.
policy_instance (Policy, optional) – A policy instance, if None, we must reset it. Defaults to None.
- property counter
- property policy
- reset(policy_instance=None, configs=None, learning_mode: Optional[str] = None)[source]
Reset current trainer, with given policy instance, training configuration or learning mode.
Note
Becareful to reset the learning mode, since it will change the sample behavior. Specifically, the on_policy mode will sample datas sequntially, which will return a torch.DataLoader to the method self.train. For the off_policy case, the sampler will sample data randomly, which will return a dict to
- Parameters:
policy_instance (Policy, optional) – A policy instance. Defaults to None.
configs (Dict[str, Any], optional) – A training configuration used to update existing one. Defaults to None.
learning_mode (str, optional) – Learning mode, could be off_policy or on_policy. Defaults to None.
- property training_config: Dict[str, Any]