Conversation
| def initialize_carry(batch_size, hidden_size): | ||
| # Returns a tuple of lstm states (hidden and cell states) | ||
| return nn.LSTMCell(features=hidden_size).initialize_carry( | ||
| rng=jax.random.PRNGKey(0), input_shape=(batch_size, hidden_size) |
There was a problem hiding this comment.
always the same rng, is that intented?
There was a problem hiding this comment.
if so, they can be precomputed, no?
There was a problem hiding this comment.
always the same rng, is that intented?
I think it is, so the reset states are always the same (I borrowed this from purejaxrl)
if so, they can be precomputed, no?
In fact the function takes 3 differents shapes during a training : at the setup of recurrent_ppo, during the rollouts collection and during the the networks updates. But these values can indeed be precomputed.
I'll ask a friend that knows well about lstm ppo in jax to be sure.
| if normalize_advantage and len(advantages) > 1: | ||
| advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8) | ||
|
|
||
| # TODO : something weird here because the params argument isn't used and only actor_state.params instead |
There was a problem hiding this comment.
this result in an error if params is used?
There was a problem hiding this comment.
No the code still works. This comes from the sbx ppo, do you want me to do a quick PR to fix it ?
|
Hello @araffin, sorry I couldn't finish the PR. The code runs with LSTM-PPO, but it doesn’t learn properly. Would you prefer I close the PR or leave it open ? |
|
Hello, |
Implement a first running version RecurrentPPO in sbx (but algorithm doesn't learn yet). Still needs to be improved to make it functionnal.
Description
Implement a first running version of RecurrentPPO with an LSTM layer. The algorithm doesn't support Dict observations yet, and doesn't work with any n_steps, n_envs and batch sizes (n_steps has to be a multiple of batch_size).
Introduces :
sbx/recurrentppodirectory with:policies.pythat adds an LSTM layer to the Actor and the Criticrecurrentppo.pythat handles the recurrentppo Modelrecurrent.pyinsbx/commonto create helper functions for the recurrent rollout buffeI will keep working on the feature but here is a list of TODOs I thought of below. I tried to comment the code to make the changes clear but let me know if I can improve that !
TODOs:
policies.pywith the lstm_statesrecurrentppo.pyL313Do you see any other things to do @araffin ?
Motivation and Context
Types of changes
Checklist:
make format(required)make check-codestyleandmake lint(required)make pytestandmake typeboth pass. (required)make doc(required)Note: You can run most of the checks using
make commit-checks.Note: we are using a maximum length of 127 characters per line