From e7df6bbefc6de15ae717146db0497d166289751f Mon Sep 17 00:00:00 2001 From: narugo1992 Date: Mon, 24 Mar 2025 18:34:12 +0800 Subject: [PATCH 01/18] dev(narugo): guess what is this --- requirements-zoo.txt | 4 +- zoo/ptagger/__init__.py | 0 zoo/ptagger/model.py | 154 ++++++++++++++++++++++++++++++++++++++++ 3 files changed, 157 insertions(+), 1 deletion(-) create mode 100644 zoo/ptagger/__init__.py create mode 100644 zoo/ptagger/model.py diff --git a/requirements-zoo.txt b/requirements-zoo.txt index 515110506c7..d58c1794a1d 100644 --- a/requirements-zoo.txt +++ b/requirements-zoo.txt @@ -27,4 +27,6 @@ tabulate git+https://github.com/deepghs/waifuc.git@main#egg=waifuc pyquery httpx -onnxslim==0.1.32 \ No newline at end of file +onnxslim==0.1.32 +procslib +unibox \ No newline at end of file diff --git a/zoo/ptagger/__init__.py b/zoo/ptagger/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/zoo/ptagger/model.py b/zoo/ptagger/model.py new file mode 100644 index 00000000000..698a05af2ae --- /dev/null +++ b/zoo/ptagger/model.py @@ -0,0 +1,154 @@ +import json +import os.path + +import numpy as np +import onnx +import onnxruntime +import torch +from PIL import Image +from ditk import logging +from hbutils.system import TemporaryDirectory +from hfutils.operate import get_hf_fs, get_hf_client +from procslib import get_model +from procslib.models.pixai_tagger import PixAITaggerInference +from thop import profile, clever_format +from torch import nn + +from imgutils.preprocess import parse_torchvision_transforms +from test.testings import get_testfile +from zoo.utils import onnx_optimize + + +class ModuleWrapper(nn.Module): + def __init__(self, base_module: nn.Module, classifier: nn.Module): + super().__init__() + self.base_module = base_module + self.classifier = classifier + + self._output_features = None + self._register_hook() + + def _register_hook(self): + def hook_fn(module, input_tensor, output_tensor): + assert isinstance(input_tensor, tuple) and len(input_tensor) == 1 + input_tensor = input_tensor[0] + self._output_features = input_tensor + + self.classifier.register_forward_hook(hook_fn) + + def forward(self, x: torch.Tensor): + logits = self.base_module(x) + preds = torch.sigmoid(logits) + + if self._output_features is None: + raise RuntimeError("Target module did not receive any input during forward pass") + features, self._output_features = self._output_features, None + assert all([x == 1 for x in features.shape[2:]]), f'Invalid feature shape: {features.shape!r}' + features = torch.flatten(features, start_dim=1) + + return features, logits, preds + + +def load_model(model_name: str = "tagger_v_2_2_7"): + model: PixAITaggerInference = get_model("pixai_tagger", model_version=model_name, device='cpu') + infer_model = model.model + transforms = model.transform + return infer_model, transforms + + +def extract(export_dir: str, model_name: str = "tagger_v_2_2_7", no_optimize: bool = False): + hf_fs = get_hf_fs() + hf_client = get_hf_client() + + os.makedirs(export_dir, exist_ok=True) + + model, transforms = load_model(model_name) + image = Image.open(get_testfile('genshin_post.jpg')) + dummy_input = transforms(image).unsqueeze(0) + logging.info(f'Dummy input size: {dummy_input.shape!r}') + + with torch.no_grad(): + expected_dummy_output = model(dummy_input) + logging.info(f'Dummy output size: {expected_dummy_output.shape!r}') + + classifier = model.get_classifier() + classifier_position = None + for name, module in model.named_modules(): + if module is classifier: + classifier_position = name + break + if not classifier_position: + raise RuntimeError(f'No classifier module found in model {type(model)}.') + logging.info(f'Classifier module found at {classifier_position!r}:\n{classifier}') + + wrapped_model = ModuleWrapper(model, classifier=classifier) + with torch.no_grad(): + conv_features, conv_output, conv_preds = wrapped_model(dummy_input) + logging.info(f'Shape of embeddings: {conv_features.shape!r}') + logging.info(f'Sample of expected logits:\n' + f'{expected_dummy_output[:, -10:]}\n' + f'Sample of actual logits:\n' + f'{conv_output[:, -10:]}') + close_matrix = torch.isclose(expected_dummy_output, conv_output, atol=1e-3) + ratio = close_matrix.type(torch.float32).mean() + logging.info(f'{ratio * 100:.2f}% of the logits value are the same.') + assert close_matrix.all(), 'Not all values can match.' + + logging.info('Profiling model ...') + macs, params = profile(model, inputs=(dummy_input,)) + s_macs, s_params = clever_format([macs, params], "%.1f") + logging.info(f'Params: {s_params}, FLOPs: {s_macs}') + + with open(os.path.join(export_dir, 'preprocess.json'), 'w') as f: + json.dump({ + 'stages': parse_torchvision_transforms(transforms), + }, f, indent=4, sort_keys=True) + + onnx_filename = os.path.join(export_dir, 'model.onnx') + with TemporaryDirectory() as td: + temp_model_onnx = os.path.join(td, 'model.onnx') + logging.info(f'Exporting temporary ONNX model to {temp_model_onnx!r} ...') + torch.onnx.export( + wrapped_model, + dummy_input, + temp_model_onnx, + input_names=['input'], + output_names=['embedding', 'logits', 'output'], + dynamic_axes={ + 'input': {0: 'batch_size'}, + 'embedding': {0: 'batch_size'}, + 'logits': {0: 'batch_size'}, + 'output': {0: 'batch_size'}, + }, + opset_version=14, + do_constant_folding=True, + export_params=True, + verbose=False, + custom_opsets=None, + ) + + model = onnx.load(temp_model_onnx) + if not no_optimize: + logging.info('Optimizing onnx model ...') + model = onnx_optimize(model) + + output_model_dir, _ = os.path.split(onnx_filename) + if output_model_dir: + os.makedirs(output_model_dir, exist_ok=True) + logging.info(f'Complete model saving to {onnx_filename!r} ...') + onnx.save(model, onnx_filename) + + session = onnxruntime.InferenceSession(onnx_filename) + o_logits, o_embeddings = session.run(['logits', 'embedding'], {'input': dummy_input.numpy()}) + emb_1 = o_embeddings / np.linalg.norm(o_embeddings, axis=-1, keepdims=True) + emb_2 = conv_features.numpy() / np.linalg.norm(conv_features.numpy(), axis=-1, keepdims=True) + emb_sims = (emb_1 * emb_2).sum() + logging.info(f'Similarity of the embeddings is {emb_sims:.5f}.') + assert emb_sims >= 0.98, f'Similarity of the embeddings is {emb_sims:.5f}, ONNX validation failed.' + + +if __name__ == '__main__': + logging.try_init_root(level=logging.INFO) + extract( + export_dir='test_ex', + ) From 32b8388be4c1ca0e6e37861e36926c69b0ed854f Mon Sep 17 00:00:00 2001 From: narugo1992 Date: Mon, 24 Mar 2025 18:40:24 +0800 Subject: [PATCH 02/18] dev(narugo): unlimit timm version --- requirements-zoo.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements-zoo.txt b/requirements-zoo.txt index d58c1794a1d..cb696283cfc 100644 --- a/requirements-zoo.txt +++ b/requirements-zoo.txt @@ -13,7 +13,7 @@ tensorboard einops thop accelerate -timm~=0.6.13 +timm ftfy regex git+https://github.com/openai/CLIP.git From 6f5b05d1be36404e2c38b8b0ea475eaf3a5f4806 Mon Sep 17 00:00:00 2001 From: narugo1992 Date: Mon, 24 Mar 2025 18:42:17 +0800 Subject: [PATCH 03/18] dev(narugo): unlimit timm version --- requirements-zoo.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements-zoo.txt b/requirements-zoo.txt index cb696283cfc..2eb15ee5f08 100644 --- a/requirements-zoo.txt +++ b/requirements-zoo.txt @@ -1,4 +1,4 @@ -torch<2 +torch lpips matplotlib torchvision From a0f7177ab3238fa76516bc6080cc26d60f288f24 Mon Sep 17 00:00:00 2001 From: narugo1992 Date: Mon, 24 Mar 2025 19:54:54 +0800 Subject: [PATCH 04/18] dev(narugo): addd x --- zoo/ptagger/model.py | 182 +++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 177 insertions(+), 5 deletions(-) diff --git a/zoo/ptagger/model.py b/zoo/ptagger/model.py index 698a05af2ae..480d0bfd6a8 100644 --- a/zoo/ptagger/model.py +++ b/zoo/ptagger/model.py @@ -1,14 +1,21 @@ +import copy +import datetime import json import os.path import numpy as np import onnx import onnxruntime +import pandas as pd import torch from PIL import Image from ditk import logging +from hbutils.string import plural_word from hbutils.system import TemporaryDirectory -from hfutils.operate import get_hf_fs, get_hf_client +from hfutils.cache import delete_detached_cache +from hfutils.operate import get_hf_fs, get_hf_client, upload_directory_as_directory +from hfutils.repository import hf_hub_repo_file_url +from natsort import natsorted from procslib import get_model from procslib.models.pixai_tagger import PixAITaggerInference from thop import profile, clever_format @@ -17,6 +24,7 @@ from imgutils.preprocess import parse_torchvision_transforms from test.testings import get_testfile from zoo.utils import onnx_optimize +from zoo.wd14.tags import _get_tag_by_name class ModuleWrapper(nn.Module): @@ -53,7 +61,7 @@ def load_model(model_name: str = "tagger_v_2_2_7"): model: PixAITaggerInference = get_model("pixai_tagger", model_version=model_name, device='cpu') infer_model = model.model transforms = model.transform - return infer_model, transforms + return model, infer_model, transforms def extract(export_dir: str, model_name: str = "tagger_v_2_2_7", no_optimize: bool = False): @@ -62,7 +70,8 @@ def extract(export_dir: str, model_name: str = "tagger_v_2_2_7", no_optimize: bo os.makedirs(export_dir, exist_ok=True) - model, transforms = load_model(model_name) + raw_model, model, transforms = load_model(model_name) + raw_model: PixAITaggerInference image = Image.open(get_testfile('genshin_post.jpg')) dummy_input = transforms(image).unsqueeze(0) logging.info(f'Dummy input size: {dummy_input.shape!r}') @@ -99,11 +108,75 @@ def extract(export_dir: str, model_name: str = "tagger_v_2_2_7", no_optimize: bo s_macs, s_params = clever_format([macs, params], "%.1f") logging.info(f'Params: {s_params}, FLOPs: {s_macs}') + src_repo_id = raw_model.model_version_map[model_name]['repo_id'] + src_model_filename = raw_model.model_version_map[model_name]['ckpt_name'] + + with open(os.path.join(export_dir, 'meta.json'), 'w') as f: + json.dump({ + 'num_classes': conv_preds.shape[-1], + 'num_features': conv_features.shape[-1], + 'params': params, + 'flops': macs, + 'name': model_name, + 'model_cls': type(model).__name__, + 'input_size': dummy_input.shape[2], + 'repo_id': src_repo_id, + 'model_filename': src_model_filename, + 'created_at': hf_client.get_paths_info( + repo_id=src_repo_id, + repo_type='model', + paths=[src_model_filename], + expand=True + )[0].last_commit.date.timestamp(), + }, f, indent=4, sort_keys=True) + + logging.info(f'Writing transforms:\n{transforms}') with open(os.path.join(export_dir, 'preprocess.json'), 'w') as f: json.dump({ 'stages': parse_torchvision_transforms(transforms), }, f, indent=4, sort_keys=True) + df_p_tags = pd.read_csv(hf_client.hf_hub_download( + repo_id='deepghs/site_tags', + repo_type='dataset', + filename='danbooru.donmai.us/tags.csv' + )) + logging.info(f'Loaded danbooru tags pool, columns: {df_p_tags.columns!r}') + d_p_tags = {(item['category'], item['name']): item for item in df_p_tags.to_dict('records')} + + num_classes = raw_model.model_version_map[model_name]['num_classes'] + logging.info(f'Num classes: {num_classes!r}') + d_tags = {v: k for k, v in raw_model.tag_map.items()} + r_tags = [] + for i in range(num_classes): + category = 0 if i < raw_model.gen_tag_count else 4 + if (category, d_tags[i]) in d_p_tags: + tag_id = d_p_tags[(category, d_tags[i])]['id'] + count = d_p_tags[(category, d_tags[i])]['post_count'] + else: + logging.warning(f'Cannot find tag {d_tags[i]!r}, category: {category!r}.') + tag_info = _get_tag_by_name(d_tags[i]) + if tag_info['name'] != d_tags[i]: + logging.warning(f'Not found matching tags for {d_tags[i]!r}, will be ignored.') + tag_id = -1 + count = -1 + else: + logging.info(f'Tag info found from danbooru - {tag_info!r}.') + tag_id = tag_info['id'] + count = tag_info['post_count'] + r_tags.append({ + 'id': i, + 'tag_id': tag_id, + 'name': d_tags[i], + 'category': category, + 'count': count, + }) + df_tags = pd.DataFrame(r_tags) + tags_file = os.path.join(export_dir, 'selected_tags.csv') + logging.info(f'Tags List:\n{df_tags}\n' + f'Saving to {tags_file!r} ...') + df_tags.to_csv(tags_file, index=False) + onnx_filename = os.path.join(export_dir, 'model.onnx') with TemporaryDirectory() as td: temp_model_onnx = os.path.join(td, 'model.onnx') @@ -147,8 +220,107 @@ def extract(export_dir: str, model_name: str = "tagger_v_2_2_7", no_optimize: bo assert emb_sims >= 0.98, f'Similarity of the embeddings is {emb_sims:.5f}, ONNX validation failed.' +def sync(repository: str = 'onopix/pixai-tagger-onnx'): + hf_client = get_hf_client() + hf_fs = get_hf_fs() + delete_detached_cache() + if not hf_client.repo_exists(repo_id=repository, repo_type='model'): + hf_client.create_repo(repo_id=repository, repo_type='model', private=True) + attr_lines = hf_fs.read_text(f'{repository}/.gitattributes').splitlines(keepends=False) + attr_lines.append('*.json filter=lfs diff=lfs merge=lfs -text') + attr_lines.append('*.csv filter=lfs diff=lfs merge=lfs -text') + hf_fs.write_text(f'{repository}/.gitattributes', os.linesep.join(attr_lines)) + + if hf_client.file_exists( + repo_id=repository, + repo_type='model', + filename='models.parquet', + ): + df_models = pd.read_parquet(hf_client.hf_hub_download( + repo_id=repository, + repo_type='model', + filename='models.parquet', + )) + d_models = {item['name']: item for item in df_models.to_dict('records')} + else: + d_models = {} + + for model_name in ["tagger_v_2_2_7"]: + with TemporaryDirectory() as upload_dir: + logging.info(f'Exporting model {model_name!r} ...') + os.makedirs(os.path.join(upload_dir, model_name), exist_ok=True) + try: + extract( + export_dir=os.path.join(upload_dir, model_name), + no_optimize=False, + ) + except Exception: + logging.exception(f'Error when exporting {model_name!r}, skipped.') + continue + + with open(os.path.join(upload_dir, model_name, 'meta.json'), 'r') as f: + meta_info = json.load(f) + c_meta_info = copy.deepcopy(meta_info) + d_models[meta_info['name']] = c_meta_info + + df_models = pd.DataFrame(list(d_models.values())) + df_models = df_models.sort_values(by=['created_at'], ascending=False) + df_models.to_parquet(os.path.join(upload_dir, 'models.parquet'), index=False) + + with open(os.path.join(upload_dir, 'README.md'), 'w') as f: + print('---', file=f) + print('pipeline_tag: image-classification', file=f) + print('base_model:', file=f) + for rid in natsorted(set(df_models['repo_id'][:100])): + print(f'- {rid}', file=f) + print('language:', file=f) + print('- en', file=f) + print('tags:', file=f) + print('- timm', file=f) + print('- image', file=f) + print('- dghs-imgutils', file=f) + print('library_name: dghs-imgutils', file=f) + print('---', file=f) + print('', file=f) + + print('ONNX export version from [TIMM](https://huggingface.co/timm).', file=f) + print('', file=f) + + print(f'# Models', file=f) + print(f'', file=f) + + df_shown = pd.DataFrame([ + { + "Name": f'[{item["name"]}]({hf_hub_repo_file_url(repo_id=item["repo_id"], repo_type="model", path=item["model_filename"])})', + 'Params': clever_format(item["params"], "%.1f"), + 'Flops': clever_format(item["flops"], "%.1f"), + 'Input Size': item['input_size'], + "Features": item['num_features'], + "Classes": item['num_classes'], + 'Model': item['model_cls'], + 'Created At': datetime.datetime.fromtimestamp(item['created_at']).strftime('%Y-%m-%d'), + 'created_at': item['created_at'], + } + for item in df_models.to_dict('records') + ]) + df_shown = df_shown.sort_values(by=['created_at'], ascending=[False]) + del df_shown['created_at'] + print(f'{plural_word(len(df_shown), "ONNX model")} exported in total.', file=f) + print(f'', file=f) + print(df_shown.to_markdown(index=False), file=f) + print(f'', file=f) + + upload_directory_as_directory( + repo_id=repository, + repo_type='model', + local_directory=upload_dir, + path_in_repo='.', + message=f'Export model {model_name!r}', + ) + + if __name__ == '__main__': logging.try_init_root(level=logging.INFO) - extract( - export_dir='test_ex', + sync( + repository='onopix/pixai-tagger-onnx' ) From 8a973bd2680ca82a9cc72f3e56f68701e3bbbc20 Mon Sep 17 00:00:00 2001 From: narugo1992 Date: Mon, 24 Mar 2025 19:55:29 +0800 Subject: [PATCH 05/18] dev(narugo): addd x --- zoo/ptagger/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/zoo/ptagger/model.py b/zoo/ptagger/model.py index 480d0bfd6a8..d32caf4b4ff 100644 --- a/zoo/ptagger/model.py +++ b/zoo/ptagger/model.py @@ -283,7 +283,7 @@ def sync(repository: str = 'onopix/pixai-tagger-onnx'): print('---', file=f) print('', file=f) - print('ONNX export version from [TIMM](https://huggingface.co/timm).', file=f) + print('PixAI Tagger ONNX Exported Version.', file=f) print('', file=f) print(f'# Models', file=f) From 91cee2e48311e8fc5dab472e2c2d98d2b25a8beb Mon Sep 17 00:00:00 2001 From: narugo1992 Date: Mon, 24 Mar 2025 20:05:44 +0800 Subject: [PATCH 06/18] dev(narugo): addd x --- zoo/ptagger/model.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/zoo/ptagger/model.py b/zoo/ptagger/model.py index d32caf4b4ff..04134323f79 100644 --- a/zoo/ptagger/model.py +++ b/zoo/ptagger/model.py @@ -103,6 +103,19 @@ def extract(export_dir: str, model_name: str = "tagger_v_2_2_7", no_optimize: bo logging.info(f'{ratio * 100:.2f}% of the logits value are the same.') assert close_matrix.all(), 'Not all values can match.' + matrix_data_file = os.path.join(export_dir, 'matrix.npz') + bias = classifier.bias.detach().numpy() + weight = classifier.weight.detach().numpy().T + logging.info(f'Saving matrix data file to {matrix_data_file!r}, ' + f'bias: {bias.dtype!r}{bias.shape!r}, weight: {weight.dtype!r}{weight.shape!r}.') + np.savez( + matrix_data_file, + bias=bias, + weight=weight, + ) + expected_logits = conv_features.detach().numpy() @ weight + bias + np.testing.assert_allclose(conv_output.detach().numpy(), expected_logits, rtol=1e-03, atol=1e-05) + logging.info('Profiling model ...') macs, params = profile(model, inputs=(dummy_input,)) s_macs, s_params = clever_format([macs, params], "%.1f") From 22085242921cabac668d61270620cde06ccad1d2 Mon Sep 17 00:00:00 2001 From: narugo1992 Date: Mon, 24 Mar 2025 20:20:22 +0800 Subject: [PATCH 07/18] dev(narugo): add infer code --- imgutils/tagging/__init__.py | 1 + imgutils/tagging/pixai.py | 230 +++++++++++++++++++++++++++++++++++ 2 files changed, 231 insertions(+) create mode 100644 imgutils/tagging/pixai.py diff --git a/imgutils/tagging/__init__.py b/imgutils/tagging/__init__.py index 297c1ae4d84..33b17be9e96 100644 --- a/imgutils/tagging/__init__.py +++ b/imgutils/tagging/__init__.py @@ -17,4 +17,5 @@ from .mldanbooru import get_mldanbooru_tags from .order import sort_tags from .overlap import drop_overlap_tags +from .pixai import get_pixai_tags from .wd14 import get_wd14_tags, convert_wd14_emb_to_prediction, denormalize_wd14_emb diff --git a/imgutils/tagging/pixai.py b/imgutils/tagging/pixai.py new file mode 100644 index 00000000000..d963cf8a9f6 --- /dev/null +++ b/imgutils/tagging/pixai.py @@ -0,0 +1,230 @@ +import json +from typing import List, Tuple, Any + +import numpy as np +import pandas as pd +from huggingface_hub import hf_hub_download + +from imgutils.data import load_image, ImageTyping +from imgutils.preprocess import create_pillow_transforms +from imgutils.tagging.format import remove_underline +from imgutils.tagging.overlap import drop_overlap_tags +from imgutils.utils import open_onnx_model, vreplace, ts_lru_cache + +EXP_REPO = 'onopix/pixai-tagger-onnx' +MODEL_FILENAME = "model.onnx" +LABEL_FILENAME = "selected_tags.csv" + +_DEFAULT_MODEL_NAME = 'tagger_v_2_2_7' + + +@ts_lru_cache() +def _get_pixai_model(model_name): + """ + Load an ONNX model from the Hugging Face Hub. + + :param model_name: The name of the model to load. + :type model_name: str + :return: The loaded ONNX model. + :rtype: ONNXModel + """ + return open_onnx_model(hf_hub_download( + repo_id=EXP_REPO, + filename=f'{model_name}/model.onnx', + )) + + +@ts_lru_cache() +def _get_pixai_labels(model_name, no_underline: bool = False) -> Tuple[List[str], List[int], List[int]]: + """ + Get labels for the pixai model. + + :param model_name: The name of the model. + :type model_name: str + :param no_underline: If True, replaces underscores in tag names with spaces. + :type no_underline: bool + :return: A tuple containing the list of tag names, and lists of indexes for rating, general, and character categories. + :rtype: Tuple[List[str], List[int], List[int]] + """ + df = pd.read_csv(hf_hub_download( + repo_id=EXP_REPO, + filename=f'{model_name}/selected_tags.csv', + )) + name_series = df["name"] + if no_underline: + name_series = name_series.map(remove_underline) + tag_names = name_series.tolist() + + general_indexes = list(np.where(df["category"] == 0)[0]) + character_indexes = list(np.where(df["category"] == 4)[0]) + return tag_names, general_indexes, character_indexes + + +@ts_lru_cache() +def _get_pixai_weights(model_name): + """ + Load the weights for a pixai model. + + :param model_name: The name of the model. + :type model_name: str + :return: The loaded weights. + :rtype: numpy.ndarray + """ + return np.load(hf_hub_download( + repo_id=EXP_REPO, + filename=f'{model_name}/matrix.npz', + )) + + +@ts_lru_cache() +def _open_preprocess_transforms(model_name: str): + with open(hf_hub_download( + repo_id=EXP_REPO, + filename=f'{model_name}/preprocess.json', + )) as f: + return create_pillow_transforms(json.load(f)['stages']) + + +def _prepare_image_for_tagging(image: ImageTyping, model_name: str): + """ + Prepare an image for tagging by resizing and padding it. + + :param image: The input image. + :type image: ImageTyping + :param model_name: Name of the model. + :type model_name: str + :return: The prepared image as a numpy array. + :rtype: numpy.ndarray + """ + image = load_image(image, force_background='white', mode='RGB') + image_array = _open_preprocess_transforms(model_name)(image) + return np.expand_dims(image_array, axis=0) + + +def _postprocess_embedding( + pred, embedding, logit, + model_name: str = _DEFAULT_MODEL_NAME, + general_threshold: float = 0.15, + character_threshold: float = 0.7, + no_underline: bool = False, + drop_overlap: bool = False, + fmt: Any = ('general', 'character'), +): + """ + Post-process the embedding and prediction results. + + :param pred: The prediction array. + :type pred: numpy.ndarray + :param embedding: The embedding array. + :type embedding: numpy.ndarray + :param logit: The logit array. + :type logit: numpy.ndarray + :param model_name: The name of the model used. + :type model_name: str + :param general_threshold: Threshold for general tags. + :type general_threshold: float + :param character_threshold: Threshold for character tags. + :type character_threshold: float + :param no_underline: Whether to remove underscores from tag names. + :type no_underline: bool + :param drop_overlap: Whether to drop overlapping tags. + :type drop_overlap: bool + :param fmt: The format of the output. + :type fmt: Any + :return: The post-processed results. + """ + assert len(pred.shape) == len(embedding.shape) == 1, \ + f'Both pred and embeddings shapes should be 1-dim, ' \ + f'but pred: {pred.shape!r}, embedding: {embedding.shape!r} actually found.' + tag_names, general_indexes, character_indexes = _get_pixai_labels(model_name, no_underline) + labels = list(zip(tag_names, pred.astype(float))) + + general_names = [labels[i] for i in general_indexes] + general_res = {x: v.item() for x, v in general_names if v > general_threshold} + if drop_overlap: + general_res = drop_overlap_tags(general_res) + + character_names = [labels[i] for i in character_indexes] + character_res = {x: v.item() for x, v in character_names if v > character_threshold} + + return vreplace( + fmt, + { + 'general': general_res, + 'character': character_res, + 'tag': {**general_res, **character_res}, + 'embedding': embedding.astype(np.float32), + 'prediction': pred.astype(np.float32), + 'logit': logit.astype(np.float32), + } + ) + + +def get_pixai_tags( + image: ImageTyping, + model_name: str = _DEFAULT_MODEL_NAME, + general_threshold: float = 0.15, + character_threshold: float = 0.7, + no_underline: bool = False, + drop_overlap: bool = False, + fmt: Any = ('general', 'character'), +): + """ + Get tags for an image using pixai taggers. + + :param image: The input image. + :type image: ImageTyping + :param model_name: The name of the model to use. + :type model_name: str + :param general_threshold: The threshold for general tags. + :type general_threshold: float + :param character_threshold: The threshold for character tags. + :type character_threshold: float + :param no_underline: If True, replaces underscores in tag names with spaces. + :type no_underline: bool + :param drop_overlap: If True, drops overlapping tags. + :type drop_overlap: bool + :param fmt: Return format, default is ``('general', 'character')``. + ``embedding`` is also supported for feature extraction. + :type fmt: Any + :return: Prediction result based on the provided fmt. + + .. note:: + The fmt argument can include the following keys: + + - ``rating``: a dict containing ratings and their confidences + - ``general``: a dict containing general tags and their confidences + - ``character``: a dict containing character tags and their confidences + - ``tag``: a dict containing all tags (including general and character, not including rating) and their confidences + - ``embedding``: a 1-dim embedding of image, recommended for index building after L2 normalization + - ``logit``: a 1-dim logit of image, before softmax. + - ``prediction``: a 1-dim prediction result of image + + You can extract embedding of the given image with the follwing code + + >>> from imgutils.tagging import get_pixai_tags + >>> + >>> embedding = get_pixai_tags('pixai/1.jpg', fmt='embedding') + >>> embedding.shape + (1024, ) + + This embedding is valuable for constructing indices that enable rapid querying of images based on + visual features within large-scale datasets. + """ + + model = _get_pixai_model(model_name) + _, _, target_size, _ = model.get_inputs()[0].shape + input_ = _prepare_image_for_tagging(image, model_name=model_name) + preds, logits, embeddings = model.run(['output', 'logits', 'embedding'], {'input': input_}) + + return _postprocess_embedding( + pred=preds[0], + embedding=embeddings[0], + logit=logits[0], + model_name=model_name, + general_threshold=general_threshold, + character_threshold=character_threshold, + no_underline=no_underline, + drop_overlap=drop_overlap, + fmt=fmt, + ) From d9563dee886c0b2064b0bb9cdc664045a1cb3f5b Mon Sep 17 00:00:00 2001 From: narugo1992 Date: Mon, 24 Mar 2025 20:53:16 +0800 Subject: [PATCH 08/18] dev(narugo): add new export code --- zoo/ptagger/model.py | 49 +++++++++++++++++++++++++++++++++++--------- 1 file changed, 39 insertions(+), 10 deletions(-) diff --git a/zoo/ptagger/model.py b/zoo/ptagger/model.py index 04134323f79..7f75f0e8b93 100644 --- a/zoo/ptagger/model.py +++ b/zoo/ptagger/model.py @@ -58,10 +58,44 @@ def forward(self, x: torch.Tensor): def load_model(model_name: str = "tagger_v_2_2_7"): - model: PixAITaggerInference = get_model("pixai_tagger", model_version=model_name, device='cpu') + hf_client = get_hf_client() + try: + model: PixAITaggerInference = get_model("pixai_tagger", model_version=model_name, device='cpu') + created_at = hf_client.get_paths_info( + repo_id=model.model_version_map[model_name]['repo_id'], + repo_type='model', + paths=[model.model_version_map[model_name]['ckpt_name']], + expand=True + )[0].last_commit.date.timestamp() + + except KeyError: + model: PixAITaggerInference = get_model("pixai_tagger", model_version='tagger_v_2_2_7', device='cpu') + state_dicts = torch.load(hf_client.hf_hub_download( + repo_id=model.model_version_map[model_name]['repo_id'], + repo_type='model', + filename=model.model_version_map[model_name]['ckpt_name'], + ), map_location="cpu") + state_dicts_head = torch.load(hf_client.hf_hub_download( + repo_id=model.model_version_map[model_name], + repo_type='model', + filename=f'{model_name}.pth', + ), map_location="cpu") + state_dicts['head.weight'] = state_dicts_head['head.0.weight'] + state_dicts['head.bias'] = state_dicts_head['head.0.bias'] + model.model.load_state_dict(state_dicts) + model.model = model.model.to(model.device) + model.model.eval() + + created_at = hf_client.get_paths_info( + repo_id=model.model_version_map[model_name]['repo_id'], + repo_type='model', + paths=[f'{model_name}.pth'], + expand=True + )[0].last_commit.date.timestamp() + infer_model = model.model transforms = model.transform - return model, infer_model, transforms + return model, infer_model, transforms, created_at def extract(export_dir: str, model_name: str = "tagger_v_2_2_7", no_optimize: bool = False): @@ -70,7 +104,7 @@ def extract(export_dir: str, model_name: str = "tagger_v_2_2_7", no_optimize: bo os.makedirs(export_dir, exist_ok=True) - raw_model, model, transforms = load_model(model_name) + raw_model, model, transforms, created_at = load_model(model_name) raw_model: PixAITaggerInference image = Image.open(get_testfile('genshin_post.jpg')) dummy_input = transforms(image).unsqueeze(0) @@ -135,12 +169,7 @@ def extract(export_dir: str, model_name: str = "tagger_v_2_2_7", no_optimize: bo 'input_size': dummy_input.shape[2], 'repo_id': src_repo_id, 'model_filename': src_model_filename, - 'created_at': hf_client.get_paths_info( - repo_id=src_repo_id, - repo_type='model', - paths=[src_model_filename], - expand=True - )[0].last_commit.date.timestamp(), + 'created_at': created_at, }, f, indent=4, sort_keys=True) logging.info(f'Writing transforms:\n{transforms}') @@ -258,7 +287,7 @@ def sync(repository: str = 'onopix/pixai-tagger-onnx'): else: d_models = {} - for model_name in ["tagger_v_2_2_7"]: + for model_name in ["tagger_v_2_3_2", "tagger_v_2_2_7"]: with TemporaryDirectory() as upload_dir: logging.info(f'Exporting model {model_name!r} ...') os.makedirs(os.path.join(upload_dir, model_name), exist_ok=True) From ad5cbb0504f86dd3fb0c945a958736d1daf4b009 Mon Sep 17 00:00:00 2001 From: narugo1992 Date: Mon, 24 Mar 2025 20:59:30 +0800 Subject: [PATCH 09/18] dev(narugo): add new export code --- zoo/ptagger/model.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/zoo/ptagger/model.py b/zoo/ptagger/model.py index 7f75f0e8b93..c591f4a654f 100644 --- a/zoo/ptagger/model.py +++ b/zoo/ptagger/model.py @@ -60,6 +60,7 @@ def forward(self, x: torch.Tensor): def load_model(model_name: str = "tagger_v_2_2_7"): hf_client = get_hf_client() try: + logging.info(f'Try loading model {model_name!r} ...') model: PixAITaggerInference = get_model("pixai_tagger", model_version=model_name, device='cpu') created_at = hf_client.get_paths_info( repo_id=model.model_version_map[model_name]['repo_id'], @@ -69,6 +70,7 @@ def load_model(model_name: str = "tagger_v_2_2_7"): )[0].last_commit.date.timestamp() except KeyError: + logging.info('Cannot directly load it, load from head weights ...') model: PixAITaggerInference = get_model("pixai_tagger", model_version='tagger_v_2_2_7', device='cpu') state_dicts = torch.load(hf_client.hf_hub_download( repo_id=model.model_version_map[model_name]['repo_id'], @@ -85,6 +87,7 @@ def load_model(model_name: str = "tagger_v_2_2_7"): model.model.load_state_dict(state_dicts) model.model = model.model.to(model.device) model.model.eval() + logging.info('Head weights loaded.') created_at = hf_client.get_paths_info( repo_id=model.model_version_map[model_name]['repo_id'], @@ -294,6 +297,7 @@ def sync(repository: str = 'onopix/pixai-tagger-onnx'): try: extract( export_dir=os.path.join(upload_dir, model_name), + model_name=model_name, no_optimize=False, ) except Exception: From 0a5707eec7bc047daa678d5205534885a6f05605 Mon Sep 17 00:00:00 2001 From: narugo1992 Date: Mon, 24 Mar 2025 21:00:59 +0800 Subject: [PATCH 10/18] dev(narugo): add new export code, ci skip --- zoo/ptagger/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/zoo/ptagger/model.py b/zoo/ptagger/model.py index c591f4a654f..89fca2f3966 100644 --- a/zoo/ptagger/model.py +++ b/zoo/ptagger/model.py @@ -69,7 +69,7 @@ def load_model(model_name: str = "tagger_v_2_2_7"): expand=True )[0].last_commit.date.timestamp() - except KeyError: + except (KeyError, ValueError): logging.info('Cannot directly load it, load from head weights ...') model: PixAITaggerInference = get_model("pixai_tagger", model_version='tagger_v_2_2_7', device='cpu') state_dicts = torch.load(hf_client.hf_hub_download( From 4db5f4515fb451c287b9eb951d3636bbd9812b88 Mon Sep 17 00:00:00 2001 From: narugo1992 Date: Mon, 24 Mar 2025 21:02:46 +0800 Subject: [PATCH 11/18] dev(narugo): add new export code, ci skip --- zoo/ptagger/model.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/zoo/ptagger/model.py b/zoo/ptagger/model.py index 89fca2f3966..cc63ced0f9d 100644 --- a/zoo/ptagger/model.py +++ b/zoo/ptagger/model.py @@ -70,15 +70,16 @@ def load_model(model_name: str = "tagger_v_2_2_7"): )[0].last_commit.date.timestamp() except (KeyError, ValueError): + alt_model_name = "tagger_v_2_2_7" logging.info('Cannot directly load it, load from head weights ...') - model: PixAITaggerInference = get_model("pixai_tagger", model_version='tagger_v_2_2_7', device='cpu') + model: PixAITaggerInference = get_model("pixai_tagger", model_version=alt_model_name, device='cpu') state_dicts = torch.load(hf_client.hf_hub_download( - repo_id=model.model_version_map[model_name]['repo_id'], + repo_id=model.model_version_map[alt_model_name]['repo_id'], repo_type='model', - filename=model.model_version_map[model_name]['ckpt_name'], + filename=model.model_version_map[alt_model_name]['ckpt_name'], ), map_location="cpu") state_dicts_head = torch.load(hf_client.hf_hub_download( - repo_id=model.model_version_map[model_name], + repo_id=model.model_version_map[alt_model_name], repo_type='model', filename=f'{model_name}.pth', ), map_location="cpu") @@ -90,7 +91,7 @@ def load_model(model_name: str = "tagger_v_2_2_7"): logging.info('Head weights loaded.') created_at = hf_client.get_paths_info( - repo_id=model.model_version_map[model_name]['repo_id'], + repo_id=model.model_version_map[alt_model_name]['repo_id'], repo_type='model', paths=[f'{model_name}.pth'], expand=True From 58865938d3639984f85c4501a1e607391666ca67 Mon Sep 17 00:00:00 2001 From: narugo1992 Date: Mon, 24 Mar 2025 21:03:42 +0800 Subject: [PATCH 12/18] dev(narugo): add new export code, ci skip --- zoo/ptagger/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/zoo/ptagger/model.py b/zoo/ptagger/model.py index cc63ced0f9d..c425cbd53c8 100644 --- a/zoo/ptagger/model.py +++ b/zoo/ptagger/model.py @@ -79,7 +79,7 @@ def load_model(model_name: str = "tagger_v_2_2_7"): filename=model.model_version_map[alt_model_name]['ckpt_name'], ), map_location="cpu") state_dicts_head = torch.load(hf_client.hf_hub_download( - repo_id=model.model_version_map[alt_model_name], + repo_id=model.model_version_map[alt_model_name]['repo_id'], repo_type='model', filename=f'{model_name}.pth', ), map_location="cpu") From 45f9473e908e944767ee8ac1b409267330c66b77 Mon Sep 17 00:00:00 2001 From: narugo1992 Date: Mon, 24 Mar 2025 21:06:26 +0800 Subject: [PATCH 13/18] dev(narugo): add new export code, ci skip --- zoo/ptagger/model.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/zoo/ptagger/model.py b/zoo/ptagger/model.py index c425cbd53c8..f12655206b1 100644 --- a/zoo/ptagger/model.py +++ b/zoo/ptagger/model.py @@ -159,8 +159,8 @@ def extract(export_dir: str, model_name: str = "tagger_v_2_2_7", no_optimize: bo s_macs, s_params = clever_format([macs, params], "%.1f") logging.info(f'Params: {s_params}, FLOPs: {s_macs}') - src_repo_id = raw_model.model_version_map[model_name]['repo_id'] - src_model_filename = raw_model.model_version_map[model_name]['ckpt_name'] + src_repo_id = raw_model.model_version_map[raw_model.model_version]['repo_id'] + src_model_filename = raw_model.model_version_map[raw_model.model_version]['ckpt_name'] with open(os.path.join(export_dir, 'meta.json'), 'w') as f: json.dump({ @@ -190,7 +190,7 @@ def extract(export_dir: str, model_name: str = "tagger_v_2_2_7", no_optimize: bo logging.info(f'Loaded danbooru tags pool, columns: {df_p_tags.columns!r}') d_p_tags = {(item['category'], item['name']): item for item in df_p_tags.to_dict('records')} - num_classes = raw_model.model_version_map[model_name]['num_classes'] + num_classes = raw_model.model_version_map[raw_model.model_version]['num_classes'] logging.info(f'Num classes: {num_classes!r}') d_tags = {v: k for k, v in raw_model.tag_map.items()} r_tags = [] From 0531ed03bfa60ee6157ce7277c556e6aed21140b Mon Sep 17 00:00:00 2001 From: narugo1992 Date: Tue, 25 Mar 2025 14:10:08 +0800 Subject: [PATCH 14/18] dev(narugo): update them, ci skip --- zoo/ptagger/model.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/zoo/ptagger/model.py b/zoo/ptagger/model.py index f12655206b1..29122e41f7c 100644 --- a/zoo/ptagger/model.py +++ b/zoo/ptagger/model.py @@ -68,6 +68,8 @@ def load_model(model_name: str = "tagger_v_2_2_7"): paths=[model.model_version_map[model_name]['ckpt_name']], expand=True )[0].last_commit.date.timestamp() + model_repo_id = model.model_version_map[model_name]['repo_id'] + model_file = model.model_version_map[model_name]['ckpt_name'] except (KeyError, ValueError): alt_model_name = "tagger_v_2_2_7" @@ -78,10 +80,12 @@ def load_model(model_name: str = "tagger_v_2_2_7"): repo_type='model', filename=model.model_version_map[alt_model_name]['ckpt_name'], ), map_location="cpu") + model_repo_id = model.model_version_map[model_name]['repo_id'] + model_file = f'{model_name}.pth' state_dicts_head = torch.load(hf_client.hf_hub_download( repo_id=model.model_version_map[alt_model_name]['repo_id'], repo_type='model', - filename=f'{model_name}.pth', + filename=model_file, ), map_location="cpu") state_dicts['head.weight'] = state_dicts_head['head.0.weight'] state_dicts['head.bias'] = state_dicts_head['head.0.bias'] @@ -99,7 +103,7 @@ def load_model(model_name: str = "tagger_v_2_2_7"): infer_model = model.model transforms = model.transform - return model, infer_model, transforms, created_at + return model, infer_model, transforms, (model_repo_id, model_file), created_at def extract(export_dir: str, model_name: str = "tagger_v_2_2_7", no_optimize: bool = False): @@ -108,7 +112,7 @@ def extract(export_dir: str, model_name: str = "tagger_v_2_2_7", no_optimize: bo os.makedirs(export_dir, exist_ok=True) - raw_model, model, transforms, created_at = load_model(model_name) + raw_model, model, transforms, (model_repo_id, model_filename), created_at = load_model(model_name) raw_model: PixAITaggerInference image = Image.open(get_testfile('genshin_post.jpg')) dummy_input = transforms(image).unsqueeze(0) @@ -159,9 +163,6 @@ def extract(export_dir: str, model_name: str = "tagger_v_2_2_7", no_optimize: bo s_macs, s_params = clever_format([macs, params], "%.1f") logging.info(f'Params: {s_params}, FLOPs: {s_macs}') - src_repo_id = raw_model.model_version_map[raw_model.model_version]['repo_id'] - src_model_filename = raw_model.model_version_map[raw_model.model_version]['ckpt_name'] - with open(os.path.join(export_dir, 'meta.json'), 'w') as f: json.dump({ 'num_classes': conv_preds.shape[-1], @@ -171,8 +172,8 @@ def extract(export_dir: str, model_name: str = "tagger_v_2_2_7", no_optimize: bo 'name': model_name, 'model_cls': type(model).__name__, 'input_size': dummy_input.shape[2], - 'repo_id': src_repo_id, - 'model_filename': src_model_filename, + 'repo_id': model_repo_id, + 'model_filename': model_filename, 'created_at': created_at, }, f, indent=4, sort_keys=True) From 696179931fa8d32ac8fdebed4ca439c5205786d6 Mon Sep 17 00:00:00 2001 From: narugo1992 Date: Tue, 25 Mar 2025 14:14:18 +0800 Subject: [PATCH 15/18] dev(narugo): update them, ci skip --- zoo/ptagger/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/zoo/ptagger/model.py b/zoo/ptagger/model.py index 29122e41f7c..e9c82bf0c82 100644 --- a/zoo/ptagger/model.py +++ b/zoo/ptagger/model.py @@ -80,7 +80,7 @@ def load_model(model_name: str = "tagger_v_2_2_7"): repo_type='model', filename=model.model_version_map[alt_model_name]['ckpt_name'], ), map_location="cpu") - model_repo_id = model.model_version_map[model_name]['repo_id'] + model_repo_id = model.model_version_map[alt_model_name]['repo_id'] model_file = f'{model_name}.pth' state_dicts_head = torch.load(hf_client.hf_hub_download( repo_id=model.model_version_map[alt_model_name]['repo_id'], From ed681c6375a46133e3d83c39e8d1122fa984a7e5 Mon Sep 17 00:00:00 2001 From: narugo1992 Date: Tue, 25 Mar 2025 14:58:17 +0800 Subject: [PATCH 16/18] dev(narugo): update them, ci skip --- zoo/ptagger/model.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/zoo/ptagger/model.py b/zoo/ptagger/model.py index e9c82bf0c82..0b25a5d0f8c 100644 --- a/zoo/ptagger/model.py +++ b/zoo/ptagger/model.py @@ -19,6 +19,7 @@ from procslib import get_model from procslib.models.pixai_tagger import PixAITaggerInference from thop import profile, clever_format +from timm.models._hub import save_for_hf from torch import nn from imgutils.preprocess import parse_torchvision_transforms @@ -163,6 +164,13 @@ def extract(export_dir: str, model_name: str = "tagger_v_2_2_7", no_optimize: bo s_macs, s_params = clever_format([macs, params], "%.1f") logging.info(f'Params: {s_params}, FLOPs: {s_macs}') + logging.info('Exporting model weights ...') + save_for_hf( + model, + expected_logits, + safe_serialization='both', + ) + with open(os.path.join(export_dir, 'meta.json'), 'w') as f: json.dump({ 'num_classes': conv_preds.shape[-1], From 6b0c30a0f9f73b7488f6a15332db79b8bbf9b923 Mon Sep 17 00:00:00 2001 From: narugo1992 Date: Tue, 25 Mar 2025 15:00:15 +0800 Subject: [PATCH 17/18] dev(narugo): update them, ci skip --- zoo/ptagger/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/zoo/ptagger/model.py b/zoo/ptagger/model.py index 0b25a5d0f8c..e8bdb22bc49 100644 --- a/zoo/ptagger/model.py +++ b/zoo/ptagger/model.py @@ -167,7 +167,7 @@ def extract(export_dir: str, model_name: str = "tagger_v_2_2_7", no_optimize: bo logging.info('Exporting model weights ...') save_for_hf( model, - expected_logits, + export_dir, safe_serialization='both', ) From c50557befe4671ab18bb30efd607a193c46ec50d Mon Sep 17 00:00:00 2001 From: narugo1992 Date: Wed, 26 Mar 2025 14:02:15 +0800 Subject: [PATCH 18/18] dev(narugo): default model is tagger_v_2_3_2, ci skip --- imgutils/tagging/pixai.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/imgutils/tagging/pixai.py b/imgutils/tagging/pixai.py index d963cf8a9f6..6ab1e1e87d6 100644 --- a/imgutils/tagging/pixai.py +++ b/imgutils/tagging/pixai.py @@ -15,7 +15,7 @@ MODEL_FILENAME = "model.onnx" LABEL_FILENAME = "selected_tags.csv" -_DEFAULT_MODEL_NAME = 'tagger_v_2_2_7' +_DEFAULT_MODEL_NAME = 'tagger_v_2_3_2' @ts_lru_cache()