FIX: accumulate eval loss across all iterations#558
Conversation
Previously, the evaluation loss was computed per iteration and overwritten, leading to incorrect averaging when multiple eval iterations are used. This fix accumulates the numerator and denominator separately across all eval iterations and computes the final average at the end.
There was a problem hiding this comment.
Pull request overview
Fixes evaluation loss averaging when eval_iters > 1 by accumulating loss numerators/denominators across all evaluation iterations and computing the final average once at the end.
Changes:
- Replace per-iteration overwrite of
total_loss_dictwith cross-iteration accumulation of loss numerators/denominators. - Compute the final averaged evaluation loss after completing all eval iterations.
| if total_loss_denominators[key] > 0: | ||
| total_loss_dict[key] = total_loss_numerators[key] / total_loss_denominators[key] | ||
| else: | ||
| total_loss_dict[key] = 0.0 |
There was a problem hiding this comment.
total_loss_denominators[key] may be a torch.Tensor (e.g., token-count scalars in the (loss_sum, num_tokens) reporting path). In that case, if total_loss_denominators[key] > 0: will raise because a Tensor can’t be used as a Python boolean. Consider removing this guard (division by zero shouldn’t happen) or making it Tensor-safe (e.g., check denom.item() when denom.numel()==1, and keep the zero case on the same device/dtype).
| if total_loss_denominators[key] > 0: | |
| total_loss_dict[key] = total_loss_numerators[key] / total_loss_denominators[key] | |
| else: | |
| total_loss_dict[key] = 0.0 | |
| denom = total_loss_denominators[key] | |
| numer = total_loss_numerators[key] | |
| # Handle both Python scalars and torch.Tensors safely. | |
| if isinstance(denom, torch.Tensor): | |
| if denom.numel() == 1: | |
| # Scalar tensor: use .item() to decide on zero/non-zero. | |
| if denom.item() > 0: | |
| total_loss_dict[key] = numer / denom | |
| else: | |
| # Keep zero on the same device/dtype as the loss/denominator. | |
| if isinstance(numer, torch.Tensor): | |
| total_loss_dict[key] = torch.zeros_like(numer) | |
| else: | |
| total_loss_dict[key] = torch.zeros( | |
| (), device=denom.device, dtype=denom.dtype | |
| ) | |
| else: | |
| # Non-scalar tensor denominator: assume valid and divide directly. | |
| total_loss_dict[key] = numer / denom | |
| else: | |
| # Non-tensor (Python scalar) path, preserve original behavior. | |
| if denom > 0: | |
| total_loss_dict[key] = numer / denom | |
| else: | |
| total_loss_dict[key] = 0.0 |
| # there is one dict per microbatch. in new reporting, we average | ||
| # over the total number of tokens across the global batch. | ||
| if isinstance(val, tuple) or isinstance(val, list): | ||
| numerator += val[0] | ||
| denominator += val[1] | ||
| else: | ||
| # legacy behavior. we average over the number of microbatches, | ||
| # and so the denominator is 1. | ||
| numerator += val | ||
| denominator += 1 |
There was a problem hiding this comment.
The microbatch loss aggregation only treats tuple/list values as (numerator, denominator). In training, the same aggregation also supports the upstream Megatron format where the value is a 2-element torch.Tensor (val.numel() == 2) containing [loss_sum, num_tokens] (see primus/modules/trainer/megatron/trainer.py). If evaluation receives that format, this code will fall back to the legacy branch and compute an incorrect denominator. Please add the same 2-element tensor handling here to keep eval loss averaging correct.
Previously, the evaluation loss was computed per iteration and overwritten, leading to incorrect averaging when multiple eval iterations are used.
This fix accumulates the numerator and denominator separately across all eval iterations and computes the final average at the end.