diff --git a/README.md b/README.md index 77ff0f3..17550f8 100644 --- a/README.md +++ b/README.md @@ -25,7 +25,7 @@ In either case you get better results if you've trained or fine-tuned the model - +
Toolio is primarily developed by the crew at Oori Data. We offer data pipelines and software engineering services around AI/LLM applications.Toolio is primarily developed by the crew at Oori Data. We offer LLMOps, data pipelines and software engineering services around AI/LLM applications.
We'd love your help, though! [Click to learn how to make contributions to the project](https://github.com/OoriData/Toolio/wiki/Notes-for-contributors). diff --git a/demo/mlx_enhanced_features.py b/demo/mlx_enhanced_features.py new file mode 100644 index 0000000..00ca99d --- /dev/null +++ b/demo/mlx_enhanced_features.py @@ -0,0 +1,87 @@ +#!/usr/bin/env python +''' +Demo of MLX 0.29.0 and MLX-LM 0.27.0 enhanced features: +- Prompt caching for faster repeated queries +- 4-bit KV cache quantization for memory efficiency +''' + +import sys +sys.path.insert(0, '..') + +from pylib.schema_helper import Model + +def demo_enhanced_features(): + print("MLX Enhanced Features Demo") + print("=" * 50) + + # Initialize model with prompt cache support + model = Model() + + # Load model with max_kv_size for prompt caching + print("\n1. Loading model with prompt cache support...") + model.load( + "mlx-community/Qwen2.5-1.5B-Instruct-4bit", + max_kv_size=4096 # Enable prompt cache with 4K tokens + ) + print(" ✓ Model loaded with prompt caching enabled") + + # Example schema for structured output + schema = { + "type": "object", + "properties": { + "name": {"type": "string"}, + "age": {"type": "integer"}, + "skills": { + "type": "array", + "items": {"type": "string"} + } + }, + "required": ["name", "age", "skills"] + } + + # Test 1: Basic generation with prompt caching + print("\n2. Testing prompt caching...") + messages = [ + {"role": "user", "content": "Generate a profile for a software developer"} + ] + + print(" First generation (cold cache):") + result = model.completion( + messages=messages, + schema=schema, + cache_prompt=True, # Enable prompt caching + max_tokens=100 + ) + for chunk in result: + print(f" {chunk}", end="") + print() + + # Test 2: Generation with 4-bit KV quantization + print("\n3. Testing 4-bit KV cache quantization...") + messages2 = [ + {"role": "user", "content": "Generate a profile for a data scientist"} + ] + + result = model.completion( + messages=messages2, + schema=schema, + kv_bits=4, # Enable 4-bit quantization + kv_group_size=64, + quantized_kv_start=100, # Start quantizing after 100 tokens + max_tokens=100 + ) + + print(" Generation with 4-bit KV cache:") + for chunk in result: + print(f" {chunk}", end="") + print() + + print("\n" + "=" * 50) + print("Enhanced features successfully demonstrated!") + print("\nKey improvements with MLX 0.29.0 & MLX-LM 0.27.0:") + print("• Prompt caching reduces latency for repeated queries") + print("• 4-bit KV quantization reduces memory by ~50%") + print("• Quantized attention maintains quality while saving memory") + +if __name__ == "__main__": + demo_enhanced_features() \ No newline at end of file diff --git a/pylib/common.py b/pylib/common.py index 545ef6c..32219b5 100644 --- a/pylib/common.py +++ b/pylib/common.py @@ -20,7 +20,7 @@ # from mlx_lm.models import olmo # Will say:: To run olmo install ai2-olmo: pip install ai2-olmo -from ogbujipt import word_loom +import wordloom class model_flag(int, Flag): @@ -89,7 +89,7 @@ def obj_file_path_parent(obj): HERE = obj_file_path_parent(lambda: 0) with open(HERE / Path('resource/language.toml'), mode='rb') as fp: - LANG = word_loom.load(fp) + LANG = wordloom.load(fp) class model_runner_base: diff --git a/pylib/schema_helper.py b/pylib/schema_helper.py index d84d77c..54ccf2c 100644 --- a/pylib/schema_helper.py +++ b/pylib/schema_helper.py @@ -12,8 +12,8 @@ from typing import Iterable import mlx.core as mx -# from mlx_lm.models.cache import KVCache, _BaseCache -# from mlx_lm.models.cache import make_prompt_cache +from mlx_lm.models.cache import KVCache, _BaseCache +from mlx_lm.models.cache import make_prompt_cache from mlx_lm.generate import load, stream_generate # , GenerationResponse from toolio.vendor.llm_structured_output import JsonSchemaAcceptorDriver @@ -27,11 +27,11 @@ 'sampler': None, 'logits_processors': None, 'max_kv_size': None, - # 'prompt_cache': None, + 'prompt_cache': None, 'prefill_step_size': 512, - 'kv_bits': None, + 'kv_bits': None, # Set to 4 for mxfp4 quantization, 8 for int8 'kv_group_size': 64, - 'quantized_kv_start': 0, + 'quantized_kv_start': 0, # Start quantizing after N tokens 'prompt_progress_callback': None } @@ -80,9 +80,9 @@ def __init__(self): self.eos_id = None self.json_schema_acceptor_driver_factory = None # Note: If for example the user loads a cache from a file, and we support prompt caching that way, they should not have to re-specify init params such as max_kv_size - # self._prompt_cache = make_prompt_cache(self.model, max_kv_size) + self._prompt_cache = None - def load(self, model_path: str): + def load(self, model_path: str, max_kv_size: int | None = None): ''' Load locally or download from Huggingface hub. ''' @@ -94,6 +94,11 @@ def load(self, model_path: str): self.vocabulary, self.eos_id ) ) + # Initialize prompt cache if requested + if max_kv_size: + self._prompt_cache = make_prompt_cache(self.model, max_kv_size) + else: + self._prompt_cache = None def completion( self, @@ -125,7 +130,9 @@ def completion( prompt_tokens = self.tokenizer.apply_chat_template(messages, add_generation_prompt=True) self._prompt_length = len(prompt_tokens) # Store prompt length - logits_generator = stream_generate(self.model, self.tokenizer, prompt_tokens, **kwargs) + # Use prompt cache if available and requested + if cache_prompt and self._prompt_cache: + kwargs['prompt_cache'] = self._prompt_cache self._step_count = 0 for generation_resp in stream_generate(self.model, self.tokenizer, prompt_tokens, **kwargs): diff --git a/pyproject.toml b/pyproject.toml index e3d344e..dc7a628 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,14 +24,15 @@ classifiers = [ dependencies = [ # MLX libraries are only installed on Apple Silicon Macs running macOS 13.5+ with Python 3.8+, as required # For more on environment markers see: https://hatch.pypa.io/dev/config/dependency/#environment-markers - "mlx>=0.23.1; sys_platform == 'darwin' and platform_machine == 'arm64' and python_version >= '3.8'", - "mlx_lm>=0.21.4; sys_platform == 'darwin' and platform_machine == 'arm64' and python_version >= '3.8'", + "mlx>=0.29.0; sys_platform == 'darwin' and platform_machine == 'arm64' and python_version >= '3.8'", + "mlx_lm>=0.27.0; sys_platform == 'darwin' and platform_machine == 'arm64' and python_version >= '3.8'", # Rather than a marker such as platform_release >= '22.6.0' (corresponding to macOS 13.5 or later) we'll make this a runtime check # This is the former logic, but in some scenarios, e.g. running within Docker, we can get platform_release markers such as Linux kernel version "6.10.14-linuxkit", which are not valid version strings according to Python's packaging.version # "mlx>=0.23.1; sys_platform == 'darwin' and platform_machine == 'arm64' and python_version >= '3.8' and platform_release >= '22.6.0'", # "mlx_lm>=0.21.4; sys_platform == 'darwin' and platform_machine == 'arm64' and python_version >= '3.8' and platform_release >= '22.6.0'", - "ogbujipt>=0.9.3", + "wordloom", + # "ogbujipt>=0.10.0", "fastapi>=0.115.3", "click", "httpx>=0.27.2", @@ -40,18 +41,26 @@ dependencies = [ ] [project.optional-dependencies] +# Extra requirements for the built-in slate of tools tools = [ - # Extra requirements for the built-in slate of tools "google-re2" ] +# Extra requirements for project contributors & other developers dev = [ - # Extra requirements for project contributors & other developers "build", + "twine", + "hatch", + "pipdeptree", + "ruff", + "pytest", + "pytest-mock", "pytest-asyncio", + "pytest-cov", + "respx", "pytest-httpserver", - # "pytest-mock" ] + [project.urls] Documentation = "https://OoriData.github.io/Toolio/" Issues = "https://github.com/OoriData/Toolio/issues"