From 916aabbd32feeeb36358eee063cf49decd6dd6d8 Mon Sep 17 00:00:00 2001 From: chmjkb Date: Fri, 13 Feb 2026 15:39:35 +0100 Subject: [PATCH 01/27] wip --- .../app/image_segmentation/index.tsx | 50 ++++-- .../rnexecutorch/RnExecutorchInstaller.cpp | 5 +- .../src/hooks/useModule.ts | 7 + packages/react-native-executorch/src/index.ts | 3 +- .../GenericImageSegmentation.ts | 68 ++++++++ .../computer_vision/NewImageSegmentation.ts | 162 ++++++++++++++++++ .../src/types/genericImageSegmentation.ts | 0 7 files changed, 277 insertions(+), 18 deletions(-) create mode 100644 packages/react-native-executorch/src/modules/computer_vision/GenericImageSegmentation.ts create mode 100644 packages/react-native-executorch/src/modules/computer_vision/NewImageSegmentation.ts create mode 100644 packages/react-native-executorch/src/types/genericImageSegmentation.ts diff --git a/apps/computer-vision/app/image_segmentation/index.tsx b/apps/computer-vision/app/image_segmentation/index.tsx index 61a98ddea..03d719938 100644 --- a/apps/computer-vision/app/image_segmentation/index.tsx +++ b/apps/computer-vision/app/image_segmentation/index.tsx @@ -2,9 +2,9 @@ import Spinner from '../../components/Spinner'; import { BottomBar } from '../../components/BottomBar'; import { getImage } from '../../utils'; import { - useImageSegmentation, + ImageSegmentation, DEEPLAB_V3_RESNET50, - DeeplabLabel, + SegmentationLabels, } from 'react-native-executorch'; import { Canvas, @@ -44,16 +44,34 @@ const numberToColor: number[][] = [ ]; export default function ImageSegmentationScreen() { - const model = useImageSegmentation({ model: DEEPLAB_V3_RESNET50 }); const { setGlobalGenerating } = useContext(GeneratingContext); - useEffect(() => { - setGlobalGenerating(model.isGenerating); - }, [model.isGenerating, setGlobalGenerating]); + const [model, setModel] = useState | null>( + null + ); + const [isGenerating, setIsGenerating] = useState(false); + const [downloadProgress, setDownloadProgress] = useState(0); const [imageUri, setImageUri] = useState(''); const [imageSize, setImageSize] = useState({ width: 0, height: 0 }); const [segImage, setSegImage] = useState(null); const [canvasSize, setCanvasSize] = useState({ width: 0, height: 0 }); + useEffect(() => { + setGlobalGenerating(isGenerating); + }, [isGenerating, setGlobalGenerating]); + + useEffect(() => { + let instance: ImageSegmentation<'deeplab-v3'> | null = null; + (async () => { + instance = await ImageSegmentation.fromModelName( + DEEPLAB_V3_RESNET50.modelSource, + 'deeplab-v3', + setDownloadProgress + ); + setModel(instance); + })(); + return () => instance?.delete(); + }, []); + const handleCameraPress = async (isCamera: boolean) => { const image = await getImage(isCamera); if (!image?.uri) return; @@ -66,15 +84,13 @@ export default function ImageSegmentationScreen() { }; const runForward = async () => { - if (!imageUri || imageSize.width === 0 || imageSize.height === 0) return; + if (!model || !imageUri || imageSize.width === 0 || imageSize.height === 0) + return; try { + setIsGenerating(true); const { width, height } = imageSize; - const output = await model.forward(imageUri, [DeeplabLabel.ARGMAX]); - const argmax = output[DeeplabLabel.ARGMAX] || []; - const uniqueValues = new Set(); - for (let i = 0; i < argmax.length; i++) { - uniqueValues.add(argmax[i]); - } + const output = await model.forward(imageUri, ['dupa'], true); + const argmax = output['ARGMAX'] || []; const pixels = new Uint8Array(width * height * 4); for (let row = 0; row < height; row++) { @@ -102,14 +118,16 @@ export default function ImageSegmentationScreen() { setSegImage(img); } catch (e) { console.error(e); + } finally { + setIsGenerating(false); } }; - if (!model.isReady) { + if (!model) { return ( ); } diff --git a/packages/react-native-executorch/common/rnexecutorch/RnExecutorchInstaller.cpp b/packages/react-native-executorch/common/rnexecutorch/RnExecutorchInstaller.cpp index 7a4426e06..bceac64ad 100644 --- a/packages/react-native-executorch/common/rnexecutorch/RnExecutorchInstaller.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/RnExecutorchInstaller.cpp @@ -10,9 +10,9 @@ #include #include #include -#include #include #include +#include #include #include #include @@ -92,6 +92,7 @@ void RnExecutorchInstaller::injectJSIBindings( *jsiRuntime, "loadOCR", RnExecutorchInstaller::loadModel( jsiRuntime, jsCallInvoker, "loadOCR")); + jsiRuntime->global().setProperty( *jsiRuntime, "loadVerticalOCR", RnExecutorchInstaller::loadModel( @@ -101,10 +102,12 @@ void RnExecutorchInstaller::injectJSIBindings( *jsiRuntime, "loadSpeechToText", RnExecutorchInstaller::loadModel( jsiRuntime, jsCallInvoker, "loadSpeechToText")); + jsiRuntime->global().setProperty( *jsiRuntime, "loadTextToSpeechKokoro", RnExecutorchInstaller::loadModel( jsiRuntime, jsCallInvoker, "loadTextToSpeechKokoro")); + jsiRuntime->global().setProperty( *jsiRuntime, "loadVAD", RnExecutorchInstaller::loadModel< diff --git a/packages/react-native-executorch/src/hooks/useModule.ts b/packages/react-native-executorch/src/hooks/useModule.ts index 39b10249b..1140e9cc4 100644 --- a/packages/react-native-executorch/src/hooks/useModule.ts +++ b/packages/react-native-executorch/src/hooks/useModule.ts @@ -73,6 +73,12 @@ export const useModule = < } }; + const forwardGeneric = async ( + ...input: ForwardArgs + ): Promise => { + return await forward(...input); + }; + return { /** * Contains the error message if the model failed to load. @@ -94,5 +100,6 @@ export const useModule = < */ downloadProgress, forward, + forwardGeneric, }; }; diff --git a/packages/react-native-executorch/src/index.ts b/packages/react-native-executorch/src/index.ts index 8b4035232..cb8ba09b8 100644 --- a/packages/react-native-executorch/src/index.ts +++ b/packages/react-native-executorch/src/index.ts @@ -129,7 +129,8 @@ export * from './hooks/general/useExecutorchModule'; export * from './modules/computer_vision/ClassificationModule'; export * from './modules/computer_vision/ObjectDetectionModule'; export * from './modules/computer_vision/StyleTransferModule'; -export * from './modules/computer_vision/ImageSegmentationModule'; +// export * from './modules/computer_vision/ImageSegmentationModule'; +export * from './modules/computer_vision/NewImageSegmentation'; export * from './modules/computer_vision/OCRModule'; export * from './modules/computer_vision/VerticalOCRModule'; export * from './modules/computer_vision/ImageEmbeddingsModule'; diff --git a/packages/react-native-executorch/src/modules/computer_vision/GenericImageSegmentation.ts b/packages/react-native-executorch/src/modules/computer_vision/GenericImageSegmentation.ts new file mode 100644 index 000000000..aeee4dec0 --- /dev/null +++ b/packages/react-native-executorch/src/modules/computer_vision/GenericImageSegmentation.ts @@ -0,0 +1,68 @@ +import { ResourceFetcher } from '../../utils/ResourceFetcher'; +import { ResourceSource } from '../../types/common'; +import { DeeplabLabel } from '../../types/imageSegmentation'; +import { BaseModule } from '../BaseModule'; + +/** + * Module for image segmentation tasks. + * + * @category Typescript API + */ + +// Allow string or number values (standard Enums use numbers) +type LabelMap = Record; + +export class ImageSegmentationModule extends BaseModule { + async load( + modelSource: ResourceSource, + onDownloadProgressCallback: (progress: number) => void = () => {} + ) { + // Implementation of model loading... + } + + /** + * Generic forward pass that accepts a custom Label Enum. + * * @param imageSource - Path to the image. + * @param labelMap - The runtime Enum object (e.g., DeeplabLabel or a custom object). + * @param classesOfInterest - Array of keys from the provided Enum (e.g., ['PERSON', 'DOG']). + * @param resizeToInput - Whether to resize output to input dimensions. + */ + public async forwardGeneric( + imageSource: string, + labelMap: T, + classesOfInterest: (keyof T)[], + resizeToInput: boolean = true + ): Promise> { + // 1. Convert the string keys (e.g., "PERSON") to their numeric indices (e.g., 15) + // We use the runtime 'labelMap' object to look up the values. + const classIndices = (classesOfInterest || []).map( + (label) => labelMap[label] + ); + + // 2. Call the native module with the numeric indices + const result = await this.nativeModule.generate( + imageSource, + classIndices, + resizeToInput + ); + + return result; + } + + /** + * Convenience wrapper for the default DeeplabLabel model. + */ + public async forward( + imageSource: string, + classesOfInterest: (keyof typeof DeeplabLabel)[], + resizeToInput: boolean = true + ) { + // Passes the default DeeplabLabel enum automatically + return this.forwardGeneric( + imageSource, + DeeplabLabel, + classesOfInterest, + resizeToInput + ); + } +} diff --git a/packages/react-native-executorch/src/modules/computer_vision/NewImageSegmentation.ts b/packages/react-native-executorch/src/modules/computer_vision/NewImageSegmentation.ts new file mode 100644 index 000000000..2138f2348 --- /dev/null +++ b/packages/react-native-executorch/src/modules/computer_vision/NewImageSegmentation.ts @@ -0,0 +1,162 @@ +import { ResourceFetcher } from '../../utils/ResourceFetcher'; +import { ResourceSource } from '../../types/common'; +import { DeeplabLabel } from '../../types/imageSegmentation'; +import { RnExecutorchErrorCode } from '../../errors/ErrorCodes'; +import { RnExecutorchError } from '../../errors/errorUtils'; + +type Enumish = Record; + +type SegmentationConfig = { + labelMap: T; +}; + +const ModelConfigs = { + 'deeplab-v3': { + labelMap: DeeplabLabel, + loader: (path: string) => global.loadImageSegmentation(path), + }, + 'selfie-segmentation': { + labelMap: { background: 0, object: 1 }, + loader: (path: string) => global.loadImageSegmentation(path), + }, + 'rfdetr': { + labelMap: DeeplabLabel, + loader: (path: string) => global.loadImageSegmentation(path), + }, +} as const; + +type ModelConfigsType = typeof ModelConfigs; +type ModelName = keyof ModelConfigsType; + +export type SegmentationLabels = + ModelConfigsType[M]['labelMap']; + +/** + * Generic image segmentation module with type-safe label maps. + */ +export class ImageSegmentation { + private labelMap: T; + private nativeModule: any; + + private constructor(labelMap: T, nativeModule: unknown) { + this.labelMap = labelMap; + this.nativeModule = nativeModule; + } + + /** + * Creates a segmentation instance for a known model. + * The config object is strictly typed based on the modelName provided. + */ + static async fromModelName( + modelSource: ResourceSource, + modelName: N, + onDownloadProgress: (progress: number) => void = () => {} + ): Promise> { + const { labelMap, loader } = ModelConfigs[modelName]; + const paths = await ResourceFetcher.fetch(onDownloadProgress, modelSource); + if (paths === null || paths.length < 1) { + throw new RnExecutorchError( + RnExecutorchErrorCode.DownloadInterrupted, + 'The download has been interrupted. Please retry.' + ); + } + const nativeModule = loader(paths[0] || ''); + return new ImageSegmentation(labelMap, nativeModule); + } + + /** + * Creates a segmentation instance with a user-provided label map and custom config. + */ + static async fromCustomConfig( + modelSource: ResourceSource, + config: SegmentationConfig, + onDownloadProgress: (progress: number) => void = () => {} + ): Promise> { + const paths = await ResourceFetcher.fetch(onDownloadProgress, modelSource); + if (paths === null || paths.length < 1) { + throw new RnExecutorchError( + RnExecutorchErrorCode.DownloadInterrupted, + 'The download has been interrupted. Please retry.' + ); + } + const nativeModule = global.loadImageSegmentation(paths[0] || ''); + return new ImageSegmentation(config.labelMap, nativeModule); + } + + /** + * Executes the model's forward pass. + */ + async forward( + imageSource: string, + classesOfInterest: (keyof T)[] = [], + resizeToInput: boolean = true + ): Promise>> { + if (this.nativeModule == null) { + throw new RnExecutorchError( + RnExecutorchErrorCode.ModuleNotLoaded, + 'The model is currently not loaded.' + ); + } + + const classNames = classesOfInterest.map((label) => String(label)); + + const nativeResult = await this.nativeModule.generate( + imageSource, + classNames, + resizeToInput + ); + + const result: Partial> = {}; + for (const [key, maskData] of Object.entries(nativeResult)) { + if (key in this.labelMap) { + result[key as keyof T] = maskData as number[]; + } + } + return result; + } + + /** + * Unloads the model from memory. + */ + delete() { + if (this.nativeModule != null) { + this.nativeModule.unload(); + } + } +} + +// Type tests + +// async function _typeTests() { +// const deeplab = await ImageSegmentation.fromModelName('https://example.com/model.pte', 'deeplab-v3'); +// const deeplabResult = await deeplab.forward('image.jpg', ['PERSON', 'CAR', 'ARGMAX']); +// deeplabResult.PERSON; // OK +// deeplabResult.CAR; // OK +// // ERROR: 'BANANA' is not a DeeplabLabel key +// deeplabResult.BANANA; +// +// // fromModelName: selfie-segmentation — should autocomplete 'background' | 'object' +// const selfie = await ImageSegmentation.fromModelName('https://example.com/model.pte', 'selfie-segmentation'); +// const selfieResult = await selfie.forward('image.jpg', ['background']); +// selfieResult.background; // OK +// selfieResult.object; // OK +// // ERROR: 'PERSON' is not a selfie-segmentation key +// selfieResult.PERSON; +// +// // fromCustomConfig: custom labels — should infer from provided map +// const custom = await ImageSegmentation.fromCustomConfig('https://example.com/model.pte', { +// labelMap: { sky: 0, ground: 1, building: 2 } as const, +// }); +// const customResult = await custom.forward('image.jpg', ['sky', 'ground']); +// customResult.sky; // OK +// customResult.building; // OK +// // 'water' is not in the custom label map +// customResult.water; +// +// // ERORR: 'nonexistent-model' is not a known model name +// await ImageSegmentation.fromModelName('https://example.com/model.pte', 'nonexistent-model'); +// +// // forward classesOfInterest should only accept valid keys +// // 'INVALID' is not a DeeplabLabel key +// await deeplab.forward('image.jpg', ['INVALID']); +// } diff --git a/packages/react-native-executorch/src/types/genericImageSegmentation.ts b/packages/react-native-executorch/src/types/genericImageSegmentation.ts new file mode 100644 index 000000000..e69de29bb From 249b322ae7d8bd21c9bb687c00a43804124ac7e5 Mon Sep 17 00:00:00 2001 From: chmjkb Date: Mon, 16 Feb 2026 20:54:56 +0100 Subject: [PATCH 02/27] wip --- .../BaseImageSegmentation.cpp | 260 ++++++++++++++++++ .../BaseImageSegmentation.h | 58 ++++ .../models/image_segmentation/Constants.h | 13 - .../image_segmentation/ImageSegmentation.cpp | 169 ------------ .../image_segmentation/ImageSegmentation.h | 37 +-- packages/react-native-executorch/src/index.ts | 6 +- .../computer_vision/NewImageSegmentation.ts | 180 ++++++------ .../src/types/imageSegmentation.ts | 111 +++++++- 8 files changed, 539 insertions(+), 295 deletions(-) create mode 100644 packages/react-native-executorch/common/rnexecutorch/models/image_segmentation/BaseImageSegmentation.cpp create mode 100644 packages/react-native-executorch/common/rnexecutorch/models/image_segmentation/BaseImageSegmentation.h delete mode 100644 packages/react-native-executorch/common/rnexecutorch/models/image_segmentation/Constants.h diff --git a/packages/react-native-executorch/common/rnexecutorch/models/image_segmentation/BaseImageSegmentation.cpp b/packages/react-native-executorch/common/rnexecutorch/models/image_segmentation/BaseImageSegmentation.cpp new file mode 100644 index 000000000..5496c868f --- /dev/null +++ b/packages/react-native-executorch/common/rnexecutorch/models/image_segmentation/BaseImageSegmentation.cpp @@ -0,0 +1,260 @@ +#include "BaseImageSegmentation.h" +#include "jsi/jsi.h" + +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +namespace rnexecutorch::models::image_segmentation { + +BaseImageSegmentation::BaseImageSegmentation( + const std::string &modelSource, + std::shared_ptr callInvoker) + : BaseModel(modelSource, callInvoker) { + initModelImageSize(); +} + +BaseImageSegmentation::BaseImageSegmentation( + const std::string &modelSource, std::vector normMean, + std::vector normStd, std::shared_ptr callInvoker) + : BaseModel(modelSource, callInvoker) { + initModelImageSize(); + if (normMean.size() >= 3) { + normMean_ = cv::Scalar(normMean[0], normMean[1], normMean[2]); + } + if (normStd.size() >= 3) { + normStd_ = cv::Scalar(normStd[0], normStd[1], normStd[2]); + } +} + +void BaseImageSegmentation::initModelImageSize() { + auto inputShapes = getAllInputShapes(); + if (inputShapes.size() == 0) { + throw RnExecutorchError(RnExecutorchErrorCode::UnexpectedNumInputs, + "Model seems to not take any input tensors."); + } + std::vector modelInputShape = inputShapes[0]; + if (modelInputShape.size() < 2) { + char errorMessage[100]; + std::snprintf(errorMessage, sizeof(errorMessage), + "Unexpected model input size, expected at least 2 dimentions " + "but got: %zu.", + modelInputShape.size()); + throw RnExecutorchError(RnExecutorchErrorCode::WrongDimensions, + errorMessage); + } + modelImageSize = cv::Size(modelInputShape[modelInputShape.size() - 1], + modelInputShape[modelInputShape.size() - 2]); + numModelPixels = modelImageSize.area(); +} + +TensorPtr BaseImageSegmentation::preprocess(const std::string &imageSource, + cv::Size &originalSize) { + if (normMean_.has_value() && normStd_.has_value()) { + cv::Mat input = image_processing::readImage(imageSource); + originalSize = input.size(); + cv::resize(input, input, modelImageSize); + cv::cvtColor(input, input, cv::COLOR_BGR2RGB); + return image_processing::getTensorFromMatrix( + getAllInputShapes()[0], input, normMean_.value(), normStd_.value()); + } + auto [inputTensor, origSize] = + image_processing::readImageToTensor(imageSource, getAllInputShapes()[0]); + originalSize = origSize; + return inputTensor; +} + +std::shared_ptr BaseImageSegmentation::generate( + std::string imageSource, std::vector allClasses, + std::set> classesOfInterest, bool resize) { + + cv::Size originalSize; + auto inputTensor = preprocess(imageSource, originalSize); + + auto forwardResult = BaseModel::forward(inputTensor); + + if (!forwardResult.ok()) { + throw RnExecutorchError(forwardResult.error(), + "The model's forward function did not succeed. " + "Ensure the model input is correct."); + } + + return postprocess(forwardResult->at(0).toTensor(), originalSize, allClasses, + classesOfInterest, resize); +} + +std::shared_ptr BaseImageSegmentation::postprocess( + const Tensor &tensor, cv::Size originalSize, + std::vector allClasses, + std::set> classesOfInterest, bool resize) { + + auto dataPtr = static_cast(tensor.const_data_ptr()); + auto resultData = std::span(dataPtr, tensor.numel()); + + // Infer output pixel count and channel count. + // If output spatial dims differ from input (e.g. model downsamples), + // derive pixel count from the tensor and allClasses.size(). + size_t numOutputChannels = tensor.numel() / numModelPixels; + size_t outputPixels = numModelPixels; + if (numOutputChannels != 1 && numOutputChannels != allClasses.size() && + !allClasses.empty() && tensor.numel() % allClasses.size() == 0) { + outputPixels = tensor.numel() / allClasses.size(); + numOutputChannels = allClasses.size(); + } + auto outputSide = static_cast(std::sqrt(outputPixels)); + cv::Size outputSize(outputSide, outputSide); + + std::vector> resultClasses; + auto argmax = + std::make_shared(outputPixels * sizeof(int32_t)); + + if (numOutputChannels == 1) { + // Binary segmentation path (e.g. selfie segmentation) + // The single channel contains probability values in [0, 1] + // Synthesize two class buffers: background (1-p) and foreground (p) + resultClasses.reserve(2); + + auto bgBuffer = + std::make_shared(outputPixels * sizeof(float)); + auto fgBuffer = + std::make_shared(outputPixels * sizeof(float)); + + auto *bgData = reinterpret_cast(bgBuffer->data()); + auto *fgData = reinterpret_cast(fgBuffer->data()); + auto *argmaxData = reinterpret_cast(argmax->data()); + + for (std::size_t pixel = 0; pixel < outputPixels; ++pixel) { + float p = resultData[pixel]; + bgData[pixel] = 1.0f - p; + fgData[pixel] = p; + argmaxData[pixel] = (p > 0.5f) ? 1 : 0; + } + + resultClasses.push_back(bgBuffer); + resultClasses.push_back(fgBuffer); + } else if (numOutputChannels == allClasses.size()) { + // Multi-class segmentation path (e.g. DeepLab-v3) + // Copy per-class buffers from the ET-owned tensor data + resultClasses.reserve(allClasses.size()); + for (std::size_t cl = 0; cl < allClasses.size(); ++cl) { + auto classBuffer = std::make_shared( + &resultData[cl * outputPixels], outputPixels * sizeof(float)); + resultClasses.push_back(classBuffer); + } + + // Apply softmax per each pixel across all classes + for (std::size_t pixel = 0; pixel < outputPixels; ++pixel) { + std::vector classValues(allClasses.size()); + for (std::size_t cl = 0; cl < allClasses.size(); ++cl) { + classValues[cl] = + reinterpret_cast(resultClasses[cl]->data())[pixel]; + } + numerical::softmax(classValues); + for (std::size_t cl = 0; cl < allClasses.size(); ++cl) { + reinterpret_cast(resultClasses[cl]->data())[pixel] = + classValues[cl]; + } + } + + // Calculate the maximum class for each pixel + auto *argmaxData = reinterpret_cast(argmax->data()); + for (std::size_t pixel = 0; pixel < outputPixels; ++pixel) { + float max = reinterpret_cast(resultClasses[0]->data())[pixel]; + int maxInd = 0; + for (std::size_t cl = 1; cl < allClasses.size(); ++cl) { + if (reinterpret_cast(resultClasses[cl]->data())[pixel] > max) { + maxInd = static_cast(cl); + max = reinterpret_cast(resultClasses[cl]->data())[pixel]; + } + } + argmaxData[pixel] = maxInd; + } + } else { + char errorMessage[200]; + std::snprintf( + errorMessage, sizeof(errorMessage), + "Unexpected number of output channels: %zu. Expected 1 (binary) or " + "%zu (matching allClasses). Model output has %zu elements for %zu " + "pixels.", + numOutputChannels, allClasses.size(), tensor.numel(), outputPixels); + throw RnExecutorchError(RnExecutorchErrorCode::WrongDimensions, + errorMessage); + } + + // Filter classes of interest using allClasses labels + auto buffersToReturn = std::make_shared>>(); + for (std::size_t cl = 0; cl < resultClasses.size(); ++cl) { + if (cl < allClasses.size() && classesOfInterest.contains(allClasses[cl])) { + (*buffersToReturn)[allClasses[cl]] = resultClasses[cl]; + } + } + + // Resize selected classes and argmax + if (resize) { + cv::Mat argmaxMat(outputSize, CV_32SC1, argmax->data()); + cv::resize(argmaxMat, argmaxMat, originalSize, 0, 0, + cv::InterpolationFlags::INTER_NEAREST); + argmax = std::make_shared( + argmaxMat.data, originalSize.area() * sizeof(int32_t)); + + for (auto &[label, arrayBuffer] : *buffersToReturn) { + cv::Mat classMat(outputSize, CV_32FC1, arrayBuffer->data()); + cv::resize(classMat, classMat, originalSize); + arrayBuffer = std::make_shared( + classMat.data, originalSize.area() * sizeof(float)); + } + } + return populateDictionary(argmax, buffersToReturn); +} + +std::shared_ptr BaseImageSegmentation::populateDictionary( + std::shared_ptr argmax, + std::shared_ptr>> + classesToOutput) { + // Synchronize the invoked thread to return when the dict is constructed + auto promisePtr = std::make_shared>(); + std::future doneFuture = promisePtr->get_future(); + + std::shared_ptr dictPtr = nullptr; + callInvoker->invokeAsync( + [argmax, classesToOutput, &dictPtr, promisePtr](jsi::Runtime &runtime) { + dictPtr = std::make_shared(runtime); + auto argmaxArrayBuffer = jsi::ArrayBuffer(runtime, argmax); + + auto int32ArrayCtor = + runtime.global().getPropertyAsFunction(runtime, "Int32Array"); + auto int32Array = + int32ArrayCtor.callAsConstructor(runtime, argmaxArrayBuffer) + .getObject(runtime); + dictPtr->setProperty(runtime, "ARGMAX", int32Array); + + for (auto &[classLabel, owningBuffer] : *classesToOutput) { + auto classArrayBuffer = jsi::ArrayBuffer(runtime, owningBuffer); + + auto float32ArrayCtor = + runtime.global().getPropertyAsFunction(runtime, "Float32Array"); + auto float32Array = + float32ArrayCtor.callAsConstructor(runtime, classArrayBuffer) + .getObject(runtime); + + dictPtr->setProperty( + runtime, jsi::String::createFromAscii(runtime, classLabel.data()), + float32Array); + } + promisePtr->set_value(); + }); + + doneFuture.wait(); + return dictPtr; +} + +} // namespace rnexecutorch::models::image_segmentation diff --git a/packages/react-native-executorch/common/rnexecutorch/models/image_segmentation/BaseImageSegmentation.h b/packages/react-native-executorch/common/rnexecutorch/models/image_segmentation/BaseImageSegmentation.h new file mode 100644 index 000000000..3444342a5 --- /dev/null +++ b/packages/react-native-executorch/common/rnexecutorch/models/image_segmentation/BaseImageSegmentation.h @@ -0,0 +1,58 @@ +#pragma once + +#include +#include +#include +#include +#include + +#include "rnexecutorch/metaprogramming/ConstructorHelpers.h" +#include +#include + +namespace rnexecutorch { +namespace models::image_segmentation { +using namespace facebook; + +using executorch::aten::Tensor; +using executorch::extension::TensorPtr; + +class BaseImageSegmentation : public BaseModel { +public: + BaseImageSegmentation(const std::string &modelSource, + std::shared_ptr callInvoker); + + BaseImageSegmentation(const std::string &modelSource, + std::vector normMean, std::vector normStd, + std::shared_ptr callInvoker); + + [[nodiscard("Registered non-void function")]] std::shared_ptr + generate(std::string imageSource, std::vector allClasses, + std::set> classesOfInterest, bool resize); + +protected: + virtual TensorPtr preprocess(const std::string &imageSource, + cv::Size &originalSize); + virtual std::shared_ptr + postprocess(const Tensor &tensor, cv::Size originalSize, + std::vector allClasses, + std::set> classesOfInterest, + bool resize); + + cv::Size modelImageSize; + std::size_t numModelPixels; + std::optional normMean_; + std::optional normStd_; + +private: + void initModelImageSize(); + +protected: + std::shared_ptr populateDictionary( + std::shared_ptr argmax, + std::shared_ptr>> + classesToOutput); +}; +} // namespace models::image_segmentation +} // namespace rnexecutorch diff --git a/packages/react-native-executorch/common/rnexecutorch/models/image_segmentation/Constants.h b/packages/react-native-executorch/common/rnexecutorch/models/image_segmentation/Constants.h deleted file mode 100644 index 847c44373..000000000 --- a/packages/react-native-executorch/common/rnexecutorch/models/image_segmentation/Constants.h +++ /dev/null @@ -1,13 +0,0 @@ -#pragma once - -#include -#include - -namespace rnexecutorch::models::image_segmentation::constants { -inline constexpr std::array kDeeplabV3Resnet50Labels = { - "BACKGROUND", "AEROPLANE", "BICYCLE", "BIRD", "BOAT", - "BOTTLE", "BUS", "CAR", "CAT", "CHAIR", - "COW", "DININGTABLE", "DOG", "HORSE", "MOTORBIKE", - "PERSON", "POTTEDPLANT", "SHEEP", "SOFA", "TRAIN", - "TVMONITOR"}; -} // namespace rnexecutorch::models::image_segmentation::constants \ No newline at end of file diff --git a/packages/react-native-executorch/common/rnexecutorch/models/image_segmentation/ImageSegmentation.cpp b/packages/react-native-executorch/common/rnexecutorch/models/image_segmentation/ImageSegmentation.cpp index a2c1ae865..acf4bbdf7 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/image_segmentation/ImageSegmentation.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/models/image_segmentation/ImageSegmentation.cpp @@ -1,170 +1 @@ #include "ImageSegmentation.h" - -#include - -#include -#include -#include -#include -#include -#include - -namespace rnexecutorch::models::image_segmentation { - -ImageSegmentation::ImageSegmentation( - const std::string &modelSource, - std::shared_ptr callInvoker) - : BaseModel(modelSource, callInvoker) { - auto inputShapes = getAllInputShapes(); - if (inputShapes.size() == 0) { - throw RnExecutorchError(RnExecutorchErrorCode::UnexpectedNumInputs, - "Model seems to not take any input tensors."); - } - std::vector modelInputShape = inputShapes[0]; - if (modelInputShape.size() < 2) { - char errorMessage[100]; - std::snprintf(errorMessage, sizeof(errorMessage), - "Unexpected model input size, expected at least 2 dimentions " - "but got: %zu.", - modelInputShape.size()); - throw RnExecutorchError(RnExecutorchErrorCode::WrongDimensions, - errorMessage); - } - modelImageSize = cv::Size(modelInputShape[modelInputShape.size() - 1], - modelInputShape[modelInputShape.size() - 2]); - numModelPixels = modelImageSize.area(); -} - -std::shared_ptr ImageSegmentation::generate( - std::string imageSource, - std::set> classesOfInterest, bool resize) { - auto [inputTensor, originalSize] = - image_processing::readImageToTensor(imageSource, getAllInputShapes()[0]); - - auto forwardResult = BaseModel::forward(inputTensor); - if (!forwardResult.ok()) { - throw RnExecutorchError(forwardResult.error(), - "The model's forward function did not succeed. " - "Ensure the model input is correct."); - } - - return postprocess(forwardResult->at(0).toTensor(), originalSize, - classesOfInterest, resize); -} - -std::shared_ptr ImageSegmentation::postprocess( - const Tensor &tensor, cv::Size originalSize, - std::set> classesOfInterest, bool resize) { - - auto dataPtr = static_cast(tensor.const_data_ptr()); - auto resultData = std::span(dataPtr, tensor.numel()); - - // We copy the ET-owned data to jsi array buffers that can be directly - // returned to JS - std::vector> resultClasses; - resultClasses.reserve(numClasses); - for (std::size_t cl = 0; cl < numClasses; ++cl) { - auto classBuffer = std::make_shared( - &resultData[cl * numModelPixels], numModelPixels * sizeof(float)); - resultClasses.push_back(classBuffer); - } - - // Apply softmax per each pixel across all classes - for (std::size_t pixel = 0; pixel < numModelPixels; ++pixel) { - std::vector classValues(numClasses); - for (std::size_t cl = 0; cl < numClasses; ++cl) { - classValues[cl] = - reinterpret_cast(resultClasses[cl]->data())[pixel]; - } - numerical::softmax(classValues); - for (std::size_t cl = 0; cl < numClasses; ++cl) { - reinterpret_cast(resultClasses[cl]->data())[pixel] = - classValues[cl]; - } - } - - // Calculate the maximum class for each pixel - auto argmax = - std::make_shared(numModelPixels * sizeof(int32_t)); - for (std::size_t pixel = 0; pixel < numModelPixels; ++pixel) { - float max = reinterpret_cast(resultClasses[0]->data())[pixel]; - int maxInd = 0; - for (int cl = 1; cl < numClasses; ++cl) { - if (reinterpret_cast(resultClasses[cl]->data())[pixel] > max) { - maxInd = cl; - max = reinterpret_cast(resultClasses[cl]->data())[pixel]; - } - } - reinterpret_cast(argmax->data())[pixel] = maxInd; - } - - auto buffersToReturn = std::make_shared>>(); - for (std::size_t cl = 0; cl < numClasses; ++cl) { - if (classesOfInterest.contains(constants::kDeeplabV3Resnet50Labels[cl])) { - (*buffersToReturn)[constants::kDeeplabV3Resnet50Labels[cl]] = - resultClasses[cl]; - } - } - - // Resize selected classes and argmax - if (resize) { - cv::Mat argmaxMat(modelImageSize, CV_32SC1, argmax->data()); - cv::resize(argmaxMat, argmaxMat, originalSize, 0, 0, - cv::InterpolationFlags::INTER_NEAREST); - argmax = std::make_shared( - argmaxMat.data, originalSize.area() * sizeof(int32_t)); - - for (auto &[label, arrayBuffer] : *buffersToReturn) { - cv::Mat classMat(modelImageSize, CV_32FC1, arrayBuffer->data()); - cv::resize(classMat, classMat, originalSize); - arrayBuffer = std::make_shared( - classMat.data, originalSize.area() * sizeof(float)); - } - } - return populateDictionary(argmax, buffersToReturn); -} - -std::shared_ptr ImageSegmentation::populateDictionary( - std::shared_ptr argmax, - std::shared_ptr>> - classesToOutput) { - // Synchronize the invoked thread to return when the dict is constructed - auto promisePtr = std::make_shared>(); - std::future doneFuture = promisePtr->get_future(); - - std::shared_ptr dictPtr = nullptr; - callInvoker->invokeAsync( - [argmax, classesToOutput, &dictPtr, promisePtr](jsi::Runtime &runtime) { - dictPtr = std::make_shared(runtime); - auto argmaxArrayBuffer = jsi::ArrayBuffer(runtime, argmax); - - auto int32ArrayCtor = - runtime.global().getPropertyAsFunction(runtime, "Int32Array"); - auto int32Array = - int32ArrayCtor.callAsConstructor(runtime, argmaxArrayBuffer) - .getObject(runtime); - dictPtr->setProperty(runtime, "ARGMAX", int32Array); - - for (auto &[classLabel, owningBuffer] : *classesToOutput) { - auto classArrayBuffer = jsi::ArrayBuffer(runtime, owningBuffer); - - auto float32ArrayCtor = - runtime.global().getPropertyAsFunction(runtime, "Float32Array"); - auto float32Array = - float32ArrayCtor.callAsConstructor(runtime, classArrayBuffer) - .getObject(runtime); - - dictPtr->setProperty( - runtime, jsi::String::createFromAscii(runtime, classLabel.data()), - float32Array); - } - promisePtr->set_value(); - }); - - doneFuture.wait(); - return dictPtr; -} - -} // namespace rnexecutorch::models::image_segmentation diff --git a/packages/react-native-executorch/common/rnexecutorch/models/image_segmentation/ImageSegmentation.h b/packages/react-native-executorch/common/rnexecutorch/models/image_segmentation/ImageSegmentation.h index 301833ce8..4e4bf1baf 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/image_segmentation/ImageSegmentation.h +++ b/packages/react-native-executorch/common/rnexecutorch/models/image_segmentation/ImageSegmentation.h @@ -1,48 +1,19 @@ #pragma once -#include -#include -#include -#include - #include "rnexecutorch/metaprogramming/ConstructorHelpers.h" -#include -#include -#include +#include namespace rnexecutorch { namespace models::image_segmentation { using namespace facebook; -using executorch::aten::Tensor; -using executorch::extension::TensorPtr; - -class ImageSegmentation : public BaseModel { +class ImageSegmentation : public BaseImageSegmentation { public: - ImageSegmentation(const std::string &modelSource, - std::shared_ptr callInvoker); - [[nodiscard("Registered non-void function")]] std::shared_ptr - generate(std::string imageSource, - std::set> classesOfInterest, bool resize); - -private: - std::shared_ptr - postprocess(const Tensor &tensor, cv::Size originalSize, - std::set> classesOfInterest, - bool resize); - std::shared_ptr populateDictionary( - std::shared_ptr argmax, - std::shared_ptr>> - classesToOutput); - - static constexpr std::size_t numClasses{ - constants::kDeeplabV3Resnet50Labels.size()}; - cv::Size modelImageSize; - std::size_t numModelPixels; + using BaseImageSegmentation::BaseImageSegmentation; }; } // namespace models::image_segmentation REGISTER_CONSTRUCTOR(models::image_segmentation::ImageSegmentation, std::string, + std::vector, std::vector, std::shared_ptr); } // namespace rnexecutorch diff --git a/packages/react-native-executorch/src/index.ts b/packages/react-native-executorch/src/index.ts index cb8ba09b8..0e9f09ba6 100644 --- a/packages/react-native-executorch/src/index.ts +++ b/packages/react-native-executorch/src/index.ts @@ -36,7 +36,11 @@ export function cleanupExecutorch() { // eslint-disable no-var declare global { var loadStyleTransfer: (source: string) => any; - var loadImageSegmentation: (source: string) => any; + var loadImageSegmentation: ( + source: string, + normMean: number[], + normStd: number[] + ) => any; var loadClassification: (source: string) => any; var loadObjectDetection: (source: string) => any; var loadExecutorchModule: (source: string) => any; diff --git a/packages/react-native-executorch/src/modules/computer_vision/NewImageSegmentation.ts b/packages/react-native-executorch/src/modules/computer_vision/NewImageSegmentation.ts index 2138f2348..765a92c9a 100644 --- a/packages/react-native-executorch/src/modules/computer_vision/NewImageSegmentation.ts +++ b/packages/react-native-executorch/src/modules/computer_vision/NewImageSegmentation.ts @@ -1,29 +1,41 @@ import { ResourceFetcher } from '../../utils/ResourceFetcher'; import { ResourceSource } from '../../types/common'; -import { DeeplabLabel } from '../../types/imageSegmentation'; +import { CocoLabel, DeeplabLabel } from '../../types/imageSegmentation'; import { RnExecutorchErrorCode } from '../../errors/ErrorCodes'; import { RnExecutorchError } from '../../errors/errorUtils'; +import { BaseModule } from '../BaseModule'; -type Enumish = Record; +enum SelfieSegmentationLabel { + BACKGROUND, + SELFIE, +} +type Enumish = Readonly>; +export type Triple = readonly [T, T, T]; -type SegmentationConfig = { +type SegmentationConfig = { labelMap: T; + preprocessorConfig?: { normMean?: Triple; normStd?: Triple }; }; +type ForwardReturnWithArgmax = Partial< + Record +>; + +const IMAGENET_MEAN: Triple = [0.485, 0.456, 0.406]; +const IMAGENET_STD: Triple = [0.229, 0.224, 0.225]; + const ModelConfigs = { 'deeplab-v3': { labelMap: DeeplabLabel, - loader: (path: string) => global.loadImageSegmentation(path), }, 'selfie-segmentation': { - labelMap: { background: 0, object: 1 }, - loader: (path: string) => global.loadImageSegmentation(path), + labelMap: SelfieSegmentationLabel, }, 'rfdetr': { - labelMap: DeeplabLabel, - loader: (path: string) => global.loadImageSegmentation(path), + labelMap: CocoLabel, + preprocessorConfig: { normMean: IMAGENET_MEAN, normStd: IMAGENET_STD }, }, -} as const; +} as const satisfies Record>; type ModelConfigsType = typeof ModelConfigs; type ModelName = keyof ModelConfigsType; @@ -31,28 +43,63 @@ type ModelName = keyof ModelConfigsType; export type SegmentationLabels = ModelConfigsType[M]['labelMap']; +/** + * Resolves the label type: if T is a ModelName, look up its labels; otherwise use T directly as an Enumish. + */ +type ResolveLabels = T extends ModelName + ? SegmentationLabels + : T; + +/** + * Per-model config for `fromModelName`. Each model name maps to its required fields. + * Add new union members here when a model needs extra sources or options. + */ +type ModelSources = + | { modelName: 'deeplab-v3'; modelSource: ResourceSource } + | { modelName: 'selfie-segmentation'; modelSource: ResourceSource } + | { modelName: 'rfdetr'; modelSource: ResourceSource }; + +/** + * Extract the model name from a config object. + */ +type ModelNameOf = C['modelName']; + /** * Generic image segmentation module with type-safe label maps. + * Use a model name (e.g. `'deeplab-v3'`) as the generic parameter for built-in models, + * or a custom label enum for custom configs. */ -export class ImageSegmentation { - private labelMap: T; - private nativeModule: any; +export class ImageSegmentation< + T extends ModelName | Enumish, +> extends BaseModule { + private labelMap: ResolveLabels; - private constructor(labelMap: T, nativeModule: unknown) { + private constructor(labelMap: ResolveLabels, nativeModule: unknown) { + super(); this.labelMap = labelMap; this.nativeModule = nativeModule; } + // TODO: figure it out so we can delete this (we need this because of basemodule inheritance) + async load() {} + /** * Creates a segmentation instance for a known model. - * The config object is strictly typed based on the modelName provided. + * The config object is discriminated by `modelName` — each model can require different fields. */ - static async fromModelName( - modelSource: ResourceSource, - modelName: N, + static async fromModelName( + config: C, onDownloadProgress: (progress: number) => void = () => {} - ): Promise> { - const { labelMap, loader } = ModelConfigs[modelName]; + ): Promise>> { + const { modelName, modelSource } = config; + const modelConfig = ModelConfigs[modelName]; + const { labelMap } = modelConfig; + const preprocessorConfig = + 'preprocessorConfig' in modelConfig + ? modelConfig.preprocessorConfig + : undefined; + const normMean = [...(preprocessorConfig?.normMean ?? [])]; + const normStd = [...(preprocessorConfig?.normStd ?? [])]; const paths = await ResourceFetcher.fetch(onDownloadProgress, modelSource); if (paths === null || paths.length < 1) { throw new RnExecutorchError( @@ -60,18 +107,25 @@ export class ImageSegmentation { 'The download has been interrupted. Please retry.' ); } - const nativeModule = loader(paths[0] || ''); - return new ImageSegmentation(labelMap, nativeModule); + const nativeModule = global.loadImageSegmentation( + paths[0] || '', + normMean, + normStd + ); + return new ImageSegmentation>( + labelMap as ResolveLabels>, + nativeModule + ); } /** * Creates a segmentation instance with a user-provided label map and custom config. */ - static async fromCustomConfig( + static async fromCustomConfig( modelSource: ResourceSource, - config: SegmentationConfig, + config: SegmentationConfig, onDownloadProgress: (progress: number) => void = () => {} - ): Promise> { + ): Promise> { const paths = await ResourceFetcher.fetch(onDownloadProgress, modelSource); if (paths === null || paths.length < 1) { throw new RnExecutorchError( @@ -79,8 +133,17 @@ export class ImageSegmentation { 'The download has been interrupted. Please retry.' ); } - const nativeModule = global.loadImageSegmentation(paths[0] || ''); - return new ImageSegmentation(config.labelMap, nativeModule); + const normMean = config.preprocessorConfig?.normMean ?? []; + const normStd = config.preprocessorConfig?.normStd ?? []; + const nativeModule = global.loadImageSegmentation( + paths[0] || '', + [...normMean], + [...normStd] + ); + return new ImageSegmentation( + config.labelMap as ResolveLabels, + nativeModule + ); } /** @@ -88,9 +151,9 @@ export class ImageSegmentation { */ async forward( imageSource: string, - classesOfInterest: (keyof T)[] = [], + classesOfInterest: (keyof ResolveLabels | 'ARGMAX')[] = [], resizeToInput: boolean = true - ): Promise>> { + ): Promise>> { if (this.nativeModule == null) { throw new RnExecutorchError( RnExecutorchErrorCode.ModuleNotLoaded, @@ -98,65 +161,26 @@ export class ImageSegmentation { ); } - const classNames = classesOfInterest.map((label) => String(label)); + const allClassNames = Object.keys(this.labelMap).filter((k) => + isNaN(Number(k)) + ); + const classesOfInterestNames = classesOfInterest.map((label) => + String(label) + ); const nativeResult = await this.nativeModule.generate( imageSource, - classNames, + allClassNames, + classesOfInterestNames, resizeToInput ); - const result: Partial> = {}; + const result: ForwardReturnWithArgmax> = {}; for (const [key, maskData] of Object.entries(nativeResult)) { - if (key in this.labelMap) { - result[key as keyof T] = maskData as number[]; + if (key in this.labelMap || key === 'ARGMAX') { + result[key as keyof ResolveLabels] = maskData as number[]; } } return result; } - - /** - * Unloads the model from memory. - */ - delete() { - if (this.nativeModule != null) { - this.nativeModule.unload(); - } - } } - -// Type tests - -// async function _typeTests() { -// const deeplab = await ImageSegmentation.fromModelName('https://example.com/model.pte', 'deeplab-v3'); -// const deeplabResult = await deeplab.forward('image.jpg', ['PERSON', 'CAR', 'ARGMAX']); -// deeplabResult.PERSON; // OK -// deeplabResult.CAR; // OK -// // ERROR: 'BANANA' is not a DeeplabLabel key -// deeplabResult.BANANA; -// -// // fromModelName: selfie-segmentation — should autocomplete 'background' | 'object' -// const selfie = await ImageSegmentation.fromModelName('https://example.com/model.pte', 'selfie-segmentation'); -// const selfieResult = await selfie.forward('image.jpg', ['background']); -// selfieResult.background; // OK -// selfieResult.object; // OK -// // ERROR: 'PERSON' is not a selfie-segmentation key -// selfieResult.PERSON; -// -// // fromCustomConfig: custom labels — should infer from provided map -// const custom = await ImageSegmentation.fromCustomConfig('https://example.com/model.pte', { -// labelMap: { sky: 0, ground: 1, building: 2 } as const, -// }); -// const customResult = await custom.forward('image.jpg', ['sky', 'ground']); -// customResult.sky; // OK -// customResult.building; // OK -// // 'water' is not in the custom label map -// customResult.water; -// -// // ERORR: 'nonexistent-model' is not a known model name -// await ImageSegmentation.fromModelName('https://example.com/model.pte', 'nonexistent-model'); -// -// // forward classesOfInterest should only accept valid keys -// // 'INVALID' is not a DeeplabLabel key -// await deeplab.forward('image.jpg', ['INVALID']); -// } diff --git a/packages/react-native-executorch/src/types/imageSegmentation.ts b/packages/react-native-executorch/src/types/imageSegmentation.ts index 02d9eec10..3160ec5b7 100644 --- a/packages/react-native-executorch/src/types/imageSegmentation.ts +++ b/packages/react-native-executorch/src/types/imageSegmentation.ts @@ -28,7 +28,116 @@ export enum DeeplabLabel { SOFA, TRAIN, TVMONITOR, - ARGMAX, // Additional label not present in the model +} + +/** + * Labels used in the selfie image segmentation model. + * + * @category Types + */ +export enum SelfieSegmentationLabel { + SELFIE, + BACKGROUND, +} + +/** + * COCO 91-class labels used by RF-DETR segmentation models. + * Indices match the model's 91 output channels. + * + * @category Types + */ +export enum CocoLabel { + BACKGROUND = 0, + PERSON = 1, + BICYCLE = 2, + CAR = 3, + MOTORCYCLE = 4, + AIRPLANE = 5, + BUS = 6, + TRAIN = 7, + TRUCK = 8, + BOAT = 9, + TRAFFIC_LIGHT = 10, + FIRE_HYDRANT = 11, + _RESERVED_12 = 12, + STOP_SIGN = 13, + PARKING_METER = 14, + BENCH = 15, + BIRD = 16, + CAT = 17, + DOG = 18, + HORSE = 19, + SHEEP = 20, + COW = 21, + ELEPHANT = 22, + BEAR = 23, + ZEBRA = 24, + GIRAFFE = 25, + _RESERVED_26 = 26, + BACKPACK = 27, + UMBRELLA = 28, + _RESERVED_29 = 29, + _RESERVED_30 = 30, + HANDBAG = 31, + TIE = 32, + SUITCASE = 33, + FRISBEE = 34, + SKIS = 35, + SNOWBOARD = 36, + SPORTS_BALL = 37, + KITE = 38, + BASEBALL_BAT = 39, + BASEBALL_GLOVE = 40, + SKATEBOARD = 41, + SURFBOARD = 42, + TENNIS_RACKET = 43, + BOTTLE = 44, + _RESERVED_45 = 45, + WINE_GLASS = 46, + CUP = 47, + FORK = 48, + KNIFE = 49, + SPOON = 50, + BOWL = 51, + BANANA = 52, + APPLE = 53, + SANDWICH = 54, + ORANGE = 55, + BROCCOLI = 56, + CARROT = 57, + HOT_DOG = 58, + PIZZA = 59, + DONUT = 60, + CAKE = 61, + CHAIR = 62, + COUCH = 63, + POTTED_PLANT = 64, + BED = 65, + _RESERVED_66 = 66, + DINING_TABLE = 67, + _RESERVED_68 = 68, + _RESERVED_69 = 69, + TOILET = 70, + _RESERVED_71 = 71, + TV = 72, + LAPTOP = 73, + MOUSE = 74, + REMOTE = 75, + KEYBOARD = 76, + CELL_PHONE = 77, + MICROWAVE = 78, + OVEN = 79, + TOASTER = 80, + SINK = 81, + REFRIGERATOR = 82, + _RESERVED_83 = 83, + BOOK = 84, + CLOCK = 85, + VASE = 86, + SCISSORS = 87, + TEDDY_BEAR = 88, + HAIR_DRIER = 89, + TOOTHBRUSH = 90, } /** From 4c7dfade3170a259402a6a79d5dfab7f0773e3ab Mon Sep 17 00:00:00 2001 From: chmjkb Date: Mon, 16 Feb 2026 20:56:07 +0100 Subject: [PATCH 03/27] delete stuff --- .../GenericImageSegmentation.ts | 68 ------------------- 1 file changed, 68 deletions(-) delete mode 100644 packages/react-native-executorch/src/modules/computer_vision/GenericImageSegmentation.ts diff --git a/packages/react-native-executorch/src/modules/computer_vision/GenericImageSegmentation.ts b/packages/react-native-executorch/src/modules/computer_vision/GenericImageSegmentation.ts deleted file mode 100644 index aeee4dec0..000000000 --- a/packages/react-native-executorch/src/modules/computer_vision/GenericImageSegmentation.ts +++ /dev/null @@ -1,68 +0,0 @@ -import { ResourceFetcher } from '../../utils/ResourceFetcher'; -import { ResourceSource } from '../../types/common'; -import { DeeplabLabel } from '../../types/imageSegmentation'; -import { BaseModule } from '../BaseModule'; - -/** - * Module for image segmentation tasks. - * - * @category Typescript API - */ - -// Allow string or number values (standard Enums use numbers) -type LabelMap = Record; - -export class ImageSegmentationModule extends BaseModule { - async load( - modelSource: ResourceSource, - onDownloadProgressCallback: (progress: number) => void = () => {} - ) { - // Implementation of model loading... - } - - /** - * Generic forward pass that accepts a custom Label Enum. - * * @param imageSource - Path to the image. - * @param labelMap - The runtime Enum object (e.g., DeeplabLabel or a custom object). - * @param classesOfInterest - Array of keys from the provided Enum (e.g., ['PERSON', 'DOG']). - * @param resizeToInput - Whether to resize output to input dimensions. - */ - public async forwardGeneric( - imageSource: string, - labelMap: T, - classesOfInterest: (keyof T)[], - resizeToInput: boolean = true - ): Promise> { - // 1. Convert the string keys (e.g., "PERSON") to their numeric indices (e.g., 15) - // We use the runtime 'labelMap' object to look up the values. - const classIndices = (classesOfInterest || []).map( - (label) => labelMap[label] - ); - - // 2. Call the native module with the numeric indices - const result = await this.nativeModule.generate( - imageSource, - classIndices, - resizeToInput - ); - - return result; - } - - /** - * Convenience wrapper for the default DeeplabLabel model. - */ - public async forward( - imageSource: string, - classesOfInterest: (keyof typeof DeeplabLabel)[], - resizeToInput: boolean = true - ) { - // Passes the default DeeplabLabel enum automatically - return this.forwardGeneric( - imageSource, - DeeplabLabel, - classesOfInterest, - resizeToInput - ); - } -} From 4ae6370bfde6bd353819809f4305fafd02232224 Mon Sep 17 00:00:00 2001 From: chmjkb Date: Tue, 17 Feb 2026 08:37:05 +0100 Subject: [PATCH 04/27] a bunch of cleanups, add wip for a hook --- .../data_processing/ImageProcessing.cpp | 8 +- .../data_processing/ImageProcessing.h | 4 +- .../BaseImageSegmentation.cpp | 139 ++++------- .../BaseImageSegmentation.h | 11 +- .../src/constants/commonVision.ts | 4 + .../computer_vision/useImageSegmentation.ts | 113 ++++++++- packages/react-native-executorch/src/index.ts | 3 +- .../ImageSegmentationModule.ts | 221 ++++++++++++++---- .../computer_vision/NewImageSegmentation.ts | 186 --------------- .../src/types/common.ts | 15 ++ .../src/types/imageSegmentation.ts | 208 +++++------------ 11 files changed, 420 insertions(+), 492 deletions(-) create mode 100644 packages/react-native-executorch/src/constants/commonVision.ts delete mode 100644 packages/react-native-executorch/src/modules/computer_vision/NewImageSegmentation.ts diff --git a/packages/react-native-executorch/common/rnexecutorch/data_processing/ImageProcessing.cpp b/packages/react-native-executorch/common/rnexecutorch/data_processing/ImageProcessing.cpp index bdf3f97cb..bd29500b0 100644 --- a/packages/react-native-executorch/common/rnexecutorch/data_processing/ImageProcessing.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/data_processing/ImageProcessing.cpp @@ -217,7 +217,8 @@ cv::Mat resizePadded(const cv::Mat inputImage, cv::Size targetSize) { std::pair readImageToTensor(const std::string &path, const std::vector &tensorDims, - bool maintainAspectRatio) { + bool maintainAspectRatio, std::optional normMean, + std::optional normStd) { cv::Mat input = image_processing::readImage(path); cv::Size imageSize = input.size(); @@ -241,6 +242,11 @@ readImageToTensor(const std::string &path, cv::cvtColor(input, input, cv::COLOR_BGR2RGB); + if (normMean.has_value() && normStd.has_value()) { + return {image_processing::getTensorFromMatrix( + tensorDims, input, normMean.value(), normStd.value()), + imageSize}; + } return {image_processing::getTensorFromMatrix(tensorDims, input), imageSize}; } } // namespace image_processing diff --git a/packages/react-native-executorch/common/rnexecutorch/data_processing/ImageProcessing.h b/packages/react-native-executorch/common/rnexecutorch/data_processing/ImageProcessing.h index 27934330a..1b0c10b33 100644 --- a/packages/react-native-executorch/common/rnexecutorch/data_processing/ImageProcessing.h +++ b/packages/react-native-executorch/common/rnexecutorch/data_processing/ImageProcessing.h @@ -51,5 +51,7 @@ cv::Mat resizePadded(const cv::Mat inputImage, cv::Size targetSize); std::pair readImageToTensor(const std::string &path, const std::vector &tensorDims, - bool maintainAspectRatio = false); + bool maintainAspectRatio = false, + std::optional normMean = std::nullopt, + std::optional normStd = std::nullopt); } // namespace rnexecutorch::image_processing diff --git a/packages/react-native-executorch/common/rnexecutorch/models/image_segmentation/BaseImageSegmentation.cpp b/packages/react-native-executorch/common/rnexecutorch/models/image_segmentation/BaseImageSegmentation.cpp index 5496c868f..67790109f 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/image_segmentation/BaseImageSegmentation.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/models/image_segmentation/BaseImageSegmentation.cpp @@ -1,13 +1,10 @@ #include "BaseImageSegmentation.h" #include "jsi/jsi.h" -#include #include #include -#include #include -#include #include #include #include @@ -57,16 +54,8 @@ void BaseImageSegmentation::initModelImageSize() { TensorPtr BaseImageSegmentation::preprocess(const std::string &imageSource, cv::Size &originalSize) { - if (normMean_.has_value() && normStd_.has_value()) { - cv::Mat input = image_processing::readImage(imageSource); - originalSize = input.size(); - cv::resize(input, input, modelImageSize); - cv::cvtColor(input, input, cv::COLOR_BGR2RGB); - return image_processing::getTensorFromMatrix( - getAllInputShapes()[0], input, normMean_.value(), normStd_.value()); - } - auto [inputTensor, origSize] = - image_processing::readImageToTensor(imageSource, getAllInputShapes()[0]); + auto [inputTensor, origSize] = image_processing::readImageToTensor( + imageSource, getAllInputShapes()[0], false, normMean_, normStd_); originalSize = origSize; return inputTensor; } @@ -92,103 +81,73 @@ std::shared_ptr BaseImageSegmentation::generate( std::shared_ptr BaseImageSegmentation::postprocess( const Tensor &tensor, cv::Size originalSize, - std::vector allClasses, - std::set> classesOfInterest, bool resize) { + std::vector &allClasses, + std::set> &classesOfInterest, bool resize) { auto dataPtr = static_cast(tensor.const_data_ptr()); auto resultData = std::span(dataPtr, tensor.numel()); - // Infer output pixel count and channel count. - // If output spatial dims differ from input (e.g. model downsamples), - // derive pixel count from the tensor and allClasses.size(). - size_t numOutputChannels = tensor.numel() / numModelPixels; - size_t outputPixels = numModelPixels; - if (numOutputChannels != 1 && numOutputChannels != allClasses.size() && - !allClasses.empty() && tensor.numel() % allClasses.size() == 0) { - outputPixels = tensor.numel() / allClasses.size(); - numOutputChannels = allClasses.size(); - } - auto outputSide = static_cast(std::sqrt(outputPixels)); - cv::Size outputSize(outputSide, outputSide); - - std::vector> resultClasses; - auto argmax = - std::make_shared(outputPixels * sizeof(int32_t)); - - if (numOutputChannels == 1) { - // Binary segmentation path (e.g. selfie segmentation) - // The single channel contains probability values in [0, 1] - // Synthesize two class buffers: background (1-p) and foreground (p) - resultClasses.reserve(2); - - auto bgBuffer = - std::make_shared(outputPixels * sizeof(float)); - auto fgBuffer = - std::make_shared(outputPixels * sizeof(float)); - - auto *bgData = reinterpret_cast(bgBuffer->data()); - auto *fgData = reinterpret_cast(fgBuffer->data()); - auto *argmaxData = reinterpret_cast(argmax->data()); - + // Read output dimensions directly from tensor shape + std::size_t numChannels = + (tensor.dim() >= 3) ? tensor.size(tensor.dim() - 3) : 1; + std::size_t outputH = tensor.size(tensor.dim() - 2); + std::size_t outputW = tensor.size(tensor.dim() - 1); + std::size_t outputPixels = outputH * outputW; + cv::Size outputSize(static_cast(outputW), static_cast(outputH)); + + // Work with vectors, only wrap into OwningArrayBuffer at the end + std::vector> classBuffers; + std::vector argmaxData(outputPixels); + + if (numChannels == 1) { + // Binary segmentation (e.g. selfie segmentation) + std::vector bg(outputPixels); + std::vector fg(outputPixels); for (std::size_t pixel = 0; pixel < outputPixels; ++pixel) { float p = resultData[pixel]; - bgData[pixel] = 1.0f - p; - fgData[pixel] = p; + bg[pixel] = 1.0f - p; + fg[pixel] = p; argmaxData[pixel] = (p > 0.5f) ? 1 : 0; } - - resultClasses.push_back(bgBuffer); - resultClasses.push_back(fgBuffer); - } else if (numOutputChannels == allClasses.size()) { - // Multi-class segmentation path (e.g. DeepLab-v3) - // Copy per-class buffers from the ET-owned tensor data - resultClasses.reserve(allClasses.size()); - for (std::size_t cl = 0; cl < allClasses.size(); ++cl) { - auto classBuffer = std::make_shared( - &resultData[cl * outputPixels], outputPixels * sizeof(float)); - resultClasses.push_back(classBuffer); + classBuffers = {std::move(bg), std::move(fg)}; + } else { + // Multi-class segmentation (e.g. DeepLab, RF-DETR) + classBuffers.resize(numChannels); + for (std::size_t cl = 0; cl < numChannels; ++cl) { + classBuffers[cl].assign(&resultData[cl * outputPixels], + &resultData[(cl + 1) * outputPixels]); } - // Apply softmax per each pixel across all classes + // Apply softmax and compute argmax per pixel for (std::size_t pixel = 0; pixel < outputPixels; ++pixel) { - std::vector classValues(allClasses.size()); - for (std::size_t cl = 0; cl < allClasses.size(); ++cl) { - classValues[cl] = - reinterpret_cast(resultClasses[cl]->data())[pixel]; + std::vector values(numChannels); + for (std::size_t cl = 0; cl < numChannels; ++cl) { + values[cl] = classBuffers[cl][pixel]; } - numerical::softmax(classValues); - for (std::size_t cl = 0; cl < allClasses.size(); ++cl) { - reinterpret_cast(resultClasses[cl]->data())[pixel] = - classValues[cl]; - } - } + numerical::softmax(values); - // Calculate the maximum class for each pixel - auto *argmaxData = reinterpret_cast(argmax->data()); - for (std::size_t pixel = 0; pixel < outputPixels; ++pixel) { - float max = reinterpret_cast(resultClasses[0]->data())[pixel]; + float maxVal = values[0]; int maxInd = 0; - for (std::size_t cl = 1; cl < allClasses.size(); ++cl) { - if (reinterpret_cast(resultClasses[cl]->data())[pixel] > max) { + for (std::size_t cl = 0; cl < numChannels; ++cl) { + classBuffers[cl][pixel] = values[cl]; + if (values[cl] > maxVal) { + maxVal = values[cl]; maxInd = static_cast(cl); - max = reinterpret_cast(resultClasses[cl]->data())[pixel]; } } argmaxData[pixel] = maxInd; } - } else { - char errorMessage[200]; - std::snprintf( - errorMessage, sizeof(errorMessage), - "Unexpected number of output channels: %zu. Expected 1 (binary) or " - "%zu (matching allClasses). Model output has %zu elements for %zu " - "pixels.", - numOutputChannels, allClasses.size(), tensor.numel(), outputPixels); - throw RnExecutorchError(RnExecutorchErrorCode::WrongDimensions, - errorMessage); } - // Filter classes of interest using allClasses labels + // Wrap into OwningArrayBuffers + auto argmax = std::make_shared(argmaxData); + std::vector> resultClasses; + resultClasses.reserve(classBuffers.size()); + for (auto &buf : classBuffers) { + resultClasses.push_back(std::make_shared(buf)); + } + + // Filter classes of interest auto buffersToReturn = std::make_shared>>(); for (std::size_t cl = 0; cl < resultClasses.size(); ++cl) { @@ -212,6 +171,7 @@ std::shared_ptr BaseImageSegmentation::postprocess( classMat.data, originalSize.area() * sizeof(float)); } } + return populateDictionary(argmax, buffersToReturn); } @@ -220,7 +180,6 @@ std::shared_ptr BaseImageSegmentation::populateDictionary( std::shared_ptr>> classesToOutput) { - // Synchronize the invoked thread to return when the dict is constructed auto promisePtr = std::make_shared>(); std::future doneFuture = promisePtr->get_future(); diff --git a/packages/react-native-executorch/common/rnexecutorch/models/image_segmentation/BaseImageSegmentation.h b/packages/react-native-executorch/common/rnexecutorch/models/image_segmentation/BaseImageSegmentation.h index 3444342a5..baf3872a9 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/image_segmentation/BaseImageSegmentation.h +++ b/packages/react-native-executorch/common/rnexecutorch/models/image_segmentation/BaseImageSegmentation.h @@ -35,8 +35,8 @@ class BaseImageSegmentation : public BaseModel { cv::Size &originalSize); virtual std::shared_ptr postprocess(const Tensor &tensor, cv::Size originalSize, - std::vector allClasses, - std::set> classesOfInterest, + std::vector &allClasses, + std::set> &classesOfInterest, bool resize); cv::Size modelImageSize; @@ -44,15 +44,14 @@ class BaseImageSegmentation : public BaseModel { std::optional normMean_; std::optional normStd_; -private: - void initModelImageSize(); - -protected: std::shared_ptr populateDictionary( std::shared_ptr argmax, std::shared_ptr>> classesToOutput); + +private: + void initModelImageSize(); }; } // namespace models::image_segmentation } // namespace rnexecutorch diff --git a/packages/react-native-executorch/src/constants/commonVision.ts b/packages/react-native-executorch/src/constants/commonVision.ts new file mode 100644 index 000000000..4e2f29473 --- /dev/null +++ b/packages/react-native-executorch/src/constants/commonVision.ts @@ -0,0 +1,4 @@ +import { Triple } from '../types/common'; + +export const IMAGENET_MEAN: Triple = [0.485, 0.456, 0.406]; +export const IMAGENET_STD: Triple = [0.229, 0.224, 0.225]; diff --git a/packages/react-native-executorch/src/hooks/computer_vision/useImageSegmentation.ts b/packages/react-native-executorch/src/hooks/computer_vision/useImageSegmentation.ts index c7a352e9a..05af520ce 100644 --- a/packages/react-native-executorch/src/hooks/computer_vision/useImageSegmentation.ts +++ b/packages/react-native-executorch/src/hooks/computer_vision/useImageSegmentation.ts @@ -1,23 +1,112 @@ -import { useModule } from '../useModule'; -import { ImageSegmentationModule } from '../../modules/computer_vision/ImageSegmentationModule'; +import { useState, useEffect } from 'react'; import { + ImageSegmentation, + SegmentationLabels, +} from '../../modules/computer_vision/ImageSegmentationModule'; +import { + ImageSegmentationForwardReturn, ImageSegmentationProps, - ImageSegmentationType, + ModelNameOf, + ModelSources, } from '../../types/imageSegmentation'; +import { RnExecutorchErrorCode } from '../../errors/ErrorCodes'; +import { RnExecutorchError, parseUnknownError } from '../../errors/errorUtils'; /** * React hook for managing an Image Segmentation model instance. * + * @typeParam C - A {@link ModelSources} config specifying which built-in model to load. + * @param props - Configuration object containing `model` config and optional `preventLoad` flag. + * @returns An object with model state (`error`, `isReady`, `isGenerating`, `downloadProgress`) and a typed `forward` function. + * + * @example + * ```ts + * const { isReady, forward } = useImageSegmentation({ + * model: { modelName: 'deeplab-v3', modelSource: DEEPLAB_V3_RESNET50 }, + * }); + * ``` + * * @category Hooks - * @param ImageSegmentationProps - Configuration object containing `model` source and optional `preventLoad` flag. - * @returns Ready to use Image Segmentation model. */ -export const useImageSegmentation = ({ +export const useImageSegmentation = ({ model, preventLoad = false, -}: ImageSegmentationProps): ImageSegmentationType => - useModule({ - module: ImageSegmentationModule, - model, - preventLoad, - }); +}: ImageSegmentationProps) => { + const [error, setError] = useState(null); + const [isReady, setIsReady] = useState(false); + const [isGenerating, setIsGenerating] = useState(false); + const [downloadProgress, setDownloadProgress] = useState(0); + const [instance, setInstance] = useState + > | null>(null); + + useEffect(() => { + if (preventLoad) return; + + let currentInstance: ImageSegmentation> | null = null; + + (async () => { + setDownloadProgress(0); + setError(null); + setIsReady(false); + try { + currentInstance = await ImageSegmentation.fromModelName( + model, + setDownloadProgress + ); + setInstance(currentInstance); + setIsReady(true); + } catch (err) { + setError(parseUnknownError(err)); + } + })(); + + return () => { + currentInstance?.delete(); + }; + + // eslint-disable-next-line react-hooks/exhaustive-deps + }, [model.modelName, model.modelSource, preventLoad]); + + const forward = async ( + imageSource: string, + classesOfInterest: ( + | keyof SegmentationLabels> + | 'ARGMAX' + )[] = ['ARGMAX'], + resizeToInput: boolean = true + ): Promise< + ImageSegmentationForwardReturn>> + > => { + if (!isReady || !instance) { + throw new RnExecutorchError( + RnExecutorchErrorCode.ModuleNotLoaded, + 'The model is currently not loaded. Please load the model before calling forward().' + ); + } + if (isGenerating) { + throw new RnExecutorchError( + RnExecutorchErrorCode.ModelGenerating, + 'The model is currently generating. Please wait until previous model run is complete.' + ); + } + try { + setIsGenerating(true); + return await instance.forward( + imageSource, + classesOfInterest, + resizeToInput + ); + } finally { + setIsGenerating(false); + } + }; + + return { + error, + isReady, + isGenerating, + downloadProgress, + forward, + }; +}; diff --git a/packages/react-native-executorch/src/index.ts b/packages/react-native-executorch/src/index.ts index 0e9f09ba6..b33062346 100644 --- a/packages/react-native-executorch/src/index.ts +++ b/packages/react-native-executorch/src/index.ts @@ -133,8 +133,7 @@ export * from './hooks/general/useExecutorchModule'; export * from './modules/computer_vision/ClassificationModule'; export * from './modules/computer_vision/ObjectDetectionModule'; export * from './modules/computer_vision/StyleTransferModule'; -// export * from './modules/computer_vision/ImageSegmentationModule'; -export * from './modules/computer_vision/NewImageSegmentation'; +export * from './modules/computer_vision/ImageSegmentationModule'; export * from './modules/computer_vision/OCRModule'; export * from './modules/computer_vision/VerticalOCRModule'; export * from './modules/computer_vision/ImageEmbeddingsModule'; diff --git a/packages/react-native-executorch/src/modules/computer_vision/ImageSegmentationModule.ts b/packages/react-native-executorch/src/modules/computer_vision/ImageSegmentationModule.ts index ddba7cdb7..4bac7deca 100644 --- a/packages/react-native-executorch/src/modules/computer_vision/ImageSegmentationModule.ts +++ b/packages/react-native-executorch/src/modules/computer_vision/ImageSegmentationModule.ts @@ -1,82 +1,211 @@ import { ResourceFetcher } from '../../utils/ResourceFetcher'; -import { ResourceSource } from '../../types/common'; -import { DeeplabLabel } from '../../types/imageSegmentation'; +import { ResourceSource, LabelEnum } from '../../types/common'; +import { CocoLabel } from '../../types/objectDetection'; +import { + DeeplabLabel, + ImageSegmentationForwardReturn, + ModelNameOf, + ModelSources, + SegmentationConfig, + SegmentationModelName, + SelfieSegmentationLabel, +} from '../../types/imageSegmentation'; +import { IMAGENET_MEAN, IMAGENET_STD } from '../../constants/commonVision'; import { RnExecutorchErrorCode } from '../../errors/ErrorCodes'; import { parseUnknownError, RnExecutorchError } from '../../errors/errorUtils'; import { BaseModule } from '../BaseModule'; import { Logger } from '../../common/Logger'; +const ModelConfigs = { + 'deeplab-v3': { + labelMap: DeeplabLabel, + preprocessorConfig: undefined, + }, + 'selfie-segmentation': { + labelMap: SelfieSegmentationLabel, + preprocessorConfig: undefined, + }, + 'rfdetr': { + labelMap: CocoLabel, + preprocessorConfig: { normMean: IMAGENET_MEAN, normStd: IMAGENET_STD }, + }, +} as const satisfies Record< + SegmentationModelName, + SegmentationConfig +>; + +/** @internal */ +type ModelConfigsType = typeof ModelConfigs; + +/** + * Resolves the {@link LabelEnum} for a given built-in model name. + * + * @typeParam M - A built-in model name from {@link SegmentationModelName}. + * + * @category Types + */ +export type SegmentationLabels = + ModelConfigsType[M]['labelMap']; + +/** + * @internal + * Resolves the label type: if `T` is a {@link SegmentationModelName}, looks up its labels + * from the built-in config; otherwise uses `T` directly as a {@link LabelEnum}. + */ +type ResolveLabels = + T extends SegmentationModelName ? SegmentationLabels : T; + /** - * Module for image segmentation tasks. + * Generic image segmentation module with type-safe label maps. + * Use a model name (e.g. `'deeplab-v3'`) as the generic parameter for built-in models, + * or a custom label enum for custom configs. + * + * @typeParam T - Either a built-in model name (`'deeplab-v3'`, `'selfie-segmentation'`, `'rfdetr'`) + * or a custom {@link LabelEnum} label map. * * @category Typescript API */ -export class ImageSegmentationModule extends BaseModule { +export class ImageSegmentation< + T extends SegmentationModelName | LabelEnum, +> extends BaseModule { + private labelMap: ResolveLabels; + + private constructor(labelMap: ResolveLabels, nativeModule: unknown) { + super(); + this.labelMap = labelMap; + this.nativeModule = nativeModule; + } + + // TODO: figure it out so we can delete this (we need this because of basemodule inheritance) + override async load() { } + /** - * Loads the model, where `modelSource` is a string that specifies the location of the model binary. - * To track the download progress, supply a callback function `onDownloadProgressCallback`. + * Creates a segmentation instance for a built-in model. + * The config object is discriminated by `modelName` — each model can require different fields. + * + * @param config - A {@link ModelSources} object specifying which model to load and where to fetch it from. + * @param onDownloadProgress - Optional callback to monitor download progress, receiving a value between 0 and 1. + * @returns A Promise resolving to an `ImageSegmentation` instance typed to the chosen model's label map. * - * @param model - Object containing `modelSource`. - * @param onDownloadProgressCallback - Optional callback to monitor download progress. + * @example + * ```ts + * const segmentation = await ImageSegmentation.fromModelName({ + * modelName: 'deeplab-v3', + * modelSource: 'https://example.com/deeplab.pte', + * }); + * ``` */ - async load( - model: { modelSource: ResourceSource }, - onDownloadProgressCallback: (progress: number) => void = () => {} - ): Promise { - try { - const paths = await ResourceFetcher.fetch( - onDownloadProgressCallback, - model.modelSource - ); - if (!paths?.[0]) { - throw new RnExecutorchError( - RnExecutorchErrorCode.DownloadInterrupted, - 'The download has been interrupted. As a result, not every file was downloaded. Please retry the download.' - ); - } + static async fromModelName( + config: C, + onDownloadProgress: (progress: number) => void = () => { } + ): Promise>> { + const { modelName, modelSource } = config; + const { labelMap, preprocessorConfig } = ModelConfigs[modelName]; + const normMean = [...(preprocessorConfig?.normMean ?? [])]; + const normStd = [...(preprocessorConfig?.normStd ?? [])]; + const paths = await ResourceFetcher.fetch(onDownloadProgress, modelSource); + if (!paths?.[0]) { + throw new RnExecutorchError( + RnExecutorchErrorCode.DownloadInterrupted, + 'The download has been interrupted. As a result, not every file was downloaded. Please retry the download.' + ); + } + const nativeModule = global.loadImageSegmentation( + paths[0] || '', + normMean, + normStd + ); + return new ImageSegmentation>( + labelMap as ResolveLabels>, + nativeModule + ); + } - this.nativeModule = global.loadImageSegmentation(paths[0]); - } catch (error) { - Logger.error('Load failed:', error); - throw parseUnknownError(error); + /** + * Creates a segmentation instance with a user-provided label map and custom config. + * Use this when working with a custom-exported segmentation model that is not one of the built-in models. + * + * @param modelSource - A fetchable resource pointing to the model binary. + * @param config - A {@link SegmentationConfig} object with the label map and optional preprocessing parameters. + * @param onDownloadProgress - Optional callback to monitor download progress, receiving a value between 0 and 1. + * @returns A Promise resolving to an `ImageSegmentation` instance typed to the provided label map. + * + * @example + * ```ts + * const MyLabels = { BACKGROUND: 0, FOREGROUND: 1 } as const; + * const segmentation = await ImageSegmentation.fromCustomConfig( + * 'https://example.com/custom_model.pte', + * { labelMap: MyLabels }, + * ); + * ``` + */ + static async fromCustomConfig( + modelSource: ResourceSource, + config: SegmentationConfig, + onDownloadProgress: (progress: number) => void = () => { } + ): Promise> { + const paths = await ResourceFetcher.fetch(onDownloadProgress, modelSource); + if (paths === null || !paths[0]) { + throw new RnExecutorchError( + RnExecutorchErrorCode.DownloadInterrupted, + 'The download has been interrupted. Please retry.' + ); } + const normMean = config.preprocessorConfig?.normMean ?? []; + const normStd = config.preprocessorConfig?.normStd ?? []; + const nativeModule = global.loadImageSegmentation( + paths[0], + [...normMean], + [...normStd] + ); + return new ImageSegmentation( + config.labelMap as ResolveLabels, + nativeModule + ); } /** - * Executes the model's forward pass + * Executes the model's forward pass to perform semantic segmentation on the provided image. * - * @param imageSource - a fetchable resource or a Base64-encoded string. - * @param classesOfInterest - an optional list of DeeplabLabel used to indicate additional arrays of probabilities to output (see section "Running the model"). The default is an empty list. - * @param resizeToInput - an optional boolean to indicate whether the output should be resized to the original input image dimensions. If `false`, returns the model output without any resizing (see section "Running the model"). Defaults to `true`. - * @returns A dictionary where keys are `DeeplabLabel` and values are arrays of probabilities for each pixel belonging to the corresponding class. + * @param imageSource - A string representing the image source (e.g., a file path, URI, or Base64-encoded string). + * @param classesOfInterest - An optional list of label keys (or `'ARGMAX'`) indicating which per-class probability masks to include in the output. Defaults to an empty list (only `ARGMAX` is always returned). + * @param resizeToInput - Whether to resize the output masks to the original input image dimensions. If `false`, returns the raw model output dimensions. Defaults to `true`. + * @returns A Promise resolving to an object mapping each requested class label (and `'ARGMAX'`) to a flat array of per-pixel values. + * @throws {RnExecutorchError} If the model is not loaded. */ async forward( imageSource: string, - classesOfInterest?: DeeplabLabel[], - resizeToInput?: boolean - ): Promise>> { + classesOfInterest: (keyof ResolveLabels | 'ARGMAX')[] = ['ARGMAX'], + resizeToInput: boolean = true + ): Promise>> { if (this.nativeModule == null) { throw new RnExecutorchError( RnExecutorchErrorCode.ModuleNotLoaded, - 'The model is currently not loaded. Please load the model before calling forward().' + 'The model is currently not loaded.' ); } - const stringDict = await this.nativeModule.generate( - imageSource, - (classesOfInterest || []).map((label) => DeeplabLabel[label]), - resizeToInput ?? true + const allClassNames = Object.keys(this.labelMap).filter((k) => + isNaN(Number(k)) + ); + const classesOfInterestNames = classesOfInterest.map((label) => + String(label) ); - let enumDict: { [key in DeeplabLabel]?: number[] } = {}; + const nativeResult = await this.nativeModule.generate( + imageSource, + allClassNames, + classesOfInterestNames, + resizeToInput + ); - for (const key in stringDict) { - if (key in DeeplabLabel) { - const enumKey = DeeplabLabel[key as keyof typeof DeeplabLabel]; - enumDict[enumKey] = stringDict[key]; + const result: ImageSegmentationForwardReturn> = {}; + for (const [key, maskData] of Object.entries(nativeResult)) { + if (key in this.labelMap || key === 'ARGMAX') { + result[key as keyof ResolveLabels] = maskData as number[]; } } - return enumDict; + return result; } } diff --git a/packages/react-native-executorch/src/modules/computer_vision/NewImageSegmentation.ts b/packages/react-native-executorch/src/modules/computer_vision/NewImageSegmentation.ts deleted file mode 100644 index 765a92c9a..000000000 --- a/packages/react-native-executorch/src/modules/computer_vision/NewImageSegmentation.ts +++ /dev/null @@ -1,186 +0,0 @@ -import { ResourceFetcher } from '../../utils/ResourceFetcher'; -import { ResourceSource } from '../../types/common'; -import { CocoLabel, DeeplabLabel } from '../../types/imageSegmentation'; -import { RnExecutorchErrorCode } from '../../errors/ErrorCodes'; -import { RnExecutorchError } from '../../errors/errorUtils'; -import { BaseModule } from '../BaseModule'; - -enum SelfieSegmentationLabel { - BACKGROUND, - SELFIE, -} -type Enumish = Readonly>; -export type Triple = readonly [T, T, T]; - -type SegmentationConfig = { - labelMap: T; - preprocessorConfig?: { normMean?: Triple; normStd?: Triple }; -}; - -type ForwardReturnWithArgmax = Partial< - Record ->; - -const IMAGENET_MEAN: Triple = [0.485, 0.456, 0.406]; -const IMAGENET_STD: Triple = [0.229, 0.224, 0.225]; - -const ModelConfigs = { - 'deeplab-v3': { - labelMap: DeeplabLabel, - }, - 'selfie-segmentation': { - labelMap: SelfieSegmentationLabel, - }, - 'rfdetr': { - labelMap: CocoLabel, - preprocessorConfig: { normMean: IMAGENET_MEAN, normStd: IMAGENET_STD }, - }, -} as const satisfies Record>; - -type ModelConfigsType = typeof ModelConfigs; -type ModelName = keyof ModelConfigsType; - -export type SegmentationLabels = - ModelConfigsType[M]['labelMap']; - -/** - * Resolves the label type: if T is a ModelName, look up its labels; otherwise use T directly as an Enumish. - */ -type ResolveLabels = T extends ModelName - ? SegmentationLabels - : T; - -/** - * Per-model config for `fromModelName`. Each model name maps to its required fields. - * Add new union members here when a model needs extra sources or options. - */ -type ModelSources = - | { modelName: 'deeplab-v3'; modelSource: ResourceSource } - | { modelName: 'selfie-segmentation'; modelSource: ResourceSource } - | { modelName: 'rfdetr'; modelSource: ResourceSource }; - -/** - * Extract the model name from a config object. - */ -type ModelNameOf = C['modelName']; - -/** - * Generic image segmentation module with type-safe label maps. - * Use a model name (e.g. `'deeplab-v3'`) as the generic parameter for built-in models, - * or a custom label enum for custom configs. - */ -export class ImageSegmentation< - T extends ModelName | Enumish, -> extends BaseModule { - private labelMap: ResolveLabels; - - private constructor(labelMap: ResolveLabels, nativeModule: unknown) { - super(); - this.labelMap = labelMap; - this.nativeModule = nativeModule; - } - - // TODO: figure it out so we can delete this (we need this because of basemodule inheritance) - async load() {} - - /** - * Creates a segmentation instance for a known model. - * The config object is discriminated by `modelName` — each model can require different fields. - */ - static async fromModelName( - config: C, - onDownloadProgress: (progress: number) => void = () => {} - ): Promise>> { - const { modelName, modelSource } = config; - const modelConfig = ModelConfigs[modelName]; - const { labelMap } = modelConfig; - const preprocessorConfig = - 'preprocessorConfig' in modelConfig - ? modelConfig.preprocessorConfig - : undefined; - const normMean = [...(preprocessorConfig?.normMean ?? [])]; - const normStd = [...(preprocessorConfig?.normStd ?? [])]; - const paths = await ResourceFetcher.fetch(onDownloadProgress, modelSource); - if (paths === null || paths.length < 1) { - throw new RnExecutorchError( - RnExecutorchErrorCode.DownloadInterrupted, - 'The download has been interrupted. Please retry.' - ); - } - const nativeModule = global.loadImageSegmentation( - paths[0] || '', - normMean, - normStd - ); - return new ImageSegmentation>( - labelMap as ResolveLabels>, - nativeModule - ); - } - - /** - * Creates a segmentation instance with a user-provided label map and custom config. - */ - static async fromCustomConfig( - modelSource: ResourceSource, - config: SegmentationConfig, - onDownloadProgress: (progress: number) => void = () => {} - ): Promise> { - const paths = await ResourceFetcher.fetch(onDownloadProgress, modelSource); - if (paths === null || paths.length < 1) { - throw new RnExecutorchError( - RnExecutorchErrorCode.DownloadInterrupted, - 'The download has been interrupted. Please retry.' - ); - } - const normMean = config.preprocessorConfig?.normMean ?? []; - const normStd = config.preprocessorConfig?.normStd ?? []; - const nativeModule = global.loadImageSegmentation( - paths[0] || '', - [...normMean], - [...normStd] - ); - return new ImageSegmentation( - config.labelMap as ResolveLabels, - nativeModule - ); - } - - /** - * Executes the model's forward pass. - */ - async forward( - imageSource: string, - classesOfInterest: (keyof ResolveLabels | 'ARGMAX')[] = [], - resizeToInput: boolean = true - ): Promise>> { - if (this.nativeModule == null) { - throw new RnExecutorchError( - RnExecutorchErrorCode.ModuleNotLoaded, - 'The model is currently not loaded.' - ); - } - - const allClassNames = Object.keys(this.labelMap).filter((k) => - isNaN(Number(k)) - ); - const classesOfInterestNames = classesOfInterest.map((label) => - String(label) - ); - - const nativeResult = await this.nativeModule.generate( - imageSource, - allClassNames, - classesOfInterestNames, - resizeToInput - ); - - const result: ForwardReturnWithArgmax> = {}; - for (const [key, maskData] of Object.entries(nativeResult)) { - if (key in this.labelMap || key === 'ARGMAX') { - result[key as keyof ResolveLabels] = maskData as number[]; - } - } - return result; - } -} diff --git a/packages/react-native-executorch/src/types/common.ts b/packages/react-native-executorch/src/types/common.ts index 7b87f31b6..384caa861 100644 --- a/packages/react-native-executorch/src/types/common.ts +++ b/packages/react-native-executorch/src/types/common.ts @@ -136,3 +136,18 @@ export interface TensorPtr { sizes: number[]; scalarType: ScalarType; } + +/** + * A readonly record mapping string keys to numeric or string values. + * Used to represent enum-like label maps for models. + * + * @category Types + */ +export type LabelEnum = Readonly>; + +/** + * A readonly triple of values, typically used for per-channel normalization parameters. + * + * @category Types + */ +export type Triple = readonly [T, T, T]; diff --git a/packages/react-native-executorch/src/types/imageSegmentation.ts b/packages/react-native-executorch/src/types/imageSegmentation.ts index 3160ec5b7..856ff79f1 100644 --- a/packages/react-native-executorch/src/types/imageSegmentation.ts +++ b/packages/react-native-executorch/src/types/imageSegmentation.ts @@ -1,5 +1,59 @@ -import { RnExecutorchError } from '../errors/errorUtils'; -import { ResourceSource } from './common'; +import { LabelEnum, Triple, ResourceSource } from './common'; + +/** + * Configuration for a custom segmentation model. + * + * @typeParam T - The {@link LabelEnum} type for the model. + * @property labelMap - The enum-like object mapping class names to indices. + * @property preprocessorConfig - Optional preprocessing parameters. + * @property preprocessorConfig.normMean - Per-channel mean values for input normalization. + * @property preprocessorConfig.normStd - Per-channel standard deviation values for input normalization. + * + * @category Types + */ +export type SegmentationConfig = { + labelMap: T; + preprocessorConfig?: { normMean?: Triple; normStd?: Triple }; +}; + +/** + * Return type for the segmentation model's forward pass. + * Maps class label keys (and optionally `"ARGMAX"`) to flat arrays of per-pixel values. + * + * @typeParam C - The {@link LabelEnum} type for the model. + * + * @category Types + */ +export type ImageSegmentationForwardReturn = Partial< + Record +>; + +/** + * Per-model config for {@link ImageSegmentation.fromModelName}. + * Each model name maps to its required fields. + * Add new union members here when a model needs extra sources or options. + * + * @category Types + */ +export type ModelSources = + | { modelName: 'deeplab-v3'; modelSource: ResourceSource } + | { modelName: 'selfie-segmentation'; modelSource: ResourceSource } + | { modelName: 'rfdetr'; modelSource: ResourceSource }; + +/** + * Union of all built-in segmentation model names + * (e.g. `'deeplab-v3'`, `'selfie-segmentation'`, `'rfdetr'`). + * + * @category Types + */ +export type SegmentationModelName = ModelSources['modelName']; + +/** + * Extracts the model name from a {@link ModelSources} config object. + * + * @category Types + */ +export type ModelNameOf = C['modelName']; /** * Labels used in the DeepLab image segmentation model. @@ -40,158 +94,16 @@ export enum SelfieSegmentationLabel { BACKGROUND, } -/** - * COCO 91-class labels used by RF-DETR segmentation models. - * Indices match the model's 91 output channels. - * - * @category Types - */ -export enum CocoLabel { - BACKGROUND = 0, - PERSON = 1, - BICYCLE = 2, - CAR = 3, - MOTORCYCLE = 4, - AIRPLANE = 5, - BUS = 6, - TRAIN = 7, - TRUCK = 8, - BOAT = 9, - TRAFFIC_LIGHT = 10, - FIRE_HYDRANT = 11, - _RESERVED_12 = 12, - STOP_SIGN = 13, - PARKING_METER = 14, - BENCH = 15, - BIRD = 16, - CAT = 17, - DOG = 18, - HORSE = 19, - SHEEP = 20, - COW = 21, - ELEPHANT = 22, - BEAR = 23, - ZEBRA = 24, - GIRAFFE = 25, - _RESERVED_26 = 26, - BACKPACK = 27, - UMBRELLA = 28, - _RESERVED_29 = 29, - _RESERVED_30 = 30, - HANDBAG = 31, - TIE = 32, - SUITCASE = 33, - FRISBEE = 34, - SKIS = 35, - SNOWBOARD = 36, - SPORTS_BALL = 37, - KITE = 38, - BASEBALL_BAT = 39, - BASEBALL_GLOVE = 40, - SKATEBOARD = 41, - SURFBOARD = 42, - TENNIS_RACKET = 43, - BOTTLE = 44, - _RESERVED_45 = 45, - WINE_GLASS = 46, - CUP = 47, - FORK = 48, - KNIFE = 49, - SPOON = 50, - BOWL = 51, - BANANA = 52, - APPLE = 53, - SANDWICH = 54, - ORANGE = 55, - BROCCOLI = 56, - CARROT = 57, - HOT_DOG = 58, - PIZZA = 59, - DONUT = 60, - CAKE = 61, - CHAIR = 62, - COUCH = 63, - POTTED_PLANT = 64, - BED = 65, - _RESERVED_66 = 66, - DINING_TABLE = 67, - _RESERVED_68 = 68, - _RESERVED_69 = 69, - TOILET = 70, - _RESERVED_71 = 71, - TV = 72, - LAPTOP = 73, - MOUSE = 74, - REMOTE = 75, - KEYBOARD = 76, - CELL_PHONE = 77, - MICROWAVE = 78, - OVEN = 79, - TOASTER = 80, - SINK = 81, - REFRIGERATOR = 82, - _RESERVED_83 = 83, - BOOK = 84, - CLOCK = 85, - VASE = 86, - SCISSORS = 87, - TEDDY_BEAR = 88, - HAIR_DRIER = 89, - TOOTHBRUSH = 90, -} - /** * Props for the `useImageSegmentation` hook. * - * @property {Object} model - An object containing the model source. - * @property {ResourceSource} model.modelSource - The source of the image segmentation model binary. + * @typeParam C - A {@link ModelSources} config specifying which built-in model to load. + * @property model - The model config containing `modelName` and `modelSource`. * @property {boolean} [preventLoad] - Boolean that can prevent automatic model loading (and downloading the data if you load it for the first time) after running the hook. * * @category Types */ -export interface ImageSegmentationProps { - model: { modelSource: ResourceSource }; +export interface ImageSegmentationProps { + model: C; preventLoad?: boolean; } - -/** - * Return type for the `useImageSegmentation` hook. - * Manages the state and operations for Computer Vision image segmentation (e.g., DeepLab). - * - * @category Types - */ -export interface ImageSegmentationType { - /** - * Contains the error object if the model failed to load, download, or encountered a runtime error during segmentation. - */ - error: RnExecutorchError | null; - - /** - * Indicates whether the segmentation model is loaded and ready to process images. - */ - isReady: boolean; - - /** - * Indicates whether the model is currently processing an image. - */ - isGenerating: boolean; - - /** - * Represents the download progress of the model binary as a value between 0 and 1. - */ - downloadProgress: number; - - /** - * Executes the model's forward pass to perform semantic segmentation on the provided image. - * @param imageSource - A string representing the image source (e.g., a file path, URI, or base64 string) to be processed. - * @param classesOfInterest - An optional array of `DeeplabLabel` enums. If provided, the model will only return segmentation masks for these specific classes. - * @param resizeToInput - an optional boolean to indicate whether the output should be resized to the original input image dimensions. If `false`, returns the model output without any resizing (see section "Running the model"). Defaults to `true`. - * @returns A Promise that resolves to an object mapping each detected `DeeplabLabel` to its corresponding segmentation mask (represented as a flattened array of numbers). - * @throws {RnExecutorchError} If the model is not loaded or is currently processing another image. - */ - forward: ( - imageSource: string, - classesOfInterest?: DeeplabLabel[], - resizeToInput?: boolean - ) => Promise>>; -} From f81b0745d85740013557e311918a152ad0c76eb6 Mon Sep 17 00:00:00 2001 From: chmjkb Date: Tue, 17 Feb 2026 09:04:42 +0100 Subject: [PATCH 05/27] type fix --- .../computer_vision/useImageSegmentation.ts | 14 +++++--------- .../computer_vision/ImageSegmentationModule.ts | 17 ++++++++--------- .../src/types/imageSegmentation.ts | 12 ------------ 3 files changed, 13 insertions(+), 30 deletions(-) diff --git a/packages/react-native-executorch/src/hooks/computer_vision/useImageSegmentation.ts b/packages/react-native-executorch/src/hooks/computer_vision/useImageSegmentation.ts index 05af520ce..9e56276cd 100644 --- a/packages/react-native-executorch/src/hooks/computer_vision/useImageSegmentation.ts +++ b/packages/react-native-executorch/src/hooks/computer_vision/useImageSegmentation.ts @@ -4,7 +4,6 @@ import { SegmentationLabels, } from '../../modules/computer_vision/ImageSegmentationModule'; import { - ImageSegmentationForwardReturn, ImageSegmentationProps, ModelNameOf, ModelSources, @@ -68,16 +67,13 @@ export const useImageSegmentation = ({ // eslint-disable-next-line react-hooks/exhaustive-deps }, [model.modelName, model.modelSource, preventLoad]); - const forward = async ( + const forward = async < + K extends keyof SegmentationLabels> | 'ARGMAX' = 'ARGMAX', + >( imageSource: string, - classesOfInterest: ( - | keyof SegmentationLabels> - | 'ARGMAX' - )[] = ['ARGMAX'], + classesOfInterest: K[] = ['ARGMAX' as K], resizeToInput: boolean = true - ): Promise< - ImageSegmentationForwardReturn>> - > => { + ): Promise> => { if (!isReady || !instance) { throw new RnExecutorchError( RnExecutorchErrorCode.ModuleNotLoaded, diff --git a/packages/react-native-executorch/src/modules/computer_vision/ImageSegmentationModule.ts b/packages/react-native-executorch/src/modules/computer_vision/ImageSegmentationModule.ts index 4bac7deca..a6588a09c 100644 --- a/packages/react-native-executorch/src/modules/computer_vision/ImageSegmentationModule.ts +++ b/packages/react-native-executorch/src/modules/computer_vision/ImageSegmentationModule.ts @@ -3,7 +3,6 @@ import { ResourceSource, LabelEnum } from '../../types/common'; import { CocoLabel } from '../../types/objectDetection'; import { DeeplabLabel, - ImageSegmentationForwardReturn, ModelNameOf, ModelSources, SegmentationConfig, @@ -146,7 +145,7 @@ export class ImageSegmentation< onDownloadProgress: (progress: number) => void = () => { } ): Promise> { const paths = await ResourceFetcher.fetch(onDownloadProgress, modelSource); - if (paths === null || !paths[0]) { + if (paths === null || paths.length < 1) { throw new RnExecutorchError( RnExecutorchErrorCode.DownloadInterrupted, 'The download has been interrupted. Please retry.' @@ -155,7 +154,7 @@ export class ImageSegmentation< const normMean = config.preprocessorConfig?.normMean ?? []; const normStd = config.preprocessorConfig?.normStd ?? []; const nativeModule = global.loadImageSegmentation( - paths[0], + paths[0]!, [...normMean], [...normStd] ); @@ -174,11 +173,11 @@ export class ImageSegmentation< * @returns A Promise resolving to an object mapping each requested class label (and `'ARGMAX'`) to a flat array of per-pixel values. * @throws {RnExecutorchError} If the model is not loaded. */ - async forward( + async forward | 'ARGMAX' = 'ARGMAX'>( imageSource: string, - classesOfInterest: (keyof ResolveLabels | 'ARGMAX')[] = ['ARGMAX'], + classesOfInterest: K[] = ['ARGMAX' as K], resizeToInput: boolean = true - ): Promise>> { + ): Promise> { if (this.nativeModule == null) { throw new RnExecutorchError( RnExecutorchErrorCode.ModuleNotLoaded, @@ -200,12 +199,12 @@ export class ImageSegmentation< resizeToInput ); - const result: ImageSegmentationForwardReturn> = {}; + const result: Partial> = {}; for (const [key, maskData] of Object.entries(nativeResult)) { if (key in this.labelMap || key === 'ARGMAX') { - result[key as keyof ResolveLabels] = maskData as number[]; + result[key as K] = maskData as number[]; } } - return result; + return result as Record; } } diff --git a/packages/react-native-executorch/src/types/imageSegmentation.ts b/packages/react-native-executorch/src/types/imageSegmentation.ts index 856ff79f1..8573bde05 100644 --- a/packages/react-native-executorch/src/types/imageSegmentation.ts +++ b/packages/react-native-executorch/src/types/imageSegmentation.ts @@ -16,18 +16,6 @@ export type SegmentationConfig = { preprocessorConfig?: { normMean?: Triple; normStd?: Triple }; }; -/** - * Return type for the segmentation model's forward pass. - * Maps class label keys (and optionally `"ARGMAX"`) to flat arrays of per-pixel values. - * - * @typeParam C - The {@link LabelEnum} type for the model. - * - * @category Types - */ -export type ImageSegmentationForwardReturn = Partial< - Record ->; - /** * Per-model config for {@link ImageSegmentation.fromModelName}. * Each model name maps to its required fields. From a2d477f3565d24140daf4dec46cfd7a7b169c5a7 Mon Sep 17 00:00:00 2001 From: chmjkb Date: Tue, 17 Feb 2026 09:16:59 +0100 Subject: [PATCH 06/27] type fix --- .../app/image_segmentation/index.tsx | 18 +++++++++-------- .../computer_vision/useImageSegmentation.ts | 8 +++----- .../ImageSegmentationModule.ts | 20 ++++++++++++------- 3 files changed, 26 insertions(+), 20 deletions(-) diff --git a/apps/computer-vision/app/image_segmentation/index.tsx b/apps/computer-vision/app/image_segmentation/index.tsx index 03d719938..26869c7a4 100644 --- a/apps/computer-vision/app/image_segmentation/index.tsx +++ b/apps/computer-vision/app/image_segmentation/index.tsx @@ -1,11 +1,7 @@ import Spinner from '../../components/Spinner'; import { BottomBar } from '../../components/BottomBar'; import { getImage } from '../../utils'; -import { - ImageSegmentation, - DEEPLAB_V3_RESNET50, - SegmentationLabels, -} from 'react-native-executorch'; +import { ImageSegmentation } from 'react-native-executorch'; import { Canvas, Image as SkiaImage, @@ -63,8 +59,11 @@ export default function ImageSegmentationScreen() { let instance: ImageSegmentation<'deeplab-v3'> | null = null; (async () => { instance = await ImageSegmentation.fromModelName( - DEEPLAB_V3_RESNET50.modelSource, - 'deeplab-v3', + { + modelName: 'deeplab-v3', + modelSource: + 'https://ai.swmansion.com/storage/jc_tests/selfie_seg.pte', + }, setDownloadProgress ); setModel(instance); @@ -89,7 +88,10 @@ export default function ImageSegmentationScreen() { try { setIsGenerating(true); const { width, height } = imageSize; - const output = await model.forward(imageUri, ['dupa'], true); + const t1 = performance.now(); + const output = await model.forward(imageUri, [], true); + const t2 = performance.now(); + console.log(t2 - t1); const argmax = output['ARGMAX'] || []; const pixels = new Uint8Array(width * height * 4); diff --git a/packages/react-native-executorch/src/hooks/computer_vision/useImageSegmentation.ts b/packages/react-native-executorch/src/hooks/computer_vision/useImageSegmentation.ts index 9e56276cd..3d591abb2 100644 --- a/packages/react-native-executorch/src/hooks/computer_vision/useImageSegmentation.ts +++ b/packages/react-native-executorch/src/hooks/computer_vision/useImageSegmentation.ts @@ -67,13 +67,11 @@ export const useImageSegmentation = ({ // eslint-disable-next-line react-hooks/exhaustive-deps }, [model.modelName, model.modelSource, preventLoad]); - const forward = async < - K extends keyof SegmentationLabels> | 'ARGMAX' = 'ARGMAX', - >( + const forward = async >>( imageSource: string, - classesOfInterest: K[] = ['ARGMAX' as K], + classesOfInterest: K[] = [], resizeToInput: boolean = true - ): Promise> => { + ): Promise> => { if (!isReady || !instance) { throw new RnExecutorchError( RnExecutorchErrorCode.ModuleNotLoaded, diff --git a/packages/react-native-executorch/src/modules/computer_vision/ImageSegmentationModule.ts b/packages/react-native-executorch/src/modules/computer_vision/ImageSegmentationModule.ts index a6588a09c..32e75574f 100644 --- a/packages/react-native-executorch/src/modules/computer_vision/ImageSegmentationModule.ts +++ b/packages/react-native-executorch/src/modules/computer_vision/ImageSegmentationModule.ts @@ -151,10 +151,16 @@ export class ImageSegmentation< 'The download has been interrupted. Please retry.' ); } + if (!paths[0]) { + throw new RnExecutorchError( + RnExecutorchErrorCode.DownloadInterrupted, + "The download couldn't be completed. Please retry." + ); + } const normMean = config.preprocessorConfig?.normMean ?? []; const normStd = config.preprocessorConfig?.normStd ?? []; const nativeModule = global.loadImageSegmentation( - paths[0]!, + paths[0], [...normMean], [...normStd] ); @@ -168,16 +174,16 @@ export class ImageSegmentation< * Executes the model's forward pass to perform semantic segmentation on the provided image. * * @param imageSource - A string representing the image source (e.g., a file path, URI, or Base64-encoded string). - * @param classesOfInterest - An optional list of label keys (or `'ARGMAX'`) indicating which per-class probability masks to include in the output. Defaults to an empty list (only `ARGMAX` is always returned). + * @param classesOfInterest - An optional list of label keys indicating which per-class probability masks to include in the output. `ARGMAX` is always returned regardless. * @param resizeToInput - Whether to resize the output masks to the original input image dimensions. If `false`, returns the raw model output dimensions. Defaults to `true`. - * @returns A Promise resolving to an object mapping each requested class label (and `'ARGMAX'`) to a flat array of per-pixel values. + * @returns A Promise resolving to an object mapping `'ARGMAX'` and each requested class label to a flat array of per-pixel values. * @throws {RnExecutorchError} If the model is not loaded. */ - async forward | 'ARGMAX' = 'ARGMAX'>( + async forward>( imageSource: string, - classesOfInterest: K[] = ['ARGMAX' as K], + classesOfInterest: K[] = [], resizeToInput: boolean = true - ): Promise> { + ): Promise> { if (this.nativeModule == null) { throw new RnExecutorchError( RnExecutorchErrorCode.ModuleNotLoaded, @@ -205,6 +211,6 @@ export class ImageSegmentation< result[key as K] = maskData as number[]; } } - return result as Record; + return result as Record; } } From 24026089061e3b39b21542daf8acf06f4b8c5ea5 Mon Sep 17 00:00:00 2001 From: chmjkb Date: Tue, 17 Feb 2026 09:26:01 +0100 Subject: [PATCH 07/27] remove useless stuff --- packages/react-native-executorch/src/hooks/useModule.ts | 7 ------- 1 file changed, 7 deletions(-) diff --git a/packages/react-native-executorch/src/hooks/useModule.ts b/packages/react-native-executorch/src/hooks/useModule.ts index 1140e9cc4..39b10249b 100644 --- a/packages/react-native-executorch/src/hooks/useModule.ts +++ b/packages/react-native-executorch/src/hooks/useModule.ts @@ -73,12 +73,6 @@ export const useModule = < } }; - const forwardGeneric = async ( - ...input: ForwardArgs - ): Promise => { - return await forward(...input); - }; - return { /** * Contains the error message if the model failed to load. @@ -100,6 +94,5 @@ export const useModule = < */ downloadProgress, forward, - forwardGeneric, }; }; From 51d412aec5ecbd1d173b839812e332bf3a125887 Mon Sep 17 00:00:00 2001 From: chmjkb Date: Tue, 17 Feb 2026 11:51:39 +0100 Subject: [PATCH 08/27] hehe --- apps/computer-vision/app/image_segmentation/index.tsx | 2 +- .../modules/computer_vision/ImageSegmentationModule.ts | 10 ++-------- 2 files changed, 3 insertions(+), 9 deletions(-) diff --git a/apps/computer-vision/app/image_segmentation/index.tsx b/apps/computer-vision/app/image_segmentation/index.tsx index 26869c7a4..d9d689432 100644 --- a/apps/computer-vision/app/image_segmentation/index.tsx +++ b/apps/computer-vision/app/image_segmentation/index.tsx @@ -89,7 +89,7 @@ export default function ImageSegmentationScreen() { setIsGenerating(true); const { width, height } = imageSize; const t1 = performance.now(); - const output = await model.forward(imageUri, [], true); + const output = await model.forward(imageUri, ['PERSON'], true); const t2 = performance.now(); console.log(t2 - t1); const argmax = output['ARGMAX'] || []; diff --git a/packages/react-native-executorch/src/modules/computer_vision/ImageSegmentationModule.ts b/packages/react-native-executorch/src/modules/computer_vision/ImageSegmentationModule.ts index 32e75574f..a6e7bbb58 100644 --- a/packages/react-native-executorch/src/modules/computer_vision/ImageSegmentationModule.ts +++ b/packages/react-native-executorch/src/modules/computer_vision/ImageSegmentationModule.ts @@ -111,7 +111,7 @@ export class ImageSegmentation< ); } const nativeModule = global.loadImageSegmentation( - paths[0] || '', + paths[0], normMean, normStd ); @@ -145,18 +145,12 @@ export class ImageSegmentation< onDownloadProgress: (progress: number) => void = () => { } ): Promise> { const paths = await ResourceFetcher.fetch(onDownloadProgress, modelSource); - if (paths === null || paths.length < 1) { + if (!paths?.[0]) { throw new RnExecutorchError( RnExecutorchErrorCode.DownloadInterrupted, 'The download has been interrupted. Please retry.' ); } - if (!paths[0]) { - throw new RnExecutorchError( - RnExecutorchErrorCode.DownloadInterrupted, - "The download couldn't be completed. Please retry." - ); - } const normMean = config.preprocessorConfig?.normMean ?? []; const normStd = config.preprocessorConfig?.normStd ?? []; const nativeModule = global.loadImageSegmentation( From 9a7d8dc2fd39fb266f781d00387069acd79d5f10 Mon Sep 17 00:00:00 2001 From: chmjkb Date: Tue, 17 Feb 2026 13:03:47 +0100 Subject: [PATCH 09/27] rename --- .../app/image_segmentation/index.tsx | 11 +++-- .../computer_vision/useImageSegmentation.ts | 8 ++-- packages/react-native-executorch/src/index.ts | 6 +-- .../ImageSegmentationModule.ts | 41 +++++++++---------- .../src/types/imageSegmentation.ts | 2 +- 5 files changed, 33 insertions(+), 35 deletions(-) diff --git a/apps/computer-vision/app/image_segmentation/index.tsx b/apps/computer-vision/app/image_segmentation/index.tsx index d9d689432..bd049ee48 100644 --- a/apps/computer-vision/app/image_segmentation/index.tsx +++ b/apps/computer-vision/app/image_segmentation/index.tsx @@ -1,7 +1,7 @@ import Spinner from '../../components/Spinner'; import { BottomBar } from '../../components/BottomBar'; import { getImage } from '../../utils'; -import { ImageSegmentation } from 'react-native-executorch'; +import { ImageSegmentationModule } from 'react-native-executorch'; import { Canvas, Image as SkiaImage, @@ -41,9 +41,8 @@ const numberToColor: number[][] = [ export default function ImageSegmentationScreen() { const { setGlobalGenerating } = useContext(GeneratingContext); - const [model, setModel] = useState | null>( - null - ); + const [model, setModel] = + useState | null>(null); const [isGenerating, setIsGenerating] = useState(false); const [downloadProgress, setDownloadProgress] = useState(0); const [imageUri, setImageUri] = useState(''); @@ -56,9 +55,9 @@ export default function ImageSegmentationScreen() { }, [isGenerating, setGlobalGenerating]); useEffect(() => { - let instance: ImageSegmentation<'deeplab-v3'> | null = null; + let instance: ImageSegmentationModule<'deeplab-v3'> | null = null; (async () => { - instance = await ImageSegmentation.fromModelName( + instance = await ImageSegmentationModule.fromModelName( { modelName: 'deeplab-v3', modelSource: diff --git a/packages/react-native-executorch/src/hooks/computer_vision/useImageSegmentation.ts b/packages/react-native-executorch/src/hooks/computer_vision/useImageSegmentation.ts index 3d591abb2..6bc38c9d5 100644 --- a/packages/react-native-executorch/src/hooks/computer_vision/useImageSegmentation.ts +++ b/packages/react-native-executorch/src/hooks/computer_vision/useImageSegmentation.ts @@ -1,6 +1,6 @@ import { useState, useEffect } from 'react'; import { - ImageSegmentation, + ImageSegmentationModule, SegmentationLabels, } from '../../modules/computer_vision/ImageSegmentationModule'; import { @@ -35,21 +35,21 @@ export const useImageSegmentation = ({ const [isReady, setIsReady] = useState(false); const [isGenerating, setIsGenerating] = useState(false); const [downloadProgress, setDownloadProgress] = useState(0); - const [instance, setInstance] = useState > | null>(null); useEffect(() => { if (preventLoad) return; - let currentInstance: ImageSegmentation> | null = null; + let currentInstance: ImageSegmentationModule> | null = null; (async () => { setDownloadProgress(0); setError(null); setIsReady(false); try { - currentInstance = await ImageSegmentation.fromModelName( + currentInstance = await ImageSegmentationModule.fromModelName( model, setDownloadProgress ); diff --git a/packages/react-native-executorch/src/index.ts b/packages/react-native-executorch/src/index.ts index b33062346..32b97fa58 100644 --- a/packages/react-native-executorch/src/index.ts +++ b/packages/react-native-executorch/src/index.ts @@ -3,7 +3,7 @@ import { ResourceFetcher, ResourceFetcherAdapter, } from './utils/ResourceFetcher'; - +import { Triple } from './types/common'; /** * Configuration that goes to the `initExecutorch`. * You can pass either bare React Native or Expo configuration. @@ -38,8 +38,8 @@ declare global { var loadStyleTransfer: (source: string) => any; var loadImageSegmentation: ( source: string, - normMean: number[], - normStd: number[] + normMean: Triple | [], + normStd: Triple | [] ) => any; var loadClassification: (source: string) => any; var loadObjectDetection: (source: string) => any; diff --git a/packages/react-native-executorch/src/modules/computer_vision/ImageSegmentationModule.ts b/packages/react-native-executorch/src/modules/computer_vision/ImageSegmentationModule.ts index a6e7bbb58..f236386ec 100644 --- a/packages/react-native-executorch/src/modules/computer_vision/ImageSegmentationModule.ts +++ b/packages/react-native-executorch/src/modules/computer_vision/ImageSegmentationModule.ts @@ -64,14 +64,18 @@ type ResolveLabels = * * @category Typescript API */ -export class ImageSegmentation< +export class ImageSegmentationModule< T extends SegmentationModelName | LabelEnum, > extends BaseModule { private labelMap: ResolveLabels; + private allClassNames: string[]; private constructor(labelMap: ResolveLabels, nativeModule: unknown) { super(); this.labelMap = labelMap; + this.allClassNames = Object.keys(this.labelMap).filter((k) => + isNaN(Number(k)) + ); this.nativeModule = nativeModule; } @@ -84,11 +88,11 @@ export class ImageSegmentation< * * @param config - A {@link ModelSources} object specifying which model to load and where to fetch it from. * @param onDownloadProgress - Optional callback to monitor download progress, receiving a value between 0 and 1. - * @returns A Promise resolving to an `ImageSegmentation` instance typed to the chosen model's label map. + * @returns A Promise resolving to an `ImageSegmentationModule` instance typed to the chosen model's label map. * * @example * ```ts - * const segmentation = await ImageSegmentation.fromModelName({ + * const segmentation = await ImageSegmentationModule.fromModelName({ * modelName: 'deeplab-v3', * modelSource: 'https://example.com/deeplab.pte', * }); @@ -98,11 +102,11 @@ export class ImageSegmentation< static async fromModelName( config: C, onDownloadProgress: (progress: number) => void = () => { } - ): Promise>> { + ): Promise>> { const { modelName, modelSource } = config; const { labelMap, preprocessorConfig } = ModelConfigs[modelName]; - const normMean = [...(preprocessorConfig?.normMean ?? [])]; - const normStd = [...(preprocessorConfig?.normStd ?? [])]; + const normMean = preprocessorConfig?.normMean ?? []; + const normStd = preprocessorConfig?.normStd ?? []; const paths = await ResourceFetcher.fetch(onDownloadProgress, modelSource); if (!paths?.[0]) { throw new RnExecutorchError( @@ -115,7 +119,7 @@ export class ImageSegmentation< normMean, normStd ); - return new ImageSegmentation>( + return new ImageSegmentationModule>( labelMap as ResolveLabels>, nativeModule ); @@ -128,12 +132,12 @@ export class ImageSegmentation< * @param modelSource - A fetchable resource pointing to the model binary. * @param config - A {@link SegmentationConfig} object with the label map and optional preprocessing parameters. * @param onDownloadProgress - Optional callback to monitor download progress, receiving a value between 0 and 1. - * @returns A Promise resolving to an `ImageSegmentation` instance typed to the provided label map. + * @returns A Promise resolving to an `ImageSegmentationModule` instance typed to the provided label map. * * @example * ```ts * const MyLabels = { BACKGROUND: 0, FOREGROUND: 1 } as const; - * const segmentation = await ImageSegmentation.fromCustomConfig( + * const segmentation = await ImageSegmentationModule.fromCustomConfig( * 'https://example.com/custom_model.pte', * { labelMap: MyLabels }, * ); @@ -143,7 +147,7 @@ export class ImageSegmentation< modelSource: ResourceSource, config: SegmentationConfig, onDownloadProgress: (progress: number) => void = () => { } - ): Promise> { + ): Promise> { const paths = await ResourceFetcher.fetch(onDownloadProgress, modelSource); if (!paths?.[0]) { throw new RnExecutorchError( @@ -155,10 +159,10 @@ export class ImageSegmentation< const normStd = config.preprocessorConfig?.normStd ?? []; const nativeModule = global.loadImageSegmentation( paths[0], - [...normMean], - [...normStd] + normMean, + normStd ); - return new ImageSegmentation( + return new ImageSegmentationModule( config.labelMap as ResolveLabels, nativeModule ); @@ -185,25 +189,20 @@ export class ImageSegmentation< ); } - const allClassNames = Object.keys(this.labelMap).filter((k) => - isNaN(Number(k)) - ); const classesOfInterestNames = classesOfInterest.map((label) => String(label) ); const nativeResult = await this.nativeModule.generate( imageSource, - allClassNames, + this.allClassNames, classesOfInterestNames, resizeToInput ); - const result: Partial> = {}; + const result: Partial> = {}; for (const [key, maskData] of Object.entries(nativeResult)) { - if (key in this.labelMap || key === 'ARGMAX') { - result[key as K] = maskData as number[]; - } + result[key as K | 'ARGMAX'] = maskData as number[]; } return result as Record; } diff --git a/packages/react-native-executorch/src/types/imageSegmentation.ts b/packages/react-native-executorch/src/types/imageSegmentation.ts index 8573bde05..36e25cae6 100644 --- a/packages/react-native-executorch/src/types/imageSegmentation.ts +++ b/packages/react-native-executorch/src/types/imageSegmentation.ts @@ -17,7 +17,7 @@ export type SegmentationConfig = { }; /** - * Per-model config for {@link ImageSegmentation.fromModelName}. + * Per-model config for {@link ImageSegmentationModule.fromModelName}. * Each model name maps to its required fields. * Add new union members here when a model needs extra sources or options. * From d77560bccd64872659f7c14ef5c50b5b182e8e22 Mon Sep 17 00:00:00 2001 From: chmjkb Date: Tue, 17 Feb 2026 13:22:16 +0100 Subject: [PATCH 10/27] use hook in example app --- .../app/image_segmentation/index.tsx | 44 +++++-------------- 1 file changed, 12 insertions(+), 32 deletions(-) diff --git a/apps/computer-vision/app/image_segmentation/index.tsx b/apps/computer-vision/app/image_segmentation/index.tsx index bd049ee48..921554a09 100644 --- a/apps/computer-vision/app/image_segmentation/index.tsx +++ b/apps/computer-vision/app/image_segmentation/index.tsx @@ -1,7 +1,7 @@ import Spinner from '../../components/Spinner'; import { BottomBar } from '../../components/BottomBar'; import { getImage } from '../../utils'; -import { ImageSegmentationModule } from 'react-native-executorch'; +import { useImageSegmentation } from 'react-native-executorch'; import { Canvas, Image as SkiaImage, @@ -41,10 +41,13 @@ const numberToColor: number[][] = [ export default function ImageSegmentationScreen() { const { setGlobalGenerating } = useContext(GeneratingContext); - const [model, setModel] = - useState | null>(null); - const [isGenerating, setIsGenerating] = useState(false); - const [downloadProgress, setDownloadProgress] = useState(0); + const { isReady, isGenerating, downloadProgress, forward } = + useImageSegmentation({ + model: { + modelName: 'deeplab-v3', + modelSource: 'https://ai.swmansion.com/storage/jc_tests/selfie_seg.pte', + }, + }); const [imageUri, setImageUri] = useState(''); const [imageSize, setImageSize] = useState({ width: 0, height: 0 }); const [segImage, setSegImage] = useState(null); @@ -54,22 +57,6 @@ export default function ImageSegmentationScreen() { setGlobalGenerating(isGenerating); }, [isGenerating, setGlobalGenerating]); - useEffect(() => { - let instance: ImageSegmentationModule<'deeplab-v3'> | null = null; - (async () => { - instance = await ImageSegmentationModule.fromModelName( - { - modelName: 'deeplab-v3', - modelSource: - 'https://ai.swmansion.com/storage/jc_tests/selfie_seg.pte', - }, - setDownloadProgress - ); - setModel(instance); - })(); - return () => instance?.delete(); - }, []); - const handleCameraPress = async (isCamera: boolean) => { const image = await getImage(isCamera); if (!image?.uri) return; @@ -82,15 +69,10 @@ export default function ImageSegmentationScreen() { }; const runForward = async () => { - if (!model || !imageUri || imageSize.width === 0 || imageSize.height === 0) - return; + if (!imageUri || imageSize.width === 0 || imageSize.height === 0) return; try { - setIsGenerating(true); const { width, height } = imageSize; - const t1 = performance.now(); - const output = await model.forward(imageUri, ['PERSON'], true); - const t2 = performance.now(); - console.log(t2 - t1); + const output = await forward(imageUri, ['PERSON'], true); const argmax = output['ARGMAX'] || []; const pixels = new Uint8Array(width * height * 4); @@ -119,15 +101,13 @@ export default function ImageSegmentationScreen() { setSegImage(img); } catch (e) { console.error(e); - } finally { - setIsGenerating(false); } }; - if (!model) { + if (!isReady) { return ( ); From 407b40f56f85bec1bdc044c74cb59bb48533a053 Mon Sep 17 00:00:00 2001 From: chmjkb Date: Tue, 17 Feb 2026 14:40:21 +0100 Subject: [PATCH 11/27] return typed arrays instead of numberp[ --- .../src/hooks/computer_vision/useImageSegmentation.ts | 2 +- .../computer_vision/ImageSegmentationModule.ts | 11 ++++------- 2 files changed, 5 insertions(+), 8 deletions(-) diff --git a/packages/react-native-executorch/src/hooks/computer_vision/useImageSegmentation.ts b/packages/react-native-executorch/src/hooks/computer_vision/useImageSegmentation.ts index 6bc38c9d5..06ea4c61f 100644 --- a/packages/react-native-executorch/src/hooks/computer_vision/useImageSegmentation.ts +++ b/packages/react-native-executorch/src/hooks/computer_vision/useImageSegmentation.ts @@ -71,7 +71,7 @@ export const useImageSegmentation = ({ imageSource: string, classesOfInterest: K[] = [], resizeToInput: boolean = true - ): Promise> => { + ): Promise & Record> => { if (!isReady || !instance) { throw new RnExecutorchError( RnExecutorchErrorCode.ModuleNotLoaded, diff --git a/packages/react-native-executorch/src/modules/computer_vision/ImageSegmentationModule.ts b/packages/react-native-executorch/src/modules/computer_vision/ImageSegmentationModule.ts index f236386ec..63e1e9410 100644 --- a/packages/react-native-executorch/src/modules/computer_vision/ImageSegmentationModule.ts +++ b/packages/react-native-executorch/src/modules/computer_vision/ImageSegmentationModule.ts @@ -174,14 +174,14 @@ export class ImageSegmentationModule< * @param imageSource - A string representing the image source (e.g., a file path, URI, or Base64-encoded string). * @param classesOfInterest - An optional list of label keys indicating which per-class probability masks to include in the output. `ARGMAX` is always returned regardless. * @param resizeToInput - Whether to resize the output masks to the original input image dimensions. If `false`, returns the raw model output dimensions. Defaults to `true`. - * @returns A Promise resolving to an object mapping `'ARGMAX'` and each requested class label to a flat array of per-pixel values. + * @returns A Promise resolving to an object with an `'ARGMAX'` key mapped to an `Int32Array` of per-pixel class indices, and each requested class label mapped to a `Float32Array` of per-pixel probabilities. * @throws {RnExecutorchError} If the model is not loaded. */ async forward>( imageSource: string, classesOfInterest: K[] = [], resizeToInput: boolean = true - ): Promise> { + ): Promise & Record> { if (this.nativeModule == null) { throw new RnExecutorchError( RnExecutorchErrorCode.ModuleNotLoaded, @@ -200,10 +200,7 @@ export class ImageSegmentationModule< resizeToInput ); - const result: Partial> = {}; - for (const [key, maskData] of Object.entries(nativeResult)) { - result[key as K | 'ARGMAX'] = maskData as number[]; - } - return result as Record; + return nativeResult as Record<'ARGMAX', Int32Array> & + Record; } } From 72352561a0cffd423b3383ec7841f82d7b4b9c16 Mon Sep 17 00:00:00 2001 From: chmjkb Date: Tue, 17 Feb 2026 14:45:38 +0100 Subject: [PATCH 12/27] fix --- .../src/hooks/computer_vision/useImageSegmentation.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/react-native-executorch/src/hooks/computer_vision/useImageSegmentation.ts b/packages/react-native-executorch/src/hooks/computer_vision/useImageSegmentation.ts index 06ea4c61f..c7321fb80 100644 --- a/packages/react-native-executorch/src/hooks/computer_vision/useImageSegmentation.ts +++ b/packages/react-native-executorch/src/hooks/computer_vision/useImageSegmentation.ts @@ -71,7 +71,7 @@ export const useImageSegmentation = ({ imageSource: string, classesOfInterest: K[] = [], resizeToInput: boolean = true - ): Promise & Record> => { + ) => { if (!isReady || !instance) { throw new RnExecutorchError( RnExecutorchErrorCode.ModuleNotLoaded, From 05af3263e307f84d0ddbd92d9000389ae1b7ac8f Mon Sep 17 00:00:00 2001 From: chmjkb Date: Tue, 17 Feb 2026 15:22:23 +0100 Subject: [PATCH 13/27] add new mdoel constants, fix native stuff --- .../image_segmentation/BaseImageSegmentation.cpp | 13 +++++-------- .../src/constants/modelUrls.ts | 12 +++++++++++- 2 files changed, 16 insertions(+), 9 deletions(-) diff --git a/packages/react-native-executorch/common/rnexecutorch/models/image_segmentation/BaseImageSegmentation.cpp b/packages/react-native-executorch/common/rnexecutorch/models/image_segmentation/BaseImageSegmentation.cpp index 67790109f..ddfb284d0 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/image_segmentation/BaseImageSegmentation.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/models/image_segmentation/BaseImageSegmentation.cpp @@ -39,13 +39,10 @@ void BaseImageSegmentation::initModelImageSize() { } std::vector modelInputShape = inputShapes[0]; if (modelInputShape.size() < 2) { - char errorMessage[100]; - std::snprintf(errorMessage, sizeof(errorMessage), - "Unexpected model input size, expected at least 2 dimentions " - "but got: %zu.", - modelInputShape.size()); throw RnExecutorchError(RnExecutorchErrorCode::WrongDimensions, - errorMessage); + "Unexpected model input size, expected at least 2 " + "dimensions but got: " + + std::to_string(modelInputShape.size()) + "."); } modelImageSize = cv::Size(modelInputShape[modelInputShape.size() - 1], modelInputShape[modelInputShape.size() - 2]); @@ -114,8 +111,8 @@ std::shared_ptr BaseImageSegmentation::postprocess( // Multi-class segmentation (e.g. DeepLab, RF-DETR) classBuffers.resize(numChannels); for (std::size_t cl = 0; cl < numChannels; ++cl) { - classBuffers[cl].assign(&resultData[cl * outputPixels], - &resultData[(cl + 1) * outputPixels]); + classBuffers[cl].assign(resultData.data() + cl * outputPixels, + resultData.data() + (cl + 1) * outputPixels); } // Apply softmax and compute argmax per pixel diff --git a/packages/react-native-executorch/src/constants/modelUrls.ts b/packages/react-native-executorch/src/constants/modelUrls.ts index 6e76e52b7..20cacd050 100644 --- a/packages/react-native-executorch/src/constants/modelUrls.ts +++ b/packages/react-native-executorch/src/constants/modelUrls.ts @@ -521,8 +521,18 @@ const DEEPLAB_V3_RESNET50_MODEL = `${URL_PREFIX}-deeplab-v3/${VERSION_TAG}/xnnpa * @category Models - Image Segmentation */ export const DEEPLAB_V3_RESNET50 = { + modelName: 'deeplab-v3', modelSource: DEEPLAB_V3_RESNET50_MODEL, -}; +} as const; + +const SELFIE_SEGMENTATION_MODEL = `${URL_PREFIX}-selfie-segmentation/${VERSION_TAG}/xnnpack/selfie-segmentation.pte`; +/** + * @category Models - Image segmentation + */ +export const SELFIE_SEGMENTATION = { + modelName: 'selfie-segmentation', + modelSource: SELFIE_SEGMENTATION_MODEL, +} as const; // Image Embeddings const CLIP_VIT_BASE_PATCH32_IMAGE_MODEL = `${URL_PREFIX}-clip-vit-base-patch32/${VERSION_TAG}/clip-vit-base-patch32-vision_xnnpack.pte`; From 918d593509f0e24aae3b9aead265738fab373ef4 Mon Sep 17 00:00:00 2001 From: chmjkb Date: Tue, 17 Feb 2026 15:29:07 +0100 Subject: [PATCH 14/27] add type for hook return --- .../app/image_segmentation/index.tsx | 12 ++--- .../computer_vision/useImageSegmentation.ts | 5 ++- .../src/types/imageSegmentation.ts | 45 +++++++++++++++++++ 3 files changed, 55 insertions(+), 7 deletions(-) diff --git a/apps/computer-vision/app/image_segmentation/index.tsx b/apps/computer-vision/app/image_segmentation/index.tsx index 921554a09..8c6310540 100644 --- a/apps/computer-vision/app/image_segmentation/index.tsx +++ b/apps/computer-vision/app/image_segmentation/index.tsx @@ -1,7 +1,10 @@ import Spinner from '../../components/Spinner'; import { BottomBar } from '../../components/BottomBar'; import { getImage } from '../../utils'; -import { useImageSegmentation } from 'react-native-executorch'; +import { + DEEPLAB_V3_RESNET50, + useImageSegmentation, +} from 'react-native-executorch'; import { Canvas, Image as SkiaImage, @@ -43,10 +46,7 @@ export default function ImageSegmentationScreen() { const { setGlobalGenerating } = useContext(GeneratingContext); const { isReady, isGenerating, downloadProgress, forward } = useImageSegmentation({ - model: { - modelName: 'deeplab-v3', - modelSource: 'https://ai.swmansion.com/storage/jc_tests/selfie_seg.pte', - }, + model: DEEPLAB_V3_RESNET50, }); const [imageUri, setImageUri] = useState(''); const [imageSize, setImageSize] = useState({ width: 0, height: 0 }); @@ -72,7 +72,7 @@ export default function ImageSegmentationScreen() { if (!imageUri || imageSize.width === 0 || imageSize.height === 0) return; try { const { width, height } = imageSize; - const output = await forward(imageUri, ['PERSON'], true); + const output = await forward(imageUri, [], true); const argmax = output['ARGMAX'] || []; const pixels = new Uint8Array(width * height * 4); diff --git a/packages/react-native-executorch/src/hooks/computer_vision/useImageSegmentation.ts b/packages/react-native-executorch/src/hooks/computer_vision/useImageSegmentation.ts index c7321fb80..a885dd9b8 100644 --- a/packages/react-native-executorch/src/hooks/computer_vision/useImageSegmentation.ts +++ b/packages/react-native-executorch/src/hooks/computer_vision/useImageSegmentation.ts @@ -5,6 +5,7 @@ import { } from '../../modules/computer_vision/ImageSegmentationModule'; import { ImageSegmentationProps, + ImageSegmentationType, ModelNameOf, ModelSources, } from '../../types/imageSegmentation'; @@ -30,7 +31,9 @@ import { RnExecutorchError, parseUnknownError } from '../../errors/errorUtils'; export const useImageSegmentation = ({ model, preventLoad = false, -}: ImageSegmentationProps) => { +}: ImageSegmentationProps): ImageSegmentationType< + SegmentationLabels> +> => { const [error, setError] = useState(null); const [isReady, setIsReady] = useState(false); const [isGenerating, setIsGenerating] = useState(false); diff --git a/packages/react-native-executorch/src/types/imageSegmentation.ts b/packages/react-native-executorch/src/types/imageSegmentation.ts index 36e25cae6..9c8bcc7ae 100644 --- a/packages/react-native-executorch/src/types/imageSegmentation.ts +++ b/packages/react-native-executorch/src/types/imageSegmentation.ts @@ -1,3 +1,4 @@ +import { RnExecutorchError } from '../errors/errorUtils'; import { LabelEnum, Triple, ResourceSource } from './common'; /** @@ -95,3 +96,47 @@ export interface ImageSegmentationProps { model: C; preventLoad?: boolean; } + +/** + * Return type for the `useImageSegmentation` hook. + * Manages the state and operations for image segmentation models. + * + * @typeParam L - The {@link LabelEnum} representing the model's class labels. + * + * @category Types + */ +export interface ImageSegmentationType { + /** + * Contains the error object if the model failed to load, download, or encountered a runtime error during segmentation. + */ + error: RnExecutorchError | null; + + /** + * Indicates whether the segmentation model is loaded and ready to process images. + */ + isReady: boolean; + + /** + * Indicates whether the model is currently processing an image. + */ + isGenerating: boolean; + + /** + * Represents the download progress of the model binary as a value between 0 and 1. + */ + downloadProgress: number; + + /** + * Executes the model's forward pass to perform semantic segmentation on the provided image. + * @param imageSource - A string representing the image source (e.g., a file path, URI, or base64 string) to be processed. + * @param classesOfInterest - An optional array of label keys indicating which per-class probability masks to include in the output. `ARGMAX` is always returned regardless. + * @param resizeToInput - Whether to resize the output masks to the original input image dimensions. If `false`, returns the raw model output dimensions. Defaults to `true`. + * @returns A Promise resolving to an object with an `'ARGMAX'` `Int32Array` of per-pixel class indices, and each requested class label mapped to a `Float32Array` of per-pixel probabilities. + * @throws {RnExecutorchError} If the model is not loaded or is currently processing another image. + */ + forward: ( + imageSource: string, + classesOfInterest?: K[], + resizeToInput?: boolean + ) => Promise & Record>; +} From 54f2d4d7e892c0f9072cae033a7cfc5c561e8885 Mon Sep 17 00:00:00 2001 From: chmjkb Date: Wed, 18 Feb 2026 08:35:06 +0100 Subject: [PATCH 15/27] chore: review suggestions --- .../models/image_segmentation/BaseImageSegmentation.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/packages/react-native-executorch/common/rnexecutorch/models/image_segmentation/BaseImageSegmentation.cpp b/packages/react-native-executorch/common/rnexecutorch/models/image_segmentation/BaseImageSegmentation.cpp index ddfb284d0..fb3a6b152 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/image_segmentation/BaseImageSegmentation.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/models/image_segmentation/BaseImageSegmentation.cpp @@ -33,7 +33,7 @@ BaseImageSegmentation::BaseImageSegmentation( void BaseImageSegmentation::initModelImageSize() { auto inputShapes = getAllInputShapes(); - if (inputShapes.size() == 0) { + if (inputShapes.empty()) { throw RnExecutorchError(RnExecutorchErrorCode::UnexpectedNumInputs, "Model seems to not take any input tensors."); } @@ -81,7 +81,7 @@ std::shared_ptr BaseImageSegmentation::postprocess( std::vector &allClasses, std::set> &classesOfInterest, bool resize) { - auto dataPtr = static_cast(tensor.const_data_ptr()); + const auto *dataPtr = tensor.const_data_ptr(); auto resultData = std::span(dataPtr, tensor.numel()); // Read output dimensions directly from tensor shape @@ -90,7 +90,7 @@ std::shared_ptr BaseImageSegmentation::postprocess( std::size_t outputH = tensor.size(tensor.dim() - 2); std::size_t outputW = tensor.size(tensor.dim() - 1); std::size_t outputPixels = outputH * outputW; - cv::Size outputSize(static_cast(outputW), static_cast(outputH)); + cv::Size outputSize(outputW, outputH); // Work with vectors, only wrap into OwningArrayBuffer at the end std::vector> classBuffers; From 9fea26fe39dd2a326c157f14b1935d944c0f75b8 Mon Sep 17 00:00:00 2001 From: chmjkb Date: Wed, 18 Feb 2026 11:21:42 +0100 Subject: [PATCH 16/27] chore: review changes --- .../BaseImageSegmentation.cpp | 87 +++++++++++-------- 1 file changed, 51 insertions(+), 36 deletions(-) diff --git a/packages/react-native-executorch/common/rnexecutorch/models/image_segmentation/BaseImageSegmentation.cpp b/packages/react-native-executorch/common/rnexecutorch/models/image_segmentation/BaseImageSegmentation.cpp index fb3a6b152..a924bc1f2 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/image_segmentation/BaseImageSegmentation.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/models/image_segmentation/BaseImageSegmentation.cpp @@ -6,7 +6,6 @@ #include #include #include -#include #include namespace rnexecutorch::models::image_segmentation { @@ -92,56 +91,72 @@ std::shared_ptr BaseImageSegmentation::postprocess( std::size_t outputPixels = outputH * outputW; cv::Size outputSize(outputW, outputH); - // Work with vectors, only wrap into OwningArrayBuffer at the end - std::vector> classBuffers; - std::vector argmaxData(outputPixels); + // Copy class data directly into OwningArrayBuffers (single copy from span) + std::vector> resultClasses; + resultClasses.reserve(numChannels); if (numChannels == 1) { // Binary segmentation (e.g. selfie segmentation) - std::vector bg(outputPixels); - std::vector fg(outputPixels); + auto fg = std::make_shared(resultData.data(), + outputPixels * sizeof(float)); + auto bg = std::make_shared(outputPixels * sizeof(float)); + auto *fgPtr = reinterpret_cast(fg->data()); + auto *bgPtr = reinterpret_cast(bg->data()); for (std::size_t pixel = 0; pixel < outputPixels; ++pixel) { - float p = resultData[pixel]; - bg[pixel] = 1.0f - p; - fg[pixel] = p; - argmaxData[pixel] = (p > 0.5f) ? 1 : 0; + bgPtr[pixel] = 1.0f - fgPtr[pixel]; } - classBuffers = {std::move(bg), std::move(fg)}; + resultClasses.push_back(bg); + resultClasses.push_back(fg); } else { // Multi-class segmentation (e.g. DeepLab, RF-DETR) - classBuffers.resize(numChannels); for (std::size_t cl = 0; cl < numChannels; ++cl) { - classBuffers[cl].assign(resultData.data() + cl * outputPixels, - resultData.data() + (cl + 1) * outputPixels); + resultClasses.push_back(std::make_shared( + resultData.data() + cl * outputPixels, outputPixels * sizeof(float))); } + } + + // Softmax + argmax in class-major order + auto argmax = + std::make_shared(outputPixels * sizeof(int32_t)); + auto *argmaxPtr = reinterpret_cast(argmax->data()); - // Apply softmax and compute argmax per pixel + if (numChannels == 1) { + auto *fgPtr = reinterpret_cast(resultClasses[1]->data()); for (std::size_t pixel = 0; pixel < outputPixels; ++pixel) { - std::vector values(numChannels); - for (std::size_t cl = 0; cl < numChannels; ++cl) { - values[cl] = classBuffers[cl][pixel]; - } - numerical::softmax(values); - - float maxVal = values[0]; - int maxInd = 0; - for (std::size_t cl = 0; cl < numChannels; ++cl) { - classBuffers[cl][pixel] = values[cl]; - if (values[cl] > maxVal) { - maxVal = values[cl]; - maxInd = static_cast(cl); + argmaxPtr[pixel] = (fgPtr[pixel] > 0.5f) ? 1 : 0; + } + } else { + std::vector maxLogits(outputPixels, + -std::numeric_limits::infinity()); + std::vector sumExp(outputPixels, 0.0f); + + // Pass 1: find per-pixel max and argmax + for (std::size_t cl = 0; cl < numChannels; ++cl) { + auto *clPtr = reinterpret_cast(resultClasses[cl]->data()); + for (std::size_t pixel = 0; pixel < outputPixels; ++pixel) { + if (clPtr[pixel] > maxLogits[pixel]) { + maxLogits[pixel] = clPtr[pixel]; + argmaxPtr[pixel] = static_cast(cl); } } - argmaxData[pixel] = maxInd; } - } - // Wrap into OwningArrayBuffers - auto argmax = std::make_shared(argmaxData); - std::vector> resultClasses; - resultClasses.reserve(classBuffers.size()); - for (auto &buf : classBuffers) { - resultClasses.push_back(std::make_shared(buf)); + // Pass 2: subtract max, exp, accumulate sum + for (std::size_t cl = 0; cl < numChannels; ++cl) { + auto *clPtr = reinterpret_cast(resultClasses[cl]->data()); + for (std::size_t pixel = 0; pixel < outputPixels; ++pixel) { + clPtr[pixel] = std::exp(clPtr[pixel] - maxLogits[pixel]); + sumExp[pixel] += clPtr[pixel]; + } + } + + // Pass 3: normalize by sum + for (std::size_t cl = 0; cl < numChannels; ++cl) { + auto *clPtr = reinterpret_cast(resultClasses[cl]->data()); + for (std::size_t pixel = 0; pixel < outputPixels; ++pixel) { + clPtr[pixel] /= sumExp[pixel]; + } + } } // Filter classes of interest From 25b3cbc0443df2853c5fe9c86efc4f2617970755 Mon Sep 17 00:00:00 2001 From: chmjkb Date: Wed, 18 Feb 2026 12:17:47 +0100 Subject: [PATCH 17/27] docs: add docs --- .cspell-wordlist.txt | 3 +- .../app/image_segmentation/index.tsx | 2 +- .../useImageSegmentation.md | 78 ++++++++++++------ .../ImageSegmentationModule.md | 82 ++++++++++++++----- .../enumerations/SelfieSegmentationLabel.md | 21 +++++ .../type-aliases/LabelEnum.md | 8 ++ .../type-aliases/ModelNameOf.md | 13 +++ .../type-aliases/ModelSources.md | 9 ++ .../type-aliases/SegmentationConfig.md | 43 ++++++++++ .../type-aliases/SegmentationLabels.md | 15 ++++ .../type-aliases/SegmentationModelName.md | 8 ++ .../06-api-reference/type-aliases/Triple.md | 13 +++ .../variables/SELFIE_SEGMENTATION.md | 15 ++++ .../ImageSegmentationModule.ts | 18 ++-- .../src/types/imageSegmentation.ts | 5 +- 15 files changed, 273 insertions(+), 60 deletions(-) create mode 100644 docs/docs/06-api-reference/enumerations/SelfieSegmentationLabel.md create mode 100644 docs/docs/06-api-reference/type-aliases/LabelEnum.md create mode 100644 docs/docs/06-api-reference/type-aliases/ModelNameOf.md create mode 100644 docs/docs/06-api-reference/type-aliases/ModelSources.md create mode 100644 docs/docs/06-api-reference/type-aliases/SegmentationConfig.md create mode 100644 docs/docs/06-api-reference/type-aliases/SegmentationLabels.md create mode 100644 docs/docs/06-api-reference/type-aliases/SegmentationModelName.md create mode 100644 docs/docs/06-api-reference/type-aliases/Triple.md create mode 100644 docs/docs/06-api-reference/variables/SELFIE_SEGMENTATION.md diff --git a/.cspell-wordlist.txt b/.cspell-wordlist.txt index 7428cd147..2e5bf8fb8 100644 --- a/.cspell-wordlist.txt +++ b/.cspell-wordlist.txt @@ -111,4 +111,5 @@ logprob RNFS pogodin kesha -antonov \ No newline at end of file +antonov +rfdetr diff --git a/apps/computer-vision/app/image_segmentation/index.tsx b/apps/computer-vision/app/image_segmentation/index.tsx index 8c6310540..9b614d409 100644 --- a/apps/computer-vision/app/image_segmentation/index.tsx +++ b/apps/computer-vision/app/image_segmentation/index.tsx @@ -73,7 +73,7 @@ export default function ImageSegmentationScreen() { try { const { width, height } = imageSize; const output = await forward(imageUri, [], true); - const argmax = output['ARGMAX'] || []; + const argmax = output.ARGMAX || []; const pixels = new Uint8Array(width * height * 4); for (let row = 0; row < height; row++) { diff --git a/docs/docs/03-hooks/02-computer-vision/useImageSegmentation.md b/docs/docs/03-hooks/02-computer-vision/useImageSegmentation.md index 3e541bd3b..a58d70c14 100644 --- a/docs/docs/03-hooks/02-computer-vision/useImageSegmentation.md +++ b/docs/docs/03-hooks/02-computer-vision/useImageSegmentation.md @@ -21,12 +21,15 @@ import { DEEPLAB_V3_RESNET50, } from 'react-native-executorch'; -const model = useImageSegmentation({ model: DEEPLAB_V3_RESNET50 }); +const model = useImageSegmentation({ + model: { modelName: 'deeplab-v3', modelSource: DEEPLAB_V3_RESNET50 }, +}); const imageUri = 'file::///Users/.../cute_cat.png'; try { - const outputDict = await model.forward(imageUri); + const result = await model.forward(imageUri); + // result.ARGMAX is an Int32Array of per-pixel class indices } catch (error) { console.error(error); } @@ -36,9 +39,13 @@ try { `useImageSegmentation` takes [`ImageSegmentationProps`](../../06-api-reference/interfaces/ImageSegmentationProps.md) that consists of: -- `model` containing [`modelSource`](../../06-api-reference/interfaces/ImageSegmentationProps.md#modelsource). +- `model` - An object containing: + - `modelName` - The name of a built-in model. See [`ModelSources`](../../06-api-reference/type-aliases/ModelSources.md) for the list of supported models. + - `modelSource` - The location of the model binary (a URL or a bundled resource). - An optional flag [`preventLoad`](../../06-api-reference/interfaces/ImageSegmentationProps.md#preventload) which prevents auto-loading of the model. +The hook is generic over the model config — TypeScript automatically infers the correct label type based on the `modelName` you provide. No explicit generic parameter is needed. + You need more details? Check the following resources: - For detailed information about `useImageSegmentation` arguments check this section: [`useImageSegmentation` arguments](../../06-api-reference/functions/useImageSegmentation.md#parameters). @@ -47,45 +54,70 @@ You need more details? Check the following resources: ### Returns -`useImageSegmentation` returns an object called `ImageSegmentationType` containing bunch of functions to interact with image segmentation models. To get more details please read: [`ImageSegmentationType` API Reference](../../06-api-reference/interfaces/ImageSegmentationType.md). +`useImageSegmentation` returns an [`ImageSegmentationType`](../../06-api-reference/interfaces/ImageSegmentationType.md) object containing: + +- `isReady` - Whether the model is loaded and ready to process images. +- `isGenerating` - Whether the model is currently processing an image. +- `error` - An error object if the model failed to load or encountered a runtime error. +- `downloadProgress` - A value between 0 and 1 representing the download progress of the model binary. +- `forward` - A function to run inference on an image. ## Running the model -To run the model, you can use the [`forward`](../../06-api-reference/interfaces/ImageSegmentationType.md#forward) method. It accepts three arguments: a required image - can be a remote URL, a local file URI, or a base64-encoded image (whole URI or only raw base64), an optional list of classes, and an optional flag whether to resize the output to the original dimensions. +To run the model, use the `forward` method. It accepts three arguments: -- The image can be a remote URL, a local file URI, or a base64-encoded image. -- The [`classesOfInterest`](../../06-api-reference/interfaces/ImageSegmentationType.md#classesofinterest) list contains classes for which to output the full results. By default the list is empty, and only the most probable classes are returned (essentially an arg max for each pixel). Look at [`DeeplabLabel`](../../06-api-reference/enumerations/DeeplabLabel.md) enum for possible classes. -- The [`resizeToInput`](../../06-api-reference/interfaces/ImageSegmentationType.md#resizetoinput) flag specifies whether the output will be rescaled back to the size of the input image. The default is `true`. The model runs inference on a scaled (probably smaller) version of your image (224x224 for `DEEPLAB_V3_RESNET50`). If you choose to resize, the output will be `number[]` of size `width * height` of your original image. +- `imageSource` (required) - The image to segment. Can be a remote URL, a local file URI, or a base64-encoded image (whole URI or only raw base64). +- `classesOfInterest` (optional) - An array of label keys indicating which per-class probability masks to include in the output. Defaults to `[]` (no class masks). The `ARGMAX` map is always returned regardless of this parameter. +- `resizeToInput` (optional) - Whether to resize the output masks to the original input image dimensions. Defaults to `true`. If `false`, returns the raw model output dimensions (e.g. 224x224 for `DEEPLAB_V3_RESNET50`). :::warning -Setting `resizeToInput` to `false` will make `forward` faster. +Setting `resizeToInput` to `true` will make `forward` slower. ::: -[`forward`](../../06-api-reference/interfaces/ImageSegmentationType.md#forward) returns a promise which can resolve either to an error or a dictionary containing number arrays with size depending on [`resizeToInput`](../../06-api-reference/interfaces/ImageSegmentationType.md#resizetoinput): +`forward` returns a promise resolving to an object containing: + +- `ARGMAX` - An `Int32Array` where each element is the class index with the highest probability for that pixel. +- For each label included in `classesOfInterest`, a `Float32Array` of per-pixel probabilities for that class. -- For the key [`DeeplabLabel.ARGMAX`](../../06-api-reference/enumerations/DeeplabLabel.md#argmax) the array contains for each pixel an integer corresponding to the class with the highest probability. -- For every other key from [`DeeplabLabel`](../../06-api-reference/enumerations/DeeplabLabel.md), if the label was included in [`classesOfInterest`](../../06-api-reference/interfaces/ImageSegmentationType.md#classesofinterest) the dictionary will contain an array of floats corresponding to the probability of this class for every pixel. +The return type is fully typed — TypeScript narrows it based on the labels you pass in `classesOfInterest`. ## Example ```typescript +import { + useImageSegmentation, + DEEPLAB_V3_RESNET50, + DeeplabLabel, +} from 'react-native-executorch'; + function App() { - const model = useImageSegmentation({ model: DEEPLAB_V3_RESNET50 }); + const model = useImageSegmentation({ + model: { modelName: 'deeplab-v3', modelSource: DEEPLAB_V3_RESNET50 }, + }); - // ... - const imageUri = 'file::///Users/.../cute_cat.png'; + const handleSegment = async () => { + if (!model.isReady) return; + + const imageUri = 'file::///Users/.../cute_cat.png'; + + try { + const result = await model.forward(imageUri, ['CAT', 'PERSON'], true); + + // result.ARGMAX — Int32Array of per-pixel class indices + // result.CAT — Float32Array of per-pixel probabilities for CAT + // result.PERSON — Float32Array of per-pixel probabilities for PERSON + } catch (error) { + console.error(error); + } + }; - try { - const outputDict = await model.forward(imageUri, [DeeplabLabel.CAT], true); - } catch (error) { - console.error(error); - } // ... } ``` ## Supported models -| Model | Number of classes | Class list | -| ------------------------------------------------------------------------------------------------ | ----------------- | ------------------------------------------------------------------- | -| [deeplabv3_resnet50](https://huggingface.co/software-mansion/react-native-executorch-deeplab-v3) | 21 | [DeeplabLabel](../../06-api-reference/enumerations/DeeplabLabel.md) | +| Model | Number of classes | Class list | +| ------------------------------------------------------------------------------------------------ | ----------------- | ----------------------------------------------------------------------------------------- | +| [deeplabv3_resnet50](https://huggingface.co/software-mansion/react-native-executorch-deeplab-v3) | 21 | [DeeplabLabel](../../06-api-reference/enumerations/DeeplabLabel.md) | +| selfie-segmentation | 2 | [SelfieSegmentationLabel](../../06-api-reference/enumerations/SelfieSegmentationLabel.md) | diff --git a/docs/docs/04-typescript-api/02-computer-vision/ImageSegmentationModule.md b/docs/docs/04-typescript-api/02-computer-vision/ImageSegmentationModule.md index f315e72b0..b688300b5 100644 --- a/docs/docs/04-typescript-api/02-computer-vision/ImageSegmentationModule.md +++ b/docs/docs/04-typescript-api/02-computer-vision/ImageSegmentationModule.md @@ -19,14 +19,15 @@ import { const imageUri = 'path/to/image.png'; -// Creating an instance -const imageSegmentationModule = new ImageSegmentationModule(); - -// Loading the model -await imageSegmentationModule.load(DEEPLAB_V3_RESNET50); +// Creating an instance from a built-in model +const segmentation = await ImageSegmentationModule.fromModelName({ + modelName: 'deeplab-v3', + modelSource: DEEPLAB_V3_RESNET50, +}); // Running the model -const outputDict = await imageSegmentationModule.forward(imageUri); +const result = await segmentation.forward(imageUri); +// result.ARGMAX — Int32Array of per-pixel class indices ``` ### Methods @@ -35,34 +36,75 @@ All methods of `ImageSegmentationModule` are explained in details here: [`ImageS ## Loading the model -To initialize the module, create an instance and call the [`load`](../../06-api-reference/classes/ImageSegmentationModule.md#load) method with the following parameters: +`ImageSegmentationModule` uses static factory methods instead of `new()` + `load()`. There are two ways to create an instance: + +### Built-in models — `fromModelName` + +Use [`fromModelName`](../../06-api-reference/classes/ImageSegmentationModule.md#frommodelname) for models that ship with built-in label maps and preprocessing configs: + +```typescript +const segmentation = await ImageSegmentationModule.fromModelName( + { modelName: 'deeplab-v3', modelSource: DEEPLAB_V3_RESNET50 }, + (progress) => console.log(`Download: ${Math.round(progress * 100)}%`) +); +``` + +The `config` parameter is a discriminated union — TypeScript ensures you provide the correct fields for each model name. Available built-in models: `'deeplab-v3'`, `'selfie-segmentation'`. -- [`model`](../../06-api-reference/classes/ImageSegmentationModule.md#model) - Object containing: - - [`modelSource`](../../06-api-reference/classes/ImageSegmentationModule.md#modelsource) - Location of the used model. +### Custom models — `fromCustomConfig` -- [`onDownloadProgressCallback`](../../06-api-reference/classes/ImageSegmentationModule.md#ondownloadprogresscallback) - Callback to track download progress. +Use [`fromCustomConfig`](../../06-api-reference/classes/ImageSegmentationModule.md#fromcustomconfig) for custom-exported segmentation models with your own label map: -This method returns a promise, which can resolve to an error or void. +```typescript +const MyLabels = { BACKGROUND: 0, FOREGROUND: 1 } as const; + +const segmentation = await ImageSegmentationModule.fromCustomConfig( + 'https://example.com/custom_model.pte', + { + labelMap: MyLabels, + preprocessorConfig: { + normMean: [0.485, 0.456, 0.406], + normStd: [0.229, 0.224, 0.225], + }, + } +); +``` + +The `preprocessorConfig` is optional. If omitted, no input normalization is applied. The module instance will be typed to your custom label map — `forward` will accept and return keys from `MyLabels`. For more information on loading resources, take a look at [loading models](../../01-fundamentals/02-loading-models.md) page. ## Running the model -To run the model, you can use the [`forward`](../../06-api-reference/classes/ImageSegmentationModule.md#forward) method on the module object. It accepts three arguments: a required image - can be a remote URL, a local file URI, or a base64-encoded image (whole URI or only raw base64), an optional list of classes, and an optional flag whether to resize the output to the original dimensions. +To run the model, use the [`forward`](../../06-api-reference/classes/ImageSegmentationModule.md#forward) method. It accepts three arguments: -- The image can be a remote URL, a local file URI, or a base64-encoded image. -- The [`classesOfInterest`](../../06-api-reference/classes/ImageSegmentationModule.md#classesofinterest) list contains classes for which to output the full results. By default the list is empty, and only the most probable classes are returned (essentially an arg max for each pixel). Look at [`DeeplabLabel`](../../06-api-reference/enumerations/DeeplabLabel.md) enum for possible classes. -- The [`resizeToInput`](../../06-api-reference/classes/ImageSegmentationModule.md#resizetoinput) flag specifies whether the output will be rescaled back to the size of the input image. The default is `true`. The model runs inference on a scaled (probably smaller) version of your image (224x224 for the `DEEPLAB_V3_RESNET50`). If you choose to resize, the output will be `number[]` of size `width * height` of your original image. +- `imageSource` (required) - The image to segment. Can be a remote URL, a local file URI, or a base64-encoded image (whole URI or only raw base64). +- `classesOfInterest` (optional) - An array of label keys indicating which per-class probability masks to include in the output. Defaults to `[]`. The `ARGMAX` map is always returned regardless. +- `resizeToInput` (optional) - Whether to resize the output masks to the original input image dimensions. Defaults to `true`. If `false`, returns the raw model output dimensions. :::warning -Setting `resize` to true will make `forward` slower. +Setting `resizeToInput` to `true` will make `forward` slower. ::: -[`forward`](../../06-api-reference/classes/ImageSegmentationModule.md#forward) returns a promise which can resolve either to an error or a dictionary containing number arrays with size depending on [`resizeToInput`](../../06-api-reference/classes/ImageSegmentationModule.md#resizetoinput): +`forward` returns a promise resolving to an object containing: + +- `ARGMAX` - An `Int32Array` where each element is the class index with the highest probability for that pixel. +- For each label included in `classesOfInterest`, a `Float32Array` of per-pixel probabilities for that class. -- For the key [`DeeplabLabel.ARGMAX`](../../06-api-reference/enumerations/DeeplabLabel.md#argmax) the array contains for each pixel an integer corresponding to the class with the highest probability. -- For every other key from [`DeeplabLabel`](../../06-api-reference/enumerations/DeeplabLabel.md), if the label was included in [`classesOfInterest`](../../06-api-reference/classes/ImageSegmentationModule.md#classesofinterest) the dictionary will contain an array of floats corresponding to the probability of this class for every pixel. +The return type narrows based on the labels passed in `classesOfInterest`: + +```typescript +// Only ARGMAX in the result +const result = await segmentation.forward(imageUri); +result.ARGMAX; // Int32Array + +// ARGMAX + requested class masks +const result = await segmentation.forward(imageUri, ['CAT', 'DOG']); +result.ARGMAX; // Int32Array +result.CAT; // Float32Array +result.DOG; // Float32Array +``` ## Managing memory -The module is a regular JavaScript object, and as such its lifespan will be managed by the garbage collector. In most cases this should be enough, and you should not worry about freeing the memory of the module yourself, but in some cases you may want to release the memory occupied by the module before the garbage collector steps in. In this case use the method [`delete`](../../06-api-reference/classes/ImageSegmentationModule.md#delete) on the module object you will no longer use, and want to remove from the memory. Note that you cannot use [`forward`](../../06-api-reference/classes/ImageSegmentationModule.md#forward) after [`delete`](../../06-api-reference/classes/ImageSegmentationModule.md#delete) unless you load the module again. +The module is a regular JavaScript object, and as such its lifespan will be managed by the garbage collector. In most cases this should be enough, and you should not worry about freeing the memory of the module yourself, but in some cases you may want to release the memory occupied by the module before the garbage collector steps in. In this case use the method [`delete`](../../06-api-reference/classes/ImageSegmentationModule.md#delete) on the module object you will no longer use, and want to remove from the memory. Note that you cannot use [`forward`](../../06-api-reference/classes/ImageSegmentationModule.md#forward) after [`delete`](../../06-api-reference/classes/ImageSegmentationModule.md#delete) unless you create a new instance. diff --git a/docs/docs/06-api-reference/enumerations/SelfieSegmentationLabel.md b/docs/docs/06-api-reference/enumerations/SelfieSegmentationLabel.md new file mode 100644 index 000000000..912d29f5d --- /dev/null +++ b/docs/docs/06-api-reference/enumerations/SelfieSegmentationLabel.md @@ -0,0 +1,21 @@ +# Enumeration: SelfieSegmentationLabel + +Defined in: [packages/react-native-executorch/src/types/imageSegmentation.ts:81](https://github.com/software-mansion/react-native-executorch/blob/ec04754e2ea2ad38fe30c36a9250db47f020a06e/packages/react-native-executorch/src/types/imageSegmentation.ts#L81) + +Labels used in the selfie image segmentation model. + +## Enumeration Members + +### BACKGROUND + +> **BACKGROUND**: `1` + +Defined in: [packages/react-native-executorch/src/types/imageSegmentation.ts:83](https://github.com/software-mansion/react-native-executorch/blob/ec04754e2ea2ad38fe30c36a9250db47f020a06e/packages/react-native-executorch/src/types/imageSegmentation.ts#L83) + +--- + +### SELFIE + +> **SELFIE**: `0` + +Defined in: [packages/react-native-executorch/src/types/imageSegmentation.ts:82](https://github.com/software-mansion/react-native-executorch/blob/ec04754e2ea2ad38fe30c36a9250db47f020a06e/packages/react-native-executorch/src/types/imageSegmentation.ts#L82) diff --git a/docs/docs/06-api-reference/type-aliases/LabelEnum.md b/docs/docs/06-api-reference/type-aliases/LabelEnum.md new file mode 100644 index 000000000..9414676dc --- /dev/null +++ b/docs/docs/06-api-reference/type-aliases/LabelEnum.md @@ -0,0 +1,8 @@ +# Type Alias: LabelEnum + +> **LabelEnum** = `Readonly`\<`Record`\<`string`, `number` \| `string`\>\> + +Defined in: [packages/react-native-executorch/src/types/common.ts:146](https://github.com/software-mansion/react-native-executorch/blob/ec04754e2ea2ad38fe30c36a9250db47f020a06e/packages/react-native-executorch/src/types/common.ts#L146) + +A readonly record mapping string keys to numeric or string values. +Used to represent enum-like label maps for models. diff --git a/docs/docs/06-api-reference/type-aliases/ModelNameOf.md b/docs/docs/06-api-reference/type-aliases/ModelNameOf.md new file mode 100644 index 000000000..e962ab698 --- /dev/null +++ b/docs/docs/06-api-reference/type-aliases/ModelNameOf.md @@ -0,0 +1,13 @@ +# Type Alias: ModelNameOf\ + +> **ModelNameOf**\<`C`\> = `C`\[`"modelName"`\] + +Defined in: [packages/react-native-executorch/src/types/imageSegmentation.ts:45](https://github.com/software-mansion/react-native-executorch/blob/ec04754e2ea2ad38fe30c36a9250db47f020a06e/packages/react-native-executorch/src/types/imageSegmentation.ts#L45) + +Extracts the model name from a [ModelSources](ModelSources.md) config object. + +## Type Parameters + +### C + +`C` _extends_ [`ModelSources`](ModelSources.md) diff --git a/docs/docs/06-api-reference/type-aliases/ModelSources.md b/docs/docs/06-api-reference/type-aliases/ModelSources.md new file mode 100644 index 000000000..eefe51641 --- /dev/null +++ b/docs/docs/06-api-reference/type-aliases/ModelSources.md @@ -0,0 +1,9 @@ +# Type Alias: ModelSources + +> **ModelSources** = \{ `modelName`: `"deeplab-v3"`; `modelSource`: [`ResourceSource`](ResourceSource.md); \} \| \{ `modelName`: `"selfie-segmentation"`; `modelSource`: [`ResourceSource`](ResourceSource.md); \} \| \{ `modelName`: `"rfdetr"`; `modelSource`: [`ResourceSource`](ResourceSource.md); \} + +Defined in: [packages/react-native-executorch/src/types/imageSegmentation.ts:27](https://github.com/software-mansion/react-native-executorch/blob/ec04754e2ea2ad38fe30c36a9250db47f020a06e/packages/react-native-executorch/src/types/imageSegmentation.ts#L27) + +Per-model config for [ImageSegmentationModule.fromModelName](../classes/ImageSegmentationModule.md#frommodelname). +Each model name maps to its required fields. +Add new union members here when a model needs extra sources or options. diff --git a/docs/docs/06-api-reference/type-aliases/SegmentationConfig.md b/docs/docs/06-api-reference/type-aliases/SegmentationConfig.md new file mode 100644 index 000000000..ce957311d --- /dev/null +++ b/docs/docs/06-api-reference/type-aliases/SegmentationConfig.md @@ -0,0 +1,43 @@ +# Type Alias: SegmentationConfig\ + +> **SegmentationConfig**\<`T`\> = `object` + +Defined in: [packages/react-native-executorch/src/types/imageSegmentation.ts:15](https://github.com/software-mansion/react-native-executorch/blob/ec04754e2ea2ad38fe30c36a9250db47f020a06e/packages/react-native-executorch/src/types/imageSegmentation.ts#L15) + +Configuration for a custom segmentation model. + +## Type Parameters + +### T + +`T` _extends_ [`LabelEnum`](LabelEnum.md) + +The [LabelEnum](LabelEnum.md) type for the model. + +## Properties + +### labelMap + +> **labelMap**: `T` + +Defined in: [packages/react-native-executorch/src/types/imageSegmentation.ts:16](https://github.com/software-mansion/react-native-executorch/blob/ec04754e2ea2ad38fe30c36a9250db47f020a06e/packages/react-native-executorch/src/types/imageSegmentation.ts#L16) + +The enum-like object mapping class names to indices. + +--- + +### preprocessorConfig? + +> `optional` **preprocessorConfig**: `object` + +Defined in: [packages/react-native-executorch/src/types/imageSegmentation.ts:17](https://github.com/software-mansion/react-native-executorch/blob/ec04754e2ea2ad38fe30c36a9250db47f020a06e/packages/react-native-executorch/src/types/imageSegmentation.ts#L17) + +Optional preprocessing parameters. + +#### normMean? + +> `optional` **normMean**: [`Triple`](Triple.md)\<`number`\> + +#### normStd? + +> `optional` **normStd**: [`Triple`](Triple.md)\<`number`\> diff --git a/docs/docs/06-api-reference/type-aliases/SegmentationLabels.md b/docs/docs/06-api-reference/type-aliases/SegmentationLabels.md new file mode 100644 index 000000000..96939a346 --- /dev/null +++ b/docs/docs/06-api-reference/type-aliases/SegmentationLabels.md @@ -0,0 +1,15 @@ +# Type Alias: SegmentationLabels\ + +> **SegmentationLabels**\<`M`\> = `ModelConfigsType`\[`M`\]\[`"labelMap"`\] + +Defined in: [packages/react-native-executorch/src/modules/computer_vision/ImageSegmentationModule.ts:45](https://github.com/software-mansion/react-native-executorch/blob/ec04754e2ea2ad38fe30c36a9250db47f020a06e/packages/react-native-executorch/src/modules/computer_vision/ImageSegmentationModule.ts#L45) + +Resolves the [LabelEnum](LabelEnum.md) for a given built-in model name. + +## Type Parameters + +### M + +`M` _extends_ [`SegmentationModelName`](SegmentationModelName.md) + +A built-in model name from [SegmentationModelName](SegmentationModelName.md). diff --git a/docs/docs/06-api-reference/type-aliases/SegmentationModelName.md b/docs/docs/06-api-reference/type-aliases/SegmentationModelName.md new file mode 100644 index 000000000..ba7197d06 --- /dev/null +++ b/docs/docs/06-api-reference/type-aliases/SegmentationModelName.md @@ -0,0 +1,8 @@ +# Type Alias: SegmentationModelName + +> **SegmentationModelName** = [`ModelSources`](ModelSources.md)\[`"modelName"`\] + +Defined in: [packages/react-native-executorch/src/types/imageSegmentation.ts:38](https://github.com/software-mansion/react-native-executorch/blob/ec04754e2ea2ad38fe30c36a9250db47f020a06e/packages/react-native-executorch/src/types/imageSegmentation.ts#L38) + +Union of all built-in segmentation model names +(e.g. `'deeplab-v3'`, `'selfie-segmentation'`, `'rfdetr'`). diff --git a/docs/docs/06-api-reference/type-aliases/Triple.md b/docs/docs/06-api-reference/type-aliases/Triple.md new file mode 100644 index 000000000..41fa3d2b1 --- /dev/null +++ b/docs/docs/06-api-reference/type-aliases/Triple.md @@ -0,0 +1,13 @@ +# Type Alias: Triple\ + +> **Triple**\<`T`\> = readonly \[`T`, `T`, `T`\] + +Defined in: [packages/react-native-executorch/src/types/common.ts:153](https://github.com/software-mansion/react-native-executorch/blob/ec04754e2ea2ad38fe30c36a9250db47f020a06e/packages/react-native-executorch/src/types/common.ts#L153) + +A readonly triple of values, typically used for per-channel normalization parameters. + +## Type Parameters + +### T + +`T` diff --git a/docs/docs/06-api-reference/variables/SELFIE_SEGMENTATION.md b/docs/docs/06-api-reference/variables/SELFIE_SEGMENTATION.md new file mode 100644 index 000000000..36135377c --- /dev/null +++ b/docs/docs/06-api-reference/variables/SELFIE_SEGMENTATION.md @@ -0,0 +1,15 @@ +# Variable: SELFIE_SEGMENTATION + +> `const` **SELFIE_SEGMENTATION**: `object` + +Defined in: [packages/react-native-executorch/src/constants/modelUrls.ts:533](https://github.com/software-mansion/react-native-executorch/blob/ec04754e2ea2ad38fe30c36a9250db47f020a06e/packages/react-native-executorch/src/constants/modelUrls.ts#L533) + +## Type Declaration + +### modelName + +> `readonly` **modelName**: `"selfie-segmentation"` = `'selfie-segmentation'` + +### modelSource + +> `readonly` **modelSource**: `"https://ai.swmansion.com/storage/jc_tests/selfie.pte"` = `SELFIE_SEGMENTATION_MODEL` diff --git a/packages/react-native-executorch/src/modules/computer_vision/ImageSegmentationModule.ts b/packages/react-native-executorch/src/modules/computer_vision/ImageSegmentationModule.ts index 63e1e9410..5eb7ce847 100644 --- a/packages/react-native-executorch/src/modules/computer_vision/ImageSegmentationModule.ts +++ b/packages/react-native-executorch/src/modules/computer_vision/ImageSegmentationModule.ts @@ -1,6 +1,5 @@ import { ResourceFetcher } from '../../utils/ResourceFetcher'; import { ResourceSource, LabelEnum } from '../../types/common'; -import { CocoLabel } from '../../types/objectDetection'; import { DeeplabLabel, ModelNameOf, @@ -9,13 +8,15 @@ import { SegmentationModelName, SelfieSegmentationLabel, } from '../../types/imageSegmentation'; -import { IMAGENET_MEAN, IMAGENET_STD } from '../../constants/commonVision'; import { RnExecutorchErrorCode } from '../../errors/ErrorCodes'; import { parseUnknownError, RnExecutorchError } from '../../errors/errorUtils'; import { BaseModule } from '../BaseModule'; import { Logger } from '../../common/Logger'; -const ModelConfigs = { +const ModelConfigs: Record< + SegmentationModelName, + SegmentationConfig +> = { 'deeplab-v3': { labelMap: DeeplabLabel, preprocessorConfig: undefined, @@ -24,14 +25,7 @@ const ModelConfigs = { labelMap: SelfieSegmentationLabel, preprocessorConfig: undefined, }, - 'rfdetr': { - labelMap: CocoLabel, - preprocessorConfig: { normMean: IMAGENET_MEAN, normStd: IMAGENET_STD }, - }, -} as const satisfies Record< - SegmentationModelName, - SegmentationConfig ->; +} as const; /** @internal */ type ModelConfigsType = typeof ModelConfigs; @@ -59,7 +53,7 @@ type ResolveLabels = * Use a model name (e.g. `'deeplab-v3'`) as the generic parameter for built-in models, * or a custom label enum for custom configs. * - * @typeParam T - Either a built-in model name (`'deeplab-v3'`, `'selfie-segmentation'`, `'rfdetr'`) + * @typeParam T - Either a built-in model name (`'deeplab-v3'`, `'selfie-segmentation'`) * or a custom {@link LabelEnum} label map. * * @category Typescript API diff --git a/packages/react-native-executorch/src/types/imageSegmentation.ts b/packages/react-native-executorch/src/types/imageSegmentation.ts index 9c8bcc7ae..6d79a801d 100644 --- a/packages/react-native-executorch/src/types/imageSegmentation.ts +++ b/packages/react-native-executorch/src/types/imageSegmentation.ts @@ -26,12 +26,11 @@ export type SegmentationConfig = { */ export type ModelSources = | { modelName: 'deeplab-v3'; modelSource: ResourceSource } - | { modelName: 'selfie-segmentation'; modelSource: ResourceSource } - | { modelName: 'rfdetr'; modelSource: ResourceSource }; + | { modelName: 'selfie-segmentation'; modelSource: ResourceSource }; /** * Union of all built-in segmentation model names - * (e.g. `'deeplab-v3'`, `'selfie-segmentation'`, `'rfdetr'`). + * (e.g. `'deeplab-v3'`, `'selfie-segmentation'`). * * @category Types */ From ac4a107954740ca2b25f8bb7c5293c7e43ce719b Mon Sep 17 00:00:00 2001 From: chmjkb Date: Wed, 18 Feb 2026 12:18:19 +0100 Subject: [PATCH 18/27] lint: add words to cspell --- .cspell-wordlist.txt | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.cspell-wordlist.txt b/.cspell-wordlist.txt index 2e5bf8fb8..603da5b59 100644 --- a/.cspell-wordlist.txt +++ b/.cspell-wordlist.txt @@ -113,3 +113,5 @@ pogodin kesha antonov rfdetr +basemodule +IMAGENET From 16f703dcae65a77b141b304c7bb86e4ebae6844a Mon Sep 17 00:00:00 2001 From: chmjkb Date: Wed, 18 Feb 2026 13:03:08 +0100 Subject: [PATCH 19/27] docs: add missing links --- .../03-hooks/02-computer-vision/useImageSegmentation.md | 8 ++++---- .../02-computer-vision/ImageSegmentationModule.md | 6 +++--- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/docs/docs/03-hooks/02-computer-vision/useImageSegmentation.md b/docs/docs/03-hooks/02-computer-vision/useImageSegmentation.md index a58d70c14..01db4e4cd 100644 --- a/docs/docs/03-hooks/02-computer-vision/useImageSegmentation.md +++ b/docs/docs/03-hooks/02-computer-vision/useImageSegmentation.md @@ -64,11 +64,11 @@ You need more details? Check the following resources: ## Running the model -To run the model, use the `forward` method. It accepts three arguments: +To run the model, use the [`forward`](../../06-api-reference/interfaces/ImageSegmentationType.md#forward) method. It accepts three arguments: -- `imageSource` (required) - The image to segment. Can be a remote URL, a local file URI, or a base64-encoded image (whole URI or only raw base64). -- `classesOfInterest` (optional) - An array of label keys indicating which per-class probability masks to include in the output. Defaults to `[]` (no class masks). The `ARGMAX` map is always returned regardless of this parameter. -- `resizeToInput` (optional) - Whether to resize the output masks to the original input image dimensions. Defaults to `true`. If `false`, returns the raw model output dimensions (e.g. 224x224 for `DEEPLAB_V3_RESNET50`). +- [`imageSource`](../../06-api-reference/interfaces/ImageSegmentationType.md#forward) (required) - The image to segment. Can be a remote URL, a local file URI, or a base64-encoded image (whole URI or only raw base64). +- [`classesOfInterest`](../../06-api-reference/interfaces/ImageSegmentationType.md#forward) (optional) - An array of label keys indicating which per-class probability masks to include in the output. Defaults to `[]` (no class masks). The `ARGMAX` map is always returned regardless of this parameter. +- [`resizeToInput`](../../06-api-reference/interfaces/ImageSegmentationType.md#forward) (optional) - Whether to resize the output masks to the original input image dimensions. Defaults to `true`. If `false`, returns the raw model output dimensions (e.g. 224x224 for `DEEPLAB_V3_RESNET50`). :::warning Setting `resizeToInput` to `true` will make `forward` slower. diff --git a/docs/docs/04-typescript-api/02-computer-vision/ImageSegmentationModule.md b/docs/docs/04-typescript-api/02-computer-vision/ImageSegmentationModule.md index b688300b5..60e832eae 100644 --- a/docs/docs/04-typescript-api/02-computer-vision/ImageSegmentationModule.md +++ b/docs/docs/04-typescript-api/02-computer-vision/ImageSegmentationModule.md @@ -78,9 +78,9 @@ For more information on loading resources, take a look at [loading models](../.. To run the model, use the [`forward`](../../06-api-reference/classes/ImageSegmentationModule.md#forward) method. It accepts three arguments: -- `imageSource` (required) - The image to segment. Can be a remote URL, a local file URI, or a base64-encoded image (whole URI or only raw base64). -- `classesOfInterest` (optional) - An array of label keys indicating which per-class probability masks to include in the output. Defaults to `[]`. The `ARGMAX` map is always returned regardless. -- `resizeToInput` (optional) - Whether to resize the output masks to the original input image dimensions. Defaults to `true`. If `false`, returns the raw model output dimensions. +- [`imageSource`](../../06-api-reference/classes/ImageSegmentationModule.md#forward) (required) - The image to segment. Can be a remote URL, a local file URI, or a base64-encoded image (whole URI or only raw base64). +- [`classesOfInterest`](../../06-api-reference/classes/ImageSegmentationModule.md#forward) (optional) - An array of label keys indicating which per-class probability masks to include in the output. Defaults to `[]`. The `ARGMAX` map is always returned regardless. +- [`resizeToInput`](../../06-api-reference/classes/ImageSegmentationModule.md#forward) (optional) - Whether to resize the output masks to the original input image dimensions. Defaults to `true`. If `false`, returns the raw model output dimensions. :::warning Setting `resizeToInput` to `true` will make `forward` slower. From b4a8ce2de6dec6fc676675d54e3e77d30cec4074 Mon Sep 17 00:00:00 2001 From: chmjkb Date: Wed, 18 Feb 2026 13:05:39 +0100 Subject: [PATCH 20/27] docs: review suggestion --- docs/docs/03-hooks/02-computer-vision/useImageSegmentation.md | 2 +- .../02-computer-vision/ImageSegmentationModule.md | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/docs/03-hooks/02-computer-vision/useImageSegmentation.md b/docs/docs/03-hooks/02-computer-vision/useImageSegmentation.md index 01db4e4cd..bc0cc170b 100644 --- a/docs/docs/03-hooks/02-computer-vision/useImageSegmentation.md +++ b/docs/docs/03-hooks/02-computer-vision/useImageSegmentation.md @@ -71,7 +71,7 @@ To run the model, use the [`forward`](../../06-api-reference/interfaces/ImageSeg - [`resizeToInput`](../../06-api-reference/interfaces/ImageSegmentationType.md#forward) (optional) - Whether to resize the output masks to the original input image dimensions. Defaults to `true`. If `false`, returns the raw model output dimensions (e.g. 224x224 for `DEEPLAB_V3_RESNET50`). :::warning -Setting `resizeToInput` to `true` will make `forward` slower. +Setting `resizeToInput` to `false` will make `forward` faster. ::: `forward` returns a promise resolving to an object containing: diff --git a/docs/docs/04-typescript-api/02-computer-vision/ImageSegmentationModule.md b/docs/docs/04-typescript-api/02-computer-vision/ImageSegmentationModule.md index 60e832eae..4b618b5b1 100644 --- a/docs/docs/04-typescript-api/02-computer-vision/ImageSegmentationModule.md +++ b/docs/docs/04-typescript-api/02-computer-vision/ImageSegmentationModule.md @@ -83,7 +83,7 @@ To run the model, use the [`forward`](../../06-api-reference/classes/ImageSegmen - [`resizeToInput`](../../06-api-reference/classes/ImageSegmentationModule.md#forward) (optional) - Whether to resize the output masks to the original input image dimensions. Defaults to `true`. If `false`, returns the raw model output dimensions. :::warning -Setting `resizeToInput` to `true` will make `forward` slower. +Setting `resizeToInput` to `false` will make `forward` faster. ::: `forward` returns a promise resolving to an object containing: From dc926391b307c9b4133a4170bd5dc9adf94261ab Mon Sep 17 00:00:00 2001 From: chmjkb Date: Wed, 18 Feb 2026 13:09:22 +0100 Subject: [PATCH 21/27] docs: fix example --- docs/docs/03-hooks/02-computer-vision/useImageSegmentation.md | 4 ++-- .../02-computer-vision/ImageSegmentationModule.md | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/docs/03-hooks/02-computer-vision/useImageSegmentation.md b/docs/docs/03-hooks/02-computer-vision/useImageSegmentation.md index bc0cc170b..676ae012e 100644 --- a/docs/docs/03-hooks/02-computer-vision/useImageSegmentation.md +++ b/docs/docs/03-hooks/02-computer-vision/useImageSegmentation.md @@ -22,7 +22,7 @@ import { } from 'react-native-executorch'; const model = useImageSegmentation({ - model: { modelName: 'deeplab-v3', modelSource: DEEPLAB_V3_RESNET50 }, + model: DEEPLAB_V3_RESNET50, }); const imageUri = 'file::///Users/.../cute_cat.png'; @@ -92,7 +92,7 @@ import { function App() { const model = useImageSegmentation({ - model: { modelName: 'deeplab-v3', modelSource: DEEPLAB_V3_RESNET50 }, + model: DEEPLAB_V3_RESNET50, }); const handleSegment = async () => { diff --git a/docs/docs/04-typescript-api/02-computer-vision/ImageSegmentationModule.md b/docs/docs/04-typescript-api/02-computer-vision/ImageSegmentationModule.md index 4b618b5b1..858713368 100644 --- a/docs/docs/04-typescript-api/02-computer-vision/ImageSegmentationModule.md +++ b/docs/docs/04-typescript-api/02-computer-vision/ImageSegmentationModule.md @@ -44,7 +44,7 @@ Use [`fromModelName`](../../06-api-reference/classes/ImageSegmentationModule.md# ```typescript const segmentation = await ImageSegmentationModule.fromModelName( - { modelName: 'deeplab-v3', modelSource: DEEPLAB_V3_RESNET50 }, + DEEPLAB_V3_RESNET50, (progress) => console.log(`Download: ${Math.round(progress * 100)}%`) ); ``` From 44e3454181175d8b0486f053a375145f6ba16757 Mon Sep 17 00:00:00 2001 From: chmjkb Date: Thu, 19 Feb 2026 11:18:21 +0100 Subject: [PATCH 22/27] fix: types and revert bg, fg in selfie segmentation --- .../image_segmentation/BaseImageSegmentation.cpp | 6 +++--- .../computer_vision/ImageSegmentationModule.ts | 15 +++++++++------ 2 files changed, 12 insertions(+), 9 deletions(-) diff --git a/packages/react-native-executorch/common/rnexecutorch/models/image_segmentation/BaseImageSegmentation.cpp b/packages/react-native-executorch/common/rnexecutorch/models/image_segmentation/BaseImageSegmentation.cpp index a924bc1f2..d9350ee57 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/image_segmentation/BaseImageSegmentation.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/models/image_segmentation/BaseImageSegmentation.cpp @@ -105,8 +105,8 @@ std::shared_ptr BaseImageSegmentation::postprocess( for (std::size_t pixel = 0; pixel < outputPixels; ++pixel) { bgPtr[pixel] = 1.0f - fgPtr[pixel]; } - resultClasses.push_back(bg); resultClasses.push_back(fg); + resultClasses.push_back(bg); } else { // Multi-class segmentation (e.g. DeepLab, RF-DETR) for (std::size_t cl = 0; cl < numChannels; ++cl) { @@ -121,9 +121,9 @@ std::shared_ptr BaseImageSegmentation::postprocess( auto *argmaxPtr = reinterpret_cast(argmax->data()); if (numChannels == 1) { - auto *fgPtr = reinterpret_cast(resultClasses[1]->data()); + auto *fgPtr = reinterpret_cast(resultClasses[0]->data()); for (std::size_t pixel = 0; pixel < outputPixels; ++pixel) { - argmaxPtr[pixel] = (fgPtr[pixel] > 0.5f) ? 1 : 0; + argmaxPtr[pixel] = (fgPtr[pixel] > 0.5f) ? 0 : 1; } } else { std::vector maxLogits(outputPixels, diff --git a/packages/react-native-executorch/src/modules/computer_vision/ImageSegmentationModule.ts b/packages/react-native-executorch/src/modules/computer_vision/ImageSegmentationModule.ts index 5eb7ce847..6b2c2b7d6 100644 --- a/packages/react-native-executorch/src/modules/computer_vision/ImageSegmentationModule.ts +++ b/packages/react-native-executorch/src/modules/computer_vision/ImageSegmentationModule.ts @@ -13,10 +13,7 @@ import { parseUnknownError, RnExecutorchError } from '../../errors/errorUtils'; import { BaseModule } from '../BaseModule'; import { Logger } from '../../common/Logger'; -const ModelConfigs: Record< - SegmentationModelName, - SegmentationConfig -> = { +const ModelConfigs = { 'deeplab-v3': { labelMap: DeeplabLabel, preprocessorConfig: undefined, @@ -25,7 +22,10 @@ const ModelConfigs: Record< labelMap: SelfieSegmentationLabel, preprocessorConfig: undefined, }, -} as const; +} as const satisfies Record< + SegmentationModelName, + SegmentationConfig +>; /** @internal */ type ModelConfigsType = typeof ModelConfigs; @@ -98,7 +98,10 @@ export class ImageSegmentationModule< onDownloadProgress: (progress: number) => void = () => { } ): Promise>> { const { modelName, modelSource } = config; - const { labelMap, preprocessorConfig } = ModelConfigs[modelName]; + const { labelMap } = ModelConfigs[modelName]; + const { preprocessorConfig } = ModelConfigs[ + modelName + ] as SegmentationConfig; const normMean = preprocessorConfig?.normMean ?? []; const normStd = preprocessorConfig?.normStd ?? []; const paths = await ResourceFetcher.fetch(onDownloadProgress, modelSource); From a9d58e1106b97770b523e61b837e2bf37b3c4d48 Mon Sep 17 00:00:00 2001 From: Jakub Chmura <92989966+chmjkb@users.noreply.github.com> Date: Fri, 20 Feb 2026 15:25:17 +0100 Subject: [PATCH 23/27] Update packages/react-native-executorch/src/constants/commonVision.ts Co-authored-by: Mateusz Kopcinski <120639731+mkopcins@users.noreply.github.com> --- .../react-native-executorch/src/constants/commonVision.ts | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/packages/react-native-executorch/src/constants/commonVision.ts b/packages/react-native-executorch/src/constants/commonVision.ts index 4e2f29473..05eeba759 100644 --- a/packages/react-native-executorch/src/constants/commonVision.ts +++ b/packages/react-native-executorch/src/constants/commonVision.ts @@ -1,4 +1,4 @@ import { Triple } from '../types/common'; -export const IMAGENET_MEAN: Triple = [0.485, 0.456, 0.406]; -export const IMAGENET_STD: Triple = [0.229, 0.224, 0.225]; +export const IMAGENET1K_MEAN: Triple = [0.485, 0.456, 0.406]; +export const IMAGENET1K_STD: Triple = [0.229, 0.224, 0.225]; From edd1b1d99e97ffb4de1308ab91a4766ae701f265 Mon Sep 17 00:00:00 2001 From: chmjkb Date: Fri, 20 Feb 2026 15:54:30 +0100 Subject: [PATCH 24/27] chore: review suggextions --- .../common/rnexecutorch/Log.h | 7 +++++++ .../rnexecutorch/RnExecutorchInstaller.cpp | 4 ++-- .../common/rnexecutorch/models/BaseModel.h | 1 + .../BaseImageSegmentation.cpp | 11 +++++++++-- .../BaseImageSegmentation.h | 5 +++++ .../image_segmentation/ImageSegmentation.cpp | 1 - .../image_segmentation/ImageSegmentation.h | 19 ------------------- .../object_detection/ObjectDetection.cpp | 1 + .../computer_vision/useImageSegmentation.ts | 14 ++++++++++---- .../src/hooks/useModule.ts | 11 ++++++++--- 10 files changed, 43 insertions(+), 31 deletions(-) delete mode 100644 packages/react-native-executorch/common/rnexecutorch/models/image_segmentation/ImageSegmentation.cpp delete mode 100644 packages/react-native-executorch/common/rnexecutorch/models/image_segmentation/ImageSegmentation.h diff --git a/packages/react-native-executorch/common/rnexecutorch/Log.h b/packages/react-native-executorch/common/rnexecutorch/Log.h index 9381324ab..bb17a53ec 100644 --- a/packages/react-native-executorch/common/rnexecutorch/Log.h +++ b/packages/react-native-executorch/common/rnexecutorch/Log.h @@ -371,6 +371,8 @@ namespace rnexecutorch { enum class LOG_LEVEL : uint8_t { Info, /**< Informational messages that highlight the progress of the application. */ + Warn, /**< Warning messages that a non-critical error occurred during + program execution */ Error, /**< Error events of considerable importance that will prevent normal program execution. */ Debug /**< Detailed information, typically of interest only when diagnosing @@ -384,6 +386,8 @@ inline android_LogPriority androidLogLevel(LOG_LEVEL logLevel) { switch (logLevel) { case LOG_LEVEL::Info: return ANDROID_LOG_INFO; + case LOG_LEVEL::Warn: + return ANDROID_LOG_WARN; case LOG_LEVEL::Error: return ANDROID_LOG_ERROR; case LOG_LEVEL::Debug: @@ -404,6 +408,9 @@ inline void handleIosLog(LOG_LEVEL logLevel, const char *buffer) { case LOG_LEVEL::Info: os_log_info(OS_LOG_DEFAULT, "%{public}s", buffer); return; + case LOG_LEVEL::Warn: + os_log(OS_LOG_DEFAULT, "%{public}s", buffer); + return; case LOG_LEVEL::Error: os_log_error(OS_LOG_DEFAULT, "%{public}s", buffer); return; diff --git a/packages/react-native-executorch/common/rnexecutorch/RnExecutorchInstaller.cpp b/packages/react-native-executorch/common/rnexecutorch/RnExecutorchInstaller.cpp index bceac64ad..7804fb5b7 100644 --- a/packages/react-native-executorch/common/rnexecutorch/RnExecutorchInstaller.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/RnExecutorchInstaller.cpp @@ -5,7 +5,7 @@ #include #include #include -#include +#include #include #include #include @@ -44,7 +44,7 @@ void RnExecutorchInstaller::injectJSIBindings( jsiRuntime->global().setProperty( *jsiRuntime, "loadImageSegmentation", RnExecutorchInstaller::loadModel< - models::image_segmentation::ImageSegmentation>( + models::image_segmentation::BaseImageSegmentation>( jsiRuntime, jsCallInvoker, "loadImageSegmentation")); jsiRuntime->global().setProperty( diff --git a/packages/react-native-executorch/common/rnexecutorch/models/BaseModel.h b/packages/react-native-executorch/common/rnexecutorch/models/BaseModel.h index c40fa2569..a0f9d5446 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/BaseModel.h +++ b/packages/react-native-executorch/common/rnexecutorch/models/BaseModel.h @@ -19,6 +19,7 @@ using executorch::runtime::Result; class BaseModel { public: + virtual ~BaseModel() = default; BaseModel( const std::string &modelSource, std::shared_ptr callInvoker, diff --git a/packages/react-native-executorch/common/rnexecutorch/models/image_segmentation/BaseImageSegmentation.cpp b/packages/react-native-executorch/common/rnexecutorch/models/image_segmentation/BaseImageSegmentation.cpp index d9350ee57..141ec430e 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/image_segmentation/BaseImageSegmentation.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/models/image_segmentation/BaseImageSegmentation.cpp @@ -5,6 +5,7 @@ #include #include +#include #include #include @@ -22,11 +23,17 @@ BaseImageSegmentation::BaseImageSegmentation( std::vector normStd, std::shared_ptr callInvoker) : BaseModel(modelSource, callInvoker) { initModelImageSize(); - if (normMean.size() >= 3) { + if (normMean.size() == 3) { normMean_ = cv::Scalar(normMean[0], normMean[1], normMean[2]); + } else { + log(LOG_LEVEL::Warn, + "normMean must have 3 elements — ignoring provided value."); } - if (normStd.size() >= 3) { + if (normStd.size() == 3) { normStd_ = cv::Scalar(normStd[0], normStd[1], normStd[2]); + } else { + log(LOG_LEVEL::Warn, + "normStd must have 3 elements — ignoring provided value."); } } diff --git a/packages/react-native-executorch/common/rnexecutorch/models/image_segmentation/BaseImageSegmentation.h b/packages/react-native-executorch/common/rnexecutorch/models/image_segmentation/BaseImageSegmentation.h index baf3872a9..201efc0c0 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/image_segmentation/BaseImageSegmentation.h +++ b/packages/react-native-executorch/common/rnexecutorch/models/image_segmentation/BaseImageSegmentation.h @@ -19,6 +19,7 @@ using executorch::extension::TensorPtr; class BaseImageSegmentation : public BaseModel { public: + ~BaseImageSegmentation() override = default; BaseImageSegmentation(const std::string &modelSource, std::shared_ptr callInvoker); @@ -54,4 +55,8 @@ class BaseImageSegmentation : public BaseModel { void initModelImageSize(); }; } // namespace models::image_segmentation + +REGISTER_CONSTRUCTOR(models::image_segmentation::BaseImageSegmentation, + std::string, std::vector, std::vector, + std::shared_ptr); } // namespace rnexecutorch diff --git a/packages/react-native-executorch/common/rnexecutorch/models/image_segmentation/ImageSegmentation.cpp b/packages/react-native-executorch/common/rnexecutorch/models/image_segmentation/ImageSegmentation.cpp deleted file mode 100644 index acf4bbdf7..000000000 --- a/packages/react-native-executorch/common/rnexecutorch/models/image_segmentation/ImageSegmentation.cpp +++ /dev/null @@ -1 +0,0 @@ -#include "ImageSegmentation.h" diff --git a/packages/react-native-executorch/common/rnexecutorch/models/image_segmentation/ImageSegmentation.h b/packages/react-native-executorch/common/rnexecutorch/models/image_segmentation/ImageSegmentation.h deleted file mode 100644 index 4e4bf1baf..000000000 --- a/packages/react-native-executorch/common/rnexecutorch/models/image_segmentation/ImageSegmentation.h +++ /dev/null @@ -1,19 +0,0 @@ -#pragma once - -#include "rnexecutorch/metaprogramming/ConstructorHelpers.h" -#include - -namespace rnexecutorch { -namespace models::image_segmentation { -using namespace facebook; - -class ImageSegmentation : public BaseImageSegmentation { -public: - using BaseImageSegmentation::BaseImageSegmentation; -}; -} // namespace models::image_segmentation - -REGISTER_CONSTRUCTOR(models::image_segmentation::ImageSegmentation, std::string, - std::vector, std::vector, - std::shared_ptr); -} // namespace rnexecutorch diff --git a/packages/react-native-executorch/common/rnexecutorch/models/object_detection/ObjectDetection.cpp b/packages/react-native-executorch/common/rnexecutorch/models/object_detection/ObjectDetection.cpp index 8b5bc022f..54528572f 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/object_detection/ObjectDetection.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/models/object_detection/ObjectDetection.cpp @@ -2,6 +2,7 @@ #include #include +#include #include namespace rnexecutorch::models::object_detection { diff --git a/packages/react-native-executorch/src/hooks/computer_vision/useImageSegmentation.ts b/packages/react-native-executorch/src/hooks/computer_vision/useImageSegmentation.ts index a885dd9b8..88831f9aa 100644 --- a/packages/react-native-executorch/src/hooks/computer_vision/useImageSegmentation.ts +++ b/packages/react-native-executorch/src/hooks/computer_vision/useImageSegmentation.ts @@ -45,6 +45,7 @@ export const useImageSegmentation = ({ useEffect(() => { if (preventLoad) return; + let isMounted = true; let currentInstance: ImageSegmentationModule> | null = null; (async () => { @@ -54,16 +55,21 @@ export const useImageSegmentation = ({ try { currentInstance = await ImageSegmentationModule.fromModelName( model, - setDownloadProgress + (progress) => { + if (isMounted) setDownloadProgress(progress); + } ); - setInstance(currentInstance); - setIsReady(true); + if (isMounted) { + setInstance(currentInstance); + setIsReady(true); + } } catch (err) { - setError(parseUnknownError(err)); + if (isMounted) setError(parseUnknownError(err)); } })(); return () => { + isMounted = false; currentInstance?.delete(); }; diff --git a/packages/react-native-executorch/src/hooks/useModule.ts b/packages/react-native-executorch/src/hooks/useModule.ts index 39b10249b..1a35885d5 100644 --- a/packages/react-native-executorch/src/hooks/useModule.ts +++ b/packages/react-native-executorch/src/hooks/useModule.ts @@ -35,19 +35,24 @@ export const useModule = < useEffect(() => { if (preventLoad) return; + let isMounted = true; + (async () => { setDownloadProgress(0); setError(null); try { setIsReady(false); - await moduleInstance.load(model, setDownloadProgress); - setIsReady(true); + await moduleInstance.load(model, (progress: number) => { + if (isMounted) setDownloadProgress(progress); + }); + if (isMounted) setIsReady(true); } catch (err) { - setError(parseUnknownError(err)); + if (isMounted) setError(parseUnknownError(err)); } })(); return () => { + isMounted = false; moduleInstance.delete(); }; From aa9823f1d98db84a8b69923d6d54bcc5b55ac56e Mon Sep 17 00:00:00 2001 From: chmjkb Date: Mon, 23 Feb 2026 08:20:52 +0100 Subject: [PATCH 25/27] chore: remove unused header --- .../rnexecutorch/models/object_detection/ObjectDetection.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/packages/react-native-executorch/common/rnexecutorch/models/object_detection/ObjectDetection.cpp b/packages/react-native-executorch/common/rnexecutorch/models/object_detection/ObjectDetection.cpp index 54528572f..8b5bc022f 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/object_detection/ObjectDetection.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/models/object_detection/ObjectDetection.cpp @@ -2,7 +2,6 @@ #include #include -#include #include namespace rnexecutorch::models::object_detection { From 3e95f6ea1ff60e363da2e5493ad27e2fc7c01e97 Mon Sep 17 00:00:00 2001 From: chmjkb Date: Mon, 23 Feb 2026 10:43:55 +0100 Subject: [PATCH 26/27] post-rebase fi --- .../modules/computer_vision/ImageSegmentationModule.ts | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/packages/react-native-executorch/src/modules/computer_vision/ImageSegmentationModule.ts b/packages/react-native-executorch/src/modules/computer_vision/ImageSegmentationModule.ts index 6b2c2b7d6..f2de6edd7 100644 --- a/packages/react-native-executorch/src/modules/computer_vision/ImageSegmentationModule.ts +++ b/packages/react-native-executorch/src/modules/computer_vision/ImageSegmentationModule.ts @@ -9,9 +9,8 @@ import { SelfieSegmentationLabel, } from '../../types/imageSegmentation'; import { RnExecutorchErrorCode } from '../../errors/ErrorCodes'; -import { parseUnknownError, RnExecutorchError } from '../../errors/errorUtils'; +import { RnExecutorchError } from '../../errors/errorUtils'; import { BaseModule } from '../BaseModule'; -import { Logger } from '../../common/Logger'; const ModelConfigs = { 'deeplab-v3': { @@ -74,7 +73,7 @@ export class ImageSegmentationModule< } // TODO: figure it out so we can delete this (we need this because of basemodule inheritance) - override async load() { } + override async load() {} /** * Creates a segmentation instance for a built-in model. @@ -95,7 +94,7 @@ export class ImageSegmentationModule< static async fromModelName( config: C, - onDownloadProgress: (progress: number) => void = () => { } + onDownloadProgress: (progress: number) => void = () => {} ): Promise>> { const { modelName, modelSource } = config; const { labelMap } = ModelConfigs[modelName]; @@ -143,7 +142,7 @@ export class ImageSegmentationModule< static async fromCustomConfig( modelSource: ResourceSource, config: SegmentationConfig, - onDownloadProgress: (progress: number) => void = () => { } + onDownloadProgress: (progress: number) => void = () => {} ): Promise> { const paths = await ResourceFetcher.fetch(onDownloadProgress, modelSource); if (!paths?.[0]) { From b6a4f1e6dd63239d0b694ee4719390e1bc139ed3 Mon Sep 17 00:00:00 2001 From: chmjkb Date: Mon, 23 Feb 2026 13:41:31 +0100 Subject: [PATCH 27/27] fix --- .../common/rnexecutorch/models/BaseModel.h | 2 ++ .../models/image_segmentation/BaseImageSegmentation.h | 1 - 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/packages/react-native-executorch/common/rnexecutorch/models/BaseModel.h b/packages/react-native-executorch/common/rnexecutorch/models/BaseModel.h index a0f9d5446..56de2b423 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/BaseModel.h +++ b/packages/react-native-executorch/common/rnexecutorch/models/BaseModel.h @@ -20,6 +20,8 @@ using executorch::runtime::Result; class BaseModel { public: virtual ~BaseModel() = default; + BaseModel(BaseModel &&) = default; + BaseModel &operator=(BaseModel &&) = default; BaseModel( const std::string &modelSource, std::shared_ptr callInvoker, diff --git a/packages/react-native-executorch/common/rnexecutorch/models/image_segmentation/BaseImageSegmentation.h b/packages/react-native-executorch/common/rnexecutorch/models/image_segmentation/BaseImageSegmentation.h index 201efc0c0..f46f41d69 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/image_segmentation/BaseImageSegmentation.h +++ b/packages/react-native-executorch/common/rnexecutorch/models/image_segmentation/BaseImageSegmentation.h @@ -19,7 +19,6 @@ using executorch::extension::TensorPtr; class BaseImageSegmentation : public BaseModel { public: - ~BaseImageSegmentation() override = default; BaseImageSegmentation(const std::string &modelSource, std::shared_ptr callInvoker);