from typing import OrderedDict
import torch
import numpy as np
import copy
from malib import settings
from malib.utils.typing import Dict, Callable, List, Tuple
[docs]def update_configs(update_dict, ori_dict=None):
"""Update global configs with a given dict"""
ori_configs = (
copy.copy(ori_dict)
if ori_dict is not None
else copy.copy(settings.DEFAULT_CONFIG)
)
for k, v in update_dict.items():
# assert k in ori_configs, f"Illegal key: {k}, {list(ori_configs.keys())}"
if isinstance(v, dict):
ph = ori_configs[k] if isinstance(ori_configs.get(k), dict) else {}
ori_configs[k] = update_configs(v, ph)
else:
ori_configs[k] = copy.copy(v)
# clean environment desc
return ori_configs
# TODO(ming): will be replaced with many dicts
[docs]def iter_dicts_recursively(d1, d2):
"""Assuming dicts have the exact same structure."""
for k, v in d1.items():
assert k in d2
if isinstance(v, (dict, OrderedDict)):
yield from iter_dicts_recursively(d1[k], d2[k])
else:
yield d1, d2, k, d1[k], d2[k]
[docs]def iter_many_dicts_recursively(*d, history=None):
"""Assuming dicts have the exact same structure, or raise KeyError."""
for k, v in d[0].items():
if isinstance(v, (dict, OrderedDict)):
yield from iter_many_dicts_recursively(
*[_d[k] for _d in d],
history=history + [k] if history is not None else None,
)
else:
if history is None:
yield d, k, tuple([_d[k] for _d in d])
else:
yield history + [k], d, k, tuple([_d[k] for _d in d])
[docs]class BufferDict(dict):
@property
def capacity(self) -> int:
capacities = []
for _, _, v in iterate_recursively(self):
capacities.append(v.shape[0])
return max(capacities)
[docs] def index(self, indices):
return self.index_func(self, indices)
[docs] def index_func(self, x, indices):
if isinstance(x, (dict, BufferDict)):
res = BufferDict()
for k, v in x.items():
res[k] = self.index_func(v, indices)
return res
else:
t = x[indices]
# Logger.debug("sampled data shape: {} {}".format(t.shape, indices))
return t
[docs] def set_data(self, index, new_data):
return self.set_data_func(self, index, new_data)
[docs] def set_data_func(self, x, index, new_data):
if isinstance(new_data, (dict, BufferDict)):
for nk, nv in new_data.items():
self.set_data_func(x[nk], index, nv)
else:
if isinstance(new_data, torch.Tensor):
t = new_data.cpu().numpy()
elif isinstance(new_data, np.ndarray):
t = new_data
else:
raise TypeError(
f"Unexpected type for new insert data: {type(new_data)}, expected is np.ndarray"
)
x[index] = t.copy()
[docs]def iterate_recursively(d: Dict):
for k, v in d.items():
if isinstance(v, (dict, BufferDict)):
yield from iterate_recursively(v)
else:
yield d, k, v
def _default_dtype_mapping(dtype):
# FIXME(ming): cast 64 to 32?
if dtype in [np.int32, np.int64, int]:
return torch.int32
elif dtype in [float, np.float32, np.float16]:
return torch.float32
elif dtype == np.float64:
return torch.float64
elif dtype in [bool, np.bool_]:
return torch.float32
else:
raise NotImplementedError(f"dtype: {dtype} has no transmission rule.") from None
# wrap with type checking
def _walk(caster, v):
if isinstance(v, (dict, BufferDict)):
for k, _v in v.items():
v[k] = _walk(caster, _v)
else:
v = caster(v)
return v
[docs]def tensor_cast(
custom_caster: Callable = None,
callback: Callable = None,
dtype_mapping: Dict = None,
device="cpu",
):
"""Casting the inputs of a method into tensors if needed.
Note:
This function does not support recursive iteration.
Args:
custom_caster (Callable, optional): Customized caster. Defaults to None.
callback (Callable, optional): Callback function, accepts returns of wrapped function as inputs. Defaults to None.
dtype_mapping (Dict, optional): Specify the data type for inputs which you wanna. Defaults to None.
Returns:
Callable: A decorator.
"""
dtype_mapping = dtype_mapping or _default_dtype_mapping
cast_to_tensor = custom_caster or (
lambda x: torch.FloatTensor(x.copy()).to(
device=device, dtype=dtype_mapping(x.dtype)
)
if not isinstance(x, torch.Tensor)
else x
)
def decorator(func):
def wrap(*args, **kwargs):
new_args = []
for i, arg in enumerate(args):
new_args.append(_walk(cast_to_tensor, arg))
for k, v in kwargs.items():
kwargs[k] = _walk(cast_to_tensor, v)
rets = func(*new_args, **kwargs)
if callback is not None:
callback(rets)
return rets
return wrap
return decorator
[docs]def frozen_data(data):
_hash = 0
if isinstance(data, Dict):
for k, v in data.items():
_v = frozen_data(v)
_hash ^= hash((k, _v))
elif isinstance(data, (List, Tuple)):
for e in data:
_hash ^= hash(e)
else:
return hash(data)
return _hash