Skip to content

Conversation

@forforever73
Copy link

@forforever73
Copy link
Author

Adding supplemental evaluation results for reference.

Performance

https://github.com/stepfun-ai/Step-3.5-Flash/blob/main/llama.cpp/docs/step3.5-flash.md

Accuracy

Accuracy was evaluated against a BF16 vLLM baseline.

Test the maximum 256k context on 8 * H200 devices

Dataset vLLM BF16 Baseline step3.5_flash_fp16.gguf
IFEVAL (keywords / existence) 98.08% (±2.13) 98.33% (±2.89)
Dataset vLLM BF16 Baseline step3.5_flash_fp16.gguf
HMMT25 98.44% (±1.86) 97.50%

Test the maximum 256k context on Mac Studio
Repeated 64 times and averaged.

Model Device Repeats Average
vLLM BF16 baseline H200 64 84.38%
step3.5_flash_Q4_K_S.gguf Mac Studio 64 82.89%

@IIIIIllllIIIIIlllll
Copy link

IIIIIllllIIIIIlllll commented Feb 3, 2026

great work! thank you!
It works fine on my 395, about 22 token/s.
image

@gopinath87607
Copy link

is this exactly a same modification did in the forked step llama.cpp ? or its a new one ?

@forforever73
Copy link
Author

@gopinath87607 The register name (step3p5) was modified in the convert_hf_to_gguf part. Everything else is exactly the same.

@tarruda
Copy link

tarruda commented Feb 3, 2026

I tried running this branch with codex. While it works, I see some leaked tool call tokens into the UI:

image

Additionally, I see some warnings in llama-server

slot init_sampler: id  1 | task 3684 | init sampler, took 4.65 ms, tokens: text = 47025, total = 47025
slot update_slots: id  1 | task 3684 | erasing old context checkpoint (pos_min = 33429, pos_max = 35988, size = 330.030 MiB)
slot update_slots: id  1 | task 3684 | created context checkpoint 8 of 8 (pos_min = 44401, pos_max = 46960, size = 330.030 MiB)
slot print_timing: id  1 | task 3684 | 
prompt eval time =   10377.77 ms /  2080 tokens (    4.99 ms per token,   200.43 tokens per second)
       eval time =    6575.80 ms /   169 tokens (   38.91 ms per token,    25.70 tokens per second)
      total time =   16953.57 ms /  2249 tokens
slot      release: id  1 | task 3684 | stop processing: n_tokens = 47193, truncated = 0
srv  update_slots: all slots are idle
srv  log_server_r: done request: POST /v1/responses 192.168.10.78 200
Template supports tool calls but does not natively describe tools. The fallback behaviour used may produce bad results, inspect prompt w/ --verbose & consider overriding the template.
srv  params_from_: Chat format: Hermes 2 Pro

@AesSedai
Copy link

AesSedai commented Feb 3, 2026

I pulled and compiled with this commit, then produced a BF16 with convert_hf_to_gguf, then attempted to imatrix it and the results were looking very suspect:

llama-imatrix output on commit `2f0f12e70`
ggml_cuda_init: found 2 CUDA devices:
  Device 0: NVIDIA GeForce RTX 3090, compute capability 8.6, VMM: yes
  Device 1: NVIDIA GeForce RTX 3090, compute capability 8.6, VMM: yes
build: 7907 (2f0f12e70) with GNU 14.2.1 for Linux x86_64
common_init_result: fitting params to device memory, for bugs during this step try to reproduce them with -fit off, or provide --verbose logs if the bug only occurs with -fit on
llama_params_fit_impl: projected memory use with initial parameters [MiB]:
llama_params_fit_impl:   - CUDA0 (NVIDIA GeForce RTX 3090):  24135 total,  11876 used,  11995 free vs. target of   1024
llama_params_fit_impl:   - CUDA1 (NVIDIA GeForce RTX 3090):  24135 total,   7252 used,  16619 free vs. target of   1024
llama_params_fit_impl: projected to use 19129 MiB of device memory vs. 47743 MiB of free device memory
llama_params_fit_impl: targets for free memory can be met on all devices, no changes needed
llama_params_fit: successfully fit params to free device memory
llama_params_fit: fitting params to free memory took 15.59 seconds
llama_model_load_from_file_impl: using device CUDA0 (NVIDIA GeForce RTX 3090) (0000:06:10.0) - 23871 MiB free
llama_model_load_from_file_impl: using device CUDA1 (NVIDIA GeForce RTX 3090) (0000:06:11.0) - 23871 MiB free
llama_model_loader: loaded meta data with 49 key-value pairs and 754 tensors from /mnt/srv/snowdrift/ggml/Step-3.5-Flash/Step-3.5-Flash-BF16.gguf (version GGUF V3 (latest))
llama_model_loader: Dumping metadata keys/values. Note: KV overrides do not apply in this output.
llama_model_loader: - kv   0:                       general.architecture str              = step35
llama_model_loader: - kv   1:                               general.type str              = model
llama_model_loader: - kv   2:                               general.name str              = Step 3.5 Flash
llama_model_loader: - kv   3:                         general.size_label str              = 288x7.4B
llama_model_loader: - kv   4:                            general.license str              = apache-2.0
llama_model_loader: - kv   5:                   general.base_model.count u32              = 1
llama_model_loader: - kv   6:                  general.base_model.0.name str              = Step 3.5 Flash
llama_model_loader: - kv   7:          general.base_model.0.organization str              = Stepfun Ai
llama_model_loader: - kv   8:              general.base_model.0.repo_url str              = https://huggingface.co/stepfun-ai/ste...
llama_model_loader: - kv   9:                         step35.block_count u32              = 45
llama_model_loader: - kv  10:                      step35.context_length u32              = 262144
llama_model_loader: - kv  11:                    step35.embedding_length u32              = 4096
llama_model_loader: - kv  12:                 step35.feed_forward_length u32              = 11264
llama_model_loader: - kv  13:                step35.attention.head_count arr[i32,45]      = [64, 96, 96, 96, 64, 96, 96, 96, 64, ...
llama_model_loader: - kv  14:                      step35.rope.freq_base f32              = 5000000.000000
llama_model_loader: - kv  15:                step35.attention.key_length u32              = 128
llama_model_loader: - kv  16:              step35.attention.value_length u32              = 128
llama_model_loader: - kv  17:                          general.file_type u32              = 32
llama_model_loader: - kv  18:             step35.attention.head_count_kv arr[i32,45]      = [8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, ...
llama_model_loader: - kv  19:            step35.attention.sliding_window u32              = 512
llama_model_loader: - kv  20:    step35.attention.sliding_window_pattern arr[i32,45]      = [0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, ...
llama_model_loader: - kv  21:             step35.rope.scaling.apply_mask u32              = 1
llama_model_loader: - kv  22:                        step35.expert_count u32              = 288
llama_model_loader: - kv  23:                   step35.expert_used_count u32              = 8
llama_model_loader: - kv  24:          step35.expert_feed_forward_length u32              = 1280
llama_model_loader: - kv  25:   step35.expert_shared_feed_forward_length u32              = 1280
llama_model_loader: - kv  26:                  step35.expert_gating_func u32              = 2
llama_model_loader: - kv  27:                step35.expert_weights_scale f32              = 3.000000
llama_model_loader: - kv  28:                 step35.expert_weights_norm bool             = true
llama_model_loader: - kv  29:           step35.leading_dense_block_count u32              = 3
llama_model_loader: - kv  30:                  step35.moe_every_n_layers u32              = 1
llama_model_loader: - kv  31:      step35.rope.dimension_count_per_layer arr[i32,45]      = [64, 128, 128, 128, 64, 128, 128, 128...
llama_model_loader: - kv  32:    step35.attention.layer_norm_rms_epsilon f32              = 0.000010
llama_model_loader: - kv  33:            step35.rope.freq_base_per_layer arr[f32,45]      = [5000000.000000, 10000.000000, 10000....
llama_model_loader: - kv  34:                       step35.swiglu_limits arr[f32,45]      = [0.000000, 0.000000, 0.000000, 0.0000...
llama_model_loader: - kv  35:                step35.swiglu_limits_shared arr[f32,45]      = [0.000000, 0.000000, 0.000000, 0.0000...
llama_model_loader: - kv  36:               general.quantization_version u32              = 2
llama_model_loader: - kv  37:                       tokenizer.ggml.model str              = gpt2
llama_model_loader: - kv  38:                         tokenizer.ggml.pre str              = deepseek-v3
llama_model_loader: - kv  39:                      tokenizer.ggml.tokens arr[str,128896]  = ["<|begin▁of▁sentence|>", "<�...
llama_model_loader: - kv  40:                  tokenizer.ggml.token_type arr[i32,128896]  = [3, 3, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...
llama_model_loader: - kv  41:                      tokenizer.ggml.merges arr[str,127741]  = ["Ġ t", "Ġ a", "i n", "Ġ Ġ", "h e...
llama_model_loader: - kv  42:                tokenizer.ggml.bos_token_id u32              = 0
llama_model_loader: - kv  43:                tokenizer.ggml.eos_token_id u32              = 128007
llama_model_loader: - kv  44:            tokenizer.ggml.padding_token_id u32              = 1
llama_model_loader: - kv  45:               tokenizer.ggml.add_bos_token bool             = true
llama_model_loader: - kv  46:               tokenizer.ggml.add_sep_token bool             = false
llama_model_loader: - kv  47:               tokenizer.ggml.add_eos_token bool             = false
llama_model_loader: - kv  48:                    tokenizer.chat_template str              = {% macro render_content(content) %}{%...
llama_model_loader: - type  f32:  266 tensors
llama_model_loader: - type bf16:  488 tensors
print_info: file format = GGUF V3 (latest)
print_info: file type   = BF16
print_info: file size   = 366.95 GiB (16.00 BPW) 
load: 0 unused tokens
load: printing all EOG tokens:
load:   - 128007 ('<|im_end|>')
load: special tokens cache size = 818
load: token to piece cache size = 0.8220 MB
print_info: arch                  = step35
print_info: vocab_only            = 0
print_info: no_alloc              = 0
print_info: n_ctx_train           = 262144
print_info: n_embd                = 4096
print_info: n_embd_inp            = 4096
print_info: n_layer               = 45
print_info: n_head                = [64, 96, 96, 96, 64, 96, 96, 96, 64, 96, 96, 96, 64, 96, 96, 96, 64, 96, 96, 96, 64, 96, 96, 96, 64, 96, 96, 96, 64, 96, 96, 96, 64, 96, 96, 96, 64, 96, 96, 96, 64, 96, 96, 96, 64]
print_info: n_head_kv             = 8
print_info: n_rot                 = 128
print_info: n_swa                 = 512
print_info: is_swa_any            = 1
print_info: n_embd_head_k         = 128
print_info: n_embd_head_v         = 128
print_info: n_gqa                 = [8, 12, 12, 12, 8, 12, 12, 12, 8, 12, 12, 12, 8, 12, 12, 12, 8, 12, 12, 12, 8, 12, 12, 12, 8, 12, 12, 12, 8, 12, 12, 12, 8, 12, 12, 12, 8, 12, 12, 12, 8, 12, 12, 12, 8]
print_info: n_embd_k_gqa          = 1024
print_info: n_embd_v_gqa          = 1024
print_info: f_norm_eps            = 0.0e+00
print_info: f_norm_rms_eps        = 1.0e-05
print_info: f_clamp_kqv           = 0.0e+00
print_info: f_max_alibi_bias      = 0.0e+00
print_info: f_logit_scale         = 0.0e+00
print_info: f_attn_scale          = 0.0e+00
print_info: n_ff                  = 11264
print_info: n_expert              = 288
print_info: n_expert_used         = 8
print_info: n_expert_groups       = 0
print_info: n_group_used          = 0
print_info: causal attn           = 1
print_info: pooling type          = 0
print_info: rope type             = 2
print_info: rope scaling          = linear
print_info: freq_base_train       = 5000000.0
print_info: freq_scale_train      = 1
print_info: freq_base_swa         = 10000.0
print_info: freq_scale_swa        = 1
print_info: n_ctx_orig_yarn       = 262144
print_info: rope_yarn_log_mul     = 0.0000
print_info: rope_finetuned        = unknown
print_info: model type            = ?B
print_info: model params          = 196.96 B
print_info: general.name          = Step 3.5 Flash
print_info: vocab type            = BPE
print_info: n_vocab               = 128896
print_info: n_merges              = 127741
print_info: BOS token             = 0 '<|begin▁of▁sentence|>'
print_info: EOS token             = 128007 '<|im_end|>'
print_info: EOT token             = 128007 '<|im_end|>'
print_info: PAD token             = 1 '<|end▁of▁sentence|>'
print_info: LF token              = 201 'Ċ'
print_info: FIM PRE token         = 128801 '<|fim▁begin|>'
print_info: FIM SUF token         = 128800 '<|fim▁hole|>'
print_info: FIM MID token         = 128802 '<|fim▁end|>'
print_info: EOG token             = 128007 '<|im_end|>'
print_info: max token length      = 256
load_tensors: loading model tensors, this can take a while... (mmap = true, direct_io = false)
load_tensors: offloading output layer to GPU
load_tensors: offloading 44 repeating layers to GPU
load_tensors: offloaded 46/46 layers to GPU
load_tensors:   CPU_Mapped model buffer size = 375759.26 MiB
load_tensors:        CUDA0 model buffer size =  5898.51 MiB
load_tensors:        CUDA1 model buffer size =  5973.75 MiB
....................................................................................................
common_init_result: added <|im_end|> logit bias = -inf
llama_context: constructing llama_context
llama_context: n_seq_max     = 1
llama_context: n_ctx         = 2048
llama_context: n_ctx_seq     = 2048
llama_context: n_batch       = 2048
llama_context: n_ubatch      = 2048
llama_context: causal_attn   = 1
llama_context: flash_attn    = auto
llama_context: kv_unified    = false
llama_context: freq_base     = 5000000.0
llama_context: freq_scale    = 1
llama_context: n_ctx_seq (2048) < n_ctx_train (262144) -- the full capacity of the model will not be utilized
llama_context:  CUDA_Host  output buffer size =     0.49 MiB
llama_kv_cache_iswa: creating non-SWA KV cache, size = 2048 cells
llama_kv_cache:      CUDA0 KV buffer size =    48.00 MiB
llama_kv_cache:      CUDA1 KV buffer size =    48.00 MiB
llama_kv_cache: size =   96.00 MiB (  2048 cells,  12 layers,  1/1 seqs), K (f16):   48.00 MiB, V (f16):   48.00 MiB
llama_kv_cache_iswa: creating     SWA KV cache, size = 2048 cells
llama_kv_cache:      CUDA0 KV buffer size =   136.00 MiB
llama_kv_cache:      CUDA1 KV buffer size =   128.00 MiB
llama_kv_cache: size =  264.00 MiB (  2048 cells,  33 layers,  1/1 seqs), K (f16):  132.00 MiB, V (f16):  132.00 MiB
sched_reserve: reserving ...
sched_reserve: Flash Attention was auto, set to enabled
sched_reserve:      CUDA0 compute buffer size =  5794.25 MiB
sched_reserve:      CUDA1 compute buffer size =  1103.00 MiB
sched_reserve:  CUDA_Host compute buffer size =    96.09 MiB
sched_reserve: graph nodes  = 3422
sched_reserve: graph splits = 151 (with bs=2048), 87 (with bs=1)
sched_reserve: reserve took 21.70 ms, sched copies = 1

system_info: n_threads = 56 (n_threads_batch = 56) / 56 | CUDA : ARCHS = 860 | USE_GRAPHS = 1 | PEER_MAX_BATCH_SIZE = 128 | CPU : SSE3 = 1 | SSSE3 = 1 | AVX = 1 | AVX_VNNI = 1 | AVX2 = 1 | F16C = 1 | FMA = 1 | BMI2 = 1 | AVX512 = 1 | AVX512_VBMI = 1 | AVX512_VNNI = 1 | AVX512_BF16 = 1 | LLAMAFILE = 1 | OPENMP = 1 | REPACK = 1 | 
compute_imatrix: tokenizing the input ..
compute_imatrix: tokenization took 585.593 ms
compute_imatrix: computing over 200 chunks, n_ctx=2048, batch_size=2048, n_seq=1
compute_imatrix: 12.43 seconds per pass - ETA 41.43 minutes
[1]86644.2818,[2]87846.2570,[3]85126.1948,[4]85482.5234,[5]86821.5460,[6]86843.7771,[7]85988.4366,[8]87247.1141,[9]88148.0087,
save_imatrix: entry '               blk.43.ffn_up_exps.weight' has partial data (2.78%)
save_imatrix: entry '             blk.42.ffn_down_exps.weight' has partial data (3.12%)
save_imatrix: entry '             blk.39.ffn_gate_exps.weight' has partial data (2.78%)
save_imatrix: entry '             blk.38.ffn_gate_exps.weight' has partial data (2.78%)
save_imatrix: entry '             blk.39.ffn_down_exps.weight' has partial data (2.78%)
save_imatrix: entry '             blk.37.ffn_gate_exps.weight' has partial data (2.78%)
save_imatrix: entry '             blk.36.ffn_down_exps.weight' has partial data (2.78%)
save_imatrix: entry '             blk.40.ffn_down_exps.weight' has partial data (2.78%)
save_imatrix: entry '             blk.35.ffn_down_exps.weight' has partial data (2.78%)
save_imatrix: entry '             blk.35.ffn_gate_exps.weight' has partial data (2.78%)
save_imatrix: entry '               blk.34.ffn_up_exps.weight' has partial data (2.78%)
save_imatrix: entry '             blk.34.ffn_gate_exps.weight' has partial data (2.78%)
save_imatrix: entry '             blk.33.ffn_down_exps.weight' has partial data (2.78%)
save_imatrix: entry '             blk.33.ffn_gate_exps.weight' has partial data (2.78%)
save_imatrix: entry '               blk.39.ffn_up_exps.weight' has partial data (2.78%)
save_imatrix: entry '             blk.32.ffn_down_exps.weight' has partial data (3.12%)
save_imatrix: entry '               blk.32.ffn_up_exps.weight' has partial data (3.12%)
save_imatrix: entry '             blk.34.ffn_down_exps.weight' has partial data (2.78%)
save_imatrix: entry '             blk.31.ffn_down_exps.weight' has partial data (2.78%)
save_imatrix: entry '             blk.31.ffn_gate_exps.weight' has partial data (2.78%)
save_imatrix: entry '             blk.40.ffn_gate_exps.weight' has partial data (2.78%)
save_imatrix: entry '             blk.43.ffn_gate_exps.weight' has partial data (2.78%)

I canceled it because the partial data for the experts and the 80,000+ PPL make it seem like something has gone wrong in the conversion or inference process somewhere.

@IIIIIllllIIIIIlllll
Copy link

IIIIIllllIIIIIlllll commented Feb 3, 2026

470a255d3a0777d23d24755be629f502

The same issue, about 'tool_call'.

Edited: However, the result is correct; it indeed helped me write the HTML game I wanted. @forforever73

@drrros
Copy link

drrros commented Feb 3, 2026

same in cline
image

running with LLAMA_SET_ROWS=1 ./build/bin/llama-server --model /mnt/ds1nfs/codellamaweights/stepfun/step3p5_flash_Q4_K_S.gguf --port 30000 --host 192.168.0.60 -c $((256*1024)) -fa on --reasoning-format auto --no-mmap --jinja --temp 1.0

speed i'm getting:

prompt eval time =  266478.36 ms / 47459 tokens (    5.61 ms per token,   178.10 tokens per second)
       eval time =    3651.26 ms /   141 tokens (   25.90 ms per token,    38.62 tokens per second)
      total time =  270129.61 ms / 47600 tokens

This is on Epyc 9274f \ 12*32Gb 4800 MT/s \ dual Nvidia A5000

@forforever73
Copy link
Author

@AesSedai Sorry about that. For now, please use the pre-quantized GGUF model: https://huggingface.co/stepfun-ai/Step-3.5-Flash-Int4
This is because an offline +1 adjustment was applied to the weights before conversion. I’ll move this part into convert_hf_to_gguf as soon as possible.

@eauchs
Copy link

eauchs commented Feb 3, 2026

have around 23 tokens.s-1 with m3 max 128go, this is really great!
image

@forforever73
Copy link
Author

Tool calling is still missing some support in llama.cpp at the moment. I’ll submit the next PR to address this as soon as possible 💪🙂

@IIIIIllllIIIIIlllll
Copy link

Tool calling is still missing some support in llama.cpp at the moment. I’ll submit the next PR to address this as soon as possible 💪🙂

After testing, I found that this bug occurs when more MCP tools are provided.

If there is only one (perhaps) MCP tool, this issue does not occur.

@tarruda
Copy link

tarruda commented Feb 3, 2026

Tool calling is still missing some support in llama.cpp at the moment. I’ll submit the next PR to address this as soon as possible 💪🙂

Looking forward to it!

This is the best LLM I could run locally so far, thank you for it!

@joonanykanen
Copy link

@tarruda I do share your thoughts. This model seems extremely intelligent. Running ~16tok/s with 2xRTX3090 and 128GB DDR4. Makes me want to invest in Pro 6000 Blackwells lmao!

@pwilkin
Copy link
Collaborator

pwilkin commented Feb 3, 2026

If someone wants a version with fully working reasoning + tool calling, I've added a cherry-picked version of my autoparser branch. Already tested with OpenCode and works great so far.

https://github.com/pwilkin/llama.cpp/tree/autoparser-stepfun

@tarruda
Copy link

tarruda commented Feb 3, 2026

If someone wants a version with fully working reasoning + tool calling, I've added a cherry-picked version of my autoparser branch. Already tested with OpenCode and works great so far.

https://github.com/pwilkin/llama.cpp/tree/autoparser-stepfun

Thank you @pwilkin, will use that branch for now!

@drrros
Copy link

drrros commented Feb 3, 2026

@pwilkin

https://github.com/pwilkin/llama.cpp/tree/autoparser-stepfun

This doesn't compiling for me:

git status
On branch autoparser-stepfun
Your branch is up to date with 'origin/autoparser-stepfun'.

nothing to commit, working tree clean
...
cmake -B build -DGGML_CUDA=ON -DGGML_CUDA_FA_ALL_QUANTS=ON && cmake --build build --config Release -j 24
...
/bin/ld: ../../bin/libllama.so.0.0.7931: undefined reference to `bool llama_model_loader::get_key_or_arr<float, 512ul>(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, std::array<float, 512ul>&, unsigned int, bool)'
collect2: error: ld returned 1 exit status
gmake[2]: *** [examples/simple/CMakeFiles/llama-simple.dir/build.make:102: bin/llama-simple] Error 1
gmake[1]: *** [CMakeFiles/Makefile2:3749: examples/simple/CMakeFiles/llama-simple.dir/all] Error 2
gmake[1]: *** Waiting for unfinished jobs....
[ 65%] Building CXX object common/CMakeFiles/common.dir/json-partial.cpp.o
[ 65%] Linking CXX executable ../../bin/llama-simple-chat
/bin/ld: ../../bin/libllama.so.0.0.7931: undefined reference to `bool llama_model_loader::get_key_or_arr<float, 512ul>(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, std::array<float, 512ul>&, unsigned int, bool)'
collect2: error: ld returned 1 exit status
gmake[2]: *** [examples/simple-chat/CMakeFiles/llama-simple-chat.dir/build.make:102: bin/llama-simple-chat] Error 1
gmake[1]: *** [CMakeFiles/Makefile2:3779: examples/simple-chat/CMakeFiles/llama-simple-chat.dir/all] Error 2

@pwilkin
Copy link
Collaborator

pwilkin commented Feb 3, 2026

@drrros sorry, forgot to commit that fix, try now.

@ngladitz
Copy link

ngladitz commented Feb 3, 2026

@pwilkin

https://github.com/pwilkin/llama.cpp/tree/autoparser-stepfun

This doesn't compiling for me:

I ran into the same issue but taking https://github.com/pwilkin/llama.cpp/tree/autoparser and then cherry-picking this MR's commit on top worked for me.

I do occasionally see "Invalid diff:" exceptions. A tool "string" parameter (which happens to consist of only digits; is incidentally also a legal integer) is shown once with and once without quotes.

@Edgar-I
Copy link

Edgar-I commented Feb 3, 2026

@pwilkin Compiling now, thanks
image

@pwilkin
Copy link
Collaborator

pwilkin commented Feb 3, 2026

I do occasionally see "Invalid diff:" exceptions. A tool "string" parameter (which happens to consist of only digits; is incidentally also a legal integer) is shown once with and once without quotes.

That's a good debug case, could you possibly paste it here?

@forforever73
Copy link
Author

@pwilkin I tried your branch and it does fix the tool call issue — thanks!
Is this a general issue, or something tied to this pr or the step3.5 model?

Comment on lines 968 to 976
def add_rope_scaling_apply_mask(self, yarn_only_types: Sequence[str] | None) -> None:
apply_mask = 0x3 # default: apply on all layers (backwards compatible)
if isinstance(yarn_only_types, list):
apply_mask = 0
if "full_attention" in yarn_only_types:
apply_mask |= 0x1
if "sliding_attention" in yarn_only_types:
apply_mask |= 0x2
self.add_uint32(Keys.Rope.SCALING_APPLY_MASK.format(arch=self.arch), int(apply_mask))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is too hacky IMO, we already had a notion of hparams.swa_layers and it should be used instead. See MiMo2 model for an example:

ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, hparams.swa_layers, hparams.n_layer);

Comment on lines 57 to 60
const uint32_t apply_mask = hparams.rope_scaling_apply_mask;
if ((is_swa && (apply_mask & 0x2)) || (!is_swa && (apply_mask & 0x1))) {
rope_factors = model.get_rope_factors(cparams, il);
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the use of the word "mask" here is quite ambiguous and can be interpreted as something like attention mask. either rename it to "bitmask" or better, don't use bit mask, save it as dedicated std::array<bool, LLAMA_MAX_LAYERS>

you really don't need a mask here, the is_swa already provided the info about SWA layer, and the rope_scaling can be just an array of bool

Comment on lines 217 to 218
std::array<float, LLAMA_MAX_LAYERS> swiglu_limits;
std::array<float, LLAMA_MAX_LAYERS> swiglu_limits_shared;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not quite a fan of calling the same thing using different names.

This is just clamping, even the python code calls it "clamp". I don't care about how config.json call it.

Suggested change
std::array<float, LLAMA_MAX_LAYERS> swiglu_limits;
std::array<float, LLAMA_MAX_LAYERS> swiglu_limits_shared;
std::array<float, LLAMA_MAX_LAYERS> swiglu_clamp_exp; // clamping for expert FFN
std::array<float, LLAMA_MAX_LAYERS> swiglu_clamp_shexp; // shared exp

Comment on lines 2480 to 2490
format("%s.rope.scaling.apply_mask", ml.get_arch_name().c_str()),
hparams.rope_scaling_apply_mask,
false
);

hparams.has_rope_freq_base_per_layer = ml.get_key_or_arr(
format("%s.rope.freq_base_per_layer", ml.get_arch_name().c_str()),
hparams.rope_freq_base_per_layer,
hparams.n_layer,
false
);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add these as proper LLM_KV_* like all other models

@pwilkin
Copy link
Collaborator

pwilkin commented Feb 3, 2026

@pwilkin I tried your branch and it does fix the tool call issue — thanks! Is this a general issue, or something tied to this pr or the step3.5 model?

I'm refactoring the parser in general so that it handles new typical templates automatically (and I tackle a few edge cases that are annoying during agentic coding). It's just that the model doesn't have a dedicated parser in master yet (which is how things were done till now).

@ngladitz
Copy link

ngladitz commented Feb 3, 2026

I do occasionally see "Invalid diff:" exceptions. A tool "string" parameter (which happens to consist of only digits; is incidentally also a legal integer) is shown once with and once without quotes.

That's a good debug case, could you possibly paste it here?

@pwilkin I think with the following test case I consistently see the ref parameter being generated without quotes.
First thought this is tied to streaming but doesn't seem to be the case.

Test case here
curl -N http://localhost:8080/v1/chat/completions \
        -H "Content-Type: application/json" \
        -d '{
                "model": "step-3.5-flash",
                "stream": false,
                "messages": [
                        {"role": "user", "content": "call the magic tool with ref 5123123 and name fooBar"}
                ],
                "tools": [
                        {
                                "type": "function",
                                "function": {
                                        "name": "magic",
                                        "description": "Magic tool that takes a hash",
                                        "parameters": {
                                                "type": "object",
                                                "properties": {
                                                        "name": {"type": "string"},
                                                        "ref": {"type": "string"}
                                                },
                                                "required": ["name", "ref"]
                                        }
                                }
                        }
                ],
                "tool_choice": "auto"
        }'

Relevant output (reformatted for readability):

{
    "tool_calls": [
        {
            "type": "function",
            "function": {
                "name": "magic",
                "arguments": {
                    "name": "fooBar",
                    "ref": 5123123
                }
            },
            "id": "EmNp5CqLXcPOl91dF0OqiEIYsZyQ1TY3"
        }
    ]
}

Tool schema says both parameters should be strings but ref is missing quotes.

@pwilkin
Copy link
Collaborator

pwilkin commented Feb 3, 2026

@ngladitz Yeah, good case, thanks, will fix and add to tests.

@pwilkin
Copy link
Collaborator

pwilkin commented Feb 3, 2026

@ngladitz BTW it is, in a sense, a streaming problem: the tool reads the input as a string, but then parses it as a number, so there's a divergence between the partial parse result (which is a string) and the final result (which is a number) since in JSON rendering the string version isn't a prefix of the number version, hence the error.

@pwilkin
Copy link
Collaborator

pwilkin commented Feb 3, 2026

@ngladitz aight, can you please check if the newest commit on https://github.com/pwilkin/llama.cpp/tree/autoparser-stepfun properly supports streaming with that tool?

@pwilkin pwilkin mentioned this pull request Feb 3, 2026
@ngladitz
Copy link

ngladitz commented Feb 3, 2026

@ngladitz aight, can you please check if the newest commit on https://github.com/pwilkin/llama.cpp/tree/autoparser-stepfun properly supports streaming with that tool?

@pwilkin thank you that seems to have fixed both my reduced test case as well as my actual use ❤️

@AesSedai
Copy link

AesSedai commented Feb 3, 2026

@forforever73 I've re-converted the BF16 and am doing a new imatrix and the values look correct now, PPL is approx 3-4 now and the experts are showing much better data coverage. Thanks!

lvyichen and others added 4 commits February 4, 2026 19:15
Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
Comment on lines 7723 to 7724
kv_arr = [n_kv_swa if lt == "sliding_attention" else n_kv_base for lt in layer_types]
swa_pat = [1 if lt == "sliding_attention" else 0 for lt in layer_types]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
kv_arr = [n_kv_swa if lt == "sliding_attention" else n_kv_base for lt in layer_types]
swa_pat = [1 if lt == "sliding_attention" else 0 for lt in layer_types]
kv_arr = [n_kv_swa if lt == "sliding_attention" else n_kv_base for lt in layer_types]
swa_pat = [lt == "sliding_attention" for lt in layer_types]

You need to change this otherwise CI will fail.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

model Model specific python python script changes

Projects

None yet

Development

Successfully merging this pull request may close these issues.