diff --git a/imgutils/tagging/__init__.py b/imgutils/tagging/__init__.py index 297c1ae4d84..33b17be9e96 100644 --- a/imgutils/tagging/__init__.py +++ b/imgutils/tagging/__init__.py @@ -17,4 +17,5 @@ from .mldanbooru import get_mldanbooru_tags from .order import sort_tags from .overlap import drop_overlap_tags +from .pixai import get_pixai_tags from .wd14 import get_wd14_tags, convert_wd14_emb_to_prediction, denormalize_wd14_emb diff --git a/imgutils/tagging/pixai.py b/imgutils/tagging/pixai.py new file mode 100644 index 00000000000..6ab1e1e87d6 --- /dev/null +++ b/imgutils/tagging/pixai.py @@ -0,0 +1,230 @@ +import json +from typing import List, Tuple, Any + +import numpy as np +import pandas as pd +from huggingface_hub import hf_hub_download + +from imgutils.data import load_image, ImageTyping +from imgutils.preprocess import create_pillow_transforms +from imgutils.tagging.format import remove_underline +from imgutils.tagging.overlap import drop_overlap_tags +from imgutils.utils import open_onnx_model, vreplace, ts_lru_cache + +EXP_REPO = 'onopix/pixai-tagger-onnx' +MODEL_FILENAME = "model.onnx" +LABEL_FILENAME = "selected_tags.csv" + +_DEFAULT_MODEL_NAME = 'tagger_v_2_3_2' + + +@ts_lru_cache() +def _get_pixai_model(model_name): + """ + Load an ONNX model from the Hugging Face Hub. + + :param model_name: The name of the model to load. + :type model_name: str + :return: The loaded ONNX model. + :rtype: ONNXModel + """ + return open_onnx_model(hf_hub_download( + repo_id=EXP_REPO, + filename=f'{model_name}/model.onnx', + )) + + +@ts_lru_cache() +def _get_pixai_labels(model_name, no_underline: bool = False) -> Tuple[List[str], List[int], List[int]]: + """ + Get labels for the pixai model. + + :param model_name: The name of the model. + :type model_name: str + :param no_underline: If True, replaces underscores in tag names with spaces. + :type no_underline: bool + :return: A tuple containing the list of tag names, and lists of indexes for rating, general, and character categories. + :rtype: Tuple[List[str], List[int], List[int]] + """ + df = pd.read_csv(hf_hub_download( + repo_id=EXP_REPO, + filename=f'{model_name}/selected_tags.csv', + )) + name_series = df["name"] + if no_underline: + name_series = name_series.map(remove_underline) + tag_names = name_series.tolist() + + general_indexes = list(np.where(df["category"] == 0)[0]) + character_indexes = list(np.where(df["category"] == 4)[0]) + return tag_names, general_indexes, character_indexes + + +@ts_lru_cache() +def _get_pixai_weights(model_name): + """ + Load the weights for a pixai model. + + :param model_name: The name of the model. + :type model_name: str + :return: The loaded weights. + :rtype: numpy.ndarray + """ + return np.load(hf_hub_download( + repo_id=EXP_REPO, + filename=f'{model_name}/matrix.npz', + )) + + +@ts_lru_cache() +def _open_preprocess_transforms(model_name: str): + with open(hf_hub_download( + repo_id=EXP_REPO, + filename=f'{model_name}/preprocess.json', + )) as f: + return create_pillow_transforms(json.load(f)['stages']) + + +def _prepare_image_for_tagging(image: ImageTyping, model_name: str): + """ + Prepare an image for tagging by resizing and padding it. + + :param image: The input image. + :type image: ImageTyping + :param model_name: Name of the model. + :type model_name: str + :return: The prepared image as a numpy array. + :rtype: numpy.ndarray + """ + image = load_image(image, force_background='white', mode='RGB') + image_array = _open_preprocess_transforms(model_name)(image) + return np.expand_dims(image_array, axis=0) + + +def _postprocess_embedding( + pred, embedding, logit, + model_name: str = _DEFAULT_MODEL_NAME, + general_threshold: float = 0.15, + character_threshold: float = 0.7, + no_underline: bool = False, + drop_overlap: bool = False, + fmt: Any = ('general', 'character'), +): + """ + Post-process the embedding and prediction results. + + :param pred: The prediction array. + :type pred: numpy.ndarray + :param embedding: The embedding array. + :type embedding: numpy.ndarray + :param logit: The logit array. + :type logit: numpy.ndarray + :param model_name: The name of the model used. + :type model_name: str + :param general_threshold: Threshold for general tags. + :type general_threshold: float + :param character_threshold: Threshold for character tags. + :type character_threshold: float + :param no_underline: Whether to remove underscores from tag names. + :type no_underline: bool + :param drop_overlap: Whether to drop overlapping tags. + :type drop_overlap: bool + :param fmt: The format of the output. + :type fmt: Any + :return: The post-processed results. + """ + assert len(pred.shape) == len(embedding.shape) == 1, \ + f'Both pred and embeddings shapes should be 1-dim, ' \ + f'but pred: {pred.shape!r}, embedding: {embedding.shape!r} actually found.' + tag_names, general_indexes, character_indexes = _get_pixai_labels(model_name, no_underline) + labels = list(zip(tag_names, pred.astype(float))) + + general_names = [labels[i] for i in general_indexes] + general_res = {x: v.item() for x, v in general_names if v > general_threshold} + if drop_overlap: + general_res = drop_overlap_tags(general_res) + + character_names = [labels[i] for i in character_indexes] + character_res = {x: v.item() for x, v in character_names if v > character_threshold} + + return vreplace( + fmt, + { + 'general': general_res, + 'character': character_res, + 'tag': {**general_res, **character_res}, + 'embedding': embedding.astype(np.float32), + 'prediction': pred.astype(np.float32), + 'logit': logit.astype(np.float32), + } + ) + + +def get_pixai_tags( + image: ImageTyping, + model_name: str = _DEFAULT_MODEL_NAME, + general_threshold: float = 0.15, + character_threshold: float = 0.7, + no_underline: bool = False, + drop_overlap: bool = False, + fmt: Any = ('general', 'character'), +): + """ + Get tags for an image using pixai taggers. + + :param image: The input image. + :type image: ImageTyping + :param model_name: The name of the model to use. + :type model_name: str + :param general_threshold: The threshold for general tags. + :type general_threshold: float + :param character_threshold: The threshold for character tags. + :type character_threshold: float + :param no_underline: If True, replaces underscores in tag names with spaces. + :type no_underline: bool + :param drop_overlap: If True, drops overlapping tags. + :type drop_overlap: bool + :param fmt: Return format, default is ``('general', 'character')``. + ``embedding`` is also supported for feature extraction. + :type fmt: Any + :return: Prediction result based on the provided fmt. + + .. note:: + The fmt argument can include the following keys: + + - ``rating``: a dict containing ratings and their confidences + - ``general``: a dict containing general tags and their confidences + - ``character``: a dict containing character tags and their confidences + - ``tag``: a dict containing all tags (including general and character, not including rating) and their confidences + - ``embedding``: a 1-dim embedding of image, recommended for index building after L2 normalization + - ``logit``: a 1-dim logit of image, before softmax. + - ``prediction``: a 1-dim prediction result of image + + You can extract embedding of the given image with the follwing code + + >>> from imgutils.tagging import get_pixai_tags + >>> + >>> embedding = get_pixai_tags('pixai/1.jpg', fmt='embedding') + >>> embedding.shape + (1024, ) + + This embedding is valuable for constructing indices that enable rapid querying of images based on + visual features within large-scale datasets. + """ + + model = _get_pixai_model(model_name) + _, _, target_size, _ = model.get_inputs()[0].shape + input_ = _prepare_image_for_tagging(image, model_name=model_name) + preds, logits, embeddings = model.run(['output', 'logits', 'embedding'], {'input': input_}) + + return _postprocess_embedding( + pred=preds[0], + embedding=embeddings[0], + logit=logits[0], + model_name=model_name, + general_threshold=general_threshold, + character_threshold=character_threshold, + no_underline=no_underline, + drop_overlap=drop_overlap, + fmt=fmt, + ) diff --git a/requirements-zoo.txt b/requirements-zoo.txt index 515110506c7..2eb15ee5f08 100644 --- a/requirements-zoo.txt +++ b/requirements-zoo.txt @@ -1,4 +1,4 @@ -torch<2 +torch lpips matplotlib torchvision @@ -13,7 +13,7 @@ tensorboard einops thop accelerate -timm~=0.6.13 +timm ftfy regex git+https://github.com/openai/CLIP.git @@ -27,4 +27,6 @@ tabulate git+https://github.com/deepghs/waifuc.git@main#egg=waifuc pyquery httpx -onnxslim==0.1.32 \ No newline at end of file +onnxslim==0.1.32 +procslib +unibox \ No newline at end of file diff --git a/zoo/ptagger/__init__.py b/zoo/ptagger/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/zoo/ptagger/model.py b/zoo/ptagger/model.py new file mode 100644 index 00000000000..e8bdb22bc49 --- /dev/null +++ b/zoo/ptagger/model.py @@ -0,0 +1,382 @@ +import copy +import datetime +import json +import os.path + +import numpy as np +import onnx +import onnxruntime +import pandas as pd +import torch +from PIL import Image +from ditk import logging +from hbutils.string import plural_word +from hbutils.system import TemporaryDirectory +from hfutils.cache import delete_detached_cache +from hfutils.operate import get_hf_fs, get_hf_client, upload_directory_as_directory +from hfutils.repository import hf_hub_repo_file_url +from natsort import natsorted +from procslib import get_model +from procslib.models.pixai_tagger import PixAITaggerInference +from thop import profile, clever_format +from timm.models._hub import save_for_hf +from torch import nn + +from imgutils.preprocess import parse_torchvision_transforms +from test.testings import get_testfile +from zoo.utils import onnx_optimize +from zoo.wd14.tags import _get_tag_by_name + + +class ModuleWrapper(nn.Module): + def __init__(self, base_module: nn.Module, classifier: nn.Module): + super().__init__() + self.base_module = base_module + self.classifier = classifier + + self._output_features = None + self._register_hook() + + def _register_hook(self): + def hook_fn(module, input_tensor, output_tensor): + assert isinstance(input_tensor, tuple) and len(input_tensor) == 1 + input_tensor = input_tensor[0] + self._output_features = input_tensor + + self.classifier.register_forward_hook(hook_fn) + + def forward(self, x: torch.Tensor): + logits = self.base_module(x) + preds = torch.sigmoid(logits) + + if self._output_features is None: + raise RuntimeError("Target module did not receive any input during forward pass") + features, self._output_features = self._output_features, None + assert all([x == 1 for x in features.shape[2:]]), f'Invalid feature shape: {features.shape!r}' + features = torch.flatten(features, start_dim=1) + + return features, logits, preds + + +def load_model(model_name: str = "tagger_v_2_2_7"): + hf_client = get_hf_client() + try: + logging.info(f'Try loading model {model_name!r} ...') + model: PixAITaggerInference = get_model("pixai_tagger", model_version=model_name, device='cpu') + created_at = hf_client.get_paths_info( + repo_id=model.model_version_map[model_name]['repo_id'], + repo_type='model', + paths=[model.model_version_map[model_name]['ckpt_name']], + expand=True + )[0].last_commit.date.timestamp() + model_repo_id = model.model_version_map[model_name]['repo_id'] + model_file = model.model_version_map[model_name]['ckpt_name'] + + except (KeyError, ValueError): + alt_model_name = "tagger_v_2_2_7" + logging.info('Cannot directly load it, load from head weights ...') + model: PixAITaggerInference = get_model("pixai_tagger", model_version=alt_model_name, device='cpu') + state_dicts = torch.load(hf_client.hf_hub_download( + repo_id=model.model_version_map[alt_model_name]['repo_id'], + repo_type='model', + filename=model.model_version_map[alt_model_name]['ckpt_name'], + ), map_location="cpu") + model_repo_id = model.model_version_map[alt_model_name]['repo_id'] + model_file = f'{model_name}.pth' + state_dicts_head = torch.load(hf_client.hf_hub_download( + repo_id=model.model_version_map[alt_model_name]['repo_id'], + repo_type='model', + filename=model_file, + ), map_location="cpu") + state_dicts['head.weight'] = state_dicts_head['head.0.weight'] + state_dicts['head.bias'] = state_dicts_head['head.0.bias'] + model.model.load_state_dict(state_dicts) + model.model = model.model.to(model.device) + model.model.eval() + logging.info('Head weights loaded.') + + created_at = hf_client.get_paths_info( + repo_id=model.model_version_map[alt_model_name]['repo_id'], + repo_type='model', + paths=[f'{model_name}.pth'], + expand=True + )[0].last_commit.date.timestamp() + + infer_model = model.model + transforms = model.transform + return model, infer_model, transforms, (model_repo_id, model_file), created_at + + +def extract(export_dir: str, model_name: str = "tagger_v_2_2_7", no_optimize: bool = False): + hf_fs = get_hf_fs() + hf_client = get_hf_client() + + os.makedirs(export_dir, exist_ok=True) + + raw_model, model, transforms, (model_repo_id, model_filename), created_at = load_model(model_name) + raw_model: PixAITaggerInference + image = Image.open(get_testfile('genshin_post.jpg')) + dummy_input = transforms(image).unsqueeze(0) + logging.info(f'Dummy input size: {dummy_input.shape!r}') + + with torch.no_grad(): + expected_dummy_output = model(dummy_input) + logging.info(f'Dummy output size: {expected_dummy_output.shape!r}') + + classifier = model.get_classifier() + classifier_position = None + for name, module in model.named_modules(): + if module is classifier: + classifier_position = name + break + if not classifier_position: + raise RuntimeError(f'No classifier module found in model {type(model)}.') + logging.info(f'Classifier module found at {classifier_position!r}:\n{classifier}') + + wrapped_model = ModuleWrapper(model, classifier=classifier) + with torch.no_grad(): + conv_features, conv_output, conv_preds = wrapped_model(dummy_input) + logging.info(f'Shape of embeddings: {conv_features.shape!r}') + logging.info(f'Sample of expected logits:\n' + f'{expected_dummy_output[:, -10:]}\n' + f'Sample of actual logits:\n' + f'{conv_output[:, -10:]}') + close_matrix = torch.isclose(expected_dummy_output, conv_output, atol=1e-3) + ratio = close_matrix.type(torch.float32).mean() + logging.info(f'{ratio * 100:.2f}% of the logits value are the same.') + assert close_matrix.all(), 'Not all values can match.' + + matrix_data_file = os.path.join(export_dir, 'matrix.npz') + bias = classifier.bias.detach().numpy() + weight = classifier.weight.detach().numpy().T + logging.info(f'Saving matrix data file to {matrix_data_file!r}, ' + f'bias: {bias.dtype!r}{bias.shape!r}, weight: {weight.dtype!r}{weight.shape!r}.') + np.savez( + matrix_data_file, + bias=bias, + weight=weight, + ) + expected_logits = conv_features.detach().numpy() @ weight + bias + np.testing.assert_allclose(conv_output.detach().numpy(), expected_logits, rtol=1e-03, atol=1e-05) + + logging.info('Profiling model ...') + macs, params = profile(model, inputs=(dummy_input,)) + s_macs, s_params = clever_format([macs, params], "%.1f") + logging.info(f'Params: {s_params}, FLOPs: {s_macs}') + + logging.info('Exporting model weights ...') + save_for_hf( + model, + export_dir, + safe_serialization='both', + ) + + with open(os.path.join(export_dir, 'meta.json'), 'w') as f: + json.dump({ + 'num_classes': conv_preds.shape[-1], + 'num_features': conv_features.shape[-1], + 'params': params, + 'flops': macs, + 'name': model_name, + 'model_cls': type(model).__name__, + 'input_size': dummy_input.shape[2], + 'repo_id': model_repo_id, + 'model_filename': model_filename, + 'created_at': created_at, + }, f, indent=4, sort_keys=True) + + logging.info(f'Writing transforms:\n{transforms}') + with open(os.path.join(export_dir, 'preprocess.json'), 'w') as f: + json.dump({ + 'stages': parse_torchvision_transforms(transforms), + }, f, indent=4, sort_keys=True) + + df_p_tags = pd.read_csv(hf_client.hf_hub_download( + repo_id='deepghs/site_tags', + repo_type='dataset', + filename='danbooru.donmai.us/tags.csv' + )) + logging.info(f'Loaded danbooru tags pool, columns: {df_p_tags.columns!r}') + d_p_tags = {(item['category'], item['name']): item for item in df_p_tags.to_dict('records')} + + num_classes = raw_model.model_version_map[raw_model.model_version]['num_classes'] + logging.info(f'Num classes: {num_classes!r}') + d_tags = {v: k for k, v in raw_model.tag_map.items()} + r_tags = [] + for i in range(num_classes): + category = 0 if i < raw_model.gen_tag_count else 4 + if (category, d_tags[i]) in d_p_tags: + tag_id = d_p_tags[(category, d_tags[i])]['id'] + count = d_p_tags[(category, d_tags[i])]['post_count'] + else: + logging.warning(f'Cannot find tag {d_tags[i]!r}, category: {category!r}.') + tag_info = _get_tag_by_name(d_tags[i]) + if tag_info['name'] != d_tags[i]: + logging.warning(f'Not found matching tags for {d_tags[i]!r}, will be ignored.') + tag_id = -1 + count = -1 + else: + logging.info(f'Tag info found from danbooru - {tag_info!r}.') + tag_id = tag_info['id'] + count = tag_info['post_count'] + r_tags.append({ + 'id': i, + 'tag_id': tag_id, + 'name': d_tags[i], + 'category': category, + 'count': count, + }) + df_tags = pd.DataFrame(r_tags) + tags_file = os.path.join(export_dir, 'selected_tags.csv') + logging.info(f'Tags List:\n{df_tags}\n' + f'Saving to {tags_file!r} ...') + df_tags.to_csv(tags_file, index=False) + + onnx_filename = os.path.join(export_dir, 'model.onnx') + with TemporaryDirectory() as td: + temp_model_onnx = os.path.join(td, 'model.onnx') + logging.info(f'Exporting temporary ONNX model to {temp_model_onnx!r} ...') + torch.onnx.export( + wrapped_model, + dummy_input, + temp_model_onnx, + input_names=['input'], + output_names=['embedding', 'logits', 'output'], + dynamic_axes={ + 'input': {0: 'batch_size'}, + 'embedding': {0: 'batch_size'}, + 'logits': {0: 'batch_size'}, + 'output': {0: 'batch_size'}, + }, + opset_version=14, + do_constant_folding=True, + export_params=True, + verbose=False, + custom_opsets=None, + ) + + model = onnx.load(temp_model_onnx) + if not no_optimize: + logging.info('Optimizing onnx model ...') + model = onnx_optimize(model) + + output_model_dir, _ = os.path.split(onnx_filename) + if output_model_dir: + os.makedirs(output_model_dir, exist_ok=True) + logging.info(f'Complete model saving to {onnx_filename!r} ...') + onnx.save(model, onnx_filename) + + session = onnxruntime.InferenceSession(onnx_filename) + o_logits, o_embeddings = session.run(['logits', 'embedding'], {'input': dummy_input.numpy()}) + emb_1 = o_embeddings / np.linalg.norm(o_embeddings, axis=-1, keepdims=True) + emb_2 = conv_features.numpy() / np.linalg.norm(conv_features.numpy(), axis=-1, keepdims=True) + emb_sims = (emb_1 * emb_2).sum() + logging.info(f'Similarity of the embeddings is {emb_sims:.5f}.') + assert emb_sims >= 0.98, f'Similarity of the embeddings is {emb_sims:.5f}, ONNX validation failed.' + + +def sync(repository: str = 'onopix/pixai-tagger-onnx'): + hf_client = get_hf_client() + hf_fs = get_hf_fs() + delete_detached_cache() + if not hf_client.repo_exists(repo_id=repository, repo_type='model'): + hf_client.create_repo(repo_id=repository, repo_type='model', private=True) + attr_lines = hf_fs.read_text(f'{repository}/.gitattributes').splitlines(keepends=False) + attr_lines.append('*.json filter=lfs diff=lfs merge=lfs -text') + attr_lines.append('*.csv filter=lfs diff=lfs merge=lfs -text') + hf_fs.write_text(f'{repository}/.gitattributes', os.linesep.join(attr_lines)) + + if hf_client.file_exists( + repo_id=repository, + repo_type='model', + filename='models.parquet', + ): + df_models = pd.read_parquet(hf_client.hf_hub_download( + repo_id=repository, + repo_type='model', + filename='models.parquet', + )) + d_models = {item['name']: item for item in df_models.to_dict('records')} + else: + d_models = {} + + for model_name in ["tagger_v_2_3_2", "tagger_v_2_2_7"]: + with TemporaryDirectory() as upload_dir: + logging.info(f'Exporting model {model_name!r} ...') + os.makedirs(os.path.join(upload_dir, model_name), exist_ok=True) + try: + extract( + export_dir=os.path.join(upload_dir, model_name), + model_name=model_name, + no_optimize=False, + ) + except Exception: + logging.exception(f'Error when exporting {model_name!r}, skipped.') + continue + + with open(os.path.join(upload_dir, model_name, 'meta.json'), 'r') as f: + meta_info = json.load(f) + c_meta_info = copy.deepcopy(meta_info) + d_models[meta_info['name']] = c_meta_info + + df_models = pd.DataFrame(list(d_models.values())) + df_models = df_models.sort_values(by=['created_at'], ascending=False) + df_models.to_parquet(os.path.join(upload_dir, 'models.parquet'), index=False) + + with open(os.path.join(upload_dir, 'README.md'), 'w') as f: + print('---', file=f) + print('pipeline_tag: image-classification', file=f) + print('base_model:', file=f) + for rid in natsorted(set(df_models['repo_id'][:100])): + print(f'- {rid}', file=f) + print('language:', file=f) + print('- en', file=f) + print('tags:', file=f) + print('- timm', file=f) + print('- image', file=f) + print('- dghs-imgutils', file=f) + print('library_name: dghs-imgutils', file=f) + print('---', file=f) + print('', file=f) + + print('PixAI Tagger ONNX Exported Version.', file=f) + print('', file=f) + + print(f'# Models', file=f) + print(f'', file=f) + + df_shown = pd.DataFrame([ + { + "Name": f'[{item["name"]}]({hf_hub_repo_file_url(repo_id=item["repo_id"], repo_type="model", path=item["model_filename"])})', + 'Params': clever_format(item["params"], "%.1f"), + 'Flops': clever_format(item["flops"], "%.1f"), + 'Input Size': item['input_size'], + "Features": item['num_features'], + "Classes": item['num_classes'], + 'Model': item['model_cls'], + 'Created At': datetime.datetime.fromtimestamp(item['created_at']).strftime('%Y-%m-%d'), + 'created_at': item['created_at'], + } + for item in df_models.to_dict('records') + ]) + df_shown = df_shown.sort_values(by=['created_at'], ascending=[False]) + del df_shown['created_at'] + print(f'{plural_word(len(df_shown), "ONNX model")} exported in total.', file=f) + print(f'', file=f) + print(df_shown.to_markdown(index=False), file=f) + print(f'', file=f) + + upload_directory_as_directory( + repo_id=repository, + repo_type='model', + local_directory=upload_dir, + path_in_repo='.', + message=f'Export model {model_name!r}', + ) + + +if __name__ == '__main__': + logging.try_init_root(level=logging.INFO) + sync( + repository='onopix/pixai-tagger-onnx' + )