diff --git a/python/graphstorm/gsf.py b/python/graphstorm/gsf.py index 4a1e1f82ef..2dd5087651 100644 --- a/python/graphstorm/gsf.py +++ b/python/graphstorm/gsf.py @@ -549,8 +549,6 @@ def create_builtin_node_decoder(g, decoder_input_dim, config, train_task): alpha = config.alpha if config.alpha is not None else 0.25 gamma = config.gamma if config.gamma is not None else 2. loss_func = FocalLossFunc(alpha, gamma) - # Focal loss expects 1-dimensional output - decoder_output_dim = 1 else: raise RuntimeError( f"Unknown classification loss {config.class_loss_func}") @@ -590,7 +588,6 @@ def create_builtin_node_decoder(g, decoder_input_dim, config, train_task): alpha = config.alpha if config.alpha is not None else 0.25 gamma = config.gamma if config.gamma is not None else 2. loss_func[ntype] = FocalLossFunc(alpha, gamma) - decoder_output_dim = 1 else: raise RuntimeError( f"Unknown classification loss {config.class_loss_func}") @@ -711,7 +708,7 @@ def create_builtin_edge_decoder(g, decoder_input_dim, config, train_task): "Focal loss only works with binary classification. " "num_classes should be set to 2." ) - decoder_output_dim = 1 + decoder_output_dim = 2 else: decoder_output_dim = num_classes @@ -1490,7 +1487,7 @@ def restore_builtin_model_from_artifacts(model_dir, json_file, yaml_file): with the trained model weights, a JSON file that is the GConstruct configuration specification with data-derived transformations, and a YAML file that is the Graphstorm train configuration updated with runtime arguments. - + This method uses the `GSMetaData` and `GSDglDistGraphFromMetadata` to create a lightweight graph that only contains graph structure, and then use it to restore a built-in GraphStorm model, and return both the model and the graph construction and model configuration objects. diff --git a/python/graphstorm/model/loss_func.py b/python/graphstorm/model/loss_func.py index a41b010e6e..16ccec8cb2 100644 --- a/python/graphstorm/model/loss_func.py +++ b/python/graphstorm/model/loss_func.py @@ -16,7 +16,6 @@ Loss functions. """ import logging -import warnings import torch as th from torch import nn @@ -123,13 +122,6 @@ def __init__(self, alpha=0.25, gamma=2): super(FocalLossFunc, self).__init__() self.alpha = alpha self.gamma = gamma - # TODO: Focal loss should also produce (N, num_classes) output - if get_rank() == 0: - warnings.warn( - "Focal loss currently produces predictions with shape (N, 1) where N " - "is the number of targets. This behavior will change in v0.5.0" - "to produce predictions of shape (N, 2).", - FutureWarning) def forward(self, logits, labels): """ The forward function. @@ -137,31 +129,42 @@ def forward(self, logits, labels): Parameters ---------- logits: torch.Tensor - The prediction results. + The prediction results, shape (N, 2) containing [score_class0, score_class1] labels: torch.Tensor - The training labels. + The training labels Returns ------- loss: Tensor The loss value. + + .. versionchanged:: 0.5.0 + Produce (N, 2)-shaped instead of (N, 1), matching other loss functions + (e.g. binary cross-entropy). + """ - # We need to reshape logits into a 1D float tensor - # and cast labels into a float tensor. - inputs = logits.squeeze() + # Extract logits for positive class (class 1) + inputs = logits[:, 1] # Shape: (N,) targets = labels.float() + # Compute probabilities using sigmoid pred = th.sigmoid(inputs) + + # Compute binary cross entropy ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none") + + # Get probability of correct class pred_t = pred * targets + (1 - pred) * (1 - targets) + + # Apply focal loss modulation loss = ce_loss * ((1 - pred_t) ** self.gamma) + # Apply alpha balancing if self.alpha >= 0: alpha_t = self.alpha * targets + (1 - self.alpha) * (1 - targets) loss = alpha_t * loss - loss = loss.mean() - return loss + return loss.mean() @property def in_dims(self): diff --git a/tests/unit-tests/test_loss.py b/tests/unit-tests/test_loss.py index cce0693a56..73967074db 100644 --- a/tests/unit-tests/test_loss.py +++ b/tests/unit-tests/test_loss.py @@ -89,26 +89,57 @@ def test_WeightedLinkPredictAdvBCELossFunc(num_pos, num_neg): assert_almost_equal(loss.numpy(),gt_loss.numpy()) def test_FocalLossFunc(): + # Test case 1: Strong predictions alpha = 0.25 - gamma = 2. + gamma = 2.0 loss_func = FocalLossFunc(alpha, gamma) - logits = th.tensor([[0.6330],[0.9946],[0.2322],[0.0115],[0.9159],[0.5752],[0.4491], [0.9231],[0.7170],[0.2761]]) - labels = th.tensor([0, 0, 0, 1, 1, 1, 0, 1, 0, 0]) - # Manually call the torchvision.ops.sigmoid_focal_loss to generate the loss value - gt_loss = th.tensor(0.1968) - loss = loss_func(logits, labels) - assert_almost_equal(loss.numpy(), gt_loss.numpy(), decimal=4) - + # Create logits for both classes + logits = th.tensor([ + [2.0, -2.0], # Strong prediction for class 0 + [-3.0, 3.0], # Strong prediction for class 1 + [0.1, -0.1], # Weak prediction for class 0 + [-0.2, 0.2] # Weak prediction for class 1 + ]) + labels = th.tensor([0, 1, 0, 1]) + + # Get our implementation's loss + our_loss = loss_func(logits, labels) + + # Get torchvision's loss using the positive class logits + # To get the results we used: + # from torchvision.ops import sigmoid_focal_loss + # tv_loss = sigmoid_focal_loss( + # logits[:, 1], # Take logits for positive class + # labels.float(), + # alpha=alpha, + # gamma=gamma, + # reduction='mean' + # ) + tv_loss = th.tensor(0.0352) + + assert_almost_equal(our_loss.numpy(), tv_loss.numpy(), decimal=4) + + # Test case 2: Original test case alpha = 0.2 gamma = 1.5 loss_func = FocalLossFunc(alpha, gamma) - logits = th.tensor([2.8205, 0.4035, 0.8215, 1.9420, 0.2400, 2.8565, 1.8330, 0.7786, 2.0962, 1.0399]) + + logits_orig = th.tensor([ + [2.8205, -2.8205], [0.4035, -0.4035], [0.8215, -0.8215], + [1.9420, -1.9420], [0.2400, -0.2400], [2.8565, -2.8565], + [1.8330, -1.8330], [0.7786, -0.7786], [2.0962, -2.0962], + [1.0399, -1.0399] + ]) labels = th.tensor([0, 0, 1, 0, 1, 1, 1, 1, 0, 0]) - # Manually call the torchvision.ops.sigmoid_focal_loss to generate the loss value - gt_loss = th.tensor(0.6040) - loss = loss_func(logits, labels) - assert_almost_equal(loss.numpy(), gt_loss.numpy(), decimal=4) + + # Get our implementation's loss + our_loss = loss_func(logits_orig, labels) + + # Get torchvision's loss + tv_loss = th.tensor(0.1335) + + assert_almost_equal(our_loss.numpy(), tv_loss.numpy(), decimal=4) @pytest.mark.parametrize("num_pos", [1, 8, 32]) @pytest.mark.parametrize("num_neg", [1, 8, 32]) @@ -192,13 +223,3 @@ def test_ShrinkageLossFunc(): gt_loss = th.tensor(0.0692) loss = loss_func(logits, labels) assert_almost_equal(loss.numpy(), gt_loss.numpy(), decimal=4) - - -if __name__ == '__main__': - test_FocalLossFunc() - test_LinkPredictBPRLossFunc() - test_WeightedLinkPredictBPRLossFunc() - test_ShrinkageLossFunc() - - test_LinkPredictAdvBCELossFunc(16, 128) - test_WeightedLinkPredictAdvBCELossFunc(16, 128)