Skip to content

Add a transformers client in contrib.#89

Open
joshgreaves wants to merge 10 commits intomainfrom
transformers-client
Open

Add a transformers client in contrib.#89
joshgreaves wants to merge 10 commits intomainfrom
transformers-client

Conversation

@joshgreaves
Copy link
Contributor

@joshgreaves joshgreaves commented Feb 11, 2026

Generated description

Below is a concise technical summary of the changes proposed in this PR:

graph LR
call_("__call__"):::added
inference_loop_("_inference_loop"):::added
process_batch_("_process_batch"):::added
generate_batch_("_generate_batch"):::added
tokenizer_("_tokenizer"):::added
model_("_model"):::added
TRANSFORMERS_LIBRARY_("TRANSFORMERS_LIBRARY"):::added
TORCH_LIBRARY_("TORCH_LIBRARY"):::added
call_ -- "Queues LLMRequest and starts background batching task for responses." --> inference_loop_
inference_loop_ -- "Groups requests by params and invokes batch processor for execution." --> process_batch_
process_batch_ -- "Prepares conversations and calls synchronous generator for model outputs." --> generate_batch_
generate_batch_ -- "Formats messages and decodes tokens using cached tokenizer." --> tokenizer_
generate_batch_ -- "Invokes model.generate with inputs, temperature, and max_new_tokens." --> model_
model_ -- "Loads pretrained AutoModelForCausalLM weights via Transformers library." --> TRANSFORMERS_LIBRARY_
tokenizer_ -- "Fetches AutoTokenizer to encode inputs and decode outputs." --> TRANSFORMERS_LIBRARY_
model_ -- "Uses torch for dtype, device movement, and no-grad evaluation." --> TORCH_LIBRARY_
classDef added stroke:#15AA7A
classDef removed stroke:#CD5270
classDef modified stroke:#EDAC4C
linkStyle default stroke:#CBD5E1,font-size:13px
Loading

Introduces a TransformersLLMClient to the contrib module to enable local model inference with automatic request batching and device-specific optimizations. Updates the project's dependency groups and CI workflows to support the new client and ensure consistent environment synchronization.

TopicDetails
Dependency Updates Renames the llamacpp dependency group to llamacpp_client and updates CI workflows to synchronize all dependency groups and extras.
Modified files (4)
  • .github/workflows/ruff.yml
  • .github/workflows/unit-tests.yml
  • pyproject.toml
  • src/ares/contrib/llama_cpp.py
Latest Contributors(2)
UserCommitDate
joshua.greaves@gmail.comBump-ARES-to-0.0.2-72January 29, 2026
ryan@withmartian.comAdd-Tinker-Example-58January 29, 2026
Transformers Client Implements TransformersLLMClient with support for automatic batching, lazy-loaded models, and auto-detection of CUDA, MPS, or CPU devices.
Modified files (3)
  • pyproject.toml
  • src/ares/contrib/transformers_client.py
  • src/ares/contrib/transformers_client_test.py
Latest Contributors(2)
UserCommitDate
joshua.greaves@gmail.comBump-ARES-to-0.0.2-72January 29, 2026
ryan@withmartian.comAdd-Tinker-Example-58January 29, 2026
This pull request is reviewed by Baz. Review like a pro on (Baz).

@joshgreaves joshgreaves requested a review from rsmith49 February 11, 2026 23:29
Comment on lines 201 to 207
def _ensure_inference_task_started(self) -> None:
"""Lazy-start the background inference task if not already running."""
if self._inference_task is None or self._inference_task.done():
# Create weakref to self so task can detect when client is GC'd
weak_self = weakref.ref(self)
self._inference_task = asyncio.create_task(self._inference_loop(weak_self))
_LOGGER.info("TransformersLLMClient started background inference task")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Per CLAUDE.md async logging guidance, TransformersLLMClient started background inference task never logs id(self) or another client identifier, so if multiple clients run concurrently you cannot trace which instance started its background task; can we include the client id (e.g., id(self)) in this log?

Suggested change
def _ensure_inference_task_started(self) -> None:
"""Lazy-start the background inference task if not already running."""
if self._inference_task is None or self._inference_task.done():
# Create weakref to self so task can detect when client is GC'd
weak_self = weakref.ref(self)
self._inference_task = asyncio.create_task(self._inference_loop(weak_self))
_LOGGER.info("TransformersLLMClient started background inference task")
def _ensure_inference_task_started(self) -> None:
"""Lazy-start the background inference task if not already running."""
if self._inference_task is None or self._inference_task.done():
# Create weakref to self so task can detect when client is GC'd
weak_self = weakref.ref(self)
self._inference_task = asyncio.create_task(self._inference_loop(weak_self))
_LOGGER.info("TransformersLLMClient(id=%s) started background inference task", id(self))

Finding type: AI Coding Guidelines

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not necessary, IMO

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The log currently just says TransformersLLMClient started background inference task, so when multiple clients are running there’s no way to tell which instance triggered it; including id(self) (or similar) gives us a lightweight identifier for debugging without changing functionality.

Comment on lines 266 to 271
# Process batch
_LOGGER.debug("Processing batch of %d request(s)", len(batch))
await self._process_batch(batch)

except Exception as e:
_LOGGER.exception("Error in inference loop: %s", e)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Per CLAUDE.md async logging guidance, the Error in inference loop exception log never includes id(self) (or similar client identifier), so failed batches cannot be traced back to a particular client instance; can we include the client id in this log?

Suggested change
# Process batch
_LOGGER.debug("Processing batch of %d request(s)", len(batch))
await self._process_batch(batch)
except Exception as e:
_LOGGER.exception("Error in inference loop: %s", e)
# Process batch
_LOGGER.debug("Processing batch of %d request(s)", len(batch))
await self._process_batch(batch)
except Exception as e:
_LOGGER.exception("Error in inference loop for client %s: %s", hex(id(self)), e)

Finding type: AI Coding Guidelines

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not necessary IMO

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I hear you, but tracing the background task failures to a specific client instance is hard without an identifier—the exception already includes the client’s ID now, so we can correlate Error in inference loop for client … messages with the right request queue.

Comment on lines +314 to +316

except Exception as e:
_LOGGER.exception("Error processing batch: %s", e)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Per CLAUDE.md async logging guidance, the Error processing batch exception log never includes id(self) (or similar client identifier), so downstream errors cannot be tied back to a particular client instance when multiple are running; can we include the client id here as well?

Suggested change
except Exception as e:
_LOGGER.exception("Error processing batch: %s", e)
except Exception as e:
_LOGGER.exception("Error processing batch in client id=%s: %s", id(self), e)

Finding type: AI Coding Guidelines

Comment on lines 422 to 430
# Create minimal tokenizer
minimal_tokenizer = transformers.GPT2Tokenizer.from_pretrained(
"gpt2", # Uses cached tokenizer vocab
model_max_length=32,
)
minimal_tokenizer.pad_token = minimal_tokenizer.eos_token
minimal_tokenizer.padding_side = "left"
# Add minimal chat template
minimal_tokenizer.chat_template = "{% for message in messages %}{{ message.content }}{% endfor %}"

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Important

[Testing] The “no-download” integration test still calls transformers.GPT2Tokenizer.from_pretrained("gpt2") (lines 422‑430), which hits the HuggingFace hub whenever the vocab isn’t already cached on the test machine. That reintroduces a network dependency into the unit test suite and will make CI fail in air‑gapped environments. Please construct the tokenizer fully in-code (e.g., build a tiny tokenizers.Tokenizer/PreTrainedTokenizerFast with a hard-coded vocab, or write minimal vocab/merges files to a temporary directory and load from that path) so the test is deterministic and offline.

Context for Agents
The “no-download” integration test still calls `transformers.GPT2Tokenizer.from_pretrained("gpt2")` (lines 422‑430), which hits the HuggingFace hub whenever the vocab isn’t already cached on the test machine. That reintroduces a network dependency into the unit test suite and will make CI fail in air‑gapped environments. Please construct the tokenizer fully in-code (e.g., build a tiny `tokenizers.Tokenizer`/`PreTrainedTokenizerFast` with a hard-coded vocab, or write minimal vocab/merges files to a temporary directory and load from that path) so the test is deterministic and offline.

File: src/ares/contrib/transformers_client_test.py
Line: 430

Comment on lines +154 to +167
@functools.cached_property
def _inference_task(self) -> asyncio.Task[None]:
"""Lazy-initialized background inference task.

The task automatically exits when the client is garbage collected via weakref.
"""
weak_self = weakref.ref(self)
task = asyncio.create_task(self._inference_loop(weak_self))
_LOGGER.info("TransformersLLMClient started background inference task")
return task

@functools.cached_property
def _device(self) -> str:
"""Resolved device."""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TransformersLLMClient logs TransformersLLMClient started background inference task without id(self) even though this async background task can run concurrently for multiple clients, so tracing is ambiguous; CLAUDE.md requires async logging to include object identifiers—can we log the client id in these statements?

Finding type: AI Coding Guidelines

