Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
175 changes: 175 additions & 0 deletions diffulex/engine/dp_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@
import atexit
import traceback
import faulthandler
import asyncio

import multiprocessing as mp

from typing import Any
from multiprocessing.connection import wait as mp_wait
from concurrent.futures import ThreadPoolExecutor

from diffulex.config import Config
from diffulex.engine.tp_worker import DiffulexTPWorker
Expand Down Expand Up @@ -140,6 +142,7 @@ def __init__(self, model, **kwargs):
self._gid_counter = 0
self._gid_map = {} # (replica, local_id) -> global_id
self._rev_gid_map = {} # global_id -> (replica, local_id)
self._executor = ThreadPoolExecutor(max_workers=self.dp_size)
atexit.register(self.exit)

def _ask(self, replica: int, cmd: str, *args):
Expand All @@ -159,7 +162,15 @@ def _ask(self, replica: int, cmd: str, *args):
return payload
raise RuntimeError(f"DP child #{replica} error: {payload}")

async def _ask_async(self, replica: int, cmd: str, *args):
"""Async version of _ask that runs in a thread pool."""
loop = asyncio.get_event_loop()
return await loop.run_in_executor(self._executor, self._ask, replica, cmd, *args)

def exit(self):
# Shutdown executor
if hasattr(self, '_executor'):
self._executor.shutdown(wait=True)
for i, p in enumerate(self.ps):
if p.is_alive():
try:
Expand All @@ -178,6 +189,17 @@ def add_request(self, prompt: str | list[int], sampling_params: SamplingParams):
self._rev_gid_map[gid] = (target, local_id)
return gid

async def add_request_async(self, prompt: str | list[int], sampling_params: SamplingParams):
"""Async version of add_request."""
target = self._rr
self._rr = (self._rr + 1) % self.dp_size
local_id = await self._ask_async(target, "add_request", prompt, sampling_params)
gid = self._gid_counter
self._gid_counter += 1
self._gid_map[(target, local_id)] = gid
self._rev_gid_map[gid] = (target, local_id)
return gid

def step(self):
all_outputs = []
total_tokens = 0
Expand Down Expand Up @@ -206,9 +228,54 @@ def step(self):
merged_deltas.append((gid, toks, fin))
return all_outputs, total_tokens, any_prefill, merged_diff_steps, merged_deltas

async def step_async(self):
"""Async version of step that runs all DP replicas concurrently."""
all_outputs = []
total_tokens = 0
any_prefill = False
merged_diff_steps = {}
merged_deltas = []

# Check all replicas in parallel
tasks = [self._ask_async(i, "is_finished") for i in range(self.dp_size)]
done_flags = await asyncio.gather(*tasks)

# Step all non-finished replicas in parallel
step_tasks = []
for i, done in enumerate(done_flags):
if not done:
step_tasks.append((i, self._ask_async(i, "step")))

if step_tasks:
step_results = await asyncio.gather(*[task for _, task in step_tasks])
for (i, _), (outputs, num_tokens, is_prefill, n_diff_steps, deltas) in zip(step_tasks, step_results):
if outputs:
# remap local seq_ids to global ids
for sid, toks in outputs:
gid = self._gid_map.get((i, sid), None)
if gid is not None:
all_outputs.append((gid, toks))
total_tokens += num_tokens
any_prefill = any_prefill or is_prefill
if n_diff_steps:
merged_diff_steps.update(n_diff_steps)
if deltas:
for sid, toks, fin in deltas:
gid = self._gid_map.get((i, sid), None)
if gid is not None:
merged_deltas.append((gid, toks, fin))

return all_outputs, total_tokens, any_prefill, merged_diff_steps, merged_deltas

def is_finished(self):
return all(self._ask(i, "is_finished") for i in range(self.dp_size))

async def is_finished_async(self):
"""Async version of is_finished that checks all replicas in parallel."""
tasks = [self._ask_async(i, "is_finished") for i in range(self.dp_size)]
results = await asyncio.gather(*tasks)
return all(results)

