Skip to content

CUDA: add gqa_ratio 4 for GLM 4.7 flash#18953

Merged
am17an merged 6 commits intoggml-org:masterfrom
am17an:glm_4.7_headsize
Jan 22, 2026
Merged

CUDA: add gqa_ratio 4 for GLM 4.7 flash#18953
am17an merged 6 commits intoggml-org:masterfrom
am17an:glm_4.7_headsize

Conversation

@am17an
Copy link
Collaborator

@am17an am17an commented Jan 20, 2026

Enable FA for GLM 4.7, I'm not sure it's optimal but at least it does not go to CPU. Fixes #18944

@github-actions github-actions bot added Nvidia GPU Issues specific to Nvidia GPUs python python script changes ggml changes relating to the ggml tensor library for machine learning labels Jan 20, 2026
@am17an
Copy link
Collaborator Author

am17an commented Jan 20, 2026

Looks like there is a bug there, taking a look

@JohannesGaessler
Copy link
Collaborator

I would have thought the correct patch is this but the results are wrong.

diff --git a/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ggml/src/ggml-cuda/fattn-mma-f16.cuh
index 9e98da95f..6cca5b2ec 100644
--- a/ggml/src/ggml-cuda/fattn-mma-f16.cuh
+++ b/ggml/src/ggml-cuda/fattn-mma-f16.cuh
@@ -510,7 +510,6 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
                 }
             }
         } else {
-            static_assert(cols_per_warp != 8, "cols_per_warp == 8 not implemented");
 #pragma unroll
             for (int k_KQ_0 = k0_start; k_KQ_0 < k0_stop; k_KQ_0 += T_A_KQ::J) {
                 load_ldmatrix(Q_B[0], tile_Q + (threadIdx.y / np)*(T_B_KQ::I*stride_tile_Q) + k_KQ_0, stride_tile_Q);
@@ -522,14 +521,18 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
                     T_A_KQ K_A;
                     load_ldmatrix(K_A, tile_K + i_KQ_0*stride_tile_K + (k_KQ_0 - k0_start), stride_tile_K);
 
-                    // Wide version of KQ_C is column-major
+                    if constexpr (cols_per_warp == 8) {
+                        mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[0]);
+                    } else {
+                        // Wide version of KQ_C is column-major
 #if defined(AMD_WMMA_AVAILABLE)
-                    // RDNA matrix C is column-major.
-                    mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[0]);
+                        // RDNA matrix C is column-major.
+                        mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[0]);
 #else
-                    // swap A and B for CUDA.
-                    mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], Q_B[0], K_A);
+                        // swap A and B for CUDA.
+                        mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], Q_B[0], K_A);
 #endif // defined(AMD_WMMA_AVAILABLE)
+                    }
                 }
             }
         }

@JohannesGaessler
Copy link
Collaborator

I forgot that the tests in test-backend-ops are not reliable for MLA models because they do not account for the CUDA backend using the K data also for V. The tests should be adapted to pass a view of K as V specifically for head size 576 since that is only used for MLA models. I pushed a patch that works correctly with /opt/models/GLM-4.7-Flash-Q4_K_M.gguf on the server.

@am17an
Copy link
Collaborator Author

am17an commented Jan 21, 2026

Added rest of your comments. In general this model spews out nonsense with and without fa, I think that should be fixed with #18980

@ggerganov
Copy link
Member

I forgot that the tests in test-backend-ops are not reliable for MLA models because they do not account for the CUDA backend using the K data also for V. The tests should be adapted to pass a view of K as V specifically for head size 576 since that is only used for MLA models.

We can add the tests, but unless there is a technical blocker, it would be useful to support separate K and V data in the CUDA implementation as well.

@mdierolf
Copy link

This patch is working well for me, combined with --override-kv deepseek2.expert_gating_func=int:2 to resolve most of the issues with GLM 4.7 Flash

2100 tokens/sec prompt processing and 90 tokens per second on RTX 6000 Blackwell. Output looks overall pretty decent when using the Unsloth FP16 GGUF

@ggerganov
Copy link
Member

ggerganov commented Jan 21, 2026

@JohannesGaessler @am17an This PR should OK to merge. I have outlined a plan for improving the implementation and fixing the tests in #18986. Let's continue there after we merge this.

@am17an
Copy link
Collaborator Author

am17an commented Jan 21, 2026

