From b5eab97f2ce708c5459e6ea12a1ed1b8d31caf3d Mon Sep 17 00:00:00 2001 From: Kerem Tezcan Date: Mon, 2 Feb 2026 11:50:52 +0100 Subject: [PATCH] added check for sharding --- src/weathergen/model/model_interface.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/src/weathergen/model/model_interface.py b/src/weathergen/model/model_interface.py index 23334749b..528db56d9 100644 --- a/src/weathergen/model/model_interface.py +++ b/src/weathergen/model/model_interface.py @@ -189,13 +189,17 @@ def load_model(cf, model, device, run_id: str, mini_epoch=-1): maybe_sharded_sd = {} for param_name, full_tensor in params.items(): sharded_meta_param = meta_sharded_sd.get(param_name) - sharded_tensor = distribute_tensor( - full_tensor, - sharded_meta_param.device_mesh, - sharded_meta_param.placements, - ) - # maybe_sharded_sd[param_name.replace("module.", "")] = nn.Parameter(sharded_tensor) - maybe_sharded_sd[param_name] = torch.nn.Parameter(sharded_tensor) + if sharded_meta_param is None: + logger.info(f"Sharding meta parameters is None for: {param_name}") + maybe_sharded_sd[param_name] = torch.nn.Parameter(full_tensor) + else: + sharded_tensor = distribute_tensor( + full_tensor, + sharded_meta_param.device_mesh, + sharded_meta_param.placements, + ) + # maybe_sharded_sd[param_name.replace("module.", "")] = nn.Parameter(sharded_tensor) + maybe_sharded_sd[param_name] = torch.nn.Parameter(sharded_tensor) # choose `assign=True` for sharded model since we cannot call `copy_` on meta tensor mkeys, ukeys = model.load_state_dict(maybe_sharded_sd, strict=False, assign=True)