Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 2 additions & 5 deletions python/graphstorm/gsf.py
Original file line number Diff line number Diff line change
Expand Up @@ -549,8 +549,6 @@ def create_builtin_node_decoder(g, decoder_input_dim, config, train_task):
alpha = config.alpha if config.alpha is not None else 0.25
gamma = config.gamma if config.gamma is not None else 2.
loss_func = FocalLossFunc(alpha, gamma)
# Focal loss expects 1-dimensional output
decoder_output_dim = 1
else:
raise RuntimeError(
f"Unknown classification loss {config.class_loss_func}")
Expand Down Expand Up @@ -590,7 +588,6 @@ def create_builtin_node_decoder(g, decoder_input_dim, config, train_task):
alpha = config.alpha if config.alpha is not None else 0.25
gamma = config.gamma if config.gamma is not None else 2.
loss_func[ntype] = FocalLossFunc(alpha, gamma)
decoder_output_dim = 1
else:
raise RuntimeError(
f"Unknown classification loss {config.class_loss_func}")
Expand Down Expand Up @@ -711,7 +708,7 @@ def create_builtin_edge_decoder(g, decoder_input_dim, config, train_task):
"Focal loss only works with binary classification. "
"num_classes should be set to 2."
)
decoder_output_dim = 1
decoder_output_dim = 2
else:
decoder_output_dim = num_classes

Expand Down Expand Up @@ -1490,7 +1487,7 @@ def restore_builtin_model_from_artifacts(model_dir, json_file, yaml_file):
with the trained model weights, a JSON file that is the GConstruct configuration specification
with data-derived transformations, and a YAML file that is the Graphstorm train configuration
updated with runtime arguments.

This method uses the `GSMetaData` and `GSDglDistGraphFromMetadata` to create a lightweight
graph that only contains graph structure, and then use it to restore a built-in GraphStorm
model, and return both the model and the graph construction and model configuration objects.
Expand Down
33 changes: 18 additions & 15 deletions python/graphstorm/model/loss_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
Loss functions.
"""
import logging
import warnings

import torch as th
from torch import nn
Expand Down Expand Up @@ -123,45 +122,49 @@ def __init__(self, alpha=0.25, gamma=2):
super(FocalLossFunc, self).__init__()
self.alpha = alpha
self.gamma = gamma
# TODO: Focal loss should also produce (N, num_classes) output
if get_rank() == 0:
warnings.warn(
"Focal loss currently produces predictions with shape (N, 1) where N "
"is the number of targets. This behavior will change in v0.5.0"
"to produce predictions of shape (N, 2).",
FutureWarning)

def forward(self, logits, labels):
""" The forward function.

Parameters
----------
logits: torch.Tensor
The prediction results.
The prediction results, shape (N, 2) containing [score_class0, score_class1]
labels: torch.Tensor
The training labels.
The training labels

Returns
-------
loss: Tensor
The loss value.

.. versionchanged:: 0.5.0
Produce (N, 2)-shaped instead of (N, 1), matching other loss functions
(e.g. binary cross-entropy).

"""
# We need to reshape logits into a 1D float tensor
# and cast labels into a float tensor.
inputs = logits.squeeze()
# Extract logits for positive class (class 1)
inputs = logits[:, 1] # Shape: (N,)
targets = labels.float()

# Compute probabilities using sigmoid
pred = th.sigmoid(inputs)

# Compute binary cross entropy
ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")

# Get probability of correct class
pred_t = pred * targets + (1 - pred) * (1 - targets)

# Apply focal loss modulation
loss = ce_loss * ((1 - pred_t) ** self.gamma)

# Apply alpha balancing
if self.alpha >= 0:
alpha_t = self.alpha * targets + (1 - self.alpha) * (1 - targets)
loss = alpha_t * loss

loss = loss.mean()
return loss
return loss.mean()

@property
def in_dims(self):
Expand Down
67 changes: 44 additions & 23 deletions tests/unit-tests/test_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,26 +89,57 @@ def test_WeightedLinkPredictAdvBCELossFunc(num_pos, num_neg):
assert_almost_equal(loss.numpy(),gt_loss.numpy())

def test_FocalLossFunc():
# Test case 1: Strong predictions
alpha = 0.25
gamma = 2.
gamma = 2.0

loss_func = FocalLossFunc(alpha, gamma)
logits = th.tensor([[0.6330],[0.9946],[0.2322],[0.0115],[0.9159],[0.5752],[0.4491], [0.9231],[0.7170],[0.2761]])
labels = th.tensor([0, 0, 0, 1, 1, 1, 0, 1, 0, 0])
# Manually call the torchvision.ops.sigmoid_focal_loss to generate the loss value
gt_loss = th.tensor(0.1968)
loss = loss_func(logits, labels)
assert_almost_equal(loss.numpy(), gt_loss.numpy(), decimal=4)

# Create logits for both classes
logits = th.tensor([
[2.0, -2.0], # Strong prediction for class 0
[-3.0, 3.0], # Strong prediction for class 1
[0.1, -0.1], # Weak prediction for class 0
[-0.2, 0.2] # Weak prediction for class 1
])
labels = th.tensor([0, 1, 0, 1])

# Get our implementation's loss
our_loss = loss_func(logits, labels)

# Get torchvision's loss using the positive class logits
# To get the results we used:
# from torchvision.ops import sigmoid_focal_loss
# tv_loss = sigmoid_focal_loss(
# logits[:, 1], # Take logits for positive class
# labels.float(),
# alpha=alpha,
# gamma=gamma,
# reduction='mean'
# )
tv_loss = th.tensor(0.0352)

assert_almost_equal(our_loss.numpy(), tv_loss.numpy(), decimal=4)

# Test case 2: Original test case
alpha = 0.2
gamma = 1.5
loss_func = FocalLossFunc(alpha, gamma)
logits = th.tensor([2.8205, 0.4035, 0.8215, 1.9420, 0.2400, 2.8565, 1.8330, 0.7786, 2.0962, 1.0399])

logits_orig = th.tensor([
[2.8205, -2.8205], [0.4035, -0.4035], [0.8215, -0.8215],
[1.9420, -1.9420], [0.2400, -0.2400], [2.8565, -2.8565],
[1.8330, -1.8330], [0.7786, -0.7786], [2.0962, -2.0962],
[1.0399, -1.0399]
])
labels = th.tensor([0, 0, 1, 0, 1, 1, 1, 1, 0, 0])
# Manually call the torchvision.ops.sigmoid_focal_loss to generate the loss value
gt_loss = th.tensor(0.6040)
loss = loss_func(logits, labels)
assert_almost_equal(loss.numpy(), gt_loss.numpy(), decimal=4)

# Get our implementation's loss
our_loss = loss_func(logits_orig, labels)

# Get torchvision's loss
tv_loss = th.tensor(0.1335)

assert_almost_equal(our_loss.numpy(), tv_loss.numpy(), decimal=4)

@pytest.mark.parametrize("num_pos", [1, 8, 32])
@pytest.mark.parametrize("num_neg", [1, 8, 32])
Expand Down Expand Up @@ -192,13 +223,3 @@ def test_ShrinkageLossFunc():
gt_loss = th.tensor(0.0692)
loss = loss_func(logits, labels)
assert_almost_equal(loss.numpy(), gt_loss.numpy(), decimal=4)


if __name__ == '__main__':
test_FocalLossFunc()
test_LinkPredictBPRLossFunc()
test_WeightedLinkPredictBPRLossFunc()
test_ShrinkageLossFunc()

test_LinkPredictAdvBCELossFunc(16, 128)
test_WeightedLinkPredictAdvBCELossFunc(16, 128)
Loading