Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,6 @@ __pycache__/
perda.egg-info/
*.pyc
*.pyo

# Local ML model artifacts (downloaded on demand)
perda/models/stsb-cross-encoder/
10 changes: 10 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,5 +36,15 @@ Not easy to modify and iterate on the source code, but useful for quick setup an
`pip install git+https://github.com/Penn-Electric-Racing/PER-Data-Analyzer.git@main`


## Search Model Setup

The local cross-encoder model is installed automatically with `pip install` and packaged under:

`perda/models/stsb-cross-encoder/`

Natural-language search still falls back to loading from Hugging Face if local model files are unavailable.


## Code Demo

See [Tutorial.ipynb](notebooks/Tutorial.ipynb)
3,879 changes: 3,856 additions & 23 deletions notebooks/Tutorial.ipynb

Large diffs are not rendered by default.

224 changes: 193 additions & 31 deletions perda/utils/search.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,44 @@
import re
from pathlib import Path

from sentence_transformers.cross_encoder import CrossEncoder

from ..analyzer.single_run_data import SingleRunData


PACKAGED_MODEL_DIR = Path(__file__).resolve().parents[1] / "models" / "stsb-cross-encoder"
REPO_MODEL_DIR = Path(__file__).resolve().parents[2] / "models" / "stsb-cross-encoder"


def _resolve_local_model_dir() -> Path:
if PACKAGED_MODEL_DIR.exists():
return PACKAGED_MODEL_DIR

if REPO_MODEL_DIR.exists():
return REPO_MODEL_DIR

return PACKAGED_MODEL_DIR


LOCAL_MODEL_DIR = _resolve_local_model_dir()


ABBREVIATIONS = {
"pcm": "powertrain control module",
"pdu": "power distribution unit",
"ams": "accumulator management system",
"bms": "battery management system",
"dash": "dashboard",
"moc": "motor controller",
"nav": "navigation",
"bat": "battery",
"bspd": "brake system plausibility device",
"rtds": "ready to drive sound",
"imd": "insulation monitoring device",
"flt": "fault",
}


