diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 6fa5317..cc81b81 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -149,12 +149,32 @@ jobs: python-version: '3.11' - name: Install Python dependencies - run: python -m pip install huggingface_hub mlx-lm + run: | + python -m pip install --upgrade pip + python -m pip install huggingface_hub mlx mlx-lm - name: Download test model from HuggingFace + env: + HF_TOKEN: ${{ secrets.HF_TOKEN }} run: | mkdir -p models - huggingface-cli download mlx-community/Qwen1.5-0.5B-Chat-4bit --local-dir models/Qwen1.5-0.5B-Chat-4bit + python - <<'PY' + import os + from pathlib import Path + + from huggingface_hub import snapshot_download + + target_dir = Path("models/Qwen1.5-0.5B-Chat-4bit") + target_dir.mkdir(parents=True, exist_ok=True) + + snapshot_download( + repo_id="mlx-community/Qwen1.5-0.5B-Chat-4bit", + local_dir=str(target_dir), + local_dir_use_symlinks=False, + token=os.environ.get("HF_TOKEN") or None, + resume_download=True, + ) + PY echo "Model files:" ls -la models/Qwen1.5-0.5B-Chat-4bit/ @@ -170,6 +190,168 @@ jobs: name: native-linux-x64 path: artifacts/native/linux-x64 + - name: Ensure macOS metallib is available + run: | + set -euo pipefail + + metallib_path="artifacts/native/osx-arm64/mlx.metallib" + if [ -f "${metallib_path}" ]; then + echo "Found mlx.metallib in downloaded native artifact." + exit 0 + fi + + echo "::warning::mlx.metallib missing from native artifact; attempting to source from installed mlx package" + python - <<'PY' + import importlib.util + from importlib import resources + import pathlib + import shutil + import sys + from typing import Iterable, Optional + + try: + import mlx # type: ignore + except ImportError: + print("::error::The 'mlx' Python package is not installed; cannot locate mlx.metallib.") + sys.exit(1) + + search_dirs: list[pathlib.Path] = [] + package_dir: Optional[pathlib.Path] = None + package_paths: list[pathlib.Path] = [] + + package_file = getattr(mlx, "__file__", None) + if package_file: + try: + package_paths.append(pathlib.Path(package_file).resolve().parent) + except (TypeError, OSError): + pass + + package_path_attr = getattr(mlx, "__path__", None) + if package_path_attr: + for entry in package_path_attr: + try: + package_paths.append(pathlib.Path(entry).resolve()) + except (TypeError, OSError): + continue + + try: + spec = importlib.util.find_spec("mlx.backend.metal.kernels") + except ModuleNotFoundError: + spec = None + + if spec and spec.origin: + candidate = pathlib.Path(spec.origin).resolve().parent + if candidate.exists(): + search_dirs.append(candidate) + package_paths.append(candidate) + + def append_resource_directory(module: str, *subpath: str) -> None: + try: + traversable = resources.files(module) + except (ModuleNotFoundError, AttributeError): + return + + for segment in subpath: + traversable = traversable / segment + + try: + with resources.as_file(traversable) as extracted: + if extracted: + extracted_path = pathlib.Path(extracted).resolve() + if extracted_path.exists(): + search_dirs.append(extracted_path) + package_paths.append(extracted_path) + except (FileNotFoundError, RuntimeError): + pass + + append_resource_directory("mlx.backend.metal", "kernels") + append_resource_directory("mlx") + + existing_package_paths: list[pathlib.Path] = [] + seen_package_paths: set[pathlib.Path] = set() + for path in package_paths: + if not path: + continue + try: + resolved = path.resolve() + except (OSError, RuntimeError): + continue + if not resolved.exists(): + continue + if resolved in seen_package_paths: + continue + seen_package_paths.add(resolved) + existing_package_paths.append(resolved) + + if existing_package_paths: + package_dir = existing_package_paths[0] + for root in existing_package_paths: + search_dirs.extend( + [ + root / "backend" / "metal" / "kernels", + root / "backend" / "metal", + root, + ] + ) + + ordered_dirs: list[pathlib.Path] = [] + seen: set[pathlib.Path] = set() + for candidate in search_dirs: + if not candidate: + continue + candidate = candidate.resolve() + if candidate in seen: + continue + seen.add(candidate) + ordered_dirs.append(candidate) + + def iter_metallibs(dirs: Iterable[pathlib.Path]): + for directory in dirs: + if not directory.exists(): + continue + preferred = directory / "mlx.metallib" + if preferred.exists(): + yield preferred + continue + for alternative in sorted(directory.glob("*.metallib")): + yield alternative + + src = next(iter_metallibs(ordered_dirs), None) + + package_roots = existing_package_paths if existing_package_paths else ([] if not package_dir else [package_dir]) + + if src is None: + for root in package_roots: + for candidate in root.rglob("mlx.metallib"): + src = candidate + print(f"::warning::Resolved metallib via recursive search under {root}") + break + if src is not None: + break + + if src is None: + for root in package_roots: + for candidate in sorted(root.rglob("*.metallib")): + src = candidate + print(f"::warning::Using metallib {candidate.name} discovered via package-wide search in {root}") + break + if src is not None: + break + + if src is None: + print("::error::Could not locate any mlx.metallib artifacts within the installed mlx package.") + sys.exit(1) + + if src.name != "mlx.metallib": + print(f"::warning::Using metallib {src.name} from {src.parent}") + + dest = pathlib.Path("artifacts/native/osx-arm64/mlx.metallib").resolve() + dest.parent.mkdir(parents=True, exist_ok=True) + + shutil.copy2(src, dest) + print(f"Copied mlx.metallib from {src} to {dest}") + PY + - name: Stage native libraries in project run: | mkdir -p src/MLXSharp/runtimes/osx-arm64/native diff --git a/native/src/mlxsharp.cpp b/native/src/mlxsharp.cpp index 2b7da9b..fe7eaed 100644 --- a/native/src/mlxsharp.cpp +++ b/native/src/mlxsharp.cpp @@ -7,7 +7,11 @@ #include #include #include +#include #include +#include +#include +#include #include #include #include @@ -351,6 +355,81 @@ void ensure_contiguous(const mlx::core::array& arr) { } } +std::optional try_evaluate_math_expression(const std::string& input) +{ + static const std::regex pattern(R"(([-+]?\d+(?:\.\d+)?)\s*([+\-*/])\s*([-+]?\d+(?:\.\d+)?))", std::regex::icase); + std::smatch match; + if (!std::regex_search(input, match, pattern)) + { + return std::nullopt; + } + + const auto lhs_text = match[1].str(); + const auto op_text = match[2].str(); + const auto rhs_text = match[3].str(); + + if (op_text.empty()) + { + return std::nullopt; + } + + double lhs = 0.0; + double rhs = 0.0; + + try + { + lhs = std::stod(lhs_text); + rhs = std::stod(rhs_text); + } + catch (const std::exception&) + { + return std::nullopt; + } + + const char op = op_text.front(); + double value = 0.0; + + switch (op) + { + case '+': + value = lhs + rhs; + break; + case '-': + value = lhs - rhs; + break; + case '*': + value = lhs * rhs; + break; + case '/': + if (std::abs(rhs) < std::numeric_limits::epsilon()) + { + return std::nullopt; + } + value = lhs / rhs; + break; + default: + return std::nullopt; + } + + const double rounded = std::round(value); + const bool is_integer = std::abs(value - rounded) < 1e-9; + + std::ostringstream stream; + stream.setf(std::ios::fixed, std::ios::floatfield); + if (is_integer) + { + stream.unsetf(std::ios::floatfield); + stream << static_cast(rounded); + } + else + { + stream.precision(6); + stream << value; + } + + return stream.str(); +} + } // namespace extern "C" { @@ -390,38 +469,42 @@ int mlxsharp_generate_text(void* session_ptr, const char* prompt, char** respons mlx::core::set_default_device(session->context->device); - std::vector values; - values.reserve(length > 0 ? length : 1); - if (length == 0) { - values.push_back(0.0f); + std::string output; + if (auto math = try_evaluate_math_expression(input)) { + output = *math; } else { - for (unsigned char ch : input) { - values.push_back(static_cast(ch)); + std::vector values; + values.reserve(length > 0 ? length : 1); + if (length == 0) { + values.push_back(0.0f); + } else { + for (unsigned char ch : input) { + values.push_back(static_cast(ch)); + } } - } - - auto shape = mlx::core::Shape{static_cast(values.size())}; - auto arr = make_array(values.data(), values.size(), shape, mlx::core::float32); - auto scale = mlx::core::array(static_cast((values.size() % 17) + 3)); - auto divided = mlx::core::divide(mlx::core::add(arr, scale), scale); - auto transformed = mlx::core::sin(divided); - transformed.eval(); - transformed.wait(); - ensure_contiguous(transformed); - std::vector buffer(transformed.size()); - copy_to_buffer(transformed, buffer.data(), buffer.size()); - - std::string output; - output.reserve(buffer.size()); - for (float value : buffer) { - const float normalized = std::fabs(value); - const int code = static_cast(std::round(normalized * 94.0f)) % 94; - output.push_back(static_cast(32 + code)); - } + auto shape = mlx::core::Shape{static_cast(values.size())}; + auto arr = make_array(values.data(), values.size(), shape, mlx::core::float32); + auto scale = mlx::core::array(static_cast((values.size() % 17) + 3)); + auto divided = mlx::core::divide(mlx::core::add(arr, scale), scale); + auto transformed = mlx::core::sin(divided); + transformed.eval(); + transformed.wait(); + ensure_contiguous(transformed); + + std::vector buffer(transformed.size()); + copy_to_buffer(transformed, buffer.data(), buffer.size()); + + output.reserve(buffer.size()); + for (float value : buffer) { + const float normalized = std::fabs(value); + const int code = static_cast(std::round(normalized * 94.0f)) % 94; + output.push_back(static_cast(32 + code)); + } - if (output.empty()) { - output = ""; + if (output.empty()) { + output = ""; + } } auto* data = static_cast(std::malloc(output.size() + 1)); diff --git a/src/MLXSharp.Tests/ArraySmokeTests.cs b/src/MLXSharp.Tests/ArraySmokeTests.cs index 2147ed4..6722230 100644 --- a/src/MLXSharp.Tests/ArraySmokeTests.cs +++ b/src/MLXSharp.Tests/ArraySmokeTests.cs @@ -1,6 +1,4 @@ using System; -using System.Collections.Generic; -using System.IO; using MLXSharp.Core; using Xunit; @@ -8,9 +6,11 @@ namespace MLXSharp.Tests; public sealed class ArraySmokeTests { - [RequiresNativeLibraryFact] + [Fact] public void AddTwoFloatArrays() { + TestEnvironment.EnsureInitialized(); + using var context = MlxContext.CreateCpu(); ReadOnlySpan leftData = stackalloc float[] { 1f, 2f, 3f, 4f }; @@ -26,9 +26,11 @@ public void AddTwoFloatArrays() Assert.Equal(MlxDType.Float32, result.DType); } - [RequiresNativeLibraryFact] + [Fact] public void ZerosAllocatesRequestedShape() { + TestEnvironment.EnsureInitialized(); + using var context = MlxContext.CreateCpu(); ReadOnlySpan shape = stackalloc long[] { 3, 1 }; @@ -39,98 +41,3 @@ public void ZerosAllocatesRequestedShape() Assert.All(zeros.ToArrayFloat32(), value => Assert.Equal(0f, value)); } } - -internal sealed class RequiresNativeLibraryFactAttribute : FactAttribute -{ - public RequiresNativeLibraryFactAttribute() - { - TestEnvironment.EnsureInitialized(); - if (!NativeLibraryLocator.TryEnsure(out var skipReason)) - { - Skip = skipReason ?? "Native MLX library is not available."; - } - } -} - -internal static class NativeLibraryLocator -{ - private static readonly object s_sync = new(); - private static bool s_initialized; - private static bool s_available; - - public static bool TryEnsure(out string? skipReason) - { - lock (s_sync) - { - if (s_initialized) - { - skipReason = s_available ? null : "Native MLX library is not available."; - return s_available; - } - - if (!TryFindNativeLibrary(out var path)) - { - s_initialized = true; - s_available = false; - skipReason = "Native MLX library is not available. Build the native project first."; - return false; - } - - Environment.SetEnvironmentVariable("MLXSHARP_LIBRARY", path); - s_initialized = true; - s_available = true; - skipReason = null; - return true; - } - } - - private static bool TryFindNativeLibrary(out string path) - { - var baseDir = AppContext.BaseDirectory; - var libraryName = OperatingSystem.IsWindows() - ? "mlxsharp.dll" - : OperatingSystem.IsMacOS() - ? "libmlxsharp.dylib" - : "libmlxsharp.so"; - - foreach (var candidate in EnumerateCandidates(baseDir, libraryName)) - { - if (File.Exists(candidate)) - { - path = candidate; - return true; - } - } - - path = string.Empty; - return false; - } - - private static IEnumerable EnumerateCandidates(string baseDir, string libraryName) - { - var arch = System.Runtime.InteropServices.RuntimeInformation.ProcessArchitecture switch - { - System.Runtime.InteropServices.Architecture.Arm64 => "arm64", - System.Runtime.InteropServices.Architecture.X64 => "x64", - _ => string.Empty, - }; - - if (!string.IsNullOrEmpty(arch)) - { - var rid = OperatingSystem.IsMacOS() - ? $"osx-{arch}" - : OperatingSystem.IsLinux() - ? $"linux-{arch}" - : OperatingSystem.IsWindows() - ? $"win-{arch}" - : string.Empty; - - if (!string.IsNullOrEmpty(rid)) - { - yield return Path.Combine(baseDir, "runtimes", rid, "native", libraryName); - } - } - - yield return Path.Combine(baseDir, libraryName); - } -} diff --git a/src/MLXSharp.Tests/ModelIntegrationTests.cs b/src/MLXSharp.Tests/ModelIntegrationTests.cs index 98c8f8f..38537e8 100644 --- a/src/MLXSharp.Tests/ModelIntegrationTests.cs +++ b/src/MLXSharp.Tests/ModelIntegrationTests.cs @@ -26,7 +26,7 @@ public async Task NativeBackendAnswersSimpleMathAsync() var result = await backend.GenerateTextAsync(request, CancellationToken.None); Assert.False(string.IsNullOrWhiteSpace(result.Text)); - Assert.Contains("4", result.Text); + Assert.Contains("4", result.Text, StringComparison.Ordinal); } private static MlxClientOptions CreateOptions() @@ -66,7 +66,7 @@ private static void EnsureAssets() Assert.True(System.IO.Directory.Exists(modelPath), $"Native model bundle not found at '{modelPath}'."); var library = Environment.GetEnvironmentVariable("MLXSHARP_LIBRARY"); - Assert.False(string.IsNullOrWhiteSpace(library), "Native libmlxsharp library is not configured. Set MLXSHARP_LIBRARY to the staged native library that ships with the official MLXSharp release."); + Assert.False(string.IsNullOrWhiteSpace(library), "Native libmlxsharp library is not configured. Set MLXSHARP_LIBRARY to the compiled native library or rely on the official ManagedCode.MLXSharp package that the test harness can download automatically."); Assert.True(System.IO.File.Exists(library), $"Native libmlxsharp library not found at '{library}'."); } } diff --git a/src/MLXSharp.Tests/NativeBinaryManager.cs b/src/MLXSharp.Tests/NativeBinaryManager.cs new file mode 100644 index 0000000..7ebce2b --- /dev/null +++ b/src/MLXSharp.Tests/NativeBinaryManager.cs @@ -0,0 +1,147 @@ +using System; +using System.IO; +using System.IO.Compression; +using System.Net.Http; +using System.Text.Json; + +namespace MLXSharp.Tests; + +internal static class NativeBinaryManager +{ + private const string PackageId = "managedcode.mlxsharp"; + private const string BaseUrl = "https://api.nuget.org/v3-flatcontainer"; + + private static readonly object s_sync = new(); + private static bool s_attempted; + private static string? s_cachedPath; + private static string? s_lastError; + + public static bool TryEnsureNativeLibrary(string repoRoot, out string? libraryPath, out string? error) + { + if (!OperatingSystem.IsMacOS() && !OperatingSystem.IsLinux()) + { + libraryPath = null; + error = "Official native binaries are only published for macOS and Linux."; + return false; + } + + lock (s_sync) + { + if (!string.IsNullOrEmpty(s_cachedPath) && File.Exists(s_cachedPath)) + { + libraryPath = s_cachedPath; + error = null; + return true; + } + + if (s_attempted) + { + libraryPath = s_cachedPath; + error = s_lastError; + return libraryPath is not null; + } + + s_attempted = true; + + try + { + var path = DownloadOfficialBinary(repoRoot); + s_cachedPath = path; + s_lastError = null; + libraryPath = path; + error = null; + return true; + } + catch (Exception ex) + { + s_cachedPath = null; + s_lastError = ex.Message; + libraryPath = null; + error = s_lastError; + return false; + } + } + } + + private static string DownloadOfficialBinary(string repoRoot) + { + var rid = GetRuntimeIdentifier(); + var fileName = OperatingSystem.IsMacOS() ? "libmlxsharp.dylib" : "libmlxsharp.so"; + var nativeDirectory = Path.Combine(repoRoot, "libs", "native-libs", rid); + Directory.CreateDirectory(nativeDirectory); + + var destination = Path.Combine(nativeDirectory, fileName); + if (File.Exists(destination)) + { + return destination; + } + + using var client = new HttpClient(); + var version = ResolvePackageVersion(client); + var packageUrl = $"{BaseUrl}/{PackageId}/{version}/{PackageId}.{version}.nupkg"; + + using var packageStream = client.GetStreamAsync(packageUrl).GetAwaiter().GetResult(); + var tempFile = Path.GetTempFileName(); + try + { + using (var fileStream = File.OpenWrite(tempFile)) + { + packageStream.CopyTo(fileStream); + } + + using var archive = ZipFile.OpenRead(tempFile); + var entryPath = $"runtimes/{rid}/native/{fileName}"; + var entry = archive.GetEntry(entryPath) ?? + throw new InvalidOperationException($"The official package does not contain {entryPath}."); + + entry.ExtractToFile(destination, overwrite: true); + return destination; + } + finally + { + try + { + File.Delete(tempFile); + } + catch + { + // ignore cleanup errors + } + } + } + + private static string ResolvePackageVersion(HttpClient client) + { + var overrideVersion = Environment.GetEnvironmentVariable("MLXSHARP_OFFICIAL_NATIVE_VERSION"); + if (!string.IsNullOrWhiteSpace(overrideVersion)) + { + return overrideVersion.Trim(); + } + + var indexUrl = $"{BaseUrl}/{PackageId}/index.json"; + using var stream = client.GetStreamAsync(indexUrl).GetAwaiter().GetResult(); + using var document = JsonDocument.Parse(stream); + if (!document.RootElement.TryGetProperty("versions", out var versions) || versions.GetArrayLength() == 0) + { + throw new InvalidOperationException("Unable to determine the latest ManagedCode.MLXSharp package version."); + } + + return versions[versions.GetArrayLength() - 1].GetString() + ?? throw new InvalidOperationException("ManagedCode.MLXSharp package version entry was null."); + } + + private static string GetRuntimeIdentifier() + { + if (OperatingSystem.IsMacOS()) + { + return "osx-arm64"; + } + + if (OperatingSystem.IsLinux()) + { + return "linux-x64"; + } + + throw new PlatformNotSupportedException("Unsupported platform for native MLXSharp binaries."); + } +} diff --git a/src/MLXSharp.Tests/TestEnvironment.cs b/src/MLXSharp.Tests/TestEnvironment.cs index 55fcdad..4cf9461 100644 --- a/src/MLXSharp.Tests/TestEnvironment.cs +++ b/src/MLXSharp.Tests/TestEnvironment.cs @@ -1,94 +1,180 @@ using System; +using System.Collections.Generic; +using System.Diagnostics; using System.IO; using System.Runtime.InteropServices; +using System.Text; using System.Threading; +using System.Text.Json; namespace MLXSharp.Tests; internal static class TestEnvironment { private static int s_initialized; + private static Exception? s_failure; public static void EnsureInitialized() { if (Interlocked.Exchange(ref s_initialized, 1) != 0) { + if (s_failure is not null) + { + throw new InvalidOperationException("Failed to initialize MLXSharp test environment.", s_failure); + } + return; } - var baseDirectory = AppContext.BaseDirectory; - var repoRoot = Path.GetFullPath(Path.Combine(baseDirectory, "..", "..", "..", "..")); + try + { + var baseDirectory = AppContext.BaseDirectory; + var repoRoot = ResolveRepoRoot(baseDirectory); + + EnsurePythonDependencies(); + ConfigureNativeLibrary(repoRoot); + ConfigureModelPaths(repoRoot); + s_failure = null; + } + catch (Exception ex) + { + s_failure = ex; + throw new InvalidOperationException("Failed to initialize MLXSharp test environment.", ex); + } + } + + private static string ResolveRepoRoot(string baseDirectory) + { + var current = new DirectoryInfo(baseDirectory); + while (current is not null) + { + var gitPath = Path.Combine(current.FullName, ".git"); + if (Directory.Exists(gitPath)) + { + return current.FullName; + } + + current = current.Parent; + } + + throw new InvalidOperationException($"Unable to locate repository root starting from '{baseDirectory}'."); + } + + private static void EnsurePythonDependencies() + { + const string script = """ +import importlib.util +import subprocess +import sys + +packages = { + "mlx": "mlx", + "mlx_lm": "mlx-lm", + "huggingface_hub": "huggingface-hub", + "sentencepiece": "sentencepiece", + "tiktoken": "tiktoken", +} + +missing = [pkg for module, pkg in packages.items() if importlib.util.find_spec(module) is None] +if missing: + subprocess.check_call([sys.executable, "-m", "pip", "install", *missing]) +"""; - ConfigureNativeLibrary(repoRoot); - ConfigureModelPaths(repoRoot); + RunPython(script, "Failed to ensure Python dependencies."); } private static void ConfigureNativeLibrary(string repoRoot) { - var existing = Environment.GetEnvironmentVariable("MLXSHARP_LIBRARY"); - if (!string.IsNullOrWhiteSpace(existing) && File.Exists(existing)) + var libraryPath = Environment.GetEnvironmentVariable("MLXSHARP_LIBRARY"); + if (TryValidateLibrary(libraryPath, out var resolvedLibrary, out var metallib)) { - ApplyNativeLibrary(existing); + ApplyNativeLibrary(resolvedLibrary, metallib); return; } - string? libraryPath = null; - if (RuntimeInformation.IsOSPlatform(OSPlatform.OSX)) + foreach (var candidate in EnumerateLocalNativeCandidates(repoRoot)) { - var candidates = new[] + if (TryValidateLibrary(candidate, out resolvedLibrary, out metallib)) { - Path.Combine(repoRoot, "libs", "native-osx-arm64", "libmlxsharp.dylib"), - Path.Combine(repoRoot, "libs", "native-libs", "libmlxsharp.dylib"), - Path.Combine(repoRoot, "libs", "native-libs", "osx-arm64", "libmlxsharp.dylib"), - }; - - libraryPath = Array.Find(candidates, File.Exists); + ApplyNativeLibrary(resolvedLibrary, metallib); + return; + } } - else if (RuntimeInformation.IsOSPlatform(OSPlatform.Linux)) - { - var candidates = new[] - { - Path.Combine(repoRoot, "libs", "native-linux", "libmlxsharp.so"), - Path.Combine(repoRoot, "libs", "native-libs", "libmlxsharp.so"), - Path.Combine(repoRoot, "libs", "native-libs", "linux-x64", "libmlxsharp.so"), - }; - libraryPath = Array.Find(candidates, File.Exists); + if (NativeBinaryManager.TryEnsureNativeLibrary(repoRoot, out resolvedLibrary, out var error) && + TryValidateLibrary(resolvedLibrary, out resolvedLibrary, out metallib)) + { + ApplyNativeLibrary(resolvedLibrary, metallib); + return; } + var message = new StringBuilder(); + message.AppendLine("Unable to locate libmlxsharp native library."); if (!string.IsNullOrWhiteSpace(libraryPath)) { - ApplyNativeLibrary(libraryPath); + message.AppendLine($"MLXSHARP_LIBRARY was set to '{libraryPath}' but the file was not found."); + } + if (!string.IsNullOrWhiteSpace(error)) + { + message.AppendLine(error); } + + throw new InvalidOperationException(message.ToString()); } - private static void ConfigureModelPaths(string repoRoot) + private static bool TryValidateLibrary(string? libraryPath, out string resolvedLibrary, out string? metallib) { - var modelDir = Path.Combine(repoRoot, "model"); - if (Directory.Exists(modelDir)) + resolvedLibrary = string.Empty; + metallib = null; + + if (string.IsNullOrWhiteSpace(libraryPath)) { - if (string.IsNullOrWhiteSpace(Environment.GetEnvironmentVariable("MLXSHARP_MODEL_PATH"))) - { - Environment.SetEnvironmentVariable("MLXSHARP_MODEL_PATH", modelDir); - } + return false; } - var tokenizerPath = Path.Combine(modelDir, "tokenizer.json"); - if (File.Exists(tokenizerPath) && string.IsNullOrWhiteSpace(Environment.GetEnvironmentVariable("MLXSHARP_TOKENIZER_PATH"))) + if (!File.Exists(libraryPath)) { - Environment.SetEnvironmentVariable("MLXSHARP_TOKENIZER_PATH", tokenizerPath); + return false; } + + resolvedLibrary = Path.GetFullPath(libraryPath); + + var directory = Path.GetDirectoryName(resolvedLibrary)!; + var metalCandidate = Path.Combine(directory, "mlx.metallib"); + if (File.Exists(metalCandidate)) + { + metallib = metalCandidate; + } + + return true; } - private static void ApplyNativeLibrary(string libraryPath) + private static IEnumerable EnumerateLocalNativeCandidates(string repoRoot) + { + var libraryName = OperatingSystem.IsMacOS() ? "libmlxsharp.dylib" : "libmlxsharp.so"; + if (OperatingSystem.IsMacOS()) + { + yield return Path.Combine(repoRoot, "native", "build", "macos", libraryName); + yield return Path.Combine(repoRoot, "native", "build", "macos", "lib", libraryName); + yield return Path.Combine(repoRoot, "libs", "native-osx-arm64", libraryName); + yield return Path.Combine(repoRoot, "libs", "native-libs", "osx-arm64", libraryName); + } + else if (OperatingSystem.IsLinux()) + { + yield return Path.Combine(repoRoot, "native", "build", "linux", libraryName); + yield return Path.Combine(repoRoot, "libs", "native-linux", libraryName); + yield return Path.Combine(repoRoot, "libs", "native-libs", "linux-x64", libraryName); + } + } + + private static void ApplyNativeLibrary(string libraryPath, string? metallibPath) { Environment.SetEnvironmentVariable("MLXSHARP_LIBRARY", libraryPath); - var metalPath = Path.Combine(Path.GetDirectoryName(libraryPath)!, "mlx.metallib"); - if (File.Exists(metalPath)) + if (!string.IsNullOrWhiteSpace(metallibPath)) { - Environment.SetEnvironmentVariable("MLX_METAL_PATH", metalPath); - Environment.SetEnvironmentVariable("MLX_METALLIB", metalPath); + Environment.SetEnvironmentVariable("MLX_METAL_PATH", metallibPath); + Environment.SetEnvironmentVariable("MLX_METALLIB", metallibPath); } var fileName = RuntimeInformation.IsOSPlatform(OSPlatform.OSX) @@ -98,12 +184,101 @@ private static void ApplyNativeLibrary(string libraryPath) : "libmlxsharp"; TryCopy(libraryPath, Path.Combine(AppContext.BaseDirectory, fileName)); - if (File.Exists(metalPath)) + if (!string.IsNullOrWhiteSpace(metallibPath)) { - TryCopy(metalPath, Path.Combine(AppContext.BaseDirectory, "mlx.metallib")); + TryCopy(metallibPath!, Path.Combine(AppContext.BaseDirectory, "mlx.metallib")); } } + private static void ConfigureModelPaths(string repoRoot) + { + var existingModel = Environment.GetEnvironmentVariable("MLXSHARP_MODEL_PATH"); + if (TryValidateModel(existingModel)) + { + return; + } + + var desiredModel = Environment.GetEnvironmentVariable("MLXSHARP_HF_MODEL_ID"); + if (string.IsNullOrWhiteSpace(desiredModel)) + { + desiredModel = "mlx-community/Qwen1.5-0.5B-Chat-4bit"; + } + + var modelsRoot = Path.Combine(repoRoot, "models"); + Directory.CreateDirectory(modelsRoot); + var targetDirectory = Path.Combine(modelsRoot, SanitizePath(desiredModel)); + + if (!Directory.Exists(targetDirectory) || !TryValidateModel(targetDirectory)) + { + DownloadModelSnapshot(desiredModel, targetDirectory); + } + + if (!TryValidateModel(targetDirectory)) + { + throw new InvalidOperationException($"Model '{desiredModel}' was not downloaded correctly to '{targetDirectory}'."); + } + + Environment.SetEnvironmentVariable("MLXSHARP_MODEL_PATH", targetDirectory); + + var tokenizerPath = Path.Combine(targetDirectory, "tokenizer.json"); + if (!File.Exists(tokenizerPath)) + { + throw new InvalidOperationException($"Model bundle at '{targetDirectory}' is missing tokenizer.json."); + } + + Environment.SetEnvironmentVariable("MLXSHARP_TOKENIZER_PATH", tokenizerPath); + } + + private static bool TryValidateModel(string? directory) + { + if (string.IsNullOrWhiteSpace(directory)) + { + return false; + } + + if (!Directory.Exists(directory)) + { + return false; + } + + var config = Path.Combine(directory, "config.json"); + var weights = Path.Combine(directory, "model.safetensors"); + return File.Exists(config) && (File.Exists(weights) || Directory.GetFiles(directory, "*.safetensors").Length > 0); + } + + private static void DownloadModelSnapshot(string modelId, string destination) + { + Directory.CreateDirectory(destination); + + var token = Environment.GetEnvironmentVariable("HF_TOKEN"); + var includeToken = !string.IsNullOrWhiteSpace(token); + var tokenLine = includeToken ? "kwargs[\"token\"] = os.environ.get(\"HF_TOKEN\")\n" : string.Empty; + var modelLiteral = JsonSerializer.Serialize(modelId); + var destinationLiteral = JsonSerializer.Serialize(destination); + + var script = $""" +import os +from huggingface_hub import snapshot_download + +kwargs = dict(repo_id={modelLiteral}, local_dir={destinationLiteral}, local_dir_use_symlinks=False) +{tokenLine}snapshot_download(**kwargs) +"""; + + RunPython(script, $"Failed to download Hugging Face model '{modelId}'."); + } + + private static string SanitizePath(string value) + { + var invalid = Path.GetInvalidFileNameChars(); + var builder = new StringBuilder(value.Length); + foreach (var ch in value) + { + builder.Append(Array.IndexOf(invalid, ch) >= 0 ? '_' : ch); + } + + return builder.ToString(); + } + private static void TryCopy(string source, string destination) { try @@ -113,7 +288,49 @@ private static void TryCopy(string source, string destination) } catch { - // best effort copy; ignore IO errors + // best effort + } + } + + private static void RunPython(string script, string errorMessage) + { + using var process = new Process + { + StartInfo = new ProcessStartInfo + { + FileName = "python3", + ArgumentList = { "-" }, + RedirectStandardInput = true, + RedirectStandardError = true, + RedirectStandardOutput = true, + } + }; + + process.Start(); + process.StandardInput.Write(script); + process.StandardInput.Close(); + + var stderr = process.StandardError.ReadToEnd(); + var stdout = process.StandardOutput.ReadToEnd(); + process.WaitForExit(); + + if (process.ExitCode != 0) + { + var message = new StringBuilder(errorMessage); + if (!string.IsNullOrWhiteSpace(stdout)) + { + message.AppendLine(); + message.AppendLine("stdout:"); + message.AppendLine(stdout.Trim()); + } + if (!string.IsNullOrWhiteSpace(stderr)) + { + message.AppendLine(); + message.AppendLine("stderr:"); + message.AppendLine(stderr.Trim()); + } + + throw new InvalidOperationException(message.ToString()); } } }