diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 6fa5317..b2885d2 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -172,9 +172,17 @@ jobs: - name: Stage native libraries in project run: | + set -euo pipefail + mkdir -p src/MLXSharp/runtimes/osx-arm64/native cp artifacts/native/osx-arm64/libmlxsharp.dylib src/MLXSharp/runtimes/osx-arm64/native/ - cp artifacts/native/osx-arm64/mlx.metallib src/MLXSharp/runtimes/osx-arm64/native/ + + if [ -f artifacts/native/osx-arm64/mlx.metallib ]; then + cp artifacts/native/osx-arm64/mlx.metallib src/MLXSharp/runtimes/osx-arm64/native/ + else + echo "::warning::mlx.metallib not found in macOS native artifact; continuing without Metal shaders" + fi + mkdir -p src/MLXSharp/runtimes/linux-x64/native cp artifacts/native/linux-x64/libmlxsharp.so src/MLXSharp/runtimes/linux-x64/native/ @@ -186,7 +194,12 @@ jobs: TEST_OUTPUT="src/MLXSharp.Tests/bin/Release/net9.0" mkdir -p "$TEST_OUTPUT/runtimes/osx-arm64/native" cp src/MLXSharp/runtimes/osx-arm64/native/libmlxsharp.dylib "$TEST_OUTPUT/runtimes/osx-arm64/native/" - cp src/MLXSharp/runtimes/osx-arm64/native/mlx.metallib "$TEST_OUTPUT/runtimes/osx-arm64/native/" + + if [ -f src/MLXSharp/runtimes/osx-arm64/native/mlx.metallib ]; then + cp src/MLXSharp/runtimes/osx-arm64/native/mlx.metallib "$TEST_OUTPUT/runtimes/osx-arm64/native/" + else + echo "::warning::mlx.metallib not staged; tests will continue without Metal shaders" + fi ls -la "$TEST_OUTPUT/runtimes/osx-arm64/native/" - name: Run tests diff --git a/native/include/mlxsharp/api.h b/native/include/mlxsharp/api.h index db2e9a4..11ce0ed 100644 --- a/native/include/mlxsharp/api.h +++ b/native/include/mlxsharp/api.h @@ -137,10 +137,21 @@ typedef struct mlx_usage { int output_tokens; } mlx_usage; +typedef struct mlxsharp_session_options { + const char* chat_model_id; + const char* embedding_model_id; + const char* image_model_id; + const char* native_model_directory; + const char* tokenizer_path; + int enable_native_runner; + int max_generated_tokens; + float temperature; + float top_p; + int top_k; +} mlxsharp_session_options; + int mlxsharp_create_session( - const char* chat_model_id, - const char* embedding_model_id, - const char* image_model_id, + const mlxsharp_session_options* options, void** session); int mlxsharp_generate_text( diff --git a/native/src/mlxsharp.cpp b/native/src/mlxsharp.cpp index 2b7da9b..befda34 100644 --- a/native/src/mlxsharp.cpp +++ b/native/src/mlxsharp.cpp @@ -43,11 +43,36 @@ struct mlxsharp_session { std::string chat_model; std::string embedding_model; std::string image_model; - mlxsharp_session(mlxsharp_context_t* ctx, std::string chat, std::string embed, std::string image) + std::string native_model_directory; + std::string tokenizer_path; + bool enable_native_runner; + int max_generated_tokens; + float temperature; + float top_p; + int top_k; + mlxsharp_session( + mlxsharp_context_t* ctx, + std::string chat, + std::string embed, + std::string image, + std::string native_dir, + std::string tokenizer, + bool enable_runner, + int max_tokens, + float temperature_value, + float top_p_value, + int top_k_value) : context(ctx), chat_model(std::move(chat)), embedding_model(std::move(embed)), - image_model(std::move(image)) {} + image_model(std::move(image)), + native_model_directory(std::move(native_dir)), + tokenizer_path(std::move(tokenizer)), + enable_native_runner(enable_runner), + max_generated_tokens(max_tokens), + temperature(temperature_value), + top_p(top_p_value), + top_k(top_k_value) {} }; namespace { @@ -57,6 +82,7 @@ thread_local std::string g_last_error; constexpr const char* kNullContext = "Context pointer is null."; constexpr const char* kNullArray = "Array pointer is null."; constexpr const char* kNullOutParameter = "Output parameter is null."; +constexpr const char* kNullSessionOptions = "Session options pointer is null."; constexpr const char* kShapeMismatch = "Element count does not match provided shape."; constexpr const char* kNonContiguous = "Array data is not contiguous."; constexpr const char* kUnsupportedDType = "Unsupported dtype."; @@ -316,8 +342,26 @@ mlxsharp_session_t* make_session_ptr( mlxsharp_context_t* context, std::string chat_model, std::string embedding_model, - std::string image_model) { - auto* handle = new (std::nothrow) mlxsharp_session(context, std::move(chat_model), std::move(embedding_model), std::move(image_model)); + std::string image_model, + std::string native_model_directory, + std::string tokenizer_path, + bool enable_native_runner, + int max_generated_tokens, + float temperature, + float top_p, + int top_k) { + auto* handle = new (std::nothrow) mlxsharp_session( + context, + std::move(chat_model), + std::move(embedding_model), + std::move(image_model), + std::move(native_model_directory), + std::move(tokenizer_path), + enable_native_runner, + max_generated_tokens, + temperature, + top_p, + top_k); if (handle == nullptr) { throw std::bad_alloc(); } @@ -356,22 +400,43 @@ void ensure_contiguous(const mlx::core::array& arr) { extern "C" { int mlxsharp_create_session( - const char* chat_model_id, - const char* embedding_model_id, - const char* image_model_id, + const mlxsharp_session_options* options, void** session) { if (session == nullptr) { return set_error(MLXSHARP_STATUS_INVALID_ARGUMENT, "Session output pointer is null."); } return invoke([&]() -> int { - auto chat = chat_model_id != nullptr ? std::string(chat_model_id) : std::string{}; - auto embed = embedding_model_id != nullptr ? std::string(embedding_model_id) : std::string{}; - auto image = image_model_id != nullptr ? std::string(image_model_id) : std::string{}; + if (options == nullptr) { + return set_error(MLXSHARP_STATUS_INVALID_ARGUMENT, kNullSessionOptions); + } + + auto chat = options->chat_model_id != nullptr ? std::string(options->chat_model_id) : std::string{}; + auto embed = options->embedding_model_id != nullptr ? std::string(options->embedding_model_id) : std::string{}; + auto image = options->image_model_id != nullptr ? std::string(options->image_model_id) : std::string{}; + auto native_dir = options->native_model_directory != nullptr ? std::string(options->native_model_directory) : std::string{}; + auto tokenizer = options->tokenizer_path != nullptr ? std::string(options->tokenizer_path) : std::string{}; + const bool enable_runner = options->enable_native_runner != 0; + const int max_tokens = options->max_generated_tokens; + const float temperature = options->temperature; + const float top_p = options->top_p; + const int top_k = options->top_k; auto device = mlx::core::default_device(); + mlx::core::set_default_device(device); auto* context = make_context_ptr(device); - auto* handle = make_session_ptr(context, std::move(chat), std::move(embed), std::move(image)); + auto* handle = make_session_ptr( + context, + std::move(chat), + std::move(embed), + std::move(image), + std::move(native_dir), + std::move(tokenizer), + enable_runner, + max_tokens, + temperature, + top_p, + top_k); *session = handle; return MLXSHARP_STATUS_SUCCESS; }); diff --git a/src/MLXSharp/Backends/MlxNativeBackend.cs b/src/MLXSharp/Backends/MlxNativeBackend.cs index 16e642b..febaa23 100644 --- a/src/MLXSharp/Backends/MlxNativeBackend.cs +++ b/src/MLXSharp/Backends/MlxNativeBackend.cs @@ -27,7 +27,8 @@ public static MlxNativeBackend Create(MlxClientOptions options) ArgumentNullException.ThrowIfNull(options); MlxNativeLibrary.EnsureLoaded(options.LibraryPath); - var status = MlxNativeMethods.CreateSession(options.ChatModelId, options.EmbeddingModelId, options.ImageModelId, out var session); + using var sessionOptions = new MarshaledSessionOptions(options); + var status = MlxNativeMethods.CreateSession(in sessionOptions.Value, out var session); if (status != 0 || session.IsInvalid) { session.Dispose(); @@ -201,6 +202,50 @@ private MlxTextResult GenerateTextFallback(MlxTextRequest request) return existing; } + private sealed class MarshaledSessionOptions : IDisposable + { + public MlxSessionOptions Value; + + public MarshaledSessionOptions(MlxClientOptions options) + { + Value = new MlxSessionOptions + { + ChatModelId = Allocate(options.ChatModelId), + EmbeddingModelId = Allocate(options.EmbeddingModelId), + ImageModelId = Allocate(options.ImageModelId), + NativeModelDirectory = Allocate(options.NativeModelDirectory), + TokenizerPath = Allocate(options.TokenizerPath), + EnableNativeModelRunner = options.EnableNativeModelRunner ? 1 : 0, + MaxGeneratedTokens = options.MaxGeneratedTokens, + Temperature = options.Temperature, + TopP = options.TopP, + TopK = options.TopK, + }; + } + + public void Dispose() + { + Free(Value.ChatModelId); + Free(Value.EmbeddingModelId); + Free(Value.ImageModelId); + Free(Value.NativeModelDirectory); + Free(Value.TokenizerPath); + } + + private static nint Allocate(string? value) + { + return value is null ? nint.Zero : Marshal.StringToCoTaskMemUTF8(value); + } + + private static void Free(nint pointer) + { + if (pointer != nint.Zero) + { + Marshal.FreeCoTaskMem(pointer); + } + } + } + private void ThrowIfDisposed() { if (_disposed) diff --git a/src/MLXSharp/Native/MlxNativeMethods.cs b/src/MLXSharp/Native/MlxNativeMethods.cs index a567e00..60f1e28 100644 --- a/src/MLXSharp/Native/MlxNativeMethods.cs +++ b/src/MLXSharp/Native/MlxNativeMethods.cs @@ -18,8 +18,8 @@ internal static partial class MlxNativeMethods { private const string LibraryName = "libmlxsharp"; - [LibraryImport(LibraryName, EntryPoint = "mlxsharp_create_session", StringMarshalling = StringMarshalling.Utf8)] - public static partial int CreateSession(string chatModelId, string embeddingModelId, string imageModelId, out SafeMlxSessionHandle session); + [LibraryImport(LibraryName, EntryPoint = "mlxsharp_create_session")] + public static partial int CreateSession(in MlxSessionOptions options, out SafeMlxSessionHandle session); [LibraryImport(LibraryName, EntryPoint = "mlxsharp_release_session")] public static partial void ReleaseSession(nint session); @@ -142,3 +142,18 @@ internal struct MlxUsage public int InputTokens; public int OutputTokens; } + +[StructLayout(LayoutKind.Sequential)] +internal struct MlxSessionOptions +{ + public nint ChatModelId; + public nint EmbeddingModelId; + public nint ImageModelId; + public nint NativeModelDirectory; + public nint TokenizerPath; + public int EnableNativeModelRunner; + public int MaxGeneratedTokens; + public float Temperature; + public float TopP; + public int TopK; +}