def generate(self, prompts: list[str] | list[list[int]], sampling_params: SamplingParams | list[SamplingParams], use_tqdm: bool = True):
"""Load-balanced generate with random shuffling and stable order restoration.
- Randomly shuffle inputs to balance load across DP replicas.
Expand Down Expand Up @@ -292,3 +359,111 @@ def generate(self, prompts: list[str] | list[list[int]], sampling_params: Sampli
restored[orig_idx] = out
assert all(x is not None for x in restored), "Mismatch in outputs after DP collection"
return restored

async def generate_async(self, prompts: list[str] | list[list[int]], sampling_params: SamplingParams | list[SamplingParams], use_tqdm: bool = True):
"""Async version of generate that allows concurrent request handling."""
import random
n = len(prompts)
idxs = list(range(n))
random.shuffle(idxs)
shuffled_prompts = [prompts[i] for i in idxs]
# Align sampling params with shuffled prompts
if isinstance(sampling_params, list):
if len(sampling_params) == n:
shuffled_sps = [sampling_params[i] for i in idxs]
elif len(sampling_params) == self.dp_size:
# per-shard SP; keep as-is and broadcast per-shard below
shuffled_sps = sampling_params
else:
shuffled_sps = [sampling_params[0]] * n
else:
shuffled_sps = sampling_params

# Even partition of shuffled inputs
base = n // self.dp_size
rem = n % self.dp_size
slices = {}
start = 0
for i in range(self.dp_size):
add = base + (1 if i < rem else 0)
end = start + add
if start < end:
slices[i] = (start, end)
start = end

# Send generate requests to all replicas concurrently
async def send_generate(replica_idx: int, start_idx: int, end_idx: int):
if isinstance(shuffled_sps, list):
if len(shuffled_sps) == n:
sp_arg = shuffled_sps[start_idx:end_idx]
elif len(shuffled_sps) == self.dp_size:
sp_arg = shuffled_sps[replica_idx]
else:
sp_arg = shuffled_sps[0]
else:
sp_arg = shuffled_sps
conn = self.conns[replica_idx]
# Send in executor to avoid blocking
loop = asyncio.get_event_loop()
await loop.run_in_executor(
self._executor,
conn.send,
("generate", shuffled_prompts[start_idx:end_idx], sp_arg, use_tqdm)
)
return replica_idx

# Send all requests concurrently
send_tasks = [
send_generate(i, s, e) for i, (s, e) in slices.items()
]
await asyncio.gather(*send_tasks)

# Collect results asynchronously
collected = {}
pending = set(slices.keys())
conn_to_idx = {self.conns[i]: i for i in slices.keys()}
Comment on lines +421 to +424
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Remove unused variables pending and conn_to_idx.

These variables are assigned but never used in the async version of generate_async.

🔧 Proposed fix
         # Collect results asynchronously
         collected = {}
-        pending = set(slices.keys())
-        conn_to_idx = {self.conns[i]: i for i in slices.keys()}
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
# Collect results asynchronously
collected = {}
pending = set(slices.keys())
conn_to_idx = {self.conns[i]: i for i in slices.keys()}
# Collect results asynchronously
collected = {}
🧰 Tools
🪛 Ruff (0.14.10)

423-423: Local variable pending is assigned to but never used

Remove assignment to unused variable pending

(F841)


424-424: Local variable conn_to_idx is assigned to but never used

Remove assignment to unused variable conn_to_idx

(F841)

🤖 Prompt for AI Agents
In @diffulex/engine/dp_worker.py around lines 421 - 424, In generate_async,
remove the unused local variables pending and conn_to_idx (they are assigned
from slices.keys() and self.conns but never referenced); edit the function to
delete the lines that create pending and conn_to_idx so only the used variable
collected and slices remain, and run tests/lint to ensure no other references to
pending or conn_to_idx exist.


async def wait_for_result(replica_idx: int):
conn = self.conns[replica_idx]
loop = asyncio.get_event_loop()
try:
# Poll for data availability in executor, then recv
def check_and_recv():
# Poll is non-blocking, but we run it in executor to be safe
if conn.poll():
return conn.recv()
return None

# Poll until data is available
while True:
result = await loop.run_in_executor(self._executor, check_and_recv)
if result is not None:
tag, payload = result
break
await asyncio.sleep(0.001) # Small sleep to yield control
except EOFError:
p = self.ps[replica_idx]
exitcode = p.exitcode
raise RuntimeError(
f"DP child #{replica_idx} terminated unexpectedly during generate (exitcode={exitcode}). "
f"Enable envs: PYTHONFAULTHANDLER=1 CUDA_LAUNCH_BLOCKING=1 TORCH_SHOW_CPP_STACKTRACES=1 for more info."
)
if tag == "ok":
collected[replica_idx] = payload
else:
raise RuntimeError(f"DP child #{replica_idx} error: {payload}")

# Wait for all results concurrently
await asyncio.gather(*[wait_for_result(i) for i in slices.keys()])

# Restore to original order
restored = [None] * n
for i, (s, e) in slices.items():
outs = collected.get(i, [])
# outs are aligned with shuffled order s:e
for local_k, out in enumerate(outs):
global_pos = s + local_k
orig_idx = idxs[global_pos]
restored[orig_idx] = out
assert all(x is not None for x in restored), "Mismatch in outputs after DP collection"
return restored
15 changes: 15 additions & 0 deletions diffulex/engine/model_runner.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import torch
import pickle
import asyncio
from concurrent.futures import ThreadPoolExecutor

import torch.distributed as dist

Expand Down Expand Up @@ -72,6 +74,9 @@ def exit(self):
self.shm.unlink()
if not self.enforce_eager:
del self.graphs, self.graph_pool
# Clean up executor if it exists
if hasattr(self, '_executor'):
self._executor.shutdown(wait=True)
torch.cuda.synchronize()
dist.destroy_process_group()

Expand Down Expand Up @@ -110,6 +115,16 @@ def call(self, method_name, *args):
method = getattr(self, method_name, None)
return method(*args)

async def call_async(self, method_name, *args):
"""Async version of call that runs in a thread pool executor."""
loop = asyncio.get_event_loop()
# Use default executor or create one if needed
executor = getattr(self, '_executor', None)
if executor is None:
executor = ThreadPoolExecutor(max_workers=1)
self._executor = executor
return await loop.run_in_executor(executor, self.call, method_name, *args)
Comment on lines +118 to +126
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Use asyncio.get_running_loop() and fix potential race condition in executor creation.

Two issues:

  1. asyncio.get_event_loop() is deprecated since Python 3.10 and may raise a DeprecationWarning. Use asyncio.get_running_loop() which is the recommended approach when called from within a coroutine.

  2. The executor creation is not thread-safe. If multiple coroutines call call_async concurrently before _executor is set, multiple executors could be created.

🔧 Proposed fix
     async def call_async(self, method_name, *args):
         """Async version of call that runs in a thread pool executor."""
-        loop = asyncio.get_event_loop()
-        # Use default executor or create one if needed
-        executor = getattr(self, '_executor', None)
-        if executor is None:
-            executor = ThreadPoolExecutor(max_workers=1)
-            self._executor = executor
+        loop = asyncio.get_running_loop()
+        # Lazily create executor on first use
+        if not hasattr(self, '_executor') or self._executor is None:
+            self._executor = ThreadPoolExecutor(max_workers=1)
-        return await loop.run_in_executor(executor, self.call, method_name, *args)
+        return await loop.run_in_executor(self._executor, self.call, method_name, *args)

Note: The race condition is unlikely in practice since ModelRunnerBase instances are typically used from a single async context, but worth addressing for correctness.

🤖 Prompt for AI Agents
In @diffulex/engine/model_runner.py around lines 118 - 126, Replace
asyncio.get_event_loop() with asyncio.get_running_loop() in call_async, and
serialize executor creation using an asyncio.Lock: add an asyncio.Lock instance
(e.g., self._executor_lock) during the ModelRunnerBase __init__, then in
call_async do a double-checked pattern: if getattr(self, '_executor', None) is
None: await self._executor_lock; check _executor again and only then create and
assign ThreadPoolExecutor(max_workers=1) to self._executor; finally call
loop.run_in_executor(self._executor, self.call, method_name, *args). This fixes
the deprecated API usage and prevents concurrent creation of multiple executors.


def load_model(self, config: Config):
"""Instantiate the underlying model; override to customize."""
return AutoModelForDiffusionLM.from_config(config)
Expand Down
89 changes: 89 additions & 0 deletions diffulex/engine/tp_worker.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import atexit
import asyncio
from concurrent.futures import ThreadPoolExecutor

import torch.multiprocessing as mp

Expand Down Expand Up @@ -63,6 +65,12 @@ def add_request(self, prompt: str | list[int], sampling_params: SamplingParams):
# Return seq_id so caller can build a stable mapping
return seq.seq_id

async def add_request_async(self, prompt: str | list[int], sampling_params: SamplingParams):
"""Async version of add_request (currently synchronous but provided for API consistency)."""
# Tokenization and sequence creation are fast, but we make it async for consistency
loop = asyncio.get_event_loop()
return await loop.run_in_executor(None, self.add_request, prompt, sampling_params)

def step(self):
seqs, is_prefill = self.scheduler.schedule()
sample_output = self.model_runner.call("run", seqs, is_prefill)
Expand All @@ -73,9 +81,33 @@ def step(self):
deltas = []
return outputs, num_tokens, is_prefill, n_diff_steps, deltas

async def step_async(self):
"""Async version of step that runs model inference in a thread pool."""
loop = asyncio.get_event_loop()
executor = getattr(self, '_step_executor', None)
if executor is None:
executor = ThreadPoolExecutor(max_workers=1)
self._step_executor = executor

def _step():
seqs, is_prefill = self.scheduler.schedule()
sample_output = self.model_runner.call("run", seqs, is_prefill)
n_diff_steps = self.scheduler.postprocess(seqs, sample_output)
outputs = [(seq.seq_id, seq.completion_token_ids) for seq in seqs if seq.is_finished]
num_tokens = sum(seq.num_tokens for seq in seqs) if is_prefill else sum(seq.new_tokens for seq in seqs)
deltas = []
return outputs, num_tokens, is_prefill, n_diff_steps, deltas

return await loop.run_in_executor(executor, _step)
Comment on lines +84 to +101
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Thread pool executor is not cleaned up in exit().

The _step_executor ThreadPoolExecutor is lazily created in step_async() but is never shut down when the worker exits. This can lead to resource leaks and hanging threads.

🔧 Proposed fix - add cleanup in exit()
     def exit(self):
         if getattr(self, "_exited", False):
             return
         self._exited = True
+        # Shutdown step executor if created
+        if hasattr(self, '_step_executor') and self._step_executor is not None:
+            self._step_executor.shutdown(wait=False)
         if hasattr(self, "model_runner") and self.model_runner is not None:
🤖 Prompt for AI Agents
In @diffulex/engine/tp_worker.py around lines 84 - 101, The ThreadPoolExecutor
created lazily in step_async() (stored as self._step_executor) is never shut
down causing leaked threads; update the worker's exit() method to check for
self._step_executor, call its shutdown(wait=True) (or shutdown(wait=False) if
non-blocking termination is preferred), handle any exceptions, and set
self._step_executor = None so resources are released and repeated exits are
safe; ensure exit() cleans up even if executor was never created.


def is_finished(self):
return self.scheduler.is_finished()

async def is_finished_async(self):
"""Async version of is_finished (currently synchronous but provided for API consistency)."""
loop = asyncio.get_event_loop()
return await loop.run_in_executor(None, self.is_finished)

def generate(
self,
prompts: list[str] | list[list[int]],
Expand Down Expand Up @@ -129,3 +161,60 @@ def generate(
if use_tqdm:
pbar.close()
return outputs

async def generate_async(
self,
prompts: list[str] | list[list[int]],
sampling_params: SamplingParams | list[SamplingParams],
use_tqdm: bool = True,
) -> list[str]:
"""Async version of generate that allows concurrent request handling."""
if use_tqdm:
pbar = tqdm(total=len(prompts), desc="Generating", dynamic_ncols=True)
if not isinstance(sampling_params, list):
sampling_params = [sampling_params] * len(prompts)
# Map internal seq_id -> input index to keep output order stable
seqid_to_idx = {}
# Add requests synchronously to avoid race conditions with scheduler
# The actual async benefit comes from the inference steps, not request addition
for idx, (prompt, sp) in enumerate(zip(prompts, sampling_params)):
sid = self.add_request(prompt, sp)
seqid_to_idx[sid] = idx
outputs = [None] * len(prompts)
prefill_throughput = decode_throughput = 0.
n_steps = 0
n_diff_steps = [-1] * len(prompts)
while not await self.is_finished_async():
t = perf_counter()
n_steps += 1
output, num_tokens, is_prefill, cur_n_diff_steps, _ = await self.step_async()
if use_tqdm:
if is_prefill:
prefill_throughput = num_tokens / (perf_counter() - t)
else:
decode_throughput = num_tokens / (perf_counter() - t)
pbar.set_postfix({
"Prefill": f"{int(prefill_throughput)}tok/s",
"Decode": f"{int(decode_throughput)}tok/s",
})
if cur_n_diff_steps:
for seq_id, n_step in cur_n_diff_steps.items():
if seq_id in seqid_to_idx and n_step >= 0:
n_diff_steps[seqid_to_idx[seq_id]] = n_step
for seq_id, token_ids in output:
if seq_id in seqid_to_idx:
outputs[seqid_to_idx[seq_id]] = token_ids
if use_tqdm:
pbar.update(1)
await asyncio.sleep(0)

print(f"Finished in {n_steps} steps, prefill throughput: {prefill_throughput:.2f} tok/s, decode throughput: {decode_throughput:.2f} tok/s")
assert all(toks is not None for toks in outputs), "Some sequences did not produce outputs"
outputs = [{
"text": self.tokenizer.decode(token_ids).split(self.tokenizer.eos_token)[0],
"token_ids": token_ids[:token_ids.index(self.config.eos)] if self.config.eos in token_ids else token_ids,
"n_diff_steps": n_diff_step,
} for token_ids, n_diff_step in zip(outputs, n_diff_steps)]
if use_tqdm:
pbar.close()
return outputs
14 changes: 12 additions & 2 deletions diffulex/sampler/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,10 +86,20 @@ def __init__(self):
self.seq_last_logits_map: dict[str, torch.Tensor] = {}

def _fetch_last_logits(self, logits: torch.Tensor, seq: SequenceBase) -> torch.Tensor:
seq_id_str = str(seq.seq_id)
if seq.has_to_cache_block:
last_logits = logits[seq.to_cache_last_token_id]
self.seq_last_logits_map[seq.seq_id] = last_logits
return self.seq_last_logits_map[seq.seq_id]
self.seq_last_logits_map[seq_id_str] = last_logits
return last_logits
# If no cached block, return cached value if available, otherwise use last logit
if seq_id_str in self.seq_last_logits_map:
return self.seq_last_logits_map[seq_id_str]
# Fallback: use last logit from current batch and cache it
last_logits = logits[-1] if logits.shape[0] > 0 else None
if last_logits is not None:
self.seq_last_logits_map[seq_id_str] = last_logits
return last_logits
raise ValueError(f"Cannot fetch last logits for sequence {seq.seq_id}: empty logits tensor")

def _shift_logits(self, logits, last_logit=None):
if logits.shape[1] == 0:
Expand Down
Loading