diff --git a/projects/mock_transformers/dist_infer_bloom.py b/projects/mock_transformers/dist_infer_bloom.py index b738ef7d6..e0497a726 100644 --- a/projects/mock_transformers/dist_infer_bloom.py +++ b/projects/mock_transformers/dist_infer_bloom.py @@ -82,9 +82,10 @@ def __init__(self, config): parallel_config = DictConfig( dict( data_parallel_size=1, - tensor_parallel_size=2, - pipeline_parallel_size=1, # set to 1, unsupport pipeline parallel now - pipeline_num_layers=None, + tensor_parallel_size=1, + pipeline_parallel_size=4, # set to 1, unsupport pipeline parallel now + pipeline_num_layers=24, # set to 1, unsupport pipeline parallel now + custom_pipeline_stage_id=[0]*6 + [1]*6 + [2]*6 + [3]*6, device_type="cpu", ) ) @@ -95,6 +96,10 @@ def __init__(self, config): # set model to cuda dist.set_device_type("cuda") model._apply(dist.convert_to_distributed_default_setting) + + init_env.auto_set_pipeline_stage_id(model, pipeline_parallel_size=parallel_config.pipeline_parallel_size) + import pdb + pdb.set_trace() # initial tokenizer tokenizer = AutoTokenizer.from_pretrained("bigscience/bloom-560m", use_fast=False) diff --git a/projects/mock_transformers/dist_infer_opt.py b/projects/mock_transformers/dist_infer_opt.py index bc04517df..169b382f3 100644 --- a/projects/mock_transformers/dist_infer_opt.py +++ b/projects/mock_transformers/dist_infer_opt.py @@ -73,37 +73,46 @@ def __init__(self, *args, **kwargs): parallel_config = DictConfig( dict( data_parallel_size=1, - tensor_parallel_size=2, + tensor_parallel_size=1, pipeline_parallel_size=1, # set to 1, unsupport pipeline parallel now - pipeline_num_layers=None, + # pipeline_num_layers=12, + # custom_pipeline_stage_id= [0]*3 + [1]*3 + [2]*3 + [3]*3, device_type="cpu", ) ) dist.setup_dist_util(parallel_config) - + placement_sbp_dict = dict( + placement=flow.env.all_device_placement("cuda"), + sbp=flow.sbp.broadcast, + ) + # initial and load model - model = AutoModelForCausalLM.from_pretrained("facebook/opt-2.7b", torch_dtype=flow.float16) + model = AutoModelForCausalLM.from_pretrained("facebook/opt-125m", torch_dtype=flow.float16) # set model to cuda dist.set_device_type("cuda") model._apply(dist.convert_to_distributed_default_setting) + + model = init_env.auto_set_pipeline_stage_id(model, pipeline_parallel_size=parallel_config.pipeline_parallel_size) + # initial tokenizer - tokenizer = AutoTokenizer.from_pretrained("facebook/opt-2.7b", use_fast=False) + tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125m", use_fast=False) # get input_ids prompt = "Hello, I'm am conscious and" - input_ids = tokenizer(prompt, return_tensors="np").input_ids - input_ids = flow.from_numpy(input_ids) + input_ids = tokenizer(prompt, return_tensors="pt").input_ids + # input_ids = flow.from_numpy(input_ids) input_ids = input_ids.to_global( sbp=dist.get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast]), placement=dist.get_layer_placement(0), ) - + # generate id - placement_sbp_dict = dict( - placement=flow.env.all_device_placement("cuda"), - sbp=flow.sbp.broadcast, - ) - with global_mode(True, **placement_sbp_dict): - generated_ids = model.generate(input_ids, max_length=30) - out_put_ids = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) - print(out_put_ids) + for i in range(100): + with global_mode(True, **placement_sbp_dict): + model = init_env.compile_auto_placement( + model, + input_ids + ) + generated_ids = model.generate(input_ids, max_length=30) + out_put_ids = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) + print(out_put_ids) diff --git a/projects/mock_transformers/init_env.py b/projects/mock_transformers/init_env.py index 814f8cf3c..7847f6c81 100644 --- a/projects/mock_transformers/init_env.py +++ b/projects/mock_transformers/init_env.py @@ -18,9 +18,15 @@ flow.mock_torch.enable() + +import copy # noqa +import onefx as fx # noqa +from typing import List, Dict, Any # noqa from oneflow import Tensor, nn # noqa from transformers import modeling_utils # noqa from transformers.modeling_utils import _load_state_dict_into_model # noqa +from libai.utils import distributed as dist #noqa + # ---------------- mock _load_state_dict_into_model ------------------ @@ -111,3 +117,183 @@ def flow_softmax(*args, **kwargs): nn.functional.softmax = flow_softmax + +# ============================================= +# -----------------def function---------------- +# ============================================= + +def set_pipeline_stage_id(self, placement): + for param in self.parameters(): + param.data = param.data.to_global(placement=placement) + +nn.Module.set_pipeline_stage_id = set_pipeline_stage_id + + +def sizeof_fmt(num, suffix='B'): + for unit in ['', 'Ki', 'Mi', 'Gi', 'Ti', 'Pi', 'Ei', 'Zi']: + if abs(num) < 1024.0: + return f"{num:.2f} {unit}{suffix}" + num /= 1024.0 + return f"{num:.2f} Yi{suffix}" + +def print_model(model, depth=0, max_depth=2, last_child=False, prefix=''): + indent = " " + stage_str = "" + if hasattr(model, "layer_idx"): + layer_idx = getattr(model, "layer_idx") + stage_idx = getattr(model, "stage_idx") + same_placement = True + for path, module in model.named_modules(): + if getattr(module, "layer_idx") != layer_idx: + same_placement = False + if same_placement: + stage_str = f" stage{stage_idx}_ranks{dist.get_layer_placement(layer_idx).ranks} " + + if depth > max_depth: + return + if isinstance(model, nn.Module): + params = sum(p.numel() for p in model.parameters()) + print(indent * depth + ("└─" if last_child else "├─") + prefix + str(model.__class__.__name__) + ": " + stage_str + sizeof_fmt(params) + " params") + elif isinstance(model, nn.Sequential): + print(indent * depth + ("└─" if last_child else "├─") + prefix + str(model.__class__.__name__) + ": " + str(len(list(model.named_children()))) + " modules") + else: + print(indent * depth + ("└─" if last_child else "├─") + prefix + str(type(model).__name__)) + for i, (name, child) in enumerate(model.named_children()): + print_model(child, depth=depth+1, max_depth=max_depth, last_child=i==len(list(model.named_children()))-1, prefix=f'[{name}] ') + + +def auto_set_pipeline_stage_id(model, pipeline_parallel_size=1): + # Define a local variable to record the number of repeated and integer layers encountered + count = 0 + max_depth=1 + name_stage_dict = {} + # Iterate over all submodules and paths of the model + for path, module in model.named_modules(): + # Get the name and class of the module + name = path.split(".")[-1] + prefix_path = ".".join(path.split(".")[:-1]) + module_cls = type(module) + + # Determine if the layer is a number, i.e. if it is possible to be a repeated and integer layer + if name.isdigit(): + # Determine if the layer has been repeated, i.e. if there is the same path and class in named_modules + repeated = False + for n, m in model.named_modules(): + prefix_n = ".".join(n.split(".")[:-1]) + if m is not module and prefix_n == prefix_path and type(m) == module_cls: + max_depth = max(len(n.split(".")), max_depth) + repeated = True + if repeated: + count += 1 + # print(f"Layer {name} with path {path} is repeated. {count}") + + name_stage_dict[path] = max(count-1, 0) + + + length = (count + pipeline_parallel_size - 1) // pipeline_parallel_size + param_id_set = set() # skip shared weight param + + for path, module in model.named_modules(): + # Add to_global to the parameter + layer_idx = name_stage_dict[path] + stage_idx = layer_idx // length + setattr(module, "stage_idx", stage_idx) + setattr(module, "layer_idx", layer_idx) + if len(path.split(".")) >= max_depth or len(list(module.named_children())) == 0: + for param in module.parameters(): + if id(param) not in param_id_set: + param.data = param.data.to_global(placement=dist.get_layer_placement(layer_idx)) + param_id_set.add(id(param)) + + if dist.is_main_process(): + print_model(model, depth=0, max_depth=100 if max_depth==1 else max_depth) + # Return the modified model + return model + +# ---------------def fx for auto changing placement ---------------------- + + +class AutoPlacementInterpreter(fx.Interpreter): + def __init__(self, mod : flow.nn.Module): + gm = fx.symbolic_trace(mod) + super().__init__(gm) + + self.global_infos : Dict[int, Dict[int, Any]] = {} + self.node_id = 0 + + def run(self, *args) -> Any: + return_val = super().run(*args) + return return_val + + def run_node(self, n : fx.Node) -> Any: + args, kwargs = self.fetch_args_kwargs_from_env(n) + global_info_to_replace = None + max_rank_sum = -1 + for arg in args: + if not isinstance(arg, flow.Tensor): + continue + if arg.is_local or len(arg.placement.ranks) == 0: + continue + placement = arg.placement + sbp = arg.sbp + # print(sum(placement.ranks)) + if max_rank_sum < sum(placement.ranks): + max_rank_sum = sum(placement.ranks) + global_info_to_replace = (placement, sbp) + # elif max_rank_sum == sum(placement.ranks) and zip(placement_to_replace.ranks, placement.ranks).all(lambda x, y: x == y): + # raise ValueError("There is two different placements with same rank sum. " + # + f"They are {placement_to_replace} and {placement}.") + + if max_rank_sum == -1: + self.node_id += 1 + return_val = super().run_node(n) + return return_val + + for arg_id in range(len(args)): + if isinstance(arg, flow.Tensor) and sum(arg.placement.ranks) < max_rank_sum: + self.global_infos.setdefault(self.node_id, {})[arg_id] = global_info_to_replace + n.update_arg(arg_id, args[arg_id].to_global(global_info_to_replace[0], global_info_to_replace[1])) + + return_val = super().run_node(n) + return return_val + + +def add_auto_placement(model: flow.nn.Module, global_info_dict: Dict[int, Dict[int, List[int]]]) -> flow.nn.Module: + model = copy.deepcopy(model) + fx_model: fx.GraphModule = fx.symbolic_trace(model) + + for node_id, node in enumerate(fx_model.graph.nodes): + print(node_id, " ", node.op) + if not node_id in global_info_dict: + continue + + for idx, arg in enumerate(node.args): + if not idx in global_info_dict[node_id]: + continue + global_info = global_info_dict[node_id][idx] + new_node = fx.Node(fx_model.graph, f"auto_placement_{node_id}_{idx}", "call_function", flow.to_global, (arg, global_info[0], global_info[1]), {}) + node.prepend(new_node) + node.update_arg(idx, new_node) + + fx_model.graph.lint() + fx_model.recompile() + return fx_model + +def compile_auto_placement(model: flow.nn.Module, input_x: flow.Tensor): + assert input_x.is_global + interpret = AutoPlacementInterpreter(model) + interpret.run(input_x) + model = add_auto_placement(model, interpret.global_infos) + return model + +# b = flow.ones( +# (2,2), +# sbp=[flow.sbp.broadcast, flow.sbp.broadcast], +# placement=flow.placement("cuda", ranks=[[2], [3]]) +# ) +# demo_module = demoModule() +# interpret = AutoPlacementInterpreter(demo_module) +# c = interpret.run(b) +# model = add_auto_placement(demo_module, interpret.global_infos) +# print(model.code) +# print(model(b)) \ No newline at end of file