Skip to content

Fix compiler toolkit CI by removing duplicated buffer registration#2435

Open
yiming0416 wants to merge 1 commit intomainfrom
yiming/fix_compiler_toolkit_ci
Open

Fix compiler toolkit CI by removing duplicated buffer registration#2435
yiming0416 wants to merge 1 commit intomainfrom
yiming/fix_compiler_toolkit_ci

Conversation

@yiming0416
Copy link
Contributor

self.rope.cache and self.freqs_cis are the same tensor object registered as buffers on two different modules. Tracing would see them as two distinct graph inputs for the same underlying data.

This PR removes the register_buffer from RoPE and just store cache as a plain tensor attribute there, keeping only the Decoder-level register_buffer("freqs_cis", ...).

This fixes the compiler toolkit CI:

NGPU=4 TRAIN_FILE=torchtitan.experiments.compiler_toolkit.train MODULE=compiler_toolkit.llama3 CONFIG=compiler_toolkit_llama3_debugmodel ./run_train.sh --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=2

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Feb 24, 2026
@yiming0416
Copy link
Contributor Author

@tianyu-l This fixes the compiler toolkit CI failure after your config system change.

let me know if you are okay with changing the model code. Otherwise I can fix it from the compiler toolkit experiment.

@yiming0416 yiming0416 changed the title fix compiler toolkit ci Fix compiler toolkit CIi by removing duplicated buffer registration Feb 24, 2026
@yiming0416 yiming0416 changed the title Fix compiler toolkit CIi by removing duplicated buffer registration Fix compiler toolkit CI by removing duplicated buffer registration Feb 24, 2026
Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds OK to me, @fegin please take a look

@wconstab
Copy link
Contributor

This PR removes the register_buffer from RoPE and just store cache as a plain tensor attribute there
this doesn't have any impact to checkpointing does it?

@yiming0416
Copy link
Contributor Author

This PR removes the register_buffer from RoPE and just store cache as a plain tensor attribute there

this doesn't have any impact to checkpointing does it?

@wconstab Originally the cache was registered as a non-persistent buffer, which won't appear in state_dict, so I assume it shouldn't affect checkpointing?

@yiming0416
Copy link
Contributor Author

@fegin could you take a look? thanks!

Copy link
Contributor

@fegin fegin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, I don't think checkpointing is a problem.

self.config = config
# Buffer registered later in init_weights
self.register_buffer("cache", self._precompute(), persistent=False)
self.cache: torch.Tensor = self._precompute()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

one other consideration, iiuc when registered as a non-persisten buffer, 'cache' will at least be 'moved' to device when the module is moved .to(device). Having it as a plain tensor will not do this. Will this break our initialization flow (starting as 'meta' and moving to 'cuda' for example?)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/8gpu CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants