Skip to content

Feat/kv cache fp8 support#26

Merged
luozixin2 merged 10 commits intoSJTU-DENG-Lab:feat/kv-cache-fp8-supportfrom
luozixin2:feat/kv-cache-fp8-support
Jan 28, 2026
Merged

Feat/kv cache fp8 support#26
luozixin2 merged 10 commits intoSJTU-DENG-Lab:feat/kv-cache-fp8-supportfrom
luozixin2:feat/kv-cache-fp8-support

Conversation

@luozixin2
Copy link
Collaborator

@luozixin2 luozixin2 commented Jan 18, 2026

Summary by CodeRabbit

  • New Features

    • Added support for vLLM-compatible GPTQ Marlin and AWQ Marlin quantization formats.
    • Introduced new configuration parameters for W8A16 quantization tuning.
    • Enhanced CLI options for benchmark configuration overrides.
  • Improvements

    • Optimized linear layer forward planning for improved CUDA graph performance.
    • Refactored quantization strategies to leverage external kernel implementations.
  • Chores

    • Removed benchmark result artifacts and deprecated configuration files.
    • Cleaned up legacy kernel implementations and corresponding PyTorch bindings.

✏️ Tip: You can customize this high-level summary in your review settings.

luozixin2 added 2 commits January 18, 2026 05:43
主要变更:
- 添加 GPTQ Marlin (W4A16) 和 AWQ Marlin (W4A16) 量化策略
- 修复 loader.py 以正确加载 gptq_marlin 格式权重(支持 Marlin 特有的 repacked qweight 和 permuted scales)
- 修改 quantize_model.py 支持导出 gptq_marlin 格式(对称量化 + Marlin repack/permute)
- 更新 linear.py:
  - 添加 _offline_quant_bits 缓冲区存储量化位数
  - 添加 GPTQ runtime shuffle 支持(gptq_shuffle)
  - 添加 GPTQ/AWQ Marlin 的 lazy repack 支持(_maybe_prepare_offline_gptq_marlin/_awq_marlin)
  - 统一使用 vLLM 格式(int32 packed, fp16 scales)
- 简化各策略文件,移除重复代码
- 移除旧的 AllSpark Marlin 实现文件
- 添加多个 benchmark 配置文件(GPTQ/AWQ Marlin 各 bit 版本)
benchmark_results 是本地生成的评测产物,不应进入版本库。
本提交将其作为正常删除移出,并依赖 .gitignore 中的 benchmark_results/ 规则避免后续再次提交。
@coderabbitai
Copy link

coderabbitai bot commented Jan 18, 2026

Important

Review skipped

Auto reviews are disabled on base/target branches other than the default branch.

Please check the settings in the CodeRabbit UI or the .coderabbit.yaml file in this repository. To trigger a single review, invoke the @coderabbitai review command.

You can disable this status message by setting the reviews.review_status to false in the CodeRabbit configuration file.

  • 🔍 Trigger a full review

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@luozixin2 luozixin2 marked this pull request as draft January 18, 2026 06:01
luozixin2 added 8 commits January 18, 2026 06:40
- 添加 quant-method=auto 支持:使用 auto-gptq / awq 进行真正的校准量化
- 添加校准数据参数:--calib-text-file, --calib-num-samples, --calib-seq-len 等
- 实现 _export_autogptq_to_vllm_weights:从 auto-gptq 量化模型中导出 vLLM 格式权重
- 实现 _export_awq_to_vllm_weights:从 awq 量化模型中导出 vLLM 格式权重
- 保留 quant-method=simple 旧实现作为后向兼容
- 修复 loader.py 中 gptq_marlin scales 的 shape 推理和 TP sharding 逻辑
- 修复 linear_gptq_marlin_w4a16.py 移除不必要的 bf16->fp16 转换
主要重构内容:

1. **diffulex/layer/linear.py** - 大幅简化量化逻辑(-197行):
   - 新增 `_forward_base()`: 统一的前向分发器,替换子类中重复的量化分支逻辑
   - 新增 `_build_offline_forward_kwargs()`: 统一构建离线量化(GPTQ/AWQ)前向参数
   - 新增 `_get_linear_strategy()`, `_offline_meta()`, `_infer_gptq_weight_bits()` 等辅助方法
   - 修复 `LoRAMixin.merge_lora` 中 base weight 为 None 的边界情况
   - 移除未使用的导入(marlin_zero_points, unpack_cols, marlin_make_empty_g_idx)

2. **diffulex/utils/loader.py** - 优化性能和代码结构:
   - 一次性扫描 safetensors 文件建立 key_to_file 索引,避免重复文件 I/O
   - 缓存 `model.named_modules()` 结果,避免重复构建字典
   - 新增 `_find_offline_capable_module()`: 统一模块查找逻辑
   - 新增 `_load_tensors_for_prefix()`: 集中加载张量,仅打开必要的文件
   - 将 print() 替换为 logger.warning()/logger.exception() 以规范化日志

3. **diffulex/engine/model_runner.py** - 消除重复循环:
   - 在 `allocate_kv_cache` 中统一缓存 attention 模块列表
   - 用 `enumerate(attn_modules)` 替换重复的模块遍历循环

4. **diffulex/utils/quantization/strategies/linear_int4_w4a16.py** - 修复缺失实现:
   - 添加 `quantize_weight_for_kernel` 方法,修复 W4A16 在线量化运行时错误

5. 删除未使用的配置文件 `gptq_marlin_w2_bf16kv_varlen.yml`

测试: 已验证 W8A16 在线量化和 GPTQ 离线量化功能正常
- 将最后总结从最后一步的瞬时吞吐改为真正的平均值(总token/总时间)
- 新增 ms/step 统计信息,便于分析性能
- 修复了之前只显示最后一步瞬时值而非平均值的问题
- 量化 linear:去 kwargs/pop/重复可用性检查,缓存 out_features 与必要中间张量
- 直连 vLLM CUDA ops(W8A8/GPTQ/AWQ/Marlin 等)以降低 Python glue 开销
- load-time 处理 qweight/scales 的布局与 contiguous,避免 forward 里重复处理
- 移除 linear.py 中 profiler record 标注,保持代码简洁
- 补充 trace/profile 辅助分析脚本与相关测试
… strategies

- Remove all .item() calls in LinearBase hot paths (GPU->CPU sync breaks graph capture)
  - Add Python-side meta cache (_offline_quant_*_py, _gptq_is_shuffled_py, etc.)
  - Use in-place fill_() + Python mirrors for state updates
- Simplify linear quantization strategies for future CUDA Graph support
  - Remove fast_path checks and redundant branching in linear_marlin_int8_w8a16
  - Remove fast_path in linear_int8_w8a8 (unified vLLM path)
  - Simplify linear_gptq_w4a16 (direct torch.ops._C.gptq_gemm call)
  - Make linear_fp8_w8a16 use explicit quant_scales parameter
- Fix FP8 weight layout: do not force contiguous for transpose-view (KxN stride0==1)
- Remove profiler record_function wrappers (graph-friendly)

Net: -129 lines, cleaner codebase ready for CUDA Graph capture
- Add per-layer ForwardPlan to pre-resolve bf16/quant/offline paths and reduce per-call Python branching.
- Prefer direct torch.ops kernels (GPTQ/AWQ/Marlin) with static args for stable capture.
- Fix D2F static CUDA graph capture/replay metadata (token buckets + cu_seqlens) and add profiler flag.
- Fix tensor shape mismatch bug in static+CUDA Graph decode mode (model_runner.py)
  - Improve bucket selection logic for variable token counts
  - Add safety fallback when runtime batch exceeds captured capacity
  - Fix metadata buffer initialization and padding

- Add new static mode benchmark configs:
  - awq_bf16kv_static.yml
  - gptq_marlin_w4_bf16kv_static.yml
  - gptq_marlin_w8_bf16kv_static.yml

- Update quantization strategies and loader utilities
- Update benchmark configurations for consistency
- 移除 v0.0.1 之后新增的 bench 配置与量化架构文档
- 将 W8A16/DP 等调参从 env 收敛到 Config/strategy.configure
- 示例/脚本去掉硬编码本机路径与默认 GPU,并修复语法问题
@luozixin2 luozixin2 marked this pull request as ready for review January 28, 2026 02:39
@luozixin2 luozixin2 merged commit 600eb4c into SJTU-DENG-Lab:feat/kv-cache-fp8-support Jan 28, 2026
1 check passed
@luozixin2 luozixin2 deleted the feat/kv-cache-fp8-support branch January 28, 2026 02:47
Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 9

🤖 Fix all issues with AI agents
In `@diffulex/engine/model_runner.py`:
- Around line 38-46: dist.init_process_group is calling init_method with an
unconditional access to config.device_ids[rank] which can raise if device_ids is
missing; guard that access by computing device_id first (use
config.device_ids[rank] only if getattr(config, "device_ids", None) is truthy,
otherwise compute from config.device_start + rank) and then pass device_id into
dist.init_process_group (and keep the existing assert using
torch.cuda.device_count()); update the call site referencing config.device_ids,
device_id, dist.init_process_group, and torch.cuda.device_count() accordingly.

In `@diffulex/utils/loader.py`:
- Around line 286-294: The fuzzy matching in the loop (using
offline_capable_modules, module_name and leaf) is too permissive and can map
wrong modules; replace the loose checks with stricter rules: keep exact match,
allow only full-dot-boundary suffix/prefix matches (name == module_name or name
== f"{parent}.{module_name}" style), and remove the single-segment fallback
(name.split(".")[-1] == leaf) or tighten it to require the last two path
segments to match (compare the last two segments of name and module_name) before
returning cand; update the matching logic inside the function that iterates
offline_capable_modules and add a short comment explaining the fallback
behavior.

In `@diffulex/utils/quantization/quantize_model.py`:
- Around line 104-119: The helper _load_calib_texts silently returns fewer
samples when the file has fewer lines than requested; change the logic in
_load_calib_texts so that when len(lines) < num_samples you emit a clear warning
(use the warnings module) indicating fewer samples are available and return all
lines, and only use the rng.sample path when len(lines) >= num_samples (i.e.,
change the current len(lines) <= num_samples branch to handle the < case with a
warning and return, keeping rng.sample(lines, k=num_samples) for the sampling
case); reference the function name _load_calib_texts and the variables lines,
num_samples, seed when making the change.

In `@diffulex/utils/quantization/strategies/linear_awq_marlin_w4a16.py`:
- Line 159: The AWQ strategy currently hardcodes is_k_full=True which can be
incorrect for row-parallel tensor-parallel cases; change the assignment of
is_k_full in the linear_awq_marlin_w4a16 strategy to compute it dynamically
(e.g., mirror the GPTQ Marlin logic) using the existing has_g_idx and
row_parallel variables so is_k_full = not has_g_idx or row_parallel (or
equivalent logic used in the GPTQ Marlin implementation); update the place where
the tuple/flag is constructed (the entry labeled is_k_full) so it derives from
has_g_idx and row_parallel rather than the literal True.

In `@diffulex/utils/quantization/strategies/linear_fp8_w8a16.py`:
- Around line 115-129: The else branch assumes weight is non-null but calls
id(weight) and self.quantize(weight) even when weight is None; add an explicit
check at the start of that branch (e.g., if weight is None: raise
ValueError(...) or handle a valid fallback) so you never call id() or
self.quantize with None; update the block that uses wid = id(weight), cached =
self._weight_cache.get(wid), self.quantize(weight), and subsequent assignments
to q_fp8, meta, scales, q_kn to only run when weight is verified non-None and
provide a clear error message referencing weight and quant_scales for easier
debugging.
- Around line 62-64: The get_storage_dtype function currently returns
torch.uint8 although the comment and vLLM use an FP8 float dtype; change the
return value in get_storage_dtype to use PyTorch's FP8 dtype
(torch.float8_e4m3fn) so it returns (torch.float8_e4m3fn, 1) instead of
(torch.uint8, 1), ensuring the storage dtype matches vLLM's FP8 implementation.

In `@diffulex/utils/quantization/strategies/linear_fp8_w8a8.py`:
- Around line 109-117: The code calls self.quantize(weight) without guarding
against weight being None, which will raise if weight is absent; add the same
null-check used in the W8A16 variant before computing wid/using the cache (or
directly before calling self.quantize) so that when weight is None you skip
quantization and set q_fp8/w_scale to appropriate defaults or handle the
None-case consistently; update the branch around wid, self._weight_cache, and
the call to self.quantize(weight) to first check "if weight is None" and return
or assign safe placeholders used by the rest of the method.

In `@diffulex/utils/quantization/strategies/linear_gptq_marlin_w4a16.py`:
- Around line 181-202: The torch.ops._C.gptq_marlin_gemm call has its arguments
shifted and doesn't match vLLM's 17-arg signature; reorder the call so arguments
follow gptq_marlin_gemm(a, c_or_none, b_q_weight, b_scales, global_scale,
b_zeros_or_none, g_idx_or_none, perm_or_none, workspace, b_q_type, size_m,
size_n, size_k, is_k_full, use_atomic_add, use_fp32_reduce, is_zp_float): pass
reshaped_x, None, qweight, scales, marlin_bias (as global_scale), zp (as
b_zeros_or_none), g_idx_t, g_idx_sort_t, workspace, wtype.id, m, n, k,
is_k_full, use_atomic_add, True (use_fp32_reduce), False (is_zp_float) so all
parameters (reshaped_x, qweight, scales, marlin_bias, zp, g_idx_t, g_idx_sort_t,
workspace, wtype.id, m, n, k, is_k_full, use_atomic_add) are in the correct
positions for torch.ops._C.gptq_marlin_gemm.

In `@diffulex/utils/quantization/strategies/linear_gptq_w4a16.py`:
- Around line 130-131: Remove the redundant device comparison when retrieving
cached empty tensors: rely on the cache key (dev_key) to ensure device
correctness and only check "empty is None" before creating a new tensor.
Concretely, in the block that currently reads "if empty is None or empty.device
!= device:", change it to "if empty is None" and keep the creation line "empty =
torch.empty((0,), device=device, dtype=torch.int)" so CPU tensors (device.index
is None) are handled consistently.
🧹 Nitpick comments (17)
diffulex_bench/configs/example.yml (1)

21-25: Consider aligning max_num_batched_tokens with the new max_model_len.

With max_model_len: 4096, keeping max_num_batched_tokens: 4096 effectively limits full-length batches to a single sequence. If the example aims to allow multi‑sequence batching at 4k context, consider raising max_num_batched_tokens or adding a brief note to avoid confusion.

diffulex/utils/quantization/registry.py (1)

87-103: Update error message to include new quantization methods.

The error message on line 101-102 lists supported methods but doesn't include the newly added gptq_marlin, gptq_marlin_24, and awq_marlin aliases.

✏️ Suggested fix
     if s not in aliases:
         raise ValueError(
             f"Unsupported linear quant dtype={dtype!r}. "
-            "Supported: bf16/int8/int4/fp8/fp8_e4m3/fp8_e5m2/gptq/awq/marlin"
+            "Supported: bf16/int8/int4/fp8/fp8_e4m3/fp8_e5m2/gptq/gptq_marlin/gptq_marlin_24/awq/awq_marlin/marlin"
         )
diffulex_kernel/__init__.py (1)

24-45: Consider caching imported modules to avoid repeated imports.

Each call to __getattr__ re-executes the import statement. For frequently accessed attributes, consider caching the imported module in the module's namespace:

♻️ Suggested improvement
 def __getattr__(name: str):
     if name == "dllm_flash_attn_decode":
         from diffulex_kernel.python.dllm_flash_attn_kernels import dllm_flash_attn_decode
-
+        globals()[name] = dllm_flash_attn_decode
         return dllm_flash_attn_decode
     if name == "dllm_flash_attn_prefill":
         from diffulex_kernel.python.dllm_flash_attn_kernels import dllm_flash_attn_prefill
-
+        globals()[name] = dllm_flash_attn_prefill
         return dllm_flash_attn_prefill
-    if name == "store_kvcache_distinct_layout":
-        from diffulex_kernel.python.kv_cache_kernels import store_kvcache_distinct_layout
-
-        return store_kvcache_distinct_layout
-    if name == "store_kvcache_unified_layout":
-        from diffulex_kernel.python.kv_cache_kernels import store_kvcache_unified_layout
-
-        return store_kvcache_unified_layout
-    if name == "load_kvcache":
-        from diffulex_kernel.python.kv_cache_kernels import load_kvcache
-
-        return load_kvcache
+    # ... similar for other attributes
     raise AttributeError(name)
diffulex/utils/quantization/strategies/linear_marlin_int8_w8a16.py (2)

69-82: Silent exception swallowing may hide configuration bugs.

The try/except blocks with bare pass silently ignore any errors when parsing config values. This could hide misconfigurations (e.g., a string value that can't be converted to int). Consider logging a warning or using a more specific exception type.

♻️ Suggested improvement
     def configure(self, *, diffulex_config: Any | None = None) -> None:
         # Prefer explicit config fields over environment-variable based tuning.
         if diffulex_config is None:
             return
         try:
             bn = int(getattr(diffulex_config, "linear_w8a16_quant_block_n", self._quant_block_n))
             self._quant_block_n = max(1, bn)
-        except Exception:
-            pass
+        except (TypeError, ValueError) as e:
+            import warnings
+            warnings.warn(f"Invalid linear_w8a16_quant_block_n config: {e}")
         try:
             thr = int(getattr(diffulex_config, "linear_w8a16_allspark_cublas_m_threshold", self._cublas_m_thr))
             self._cublas_m_thr = max(1, thr)
-        except Exception:
-            pass
+        except (TypeError, ValueError) as e:
+            import warnings
+            warnings.warn(f"Invalid linear_w8a16_allspark_cublas_m_threshold config: {e}")

304-305: Complex n calculation could use clarifying comment.

The logic for determining n from multiple sources (out_features, bias, scales) is dense. A brief comment explaining the priority and fallback chain would improve maintainability.

diffulex/strategy/d2f/engine/model_runner.py (1)

355-361: Consider narrowing the exception handling for forward plan setup.

The bare except Exception silently swallows all errors, which could mask legitimate issues during LinearBase.enable_forward_plan() setup (e.g., attribute errors, type errors, or device mismatches).

Consider catching only expected exceptions like ImportError or AttributeError, or at minimum logging when an exception occurs:

♻️ Suggested improvement
         try:
             from diffulex.layer.linear import LinearBase
             for m in self.model.modules():
                 if isinstance(m, LinearBase):
                     m.enable_forward_plan(True)
-        except Exception:
-            pass
+        except (ImportError, AttributeError):
+            pass  # LinearBase not available or missing enable_forward_plan
diffulex/utils/quantization/strategies/linear_gptq_marlin_w4a16.py (2)

20-20: Unused import: torch.nn.functional is imported but never used.

F is imported but not referenced anywhere in this file.

🔧 Suggested fix
 import torch
-import torch.nn.functional as F

127-143: Duplicated empty tensor creation logic can be extracted.

The logic for creating/caching empty g_idx tensors is repeated for both g_idx and g_idx_sort_indices. Consider extracting a helper method.

♻️ Suggested refactor
+    def _get_empty_tensor(self, dev_key: int, device: torch.device) -> torch.Tensor:
+        empty = self._empty_cache.get(dev_key)
+        if empty is None:
+            empty = marlin_make_empty_g_idx(device) if marlin_make_empty_g_idx is not None else torch.empty((0,), device=device, dtype=torch.int32)
+            self._empty_cache[dev_key] = empty
+        return empty
+
     # Then in linear_forward:
-        if g_idx is None or g_idx.numel() == 0:
-            empty = self._empty_cache.get(dev_key)
-            if empty is None:
-                empty = marlin_make_empty_g_idx(device) if marlin_make_empty_g_idx is not None else torch.empty((0,), device=device, dtype=torch.int32)
-                self._empty_cache[dev_key] = empty
-            g_idx_t = empty
-        else:
-            g_idx_t = g_idx
-        if g_idx_sort_indices is None or g_idx_sort_indices.numel() == 0:
-            empty = self._empty_cache.get(dev_key)
-            if empty is None:
-                empty = marlin_make_empty_g_idx(device) if marlin_make_empty_g_idx is not None else torch.empty((0,), device=device, dtype=torch.int32)
-                self._empty_cache[dev_key] = empty
-            g_idx_sort_t = empty
-        else:
-            g_idx_sort_t = g_idx_sort_indices
+        empty = self._get_empty_tensor(dev_key, device)
+        g_idx_t = g_idx if (g_idx is not None and g_idx.numel() > 0) else empty
+        g_idx_sort_t = g_idx_sort_indices if (g_idx_sort_indices is not None and g_idx_sort_indices.numel() > 0) else empty
diffulex/utils/quantization/strategies/linear_awq_marlin_w4a16.py (1)

22-22: Unused import: apply_awq_marlin_linear is imported but never used.

The import fallback assigns it to None, but the actual function is never called in the code.

🔧 Suggested fix
 try:
     from vllm.model_executor.layers.quantization.utils.marlin_utils import (  # type: ignore
-        apply_awq_marlin_linear,
         marlin_make_empty_g_idx,
         should_use_atomic_add_reduce,
         marlin_permute_bias,
     )
diffulex/utils/quantization/strategies/linear_gptq_w4a16.py (1)

111-115: Silent fallback to F.linear may mask configuration errors.

When qweight, qzeros, or scales are missing, the code falls back to F.linear(x, weight, bias). This could hide misconfigurations where offline weights were expected but not loaded. Consider logging a warning.

⚠️ Add warning for fallback path
         if qweight is None or qzeros is None or scales is None:
             # correctness fallback (should not happen for offline GPTQ weights)
             if weight is None:
                 raise RuntimeError("GPTQ offline weights missing packed tensors and bf16 weight is not present.")
+            import warnings
+            warnings.warn(
+                f"GPTQ packed tensors missing; falling back to F.linear. "
+                "This may indicate a configuration issue.",
+                RuntimeWarning,
+                stacklevel=2,
+            )
             return F.linear(x, weight, bias)
diffulex/utils/loader.py (4)

24-29: Silent exception swallowing in _read_quantize_config.

Catching all exceptions and returning an empty dict could hide I/O errors or JSON parse errors that indicate real problems (e.g., corrupted config file).

⚠️ Consider logging exceptions
     try:
         with open(cfg_path, "r") as f:
             data = json.load(f)
         return data if isinstance(data, dict) else {}
-    except Exception:
+    except Exception as e:
+        logger.debug(f"Failed to read quantize_config.json: {e}")
         return {}

71-79: _infer_module_device returns CPU as default, which may cause device mismatch issues.

If a module has no parameters or buffers (unlikely but possible for some wrapper modules), the function returns torch.device("cpu"), which could cause silent device mismatches when setting weights.

💡 Consider raising or warning for empty modules
 def _infer_module_device(module: nn.Module) -> torch.device:
     w = getattr(module, "weight", None)
     if isinstance(w, torch.Tensor):
         return w.device
     for p in module.parameters(recurse=False):
         return p.device
     for b in module.buffers(recurse=False):
         return b.device
+    # Module has no parameters/buffers; caller should handle this case
     return torch.device("cpu")

132-135: Complex g_idx handling with nested conditionals.

The logic for handling g_idx with numel() checks is duplicated and hard to follow. The getattr(g_idx, "numel", lambda: 1)() pattern is unusual.

♻️ Simplify g_idx handling
-    if g_idx is None:
-        module.gptq_g_idx = torch.empty(0, dtype=torch.int32, device=module_device)
-    else:
-        if getattr(g_idx, "numel", lambda: 1)() == 0:
-            module.gptq_g_idx = torch.empty(0, dtype=torch.int32, device=module_device)
-        else:
-            module.gptq_g_idx = g_idx.to(dtype=torch.int32)
+    if g_idx is None or g_idx.numel() == 0:
+        module.gptq_g_idx = torch.empty(0, dtype=torch.int32, device=module_device)
+    else:
+        module.gptq_g_idx = g_idx.to(dtype=torch.int32)

382-399: Complex dimension inference logic with multiple fallback paths.

The GPTQ dimension inference has many branches for empty qzeros, Marlin checkpoints, and standard formats. Consider extracting this into a helper function for clarity and testability.

diffulex/utils/quantization/quantize_model.py (3)

272-273: AWQ only supports 4-bit - consider future-proofing.

The current restriction to 4-bit for AWQ is documented, but the error message could be more informative about when other bit widths might be supported.


357-357: Empty perm tensor created on CUDA unconditionally.

Line 357 creates empty_perm on "cuda" device, but if device="cpu" was specified earlier (which would fail at line 314-315), this could cause issues. The earlier check should catch this, but the hardcoded device is inconsistent.

🔧 Use consistent device handling
-        empty_perm = torch.empty((0,), dtype=torch.int32, device="cuda")
+        target_device = "cuda"  # gptq_marlin requires CUDA
+        empty_perm = torch.empty((0,), dtype=torch.int32, device=target_device)

652-670: Copying all safetensors files then adding quantized weights could cause confusion.

The workflow copies original safetensors files and then adds a new quantized file. Users might not realize the original weights are still present. Consider documenting this or optionally removing original weight files.

📝 Add documentation comment
     # Copy all model files (config, tokenizer, etc.) to output directory
     import shutil
     print(f"\nCopying model files to {output_path}...")
     model_path_obj = Path(model_path)
     
-    # First, copy original safetensors files (for non-quantized layers like lm_head, embeddings, etc.)
+    # First, copy original safetensors files (for non-quantized layers like lm_head, embeddings, etc.)
+    # NOTE: Both original and quantized weight files will exist in output directory.
+    # The loader prioritizes quantized weights (*.qweight) over original weights.
     print("  Copying original safetensors files (for non-quantized layers)...")

Comment on lines 38 to 46
dist.init_process_group("nccl", init_method, world_size=self.world_size, rank=rank, device_id=config.device_ids[rank])
device_id = (getattr(config, "device_start", 0) or 0) + rank + config.device_ids[rank]
# Choose CUDA device for this TP rank.
# config.device_ids is already a list of logical CUDA device indices (respecting CUDA_VISIBLE_DEVICES).
# Do NOT add rank again, otherwise rank 1 with device_ids=[0,1] becomes device 2.
if getattr(config, "device_ids", None):
device_id = config.device_ids[rank]
else:
device_id = (getattr(config, "device_start", 0) or 0) + rank
assert 0 <= device_id < torch.cuda.device_count(), f"Invalid device_id {device_id}."
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Potential inconsistency in device_ids access.

Line 38 unconditionally accesses config.device_ids[rank], but lines 42-45 check if device_ids exists before using it. If config.device_ids is None or missing, line 38 will raise an error before the fallback logic is reached.

🐛 Suggested fix
         init_method = f"tcp://{config.master_addr}:{config.master_port}"
-        dist.init_process_group("nccl", init_method, world_size=self.world_size, rank=rank, device_id=config.device_ids[rank])
         # Choose CUDA device for this TP rank.
         # config.device_ids is already a list of logical CUDA device indices (respecting CUDA_VISIBLE_DEVICES).
         # Do NOT add rank again, otherwise rank 1 with device_ids=[0,1] becomes device 2.
         if getattr(config, "device_ids", None):
             device_id = config.device_ids[rank]
         else:
             device_id = (getattr(config, "device_start", 0) or 0) + rank
+        dist.init_process_group("nccl", init_method, world_size=self.world_size, rank=rank, device_id=device_id)
🤖 Prompt for AI Agents
In `@diffulex/engine/model_runner.py` around lines 38 - 46,
dist.init_process_group is calling init_method with an unconditional access to
config.device_ids[rank] which can raise if device_ids is missing; guard that
access by computing device_id first (use config.device_ids[rank] only if
getattr(config, "device_ids", None) is truthy, otherwise compute from
config.device_start + rank) and then pass device_id into dist.init_process_group
(and keep the existing assert using torch.cuda.device_count()); update the call
site referencing config.device_ids, device_id, dist.init_process_group, and
torch.cuda.device_count() accordingly.

Comment on lines +286 to +294
for name, cand in offline_capable_modules.items():
if (
name == module_name
or name.endswith("." + module_name)
or module_name.endswith("." + name)
or (name.split(".")[-1] == leaf)
):
return cand
return None
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Fuzzy module name matching could cause incorrect mappings.

The fallback logic using endswith and split(".")[-1] comparisons is permissive and could match unintended modules with similar names (e.g., layer.q_proj matching other_layer.q_proj).

⚠️ Consider more precise matching

The current logic:

or name.endswith("." + module_name)
or module_name.endswith("." + name)
or (name.split(".")[-1] == leaf)

Could match modules incorrectly. Consider documenting this behavior or tightening the matching criteria.

🤖 Prompt for AI Agents
In `@diffulex/utils/loader.py` around lines 286 - 294, The fuzzy matching in the
loop (using offline_capable_modules, module_name and leaf) is too permissive and
can map wrong modules; replace the loose checks with stricter rules: keep exact
match, allow only full-dot-boundary suffix/prefix matches (name == module_name
or name == f"{parent}.{module_name}" style), and remove the single-segment
fallback (name.split(".")[-1] == leaf) or tighten it to require the last two
path segments to match (compare the last two segments of name and module_name)
before returning cand; update the matching logic inside the function that
iterates offline_capable_modules and add a short comment explaining the fallback
behavior.

Comment on lines +104 to +119
def _load_calib_texts(
calib_text_file: str, *, num_samples: int, seed: int
) -> list[str]:
p = Path(calib_text_file)
if not p.exists():
raise FileNotFoundError(f"calib_text_file 不存在: {calib_text_file}")
lines = [ln.strip() for ln in p.read_text(encoding="utf-8", errors="ignore").splitlines()]
lines = [ln for ln in lines if ln]
if not lines:
raise ValueError(f"calib_text_file 为空: {calib_text_file}")
if num_samples <= 0:
raise ValueError(f"calib_num_samples 必须 > 0, got {num_samples}")
if len(lines) <= num_samples:
return lines[:num_samples]
rng = random.Random(seed)
return rng.sample(lines, k=num_samples)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Potential issues in _load_calib_texts.

  1. Line 116-117: If len(lines) <= num_samples, it returns lines[:num_samples], but if len(lines) < num_samples, this silently returns fewer samples than requested.
  2. The seed parameter allows reproducibility, which is good.
⚠️ Consider warning when fewer samples available
     if len(lines) <= num_samples:
+        if len(lines) < num_samples:
+            import warnings
+            warnings.warn(
+                f"Requested {num_samples} samples but only {len(lines)} available in {calib_text_file}",
+                UserWarning,
+            )
         return lines[:num_samples]
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def _load_calib_texts(
calib_text_file: str, *, num_samples: int, seed: int
) -> list[str]:
p = Path(calib_text_file)
if not p.exists():
raise FileNotFoundError(f"calib_text_file 不存在: {calib_text_file}")
lines = [ln.strip() for ln in p.read_text(encoding="utf-8", errors="ignore").splitlines()]
lines = [ln for ln in lines if ln]
if not lines:
raise ValueError(f"calib_text_file 为空: {calib_text_file}")
if num_samples <= 0:
raise ValueError(f"calib_num_samples 必须 > 0, got {num_samples}")
if len(lines) <= num_samples:
return lines[:num_samples]
rng = random.Random(seed)
return rng.sample(lines, k=num_samples)
def _load_calib_texts(
calib_text_file: str, *, num_samples: int, seed: int
) -> list[str]:
p = Path(calib_text_file)
if not p.exists():
raise FileNotFoundError(f"calib_text_file 不存在: {calib_text_file}")
lines = [ln.strip() for ln in p.read_text(encoding="utf-8", errors="ignore").splitlines()]
lines = [ln for ln in lines if ln]
if not lines:
raise ValueError(f"calib_text_file 为空: {calib_text_file}")
if num_samples <= 0:
raise ValueError(f"calib_num_samples 必须 > 0, got {num_samples}")
if len(lines) <= num_samples:
if len(lines) < num_samples:
import warnings
warnings.warn(
f"Requested {num_samples} samples but only {len(lines)} available in {calib_text_file}",
UserWarning,
)
return lines[:num_samples]
rng = random.Random(seed)
return rng.sample(lines, k=num_samples)
🤖 Prompt for AI Agents
In `@diffulex/utils/quantization/quantize_model.py` around lines 104 - 119, The
helper _load_calib_texts silently returns fewer samples when the file has fewer
lines than requested; change the logic in _load_calib_texts so that when
len(lines) < num_samples you emit a clear warning (use the warnings module)
indicating fewer samples are available and return all lines, and only use the
rng.sample path when len(lines) >= num_samples (i.e., change the current
len(lines) <= num_samples branch to handle the < case with a warning and return,
keeping rng.sample(lines, k=num_samples) for the sampling case); reference the
function name _load_calib_texts and the variables lines, num_samples, seed when
making the change.

m,
n,
k,
True, # is_k_full
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Hardcoded is_k_full=True may not always be correct.

Unlike the GPTQ Marlin strategy which computes is_k_full based on has_g_idx and row_parallel, the AWQ strategy hardcodes it to True. This could cause issues with row-parallel tensor-parallel configurations.

🔧 Consider computing is_k_full dynamically
+        # AWQ doesn't use g_idx; is_k_full depends on row-parallel setting.
+        row_parallel = bool(tp_dim == 1)
+        is_k_full = True  # AWQ has no g_idx, so is_k_full follows row_parallel logic
+        # Note: For row-parallel, vLLM's marlin_is_k_full(has_g_idx=False, is_row_parallel=True) returns True
+
         out = torch.ops._C.gptq_marlin_gemm(
             ...
-            True,  # is_k_full
+            is_k_full,
             ...
         )
🤖 Prompt for AI Agents
In `@diffulex/utils/quantization/strategies/linear_awq_marlin_w4a16.py` at line
159, The AWQ strategy currently hardcodes is_k_full=True which can be incorrect
for row-parallel tensor-parallel cases; change the assignment of is_k_full in
the linear_awq_marlin_w4a16 strategy to compute it dynamically (e.g., mirror the
GPTQ Marlin logic) using the existing has_g_idx and row_parallel variables so
is_k_full = not has_g_idx or row_parallel (or equivalent logic used in the GPTQ
Marlin implementation); update the place where the tuple/flag is constructed
(the entry labeled is_k_full) so it derives from has_g_idx and row_parallel
rather than the literal True.

Comment on lines 62 to 64
def get_storage_dtype(self) -> tuple[torch.dtype, int]:
# FP8 weights are stored as uint8 (1 byte per element)
# vLLM stores fp8 weights as float8 dtype tensor
return torch.uint8, 1
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

# First, examine the file structure and the specific lines in question
cat -n diffulex/utils/quantization/strategies/linear_fp8_w8a16.py | head -80

Repository: SJTU-DENG-Lab/Diffulex

Length of output: 465


🏁 Script executed:

# Search for any references to float8 or FP8 in the codebase
rg "float8|float8_e4m3" --type py

Repository: SJTU-DENG-Lab/Diffulex

Length of output: 465


🏁 Script executed:

# Look for vLLM imports and usage patterns
rg "from vllm|import vllm" --type py -A 3

Repository: SJTU-DENG-Lab/Diffulex

Length of output: 465


🏁 Script executed:

# Search for other get_storage_dtype implementations to understand the pattern
ast-grep --pattern 'def get_storage_dtype($$$) {
  $$$
}'

Repository: SJTU-DENG-Lab/Diffulex

Length of output: 465


🌐 Web query:

vLLM FP8 quantization dtype storage float8

💡 Result:

  • What “FP8” means in vLLM: vLLM uses the two standard FP8 formats E4M3FN and E5M2 (different range/precision tradeoffs). vLLM’s “fp8” generally refers to FP8 E4M3. [1][2]
  • Online FP8 quantization (quantization="fp8"): vLLM dynamically quantizes an FP16/BF16 model so that Linear weights are quantized to FP8_E4M3 with a per-tensor scale (activations are dynamically scaled each forward pass). [1]
  • Actual tensor dtype used for stored FP8 weights: when a checkpoint is FP8-serialized, vLLM allocates weight tensors with PyTorch’s float8 dtype torch.float8_e4m3fn (i.e., real float8 storage, not int8). [3]
  • FP8 KV-cache storage: kv_cache_dtype supports "fp8"/"fp8_e4m3" and "fp8_e5m2" (CUDA 11.8+; ROCm supports "fp8" = E4M3). With E4M3, vLLM notes you typically need an FP32 scaling factor alongside each quantized tensor. [2][4]

Sources: FP8 W8A8 docs [1], Engine args / --kv-cache-dtype [2], vLLM FP8 layer code showing torch.float8_e4m3fn weight dtype [3], Quantized KV-cache docs [4].


Return value should match vLLM's actual FP8 dtype storage.

The comment states "vLLM stores fp8 weights as float8 dtype tensor", but the method returns torch.uint8. vLLM actually stores FP8-quantized weights using PyTorch's torch.float8_e4m3fn dtype. Update the return value to torch.float8_e4m3fn to match both the comment and vLLM's implementation.

🤖 Prompt for AI Agents
In `@diffulex/utils/quantization/strategies/linear_fp8_w8a16.py` around lines 62 -
64, The get_storage_dtype function currently returns torch.uint8 although the
comment and vLLM use an FP8 float dtype; change the return value in
get_storage_dtype to use PyTorch's FP8 dtype (torch.float8_e4m3fn) so it returns
(torch.float8_e4m3fn, 1) instead of (torch.uint8, 1), ensuring the storage dtype
matches vLLM's FP8 implementation.

Comment on lines +115 to +129
if weight is not None and quant_scales is not None:
# Expected: weight is fp8 K×N tensor (transpose-view is fine).
q_kn = weight.to(device=x.device)
scales = quant_scales.to(device=x.device, dtype=torch.float32).reshape(1)
else:
# Lazy cache: use weight tensor id as key (only for bf16/fp16/fp32 weights)
weight_id = id(weight)

# Check cache
if weight_id in self._weight_cache:
quantized_weight, scales = self._weight_cache[weight_id]
# Ensure cached tensors are on the correct device
if quantized_weight.device != x.device:
quantized_weight = quantized_weight.to(device=x.device)
scales = scales.to(device=x.device)
wid = id(weight)
cached = self._weight_cache.get(wid)
if cached is None or cached[0].device != x.device:
q_fp8, meta = self.quantize(weight)
q_fp8 = q_fp8.to(device=x.device)
scales = meta["scales"].to(device=x.device, dtype=torch.float32).reshape(1)
q_kn = q_fp8
self._weight_cache[wid] = (q_fp8, scales)
else:
# Quantize weight and cache it
quantized_weight, scales = self.quantize_weight_for_kernel(weight, device=x.device)
# Cache the quantized weight and scales
self._weight_cache[weight_id] = (quantized_weight, scales)

# Speed-first option: cache dequantized bf16 weight for F.linear (cuBLAS)
# This trades extra GPU memory for throughput.
import os
if os.getenv("DIFFULEX_FP8_W8A16_PREFER_CUBLAS", "0") == "1":
deq_key = id(weight) if weight.dtype != torch.uint8 else id(quantized_weight)
deq_w = self._dequant_weight_cache.get(deq_key)
if deq_w is None or deq_w.device != x.device:
# Dequantize: FP8[N,K] * scales[N] -> bf16[N,K]
deq_w = self.dequantize(quantized_weight, scales)
self._dequant_weight_cache[deq_key] = deq_w
return F.linear(x, deq_w, bias)

# Try to use TileLang kernel if available
fp8_w8a16_gemm = None
if self.weight_dtype_str == "fp8_e4m3":
fp8_w8a16_gemm = _fp8_e4m3_w8a16_gemm
elif self.weight_dtype_str == "fp8_e5m2":
fp8_w8a16_gemm = _fp8_e5m2_w8a16_gemm

if _TILELANG_AVAILABLE and fp8_w8a16_gemm is not None:
try:
# Check device
if x.device.type != 'cuda':
return self._fallback_python_forward(x, quantized_weight, scales, bias)

# Get shapes
M, K = x.shape
N, K_w = quantized_weight.shape
assert K == K_w, f"K dimension mismatch: {K} != {K_w}"

# Bucket M to reduce compilation churn
M_bucket = M
if M > 1:
if M <= 64:
M_bucket = 1 << (M - 1).bit_length()
else:
M_bucket = ((M + 63) // 64) * 64

x_for_kernel = x
if M_bucket != M:
x_pad = torch.zeros((M_bucket, K), device=x.device, dtype=x.dtype)
x_pad[:M, :] = x
x_for_kernel = x_pad

# TileLang autotune: use warmup + config cache pattern
cache_key = (str(x.device), M_bucket, N, K)
config = self._tl_autotune_config_cache.get(cache_key)

if _AUTOTUNE_AVAILABLE and is_warming_up() and config is None:
# Warmup phase: run autotune with real inputs
try:
assert self.spec.fp8_view_dtype is not None
qweight_fp8 = quantized_weight.view(self.spec.fp8_view_dtype)
with set_autotune_inputs([x_for_kernel, qweight_fp8, scales]):
kernel = fp8_w8a16_gemm(M_bucket, N, K)
config = kernel.config
self._tl_autotune_config_cache[cache_key] = config
except Exception:
# Fallback to default config if autotune fails
config = None

# Use cached config or default parameters
assert self.spec.fp8_view_dtype is not None
qweight_fp8 = quantized_weight.view(self.spec.fp8_view_dtype)
if config is not None:
kernel = fp8_w8a16_gemm(M_bucket, N, K, **config)
else:
# Default config (backward compatible)
kernel = fp8_w8a16_gemm(M_bucket, N, K, block_M=64, block_N=64, block_K=128, num_stages=2, threads=128)

# Call kernel - out_idx=[3] means output is the 4th parameter
assert self.spec.fp8_view_dtype is not None
qweight_fp8 = quantized_weight.view(self.spec.fp8_view_dtype)
output_full = kernel(x_for_kernel, qweight_fp8, scales)
output = output_full[:M, :] if M_bucket != M else output_full

# Add bias if present
if bias is not None:
output = output + bias

return output
except Exception as e:
# Fallback to Python implementation on any error
import warnings
error_msg = str(e)

# Extract meaningful error information
if 'sm_' in error_msg and ('not defined' in error_msg or 'fatal' in error_msg):
# CUDA architecture not supported - silently fallback
pass
elif 'Compilation error' in error_msg:
# Extract the actual error
idx = error_msg.find('Compilation error')
after = error_msg[idx + len('Compilation error'):]
lines = after.split('\n')
for line in lines:
line = line.strip()
if line and not line.startswith('#') and ('error:' in line.lower() or 'fatal' in line.lower()):
error_msg = f"CUDA compilation error: {line[:200]}"
break
else:
error_msg = "CUDA compilation error (see logs for details)"
elif 'pipeline' in error_msg.lower() and 'stage' in error_msg.lower():
# Pipeline stages mismatch - silently fallback
pass
else:
# Truncate very long error messages
if len(error_msg) > 200:
error_msg = error_msg[:200] + "..."

# Only warn for unexpected errors
if 'CUDA architecture not supported' not in error_msg and 'sm_' not in error_msg and 'Pipeline stages' not in error_msg:
warnings.warn(
f"TileLang kernel failed, falling back to Python implementation: {error_msg}",
UserWarning,
)
return self._fallback_python_forward(x, quantized_weight, scales, bias)
else:
# TileLang not available, use Python reference
return self._fallback_python_forward(x, quantized_weight, scales, bias)

def _fallback_python_forward(
self,
x: torch.Tensor,
quantized_weight: torch.Tensor,
scales: torch.Tensor,
bias: Optional[torch.Tensor],
) -> torch.Tensor:
"""Fallback Python implementation: dequantize + F.linear."""
# Dequantize for reference implementation
dequantized_weight = self.dequantize(quantized_weight, scales)

# Compute linear output
return F.linear(x, dequantized_weight, bias)

def clear_cache(self) -> None:
"""Clear the weight quantization cache.

Useful for memory management or when weights are updated (e.g., fine-tuning).
"""
self._weight_cache.clear()
self._dequant_weight_cache.clear()
q_kn, scales = cached
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Missing null check for weight in the else branch.

When the condition on line 115 is false, the code enters the else branch which calls self.quantize(weight). If weight is None (and only quant_scales is None), this will fail. Consider adding an explicit check.

🐛 Suggested fix
         if weight is not None and quant_scales is not None:
             # Expected: weight is fp8 K×N tensor (transpose-view is fine).
             q_kn = weight.to(device=x.device)
             scales = quant_scales.to(device=x.device, dtype=torch.float32).reshape(1)
         else:
+            if weight is None:
+                raise ValueError("weight cannot be None when quant_scales is not provided")
             wid = id(weight)
             cached = self._weight_cache.get(wid)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if weight is not None and quant_scales is not None:
# Expected: weight is fp8 K×N tensor (transpose-view is fine).
q_kn = weight.to(device=x.device)
scales = quant_scales.to(device=x.device, dtype=torch.float32).reshape(1)
else:
# Lazy cache: use weight tensor id as key (only for bf16/fp16/fp32 weights)
weight_id = id(weight)
# Check cache
if weight_id in self._weight_cache:
quantized_weight, scales = self._weight_cache[weight_id]
# Ensure cached tensors are on the correct device
if quantized_weight.device != x.device:
quantized_weight = quantized_weight.to(device=x.device)
scales = scales.to(device=x.device)
wid = id(weight)
cached = self._weight_cache.get(wid)
if cached is None or cached[0].device != x.device:
q_fp8, meta = self.quantize(weight)
q_fp8 = q_fp8.to(device=x.device)
scales = meta["scales"].to(device=x.device, dtype=torch.float32).reshape(1)
q_kn = q_fp8
self._weight_cache[wid] = (q_fp8, scales)
else:
# Quantize weight and cache it
quantized_weight, scales = self.quantize_weight_for_kernel(weight, device=x.device)
# Cache the quantized weight and scales
self._weight_cache[weight_id] = (quantized_weight, scales)
# Speed-first option: cache dequantized bf16 weight for F.linear (cuBLAS)
# This trades extra GPU memory for throughput.
import os
if os.getenv("DIFFULEX_FP8_W8A16_PREFER_CUBLAS", "0") == "1":
deq_key = id(weight) if weight.dtype != torch.uint8 else id(quantized_weight)
deq_w = self._dequant_weight_cache.get(deq_key)
if deq_w is None or deq_w.device != x.device:
# Dequantize: FP8[N,K] * scales[N] -> bf16[N,K]
deq_w = self.dequantize(quantized_weight, scales)
self._dequant_weight_cache[deq_key] = deq_w
return F.linear(x, deq_w, bias)
# Try to use TileLang kernel if available
fp8_w8a16_gemm = None
if self.weight_dtype_str == "fp8_e4m3":
fp8_w8a16_gemm = _fp8_e4m3_w8a16_gemm
elif self.weight_dtype_str == "fp8_e5m2":
fp8_w8a16_gemm = _fp8_e5m2_w8a16_gemm
if _TILELANG_AVAILABLE and fp8_w8a16_gemm is not None:
try:
# Check device
if x.device.type != 'cuda':
return self._fallback_python_forward(x, quantized_weight, scales, bias)
# Get shapes
M, K = x.shape
N, K_w = quantized_weight.shape
assert K == K_w, f"K dimension mismatch: {K} != {K_w}"
# Bucket M to reduce compilation churn
M_bucket = M
if M > 1:
if M <= 64:
M_bucket = 1 << (M - 1).bit_length()
else:
M_bucket = ((M + 63) // 64) * 64
x_for_kernel = x
if M_bucket != M:
x_pad = torch.zeros((M_bucket, K), device=x.device, dtype=x.dtype)
x_pad[:M, :] = x
x_for_kernel = x_pad
# TileLang autotune: use warmup + config cache pattern
cache_key = (str(x.device), M_bucket, N, K)
config = self._tl_autotune_config_cache.get(cache_key)
if _AUTOTUNE_AVAILABLE and is_warming_up() and config is None:
# Warmup phase: run autotune with real inputs
try:
assert self.spec.fp8_view_dtype is not None
qweight_fp8 = quantized_weight.view(self.spec.fp8_view_dtype)
with set_autotune_inputs([x_for_kernel, qweight_fp8, scales]):
kernel = fp8_w8a16_gemm(M_bucket, N, K)
config = kernel.config
self._tl_autotune_config_cache[cache_key] = config
except Exception:
# Fallback to default config if autotune fails
config = None
# Use cached config or default parameters
assert self.spec.fp8_view_dtype is not None
qweight_fp8 = quantized_weight.view(self.spec.fp8_view_dtype)
if config is not None:
kernel = fp8_w8a16_gemm(M_bucket, N, K, **config)
else:
# Default config (backward compatible)
kernel = fp8_w8a16_gemm(M_bucket, N, K, block_M=64, block_N=64, block_K=128, num_stages=2, threads=128)
# Call kernel - out_idx=[3] means output is the 4th parameter
assert self.spec.fp8_view_dtype is not None
qweight_fp8 = quantized_weight.view(self.spec.fp8_view_dtype)
output_full = kernel(x_for_kernel, qweight_fp8, scales)
output = output_full[:M, :] if M_bucket != M else output_full
# Add bias if present
if bias is not None:
output = output + bias
return output
except Exception as e:
# Fallback to Python implementation on any error
import warnings
error_msg = str(e)
# Extract meaningful error information
if 'sm_' in error_msg and ('not defined' in error_msg or 'fatal' in error_msg):
# CUDA architecture not supported - silently fallback
pass
elif 'Compilation error' in error_msg:
# Extract the actual error
idx = error_msg.find('Compilation error')
after = error_msg[idx + len('Compilation error'):]
lines = after.split('\n')
for line in lines:
line = line.strip()
if line and not line.startswith('#') and ('error:' in line.lower() or 'fatal' in line.lower()):
error_msg = f"CUDA compilation error: {line[:200]}"
break
else:
error_msg = "CUDA compilation error (see logs for details)"
elif 'pipeline' in error_msg.lower() and 'stage' in error_msg.lower():
# Pipeline stages mismatch - silently fallback
pass
else:
# Truncate very long error messages
if len(error_msg) > 200:
error_msg = error_msg[:200] + "..."
# Only warn for unexpected errors
if 'CUDA architecture not supported' not in error_msg and 'sm_' not in error_msg and 'Pipeline stages' not in error_msg:
warnings.warn(
f"TileLang kernel failed, falling back to Python implementation: {error_msg}",
UserWarning,
)
return self._fallback_python_forward(x, quantized_weight, scales, bias)
else:
# TileLang not available, use Python reference
return self._fallback_python_forward(x, quantized_weight, scales, bias)
def _fallback_python_forward(
self,
x: torch.Tensor,
quantized_weight: torch.Tensor,
scales: torch.Tensor,
bias: Optional[torch.Tensor],
) -> torch.Tensor:
"""Fallback Python implementation: dequantize + F.linear."""
# Dequantize for reference implementation
dequantized_weight = self.dequantize(quantized_weight, scales)
# Compute linear output
return F.linear(x, dequantized_weight, bias)
def clear_cache(self) -> None:
"""Clear the weight quantization cache.
Useful for memory management or when weights are updated (e.g., fine-tuning).
"""
self._weight_cache.clear()
self._dequant_weight_cache.clear()
q_kn, scales = cached
if weight is not None and quant_scales is not None:
# Expected: weight is fp8 K×N tensor (transpose-view is fine).
q_kn = weight.to(device=x.device)
scales = quant_scales.to(device=x.device, dtype=torch.float32).reshape(1)
else:
if weight is None:
raise ValueError("weight cannot be None when quant_scales is not provided")
wid = id(weight)
cached = self._weight_cache.get(wid)
if cached is None or cached[0].device != x.device:
q_fp8, meta = self.quantize(weight)
q_fp8 = q_fp8.to(device=x.device)
scales = meta["scales"].to(device=x.device, dtype=torch.float32).reshape(1)
q_kn = q_fp8
self._weight_cache[wid] = (q_fp8, scales)
else:
q_kn, scales = cached
🤖 Prompt for AI Agents
In `@diffulex/utils/quantization/strategies/linear_fp8_w8a16.py` around lines 115
- 129, The else branch assumes weight is non-null but calls id(weight) and
self.quantize(weight) even when weight is None; add an explicit check at the
start of that branch (e.g., if weight is None: raise ValueError(...) or handle a
valid fallback) so you never call id() or self.quantize with None; update the
block that uses wid = id(weight), cached = self._weight_cache.get(wid),
self.quantize(weight), and subsequent assignments to q_fp8, meta, scales, q_kn
to only run when weight is verified non-None and provide a clear error message
referencing weight and quant_scales for easier debugging.

Comment on lines +109 to +117
wid = id(weight)
cached = self._weight_cache.get(wid)
if cached is None or cached[0].device != x.device:
q_fp8, meta = self.quantize(weight)
q_fp8 = q_fp8.to(device=x.device)
w_scale = meta["scales"].to(device=x.device, dtype=torch.float32).reshape(1)
self._weight_cache[wid] = (q_fp8, w_scale)
else:
weight_id = id(weight)
cached = self._weight_cache.get(weight_id)
if cached is None:
qweight, w_scales = self.quantize_weight_for_kernel(weight, device=x.device)
self._weight_cache[weight_id] = (qweight, w_scales)
else:
qweight, w_scales = cached
if qweight.device != x.device:
qweight = qweight.to(device=x.device)
w_scales = w_scales.to(device=x.device)
self._weight_cache[weight_id] = (qweight, w_scales)

# Optional: use cuBLAS BF16 (dequant once)
import os
if os.getenv("DIFFULEX_FP8_W8A8_PREFER_CUBLAS", "0") == "1":
deq_key = weight_id
deq_w = self._dequant_weight_cache.get(deq_key)
if deq_w is None or deq_w.device != x.device:
deq_w = self.dequantize(qweight, w_scales)
self._dequant_weight_cache[deq_key] = deq_w
# Also dequantize activation
x_q_temp, x_scales_temp = self.quantize_act_for_kernel(x, device=x.device)
x_deq = self._dequantize_act(x_q_temp, x_scales_temp)
return F.linear(x_deq, deq_w, bias)

# Quantize activation per-row
if x.dtype not in (torch.bfloat16, torch.float16, torch.float32):
x = x.to(torch.bfloat16)
x_q, x_scales = self.quantize_act_for_kernel(x, device=x.device)

# Try to use TileLang kernel if available
# For W8A8, weight_dtype and act_dtype should match (both e4m3 or both e5m2)
fp8_w8a8_gemm = None
if self.weight_dtype_str == "fp8_e4m3" and self.act_dtype_str == "fp8_e4m3":
fp8_w8a8_gemm = _fp8_e4m3_w8a8_gemm
elif self.weight_dtype_str == "fp8_e5m2" and self.act_dtype_str == "fp8_e5m2":
fp8_w8a8_gemm = _fp8_e5m2_w8a8_gemm

if _TILELANG_AVAILABLE and fp8_w8a8_gemm is not None:
try:
# Check device
if x.device.type != 'cuda':
return self._fallback_python_forward(x_q, x_scales, qweight, w_scales, bias)

# Get shapes
M, K = x_q.shape
N, K_w = qweight.shape
assert K == K_w, f"K dimension mismatch: {K} != {K_w}"

# Bucket M to reduce compilation churn
M_bucket = M
if M > 1:
if M <= 64:
M_bucket = 1 << (M - 1).bit_length()
else:
M_bucket = ((M + 63) // 64) * 64
q_fp8, w_scale = cached
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Same null-check concern as in W8A16 strategy.

Similar to the W8A16 variant, this branch can be entered when weight is None (since the only guard is quant_kind being unused). The self.quantize(weight) call on line 112 will fail if weight is None.

🐛 Suggested fix
     def linear_forward(
         self,
         x: torch.Tensor,
         weight: torch.Tensor,
         bias: Optional[torch.Tensor],
         *,
         quant_kind: str,
         **kwargs: Any,
     ) -> torch.Tensor:
         _ = quant_kind
+        if weight is None:
+            raise ValueError("weight cannot be None for FP8 W8A8 forward pass")
         wid = id(weight)
         cached = self._weight_cache.get(wid)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
wid = id(weight)
cached = self._weight_cache.get(wid)
if cached is None or cached[0].device != x.device:
q_fp8, meta = self.quantize(weight)
q_fp8 = q_fp8.to(device=x.device)
w_scale = meta["scales"].to(device=x.device, dtype=torch.float32).reshape(1)
self._weight_cache[wid] = (q_fp8, w_scale)
else:
weight_id = id(weight)
cached = self._weight_cache.get(weight_id)
if cached is None:
qweight, w_scales = self.quantize_weight_for_kernel(weight, device=x.device)
self._weight_cache[weight_id] = (qweight, w_scales)
else:
qweight, w_scales = cached
if qweight.device != x.device:
qweight = qweight.to(device=x.device)
w_scales = w_scales.to(device=x.device)
self._weight_cache[weight_id] = (qweight, w_scales)
# Optional: use cuBLAS BF16 (dequant once)
import os
if os.getenv("DIFFULEX_FP8_W8A8_PREFER_CUBLAS", "0") == "1":
deq_key = weight_id
deq_w = self._dequant_weight_cache.get(deq_key)
if deq_w is None or deq_w.device != x.device:
deq_w = self.dequantize(qweight, w_scales)
self._dequant_weight_cache[deq_key] = deq_w
# Also dequantize activation
x_q_temp, x_scales_temp = self.quantize_act_for_kernel(x, device=x.device)
x_deq = self._dequantize_act(x_q_temp, x_scales_temp)
return F.linear(x_deq, deq_w, bias)
# Quantize activation per-row
if x.dtype not in (torch.bfloat16, torch.float16, torch.float32):
x = x.to(torch.bfloat16)
x_q, x_scales = self.quantize_act_for_kernel(x, device=x.device)
# Try to use TileLang kernel if available
# For W8A8, weight_dtype and act_dtype should match (both e4m3 or both e5m2)
fp8_w8a8_gemm = None
if self.weight_dtype_str == "fp8_e4m3" and self.act_dtype_str == "fp8_e4m3":
fp8_w8a8_gemm = _fp8_e4m3_w8a8_gemm
elif self.weight_dtype_str == "fp8_e5m2" and self.act_dtype_str == "fp8_e5m2":
fp8_w8a8_gemm = _fp8_e5m2_w8a8_gemm
if _TILELANG_AVAILABLE and fp8_w8a8_gemm is not None:
try:
# Check device
if x.device.type != 'cuda':
return self._fallback_python_forward(x_q, x_scales, qweight, w_scales, bias)
# Get shapes
M, K = x_q.shape
N, K_w = qweight.shape
assert K == K_w, f"K dimension mismatch: {K} != {K_w}"
# Bucket M to reduce compilation churn
M_bucket = M
if M > 1:
if M <= 64:
M_bucket = 1 << (M - 1).bit_length()
else:
M_bucket = ((M + 63) // 64) * 64
q_fp8, w_scale = cached
def linear_forward(
self,
x: torch.Tensor,
weight: torch.Tensor,
bias: Optional[torch.Tensor],
*,
quant_kind: str,
**kwargs: Any,
) -> torch.Tensor:
_ = quant_kind
if weight is None:
raise ValueError("weight cannot be None for FP8 W8A8 forward pass")
wid = id(weight)
cached = self._weight_cache.get(wid)
if cached is None or cached[0].device != x.device:
q_fp8, meta = self.quantize(weight)
q_fp8 = q_fp8.to(device=x.device)
w_scale = meta["scales"].to(device=x.device, dtype=torch.float32).reshape(1)
self._weight_cache[wid] = (q_fp8, w_scale)
else:
q_fp8, w_scale = cached
🤖 Prompt for AI Agents
In `@diffulex/utils/quantization/strategies/linear_fp8_w8a8.py` around lines 109 -
117, The code calls self.quantize(weight) without guarding against weight being
None, which will raise if weight is absent; add the same null-check used in the
W8A16 variant before computing wid/using the cache (or directly before calling
self.quantize) so that when weight is None you skip quantization and set
q_fp8/w_scale to appropriate defaults or handle the None-case consistently;
update the branch around wid, self._weight_cache, and the call to
self.quantize(weight) to first check "if weight is None" and return or assign
safe placeholders used by the rest of the method.

Comment on lines +181 to +202
# Directly call the underlying CUDA op to minimize Python glue.
out = torch.ops._C.gptq_marlin_gemm(
reshaped_x,
None,
qweight,
marlin_bias,
scales,
None,
None,
zp,
g_idx_t,
g_idx_sort_t,
workspace,
wtype.id,
m,
n,
k,
is_k_full,
use_atomic_add,
True, # use_fp32_reduce
False, # is_zp_float
)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

🌐 Web query:

vLLM gptq_marlin_gemm function signature arguments

💡 Result:

In vLLM, the custom op schema for gptq_marlin_gemm (as registered with PyTorch) is:

gptq_marlin_gemm(
    Tensor a,
    Tensor? c_or_none,
    Tensor b_q_weight,
    Tensor b_scales,
    Tensor? global_scale,
    Tensor? b_zeros_or_none,
    Tensor? g_idx_or_none,
    Tensor? perm_or_none,
    Tensor workspace,
    int b_q_type,
    SymInt size_m,
    SymInt size_n,
    SymInt size_k,
    bool is_k_full,
    bool use_atomic_add,
    bool use_fp32_reduce,
    bool is_zp_float
) -> Tensor

[1]

A matching C++ implementation signature looks like:

torch::Tensor gptq_marlin_gemm(
  torch::Tensor& a,
  std::optional<torch::Tensor> c_or_none,
  torch::Tensor& b_q_weight,
  torch::Tensor& b_scales,
  std::optional<torch::Tensor> const& global_scale_or_none,
  std::optional<torch::Tensor> const& b_zeros_or_none,
  std::optional<torch::Tensor> const& g_idx_or_none,
  std::optional<torch::Tensor> const& perm_or_none,
  torch::Tensor& workspace,
  vllm::ScalarTypeId const& b_q_type_id,
  int64_t size_m, int64_t size_n, int64_t size_k,
  bool is_k_full, bool use_atomic_add, bool use_fp32_reduce,
  bool is_zp_float);

[2]

Argument order is exactly as shown above. a is the activation (MxK), b_q_weight/b_scales/(b_zeros) are the quantized weight + params, g_idx/perm are optional GPTQ grouping/permutation tensors, workspace is scratch memory, and size_m/size_n/size_k are the GEMM sizes. [1]

Sources:
[1] https://huggingface.co/kernels-community/quantization/blob/main/torch-ext/torch_binding.cpp
[2] https://huggingface.co/RedHatAI/quantization/blob/main/gptq_marlin/gptq_marlin.cu


🏁 Script executed:

# First, let's check if this is a local implementation or from vLLM
cd /tmp && find . -name "*.py" -type f 2>/dev/null | head -20

Repository: SJTU-DENG-Lab/Diffulex

Length of output: 465


🏁 Script executed:

# Search for gptq_marlin_gemm in the codebase
rg "gptq_marlin_gemm" --type py -B 5 -A 10

Repository: SJTU-DENG-Lab/Diffulex

Length of output: 465


🏁 Script executed:

# Check the file structure and imports
head -50 diffulex/utils/quantization/strategies/linear_gptq_marlin_w4a16.py

Repository: SJTU-DENG-Lab/Diffulex

Length of output: 465


🏁 Script executed:

# Look for any vLLM package or imports
rg "vllm|gptq_marlin" --type py | head -30

Repository: SJTU-DENG-Lab/Diffulex

Length of output: 465


Critical: argument order mismatch with vLLM's gptq_marlin_gemm signature

The call passes 19 arguments, but vLLM's gptq_marlin_gemm expects only 17. The function signature is:

gptq_marlin_gemm(
    a, c_or_none, b_q_weight, b_scales, global_scale, b_zeros_or_none,
    g_idx_or_none, perm_or_none, workspace, b_q_type, size_m, size_n,
    size_k, is_k_full, use_atomic_add, use_fp32_reduce, is_zp_float
)

The code's arguments diverge from arg 3 onward: marlin_bias is passed where b_scales is expected, scales where global_scale is expected, and subsequent arguments are offset, with workspace, sizes, and flags in incorrect positions. This will cause type mismatches or silent data corruption.

🤖 Prompt for AI Agents
In `@diffulex/utils/quantization/strategies/linear_gptq_marlin_w4a16.py` around
lines 181 - 202, The torch.ops._C.gptq_marlin_gemm call has its arguments
shifted and doesn't match vLLM's 17-arg signature; reorder the call so arguments
follow gptq_marlin_gemm(a, c_or_none, b_q_weight, b_scales, global_scale,
b_zeros_or_none, g_idx_or_none, perm_or_none, workspace, b_q_type, size_m,
size_n, size_k, is_k_full, use_atomic_add, use_fp32_reduce, is_zp_float): pass
reshaped_x, None, qweight, scales, marlin_bias (as global_scale), zp (as
b_zeros_or_none), g_idx_t, g_idx_sort_t, workspace, wtype.id, m, n, k,
is_k_full, use_atomic_add, True (use_fp32_reduce), False (is_zp_float) so all
parameters (reshaped_x, qweight, scales, marlin_bias, zp, g_idx_t, g_idx_sort_t,
workspace, wtype.id, m, n, k, is_k_full, use_atomic_add) are in the correct
positions for torch.ops._C.gptq_marlin_gemm.

Comment on lines +130 to +131
if empty is None or empty.device != device:
empty = torch.empty((0,), device=device, dtype=torch.int)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Redundant device check in empty tensor cache.

The cache already uses dev_key as the key, so checking empty.device != device is redundant—if the device index matches, the device should match. However, this could fail for CPU tensors where device.index is None.

🔧 Simplify the cache check
         if g_idx is None or g_idx.numel() == 0:
             empty = self._empty_cache.get(dev_key)
-            if empty is None or empty.device != device:
+            if empty is None:
                 empty = torch.empty((0,), device=device, dtype=torch.int)
                 self._empty_cache[dev_key] = empty
             g_idx_t = empty
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if empty is None or empty.device != device:
empty = torch.empty((0,), device=device, dtype=torch.int)
if empty is None:
empty = torch.empty((0,), device=device, dtype=torch.int)
🤖 Prompt for AI Agents
In `@diffulex/utils/quantization/strategies/linear_gptq_w4a16.py` around lines 130
- 131, Remove the redundant device comparison when retrieving cached empty
tensors: rely on the cache key (dev_key) to ensure device correctness and only
check "empty is None" before creating a new tensor. Concretely, in the block
that currently reads "if empty is None or empty.device != device:", change it to
"if empty is None" and keep the creation line "empty = torch.empty((0,),
device=device, dtype=torch.int)" so CPU tensors (device.index is None) are
handled consistently.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant