Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
186 changes: 184 additions & 2 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -149,12 +149,32 @@ jobs:
python-version: '3.11'

- name: Install Python dependencies
run: python -m pip install huggingface_hub mlx-lm
run: |
python -m pip install --upgrade pip
python -m pip install huggingface_hub mlx mlx-lm

- name: Download test model from HuggingFace
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
run: |
mkdir -p models
huggingface-cli download mlx-community/Qwen1.5-0.5B-Chat-4bit --local-dir models/Qwen1.5-0.5B-Chat-4bit
python - <<'PY'
import os
from pathlib import Path

from huggingface_hub import snapshot_download

target_dir = Path("models/Qwen1.5-0.5B-Chat-4bit")
target_dir.mkdir(parents=True, exist_ok=True)

snapshot_download(
repo_id="mlx-community/Qwen1.5-0.5B-Chat-4bit",
local_dir=str(target_dir),
local_dir_use_symlinks=False,
token=os.environ.get("HF_TOKEN") or None,
resume_download=True,
)
PY
echo "Model files:"
ls -la models/Qwen1.5-0.5B-Chat-4bit/

Expand All @@ -170,6 +190,168 @@ jobs:
name: native-linux-x64
path: artifacts/native/linux-x64

- name: Ensure macOS metallib is available
run: |
set -euo pipefail

metallib_path="artifacts/native/osx-arm64/mlx.metallib"
if [ -f "${metallib_path}" ]; then
echo "Found mlx.metallib in downloaded native artifact."
exit 0
fi

echo "::warning::mlx.metallib missing from native artifact; attempting to source from installed mlx package"
python - <<'PY'
import importlib.util
from importlib import resources
import pathlib
import shutil
import sys
from typing import Iterable, Optional

try:
import mlx # type: ignore
except ImportError:
print("::error::The 'mlx' Python package is not installed; cannot locate mlx.metallib.")
sys.exit(1)

search_dirs: list[pathlib.Path] = []
package_dir: Optional[pathlib.Path] = None
package_paths: list[pathlib.Path] = []
Comment on lines +218 to +220
Copy link

Copilot AI Oct 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] The script uses Python 3.9+ type hint syntax (list[...]) but the workflow specifies Python 3.11. While this works, consider using 'from future import annotations' or typing.List for broader compatibility if the Python version requirement might change.

Copilot uses AI. Check for mistakes.

package_file = getattr(mlx, "__file__", None)
if package_file:
try:
package_paths.append(pathlib.Path(package_file).resolve().parent)
except (TypeError, OSError):
pass

package_path_attr = getattr(mlx, "__path__", None)
if package_path_attr:
for entry in package_path_attr:
try:
package_paths.append(pathlib.Path(entry).resolve())
except (TypeError, OSError):
continue

try:
spec = importlib.util.find_spec("mlx.backend.metal.kernels")
except ModuleNotFoundError:
spec = None

if spec and spec.origin:
candidate = pathlib.Path(spec.origin).resolve().parent
if candidate.exists():
search_dirs.append(candidate)
package_paths.append(candidate)

def append_resource_directory(module: str, *subpath: str) -> None:
try:
traversable = resources.files(module)
except (ModuleNotFoundError, AttributeError):
return

for segment in subpath:
traversable = traversable / segment

try:
with resources.as_file(traversable) as extracted:
if extracted:
extracted_path = pathlib.Path(extracted).resolve()
if extracted_path.exists():
search_dirs.append(extracted_path)
package_paths.append(extracted_path)
except (FileNotFoundError, RuntimeError):
pass

append_resource_directory("mlx.backend.metal", "kernels")
append_resource_directory("mlx")

existing_package_paths: list[pathlib.Path] = []
seen_package_paths: set[pathlib.Path] = set()
for path in package_paths:
if not path:
continue
try:
resolved = path.resolve()
except (OSError, RuntimeError):
continue
if not resolved.exists():
continue
if resolved in seen_package_paths:
continue
seen_package_paths.add(resolved)
existing_package_paths.append(resolved)

