Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
name: Test

on: [push, pull_request]

jobs:
build:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- uses: actions/setup-python@v5
with:
python-version: '3.10'
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install --no-deps .
pip install pytest onnxruntime transformers huggingface_hub
- name: Run tests
run: |
pytest -vv

4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,9 @@ docker run -d --name tei --gpus all -p 8080:80 -v $volume:/data --pull always gh
```

```bash
python index_to.py pgvector
python index_to.py pgvector --pgstring <PGSTRING>
# or for local onnx inference
python index_to.py pgvector --pgstring <PGSTRING> --local
```

5. Use embedding for recommendation
Expand Down
64 changes: 48 additions & 16 deletions curiosity/embedding.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,23 @@
from typing import List
from typing import List, Tuple

from transformers import AutoTokenizer, AutoModel
from huggingface_hub import hf_hub_download
import onnxruntime as ort
import numpy as np

from torch import Tensor
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModel

def average_pool(last_hidden_states,
attention_mask):
last_hidden = last_hidden_states.masked_fill(
~attention_mask[..., None].bool(), 0.0)
return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]

def average_pool(last_hidden_states: Tensor,
attention_mask: Tensor) -> Tensor:
last_hidden = last_hidden_states.masked_fill(
~attention_mask[..., None].bool(), 0.0)
return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]

def encode_hf(input_texts: List[str], model_id: str = 'intfloat/multilingual-e5-large',
prefix: str = 'intfloat/multilingual-e5-large'):
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModel.from_pretrained(model_id)
def encode_hf(input_texts: List[str], model_id: str = 'intfloat/multilingual-e5-large',
prefix: str = 'intfloat/multilingual-e5-large'):
import torch.nn.functional as F
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModel.from_pretrained(model_id)
input_texts = [prefix + input_text for input_text in input_texts]
# Tokenize the input texts
batch_dict = tokenizer(input_texts, max_length=512,
Expand All @@ -24,5 +26,35 @@ def encode_hf(input_texts: List[str], model_id: str = 'intfloat/multilingual-e5-
embeddings = average_pool(outputs.last_hidden_state,
batch_dict['attention_mask'])
# normalize embeddings
embeddings = F.normalize(embeddings)
return embeddings
embeddings = F.normalize(embeddings)
return embeddings


def load_onnx(model_id: str = 'texonom/multilingual-e5-small-4096') -> Tuple[AutoTokenizer, ort.InferenceSession]:
"""Load tokenizer and ONNX session for local inference."""
tokenizer = AutoTokenizer.from_pretrained(model_id)
onnx_path = hf_hub_download(model_id, 'onnx/model_quantized.onnx')
session = ort.InferenceSession(onnx_path, providers=['CPUExecutionProvider'])
return tokenizer, session


def encode_onnx(input_texts: List[str], tokenizer: AutoTokenizer,
session: ort.InferenceSession, prefix: str = '') -> List[List[float]]:
"""Encode texts using an ONNX model."""
input_texts = [prefix + text for text in input_texts]
batch_dict = tokenizer(input_texts, max_length=512,
padding=True, truncation=True, return_tensors='np')
ort_inputs = {k: v for k, v in batch_dict.items()}
if 'token_type_ids' not in ort_inputs:
ort_inputs['token_type_ids'] = np.zeros_like(batch_dict['input_ids'])
# onnxruntime expects int64 inputs
for key in ort_inputs:
ort_inputs[key] = ort_inputs[key].astype('int64')
outputs = session.run(None, ort_inputs)[0]
attention_mask = batch_dict['attention_mask']
masked = np.where(attention_mask[..., None] == 1, outputs, 0.0)
summed = masked.sum(axis=1)
counts = attention_mask.sum(axis=1, keepdims=True)
embeddings = summed / counts
embeddings = embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True)
return embeddings.tolist()
25 changes: 24 additions & 1 deletion index_to.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,26 @@
import chromadb
from datasets import load_dataset, Dataset
from tei import TEIClient
from curiosity.embedding import load_onnx, encode_onnx
from huggingface_hub import HfApi
import vecs
import faiss as vdb

from curiosity.data import load_documents


<<<<<<< feature/load-onnx-multilingual-e5-small-model-for-local-inference -- Incoming Change
def pgvector(dataset_id="texonom/texonom-md", dimension=384,
prefix="", subset=None, stream=False, pgstring=None,
tei_host="localhost", tei_port='8080', tei_protocol="http",
batch_size=1000, start_index=None, end_index=None,
local=False, model_id="texonom/multilingual-e5-small-4096"):
=======
def pgvector(dataset_id="texonom/texonom-md", dimension=384,
prefix="", subset=None, stream=False, pgstring=None,
tei_host="localhost", tei_port='8080', tei_protocol="http",
batch_size=1000, start_index=None, end_index=None, limit=3000):
>>>>>>> main -- Current Change
# Load DB and dataset
assert pgstring is not None
vx = vecs.create_client(pgstring)
Expand All @@ -35,8 +44,22 @@ def pgvector(dataset_id="texonom/texonom-md", dimension=384,
dataset = dataset[int(start_index):]
dataset = Dataset.from_dict(dataset)

<<<<<<< feature/load-onnx-multilingual-e5-small-model-for-local-inference -- Incoming Change
# Batch processing function
if local:
tokenizer, session = load_onnx(model_id)

def embed(texts):
return encode_onnx(texts, tokenizer, session)
else:
teiclient = TEIClient(host=tei_host, port=tei_port, protocol=tei_protocol)

def embed(texts):
return teiclient.embed_batch_sync(texts)
=======
# Batch processing function
teiclient = TEIClient(host=tei_host, port=tei_port, protocol=tei_protocol, limit=3000)
>>>>>>> main -- Current Change

def batch_encode(batch_data: Dict) -> Dict:
start = time.time()
Expand All @@ -52,7 +75,7 @@ def batch_encode(batch_data: Dict) -> Dict:
row['text'] = row['text'].strip()[:limit]
input_texts = [
f"{prefix}{row['title']}\n{row['text']}\n{row['refs']}\nParent: {row['parent']}" for row in rows]
embeddings = teiclient.embed_batch_sync(input_texts)
embeddings = embed(input_texts)
metadatas = [{'title': row['title'] if row['title'] is not None else '',
'text': row['text'] if row['text'] is not None else '',
'created': row['created'] if row['created'] is not None else '',
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ dependencies = [
"teicli>=0.3.1",
"faiss-cpu>=1.7.4",
"vecs>=0.4.1",
"onnxruntime>=1.16.0",
]
description = "Add your description here"
name = "curiosity"
Expand Down
11 changes: 11 additions & 0 deletions tests/test_onnx.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import pytest
from curiosity.embedding import load_onnx, encode_onnx


def test_encode_onnx_output_shape():
tokenizer, session = load_onnx('texonom/multilingual-e5-small-4096')
embeddings = encode_onnx(["hello world"], tokenizer, session)
assert isinstance(embeddings, list)
assert len(embeddings) == 1
assert isinstance(embeddings[0], list)
assert len(embeddings[0]) > 0