Source code for malib.common.distributions

"""
Probability distributions.
Reference: https://github.com/DLR-RM/stable-baselines3/blob/master/stable_baselines3/common/distributions.py
"""

from abc import ABC, abstractmethod

import operator
import math

from typing import Any, Dict, List, Optional, Tuple, Union
from functools import reduce

import gym
import torch as torch

from gym import spaces
from torch import nn
from torch.nn import functional as F
from torch.distributions import Bernoulli, Categorical, Normal

from torch.distributions.utils import lazy_property
from torch.distributions import utils as distr_utils
from torch.distributions.categorical import Categorical as TorchCategorical


[docs]class Distribution(ABC): """Abstract base class for distributions.""" def __init__(self): super(Distribution, self).__init__() self.distribution = None
[docs] @abstractmethod def proba_distribution_net( self, *args, **kwargs ) -> Union[nn.Module, Tuple[nn.Module, nn.Parameter]]: """Create the layers and parameters that represent the distribution. Subclasses must define this, but the arguments and return type vary between concrete classes."""
[docs] @abstractmethod def proba_distribution(self, *args, **kwargs) -> "Distribution": """Set parameters of the distribution. :return: self """
[docs] @abstractmethod def log_prob(self, x: torch.Tensor) -> torch.Tensor: """ Returns the log likelihood :param x: the taken action :return: The log likelihood of the distribution """
[docs] @abstractmethod def entropy(self) -> Optional[torch.Tensor]: """ Returns Shannon's entropy of the probability :return: the entropy, or None if no analytical form is known """
[docs] @abstractmethod def sample(self) -> torch.Tensor: """ Returns a sample from the probability distribution :return: the stochastic action """
[docs] @abstractmethod def prob(self) -> torch.Tensor: """Return a tensor which indicates the distribution Returns: torch.Tensor: A distribution tensor """
[docs] @abstractmethod def mode(self) -> torch.Tensor: """ Returns the most likely action (deterministic output) from the probability distribution :return: the stochastic action """
[docs] def get_actions(self, deterministic: bool = False) -> torch.Tensor: """ Return actions according to the probability distribution. :param deterministic: :return: """ if deterministic: return self.mode() return self.sample()
[docs] @abstractmethod def actions_from_params(self, *args, **kwargs) -> torch.Tensor: """ Returns samples from the probability distribution given its parameters. :return: actions """
[docs] @abstractmethod def log_prob_from_params( self, *args, **kwargs ) -> Tuple[torch.Tensor, torch.Tensor]: """ Returns samples and the associated log probabilities from the probability distribution given its parameters. :return: actions and log prob """
[docs]def sum_independent_dims(tensor: torch.Tensor) -> torch.Tensor: """ Continuous actions are usually considered to be independent, so we can sum components of the ``log_prob`` or the entropy. :param tensor: shape: (n_batch, n_actions) or (n_batch,) :return: shape: (n_batch,) """ if len(tensor.shape) > 1: tensor = tensor.sum(dim=1) else: tensor = tensor.sum() return tensor
[docs]class DiagGaussianDistribution(Distribution): """ Gaussian distribution with diagonal covariance matrix, for continuous actions. :param action_dim: Dimension of the action space. """ def __init__(self, action_dim: int): super(DiagGaussianDistribution, self).__init__() self.action_dim = action_dim self.mean_actions = None self.log_std = None
[docs] def proba_distribution_net( self, latent_dim: int, log_std_init: float = 0.0 ) -> Tuple[nn.Module, nn.Parameter]: """ Create the layers and parameter that represent the distribution: one output will be the mean of the Gaussian, the other parameter will be the standard deviation (log std in fact to allow negative values) :param latent_dim: Dimension of the last layer of the policy (before the action layer) :param log_std_init: Initial value for the log standard deviation :return: """ mean_actions = nn.Linear(latent_dim, self.action_dim) # TODO: allow action dependent std log_std = nn.Parameter( torch.ones(self.action_dim) * log_std_init, requires_grad=True ) return mean_actions, log_std
[docs] def proba_distribution( self, mean_actions: torch.Tensor, log_std: torch.Tensor ) -> "DiagGaussianDistribution": """ Create the distribution given its parameters (mean, std) :param mean_actions: :param log_std: :return: """ action_std = torch.ones_like(mean_actions) * log_std.exp() self.distribution = Normal(mean_actions, action_std) return self
[docs] def prob(self) -> torch.Tensor: return RuntimeError("Normal distribution has no static probs")
[docs] def log_prob(self, actions: torch.Tensor) -> torch.Tensor: """ Get the log probabilities of actions according to the distribution. Note that you must first call the ``proba_distribution()`` method. :param actions: :return: """ log_prob = self.distribution.log_prob(actions) return sum_independent_dims(log_prob)
[docs] def entropy(self) -> torch.Tensor: return sum_independent_dims(self.distribution.entropy())
[docs] def sample(self) -> torch.Tensor: # Reparametrization trick to pass gradients return self.distribution.rsample()
[docs] def mode(self) -> torch.Tensor: return self.distribution.mean
[docs] def actions_from_params( self, mean_actions: torch.Tensor, log_std: torch.Tensor, deterministic: bool = False, ) -> torch.Tensor: # Update the proba distribution self.proba_distribution(mean_actions, log_std) return self.get_actions(deterministic=deterministic)
[docs] def log_prob_from_params( self, mean_actions: torch.Tensor, log_std: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: """ Compute the log probability of taking an action given the distribution parameters. :param mean_actions: :param log_std: :return: """ actions = self.actions_from_params(mean_actions, log_std) log_prob = self.log_prob(actions) return actions, log_prob
[docs]class SquashedDiagGaussianDistribution(DiagGaussianDistribution): """ Gaussian distribution with diagonal covariance matrix, followed by a squashing function (tanh) to ensure bounds. :param action_dim: Dimension of the action space. :param epsilon: small value to avoid NaN due to numerical imprecision. """ def __init__(self, action_dim: int, epsilon: float = 1e-6): super(SquashedDiagGaussianDistribution, self).__init__(action_dim) # Avoid NaN (prevents division by zero or log of zero) self.epsilon = epsilon self.gaussian_actions = None
[docs] def proba_distribution( self, mean_actions: torch.Tensor, log_std: torch.Tensor ) -> "SquashedDiagGaussianDistribution": super(SquashedDiagGaussianDistribution, self).proba_distribution( mean_actions, log_std ) return self
[docs] def log_prob( self, actions: torch.Tensor, gaussian_actions: Optional[torch.Tensor] = None ) -> torch.Tensor: # Inverse tanh # Naive implementation (not stable): 0.5 * torch.log((1 + x) / (1 - x)) # We use numpy to avoid numerical instability if gaussian_actions is None: # It will be clipped to avoid NaN when inversing tanh gaussian_actions = TanhBijector.inverse(actions) # Log likelihood for a Gaussian distribution log_prob = super(SquashedDiagGaussianDistribution, self).log_prob( gaussian_actions ) # Squash correction (from original SAC implementation) # this comes from the fact that tanh is bijective and differentiable log_prob -= torch.sum(torch.log(1 - actions**2 + self.epsilon), dim=1) return log_prob
[docs] def entropy(self) -> Optional[torch.Tensor]: # No analytical form, # entropy needs to be estimated using -log_prob.mean() return None
[docs] def sample(self) -> torch.Tensor: # Reparametrization trick to pass gradients self.gaussian_actions = super().sample() return torch.tanh(self.gaussian_actions)
[docs] def mode(self) -> torch.Tensor: self.gaussian_actions = super().mode() # Squash the output return torch.tanh(self.gaussian_actions)
[docs] def log_prob_from_params( self, mean_actions: torch.Tensor, log_std: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: action = self.actions_from_params(mean_actions, log_std) log_prob = self.log_prob(action, self.gaussian_actions) return action, log_prob
[docs]class CategoricalDistribution(Distribution): """ Categorical distribution for discrete actions. Args: action_dim (int): Number of discrete actions. """ def __init__(self, action_dim: int): super(CategoricalDistribution, self).__init__() self.action_dim = action_dim
[docs] def proba_distribution_net(self, latent_dim: int) -> nn.Module: """ Create the layer that represents the distribution: it will be the logits of the Categorical distribution. You can then get probabilities using a softmax. :param latent_dim: Dimension of the last layer of the policy network (before the action layer) :return: """ action_logits = nn.Linear(latent_dim, self.action_dim) return action_logits
[docs] def proba_distribution( self, action_logits: torch.Tensor, action_mask: torch.Tensor = None ) -> "CategoricalDistribution": if action_mask is not None: self.distribution = MaskedCategorical(action_logits, action_mask) else: self.distribution = Categorical(logits=action_logits) return self
[docs] def log_prob(self, actions: torch.Tensor) -> torch.Tensor: return self.distribution.log_prob(actions)
[docs] def entropy(self) -> torch.Tensor: return self.distribution.entropy()
[docs] def sample(self) -> torch.Tensor: return self.distribution.sample()
[docs] def prob(self) -> torch.Tensor: return self.distribution.probs
[docs] def mode(self) -> torch.Tensor: return torch.argmax(self.distribution.probs, dim=1)
[docs] def actions_from_params( self, action_logits: torch.Tensor, deterministic: bool = False ) -> torch.Tensor: # Update the proba distribution self.proba_distribution(action_logits) return self.get_actions(deterministic=deterministic)
[docs] def log_prob_from_params( self, action_logits: torch.Tensor, deterministic: bool = False ) -> Tuple[torch.Tensor, torch.Tensor]: actions = self.actions_from_params(action_logits, deterministic) log_prob = self.log_prob(actions) return actions, log_prob
[docs]class MultiCategoricalDistribution(Distribution): """ MultiCategorical distribution for multi discrete actions. :param action_dims: List of sizes of discrete action spaces """ def __init__(self, action_dims: List[int]): super(MultiCategoricalDistribution, self).__init__() self.action_dims = action_dims
[docs] def proba_distribution_net(self, latent_dim: int) -> nn.Module: """ Create the layer that represents the distribution: it will be the logits (flattened) of the MultiCategorical distribution. You can then get probabilities using a softmax on each sub-space. :param latent_dim: Dimension of the last layer of the policy network (before the action layer) :return: """ action_logits = nn.Linear(latent_dim, sum(self.action_dims)) return action_logits
[docs] def proba_distribution( self, action_logits: torch.Tensor ) -> "MultiCategoricalDistribution": self.distribution = [ Categorical(logits=split) for split in torch.split(action_logits, tuple(self.action_dims), dim=1) ] return self
[docs] def log_prob(self, actions: torch.Tensor) -> torch.Tensor: # Extract each discrete action and compute log prob for their respective distributions return torch.stack( [ dist.log_prob(action) for dist, action in zip(self.distribution, torch.unbind(actions, dim=1)) ], dim=1, ).sum(dim=1)
[docs] def entropy(self) -> torch.Tensor: return torch.stack([dist.entropy() for dist in self.distribution], dim=1).sum( dim=1 )
[docs] def sample(self) -> torch.Tensor: return torch.stack([dist.sample() for dist in self.distribution], dim=1)
[docs] def mode(self) -> torch.Tensor: return torch.stack( [torch.argmax(dist.probs, dim=1) for dist in self.distribution], dim=1 )
[docs] def actions_from_params( self, action_logits: torch.Tensor, deterministic: bool = False ) -> torch.Tensor: # Update the proba distribution self.proba_distribution(action_logits) return self.get_actions(deterministic=deterministic)
[docs] def log_prob_from_params( self, action_logits: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: actions = self.actions_from_params(action_logits) log_prob = self.log_prob(actions) return actions, log_prob
[docs]class BernoulliDistribution(Distribution): """ Bernoulli distribution for MultiBinary action spaces. :param action_dim: Number of binary actions """ def __init__(self, action_dims: int): super(BernoulliDistribution, self).__init__() self.action_dims = action_dims
[docs] def proba_distribution_net(self, latent_dim: int) -> nn.Module: """ Create the layer that represents the distribution: it will be the logits of the Bernoulli distribution. :param latent_dim: Dimension of the last layer of the policy network (before the action layer) :return: """ action_logits = nn.Linear(latent_dim, self.action_dims) return action_logits
[docs] def proba_distribution( self, action_logits: torch.Tensor ) -> "BernoulliDistribution": self.distribution = Bernoulli(logits=action_logits) return self
[docs] def log_prob(self, actions: torch.Tensor) -> torch.Tensor: return self.distribution.log_prob(actions).sum(dim=1)
[docs] def entropy(self) -> torch.Tensor: return self.distribution.entropy().sum(dim=1)
[docs] def sample(self) -> torch.Tensor: return self.distribution.sample()
[docs] def mode(self) -> torch.Tensor: return torch.round(self.distribution.probs)
[docs] def actions_from_params( self, action_logits: torch.Tensor, deterministic: bool = False ) -> torch.Tensor: # Update the proba distribution self.proba_distribution(action_logits) return self.get_actions(deterministic=deterministic)
[docs] def log_prob_from_params( self, action_logits: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: actions = self.actions_from_params(action_logits) log_prob = self.log_prob(actions) return actions, log_prob
[docs]class StateDependentNoiseDistribution(Distribution): """ Distribution class for using generalized State Dependent Exploration (gSDE). Paper: https://arxiv.org/abs/2005.05719 It is used to create the noise exploration matrix and compute the log probability of an action with that noise. :param action_dim: Dimension of the action space. :param full_std: Whether to use (n_features x n_actions) parameters for the std instead of only (n_features,) :param use_expln: Use ``expln()`` function instead of ``exp()`` to ensure a positive standard deviation (cf paper). It allows to keep variance above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough. :param squash_output: Whether to squash the output using a tanh function, this ensures bounds are satisfied. :param learn_features: Whether to learn features for gSDE or not. This will enable gradients to be backpropagated through the features ``latent_sde`` in the code. :param epsilon: small value to avoid NaN due to numerical imprecision. """ def __init__( self, action_dim: int, full_std: bool = True, use_expln: bool = False, squash_output: bool = False, learn_features: bool = False, epsilon: float = 1e-6, ): super(StateDependentNoiseDistribution, self).__init__() self.action_dim = action_dim self.latent_sde_dim = None self.mean_actions = None self.log_std = None self.weights_dist = None self.exploration_mat = None self.exploration_matrices = None self._latent_sde = None self.use_expln = use_expln self.full_std = full_std self.epsilon = epsilon self.learn_features = learn_features if squash_output: self.bijector = TanhBijector(epsilon) else: self.bijector = None
[docs] def get_std(self, log_std: torch.Tensor) -> torch.Tensor: """ Get the standard deviation from the learned parameter (log of it by default). This ensures that the std is positive. :param log_std: :return: """ if self.use_expln: # From gSDE paper, it allows to keep variance # above zero and prevent it from growing too fast below_threshold = torch.exp(log_std) * (log_std <= 0) # Avoid NaN: zeros values that are below zero safe_log_std = log_std * (log_std > 0) + self.epsilon above_threshold = (torch.log1p(safe_log_std) + 1.0) * (log_std > 0) std = below_threshold + above_threshold else: # Use normal exponential std = torch.exp(log_std) if self.full_std: return std # Reduce the number of parameters: return torch.ones(self.latent_sde_dim, self.action_dim).to(log_std.device) * std
[docs] def sample_weights(self, log_std: torch.Tensor, batch_size: int = 1) -> None: """ Sample weights for the noise exploration matrix, using a centered Gaussian distribution. :param log_std: :param batch_size: """ std = self.get_std(log_std) self.weights_dist = Normal(torch.zeros_like(std), std) # Reparametrization trick to pass gradients self.exploration_mat = self.weights_dist.rsample() # Pre-compute matrices in case of parallel exploration self.exploration_matrices = self.weights_dist.rsample((batch_size,))
[docs] def proba_distribution_net( self, latent_dim: int, log_std_init: float = -2.0, latent_sde_dim: Optional[int] = None, ) -> Tuple[nn.Module, nn.Parameter]: """ Create the layers and parameter that represent the distribution: one output will be the deterministic action, the other parameter will be the standard deviation of the distribution that control the weights of the noise matrix. :param latent_dim: Dimension of the last layer of the policy (before the action layer) :param log_std_init: Initial value for the log standard deviation :param latent_sde_dim: Dimension of the last layer of the features extractor for gSDE. By default, it is shared with the policy network. :return: """ # Network for the deterministic action, it represents the mean of the distribution mean_actions_net = nn.Linear(latent_dim, self.action_dim) # When we learn features for the noise, the feature dimension # can be different between the policy and the noise network self.latent_sde_dim = latent_dim if latent_sde_dim is None else latent_sde_dim # Reduce the number of parameters if needed log_std = ( torch.ones(self.latent_sde_dim, self.action_dim) if self.full_std else torch.ones(self.latent_sde_dim, 1) ) # Transform it to a parameter so it can be optimized log_std = nn.Parameter(log_std * log_std_init, requires_grad=True) # Sample an exploration matrix self.sample_weights(log_std) return mean_actions_net, log_std
[docs] def proba_distribution( self, mean_actions: torch.Tensor, log_std: torch.Tensor, latent_sde: torch.Tensor, ) -> "StateDependentNoiseDistribution": """ Create the distribution given its parameters (mean, std) :param mean_actions: :param log_std: :param latent_sde: :return: """ # Stop gradient if we don't want to influence the features self._latent_sde = latent_sde if self.learn_features else latent_sde.detach() variance = torch.mm(self._latent_sde**2, self.get_std(log_std) ** 2) self.distribution = Normal(mean_actions, torch.sqrt(variance + self.epsilon)) return self
[docs] def log_prob(self, actions: torch.Tensor) -> torch.Tensor: if self.bijector is not None: gaussian_actions = self.bijector.inverse(actions) else: gaussian_actions = actions # log likelihood for a gaussian log_prob = self.distribution.log_prob(gaussian_actions) # Sum along action dim log_prob = sum_independent_dims(log_prob) if self.bijector is not None: # Squash correction (from original SAC implementation) log_prob -= torch.sum( self.bijector.log_prob_correction(gaussian_actions), dim=1 ) return log_prob
[docs] def entropy(self) -> Optional[torch.Tensor]: if self.bijector is not None: # No analytical form, # entropy needs to be estimated using -log_prob.mean() return None return sum_independent_dims(self.distribution.entropy())
[docs] def sample(self) -> torch.Tensor: noise = self.get_noise(self._latent_sde) actions = self.distribution.mean + noise if self.bijector is not None: return self.bijector.forward(actions) return actions
[docs] def mode(self) -> torch.Tensor: actions = self.distribution.mean if self.bijector is not None: return self.bijector.forward(actions) return actions
[docs] def get_noise(self, latent_sde: torch.Tensor) -> torch.Tensor: latent_sde = latent_sde if self.learn_features else latent_sde.detach() # Default case: only one exploration matrix if len(latent_sde) == 1 or len(latent_sde) != len(self.exploration_matrices): return torch.mm(latent_sde, self.exploration_mat) # Use batch matrix multiplication for efficient computation # (batch_size, n_features) -> (batch_size, 1, n_features) latent_sde = latent_sde.unsqueeze(1) # (batch_size, 1, n_actions) noise = torch.bmm(latent_sde, self.exploration_matrices) return noise.squeeze(1)
[docs] def actions_from_params( self, mean_actions: torch.Tensor, log_std: torch.Tensor, latent_sde: torch.Tensor, deterministic: bool = False, ) -> torch.Tensor: # Update the proba distribution self.proba_distribution(mean_actions, log_std, latent_sde) return self.get_actions(deterministic=deterministic)
[docs] def log_prob_from_params( self, mean_actions: torch.Tensor, log_std: torch.Tensor, latent_sde: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: actions = self.actions_from_params(mean_actions, log_std, latent_sde) log_prob = self.log_prob(actions) return actions, log_prob
[docs]class TanhBijector(object): """ Bijective transformation of a probability distribution using a squashing function (tanh) TODO: use Pyro instead (https://pyro.ai/) :param epsilon: small value to avoid NaN due to numerical imprecision. """ def __init__(self, epsilon: float = 1e-6): super(TanhBijector, self).__init__() self.epsilon = epsilon
[docs] @staticmethod def forward(x: torch.Tensor) -> torch.Tensor: return torch.tanh(x)
[docs] @staticmethod def atanh(x: torch.Tensor) -> torch.Tensor: """ Inverse of Tanh Taken from Pyro: https://github.com/pyro-ppl/pyro 0.5 * torch.log((1 + x ) / (1 - x)) """ return 0.5 * (x.log1p() - (-x).log1p())
[docs] @staticmethod def inverse(y: torch.Tensor) -> torch.Tensor: """ Inverse tanh. :param y: :return: """ eps = torch.finfo(y.dtype).eps # Clip the action to avoid NaN return TanhBijector.atanh(y.clamp(min=-1.0 + eps, max=1.0 - eps))
[docs] def log_prob_correction(self, x: torch.Tensor) -> torch.Tensor: # Squash correction (from original SAC implementation) return torch.log(1.0 - torch.tanh(x) ** 2 + self.epsilon)
[docs]class MaskedCategorical: def __init__(self, scores, mask=None): self.mask = mask if mask is None: self.cat_distr = TorchCategorical(F.softmax(scores, dim=-1)) self.n = scores.shape[0] self.log_n = math.log(self.n) else: if not isinstance(self.mask, torch.Tensor): self.mask = torch.FloatTensor(self.mask, device=scores.device) self.n = self.mask.sum(dim=-1) self.log_n = (self.n + 1e-17).log() self.cat_distr = TorchCategorical( MaskedCategorical.masked_softmax(scores, self.mask) ) @lazy_property def probs(self): return self.cat_distr.probs @lazy_property def logits(self): return self.cat_distr.logits @lazy_property def entropy(self): if self.mask is None: return self.cat_distr.entropy() * (self.n != 1) else: entropy = -torch.sum( self.cat_distr.logits * self.cat_distr.probs * self.mask, dim=-1 ) does_not_have_one_category = (self.n != 1.0).to(dtype=torch.float32) # to make sure that the entropy is precisely zero when there is only one category return entropy * does_not_have_one_category @lazy_property def normalized_entropy(self): return self.entropy / (self.log_n + 1e-17)
[docs] def sample(self): return self.cat_distr.sample()
[docs] def rsample(self, temperature=None, gumbel_noise=None): if gumbel_noise is None: with torch.no_grad(): uniforms = torch.empty_like(self.probs).uniform_() uniforms = distr_utils.clamp_probs(uniforms) gumbel_noise = -(-uniforms.log()).log() # TODO(ming): This is used for debugging (to get the same samples) and is not differentiable. # gumbel_noise = None # _sample = self.cat_distr.sample() # sample = torch.zeros_like(self.probs) # sample.scatter_(-1, _sample[:, None], 1.0) # return sample, gumbel_noise elif gumbel_noise.shape != self.probs.shape: raise ValueError if temperature is None: with torch.no_grad(): scores = self.logits + gumbel_noise scores = MaskedCategorical.masked_softmax(scores, self.mask) sample = torch.zeros_like(scores) sample.scatter_(-1, scores.argmax(dim=-1, keepdim=True), 1.0) return sample, gumbel_noise else: scores = (self.logits + gumbel_noise) / temperature sample = MaskedCategorical.masked_softmax(scores, self.mask) return sample, gumbel_noise
[docs] def log_prob(self, value): if value.dtype == torch.long: if self.mask is None: return self.cat_distr.log_prob(value) else: return self.cat_distr.log_prob(value) * (self.n != 0.0).to( dtype=torch.float32 ) else: max_values, mv_idxs = value.max(dim=-1) relaxed = (max_values - torch.ones_like(max_values)).sum().item() != 0.0 if relaxed: raise ValueError( "The log_prob can't be calculated for the relaxed sample!" ) return self.cat_distr.log_prob(mv_idxs) * (self.n != 0.0).to( dtype=torch.float32 )
[docs] @staticmethod def masked_softmax(logits, mask): """ This method will return valid probability distribution for the particular instance if its corresponding row in the `mask` matrix is not a zero vector. Otherwise, a uniform distribution will be returned. This is just a technical workaround that allows `Categorical` class usage. If probs doesn't sum to one there will be an exception during sampling. """ probs = F.softmax(logits, dim=-1) * mask probs = probs + (mask.sum(dim=-1, keepdim=True) == 0.0).to(dtype=torch.float32) Z = probs.sum(dim=-1, keepdim=True) return probs / Z
[docs]def make_proba_distribution( action_space: gym.spaces.Space, use_sde: bool = False, dist_kwargs: Optional[Dict[str, Any]] = None, ) -> Distribution: """Return an instance of Distribution for the correct type of action space. Args: action_space (gym.spaces.Space): The action space. use_sde (bool, optional): Force the use of StateDependentNoiseDistribution instead of DiagGaussianDistribution. Defaults to False. dist_kwargs (Optional[Dict[str, Any]], optional): Keyword arguments to pass to the probability distribution. Defaults to None. Raises: NotImplementedError: Probability distribution not implemented for the specified action space. Returns: Distribution: The appropriate Distribution object """ if dist_kwargs is None: dist_kwargs = {} if isinstance(action_space, spaces.Box): assert len(action_space.shape) == 1, "Error: the action space must be a vector" cls = StateDependentNoiseDistribution if use_sde else DiagGaussianDistribution return cls(reduce(operator.mul, action_space.shape), **dist_kwargs) elif isinstance(action_space, spaces.Discrete): return CategoricalDistribution(action_space.n, **dist_kwargs) elif isinstance(action_space, spaces.MultiDiscrete): return MultiCategoricalDistribution(action_space.nvec, **dist_kwargs) elif isinstance(action_space, spaces.MultiBinary): return BernoulliDistribution(action_space.n, **dist_kwargs) else: raise NotImplementedError( "Error: probability distribution, not implemented for action space" f"of type {type(action_space)}." " Must be of type Gym Spaces: Box, Discrete, MultiDiscrete or MultiBinary." )
[docs]def kl_divergence(dist_true: Distribution, dist_pred: Distribution) -> torch.Tensor: """ Wrapper for the PyTorch implementation of the full form KL Divergence :param dist_true: the p distribution :param dist_pred: the q distribution :return: KL(dist_true||dist_pred) """ # KL Divergence for different distribution types is out of scope assert ( dist_true.__class__ == dist_pred.__class__ ), "Error: input distributions should be the same type" # MultiCategoricalDistribution is not a PyTorch Distribution subclass # so we need to implement it ourselves! if isinstance(dist_pred, MultiCategoricalDistribution): assert ( dist_pred.action_dims == dist_true.action_dims ), "Error: distributions must have the same input space" return torch.stack( [ torch.distributions.kl_divergence(p, q) for p, q in zip(dist_true.distribution, dist_pred.distribution) ], dim=1, ).sum(dim=1) # Use the PyTorch kl_divergence implementation else: return torch.distributions.kl_divergence( dist_true.distribution, dist_pred.distribution )