From 15a9a472f4dc54670bcca2ffdf7986c77fc8b175 Mon Sep 17 00:00:00 2001 From: shreyas-omkar Date: Mon, 23 Feb 2026 20:51:10 +0530 Subject: [PATCH] feat: add GPU-native kron support for Diagonal matrices --- src/host/linalg.jl | 54 ++++++++++++++++++++++++++++++++++++++++ test/testsuite/linalg.jl | 38 ++++++++++++++++++++++++++++ 2 files changed, 92 insertions(+) diff --git a/src/host/linalg.jl b/src/host/linalg.jl index ea4c89cc..92310d2e 100644 --- a/src/host/linalg.jl +++ b/src/host/linalg.jl @@ -952,3 +952,57 @@ for wrapa in trans_adj_wrappers, wrapb in trans_adj_wrappers return kron!(C, A, B) end end + +@kernel function kron_diag_dense_kernel!(C, @Const(a), @Const(B)) + ci, cj = @index(Global, NTuple) + mb = size(B, 1) + nb = size(B, 2) + i = fld1(ci, mb) + bi = mod1(ci, mb) + j = fld1(cj, nb) + bj = mod1(cj, nb) + @inbounds C[ci, cj] = (i == j) ? a[i] * B[bi, bj] : zero(eltype(C)) +end + +@kernel function kron_dense_diag_kernel!(C, @Const(A), @Const(b)) + ci, cj = @index(Global, NTuple) + nb = length(b) + i = fld1(ci, nb) + bi = mod1(ci, nb) + j = fld1(cj, nb) + bj = mod1(cj, nb) + @inbounds C[ci, cj] = (bi == bj) ? A[i, j] * b[bi] : zero(eltype(C)) +end + +function LinearAlgebra.kron!(C::AbstractGPUMatrix, A::Diagonal{T1, <:AbstractGPUVector}, B::AbstractGPUMatrix{T2}) where {T1, T2} + size(C) == (length(A.diag) * size(B, 1), length(A.diag) * size(B, 2)) || throw(DimensionMismatch()) + backend = KernelAbstractions.get_backend(C) + kron_diag_dense_kernel!(backend)(C, A.diag, B, ndrange = size(C)) + return C +end + +function LinearAlgebra.kron(A::Diagonal{T1, <:AbstractGPUVector}, B::AbstractGPUMatrix{T2}) where {T1, T2} + T = promote_type(T1, T2) + return kron!(similar(B, T, length(A.diag) * size(B, 1), length(A.diag) * size(B, 2)), A, B) +end + +function LinearAlgebra.kron!(C::AbstractGPUMatrix, A::AbstractGPUMatrix{T1}, B::Diagonal{T2, <:AbstractGPUVector}) where {T1, T2} + size(C) == (size(A, 1) * length(B.diag), size(A, 2) * length(B.diag)) || throw(DimensionMismatch()) + backend = KernelAbstractions.get_backend(C) + kron_dense_diag_kernel!(backend)(C, A, B.diag, ndrange = size(C)) + return C +end + +function LinearAlgebra.kron(A::AbstractGPUMatrix{T1}, B::Diagonal{T2, <:AbstractGPUVector}) where {T1, T2} + T = promote_type(T1, T2) + return kron!(similar(A, T, size(A, 1) * length(B.diag), size(A, 2) * length(B.diag)), A, B) +end + +function LinearAlgebra.kron!(C::Diagonal{<:Any, <:AbstractGPUVector}, A::Diagonal{T1, <:AbstractGPUVector}, B::Diagonal{T2, <:AbstractGPUVector}) where {T1, T2} + kron!(C.diag, A.diag, B.diag) + return C +end + +function LinearAlgebra.kron(A::Diagonal{T1, <:AbstractGPUVector}, B::Diagonal{T2, <:AbstractGPUVector}) where {T1, T2} + Diagonal(kron(A.diag, B.diag)) +end diff --git a/test/testsuite/linalg.jl b/test/testsuite/linalg.jl index cf3b4448..34a4b722 100644 --- a/test/testsuite/linalg.jl +++ b/test/testsuite/linalg.jl @@ -567,3 +567,41 @@ end end end end + +@testsuite "linalg/kron_diagonal" (AT, eltypes) -> begin + for T in filter(T -> T == Float32 || T == Float64, eltypes) + n, m = 16, 8 + a, b = rand(T, n), rand(T, m) + + # Diagonal*Diagonal + R = kron(Diagonal(adapt(AT, a)), Diagonal(adapt(AT, b))) + @test R isa Diagonal + @test Array(R.diag) ≈ kron(a, b) + + # Diagonal*Dense + B = rand(T, m, m) + R2 = kron(Diagonal(adapt(AT, a)), adapt(AT, B)) + @test Array(R2) ≈ kron(Matrix(Diagonal(a)), B) + + # Dense*Diagonal + A = rand(T, n, n) + R3 = kron(adapt(AT, A), Diagonal(adapt(AT, b))) + @test Array(R3) ≈ kron(A, Matrix(Diagonal(b))) + + # kron! Diagonal*Diagonal + C1 = Diagonal(adapt(AT, zeros(T, n * m))) + kron!(C1, Diagonal(adapt(AT, a)), Diagonal(adapt(AT, b))) + @test C1 isa Diagonal + @test Array(C1.diag) ≈ kron(a, b) + + # kron! Diagonal*Dense + C2 = adapt(AT, zeros(T, n * m, n * m)) + kron!(C2, Diagonal(adapt(AT, a)), adapt(AT, B)) + @test Array(C2) ≈ kron(Matrix(Diagonal(a)), B) + + # kron! Dense*Diagonal + C3 = adapt(AT, zeros(T, n * m, n * m)) + kron!(C3, adapt(AT, A), Diagonal(adapt(AT, b))) + @test Array(C3) ≈ kron(A, Matrix(Diagonal(b))) + end +end