-
Notifications
You must be signed in to change notification settings - Fork 12
Hf checkpoint conversion for distributed checkpoints #424
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
BlueCrescent
wants to merge
29
commits into
main
Choose a base branch
from
hf_checkpoint_conversion_for_fsdp2
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
29 commits
Select commit
Hold shift + click to select a range
e6b7cab
feat(config): Added OmegaConf based serializer save_yaml_config_dict().
BlueCrescent 9fa51ec
feat(huggingface): Added conversion of distributed gpt2 checkpoints t…
BlueCrescent a73de85
chore: Merge branch 'fix_rotary_transform_deferred_init' into hf_chec…
BlueCrescent d7d0956
refactor: More robust parent directory path handling.
BlueCrescent 8957f19
docs: better dcp to torch conversion docstring
BlueCrescent 527a0d2
fix: Added handling for missing directory.
BlueCrescent 95cead4
fix: use Path instead of string
BlueCrescent b8cf4ea
fix: use cpu device for dcp to torch converted checkpoints
BlueCrescent 652e77a
fix: error handling if wrong model key is set in checkpoint conversion
BlueCrescent fca72dc
feat(utility): Moved MultiProcessingCudaEnv from tests to modalities.
BlueCrescent ace93c7
feat(utility): Added option to set init_process_group kwargs in cuda …
BlueCrescent 53eb907
feat(utility): Extended get_model_from_config for distributed checkpo…
BlueCrescent 3a4b46c
feat(huggingface): Added dcp specific conversion verification logic.
BlueCrescent 642466d
fix(huggingface): Better dcp config conversion.
BlueCrescent f54abc6
feat(config): Added interoperability between PyTorchDtypes and Precis…
BlueCrescent 3fbe498
fix(huggingface): Correct conversion of model dtype.
BlueCrescent ee4e244
fix(config): circular import
BlueCrescent 1b4cfe0
feat(checkpointing): improvements for dcp to torch checkpoint conversion
BlueCrescent 3a67ed9
revert(config): Removed PrecisionEnum <-> PyTorchDtypes interoperabil…
BlueCrescent ddbb8cc
fix(huggingface): output parity between dcp and converted hf checkpoints
BlueCrescent 5a36d48
fix(model): Corrected type casting in rotary pos embeddings to match …
BlueCrescent bce2ae1
feat(utility): Added weights printing to print_forward_hook.
BlueCrescent 5da0e7f
fix(requirements): Excluded bugged transformers versions.
BlueCrescent f902152
feat(utility): Added EnvOverride utility for temporary changing envir…
BlueCrescent d520095
fix(huggingface): Setting some environment variables when loading dcp…
BlueCrescent 42a7e42
fix(checkpointing): Moved EnvOverride into load_dcp_config so that al…
BlueCrescent 03e07f5
fix(huggingface): Made single node dcp config generation more robust …
BlueCrescent 8a9ff2f
test(utility): Made manager shutdown in monitor_child_processes optio…
BlueCrescent 9ae218d
test(huggingface): Added unit tests for dcp to hf conversion.
BlueCrescent File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,111 @@ | ||
| import os | ||
| from pathlib import Path | ||
|
|
||
| import torch | ||
| from torch.distributed.checkpoint.default_planner import _EmptyStateDictLoadPlanner | ||
| from torch.distributed.checkpoint.filesystem import FileSystemReader | ||
| from torch.distributed.checkpoint.metadata import STATE_DICT_TYPE | ||
| from torch.distributed.checkpoint.state_dict_loader import _load_state_dict | ||
|
|
||
| from modalities.config.config import ConfigDictType, load_app_config_dict, save_yaml_config_dict | ||
| from modalities.utils.env import EnvOverride | ||
|
|
||
|
|
||
| def convert_dcp_to_torch(dcp_checkpoint_dir: str, output_dir: str, model_key: str = "model_raw") -> str: | ||
| """Converts a DCP (Distributed Checkpoint) checkpoint—including | ||
| FSDP2, PP, or TP checkpoints—to a standard PyTorch checkpoint. | ||
|
|
||
| Args: | ||
| dcp_checkpoint_dir (str): Directory containing the DCP checkpoint files (may include FSDP2, PP, or TP). | ||
| output_dir (str): Directory to save the converted PyTorch checkpoint. | ||
| model_key (str): Key of the model configuration in the modalities config. | ||
| Returns: | ||
| str: Path to the converted config file. | ||
| """ | ||
| os.makedirs(output_dir, exist_ok=True) | ||
| torch_checkpoint_file = os.path.join(output_dir, "pytorch_model.bin") | ||
| torch_config_file = convert_config_file(dcp_checkpoint_dir, output_dir, model_key, torch_checkpoint_file) | ||
| # TODO This is the (adapted) code from torch's dcp_to_torch_save(dcp_checkpoint_dir, torch_checkpoint_file) | ||
| # since we only want to convert the model state dict here. In future torch versions this function might | ||
| # support converting only parts of the checkpoint. | ||
| # (from torch.distributed.checkpoint.format_utils import dcp_to_torch_save) | ||
| sd: STATE_DICT_TYPE = {} | ||
| planner = _EmptyStateDictLoadPlanner(keys=["app.model"], allow_partial_load=True) | ||
| _load_state_dict(sd, storage_reader=FileSystemReader(dcp_checkpoint_dir), planner=planner, no_dist=True) | ||
| torch.save(sd["app"]["model"], torch_checkpoint_file) | ||
| return torch_config_file | ||
|
|
||
|
|
||
| def convert_config_file(dcp_checkpoint_dir: str, output_dir: str, model_key: str, torch_checkpoint_file: str) -> str: | ||
| """Converts the modalities config file for DCP to a config file for standard PyTorch checkpoint loading. | ||
| Args: | ||
| dcp_checkpoint_dir (str): Directory containing the DCP checkpoint files. | ||
| output_dir (str): Directory to save the converted config file. | ||
| model_key (str): Key of the model configuration in the modalities config. | ||
| torch_checkpoint_file (str): Path to the converted PyTorch checkpoint file. | ||
| Returns: | ||
| str: Path to the converted config file. | ||
| """ | ||
| config_src, dcp_config = load_dcp_config(dcp_checkpoint_dir) | ||
| config_dst: str = os.path.join(output_dir, os.path.basename(config_src)) | ||
| if os.path.exists(config_dst): | ||
| raise FileExistsError(f"Config file '{config_dst}' already exists.") | ||
| torch_config: ConfigDictType = { | ||
| "checkpointed_model": { | ||
| "component_key": "model", | ||
| "variant_key": "fsdp1_checkpointed", | ||
| "config": { | ||
| "checkpoint_loading": { | ||
| "component_key": "checkpoint_loading", | ||
| "variant_key": "torch", | ||
| "config": { | ||
| "device": "cpu", | ||
| "precision": "FP32", | ||
| }, | ||
| }, | ||
| "model": { | ||
| "instance_key": "model", | ||
| "pass_type": "BY_REFERENCE", | ||
| }, | ||
| "checkpoint_path": torch_checkpoint_file, | ||
| }, | ||
| }, | ||
| } | ||
| if model_key not in dcp_config: | ||
| raise KeyError( | ||
| f"Model key '{model_key}' not found in config file '{config_src}'." | ||
| f" Available keys: {list(dcp_config.keys())}" | ||
| ) | ||
| torch_config["model"] = dcp_config[model_key] | ||
BlueCrescent marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| torch_config["model"]["config"]["use_meta_device"] = False | ||
| save_yaml_config_dict(torch_config, Path(config_dst)) | ||
| return config_dst | ||
|
|
||
|
|
||
| def load_dcp_config(dcp_checkpoint_dir: str) -> tuple[str, ConfigDictType]: | ||
| with EnvOverride({"LOCAL_RANK": "0", "RANK": "0", "WORLD_SIZE": "1"}): | ||
| config_src: str | None = find_yaml_config_in_dir(dcp_checkpoint_dir) | ||
| if config_src is None: | ||
| config_src = find_yaml_config_in_dir(str(Path(dcp_checkpoint_dir).parent)) | ||
| if config_src is None: | ||
| raise FileNotFoundError("No YAML config file found in checkpoint directory or its parent.") | ||
| dcp_config = load_app_config_dict(Path(config_src), experiment_id="-1") | ||
| return config_src, dcp_config | ||
|
|
||
|
|
||
| def find_yaml_config_in_dir(directory: str) -> str | None: | ||
| """Finds the first YAML config file in the given directory. | ||
|
|
||
| Args: | ||
| directory (str): Directory to search for YAML files. | ||
|
|
||
| Returns: | ||
| str | None: Path to the found YAML file or None if not found. | ||
| """ | ||
BlueCrescent marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| if not os.path.isdir(directory) or not os.access(directory, os.R_OK): | ||
| # Directory does not exist or is not accessible | ||
| return None | ||
| for filename in os.listdir(directory): | ||
| if filename.endswith(".yaml") or filename.endswith(".yml"): | ||
| return os.path.join(directory, filename) | ||
| return None | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.