Skip to content

Conversation

@Aman071106
Copy link

Title: feat(data): Add 'axis' support to gm.data.pad

Description

This PR implements the missing axis argument in gm.data.pad, addressing the TODO left in gemma/gm/data/_functional.py.

Previously, gm.data.pad raised a NotImplementedError if axis != -1. This change enables padding along arbitrary axes, which is essential for handling multi-dimensional data pipelines in JAX (e.g., padding batch dimensions or specific feature dimensions).

Changes

  • gemma/gm/data/_functional.py:
    • Removed NotImplementedError guard.
    • Implemented pad_width construction logic to target the specified axis.
  • gemma/gm/data/_functional_test.py:
    • Added test_pad_axis regression test to verify padding on axis=0 and axis=-2.

Verification

  • Added new test case: test_pad_axis (Passed).
  • Ran existing tests: test_pad, test_seq2seq (Passed).

Related Issue

Resolves TODO(epot): Could add an axis= kwarg to support multi-dimensional arrays.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant