malib.utils package
Submodules
malib.utils.data module
malib.utils.episode module
- class malib.utils.episode.Episode(agents: List[str], processors=None)[source]
Bases:
object
Multi-agent episode tracking
- ACC_REWARD = 'accumulate_reward'
- ACTION = 'act'
- ACTION_DIST = 'act_dist'
- ACTION_LOGITS = 'act_logits'
- ACTION_MASK = 'act_mask'
- ADVANTAGE = 'advantage'
- CUR_OBS = 'obs'
- CUR_STATE = 'state'
- DONE = 'done'
- INFO = 'infos'
- LAST_REWARD = 'last_reward'
- NEXT_ACTION_MASK = 'act_mask_next'
- NEXT_OBS = 'obs_next'
- NEXT_STATE = 'state_next'
- PRE_DONE = 'pre_done'
- PRE_REWARD = 'pre_rew'
- REWARD = 'rew'
- RNN_STATE = 'rnn_state'
- STATE_ACTION_VALUE = 'state_action_value_estimation'
- STATE_VALUE = 'state_value_estimation'
- STATE_VALUE_TARGET = 'state_value_target'
- record(data: Dict[str, Dict[str, Any]], agent_first: bool, ignore_keys={})[source]
Save a transiton. The given transition is a sub sequence of (obs, action_mask, reward, done, info). Users specify ignore keys to filter keys.
- Parameters:
data (Dict[str, Dict[AgentID, Any]]) – A transition.
ignore_keys (dict, optional) – . Defaults to {}.
- class malib.utils.episode.NewEpisodeDict[source]
Bases:
defaultdict
Episode dict, for trajectory tracking for a bunch of environments.
malib.utils.exploitability module
- class malib.utils.exploitability.NFSPPolicies(game, nfsp_policies: List[TabularPolicy])[source]
Bases:
Policy
Joint policy to be evaluated.
Initializes a policy.
- Parameters:
game – the game for which this policy applies
player_ids – list of player ids for which this policy applies; each should be in the range 0..game.num_players()-1.
- action_probabilities(state: Any, player_id: Optional[str] = None)[source]
Returns a dictionary {action: prob} for all legal actions.
IMPORTANT: We assume the following properties hold: - All probabilities are >=0 and sum to 1 - TLDR: Policy implementations should list the (action, prob) for all legal
actions, but algorithms should not rely on this (yet). Details: Before May 2020, only legal actions were present in the mapping, but it did not have to be exhaustive: missing actions were considered to be associated to a zero probability. For example, a deterministic state-poliy was previously {action: 1.0}. Given this change of convention is new and hard to enforce, algorithms should not rely on the fact that all legal actions should be present.
- Parameters:
state – A pyspiel.State object.
player_id – Optional, the player id for whom we want an action. Optional unless this is a simultaneous state at which multiple players can act.
- Returns:
probability}` for the specified player in the supplied state.
- Return type:
A dict of `{action
- class malib.utils.exploitability.OSPolicyWrapper(game, policy: Policy, player_ids: List[int], use_observation, tolerance: float = 1e-05)[source]
Bases:
Policy
Initializes a policy.
- Parameters:
game – the game for which this policy applies
player_ids – list of player ids for which this policy applies; each should be in the range 0..game.num_players()-1.
- action_probabilities(state, player_id=None)[source]
Returns a dictionary {action: prob} for all legal actions.
IMPORTANT: We assume the following properties hold: - All probabilities are >=0 and sum to 1 - TLDR: Policy implementations should list the (action, prob) for all legal
actions, but algorithms should not rely on this (yet). Details: Before May 2020, only legal actions were present in the mapping, but it did not have to be exhaustive: missing actions were considered to be associated to a zero probability. For example, a deterministic state-poliy was previously {action: 1.0}. Given this change of convention is new and hard to enforce, algorithms should not rely on the fact that all legal actions should be present.
- Parameters:
state – A pyspiel.State object.
player_id – Optional, the player id for whom we want an action. Optional unless this is a simultaneous state at which multiple players can act.
- Returns:
probability}` for the specified player in the supplied state.
- Return type:
A dict of `{action
- malib.utils.exploitability.compute_act_probs(game: Game, policy: Policy, state: State, player_id: int, use_observation, epsilon: float = 1e-05)[source]
- malib.utils.exploitability.convert_to_os_policies(game, policies: List[Policy], use_observation: bool, player_ids: List[int]) List[Policy] [source]
- malib.utils.exploitability.measure_exploitability(game: Union[str, Game], populations: Dict[str, Dict[str, Policy]], policy_mixture_dict: Dict[str, Dict[str, float]], use_observation: bool = False, use_cpp_br: bool = False)[source]
Return a measure of closeness to Nash for a policy in the game. :param game: An open_spiel game, e.g. kuhn_poker. :type game: Union[str, pyspiel.Game] :param populations: A dict of strategy specs, mapping from agent to StrategySpec. :type populations: Dict[AgentID, Dict[PolicyID, Policy]] :param policy_mixture_dict: A dict if policy distribution, maps from agent to a dict of floats. :type policy_mixture_dict: Dict[AgentID, Dict[PolicyID, float]] :param use_cpp_br: Compute best response with C++. Defaults to False. :type use_cpp_br: bool, optional
- Returns:
An object with the following attributes: - player_improvements: A [num_players] numpy array of the improvement for players (i.e. value_player_p_versus_BR - value_player_p). - nash_conv: The sum over all players of the improvements in value that each player could obtain by unilaterally changing their strategy, i.e. sum(player_improvements).
- Return type:
NashConv
malib.utils.general module
- malib.utils.general.deep_update(original: dict, new_dict: dict, new_keys_allowed: str = False, allow_new_subkey_list: Optional[List[str]] = None, override_all_if_type_changes: Optional[List[str]] = None) dict [source]
Updates original dict with values from new_dict recursively.
If new key is introduced in new_dict, then if new_keys_allowed is not True, an error will be thrown. Further, for sub-dicts, if the key is in the allow_new_subkey_list, then new subkeys can be introduced.
- Parameters:
original (dict) – Dictionary with default values.
new_dict (dict) – Dictionary with values to be updated
new_keys_allowed (bool) – Whether new keys are allowed.
allow_new_subkey_list (Optional[List[str]]) – List of keys that correspond to dict values where new subkeys can be introduced. This is only at the top level.
override_all_if_type_changes (Optional[List[str]]) – List of top level keys with value=dict, for which we always simply override the entire value (dict), iff the “type” key in that value dict changes.
- malib.utils.general.flatten_dict(dt: Dict, delimiter: str = '/', prevent_delimiter: bool = False, flatten_list: bool = False)[source]
Flatten dict.
Output and input are of the same dict type. Input dict remains the same after the operation.
- malib.utils.general.iter_dicts_recursively(d1, d2)[source]
Assuming dicts have the exact same structure.
- malib.utils.general.iter_many_dicts_recursively(*d, history=None)[source]
Assuming dicts have the exact same structure, or raise KeyError.
- malib.utils.general.merge_dicts(d1: dict, d2: dict) dict [source]
- Parameters:
d1 (dict) – Dict 1, the original dict template.
d2 (dict) – Dict 2, the new dict used to udpate.
- Returns:
A new dict that is d1 and d2 deep merged.
- Return type:
dict
- malib.utils.general.tensor_cast(custom_caster: Optional[Callable] = None, callback: Optional[Callable] = None, dtype_mapping: Optional[Dict] = None, device='cpu')[source]
Casting the inputs of a method into tensors if needed.
Note
This function does not support recursive iteration.
- Parameters:
custom_caster (Callable, optional) – Customized caster. Defaults to None.
callback (Callable, optional) – Callback function, accepts returns of wrapped function as inputs. Defaults to None.
dtype_mapping (Dict, optional) – Specify the data type for inputs which you wanna. Defaults to None.
- Returns:
A decorator.
- Return type:
Callable
- malib.utils.general.unflatten_dict(dt: Dict[str, T], delimiter: str = '/') Dict[str, T] [source]
Unflatten dict. Does not support unflattening lists.
- malib.utils.general.unflatten_list_dict(dt: Dict[str, T], delimiter: str = '/') Dict[str, T] [source]
Unflatten nested dict and list.
This function now has some limitations: (1) The keys of dt must be str. (2) If unflattened dt (the result) contains list, the index order must be
ascending when accessing dt. Otherwise, this function will throw AssertionError.
The unflattened dt (the result) shouldn’t contain dict with number keys.
Be careful to use this function. If you want to improve this function, please also improve the unit test. See #14487 for more details.
- Parameters:
dt (dict) – Flattened dictionary that is originally nested by multiple list and dict.
delimiter (str) – Delimiter of keys.
Example
>>> dt = {"aaa/0/bb": 12, "aaa/1/cc": 56, "aaa/1/dd": 92} >>> unflatten_list_dict(dt) {'aaa': [{'bb': 12}, {'cc': 56, 'dd': 92}]}
- malib.utils.general.unflattened_lookup(flat_key: str, lookup: Union[Mapping, Sequence], delimiter: str = '/', **kwargs) Union[Mapping, Sequence] [source]
Unflatten flat_key and iteratively look up in lookup. E.g. flat_key=”a/0/b” will try to return lookup[“a”][0][“b”].
- malib.utils.general.update_configs(runtime_config: Dict[str, Any])[source]
Update global configs with a given dict
- malib.utils.general.update_dataset_config(global_dict: Dict[str, Any], runtime_config: Dict[str, Any])[source]
- malib.utils.general.update_evaluation_config(global_dict: Dict[str, Any], runtime_config: Dict[str, Any])[source]
- malib.utils.general.update_global_evaluator_config(global_dict: Dict[str, Any], runtime_config: Dict[str, Any])[source]
- malib.utils.general.update_parameter_server_config(global_dict: Dict[str, Any], runtime_config: Dict[str, Any])[source]
- malib.utils.general.update_rollout_configs(global_dict: Dict[str, Any], runtime_dict: Dict[str, Any]) Dict[str, Any] [source]
Update default rollout configuration and return a new one.
Note
the keys in rollout configuration include - num_threads: int, the total threads in a rollout worker to run simulations. - num_env_per_thread: int, indicate how many environment will be created for each running thread. - batch_mode: default by ‘time_step’. - post_processor_types: default by [‘default’]. - use_subprov_env: use sub proc environment or not, default by False. - num_eval_threads: the number of threads for evaluation, default by 1.
- Parameters:
global_dict (Dict[str, Any]) – The default global configuration.
runtime_dict (Dict[str, Any]) – The default global configuration.
- Returns:
Updated rollout configuration.
- Return type:
Dict[str, Any]
malib.utils.logging module
malib.utils.monitor module
- malib.utils.monitor.write_to_tensorboard(writer: SummaryWriter, info: Dict, global_step: Union[int, Dict], prefix: str)[source]
Write learning info to tensorboard.
- Parameters:
writer (tensorboard.SummaryWriter) – The summary writer instance.
info (Dict) – The information dict.
global_step (int) – The global step indicator.
prefix (str) – Prefix added to keys in the info dict.
malib.utils.notations module
- malib.utils.notations.AGENT_EXPERIENCE_TABLE_NAME_GEN(env_id, policy_id, policy_type)
- malib.utils.notations.EPISODE_EXPERIENCE_TABLE_NAME_GEN(env_id)
malib.utils.preprocessor module
- class malib.utils.preprocessor.BoxFlattenPreprocessor(space: Box)[source]
Bases:
Preprocessor
- property shape
- property size
- class malib.utils.preprocessor.BoxStackedPreprocessor(space: Box)[source]
Bases:
Preprocessor
- property shape
- property size
- class malib.utils.preprocessor.DictFlattenPreprocessor(space: Dict)[source]
Bases:
Preprocessor
- property shape
- property size
- class malib.utils.preprocessor.DiscreteFlattenPreprocessor(space: Discrete)[source]
Bases:
Preprocessor
- property shape
- property size
- class malib.utils.preprocessor.Preprocessor(space: Space)[source]
Bases:
object
- property observation_space
- property original_space: Space
- property shape
- property size
- class malib.utils.preprocessor.TupleFlattenPreprocessor(space: Tuple)[source]
Bases:
Preprocessor
Init a tuple flatten preprocessor, will stack inner flattend spaces.
Note
All sub spaces in a tuple should be homogeneous.
- Parameters:
space (spaces.Tuple) – A tuple of homogeneous spaces.
- property shape
- property size
malib.utils.replay_buffer module
- class malib.utils.replay_buffer.MultiagentReplayBuffer(size: int, stack_num: int = 1, ignore_obs_next: bool = False, save_only_last_obs: bool = False, sample_avail: bool = False, **kwargs)[source]
Bases:
ReplayBuffer
malib.utils.schedules module
This file is used for specifying various schedules that evolve over time throughout the execution of the algorithm, such as:
learning rate for the optimizer
exploration epsilon for the epsilon greedy exploration strategy
beta parameter for beta parameter in prioritized replay
Each schedule has a function value(t) which returns the current value of the parameter given the timestep t of the optimization procedure.
- class malib.utils.schedules.ConstantSchedule(value)[source]
Bases:
object
Value remains constant over time. :param value: Constant value of the schedule :type value: float
- class malib.utils.schedules.LinearSchedule(schedule_timesteps, final_p, initial_p=1.0)[source]
Bases:
object
Linear interpolation between initial_p and final_p over schedule_timesteps. After this many timesteps pass final_p is returned. :param schedule_timesteps: Number of timesteps for which to linearly anneal initial_p
to final_p
- Parameters:
initial_p (float) – initial output value
final_p (float) – final output value
- class malib.utils.schedules.PiecewiseSchedule(endpoints, interpolation=<function linear_interpolation>, outside_value=None)[source]
Bases:
object
Piecewise schedule. endpoints: [(int, int)]
list of pairs (time, value) meanining that schedule should output value when t==time. All the values for time must be sorted in an increasing order. When t is between two times, e.g. (time_a, value_a) and (time_b, value_b), such that time_a <= t < time_b then value outputs interpolation(value_a, value_b, alpha) where alpha is a fraction of time passed between time_a and time_b for time t.
- interpolation: lambda float, float, float: float
a function that takes value to the left and to the right of t according to the endpoints. Alpha is the fraction of distance from left endpoint to right endpoint that t has covered. See linear_interpolation for example.
- outside_value: float
if the value is requested outside of all the intervals sepecified in endpoints this value is returned. If None then AssertionError is raised when outside value is requested.
- class malib.utils.schedules.PowerSchedule(schedule_timesteps, final_p, initial_p=1.0)[source]
Bases:
object
malib.utils.statistic module
malib.utils.stopping_conditions module
- class malib.utils.stopping_conditions.MaxIterationStopping(max_iteration: int)[source]
Bases:
StoppingCondition
- class malib.utils.stopping_conditions.MergeStopping(stoppings: List[StoppingCondition])[source]
Bases:
StoppingCondition
- class malib.utils.stopping_conditions.NoStoppingCondition[source]
Bases:
StoppingCondition
- class malib.utils.stopping_conditions.RewardImprovementStopping(mininum_reward_improvement: float)[source]
Bases:
StoppingCondition
- class malib.utils.stopping_conditions.StopImmediately[source]
Bases:
StoppingCondition
malib.utils.tasks_register module
malib.utils.tianshou_batch module
- class malib.utils.tianshou_batch.Batch(batch_dict: Optional[Union[dict, Batch, Sequence[Union[dict, Batch]], ndarray]] = None, copy: bool = False, **kwargs: Any)[source]
Bases:
object
The internal data structure in Tianshou.
Batch is a kind of supercharged array (of temporal data) stored individually in a (recursive) dictionary of object that can be either numpy array, torch tensor, or batch themselves. It is designed to make it extremely easily to access, manipulate and set partial view of the heterogeneous data conveniently.
For a detailed description, please refer to batch_concept.
- static cat(batches: Sequence[Union[dict, Batch]]) Batch [source]
Concatenate a list of Batch object into a single new batch.
For keys that are not shared across all batches, batches that do not have these keys will be padded by zeros with appropriate shapes. E.g.
>>> a = Batch(a=np.zeros([3, 4]), common=Batch(c=np.zeros([3, 5]))) >>> b = Batch(b=np.zeros([4, 3]), common=Batch(c=np.zeros([4, 5]))) >>> c = Batch.cat([a, b]) >>> c.a.shape (7, 4) >>> c.b.shape (7, 3) >>> c.common.c.shape (7, 5)
- cat_(batches: Union[Batch, Sequence[Union[dict, Batch]]]) None [source]
Concatenate a list of (or one) Batch objects into current batch.
- static empty(batch: Batch, index: Optional[Union[slice, int, ndarray, List[int]]] = None) Batch [source]
Return an empty Batch object with 0 or None filled.
The shape is the same as the given Batch.
- empty_(index: Optional[Union[slice, int, ndarray, List[int]]] = None) Batch [source]
Return an empty Batch object with 0 or None filled.
If “index” is specified, it will only reset the specific indexed-data.
>>> data.empty_() >>> print(data) Batch( a: array([[0., 0.], [0., 0.]]), b: array([None, None], dtype=object), ) >>> b={'c': [2., 'st'], 'd': [1., 0.]} >>> data = Batch(a=[False, True], b=b) >>> data[0] = Batch.empty(data[1]) >>> data Batch( a: array([False, True]), b: Batch( c: array([None, 'st']), d: array([0., 0.]), ), )
- is_empty(recurse: bool = False) bool [source]
Test if a Batch is empty.
If
recurse=True
, it further tests the values of the object; else it only tests the existence of any key.b.is_empty(recurse=True)
is mainly used to distinguishBatch(a=Batch(a=Batch()))
andBatch(a=1)
. They both raise exceptions when applied tolen()
, but the former can be used incat
, while the latter is a scalar and cannot be used incat
.Another usage is in
__len__
, where we have to skip checking the length of recursively empty Batch.>>> Batch().is_empty() True >>> Batch(a=Batch(), b=Batch(c=Batch())).is_empty() False >>> Batch(a=Batch(), b=Batch(c=Batch())).is_empty(recurse=True) True >>> Batch(d=1).is_empty() False >>> Batch(a=np.float64(1.0)).is_empty() False
- property shape: List[int]
Return self.shape.
- split(size: int, shuffle: bool = True, merge_last: bool = False) Iterator[Batch] [source]
Split whole data into multiple small batches.
- Parameters:
size (int) – divide the data batch with the given size, but one batch if the length of the batch is smaller than “size”.
shuffle (bool) – randomly shuffle the entire data batch if it is True, otherwise remain in the same. Default to True.
merge_last (bool) – merge the last batch into the previous one. Default to False.
- static stack(batches: Sequence[Union[dict, Batch]], axis: int = 0) Batch [source]
Stack a list of Batch object into a single new batch.
For keys that are not shared across all batches, batches that do not have these keys will be padded by zeros. E.g.
>>> a = Batch(a=np.zeros([4, 4]), common=Batch(c=np.zeros([4, 5]))) >>> b = Batch(b=np.zeros([4, 6]), common=Batch(c=np.zeros([4, 5]))) >>> c = Batch.stack([a, b]) >>> c.a.shape (2, 4, 4) >>> c.b.shape (2, 4, 6) >>> c.common.c.shape (2, 4, 5)
Note
If there are keys that are not shared across all batches,
stack
withaxis != 0
is undefined, and will cause an exception.
- stack_(batches: Sequence[Union[dict, Batch]], axis: int = 0) None [source]
Stack a list of Batch object into current batch.
malib.utils.timing module
malib.utils.typing module
- class malib.utils.typing.BColors[source]
Bases:
object
- BOLD = '\x1b[1m'
- ENDC = '\x1b[0m'
- FAIL = '\x1b[91m'
- HEADER = '\x1b[95m'
- OKBLUE = '\x1b[94m'
- OKCYAN = '\x1b[96m'
- OKGREEN = '\x1b[92m'
- UNDERLINE = '\x1b[4m'
- WARNING = '\x1b[93m'