Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
e6b7cab
feat(config): Added OmegaConf based serializer save_yaml_config_dict().
BlueCrescent Nov 27, 2025
9fa51ec
feat(huggingface): Added conversion of distributed gpt2 checkpoints t…
BlueCrescent Nov 27, 2025
a73de85
chore: Merge branch 'fix_rotary_transform_deferred_init' into hf_chec…
BlueCrescent Nov 28, 2025
d7d0956
refactor: More robust parent directory path handling.
BlueCrescent Nov 28, 2025
8957f19
docs: better dcp to torch conversion docstring
BlueCrescent Nov 28, 2025
527a0d2
fix: Added handling for missing directory.
BlueCrescent Nov 28, 2025
95cead4
fix: use Path instead of string
BlueCrescent Nov 28, 2025
b8cf4ea
fix: use cpu device for dcp to torch converted checkpoints
BlueCrescent Nov 28, 2025
652e77a
fix: error handling if wrong model key is set in checkpoint conversion
BlueCrescent Nov 28, 2025
fca72dc
feat(utility): Moved MultiProcessingCudaEnv from tests to modalities.
BlueCrescent Dec 2, 2025
ace93c7
feat(utility): Added option to set init_process_group kwargs in cuda …
BlueCrescent Dec 3, 2025
53eb907
feat(utility): Extended get_model_from_config for distributed checkpo…
BlueCrescent Dec 3, 2025
3a4b46c
feat(huggingface): Added dcp specific conversion verification logic.
BlueCrescent Dec 3, 2025
642466d
fix(huggingface): Better dcp config conversion.
BlueCrescent Dec 3, 2025
f54abc6
feat(config): Added interoperability between PyTorchDtypes and Precis…
BlueCrescent Dec 3, 2025
3fbe498
fix(huggingface): Correct conversion of model dtype.
BlueCrescent Dec 3, 2025
ee4e244
fix(config): circular import
BlueCrescent Dec 3, 2025
1b4cfe0
feat(checkpointing): improvements for dcp to torch checkpoint conversion
BlueCrescent Dec 5, 2025
3a67ed9
revert(config): Removed PrecisionEnum <-> PyTorchDtypes interoperabil…
BlueCrescent Dec 5, 2025
ddbb8cc
fix(huggingface): output parity between dcp and converted hf checkpoints
BlueCrescent Dec 5, 2025
5a36d48
fix(model): Corrected type casting in rotary pos embeddings to match …
BlueCrescent Dec 5, 2025
bce2ae1
feat(utility): Added weights printing to print_forward_hook.
BlueCrescent Dec 8, 2025
5da0e7f
fix(requirements): Excluded bugged transformers versions.
BlueCrescent Dec 8, 2025
f902152
feat(utility): Added EnvOverride utility for temporary changing envir…
BlueCrescent Dec 9, 2025
d520095
fix(huggingface): Setting some environment variables when loading dcp…
BlueCrescent Dec 9, 2025
42a7e42
fix(checkpointing): Moved EnvOverride into load_dcp_config so that al…
BlueCrescent Dec 11, 2025
03e07f5
fix(huggingface): Made single node dcp config generation more robust …
BlueCrescent Dec 11, 2025
8a9ff2f
test(utility): Made manager shutdown in monitor_child_processes optio…
BlueCrescent Dec 11, 2025
9ae218d
test(huggingface): Added unit tests for dcp to hf conversion.
BlueCrescent Dec 11, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ dependencies = [
"packaging",
"tqdm",
"pyyaml",
"transformers",
"transformers!=4.57.2,!=4.57.3", # Exclude versions with known issues. Can probably be removed if a version >4.57.3 is released.
"datasets",
"protobuf",
"SentencePiece",
Expand Down
111 changes: 111 additions & 0 deletions src/modalities/checkpointing/convert_dcp_to_torch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
import os
from pathlib import Path

import torch
from torch.distributed.checkpoint.default_planner import _EmptyStateDictLoadPlanner
from torch.distributed.checkpoint.filesystem import FileSystemReader
from torch.distributed.checkpoint.metadata import STATE_DICT_TYPE
from torch.distributed.checkpoint.state_dict_loader import _load_state_dict

from modalities.config.config import ConfigDictType, load_app_config_dict, save_yaml_config_dict
from modalities.utils.env import EnvOverride


def convert_dcp_to_torch(dcp_checkpoint_dir: str, output_dir: str, model_key: str = "model_raw") -> str:
"""Converts a DCP (Distributed Checkpoint) checkpoint—including
FSDP2, PP, or TP checkpoints—to a standard PyTorch checkpoint.

Args:
dcp_checkpoint_dir (str): Directory containing the DCP checkpoint files (may include FSDP2, PP, or TP).
output_dir (str): Directory to save the converted PyTorch checkpoint.
model_key (str): Key of the model configuration in the modalities config.
Returns:
str: Path to the converted config file.
"""
os.makedirs(output_dir, exist_ok=True)
torch_checkpoint_file = os.path.join(output_dir, "pytorch_model.bin")
torch_config_file = convert_config_file(dcp_checkpoint_dir, output_dir, model_key, torch_checkpoint_file)
# TODO This is the (adapted) code from torch's dcp_to_torch_save(dcp_checkpoint_dir, torch_checkpoint_file)
# since we only want to convert the model state dict here. In future torch versions this function might
# support converting only parts of the checkpoint.
# (from torch.distributed.checkpoint.format_utils import dcp_to_torch_save)
sd: STATE_DICT_TYPE = {}
planner = _EmptyStateDictLoadPlanner(keys=["app.model"], allow_partial_load=True)
_load_state_dict(sd, storage_reader=FileSystemReader(dcp_checkpoint_dir), planner=planner, no_dist=True)
torch.save(sd["app"]["model"], torch_checkpoint_file)
return torch_config_file


def convert_config_file(dcp_checkpoint_dir: str, output_dir: str, model_key: str, torch_checkpoint_file: str) -> str:
"""Converts the modalities config file for DCP to a config file for standard PyTorch checkpoint loading.
Args:
dcp_checkpoint_dir (str): Directory containing the DCP checkpoint files.
output_dir (str): Directory to save the converted config file.
model_key (str): Key of the model configuration in the modalities config.
torch_checkpoint_file (str): Path to the converted PyTorch checkpoint file.
Returns:
str: Path to the converted config file.
"""
config_src, dcp_config = load_dcp_config(dcp_checkpoint_dir)
config_dst: str = os.path.join(output_dir, os.path.basename(config_src))
if os.path.exists(config_dst):
raise FileExistsError(f"Config file '{config_dst}' already exists.")
torch_config: ConfigDictType = {
"checkpointed_model": {
"component_key": "model",
"variant_key": "fsdp1_checkpointed",
"config": {
"checkpoint_loading": {
"component_key": "checkpoint_loading",
"variant_key": "torch",
"config": {
"device": "cpu",
"precision": "FP32",
},
},
"model": {
"instance_key": "model",
"pass_type": "BY_REFERENCE",
},
"checkpoint_path": torch_checkpoint_file,
},
},
}
if model_key not in dcp_config:
raise KeyError(
f"Model key '{model_key}' not found in config file '{config_src}'."
f" Available keys: {list(dcp_config.keys())}"
)
torch_config["model"] = dcp_config[model_key]
torch_config["model"]["config"]["use_meta_device"] = False
save_yaml_config_dict(torch_config, Path(config_dst))
return config_dst


def load_dcp_config(dcp_checkpoint_dir: str) -> tuple[str, ConfigDictType]:
with EnvOverride({"LOCAL_RANK": "0", "RANK": "0", "WORLD_SIZE": "1"}):
config_src: str | None = find_yaml_config_in_dir(dcp_checkpoint_dir)
if config_src is None:
config_src = find_yaml_config_in_dir(str(Path(dcp_checkpoint_dir).parent))
if config_src is None:
raise FileNotFoundError("No YAML config file found in checkpoint directory or its parent.")
dcp_config = load_app_config_dict(Path(config_src), experiment_id="-1")
return config_src, dcp_config


def find_yaml_config_in_dir(directory: str) -> str | None:
"""Finds the first YAML config file in the given directory.

Args:
directory (str): Directory to search for YAML files.

Returns:
str | None: Path to the found YAML file or None if not found.
"""
if not os.path.isdir(directory) or not os.access(directory, os.R_OK):
# Directory does not exist or is not accessible
return None
for filename in os.listdir(directory):
if filename.endswith(".yaml") or filename.endswith(".yml"):
return os.path.join(directory, filename)
return None
21 changes: 18 additions & 3 deletions src/modalities/config/config.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import os
from functools import partial
from pathlib import Path
Expand Down Expand Up @@ -497,13 +499,14 @@ class ParallelDegreeConfig(BaseModel):
# Recursive type representing arbitrary-depth YAML config structures.
YAMLPrimitive = str | int | float | bool | None
YAMLValue: TypeAlias = YAMLPrimitive | Path | list["YAMLValue"] | dict[str, "YAMLValue"]
ConfigDictType: TypeAlias = dict[str, YAMLValue]


def load_app_config_dict(
config_file_path: Path,
experiment_id: Optional[str] = None,
additional_resolver_funs: Optional[dict[str, Resolver]] = None,
) -> dict[str, YAMLValue]:
) -> ConfigDictType:
"""Load the application configuration from the given YAML file.

Args:
Expand All @@ -512,7 +515,7 @@ def load_app_config_dict(
additional_resolver_funs (dict[str, Resolver], optional): Additional resolver functions.

Returns:
dict[str, YAMLValue]: Dictionary representation of the config file with arbitrary depth.
ConfigDictType: Dictionary representation of the config file with arbitrary depth.
"""

