Skip to content

Comments

feat: add GPU-native kron support for Diagonal matrices#690

Open
shreyas-omkar wants to merge 1 commit intoJuliaGPU:masterfrom
shreyas-omkar:master
Open

feat: add GPU-native kron support for Diagonal matrices#690
shreyas-omkar wants to merge 1 commit intoJuliaGPU:masterfrom
shreyas-omkar:master

Conversation

@shreyas-omkar
Copy link
Contributor

@shreyas-omkar shreyas-omkar commented Feb 23, 2026

Feat #668

Diagonal⊗Diagonal: previously required manual densification (diagm), causing O(n⁴) memory blowup (1 GiB at n=128, OOM beyond). Now returns a Diagonal directly with O(n) memory.
Diagonal⊗Dense / Dense⊗Diagonal: previously crashed with scalar indexing errors. Now handled by dedicated GPU kernels.

@shreyas-omkar
Copy link
Contributor Author

Before After

Dense workaround for Diagonal⊗Diagonal causes O(n⁴) scaling. At n=128 already 13ms and 1024 MiB.

Diagonal⊗Diagonal now O(n) runs up to n=2048 in 0.4ms. Dense cases ~100x faster.

Diagonal⊗Diagonal allocates 4→64→1024 MiB (n=32→64→128). n=256 would OOM at 16 GiB.

Diagonal⊗Diagonal stays at 0.004→0.06→16 MiB across n=32→2048. O(n) instead of O(n⁴).

Note on Dense cases (before): The before benchmark uses the dense matrix workaround kron(diagm(d), B) — not the scalar indexing crash that users actually hit. Real usage without the fix would error immediately.

Device: RTX3060 (sm_86)

@shreyas-omkar shreyas-omkar marked this pull request as ready for review February 23, 2026 18:38
@kshyatt
Copy link
Member

kshyatt commented Feb 24, 2026

Nice! One question I have: am I reading the graphs correctly that the Dense x Diagonal cases are now allocating more/take longer?

@shreyas-omkar
Copy link
Contributor Author

The Dense⊗Diagonal and Diagonal⊗Dense cases use the same amount of memory before and after the fix because the result is always a full dense matrix. In kron(A, Diagonal(b)), where A is n×n and b has m elements, the output is an (n*m) × (n*m) matrix where every value matters. So we have to allocate the whole thing anyway. The fix here was only about correctness. Earlier it crashed on GPU due to scalar indexing. Now it runs, but memory stays the same because the output size is the same.

The Diagonal⊗Diagonal case is different. kron(Diagonal(a), Diagonal(b)) is also just a diagonal matrix. That means only nm values actually matter and everything else is zero. Before, we were still creating the full (nm) × (nm) dense matrix to store this, which is basically wasting space. Now we directly return Diagonal(kron(a, b)), so instead of storing (nm)² values we store only n*m.

At n = 128, this means going from storing a ~1024 MiB dense matrix to just ~0.06 MiB.

The Dense cases only look higher now because the Diagonal⊗Diagonal case dropped so much after the fix.

@kshyatt
Copy link
Member

kshyatt commented Feb 24, 2026

Oh I see what's going on, the x-axes are not the same. It would be great to show the %/OOM change in time in the future for such plots

@shreyas-omkar
Copy link
Contributor Author

Oh I see what's going on, the x-axes are not the same. It would be great to show the %/OOM change in time in the future for such plots

Sure. I'll take a note of it. Thank you.

@kshyatt
Copy link
Member

kshyatt commented Feb 24, 2026

While we're modifying this logic anyway, would it be possible to provide a lower level integration to https://docs.julialang.org/en/v1/stdlib/LinearAlgebra/#Base.kron! so that users can provide/reuse a pre-allocated output C?

@shreyas-omkar
Copy link
Contributor Author

While we're modifying this logic anyway, would it be possible to provide a lower level integration to https://docs.julialang.org/en/v1/stdlib/LinearAlgebra/#Base.kron! so that users can provide/reuse a pre-allocated output C?

Sure I will be happy to do this. Adds to my learning.

@github-actions
Copy link
Contributor

Your PR requires formatting changes to meet the project's style guidelines.
Please consider running Runic (git runic master) to apply these changes.

Click here to view the suggested changes.
diff --git a/src/host/linalg.jl b/src/host/linalg.jl
index 92310d2..1538d2a 100644
--- a/src/host/linalg.jl
+++ b/src/host/linalg.jl
@@ -1004,5 +1004,5 @@ function LinearAlgebra.kron!(C::Diagonal{<:Any, <:AbstractGPUVector}, A::Diagona
 end
 
 function LinearAlgebra.kron(A::Diagonal{T1, <:AbstractGPUVector}, B::Diagonal{T2, <:AbstractGPUVector}) where {T1, T2}
-    Diagonal(kron(A.diag, B.diag))
+    return Diagonal(kron(A.diag, B.diag))
 end

@shreyas-omkar
Copy link
Contributor Author

@kshyatt please take a look.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants