-
Notifications
You must be signed in to change notification settings - Fork 3k
[WIP][veomni] feat: support offloading/loading the veomni model/optimizer #4916
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces functionality to offload and load veomni models and optimizers to/from CPU, controlled by new configuration flags. The implementation is mostly sound, but I've identified a significant issue in the optimizer handling logic. The functions for offloading and loading optimizer state contain duplicated code and a potential bug that could prevent MultiOptimizer states from being processed correctly. I've provided suggestions to refactor this logic to improve its correctness and maintainability.
| if not optimizer.state: | ||
| return | ||
|
|
||
| # Check if this is a MultiOptimizer (for ep and non-ep parameters when ep+fsdp2 is enabled) | ||
| if hasattr(optimizer, "_is_multi_optimizer") and optimizer._is_multi_optimizer: | ||
| for sub_opt in optimizer.optimizers_dict.values(): | ||
| if not sub_opt.state: | ||
| continue | ||
| for param_group in sub_opt.param_groups: | ||
| for param in param_group["params"]: | ||
| state = sub_opt.state[param] | ||
| for key, value in state.items(): | ||
| if isinstance(value, torch.Tensor): | ||
| state[key] = value.to("cpu", non_blocking=True) | ||
| else: | ||
| for param_group in optimizer.param_groups: | ||
| for param in param_group["params"]: | ||
| state = optimizer.state[param] | ||
| for key, value in state.items(): | ||
| if isinstance(value, torch.Tensor): | ||
| state[key] = value.to("cpu", non_blocking=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The logic for offloading optimizer state is duplicated for single optimizers and MultiOptimizer. This can be refactored into a single loop to improve readability and maintainability. Additionally, the initial check if not optimizer.state: return could cause a bug for MultiOptimizer if its top-level state dictionary is empty but its sub-optimizers have state, causing the function to exit prematurely without offloading.
| if not optimizer.state: | |
| return | |
| # Check if this is a MultiOptimizer (for ep and non-ep parameters when ep+fsdp2 is enabled) | |
| if hasattr(optimizer, "_is_multi_optimizer") and optimizer._is_multi_optimizer: | |
| for sub_opt in optimizer.optimizers_dict.values(): | |
| if not sub_opt.state: | |
| continue | |
| for param_group in sub_opt.param_groups: | |
| for param in param_group["params"]: | |
| state = sub_opt.state[param] | |
| for key, value in state.items(): | |
| if isinstance(value, torch.Tensor): | |
| state[key] = value.to("cpu", non_blocking=True) | |
| else: | |
| for param_group in optimizer.param_groups: | |
| for param in param_group["params"]: | |
| state = optimizer.state[param] | |
| for key, value in state.items(): | |
| if isinstance(value, torch.Tensor): | |
| state[key] = value.to("cpu", non_blocking=True) | |
| optimizers_to_process = [] | |
| if hasattr(optimizer, "_is_multi_optimizer") and optimizer._is_multi_optimizer: | |
| optimizers_to_process.extend(optimizer.optimizers_dict.values()) | |
| else: | |
| optimizers_to_process.append(optimizer) | |
| for opt in optimizers_to_process: | |
| if not opt.state: | |
| continue | |
| for param_group in opt.param_groups: | |
| for param in param_group["params"]: | |
| if param not in opt.state: | |
| continue | |
| state = opt.state[param] | |
| for key, value in state.items(): | |
| if isinstance(value, torch.Tensor): | |
| state[key] = value.to("cpu", non_blocking=True) |
| if not optimizer.state: | ||
| return | ||
|
|
||
| # Check if this is a MultiOptimizer (for ep and non-ep parameters when ep+fsdp2 is enabled) | ||
| if hasattr(optimizer, "_is_multi_optimizer") and optimizer._is_multi_optimizer: | ||
| for sub_opt in optimizer.optimizers_dict.values(): | ||
| if not sub_opt.state: | ||
| continue | ||
| for param_group in sub_opt.param_groups: | ||
| for param in param_group["params"]: | ||
| state = sub_opt.state[param] | ||
| for key, value in state.items(): | ||
| if isinstance(value, torch.Tensor): | ||
| state[key] = value.to(device_id, non_blocking=True) | ||
| else: | ||
| for param_group in optimizer.param_groups: | ||
| for param in param_group["params"]: | ||
| state = optimizer.state[param] | ||
| for key, value in state.items(): | ||
| if isinstance(value, torch.Tensor): | ||
| state[key] = value.to(device_id, non_blocking=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similar to offload_veomni_optimizer, the logic here is duplicated and contains a potential bug with the initial if not optimizer.state: check for MultiOptimizer. This can be refactored to improve correctness and maintainability.
| if not optimizer.state: | |
| return | |
| # Check if this is a MultiOptimizer (for ep and non-ep parameters when ep+fsdp2 is enabled) | |
| if hasattr(optimizer, "_is_multi_optimizer") and optimizer._is_multi_optimizer: | |
| for sub_opt in optimizer.optimizers_dict.values(): | |
| if not sub_opt.state: | |
| continue | |
| for param_group in sub_opt.param_groups: | |
| for param in param_group["params"]: | |
| state = sub_opt.state[param] | |
| for key, value in state.items(): | |
| if isinstance(value, torch.Tensor): | |
| state[key] = value.to(device_id, non_blocking=True) | |
| else: | |
| for param_group in optimizer.param_groups: | |
| for param in param_group["params"]: | |
| state = optimizer.state[param] | |
| for key, value in state.items(): | |
| if isinstance(value, torch.Tensor): | |
| state[key] = value.to(device_id, non_blocking=True) | |
| optimizers_to_process = [] | |
| if hasattr(optimizer, "_is_multi_optimizer") and optimizer._is_multi_optimizer: | |
| optimizers_to_process.extend(optimizer.optimizers_dict.values()) | |
| else: | |
| optimizers_to_process.append(optimizer) | |
| for opt in optimizers_to_process: | |
| if not opt.state: | |
| continue | |
| for param_group in opt.param_groups: | |
| for param in param_group["params"]: | |
| if param not in opt.state: | |
| continue | |
| state = opt.state[param] | |
| for key, value in state.items(): | |
| if isinstance(value, torch.Tensor): | |
| state[key] = value.to(device_id, non_blocking=True) |
0a40e09 to
7d850e9
Compare
What does this PR do?
As title.
Given the significant differences between the implementations of fsdp1+ep and fsdp2+ep in Veomni, I will prioritize completing the functional verification of rl based on fsdp2.
Checklist Before Starting
[{modules}] {type}: {description}(This will be checked by the CI){modules}includefsdp,megatron,veomni,sglang,vllm,rollout,trainer,ci,training_utils,recipe,hardware,deployment,ray,worker,single_controller,misc,perf,model,algo,env,tool,ckpt,doc,data,cfg,reward,like[megatron, fsdp, doc]{type}is infeat,fix,refactor,chore,test[BREAKING]to the beginning of the title.[BREAKING][fsdp, megatron] feat: dynamic batchingTest
API and Usage Example
# Add code snippet or script demonstrating how to use thisDesign & Code Changes
Checklist Before Submitting
Important
Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review.
pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=alwaysci-requestchannel in theverlSlack workspace. (If not accessible, please try the Feishu group (飞书群).)recipesubmodule, please also update the reference to the submodule commit viagit submodule update --remoteorcd recipe && git pull origin main.