Skip to content

Conversation

@Luodian
Copy link
Collaborator

@Luodian Luodian commented Jan 10, 2026

Summary

  • Disable find_unused_parameters in DDP wrapper for performance improvement
  • Add documentation explaining the reasoning and how to debug if issues arise

Problem

find_unused_parameters=True adds overhead:

  • Extra autograd graph traversal after every forward pass
  • Reduced comm/compute overlap
  • CPU overhead for large models

Analysis

Code review confirms this training loop does NOT have unused parameters:

Check Result
All backbone outputs flow to loss? ✅ Yes, via pfc() for all strategies
Branching skips modules? ❌ No, branching is input shape/indexing only
Optional modules exist? Pooling head (use_head), but always used in training

What find_unused_parameters does

Forward → DDP traverses autograd graph → Marks unused params "ready"

If no params are unused, this traversal is pure overhead.

What static_graph=True does

Caches the parameter usage pattern after warmup, mitigating some overhead.
But if no params are unused, find_unused_parameters=False is still cleaner.

Testing Instructions

If this change causes errors:

# Error you might see:
RuntimeError: Expected to have finished reduction in the prior iteration...

# Debug with:
TORCH_DISTRIBUTED_DEBUG=INFO torchrun ... --max_steps 10

The debug output will show which parameters didn't receive gradients.

Rollback

If issues occur, simply change line to:

find_unused_parameters=True,  # Re-enabled due to [reason]

Files Changed

  • training/train.py - DDP configuration with documentation

Problem:
find_unused_parameters=True causes DDP to traverse the autograd graph after
every forward pass to detect unused parameters. This adds CPU overhead and
can reduce comm/compute overlap.

Analysis:
Code review confirms this training loop does NOT have unused parameters:
- All backbone outputs flow to loss via pfc() for all data strategies
  (residual, frame sampling, collage)
- Branching in the training loop is about input shape/indexing, not
  skipping parameterized modules
- The only 'optional' module is the pooling head (gated by use_head),
  but training code always uses pooler_output/head_output

static_graph=True is already set, which:
- Caches the used/unused parameter set after warmup
- Mitigates overhead if find_unused_parameters was needed
- But the cleanest performance path is find_unused_parameters=False

Testing:
If this causes 'Expected to have finished reduction' errors:
1. Re-enable find_unused_parameters=True
2. Debug with TORCH_DISTRIBUTED_DEBUG=INFO to identify unused params
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants