-
Notifications
You must be signed in to change notification settings - Fork 73
Description
Describe the bug
When I tried to run streamingLLM on minference with the qwen3-0.6B model, I found that the computed token loss was very abnormal compared to other models, such as the qwen2.5 series. The results and the example code are shown below. Even with dense mode, the output is still abnormal.
Code
from transformers import AutoModelForCausalLM, AutoTokenizer
from minference import MInference
import os, torch
import torch.nn.functional as F
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
prompt =
"""
"""
model_name = "/media/public/models/huggingface/Qwen/Qwen3-0.6B"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
device_map="auto",
trust_remote_code=True,
)
minference_patch = MInference("a_shape", model_name, attn_kwargs={"n_local": 255, "n_init": 1})
model = minference_patch(model)
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
input_ids = inputs["input_ids"]
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits # [1, seq_len, vocab_size]
shifted_logits = logits[:, :-1, :] # [1, seq_len-1, vocab]
shifted_labels = input_ids[:, 1:] # [1, seq_len-1]
log_probs = F.log_softmax(shifted_logits, dim=-1) # log-probs over vocab
token_logprobs = log_probs.gather(2, shifted_labels.unsqueeze(-1)).squeeze(-1) # [1, seq_len-1]
input_tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
print(f"Prompt: {prompt!r}\n")
cnt = 0.0
for i in range(len(token_logprobs[0])):
prev_token = input_tokens[i]
current_token = input_tokens[i + 1]
logprob = token_logprobs[0, i].item()
sum += logprob
print(f"Token {i + 1}: {current_token!r} | log P({current_token} | ... {prev_token}) = {logprob:.4f}")
print(sum)