diff --git a/training/train.py b/training/train.py index d3139f7..6ee6f3f 100644 --- a/training/train.py +++ b/training/train.py @@ -370,13 +370,23 @@ def _expand(name, v): global_step = 0 def wrap_ddp(model): + # Performance optimization: disable find_unused_parameters + # + # find_unused_parameters=True adds overhead from extra autograd graph traversal. + # Analysis shows this training loop does NOT have unused parameters: + # - All backbone outputs flow to loss via pfc() for all data strategies + # - Branching is about input shape/indexing, not skipping modules + # - Only "optional" module is pooling head (gated by use_head), always used + # + # If you encounter "Expected to have finished reduction" errors, re-enable + # with find_unused_parameters=True and debug with TORCH_DISTRIBUTED_DEBUG=INFO return torch.nn.parallel.DistributedDataParallel( module=model, broadcast_buffers=False, device_ids=[local_rank], bucket_cap_mb=32, - find_unused_parameters=True, - static_graph=True, + find_unused_parameters=False, # Disabled for performance + static_graph=True, # Enables further optimizations when graph is fixed ) backbone_ddp = wrap_ddp(backbone)