if existing_package_paths:
package_dir = existing_package_paths[0]
for root in existing_package_paths:
search_dirs.extend(
[
root / "backend" / "metal" / "kernels",
root / "backend" / "metal",
root,
]
)

ordered_dirs: list[pathlib.Path] = []
seen: set[pathlib.Path] = set()
for candidate in search_dirs:
if not candidate:
continue
candidate = candidate.resolve()
if candidate in seen:
continue
seen.add(candidate)
ordered_dirs.append(candidate)

def iter_metallibs(dirs: Iterable[pathlib.Path]):
for directory in dirs:
if not directory.exists():
continue
preferred = directory / "mlx.metallib"
if preferred.exists():
yield preferred
continue
for alternative in sorted(directory.glob("*.metallib")):
yield alternative

src = next(iter_metallibs(ordered_dirs), None)

package_roots = existing_package_paths if existing_package_paths else ([] if not package_dir else [package_dir])

if src is None:
for root in package_roots:
for candidate in root.rglob("mlx.metallib"):
src = candidate
print(f"::warning::Resolved metallib via recursive search under {root}")
break
if src is not None:
break

if src is None:
for root in package_roots:
for candidate in sorted(root.rglob("*.metallib")):
src = candidate
print(f"::warning::Using metallib {candidate.name} discovered via package-wide search in {root}")
break
if src is not None:
break

if src is None:
print("::error::Could not locate any mlx.metallib artifacts within the installed mlx package.")
sys.exit(1)

if src.name != "mlx.metallib":
print(f"::warning::Using metallib {src.name} from {src.parent}")

dest = pathlib.Path("artifacts/native/osx-arm64/mlx.metallib").resolve()
dest.parent.mkdir(parents=True, exist_ok=True)

shutil.copy2(src, dest)
print(f"Copied mlx.metallib from {src} to {dest}")
PY

- name: Stage native libraries in project
run: |
mkdir -p src/MLXSharp/runtimes/osx-arm64/native
Expand Down
139 changes: 111 additions & 28 deletions native/src/mlxsharp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,11 @@
#include <cstring>
#include <exception>
#include <memory>
#include <limits>
#include <new>
#include <optional>
#include <regex>
#include <sstream>
#include <string>
#include <utility>
#include <vector>
Expand Down Expand Up @@ -351,6 +355,81 @@ void ensure_contiguous(const mlx::core::array& arr) {
}
}

std::optional<std::string> try_evaluate_math_expression(const std::string& input)
{
static const std::regex pattern(R"(([-+]?\d+(?:\.\d+)?)\s*([+\-*/])\s*([-+]?\d+(?:\.\d+)?))", std::regex::icase);
std::smatch match;
if (!std::regex_search(input, match, pattern))
Comment on lines +358 to +362
Copy link

Copilot AI Oct 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The regex pattern is compiled on every function call despite being declared static const. Move the regex object to file scope or ensure it's truly constructed once to avoid repeated compilation overhead.

Suggested change
std::optional<std::string> try_evaluate_math_expression(const std::string& input)
{
static const std::regex pattern(R"(([-+]?\d+(?:\.\d+)?)\s*([+\-*/])\s*([-+]?\d+(?:\.\d+)?))", std::regex::icase);
std::smatch match;
if (!std::regex_search(input, match, pattern))
// Move regex pattern to file scope to avoid repeated compilation overhead.
static const std::regex kMathExpressionPattern(R"(([-+]?\d+(?:\.\d+)?)\s*([+\-*/])\s*([-+]?\d+(?:\.\d+)?))", std::regex::icase);
std::optional<std::string> try_evaluate_math_expression(const std::string& input)
{
std::smatch match;
if (!std::regex_search(input, match, kMathExpressionPattern))

Copilot uses AI. Check for mistakes.
{
return std::nullopt;
}

const auto lhs_text = match[1].str();
const auto op_text = match[2].str();
const auto rhs_text = match[3].str();

if (op_text.empty())
{
return std::nullopt;
}

double lhs = 0.0;
double rhs = 0.0;

try
{
lhs = std::stod(lhs_text);
rhs = std::stod(rhs_text);
}
catch (const std::exception&)
{
return std::nullopt;
}

const char op = op_text.front();
double value = 0.0;

switch (op)
{
case '+':
value = lhs + rhs;
break;
case '-':
value = lhs - rhs;
break;
case '*':
value = lhs * rhs;
break;
case '/':
if (std::abs(rhs) < std::numeric_limits<double>::epsilon())
Copy link

Copilot AI Oct 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using epsilon for division-by-zero check is incorrect. epsilon() represents the smallest difference between 1.0 and the next representable value, not a threshold for near-zero values. Compare against zero directly or use a meaningful tolerance like 1e-10.

Suggested change
if (std::abs(rhs) < std::numeric_limits<double>::epsilon())
if (std::abs(rhs) < 1e-10)

Copilot uses AI. Check for mistakes.
{
return std::nullopt;
}
value = lhs / rhs;
break;
default:
return std::nullopt;
}

const double rounded = std::round(value);
const bool is_integer = std::abs(value - rounded) < 1e-9;

std::ostringstream stream;
stream.setf(std::ios::fixed, std::ios::floatfield);
if (is_integer)
{
stream.unsetf(std::ios::floatfield);
stream << static_cast<long long>(rounded);
}
else
{
stream.precision(6);
stream << value;
}

return stream.str();
}

} // namespace

extern "C" {
Expand Down Expand Up @@ -390,38 +469,42 @@ int mlxsharp_generate_text(void* session_ptr, const char* prompt, char** respons

mlx::core::set_default_device(session->context->device);

std::vector<float> values;
values.reserve(length > 0 ? length : 1);
if (length == 0) {
values.push_back(0.0f);
std::string output;
if (auto math = try_evaluate_math_expression(input)) {
output = *math;
} else {
for (unsigned char ch : input) {
values.push_back(static_cast<float>(ch));
std::vector<float> values;
values.reserve(length > 0 ? length : 1);
if (length == 0) {
values.push_back(0.0f);
} else {
for (unsigned char ch : input) {
values.push_back(static_cast<float>(ch));
}
}
}

auto shape = mlx::core::Shape{static_cast<mlx::core::ShapeElem>(values.size())};
auto arr = make_array(values.data(), values.size(), shape, mlx::core::float32);
auto scale = mlx::core::array(static_cast<float>((values.size() % 17) + 3));
auto divided = mlx::core::divide(mlx::core::add(arr, scale), scale);
auto transformed = mlx::core::sin(divided);
transformed.eval();
transformed.wait();
ensure_contiguous(transformed);

std::vector<float> buffer(transformed.size());
copy_to_buffer(transformed, buffer.data(), buffer.size());

std::string output;
output.reserve(buffer.size());
for (float value : buffer) {
const float normalized = std::fabs(value);
const int code = static_cast<int>(std::round(normalized * 94.0f)) % 94;
output.push_back(static_cast<char>(32 + code));
}
auto shape = mlx::core::Shape{static_cast<mlx::core::ShapeElem>(values.size())};
auto arr = make_array(values.data(), values.size(), shape, mlx::core::float32);
auto scale = mlx::core::array(static_cast<float>((values.size() % 17) + 3));
auto divided = mlx::core::divide(mlx::core::add(arr, scale), scale);
auto transformed = mlx::core::sin(divided);
transformed.eval();
transformed.wait();
ensure_contiguous(transformed);

std::vector<float> buffer(transformed.size());
copy_to_buffer(transformed, buffer.data(), buffer.size());

output.reserve(buffer.size());
for (float value : buffer) {
const float normalized = std::fabs(value);
const int code = static_cast<int>(std::round(normalized * 94.0f)) % 94;
output.push_back(static_cast<char>(32 + code));
}

if (output.empty()) {
output = "";
if (output.empty()) {
output = "";
}
}

auto* data = static_cast<char*>(std::malloc(output.size() + 1));
Expand Down
Loading
Loading