Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 42 additions & 12 deletions byte_infer_perf/llm_perf/backends/GPU/gpu_mp_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import pathlib
from multiprocessing import Queue
from typing import List

import torch
import torch.nn as nn
import torch.distributed as dist
Expand Down Expand Up @@ -92,16 +91,30 @@ def build_inputs(self, forward_inputs):
).cuda()

is_context = forward_inputs["is_context"]

batch_offset = forward_inputs["cache_batch_offset"]

if is_context:
forward_inputs["full_attention_mask"] = get_context_masks(
forward_inputs["input_ids"],
forward_inputs["attention_mask"]
)
slot_offset = torch.tensor([forward_inputs["valid_slot_ids"][0] * batch_offset],
device = forward_inputs["position_ids"].device,
dtype = forward_inputs["position_ids"].dtype).unsqueeze(1)
else:
bsz = forward_inputs["input_ids"].shape[0]
forward_inputs["full_attention_mask"] = get_decode_masks(
forward_inputs["input_ids"],
forward_inputs["all_kv_len"]
)
forward_inputs["seq_lens"] = torch.tensor( [x + y for x, y in zip(forward_inputs["all_q_len"], forward_inputs["all_kv_len"])],
dtype=torch.int,
device=forward_inputs["position_ids"].device)
slot_offset = torch.arange(0, bsz * batch_offset, batch_offset,
device = forward_inputs["position_ids"].device,
dtype = forward_inputs["position_ids"].dtype).unsqueeze(1)
forward_inputs["slot_mapping"] = forward_inputs["position_ids"] + slot_offset
return forward_inputs


Expand Down Expand Up @@ -193,7 +206,6 @@ def mp_forward(self, *args):

# wait for one subprocess send result back to main process
output_dict = self._output_queues.get(block=True)

return output_dict

# ROCM_HIPGRAPH modify
Expand Down Expand Up @@ -240,22 +252,40 @@ def signal_handler(signum, frame):
logger.info(f"{local_rank}/{world_size} rank is ready")

graph = torch.cuda.CUDAGraph()

s = torch.cuda.Stream()
# model process loop
while True:
(
forward_inputs,
) = input_queue.get(block=True)

# Only copy tensors to gpu if we aren't in replay mode.
# Future cases we'll want to take advantage of the copy_ operator
if 'replay' not in forward_inputs:
forward_inputs["cache_batch_offset"] = model.cache_batch_offset
inputs_dict = self.build_inputs(forward_inputs)

# this is the capture phase of graph
if 'capture' in forward_inputs:
graph.reset() # reset cuda graph each time
inputs_dict = self.build_inputs(forward_inputs)
# model.forward(inputs_dict)
graph = torch.cuda.CUDAGraph() # reset cuda graph each time

_NUM_WARMUP_ITERS=2
with torch.cuda.stream(s):
for _ in range(_NUM_WARMUP_ITERS):
logits = model.forward(inputs_dict)

torch.cuda.current_stream().wait_stream(s)
torch.cuda.synchronize()
with torch.cuda.graph(graph):
model.forward(inputs_dict)
output_dict = model.forward(inputs_dict)

torch.cuda.synchronize()
output_dict = dict()

output_dict["duration_ms"] = 0
# TP realization: rank0 send result back to main process
if local_rank == 0:
output_queue.put(output_dict)
continue

log = forward_inputs.get("log", False)
Expand All @@ -267,13 +297,13 @@ def signal_handler(signum, frame):
workspace_dir.mkdir(exist_ok=True, parents=True)
forward_inputs["log_file"] = open(workspace_dir / "run.log", "w")


inputs_dict = self.build_inputs(forward_inputs)
start_time = time.perf_counter_ns()

# output_dict = model.forward(inputs_dict)
graph.replay()

if 'replay' in forward_inputs:
with torch.cuda.stream(s):
graph.replay()
else:
output_dict = model.forward(inputs_dict)
torch.cuda.synchronize()
end_time = time.perf_counter_ns()
duration_ms = round((end_time - start_time) / 1e6, 3)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1234,31 +1234,15 @@ def forward(
**kwargs,
) -> Union[Tuple, MoeModelOutputWithPast]:
residual = None
bsz = input_ids.shape[0]
is_context = kwargs.get("is_context")
valid_slot_ids = kwargs.get("valid_slot_ids")
batch_offset = kwargs.get("cache_batch_offset")
if is_context:
slot_offset = torch.tensor([valid_slot_ids[0] * batch_offset],
device = position_ids.device,
dtype = position_ids.dtype).unsqueeze(1)
else:
slot_offset = torch.arange(0, bsz * batch_offset, batch_offset,
device = position_ids.device,
dtype = position_ids.dtype).unsqueeze(1)
kwargs["slot_mapping"] = position_ids + slot_offset

if kwargs.pop("override_hidden_states", False):
random_seed = kwargs.pop("random_seed", None)
layer_index = kwargs.pop("fixed_layer_index", -1)
layer_index = layer_index % len(self.layers)

# create random input ids on cpu and copy to device
if random_seed is not None:
# RuntimeError: Cannot call CUDAGeneratorImpl::set_current_seed during CUDA graph capture.
torch.manual_seed(random_seed)
random_input_ids = torch.randint(10, self.vocab_size, input_ids.shape, dtype=torch.int64, device="cpu").to(input_ids.device)

