-
Notifications
You must be signed in to change notification settings - Fork 22
Open
Description
This bug is very wired....
I can not reproduce it on linux machine but only on my local macbook....
Bascially, with a batch size of 1, I get this wired error
input_ids = tokenizer(
[prefix1], add_special_tokens=False, return_tensors="pt", padding=True
)["input_ids"] # 2415 bus error python examples/generate_json.py->
[1] 2415 bus error python examples/generate_json.py
But with batch size = 2, it disappears...
input_ids = tokenizer(
[prefix1, prefix1], add_special_tokens=False, return_tensors="pt", padding=True
)["input_ids"] import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers_cfg.grammar_utils import IncrementalGrammarConstraint
from transformers_cfg.generation.logits_process import GrammarConstrainedLogitsProcessor
if __name__ == "__main__":
model_id = "gpt2"
# Load model and tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(model_id)
model.generation_config.pad_token_id = model.generation_config.eos_token_id
# Load grammar
with open("examples/grammars/json.ebnf", "r") as file:
grammar_str = file.read()
grammar = IncrementalGrammarConstraint(grammar_str, "root", tokenizer)
grammar_processor = GrammarConstrainedLogitsProcessor(grammar)
# Generate
prefix1 = "This is a valid json string for http request:"
prefix2 = "This is a valid json string for shopping cart:"
input_ids = tokenizer(
[prefix1], add_special_tokens=False, return_tensors="pt", padding=True
)["input_ids"] # 2415 bus error python examples/generate_json.py
# input_ids = tokenizer(
# [prefix1, prefix2], add_special_tokens=False, return_tensors="pt", padding=True
# )["input_ids"] # this works fine
output = model.generate(
input_ids,
do_sample=False,
max_new_tokens=60,
logits_processor=[grammar_processor],
repetition_penalty=1.1,
num_return_sequences=1,
)
# decode output
generations = tokenizer.batch_decode(output, skip_special_tokens=True)
print(generations)
"""
'This is a valid json string for http request:{ "request": { "method": "GET", "headers": [], "content": "Content","type": "application" }}
'This is a valid json string for shopping cart:This is a valid json string for shopping cart:{ "name": "MyCart", "price": 0, "value": 1 }
"""Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels