# MIT License
# Copyright (c) 2021 MARL @ SJTU
# Author: Ming Zhou
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
from argparse import Namespace
from typing import Dict, Any, Sequence
from threading import Lock
import itertools
import torch
from malib.rl.common.policy import Policy
from malib.common.strategy_spec import StrategySpec
from malib.remote.interface import RemoteInterface
from malib.utils.logging import Logger
[docs]class Table:
def __init__(self, policy_meta_data: Dict[str, Any]):
policy_cls = policy_meta_data["policy_cls"]
optim_config = policy_meta_data.get("optim_config")
policy_init_kwargs = Namespace(**policy_meta_data["kwargs"])
self.state_dict = None
if optim_config is not None:
self.policy: Policy = policy_cls(
observation_space=policy_init_kwargs.observation_space,
action_space=policy_init_kwargs.action_space,
model_config=policy_init_kwargs.model_config,
custom_config=policy_init_kwargs.custom_config,
**policy_init_kwargs.kwargs,
)
parameters = [list(v) for v in self.policy.parameters().values()]
parameters = itertools.chain(*parameters)
self.optimizer: torch.optim.Optimizer = getattr(
torch.optim, optim_config["type"]
)(parameters, lr=optim_config["lr"])
else:
self.optimizer: torch.optim.Optimizer = None
self.lock = Lock()
[docs] def set_weights(self, state_dict: Dict[str, Any]):
"""Update weights with given weights.
Args:
state_dict (Dict[str, Any]): A dict of weights
"""
with self.lock:
self.state_dict = state_dict
[docs] def apply_gradients(self, *gradients):
raise NotImplementedError
[docs] def get_weights(self) -> Dict[str, Any]:
"""Retrive model weights.
Returns:
Dict[str, Any]: Weights dict
"""
with self.lock:
return self.state_dict
[docs]class ParameterServer(RemoteInterface):
def __init__(self, **kwargs):
self.tables: Dict[str, Table] = {}
self.lock = Lock()
[docs] def start(self):
"""For debug"""
Logger.info("Parameter server started")
[docs] def apply_gradients(self, table_name: str, gradients: Sequence[Any]):
"""Apply gradients to a data table.
Args:
table_name (str): The specified table name.
gradients (Sequence[Any]): Given gradients to update parameters.
Raises:
NotImplementedError: Not implemented yet.
"""
raise NotImplementedError
[docs] def get_weights(self, spec_id: str, spec_policy_id: str) -> Dict[str, Any]:
"""Request for weight retrive, return a dict includes keys: `spec_id`, `spec_policy_id` and `weights`.
Args:
spec_id (str): Strategy spec id.
spec_policy_id (str): Related policy id.
Returns:
Dict[str, Any]: A dict.
"""
table_name = f"{spec_id}/{spec_policy_id}"
weights = self.tables[table_name].get_weights()
return {
"spec_id": spec_id,
"spec_policy_id": spec_policy_id,
"weights": weights,
}
[docs] def set_weights(
self, spec_id: str, spec_policy_id: str, state_dict: Dict[str, Any]
):
"""Set weights to a parameter table. The table name will be defined as `{spec_id}/{spec_policy_id}`
Args:
spec_id (str): StrategySpec id.
spec_policy_id (str): Policy id in the specified strategy spec.
state_dict (Dict[str, Any]): A dict that specify the parameters.
"""
table_name = f"{spec_id}/{spec_policy_id}"
self.tables[table_name].set_weights(state_dict)
[docs] def create_table(self, strategy_spec: StrategySpec) -> str:
"""Create parameter table with given strategy spec. This function will traverse existing policy \
id in this spec, then generate table for policy ids which have no cooresponding tables.
Args:
strategy_spec (StrategySpec): A startegy spec instance.
Returns:
str: Table name formatted as `{startegy_spec_id}/{policy_id}`.
"""
with self.lock:
for policy_id in strategy_spec.policy_ids:
table_name = f"{strategy_spec.id}/{policy_id}"
if table_name in self.tables:
continue
meta_data = strategy_spec.get_meta_data().copy()
self.tables[table_name] = Table(meta_data)
return table_name