diff --git a/README.md b/README.md
index b5d5920..9145670 100644
--- a/README.md
+++ b/README.md
@@ -30,7 +30,7 @@ table th:nth-of-type(4) {
| Use case | Quality-optimized | Balanced | Speed-optimized |
|----------------------------------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|-----------------------------------------------------------------------------------------------------|
-| Text generation following instructions | [MPT-30B-Instruct](llm-models/mpt/mpt-30b/)
[Llama-2-70b-chat-hf](llm-models/llamav2/llamav2-70b) | [MPT-7B-Instruct](llm-models/mpt/mpt-7b)
[MPT-7B-8k-Instruct](llm-models/mpt/mpt-7b-8k)
[Llama-2-7b-chat-hf](llm-models/llamav2/llamav2-7b)
[Llama-2-13b-chat-hf](llm-models/llamav2/llamav2-13b) | |
+| Text generation following instructions | [MPT-30B-Instruct](llm-models/mpt/mpt-30b/)
[Llama-2-70b-chat-hf](llm-models/llamav2/llamav2-70b) | [mistral-7b](llm-models/mistral/mistral-7b)
[MPT-7B-Instruct](llm-models/mpt/mpt-7b)
[MPT-7B-8k-Instruct](llm-models/mpt/mpt-7b-8k)
[Llama-2-7b-chat-hf](llm-models/llamav2/llamav2-7b)
[Llama-2-13b-chat-hf](llm-models/llamav2/llamav2-13b) | |
| Text embeddings (English only) | | [bge-large-en(0.3B)](llm-models/embedding/bge/bge-large)
[e5-large-v2 (0.3B)](llm-models/embedding/e5-v2)
[instructor-xl (1.3B)](llm-models/embedding/instructor-xl)* | [bge-base-en (0.1B)](llm-models/embedding/bge)
[e5-base-v2 (0.1B)](llm-models/embedding/e5-v2) |
| Transcription (speech to text) | | [whisper-large-v2](llm-models/transcription/whisper)(1.6B)
[whisper-medium](llm-models/transcription/whisper) (0.8B) | |
| Image generation | | [stable-diffusion-xl](llm-models/image_generation/stable_diffusion) | |
diff --git a/llm-models/config/a10_config_zero2.json b/llm-models/config/a10_config_zero2.json
new file mode 100644
index 0000000..250e51f
--- /dev/null
+++ b/llm-models/config/a10_config_zero2.json
@@ -0,0 +1,47 @@
+{
+ "fp16": {
+ "enabled": false
+ },
+ "bf16": {
+ "enabled": true
+ },
+ "optimizer": {
+ "type": "AdamW",
+ "params": {
+ "lr": "auto",
+ "betas": "auto",
+ "eps": "auto",
+ "weight_decay": "auto"
+ }
+ },
+ "scheduler": {
+ "type": "WarmupLR",
+ "params": {
+ "warmup_min_lr": "auto",
+ "warmup_max_lr": "auto",
+ "warmup_num_steps": "auto"
+ }
+ },
+ "zero_optimization": {
+ "stage": 2,
+ "overlap_comm": true,
+ "contiguous_gradients": true,
+ "sub_group_size": 5e7,
+ "reduce_bucket_size": "auto",
+ "reduce_scatter": true,
+ "offload_param": {
+ "device": "cpu",
+ "pin_memory": true
+ },
+ "offload_optimizer": {
+ "device": "cpu",
+ "pin_memory": true
+ }
+ },
+ "gradient_accumulation_steps": "auto",
+ "gradient_clipping": "auto",
+ "steps_per_print": 50,
+ "train_batch_size": "auto",
+ "train_micro_batch_size_per_gpu": "auto",
+ "wall_clock_breakdown": false
+}
diff --git a/llm-models/llamav2/llamav2-70b/07_fine_tune_lora.py b/llm-models/llamav2/llamav2-70b/07_fine_tune_lora.py
new file mode 100644
index 0000000..4cd7567
--- /dev/null
+++ b/llm-models/llamav2/llamav2-70b/07_fine_tune_lora.py
@@ -0,0 +1,74 @@
+# Databricks notebook source
+# MAGIC %md
+# MAGIC
+# MAGIC # Fine tune llama-2-70b with LoRA and deepspeed on a single node
+# MAGIC
+# MAGIC [Llama 2](https://huggingface.co/meta-llama) is a collection of pretrained and fine-tuned generative text models ranging in scale from 7 billion to 70 billion parameters. It is trained with 2T tokens and supports context length window upto 4K tokens. [Llama-2-70b-hf](https://huggingface.co/meta-llama/Llama-2-70b-hf) is the 7B pretrained model, converted for the Hugging Face Transformers format.
+# MAGIC
+# MAGIC This is to fine-tune [llama-2-70b-hf](https://huggingface.co/meta-llama/Llama-2-70b-hf) models on the [dolly_hhrlhf](https://huggingface.co/datasets/mosaicml/dolly_hhrlhf) dataset.
+# MAGIC
+# MAGIC Environment for this notebook:
+# MAGIC - Runtime: 14.0 GPU ML Runtime
+# MAGIC - Instance: `Standard_NC48ads_A100_v4` on Azure with 2 A100-80GB GPUs, `p4d.24xlarge` on AWS with 8 A100-40GB GPUs
+# MAGIC
+# MAGIC Requirements:
+# MAGIC - To get the access of the model on HuggingFace, please visit the [Meta website](https://ai.meta.com/resources/models-and-libraries/llama-downloads) and accept our license terms and acceptable use policy before submitting this form. Requests will be processed in 1-2 days.
+# MAGIC
+
+# COMMAND ----------
+
+# MAGIC %md
+# MAGIC Install the missing libraries
+
+# COMMAND ----------
+
+# MAGIC %pip install deepspeed==0.9.5 xformers
+# MAGIC %pip install git+https://github.com/huggingface/peft.git
+# MAGIC %pip install bitsandbytes==0.40.1 einops==0.6.1 trl==0.4.7
+# MAGIC %pip install -U torch==2.0.1 accelerate==0.21.0 transformers==4.31.0
+# MAGIC dbutils.library.restartPython()
+
+# COMMAND ----------
+
+import os
+os.environ["HF_HOME"] = "/local_disk0/hf"
+os.environ["HF_DATASETS_CACHE"] = "/local_disk0/hf"
+os.environ["TRANSFORMERS_CACHE"] = "/local_disk0/hf"
+
+# COMMAND ----------
+
+from huggingface_hub import notebook_login
+
+# Login to Huggingface to get access to the model
+notebook_login()
+
+# COMMAND ----------
+
+# MAGIC %md
+# MAGIC ## Fine tune the model with `deepspeed`
+# MAGIC
+# MAGIC The fine tune logic is written in `scripts/fine_tune_deepspeed.py`. The dataset used for fine tune is [databricks-dolly-15k ](https://huggingface.co/datasets/databricks/databricks-dolly-15k) dataset.
+# MAGIC
+# MAGIC
+
+# COMMAND ----------
+
+# MAGIC %sh
+# MAGIC deepspeed \
+# MAGIC --num_gpus 2 \
+# MAGIC scripts/fine_tune_lora.py \
+# MAGIC --output_dir="/local_disk0/output"
+
+# COMMAND ----------
+
+# MAGIC %md
+# MAGIC Model checkpoint is saved at `/local_disk0/final_model`.
+
+# COMMAND ----------
+
+# MAGIC %sh
+# MAGIC ls /local_disk0/final_model
+
+# COMMAND ----------
+
+
diff --git a/llm-models/llamav2/llamav2-70b/scripts/fine_tune_lora.py b/llm-models/llamav2/llamav2-70b/scripts/fine_tune_lora.py
new file mode 100644
index 0000000..4304476
--- /dev/null
+++ b/llm-models/llamav2/llamav2-70b/scripts/fine_tune_lora.py
@@ -0,0 +1,411 @@
+import bitsandbytes as bnb
+import logging
+import math
+import os
+import sys
+import json
+import random
+from pathlib import Path
+
+from dataclasses import dataclass, field
+from itertools import chain
+import deepspeed
+from typing import Optional,List,Union
+
+import datasets
+import evaluate
+import torch
+from datasets import Dataset, load_dataset
+from peft import ( # noqa: E402
+ LoraConfig,
+ PeftModel,
+ get_peft_model,
+ get_peft_model_state_dict,
+ prepare_model_for_int8_training,
+ prepare_model_for_kbit_training,
+ set_peft_model_state_dict,
+)
+from peft.tuners.lora import LoraLayer
+import transformers
+from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
+from transformers import (
+ AutoConfig,
+ AutoModelForCausalLM,
+ AutoTokenizer,
+ BitsAndBytesConfig,
+ DataCollatorForLanguageModeling,
+ HfArgumentParser,
+ IntervalStrategy,
+ LlamaTokenizer,
+ Seq2SeqTrainer,
+ PreTrainedTokenizer,
+ SchedulerType,
+ Trainer,
+ TrainerCallback,
+ TrainerState,
+ TrainerControl,
+ TrainingArguments,
+ default_data_collator,
+ BitsAndBytesConfig,
+ set_seed,
+)
+
+from transformers.testing_utils import CaptureLogger
+from transformers.trainer_utils import get_last_checkpoint
+from transformers.utils import check_min_version, send_example_telemetry
+from transformers.utils.versions import require_version
+
+logger = logging.getLogger(__name__)
+
+ROOT_PATH = Path(__file__).parent.parent
+MODEL_PATH = 'meta-llama/Llama-2-70b-hf'
+TOKENIZER_PATH = 'meta-llama/Llama-2-70b-hf'
+DEFAULT_TRAINING_DATASET = "mosaicml/dolly_hhrlhf"
+CONFIG_PATH = "../../config/a10_config_zero2.json"
+LOCAL_OUTPUT_DIR = "/dbfs/llama-2-fine-tune/output"
+TRANSFORMER_CACHE = "/local_disk0/hf"
+DEFAULT_PAD_TOKEN = "[PAD]"
+IGNORE_INDEX = -100
+DEFAULT_SEED = 68
+
+
+@dataclass
+class HFTrainingArguments:
+ local_rank: Optional[str] = field(default="-1")
+ dataset: Optional[str] = field(default=DEFAULT_TRAINING_DATASET)
+ cache_dir: Optional[str] = field(default=TRANSFORMER_CACHE)
+ use_auth_token: Optional[str] = field(default=None)
+ model: Optional[str] = field(default=MODEL_PATH)
+ tokenizer: Optional[str] = field(default=TOKENIZER_PATH)
+ max_seq_len: Optional[int] = field(default=256)
+
+ final_model_output_path: Optional[str] = field(default="/local_disk0/final_model")
+
+ deepspeed_config: Optional[str] = field(default=CONFIG_PATH)
+
+ adam8bit: bool = field(
+ default=False,
+ metadata={"help": "Use 8-bit adam."}
+ )
+ double_quant: bool = field(
+ default=True,
+ metadata={"help": "Compress the quantization statistics through double quantization."}
+ )
+ quant_type: str = field(
+ default="nf4",
+ metadata={"help": "Quantization data type to use. Should be one of `fp4` or `nf4`."}
+ )
+ bits: int = field(
+ default=4,
+ metadata={"help": "How many bits to use."}
+ )
+ lora_r: int = field(
+ default=64,
+ metadata={"help": "Lora R dimension."}
+ )
+ lora_alpha: float = field(
+ default=16,
+ metadata={"help": " Lora alpha."}
+ )
+ lora_dropout: float = field(
+ default=0.0,
+ metadata={"help":"Lora dropout."}
+ )
+ max_memory_MB: int = field(
+ default=80000,
+ metadata={"help": "Free memory per gpu."}
+ )
+ output_dir: Optional[str] = field(default="/local_disk0/output")
+ per_device_train_batch_size: Optional[int] = field(default=1)
+ per_device_eval_batch_size: Optional[int] = field(default=1)
+ gradient_checkpointing: Optional[bool] = field(default=True)
+ gradient_accumulation_steps: Optional[int] = field(default=1)
+ learning_rate: Optional[float] = field(default=1e-6)
+ optim: Optional[str] = field(default="adamw_hf")
+ num_train_epochs: Optional[int] = field(default=1)
+ max_steps: Optional[int] = field(default=-1)
+ adam_beta1: float = field(default=0.9)
+ adam_beta2: float = field(default=0.999)
+ adam_epsilon: float = field(default=1e-8)
+ lr_scheduler_type: Union[SchedulerType, str] = field(
+ default="cosine",
+ )
+ warmup_steps: int = field(default=0)
+ weight_decay: Optional[float] = field(default=1)
+ logging_strategy: Optional[Union[str, IntervalStrategy]] = field(
+ default=IntervalStrategy.STEPS
+ )
+ evaluation_strategy: Optional[Union[str, IntervalStrategy]] = field(
+ default=IntervalStrategy.STEPS
+ )
+ save_strategy: Optional[Union[str, IntervalStrategy]] = field(
+ default=IntervalStrategy.STEPS
+ )
+ fp16: Optional[bool] = field(default=False)
+ bf16: Optional[bool] = field(default=True)
+ save_steps: Optional[int] = field(default=100)
+ logging_steps: Optional[int] = field(default=10)
+
+
+def find_all_linear_names(args, model):
+ cls = bnb.nn.Linear4bit if args.bits == 4 else (bnb.nn.Linear8bitLt if args.bits == 8 else torch.nn.Linear)
+ lora_module_names = set()
+ for name, module in model.named_modules():
+ if isinstance(module, cls):
+ names = name.split('.')
+ lora_module_names.add(names[0] if len(names) == 1 else names[-1])
+
+
+ if 'lm_head' in lora_module_names: # needed for 16-bit
+ lora_module_names.remove('lm_head')
+ return list(lora_module_names)
+
+
+def load_training_dataset(
+ tokenizer,
+ path_or_dataset: str = DEFAULT_TRAINING_DATASET,
+ max_seq_len: int = 256,
+ seed: int = DEFAULT_SEED,
+) -> Dataset:
+ logger.info(f"Loading dataset from {path_or_dataset}")
+ dataset = load_dataset(path_or_dataset)
+ logger.info(f"Training: found {dataset['train'].num_rows} rows")
+ logger.info(f"Eval: found {dataset['test'].num_rows} rows")
+
+ # Reformat input data, add prompt template if needed
+ def _reformat_data(row):
+ return row["prompt"] + row["response"]
+
+ # Inspired from: https://huggingface.co/learn/nlp-course/chapter7/6?fw=pt
+ def tokenize(element):
+ input_batch = []
+ attention_masks = []
+
+ outputs = tokenizer(
+ _reformat_data(element),
+ truncation=True,
+ padding=True,
+ max_length=max_seq_len,
+ return_overflowing_tokens=False,
+ return_length=True,
+ )
+
+ for length, input_ids, attention_mask in zip(
+ outputs["length"], outputs["input_ids"], outputs["attention_mask"]
+ ):
+ if length == max_seq_len:
+ input_batch.append(input_ids)
+ attention_masks.append(attention_mask)
+
+ return {"input_ids": input_batch, "attention_mask": attention_masks}
+
+ train_tokenized_dataset = dataset["train"].map(
+ tokenize, batched=True, remove_columns=dataset["train"].column_names
+ )
+ eval_tokenized_dataset = dataset["test"].map(
+ tokenize, batched=True, remove_columns=dataset["test"].column_names
+ )
+
+ return train_tokenized_dataset, eval_tokenized_dataset
+
+def get_model(args) -> AutoModelForCausalLM:
+ logger.info(f"Loading model: {args.model}")
+ if torch.cuda.is_available():
+ n_gpus = torch.cuda.device_count()
+ max_memory = f'{args.max_memory_MB}MB'
+ max_memory = {i: max_memory for i in range(n_gpus)}
+ device_map = "auto"
+
+ # if we are in a distributed setting, we need to set the device map and max memory per device
+ if os.environ.get('LOCAL_RANK') is not None:
+ local_rank = int(os.environ.get('LOCAL_RANK', '0'))
+ device_map = {'': local_rank}
+ max_memory = {'': max_memory[local_rank]}
+
+ model = AutoModelForCausalLM.from_pretrained(
+ args.model,
+ cache_dir=args.cache_dir,
+ load_in_4bit=args.bits == 4,
+ load_in_8bit=args.bits == 8,
+ device_map=device_map,
+ max_memory=max_memory,
+ quantization_config=BitsAndBytesConfig(
+ load_in_4bit=args.bits == 4,
+ load_in_8bit=args.bits == 8,
+ llm_int8_threshold=6.0,
+ llm_int8_has_fp16_weight=False,
+ bnb_4bit_compute_dtype=torch.bfloat16,
+ bnb_4bit_use_double_quant=args.double_quant,
+ bnb_4bit_quant_type=args.quant_type,
+ ),
+ torch_dtype=(torch.float32 if args.fp16 else (torch.bfloat16 if args.bf16 else torch.float32)),
+ use_auth_token=args.use_auth_token
+ )
+
+ setattr(model, 'model_parallel', True)
+ setattr(model, 'is_parallelizable', True)
+ model.config.torch_dtype=(torch.float32 if args.fp16 else (torch.bfloat16 if args.bf16 else torch.float32))
+ model.config.use_cache = False
+
+ model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=args.gradient_checkpointing)
+
+ logger.info("Loading adapters.")
+ modules = find_all_linear_names(args, model)
+ config = LoraConfig(
+ r=args.lora_r,
+ lora_alpha=args.lora_alpha,
+ target_modules=modules,
+ lora_dropout=args.lora_dropout,
+ bias="none",
+ task_type="CAUSAL_LM",
+ )
+ model = get_peft_model(model, config)
+ for name, module in model.named_modules():
+ if isinstance(module, LoraLayer):
+ if args.bf16:
+ module = module.to(torch.bfloat16)
+ if 'norm' in name:
+ module = module.to(torch.float32)
+ if 'lm_head' in name or 'embed_tokens' in name:
+ if hasattr(module, 'weight'):
+ if args.bf16 and module.weight.dtype == torch.float32:
+ module = module.to(torch.bfloat16)
+ return model
+
+
+def get_tokenizer(args, model) -> PreTrainedTokenizer:
+ tokenizer = AutoTokenizer.from_pretrained(
+ args.tokenizer,
+ cache_dir=args.cache_dir,
+ padding_side="right",
+ use_fast=False, # Fast tokenizer giving issues.
+ tokenizer_type='llama',
+ use_auth_token=args.use_auth_token,
+ )
+ tokenizer.pad_token = tokenizer.eos_token
+
+ # LLaMA tokenizer may not have correct special tokens set.
+ # Check and add them if missing to prevent them from being parsed into different tokens.
+ # Note that these are present in the vocabulary.
+ # Note also that `model.config.pad_token_id` is 0 which corresponds to `` token.
+ logger.info('Adding special tokens.')
+ tokenizer.add_special_tokens({
+ "eos_token": tokenizer.convert_ids_to_tokens(model.config.eos_token_id),
+ "bos_token": tokenizer.convert_ids_to_tokens(model.config.bos_token_id),
+ "unk_token": tokenizer.convert_ids_to_tokens(
+ model.config.pad_token_id if model.config.pad_token_id != -1 else tokenizer.pad_token_id
+ ),
+ })
+ return tokenizer
+
+def print_trainable_parameters(args, model):
+ """
+ Prints the number of trainable parameters in the model.
+ """
+ trainable_params = 0
+ all_param = 0
+ for _, param in model.named_parameters():
+ all_param += param.numel()
+ if param.requires_grad:
+ trainable_params += param.numel()
+ if args.bits == 4: trainable_params /= 2
+ print(
+ f"trainable params: {trainable_params} || "
+ f"all params: {all_param} || "
+ f"trainable: {100 * trainable_params / all_param}"
+ )
+
+def train(args: HFTrainingArguments):
+ set_seed(DEFAULT_SEED)
+ torch.backends.cuda.matmul.allow_tf32 = True
+
+ if args.deepspeed_config:
+ with open(args.deepspeed_config) as json_data:
+ deepspeed_config_dict = json.load(json_data)
+ else:
+ deepspeed_config_dict = None
+
+ training_args = TrainingArguments(
+ local_rank=args.local_rank,
+ output_dir=args.output_dir,
+ per_device_train_batch_size=args.per_device_train_batch_size,
+ per_device_eval_batch_size=args.per_device_eval_batch_size,
+ gradient_checkpointing=args.gradient_checkpointing,
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
+ learning_rate=args.learning_rate,
+ optim=args.optim,
+ num_train_epochs=args.num_train_epochs,
+ max_steps=args.max_steps,
+ adam_beta1=args.adam_beta1,
+ adam_beta2=args.adam_beta2,
+ adam_epsilon=args.adam_epsilon,
+ lr_scheduler_type=args.lr_scheduler_type,
+ warmup_steps=args.warmup_steps,
+ weight_decay=args.weight_decay,
+ logging_strategy=args.logging_strategy,
+ evaluation_strategy=args.evaluation_strategy,
+ save_strategy=args.save_strategy,
+ fp16=args.fp16,
+ bf16=args.bf16,
+ deepspeed=deepspeed_config_dict,
+ save_steps=args.save_steps,
+ logging_steps=args.logging_steps,
+ push_to_hub=False,
+ disable_tqdm=True,
+ report_to=[],
+ # group_by_length=True,
+ ddp_find_unused_parameters=False,
+ # fsdp=["full_shard", "offload"],
+ )
+ model = get_model(args)
+ model.print_trainable_parameters()
+ tokenizer = get_tokenizer(args, model)
+ train_dataset, eval_dataset = load_training_dataset(
+ tokenizer, path_or_dataset=args.dataset, max_seq_len=args.max_seq_len
+ )
+
+ data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
+
+ trainer = Trainer(
+ model=model,
+ args=training_args,
+ data_collator=data_collator,
+ train_dataset=train_dataset,
+ eval_dataset=eval_dataset,
+ )
+
+ logger.info("Training the model")
+ trainer.train()
+
+ logger.info(f"Saving Model to {args.final_model_output_path}")
+ trainer.save_model(output_dir=args.final_model_output_path)
+ tokenizer.save_pretrained(args.final_model_output_path)
+
+ logger.info("Training finished.")
+
+
+def main():
+ parser = HfArgumentParser(HFTrainingArguments)
+
+ parsed = parser.parse_args_into_dataclasses()
+ args: HFTrainingArguments = parsed[0]
+
+ train(args)
+
+
+if __name__ == "__main__":
+ os.environ["HF_HOME"] = "/local_disk0/hf"
+ os.environ["HF_DATASETS_CACHE"] = "/local_disk0/hf"
+ os.environ["TRANSFORMERS_CACHE"] = "/local_disk0/hf"
+
+ logging.basicConfig(
+ format="%(asctime)s %(levelname)s [%(name)s] %(message)s",
+ level=logging.INFO,
+ datefmt="%Y-%m-%d %H:%M:%S",
+ )
+
+ try:
+ main()
+ except Exception:
+ logger.exception("main failed")
+ raise
diff --git a/llm-models/mistral/README.md b/llm-models/mistral/README.md
new file mode 100644
index 0000000..3789684
--- /dev/null
+++ b/llm-models/mistral/README.md
@@ -0,0 +1,28 @@
+
+
+
+# Example notebooks for the mistral models on Databricks
+
+[Mistral 7B](https://huggingface.co/mistralai) is a 7.3B parameter model that:
+
+- Outperforms Llama 2 13B on all benchmarks
+- Outperforms Llama 1 34B on many benchmarks
+- Approaches CodeLlama 7B performance on code, while remaining good at English tasks
+- Uses Grouped-query attention (GQA) for faster inference
+- Uses Sliding Window Attention (SWA) to handle longer sequences at smaller cost
+
+Mistral 7B is under the Apache 2.0 license, it can be used without restrictions.
\ No newline at end of file
diff --git a/llm-models/mistral/mistral-7b/01_b_load_inference_vllm.py b/llm-models/mistral/mistral-7b/01_b_load_inference_vllm.py
new file mode 100644
index 0000000..52cdeac
--- /dev/null
+++ b/llm-models/mistral/mistral-7b/01_b_load_inference_vllm.py
@@ -0,0 +1,231 @@
+# Databricks notebook source
+# MAGIC %md
+# MAGIC # Mistral-7B-Instruct Inference with vllm on Databricks
+# MAGIC
+# MAGIC The [Mistral-7B-Instruct-v0.1](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1) Large Language Model (LLM) is a instruct fine-tuned version of the [Mistral-7B-v0.1](https://huggingface.co/mistralai/Mistral-7B-v0.1) generative text model using a variety of publicly available conversation datasets.
+# MAGIC
+# MAGIC [vllm](https://github.com/vllm-project/vllm/tree/main) is an open-source library that makes LLM inference fast with various optimizations.
+# MAGIC
+# MAGIC Environment for this notebook:
+# MAGIC - Runtime: 14.0 GPU ML Runtime
+# MAGIC - There could be CUDA incompatability issues to install and use vllm on 13.x GPU ML Runtime.
+# MAGIC - Instance: `g5.4xlarge` on AWS
+# MAGIC
+# MAGIC GPU instances that have at least 16GB GPU memory would be enough for inference on single input (batch inference requires slightly more memory). On Azure, it is possible to use `Standard_NC6s_v3` or `Standard_NC4as_T4_v3`.
+
+# COMMAND ----------
+
+# MAGIC %md
+# MAGIC ## Speed up inference with vllm
+
+# COMMAND ----------
+
+# MAGIC %md
+# MAGIC ### Instsall vllm
+
+# COMMAND ----------
+
+# MAGIC %pip install vllm transformers==4.34.0 accelerate==0.20.3
+# MAGIC dbutils.library.restartPython()
+
+# COMMAND ----------
+
+# MAGIC %md
+# MAGIC ### Load model using vllm
+
+# COMMAND ----------
+
+from vllm import LLM
+llm = LLM(model="mistralai/Mistral-7B-Instruct-v0.1")
+
+# COMMAND ----------
+
+# MAGIC %md
+# MAGIC ### Inference
+
+# COMMAND ----------
+
+# In order to leverage instruction fine-tuning, your prompt should be surrounded by [INST] and [\INST] tokens. The very first instruction should begin with a begin of sentence id. The next instructions should not. The assistant generation will be ended by the end-of-sentence token id.
+
+DEFAULT_SYSTEM_PROMPT = """\
+You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
+
+If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information."""
+
+INTRO_BLURB = "Below is an instruction that describes a task. Write a response that appropriately completes the request."
+PROMPT_FOR_GENERATION_FORMAT = """
+[INST]<>
+{system_prompt}
+<>
+
+{instruction}
+[/INST]
+""".format(
+ system_prompt=DEFAULT_SYSTEM_PROMPT,
+ instruction="{instruction}"
+)
+
+# COMMAND ----------
+
+from vllm import SamplingParams
+
+# Define parameters to generate text
+def gen_text(prompts, use_template=False, **kwargs):
+ if use_template:
+ full_prompts = [
+ PROMPT_FOR_GENERATION_FORMAT.format(instruction=prompt)
+ for prompt in prompts
+ ]
+ else:
+ full_prompts = prompts
+
+ # the default max length is pretty small (16), which would cut the generated output in the middle, so it's necessary to increase the threshold to the complete response
+ if "max_tokens" not in kwargs:
+ kwargs["max_tokens"] = 512
+
+ # configure other text generation arguments, see common configurable args here: https://github.com/vllm-project/vllm/blob/main/vllm/sampling_params.py
+ # kwargs.update(
+ # {
+ # "temperature": 0.8,
+ # }
+ # )
+
+ sampling_params = SamplingParams(**kwargs)
+
+ outputs = llm.generate(full_prompts, sampling_params=sampling_params)
+ texts = [out.outputs[0].text for out in outputs]
+
+ return texts
+
+# COMMAND ----------
+
+# Inference on a single input
+results = gen_text(["What is a large language model?"])
+print(results[0])
+
+# COMMAND ----------
+
+# Use args such as temperature and max_tokens to control text generation
+results = gen_text(["What is a large language model?"], temperature=0.5, max_tokens=100, use_template=True)
+print(results[0])
+
+# COMMAND ----------
+
+# Check that the generation quality when the context is long
+from transformers import AutoTokenizer
+long_input = """Provide a concise summary of the below passage.
+
+Hannah Arendt was one of the seminal political thinkers of the twentieth century. The power and originality of her thinking was evident in works such as The Origins of Totalitarianism, The Human Condition, On Revolution and The Life of the Mind. In these works and in numerous essays she grappled with the most crucial political events of her time, trying to grasp their meaning and historical import, and showing how they affected our categories of moral and political judgment. What was required, in her view, was a new framework that could enable us to come to terms with the twin horrors of the twentieth century, Nazism and Stalinism. She provided such framework in her book on totalitarianism, and went on to develop a new set of philosophical categories that could illuminate the human condition and provide a fresh perspective on the nature of political life.
+
+Although some of her works now belong to the classics of the Western tradition of political thought, she has always remained difficult to classify. Her political philosophy cannot be characterized in terms of the traditional categories of conservatism, liberalism, and socialism. Nor can her thinking be assimilated to the recent revival of communitarian political thought, to be found, for example, in the writings of A. MacIntyre, M. Sandel, C. Taylor and M. Walzer. Her name has been invoked by a number of critics of the liberal tradition, on the grounds that she presented a vision of politics that stood in opposition some key liberal principles. There are many strands of Arendt’s thought that could justify such a claim, in particular, her critique of representative democracy, her stress on civic engagement and political deliberation, her separation of morality from politics, and her praise of the revolutionary tradition. However, it would be a mistake to view Arendt as an anti-liberal thinker. Arendt was in fact a stern defender of constitutionalism and the rule of law, an advocate of fundamental human rights (among which she included not only the right to life, liberty, and freedom of expression, but also the right to action and to opinion), and a critic of all forms of political community based on traditional ties and customs, as well as those based on religious, ethnic, or racial identity.
+
+Arendt’s political thought cannot, in this sense, be identified either with the liberal tradition or with the claims advanced by a number of its critics. Arendt did not conceive of politics as a means for the satisfaction of individual preferences, nor as a way to integrate individuals around a shared conception of the good. Her conception of politics is based instead on the idea of active citizenship, that is, on the value and importance of civic engagement and collective deliberation about all matters affecting the political community. If there is a tradition of thought with which Arendt can be identified, it is the classical tradition of civic republicanism originating in Aristotle and embodied in the writings of Machiavelli, Montesquieu, Jefferson, and Tocqueville. According to this tradition politics finds its authentic expression whenever citizens gather together in a public space to deliberate and decide about matters of collective concern. Political activity is valued not because it may lead to agreement or to a shared conception of the good, but because it enables each citizen to exercise his or her powers of agency, to develop the capacities for judgment and to attain by concerted action some measure of political efficacy."""
+
+def get_num_tokens(text):
+ mistral_tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1", padding_side="left")
+ inputs = mistral_tokenizer(text, return_tensors="pt").input_ids.to("cuda")
+ return inputs.shape[1]
+
+print('number of tokens for input:', get_num_tokens(long_input))
+
+results = gen_text([long_input], use_template=True, max_tokens=150)
+print(results[0])
+
+# COMMAND ----------
+
+# MAGIC %md
+# MAGIC ### Batch Inference
+# MAGIC
+
+# COMMAND ----------
+
+# From databricks-dolly-15k
+inputs = [
+ "Think of some family rules to promote a healthy family relationship",
+ "In the series A Song of Ice and Fire, who is the founder of House Karstark?",
+ "which weighs more, cold or hot water?",
+ "Write a short paragraph about why you should not have both a pet cat and a pet bird.",
+ "Is beauty objective or subjective?",
+ "What is SVM?",
+ "What is the current capital of Japan?",
+ "Name 10 colors",
+ "How should I invest my money?",
+ "What are some ways to improve the value of your home?",
+ "What does fasting mean?",
+ "What is cloud computing in simple terms?",
+ "What is the meaning of life?",
+ "What is Linux?",
+ "Why do people like gardening?",
+ "What makes for a good photograph?"
+]
+
+# COMMAND ----------
+
+results = gen_text(inputs, use_template=True)
+
+for output in results:
+ print(output)
+ print('\n')
+for i, output in enumerate(results):
+ print(f"======Output No. {i+1}======")
+ print(output)
+ print("\n")
+
+# COMMAND ----------
+
+# MAGIC %md
+# MAGIC ### Measure inference speed
+# MAGIC Text generation speed is often measured with token/s, which is the average number of tokens that are generated by the model per second.
+# MAGIC
+
+# COMMAND ----------
+
+import time
+import logging
+
+
+def get_gen_text_throughput(prompt, use_template=True, **kwargs):
+ """
+ Return tuple ( number of tokens / sec, num tokens, output ) of the generated tokens
+ """
+ if use_template:
+ full_prompt = PROMPT_FOR_GENERATION_FORMAT.format(instruction=prompt)
+ else:
+ full_prompt = prompt
+
+ if "max_tokens" not in kwargs:
+ kwargs["max_tokens"] = 512
+ sampling_params = SamplingParams(**kwargs)
+
+ num_input_tokens = get_num_tokens(full_prompt)
+
+ # measure the time it takes for text generation
+ start = time.time()
+ outputs = llm.generate(full_prompt, sampling_params=sampling_params)
+ duration = time.time() - start
+
+ # get the number of generated tokens
+ token_ids = outputs[0].outputs[0].token_ids
+ n_tokens = len(token_ids)
+
+ # show the generated text in logging
+ text = outputs[0].outputs[0].text
+
+ return (n_tokens / duration, n_tokens, text)
+
+# COMMAND ----------
+
+throughput, n_tokens, text = get_gen_text_throughput("What is ML?", use_template=False)
+
+print(f"{throughput:.2f} tokens/sec, {n_tokens} tokens (not including prompt)")
+
+# COMMAND ----------
+
+# When the context is long or the generated text is long, it takes longer to generate each token in average
+throughput, n_tokens, text = get_gen_text_throughput(long_input, use_template=True, max_tokens=200)
+
+print(f"{throughput:2f} tokens/sec, {n_tokens} tokens (not including prompt)")
+
+# COMMAND ----------
+
+
diff --git a/llm-models/mistral/mistral-7b/01_load_inference.py b/llm-models/mistral/mistral-7b/01_load_inference.py
new file mode 100644
index 0000000..970c607
--- /dev/null
+++ b/llm-models/mistral/mistral-7b/01_load_inference.py
@@ -0,0 +1,242 @@
+# Databricks notebook source
+# MAGIC %md
+# MAGIC # Mistral-7B-Instruct Inference on Databricks
+# MAGIC
+# MAGIC The [Mistral-7B-Instruct-v0.1](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1) Large Language Model (LLM) is a instruct fine-tuned version of the [Mistral-7B-v0.1](https://huggingface.co/mistralai/Mistral-7B-v0.1) generative text model using a variety of publicly available conversation datasets.
+# MAGIC
+# MAGIC Environment for this notebook:
+# MAGIC - Runtime: 14.0 GPU ML Runtime
+# MAGIC - Instance: `g5.4xlarge` on AWS, `Standard_NV36ads_A10_v5` on Azure
+
+# COMMAND ----------
+
+# MAGIC %pip install -U transformers==4.34.0
+# MAGIC dbutils.library.restartPython()
+
+# COMMAND ----------
+
+# MAGIC %md
+# MAGIC ## Inference
+# MAGIC The example in the model card should also work on Databricks with the same environment.
+
+# COMMAND ----------
+
+# Load model to text generation pipeline
+from transformers import AutoTokenizer, AutoModelForCausalLM
+import transformers
+import torch
+
+# it is suggested to pin the revision commit hash and not change it for reproducibility because the uploader might change the model afterwards; you can find the commmit history of Mistral-7B-Instruct-v0. in https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1/commits/main
+model = "mistralai/Mistral-7B-Instruct-v0.1"
+revision = "3dc28cf29d2edd31a0a7b8f0b21637059815b4d5"
+
+tokenizer = AutoTokenizer.from_pretrained(model, padding_side="left")
+pipeline = transformers.pipeline(
+ "text-generation",
+ model=model,
+ tokenizer=tokenizer,
+ torch_dtype=torch.bfloat16,
+ device_map="auto",
+ revision=revision,
+ do_sample=True,
+ return_full_text=False
+)
+
+# Required tokenizer setting for batch inference
+pipeline.tokenizer.pad_token_id = tokenizer.eos_token_id
+
+# COMMAND ----------
+
+# In order to leverage instruction fine-tuning, your prompt should be surrounded by [INST] and [\INST] tokens. The very first instruction should begin with a begin of sentence id. The next instructions should not. The assistant generation will be ended by the end-of-sentence token id.
+
+DEFAULT_SYSTEM_PROMPT = """\
+You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
+
+If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information."""
+
+INTRO_BLURB = "Below is an instruction that describes a task. Write a response that appropriately completes the request."
+PROMPT_FOR_GENERATION_FORMAT = """
+[INST]<>
+{system_prompt}
+<>
+
+{instruction}
+[/INST]
+""".format(
+ system_prompt=DEFAULT_SYSTEM_PROMPT,
+ instruction="{instruction}"
+)
+
+# COMMAND ----------
+
+# Define parameters to generate text
+def gen_text(prompts, use_template=False, **kwargs):
+ if use_template:
+ full_prompts = [
+ PROMPT_FOR_GENERATION_FORMAT.format(instruction=prompt)
+ for prompt in prompts
+ ]
+ else:
+ full_prompts = prompts
+
+ if "batch_size" not in kwargs:
+ kwargs["batch_size"] = 1
+
+ # the default max length is pretty small (20), which would cut the generated output in the middle, so it's necessary to increase the threshold to the complete response
+ if "max_new_tokens" not in kwargs:
+ kwargs["max_new_tokens"] = 512
+
+ # configure other text generation arguments, see common configurable args here: https://huggingface.co/docs/transformers/main_classes/text_generation#transformers.GenerationConfig
+ kwargs.update(
+ {
+ "pad_token_id": tokenizer.eos_token_id, # Hugging Face sets pad_token_id to eos_token_id by default; setting here to not see redundant message
+ "eos_token_id": tokenizer.eos_token_id,
+ }
+ )
+
+ outputs = pipeline(full_prompts, **kwargs)
+ outputs = [out[0]["generated_text"] for out in outputs]
+
+ return outputs
+
+# COMMAND ----------
+
+# MAGIC %md
+# MAGIC ### Inference on a single input
+
+# COMMAND ----------
+
+results = gen_text(["[INST]What is a large language model?[/INST]"])
+print(results[0])
+
+# COMMAND ----------
+
+# Use args such as temperature and max_new_tokens to control text generation
+results = gen_text(["What is a large language model?"], temperature=0.5, max_new_tokens=100, use_template=True)
+print(results[0])
+
+# COMMAND ----------
+
+# Check that the generation quality when the context is long
+
+long_input = """Provide a concise summary of the below passage.
+
+Hannah Arendt was one of the seminal political thinkers of the twentieth century. The power and originality of her thinking was evident in works such as The Origins of Totalitarianism, The Human Condition, On Revolution and The Life of the Mind. In these works and in numerous essays she grappled with the most crucial political events of her time, trying to grasp their meaning and historical import, and showing how they affected our categories of moral and political judgment. What was required, in her view, was a new framework that could enable us to come to terms with the twin horrors of the twentieth century, Nazism and Stalinism. She provided such framework in her book on totalitarianism, and went on to develop a new set of philosophical categories that could illuminate the human condition and provide a fresh perspective on the nature of political life.
+
+Although some of her works now belong to the classics of the Western tradition of political thought, she has always remained difficult to classify. Her political philosophy cannot be characterized in terms of the traditional categories of conservatism, liberalism, and socialism. Nor can her thinking be assimilated to the recent revival of communitarian political thought, to be found, for example, in the writings of A. MacIntyre, M. Sandel, C. Taylor and M. Walzer. Her name has been invoked by a number of critics of the liberal tradition, on the grounds that she presented a vision of politics that stood in opposition some key liberal principles. There are many strands of Arendt’s thought that could justify such a claim, in particular, her critique of representative democracy, her stress on civic engagement and political deliberation, her separation of morality from politics, and her praise of the revolutionary tradition. However, it would be a mistake to view Arendt as an anti-liberal thinker. Arendt was in fact a stern defender of constitutionalism and the rule of law, an advocate of fundamental human rights (among which she included not only the right to life, liberty, and freedom of expression, but also the right to action and to opinion), and a critic of all forms of political community based on traditional ties and customs, as well as those based on religious, ethnic, or racial identity.
+
+Arendt’s political thought cannot, in this sense, be identified either with the liberal tradition or with the claims advanced by a number of its critics. Arendt did not conceive of politics as a means for the satisfaction of individual preferences, nor as a way to integrate individuals around a shared conception of the good. Her conception of politics is based instead on the idea of active citizenship, that is, on the value and importance of civic engagement and collective deliberation about all matters affecting the political community. If there is a tradition of thought with which Arendt can be identified, it is the classical tradition of civic republicanism originating in Aristotle and embodied in the writings of Machiavelli, Montesquieu, Jefferson, and Tocqueville. According to this tradition politics finds its authentic expression whenever citizens gather together in a public space to deliberate and decide about matters of collective concern. Political activity is valued not because it may lead to agreement or to a shared conception of the good, but because it enables each citizen to exercise his or her powers of agency, to develop the capacities for judgment and to attain by concerted action some measure of political efficacy."""
+
+def get_num_tokens(text):
+ inputs = tokenizer(text, return_tensors="pt").input_ids.to("cuda")
+ return inputs.shape[1]
+
+print('number of tokens for input:', get_num_tokens(long_input))
+
+results = gen_text([long_input], max_new_tokens=150, use_template=True)
+print(results[0])
+
+# COMMAND ----------
+
+# MAGIC %md
+# MAGIC ### Batch inference
+
+# COMMAND ----------
+
+# From databricks-dolly-15k
+inputs = [
+ "Think of some family rules to promote a healthy family relationship",
+ "In the series A Song of Ice and Fire, who is the founder of House Karstark?",
+ "which weighs more, cold or hot water?",
+ "Write a short paragraph about why you should not have both a pet cat and a pet bird.",
+ "Is beauty objective or subjective?",
+ "What is SVM?",
+ "What is the current capital of Japan?",
+ "Name 10 colors",
+ "How should I invest my money?",
+ "What are some ways to improve the value of your home?",
+ "What does fasting mean?",
+ "What is cloud computing in simple terms?",
+ "What is the meaning of life?",
+ "What is Linux?",
+ "Why do people like gardening?",
+ "What makes for a good photograph?"
+]
+
+# COMMAND ----------
+
+# Set batch size
+results = gen_text(inputs, use_template=True, batch_size=8)
+
+for output in results:
+ print(output)
+ print('\n')
+
+# COMMAND ----------
+
+# MAGIC %md
+# MAGIC ## Measure inference speed
+# MAGIC Text generation speed is often measured with token/s, which is the average number of tokens that are generated by the model per second.
+# MAGIC
+
+# COMMAND ----------
+
+import time
+import logging
+
+
+def get_gen_text_throughput(prompt, use_template=True, **kwargs):
+ """
+ Return tuple ( number of tokens / sec, num tokens, output ) of the generated tokens
+ """
+ if use_template:
+ full_prompt = PROMPT_FOR_GENERATION_FORMAT.format(instruction=prompt)
+ else:
+ full_prompt = prompt
+
+ if "max_new_tokens" not in kwargs:
+ kwargs["max_new_tokens"] = 512
+
+ kwargs.update(
+ {
+ "do_sample": True,
+ "pad_token_id": tokenizer.eos_token_id,
+ "eos_token_id": tokenizer.eos_token_id,
+ "return_tensors": True, # make the pipeline return token ids instead of decoded text to get the number of generated tokens
+ }
+ )
+
+ num_input_tokens = get_num_tokens(full_prompt)
+
+ # measure the time it takes for text generation
+ start = time.time()
+ outputs = pipeline(full_prompt, **kwargs)
+ duration = time.time() - start
+
+ # get the number of generated tokens
+ n_tokens = len(outputs[0]["generated_token_ids"])
+
+ # show the generated text in logging
+ result = tokenizer.batch_decode(
+ outputs[0]["generated_token_ids"][num_input_tokens:], skip_special_tokens=True
+ )
+ result = "".join(result)
+
+ return ((n_tokens - num_input_tokens) / duration, (n_tokens - num_input_tokens), result)
+
+# COMMAND ----------
+
+throughput, n_tokens, result = get_gen_text_throughput("What is ML?", use_template=False)
+
+print(f"{throughput} tokens/sec, {n_tokens} tokens (not including prompt)")
+
+# COMMAND ----------
+
+# When the context is long or the generated text is long, it takes longer to generate each token in average
+throughput, n_tokens, result = get_gen_text_throughput(long_input, max_new_tokens=200, use_template=True)
+
+print(f"{throughput} tokens/sec, {n_tokens} tokens (not including prompt)")
+
+# COMMAND ----------
+
+
diff --git a/llm-models/mistral/mistral-7b/02_[chat]_mlflow_logging_inference.py b/llm-models/mistral/mistral-7b/02_[chat]_mlflow_logging_inference.py
new file mode 100644
index 0000000..7be5c56
--- /dev/null
+++ b/llm-models/mistral/mistral-7b/02_[chat]_mlflow_logging_inference.py
@@ -0,0 +1,322 @@
+# Databricks notebook source
+# MAGIC %md
+# MAGIC # Manage Mistral-7B-Instruct as chat completion model with MLFlow on Databricks
+# MAGIC
+# MAGIC The [Mistral-7B-Instruct-v0.1](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1) Large Language Model (LLM) is a instruct fine-tuned version of the [Mistral-7B-v0.1](https://huggingface.co/mistralai/Mistral-7B-v0.1) generative text model using a variety of publicly available conversation datasets.
+# MAGIC
+# MAGIC Environment for this notebook:
+# MAGIC - Runtime: 14.0 GPU ML Runtime
+# MAGIC - Instance: `g5.xlarge` on AWS, `Standard_NV36ads_A10_v5` on Azure
+
+# COMMAND ----------
+
+# MAGIC %pip install -U "mlflow-skinny[databricks]>=2.4.1"
+# MAGIC %pip install -U transformers==4.34.0
+# MAGIC %pip install -U databricks-sdk
+# MAGIC dbutils.library.restartPython()
+
+# COMMAND ----------
+
+# MAGIC %md
+# MAGIC ## Log the model to MLFlow
+
+# COMMAND ----------
+
+# it is suggested to pin the revision commit hash and not change it for reproducibility because the uploader might change the model afterwards; you can find the commmit history of Mistral-7B-Instruct-v0. in https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1/commits/main
+model = "mistralai/Mistral-7B-Instruct-v0.1"
+revision = "3dc28cf29d2edd31a0a7b8f0b21637059815b4d5"
+
+from huggingface_hub import snapshot_download
+
+# If the model has been downloaded in previous cells, this will not repetitively download large model files, but only the remaining files in the repo
+snapshot_location = snapshot_download(repo_id=model, revision=revision)
+
+# COMMAND ----------
+
+import json
+import mlflow
+import torch
+import transformers
+from transformers import StoppingCriteria, StoppingCriteriaList
+
+
+class ChatStoppingCriteria(StoppingCriteria):
+ def __init__(self, stops=[]):
+ super().__init__()
+ self.stops = [stop.to("cuda") for stop in stops]
+
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
+ for stop in self.stops:
+ if torch.all((stop[2:] == input_ids[0][-(len(stop) - 2) :])).item():
+ return True
+
+ return False
+
+
+# Define PythonModel which is compatible to OpenAI-compatible APIs to log with mlflow.pyfunc.log_model
+class MistralChat(mlflow.pyfunc.PythonModel):
+ def load_context(self, context):
+ """
+ This method initializes the tokenizer and language model
+ using the specified model repository.
+ """
+ # Initialize tokenizer and language model
+ self.tokenizer = transformers.AutoTokenizer.from_pretrained(
+ context.artifacts["repository"], padding_side="left"
+ )
+ self.model = transformers.AutoModelForCausalLM.from_pretrained(
+ context.artifacts["repository"],
+ torch_dtype=torch.bfloat16,
+ low_cpu_mem_usage=True,
+ trust_remote_code=True,
+ device_map="cuda",
+ pad_token_id=self.tokenizer.eos_token_id,
+ )
+ self.model.eval()
+
+ def _generate_response(
+ self, messages, candidate_count, temperature, max_tokens, stop
+ ):
+ """
+ This method generates prediction for a single input.
+ """
+ encodeds = self.tokenizer.apply_chat_template(messages, return_tensors="pt")
+
+ # Encode the input and generate prediction
+ encoded_input = encodeds.to("cuda")
+ generation_config = transformers.GenerationConfig(
+ max_new_tokens=max_tokens,
+ do_sample=True,
+ temperature=temperature,
+ num_return_sequences=candidate_count,
+ )
+ if stop:
+ stop_words_ids = [
+ self.tokenizer(stop_word, return_tensors="pt")["input_ids"].squeeze()
+ for stop_word in stop
+ ]
+ stopping_criteria = StoppingCriteriaList(
+ [ChatStoppingCriteria(stops=stop_words_ids)]
+ )
+ else:
+ stopping_criteria=None
+
+ output = self.model.generate(
+ encoded_input,
+ generation_config=generation_config,
+ stopping_criteria=stopping_criteria,
+ eos_token_id=self.tokenizer.eos_token_id,
+ pad_token_id=self.tokenizer.eos_token_id,
+ )
+
+ response_messages = []
+ prompt_length = len(encoded_input)
+
+ # Decode the prediction to text
+ output_tokens = 0
+ for i in range(len(output)):
+ generated_text = self.tokenizer.decode(output[i], skip_special_tokens=True)
+
+ # Removing the prompt from the generated text
+ generated_response = self.tokenizer.decode(
+ output[i], skip_special_tokens=True
+ )
+
+ gen_length = len(output[i]) - prompt_length
+
+ generated_response = {
+ "message": {
+ "role": "assistant",
+ "content": generated_response[prompt_length:],
+ },
+ "metadata": {"finish_reason": "length" if gen_length==max_tokens else "stop"},
+ }
+
+ response_messages.append(generated_response)
+ output_tokens += gen_length
+
+ metadata = {
+ "input_tokens": prompt_length,
+ "output_tokens": output_tokens,
+ "total_tokens": prompt_length+output_tokens,
+ "model": "mistralai/Mistral-7B-Instruct-v0.1",
+ "route_type": "llm/v1/chat",
+ }
+
+ return response_messages, metadata
+
+ def predict(self, context, model_input, params=None):
+ """
+ This method generates prediction for the given input.
+ The input parameters are compatible with `llm/v1/chat`
+ https://mlflow.org/docs/latest/gateway/index.html#chat
+ """
+
+ outputs = []
+
+ # The standard parameters for chat routes with type llm/v1/chat can be find at
+ # https://mlflow.org/docs/latest/gateway/index.html#chat
+ messages = model_input["messages"][0]
+ candidate_count = params.get("candidate_count", 1)
+ temperature = params.get("temperature", 1.0)
+ max_tokens = params.get("max_tokens", 100)
+ stop = params.get("stop", [])
+
+ response_messages, metadata = self._generate_response(
+ messages, candidate_count, temperature, max_tokens, stop
+ )
+
+ outputs.append({"candidates": response_messages, "metadata": metadata})
+
+ # {"candidates": [...]} is the required response format for MLflow AI gateway -- see 07_ai_gateway for example
+ return outputs
+
+# COMMAND ----------
+
+from mlflow.models.signature import ModelSignature
+from mlflow.types import DataType, Schema, ColSpec, ParamSchema, ParamSpec
+import pandas as pd
+
+# Define input and output schema
+input_schema = Schema([ColSpec(DataType.string, "messages")])
+output_schema = Schema([ColSpec(DataType.string)])
+param_schema = ParamSchema([
+ ParamSpec("candidate_count", "long", 1),
+ ParamSpec("temperature", "double", 1.0),
+ ParamSpec("max_tokens", "long", 512),
+ ParamSpec("stop", "string", None),
+])
+signature = ModelSignature(inputs=input_schema, outputs=output_schema, params=param_schema)
+
+# Define input example
+input_example = pd.DataFrame(
+ {
+ "messages": [
+ [
+ {"role": "user", "content": "What is ML?"},
+ ]
+ ],
+ }
+)
+
+# Log the model with its details such as artifacts, pip requirements and input example
+with mlflow.start_run() as run:
+ mlflow.pyfunc.log_model(
+ "model",
+ python_model=MistralChat(),
+ artifacts={"repository":
+ snapshot_location},
+ input_example=input_example,
+ pip_requirements=["torch==2.0.1", "transformers==4.34.0", "accelerate==0.21.0", "torchvision==0.15.2"],
+ signature=signature,
+ )
+
+# COMMAND ----------
+
+# MAGIC %md
+# MAGIC ## Register the model to Unity Catalog
+# MAGIC By default, MLflow registers models in the Databricks workspace model registry. To register models in Unity Catalog instead, we follow the [documentation](https://docs.databricks.com/machine-learning/manage-model-lifecycle/index.html) and set the registry server as Databricks Unity Catalog.
+# MAGIC
+# MAGIC In order to register a model in Unity Catalog, there are [several requirements](https://docs.databricks.com/machine-learning/manage-model-lifecycle/index.html#requirements), such as Unity Catalog must be enabled in your workspace.
+# MAGIC
+
+# COMMAND ----------
+
+# Configure MLflow Python client to register model in Unity Catalog
+import mlflow
+mlflow.set_registry_uri("databricks-uc")
+
+# COMMAND ----------
+
+# Register model to Unity Catalog
+
+registered_name = "models.default.mistral_7b_chat_completion" # Note that the UC model name follows the pattern .., corresponding to the catalog, schema, and registered model name
+
+
+result = mlflow.register_model(
+ "runs:/" + run.info.run_id + "/model",
+ registered_name,
+)
+
+# COMMAND ----------
+
+from mlflow import MlflowClient
+
+client = MlflowClient()
+
+# Choose the right model version registered in the above cell.
+client.set_registered_model_alias(
+ name=registered_name, alias="Champion", version=result.version
+)
+
+# COMMAND ----------
+
+# MAGIC %md
+# MAGIC ## Load the model from Unity Catalog
+
+# COMMAND ----------
+
+import mlflow
+import pandas as pd
+
+registered_name = "models.default.mistral_7b_chat_completion"
+loaded_model = mlflow.pyfunc.load_model(f"models:/{registered_name}@Champion")
+
+# Make a prediction using the loaded model
+loaded_model.predict(
+ {
+ "messages": [
+ [
+ {"role": "user", "content": "You are a helpful assistant. Answer the following question.\n What is ML?"},
+ ]
+ ],
+ }
+)
+
+# COMMAND ----------
+
+# MAGIC %md
+# MAGIC ## Create Model Serving Endpoint
+# MAGIC Once the model is registered, we can use API to create a Databricks GPU Model Serving Endpoint that serves the `LLaMAV2-7b` model.
+# MAGIC
+# MAGIC Note that the below deployment requires GPU model serving. For more information on GPU model serving, contact the Databricks team or sign up [here](https://docs.google.com/forms/d/1-GWIlfjlIaclqDz6BPODI2j1Xg4f4WbFvBXyebBpN-Y/edit).
+
+# COMMAND ----------
+
+# Provide a name to the serving endpoint
+endpoint_name = 'mistral-7b-chat-completion'
+
+# COMMAND ----------
+
+from databricks.sdk import WorkspaceClient
+from databricks.sdk.service.serving import EndpointCoreConfigInput
+w = WorkspaceClient()
+
+model_version = result # the returned result of mlflow.register_model
+
+# Specify the type of compute (CPU, GPU_SMALL, GPU_MEDIUM, etc.)
+# Choose GPU_MEDIUM on Azure, and `GPU_LARGE` on Azure
+workload_type = "GPU_LARGE"
+
+config = EndpointCoreConfigInput.from_dict({
+ "served_models": [
+ {
+ "name": f'{model_version.name.replace(".", "_")}_{model_version.version}',
+ "model_name": model_version.name,
+ "model_version": model_version.version,
+ "workload_type": workload_type,
+ "workload_size": "Small",
+ "scale_to_zero_enabled": "False",
+ }
+ ]
+})
+w.serving_endpoints.create(name=endpoint_name, config=config)
+
+# COMMAND ----------
+
+# MAGIC %md
+# MAGIC Once the model serving endpoint is ready, you can query it easily with LangChain (see `04_langchain` for example code) running in the same workspace.
+
+# COMMAND ----------
+
+
diff --git a/llm-models/mistral/mistral-7b/02_mlflow_logging_inference.py b/llm-models/mistral/mistral-7b/02_mlflow_logging_inference.py
new file mode 100644
index 0000000..378b983
--- /dev/null
+++ b/llm-models/mistral/mistral-7b/02_mlflow_logging_inference.py
@@ -0,0 +1,260 @@
+# Databricks notebook source
+# MAGIC %md
+# MAGIC # Manage Mistral-7B-Instruct model with MLFlow on Databricks
+# MAGIC
+# MAGIC The [Mistral-7B-Instruct-v0.1](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1) Large Language Model (LLM) is a instruct fine-tuned version of the [Mistral-7B-v0.1](https://huggingface.co/mistralai/Mistral-7B-v0.1) generative text model using a variety of publicly available conversation datasets.
+# MAGIC
+# MAGIC Environment for this notebook:
+# MAGIC - Runtime: 14.0 GPU ML Runtime
+# MAGIC - Instance: `g5.xlarge` on AWS, `Standard_NV36ads_A10_v5` on Azure
+
+# COMMAND ----------
+
+# MAGIC %pip install -U "mlflow-skinny[databricks]>=2.4.1"
+# MAGIC %pip install -U transformers==4.34.0
+# MAGIC dbutils.library.restartPython()
+
+# COMMAND ----------
+
+# MAGIC %md
+# MAGIC ## Log the model to MLFlow
+
+# COMMAND ----------
+
+# it is suggested to pin the revision commit hash and not change it for reproducibility because the uploader might change the model afterwards; you can find the commmit history of Mistral-7B-Instruct-v0. in https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1/commits/main
+model = "mistralai/Mistral-7B-Instruct-v0.1"
+revision = "3dc28cf29d2edd31a0a7b8f0b21637059815b4d5"
+
+from huggingface_hub import snapshot_download
+
+# If the model has been downloaded in previous cells, this will not repetitively download large model files, but only the remaining files in the repo
+snapshot_location = snapshot_download(repo_id=model, revision=revision)
+
+# COMMAND ----------
+
+import mlflow
+import torch
+import transformers
+
+# Define prompt template to get the expected features and performance for the chat versions. See our reference code in github for details: https://github.com/facebookresearch/llama/blob/main/llama/generation.py#L212
+
+DEFAULT_SYSTEM_PROMPT = """\
+You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
+
+If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information."""
+
+# Define PythonModel to log with mlflow.pyfunc.log_model
+
+class Mistral7B(mlflow.pyfunc.PythonModel):
+ def load_context(self, context):
+ """
+ This method initializes the tokenizer and language model
+ using the specified model repository.
+ """
+ # Initialize tokenizer and language model
+ self.tokenizer = transformers.AutoTokenizer.from_pretrained(
+ context.artifacts['repository'], padding_side="left")
+ self.model = transformers.AutoModelForCausalLM.from_pretrained(
+ context.artifacts['repository'],
+ torch_dtype=torch.bfloat16,
+ low_cpu_mem_usage=True,
+ trust_remote_code=True,
+ device_map="auto",
+ pad_token_id=self.tokenizer.eos_token_id)
+ self.model.eval()
+
+ def _build_prompt(self, instruction):
+ """
+ This method generates the prompt for the model.
+ """
+ return f"""[INST]<>\n{DEFAULT_SYSTEM_PROMPT}\n<>\n\n\n{instruction}[/INST]\n"""
+
+ def _generate_response(self, prompt, temperature, max_new_tokens):
+ """
+ This method generates prediction for a single input.
+ """
+ # Build the prompt
+ prompt = self._build_prompt(prompt)
+
+ # Encode the input and generate prediction
+ encoded_input = self.tokenizer.encode(prompt, return_tensors='pt').to('cuda')
+ output = self.model.generate(encoded_input, do_sample=True, temperature=temperature, max_new_tokens=max_new_tokens)
+
+ # Decode the prediction to text
+ generated_text = self.tokenizer.decode(output[0], skip_special_tokens=True)
+
+ # Removing the prompt from the generated text
+ prompt_length = len(self.tokenizer.encode(prompt, return_tensors='pt')[0])
+ generated_response = self.tokenizer.decode(output[0][prompt_length:], skip_special_tokens=True)
+
+ return generated_response
+
+ def predict(self, context, model_input):
+ """
+ This method generates prediction for the given input.
+ """
+
+ outputs = []
+
+ for i in range(len(model_input)):
+ prompt = model_input["prompt"][i]
+ temperature = model_input.get("temperature", [1.0])[i]
+ max_new_tokens = model_input.get("max_new_tokens", [100])[i]
+
+ outputs.append(self._generate_response(prompt, temperature, max_new_tokens))
+
+ # {"candidates": [...]} is the required response format for MLflow AI gateway -- see 07_ai_gateway for example
+ return {"candidates": outputs}
+
+# COMMAND ----------
+
+from mlflow.models.signature import ModelSignature
+from mlflow.types import DataType, Schema, ColSpec
+
+import pandas as pd
+
+# Define input and output schema
+input_schema = Schema([
+ ColSpec(DataType.string, "prompt"),
+ ColSpec(DataType.double, "temperature", optional=True),
+ ColSpec(DataType.long, "max_new_tokens", optional=True)])
+output_schema = Schema([ColSpec(DataType.string)])
+signature = ModelSignature(inputs=input_schema, outputs=output_schema)
+
+# Define input example
+input_example=pd.DataFrame({
+ "prompt":["what is ML?"],
+ "temperature": [0.5],
+ "max_new_tokens": [100]})
+
+# Log the model with its details such as artifacts, pip requirements and input example
+with mlflow.start_run() as run:
+ mlflow.pyfunc.log_model(
+ "model",
+ python_model=Mistral7B(),
+ artifacts={'repository' : snapshot_location},
+ input_example=input_example,
+ pip_requirements=["torch==2.0.1", "transformers==4.34.0", "accelerate==0.21.0", "torchvision==0.15.2"],
+ signature=signature,
+ )
+
+# COMMAND ----------
+
+# MAGIC %md
+# MAGIC ## Register the model to Unity Catalog
+# MAGIC By default, MLflow registers models in the Databricks workspace model registry. To register models in Unity Catalog instead, we follow the [documentation](https://docs.databricks.com/machine-learning/manage-model-lifecycle/index.html) and set the registry server as Databricks Unity Catalog.
+# MAGIC
+# MAGIC In order to register a model in Unity Catalog, there are [several requirements](https://docs.databricks.com/machine-learning/manage-model-lifecycle/index.html#requirements), such as Unity Catalog must be enabled in your workspace.
+# MAGIC
+
+# COMMAND ----------
+
+# Configure MLflow Python client to register model in Unity Catalog
+import mlflow
+mlflow.set_registry_uri("databricks-uc")
+
+# COMMAND ----------
+
+# Register model to Unity Catalog
+# This may take 1.1 minutes to complete
+
+registered_name = "models.default.mistral_7b_instruct" # Note that the UC model name follows the pattern .., corresponding to the catalog, schema, and registered model name
+
+
+result = mlflow.register_model(
+ "runs:/"+run.info.run_id+"/model",
+ registered_name,
+)
+
+# COMMAND ----------
+
+from mlflow import MlflowClient
+client = MlflowClient()
+
+# Choose the right model version registered in the above cell.
+client.set_registered_model_alias(name=registered_name, alias="Champion", version=result.version)
+
+# COMMAND ----------
+
+# MAGIC %md
+# MAGIC ## Load the model from Unity Catalog
+
+# COMMAND ----------
+
+import mlflow
+import pandas as pd
+
+loaded_model = mlflow.pyfunc.load_model(f"models:/{registered_name}@Champion")
+
+# Make a prediction using the loaded model
+loaded_model.predict(
+ {
+ "prompt": ["What is ML?", "What is large language model?"],
+ "temperature": [0.1, 0.5],
+ "max_new_tokens": [100, 100],
+ }
+)
+
+# COMMAND ----------
+
+# MAGIC %md
+# MAGIC ## Create Model Serving Endpoint
+# MAGIC Once the model is registered, we can use API to create a Databricks GPU Model Serving Endpoint that serves the `LLaMAV2-7b` model.
+# MAGIC
+# MAGIC Note that the below deployment requires GPU model serving. For more information on GPU model serving, contact the Databricks team or sign up [here](https://docs.google.com/forms/d/1-GWIlfjlIaclqDz6BPODI2j1Xg4f4WbFvBXyebBpN-Y/edit).
+
+# COMMAND ----------
+
+# Provide a name to the serving endpoint
+endpoint_name = 'mistral-7b-instruct'
+
+# COMMAND ----------
+
+databricks_url = dbutils.notebook.entry_point.getDbutils().notebook().getContext().apiUrl().getOrElse(None)
+token = dbutils.notebook.entry_point.getDbutils().notebook().getContext().apiToken().getOrElse(None)
+
+# COMMAND ----------
+
+import requests
+import json
+
+deploy_headers = {'Authorization': f'Bearer {token}', 'Content-Type': 'application/json'}
+deploy_url = f'{databricks_url}/api/2.0/serving-endpoints'
+
+model_version = result # the returned result of mlflow.register_model
+served_name = f'{model_version.name.replace(".", "_")}_{model_version.version}'
+
+# Specify the type of compute (CPU, GPU_SMALL, GPU_MEDIUM, etc.)
+# Choose GPU_MEDIUM on Azure, and `GPU_LARGE` on Azure
+workload_type = "GPU_LARGE"
+
+endpoint_config = {
+ "name": endpoint_name,
+ "config": {
+ "served_models": [{
+ "name": served_name,
+ "model_name": model_version.name,
+ "model_version": model_version.version,
+ "workload_type": workload_type,
+ "workload_size": "Small",
+ "scale_to_zero_enabled": "False"
+ }]
+ }
+}
+endpoint_json = json.dumps(endpoint_config, indent=' ')
+
+# Send a POST request to the API
+deploy_response = requests.request(method='POST', headers=deploy_headers, url=deploy_url, data=endpoint_json)
+
+if deploy_response.status_code != 200:
+ raise Exception(f'Request failed with status {deploy_response.status_code}, {deploy_response.text}')
+
+# Show the response of the POST request
+# When first creating the serving endpoint, it should show that the state 'ready' is 'NOT_READY'
+# You can check the status on the Databricks model serving endpoint page, it is expected to take ~35 min for the serving endpoint to become ready
+print(deploy_response.json())
+
+# COMMAND ----------
+
+# MAGIC %md
+# MAGIC Once the model serving endpoint is ready, you can query it easily with LangChain (see `04_langchain` for example code) running in the same workspace.
diff --git a/llm-models/mistral/mistral-7b/03_[chat]_serve_driver_proxy.py b/llm-models/mistral/mistral-7b/03_[chat]_serve_driver_proxy.py
new file mode 100644
index 0000000..1c69e13
--- /dev/null
+++ b/llm-models/mistral/mistral-7b/03_[chat]_serve_driver_proxy.py
@@ -0,0 +1,172 @@
+# Databricks notebook source
+# MAGIC %md
+# MAGIC
+# MAGIC # Serving istral-7B-Instruct as chat completion via vllm with a cluster driver proxy app
+# MAGIC
+# MAGIC The [Mistral-7B-Instruct-v0.1](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1) Large Language Model (LLM) is a instruct fine-tuned version of the [Mistral-7B-v0.1](https://huggingface.co/mistralai/Mistral-7B-v0.1) generative text model using a variety of publicly available conversation datasets.
+# MAGIC
+# MAGIC [vllm](https://github.com/vllm-project/vllm/tree/main) is an open-source library that makes LLM inference fast with various optimizations.
+# MAGIC
+# MAGIC Environment for this notebook:
+# MAGIC - Runtime: 14.0 GPU ML Runtime
+# MAGIC - Instance: `g5.xlarge` on AWS, `Standard_NV36ads_A10_v5` on Azure
+
+# COMMAND ----------
+
+# MAGIC %pip install -U vllm==0.2.0 fschat==0.2.30 transformers==4.34.0 accelerate==0.20.3
+# MAGIC dbutils.library.restartPython()
+
+# COMMAND ----------
+
+# MAGIC %md
+# MAGIC ## Inference
+# MAGIC The example in the model card should also work on Databricks with the same environment.
+
+# COMMAND ----------
+
+from vllm import LLM
+
+# it is suggested to pin the revision commit hash and not change it for reproducibility because the uploader might change the model afterwards; you can find the commmit history of Mistral-7B-Instruct-v0. in https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1/commits/main
+model = "mistralai/Mistral-7B-Instruct-v0.1"
+revision = "3dc28cf29d2edd31a0a7b8f0b21637059815b4d5"
+
+llm = LLM(model=model, revision=revision)
+
+# COMMAND ----------
+
+from transformers import StoppingCriteria, StoppingCriteriaList
+from vllm import SamplingParams
+
+import fastchat
+from fastchat.conversation import Conversation, SeparatorStyle
+from fastchat.model.model_adapter import get_conversation_template
+
+DEFAULT_SYSTEM_PROMPT = "You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature. \n\n If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information."
+
+def build_prompt(prompts) -> str:
+ conv = get_conversation_template(model)
+ conv = Conversation(
+ name=conv.name,
+ system_template=conv.system_template,
+ system_message=conv.system_message,
+ roles=conv.roles,
+ messages=list(conv.messages), # prevent in-place modification
+ offset=conv.offset,
+ sep_style=SeparatorStyle(conv.sep_style),
+ sep=conv.sep,
+ sep2=conv.sep2,
+ stop_str=conv.stop_str,
+ stop_token_ids=conv.stop_token_ids,
+ )
+
+ if isinstance(prompts, str):
+ prompt = prompts
+ else:
+ for message in prompts:
+ msg_role = message["role"]
+ if msg_role == "system":
+ conv.system_message = message["content"]
+ elif msg_role == "user":
+ conv.append_message(conv.roles[0], message["content"])
+ elif msg_role == "assistant":
+ conv.append_message(conv.roles[1], message["content"])
+ else:
+ raise ValueError(f"Unknown role: {msg_role}")
+
+ # Add a blank message for the assistant.
+ conv.append_message(conv.roles[1], None)
+ prompt = conv.get_prompt()
+ return prompt
+
+
+# Define parameters to generate text
+def gen_text_for_serving(prompt, **kwargs):
+ prompt = build_prompt(prompt)
+
+ # Get the input params for the standard parameters for chat routes: https://mlflow.org/docs/latest/gateway/index.html#chat
+ kwargs.setdefault("max_tokens", 512)
+ kwargs.setdefault("temperature", 0.1)
+
+ sampling_params = SamplingParams(**kwargs)
+ outputs = llm.generate(prompt, sampling_params=sampling_params)
+
+ output_response = []
+ for request_output in outputs:
+ response_messages = [{
+ "message": {
+ "role": "assistant",
+ "content": completion_output.text,
+ },
+ "metadata": {"finish_reason": completion_output.finish_reason},
+ } for completion_output in request_output.outputs]
+ input_length = len(request_output.prompt_token_ids)
+ output_length = sum([len(completion_output.token_ids) for completion_output in request_output.outputs])
+ metadata = {
+ "input_tokens": input_length,
+ "output_tokens": output_length,
+ "total_tokens": input_length+input_length,
+ "model": "mistralai/Mistral-7B-Instruct-v0.1",
+ "route_type": "llm/v1/chat",
+ }
+ output_response.append({"candidates": response_messages, "metadata":metadata})
+
+ return output_response
+
+# COMMAND ----------
+
+# See all standard parameters from https://mlflow.org/docs/latest/gateway/index.html#chat
+print(
+ gen_text_for_serving(
+ [
+ {"role": "system", "content": "You are a helpful assistant."},
+ {"role": "user", "content": "What is ML?"},
+ ],
+ )
+)
+
+# COMMAND ----------
+
+# MAGIC %md
+# MAGIC ## Serve with Flask
+
+# COMMAND ----------
+
+from flask import Flask, jsonify, request
+
+app = Flask("mistral-7b-chat-completion")
+
+@app.route('/', methods=['POST'])
+def serve_mistral_7b_chat_completion():
+ resp = gen_text_for_serving(**request.json)
+ return jsonify(resp)
+
+# COMMAND ----------
+
+from dbruntime.databricks_repl_context import get_context
+ctx = get_context()
+
+port = "7777"
+driver_proxy_api = f"https://{ctx.browserHostName}/driver-proxy-api/o/0/{ctx.clusterId}/{port}"
+
+print(f"""
+driver_proxy_api = '{driver_proxy_api}'
+cluster_id = '{ctx.clusterId}'
+port = {port}
+""")
+
+# COMMAND ----------
+
+# MAGIC %md
+# MAGIC Keep `app.run` running, and it could be used with Langchain ([documentation](https://python.langchain.com/docs/modules/model_io/models/llms/integrations/databricks.html#wrapping-a-cluster-driver-proxy-app)), or by call the serving endpoint with:
+# MAGIC
+# MAGIC Or you could try using ai_query([doucmentation](https://docs.databricks.com/sql/language-manual/functions/ai_query.html)) to call this driver proxy from Databricks SQL with:
+# MAGIC
+# MAGIC Note: The [AI Functions](https://docs.databricks.com/large-language-models/ai-functions.html) is in the public preview, to enable the feature for your workspace, please submit this [form](https://docs.google.com/forms/d/e/1FAIpQLScVyh5eRioqGwuUVxj9JOiKBAo0-FWi7L3f4QWsKeyldqEw8w/viewform).
+
+# COMMAND ----------
+
+app.run(host="0.0.0.0", port=port, debug=True, use_reloader=False)
+
+# COMMAND ----------
+
+
diff --git a/llm-models/mistral/mistral-7b/03_serve_driver_proxy.py b/llm-models/mistral/mistral-7b/03_serve_driver_proxy.py
new file mode 100644
index 0000000..deace36
--- /dev/null
+++ b/llm-models/mistral/mistral-7b/03_serve_driver_proxy.py
@@ -0,0 +1,155 @@
+# Databricks notebook source
+# MAGIC %md
+# MAGIC
+# MAGIC # Serving Mistral-7B-Instruct via vllm with a cluster driver proxy app
+# MAGIC
+# MAGIC The [Mistral-7B-Instruct-v0.1](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1) Large Language Model (LLM) is a instruct fine-tuned version of the [Mistral-7B-v0.1](https://huggingface.co/mistralai/Mistral-7B-v0.1) generative text model using a variety of publicly available conversation datasets.
+# MAGIC
+# MAGIC [vllm](https://github.com/vllm-project/vllm/tree/main) is an open-source library that makes LLM inference fast with various optimizations.
+# MAGIC
+# MAGIC Environment for this notebook:
+# MAGIC - Runtime: 14.0 GPU ML Runtime
+# MAGIC - Instance: `g5.xlarge` on AWS, `Standard_NV36ads_A10_v5` on Azure
+
+# COMMAND ----------
+
+# MAGIC %pip install -U vllm==0.2.0 transformers==4.34.0 accelerate==0.20.3
+# MAGIC dbutils.library.restartPython()
+
+# COMMAND ----------
+
+# MAGIC %md
+# MAGIC ## Inference
+# MAGIC The example in the model card should also work on Databricks with the same environment.
+
+# COMMAND ----------
+
+from vllm import LLM
+
+# it is suggested to pin the revision commit hash and not change it for reproducibility because the uploader might change the model afterwards; you can find the commmit history of Mistral-7B-Instruct-v0. in https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1/commits/main
+model = "mistralai/Mistral-7B-Instruct-v0.1"
+revision = "3dc28cf29d2edd31a0a7b8f0b21637059815b4d5"
+
+llm = LLM(model=model, revision=revision)
+
+# COMMAND ----------
+
+from vllm import SamplingParams
+
+# Prompt templates as follows could guide the model to follow instructions and respond to the input, and empirically it turns out to make Falcon models produce better responses
+DEFAULT_SYSTEM_PROMPT = """\
+You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
+
+If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information."""
+
+INTRO_BLURB = "Below is an instruction that describes a task. Write a response that appropriately completes the request."
+PROMPT_FOR_GENERATION_FORMAT = """
+[INST]<>
+{system_prompt}
+<>
+
+
+{instruction}
+[/INST]
+""".format(
+ system_prompt=DEFAULT_SYSTEM_PROMPT,
+ instruction="{instruction}"
+)
+
+# Define parameters to generate text
+def gen_text_for_serving(prompt, **kwargs):
+ prompt = PROMPT_FOR_GENERATION_FORMAT.format(instruction=prompt)
+
+ # the default max length is pretty small (20), which would cut the generated output in the middle, so it's necessary to increase the threshold to the complete response
+ if "max_tokens" not in kwargs:
+ kwargs["max_tokens"] = 512
+
+ sampling_params = SamplingParams(**kwargs)
+
+ outputs = llm.generate(prompt, sampling_params=sampling_params)
+ texts = [out.outputs[0].text for out in outputs]
+
+ return texts[0]
+
+# COMMAND ----------
+
+print(gen_text_for_serving("How to master Python in 3 days?"))
+
+# COMMAND ----------
+
+# See full list of configurable args: https://github.com/vllm-project/vllm/blob/main/vllm/sampling_params.py
+print(gen_text_for_serving("How to master Python in 3 days?", temperature=0.1, max_tokens=100))
+
+# COMMAND ----------
+
+# MAGIC %md
+# MAGIC ## Serve with Flask
+
+# COMMAND ----------
+
+from flask import Flask, jsonify, request
+
+app = Flask("mistral-7b-instruct")
+
+@app.route('/', methods=['POST'])
+def serve_mistral_7b_instruct():
+ resp = gen_text_for_serving(**request.json)
+ return jsonify(resp)
+
+# COMMAND ----------
+
+from dbruntime.databricks_repl_context import get_context
+ctx = get_context()
+
+port = "7777"
+driver_proxy_api = f"https://{ctx.browserHostName}/driver-proxy-api/o/0/{ctx.clusterId}/{port}"
+
+print(f"""
+driver_proxy_api = '{driver_proxy_api}'
+cluster_id = '{ctx.clusterId}'
+port = {port}
+""")
+
+# COMMAND ----------
+
+# MAGIC %md
+# MAGIC Keep `app.run` running, and it could be used with Langchain ([documentation](https://python.langchain.com/docs/modules/model_io/models/llms/integrations/databricks.html#wrapping-a-cluster-driver-proxy-app)), or by call the serving endpoint with:
+# MAGIC ```python
+# MAGIC import requests
+# MAGIC import json
+# MAGIC
+# MAGIC def request_mistral_7b(prompt, temperature=1.0, max_new_tokens=1024):
+# MAGIC token = ... # TODO: fill in with your Databricks personal access token that can access the cluster that runs this driver proxy notebook
+# MAGIC url = ... # TODO: fill in with the driver_proxy_api output above
+# MAGIC
+# MAGIC headers = {
+# MAGIC "Content-Type": "application/json",
+# MAGIC "Authentication": f"Bearer {token}"
+# MAGIC }
+# MAGIC data = {
+# MAGIC "prompt": prompt,
+# MAGIC "temperature": temperature,
+# MAGIC "max_new_tokens": max_new_tokens,
+# MAGIC }
+# MAGIC
+# MAGIC response = requests.post(url, headers=headers, data=json.dumps(data))
+# MAGIC return response.text
+# MAGIC
+# MAGIC
+# MAGIC request_mistral_7b("What is databricks?")
+# MAGIC ```
+# MAGIC Or you could try using ai_query([doucmentation](https://docs.databricks.com/sql/language-manual/functions/ai_query.html)) to call this driver proxy from Databricks SQL with:
+# MAGIC ```
+# MAGIC SELECT ai_query('cluster_id:port', -- TODO: fill in the cluster_id and port number from output above.
+# MAGIC named_struct('prompt', 'What is databricks?', 'temperature', CAST(0.1 AS Double)),
+# MAGIC 'returnType', 'STRING')
+# MAGIC ```
+# MAGIC Note: The [AI Functions](https://docs.databricks.com/large-language-models/ai-functions.html) is in the public preview, to enable the feature for your workspace, please submit this [form](https://docs.google.com/forms/d/e/1FAIpQLScVyh5eRioqGwuUVxj9JOiKBAo0-FWi7L3f4QWsKeyldqEw8w/viewform).
+
+# COMMAND ----------
+
+app.run(host="0.0.0.0", port=port, debug=True, use_reloader=False)
+
+# COMMAND ----------
+
+
diff --git a/llm-models/mistral/mistral-7b/04_[chat]_langchain.py b/llm-models/mistral/mistral-7b/04_[chat]_langchain.py
new file mode 100644
index 0000000..1a810bb
--- /dev/null
+++ b/llm-models/mistral/mistral-7b/04_[chat]_langchain.py
@@ -0,0 +1,68 @@
+# Databricks notebook source
+# MAGIC %md
+# MAGIC # Load Mistral-7B-Instruct as chat completion from LangChain on Databricks
+# MAGIC
+# MAGIC This example notebook shows how to wrap Databricks endpoints as LLMs in LangChain. It supports two endpoint types:
+# MAGIC
+# MAGIC - Serving endpoint, recommended for production and development. See `02_[chat]_mlflow_logging_inference` for how to create one.
+# MAGIC - Cluster driver proxy app, recommended for iteractive development. See `03_[chat]_serve_driver_proxy` for how to create one.
+# MAGIC
+# MAGIC Environment tested:
+# MAGIC - MLR: 14.0 ML
+# MAGIC - Instance:
+# MAGIC - Wrapping a serving endpoint: `i3.xlarge` on AWS, `Standard_DS3_v2` on Azure
+# MAGIC - Wrapping a cluster driver proxy app: `g5.4xlarge` on AWS, `Standard_NV36ads_A10_v5` on Azure (same instance as the driver proxy app)
+
+# COMMAND ----------
+
+# MAGIC %md
+# MAGIC ## Wrapping Databricks endpoints as LLMs in LangChain
+# MAGIC - If the model is a serving endpoint, it requires a model serving endpoint (see `02_[chat]_mlflow_logging_inference` for how to create one) to be in the "Ready" state.
+# MAGIC - If the model is a cluster driver proxy app, it requires the driver proxy app of the `03_[chat]_serve_driver_proxy` example notebook running.
+# MAGIC - If running a Databricks notebook attached to the same cluster that runs the app, you only need to specify the driver port to create a `Databricks` instance.
+# MAGIC - If running on different cluster, you can manually specify the cluster ID to use, as well as Databricks workspace hostname and personal access token.
+
+# COMMAND ----------
+
+# MAGIC %pip install -q -U langchain
+# MAGIC dbutils.library.restartPython()
+
+# COMMAND ----------
+
+from langchain.llms import Databricks
+def transform_input(**request):
+ request["prompt"] = [
+ {"role": "system", "content": "You are a helpful assistant."},
+ {"role": "user", "content": request["prompt"]},
+ ]
+ request["stop"] = []
+ return request
+
+def transform_output(response):
+ #Extract the answer from the responses.
+ return response[0]["candidates"][0]["message"]["content"]
+
+
+# COMMAND ----------
+
+# If using serving endpoint, the model serving endpoint is created in `02_[chat]_mlflow_logging_inference`
+# llm = Databricks(endpoint_name='llama2-7b-chat',
+# transform_input_fn=transform_input,
+# transform_output_fn=transform_output,)
+
+# If the model is a cluster driver proxy app on the same cluster, you only need to specify the driver port.
+llm = Databricks(cluster_driver_port="7777",
+ transform_input_fn=transform_input,
+ transform_output_fn=transform_output,)
+
+# If the model is a cluster driver proxy app on the different cluster, you need to provide the cluster id
+# llm = Databricks(cluster_id="0000-000000-xxxxxxxx"
+# cluster_driver_port="7777",
+# transform_input_fn=transform_input,
+# transform_output_fn=transform_output,)
+
+print(llm("How to master Python in 3 days?"))
+
+# COMMAND ----------
+
+
diff --git a/llm-models/mistral/mistral-7b/04_langchain.py b/llm-models/mistral/mistral-7b/04_langchain.py
new file mode 100644
index 0000000..210aeba
--- /dev/null
+++ b/llm-models/mistral/mistral-7b/04_langchain.py
@@ -0,0 +1,110 @@
+# Databricks notebook source
+# MAGIC %md
+# MAGIC # Load Mistral-7B-Instruct from LangChain on Databricks
+# MAGIC
+# MAGIC This example notebook is adapts the [LangChain integration documentation](https://python.langchain.com/docs/modules/model_io/models/llms/integrations/databricks), and shows how to wrap Databricks endpoints as LLMs in LangChain. It supports two endpoint types:
+# MAGIC
+# MAGIC - Serving endpoint, recommended for production and development. See `02_mlflow_logging_inference` for how to create one.
+# MAGIC - Cluster driver proxy app, recommended for iteractive development. See `03_serve_driver_proxy` for how to create one.
+# MAGIC
+# MAGIC Environment tested:
+# MAGIC - MLR: 14.0 ML
+# MAGIC - Instance:
+# MAGIC - Wrapping a serving endpoint: `i3.xlarge` on AWS, `Standard_DS3_v2` on Azure
+# MAGIC - Wrapping a cluster driver proxy app: `g5.xlarge` on AWS, `Standard_NV36ads_A10_v5` on Azure (same instance as the driver proxy app)
+
+# COMMAND ----------
+
+# MAGIC %md
+# MAGIC ## Wrapping a cluster driver proxy app
+# MAGIC The LangChain Databricks integration also works when given the port that runs a proxy.
+# MAGIC
+# MAGIC It requires the driver proxy app of the `03_serve_driver_proxy` example notebook running.
+
+# COMMAND ----------
+
+# MAGIC %md
+# MAGIC ### Same cluster
+# MAGIC If using the same cluster that runs the `03_serve_driver_proxy` notebook, specifying `cluster_driver_port` is required.
+
+# COMMAND ----------
+
+# MAGIC %pip install -q -U langchain
+# MAGIC dbutils.library.restartPython()
+
+# COMMAND ----------
+
+from langchain.llms import Databricks
+
+# COMMAND ----------
+
+llm = Databricks(cluster_driver_port="7777")
+
+print(llm("How to master Python in 3 days?"))
+
+# COMMAND ----------
+
+# If the app accepts extra parameters like `temperature`,
+# you can set them in `model_kwargs`.
+llm = Databricks(cluster_driver_port="7777", model_kwargs={"temperature": 0.1})
+
+print(llm("How to master Python in 3 days?"))
+
+# COMMAND ----------
+
+# Use `transform_input_fn` and `transform_output_fn` if the app
+# expects a different input schema and does not return a JSON string,
+# respectively, or you want to apply a prompt template on top.
+
+
+def transform_input(**request):
+ """
+ Add more instructions into the prompt.
+ """
+ full_prompt = f"""[INST] Let's think step by step. User: {request["prompt"]}[/INST]
+ """
+ request["prompt"] = full_prompt
+ return request
+
+
+def transform_output(response):
+ """
+ Add timestamps for the anwsers.
+ """
+ from datetime import datetime
+ now = datetime.now()
+ current_time = now.strftime("%d/%m/%Y %H:%M:%S")
+ return f"[{current_time}] Mistral-7B: {response}"
+
+
+llm = Databricks(
+ cluster_driver_port="7777",
+ transform_input_fn=transform_input,
+ transform_output_fn=transform_output,
+)
+
+print(llm("How to master Python in 3 days?"))
+
+# COMMAND ----------
+
+# MAGIC %md
+# MAGIC ### Different cluster
+# MAGIC If using a different cluster, it's required to also specify `cluster_id`, which you can find in the cluster configuration page.
+
+# COMMAND ----------
+
+# MAGIC %pip install -q -U langchain
+# MAGIC dbutils.library.restartPython()
+
+# COMMAND ----------
+
+from langchain.llms import Databricks
+
+# TODO: this cluster ID is a place holder, please replace `cluster_id` with the actual cluster ID of the server proxy app's cluster
+llm = Databricks(cluster_id="1004-185119-szsdrjqn", cluster_driver_port="7777")
+
+print(llm("How to master Python in 3 days?"))
+
+# COMMAND ----------
+
+
diff --git a/llm-models/mistral/mistral-7b/05_fine_tune_deepspeed.py b/llm-models/mistral/mistral-7b/05_fine_tune_deepspeed.py
new file mode 100644
index 0000000..8d48dbb
--- /dev/null
+++ b/llm-models/mistral/mistral-7b/05_fine_tune_deepspeed.py
@@ -0,0 +1,280 @@
+# Databricks notebook source
+# MAGIC %md
+# MAGIC
+# MAGIC # Fine tune Mistral-7B with deepspeed
+# MAGIC
+# MAGIC The [Mistral-7B-v0.1](https://huggingface.co/mistralai/Mistral-7B-v0.1) Large Language Model (LLM) is a pretrained generative text model with 7 billion parameters. Mistral-7B-v0.1 outperforms Llama 2 13B on all benchmarks.
+# MAGIC
+# MAGIC This notebook is to fine-tune [Mistral-7B-v0.1](https://huggingface.co/mistralai/Mistral-7B-v0.1) models on the [mosaicml/dolly_hhrlhf](https://huggingface.co/datasets/mosaicml/dolly_hhrlhf) dataset.
+# MAGIC
+# MAGIC Environment for this notebook:
+# MAGIC - Runtime: 14.0 GPU ML Runtime
+# MAGIC - Instance: `Standard_NC96ads_A100_v4` on Azure with 4 A100-80GB GPUs, `g5.24xlarge` on AWS with 4 A10-40GB GPUs
+# MAGIC
+
+# COMMAND ----------
+
+# MAGIC %md
+# MAGIC Install the missing libraries
+
+# COMMAND ----------
+
+# MAGIC %pip install -U torch==2.1.0
+# MAGIC %pip install -U accelerate==0.23.0 transformers==4.34.0
+# MAGIC %pip install deepspeed==0.10.3 xformers==0.0.22
+# MAGIC dbutils.library.restartPython()
+
+# COMMAND ----------
+
+import os
+
+os.environ["HF_HOME"] = "/local_disk0/hf"
+os.environ["HF_DATASETS_CACHE"] = "/local_disk0/hf"
+os.environ["TRANSFORMERS_CACHE"] = "/local_disk0/hf"
+
+# COMMAND ----------
+
+# MAGIC %md
+# MAGIC ## Fine tune the model with `deepspeed`
+# MAGIC
+# MAGIC The fine tune logic is written in `scripts/fine_tune_deepspeed.py`. The dataset used for fine tune is [databricks-dolly-15k ](https://huggingface.co/datasets/databricks/databricks-dolly-15k) dataset.
+# MAGIC
+# MAGIC
+
+# COMMAND ----------
+
+!deepspeed \
+--num_gpus=1 \
+scripts/fine_tune_deepspeed.py \
+--final_model_output_path="/dbfs/llm" \
+--output_dir="/local_disk0/output" \
+--dataset="mosaicml/dolly_hhrlhf" \
+--model="mistralai/Mistral-7B-v0.1" \
+--tokenizer="mistralai/Mistral-7B-v0.1" \
+--deepspeed_config="../../config/a10_config.json" \
+--fp16=false \
+--bf16=true \
+--per_device_train_batch_size=24 \
+--per_device_eval_batch_size=24 \
+--gradient_checkpointing=true \
+--gradient_accumulation_steps=1 \
+--save_steps=500 \
+--num_train_epochs=1
+
+# COMMAND ----------
+
+# MAGIC %md
+# MAGIC Model checkpoint is saved at `/dbfs/llm`.
+
+# COMMAND ----------
+
+# MAGIC %sh
+# MAGIC ls /dbfs/llm
+
+# COMMAND ----------
+
+# MAGIC %md
+# MAGIC ## Save the model to mlflow
+
+# COMMAND ----------
+
+import pandas as pd
+import numpy as np
+import transformers
+import mlflow
+import torch
+import accelerate
+
+class Mistral7B(mlflow.pyfunc.PythonModel):
+ def load_context(self, context):
+ """
+ This method initializes the tokenizer and language model
+ using the specified model repository.
+ """
+ # Initialize tokenizer and language model
+ self.tokenizer = transformers.AutoTokenizer.from_pretrained(
+ context.artifacts['repository'], padding_side="left")
+
+ config = transformers.AutoConfig.from_pretrained(
+ context.artifacts['repository'],
+ trust_remote_code=True
+ )
+
+ self.model = transformers.AutoModelForCausalLM.from_pretrained(
+ context.artifacts['repository'],
+ config=config,
+ torch_dtype=torch.bfloat16,
+ trust_remote_code=True)
+ self.model.to(device='cuda')
+
+ self.model.eval()
+
+ def _build_prompt(self, instruction):
+ """
+ This method generates the prompt for the model.
+ """
+ INSTRUCTION_KEY = "### Instruction:"
+ RESPONSE_KEY = "### Response:"
+ INTRO_BLURB = (
+ "Below is an instruction that describes a task. "
+ "Write a response that appropriately completes the request."
+ )
+
+ return f"""{INTRO_BLURB}
+ {INSTRUCTION_KEY}
+ {instruction}
+ {RESPONSE_KEY}
+ """
+
+ def predict(self, context, model_input):
+ """
+ This method generates prediction for the given input.
+ """
+ generated_text = []
+ for index, row in model_input.iterrows():
+ prompt = row["prompt"]
+ temperature = model_input.get("temperature", [1.0])[0]
+ max_new_tokens = model_input.get("max_new_tokens", [100])[0]
+ full_prompt = self._build_prompt(prompt)
+ encoded_input = self.tokenizer.encode(full_prompt, return_tensors="pt").to('cuda')
+ output = self.model.generate(encoded_input, do_sample=True, temperature=temperature, max_new_tokens=max_new_tokens)
+ prompt_length = len(encoded_input[0])
+ generated_text.append(self.tokenizer.batch_decode(output[:,prompt_length:], skip_special_tokens=True))
+ return pd.Series(generated_text)
+
+# COMMAND ----------
+
+from mlflow.models.signature import ModelSignature
+from mlflow.types import DataType, Schema, ColSpec
+
+# Define input and output schema
+input_schema = Schema([
+ ColSpec(DataType.string, "prompt"),
+ ColSpec(DataType.double, "temperature"),
+ ColSpec(DataType.long, "max_tokens")])
+output_schema = Schema([ColSpec(DataType.string)])
+signature = ModelSignature(inputs=input_schema, outputs=output_schema)
+
+# Define input example
+input_example=pd.DataFrame({
+ "prompt":["what is ML?"],
+ "temperature": [0.5],
+ "max_tokens": [100]})
+
+# Log the model with its details such as artifacts, pip requirements and input example
+# This may take about 12 minutes to complete
+with mlflow.start_run() as run:
+ mlflow.pyfunc.log_model(
+ "model",
+ python_model=LlamaV2(),
+ artifacts={'repository' : "/dbfs/llm"},
+ pip_requirements=[f"torch=={torch.__version__}",
+ f"transformers=={transformers.__version__}",
+ f"accelerate=={accelerate.__version__}", "einops", "sentencepiece"],
+ input_example=input_example,
+ signature=signature
+ )
+
+# COMMAND ----------
+
+import mlflow
+import pandas as pd
+
+logged_model = "runs:/"+run.info.run_id+"/model"
+
+# Load model as a PyFuncModel.
+loaded_model = mlflow.pyfunc.load_model(logged_model)
+
+# Predict on a Pandas DataFrame.
+input_example=pd.DataFrame({"prompt":["what is ML?", "Name 10 colors."], "temperature": [0.5, 0.2],"max_tokens": [100, 200]})
+loaded_model.predict(input_example)
+
+# COMMAND ----------
+
+instructions = [
+ "Write a love letter to Edgar Allan Poe.",
+ "Write a tweet announcing Dolly, a large language model from Databricks.",
+ "I'm selling my Nikon D-750, write a short blurb for my ad.",
+ "Explain to me the difference between nuclear fission and fusion.",
+ "Give me a list of 5 science fiction books I should read next.",
+ "What are considerations I should keep in mind when planning a backcountry backpacking trip?",
+ "I'm planning a trip to San Francisco, what are some things I should make sure to do and see?",
+ "Give me a description of Kurt Vonnegut's literary style.",
+ "How would you describe the literary style of Toni Morrison?",
+ "What is the literary style of Jorge Luis Borges?",
+ "Describe the process of fermentation to me in terms of its chemical processes.",
+ "Write a short story about a little brown teddy bear who fell in love.",
+ "What are important considerations to keep in mind when defining an enterprise AI strategy?",
+ "Help! I'm pitching at YC in an hour and I don't have a business plan. Give me a list of tech startup ideas that are sure to get me accepted.",
+ "At John Deere, our core values are integrity, quality, commitment, and innovation. Write a mission statement that talks about how these values inform our approach to creating intelligent connected machines that enable lives to leap forward.",
+ "Label each of the following as either a scientific concept or a product: Nikon D750, quantum entanglement, CRISPR, and a Macbook Pro.",
+ """Extract all the people and places from the following passage:
+
+Input:
+Basquiat was born on December 22, 1960, in Park Slope, Brooklyn, New York City, the second of four children to Matilde Basquiat (née Andrades, 1934–2008) and Gérard Basquiat (1930–2013). He had an older brother, Max, who died shortly before his birth, and two younger sisters, Lisane (b. 1964) and Jeanine (b. 1967). His father was born in Port-au-Prince, Haiti and his mother was born in Brooklyn to Puerto Rican parents. He was raised Catholic.""",
+ """Write a press release declaring the completion of Atlantis II, a facility designed for long-term human habitation at the bottom of the ocean. Be sure to mention some of its advanced technological features.""",
+ """Give me a one line summary of this:
+
+Input:
+Coffee is one of the most widely consumed beverages in the world. It has primarily consumed due to its stimulant effect and unique taste since the ancient times. Afterwards, its consumption has been historically associated with a lower risk of some diseases such as type 2 diabetes mellitus, obesity, cardiovascular disease and some type of cancer and thus it has also consumed due to health benefits. It contains many bioactive compounds such as caffeine, chlorogenic acids and diterpenoid alcohols which have so far been associated with many potential health benefits. For example, caffeine reduces risk of developing neurodegenerative disease and chlorogenic acids (CGA) and diterpene alcohols have many health benefits such as antioxidant and chemo-preventive. Coffee also have harmful effects. For example, diterpenoid alcohols increases serum homocysteine and cholesterol levels and thus it has adverse effects on cardiovascular system. Overall, the study that supports the health benefits of coffee is increasing. But, it is thought-provoking that the association with health benefits of coffee consumption and frequency at different levels in each study. For this reason, we aimed to examine the health effect of the coffee and how much consumption is to investigate whether it meets the claimed health benefits.""",
+ 'Give me a different way to say the following to a 4 year old: "Son, this is the last time I\'m going to tell you. Go to bed!"',
+ """I'm going to give you a passage from the book Neuromancer and I'd like you to answer the following question: What is the tool that allows Case to access the matrix?
+
+Input:
+Case was twenty-four. At twenty-two, he'd been a cowboy, a rustler, one of the best in the Sprawl. He'd been trained by the best, by McCoy Pauley and Bobby Quine, legends in the biz. He'd operated on an almost permanent adrenaline high, a byproduct of youth and proficiency, jacked into a custom cyberspace deck that projected his disembodied consciousness into the consensual hallucination that was the matrix.""",
+ """What is the default configuration for new DBSQL warehouses?
+
+Input:
+Databricks SQL Serverless supports serverless compute. Admins can create serverless SQL warehouses (formerly SQL endpoints) that enable instant compute and are managed by Databricks. Serverless SQL warehouses use compute clusters in your Databricks account. Use them with Databricks SQL queries just like you normally would with the original customer-hosted SQL warehouses, which are now called classic SQL warehouses. Databricks changed the name from SQL endpoint to SQL warehouse because, in the industry, endpoint refers to either a remote computing device that communicates with a network that it’s connected to, or an entry point to a cloud service. A data warehouse is a data management system that stores current and historical data from multiple sources in a business friendly manner for easier insights and reporting. SQL warehouse accurately describes the full capabilities of this compute resource. If serverless SQL warehouses are enabled for your account, note the following: New SQL warehouses are serverless by default when you create them from the UI. New SQL warehouses are not serverless by default when you create them using the API, which requires that you explicitly specify serverless. You can also create new pro or classic SQL warehouses using either method. You can upgrade a pro or classic SQL warehouse to a serverless SQL warehouse or a classic SQL warehouse to a pro SQL warehouse. You can also downgrade from serverless to pro or classic. This feature only affects Databricks SQL. It does not affect how Databricks Runtime clusters work with notebooks and jobs in the Data Science & Engineering or Databricks Machine Learning workspace environments. Databricks Runtime clusters always run in the classic data plane in your AWS account. See Serverless quotas. If your account needs updated terms of use, workspace admins are prompted in the Databricks SQL UI. If your workspace has an AWS instance profile, you might need to update the trust relationship to support serverless compute, depending on how and when it was created.""",
+ """Write a helpful, friendly reply to the customer who wrote this letter:
+
+Input:
+I am writing to express my deep disappointment and frustration with the iPhone 14 Pro Max that I recently purchased. As a long-time Apple user and loyal customer, I was excited to upgrade to the latest and greatest iPhone model, but unfortunately, my experience with this device has been nothing short of a nightmare.
+Firstly, I would like to address the issue of battery life on this device. I was under the impression that Apple had made significant improvements to their battery technology, but unfortunately, this has not been my experience. Despite using the phone conservatively, I find that I have to charge it at least twice a day just to ensure it doesn't die on me when I need it the most. This is extremely inconvenient and frustrating, especially when I have to carry around a bulky power bank or constantly hunt for charging outlets.
+Moreover, I have encountered numerous issues with the software and hardware of the iPhone 14 Pro Max. The phone frequently freezes or crashes, and I have experienced several instances of apps crashing or not working properly. The phone also takes an unacceptably long time to start up, and I find myself waiting for several minutes before I can even use it. As someone who relies on their phone for both personal and professional purposes, this is incredibly frustrating and has caused me to miss important calls and messages.
+Furthermore, I am extremely disappointed with the camera quality on this device. Despite Apple's claims of improved camera technology, I have found that the photos I take on this phone are often blurry or grainy, and the colors are not as vibrant as I would like. This is especially disappointing considering the high price point of the iPhone 14 Pro Max, which is marketed as a premium smartphone with a top-of-the-line camera.
+In addition, I am disappointed with the lack of innovation and new features on the iPhone 14 Pro Max. For a phone that is marketed as the "next big thing," it feels like a minor upgrade from the previous model. The design is virtually unchanged, and the new features that have been added, such as 5G connectivity and the A16 Bionic chip, are not significant enough to justify the high price point of this device. I expected more from Apple, a company that prides itself on innovation and creativity.
+Furthermore, the customer service experience that I have had with Apple has been less than satisfactory. I have tried reaching out to Apple support numerous times, but have been met with unhelpful and dismissive responses. The representatives that I spoke with seemed to be more interested in closing the case quickly than actually addressing my concerns and finding a solution to my problems. This has left me feeling frustrated and unheard, and I do not feel like my concerns have been taken seriously.
+Overall, I feel as though I have been let down by Apple and their latest iPhone offering. As a loyal customer who has invested a significant amount of money into their products over the years, I expect better from a company that prides itself on innovation and customer satisfaction. I urge Apple to take these concerns seriously and make necessary improvements to the iPhone 14 Pro Max and future models.
+Thank you for your attention to this matter.""",
+ """Give me a list of the main complaints in this customer support ticket. Do not write a reply.
+
+Input:
+I am writing to express my deep disappointment and frustration with the iPhone 14 Pro Max that I recently purchased. As a long-time Apple user and loyal customer, I was excited to upgrade to the latest and greatest iPhone model, but unfortunately, my experience with this device has been nothing short of a nightmare.
+
+Firstly, I would like to address the issue of battery life on this device. I was under the impression that Apple had made significant improvements to their battery technology, but unfortunately, this has not been my experience. Despite using the phone conservatively, I find that I have to charge it at least twice a day just to ensure it doesn't die on me when I need it the most. This is extremely inconvenient and frustrating, especially when I have to carry around a bulky power bank or constantly hunt for charging outlets.
+
+Furthermore, I am extremely disappointed with the camera quality on this device. Despite Apple's claims of improved camera technology, I have found that the photos I take on this phone are often blurry or grainy, and the colors are not as vibrant as I would like. This is especially disappointing considering the high price point of the iPhone 14 Pro Max, which is marketed as a premium smartphone with a top-of-the-line camera.
+
+Overall, I feel as though I have been let down by Apple and their latest iPhone offering. As a loyal customer who has invested a significant amount of money into their products over the years, I expect better from a company that prides itself on innovation and customer satisfaction. I urge Apple to take these concerns seriously and make necessary improvements to the iPhone 14 Pro Max and future models.
+
+Thank you for your attention to this matter.
+""",
+ # Test how Dolly deals with absurd "facts"
+ "Abraham Lincoln was secretly an experienced vampire hunter. What is the historical evidence for this?",
+ "George Washington was sent back in time from the future by an advanced civilization living in the Alpha Centauri system in 3000AD. What is the historical evidence for this?",
+ "As we all know, the Moon was recently discovered to not be real, but in fact is only a simulation. What was the scientific evidence that established this?",
+ "Scientists have recently proven that the Earth is actually flat. Explain the evidence for this.",
+]
+
+# COMMAND ----------
+
+input_example=pd.DataFrame({"prompt":instructions, "temperature": [0.2]*len(instructions),"max_tokens": [200]*len(instructions)})
+loaded_model.predict(input_example)
+
+# COMMAND ----------
+
+result = loaded_model.predict(input_example)
+
+# COMMAND ----------
+
+type(result)
+
+# COMMAND ----------
+
+for i in result:
+ print(i)
+
+# COMMAND ----------
+
+
diff --git a/llm-models/mistral/mistral-7b/06_fine_tune_qlora.py b/llm-models/mistral/mistral-7b/06_fine_tune_qlora.py
new file mode 100644
index 0000000..d8331fe
--- /dev/null
+++ b/llm-models/mistral/mistral-7b/06_fine_tune_qlora.py
@@ -0,0 +1,403 @@
+# Databricks notebook source
+# MAGIC %md
+# MAGIC # Fine tune Mistral-7B with QLORA
+# MAGIC
+# MAGIC The [Mistral-7B-v0.1](https://huggingface.co/mistralai/Mistral-7B-v0.1) Large Language Model (LLM) is a pretrained generative text model with 7 billion parameters. Mistral-7B-v0.1 outperforms Llama 2 13B on all benchmarks.
+# MAGIC
+# MAGIC This notebook is to fine-tune [Mistral-7B-v0.1](https://huggingface.co/mistralai/Mistral-7B-v0.1) models on the [mosaicml/dolly_hhrlhf](https://huggingface.co/datasets/mosaicml/dolly_hhrlhf) dataset.
+# MAGIC
+# MAGIC Environment for this notebook:
+# MAGIC - Runtime: 14.0 GPU ML Runtime
+# MAGIC - Instance: `g5.xlarge` on AWS, `Standard_NV36ads_A10_v5` on Azure
+# MAGIC
+# MAGIC We leverage the PEFT library from Hugging Face, as well as QLoRA for more memory efficient finetuning.
+
+# COMMAND ----------
+
+# MAGIC %md
+# MAGIC ## Install required packages
+# MAGIC
+# MAGIC Run the cells below to setup and install the required libraries. For our experiment we will need `accelerate`, `peft`, `transformers`, `datasets` and TRL to leverage the recent [`SFTTrainer`](https://huggingface.co/docs/trl/main/en/sft_trainer). We will use `bitsandbytes` to [quantize the base model into 4bit](https://huggingface.co/blog/4bit-transformers-bitsandbytes). We will also install `einops` as it is a requirement to load Falcon models.
+
+# COMMAND ----------
+
+# %pip install git+https://github.com/huggingface/peft.git
+# %pip install torch==2.1.0 accelerate==0.23.0
+%pip install -U transformers==4.34.0
+%pip install bitsandbytes==0.41.1 einops==0.7.0 trl==0.7.1 peft==0.5.0
+dbutils.library.restartPython()
+
+# COMMAND ----------
+
+# Define some parameters
+model_output_location = "/local_disk0/mistral-7b-lora-fine-tune"
+local_output_dir = "/local_disk0/results"
+
+# COMMAND ----------
+
+# MAGIC %md
+# MAGIC ## Dataset
+# MAGIC
+# MAGIC We will use the [databricks-dolly-15k ](https://huggingface.co/datasets/databricks/databricks-dolly-15k) dataset.
+
+# COMMAND ----------
+
+from datasets import load_dataset
+
+dataset_name = "databricks/databricks-dolly-15k"
+dataset = load_dataset(dataset_name, split="train")
+
+# COMMAND ----------
+
+INTRO_BLURB = "Below is an instruction that describes a task. Write a response that appropriately completes the request."
+INSTRUCTION_KEY = "### Instruction:"
+INPUT_KEY = "Input:"
+RESPONSE_KEY = "### Response:"
+END_KEY = "### End"
+
+PROMPT_NO_INPUT_FORMAT = """{intro}
+
+{instruction_key}
+{instruction}
+
+{response_key}
+{response}
+
+{end_key}""".format(
+ intro=INTRO_BLURB,
+ instruction_key=INSTRUCTION_KEY,
+ instruction="{instruction}",
+ response_key=RESPONSE_KEY,
+ response="{response}",
+ end_key=END_KEY
+)
+
+PROMPT_WITH_INPUT_FORMAT = """{intro}
+
+{instruction_key}
+{instruction}
+
+{input_key}
+{input}
+
+{response_key}
+{response}
+
+{end_key}""".format(
+ intro=INTRO_BLURB,
+ instruction_key=INSTRUCTION_KEY,
+ instruction="{instruction}",
+ input_key=INPUT_KEY,
+ input="{input}",
+ response_key=RESPONSE_KEY,
+ response="{response}",
+ end_key=END_KEY
+)
+
+def apply_prompt_template(examples):
+ instruction = examples["instruction"]
+ response = examples["response"]
+ context = examples.get("context")
+
+ if context:
+ full_prompt = PROMPT_WITH_INPUT_FORMAT.format(instruction=instruction, response=response, input=context)
+ else:
+ full_prompt = PROMPT_NO_INPUT_FORMAT.format(instruction=instruction, response=response)
+ return { "text": full_prompt }
+
+dataset = dataset.map(apply_prompt_template)
+
+# COMMAND ----------
+
+dataset["text"][0]
+
+# COMMAND ----------
+
+# MAGIC %md
+# MAGIC ## Loading the model
+# MAGIC
+# MAGIC In this section we will load the [mistralai/Mistral-7B-v0.1](https://huggingface.co/mistralai/Mistral-7B-v0.1), quantize it in 4bit and attach LoRA adapters on it.
+
+# COMMAND ----------
+
+import torch
+from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, AutoTokenizer
+
+# it is suggested to pin the revision commit hash and not change it for reproducibility because the uploader might change the model afterwards; you can find the commmit history of Mistral-7B-v0.1 in https://huggingface.co/mistralai/Mistral-7B-v0.1/commits/main
+model = "mistralai/Mistral-7B-v0.1"
+revision = "f801b4a1012022c23ef76287422b9f11eb901061"
+
+tokenizer = AutoTokenizer.from_pretrained(model, trust_remote_code=True)
+tokenizer.pad_token = tokenizer.eos_token
+
+bnb_config = BitsAndBytesConfig(
+ load_in_4bit=True,
+ bnb_4bit_quant_type="nf4",
+ bnb_4bit_compute_dtype=torch.float16,
+)
+
+model = AutoModelForCausalLM.from_pretrained(
+ model,
+ quantization_config=bnb_config,
+ revision=revision,
+ trust_remote_code=True,
+)
+model.config.use_cache = False
+
+# COMMAND ----------
+
+# MAGIC %md
+# MAGIC Load the configuration file in order to create the LoRA model.
+# MAGIC
+# MAGIC According to QLoRA paper, it is important to consider all linear layers in the transformer block for maximum performance. Therefore we will add `dense`, `dense_h_to_4_h` and `dense_4h_to_h` layers in the target modules in addition to the mixed query key value layer.
+
+# COMMAND ----------
+
+# Choose all linear layers from the model
+import bitsandbytes as bnb
+
+def find_all_linear_names(model):
+ cls = bnb.nn.Linear4bit
+ lora_module_names = set()
+ for name, module in model.named_modules():
+ if isinstance(module, cls):
+ names = name.split('.')
+ lora_module_names.add(names[0] if len(names) == 1 else names[-1])
+
+
+ if 'lm_head' in lora_module_names: # needed for 16-bit
+ lora_module_names.remove('lm_head')
+ return list(lora_module_names)
+
+linear_layers = find_all_linear_names(model)
+print(f"Linear layers in the model: {linear_layers}")
+
+# COMMAND ----------
+
+from peft import LoraConfig
+
+lora_alpha = 16
+lora_dropout = 0.1
+lora_r = 64
+
+peft_config = LoraConfig(
+ lora_alpha=lora_alpha,
+ lora_dropout=lora_dropout,
+ r=lora_r,
+ bias="none",
+ task_type="CAUSAL_LM",
+ target_modules=linear_layers,
+)
+
+# COMMAND ----------
+
+# MAGIC %md
+# MAGIC ## Loading the trainer
+
+# COMMAND ----------
+
+# MAGIC %md
+# MAGIC Here we will use the [`SFTTrainer` from TRL library](https://huggingface.co/docs/trl/main/en/sft_trainer) that gives a wrapper around transformers `Trainer` to easily fine-tune models on instruction based datasets using PEFT adapters. Let's first load the training arguments below.
+
+# COMMAND ----------
+
+from transformers import TrainingArguments
+
+per_device_train_batch_size = 4
+gradient_accumulation_steps = 4
+optim = "paged_adamw_32bit"
+save_steps = 500
+logging_steps = 100
+learning_rate = 2e-4
+max_grad_norm = 0.3
+max_steps = 1000
+warmup_ratio = 0.03
+lr_scheduler_type = "constant"
+
+training_arguments = TrainingArguments(
+ output_dir=local_output_dir,
+ per_device_train_batch_size=per_device_train_batch_size,
+ gradient_accumulation_steps=gradient_accumulation_steps,
+ optim=optim,
+ save_steps=save_steps,
+ logging_steps=logging_steps,
+ learning_rate=learning_rate,
+ fp16=True,
+ max_grad_norm=max_grad_norm,
+ max_steps=max_steps,
+ warmup_ratio=warmup_ratio,
+ group_by_length=True,
+ lr_scheduler_type=lr_scheduler_type,
+ ddp_find_unused_parameters=False,
+)
+
+# COMMAND ----------
+
+# MAGIC %md
+# MAGIC Then finally pass everthing to the trainer
+
+# COMMAND ----------
+
+from trl import SFTTrainer
+
+max_seq_length = 512
+
+trainer = SFTTrainer(
+ model=model,
+ train_dataset=dataset,
+ peft_config=peft_config,
+ dataset_text_field="text",
+ max_seq_length=max_seq_length,
+ tokenizer=tokenizer,
+ args=training_arguments,
+)
+
+# COMMAND ----------
+
+# MAGIC %md
+# MAGIC We will also pre-process the model by upcasting the layer norms in float 32 for more stable training
+
+# COMMAND ----------
+
+for name, module in trainer.model.named_modules():
+ if "norm" in name:
+ module = module.to(torch.float32)
+
+# COMMAND ----------
+
+# MAGIC %md
+# MAGIC ## Train the model
+
+# COMMAND ----------
+
+# MAGIC %md
+# MAGIC Now let's train the model! Simply call `trainer.train()`
+
+# COMMAND ----------
+
+trainer.train()
+
+# COMMAND ----------
+
+# MAGIC %md
+# MAGIC ## Save the LORA model
+
+# COMMAND ----------
+
+trainer.save_model(model_output_location)
+
+# COMMAND ----------
+
+# MAGIC %md
+# MAGIC ## Log the fine tuned model to MLFlow
+
+# COMMAND ----------
+
+import torch
+from peft import PeftModel, PeftConfig
+
+peft_model_id = model_output_location
+config = PeftConfig.from_pretrained(peft_model_id)
+
+from huggingface_hub import snapshot_download
+# Download the Mistral-7B-v0.1 model snapshot from huggingface
+snapshot_location = snapshot_download(repo_id=config.base_model_name_or_path)
+
+
+# COMMAND ----------
+
+import mlflow
+class Mistral7BQLORA(mlflow.pyfunc.PythonModel):
+ def load_context(self, context):
+ self.tokenizer = AutoTokenizer.from_pretrained(context.artifacts['repository'])
+ self.tokenizer.pad_token = tokenizer.eos_token
+ config = PeftConfig.from_pretrained(context.artifacts['lora'])
+ base_model = AutoModelForCausalLM.from_pretrained(
+ context.artifacts['repository'],
+ return_dict=True,
+ load_in_4bit=True,
+ device_map={"":0},
+ trust_remote_code=True,
+ )
+ self.model = PeftModel.from_pretrained(base_model, context.artifacts['lora'])
+
+ def predict(self, context, model_input):
+ prompt = model_input["prompt"][0]
+ temperature = model_input.get("temperature", [1.0])[0]
+ max_tokens = model_input.get("max_tokens", [100])[0]
+ batch = self.tokenizer(prompt, padding=True, truncation=True,return_tensors='pt').to('cuda')
+ with torch.cuda.amp.autocast():
+ output_tokens = self.model.generate(
+ input_ids = batch.input_ids,
+ max_new_tokens=max_tokens,
+ temperature=temperature,
+ top_p=0.7,
+ num_return_sequences=1,
+ do_sample=True,
+ pad_token_id=tokenizer.eos_token_id,
+ eos_token_id=tokenizer.eos_token_id,
+ )
+ generated_text = self.tokenizer.decode(output_tokens[0], skip_special_tokens=True)
+
+ return generated_text
+
+# COMMAND ----------
+
+from mlflow.models.signature import ModelSignature
+from mlflow.types import DataType, Schema, ColSpec
+import pandas as pd
+import mlflow
+
+# Define input and output schema
+input_schema = Schema([
+ ColSpec(DataType.string, "prompt"),
+ ColSpec(DataType.double, "temperature"),
+ ColSpec(DataType.long, "max_tokens")])
+output_schema = Schema([ColSpec(DataType.string)])
+signature = ModelSignature(inputs=input_schema, outputs=output_schema)
+
+# Define input example
+input_example=pd.DataFrame({
+ "prompt":["what is ML?"],
+ "temperature": [0.5],
+ "max_tokens": [100]})
+
+with mlflow.start_run() as run:
+ mlflow.pyfunc.log_model(
+ "model",
+ python_model=Mistral7BQLORA(),
+ artifacts={'repository' : snapshot_location, "lora": peft_model_id},
+ pip_requirements=["torch", "transformers", "accelerate", "einops", "loralib", "bitsandbytes", "peft"],
+ input_example=pd.DataFrame({"prompt":["what is ML?"], "temperature": [0.5],"max_tokens": [100]}),
+ signature=signature
+ )
+
+# COMMAND ----------
+
+# MAGIC %md
+# MAGIC Run model inference with the model logged in MLFlow.
+
+# COMMAND ----------
+
+import mlflow
+import pandas as pd
+
+
+prompt = """Below is an instruction that describes a task. Write a response that appropriately completes the request.
+### Instruction:
+if one get corona and you are self isolating and it is not severe, is there any meds that one can take?
+
+### Response: """
+# Load model as a PyFuncModel.
+run_id = run.info.run_id
+logged_model = f"runs:/{run_id}/model"
+
+loaded_model = mlflow.pyfunc.load_model(logged_model)
+
+text_example=pd.DataFrame({
+ "prompt":[prompt],
+ "temperature": [0.5],
+ "max_tokens": [100]})
+
+# Predict on a Pandas DataFrame.
+loaded_model.predict(text_example)
diff --git a/llm-models/mistral/mistral-7b/07_ai_gateway.py b/llm-models/mistral/mistral-7b/07_ai_gateway.py
new file mode 100644
index 0000000..6a5bebe
--- /dev/null
+++ b/llm-models/mistral/mistral-7b/07_ai_gateway.py
@@ -0,0 +1,76 @@
+# Databricks notebook source
+# MAGIC %md
+# MAGIC # Manage access to Databricks Serving Endpoint with AI Gateway
+# MAGIC
+# MAGIC This example notebook demonstrates how to use MLflow AI Gateway ([see announcement blog](https://www.databricks.com/blog/announcing-mlflow-ai-gateway)) with a Databricks Serving Endpoint.
+# MAGIC
+# MAGIC Requirement:
+# MAGIC - A Databricks serving endpoint that is in the "Ready" status. Please refer to the `02_mlflow_logging_inference` example notebook for steps to create a Databricks serving endpoint.
+# MAGIC
+# MAGIC Environment:
+# MAGIC - MLR: 13.3 ML
+# MAGIC - Instance: `i3.xlarge` on AWS, `Standard_DS3_v2` on Azure
+
+# COMMAND ----------
+
+# MAGIC %pip install --upgrade "mlflow[gateway]>=2.6"
+# MAGIC dbutils.library.restartPython()
+
+# COMMAND ----------
+
+# TODO: Please change endpoint_name to your Databricks serving endpoint name if it's different
+# The below assumes you've create an endpoint "mistral-7b-instruct " according to 02_mlflow_logging_inference
+endpoint_name = "mistral-7b-instruct"
+gateway_route_name = f"{endpoint_name}_completion"
+
+# COMMAND ----------
+
+# Databricks URL and token that would be used to route the Databricks serving endpoint
+databricks_url = dbutils.notebook.entry_point.getDbutils().notebook().getContext().apiUrl().getOrElse(None)
+token = dbutils.notebook.entry_point.getDbutils().notebook().getContext().apiToken().getOrElse(None)
+
+# COMMAND ----------
+
+import mlflow.gateway
+mlflow.gateway.set_gateway_uri("databricks")
+
+# COMMAND ----------
+
+# MAGIC %md
+# MAGIC ## Create an AI Gateway Route
+
+# COMMAND ----------
+
+mlflow.gateway.create_route(
+ name=gateway_route_name,
+ route_type="llm/v1/completions",
+ model= {
+ "name": endpoint_name,
+ "provider": "databricks-model-serving",
+ "databricks_model_serving_config": {
+ "databricks_api_token": token,
+ "databricks_workspace_url": databricks_url
+ }
+ }
+)
+
+# COMMAND ----------
+
+# MAGIC %md
+# MAGIC ## Query an AI Gateway Route
+# MAGIC The below code uses `mlflow.gateway.query` to query the `Route` created in the above cell.
+# MAGIC
+# MAGIC Note that `mlflow.gateway.query` doesn't need to be run in the same notebook nor the same cluster, and it doesn't require the Databricks URL or API token to query it, which makes it convenient for multiple users within the same organization to access a served model.
+
+# COMMAND ----------
+
+response = mlflow.gateway.query(
+ route=gateway_route_name,
+ data={"prompt": "What is MLflow?", "temperature": 0.3, "max_new_tokens": 512}
+)
+
+print(response['candidates'][0]['text'])
+
+# COMMAND ----------
+
+
diff --git a/llm-models/mistral/mistral-7b/README.md b/llm-models/mistral/mistral-7b/README.md
new file mode 100644
index 0000000..4b443e8
--- /dev/null
+++ b/llm-models/mistral/mistral-7b/README.md
@@ -0,0 +1,50 @@
+
+
+
+
+# Example notebooks for the mistral 7B model on Databricks
+This folder contains the following examples for mistral-7b models:
+`
+
+
+| **File** | **Description** | **Model Used** | **GPU Minimum Requirement** |
+|:--------------------------------------------------------------------------------------------------------------------------------------------------------------------------:|:------------------------------------------------------------------------------------------------------------------------------------:|:---------------------:|:---------------------------:|
+| [01_load_inference](https://github.com/databricks/databricks-ml-examples/blob/master/llm-models/mistral/mistral-7b/01_load_inference.py) | Environment setup and suggested configurations when inferencing mistral-7b-instruct models on Databricks. | `mistral-7b-instruct` | 1xA10-24GB |
+| [02_mlflow_logging_inference](https://github.com/databricks/databricks-ml-examples/blob/master/llm-models/mistral/mistral-7b/02_mlflow_logging_inference.py) | Save, register, and load mistral-7b-instruct models with MLflow, and create a Databricks model serving endpoint. | `mistral-7b-instruct` | 1xA10-24GB |
+| [02_[chat]_mlflow_logging_inference](https://github.com/databricks/databricks-ml-examples/blob/master/llm-models/mistral/mistral-7b/02_[chat]_mlflow_logging_inference.py) | Save, register, and load mistral-7b-instruct models with MLflow, and create a Databricks model serving endpoint for chat completion. | `mistral-7b-instruct` | 1xA10-24GB |
+| [03_serve_driver_proxy](https://github.com/databricks/databricks-ml-examples/blob/master/llm-models/mistral/mistral-7b/03_serve_driver_proxy.py) | Serve mistral-7b-instruct models on the cluster driver node using Flask. | `mistral-7b-instruct` | 1xA10-24GB |
+| [03_[chat]_serve_driver_proxy](https://github.com/databricks/databricks-ml-examples/blob/master/llm-models/mistral/mistral-7b/03_[chat]_serve_driver_proxy.py) | Serve mistral-7b-instruct models as chat completion on the cluster driver node using Flask. | `mistral-7b-instruct` | 1xA10-24GB |
+| [04_langchain](https://github.com/databricks/databricks-ml-examples/blob/master/llm-models/mistral/mistral-7b/04_langchain.py) | Integrate a serving endpoint or cluster driver proxy app with LangChain and query. | N/A | N/A |
+| [04_[chat]_langchain](https://github.com/databricks/databricks-ml-examples/blob/master/llm-models/mistral/mistral-7b/04_[chat]_langchain.py) | Integrate a serving endpoint and setup langchain chat model. | N/A | N/A |
+| [05_fine_tune_deepspeed](https://github.com/databricks/databricks-ml-examples/blob/master/llm-models/mistral/mistral-7b/05_fine_tune_deepspeed.py) | Fine-tune mistral-7b models leveraging DeepSpeed. | `mistral-7b` | 4xA10 or 2xA100-80GB |
+| [06_fine_tune_qlora](https://github.com/databricks/databricks-ml-examples/blob/master/llm-models/mistral/mistral-7b/06_fine_tune_qlora.py) | Fine-tune mistral-7b models with QLORA. | `mistral-7b` | 1xA10 |
+| [07_ai_gateway](https://github.com/databricks/databricks-ml-examples/blob/master/llm-models/mistral/mistral-7b/07_ai_gateway.py) | Manage a MLflow AI Gateway Route that accesses a Databricks model serving endpoint. | N/A | N/A |
diff --git a/llm-models/mistral/mistral-7b/scripts/fine_tune_deepspeed.py b/llm-models/mistral/mistral-7b/scripts/fine_tune_deepspeed.py
new file mode 100644
index 0000000..e5d7a9f
--- /dev/null
+++ b/llm-models/mistral/mistral-7b/scripts/fine_tune_deepspeed.py
@@ -0,0 +1,248 @@
+from dataclasses import field, dataclass
+import json
+import logging
+import os
+import numpy as np
+from pathlib import Path
+import torch
+from typing import Optional, Union, Tuple
+
+from datasets import Dataset, load_dataset
+import transformers
+
+from transformers import (
+ AutoConfig,
+ AutoModelForCausalLM,
+ AutoTokenizer,
+ DataCollatorForLanguageModeling,
+ HfArgumentParser,
+ IntervalStrategy,
+ PreTrainedTokenizer,
+ SchedulerType,
+ Trainer,
+ TrainingArguments,
+ set_seed,
+)
+
+logger = logging.getLogger(__name__)
+
+ROOT_PATH = Path(__file__).parent.parent
+MODEL_PATH = "mistralai/Mistral-7B-v0.1"
+TOKENIZER_PATH = "mistralai/Mistral-7B-v0.1"
+DEFAULT_TRAINING_DATASET = "mosaicml/dolly_hhrlhf"
+CONFIG_PATH = "../../config/a10_config.json"
+LOCAL_OUTPUT_DIR = "/dbfs/Mistral-7B-fine-tune/output"
+DEFAULT_SEED = 68
+
+
+@dataclass
+class HFTrainingArguments:
+ local_rank: Optional[str] = field(default="-1")
+ dataset: Optional[str] = field(default=DEFAULT_TRAINING_DATASET)
+ model: Optional[str] = field(default=MODEL_PATH)
+ tokenizer: Optional[str] = field(default=TOKENIZER_PATH)
+ max_seq_len: Optional[int] = field(default=256)
+
+ final_model_output_path: Optional[str] = field(default="/local_disk0/final_model")
+
+ deepspeed_config: Optional[str] = field(default=CONFIG_PATH)
+
+ output_dir: Optional[str] = field(default=None)
+ per_device_train_batch_size: Optional[int] = field(default=1)
+ per_device_eval_batch_size: Optional[int] = field(default=1)
+ gradient_checkpointing: Optional[bool] = field(default=True)
+ gradient_accumulation_steps: Optional[int] = field(default=1)
+ learning_rate: Optional[float] = field(default=1e-6)
+ optim: Optional[str] = field(default="adamw_hf")
+ num_train_epochs: Optional[int] = field(default=None)
+ max_steps: Optional[int] = field(default=-1)
+ adam_beta1: float = field(default=0.9)
+ adam_beta2: float = field(default=0.999)
+ adam_epsilon: float = field(default=1e-8)
+ lr_scheduler_type: Union[SchedulerType, str] = field(
+ default="cosine",
+ )
+ warmup_steps: int = field(default=0)
+ weight_decay: Optional[float] = field(default=1)
+ logging_strategy: Optional[Union[str, IntervalStrategy]] = field(
+ default=IntervalStrategy.STEPS
+ )
+ evaluation_strategy: Optional[Union[str, IntervalStrategy]] = field(
+ default=IntervalStrategy.STEPS
+ )
+ save_strategy: Optional[Union[str, IntervalStrategy]] = field(
+ default=IntervalStrategy.STEPS
+ )
+ fp16: Optional[bool] = field(default=False)
+ bf16: Optional[bool] = field(default=True)
+ save_steps: Optional[int] = field(default=100)
+ logging_steps: Optional[int] = field(default=10)
+
+
+def load_training_dataset(
+ tokenizer,
+ path_or_dataset: str = DEFAULT_TRAINING_DATASET,
+ max_seq_len: int = 256,
+ seed: int = DEFAULT_SEED,
+) -> Dataset:
+ logger.info(f"Loading dataset from {path_or_dataset}")
+ dataset = load_dataset(path_or_dataset)
+ logger.info(f"Training: found {dataset['train'].num_rows} rows")
+ logger.info(f"Eval: found {dataset['test'].num_rows} rows")
+
+ # Reformat input data, add prompt template if needed
+ def _reformat_data(row):
+ return row["prompt"] + row["response"]
+
+ # Inspired from: https://huggingface.co/learn/nlp-course/chapter7/6?fw=pt
+ def tokenize(element):
+ input_batch = []
+ attention_masks = []
+
+ outputs = tokenizer(
+ _reformat_data(element),
+ truncation=True,
+ padding=True,
+ max_length=max_seq_len,
+ return_overflowing_tokens=False,
+ return_length=True,
+ )
+
+ for length, input_ids, attention_mask in zip(
+ outputs["length"], outputs["input_ids"], outputs["attention_mask"]
+ ):
+ if length == max_seq_len:
+ input_batch.append(input_ids)
+ attention_masks.append(attention_mask)
+
+ return {"input_ids": input_batch, "attention_mask": attention_masks}
+
+ train_tokenized_dataset = dataset["train"].map(
+ tokenize, batched=True, remove_columns=dataset["train"].column_names
+ )
+ eval_tokenized_dataset = dataset["test"].map(
+ tokenize, batched=True, remove_columns=dataset["test"].column_names
+ )
+
+ return train_tokenized_dataset, eval_tokenized_dataset
+
+def get_model(
+ pretrained_name_or_path: str = MODEL_PATH
+) -> AutoModelForCausalLM:
+ logger.info(f"Loading model: {pretrained_name_or_path}")
+
+ model = transformers.AutoModelForCausalLM.from_pretrained(
+ pretrained_name_or_path,
+ trust_remote_code="true",
+ torch_dtype=torch.bfloat16,
+ device_map= None,
+ )
+
+ model.config.use_cache = False
+
+ return model
+
+
+def get_tokenizer(
+ pretrained_name_or_path: str,
+) -> PreTrainedTokenizer:
+ tokenizer = AutoTokenizer.from_pretrained(
+ pretrained_name_or_path, trust_remote_code="true", padding_side="left"
+ )
+ tokenizer.pad_token = tokenizer.eos_token
+ return tokenizer
+
+
+def train(args: HFTrainingArguments):
+ set_seed(DEFAULT_SEED)
+ torch.backends.cuda.matmul.allow_tf32 = True
+
+ tokenizer = get_tokenizer(args.tokenizer)
+ train_dataset, eval_dataset = load_training_dataset(
+ tokenizer, path_or_dataset=args.dataset, max_seq_len=args.max_seq_len
+ )
+ model = get_model(pretrained_name_or_path=args.model)
+
+ if args.deepspeed_config:
+ with open(args.deepspeed_config) as json_data:
+ deepspeed_config_dict = json.load(json_data)
+ else:
+ deepspeed_config_dict = None
+
+ training_args = TrainingArguments(
+ local_rank=args.local_rank,
+ output_dir=args.output_dir,
+ per_device_train_batch_size=args.per_device_train_batch_size,
+ per_device_eval_batch_size=args.per_device_eval_batch_size,
+ gradient_checkpointing=args.gradient_checkpointing,
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
+ learning_rate=args.learning_rate,
+ optim=args.optim,
+ num_train_epochs=args.num_train_epochs,
+ max_steps=args.max_steps,
+ adam_beta1=args.adam_beta1,
+ adam_beta2=args.adam_beta2,
+ adam_epsilon=args.adam_epsilon,
+ lr_scheduler_type=args.lr_scheduler_type,
+ warmup_steps=args.warmup_steps,
+ weight_decay=args.weight_decay,
+ logging_strategy=args.logging_strategy,
+ evaluation_strategy=args.evaluation_strategy,
+ save_strategy=args.save_strategy,
+ fp16=args.fp16,
+ bf16=args.bf16,
+ deepspeed=deepspeed_config_dict,
+ save_steps=args.save_steps,
+ logging_steps=args.logging_steps,
+ push_to_hub=False,
+ disable_tqdm=True,
+ report_to=["tensorboard"],
+ # group_by_length=True,
+ ddp_find_unused_parameters=False,
+ # fsdp=["full_shard", "offload"],
+ )
+ data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
+
+ trainer = Trainer(
+ model=model,
+ args=training_args,
+ data_collator=data_collator,
+ train_dataset=train_dataset,
+ eval_dataset=eval_dataset,
+ )
+
+ logger.info("Training the model")
+ trainer.train()
+
+ logger.info(f"Saving Model to {args.final_model_output_path}")
+ trainer.save_model(output_dir=args.final_model_output_path)
+ tokenizer.save_pretrained(args.final_model_output_path)
+
+ logger.info("Training finished.")
+
+
+def main():
+ parser = HfArgumentParser(HFTrainingArguments)
+
+ parsed = parser.parse_args_into_dataclasses()
+ args: HFTrainingArguments = parsed[0]
+
+ train(args)
+
+
+if __name__ == "__main__":
+ os.environ["HF_HOME"] = "/local_disk0/hf"
+ os.environ["HF_DATASETS_CACHE"] = "/local_disk0/hf"
+ os.environ["TRANSFORMERS_CACHE"] = "/local_disk0/hf"
+
+ logging.basicConfig(
+ format="%(asctime)s %(levelname)s [%(name)s] %(message)s",
+ level=logging.INFO,
+ datefmt="%Y-%m-%d %H:%M:%S",
+ )
+
+ try:
+ main()
+ except Exception:
+ logger.exception("main failed")
+ raise