diff --git a/.gitignore b/.gitignore index 75580df..4319d63 100644 --- a/.gitignore +++ b/.gitignore @@ -171,4 +171,7 @@ src/outputs/ # scrapped data scrapped_docs/* -llms/* \ No newline at end of file +*_env/ + +# vscode +.vscode* \ No newline at end of file diff --git a/client_test/call_agent.py b/client_test/call_agent.py index bd0b28c..a46c476 100644 --- a/client_test/call_agent.py +++ b/client_test/call_agent.py @@ -2,24 +2,80 @@ import pathlib import requests +USER_NAME = 'erfan_miahi' +PASS = 'temp' + +def login(port: int = 8000): + # data = {'username': USER_NAME, 'email': 'mhi.erfan1@gmail.com', 'password': PASS} + # response = requests.post(f'http://localhost:{port}/register', json=data) + data = {'username': USER_NAME, 'password': PASS} + form_data = { + 'username': USER_NAME, + 'password': PASS + } + + # Headers to be sent in the POST request + headers = { + 'Content-Type': 'application/x-www-form-urlencoded', + } + + response = requests.post(f'http://localhost:{port}/token', data=form_data, headers=headers) + if response.status_code == 200: + print('Request successful!') + # Accessing the response content + print('Response content:', response.json()) + return response.json()['access_token'] + else: + print('Request failed with status code:', response.status_code) + print('Response content:', response.text) + return None + def send_request(input_text, proceeding_text, port=8000): - response = requests.put(f"http://localhost:{port}/generate", json={ - "file_path": str(pathlib.Path(__file__).parent.absolute()), # This file's path + token = login(port) + if not token: + print("Failed to retrieve token. Exiting.") + return + + # Headers to be sent in the POST request + headers = { + 'Content-Type': 'application/json', + 'Authorization': f'Bearer {token}' + } + print('Headers: ', headers) + + # Ensure the endpoint and method are correct (change PUT to POST if necessary) + response = requests.post(f"http://localhost:{port}/generate", json={ + "file_path": str(pathlib.Path(__file__).parent.absolute()), # This file's path "prior_context": input_text, "proceeding_context": proceeding_text, "max_decode_length": 128, - }, timeout=180) - output_data = response.json() - if "error" in output_data: - print(f"Error: {output_data['error']}") - return - output_text = output_data["generated_text"] - score = output_data["score"] + }, timeout=180, headers=headers) + + # Check for successful request + if response.status_code == 200: + try: + # Check if the response is JSON + if 'application/json' in response.headers.get('Content-Type', ''): + output_data = response.json() + if "error" in output_data: + print(f"Error: {output_data['error']}") + return + output_text = output_data["generated_text"] + score = output_data["score"] - print("Input text: " + input_text) - print(f"Generated text ({score:.3f}):") - print(output_text) + print("Input text: " + input_text) + print(f"Generated text ({score:.3f}):") + print(output_text) + else: + print("Response is not JSON:") + print(response.text) + except ValueError as e: + print('Error decoding JSON:', e) + print('Response content:', response.text) + else: + print(f"Request failed with status code: {response.status_code}") + print(f"Response content: {response.text}") if __name__ == "__main__": diff --git a/eval/run_eval.py b/eval/run_eval.py index c7ce6a8..f4eb6d7 100644 --- a/eval/run_eval.py +++ b/eval/run_eval.py @@ -21,7 +21,8 @@ sys.path.append('../') from src import config_handler -from src.modeling import ModelProvider, FIM_HOLE_TOKEN +from src.modeling.model_provider import ModelProvider +from src.modeling.tokens import FIM_HOLE_TOKEN from src.routers.fine_tuner import finetune_model, ProjectFinetuneData from src.training import finetune from benchmarks import run_human_eval_benchmark diff --git a/eval/run_rag_eval.py b/eval/run_rag_eval.py index 632d5bb..1d89166 100644 --- a/eval/run_rag_eval.py +++ b/eval/run_rag_eval.py @@ -31,7 +31,8 @@ from src import config_handler # from src import finetune, modeling # from src.data_formatting import IGNORE_INDEX, FIM_HOLE_TOKEN -from src.modeling import ModelProvider, FIM_HOLE_TOKEN +from src.modeling.model_provider import ModelProvider +from src.modeling.tokens import FIM_HOLE_TOKEN from src.rag import retrieve_context, VectorStoreProvider from src.routers.fine_tuner import collect_item_data, finetune_model, ProjectFinetuneData from src.training import finetune diff --git a/eval/utils.py b/eval/utils.py index 0a0f5d1..d589ebc 100644 --- a/eval/utils.py +++ b/eval/utils.py @@ -8,7 +8,7 @@ import torch sys.path.append('../') -from src.modeling import ModelProvider +from src.modeling.model_provider import ModelProvider def create_new_model_tuple(model_provider: ModelProvider): diff --git a/requirements.txt b/requirements.txt index 71334d9..3742c3e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,26 +1,27 @@ -transformers==4.38.2 +vllm==0.5.1 +transformers fastapi==0.108.0 -torch==2.2.1 +torch jax pydantic[email]==2.5.3 uvicorn[standard]==0.20.0 -datasets==2.16.1 -accelerate==0.27.2 +datasets +accelerate attrdict -tqdm==4.66.1 -bitsandbytes==0.41.2.post2 -peft==0.6.2 +tqdm +bitsandbytes +peft pytest==7.2.1 hydra-core==1.3.2 omegaconf==2.3.0 mock==5.1.0 -numpy==1.24.4 +numpy aiosqlite==0.20.0 -openai==1.13.3 -trl==0.7.10 +openai +trl packaging==23.2 ninja==1.11.1.1 -scipy==1.12.0 +scipy python-jose==3.3.0 passlib==1.7.4 python-multipart==0.0.5 @@ -38,5 +39,4 @@ charset-normalizer==3.3.2 pymupdf==1.24.5 pymupdf4llm==0.0.5 nougat-ocr==0.1.17 -llama-index==0.10.48.post1 -nougat-ocr==0.1.17 \ No newline at end of file +llama-index \ No newline at end of file diff --git a/src/conf/config.yaml b/src/conf/config.yaml index 490dc46..ac1f9d6 100644 --- a/src/conf/config.yaml +++ b/src/conf/config.yaml @@ -45,6 +45,9 @@ model: quant_type: nf4 optim: paged_adamw_32bit gradient_checkpointing: True + + # inference model parameters + inference_model_type: vllm inference: diff --git a/src/main.py b/src/main.py index 3b01aff..ef81a6f 100644 --- a/src/main.py +++ b/src/main.py @@ -12,7 +12,9 @@ sys.path.append('../') from src import config_handler, database -from src.modeling import ModelProvider, set_main_thread_id +from src.session import set_main_thread_id +from src.modeling.model_provider import ModelProvider + from src.rag import VectorStoreProvider from src.routers import auth, fine_tuner, generator from src.users import SessionTracker diff --git a/src/modeling.py b/src/modeling.py deleted file mode 100644 index 0eb1238..0000000 --- a/src/modeling.py +++ /dev/null @@ -1,459 +0,0 @@ -import abc -from argparse import Namespace -from concurrent.futures import CancelledError -import copy -from functools import partial -import logging -import os -import threading -import types -from typing import NamedTuple, Tuple - -import bitsandbytes as bnb -from huggingface_hub import snapshot_download -from peft import ( - prepare_model_for_kbit_training, - LoraConfig, - get_peft_model -) -import torch -from transformers import AutoTokenizer, AutoModelForCausalLM -from transformers import BitsAndBytesConfig - -from src.users import SessionTracker - - -logger = logging.getLogger(__name__) - -# Create a global variable to store the main thread ID -GLOBAL_MAIN_THREAD_ID = None - -EOT_TOKEN = '<|EOT|>' -FIM_BEGIN_TOKEN = '<|fim▁begin|>' # 32016 -FIM_HOLE_TOKEN = '<|fim▁hole|>' # 32015 -FIM_END_TOKEN = '<|fim▁end|>' # 32017 - - -def thread_hook(username, *args): - current_thread_id = threading.get_ident() - - session_tracker = SessionTracker.get_instance() - activity_threads = session_tracker.get_user_activity_threads(username) - - if current_thread_id not in activity_threads: - raise CancelledError("Cancelled by new request") - - -def set_main_thread_id(): - global GLOBAL_MAIN_THREAD_ID - GLOBAL_MAIN_THREAD_ID = threading.get_ident() - - -class ModelLoader(abc.ABC): - """This class has the responsibility of providing the functionality to load the model and its utilities, including tokenizer. - - Args: - config (Namespace): The configuration object. - """ - - def __init__(self, config: Namespace): - self._config = config - - def _determine_model_type(self): - if self._config.fp16: - return torch.float16 - elif self._config.bf16: - return torch.bfloat16 - else: - return torch.float32 - - def _determine_device(self): - if self._config.device == 'cpu': - return torch.device('cpu') - elif self._config.device == 'cuda': - assert torch.cuda.is_available(), 'CUDA device is not available' - return torch.device('cuda') - else: - raise Exception(f"Unknown device: {self._config.device}") - - def _find_all_linear_names(self, model): - cls = bnb.nn.Linear4bit if self._config.bits == 4 else \ - (bnb.nn.Linear8bitLt if self._config.bits == 8 else torch.nn.Linear) - lora_module_names = set() - for name, module in model.named_modules(): - if isinstance(module, cls): - names = name.split('.') - lora_module_names.add(names[0] if len(names) == 1 else names[-1]) - - if 'lm_head' in lora_module_names: # needed for 16-bit - lora_module_names.remove('lm_head') - return list(lora_module_names) - - def _determine_flash_attention(self): - if not self._config.use_flash_attention: - return False - - flash_attention_compatible = \ - 'cuda' in self._config.device and (self._config.fp16 or self._config.bf16) - - if not flash_attention_compatible: - raise ValueError("Flash attention is only compatible with CUDA and FP16/BF16!") - - return True - - @abc.abstractmethod - def load_model(self) -> Tuple[torch.nn.Module, dict]: - pass - - -def make_synchronous_func(tokenizer, lock, source_func): - def new_func(*args, **kwargs): - with lock: - return source_func(*args, **kwargs) - return new_func - - -class SynchronizedTokenizer(): - """ - A synchronized tokenizer class that creates partially synchronous tokenizers - to prevent tokenizer concurrency issues. - This could probably be made more efficient by allowing multiple encode / decodes - at the same time, but just making sure that writes are synchronous with encodes / - decodes. For now, the this seems efficient enough. - """ - - def from_tokenizer(tokenizer): - tokenizer._lock = threading.RLock() - - # Take the functions of this class and copy them over to the tokenizer - # This allows us to copy the tokenizer - for func in ('__getstate__', '__setstate__', '__deepcopy__'): - source_func = getattr(SynchronizedTokenizer, func) - setattr(tokenizer, func, types.MethodType(source_func, tokenizer)) - - # Make these function synchronous with delete and set var functions - for func in ('__call__', 'encode', 'decode'): - source_func = getattr(tokenizer, func) - setattr(tokenizer, func, make_synchronous_func(tokenizer, tokenizer._lock, source_func)) - - for func in ('__setattr__', '__delattr__'): - source_func = getattr(tokenizer, func) - setattr(tokenizer, func, make_synchronous_func(tokenizer, tokenizer._lock, source_func)) - - return tokenizer - - def __getstate__(self): - state = self.__dict__.copy() - # Remove the lock from the state because it cannot be pickled. - del state['_lock'] - return state - - def __setstate__(self, state): - # Reinitialize the lock after unpickling. - state['_lock'] = threading.RLock() - self.__dict__.update(state) - - def __deepcopy__(self, memo): - # Customize the deepcopy behavior to handle the lock. - cls = self.__class__ - # Create a new instance without calling __init__ (to avoid creating another lock initially). - result = cls.__new__(cls) - memo[id(self)] = result - for k, v in self.__dict__.items(): - if k == '_lock': - # Initialize a new lock for the copied object. - setattr(result, k, threading.RLock()) - else: - setattr(result, k, copy.deepcopy(v, memo)) - return result - - -class StandardModelLoader(ModelLoader): - - def load_model(self, device=None) -> Tuple[torch.nn.Module, dict]: - - global GLOBAL_MAIN_THREAD_ID - - model_dir = self._config.model_dir - model_name = self._config.model_name - - device = device or self._determine_device() - dtype = self._determine_model_type() - use_flash_attention = self._determine_flash_attention() - - if use_flash_attention: - attn_implementation = 'flash_attention_2' - logger.info('Using flash attention') - else: - attn_implementation = None - - # Ideally we should check all the files, but for now just check one - model_dir = os.path.join(model_dir, model_name) - if not os.path.isfile(os.path.join(model_dir, 'config.json')): - snapshot_download(model_name, local_dir=model_dir, local_dir_use_symlinks=False, - ignore_patterns=['*.msgpack', '*.h5']) - - os.environ['TOKENIZERS_PARALLELISM'] = \ - os.environ.get('TOKENIZERS_PARALLELISM', 'true') - tokenizer = AutoTokenizer.from_pretrained(model_dir) - tokenizer.model_max_length = self._config.context_length - tokenizer.truncation_side = 'left' - - model = AutoModelForCausalLM.from_pretrained( - model_dir, attn_implementation=attn_implementation, - device_map=device, torch_dtype=dtype) - model.to(device) - - GLOBAL_MAIN_THREAD_ID = threading.get_ident() - - return model, {'tokenizer': SynchronizedTokenizer.from_tokenizer(tokenizer)} - - -class LoraModelLoader(ModelLoader): - - def load_model(self) -> Tuple[torch.nn.Module, dict]: - - global GLOBAL_MAIN_THREAD_ID - - model_dir = self._config.model_dir - model_name = self._config.model_name - - device = self._determine_device() - dtype = self._determine_model_type() - use_flash_attention = self._determine_flash_attention() - - if use_flash_attention: - logger.info("Using flash attention") - - # Ideally we should check all the files, but for now just check one - model_dir = os.path.join(model_dir, model_name) - if not os.path.isfile(os.path.join(model_dir, 'config.json')): - snapshot_download(model_name, local_dir=model_dir, local_dir_use_symlinks=False, - ignore_patterns=['*.msgpack', '*.h5']) - - tokenizer = AutoTokenizer.from_pretrained(model_dir) - tokenizer.model_max_length = self._config.context_length - tokenizer.truncation_side = 'left' - - model = AutoModelForCausalLM.from_pretrained( - model_dir, use_flash_attention_2=use_flash_attention, - device_map=device, torch_dtype=dtype - ) - - modules = self._find_all_linear_names(model) - logger.info(f"Modules to be fine-tuned: {modules}") - config = LoraConfig( - r=self._config.lora_r, - lora_alpha=self._config.lora_alpha, - target_modules=modules, - lora_dropout=self._config.lora_dropout, - bias="none", - task_type="CAUSAL_LM", - ) - model = get_peft_model(model, config) - model.to(device) - - for _, md in model.named_modules(): - md.register_forward_hook(thread_hook) - - GLOBAL_MAIN_THREAD_ID = threading.get_ident() - - return model, {'tokenizer': tokenizer} - - -class QLoraModelLoader(ModelLoader): - - def load_model(self) -> Tuple[torch.nn.Module, dict]: - - global GLOBAL_MAIN_THREAD_ID - - model_dir = self._config.model_dir - model_name = self._config.model_name - - device = self._determine_device() - dtype = self._determine_model_type() - use_flash_attention = self._determine_flash_attention() - - if use_flash_attention: - logger.info('Using flash attention') - - # Ideally we should check all the files, but for now just check one - model_dir = os.path.join(model_dir, model_name) - if not os.path.isfile(os.path.join(model_dir, 'config.json')): - snapshot_download(model_name, local_dir=model_dir, local_dir_use_symlinks=False, - ignore_patterns=['*.msgpack', '*.h5']) - - tokenizer = AutoTokenizer.from_pretrained(model_dir) - tokenizer.model_max_length = self._config.context_length - tokenizer.truncation_side = 'left' - - bb_config = BitsAndBytesConfig( - load_in_4bit=self._config.bits == 4, - load_in_8bit=self._config.bits == 8, - llm_int8_threshold=6.0, - llm_int8_has_fp16_weight=False, - bnb_4bit_use_double_quant=self._config.double_quant, - bnb_4bit_quant_type=self._config.quant_type, - ) - - model = AutoModelForCausalLM.from_pretrained( - model_dir, use_flash_attention_2=use_flash_attention, - # device_map=device, # it weirdly crashes if I add the device for QLORA - torch_dtype=dtype, use_cache=False, - load_in_4bit=self._config.bits == 4, - load_in_8bit=self._config.bits == 8, - config=bb_config - ) - - model = prepare_model_for_kbit_training( - model, use_gradient_checkpointing=self._config.gradient_checkpointing) - - modules = self._find_all_linear_names(model) - logger.info(f"Modules to be fine-tuned: {modules}") - config = LoraConfig( - r=self._config.lora_r, - lora_alpha=self._config.lora_alpha, - target_modules=modules, - lora_dropout=self._config.lora_dropout, - bias="none", - task_type="CAUSAL_LM", - ) - model = get_peft_model(model, config) - model.to(device) - - GLOBAL_MAIN_THREAD_ID = threading.get_ident() - - return model, {'tokenizer': tokenizer} - - -# Named tuple for model and model utilities -ModelTuple = NamedTuple( - 'ModelTuple', - [('model', torch.nn.Module), ('model_utils', dict)] -) - - -class ModelProvider: - _instance = None - _lock = threading.Lock() - _model_loaders = { - 'standard': StandardModelLoader, - 'lora': LoraModelLoader, - 'qlora': QLoraModelLoader - } - - @classmethod - def get_instance(cls, config: dict = None): - # First check for the singleton instance existence without acquiring the lock - if cls._instance is None: - # Acquire the lock and check again to ensure no other thread created the instance - with cls._lock: - if cls._instance is None: - cls._instance = cls(config) - return cls._instance - - def __init__(self, config: dict): - if ModelProvider._instance is not None: - raise Exception("This class is a singleton!") - else: - ModelProvider._instance = self - - # Initialize the model here - self.config = config - self._models: dict[str, ModelTuple] = {} - self._user_locks = {} # Locks for each user's model - - # New models are cloned from the base model, which is stored on the CPU - # This drastically reduces load times for newly active users - model_loader = ModelProvider._model_loaders[config.model_type](config) - self._target_device = model_loader._determine_device() - model, model_utils = model_loader.load_model(device=self._target_device) - self._base_model_tuple = ModelTuple(model, model_utils) - - def _get_user_lock(self, username: str): - """Get the lock for the user's model.""" - if username not in self._user_locks: - with ModelProvider._lock: - if username not in self._user_locks: - self._user_locks[username] = threading.Lock() - return self._user_locks[username] - - def create_new_model_tuple(self) -> ModelTuple: - """Create a new model tuple from the base model.""" - model = copy.deepcopy(self._base_model_tuple.model) - model = model.to(self._target_device) - # We also want to copy the tokenizer because HuggingFace tokenizer may - # throw an error if you try to use them concurrently - # Previously was having this problem: - # https://github.com/huggingface/tokenizers/issues/537 - model_utils = copy.deepcopy(self._base_model_tuple.model_utils) - return ModelTuple(model, model_utils) - - def _register_preemption_hooks(self, model: torch.nn.Module, username: str): - """Register hooks that preempt the model if a newer request is made. - - This allows us to cancel the current request if a new one is made, - saving resources and preventing the server from being overwhelmed for - requests that are no longer needed. - - Args: - model (torch.Module): The model to register the hooks on. - username (str): The username of the user. - """ - for _, md in model.named_modules(): - md.register_forward_hook( - partial(thread_hook, username)) - - def get_model_tuple(self, username: str) -> ModelTuple: - """Get the model and model utilities for the user.""" - with self._get_user_lock(username): - if username not in self._models: - self._models[username] = self.create_new_model_tuple() - model = self._models[username].model - self._register_preemption_hooks(model, username) - return self._models[username] - - def get_model(self, username: str): - return self.get_model_tuple(username).model - - def get_model_utils(self, username: str): - return self.get_model_tuple(username).model_utils - - def update_model(self, username: str, model: torch.nn.Module): - """Update the model for the user.""" - with self._get_user_lock(username): - if username in self._models: - self._models[username] = ModelTuple(model, self._models[username].model_utils) - else: - raise ValueError(f"Model for user {username} does not exist.") - - def delete_model(self, username: str): - """Delete the model for the user.""" - with self._get_user_lock(username): - if username in self._models: - del self._models[username] - else: - logger.warning(f"Tried to delete model for user {username}, but it does not exist.") - - torch.cuda.empty_cache() - - with ModelProvider._lock: - if username in self._user_locks: - del self._user_locks[username] - else: - logger.warning(f"Tried to delete lock for user {username}, but it does not exist.") - - -def get_model(username: str): - model_provider = ModelProvider.get_instance() - return model_provider.get_model(username) - - -def get_model_utils(username: str): - model_provider = ModelProvider.get_instance() - return model_provider.get_model_utils(username) - - -def get_tokenizer(username: str): - model_utils = get_model_utils(username) - return model_utils['tokenizer'] diff --git a/src/modeling/model_hub.py b/src/modeling/model_hub.py new file mode 100644 index 0000000..4f36a64 --- /dev/null +++ b/src/modeling/model_hub.py @@ -0,0 +1,21 @@ +from src.modeling.model_provider import ModelProvider + + +def get_model(username: str): + model_provider = ModelProvider.get_instance() + return model_provider.get_model(username) + + +def get_model_utils(username: str): + model_provider = ModelProvider.get_instance() + return model_provider.get_model_utils(username) + + +def get_tokenizer(username: str): + model_utils = get_model_utils(username) + return model_utils['tokenizer'] + + +def get_inference_model(username: str): + model_provider = ModelProvider.get_instance() + return model_provider.get_inference_model(username) diff --git a/src/modeling/model_loaders/__init__.py b/src/modeling/model_loaders/__init__.py new file mode 100644 index 0000000..49d0c5f --- /dev/null +++ b/src/modeling/model_loaders/__init__.py @@ -0,0 +1,3 @@ +from src.modeling.model_loaders.standard import StandardModelLoader +from src.modeling.model_loaders.lora import LoraModelLoader +from src.modeling.model_loaders.qlora import QLoraModelLoader \ No newline at end of file diff --git a/src/modeling/model_loaders/base.py b/src/modeling/model_loaders/base.py new file mode 100644 index 0000000..0ef9b98 --- /dev/null +++ b/src/modeling/model_loaders/base.py @@ -0,0 +1,80 @@ +import abc +from argparse import Namespace +from typing import NamedTuple, Tuple, Union, Any, Optional + +import bitsandbytes as bnb +import torch +from vllm import LLM + +from src.modeling.model_wrappers import VLLMWrapper, HuggingFaceModelWrapper + + +class ModelLoader(abc.ABC): + """This class has the responsibility of providing the functionality to load the model and its utilities, including tokenizer. + + Args: + config (Namespace): The configuration object. + """ + + _model_wrappers = { + 'vllm': VLLMWrapper, + 'huggingface': HuggingFaceModelWrapper + } + + def __init__(self, config: Namespace): + self._config = config + + def _determine_model_type(self): + if self._config.fp16: + return torch.float16 + elif self._config.bf16: + return torch.bfloat16 + else: + return torch.float32 + + def _determine_device(self): + if self._config.device == 'cpu': + return torch.device('cpu') + elif self._config.device == 'cuda': + assert torch.cuda.is_available(), 'CUDA device is not available' + return torch.device('cuda') + else: + raise Exception(f"Unknown device: {self._config.device}") + + def _find_all_linear_names(self, model): + cls = bnb.nn.Linear4bit if self._config.bits == 4 else \ + (bnb.nn.Linear8bitLt if self._config.bits == 8 else torch.nn.Linear) + lora_module_names = set() + for name, module in model.named_modules(): + if isinstance(module, cls): + names = name.split('.') + lora_module_names.add(names[0] if len(names) == 1 else names[-1]) + + if 'lm_head' in lora_module_names: # needed for 16-bit + lora_module_names.remove('lm_head') + return list(lora_module_names) + + def _determine_flash_attention(self): + if not self._config.use_flash_attention: + return False + + flash_attention_compatible = \ + 'cuda' in self._config.device and (self._config.fp16 or self._config.bf16) + + if not flash_attention_compatible: + raise ValueError("Flash attention is only compatible with CUDA and FP16/BF16!") + + return True + + @abc.abstractmethod + def load_model(self) -> Tuple[torch.nn.Module, dict]: + pass + + def load_inference_model(self) -> Tuple[Union[LLM, torch.nn.Module], dict]: + model, utils = self._load_inference_model(self._config.inference_model_type) + model = self._model_wrappers[self._config.inference_model_type](model, utils.get('tokenizer', None)) + return model, utils + + @abc.abstractmethod + def _load_inference_model(self) -> Tuple[Union[LLM, torch.nn.Module], dict]: + pass \ No newline at end of file diff --git a/src/modeling/model_loaders/lora.py b/src/modeling/model_loaders/lora.py new file mode 100644 index 0000000..e7305e1 --- /dev/null +++ b/src/modeling/model_loaders/lora.py @@ -0,0 +1,83 @@ +import logging +import os +import threading +from typing import NamedTuple, Tuple, Union, Any, Optional + +from huggingface_hub import snapshot_download +from peft import ( + prepare_model_for_kbit_training, + LoraConfig, + get_peft_model +) +import torch +from transformers import AutoTokenizer, AutoModelForCausalLM +from vllm import LLM + +# from src.routers.generator import GenerateData +from src.modeling.model_loaders.base import ModelLoader +from src.session import thread_hook + +logger = logging.getLogger(__name__) + + +class LoraModelLoader(ModelLoader): + + def load_model(self) -> Tuple[torch.nn.Module, dict]: + + global GLOBAL_MAIN_THREAD_ID + + model_dir = self._config.model_dir + model_name = self._config.model_name + + device = self._determine_device() + dtype = self._determine_model_type() + use_flash_attention = self._determine_flash_attention() + + if use_flash_attention: + logger.info("Using flash attention") + + # Ideally we should check all the files, but for now just check one + model_dir = os.path.join(model_dir, model_name) + if not os.path.isfile(os.path.join(model_dir, 'config.json')): + snapshot_download(model_name, local_dir=model_dir, local_dir_use_symlinks=False, + ignore_patterns=['*.msgpack', '*.h5']) + + tokenizer = AutoTokenizer.from_pretrained(model_dir) + tokenizer.model_max_length = self._config.context_length + tokenizer.truncation_side = 'left' + + model = AutoModelForCausalLM.from_pretrained( + model_dir, use_flash_attention_2=use_flash_attention, + device_map=device, torch_dtype=dtype + ) + + modules = self._find_all_linear_names(model) + logger.info(f"Modules to be fine-tuned: {modules}") + config = LoraConfig( + r=self._config.lora_r, + lora_alpha=self._config.lora_alpha, + target_modules=modules, + lora_dropout=self._config.lora_dropout, + bias="none", + task_type="CAUSAL_LM", + ) + model = get_peft_model(model, config) + model.to(device) + + for _, md in model.named_modules(): + md.register_forward_hook(thread_hook) + + GLOBAL_MAIN_THREAD_ID = threading.get_ident() + + return model, {'tokenizer': tokenizer} + + def _load_inference_model(self) -> Tuple[Union[LLM, torch.nn.Module], dict]: + model_name = self._config.model_name + + model = LLM( + model=model_name, + gpu_memory_utilization=0.4, + enable_lora=True + ) + + return model, {} \ No newline at end of file diff --git a/src/modeling/model_loaders/qlora.py b/src/modeling/model_loaders/qlora.py new file mode 100644 index 0000000..d50c759 --- /dev/null +++ b/src/modeling/model_loaders/qlora.py @@ -0,0 +1,89 @@ +import logging +import os +import threading +from typing import NamedTuple, Tuple, Union, Any, Optional + +from huggingface_hub import snapshot_download +from peft import ( + prepare_model_for_kbit_training, + LoraConfig, + get_peft_model +) +import torch +from transformers import AutoTokenizer, AutoModelForCausalLM +from transformers import BitsAndBytesConfig + +from src.modeling.model_loaders.base import ModelLoader + + +logger = logging.getLogger(__name__) + + +class QLoraModelLoader(ModelLoader): + + def load_model(self) -> Tuple[torch.nn.Module, dict]: + + global GLOBAL_MAIN_THREAD_ID + + model_dir = self._config.model_dir + model_name = self._config.model_name + + device = self._determine_device() + dtype = self._determine_model_type() + use_flash_attention = self._determine_flash_attention() + + if use_flash_attention: + logger.info('Using flash attention') + + # Ideally we should check all the files, but for now just check one + model_dir = os.path.join(model_dir, model_name) + if not os.path.isfile(os.path.join(model_dir, 'config.json')): + snapshot_download(model_name, local_dir=model_dir, local_dir_use_symlinks=False, + ignore_patterns=['*.msgpack', '*.h5']) + + tokenizer = AutoTokenizer.from_pretrained(model_dir) + tokenizer.model_max_length = self._config.context_length + tokenizer.truncation_side = 'left' + + bb_config = BitsAndBytesConfig( + load_in_4bit=self._config.bits == 4, + load_in_8bit=self._config.bits == 8, + llm_int8_threshold=6.0, + llm_int8_has_fp16_weight=False, + bnb_4bit_use_double_quant=self._config.double_quant, + bnb_4bit_quant_type=self._config.quant_type, + ) + + model = AutoModelForCausalLM.from_pretrained( + model_dir, use_flash_attention_2=use_flash_attention, + # device_map=device, # it weirdly crashes if I add the device for QLORA + torch_dtype=dtype, use_cache=False, + load_in_4bit=self._config.bits == 4, + load_in_8bit=self._config.bits == 8, + config=bb_config + ) + + model = prepare_model_for_kbit_training( + model, use_gradient_checkpointing=self._config.gradient_checkpointing) + + modules = self._find_all_linear_names(model) + logger.info(f"Modules to be fine-tuned: {modules}") + config = LoraConfig( + r=self._config.lora_r, + lora_alpha=self._config.lora_alpha, + target_modules=modules, + lora_dropout=self._config.lora_dropout, + bias="none", + task_type="CAUSAL_LM", + ) + model = get_peft_model(model, config) + model.to(device) + + GLOBAL_MAIN_THREAD_ID = threading.get_ident() + + return model, {'tokenizer': tokenizer} + + def _load_inference_model(self) -> Tuple[torch.nn.Module, dict]: + ValueError('VLLM is not supporting QLORA yet!') + + diff --git a/src/modeling/model_loaders/standard.py b/src/modeling/model_loaders/standard.py new file mode 100644 index 0000000..50ae94e --- /dev/null +++ b/src/modeling/model_loaders/standard.py @@ -0,0 +1,142 @@ +import copy +import logging +import os +import threading +import types +from typing import NamedTuple, Tuple, Union, Any, Optional + +from huggingface_hub import snapshot_download +import torch +from transformers import AutoTokenizer, AutoModelForCausalLM +from vllm import LLM + +from src.modeling.model_loaders.base import ModelLoader + + +logger = logging.getLogger(__name__) + + +def make_synchronous_func(tokenizer, lock, source_func): + def new_func(*args, **kwargs): + with lock: + return source_func(*args, **kwargs) + return new_func + + +class SynchronizedTokenizer(): + """ + A synchronized tokenizer class that creates partially synchronous tokenizers + to prevent tokenizer concurrency issues. + This could probably be made more efficient by allowing multiple encode / decodes + at the same time, but just making sure that writes are synchronous with encodes / + decodes. For now, the this seems efficient enough. + """ + + def from_tokenizer(tokenizer): + tokenizer._lock = threading.RLock() + + # Take the functions of this class and copy them over to the tokenizer + # This allows us to copy the tokenizer + for func in ('__getstate__', '__setstate__', '__deepcopy__'): + source_func = getattr(SynchronizedTokenizer, func) + setattr(tokenizer, func, types.MethodType(source_func, tokenizer)) + + # Make these function synchronous with delete and set var functions + for func in ('__call__', 'encode', 'decode'): + source_func = getattr(tokenizer, func) + setattr(tokenizer, func, make_synchronous_func(tokenizer, tokenizer._lock, source_func)) + + for func in ('__setattr__', '__delattr__'): + source_func = getattr(tokenizer, func) + setattr(tokenizer, func, make_synchronous_func(tokenizer, tokenizer._lock, source_func)) + + return tokenizer + + def __getstate__(self): + state = self.__dict__.copy() + # Remove the lock from the state because it cannot be pickled. + del state['_lock'] + return state + + def __setstate__(self, state): + # Reinitialize the lock after unpickling. + state['_lock'] = threading.RLock() + self.__dict__.update(state) + + def __deepcopy__(self, memo): + # Customize the deepcopy behavior to handle the lock. + cls = self.__class__ + # Create a new instance without calling __init__ (to avoid creating another lock initially). + result = cls.__new__(cls) + memo[id(self)] = result + for k, v in self.__dict__.items(): + if k == '_lock': + # Initialize a new lock for the copied object. + setattr(result, k, threading.RLock()) + else: + setattr(result, k, copy.deepcopy(v, memo)) + return result + + + + +class StandardModelLoader(ModelLoader): + + def _load_vllm_model(self): + model_name = self._config.model_name + + model = LLM( + model=model_name, + gpu_memory_utilization=0.4, + enable_lora=False + ) + + return model, {} + + def _load_huggingface_model(self, device=None): + global GLOBAL_MAIN_THREAD_ID + + model_dir = self._config.model_dir + model_name = self._config.model_name + + device = device or self._determine_device() + dtype = self._determine_model_type() + use_flash_attention = self._determine_flash_attention() + + if use_flash_attention: + attn_implementation = 'flash_attention_2' + logger.info('Using flash attention') + else: + attn_implementation = None + + # Ideally we should check all the files, but for now just check one + model_dir = os.path.join(model_dir, model_name) + if not os.path.isfile(os.path.join(model_dir, 'config.json')): + snapshot_download(model_name, local_dir=model_dir, local_dir_use_symlinks=False, + ignore_patterns=['*.msgpack', '*.h5']) + + os.environ['TOKENIZERS_PARALLELISM'] = \ + os.environ.get('TOKENIZERS_PARALLELISM', 'true') + tokenizer = AutoTokenizer.from_pretrained(model_dir) + tokenizer.model_max_length = self._config.context_length + tokenizer.truncation_side = 'left' + + model = AutoModelForCausalLM.from_pretrained( + model_dir, attn_implementation=attn_implementation, + device_map=device, torch_dtype=dtype) + model.to(device) + + GLOBAL_MAIN_THREAD_ID = threading.get_ident() + + return model, {'tokenizer': SynchronizedTokenizer.from_tokenizer(tokenizer)} + + def load_model(self, device=None) -> Tuple[torch.nn.Module, dict]: + return self._load_huggingface_model(device) + + def _load_inference_model(self, model_type: str) -> Tuple[Union[LLM, torch.nn.Module], dict]: + if model_type == 'huggingface': + return self._load_huggingface_model() + elif model_type == 'vllm': + return self._load_vllm_model() + else: + ValueError(f'This model type is not supported: {model_type}') \ No newline at end of file diff --git a/src/modeling/model_provider.py b/src/modeling/model_provider.py new file mode 100644 index 0000000..05dbadd --- /dev/null +++ b/src/modeling/model_provider.py @@ -0,0 +1,166 @@ +import abc +from argparse import Namespace +from concurrent.futures import CancelledError +import copy +from functools import partial +import logging +import os +import threading +from typing import NamedTuple + +import torch + +from src.modeling.model_loaders import StandardModelLoader, LoraModelLoader, QLoraModelLoader +from src.session import thread_hook + +logger = logging.getLogger(__name__) + + +# Named tuple for model and model utilities +ModelTuple = NamedTuple( + 'ModelTuple', + [('model', torch.nn.Module), ('model_utils', dict)] +) + +class ModelProvider: + _instance = None + _lock = threading.Lock() + _model_loaders = { + 'standard': StandardModelLoader, + 'lora': LoraModelLoader, + 'qlora': QLoraModelLoader + } + + @classmethod + def get_instance(cls, config: dict = None): + # First check for the singleton instance existence without acquiring the lock + if cls._instance is None: + # Acquire the lock and check again to ensure no other thread created the instance + with cls._lock: + if cls._instance is None: + cls._instance = cls(config) + return cls._instance + + def __init__(self, config: dict): + if ModelProvider._instance is not None: + raise Exception("This class is a singleton!") + else: + ModelProvider._instance = self + + # Initialize the model here + self.config = config + self._models: dict[str, ModelTuple] = {} + self._inference_models: dict[str, ModelTuple] = {} + self._user_locks = {} # Locks for each user's model + + # New models are cloned from the base model, which is stored on the CPU + # This drastically reduces load times for newly active users + model_loader = ModelProvider._model_loaders[config.model_type](config) + self._target_device = model_loader._determine_device() + model, model_utils = model_loader.load_model(device=self._target_device) + inf_model, inf_model_utils = model_loader.load_inference_model() #NOTE: VLLM only works in GPU mode for now + self._base_model_tuple = ModelTuple(model, model_utils) + self._base_inf_model_tuple = ModelTuple(inf_model, inf_model_utils) + + def _get_user_lock(self, username: str): + """Get the lock for the user's model.""" + if username not in self._user_locks: + with ModelProvider._lock: + if username not in self._user_locks: + self._user_locks[username] = threading.Lock() + return self._user_locks[username] + + def create_new_model_tuple(self) -> ModelTuple: + """Create a new model tuple from the base model.""" + model = copy.deepcopy(self._base_model_tuple.model) + model = model.to(self._target_device) + # We also want to copy the tokenizer because HuggingFace tokenizer may + # throw an error if you try to use them concurrently + # Previously was having this problem: + # https://github.com/huggingface/tokenizers/issues/537 + model_utils = copy.deepcopy(self._base_model_tuple.model_utils) + return ModelTuple(model, model_utils) + + def create_new_inference_model_tuple(self) -> ModelTuple: + """Create a new model tuple from the base model.""" + # model = copy.deepcopy(self._base_inf_model_tuple.model) + model = self._base_inf_model_tuple.model + # model = model.to(self._target_device) + # We also want to copy the tokenizer because HuggingFace tokenizer may + # throw an error if you try to use them concurrently + # Previously was having this problem: + # https://github.com/huggingface/tokenizers/issues/537 + # model_utils = copy.deepcopy(self._base_inf_model_tuple.model_utils) + model_utils = self._base_inf_model_tuple.model_utils + return ModelTuple(model, model_utils) + + def _register_preemption_hooks(self, model: torch.nn.Module, username: str): + """Register hooks that preempt the model if a newer request is made. + + This allows us to cancel the current request if a new one is made, + saving resources and preventing the server from being overwhelmed for + requests that are no longer needed. + + Args: + model (torch.Module): The model to register the hooks on. + username (str): The username of the user. + """ + for _, md in model.named_modules(): + md.register_forward_hook( + partial(thread_hook, username)) + + def get_model_tuple(self, username: str) -> ModelTuple: + """Get the model and model utilities for the user.""" + with self._get_user_lock(username): + if username not in self._models: + self._models[username] = self.create_new_inference_model_tuple() + model = self._models[username].model + self._register_preemption_hooks(model, username) + return self._models[username] + + def get_inference_model_tuple(self, username: str) -> ModelTuple: + """Get the inference model and model utilities for the user.""" + # TODO: think about how to use a single inference model in the case of multiple users + with self._get_user_lock(username): # TODO: the inference and training lock should be different + if username not in self._inference_models: + self._inference_models[username] = self.create_new_inference_model_tuple() + # TODO: figure out forward hook is required for vllm like models + # model = self._inf_models[username].model + # self._register_preemption_hooks(model, username) + + return self._inference_models[username] + + def get_inference_model(self, username: str): + return self.get_inference_model_tuple(username).model + + def get_model(self, username: str): + return self.get_model_tuple(username).model + + def get_model_utils(self, username: str): + return self.get_model_tuple(username).model_utils + + def update_model(self, username: str, model: torch.nn.Module): + """Update the model for the user.""" + with self._get_user_lock(username): + if username in self._models: + self._models[username] = ModelTuple(model, self._models[username].model_utils) + else: + raise ValueError(f"Model for user {username} does not exist.") + + def delete_model(self, username: str): + """Delete the model for the user.""" + with self._get_user_lock(username): + if username in self._models: + del self._models[username] + else: + logger.warning(f"Tried to delete model for user {username}, but it does not exist.") + + torch.cuda.empty_cache() + + with ModelProvider._lock: + if username in self._user_locks: + del self._user_locks[username] + else: + logger.warning(f"Tried to delete lock for user {username}, but it does not exist.") + + diff --git a/src/modeling/model_wrappers/__init__.py b/src/modeling/model_wrappers/__init__.py new file mode 100644 index 0000000..b9f66db --- /dev/null +++ b/src/modeling/model_wrappers/__init__.py @@ -0,0 +1,2 @@ +from src.modeling.model_wrappers.hugging_face import HuggingFaceModelWrapper +from src.modeling.model_wrappers.vllm import VLLMWrapper \ No newline at end of file diff --git a/src/modeling/model_wrappers/base.py b/src/modeling/model_wrappers/base.py new file mode 100644 index 0000000..e895112 --- /dev/null +++ b/src/modeling/model_wrappers/base.py @@ -0,0 +1,54 @@ +import abc +import torch +from typing import Union, Optional +from vllm import LLM +from omegaconf import DictConfig +from transformers import PreTrainedTokenizer, BatchEncoding +from llama_index.core import VectorStoreIndex + +from src.training.data_formatting import prepare_input +from src.types import GenerateData + + + +class ModelWrapper(abc.ABC): + + def __init__( + self, + model: Union[torch.nn.Module, LLM], + tokenizer: PreTrainedTokenizer = None, + **kwargs + ): + self._wrapped_model = model + self._wrapped_tokenizer = tokenizer + + def __getattr__(self, name): + # Delegate attribute access to the wrapped instance + return getattr(self._wrapped_model, name) + + def __setattr__(self, name, value): + if name == '_wrapped_model' or name == '_wrapped_tokenizer': + super().__setattr__(name, value) + else: + setattr(self._wrapped_model, name, value) + + def __delattr__(self, name): + delattr(self._wrapped_model, name) + + @abc.abstractmethod + def generate_completion( + self, + item: GenerateData, + config: DictConfig, + vector_store: Optional[VectorStoreIndex] = None, + ) -> dict: + pass + + def _prepare_input( + self, + item: GenerateData, + config: DictConfig, + tokenizer: PreTrainedTokenizer, + vector_store: Optional[VectorStoreIndex] = None, + ) -> BatchEncoding: + return prepare_input(item, config, tokenizer, vector_store) \ No newline at end of file diff --git a/src/modeling/model_wrappers/hugging_face.py b/src/modeling/model_wrappers/hugging_face.py new file mode 100644 index 0000000..f5f0f32 --- /dev/null +++ b/src/modeling/model_wrappers/hugging_face.py @@ -0,0 +1,53 @@ +import torch +from typing import Optional + +from llama_index.core import VectorStoreIndex +from omegaconf import DictConfig +from transformers import GenerationConfig, PreTrainedTokenizer + +from src.modeling.model_wrappers.base import ModelWrapper +from src.types import GenerateData + +class HuggingFaceModelWrapper(ModelWrapper): + + _generation_config = GenerationConfig( + temperature=0.7, top_k=5, do_sample=True, + ) + + def __init__( + self, + model: torch.nn.Module, + tokenizer: PreTrainedTokenizer, + **kwargs + ): + super().__init__(model, tokenizer) + + def generate_completion( + self, + item: GenerateData, + config: DictConfig, + vector_store: Optional[VectorStoreIndex] = None, + ) -> dict: + + tokenizer = self._wrapped_tokenizer + + inputs = self._prepare_input(item, config, tokenizer, vector_store).to(self.device) + + outputs = self.generate( + **inputs, max_new_tokens=config.inference.max_gen_length, + return_dict_in_generate=True, output_scores=True, + generation_config=self._generation_config, + ) + + out_tokens = outputs.sequences[0][inputs.input_ids.shape[1]:] + output_text = tokenizer.decode(out_tokens, skip_special_tokens=True) + logits = torch.stack(outputs.scores[-len(out_tokens):]).squeeze(1) + probs = logits.softmax(dim=1).gather(1, out_tokens.unsqueeze(1)) + + perplexity = torch.exp(-torch.sum(torch.log(probs)) / len(out_tokens)).item() + + return { + 'outputs': outputs, + 'output_text': output_text, + 'perplexity': perplexity, + } \ No newline at end of file diff --git a/src/modeling/model_wrappers/vllm.py b/src/modeling/model_wrappers/vllm.py new file mode 100644 index 0000000..235ee8d --- /dev/null +++ b/src/modeling/model_wrappers/vllm.py @@ -0,0 +1,32 @@ +from typing import NamedTuple, Tuple, Union, Any, Optional + +from vllm import LLM, SamplingParams +from llama_index.core import VectorStoreIndex +from omegaconf import DictConfig + +from src.modeling.model_wrappers.base import ModelWrapper +from src.types import GenerateData + +class VLLMWrapper(ModelWrapper): + + _sampling_params = SamplingParams(temperature=0.75, top_k=5) + + def generate_completion( + self, + item: GenerateData, + config: DictConfig, + vector_store: Optional[VectorStoreIndex] = None, + ) -> dict: + + input_ids = self._prepare_input(item, config, self.get_tokenizer(), vector_store).input_ids.tolist() + + outputs = self.generate(prompt_token_ids=input_ids, sampling_params=self._sampling_params) + + output_text = outputs[0].outputs[0].text + + perplexity = 0 #TODO: calculate perplexity later + return { + 'outputs': outputs, + 'output_text': output_text, + 'perplexity': perplexity, + } \ No newline at end of file diff --git a/src/modeling/tokens.py b/src/modeling/tokens.py new file mode 100644 index 0000000..ad306b5 --- /dev/null +++ b/src/modeling/tokens.py @@ -0,0 +1,4 @@ +EOT_TOKEN = '<|EOT|>' +FIM_BEGIN_TOKEN = '<|fim▁begin|>' # 32016 +FIM_HOLE_TOKEN = '<|fim▁hole|>' # 32015 +FIM_END_TOKEN = '<|fim▁end|>' # 32017 \ No newline at end of file diff --git a/src/routers/fine_tuner.py b/src/routers/fine_tuner.py index 9bbbe70..0ac371f 100644 --- a/src/routers/fine_tuner.py +++ b/src/routers/fine_tuner.py @@ -16,13 +16,14 @@ from src.auto_generation.problem_generator import LibraryProblemGenerator from src.crawler.docs_scraper import get_doc_data from src.documents import read_from_bytes, retrieve_from_cache, save_to_cache -from src.modeling import get_model, get_tokenizer +from src.modeling.model_hub import get_model, get_tokenizer from src.rag import VectorStoreProvider from src.training import finetune from src.training.interactive.train_multi_step_sft import ( generate_solutions, train_multi_step_sft_with_verification, ) +from src.types import ProjectFinetuneData from src.users import SessionTracker, validate_user_session @@ -34,14 +35,6 @@ global_finetune_lock = threading.Lock() -class ProjectFinetuneData(BaseModel): - project_dict: Optional[Dict[str, str]] = None - language: Optional[str] = None - libraries: Optional[List[str]] = None - urls: Optional[List[str]] = None - documents: Optional[List[Tuple[str, bytes]]] = None - - # @router.get('/learn/project') # async def finetune_project_form( # username: Annotated[str, Depends(validate_user_session)], diff --git a/src/routers/generator.py b/src/routers/generator.py index 08963a2..97a9602 100644 --- a/src/routers/generator.py +++ b/src/routers/generator.py @@ -1,7 +1,7 @@ from functools import partial import logging import math -from typing import Annotated, List, Optional, Union +from typing import Annotated, List, Dict, Optional, Union from concurrent.futures import CancelledError from fastapi import APIRouter, Depends @@ -13,38 +13,23 @@ from tqdm import tqdm from src import config_handler, modeling -from src.modeling import get_model, get_tokenizer +from src.modeling.model_hub import get_inference_model from src.rag import get_vector_store, retrieve_context from src.training.data_formatting import format_inference_input, format_rag_query +from src.types import GenerateData from src.users import validate_user_session +logger = logging.getLogger(__name__) +router = APIRouter() + + GENERATION_KWARGS = dict( temperature=0.7, do_sample=True, top_k=5, # num_beams=3, early_stopping=True, # do_sample=True, temperature=1.1, top_k=3, ) -logger = logging.getLogger(__name__) -router = APIRouter() - -class GenerateData(BaseModel): - """Data class for the /generate endpoint. - - Args: - file_path (str): The path to the file to generate from. - prior_context (str): The prior context to generate from. - proceeding_context (Optional[str], optional): - The proceeding context to generate from. Defaults to None. - If provided, FIM is used, otherwise next token prediction is used. - max_decode_length (int, optional): - The maximum length of the generated sequence. Defaults to 128. - """ - file_path: Optional[str] = None - prior_context: str - proceeding_context: Optional[str] = None - max_decode_length: int = 256 - @router.post('/generate') def generate( @@ -62,105 +47,23 @@ def generate( return result -def prepare_input( - item: GenerateData, - config: DictConfig, - device: Union[str, torch.device], - tokenizer: PreTrainedTokenizer, - vector_store: Optional[VectorStoreIndex] = None, - ) -> BatchEncoding: - """Prepare and format input for the model. - - Depending on the input, the example will either be formatted as - a next token prediction problem or a fill in the middle problem. - RAG context will also be added if enabled in the config. - - Args: - item (GenerateData): The input data from a REST request. - config (DictConfig): The config global config. - model (Model): The model. - tokenizer (Tokenizer): The tokenizer. - vector_store (VectorStoreIndex, optional): The vector store for RAG. Defaults to None. - - Returns: - BatchEncoding: The formatted input with input_ids and an attention_mask. - """ - use_rag = config.rag.get('enabled', False) and vector_store is not None - if use_rag: - # Create a user context string from the prior and proceeding context - user_context_str = format_rag_query( - prior_context = item.prior_context, - proceeding_context = item.proceeding_context, - max_length = config.rag.max_embed_context_length, - ) - - # Then retrieve the relevant context strings from the vector store - retrieved = retrieve_context(user_context_str, vector_store, top_k=config.rag.n_chunks_per_generation) - else: - retrieved = None - - context_length = config.model.context_length - - inputs = format_inference_input( - preceeding_text = item.prior_context, - tokenizer = tokenizer, - config = config, - proceeding_text = item.proceeding_context, - file_path = item.file_path, - max_decode_length = config.inference.max_gen_length, - context_length = context_length, - retrieved_context = retrieved, - ) - - return inputs.to(device) - - def generate_task(item: GenerateData, config: DictConfig, username: str): logging.info(f"Generating text for user: {username}.") - model = get_model(username) - tokenizer = get_tokenizer(username) + model = get_inference_model(username) + # tokenizer = get_tokenizer(username) vector_store = get_vector_store(username) - results = generate_completion(item, config, model, tokenizer, vector_store) + results = model.generate_completion(item, config, vector_store) output_text = results['output_text'] - score = results.get('perplexity').item() + score = results.get('perplexity') if not math.isfinite(score): score = None return {'generated_text': output_text, 'score': score} -def generate_completion( - item: GenerateData, - config: DictConfig, - model: torch.nn.Module, - tokenizer: PreTrainedTokenizer, - vector_store: Optional[VectorStoreIndex] = None, - ) -> dict: - inputs = prepare_input(item, config, model.device, tokenizer, vector_store) - - outputs = model.generate( - **inputs, max_new_tokens=config.inference.max_gen_length, - return_dict_in_generate=True, output_scores=True, - **GENERATION_KWARGS, - ) - - out_tokens = outputs.sequences[0][inputs.input_ids.shape[1]:] - output_text = tokenizer.decode(out_tokens, skip_special_tokens=True) - logits = torch.stack(outputs.scores[-len(out_tokens):]).squeeze(1) - probs = logits.softmax(dim=1).gather(1, out_tokens.unsqueeze(1)) - - perplexity = torch.exp(-torch.sum(torch.log(probs)) / len(out_tokens)) - - return { - 'outputs': outputs, - 'output_text': output_text, - 'perplexity': perplexity, - } - - def batch_generate_completions( items: List[GenerateData], config: DictConfig, diff --git a/src/session.py b/src/session.py new file mode 100644 index 0000000..ea2b349 --- /dev/null +++ b/src/session.py @@ -0,0 +1,19 @@ +from concurrent.futures import CancelledError +import threading +from src.users import SessionTracker + +GLOBAL_MAIN_THREAD_ID = None + +def thread_hook(username, *args): + current_thread_id = threading.get_ident() + + session_tracker = SessionTracker.get_instance() + activity_threads = session_tracker.get_user_activity_threads(username) + + if current_thread_id not in activity_threads: + raise CancelledError("Cancelled by new request") + + +def set_main_thread_id(): + global GLOBAL_MAIN_THREAD_ID + GLOBAL_MAIN_THREAD_ID = threading.get_ident() diff --git a/src/training/data_formatting.py b/src/training/data_formatting.py index 630f1a5..691f624 100644 --- a/src/training/data_formatting.py +++ b/src/training/data_formatting.py @@ -7,13 +7,16 @@ from typing import Optional, Sequence from fastapi import APIRouter +from llama_index.core import VectorStoreIndex from omegaconf import DictConfig import torch from transformers.tokenization_utils_base import BatchEncoding from transformers import PreTrainedTokenizer -from src.modeling import FIM_BEGIN_TOKEN, FIM_HOLE_TOKEN, FIM_END_TOKEN +from src.modeling.tokens import FIM_BEGIN_TOKEN, FIM_HOLE_TOKEN, FIM_END_TOKEN +from src.rag import retrieve_context +from src.types import GenerateData logger = logging.getLogger(__name__) router = APIRouter() @@ -31,6 +34,57 @@ RAG_QUERY_FIM_HOLE_TOKEN = '' +def prepare_input( + item: GenerateData, + config: DictConfig, + tokenizer: PreTrainedTokenizer, + vector_store: Optional[VectorStoreIndex] = None, + ) -> BatchEncoding: + """Prepare and format input for the model. + + Depending on the input, the example will either be formatted as + a next token prediction problem or a fill in the middle problem. + RAG context will also be added if enabled in the config. + + Args: + item (GenerateData): The input data from a REST request. + config (DictConfig): The config global config. + tokenizer (Tokenizer): The tokenizer. + vector_store (VectorStoreIndex, optional): The vector store for RAG. Defaults to None. + + Returns: + Dict: The formatted input with input_ids and an attention_mask. + """ + use_rag = config.rag.get('enabled', False) and vector_store is not None + if use_rag: + # Create a user context string from the prior and proceeding context + user_context_str = format_rag_query( + prior_context = item.prior_context, + proceeding_context = item.proceeding_context, + max_length = config.rag.max_embed_context_length, + ) + + # Then retrieve the relevant context strings from the vector store + retrieved = retrieve_context(user_context_str, vector_store, top_k=config.rag.n_chunks_per_generation) + else: + retrieved = None + + context_length = config.model.context_length + + inputs = format_inference_input( + preceeding_text = item.prior_context, + tokenizer = tokenizer, + config = config, + proceeding_text = item.proceeding_context, + file_path = item.file_path, + max_decode_length = config.inference.max_gen_length, + context_length = context_length, + retrieved_context = retrieved, + ) + + return inputs + + @dataclass class FIMTrainSample: """A single FIM sample for training.""" @@ -283,6 +337,13 @@ def format_ntp_inference_input( context_tokens, ]).unsqueeze(0) + # attention_mask = input_ids.ne(tokenizer.pad_token_id).long() + + # return dict( + # input_ids = input_ids, + # attention_mask = attention_mask + # ) + inputs = BatchEncoding({ 'input_ids': input_ids, 'attention_mask': input_ids.ne(tokenizer.pad_token_id).long(), @@ -411,6 +472,13 @@ def format_fim_inference_input( torch.tensor([suffix_tok_id]), ]).unsqueeze(0) + # attention_mask = input_ids.ne(tokenizer.pad_token_id).long() + + # return dict( + # input_ids = input_ids, + # attention_mask = attention_mask + # ) + inputs = BatchEncoding({ 'input_ids': input_ids, 'attention_mask': input_ids.ne(tokenizer.pad_token_id).long(), diff --git a/src/training/trainer.py b/src/training/trainer.py index a12e8da..6fdcda1 100644 --- a/src/training/trainer.py +++ b/src/training/trainer.py @@ -24,14 +24,14 @@ EvalPrediction, EvalLoopOutput, has_length, - is_torch_tpu_available, + # is_torch_tpu_available, ) from transformers.training_args import TrainingArguments from transformers.utils import is_accelerate_available, logging -if is_torch_tpu_available(check_device=False): - import torch_xla.core.xla_model as xm +# if is_torch_tpu_available(check_device=False): +# import torch_xla.core.xla_model as xm if is_accelerate_available(): from accelerate import __version__ as accelerate_version @@ -178,8 +178,8 @@ def evaluation_loop( inputs_decode = self._prepare_input( inputs[main_input_name]) if args.include_inputs_for_metrics else None - if is_torch_tpu_available(): - xm.mark_step() + # if is_torch_tpu_available(): + # xm.mark_step() # Update containers on host if loss is not None: diff --git a/src/types.py b/src/types.py new file mode 100644 index 0000000..14bdc5b --- /dev/null +++ b/src/types.py @@ -0,0 +1,15 @@ +from pydantic import BaseModel +from typing import Dict, List, Optional, Tuple + +class ProjectFinetuneData(BaseModel): + project_dict: Optional[Dict[str, str]] = None + language: Optional[str] = None + libraries: Optional[List[str]] = None + urls: Optional[List[str]] = None + documents: Optional[List[Tuple[str, bytes]]] = None + +class GenerateData(BaseModel): + file_path: Optional[str] = None + prior_context: str + proceeding_context: Optional[str] = None + max_decode_length: int = 256 \ No newline at end of file