Skip to content
72 changes: 52 additions & 20 deletions examples/cim/rl/algorithms/dqn.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from typing import Optional, Tuple

import torch
from torch.optim import RMSprop

from maro.rl.exploration import MultiLinearExplorationScheduler, epsilon_greedy
from maro.rl.exploration import EpsilonGreedy
from maro.rl.model import DiscreteQNet, FullyConnected
from maro.rl.policy import ValueBasedPolicy
from maro.rl.training.algorithms import DQNParams, DQNTrainer
Expand All @@ -23,32 +24,62 @@


class MyQNet(DiscreteQNet):
def __init__(self, state_dim: int, action_num: int) -> None:
def __init__(
self,
state_dim: int,
action_num: int,
dueling_param: Optional[Tuple[dict, dict]] = None,
) -> None:
super(MyQNet, self).__init__(state_dim=state_dim, action_num=action_num)
self._fc = FullyConnected(input_dim=state_dim, output_dim=action_num, **q_net_conf)
self._optim = RMSprop(self._fc.parameters(), lr=learning_rate)

self._use_dueling = dueling_param is not None
self._fc = FullyConnected(input_dim=state_dim, output_dim=0 if self._use_dueling else action_num, **q_net_conf)
if self._use_dueling:
q_kwargs, v_kwargs = dueling_param
self._q = FullyConnected(input_dim=self._fc.output_dim, output_dim=action_num, **q_kwargs)
self._v = FullyConnected(input_dim=self._fc.output_dim, output_dim=1, **v_kwargs)

self._optim = RMSprop(self.parameters(), lr=learning_rate)

def _get_q_values_for_all_actions(self, states: torch.Tensor) -> torch.Tensor:
return self._fc(states)
logits = self._fc(states)
if self._use_dueling:
q = self._q(logits)
v = self._v(logits)
logits = q - q.mean(dim=1, keepdim=True) + v
return logits


def get_dqn_policy(state_dim: int, action_num: int, name: str) -> ValueBasedPolicy:
q_kwargs = {
"hidden_dims": [128],
"activation": torch.nn.LeakyReLU,
"output_activation": torch.nn.LeakyReLU,
"softmax": False,
"batch_norm": True,
"skip_connection": False,
"head": True,
"dropout_p": 0.0,
}
v_kwargs = {
"hidden_dims": [128],
"activation": torch.nn.LeakyReLU,
"output_activation": None,
"softmax": False,
"batch_norm": True,
"skip_connection": False,
"head": True,
"dropout_p": 0.0,
}

return ValueBasedPolicy(
name=name,
q_net=MyQNet(state_dim, action_num),
exploration_strategy=(epsilon_greedy, {"epsilon": 0.4}),
exploration_scheduling_options=[
(
"epsilon",
MultiLinearExplorationScheduler,
{
"splits": [(2, 0.32)],
"initial_value": 0.4,
"last_ep": 5,
"final_value": 0.0,
},
),
],
q_net=MyQNet(
state_dim,
action_num,
dueling_param=(q_kwargs, v_kwargs),
),
explore_strategy=EpsilonGreedy(epsilon=0.4, num_actions=action_num),
warmup=100,
)

Expand All @@ -64,6 +95,7 @@ def get_dqn(name: str) -> DQNTrainer:
num_epochs=10,
soft_update_coef=0.1,
double=False,
random_overwrite=False,
alpha=1.0,
beta=1.0,
),
)
2 changes: 1 addition & 1 deletion examples/cim/rl/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,4 +35,4 @@

action_num = len(action_shaping_conf["action_space"])

algorithm = "ppo" # ac, ppo, dqn or discrete_maddpg
algorithm = "dqn" # ac, ppo, dqn or discrete_maddpg
61 changes: 52 additions & 9 deletions examples/cim/rl/env_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,30 +85,73 @@ def _post_step(self, cache_element: CacheElement) -> None:
def _post_eval_step(self, cache_element: CacheElement) -> None:
self._post_step(cache_element)

def post_collect(self, info_list: list, ep: int) -> None:
def post_collect(self, ep: int) -> None:
if len(self._info_list) == 0:
return

