From 99505d07459e36802af9fc3fd67939fc64c315d4 Mon Sep 17 00:00:00 2001 From: Bo Li Date: Sat, 10 Jan 2026 16:04:01 +0800 Subject: [PATCH] perf: disable find_unused_parameters for faster DDP training Problem: find_unused_parameters=True causes DDP to traverse the autograd graph after every forward pass to detect unused parameters. This adds CPU overhead and can reduce comm/compute overlap. Analysis: Code review confirms this training loop does NOT have unused parameters: - All backbone outputs flow to loss via pfc() for all data strategies (residual, frame sampling, collage) - Branching in the training loop is about input shape/indexing, not skipping parameterized modules - The only 'optional' module is the pooling head (gated by use_head), but training code always uses pooler_output/head_output static_graph=True is already set, which: - Caches the used/unused parameter set after warmup - Mitigates overhead if find_unused_parameters was needed - But the cleanest performance path is find_unused_parameters=False Testing: If this causes 'Expected to have finished reduction' errors: 1. Re-enable find_unused_parameters=True 2. Debug with TORCH_DISTRIBUTED_DEBUG=INFO to identify unused params --- training/train.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/training/train.py b/training/train.py index d3139f7..6ee6f3f 100644 --- a/training/train.py +++ b/training/train.py @@ -370,13 +370,23 @@ def _expand(name, v): global_step = 0 def wrap_ddp(model): + # Performance optimization: disable find_unused_parameters + # + # find_unused_parameters=True adds overhead from extra autograd graph traversal. + # Analysis shows this training loop does NOT have unused parameters: + # - All backbone outputs flow to loss via pfc() for all data strategies + # - Branching is about input shape/indexing, not skipping modules + # - Only "optional" module is pooling head (gated by use_head), always used + # + # If you encounter "Expected to have finished reduction" errors, re-enable + # with find_unused_parameters=True and debug with TORCH_DISTRIBUTED_DEBUG=INFO return torch.nn.parallel.DistributedDataParallel( module=model, broadcast_buffers=False, device_ids=[local_rank], bucket_cap_mb=32, - find_unused_parameters=True, - static_graph=True, + find_unused_parameters=False, # Disabled for performance + static_graph=True, # Enables further optimizations when graph is fixed ) backbone_ddp = wrap_ddp(backbone)