Source code for malib.rl.dqn.trainer

# MIT License

# Copyright (c) 2021 MARL @ SJTU

# Author: Ming Zhou

# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:

# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.

# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

from argparse import Namespace
from typing import Tuple, Sequence

import copy

import torch
import numpy as np

from torch.nn import functional as F

from malib.utils.typing import AgentID

from malib.rl.common import misc
from malib.rl.common.trainer import Trainer
from malib.utils.schedules import LinearSchedule
from malib.utils.tianshou_batch import Batch


[docs]class DQNTrainer(Trainer):
[docs] def setup(self): exploration_fraction = self._training_config["exploration_fraction"] total_timesteps = self._training_config["total_timesteps"] exploration_final_eps = self._training_config["exploration_final_eps"] self.fixed_eps = self._training_config.get("pretrain_eps") self.exploration = LinearSchedule( schedule_timesteps=int(exploration_fraction * total_timesteps), initial_p=1.0, final_p=exploration_final_eps, ) optim_cls = getattr(torch.optim, self.training_config["optimizer"]) self.target_critic = copy.deepcopy(self.policy.critic) self.optimizer: torch.optim.Optimizer = optim_cls( self.policy.critic.parameters(), lr=self.training_config["lr"] )
[docs] def post_process(self, batch: Batch, agent_filter: Sequence[AgentID]) -> Batch: # set exploration rate for policy update_eps = self.exploration.value(self.counter) self.policy.eps = update_eps if self.policy.agent_dimension > 0: for k, v in batch.items(): if isinstance(v, np.ndarray): inner_shape = v.shape[2:] # (batch_size, agent_num, inner_shape) batch[k] = v.reshape((-1,) + inner_shape) return batch
[docs] def train(self, batch: Batch): state_action_values, _ = self.policy.critic(batch.obs.squeeze()) state_action_values = state_action_values.gather( -1, batch.act.long().view((-1, 1)) ).view(-1) next_state_q, _ = self.target_critic(batch.obs_next.squeeze()) next_action_mask = batch.get("act_mask_next", None) if next_action_mask is not None: illegal_action_mask = 1.0 - next_action_mask # give very low value to illegal action logits illegal_action_logits = -illegal_action_mask * 1e9 next_state_q += illegal_action_logits next_state_action_values = next_state_q.max(-1)[0] assert batch.rew.shape == batch.done.shape == next_state_action_values.shape, ( batch.rew.shape, batch.done.shape, next_state_action_values.shape, ) assert batch.rew.shape == batch.done.shape == next_state_action_values.shape, ( batch.rew.shape, batch.done.shape, next_state_action_values.shape, ) expected_state_values = ( batch.rew.float() + self._training_config["gamma"] * (1.0 - batch.done.float()) * next_state_action_values ) assert expected_state_values.shape == state_action_values.shape, ( expected_state_values.shape, state_action_values.shape, ) self.optimizer.zero_grad() loss = F.mse_loss(state_action_values, expected_state_values.detach()) loss.backward() self.optimizer.step() misc.soft_update( self.target_critic, self.policy.critic, tau=self._training_config["tau"] ) return { "loss": loss.detach().item(), "mean_target": expected_state_values.mean().cpu().item(), "mean_eval": state_action_values.mean().cpu().item(), "min_eval": state_action_values.min().cpu().item(), "max_eval": state_action_values.max().cpu().item(), "max_target": expected_state_values.max().cpu().item(), "min_target": expected_state_values.min().cpu().item(), "mean_reward": batch.rew.mean().cpu().item(), "min_reward": batch.rew.min().cpu().item(), "max_reward": batch.rew.max().cpu().item(), "eps": self.policy.eps, }