Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions docs/source/api_doc/generic/attachment.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
imgutils.generic.attachment
=======================================

.. currentmodule:: imgutils.generic.attachment

.. automodule:: imgutils.generic.attachment



Attachment
-----------------------------------------

.. autoclass:: Attachment
:members: __init__, encoder_model, predict


open_attachment
-----------------------------------------

.. autofunction:: open_attachment



1 change: 1 addition & 0 deletions docs/source/api_doc/generic/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ imgutils.generic
.. toctree::
:maxdepth: 3

attachment
classify
enhance
clip
Expand Down
238 changes: 238 additions & 0 deletions imgutils/generic/attachment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,238 @@
"""
This module provides functionality for handling attachments in machine learning models,
particularly those hosted on Hugging Face's model hub. It includes tools for loading,
managing and making predictions with ONNX models for classification, tagging and regression tasks.

The module provides a caching mechanism for model loading and thread-safe operations
for concurrent access to models and their metadata.

An example of attachment models is `deepghs/eattach_monochrome_experiments <https://huggingface.co/deepghs/eattach_monochrome_experiments>`_.

.. note::
If you want to train a custom attachment model for taggers,
take a look at our framework `deepghs/emb_attachments <https://github.com/deepghs/emb_attachments>`_.
"""

import json
import os
from threading import Lock
from typing import Optional, Any, Tuple

import numpy as np
from huggingface_hub import hf_hub_download

from ..utils import open_onnx_model, vreplace, ts_lru_cache


class Attachment:
"""
A class to manage machine learning model attachments from Hugging Face.

This class handles model loading, caching, and prediction for various types of problems
including classification, tagging, and regression.

:param repo_id: The Hugging Face repository ID
:type repo_id: str
:param model_name: Name of the model
:type model_name: str
:param hf_token: Optional Hugging Face authentication token
:type hf_token: Optional[str]
"""

def __init__(self, repo_id: str, model_name: str, hf_token: Optional[str] = None):
"""
Initialize the Attachment instance with repository and model information.
"""
self.repo_id = repo_id
self.model_name = model_name
self._meta_value = None
self._model = None

self._hf_token = hf_token
self._global_lock = Lock()
self._model_lock = Lock()

def _get_hf_token(self) -> Optional[str]:
"""
Retrieve the Hugging Face authentication token.

Checks both instance variable and environment for token presence.

:return: Authentication token if available
:rtype: Optional[str]
"""
return self._hf_token or os.environ.get('HF_TOKEN')

@property
def _meta(self):
"""
Load and cache model metadata from the Hugging Face repository.

:return: Model metadata as a dictionary
:rtype: dict
"""
with self._model_lock:
if self._meta_value is None:
with open(hf_hub_download(
repo_id=self.repo_id,
repo_type='model',
filename=f'{self.model_name}/meta.json',
token=self._get_hf_token(),
), 'r') as f:
self._meta_value = json.load(f)

return self._meta_value

@property
def encoder_model(self) -> str:
"""
Get the encoder model name from metadata.

:return: Name of the encoder model
:rtype: str
"""
return self._meta['encoder_model']

def _open_model(self):
"""
Load and cache the ONNX model from Hugging Face.

:return: Loaded ONNX model
:rtype: object
"""
with self._model_lock:
if self._model is None:
self._model = open_onnx_model(hf_hub_download(
repo_id=self.repo_id,
repo_type='model',
filename=f'{self.model_name}/model.onnx',
token=self._get_hf_token(),
))

return self._model

