diff --git a/README.md b/README.md
index 22e373d..c981732 100644
--- a/README.md
+++ b/README.md
@@ -5,24 +5,32 @@
## 💠Release news
-### Latest release
+### Latest experimental
+
+
+
+- LlamaCPP Python wrapper support ([#116](https://github.com/epfl-dlab/transformers-CFG/pull/116))
+
+
+
+### Latest stable
#### **[v0.2.7 Latest](https://github.com/epfl-dlab/transformers-CFG/releases/tag/v0.2.7)** (2025-03-02)
#### **Features**
-- **(CLI)** Types and MLX support (#93)
-- **(Regex)** Negation, wildcard, and repetition bracket operators (#94, #95, #96, #104)
-- **(Models)** Qwen2 and Qwen2.5 (#97)
-- **(Logits)** Resuable `GrammarConstrainedLogitsProcessor` across generations for efficiency (#100)
-- **(Backend)** Pytest for testing (#109)
-- **(CI/CD)** GitHub Actions workflow for automation (#110)
+- Types and MLX support ([#93](https://github.com/epfl-dlab/transformers-CFG/pull/93))
+- Negation, wildcard, and repetition bracket operators ([#94](https://github.com/epfl-dlab/transformers-CFG/pull/94), [#95](https://github.com/epfl-dlab/transformers-CFG/pull/95), [#96](https://github.com/epfl-dlab/transformers-CFG/pull/96), [#104](https://github.com/epfl-dlab/transformers-CFG/pull/104))
+- Qwen2 and Qwen2.5 ([#97](https://github.com/epfl-dlab/transformers-CFG/pull/97))
+- Resuable `GrammarConstrainedLogitsProcessor` for efficiency ([#100](https://github.com/epfl-dlab/transformers-CFG/pull/100))
+- Pytest for testing ([#109](https://github.com/epfl-dlab/transformers-CFG/pull/109))
+- GitHub Actions workflow for automation ([#110](https://github.com/epfl-dlab/transformers-CFG/pull/110))
#### **Bug fixes**
-- Avoid computing full masks and optimized type additions (#101)
-- Refactored grammar encoding to improve structure (#99)
-- EOS token now correctly masks (#108)
-- Multiple bugs removed and aesthetics improved (#107)
+- Avoid computing full masks and optimized type additions ([#101](https://github.com/epfl-dlab/transformers-CFG/pull/101))
+- Refactored grammar encoding to improve structure ([#99](https://github.com/epfl-dlab/transformers-CFG/pull/99))
+- EOS token now correctly masks ([#108](https://github.com/epfl-dlab/transformers-CFG/pull/108))
+- Multiple bugs removed and aesthetics improved ([#107](https://github.com/epfl-dlab/transformers-CFG/pull/107))
### Recent releases
@@ -196,6 +204,8 @@ if __name__ == "__main__":
### Transformers *Pipeline*
+
+
```py
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
@@ -251,9 +261,60 @@ for generation_group in generations:
print(generation['generated_text'])
```
+
+
### LlamaCPP Python
+Use the `llama-cpp-python` adapter, automatically loadable with the `adapter` parameter.
+
+```py
+import io
+import torch
+import logging
+from contextlib import redirect_stderr
+from llama_cpp import Llama
+from transformers_cfg.grammar_utils import IncrementalGrammarConstraint
+from transformers_cfg.generation.logits_process import GrammarConstrainedLogitsProcessor
+from transformers import AutoTokenizer
+
+logging.basicConfig(level=logging.INFO)
-Coming soon!
+# Define your EBNF grammar (you can replace this with your own)
+ebnf_grammar = """
+
+ root ::= "The animal is a " animal "."
+
+ animal ::= "cat" | "fish"
+
+ """
+
+# Load the tokenizer matching your model
+tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-1.5b")
+
+# Redirect stderr and load the model via llama-cpp-python
+f = io.StringIO()
+with redirect_stderr(f):
+ model = Llama(model_path="qwen2.5-1.5b-q8_0.gguf", n_ctx=8000, verbose=False)
+
+# Create the grammar constraint and the logits processor with the new parameter.
+grammar_constraint = IncrementalGrammarConstraint(ebnf_grammar, "root", tokenizer)
+grammar_processor = GrammarConstrainedLogitsProcessor(grammar_constraint, adapter="llama-cpp-python")
+
+# Define a prompt.
+prompt = """The text says, "The animal is a dog." The answer is obvious. """
+
+# Use the text completion API with the logits processor.
+response = model.create_completion(
+ stream=True,
+ prompt=prompt,
+ logits_processor=[grammar_processor],
+ max_tokens=100,
+)
+
+for token in response:
+ token_text = token["choices"][0]["text"]
+ print(token_text, end="", flush=True)
+
+```
## 💡 Why use `transformers-cfg`?
diff --git a/examples/generate_llama_cpp_python b/examples/generate_llama_cpp_python
new file mode 100644
index 0000000..ff220f5
--- /dev/null
+++ b/examples/generate_llama_cpp_python
@@ -0,0 +1,46 @@
+import io
+import torch
+import logging
+from contextlib import redirect_stderr
+from llama_cpp import Llama
+from transformers_cfg.grammar_utils import IncrementalGrammarConstraint
+from transformers_cfg.generation.logits_process import GrammarConstrainedLogitsProcessor
+from transformers import AutoTokenizer
+
+logging.basicConfig(level=logging.INFO)
+
+# Define your EBNF grammar (you can replace this with your own)
+ebnf_grammar = """
+
+ root ::= "The animal is a " animal "."
+
+ animal ::= "cat" | "fish"
+
+ """
+
+# Load the tokenizer matching your model
+tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-1.5b")
+
+# Redirect stderr and load the model via llama-cpp-python
+f = io.StringIO()
+with redirect_stderr(f):
+ model = Llama(model_path="qwen2.5-1.5b-q8_0.gguf", n_ctx=8000, verbose=False)
+
+# Create the grammar constraint and the logits processor with the new parameter.
+grammar_constraint = IncrementalGrammarConstraint(ebnf_grammar, "root", tokenizer)
+grammar_processor = GrammarConstrainedLogitsProcessor(grammar_constraint, adapter="llama-cpp-python")
+
+# Define a prompt.
+prompt = """The text says, "The animal is a dog." The answer is obvious. """
+
+# Use the text completion API with the logits processor.
+response = model.create_completion(
+ stream=True,
+ prompt=prompt,
+ logits_processor=[grammar_processor],
+ max_tokens=100,
+)
+
+for token in response:
+ token_text = token["choices"][0]["text"]
+ print(token_text, end="", flush=True)
diff --git a/transformers_cfg/adapters/__init__.py b/transformers_cfg/adapters/__init__.py
new file mode 100644
index 0000000..8b13789
--- /dev/null
+++ b/transformers_cfg/adapters/__init__.py
@@ -0,0 +1 @@
+
diff --git a/transformers_cfg/adapters/llama_cpp_python.py b/transformers_cfg/adapters/llama_cpp_python.py
new file mode 100644
index 0000000..2e313d1
--- /dev/null
+++ b/transformers_cfg/adapters/llama_cpp_python.py
@@ -0,0 +1,107 @@
+import logging
+import numpy as np
+import torch
+
+logger = logging.getLogger(__name__)
+
+
+def llama_cpp_python(processor):
+ """
+ Adapter function for llama-cpp-python.
+
+ Args:
+ processor: A GrammarConstrainedLogitsProcessor instance
+
+ Returns:
+ A function that can be used as a logits processor with llama-cpp-python
+ """
+ reinit_attempts = 0
+ reinit_max = 3
+ accumulated_tokens = []
+
+ def _force_eos(scores):
+ eos_token = processor.grammar_constraint.tokenizer.eos_token_id
+ logger.warning(f"Forcing EOS token: {eos_token}")
+ mask = torch.full_like(scores, fill_value=-float("inf"))
+ if scores.dim() == 2:
+ mask[:, eos_token] = 0
+ else:
+ mask[eos_token] = 0
+ return mask
+
+ def adapter_func(input_ids, scores):
+ nonlocal reinit_attempts, accumulated_tokens
+
+ # Normalize input_ids to a list of token sequences
+ if np.isscalar(input_ids):
+ input_ids = [int(input_ids)]
+ elif isinstance(input_ids, np.ndarray):
+ input_ids = input_ids.tolist()
+ elif isinstance(input_ids, list):
+ input_ids = [int(i) if isinstance(i, np.generic) else i for i in input_ids]
+ elif isinstance(input_ids, np.generic):
+ input_ids = [int(input_ids)]
+
+ # Ensure we have a batch (list of token lists)
+ if input_ids and isinstance(input_ids[0], int):
+ input_ids = [input_ids]
+
+ # Convert scores to a torch.Tensor if needed
+ if not isinstance(scores, torch.Tensor):
+ scores = torch.tensor(scores)
+
+ # Ensure scores is 2D: [batch, vocab_size]
+ if scores.dim() == 1:
+ scores = scores.unsqueeze(0)
+
+ # Track tokens for debugging
+ if len(input_ids[0]) > len(accumulated_tokens):
+ new_token = input_ids[0][-1]
+ accumulated_tokens.append(new_token)
+ try:
+ token_text = processor.grammar_constraint.tokenizer.decode([new_token])
+ logger.debug(f"Added token: {new_token} ({token_text})")
+ except Exception:
+ logger.debug(f"Added token: {new_token} (cannot decode)")
+
+ # Check for consistency: if the length of our input token sequence
+ # does not match what the grammar expects, then reinitialize
+ current_length = len(input_ids[0])
+ if hasattr(processor.grammar_constraint, "last_size") and processor.grammar_constraint.last_size is not None:
+ expected_length = processor.grammar_constraint.last_size + 1
+ if current_length != expected_length:
+ logger.warning(f"Length mismatch: current={current_length}, expected={expected_length}. Reinitializing.")
+ processor.reset()
+ reinit_attempts = 0
+
+ try:
+ processed_scores = processor.process_logits(input_ids, scores)
+ reinit_attempts = 0
+ except ValueError as e:
+ error_msg = str(e)
+ if "All stacks are empty" in error_msg:
+ # Try to recover by reinitializing the grammar constraint
+ if reinit_attempts < reinit_max:
+ logger.warning(f"Grammar constraint error: {error_msg}. Attempt {reinit_attempts+1}/{reinit_max} to recover.")
+ processor.reset()
+ reinit_attempts += 1
+ try:
+ processed_scores = processor.process_logits(input_ids, scores)
+ except ValueError as e2:
+ logger.error(f"Recovery failed: {str(e2)}")
+ processed_scores = _force_eos(scores)
+ else:
+ # If reinitialization has already been attempted enough times,
+ # treat the output as complete and force EOS
+ logger.error(f"Max retries ({reinit_max}) exceeded. Current text: {processor.grammar_constraint.tokenizer.decode(accumulated_tokens)}")
+ processed_scores = _force_eos(scores)
+ else:
+ logger.error(f"Unexpected error: {error_msg}")
+ raise e
+
+ # Remove the batch dimension if present
+ if processed_scores.dim() == 2 and processed_scores.size(0) == 1:
+ processed_scores = processed_scores.squeeze(0)
+ return processed_scores.detach().cpu().numpy()
+
+ return adapter_func
diff --git a/transformers_cfg/generation/logits_process.py b/transformers_cfg/generation/logits_process.py
index fc70e90..978fcd5 100644
--- a/transformers_cfg/generation/logits_process.py
+++ b/transformers_cfg/generation/logits_process.py
@@ -2,6 +2,7 @@
import math
import os
import pprint
+import importlib
from typing import Optional, Literal
import torch
@@ -25,6 +26,7 @@ def __init__(
valid_token_start_idx: Optional[int] = None,
execution_mode: Literal["speculation", "full_mask"] = "full_mask",
device: Optional[torch.device] = None,
+ adapter: str = "transformers",
) -> None:
self.last_size = None
self.grammar_constraint = grammar_constraint
@@ -33,6 +35,30 @@ def __init__(
self.execution_mode = execution_mode
self.device = device
+ # Create an alias for llama-cpp-python
+ if adapter == "llama-cpp-python":
+ adapter = "llama_cpp_python"
+
+ self.adapter = adapter
+
+ # Load adapter if specified and not "transformers"
+ self._adapter_func = None
+ if adapter != "transformers":
+ try:
+ # Import the adapter module
+ adapter_module = importlib.import_module(
+ f"transformers_cfg.adapters.{adapter}"
+ )
+ # Get the adapter function with the same name as the module
+ adapter_func = getattr(adapter_module, adapter)
+ # Create the adapter function with this processor
+ self._adapter_func = adapter_func(self)
+ except (ImportError, AttributeError) as e:
+ logger.warning(
+ f"Failed to load adapter '{adapter}': {str(e)}. "
+ f"Falling back to default transformers behavior."
+ )
+
def mask_logits(
self, logits: torch.FloatTensor, device: torch.device
) -> torch.FloatTensor:
@@ -158,9 +184,11 @@ def process_logits(
return masked_scores
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
- def __call__(
- self, input_ids: torch.LongTensor, scores: torch.FloatTensor
- ) -> torch.FloatTensor:
+ def __call__(self, input_ids, scores):
+ # If we have an adapter function, use it
+ if self._adapter_func is not None:
+ return self._adapter_func(input_ids, scores)
+ # Otherwise, use the default behavior
return self.process_logits(input_ids, scores)
def reset(self):