From 18e5eabc5a92681d099431881c6fbfd1929896d6 Mon Sep 17 00:00:00 2001 From: lujianghu Date: Mon, 1 Jan 2024 22:42:03 +0800 Subject: [PATCH 1/4] add PreProcessor for VLM --- src/uform/models.py | 154 ++++++++++++++++++++++---------------------- 1 file changed, 77 insertions(+), 77 deletions(-) diff --git a/src/uform/models.py b/src/uform/models.py index d821f35..20207ab 100644 --- a/src/uform/models.py +++ b/src/uform/models.py @@ -332,6 +332,75 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return features, embeddings +class PreProcessor: + def __init__(self, tokenizer_path: PathLike, max_position_embeddings: int, _pad_token_idx: int, _image_size: int=224) -> None: + self._image_size = _image_size + self._image_transform = Compose( + [ + Resize(self._image_size, interpolation=InterpolationMode.BICUBIC), + convert_to_rgb, + CenterCrop(self._image_size), + ToTensor(), + Normalize( + mean=(0.48145466, 0.4578275, 0.40821073), + std=(0.26862954, 0.26130258, 0.27577711), + ), + ] + ) + + self._tokenizer = Tokenizer.from_file(tokenizer_path) + self._tokenizer.no_padding() + self.max_position_embeddings = max_position_embeddings + self._pad_token_idx = _pad_token_idx + + def preprocess_image(self, images: Union[Image, List[Image]]) -> Tensor: + """Transforms one or more Pillow images into Torch Tensors. + + :param images: image or list of images to preprocess + """ + + if isinstance(images, list): + batch_images = torch.empty( + (len(images), 3, self._image_size, self._image_size), + dtype=torch.float32, + ) + + for i, image in enumerate(images): + batch_images[i] = self._image_transform(image) + + return batch_images + else: + return self._image_transform(images).unsqueeze(0) + + def preprocess_text(self, texts: Union[str, List[str]]) -> Dict[str, Tensor]: + """Transforms one or more strings into dictionary with tokenized strings and attention masks. + + :param texts: text of list of texts to tokenizer + """ + if isinstance(texts, str): + texts = [texts] + + input_ids = torch.full( + (len(texts), self.max_position_embeddings), + fill_value=self._pad_token_idx, + dtype=torch.int64, + ) + + attention_mask = torch.zeros( + len(texts), self.max_position_embeddings, dtype=torch.int32 + ) + encoded = self._tokenizer.encode_batch(texts) + + for i, seq in enumerate(encoded): + seq_len = min(len(seq), self.max_position_embeddings) + input_ids[i, :seq_len] = torch.LongTensor( + seq.ids[: self.max_position_embeddings] + ) + attention_mask[i, :seq_len] = 1 + + return {"input_ids": input_ids, "attention_mask": attention_mask} + + class VLM(nn.Module): """ Vision-Language Model for multi-modal embeddings. @@ -350,22 +419,11 @@ def __init__(self, config: Dict, tokenizer_path: PathLike): self.text_encoder = TextEncoder(**config["text_encoder"]) self.image_encoder = VisualEncoder(**config["image_encoder"]) - self._tokenizer = Tokenizer.from_file(tokenizer_path) - self._tokenizer.no_padding() - self._pad_token_idx = self.text_encoder.padding_idx + self.preprocess = PreProcessor(tokenizer_path, self.text_encoder.max_position_embeddings, self.text_encoder.padding_idx, self._image_size) + + self.preprocess_text = self.preprocess.preprocess_text + self.preprocess_image = self.preprocess.preprocess_image - self._image_transform = Compose( - [ - Resize(self._image_size, interpolation=InterpolationMode.BICUBIC), - convert_to_rgb, - CenterCrop(self._image_size), - ToTensor(), - Normalize( - mean=(0.48145466, 0.4578275, 0.40821073), - std=(0.26862954, 0.26130258, 0.27577711), - ), - ] - ) def encode_image( self, @@ -461,53 +519,6 @@ def get_matching_scores(self, embeddings: Tensor) -> Tensor: return self.text_encoder.forward_matching(embeddings) - def preprocess_text(self, texts: Union[str, List[str]]) -> Dict[str, Tensor]: - """Transforms one or more strings into dictionary with tokenized strings and attention masks. - - :param texts: text of list of texts to tokenizer - """ - if isinstance(texts, str): - texts = [texts] - - input_ids = torch.full( - (len(texts), self.text_encoder.max_position_embeddings), - fill_value=self._pad_token_idx, - dtype=torch.int64, - ) - - attention_mask = torch.zeros( - len(texts), self.text_encoder.max_position_embeddings, dtype=torch.int32 - ) - encoded = self._tokenizer.encode_batch(texts) - - for i, seq in enumerate(encoded): - seq_len = min(len(seq), self.text_encoder.max_position_embeddings) - input_ids[i, :seq_len] = torch.LongTensor( - seq.ids[: self.text_encoder.max_position_embeddings] - ) - attention_mask[i, :seq_len] = 1 - - return {"input_ids": input_ids, "attention_mask": attention_mask} - - def preprocess_image(self, images: Union[Image, List[Image]]) -> Tensor: - """Transforms one or more Pillow images into Torch Tensors. - - :param images: image or list of images to preprocess - """ - - if isinstance(images, list): - batch_images = torch.empty( - (len(images), 3, self._image_size, self._image_size), - dtype=torch.float32, - ) - - for i, image in enumerate(images): - batch_images[i] = self._image_transform(image) - - return batch_images - else: - return self._image_transform(images).unsqueeze(0) - def forward( self, images: torch.Tensor, @@ -558,22 +569,11 @@ def __init__(self, tokenizer_path, pad_token_idx, url: str = "localhost:7001"): self._client = httpclient self._triton_client = self._client.InferenceServerClient(url=url) - self._tokenizer = Tokenizer.from_file(tokenizer_path) - self._tokenizer.no_padding() - self._pad_token_idx = pad_token_idx + self.preprocess = PreProcessor(tokenizer_path, self.text_encoder.max_position_embeddings, self.text_encoder.padding_idx, self._image_size) + + self.preprocess_text = self.preprocess.preprocess_text + self.preprocess_image = self.preprocess.preprocess_image - self._image_transform = Compose( - [ - Resize(self._image_size, interpolation=InterpolationMode.BICUBIC), - convert_to_rgb, - CenterCrop(self._image_size), - ToTensor(), - Normalize( - mean=(0.48145466, 0.4578275, 0.40821073), - std=(0.26862954, 0.26130258, 0.27577711), - ), - ] - ) def encode_image( self, From 1c7005a618c8f6e576d4ade9b9a5ae9f84b8d03f Mon Sep 17 00:00:00 2001 From: lujianghu Date: Mon, 1 Jan 2024 22:46:56 +0800 Subject: [PATCH 2/4] add torch,onnx,coreml example --- scripts/example.py | 125 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 125 insertions(+) create mode 100644 scripts/example.py diff --git a/scripts/example.py b/scripts/example.py new file mode 100644 index 0000000..7df2f20 --- /dev/null +++ b/scripts/example.py @@ -0,0 +1,125 @@ +import os +from typing import Dict, Tuple, List + +import coremltools as ct +import onnxruntime +import torch +import torch.nn.functional as F +from PIL import Image +from uform.models import PreProcessor +from functools import partial +import time + +import uform + +# export TOKENIZERS_PARALLELISM=true + +def preprocess_data(func): + def wrapper(self, *args, **kwargs): + result = func(self, *args, **kwargs) + if self.method == "onnx": + result = {k: v.cpu().numpy() for k, v in result[0].items()}, { + self.input_image_name: result[1].cpu().numpy() + } + elif self.method == "coreml": + input_ids = result[0]["input_ids"].type(torch.int32).cpu().numpy() + attention_mask = result[0]["attention_mask"].type(torch.int32).cpu().numpy() + + result = {"input_ids": input_ids, "attention_mask": attention_mask}, { + self.input_image_name: result[1].cpu().numpy() + } + return result + + return wrapper + + +class MyModel: + def __init__(self, method: str, model_fpath: str) -> None: + self.method = method + self.model_fpath = model_fpath + max_position_embeddings = 50 + if method == "torch": + self.model = uform.get_model(model_fpath) + self.image_model = partial(self.model.encode_image, return_features=True) + self.text_model = partial(self.model.encode_text, return_features=True) + elif method == "onnx": + fname = "multilingual.{}-encoder.onnx" + image_ort_session = onnxruntime.InferenceSession( + os.path.join(model_fpath, fname.format("image")), providers=["CPUExecutionProvider"] + ) + text_ort_session = onnxruntime.InferenceSession( + os.path.join(model_fpath, fname.format("text")), providers=["CPUExecutionProvider"] + ) + input_ids = text_ort_session.get_inputs()[0] + max_position_embeddings = input_ids.shape[-1] + + def predict_func(ort_session, data): + out = ort_session.run(None, data) + return torch.tensor(out[0]), torch.tensor(out[1]) + + self.image_model = partial(predict_func, image_ort_session) + self.text_model = partial(predict_func, text_ort_session) + + input_image = image_ort_session.get_inputs()[0] + self.input_image_name = input_image.name + elif method == "coreml": + fname = "multilingual-v2.{}-encoder.mlpackage" + image_mlmodel = ct.models.MLModel(os.path.join(model_fpath, fname.format("image"))) + text_mlmodel = ct.models.MLModel(os.path.join(model_fpath, fname.format("text"))) + + def predict_func(model, data): + out = model.predict(data) + return torch.tensor(out["features"]), torch.tensor(out["embeddings"]) + + self.image_model = partial(predict_func, image_mlmodel) + self.text_model = partial(predict_func, text_mlmodel) + + input_image = image_mlmodel.input_description._fd_spec[0] + input_text_lst = text_mlmodel.input_description._fd_spec + self.input_image_name = input_image.name + + self.preprocess = PreProcessor( + os.path.join(self.model_fpath, "tokenizer.json"), max_position_embeddings, 1, 224 + ) + + @preprocess_data + def preprocess_text_image(self, text, image) -> Tuple[Dict[str, torch.Tensor], torch.Tensor]: + image_data = self.preprocess.preprocess_image(image) + text_data = self.preprocess.preprocess_text(text) + return text_data, image_data + + def forward(self, text_data, image_data) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + image_features, image_embedding = self.image_model(image_data) + text_features, text_embedding = self.text_model(text_data) + return image_features, image_embedding, text_features, text_embedding + + def __call__(self, text: str, image) -> float: + text_data, image_data = self.preprocess_text_image(text, image) + image_features, image_embedding = self.image_model(image_data) + text_features, text_embedding = self.text_model(text_data) + + similarity = F.cosine_similarity(image_embedding, text_embedding) + if self.method == "torch": + joint_embedding = self.model.encode_multimodal( + image_features=image_features, + text_features=text_features, + attention_mask=text_data["attention_mask"], + ) + score = self.model.get_matching_scores(joint_embedding) + print("torch score", score) + return similarity + +if __name__ == "__main__": + text = 'a small red panda in a zoo' + image = Image.open('red_panda.jpg') + model_fpath = ... + + for method in ["torch", "onnx", "coreml"]: + model = MyModel(method, model_fpath) + text_data, image_data = model.preprocess_text_image(text, image) + model.forward(text_data, image_data) # just for warm-up + loop_cnt = 10 + s1 = time.time() + for _ in range(loop_cnt): + model.forward(text_data, image_data) + print(method, time.time() - s1) \ No newline at end of file From 53de171a2d654974463b835145de0b94f6de2277 Mon Sep 17 00:00:00 2001 From: lujianghu Date: Sat, 6 Jan 2024 02:28:04 +0800 Subject: [PATCH 3/4] remove models vlm PreProcessor --- src/uform/models.py | 154 ++++++++++++++++++++++---------------------- 1 file changed, 77 insertions(+), 77 deletions(-) diff --git a/src/uform/models.py b/src/uform/models.py index 20207ab..d821f35 100644 --- a/src/uform/models.py +++ b/src/uform/models.py @@ -332,75 +332,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return features, embeddings -class PreProcessor: - def __init__(self, tokenizer_path: PathLike, max_position_embeddings: int, _pad_token_idx: int, _image_size: int=224) -> None: - self._image_size = _image_size - self._image_transform = Compose( - [ - Resize(self._image_size, interpolation=InterpolationMode.BICUBIC), - convert_to_rgb, - CenterCrop(self._image_size), - ToTensor(), - Normalize( - mean=(0.48145466, 0.4578275, 0.40821073), - std=(0.26862954, 0.26130258, 0.27577711), - ), - ] - ) - - self._tokenizer = Tokenizer.from_file(tokenizer_path) - self._tokenizer.no_padding() - self.max_position_embeddings = max_position_embeddings - self._pad_token_idx = _pad_token_idx - - def preprocess_image(self, images: Union[Image, List[Image]]) -> Tensor: - """Transforms one or more Pillow images into Torch Tensors. - - :param images: image or list of images to preprocess - """ - - if isinstance(images, list): - batch_images = torch.empty( - (len(images), 3, self._image_size, self._image_size), - dtype=torch.float32, - ) - - for i, image in enumerate(images): - batch_images[i] = self._image_transform(image) - - return batch_images - else: - return self._image_transform(images).unsqueeze(0) - - def preprocess_text(self, texts: Union[str, List[str]]) -> Dict[str, Tensor]: - """Transforms one or more strings into dictionary with tokenized strings and attention masks. - - :param texts: text of list of texts to tokenizer - """ - if isinstance(texts, str): - texts = [texts] - - input_ids = torch.full( - (len(texts), self.max_position_embeddings), - fill_value=self._pad_token_idx, - dtype=torch.int64, - ) - - attention_mask = torch.zeros( - len(texts), self.max_position_embeddings, dtype=torch.int32 - ) - encoded = self._tokenizer.encode_batch(texts) - - for i, seq in enumerate(encoded): - seq_len = min(len(seq), self.max_position_embeddings) - input_ids[i, :seq_len] = torch.LongTensor( - seq.ids[: self.max_position_embeddings] - ) - attention_mask[i, :seq_len] = 1 - - return {"input_ids": input_ids, "attention_mask": attention_mask} - - class VLM(nn.Module): """ Vision-Language Model for multi-modal embeddings. @@ -419,11 +350,22 @@ def __init__(self, config: Dict, tokenizer_path: PathLike): self.text_encoder = TextEncoder(**config["text_encoder"]) self.image_encoder = VisualEncoder(**config["image_encoder"]) - self.preprocess = PreProcessor(tokenizer_path, self.text_encoder.max_position_embeddings, self.text_encoder.padding_idx, self._image_size) - - self.preprocess_text = self.preprocess.preprocess_text - self.preprocess_image = self.preprocess.preprocess_image + self._tokenizer = Tokenizer.from_file(tokenizer_path) + self._tokenizer.no_padding() + self._pad_token_idx = self.text_encoder.padding_idx + self._image_transform = Compose( + [ + Resize(self._image_size, interpolation=InterpolationMode.BICUBIC), + convert_to_rgb, + CenterCrop(self._image_size), + ToTensor(), + Normalize( + mean=(0.48145466, 0.4578275, 0.40821073), + std=(0.26862954, 0.26130258, 0.27577711), + ), + ] + ) def encode_image( self, @@ -519,6 +461,53 @@ def get_matching_scores(self, embeddings: Tensor) -> Tensor: return self.text_encoder.forward_matching(embeddings) + def preprocess_text(self, texts: Union[str, List[str]]) -> Dict[str, Tensor]: + """Transforms one or more strings into dictionary with tokenized strings and attention masks. + + :param texts: text of list of texts to tokenizer + """ + if isinstance(texts, str): + texts = [texts] + + input_ids = torch.full( + (len(texts), self.text_encoder.max_position_embeddings), + fill_value=self._pad_token_idx, + dtype=torch.int64, + ) + + attention_mask = torch.zeros( + len(texts), self.text_encoder.max_position_embeddings, dtype=torch.int32 + ) + encoded = self._tokenizer.encode_batch(texts) + + for i, seq in enumerate(encoded): + seq_len = min(len(seq), self.text_encoder.max_position_embeddings) + input_ids[i, :seq_len] = torch.LongTensor( + seq.ids[: self.text_encoder.max_position_embeddings] + ) + attention_mask[i, :seq_len] = 1 + + return {"input_ids": input_ids, "attention_mask": attention_mask} + + def preprocess_image(self, images: Union[Image, List[Image]]) -> Tensor: + """Transforms one or more Pillow images into Torch Tensors. + + :param images: image or list of images to preprocess + """ + + if isinstance(images, list): + batch_images = torch.empty( + (len(images), 3, self._image_size, self._image_size), + dtype=torch.float32, + ) + + for i, image in enumerate(images): + batch_images[i] = self._image_transform(image) + + return batch_images + else: + return self._image_transform(images).unsqueeze(0) + def forward( self, images: torch.Tensor, @@ -569,11 +558,22 @@ def __init__(self, tokenizer_path, pad_token_idx, url: str = "localhost:7001"): self._client = httpclient self._triton_client = self._client.InferenceServerClient(url=url) - self.preprocess = PreProcessor(tokenizer_path, self.text_encoder.max_position_embeddings, self.text_encoder.padding_idx, self._image_size) - - self.preprocess_text = self.preprocess.preprocess_text - self.preprocess_image = self.preprocess.preprocess_image + self._tokenizer = Tokenizer.from_file(tokenizer_path) + self._tokenizer.no_padding() + self._pad_token_idx = pad_token_idx + self._image_transform = Compose( + [ + Resize(self._image_size, interpolation=InterpolationMode.BICUBIC), + convert_to_rgb, + CenterCrop(self._image_size), + ToTensor(), + Normalize( + mean=(0.48145466, 0.4578275, 0.40821073), + std=(0.26862954, 0.26130258, 0.27577711), + ), + ] + ) def encode_image( self, From decd1a942d467f60f1efada6d410866f80e83a47 Mon Sep 17 00:00:00 2001 From: lujianghu Date: Mon, 15 Jan 2024 20:27:08 +0800 Subject: [PATCH 4/4] add func get_local_model for test torch model --- scripts/example.py | 23 ++++++++++++++++++----- 1 file changed, 18 insertions(+), 5 deletions(-) diff --git a/scripts/example.py b/scripts/example.py index 7df2f20..dc8b391 100644 --- a/scripts/example.py +++ b/scripts/example.py @@ -1,17 +1,16 @@ import os -from typing import Dict, Tuple, List +from typing import Dict, Tuple, List, Optional import coremltools as ct import onnxruntime import torch import torch.nn.functional as F +import json from PIL import Image -from uform.models import PreProcessor +from uform.models import PreProcessor, VLM from functools import partial import time -import uform - # export TOKENIZERS_PARALLELISM=true def preprocess_data(func): @@ -32,6 +31,20 @@ def wrapper(self, *args, **kwargs): return wrapper +def get_local_model(model_name: str, token: Optional[str] = None) -> VLM: + config_path = f"{model_name}/torch_config.json" + state = torch.load(f"{model_name}/torch_weight.pt") + + tokenizer_path = f"{model_name}/tokenizer.json" + + with open(config_path, "r") as f: + model = VLM(json.load(f), tokenizer_path) + + model.image_encoder.load_state_dict(state["image_encoder"]) + model.text_encoder.load_state_dict(state["text_encoder"]) + + return model.eval() + class MyModel: def __init__(self, method: str, model_fpath: str) -> None: @@ -39,7 +52,7 @@ def __init__(self, method: str, model_fpath: str) -> None: self.model_fpath = model_fpath max_position_embeddings = 50 if method == "torch": - self.model = uform.get_model(model_fpath) + self.model = get_local_model(model_fpath) self.image_model = partial(self.model.encode_image, return_features=True) self.text_model = partial(self.model.encode_text, return_features=True) elif method == "onnx":