diff --git a/README.md b/README.md
index 77ff0f3..17550f8 100644
--- a/README.md
+++ b/README.md
@@ -25,7 +25,7 @@ In either case you get better results if you've trained or fine-tuned the model
 |
- Toolio is primarily developed by the crew at Oori Data. We offer data pipelines and software engineering services around AI/LLM applications. |
+ Toolio is primarily developed by the crew at Oori Data. We offer LLMOps, data pipelines and software engineering services around AI/LLM applications. |
We'd love your help, though! [Click to learn how to make contributions to the project](https://github.com/OoriData/Toolio/wiki/Notes-for-contributors).
diff --git a/demo/mlx_enhanced_features.py b/demo/mlx_enhanced_features.py
new file mode 100644
index 0000000..00ca99d
--- /dev/null
+++ b/demo/mlx_enhanced_features.py
@@ -0,0 +1,87 @@
+#!/usr/bin/env python
+'''
+Demo of MLX 0.29.0 and MLX-LM 0.27.0 enhanced features:
+- Prompt caching for faster repeated queries
+- 4-bit KV cache quantization for memory efficiency
+'''
+
+import sys
+sys.path.insert(0, '..')
+
+from pylib.schema_helper import Model
+
+def demo_enhanced_features():
+ print("MLX Enhanced Features Demo")
+ print("=" * 50)
+
+ # Initialize model with prompt cache support
+ model = Model()
+
+ # Load model with max_kv_size for prompt caching
+ print("\n1. Loading model with prompt cache support...")
+ model.load(
+ "mlx-community/Qwen2.5-1.5B-Instruct-4bit",
+ max_kv_size=4096 # Enable prompt cache with 4K tokens
+ )
+ print(" ✓ Model loaded with prompt caching enabled")
+
+ # Example schema for structured output
+ schema = {
+ "type": "object",
+ "properties": {
+ "name": {"type": "string"},
+ "age": {"type": "integer"},
+ "skills": {
+ "type": "array",
+ "items": {"type": "string"}
+ }
+ },
+ "required": ["name", "age", "skills"]
+ }
+
+ # Test 1: Basic generation with prompt caching
+ print("\n2. Testing prompt caching...")
+ messages = [
+ {"role": "user", "content": "Generate a profile for a software developer"}
+ ]
+
+ print(" First generation (cold cache):")
+ result = model.completion(
+ messages=messages,
+ schema=schema,
+ cache_prompt=True, # Enable prompt caching
+ max_tokens=100
+ )
+ for chunk in result:
+ print(f" {chunk}", end="")
+ print()
+
+ # Test 2: Generation with 4-bit KV quantization
+ print("\n3. Testing 4-bit KV cache quantization...")
+ messages2 = [
+ {"role": "user", "content": "Generate a profile for a data scientist"}
+ ]
+
+ result = model.completion(
+ messages=messages2,
+ schema=schema,
+ kv_bits=4, # Enable 4-bit quantization
+ kv_group_size=64,
+ quantized_kv_start=100, # Start quantizing after 100 tokens
+ max_tokens=100
+ )
+
+ print(" Generation with 4-bit KV cache:")
+ for chunk in result:
+ print(f" {chunk}", end="")
+ print()
+
+ print("\n" + "=" * 50)
+ print("Enhanced features successfully demonstrated!")
+ print("\nKey improvements with MLX 0.29.0 & MLX-LM 0.27.0:")
+ print("• Prompt caching reduces latency for repeated queries")
+ print("• 4-bit KV quantization reduces memory by ~50%")
+ print("• Quantized attention maintains quality while saving memory")
+
+if __name__ == "__main__":
+ demo_enhanced_features()
\ No newline at end of file
diff --git a/pylib/common.py b/pylib/common.py
index 545ef6c..32219b5 100644
--- a/pylib/common.py
+++ b/pylib/common.py
@@ -20,7 +20,7 @@
# from mlx_lm.models import olmo # Will say:: To run olmo install ai2-olmo: pip install ai2-olmo
-from ogbujipt import word_loom
+import wordloom
class model_flag(int, Flag):
@@ -89,7 +89,7 @@ def obj_file_path_parent(obj):
HERE = obj_file_path_parent(lambda: 0)
with open(HERE / Path('resource/language.toml'), mode='rb') as fp:
- LANG = word_loom.load(fp)
+ LANG = wordloom.load(fp)
class model_runner_base:
diff --git a/pylib/schema_helper.py b/pylib/schema_helper.py
index d84d77c..54ccf2c 100644
--- a/pylib/schema_helper.py
+++ b/pylib/schema_helper.py
@@ -12,8 +12,8 @@
from typing import Iterable
import mlx.core as mx
-# from mlx_lm.models.cache import KVCache, _BaseCache
-# from mlx_lm.models.cache import make_prompt_cache
+from mlx_lm.models.cache import KVCache, _BaseCache
+from mlx_lm.models.cache import make_prompt_cache
from mlx_lm.generate import load, stream_generate # , GenerationResponse
from toolio.vendor.llm_structured_output import JsonSchemaAcceptorDriver
@@ -27,11 +27,11 @@
'sampler': None,
'logits_processors': None,
'max_kv_size': None,
- # 'prompt_cache': None,
+ 'prompt_cache': None,
'prefill_step_size': 512,
- 'kv_bits': None,
+ 'kv_bits': None, # Set to 4 for mxfp4 quantization, 8 for int8
'kv_group_size': 64,
- 'quantized_kv_start': 0,
+ 'quantized_kv_start': 0, # Start quantizing after N tokens
'prompt_progress_callback': None
}
@@ -80,9 +80,9 @@ def __init__(self):
self.eos_id = None
self.json_schema_acceptor_driver_factory = None
# Note: If for example the user loads a cache from a file, and we support prompt caching that way, they should not have to re-specify init params such as max_kv_size
- # self._prompt_cache = make_prompt_cache(self.model, max_kv_size)
+ self._prompt_cache = None
- def load(self, model_path: str):
+ def load(self, model_path: str, max_kv_size: int | None = None):
'''
Load locally or download from Huggingface hub.
'''
@@ -94,6 +94,11 @@ def load(self, model_path: str):
self.vocabulary, self.eos_id
)
)
+ # Initialize prompt cache if requested
+ if max_kv_size:
+ self._prompt_cache = make_prompt_cache(self.model, max_kv_size)
+ else:
+ self._prompt_cache = None
def completion(
self,
@@ -125,7 +130,9 @@ def completion(
prompt_tokens = self.tokenizer.apply_chat_template(messages, add_generation_prompt=True)
self._prompt_length = len(prompt_tokens) # Store prompt length
- logits_generator = stream_generate(self.model, self.tokenizer, prompt_tokens, **kwargs)
+ # Use prompt cache if available and requested
+ if cache_prompt and self._prompt_cache:
+ kwargs['prompt_cache'] = self._prompt_cache
self._step_count = 0
for generation_resp in stream_generate(self.model, self.tokenizer, prompt_tokens, **kwargs):
diff --git a/pyproject.toml b/pyproject.toml
index e3d344e..dc7a628 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -24,14 +24,15 @@ classifiers = [
dependencies = [
# MLX libraries are only installed on Apple Silicon Macs running macOS 13.5+ with Python 3.8+, as required
# For more on environment markers see: https://hatch.pypa.io/dev/config/dependency/#environment-markers
- "mlx>=0.23.1; sys_platform == 'darwin' and platform_machine == 'arm64' and python_version >= '3.8'",
- "mlx_lm>=0.21.4; sys_platform == 'darwin' and platform_machine == 'arm64' and python_version >= '3.8'",
+ "mlx>=0.29.0; sys_platform == 'darwin' and platform_machine == 'arm64' and python_version >= '3.8'",
+ "mlx_lm>=0.27.0; sys_platform == 'darwin' and platform_machine == 'arm64' and python_version >= '3.8'",
# Rather than a marker such as platform_release >= '22.6.0' (corresponding to macOS 13.5 or later) we'll make this a runtime check
# This is the former logic, but in some scenarios, e.g. running within Docker, we can get platform_release markers such as Linux kernel version "6.10.14-linuxkit", which are not valid version strings according to Python's packaging.version
# "mlx>=0.23.1; sys_platform == 'darwin' and platform_machine == 'arm64' and python_version >= '3.8' and platform_release >= '22.6.0'",
# "mlx_lm>=0.21.4; sys_platform == 'darwin' and platform_machine == 'arm64' and python_version >= '3.8' and platform_release >= '22.6.0'",
- "ogbujipt>=0.9.3",
+ "wordloom",
+ # "ogbujipt>=0.10.0",
"fastapi>=0.115.3",
"click",
"httpx>=0.27.2",
@@ -40,18 +41,26 @@ dependencies = [
]
[project.optional-dependencies]
+# Extra requirements for the built-in slate of tools
tools = [
- # Extra requirements for the built-in slate of tools
"google-re2"
]
+# Extra requirements for project contributors & other developers
dev = [
- # Extra requirements for project contributors & other developers
"build",
+ "twine",
+ "hatch",
+ "pipdeptree",
+ "ruff",
+ "pytest",
+ "pytest-mock",
"pytest-asyncio",
+ "pytest-cov",
+ "respx",
"pytest-httpserver",
- # "pytest-mock"
]
+
[project.urls]
Documentation = "https://OoriData.github.io/Toolio/"
Issues = "https://github.com/OoriData/Toolio/issues"