diff --git a/.cspell-wordlist.txt b/.cspell-wordlist.txt index 7428cd147..603da5b59 100644 --- a/.cspell-wordlist.txt +++ b/.cspell-wordlist.txt @@ -111,4 +111,7 @@ logprob RNFS pogodin kesha -antonov \ No newline at end of file +antonov +rfdetr +basemodule +IMAGENET diff --git a/apps/computer-vision/app/image_segmentation/index.tsx b/apps/computer-vision/app/image_segmentation/index.tsx index 61a98ddea..9b614d409 100644 --- a/apps/computer-vision/app/image_segmentation/index.tsx +++ b/apps/computer-vision/app/image_segmentation/index.tsx @@ -2,9 +2,8 @@ import Spinner from '../../components/Spinner'; import { BottomBar } from '../../components/BottomBar'; import { getImage } from '../../utils'; import { - useImageSegmentation, DEEPLAB_V3_RESNET50, - DeeplabLabel, + useImageSegmentation, } from 'react-native-executorch'; import { Canvas, @@ -44,16 +43,20 @@ const numberToColor: number[][] = [ ]; export default function ImageSegmentationScreen() { - const model = useImageSegmentation({ model: DEEPLAB_V3_RESNET50 }); const { setGlobalGenerating } = useContext(GeneratingContext); - useEffect(() => { - setGlobalGenerating(model.isGenerating); - }, [model.isGenerating, setGlobalGenerating]); + const { isReady, isGenerating, downloadProgress, forward } = + useImageSegmentation({ + model: DEEPLAB_V3_RESNET50, + }); const [imageUri, setImageUri] = useState(''); const [imageSize, setImageSize] = useState({ width: 0, height: 0 }); const [segImage, setSegImage] = useState(null); const [canvasSize, setCanvasSize] = useState({ width: 0, height: 0 }); + useEffect(() => { + setGlobalGenerating(isGenerating); + }, [isGenerating, setGlobalGenerating]); + const handleCameraPress = async (isCamera: boolean) => { const image = await getImage(isCamera); if (!image?.uri) return; @@ -69,12 +72,8 @@ export default function ImageSegmentationScreen() { if (!imageUri || imageSize.width === 0 || imageSize.height === 0) return; try { const { width, height } = imageSize; - const output = await model.forward(imageUri, [DeeplabLabel.ARGMAX]); - const argmax = output[DeeplabLabel.ARGMAX] || []; - const uniqueValues = new Set(); - for (let i = 0; i < argmax.length; i++) { - uniqueValues.add(argmax[i]); - } + const output = await forward(imageUri, [], true); + const argmax = output.ARGMAX || []; const pixels = new Uint8Array(width * height * 4); for (let row = 0; row < height; row++) { @@ -105,11 +104,11 @@ export default function ImageSegmentationScreen() { } }; - if (!model.isReady) { + if (!isReady) { return ( ); } diff --git a/docs/docs/03-hooks/02-computer-vision/useImageSegmentation.md b/docs/docs/03-hooks/02-computer-vision/useImageSegmentation.md index 3e541bd3b..676ae012e 100644 --- a/docs/docs/03-hooks/02-computer-vision/useImageSegmentation.md +++ b/docs/docs/03-hooks/02-computer-vision/useImageSegmentation.md @@ -21,12 +21,15 @@ import { DEEPLAB_V3_RESNET50, } from 'react-native-executorch'; -const model = useImageSegmentation({ model: DEEPLAB_V3_RESNET50 }); +const model = useImageSegmentation({ + model: DEEPLAB_V3_RESNET50, +}); const imageUri = 'file::///Users/.../cute_cat.png'; try { - const outputDict = await model.forward(imageUri); + const result = await model.forward(imageUri); + // result.ARGMAX is an Int32Array of per-pixel class indices } catch (error) { console.error(error); } @@ -36,9 +39,13 @@ try { `useImageSegmentation` takes [`ImageSegmentationProps`](../../06-api-reference/interfaces/ImageSegmentationProps.md) that consists of: -- `model` containing [`modelSource`](../../06-api-reference/interfaces/ImageSegmentationProps.md#modelsource). +- `model` - An object containing: + - `modelName` - The name of a built-in model. See [`ModelSources`](../../06-api-reference/type-aliases/ModelSources.md) for the list of supported models. + - `modelSource` - The location of the model binary (a URL or a bundled resource). - An optional flag [`preventLoad`](../../06-api-reference/interfaces/ImageSegmentationProps.md#preventload) which prevents auto-loading of the model. +The hook is generic over the model config — TypeScript automatically infers the correct label type based on the `modelName` you provide. No explicit generic parameter is needed. + You need more details? Check the following resources: - For detailed information about `useImageSegmentation` arguments check this section: [`useImageSegmentation` arguments](../../06-api-reference/functions/useImageSegmentation.md#parameters). @@ -47,45 +54,70 @@ You need more details? Check the following resources: ### Returns -`useImageSegmentation` returns an object called `ImageSegmentationType` containing bunch of functions to interact with image segmentation models. To get more details please read: [`ImageSegmentationType` API Reference](../../06-api-reference/interfaces/ImageSegmentationType.md). +`useImageSegmentation` returns an [`ImageSegmentationType`](../../06-api-reference/interfaces/ImageSegmentationType.md) object containing: + +- `isReady` - Whether the model is loaded and ready to process images. +- `isGenerating` - Whether the model is currently processing an image. +- `error` - An error object if the model failed to load or encountered a runtime error. +- `downloadProgress` - A value between 0 and 1 representing the download progress of the model binary. +- `forward` - A function to run inference on an image. ## Running the model -To run the model, you can use the [`forward`](../../06-api-reference/interfaces/ImageSegmentationType.md#forward) method. It accepts three arguments: a required image - can be a remote URL, a local file URI, or a base64-encoded image (whole URI or only raw base64), an optional list of classes, and an optional flag whether to resize the output to the original dimensions. +To run the model, use the [`forward`](../../06-api-reference/interfaces/ImageSegmentationType.md#forward) method. It accepts three arguments: -- The image can be a remote URL, a local file URI, or a base64-encoded image. -- The [`classesOfInterest`](../../06-api-reference/interfaces/ImageSegmentationType.md#classesofinterest) list contains classes for which to output the full results. By default the list is empty, and only the most probable classes are returned (essentially an arg max for each pixel). Look at [`DeeplabLabel`](../../06-api-reference/enumerations/DeeplabLabel.md) enum for possible classes. -- The [`resizeToInput`](../../06-api-reference/interfaces/ImageSegmentationType.md#resizetoinput) flag specifies whether the output will be rescaled back to the size of the input image. The default is `true`. The model runs inference on a scaled (probably smaller) version of your image (224x224 for `DEEPLAB_V3_RESNET50`). If you choose to resize, the output will be `number[]` of size `width * height` of your original image. +- [`imageSource`](../../06-api-reference/interfaces/ImageSegmentationType.md#forward) (required) - The image to segment. Can be a remote URL, a local file URI, or a base64-encoded image (whole URI or only raw base64). +- [`classesOfInterest`](../../06-api-reference/interfaces/ImageSegmentationType.md#forward) (optional) - An array of label keys indicating which per-class probability masks to include in the output. Defaults to `[]` (no class masks). The `ARGMAX` map is always returned regardless of this parameter. +- [`resizeToInput`](../../06-api-reference/interfaces/ImageSegmentationType.md#forward) (optional) - Whether to resize the output masks to the original input image dimensions. Defaults to `true`. If `false`, returns the raw model output dimensions (e.g. 224x224 for `DEEPLAB_V3_RESNET50`). :::warning Setting `resizeToInput` to `false` will make `forward` faster. ::: -[`forward`](../../06-api-reference/interfaces/ImageSegmentationType.md#forward) returns a promise which can resolve either to an error or a dictionary containing number arrays with size depending on [`resizeToInput`](../../06-api-reference/interfaces/ImageSegmentationType.md#resizetoinput): +`forward` returns a promise resolving to an object containing: + +- `ARGMAX` - An `Int32Array` where each element is the class index with the highest probability for that pixel. +- For each label included in `classesOfInterest`, a `Float32Array` of per-pixel probabilities for that class. -- For the key [`DeeplabLabel.ARGMAX`](../../06-api-reference/enumerations/DeeplabLabel.md#argmax) the array contains for each pixel an integer corresponding to the class with the highest probability. -- For every other key from [`DeeplabLabel`](../../06-api-reference/enumerations/DeeplabLabel.md), if the label was included in [`classesOfInterest`](../../06-api-reference/interfaces/ImageSegmentationType.md#classesofinterest) the dictionary will contain an array of floats corresponding to the probability of this class for every pixel. +The return type is fully typed — TypeScript narrows it based on the labels you pass in `classesOfInterest`. ## Example ```typescript +import { + useImageSegmentation, + DEEPLAB_V3_RESNET50, + DeeplabLabel, +} from 'react-native-executorch'; + function App() { - const model = useImageSegmentation({ model: DEEPLAB_V3_RESNET50 }); + const model = useImageSegmentation({ + model: DEEPLAB_V3_RESNET50, + }); - // ... - const imageUri = 'file::///Users/.../cute_cat.png'; + const handleSegment = async () => { + if (!model.isReady) return; + + const imageUri = 'file::///Users/.../cute_cat.png'; + + try { + const result = await model.forward(imageUri, ['CAT', 'PERSON'], true); + + // result.ARGMAX — Int32Array of per-pixel class indices + // result.CAT — Float32Array of per-pixel probabilities for CAT + // result.PERSON — Float32Array of per-pixel probabilities for PERSON + } catch (error) { + console.error(error); + } + }; - try { - const outputDict = await model.forward(imageUri, [DeeplabLabel.CAT], true); - } catch (error) { - console.error(error); - } // ... } ``` ## Supported models -| Model | Number of classes | Class list | -| ------------------------------------------------------------------------------------------------ | ----------------- | ------------------------------------------------------------------- | -| [deeplabv3_resnet50](https://huggingface.co/software-mansion/react-native-executorch-deeplab-v3) | 21 | [DeeplabLabel](../../06-api-reference/enumerations/DeeplabLabel.md) | +| Model | Number of classes | Class list | +| ------------------------------------------------------------------------------------------------ | ----------------- | ----------------------------------------------------------------------------------------- | +| [deeplabv3_resnet50](https://huggingface.co/software-mansion/react-native-executorch-deeplab-v3) | 21 | [DeeplabLabel](../../06-api-reference/enumerations/DeeplabLabel.md) | +| selfie-segmentation | 2 | [SelfieSegmentationLabel](../../06-api-reference/enumerations/SelfieSegmentationLabel.md) | diff --git a/docs/docs/04-typescript-api/02-computer-vision/ImageSegmentationModule.md b/docs/docs/04-typescript-api/02-computer-vision/ImageSegmentationModule.md index f315e72b0..858713368 100644 --- a/docs/docs/04-typescript-api/02-computer-vision/ImageSegmentationModule.md +++ b/docs/docs/04-typescript-api/02-computer-vision/ImageSegmentationModule.md @@ -19,14 +19,15 @@ import { const imageUri = 'path/to/image.png'; -// Creating an instance -const imageSegmentationModule = new ImageSegmentationModule(); - -// Loading the model -await imageSegmentationModule.load(DEEPLAB_V3_RESNET50); +// Creating an instance from a built-in model +const segmentation = await ImageSegmentationModule.fromModelName({ + modelName: 'deeplab-v3', + modelSource: DEEPLAB_V3_RESNET50, +}); // Running the model -const outputDict = await imageSegmentationModule.forward(imageUri); +const result = await segmentation.forward(imageUri); +// result.ARGMAX — Int32Array of per-pixel class indices ``` ### Methods @@ -35,34 +36,75 @@ All methods of `ImageSegmentationModule` are explained in details here: [`ImageS ## Loading the model -To initialize the module, create an instance and call the [`load`](../../06-api-reference/classes/ImageSegmentationModule.md#load) method with the following parameters: +`ImageSegmentationModule` uses static factory methods instead of `new()` + `load()`. There are two ways to create an instance: + +### Built-in models — `fromModelName` + +Use [`fromModelName`](../../06-api-reference/classes/ImageSegmentationModule.md#frommodelname) for models that ship with built-in label maps and preprocessing configs: + +```typescript +const segmentation = await ImageSegmentationModule.fromModelName( + DEEPLAB_V3_RESNET50, + (progress) => console.log(`Download: ${Math.round(progress * 100)}%`) +); +``` -- [`model`](../../06-api-reference/classes/ImageSegmentationModule.md#model) - Object containing: - - [`modelSource`](../../06-api-reference/classes/ImageSegmentationModule.md#modelsource) - Location of the used model. +The `config` parameter is a discriminated union — TypeScript ensures you provide the correct fields for each model name. Available built-in models: `'deeplab-v3'`, `'selfie-segmentation'`. -- [`onDownloadProgressCallback`](../../06-api-reference/classes/ImageSegmentationModule.md#ondownloadprogresscallback) - Callback to track download progress. +### Custom models — `fromCustomConfig` -This method returns a promise, which can resolve to an error or void. +Use [`fromCustomConfig`](../../06-api-reference/classes/ImageSegmentationModule.md#fromcustomconfig) for custom-exported segmentation models with your own label map: + +```typescript +const MyLabels = { BACKGROUND: 0, FOREGROUND: 1 } as const; + +const segmentation = await ImageSegmentationModule.fromCustomConfig( + 'https://example.com/custom_model.pte', + { + labelMap: MyLabels, + preprocessorConfig: { + normMean: [0.485, 0.456, 0.406], + normStd: [0.229, 0.224, 0.225], + }, + } +); +``` + +The `preprocessorConfig` is optional. If omitted, no input normalization is applied. The module instance will be typed to your custom label map — `forward` will accept and return keys from `MyLabels`. For more information on loading resources, take a look at [loading models](../../01-fundamentals/02-loading-models.md) page. ## Running the model -To run the model, you can use the [`forward`](../../06-api-reference/classes/ImageSegmentationModule.md#forward) method on the module object. It accepts three arguments: a required image - can be a remote URL, a local file URI, or a base64-encoded image (whole URI or only raw base64), an optional list of classes, and an optional flag whether to resize the output to the original dimensions. +To run the model, use the [`forward`](../../06-api-reference/classes/ImageSegmentationModule.md#forward) method. It accepts three arguments: -- The image can be a remote URL, a local file URI, or a base64-encoded image. -- The [`classesOfInterest`](../../06-api-reference/classes/ImageSegmentationModule.md#classesofinterest) list contains classes for which to output the full results. By default the list is empty, and only the most probable classes are returned (essentially an arg max for each pixel). Look at [`DeeplabLabel`](../../06-api-reference/enumerations/DeeplabLabel.md) enum for possible classes. -- The [`resizeToInput`](../../06-api-reference/classes/ImageSegmentationModule.md#resizetoinput) flag specifies whether the output will be rescaled back to the size of the input image. The default is `true`. The model runs inference on a scaled (probably smaller) version of your image (224x224 for the `DEEPLAB_V3_RESNET50`). If you choose to resize, the output will be `number[]` of size `width * height` of your original image. +- [`imageSource`](../../06-api-reference/classes/ImageSegmentationModule.md#forward) (required) - The image to segment. Can be a remote URL, a local file URI, or a base64-encoded image (whole URI or only raw base64). +- [`classesOfInterest`](../../06-api-reference/classes/ImageSegmentationModule.md#forward) (optional) - An array of label keys indicating which per-class probability masks to include in the output. Defaults to `[]`. The `ARGMAX` map is always returned regardless. +- [`resizeToInput`](../../06-api-reference/classes/ImageSegmentationModule.md#forward) (optional) - Whether to resize the output masks to the original input image dimensions. Defaults to `true`. If `false`, returns the raw model output dimensions. :::warning -Setting `resize` to true will make `forward` slower. +Setting `resizeToInput` to `false` will make `forward` faster. ::: -[`forward`](../../06-api-reference/classes/ImageSegmentationModule.md#forward) returns a promise which can resolve either to an error or a dictionary containing number arrays with size depending on [`resizeToInput`](../../06-api-reference/classes/ImageSegmentationModule.md#resizetoinput): +`forward` returns a promise resolving to an object containing: + +- `ARGMAX` - An `Int32Array` where each element is the class index with the highest probability for that pixel. +- For each label included in `classesOfInterest`, a `Float32Array` of per-pixel probabilities for that class. -- For the key [`DeeplabLabel.ARGMAX`](../../06-api-reference/enumerations/DeeplabLabel.md#argmax) the array contains for each pixel an integer corresponding to the class with the highest probability. -- For every other key from [`DeeplabLabel`](../../06-api-reference/enumerations/DeeplabLabel.md), if the label was included in [`classesOfInterest`](../../06-api-reference/classes/ImageSegmentationModule.md#classesofinterest) the dictionary will contain an array of floats corresponding to the probability of this class for every pixel. +The return type narrows based on the labels passed in `classesOfInterest`: + +```typescript +// Only ARGMAX in the result +const result = await segmentation.forward(imageUri); +result.ARGMAX; // Int32Array + +// ARGMAX + requested class masks +const result = await segmentation.forward(imageUri, ['CAT', 'DOG']); +result.ARGMAX; // Int32Array +result.CAT; // Float32Array +result.DOG; // Float32Array +``` ## Managing memory -The module is a regular JavaScript object, and as such its lifespan will be managed by the garbage collector. In most cases this should be enough, and you should not worry about freeing the memory of the module yourself, but in some cases you may want to release the memory occupied by the module before the garbage collector steps in. In this case use the method [`delete`](../../06-api-reference/classes/ImageSegmentationModule.md#delete) on the module object you will no longer use, and want to remove from the memory. Note that you cannot use [`forward`](../../06-api-reference/classes/ImageSegmentationModule.md#forward) after [`delete`](../../06-api-reference/classes/ImageSegmentationModule.md#delete) unless you load the module again. +The module is a regular JavaScript object, and as such its lifespan will be managed by the garbage collector. In most cases this should be enough, and you should not worry about freeing the memory of the module yourself, but in some cases you may want to release the memory occupied by the module before the garbage collector steps in. In this case use the method [`delete`](../../06-api-reference/classes/ImageSegmentationModule.md#delete) on the module object you will no longer use, and want to remove from the memory. Note that you cannot use [`forward`](../../06-api-reference/classes/ImageSegmentationModule.md#forward) after [`delete`](../../06-api-reference/classes/ImageSegmentationModule.md#delete) unless you create a new instance. diff --git a/docs/docs/06-api-reference/enumerations/SelfieSegmentationLabel.md b/docs/docs/06-api-reference/enumerations/SelfieSegmentationLabel.md new file mode 100644 index 000000000..912d29f5d --- /dev/null +++ b/docs/docs/06-api-reference/enumerations/SelfieSegmentationLabel.md @@ -0,0 +1,21 @@ +# Enumeration: SelfieSegmentationLabel + +Defined in: [packages/react-native-executorch/src/types/imageSegmentation.ts:81](https://github.com/software-mansion/react-native-executorch/blob/ec04754e2ea2ad38fe30c36a9250db47f020a06e/packages/react-native-executorch/src/types/imageSegmentation.ts#L81) + +Labels used in the selfie image segmentation model. + +## Enumeration Members + +### BACKGROUND + +> **BACKGROUND**: `1` + +Defined in: [packages/react-native-executorch/src/types/imageSegmentation.ts:83](https://github.com/software-mansion/react-native-executorch/blob/ec04754e2ea2ad38fe30c36a9250db47f020a06e/packages/react-native-executorch/src/types/imageSegmentation.ts#L83) + +--- + +### SELFIE + +> **SELFIE**: `0` + +Defined in: [packages/react-native-executorch/src/types/imageSegmentation.ts:82](https://github.com/software-mansion/react-native-executorch/blob/ec04754e2ea2ad38fe30c36a9250db47f020a06e/packages/react-native-executorch/src/types/imageSegmentation.ts#L82) diff --git a/docs/docs/06-api-reference/type-aliases/LabelEnum.md b/docs/docs/06-api-reference/type-aliases/LabelEnum.md new file mode 100644 index 000000000..9414676dc --- /dev/null +++ b/docs/docs/06-api-reference/type-aliases/LabelEnum.md @@ -0,0 +1,8 @@ +# Type Alias: LabelEnum + +> **LabelEnum** = `Readonly`\<`Record`\<`string`, `number` \| `string`\>\> + +Defined in: [packages/react-native-executorch/src/types/common.ts:146](https://github.com/software-mansion/react-native-executorch/blob/ec04754e2ea2ad38fe30c36a9250db47f020a06e/packages/react-native-executorch/src/types/common.ts#L146) + +A readonly record mapping string keys to numeric or string values. +Used to represent enum-like label maps for models. diff --git a/docs/docs/06-api-reference/type-aliases/ModelNameOf.md b/docs/docs/06-api-reference/type-aliases/ModelNameOf.md new file mode 100644 index 000000000..e962ab698 --- /dev/null +++ b/docs/docs/06-api-reference/type-aliases/ModelNameOf.md @@ -0,0 +1,13 @@ +# Type Alias: ModelNameOf\ + +> **ModelNameOf**\<`C`\> = `C`\[`"modelName"`\] + +Defined in: [packages/react-native-executorch/src/types/imageSegmentation.ts:45](https://github.com/software-mansion/react-native-executorch/blob/ec04754e2ea2ad38fe30c36a9250db47f020a06e/packages/react-native-executorch/src/types/imageSegmentation.ts#L45) + +Extracts the model name from a [ModelSources](ModelSources.md) config object. + +## Type Parameters + +### C + +`C` _extends_ [`ModelSources`](ModelSources.md) diff --git a/docs/docs/06-api-reference/type-aliases/ModelSources.md b/docs/docs/06-api-reference/type-aliases/ModelSources.md new file mode 100644 index 000000000..eefe51641 --- /dev/null +++ b/docs/docs/06-api-reference/type-aliases/ModelSources.md @@ -0,0 +1,9 @@ +# Type Alias: ModelSources + +> **ModelSources** = \{ `modelName`: `"deeplab-v3"`; `modelSource`: [`ResourceSource`](ResourceSource.md); \} \| \{ `modelName`: `"selfie-segmentation"`; `modelSource`: [`ResourceSource`](ResourceSource.md); \} \| \{ `modelName`: `"rfdetr"`; `modelSource`: [`ResourceSource`](ResourceSource.md); \} + +Defined in: [packages/react-native-executorch/src/types/imageSegmentation.ts:27](https://github.com/software-mansion/react-native-executorch/blob/ec04754e2ea2ad38fe30c36a9250db47f020a06e/packages/react-native-executorch/src/types/imageSegmentation.ts#L27) + +Per-model config for [ImageSegmentationModule.fromModelName](../classes/ImageSegmentationModule.md#frommodelname). +Each model name maps to its required fields. +Add new union members here when a model needs extra sources or options. diff --git a/docs/docs/06-api-reference/type-aliases/SegmentationConfig.md b/docs/docs/06-api-reference/type-aliases/SegmentationConfig.md new file mode 100644 index 000000000..ce957311d --- /dev/null +++ b/docs/docs/06-api-reference/type-aliases/SegmentationConfig.md @@ -0,0 +1,43 @@ +# Type Alias: SegmentationConfig\ + +> **SegmentationConfig**\<`T`\> = `object` + +Defined in: [packages/react-native-executorch/src/types/imageSegmentation.ts:15](https://github.com/software-mansion/react-native-executorch/blob/ec04754e2ea2ad38fe30c36a9250db47f020a06e/packages/react-native-executorch/src/types/imageSegmentation.ts#L15) + +Configuration for a custom segmentation model. + +## Type Parameters + +### T + +`T` _extends_ [`LabelEnum`](LabelEnum.md) + +The [LabelEnum](LabelEnum.md) type for the model. + +## Properties + +### labelMap + +> **labelMap**: `T` + +Defined in: [packages/react-native-executorch/src/types/imageSegmentation.ts:16](https://github.com/software-mansion/react-native-executorch/blob/ec04754e2ea2ad38fe30c36a9250db47f020a06e/packages/react-native-executorch/src/types/imageSegmentation.ts#L16) + +The enum-like object mapping class names to indices. + +--- + +### preprocessorConfig? + +> `optional` **preprocessorConfig**: `object` + +Defined in: [packages/react-native-executorch/src/types/imageSegmentation.ts:17](https://github.com/software-mansion/react-native-executorch/blob/ec04754e2ea2ad38fe30c36a9250db47f020a06e/packages/react-native-executorch/src/types/imageSegmentation.ts#L17) + +Optional preprocessing parameters. + +#### normMean? + +> `optional` **normMean**: [`Triple`](Triple.md)\<`number`\> + +#### normStd? + +> `optional` **normStd**: [`Triple`](Triple.md)\<`number`\> diff --git a/docs/docs/06-api-reference/type-aliases/SegmentationLabels.md b/docs/docs/06-api-reference/type-aliases/SegmentationLabels.md new file mode 100644 index 000000000..96939a346 --- /dev/null +++ b/docs/docs/06-api-reference/type-aliases/SegmentationLabels.md @@ -0,0 +1,15 @@ +# Type Alias: SegmentationLabels\ + +> **SegmentationLabels**\<`M`\> = `ModelConfigsType`\[`M`\]\[`"labelMap"`\] + +Defined in: [packages/react-native-executorch/src/modules/computer_vision/ImageSegmentationModule.ts:45](https://github.com/software-mansion/react-native-executorch/blob/ec04754e2ea2ad38fe30c36a9250db47f020a06e/packages/react-native-executorch/src/modules/computer_vision/ImageSegmentationModule.ts#L45) + +Resolves the [LabelEnum](LabelEnum.md) for a given built-in model name. + +## Type Parameters + +### M + +`M` _extends_ [`SegmentationModelName`](SegmentationModelName.md) + +A built-in model name from [SegmentationModelName](SegmentationModelName.md). diff --git a/docs/docs/06-api-reference/type-aliases/SegmentationModelName.md b/docs/docs/06-api-reference/type-aliases/SegmentationModelName.md new file mode 100644 index 000000000..ba7197d06 --- /dev/null +++ b/docs/docs/06-api-reference/type-aliases/SegmentationModelName.md @@ -0,0 +1,8 @@ +# Type Alias: SegmentationModelName + +> **SegmentationModelName** = [`ModelSources`](ModelSources.md)\[`"modelName"`\] + +Defined in: [packages/react-native-executorch/src/types/imageSegmentation.ts:38](https://github.com/software-mansion/react-native-executorch/blob/ec04754e2ea2ad38fe30c36a9250db47f020a06e/packages/react-native-executorch/src/types/imageSegmentation.ts#L38) + +Union of all built-in segmentation model names +(e.g. `'deeplab-v3'`, `'selfie-segmentation'`, `'rfdetr'`). diff --git a/docs/docs/06-api-reference/type-aliases/Triple.md b/docs/docs/06-api-reference/type-aliases/Triple.md new file mode 100644 index 000000000..41fa3d2b1 --- /dev/null +++ b/docs/docs/06-api-reference/type-aliases/Triple.md @@ -0,0 +1,13 @@ +# Type Alias: Triple\ + +> **Triple**\<`T`\> = readonly \[`T`, `T`, `T`\] + +Defined in: [packages/react-native-executorch/src/types/common.ts:153](https://github.com/software-mansion/react-native-executorch/blob/ec04754e2ea2ad38fe30c36a9250db47f020a06e/packages/react-native-executorch/src/types/common.ts#L153) + +A readonly triple of values, typically used for per-channel normalization parameters. + +## Type Parameters + +### T + +`T` diff --git a/docs/docs/06-api-reference/variables/SELFIE_SEGMENTATION.md b/docs/docs/06-api-reference/variables/SELFIE_SEGMENTATION.md new file mode 100644 index 000000000..36135377c --- /dev/null +++ b/docs/docs/06-api-reference/variables/SELFIE_SEGMENTATION.md @@ -0,0 +1,15 @@ +# Variable: SELFIE_SEGMENTATION + +> `const` **SELFIE_SEGMENTATION**: `object` + +Defined in: [packages/react-native-executorch/src/constants/modelUrls.ts:533](https://github.com/software-mansion/react-native-executorch/blob/ec04754e2ea2ad38fe30c36a9250db47f020a06e/packages/react-native-executorch/src/constants/modelUrls.ts#L533) + +## Type Declaration + +### modelName + +> `readonly` **modelName**: `"selfie-segmentation"` = `'selfie-segmentation'` + +### modelSource + +> `readonly` **modelSource**: `"https://ai.swmansion.com/storage/jc_tests/selfie.pte"` = `SELFIE_SEGMENTATION_MODEL` diff --git a/packages/react-native-executorch/common/rnexecutorch/Log.h b/packages/react-native-executorch/common/rnexecutorch/Log.h index 9381324ab..bb17a53ec 100644 --- a/packages/react-native-executorch/common/rnexecutorch/Log.h +++ b/packages/react-native-executorch/common/rnexecutorch/Log.h @@ -371,6 +371,8 @@ namespace rnexecutorch { enum class LOG_LEVEL : uint8_t { Info, /**< Informational messages that highlight the progress of the application. */ + Warn, /**< Warning messages that a non-critical error occurred during + program execution */ Error, /**< Error events of considerable importance that will prevent normal program execution. */ Debug /**< Detailed information, typically of interest only when diagnosing @@ -384,6 +386,8 @@ inline android_LogPriority androidLogLevel(LOG_LEVEL logLevel) { switch (logLevel) { case LOG_LEVEL::Info: return ANDROID_LOG_INFO; + case LOG_LEVEL::Warn: + return ANDROID_LOG_WARN; case LOG_LEVEL::Error: return ANDROID_LOG_ERROR; case LOG_LEVEL::Debug: @@ -404,6 +408,9 @@ inline void handleIosLog(LOG_LEVEL logLevel, const char *buffer) { case LOG_LEVEL::Info: os_log_info(OS_LOG_DEFAULT, "%{public}s", buffer); return; + case LOG_LEVEL::Warn: + os_log(OS_LOG_DEFAULT, "%{public}s", buffer); + return; case LOG_LEVEL::Error: os_log_error(OS_LOG_DEFAULT, "%{public}s", buffer); return; diff --git a/packages/react-native-executorch/common/rnexecutorch/RnExecutorchInstaller.cpp b/packages/react-native-executorch/common/rnexecutorch/RnExecutorchInstaller.cpp index 7a4426e06..7804fb5b7 100644 --- a/packages/react-native-executorch/common/rnexecutorch/RnExecutorchInstaller.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/RnExecutorchInstaller.cpp @@ -5,14 +5,14 @@ #include #include #include -#include +#include #include #include #include #include -#include #include #include +#include #include #include #include @@ -44,7 +44,7 @@ void RnExecutorchInstaller::injectJSIBindings( jsiRuntime->global().setProperty( *jsiRuntime, "loadImageSegmentation", RnExecutorchInstaller::loadModel< - models::image_segmentation::ImageSegmentation>( + models::image_segmentation::BaseImageSegmentation>( jsiRuntime, jsCallInvoker, "loadImageSegmentation")); jsiRuntime->global().setProperty( @@ -92,6 +92,7 @@ void RnExecutorchInstaller::injectJSIBindings( *jsiRuntime, "loadOCR", RnExecutorchInstaller::loadModel( jsiRuntime, jsCallInvoker, "loadOCR")); + jsiRuntime->global().setProperty( *jsiRuntime, "loadVerticalOCR", RnExecutorchInstaller::loadModel( @@ -101,10 +102,12 @@ void RnExecutorchInstaller::injectJSIBindings( *jsiRuntime, "loadSpeechToText", RnExecutorchInstaller::loadModel( jsiRuntime, jsCallInvoker, "loadSpeechToText")); + jsiRuntime->global().setProperty( *jsiRuntime, "loadTextToSpeechKokoro", RnExecutorchInstaller::loadModel( jsiRuntime, jsCallInvoker, "loadTextToSpeechKokoro")); + jsiRuntime->global().setProperty( *jsiRuntime, "loadVAD", RnExecutorchInstaller::loadModel< diff --git a/packages/react-native-executorch/common/rnexecutorch/data_processing/ImageProcessing.cpp b/packages/react-native-executorch/common/rnexecutorch/data_processing/ImageProcessing.cpp index bdf3f97cb..bd29500b0 100644 --- a/packages/react-native-executorch/common/rnexecutorch/data_processing/ImageProcessing.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/data_processing/ImageProcessing.cpp @@ -217,7 +217,8 @@ cv::Mat resizePadded(const cv::Mat inputImage, cv::Size targetSize) { std::pair readImageToTensor(const std::string &path, const std::vector &tensorDims, - bool maintainAspectRatio) { + bool maintainAspectRatio, std::optional normMean, + std::optional normStd) { cv::Mat input = image_processing::readImage(path); cv::Size imageSize = input.size(); @@ -241,6 +242,11 @@ readImageToTensor(const std::string &path, cv::cvtColor(input, input, cv::COLOR_BGR2RGB); + if (normMean.has_value() && normStd.has_value()) { + return {image_processing::getTensorFromMatrix( + tensorDims, input, normMean.value(), normStd.value()), + imageSize}; + } return {image_processing::getTensorFromMatrix(tensorDims, input), imageSize}; } } // namespace image_processing diff --git a/packages/react-native-executorch/common/rnexecutorch/data_processing/ImageProcessing.h b/packages/react-native-executorch/common/rnexecutorch/data_processing/ImageProcessing.h index 27934330a..1b0c10b33 100644 --- a/packages/react-native-executorch/common/rnexecutorch/data_processing/ImageProcessing.h +++ b/packages/react-native-executorch/common/rnexecutorch/data_processing/ImageProcessing.h @@ -51,5 +51,7 @@ cv::Mat resizePadded(const cv::Mat inputImage, cv::Size targetSize); std::pair readImageToTensor(const std::string &path, const std::vector &tensorDims, - bool maintainAspectRatio = false); + bool maintainAspectRatio = false, + std::optional normMean = std::nullopt, + std::optional normStd = std::nullopt); } // namespace rnexecutorch::image_processing diff --git a/packages/react-native-executorch/common/rnexecutorch/models/BaseModel.h b/packages/react-native-executorch/common/rnexecutorch/models/BaseModel.h index c40fa2569..56de2b423 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/BaseModel.h +++ b/packages/react-native-executorch/common/rnexecutorch/models/BaseModel.h @@ -19,6 +19,9 @@ using executorch::runtime::Result; class BaseModel { public: + virtual ~BaseModel() = default; + BaseModel(BaseModel &&) = default; + BaseModel &operator=(BaseModel &&) = default; BaseModel( const std::string &modelSource, std::shared_ptr callInvoker, diff --git a/packages/react-native-executorch/common/rnexecutorch/models/image_segmentation/BaseImageSegmentation.cpp b/packages/react-native-executorch/common/rnexecutorch/models/image_segmentation/BaseImageSegmentation.cpp new file mode 100644 index 000000000..141ec430e --- /dev/null +++ b/packages/react-native-executorch/common/rnexecutorch/models/image_segmentation/BaseImageSegmentation.cpp @@ -0,0 +1,238 @@ +#include "BaseImageSegmentation.h" +#include "jsi/jsi.h" + +#include + +#include +#include +#include +#include +#include + +namespace rnexecutorch::models::image_segmentation { + +BaseImageSegmentation::BaseImageSegmentation( + const std::string &modelSource, + std::shared_ptr callInvoker) + : BaseModel(modelSource, callInvoker) { + initModelImageSize(); +} + +BaseImageSegmentation::BaseImageSegmentation( + const std::string &modelSource, std::vector normMean, + std::vector normStd, std::shared_ptr callInvoker) + : BaseModel(modelSource, callInvoker) { + initModelImageSize(); + if (normMean.size() == 3) { + normMean_ = cv::Scalar(normMean[0], normMean[1], normMean[2]); + } else { + log(LOG_LEVEL::Warn, + "normMean must have 3 elements — ignoring provided value."); + } + if (normStd.size() == 3) { + normStd_ = cv::Scalar(normStd[0], normStd[1], normStd[2]); + } else { + log(LOG_LEVEL::Warn, + "normStd must have 3 elements — ignoring provided value."); + } +} + +void BaseImageSegmentation::initModelImageSize() { + auto inputShapes = getAllInputShapes(); + if (inputShapes.empty()) { + throw RnExecutorchError(RnExecutorchErrorCode::UnexpectedNumInputs, + "Model seems to not take any input tensors."); + } + std::vector modelInputShape = inputShapes[0]; + if (modelInputShape.size() < 2) { + throw RnExecutorchError(RnExecutorchErrorCode::WrongDimensions, + "Unexpected model input size, expected at least 2 " + "dimensions but got: " + + std::to_string(modelInputShape.size()) + "."); + } + modelImageSize = cv::Size(modelInputShape[modelInputShape.size() - 1], + modelInputShape[modelInputShape.size() - 2]); + numModelPixels = modelImageSize.area(); +} + +TensorPtr BaseImageSegmentation::preprocess(const std::string &imageSource, + cv::Size &originalSize) { + auto [inputTensor, origSize] = image_processing::readImageToTensor( + imageSource, getAllInputShapes()[0], false, normMean_, normStd_); + originalSize = origSize; + return inputTensor; +} + +std::shared_ptr BaseImageSegmentation::generate( + std::string imageSource, std::vector allClasses, + std::set> classesOfInterest, bool resize) { + + cv::Size originalSize; + auto inputTensor = preprocess(imageSource, originalSize); + + auto forwardResult = BaseModel::forward(inputTensor); + + if (!forwardResult.ok()) { + throw RnExecutorchError(forwardResult.error(), + "The model's forward function did not succeed. " + "Ensure the model input is correct."); + } + + return postprocess(forwardResult->at(0).toTensor(), originalSize, allClasses, + classesOfInterest, resize); +} + +std::shared_ptr BaseImageSegmentation::postprocess( + const Tensor &tensor, cv::Size originalSize, + std::vector &allClasses, + std::set> &classesOfInterest, bool resize) { + + const auto *dataPtr = tensor.const_data_ptr(); + auto resultData = std::span(dataPtr, tensor.numel()); + + // Read output dimensions directly from tensor shape + std::size_t numChannels = + (tensor.dim() >= 3) ? tensor.size(tensor.dim() - 3) : 1; + std::size_t outputH = tensor.size(tensor.dim() - 2); + std::size_t outputW = tensor.size(tensor.dim() - 1); + std::size_t outputPixels = outputH * outputW; + cv::Size outputSize(outputW, outputH); + + // Copy class data directly into OwningArrayBuffers (single copy from span) + std::vector> resultClasses; + resultClasses.reserve(numChannels); + + if (numChannels == 1) { + // Binary segmentation (e.g. selfie segmentation) + auto fg = std::make_shared(resultData.data(), + outputPixels * sizeof(float)); + auto bg = std::make_shared(outputPixels * sizeof(float)); + auto *fgPtr = reinterpret_cast(fg->data()); + auto *bgPtr = reinterpret_cast(bg->data()); + for (std::size_t pixel = 0; pixel < outputPixels; ++pixel) { + bgPtr[pixel] = 1.0f - fgPtr[pixel]; + } + resultClasses.push_back(fg); + resultClasses.push_back(bg); + } else { + // Multi-class segmentation (e.g. DeepLab, RF-DETR) + for (std::size_t cl = 0; cl < numChannels; ++cl) { + resultClasses.push_back(std::make_shared( + resultData.data() + cl * outputPixels, outputPixels * sizeof(float))); + } + } + + // Softmax + argmax in class-major order + auto argmax = + std::make_shared(outputPixels * sizeof(int32_t)); + auto *argmaxPtr = reinterpret_cast(argmax->data()); + + if (numChannels == 1) { + auto *fgPtr = reinterpret_cast(resultClasses[0]->data()); + for (std::size_t pixel = 0; pixel < outputPixels; ++pixel) { + argmaxPtr[pixel] = (fgPtr[pixel] > 0.5f) ? 0 : 1; + } + } else { + std::vector maxLogits(outputPixels, + -std::numeric_limits::infinity()); + std::vector sumExp(outputPixels, 0.0f); + + // Pass 1: find per-pixel max and argmax + for (std::size_t cl = 0; cl < numChannels; ++cl) { + auto *clPtr = reinterpret_cast(resultClasses[cl]->data()); + for (std::size_t pixel = 0; pixel < outputPixels; ++pixel) { + if (clPtr[pixel] > maxLogits[pixel]) { + maxLogits[pixel] = clPtr[pixel]; + argmaxPtr[pixel] = static_cast(cl); + } + } + } + + // Pass 2: subtract max, exp, accumulate sum + for (std::size_t cl = 0; cl < numChannels; ++cl) { + auto *clPtr = reinterpret_cast(resultClasses[cl]->data()); + for (std::size_t pixel = 0; pixel < outputPixels; ++pixel) { + clPtr[pixel] = std::exp(clPtr[pixel] - maxLogits[pixel]); + sumExp[pixel] += clPtr[pixel]; + } + } + + // Pass 3: normalize by sum + for (std::size_t cl = 0; cl < numChannels; ++cl) { + auto *clPtr = reinterpret_cast(resultClasses[cl]->data()); + for (std::size_t pixel = 0; pixel < outputPixels; ++pixel) { + clPtr[pixel] /= sumExp[pixel]; + } + } + } + + // Filter classes of interest + auto buffersToReturn = std::make_shared>>(); + for (std::size_t cl = 0; cl < resultClasses.size(); ++cl) { + if (cl < allClasses.size() && classesOfInterest.contains(allClasses[cl])) { + (*buffersToReturn)[allClasses[cl]] = resultClasses[cl]; + } + } + + // Resize selected classes and argmax + if (resize) { + cv::Mat argmaxMat(outputSize, CV_32SC1, argmax->data()); + cv::resize(argmaxMat, argmaxMat, originalSize, 0, 0, + cv::InterpolationFlags::INTER_NEAREST); + argmax = std::make_shared( + argmaxMat.data, originalSize.area() * sizeof(int32_t)); + + for (auto &[label, arrayBuffer] : *buffersToReturn) { + cv::Mat classMat(outputSize, CV_32FC1, arrayBuffer->data()); + cv::resize(classMat, classMat, originalSize); + arrayBuffer = std::make_shared( + classMat.data, originalSize.area() * sizeof(float)); + } + } + + return populateDictionary(argmax, buffersToReturn); +} + +std::shared_ptr BaseImageSegmentation::populateDictionary( + std::shared_ptr argmax, + std::shared_ptr>> + classesToOutput) { + auto promisePtr = std::make_shared>(); + std::future doneFuture = promisePtr->get_future(); + + std::shared_ptr dictPtr = nullptr; + callInvoker->invokeAsync( + [argmax, classesToOutput, &dictPtr, promisePtr](jsi::Runtime &runtime) { + dictPtr = std::make_shared(runtime); + auto argmaxArrayBuffer = jsi::ArrayBuffer(runtime, argmax); + + auto int32ArrayCtor = + runtime.global().getPropertyAsFunction(runtime, "Int32Array"); + auto int32Array = + int32ArrayCtor.callAsConstructor(runtime, argmaxArrayBuffer) + .getObject(runtime); + dictPtr->setProperty(runtime, "ARGMAX", int32Array); + + for (auto &[classLabel, owningBuffer] : *classesToOutput) { + auto classArrayBuffer = jsi::ArrayBuffer(runtime, owningBuffer); + + auto float32ArrayCtor = + runtime.global().getPropertyAsFunction(runtime, "Float32Array"); + auto float32Array = + float32ArrayCtor.callAsConstructor(runtime, classArrayBuffer) + .getObject(runtime); + + dictPtr->setProperty( + runtime, jsi::String::createFromAscii(runtime, classLabel.data()), + float32Array); + } + promisePtr->set_value(); + }); + + doneFuture.wait(); + return dictPtr; +} + +} // namespace rnexecutorch::models::image_segmentation diff --git a/packages/react-native-executorch/common/rnexecutorch/models/image_segmentation/ImageSegmentation.h b/packages/react-native-executorch/common/rnexecutorch/models/image_segmentation/BaseImageSegmentation.h similarity index 53% rename from packages/react-native-executorch/common/rnexecutorch/models/image_segmentation/ImageSegmentation.h rename to packages/react-native-executorch/common/rnexecutorch/models/image_segmentation/BaseImageSegmentation.h index 301833ce8..f46f41d69 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/image_segmentation/ImageSegmentation.h +++ b/packages/react-native-executorch/common/rnexecutorch/models/image_segmentation/BaseImageSegmentation.h @@ -3,12 +3,12 @@ #include #include #include +#include #include #include "rnexecutorch/metaprogramming/ConstructorHelpers.h" #include #include -#include namespace rnexecutorch { namespace models::image_segmentation { @@ -17,32 +17,45 @@ using namespace facebook; using executorch::aten::Tensor; using executorch::extension::TensorPtr; -class ImageSegmentation : public BaseModel { +class BaseImageSegmentation : public BaseModel { public: - ImageSegmentation(const std::string &modelSource, - std::shared_ptr callInvoker); + BaseImageSegmentation(const std::string &modelSource, + std::shared_ptr callInvoker); + + BaseImageSegmentation(const std::string &modelSource, + std::vector normMean, std::vector normStd, + std::shared_ptr callInvoker); + [[nodiscard("Registered non-void function")]] std::shared_ptr - generate(std::string imageSource, + generate(std::string imageSource, std::vector allClasses, std::set> classesOfInterest, bool resize); -private: - std::shared_ptr +protected: + virtual TensorPtr preprocess(const std::string &imageSource, + cv::Size &originalSize); + virtual std::shared_ptr postprocess(const Tensor &tensor, cv::Size originalSize, - std::set> classesOfInterest, + std::vector &allClasses, + std::set> &classesOfInterest, bool resize); + + cv::Size modelImageSize; + std::size_t numModelPixels; + std::optional normMean_; + std::optional normStd_; + std::shared_ptr populateDictionary( std::shared_ptr argmax, std::shared_ptr>> classesToOutput); - static constexpr std::size_t numClasses{ - constants::kDeeplabV3Resnet50Labels.size()}; - cv::Size modelImageSize; - std::size_t numModelPixels; +private: + void initModelImageSize(); }; } // namespace models::image_segmentation -REGISTER_CONSTRUCTOR(models::image_segmentation::ImageSegmentation, std::string, +REGISTER_CONSTRUCTOR(models::image_segmentation::BaseImageSegmentation, + std::string, std::vector, std::vector, std::shared_ptr); } // namespace rnexecutorch diff --git a/packages/react-native-executorch/common/rnexecutorch/models/image_segmentation/Constants.h b/packages/react-native-executorch/common/rnexecutorch/models/image_segmentation/Constants.h deleted file mode 100644 index 847c44373..000000000 --- a/packages/react-native-executorch/common/rnexecutorch/models/image_segmentation/Constants.h +++ /dev/null @@ -1,13 +0,0 @@ -#pragma once - -#include -#include - -namespace rnexecutorch::models::image_segmentation::constants { -inline constexpr std::array kDeeplabV3Resnet50Labels = { - "BACKGROUND", "AEROPLANE", "BICYCLE", "BIRD", "BOAT", - "BOTTLE", "BUS", "CAR", "CAT", "CHAIR", - "COW", "DININGTABLE", "DOG", "HORSE", "MOTORBIKE", - "PERSON", "POTTEDPLANT", "SHEEP", "SOFA", "TRAIN", - "TVMONITOR"}; -} // namespace rnexecutorch::models::image_segmentation::constants \ No newline at end of file diff --git a/packages/react-native-executorch/common/rnexecutorch/models/image_segmentation/ImageSegmentation.cpp b/packages/react-native-executorch/common/rnexecutorch/models/image_segmentation/ImageSegmentation.cpp deleted file mode 100644 index a2c1ae865..000000000 --- a/packages/react-native-executorch/common/rnexecutorch/models/image_segmentation/ImageSegmentation.cpp +++ /dev/null @@ -1,170 +0,0 @@ -#include "ImageSegmentation.h" - -#include - -#include -#include -#include -#include -#include -#include - -namespace rnexecutorch::models::image_segmentation { - -ImageSegmentation::ImageSegmentation( - const std::string &modelSource, - std::shared_ptr callInvoker) - : BaseModel(modelSource, callInvoker) { - auto inputShapes = getAllInputShapes(); - if (inputShapes.size() == 0) { - throw RnExecutorchError(RnExecutorchErrorCode::UnexpectedNumInputs, - "Model seems to not take any input tensors."); - } - std::vector modelInputShape = inputShapes[0]; - if (modelInputShape.size() < 2) { - char errorMessage[100]; - std::snprintf(errorMessage, sizeof(errorMessage), - "Unexpected model input size, expected at least 2 dimentions " - "but got: %zu.", - modelInputShape.size()); - throw RnExecutorchError(RnExecutorchErrorCode::WrongDimensions, - errorMessage); - } - modelImageSize = cv::Size(modelInputShape[modelInputShape.size() - 1], - modelInputShape[modelInputShape.size() - 2]); - numModelPixels = modelImageSize.area(); -} - -std::shared_ptr ImageSegmentation::generate( - std::string imageSource, - std::set> classesOfInterest, bool resize) { - auto [inputTensor, originalSize] = - image_processing::readImageToTensor(imageSource, getAllInputShapes()[0]); - - auto forwardResult = BaseModel::forward(inputTensor); - if (!forwardResult.ok()) { - throw RnExecutorchError(forwardResult.error(), - "The model's forward function did not succeed. " - "Ensure the model input is correct."); - } - - return postprocess(forwardResult->at(0).toTensor(), originalSize, - classesOfInterest, resize); -} - -std::shared_ptr ImageSegmentation::postprocess( - const Tensor &tensor, cv::Size originalSize, - std::set> classesOfInterest, bool resize) { - - auto dataPtr = static_cast(tensor.const_data_ptr()); - auto resultData = std::span(dataPtr, tensor.numel()); - - // We copy the ET-owned data to jsi array buffers that can be directly - // returned to JS - std::vector> resultClasses; - resultClasses.reserve(numClasses); - for (std::size_t cl = 0; cl < numClasses; ++cl) { - auto classBuffer = std::make_shared( - &resultData[cl * numModelPixels], numModelPixels * sizeof(float)); - resultClasses.push_back(classBuffer); - } - - // Apply softmax per each pixel across all classes - for (std::size_t pixel = 0; pixel < numModelPixels; ++pixel) { - std::vector classValues(numClasses); - for (std::size_t cl = 0; cl < numClasses; ++cl) { - classValues[cl] = - reinterpret_cast(resultClasses[cl]->data())[pixel]; - } - numerical::softmax(classValues); - for (std::size_t cl = 0; cl < numClasses; ++cl) { - reinterpret_cast(resultClasses[cl]->data())[pixel] = - classValues[cl]; - } - } - - // Calculate the maximum class for each pixel - auto argmax = - std::make_shared(numModelPixels * sizeof(int32_t)); - for (std::size_t pixel = 0; pixel < numModelPixels; ++pixel) { - float max = reinterpret_cast(resultClasses[0]->data())[pixel]; - int maxInd = 0; - for (int cl = 1; cl < numClasses; ++cl) { - if (reinterpret_cast(resultClasses[cl]->data())[pixel] > max) { - maxInd = cl; - max = reinterpret_cast(resultClasses[cl]->data())[pixel]; - } - } - reinterpret_cast(argmax->data())[pixel] = maxInd; - } - - auto buffersToReturn = std::make_shared>>(); - for (std::size_t cl = 0; cl < numClasses; ++cl) { - if (classesOfInterest.contains(constants::kDeeplabV3Resnet50Labels[cl])) { - (*buffersToReturn)[constants::kDeeplabV3Resnet50Labels[cl]] = - resultClasses[cl]; - } - } - - // Resize selected classes and argmax - if (resize) { - cv::Mat argmaxMat(modelImageSize, CV_32SC1, argmax->data()); - cv::resize(argmaxMat, argmaxMat, originalSize, 0, 0, - cv::InterpolationFlags::INTER_NEAREST); - argmax = std::make_shared( - argmaxMat.data, originalSize.area() * sizeof(int32_t)); - - for (auto &[label, arrayBuffer] : *buffersToReturn) { - cv::Mat classMat(modelImageSize, CV_32FC1, arrayBuffer->data()); - cv::resize(classMat, classMat, originalSize); - arrayBuffer = std::make_shared( - classMat.data, originalSize.area() * sizeof(float)); - } - } - return populateDictionary(argmax, buffersToReturn); -} - -std::shared_ptr ImageSegmentation::populateDictionary( - std::shared_ptr argmax, - std::shared_ptr>> - classesToOutput) { - // Synchronize the invoked thread to return when the dict is constructed - auto promisePtr = std::make_shared>(); - std::future doneFuture = promisePtr->get_future(); - - std::shared_ptr dictPtr = nullptr; - callInvoker->invokeAsync( - [argmax, classesToOutput, &dictPtr, promisePtr](jsi::Runtime &runtime) { - dictPtr = std::make_shared(runtime); - auto argmaxArrayBuffer = jsi::ArrayBuffer(runtime, argmax); - - auto int32ArrayCtor = - runtime.global().getPropertyAsFunction(runtime, "Int32Array"); - auto int32Array = - int32ArrayCtor.callAsConstructor(runtime, argmaxArrayBuffer) - .getObject(runtime); - dictPtr->setProperty(runtime, "ARGMAX", int32Array); - - for (auto &[classLabel, owningBuffer] : *classesToOutput) { - auto classArrayBuffer = jsi::ArrayBuffer(runtime, owningBuffer); - - auto float32ArrayCtor = - runtime.global().getPropertyAsFunction(runtime, "Float32Array"); - auto float32Array = - float32ArrayCtor.callAsConstructor(runtime, classArrayBuffer) - .getObject(runtime); - - dictPtr->setProperty( - runtime, jsi::String::createFromAscii(runtime, classLabel.data()), - float32Array); - } - promisePtr->set_value(); - }); - - doneFuture.wait(); - return dictPtr; -} - -} // namespace rnexecutorch::models::image_segmentation diff --git a/packages/react-native-executorch/src/constants/commonVision.ts b/packages/react-native-executorch/src/constants/commonVision.ts new file mode 100644 index 000000000..05eeba759 --- /dev/null +++ b/packages/react-native-executorch/src/constants/commonVision.ts @@ -0,0 +1,4 @@ +import { Triple } from '../types/common'; + +export const IMAGENET1K_MEAN: Triple = [0.485, 0.456, 0.406]; +export const IMAGENET1K_STD: Triple = [0.229, 0.224, 0.225]; diff --git a/packages/react-native-executorch/src/constants/modelUrls.ts b/packages/react-native-executorch/src/constants/modelUrls.ts index 6e76e52b7..20cacd050 100644 --- a/packages/react-native-executorch/src/constants/modelUrls.ts +++ b/packages/react-native-executorch/src/constants/modelUrls.ts @@ -521,8 +521,18 @@ const DEEPLAB_V3_RESNET50_MODEL = `${URL_PREFIX}-deeplab-v3/${VERSION_TAG}/xnnpa * @category Models - Image Segmentation */ export const DEEPLAB_V3_RESNET50 = { + modelName: 'deeplab-v3', modelSource: DEEPLAB_V3_RESNET50_MODEL, -}; +} as const; + +const SELFIE_SEGMENTATION_MODEL = `${URL_PREFIX}-selfie-segmentation/${VERSION_TAG}/xnnpack/selfie-segmentation.pte`; +/** + * @category Models - Image segmentation + */ +export const SELFIE_SEGMENTATION = { + modelName: 'selfie-segmentation', + modelSource: SELFIE_SEGMENTATION_MODEL, +} as const; // Image Embeddings const CLIP_VIT_BASE_PATCH32_IMAGE_MODEL = `${URL_PREFIX}-clip-vit-base-patch32/${VERSION_TAG}/clip-vit-base-patch32-vision_xnnpack.pte`; diff --git a/packages/react-native-executorch/src/hooks/computer_vision/useImageSegmentation.ts b/packages/react-native-executorch/src/hooks/computer_vision/useImageSegmentation.ts index c7a352e9a..88831f9aa 100644 --- a/packages/react-native-executorch/src/hooks/computer_vision/useImageSegmentation.ts +++ b/packages/react-native-executorch/src/hooks/computer_vision/useImageSegmentation.ts @@ -1,23 +1,115 @@ -import { useModule } from '../useModule'; -import { ImageSegmentationModule } from '../../modules/computer_vision/ImageSegmentationModule'; +import { useState, useEffect } from 'react'; +import { + ImageSegmentationModule, + SegmentationLabels, +} from '../../modules/computer_vision/ImageSegmentationModule'; import { ImageSegmentationProps, ImageSegmentationType, + ModelNameOf, + ModelSources, } from '../../types/imageSegmentation'; +import { RnExecutorchErrorCode } from '../../errors/ErrorCodes'; +import { RnExecutorchError, parseUnknownError } from '../../errors/errorUtils'; /** * React hook for managing an Image Segmentation model instance. * + * @typeParam C - A {@link ModelSources} config specifying which built-in model to load. + * @param props - Configuration object containing `model` config and optional `preventLoad` flag. + * @returns An object with model state (`error`, `isReady`, `isGenerating`, `downloadProgress`) and a typed `forward` function. + * + * @example + * ```ts + * const { isReady, forward } = useImageSegmentation({ + * model: { modelName: 'deeplab-v3', modelSource: DEEPLAB_V3_RESNET50 }, + * }); + * ``` + * * @category Hooks - * @param ImageSegmentationProps - Configuration object containing `model` source and optional `preventLoad` flag. - * @returns Ready to use Image Segmentation model. */ -export const useImageSegmentation = ({ +export const useImageSegmentation = ({ model, preventLoad = false, -}: ImageSegmentationProps): ImageSegmentationType => - useModule({ - module: ImageSegmentationModule, - model, - preventLoad, - }); +}: ImageSegmentationProps): ImageSegmentationType< + SegmentationLabels> +> => { + const [error, setError] = useState(null); + const [isReady, setIsReady] = useState(false); + const [isGenerating, setIsGenerating] = useState(false); + const [downloadProgress, setDownloadProgress] = useState(0); + const [instance, setInstance] = useState + > | null>(null); + + useEffect(() => { + if (preventLoad) return; + + let isMounted = true; + let currentInstance: ImageSegmentationModule> | null = null; + + (async () => { + setDownloadProgress(0); + setError(null); + setIsReady(false); + try { + currentInstance = await ImageSegmentationModule.fromModelName( + model, + (progress) => { + if (isMounted) setDownloadProgress(progress); + } + ); + if (isMounted) { + setInstance(currentInstance); + setIsReady(true); + } + } catch (err) { + if (isMounted) setError(parseUnknownError(err)); + } + })(); + + return () => { + isMounted = false; + currentInstance?.delete(); + }; + + // eslint-disable-next-line react-hooks/exhaustive-deps + }, [model.modelName, model.modelSource, preventLoad]); + + const forward = async >>( + imageSource: string, + classesOfInterest: K[] = [], + resizeToInput: boolean = true + ) => { + if (!isReady || !instance) { + throw new RnExecutorchError( + RnExecutorchErrorCode.ModuleNotLoaded, + 'The model is currently not loaded. Please load the model before calling forward().' + ); + } + if (isGenerating) { + throw new RnExecutorchError( + RnExecutorchErrorCode.ModelGenerating, + 'The model is currently generating. Please wait until previous model run is complete.' + ); + } + try { + setIsGenerating(true); + return await instance.forward( + imageSource, + classesOfInterest, + resizeToInput + ); + } finally { + setIsGenerating(false); + } + }; + + return { + error, + isReady, + isGenerating, + downloadProgress, + forward, + }; +}; diff --git a/packages/react-native-executorch/src/hooks/useModule.ts b/packages/react-native-executorch/src/hooks/useModule.ts index 39b10249b..1a35885d5 100644 --- a/packages/react-native-executorch/src/hooks/useModule.ts +++ b/packages/react-native-executorch/src/hooks/useModule.ts @@ -35,19 +35,24 @@ export const useModule = < useEffect(() => { if (preventLoad) return; + let isMounted = true; + (async () => { setDownloadProgress(0); setError(null); try { setIsReady(false); - await moduleInstance.load(model, setDownloadProgress); - setIsReady(true); + await moduleInstance.load(model, (progress: number) => { + if (isMounted) setDownloadProgress(progress); + }); + if (isMounted) setIsReady(true); } catch (err) { - setError(parseUnknownError(err)); + if (isMounted) setError(parseUnknownError(err)); } })(); return () => { + isMounted = false; moduleInstance.delete(); }; diff --git a/packages/react-native-executorch/src/index.ts b/packages/react-native-executorch/src/index.ts index 8b4035232..32b97fa58 100644 --- a/packages/react-native-executorch/src/index.ts +++ b/packages/react-native-executorch/src/index.ts @@ -3,7 +3,7 @@ import { ResourceFetcher, ResourceFetcherAdapter, } from './utils/ResourceFetcher'; - +import { Triple } from './types/common'; /** * Configuration that goes to the `initExecutorch`. * You can pass either bare React Native or Expo configuration. @@ -36,7 +36,11 @@ export function cleanupExecutorch() { // eslint-disable no-var declare global { var loadStyleTransfer: (source: string) => any; - var loadImageSegmentation: (source: string) => any; + var loadImageSegmentation: ( + source: string, + normMean: Triple | [], + normStd: Triple | [] + ) => any; var loadClassification: (source: string) => any; var loadObjectDetection: (source: string) => any; var loadExecutorchModule: (source: string) => any; diff --git a/packages/react-native-executorch/src/modules/computer_vision/ImageSegmentationModule.ts b/packages/react-native-executorch/src/modules/computer_vision/ImageSegmentationModule.ts index ddba7cdb7..f2de6edd7 100644 --- a/packages/react-native-executorch/src/modules/computer_vision/ImageSegmentationModule.ts +++ b/packages/react-native-executorch/src/modules/computer_vision/ImageSegmentationModule.ts @@ -1,82 +1,202 @@ import { ResourceFetcher } from '../../utils/ResourceFetcher'; -import { ResourceSource } from '../../types/common'; -import { DeeplabLabel } from '../../types/imageSegmentation'; +import { ResourceSource, LabelEnum } from '../../types/common'; +import { + DeeplabLabel, + ModelNameOf, + ModelSources, + SegmentationConfig, + SegmentationModelName, + SelfieSegmentationLabel, +} from '../../types/imageSegmentation'; import { RnExecutorchErrorCode } from '../../errors/ErrorCodes'; -import { parseUnknownError, RnExecutorchError } from '../../errors/errorUtils'; +import { RnExecutorchError } from '../../errors/errorUtils'; import { BaseModule } from '../BaseModule'; -import { Logger } from '../../common/Logger'; + +const ModelConfigs = { + 'deeplab-v3': { + labelMap: DeeplabLabel, + preprocessorConfig: undefined, + }, + 'selfie-segmentation': { + labelMap: SelfieSegmentationLabel, + preprocessorConfig: undefined, + }, +} as const satisfies Record< + SegmentationModelName, + SegmentationConfig +>; + +/** @internal */ +type ModelConfigsType = typeof ModelConfigs; + +/** + * Resolves the {@link LabelEnum} for a given built-in model name. + * + * @typeParam M - A built-in model name from {@link SegmentationModelName}. + * + * @category Types + */ +export type SegmentationLabels = + ModelConfigsType[M]['labelMap']; + +/** + * @internal + * Resolves the label type: if `T` is a {@link SegmentationModelName}, looks up its labels + * from the built-in config; otherwise uses `T` directly as a {@link LabelEnum}. + */ +type ResolveLabels = + T extends SegmentationModelName ? SegmentationLabels : T; /** - * Module for image segmentation tasks. + * Generic image segmentation module with type-safe label maps. + * Use a model name (e.g. `'deeplab-v3'`) as the generic parameter for built-in models, + * or a custom label enum for custom configs. + * + * @typeParam T - Either a built-in model name (`'deeplab-v3'`, `'selfie-segmentation'`) + * or a custom {@link LabelEnum} label map. * * @category Typescript API */ -export class ImageSegmentationModule extends BaseModule { +export class ImageSegmentationModule< + T extends SegmentationModelName | LabelEnum, +> extends BaseModule { + private labelMap: ResolveLabels; + private allClassNames: string[]; + + private constructor(labelMap: ResolveLabels, nativeModule: unknown) { + super(); + this.labelMap = labelMap; + this.allClassNames = Object.keys(this.labelMap).filter((k) => + isNaN(Number(k)) + ); + this.nativeModule = nativeModule; + } + + // TODO: figure it out so we can delete this (we need this because of basemodule inheritance) + override async load() {} + /** - * Loads the model, where `modelSource` is a string that specifies the location of the model binary. - * To track the download progress, supply a callback function `onDownloadProgressCallback`. + * Creates a segmentation instance for a built-in model. + * The config object is discriminated by `modelName` — each model can require different fields. * - * @param model - Object containing `modelSource`. - * @param onDownloadProgressCallback - Optional callback to monitor download progress. + * @param config - A {@link ModelSources} object specifying which model to load and where to fetch it from. + * @param onDownloadProgress - Optional callback to monitor download progress, receiving a value between 0 and 1. + * @returns A Promise resolving to an `ImageSegmentationModule` instance typed to the chosen model's label map. + * + * @example + * ```ts + * const segmentation = await ImageSegmentationModule.fromModelName({ + * modelName: 'deeplab-v3', + * modelSource: 'https://example.com/deeplab.pte', + * }); + * ``` */ - async load( - model: { modelSource: ResourceSource }, - onDownloadProgressCallback: (progress: number) => void = () => {} - ): Promise { - try { - const paths = await ResourceFetcher.fetch( - onDownloadProgressCallback, - model.modelSource - ); - if (!paths?.[0]) { - throw new RnExecutorchError( - RnExecutorchErrorCode.DownloadInterrupted, - 'The download has been interrupted. As a result, not every file was downloaded. Please retry the download.' - ); - } + static async fromModelName( + config: C, + onDownloadProgress: (progress: number) => void = () => {} + ): Promise>> { + const { modelName, modelSource } = config; + const { labelMap } = ModelConfigs[modelName]; + const { preprocessorConfig } = ModelConfigs[ + modelName + ] as SegmentationConfig; + const normMean = preprocessorConfig?.normMean ?? []; + const normStd = preprocessorConfig?.normStd ?? []; + const paths = await ResourceFetcher.fetch(onDownloadProgress, modelSource); + if (!paths?.[0]) { + throw new RnExecutorchError( + RnExecutorchErrorCode.DownloadInterrupted, + 'The download has been interrupted. As a result, not every file was downloaded. Please retry the download.' + ); + } + const nativeModule = global.loadImageSegmentation( + paths[0], + normMean, + normStd + ); + return new ImageSegmentationModule>( + labelMap as ResolveLabels>, + nativeModule + ); + } - this.nativeModule = global.loadImageSegmentation(paths[0]); - } catch (error) { - Logger.error('Load failed:', error); - throw parseUnknownError(error); + /** + * Creates a segmentation instance with a user-provided label map and custom config. + * Use this when working with a custom-exported segmentation model that is not one of the built-in models. + * + * @param modelSource - A fetchable resource pointing to the model binary. + * @param config - A {@link SegmentationConfig} object with the label map and optional preprocessing parameters. + * @param onDownloadProgress - Optional callback to monitor download progress, receiving a value between 0 and 1. + * @returns A Promise resolving to an `ImageSegmentationModule` instance typed to the provided label map. + * + * @example + * ```ts + * const MyLabels = { BACKGROUND: 0, FOREGROUND: 1 } as const; + * const segmentation = await ImageSegmentationModule.fromCustomConfig( + * 'https://example.com/custom_model.pte', + * { labelMap: MyLabels }, + * ); + * ``` + */ + static async fromCustomConfig( + modelSource: ResourceSource, + config: SegmentationConfig, + onDownloadProgress: (progress: number) => void = () => {} + ): Promise> { + const paths = await ResourceFetcher.fetch(onDownloadProgress, modelSource); + if (!paths?.[0]) { + throw new RnExecutorchError( + RnExecutorchErrorCode.DownloadInterrupted, + 'The download has been interrupted. Please retry.' + ); } + const normMean = config.preprocessorConfig?.normMean ?? []; + const normStd = config.preprocessorConfig?.normStd ?? []; + const nativeModule = global.loadImageSegmentation( + paths[0], + normMean, + normStd + ); + return new ImageSegmentationModule( + config.labelMap as ResolveLabels, + nativeModule + ); } /** - * Executes the model's forward pass + * Executes the model's forward pass to perform semantic segmentation on the provided image. * - * @param imageSource - a fetchable resource or a Base64-encoded string. - * @param classesOfInterest - an optional list of DeeplabLabel used to indicate additional arrays of probabilities to output (see section "Running the model"). The default is an empty list. - * @param resizeToInput - an optional boolean to indicate whether the output should be resized to the original input image dimensions. If `false`, returns the model output without any resizing (see section "Running the model"). Defaults to `true`. - * @returns A dictionary where keys are `DeeplabLabel` and values are arrays of probabilities for each pixel belonging to the corresponding class. + * @param imageSource - A string representing the image source (e.g., a file path, URI, or Base64-encoded string). + * @param classesOfInterest - An optional list of label keys indicating which per-class probability masks to include in the output. `ARGMAX` is always returned regardless. + * @param resizeToInput - Whether to resize the output masks to the original input image dimensions. If `false`, returns the raw model output dimensions. Defaults to `true`. + * @returns A Promise resolving to an object with an `'ARGMAX'` key mapped to an `Int32Array` of per-pixel class indices, and each requested class label mapped to a `Float32Array` of per-pixel probabilities. + * @throws {RnExecutorchError} If the model is not loaded. */ - async forward( + async forward>( imageSource: string, - classesOfInterest?: DeeplabLabel[], - resizeToInput?: boolean - ): Promise>> { + classesOfInterest: K[] = [], + resizeToInput: boolean = true + ): Promise & Record> { if (this.nativeModule == null) { throw new RnExecutorchError( RnExecutorchErrorCode.ModuleNotLoaded, - 'The model is currently not loaded. Please load the model before calling forward().' + 'The model is currently not loaded.' ); } - const stringDict = await this.nativeModule.generate( - imageSource, - (classesOfInterest || []).map((label) => DeeplabLabel[label]), - resizeToInput ?? true + const classesOfInterestNames = classesOfInterest.map((label) => + String(label) ); - let enumDict: { [key in DeeplabLabel]?: number[] } = {}; + const nativeResult = await this.nativeModule.generate( + imageSource, + this.allClassNames, + classesOfInterestNames, + resizeToInput + ); - for (const key in stringDict) { - if (key in DeeplabLabel) { - const enumKey = DeeplabLabel[key as keyof typeof DeeplabLabel]; - enumDict[enumKey] = stringDict[key]; - } - } - return enumDict; + return nativeResult as Record<'ARGMAX', Int32Array> & + Record; } } diff --git a/packages/react-native-executorch/src/types/common.ts b/packages/react-native-executorch/src/types/common.ts index 7b87f31b6..384caa861 100644 --- a/packages/react-native-executorch/src/types/common.ts +++ b/packages/react-native-executorch/src/types/common.ts @@ -136,3 +136,18 @@ export interface TensorPtr { sizes: number[]; scalarType: ScalarType; } + +/** + * A readonly record mapping string keys to numeric or string values. + * Used to represent enum-like label maps for models. + * + * @category Types + */ +export type LabelEnum = Readonly>; + +/** + * A readonly triple of values, typically used for per-channel normalization parameters. + * + * @category Types + */ +export type Triple = readonly [T, T, T]; diff --git a/packages/react-native-executorch/src/types/genericImageSegmentation.ts b/packages/react-native-executorch/src/types/genericImageSegmentation.ts new file mode 100644 index 000000000..e69de29bb diff --git a/packages/react-native-executorch/src/types/imageSegmentation.ts b/packages/react-native-executorch/src/types/imageSegmentation.ts index 02d9eec10..6d79a801d 100644 --- a/packages/react-native-executorch/src/types/imageSegmentation.ts +++ b/packages/react-native-executorch/src/types/imageSegmentation.ts @@ -1,5 +1,47 @@ import { RnExecutorchError } from '../errors/errorUtils'; -import { ResourceSource } from './common'; +import { LabelEnum, Triple, ResourceSource } from './common'; + +/** + * Configuration for a custom segmentation model. + * + * @typeParam T - The {@link LabelEnum} type for the model. + * @property labelMap - The enum-like object mapping class names to indices. + * @property preprocessorConfig - Optional preprocessing parameters. + * @property preprocessorConfig.normMean - Per-channel mean values for input normalization. + * @property preprocessorConfig.normStd - Per-channel standard deviation values for input normalization. + * + * @category Types + */ +export type SegmentationConfig = { + labelMap: T; + preprocessorConfig?: { normMean?: Triple; normStd?: Triple }; +}; + +/** + * Per-model config for {@link ImageSegmentationModule.fromModelName}. + * Each model name maps to its required fields. + * Add new union members here when a model needs extra sources or options. + * + * @category Types + */ +export type ModelSources = + | { modelName: 'deeplab-v3'; modelSource: ResourceSource } + | { modelName: 'selfie-segmentation'; modelSource: ResourceSource }; + +/** + * Union of all built-in segmentation model names + * (e.g. `'deeplab-v3'`, `'selfie-segmentation'`). + * + * @category Types + */ +export type SegmentationModelName = ModelSources['modelName']; + +/** + * Extracts the model name from a {@link ModelSources} config object. + * + * @category Types + */ +export type ModelNameOf = C['modelName']; /** * Labels used in the DeepLab image segmentation model. @@ -28,30 +70,41 @@ export enum DeeplabLabel { SOFA, TRAIN, TVMONITOR, - ARGMAX, // Additional label not present in the model +} + +/** + * Labels used in the selfie image segmentation model. + * + * @category Types + */ +export enum SelfieSegmentationLabel { + SELFIE, + BACKGROUND, } /** * Props for the `useImageSegmentation` hook. * - * @property {Object} model - An object containing the model source. - * @property {ResourceSource} model.modelSource - The source of the image segmentation model binary. + * @typeParam C - A {@link ModelSources} config specifying which built-in model to load. + * @property model - The model config containing `modelName` and `modelSource`. * @property {boolean} [preventLoad] - Boolean that can prevent automatic model loading (and downloading the data if you load it for the first time) after running the hook. * * @category Types */ -export interface ImageSegmentationProps { - model: { modelSource: ResourceSource }; +export interface ImageSegmentationProps { + model: C; preventLoad?: boolean; } /** * Return type for the `useImageSegmentation` hook. - * Manages the state and operations for Computer Vision image segmentation (e.g., DeepLab). + * Manages the state and operations for image segmentation models. + * + * @typeParam L - The {@link LabelEnum} representing the model's class labels. * * @category Types */ -export interface ImageSegmentationType { +export interface ImageSegmentationType { /** * Contains the error object if the model failed to load, download, or encountered a runtime error during segmentation. */ @@ -75,14 +128,14 @@ export interface ImageSegmentationType { /** * Executes the model's forward pass to perform semantic segmentation on the provided image. * @param imageSource - A string representing the image source (e.g., a file path, URI, or base64 string) to be processed. - * @param classesOfInterest - An optional array of `DeeplabLabel` enums. If provided, the model will only return segmentation masks for these specific classes. - * @param resizeToInput - an optional boolean to indicate whether the output should be resized to the original input image dimensions. If `false`, returns the model output without any resizing (see section "Running the model"). Defaults to `true`. - * @returns A Promise that resolves to an object mapping each detected `DeeplabLabel` to its corresponding segmentation mask (represented as a flattened array of numbers). + * @param classesOfInterest - An optional array of label keys indicating which per-class probability masks to include in the output. `ARGMAX` is always returned regardless. + * @param resizeToInput - Whether to resize the output masks to the original input image dimensions. If `false`, returns the raw model output dimensions. Defaults to `true`. + * @returns A Promise resolving to an object with an `'ARGMAX'` `Int32Array` of per-pixel class indices, and each requested class label mapped to a `Float32Array` of per-pixel probabilities. * @throws {RnExecutorchError} If the model is not loaded or is currently processing another image. */ - forward: ( + forward: ( imageSource: string, - classesOfInterest?: DeeplabLabel[], + classesOfInterest?: K[], resizeToInput?: boolean - ) => Promise>>; + ) => Promise & Record>; }