# print the env metric from each rollout worker
for info in info_list:
for info in self._info_list:
print(f"env summary (episode {ep}): {info['env_metric']}")

# average env metric
metric_keys, num_envs = info_list[0]["env_metric"].keys(), len(info_list)
avg_metric = {key: sum(info["env_metric"][key] for info in info_list) / num_envs for key in metric_keys}
metric_keys, num_envs = self._info_list[0]["env_metric"].keys(), len(self._info_list)
avg_metric = {key: sum(info["env_metric"][key] for info in self._info_list) / num_envs for key in metric_keys}
avg_metric["n_episode"] = len(self._info_list)
print(f"average env summary (episode {ep}): {avg_metric}")

self.metrics.update(avg_metric)
self.metrics = {k: v for k, v in self.metrics.items() if not k.startswith("val/")}
self._info_list.clear()

def post_evaluate(self, ep: int) -> None:
if len(self._info_list) == 0:
return

def post_evaluate(self, info_list: list, ep: int) -> None:
# print the env metric from each rollout worker
for info in info_list:
for info in self._info_list:
print(f"env summary (episode {ep}): {info['env_metric']}")

# average env metric
metric_keys, num_envs = info_list[0]["env_metric"].keys(), len(info_list)
avg_metric = {key: sum(info["env_metric"][key] for info in info_list) / num_envs for key in metric_keys}
metric_keys, num_envs = self._info_list[0]["env_metric"].keys(), len(self._info_list)
avg_metric = {key: sum(info["env_metric"][key] for info in self._info_list) / num_envs for key in metric_keys}
avg_metric["n_episode"] = len(self._info_list)
print(f"average env summary (episode {ep}): {avg_metric}")

self.metrics.update({"val/" + k: v for k, v in avg_metric.items()})
self._info_list.clear()

def monitor_metrics(self) -> float:
return -self.metrics["val/container_shortage"]
return -self.metrics["val/shortage_percentage"]

@staticmethod
def merge_metrics(metrics_list: List[dict]) -> dict:
n_episode = sum(m["n_episode"] for m in metrics_list)
metrics: dict = {
"order_requirements": sum(m["order_requirements"] * m["n_episode"] for m in metrics_list) / n_episode,
"container_shortage": sum(m["container_shortage"] * m["n_episode"] for m in metrics_list) / n_episode,
"operation_number": sum(m["operation_number"] * m["n_episode"] for m in metrics_list) / n_episode,
"n_episode": n_episode,
}
metrics["shortage_percentage"] = metrics["container_shortage"] / metrics["order_requirements"]

metrics_list = [m for m in metrics_list if "val/shortage_percentage" in m]
if len(metrics_list) > 0:
n_episode = sum(m["val/n_episode"] for m in metrics_list)
metrics.update(
{
"val/order_requirements": sum(
m["val/order_requirements"] * m["val/n_episode"] for m in metrics_list
)
/ n_episode,
"val/container_shortage": sum(
m["val/container_shortage"] * m["val/n_episode"] for m in metrics_list
)
/ n_episode,
"val/operation_number": sum(m["val/operation_number"] * m["val/n_episode"] for m in metrics_list)
/ n_episode,
"val/n_episode": n_episode,
},
)
metrics["val/shortage_percentage"] = metrics["val/container_shortage"] / metrics["val/order_requirements"]

return metrics
16 changes: 2 additions & 14 deletions examples/vm_scheduling/rl/algorithms/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from torch.optim import SGD
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts

from maro.rl.exploration import MultiLinearExplorationScheduler
from maro.rl.exploration import EpsilonGreedy
from maro.rl.model import DiscreteQNet, FullyConnected
from maro.rl.policy import ValueBasedPolicy
from maro.rl.training.algorithms import DQNParams, DQNTrainer
Expand Down Expand Up @@ -58,19 +58,7 @@ def get_dqn_policy(state_dim: int, action_num: int, num_features: int, name: str
return ValueBasedPolicy(
name=name,
q_net=MyQNet(state_dim, action_num, num_features),
exploration_strategy=(MaskedEpsGreedy(state_dim, num_features), {"epsilon": 0.4}),
exploration_scheduling_options=[
(
"epsilon",
MultiLinearExplorationScheduler,
{
"splits": [(100, 0.32)],
"initial_value": 0.4,
"last_ep": 400,
"final_value": 0.0,
},
),
],
explore_strategy=EpsilonGreedy(epsilon=0.4, num_actions=action_num),
warmup=100,
)

Expand Down
34 changes: 21 additions & 13 deletions examples/vm_scheduling/rl/env_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,29 +166,35 @@ def _post_step(self, cache_element: CacheElement) -> None:
def _post_eval_step(self, cache_element: CacheElement) -> None:
self._post_step(cache_element)

def post_collect(self, info_list: list, ep: int) -> None:
def post_collect(self, ep: int) -> None:
if len(self._info_list) == 0:
return

# print the env metric from each rollout worker
for info in info_list:
for info in self._info_list:
print(f"env summary (episode {ep}): {info['env_metric']}")

# print the average env metric
if len(info_list) > 1:
metric_keys, num_envs = info_list[0]["env_metric"].keys(), len(info_list)
avg_metric = {key: sum(tr["env_metric"][key] for tr in info_list) / num_envs for key in metric_keys}
print(f"average env metric (episode {ep}): {avg_metric}")
metric_keys, num_envs = self._info_list[0]["env_metric"].keys(), len(self._info_list)
avg_metric = {key: sum(tr["env_metric"][key] for tr in self._info_list) / num_envs for key in metric_keys}
print(f"average env metric (episode {ep}): {avg_metric}")

self._info_list.clear()

def post_evaluate(self, ep: int) -> None:
if len(self._info_list) == 0:
return

def post_evaluate(self, info_list: list, ep: int) -> None:
# print the env metric from each rollout worker
for info in info_list:
for info in self._info_list:
print(f"env summary (evaluation episode {ep}): {info['env_metric']}")

# print the average env metric
if len(info_list) > 1:
metric_keys, num_envs = info_list[0]["env_metric"].keys(), len(info_list)
avg_metric = {key: sum(tr["env_metric"][key] for tr in info_list) / num_envs for key in metric_keys}
print(f"average env metric (evaluation episode {ep}): {avg_metric}")
metric_keys, num_envs = self._info_list[0]["env_metric"].keys(), len(self._info_list)
avg_metric = {key: sum(tr["env_metric"][key] for tr in self._info_list) / num_envs for key in metric_keys}
print(f"average env metric (evaluation episode {ep}): {avg_metric}")

for info in info_list:
for info in self._info_list:
core_requirement = info["actions_by_core_requirement"]
action_sequence = info["action_sequence"]
# plot action sequence
Expand Down Expand Up @@ -231,3 +237,5 @@ def post_evaluate(self, info_list: list, ep: int) -> None:

plt.cla()
plt.close("all")

self._info_list.clear()
2 changes: 1 addition & 1 deletion maro/rl/distributed/abs_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def __init__(
super(AbsWorker, self).__init__()

self._id = f"worker.{idx}"
self._logger: Union[LoggerV2, DummyLogger] = logger if logger else DummyLogger()
self._logger: Union[LoggerV2, DummyLogger] = logger or DummyLogger()

# ZMQ sockets and streams
self._context = Context.instance()
Expand Down
12 changes: 4 additions & 8 deletions maro/rl/exploration/__init__.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,10 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from .scheduling import AbsExplorationScheduler, LinearExplorationScheduler, MultiLinearExplorationScheduler
from .strategies import epsilon_greedy, gaussian_noise, uniform_noise
from .strategies import EpsilonGreedy, ExploreStrategy, LinearExploration

__all__ = [
"AbsExplorationScheduler",
"LinearExplorationScheduler",
"MultiLinearExplorationScheduler",
"epsilon_greedy",
"gaussian_noise",
"uniform_noise",
"ExploreStrategy",
"EpsilonGreedy",
"LinearExploration",
]
Loading