In the CI, the following tests are failing, locally I am able get tests to pass. So I'm thinking there's probably bug in the tile kernel we haven't fixed

  FLASH_ATTN_EXT(hsk=576,hsv=512,nh=4,nr23=[4,1],kv=512,nb=1,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f32,permute=[0,1,2,3])
  FLASH_ATTN_EXT(hsk=576,hsv=512,nh=4,nr23=[4,1],kv=512,nb=1,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f32,permute=[0,2,1,3])
  FLASH_ATTN_EXT(hsk=576,hsv=512,nh=4,nr23=[4,1],kv=512,nb=1,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3])
  FLASH_ATTN_EXT(hsk=576,hsv=512,nh=4,nr23=[4,1],kv=512,nb=1,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,2,1,3])
  FLASH_ATTN_EXT(hsk=576,hsv=512,nh=4,nr23=[4,1],kv=512,nb=3,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f32,permute=[0,1,2,3])
  FLASH_ATTN_EXT(hsk=576,hsv=512,nh=4,nr23=[4,1],kv=512,nb=3,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f32,permute=[0,2,1,3])
  FLASH_ATTN_EXT(hsk=576,hsv=512,nh=4,nr23=[4,1],kv=512,nb=3,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3])
  FLASH_ATTN_EXT(hsk=576,hsv=512,nh=4,nr23=[4,1],kv=512,nb=3,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,2,1,3])
  FLASH_ATTN_EXT(hsk=576,hsv=512,nh=4,nr23=[4,1],kv=512,nb=32,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f32,permute=[0,1,2,3])
  FLASH_ATTN_EXT(hsk=576,hsv=512,nh=4,nr23=[4,1],kv=512,nb=32,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f32,permute=[0,2,1,3])
  FLASH_ATTN_EXT(hsk=576,hsv=512,nh=4,nr23=[4,1],kv=512,nb=32,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3])
  FLASH_ATTN_EXT(hsk=576,hsv=512,nh=4,nr23=[4,1],kv=512,nb=32,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,2,1,3])
  FLASH_ATTN_EXT(hsk=576,hsv=512,nh=4,nr23=[4,1],kv=512,nb=35,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f32,permute=[0,1,2,3])
  FLASH_ATTN_EXT(hsk=576,hsv=512,nh=4,nr23=[4,1],kv=512,nb=35,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f32,permute=[0,2,1,3])
  FLASH_ATTN_EXT(hsk=576,hsv=512,nh=4,nr23=[4,1],kv=512,nb=35,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3])
  FLASH_ATTN_EXT(hsk=576,hsv=512,nh=4,nr23=[4,1],kv=512,nb=35,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,2,1,3])
  FLASH_ATTN_EXT(hsk=576,hsv=512,nh=4,nr23=[4,1],kv=512,nb=1,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f32,permute=[0,1,2,3])
  FLASH_ATTN_EXT(hsk=576,hsv=512,nh=4,nr23=[4,1],kv=512,nb=1,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f32,permute=[0,2,1,3])
  FLASH_ATTN_EXT(hsk=576,hsv=512,nh=4,nr23=[4,1],kv=512,nb=1,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3])
  FLASH_ATTN_EXT(hsk=576,hsv=512,nh=4,nr23=[4,1],kv=512,nb=1,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,2,1,3])
  FLASH_ATTN_EXT(hsk=576,hsv=512,nh=4,nr23=[4,1],kv=512,nb=3,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f32,permute=[0,1,2,3])
  FLASH_ATTN_EXT(hsk=576,hsv=512,nh=4,nr23=[4,1],kv=512,nb=3,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f32,permute=[0,2,1,3])
  FLASH_ATTN_EXT(hsk=576,hsv=512,nh=4,nr23=[4,1],kv=512,nb=3,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3])
  FLASH_ATTN_EXT(hsk=576,hsv=512,nh=4,nr23=[4,1],kv=512,nb=3,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,2,1,3])
  FLASH_ATTN_EXT(hsk=576,hsv=512,nh=4,nr23=[4,1],kv=512,nb=32,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f32,permute=[0,1,2,3])
  FLASH_ATTN_EXT(hsk=576,hsv=512,nh=4,nr23=[4,1],kv=512,nb=32,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f32,permute=[0,2,1,3])
  FLASH_ATTN_EXT(hsk=576,hsv=512,nh=4,nr23=[4,1],kv=512,nb=32,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3])
  FLASH_ATTN_EXT(hsk=576,hsv=512,nh=4,nr23=[4,1],kv=512,nb=32,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,2,1,3])
  FLASH_ATTN_EXT(hsk=576,hsv=512,nh=4,nr23=[4,1],kv=512,nb=35,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f32,permute=[0,1,2,3])
  FLASH_ATTN_EXT(hsk=576,hsv=512,nh=4,nr23=[4,1],kv=512,nb=35,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f32,permute=[0,2,1,3])
  FLASH_ATTN_EXT(hsk=576,hsv=512,nh=4,nr23=[4,1],kv=512,nb=35,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3])
  FLASH_ATTN_EXT(hsk=576,hsv=512,nh=4,nr23=[4,1],kv=512,nb=35,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,2,1,3])

@JohannesGaessler
Copy link
Collaborator

I didn't touch any of the tests so it's expected that they are still failing.

@ggerganov
Copy link
Member

In the CI, the following tests are failing, locally I am able get tests to pass.

@am17an The test should fail locally too. For me, they fail on DGX Spark.

On which setup are they passing for you?

@ggerganov
Copy link
Member

Btw, the windows workflow fails to compile: https://github.com/ggml-org/llama.cpp/actions/runs/21195664178/job/60970976275?pr=18953

Any ideas?

@am17an
Copy link
Collaborator Author

am17an commented Jan 21, 2026

The tests now fail for me too. Is the fix going to go in #18986?

@ggerganov
Copy link
Member

The tests now fail for me too. Is the fix going to go in #18986?

Yes, it needs some more work and will be merged after this PR.

Regarding the windows build - from the logs, I think this is the problematic part:

const int k_VKQ_0 = kb0 * nbatch_fa;
#if defined(TURING_MMA_AVAILABLE)
T_C_KQ KQ_C[nbatch_fa/(np*(cols_per_warp == 8 ? T_C_KQ::I : T_C_KQ::J))];
#elif defined(AMD_WMMA_AVAILABLE)
T_C_KQ KQ_C[nbatch_fa/(np*T_C_KQ::J)];
#else // Volta
T_C_KQ KQ_C[nbatch_fa/(np*T_C_KQ::J)];
#endif // defined(TURING_MMA_AVAILABLE)

2026-01-21T10:56:39.5136943Z D:\a\llama.cpp\llama.cpp\ggml\src\ggml-cuda\template-instances\../fattn-mma-f16.cuh(454): error: the size of an array must be greater than zero
2026-01-21T10:56:39.5137536Z       T_C_KQ KQ_C[nbatch_fa/(np*T_C_KQ::J)];
2026-01-21T10:56:39.5137756Z                   ^
2026-01-21T10:56:39.5137908Z           detected during:
2026-01-21T10:56:39.5141101Z             instantiation of "void flash_attn_ext_f16_iter<DKQ,DV,ncols1,ncols2,nwarps,use_logit_softcap,mla,needs_fixup,is_fixup,last_iter,oob_check,T_A_KQ,T_B_KQ,T_C_KQ,T_A_VKQ,T_B_VKQ,T_C_VKQ>(const float2 *, const half2 *, const half2 *, const half *, float2 *, float2 *, float, float, float, uint3, int, int, int, int, half2 *, half2 *, half2 *, half *, T_B_KQ *, T_C_VKQ *, float *, float *, int, int, int) [with DKQ=576, DV=512, ncols1=2, ncols2=4, nwarps=2, use_logit_softcap=false, mla=true, needs_fixup=false, is_fixup=false, last_iter=false, oob_check=false, T_A_KQ=ggml_cuda_mma::tile<8, 4, half2, ggml_cuda_mma::DATA_LAYOUT_I_MAJOR_MIRRORED>, T_B_KQ=ggml_cuda_mma::tile<32, 4, half2, ggml_cuda_mma::DATA_LAYOUT_I_MAJOR>, T_C_KQ=ggml_cuda_mma::tile<32, 8, float, ggml_cuda_mma::DATA_LAYOUT_I_MAJOR>, T_A_VKQ=ggml_cuda_mma::tile<8, 4, half2, ggml_cuda_mma::DATA_LAYOUT_J_MAJOR_MIRRORED>, T_B_VKQ=ggml_cuda_mma::tile<32, 4, half2, ggml_cuda_mma::DATA_LAYOUT_I_MAJOR>, T_C_VKQ=ggml_cuda_mma::tile<32, 4, half2, ggml_cuda_mma::DATA_LAYOUT_I_MAJOR>]" at line 1106
2026-01-21T10:56:39.5145657Z             instantiation of "void flash_attn_ext_f16_process_tile<DKQ,DV,ncols1,ncols2,nwarps,use_logit_softcap,mla,needs_fixup,is_fixup>(const float2 *, const half2 *, const half2 *, const half *, const float *, float2 *, float2 *, float, float, float, uint3, int, int, int, int, int, int, int, int, int, int) [with DKQ=576, DV=512, ncols1=2, ncols2=4, nwarps=2, use_logit_softcap=false, mla=true, needs_fixup=false, is_fixup=false]" at line 1561
2026-01-21T10:56:39.5148489Z             instantiation of "void flash_attn_ext_f16<DKQ,DV,ncols1,ncols2,use_logit_softcap,mla>(const char *, const char *, const char *, const char *, const char *, const int *, float *, float2 *, float, float, float, float, uint32_t, float, int32_t, uint3, int32_t, int32_t, int32_t, int32_t, int32_t, int32_t, int32_t, int32_t, int32_t, int32_t, int32_t, int64_t, int32_t, int32_t, int64_t, int32_t, int32_t, int32_t, int32_t, int32_t, int64_t) [with DKQ=576, DV=512, ncols1=2, ncols2=4, use_logit_softcap=false, mla=true]" at line 1664
2026-01-21T10:56:39.5150933Z             instantiation of "void ggml_cuda_flash_attn_ext_mma_f16_case<DKQ,DV,ncols1,ncols2>(ggml_backend_cuda_context &, ggml_tensor *) [with DKQ=576, DV=512, ncols1=2, ncols2=4]" at line 11 of D:\a\llama.cpp\llama.cpp\ggml\src\ggml-cuda\template-instances\fattn-mma-f16-instance-ncols1_2-ncols2_4.cu
2026-01-21T10:56:39.5151880Z 

If TURING_MMA_AVAILABLE is not available, we get a zero-sized array here.

@simcop2387
Copy link

Forgive me if this is covered in previous discussions and I'm not understanding but I'm attempting to build this PR myself

I'm building my own docker containers (with a bunch of other stuff bundled, hence why not using the official ones)
Base image right now is nvidia/cuda:12.8.1-cudnn-devel-ubuntu24.04 with
ARG PYTHON_VER=3.10
ARG CUDA_DOCKER_ARCH="61;70;75;80;86;89"

and finally building with

RUN cmake -Bbuild -DLLAMA_BUILD_TESTS=OFF -DGGML_BACKEND_DL=ON -DGGML_CPU_ALL_VARIANTS=ON" -DCMAKE_EXE_LINKER_FLAGS=-Wl,--allow-shlib-undefined -DGGML_CUDA=ON -DLLAMA_CURL=ON -DGGML_CUDA_FA_ALL_QUANTS=ON -DGGML_NATIVE=OFF -DCMAKE_CUDA_ARCHITECTURES="61;70;75;80;86;89" .
RUN cmake --build build --config Release -j$(nproc)

And running into a failure during build after applying it:

#18 309.5             instantiation of "void ggml_cuda_flash_attn_ext_mma_f16_case<DKQ,DV,ncols1,ncols2>(ggml_backend_cuda_context &, ggml_tensor *) [with DKQ=576, DV=512, ncols1=2, ncols2=4]" at line 11 of /build/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu
#18 309.5 
#18 309.5 /build/llama.cpp/ggml/src/ggml-cuda/template-instances/../fattn-mma-f16.cuh(859): error: static assertion failed with "bad loop size"
#18 309.5               static_assert(nbatch_fa % (np*T_A_VKQ::I) == 0, "bad loop size");
#18 309.5               ^
#18 309.5           detected during:
#18 309.5             instantiation of "void flash_attn_ext_f16_iter<DKQ,DV,ncols1,ncols2,nwarps,use_logit_softcap,mla,needs_fixup,is_fixup,last_iter,oob_check,T_A_KQ,T_B_KQ,T_C_KQ,T_A_VKQ,T_B_VKQ,T_C_VKQ>(const float2 *, const half2 *, const half2 *, const half *, float2 *, float2 *, float, float, float, uint3, int, int, int, int, half2 *, half2 *, half2 *, half *, T_B_KQ *, T_C_VKQ *, float *, float *, int, int, int) [with DKQ=576, DV=512, ncols1=2, ncols2=4, nwarps=2, use_logit_softcap=true, mla=true, needs_fixup=false, is_fixup=true, last_iter=true, oob_check=false, T_A_KQ=ggml_cuda_mma::tile<8, 4, half2, ggml_cuda_mma::DATA_LAYOUT_I_MAJOR_MIRRORED>, T_B_KQ=ggml_cuda_mma::tile<32, 4, half2, ggml_cuda_mma::DATA_LAYOUT_I_MAJOR>, T_C_KQ=ggml_cuda_mma::tile<32, 8, float, ggml_cuda_mma::DATA_LAYOUT_I_MAJOR>, T_A_VKQ=ggml_cuda_mma::tile<8, 4, half2, ggml_cuda_mma::DATA_LAYOUT_J_MAJOR_MIRRORED>, T_B_VKQ=ggml_cuda_mma::tile<32, 4, half2, ggml_cuda_mma::DATA_LAYOUT_I_MAJOR>, T_C_VKQ=ggml_cuda_mma::tile<32, 4, half2, ggml_cuda_mma::DATA_LAYOUT_I_MAJOR>]" at line 1115

@JohannesGaessler
Copy link
Collaborator

The compilation problems have to do with Volta, I'll push a fix.

@JohannesGaessler
Copy link
Collaborator

For Volta the minimum number of KQ columns is 32 so for the kernel templates with 8 and 16 columns the number of parallel warps per CUDA block was being calculated incorrectly. This tripped the static asserts even though that code should never actually be executed. I fixed the calculation of parallel warps to consider this edge case correctly.

@simcop2387
Copy link

For Volta the minimum number of KQ columns is 32 so for the kernel templates with 8 and 16 columns the number of parallel warps per CUDA block was being calculated incorrectly. This tripped the static asserts even though that code should never actually be executed. I fixed the calculation of parallel warps to consider this edge case correctly.

Awesome, glad I helped find something :) I'll give the updated branch a test shortly.

@IceFog72
Copy link

IceFog72 commented Jan 21, 2026

is -ctk q8_0 -ctv q8_0 is broken too? For me it's again uses only cpu. Without all ok.

@JohannesGaessler
Copy link
Collaborator

I forgot: I reverted the change to exclude kernels with 8 KQ columns. I meant in my review comment that this variant should be excluded if and only if there are problems with it (but those are now fixed).

@ubergarm
Copy link
Contributor

ubergarm commented Jan 21, 2026

Only had time for one quick test before running out the door today, but initial impression is this PR is looking good:

  • ik_llama.cpp
    • Final estimate: PPL over 565 chunks for n_ctx=512 = 8.4759 +/- 0.0615
  • mainline
    • Final estimate: PPL = 9.2083 +/- 0.06770
sweep-bench-GLM-4 7-Flash
👈 Details
model=/mnt/raid/models/ubergarm/GLM-4.7-Flash-GGUF/GLM-4.7-Flash-MXFP4.gguf

## ik_llama.cpp
CUDA_VISIBLE_DEVICES="0" \
./build/bin/llama-sweep-bench \
  --model "$model" \
  -c 69632 \
  -fa on \
  -ger \
  --merge-qkv \
  -mla 3 -amb 1024 \
  -ngl 99 \
  -ub 4096 -b 4096 \
  --threads 1 \
  --n-predict 128 \
  --warmup-batch

CUDA_VISIBLE_DEVICES="0," \
./build/bin/llama-perplexity \
    --model "$model"\
    -f wiki.test.raw \
    --seed 1337 \
    -mla 3 -amb 1024 \
    -ub 4096 -b 4096 \
    --ctx-size 512 \
    -ngl 99 \
    --threads 1 \
    --no-mmap \
    --validate-quants

## this PR rebased with  4ca23c and 2472acd47 and ug/port-sweep-bench
CUDA_VISIBLE_DEVICES="0" \
./build/bin/llama-sweep-bench \
  --model "$model" \
  -c 69632 \
  -fit off \
  -fa on \
  -ngl 99 \
  -ub 4096 -b 4096 \
  --threads 1

