diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 0000000..e8aab8b --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,3 @@ +{ + "workbench.colorTheme": "Default Dark Modern" +} \ No newline at end of file diff --git a/apex_plus/cluster/__pycache__/cluster.cpython-312.pyc b/apex_plus/cluster/__pycache__/cluster.cpython-312.pyc new file mode 100644 index 0000000..dfda0e3 Binary files /dev/null and b/apex_plus/cluster/__pycache__/cluster.cpython-312.pyc differ diff --git a/apex_plus/cluster/__pycache__/device.cpython-312.pyc b/apex_plus/cluster/__pycache__/device.cpython-312.pyc new file mode 100644 index 0000000..c0de010 Binary files /dev/null and b/apex_plus/cluster/__pycache__/device.cpython-312.pyc differ diff --git a/apex_plus/cluster/__pycache__/gpu.cpython-312.pyc b/apex_plus/cluster/__pycache__/gpu.cpython-312.pyc new file mode 100644 index 0000000..3713a99 Binary files /dev/null and b/apex_plus/cluster/__pycache__/gpu.cpython-312.pyc differ diff --git a/apex_plus/execution/__pycache__/plan.cpython-312.pyc b/apex_plus/execution/__pycache__/plan.cpython-312.pyc new file mode 100644 index 0000000..180b139 Binary files /dev/null and b/apex_plus/execution/__pycache__/plan.cpython-312.pyc differ diff --git a/apex_plus/ir/__pycache__/block.cpython-312.pyc b/apex_plus/ir/__pycache__/block.cpython-312.pyc new file mode 100644 index 0000000..ab05365 Binary files /dev/null and b/apex_plus/ir/__pycache__/block.cpython-312.pyc differ diff --git a/apex_plus/ir/__pycache__/cell.cpython-312.pyc b/apex_plus/ir/__pycache__/cell.cpython-312.pyc new file mode 100644 index 0000000..c1a77db Binary files /dev/null and b/apex_plus/ir/__pycache__/cell.cpython-312.pyc differ diff --git a/apex_plus/ir/__pycache__/task.cpython-312.pyc b/apex_plus/ir/__pycache__/task.cpython-312.pyc new file mode 100644 index 0000000..cf59a78 Binary files /dev/null and b/apex_plus/ir/__pycache__/task.cpython-312.pyc differ diff --git a/apex_plus/ir/__pycache__/transformer.cpython-312.pyc b/apex_plus/ir/__pycache__/transformer.cpython-312.pyc new file mode 100644 index 0000000..d4acf15 Binary files /dev/null and b/apex_plus/ir/__pycache__/transformer.cpython-312.pyc differ diff --git a/apex_plus/ir/cells/__pycache__/attention.cpython-312.pyc b/apex_plus/ir/cells/__pycache__/attention.cpython-312.pyc new file mode 100644 index 0000000..f4699ef Binary files /dev/null and b/apex_plus/ir/cells/__pycache__/attention.cpython-312.pyc differ diff --git a/apex_plus/ir/cells/__pycache__/embedding.cpython-312.pyc b/apex_plus/ir/cells/__pycache__/embedding.cpython-312.pyc new file mode 100644 index 0000000..5380743 Binary files /dev/null and b/apex_plus/ir/cells/__pycache__/embedding.cpython-312.pyc differ diff --git a/apex_plus/ir/cells/__pycache__/ffn.cpython-312.pyc b/apex_plus/ir/cells/__pycache__/ffn.cpython-312.pyc new file mode 100644 index 0000000..6271b6a Binary files /dev/null and b/apex_plus/ir/cells/__pycache__/ffn.cpython-312.pyc differ diff --git a/apex_plus/ir/cells/__pycache__/sampler.cpython-312.pyc b/apex_plus/ir/cells/__pycache__/sampler.cpython-312.pyc new file mode 100644 index 0000000..10eb88d Binary files /dev/null and b/apex_plus/ir/cells/__pycache__/sampler.cpython-312.pyc differ diff --git a/apex_plus/ir/tasks/__pycache__/attention.cpython-312.pyc b/apex_plus/ir/tasks/__pycache__/attention.cpython-312.pyc new file mode 100644 index 0000000..865d923 Binary files /dev/null and b/apex_plus/ir/tasks/__pycache__/attention.cpython-312.pyc differ diff --git a/apex_plus/ir/tasks/__pycache__/ffn.cpython-312.pyc b/apex_plus/ir/tasks/__pycache__/ffn.cpython-312.pyc new file mode 100644 index 0000000..8f84bc6 Binary files /dev/null and b/apex_plus/ir/tasks/__pycache__/ffn.cpython-312.pyc differ diff --git a/apex_plus/models/__pycache__/CLIP_vision.cpython-312.pyc b/apex_plus/models/__pycache__/CLIP_vision.cpython-312.pyc new file mode 100644 index 0000000..0b48085 Binary files /dev/null and b/apex_plus/models/__pycache__/CLIP_vision.cpython-312.pyc differ diff --git a/apex_plus/models/__pycache__/bloom.cpython-312.pyc b/apex_plus/models/__pycache__/bloom.cpython-312.pyc new file mode 100644 index 0000000..a7de8d5 Binary files /dev/null and b/apex_plus/models/__pycache__/bloom.cpython-312.pyc differ diff --git a/apex_plus/models/__pycache__/gpt2.cpython-312.pyc b/apex_plus/models/__pycache__/gpt2.cpython-312.pyc new file mode 100644 index 0000000..48ea90f Binary files /dev/null and b/apex_plus/models/__pycache__/gpt2.cpython-312.pyc differ diff --git a/apex_plus/models/__pycache__/gpt_bigcode.cpython-312.pyc b/apex_plus/models/__pycache__/gpt_bigcode.cpython-312.pyc new file mode 100644 index 0000000..afcf9ae Binary files /dev/null and b/apex_plus/models/__pycache__/gpt_bigcode.cpython-312.pyc differ diff --git a/apex_plus/models/__pycache__/gpt_neox.cpython-312.pyc b/apex_plus/models/__pycache__/gpt_neox.cpython-312.pyc new file mode 100644 index 0000000..c5feb3c Binary files /dev/null and b/apex_plus/models/__pycache__/gpt_neox.cpython-312.pyc differ diff --git a/apex_plus/models/__pycache__/gptj.cpython-312.pyc b/apex_plus/models/__pycache__/gptj.cpython-312.pyc new file mode 100644 index 0000000..be151e0 Binary files /dev/null and b/apex_plus/models/__pycache__/gptj.cpython-312.pyc differ diff --git a/apex_plus/models/__pycache__/llama.cpython-312.pyc b/apex_plus/models/__pycache__/llama.cpython-312.pyc new file mode 100644 index 0000000..b2f1006 Binary files /dev/null and b/apex_plus/models/__pycache__/llama.cpython-312.pyc differ diff --git a/apex_plus/models/__pycache__/llama3.cpython-312.pyc b/apex_plus/models/__pycache__/llama3.cpython-312.pyc new file mode 100644 index 0000000..f44fa9d Binary files /dev/null and b/apex_plus/models/__pycache__/llama3.cpython-312.pyc differ diff --git a/apex_plus/models/__pycache__/mistral.cpython-312.pyc b/apex_plus/models/__pycache__/mistral.cpython-312.pyc new file mode 100644 index 0000000..89272c2 Binary files /dev/null and b/apex_plus/models/__pycache__/mistral.cpython-312.pyc differ diff --git a/apex_plus/models/__pycache__/mixtral.cpython-312.pyc b/apex_plus/models/__pycache__/mixtral.cpython-312.pyc new file mode 100644 index 0000000..3aae166 Binary files /dev/null and b/apex_plus/models/__pycache__/mixtral.cpython-312.pyc differ diff --git a/apex_plus/models/__pycache__/model.cpython-312.pyc b/apex_plus/models/__pycache__/model.cpython-312.pyc new file mode 100644 index 0000000..65d1a21 Binary files /dev/null and b/apex_plus/models/__pycache__/model.cpython-312.pyc differ diff --git a/apex_plus/models/__pycache__/moe.cpython-312.pyc b/apex_plus/models/__pycache__/moe.cpython-312.pyc new file mode 100644 index 0000000..e7d34f0 Binary files /dev/null and b/apex_plus/models/__pycache__/moe.cpython-312.pyc differ diff --git a/apex_plus/models/__pycache__/opt.cpython-312.pyc b/apex_plus/models/__pycache__/opt.cpython-312.pyc new file mode 100644 index 0000000..fea18c3 Binary files /dev/null and b/apex_plus/models/__pycache__/opt.cpython-312.pyc differ diff --git a/apex_plus/models/__pycache__/registry.cpython-312.pyc b/apex_plus/models/__pycache__/registry.cpython-312.pyc new file mode 100644 index 0000000..07df70e Binary files /dev/null and b/apex_plus/models/__pycache__/registry.cpython-312.pyc differ diff --git a/apex_plus/models/__pycache__/t5.cpython-312.pyc b/apex_plus/models/__pycache__/t5.cpython-312.pyc new file mode 100644 index 0000000..5b0a65b Binary files /dev/null and b/apex_plus/models/__pycache__/t5.cpython-312.pyc differ diff --git a/apex_plus/models/__pycache__/whisper.cpython-312.pyc b/apex_plus/models/__pycache__/whisper.cpython-312.pyc new file mode 100644 index 0000000..640db51 Binary files /dev/null and b/apex_plus/models/__pycache__/whisper.cpython-312.pyc differ diff --git a/apex_plus/parallel/__pycache__/comm.cpython-312.pyc b/apex_plus/parallel/__pycache__/comm.cpython-312.pyc new file mode 100644 index 0000000..928e2f5 Binary files /dev/null and b/apex_plus/parallel/__pycache__/comm.cpython-312.pyc differ diff --git a/apex_plus/parallel/__pycache__/reshard.cpython-312.pyc b/apex_plus/parallel/__pycache__/reshard.cpython-312.pyc new file mode 100644 index 0000000..d64c17b Binary files /dev/null and b/apex_plus/parallel/__pycache__/reshard.cpython-312.pyc differ diff --git a/apex_plus/parallel/__pycache__/schedule.cpython-312.pyc b/apex_plus/parallel/__pycache__/schedule.cpython-312.pyc new file mode 100644 index 0000000..4afd49e Binary files /dev/null and b/apex_plus/parallel/__pycache__/schedule.cpython-312.pyc differ diff --git a/apex_plus/parallel/__pycache__/task_parallel.cpython-312.pyc b/apex_plus/parallel/__pycache__/task_parallel.cpython-312.pyc new file mode 100644 index 0000000..03eea5c Binary files /dev/null and b/apex_plus/parallel/__pycache__/task_parallel.cpython-312.pyc differ diff --git a/apex_plus/parallel/templates/__pycache__/__init__.cpython-312.pyc b/apex_plus/parallel/templates/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000..43614ce Binary files /dev/null and b/apex_plus/parallel/templates/__pycache__/__init__.cpython-312.pyc differ diff --git a/apex_plus/parallel/templates/__pycache__/attention.cpython-312.pyc b/apex_plus/parallel/templates/__pycache__/attention.cpython-312.pyc new file mode 100644 index 0000000..077bbd4 Binary files /dev/null and b/apex_plus/parallel/templates/__pycache__/attention.cpython-312.pyc differ diff --git a/apex_plus/parallel/templates/__pycache__/default.cpython-312.pyc b/apex_plus/parallel/templates/__pycache__/default.cpython-312.pyc new file mode 100644 index 0000000..4d7e97e Binary files /dev/null and b/apex_plus/parallel/templates/__pycache__/default.cpython-312.pyc differ diff --git a/apex_plus/parallel/templates/__pycache__/ffn.cpython-312.pyc b/apex_plus/parallel/templates/__pycache__/ffn.cpython-312.pyc new file mode 100644 index 0000000..7ffbfa1 Binary files /dev/null and b/apex_plus/parallel/templates/__pycache__/ffn.cpython-312.pyc differ diff --git a/apex_plus/search/__pycache__/engine.cpython-312.pyc b/apex_plus/search/__pycache__/engine.cpython-312.pyc new file mode 100644 index 0000000..dd3d5fb Binary files /dev/null and b/apex_plus/search/__pycache__/engine.cpython-312.pyc differ diff --git a/apex_plus/search/engine.py b/apex_plus/search/engine.py index 6cf1e19..a8596a8 100644 --- a/apex_plus/search/engine.py +++ b/apex_plus/search/engine.py @@ -48,6 +48,7 @@ def generate_schedules( device_memory_capacity = cluster.get_device_memory_capacity() num_devices = cluster.get_num_devices() + # print(f"num_devices: {num_devices}") # 1. Model-level data parallelism. for num_replicas in _get_divisors(num_devices): if not cluster.is_partitionable(num_replicas): @@ -177,7 +178,7 @@ def generate_schedules( return parallel_schedules def generate_plans(self, arch: str, cluster: "Cluster") -> List[ExecutionPlan]: - + # print(f"devices: {cluster.get_num_devices()}") # Generate all possible parallel schedules. if arch == "encoder": parallel_schedules = self.generate_schedules( @@ -225,14 +226,19 @@ def search( model_config=[], ttft_slo = 10, tpot_slo = 10, - max_batch_size = 0) -> List[ExecutionPlan]: + max_batch_size = 0, + distserve = False) -> List[ExecutionPlan]: """Search for the best execution plan.""" candidate_plans = self.generate_plans(self.arch, self.cluster) print(f"Generated {len(candidate_plans)} {self.arch} candidate plans.") + # print(self.simulator.num_total_nodes, self.simulator.num_total_devices) + # print(self.cluster) + outputs: List[Tuple[ExecutionPlan, SimulatorOutput]] = [] slo_targets = [ttft_slo, tpot_slo] for plan in tqdm(candidate_plans): + # for plan in candidate_plans: requests, output = self.simulator.simulate( plan, self.arch, @@ -241,7 +247,8 @@ def search( req_percentiles, token_percentiles, slo_targets, - max_batch_size) + max_batch_size, + distserve) if output is None: # Invalid plan (e.g., when the model does not fit in memory). continue @@ -253,7 +260,7 @@ def search( print("=" * 80) outputs = sorted(outputs, key=lambda x: x[1].total_time) - # Print either best plan or all plans based off flag + # # Print either best plan or all plans based off flag if return_all_plans: for i, (plan, output) in enumerate(outputs): print(f"* Parallel schedule {i} for {self.arch}:") diff --git a/apex_plus/simulator/__pycache__/comm_profile.cpython-312.pyc b/apex_plus/simulator/__pycache__/comm_profile.cpython-312.pyc new file mode 100644 index 0000000..7c5b5e6 Binary files /dev/null and b/apex_plus/simulator/__pycache__/comm_profile.cpython-312.pyc differ diff --git a/apex_plus/simulator/__pycache__/comp_profile.cpython-312.pyc b/apex_plus/simulator/__pycache__/comp_profile.cpython-312.pyc new file mode 100644 index 0000000..9c5e763 Binary files /dev/null and b/apex_plus/simulator/__pycache__/comp_profile.cpython-312.pyc differ diff --git a/apex_plus/simulator/__pycache__/simulator.cpython-312.pyc b/apex_plus/simulator/__pycache__/simulator.cpython-312.pyc new file mode 100644 index 0000000..d6f85a2 Binary files /dev/null and b/apex_plus/simulator/__pycache__/simulator.cpython-312.pyc differ diff --git a/apex_plus/simulator/__pycache__/simulator_origin.cpython-312.pyc b/apex_plus/simulator/__pycache__/simulator_origin.cpython-312.pyc new file mode 100644 index 0000000..0e3d9df Binary files /dev/null and b/apex_plus/simulator/__pycache__/simulator_origin.cpython-312.pyc differ diff --git a/apex_plus/simulator/__pycache__/trace.cpython-312.pyc b/apex_plus/simulator/__pycache__/trace.cpython-312.pyc new file mode 100644 index 0000000..3eb32aa Binary files /dev/null and b/apex_plus/simulator/__pycache__/trace.cpython-312.pyc differ diff --git a/apex_plus/simulator/simulator.py b/apex_plus/simulator/simulator.py index 7e96ef2..a8bb4ea 100644 --- a/apex_plus/simulator/simulator.py +++ b/apex_plus/simulator/simulator.py @@ -17,18 +17,15 @@ from apex_plus.utils.dtype import DTYPE GB = 1024 * 1024 * 1024 -WORKSPACE = 1 * GB # a constant buffer for each device to run the program - -MAX_NUM_INPUT_TOKENS = 64 * 1024 # Max in profile/scripts/gemm.py - +WORKSPACE = 1 * GB +MAX_NUM_INPUT_TOKENS = 64 * 1024 US_TO_SEC = 1000000 MS_TO_SEC = 1000 US_TO_MS = 1000 - +KV_CACHE_TRANSFER_POWER = 1000 @dataclass class SimulatorOutput: - param_size_per_device: float available_memory_per_device: float num_requests_per_iteration: float @@ -40,20 +37,22 @@ class SimulatorOutput: total_time: float total_energy: float - class Simulator: - def __init__( self, model: Transformer, cluster: Cluster, trace: Trace, dtype: dict, + prefill_gpu: int = 0, # GPU for prefill phase + decode_gpu: int = 1, # GPU for decode phase ) -> None: self.model = model self.cluster = cluster self.trace = trace self.dtype = dtype + self.prefill_gpu = prefill_gpu + self.decode_gpu = decode_gpu self.gpu = cluster.get_device().device_type self.gpu_memory = cluster.get_device_memory_capacity() self.peak_flops = cluster.get_device().peak_flops[self.highest_prec()] @@ -63,12 +62,7 @@ def __init__( self.cluster_size_per_node = self.num_total_devices // self.num_total_nodes def highest_prec(self) -> DTYPE: - data_type = [] - data_type.append(self.dtype["w"]) - data_type.append(self.dtype["kv"]) - data_type.append(self.dtype["act"]) - # Dealing with mixed precesion - # Assuming we dequantize the value for computation + data_type = [self.dtype["w"], self.dtype["kv"], self.dtype["act"]] highest_precision = DTYPE.FLOAT8 if DTYPE.FLOAT16 in data_type: highest_precision = DTYPE.FLOAT16 @@ -76,14 +70,8 @@ def highest_prec(self) -> DTYPE: highest_precision = DTYPE.FLOAT32 return highest_precision - def dispatch( - self, - requests: List[Request], - factor: int, - ) -> List[List[Request]]: + def dispatch(self, requests: List[Request], factor: int) -> List[List[Request]]: sublists = [[] for _ in range(factor)] - # Distribute elements in a round-robin fashion - # Can be replaced with more sophisticated strategy for index, element in enumerate(requests): sublists[index % factor].append(element) return sublists @@ -110,15 +98,11 @@ def get_metrics( seq_lens: List[int] = [], arch: str = '', slo_targets: List[int] = [], - ): - - + ): def calculate_tbt_percentiles(latency_dict): token_latencies_per_request = [latency for latency in latency_dict.values()] avg_tbt_vals = [] percentile_vals = [] - - # Calculate avg TBT per request for latency_list in token_latencies_per_request: list_length = len(latency_list) avg_tbt = 0 @@ -127,11 +111,8 @@ def calculate_tbt_percentiles(latency_dict): ttlt = sum(latency_list) avg_tbt = (ttlt - ttft) / list_length avg_tbt_vals.append(avg_tbt) - - # Calculate all necessary percentiles for tbt for percentile in token_percentiles: percentile_vals.append(np.percentile(avg_tbt_vals, percentile)) - return percentile_vals def calculate_slo_metrics(latency_dict, slo_targets): @@ -140,56 +121,27 @@ def calculate_slo_metrics(latency_dict, slo_targets): ttft_target = slo_targets[0] tpot_target = slo_targets[1] slo_metrics = [] - - # Calculate Percentage of requests that are <= TTFT_SLO - # TTFT is just the first token in the latency ttft_slo_counter = 0 for latency_list in token_latencies_per_request: - if(latency_list[0]/US_TO_MS <= ttft_target): + if latency_list[0] / US_TO_MS <= ttft_target: ttft_slo_counter += 1 - slo_metrics.append( (ttft_slo_counter/num_requests) * 100 ) - - # Calculate Percentage of requests that have an avg TPOT <= TPOT_SLO + slo_metrics.append((ttft_slo_counter / num_requests) * 100) tpot_slo_counter = 0 - tok_latencies_per_req_after_first_tok = [sublist[1:] for sublist in token_latencies_per_request] - for latency_list in token_gen_times: - # Calculate Avg TPOT per request - avg_tpot = np.mean(latency_list)/US_TO_MS - if(avg_tpot <= tpot_target): - tpot_slo_counter += 1 - slo_metrics.append( (tpot_slo_counter/num_requests) * 100 ) - + for latency_list in token_latencies_per_request: + avg_tpot = np.mean(latency_list[1:]) / US_TO_MS if len(latency_list) > 1 else float('inf') + if avg_tpot <= tpot_target: + tpot_slo_counter += 1 + slo_metrics.append((tpot_slo_counter / num_requests) * 100) return slo_metrics - - # Store performance metrics - Time to first token, TPOT ,P50, P95, & other latencies + performance_metrics: List[Tuple[str, float]] = [] performance_metrics_units: List[str] = [] - performance_metrics.append( - ("Throughput: Avg. Tokens generated per second", float("NaN")) - ) - performance_metrics.append( - ("Throughput: Avg. Tokens processed per second", float("NaN")) - ) + performance_metrics.append(("Throughput: Avg. Tokens generated per second", float("NaN"))) + performance_metrics.append(("Throughput: Avg. Tokens processed per second", float("NaN"))) performance_metrics.append(("Throughput: Requests per second", float("NaN"))) - performance_metrics.append( - ("Latency: Avg. Time to first token (TTFT in msec)", float("NaN")) - ) - performance_metrics.append( - ("Latency: Avg. Time per output token (TPOT in msec)", float("NaN")) - ) - performance_metrics_units += [ - "tokens/sec", - "tokens/sec", - "requests/sec", - "msec", - "msec", - ] - - num_layers = 0 - num_heads = 0 - head_dim = 0 - hidden_size = 0 - theoretical_peak_flops = self.peak_flops + performance_metrics.append(("Latency: Avg. Time to first token (TTFT in msec)", float("NaN"))) + performance_metrics.append(("Latency: Avg. Time per output token (TPOT in msec)", float("NaN"))) + performance_metrics_units += ["tokens/sec", "tokens/sec", "requests/sec", "msec", "msec"] if hasattr(model_config, "num_layers"): num_layers = model_config.num_layers @@ -202,106 +154,53 @@ def calculate_slo_metrics(latency_dict, slo_targets): else: raise ValueError("Unable to get model layers, heads, or hidden size") head_dim = hidden_size // num_heads - num_parameters = num_layers * hidden_size * hidden_size * 12 tpot = 0.0 avg_ttft = 0.0 mbu = 0.0 - # If encoder, there the output_len is 0 - if arch == "encoder": - avg_output_len = 0.0 - # Decoders that generate tokens - else: - # TPOT after the first token(this is also known as inter-token latency) + if arch != "encoder": token_gen_times = [value[1:] for value in request_token_gen_times.values()] - flat_token_gen_times = [ - item for sublist in token_gen_times for item in sublist - ] - tpot = np.mean(flat_token_gen_times) / MS_TO_SEC - avg_ttft = ( - np.mean([value[0] for value in request_token_gen_times.values()]) - / US_TO_MS - ) - - # Calculate token percentiles - token_percentiles = token_percentiles + [50, 95] - token_percentiles.sort() - # Avg percentile vals are returned in order of sorted percentiles + flat_token_gen_times = [item for sublist in token_gen_times for item in sublist] + tpot = np.mean(flat_token_gen_times) / MS_TO_SEC if flat_token_gen_times else 0.0 + avg_ttft = np.mean([value[0] for value in request_token_gen_times.values()]) / US_TO_MS + token_percentiles = sorted(token_percentiles + [50, 95]) avg_percentile_vals = calculate_tbt_percentiles(request_token_gen_times) - # Add to performance_metrics for index, percentile in enumerate(token_percentiles): - performance_metrics.append( - ( - f"Avg. TBT Percentile: P{percentile}", - avg_percentile_vals[index] / US_TO_MS, - ) - ) + performance_metrics.append((f"Avg. TBT Percentile: P{percentile}", avg_percentile_vals[index] / US_TO_MS)) performance_metrics_units.append("msec") - # MBU - kv_cache_size = ( - 2 * num_layers * num_heads * head_dim * self.dtype["kv"].size - ) + kv_cache_size = 2 * num_layers * num_heads * head_dim * self.dtype["kv"].size tpot_sec = tpot / MS_TO_SEC - theoretical_peak_mem_bandwidth = self.peak_mem_bandwidth - observed_mem_bandwidth = (num_parameters + kv_cache_size) / tpot_sec - mbu = (observed_mem_bandwidth / theoretical_peak_mem_bandwidth) * 100 + observed_mem_bandwidth = (num_parameters + kv_cache_size) / tpot_sec if tpot_sec else 0 + mbu = (observed_mem_bandwidth / self.peak_mem_bandwidth) * 100 - # Tokens gen per second - token_throughput = avg_output_len * len(requests) / (total_time / US_TO_SEC) + token_throughput = avg_output_len * len(requests) / (total_time / US_TO_SEC) if total_time else 0 performance_metrics[0] = (performance_metrics[0][0], token_throughput) - # Tokens processed per second - performance_metrics[1] = ( - performance_metrics[1][0], - (avg_input_len + avg_output_len) * len(requests) / (total_time / US_TO_SEC), - ) - # Requests per second - performance_metrics[2] = ( - performance_metrics[2][0], - len(requests) / (total_time / US_TO_SEC), - ) - # Time to first token + performance_metrics[1] = (performance_metrics[1][0], (avg_input_len + avg_output_len) * len(requests) / (total_time / US_TO_SEC)) + performance_metrics[2] = (performance_metrics[2][0], len(requests) / (total_time / US_TO_SEC)) performance_metrics[3] = (performance_metrics[3][0], avg_ttft) - # TPOT after the first token(this is also known as inter-token latency) performance_metrics[4] = (performance_metrics[4][0], tpot) - # Calculate request percentiles - request_latencies = [ - sum(token_latencies) for token_latencies in request_token_gen_times.values() - ] - req_percentiles = req_percentiles + [50, 95] - req_percentiles.sort() + request_latencies = [sum(token_latencies) for token_latencies in request_token_gen_times.values()] + req_percentiles = sorted(req_percentiles + [50, 95]) for percentile in req_percentiles: percentile_val = np.percentile(request_latencies, percentile) / US_TO_SEC - performance_metrics.append( - ( - f"Request Completion Latency: {percentile}th percentile", - percentile_val, - ) - ) + performance_metrics.append((f"Request Completion Latency: {percentile}th percentile", percentile_val)) performance_metrics_units.append("sec") - # Calculate Avg MFU observed_throughput = token_throughput mfus: List[float] = [] - for index, seq_len in enumerate(seq_lens): - theoretical_throughput = theoretical_peak_flops / ( - 6 * num_parameters + 12 * num_layers * num_heads * head_dim * seq_len - ) + for seq_len in seq_lens: + theoretical_throughput = self.peak_flops / (6 * num_parameters + 12 * num_layers * num_heads * head_dim * seq_len) mfu = (observed_throughput / theoretical_throughput) * 100 mfus.append(mfu) - avg_mfu = np.mean(mfus) performance_metrics.append((f"Avg. MFU Per iteration", avg_mfu)) performance_metrics_units.append("%") - - # Append mbu - performance_metrics.append((f"MBU ", mbu)) + performance_metrics.append((f"MBU", mbu)) performance_metrics_units.append("%") - # Get SLO Metrics slo_metrics = calculate_slo_metrics(request_token_gen_times, slo_targets) - return performance_metrics, performance_metrics_units, slo_metrics def simulate( @@ -314,147 +213,148 @@ def simulate( token_percentiles: List[int] = [], slo_targets: List[int] = [], max_batch_size: int = 0, + distserve: bool = False, ) -> Optional[SimulatorOutput]: + if self.num_total_devices < 2: + raise ValueError("Need at least 2 GPUs for prefill and decode disaggregation") + parallel_schedule = execution_plan.parallel_schedule stage_schedule = parallel_schedule.stage_schedule num_stages = parallel_schedule.num_stages num_model_replicas = parallel_schedule.num_model_replicas - num_attn_cell_replicas = 0 - for cell_schedule in stage_schedule.cell_schedules: - if cell_schedule.cell.is_attn(): - num_attn_cell_replicas = cell_schedule.num_replicas - break - assert num_attn_cell_replicas > 0 - + num_attn_cell_replicas = next( + cell_schedule.num_replicas for cell_schedule in stage_schedule.cell_schedules if cell_schedule.cell.is_attn() + ) + # print(f"cluster: {self.cluster}") param_sizes = stage_schedule.get_param_size_per_device(self.dtype["w"]) - num_devices = len(param_sizes) - - available_memories = [ - self.gpu_memory - WORKSPACE - param_size for param_size in param_sizes - ] - if any(avail_mem < 0 for avail_mem in available_memories): - # Invalid. + if len(param_sizes) == 1: + return None, None + # print(f"param_sizes: {param_sizes}") + # print(f"self.prefill_gpu: {self.prefill_gpu}") + # print(f"self.decode_gpu: {self.decode_gpu}") + available_memory_prefill = self.gpu_memory - WORKSPACE - param_sizes[self.prefill_gpu] + available_memory_decode = self.gpu_memory - WORKSPACE - param_sizes[self.decode_gpu] + if any(avail_mem < 0 for avail_mem in [available_memory_prefill, available_memory_decode]): return None - min_available_memory = min(available_memories) + WORKSPACE + min_available_memory_prefill = available_memory_prefill + WORKSPACE + min_available_memory_decode = available_memory_decode + WORKSPACE param_size = max(param_sizes) - # Calculate the maximum number of tokens that can be stored in KV cache. - # This limits the maximum number of sequences that can be batched. - kv_token_sizes = ( - [1] * num_devices - if arch == "encoder" - else stage_schedule.get_kv_token_size_per_device(self.dtype["kv"]) - ) - - max_num_tokens = min( - int(available_memories[i] // kv_token_sizes[i]) for i in range(num_devices) - ) - # Evenly partition the KV cache for each stage. + # KV cache on decode GPU only + kv_token_sizes = [0] * self.num_total_devices + kv_token_sizes[self.decode_gpu] = stage_schedule.get_kv_token_size_per_device(self.dtype["kv"])[0] + max_num_tokens = int(available_memory_decode // kv_token_sizes[self.decode_gpu]) max_num_tokens_per_stage = max_num_tokens // num_stages - # Statistics list_of_exe_time = [] num_reqs_per_iteration = [] num_tokens_per_iteration = [] request_token_gen_times = {} - - requests = [] - model_replica_time = [] - total_energy = 0.0 - # Copy to avoid altering the original traces copied_requests = copy.deepcopy(self.trace.requests) - ### Finished housekeeping; starts the actual simulation ### - - # Split a list of requests to n sublists, - # where n = num_model_replicas + # Split requests across model replicas model_requests = self.dispatch(copied_requests, num_model_replicas) + total_time = 0.0 + for model_replica in range(num_model_replicas): model_replica_energy = 0.0 - stage_iter_times = [] - stage_requests = self.dispatch(model_requests[model_replica], num_stages) - for stage in range(num_stages): + model_replica_time = [] + + # Prefill phase on prefill GPU + prefill_requests = copy.deepcopy(model_requests[model_replica]) + ( + prefill_updated_requests, + prefill_iter_times, + prefill_reqs_per_iter, + prefill_tokens_per_iter, + prefill_exe_time, + prefill_token_gen_times, + prefill_energy, + ) = self.sub_simulate( + execution_plan, + arch, + frequency, + prefill_requests, + num_attn_cell_replicas, + max_num_tokens_per_stage, + max_batch_size, + phase="prefill", + ) + if prefill_updated_requests is None: + return None - cell_replica_iter_times = [] - cell_requests = self.dispatch( - stage_requests[stage], num_attn_cell_replicas + # Calculate KV cache transfer time + kv_cache_size = sum( + 2 * model_config.num_layers * model_config.num_heads * (model_config.hidden_size // model_config.num_heads) * self.dtype["kv"].size * req.input_len + for req in prefill_updated_requests + ) + if self.num_total_nodes == 1: + kv_transfer_time = get_p2p_comm_time( + gpu=self.gpu, + num_nodes=1, + num_gpus_per_node=2, + dtype=self.dtype["kv"], + num_elements=kv_cache_size, ) - for cell_replica in range(num_attn_cell_replicas): - target_requests = cell_requests[cell_replica] - - ( - updated_requests, - cell_replica_iter_time, - reqs_per_iter, - tokens_per_iter, - exe_time, - request_token_gen_time, - stage_energy, - ) = self.sub_simulate( - execution_plan, - arch, - frequency, - target_requests, - num_attn_cell_replicas, - max_num_tokens_per_stage, - max_batch_size, - ) - - if updated_requests is None: - return None, None - requests.extend( - updated_requests - ) # timestamp updated with completion time - id_num = model_replica * stage * cell_replica - renamed_gen_time = { - f"{key}_{id_num}": value - for key, value in request_token_gen_time.items() - } - request_token_gen_times.update(renamed_gen_time) - cell_replica_iter_times.append(cell_replica_iter_time) - num_reqs_per_iteration.append(reqs_per_iter) - num_tokens_per_iteration.append(tokens_per_iter) - list_of_exe_time.append(exe_time) - model_replica_energy += ( - stage_energy // num_attn_cell_replicas * num_stages - ) - # Note: dividied by cell replicas as the energy scaling of cell is already handled in Line 693 - # Multiplied with num_stages because we only simulate one stage, but num_stages run concurrently - - # iteration time = slowest among the cell replicas - max_cell_iter_time = self.merge_max_elements(cell_replica_iter_times) - stage_iter_times.append(max_cell_iter_time) - - interleaved_list = [ - val - for pair in itertools.zip_longest(*stage_iter_times) - for val in pair - if val is not None - ] - - model_replica_iter_times = [] - if len(interleaved_list) <= num_stages: - model_replica_iter_times = interleaved_list.copy() else: - # Creates a sliding window to find the bottlenecked stage in the pipeline - for i in range(len(interleaved_list) - num_stages): - window = interleaved_list[i : i + num_stages] - model_replica_iter_times.append(max(window)) - model_replica_time.append(sum(model_replica_iter_times)) - total_energy += model_replica_energy + kv_transfer_time = get_p2p_comm_time( + gpu=self.gpu, + num_nodes=2, + num_gpus_per_node=1, + dtype=self.dtype["kv"], + num_elements=kv_cache_size, + ) - # Final execution time = the slowest among the replicas - total_time = max(model_replica_time) - ### Finished simulation; calculate the statistics of the results ### + # Decode phase on decode GPU + decode_requests = prefill_updated_requests + ( + decode_updated_requests, + decode_iter_times, + decode_reqs_per_iter, + decode_tokens_per_iter, + decode_exe_time, + decode_token_gen_times, + decode_energy, + ) = self.sub_simulate( + execution_plan, + arch, + frequency, + decode_requests, + num_attn_cell_replicas, + max_num_tokens_per_stage, + max_batch_size, + phase="decode", + ) + if decode_updated_requests is None: + return None + + # Combine statistics + renamed_gen_times = { + f"{key}_{model_replica}": prefill_token_gen_times.get(key, []) + decode_token_gen_times.get(key, []) + for key in set(list(prefill_token_gen_times.keys()) + list(decode_token_gen_times.keys())) + } + request_token_gen_times.update(renamed_gen_times) + # list_of_exe_time.extend(prefill_exe_time + [(name, t + kv_transfer_time) if name != "Wait" else (name, t) for name, t in decode_exe_time]) + # integrate the prefill and decode exe time + prefill_exe_time_dict = {name: t for name, t in prefill_exe_time} + decode_exe_time_dict = {name: t for name, t in decode_exe_time} + total_exe_time = [] + for name, t in decode_exe_time_dict.items(): + if name in prefill_exe_time_dict: + total_exe_time.append((name, prefill_exe_time_dict[name] + decode_exe_time_dict[name])) + total_exe_time.append(("KV Transfer", kv_transfer_time)) + list_of_exe_time.append(total_exe_time) + num_reqs_per_iteration.extend(prefill_reqs_per_iter + decode_reqs_per_iter) + num_tokens_per_iteration.extend(prefill_tokens_per_iter + decode_tokens_per_iter) + model_replica_energy += prefill_energy + decode_energy + kv_transfer_time * KV_CACHE_TRANSFER_POWER + model_replica_time.extend(prefill_iter_times + decode_iter_times + [kv_transfer_time]) + total_energy += model_replica_energy + total_time = max(total_time, sum(model_replica_time)) - avg_input_len = sum(request.input_len for request in copied_requests) / len( - copied_requests - ) - avg_output_len = sum(request.output_len for request in copied_requests) / len( - copied_requests - ) + avg_input_len = sum(request.input_len for request in copied_requests) / len(copied_requests) + avg_output_len = sum(request.output_len for request in copied_requests) / len(copied_requests) performance_metrics, performance_metrics_units, slo_metrics = self.get_metrics( model_config, avg_input_len, @@ -466,28 +366,24 @@ def simulate( token_percentiles, num_tokens_per_iteration, arch, - slo_targets) + slo_targets, + ) exe_stat_dict = {} idle_time = [] + # print(f"list_of_exe_time: {list_of_exe_time}") for exe_t in list_of_exe_time: - # Calculating idle time + # print(f"exe_t: {exe_t}") idle_time.append(total_time - sum(t * num_stages for _, t in exe_t)) summed_data = defaultdict(float) - # A function may be called multiple times in one iter, creating multiple entries - # merging these entries for each iter for key, value in exe_t: summed_data[key] += value result = [(key, value) for key, value in summed_data.items()] - # Modifying the data structure so it's easier to compute mean and std later for name, time in result: exe_stat_dict.setdefault(name, []).append(time * num_stages) - - # Compute the statistics exe_stat = [] for name in exe_stat_dict.keys(): exe_lst = exe_stat_dict.get(name) - # Wait time will be counted as Idle time if name == "Wait": for i, val in enumerate(exe_lst): idle_time[i] += val @@ -497,11 +393,10 @@ def simulate( exe_stat.append((name, exe_mean, exe_std)) exe_stat.append(("Idle", np.mean(idle_time), np.std(idle_time))) - requests = sorted(requests, key=lambda x: x.time_stamp) - + requests = sorted(copied_requests, key=lambda x: x.time_stamp) return requests, SimulatorOutput( param_size_per_device=param_size / GB, - available_memory_per_device=min_available_memory / GB, + available_memory_per_device=min_available_memory_decode / GB, num_requests_per_iteration=np.mean(num_reqs_per_iteration), num_tokens_per_iteration=np.mean(num_tokens_per_iteration), time_statistics=exe_stat, @@ -521,208 +416,223 @@ def sub_simulate( num_attn_cell_replicas: int, max_num_tokens_per_stage: int, max_batch_size: int, + phase: str = "combined", # "prefill" or "decode" ): parallel_schedule = execution_plan.parallel_schedule stage_schedule = parallel_schedule.stage_schedule num_stages = parallel_schedule.num_stages - - min_num_replicas = min( - cell_schedule.num_replicas - for cell_schedule in stage_schedule.cell_schedules - ) + min_num_replicas = min(cell_schedule.num_replicas for cell_schedule in stage_schedule.cell_schedules) num_cached_tokens = 0 req_counter = 0 - num_generated_tokens: Dict[int, int] = {} # request_id -> num_tokens - running: List[int] = [] # request_ids - stopped: List[int] = [] # request_ids - + num_generated_tokens: Dict[int, int] = {} + running: List[int] = [] + stopped: List[int] = [] get_seq_len = lambda request_id: ( - requests[request_id].input_len + num_generated_tokens[request_id] + requests[request_id].input_len + num_generated_tokens.get(request_id, 0) ) - # Statistics. execution_time: List[Tuple[str, float]] = [] - num_cells = len(stage_schedule.cell_schedules) - for i in range(num_cells): - cell = stage_schedule.cell_schedules[i].cell + for i, cell_schedule in enumerate(stage_schedule.cell_schedules): + cell = cell_schedule.cell execution_time.append((cell.get_name(), 0.0)) for comm in stage_schedule.reshard_comms[i]: execution_time.append((comm.comm_type.name, 0.0)) - if parallel_schedule.num_stages > 1: + if num_stages > 1: execution_time.append(("SendRecv", 0.0)) num_reqs_per_iteration: List[int] = [] num_tokens_per_iteration: List[int] = [] - # Simulate the execution. time_per_iteration: List[float] = [] - # Time metrics for each request, request id is the key and value of list is time per token request_token_gen_times: Dict[str, List[float]] = {} - internal_clock = 0 # decide whether a request has arrived - wait_next_req_time = 0 # the idle time of waiting for next request to come - energy = 0 # energy consumption + internal_clock = 0 + wait_next_req_time = 0 + energy = 0 + while True: - # Batch requests. input_lens: List[int] = [] cached_lens: List[int] = [] - new_running: List[int] = [] - while running: - request_id = running.pop(0) - while num_cached_tokens + 1 > max_num_tokens_per_stage: - if running: - victim = running.pop(-1) - stopped.append(victim) - num_cached_tokens -= get_seq_len(victim) - else: + + if phase == "prefill": + # Batch requests for prefill + while running: + request_id = running.pop(0) + seq_len = get_seq_len(request_id) + if num_cached_tokens + seq_len > max_num_tokens_per_stage: stopped.append(request_id) - num_cached_tokens -= get_seq_len(request_id) - break - else: - input_lens.append(1) - num_cached_tokens += 1 - cached_lens.append(num_generated_tokens[request_id] + 1) - new_running.append(request_id) - running = new_running - - # Resume the stopped requests. - # Sort in the order of request_id. - stopped = sorted(stopped) - while stopped: - request_id = stopped[0] - seq_len = get_seq_len(request_id) - if num_cached_tokens + seq_len + 1 > max_num_tokens_per_stage: - break - request_id = stopped.pop(0) - input_lens.append(1) - num_cached_tokens += seq_len + 1 - cached_lens.append(num_generated_tokens[request_id] + 1) - running.append(request_id) - - # Batch new requests. - if not stopped: + num_cached_tokens -= seq_len + else: + input_lens.append(requests[request_id].input_len) + cached_lens.append(0) + new_running.append(request_id) + running = new_running + + # Add new requests while req_counter < len(requests): request_id = req_counter input_len = requests[request_id].input_len - # If the KV cache does not have enough space, stop. if num_cached_tokens + input_len > max_num_tokens_per_stage: break - num_tokens = sum(input_lens) + input_len - # If the total number of tokens exceeds the maximum, stop. - if ( - num_tokens * num_attn_cell_replicas / min_num_replicas - > MAX_NUM_INPUT_TOKENS - ): + if num_tokens * num_attn_cell_replicas / min_num_replicas > MAX_NUM_INPUT_TOKENS: break - - curr_batch_size = len(running) - if(curr_batch_size == max_batch_size and max_batch_size != 0): + if len(running) == max_batch_size and max_batch_size != 0: break - - # Request has not yet arrived if requests[request_id].time_stamp > internal_clock: break - num_cached_tokens += input_len input_lens.append(input_len) cached_lens.append(0) running.append(request_id) - num_generated_tokens[request_id] = 0 req_counter += 1 - if not running: - if req_counter < len(requests): - # Cannot proceed. - # This can happen when the space for the KV cache is - # too small to store even a single sequence. - if num_cached_tokens + input_len > max_num_tokens_per_stage: - return None, None, None, None, None, None + if not running: + if req_counter < len(requests): + if num_cached_tokens + input_len > max_num_tokens_per_stage: + return None, None, None, None, None, None, None + wait_next_req_time += requests[req_counter].time_stamp - internal_clock + internal_clock = requests[req_counter].time_stamp else: - # Or because the requests are coming too slow; - # wait until next request comes. - if requests[req_counter].time_stamp > internal_clock: - wait_next_req_time += ( - requests[req_counter].time_stamp - internal_clock - ) - internal_clock = requests[req_counter].time_stamp - else: - return None, None, None, None, None, None - - else: - # All the requests are finished. - assert num_cached_tokens == 0, num_cached_tokens - assert not stopped, stopped - break - - # Record the number of requests and tokens. - num_reqs_per_iteration.append(len(running) * num_attn_cell_replicas) - num_tokens_per_iteration.append(sum(input_lens) * num_attn_cell_replicas) + break - # Get the execution time of a stage with the given input if running - if running: - stage_execution_time, stage_energy = self.get_stage_execution_time( - execution_plan.parallel_schedule.stage_schedule, - num_attn_cell_replicas, - input_lens, - cached_lens, - self.gpu, - frequency, - self.cluster_size_per_node, - ) - if num_stages > 1: - stage_execution_time.append( - self.get_cross_stage_comm_time( - sum(input_lens), - execution_plan.stage_clusters, - self.gpu, - self.cluster_size_per_node, - ) - ) - time_per_iteration.append(sum(stage_execution_time)) - internal_clock += sum(stage_execution_time) - energy += sum(stage_energy) - # Update the statistics. - for i in range(len(execution_time)): - execution_time[i] = ( - execution_time[i][0], - execution_time[i][1] + stage_execution_time[i], + # Execute prefill + if running: + stage_execution_time, stage_energy = self.get_stage_execution_time( + stage_schedule, + num_attn_cell_replicas, + input_lens, + cached_lens, + self.gpu, + frequency, + self.cluster_size_per_node, + phase="prefill", + gpu_id=self.prefill_gpu, ) - - # Remove finished requests from the batch. Update logged time per token - for request_id in running: - num_generated_tokens[request_id] += 1 - if num_generated_tokens[request_id] == 1: - request_token_gen_times[request_id] = [ - time_per_iteration[-1] * num_stages - ] - else: - request_token_gen_times[request_id].append( - time_per_iteration[-1] * num_stages + time_per_iteration.append(sum(stage_execution_time)) + internal_clock += sum(stage_execution_time) + energy += sum(stage_energy) + for i in range(len(execution_time)): + execution_time[i] = ( + execution_time[i][0], + execution_time[i][1] + stage_execution_time[i], ) - new_running: List[int] = [] - for request_id in running: - num_generated = num_generated_tokens[request_id] - if arch == "encoder": - output_len = 0 + + # Record TTFT + for request_id in running: + request_token_gen_times[str(request_id)] = [time_per_iteration[-1] * num_stages] + requests[request_id].time_stamp = internal_clock * num_stages + running = [] # Prefill complete, move to decode + + elif phase == "decode": + # Batch requests for decode + while running: + request_id = running.pop(0) + seq_len = get_seq_len(request_id) + if num_cached_tokens + seq_len + 1 > max_num_tokens_per_stage: + stopped.append(request_id) + num_cached_tokens -= seq_len else: - output_len = requests[request_id].output_len - if num_generated < output_len: + input_lens.append(1) + num_cached_tokens += 1 + cached_lens.append(num_generated_tokens.get(request_id, 0) + 1) new_running.append(request_id) - else: - # Finished processing; update the time_stamp to completed time. - num_cached_tokens -= get_seq_len(request_id) - 1 - requests[request_id].time_stamp = internal_clock * num_stages running = new_running - execution_time.append(("Wait", wait_next_req_time)) + # Resume stopped requests + stopped = sorted(stopped) + while stopped: + request_id = stopped[0] + seq_len = get_seq_len(request_id) + if num_cached_tokens + seq_len + 1 > max_num_tokens_per_stage: + break + request_id = stopped.pop(0) + input_lens.append(1) + num_cached_tokens += seq_len + 1 + cached_lens.append(num_generated_tokens.get(request_id, 0) + 1) + running.append(request_id) + # Add new requests + while req_counter < len(requests): + request_id = req_counter + input_len = requests[request_id].input_len + if num_cached_tokens + input_len + 1 > max_num_tokens_per_stage: + break + num_tokens = sum(input_lens) + 1 + if num_tokens * num_attn_cell_replicas / min_num_replicas > MAX_NUM_INPUT_TOKENS: + break + if len(running) >= max_batch_size and max_batch_size != 0: + break + if requests[request_id].time_stamp > internal_clock: + break + num_cached_tokens += input_len + 1 + input_lens.append(1) + cached_lens.append(num_generated_tokens.get(request_id, 0) + 1) + running.append(request_id) + num_generated_tokens[request_id] = 0 + req_counter += 1 + + if not running: + if req_counter < len(requests): + if num_cached_tokens + input_len > max_num_tokens_per_stage: + return None, None, None, None, None, None, None + wait_next_req_time += requests[req_counter].time_stamp - internal_clock + internal_clock = requests[req_counter].time_stamp + else: + break + + # Execute decode + if running: + stage_execution_time, stage_energy = self.get_stage_execution_time( + stage_schedule, + num_attn_cell_replicas, + input_lens, + cached_lens, + self.gpu, + frequency, + self.cluster_size_per_node, + phase="decode", + gpu_id=self.decode_gpu, + ) + time_per_iteration.append(sum(stage_execution_time)) + internal_clock += sum(stage_execution_time) + energy += sum(stage_energy) + for i in range(len(execution_time)): + execution_time[i] = ( + execution_time[i][0], + execution_time[i][1] + stage_execution_time[i], + ) + + # Update token generation + for request_id in running: + num_generated_tokens[request_id] = num_generated_tokens.get(request_id, 0) + 1 + if str(request_id) not in request_token_gen_times: + request_token_gen_times[str(request_id)] = [] + request_token_gen_times[str(request_id)].append(time_per_iteration[-1] * num_stages) + + # Remove finished requests + new_running = [] + for request_id in running: + if num_generated_tokens[request_id] < requests[request_id].output_len: + new_running.append(request_id) + else: + num_cached_tokens -= get_seq_len(request_id) - 1 + requests[request_id].time_stamp = internal_clock * num_stages + running = new_running + + else: + raise ValueError(f"Unsupported phase: {phase}") + + num_reqs_per_iteration.append(len(running) * num_attn_cell_replicas) + num_tokens_per_iteration.append(sum(input_lens) * num_attn_cell_replicas) + + execution_time.append(("Wait", wait_next_req_time)) return ( requests, time_per_iteration, - np.mean(num_reqs_per_iteration), - np.mean(num_tokens_per_iteration), + num_reqs_per_iteration, + num_tokens_per_iteration, execution_time, request_token_gen_times, energy, @@ -737,33 +647,23 @@ def get_stage_execution_time( gpu_type: str, frequency: int, cluster_size_per_node: int, - ) -> List[float]: - # Calculate the number of input tokens per cell. - num_total_input_tokens = ( - sum(input_lens_per_attn_replica) * num_attn_cell_replicas - ) - + phase: str = "combined", + gpu_id: int = 0, + ) -> Tuple[List[float], List[float]]: + num_total_input_tokens = sum(input_lens_per_attn_replica) * num_attn_cell_replicas execution_time: List[float] = [] execution_energy: List[float] = [] for i, cell_schedule in enumerate(stage_schedule.cell_schedules): - # Split the input tokens evenly among the replicas. num_replicas = cell_schedule.num_replicas - num_input_tokens = ( - num_total_input_tokens + num_replicas - 1 - ) // num_replicas + num_input_tokens = (num_total_input_tokens + num_replicas - 1) // num_replicas num_devices = cell_schedule.get_num_devices() - # For mixed precision, assume the data with lower precision - # will be dequantized to match data of higher precision comp_type = self.highest_prec() - # Cell execution. - # We leverage the fact that the 0-th device is always assigned the - # most number of tasks. cell_execution_time = 0.0 cell_execution_energy = 0.0 task_dict = cell_schedule.task_mapping.tasks_per_device[0] for task_type, tasks in task_dict.items(): - if task_type == "MHAHead" or task_type == "MQAHead": + if task_type in ["MHAHead", "MQAHead"]: exe_time, exe_energy = mha_time( gpu_type, frequency, @@ -806,8 +706,6 @@ def get_stage_execution_time( cell_execution_time += exe_time cell_execution_energy += exe_energy elif task_type.startswith("ExpertMLPFilter"): - # Each expert will get topk / E of the input tokens where E - # is the total number of experts. num_total_experts = cell_schedule.cell.num_experts topk = cell_schedule.cell.topk exe_time, exe_energy = mlp_time( @@ -836,16 +734,12 @@ def get_stage_execution_time( execution_time.append(cell_execution_time) execution_energy.append(cell_execution_energy * num_devices) - if ( - cell_schedule.cell.get_name() == "MoE" - or cell_schedule.cell.get_name() == "SwiMoE" - ): + if cell_schedule.cell.get_name() in ["MoE", "SwiMoE"]: if len(task_dict) < cell_schedule.cell.num_experts: num_devices = len(cell_schedule.task_mapping.tasks_per_device) num_input_tokens = max(num_input_tokens // num_devices, 1) hidden_size = self.model.hidden_size - # Resharding. for comm in stage_schedule.reshard_comms[i]: if comm.num_devices < cluster_size_per_node: num_nodes = 1 @@ -853,7 +747,6 @@ def get_stage_execution_time( else: num_nodes = comm.num_devices // cluster_size_per_node num_devices_per_node = cluster_size_per_node - num_input_tokens *= comm.size_factor num_input_tokens = max(num_input_tokens, 1) if comm.comm_type == CommType.AllReduce: @@ -867,9 +760,7 @@ def get_stage_execution_time( elif comm.comm_type == CommType.AllToAll: num_elements = num_input_tokens * hidden_size else: - raise NotImplementedError( - f"Unsupported comm type: {comm.comm_type}" - ) + raise NotImplementedError(f"Unsupported comm type: {comm.comm_type}") comm_time = get_comm_time( comm.comm_type, gpu_type, @@ -880,7 +771,6 @@ def get_stage_execution_time( ) execution_time.append(comm_time) - # Multiply the block execution time by the number of blocks. return [t * stage_schedule.num_blocks for t in execution_time], [ e * stage_schedule.num_blocks for e in execution_energy ] @@ -895,19 +785,10 @@ def get_cross_stage_comm_time( hidden_size = self.model.hidden_size num_total_devices = sum(cluster.get_num_devices() for cluster in stage_clusters) cross_node = num_total_devices > cluster_size_per_node - if cross_node: - return get_p2p_comm_time( - gpu=gpu_type, - num_nodes=2, - num_gpus_per_node=1, - dtype=self.dtype["act"], - num_elements=num_input_tokens * hidden_size, - ) - else: - return get_p2p_comm_time( - gpu=gpu_type, - num_nodes=1, - num_gpus_per_node=2, - dtype=self.dtype["act"], - num_elements=num_input_tokens * hidden_size, - ) + return get_p2p_comm_time( + gpu=gpu_type, + num_nodes=2 if cross_node else 1, + num_gpus_per_node=1, + dtype=self.dtype["act"], + num_elements=num_input_tokens * hidden_size, + ) \ No newline at end of file diff --git a/apex_plus/simulator/simulator_origin.py b/apex_plus/simulator/simulator_origin.py new file mode 100644 index 0000000..5a279f6 --- /dev/null +++ b/apex_plus/simulator/simulator_origin.py @@ -0,0 +1,914 @@ +from dataclasses import dataclass +from typing import Dict, List, Optional, Tuple +from collections import defaultdict + +import numpy as np +import itertools +import copy + +from apex_plus.cluster.cluster import Cluster +from apex_plus.execution.plan import ExecutionPlan +from apex_plus.ir.transformer import Transformer +from apex_plus.parallel.comm import CommType +from apex_plus.parallel.schedule import StageSchedule +from apex_plus.simulator.comm_profile import get_comm_time, get_p2p_comm_time +from apex_plus.simulator.comp_profile import mha_time, mlp_time, glu_time, swiglu_time +from apex_plus.simulator.trace import Trace, Request +from apex_plus.utils.dtype import DTYPE + +GB = 1024 * 1024 * 1024 +WORKSPACE = 1 * GB # a constant buffer for each device to run the program + +MAX_NUM_INPUT_TOKENS = 64 * 1024 # Max in profile/scripts/gemm.py + +US_TO_SEC = 1000000 +MS_TO_SEC = 1000 +US_TO_MS = 1000 + + +@dataclass +class SimulatorOutput: + + param_size_per_device: float + available_memory_per_device: float + num_requests_per_iteration: float + num_tokens_per_iteration: float + time_statistics: List[Tuple[str, float]] + performance_metrics: List[Tuple[str, float]] + performance_metrics_units: List[str] + slo_metrics: List[float] + total_time: float + total_energy: float + + +class Simulator: + + def __init__( + self, + model: Transformer, + cluster: Cluster, + trace: Trace, + dtype: dict, + ) -> None: + self.model = model + self.cluster = cluster + self.trace = trace + self.dtype = dtype + self.gpu = cluster.get_device().device_type + self.gpu_memory = cluster.get_device_memory_capacity() + self.peak_flops = cluster.get_device().peak_flops[self.highest_prec()] + self.peak_mem_bandwidth = cluster.get_device().peak_mem_bandwidth + self.num_total_nodes = cluster.get_num_nodes() + self.num_total_devices = cluster.get_num_devices() + self.cluster_size_per_node = self.num_total_devices // self.num_total_nodes + + def highest_prec(self) -> DTYPE: + data_type = [] + data_type.append(self.dtype["w"]) + data_type.append(self.dtype["kv"]) + data_type.append(self.dtype["act"]) + # Dealing with mixed precesion + # Assuming we dequantize the value for computation + highest_precision = DTYPE.FLOAT8 + if DTYPE.FLOAT16 in data_type: + highest_precision = DTYPE.FLOAT16 + if DTYPE.FLOAT32 in data_type: + highest_precision = DTYPE.FLOAT32 + return highest_precision + + def dispatch( + self, + requests: List[Request], + factor: int, + ) -> List[List[Request]]: + sublists = [[] for _ in range(factor)] + # Distribute elements in a round-robin fashion + # Can be replaced with more sophisticated strategy + for index, element in enumerate(requests): + sublists[index % factor].append(element) + return sublists + + def merge_max_elements(self, lists): + max_length = max(len(lst) for lst in lists) + extended_lists = [lst + [None] * (max_length - len(lst)) for lst in lists] + merged_list = [ + max(filter(lambda x: x is not None, elements)) + for elements in zip(*extended_lists) + ] + return merged_list + + def get_metrics( + self, + model_config: Transformer, + avg_input_len: int = 0, + avg_output_len: int = 0, + requests: List[Trace] = [], + total_time: float = 0.0, + request_token_gen_times: Dict[str, List[float]] = {}, + req_percentiles: List[int] = [], + token_percentiles: List[int] = [], + seq_lens: List[int] = [], + arch: str = '', + slo_targets: List[int] = [], + ): + + + def calculate_tbt_percentiles(latency_dict): + token_latencies_per_request = [latency for latency in latency_dict.values()] + avg_tbt_vals = [] + percentile_vals = [] + + # Calculate avg TBT per request + for latency_list in token_latencies_per_request: + list_length = len(latency_list) + avg_tbt = 0 + if list_length > 1: + ttft = latency_list[0] + ttlt = sum(latency_list) + avg_tbt = (ttlt - ttft) / list_length + avg_tbt_vals.append(avg_tbt) + + # Calculate all necessary percentiles for tbt + for percentile in token_percentiles: + percentile_vals.append(np.percentile(avg_tbt_vals, percentile)) + + return percentile_vals + + def calculate_slo_metrics(latency_dict, slo_targets): + token_latencies_per_request = [latency for latency in latency_dict.values()] + num_requests = len(token_latencies_per_request) + ttft_target = slo_targets[0] + tpot_target = slo_targets[1] + slo_metrics = [] + + # Calculate Percentage of requests that are <= TTFT_SLO + # TTFT is just the first token in the latency + ttft_slo_counter = 0 + for latency_list in token_latencies_per_request: + if(latency_list[0]/US_TO_MS <= ttft_target): + ttft_slo_counter += 1 + slo_metrics.append( (ttft_slo_counter/num_requests) * 100 ) + + # Calculate Percentage of requests that have an avg TPOT <= TPOT_SLO + tpot_slo_counter = 0 + tok_latencies_per_req_after_first_tok = [sublist[1:] for sublist in token_latencies_per_request] + for latency_list in token_gen_times: + # Calculate Avg TPOT per request + avg_tpot = np.mean(latency_list)/US_TO_MS + if(avg_tpot <= tpot_target): + tpot_slo_counter += 1 + slo_metrics.append( (tpot_slo_counter/num_requests) * 100 ) + + return slo_metrics + + # Store performance metrics - Time to first token, TPOT ,P50, P95, & other latencies + performance_metrics: List[Tuple[str, float]] = [] + performance_metrics_units: List[str] = [] + performance_metrics.append( + ("Throughput: Avg. Tokens generated per second", float("NaN")) + ) + performance_metrics.append( + ("Throughput: Avg. Tokens processed per second", float("NaN")) + ) + performance_metrics.append(("Throughput: Requests per second", float("NaN"))) + performance_metrics.append( + ("Latency: Avg. Time to first token (TTFT in msec)", float("NaN")) + ) + performance_metrics.append( + ("Latency: Avg. Time per output token (TPOT in msec)", float("NaN")) + ) + performance_metrics_units += [ + "tokens/sec", + "tokens/sec", + "requests/sec", + "msec", + "msec", + ] + + num_layers = 0 + num_heads = 0 + head_dim = 0 + hidden_size = 0 + theoretical_peak_flops = self.peak_flops + + if hasattr(model_config, "num_layers"): + num_layers = model_config.num_layers + num_heads = model_config.num_heads + hidden_size = model_config.hidden_size + elif hasattr(model_config, "num_decoder_layers"): + num_layers = model_config.num_decoder_layers + num_heads = model_config.num_decoder_heads + hidden_size = model_config.decoder_hidden_size + else: + raise ValueError("Unable to get model layers, heads, or hidden size") + head_dim = hidden_size // num_heads + + num_parameters = num_layers * hidden_size * hidden_size * 12 + + tpot = 0.0 + avg_ttft = 0.0 + mbu = 0.0 + # If encoder, there the output_len is 0 + if arch == "encoder": + avg_output_len = 0.0 + # Decoders that generate tokens + else: + # TPOT after the first token(this is also known as inter-token latency) + token_gen_times = [value[1:] for value in request_token_gen_times.values()] + flat_token_gen_times = [ + item for sublist in token_gen_times for item in sublist + ] + tpot = np.mean(flat_token_gen_times) / MS_TO_SEC + avg_ttft = ( + np.mean([value[0] for value in request_token_gen_times.values()]) + / US_TO_MS + ) + + # Calculate token percentiles + token_percentiles = token_percentiles + [50, 95] + token_percentiles.sort() + # Avg percentile vals are returned in order of sorted percentiles + avg_percentile_vals = calculate_tbt_percentiles(request_token_gen_times) + # Add to performance_metrics + for index, percentile in enumerate(token_percentiles): + performance_metrics.append( + ( + f"Avg. TBT Percentile: P{percentile}", + avg_percentile_vals[index] / US_TO_MS, + ) + ) + performance_metrics_units.append("msec") + # MBU + kv_cache_size = ( + 2 * num_layers * num_heads * head_dim * self.dtype["kv"].size + ) + tpot_sec = tpot / MS_TO_SEC + theoretical_peak_mem_bandwidth = self.peak_mem_bandwidth + observed_mem_bandwidth = (num_parameters + kv_cache_size) / tpot_sec + mbu = (observed_mem_bandwidth / theoretical_peak_mem_bandwidth) * 100 + + # Tokens gen per second + token_throughput = avg_output_len * len(requests) / (total_time / US_TO_SEC) + performance_metrics[0] = (performance_metrics[0][0], token_throughput) + # Tokens processed per second + performance_metrics[1] = ( + performance_metrics[1][0], + (avg_input_len + avg_output_len) * len(requests) / (total_time / US_TO_SEC), + ) + # Requests per second + performance_metrics[2] = ( + performance_metrics[2][0], + len(requests) / (total_time / US_TO_SEC), + ) + # Time to first token + performance_metrics[3] = (performance_metrics[3][0], avg_ttft) + # TPOT after the first token(this is also known as inter-token latency) + performance_metrics[4] = (performance_metrics[4][0], tpot) + # Calculate request percentiles + + request_latencies = [ + sum(token_latencies) for token_latencies in request_token_gen_times.values() + ] + req_percentiles = req_percentiles + [50, 95] + req_percentiles.sort() + for percentile in req_percentiles: + percentile_val = np.percentile(request_latencies, percentile) / US_TO_SEC + performance_metrics.append( + ( + f"Request Completion Latency: {percentile}th percentile", + percentile_val, + ) + ) + performance_metrics_units.append("sec") + + # Calculate Avg MFU + observed_throughput = token_throughput + mfus: List[float] = [] + for index, seq_len in enumerate(seq_lens): + theoretical_throughput = theoretical_peak_flops / ( + 6 * num_parameters + 12 * num_layers * num_heads * head_dim * seq_len + ) + mfu = (observed_throughput / theoretical_throughput) * 100 + mfus.append(mfu) + + avg_mfu = np.mean(mfus) + performance_metrics.append((f"Avg. MFU Per iteration", avg_mfu)) + performance_metrics_units.append("%") + + # Append mbu + performance_metrics.append((f"MBU ", mbu)) + performance_metrics_units.append("%") + + # Get SLO Metrics + slo_metrics = calculate_slo_metrics(request_token_gen_times, slo_targets) + + return performance_metrics, performance_metrics_units, slo_metrics + + def simulate( + self, + execution_plan: ExecutionPlan, + arch: str, + frequency: int, + model_config: Transformer, + req_percentiles: List[int] = [], + token_percentiles: List[int] = [], + slo_targets: List[int] = [], + max_batch_size: int = 0, + distserve: bool = False, + ) -> Optional[SimulatorOutput]: + parallel_schedule = execution_plan.parallel_schedule + stage_schedule = parallel_schedule.stage_schedule + num_stages = parallel_schedule.num_stages + num_model_replicas = parallel_schedule.num_model_replicas + num_attn_cell_replicas = 0 + for cell_schedule in stage_schedule.cell_schedules: + if cell_schedule.cell.is_attn(): + num_attn_cell_replicas = cell_schedule.num_replicas + break + assert num_attn_cell_replicas > 0 + + param_sizes = stage_schedule.get_param_size_per_device(self.dtype["w"]) + num_devices = len(param_sizes) + + available_memories = [ + self.gpu_memory - WORKSPACE - param_size for param_size in param_sizes + ] + if any(avail_mem < 0 for avail_mem in available_memories): + # Invalid. + return None + min_available_memory = min(available_memories) + WORKSPACE + param_size = max(param_sizes) + + # Calculate the maximum number of tokens that can be stored in KV cache. + # This limits the maximum number of sequences that can be batched. + kv_token_sizes = ( + [1] * num_devices + if arch == "encoder" + else stage_schedule.get_kv_token_size_per_device(self.dtype["kv"]) + ) + + max_num_tokens = min( + int(available_memories[i] // kv_token_sizes[i]) for i in range(num_devices) + ) + # Evenly partition the KV cache for each stage. + max_num_tokens_per_stage = max_num_tokens // num_stages + + # Statistics + list_of_exe_time = [] + num_reqs_per_iteration = [] + num_tokens_per_iteration = [] + request_token_gen_times = {} + + requests = [] + model_replica_time = [] + + total_energy = 0.0 + # Copy to avoid altering the original traces + copied_requests = copy.deepcopy(self.trace.requests) + + ### Finished housekeeping; starts the actual simulation ### + + # Split a list of requests to n sublists, + # where n = num_model_replicas + model_requests = self.dispatch(copied_requests, num_model_replicas) + for model_replica in range(num_model_replicas): + model_replica_energy = 0.0 + stage_iter_times = [] + stage_requests = self.dispatch(model_requests[model_replica], num_stages) + for stage in range(num_stages): + + cell_replica_iter_times = [] + cell_requests = self.dispatch( + stage_requests[stage], num_attn_cell_replicas + ) + for cell_replica in range(num_attn_cell_replicas): + target_requests = cell_requests[cell_replica] + + ( + updated_requests, + cell_replica_iter_time, + reqs_per_iter, + tokens_per_iter, + exe_time, + request_token_gen_time, + stage_energy, + ) = self.sub_simulate( + execution_plan, + arch, + frequency, + target_requests, + num_attn_cell_replicas, + max_num_tokens_per_stage, + max_batch_size, + ) + + if updated_requests is None: + return None, None + requests.extend( + updated_requests + ) # timestamp updated with completion time + id_num = model_replica * stage * cell_replica + renamed_gen_time = { + f"{key}_{id_num}": value + for key, value in request_token_gen_time.items() + } + request_token_gen_times.update(renamed_gen_time) + cell_replica_iter_times.append(cell_replica_iter_time) + num_reqs_per_iteration.append(reqs_per_iter) + num_tokens_per_iteration.append(tokens_per_iter) + list_of_exe_time.append(exe_time) + model_replica_energy += ( + stage_energy // num_attn_cell_replicas * num_stages + ) + # Note: dividied by cell replicas as the energy scaling of cell is already handled in Line 693 + # Multiplied with num_stages because we only simulate one stage, but num_stages run concurrently + + # iteration time = slowest among the cell replicas + max_cell_iter_time = self.merge_max_elements(cell_replica_iter_times) + stage_iter_times.append(max_cell_iter_time) + + interleaved_list = [ + val + for pair in itertools.zip_longest(*stage_iter_times) + for val in pair + if val is not None + ] + + model_replica_iter_times = [] + if len(interleaved_list) <= num_stages: + model_replica_iter_times = interleaved_list.copy() + else: + # Creates a sliding window to find the bottlenecked stage in the pipeline + for i in range(len(interleaved_list) - num_stages): + window = interleaved_list[i : i + num_stages] + model_replica_iter_times.append(max(window)) + model_replica_time.append(sum(model_replica_iter_times)) + total_energy += model_replica_energy + + # Final execution time = the slowest among the replicas + total_time = max(model_replica_time) + + ### Finished simulation; calculate the statistics of the results ### + + avg_input_len = sum(request.input_len for request in copied_requests) / len( + copied_requests + ) + avg_output_len = sum(request.output_len for request in copied_requests) / len( + copied_requests + ) + performance_metrics, performance_metrics_units, slo_metrics = self.get_metrics( + model_config, + avg_input_len, + avg_output_len, + copied_requests, + total_time, + request_token_gen_times, + req_percentiles, + token_percentiles, + num_tokens_per_iteration, + arch, + slo_targets) + + exe_stat_dict = {} + idle_time = [] + for exe_t in list_of_exe_time: + # Calculating idle time + idle_time.append(total_time - sum(t * num_stages for _, t in exe_t)) + summed_data = defaultdict(float) + # A function may be called multiple times in one iter, creating multiple entries + # merging these entries for each iter + for key, value in exe_t: + summed_data[key] += value + result = [(key, value) for key, value in summed_data.items()] + # Modifying the data structure so it's easier to compute mean and std later + for name, time in result: + exe_stat_dict.setdefault(name, []).append(time * num_stages) + + # Compute the statistics + exe_stat = [] + for name in exe_stat_dict.keys(): + exe_lst = exe_stat_dict.get(name) + # Wait time will be counted as Idle time + if name == "Wait": + for i, val in enumerate(exe_lst): + idle_time[i] += val + continue + exe_mean = np.mean(exe_lst) + exe_std = np.std(exe_lst) + exe_stat.append((name, exe_mean, exe_std)) + exe_stat.append(("Idle", np.mean(idle_time), np.std(idle_time))) + + requests = sorted(requests, key=lambda x: x.time_stamp) + + return requests, SimulatorOutput( + param_size_per_device=param_size / GB, + available_memory_per_device=min_available_memory / GB, + num_requests_per_iteration=np.mean(num_reqs_per_iteration), + num_tokens_per_iteration=np.mean(num_tokens_per_iteration), + time_statistics=exe_stat, + performance_metrics=performance_metrics, + performance_metrics_units=performance_metrics_units, + slo_metrics=slo_metrics, + total_time=total_time, + total_energy=total_energy, + ) + + def sub_simulate( + self, + execution_plan: ExecutionPlan, + arch: str, + frequency: int, + requests: List[Request], + num_attn_cell_replicas: int, + max_num_tokens_per_stage: int, + max_batch_size: int, + ): + parallel_schedule = execution_plan.parallel_schedule + stage_schedule = parallel_schedule.stage_schedule + num_stages = parallel_schedule.num_stages + + min_num_replicas = min( + cell_schedule.num_replicas + for cell_schedule in stage_schedule.cell_schedules + ) + + num_cached_tokens = 0 + req_counter = 0 + num_generated_tokens: Dict[int, int] = {} # request_id -> num_tokens + running: List[int] = [] # request_ids + stopped: List[int] = [] # request_ids + + get_seq_len = lambda request_id: ( + requests[request_id].input_len + num_generated_tokens[request_id] + ) + + # Statistics. + execution_time: List[Tuple[str, float]] = [] + num_cells = len(stage_schedule.cell_schedules) + for i in range(num_cells): + cell = stage_schedule.cell_schedules[i].cell + execution_time.append((cell.get_name(), 0.0)) + for comm in stage_schedule.reshard_comms[i]: + execution_time.append((comm.comm_type.name, 0.0)) + if parallel_schedule.num_stages > 1: + execution_time.append(("SendRecv", 0.0)) + + num_reqs_per_iteration: List[int] = [] + num_tokens_per_iteration: List[int] = [] + # Simulate the execution. + time_per_iteration: List[float] = [] + # Time metrics for each request, request id is the key and value of list is time per token + request_token_gen_times: Dict[str, List[float]] = {} + internal_clock = 0 # decide whether a request has arrived + wait_next_req_time = 0 # the idle time of waiting for next request to come + energy = 0 # energy consumption + while True: + # Batch requests. + input_lens: List[int] = [] + cached_lens: List[int] = [] + + new_running: List[int] = [] + while running: + request_id = running.pop(0) + while num_cached_tokens + 1 > max_num_tokens_per_stage: + if running: + victim = running.pop(-1) + stopped.append(victim) + num_cached_tokens -= get_seq_len(victim) + else: + stopped.append(request_id) + num_cached_tokens -= get_seq_len(request_id) + break + else: + input_lens.append(1) + num_cached_tokens += 1 + cached_lens.append(num_generated_tokens[request_id] + 1) + new_running.append(request_id) + running = new_running + + # Resume the stopped requests. + # Sort in the order of request_id. + stopped = sorted(stopped) + while stopped: + request_id = stopped[0] + seq_len = get_seq_len(request_id) + if num_cached_tokens + seq_len + 1 > max_num_tokens_per_stage: + break + request_id = stopped.pop(0) + input_lens.append(1) + num_cached_tokens += seq_len + 1 + cached_lens.append(num_generated_tokens[request_id] + 1) + running.append(request_id) + + # Batch new requests. + if not stopped: + while req_counter < len(requests): + request_id = req_counter + input_len = requests[request_id].input_len + # If the KV cache does not have enough space, stop. + if num_cached_tokens + input_len > max_num_tokens_per_stage: + break + + num_tokens = sum(input_lens) + input_len + # If the total number of tokens exceeds the maximum, stop. + if ( + num_tokens * num_attn_cell_replicas / min_num_replicas + > MAX_NUM_INPUT_TOKENS + ): + break + + curr_batch_size = len(running) + if(curr_batch_size == max_batch_size and max_batch_size != 0): + break + + # Request has not yet arrived + if requests[request_id].time_stamp > internal_clock: + break + + num_cached_tokens += input_len + input_lens.append(input_len) + cached_lens.append(0) + running.append(request_id) + + num_generated_tokens[request_id] = 0 + req_counter += 1 + + if not running: + if req_counter < len(requests): + # Cannot proceed. + # This can happen when the space for the KV cache is + # too small to store even a single sequence. + if num_cached_tokens + input_len > max_num_tokens_per_stage: + return None, None, None, None, None, None + else: + # Or because the requests are coming too slow; + # wait until next request comes. + if requests[req_counter].time_stamp > internal_clock: + wait_next_req_time += ( + requests[req_counter].time_stamp - internal_clock + ) + internal_clock = requests[req_counter].time_stamp + else: + return None, None, None, None, None, None + + else: + # All the requests are finished. + assert num_cached_tokens == 0, num_cached_tokens + assert not stopped, stopped + break + + # Record the number of requests and tokens. + num_reqs_per_iteration.append(len(running) * num_attn_cell_replicas) + num_tokens_per_iteration.append(sum(input_lens) * num_attn_cell_replicas) + + # Get the execution time of a stage with the given input if running + if running: + stage_execution_time, stage_energy = self.get_stage_execution_time( + execution_plan.parallel_schedule.stage_schedule, + num_attn_cell_replicas, + input_lens, + cached_lens, + self.gpu, + frequency, + self.cluster_size_per_node, + ) + if num_stages > 1: + stage_execution_time.append( + self.get_cross_stage_comm_time( + sum(input_lens), + execution_plan.stage_clusters, + self.gpu, + self.cluster_size_per_node, + ) + ) + time_per_iteration.append(sum(stage_execution_time)) + internal_clock += sum(stage_execution_time) + energy += sum(stage_energy) + # Update the statistics. + for i in range(len(execution_time)): + execution_time[i] = ( + execution_time[i][0], + execution_time[i][1] + stage_execution_time[i], + ) + + # Remove finished requests from the batch. Update logged time per token + for request_id in running: + num_generated_tokens[request_id] += 1 + if num_generated_tokens[request_id] == 1: + request_token_gen_times[request_id] = [ + time_per_iteration[-1] * num_stages + ] + else: + request_token_gen_times[request_id].append( + time_per_iteration[-1] * num_stages + ) + new_running: List[int] = [] + for request_id in running: + num_generated = num_generated_tokens[request_id] + if arch == "encoder": + output_len = 0 + else: + output_len = requests[request_id].output_len + if num_generated < output_len: + new_running.append(request_id) + else: + # Finished processing; update the time_stamp to completed time. + num_cached_tokens -= get_seq_len(request_id) - 1 + requests[request_id].time_stamp = internal_clock * num_stages + running = new_running + + execution_time.append(("Wait", wait_next_req_time)) + + return ( + requests, + time_per_iteration, + np.mean(num_reqs_per_iteration), + np.mean(num_tokens_per_iteration), + execution_time, + request_token_gen_times, + energy, + ) + + def get_stage_execution_time( + self, + stage_schedule: StageSchedule, + num_attn_cell_replicas: int, + input_lens_per_attn_replica: List[int], + cached_lens_per_attn_replica: List[int], + gpu_type: str, + frequency: int, + cluster_size_per_node: int, + ) -> List[float]: + # Calculate the number of input tokens per cell. + num_total_input_tokens = ( + sum(input_lens_per_attn_replica) * num_attn_cell_replicas + ) + + execution_time: List[float] = [] + execution_energy: List[float] = [] + for i, cell_schedule in enumerate(stage_schedule.cell_schedules): + # Split the input tokens evenly among the replicas. + num_replicas = cell_schedule.num_replicas + num_input_tokens = ( + num_total_input_tokens + num_replicas - 1 + ) // num_replicas + num_devices = cell_schedule.get_num_devices() + # For mixed precision, assume the data with lower precision + # will be dequantized to match data of higher precision + comp_type = self.highest_prec() + + # Cell execution. + # We leverage the fact that the 0-th device is always assigned the + # most number of tasks. + cell_execution_time = 0.0 + cell_execution_energy = 0.0 + task_dict = cell_schedule.task_mapping.tasks_per_device[0] + for task_type, tasks in task_dict.items(): + if task_type == "MHAHead" or task_type == "MQAHead": + exe_time, exe_energy = mha_time( + gpu_type, + frequency, + tasks, + comp_type, + input_lens_per_attn_replica, + cached_lens_per_attn_replica, + True, + ) + cell_execution_time += exe_time + cell_execution_energy += exe_energy + elif task_type == "BiMHAHead": + exe_time, exe_energy = mha_time( + gpu_type, + frequency, + tasks, + comp_type, + input_lens_per_attn_replica, + cached_lens_per_attn_replica, + False, + ) + cell_execution_time += exe_time + cell_execution_energy += exe_energy + elif task_type == "MLPFilter": + exe_time, exe_energy = mlp_time( + gpu_type, frequency, tasks, comp_type, num_input_tokens + ) + cell_execution_time += exe_time + cell_execution_energy += exe_energy + elif task_type == "GLUFilter": + exe_time, exe_energy = glu_time( + gpu_type, frequency, tasks, comp_type, num_input_tokens + ) + cell_execution_time += exe_time + cell_execution_energy += exe_energy + elif task_type == "SwiGLUFilter": + exe_time, exe_energy = swiglu_time( + gpu_type, frequency, tasks, comp_type, num_input_tokens + ) + cell_execution_time += exe_time + cell_execution_energy += exe_energy + elif task_type.startswith("ExpertMLPFilter"): + # Each expert will get topk / E of the input tokens where E + # is the total number of experts. + num_total_experts = cell_schedule.cell.num_experts + topk = cell_schedule.cell.topk + exe_time, exe_energy = mlp_time( + gpu_type, + frequency, + tasks, + comp_type, + max(num_input_tokens * topk // num_total_experts, 1), + ) + cell_execution_time += exe_time + cell_execution_energy += exe_energy + elif task_type.startswith("ExpertSwiGLUFilter"): + num_total_experts = cell_schedule.cell.num_experts + topk = cell_schedule.cell.topk + exe_time, exe_energy = swiglu_time( + gpu_type, + frequency, + tasks, + comp_type, + max(num_input_tokens * topk // num_total_experts, 1), + ) + cell_execution_time += exe_time + cell_execution_energy += exe_energy + else: + raise ValueError(f"Unsupported task type: {task_type}") + execution_time.append(cell_execution_time) + execution_energy.append(cell_execution_energy * num_devices) + + if ( + cell_schedule.cell.get_name() == "MoE" + or cell_schedule.cell.get_name() == "SwiMoE" + ): + if len(task_dict) < cell_schedule.cell.num_experts: + num_devices = len(cell_schedule.task_mapping.tasks_per_device) + num_input_tokens = max(num_input_tokens // num_devices, 1) + + hidden_size = self.model.hidden_size + # Resharding. + for comm in stage_schedule.reshard_comms[i]: + if comm.num_devices < cluster_size_per_node: + num_nodes = 1 + num_devices_per_node = comm.num_devices + else: + num_nodes = comm.num_devices // cluster_size_per_node + num_devices_per_node = cluster_size_per_node + + num_input_tokens *= comm.size_factor + num_input_tokens = max(num_input_tokens, 1) + if comm.comm_type == CommType.AllReduce: + num_elements = num_input_tokens * hidden_size + elif comm.comm_type == CommType.AllGather: + num_elements = num_input_tokens * comm.num_devices * hidden_size + num_input_tokens *= comm.num_devices + elif comm.comm_type == CommType.ReduceScatter: + num_elements = num_input_tokens * hidden_size + num_input_tokens = max(num_input_tokens // comm.num_devices, 1) + elif comm.comm_type == CommType.AllToAll: + num_elements = num_input_tokens * hidden_size + else: + raise NotImplementedError( + f"Unsupported comm type: {comm.comm_type}" + ) + comm_time = get_comm_time( + comm.comm_type, + gpu_type, + num_nodes, + num_devices_per_node, + self.dtype["act"], + num_elements, + ) + execution_time.append(comm_time) + + # Multiply the block execution time by the number of blocks. + return [t * stage_schedule.num_blocks for t in execution_time], [ + e * stage_schedule.num_blocks for e in execution_energy + ] + + def get_cross_stage_comm_time( + self, + num_input_tokens: int, + stage_clusters: List[Cluster], + gpu_type: str, + cluster_size_per_node: int, + ) -> float: + hidden_size = self.model.hidden_size + num_total_devices = sum(cluster.get_num_devices() for cluster in stage_clusters) + cross_node = num_total_devices > cluster_size_per_node + if cross_node: + return get_p2p_comm_time( + gpu=gpu_type, + num_nodes=2, + num_gpus_per_node=1, + dtype=self.dtype["act"], + num_elements=num_input_tokens * hidden_size, + ) + else: + return get_p2p_comm_time( + gpu=gpu_type, + num_nodes=1, + num_gpus_per_node=2, + dtype=self.dtype["act"], + num_elements=num_input_tokens * hidden_size, + ) diff --git a/apex_plus/utils/__pycache__/dtype.cpython-312.pyc b/apex_plus/utils/__pycache__/dtype.cpython-312.pyc new file mode 100644 index 0000000..07324d2 Binary files /dev/null and b/apex_plus/utils/__pycache__/dtype.cpython-312.pyc differ diff --git a/main.py b/main.py index e2ec5ee..1c5f347 100644 --- a/main.py +++ b/main.py @@ -44,9 +44,13 @@ def main(args: argparse.Namespace): model, model_config = get_model_ir( args.model, args.num_experts, args.topk, args.capacity_factor ) - - encoder_cluster = Cluster.from_gpu(args.gpu, args.num_nodes, 1) - cluster = Cluster.from_gpu(args.gpu, args.num_nodes, args.num_gpus_per_node) + # handle distserve + if args.distserve: + encoder_cluster = Cluster.from_gpu(args.gpu, 2, 1) + cluster = Cluster.from_gpu(args.gpu, 1, 2) + else: + encoder_cluster = Cluster.from_gpu(args.gpu, args.num_nodes, 1) + cluster = Cluster.from_gpu(args.gpu, args.num_nodes, args.num_gpus_per_node) if args.trace_file: trace = Trace.from_dynamic(args.trace_file) @@ -74,6 +78,7 @@ def main(args: argparse.Namespace): args.ttft_slo, args.tpot_slo, args.max_batch_size, + args.distserve, ) # updated traces by adding encode time trace = Trace(trace) if model.num_decoder_blocks > 0: @@ -89,6 +94,7 @@ def main(args: argparse.Namespace): args.ttft_slo, args.tpot_slo, args.max_batch_size, + args.distserve, ) @@ -201,6 +207,15 @@ def main(args: argparse.Namespace): default=0, help="Define max batch size. This is also known as max number of sequences." ) + # Define if the prefill and decode phases should be run on different GPUs + parser.add_argument( + "--distserve", + action="store_true", + default=True, + help="Enable 1P1D DistServe simulation (enabled by default)." + ) + + args = parser.parse_args() main(args)