diff --git a/README.md b/README.md index 86edc7c..8bdda3a 100644 --- a/README.md +++ b/README.md @@ -20,6 +20,12 @@ To install the dependencies, run the following command: uv sync ``` +To bring a code to a single format: + +```bash +uvx ruff format +``` + ### PyTorch version PyTorch is default set to CPU distributive: diff --git a/convert_trace_annos.py b/convert_trace_annos.py index ebe3f44..2ba729d 100644 --- a/convert_trace_annos.py +++ b/convert_trace_annos.py @@ -13,50 +13,52 @@ def is_empty(value): return True if value == EMPTY else False -def trace_test_cases_to_annos(db_path: Path, trace_file_path: Path): - db = get_db_client() - - insertions = list() - logger.info("Reading trace file and inserting annotations into table...") - with open(trace_file_path, mode="r", newline="", encoding="utf-8") as trace_file: - reader = csv.reader(trace_file) - current_tc = EMPTY - concat_summary = EMPTY - test_script = EMPTY - global_columns = next(reader) - for row in reader: - if row[0] == "TestCaseStart": - current_tc = row[1] - test_script = EMPTY - concat_summary = EMPTY - next(reader) - elif row[0] == "Summary": - continue - elif row[0] == "TestCaseEnd": - if not is_empty(current_tc) and not is_empty(concat_summary): - case_id = db.test_cases.get_or_insert( - test_script=test_script, test_case=current_tc - ) - annotation_id = db.annotations.get_or_insert(summary=concat_summary) - insertions.append( - db.cases_to_annos.insert( - case_id=case_id, annotation_id=annotation_id +def trace_test_cases_to_annos(trace_file_path: Path): + with get_db_client() as db: + insertions = list() + logger.info("Reading trace file and inserting annotations into table...") + with open( + trace_file_path, mode="r", newline="", encoding="utf-8" + ) as trace_file: + reader = csv.reader(trace_file) + current_tc = EMPTY + concat_summary = EMPTY + test_script = EMPTY + global_columns = next(reader) + for row in reader: + if row[0] == "TestCaseStart": + current_tc = row[1] + test_script = EMPTY + concat_summary = EMPTY + next(reader) + elif row[0] == "Summary": + continue + elif row[0] == "TestCaseEnd": + if not is_empty(current_tc) and not is_empty(concat_summary): + case_id = db.test_cases.get_or_insert( + test_script=test_script, test_case=current_tc + ) + annotation_id = db.annotations.get_or_insert( + summary=concat_summary + ) + insertions.append( + db.cases_to_annos.insert( + case_id=case_id, annotation_id=annotation_id + ) ) - ) - else: - if not is_empty(row[global_columns.index("TestCase")]): - if current_tc != row[global_columns.index("TestCase")]: - current_tc = row[global_columns.index("TestCase")] - if is_empty(test_script) and not is_empty( - row[global_columns.index("TestScript")] - ): - test_script = row[global_columns.index("TestScript")] - concat_summary += row[0] + else: + if not is_empty(row[global_columns.index("TestCase")]): + if current_tc != row[global_columns.index("TestCase")]: + current_tc = row[global_columns.index("TestCase")] + if is_empty(test_script) and not is_empty( + row[global_columns.index("TestScript")] + ): + test_script = row[global_columns.index("TestScript")] + concat_summary += row[0] - db.conn.commit() - logger.info( - f"Inserted {len(insertions)} testcase-annotations pairs to database. Successful: {sum(insertions)}" - ) + logger.info( + f"Inserted {len(insertions)} testcase-annotations pairs to database. Successful: {sum(insertions)}" + ) if __name__ == "__main__": diff --git a/main.py b/main.py index 5f9a4da..4637b45 100644 --- a/main.py +++ b/main.py @@ -1,10 +1,12 @@ import streamlit as st +from test2text.pages.documentation import show_documentation from test2text.pages.upload.annotations import show_annotations from test2text.pages.upload.requirements import show_requirements -from test2text.pages.controls.controls_page import controls_page -from test2text.pages.report import make_a_report +from test2text.pages.reports.report_by_req import make_a_report +from test2text.pages.reports.report_by_tc import make_a_tc_report from test2text.services.visualisation.visualize_vectors import visualize_vectors +from test2text.pages.controls.controls_page import controls_page def add_logo(): @@ -37,6 +39,10 @@ def add_logo(): ) add_logo() + about = st.Page( + show_documentation, title="About application", icon=":material/info:" + ) + annotations = st.Page( show_annotations, title="Annotations", icon=":material/database_upload:" ) @@ -44,14 +50,20 @@ def add_logo(): show_requirements, title="Requirements", icon=":material/database_upload:" ) cache_distances = st.Page(controls_page, title="Controls", icon=":material/cached:") - report = st.Page(make_a_report, title="Report", icon=":material/publish:") + report_by_req = st.Page( + make_a_report, title="Requirement's Report", icon=":material/publish:" + ) + report_by_tc = st.Page( + make_a_tc_report, title="Test cases Report", icon=":material/publish:" + ) visualization = st.Page( visualize_vectors, title="Visualize Vectors", icon=":material/dataset:" ) pages = { + "Home": [about], "Upload": [annotations, requirements], "Update": [cache_distances], - "Inspect": [report, visualization], + "Inspect": [report_by_req, report_by_tc, visualization], } pg = st.navigation(pages) diff --git a/pyproject.toml b/pyproject.toml index 73c607c..4d82387 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,7 +3,8 @@ name = "test2text" version = "0.1.0" description = "" authors = [ - {name = "Nikolai Dorofeev - d0rich",email = "dorich2000@gmail.com"} + {name = "Nikolai Dorofeev - d0rich", email = "dorich2000@gmail.com"}, + {name = "Anna Yamkovaya - anngoroshi", email = "avyamkovaya@gmail.com"} ] readme = "README.md" requires-python = ">=3.9" diff --git a/test2text/pages/controls/controls_page.py b/test2text/pages/controls/controls_page.py index cbfc218..166c333 100644 --- a/test2text/pages/controls/controls_page.py +++ b/test2text/pages/controls/controls_page.py @@ -1,22 +1,23 @@ +from test2text.services.db import get_db_client + + def controls_page(): import streamlit as st import plotly.express as px - from test2text.services.embeddings.annotation_embeddings_controls import ( - count_all_annotations, - count_embedded_annotations, - ) - st.header("Controls page") embedding_col, distances_col = st.columns(2) with embedding_col: st.subheader("Embedding") def refresh_counts(): - st.session_state["all_annotations_count"] = count_all_annotations() - st.session_state["embedded_annotations_count"] = ( - count_embedded_annotations() - ) + with get_db_client() as db: + st.session_state["all_annotations_count"] = db.count_all_entries( + "Annotations" + ) + st.session_state["embedded_annotations_count"] = ( + db.count_notnull_entries("embedding", from_table="Annotations") + ) refresh_counts() diff --git a/test2text/pages/documentation.py b/test2text/pages/documentation.py new file mode 100644 index 0000000..87be4e8 --- /dev/null +++ b/test2text/pages/documentation.py @@ -0,0 +1,98 @@ +import streamlit as st + +from test2text.services.db import get_db_client + + +def show_documentation(): + st.markdown(""" + # Test2Text Application Documentation + + ## About the Application + + **Test2Text** is a tool for computing requirement's coverage by tests and generating relevant reports. + The application provides a convenient interface for analysis the relationships between test cases and requirements. + + """) + st.divider() + st.markdown(""" + ## HOW TO USE + + ### Upload data + Click :gray-badge[:material/database_upload: Annotations] or :gray-badge[:material/database_upload: Requirements] to upload annotations and requirements from CSV files to the app's database. + Then Annotations and Requirements are loaded and Test cases are linked to Annotations go to the next chapter. + + ### Renew data + Click :gray-badge[:material/cached: Controls] to transform missed and new texts into numeral vectors (embeddings). + Update distances by embeddings for intelligent matching of Requirements and Annotations. + After distances are refreshed (all Annotations linked with Requirement by distances) go to the next chapter. + + ### Generate reports + Click :gray-badge[:material/publish: Requirement's Report] or :gray-badge[:material/publish: Test cases Report] to make a report. + Use filters and Smart search based on embeddings to select desired information. + Analyze selected requirements or test cases by plotted distances. + List of all requirements/test cases and their annotations are shown here. + + ### Visualize saved data + Click :gray-badge[:material/dataset: Visualize vectors] to plot distances between vector representations of all requirements and annotations in multidimensional spaces. + + """) + st.divider() + with get_db_client() as db: + st.markdown("""## Database overview""") + table, row_count = st.columns(2) + with table: + st.write("Table name") + with row_count: + st.write("Number of entries") + for table_name, count in db.get_db_full_info.items(): + with table: + st.write(table_name) + with row_count: + st.write(count) + st.divider() + st.markdown(""" + ### Methodology + The application use a pre-trained transformer model from the [sentence-transformers library](https://huggingface.co/sentence-transformers), specifically [nomic-ai/nomic-embed-text-v1](https://huggingface.co/nomic-ai/nomic-embed-text-v1), a model trained to produce high-quality vector embeddings for text. + The model returns, for each input text, a high-dimensional NumPy array (vector) of floating point numbers (the embedding). + This arrays give a possibility to calculate Euclidian distances between test cases annotations and requirements to show how similar or dissimilar the two texts. + """) + + st.markdown(""" + #### Euclidean (L2) Distance Formula + The Euclidean (L2) distance is a measure of the straight-line distance between two points (or vectors) in a multidimensional space. + It is widely used to compute the similarity or dissimilarity between two vector representations, such as text embeddings. + """) + st.markdown(""" + Suppose we have two vectors: + """) + st.latex(r""" + \mathbf{a} = [a_1, a_2, ..., a_n] , + """) + st.latex(r""" + \mathbf{b} = [b_1, b_2, ..., b_n] + """) + + st.markdown(""" + The L2 distance between **a** and **b** is calculated as: + """) + + st.latex(r""" + L_2(\mathbf{a}, \mathbf{b}) = \sqrt{(a_1 - b_1)^2 + (a_2 - b_2)^2 + \cdots + (a_n - b_n)^2} + """) + + st.markdown(""" + Or, more compactly: + """) + + st.latex(r""" + L_2(\mathbf{a}, \mathbf{b}) = \sqrt{\sum_{i=1}^n (a_i - b_i)^2} + """) + + st.markdown(""" + - A **smaller L2 distance** means the vectors are more similar. + - A **larger L2 distance** indicates greater dissimilarity. + """) + + st.markdown(""" + This formula is commonly used for comparing the semantic similarity of embeddings generated from text using models like Sentence Transformers. + """) diff --git a/test2text/pages/report.py b/test2text/pages/report.py deleted file mode 100644 index 3d5397c..0000000 --- a/test2text/pages/report.py +++ /dev/null @@ -1,119 +0,0 @@ -from itertools import groupby - -import streamlit as st -from test2text.services.db import get_db_client - - -def add_new_line(summary): - return summary.replace("\n", "
") - - -def make_a_report(): - st.header("Test2Text Report") - - db = get_db_client() - - st.subheader("Table of Contents") - - data = db.conn.execute(""" - SELECT - Requirements.id as req_id, - Requirements.external_id as req_external_id, - Requirements.summary as req_summary, - - Annotations.id as anno_id, - Annotations.summary as anno_summary, - - AnnotationsToRequirements.cached_distance as distance, - - TestCases.id as case_id, - TestCases.test_script as test_script, - TestCases.test_case as test_case - FROM - Requirements - JOIN AnnotationsToRequirements ON Requirements.id = AnnotationsToRequirements.requirement_id - JOIN Annotations ON Annotations.id = AnnotationsToRequirements.annotation_id - JOIN CasesToAnnos ON Annotations.id = CasesToAnnos.annotation_id - JOIN TestCases ON TestCases.id = CasesToAnnos.case_id - ORDER BY - Requirements.id, AnnotationsToRequirements.cached_distance, TestCases.id - """) - - current_annotations = {} - current_test_scripts = set() - - def write_requirement( - req_id, - req_external_id, - req_summary, - current_annotations: set[tuple], - current_test_scripts: set, - ): - if req_id is None and req_external_id is None: - return False - - with st.expander(f"#{req_id} Requirement {req_external_id}"): - st.subheader(f"Requirement {req_external_id}") - st.html(f"

{add_new_line(req_summary)}

") - st.subheader("Annotations") - anno, summary, dist = st.columns(3) - with anno: - st.write("Annonation's id") - with summary: - st.write("Summary") - with dist: - st.write("Distance") - for anno_id, anno_summary, distance in current_annotations: - anno, summary, dist = st.columns(3) - with anno: - st.write(f"{anno_id}") - with summary: - st.html(f"{add_new_line(anno_summary)}") - with dist: - st.write(round(distance, 2)) - - st.subheader("Test Scripts") - for test_script in current_test_scripts: - st.markdown(f"- {test_script}") - - progress_bar = st.empty() - rows = data.fetchall() - if not rows: - st.error("There is no data to inspect.\nPlease upload annotations.") - return None - max_progress = len(rows) - index = 0 - for (req_id, req_external_id, req_summary), group in groupby( - rows, lambda x: x[0:3] - ): - current_annotations = set() - current_test_scripts = set() - index += 1 - for ( - _, - _, - _, - anno_id, - anno_summary, - distance, - case_id, - test_script, - test_case, - ) in group: - current_annotations.add((anno_id, anno_summary, distance)) - current_test_scripts.add(test_script) - write_requirement( - req_id=req_id, - req_external_id=req_external_id, - req_summary=req_summary, - current_annotations=current_annotations, - current_test_scripts=current_test_scripts, - ) - - progress_bar.progress(round(index * 100 / max_progress), text="Processing...") - progress_bar.empty() - db.conn.close() - - -if __name__ == "__main__": - make_a_report() diff --git a/test2text/pages/reports/__init__.py b/test2text/pages/reports/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/test2text/pages/reports/report_by_req.py b/test2text/pages/reports/report_by_req.py new file mode 100644 index 0000000..4880a12 --- /dev/null +++ b/test2text/pages/reports/report_by_req.py @@ -0,0 +1,263 @@ +from itertools import groupby +import numpy as np +import streamlit as st +from sqlite_vec import serialize_float32 + +from test2text.services.utils.math_utils import round_distance + +SUMMARY_LENGTH = 100 +LABELS_SUMMARY_LENGTH = 15 + + +def make_a_report(): + from test2text.services.db import get_db_client + + with get_db_client() as db: + from test2text.services.embeddings.embed import embed_requirement + from test2text.services.utils import unpack_float32 + from test2text.services.visualisation.visualize_vectors import ( + minifold_vectors_2d, + plot_2_sets_in_one_2d, + minifold_vectors_3d, + plot_2_sets_in_one_3d, + ) + + st.header("Test2Text Report") + + def write_annotations(current_annotations: set[tuple]): + st.write("id,", "Summary,", "Distance") + for anno_id, anno_summary, _, distance in current_annotations: + st.write(anno_id, anno_summary, round_distance(distance)) + + with st.container(border=True): + st.subheader("Filter requirements") + with st.expander("🔍 Filters"): + r_id, summary, embed = st.columns(3) + with r_id: + filter_id = st.text_input("ID", value="", key="filter_id") + st.info("Filter by external ID") + with summary: + filter_summary = st.text_input( + "Text content", value="", key="filter_summary" + ) + st.info("Search concrete phrases using SQL like expressions") + with embed: + filter_embedding = st.text_input( + "Smart rearch", value="", key="filter_embedding" + ) + st.info("Search using embeddings") + + where_clauses = [] + params = [] + + if filter_id.strip(): + where_clauses.append("Requirements.id = ?") + params.append(filter_id.strip()) + + if filter_summary.strip(): + where_clauses.append("Requirements.summary LIKE ?") + params.append(f"%{filter_summary.strip()}%") + + distance_sql = "" + distance_order_sql = "" + query_embedding_bytes = None + if filter_embedding.strip(): + query_embedding = embed_requirement(filter_embedding.strip()) + query_embedding_bytes = serialize_float32(query_embedding) + distance_sql = ", vec_distance_L2(embedding, ?) AS distance" + distance_order_sql = "distance ASC, " + + with st.container(border=True): + st.session_state.update({"req_form_submitting": True}) + data = db.get_ordered_values_from_requirements( + distance_sql, + where_clauses, + distance_order_sql, + params + [query_embedding_bytes] if distance_sql else params, + ) + + if distance_sql: + requirements_dict = { + f"{req_external_id} {summary[:SUMMARY_LENGTH]}... [smart search d={round_distance(distance)}]": req_id + for (req_id, req_external_id, summary, distance) in data + } + else: + requirements_dict = { + f"{req_external_id} {summary[:SUMMARY_LENGTH]}...": req_id + for (req_id, req_external_id, summary) in data + } + + st.subheader("Choose 1 of filtered requirements") + option = st.selectbox( + "Choose a requirement to work with", + requirements_dict.keys(), + key="filter_req_id", + ) + + if option: + clause = "Requirements.id = ?" + if clause in where_clauses: + idx = where_clauses.index(clause) + params.insert(idx, requirements_dict[option]) + else: + where_clauses.append(clause) + params.append(requirements_dict[option]) + + st.subheader("Filter Test cases") + + with st.expander("🔍 Filters"): + radius, limit = st.columns(2) + with radius: + filter_radius = st.number_input( + "Insert a radius", + value=1.00, + step=0.01, + key="filter_radius", + ) + st.info("Max distance to annotation") + with limit: + filter_limit = st.number_input( + "Test case limit to show", + min_value=1, + max_value=15, + value=15, + step=1, + key="filter_limit", + ) + st.info("Limit of selected test cases") + + if filter_radius: + where_clauses.append("distance <= ?") + params.append(f"{filter_radius}") + + if filter_limit: + params.append(f"{filter_limit}") + + rows = db.join_all_tables_by_requirements(where_clauses, params) + + if not rows: + st.error( + "There is no requested data to inspect.\n" + "Please check filters, completeness of the data or upload new annotations and requirements." + ) + return None + + for ( + req_id, + req_external_id, + req_summary, + req_embedding, + ), group in groupby(rows, lambda x: x[0:4]): + st.divider() + with st.container(): + st.subheader(f" Inspect Requirement {req_external_id}") + st.write(req_summary) + current_test_cases = dict() + for ( + _, + _, + _, + _, + anno_id, + anno_summary, + anno_embedding, + distance, + case_id, + test_script, + test_case, + ) in group: + current_annotation = current_test_cases.get( + test_case, set() + ) + current_test_cases.update({test_case: current_annotation}) + current_test_cases[test_case].add( + (anno_id, anno_summary, anno_embedding, distance) + ) + + t_cs, anno, viz = st.columns(3) + with t_cs: + with st.container(border=True): + st.write("Test Cases") + st.info("Test cases of chosen Requirement") + st.radio( + "Test cases name", + current_test_cases.keys(), + key="radio_choice", + ) + st.markdown( + """ + + """, + unsafe_allow_html=True, + ) + + if st.session_state["radio_choice"]: + with anno: + with st.container(border=True): + st.write("Annotations") + st.info( + "List of Annotations for chosen Test case" + ) + write_annotations( + current_annotations=current_test_cases[ + st.session_state["radio_choice"] + ] + ) + with viz: + with st.container(border=True): + st.write("Visualization") + select = st.selectbox( + "Choose type of visualization", ["2D", "3D"] + ) + anno_embeddings = [ + unpack_float32(anno_emb) + for _, _, anno_emb, _ in current_test_cases[ + st.session_state["radio_choice"] + ] + ] + anno_labels = [ + f"{anno_id}" + for anno_id, _, _, _ in current_test_cases[ + st.session_state["radio_choice"] + ] + ] + requirement_vectors = np.array( + [np.array(unpack_float32(req_embedding))] + ) + annotation_vectors = np.array(anno_embeddings) + if select == "2D": + plot_2_sets_in_one_2d( + minifold_vectors_2d( + requirement_vectors + ), + minifold_vectors_2d(annotation_vectors), + "Requirement", + "Annotations", + first_labels=[f"{req_external_id}"], + second_labels=anno_labels, + ) + else: + reqs_vectors_3d = minifold_vectors_3d( + requirement_vectors + ) + anno_vectors_3d = minifold_vectors_3d( + annotation_vectors + ) + plot_2_sets_in_one_3d( + reqs_vectors_3d, + anno_vectors_3d, + "Requirement", + "Annotations", + first_labels=[f"{req_external_id}"], + second_labels=anno_labels, + ) + + +if __name__ == "__main__": + make_a_report() diff --git a/test2text/pages/reports/report_by_tc.py b/test2text/pages/reports/report_by_tc.py new file mode 100644 index 0000000..94c5cb9 --- /dev/null +++ b/test2text/pages/reports/report_by_tc.py @@ -0,0 +1,269 @@ +from itertools import groupby +import numpy as np +import streamlit as st +from sqlite_vec import serialize_float32 + +from test2text.services.utils.math_utils import round_distance + + +SUMMARY_LENGTH = 100 + + +def make_a_tc_report(): + from test2text.services.db import get_db_client + + with get_db_client() as db: + from test2text.services.embeddings.embed import embed_requirement + from test2text.services.utils import unpack_float32 + from test2text.services.visualisation.visualize_vectors import ( + minifold_vectors_2d, + plot_2_sets_in_one_2d, + minifold_vectors_3d, + plot_2_sets_in_one_3d, + ) + + st.header("Test2Text Report") + + def write_requirements(current_requirements: set[tuple]): + st.write("External id,", "Summary,", "Distance") + for ( + _, + req_external_id, + req_summary, + _, + distance, + ) in current_requirements: + st.write(req_external_id, req_summary, round_distance(distance)) + + with st.container(border=True): + st.subheader("Filter test cases") + with st.expander("🔍 Filters"): + summary, embed = st.columns(2) + with summary: + filter_summary = st.text_input( + "Text content", value="", key="filter_summary" + ) + st.info("Search concrete phrases using SQL like expressions") + with embed: + filter_embedding = st.text_input( + "Smart rearch", value="", key="filter_embedding" + ) + st.info("Search using embeddings") + + where_clauses = [] + params = [] + + if filter_summary.strip(): + where_clauses.append("Testcases.test_case LIKE ?") + params.append(f"%{filter_summary.strip()}%") + + distance_sql = "" + distance_order_sql = "" + query_embedding_bytes = None + if filter_embedding.strip(): + query_embedding = embed_requirement(filter_embedding.strip()) + query_embedding_bytes = serialize_float32(query_embedding) + distance_sql = ", vec_distance_L2(embedding, ?) AS distance" + distance_order_sql = "distance ASC, " + + with st.container(border=True): + st.session_state.update({"tc_form_submitting": True}) + data = db.get_ordered_values_from_test_cases( + distance_sql, + where_clauses, + distance_order_sql, + params + [query_embedding_bytes] if distance_sql else params, + ) + if distance_sql: + tc_dict = { + f"{test_case} [smart search d={round_distance(distance)}]": tc_id + for (tc_id, _, test_case, distance) in data + } + else: + tc_dict = {test_case: tc_id for (tc_id, _, test_case) in data} + + st.subheader("Choose ONE of filtered test cases") + option = st.selectbox( + "Choose a requirement to work with", tc_dict.keys(), key="filter_tc_id" + ) + + if option: + where_clauses.append("Testcases.id = ?") + params.append(tc_dict[option]) + + st.subheader("Filter Requirements") + + with st.expander("🔍 Filters"): + radius, limit = st.columns(2) + with radius: + filter_radius = st.number_input( + "Insert a radius", + value=1.00, + step=0.01, + key="filter_radius", + ) + st.info("Max distance to annotation") + with limit: + filter_limit = st.number_input( + "Requirement's limit to show", + min_value=1, + max_value=15, + value=15, + step=1, + key="filter_limit", + ) + st.info("Limit of selected requirements") + + if filter_radius: + where_clauses.append("distance <= ?") + params.append(f"{filter_radius}") + + if filter_limit: + params.append(f"{filter_limit}") + + rows = db.join_all_tables_by_test_cases(where_clauses, params) + + if not rows: + st.error( + "There is no requested data to inspect.\n" + "Please check filters, completeness of the data or upload new annotations and requirements." + ) + return None + + for (tc_id, test_script, test_case), group in groupby( + rows, lambda x: x[0:3] + ): + st.divider() + with st.container(): + st.subheader(f"Inspect #{tc_id} Test case '{test_case}'") + st.write(f"From test script {test_script}") + current_annotations = dict() + for ( + _, + _, + _, + anno_id, + anno_summary, + anno_embedding, + distance, + req_id, + req_external_id, + req_summary, + req_embedding, + ) in group: + current_annotation = (anno_id, anno_summary, anno_embedding) + current_reqs = current_annotations.get( + current_annotation, set() + ) + current_annotations.update( + {current_annotation: current_reqs} + ) + current_annotations[current_annotation].add( + ( + req_id, + req_external_id, + req_summary, + req_embedding, + distance, + ) + ) + + t_cs, anno, viz = st.columns(3) + with t_cs: + with st.container(border=True): + st.write("Annotations") + st.info("Annotations linked to chosen Test case") + reqs_by_anno = { + f"#{anno_id} {anno_summary}": ( + anno_id, + anno_summary, + anno_embedding, + ) + for ( + anno_id, + anno_summary, + anno_embedding, + ) in current_annotations.keys() + } + radio_choice = st.radio( + "Annotation's id + summary", + reqs_by_anno.keys(), + key="radio_choice", + ) + st.markdown( + """ + + """, + unsafe_allow_html=True, + ) + + if radio_choice: + with anno: + with st.container(border=True): + st.write("Requirements") + st.info( + "Found Requirements for chosen annotation" + ) + write_requirements( + current_annotations[ + reqs_by_anno[radio_choice] + ] + ) + with viz: + with st.container(border=True): + st.write("Visualization") + select = st.selectbox( + "Choose type of visualization", ["2D", "3D"] + ) + req_embeddings = [ + unpack_float32(req_emb) + for _, _, _, req_emb, _ in current_annotations[ + reqs_by_anno[radio_choice] + ] + ] + req_labels = [ + f"{ext_id}" + for _, ext_id, req_sum, _, _ in current_annotations[ + reqs_by_anno[radio_choice] + ] + ] + annotation_vectors = np.array( + [np.array(unpack_float32(anno_embedding))] + ) + requirement_vectors = np.array(req_embeddings) + if select == "2D": + plot_2_sets_in_one_2d( + minifold_vectors_2d(annotation_vectors), + minifold_vectors_2d( + requirement_vectors + ), + first_title="Annotation", + second_title="Requirements", + first_labels=radio_choice, + second_labels=req_labels, + ) + else: + reqs_vectors_3d = minifold_vectors_3d( + requirement_vectors + ) + anno_vectors_3d = minifold_vectors_3d( + annotation_vectors + ) + plot_2_sets_in_one_3d( + anno_vectors_3d, + reqs_vectors_3d, + first_title="Annotation", + second_title="Requirements", + first_labels=radio_choice, + second_labels=req_labels, + ) + + +if __name__ == "__main__": + make_a_tc_report() diff --git a/test2text/services/db/client.py b/test2text/services/db/client.py index 7de53fe..1c84951 100644 --- a/test2text/services/db/client.py +++ b/test2text/services/db/client.py @@ -1,4 +1,5 @@ import sqlite3 + import sqlite_vec import logging @@ -53,7 +54,7 @@ def _turn_on_foreign_keys(self): def _init_tables(self): self.requirements = RequirementsTable(self.conn, self.embedding_dim) self.annotations = AnnotationsTable(self.conn, self.embedding_dim) - self.test_cases = TestCasesTable(self.conn) + self.test_cases = TestCasesTable(self.conn, self.embedding_dim) self.annos_to_reqs = AnnotationsToRequirementsTable(self.conn) self.cases_to_annos = TestCasesToAnnotationsTable(self.conn) self.requirements.init_table() @@ -67,8 +68,234 @@ def close(self): self.conn.commit() self.conn.close() - def __del__(self): - self.close() - def __exit__(self, exc_type, exc_val, exc_tb): self.close() + + def __enter__(self): + return self + + def get_table_names(self): + """ + Returns a list of all user-defined tables in the database. + + :return: List[str] - table names + """ + cursor = self.conn.execute( + "SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%';" + ) + tables = [row[0] for row in cursor.fetchall()] + cursor.close() + return tables + + def get_column_values(self, *columns: str, from_table: str): + cursor = self.conn.execute(f"SELECT {', '.join(columns)} FROM {from_table}") + return cursor.fetchall() + + @property + def get_db_full_info(self): + """ + Returns table information: + - row_count: number of records in the table + - columns: list of dicts as in get_extended_table_info (name, type, non-NULL count, typeof distribution) + + :return: dict + """ + db_tables_info = {} + table_names = self.get_table_names() + for table_name in table_names: + row_count = self.count_all_entries(table_name) + db_tables_info.update( + { + table_name: row_count, + } + ) + return db_tables_info + + def count_all_entries(self, from_table: str) -> int: + count = self.conn.execute(f"SELECT COUNT(*) FROM {from_table}").fetchone()[0] + return count + + def count_notnull_entries(self, *columns: str, from_table: str) -> int: + count = self.conn.execute( + f"SELECT COUNT(*) FROM {from_table} WHERE {' AND '.join([column + ' IS NOT NULL' for column in columns])}" + ).fetchone()[0] + return count + + def has_column(self, column_name: str, table_name: str) -> bool: + """ + Returns True if the table has a column, otherwise False. + + :param column_name: name of the column + :param table_name: name of the table + :return: bool + """ + cursor = self.conn.execute(f'PRAGMA table_info("{table_name}")') + columns = [row[1] for row in cursor.fetchall()] # row[1] is the column name + cursor.close() + return column_name in columns + + def get_null_entries(self, from_table: str) -> list: + cursor = self.conn.execute( + f"SELECT id, summary FROM {from_table} WHERE embedding IS NULL" + ) + return cursor.fetchall() + + def get_distances(self) -> list[tuple[int, int, float]]: + """ + Returns a list of tuples containing the id of the annotation and the id of the requirement, + and the distance between their embeddings (anno_id, req_id, distance). + The distance is calculated using the L2 norm. The results are ordered by requirement ID and distance. + """ + cursor = self.conn.execute(""" + SELECT + Annotations.id AS anno_id, + Requirements.id AS req_id, + vec_distance_L2(Annotations.embedding, Requirements.embedding) AS distance + FROM Annotations, Requirements + WHERE Annotations.embedding IS NOT NULL AND Requirements.embedding IS NOT NULL + ORDER BY req_id, distance + """) + return cursor.fetchall() + + def get_embeddings_from_annotations_to_requirements_table(self): + """ + Returns a list of annotation's embeddings that are stored in the AnnotationsToRequirements table. + The embeddings are ordered by annotation ID. + """ + cursor = self.conn.execute(""" + SELECT embedding FROM Annotations + WHERE id IN ( + SELECT DISTINCT annotation_id FROM AnnotationsToRequirements + ) + """) + return cursor.fetchall() + + def join_all_tables_by_requirements( + self, where_clauses="", params=None + ) -> list[tuple]: + """ + Join all tables related to requirements based on the provided where clauses and parameters. + return a list of tuples containing : + req_id, + req_external_id, + req_summary, + req_embedding, + anno_id, + anno_summary, + anno_embedding, + distance, + case_id, + test_script, + test_case + """ + where_sql = f"WHERE {' AND '.join(where_clauses)}" if where_clauses else "" + sql = f""" + SELECT + Requirements.id as req_id, + Requirements.external_id as req_external_id, + Requirements.summary as req_summary, + Requirements.embedding as req_embedding, + + Annotations.id as anno_id, + Annotations.summary as anno_summary, + Annotations.embedding as anno_embedding, + + AnnotationsToRequirements.cached_distance as distance, + + TestCases.id as case_id, + TestCases.test_script as test_script, + TestCases.test_case as test_case + FROM + Requirements + JOIN AnnotationsToRequirements ON Requirements.id = AnnotationsToRequirements.requirement_id + JOIN Annotations ON Annotations.id = AnnotationsToRequirements.annotation_id + JOIN CasesToAnnos ON Annotations.id = CasesToAnnos.annotation_id + JOIN TestCases ON TestCases.id = CasesToAnnos.case_id + {where_sql} + ORDER BY + Requirements.id, AnnotationsToRequirements.cached_distance, TestCases.id + LIMIT ? + """ + data = self.conn.execute(sql, params) + return data.fetchall() + + def get_ordered_values_from_requirements( + self, distance_sql="", where_clauses="", distance_order_sql="", params=None + ) -> list[tuple]: + where_sql = f"WHERE {' AND '.join(where_clauses)}" if where_clauses else "" + sql = f""" + SELECT + Requirements.id as req_id, + Requirements.external_id as req_external_id, + Requirements.summary as req_summary + {distance_sql} + FROM + Requirements + {where_sql} + ORDER BY + {distance_order_sql}Requirements.id + """ + data = self.conn.execute(sql, params) + return data.fetchall() + + def get_ordered_values_from_test_cases( + self, distance_sql="", where_clauses="", distance_order_sql="", params=None + ) -> list[tuple]: + where_sql = f"WHERE {' AND '.join(where_clauses)}" if where_clauses else "" + sql = f""" + SELECT + TestCases.id as case_id, + TestCases.test_script as test_script, + TestCases.test_case as test_case + {distance_sql} + FROM + TestCases + {where_sql} + ORDER BY + {distance_order_sql}TestCases.id + """ + data = self.conn.execute(sql, params) + return data.fetchall() + + def join_all_tables_by_test_cases( + self, where_clauses="", params=None + ) -> list[tuple]: + where_sql = "" + if where_clauses: + where_sql = f"WHERE {' AND '.join(where_clauses)}" + + sql = f""" + SELECT + TestCases.id as case_id, + TestCases.test_script as test_script, + TestCases.test_case as test_case, + + Annotations.id as anno_id, + Annotations.summary as anno_summary, + Annotations.embedding as anno_embedding, + + AnnotationsToRequirements.cached_distance as distance, + + Requirements.id as req_id, + Requirements.external_id as req_external_id, + Requirements.summary as req_summary, + Requirements.embedding as req_embedding + FROM + TestCases + JOIN CasesToAnnos ON TestCases.id = CasesToAnnos.case_id + JOIN Annotations ON Annotations.id = CasesToAnnos.annotation_id + JOIN AnnotationsToRequirements ON Annotations.id = AnnotationsToRequirements.annotation_id + JOIN Requirements ON Requirements.id = AnnotationsToRequirements.requirement_id + {where_sql} + ORDER BY + case_id, distance, req_id + LIMIT ? + """ + data = self.conn.execute(sql, params) + return data.fetchall() + + def get_embeddings_by_id(self, id1: int, from_table: str): + cursor = self.conn.execute( + f"SELECT embedding FROM {from_table} WHERE id = ?", (id1,) + ) + return cursor.fetchone() diff --git a/test2text/services/db/tables/test_case.py b/test2text/services/db/tables/test_case.py index 646acda..7db4ab4 100644 --- a/test2text/services/db/tables/test_case.py +++ b/test2text/services/db/tables/test_case.py @@ -1,27 +1,50 @@ +from string import Template from typing import Optional +from sqlite_vec import serialize_float32 +from sqlite3 import Connection from .abstract_table import AbstractTable class TestCasesTable(AbstractTable): + def __init__(self, connection: Connection, embedding_size: int): + super().__init__(connection) + self.embedding_size = embedding_size + def init_table(self): - self.connection.execute(""" - CREATE TABLE IF NOT EXISTS TestCases ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - test_script TEXT NOT NULL, - test_case TEXT NOT NULL, - UNIQUE (test_script, test_case) - ) - """) + self.connection.execute( + Template(""" + + CREATE TABLE IF NOT EXISTS TestCases ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + test_script TEXT NOT NULL, + test_case TEXT NOT NULL, + embedding float[$embedding_size], + UNIQUE (test_script, test_case) + + CHECK ( + typeof(embedding) == 'null' or + (typeof(embedding) == 'blob' + and vec_length(embedding) == $embedding_size) + ) + ) + """).substitute(embedding_size=self.embedding_size) + ) - def insert(self, test_script: str, test_case: str) -> Optional[int]: + def insert( + self, test_script: str, test_case: str, embedding: list[float] = None + ) -> Optional[int]: cursor = self.connection.execute( """ - INSERT OR IGNORE INTO TestCases (test_script, test_case) - VALUES (?, ?) + INSERT OR IGNORE INTO TestCases (test_script, test_case, embedding) + VALUES (?, ?, ?) RETURNING id """, - (test_script, test_case), + ( + test_script, + test_case, + serialize_float32(embedding) if embedding is not None else None, + ), ) result = cursor.fetchone() cursor.close() diff --git a/test2text/services/embeddings/annotation_embeddings_controls.py b/test2text/services/embeddings/annotation_embeddings_controls.py index 7bf1650..8815506 100644 --- a/test2text/services/embeddings/annotation_embeddings_controls.py +++ b/test2text/services/embeddings/annotation_embeddings_controls.py @@ -5,54 +5,41 @@ BATCH_SIZE = 30 -def count_all_annotations() -> int: - db = get_db_client() - count = db.conn.execute("SELECT COUNT(*) FROM Annotations").fetchone()[0] - return count - - -def count_embedded_annotations() -> int: - db = get_db_client() - count = db.conn.execute( - "SELECT COUNT(*) FROM Annotations WHERE embedding IS NOT NULL" - ).fetchone()[0] - return count - - OnProgress = Callable[[float], None] def embed_annotations(*_, embed_all=False, on_progress: OnProgress = None): - from .embed import embed_annotations_batch - - db = get_db_client() - annotations_count = count_all_annotations() - embedded_annotations_count = count_embedded_annotations() - if embed_all: - annotations_to_embed = annotations_count - else: - annotations_to_embed = annotations_count - embedded_annotations_count - - batch = [] - - def write_batch(batch: list[tuple[int, str]]): - embeddings = embed_annotations_batch([annotation for _, annotation in batch]) - for i, (anno_id, annotation) in enumerate(batch): - embedding = embeddings[i] - db.annotations.set_embedding(anno_id, embedding) - db.conn.commit() - - annotations = db.conn.execute(f""" - SELECT id, summary FROM Annotations - {"WHERE embedding IS NULL" if not embed_all else ""} - """) - - for i, (anno_id, summary) in enumerate(annotations.fetchall()): - if on_progress: - on_progress((i + 1) / annotations_to_embed) - batch.append((anno_id, summary)) - if len(batch) == BATCH_SIZE: - write_batch(batch) - batch = [] - - write_batch(batch) + with get_db_client() as db: + from .embed import embed_annotations_batch + + annotations_count = db.count_all_entries("Annotations") + embedded_annotations_count = db.count_notnull_entries( + "embedding", from_table="Annotations" + ) + if embed_all: + annotations_to_embed = annotations_count + else: + annotations_to_embed = annotations_count - embedded_annotations_count + + batch = [] + + def write_batch(batch: list[tuple[int, str]]): + embeddings = embed_annotations_batch( + [annotation for _, annotation in batch] + ) + for i, (anno_id, annotation) in enumerate(batch): + embedding = embeddings[i] + db.annotations.set_embedding(anno_id, embedding) + db.conn.commit() + + annotations = db.get_null_entries(from_table="Annotations") + + for i, (anno_id, summary) in enumerate(annotations): + if on_progress: + on_progress((i + 1) / annotations_to_embed) + batch.append((anno_id, summary)) + if len(batch) == BATCH_SIZE: + write_batch(batch) + batch = [] + + write_batch(batch) diff --git a/test2text/services/embeddings/cache_distances.py b/test2text/services/embeddings/cache_distances.py index 545a955..ee9c1da 100644 --- a/test2text/services/embeddings/cache_distances.py +++ b/test2text/services/embeddings/cache_distances.py @@ -2,34 +2,27 @@ def refresh_and_get_distances() -> list[float]: - db = get_db_client() - db.annos_to_reqs.recreate_table() - # Link requirements to annotations - annotations = db.conn.execute(""" - SELECT - Annotations.id AS anno_id, - Requirements.id AS req_id, - vec_distance_L2(Annotations.embedding, Requirements.embedding) AS distance - FROM Annotations, Requirements - WHERE Annotations.embedding IS NOT NULL AND Requirements.embedding IS NOT NULL - ORDER BY req_id, distance - """) - # Visualize distances - distances = [] - current_req_id = None - current_req_annos = 0 - for i, (anno_id, req_id, distance) in enumerate(annotations.fetchall()): - distances.append(distance) - if req_id != current_req_id: - current_req_id = req_id - current_req_annos = 0 - if current_req_annos < 5 or distance < 0.7: - db.annos_to_reqs.insert( - annotation_id=anno_id, requirement_id=req_id, cached_distance=distance - ) - current_req_annos += 1 - db.conn.commit() - return distances + with get_db_client() as db: + db.annos_to_reqs.recreate_table() + # Link requirements to annotations + annotations = db.get_distances() + # Visualize distances + distances = [] + current_req_id = None + current_req_annos = 0 + for i, (anno_id, req_id, distance) in enumerate(annotations): + distances.append(distance) + if req_id != current_req_id: + current_req_id = req_id + current_req_annos = 0 + if current_req_annos < 5 or distance < 0.7: + db.annos_to_reqs.insert( + annotation_id=anno_id, + requirement_id=req_id, + cached_distance=distance, + ) + current_req_annos += 1 + return distances if __name__ == "__main__": diff --git a/test2text/services/loaders/convert_trace_annos.py b/test2text/services/loaders/convert_trace_annos.py index 3764c8a..aad1ce9 100644 --- a/test2text/services/loaders/convert_trace_annos.py +++ b/test2text/services/loaders/convert_trace_annos.py @@ -26,58 +26,56 @@ def write_table_row(*args, **kwargs): def trace_test_cases_to_annos(trace_files: list): - db = get_db_client() - - st.info( - "Reading trace files and inserting test case + annotations pairs into database..." - ) - write_table_row( - "File name", - "Extracted pairs test cases + annotations", - "Inserted to data base", - "Ignored (dublicates or wrong id)", - ) - for i, file in enumerate(trace_files): - stringio = io.StringIO(file.getvalue().decode("utf-8")) - reader = csv.reader(stringio) - current_tc = EMPTY - concat_summary = EMPTY - test_script = EMPTY - global_columns = next(reader) - insertions = list() - for row in reader: - if row[0] == "TestCaseStart": - current_tc = row[1] - test_script = EMPTY - concat_summary = EMPTY - elif row[0] == "Summary": - continue - elif row[0] == "TestCaseEnd": - if not is_empty(current_tc) and not is_empty(concat_summary): - case_id = db.test_cases.get_or_insert( - test_script=test_script, test_case=current_tc - ) - annotation_id = db.annotations.get_or_insert(summary=concat_summary) - insertions.append( - db.cases_to_annos.insert( - case_id=case_id, annotation_id=annotation_id - ) - ) - else: - if not is_empty(row[global_columns.index("TestCase")]): - if current_tc != row[global_columns.index("TestCase")]: - current_tc = row[global_columns.index("TestCase")] - if is_empty(test_script) and not is_empty( - row[global_columns.index("TestScript")] - ): - test_script = row[global_columns.index("TestScript")] - concat_summary += row[0] + with get_db_client() as db: + st.info( + "Reading trace files and inserting test case + annotations pairs into database..." + ) write_table_row( - file.name, - len(insertions), - sum(insertions), - len(insertions) - sum(insertions), + "File name", + "Extracted pairs test cases + annotations", + "Inserted to data base", + "Ignored (dublicates or wrong id)", ) - - db.conn.commit() - db.conn.close() + for i, file in enumerate(trace_files): + stringio = io.StringIO(file.getvalue().decode("utf-8")) + reader = csv.reader(stringio) + current_tc = EMPTY + concat_summary = EMPTY + test_script = EMPTY + global_columns = next(reader) + insertions = list() + for row in reader: + if row[0] == "TestCaseStart": + current_tc = row[1] + test_script = EMPTY + concat_summary = EMPTY + elif row[0] == "Summary": + continue + elif row[0] == "TestCaseEnd": + if not is_empty(current_tc) and not is_empty(concat_summary): + case_id = db.test_cases.get_or_insert( + test_script=test_script, test_case=current_tc + ) + annotation_id = db.annotations.get_or_insert( + summary=concat_summary + ) + insertions.append( + db.cases_to_annos.insert( + case_id=case_id, annotation_id=annotation_id + ) + ) + else: + if not is_empty(row[global_columns.index("TestCase")]): + if current_tc != row[global_columns.index("TestCase")]: + current_tc = row[global_columns.index("TestCase")] + if is_empty(test_script) and not is_empty( + row[global_columns.index("TestScript")] + ): + test_script = row[global_columns.index("TestScript")] + concat_summary += row[0] + write_table_row( + file.name, + len(insertions), + sum(insertions), + len(insertions) - sum(insertions), + ) diff --git a/test2text/services/loaders/index_annotations.py b/test2text/services/loaders/index_annotations.py index e9de241..781229b 100644 --- a/test2text/services/loaders/index_annotations.py +++ b/test2text/services/loaders/index_annotations.py @@ -12,27 +12,25 @@ def index_annotations_from_files(files: list, *_, on_file_start: OnFileStart = None): - db = get_db_client() - - for i, file in enumerate(files): - file_counter = None - if on_file_start: - file_counter = on_file_start(f"{i + 1}/{len(files)}", file.name) - stringio = io.StringIO(file.getvalue().decode("utf-8")) - reader = csv.reader(stringio) - insertions = [] - - for i, row in enumerate(reader): - if file_counter: - file_counter.write(i) - [summary, _, test_script, test_case, *_] = row - anno_id = db.annotations.get_or_insert(summary=summary) - tc_id = db.test_cases.get_or_insert( - test_script=test_script, test_case=test_case - ) - insertions.append( - db.cases_to_annos.insert(case_id=tc_id, annotation_id=anno_id) - ) - - db.conn.commit() - return None + with get_db_client() as db: + for i, file in enumerate(files): + file_counter = None + if on_file_start: + file_counter = on_file_start(f"{i + 1}/{len(files)}", file.name) + stringio = io.StringIO(file.getvalue().decode("utf-8")) + reader = csv.reader(stringio) + insertions = [] + + for i, row in enumerate(reader): + if file_counter: + file_counter.write(i) + [summary, _, test_script, test_case, *_] = row + anno_id = db.annotations.get_or_insert(summary=summary) + tc_id = db.test_cases.get_or_insert( + test_script=test_script, test_case=test_case + ) + insertions.append( + db.cases_to_annos.insert(case_id=tc_id, annotation_id=anno_id) + ) + + return None diff --git a/test2text/services/loaders/index_requirements.py b/test2text/services/loaders/index_requirements.py index 713eb75..c112985 100644 --- a/test2text/services/loaders/index_requirements.py +++ b/test2text/services/loaders/index_requirements.py @@ -18,50 +18,47 @@ def index_requirements_from_files( *args, on_start_file: OnStartFile = None, on_requirement_written: OnRequirementWritten = None, -) -> tuple[int]: - db = get_db_client() - for i, file in enumerate(files): - if on_start_file: - on_start_file(i + 1, file.name) - stringio = io.StringIO(file.getvalue().decode("utf-8")) - reader = csv.reader(stringio) +) -> int: + with get_db_client() as db: + for i, file in enumerate(files): + if on_start_file: + on_start_file(i + 1, file.name) + stringio = io.StringIO(file.getvalue().decode("utf-8")) + reader = csv.reader(stringio) - try: - for _ in range(3): - next(reader) - except StopIteration: - raise ValueError( - f"The uploaded CSV file {file.name} does not have enough lines. " - "Please ensure it has at least 3 lines of data." - ) + try: + for _ in range(3): + next(reader) + except StopIteration: + raise ValueError( + f"The uploaded CSV file {file.name} does not have enough lines. " + "Please ensure it has at least 3 lines of data." + ) - batch = [] - last_requirement = "" - - def write_batch(): - nonlocal batch - embeddings = embed_requirements_batch( - [requirement for _, requirement in batch] - ) - for i, (external_id, requirement) in enumerate(batch): - embedding = embeddings[i] - db.requirements.insert(requirement, embedding, external_id) - if on_requirement_written: - on_requirement_written(external_id) - db.conn.commit() batch = [] + last_requirement = "" + + def write_batch(): + nonlocal batch + embeddings = embed_requirements_batch( + [requirement for _, requirement in batch] + ) + for i, (external_id, requirement) in enumerate(batch): + embedding = embeddings[i] + db.requirements.insert(requirement, embedding, external_id) + if on_requirement_written: + on_requirement_written(external_id) + db.conn.commit() + batch = [] - for row in reader: - [external_id, requirement, *_] = row - if requirement.startswith("..."): - requirement = last_requirement + requirement[3:] - last_requirement = requirement - batch.append((external_id, requirement)) - if len(batch) == BATCH_SIZE: - write_batch() - write_batch() - # Check requirements - cursor = db.conn.execute(""" - SELECT COUNT(*) FROM Requirements - """) - return cursor.fetchone()[0] + for row in reader: + [external_id, requirement, *_] = row + if requirement.startswith("..."): + requirement = last_requirement + requirement[3:] + last_requirement = requirement + batch.append((external_id, requirement)) + if len(batch) == BATCH_SIZE: + write_batch() + write_batch() + # Check requirements + return db.count_all_entries(from_table="Requirements") diff --git a/test2text/services/utils/math_utils.py b/test2text/services/utils/math_utils.py new file mode 100644 index 0000000..f43ba16 --- /dev/null +++ b/test2text/services/utils/math_utils.py @@ -0,0 +1,2 @@ +def round_distance(distance: float) -> float: + return round(distance, 2) diff --git a/test2text/services/visualisation/visualize_vectors.py b/test2text/services/visualisation/visualize_vectors.py index 4b13ba0..fe59298 100644 --- a/test2text/services/visualisation/visualize_vectors.py +++ b/test2text/services/visualisation/visualize_vectors.py @@ -11,15 +11,16 @@ FONT_SIZE = 18 DOT_SIZE_2D = 20 DOT_SIZE_3D = 10 +LABELS_SUMMARY_LENGTH = 15 def extract_annotation_vectors(db: DbClient): vectors = [] - embeddings = db.conn.execute("SELECT embedding FROM Annotations") - if embeddings.fetchone() is None: + embeddings = db.get_column_values("embedding", from_table="Annotations") + if not embeddings: st.error("Embeddings is empty. Please fill embeddings in annotations.") return None - for row in embeddings.fetchall(): + for row in embeddings: if row[0] is not None: vectors.append(np.array(unpack_float32(row[0]))) return np.array(vectors) @@ -27,78 +28,111 @@ def extract_annotation_vectors(db: DbClient): def extract_closest_annotation_vectors(db: DbClient): vectors = [] - embeddings = db.conn.execute(""" - SELECT embedding FROM Annotations - WHERE id IN ( - SELECT DISTINCT annotation_id FROM AnnotationsToRequirements - ) - """) - if embeddings.fetchone() is None: + embeddings = db.get_embeddings_from_annotations_to_requirements_table() + if not embeddings: st.error("Embeddings is empty. Please calculate and cache distances.") return None - for row in embeddings.fetchall(): + for row in embeddings: vectors.append(np.array(unpack_float32(row[0]))) return np.array(vectors) def extract_requirement_vectors(db: DbClient): vectors = [] - embeddings = db.conn.execute("SELECT embedding FROM Requirements") - if embeddings.fetchone() is None: + embeddings = db.get_column_values("embedding", from_table="Requirements") + if not embeddings: st.error("Embeddings is empty. Please fill embeddings in requirements.") return None - for row in embeddings.fetchall(): + for row in embeddings: vectors.append(np.array(unpack_float32(row[0]))) return np.array(vectors) def minifold_vectors_2d(vectors: np.array): - tsne = TSNE(n_components=2, random_state=0) + """ + Reduces high-dimensional vectors to 2D using TSNE. + Handles cases where the number of samples is too small for TSNE by returning the input as-is. + """ + n_samples = vectors.shape[0] + # TSNE requires perplexity < n_samples + if n_samples < 2: + return vectors.reshape(n_samples, -1)[:, :2] + perplexity = min(30, max(1, (n_samples - 1) // 3)) + tsne = TSNE(n_components=2, random_state=0, perplexity=perplexity) vectors_2d = tsne.fit_transform(vectors) return vectors_2d def minifold_vectors_3d(vectors: np.array): - tsne = TSNE(n_components=3, random_state=0) + n_samples = vectors.shape[0] + # TSNE requires perplexity < n_samples + if n_samples < 2: + return vectors.reshape(n_samples, -1)[:, :3] + perplexity = min(30, n_samples - 1) if n_samples > 1 else 1 + tsne = TSNE(n_components=3, random_state=0, perplexity=perplexity) vectors_3d = tsne.fit_transform(vectors) return vectors_3d -def plot_vectors_2d(vectors_2d: np.array, title): - fig = px.scatter(x=vectors_2d[:, 0], y=vectors_2d[:, 1]) - fig.update_layout(title=title, xaxis_title="X", yaxis_title="Y") +def plot_vectors_2d(vectors_2d: np.array, title: str, labels: list = None): + fig = px.scatter( + x=vectors_2d[:, 0], + y=vectors_2d[:, 1], + text=labels, + ) + fig.update_traces(textposition="top center") + fig.update_layout( + title=title, + xaxis_title="X", + yaxis_title="Y", + ) st.plotly_chart(fig, use_container_width=True) -def plot_vectors_3d(vectors_3d: np.array, title): +def plot_vectors_3d(vectors_3d: np.array, title: str, labels: list = None): fig = px.scatter_3d( x=vectors_3d[:, 0], y=vectors_3d[:, 1], z=vectors_3d[:, 2], color=vectors_3d[:, 2], + text=labels, ) + fig.update_traces(textposition="top center") fig.update_layout(title=title, xaxis_title="X", yaxis_title="Y") st.plotly_chart(fig, use_container_width=True) def plot_2_sets_in_one_2d( - first_set_of_vec, second_set_of_vec, first_title, second_title + first_set_of_vec, + second_set_of_vec, + first_title, + second_title, + first_color="red", + second_color="green", + first_labels=None, + second_labels=None, ): fig = go.Figure() fig.add_trace( go.Scatter( x=first_set_of_vec[:, 0], y=first_set_of_vec[:, 1], - mode="markers", - name={first_title}, + mode="markers+text", + name=first_title, + text=first_labels, + textposition="top center", + marker=dict(color=f"{first_color}"), ) ) fig.add_trace( go.Scatter( x=second_set_of_vec[:, 0], y=second_set_of_vec[:, 1], - mode="markers", - name={second_title}, + mode="markers+text", + name=second_title, + text=second_labels, + textposition="top center", + marker=dict(color=f"{second_color}"), ) ) fig.update_layout( @@ -108,7 +142,14 @@ def plot_2_sets_in_one_2d( def plot_2_sets_in_one_3d( - first_set_of_vec, second_set_of_vec, first_title, second_title + first_set_of_vec, + second_set_of_vec, + first_title, + second_title, + first_color="red", + second_color="green", + first_labels=None, + second_labels=None, ): fig = go.Figure() fig.add_trace( @@ -116,8 +157,11 @@ def plot_2_sets_in_one_3d( x=first_set_of_vec[:, 0], y=first_set_of_vec[:, 1], z=first_set_of_vec[:, 2], - mode="markers", + mode="markers+text", name=first_title, + text=first_labels, + textposition="top left", + marker=dict(color=f"{first_color}"), ) ) @@ -126,8 +170,11 @@ def plot_2_sets_in_one_3d( x=second_set_of_vec[:, 0], y=second_set_of_vec[:, 1], z=second_set_of_vec[:, 2], - mode="markers", + mode="markers+text", name=second_title, + text=second_labels, + textposition="top center", + marker=dict(color=f"{second_color}"), ) ) @@ -144,61 +191,72 @@ def plot_2_sets_in_one_3d( def visualize_vectors(): st.header("Visualizing vectors") - db = get_db_client() - Req_tab, Anno_tab, Req_Anno_tab = st.tabs( - ["Requirements", "Annotations", "Requirements vs Annotations"] - ) - with Req_tab: - st.subheader("Requirements vectors") - progress_bar = st.progress(0) - - requirement_vectors = extract_requirement_vectors(db) - progress_bar.progress(20, "Extracted") - reqs_vectors_2d = minifold_vectors_2d(requirement_vectors) - progress_bar.progress(40, "Minifolded for 2D") - plot_vectors_2d(reqs_vectors_2d, "Requirements") - progress_bar.progress(60, "Plotted in 2D") - reqs_vectors_3d = minifold_vectors_3d(requirement_vectors) - progress_bar.progress(80, "Minifolded for 3D") - plot_vectors_3d(reqs_vectors_3d, "Requirements") - progress_bar.progress(100, "Plotted in 3D") - - with Anno_tab: - st.subheader("Annotations vectors") - progress_bar = st.progress(0) - - annotation_vectors = extract_annotation_vectors(db) - progress_bar.progress(20, "Extracted") - anno_vectors_2d = minifold_vectors_2d(annotation_vectors) - progress_bar.progress(40, "Minifolded for 2D") - plot_vectors_2d(anno_vectors_2d, "Annotations") - progress_bar.progress(60, "Plotted in 2D") - anno_vectors_3d = minifold_vectors_3d(annotation_vectors) - progress_bar.progress(80, "Minifolded for 3D") - plot_vectors_3d(anno_vectors_3d, "Annotations") - progress_bar.progress(100, "Plotted in 3D") - - with Req_Anno_tab: - # Show how these 2 groups of vectors are different - st.subheader("Requirements vs Annotations") - progress_bar = st.progress(40, "Extracted") - plot_2_sets_in_one_2d( - reqs_vectors_2d, anno_vectors_2d, "Requerements", "Annotations" - ) - progress_bar.progress(60, "Plotted in 2D") - - plot_2_sets_in_one_3d( - reqs_vectors_3d, anno_vectors_3d, "Requerements", "Annotations" - ) - progress_bar.progress(80, "Plotted in 3D") - - anno_vectors_2d = minifold_vectors_2d(extract_closest_annotation_vectors(db)) - - plot_2_sets_in_one_2d( - reqs_vectors_2d, anno_vectors_2d, "Requerements", "Annotations" + with get_db_client() as db: + req_tab, anno_tab, req_anno_tab = st.tabs( + ["Requirements", "Annotations", "Requirements vs Annotations"] ) - progress_bar.progress(100, "Minifolded and Plotted in 2D") - db.conn.close() + with req_tab: + st.subheader("Requirements vectors") + progress_bar = st.progress(0) + + requirement_vectors = extract_requirement_vectors(db) + progress_bar.progress(20, "Extracted") + reqs_vectors_2d = minifold_vectors_2d(requirement_vectors) + progress_bar.progress(40, "Minifolded for 2D") + req_labels = db.get_column_values("external_id", from_table="Requirements") + plot_vectors_2d(reqs_vectors_2d, "Requirements", labels=req_labels) + progress_bar.progress(60, "Plotted in 2D") + reqs_vectors_3d = minifold_vectors_3d(requirement_vectors) + progress_bar.progress(80, "Minifolded for 3D") + plot_vectors_3d(reqs_vectors_3d, "Requirements") + progress_bar.progress(100, "Plotted in 3D") + + with anno_tab: + st.subheader("Annotations vectors") + progress_bar = st.progress(0) + + annotation_vectors = extract_annotation_vectors(db) + progress_bar.progress(20, "Extracted") + anno_vectors_2d = minifold_vectors_2d(annotation_vectors) + progress_bar.progress(40, "Minifolded for 2D") + plot_vectors_2d(anno_vectors_2d, "Annotations") + progress_bar.progress(60, "Plotted in 2D") + anno_vectors_3d = minifold_vectors_3d(annotation_vectors) + progress_bar.progress(80, "Minifolded for 3D") + plot_vectors_3d(anno_vectors_3d, "Annotations") + progress_bar.progress(100, "Plotted in 3D") + + with req_anno_tab: + # Show how these 2 groups of vectors are different + st.subheader("Requirements vs Annotations") + progress_bar = st.progress(40, "Extracted") + plot_2_sets_in_one_2d( + reqs_vectors_2d, + anno_vectors_2d, + first_title="Requirements", + second_title="Annotations", + ) + progress_bar.progress(60, "Plotted in 2D") + + plot_2_sets_in_one_3d( + reqs_vectors_3d, + anno_vectors_3d, + first_title="Requirements", + second_title="Annotations", + ) + progress_bar.progress(80, "Plotted in 3D") + + anno_vectors_2d = minifold_vectors_2d( + extract_closest_annotation_vectors(db) + ) + + plot_2_sets_in_one_2d( + reqs_vectors_2d, + anno_vectors_2d, + first_title="Requirements", + second_title="Annotations", + ) + progress_bar.progress(100, "Minifolded and Plotted in 2D") if __name__ == "__main__": diff --git a/tests/test_db/test_tables/test_annotations.py b/tests/test_db/test_tables/test_annotations.py index a2d2feb..7257172 100644 --- a/tests/test_db/test_tables/test_annotations.py +++ b/tests/test_db/test_tables/test_annotations.py @@ -67,11 +67,7 @@ def test_set_embedding(self): orig_embedding = [0.1] * self.db.annotations.embedding_size self.db.annotations.set_embedding(id1, orig_embedding) self.db.conn.commit() - cursor = self.db.conn.execute( - "SELECT embedding FROM Annotations WHERE id = ?", (id1,) - ) - result = cursor.fetchone() - cursor.close() + result = self.db.get_embeddings_by_id(id1, "Annotations") self.assertIsNotNone(result) read_embedding = unpack_float32(result[0]) self.assertEqual(len(read_embedding), self.db.annotations.embedding_size) @@ -80,11 +76,7 @@ def test_set_embedding(self): new_embedding = [0.9] * self.db.annotations.embedding_size self.db.annotations.set_embedding(id1, new_embedding) self.db.conn.commit() - cursor = self.db.conn.execute( - "SELECT embedding FROM Annotations WHERE id = ?", (id1,) - ) - result = cursor.fetchone() - cursor.close() + result = self.db.get_embeddings_by_id(id1, "Annotations") self.assertIsNotNone(result) read_embedding = unpack_float32(result[0]) self.assertEqual(len(read_embedding), self.db.annotations.embedding_size) diff --git a/tests/test_db/test_tables/test_requirements.py b/tests/test_db/test_tables/test_requirements.py index 38e749d..adb372d 100644 --- a/tests/test_db/test_tables/test_requirements.py +++ b/tests/test_db/test_tables/test_requirements.py @@ -45,3 +45,18 @@ def test_insert_duplicate_external_id(self): id2 = self.db.requirements.insert("Test Requirement 3", external_id="ext-2") self.assertIsNotNone(id1) self.assertIsNone(id2) + + def test_insert_embedding(self): + embedding = [0.1] * self.db.requirements.embedding_size + id1 = self.db.requirements.insert("Test Requirement 5", embedding) + self.assertIsNotNone(id1) + + def test_insert_short_embedding(self): + short_embedding = [0.1] * (self.db.requirements.embedding_size - 1) + id1 = self.db.requirements.insert("Test Requirement 6", short_embedding) + self.assertIsNone(id1) + + def test_insert_long_embedding(self): + long_embedding = [0.1] * (self.db.requirements.embedding_size + 1) + id1 = self.db.requirements.insert("Test Requirement 7", long_embedding) + self.assertIsNone(id1) diff --git a/tests/test_db/test_tables/test_test_cases.py b/tests/test_db/test_tables/test_test_cases.py index 85d5725..a1bd922 100644 --- a/tests/test_db/test_tables/test_test_cases.py +++ b/tests/test_db/test_tables/test_test_cases.py @@ -42,3 +42,22 @@ def test_get_or_insert_duplicate(self): self.assertIsNotNone(id1) self.assertIsNotNone(id2) self.assertEqual(id1, id2) + + def test_insert_embedding(self): + embedding = [0.1] * self.db.test_cases.embedding_size + id1 = self.db.test_cases.insert("Test Script 12", "Test Case 12", embedding) + self.assertIsNotNone(id1) + + def test_insert_short_embedding(self): + short_embedding = [0.1] * (self.db.test_cases.embedding_size - 1) + id1 = self.db.test_cases.insert( + "Test Script 13", "Test Case 13", short_embedding + ) + self.assertIsNone(id1) + + def test_insert_long_embedding(self): + long_embedding = [0.1] * (self.db.test_cases.embedding_size + 1) + id1 = self.db.test_cases.insert( + "Test Script 14", "Test Case 14", long_embedding + ) + self.assertIsNone(id1)