From 2c271b3b44fdbac0343b1b5f654ad93e42ff20f1 Mon Sep 17 00:00:00 2001 From: Huoran Li Date: Fri, 5 May 2023 16:52:22 +0800 Subject: [PATCH 01/12] Runnable. Should setup a benchmark and test performance. --- examples/cim/rl/algorithms/dqn.py | 3 +- maro/rl/training/__init__.py | 3 +- maro/rl/training/algorithms/ddpg.py | 3 - maro/rl/training/algorithms/dqn.py | 52 ++++++++++------ maro/rl/training/algorithms/sac.py | 3 - maro/rl/training/replay_memory.py | 95 ++++++++++++++++++++++++++++- maro/rl/training/trainer.py | 3 +- 7 files changed, 134 insertions(+), 28 deletions(-) diff --git a/examples/cim/rl/algorithms/dqn.py b/examples/cim/rl/algorithms/dqn.py index 022275552..4e40034b8 100644 --- a/examples/cim/rl/algorithms/dqn.py +++ b/examples/cim/rl/algorithms/dqn.py @@ -64,6 +64,7 @@ def get_dqn(name: str) -> DQNTrainer: num_epochs=10, soft_update_coef=0.1, double=False, - random_overwrite=False, + alpha=1.0, + beta=1.0, ), ) diff --git a/maro/rl/training/__init__.py b/maro/rl/training/__init__.py index a77296f98..b3dfd8c61 100644 --- a/maro/rl/training/__init__.py +++ b/maro/rl/training/__init__.py @@ -2,7 +2,7 @@ # Licensed under the MIT license. from .proxy import TrainingProxy -from .replay_memory import FIFOMultiReplayMemory, FIFOReplayMemory, RandomMultiReplayMemory, RandomReplayMemory +from .replay_memory import FIFOMultiReplayMemory, FIFOReplayMemory, PriorityReplayMemory, RandomMultiReplayMemory, RandomReplayMemory from .train_ops import AbsTrainOps, RemoteOps, remote from .trainer import AbsTrainer, BaseTrainerParams, MultiAgentTrainer, SingleAgentTrainer from .training_manager import TrainingManager @@ -12,6 +12,7 @@ "TrainingProxy", "FIFOMultiReplayMemory", "FIFOReplayMemory", + "PriorityReplayMemory", "RandomMultiReplayMemory", "RandomReplayMemory", "AbsTrainOps", diff --git a/maro/rl/training/algorithms/ddpg.py b/maro/rl/training/algorithms/ddpg.py index aaa0b7454..bf7b0f8d4 100644 --- a/maro/rl/training/algorithms/ddpg.py +++ b/maro/rl/training/algorithms/ddpg.py @@ -261,9 +261,6 @@ def _register_policy(self, policy: RLPolicy) -> None: assert isinstance(policy, ContinuousRLPolicy) self._policy = policy - def _preprocess_batch(self, transition_batch: TransitionBatch) -> TransitionBatch: - return transition_batch - def get_local_ops(self) -> AbsTrainOps: return DDPGOps( name=self._policy.name, diff --git a/maro/rl/training/algorithms/dqn.py b/maro/rl/training/algorithms/dqn.py index 5a4f938ab..0029bd818 100644 --- a/maro/rl/training/algorithms/dqn.py +++ b/maro/rl/training/algorithms/dqn.py @@ -2,12 +2,13 @@ # Licensed under the MIT license. from dataclasses import dataclass -from typing import Dict, cast +from typing import Dict, Tuple, cast +import numpy as np import torch from maro.rl.policy import RLPolicy, ValueBasedPolicy -from maro.rl.training import AbsTrainOps, BaseTrainerParams, RandomReplayMemory, RemoteOps, SingleAgentTrainer, remote +from maro.rl.training import AbsTrainOps, BaseTrainerParams, PriorityReplayMemory, RemoteOps, SingleAgentTrainer, remote from maro.rl.utils import TransitionBatch, get_torch_device, ndarray_to_tensor from maro.utils import clone @@ -27,11 +28,12 @@ class DQNParams(BaseTrainerParams): sequentially with wrap-around. """ + alpha: float + beta: float num_epochs: int = 1 update_target_every: int = 5 soft_update_coef: float = 0.1 double: bool = False - random_overwrite: bool = False class DQNOps(AbsTrainOps): @@ -54,20 +56,21 @@ def __init__( self._reward_discount = reward_discount self._soft_update_coef = params.soft_update_coef self._double = params.double - self._loss_func = torch.nn.MSELoss() self._target_policy: ValueBasedPolicy = clone(self._policy) self._target_policy.set_name(f"target_{self._policy.name}") self._target_policy.eval() - def _get_batch_loss(self, batch: TransitionBatch) -> torch.Tensor: + def _get_batch_loss(self, batch: TransitionBatch, weight: np.ndarray) -> Tuple[torch.Tensor, torch.Tensor]: """Compute the loss of the batch. Args: batch (TransitionBatch): Batch. + weight (np.ndarray): Weight of each data entry. Returns: loss (torch.Tensor): The loss of the batch. + td_error (torch.Tensor): TD-error of the batch. """ assert isinstance(batch, TransitionBatch) assert isinstance(self._policy, ValueBasedPolicy) @@ -79,6 +82,8 @@ def _get_batch_loss(self, batch: TransitionBatch) -> torch.Tensor: rewards = ndarray_to_tensor(batch.rewards, device=self._device) terminals = ndarray_to_tensor(batch.terminals, device=self._device).float() + weight = ndarray_to_tensor(weight, device=self._device) + with torch.no_grad(): if self._double: self._policy.exploit() @@ -91,7 +96,9 @@ def _get_batch_loss(self, batch: TransitionBatch) -> torch.Tensor: target_q_values = (rewards + self._reward_discount * (1 - terminals) * next_q_values).detach() q_values = self._policy.q_values_tensor(states, actions) - return self._loss_func(q_values, target_q_values) + td_error = target_q_values - q_values + + return (td_error.pow(2) * weight).mean(), td_error @remote def get_batch_grad(self, batch: TransitionBatch) -> Dict[str, torch.Tensor]: @@ -103,7 +110,8 @@ def get_batch_grad(self, batch: TransitionBatch) -> Dict[str, torch.Tensor]: Returns: grad (torch.Tensor): The gradient of the batch. """ - return self._policy.get_gradients(self._get_batch_loss(batch)) + loss, _ = self._get_batch_loss(batch) + return self._policy.get_gradients(loss) def update_with_grad(self, grad_dict: dict) -> None: """Update the network with remotely computed gradients. @@ -114,14 +122,20 @@ def update_with_grad(self, grad_dict: dict) -> None: self._policy.train() self._policy.apply_gradients(grad_dict) - def update(self, batch: TransitionBatch) -> None: + def update(self, batch: TransitionBatch, weight: np.ndarray) -> np.ndarray: """Update the network using a batch. Args: batch (TransitionBatch): Batch. + weight (np.ndarray): Weight of each data entry. + + Returns: + td_errors (np.ndarray) """ self._policy.train() - self._policy.train_step(self._get_batch_loss(batch)) + loss, td_error = self._get_batch_loss(batch, weight) + self._policy.train_step(loss) + return td_error.detach().numpy() def get_non_policy_state(self) -> dict: return { @@ -168,20 +182,18 @@ def __init__( def build(self) -> None: self._ops = cast(DQNOps, self.get_ops()) - self._replay_memory = RandomReplayMemory( + self._replay_memory = PriorityReplayMemory( capacity=self._replay_memory_capacity, state_dim=self._ops.policy_state_dim, action_dim=self._ops.policy_action_dim, - random_overwrite=self._params.random_overwrite, + alpha=self._params.alpha, + beta=self._params.beta, ) def _register_policy(self, policy: RLPolicy) -> None: assert isinstance(policy, ValueBasedPolicy) self._policy = policy - def _preprocess_batch(self, transition_batch: TransitionBatch) -> TransitionBatch: - return transition_batch - def get_local_ops(self) -> AbsTrainOps: return DQNOps( name=self._policy.name, @@ -191,13 +203,19 @@ def get_local_ops(self) -> AbsTrainOps: params=self._params, ) - def _get_batch(self, batch_size: int = None) -> TransitionBatch: - return self._replay_memory.sample(batch_size if batch_size is not None else self._batch_size) + def _get_batch(self, batch_size: int = None) -> Tuple[TransitionBatch, np.ndarray]: + replay_memory = cast(PriorityReplayMemory, self.replay_memory) + batch = replay_memory.sample(batch_size or self._batch_size) + weight = replay_memory.get_weight() + return batch, weight def train_step(self) -> None: assert isinstance(self._ops, DQNOps) + replay_memory = cast(PriorityReplayMemory, self.replay_memory) for _ in range(self._params.num_epochs): - self._ops.update(self._get_batch()) + batch, weight = self._get_batch() + td_error = self._ops.update(batch, weight) + replay_memory.update_weight(td_error) self._try_soft_update_target() diff --git a/maro/rl/training/algorithms/sac.py b/maro/rl/training/algorithms/sac.py index 7daf99c7d..54a6c4cbd 100644 --- a/maro/rl/training/algorithms/sac.py +++ b/maro/rl/training/algorithms/sac.py @@ -272,9 +272,6 @@ async def train_step_as_task(self) -> None: if early_stop: break - def _preprocess_batch(self, transition_batch: TransitionBatch) -> TransitionBatch: - return transition_batch - def get_local_ops(self) -> SoftActorCriticOps: return SoftActorCriticOps( name=self._policy.name, diff --git a/maro/rl/training/replay_memory.py b/maro/rl/training/replay_memory.py index da1e7d692..e5f187889 100644 --- a/maro/rl/training/replay_memory.py +++ b/maro/rl/training/replay_memory.py @@ -2,7 +2,7 @@ # Licensed under the MIT license. from abc import ABCMeta, abstractmethod -from typing import List +from typing import List, Optional import numpy as np @@ -88,6 +88,74 @@ def get_sample_indexes(self, batch_size: int = None) -> np.ndarray: return np.random.choice(self._size, size=batch_size, replace=True) +class PriorityReplayIndexScheduler(AbsIndexScheduler): + """ + Indexer for priority replay memory: https://arxiv.org/abs/1511.05952. + + Args: + capacity (int): Maximum capacity of the replay memory. + alpha (float): Alpha (see original paper for explanation). + beta (float): Alpha (see original paper for explanation). + """ + def __init__( + self, + capacity: int, + alpha: float, + beta: float, + ) -> None: + super(PriorityReplayIndexScheduler, self).__init__(capacity) + self._alpha = alpha + self._beta = beta + self._max_prio = self._min_prio = 1.0 + self._weights = np.zeros(capacity, dtype=np.float32) + + self._ptr = self._size = 0 + + self._last_sample_indexes: Optional[np.ndarray] = None + + def init_weights(self, indexes: np.ndarray) -> None: + self._weights[indexes] = self._max_prio ** self._alpha + + def get_weight(self) -> np.ndarray: + # important sampling weight calculation + # original formula: ((p_j/p_sum*N)**(-beta))/((p_min/p_sum*N)**(-beta)) + # simplified formula: (p_j/p_min)**(-beta) + return (self._weights[self._last_sample_indexes] / self._min_prio) ** (-self._beta) + + def update_weight(self, weight: np.ndarray) -> None: + weight = np.abs(weight) + np.finfo(np.float32).eps.item() + self._weights[self._last_sample_indexes] = weight ** self._alpha + self._max_prio = max(self._max_prio, weight.max()) + self._min_prio = min(self._min_prio, weight.min()) + + def get_put_indexes(self, batch_size: int) -> np.ndarray: + if self._ptr + batch_size <= self._capacity: + indexes = np.arange(self._ptr, self._ptr + batch_size) + self._ptr += batch_size + else: + overwrites = self._ptr + batch_size - self._capacity + indexes = np.concatenate( + [ + np.arange(self._ptr, self._capacity), + np.arange(overwrites), + ], + ) + self._ptr = overwrites + + self._size = min(self._size + batch_size, self._capacity) + self.init_weights(indexes) + return indexes + + def get_sample_indexes(self, batch_size: int = None) -> np.ndarray: + assert batch_size is not None and batch_size > 0, f"Invalid batch size: {batch_size}" + assert self._size > 0, "Cannot sample from an empty memory." + + weights = self._weights[:self._size] + weights = weights / weights.sum() + self._last_sample_indexes = np.random.choice(np.arange(self._size), p=weights, size=batch_size, replace=False) + return self._last_sample_indexes + + class FIFOIndexScheduler(AbsIndexScheduler): """First-in-first-out index scheduler. @@ -306,6 +374,31 @@ def random_overwrite(self) -> bool: return self._random_overwrite +class PriorityReplayMemory(ReplayMemory): + def __init__( + self, + capacity: int, + state_dim: int, + action_dim: int, + alpha: float, + beta: float, + ) -> None: + super(PriorityReplayMemory, self).__init__( + capacity, + state_dim, + action_dim, + PriorityReplayIndexScheduler(capacity, alpha, beta), + ) + + def get_weight(self) -> np.ndarray: + assert isinstance(self._idx_scheduler, PriorityReplayIndexScheduler) + return self._idx_scheduler.get_weight() + + def update_weight(self, weight: np.ndarray) -> None: + assert isinstance(self._idx_scheduler, PriorityReplayIndexScheduler) + self._idx_scheduler.update_weight(weight) + + class FIFOReplayMemory(ReplayMemory): def __init__( self, diff --git a/maro/rl/training/trainer.py b/maro/rl/training/trainer.py index 774954f6c..53bd123d8 100644 --- a/maro/rl/training/trainer.py +++ b/maro/rl/training/trainer.py @@ -271,9 +271,8 @@ def record_multiple(self, env_idx: int, exp_elements: List[ExpElement]) -> None: transition_batch = self._preprocess_batch(transition_batch) self.replay_memory.put(transition_batch) - @abstractmethod def _preprocess_batch(self, transition_batch: TransitionBatch) -> TransitionBatch: - raise NotImplementedError + return transition_batch def _assert_ops_exists(self) -> None: if not self.ops: From b61b300e8c559768df71b59eef8de709686c2399 Mon Sep 17 00:00:00 2001 From: Huoran Li Date: Fri, 5 May 2023 21:39:33 +0800 Subject: [PATCH 02/12] Refine logic --- examples/cim/rl/config.py | 2 +- maro/rl/training/algorithms/dqn.py | 13 +++++----- maro/rl/training/replay_memory.py | 38 ++++++++++++++---------------- 3 files changed, 26 insertions(+), 27 deletions(-) diff --git a/examples/cim/rl/config.py b/examples/cim/rl/config.py index a46194900..e91e5edfc 100644 --- a/examples/cim/rl/config.py +++ b/examples/cim/rl/config.py @@ -35,4 +35,4 @@ action_num = len(action_shaping_conf["action_space"]) -algorithm = "ppo" # ac, ppo, dqn or discrete_maddpg +algorithm = "dqn" # ac, ppo, dqn or discrete_maddpg diff --git a/maro/rl/training/algorithms/dqn.py b/maro/rl/training/algorithms/dqn.py index 0029bd818..130cc9797 100644 --- a/maro/rl/training/algorithms/dqn.py +++ b/maro/rl/training/algorithms/dqn.py @@ -203,19 +203,20 @@ def get_local_ops(self) -> AbsTrainOps: params=self._params, ) - def _get_batch(self, batch_size: int = None) -> Tuple[TransitionBatch, np.ndarray]: + def _get_batch(self, batch_size: int = None) -> Tuple[TransitionBatch, np.ndarray, np.ndarray]: replay_memory = cast(PriorityReplayMemory, self.replay_memory) - batch = replay_memory.sample(batch_size or self._batch_size) - weight = replay_memory.get_weight() - return batch, weight + indexes = replay_memory.get_sample_indexes(batch_size or self._batch_size) + batch = replay_memory.sample_by_indexes(indexes) + weight = replay_memory.get_weight(indexes) + return batch, indexes, weight def train_step(self) -> None: assert isinstance(self._ops, DQNOps) replay_memory = cast(PriorityReplayMemory, self.replay_memory) for _ in range(self._params.num_epochs): - batch, weight = self._get_batch() + batch, indexes, weight = self._get_batch() td_error = self._ops.update(batch, weight) - replay_memory.update_weight(td_error) + replay_memory.update_weight(indexes, td_error) self._try_soft_update_target() diff --git a/maro/rl/training/replay_memory.py b/maro/rl/training/replay_memory.py index e5f187889..e31b99636 100644 --- a/maro/rl/training/replay_memory.py +++ b/maro/rl/training/replay_memory.py @@ -111,20 +111,19 @@ def __init__( self._ptr = self._size = 0 - self._last_sample_indexes: Optional[np.ndarray] = None - def init_weights(self, indexes: np.ndarray) -> None: self._weights[indexes] = self._max_prio ** self._alpha - def get_weight(self) -> np.ndarray: + def get_weight(self, indexes: np.ndarray) -> np.ndarray: # important sampling weight calculation # original formula: ((p_j/p_sum*N)**(-beta))/((p_min/p_sum*N)**(-beta)) # simplified formula: (p_j/p_min)**(-beta) - return (self._weights[self._last_sample_indexes] / self._min_prio) ** (-self._beta) + return (self._weights[indexes] / self._min_prio) ** (-self._beta) - def update_weight(self, weight: np.ndarray) -> None: + def update_weight(self, indexes: np.ndarray, weight: np.ndarray) -> None: + assert indexes.shape == weight.shape weight = np.abs(weight) + np.finfo(np.float32).eps.item() - self._weights[self._last_sample_indexes] = weight ** self._alpha + self._weights[indexes] = weight ** self._alpha self._max_prio = max(self._max_prio, weight.max()) self._min_prio = min(self._min_prio, weight.min()) @@ -152,8 +151,7 @@ def get_sample_indexes(self, batch_size: int = None) -> np.ndarray: weights = self._weights[:self._size] weights = weights / weights.sum() - self._last_sample_indexes = np.random.choice(np.arange(self._size), p=weights, size=batch_size, replace=False) - return self._last_sample_indexes + return np.random.choice(np.arange(self._size), p=weights, size=batch_size, replace=False) class FIFOIndexScheduler(AbsIndexScheduler): @@ -222,11 +220,11 @@ def capacity(self) -> int: def state_dim(self) -> int: return self._state_dim - def _get_put_indexes(self, batch_size: int) -> np.ndarray: + def get_put_indexes(self, batch_size: int) -> np.ndarray: """Please refer to the doc string in AbsIndexScheduler.""" return self._idx_scheduler.get_put_indexes(batch_size) - def _get_sample_indexes(self, batch_size: int = None) -> np.ndarray: + def get_sample_indexes(self, batch_size: int = None) -> np.ndarray: """Please refer to the doc string in AbsIndexScheduler.""" return self._idx_scheduler.get_sample_indexes(batch_size) @@ -293,10 +291,10 @@ def put(self, transition_batch: TransitionBatch) -> None: if transition_batch.old_logps is not None: match_shape(transition_batch.old_logps, (batch_size,)) - self._put_by_indexes(self._get_put_indexes(batch_size), transition_batch) + self.put_by_indexes(self.get_put_indexes(batch_size), transition_batch) self._n_sample = min(self._n_sample + transition_batch.size, self._capacity) - def _put_by_indexes(self, indexes: np.ndarray, transition_batch: TransitionBatch) -> None: + def put_by_indexes(self, indexes: np.ndarray, transition_batch: TransitionBatch) -> None: """Store a transition batch into the memory at the give indexes. Args: @@ -326,7 +324,7 @@ def sample(self, batch_size: int = None) -> TransitionBatch: Returns: batch (TransitionBatch): The sampled batch. """ - indexes = self._get_sample_indexes(batch_size) + indexes = self.get_sample_indexes(batch_size) return self.sample_by_indexes(indexes) def sample_by_indexes(self, indexes: np.ndarray) -> TransitionBatch: @@ -390,13 +388,13 @@ def __init__( PriorityReplayIndexScheduler(capacity, alpha, beta), ) - def get_weight(self) -> np.ndarray: + def get_weight(self, indexes: np.ndarray) -> np.ndarray: assert isinstance(self._idx_scheduler, PriorityReplayIndexScheduler) - return self._idx_scheduler.get_weight() + return self._idx_scheduler.get_weight(indexes) - def update_weight(self, weight: np.ndarray) -> None: + def update_weight(self, indexes: np.ndarray, weight: np.ndarray) -> None: assert isinstance(self._idx_scheduler, PriorityReplayIndexScheduler) - self._idx_scheduler.update_weight(weight) + self._idx_scheduler.update_weight(indexes, weight) class FIFOReplayMemory(ReplayMemory): @@ -486,9 +484,9 @@ def put(self, transition_batch: MultiTransitionBatch) -> None: assert match_shape(transition_batch.agent_states[i], (batch_size, self._agent_states_dims[i])) assert match_shape(transition_batch.next_agent_states[i], (batch_size, self._agent_states_dims[i])) - self._put_by_indexes(self._get_put_indexes(batch_size), transition_batch=transition_batch) + self.put_by_indexes(self.get_put_indexes(batch_size), transition_batch=transition_batch) - def _put_by_indexes(self, indexes: np.ndarray, transition_batch: MultiTransitionBatch) -> None: + def put_by_indexes(self, indexes: np.ndarray, transition_batch: MultiTransitionBatch) -> None: """Store a transition batch into the memory at the give indexes. Args: @@ -517,7 +515,7 @@ def sample(self, batch_size: int = None) -> MultiTransitionBatch: Returns: batch (MultiTransitionBatch): The sampled batch. """ - indexes = self._get_sample_indexes(batch_size) + indexes = self.get_sample_indexes(batch_size) return self.sample_by_indexes(indexes) def sample_by_indexes(self, indexes: np.ndarray) -> MultiTransitionBatch: From d6f775db5ffcc49ed5f36ed54b20836b86008b44 Mon Sep 17 00:00:00 2001 From: Jinyu Wang Date: Wed, 10 May 2023 08:31:13 +0000 Subject: [PATCH 03/12] fix DiscreteRLPolicy set_state bug --- maro/rl/policy/discrete_rl_policy.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/maro/rl/policy/discrete_rl_policy.py b/maro/rl/policy/discrete_rl_policy.py index c5d67b7e4..50e6d8943 100644 --- a/maro/rl/policy/discrete_rl_policy.py +++ b/maro/rl/policy/discrete_rl_policy.py @@ -244,7 +244,7 @@ def get_state(self) -> dict: } def set_state(self, policy_state: dict) -> None: - self._q_net.set_state(policy_state) + self._q_net.set_state(policy_state["net"]) self._warmup = policy_state["policy"]["warmup"] self._call_count = policy_state["policy"]["call_count"] From 133fb95603338a6a740f64acf4f602226c1cfebc Mon Sep 17 00:00:00 2001 From: Default Date: Wed, 10 May 2023 20:58:32 +0800 Subject: [PATCH 04/12] Test DQN on GYM passed --- maro/rl/policy/discrete_rl_policy.py | 3 - maro/rl/training/__init__.py | 10 +- maro/rl/training/algorithms/dqn.py | 61 +++++++++---- maro/rl/training/replay_memory.py | 13 +-- tests/rl/gym_wrapper/common.py | 22 ++++- tests/rl/gym_wrapper/env_sampler.py | 9 +- tests/rl/tasks/ac/__init__.py | 3 + tests/rl/tasks/ddpg/__init__.py | 3 + tests/rl/tasks/dqn/__init__.py | 132 +++++++++++++++++++++++++++ tests/rl/tasks/dqn/config.yml | 32 +++++++ tests/rl/tasks/ppo/__init__.py | 3 + tests/rl/tasks/sac/__init__.py | 3 + 12 files changed, 257 insertions(+), 37 deletions(-) create mode 100644 tests/rl/tasks/dqn/__init__.py create mode 100644 tests/rl/tasks/dqn/config.yml diff --git a/maro/rl/policy/discrete_rl_policy.py b/maro/rl/policy/discrete_rl_policy.py index c5d67b7e4..759239db8 100644 --- a/maro/rl/policy/discrete_rl_policy.py +++ b/maro/rl/policy/discrete_rl_policy.py @@ -176,9 +176,6 @@ def q_values_tensor(self, states: torch.Tensor, actions: torch.Tensor, **kwargs) assert match_shape(q_values, (states.shape[0],)) # [B] return q_values - def explore(self) -> None: - pass # Overwrite the base method and turn off explore mode. - def _get_actions_impl(self, states: torch.Tensor, **kwargs) -> torch.Tensor: return self._get_actions_with_probs_impl(states, **kwargs)[0] diff --git a/maro/rl/training/__init__.py b/maro/rl/training/__init__.py index b3dfd8c61..0e4488915 100644 --- a/maro/rl/training/__init__.py +++ b/maro/rl/training/__init__.py @@ -2,7 +2,13 @@ # Licensed under the MIT license. from .proxy import TrainingProxy -from .replay_memory import FIFOMultiReplayMemory, FIFOReplayMemory, PriorityReplayMemory, RandomMultiReplayMemory, RandomReplayMemory +from .replay_memory import ( + FIFOMultiReplayMemory, + FIFOReplayMemory, + PrioritizedReplayMemory, + RandomMultiReplayMemory, + RandomReplayMemory, +) from .train_ops import AbsTrainOps, RemoteOps, remote from .trainer import AbsTrainer, BaseTrainerParams, MultiAgentTrainer, SingleAgentTrainer from .training_manager import TrainingManager @@ -12,7 +18,7 @@ "TrainingProxy", "FIFOMultiReplayMemory", "FIFOReplayMemory", - "PriorityReplayMemory", + "PrioritizedReplayMemory", "RandomMultiReplayMemory", "RandomReplayMemory", "AbsTrainOps", diff --git a/maro/rl/training/algorithms/dqn.py b/maro/rl/training/algorithms/dqn.py index 130cc9797..a12ba4a11 100644 --- a/maro/rl/training/algorithms/dqn.py +++ b/maro/rl/training/algorithms/dqn.py @@ -8,7 +8,15 @@ import torch from maro.rl.policy import RLPolicy, ValueBasedPolicy -from maro.rl.training import AbsTrainOps, BaseTrainerParams, PriorityReplayMemory, RemoteOps, SingleAgentTrainer, remote +from maro.rl.training import ( + AbsTrainOps, + BaseTrainerParams, + PrioritizedReplayMemory, + RandomReplayMemory, + RemoteOps, + SingleAgentTrainer, + remote, +) from maro.rl.utils import TransitionBatch, get_torch_device, ndarray_to_tensor from maro.utils import clone @@ -16,6 +24,9 @@ @dataclass class DQNParams(BaseTrainerParams): """ + use_prioritized_replay (bool, default=False): Whether to use prioritized replay memory. + alpha (float, default=0.4): Alpha in prioritized replay memory. + beta (float, default=0.6): Beta in prioritized replay memory. num_epochs (int, default=1): Number of training epochs. update_target_every (int, default=5): Number of gradient steps between target model updates. soft_update_coef (float, default=0.1): Soft update coefficient, e.g., @@ -28,8 +39,9 @@ class DQNParams(BaseTrainerParams): sequentially with wrap-around. """ - alpha: float - beta: float + use_prioritized_replay: bool = False + alpha: float = 0.4 + beta: float = 0.6 num_epochs: int = 1 update_target_every: int = 5 soft_update_coef: float = 0.1 @@ -90,9 +102,7 @@ def _get_batch_loss(self, batch: TransitionBatch, weight: np.ndarray) -> Tuple[t actions_by_eval_policy = self._policy.get_actions_tensor(next_states) next_q_values = self._target_policy.q_values_tensor(next_states, actions_by_eval_policy) else: - self._target_policy.exploit() - actions = self._target_policy.get_actions_tensor(next_states) - next_q_values = self._target_policy.q_values_tensor(next_states, actions) + next_q_values = self._target_policy.q_values_for_all_actions_tensor(next_states).max(dim=1)[0] target_q_values = (rewards + self._reward_discount * (1 - terminals) * next_q_values).detach() q_values = self._policy.q_values_tensor(states, actions) @@ -182,13 +192,22 @@ def __init__( def build(self) -> None: self._ops = cast(DQNOps, self.get_ops()) - self._replay_memory = PriorityReplayMemory( - capacity=self._replay_memory_capacity, - state_dim=self._ops.policy_state_dim, - action_dim=self._ops.policy_action_dim, - alpha=self._params.alpha, - beta=self._params.beta, - ) + + if self._params.use_prioritized_replay: + self._replay_memory = PrioritizedReplayMemory( + capacity=self._replay_memory_capacity, + state_dim=self._ops.policy_state_dim, + action_dim=self._ops.policy_action_dim, + alpha=self._params.alpha, + beta=self._params.beta, + ) + else: + self._replay_memory = RandomReplayMemory( + capacity=self._replay_memory_capacity, + state_dim=self._ops.policy_state_dim, + action_dim=self._ops.policy_action_dim, + random_overwrite=False, + ) def _register_policy(self, policy: RLPolicy) -> None: assert isinstance(policy, ValueBasedPolicy) @@ -204,19 +223,23 @@ def get_local_ops(self) -> AbsTrainOps: ) def _get_batch(self, batch_size: int = None) -> Tuple[TransitionBatch, np.ndarray, np.ndarray]: - replay_memory = cast(PriorityReplayMemory, self.replay_memory) - indexes = replay_memory.get_sample_indexes(batch_size or self._batch_size) - batch = replay_memory.sample_by_indexes(indexes) - weight = replay_memory.get_weight(indexes) + indexes = self.replay_memory.get_sample_indexes(batch_size or self._batch_size) + batch = self.replay_memory.sample_by_indexes(indexes) + + if self._params.use_prioritized_replay: + weight = cast(PrioritizedReplayMemory, self.replay_memory).get_weight(indexes) + else: + weight = np.ones(len(indexes)) + return batch, indexes, weight def train_step(self) -> None: assert isinstance(self._ops, DQNOps) - replay_memory = cast(PriorityReplayMemory, self.replay_memory) for _ in range(self._params.num_epochs): batch, indexes, weight = self._get_batch() td_error = self._ops.update(batch, weight) - replay_memory.update_weight(indexes, td_error) + if self._params.use_prioritized_replay: + cast(PrioritizedReplayMemory, self.replay_memory).update_weight(indexes, td_error) self._try_soft_update_target() diff --git a/maro/rl/training/replay_memory.py b/maro/rl/training/replay_memory.py index e31b99636..59a380cd6 100644 --- a/maro/rl/training/replay_memory.py +++ b/maro/rl/training/replay_memory.py @@ -2,7 +2,7 @@ # Licensed under the MIT license. from abc import ABCMeta, abstractmethod -from typing import List, Optional +from typing import List import numpy as np @@ -97,6 +97,7 @@ class PriorityReplayIndexScheduler(AbsIndexScheduler): alpha (float): Alpha (see original paper for explanation). beta (float): Alpha (see original paper for explanation). """ + def __init__( self, capacity: int, @@ -112,7 +113,7 @@ def __init__( self._ptr = self._size = 0 def init_weights(self, indexes: np.ndarray) -> None: - self._weights[indexes] = self._max_prio ** self._alpha + self._weights[indexes] = self._max_prio**self._alpha def get_weight(self, indexes: np.ndarray) -> np.ndarray: # important sampling weight calculation @@ -123,7 +124,7 @@ def get_weight(self, indexes: np.ndarray) -> np.ndarray: def update_weight(self, indexes: np.ndarray, weight: np.ndarray) -> None: assert indexes.shape == weight.shape weight = np.abs(weight) + np.finfo(np.float32).eps.item() - self._weights[indexes] = weight ** self._alpha + self._weights[indexes] = weight**self._alpha self._max_prio = max(self._max_prio, weight.max()) self._min_prio = min(self._min_prio, weight.min()) @@ -149,7 +150,7 @@ def get_sample_indexes(self, batch_size: int = None) -> np.ndarray: assert batch_size is not None and batch_size > 0, f"Invalid batch size: {batch_size}" assert self._size > 0, "Cannot sample from an empty memory." - weights = self._weights[:self._size] + weights = self._weights[: self._size] weights = weights / weights.sum() return np.random.choice(np.arange(self._size), p=weights, size=batch_size, replace=False) @@ -372,7 +373,7 @@ def random_overwrite(self) -> bool: return self._random_overwrite -class PriorityReplayMemory(ReplayMemory): +class PrioritizedReplayMemory(ReplayMemory): def __init__( self, capacity: int, @@ -381,7 +382,7 @@ def __init__( alpha: float, beta: float, ) -> None: - super(PriorityReplayMemory, self).__init__( + super(PrioritizedReplayMemory, self).__init__( capacity, state_dim, action_dim, diff --git a/tests/rl/gym_wrapper/common.py b/tests/rl/gym_wrapper/common.py index 538a5f996..a6a39445b 100644 --- a/tests/rl/gym_wrapper/common.py +++ b/tests/rl/gym_wrapper/common.py @@ -3,12 +3,14 @@ from typing import cast +from gym import spaces + from maro.simulator import Env from tests.rl.gym_wrapper.simulator.business_engine import GymBusinessEngine env_conf = { - "topology": "Walker2d-v4", # HalfCheetah-v4, Hopper-v4, Walker2d-v4, Swimmer-v4, Ant-v4 + "topology": "CartPole-v1", # HalfCheetah-v4, Hopper-v4, Walker2d-v4, Swimmer-v4, Ant-v4 "start_tick": 0, "durations": 100000, # Set a very large number "options": {}, @@ -19,8 +21,18 @@ num_agents = len(learn_env.agent_idx_list) gym_env = cast(GymBusinessEngine, learn_env.business_engine).gym_env -gym_action_space = gym_env.action_space gym_state_dim = gym_env.observation_space.shape[0] -gym_action_dim = gym_action_space.shape[0] -action_lower_bound, action_upper_bound = gym_action_space.low, gym_action_space.high -action_limit = gym_action_space.high[0] +gym_action_space = gym_env.action_space +is_discrete = isinstance(gym_action_space, spaces.Discrete) +if is_discrete: + gym_action_space = cast(spaces.Discrete, gym_action_space) + gym_action_dim = 1 + gym_action_num = gym_action_space.n + action_lower_bound, action_upper_bound = None, None # Should never be used + action_limit = None # Should never be used +else: + gym_action_space = cast(spaces.Box, gym_action_space) + gym_action_dim = gym_action_space.shape[0] + gym_action_num = -1 # Should never be used + action_lower_bound, action_upper_bound = gym_action_space.low, gym_action_space.high + action_limit = action_upper_bound[0] diff --git a/tests/rl/gym_wrapper/env_sampler.py b/tests/rl/gym_wrapper/env_sampler.py index f95aaa546..e740bafdb 100644 --- a/tests/rl/gym_wrapper/env_sampler.py +++ b/tests/rl/gym_wrapper/env_sampler.py @@ -1,9 +1,10 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -from typing import Any, Dict, List, Tuple, Type, Union +from typing import Any, Dict, List, Tuple, Type, Union, cast import numpy as np +from gym import spaces from maro.rl.policy.abs_policy import AbsPolicy from maro.rl.rollout import AbsEnvSampler, CacheElement @@ -40,6 +41,10 @@ def __init__( self._sample_rewards = [] self._eval_rewards = [] + gym_env = cast(GymBusinessEngine, learn_env.business_engine).gym_env + gym_action_space = gym_env.action_space + self._is_discrete = isinstance(gym_action_space, spaces.Discrete) + def _get_global_and_agent_state_impl( self, event: DecisionEvent, @@ -48,7 +53,7 @@ def _get_global_and_agent_state_impl( return None, {0: event.state} def _translate_to_env_action(self, action_dict: dict, event: Any) -> dict: - return {k: Action(v) for k, v in action_dict.items()} + return {k: Action(v.item() if self._is_discrete else v) for k, v in action_dict.items()} def _get_reward(self, env_action_dict: dict, event: Any, tick: int) -> Dict[Any, float]: be = self._env.business_engine diff --git a/tests/rl/tasks/ac/__init__.py b/tests/rl/tasks/ac/__init__.py index 24cc961fc..d9a73ecf3 100644 --- a/tests/rl/tasks/ac/__init__.py +++ b/tests/rl/tasks/ac/__init__.py @@ -19,6 +19,7 @@ action_upper_bound, gym_action_dim, gym_state_dim, + is_discrete, learn_env, num_agents, test_env, @@ -109,6 +110,8 @@ def get_ac_trainer(name: str, state_dim: int) -> ActorCriticTrainer: ) +assert not is_discrete + algorithm = "ac" agent2policy = {agent: f"{algorithm}_{agent}.policy" for agent in learn_env.agent_idx_list} policies = [ diff --git a/tests/rl/tasks/ddpg/__init__.py b/tests/rl/tasks/ddpg/__init__.py index 861904a43..cd097ddbc 100644 --- a/tests/rl/tasks/ddpg/__init__.py +++ b/tests/rl/tasks/ddpg/__init__.py @@ -20,6 +20,7 @@ gym_action_dim, gym_action_space, gym_state_dim, + is_discrete, learn_env, num_agents, test_env, @@ -123,6 +124,8 @@ def get_ddpg_trainer(name: str, state_dim: int, action_dim: int) -> DDPGTrainer: ) +assert not is_discrete + algorithm = "ddpg" agent2policy = {agent: f"{algorithm}_{agent}.policy" for agent in learn_env.agent_idx_list} policies = [ diff --git a/tests/rl/tasks/dqn/__init__.py b/tests/rl/tasks/dqn/__init__.py new file mode 100644 index 000000000..bc8185110 --- /dev/null +++ b/tests/rl/tasks/dqn/__init__.py @@ -0,0 +1,132 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. +import numpy as np +import torch +from torch.optim import Adam + +from maro.rl.model import DiscreteQNet, FullyConnected +from maro.rl.policy import ValueBasedPolicy +from maro.rl.rl_component.rl_component_bundle import RLComponentBundle +from maro.rl.training.algorithms import DQNParams, DQNTrainer + +from tests.rl.gym_wrapper.common import gym_action_num, gym_state_dim, is_discrete, learn_env, num_agents, test_env +from tests.rl.gym_wrapper.env_sampler import GymEnvSampler + +net_conf = { + "hidden_dims": [256], + "activation": torch.nn.ReLU, + "output_activation": None, +} +lr = 1e-3 + + +class LinearExplore: + def __init__(self) -> None: + self._call_count = 0 + + def explore_func( + self, + state: np.ndarray, + action: np.ndarray, + num_actions: int, + *, + explore_steps: int, + start_explore_prob: float, + end_explore_prob: float, + ) -> np.ndarray: + ratio = min(self._call_count / explore_steps, 1.0) + epsilon = start_explore_prob + (end_explore_prob - start_explore_prob) * ratio + explore_flag = np.random.random() < epsilon + action = np.array([np.random.randint(num_actions) if explore_flag else act for act in action]) + + self._call_count += 1 + return action + + +linear_explore = LinearExplore() + + +class MyQNet(DiscreteQNet): + def __init__(self, state_dim: int, action_num: int) -> None: + super(MyQNet, self).__init__(state_dim=state_dim, action_num=action_num) + + self._mlp = FullyConnected( + input_dim=state_dim, + output_dim=action_num, + **net_conf, + ) + self._optim = Adam(self._mlp.parameters(), lr=lr) + + def _get_q_values_for_all_actions(self, states: torch.Tensor) -> torch.Tensor: + return self._mlp(states) + + +def get_dqn_policy( + name: str, + state_dim: int, + action_num: int, +) -> ValueBasedPolicy: + return ValueBasedPolicy( + name=name, + q_net=MyQNet(state_dim=state_dim, action_num=action_num), + exploration_strategy=( + linear_explore.explore_func, + { + "explore_steps": 10000, + "start_explore_prob": 1.0, + "end_explore_prob": 0.02, + }, + ), + warmup=0, # TODO: check this + ) + + +def get_dqn_trainer( + name: str, +) -> DQNTrainer: + return DQNTrainer( + name=name, + params=DQNParams( + use_prioritized_replay=False, # + # alpha=0.4, + # beta=0.6, + num_epochs=50, + update_target_every=10, + soft_update_coef=1.0, + ), + replay_memory_capacity=50000, + batch_size=64, + reward_discount=1.0, + ) + + +assert is_discrete + +algorithm = "dqn" +agent2policy = {agent: f"{algorithm}_{agent}.policy" for agent in learn_env.agent_idx_list} +policies = [ + get_dqn_policy( + f"{algorithm}_{i}.policy", + state_dim=gym_state_dim, + action_num=gym_action_num, + ) + for i in range(num_agents) +] +trainers = [get_dqn_trainer(f"{algorithm}_{i}") for i in range(num_agents)] + +device_mapping = {f"{algorithm}_{i}.policy": "cuda:0" for i in range(num_agents)} if torch.cuda.is_available() else None + +rl_component_bundle = RLComponentBundle( + env_sampler=GymEnvSampler( + learn_env=learn_env, + test_env=test_env, + policies=policies, + agent2policy=agent2policy, + ), + agent2policy=agent2policy, + policies=policies, + trainers=trainers, + device_mapping=device_mapping, +) + +__all__ = ["rl_component_bundle"] diff --git a/tests/rl/tasks/dqn/config.yml b/tests/rl/tasks/dqn/config.yml new file mode 100644 index 000000000..aa3971127 --- /dev/null +++ b/tests/rl/tasks/dqn/config.yml @@ -0,0 +1,32 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +# Example RL config file for GYM scenario. +# Please refer to `maro/rl/workflows/config/template.yml` for the complete template and detailed explanations. + +job: gym_rl_workflow +scenario_path: "tests/rl/tasks/dqn" +log_path: "tests/rl/log/dqn_cartpole" +main: + num_episodes: 3000 + num_steps: 50 + eval_schedule: 50 + num_eval_episodes: 10 + min_n_sample: 1 + logging: + stdout: INFO + file: DEBUG +rollout: + logging: + stdout: INFO + file: DEBUG +training: + mode: simple + load_path: null + load_episode: null + checkpointing: + path: null + interval: 5 + logging: + stdout: INFO + file: DEBUG diff --git a/tests/rl/tasks/ppo/__init__.py b/tests/rl/tasks/ppo/__init__.py index 15fc71069..722fce328 100644 --- a/tests/rl/tasks/ppo/__init__.py +++ b/tests/rl/tasks/ppo/__init__.py @@ -11,6 +11,7 @@ action_upper_bound, gym_action_dim, gym_state_dim, + is_discrete, learn_env, num_agents, test_env, @@ -36,6 +37,8 @@ def get_ppo_trainer(name: str, state_dim: int) -> PPOTrainer: ) +assert not is_discrete + algorithm = "ppo" agent2policy = {agent: f"{algorithm}_{agent}.policy" for agent in learn_env.agent_idx_list} policies = [ diff --git a/tests/rl/tasks/sac/__init__.py b/tests/rl/tasks/sac/__init__.py index 1e033f12b..421ea4e96 100644 --- a/tests/rl/tasks/sac/__init__.py +++ b/tests/rl/tasks/sac/__init__.py @@ -24,6 +24,7 @@ gym_action_dim, gym_action_space, gym_state_dim, + is_discrete, learn_env, num_agents, test_env, @@ -133,6 +134,8 @@ def get_sac_trainer(name: str, state_dim: int, action_dim: int) -> SoftActorCrit ) +assert not is_discrete + algorithm = "sac" agent2policy = {agent: f"{algorithm}_{agent}.policy" for agent in learn_env.agent_idx_list} policies = [ From b4e94a06298cf0dfd927dfc24c1fbf80c06b6c43 Mon Sep 17 00:00:00 2001 From: Huoran Li Date: Fri, 12 May 2023 10:44:30 +0800 Subject: [PATCH 05/12] Refine explore strategy --- examples/cim/rl/algorithms/dqn.py | 16 +--- examples/vm_scheduling/rl/algorithms/dqn.py | 16 +--- maro/rl/exploration/__init__.py | 5 +- maro/rl/exploration/strategies.py | 98 ++++++++++++++++++++- maro/rl/policy/discrete_rl_policy.py | 31 ++----- tests/rl/tasks/dqn/__init__.py | 13 ++- 6 files changed, 117 insertions(+), 62 deletions(-) diff --git a/examples/cim/rl/algorithms/dqn.py b/examples/cim/rl/algorithms/dqn.py index 4e40034b8..158cf20f2 100644 --- a/examples/cim/rl/algorithms/dqn.py +++ b/examples/cim/rl/algorithms/dqn.py @@ -4,7 +4,7 @@ import torch from torch.optim import RMSprop -from maro.rl.exploration import MultiLinearExplorationScheduler, epsilon_greedy +from maro.rl.exploration import EpsilonGreedy from maro.rl.model import DiscreteQNet, FullyConnected from maro.rl.policy import ValueBasedPolicy from maro.rl.training.algorithms import DQNParams, DQNTrainer @@ -36,19 +36,7 @@ def get_dqn_policy(state_dim: int, action_num: int, name: str) -> ValueBasedPoli return ValueBasedPolicy( name=name, q_net=MyQNet(state_dim, action_num), - exploration_strategy=(epsilon_greedy, {"epsilon": 0.4}), - exploration_scheduling_options=[ - ( - "epsilon", - MultiLinearExplorationScheduler, - { - "splits": [(2, 0.32)], - "initial_value": 0.4, - "last_ep": 5, - "final_value": 0.0, - }, - ), - ], + explore_strategy=EpsilonGreedy(epsilon=0.4, num_actions=action_num), warmup=100, ) diff --git a/examples/vm_scheduling/rl/algorithms/dqn.py b/examples/vm_scheduling/rl/algorithms/dqn.py index 643d6c6d4..78be0d7bd 100644 --- a/examples/vm_scheduling/rl/algorithms/dqn.py +++ b/examples/vm_scheduling/rl/algorithms/dqn.py @@ -6,7 +6,7 @@ from torch.optim import SGD from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts -from maro.rl.exploration import MultiLinearExplorationScheduler +from maro.rl.exploration import EpsilonGreedy from maro.rl.model import DiscreteQNet, FullyConnected from maro.rl.policy import ValueBasedPolicy from maro.rl.training.algorithms import DQNParams, DQNTrainer @@ -58,19 +58,7 @@ def get_dqn_policy(state_dim: int, action_num: int, num_features: int, name: str return ValueBasedPolicy( name=name, q_net=MyQNet(state_dim, action_num, num_features), - exploration_strategy=(MaskedEpsGreedy(state_dim, num_features), {"epsilon": 0.4}), - exploration_scheduling_options=[ - ( - "epsilon", - MultiLinearExplorationScheduler, - { - "splits": [(100, 0.32)], - "initial_value": 0.4, - "last_ep": 400, - "final_value": 0.0, - }, - ), - ], + explore_strategy=EpsilonGreedy(epsilon=0.4, num_actions=action_num), warmup=100, ) diff --git a/maro/rl/exploration/__init__.py b/maro/rl/exploration/__init__.py index 383cca89a..3db026099 100644 --- a/maro/rl/exploration/__init__.py +++ b/maro/rl/exploration/__init__.py @@ -2,12 +2,15 @@ # Licensed under the MIT license. from .scheduling import AbsExplorationScheduler, LinearExplorationScheduler, MultiLinearExplorationScheduler -from .strategies import epsilon_greedy, gaussian_noise, uniform_noise +from .strategies import EpsilonGreedy, ExploreStrategy, LinearExploration, epsilon_greedy, gaussian_noise, uniform_noise __all__ = [ "AbsExplorationScheduler", "LinearExplorationScheduler", "MultiLinearExplorationScheduler", + "ExploreStrategy", + "EpsilonGreedy", + "LinearExploration", "epsilon_greedy", "gaussian_noise", "uniform_noise", diff --git a/maro/rl/exploration/strategies.py b/maro/rl/exploration/strategies.py index c85340c78..46eb10513 100644 --- a/maro/rl/exploration/strategies.py +++ b/maro/rl/exploration/strategies.py @@ -1,11 +1,105 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. - -from typing import Union +from abc import abstractmethod +from typing import Any, Union import numpy as np +class ExploreStrategy: + def __init__(self) -> None: + pass + + @abstractmethod + def get_action( + self, + state: np.ndarray, + action: np.ndarray, + **kwargs: Any, + ) -> np.ndarray: + """ + Args: + state (np.ndarray): State(s) based on which ``action`` is chosen. This is not used by the vanilla + eps-greedy exploration and is put here to conform to the function signature required for the exploration + strategy parameter for ``DQN``. + action (np.ndarray): Action(s) chosen greedily by the policy. + + Returns: + Exploratory actions. + """ + raise NotImplementedError + + +class EpsilonGreedy(ExploreStrategy): + """Epsilon-greedy exploration. Returns uniformly random action with probability `epsilon` or returns original + action with probability `1.0 - epsilon`. + + Args: + num_actions (int): Number of possible actions. + epsilon (float): The probability that a random action will be selected. + """ + + def __init__(self, num_actions: int, epsilon: float) -> None: + super(EpsilonGreedy, self).__init__() + + assert 0.0 <= epsilon <= 1.0 + + self._num_actions = num_actions + self._eps = epsilon + + def get_action( + self, + state: np.ndarray, + action: np.ndarray, + **kwargs: Any, + ) -> np.ndarray: + return np.array( + [act if np.random.random() > self._eps else np.random.randint(self._num_actions) for act in action], + ) + + +class LinearExploration(ExploreStrategy): + """Epsilon greedy which the probability `epsilon` is linearly interpolated between `start_explore_prob` and + `end_explore_prob` over `explore_steps`. After this many timesteps pass, `epsilon` is fixed to `end_explore_prob`. + + Args: + num_actions (int): Number of possible actions. + explore_steps (int) + start_explore_prob (float) + end_explore_prob (float) + """ + + def __init__( + self, + num_actions: int, + explore_steps: int, + start_explore_prob: float, + end_explore_prob: float, + ) -> None: + super(LinearExploration, self).__init__() + + self._call_count = 0 + + self._num_actions = num_actions + self._explore_steps = explore_steps + self._start_explore_prob = start_explore_prob + self._end_explore_prob = end_explore_prob + + def get_action( + self, + state: np.ndarray, + action: np.ndarray, + **kwargs: Any, + ) -> np.ndarray: + ratio = min(self._call_count / self._explore_steps, 1.0) + epsilon = self._start_explore_prob + (self._end_explore_prob - self._start_explore_prob) * ratio + explore_flag = np.random.random() < epsilon + action = np.array([np.random.randint(self._num_actions) if explore_flag else act for act in action]) + + self._call_count += 1 + return action + + def epsilon_greedy( state: np.ndarray, action: np.ndarray, diff --git a/maro/rl/policy/discrete_rl_policy.py b/maro/rl/policy/discrete_rl_policy.py index 759239db8..073976014 100644 --- a/maro/rl/policy/discrete_rl_policy.py +++ b/maro/rl/policy/discrete_rl_policy.py @@ -2,15 +2,14 @@ # Licensed under the MIT license. from abc import ABCMeta -from typing import Callable, Dict, List, Tuple +from typing import Dict, Optional, Tuple import numpy as np import torch -from maro.rl.exploration import epsilon_greedy +from maro.rl.exploration import ExploreStrategy from maro.rl.model import DiscretePolicyNet, DiscreteQNet from maro.rl.utils import match_shape, ndarray_to_tensor -from maro.utils import clone from .abs_policy import RLPolicy @@ -69,8 +68,7 @@ class ValueBasedPolicy(DiscreteRLPolicy): name (str): Name of the policy. q_net (DiscreteQNet): Q-net used in this value-based policy. trainable (bool, default=True): Whether this policy is trainable. - exploration_strategy (Tuple[Callable, dict], default=(epsilon_greedy, {"epsilon": 0.1})): Exploration strategy. - exploration_scheduling_options (List[tuple], default=None): List of exploration scheduler options. + explore_strategy (Optional[ExploreStrategy], default=None): Explore strategy. warmup (int, default=50000): Number of steps for uniform-random action selection, before running real policy. Helps exploration. """ @@ -80,8 +78,7 @@ def __init__( name: str, q_net: DiscreteQNet, trainable: bool = True, - exploration_strategy: Tuple[Callable, dict] = (epsilon_greedy, {"epsilon": 0.1}), - exploration_scheduling_options: List[tuple] = None, + explore_strategy: Optional[ExploreStrategy] = None, warmup: int = 50000, ) -> None: assert isinstance(q_net, DiscreteQNet) @@ -94,15 +91,7 @@ def __init__( warmup=warmup, ) self._q_net = q_net - - self._exploration_func = exploration_strategy[0] - self._exploration_params = clone(exploration_strategy[1]) # deep copy is needed to avoid unwanted sharing - self._exploration_schedulers = ( - [opt[1](self._exploration_params, opt[0], **opt[2]) for opt in exploration_scheduling_options] - if exploration_scheduling_options is not None - else [] - ) - + self._explore_strategy = explore_strategy self._softmax = torch.nn.Softmax(dim=1) @property @@ -184,14 +173,8 @@ def _get_actions_with_probs_impl(self, states: torch.Tensor, **kwargs) -> Tuple[ q_matrix_softmax = self._softmax(q_matrix) _, actions = q_matrix.max(dim=1) # [B], [B] - if self._is_exploring: - actions = self._exploration_func( - states, - actions.cpu().numpy(), - self.action_num, - **self._exploration_params, - **kwargs, - ) + if self._is_exploring and self._explore_strategy is not None: + actions = self._explore_strategy.get_action(state=states.cpu().numpy(), action=actions.cpu().numpy()) actions = ndarray_to_tensor(actions, device=self._device) actions = actions.unsqueeze(1) diff --git a/tests/rl/tasks/dqn/__init__.py b/tests/rl/tasks/dqn/__init__.py index bc8185110..f3c7be033 100644 --- a/tests/rl/tasks/dqn/__init__.py +++ b/tests/rl/tasks/dqn/__init__.py @@ -4,6 +4,7 @@ import torch from torch.optim import Adam +from maro.rl.exploration import LinearExploration from maro.rl.model import DiscreteQNet, FullyConnected from maro.rl.policy import ValueBasedPolicy from maro.rl.rl_component.rl_component_bundle import RLComponentBundle @@ -69,13 +70,11 @@ def get_dqn_policy( return ValueBasedPolicy( name=name, q_net=MyQNet(state_dim=state_dim, action_num=action_num), - exploration_strategy=( - linear_explore.explore_func, - { - "explore_steps": 10000, - "start_explore_prob": 1.0, - "end_explore_prob": 0.02, - }, + explore_strategy=LinearExploration( + num_actions=action_num, + explore_steps=10000, + start_explore_prob=1.0, + end_explore_prob=0.02, ), warmup=0, # TODO: check this ) From 7e004709281e7618bf04d72907dfb331698d8f66 Mon Sep 17 00:00:00 2001 From: Huoran Li Date: Fri, 12 May 2023 11:38:39 +0800 Subject: [PATCH 06/12] Minor --- maro/rl/training/replay_memory.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/maro/rl/training/replay_memory.py b/maro/rl/training/replay_memory.py index 59a380cd6..164c2580c 100644 --- a/maro/rl/training/replay_memory.py +++ b/maro/rl/training/replay_memory.py @@ -152,7 +152,7 @@ def get_sample_indexes(self, batch_size: int = None) -> np.ndarray: weights = self._weights[: self._size] weights = weights / weights.sum() - return np.random.choice(np.arange(self._size), p=weights, size=batch_size, replace=False) + return np.random.choice(np.arange(self._size), p=weights, size=batch_size, replace=True) class FIFOIndexScheduler(AbsIndexScheduler): From ca6b387bfd855564c2116db46fba31745f8f77f4 Mon Sep 17 00:00:00 2001 From: Huoran Li Date: Fri, 12 May 2023 13:19:37 +0800 Subject: [PATCH 07/12] Minor --- tests/rl/gym_wrapper/common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/rl/gym_wrapper/common.py b/tests/rl/gym_wrapper/common.py index a6a39445b..4a4bd12b2 100644 --- a/tests/rl/gym_wrapper/common.py +++ b/tests/rl/gym_wrapper/common.py @@ -10,7 +10,7 @@ from tests.rl.gym_wrapper.simulator.business_engine import GymBusinessEngine env_conf = { - "topology": "CartPole-v1", # HalfCheetah-v4, Hopper-v4, Walker2d-v4, Swimmer-v4, Ant-v4 + "topology": "CartPole-v1", # HalfCheetah-v4, Hopper-v4, Walker2d-v4, Swimmer-v4, Ant-v4, CartPole-v1 "start_tick": 0, "durations": 100000, # Set a very large number "options": {}, From 790a37338a9f83b2bda61b291eab9dfed86f0538 Mon Sep 17 00:00:00 2001 From: Default Date: Mon, 15 May 2023 09:59:00 +0800 Subject: [PATCH 08/12] almost ready --- examples/cim/rl/config.py | 2 +- examples/cim/rl/env_sampler.py | 52 +++++++-- examples/vm_scheduling/rl/env_sampler.py | 34 +++--- maro/rl/distributed/abs_worker.py | 2 +- maro/rl/policy/discrete_rl_policy.py | 4 +- maro/rl/rollout/__init__.py | 3 +- maro/rl/rollout/batch_env_sampler.py | 105 ++++++++++++------ maro/rl/rollout/env_sampler.py | 55 +++++++-- maro/rl/rollout/worker.py | 19 +++- .../training/algorithms/base/ac_ppo_base.py | 30 ++--- maro/rl/training/algorithms/ddpg.py | 17 +-- maro/rl/training/algorithms/dqn.py | 21 ++-- maro/rl/training/algorithms/maddpg.py | 3 +- maro/rl/training/algorithms/sac.py | 17 +-- maro/rl/training/training_manager.py | 1 - maro/rl/training/worker.py | 6 +- maro/rl/workflows/callback.py | 18 ++- maro/rl/workflows/main.py | 14 +-- .../scenarios/cim/business_engine.py | 1 + tests/rl/gym_wrapper/common.py | 2 +- tests/rl/gym_wrapper/env_sampler.py | 38 ++++++- tests/rl/tasks/ac/config_distributed.yml | 45 ++++++++ tests/rl/tasks/ddpg/config_distributed.yml | 45 ++++++++ tests/rl/tasks/dqn/config_distributed.yml | 45 ++++++++ tests/rl/tasks/ppo/config_distributed.yml | 45 ++++++++ tests/rl/tasks/sac/config_distributed.yml | 45 ++++++++ 26 files changed, 523 insertions(+), 146 deletions(-) create mode 100644 tests/rl/tasks/ac/config_distributed.yml create mode 100644 tests/rl/tasks/ddpg/config_distributed.yml create mode 100644 tests/rl/tasks/dqn/config_distributed.yml create mode 100644 tests/rl/tasks/ppo/config_distributed.yml create mode 100644 tests/rl/tasks/sac/config_distributed.yml diff --git a/examples/cim/rl/config.py b/examples/cim/rl/config.py index e91e5edfc..a46194900 100644 --- a/examples/cim/rl/config.py +++ b/examples/cim/rl/config.py @@ -35,4 +35,4 @@ action_num = len(action_shaping_conf["action_space"]) -algorithm = "dqn" # ac, ppo, dqn or discrete_maddpg +algorithm = "ppo" # ac, ppo, dqn or discrete_maddpg diff --git a/examples/cim/rl/env_sampler.py b/examples/cim/rl/env_sampler.py index c7cd241e4..5797dcbb1 100644 --- a/examples/cim/rl/env_sampler.py +++ b/examples/cim/rl/env_sampler.py @@ -85,30 +85,64 @@ def _post_step(self, cache_element: CacheElement) -> None: def _post_eval_step(self, cache_element: CacheElement) -> None: self._post_step(cache_element) - def post_collect(self, info_list: list, ep: int) -> None: + def post_collect(self, ep: int) -> None: + if len(self._info_list) == 0: + return + # print the env metric from each rollout worker - for info in info_list: + for info in self._info_list: print(f"env summary (episode {ep}): {info['env_metric']}") # average env metric - metric_keys, num_envs = info_list[0]["env_metric"].keys(), len(info_list) - avg_metric = {key: sum(info["env_metric"][key] for info in info_list) / num_envs for key in metric_keys} + metric_keys, num_envs = self._info_list[0]["env_metric"].keys(), len(self._info_list) + avg_metric = {key: sum(info["env_metric"][key] for info in self._info_list) / num_envs for key in metric_keys} + avg_metric["n_episode"] = len(self._info_list) print(f"average env summary (episode {ep}): {avg_metric}") self.metrics.update(avg_metric) self.metrics = {k: v for k, v in self.metrics.items() if not k.startswith("val/")} + self._info_list.clear() + + def post_evaluate(self, ep: int) -> None: + if len(self._info_list) == 0: + return - def post_evaluate(self, info_list: list, ep: int) -> None: # print the env metric from each rollout worker - for info in info_list: + for info in self._info_list: print(f"env summary (episode {ep}): {info['env_metric']}") # average env metric - metric_keys, num_envs = info_list[0]["env_metric"].keys(), len(info_list) - avg_metric = {key: sum(info["env_metric"][key] for info in info_list) / num_envs for key in metric_keys} + metric_keys, num_envs = self._info_list[0]["env_metric"].keys(), len(self._info_list) + avg_metric = {key: sum(info["env_metric"][key] for info in self._info_list) / num_envs for key in metric_keys} + avg_metric["n_episode"] = len(self._info_list) print(f"average env summary (episode {ep}): {avg_metric}") self.metrics.update({"val/" + k: v for k, v in avg_metric.items()}) + self._info_list.clear() def monitor_metrics(self) -> float: - return -self.metrics["val/container_shortage"] + return -self.metrics["val/shortage_percentage"] + + @staticmethod + def merge_metrics(metrics_list: List[dict]) -> dict: + n_episode = sum(m["n_episode"] for m in metrics_list) + metrics: dict = { + "order_requirements": sum(m["order_requirements"] * m["n_episode"] for m in metrics_list) / n_episode, + "container_shortage": sum(m["container_shortage"] * m["n_episode"] for m in metrics_list) / n_episode, + "operation_number": sum(m["operation_number"] * m["n_episode"] for m in metrics_list) / n_episode, + "n_episode": n_episode, + } + metrics["shortage_percentage"] = metrics["container_shortage"] / metrics["order_requirements"] + + metrics_list = [m for m in metrics_list if "val/shortage_percentage" in m] + if len(metrics_list) > 0: + n_episode = sum(m["val/n_episode"] for m in metrics_list) + metrics.update({ + "val/order_requirements": sum(m["val/order_requirements"] * m["val/n_episode"] for m in metrics_list) / n_episode, + "val/container_shortage": sum(m["val/container_shortage"] * m["val/n_episode"] for m in metrics_list) / n_episode, + "val/operation_number": sum(m["val/operation_number"] * m["val/n_episode"] for m in metrics_list) / n_episode, + "val/n_episode": n_episode, + }) + metrics["val/shortage_percentage"] = metrics["val/container_shortage"] / metrics["val/order_requirements"] + + return metrics diff --git a/examples/vm_scheduling/rl/env_sampler.py b/examples/vm_scheduling/rl/env_sampler.py index 3fc39776e..813df3208 100644 --- a/examples/vm_scheduling/rl/env_sampler.py +++ b/examples/vm_scheduling/rl/env_sampler.py @@ -166,29 +166,35 @@ def _post_step(self, cache_element: CacheElement) -> None: def _post_eval_step(self, cache_element: CacheElement) -> None: self._post_step(cache_element) - def post_collect(self, info_list: list, ep: int) -> None: + def post_collect(self, ep: int) -> None: + if len(self._info_list) == 0: + return + # print the env metric from each rollout worker - for info in info_list: + for info in self._info_list: print(f"env summary (episode {ep}): {info['env_metric']}") # print the average env metric - if len(info_list) > 1: - metric_keys, num_envs = info_list[0]["env_metric"].keys(), len(info_list) - avg_metric = {key: sum(tr["env_metric"][key] for tr in info_list) / num_envs for key in metric_keys} - print(f"average env metric (episode {ep}): {avg_metric}") + metric_keys, num_envs = self._info_list[0]["env_metric"].keys(), len(self._info_list) + avg_metric = {key: sum(tr["env_metric"][key] for tr in self._info_list) / num_envs for key in metric_keys} + print(f"average env metric (episode {ep}): {avg_metric}") + + self._info_list.clear() + + def post_evaluate(self, ep: int) -> None: + if len(self._info_list) == 0: + return - def post_evaluate(self, info_list: list, ep: int) -> None: # print the env metric from each rollout worker - for info in info_list: + for info in self._info_list: print(f"env summary (evaluation episode {ep}): {info['env_metric']}") # print the average env metric - if len(info_list) > 1: - metric_keys, num_envs = info_list[0]["env_metric"].keys(), len(info_list) - avg_metric = {key: sum(tr["env_metric"][key] for tr in info_list) / num_envs for key in metric_keys} - print(f"average env metric (evaluation episode {ep}): {avg_metric}") + metric_keys, num_envs = self._info_list[0]["env_metric"].keys(), len(self._info_list) + avg_metric = {key: sum(tr["env_metric"][key] for tr in self._info_list) / num_envs for key in metric_keys} + print(f"average env metric (evaluation episode {ep}): {avg_metric}") - for info in info_list: + for info in self._info_list: core_requirement = info["actions_by_core_requirement"] action_sequence = info["action_sequence"] # plot action sequence @@ -231,3 +237,5 @@ def post_evaluate(self, info_list: list, ep: int) -> None: plt.cla() plt.close("all") + + self._info_list.clear() diff --git a/maro/rl/distributed/abs_worker.py b/maro/rl/distributed/abs_worker.py index 7da7e9435..475be6f1a 100644 --- a/maro/rl/distributed/abs_worker.py +++ b/maro/rl/distributed/abs_worker.py @@ -34,7 +34,7 @@ def __init__( super(AbsWorker, self).__init__() self._id = f"worker.{idx}" - self._logger: Union[LoggerV2, DummyLogger] = logger if logger else DummyLogger() + self._logger: Union[LoggerV2, DummyLogger] = logger or DummyLogger() # ZMQ sockets and streams self._context = Context.instance() diff --git a/maro/rl/policy/discrete_rl_policy.py b/maro/rl/policy/discrete_rl_policy.py index aeb0e4767..344be00d8 100644 --- a/maro/rl/policy/discrete_rl_policy.py +++ b/maro/rl/policy/discrete_rl_policy.py @@ -177,7 +177,7 @@ def _get_actions_with_probs_impl(self, states: torch.Tensor, **kwargs) -> Tuple[ actions = self._explore_strategy.get_action(state=states.cpu().numpy(), action=actions.cpu().numpy()) actions = ndarray_to_tensor(actions, device=self._device) - actions = actions.unsqueeze(1) + actions = actions.unsqueeze(1).long() return actions, q_matrix_softmax.gather(1, actions).squeeze(-1) # [B, 1] def _get_actions_with_logps_impl(self, states: torch.Tensor, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]: @@ -316,7 +316,7 @@ def get_state(self) -> dict: } def set_state(self, policy_state: dict) -> None: - self._policy_net.set_state(policy_state) + self._policy_net.set_state(policy_state["net"]) self._warmup = policy_state["policy"]["warmup"] self._call_count = policy_state["policy"]["call_count"] diff --git a/maro/rl/rollout/__init__.py b/maro/rl/rollout/__init__.py index 17315f615..d5ab1c083 100644 --- a/maro/rl/rollout/__init__.py +++ b/maro/rl/rollout/__init__.py @@ -2,7 +2,7 @@ # Licensed under the MIT license. from .batch_env_sampler import BatchEnvSampler -from .env_sampler import AbsAgentWrapper, AbsEnvSampler, CacheElement, ExpElement, SimpleAgentWrapper +from .env_sampler import AbsAgentWrapper, AbsEnvSampler, CacheElement, EnvSamplerInterface, ExpElement, SimpleAgentWrapper from .worker import RolloutWorker __all__ = [ @@ -10,6 +10,7 @@ "AbsAgentWrapper", "AbsEnvSampler", "CacheElement", + "EnvSamplerInterface", "ExpElement", "SimpleAgentWrapper", "RolloutWorker", diff --git a/maro/rl/rollout/batch_env_sampler.py b/maro/rl/rollout/batch_env_sampler.py index b6b0976f0..719104eae 100644 --- a/maro/rl/rollout/batch_env_sampler.py +++ b/maro/rl/rollout/batch_env_sampler.py @@ -6,6 +6,7 @@ from itertools import chain from typing import Any, Dict, List, Optional, Union +import numpy as np import torch import zmq from zmq import Context, Poller @@ -15,7 +16,16 @@ from maro.rl.utils.objects import FILE_SUFFIX from maro.utils import DummyLogger, LoggerV2 -from .env_sampler import ExpElement +from .env_sampler import EnvSamplerInterface, ExpElement + + +def _split(total: int, k: int) -> List[int]: + """Split integer `total` into `k` groups where the sum of the `k` groups equals to `total` and all groups are + as close as possible. + """ + + p, q = total // k, total % k + return [p + 1] * q + [p] * (k - q) class ParallelTaskController(object): @@ -39,7 +49,7 @@ def __init__(self, port: int = 20000, logger: LoggerV2 = None) -> None: self._poller.register(self._task_endpoint, zmq.POLLIN) self._workers: set = set() - self._logger: Union[DummyLogger, LoggerV2] = logger if logger is not None else DummyLogger() + self._logger: Union[DummyLogger, LoggerV2] = logger or DummyLogger() def _wait_for_workers_ready(self, k: int) -> None: while len(self._workers) < k: @@ -50,7 +60,14 @@ def _recv_result_for_target_index(self, index: int) -> Any: assert isinstance(rep, dict) return rep["result"] if rep["index"] == index else None - def collect(self, req: dict, parallelism: int, min_replies: int = None, grace_factor: float = None) -> List[dict]: + def collect( + self, + req: dict, + parallelism: int, + min_replies: int = None, + grace_factor: float = None, + unique_params: Optional[dict] = None, + ) -> list: """Send a task request to a set of remote workers and collect the results. Args: @@ -62,18 +79,24 @@ def collect(self, req: dict, parallelism: int, min_replies: int = None, grace_fa minimum required replies (as determined by ``min_replies``). For example, if the minimum required replies are received in T seconds, it will allow an additional T * grace_factor seconds to collect the remaining results. + unique_params (Optional[float], default=None): Unique params for each worker. Returns: A list of results. Each element in the list is a dict that contains results from a worker. """ + if unique_params is not None: + for key, params in unique_params.items(): + assert len(params) == parallelism + assert len(set(req.keys()) & set(unique_params.keys())) == 0, "Parameter overwritten is not allowed." + self._wait_for_workers_ready(parallelism) - if min_replies is None: - min_replies = parallelism + min_replies = min_replies or parallelism start_time = time.time() results: list = [] - for worker_id in list(self._workers)[:parallelism]: - self._task_endpoint.send_multipart([worker_id, pyobj_to_bytes(req)]) + for i, worker_id in enumerate(list(self._workers)[:parallelism]): + cur_params = {key: params[i] for key, params in unique_params.items()} if unique_params is not None else {} + self._task_endpoint.send_multipart([worker_id, pyobj_to_bytes({**req, **cur_params})]) self._logger.debug(f"Sent {parallelism} roll-out requests...") while len(results) < min_replies: @@ -104,48 +127,45 @@ def exit(self) -> None: self._context.term() -class BatchEnvSampler: +class BatchEnvSampler(EnvSamplerInterface): """Facility that samples from multiple copies of an environment in parallel. No environment is created here. Instead, it uses a ParallelTaskController to send roll-out requests to a set of remote workers and collect results from them. Args: - sampling_parallelism (int): Parallelism for sampling from the environment. - port (int): Network port that the internal ``ParallelTaskController`` uses to talk to the remote workers. + sampling_parallelism (int, default=1): Parallelism for sampling from the environment. + port (int, default=DEFAULT_ROLLOUT_PRODUCER_PORT): Network port that the internal ``ParallelTaskController`` + uses to talk to the remote workers. min_env_samples (int, default=None): The minimum number of results to collect in one round of remote sampling. If it is None, it defaults to the value of ``sampling_parallelism``. grace_factor (float, default=None): Factor that determines the additional wait time after receiving the minimum required env samples (as determined by ``min_env_samples``). For example, if the minimum required samples are received in T seconds, it will allow an additional T * grace_factor seconds to collect the remaining results. - eval_parallelism (int, default=None): Parallelism for policy evaluation on remote workers. + eval_parallelism (int, default=1): Parallelism for policy evaluation on remote workers. logger (LoggerV2, default=None): Optional logger for logging key events. """ def __init__( self, - sampling_parallelism: int, - port: int = None, + sampling_parallelism: int = 1, + port: int = DEFAULT_ROLLOUT_PRODUCER_PORT, min_env_samples: int = None, grace_factor: float = None, - eval_parallelism: int = None, + eval_parallelism: int = 1, logger: LoggerV2 = None, ) -> None: super(BatchEnvSampler, self).__init__() - self._logger: Union[LoggerV2, DummyLogger] = logger if logger is not None else DummyLogger() - self._controller = ParallelTaskController( - port=port if port is not None else DEFAULT_ROLLOUT_PRODUCER_PORT, - logger=logger, - ) + self._logger: Union[LoggerV2, DummyLogger] = logger or DummyLogger() + self._controller = ParallelTaskController(port=port, logger=logger) - self._sampling_parallelism = 1 if sampling_parallelism is None else sampling_parallelism - self._min_env_samples = min_env_samples if min_env_samples is not None else self._sampling_parallelism + self._sampling_parallelism = sampling_parallelism + self._min_env_samples = min_env_samples or self._sampling_parallelism self._grace_factor = grace_factor - self._eval_parallelism = 1 if eval_parallelism is None else eval_parallelism + self._eval_parallelism = eval_parallelism self._ep = 0 - self._end_of_episode = True def sample( self, @@ -165,26 +185,23 @@ def sample( A dict that contains the collected experiences and additional information. """ # increment episode depending on whether the last episode has concluded - if self._end_of_episode: - self._ep += 1 - + self._ep += 1 self._logger.info(f"Collecting roll-out data for episode {self._ep}") req = { "type": "sample", "policy_state": policy_state, - "num_steps": num_steps, "index": self._ep, } + num_steps = [None] * self._sampling_parallelism if num_steps is None else _split(num_steps, self._sampling_parallelism) results = self._controller.collect( req, self._sampling_parallelism, min_replies=self._min_env_samples, grace_factor=self._grace_factor, + unique_params={"num_steps": num_steps}, ) - self._end_of_episode = any(res["end_of_episode"] for res in results) merged_experiences: List[List[ExpElement]] = list(chain(*[res["experiences"] for res in results])) return { - "end_of_episode": self._end_of_episode, "experiences": merged_experiences, "info": [res["info"][0] for res in results], } @@ -194,9 +211,12 @@ def eval(self, policy_state: Dict[str, Dict[str, Any]] = None, num_episodes: int "type": "eval", "policy_state": policy_state, "index": self._ep, - "num_eval_episodes": num_episodes, } # -1 signals test - results = self._controller.collect(req, self._eval_parallelism) + results = self._controller.collect( + req, + self._eval_parallelism, + unique_params={"num_eval_episodes": _split(num_episodes, self._eval_parallelism)}, + ) return { "info": [res["info"][0] for res in results], } @@ -223,10 +243,21 @@ def load_policy_state(self, path: str) -> List[str]: def exit(self) -> None: self._controller.exit() - def post_collect(self, info_list: list, ep: int) -> None: - req = {"type": "post_collect", "info_list": info_list, "index": ep} - self._controller.collect(req, 1) + def post_collect(self, ep: int) -> None: + req = {"type": "post_collect", "index": ep} + self._controller.collect(req, self._sampling_parallelism) + + def post_evaluate(self, ep: int) -> None: + req = {"type": "post_evaluate", "index": ep} + self._controller.collect(req, self._eval_parallelism) + + def monitor_metrics(self) -> float: + req = {"type": "monitor_metrics", "index": self._ep} + return float(np.mean(self._controller.collect(req, self._sampling_parallelism))) - def post_evaluate(self, info_list: list, ep: int) -> None: - req = {"type": "post_evaluate", "info_list": info_list, "index": ep} - self._controller.collect(req, 1) + def get_metrics(self) -> dict: + req = {"type": "get_metrics", "index": self._ep} + metrics_list = self._controller.collect(req, self._sampling_parallelism) + req = {"type": "merge_metrics", "metrics_list": metrics_list, "index": self._ep} + metrics = self._controller.collect(req, 1)[0] + return metrics diff --git a/maro/rl/rollout/env_sampler.py b/maro/rl/rollout/env_sampler.py index 2825ef4aa..41118fb9d 100644 --- a/maro/rl/rollout/env_sampler.py +++ b/maro/rl/rollout/env_sampler.py @@ -235,7 +235,39 @@ def make_exp_element(self) -> ExpElement: ) -class AbsEnvSampler(object, metaclass=ABCMeta): +class EnvSamplerInterface(object, metaclass=ABCMeta): + @abstractmethod + def monitor_metrics(self) -> float: + raise NotImplementedError + + @abstractmethod + def load_policy_state(self, path: str) -> List[str]: + raise NotImplementedError + + @abstractmethod + def sample( + self, + policy_state: Optional[Dict[str, Dict[str, Any]]] = None, + num_steps: Optional[int] = None, + ) -> dict: + raise NotImplementedError + + @abstractmethod + def eval(self, policy_state: Dict[str, Dict[str, Any]] = None, num_episodes: int = 1) -> dict: + raise NotImplementedError + + def post_collect(self, ep: int) -> None: + """Routines to be invoked at the end of training episodes""" + + def post_evaluate(self, ep: int) -> None: + """Routines to be invoked at the end of evaluation episodes""" + + @abstractmethod + def get_metrics(self) -> dict: + raise NotImplementedError + + +class AbsEnvSampler(EnvSamplerInterface, metaclass=ABCMeta): """Simulation data collector and policy evaluator. Args: @@ -262,6 +294,8 @@ def __init__( reward_eval_delay: int = None, max_episode_length: int = None, ) -> None: + super(AbsEnvSampler, self).__init__() + assert learn_env is not test_env, "Please use different envs for training and testing." self._learn_env = learn_env @@ -281,6 +315,7 @@ def __init__( self._current_episode_length = 0 self._info: dict = {} + self._info_list: List[dict] = [] # Will NOT be cleared in `reset()`. self.metrics: dict = {} assert self._reward_eval_delay is None or self._reward_eval_delay >= 0 @@ -531,9 +566,11 @@ def sample( total_experiences += experiences + info = deepcopy(self._info) # TODO: may have overhead issues. Leave to future work. + self._info_list.append(info) return { "experiences": [total_experiences], - "info": [deepcopy(self._info)], # TODO: may have overhead issues. Leave to future work. + "info": [info], } def set_policy_state(self, policy_state_dict: Dict[str, dict]) -> None: @@ -560,7 +597,6 @@ def load_policy_state(self, path: str) -> List[str]: def eval(self, policy_state: Dict[str, Dict[str, Any]] = None, num_episodes: int = 1) -> dict: self._switch_env(self._test_env) - info_list = [] for _ in range(num_episodes): self._reset() @@ -606,9 +642,9 @@ def eval(self, policy_state: Dict[str, Dict[str, Any]] = None, num_episodes: int self._calc_reward(cache_element) self._post_eval_step(cache_element) - info_list.append(self._info) + self._info_list.append(self._info) - return {"info": info_list} + return {"info": self._info_list} @abstractmethod def _post_step(self, cache_element: CacheElement) -> None: @@ -618,8 +654,9 @@ def _post_step(self, cache_element: CacheElement) -> None: def _post_eval_step(self, cache_element: CacheElement) -> None: raise NotImplementedError - def post_collect(self, info_list: list, ep: int) -> None: - """Routines to be invoked at the end of training episodes""" + def get_metrics(self) -> dict: + return self.metrics - def post_evaluate(self, info_list: list, ep: int) -> None: - """Routines to be invoked at the end of evaluation episodes""" + @staticmethod + def merge_metrics(metrics_list: List[dict]) -> dict: + return metrics_list[0] diff --git a/maro/rl/rollout/worker.py b/maro/rl/rollout/worker.py index 1532a6489..124efd438 100644 --- a/maro/rl/rollout/worker.py +++ b/maro/rl/rollout/worker.py @@ -39,7 +39,7 @@ def __init__( producer_port=producer_port if producer_port is not None else DEFAULT_ROLLOUT_PRODUCER_PORT, logger=logger, ) - self._env_sampler = rl_component_bundle.env_sampler + self._env_sampler = rl_component_bundle.env_sampler # TODO: deep copy? def _compute(self, msg: list) -> None: """Perform a full or partial episode of roll-out for sampling or evaluation. @@ -53,7 +53,6 @@ def _compute(self, msg: list) -> None: else: req = bytes_to_pyobj(msg[-1]) assert isinstance(req, dict) - assert req["type"] in {"sample", "eval", "set_policy_state", "post_collect", "post_evaluate"} if req["type"] in ("sample", "eval"): result = ( @@ -61,12 +60,20 @@ def _compute(self, msg: list) -> None: if req["type"] == "sample" else self._env_sampler.eval(policy_state=req["policy_state"], num_episodes=req["num_eval_episodes"]) ) - self._stream.send(pyobj_to_bytes({"result": result, "index": req["index"]})) + elif req["type"] == "monitor_metrics": + result = self._env_sampler.monitor_metrics() + elif req["type"] == "get_metrics": + result = self._env_sampler.get_metrics() + elif req["type"] == "merge_metrics": + result = self._env_sampler.merge_metrics(metrics_list=req["metrics_list"]) else: if req["type"] == "set_policy_state": self._env_sampler.set_policy_state(policy_state_dict=req["policy_state"]) elif req["type"] == "post_collect": - self._env_sampler.post_collect(info_list=req["info_list"], ep=req["index"]) + self._env_sampler.post_collect(ep=req["index"]) + elif req["type"] == "post_evaluate": + self._env_sampler.post_evaluate(ep=req["index"]) else: - self._env_sampler.post_evaluate(info_list=req["info_list"], ep=req["index"]) - self._stream.send(pyobj_to_bytes({"result": True, "index": req["index"]})) + raise ValueError(f"Invalid remote function call: {req['type']}") + result = True + self._stream.send(pyobj_to_bytes({"result": result, "index": req["index"]})) diff --git a/maro/rl/training/algorithms/base/ac_ppo_base.py b/maro/rl/training/algorithms/base/ac_ppo_base.py index 5903d2ec7..8f990257a 100644 --- a/maro/rl/training/algorithms/base/ac_ppo_base.py +++ b/maro/rl/training/algorithms/base/ac_ppo_base.py @@ -175,18 +175,14 @@ def update_actor(self, batch: TransitionBatch) -> Tuple[float, bool]: self._policy.train_step(loss) return loss.detach().cpu().numpy().item(), early_stop - def update_actor_with_grad(self, grad_dict_and_early_stop: Tuple[dict, bool]) -> bool: + def update_actor_with_grad(self, grad_dict: dict) -> None: """Update the actor network with remotely computed gradients. Args: - grad_dict_and_early_stop (Tuple[dict, bool]): Gradients and early stop indicator. - - Returns: - early stop indicator + grad_dict (dict): Gradients. """ self._policy.train() - self._policy.apply_gradients(grad_dict_and_early_stop[0]) - return grad_dict_and_early_stop[1] + self._policy.apply_gradients(grad_dict) def get_non_policy_state(self) -> dict: return { @@ -322,28 +318,22 @@ def _get_batch(self) -> TransitionBatch: return batch def train_step(self) -> None: - assert isinstance(self._ops, ACBasedOps) - + ops = cast(ACBasedOps, self._ops) batch = self._get_batch() - for _ in range(self._params.grad_iters): - early_stop = self._ops.update_actor(batch) + _, early_stop = ops.update_actor(batch) if early_stop: break - for _ in range(self._params.grad_iters): - self._ops.update_critic(batch) + ops.update_critic(batch) async def train_step_as_task(self) -> None: - assert isinstance(self._ops, RemoteOps) - + ops = cast(ACBasedOps, self._ops) batch = self._get_batch() - for _ in range(self._params.grad_iters): - grad_dict, early_stop = await self._ops.get_actor_grad(batch) - self._ops.update_actor_with_grad(grad_dict) + grad_dict, early_stop = await ops.get_actor_grad(batch) + ops.update_actor_with_grad(grad_dict) if early_stop: break - for _ in range(self._params.grad_iters): - self._ops.update_critic_with_grad(await self._ops.get_critic_grad(batch)) + ops.update_critic_with_grad(await ops.get_critic_grad(batch)) diff --git a/maro/rl/training/algorithms/ddpg.py b/maro/rl/training/algorithms/ddpg.py index bf7b0f8d4..f3c89cced 100644 --- a/maro/rl/training/algorithms/ddpg.py +++ b/maro/rl/training/algorithms/ddpg.py @@ -274,7 +274,7 @@ def _get_batch(self, batch_size: int = None) -> TransitionBatch: return self._replay_memory.sample(batch_size if batch_size is not None else self._batch_size) def train_step(self) -> None: - assert isinstance(self._ops, DDPGOps) + ops = cast(DDPGOps, self._ops) if self._replay_memory.n_sample < self._params.n_start_train: print( @@ -285,13 +285,15 @@ def train_step(self) -> None: for _ in range(self._params.num_epochs): batch = self._get_batch() - self._ops.update_critic(batch) - self._ops.update_actor(batch) + ops.update_critic(batch) + _, early_stop = ops.update_actor(batch) self._try_soft_update_target() + if early_stop: + break async def train_step_as_task(self) -> None: - assert isinstance(self._ops, RemoteOps) + ops = cast(DDPGOps, self._ops) if self._replay_memory.n_sample < self._params.n_start_train: print( @@ -302,9 +304,10 @@ async def train_step_as_task(self) -> None: for _ in range(self._params.num_epochs): batch = self._get_batch() - self._ops.update_critic_with_grad(await self._ops.get_critic_grad(batch)) - grad_dict, early_stop = await self._ops.get_actor_grad(batch) - self._ops.update_actor_with_grad(grad_dict) + ops.update_critic_with_grad(await ops.get_critic_grad(batch)) + grad_dict, early_stop = await ops.get_actor_grad(batch) + ops.update_actor_with_grad(grad_dict) + self._try_soft_update_target() if early_stop: break diff --git a/maro/rl/training/algorithms/dqn.py b/maro/rl/training/algorithms/dqn.py index a12ba4a11..80528b3cd 100644 --- a/maro/rl/training/algorithms/dqn.py +++ b/maro/rl/training/algorithms/dqn.py @@ -111,17 +111,19 @@ def _get_batch_loss(self, batch: TransitionBatch, weight: np.ndarray) -> Tuple[t return (td_error.pow(2) * weight).mean(), td_error @remote - def get_batch_grad(self, batch: TransitionBatch) -> Dict[str, torch.Tensor]: + def get_batch_grad(self, batch: TransitionBatch, weight: np.ndarray) -> Tuple[Dict[str, torch.Tensor], np.ndarray]: """Compute the network's gradients of a batch. Args: batch (TransitionBatch): Batch. + weight (np.ndarray): Weight of each data entry. Returns: grad (torch.Tensor): The gradient of the batch. + td_error (np.ndarray): TD error. """ - loss, _ = self._get_batch_loss(batch) - return self._policy.get_gradients(loss) + loss, td_error = self._get_batch_loss(batch, weight) + return self._policy.get_gradients(loss), td_error.detach().numpy() def update_with_grad(self, grad_dict: dict) -> None: """Update the network with remotely computed gradients. @@ -234,20 +236,23 @@ def _get_batch(self, batch_size: int = None) -> Tuple[TransitionBatch, np.ndarra return batch, indexes, weight def train_step(self) -> None: - assert isinstance(self._ops, DQNOps) + ops = cast(DQNOps, self._ops) for _ in range(self._params.num_epochs): batch, indexes, weight = self._get_batch() - td_error = self._ops.update(batch, weight) + td_error = ops.update(batch, weight) if self._params.use_prioritized_replay: cast(PrioritizedReplayMemory, self.replay_memory).update_weight(indexes, td_error) self._try_soft_update_target() async def train_step_as_task(self) -> None: - assert isinstance(self._ops, RemoteOps) + ops = cast(DQNOps, self._ops) for _ in range(self._params.num_epochs): - batch = self._get_batch() - self._ops.update_with_grad(await self._ops.get_batch_grad(batch)) + batch, indexes, weight = self._get_batch() + grad, td_error = await ops.get_batch_grad(batch, weight) + ops.update_with_grad(grad) + if self._params.use_prioritized_replay: + cast(PrioritizedReplayMemory, self.replay_memory).update_weight(indexes, td_error) self._try_soft_update_target() diff --git a/maro/rl/training/algorithms/maddpg.py b/maro/rl/training/algorithms/maddpg.py index d17f91e14..59af9ced7 100644 --- a/maro/rl/training/algorithms/maddpg.py +++ b/maro/rl/training/algorithms/maddpg.py @@ -478,7 +478,8 @@ async def train_step_as_task(self) -> None: ops.update_critic_with_grad(critic_grad) # Update actors - actor_grad_list = await asyncio.gather(*[ops.get_actor_grad(batch)[0] for ops in self._actor_ops_list]) + return_list = await asyncio.gather(*[ops.get_actor_grad(batch) for ops in self._actor_ops_list]) + actor_grad_list = [e[0] for e in return_list] for ops, actor_grad in zip(self._actor_ops_list, actor_grad_list): ops.update_actor_with_grad(actor_grad) diff --git a/maro/rl/training/algorithms/sac.py b/maro/rl/training/algorithms/sac.py index 54a6c4cbd..455cbce98 100644 --- a/maro/rl/training/algorithms/sac.py +++ b/maro/rl/training/algorithms/sac.py @@ -237,7 +237,7 @@ def _register_policy(self, policy: RLPolicy) -> None: self._policy = policy def train_step(self) -> None: - assert isinstance(self._ops, SoftActorCriticOps) + ops = cast(SoftActorCriticOps, self._ops) if self._replay_memory.n_sample < self._params.n_start_train: print( @@ -248,13 +248,15 @@ def train_step(self) -> None: for _ in range(self._params.num_epochs): batch = self._get_batch() - self._ops.update_critic(batch) - self._ops.update_actor(batch) + ops.update_critic(batch) + _, early_stop = ops.update_actor(batch) self._try_soft_update_target() + if early_stop: + break async def train_step_as_task(self) -> None: - assert isinstance(self._ops, RemoteOps) + ops = cast(SoftActorCriticOps, self._ops) if self._replay_memory.n_sample < self._params.n_start_train: print( @@ -265,9 +267,10 @@ async def train_step_as_task(self) -> None: for _ in range(self._params.num_epochs): batch = self._get_batch() - self._ops.update_critic_with_grad(await self._ops.get_critic_grad(batch)) - grad_dict, early_stop = await self._ops.get_actor_grad(batch) - self._ops.update_actor_with_grad(grad_dict) + ops.update_critic_with_grad(await ops.get_critic_grad(batch)) + grad_dict, early_stop = await ops.get_actor_grad(batch) + ops.update_actor_with_grad(grad_dict) + self._try_soft_update_target() if early_stop: break diff --git a/maro/rl/training/training_manager.py b/maro/rl/training/training_manager.py index 9d6b36b15..0ad39d699 100644 --- a/maro/rl/training/training_manager.py +++ b/maro/rl/training/training_manager.py @@ -80,7 +80,6 @@ def __init__( def train_step(self) -> None: if self._proxy_address: - async def train_step() -> Iterable: return await asyncio.gather( *[trainer_.train_step_as_task() for trainer_ in self._trainer_dict.values()] diff --git a/maro/rl/training/worker.py b/maro/rl/training/worker.py index 4cb1528f4..aed9505b2 100644 --- a/maro/rl/training/worker.py +++ b/maro/rl/training/worker.py @@ -26,7 +26,7 @@ class TrainOpsWorker(AbsWorker): so that the proxy can keep track of its connection status. rl_component_bundle (RLComponentBundle): Resources to launch the RL workflow. producer_host (str): IP address of the proxy host to connect to. - producer_port (int, default=10001): Port of the proxy host to connect to. + producer_port (int, default=DEFAULT_TRAINING_BACKEND_PORT): Port of the proxy host to connect to. """ def __init__( @@ -34,13 +34,13 @@ def __init__( idx: int, rl_component_bundle: RLComponentBundle, producer_host: str, - producer_port: int = None, + producer_port: int = DEFAULT_TRAINING_BACKEND_PORT, logger: LoggerV2 = None, ) -> None: super(TrainOpsWorker, self).__init__( idx=idx, producer_host=producer_host, - producer_port=producer_port if producer_port is not None else DEFAULT_TRAINING_BACKEND_PORT, + producer_port=producer_port, logger=logger, ) diff --git a/maro/rl/workflows/callback.py b/maro/rl/workflows/callback.py index 1c5a2c2f7..7bee0411e 100644 --- a/maro/rl/workflows/callback.py +++ b/maro/rl/workflows/callback.py @@ -6,24 +6,22 @@ import copy import os import typing -from typing import Dict, List, Optional, Union +from typing import Dict, List, Optional import pandas as pd -from maro.rl.rollout import AbsEnvSampler, BatchEnvSampler +from maro.rl.rollout import EnvSamplerInterface from maro.rl.training import TrainingManager from maro.utils import LoggerV2 if typing.TYPE_CHECKING: from maro.rl.workflows.main import TrainingWorkflow -EnvSampler = Union[AbsEnvSampler, BatchEnvSampler] - class Callback(object): def __init__(self) -> None: self.workflow: Optional[TrainingWorkflow] = None - self.env_sampler: Optional[EnvSampler] = None + self.env_sampler: Optional[EnvSamplerInterface] = None self.training_manager: Optional[TrainingManager] = None self.logger: Optional[LoggerV2] = None @@ -107,8 +105,8 @@ def _dump_metric_history(self) -> None: df.to_csv(os.path.join(self._path, "metrics_valid.csv"), index=True) def on_training_end(self, ep: int) -> None: - if len(self.env_sampler.metrics) > 0: - metrics = copy.deepcopy(self.env_sampler.metrics) + metrics = copy.deepcopy(self.env_sampler.get_metrics()) + if len(metrics) > 0: metrics["ep"] = ep if ep in self._full_metrics: self._full_metrics[ep].update(metrics) @@ -117,8 +115,8 @@ def on_training_end(self, ep: int) -> None: self._dump_metric_history() def on_validation_end(self, ep: int) -> None: - if len(self.env_sampler.metrics) > 0: - metrics = copy.deepcopy(self.env_sampler.metrics) + metrics = copy.deepcopy(self.env_sampler.get_metrics()) + if len(metrics) > 0: metrics["ep"] = ep if ep in self._full_metrics: self._full_metrics[ep].update(metrics) @@ -136,7 +134,7 @@ def __init__( self, workflow: TrainingWorkflow, callbacks: List[Callback], - env_sampler: EnvSampler, + env_sampler: EnvSamplerInterface, training_manager: TrainingManager, logger: LoggerV2, ) -> None: diff --git a/maro/rl/workflows/main.py b/maro/rl/workflows/main.py index a1f89b82a..8da3da4d0 100644 --- a/maro/rl/workflows/main.py +++ b/maro/rl/workflows/main.py @@ -6,10 +6,10 @@ import os import sys import time -from typing import List, Union +from typing import List from maro.rl.rl_component.rl_component_bundle import RLComponentBundle -from maro.rl.rollout import AbsEnvSampler, BatchEnvSampler, ExpElement +from maro.rl.rollout import BatchEnvSampler, EnvSamplerInterface, ExpElement from maro.rl.training import TrainingManager from maro.rl.utils import get_torch_device from maro.rl.utils.common import float_or_none, get_env, int_or_none, list_or_none @@ -92,7 +92,7 @@ def _get_args() -> argparse.Namespace: def _get_env_sampler( rl_component_bundle: RLComponentBundle, env_attr: WorkflowEnvAttributes, -) -> Union[AbsEnvSampler, BatchEnvSampler]: +) -> EnvSamplerInterface: if env_attr.parallel_rollout: assert env_attr.env_sampling_parallelism is not None return BatchEnvSampler( @@ -190,7 +190,7 @@ def run(self, rl_component_bundle: RLComponentBundle, env_attr: WorkflowEnvAttri collect_time += time.time() - tc0 - env_sampler.post_collect(total_info_list, ep) + env_sampler.post_collect(ep) tu0 = time.time() env_attr.logger.info(f"Roll-out completed for episode {ep}. Training started...") @@ -208,11 +208,11 @@ def run(self, rl_component_bundle: RLComponentBundle, env_attr: WorkflowEnvAttri cbm.on_validation_start(ep) eval_point_index += 1 - result = env_sampler.eval( + env_sampler.eval( policy_state=training_manager.get_policy_state() if not env_attr.is_single_thread else None, num_episodes=env_attr.num_eval_episodes, ) - env_sampler.post_evaluate(result["info"], ep) + env_sampler.post_evaluate(ep) cbm.on_validation_end(ep) @@ -238,7 +238,7 @@ def evaluate_only_workflow(rl_component_bundle: RLComponentBundle, env_attr: Wor env_attr.logger.info(f"Loaded policies {loaded} into env sampler from {path}") result = env_sampler.eval(num_episodes=env_attr.num_eval_episodes) - env_sampler.post_evaluate(result["info"], -1) + env_sampler.post_evaluate(-1) if isinstance(env_sampler, BatchEnvSampler): env_sampler.exit() diff --git a/maro/simulator/scenarios/cim/business_engine.py b/maro/simulator/scenarios/cim/business_engine.py index 6a26dbe99..f8afb50e7 100644 --- a/maro/simulator/scenarios/cim/business_engine.py +++ b/maro/simulator/scenarios/cim/business_engine.py @@ -277,6 +277,7 @@ def get_metrics(self) -> DocableDict: { "order_requirements": total_booking, "container_shortage": total_shortage, + "shortage_percentage": total_shortage / total_booking, "operation_number": self._total_operate_num, }, ) diff --git a/tests/rl/gym_wrapper/common.py b/tests/rl/gym_wrapper/common.py index 4a4bd12b2..2d1db2f7e 100644 --- a/tests/rl/gym_wrapper/common.py +++ b/tests/rl/gym_wrapper/common.py @@ -10,7 +10,7 @@ from tests.rl.gym_wrapper.simulator.business_engine import GymBusinessEngine env_conf = { - "topology": "CartPole-v1", # HalfCheetah-v4, Hopper-v4, Walker2d-v4, Swimmer-v4, Ant-v4, CartPole-v1 + "topology": "Walker2d-v4", # HalfCheetah-v4, Hopper-v4, Walker2d-v4, Swimmer-v4, Ant-v4, CartPole-v1 "start_tick": 0, "durations": 100000, # Set a very large number "options": {}, diff --git a/tests/rl/gym_wrapper/env_sampler.py b/tests/rl/gym_wrapper/env_sampler.py index e740bafdb..a0ad0e8d8 100644 --- a/tests/rl/gym_wrapper/env_sampler.py +++ b/tests/rl/gym_wrapper/env_sampler.py @@ -72,7 +72,7 @@ def _post_eval_step(self, cache_element: CacheElement) -> None: rewards = list(self._env.metrics["reward_record"].values()) self._eval_rewards.append((len(rewards), np.sum(rewards))) - def post_collect(self, info_list: list, ep: int) -> None: + def post_collect(self, ep: int) -> None: if len(self._sample_rewards) > 0: cur = { "n_steps": sum([n for n, _ in self._sample_rewards]), @@ -89,7 +89,7 @@ def post_collect(self, info_list: list, ep: int) -> None: else: self.metrics = {"n_interactions": self._total_number_interactions} - def post_evaluate(self, info_list: list, ep: int) -> None: + def post_evaluate(self, ep: int) -> None: if len(self._eval_rewards) > 0: cur = { "val/n_steps": sum([n for n, _ in self._eval_rewards]), @@ -102,3 +102,37 @@ def post_evaluate(self, info_list: list, ep: int) -> None: self._eval_rewards.clear() else: self.metrics = {k: v for k, v in self.metrics.items() if not k.startswith("val/")} + + @staticmethod + def merge_metrics(metrics_list: List[dict]) -> dict: + metrics = {"n_interactions": sum(m["n_interactions"] for m in metrics_list)} + + tmp_metrics_list = [m for m in metrics_list if "n_steps" in m ] + if len(tmp_metrics_list) > 0: + n_steps = sum(m["n_steps"] for m in tmp_metrics_list) + n_segment = sum(m["n_segment"] for m in tmp_metrics_list) + metrics.update( + { + "n_steps": n_steps, + "n_segment": n_segment, + "avg_reward": sum(m["avg_reward"] * m["n_segment"] for m in tmp_metrics_list) / n_segment, + "avg_n_steps": n_steps / n_segment, + "max_n_steps": max(m["max_n_steps"] for m in tmp_metrics_list), + } + ) + + tmp_metrics_list = [m for m in metrics_list if "val/n_steps" in m ] + if len(tmp_metrics_list) > 0: + n_steps = sum(m["val/n_steps"] for m in tmp_metrics_list) + n_segment = sum(m["val/n_segment"] for m in tmp_metrics_list) + metrics.update( + { + "val/n_steps": n_steps, + "val/n_segment": n_segment, + "val/avg_reward": sum(m["val/avg_reward"] * m["val/n_segment"] for m in tmp_metrics_list) / n_segment, + "val/avg_n_steps": n_steps / n_segment, + "val/max_n_steps": max(m["val/max_n_steps"] for m in tmp_metrics_list), + } + ) + + return metrics diff --git a/tests/rl/tasks/ac/config_distributed.yml b/tests/rl/tasks/ac/config_distributed.yml new file mode 100644 index 000000000..61137d8ef --- /dev/null +++ b/tests/rl/tasks/ac/config_distributed.yml @@ -0,0 +1,45 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +# Example RL config file for GYM scenario (parallel mode). +# Please refer to `maro/rl/workflows/config/template.yml` for the complete template and detailed explanations. + +job: gym_rl_workflow +scenario_path: "tests/rl/tasks/ac" +log_path: "tests/rl/log/ac" +main: + num_episodes: 1000 + num_steps: null + eval_schedule: 5 + num_eval_episodes: 10 + min_n_sample: 5000 + logging: + stdout: INFO + file: DEBUG +rollout: + logging: + stdout: INFO + file: DEBUG + parallelism: + sampling: 2 + eval: 2 + min_env_samples: 2 + grace_factor: 0.2 + controller: + host: "127.0.0.1" + port: 21000 +training: + mode: parallel + load_path: null + load_episode: null + checkpointing: + path: null + interval: 5 + logging: + stdout: INFO + file: DEBUG + proxy: + host: "127.0.0.1" + frontend: 10000 + backend: 10001 + num_workers: 2 diff --git a/tests/rl/tasks/ddpg/config_distributed.yml b/tests/rl/tasks/ddpg/config_distributed.yml new file mode 100644 index 000000000..fc158bcd5 --- /dev/null +++ b/tests/rl/tasks/ddpg/config_distributed.yml @@ -0,0 +1,45 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +# Example RL config file for GYM scenario (parallel mode). +# Please refer to `maro/rl/workflows/config/template.yml` for the complete template and detailed explanations. + +job: gym_rl_workflow +scenario_path: "tests/rl/tasks/ddpg" +log_path: "tests/rl/log/ddpg_walker2d" +main: + num_episodes: 80000 + num_steps: 50 + eval_schedule: 200 + num_eval_episodes: 10 + min_n_sample: 1 + logging: + stdout: INFO + file: DEBUG +rollout: + logging: + stdout: INFO + file: DEBUG + parallelism: + sampling: 2 + eval: 2 + min_env_samples: 2 + grace_factor: 0.2 + controller: + host: "127.0.0.1" + port: 21000 +training: + mode: parallel + load_path: null + load_episode: null + checkpointing: + path: null + interval: 200 + logging: + stdout: INFO + file: DEBUG + proxy: + host: "127.0.0.1" + frontend: 10000 + backend: 10001 + num_workers: 2 diff --git a/tests/rl/tasks/dqn/config_distributed.yml b/tests/rl/tasks/dqn/config_distributed.yml new file mode 100644 index 000000000..e8ac86831 --- /dev/null +++ b/tests/rl/tasks/dqn/config_distributed.yml @@ -0,0 +1,45 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +# Example RL config file for GYM scenario (parallel mode). +# Please refer to `maro/rl/workflows/config/template.yml` for the complete template and detailed explanations. + +job: gym_rl_workflow +scenario_path: "tests/rl/tasks/dqn" +log_path: "tests/rl/log/dqn_cartpole" +main: + num_episodes: 3000 + num_steps: 50 + eval_schedule: 50 + num_eval_episodes: 10 + min_n_sample: 1 + logging: + stdout: INFO + file: DEBUG +rollout: + logging: + stdout: INFO + file: DEBUG + parallelism: + sampling: 2 + eval: 2 + min_env_samples: 2 + grace_factor: 0.2 + controller: + host: "127.0.0.1" + port: 21000 +training: + mode: parallel + load_path: null + load_episode: null + checkpointing: + path: null + interval: 5 + logging: + stdout: INFO + file: DEBUG + proxy: + host: "127.0.0.1" + frontend: 10000 + backend: 10001 + num_workers: 2 diff --git a/tests/rl/tasks/ppo/config_distributed.yml b/tests/rl/tasks/ppo/config_distributed.yml new file mode 100644 index 000000000..96b6a1edf --- /dev/null +++ b/tests/rl/tasks/ppo/config_distributed.yml @@ -0,0 +1,45 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +# Example RL config file for GYM scenario (parallel mode). +# Please refer to `maro/rl/workflows/config/template.yml` for the complete template and detailed explanations. + +job: gym_rl_workflow +scenario_path: "tests/rl/tasks/ppo" +log_path: "tests/rl/log/ppo_walker2d" +main: + num_episodes: 1000 + num_steps: 4000 + eval_schedule: 5 + num_eval_episodes: 10 + min_n_sample: 1 + logging: + stdout: INFO + file: DEBUG +rollout: + logging: + stdout: INFO + file: DEBUG + parallelism: + sampling: 2 + eval: 2 + min_env_samples: 2 + grace_factor: 0.2 + controller: + host: "127.0.0.1" + port: 21000 +training: + mode: parallel + load_path: null + load_episode: null + checkpointing: + path: null + interval: 5 + logging: + stdout: INFO + file: DEBUG + proxy: + host: "127.0.0.1" + frontend: 10000 + backend: 10001 + num_workers: 2 diff --git a/tests/rl/tasks/sac/config_distributed.yml b/tests/rl/tasks/sac/config_distributed.yml new file mode 100644 index 000000000..eba8e973d --- /dev/null +++ b/tests/rl/tasks/sac/config_distributed.yml @@ -0,0 +1,45 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +# Example RL config file for GYM scenario (parallel mode). +# Please refer to `maro/rl/workflows/config/template.yml` for the complete template and detailed explanations. + +job: gym_rl_workflow +scenario_path: "tests/rl/tasks/sac" +log_path: "tests/rl/log/sac_walker2d" +main: + num_episodes: 80000 + num_steps: 50 + eval_schedule: 200 + num_eval_episodes: 10 + min_n_sample: 1 + logging: + stdout: INFO + file: DEBUG +rollout: + logging: + stdout: INFO + file: DEBUG + parallelism: + sampling: 2 + eval: 2 + min_env_samples: 2 + grace_factor: 0.2 + controller: + host: "127.0.0.1" + port: 21000 +training: + mode: parallel + load_path: null + load_episode: null + checkpointing: + path: null + interval: 200 + logging: + stdout: INFO + file: DEBUG + proxy: + host: "127.0.0.1" + frontend: 10000 + backend: 10001 + num_workers: 2 From e884f4d465ee4a77098cbf69991a4b218f9bcabe Mon Sep 17 00:00:00 2001 From: Huoran Li Date: Mon, 15 May 2023 10:03:22 +0800 Subject: [PATCH 09/12] Run pre-commit --- examples/cim/rl/env_sampler.py | 21 +++++++++++++------ maro/rl/rollout/__init__.py | 9 +++++++- maro/rl/rollout/batch_env_sampler.py | 4 +++- .../training/algorithms/base/ac_ppo_base.py | 2 +- maro/rl/training/algorithms/ddpg.py | 2 +- maro/rl/training/algorithms/dqn.py | 1 - maro/rl/training/algorithms/sac.py | 2 +- maro/rl/training/training_manager.py | 1 + tests/rl/gym_wrapper/env_sampler.py | 15 ++++++------- 9 files changed, 38 insertions(+), 19 deletions(-) diff --git a/examples/cim/rl/env_sampler.py b/examples/cim/rl/env_sampler.py index 5797dcbb1..850253721 100644 --- a/examples/cim/rl/env_sampler.py +++ b/examples/cim/rl/env_sampler.py @@ -137,12 +137,21 @@ def merge_metrics(metrics_list: List[dict]) -> dict: metrics_list = [m for m in metrics_list if "val/shortage_percentage" in m] if len(metrics_list) > 0: n_episode = sum(m["val/n_episode"] for m in metrics_list) - metrics.update({ - "val/order_requirements": sum(m["val/order_requirements"] * m["val/n_episode"] for m in metrics_list) / n_episode, - "val/container_shortage": sum(m["val/container_shortage"] * m["val/n_episode"] for m in metrics_list) / n_episode, - "val/operation_number": sum(m["val/operation_number"] * m["val/n_episode"] for m in metrics_list) / n_episode, - "val/n_episode": n_episode, - }) + metrics.update( + { + "val/order_requirements": sum( + m["val/order_requirements"] * m["val/n_episode"] for m in metrics_list + ) + / n_episode, + "val/container_shortage": sum( + m["val/container_shortage"] * m["val/n_episode"] for m in metrics_list + ) + / n_episode, + "val/operation_number": sum(m["val/operation_number"] * m["val/n_episode"] for m in metrics_list) + / n_episode, + "val/n_episode": n_episode, + }, + ) metrics["val/shortage_percentage"] = metrics["val/container_shortage"] / metrics["val/order_requirements"] return metrics diff --git a/maro/rl/rollout/__init__.py b/maro/rl/rollout/__init__.py index d5ab1c083..0e4b0b42c 100644 --- a/maro/rl/rollout/__init__.py +++ b/maro/rl/rollout/__init__.py @@ -2,7 +2,14 @@ # Licensed under the MIT license. from .batch_env_sampler import BatchEnvSampler -from .env_sampler import AbsAgentWrapper, AbsEnvSampler, CacheElement, EnvSamplerInterface, ExpElement, SimpleAgentWrapper +from .env_sampler import ( + AbsAgentWrapper, + AbsEnvSampler, + CacheElement, + EnvSamplerInterface, + ExpElement, + SimpleAgentWrapper, +) from .worker import RolloutWorker __all__ = [ diff --git a/maro/rl/rollout/batch_env_sampler.py b/maro/rl/rollout/batch_env_sampler.py index 719104eae..27bfc2e2b 100644 --- a/maro/rl/rollout/batch_env_sampler.py +++ b/maro/rl/rollout/batch_env_sampler.py @@ -192,7 +192,9 @@ def sample( "policy_state": policy_state, "index": self._ep, } - num_steps = [None] * self._sampling_parallelism if num_steps is None else _split(num_steps, self._sampling_parallelism) + num_steps = ( + [None] * self._sampling_parallelism if num_steps is None else _split(num_steps, self._sampling_parallelism) + ) results = self._controller.collect( req, self._sampling_parallelism, diff --git a/maro/rl/training/algorithms/base/ac_ppo_base.py b/maro/rl/training/algorithms/base/ac_ppo_base.py index 8f990257a..7834c6dc4 100644 --- a/maro/rl/training/algorithms/base/ac_ppo_base.py +++ b/maro/rl/training/algorithms/base/ac_ppo_base.py @@ -10,7 +10,7 @@ from maro.rl.model import VNet from maro.rl.policy import ContinuousRLPolicy, DiscretePolicyGradient, RLPolicy -from maro.rl.training import AbsTrainOps, BaseTrainerParams, FIFOReplayMemory, RemoteOps, SingleAgentTrainer, remote +from maro.rl.training import AbsTrainOps, BaseTrainerParams, FIFOReplayMemory, SingleAgentTrainer, remote from maro.rl.utils import TransitionBatch, discount_cumsum, get_torch_device, ndarray_to_tensor diff --git a/maro/rl/training/algorithms/ddpg.py b/maro/rl/training/algorithms/ddpg.py index f3c89cced..4bb5b42c1 100644 --- a/maro/rl/training/algorithms/ddpg.py +++ b/maro/rl/training/algorithms/ddpg.py @@ -8,7 +8,7 @@ from maro.rl.model import QNet from maro.rl.policy import ContinuousRLPolicy, RLPolicy -from maro.rl.training import AbsTrainOps, BaseTrainerParams, RandomReplayMemory, RemoteOps, SingleAgentTrainer, remote +from maro.rl.training import AbsTrainOps, BaseTrainerParams, RandomReplayMemory, SingleAgentTrainer, remote from maro.rl.utils import TransitionBatch, get_torch_device, ndarray_to_tensor from maro.utils import clone diff --git a/maro/rl/training/algorithms/dqn.py b/maro/rl/training/algorithms/dqn.py index 80528b3cd..aa99a64f2 100644 --- a/maro/rl/training/algorithms/dqn.py +++ b/maro/rl/training/algorithms/dqn.py @@ -13,7 +13,6 @@ BaseTrainerParams, PrioritizedReplayMemory, RandomReplayMemory, - RemoteOps, SingleAgentTrainer, remote, ) diff --git a/maro/rl/training/algorithms/sac.py b/maro/rl/training/algorithms/sac.py index 455cbce98..a51a61d75 100644 --- a/maro/rl/training/algorithms/sac.py +++ b/maro/rl/training/algorithms/sac.py @@ -8,7 +8,7 @@ from maro.rl.model import QNet from maro.rl.policy import ContinuousRLPolicy, RLPolicy -from maro.rl.training import AbsTrainOps, BaseTrainerParams, RandomReplayMemory, RemoteOps, SingleAgentTrainer, remote +from maro.rl.training import AbsTrainOps, BaseTrainerParams, RandomReplayMemory, SingleAgentTrainer, remote from maro.rl.utils import TransitionBatch, get_torch_device, ndarray_to_tensor from maro.utils import clone diff --git a/maro/rl/training/training_manager.py b/maro/rl/training/training_manager.py index 0ad39d699..9d6b36b15 100644 --- a/maro/rl/training/training_manager.py +++ b/maro/rl/training/training_manager.py @@ -80,6 +80,7 @@ def __init__( def train_step(self) -> None: if self._proxy_address: + async def train_step() -> Iterable: return await asyncio.gather( *[trainer_.train_step_as_task() for trainer_ in self._trainer_dict.values()] diff --git a/tests/rl/gym_wrapper/env_sampler.py b/tests/rl/gym_wrapper/env_sampler.py index a0ad0e8d8..830610e3e 100644 --- a/tests/rl/gym_wrapper/env_sampler.py +++ b/tests/rl/gym_wrapper/env_sampler.py @@ -106,8 +106,8 @@ def post_evaluate(self, ep: int) -> None: @staticmethod def merge_metrics(metrics_list: List[dict]) -> dict: metrics = {"n_interactions": sum(m["n_interactions"] for m in metrics_list)} - - tmp_metrics_list = [m for m in metrics_list if "n_steps" in m ] + + tmp_metrics_list = [m for m in metrics_list if "n_steps" in m] if len(tmp_metrics_list) > 0: n_steps = sum(m["n_steps"] for m in tmp_metrics_list) n_segment = sum(m["n_segment"] for m in tmp_metrics_list) @@ -118,10 +118,10 @@ def merge_metrics(metrics_list: List[dict]) -> dict: "avg_reward": sum(m["avg_reward"] * m["n_segment"] for m in tmp_metrics_list) / n_segment, "avg_n_steps": n_steps / n_segment, "max_n_steps": max(m["max_n_steps"] for m in tmp_metrics_list), - } + }, ) - - tmp_metrics_list = [m for m in metrics_list if "val/n_steps" in m ] + + tmp_metrics_list = [m for m in metrics_list if "val/n_steps" in m] if len(tmp_metrics_list) > 0: n_steps = sum(m["val/n_steps"] for m in tmp_metrics_list) n_segment = sum(m["val/n_segment"] for m in tmp_metrics_list) @@ -129,10 +129,11 @@ def merge_metrics(metrics_list: List[dict]) -> dict: { "val/n_steps": n_steps, "val/n_segment": n_segment, - "val/avg_reward": sum(m["val/avg_reward"] * m["val/n_segment"] for m in tmp_metrics_list) / n_segment, + "val/avg_reward": sum(m["val/avg_reward"] * m["val/n_segment"] for m in tmp_metrics_list) + / n_segment, "val/avg_n_steps": n_steps / n_segment, "val/max_n_steps": max(m["val/max_n_steps"] for m in tmp_metrics_list), - } + }, ) return metrics From 240ec218dd87ff050c0b23b3a58dad350a1eec99 Mon Sep 17 00:00:00 2001 From: Huoran Li Date: Mon, 15 May 2023 10:15:25 +0800 Subject: [PATCH 10/12] Minor --- maro/rl/rollout/worker.py | 2 +- tests/rl/gym_wrapper/common.py | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/maro/rl/rollout/worker.py b/maro/rl/rollout/worker.py index 124efd438..cb2e39459 100644 --- a/maro/rl/rollout/worker.py +++ b/maro/rl/rollout/worker.py @@ -39,7 +39,7 @@ def __init__( producer_port=producer_port if producer_port is not None else DEFAULT_ROLLOUT_PRODUCER_PORT, logger=logger, ) - self._env_sampler = rl_component_bundle.env_sampler # TODO: deep copy? + self._env_sampler = rl_component_bundle.env_sampler def _compute(self, msg: list) -> None: """Perform a full or partial episode of roll-out for sampling or evaluation. diff --git a/tests/rl/gym_wrapper/common.py b/tests/rl/gym_wrapper/common.py index 2d1db2f7e..d9b8c51e3 100644 --- a/tests/rl/gym_wrapper/common.py +++ b/tests/rl/gym_wrapper/common.py @@ -10,7 +10,9 @@ from tests.rl.gym_wrapper.simulator.business_engine import GymBusinessEngine env_conf = { - "topology": "Walker2d-v4", # HalfCheetah-v4, Hopper-v4, Walker2d-v4, Swimmer-v4, Ant-v4, CartPole-v1 + # Envs with discrete action: {CartPole-v1} + # Envs with continuous action: {HalfCheetah-v4, Hopper-v4, Walker2d-v4, Swimmer-v4, Ant-v4} + "topology": "Walker2d-v4", "start_tick": 0, "durations": 100000, # Set a very large number "options": {}, From 607d3b6c743da067227f28a57c10f9d24d45b09e Mon Sep 17 00:00:00 2001 From: Huoran Li Date: Tue, 23 May 2023 15:52:26 +0800 Subject: [PATCH 11/12] Refine explore strategy, add prioritized sampling support; add DDQN example; add DQN test (#590) * Runnable. Should setup a benchmark and test performance. * Refine logic * Test DQN on GYM passed * Refine explore strategy * Minor * Minor * Add Dueling DQN in CIM scenario * Resolve PR comments * Add one more explanation --- examples/cim/rl/algorithms/dqn.py | 72 ++++++--- examples/cim/rl/config.py | 2 +- examples/vm_scheduling/rl/algorithms/dqn.py | 16 +- maro/rl/exploration/__init__.py | 12 +- maro/rl/exploration/scheduling.py | 127 ---------------- maro/rl/exploration/strategies.py | 159 ++++++++++---------- maro/rl/model/fc_block.py | 13 +- maro/rl/policy/discrete_rl_policy.py | 36 +---- maro/rl/training/__init__.py | 9 +- maro/rl/training/algorithms/ddpg.py | 3 - maro/rl/training/algorithms/dqn.py | 90 ++++++++--- maro/rl/training/algorithms/sac.py | 3 - maro/rl/training/replay_memory.py | 108 ++++++++++++- maro/rl/training/trainer.py | 3 +- tests/rl/gym_wrapper/common.py | 22 ++- tests/rl/gym_wrapper/env_sampler.py | 9 +- tests/rl/performance.md | 35 ++++- tests/rl/tasks/ac/__init__.py | 3 + tests/rl/tasks/ddpg/__init__.py | 3 + tests/rl/tasks/dqn/__init__.py | 104 +++++++++++++ tests/rl/tasks/dqn/config.yml | 32 ++++ tests/rl/tasks/ppo/__init__.py | 3 + tests/rl/tasks/sac/__init__.py | 3 + 23 files changed, 537 insertions(+), 330 deletions(-) delete mode 100644 maro/rl/exploration/scheduling.py create mode 100644 tests/rl/tasks/dqn/__init__.py create mode 100644 tests/rl/tasks/dqn/config.yml diff --git a/examples/cim/rl/algorithms/dqn.py b/examples/cim/rl/algorithms/dqn.py index 022275552..c2c5f1952 100644 --- a/examples/cim/rl/algorithms/dqn.py +++ b/examples/cim/rl/algorithms/dqn.py @@ -1,10 +1,11 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. +from typing import Optional, Tuple import torch from torch.optim import RMSprop -from maro.rl.exploration import MultiLinearExplorationScheduler, epsilon_greedy +from maro.rl.exploration import EpsilonGreedy from maro.rl.model import DiscreteQNet, FullyConnected from maro.rl.policy import ValueBasedPolicy from maro.rl.training.algorithms import DQNParams, DQNTrainer @@ -23,32 +24,62 @@ class MyQNet(DiscreteQNet): - def __init__(self, state_dim: int, action_num: int) -> None: + def __init__( + self, + state_dim: int, + action_num: int, + dueling_param: Optional[Tuple[dict, dict]] = None, + ) -> None: super(MyQNet, self).__init__(state_dim=state_dim, action_num=action_num) - self._fc = FullyConnected(input_dim=state_dim, output_dim=action_num, **q_net_conf) - self._optim = RMSprop(self._fc.parameters(), lr=learning_rate) + + self._use_dueling = dueling_param is not None + self._fc = FullyConnected(input_dim=state_dim, output_dim=0 if self._use_dueling else action_num, **q_net_conf) + if self._use_dueling: + q_kwargs, v_kwargs = dueling_param + self._q = FullyConnected(input_dim=self._fc.output_dim, output_dim=action_num, **q_kwargs) + self._v = FullyConnected(input_dim=self._fc.output_dim, output_dim=1, **v_kwargs) + + self._optim = RMSprop(self.parameters(), lr=learning_rate) def _get_q_values_for_all_actions(self, states: torch.Tensor) -> torch.Tensor: - return self._fc(states) + logits = self._fc(states) + if self._use_dueling: + q = self._q(logits) + v = self._v(logits) + logits = q - q.mean(dim=1, keepdim=True) + v + return logits def get_dqn_policy(state_dim: int, action_num: int, name: str) -> ValueBasedPolicy: + q_kwargs = { + "hidden_dims": [128], + "activation": torch.nn.LeakyReLU, + "output_activation": torch.nn.LeakyReLU, + "softmax": False, + "batch_norm": True, + "skip_connection": False, + "head": True, + "dropout_p": 0.0, + } + v_kwargs = { + "hidden_dims": [128], + "activation": torch.nn.LeakyReLU, + "output_activation": None, + "softmax": False, + "batch_norm": True, + "skip_connection": False, + "head": True, + "dropout_p": 0.0, + } + return ValueBasedPolicy( name=name, - q_net=MyQNet(state_dim, action_num), - exploration_strategy=(epsilon_greedy, {"epsilon": 0.4}), - exploration_scheduling_options=[ - ( - "epsilon", - MultiLinearExplorationScheduler, - { - "splits": [(2, 0.32)], - "initial_value": 0.4, - "last_ep": 5, - "final_value": 0.0, - }, - ), - ], + q_net=MyQNet( + state_dim, + action_num, + dueling_param=(q_kwargs, v_kwargs), + ), + explore_strategy=EpsilonGreedy(epsilon=0.4, num_actions=action_num), warmup=100, ) @@ -64,6 +95,7 @@ def get_dqn(name: str) -> DQNTrainer: num_epochs=10, soft_update_coef=0.1, double=False, - random_overwrite=False, + alpha=1.0, + beta=1.0, ), ) diff --git a/examples/cim/rl/config.py b/examples/cim/rl/config.py index a46194900..e91e5edfc 100644 --- a/examples/cim/rl/config.py +++ b/examples/cim/rl/config.py @@ -35,4 +35,4 @@ action_num = len(action_shaping_conf["action_space"]) -algorithm = "ppo" # ac, ppo, dqn or discrete_maddpg +algorithm = "dqn" # ac, ppo, dqn or discrete_maddpg diff --git a/examples/vm_scheduling/rl/algorithms/dqn.py b/examples/vm_scheduling/rl/algorithms/dqn.py index 643d6c6d4..78be0d7bd 100644 --- a/examples/vm_scheduling/rl/algorithms/dqn.py +++ b/examples/vm_scheduling/rl/algorithms/dqn.py @@ -6,7 +6,7 @@ from torch.optim import SGD from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts -from maro.rl.exploration import MultiLinearExplorationScheduler +from maro.rl.exploration import EpsilonGreedy from maro.rl.model import DiscreteQNet, FullyConnected from maro.rl.policy import ValueBasedPolicy from maro.rl.training.algorithms import DQNParams, DQNTrainer @@ -58,19 +58,7 @@ def get_dqn_policy(state_dim: int, action_num: int, num_features: int, name: str return ValueBasedPolicy( name=name, q_net=MyQNet(state_dim, action_num, num_features), - exploration_strategy=(MaskedEpsGreedy(state_dim, num_features), {"epsilon": 0.4}), - exploration_scheduling_options=[ - ( - "epsilon", - MultiLinearExplorationScheduler, - { - "splits": [(100, 0.32)], - "initial_value": 0.4, - "last_ep": 400, - "final_value": 0.0, - }, - ), - ], + explore_strategy=EpsilonGreedy(epsilon=0.4, num_actions=action_num), warmup=100, ) diff --git a/maro/rl/exploration/__init__.py b/maro/rl/exploration/__init__.py index 383cca89a..7be8b579d 100644 --- a/maro/rl/exploration/__init__.py +++ b/maro/rl/exploration/__init__.py @@ -1,14 +1,10 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -from .scheduling import AbsExplorationScheduler, LinearExplorationScheduler, MultiLinearExplorationScheduler -from .strategies import epsilon_greedy, gaussian_noise, uniform_noise +from .strategies import EpsilonGreedy, ExploreStrategy, LinearExploration __all__ = [ - "AbsExplorationScheduler", - "LinearExplorationScheduler", - "MultiLinearExplorationScheduler", - "epsilon_greedy", - "gaussian_noise", - "uniform_noise", + "ExploreStrategy", + "EpsilonGreedy", + "LinearExploration", ] diff --git a/maro/rl/exploration/scheduling.py b/maro/rl/exploration/scheduling.py deleted file mode 100644 index 3981729c9..000000000 --- a/maro/rl/exploration/scheduling.py +++ /dev/null @@ -1,127 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT license. - -from abc import ABC, abstractmethod -from typing import List, Tuple - - -class AbsExplorationScheduler(ABC): - """Abstract exploration scheduler. - - Args: - exploration_params (dict): The exploration params attribute from some ``RLPolicy`` instance to which the - scheduler is applied. - param_name (str): Name of the exploration parameter to which the scheduler is applied. - initial_value (float, default=None): Initial value for the exploration parameter. If None, the value used - when instantiating the policy will be used as the initial value. - """ - - def __init__(self, exploration_params: dict, param_name: str, initial_value: float = None) -> None: - super().__init__() - self._exploration_params = exploration_params - self.param_name = param_name - if initial_value is not None: - self._exploration_params[self.param_name] = initial_value - - def get_value(self) -> float: - return self._exploration_params[self.param_name] - - @abstractmethod - def step(self) -> None: - raise NotImplementedError - - -class LinearExplorationScheduler(AbsExplorationScheduler): - """Linear exploration parameter schedule. - - Args: - exploration_params (dict): The exploration params attribute from some ``RLPolicy`` instance to which the - scheduler is applied. - param_name (str): Name of the exploration parameter to which the scheduler is applied. - last_ep (int): Last episode. - final_value (float): The value of the exploration parameter corresponding to ``last_ep``. - start_ep (int, default=1): starting episode. - initial_value (float, default=None): Initial value for the exploration parameter. If None, the value used - when instantiating the policy will be used as the initial value. - """ - - def __init__( - self, - exploration_params: dict, - param_name: str, - *, - last_ep: int, - final_value: float, - start_ep: int = 1, - initial_value: float = None, - ) -> None: - super().__init__(exploration_params, param_name, initial_value=initial_value) - self.final_value = final_value - if last_ep > 1: - self.delta = (self.final_value - self._exploration_params[self.param_name]) / (last_ep - start_ep) - else: - self.delta = 0 - - def step(self) -> None: - if self._exploration_params[self.param_name] == self.final_value: - return - - self._exploration_params[self.param_name] += self.delta - - -class MultiLinearExplorationScheduler(AbsExplorationScheduler): - """Exploration parameter schedule that consists of multiple linear phases. - - Args: - exploration_params (dict): The exploration params attribute from some ``RLPolicy`` instance to which the - scheduler is applied. - param_name (str): Name of the exploration parameter to which the scheduler is applied. - splits (List[Tuple[int, float]]): List of points that separate adjacent linear phases. Each - point is a (episode, parameter_value) tuple that indicates the end of one linear phase and - the start of another. These points do not have to be given in any particular order. There - cannot be two points with the same first element (episode), or a ``ValueError`` will be raised. - last_ep (int): Last episode. - final_value (float): The value of the exploration parameter corresponding to ``last_ep``. - start_ep (int, default=1): starting episode. - initial_value (float, default=None): Initial value for the exploration parameter. If None, the value from - the original dictionary the policy is instantiated with will be used as the initial value. - """ - - def __init__( - self, - exploration_params: dict, - param_name: str, - *, - splits: List[Tuple[int, float]], - last_ep: int, - final_value: float, - start_ep: int = 1, - initial_value: float = None, - ) -> None: - super().__init__(exploration_params, param_name, initial_value=initial_value) - - # validate splits - splits = [(start_ep, self._exploration_params[self.param_name])] + splits + [(last_ep, final_value)] - splits.sort() - for (ep, _), (ep2, _) in zip(splits, splits[1:]): - if ep == ep2: - raise ValueError("The zeroth element of split points must be unique") - - self.final_value = final_value - self._splits = splits - self._ep = start_ep - self._split_index = 1 - self._delta = (self._splits[1][1] - self._exploration_params[self.param_name]) / (self._splits[1][0] - start_ep) - - def step(self) -> None: - if self._split_index == len(self._splits): - return - - self._exploration_params[self.param_name] += self._delta - self._ep += 1 - if self._ep == self._splits[self._split_index][0]: - self._split_index += 1 - if self._split_index < len(self._splits): - self._delta = (self._splits[self._split_index][1] - self._splits[self._split_index - 1][1]) / ( - self._splits[self._split_index][0] - self._splits[self._split_index - 1][0] - ) diff --git a/maro/rl/exploration/strategies.py b/maro/rl/exploration/strategies.py index c85340c78..37b164389 100644 --- a/maro/rl/exploration/strategies.py +++ b/maro/rl/exploration/strategies.py @@ -1,93 +1,100 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. - -from typing import Union +from abc import abstractmethod +from typing import Any import numpy as np -def epsilon_greedy( - state: np.ndarray, - action: np.ndarray, - num_actions: int, - *, - epsilon: float, -) -> np.ndarray: - """Epsilon-greedy exploration. +class ExploreStrategy: + def __init__(self) -> None: + pass + + @abstractmethod + def get_action( + self, + state: np.ndarray, + action: np.ndarray, + **kwargs: Any, + ) -> np.ndarray: + """ + Args: + state (np.ndarray): State(s) based on which ``action`` is chosen. This is not used by the vanilla + eps-greedy exploration and is put here to conform to the function signature required for the exploration + strategy parameter for ``DQN``. + action (np.ndarray): Action(s) chosen greedily by the policy. + + Returns: + Exploratory actions. + """ + raise NotImplementedError + + +class EpsilonGreedy(ExploreStrategy): + """Epsilon-greedy exploration. Returns uniformly random action with probability `epsilon` or returns original + action with probability `1.0 - epsilon`. Args: - state (np.ndarray): State(s) based on which ``action`` is chosen. This is not used by the vanilla - eps-greedy exploration and is put here to conform to the function signature required for the exploration - strategy parameter for ``DQN``. - action (np.ndarray): Action(s) chosen greedily by the policy. num_actions (int): Number of possible actions. epsilon (float): The probability that a random action will be selected. - - Returns: - Exploratory actions. """ - return np.array([act if np.random.random() > epsilon else np.random.randint(num_actions) for act in action]) + def __init__(self, num_actions: int, epsilon: float) -> None: + super(EpsilonGreedy, self).__init__() -def uniform_noise( - state: np.ndarray, - action: np.ndarray, - min_action: Union[float, list, np.ndarray] = None, - max_action: Union[float, list, np.ndarray] = None, - *, - low: Union[float, list, np.ndarray], - high: Union[float, list, np.ndarray], -) -> Union[float, np.ndarray]: - """Apply a uniform noise to a continuous multidimensional action. + assert 0.0 <= epsilon <= 1.0 - Args: - state (np.ndarray): State(s) based on which ``action`` is chosen. This is not used by the gaussian noise - exploration scheme and is put here to conform to the function signature for the exploration in continuous - action spaces. - action (np.ndarray): Action(s) chosen greedily by the policy. - min_action (Union[float, list, np.ndarray], default=None): Lower bound for the multidimensional action space. - max_action (Union[float, list, np.ndarray], default=None): Upper bound for the multidimensional action space. - low (Union[float, list, np.ndarray]): Lower bound for the noise range. - high (Union[float, list, np.ndarray]): Upper bound for the noise range. - - Returns: - Exploration actions with added noise. - """ - if min_action is None and max_action is None: - return action + np.random.uniform(low, high, size=action.shape) - else: - return np.clip(action + np.random.uniform(low, high, size=action.shape), min_action, max_action) - - -def gaussian_noise( - state: np.ndarray, - action: np.ndarray, - min_action: Union[float, list, np.ndarray] = None, - max_action: Union[float, list, np.ndarray] = None, - *, - mean: Union[float, list, np.ndarray] = 0.0, - stddev: Union[float, list, np.ndarray] = 1.0, - relative: bool = False, -) -> Union[float, np.ndarray]: - """Apply a gaussian noise to a continuous multidimensional action. + self._num_actions = num_actions + self._eps = epsilon + + def get_action( + self, + state: np.ndarray, + action: np.ndarray, + **kwargs: Any, + ) -> np.ndarray: + return np.array( + [act if np.random.random() > self._eps else np.random.randint(self._num_actions) for act in action], + ) + + +class LinearExploration(ExploreStrategy): + """Epsilon greedy which the probability `epsilon` is linearly interpolated between `start_explore_prob` and + `end_explore_prob` over `explore_steps`. After this many timesteps pass, `epsilon` is fixed to `end_explore_prob`. Args: - state (np.ndarray): State(s) based on which ``action`` is chosen. This is not used by the gaussian noise - exploration scheme and is put here to conform to the function signature for the exploration in continuous - action spaces. - action (np.ndarray): Action(s) chosen greedily by the policy. - min_action (Union[float, list, np.ndarray], default=None): Lower bound for the multidimensional action space. - max_action (Union[float, list, np.ndarray], default=None): Upper bound for the multidimensional action space. - mean (Union[float, list, np.ndarray], default=0.0): Gaussian noise mean. - stddev (Union[float, list, np.ndarray], default=1.0): Standard deviation for the Gaussian noise. - relative (bool, default=False): If True, the generated noise is treated as a relative measure and will - be multiplied by the action itself before being added to the action. - - Returns: - Exploration actions with added noise (a numpy ndarray). + num_actions (int): Number of possible actions. + explore_steps (int): Maximum number of steps to interpolate probability. + start_explore_prob (float): Starting explore probability. + end_explore_prob (float): Ending explore probability. """ - noise = np.random.normal(loc=mean, scale=stddev, size=action.shape) - if min_action is None and max_action is None: - return action + ((noise * action) if relative else noise) - else: - return np.clip(action + ((noise * action) if relative else noise), min_action, max_action) + + def __init__( + self, + num_actions: int, + explore_steps: int, + start_explore_prob: float, + end_explore_prob: float, + ) -> None: + super(LinearExploration, self).__init__() + + self._call_count = 0 + + self._num_actions = num_actions + self._explore_steps = explore_steps + self._start_explore_prob = start_explore_prob + self._end_explore_prob = end_explore_prob + + def get_action( + self, + state: np.ndarray, + action: np.ndarray, + **kwargs: Any, + ) -> np.ndarray: + ratio = min(self._call_count / self._explore_steps, 1.0) + epsilon = self._start_explore_prob + (self._end_explore_prob - self._start_explore_prob) * ratio + explore_flag = np.random.random() < epsilon + action = np.array([np.random.randint(self._num_actions) if explore_flag else act for act in action]) + + self._call_count += 1 + return action diff --git a/maro/rl/model/fc_block.py b/maro/rl/model/fc_block.py index 9154d6ff0..f7f78e518 100644 --- a/maro/rl/model/fc_block.py +++ b/maro/rl/model/fc_block.py @@ -13,7 +13,7 @@ class FullyConnected(nn.Module): Args: input_dim (int): Network input dimension. - output_dim (int): Network output dimension. + output_dim (int): Network output dimension. If it is 0, will not create the top layer. hidden_dims (List[int]): Dimensions of hidden layers. Its length is the number of hidden layers. For example, `hidden_dims=[128, 256]` refers to two hidden layers with output dim of 128 and 256, respectively. activation (Optional[Type[torch.nn.Module], default=nn.ReLU): Activation class provided by ``torch.nn`` or a @@ -52,7 +52,6 @@ def __init__( super(FullyConnected, self).__init__() self._input_dim = input_dim self._hidden_dims = hidden_dims if hidden_dims is not None else [] - self._output_dim = output_dim # network features self._activation = activation if activation else None @@ -76,9 +75,13 @@ def __init__( self._build_layer(in_dim, out_dim, activation=self._activation) for in_dim, out_dim in zip(dims, dims[1:]) ] # top layer - layers.append( - self._build_layer(dims[-1], self._output_dim, head=self._head, activation=self._output_activation), - ) + if output_dim != 0: + layers.append( + self._build_layer(dims[-1], output_dim, head=self._head, activation=self._output_activation), + ) + self._output_dim = output_dim + else: + self._output_dim = hidden_dims[-1] self._net = nn.Sequential(*layers) diff --git a/maro/rl/policy/discrete_rl_policy.py b/maro/rl/policy/discrete_rl_policy.py index 289d150e7..344be00d8 100644 --- a/maro/rl/policy/discrete_rl_policy.py +++ b/maro/rl/policy/discrete_rl_policy.py @@ -2,15 +2,14 @@ # Licensed under the MIT license. from abc import ABCMeta -from typing import Callable, Dict, List, Tuple +from typing import Dict, Optional, Tuple import numpy as np import torch -from maro.rl.exploration import epsilon_greedy +from maro.rl.exploration import ExploreStrategy from maro.rl.model import DiscretePolicyNet, DiscreteQNet from maro.rl.utils import match_shape, ndarray_to_tensor -from maro.utils import clone from .abs_policy import RLPolicy @@ -69,8 +68,7 @@ class ValueBasedPolicy(DiscreteRLPolicy): name (str): Name of the policy. q_net (DiscreteQNet): Q-net used in this value-based policy. trainable (bool, default=True): Whether this policy is trainable. - exploration_strategy (Tuple[Callable, dict], default=(epsilon_greedy, {"epsilon": 0.1})): Exploration strategy. - exploration_scheduling_options (List[tuple], default=None): List of exploration scheduler options. + explore_strategy (Optional[ExploreStrategy], default=None): Explore strategy. warmup (int, default=50000): Number of steps for uniform-random action selection, before running real policy. Helps exploration. """ @@ -80,8 +78,7 @@ def __init__( name: str, q_net: DiscreteQNet, trainable: bool = True, - exploration_strategy: Tuple[Callable, dict] = (epsilon_greedy, {"epsilon": 0.1}), - exploration_scheduling_options: List[tuple] = None, + explore_strategy: Optional[ExploreStrategy] = None, warmup: int = 50000, ) -> None: assert isinstance(q_net, DiscreteQNet) @@ -94,15 +91,7 @@ def __init__( warmup=warmup, ) self._q_net = q_net - - self._exploration_func = exploration_strategy[0] - self._exploration_params = clone(exploration_strategy[1]) # deep copy is needed to avoid unwanted sharing - self._exploration_schedulers = ( - [opt[1](self._exploration_params, opt[0], **opt[2]) for opt in exploration_scheduling_options] - if exploration_scheduling_options is not None - else [] - ) - + self._explore_strategy = explore_strategy self._softmax = torch.nn.Softmax(dim=1) @property @@ -176,9 +165,6 @@ def q_values_tensor(self, states: torch.Tensor, actions: torch.Tensor, **kwargs) assert match_shape(q_values, (states.shape[0],)) # [B] return q_values - def explore(self) -> None: - pass # Overwrite the base method and turn off explore mode. - def _get_actions_impl(self, states: torch.Tensor, **kwargs) -> torch.Tensor: return self._get_actions_with_probs_impl(states, **kwargs)[0] @@ -187,17 +173,11 @@ def _get_actions_with_probs_impl(self, states: torch.Tensor, **kwargs) -> Tuple[ q_matrix_softmax = self._softmax(q_matrix) _, actions = q_matrix.max(dim=1) # [B], [B] - if self._is_exploring: - actions = self._exploration_func( - states, - actions.cpu().numpy(), - self.action_num, - **self._exploration_params, - **kwargs, - ) + if self._is_exploring and self._explore_strategy is not None: + actions = self._explore_strategy.get_action(state=states.cpu().numpy(), action=actions.cpu().numpy()) actions = ndarray_to_tensor(actions, device=self._device) - actions = actions.unsqueeze(1) + actions = actions.unsqueeze(1).long() return actions, q_matrix_softmax.gather(1, actions).squeeze(-1) # [B, 1] def _get_actions_with_logps_impl(self, states: torch.Tensor, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]: diff --git a/maro/rl/training/__init__.py b/maro/rl/training/__init__.py index a77296f98..0e4488915 100644 --- a/maro/rl/training/__init__.py +++ b/maro/rl/training/__init__.py @@ -2,7 +2,13 @@ # Licensed under the MIT license. from .proxy import TrainingProxy -from .replay_memory import FIFOMultiReplayMemory, FIFOReplayMemory, RandomMultiReplayMemory, RandomReplayMemory +from .replay_memory import ( + FIFOMultiReplayMemory, + FIFOReplayMemory, + PrioritizedReplayMemory, + RandomMultiReplayMemory, + RandomReplayMemory, +) from .train_ops import AbsTrainOps, RemoteOps, remote from .trainer import AbsTrainer, BaseTrainerParams, MultiAgentTrainer, SingleAgentTrainer from .training_manager import TrainingManager @@ -12,6 +18,7 @@ "TrainingProxy", "FIFOMultiReplayMemory", "FIFOReplayMemory", + "PrioritizedReplayMemory", "RandomMultiReplayMemory", "RandomReplayMemory", "AbsTrainOps", diff --git a/maro/rl/training/algorithms/ddpg.py b/maro/rl/training/algorithms/ddpg.py index aaa0b7454..bf7b0f8d4 100644 --- a/maro/rl/training/algorithms/ddpg.py +++ b/maro/rl/training/algorithms/ddpg.py @@ -261,9 +261,6 @@ def _register_policy(self, policy: RLPolicy) -> None: assert isinstance(policy, ContinuousRLPolicy) self._policy = policy - def _preprocess_batch(self, transition_batch: TransitionBatch) -> TransitionBatch: - return transition_batch - def get_local_ops(self) -> AbsTrainOps: return DDPGOps( name=self._policy.name, diff --git a/maro/rl/training/algorithms/dqn.py b/maro/rl/training/algorithms/dqn.py index 5a4f938ab..a12ba4a11 100644 --- a/maro/rl/training/algorithms/dqn.py +++ b/maro/rl/training/algorithms/dqn.py @@ -2,12 +2,21 @@ # Licensed under the MIT license. from dataclasses import dataclass -from typing import Dict, cast +from typing import Dict, Tuple, cast +import numpy as np import torch from maro.rl.policy import RLPolicy, ValueBasedPolicy -from maro.rl.training import AbsTrainOps, BaseTrainerParams, RandomReplayMemory, RemoteOps, SingleAgentTrainer, remote +from maro.rl.training import ( + AbsTrainOps, + BaseTrainerParams, + PrioritizedReplayMemory, + RandomReplayMemory, + RemoteOps, + SingleAgentTrainer, + remote, +) from maro.rl.utils import TransitionBatch, get_torch_device, ndarray_to_tensor from maro.utils import clone @@ -15,6 +24,9 @@ @dataclass class DQNParams(BaseTrainerParams): """ + use_prioritized_replay (bool, default=False): Whether to use prioritized replay memory. + alpha (float, default=0.4): Alpha in prioritized replay memory. + beta (float, default=0.6): Beta in prioritized replay memory. num_epochs (int, default=1): Number of training epochs. update_target_every (int, default=5): Number of gradient steps between target model updates. soft_update_coef (float, default=0.1): Soft update coefficient, e.g., @@ -27,11 +39,13 @@ class DQNParams(BaseTrainerParams): sequentially with wrap-around. """ + use_prioritized_replay: bool = False + alpha: float = 0.4 + beta: float = 0.6 num_epochs: int = 1 update_target_every: int = 5 soft_update_coef: float = 0.1 double: bool = False - random_overwrite: bool = False class DQNOps(AbsTrainOps): @@ -54,20 +68,21 @@ def __init__( self._reward_discount = reward_discount self._soft_update_coef = params.soft_update_coef self._double = params.double - self._loss_func = torch.nn.MSELoss() self._target_policy: ValueBasedPolicy = clone(self._policy) self._target_policy.set_name(f"target_{self._policy.name}") self._target_policy.eval() - def _get_batch_loss(self, batch: TransitionBatch) -> torch.Tensor: + def _get_batch_loss(self, batch: TransitionBatch, weight: np.ndarray) -> Tuple[torch.Tensor, torch.Tensor]: """Compute the loss of the batch. Args: batch (TransitionBatch): Batch. + weight (np.ndarray): Weight of each data entry. Returns: loss (torch.Tensor): The loss of the batch. + td_error (torch.Tensor): TD-error of the batch. """ assert isinstance(batch, TransitionBatch) assert isinstance(self._policy, ValueBasedPolicy) @@ -79,19 +94,21 @@ def _get_batch_loss(self, batch: TransitionBatch) -> torch.Tensor: rewards = ndarray_to_tensor(batch.rewards, device=self._device) terminals = ndarray_to_tensor(batch.terminals, device=self._device).float() + weight = ndarray_to_tensor(weight, device=self._device) + with torch.no_grad(): if self._double: self._policy.exploit() actions_by_eval_policy = self._policy.get_actions_tensor(next_states) next_q_values = self._target_policy.q_values_tensor(next_states, actions_by_eval_policy) else: - self._target_policy.exploit() - actions = self._target_policy.get_actions_tensor(next_states) - next_q_values = self._target_policy.q_values_tensor(next_states, actions) + next_q_values = self._target_policy.q_values_for_all_actions_tensor(next_states).max(dim=1)[0] target_q_values = (rewards + self._reward_discount * (1 - terminals) * next_q_values).detach() q_values = self._policy.q_values_tensor(states, actions) - return self._loss_func(q_values, target_q_values) + td_error = target_q_values - q_values + + return (td_error.pow(2) * weight).mean(), td_error @remote def get_batch_grad(self, batch: TransitionBatch) -> Dict[str, torch.Tensor]: @@ -103,7 +120,8 @@ def get_batch_grad(self, batch: TransitionBatch) -> Dict[str, torch.Tensor]: Returns: grad (torch.Tensor): The gradient of the batch. """ - return self._policy.get_gradients(self._get_batch_loss(batch)) + loss, _ = self._get_batch_loss(batch) + return self._policy.get_gradients(loss) def update_with_grad(self, grad_dict: dict) -> None: """Update the network with remotely computed gradients. @@ -114,14 +132,20 @@ def update_with_grad(self, grad_dict: dict) -> None: self._policy.train() self._policy.apply_gradients(grad_dict) - def update(self, batch: TransitionBatch) -> None: + def update(self, batch: TransitionBatch, weight: np.ndarray) -> np.ndarray: """Update the network using a batch. Args: batch (TransitionBatch): Batch. + weight (np.ndarray): Weight of each data entry. + + Returns: + td_errors (np.ndarray) """ self._policy.train() - self._policy.train_step(self._get_batch_loss(batch)) + loss, td_error = self._get_batch_loss(batch, weight) + self._policy.train_step(loss) + return td_error.detach().numpy() def get_non_policy_state(self) -> dict: return { @@ -168,20 +192,27 @@ def __init__( def build(self) -> None: self._ops = cast(DQNOps, self.get_ops()) - self._replay_memory = RandomReplayMemory( - capacity=self._replay_memory_capacity, - state_dim=self._ops.policy_state_dim, - action_dim=self._ops.policy_action_dim, - random_overwrite=self._params.random_overwrite, - ) + + if self._params.use_prioritized_replay: + self._replay_memory = PrioritizedReplayMemory( + capacity=self._replay_memory_capacity, + state_dim=self._ops.policy_state_dim, + action_dim=self._ops.policy_action_dim, + alpha=self._params.alpha, + beta=self._params.beta, + ) + else: + self._replay_memory = RandomReplayMemory( + capacity=self._replay_memory_capacity, + state_dim=self._ops.policy_state_dim, + action_dim=self._ops.policy_action_dim, + random_overwrite=False, + ) def _register_policy(self, policy: RLPolicy) -> None: assert isinstance(policy, ValueBasedPolicy) self._policy = policy - def _preprocess_batch(self, transition_batch: TransitionBatch) -> TransitionBatch: - return transition_batch - def get_local_ops(self) -> AbsTrainOps: return DQNOps( name=self._policy.name, @@ -191,13 +222,24 @@ def get_local_ops(self) -> AbsTrainOps: params=self._params, ) - def _get_batch(self, batch_size: int = None) -> TransitionBatch: - return self._replay_memory.sample(batch_size if batch_size is not None else self._batch_size) + def _get_batch(self, batch_size: int = None) -> Tuple[TransitionBatch, np.ndarray, np.ndarray]: + indexes = self.replay_memory.get_sample_indexes(batch_size or self._batch_size) + batch = self.replay_memory.sample_by_indexes(indexes) + + if self._params.use_prioritized_replay: + weight = cast(PrioritizedReplayMemory, self.replay_memory).get_weight(indexes) + else: + weight = np.ones(len(indexes)) + + return batch, indexes, weight def train_step(self) -> None: assert isinstance(self._ops, DQNOps) for _ in range(self._params.num_epochs): - self._ops.update(self._get_batch()) + batch, indexes, weight = self._get_batch() + td_error = self._ops.update(batch, weight) + if self._params.use_prioritized_replay: + cast(PrioritizedReplayMemory, self.replay_memory).update_weight(indexes, td_error) self._try_soft_update_target() diff --git a/maro/rl/training/algorithms/sac.py b/maro/rl/training/algorithms/sac.py index 7daf99c7d..54a6c4cbd 100644 --- a/maro/rl/training/algorithms/sac.py +++ b/maro/rl/training/algorithms/sac.py @@ -272,9 +272,6 @@ async def train_step_as_task(self) -> None: if early_stop: break - def _preprocess_batch(self, transition_batch: TransitionBatch) -> TransitionBatch: - return transition_batch - def get_local_ops(self) -> SoftActorCriticOps: return SoftActorCriticOps( name=self._policy.name, diff --git a/maro/rl/training/replay_memory.py b/maro/rl/training/replay_memory.py index da1e7d692..164c2580c 100644 --- a/maro/rl/training/replay_memory.py +++ b/maro/rl/training/replay_memory.py @@ -88,6 +88,73 @@ def get_sample_indexes(self, batch_size: int = None) -> np.ndarray: return np.random.choice(self._size, size=batch_size, replace=True) +class PriorityReplayIndexScheduler(AbsIndexScheduler): + """ + Indexer for priority replay memory: https://arxiv.org/abs/1511.05952. + + Args: + capacity (int): Maximum capacity of the replay memory. + alpha (float): Alpha (see original paper for explanation). + beta (float): Alpha (see original paper for explanation). + """ + + def __init__( + self, + capacity: int, + alpha: float, + beta: float, + ) -> None: + super(PriorityReplayIndexScheduler, self).__init__(capacity) + self._alpha = alpha + self._beta = beta + self._max_prio = self._min_prio = 1.0 + self._weights = np.zeros(capacity, dtype=np.float32) + + self._ptr = self._size = 0 + + def init_weights(self, indexes: np.ndarray) -> None: + self._weights[indexes] = self._max_prio**self._alpha + + def get_weight(self, indexes: np.ndarray) -> np.ndarray: + # important sampling weight calculation + # original formula: ((p_j/p_sum*N)**(-beta))/((p_min/p_sum*N)**(-beta)) + # simplified formula: (p_j/p_min)**(-beta) + return (self._weights[indexes] / self._min_prio) ** (-self._beta) + + def update_weight(self, indexes: np.ndarray, weight: np.ndarray) -> None: + assert indexes.shape == weight.shape + weight = np.abs(weight) + np.finfo(np.float32).eps.item() + self._weights[indexes] = weight**self._alpha + self._max_prio = max(self._max_prio, weight.max()) + self._min_prio = min(self._min_prio, weight.min()) + + def get_put_indexes(self, batch_size: int) -> np.ndarray: + if self._ptr + batch_size <= self._capacity: + indexes = np.arange(self._ptr, self._ptr + batch_size) + self._ptr += batch_size + else: + overwrites = self._ptr + batch_size - self._capacity + indexes = np.concatenate( + [ + np.arange(self._ptr, self._capacity), + np.arange(overwrites), + ], + ) + self._ptr = overwrites + + self._size = min(self._size + batch_size, self._capacity) + self.init_weights(indexes) + return indexes + + def get_sample_indexes(self, batch_size: int = None) -> np.ndarray: + assert batch_size is not None and batch_size > 0, f"Invalid batch size: {batch_size}" + assert self._size > 0, "Cannot sample from an empty memory." + + weights = self._weights[: self._size] + weights = weights / weights.sum() + return np.random.choice(np.arange(self._size), p=weights, size=batch_size, replace=True) + + class FIFOIndexScheduler(AbsIndexScheduler): """First-in-first-out index scheduler. @@ -154,11 +221,11 @@ def capacity(self) -> int: def state_dim(self) -> int: return self._state_dim - def _get_put_indexes(self, batch_size: int) -> np.ndarray: + def get_put_indexes(self, batch_size: int) -> np.ndarray: """Please refer to the doc string in AbsIndexScheduler.""" return self._idx_scheduler.get_put_indexes(batch_size) - def _get_sample_indexes(self, batch_size: int = None) -> np.ndarray: + def get_sample_indexes(self, batch_size: int = None) -> np.ndarray: """Please refer to the doc string in AbsIndexScheduler.""" return self._idx_scheduler.get_sample_indexes(batch_size) @@ -225,10 +292,10 @@ def put(self, transition_batch: TransitionBatch) -> None: if transition_batch.old_logps is not None: match_shape(transition_batch.old_logps, (batch_size,)) - self._put_by_indexes(self._get_put_indexes(batch_size), transition_batch) + self.put_by_indexes(self.get_put_indexes(batch_size), transition_batch) self._n_sample = min(self._n_sample + transition_batch.size, self._capacity) - def _put_by_indexes(self, indexes: np.ndarray, transition_batch: TransitionBatch) -> None: + def put_by_indexes(self, indexes: np.ndarray, transition_batch: TransitionBatch) -> None: """Store a transition batch into the memory at the give indexes. Args: @@ -258,7 +325,7 @@ def sample(self, batch_size: int = None) -> TransitionBatch: Returns: batch (TransitionBatch): The sampled batch. """ - indexes = self._get_sample_indexes(batch_size) + indexes = self.get_sample_indexes(batch_size) return self.sample_by_indexes(indexes) def sample_by_indexes(self, indexes: np.ndarray) -> TransitionBatch: @@ -306,6 +373,31 @@ def random_overwrite(self) -> bool: return self._random_overwrite +class PrioritizedReplayMemory(ReplayMemory): + def __init__( + self, + capacity: int, + state_dim: int, + action_dim: int, + alpha: float, + beta: float, + ) -> None: + super(PrioritizedReplayMemory, self).__init__( + capacity, + state_dim, + action_dim, + PriorityReplayIndexScheduler(capacity, alpha, beta), + ) + + def get_weight(self, indexes: np.ndarray) -> np.ndarray: + assert isinstance(self._idx_scheduler, PriorityReplayIndexScheduler) + return self._idx_scheduler.get_weight(indexes) + + def update_weight(self, indexes: np.ndarray, weight: np.ndarray) -> None: + assert isinstance(self._idx_scheduler, PriorityReplayIndexScheduler) + self._idx_scheduler.update_weight(indexes, weight) + + class FIFOReplayMemory(ReplayMemory): def __init__( self, @@ -393,9 +485,9 @@ def put(self, transition_batch: MultiTransitionBatch) -> None: assert match_shape(transition_batch.agent_states[i], (batch_size, self._agent_states_dims[i])) assert match_shape(transition_batch.next_agent_states[i], (batch_size, self._agent_states_dims[i])) - self._put_by_indexes(self._get_put_indexes(batch_size), transition_batch=transition_batch) + self.put_by_indexes(self.get_put_indexes(batch_size), transition_batch=transition_batch) - def _put_by_indexes(self, indexes: np.ndarray, transition_batch: MultiTransitionBatch) -> None: + def put_by_indexes(self, indexes: np.ndarray, transition_batch: MultiTransitionBatch) -> None: """Store a transition batch into the memory at the give indexes. Args: @@ -424,7 +516,7 @@ def sample(self, batch_size: int = None) -> MultiTransitionBatch: Returns: batch (MultiTransitionBatch): The sampled batch. """ - indexes = self._get_sample_indexes(batch_size) + indexes = self.get_sample_indexes(batch_size) return self.sample_by_indexes(indexes) def sample_by_indexes(self, indexes: np.ndarray) -> MultiTransitionBatch: diff --git a/maro/rl/training/trainer.py b/maro/rl/training/trainer.py index 774954f6c..53bd123d8 100644 --- a/maro/rl/training/trainer.py +++ b/maro/rl/training/trainer.py @@ -271,9 +271,8 @@ def record_multiple(self, env_idx: int, exp_elements: List[ExpElement]) -> None: transition_batch = self._preprocess_batch(transition_batch) self.replay_memory.put(transition_batch) - @abstractmethod def _preprocess_batch(self, transition_batch: TransitionBatch) -> TransitionBatch: - raise NotImplementedError + return transition_batch def _assert_ops_exists(self) -> None: if not self.ops: diff --git a/tests/rl/gym_wrapper/common.py b/tests/rl/gym_wrapper/common.py index 538a5f996..4a4bd12b2 100644 --- a/tests/rl/gym_wrapper/common.py +++ b/tests/rl/gym_wrapper/common.py @@ -3,12 +3,14 @@ from typing import cast +from gym import spaces + from maro.simulator import Env from tests.rl.gym_wrapper.simulator.business_engine import GymBusinessEngine env_conf = { - "topology": "Walker2d-v4", # HalfCheetah-v4, Hopper-v4, Walker2d-v4, Swimmer-v4, Ant-v4 + "topology": "CartPole-v1", # HalfCheetah-v4, Hopper-v4, Walker2d-v4, Swimmer-v4, Ant-v4, CartPole-v1 "start_tick": 0, "durations": 100000, # Set a very large number "options": {}, @@ -19,8 +21,18 @@ num_agents = len(learn_env.agent_idx_list) gym_env = cast(GymBusinessEngine, learn_env.business_engine).gym_env -gym_action_space = gym_env.action_space gym_state_dim = gym_env.observation_space.shape[0] -gym_action_dim = gym_action_space.shape[0] -action_lower_bound, action_upper_bound = gym_action_space.low, gym_action_space.high -action_limit = gym_action_space.high[0] +gym_action_space = gym_env.action_space +is_discrete = isinstance(gym_action_space, spaces.Discrete) +if is_discrete: + gym_action_space = cast(spaces.Discrete, gym_action_space) + gym_action_dim = 1 + gym_action_num = gym_action_space.n + action_lower_bound, action_upper_bound = None, None # Should never be used + action_limit = None # Should never be used +else: + gym_action_space = cast(spaces.Box, gym_action_space) + gym_action_dim = gym_action_space.shape[0] + gym_action_num = -1 # Should never be used + action_lower_bound, action_upper_bound = gym_action_space.low, gym_action_space.high + action_limit = action_upper_bound[0] diff --git a/tests/rl/gym_wrapper/env_sampler.py b/tests/rl/gym_wrapper/env_sampler.py index f95aaa546..e740bafdb 100644 --- a/tests/rl/gym_wrapper/env_sampler.py +++ b/tests/rl/gym_wrapper/env_sampler.py @@ -1,9 +1,10 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -from typing import Any, Dict, List, Tuple, Type, Union +from typing import Any, Dict, List, Tuple, Type, Union, cast import numpy as np +from gym import spaces from maro.rl.policy.abs_policy import AbsPolicy from maro.rl.rollout import AbsEnvSampler, CacheElement @@ -40,6 +41,10 @@ def __init__( self._sample_rewards = [] self._eval_rewards = [] + gym_env = cast(GymBusinessEngine, learn_env.business_engine).gym_env + gym_action_space = gym_env.action_space + self._is_discrete = isinstance(gym_action_space, spaces.Discrete) + def _get_global_and_agent_state_impl( self, event: DecisionEvent, @@ -48,7 +53,7 @@ def _get_global_and_agent_state_impl( return None, {0: event.state} def _translate_to_env_action(self, action_dict: dict, event: Any) -> dict: - return {k: Action(v) for k, v in action_dict.items()} + return {k: Action(v.item() if self._is_discrete else v) for k, v in action_dict.items()} def _get_reward(self, env_action_dict: dict, event: Any, tick: int) -> Dict[Any, float]: be = self._env.business_engine diff --git a/tests/rl/performance.md b/tests/rl/performance.md index 75442035c..c49687750 100644 --- a/tests/rl/performance.md +++ b/tests/rl/performance.md @@ -1,11 +1,15 @@ # Performance for Gym Task Suite We benchmarked the MARO RL Toolkit implementation in Gym task suite. Some are compared to the benchmarks in -[OpenAI Spinning Up](https://spinningup.openai.com/en/latest/spinningup/bench.html#). We've tried to align the +[OpenAI Spinning Up](https://spinningup.openai.com/en/latest/spinningup/bench.html#) and [RL Baseline Zoo](https://github.com/DLR-RM/rl-baselines3-zoo/blob/master/benchmark.md). We've tried to align the hyper-parameters for these benchmarks , but limited by the environment version difference, there may be some gaps between the performance here and that in Spinning Up benchmarks. Generally speaking, the performance is comparable. -## Experimental Setting +## Compare with OpenAI Spinning Up + +We compare the performance of PPO, SAC, and DDPG in MARO with [OpenAI Spinning Up](https://spinningup.openai.com/en/latest/spinningup/bench.html#). + +### Experimental Setting The hyper-parameters are set to align with those used in [Spinning Up](https://spinningup.openai.com/en/latest/spinningup/bench.html#experiment-details): @@ -29,7 +33,7 @@ The hyper-parameters are set to align with those used in More details about the parameters can be found in *tests/rl/tasks/*. -## Performance +### Performance Five environments from the MuJoCo Gym task suite are reported in Spinning Up, they are: HalfCheetah, Hopper, Walker2d, Swimmer, and Ant. The commit id of the code used to conduct the experiments for MARO RL benchmarks is ee25ce1e97. @@ -52,3 +56,28 @@ python tests/rl/plot.py --smooth WINDOWSIZE | [**Walker2d**](https://gymnasium.farama.org/environments/mujoco/walker2d/) | ![Wab](https://spinningup.openai.com/en/latest/_images/pytorch_walker2d_performance.svg) | ![Wa1](./log/Walker2d_1.png) | ![Wa11](./log/Walker2d_11.png) | | [**Swimmer**](https://gymnasium.farama.org/environments/mujoco/swimmer/) | ![Swb](https://spinningup.openai.com/en/latest/_images/pytorch_swimmer_performance.svg) | ![Sw1](./log/Swimmer_1.png) | ![Sw11](./log/Swimmer_11.png) | | [**Ant**](https://gymnasium.farama.org/environments/mujoco/ant/) | ![Anb](https://spinningup.openai.com/en/latest/_images/pytorch_ant_performance.svg) | ![An1](./log/Ant_1.png) | ![An11](./log/Ant_11.png) | + +## Compare with RL Baseline Zoo + +[RL Baseline Zoo](https://github.com/DLR-RM/rl-baselines3-zoo/blob/master/benchmark.md) provides a comprehensive set of benchmarks for multiple algorithms and environments. +However, unlike OpenAI Spinning Up, it does not provide the complete learning curve. Instead, we can only find the final metrics in it. +We therefore leave the comparison with RL Baseline Zoo as a minor addition. + +We compare the performance of DQN with RL Baseline Zoo. + +### Experimental Setting + +- Batch size: size 64 for each gradient descent step; +- Network: size (256) with relu units; +- Performance metric: measured as the average trajectory return across the batch collected at 10 epochs; +- Total timesteps: 150,000. + +### Performance + +More details about the parameters can be found in *tests/rl/tasks/*. +Please refer to the original link of RL Baseline Zoo for the baseline metrics. + +| algo | env_id |mean_reward| +|--------|-------------------------------|----------:| +|DQN |CartPole-v1 | 500.00 | +|DQN |MountainCar-v0 | -116.90 | diff --git a/tests/rl/tasks/ac/__init__.py b/tests/rl/tasks/ac/__init__.py index 24cc961fc..d9a73ecf3 100644 --- a/tests/rl/tasks/ac/__init__.py +++ b/tests/rl/tasks/ac/__init__.py @@ -19,6 +19,7 @@ action_upper_bound, gym_action_dim, gym_state_dim, + is_discrete, learn_env, num_agents, test_env, @@ -109,6 +110,8 @@ def get_ac_trainer(name: str, state_dim: int) -> ActorCriticTrainer: ) +assert not is_discrete + algorithm = "ac" agent2policy = {agent: f"{algorithm}_{agent}.policy" for agent in learn_env.agent_idx_list} policies = [ diff --git a/tests/rl/tasks/ddpg/__init__.py b/tests/rl/tasks/ddpg/__init__.py index 861904a43..cd097ddbc 100644 --- a/tests/rl/tasks/ddpg/__init__.py +++ b/tests/rl/tasks/ddpg/__init__.py @@ -20,6 +20,7 @@ gym_action_dim, gym_action_space, gym_state_dim, + is_discrete, learn_env, num_agents, test_env, @@ -123,6 +124,8 @@ def get_ddpg_trainer(name: str, state_dim: int, action_dim: int) -> DDPGTrainer: ) +assert not is_discrete + algorithm = "ddpg" agent2policy = {agent: f"{algorithm}_{agent}.policy" for agent in learn_env.agent_idx_list} policies = [ diff --git a/tests/rl/tasks/dqn/__init__.py b/tests/rl/tasks/dqn/__init__.py new file mode 100644 index 000000000..b4b0befa7 --- /dev/null +++ b/tests/rl/tasks/dqn/__init__.py @@ -0,0 +1,104 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. +import torch +from torch.optim import Adam + +from maro.rl.exploration import LinearExploration +from maro.rl.model import DiscreteQNet, FullyConnected +from maro.rl.policy import ValueBasedPolicy +from maro.rl.rl_component.rl_component_bundle import RLComponentBundle +from maro.rl.training.algorithms import DQNParams, DQNTrainer + +from tests.rl.gym_wrapper.common import gym_action_num, gym_state_dim, is_discrete, learn_env, num_agents, test_env +from tests.rl.gym_wrapper.env_sampler import GymEnvSampler + +net_conf = { + "hidden_dims": [256], + "activation": torch.nn.ReLU, + "output_activation": None, +} +lr = 1e-3 + + +class MyQNet(DiscreteQNet): + def __init__(self, state_dim: int, action_num: int) -> None: + super(MyQNet, self).__init__(state_dim=state_dim, action_num=action_num) + + self._mlp = FullyConnected( + input_dim=state_dim, + output_dim=action_num, + **net_conf, + ) + self._optim = Adam(self._mlp.parameters(), lr=lr) + + def _get_q_values_for_all_actions(self, states: torch.Tensor) -> torch.Tensor: + return self._mlp(states) + + +def get_dqn_policy( + name: str, + state_dim: int, + action_num: int, +) -> ValueBasedPolicy: + return ValueBasedPolicy( + name=name, + q_net=MyQNet(state_dim=state_dim, action_num=action_num), + explore_strategy=LinearExploration( + num_actions=action_num, + explore_steps=10000, + start_explore_prob=1.0, + end_explore_prob=0.02, + ), + warmup=0, # TODO: check this + ) + + +def get_dqn_trainer( + name: str, +) -> DQNTrainer: + return DQNTrainer( + name=name, + params=DQNParams( + use_prioritized_replay=False, # + # alpha=0.4, + # beta=0.6, + num_epochs=50, + update_target_every=10, + soft_update_coef=1.0, + ), + replay_memory_capacity=50000, + batch_size=64, + reward_discount=1.0, + ) + + +assert is_discrete + +algorithm = "dqn" +agent2policy = {agent: f"{algorithm}_{agent}.policy" for agent in learn_env.agent_idx_list} +policies = [ + get_dqn_policy( + f"{algorithm}_{i}.policy", + state_dim=gym_state_dim, + action_num=gym_action_num, + ) + for i in range(num_agents) +] +trainers = [get_dqn_trainer(f"{algorithm}_{i}") for i in range(num_agents)] + +device_mapping = {f"{algorithm}_{i}.policy": "cuda:0" for i in range(num_agents)} if torch.cuda.is_available() else None + +rl_component_bundle = RLComponentBundle( + env_sampler=GymEnvSampler( + learn_env=learn_env, + test_env=test_env, + policies=policies, + agent2policy=agent2policy, + ), + agent2policy=agent2policy, + policies=policies, + trainers=trainers, + device_mapping=device_mapping, +) + +__all__ = ["rl_component_bundle"] diff --git a/tests/rl/tasks/dqn/config.yml b/tests/rl/tasks/dqn/config.yml new file mode 100644 index 000000000..aa3971127 --- /dev/null +++ b/tests/rl/tasks/dqn/config.yml @@ -0,0 +1,32 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +# Example RL config file for GYM scenario. +# Please refer to `maro/rl/workflows/config/template.yml` for the complete template and detailed explanations. + +job: gym_rl_workflow +scenario_path: "tests/rl/tasks/dqn" +log_path: "tests/rl/log/dqn_cartpole" +main: + num_episodes: 3000 + num_steps: 50 + eval_schedule: 50 + num_eval_episodes: 10 + min_n_sample: 1 + logging: + stdout: INFO + file: DEBUG +rollout: + logging: + stdout: INFO + file: DEBUG +training: + mode: simple + load_path: null + load_episode: null + checkpointing: + path: null + interval: 5 + logging: + stdout: INFO + file: DEBUG diff --git a/tests/rl/tasks/ppo/__init__.py b/tests/rl/tasks/ppo/__init__.py index 15fc71069..722fce328 100644 --- a/tests/rl/tasks/ppo/__init__.py +++ b/tests/rl/tasks/ppo/__init__.py @@ -11,6 +11,7 @@ action_upper_bound, gym_action_dim, gym_state_dim, + is_discrete, learn_env, num_agents, test_env, @@ -36,6 +37,8 @@ def get_ppo_trainer(name: str, state_dim: int) -> PPOTrainer: ) +assert not is_discrete + algorithm = "ppo" agent2policy = {agent: f"{algorithm}_{agent}.policy" for agent in learn_env.agent_idx_list} policies = [ diff --git a/tests/rl/tasks/sac/__init__.py b/tests/rl/tasks/sac/__init__.py index 1e033f12b..421ea4e96 100644 --- a/tests/rl/tasks/sac/__init__.py +++ b/tests/rl/tasks/sac/__init__.py @@ -24,6 +24,7 @@ gym_action_dim, gym_action_space, gym_state_dim, + is_discrete, learn_env, num_agents, test_env, @@ -133,6 +134,8 @@ def get_sac_trainer(name: str, state_dim: int, action_dim: int) -> SoftActorCrit ) +assert not is_discrete + algorithm = "sac" agent2policy = {agent: f"{algorithm}_{agent}.policy" for agent in learn_env.agent_idx_list} policies = [ From deaa23d851fe3601852fe9eff10594855066012c Mon Sep 17 00:00:00 2001 From: Huoran Li Date: Tue, 23 May 2023 16:17:19 +0800 Subject: [PATCH 12/12] pre commit --- maro/rl/training/algorithms/dqn.py | 14 -------------- maro/rl/workflows/main.py | 2 +- 2 files changed, 1 insertion(+), 15 deletions(-) diff --git a/maro/rl/training/algorithms/dqn.py b/maro/rl/training/algorithms/dqn.py index 0003f6fe6..aa99a64f2 100644 --- a/maro/rl/training/algorithms/dqn.py +++ b/maro/rl/training/algorithms/dqn.py @@ -8,28 +8,14 @@ import torch from maro.rl.policy import RLPolicy, ValueBasedPolicy - -from maro.rl.training import ( - AbsTrainOps, - BaseTrainerParams, - PrioritizedReplayMemory, - RandomReplayMemory, - SingleAgentTrainer, - remote, -) - -from maro.rl.training import AbsTrainOps, BaseTrainerParams, RandomReplayMemory, RemoteOps, SingleAgentTrainer, remote - from maro.rl.training import ( AbsTrainOps, BaseTrainerParams, PrioritizedReplayMemory, RandomReplayMemory, - RemoteOps, SingleAgentTrainer, remote, ) - from maro.rl.utils import TransitionBatch, get_torch_device, ndarray_to_tensor from maro.utils import clone diff --git a/maro/rl/workflows/main.py b/maro/rl/workflows/main.py index 8da3da4d0..9566240c3 100644 --- a/maro/rl/workflows/main.py +++ b/maro/rl/workflows/main.py @@ -237,7 +237,7 @@ def evaluate_only_workflow(rl_component_bundle: RLComponentBundle, env_attr: Wor loaded = env_sampler.load_policy_state(path) env_attr.logger.info(f"Loaded policies {loaded} into env sampler from {path}") - result = env_sampler.eval(num_episodes=env_attr.num_eval_episodes) + env_sampler.eval(num_episodes=env_attr.num_eval_episodes) env_sampler.post_evaluate(-1) if isinstance(env_sampler, BatchEnvSampler):