Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions robin/eval/run_llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from robin.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
from robin.conversation import conv_templates, SeparatorStyle
from robin.model.builder import load_pretrained_model
from robin.model.builder import load_pretrained_model, LlavaMetaModel
from robin.utils import disable_torch_init
from robin.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria

Expand All @@ -28,7 +28,7 @@ def eval_model(args):
disable_torch_init()

model_name = get_model_name_from_path(args.model_path)
tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name)
tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, args.llm_type, model_name)

qs = args.query
if model.config.mm_use_im_start_end:
Expand Down Expand Up @@ -89,6 +89,7 @@ def eval_model(args):
parser = argparse.ArgumentParser()
parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
parser.add_argument("--model-base", type=str, default=None)
parser.add_argument("--llm-type", type=str, default=None, choices=LlavaMetaModel.get_model_type_list())
parser.add_argument("--image-file", type=str, required=True)
parser.add_argument("--query", type=str, required=True)
parser.add_argument("--conv-mode", type=str, default=None)
Expand Down
16 changes: 13 additions & 3 deletions robin/model/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig
import torch
from robin.model import *
from robin.model.llava_arch import LlavaMetaModel
from robin.constants import DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN


def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=False, device_map="auto", device="cuda"):
def load_pretrained_model(model_path, model_base, model_name, llm_type=None, load_8bit=False, load_4bit=False, device_map="auto", device="cuda"):
kwargs = {"device_map": device_map}

if load_8bit:
Expand All @@ -31,10 +32,19 @@ def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, l
lora_cfg_pretrained = AutoConfig.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_base)
print('Loading LLaVA from base model...')

# register llm_type from Enum
if llm_type is not None:
try:
llm_type = LlavaMetaModel.ModelType(llm_type)
except KeyError as e:
raise ValueError(f"Invalid llm type provided {e}. Supported llm classes are {', '.join(LlavaMetaModel.get_model_type_list())}")
else:
llm_type = LlavaMetaModel.get_model_type_from_model_name(model_name)

if 'mistral' in model_name.lower():
if llm_type == LlavaMetaModel.ModelType.LlavaMistralModel:
model = LlavaMistralForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=lora_cfg_pretrained, **kwargs)
elif any(x in model_name.lower() for x in ['neox', 'pythia', 'hi-nolin']):
elif llm_type == LlavaMetaModel.ModelType.LlavaGPTNeoXModel:
model = LlavaGPTNeoXForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=lora_cfg_pretrained, **kwargs)
else:
model = LlavaLlamaForCausalLM.from_pretrained(
Expand Down
28 changes: 28 additions & 0 deletions robin/model/llava_arch.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
import torch
import torch.nn as nn

from enum import Enum

from .multimodal_encoder.builder import build_vision_tower
from .multimodal_projector.builder import build_vision_projector

Expand All @@ -26,13 +28,39 @@

class LlavaMetaModel:

registry = {}
ModelType = None

def __init__(self, config):
super(LlavaMetaModel, self).__init__(config)

if hasattr(config, "mm_vision_tower"):
self.vision_tower = build_vision_tower(config, delay_load=True)
self.mm_projector = build_vision_projector(config)

def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
cls.registry[cls.__name__] = cls
LlavaMetaModel.ModelType = Enum('ModelType', [(subcls.__name__, subcls.config_class.model_type) for subcls in cls.registry.values()])

# for when model_type is not passed for backwards compatibility
@classmethod
def get_model_type_from_model_name(cls, model_name: str) -> ModelType:
model_name = model_name.lower()
if 'mpt' in model_name:
return cls.ModelType.LlavaMPTModel
elif 'mistral' in model_name:
return cls.ModelType.LlavaMistralModel
elif any(x in model_name for x in ['neox', 'pythia', 'hi-nolin']):
return cls.ModelType.LlavaGPTNeoXModel
else:
return cls.ModelType.LlavaLlamaModel

@classmethod
def get_model_type_list(cls):
return [m.value for m in LlavaMetaModel.ModelType]


def get_vision_tower(self):
vision_tower = getattr(self, 'vision_tower', None)
if type(vision_tower) is list:
Expand Down
6 changes: 4 additions & 2 deletions robin/serve/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@

from robin.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
from robin.conversation import conv_templates, SeparatorStyle
from robin.model.builder import load_pretrained_model
from robin.model.builder import load_pretrained_model, LlavaMetaModel

from robin.utils import disable_torch_init
from robin.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria

Expand Down Expand Up @@ -32,7 +33,7 @@ def main(args):

model_name = get_model_name_from_path(args.model_path)

tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit, device=args.device)
tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.llm_type ,args.load_8bit, args.load_4bit, device=args.device)

conv = conv_templates[args.conv_mode].copy()
roles = conv.roles
Expand Down Expand Up @@ -102,6 +103,7 @@ def main(args):
parser = argparse.ArgumentParser()
parser.add_argument("--model-path", type=str, default="agi-collective/mistral-7b-oh-siglip-so400m-finetune-lora")
parser.add_argument("--model-base", type=str, default="teknium/OpenHermes-2.5-Mistral-7B")
parser.add_argument("--llm-type", type=str, default=None, choices=LlavaMetaModel.get_model_type_list())
parser.add_argument("--image-file", type=str, required=True)
parser.add_argument("--device", type=str, default="cuda")
parser.add_argument("--conv-mode", type=str, default="vicuna_v1")
Expand Down
8 changes: 5 additions & 3 deletions robin/serve/model_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from robin.constants import WORKER_HEART_BEAT_INTERVAL
from robin.utils import (build_logger, server_error_msg,
pretty_print_semaphore)
from robin.model.builder import load_pretrained_model
from robin.model.builder import load_pretrained_model, LlavaMetaModel
from robin.mm_utils import process_images, load_image_from_base64, tokenizer_image_token, KeywordsStoppingCriteria
from robin.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
from transformers import TextIteratorStreamer
Expand All @@ -44,7 +44,7 @@ def heart_beat_worker(controller):
class ModelWorker:
def __init__(self, controller_addr, worker_addr,
worker_id, no_register,
model_path, model_base, model_name,
model_path, model_base, model_name, llm_type,
load_8bit, load_4bit, device):
self.controller_addr = controller_addr
self.worker_addr = worker_addr
Expand All @@ -63,7 +63,7 @@ def __init__(self, controller_addr, worker_addr,
self.device = device
logger.info(f"Loading the model {self.model_name} on worker {worker_id} ...")
self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model(
model_path, model_base, self.model_name, load_8bit, load_4bit, device=self.device)
model_path, model_base, self.model_name, llm_type, load_8bit, load_4bit, device=self.device)
self.is_multimodal = 'llava' in self.model_name.lower()

if not no_register:
Expand Down Expand Up @@ -259,6 +259,7 @@ async def get_status(request: Request):
parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
parser.add_argument("--model-base", type=str, default=None)
parser.add_argument("--model-name", type=str)
parser.add_argument("--llm-type", type=str, default=None, choices=LlavaMetaModel.get_model_type_list())
parser.add_argument("--device", type=str, default="cuda")
parser.add_argument("--multi-modal", action="store_true", help="Multimodal mode is automatically detected with model name, please make sure `llava` is included in the model path.")
parser.add_argument("--limit-model-concurrency", type=int, default=5)
Expand All @@ -279,6 +280,7 @@ async def get_status(request: Request):
args.model_path,
args.model_base,
args.model_name,
args.llm_type,
args.load_8bit,
args.load_4bit,
args.device)
Expand Down
5 changes: 3 additions & 2 deletions robin/serve/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,11 @@ def load_image(image_file):

class LlavaMistralPipeline:

def __init__(self, model_path, model_base, device="cuda", load_8bit=False, load_4bit=False, temperature=.2, max_new_tokens=512):
def __init__(self, model_path, model_base, llm_type=None, device="cuda", load_8bit=False, load_4bit=False, temperature=.2, max_new_tokens=512):

self.model_path = model_path
self.model_base = model_base
self.llm_type = llm_type
self.device = device
self.load_8bit = load_8bit
self.load_4bit = load_4bit
Expand All @@ -48,7 +49,7 @@ def __init__(self, model_path, model_base, device="cuda", load_8bit=False, load_

model_name = get_model_name_from_path(self.model_path)

self.tokenizer, self.model, self.image_processor, context_len = load_pretrained_model(self.model_path, self.model_base, model_name, self.load_8bit, self.load_4bit, device=self.device)
self.tokenizer, self.model, self.image_processor, context_len = load_pretrained_model(self.model_path, self.model_base, model_name, self.llm_type, self.load_8bit, self.load_4bit, device=self.device)


def _load_image_tensor(self, image_file):
Expand Down
65 changes: 36 additions & 29 deletions robin/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@

from robin import conversation as conversation_lib
from robin.model import LlavaMistralForCausalLM, LlavaGPTNeoXForCausalLM, LlavaLlamaForCausalLM#, LlavaMPTForCausalLM [TODO] mpt is commented out at robin.model.__init__
from robin.model.builder import LlavaMetaModel
from robin.mm_utils import tokenizer_image_token, expand2square

from PIL import Image
Expand All @@ -51,6 +52,7 @@ def rank0_print(*args, **kwargs):
@dataclass
class ModelArguments:
model_name_or_path: Optional[str] = field(default="facebook/opt-125m")
llm_type: Optional[LlavaMetaModel.ModelType] = field(default=None)
version: Optional[str] = field(default="v0")
freeze_backbone: bool = field(default=False)
tune_mm_mlp_adapter: bool = field(default=False)
Expand Down Expand Up @@ -781,6 +783,10 @@ def train(USE_FLASH_ATTN_2=False):
local_rank = training_args.local_rank
compute_dtype = (torch.float16 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32))

# parse llm type input as an Enum
# no need for error checking as HFArgumentParser does it for Enums
llm_type = LlavaMetaModel.ModelType(model_args.llm_type) if model_args.llm_type is not None else LlavaMetaModel.get_model_type_from_model_name(model_args.model_name_or_path)

bnb_model_from_pretrained_args = {}
if training_args.bits in [4, 8]:
from transformers import BitsAndBytesConfig
Expand All @@ -800,42 +806,43 @@ def train(USE_FLASH_ATTN_2=False):
))

if model_args.vision_tower is not None:
model_name = model_args.model_name_or_path.lower()
rank0_print("Loading model of type:", end=' ')
if 'mpt' in model_name:
rank0_print("MPT")
config = transformers.AutoConfig.from_pretrained(model_args.model_name_or_path, trust_remote_code=True)
config.attn_config['attn_impl'] = training_args.mpt_attn_impl
model = LlavaMPTForCausalLM.from_pretrained(
model_args.model_name_or_path,
config=config,
cache_dir=training_args.cache_dir,
**bnb_model_from_pretrained_args
)
elif 'mistral' in model_name:
# if llm_type == LlavaMetaModel.ModelType.LlavaMPTModel:
# rank0_print("MPT")
# config = transformers.AutoConfig.from_pretrained(model_args.model_name_or_path, trust_remote_code=True)
# config.attn_config['attn_impl'] = training_args.mpt_attn_impl
# model = LlavaMPTForCausalLM.from_pretrained(
# model_args.model_name_or_path,
# config=config,
# cache_dir=training_args.cache_dir,
# **bnb_model_from_pretrained_args
# )
if llm_type == LlavaMetaModel.ModelType.LlavaMistralModel:
rank0_print("Mistral")
model = LlavaMistralForCausalLM.from_pretrained(
model_args.model_name_or_path,
cache_dir=training_args.cache_dir,
use_flash_attention_2 = USE_FLASH_ATTN_2,
**bnb_model_from_pretrained_args
)
elif any(x in model_name for x in ['neox', 'pythia', 'hi-nolin']):
elif llm_type == LlavaMetaModel.ModelType.LlavaGPTNeoXModel:
rank0_print("NeoX")
model = LlavaGPTNeoXForCausalLM.from_pretrained(
model_args.model_name_or_path,
cache_dir=training_args.cache_dir,
use_flash_attention_2 = USE_FLASH_ATTN_2, # The current architecture does not support Flash Attention 2.0
**bnb_model_from_pretrained_args
)
else:
elif llm_type == LlavaMetaModel.ModelType.LlavaLlamaModel:
rank0_print("Llama")
model = LlavaLlamaForCausalLM.from_pretrained(
model_args.model_name_or_path,
cache_dir=training_args.cache_dir,
**bnb_model_from_pretrained_args,
use_flash_attention_2 = USE_FLASH_ATTN_2,
)
else:
raise NotImplementedError
else:
model = transformers.LlamaForCausalLM.from_pretrained(
model_args.model_name_or_path,
Expand Down Expand Up @@ -878,21 +885,21 @@ def make_inputs_require_grad(module, input, output):
rank0_print("Adding LoRA adapters...")
model = get_peft_model(model, lora_config)

if 'mpt' in model_args.model_name_or_path:
tokenizer = transformers.AutoTokenizer.from_pretrained(
model_args.model_name_or_path,
cache_dir=training_args.cache_dir,
model_max_length=training_args.model_max_length,
padding_side="right"
)
else:
#print(model_args.model_name_or_path)
tokenizer = transformers.AutoTokenizer.from_pretrained(
model_args.model_name_or_path,
cache_dir=training_args.cache_dir,
model_max_length=training_args.model_max_length,
padding_side="right",
)
# if llm_type == LlavaMetaModel.ModelType.LlavaMPTModel:
# tokenizer = transformers.AutoTokenizer.from_pretrained(
# model_args.model_name_or_path,
# cache_dir=training_args.cache_dir,
# model_max_length=training_args.model_max_length,
# padding_side="right"
# )
# else:
#print(model_args.model_name_or_path)
tokenizer = transformers.AutoTokenizer.from_pretrained(
model_args.model_name_or_path,
cache_dir=training_args.cache_dir,
model_max_length=training_args.model_max_length,
padding_side="right",
)

if model_args.version == "v0":
if tokenizer.pad_token is None:
Expand Down