Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ In either case you get better results if you've trained or fine-tuned the model

<table><tr>
<td><a href="https://oori.dev/"><img src="https://www.oori.dev/assets/branding/oori_Logo_FullColor.png" width="64" /></a></td>
<td>Toolio is primarily developed by the crew at <a href="https://oori.dev/">Oori Data</a>. We offer data pipelines and software engineering services around AI/LLM applications.</td>
<td>Toolio is primarily developed by the crew at <a href="https://oori.dev/">Oori Data</a>. We offer LLMOps, data pipelines and software engineering services around AI/LLM applications.</td>
</tr></table>

We'd love your help, though! [Click to learn how to make contributions to the project](https://github.com/OoriData/Toolio/wiki/Notes-for-contributors).
Expand Down
87 changes: 87 additions & 0 deletions demo/mlx_enhanced_features.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
#!/usr/bin/env python
'''
Demo of MLX 0.29.0 and MLX-LM 0.27.0 enhanced features:
- Prompt caching for faster repeated queries
- 4-bit KV cache quantization for memory efficiency
'''

import sys
sys.path.insert(0, '..')

from pylib.schema_helper import Model

def demo_enhanced_features():
print("MLX Enhanced Features Demo")
print("=" * 50)

# Initialize model with prompt cache support
model = Model()

# Load model with max_kv_size for prompt caching
print("\n1. Loading model with prompt cache support...")
model.load(
"mlx-community/Qwen2.5-1.5B-Instruct-4bit",
max_kv_size=4096 # Enable prompt cache with 4K tokens
)
print(" ✓ Model loaded with prompt caching enabled")

# Example schema for structured output
schema = {
"type": "object",
"properties": {
"name": {"type": "string"},
"age": {"type": "integer"},
"skills": {
"type": "array",
"items": {"type": "string"}
}
},
"required": ["name", "age", "skills"]
}

# Test 1: Basic generation with prompt caching
print("\n2. Testing prompt caching...")
messages = [
{"role": "user", "content": "Generate a profile for a software developer"}
]

print(" First generation (cold cache):")
result = model.completion(
messages=messages,
schema=schema,
cache_prompt=True, # Enable prompt caching
max_tokens=100
)
for chunk in result:
print(f" {chunk}", end="")
print()

# Test 2: Generation with 4-bit KV quantization
print("\n3. Testing 4-bit KV cache quantization...")
messages2 = [
{"role": "user", "content": "Generate a profile for a data scientist"}
]

result = model.completion(
messages=messages2,
schema=schema,
kv_bits=4, # Enable 4-bit quantization
kv_group_size=64,
quantized_kv_start=100, # Start quantizing after 100 tokens
max_tokens=100
)

print(" Generation with 4-bit KV cache:")
for chunk in result:
print(f" {chunk}", end="")
print()

print("\n" + "=" * 50)
print("Enhanced features successfully demonstrated!")
print("\nKey improvements with MLX 0.29.0 & MLX-LM 0.27.0:")
print("• Prompt caching reduces latency for repeated queries")
print("• 4-bit KV quantization reduces memory by ~50%")
print("• Quantized attention maintains quality while saving memory")

if __name__ == "__main__":
demo_enhanced_features()
4 changes: 2 additions & 2 deletions pylib/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

# from mlx_lm.models import olmo # Will say:: To run olmo install ai2-olmo: pip install ai2-olmo

from ogbujipt import word_loom
import wordloom


class model_flag(int, Flag):
Expand Down Expand Up @@ -89,7 +89,7 @@ def obj_file_path_parent(obj):

HERE = obj_file_path_parent(lambda: 0)
with open(HERE / Path('resource/language.toml'), mode='rb') as fp:
LANG = word_loom.load(fp)
LANG = wordloom.load(fp)


class model_runner_base:
Expand Down
23 changes: 15 additions & 8 deletions pylib/schema_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
from typing import Iterable

import mlx.core as mx
# from mlx_lm.models.cache import KVCache, _BaseCache
# from mlx_lm.models.cache import make_prompt_cache
from mlx_lm.models.cache import KVCache, _BaseCache
from mlx_lm.models.cache import make_prompt_cache
from mlx_lm.generate import load, stream_generate # , GenerationResponse

