diff --git a/byte_infer_perf/llm_perf/backends/GPU/gpu_mp_engine.py b/byte_infer_perf/llm_perf/backends/GPU/gpu_mp_engine.py index 82f5ce59..21eae4e5 100644 --- a/byte_infer_perf/llm_perf/backends/GPU/gpu_mp_engine.py +++ b/byte_infer_perf/llm_perf/backends/GPU/gpu_mp_engine.py @@ -5,7 +5,6 @@ import pathlib from multiprocessing import Queue from typing import List - import torch import torch.nn as nn import torch.distributed as dist @@ -92,16 +91,30 @@ def build_inputs(self, forward_inputs): ).cuda() is_context = forward_inputs["is_context"] + + batch_offset = forward_inputs["cache_batch_offset"] + if is_context: forward_inputs["full_attention_mask"] = get_context_masks( forward_inputs["input_ids"], forward_inputs["attention_mask"] ) + slot_offset = torch.tensor([forward_inputs["valid_slot_ids"][0] * batch_offset], + device = forward_inputs["position_ids"].device, + dtype = forward_inputs["position_ids"].dtype).unsqueeze(1) else: + bsz = forward_inputs["input_ids"].shape[0] forward_inputs["full_attention_mask"] = get_decode_masks( forward_inputs["input_ids"], forward_inputs["all_kv_len"] ) + forward_inputs["seq_lens"] = torch.tensor( [x + y for x, y in zip(forward_inputs["all_q_len"], forward_inputs["all_kv_len"])], + dtype=torch.int, + device=forward_inputs["position_ids"].device) + slot_offset = torch.arange(0, bsz * batch_offset, batch_offset, + device = forward_inputs["position_ids"].device, + dtype = forward_inputs["position_ids"].dtype).unsqueeze(1) + forward_inputs["slot_mapping"] = forward_inputs["position_ids"] + slot_offset return forward_inputs @@ -193,7 +206,6 @@ def mp_forward(self, *args): # wait for one subprocess send result back to main process output_dict = self._output_queues.get(block=True) - return output_dict # ROCM_HIPGRAPH modify @@ -240,22 +252,40 @@ def signal_handler(signum, frame): logger.info(f"{local_rank}/{world_size} rank is ready") graph = torch.cuda.CUDAGraph() - + s = torch.cuda.Stream() # model process loop while True: ( forward_inputs, ) = input_queue.get(block=True) + # Only copy tensors to gpu if we aren't in replay mode. + # Future cases we'll want to take advantage of the copy_ operator + if 'replay' not in forward_inputs: + forward_inputs["cache_batch_offset"] = model.cache_batch_offset + inputs_dict = self.build_inputs(forward_inputs) + # this is the capture phase of graph if 'capture' in forward_inputs: - graph.reset() # reset cuda graph each time - inputs_dict = self.build_inputs(forward_inputs) - # model.forward(inputs_dict) + graph = torch.cuda.CUDAGraph() # reset cuda graph each time + + _NUM_WARMUP_ITERS=2 + with torch.cuda.stream(s): + for _ in range(_NUM_WARMUP_ITERS): + logits = model.forward(inputs_dict) + + torch.cuda.current_stream().wait_stream(s) torch.cuda.synchronize() with torch.cuda.graph(graph): - model.forward(inputs_dict) + output_dict = model.forward(inputs_dict) + torch.cuda.synchronize() + output_dict = dict() + + output_dict["duration_ms"] = 0 + # TP realization: rank0 send result back to main process + if local_rank == 0: + output_queue.put(output_dict) continue log = forward_inputs.get("log", False) @@ -267,13 +297,13 @@ def signal_handler(signum, frame): workspace_dir.mkdir(exist_ok=True, parents=True) forward_inputs["log_file"] = open(workspace_dir / "run.log", "w") - - inputs_dict = self.build_inputs(forward_inputs) start_time = time.perf_counter_ns() - # output_dict = model.forward(inputs_dict) - graph.replay() - + if 'replay' in forward_inputs: + with torch.cuda.stream(s): + graph.replay() + else: + output_dict = model.forward(inputs_dict) torch.cuda.synchronize() end_time = time.perf_counter_ns() duration_ms = round((end_time - start_time) / 1e6, 3) diff --git a/byte_infer_perf/llm_perf/backends/ROCM/model_impl/modeling_mixtral.py b/byte_infer_perf/llm_perf/backends/ROCM/model_impl/modeling_mixtral.py index af1504b6..2e86c765 100644 --- a/byte_infer_perf/llm_perf/backends/ROCM/model_impl/modeling_mixtral.py +++ b/byte_infer_perf/llm_perf/backends/ROCM/model_impl/modeling_mixtral.py @@ -1234,31 +1234,15 @@ def forward( **kwargs, ) -> Union[Tuple, MoeModelOutputWithPast]: residual = None - bsz = input_ids.shape[0] - is_context = kwargs.get("is_context") - valid_slot_ids = kwargs.get("valid_slot_ids") - batch_offset = kwargs.get("cache_batch_offset") - if is_context: - slot_offset = torch.tensor([valid_slot_ids[0] * batch_offset], - device = position_ids.device, - dtype = position_ids.dtype).unsqueeze(1) - else: - slot_offset = torch.arange(0, bsz * batch_offset, batch_offset, - device = position_ids.device, - dtype = position_ids.dtype).unsqueeze(1) - kwargs["slot_mapping"] = position_ids + slot_offset + if kwargs.pop("override_hidden_states", False): random_seed = kwargs.pop("random_seed", None) layer_index = kwargs.pop("fixed_layer_index", -1) layer_index = layer_index % len(self.layers) - # create random input ids on cpu and copy to device - if random_seed is not None: - # RuntimeError: Cannot call CUDAGeneratorImpl::set_current_seed during CUDA graph capture. - torch.manual_seed(random_seed) - random_input_ids = torch.randint(10, self.vocab_size, input_ids.shape, dtype=torch.int64, device="cpu").to(input_ids.device) - - hidden_states = self.embed_tokens(random_input_ids) + #torch.manual_seed(random_seed) + #random_input_ids = torch.randint(10, self.vocab_size, input_ids.shape, dtype=torch.int64, device="cpu").to(input_ids.device) + hidden_states = self.embed_tokens(input_ids) for _ in self.layers: layer_outputs, residual = self.layers[layer_index]( @@ -1271,7 +1255,7 @@ def forward( output_router_logits=False, use_cache=False, **kwargs, - ) + ) else: hidden_states = self.embed_tokens(input_ids) for decoder_layer in self.layers: @@ -1289,9 +1273,8 @@ def forward( hidden_states, _ = self.norm(hidden_states, residual) - return MoeModelOutputWithPast( - last_hidden_state=hidden_states - ) + return hidden_states + class MixtralForCausalLM(MixtralPreTrainedModel): @@ -1387,6 +1370,9 @@ def forward( # print(f'{os.environ.get("LOCAL_RANK", "0")} {outputs=}') hidden_states = outputs[0] + + # Q: Why are we converting the linear call to a float? why not compute in float? + # Also should this be half? logits = self.lm_head(hidden_states) logits = logits.float() # print(f'{os.environ.get("LOCAL_RANK", "0")}:{hidden_states.shape=}') diff --git a/byte_infer_perf/llm_perf/backends/ROCM/model_impl/rocm_mixtral.py b/byte_infer_perf/llm_perf/backends/ROCM/model_impl/rocm_mixtral.py index 48ad43bf..a2293c3c 100644 --- a/byte_infer_perf/llm_perf/backends/ROCM/model_impl/rocm_mixtral.py +++ b/byte_infer_perf/llm_perf/backends/ROCM/model_impl/rocm_mixtral.py @@ -213,6 +213,7 @@ def init_kvcache(self, dtype): while max_num_blocks * self.block_size < max_seq_len * max_batch_size: max_num_blocks += 4096 self.max_num_blocks_per_seq = (max_seq_len + self.block_size - 1) // self.block_size + self.cache_batch_offset = self.block_size * self.max_num_blocks_per_seq block_tables_lst: List[List[int]] = [] for batch_idx in range(max_batch_size): block_start = self.max_num_blocks_per_seq * batch_idx @@ -231,17 +232,19 @@ def init_kvcache(self, dtype): return block_tables, past_key_values def forward(self, inputs : Dict[str, torch.Tensor]): - inputs["cache_batch_offset"] = self.block_size * self.max_num_blocks_per_seq + model_outputs = self.transformer_model.forward( **inputs, past_key_values=(self.block_tables, self.kv_cache) ) + # context: [1, seq_len] --> [1, seq_len, vocab_size] or [1, 1, vocab_size] # decode: [max_batch_size, 1] + logits = model_outputs.logits output_dict = { "logits": logits } - return output_dict \ No newline at end of file + return output_dict diff --git a/byte_infer_perf/llm_perf/bench_model.py b/byte_infer_perf/llm_perf/bench_model.py index ce346344..e70ba753 100644 --- a/byte_infer_perf/llm_perf/bench_model.py +++ b/byte_infer_perf/llm_perf/bench_model.py @@ -131,19 +131,12 @@ def update_template(mode, batch_size, seq_len): input_template = update_template("context", 1, 1024) is_graph = int(os.environ.get("ENABLE_GRAPH", "0")) - if is_graph: - #ROCM_HIPGRAPH modify - input_template['capture'] = 1 - engine.mp_forward(input_template) - input_template.pop('capture') - start_time = time.perf_counter_ns() for _ in range(num_warm_iter): engine.mp_forward(input_template) duration_s = round((time.perf_counter_ns() - start_time) / 1e9, 3) logger.info(f"warmup cost: {duration_s}s") - def results_to_csv(file_path, results): batch_size_set = sorted(results.keys()) seq_len_set = set() @@ -167,6 +160,8 @@ def results_to_csv(file_path, results): log_results = [] if xpu_config["perf_config"]["perf_context"]: + print(f'>>> Beginning Context', flush=True) + batch_size_list = [1] seq_len_list = xpu_config["perf_config"]["seq_len_list"] @@ -187,7 +182,12 @@ def results_to_csv(file_path, results): test_iter = 0 duration_ms = 0. while test_iter < total_test_iter: - result = engine.mp_forward(input_template) + if is_graph: + input_template['replay'] = 1 + result = engine.mp_forward(input_template) + input_template.pop('replay') + else: + result = engine.mp_forward(input_template) if start_iters > 0: start_iters -= 1 continue @@ -203,10 +203,11 @@ def results_to_csv(file_path, results): lines = workspace.joinpath("rank_0", "run.log").read_text().splitlines() log_results[-1] += f", {lines[0]}" print(log_results[-1]) + print(f'>>> End of sequence length', flush=True) results_to_csv(workspace.joinpath("context_perf.csv"), context_results) - if xpu_config["perf_config"]["perf_decode"]: + print(f'>>> Beginning Context', flush=True) batch_size_list = xpu_config["perf_config"]["batch_size_list"] seq_len_list = xpu_config["perf_config"]["seq_len_list"] @@ -230,7 +231,12 @@ def results_to_csv(file_path, results): duration_ms = 0. while test_iter < total_test_iter: - result = engine.mp_forward(input_template) + if is_graph: + input_template['replay'] = 1 + result = engine.mp_forward(input_template) + input_template.pop('replay') + else: + result = engine.mp_forward(input_template) if start_iters > 0: start_iters -= 1 continue