Skip to content

add ModelTorchDistributed backend for spatial parallelism along tests#842

Draft
mahf708 wants to merge 2 commits intoai2cm:mainfrom
E3SM-Project:e3sm/oscar/sp-distributed-class
Draft

add ModelTorchDistributed backend for spatial parallelism along tests#842
mahf708 wants to merge 2 commits intoai2cm:mainfrom
E3SM-Project:e3sm/oscar/sp-distributed-class

Conversation

@mahf708
Copy link

@mahf708 mahf708 commented Feb 16, 2026

add ModelTorchDistributed backend to begin the process of enabling training and inferencing with a spatial parallelized context.

Changes:

  • In fme.core.distributed:

    • ModelTorchDistributed (DistributedBackend) in model_torch_distributed module
    • module model_torch_distributed_comm contains comm utilities for ModelTorchDistributed
    • model_torch_distributed_utils contains tensor utilities for ModelTorchDistributed
  • Expanded fme.core.distributed.parallel_tests and verified with NPROC=1,2,3,4 (on 4xA100 pod)

  • Updated signature of base classes as needed

  • Tests added

  • Optional (for now unpinned) dependency on physicsnemo

Closes #749

@mahf708
Copy link
Author

mahf708 commented Feb 16, 2026

This builds on top of @odiazib's PR (#749) and previous draft work (#719). I decided to significantly increase test coverage; arguably overkill, but I thought these tests would help with the planned upcoming dev (data loading/writing, training operations, layering, etc.). There are some tricky (and very not pretty) parts in how we are borrowing comm utilities from makani, but they are hidden enough for now. In particular, in order to facilitate efficient/easy testing, some nontrivial hacks were needed. Some of the tests are kind of trivial in nature (e.g., barrier) and they balloon the size of this PR by a fair bit. Hopefully we will keep future PRs to <1000 lines. Copying @odiazib, @elynnwu, and @mcgibbon for awareness. Feedback welcome.

Tested locally with make parallel_tests NPROC=X for X=1,2,3,4. Also, ran full test suite on A100 GPUs as in the CI files. All passing as far as I could tell (except with timed ones running overtime).

run: |
python -m pip install uv
uv pip install --system -c constraints.txt -e .[dev]
uv pip install --system --no-build-isolation -c constraints.txt -e .[dev,spatial-parallelism]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to run the very-fast tests without spatial parallelism? It's nice to have a test that doesn't use them to make sure the "optional" part of optional dependency is done correctly.



@requires_parallel
def test_gather_raises_not_implemented(monkeypatch):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggestion: Make this a unit test of the ModelTorchDistributed backend, rather than an integration test involving Distributed.



@requires_parallel
def test_gather_irregular_raises_not_implemented(monkeypatch):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Question: Do we need tests for these NotImplementedError's? We do plan to implement them.

else:
monkeypatch.delenv("W_PARALLEL_SIZE", raising=False)

result = ModelTorchDistributed.is_available()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggestion: If the logic can be refactored to a helper in some way, or you have is_available take input arguments that default to the environment variables, you can unit test this without environment monkeypatching.

return singleton

@classmethod
def reset(cls) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am nervous about something like this, because distributed backends generally use global state which doesn't get properly reset when we do something like this.

