Skip to content

ByT5Tokenizer support #124

@urroxyz

Description

@urroxyz

An implementation of ByT5Tokenizer:

Details
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from transformers_cfg.grammar_utils import IncrementalGrammarConstraint
from transformers_cfg.generation.logits_process import GrammarConstrainedLogitsProcessor

# Define a minimal base interface (optional)
class TokenizerMiddleMapping:
    def map(self, token_id: int):
        raise NotImplementedError("Subclasses must implement this method.")

# Custom mapping for ByT5 tokens with __len__ and proper token_id conversion.
class ByT5Mapping(TokenizerMiddleMapping):
    def __init__(self, tokenizer):
        self.tokenizer = tokenizer
        self.token_id_to_bytes = {}
        self._vocab_size = len(tokenizer.get_vocab())

    def map(self, token_id):
        # Ensure token_id is a Python integer.
        if isinstance(token_id, torch.Tensor):
            token_id = token_id.item()

        if token_id in self.token_id_to_bytes:
            return self.token_id_to_bytes[token_id]

        token_str = self.tokenizer.convert_ids_to_tokens(token_id)
        # Remove ByT5-specific word boundary indicator (▁)
        if token_str.startswith("▁"):
            token_str = token_str[1:]
            
        if not token_str:
            byte_repr = [0]
        else:
            byte_repr = list(token_str.encode("utf-8"))

        self.token_id_to_bytes[token_id] = byte_repr
        return byte_repr

    def __len__(self):
        return self._vocab_size

# Manually build a ByteTrie using the custom mapping.
def build_byte_trie(tokenizer, mapping):
    from transformers_cfg.tokenization.byte_trie import ByteTrie

    vocab = tokenizer.get_vocab()  # token -> token_id dict
    token_ids_to_ignore = set(tokenizer.all_special_ids)
    trie = ByteTrie()

    # Insert each token's byte representation into the trie.
    for token, token_id in vocab.items():
        if token_id in token_ids_to_ignore:
            continue
        byte_repr = mapping.map(token_id)
        trie.insert(byte_repr, token_id)
    trie.vocab_size = len(vocab)
    return trie

# Inputs
transcription = "Me gustan esas bandas."
translation = "I like those bands."
lang = "Spanish"
metalang = "English"
is_segmented = False

prompt = f"""Provide the glosses for the following transcription in {lang}.

Transcription in {lang}: {transcription}
Transcription segmented: {is_segmented}
Translation in {metalang}: {translation}

Glosses: 
"""

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Model and tokenizer setup
model_id = "lecslab/glosslm"

tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True)

# Monkey-patch to add the 'vocab' attribute if missing.
if not hasattr(tokenizer, "vocab"):
    tokenizer.vocab = tokenizer.get_vocab()

model = AutoModelForSeq2SeqLM.from_pretrained(model_id).to(device)

# Tokenize input
inputs = tokenizer(prompt, return_tensors="pt").to(device)

# Define grammar constraints
grammar_str = """
root   ::= one ws two ws three ws four
ws     ::= " "
one    ::= "me" | "to me" | "for me"
two    ::= "like" | "please"
three  ::= "these" | "those"
four   ::= "bands" | "gangs" | "groups"
"""

# Create our custom mapping and build a trie based on it.
mapping = ByT5Mapping(tokenizer)
trie = build_byte_trie(tokenizer, mapping)

# Apply grammar constraints by providing both our custom trie and mapping.
grammar = IncrementalGrammarConstraint(
    grammar_str, "root", tokenizer, trie=trie, homomorphism=mapping
)
grammar_processor = GrammarConstrainedLogitsProcessor(grammar)

# Generate constrained output
outputs_ids = model.generate(
    **inputs,
    max_length=1000,
    logits_processor=[grammar_processor],
    repetition_penalty=1.1,
    num_return_sequences=1
)

outputs = tokenizer.batch_decode(outputs_ids, skip_special_tokens=True)
print(outputs[0])

Other SentencePiece-based tokenizers can be mapped similarly.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions