diff --git a/onnxscript/function_libs/torch_lib/ops/nn.py b/onnxscript/function_libs/torch_lib/ops/nn.py index f58b2c7f18..385c87d9cd 100644 --- a/onnxscript/function_libs/torch_lib/ops/nn.py +++ b/onnxscript/function_libs/torch_lib/ops/nn.py @@ -1837,10 +1837,6 @@ def aten_scaled_dot_product_attention( if enable_gqa: key, value = _attention_repeat_kv_for_group_query(query, key, value) - else: - assert query.shape[1] == key.shape[1] == value.shape[1], ( - "SDPA (MHA) requires q_num_heads = kv_num_heads" - ) if attn_mask is None: return _aten_scaled_dot_product_attention_no_mask_onnx( diff --git a/tests/function_libs/torch_lib/e2e_ops_tests.py b/tests/function_libs/torch_lib/e2e_ops_tests.py index d344723408..7e1eeb89a7 100644 --- a/tests/function_libs/torch_lib/e2e_ops_tests.py +++ b/tests/function_libs/torch_lib/e2e_ops_tests.py @@ -228,6 +228,37 @@ def forward(self, q, k, v): ) _testing.assert_onnx_program(onnx_program) + def test_optional_enable_gqa_in_attention(self): + class Model(torch.nn.Module): + def forward(self, q, k, v): + return torch.nn.functional.scaled_dot_product_attention( # pylint: disable=not-callable + q, + k, + v, + ) + + model = Model() + + # scaled_dot_product_attention works even if query.shape[1] != key.shape[1] + # due to broadcasting + query = torch.randn(2, 1, 8, 16) + key = torch.randn(2, 2, 8, 16) + value = torch.randn(2, 2, 8, 16) + + onnx_program = torch.onnx.export( + model, + ( + query, + key, + value, + ), + input_names=["query", "key", "value"], + output_names=["output"], + opset_version=18, + dynamo=True, + ) + _testing.assert_onnx_program(onnx_program) + def test_bitwise_and_scalar(self): class Model(torch.nn.Module): def forward(self, x):