# MIT License
# Copyright (c) 2021 MARL @ SJTU
# Author: Ming Zhou
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
from numbers import Number
from typing import Any, Dict, List, Optional, Tuple, Union, no_type_check, Sequence
from copy import deepcopy
from collections import defaultdict
import h5py
import torch
import pickle
import numpy as np
from malib.utils.tianshou_batch import (
Batch,
_alloc_by_keys_diff,
_create_value,
_parse_value,
)
[docs]@no_type_check
def to_numpy(x: Any) -> Union[Batch, np.ndarray]:
"""Return an object without torch.Tensor."""
if isinstance(x, torch.Tensor): # most often case
return x.detach().cpu().numpy()
elif isinstance(x, np.ndarray): # second often case
return x
elif isinstance(x, (np.number, np.bool_, Number)):
return np.asanyarray(x)
elif x is None:
return np.array(None, dtype=object)
elif isinstance(x, (dict, Batch)):
x = Batch(x) if isinstance(x, dict) else deepcopy(x)
x.to_numpy()
return x
elif isinstance(x, (list, tuple)):
return to_numpy(_parse_value(x))
else: # fallback
return np.asanyarray(x)
# Note: object is used as a proxy for objects that can be pickled
# Note: mypy does not support cyclic definition currently
Hdf5ConvertibleValues = Union[ # type: ignore
int,
float,
Batch,
np.ndarray,
torch.Tensor,
object,
"Hdf5ConvertibleType", # type: ignore
]
Hdf5ConvertibleType = Dict[str, Hdf5ConvertibleValues] # type: ignore
[docs]class ReplayBuffer:
def __init__(
self,
size: int,
stack_num: int = 1,
ignore_obs_next: bool = False,
save_only_last_obs: bool = False,
sample_avail: bool = False,
**kwargs
) -> None:
self.capacity = size
self.data = {}
self.flag = 0
self.size = 0
def __len__(self):
return self.size
[docs] def add_batch(self, data: Dict[str, np.ndarray]):
any_v = list(data.values())[0]
shift = min(0, self.capacity - any_v.shape[0] - self.flag)
for k, v in data.items():
if k not in self.data:
self.data[k] = np.zeros((self.capacity,) + v.shape[1:], dtype=v.dtype)
assert v.shape[0] == any_v.shape[0], (any_v.shape, v.shape, k)
self.data[k] = np.roll(self.data[k], shift=shift, axis=0)
self.data[k][self.flag + shift : self.flag + shift + v.shape[0]] = v
self.flag = (self.flag + shift + any_v.shape[0]) % self.capacity
self.size = min(self.size + any_v.shape[0], self.capacity)
[docs] def sample_indices(self, batch_size: int) -> Sequence[int]:
indices = np.random.choice(self.size, batch_size)
return indices
[docs] def sample(self, batch_size: int) -> Tuple[Batch, List[int]]:
indices = self.sample_indices(batch_size)
samples = {k: v[indices] for k, v in self.data.items()}
return Batch(samples), indices
[docs]class MultiagentReplayBuffer(ReplayBuffer):
def __init__(
self,
size: int,
stack_num: int = 1,
ignore_obs_next: bool = False,
save_only_last_obs: bool = False,
sample_avail: bool = False,
**kwargs
) -> None:
super().__init__(
size, stack_num, ignore_obs_next, save_only_last_obs, sample_avail, **kwargs
)
self.data = defaultdict(
lambda: ReplayBuffer(
size=size,
stack_num=stack_num,
ignore_obs_next=ignore_obs_next,
save_only_last_obs=save_only_last_obs,
sample_avail=sample_avail,
**kwargs
)
)
[docs] def add_batch(self, data: Dict[str, Dict[str, np.ndarray]]):
for agent, _data in data.items():
self.data[agent].add_batch(_data)
size_candidates = set()
for e in self.data.values():
size_candidates.add(e.size)
assert len(size_candidates) == 1, (
size_candidates,
{k: v.size for k, v in self.data.items()},
)
self.size = list(self.data.values())[0].size
[docs] def sample(self, batch_size: int) -> Dict[str, Tuple[Batch, List[int]]]:
agent_batch_tups = {
agent: data.sample(batch_size) for agent, data in self.data.items()
}
return agent_batch_tups