diff --git a/diffulex/engine/dp_worker.py b/diffulex/engine/dp_worker.py index 0281930..0f03004 100755 --- a/diffulex/engine/dp_worker.py +++ b/diffulex/engine/dp_worker.py @@ -4,11 +4,13 @@ import atexit import traceback import faulthandler +import asyncio import multiprocessing as mp from typing import Any from multiprocessing.connection import wait as mp_wait +from concurrent.futures import ThreadPoolExecutor from diffulex.config import Config from diffulex.engine.tp_worker import DiffulexTPWorker @@ -140,6 +142,7 @@ def __init__(self, model, **kwargs): self._gid_counter = 0 self._gid_map = {} # (replica, local_id) -> global_id self._rev_gid_map = {} # global_id -> (replica, local_id) + self._executor = ThreadPoolExecutor(max_workers=self.dp_size) atexit.register(self.exit) def _ask(self, replica: int, cmd: str, *args): @@ -159,7 +162,15 @@ def _ask(self, replica: int, cmd: str, *args): return payload raise RuntimeError(f"DP child #{replica} error: {payload}") + async def _ask_async(self, replica: int, cmd: str, *args): + """Async version of _ask that runs in a thread pool.""" + loop = asyncio.get_event_loop() + return await loop.run_in_executor(self._executor, self._ask, replica, cmd, *args) + def exit(self): + # Shutdown executor + if hasattr(self, '_executor'): + self._executor.shutdown(wait=True) for i, p in enumerate(self.ps): if p.is_alive(): try: @@ -178,6 +189,17 @@ def add_request(self, prompt: str | list[int], sampling_params: SamplingParams): self._rev_gid_map[gid] = (target, local_id) return gid + async def add_request_async(self, prompt: str | list[int], sampling_params: SamplingParams): + """Async version of add_request.""" + target = self._rr + self._rr = (self._rr + 1) % self.dp_size + local_id = await self._ask_async(target, "add_request", prompt, sampling_params) + gid = self._gid_counter + self._gid_counter += 1 + self._gid_map[(target, local_id)] = gid + self._rev_gid_map[gid] = (target, local_id) + return gid + def step(self): all_outputs = [] total_tokens = 0 @@ -206,9 +228,54 @@ def step(self): merged_deltas.append((gid, toks, fin)) return all_outputs, total_tokens, any_prefill, merged_diff_steps, merged_deltas + async def step_async(self): + """Async version of step that runs all DP replicas concurrently.""" + all_outputs = [] + total_tokens = 0 + any_prefill = False + merged_diff_steps = {} + merged_deltas = [] + + # Check all replicas in parallel + tasks = [self._ask_async(i, "is_finished") for i in range(self.dp_size)] + done_flags = await asyncio.gather(*tasks) + + # Step all non-finished replicas in parallel + step_tasks = [] + for i, done in enumerate(done_flags): + if not done: + step_tasks.append((i, self._ask_async(i, "step"))) + + if step_tasks: + step_results = await asyncio.gather(*[task for _, task in step_tasks]) + for (i, _), (outputs, num_tokens, is_prefill, n_diff_steps, deltas) in zip(step_tasks, step_results): + if outputs: + # remap local seq_ids to global ids + for sid, toks in outputs: + gid = self._gid_map.get((i, sid), None) + if gid is not None: + all_outputs.append((gid, toks)) + total_tokens += num_tokens + any_prefill = any_prefill or is_prefill + if n_diff_steps: + merged_diff_steps.update(n_diff_steps) + if deltas: + for sid, toks, fin in deltas: + gid = self._gid_map.get((i, sid), None) + if gid is not None: + merged_deltas.append((gid, toks, fin)) + + return all_outputs, total_tokens, any_prefill, merged_diff_steps, merged_deltas + def is_finished(self): return all(self._ask(i, "is_finished") for i in range(self.dp_size)) + async def is_finished_async(self): + """Async version of is_finished that checks all replicas in parallel.""" + tasks = [self._ask_async(i, "is_finished") for i in range(self.dp_size)] + results = await asyncio.gather(*tasks) + return all(results) + def generate(self, prompts: list[str] | list[list[int]], sampling_params: SamplingParams | list[SamplingParams], use_tqdm: bool = True): """Load-balanced generate with random shuffling and stable order restoration. - Randomly shuffle inputs to balance load across DP replicas. @@ -292,3 +359,111 @@ def generate(self, prompts: list[str] | list[list[int]], sampling_params: Sampli restored[orig_idx] = out assert all(x is not None for x in restored), "Mismatch in outputs after DP collection" return restored + + async def generate_async(self, prompts: list[str] | list[list[int]], sampling_params: SamplingParams | list[SamplingParams], use_tqdm: bool = True): + """Async version of generate that allows concurrent request handling.""" + import random + n = len(prompts) + idxs = list(range(n)) + random.shuffle(idxs) + shuffled_prompts = [prompts[i] for i in idxs] + # Align sampling params with shuffled prompts + if isinstance(sampling_params, list): + if len(sampling_params) == n: + shuffled_sps = [sampling_params[i] for i in idxs] + elif len(sampling_params) == self.dp_size: + # per-shard SP; keep as-is and broadcast per-shard below + shuffled_sps = sampling_params + else: + shuffled_sps = [sampling_params[0]] * n + else: + shuffled_sps = sampling_params + + # Even partition of shuffled inputs + base = n // self.dp_size + rem = n % self.dp_size + slices = {} + start = 0 + for i in range(self.dp_size): + add = base + (1 if i < rem else 0) + end = start + add + if start < end: + slices[i] = (start, end) + start = end + + # Send generate requests to all replicas concurrently + async def send_generate(replica_idx: int, start_idx: int, end_idx: int): + if isinstance(shuffled_sps, list): + if len(shuffled_sps) == n: + sp_arg = shuffled_sps[start_idx:end_idx] + elif len(shuffled_sps) == self.dp_size: + sp_arg = shuffled_sps[replica_idx] + else: + sp_arg = shuffled_sps[0] + else: + sp_arg = shuffled_sps + conn = self.conns[replica_idx] + # Send in executor to avoid blocking + loop = asyncio.get_event_loop() + await loop.run_in_executor( + self._executor, + conn.send, + ("generate", shuffled_prompts[start_idx:end_idx], sp_arg, use_tqdm) + ) + return replica_idx + + # Send all requests concurrently + send_tasks = [ + send_generate(i, s, e) for i, (s, e) in slices.items() + ] + await asyncio.gather(*send_tasks) + + # Collect results asynchronously + collected = {} + pending = set(slices.keys()) + conn_to_idx = {self.conns[i]: i for i in slices.keys()} + + async def wait_for_result(replica_idx: int): + conn = self.conns[replica_idx] + loop = asyncio.get_event_loop() + try: + # Poll for data availability in executor, then recv + def check_and_recv(): + # Poll is non-blocking, but we run it in executor to be safe + if conn.poll(): + return conn.recv() + return None + + # Poll until data is available + while True: + result = await loop.run_in_executor(self._executor, check_and_recv) + if result is not None: + tag, payload = result + break + await asyncio.sleep(0.001) # Small sleep to yield control + except EOFError: + p = self.ps[replica_idx] + exitcode = p.exitcode + raise RuntimeError( + f"DP child #{replica_idx} terminated unexpectedly during generate (exitcode={exitcode}). " + f"Enable envs: PYTHONFAULTHANDLER=1 CUDA_LAUNCH_BLOCKING=1 TORCH_SHOW_CPP_STACKTRACES=1 for more info." + ) + if tag == "ok": + collected[replica_idx] = payload + else: + raise RuntimeError(f"DP child #{replica_idx} error: {payload}") + + # Wait for all results concurrently + await asyncio.gather(*[wait_for_result(i) for i in slices.keys()]) + + # Restore to original order + restored = [None] * n + for i, (s, e) in slices.items(): + outs = collected.get(i, []) + # outs are aligned with shuffled order s:e + for local_k, out in enumerate(outs): + global_pos = s + local_k + orig_idx = idxs[global_pos] + restored[orig_idx] = out + assert all(x is not None for x in restored), "Mismatch in outputs after DP collection" + return restored diff --git a/diffulex/engine/model_runner.py b/diffulex/engine/model_runner.py index 0316dd0..4a2694b 100755 --- a/diffulex/engine/model_runner.py +++ b/diffulex/engine/model_runner.py @@ -1,5 +1,7 @@ import torch import pickle +import asyncio +from concurrent.futures import ThreadPoolExecutor import torch.distributed as dist @@ -72,6 +74,9 @@ def exit(self): self.shm.unlink() if not self.enforce_eager: del self.graphs, self.graph_pool + # Clean up executor if it exists + if hasattr(self, '_executor'): + self._executor.shutdown(wait=True) torch.cuda.synchronize() dist.destroy_process_group() @@ -110,6 +115,16 @@ def call(self, method_name, *args): method = getattr(self, method_name, None) return method(*args) + async def call_async(self, method_name, *args): + """Async version of call that runs in a thread pool executor.""" + loop = asyncio.get_event_loop() + # Use default executor or create one if needed + executor = getattr(self, '_executor', None) + if executor is None: + executor = ThreadPoolExecutor(max_workers=1) + self._executor = executor + return await loop.run_in_executor(executor, self.call, method_name, *args) + def load_model(self, config: Config): """Instantiate the underlying model; override to customize.""" return AutoModelForDiffusionLM.from_config(config) diff --git a/diffulex/engine/tp_worker.py b/diffulex/engine/tp_worker.py index 3ea53c5..474a884 100755 --- a/diffulex/engine/tp_worker.py +++ b/diffulex/engine/tp_worker.py @@ -1,4 +1,6 @@ import atexit +import asyncio +from concurrent.futures import ThreadPoolExecutor import torch.multiprocessing as mp @@ -63,6 +65,12 @@ def add_request(self, prompt: str | list[int], sampling_params: SamplingParams): # Return seq_id so caller can build a stable mapping return seq.seq_id + async def add_request_async(self, prompt: str | list[int], sampling_params: SamplingParams): + """Async version of add_request (currently synchronous but provided for API consistency).""" + # Tokenization and sequence creation are fast, but we make it async for consistency + loop = asyncio.get_event_loop() + return await loop.run_in_executor(None, self.add_request, prompt, sampling_params) + def step(self): seqs, is_prefill = self.scheduler.schedule() sample_output = self.model_runner.call("run", seqs, is_prefill) @@ -73,9 +81,33 @@ def step(self): deltas = [] return outputs, num_tokens, is_prefill, n_diff_steps, deltas + async def step_async(self): + """Async version of step that runs model inference in a thread pool.""" + loop = asyncio.get_event_loop() + executor = getattr(self, '_step_executor', None) + if executor is None: + executor = ThreadPoolExecutor(max_workers=1) + self._step_executor = executor + + def _step(): + seqs, is_prefill = self.scheduler.schedule() + sample_output = self.model_runner.call("run", seqs, is_prefill) + n_diff_steps = self.scheduler.postprocess(seqs, sample_output) + outputs = [(seq.seq_id, seq.completion_token_ids) for seq in seqs if seq.is_finished] + num_tokens = sum(seq.num_tokens for seq in seqs) if is_prefill else sum(seq.new_tokens for seq in seqs) + deltas = [] + return outputs, num_tokens, is_prefill, n_diff_steps, deltas + + return await loop.run_in_executor(executor, _step) + def is_finished(self): return self.scheduler.is_finished() + async def is_finished_async(self): + """Async version of is_finished (currently synchronous but provided for API consistency).""" + loop = asyncio.get_event_loop() + return await loop.run_in_executor(None, self.is_finished) + def generate( self, prompts: list[str] | list[list[int]], @@ -129,3 +161,60 @@ def generate( if use_tqdm: pbar.close() return outputs + + async def generate_async( + self, + prompts: list[str] | list[list[int]], + sampling_params: SamplingParams | list[SamplingParams], + use_tqdm: bool = True, + ) -> list[str]: + """Async version of generate that allows concurrent request handling.""" + if use_tqdm: + pbar = tqdm(total=len(prompts), desc="Generating", dynamic_ncols=True) + if not isinstance(sampling_params, list): + sampling_params = [sampling_params] * len(prompts) + # Map internal seq_id -> input index to keep output order stable + seqid_to_idx = {} + # Add requests synchronously to avoid race conditions with scheduler + # The actual async benefit comes from the inference steps, not request addition + for idx, (prompt, sp) in enumerate(zip(prompts, sampling_params)): + sid = self.add_request(prompt, sp) + seqid_to_idx[sid] = idx + outputs = [None] * len(prompts) + prefill_throughput = decode_throughput = 0. + n_steps = 0 + n_diff_steps = [-1] * len(prompts) + while not await self.is_finished_async(): + t = perf_counter() + n_steps += 1 + output, num_tokens, is_prefill, cur_n_diff_steps, _ = await self.step_async() + if use_tqdm: + if is_prefill: + prefill_throughput = num_tokens / (perf_counter() - t) + else: + decode_throughput = num_tokens / (perf_counter() - t) + pbar.set_postfix({ + "Prefill": f"{int(prefill_throughput)}tok/s", + "Decode": f"{int(decode_throughput)}tok/s", + }) + if cur_n_diff_steps: + for seq_id, n_step in cur_n_diff_steps.items(): + if seq_id in seqid_to_idx and n_step >= 0: + n_diff_steps[seqid_to_idx[seq_id]] = n_step + for seq_id, token_ids in output: + if seq_id in seqid_to_idx: + outputs[seqid_to_idx[seq_id]] = token_ids + if use_tqdm: + pbar.update(1) + await asyncio.sleep(0) + + print(f"Finished in {n_steps} steps, prefill throughput: {prefill_throughput:.2f} tok/s, decode throughput: {decode_throughput:.2f} tok/s") + assert all(toks is not None for toks in outputs), "Some sequences did not produce outputs" + outputs = [{ + "text": self.tokenizer.decode(token_ids).split(self.tokenizer.eos_token)[0], + "token_ids": token_ids[:token_ids.index(self.config.eos)] if self.config.eos in token_ids else token_ids, + "n_diff_steps": n_diff_step, + } for token_ids, n_diff_step in zip(outputs, n_diff_steps)] + if use_tqdm: + pbar.close() + return outputs diff --git a/diffulex/sampler/base.py b/diffulex/sampler/base.py index 34f394f..56734ab 100644 --- a/diffulex/sampler/base.py +++ b/diffulex/sampler/base.py @@ -86,10 +86,20 @@ def __init__(self): self.seq_last_logits_map: dict[str, torch.Tensor] = {} def _fetch_last_logits(self, logits: torch.Tensor, seq: SequenceBase) -> torch.Tensor: + seq_id_str = str(seq.seq_id) if seq.has_to_cache_block: last_logits = logits[seq.to_cache_last_token_id] - self.seq_last_logits_map[seq.seq_id] = last_logits - return self.seq_last_logits_map[seq.seq_id] + self.seq_last_logits_map[seq_id_str] = last_logits + return last_logits + # If no cached block, return cached value if available, otherwise use last logit + if seq_id_str in self.seq_last_logits_map: + return self.seq_last_logits_map[seq_id_str] + # Fallback: use last logit from current batch and cache it + last_logits = logits[-1] if logits.shape[0] > 0 else None + if last_logits is not None: + self.seq_last_logits_map[seq_id_str] = last_logits + return last_logits + raise ValueError(f"Cannot fetch last logits for sequence {seq.seq_id}: empty logits tensor") def _shift_logits(self, logits, last_logit=None): if logits.shape[1] == 0: diff --git a/diffulex/strategy/block_diffusion/engine/kvcache_manager.py b/diffulex/strategy/block_diffusion/engine/kvcache_manager.py index 9659c10..66fcf3d 100644 --- a/diffulex/strategy/block_diffusion/engine/kvcache_manager.py +++ b/diffulex/strategy/block_diffusion/engine/kvcache_manager.py @@ -28,7 +28,7 @@ def may_append(self, seq: "BDSequence") -> None: if last_block.hash == -1: prev_end_token = seq.cached_or_caching_num_tokens - seq.caching_num_tokens - 1 prev_block_idx = prev_end_token // self.block_size - if prev_block_idx < seq.num_blocks: + if 0 <= prev_block_idx < seq.num_blocks: token_ids: list[int] = seq.block(prev_block_idx) prefix = self.blocks[block_table[-2]].hash if len(block_table) > 1 else -1 h = self.compute_hash(token_ids, prefix) diff --git a/diffulex_kernel/python/dllm_flash_attn.py b/diffulex_kernel/python/dllm_flash_attn.py index 099ed68..9b2caa3 100644 --- a/diffulex_kernel/python/dllm_flash_attn.py +++ b/diffulex_kernel/python/dllm_flash_attn.py @@ -131,7 +131,7 @@ def kernel( for i in T.Parallel(BLOCK_M): scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) - for i in T.parallel(BLOCK_M): + for i in T.Parallel(BLOCK_M): scores_scale[i] = T.exp2(scores_max_prev[i] * SCALE - scores_max[i] * SCALE) for i, j in T.Parallel(BLOCK_M, BLOCK_N): diff --git a/diffulex_legacy/layers/attention/ops/chunked_prefill_decoding_unified_kernel.py b/diffulex_legacy/layers/attention/ops/chunked_prefill_decoding_unified_kernel.py index 7c863c4..ce7b1ff 100755 --- a/diffulex_legacy/layers/attention/ops/chunked_prefill_decoding_unified_kernel.py +++ b/diffulex_legacy/layers/attention/ops/chunked_prefill_decoding_unified_kernel.py @@ -15,8 +15,16 @@ from vllm import _custom_ops as ops from vllm.platforms import current_platform -from vllm.platforms.rocm import use_rocm_custom_paged_attention -from vllm.triton_utils import tl, triton +# from vllm.platforms.rocm import use_rocm_custom_paged_attention +try: + from vllm.platforms.rocm import use_rocm_custom_paged_attention # vLLM newer +except Exception: + # vLLM older / CUDA-only env: treat as disabled + def use_rocm_custom_paged_attention() -> bool: + return False + +import triton +import triton.language as tl from diffulex_legacy.layers.attention.ops.prefix_prefill import context_attention_fwd diff --git a/diffulex_legacy/layers/attention/ops/triton_decode_attn_clm.py b/diffulex_legacy/layers/attention/ops/triton_decode_attn_clm.py index 71be261..4cc13af 100755 --- a/diffulex_legacy/layers/attention/ops/triton_decode_attn_clm.py +++ b/diffulex_legacy/layers/attention/ops/triton_decode_attn_clm.py @@ -36,7 +36,9 @@ import logging from vllm.platforms import current_platform -from vllm.triton_utils import tl, triton +import triton +import triton.language as tl + is_hip_ = current_platform.is_rocm() diff --git a/examples/test_async_inference.py b/examples/test_async_inference.py new file mode 100644 index 0000000..e3f517a --- /dev/null +++ b/examples/test_async_inference.py @@ -0,0 +1,148 @@ +import argparse +import os +import asyncio +import time +from pathlib import Path +import sys + +from tqdm import tqdm +from transformers import AutoTokenizer + +from diffulex import Diffulex, SamplingParams + + +_REPO_ROOT = Path(__file__).resolve().parents[1] +if str(_REPO_ROOT) not in sys.path: + sys.path.insert(0, str(_REPO_ROOT)) + + +async def run_async_inference(worker, prompts, sampling_params): + """Run async inference using generate_async.""" + outputs = await worker.generate_async(prompts, sampling_params, use_tqdm=True) + return outputs + + +def main() -> None: + parser = argparse.ArgumentParser(description="Test async inference with fast_dllm_v2") + parser.add_argument( + "--model", + type=str, + default="/data1/ckpts/Efficient-Large-Model/Fast_dLLM_v2_7B", + help="Fast_dLLM_v2 model directory", + ) + parser.add_argument( + "--prompt", + type=str, + default="<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\nWhat is 2+2?<|im_end|>\n<|im_start|>assistant\n", + help="Input prompt for testing", + ) + parser.add_argument( + "--max-tokens", + type=int, + default=256, + help="Maximum tokens to generate", + ) + parser.add_argument( + "--temperature", + type=float, + default=0.0, + help="Sampling temperature", + ) + parser.add_argument( + "--num-prompts", + type=int, + default=10, + help="Number of prompts to test (will duplicate the prompt)", + ) + parser.add_argument( + "--gpu-memory-utilization", + type=float, + default=0.25, + help="GPU memory utilization", + ) + args = parser.parse_args() + + print("=" * 80) + print("Testing Async Inference with Fast_dLLM_v2") + print("=" * 80) + print(f"[model] {args.model}") + print(f"[prompt] {args.prompt[:100]}..." if len(args.prompt) > 100 else f"[prompt] {args.prompt}") + print(f"[num_prompts] {args.num_prompts}") + print(f"[max_tokens] {args.max_tokens}") + print(f"[temperature] {args.temperature}") + print("=" * 80) + + # Create Diffulex engine for async inference + print("\n[Initializing Diffulex engine...]") + worker = Diffulex( + model=args.model, + use_lora=False, + model_name="fast_dllm_v2", + enforce_eager=True, + data_parallel_size=1, + tensor_parallel_size=1, + gpu_memory_utilization=args.gpu_memory_utilization, + max_num_batched_tokens=2048, + max_num_seqs=args.num_prompts, + max_model_len=2048, + kv_cache_layout="unified", + decoding_strategy="block_diffusion", + mask_token_id=151665, + master_addr="127.0.0.1", + master_port=2333, + ) + + tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=True, use_fast=True) + + prompts = [args.prompt] * args.num_prompts + sampling_params = SamplingParams(temperature=args.temperature, max_tokens=args.max_tokens) + + print("\n[Running async inference...]") + print("=" * 80) + + # Run async inference + try: + start_time = time.time() + outputs = asyncio.run(run_async_inference(worker, prompts, sampling_params)) + end_time = time.time() + + elapsed_time = end_time - start_time + total_tokens = sum(len(o['token_ids']) for o in outputs) + + print("\n" + "=" * 80) + print("[Async Inference Results]") + print("=" * 80) + print(f"Generated {len(outputs)} outputs") + print(f"Total tokens: {total_tokens}") + print(f"Total time: {elapsed_time:.2f} seconds") + print(f"Average TPS: {total_tokens / elapsed_time:.2f} tok/s") + if outputs and 'n_diff_steps' in outputs[0]: + avg_diff_steps = sum(o['n_diff_steps'] for o in outputs) / len(outputs) + print(f"Average steps: {avg_diff_steps:.2f}") + + print("\n" + "=" * 80) + print("[Individual Results]") + print("=" * 80) + for idx, output in enumerate(outputs): + print(f"\n[Output {idx + 1}/{len(outputs)}]") + print(f"Text: {output['text']}") + print(f"Token IDs length: {len(output['token_ids'])}") + if 'n_diff_steps' in output: + print(f"Number of steps: {output['n_diff_steps']}") + print("-" * 80) + + except Exception as e: + print(f"\n[Error during async inference]") + print(f"Error: {e}") + import traceback + traceback.print_exc() + finally: + # Cleanup + print("\n[Cleaning up...]") + worker.exit() + print("[Done]") + + +if __name__ == "__main__": + os.environ.setdefault("TOKENIZERS_PARALLELISM", "false") + main() diff --git a/examples/test_dream_dvllm_gsm8k.py b/examples/test_dream_dvllm_gsm8k.py index 1affedb..39fa506 100755 --- a/examples/test_dream_dvllm_gsm8k.py +++ b/examples/test_dream_dvllm_gsm8k.py @@ -10,7 +10,7 @@ from transformers import AutoTokenizer from diffulex import Diffulex, SamplingParams - +import diffulex.model.dream def summarize_profiling(csv_path: str) -> dict: totals = {} @@ -64,8 +64,7 @@ def summarize_profiling(csv_path: str) -> dict: tokenizer = AutoTokenizer.from_pretrained(model, trust_remote_code=True) sampling_params = SamplingParams(temperature=0.0, max_tokens=256) - dataset = load_dataset( - "gsm8k", "main", split="test")["question"][:10] + dataset = load_dataset("gsm8k", "main")['test']['question'][:] prompts = [tokenizer.bos_token + FEW_SHOTS + p for p in tqdm(dataset)] output_file = "log/profiles/perf_dvllm_dream_7B.json" diff --git a/examples/test_fastdllmv2_diffulex_gsm8k.py b/examples/test_fastdllmv2_diffulex_gsm8k.py index 3950537..110c090 100755 --- a/examples/test_fastdllmv2_diffulex_gsm8k.py +++ b/examples/test_fastdllmv2_diffulex_gsm8k.py @@ -40,6 +40,7 @@ def summarize_profiling(csv_path: str) -> dict: if __name__ == "__main__": model = "/data1/ckpts/Efficient-Large-Model/Fast_dLLM_v2_7B" + local_data_path = "/data1/LargeData/gsm8k" LLM = Diffulex( model, use_lora=False, @@ -49,7 +50,7 @@ def summarize_profiling(csv_path: str) -> dict: tensor_parallel_size=1, gpu_memory_utilization=0.25, max_num_batched_tokens=2048, - max_num_seqs=20, + max_num_seqs=1, max_model_len=2048, kv_cache_layout="unified", decoding_strategy="block_diffusion", @@ -58,13 +59,14 @@ def summarize_profiling(csv_path: str) -> dict: tokenizer = AutoTokenizer.from_pretrained(model, trust_remote_code=True) sampling_params = SamplingParams(temperature=0.0, max_tokens=256) - dataset = load_dataset("gsm8k", "main", split="test")["question"][:10] + # dataset = load_dataset("gsm8k", "main", split="test")["question"][:10] + dataset = load_dataset(local_data_path, "main", split="test", trust_remote_code=True)["question"][:10] prompts = [ FEW_SHOTS + f"<|im_start|>user\nQuestion: {question}\nAnswer:<|im_end|>\n<|im_start|>assistant\n" for question in tqdm(dataset) ] - output_file = "log/profiles/perf_dvllm_dream_7B.json" + output_file = "log/profiles/perf_dvllm_fastdllmv2_7B.json" if os.path.exists(output_file): os.remove(output_file) # with VizTracer(output_file=output_file, file_info=True) as tracer: diff --git a/examples/test_llada_dvllm_human_eval.py b/examples/test_llada_dvllm_human_eval.py index 5e3608f..5f52c72 100755 --- a/examples/test_llada_dvllm_human_eval.py +++ b/examples/test_llada_dvllm_human_eval.py @@ -38,7 +38,7 @@ def summarize_profiling(csv_path: str) -> dict: if __name__ == "__main__": WEIGHT_DIR = "/data1/ckpts" DATA_DIR = "/data1/LargeData" - model = f"{WEIGHT_DIR}/GSAI-ML/llada-8b-instruct" + model = f"{WEIGHT_DIR}/GSAI-ML/LLaDA-8B-Instruct" LLM = LLM( model, lora_path=f"{WEIGHT_DIR}/SJTU-Deng-Lab/D2F_LLaDA_Instruct_8B_Lora", diff --git a/examples/test_sdar_dvllm.py b/examples/test_sdar_dvllm.py index 78fbbd7..28aded5 100644 --- a/examples/test_sdar_dvllm.py +++ b/examples/test_sdar_dvllm.py @@ -97,14 +97,14 @@ def main() -> None: parser.add_argument( "--model", type=str, - default="/home/lzx/SDAR/training/model/SDAR-1.7B-Chat", + default="/data1/ckpts/SDAR/SDAR-1.7B-Chat", help="SDAR HF model directory (contains config.json + model.safetensors).", ) parser.add_argument("--device", type=int, default=0) parser.add_argument( "--converted-dir", type=str, - default="/home/lzx/tmp/diffulex_sdar_converted", + default="/home/ljp/tmp/diffulex_sdar_converted", help="Output directory for converted checkpoint keys (Diffulex-native).", ) parser.add_argument("--prompt", type=str, default="你好,请用一句话介绍 SDAR。") @@ -136,6 +136,7 @@ def main() -> None: # Build Config + load model weights using Diffulex loader. from diffulex.config import Config from diffulex.model.auto_model import AutoModelForDiffusionLM + import diffulex.model.sdar cfg = Config( model=str(model_dir), diff --git a/scripts/test_dvllm_dream_gsm8k.sh b/scripts/test_dvllm_dream_gsm8k.sh index 25ff47e..c3705bb 100755 --- a/scripts/test_dvllm_dream_gsm8k.sh +++ b/scripts/test_dvllm_dream_gsm8k.sh @@ -1,2 +1,2 @@ #!/usr/bin/zsh -CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python examples/test_dream_dvllm_gsm8k.py 2>&1 | tee log/test_dvllm_dream_gsm8k.log \ No newline at end of file +CUDA_VISIBLE_DEVICES=0,1 python examples/test_dream_dvllm_gsm8k.py 2>&1 | tee log/test_dvllm_dream_gsm8k.log \ No newline at end of file