diff --git a/docs/source/api_doc/generic/attachment.rst b/docs/source/api_doc/generic/attachment.rst new file mode 100644 index 00000000000..4ba383b91d4 --- /dev/null +++ b/docs/source/api_doc/generic/attachment.rst @@ -0,0 +1,23 @@ +imgutils.generic.attachment +======================================= + +.. currentmodule:: imgutils.generic.attachment + +.. automodule:: imgutils.generic.attachment + + + +Attachment +----------------------------------------- + +.. autoclass:: Attachment + :members: __init__, encoder_model, predict + + +open_attachment +----------------------------------------- + +.. autofunction:: open_attachment + + + diff --git a/docs/source/api_doc/generic/index.rst b/docs/source/api_doc/generic/index.rst index 70a62c32722..cfd47df40c7 100644 --- a/docs/source/api_doc/generic/index.rst +++ b/docs/source/api_doc/generic/index.rst @@ -9,6 +9,7 @@ imgutils.generic .. toctree:: :maxdepth: 3 + attachment classify enhance clip diff --git a/imgutils/generic/attachment.py b/imgutils/generic/attachment.py new file mode 100644 index 00000000000..f3ce8f00231 --- /dev/null +++ b/imgutils/generic/attachment.py @@ -0,0 +1,238 @@ +""" +This module provides functionality for handling attachments in machine learning models, +particularly those hosted on Hugging Face's model hub. It includes tools for loading, +managing and making predictions with ONNX models for classification, tagging and regression tasks. + +The module provides a caching mechanism for model loading and thread-safe operations +for concurrent access to models and their metadata. + +An example of attachment models is `deepghs/eattach_monochrome_experiments `_. + +.. note:: + If you want to train a custom attachment model for taggers, + take a look at our framework `deepghs/emb_attachments `_. +""" + +import json +import os +from threading import Lock +from typing import Optional, Any, Tuple + +import numpy as np +from huggingface_hub import hf_hub_download + +from ..utils import open_onnx_model, vreplace, ts_lru_cache + + +class Attachment: + """ + A class to manage machine learning model attachments from Hugging Face. + + This class handles model loading, caching, and prediction for various types of problems + including classification, tagging, and regression. + + :param repo_id: The Hugging Face repository ID + :type repo_id: str + :param model_name: Name of the model + :type model_name: str + :param hf_token: Optional Hugging Face authentication token + :type hf_token: Optional[str] + """ + + def __init__(self, repo_id: str, model_name: str, hf_token: Optional[str] = None): + """ + Initialize the Attachment instance with repository and model information. + """ + self.repo_id = repo_id + self.model_name = model_name + self._meta_value = None + self._model = None + + self._hf_token = hf_token + self._global_lock = Lock() + self._model_lock = Lock() + + def _get_hf_token(self) -> Optional[str]: + """ + Retrieve the Hugging Face authentication token. + + Checks both instance variable and environment for token presence. + + :return: Authentication token if available + :rtype: Optional[str] + """ + return self._hf_token or os.environ.get('HF_TOKEN') + + @property + def _meta(self): + """ + Load and cache model metadata from the Hugging Face repository. + + :return: Model metadata as a dictionary + :rtype: dict + """ + with self._model_lock: + if self._meta_value is None: + with open(hf_hub_download( + repo_id=self.repo_id, + repo_type='model', + filename=f'{self.model_name}/meta.json', + token=self._get_hf_token(), + ), 'r') as f: + self._meta_value = json.load(f) + + return self._meta_value + + @property + def encoder_model(self) -> str: + """ + Get the encoder model name from metadata. + + :return: Name of the encoder model + :rtype: str + """ + return self._meta['encoder_model'] + + def _open_model(self): + """ + Load and cache the ONNX model from Hugging Face. + + :return: Loaded ONNX model + :rtype: object + """ + with self._model_lock: + if self._model is None: + self._model = open_onnx_model(hf_hub_download( + repo_id=self.repo_id, + repo_type='model', + filename=f'{self.model_name}/model.onnx', + token=self._get_hf_token(), + )) + + return self._model + + def _predict_raw(self, embedding: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: + """ + Make raw predictions using the model. + + :param embedding: Input embedding array + :type embedding: np.ndarray + :return: Tuple of logits and predictions + :rtype: Tuple[np.ndarray, np.ndarray] + """ + model = self._open_model() + logits, prediction = model.run(['logits', 'prediction'], {'input': embedding}) + return logits, prediction + + def _predict_classification(self, embedding: np.ndarray, fmt: Any = 'top'): + """ + Make classification predictions. + + :param embedding: Input embedding array + :type embedding: np.ndarray + :param fmt: Format specification for output + :type fmt: Any + :return: List of formatted prediction results + :rtype: list + """ + labels = np.array(self._meta['problem']['labels']) + logits, prediction = self._predict_raw(embedding) + retval = [] + for logit, pred in zip(logits, prediction): + scores = dict(zip(labels, pred.tolist())) + maxidx = np.argmax(pred) + top_label, top_score = labels[maxidx].item(), pred[maxidx].item() + top = top_label, top_score + retval.append(vreplace(fmt, { + 'scores': scores, + 'top': top, + 'top_label': top_label, + 'top_score': top_score, + 'logit': logit, + 'prediction': pred, + })) + + return retval + + def _predict_tagging(self, embedding: np.ndarray, threshold: float = 0.3, fmt: Any = 'tags'): + """ + Make tagging predictions. + + :param embedding: Input embedding array + :type embedding: np.ndarray + :param threshold: Confidence threshold for tag selection + :type threshold: float + :param fmt: Format specification for output + :type fmt: Any + :return: List of formatted prediction results + :rtype: list + """ + tags = np.array(self._meta['problem']['tags']) + logits, prediction = self._predict_raw(embedding) + retval = [] + for logit, pred in zip(logits, prediction): + selection = pred >= threshold + pvalues, ptags = pred[selection], tags[selection] + result = dict(zip(ptags.tolist(), pvalues.tolist())) + retval.append(vreplace(fmt, { + 'tags': result, + 'logit': logit, + 'prediction': pred, + })) + + return retval + + def predict(self, embedding: np.ndarray, **kwargs): + """ + Make predictions based on the problem type (classification, tagging, or regression). + + :param embedding: Input embedding array + :type embedding: np.ndarray + :param kwargs: Additional arguments passed to specific prediction methods + :return: Prediction results in specified format + :raises ValueError: If embedding shape is invalid or problem type is unknown + """ + embedding = embedding.astype(np.float32) + if len(embedding.shape) == 1: + single = True + embedding = embedding[np.newaxis, ...] + elif len(embedding.shape) == 2: + single = False + else: + raise ValueError(f'Unexpected embedding shape - {embedding!r}.') + + problem_type = self._meta['problem']['type'] + if problem_type == 'classification': + result = self._predict_classification(embedding, **kwargs) + elif problem_type == 'tagging': + result = self._predict_tagging(embedding, **kwargs) + else: + raise ValueError(f'Unknown problem type - {problem_type!r}.') + + if single: + result = result[0] + return result + + +@ts_lru_cache() +def open_attachment(repo_id: str, model_name: str, hf_token: Optional[str] = None) -> 'Attachment': + """ + Create and cache an Attachment instance. + + This function creates a new Attachment instance or returns a cached one + if it was previously created with the same parameters. + + :param repo_id: The Hugging Face repository ID + :type repo_id: str + :param model_name: Name of the model + :type model_name: str + :param hf_token: Optional Hugging Face authentication token + :type hf_token: Optional[str] + :return: An Attachment instance + :rtype: Attachment + """ + return Attachment( + repo_id=repo_id, + model_name=model_name, + hf_token=hf_token, + ) diff --git a/imgutils/tagging/wd14.py b/imgutils/tagging/wd14.py index f393d008c9d..bccd6a5e79f 100644 --- a/imgutils/tagging/wd14.py +++ b/imgutils/tagging/wd14.py @@ -7,7 +7,8 @@ project on Hugging Face. """ -from typing import List, Tuple +from collections import defaultdict +from typing import List, Tuple, Any, Optional, Mapping, Dict, Union import numpy as np import onnxruntime @@ -19,7 +20,8 @@ from .format import remove_underline from .overlap import drop_overlap_tags from ..data import load_image, ImageTyping -from ..utils import open_onnx_model, vreplace, sigmoid, ts_lru_cache +from ..generic.attachment import open_attachment, Attachment +from ..utils import open_onnx_model, vreplace, sigmoid, ts_lru_cache, vnames SWIN_MODEL_REPO = "SmilingWolf/wd-v1-4-swinv2-tagger-v2" CONV_MODEL_REPO = "SmilingWolf/wd-v1-4-convnext-tagger-v2" @@ -188,7 +190,8 @@ def _postprocess_embedding( character_mcut_enabled: bool = False, no_underline: bool = False, drop_overlap: bool = False, - fmt=('rating', 'general', 'character'), + fmt: Any = ('rating', 'general', 'character'), + attachments: Optional[Mapping[str, Union[Tuple[str, str], Tuple[str, str, dict]]]] = None, ): """ Post-process the embedding and prediction results. @@ -211,9 +214,36 @@ def _postprocess_embedding( :type no_underline: bool :param drop_overlap: Whether to drop overlapping tags. :type drop_overlap: bool - :param fmt: The format of the output. - :return: The post-processed results. + :param fmt: Output format specification defining which components to include + :type fmt: Any + :param attachments: Additional model attachments for extended tagging capabilities + :type attachments: Optional[Mapping[str, Union[Tuple[str, str], Tuple[str, str, dict]]]] + + :return: Processed tagging results in the specified format + :rtype: Any + :raises ValueError: If attachment configuration is invalid or incompatible """ + attachments = dict(attachments or {}) + d_attachments: Dict[str, Tuple[Attachment, dict]] = {} + for attach_name, attach_tpl in attachments.items(): + if '/' in attach_name: + raise ValueError(f'Invalid attachment register name, no \'/\' required - {attach_name!r}.') + + if len(attach_tpl) == 2: + (attach_repo_id, attach_model_name), attach_kwargs = attach_tpl, {} + elif len(attach_tpl) == 3: + attach_repo_id, attach_model_name, attach_kwargs = attach_tpl + else: + raise ValueError(f'Invalid attachment tuple for {attach_name!r}, ' + f'2 or 3 elements expected but {attach_tpl!r} found.') + attachment = open_attachment(repo_id=attach_repo_id, model_name=attach_model_name) + expected_encoder_model = f'wdtagger:{MODEL_NAMES[model_name]}' + if attachment.encoder_model != expected_encoder_model: + raise ValueError(f'Attachment encoder model not match, ' + f'{expected_encoder_model!r} expected but {attachment.encoder_model!r} found ' + f'for {attach_name!r}.') + d_attachments[attach_name] = (attachment, attach_kwargs) + assert len(pred.shape) == len(embedding.shape) == 1, \ f'Both pred and embeddings shapes should be 1-dim, ' \ f'but pred: {pred.shape!r}, embedding: {embedding.shape!r} actually found.' @@ -239,17 +269,27 @@ def _postprocess_embedding( character_res = {x: v.item() for x, v in character_names if v > character_threshold} - return vreplace( - fmt, - { - 'rating': rating, - 'general': general_res, - 'character': character_res, - 'tag': {**general_res, **character_res}, - 'embedding': embedding.astype(np.float32), - 'prediction': pred.astype(np.float32), - } - ) + mapping_values = { + 'rating': rating, + 'general': general_res, + 'character': character_res, + 'tag': {**general_res, **character_res}, + 'embedding': embedding.astype(np.float32), + 'prediction': pred.astype(np.float32), + } + + d_attach_infers = defaultdict(list) + for vname in vnames(fmt): + if '/' in vname and vname.split('/', maxsplit=1)[0] in d_attachments: + attach_name, attach_fmt_name = vname.split('/', maxsplit=1) + d_attach_infers[attach_name].append(attach_fmt_name) + for attach_name, attach_infer_names in d_attach_infers.items(): + attachment, attach_kwargs = d_attachments[attach_name] + attach_infer_values = attachment.predict(embedding=embedding, fmt=attach_infer_names, **attach_kwargs) + attach_mapping_names = [f'{attach_name}/{name}' for name in attach_infer_names] + mapping_values.update(dict(zip(attach_mapping_names, attach_infer_values))) + + return vreplace(fmt, mapping_values) def get_wd14_tags( @@ -261,7 +301,8 @@ def get_wd14_tags( character_mcut_enabled: bool = False, no_underline: bool = False, drop_overlap: bool = False, - fmt=('rating', 'general', 'character'), + fmt: Any = ('rating', 'general', 'character'), + attachments: Optional[Mapping[str, Union[Tuple[str, str], Tuple[str, str, dict]]]] = None, ): """ Get tags for an image using WD14 taggers. @@ -287,7 +328,13 @@ def get_wd14_tags( :type drop_overlap: bool :param fmt: Return format, default is ``('rating', 'general', 'character')``. ``embedding`` is also supported for feature extraction. + :type fmt: Any + :param attachments: Additional model attachments for extended tagging capabilities + :type attachments: Optional[Mapping[str, Union[Tuple[str, str], Tuple[str, str, dict]]]] + :return: Prediction result based on the provided fmt. + :rtype: Any + :raises ValueError: If attachment configuration is invalid or incompatible .. note:: The fmt argument can include the following keys: @@ -310,6 +357,44 @@ def get_wd14_tags( This embedding is valuable for constructing indices that enable rapid querying of images based on visual features within large-scale datasets. + .. note:: + The attachment system allows integration of additional tagging models to extend the base WD14 tagger's capabilities. + Attachments are specified using a dictionary with the following format: + + .. code-block:: python + + attachments = { + 'name': ('repo_id', 'model_name'), # Basic format + 'name': ('repo_id', 'model_name', {'threshold': 0.35}) # With additional parameters when predicting + } + + The ``fmt`` argument can include attachment results using the format 'name/key', where: + + - ``name``: The name specified in the attachments dictionary + - ``key``: The specific output type requested from the attachment + + For example: + + >>> from imgutils.tagging import get_wd14_tags + >>> + >>> # Using an attachment for additional style tagging + >>> results = get_wd14_tags( + ... 'image.jpg', + ... attachments={'monochrome': ('deepghs/eattach_monochrome_experiments', 'mlp_layer1_seed1')}, + ... fmt=('general', 'monochrome/scores') + ... ) + >>> + >>> # Results will include both base tags and attachment outputs + >>> print(results) + ( + {'1girl': 0.99, ...}, + {'monochrome': 0.999, 'normal': 0.001}, + ) + + Multiple attachments can be used simultaneously, and each attachment can provide multiple output types + through its fmt specification. Ensure that attachment models are compatible with the base WD14 model's + embedding format. + Example: Here are some images for example @@ -356,6 +441,7 @@ def get_wd14_tags( no_underline=no_underline, drop_overlap=drop_overlap, fmt=fmt, + attachments=attachments, ) @@ -371,7 +457,8 @@ def convert_wd14_emb_to_prediction( character_mcut_enabled: bool = False, no_underline: bool = False, drop_overlap: bool = False, - fmt=('rating', 'general', 'character'), + fmt: Any = ('rating', 'general', 'character'), + attachments: Optional[Mapping[str, Union[Tuple[str, str], Tuple[str, str, dict]]]] = None, denormalize: bool = False, denormalizer_name: str = _DEFAULT_DENORMALIZER_NAME, ): @@ -397,12 +484,16 @@ def convert_wd14_emb_to_prediction( :param drop_overlap: Remove overlapping tags to reduce redundancy :type drop_overlap: bool :param fmt: Specify return format structure for predictions, default is ``('rating', 'general', 'character')``. - :type fmt: tuple + :type fmt: Any + :param attachments: Additional model attachments for extended tagging capabilities + :type attachments: Optional[Mapping[str, Union[Tuple[str, str], Tuple[str, str, dict]]]] :param denormalize: Whether to denormalize the embedding before prediction :type denormalize: bool :param denormalizer_name: Name of the denormalizer to use if denormalization is enabled :type denormalizer_name: str :return: For single embeddings: prediction result based on fmt. For batches: list of prediction results. + :rtype: Any + :raises ValueError: If attachment configuration is invalid or incompatible .. note:: Only the embeddings not get normalized can be converted to understandable prediction result. @@ -453,6 +544,7 @@ def convert_wd14_emb_to_prediction( no_underline=no_underline, drop_overlap=drop_overlap, fmt=fmt, + attachments=attachments, ) else: return [ @@ -467,6 +559,7 @@ def convert_wd14_emb_to_prediction( no_underline=no_underline, drop_overlap=drop_overlap, fmt=fmt, + attachments=attachments, ) for pred_item, emb_item in zip(pred, emb) ] diff --git a/test/tagging/test_wd14.py b/test/tagging/test_wd14.py index 289d20efadb..606037abbe4 100644 --- a/test/tagging/test_wd14.py +++ b/test/tagging/test_wd14.py @@ -1,6 +1,8 @@ import numpy as np import pytest +from hbutils.testing import tmatrix +from imgutils.generic.attachment import open_attachment from imgutils.tagging import get_wd14_tags, convert_wd14_emb_to_prediction from imgutils.tagging.wd14 import _get_wd14_model, denormalize_wd14_emb, _open_denormalize_model from test.testings import get_testfile @@ -13,6 +15,7 @@ def _release_model_after_run(): finally: _get_wd14_model.cache_clear() _open_denormalize_model.cache_clear() + open_attachment.cache_clear() @pytest.mark.unittest @@ -229,3 +232,91 @@ def test_denormalize_wd14_emb_multiple(self, files): assert rating == pytest.approx(expected_rating, abs=1e-2) assert general == pytest.approx(expected_general, abs=1e-2) assert character == pytest.approx(expected_character, abs=1e-2) + + @pytest.mark.parametrize(*tmatrix({ + ('type_', 'file'): [ + ('monochrome', '6130053.jpg'), + ('monochrome', '6125854(第 3 个复件).jpg'), + ('monochrome', '5221834.jpg'), + ('monochrome', '1951253.jpg'), + ('monochrome', '4879658.jpg'), + ('monochrome', '80750471_p3_master1200.jpg'), + + ('normal', '54566940_p0_master1200.jpg'), + ('normal', '60817155_p18_master1200.jpg'), + ('normal', '4945494.jpg'), + ('normal', '4008375.jpg'), + ('normal', '2416278.jpg'), + ('normal', '842709.jpg') + ], + }, mode='matrix')) + def test_get_wd14_tags_with_attachments(self, type_, file): + filename = get_testfile('dataset', 'monochrome_danbooru', type_, file) + scores, (top_label, top_score) = get_wd14_tags( + filename, + fmt=('monochrome/scores', 'monochrome/top'), + attachments={'monochrome': ('deepghs/eattach_monochrome_experiments', 'mlp_layer1_seed1')}, + ) + assert scores[type_] >= 0.5 + assert top_label == type_ + assert top_score >= 0.5 + + @pytest.mark.parametrize(*tmatrix({ + ('type_', 'file'): [ + ('monochrome', '6130053.jpg'), + ('monochrome', '6125854(第 3 个复件).jpg'), + ('monochrome', '5221834.jpg'), + ('monochrome', '1951253.jpg'), + ('monochrome', '4879658.jpg'), + ('monochrome', '80750471_p3_master1200.jpg'), + + ('normal', '54566940_p0_master1200.jpg'), + ('normal', '60817155_p18_master1200.jpg'), + ('normal', '4945494.jpg'), + ('normal', '4008375.jpg'), + ('normal', '2416278.jpg'), + ('normal', '842709.jpg') + ], + }, mode='matrix')) + def test_get_wd14_tags_with_attachments_extra_cfg(self, type_, file): + filename = get_testfile('dataset', 'monochrome_danbooru', type_, file) + scores, (top_label, top_score) = get_wd14_tags( + filename, + fmt=('monochrome/scores', 'monochrome/top'), + attachments={'monochrome': ('deepghs/eattach_monochrome_experiments', 'mlp_layer1_seed1', {})}, + ) + assert scores[type_] >= 0.5 + assert top_label == type_ + assert top_score >= 0.5 + + @pytest.mark.parametrize(*tmatrix({ + ('type_', 'file'): [ + ('monochrome', '6130053.jpg'), + ('normal', '54566940_p0_master1200.jpg'), + ], + }, mode='matrix')) + def test_get_wd14_tags_with_attachments_invalid_attachment_name(self, type_, file): + filename = get_testfile('dataset', 'monochrome_danbooru', type_, file) + with pytest.raises(ValueError): + scores, (top_label, top_score) = get_wd14_tags( + filename, + fmt=('monochrome/t/scores', 'monochrome/t/top'), + attachments={'monochrome/t': ('deepghs/eattach_monochrome_experiments', 'mlp_layer1_seed1')}, + ) + _ = top_label, top_score + + @pytest.mark.parametrize(*tmatrix({ + ('type_', 'file'): [ + ('monochrome', '6130053.jpg'), + ('normal', '54566940_p0_master1200.jpg'), + ], + }, mode='matrix')) + def test_get_wd14_tags_with_attachments_invalid_attachment_config(self, type_, file): + filename = get_testfile('dataset', 'monochrome_danbooru', type_, file) + with pytest.raises(ValueError): + scores, (top_label, top_score) = get_wd14_tags( + filename, + fmt=('monochrome/scores', 'monochrome/top'), + attachments={'monochrome': ('deepghs/eattach_monochrome_experiments',)}, + ) + _ = top_label, top_score