-
Notifications
You must be signed in to change notification settings - Fork 10
Async support for inference engine #24
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
7cddae4
ec91fab
a653222
c9dee19
1725038
9f96381
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,5 +1,7 @@ | ||
| import torch | ||
| import pickle | ||
| import asyncio | ||
| from concurrent.futures import ThreadPoolExecutor | ||
|
|
||
| import torch.distributed as dist | ||
|
|
||
|
|
@@ -72,6 +74,9 @@ def exit(self): | |
| self.shm.unlink() | ||
| if not self.enforce_eager: | ||
| del self.graphs, self.graph_pool | ||
| # Clean up executor if it exists | ||
| if hasattr(self, '_executor'): | ||
| self._executor.shutdown(wait=True) | ||
| torch.cuda.synchronize() | ||
| dist.destroy_process_group() | ||
|
|
||
|
|
@@ -110,6 +115,16 @@ def call(self, method_name, *args): | |
| method = getattr(self, method_name, None) | ||
| return method(*args) | ||
|
|
||
| async def call_async(self, method_name, *args): | ||
| """Async version of call that runs in a thread pool executor.""" | ||
| loop = asyncio.get_event_loop() | ||
| # Use default executor or create one if needed | ||
| executor = getattr(self, '_executor', None) | ||
| if executor is None: | ||
| executor = ThreadPoolExecutor(max_workers=1) | ||
| self._executor = executor | ||
| return await loop.run_in_executor(executor, self.call, method_name, *args) | ||
|
Comment on lines
+118
to
+126
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Use Two issues:
🔧 Proposed fix async def call_async(self, method_name, *args):
"""Async version of call that runs in a thread pool executor."""
- loop = asyncio.get_event_loop()
- # Use default executor or create one if needed
- executor = getattr(self, '_executor', None)
- if executor is None:
- executor = ThreadPoolExecutor(max_workers=1)
- self._executor = executor
+ loop = asyncio.get_running_loop()
+ # Lazily create executor on first use
+ if not hasattr(self, '_executor') or self._executor is None:
+ self._executor = ThreadPoolExecutor(max_workers=1)
- return await loop.run_in_executor(executor, self.call, method_name, *args)
+ return await loop.run_in_executor(self._executor, self.call, method_name, *args)Note: The race condition is unlikely in practice since 🤖 Prompt for AI Agents |
||
|
|
||
| def load_model(self, config: Config): | ||
| """Instantiate the underlying model; override to customize.""" | ||
| return AutoModelForDiffusionLM.from_config(config) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,4 +1,6 @@ | ||
| import atexit | ||
| import asyncio | ||
| from concurrent.futures import ThreadPoolExecutor | ||
|
|
||
| import torch.multiprocessing as mp | ||
|
|
||
|
|
@@ -63,6 +65,12 @@ def add_request(self, prompt: str | list[int], sampling_params: SamplingParams): | |
| # Return seq_id so caller can build a stable mapping | ||
| return seq.seq_id | ||
|
|
||
| async def add_request_async(self, prompt: str | list[int], sampling_params: SamplingParams): | ||
| """Async version of add_request (currently synchronous but provided for API consistency).""" | ||
| # Tokenization and sequence creation are fast, but we make it async for consistency | ||
| loop = asyncio.get_event_loop() | ||
| return await loop.run_in_executor(None, self.add_request, prompt, sampling_params) | ||
|
|
||
| def step(self): | ||
| seqs, is_prefill = self.scheduler.schedule() | ||
| sample_output = self.model_runner.call("run", seqs, is_prefill) | ||
|
|
@@ -73,9 +81,33 @@ def step(self): | |
| deltas = [] | ||
| return outputs, num_tokens, is_prefill, n_diff_steps, deltas | ||
|
|
||
| async def step_async(self): | ||
| """Async version of step that runs model inference in a thread pool.""" | ||
| loop = asyncio.get_event_loop() | ||
| executor = getattr(self, '_step_executor', None) | ||
| if executor is None: | ||
| executor = ThreadPoolExecutor(max_workers=1) | ||
| self._step_executor = executor | ||
|
|
||
| def _step(): | ||
| seqs, is_prefill = self.scheduler.schedule() | ||
| sample_output = self.model_runner.call("run", seqs, is_prefill) | ||
| n_diff_steps = self.scheduler.postprocess(seqs, sample_output) | ||
| outputs = [(seq.seq_id, seq.completion_token_ids) for seq in seqs if seq.is_finished] | ||
| num_tokens = sum(seq.num_tokens for seq in seqs) if is_prefill else sum(seq.new_tokens for seq in seqs) | ||
| deltas = [] | ||
| return outputs, num_tokens, is_prefill, n_diff_steps, deltas | ||
|
|
||
| return await loop.run_in_executor(executor, _step) | ||
|
Comment on lines
+84
to
+101
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thread pool executor is not cleaned up in The 🔧 Proposed fix - add cleanup in exit() def exit(self):
if getattr(self, "_exited", False):
return
self._exited = True
+ # Shutdown step executor if created
+ if hasattr(self, '_step_executor') and self._step_executor is not None:
+ self._step_executor.shutdown(wait=False)
if hasattr(self, "model_runner") and self.model_runner is not None:🤖 Prompt for AI Agents |
||
|
|
||
| def is_finished(self): | ||
| return self.scheduler.is_finished() | ||
|
|
||
| async def is_finished_async(self): | ||
| """Async version of is_finished (currently synchronous but provided for API consistency).""" | ||
| loop = asyncio.get_event_loop() | ||
| return await loop.run_in_executor(None, self.is_finished) | ||
|
|
||
| def generate( | ||
| self, | ||
| prompts: list[str] | list[list[int]], | ||
|
|
@@ -129,3 +161,60 @@ def generate( | |
| if use_tqdm: | ||
| pbar.close() | ||
| return outputs | ||
|
|
||
| async def generate_async( | ||
| self, | ||
| prompts: list[str] | list[list[int]], | ||
| sampling_params: SamplingParams | list[SamplingParams], | ||
| use_tqdm: bool = True, | ||
| ) -> list[str]: | ||
| """Async version of generate that allows concurrent request handling.""" | ||
| if use_tqdm: | ||
| pbar = tqdm(total=len(prompts), desc="Generating", dynamic_ncols=True) | ||
| if not isinstance(sampling_params, list): | ||
| sampling_params = [sampling_params] * len(prompts) | ||
| # Map internal seq_id -> input index to keep output order stable | ||
| seqid_to_idx = {} | ||
| # Add requests synchronously to avoid race conditions with scheduler | ||
| # The actual async benefit comes from the inference steps, not request addition | ||
| for idx, (prompt, sp) in enumerate(zip(prompts, sampling_params)): | ||
| sid = self.add_request(prompt, sp) | ||
| seqid_to_idx[sid] = idx | ||
| outputs = [None] * len(prompts) | ||
| prefill_throughput = decode_throughput = 0. | ||
| n_steps = 0 | ||
| n_diff_steps = [-1] * len(prompts) | ||
| while not await self.is_finished_async(): | ||
| t = perf_counter() | ||
| n_steps += 1 | ||
| output, num_tokens, is_prefill, cur_n_diff_steps, _ = await self.step_async() | ||
| if use_tqdm: | ||
| if is_prefill: | ||
| prefill_throughput = num_tokens / (perf_counter() - t) | ||
| else: | ||
| decode_throughput = num_tokens / (perf_counter() - t) | ||
| pbar.set_postfix({ | ||
| "Prefill": f"{int(prefill_throughput)}tok/s", | ||
| "Decode": f"{int(decode_throughput)}tok/s", | ||
| }) | ||
| if cur_n_diff_steps: | ||
| for seq_id, n_step in cur_n_diff_steps.items(): | ||
| if seq_id in seqid_to_idx and n_step >= 0: | ||
| n_diff_steps[seqid_to_idx[seq_id]] = n_step | ||
| for seq_id, token_ids in output: | ||
| if seq_id in seqid_to_idx: | ||
| outputs[seqid_to_idx[seq_id]] = token_ids | ||
| if use_tqdm: | ||
| pbar.update(1) | ||
| await asyncio.sleep(0) | ||
|
|
||
| print(f"Finished in {n_steps} steps, prefill throughput: {prefill_throughput:.2f} tok/s, decode throughput: {decode_throughput:.2f} tok/s") | ||
| assert all(toks is not None for toks in outputs), "Some sequences did not produce outputs" | ||
| outputs = [{ | ||
| "text": self.tokenizer.decode(token_ids).split(self.tokenizer.eos_token)[0], | ||
| "token_ids": token_ids[:token_ids.index(self.config.eos)] if self.config.eos in token_ids else token_ids, | ||
| "n_diff_steps": n_diff_step, | ||
| } for token_ids, n_diff_step in zip(outputs, n_diff_steps)] | ||
| if use_tqdm: | ||
| pbar.close() | ||
| return outputs | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove unused variables
pendingandconn_to_idx.These variables are assigned but never used in the async version of
generate_async.🔧 Proposed fix
# Collect results asynchronously collected = {} - pending = set(slices.keys()) - conn_to_idx = {self.conns[i]: i for i in slices.keys()}📝 Committable suggestion
🧰 Tools
🪛 Ruff (0.14.10)
423-423: Local variable
pendingis assigned to but never usedRemove assignment to unused variable
pending(F841)
424-424: Local variable
conn_to_idxis assigned to but never usedRemove assignment to unused variable
conn_to_idx(F841)
🤖 Prompt for AI Agents