diff --git a/batchflow/models/torch/losses/binary.py b/batchflow/models/torch/losses/binary.py index 2c552d752..9af1f70df 100644 --- a/batchflow/models/torch/losses/binary.py +++ b/batchflow/models/torch/losses/binary.py @@ -191,3 +191,23 @@ def forward(self, prediction, target): sensitivity = (squared_error * inverse).sum() / (inverse.sum() + self.eps) return self.r * specificity + (1 - self.r) * sensitivity + + + +class BalancedWeightedBCE(nn.Module): + """ Balanced weighted BCE loss for the unbalanced data which computes weights dynamically """ + def __init__(self): + super().__init__() + + def forward(self, prediction, target): + mask = target.float() + num_positive = (mask == 1).sum() + num_negative = (mask == 0).sum() + + mask[mask == 1] = num_negative / (num_positive + num_negative) + mask[mask == 0] = num_positive / (num_positive + num_negative) + + loss = torch.zeros(1, device=prediction.device) + loss += F.binary_cross_entropy_with_logits(prediction, target, weight=mask) + + return loss \ No newline at end of file