add ModelTorchDistributed backend for spatial parallelism along tests#842
add ModelTorchDistributed backend for spatial parallelism along tests#842mahf708 wants to merge 2 commits intoai2cm:mainfrom
Conversation
|
This builds on top of @odiazib's PR (#749) and previous draft work (#719). I decided to significantly increase test coverage; arguably overkill, but I thought these tests would help with the planned upcoming dev (data loading/writing, training operations, layering, etc.). There are some tricky (and very not pretty) parts in how we are borrowing comm utilities from makani, but they are hidden enough for now. In particular, in order to facilitate efficient/easy testing, some nontrivial hacks were needed. Some of the tests are kind of trivial in nature (e.g., barrier) and they balloon the size of this PR by a fair bit. Hopefully we will keep future PRs to <1000 lines. Copying @odiazib, @elynnwu, and @mcgibbon for awareness. Feedback welcome. Tested locally with |
| run: | | ||
| python -m pip install uv | ||
| uv pip install --system -c constraints.txt -e .[dev] | ||
| uv pip install --system --no-build-isolation -c constraints.txt -e .[dev,spatial-parallelism] |
There was a problem hiding this comment.
Is it possible to run the very-fast tests without spatial parallelism? It's nice to have a test that doesn't use them to make sure the "optional" part of optional dependency is done correctly.
|
|
||
|
|
||
| @requires_parallel | ||
| def test_gather_raises_not_implemented(monkeypatch): |
There was a problem hiding this comment.
Suggestion: Make this a unit test of the ModelTorchDistributed backend, rather than an integration test involving Distributed.
|
|
||
|
|
||
| @requires_parallel | ||
| def test_gather_irregular_raises_not_implemented(monkeypatch): |
There was a problem hiding this comment.
Question: Do we need tests for these NotImplementedError's? We do plan to implement them.
| else: | ||
| monkeypatch.delenv("W_PARALLEL_SIZE", raising=False) | ||
|
|
||
| result = ModelTorchDistributed.is_available() |
There was a problem hiding this comment.
Suggestion: If the logic can be refactored to a helper in some way, or you have is_available take input arguments that default to the environment variables, you can unit test this without environment monkeypatching.
| return singleton | ||
|
|
||
| @classmethod | ||
| def reset(cls) -> None: |
There was a problem hiding this comment.
I am nervous about something like this, because distributed backends generally use global state which doesn't get properly reset when we do something like this.
| """ | ||
| return self._distributed.total_ranks | ||
|
|
||
| def get_local_rank(self) -> int: |
There was a problem hiding this comment.
Question: What is a "local" rank?
| def get_local_rank(self) -> int: | ||
| return self._distributed.get_local_rank() | ||
|
|
||
| def get_sampler( |
There was a problem hiding this comment.
It would be great to add a test of this, I think you're right about the changes.
| return self._distributed.local_batch_size(batch_size) | ||
|
|
||
| def reduce_mean(self, tensor: torch.Tensor) -> torch.Tensor: | ||
| def reduce_mean(self, tensor: torch.Tensor, group=None) -> torch.Tensor: |
There was a problem hiding this comment.
Suggestion: We could abstract away concerns like specific group names to make it easier on users. For example, can we use a data_parallel_only: bool = True argument instead? or a kind: Literal["data_parallel", "model", "all"] = "data_parallel"?
| @@ -169,26 +191,28 @@ def get_local_slices( | |||
|
|
|||
| def reduce_sum(self, tensor: torch.Tensor) -> torch.Tensor: | |||
There was a problem hiding this comment.
Generally speaking we need reductions along the data-parallel dimension to get global maps of mean outputs in our aggregators. We probably also need to do global area-weighted means for certain operations in the correctors.
There was a problem hiding this comment.
We also do need to reudce over other orthogonal dimensions (say for zonal means), but that can wait a bit
| assert dist.local_batch_size(global_batch) == expected | ||
|
|
||
|
|
||
| def test_local_batch_size_not_divisible(): |
There was a problem hiding this comment.
Does this test need to be run in a parallel context or would it work in serial?
| def get_local_slices( | ||
| self, | ||
| tensor_shape, | ||
| rank: int | None = None, |
There was a problem hiding this comment.
Issue: Both the rank and data_parallel_dim argument here are ignored.
Suggestion: I have this PR to let the rank argument get removed: #839
I don't know what needs to be done yet about data_parallel_dim.
| def comm_get_size(self, key: str): | ||
| return self._distributed.comm_get_size(key) | ||
|
|
||
| def comm_get_group(self, key: str): | ||
| return self._distributed.comm_get_group(key) | ||
|
|
||
| def comm_get_rank(self, key: str): | ||
| return self._distributed.comm_get_rank(key) |
There was a problem hiding this comment.
Issue: These are low-level operations. Can we hide them inside backend methods? Other backends don't have a comm.
| from physicsnemo.distributed.manager import DistributedManager | ||
| from physicsnemo.distributed.config import ProcessGroupNode, ProcessGroupConfig |
There was a problem hiding this comment.
If this is all we're using from physicsnemo, we could consider vendorizing/forking the code, i.e. copy-pasting it into a subdirectory here. The last time I checked these seemed pretty isolated / didn't depend on other infrastructure in physicsnemo, and it would avoid a dependency.
| return 1 | ||
|
|
||
| def comm_get_group(self, key: str): | ||
| return None |
There was a problem hiding this comment.
Question: Is this a valid return value for comm_get_group?
| ... | ||
|
|
||
| @abstractmethod | ||
| def comm_get_size(self, key: str): ... |
There was a problem hiding this comment.
Issue: These need return value type hints.
| def shutdown(self): | ||
| return self._distributed.shutdown() | ||
|
|
||
| def comm_get_size(self, key: str): |
There was a problem hiding this comment.
These comm_ methods are only used in tests, do we need them / do we need them public on Distributed?
| return batch_size | ||
|
|
||
| def reduce_mean(self, tensor: torch.Tensor) -> torch.Tensor: | ||
| def reduce_mean(self, tensor: torch.Tensor, group=None) -> torch.Tensor: |
There was a problem hiding this comment.
I would like to avoid adding low-level information like the group names above this PR, but we should talk about it. For an initial PR we should avoid features that need this.
| from fme.core.cuhpx.sht import SHT as CuHpxSHT | ||
| from fme.core.cuhpx.sht import iSHT as CuHpxiSHT | ||
| from fme.core.device import get_device | ||
| from fme.core.distributed import Distributed |
There was a problem hiding this comment.
These changes to gridded_ops and their associated tests in parallel_tests can/probably should be split into their own PR.
| @@ -0,0 +1,81 @@ | |||
| import logging | |||
There was a problem hiding this comment.
This test should probably be split off with the gridded_ops changes into its own PR.
| @@ -0,0 +1,82 @@ | |||
| import pytest | |||
There was a problem hiding this comment.
What coverage does this file add that isn't already covered by test_local_slices.py?
There was a problem hiding this comment.
It would be nice if we could test this code in a unit test without the full process of setting up Distributed. But it's probably a test of DistributedManager.
|
|
||
| # Create test data | ||
| data_tensor_host = torch.randn(1, 2, nx, ny, device="cpu") | ||
| area_weights_host = torch.ones(nx, ny).to("cpu") * 5 |
There was a problem hiding this comment.
Issue: Please update area weights to not use ones (maybe 1 plus a random uniform?) to cover how area is applied.
| # depending on the batch/data parallel index/rank. | ||
| x_global_ranked = x_global_base + dist.data_parallel_rank | ||
| x_local_ranked = x_global_ranked[dist.get_local_slices(global_shape, dist.rank)] | ||
| x_local_reduced = dist.reduce_mean(x_local_ranked) |
There was a problem hiding this comment.
I don't understand how this test passes under spatial parallelism, given this isn't a reduction over the data-parallel group, and the test is supposed to require that it is.
There was a problem hiding this comment.
this never passes under sp :) because it is slyly being skipped by the pytest hackery 😺 (which we will get rid of in the next edit)
There was a problem hiding this comment.
For the record, I am not surprised at all the cpu tests bombed. I didn't even run a single one of these tests, and I totally forgot about that until after I stopped working on this. In the revision, will address those too
add ModelTorchDistributed backend to begin the process of enabling training and inferencing with a spatial parallelized context.
Changes:
In fme.core.distributed:
Expanded fme.core.distributed.parallel_tests and verified with NPROC=1,2,3,4 (on 4xA100 pod)
Updated signature of base classes as needed
Tests added
Optional (for now unpinned) dependency on physicsnemo
Closes #749