Skip to content

add ModelTorchDistributed with tests#847

Open
mahf708 wants to merge 3 commits intoai2cm:mainfrom
mahf708:modtordis
Open

add ModelTorchDistributed with tests#847
mahf708 wants to merge 3 commits intoai2cm:mainfrom
mahf708:modtordis

Conversation

@mahf708
Copy link

@mahf708 mahf708 commented Feb 18, 2026

add ModelTorchDistributed with tests

Changes:

  • fme.core.distributed has a new ModelDistributedBackend that allows for parallelism over spatial dimensions as well as batch/data.

  • torch is pinned with a minimum of 2.4.0 to use new facilities for distributed, etc.

  • Tests added

  • If dependencies changed, "deps only" image rebuilt and "latest_deps_only_image.txt" file updated

Closes #749
Closes #842

Comment on lines 61 to 67
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i'm not thrilled with these names or this general method of forcing; I if we have to go this route, I prefer to call the new one: "model" (since it can parallelize over all sorts of dims/tags)

Copy link
Contributor

@mcgibbon mcgibbon Feb 18, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For now, can you keep the behavior Oscar was going to use? That is, if (let's call it) FME_DISTRIBUTED_H or _W is set > 1 then the spatial backend is used? If you need a way to force it to get used, perhaps if one or both is set we can use the spatial backend and if not use the torch backend, and only use the torch backend if both are unset? I don't think we currently need a way to force non-distributed from the CLI so we shouldn't add that feature.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed mostly, but I'd rather force people to pick for one, we can streamline defaults later, etc.

logger.debug("Barrier on rank %d", self._rank)
torch.distributed.barrier(device_ids=self._device_ids)

def shutdown(self):
Copy link
Contributor

@mcgibbon mcgibbon Feb 18, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not for this PR or necessarily related to spatial parallelism, but we should think about defining a context manager for parallelism that makes sure cleanup happens when the context exits, kind of like we do with GlobalTimer. I think we currently don't call it properly in the unit tests for example.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FWIW, the torch backend is the one with problematic teardown for some reason; the new one is slightly cleaner. I can look into why...

"""
dist = Distributed.get_instance()
global_shape = (2, 4, 4)
n_dp = dist.total_data_parallel_ranks
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this desired? I feel like we should set a constant shape like 4 that covers the cases we plan to run with, so that special values like 3 data parallel ranks are more interesting (which my test may need to be refactored to properly manage, idk).

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For now, I think it is an easy copout to get things going and running with arbitrary tests. If we hardcode the batch dimension, we will need to instrument pytest skipping and such. I like the idea that these tests can run successfully with 5,000,000 ranks ;)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good tests!

"tensorly-torch",
"torch-harmonics==0.8.0",
"torch>=2",
"torch>=2.4.0",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No action needed here, I think our dependency only image already uses 2.7.1 and I don't believe there's anything in this PR that warrants re-rebuilding the image again.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants

Comments