Skip to content

Crash with final checkpoint (finished the number of steps) #32

@tomiock

Description

@tomiock

Bug description

The final checkpoint cannot be loaded, the program crashes.

Versions

  1. Train until the the number of steps is completed
  2. Increase the number for a new run
  3. Try to run it
  4. Crashes

Traceback:

[rank0]:Traceback (most recent call last):                                                                                                                                                     09:14:44 [110/1928]
[rank0]:  File "<frozen runpy>", line 198, in _run_module_as_main
[rank0]:  File "<frozen runpy>", line 88, in _run_code
[rank0]:  File "/home-local/tockier/torchtitan/torchtitan/train.py", line 675, in <module>
[rank0]:    trainer.train()
[rank0]:  File "/home-local/tockier/torchtitan/.venv/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 358, in wrapper
[rank0]:    return f(*args, **kwargs)
[rank0]:           ^^^^^^^^^^^^^^^^^^
[rank0]:  File "/home-local/tockier/torchtitan/torchtitan/train.py", line 557, in train
[rank0]:    self.checkpointer.load(step=job_config.checkpoint.load_step)
[rank0]:  File "/home-local/tockier/torchtitan/.venv/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 121, in decorate_context
[rank0]:    return func(*args, **kwargs)
[rank0]:           ^^^^^^^^^^^^^^^^^^^^^
[rank0]:  File "/home-local/tockier/torchtitan/torchtitan/components/checkpoint.py", line 599, in load
[rank0]:    self.dcp_load(
[rank0]:  File "/home-local/tockier/torchtitan/torchtitan/components/checkpoint.py", line 443, in dcp_load
[rank0]:    dcp.load(state_dict, checkpoint_id=checkpoint_id)
[rank0]:  File "/home-local/tockier/torchtitan/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/logger.py", line 88, in wrapper
[rank0]:    result = func(*args, **kwargs)
[rank0]:             ^^^^^^^^^^^^^^^^^^^^^
[rank0]:  File "/home-local/tockier/torchtitan/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/utils.py", line 475, in inner_func
[rank0]:    return func(*args, **kwargs)
[rank0]:           ^^^^^^^^^^^^^^^^^^^^^
[rank0]:  File "/home-local/tockier/torchtitan/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/state_dict_loader.py", line 184, in load
[rank0]:    _load_state_dict(
[rank0]:  File "/home-local/tockier/torchtitan/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/state_dict_loader.py", line 279, in _load_state_dict
[rank0]:    central_plan = distW.reduce_scatter("plan", local_step, global_step)
[rank0]:                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:  File "/home-local/tockier/torchtitan/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/utils.py", line 219, in reduce_scatter
[rank0]:    raise result
[rank0]:torch.distributed.checkpoint.api.CheckpointException: CheckpointException ranks:dict_keys([0])
[rank0]:Traceback (most recent call last): (RANK 0)
[rank0]:  File "/home-local/tockier/torchtitan/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/utils.py", line 192, in reduce_scatter
[rank0]:    local_data = map_fun()
[rank0]:                 ^^^^^^^^^
[rank0]:  File "/home-local/tockier/torchtitan/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/logger.py", line 88, in wrapper
[rank0]:    result = func(*args, **kwargs)
[rank0]:             ^^^^^^^^^^^^^^^^^^^^^
[rank0]:  File "/home-local/tockier/torchtitan/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/state_dict_loader.py", line 266, in local_step
[rank0]:    local_plan = planner.create_local_plan()
[rank0]:                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:  File "/home-local/tockier/torchtitan/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/default_planner.py", line 350, in create_local_plan
[rank0]:    return create_default_local_load_plan(
[rank0]:           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:  File "/home-local/tockier/torchtitan/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/default_planner.py", line 471, in create_default_local_load_plan
[rank0]:    raise RuntimeError(f"Missing key in checkpoint state_dict: {fqn}.")
[rank0]:RuntimeError: Missing key in checkpoint state_dict: dataloader.dp_rank_0.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions