-
Notifications
You must be signed in to change notification settings - Fork 22
Open
Description
Reproduce
import transformers
import transformers_cfg
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers_cfg.grammar_utils import IncrementalGrammarConstraint
from transformers_cfg.generation.logits_process import GrammarConstrainedLogitsProcessor
if __name__ == "__main__":
print('transformers version', transformers.__version__)
print('transformers_cfg version, ', transformers_cfg)
# Load model and tokenizer
llama_tokenizer = AutoTokenizer.from_pretrained("saibo/llama-1B")
llama_tokenizer.pad_token = llama_tokenizer.eos_token
llama_model = AutoModelForCausalLM.from_pretrained("saibo/llama-1B")
# Load json grammar
with open("examples/grammars/json.ebnf", "r") as file:
grammar_str = file.read()
grammar = IncrementalGrammarConstraint(grammar_str, "root", llama_tokenizer)
grammar_processor = GrammarConstrainedLogitsProcessor(grammar)
# Generate
prefix1 = "This is a valid json string for http request:"
prefix2 = "This is a valid json string for shopping cart:"
input_ids = llama_tokenizer([prefix1, prefix2], add_special_tokens=False, return_tensors="pt", padding=True)["input_ids"]
output = llama_model.generate(
input_ids,
do_sample=False,
max_length=50,
num_beams=1,
logits_processor=[grammar_processor],
repetition_penalty=1.0,
num_return_sequences=1,
)
# decode output
generations = llama_tokenizer.batch_decode(output, skip_special_tokens=True)
print(generations)Context
saibo/llama-1B is a randomly initialized model for debugging purpose. Though it is not a trained LLM, it should be forced to generate some structure but it is failing.
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels