Skip to content

Remove nnx.List#5223

Open
samanklesaria wants to merge 3 commits intogoogle:mainfrom
samanklesaria:new_sequential
Open

Remove nnx.List#5223
samanklesaria wants to merge 3 commits intogoogle:mainfrom
samanklesaria:new_sequential

Conversation

@samanklesaria
Copy link
Collaborator

@samanklesaria samanklesaria commented Feb 4, 2026

An an effort to simplify Flax Modules, this PR replaces the wacky nnx.List that pretends to be a list but is really a dict that converts its indices to strings on the fly with a regular python list. This should simplify the code considerably, and make faster flattening possible down the road.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @samanklesaria, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly simplifies the internal structure of Flax Modules by eliminating the custom nnx.List class. This change streamlines the handling of sequential module components by leveraging native Python lists, enhancing readability and reducing the cognitive overhead associated with a specialized list implementation. The transition to standard Python lists, often wrapped with nnx.data for JAX compatibility, aligns with the goal of making Flax Modules more intuitive and easier to work with.

Highlights

  • Removal of nnx.List: The custom nnx.List class, which mimicked a list but internally used a dictionary with string-converted integer keys, has been completely removed from the flax.nnx.helpers module.
  • Replacement with Standard Python Lists: All instances of nnx.List across benchmarks, examples, and tests have been replaced with standard Python lists ([]). For cases where the list elements need to be treated as JAX pytree nodes, nnx.data([...]) is now used to explicitly mark them.
  • Simplified Module Structure: The Sequential module's initialization has been updated to directly use nnx.data(list(fns)) instead of relying on the removed nnx.List.
  • Codebase Cleanup: Imports, docstrings, and error messages referencing nnx.List have been updated or removed to reflect its deprecation and removal, leading to a cleaner and more consistent codebase.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • benchmarks/nnx_graph_overhead.py
    • Replaced nnx.List with a standard Python list for self.intermediates.
  • examples/gemma/transformer.py
    • Replaced nnx.List with nnx.data for self.layers.
  • examples/nnx_toy_examples/hijax_demo.py
    • Replaced nnx.List with nnx.data for self.blocks initialization.
    • Updated isinstance checks from nnx.List to list.
  • flax/nnx/init.py
    • Removed the import statement for nnx.List.
  • flax/nnx/helpers.py
    • Removed the entire nnx.List class definition.
    • Added import for data from flax.nnx.pytreelib.
    • Updated Sequential module to use nnx.data(list(fns)) instead of List(fns).
  • flax/nnx/pytreelib.py
    • Removed the suggestion to use nnx.List from an error message.
  • flax/nnx/statelib.py
    • Removed example code demonstrating nnx.List usage from the restore_int_paths docstring.
  • tests/nnx/graph_utils_test.py
    • Updated test_unflatten, test_unflatten_pure_dict, test_unflatten_empty, and test_unflatten_return_variables to use standard Python lists instead of nnx.List.
  • tests/nnx/helpers_test.py
    • Removed the test_list_mutable_sequence test case.
  • tests/nnx/integration_test.py
    • Updated test_replace_by_pure_dict to initialize self.layers as a standard Python list.
  • tests/nnx/module_test.py
    • Updated test_deref_through_jit, test_deref_number_of_fields, and test_clone to use standard Python lists instead of nnx.List.
  • tests/nnx/partitioning_test.py
    • Updated multiple test cases (test_partition, test_complete_partitioning, test_complete_partitioning_plus_ellipsis, test_inclomplete_partition_error, test_ellipsis_not_last_error, test_update_from, test_update_from_with_array_leaf, test_grad_example, test_get_paritition) to use standard Python lists instead of nnx.List.
  • tests/nnx/state_test.py
    • Updated test_integer_access and test_diff to use standard Python lists instead of nnx.List.
  • tests/nnx/transforms_test.py
    • Updated test_grad, test_grad_with_multiple_ref_types, and test_grad_with_type_predicate to use standard Python lists instead of nnx.List.
Activity
  • No specific activity (comments, reviews, or progress updates) has been recorded for this pull request yet.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request is a great step towards simplifying the flax.nnx API by removing the custom nnx.List. However, the replacement of nnx.List with a standard Python list is not always correct. When a list of modules or variables is an attribute of an nnx.Module, it must be explicitly marked as nnx.data() to be included in the module's state. Otherwise, it's treated as a static attribute, and its contents (parameters, etc.) are ignored by NNX's state management, which can lead to incorrect behavior like parameters not being trained. I've identified several places in the benchmarks and tests where this needs to be corrected. The changes in the examples and helpers that correctly use nnx.data() are great.

@samanklesaria samanklesaria marked this pull request as ready for review February 4, 2026 23:02
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant