Skip to content

Conversation

@ji-huazhong
Copy link
Collaborator

@ji-huazhong ji-huazhong commented Jan 14, 2026

What does this PR do?

As title.

Given the significant differences between the implementations of fsdp1+ep and fsdp2+ep in Veomni, I will prioritize completing the functional verification of rl based on fsdp2.

Checklist Before Starting

  • Search for similar PRs. Paste at least one query link here: ...
  • Format the PR title as [{modules}] {type}: {description} (This will be checked by the CI)
    • {modules} include fsdp, megatron, veomni, sglang, vllm, rollout, trainer, ci, training_utils, recipe, hardware, deployment, ray, worker, single_controller, misc, perf, model, algo, env, tool, ckpt, doc, data, cfg, reward
    • If this PR involves multiple modules, separate them with , like [megatron, fsdp, doc]
    • {type} is in feat, fix, refactor, chore, test
    • If this PR breaks any API (CLI arguments, config, function signature, etc.), add [BREAKING] to the beginning of the title.
    • Example: [BREAKING][fsdp, megatron] feat: dynamic batching

Test

For changes that can not be tested by CI (e.g., algorithm implementation, new model support), validate by experiment(s) and show results like training curve plots, evaluation results, etc.

API and Usage Example

Demonstrate how the API changes if any, and provide usage example(s) if possible.

# Add code snippet or script demonstrating how to use this

Design & Code Changes

Demonstrate the high-level design if this PR is complex, and list the specific changes.

Checklist Before Submitting

Important

Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces functionality to offload and load veomni models and optimizers to/from CPU, controlled by new configuration flags. The implementation is mostly sound, but I've identified a significant issue in the optimizer handling logic. The functions for offloading and loading optimizer state contain duplicated code and a potential bug that could prevent MultiOptimizer states from being processed correctly. I've provided suggestions to refactor this logic to improve its correctness and maintainability.

Comment on lines +73 to +93
if not optimizer.state:
return

# Check if this is a MultiOptimizer (for ep and non-ep parameters when ep+fsdp2 is enabled)
if hasattr(optimizer, "_is_multi_optimizer") and optimizer._is_multi_optimizer:
for sub_opt in optimizer.optimizers_dict.values():
if not sub_opt.state:
continue
for param_group in sub_opt.param_groups:
for param in param_group["params"]:
state = sub_opt.state[param]
for key, value in state.items():
if isinstance(value, torch.Tensor):
state[key] = value.to("cpu", non_blocking=True)
else:
for param_group in optimizer.param_groups:
for param in param_group["params"]:
state = optimizer.state[param]
for key, value in state.items():
if isinstance(value, torch.Tensor):
state[key] = value.to("cpu", non_blocking=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The logic for offloading optimizer state is duplicated for single optimizers and MultiOptimizer. This can be refactored into a single loop to improve readability and maintainability. Additionally, the initial check if not optimizer.state: return could cause a bug for MultiOptimizer if its top-level state dictionary is empty but its sub-optimizers have state, causing the function to exit prematurely without offloading.

Suggested change
if not optimizer.state:
return
# Check if this is a MultiOptimizer (for ep and non-ep parameters when ep+fsdp2 is enabled)
if hasattr(optimizer, "_is_multi_optimizer") and optimizer._is_multi_optimizer:
for sub_opt in optimizer.optimizers_dict.values():
if not sub_opt.state:
continue
for param_group in sub_opt.param_groups:
for param in param_group["params"]:
state = sub_opt.state[param]
for key, value in state.items():
if isinstance(value, torch.Tensor):
state[key] = value.to("cpu", non_blocking=True)
else:
for param_group in optimizer.param_groups:
for param in param_group["params"]:
state = optimizer.state[param]
for key, value in state.items():
if isinstance(value, torch.Tensor):
state[key] = value.to("cpu", non_blocking=True)
optimizers_to_process = []
if hasattr(optimizer, "_is_multi_optimizer") and optimizer._is_multi_optimizer:
optimizers_to_process.extend(optimizer.optimizers_dict.values())
else:
optimizers_to_process.append(optimizer)
for opt in optimizers_to_process:
if not opt.state:
continue
for param_group in opt.param_groups:
for param in param_group["params"]:
if param not in opt.state:
continue
state = opt.state[param]
for key, value in state.items():
if isinstance(value, torch.Tensor):
state[key] = value.to("cpu", non_blocking=True)

Comment on lines +98 to +118
if not optimizer.state:
return

# Check if this is a MultiOptimizer (for ep and non-ep parameters when ep+fsdp2 is enabled)
if hasattr(optimizer, "_is_multi_optimizer") and optimizer._is_multi_optimizer:
for sub_opt in optimizer.optimizers_dict.values():
if not sub_opt.state:
continue
for param_group in sub_opt.param_groups:
for param in param_group["params"]:
state = sub_opt.state[param]
for key, value in state.items():
if isinstance(value, torch.Tensor):
state[key] = value.to(device_id, non_blocking=True)
else:
for param_group in optimizer.param_groups:
for param in param_group["params"]:
state = optimizer.state[param]
for key, value in state.items():
if isinstance(value, torch.Tensor):
state[key] = value.to(device_id, non_blocking=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Similar to offload_veomni_optimizer, the logic here is duplicated and contains a potential bug with the initial if not optimizer.state: check for MultiOptimizer. This can be refactored to improve correctness and maintainability.

Suggested change
if not optimizer.state:
return
# Check if this is a MultiOptimizer (for ep and non-ep parameters when ep+fsdp2 is enabled)
if hasattr(optimizer, "_is_multi_optimizer") and optimizer._is_multi_optimizer:
for sub_opt in optimizer.optimizers_dict.values():
if not sub_opt.state:
continue
for param_group in sub_opt.param_groups:
for param in param_group["params"]:
state = sub_opt.state[param]
for key, value in state.items():
if isinstance(value, torch.Tensor):
state[key] = value.to(device_id, non_blocking=True)
else:
for param_group in optimizer.param_groups:
for param in param_group["params"]:
state = optimizer.state[param]
for key, value in state.items():
if isinstance(value, torch.Tensor):
state[key] = value.to(device_id, non_blocking=True)
optimizers_to_process = []
if hasattr(optimizer, "_is_multi_optimizer") and optimizer._is_multi_optimizer:
optimizers_to_process.extend(optimizer.optimizers_dict.values())
else:
optimizers_to_process.append(optimizer)
for opt in optimizers_to_process:
if not opt.state:
continue
for param_group in opt.param_groups:
for param in param_group["params"]:
if param not in opt.state:
continue
state = opt.state[param]
for key, value in state.items():
if isinstance(value, torch.Tensor):
state[key] = value.to(device_id, non_blocking=True)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant