diff --git a/robin/eval/run_llava.py b/robin/eval/run_llava.py index e867ea6..930d8c4 100644 --- a/robin/eval/run_llava.py +++ b/robin/eval/run_llava.py @@ -3,7 +3,7 @@ from robin.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN from robin.conversation import conv_templates, SeparatorStyle -from robin.model.builder import load_pretrained_model +from robin.model.builder import load_pretrained_model, LlavaMetaModel from robin.utils import disable_torch_init from robin.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria @@ -28,7 +28,7 @@ def eval_model(args): disable_torch_init() model_name = get_model_name_from_path(args.model_path) - tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name) + tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, args.llm_type, model_name) qs = args.query if model.config.mm_use_im_start_end: @@ -89,6 +89,7 @@ def eval_model(args): parser = argparse.ArgumentParser() parser.add_argument("--model-path", type=str, default="facebook/opt-350m") parser.add_argument("--model-base", type=str, default=None) + parser.add_argument("--llm-type", type=str, default=None, choices=LlavaMetaModel.get_model_type_list()) parser.add_argument("--image-file", type=str, required=True) parser.add_argument("--query", type=str, required=True) parser.add_argument("--conv-mode", type=str, default=None) diff --git a/robin/model/builder.py b/robin/model/builder.py index e2f5964..05286a7 100644 --- a/robin/model/builder.py +++ b/robin/model/builder.py @@ -5,10 +5,11 @@ from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig import torch from robin.model import * +from robin.model.llava_arch import LlavaMetaModel from robin.constants import DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN -def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=False, device_map="auto", device="cuda"): +def load_pretrained_model(model_path, model_base, model_name, llm_type=None, load_8bit=False, load_4bit=False, device_map="auto", device="cuda"): kwargs = {"device_map": device_map} if load_8bit: @@ -31,10 +32,19 @@ def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, l lora_cfg_pretrained = AutoConfig.from_pretrained(model_path) tokenizer = AutoTokenizer.from_pretrained(model_base) print('Loading LLaVA from base model...') + + # register llm_type from Enum + if llm_type is not None: + try: + llm_type = LlavaMetaModel.ModelType(llm_type) + except KeyError as e: + raise ValueError(f"Invalid llm type provided {e}. Supported llm classes are {', '.join(LlavaMetaModel.get_model_type_list())}") + else: + llm_type = LlavaMetaModel.get_model_type_from_model_name(model_name) - if 'mistral' in model_name.lower(): + if llm_type == LlavaMetaModel.ModelType.LlavaMistralModel: model = LlavaMistralForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=lora_cfg_pretrained, **kwargs) - elif any(x in model_name.lower() for x in ['neox', 'pythia', 'hi-nolin']): + elif llm_type == LlavaMetaModel.ModelType.LlavaGPTNeoXModel: model = LlavaGPTNeoXForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=lora_cfg_pretrained, **kwargs) else: model = LlavaLlamaForCausalLM.from_pretrained( diff --git a/robin/model/llava_arch.py b/robin/model/llava_arch.py index f79317c..250c1ce 100644 --- a/robin/model/llava_arch.py +++ b/robin/model/llava_arch.py @@ -18,6 +18,8 @@ import torch import torch.nn as nn +from enum import Enum + from .multimodal_encoder.builder import build_vision_tower from .multimodal_projector.builder import build_vision_projector @@ -26,6 +28,9 @@ class LlavaMetaModel: + registry = {} + ModelType = None + def __init__(self, config): super(LlavaMetaModel, self).__init__(config) @@ -33,6 +38,29 @@ def __init__(self, config): self.vision_tower = build_vision_tower(config, delay_load=True) self.mm_projector = build_vision_projector(config) + def __init_subclass__(cls, **kwargs): + super().__init_subclass__(**kwargs) + cls.registry[cls.__name__] = cls + LlavaMetaModel.ModelType = Enum('ModelType', [(subcls.__name__, subcls.config_class.model_type) for subcls in cls.registry.values()]) + + # for when model_type is not passed for backwards compatibility + @classmethod + def get_model_type_from_model_name(cls, model_name: str) -> ModelType: + model_name = model_name.lower() + if 'mpt' in model_name: + return cls.ModelType.LlavaMPTModel + elif 'mistral' in model_name: + return cls.ModelType.LlavaMistralModel + elif any(x in model_name for x in ['neox', 'pythia', 'hi-nolin']): + return cls.ModelType.LlavaGPTNeoXModel + else: + return cls.ModelType.LlavaLlamaModel + + @classmethod + def get_model_type_list(cls): + return [m.value for m in LlavaMetaModel.ModelType] + + def get_vision_tower(self): vision_tower = getattr(self, 'vision_tower', None) if type(vision_tower) is list: diff --git a/robin/serve/cli.py b/robin/serve/cli.py index cfa774e..ae47416 100644 --- a/robin/serve/cli.py +++ b/robin/serve/cli.py @@ -3,7 +3,8 @@ from robin.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN from robin.conversation import conv_templates, SeparatorStyle -from robin.model.builder import load_pretrained_model +from robin.model.builder import load_pretrained_model, LlavaMetaModel + from robin.utils import disable_torch_init from robin.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria @@ -32,7 +33,7 @@ def main(args): model_name = get_model_name_from_path(args.model_path) - tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit, device=args.device) + tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.llm_type ,args.load_8bit, args.load_4bit, device=args.device) conv = conv_templates[args.conv_mode].copy() roles = conv.roles @@ -102,6 +103,7 @@ def main(args): parser = argparse.ArgumentParser() parser.add_argument("--model-path", type=str, default="agi-collective/mistral-7b-oh-siglip-so400m-finetune-lora") parser.add_argument("--model-base", type=str, default="teknium/OpenHermes-2.5-Mistral-7B") + parser.add_argument("--llm-type", type=str, default=None, choices=LlavaMetaModel.get_model_type_list()) parser.add_argument("--image-file", type=str, required=True) parser.add_argument("--device", type=str, default="cuda") parser.add_argument("--conv-mode", type=str, default="vicuna_v1") diff --git a/robin/serve/model_worker.py b/robin/serve/model_worker.py index 19b3557..cd0f02d 100644 --- a/robin/serve/model_worker.py +++ b/robin/serve/model_worker.py @@ -18,7 +18,7 @@ from robin.constants import WORKER_HEART_BEAT_INTERVAL from robin.utils import (build_logger, server_error_msg, pretty_print_semaphore) -from robin.model.builder import load_pretrained_model +from robin.model.builder import load_pretrained_model, LlavaMetaModel from robin.mm_utils import process_images, load_image_from_base64, tokenizer_image_token, KeywordsStoppingCriteria from robin.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN from transformers import TextIteratorStreamer @@ -44,7 +44,7 @@ def heart_beat_worker(controller): class ModelWorker: def __init__(self, controller_addr, worker_addr, worker_id, no_register, - model_path, model_base, model_name, + model_path, model_base, model_name, llm_type, load_8bit, load_4bit, device): self.controller_addr = controller_addr self.worker_addr = worker_addr @@ -63,7 +63,7 @@ def __init__(self, controller_addr, worker_addr, self.device = device logger.info(f"Loading the model {self.model_name} on worker {worker_id} ...") self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model( - model_path, model_base, self.model_name, load_8bit, load_4bit, device=self.device) + model_path, model_base, self.model_name, llm_type, load_8bit, load_4bit, device=self.device) self.is_multimodal = 'llava' in self.model_name.lower() if not no_register: @@ -259,6 +259,7 @@ async def get_status(request: Request): parser.add_argument("--model-path", type=str, default="facebook/opt-350m") parser.add_argument("--model-base", type=str, default=None) parser.add_argument("--model-name", type=str) + parser.add_argument("--llm-type", type=str, default=None, choices=LlavaMetaModel.get_model_type_list()) parser.add_argument("--device", type=str, default="cuda") parser.add_argument("--multi-modal", action="store_true", help="Multimodal mode is automatically detected with model name, please make sure `llava` is included in the model path.") parser.add_argument("--limit-model-concurrency", type=int, default=5) @@ -279,6 +280,7 @@ async def get_status(request: Request): args.model_path, args.model_base, args.model_name, + args.llm_type, args.load_8bit, args.load_4bit, args.device) diff --git a/robin/serve/pipeline.py b/robin/serve/pipeline.py index e756e0c..9444196 100644 --- a/robin/serve/pipeline.py +++ b/robin/serve/pipeline.py @@ -28,10 +28,11 @@ def load_image(image_file): class LlavaMistralPipeline: - def __init__(self, model_path, model_base, device="cuda", load_8bit=False, load_4bit=False, temperature=.2, max_new_tokens=512): + def __init__(self, model_path, model_base, llm_type=None, device="cuda", load_8bit=False, load_4bit=False, temperature=.2, max_new_tokens=512): self.model_path = model_path self.model_base = model_base + self.llm_type = llm_type self.device = device self.load_8bit = load_8bit self.load_4bit = load_4bit @@ -48,7 +49,7 @@ def __init__(self, model_path, model_base, device="cuda", load_8bit=False, load_ model_name = get_model_name_from_path(self.model_path) - self.tokenizer, self.model, self.image_processor, context_len = load_pretrained_model(self.model_path, self.model_base, model_name, self.load_8bit, self.load_4bit, device=self.device) + self.tokenizer, self.model, self.image_processor, context_len = load_pretrained_model(self.model_path, self.model_base, model_name, self.llm_type, self.load_8bit, self.load_4bit, device=self.device) def _load_image_tensor(self, image_file): diff --git a/robin/train/train.py b/robin/train/train.py index f35767e..f12840f 100644 --- a/robin/train/train.py +++ b/robin/train/train.py @@ -33,6 +33,7 @@ from robin import conversation as conversation_lib from robin.model import LlavaMistralForCausalLM, LlavaGPTNeoXForCausalLM, LlavaLlamaForCausalLM#, LlavaMPTForCausalLM [TODO] mpt is commented out at robin.model.__init__ +from robin.model.builder import LlavaMetaModel from robin.mm_utils import tokenizer_image_token, expand2square from PIL import Image @@ -51,6 +52,7 @@ def rank0_print(*args, **kwargs): @dataclass class ModelArguments: model_name_or_path: Optional[str] = field(default="facebook/opt-125m") + llm_type: Optional[LlavaMetaModel.ModelType] = field(default=None) version: Optional[str] = field(default="v0") freeze_backbone: bool = field(default=False) tune_mm_mlp_adapter: bool = field(default=False) @@ -781,6 +783,10 @@ def train(USE_FLASH_ATTN_2=False): local_rank = training_args.local_rank compute_dtype = (torch.float16 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32)) + # parse llm type input as an Enum + # no need for error checking as HFArgumentParser does it for Enums + llm_type = LlavaMetaModel.ModelType(model_args.llm_type) if model_args.llm_type is not None else LlavaMetaModel.get_model_type_from_model_name(model_args.model_name_or_path) + bnb_model_from_pretrained_args = {} if training_args.bits in [4, 8]: from transformers import BitsAndBytesConfig @@ -800,19 +806,18 @@ def train(USE_FLASH_ATTN_2=False): )) if model_args.vision_tower is not None: - model_name = model_args.model_name_or_path.lower() rank0_print("Loading model of type:", end=' ') - if 'mpt' in model_name: - rank0_print("MPT") - config = transformers.AutoConfig.from_pretrained(model_args.model_name_or_path, trust_remote_code=True) - config.attn_config['attn_impl'] = training_args.mpt_attn_impl - model = LlavaMPTForCausalLM.from_pretrained( - model_args.model_name_or_path, - config=config, - cache_dir=training_args.cache_dir, - **bnb_model_from_pretrained_args - ) - elif 'mistral' in model_name: + # if llm_type == LlavaMetaModel.ModelType.LlavaMPTModel: + # rank0_print("MPT") + # config = transformers.AutoConfig.from_pretrained(model_args.model_name_or_path, trust_remote_code=True) + # config.attn_config['attn_impl'] = training_args.mpt_attn_impl + # model = LlavaMPTForCausalLM.from_pretrained( + # model_args.model_name_or_path, + # config=config, + # cache_dir=training_args.cache_dir, + # **bnb_model_from_pretrained_args + # ) + if llm_type == LlavaMetaModel.ModelType.LlavaMistralModel: rank0_print("Mistral") model = LlavaMistralForCausalLM.from_pretrained( model_args.model_name_or_path, @@ -820,7 +825,7 @@ def train(USE_FLASH_ATTN_2=False): use_flash_attention_2 = USE_FLASH_ATTN_2, **bnb_model_from_pretrained_args ) - elif any(x in model_name for x in ['neox', 'pythia', 'hi-nolin']): + elif llm_type == LlavaMetaModel.ModelType.LlavaGPTNeoXModel: rank0_print("NeoX") model = LlavaGPTNeoXForCausalLM.from_pretrained( model_args.model_name_or_path, @@ -828,7 +833,7 @@ def train(USE_FLASH_ATTN_2=False): use_flash_attention_2 = USE_FLASH_ATTN_2, # The current architecture does not support Flash Attention 2.0 **bnb_model_from_pretrained_args ) - else: + elif llm_type == LlavaMetaModel.ModelType.LlavaLlamaModel: rank0_print("Llama") model = LlavaLlamaForCausalLM.from_pretrained( model_args.model_name_or_path, @@ -836,6 +841,8 @@ def train(USE_FLASH_ATTN_2=False): **bnb_model_from_pretrained_args, use_flash_attention_2 = USE_FLASH_ATTN_2, ) + else: + raise NotImplementedError else: model = transformers.LlamaForCausalLM.from_pretrained( model_args.model_name_or_path, @@ -878,21 +885,21 @@ def make_inputs_require_grad(module, input, output): rank0_print("Adding LoRA adapters...") model = get_peft_model(model, lora_config) - if 'mpt' in model_args.model_name_or_path: - tokenizer = transformers.AutoTokenizer.from_pretrained( - model_args.model_name_or_path, - cache_dir=training_args.cache_dir, - model_max_length=training_args.model_max_length, - padding_side="right" - ) - else: - #print(model_args.model_name_or_path) - tokenizer = transformers.AutoTokenizer.from_pretrained( - model_args.model_name_or_path, - cache_dir=training_args.cache_dir, - model_max_length=training_args.model_max_length, - padding_side="right", - ) + # if llm_type == LlavaMetaModel.ModelType.LlavaMPTModel: + # tokenizer = transformers.AutoTokenizer.from_pretrained( + # model_args.model_name_or_path, + # cache_dir=training_args.cache_dir, + # model_max_length=training_args.model_max_length, + # padding_side="right" + # ) + # else: + #print(model_args.model_name_or_path) + tokenizer = transformers.AutoTokenizer.from_pretrained( + model_args.model_name_or_path, + cache_dir=training_args.cache_dir, + model_max_length=training_args.model_max_length, + padding_side="right", + ) if model_args.version == "v0": if tokenizer.pad_token is None: