malib.backend package

Submodules

malib.backend.offline_dataset_server module

class malib.backend.offline_dataset_server.OfflineDataset(table_capacity: int, max_consumer_size: int = 1024)[source]

Bases: RemoteInterface

Construct an offline datataset. It maintans a dict of datatable, each for a training instance.

Parameters:
  • table_capacity (int) – Table capacity, it indicates the buffer size of each data table.

  • max_consumer_size (int, optional) – Defines the maximum of concurrency. Defaults to 1024.

end_consumer_pipe(name: str)[source]

Kill a consumer pipeline with given table name.

Parameters:

name (str) – Name of related datatable.

end_producer_pipe(name: str)[source]

Kill a producer pipe with given name.

Parameters:

name (str) – The name of related data table.

start()[source]
start_consumer_pipe(name: str, batch_size: int) Tuple[str, Queue][source]

Start a consumer pipeline, if there is no such a table that named as name, the function will be stucked until the table has been created.

Parameters:
  • name (str) – Name of datatable.

  • batch_size (int) – Batch size.

Returns:

A tuple of table name and queue for retrieving samples.

Return type:

Tuple[str, Queue]

start_producer_pipe(name: str, stack_num: int = 1, ignore_obs_next: bool = False, save_only_last_obs: bool = False, sample_avail: bool = False, **kwargs) Tuple[str, Queue][source]

Start a producer pipeline and create a datatable if not exisits.

Parameters:
  • name (str) – The name of datatable need to access

  • stack_num (int, optional) – Indicates how many steps are stacked in a single data sample. Defaults to 1.

  • ignore_obs_next (bool, optional) – Ignore the next observation or not. Defaults to False.

  • save_only_last_obs (bool, optional) – Either save only the last observation frame. Defaults to False.

  • sample_avail (bool, optional) – Sample action maks or not. Defaults to False.

Returns:

A tuple of table name and queue for insert samples.

Return type:

Tuple[str, Queue]

malib.backend.offline_dataset_server.read_table(marker: RWLockFair, buffer: Union[MultiagentReplayBuffer, ReplayBuffer], batch_size: int, reader: Queue)[source]
malib.backend.offline_dataset_server.write_table(marker: RWLockFair, buffer: Union[MultiagentReplayBuffer, ReplayBuffer], writer: Queue)[source]

malib.backend.parameter_server module

class malib.backend.parameter_server.ParameterServer(**kwargs)[source]

Bases: RemoteInterface

apply_gradients(table_name: str, gradients: Sequence[Any])[source]

Apply gradients to a data table.

Parameters:
  • table_name (str) – The specified table name.

  • gradients (Sequence[Any]) – Given gradients to update parameters.

Raises:

NotImplementedError – Not implemented yet.

create_table(strategy_spec: StrategySpec) str[source]

Create parameter table with given strategy spec. This function will traverse existing policy id in this spec, then generate table for policy ids which have no cooresponding tables.

Parameters:

strategy_spec (StrategySpec) – A startegy spec instance.

Returns:

Table name formatted as {startegy_spec_id}/{policy_id}.

Return type:

str

get_weights(spec_id: str, spec_policy_id: str) Dict[str, Any][source]

Request for weight retrive, return a dict includes keys: spec_id, spec_policy_id and weights.

Parameters:
  • spec_id (str) – Strategy spec id.

  • spec_policy_id (str) – Related policy id.

Returns:

A dict.

Return type:

Dict[str, Any]

set_weights(spec_id: str, spec_policy_id: str, state_dict: Dict[str, Any])[source]

Set weights to a parameter table. The table name will be defined as {spec_id}/{spec_policy_id}

Parameters:
  • spec_id (str) – StrategySpec id.

  • spec_policy_id (str) – Policy id in the specified strategy spec.

  • state_dict (Dict[str, Any]) – A dict that specify the parameters.

start()[source]

For debug

class malib.backend.parameter_server.Table(policy_meta_data: Dict[str, Any])[source]

Bases: object

apply_gradients(*gradients)[source]
get_weights() Dict[str, Any][source]

Retrive model weights.

Returns:

Weights dict

Return type:

Dict[str, Any]

set_weights(state_dict: Dict[str, Any])[source]

Update weights with given weights.

Parameters:

state_dict (Dict[str, Any]) – A dict of weights