Prompt for AI Agents:

In src/ares/contrib/transformers_client.py around lines 154 to 167, the _inference_task
cached_property logs "TransformersLLMClient started background inference task" without
identifying which client started the task. Update the logging call in that method to
include a unique identifier for the client (for example id(self)) and optionally the
model_name so concurrent clients are distinguishable. Make the log message something
like: "TransformersLLMClient started background inference task (id=%s, model=%s)" and
pass id(self) and self.model_name as parameters to the logger.

Fix in Cursor

Comment on lines +286 to +310
@pytest.mark.asyncio
async def test_integration_with_minimal_model():
"""Integration test with a minimal GPT2 model.

Creates a tiny GPT2 model from scratch for testing the full pipeline.
Note: Downloads GPT2 tokenizer vocab on first run (~500KB, cached after).
"""
# Create minimal GPT2 config - vocab_size must match GPT2Tokenizer (50257)
config = transformers.GPT2Config(
vocab_size=50257, # Must match GPT2Tokenizer vocab
n_positions=32,
n_ctx=32,
n_embd=32,
n_layer=2,
n_head=4,
)

minimal_model = transformers.GPT2LMHeadModel(config)
minimal_model.eval()

# GPT2 tokenizer is lightweight and cached after first download
minimal_tokenizer = transformers.GPT2Tokenizer.from_pretrained(
"gpt2",
model_max_length=32,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

test_integration_with_minimal_model calls transformers.GPT2Tokenizer.from_pretrained("gpt2"), which performs a real network download inside src/ares/contrib/transformers_client_test.py; CLAUDE.md mandates that unit tests under src/ must mock external services and real integration tests belong under integration_tests/, so this test should either mock the tokenizer download or move to the integration test suite (CLAUDE.md).

Finding type: AI Coding Guidelines

Prompt for AI Agents:

In src/ares/contrib/transformers_client_test.py around lines 286 to 310, the
test_integration_with_minimal_model function calls
transformers.GPT2Tokenizer.from_pretrained("gpt2") which performs a real network
download (not allowed for unit tests under src/). Refactor by mocking that call: replace
the direct from_pretrained invocation with a context manager that patches
transformers.GPT2Tokenizer.from_pretrained to return the locally constructed
minimal_tokenizer (use mock.patch.object(transformers.GPT2Tokenizer, "from_pretrained",
return_value=minimal_tokenizer)) so no network I/O occurs; alternatively, if you intend
this to be a true integration test, move the whole test function into integration_tests/
and update imports accordingly.

Fix in Cursor

Comment on lines +279 to +283
self._generate_batch,
chat_conversations,
temperature=temperature,
max_new_tokens=max_new_tokens,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_generate_batch is only called with chat_conversations, temperature, and max_new_tokens, so any other LLMRequest.to_chat_completion_kwargs() fields (e.g. top_p, stop/stop_sequences, tools, etc.) never reach _model.generate; those request options now have no effect when using this client

Finding type: Logical Bugs

Prompt for AI Agents:

In src/ares/contrib/transformers_client.py around lines 279 to 283, the call to
self._generate_batch only forwards (chat_conversations, temperature, max_new_tokens) so
other per-request options returned by to_chat_completion_kwargs (e.g., top_p,
stop/stop_sequences, any decoding options or tool hints) are lost. Refactor so that when
grouping requests you capture and propagate the full generation kwargs for the group
(use the first request's kwargs as the group's canonical kwargs, and only group requests
whose kwargs are identical or compatible), then change _process_batch/_generate_batch
signatures to accept a dict of generation kwargs and pass those through into the
tokenizer/model calls (model.generate and any tokenizer settings) instead of only
passing temperature and max_new_tokens. Ensure to update the grouping logic and tests
accordingly so stop sequences, top_p, etc. are applied per-request group.

Fix in Cursor

Comment on lines +330 to +337
outputs = self._model.generate(
**inputs, # type: ignore[arg-type]
max_new_tokens=max_new_tokens,
temperature=temperature,
do_sample=temperature > 0,
pad_token_id=self._tokenizer.pad_token_id,
)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self._model.generate can return a ModelOutput when config.return_dict_in_generate=True (e.g. some HF configs), but we immediately do generated_ids = outputs[:, input_lengths:] as if a tensor, so the call will fail (no __getitem__) for those models; can we either pass return_dict_in_generate=False to generate or pull outputs.sequences before slicing?

Suggested change
outputs = self._model.generate(
**inputs, # type: ignore[arg-type]
max_new_tokens=max_new_tokens,
temperature=temperature,
do_sample=temperature > 0,
pad_token_id=self._tokenizer.pad_token_id,
)
outputs = self._model.generate(
**inputs, # type: ignore[arg-type]
max_new_tokens=max_new_tokens,
temperature=temperature,
do_sample=temperature > 0,
pad_token_id=self._tokenizer.pad_token_id,
return_dict_in_generate=False,
)

Finding type: Logical Bugs

Comment on lines +193 to +207
async def _inference_loop(self, weak_self: weakref.ReferenceType) -> None:
"""Background task that batches and processes requests.

Groups requests by (temperature, max_tokens) during collection to preserve per-request
semantics. Collects until batch_wait_ms elapses OR any group reaches max_batch_size.

Tradeoff: Grouping at collection time is more efficient than collecting mixed batches
and splitting later, but less efficient than a full meta-queue with per-parameter timers.
Future: Could extend to a queue-of-queues where each parameter combo gets its own queue
and timer, allowing truly independent batching per parameter set.

Args:
weak_self: Weakref to the client - task exits when client is GC'd
"""
while weak_self() is not None:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Important

[Logic] The background task is supposed to stop once the client is garbage collected, but _inference_loop closes over self because it is invoked as an instance method (asyncio.create_task(self._inference_loop(weak_self))). The coroutine holds a strong reference to self, so weak_self() can never become None and the loop never exits. Every client instance therefore leaks its inference task/model/tokenizer forever, preventing process shutdown and continuously pinning GPU/CPU memory. Please restructure the worker so it only keeps a weak reference (e.g. make _inference_loop a @staticmethod/stand‑alone coroutine that reacquires the client each iteration and breaks when the weak ref is dead, or explicitly cancel the task when the client is disposed) to honour the documented lifecycle.

Context for Agents
The background task is supposed to stop once the client is garbage collected, but `_inference_loop` closes over `self` because it is invoked as an instance method (`asyncio.create_task(self._inference_loop(weak_self))`). The coroutine holds a strong reference to `self`, so `weak_self()` can never become `None` and the loop never exits. Every client instance therefore leaks its inference task/model/tokenizer forever, preventing process shutdown and continuously pinning GPU/CPU memory. Please restructure the worker so it only keeps a weak reference (e.g. make `_inference_loop` a `@staticmethod`/stand‑alone coroutine that reacquires the client each iteration and breaks when the weak ref is dead, or explicitly cancel the task when the client is disposed) to honour the documented lifecycle.

File: src/ares/contrib/transformers_client.py
Line: 207

Copy link
Contributor

@rsmith49 rsmith49 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is cool @joshgreaves!

A couple comments, only one I feel pretty strongly about is truncation since that seems to be really important for local models, given how quickly context length explodes in coding problems.

await self._request_queue.put(ValueAndFuture(value=req, future=future))
return await future

async def _inference_loop(self, weak_self: weakref.ReferenceType) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO the performance gains by doing this batch aggregation loop (for people running LLMs locally) are not going to outweight the debugging cost for inference issues compared to just running all single inference. If we want to leave it as a cool feature that's fine. But it is going to be a nightmare to debug inference errors.

That being said, it is a cool implementation and fun to see

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a good point, it might make sense to have a separate client for this.
It makes a pretty massive difference speed-wise doing this, so at least for this implementation I'd like to keep it in.

pass


def _detect_device() -> str:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The torch logic here is nice, noting here so I remember to move it to a shared contrib/utils or something that can be used by transformer-lens as well (and any other local inference things)

input_texts,
return_tensors="pt",
padding=True,
truncation=True,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've found truncation is actually the most complicated part of local inference - for instance, truncation=true will fully take out the <|im_end|> or </assistant> or whatever tags exist that indicate the LLM should respond to the user, instead of completing the user turn.

I think we should make truncation_strategy a first class implementation (probably in a contrib/utils module so other local inference methods can utilize it too). But at the minimum, we should include some kind of instance method or init param that specifies how truncation is done. For now something like the below would work

@dataclasses.dataclass(frozen=True, kw_only=True)
class TransformersLLMClient(llm_clients.LLMClient):
    ...
    truncation_strategy: str | Callable[[str], str] = "auto"

    def _generate_batch(
        ...
        if callable(self.truncation_strategy):
            input_texts = [self.truncation_strategy(text) for text in input_texts]
            hf_truncation = False
        elif self.truncation_strategy == "auto":
            hf_truncation = True

        inputs: transformers.BatchEncoding = self._tokenizer(
            input_texts,
            return_tensors="pt",
            padding=True,
            truncation=True,
        ).to(self._device)    

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants