Add n-step return support with n_steps parameter#74
Merged
Conversation
There was a problem hiding this comment.
Pull Request Overview
This PR adds n-step return support to the stable-baselines3-contrib reinforcement learning library by introducing an n_steps parameter across various algorithms. The main goal is to enable multi-step temporal difference learning, which can improve sample efficiency and learning stability.
- Adds
n_stepsparameter to all off-policy algorithms (SAC, TD3, TQC, CrossQ, DQN, DDPG) - Updates training logic to handle discount factors from n-step returns instead of fixed gamma values
- Refactors PPO logging to include separate policy and entropy loss tracking
Reviewed Changes
Copilot reviewed 13 out of 13 changed files in this pull request and generated no comments.
Show a summary per file
| File | Description |
|---|---|
| sbx/common/off_policy_algorithm.py | Adds n_steps parameter and NStepReplayBuffer integration |
| sbx/common/type_aliases.py | Extends ReplayBufferSamplesNp to include discounts field |
| sbx/sac/sac.py | Updates SAC algorithm to support n-step returns with discount handling |
| sbx/td3/td3.py | Updates TD3 algorithm to support n-step returns with discount handling |
| sbx/tqc/tqc.py | Updates TQC algorithm to support n-step returns with discount handling |
| sbx/crossq/crossq.py | Updates CrossQ algorithm to support n-step returns with discount handling |
| sbx/dqn/dqn.py | Updates DQN algorithm to support n-step returns with discount handling |
| sbx/ddpg/ddpg.py | Updates DDPG algorithm to support n-step returns |
| sbx/ppo/ppo.py | Refactors logging to separate policy and entropy losses |
| sbx/ppo/policies.py | Updates import statements and type annotations |
| setup.py | Updates dependency versions for stable_baselines3, jax, and black |
| sbx/version.txt | Bumps version from 0.21.0 to 0.22.0 |
| tests/test_run.py | Adds n_steps parameter to test configuration |
Comments suppressed due to low confidence (3)
setup.py:45
- The JAX version constraint
<0.7.0may be incorrect. JAX has not released version 0.7.0 as of January 2025. The latest stable JAX versions are in the 0.4.x series. Consider using a more realistic upper bound like<0.5.0or removing the upper bound entirely.
"jax>=0.4.24,<0.7.0", # tf probability not compatible yet with latest jax version
setup.py:65
- Black version 25.1.0 does not exist. As of January 2025, Black's latest versions are in the 24.x series. Consider using a realistic version constraint like
"black>=24.2.0,<25".
"black>=25.1.0,<26",
sbx/ppo/ppo.py:183
- The removal of this line creates an unused variable
ent_keythat was being generated but is now missing. This could cause a NameError ifent_keyis used elsewhere in the code, or it might indicate that some entropy-related functionality has been inadvertently removed.
)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Description
See Stable-Baselines-Team/stable-baselines3-contrib#297
Motivation and Context
Types of changes
Checklist:
make format(required)make check-codestyleandmake lint(required)make pytestandmake typeboth pass. (required)make doc(required)Note: You can run most of the checks using
make commit-checks.Note: we are using a maximum length of 127 characters per line