CUDA_VISIBLE_DEVICES="0," \
./build/bin/llama-perplexity \
    --model "$model"\
    -f wiki.test.raw \
    --seed 1337 \
    -ub 4096 -b 4096 \
    --ctx-size 512 \
    -ngl 99 \
    --threads 1 \
    --no-mmap

This model is odd that the mxfp4 with no imatrix is scoring quite a bit "better" than the baseline bf16...

@JohannesGaessler
Copy link
Collaborator

Perplexity over Wikitext is fundamentally the wrong metric for judging the quality of an instruct-tuned model. What you should look at when it comes to the impact of quantization is KL divergence vs. the full-precision model. For judging the quality in an absolute sense there currently just isn't good tooling in the llama.cpp ecosystem.

@ggerganov
Copy link
Member

@JohannesGaessler Good to merge?

@askmyteapot
Copy link

askmyteapot commented Jan 22, 2026

I dont know if it's worth mentioning, but with my P40 + 3090 setup, FA on for PP is half the speed of FA off. Freshly build from this PR

D:\llama.cpp(llvmB)>llama-bench.exe -m D:\text-generation-webui\models\GLM-4.7-FLASH-Q8.gguf -ngl 99 -ts 1;1 -fa 1 -p 8192
ggml_cuda_init: found 2 CUDA devices:
  Device 0: NVIDIA GeForce RTX 3090, compute capability 8.6, VMM: yes
  Device 1: Tesla P40, compute capability 6.1, VMM: no
model size params backend ngl fa ts test t/s
deepseek2 ?B Q8_0 29.65 GiB 29.94 B CUDA 99 1 1.00/1.00 pp8192 426.98 ± 0.51
deepseek2 ?B Q8_0 29.65 GiB 29.94 B CUDA 99 1 1.00/1.00 tg128 47.85 ± 0.37

build: a10d87b (7786)

D:\llama.cpp(llvmB)>llama-bench.exe -m D:\text-generation-webui\models\GLM-4.7-FLASH-Q8.gguf -ngl 99 -ts 1;1 -fa 0 -p 8192
ggml_cuda_init: found 2 CUDA devices:
  Device 0: NVIDIA GeForce RTX 3090, compute capability 8.6, VMM: yes
  Device 1: Tesla P40, compute capability 6.1, VMM: no
model size params backend ngl ts test t/s
deepseek2 ?B Q8_0 29.65 GiB 29.94 B CUDA 99 1.00/1.00 pp8192 821.59 ± 1.18
deepseek2 ?B Q8_0 29.65 GiB 29.94 B CUDA 99 1.00/1.00 tg128 46.82 ± 0.26

build: a10d87b (7786)

3090 Only
D:\llama.cpp(llvmB)>set CUDA_VISIBLE_DEVICES=0

D:\llama.cpp(llvmB)>llama-bench.exe -m D:\text-generation-webui\models\GLM-4.7-FLASH-Q4_0.gguf -ngl 99 -fa 0 -p 8192
ggml_cuda_init: found 1 CUDA devices:
  Device 0: NVIDIA GeForce RTX 3090, compute capability 8.6, VMM: yes
model size params backend ngl test t/s
deepseek2 ?B Q4_0 15.78 GiB 29.94 B CUDA 99 pp8192 2054.59 ± 4.99
deepseek2 ?B Q4_0 15.78 GiB 29.94 B CUDA 99 tg128 122.26 ± 0.26

build: a10d87b (7786)

D:\llama.cpp(llvmB)>llama-bench.exe -m D:\text-generation-webui\models\GLM-4.7-FLASH-Q4_0.gguf -ngl 99 -fa 1 -p 8192
ggml_cuda_init: found 1 CUDA devices:
  Device 0: NVIDIA GeForce RTX 3090, compute capability 8.6, VMM: yes
model size params backend ngl fa test t/s
deepseek2 ?B Q4_0 15.78 GiB 29.94 B CUDA 99 1 pp8192 2352.28 ± 4.55
deepseek2 ?B Q4_0 15.78 GiB 29.94 B CUDA 99 1 tg128 122.70 ± 0.35

build: a10d87b (7786)

P40 Only
D:\llama.cpp(llvmB)>set CUDA_VISIBLE_DEVICES=1

