From c924819c1bcb7b5ebd0d8dc881081dcf0fd02e7e Mon Sep 17 00:00:00 2001 From: Urro Date: Fri, 7 Mar 2025 16:28:44 -0500 Subject: [PATCH 1/9] Update logits_process.py Integrate LlamaCPP Python wrapper (`llama-cpp-python`) --- transformers_cfg/generation/logits_process.py | 138 +++++++++++++++--- 1 file changed, 115 insertions(+), 23 deletions(-) diff --git a/transformers_cfg/generation/logits_process.py b/transformers_cfg/generation/logits_process.py index fc70e90..423409c 100644 --- a/transformers_cfg/generation/logits_process.py +++ b/transformers_cfg/generation/logits_process.py @@ -4,6 +4,7 @@ import pprint from typing import Optional, Literal +import numpy as np import torch import logging from transformers.generation.logits_process import ( @@ -25,21 +26,30 @@ def __init__( valid_token_start_idx: Optional[int] = None, execution_mode: Literal["speculation", "full_mask"] = "full_mask", device: Optional[torch.device] = None, + library: str = "transformers" ) -> None: + # Initialize basic attributes. self.last_size = None self.grammar_constraint = grammar_constraint self.batch_parsing_states = None self.valid_token_start_idx = valid_token_start_idx self.execution_mode = execution_mode self.device = device + self.library = library + # For llama-cpp-python, initialize additional attributes. + if self.library == "llama-cpp-python": + self.reinit_attempts = 0 + self.reinit_max = 3 + self.accumulated_tokens = [] def mask_logits( self, logits: torch.FloatTensor, device: torch.device ) -> torch.FloatTensor: + # Clone logits to avoid modifying the original tensor. masked_logits = logits.clone() if self.execution_mode == "speculation": - # try to accept the most likely token + # Try to accept the most likely token. acceptance = torch.zeros( (logits.shape[0], len(self.grammar_constraint.homomorphism)), dtype=torch.bool, @@ -56,20 +66,21 @@ def mask_logits( if is_next_token_accepted: acceptance[i, next_token] = True else: - # resolve each stack to a tensor of True/False for each token - # indicating acceptance + # Resolve each stack to a tensor of True/False for each token + # indicating acceptance. # acceptance = self.grammar_acceptor.filter_vocab(self.stacks, device) acceptance[i] = self.grammar_constraint.filter_vocab( self.batch_parsing_states[i], device ) else: + # In full_mask mode, filter vocabulary for the entire batch. acceptance = self.grammar_constraint.batch_filter_vocab( self.batch_parsing_states, device ) - # if the logits size of the model is more than the tokennizer vocab + # If the logits size of the model is more than the tokenizer vocab, # we artificially expand the acceptance tensor and block everything - # beyond the tokenizer vocab size + # beyond the tokenizer vocab size. acceptance_vocab_size = acceptance.shape[-1] masked_logits_vocab_size = masked_logits.shape[-1] if masked_logits_vocab_size != acceptance_vocab_size: @@ -84,22 +95,20 @@ def mask_logits( ) acceptance = torch.cat((acceptance, false_tensor), dim=-1) - # acceptance is a tensor of shape (batch_size, vocab_size) - # get the indices of the accepted tokens - # do the following operation only in debug mode + # If in debug mode, print accepted token indices and tokens. if os.getenv("DEBUG_MODE") == "True": - # convert acceptance to numpy array + # Convert acceptance to a numpy array. batch_size, vocab_size = acceptance.shape acceptance_np = acceptance.cpu().numpy() accepted_x, accepted_y = acceptance_np.nonzero() - # dict of {batch_index: [accepted_token_indices]} - # initialize the dict with empty list + # Dict of {batch_index: [accepted_token_indices]} + # Initialize the dict with empty lists. accepted_token_indices = {i: [] for i in range(batch_size)} for x, y in zip(accepted_x, accepted_y): accepted_token_indices[x].append(y) logger.debug("Accepted token indices for the current batch:") logger.debug("\n" + pprint.pformat(accepted_token_indices)) - # convert token_ids to tokens + # Convert token_ids to tokens. accepted_tokens = { i: [ self.grammar_constraint.tokenizer.decode([token_id]) @@ -109,24 +118,26 @@ def mask_logits( } logger.debug("Accepted tokens for the current batch:") logger.debug("\n" + pprint.pformat(accepted_tokens)) - # Logits to -inf where False + # Set logits to -inf for tokens that are not accepted. masked_logits[~acceptance] = -math.inf return masked_logits def process_logits( - self, input_ids: torch.LongTensor, scores: torch.FloatTensor + self, input_ids: list, scores: torch.FloatTensor ) -> torch.FloatTensor: """ - :param input_ids: - :param scores: - :return: + Process logits by updating the grammar parsing states and masking logits. + + :param input_ids: List of token sequences. + :param scores: Logits tensor. + :return: Masked logits tensor. """ if self.device is None: device = scores.device - # we dynamically create stacks at the first call, so that we know the batch size and beam size + # Dynamically create stacks at the first call, so that we know the batch size. if self.batch_parsing_states is None: self.batch_parsing_states = [ - # self.grammar_constraint.init_stacks() + # Use a deep copy of the initial parsing state. copy.deepcopy( self.grammar_constraint.string_recognizer.get_initial_parsing_state() ) @@ -145,8 +156,7 @@ def process_logits( [len(acc_state.stacks) for acc_state in self.batch_parsing_states] ) ) - # logger.debug("stacks: \n" + pprint.pformat(self.batch_parsing_states.stacks)) - + # Update grammar parsing states based on the current input token sequences. self.batch_parsing_states = ( self.grammar_constraint.update_state_with_batch_token_seqs( input_ids, self.batch_parsing_states, self.valid_token_start_idx @@ -154,16 +164,98 @@ def process_logits( ) logger.debug(f"input_ids: {input_ids}") + # Mask logits based on grammar constraints. masked_scores = self.mask_logits(scores, device) return masked_scores + def _force_eos(self, scores: torch.FloatTensor) -> torch.FloatTensor: + """ + Force logits so that only the EOS token is allowed; all other tokens + get a score of -infinity. + """ + eos_token = self.grammar_constraint.tokenizer.eos_token_id + logger.warning(f"Forcing EOS token: {eos_token}") + mask = torch.full_like(scores, fill_value=-float("inf")) + if scores.dim() == 2: + mask[:, eos_token] = 0 + else: + mask[eos_token] = 0 + return mask + @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) def __call__( - self, input_ids: torch.LongTensor, scores: torch.FloatTensor + self, input_ids, scores ) -> torch.FloatTensor: - return self.process_logits(input_ids, scores) + # For the llama-cpp-python branch, perform additional normalization and error handling. + if self.library == "llama-cpp-python": + # Normalize input_ids to be a list of token sequences. + if np.isscalar(input_ids): + input_ids = [int(input_ids)] + elif isinstance(input_ids, np.ndarray): + input_ids = input_ids.tolist() + elif isinstance(input_ids, list): + input_ids = [int(i) if isinstance(i, np.generic) else i for i in input_ids] + elif isinstance(input_ids, np.generic): + input_ids = [int(input_ids)] + if input_ids and isinstance(input_ids[0], int): + input_ids = [input_ids] + + if not isinstance(scores, torch.Tensor): + scores = torch.tensor(scores) + if scores.dim() == 1: + scores = scores.unsqueeze(0) + + # Track token accumulation for debugging. + if len(input_ids[0]) > len(self.accumulated_tokens): + new_token = input_ids[0][-1] + self.accumulated_tokens.append(new_token) + try: + token_text = self.grammar_constraint.tokenizer.decode([new_token]) + logger.debug(f"Added token: {new_token} ({token_text})") + except Exception: + logger.debug(f"Added token: {new_token} (cannot decode)") + + # Check for consistency: if the current length does not match the expected length, + # reinitialize the grammar constraint. + current_length = len(input_ids[0]) + if hasattr(self.grammar_constraint, "last_size") and self.grammar_constraint.last_size is not None: + expected_length = self.grammar_constraint.last_size + 1 + if current_length != expected_length: + logger.warning(f"Length mismatch: current={current_length}, expected={expected_length}. Reinitializing grammar constraint.") + self.grammar_constraint.reset() + self.batch_parsing_states = None + self.reinit_attempts = 0 + try: + processed_scores = self.process_logits(input_ids, scores) + self.reinit_attempts = 0 + except ValueError as e: + error_msg = str(e) + if "All stacks are empty" in error_msg: + if self.reinit_attempts < self.reinit_max: + logger.warning(f"Grammar constraint error: {error_msg}. Attempt {self.reinit_attempts+1}/{self.reinit_max} to recover.") + self.grammar_constraint.reset() + self.batch_parsing_states = None + self.reinit_attempts += 1 + try: + processed_scores = self.process_logits(input_ids, scores) + except ValueError as e2: + logger.error(f"Recovery failed: {str(e2)}") + processed_scores = self._force_eos(scores) + else: + logger.error(f"Max retries ({self.reinit_max}) exceeded. Forcing EOS.") + processed_scores = self._force_eos(scores) + else: + logger.error(f"Unexpected error: {error_msg}") + raise e + if processed_scores.dim() == 2 and processed_scores.size(0) == 1: + processed_scores = processed_scores.squeeze(0) + return processed_scores.detach().cpu().numpy() + else: + # Default transformers behavior. + return self.process_logits(input_ids, scores) def reset(self): + # Reset the grammar parsing states. self.batch_parsing_states = None if isinstance(self.grammar_constraint, IncrementalGrammarConstraint): self.grammar_constraint.reset() From 6a8c281a6cd9872327d6059535beac2e73d3db04 Mon Sep 17 00:00:00 2001 From: Urro Date: Fri, 7 Mar 2025 16:34:13 -0500 Subject: [PATCH 2/9] Update logits_process.py Clean up comments --- transformers_cfg/generation/logits_process.py | 42 ------------------- 1 file changed, 42 deletions(-) diff --git a/transformers_cfg/generation/logits_process.py b/transformers_cfg/generation/logits_process.py index 423409c..994e766 100644 --- a/transformers_cfg/generation/logits_process.py +++ b/transformers_cfg/generation/logits_process.py @@ -28,7 +28,6 @@ def __init__( device: Optional[torch.device] = None, library: str = "transformers" ) -> None: - # Initialize basic attributes. self.last_size = None self.grammar_constraint = grammar_constraint self.batch_parsing_states = None @@ -36,7 +35,6 @@ def __init__( self.execution_mode = execution_mode self.device = device self.library = library - # For llama-cpp-python, initialize additional attributes. if self.library == "llama-cpp-python": self.reinit_attempts = 0 self.reinit_max = 3 @@ -45,11 +43,9 @@ def __init__( def mask_logits( self, logits: torch.FloatTensor, device: torch.device ) -> torch.FloatTensor: - # Clone logits to avoid modifying the original tensor. masked_logits = logits.clone() if self.execution_mode == "speculation": - # Try to accept the most likely token. acceptance = torch.zeros( (logits.shape[0], len(self.grammar_constraint.homomorphism)), dtype=torch.bool, @@ -66,21 +62,14 @@ def mask_logits( if is_next_token_accepted: acceptance[i, next_token] = True else: - # Resolve each stack to a tensor of True/False for each token - # indicating acceptance. - # acceptance = self.grammar_acceptor.filter_vocab(self.stacks, device) acceptance[i] = self.grammar_constraint.filter_vocab( self.batch_parsing_states[i], device ) else: - # In full_mask mode, filter vocabulary for the entire batch. acceptance = self.grammar_constraint.batch_filter_vocab( self.batch_parsing_states, device ) - # If the logits size of the model is more than the tokenizer vocab, - # we artificially expand the acceptance tensor and block everything - # beyond the tokenizer vocab size. acceptance_vocab_size = acceptance.shape[-1] masked_logits_vocab_size = masked_logits.shape[-1] if masked_logits_vocab_size != acceptance_vocab_size: @@ -95,20 +84,15 @@ def mask_logits( ) acceptance = torch.cat((acceptance, false_tensor), dim=-1) - # If in debug mode, print accepted token indices and tokens. if os.getenv("DEBUG_MODE") == "True": - # Convert acceptance to a numpy array. batch_size, vocab_size = acceptance.shape acceptance_np = acceptance.cpu().numpy() accepted_x, accepted_y = acceptance_np.nonzero() - # Dict of {batch_index: [accepted_token_indices]} - # Initialize the dict with empty lists. accepted_token_indices = {i: [] for i in range(batch_size)} for x, y in zip(accepted_x, accepted_y): accepted_token_indices[x].append(y) logger.debug("Accepted token indices for the current batch:") logger.debug("\n" + pprint.pformat(accepted_token_indices)) - # Convert token_ids to tokens. accepted_tokens = { i: [ self.grammar_constraint.tokenizer.decode([token_id]) @@ -118,37 +102,22 @@ def mask_logits( } logger.debug("Accepted tokens for the current batch:") logger.debug("\n" + pprint.pformat(accepted_tokens)) - # Set logits to -inf for tokens that are not accepted. masked_logits[~acceptance] = -math.inf return masked_logits def process_logits( self, input_ids: list, scores: torch.FloatTensor ) -> torch.FloatTensor: - """ - Process logits by updating the grammar parsing states and masking logits. - - :param input_ids: List of token sequences. - :param scores: Logits tensor. - :return: Masked logits tensor. - """ if self.device is None: device = scores.device - # Dynamically create stacks at the first call, so that we know the batch size. if self.batch_parsing_states is None: self.batch_parsing_states = [ - # Use a deep copy of the initial parsing state. copy.deepcopy( self.grammar_constraint.string_recognizer.get_initial_parsing_state() ) for _ in range(len(input_ids)) ] - - if os.getenv("DEBUG_MODE") == "True": - print("-" * 80) - logger.debug("input_ids: \n" + pprint.pformat(input_ids)) - # logger.debug("scores: \n" + pprint.pformat(scores)) logger.debug("last_size: \n" + pprint.pformat(self.last_size)) logger.debug( "num of stacks: \n" @@ -156,23 +125,16 @@ def process_logits( [len(acc_state.stacks) for acc_state in self.batch_parsing_states] ) ) - # Update grammar parsing states based on the current input token sequences. self.batch_parsing_states = ( self.grammar_constraint.update_state_with_batch_token_seqs( input_ids, self.batch_parsing_states, self.valid_token_start_idx ) ) logger.debug(f"input_ids: {input_ids}") - - # Mask logits based on grammar constraints. masked_scores = self.mask_logits(scores, device) return masked_scores def _force_eos(self, scores: torch.FloatTensor) -> torch.FloatTensor: - """ - Force logits so that only the EOS token is allowed; all other tokens - get a score of -infinity. - """ eos_token = self.grammar_constraint.tokenizer.eos_token_id logger.warning(f"Forcing EOS token: {eos_token}") mask = torch.full_like(scores, fill_value=-float("inf")) @@ -186,7 +148,6 @@ def _force_eos(self, scores: torch.FloatTensor) -> torch.FloatTensor: def __call__( self, input_ids, scores ) -> torch.FloatTensor: - # For the llama-cpp-python branch, perform additional normalization and error handling. if self.library == "llama-cpp-python": # Normalize input_ids to be a list of token sequences. if np.isscalar(input_ids): @@ -215,8 +176,6 @@ def __call__( except Exception: logger.debug(f"Added token: {new_token} (cannot decode)") - # Check for consistency: if the current length does not match the expected length, - # reinitialize the grammar constraint. current_length = len(input_ids[0]) if hasattr(self.grammar_constraint, "last_size") and self.grammar_constraint.last_size is not None: expected_length = self.grammar_constraint.last_size + 1 @@ -255,7 +214,6 @@ def __call__( return self.process_logits(input_ids, scores) def reset(self): - # Reset the grammar parsing states. self.batch_parsing_states = None if isinstance(self.grammar_constraint, IncrementalGrammarConstraint): self.grammar_constraint.reset() From 719a44615b382adf37d9cdd58e6a1e347e641808 Mon Sep 17 00:00:00 2001 From: URRO Date: Sat, 8 Mar 2025 17:39:25 -0500 Subject: [PATCH 3/9] Update logits_process.py --- transformers_cfg/generation/logits_process.py | 148 ++++++++---------- 1 file changed, 63 insertions(+), 85 deletions(-) diff --git a/transformers_cfg/generation/logits_process.py b/transformers_cfg/generation/logits_process.py index 994e766..978fcd5 100644 --- a/transformers_cfg/generation/logits_process.py +++ b/transformers_cfg/generation/logits_process.py @@ -2,9 +2,9 @@ import math import os import pprint +import importlib from typing import Optional, Literal -import numpy as np import torch import logging from transformers.generation.logits_process import ( @@ -26,7 +26,7 @@ def __init__( valid_token_start_idx: Optional[int] = None, execution_mode: Literal["speculation", "full_mask"] = "full_mask", device: Optional[torch.device] = None, - library: str = "transformers" + adapter: str = "transformers", ) -> None: self.last_size = None self.grammar_constraint = grammar_constraint @@ -34,11 +34,30 @@ def __init__( self.valid_token_start_idx = valid_token_start_idx self.execution_mode = execution_mode self.device = device - self.library = library - if self.library == "llama-cpp-python": - self.reinit_attempts = 0 - self.reinit_max = 3 - self.accumulated_tokens = [] + + # Create an alias for llama-cpp-python + if adapter == "llama-cpp-python": + adapter = "llama_cpp_python" + + self.adapter = adapter + + # Load adapter if specified and not "transformers" + self._adapter_func = None + if adapter != "transformers": + try: + # Import the adapter module + adapter_module = importlib.import_module( + f"transformers_cfg.adapters.{adapter}" + ) + # Get the adapter function with the same name as the module + adapter_func = getattr(adapter_module, adapter) + # Create the adapter function with this processor + self._adapter_func = adapter_func(self) + except (ImportError, AttributeError) as e: + logger.warning( + f"Failed to load adapter '{adapter}': {str(e)}. " + f"Falling back to default transformers behavior." + ) def mask_logits( self, logits: torch.FloatTensor, device: torch.device @@ -46,6 +65,7 @@ def mask_logits( masked_logits = logits.clone() if self.execution_mode == "speculation": + # try to accept the most likely token acceptance = torch.zeros( (logits.shape[0], len(self.grammar_constraint.homomorphism)), dtype=torch.bool, @@ -62,6 +82,9 @@ def mask_logits( if is_next_token_accepted: acceptance[i, next_token] = True else: + # resolve each stack to a tensor of True/False for each token + # indicating acceptance + # acceptance = self.grammar_acceptor.filter_vocab(self.stacks, device) acceptance[i] = self.grammar_constraint.filter_vocab( self.batch_parsing_states[i], device ) @@ -70,6 +93,9 @@ def mask_logits( self.batch_parsing_states, device ) + # if the logits size of the model is more than the tokennizer vocab + # we artificially expand the acceptance tensor and block everything + # beyond the tokenizer vocab size acceptance_vocab_size = acceptance.shape[-1] masked_logits_vocab_size = masked_logits.shape[-1] if masked_logits_vocab_size != acceptance_vocab_size: @@ -84,15 +110,22 @@ def mask_logits( ) acceptance = torch.cat((acceptance, false_tensor), dim=-1) + # acceptance is a tensor of shape (batch_size, vocab_size) + # get the indices of the accepted tokens + # do the following operation only in debug mode if os.getenv("DEBUG_MODE") == "True": + # convert acceptance to numpy array batch_size, vocab_size = acceptance.shape acceptance_np = acceptance.cpu().numpy() accepted_x, accepted_y = acceptance_np.nonzero() + # dict of {batch_index: [accepted_token_indices]} + # initialize the dict with empty list accepted_token_indices = {i: [] for i in range(batch_size)} for x, y in zip(accepted_x, accepted_y): accepted_token_indices[x].append(y) logger.debug("Accepted token indices for the current batch:") logger.debug("\n" + pprint.pformat(accepted_token_indices)) + # convert token_ids to tokens accepted_tokens = { i: [ self.grammar_constraint.tokenizer.decode([token_id]) @@ -102,22 +135,35 @@ def mask_logits( } logger.debug("Accepted tokens for the current batch:") logger.debug("\n" + pprint.pformat(accepted_tokens)) + # Logits to -inf where False masked_logits[~acceptance] = -math.inf return masked_logits def process_logits( - self, input_ids: list, scores: torch.FloatTensor + self, input_ids: torch.LongTensor, scores: torch.FloatTensor ) -> torch.FloatTensor: + """ + :param input_ids: + :param scores: + :return: + """ if self.device is None: device = scores.device + # we dynamically create stacks at the first call, so that we know the batch size and beam size if self.batch_parsing_states is None: self.batch_parsing_states = [ + # self.grammar_constraint.init_stacks() copy.deepcopy( self.grammar_constraint.string_recognizer.get_initial_parsing_state() ) for _ in range(len(input_ids)) ] + + if os.getenv("DEBUG_MODE") == "True": + print("-" * 80) + logger.debug("input_ids: \n" + pprint.pformat(input_ids)) + # logger.debug("scores: \n" + pprint.pformat(scores)) logger.debug("last_size: \n" + pprint.pformat(self.last_size)) logger.debug( "num of stacks: \n" @@ -125,93 +171,25 @@ def process_logits( [len(acc_state.stacks) for acc_state in self.batch_parsing_states] ) ) + # logger.debug("stacks: \n" + pprint.pformat(self.batch_parsing_states.stacks)) + self.batch_parsing_states = ( self.grammar_constraint.update_state_with_batch_token_seqs( input_ids, self.batch_parsing_states, self.valid_token_start_idx ) ) logger.debug(f"input_ids: {input_ids}") + masked_scores = self.mask_logits(scores, device) return masked_scores - def _force_eos(self, scores: torch.FloatTensor) -> torch.FloatTensor: - eos_token = self.grammar_constraint.tokenizer.eos_token_id - logger.warning(f"Forcing EOS token: {eos_token}") - mask = torch.full_like(scores, fill_value=-float("inf")) - if scores.dim() == 2: - mask[:, eos_token] = 0 - else: - mask[eos_token] = 0 - return mask - @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) - def __call__( - self, input_ids, scores - ) -> torch.FloatTensor: - if self.library == "llama-cpp-python": - # Normalize input_ids to be a list of token sequences. - if np.isscalar(input_ids): - input_ids = [int(input_ids)] - elif isinstance(input_ids, np.ndarray): - input_ids = input_ids.tolist() - elif isinstance(input_ids, list): - input_ids = [int(i) if isinstance(i, np.generic) else i for i in input_ids] - elif isinstance(input_ids, np.generic): - input_ids = [int(input_ids)] - if input_ids and isinstance(input_ids[0], int): - input_ids = [input_ids] - - if not isinstance(scores, torch.Tensor): - scores = torch.tensor(scores) - if scores.dim() == 1: - scores = scores.unsqueeze(0) - - # Track token accumulation for debugging. - if len(input_ids[0]) > len(self.accumulated_tokens): - new_token = input_ids[0][-1] - self.accumulated_tokens.append(new_token) - try: - token_text = self.grammar_constraint.tokenizer.decode([new_token]) - logger.debug(f"Added token: {new_token} ({token_text})") - except Exception: - logger.debug(f"Added token: {new_token} (cannot decode)") - - current_length = len(input_ids[0]) - if hasattr(self.grammar_constraint, "last_size") and self.grammar_constraint.last_size is not None: - expected_length = self.grammar_constraint.last_size + 1 - if current_length != expected_length: - logger.warning(f"Length mismatch: current={current_length}, expected={expected_length}. Reinitializing grammar constraint.") - self.grammar_constraint.reset() - self.batch_parsing_states = None - self.reinit_attempts = 0 - try: - processed_scores = self.process_logits(input_ids, scores) - self.reinit_attempts = 0 - except ValueError as e: - error_msg = str(e) - if "All stacks are empty" in error_msg: - if self.reinit_attempts < self.reinit_max: - logger.warning(f"Grammar constraint error: {error_msg}. Attempt {self.reinit_attempts+1}/{self.reinit_max} to recover.") - self.grammar_constraint.reset() - self.batch_parsing_states = None - self.reinit_attempts += 1 - try: - processed_scores = self.process_logits(input_ids, scores) - except ValueError as e2: - logger.error(f"Recovery failed: {str(e2)}") - processed_scores = self._force_eos(scores) - else: - logger.error(f"Max retries ({self.reinit_max}) exceeded. Forcing EOS.") - processed_scores = self._force_eos(scores) - else: - logger.error(f"Unexpected error: {error_msg}") - raise e - if processed_scores.dim() == 2 and processed_scores.size(0) == 1: - processed_scores = processed_scores.squeeze(0) - return processed_scores.detach().cpu().numpy() - else: - # Default transformers behavior. - return self.process_logits(input_ids, scores) + def __call__(self, input_ids, scores): + # If we have an adapter function, use it + if self._adapter_func is not None: + return self._adapter_func(input_ids, scores) + # Otherwise, use the default behavior + return self.process_logits(input_ids, scores) def reset(self): self.batch_parsing_states = None From fa0f922a432a6ef5a6b5b73d80bf3dbc8f715f9e Mon Sep 17 00:00:00 2001 From: URRO Date: Sat, 8 Mar 2025 17:39:56 -0500 Subject: [PATCH 4/9] Create llama_cpp_python.py --- transformers_cfg/adapters/llama_cpp_python.py | 107 ++++++++++++++++++ 1 file changed, 107 insertions(+) create mode 100644 transformers_cfg/adapters/llama_cpp_python.py diff --git a/transformers_cfg/adapters/llama_cpp_python.py b/transformers_cfg/adapters/llama_cpp_python.py new file mode 100644 index 0000000..2e313d1 --- /dev/null +++ b/transformers_cfg/adapters/llama_cpp_python.py @@ -0,0 +1,107 @@ +import logging +import numpy as np +import torch + +logger = logging.getLogger(__name__) + + +def llama_cpp_python(processor): + """ + Adapter function for llama-cpp-python. + + Args: + processor: A GrammarConstrainedLogitsProcessor instance + + Returns: + A function that can be used as a logits processor with llama-cpp-python + """ + reinit_attempts = 0 + reinit_max = 3 + accumulated_tokens = [] + + def _force_eos(scores): + eos_token = processor.grammar_constraint.tokenizer.eos_token_id + logger.warning(f"Forcing EOS token: {eos_token}") + mask = torch.full_like(scores, fill_value=-float("inf")) + if scores.dim() == 2: + mask[:, eos_token] = 0 + else: + mask[eos_token] = 0 + return mask + + def adapter_func(input_ids, scores): + nonlocal reinit_attempts, accumulated_tokens + + # Normalize input_ids to a list of token sequences + if np.isscalar(input_ids): + input_ids = [int(input_ids)] + elif isinstance(input_ids, np.ndarray): + input_ids = input_ids.tolist() + elif isinstance(input_ids, list): + input_ids = [int(i) if isinstance(i, np.generic) else i for i in input_ids] + elif isinstance(input_ids, np.generic): + input_ids = [int(input_ids)] + + # Ensure we have a batch (list of token lists) + if input_ids and isinstance(input_ids[0], int): + input_ids = [input_ids] + + # Convert scores to a torch.Tensor if needed + if not isinstance(scores, torch.Tensor): + scores = torch.tensor(scores) + + # Ensure scores is 2D: [batch, vocab_size] + if scores.dim() == 1: + scores = scores.unsqueeze(0) + + # Track tokens for debugging + if len(input_ids[0]) > len(accumulated_tokens): + new_token = input_ids[0][-1] + accumulated_tokens.append(new_token) + try: + token_text = processor.grammar_constraint.tokenizer.decode([new_token]) + logger.debug(f"Added token: {new_token} ({token_text})") + except Exception: + logger.debug(f"Added token: {new_token} (cannot decode)") + + # Check for consistency: if the length of our input token sequence + # does not match what the grammar expects, then reinitialize + current_length = len(input_ids[0]) + if hasattr(processor.grammar_constraint, "last_size") and processor.grammar_constraint.last_size is not None: + expected_length = processor.grammar_constraint.last_size + 1 + if current_length != expected_length: + logger.warning(f"Length mismatch: current={current_length}, expected={expected_length}. Reinitializing.") + processor.reset() + reinit_attempts = 0 + + try: + processed_scores = processor.process_logits(input_ids, scores) + reinit_attempts = 0 + except ValueError as e: + error_msg = str(e) + if "All stacks are empty" in error_msg: + # Try to recover by reinitializing the grammar constraint + if reinit_attempts < reinit_max: + logger.warning(f"Grammar constraint error: {error_msg}. Attempt {reinit_attempts+1}/{reinit_max} to recover.") + processor.reset() + reinit_attempts += 1 + try: + processed_scores = processor.process_logits(input_ids, scores) + except ValueError as e2: + logger.error(f"Recovery failed: {str(e2)}") + processed_scores = _force_eos(scores) + else: + # If reinitialization has already been attempted enough times, + # treat the output as complete and force EOS + logger.error(f"Max retries ({reinit_max}) exceeded. Current text: {processor.grammar_constraint.tokenizer.decode(accumulated_tokens)}") + processed_scores = _force_eos(scores) + else: + logger.error(f"Unexpected error: {error_msg}") + raise e + + # Remove the batch dimension if present + if processed_scores.dim() == 2 and processed_scores.size(0) == 1: + processed_scores = processed_scores.squeeze(0) + return processed_scores.detach().cpu().numpy() + + return adapter_func From 0795e789090455c4469000fc99f2a70e45c9fcb4 Mon Sep 17 00:00:00 2001 From: URRO Date: Sat, 8 Mar 2025 17:40:04 -0500 Subject: [PATCH 5/9] Create __init__.py --- transformers_cfg/adapters/__init__.py | 1 + 1 file changed, 1 insertion(+) create mode 100644 transformers_cfg/adapters/__init__.py diff --git a/transformers_cfg/adapters/__init__.py b/transformers_cfg/adapters/__init__.py new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/transformers_cfg/adapters/__init__.py @@ -0,0 +1 @@ + From c057a3b8c802b9b71f736c78e719f2b3a8ed9182 Mon Sep 17 00:00:00 2001 From: URRO Date: Sat, 8 Mar 2025 17:49:35 -0500 Subject: [PATCH 6/9] Create generate_llama_cpp_python --- examples/generate_llama_cpp_python | 46 ++++++++++++++++++++++++++++++ 1 file changed, 46 insertions(+) create mode 100644 examples/generate_llama_cpp_python diff --git a/examples/generate_llama_cpp_python b/examples/generate_llama_cpp_python new file mode 100644 index 0000000..ff220f5 --- /dev/null +++ b/examples/generate_llama_cpp_python @@ -0,0 +1,46 @@ +import io +import torch +import logging +from contextlib import redirect_stderr +from llama_cpp import Llama +from transformers_cfg.grammar_utils import IncrementalGrammarConstraint +from transformers_cfg.generation.logits_process import GrammarConstrainedLogitsProcessor +from transformers import AutoTokenizer + +logging.basicConfig(level=logging.INFO) + +# Define your EBNF grammar (you can replace this with your own) +ebnf_grammar = """ + + root ::= "The animal is a " animal "." + + animal ::= "cat" | "fish" + + """ + +# Load the tokenizer matching your model +tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-1.5b") + +# Redirect stderr and load the model via llama-cpp-python +f = io.StringIO() +with redirect_stderr(f): + model = Llama(model_path="qwen2.5-1.5b-q8_0.gguf", n_ctx=8000, verbose=False) + +# Create the grammar constraint and the logits processor with the new parameter. +grammar_constraint = IncrementalGrammarConstraint(ebnf_grammar, "root", tokenizer) +grammar_processor = GrammarConstrainedLogitsProcessor(grammar_constraint, adapter="llama-cpp-python") + +# Define a prompt. +prompt = """The text says, "The animal is a dog." The answer is obvious. """ + +# Use the text completion API with the logits processor. +response = model.create_completion( + stream=True, + prompt=prompt, + logits_processor=[grammar_processor], + max_tokens=100, +) + +for token in response: + token_text = token["choices"][0]["text"] + print(token_text, end="", flush=True) From 651734db321de4c1d41d3f7f6ee8031bc6ccd66f Mon Sep 17 00:00:00 2001 From: URRO Date: Sat, 8 Mar 2025 17:51:11 -0500 Subject: [PATCH 7/9] Update README.md --- README.md | 55 ++++++++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 54 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 22e373d..37158e8 100644 --- a/README.md +++ b/README.md @@ -196,6 +196,8 @@ if __name__ == "__main__": ### Transformers *Pipeline* +
+ ```py import torch from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline @@ -251,9 +253,60 @@ for generation_group in generations: print(generation['generated_text']) ``` +
+ ### LlamaCPP Python +Use the `llama-cpp-python` adapter, automatically loadable with the `adapter` parameter. + +```py +import io +import torch +import logging +from contextlib import redirect_stderr +from llama_cpp import Llama +from transformers_cfg.grammar_utils import IncrementalGrammarConstraint +from transformers_cfg.generation.logits_process import GrammarConstrainedLogitsProcessor +from transformers import AutoTokenizer + +logging.basicConfig(level=logging.INFO) + +# Define your EBNF grammar (you can replace this with your own) +ebnf_grammar = """ + + root ::= "The animal is a " animal "." -Coming soon! + animal ::= "cat" | "fish" + + """ + +# Load the tokenizer matching your model +tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-1.5b") + +# Redirect stderr and load the model via llama-cpp-python +f = io.StringIO() +with redirect_stderr(f): + model = Llama(model_path="qwen2.5-1.5b-q8_0.gguf", n_ctx=8000, verbose=False) + +# Create the grammar constraint and the logits processor with the new parameter. +grammar_constraint = IncrementalGrammarConstraint(ebnf_grammar, "root", tokenizer) +grammar_processor = GrammarConstrainedLogitsProcessor(grammar_constraint, adapter="llama-cpp-python") + +# Define a prompt. +prompt = """The text says, "The animal is a dog." The answer is obvious. """ + +# Use the text completion API with the logits processor. +response = model.create_completion( + stream=True, + prompt=prompt, + logits_processor=[grammar_processor], + max_tokens=100, +) + +for token in response: + token_text = token["choices"][0]["text"] + print(token_text, end="", flush=True) + +``` ## 💡 Why use `transformers-cfg`? From 7d530ef445b73cff9b543890d39606c3a424cd09 Mon Sep 17 00:00:00 2001 From: URRO Date: Sat, 8 Mar 2025 18:20:34 -0500 Subject: [PATCH 8/9] Update README.md --- README.md | 28 ++++++++++++++++++---------- 1 file changed, 18 insertions(+), 10 deletions(-) diff --git a/README.md b/README.md index 37158e8..6e2cfa0 100644 --- a/README.md +++ b/README.md @@ -5,24 +5,32 @@ ## 💭 Release news +### Experimental development + +
+ +- LlamaCPP Python wrapper support ([#116](https://github.com/epfl-dlab/transformers-CFG/pull/116)) + +
+ ### Latest release #### **[v0.2.7 Latest](https://github.com/epfl-dlab/transformers-CFG/releases/tag/v0.2.7)** (2025-03-02) #### **Features** -- **(CLI)** Types and MLX support (#93) -- **(Regex)** Negation, wildcard, and repetition bracket operators (#94, #95, #96, #104) -- **(Models)** Qwen2 and Qwen2.5 (#97) -- **(Logits)** Resuable `GrammarConstrainedLogitsProcessor` across generations for efficiency (#100) -- **(Backend)** Pytest for testing (#109) -- **(CI/CD)** GitHub Actions workflow for automation (#110) +- Types and MLX support ([#93](https://github.com/epfl-dlab/transformers-CFG/pull/93)) +- Negation, wildcard, and repetition bracket operators ([#94](https://github.com/epfl-dlab/transformers-CFG/pull/94), [#95](https://github.com/epfl-dlab/transformers-CFG/pull/95), [#96](https://github.com/epfl-dlab/transformers-CFG/pull/96), [#104](https://github.com/epfl-dlab/transformers-CFG/pull/104)) +- Qwen2 and Qwen2.5 ([#97](https://github.com/epfl-dlab/transformers-CFG/pull/97)) +- Resuable `GrammarConstrainedLogitsProcessor` for efficiency ([#100](https://github.com/epfl-dlab/transformers-CFG/pull/100)) +- Pytest for testing ([#109](https://github.com/epfl-dlab/transformers-CFG/pull/109)) +- GitHub Actions workflow for automation ([#110](https://github.com/epfl-dlab/transformers-CFG/pull/110)) #### **Bug fixes** -- Avoid computing full masks and optimized type additions (#101) -- Refactored grammar encoding to improve structure (#99) -- EOS token now correctly masks (#108) -- Multiple bugs removed and aesthetics improved (#107) +- Avoid computing full masks and optimized type additions ([#101](https://github.com/epfl-dlab/transformers-CFG/pull/101)) +- Refactored grammar encoding to improve structure ([#99](https://github.com/epfl-dlab/transformers-CFG/pull/99)) +- EOS token now correctly masks ([#108](https://github.com/epfl-dlab/transformers-CFG/pull/108)) +- Multiple bugs removed and aesthetics improved ([#107](https://github.com/epfl-dlab/transformers-CFG/pull/107)) ### Recent releases From 0cd54712cd4356e6649e84478b1bf5264a9b7e11 Mon Sep 17 00:00:00 2001 From: URRO Date: Sat, 8 Mar 2025 18:21:21 -0500 Subject: [PATCH 9/9] Update README.md --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 6e2cfa0..c981732 100644 --- a/README.md +++ b/README.md @@ -5,7 +5,7 @@ ## 💭 Release news -### Experimental development +### Latest experimental
@@ -13,7 +13,7 @@
-### Latest release +### Latest stable #### **[v0.2.7 Latest](https://github.com/epfl-dlab/transformers-CFG/releases/tag/v0.2.7)** (2025-03-02) #### **Features**