-
Notifications
You must be signed in to change notification settings - Fork 22
Closed
Description
An implementation of ByT5Tokenizer:
Details
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from transformers_cfg.grammar_utils import IncrementalGrammarConstraint
from transformers_cfg.generation.logits_process import GrammarConstrainedLogitsProcessor
# Define a minimal base interface (optional)
class TokenizerMiddleMapping:
def map(self, token_id: int):
raise NotImplementedError("Subclasses must implement this method.")
# Custom mapping for ByT5 tokens with __len__ and proper token_id conversion.
class ByT5Mapping(TokenizerMiddleMapping):
def __init__(self, tokenizer):
self.tokenizer = tokenizer
self.token_id_to_bytes = {}
self._vocab_size = len(tokenizer.get_vocab())
def map(self, token_id):
# Ensure token_id is a Python integer.
if isinstance(token_id, torch.Tensor):
token_id = token_id.item()
if token_id in self.token_id_to_bytes:
return self.token_id_to_bytes[token_id]
token_str = self.tokenizer.convert_ids_to_tokens(token_id)
# Remove ByT5-specific word boundary indicator (▁)
if token_str.startswith("▁"):
token_str = token_str[1:]
if not token_str:
byte_repr = [0]
else:
byte_repr = list(token_str.encode("utf-8"))
self.token_id_to_bytes[token_id] = byte_repr
return byte_repr
def __len__(self):
return self._vocab_size
# Manually build a ByteTrie using the custom mapping.
def build_byte_trie(tokenizer, mapping):
from transformers_cfg.tokenization.byte_trie import ByteTrie
vocab = tokenizer.get_vocab() # token -> token_id dict
token_ids_to_ignore = set(tokenizer.all_special_ids)
trie = ByteTrie()
# Insert each token's byte representation into the trie.
for token, token_id in vocab.items():
if token_id in token_ids_to_ignore:
continue
byte_repr = mapping.map(token_id)
trie.insert(byte_repr, token_id)
trie.vocab_size = len(vocab)
return trie
# Inputs
transcription = "Me gustan esas bandas."
translation = "I like those bands."
lang = "Spanish"
metalang = "English"
is_segmented = False
prompt = f"""Provide the glosses for the following transcription in {lang}.
Transcription in {lang}: {transcription}
Transcription segmented: {is_segmented}
Translation in {metalang}: {translation}
Glosses:
"""
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# Model and tokenizer setup
model_id = "lecslab/glosslm"
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True)
# Monkey-patch to add the 'vocab' attribute if missing.
if not hasattr(tokenizer, "vocab"):
tokenizer.vocab = tokenizer.get_vocab()
model = AutoModelForSeq2SeqLM.from_pretrained(model_id).to(device)
# Tokenize input
inputs = tokenizer(prompt, return_tensors="pt").to(device)
# Define grammar constraints
grammar_str = """
root ::= one ws two ws three ws four
ws ::= " "
one ::= "me" | "to me" | "for me"
two ::= "like" | "please"
three ::= "these" | "those"
four ::= "bands" | "gangs" | "groups"
"""
# Create our custom mapping and build a trie based on it.
mapping = ByT5Mapping(tokenizer)
trie = build_byte_trie(tokenizer, mapping)
# Apply grammar constraints by providing both our custom trie and mapping.
grammar = IncrementalGrammarConstraint(
grammar_str, "root", tokenizer, trie=trie, homomorphism=mapping
)
grammar_processor = GrammarConstrainedLogitsProcessor(grammar)
# Generate constrained output
outputs_ids = model.generate(
**inputs,
max_length=1000,
logits_processor=[grammar_processor],
repetition_penalty=1.1,
num_return_sequences=1
)
outputs = tokenizer.batch_decode(outputs_ids, skip_special_tokens=True)
print(outputs[0])Other SentencePiece-based tokenizers can be mapped similarly.
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels