Feat/kv cache fp8 support#26
Feat/kv cache fp8 support#26luozixin2 merged 10 commits intoSJTU-DENG-Lab:feat/kv-cache-fp8-supportfrom
Conversation
主要变更: - 添加 GPTQ Marlin (W4A16) 和 AWQ Marlin (W4A16) 量化策略 - 修复 loader.py 以正确加载 gptq_marlin 格式权重(支持 Marlin 特有的 repacked qweight 和 permuted scales) - 修改 quantize_model.py 支持导出 gptq_marlin 格式(对称量化 + Marlin repack/permute) - 更新 linear.py: - 添加 _offline_quant_bits 缓冲区存储量化位数 - 添加 GPTQ runtime shuffle 支持(gptq_shuffle) - 添加 GPTQ/AWQ Marlin 的 lazy repack 支持(_maybe_prepare_offline_gptq_marlin/_awq_marlin) - 统一使用 vLLM 格式(int32 packed, fp16 scales) - 简化各策略文件,移除重复代码 - 移除旧的 AllSpark Marlin 实现文件 - 添加多个 benchmark 配置文件(GPTQ/AWQ Marlin 各 bit 版本)
benchmark_results 是本地生成的评测产物,不应进入版本库。 本提交将其作为正常删除移出,并依赖 .gitignore 中的 benchmark_results/ 规则避免后续再次提交。
|
Important Review skippedAuto reviews are disabled on base/target branches other than the default branch. Please check the settings in the CodeRabbit UI or the You can disable this status message by setting the
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
- 添加 quant-method=auto 支持:使用 auto-gptq / awq 进行真正的校准量化 - 添加校准数据参数:--calib-text-file, --calib-num-samples, --calib-seq-len 等 - 实现 _export_autogptq_to_vllm_weights:从 auto-gptq 量化模型中导出 vLLM 格式权重 - 实现 _export_awq_to_vllm_weights:从 awq 量化模型中导出 vLLM 格式权重 - 保留 quant-method=simple 旧实现作为后向兼容 - 修复 loader.py 中 gptq_marlin scales 的 shape 推理和 TP sharding 逻辑 - 修复 linear_gptq_marlin_w4a16.py 移除不必要的 bf16->fp16 转换
主要重构内容: 1. **diffulex/layer/linear.py** - 大幅简化量化逻辑(-197行): - 新增 `_forward_base()`: 统一的前向分发器,替换子类中重复的量化分支逻辑 - 新增 `_build_offline_forward_kwargs()`: 统一构建离线量化(GPTQ/AWQ)前向参数 - 新增 `_get_linear_strategy()`, `_offline_meta()`, `_infer_gptq_weight_bits()` 等辅助方法 - 修复 `LoRAMixin.merge_lora` 中 base weight 为 None 的边界情况 - 移除未使用的导入(marlin_zero_points, unpack_cols, marlin_make_empty_g_idx) 2. **diffulex/utils/loader.py** - 优化性能和代码结构: - 一次性扫描 safetensors 文件建立 key_to_file 索引,避免重复文件 I/O - 缓存 `model.named_modules()` 结果,避免重复构建字典 - 新增 `_find_offline_capable_module()`: 统一模块查找逻辑 - 新增 `_load_tensors_for_prefix()`: 集中加载张量,仅打开必要的文件 - 将 print() 替换为 logger.warning()/logger.exception() 以规范化日志 3. **diffulex/engine/model_runner.py** - 消除重复循环: - 在 `allocate_kv_cache` 中统一缓存 attention 模块列表 - 用 `enumerate(attn_modules)` 替换重复的模块遍历循环 4. **diffulex/utils/quantization/strategies/linear_int4_w4a16.py** - 修复缺失实现: - 添加 `quantize_weight_for_kernel` 方法,修复 W4A16 在线量化运行时错误 5. 删除未使用的配置文件 `gptq_marlin_w2_bf16kv_varlen.yml` 测试: 已验证 W8A16 在线量化和 GPTQ 离线量化功能正常
- 将最后总结从最后一步的瞬时吞吐改为真正的平均值(总token/总时间) - 新增 ms/step 统计信息,便于分析性能 - 修复了之前只显示最后一步瞬时值而非平均值的问题
- 量化 linear:去 kwargs/pop/重复可用性检查,缓存 out_features 与必要中间张量 - 直连 vLLM CUDA ops(W8A8/GPTQ/AWQ/Marlin 等)以降低 Python glue 开销 - load-time 处理 qweight/scales 的布局与 contiguous,避免 forward 里重复处理 - 移除 linear.py 中 profiler record 标注,保持代码简洁 - 补充 trace/profile 辅助分析脚本与相关测试
… strategies - Remove all .item() calls in LinearBase hot paths (GPU->CPU sync breaks graph capture) - Add Python-side meta cache (_offline_quant_*_py, _gptq_is_shuffled_py, etc.) - Use in-place fill_() + Python mirrors for state updates - Simplify linear quantization strategies for future CUDA Graph support - Remove fast_path checks and redundant branching in linear_marlin_int8_w8a16 - Remove fast_path in linear_int8_w8a8 (unified vLLM path) - Simplify linear_gptq_w4a16 (direct torch.ops._C.gptq_gemm call) - Make linear_fp8_w8a16 use explicit quant_scales parameter - Fix FP8 weight layout: do not force contiguous for transpose-view (KxN stride0==1) - Remove profiler record_function wrappers (graph-friendly) Net: -129 lines, cleaner codebase ready for CUDA Graph capture
- Add per-layer ForwardPlan to pre-resolve bf16/quant/offline paths and reduce per-call Python branching. - Prefer direct torch.ops kernels (GPTQ/AWQ/Marlin) with static args for stable capture. - Fix D2F static CUDA graph capture/replay metadata (token buckets + cu_seqlens) and add profiler flag.
- Fix tensor shape mismatch bug in static+CUDA Graph decode mode (model_runner.py) - Improve bucket selection logic for variable token counts - Add safety fallback when runtime batch exceeds captured capacity - Fix metadata buffer initialization and padding - Add new static mode benchmark configs: - awq_bf16kv_static.yml - gptq_marlin_w4_bf16kv_static.yml - gptq_marlin_w8_bf16kv_static.yml - Update quantization strategies and loader utilities - Update benchmark configurations for consistency
- 移除 v0.0.1 之后新增的 bench 配置与量化架构文档 - 将 W8A16/DP 等调参从 env 收敛到 Config/strategy.configure - 示例/脚本去掉硬编码本机路径与默认 GPU,并修复语法问题
600eb4c
into
SJTU-DENG-Lab:feat/kv-cache-fp8-support
There was a problem hiding this comment.
Actionable comments posted: 9
🤖 Fix all issues with AI agents
In `@diffulex/engine/model_runner.py`:
- Around line 38-46: dist.init_process_group is calling init_method with an
unconditional access to config.device_ids[rank] which can raise if device_ids is
missing; guard that access by computing device_id first (use
config.device_ids[rank] only if getattr(config, "device_ids", None) is truthy,
otherwise compute from config.device_start + rank) and then pass device_id into
dist.init_process_group (and keep the existing assert using
torch.cuda.device_count()); update the call site referencing config.device_ids,
device_id, dist.init_process_group, and torch.cuda.device_count() accordingly.
In `@diffulex/utils/loader.py`:
- Around line 286-294: The fuzzy matching in the loop (using
offline_capable_modules, module_name and leaf) is too permissive and can map
wrong modules; replace the loose checks with stricter rules: keep exact match,
allow only full-dot-boundary suffix/prefix matches (name == module_name or name
== f"{parent}.{module_name}" style), and remove the single-segment fallback
(name.split(".")[-1] == leaf) or tighten it to require the last two path
segments to match (compare the last two segments of name and module_name) before
returning cand; update the matching logic inside the function that iterates
offline_capable_modules and add a short comment explaining the fallback
behavior.
In `@diffulex/utils/quantization/quantize_model.py`:
- Around line 104-119: The helper _load_calib_texts silently returns fewer
samples when the file has fewer lines than requested; change the logic in
_load_calib_texts so that when len(lines) < num_samples you emit a clear warning
(use the warnings module) indicating fewer samples are available and return all
lines, and only use the rng.sample path when len(lines) >= num_samples (i.e.,
change the current len(lines) <= num_samples branch to handle the < case with a
warning and return, keeping rng.sample(lines, k=num_samples) for the sampling
case); reference the function name _load_calib_texts and the variables lines,
num_samples, seed when making the change.
In `@diffulex/utils/quantization/strategies/linear_awq_marlin_w4a16.py`:
- Line 159: The AWQ strategy currently hardcodes is_k_full=True which can be
incorrect for row-parallel tensor-parallel cases; change the assignment of
is_k_full in the linear_awq_marlin_w4a16 strategy to compute it dynamically
(e.g., mirror the GPTQ Marlin logic) using the existing has_g_idx and
row_parallel variables so is_k_full = not has_g_idx or row_parallel (or
equivalent logic used in the GPTQ Marlin implementation); update the place where
the tuple/flag is constructed (the entry labeled is_k_full) so it derives from
has_g_idx and row_parallel rather than the literal True.
In `@diffulex/utils/quantization/strategies/linear_fp8_w8a16.py`:
- Around line 115-129: The else branch assumes weight is non-null but calls
id(weight) and self.quantize(weight) even when weight is None; add an explicit
check at the start of that branch (e.g., if weight is None: raise
ValueError(...) or handle a valid fallback) so you never call id() or
self.quantize with None; update the block that uses wid = id(weight), cached =
self._weight_cache.get(wid), self.quantize(weight), and subsequent assignments
to q_fp8, meta, scales, q_kn to only run when weight is verified non-None and
provide a clear error message referencing weight and quant_scales for easier
debugging.
- Around line 62-64: The get_storage_dtype function currently returns
torch.uint8 although the comment and vLLM use an FP8 float dtype; change the
return value in get_storage_dtype to use PyTorch's FP8 dtype
(torch.float8_e4m3fn) so it returns (torch.float8_e4m3fn, 1) instead of
(torch.uint8, 1), ensuring the storage dtype matches vLLM's FP8 implementation.
In `@diffulex/utils/quantization/strategies/linear_fp8_w8a8.py`:
- Around line 109-117: The code calls self.quantize(weight) without guarding
against weight being None, which will raise if weight is absent; add the same
null-check used in the W8A16 variant before computing wid/using the cache (or
directly before calling self.quantize) so that when weight is None you skip
quantization and set q_fp8/w_scale to appropriate defaults or handle the
None-case consistently; update the branch around wid, self._weight_cache, and
the call to self.quantize(weight) to first check "if weight is None" and return
or assign safe placeholders used by the rest of the method.
In `@diffulex/utils/quantization/strategies/linear_gptq_marlin_w4a16.py`:
- Around line 181-202: The torch.ops._C.gptq_marlin_gemm call has its arguments
shifted and doesn't match vLLM's 17-arg signature; reorder the call so arguments
follow gptq_marlin_gemm(a, c_or_none, b_q_weight, b_scales, global_scale,
b_zeros_or_none, g_idx_or_none, perm_or_none, workspace, b_q_type, size_m,
size_n, size_k, is_k_full, use_atomic_add, use_fp32_reduce, is_zp_float): pass
reshaped_x, None, qweight, scales, marlin_bias (as global_scale), zp (as
b_zeros_or_none), g_idx_t, g_idx_sort_t, workspace, wtype.id, m, n, k,
is_k_full, use_atomic_add, True (use_fp32_reduce), False (is_zp_float) so all
parameters (reshaped_x, qweight, scales, marlin_bias, zp, g_idx_t, g_idx_sort_t,
workspace, wtype.id, m, n, k, is_k_full, use_atomic_add) are in the correct
positions for torch.ops._C.gptq_marlin_gemm.
In `@diffulex/utils/quantization/strategies/linear_gptq_w4a16.py`:
- Around line 130-131: Remove the redundant device comparison when retrieving
cached empty tensors: rely on the cache key (dev_key) to ensure device
correctness and only check "empty is None" before creating a new tensor.
Concretely, in the block that currently reads "if empty is None or empty.device
!= device:", change it to "if empty is None" and keep the creation line "empty =
torch.empty((0,), device=device, dtype=torch.int)" so CPU tensors (device.index
is None) are handled consistently.
🧹 Nitpick comments (17)
diffulex_bench/configs/example.yml (1)
21-25: Consider aligningmax_num_batched_tokenswith the newmax_model_len.With
max_model_len: 4096, keepingmax_num_batched_tokens: 4096effectively limits full-length batches to a single sequence. If the example aims to allow multi‑sequence batching at 4k context, consider raisingmax_num_batched_tokensor adding a brief note to avoid confusion.diffulex/utils/quantization/registry.py (1)
87-103: Update error message to include new quantization methods.The error message on line 101-102 lists supported methods but doesn't include the newly added
gptq_marlin,gptq_marlin_24, andawq_marlinaliases.✏️ Suggested fix
if s not in aliases: raise ValueError( f"Unsupported linear quant dtype={dtype!r}. " - "Supported: bf16/int8/int4/fp8/fp8_e4m3/fp8_e5m2/gptq/awq/marlin" + "Supported: bf16/int8/int4/fp8/fp8_e4m3/fp8_e5m2/gptq/gptq_marlin/gptq_marlin_24/awq/awq_marlin/marlin" )diffulex_kernel/__init__.py (1)
24-45: Consider caching imported modules to avoid repeated imports.Each call to
__getattr__re-executes the import statement. For frequently accessed attributes, consider caching the imported module in the module's namespace:♻️ Suggested improvement
def __getattr__(name: str): if name == "dllm_flash_attn_decode": from diffulex_kernel.python.dllm_flash_attn_kernels import dllm_flash_attn_decode - + globals()[name] = dllm_flash_attn_decode return dllm_flash_attn_decode if name == "dllm_flash_attn_prefill": from diffulex_kernel.python.dllm_flash_attn_kernels import dllm_flash_attn_prefill - + globals()[name] = dllm_flash_attn_prefill return dllm_flash_attn_prefill - if name == "store_kvcache_distinct_layout": - from diffulex_kernel.python.kv_cache_kernels import store_kvcache_distinct_layout - - return store_kvcache_distinct_layout - if name == "store_kvcache_unified_layout": - from diffulex_kernel.python.kv_cache_kernels import store_kvcache_unified_layout - - return store_kvcache_unified_layout - if name == "load_kvcache": - from diffulex_kernel.python.kv_cache_kernels import load_kvcache - - return load_kvcache + # ... similar for other attributes raise AttributeError(name)diffulex/utils/quantization/strategies/linear_marlin_int8_w8a16.py (2)
69-82: Silent exception swallowing may hide configuration bugs.The
try/exceptblocks with barepasssilently ignore any errors when parsing config values. This could hide misconfigurations (e.g., a string value that can't be converted to int). Consider logging a warning or using a more specific exception type.♻️ Suggested improvement
def configure(self, *, diffulex_config: Any | None = None) -> None: # Prefer explicit config fields over environment-variable based tuning. if diffulex_config is None: return try: bn = int(getattr(diffulex_config, "linear_w8a16_quant_block_n", self._quant_block_n)) self._quant_block_n = max(1, bn) - except Exception: - pass + except (TypeError, ValueError) as e: + import warnings + warnings.warn(f"Invalid linear_w8a16_quant_block_n config: {e}") try: thr = int(getattr(diffulex_config, "linear_w8a16_allspark_cublas_m_threshold", self._cublas_m_thr)) self._cublas_m_thr = max(1, thr) - except Exception: - pass + except (TypeError, ValueError) as e: + import warnings + warnings.warn(f"Invalid linear_w8a16_allspark_cublas_m_threshold config: {e}")
304-305: Complexncalculation could use clarifying comment.The logic for determining
nfrom multiple sources (out_features, bias, scales) is dense. A brief comment explaining the priority and fallback chain would improve maintainability.diffulex/strategy/d2f/engine/model_runner.py (1)
355-361: Consider narrowing the exception handling for forward plan setup.The bare
except Exceptionsilently swallows all errors, which could mask legitimate issues duringLinearBase.enable_forward_plan()setup (e.g., attribute errors, type errors, or device mismatches).Consider catching only expected exceptions like
ImportErrororAttributeError, or at minimum logging when an exception occurs:♻️ Suggested improvement
try: from diffulex.layer.linear import LinearBase for m in self.model.modules(): if isinstance(m, LinearBase): m.enable_forward_plan(True) - except Exception: - pass + except (ImportError, AttributeError): + pass # LinearBase not available or missing enable_forward_plandiffulex/utils/quantization/strategies/linear_gptq_marlin_w4a16.py (2)
20-20: Unused import:torch.nn.functionalis imported but never used.
Fis imported but not referenced anywhere in this file.🔧 Suggested fix
import torch -import torch.nn.functional as F
127-143: Duplicated empty tensor creation logic can be extracted.The logic for creating/caching empty g_idx tensors is repeated for both
g_idxandg_idx_sort_indices. Consider extracting a helper method.♻️ Suggested refactor
+ def _get_empty_tensor(self, dev_key: int, device: torch.device) -> torch.Tensor: + empty = self._empty_cache.get(dev_key) + if empty is None: + empty = marlin_make_empty_g_idx(device) if marlin_make_empty_g_idx is not None else torch.empty((0,), device=device, dtype=torch.int32) + self._empty_cache[dev_key] = empty + return empty + # Then in linear_forward: - if g_idx is None or g_idx.numel() == 0: - empty = self._empty_cache.get(dev_key) - if empty is None: - empty = marlin_make_empty_g_idx(device) if marlin_make_empty_g_idx is not None else torch.empty((0,), device=device, dtype=torch.int32) - self._empty_cache[dev_key] = empty - g_idx_t = empty - else: - g_idx_t = g_idx - if g_idx_sort_indices is None or g_idx_sort_indices.numel() == 0: - empty = self._empty_cache.get(dev_key) - if empty is None: - empty = marlin_make_empty_g_idx(device) if marlin_make_empty_g_idx is not None else torch.empty((0,), device=device, dtype=torch.int32) - self._empty_cache[dev_key] = empty - g_idx_sort_t = empty - else: - g_idx_sort_t = g_idx_sort_indices + empty = self._get_empty_tensor(dev_key, device) + g_idx_t = g_idx if (g_idx is not None and g_idx.numel() > 0) else empty + g_idx_sort_t = g_idx_sort_indices if (g_idx_sort_indices is not None and g_idx_sort_indices.numel() > 0) else emptydiffulex/utils/quantization/strategies/linear_awq_marlin_w4a16.py (1)
22-22: Unused import:apply_awq_marlin_linearis imported but never used.The import fallback assigns it to
None, but the actual function is never called in the code.🔧 Suggested fix
try: from vllm.model_executor.layers.quantization.utils.marlin_utils import ( # type: ignore - apply_awq_marlin_linear, marlin_make_empty_g_idx, should_use_atomic_add_reduce, marlin_permute_bias, )diffulex/utils/quantization/strategies/linear_gptq_w4a16.py (1)
111-115: Silent fallback to F.linear may mask configuration errors.When
qweight,qzeros, orscalesare missing, the code falls back toF.linear(x, weight, bias). This could hide misconfigurations where offline weights were expected but not loaded. Consider logging a warning.
⚠️ Add warning for fallback pathif qweight is None or qzeros is None or scales is None: # correctness fallback (should not happen for offline GPTQ weights) if weight is None: raise RuntimeError("GPTQ offline weights missing packed tensors and bf16 weight is not present.") + import warnings + warnings.warn( + f"GPTQ packed tensors missing; falling back to F.linear. " + "This may indicate a configuration issue.", + RuntimeWarning, + stacklevel=2, + ) return F.linear(x, weight, bias)diffulex/utils/loader.py (4)
24-29: Silent exception swallowing in_read_quantize_config.Catching all exceptions and returning an empty dict could hide I/O errors or JSON parse errors that indicate real problems (e.g., corrupted config file).
⚠️ Consider logging exceptionstry: with open(cfg_path, "r") as f: data = json.load(f) return data if isinstance(data, dict) else {} - except Exception: + except Exception as e: + logger.debug(f"Failed to read quantize_config.json: {e}") return {}
71-79:_infer_module_devicereturns CPU as default, which may cause device mismatch issues.If a module has no parameters or buffers (unlikely but possible for some wrapper modules), the function returns
torch.device("cpu"), which could cause silent device mismatches when setting weights.💡 Consider raising or warning for empty modules
def _infer_module_device(module: nn.Module) -> torch.device: w = getattr(module, "weight", None) if isinstance(w, torch.Tensor): return w.device for p in module.parameters(recurse=False): return p.device for b in module.buffers(recurse=False): return b.device + # Module has no parameters/buffers; caller should handle this case return torch.device("cpu")
132-135: Complexg_idxhandling with nested conditionals.The logic for handling
g_idxwithnumel()checks is duplicated and hard to follow. Thegetattr(g_idx, "numel", lambda: 1)()pattern is unusual.♻️ Simplify g_idx handling
- if g_idx is None: - module.gptq_g_idx = torch.empty(0, dtype=torch.int32, device=module_device) - else: - if getattr(g_idx, "numel", lambda: 1)() == 0: - module.gptq_g_idx = torch.empty(0, dtype=torch.int32, device=module_device) - else: - module.gptq_g_idx = g_idx.to(dtype=torch.int32) + if g_idx is None or g_idx.numel() == 0: + module.gptq_g_idx = torch.empty(0, dtype=torch.int32, device=module_device) + else: + module.gptq_g_idx = g_idx.to(dtype=torch.int32)
382-399: Complex dimension inference logic with multiple fallback paths.The GPTQ dimension inference has many branches for empty qzeros, Marlin checkpoints, and standard formats. Consider extracting this into a helper function for clarity and testability.
diffulex/utils/quantization/quantize_model.py (3)
272-273: AWQ only supports 4-bit - consider future-proofing.The current restriction to 4-bit for AWQ is documented, but the error message could be more informative about when other bit widths might be supported.
357-357: Empty perm tensor created on CUDA unconditionally.Line 357 creates
empty_permon"cuda"device, but ifdevice="cpu"was specified earlier (which would fail at line 314-315), this could cause issues. The earlier check should catch this, but the hardcoded device is inconsistent.🔧 Use consistent device handling
- empty_perm = torch.empty((0,), dtype=torch.int32, device="cuda") + target_device = "cuda" # gptq_marlin requires CUDA + empty_perm = torch.empty((0,), dtype=torch.int32, device=target_device)
652-670: Copying all safetensors files then adding quantized weights could cause confusion.The workflow copies original safetensors files and then adds a new quantized file. Users might not realize the original weights are still present. Consider documenting this or optionally removing original weight files.
📝 Add documentation comment
# Copy all model files (config, tokenizer, etc.) to output directory import shutil print(f"\nCopying model files to {output_path}...") model_path_obj = Path(model_path) - # First, copy original safetensors files (for non-quantized layers like lm_head, embeddings, etc.) + # First, copy original safetensors files (for non-quantized layers like lm_head, embeddings, etc.) + # NOTE: Both original and quantized weight files will exist in output directory. + # The loader prioritizes quantized weights (*.qweight) over original weights. print(" Copying original safetensors files (for non-quantized layers)...")
| dist.init_process_group("nccl", init_method, world_size=self.world_size, rank=rank, device_id=config.device_ids[rank]) | ||
| device_id = (getattr(config, "device_start", 0) or 0) + rank + config.device_ids[rank] | ||
| # Choose CUDA device for this TP rank. | ||
| # config.device_ids is already a list of logical CUDA device indices (respecting CUDA_VISIBLE_DEVICES). | ||
| # Do NOT add rank again, otherwise rank 1 with device_ids=[0,1] becomes device 2. | ||
| if getattr(config, "device_ids", None): | ||
| device_id = config.device_ids[rank] | ||
| else: | ||
| device_id = (getattr(config, "device_start", 0) or 0) + rank | ||
| assert 0 <= device_id < torch.cuda.device_count(), f"Invalid device_id {device_id}." |
There was a problem hiding this comment.
Potential inconsistency in device_ids access.
Line 38 unconditionally accesses config.device_ids[rank], but lines 42-45 check if device_ids exists before using it. If config.device_ids is None or missing, line 38 will raise an error before the fallback logic is reached.
🐛 Suggested fix
init_method = f"tcp://{config.master_addr}:{config.master_port}"
- dist.init_process_group("nccl", init_method, world_size=self.world_size, rank=rank, device_id=config.device_ids[rank])
# Choose CUDA device for this TP rank.
# config.device_ids is already a list of logical CUDA device indices (respecting CUDA_VISIBLE_DEVICES).
# Do NOT add rank again, otherwise rank 1 with device_ids=[0,1] becomes device 2.
if getattr(config, "device_ids", None):
device_id = config.device_ids[rank]
else:
device_id = (getattr(config, "device_start", 0) or 0) + rank
+ dist.init_process_group("nccl", init_method, world_size=self.world_size, rank=rank, device_id=device_id)🤖 Prompt for AI Agents
In `@diffulex/engine/model_runner.py` around lines 38 - 46,
dist.init_process_group is calling init_method with an unconditional access to
config.device_ids[rank] which can raise if device_ids is missing; guard that
access by computing device_id first (use config.device_ids[rank] only if
getattr(config, "device_ids", None) is truthy, otherwise compute from
config.device_start + rank) and then pass device_id into dist.init_process_group
(and keep the existing assert using torch.cuda.device_count()); update the call
site referencing config.device_ids, device_id, dist.init_process_group, and
torch.cuda.device_count() accordingly.
| for name, cand in offline_capable_modules.items(): | ||
| if ( | ||
| name == module_name | ||
| or name.endswith("." + module_name) | ||
| or module_name.endswith("." + name) | ||
| or (name.split(".")[-1] == leaf) | ||
| ): | ||
| return cand | ||
| return None |
There was a problem hiding this comment.
Fuzzy module name matching could cause incorrect mappings.
The fallback logic using endswith and split(".")[-1] comparisons is permissive and could match unintended modules with similar names (e.g., layer.q_proj matching other_layer.q_proj).
⚠️ Consider more precise matching
The current logic:
or name.endswith("." + module_name)
or module_name.endswith("." + name)
or (name.split(".")[-1] == leaf)Could match modules incorrectly. Consider documenting this behavior or tightening the matching criteria.
🤖 Prompt for AI Agents
In `@diffulex/utils/loader.py` around lines 286 - 294, The fuzzy matching in the
loop (using offline_capable_modules, module_name and leaf) is too permissive and
can map wrong modules; replace the loose checks with stricter rules: keep exact
match, allow only full-dot-boundary suffix/prefix matches (name == module_name
or name == f"{parent}.{module_name}" style), and remove the single-segment
fallback (name.split(".")[-1] == leaf) or tighten it to require the last two
path segments to match (compare the last two segments of name and module_name)
before returning cand; update the matching logic inside the function that
iterates offline_capable_modules and add a short comment explaining the fallback
behavior.
| def _load_calib_texts( | ||
| calib_text_file: str, *, num_samples: int, seed: int | ||
| ) -> list[str]: | ||
| p = Path(calib_text_file) | ||
| if not p.exists(): | ||
| raise FileNotFoundError(f"calib_text_file 不存在: {calib_text_file}") | ||
| lines = [ln.strip() for ln in p.read_text(encoding="utf-8", errors="ignore").splitlines()] | ||
| lines = [ln for ln in lines if ln] | ||
| if not lines: | ||
| raise ValueError(f"calib_text_file 为空: {calib_text_file}") | ||
| if num_samples <= 0: | ||
| raise ValueError(f"calib_num_samples 必须 > 0, got {num_samples}") | ||
| if len(lines) <= num_samples: | ||
| return lines[:num_samples] | ||
| rng = random.Random(seed) | ||
| return rng.sample(lines, k=num_samples) |
There was a problem hiding this comment.
Potential issues in _load_calib_texts.
- Line 116-117: If
len(lines) <= num_samples, it returnslines[:num_samples], but iflen(lines) < num_samples, this silently returns fewer samples than requested. - The seed parameter allows reproducibility, which is good.
⚠️ Consider warning when fewer samples available
if len(lines) <= num_samples:
+ if len(lines) < num_samples:
+ import warnings
+ warnings.warn(
+ f"Requested {num_samples} samples but only {len(lines)} available in {calib_text_file}",
+ UserWarning,
+ )
return lines[:num_samples]📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| def _load_calib_texts( | |
| calib_text_file: str, *, num_samples: int, seed: int | |
| ) -> list[str]: | |
| p = Path(calib_text_file) | |
| if not p.exists(): | |
| raise FileNotFoundError(f"calib_text_file 不存在: {calib_text_file}") | |
| lines = [ln.strip() for ln in p.read_text(encoding="utf-8", errors="ignore").splitlines()] | |
| lines = [ln for ln in lines if ln] | |
| if not lines: | |
| raise ValueError(f"calib_text_file 为空: {calib_text_file}") | |
| if num_samples <= 0: | |
| raise ValueError(f"calib_num_samples 必须 > 0, got {num_samples}") | |
| if len(lines) <= num_samples: | |
| return lines[:num_samples] | |
| rng = random.Random(seed) | |
| return rng.sample(lines, k=num_samples) | |
| def _load_calib_texts( | |
| calib_text_file: str, *, num_samples: int, seed: int | |
| ) -> list[str]: | |
| p = Path(calib_text_file) | |
| if not p.exists(): | |
| raise FileNotFoundError(f"calib_text_file 不存在: {calib_text_file}") | |
| lines = [ln.strip() for ln in p.read_text(encoding="utf-8", errors="ignore").splitlines()] | |
| lines = [ln for ln in lines if ln] | |
| if not lines: | |
| raise ValueError(f"calib_text_file 为空: {calib_text_file}") | |
| if num_samples <= 0: | |
| raise ValueError(f"calib_num_samples 必须 > 0, got {num_samples}") | |
| if len(lines) <= num_samples: | |
| if len(lines) < num_samples: | |
| import warnings | |
| warnings.warn( | |
| f"Requested {num_samples} samples but only {len(lines)} available in {calib_text_file}", | |
| UserWarning, | |
| ) | |
| return lines[:num_samples] | |
| rng = random.Random(seed) | |
| return rng.sample(lines, k=num_samples) |
🤖 Prompt for AI Agents
In `@diffulex/utils/quantization/quantize_model.py` around lines 104 - 119, The
helper _load_calib_texts silently returns fewer samples when the file has fewer
lines than requested; change the logic in _load_calib_texts so that when
len(lines) < num_samples you emit a clear warning (use the warnings module)
indicating fewer samples are available and return all lines, and only use the
rng.sample path when len(lines) >= num_samples (i.e., change the current
len(lines) <= num_samples branch to handle the < case with a warning and return,
keeping rng.sample(lines, k=num_samples) for the sampling case); reference the
function name _load_calib_texts and the variables lines, num_samples, seed when
making the change.
| m, | ||
| n, | ||
| k, | ||
| True, # is_k_full |
There was a problem hiding this comment.
Hardcoded is_k_full=True may not always be correct.
Unlike the GPTQ Marlin strategy which computes is_k_full based on has_g_idx and row_parallel, the AWQ strategy hardcodes it to True. This could cause issues with row-parallel tensor-parallel configurations.
🔧 Consider computing is_k_full dynamically
+ # AWQ doesn't use g_idx; is_k_full depends on row-parallel setting.
+ row_parallel = bool(tp_dim == 1)
+ is_k_full = True # AWQ has no g_idx, so is_k_full follows row_parallel logic
+ # Note: For row-parallel, vLLM's marlin_is_k_full(has_g_idx=False, is_row_parallel=True) returns True
+
out = torch.ops._C.gptq_marlin_gemm(
...
- True, # is_k_full
+ is_k_full,
...
)🤖 Prompt for AI Agents
In `@diffulex/utils/quantization/strategies/linear_awq_marlin_w4a16.py` at line
159, The AWQ strategy currently hardcodes is_k_full=True which can be incorrect
for row-parallel tensor-parallel cases; change the assignment of is_k_full in
the linear_awq_marlin_w4a16 strategy to compute it dynamically (e.g., mirror the
GPTQ Marlin logic) using the existing has_g_idx and row_parallel variables so
is_k_full = not has_g_idx or row_parallel (or equivalent logic used in the GPTQ
Marlin implementation); update the place where the tuple/flag is constructed
(the entry labeled is_k_full) so it derives from has_g_idx and row_parallel
rather than the literal True.
| def get_storage_dtype(self) -> tuple[torch.dtype, int]: | ||
| # FP8 weights are stored as uint8 (1 byte per element) | ||
| # vLLM stores fp8 weights as float8 dtype tensor | ||
| return torch.uint8, 1 |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
# First, examine the file structure and the specific lines in question
cat -n diffulex/utils/quantization/strategies/linear_fp8_w8a16.py | head -80Repository: SJTU-DENG-Lab/Diffulex
Length of output: 465
🏁 Script executed:
# Search for any references to float8 or FP8 in the codebase
rg "float8|float8_e4m3" --type pyRepository: SJTU-DENG-Lab/Diffulex
Length of output: 465
🏁 Script executed:
# Look for vLLM imports and usage patterns
rg "from vllm|import vllm" --type py -A 3Repository: SJTU-DENG-Lab/Diffulex
Length of output: 465
🏁 Script executed:
# Search for other get_storage_dtype implementations to understand the pattern
ast-grep --pattern 'def get_storage_dtype($$$) {
$$$
}'Repository: SJTU-DENG-Lab/Diffulex
Length of output: 465
🌐 Web query:
vLLM FP8 quantization dtype storage float8
💡 Result:
- What “FP8” means in vLLM: vLLM uses the two standard FP8 formats E4M3FN and E5M2 (different range/precision tradeoffs). vLLM’s “
fp8” generally refers to FP8 E4M3. [1][2] - Online FP8 quantization (
quantization="fp8"): vLLM dynamically quantizes an FP16/BF16 model so that Linear weights are quantized to FP8_E4M3 with a per-tensor scale (activations are dynamically scaled each forward pass). [1] - Actual tensor dtype used for stored FP8 weights: when a checkpoint is FP8-serialized, vLLM allocates weight tensors with PyTorch’s float8 dtype
torch.float8_e4m3fn(i.e., real float8 storage, not int8). [3] - FP8 KV-cache storage:
kv_cache_dtypesupports"fp8"/"fp8_e4m3"and"fp8_e5m2"(CUDA 11.8+; ROCm supports"fp8"= E4M3). With E4M3, vLLM notes you typically need an FP32 scaling factor alongside each quantized tensor. [2][4]
Sources: FP8 W8A8 docs [1], Engine args / --kv-cache-dtype [2], vLLM FP8 layer code showing torch.float8_e4m3fn weight dtype [3], Quantized KV-cache docs [4].
Return value should match vLLM's actual FP8 dtype storage.
The comment states "vLLM stores fp8 weights as float8 dtype tensor", but the method returns torch.uint8. vLLM actually stores FP8-quantized weights using PyTorch's torch.float8_e4m3fn dtype. Update the return value to torch.float8_e4m3fn to match both the comment and vLLM's implementation.
🤖 Prompt for AI Agents
In `@diffulex/utils/quantization/strategies/linear_fp8_w8a16.py` around lines 62 -
64, The get_storage_dtype function currently returns torch.uint8 although the
comment and vLLM use an FP8 float dtype; change the return value in
get_storage_dtype to use PyTorch's FP8 dtype (torch.float8_e4m3fn) so it returns
(torch.float8_e4m3fn, 1) instead of (torch.uint8, 1), ensuring the storage dtype
matches vLLM's FP8 implementation.
| if weight is not None and quant_scales is not None: | ||
| # Expected: weight is fp8 K×N tensor (transpose-view is fine). | ||
| q_kn = weight.to(device=x.device) | ||
| scales = quant_scales.to(device=x.device, dtype=torch.float32).reshape(1) | ||
| else: | ||
| # Lazy cache: use weight tensor id as key (only for bf16/fp16/fp32 weights) | ||
| weight_id = id(weight) | ||
|
|
||
| # Check cache | ||
| if weight_id in self._weight_cache: | ||
| quantized_weight, scales = self._weight_cache[weight_id] | ||
| # Ensure cached tensors are on the correct device | ||
| if quantized_weight.device != x.device: | ||
| quantized_weight = quantized_weight.to(device=x.device) | ||
| scales = scales.to(device=x.device) | ||
| wid = id(weight) | ||
| cached = self._weight_cache.get(wid) | ||
| if cached is None or cached[0].device != x.device: | ||
| q_fp8, meta = self.quantize(weight) | ||
| q_fp8 = q_fp8.to(device=x.device) | ||
| scales = meta["scales"].to(device=x.device, dtype=torch.float32).reshape(1) | ||
| q_kn = q_fp8 | ||
| self._weight_cache[wid] = (q_fp8, scales) | ||
| else: | ||
| # Quantize weight and cache it | ||
| quantized_weight, scales = self.quantize_weight_for_kernel(weight, device=x.device) | ||
| # Cache the quantized weight and scales | ||
| self._weight_cache[weight_id] = (quantized_weight, scales) | ||
|
|
||
| # Speed-first option: cache dequantized bf16 weight for F.linear (cuBLAS) | ||
| # This trades extra GPU memory for throughput. | ||
| import os | ||
| if os.getenv("DIFFULEX_FP8_W8A16_PREFER_CUBLAS", "0") == "1": | ||
| deq_key = id(weight) if weight.dtype != torch.uint8 else id(quantized_weight) | ||
| deq_w = self._dequant_weight_cache.get(deq_key) | ||
| if deq_w is None or deq_w.device != x.device: | ||
| # Dequantize: FP8[N,K] * scales[N] -> bf16[N,K] | ||
| deq_w = self.dequantize(quantized_weight, scales) | ||
| self._dequant_weight_cache[deq_key] = deq_w | ||
| return F.linear(x, deq_w, bias) | ||
|
|
||
| # Try to use TileLang kernel if available | ||
| fp8_w8a16_gemm = None | ||
| if self.weight_dtype_str == "fp8_e4m3": | ||
| fp8_w8a16_gemm = _fp8_e4m3_w8a16_gemm | ||
| elif self.weight_dtype_str == "fp8_e5m2": | ||
| fp8_w8a16_gemm = _fp8_e5m2_w8a16_gemm | ||
|
|
||
| if _TILELANG_AVAILABLE and fp8_w8a16_gemm is not None: | ||
| try: | ||
| # Check device | ||
| if x.device.type != 'cuda': | ||
| return self._fallback_python_forward(x, quantized_weight, scales, bias) | ||
|
|
||
| # Get shapes | ||
| M, K = x.shape | ||
| N, K_w = quantized_weight.shape | ||
| assert K == K_w, f"K dimension mismatch: {K} != {K_w}" | ||
|
|
||
| # Bucket M to reduce compilation churn | ||
| M_bucket = M | ||
| if M > 1: | ||
| if M <= 64: | ||
| M_bucket = 1 << (M - 1).bit_length() | ||
| else: | ||
| M_bucket = ((M + 63) // 64) * 64 | ||
|
|
||
| x_for_kernel = x | ||
| if M_bucket != M: | ||
| x_pad = torch.zeros((M_bucket, K), device=x.device, dtype=x.dtype) | ||
| x_pad[:M, :] = x | ||
| x_for_kernel = x_pad | ||
|
|
||
| # TileLang autotune: use warmup + config cache pattern | ||
| cache_key = (str(x.device), M_bucket, N, K) | ||
| config = self._tl_autotune_config_cache.get(cache_key) | ||
|
|
||
| if _AUTOTUNE_AVAILABLE and is_warming_up() and config is None: | ||
| # Warmup phase: run autotune with real inputs | ||
| try: | ||
| assert self.spec.fp8_view_dtype is not None | ||
| qweight_fp8 = quantized_weight.view(self.spec.fp8_view_dtype) | ||
| with set_autotune_inputs([x_for_kernel, qweight_fp8, scales]): | ||
| kernel = fp8_w8a16_gemm(M_bucket, N, K) | ||
| config = kernel.config | ||
| self._tl_autotune_config_cache[cache_key] = config | ||
| except Exception: | ||
| # Fallback to default config if autotune fails | ||
| config = None | ||
|
|
||
| # Use cached config or default parameters | ||
| assert self.spec.fp8_view_dtype is not None | ||
| qweight_fp8 = quantized_weight.view(self.spec.fp8_view_dtype) | ||
| if config is not None: | ||
| kernel = fp8_w8a16_gemm(M_bucket, N, K, **config) | ||
| else: | ||
| # Default config (backward compatible) | ||
| kernel = fp8_w8a16_gemm(M_bucket, N, K, block_M=64, block_N=64, block_K=128, num_stages=2, threads=128) | ||
|
|
||
| # Call kernel - out_idx=[3] means output is the 4th parameter | ||
| assert self.spec.fp8_view_dtype is not None | ||
| qweight_fp8 = quantized_weight.view(self.spec.fp8_view_dtype) | ||
| output_full = kernel(x_for_kernel, qweight_fp8, scales) | ||
| output = output_full[:M, :] if M_bucket != M else output_full | ||
|
|
||
| # Add bias if present | ||
| if bias is not None: | ||
| output = output + bias | ||
|
|
||
| return output | ||
| except Exception as e: | ||
| # Fallback to Python implementation on any error | ||
| import warnings | ||
| error_msg = str(e) | ||
|
|
||
| # Extract meaningful error information | ||
| if 'sm_' in error_msg and ('not defined' in error_msg or 'fatal' in error_msg): | ||
| # CUDA architecture not supported - silently fallback | ||
| pass | ||
| elif 'Compilation error' in error_msg: | ||
| # Extract the actual error | ||
| idx = error_msg.find('Compilation error') | ||
| after = error_msg[idx + len('Compilation error'):] | ||
| lines = after.split('\n') | ||
| for line in lines: | ||
| line = line.strip() | ||
| if line and not line.startswith('#') and ('error:' in line.lower() or 'fatal' in line.lower()): | ||
| error_msg = f"CUDA compilation error: {line[:200]}" | ||
| break | ||
| else: | ||
| error_msg = "CUDA compilation error (see logs for details)" | ||
| elif 'pipeline' in error_msg.lower() and 'stage' in error_msg.lower(): | ||
| # Pipeline stages mismatch - silently fallback | ||
| pass | ||
| else: | ||
| # Truncate very long error messages | ||
| if len(error_msg) > 200: | ||
| error_msg = error_msg[:200] + "..." | ||
|
|
||
| # Only warn for unexpected errors | ||
| if 'CUDA architecture not supported' not in error_msg and 'sm_' not in error_msg and 'Pipeline stages' not in error_msg: | ||
| warnings.warn( | ||
| f"TileLang kernel failed, falling back to Python implementation: {error_msg}", | ||
| UserWarning, | ||
| ) | ||
| return self._fallback_python_forward(x, quantized_weight, scales, bias) | ||
| else: | ||
| # TileLang not available, use Python reference | ||
| return self._fallback_python_forward(x, quantized_weight, scales, bias) | ||
|
|
||
| def _fallback_python_forward( | ||
| self, | ||
| x: torch.Tensor, | ||
| quantized_weight: torch.Tensor, | ||
| scales: torch.Tensor, | ||
| bias: Optional[torch.Tensor], | ||
| ) -> torch.Tensor: | ||
| """Fallback Python implementation: dequantize + F.linear.""" | ||
| # Dequantize for reference implementation | ||
| dequantized_weight = self.dequantize(quantized_weight, scales) | ||
|
|
||
| # Compute linear output | ||
| return F.linear(x, dequantized_weight, bias) | ||
|
|
||
| def clear_cache(self) -> None: | ||
| """Clear the weight quantization cache. | ||
|
|
||
| Useful for memory management or when weights are updated (e.g., fine-tuning). | ||
| """ | ||
| self._weight_cache.clear() | ||
| self._dequant_weight_cache.clear() | ||
| q_kn, scales = cached |
There was a problem hiding this comment.
Missing null check for weight in the else branch.
When the condition on line 115 is false, the code enters the else branch which calls self.quantize(weight). If weight is None (and only quant_scales is None), this will fail. Consider adding an explicit check.
🐛 Suggested fix
if weight is not None and quant_scales is not None:
# Expected: weight is fp8 K×N tensor (transpose-view is fine).
q_kn = weight.to(device=x.device)
scales = quant_scales.to(device=x.device, dtype=torch.float32).reshape(1)
else:
+ if weight is None:
+ raise ValueError("weight cannot be None when quant_scales is not provided")
wid = id(weight)
cached = self._weight_cache.get(wid)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| if weight is not None and quant_scales is not None: | |
| # Expected: weight is fp8 K×N tensor (transpose-view is fine). | |
| q_kn = weight.to(device=x.device) | |
| scales = quant_scales.to(device=x.device, dtype=torch.float32).reshape(1) | |
| else: | |
| # Lazy cache: use weight tensor id as key (only for bf16/fp16/fp32 weights) | |
| weight_id = id(weight) | |
| # Check cache | |
| if weight_id in self._weight_cache: | |
| quantized_weight, scales = self._weight_cache[weight_id] | |
| # Ensure cached tensors are on the correct device | |
| if quantized_weight.device != x.device: | |
| quantized_weight = quantized_weight.to(device=x.device) | |
| scales = scales.to(device=x.device) | |
| wid = id(weight) | |
| cached = self._weight_cache.get(wid) | |
| if cached is None or cached[0].device != x.device: | |
| q_fp8, meta = self.quantize(weight) | |
| q_fp8 = q_fp8.to(device=x.device) | |
| scales = meta["scales"].to(device=x.device, dtype=torch.float32).reshape(1) | |
| q_kn = q_fp8 | |
| self._weight_cache[wid] = (q_fp8, scales) | |
| else: | |
| # Quantize weight and cache it | |
| quantized_weight, scales = self.quantize_weight_for_kernel(weight, device=x.device) | |
| # Cache the quantized weight and scales | |
| self._weight_cache[weight_id] = (quantized_weight, scales) | |
| # Speed-first option: cache dequantized bf16 weight for F.linear (cuBLAS) | |
| # This trades extra GPU memory for throughput. | |
| import os | |
| if os.getenv("DIFFULEX_FP8_W8A16_PREFER_CUBLAS", "0") == "1": | |
| deq_key = id(weight) if weight.dtype != torch.uint8 else id(quantized_weight) | |
| deq_w = self._dequant_weight_cache.get(deq_key) | |
| if deq_w is None or deq_w.device != x.device: | |
| # Dequantize: FP8[N,K] * scales[N] -> bf16[N,K] | |
| deq_w = self.dequantize(quantized_weight, scales) | |
| self._dequant_weight_cache[deq_key] = deq_w | |
| return F.linear(x, deq_w, bias) | |
| # Try to use TileLang kernel if available | |
| fp8_w8a16_gemm = None | |
| if self.weight_dtype_str == "fp8_e4m3": | |
| fp8_w8a16_gemm = _fp8_e4m3_w8a16_gemm | |
| elif self.weight_dtype_str == "fp8_e5m2": | |
| fp8_w8a16_gemm = _fp8_e5m2_w8a16_gemm | |
| if _TILELANG_AVAILABLE and fp8_w8a16_gemm is not None: | |
| try: | |
| # Check device | |
| if x.device.type != 'cuda': | |
| return self._fallback_python_forward(x, quantized_weight, scales, bias) | |
| # Get shapes | |
| M, K = x.shape | |
| N, K_w = quantized_weight.shape | |
| assert K == K_w, f"K dimension mismatch: {K} != {K_w}" | |
| # Bucket M to reduce compilation churn | |
| M_bucket = M | |
| if M > 1: | |
| if M <= 64: | |
| M_bucket = 1 << (M - 1).bit_length() | |
| else: | |
| M_bucket = ((M + 63) // 64) * 64 | |
| x_for_kernel = x | |
| if M_bucket != M: | |
| x_pad = torch.zeros((M_bucket, K), device=x.device, dtype=x.dtype) | |
| x_pad[:M, :] = x | |
| x_for_kernel = x_pad | |
| # TileLang autotune: use warmup + config cache pattern | |
| cache_key = (str(x.device), M_bucket, N, K) | |
| config = self._tl_autotune_config_cache.get(cache_key) | |
| if _AUTOTUNE_AVAILABLE and is_warming_up() and config is None: | |
| # Warmup phase: run autotune with real inputs | |
| try: | |
| assert self.spec.fp8_view_dtype is not None | |
| qweight_fp8 = quantized_weight.view(self.spec.fp8_view_dtype) | |
| with set_autotune_inputs([x_for_kernel, qweight_fp8, scales]): | |
| kernel = fp8_w8a16_gemm(M_bucket, N, K) | |
| config = kernel.config | |
| self._tl_autotune_config_cache[cache_key] = config | |
| except Exception: | |
| # Fallback to default config if autotune fails | |
| config = None | |
| # Use cached config or default parameters | |
| assert self.spec.fp8_view_dtype is not None | |
| qweight_fp8 = quantized_weight.view(self.spec.fp8_view_dtype) | |
| if config is not None: | |
| kernel = fp8_w8a16_gemm(M_bucket, N, K, **config) | |
| else: | |
| # Default config (backward compatible) | |
| kernel = fp8_w8a16_gemm(M_bucket, N, K, block_M=64, block_N=64, block_K=128, num_stages=2, threads=128) | |
| # Call kernel - out_idx=[3] means output is the 4th parameter | |
| assert self.spec.fp8_view_dtype is not None | |
| qweight_fp8 = quantized_weight.view(self.spec.fp8_view_dtype) | |
| output_full = kernel(x_for_kernel, qweight_fp8, scales) | |
| output = output_full[:M, :] if M_bucket != M else output_full | |
| # Add bias if present | |
| if bias is not None: | |
| output = output + bias | |
| return output | |
| except Exception as e: | |
| # Fallback to Python implementation on any error | |
| import warnings | |
| error_msg = str(e) | |
| # Extract meaningful error information | |
| if 'sm_' in error_msg and ('not defined' in error_msg or 'fatal' in error_msg): | |
| # CUDA architecture not supported - silently fallback | |
| pass | |
| elif 'Compilation error' in error_msg: | |
| # Extract the actual error | |
| idx = error_msg.find('Compilation error') | |
| after = error_msg[idx + len('Compilation error'):] | |
| lines = after.split('\n') | |
| for line in lines: | |
| line = line.strip() | |
| if line and not line.startswith('#') and ('error:' in line.lower() or 'fatal' in line.lower()): | |
| error_msg = f"CUDA compilation error: {line[:200]}" | |
| break | |
| else: | |
| error_msg = "CUDA compilation error (see logs for details)" | |
| elif 'pipeline' in error_msg.lower() and 'stage' in error_msg.lower(): | |
| # Pipeline stages mismatch - silently fallback | |
| pass | |
| else: | |
| # Truncate very long error messages | |
| if len(error_msg) > 200: | |
| error_msg = error_msg[:200] + "..." | |
| # Only warn for unexpected errors | |
| if 'CUDA architecture not supported' not in error_msg and 'sm_' not in error_msg and 'Pipeline stages' not in error_msg: | |
| warnings.warn( | |
| f"TileLang kernel failed, falling back to Python implementation: {error_msg}", | |
| UserWarning, | |
| ) | |
| return self._fallback_python_forward(x, quantized_weight, scales, bias) | |
| else: | |
| # TileLang not available, use Python reference | |
| return self._fallback_python_forward(x, quantized_weight, scales, bias) | |
| def _fallback_python_forward( | |
| self, | |
| x: torch.Tensor, | |
| quantized_weight: torch.Tensor, | |
| scales: torch.Tensor, | |
| bias: Optional[torch.Tensor], | |
| ) -> torch.Tensor: | |
| """Fallback Python implementation: dequantize + F.linear.""" | |
| # Dequantize for reference implementation | |
| dequantized_weight = self.dequantize(quantized_weight, scales) | |
| # Compute linear output | |
| return F.linear(x, dequantized_weight, bias) | |
| def clear_cache(self) -> None: | |
| """Clear the weight quantization cache. | |
| Useful for memory management or when weights are updated (e.g., fine-tuning). | |
| """ | |
| self._weight_cache.clear() | |
| self._dequant_weight_cache.clear() | |
| q_kn, scales = cached | |
| if weight is not None and quant_scales is not None: | |
| # Expected: weight is fp8 K×N tensor (transpose-view is fine). | |
| q_kn = weight.to(device=x.device) | |
| scales = quant_scales.to(device=x.device, dtype=torch.float32).reshape(1) | |
| else: | |
| if weight is None: | |
| raise ValueError("weight cannot be None when quant_scales is not provided") | |
| wid = id(weight) | |
| cached = self._weight_cache.get(wid) | |
| if cached is None or cached[0].device != x.device: | |
| q_fp8, meta = self.quantize(weight) | |
| q_fp8 = q_fp8.to(device=x.device) | |
| scales = meta["scales"].to(device=x.device, dtype=torch.float32).reshape(1) | |
| q_kn = q_fp8 | |
| self._weight_cache[wid] = (q_fp8, scales) | |
| else: | |
| q_kn, scales = cached |
🤖 Prompt for AI Agents
In `@diffulex/utils/quantization/strategies/linear_fp8_w8a16.py` around lines 115
- 129, The else branch assumes weight is non-null but calls id(weight) and
self.quantize(weight) even when weight is None; add an explicit check at the
start of that branch (e.g., if weight is None: raise ValueError(...) or handle a
valid fallback) so you never call id() or self.quantize with None; update the
block that uses wid = id(weight), cached = self._weight_cache.get(wid),
self.quantize(weight), and subsequent assignments to q_fp8, meta, scales, q_kn
to only run when weight is verified non-None and provide a clear error message
referencing weight and quant_scales for easier debugging.
| wid = id(weight) | ||
| cached = self._weight_cache.get(wid) | ||
| if cached is None or cached[0].device != x.device: | ||
| q_fp8, meta = self.quantize(weight) | ||
| q_fp8 = q_fp8.to(device=x.device) | ||
| w_scale = meta["scales"].to(device=x.device, dtype=torch.float32).reshape(1) | ||
| self._weight_cache[wid] = (q_fp8, w_scale) | ||
| else: | ||
| weight_id = id(weight) | ||
| cached = self._weight_cache.get(weight_id) | ||
| if cached is None: | ||
| qweight, w_scales = self.quantize_weight_for_kernel(weight, device=x.device) | ||
| self._weight_cache[weight_id] = (qweight, w_scales) | ||
| else: | ||
| qweight, w_scales = cached | ||
| if qweight.device != x.device: | ||
| qweight = qweight.to(device=x.device) | ||
| w_scales = w_scales.to(device=x.device) | ||
| self._weight_cache[weight_id] = (qweight, w_scales) | ||
|
|
||
| # Optional: use cuBLAS BF16 (dequant once) | ||
| import os | ||
| if os.getenv("DIFFULEX_FP8_W8A8_PREFER_CUBLAS", "0") == "1": | ||
| deq_key = weight_id | ||
| deq_w = self._dequant_weight_cache.get(deq_key) | ||
| if deq_w is None or deq_w.device != x.device: | ||
| deq_w = self.dequantize(qweight, w_scales) | ||
| self._dequant_weight_cache[deq_key] = deq_w | ||
| # Also dequantize activation | ||
| x_q_temp, x_scales_temp = self.quantize_act_for_kernel(x, device=x.device) | ||
| x_deq = self._dequantize_act(x_q_temp, x_scales_temp) | ||
| return F.linear(x_deq, deq_w, bias) | ||
|
|
||
| # Quantize activation per-row | ||
| if x.dtype not in (torch.bfloat16, torch.float16, torch.float32): | ||
| x = x.to(torch.bfloat16) | ||
| x_q, x_scales = self.quantize_act_for_kernel(x, device=x.device) | ||
|
|
||
| # Try to use TileLang kernel if available | ||
| # For W8A8, weight_dtype and act_dtype should match (both e4m3 or both e5m2) | ||
| fp8_w8a8_gemm = None | ||
| if self.weight_dtype_str == "fp8_e4m3" and self.act_dtype_str == "fp8_e4m3": | ||
| fp8_w8a8_gemm = _fp8_e4m3_w8a8_gemm | ||
| elif self.weight_dtype_str == "fp8_e5m2" and self.act_dtype_str == "fp8_e5m2": | ||
| fp8_w8a8_gemm = _fp8_e5m2_w8a8_gemm | ||
|
|
||
| if _TILELANG_AVAILABLE and fp8_w8a8_gemm is not None: | ||
| try: | ||
| # Check device | ||
| if x.device.type != 'cuda': | ||
| return self._fallback_python_forward(x_q, x_scales, qweight, w_scales, bias) | ||
|
|
||
| # Get shapes | ||
| M, K = x_q.shape | ||
| N, K_w = qweight.shape | ||
| assert K == K_w, f"K dimension mismatch: {K} != {K_w}" | ||
|
|
||
| # Bucket M to reduce compilation churn | ||
| M_bucket = M | ||
| if M > 1: | ||
| if M <= 64: | ||
| M_bucket = 1 << (M - 1).bit_length() | ||
| else: | ||
| M_bucket = ((M + 63) // 64) * 64 | ||
| q_fp8, w_scale = cached |
There was a problem hiding this comment.
Same null-check concern as in W8A16 strategy.
Similar to the W8A16 variant, this branch can be entered when weight is None (since the only guard is quant_kind being unused). The self.quantize(weight) call on line 112 will fail if weight is None.
🐛 Suggested fix
def linear_forward(
self,
x: torch.Tensor,
weight: torch.Tensor,
bias: Optional[torch.Tensor],
*,
quant_kind: str,
**kwargs: Any,
) -> torch.Tensor:
_ = quant_kind
+ if weight is None:
+ raise ValueError("weight cannot be None for FP8 W8A8 forward pass")
wid = id(weight)
cached = self._weight_cache.get(wid)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| wid = id(weight) | |
| cached = self._weight_cache.get(wid) | |
| if cached is None or cached[0].device != x.device: | |
| q_fp8, meta = self.quantize(weight) | |
| q_fp8 = q_fp8.to(device=x.device) | |
| w_scale = meta["scales"].to(device=x.device, dtype=torch.float32).reshape(1) | |
| self._weight_cache[wid] = (q_fp8, w_scale) | |
| else: | |
| weight_id = id(weight) | |
| cached = self._weight_cache.get(weight_id) | |
| if cached is None: | |
| qweight, w_scales = self.quantize_weight_for_kernel(weight, device=x.device) | |
| self._weight_cache[weight_id] = (qweight, w_scales) | |
| else: | |
| qweight, w_scales = cached | |
| if qweight.device != x.device: | |
| qweight = qweight.to(device=x.device) | |
| w_scales = w_scales.to(device=x.device) | |
| self._weight_cache[weight_id] = (qweight, w_scales) | |
| # Optional: use cuBLAS BF16 (dequant once) | |
| import os | |
| if os.getenv("DIFFULEX_FP8_W8A8_PREFER_CUBLAS", "0") == "1": | |
| deq_key = weight_id | |
| deq_w = self._dequant_weight_cache.get(deq_key) | |
| if deq_w is None or deq_w.device != x.device: | |
| deq_w = self.dequantize(qweight, w_scales) | |
| self._dequant_weight_cache[deq_key] = deq_w | |
| # Also dequantize activation | |
| x_q_temp, x_scales_temp = self.quantize_act_for_kernel(x, device=x.device) | |
| x_deq = self._dequantize_act(x_q_temp, x_scales_temp) | |
| return F.linear(x_deq, deq_w, bias) | |
| # Quantize activation per-row | |
| if x.dtype not in (torch.bfloat16, torch.float16, torch.float32): | |
| x = x.to(torch.bfloat16) | |
| x_q, x_scales = self.quantize_act_for_kernel(x, device=x.device) | |
| # Try to use TileLang kernel if available | |
| # For W8A8, weight_dtype and act_dtype should match (both e4m3 or both e5m2) | |
| fp8_w8a8_gemm = None | |
| if self.weight_dtype_str == "fp8_e4m3" and self.act_dtype_str == "fp8_e4m3": | |
| fp8_w8a8_gemm = _fp8_e4m3_w8a8_gemm | |
| elif self.weight_dtype_str == "fp8_e5m2" and self.act_dtype_str == "fp8_e5m2": | |
| fp8_w8a8_gemm = _fp8_e5m2_w8a8_gemm | |
| if _TILELANG_AVAILABLE and fp8_w8a8_gemm is not None: | |
| try: | |
| # Check device | |
| if x.device.type != 'cuda': | |
| return self._fallback_python_forward(x_q, x_scales, qweight, w_scales, bias) | |
| # Get shapes | |
| M, K = x_q.shape | |
| N, K_w = qweight.shape | |
| assert K == K_w, f"K dimension mismatch: {K} != {K_w}" | |
| # Bucket M to reduce compilation churn | |
| M_bucket = M | |
| if M > 1: | |
| if M <= 64: | |
| M_bucket = 1 << (M - 1).bit_length() | |
| else: | |
| M_bucket = ((M + 63) // 64) * 64 | |
| q_fp8, w_scale = cached | |
| def linear_forward( | |
| self, | |
| x: torch.Tensor, | |
| weight: torch.Tensor, | |
| bias: Optional[torch.Tensor], | |
| *, | |
| quant_kind: str, | |
| **kwargs: Any, | |
| ) -> torch.Tensor: | |
| _ = quant_kind | |
| if weight is None: | |
| raise ValueError("weight cannot be None for FP8 W8A8 forward pass") | |
| wid = id(weight) | |
| cached = self._weight_cache.get(wid) | |
| if cached is None or cached[0].device != x.device: | |
| q_fp8, meta = self.quantize(weight) | |
| q_fp8 = q_fp8.to(device=x.device) | |
| w_scale = meta["scales"].to(device=x.device, dtype=torch.float32).reshape(1) | |
| self._weight_cache[wid] = (q_fp8, w_scale) | |
| else: | |
| q_fp8, w_scale = cached |
🤖 Prompt for AI Agents
In `@diffulex/utils/quantization/strategies/linear_fp8_w8a8.py` around lines 109 -
117, The code calls self.quantize(weight) without guarding against weight being
None, which will raise if weight is absent; add the same null-check used in the
W8A16 variant before computing wid/using the cache (or directly before calling
self.quantize) so that when weight is None you skip quantization and set
q_fp8/w_scale to appropriate defaults or handle the None-case consistently;
update the branch around wid, self._weight_cache, and the call to
self.quantize(weight) to first check "if weight is None" and return or assign
safe placeholders used by the rest of the method.
| # Directly call the underlying CUDA op to minimize Python glue. | ||
| out = torch.ops._C.gptq_marlin_gemm( | ||
| reshaped_x, | ||
| None, | ||
| qweight, | ||
| marlin_bias, | ||
| scales, | ||
| None, | ||
| None, | ||
| zp, | ||
| g_idx_t, | ||
| g_idx_sort_t, | ||
| workspace, | ||
| wtype.id, | ||
| m, | ||
| n, | ||
| k, | ||
| is_k_full, | ||
| use_atomic_add, | ||
| True, # use_fp32_reduce | ||
| False, # is_zp_float | ||
| ) |
There was a problem hiding this comment.
🧩 Analysis chain
🌐 Web query:
vLLM gptq_marlin_gemm function signature arguments
💡 Result:
In vLLM, the custom op schema for gptq_marlin_gemm (as registered with PyTorch) is:
gptq_marlin_gemm(
Tensor a,
Tensor? c_or_none,
Tensor b_q_weight,
Tensor b_scales,
Tensor? global_scale,
Tensor? b_zeros_or_none,
Tensor? g_idx_or_none,
Tensor? perm_or_none,
Tensor workspace,
int b_q_type,
SymInt size_m,
SymInt size_n,
SymInt size_k,
bool is_k_full,
bool use_atomic_add,
bool use_fp32_reduce,
bool is_zp_float
) -> Tensor
[1]
A matching C++ implementation signature looks like:
torch::Tensor gptq_marlin_gemm(
torch::Tensor& a,
std::optional<torch::Tensor> c_or_none,
torch::Tensor& b_q_weight,
torch::Tensor& b_scales,
std::optional<torch::Tensor> const& global_scale_or_none,
std::optional<torch::Tensor> const& b_zeros_or_none,
std::optional<torch::Tensor> const& g_idx_or_none,
std::optional<torch::Tensor> const& perm_or_none,
torch::Tensor& workspace,
vllm::ScalarTypeId const& b_q_type_id,
int64_t size_m, int64_t size_n, int64_t size_k,
bool is_k_full, bool use_atomic_add, bool use_fp32_reduce,
bool is_zp_float);[2]
Argument order is exactly as shown above. a is the activation (MxK), b_q_weight/b_scales/(b_zeros) are the quantized weight + params, g_idx/perm are optional GPTQ grouping/permutation tensors, workspace is scratch memory, and size_m/size_n/size_k are the GEMM sizes. [1]
Sources:
[1] https://huggingface.co/kernels-community/quantization/blob/main/torch-ext/torch_binding.cpp
[2] https://huggingface.co/RedHatAI/quantization/blob/main/gptq_marlin/gptq_marlin.cu
🏁 Script executed:
# First, let's check if this is a local implementation or from vLLM
cd /tmp && find . -name "*.py" -type f 2>/dev/null | head -20Repository: SJTU-DENG-Lab/Diffulex
Length of output: 465
🏁 Script executed:
# Search for gptq_marlin_gemm in the codebase
rg "gptq_marlin_gemm" --type py -B 5 -A 10Repository: SJTU-DENG-Lab/Diffulex
Length of output: 465
🏁 Script executed:
# Check the file structure and imports
head -50 diffulex/utils/quantization/strategies/linear_gptq_marlin_w4a16.pyRepository: SJTU-DENG-Lab/Diffulex
Length of output: 465
🏁 Script executed:
# Look for any vLLM package or imports
rg "vllm|gptq_marlin" --type py | head -30Repository: SJTU-DENG-Lab/Diffulex
Length of output: 465
Critical: argument order mismatch with vLLM's gptq_marlin_gemm signature
The call passes 19 arguments, but vLLM's gptq_marlin_gemm expects only 17. The function signature is:
gptq_marlin_gemm(
a, c_or_none, b_q_weight, b_scales, global_scale, b_zeros_or_none,
g_idx_or_none, perm_or_none, workspace, b_q_type, size_m, size_n,
size_k, is_k_full, use_atomic_add, use_fp32_reduce, is_zp_float
)
The code's arguments diverge from arg 3 onward: marlin_bias is passed where b_scales is expected, scales where global_scale is expected, and subsequent arguments are offset, with workspace, sizes, and flags in incorrect positions. This will cause type mismatches or silent data corruption.
🤖 Prompt for AI Agents
In `@diffulex/utils/quantization/strategies/linear_gptq_marlin_w4a16.py` around
lines 181 - 202, The torch.ops._C.gptq_marlin_gemm call has its arguments
shifted and doesn't match vLLM's 17-arg signature; reorder the call so arguments
follow gptq_marlin_gemm(a, c_or_none, b_q_weight, b_scales, global_scale,
b_zeros_or_none, g_idx_or_none, perm_or_none, workspace, b_q_type, size_m,
size_n, size_k, is_k_full, use_atomic_add, use_fp32_reduce, is_zp_float): pass
reshaped_x, None, qweight, scales, marlin_bias (as global_scale), zp (as
b_zeros_or_none), g_idx_t, g_idx_sort_t, workspace, wtype.id, m, n, k,
is_k_full, use_atomic_add, True (use_fp32_reduce), False (is_zp_float) so all
parameters (reshaped_x, qweight, scales, marlin_bias, zp, g_idx_t, g_idx_sort_t,
workspace, wtype.id, m, n, k, is_k_full, use_atomic_add) are in the correct
positions for torch.ops._C.gptq_marlin_gemm.
| if empty is None or empty.device != device: | ||
| empty = torch.empty((0,), device=device, dtype=torch.int) |
There was a problem hiding this comment.
Redundant device check in empty tensor cache.
The cache already uses dev_key as the key, so checking empty.device != device is redundant—if the device index matches, the device should match. However, this could fail for CPU tensors where device.index is None.
🔧 Simplify the cache check
if g_idx is None or g_idx.numel() == 0:
empty = self._empty_cache.get(dev_key)
- if empty is None or empty.device != device:
+ if empty is None:
empty = torch.empty((0,), device=device, dtype=torch.int)
self._empty_cache[dev_key] = empty
g_idx_t = empty📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| if empty is None or empty.device != device: | |
| empty = torch.empty((0,), device=device, dtype=torch.int) | |
| if empty is None: | |
| empty = torch.empty((0,), device=device, dtype=torch.int) |
🤖 Prompt for AI Agents
In `@diffulex/utils/quantization/strategies/linear_gptq_w4a16.py` around lines 130
- 131, Remove the redundant device comparison when retrieving cached empty
tensors: rely on the cache key (dev_key) to ensure device correctness and only
check "empty is None" before creating a new tensor. Concretely, in the block
that currently reads "if empty is None or empty.device != device:", change it to
"if empty is None" and keep the creation line "empty = torch.empty((0,),
device=device, dtype=torch.int)" so CPU tensors (device.index is None) are
handled consistently.
Summary by CodeRabbit
New Features
Improvements
Chores
✏️ Tip: You can customize this high-level summary in your review settings.