malib.algorithm.common.misc module
- class malib.algorithm.common.misc.EPSGreedy(action_dimension: int, threshold: float = 0.3)[source]
Bases:
object
- class malib.algorithm.common.misc.OUNoise(action_dimension: int, scale=0.1, mu=0, theta=0.15, sigma=0.2)[source]
Bases:
object
- malib.algorithm.common.misc.gumbel_softmax(logits: numpy.ndarray, temperature=1.0, hard=False, explore=True)[source]
Sample from the Gumbel-Softmax distribution and optionally discretize.
- Note:
modified for PyTorch from https://github.com/ericjang/gumbel-softmax/blob/master/Categorical%20VAE.ipynb
- Parameters
logits (DataTransferType) – Unnormalized log-probs.
temperature (float) – Non-negative scalar.
hard (bool) – If ture take argmax, but differentiate w.r.t. soft sample y
- :returns [batch_size, n_class] sample from the Gumbel-Softmax distribution. If hard=True, then the returned sample
will be one-hot, otherwise it will be a probability distribution that sums to 1 across classes
- malib.algorithm.common.misc.gumbel_softmax_sample(logits, temperature, explore: bool = True)[source]
Draw a sample from the Gumbel-Softmax distribution.
- Note:
modified for PyTorch from https://github.com/ericjang/gumbel-softmax/blob/master/Categorical%20VAE.ipynb
- malib.algorithm.common.misc.hard_update(target, source)[source]
Copy network parameters from source to target.
- Parameters
target (torch.nn.Module) – Net to copy parameters to.
source (torch.nn.Module) – Net whose parameters to copy
- malib.algorithm.common.misc.onehot_from_logits(logits, eps=0.0)[source]
Given batch of logits, return one-hot sample using epsilon greedy strategy (based on given epsilon)
- malib.algorithm.common.misc.sample_gumbel(shape, eps=1e-20, tens_type=<class 'torch.FloatTensor'>)[source]
Sample from Gumbel(0, 1).
- Note:
modified for PyTorch from https://github.com/ericjang/gumbel-softmax/blob/master/Categorical%20VAE.ipynb
- malib.algorithm.common.misc.soft_update(target, source, tau)[source]
Perform DDPG soft update (move target params toward source based on weight factor tau).
- Parameters
target (torch.nn.Module) – Net to copy parameters to
source (torch.nn.Module) – Net whose parameters to copy
tau (float) – Range form 0 to 1, weight factor for update