Conversation
fme/core/distributed/distributed.py
Outdated
There was a problem hiding this comment.
i'm not thrilled with these names or this general method of forcing; I if we have to go this route, I prefer to call the new one: "model" (since it can parallelize over all sorts of dims/tags)
There was a problem hiding this comment.
For now, can you keep the behavior Oscar was going to use? That is, if (let's call it) FME_DISTRIBUTED_H or _W is set > 1 then the spatial backend is used? If you need a way to force it to get used, perhaps if one or both is set we can use the spatial backend and if not use the torch backend, and only use the torch backend if both are unset? I don't think we currently need a way to force non-distributed from the CLI so we shouldn't add that feature.
There was a problem hiding this comment.
Addressed mostly, but I'd rather force people to pick for one, we can streamline defaults later, etc.
| logger.debug("Barrier on rank %d", self._rank) | ||
| torch.distributed.barrier(device_ids=self._device_ids) | ||
|
|
||
| def shutdown(self): |
There was a problem hiding this comment.
Not for this PR or necessarily related to spatial parallelism, but we should think about defining a context manager for parallelism that makes sure cleanup happens when the context exits, kind of like we do with GlobalTimer. I think we currently don't call it properly in the unit tests for example.
There was a problem hiding this comment.
FWIW, the torch backend is the one with problematic teardown for some reason; the new one is slightly cleaner. I can look into why...
| """ | ||
| dist = Distributed.get_instance() | ||
| global_shape = (2, 4, 4) | ||
| n_dp = dist.total_data_parallel_ranks |
There was a problem hiding this comment.
Is this desired? I feel like we should set a constant shape like 4 that covers the cases we plan to run with, so that special values like 3 data parallel ranks are more interesting (which my test may need to be refactored to properly manage, idk).
There was a problem hiding this comment.
For now, I think it is an easy copout to get things going and running with arbitrary tests. If we hardcode the batch dimension, we will need to instrument pytest skipping and such. I like the idea that these tests can run successfully with 5,000,000 ranks ;)
| "tensorly-torch", | ||
| "torch-harmonics==0.8.0", | ||
| "torch>=2", | ||
| "torch>=2.4.0", |
There was a problem hiding this comment.
No action needed here, I think our dependency only image already uses 2.7.1 and I don't believe there's anything in this PR that warrants re-rebuilding the image again.
add ModelTorchDistributed with tests
Changes:
fme.core.distributed has a new ModelDistributedBackend that allows for parallelism over spatial dimensions as well as batch/data.
torch is pinned with a minimum of 2.4.0 to use new facilities for distributed, etc.
Tests added
If dependencies changed, "deps only" image rebuilt and "latest_deps_only_image.txt" file updated
Closes #749
Closes #842