def cuda_env_resolver_fun(var_name: str) -> int | str | None:
Expand All @@ -528,6 +531,7 @@ def modalities_env_resolver_fun(var_name: str, kwargs: dict[str, Any]) -> str |
def node_env_resolver_fun(var_name: str) -> int | None:
if var_name == "num_cpus":
return os.cpu_count()
return None

OmegaConf.register_new_resolver("cuda_env", cuda_env_resolver_fun, replace=True)
modalities_env_kwargs: dict[str, Any] = {
Expand All @@ -546,6 +550,17 @@ def node_env_resolver_fun(var_name: str) -> int | None:
OmegaConf.register_new_resolver(resolver_name, resolver_fun, replace=True)

cfg = OmegaConf.load(config_file_path)
config_dict = cast(dict[str, YAMLValue], OmegaConf.to_container(cfg, resolve=True))
config_dict = cast(ConfigDictType, OmegaConf.to_container(cfg, resolve=True))

return config_dict


def save_yaml_config_dict(config_dict: ConfigDictType, output_file_path: Path) -> None:
"""Saves the given config dictionary as a YAML file.

Args:
config_dict (ConfigDictType): Configuration dictionary to save.
output_file_path (Path): Path to the output YAML file.
"""
cfg = OmegaConf.create(config_dict)
OmegaConf.save(cfg, output_file_path)
119 changes: 109 additions & 10 deletions src/modalities/conversion/gpt2/conversion_model.py
Original file line number Diff line number Diff line change
@@ -1,40 +1,49 @@
import warnings

import torch
import torch.nn as nn
from tqdm import tqdm

from modalities.checkpointing.convert_dcp_to_torch import load_dcp_config
from modalities.config.config import ConfigDictType, PrecisionEnum, ProcessGroupBackendType
from modalities.conversion.gpt2.configuration_gpt2 import GPT2Config
from modalities.conversion.gpt2.modeling_gpt2 import GPT2DecoderLayer, GPT2ForCausalLM
from modalities.models.components.layer_norms import LayerNormConfig
from modalities.models.gpt2.gpt2_model import GPT2LLM, GPT2Block, PositionTypes
from modalities.models.model import SwiGLU
from modalities.models.utils import ModelTypeEnum, get_model_from_config
from modalities.running_env.cuda_env import MultiProcessingCudaEnv
from modalities.running_env.env_utils import PyTorchDtypes


def convert_model_checkpoint(modalities_config: dict) -> tuple[GPT2ForCausalLM, GPT2LLM]:
def convert_model_checkpoint(modalities_config: ConfigDictType) -> tuple[GPT2ForCausalLM, GPT2LLM]:
"""Converts the modalities model to a Huggingface transformers model.
Both the loaded modalities model and the converted Huggingface model are returned
so that they can be compared.

Args:
modalities_config (dict): Modalities config dictionary.
modalities_config (ConfigDictType): Modalities config dictionary.

Returns:
tuple[GPT2ForCausalLM, GPT2LLM]: Converted Hugging Face model and the original modalities model.
"""
gpt2_config = convert_model_config(modalities_config)
hf_model = GPT2ForCausalLM(gpt2_config).to(dtype=torch.bfloat16)
dtype = PrecisionEnum(
modalities_config["checkpointed_model"]["config"]["checkpoint_loading"]["config"]["precision"]
)
hf_model = GPT2ForCausalLM(gpt2_config).to(dtype=dtype.value)
modalities_model = get_model_from_config(modalities_config, model_type=ModelTypeEnum.CHECKPOINTED_MODEL)
_copy_weights_model(hf_model, modalities_model)
return hf_model, modalities_model


def convert_model_config(modalities_config: dict) -> GPT2Config:
def convert_model_config(modalities_config: ConfigDictType) -> GPT2Config:
"""Converts the modalities model configuration to a Huggingface transformers configuration.
For this the model_raw or model section of the modalities config is used.
Corresponding entries are mapped to the Huggingface configuration.

Args:
modalities_config (dict): Modalities config dictionary.
modalities_config (ConfigDictType): Modalities config dictionary.

Returns:
GPT2Config: Converted Huggingface model configuration.
Expand All @@ -43,6 +52,12 @@ def convert_model_config(modalities_config: dict) -> GPT2Config:
_check_conversion_criteria(config)

ffn_norm_key = "ffn_norm_config"
attention_type = _map_attention_type(config)
if attention_type != "sdpa":
warnings.warn(
f"transformers checkpoint will not save the attention implementation "
f"(set to {attention_type}) and use sdpa by default."
)

return GPT2Config(
vocab_size=config["vocab_size"],
Expand All @@ -62,11 +77,24 @@ def convert_model_config(modalities_config: dict) -> GPT2Config:
layer_norm_bias=_get_layer_norm_value(config[ffn_norm_key]["config"], "bias"),
max_position_embeddings=config["sequence_length"],
rope_theta=config["attention_config"]["qkv_transforms"][0]["config"]["base_freq"],
_attn_implementation=_map_attention_type(config),
attn_implementation=attention_type,
output_attentions=False,
)


def check_converted_dcp_model(
hf_model_dir: str, dcp_dir: str, num_testruns: int, device_id_modalities: str | int, device_hf: str
):
new_config: ConfigDictType = _build_single_node_dcp_config(dcp_dir)
hf_model = _load_hf_model_for_dcp_comparison(hf_model_dir, new_config, device_hf)
vocab_size: int = new_config["model_raw" if "model_raw" in new_config else "model"]["config"]["vocab_size"]
if isinstance(device_id_modalities, str):
device_id_modalities = int(device_id_modalities.replace("cuda:", ""))
with MultiProcessingCudaEnv(ProcessGroupBackendType.nccl, 0, 0, 1, 24570, device_id=device_id_modalities):
modalities_model = get_model_from_config(new_config, model_type=ModelTypeEnum.DCP_CHECKPOINTED_MODEL)
check_converted_model(hf_model, modalities_model, num_testruns=num_testruns, vocab_size=vocab_size)


def check_converted_model(hf_model: GPT2ForCausalLM, modalities_model: GPT2LLM, num_testruns: int, vocab_size: int):
"""Tests the converted model by inputting a random token sequence and comparing the output logits of both models.

Expand All @@ -85,14 +113,85 @@ def check_converted_model(hf_model: GPT2ForCausalLM, modalities_model: GPT2LLM,
modalities_logits = modalities_model(inputs)[modalities_model.prediction_key].to("cpu")

assert llama_logits.shape == modalities_logits.shape
assert llama_logits.dtype == modalities_logits.dtype
assert torch.equal(llama_logits, modalities_logits)


def _check_conversion_criteria(model_config: dict) -> None:
def _build_single_node_dcp_config(dcp_dir: str) -> ConfigDictType:
"""Builds a modalities config dictionary for loading a DCP checkpointed model on a single node.

Args:
dcp_dir (str): Directory containing the DCP checkpoint.

Returns:
ConfigDictType: New modalities config dictionary for loading the DCP checkpointed model.
"""
_, dcp_config = load_dcp_config(dcp_dir)
model_key = "model_raw" if "model_raw" in dcp_config else "model"
new_config: ConfigDictType = {
"fsdp_model": dcp_config["fsdp_model"],
"initialized_model": dcp_config["initialized_model"],
model_key: dcp_config[model_key],
}
if "settings" in dcp_config:
new_config["settings"] = dcp_config["settings"]
new_config["settings"]["config_file_path"] = "converted_dcp_config.yaml"
if "dp_degree" in dcp_config:
new_config["dp_degree"] = dcp_config["dp_degree"]
if "optimizer" in dcp_config:
new_config["optimizer"] = dcp_config["optimizer"]
if "lr_scheduler" in dcp_config:
new_config["lr_scheduler"] = dcp_config["lr_scheduler"]
new_config["app_state"] = {
"component_key": "app_state",
"variant_key": "dcp",
"config": {
"raw_app_state": dcp_config["app_state_raw" if "app_state_raw" in dcp_config else "app_state"],
"checkpoint_dir_path": dcp_dir,
},
}
new_config["device_mesh"] = {
"component_key": "device_mesh",
"variant_key": "default",
"config": {
"device_type": "cuda",
"data_parallel_shard_degree": 1,
"world_size": 1,
},
}
new_config["fsdp_model"]["config"]["model"]["instance_key"] = model_key
new_config["initialized_model"]["config"]["model"] = {"instance_key": "fsdp_model", "pass_type": "BY_REFERENCE"}
return new_config


def _load_hf_model_for_dcp_comparison(
hf_model_dir: str, dcp_modalities_config: ConfigDictType, device_hf: str
) -> GPT2ForCausalLM:
# Need execution dtype of FSDP2 to get same outputs from model.
dtype = dcp_modalities_config["fsdp_model"]["config"]["mixed_precision_settings"]["param_dtype"]
hf_model: GPT2ForCausalLM = (
GPT2ForCausalLM.from_pretrained(hf_model_dir, local_files_only=True, trust_remote_code=True)
.to(device=device_hf)
.to(PyTorchDtypes(dtype).value)
)
# Need to match attention implementation
hf_model.config._attn_implementation = _map_attention_type(
dcp_modalities_config["model_raw" if "model_raw" in dcp_modalities_config else "model"]["config"]
)
# Rotary embedding frequencies are not downcasted in FSDP2.
# Therefore, we need to ensure they remain in the original precision.
hf_model.model.rotary_emb.inv_freq = hf_model.model.rotary_emb.original_inv_freq.to(
hf_model.model.rotary_emb.inv_freq.device
)

return hf_model


def _check_conversion_criteria(model_config: ConfigDictType) -> None:
"""Checks that the modalities config fulfills criteria necessary for conversion

Args:
model_config (dict): model or model_raw part of the Modalities config dictionary.
model_config (ConfigDictType): model or model_raw part of the Modalities config dictionary.

Returns:
None
Expand All @@ -116,12 +215,12 @@ def _check_conversion_criteria(model_config: dict) -> None:
), "All norms must have the same eps setting."


def _get_layer_norm_value(config: dict, field: str) -> bool | float | int:
def _get_layer_norm_value(config: ConfigDictType, field: str) -> bool | float | int:
default = LayerNormConfig.model_fields[field].default
return config.get(field, default)


def _map_attention_type(config: dict):
def _map_attention_type(config: ConfigDictType) -> str:
if config["attention_implementation"] == "pytorch_flash":
attention_impl = "sdpa"
elif config["attention_implementation"] == "manual":
Expand Down
Loading