"""
return self._distributed.total_ranks

def get_local_rank(self) -> int:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Question: What is a "local" rank?

def get_local_rank(self) -> int:
return self._distributed.get_local_rank()

def get_sampler(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be great to add a test of this, I think you're right about the changes.

return self._distributed.local_batch_size(batch_size)

def reduce_mean(self, tensor: torch.Tensor) -> torch.Tensor:
def reduce_mean(self, tensor: torch.Tensor, group=None) -> torch.Tensor:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggestion: We could abstract away concerns like specific group names to make it easier on users. For example, can we use a data_parallel_only: bool = True argument instead? or a kind: Literal["data_parallel", "model", "all"] = "data_parallel"?

@@ -169,26 +191,28 @@ def get_local_slices(

def reduce_sum(self, tensor: torch.Tensor) -> torch.Tensor:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Generally speaking we need reductions along the data-parallel dimension to get global maps of mean outputs in our aggregators. We probably also need to do global area-weighted means for certain operations in the correctors.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We also do need to reudce over other orthogonal dimensions (say for zonal means), but that can wait a bit

assert dist.local_batch_size(global_batch) == expected


def test_local_batch_size_not_divisible():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this test need to be run in a parallel context or would it work in serial?

def get_local_slices(
self,
tensor_shape,
rank: int | None = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Issue: Both the rank and data_parallel_dim argument here are ignored.

Suggestion: I have this PR to let the rank argument get removed: #839

I don't know what needs to be done yet about data_parallel_dim.

Comment on lines +338 to +345
def comm_get_size(self, key: str):
return self._distributed.comm_get_size(key)

def comm_get_group(self, key: str):
return self._distributed.comm_get_group(key)

def comm_get_rank(self, key: str):
return self._distributed.comm_get_rank(key)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Issue: These are low-level operations. Can we hide them inside backend methods? Other backends don't have a comm.

Comment on lines +43 to +44
from physicsnemo.distributed.manager import DistributedManager
from physicsnemo.distributed.config import ProcessGroupNode, ProcessGroupConfig
Copy link
Contributor

@mcgibbon mcgibbon Feb 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this is all we're using from physicsnemo, we could consider vendorizing/forking the code, i.e. copy-pasting it into a subdirectory here. The last time I checked these seemed pretty isolated / didn't depend on other infrastructure in physicsnemo, and it would avoid a dependency.

return 1

def comm_get_group(self, key: str):
return None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Question: Is this a valid return value for comm_get_group?

...

@abstractmethod
def comm_get_size(self, key: str): ...
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Issue: These need return value type hints.

def shutdown(self):
return self._distributed.shutdown()

def comm_get_size(self, key: str):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These comm_ methods are only used in tests, do we need them / do we need them public on Distributed?

return batch_size

def reduce_mean(self, tensor: torch.Tensor) -> torch.Tensor:
def reduce_mean(self, tensor: torch.Tensor, group=None) -> torch.Tensor:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would like to avoid adding low-level information like the group names above this PR, but we should talk about it. For an initial PR we should avoid features that need this.

from fme.core.cuhpx.sht import SHT as CuHpxSHT
from fme.core.cuhpx.sht import iSHT as CuHpxiSHT
from fme.core.device import get_device
from fme.core.distributed import Distributed
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These changes to gridded_ops and their associated tests in parallel_tests can/probably should be split into their own PR.

@@ -0,0 +1,81 @@
import logging
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test should probably be split off with the gridded_ops changes into its own PR.

@@ -0,0 +1,82 @@
import pytest
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What coverage does this file add that isn't already covered by test_local_slices.py?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be nice if we could test this code in a unit test without the full process of setting up Distributed. But it's probably a test of DistributedManager.


# Create test data
data_tensor_host = torch.randn(1, 2, nx, ny, device="cpu")
area_weights_host = torch.ones(nx, ny).to("cpu") * 5
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Issue: Please update area weights to not use ones (maybe 1 plus a random uniform?) to cover how area is applied.

# depending on the batch/data parallel index/rank.
x_global_ranked = x_global_base + dist.data_parallel_rank
x_local_ranked = x_global_ranked[dist.get_local_slices(global_shape, dist.rank)]
x_local_reduced = dist.reduce_mean(x_local_ranked)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand how this test passes under spatial parallelism, given this isn't a reduction over the data-parallel group, and the test is supposed to require that it is.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this never passes under sp :) because it is slyly being skipped by the pytest hackery 😺 (which we will get rid of in the next edit)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the record, I am not surprised at all the cpu tests bombed. I didn't even run a single one of these tests, and I totally forgot about that until after I stopped working on this. In the revision, will address those too

@mahf708 mahf708 marked this pull request as draft February 18, 2026 19:41
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants

Comments