from toolio.vendor.llm_structured_output import JsonSchemaAcceptorDriver
Expand All @@ -27,11 +27,11 @@
'sampler': None,
'logits_processors': None,
'max_kv_size': None,
# 'prompt_cache': None,
'prompt_cache': None,
'prefill_step_size': 512,
'kv_bits': None,
'kv_bits': None, # Set to 4 for mxfp4 quantization, 8 for int8
'kv_group_size': 64,
'quantized_kv_start': 0,
'quantized_kv_start': 0, # Start quantizing after N tokens
'prompt_progress_callback': None
}

Expand Down Expand Up @@ -80,9 +80,9 @@ def __init__(self):
self.eos_id = None
self.json_schema_acceptor_driver_factory = None
# Note: If for example the user loads a cache from a file, and we support prompt caching that way, they should not have to re-specify init params such as max_kv_size
# self._prompt_cache = make_prompt_cache(self.model, max_kv_size)
self._prompt_cache = None

def load(self, model_path: str):
def load(self, model_path: str, max_kv_size: int | None = None):
'''
Load locally or download from Huggingface hub.
'''
Expand All @@ -94,6 +94,11 @@ def load(self, model_path: str):
self.vocabulary, self.eos_id
)
)
# Initialize prompt cache if requested
if max_kv_size:
self._prompt_cache = make_prompt_cache(self.model, max_kv_size)
else:
self._prompt_cache = None

def completion(
self,
Expand Down Expand Up @@ -125,7 +130,9 @@ def completion(
prompt_tokens = self.tokenizer.apply_chat_template(messages, add_generation_prompt=True)
self._prompt_length = len(prompt_tokens) # Store prompt length

logits_generator = stream_generate(self.model, self.tokenizer, prompt_tokens, **kwargs)
# Use prompt cache if available and requested
if cache_prompt and self._prompt_cache:
kwargs['prompt_cache'] = self._prompt_cache

self._step_count = 0
for generation_resp in stream_generate(self.model, self.tokenizer, prompt_tokens, **kwargs):
Expand Down
21 changes: 15 additions & 6 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,15 @@ classifiers = [
dependencies = [
# MLX libraries are only installed on Apple Silicon Macs running macOS 13.5+ with Python 3.8+, as required
# For more on environment markers see: https://hatch.pypa.io/dev/config/dependency/#environment-markers
"mlx>=0.23.1; sys_platform == 'darwin' and platform_machine == 'arm64' and python_version >= '3.8'",
"mlx_lm>=0.21.4; sys_platform == 'darwin' and platform_machine == 'arm64' and python_version >= '3.8'",
"mlx>=0.29.0; sys_platform == 'darwin' and platform_machine == 'arm64' and python_version >= '3.8'",
"mlx_lm>=0.27.0; sys_platform == 'darwin' and platform_machine == 'arm64' and python_version >= '3.8'",
# Rather than a marker such as platform_release >= '22.6.0' (corresponding to macOS 13.5 or later) we'll make this a runtime check

# This is the former logic, but in some scenarios, e.g. running within Docker, we can get platform_release markers such as Linux kernel version "6.10.14-linuxkit", which are not valid version strings according to Python's packaging.version
# "mlx>=0.23.1; sys_platform == 'darwin' and platform_machine == 'arm64' and python_version >= '3.8' and platform_release >= '22.6.0'",
# "mlx_lm>=0.21.4; sys_platform == 'darwin' and platform_machine == 'arm64' and python_version >= '3.8' and platform_release >= '22.6.0'",
"ogbujipt>=0.9.3",
"wordloom",
# "ogbujipt>=0.10.0",
"fastapi>=0.115.3",
"click",
"httpx>=0.27.2",
Expand All @@ -40,18 +41,26 @@ dependencies = [
]

[project.optional-dependencies]
# Extra requirements for the built-in slate of tools
tools = [
# Extra requirements for the built-in slate of tools
"google-re2"
]
# Extra requirements for project contributors & other developers
dev = [
# Extra requirements for project contributors & other developers
"build",
"twine",
"hatch",
"pipdeptree",
"ruff",
"pytest",
"pytest-mock",
"pytest-asyncio",
"pytest-cov",
"respx",
"pytest-httpserver",
# "pytest-mock"
]


[project.urls]
Documentation = "https://OoriData.github.io/Toolio/"
Issues = "https://github.com/OoriData/Toolio/issues"
Expand Down