D:\llama.cpp(llvmB)>llama-bench.exe -m D:\text-generation-webui\models\GLM-4.7-FLASH-Q4_0.gguf -ngl 99 -fa 0 -p 8192
ggml_cuda_init: found 1 CUDA devices:
  Device 0: Tesla P40, compute capability 6.1, VMM: no
model size params backend ngl test t/s
deepseek2 ?B Q4_0 15.78 GiB 29.94 B CUDA 99 pp8192 415.10 ± 0.38
deepseek2 ?B Q4_0 15.78 GiB 29.94 B CUDA 99 tg128 40.89 ± 0.14

build: a10d87b (7786)

D:\llama.cpp(llvmB)>llama-bench.exe -m D:\text-generation-webui\models\GLM-4.7-FLASH-Q4_0.gguf -ngl 99 -fa 1 -p 8192
ggml_cuda_init: found 1 CUDA devices:
  Device 0: Tesla P40, compute capability 6.1, VMM: no
model size params backend ngl fa test t/s
deepseek2 ?B Q4_0 15.78 GiB 29.94 B CUDA 99 1 pp8192 210.82 ± 0.08
deepseek2 ?B Q4_0 15.78 GiB 29.94 B CUDA 99 1 tg128 44.05 ± 0.16

build: a10d87b (7786)

Copy link
Collaborator

@JohannesGaessler JohannesGaessler left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My general preference would be that we fix CI failures prior to a merge but it's fine if we take care of it soon after.

@am17an am17an merged commit b70d251 into ggml-org:master Jan 22, 2026
79 of 80 checks passed
@paccerdk
Copy link

The changes do not improve performance with quantized KV cache, which still goes to CPU, is that to be expected?

llama-bench -m /data/AI/models/GLM-4.7-Flash-UD-Q4_K_XL.gguf -ngl 99 -fa 0 -p 2048
ggml_cuda_init: found 1 CUDA devices:
  Device 0: NVIDIA GeForce RTX 3090, compute capability 8.6, VMM: yes
model size params backend ngl test t/s
deepseek2 ?B Q4_K - Medium 16.31 GiB 29.94 B CUDA 99 pp2048 2539.03 ± 10.99
deepseek2 ?B Q4_K - Medium 16.31 GiB 29.94 B CUDA 99 tg128 107.04 ± 0.48

llama-bench -m /data/AI/models/GLM-4.7-Flash-UD-Q4_K_XL.gguf -ngl 99 -fa 1 -p 2048 -ctk f16 -ctv f16
ggml_cuda_init: found 1 CUDA devices:
  Device 0: NVIDIA GeForce RTX 3090, compute capability 8.6, VMM: yes
model size params backend ngl fa test t/s
deepseek2 ?B Q4_K - Medium 16.31 GiB 29.94 B CUDA 99 1 pp2048 2762.03 ± 8.73
deepseek2 ?B Q4_K - Medium 16.31 GiB 29.94 B CUDA 99 1 tg128 107.79 ± 0.67

llama-bench -m /data/AI/models/GLM-4.7-Flash-UD-Q4_K_XL.gguf -ngl 99 -fa 1 -p 2048 -ctk q8_0 -ctv q8_0
ggml_cuda_init: found 1 CUDA devices:
  Device 0: NVIDIA GeForce RTX 3090, compute capability 8.6, VMM: yes
model size params backend ngl type_k type_v fa test t/s
deepseek2 ?B Q4_K - Medium 16.31 GiB 29.94 B CUDA 99 q8_0 q8_0 1 pp2048 52.93 ± 1.10
deepseek2 ?B Q4_K - Medium 16.31 GiB 29.94 B CUDA 99 q8_0 q8_0 1 tg128 39.85 ± 1.74

build: b70d251 (7803)

@am17an am17an deleted the glm_4.7_headsize branch January 22, 2026 12:50
Nexesenex added a commit to Nexesenex/croco.cpp that referenced this pull request Jan 23, 2026
ronaldmannak pushed a commit to PicoMLX/llama.cpp that referenced this pull request Jan 24, 2026
linus-amg added a commit to linus-amg/llama.cpp that referenced this pull request Jan 24, 2026
This enables MMA-based flash attention on RDNA3 GPUs (gfx1100/1101/1102)
for models with head size 576, such as GLM-4.7-Flash and other MLA
(Multi-head Latent Attention) models.

