From 6c41b96cb0aaf16cbcbecdc5ba33cd70ac533ad8 Mon Sep 17 00:00:00 2001 From: Abhik Singla Date: Wed, 2 Mar 2022 10:39:30 -0800 Subject: [PATCH 1/3] Working simple custom buffer replacement --- rllib/agents/cql/cql_sac.py | 2 ++ rllib/agents/dqn/dqn.py | 24 ++++++++++++++---------- 2 files changed, 16 insertions(+), 10 deletions(-) diff --git a/rllib/agents/cql/cql_sac.py b/rllib/agents/cql/cql_sac.py index ed1d7021e57c..4aba097220d9 100644 --- a/rllib/agents/cql/cql_sac.py +++ b/rllib/agents/cql/cql_sac.py @@ -43,6 +43,8 @@ "alpha_upper_bound": 1.0, # Lower bound for alpha value during the lagrangian constraint "alpha_lower_bound": 0.0, + # custom replay buffer + "replay_buffer": None, }) # __sphinx_doc_end__ # yapf: enable diff --git a/rllib/agents/dqn/dqn.py b/rllib/agents/dqn/dqn.py index 696633d96459..9768329152ba 100644 --- a/rllib/agents/dqn/dqn.py +++ b/rllib/agents/dqn/dqn.py @@ -216,16 +216,20 @@ def execution_plan(workers: WorkerSet, else: prio_args = {} - local_replay_buffer = LocalReplayBuffer( - num_shards=1, - learning_starts=config["learning_starts"], - buffer_size=config["buffer_size"], - replay_batch_size=config["train_batch_size"], - replay_mode=config["multiagent"]["replay_mode"], - replay_sequence_length=config.get("replay_sequence_length", 1), - replay_burn_in=config.get("burn_in", 0), - replay_zero_init_states=config.get("zero_init_states", True), - **prio_args) + if config.get("replay_buffer"): + local_replay_buffer = config.get("replay_buffer") + local_replay_buffer = local_replay_buffer() + else: + local_replay_buffer = LocalReplayBuffer( + num_shards=1, + learning_starts=config["learning_starts"], + buffer_size=config["buffer_size"], + replay_batch_size=config["train_batch_size"], + replay_mode=config["multiagent"]["replay_mode"], + replay_sequence_length=config.get("replay_sequence_length", 1), + replay_burn_in=config.get("burn_in", 0), + replay_zero_init_states=config.get("zero_init_states", True), + **prio_args) input_reader = workers.local_worker().input_reader is_offline_training = isinstance(input_reader, InMemoryInputReader) From d2fbaed2c8b626dc86093229c173e9b72ab48e70 Mon Sep 17 00:00:00 2001 From: Abhik Singla Date: Wed, 2 Mar 2022 13:09:59 -0800 Subject: [PATCH 2/3] support replay buffer init with offline data --- rllib/agents/dqn/dqn.py | 43 ++++++++++++++++++----------------------- 1 file changed, 19 insertions(+), 24 deletions(-) diff --git a/rllib/agents/dqn/dqn.py b/rllib/agents/dqn/dqn.py index 9768329152ba..d839d69702d9 100644 --- a/rllib/agents/dqn/dqn.py +++ b/rllib/agents/dqn/dqn.py @@ -218,7 +218,10 @@ def execution_plan(workers: WorkerSet, if config.get("replay_buffer"): local_replay_buffer = config.get("replay_buffer") - local_replay_buffer = local_replay_buffer() + input_reader = workers.local_worker().input_reader + assert isinstance(input_reader, InMemoryInputReader) + local_replay_buffer = local_replay_buffer(input_reader) + is_offline_training = True else: local_replay_buffer = LocalReplayBuffer( num_shards=1, @@ -231,29 +234,21 @@ def execution_plan(workers: WorkerSet, replay_zero_init_states=config.get("zero_init_states", True), **prio_args) - input_reader = workers.local_worker().input_reader - is_offline_training = isinstance(input_reader, InMemoryInputReader) - if is_offline_training: - # if we have an InMemoryInputReader, then we are in Offline Training - # which means that we don't need the sampling pipeline setup - for batch in input_reader.get_all(): - local_replay_buffer.add_batch(batch) - else: - parallel_rollouts_mode = config.get("parallel_rollouts_mode", "bulk_sync") - num_async = config.get("parallel_rollouts_num_async") - # This could be set to None explicitly - if not num_async: - num_async = 1 - rollouts = ParallelRollouts(workers, mode=parallel_rollouts_mode, num_async=num_async) - - # We execute the following steps concurrently: - # (1) Generate rollouts and store them in our local replay buffer. Calling - # next() on store_op drives this. - store_op = rollouts.for_each( - StoreToReplayBuffer(local_buffer=local_replay_buffer)) - if config.get("execution_plan_custom_store_ops"): - custom_store_ops = config["execution_plan_custom_store_ops"] - store_op = store_op.for_each(custom_store_ops(workers, config)) + parallel_rollouts_mode = config.get("parallel_rollouts_mode", "bulk_sync") + num_async = config.get("parallel_rollouts_num_async") + # This could be set to None explicitly + if not num_async: + num_async = 1 + rollouts = ParallelRollouts(workers, mode=parallel_rollouts_mode, num_async=num_async) + + # We execute the following steps concurrently: + # (1) Generate rollouts and store them in our local replay buffer. Calling + # next() on store_op drives this. + store_op = rollouts.for_each( + StoreToReplayBuffer(local_buffer=local_replay_buffer)) + if config.get("execution_plan_custom_store_ops"): + custom_store_ops = config["execution_plan_custom_store_ops"] + store_op = store_op.for_each(custom_store_ops(workers, config)) def update_prio(item): samples, info_dict = item From db53f8eefb1ecee99a0a5320217d9f25ccb6164f Mon Sep 17 00:00:00 2001 From: Abhik Singla Date: Wed, 2 Mar 2022 16:56:55 -0800 Subject: [PATCH 3/3] fixed buffer config bug --- rllib/agents/dqn/dqn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rllib/agents/dqn/dqn.py b/rllib/agents/dqn/dqn.py index d839d69702d9..acd3866a03c9 100644 --- a/rllib/agents/dqn/dqn.py +++ b/rllib/agents/dqn/dqn.py @@ -220,7 +220,7 @@ def execution_plan(workers: WorkerSet, local_replay_buffer = config.get("replay_buffer") input_reader = workers.local_worker().input_reader assert isinstance(input_reader, InMemoryInputReader) - local_replay_buffer = local_replay_buffer(input_reader) + local_replay_buffer = local_replay_buffer(config, prio_args, input_reader) is_offline_training = True else: local_replay_buffer = LocalReplayBuffer(