def _predict_raw(self, embedding: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
"""
Make raw predictions using the model.

:param embedding: Input embedding array
:type embedding: np.ndarray
:return: Tuple of logits and predictions
:rtype: Tuple[np.ndarray, np.ndarray]
"""
model = self._open_model()
logits, prediction = model.run(['logits', 'prediction'], {'input': embedding})
return logits, prediction

def _predict_classification(self, embedding: np.ndarray, fmt: Any = 'top'):
"""
Make classification predictions.

:param embedding: Input embedding array
:type embedding: np.ndarray
:param fmt: Format specification for output
:type fmt: Any
:return: List of formatted prediction results
:rtype: list
"""
labels = np.array(self._meta['problem']['labels'])
logits, prediction = self._predict_raw(embedding)
retval = []
for logit, pred in zip(logits, prediction):
scores = dict(zip(labels, pred.tolist()))
maxidx = np.argmax(pred)
top_label, top_score = labels[maxidx].item(), pred[maxidx].item()
top = top_label, top_score
retval.append(vreplace(fmt, {
'scores': scores,
'top': top,
'top_label': top_label,
'top_score': top_score,
'logit': logit,
'prediction': pred,
}))

return retval

def _predict_tagging(self, embedding: np.ndarray, threshold: float = 0.3, fmt: Any = 'tags'):
"""
Make tagging predictions.

:param embedding: Input embedding array
:type embedding: np.ndarray
:param threshold: Confidence threshold for tag selection
:type threshold: float
:param fmt: Format specification for output
:type fmt: Any
:return: List of formatted prediction results
:rtype: list
"""
tags = np.array(self._meta['problem']['tags'])
logits, prediction = self._predict_raw(embedding)
retval = []
for logit, pred in zip(logits, prediction):
selection = pred >= threshold
pvalues, ptags = pred[selection], tags[selection]
result = dict(zip(ptags.tolist(), pvalues.tolist()))
retval.append(vreplace(fmt, {

Check warning on line 177 in imgutils/generic/attachment.py

View check run for this annotation

Codecov / codecov/patch

imgutils/generic/attachment.py#L170-L177

Added lines #L170 - L177 were not covered by tests
'tags': result,
'logit': logit,
'prediction': pred,
}))

return retval

Check warning on line 183 in imgutils/generic/attachment.py

View check run for this annotation

Codecov / codecov/patch

imgutils/generic/attachment.py#L183

Added line #L183 was not covered by tests

def predict(self, embedding: np.ndarray, **kwargs):
"""
Make predictions based on the problem type (classification, tagging, or regression).

:param embedding: Input embedding array
:type embedding: np.ndarray
:param kwargs: Additional arguments passed to specific prediction methods
:return: Prediction results in specified format
:raises ValueError: If embedding shape is invalid or problem type is unknown
"""
embedding = embedding.astype(np.float32)
if len(embedding.shape) == 1:
single = True
embedding = embedding[np.newaxis, ...]
elif len(embedding.shape) == 2:
single = False

Check warning on line 200 in imgutils/generic/attachment.py

View check run for this annotation

Codecov / codecov/patch

imgutils/generic/attachment.py#L199-L200

Added lines #L199 - L200 were not covered by tests
else:
raise ValueError(f'Unexpected embedding shape - {embedding!r}.')

Check warning on line 202 in imgutils/generic/attachment.py

View check run for this annotation

Codecov / codecov/patch

imgutils/generic/attachment.py#L202

Added line #L202 was not covered by tests

problem_type = self._meta['problem']['type']
if problem_type == 'classification':
result = self._predict_classification(embedding, **kwargs)
elif problem_type == 'tagging':
result = self._predict_tagging(embedding, **kwargs)

Check warning on line 208 in imgutils/generic/attachment.py

View check run for this annotation

Codecov / codecov/patch

imgutils/generic/attachment.py#L207-L208

Added lines #L207 - L208 were not covered by tests
else:
raise ValueError(f'Unknown problem type - {problem_type!r}.')

Check warning on line 210 in imgutils/generic/attachment.py

View check run for this annotation

Codecov / codecov/patch

imgutils/generic/attachment.py#L210

Added line #L210 was not covered by tests

if single:
result = result[0]
return result


@ts_lru_cache()
def open_attachment(repo_id: str, model_name: str, hf_token: Optional[str] = None) -> 'Attachment':
"""
Create and cache an Attachment instance.

This function creates a new Attachment instance or returns a cached one
if it was previously created with the same parameters.

:param repo_id: The Hugging Face repository ID
:type repo_id: str
:param model_name: Name of the model
:type model_name: str
:param hf_token: Optional Hugging Face authentication token
:type hf_token: Optional[str]
:return: An Attachment instance
:rtype: Attachment
"""
return Attachment(
repo_id=repo_id,
model_name=model_name,
hf_token=hf_token,
)
Loading
Loading