Previously, flash attention with head size 576 only worked on CUDA
(via PR ggml-org#18953) and RDNA4. RDNA3 users had to disable flash attention,
resulting in ~3x slower inference.

Changes:
- fattn.cu: Route RDNA3 + head size 576 to MMA kernel (was RDNA4-only)
- fattn-mma-f16.cuh: Enable AMD WMMA for all RDNA3/RDNA4, allow DKQ==576
- mma.cuh: Add RDNA3 to make_identity_mat(), add f16->f16 WMMA intrinsic

Tested on AMD RX 7900 XTX (gfx1100) with GLM-4.7-Flash-REAP-23B:
- FA off: ~77 t/s
- FA on (before, broken): ~27 t/s
- FA on (after fix): ~83 t/s
Nexesenex added a commit to Nexesenex/croco.cpp that referenced this pull request Jan 26, 2026
maxious added a commit to maxious/llama.cpp that referenced this pull request Jan 27, 2026
When V is a view of K but with different head dimensions (e.g., GLM-4.7-Flash
with K=576, V=512), we cannot simply reuse K's data pointer for V.

For MLA models, the K tensor layout is [kv_lora_scaled (DV), pe (DQK-DV)],
so V data is the first DV elements of each K row.

This fix extracts the correct V data from K when DQK != DV in:
- ggml_sycl_op_flash_attn_1 (basic FA path)
- ggml_sycl_op_flash_attn_coopmat (XMX path)
- ggml_sycl_op_flash_attn_mkl (oneMKL path)

Fixes GPU memory faults and incorrect results in backend tests for
hsk=576,hsv=512 configurations.

Aligns with upstream PRs ggml-org#18953, ggml-org#18986, ggml-org#19067 that implement V-less KV cache
for MLA models like DeepSeek and GLM-4.7-Flash.

Amp-Thread-ID: https://ampcode.com/threads/T-019bf97a-9105-718e-84fb-320913c5f0c6
Co-authored-by: Amp <amp@ampcode.com>
maxious added a commit to maxious/llama.cpp that referenced this pull request Jan 31, 2026
When V is a view of K but with different head dimensions (e.g., GLM-4.7-Flash
with K=576, V=512), we cannot simply reuse K's data pointer for V.

For MLA models, the K tensor layout is [kv_lora_scaled (DV), pe (DQK-DV)],
so V data is the first DV elements of each K row.

This fix extracts the correct V data from K when DQK != DV in:
- ggml_sycl_op_flash_attn_1 (basic FA path)
- ggml_sycl_op_flash_attn_coopmat (XMX path)
- ggml_sycl_op_flash_attn_mkl (oneMKL path)

Fixes GPU memory faults and incorrect results in backend tests for
hsk=576,hsv=512 configurations.

Aligns with upstream PRs ggml-org#18953, ggml-org#18986, ggml-org#19067 that implement V-less KV cache
for MLA models like DeepSeek and GLM-4.7-Flash.

Amp-Thread-ID: https://ampcode.com/threads/T-019bf97a-9105-718e-84fb-320913c5f0c6
Co-authored-by: Amp <amp@ampcode.com>
maxious added a commit to maxious/llama.cpp that referenced this pull request Feb 1, 2026
When V is a view of K but with different head dimensions (e.g., GLM-4.7-Flash
with K=576, V=512), we cannot simply reuse K's data pointer for V.

For MLA models, the K tensor layout is [kv_lora_scaled (DV), pe (DQK-DV)],
so V data is the first DV elements of each K row.

This fix extracts the correct V data from K when DQK != DV in:
- ggml_sycl_op_flash_attn_1 (basic FA path)
- ggml_sycl_op_flash_attn_coopmat (XMX path)
- ggml_sycl_op_flash_attn_mkl (oneMKL path)

Fixes GPU memory faults and incorrect results in backend tests for
hsk=576,hsv=512 configurations.

Aligns with upstream PRs ggml-org#18953, ggml-org#18986, ggml-org#19067 that implement V-less KV cache
for MLA models like DeepSeek and GLM-4.7-Flash.

Amp-Thread-ID: https://ampcode.com/threads/T-019bf97a-9105-718e-84fb-320913c5f0c6
Co-authored-by: Amp <amp@ampcode.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ggml changes relating to the ggml tensor library for machine learning Nvidia GPU Issues specific to Nvidia GPUs python python script changes

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Eval bug: FLASH_ATTN_EXT GLM 4.7 Flash tensor schema not supported on CUDA

9 participants