# reference: https://github.com/thu-ml/tianshou/blob/master/tianshou/data/batch.py
from typing import Any, Union, Optional, Collection, Dict
from numbers import Number
import torch
import numpy as np
from numba import njit
def _is_scalar(value: Any) -> bool:
# check if the value is a scalar
# 1. python bool object, number object: isinstance(value, Number)
# 2. numpy scalar: isinstance(value, np.generic)
# 3. python object rather than dict / Batch / tensor
# the check of dict / Batch is omitted because this only checks a value.
# a dict / Batch will eventually check their values
if isinstance(value, torch.Tensor):
return value.numel() == 1 and not value.shape
else:
# np.asanyarray will cause dead loop in some cases
return np.isscalar(value)
def _is_number(value: Any) -> bool:
# isinstance(value, Number) checks 1, 1.0, np.int(1), np.float(1.0), etc.
# isinstance(value, np.nummber) checks np.int32(1), np.float64(1.0), etc.
# isinstance(value, np.bool_) checks np.bool_(True), etc.
# similar to np.isscalar but np.isscalar('st') returns True
return isinstance(value, (Number, np.number, np.bool_))
def _to_array_with_correct_type(obj: Any) -> np.ndarray:
if isinstance(obj, np.ndarray) and issubclass(
obj.dtype.type, (np.bool_, np.number)
):
return obj # most often case
# convert the value to np.ndarray
# convert to object obj type if neither bool nor number
# raises an exception if array's elements are tensors themselves
obj_array = np.asanyarray(obj)
if not issubclass(obj_array.dtype.type, (np.bool_, np.number)):
obj_array = obj_array.astype(object)
if obj_array.dtype == object:
# scalar ndarray with object obj type is very annoying
# a=np.array([np.array({}, dtype=object), np.array({}, dtype=object)])
# a is not array([{}, {}], dtype=object), and a[0]={} results in
# something very strange:
# array([{}, array({}, dtype=object)], dtype=object)
if not obj_array.shape:
obj_array = obj_array.item(0)
elif all(isinstance(arr, np.ndarray) for arr in obj_array.reshape(-1)):
return obj_array # various length, np.array([[1], [2, 3], [4, 5, 6]])
elif any(isinstance(arr, torch.Tensor) for arr in obj_array.reshape(-1)):
raise ValueError("Numpy arrays of tensors are not supported yet.")
return
def _parse_value(obj: Any) -> Optional[Union[np.ndarray, torch.Tensor]]:
if (
(
isinstance(obj, np.ndarray)
and issubclass(obj.dtype.type, (np.bool_, np.number))
)
or isinstance(obj, torch.Tensor)
or obj is None
): # third often case
return obj
elif _is_number(obj): # second often case, but it is more time-consuming
return np.asanyarray(obj)
else:
if (
not isinstance(obj, np.ndarray)
and isinstance(obj, Collection)
and len(obj) > 0
and all(isinstance(element, torch.Tensor) for element in obj)
):
try:
return torch.stack(obj) # type: ignore
except RuntimeError as exception:
raise TypeError(
"Batch does not support non-stackable iterable"
" of torch.Tensor as unique value yet."
) from exception
# None, scalar, normal obj list (main case)
# or an actual list of objects
try:
obj = _to_array_with_correct_type(obj)
except ValueError as exception:
raise TypeError(
"Batch does not support heterogeneous list/"
"tuple of tensors as unique value yet."
) from exception
return obj
[docs]def to_torch(
x: Any,
dtype: Optional[torch.dtype] = None,
device: Union[str, int, torch.device] = "cpu",
) -> torch.Tensor:
"""Return an object without np.ndarray."""
if isinstance(x, np.ndarray) and issubclass(
x.dtype.type, (np.bool_, np.number)
): # most often case
x = torch.from_numpy(x).to(device) # type: ignore
if dtype is not None:
x = x.type(dtype)
return x
elif isinstance(x, torch.Tensor): # second often case
if dtype is not None:
x = x.type(dtype)
return x.to(device) # type: ignore
elif isinstance(x, (np.number, np.bool_, Number)):
return to_torch(np.asanyarray(x), dtype, device)
elif isinstance(x, (list, tuple)):
return to_torch(_parse_value(x), dtype, device)
else: # fallback
raise TypeError(f"object {x} cannot be converted to torch.")
def _gae_return(
v_s: np.ndarray,
v_s_: np.ndarray,
rew: np.ndarray,
end_flag: np.ndarray,
gamma: float,
gae_lambda: float,
) -> np.ndarray:
returns = np.zeros(rew.shape, dtype=np.float32)
delta = rew + v_s_ * gamma - v_s
discount = (1.0 - end_flag) * (gamma * gae_lambda)
gae = 0.0
for i in range(rew.shape[0] - 1, -1, -1):
gae = delta[i] + discount[i] * gae
returns[i] = gae
return returns
@njit
def _nstep_return(
rew: np.ndarray,
end_flag: np.ndarray,
target_q: np.ndarray,
indices: np.ndarray,
gamma: float,
n_step: int,
) -> np.ndarray:
gamma_buffer = np.ones(n_step + 1)
for i in range(1, n_step + 1):
gamma_buffer[i] = gamma_buffer[i - 1] * gamma
target_shape = target_q.shape
bsz = target_shape[0]
# change target_q to 2d array
target_q = target_q.reshape(bsz, -1)
returns = np.zeros(target_q.shape)
gammas = np.full(indices[0].shape, n_step)
for n in range(n_step - 1, -1, -1):
now = indices[n]
gammas[end_flag[now] > 0] = n + 1
returns[end_flag[now] > 0] = 0.0
returns = rew[now].reshape(bsz, 1) + gamma * returns
target_q = target_q * gamma_buffer[gammas].reshape(bsz, 1) + returns
return target_q.reshape(target_shape)
[docs]class Postprocessor:
[docs] @staticmethod
def gae_return(
state_value,
next_state_value,
reward,
done,
gamma: float = 0.99,
gae_lambda: float = 0.95,
):
adv = _gae_return(
state_value, next_state_value, reward, done, gamma, gae_lambda
)
return adv
[docs] @staticmethod
def compute_episodic_return(
batch: Dict[str, Any],
state_value: np.ndarray = None,
next_state_value: np.ndarray = None,
gamma: float = 0.99,
gae_lambda: float = 0.95,
):
if isinstance(batch["rew"], torch.Tensor):
rew = batch["rew"].cpu().numpy()
else:
rew = batch["rew"]
if isinstance(batch["done"], torch.Tensor):
done = batch["done"].cpu().numpy()
else:
done = batch["done"]
if next_state_value is None:
assert np.isclose(gae_lambda, 1.0)
next_state_value = np.zeros_like(rew)
else:
# mask next_state_value
next_state_value = next_state_value * (1.0 - done)
state_value = (
np.roll(next_state_value, 1) if state_value is None else state_value
)
# XXX(ming): why we clip the unfinished index?
# end_flag = batch.done.copy()
# truncated
# end_flag[np.isin(indices, buffer.unfinished_index())] = True
if gae_lambda == 0:
returns = rew + gamma * next_state_value
else:
advantage = Postprocessor.gae_return(
state_value,
next_state_value,
rew,
done,
gamma,
gae_lambda,
)
returns = advantage + state_value
# normalization varies from each policy, so we don't do it here
return returns, advantage