diff --git a/src/weathergen/model/model_interface.py b/src/weathergen/model/model_interface.py index 23334749b..528db56d9 100644 --- a/src/weathergen/model/model_interface.py +++ b/src/weathergen/model/model_interface.py @@ -189,13 +189,17 @@ def load_model(cf, model, device, run_id: str, mini_epoch=-1): maybe_sharded_sd = {} for param_name, full_tensor in params.items(): sharded_meta_param = meta_sharded_sd.get(param_name) - sharded_tensor = distribute_tensor( - full_tensor, - sharded_meta_param.device_mesh, - sharded_meta_param.placements, - ) - # maybe_sharded_sd[param_name.replace("module.", "")] = nn.Parameter(sharded_tensor) - maybe_sharded_sd[param_name] = torch.nn.Parameter(sharded_tensor) + if sharded_meta_param is None: + logger.info(f"Sharding meta parameters is None for: {param_name}") + maybe_sharded_sd[param_name] = torch.nn.Parameter(full_tensor) + else: + sharded_tensor = distribute_tensor( + full_tensor, + sharded_meta_param.device_mesh, + sharded_meta_param.placements, + ) + # maybe_sharded_sd[param_name.replace("module.", "")] = nn.Parameter(sharded_tensor) + maybe_sharded_sd[param_name] = torch.nn.Parameter(sharded_tensor) # choose `assign=True` for sharded model since we cannot call `copy_` on meta tensor mkeys, ukeys = model.load_state_dict(maybe_sharded_sd, strict=False, assign=True)