From 42ed91d3493275aa7a92e5f5ebbd585bbb2ba1e8 Mon Sep 17 00:00:00 2001 From: EvgeniyS99 Date: Tue, 11 Jul 2023 09:57:09 +0000 Subject: [PATCH 1/3] Added weighted BCE --- batchflow/models/torch/losses/binary.py | 26 +++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/batchflow/models/torch/losses/binary.py b/batchflow/models/torch/losses/binary.py index 2c552d752..532f31c62 100644 --- a/batchflow/models/torch/losses/binary.py +++ b/batchflow/models/torch/losses/binary.py @@ -5,6 +5,8 @@ from torch import nn import torch.nn.functional as F +from .core import Weighted + class BCE(nn.Module): @@ -191,3 +193,27 @@ def forward(self, prediction, target): sensitivity = (squared_error * inverse).sum() / (inverse.sum() + self.eps) return self.r * specificity + (1 - self.r) * sensitivity + + + +class WeightedBCE(nn.Module): + """ Weighted BCE for the unbalanced data which computes weights dynamically """ + def __init__(self): + super().__init__() + + def forward(self, prediction, target): + mask = target.float() + num_positive = (mask == 1).sum() + num_negative = (mask == 0).sum() + + mask[mask == 1] = num_negative / (num_positive + num_negative) + mask[mask == 0] = num_positive / (num_positive + num_negative) + + loss = torch.zeros(1, device=prediction.device) + loss += F.binary_cross_entropy_with_logits(prediction, target, weight=mask) + + return loss + +losses = [WeightedBCE(), Dice()] +weights = [.5, .5] +comboloss = Weighted(losses, weights) \ No newline at end of file From 353c32d9671b2f48399cd94a19684e2347bdca79 Mon Sep 17 00:00:00 2001 From: EvgeniyS99 Date: Wed, 12 Jul 2023 08:03:50 +0000 Subject: [PATCH 2/3] Change naming --- batchflow/models/torch/losses/binary.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/batchflow/models/torch/losses/binary.py b/batchflow/models/torch/losses/binary.py index 532f31c62..d98ae400f 100644 --- a/batchflow/models/torch/losses/binary.py +++ b/batchflow/models/torch/losses/binary.py @@ -196,8 +196,8 @@ def forward(self, prediction, target): -class WeightedBCE(nn.Module): - """ Weighted BCE for the unbalanced data which computes weights dynamically """ +class BalancedWeightedBCE(nn.Module): + """ Balanced weighted BCE loss for the unbalanced data which computes weights dynamically """ def __init__(self): super().__init__() @@ -212,8 +212,4 @@ def forward(self, prediction, target): loss = torch.zeros(1, device=prediction.device) loss += F.binary_cross_entropy_with_logits(prediction, target, weight=mask) - return loss - -losses = [WeightedBCE(), Dice()] -weights = [.5, .5] -comboloss = Weighted(losses, weights) \ No newline at end of file + return loss \ No newline at end of file From 151f865462b56027afc56848ca6dc6aa5acbf62b Mon Sep 17 00:00:00 2001 From: EvgeniyS99 Date: Wed, 12 Jul 2023 08:07:21 +0000 Subject: [PATCH 3/3] Remove extra import --- batchflow/models/torch/losses/binary.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/batchflow/models/torch/losses/binary.py b/batchflow/models/torch/losses/binary.py index d98ae400f..9af1f70df 100644 --- a/batchflow/models/torch/losses/binary.py +++ b/batchflow/models/torch/losses/binary.py @@ -5,8 +5,6 @@ from torch import nn import torch.nn.functional as F -from .core import Weighted - class BCE(nn.Module):