def search(
data: SingleRunData,
search: str,
Expand All @@ -15,62 +53,186 @@ def search(
search : str
Query
"""
search = search.strip()
search = normalize_search_query(search)
if not search:
raise ValueError("Search query cannot be empty.")
search_list = search.lower().split(" ")

query_hits = []
if LOCAL_MODEL_DIR.exists():
model = CrossEncoder(str(LOCAL_MODEL_DIR))
else:
model = CrossEncoder("cross-encoder/stsb-distilroberta-base")

semantic_query = _expand_query(search)
query_terms = _extract_terms(search)

corpus = []
corpus_meta = []

for var_id in data.id_to_cpp_name.keys():
descript = data.id_to_descript[var_id]
cpp_name = data.id_to_cpp_name[var_id]

score = _determine_query_hit(search_list, cpp_name, descript)
if score:
query_hits.append((score, var_id))
if search == normalize_search_query(cpp_name):

print("Exact match found!")
print_single_result(data, cpp_name, descript, score=1.0)
return

corpus.append(create_card(cpp_name, descript))
corpus_meta.append((cpp_name, descript))

ranks = model.rank(semantic_query, corpus)

semantic_scores = {
int(rank["corpus_id"]): float(rank["score"])
for rank in ranks
}

keyword_weight = _keyword_weight(query_terms)
combined_ranks = []
for idx, card in enumerate(corpus):
cpp_name, descript = corpus_meta[idx]
semantic_score = semantic_scores.get(idx, 0.0)
keyword_score = _keyword_score(query_terms, cpp_name, descript, card)
combined_score = (
keyword_weight * keyword_score + (1.0 - keyword_weight) * semantic_score
)
combined_ranks.append((combined_score, idx))

# Sort by score descending
query_hits.sort(reverse=True, key=lambda x: x[0])
combined_ranks.sort(key=lambda x: x[0], reverse=True)

print("==== Search Results ====")
for score, var_id in query_hits:
descript = data.id_to_descript[var_id]
cpp_name = data.id_to_cpp_name[var_id]
print("Query: ", search)
for score, corpus_id in combined_ranks[:10]:
print("----------------------------")
cpp_name, descript = corpus_meta[corpus_id]

print_single_result(data, cpp_name, descript, score)

def normalize_search_query(query: str) -> str:
"""
Normalize the search query by converting to lowercase, removing extra whitespace

Parameters
----------
query : str
The search query to normalize

print(f"Variable: {descript}")
print(f"ID: {var_id}")
print(f"C++ Name: {cpp_name}")
print("-----------------------")
Returns
-------
str
The normalized search query
"""
basic_normalized = ' '.join(query.lower().strip().split())

return basic_normalized

def _determine_query_hit(
search_list: list[str],
def print_single_result(
data: SingleRunData,
cpp_name: str,
descript: str,
) -> int:
score: float,
) -> None:
var_id = data.cpp_name_to_id[cpp_name]

print(f"Score: {score:.2f}")
print(f"Variable ID: {var_id}")
print(f"C++ Name: {cpp_name}")
print(f"Description: {descript}")


def create_card(cpp_name: str, descript: str) -> str:
"""
Determine if a variable matches the search query.
Create a card string for the search corpus.

Parameters
----------
search_list : list[str]
List of search terms
cpp_name : str
Variable name
The C++ variable name
descript : str
Variable description
The description of the variable

Returns
-------
int
Integer representing how good the match is (larger means better).
0 indicates no match
str
A combined string of cpp_name and descript for the search corpus
"""
# TODO: IMPROVE CRITERION FOR MATCH QUALITY

match = False
tokens = []
for segment in re.split(r"[._]", cpp_name):
for token in advanced_split(segment):
lowered = token.lower()
if lowered in ABBREVIATIONS:
lowered = ABBREVIATIONS[lowered] + " (" + token + ")"
tokens.append(lowered)

normalized_descript = normalize_search_query(descript)
expanded_context = " ".join(dict.fromkeys(tokens))
return f"{expanded_context} | {normalized_descript}"


def _extract_terms(text: str) -> list[str]:
return re.findall(r"[a-z0-9]+", normalize_search_query(text))


def _expand_query(query: str) -> str:
expanded_terms: list[str] = []
for term in _extract_terms(query):
expanded_terms.append(term)
if term in ABBREVIATIONS:
expanded_terms.extend(_extract_terms(ABBREVIATIONS[term]))
return " ".join(dict.fromkeys(expanded_terms))


def _keyword_weight(query_terms: list[str]) -> float:
if not query_terms:
return 0.5

if len(query_terms) == 1 and len(query_terms[0]) <= 3:
return 0.8
if len(query_terms) == 1:
return 0.6
if len(query_terms) == 2:
return 0.45

return 0.35


def _keyword_score(
query_terms: list[str],
cpp_name: str,
descript: str,
card: str,
) -> float:
if not query_terms:
return 0.0

searchable_text = normalize_search_query(f"{cpp_name} {descript} {card}")
searchable_tokens = _extract_terms(searchable_text)

raw_score = 0.0
matched_terms = 0
for term in query_terms:
if re.search(rf"\b{re.escape(term)}\b", searchable_text):
raw_score += 1.0
matched_terms += 1
elif any(token.startswith(term) for token in searchable_tokens):
raw_score += 0.7
matched_terms += 1
elif term in searchable_text:
raw_score += 0.4
matched_terms += 1

if matched_terms == len(query_terms):
raw_score += 0.5

match |= any(term in cpp_name.lower() for term in search_list)
match |= any(term in descript.lower() for term in search_list)
max_possible = len(query_terms) + 0.5
return min(raw_score / max_possible, 1.0)


return 1 if match else 0
def advanced_split(camel_case_string: str) -> list[str]:
"""Splits a camelCase string into a list of words using regex."""
# Inserts a period between a lowercase letter and an uppercase letter
s1 = re.sub(r"([a-z])([A-Z])", r"\1.\2", camel_case_string)
# Splits the resulting string by period
return s1.split('.')
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ dynamic = ["dependencies"]
where = ["."]
include = ["perda*"]

[tool.setuptools.package-data]
perda = ["models/stsb-cross-encoder/*"]

[tool.setuptools.dynamic]
dependencies = {file = ["requirements.txt"]}

Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ numpy
plotly
pydantic
scipy
sentence-transformers

# Progress Bar
tqdm
Expand Down
23 changes: 23 additions & 0 deletions scripts/download_search_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from pathlib import Path

from sentence_transformers.cross_encoder import CrossEncoder


MODEL_ID = "cross-encoder/stsb-distilroberta-base"
TARGET_DIR = Path(__file__).resolve().parents[1] / "perda" / "models" / "stsb-cross-encoder"


def main() -> None:
required_files = ["config.json", "model.safetensors", "tokenizer.json"]
if TARGET_DIR.exists() and all((TARGET_DIR / name).exists() for name in required_files):
print(f"Model already present at: {TARGET_DIR}")
return

TARGET_DIR.parent.mkdir(parents=True, exist_ok=True)
model = CrossEncoder(MODEL_ID)
model.save(str(TARGET_DIR))
print(f"Model downloaded and saved to: {TARGET_DIR}")


if __name__ == "__main__":
main()