hidden_states = self.embed_tokens(random_input_ids)
#torch.manual_seed(random_seed)
#random_input_ids = torch.randint(10, self.vocab_size, input_ids.shape, dtype=torch.int64, device="cpu").to(input_ids.device)
hidden_states = self.embed_tokens(input_ids)

for _ in self.layers:
layer_outputs, residual = self.layers[layer_index](
Expand All @@ -1271,7 +1255,7 @@ def forward(
output_router_logits=False,
use_cache=False,
**kwargs,
)
)
else:
hidden_states = self.embed_tokens(input_ids)
for decoder_layer in self.layers:
Expand All @@ -1289,9 +1273,8 @@ def forward(

hidden_states, _ = self.norm(hidden_states, residual)

return MoeModelOutputWithPast(
last_hidden_state=hidden_states
)
return hidden_states



class MixtralForCausalLM(MixtralPreTrainedModel):
Expand Down Expand Up @@ -1387,6 +1370,9 @@ def forward(

# print(f'{os.environ.get("LOCAL_RANK", "0")} {outputs=}')
hidden_states = outputs[0]

# Q: Why are we converting the linear call to a float? why not compute in float?
# Also should this be half?
logits = self.lm_head(hidden_states)
logits = logits.float()
# print(f'{os.environ.get("LOCAL_RANK", "0")}:{hidden_states.shape=}')
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,7 @@ def init_kvcache(self, dtype):
while max_num_blocks * self.block_size < max_seq_len * max_batch_size:
max_num_blocks += 4096
self.max_num_blocks_per_seq = (max_seq_len + self.block_size - 1) // self.block_size
self.cache_batch_offset = self.block_size * self.max_num_blocks_per_seq
block_tables_lst: List[List[int]] = []
for batch_idx in range(max_batch_size):
block_start = self.max_num_blocks_per_seq * batch_idx
Expand All @@ -231,17 +232,19 @@ def init_kvcache(self, dtype):
return block_tables, past_key_values

def forward(self, inputs : Dict[str, torch.Tensor]):
inputs["cache_batch_offset"] = self.block_size * self.max_num_blocks_per_seq

model_outputs = self.transformer_model.forward(
**inputs,
past_key_values=(self.block_tables, self.kv_cache)
)


# context: [1, seq_len] --> [1, seq_len, vocab_size] or [1, 1, vocab_size]
# decode: [max_batch_size, 1]

logits = model_outputs.logits

output_dict = {
"logits": logits
}
return output_dict
return output_dict
26 changes: 16 additions & 10 deletions byte_infer_perf/llm_perf/bench_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,19 +131,12 @@ def update_template(mode, batch_size, seq_len):
input_template = update_template("context", 1, 1024)
is_graph = int(os.environ.get("ENABLE_GRAPH", "0"))

if is_graph:
#ROCM_HIPGRAPH modify
input_template['capture'] = 1
engine.mp_forward(input_template)
input_template.pop('capture')

start_time = time.perf_counter_ns()
for _ in range(num_warm_iter):
engine.mp_forward(input_template)
duration_s = round((time.perf_counter_ns() - start_time) / 1e9, 3)
logger.info(f"warmup cost: {duration_s}s")


def results_to_csv(file_path, results):
batch_size_set = sorted(results.keys())
seq_len_set = set()
Expand All @@ -167,6 +160,8 @@ def results_to_csv(file_path, results):

log_results = []
if xpu_config["perf_config"]["perf_context"]:
print(f'>>> Beginning Context', flush=True)

batch_size_list = [1]
seq_len_list = xpu_config["perf_config"]["seq_len_list"]

Expand All @@ -187,7 +182,12 @@ def results_to_csv(file_path, results):
test_iter = 0
duration_ms = 0.
while test_iter < total_test_iter:
result = engine.mp_forward(input_template)
if is_graph:
input_template['replay'] = 1
result = engine.mp_forward(input_template)
input_template.pop('replay')
else:
result = engine.mp_forward(input_template)
if start_iters > 0:
start_iters -= 1
continue
Expand All @@ -203,10 +203,11 @@ def results_to_csv(file_path, results):
lines = workspace.joinpath("rank_0", "run.log").read_text().splitlines()
log_results[-1] += f", {lines[0]}"
print(log_results[-1])
print(f'>>> End of sequence length', flush=True)
results_to_csv(workspace.joinpath("context_perf.csv"), context_results)


if xpu_config["perf_config"]["perf_decode"]:
print(f'>>> Beginning Context', flush=True)
batch_size_list = xpu_config["perf_config"]["batch_size_list"]
seq_len_list = xpu_config["perf_config"]["seq_len_list"]

Expand All @@ -230,7 +231,12 @@ def results_to_csv(file_path, results):

duration_ms = 0.
while test_iter < total_test_iter:
result = engine.mp_forward(input_template)
if is_graph:
input_template['replay'] = 1
result = engine.mp_forward(input_template)
input_template.pop('replay')
else:
result = engine.mp_forward(input_template)
if start_iters > 0:
start_iters -= 1
continue
Expand Down