Conversation
| def _ensure_inference_task_started(self) -> None: | ||
| """Lazy-start the background inference task if not already running.""" | ||
| if self._inference_task is None or self._inference_task.done(): | ||
| # Create weakref to self so task can detect when client is GC'd | ||
| weak_self = weakref.ref(self) | ||
| self._inference_task = asyncio.create_task(self._inference_loop(weak_self)) | ||
| _LOGGER.info("TransformersLLMClient started background inference task") |
There was a problem hiding this comment.
Per CLAUDE.md async logging guidance, TransformersLLMClient started background inference task never logs id(self) or another client identifier, so if multiple clients run concurrently you cannot trace which instance started its background task; can we include the client id (e.g., id(self)) in this log?
| def _ensure_inference_task_started(self) -> None: | |
| """Lazy-start the background inference task if not already running.""" | |
| if self._inference_task is None or self._inference_task.done(): | |
| # Create weakref to self so task can detect when client is GC'd | |
| weak_self = weakref.ref(self) | |
| self._inference_task = asyncio.create_task(self._inference_loop(weak_self)) | |
| _LOGGER.info("TransformersLLMClient started background inference task") | |
| def _ensure_inference_task_started(self) -> None: | |
| """Lazy-start the background inference task if not already running.""" | |
| if self._inference_task is None or self._inference_task.done(): | |
| # Create weakref to self so task can detect when client is GC'd | |
| weak_self = weakref.ref(self) | |
| self._inference_task = asyncio.create_task(self._inference_loop(weak_self)) | |
| _LOGGER.info("TransformersLLMClient(id=%s) started background inference task", id(self)) |
Finding type: AI Coding Guidelines
There was a problem hiding this comment.
Not necessary, IMO
There was a problem hiding this comment.
The log currently just says TransformersLLMClient started background inference task, so when multiple clients are running there’s no way to tell which instance triggered it; including id(self) (or similar) gives us a lightweight identifier for debugging without changing functionality.
| # Process batch | ||
| _LOGGER.debug("Processing batch of %d request(s)", len(batch)) | ||
| await self._process_batch(batch) | ||
|
|
||
| except Exception as e: | ||
| _LOGGER.exception("Error in inference loop: %s", e) |
There was a problem hiding this comment.
Per CLAUDE.md async logging guidance, the Error in inference loop exception log never includes id(self) (or similar client identifier), so failed batches cannot be traced back to a particular client instance; can we include the client id in this log?
| # Process batch | |
| _LOGGER.debug("Processing batch of %d request(s)", len(batch)) | |
| await self._process_batch(batch) | |
| except Exception as e: | |
| _LOGGER.exception("Error in inference loop: %s", e) | |
| # Process batch | |
| _LOGGER.debug("Processing batch of %d request(s)", len(batch)) | |
| await self._process_batch(batch) | |
| except Exception as e: | |
| _LOGGER.exception("Error in inference loop for client %s: %s", hex(id(self)), e) |
Finding type: AI Coding Guidelines
There was a problem hiding this comment.
Not necessary IMO
There was a problem hiding this comment.
I hear you, but tracing the background task failures to a specific client instance is hard without an identifier—the exception already includes the client’s ID now, so we can correlate Error in inference loop for client … messages with the right request queue.
|
|
||
| except Exception as e: | ||
| _LOGGER.exception("Error processing batch: %s", e) |
There was a problem hiding this comment.
Per CLAUDE.md async logging guidance, the Error processing batch exception log never includes id(self) (or similar client identifier), so downstream errors cannot be tied back to a particular client instance when multiple are running; can we include the client id here as well?
| except Exception as e: | |
| _LOGGER.exception("Error processing batch: %s", e) | |
| except Exception as e: | |
| _LOGGER.exception("Error processing batch in client id=%s: %s", id(self), e) |
Finding type: AI Coding Guidelines
| # Create minimal tokenizer | ||
| minimal_tokenizer = transformers.GPT2Tokenizer.from_pretrained( | ||
| "gpt2", # Uses cached tokenizer vocab | ||
| model_max_length=32, | ||
| ) | ||
| minimal_tokenizer.pad_token = minimal_tokenizer.eos_token | ||
| minimal_tokenizer.padding_side = "left" | ||
| # Add minimal chat template | ||
| minimal_tokenizer.chat_template = "{% for message in messages %}{{ message.content }}{% endfor %}" |
There was a problem hiding this comment.
[Testing] The “no-download” integration test still calls transformers.GPT2Tokenizer.from_pretrained("gpt2") (lines 422‑430), which hits the HuggingFace hub whenever the vocab isn’t already cached on the test machine. That reintroduces a network dependency into the unit test suite and will make CI fail in air‑gapped environments. Please construct the tokenizer fully in-code (e.g., build a tiny tokenizers.Tokenizer/PreTrainedTokenizerFast with a hard-coded vocab, or write minimal vocab/merges files to a temporary directory and load from that path) so the test is deterministic and offline.
Context for Agents
The “no-download” integration test still calls `transformers.GPT2Tokenizer.from_pretrained("gpt2")` (lines 422‑430), which hits the HuggingFace hub whenever the vocab isn’t already cached on the test machine. That reintroduces a network dependency into the unit test suite and will make CI fail in air‑gapped environments. Please construct the tokenizer fully in-code (e.g., build a tiny `tokenizers.Tokenizer`/`PreTrainedTokenizerFast` with a hard-coded vocab, or write minimal vocab/merges files to a temporary directory and load from that path) so the test is deterministic and offline.
File: src/ares/contrib/transformers_client_test.py
Line: 430| @functools.cached_property | ||
| def _inference_task(self) -> asyncio.Task[None]: | ||
| """Lazy-initialized background inference task. | ||
|
|
||
| The task automatically exits when the client is garbage collected via weakref. | ||
| """ | ||
| weak_self = weakref.ref(self) | ||
| task = asyncio.create_task(self._inference_loop(weak_self)) | ||
| _LOGGER.info("TransformersLLMClient started background inference task") | ||
| return task | ||
|
|
||
| @functools.cached_property | ||
| def _device(self) -> str: | ||
| """Resolved device.""" |
There was a problem hiding this comment.
TransformersLLMClient logs TransformersLLMClient started background inference task without id(self) even though this async background task can run concurrently for multiple clients, so tracing is ambiguous; CLAUDE.md requires async logging to include object identifiers—can we log the client id in these statements?
Finding type: AI Coding Guidelines
Prompt for AI Agents:
In src/ares/contrib/transformers_client.py around lines 154 to 167, the _inference_task
cached_property logs "TransformersLLMClient started background inference task" without
identifying which client started the task. Update the logging call in that method to
include a unique identifier for the client (for example id(self)) and optionally the
model_name so concurrent clients are distinguishable. Make the log message something
like: "TransformersLLMClient started background inference task (id=%s, model=%s)" and
pass id(self) and self.model_name as parameters to the logger.
| @pytest.mark.asyncio | ||
| async def test_integration_with_minimal_model(): | ||
| """Integration test with a minimal GPT2 model. | ||
|
|
||
| Creates a tiny GPT2 model from scratch for testing the full pipeline. | ||
| Note: Downloads GPT2 tokenizer vocab on first run (~500KB, cached after). | ||
| """ | ||
| # Create minimal GPT2 config - vocab_size must match GPT2Tokenizer (50257) | ||
| config = transformers.GPT2Config( | ||
| vocab_size=50257, # Must match GPT2Tokenizer vocab | ||
| n_positions=32, | ||
| n_ctx=32, | ||
| n_embd=32, | ||
| n_layer=2, | ||
| n_head=4, | ||
| ) | ||
|
|
||
| minimal_model = transformers.GPT2LMHeadModel(config) | ||
| minimal_model.eval() | ||
|
|
||
| # GPT2 tokenizer is lightweight and cached after first download | ||
| minimal_tokenizer = transformers.GPT2Tokenizer.from_pretrained( | ||
| "gpt2", | ||
| model_max_length=32, | ||
| ) |
There was a problem hiding this comment.
test_integration_with_minimal_model calls transformers.GPT2Tokenizer.from_pretrained("gpt2"), which performs a real network download inside src/ares/contrib/transformers_client_test.py; CLAUDE.md mandates that unit tests under src/ must mock external services and real integration tests belong under integration_tests/, so this test should either mock the tokenizer download or move to the integration test suite (CLAUDE.md).
Finding type: AI Coding Guidelines
Prompt for AI Agents:
In src/ares/contrib/transformers_client_test.py around lines 286 to 310, the
test_integration_with_minimal_model function calls
transformers.GPT2Tokenizer.from_pretrained("gpt2") which performs a real network
download (not allowed for unit tests under src/). Refactor by mocking that call: replace
the direct from_pretrained invocation with a context manager that patches
transformers.GPT2Tokenizer.from_pretrained to return the locally constructed
minimal_tokenizer (use mock.patch.object(transformers.GPT2Tokenizer, "from_pretrained",
return_value=minimal_tokenizer)) so no network I/O occurs; alternatively, if you intend
this to be a true integration test, move the whole test function into integration_tests/
and update imports accordingly.
| self._generate_batch, | ||
| chat_conversations, | ||
| temperature=temperature, | ||
| max_new_tokens=max_new_tokens, | ||
| ) |
There was a problem hiding this comment.
_generate_batch is only called with chat_conversations, temperature, and max_new_tokens, so any other LLMRequest.to_chat_completion_kwargs() fields (e.g. top_p, stop/stop_sequences, tools, etc.) never reach _model.generate; those request options now have no effect when using this client
Finding type: Logical Bugs
Prompt for AI Agents:
In src/ares/contrib/transformers_client.py around lines 279 to 283, the call to
self._generate_batch only forwards (chat_conversations, temperature, max_new_tokens) so
other per-request options returned by to_chat_completion_kwargs (e.g., top_p,
stop/stop_sequences, any decoding options or tool hints) are lost. Refactor so that when
grouping requests you capture and propagate the full generation kwargs for the group
(use the first request's kwargs as the group's canonical kwargs, and only group requests
whose kwargs are identical or compatible), then change _process_batch/_generate_batch
signatures to accept a dict of generation kwargs and pass those through into the
tokenizer/model calls (model.generate and any tokenizer settings) instead of only
passing temperature and max_new_tokens. Ensure to update the grouping logic and tests
accordingly so stop sequences, top_p, etc. are applied per-request group.
| outputs = self._model.generate( | ||
| **inputs, # type: ignore[arg-type] | ||
| max_new_tokens=max_new_tokens, | ||
| temperature=temperature, | ||
| do_sample=temperature > 0, | ||
| pad_token_id=self._tokenizer.pad_token_id, | ||
| ) | ||
|
|
There was a problem hiding this comment.
self._model.generate can return a ModelOutput when config.return_dict_in_generate=True (e.g. some HF configs), but we immediately do generated_ids = outputs[:, input_lengths:] as if a tensor, so the call will fail (no __getitem__) for those models; can we either pass return_dict_in_generate=False to generate or pull outputs.sequences before slicing?
| outputs = self._model.generate( | |
| **inputs, # type: ignore[arg-type] | |
| max_new_tokens=max_new_tokens, | |
| temperature=temperature, | |
| do_sample=temperature > 0, | |
| pad_token_id=self._tokenizer.pad_token_id, | |
| ) | |
| outputs = self._model.generate( | |
| **inputs, # type: ignore[arg-type] | |
| max_new_tokens=max_new_tokens, | |
| temperature=temperature, | |
| do_sample=temperature > 0, | |
| pad_token_id=self._tokenizer.pad_token_id, | |
| return_dict_in_generate=False, | |
| ) |
Finding type: Logical Bugs
| async def _inference_loop(self, weak_self: weakref.ReferenceType) -> None: | ||
| """Background task that batches and processes requests. | ||
|
|
||
| Groups requests by (temperature, max_tokens) during collection to preserve per-request | ||
| semantics. Collects until batch_wait_ms elapses OR any group reaches max_batch_size. | ||
|
|
||
| Tradeoff: Grouping at collection time is more efficient than collecting mixed batches | ||
| and splitting later, but less efficient than a full meta-queue with per-parameter timers. | ||
| Future: Could extend to a queue-of-queues where each parameter combo gets its own queue | ||
| and timer, allowing truly independent batching per parameter set. | ||
|
|
||
| Args: | ||
| weak_self: Weakref to the client - task exits when client is GC'd | ||
| """ | ||
| while weak_self() is not None: |
There was a problem hiding this comment.
[Logic] The background task is supposed to stop once the client is garbage collected, but _inference_loop closes over self because it is invoked as an instance method (asyncio.create_task(self._inference_loop(weak_self))). The coroutine holds a strong reference to self, so weak_self() can never become None and the loop never exits. Every client instance therefore leaks its inference task/model/tokenizer forever, preventing process shutdown and continuously pinning GPU/CPU memory. Please restructure the worker so it only keeps a weak reference (e.g. make _inference_loop a @staticmethod/stand‑alone coroutine that reacquires the client each iteration and breaks when the weak ref is dead, or explicitly cancel the task when the client is disposed) to honour the documented lifecycle.
Context for Agents
The background task is supposed to stop once the client is garbage collected, but `_inference_loop` closes over `self` because it is invoked as an instance method (`asyncio.create_task(self._inference_loop(weak_self))`). The coroutine holds a strong reference to `self`, so `weak_self()` can never become `None` and the loop never exits. Every client instance therefore leaks its inference task/model/tokenizer forever, preventing process shutdown and continuously pinning GPU/CPU memory. Please restructure the worker so it only keeps a weak reference (e.g. make `_inference_loop` a `@staticmethod`/stand‑alone coroutine that reacquires the client each iteration and breaks when the weak ref is dead, or explicitly cancel the task when the client is disposed) to honour the documented lifecycle.
File: src/ares/contrib/transformers_client.py
Line: 207
rsmith49
left a comment
There was a problem hiding this comment.
This is cool @joshgreaves!
A couple comments, only one I feel pretty strongly about is truncation since that seems to be really important for local models, given how quickly context length explodes in coding problems.
| await self._request_queue.put(ValueAndFuture(value=req, future=future)) | ||
| return await future | ||
|
|
||
| async def _inference_loop(self, weak_self: weakref.ReferenceType) -> None: |
There was a problem hiding this comment.
IMO the performance gains by doing this batch aggregation loop (for people running LLMs locally) are not going to outweight the debugging cost for inference issues compared to just running all single inference. If we want to leave it as a cool feature that's fine. But it is going to be a nightmare to debug inference errors.
That being said, it is a cool implementation and fun to see
There was a problem hiding this comment.
This is a good point, it might make sense to have a separate client for this.
It makes a pretty massive difference speed-wise doing this, so at least for this implementation I'd like to keep it in.
| pass | ||
|
|
||
|
|
||
| def _detect_device() -> str: |
There was a problem hiding this comment.
The torch logic here is nice, noting here so I remember to move it to a shared contrib/utils or something that can be used by transformer-lens as well (and any other local inference things)
| input_texts, | ||
| return_tensors="pt", | ||
| padding=True, | ||
| truncation=True, |
There was a problem hiding this comment.
I've found truncation is actually the most complicated part of local inference - for instance, truncation=true will fully take out the <|im_end|> or </assistant> or whatever tags exist that indicate the LLM should respond to the user, instead of completing the user turn.
I think we should make truncation_strategy a first class implementation (probably in a contrib/utils module so other local inference methods can utilize it too). But at the minimum, we should include some kind of instance method or init param that specifies how truncation is done. For now something like the below would work
@dataclasses.dataclass(frozen=True, kw_only=True)
class TransformersLLMClient(llm_clients.LLMClient):
...
truncation_strategy: str | Callable[[str], str] = "auto"
def _generate_batch(
...
if callable(self.truncation_strategy):
input_texts = [self.truncation_strategy(text) for text in input_texts]
hf_truncation = False
elif self.truncation_strategy == "auto":
hf_truncation = True
inputs: transformers.BatchEncoding = self._tokenizer(
input_texts,
return_tensors="pt",
padding=True,
truncation=True,
).to(self._device)
Generated description
Below is a concise technical summary of the changes proposed in this PR:
graph LR call_("__call__"):::added inference_loop_("_inference_loop"):::added process_batch_("_process_batch"):::added generate_batch_("_generate_batch"):::added tokenizer_("_tokenizer"):::added model_("_model"):::added TRANSFORMERS_LIBRARY_("TRANSFORMERS_LIBRARY"):::added TORCH_LIBRARY_("TORCH_LIBRARY"):::added call_ -- "Queues LLMRequest and starts background batching task for responses." --> inference_loop_ inference_loop_ -- "Groups requests by params and invokes batch processor for execution." --> process_batch_ process_batch_ -- "Prepares conversations and calls synchronous generator for model outputs." --> generate_batch_ generate_batch_ -- "Formats messages and decodes tokens using cached tokenizer." --> tokenizer_ generate_batch_ -- "Invokes model.generate with inputs, temperature, and max_new_tokens." --> model_ model_ -- "Loads pretrained AutoModelForCausalLM weights via Transformers library." --> TRANSFORMERS_LIBRARY_ tokenizer_ -- "Fetches AutoTokenizer to encode inputs and decode outputs." --> TRANSFORMERS_LIBRARY_ model_ -- "Uses torch for dtype, device movement, and no-grad evaluation." --> TORCH_LIBRARY_ classDef added stroke:#15AA7A classDef removed stroke:#CD5270 classDef modified stroke:#EDAC4C linkStyle default stroke:#CBD5E1,font-size:13pxIntroduces a
TransformersLLMClientto thecontribmodule to enable local model inference with automatic request batching and device-specific optimizations. Updates the project's dependency groups and CI workflows to support the new client and ensure consistent environment synchronization.llamacppdependency group tollamacpp_clientand updates CI workflows to synchronize all dependency groups and extras.Modified files (4)
Latest Contributors(2)
TransformersLLMClientwith support for automatic batching, lazy-loaded models, and auto-detection of CUDA, MPS, or CPU devices.Modified files (3)
Latest Contributors(2)