diff --git a/test2text/pages/controls/controls_page.py b/test2text/pages/controls/controls_page.py index 166c333..611008e 100644 --- a/test2text/pages/controls/controls_page.py +++ b/test2text/pages/controls/controls_page.py @@ -12,9 +12,7 @@ def controls_page(): def refresh_counts(): with get_db_client() as db: - st.session_state["all_annotations_count"] = db.count_all_entries( - "Annotations" - ) + st.session_state["all_annotations_count"] = db.annotations.count st.session_state["embedded_annotations_count"] = ( db.count_notnull_entries("embedding", from_table="Annotations") ) diff --git a/test2text/services/db/client.py b/test2text/services/db/client.py index 1c84951..3624e61 100644 --- a/test2text/services/db/client.py +++ b/test2text/services/db/client.py @@ -74,7 +74,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): def __enter__(self): return self - def get_table_names(self): + def get_table_names(self) -> list[str]: """ Returns a list of all user-defined tables in the database. @@ -87,7 +87,13 @@ def get_table_names(self): cursor.close() return tables - def get_column_values(self, *columns: str, from_table: str): + def get_column_values(self, *columns: str, from_table: str) -> list[tuple]: + """ + Returns the values of the specified columns from the specified table. + :param columns: list of column names + :param from_table: name of the table + :return: list of tuples containing the values of the specified columns + """ cursor = self.conn.execute(f"SELECT {', '.join(columns)} FROM {from_table}") return cursor.fetchall() @@ -116,6 +122,11 @@ def count_all_entries(self, from_table: str) -> int: return count def count_notnull_entries(self, *columns: str, from_table: str) -> int: + """ + Count the number of non-null entries in the specified columns of the specified table. + :param columns: list of column names + :param from_table: name of the table + """ count = self.conn.execute( f"SELECT COUNT(*) FROM {from_table} WHERE {' AND '.join([column + ' IS NOT NULL' for column in columns])}" ).fetchone()[0] @@ -135,6 +146,9 @@ def has_column(self, column_name: str, table_name: str) -> bool: return column_name in columns def get_null_entries(self, from_table: str) -> list: + """ + Returns values (id and summary) witch has null values in its embedding column. + """ cursor = self.conn.execute( f"SELECT id, summary FROM {from_table} WHERE embedding IS NULL" ) @@ -174,8 +188,8 @@ def join_all_tables_by_requirements( self, where_clauses="", params=None ) -> list[tuple]: """ - Join all tables related to requirements based on the provided where clauses and parameters. - return a list of tuples containing : + Extract values from requirements with related annotations and their test cases based on the provided where clauses and parameters. + Return a list of tuples containing : req_id, req_external_id, req_summary, @@ -222,6 +236,14 @@ def join_all_tables_by_requirements( def get_ordered_values_from_requirements( self, distance_sql="", where_clauses="", distance_order_sql="", params=None ) -> list[tuple]: + """ + Extracted values from Requirements table based on the provided where clauses and specified parameters ordered by distance and id. + Return a list of tuples containing : + req_id, + req_external_id, + req_summary, + distance between annotation and requirement embeddings, + """ where_sql = f"WHERE {' AND '.join(where_clauses)}" if where_clauses else "" sql = f""" SELECT @@ -241,6 +263,14 @@ def get_ordered_values_from_requirements( def get_ordered_values_from_test_cases( self, distance_sql="", where_clauses="", distance_order_sql="", params=None ) -> list[tuple]: + """ + Extracted values from TestCases table based on the provided where clauses and specified parameters ordered by distance and id. + Return a list of tuples containing : + case_id, + test_script, + test_case, + distance between test case and typed by user text embeddings if it is specified, + """ where_sql = f"WHERE {' AND '.join(where_clauses)}" if where_clauses else "" sql = f""" SELECT @@ -260,6 +290,21 @@ def get_ordered_values_from_test_cases( def join_all_tables_by_test_cases( self, where_clauses="", params=None ) -> list[tuple]: + """ + Join all tables related to test cases based on the provided where clauses and specified parameters. + Return a list of tuples containing : + case_id, + test_script, + test_case, + anno_id, + anno_summary, + anno_embedding, + distance between annotation and requirement embeddings, + req_id, + req_external_id, + req_summary, + req_embedding + """ where_sql = "" if where_clauses: where_sql = f"WHERE {' AND '.join(where_clauses)}" @@ -294,7 +339,10 @@ def join_all_tables_by_test_cases( data = self.conn.execute(sql, params) return data.fetchall() - def get_embeddings_by_id(self, id1: int, from_table: str): + def get_embeddings_by_id(self, id1: int, from_table: str) -> float: + """ + Returns the embedding of the specified id from the specified table. + """ cursor = self.conn.execute( f"SELECT embedding FROM {from_table} WHERE id = ?", (id1,) ) diff --git a/test2text/services/db/streamlit_conn.py b/test2text/services/db/streamlit_conn.py index 8f3bed6..3222ddd 100644 --- a/test2text/services/db/streamlit_conn.py +++ b/test2text/services/db/streamlit_conn.py @@ -2,6 +2,10 @@ def get_db_client() -> DbClient: + """ + Returns a DbClient instance connected to the database where requirements, annotations, test cases and their relations are stored. + :return: DbClient instance + """ from test2text.services.utils import res_folder return DbClient(res_folder.get_file_path("db.sqlite3")) diff --git a/test2text/services/db/tables/annos_to_reqs.py b/test2text/services/db/tables/annos_to_reqs.py index 06b32b8..af3a553 100644 --- a/test2text/services/db/tables/annos_to_reqs.py +++ b/test2text/services/db/tables/annos_to_reqs.py @@ -3,7 +3,14 @@ class AnnotationsToRequirementsTable(AbstractTable): - def init_table(self): + """ + This class represents the relationship between annotations and requirements in the database by closest distance between them. + """ + + def init_table(self) -> None: + """ + Creates the AnnotationsToRequirements table in the database if it does not already exist. + """ self.connection.execute(""" CREATE TABLE IF NOT EXISTS AnnotationsToRequirements ( annotation_id INTEGER NOT NULL, @@ -15,7 +22,10 @@ def init_table(self): ) """) - def recreate_table(self): + def recreate_table(self) -> None: + """ + Drops the AnnotationsToRequirements table if it exists and recreates it. + """ self.connection.execute(""" DROP TABLE IF EXISTS AnnotationsToRequirements """) @@ -24,6 +34,13 @@ def recreate_table(self): def insert( self, annotation_id: int, requirement_id: int, cached_distance: float ) -> bool: + """ + Inserts a new entry into the AnnotationsToRequirements table. + :param annotation_id: The ID of the annotation + :param requirement_id: The ID of the requirement + :param cached_distance: The cached distance between the annotation and the requirement + :return: True if the insertion was successful, False otherwise. + """ try: cursor = self.connection.execute( """ @@ -42,7 +59,12 @@ def insert( pass return False + @property def count(self) -> int: + """ + Returns the number of entries in the AnnotationsToRequirements table. + :return: int - the number of entries in the table. + """ cursor = self.connection.execute( "SELECT COUNT(*) FROM AnnotationsToRequirements" ) diff --git a/test2text/services/db/tables/annotations.py b/test2text/services/db/tables/annotations.py index b6be8f1..029b704 100644 --- a/test2text/services/db/tables/annotations.py +++ b/test2text/services/db/tables/annotations.py @@ -7,11 +7,18 @@ class AnnotationsTable(AbstractTable): + """ + This class represents the annotations of test cases in the database. + """ + def __init__(self, connection: Connection, embedding_size: int): super().__init__(connection) self.embedding_size = embedding_size def init_table(self): + """ + Creates the Annotations table in the database if it does not already exist. + """ self.connection.execute( Template(""" CREATE TABLE IF NOT EXISTS Annotations ( @@ -29,6 +36,12 @@ def init_table(self): ) def insert(self, summary: str, embedding: list[float] = None) -> Optional[int]: + """ + Inserts a new annotation into the database. If the annotation already exists, it updates the existing record. + :param summary: The summary of the annotation + :param embedding: The embedding of the annotation (optional) + :return: The ID of the inserted or updated annotation, or None if the annotation already exists and was updated. + """ cursor = self.connection.execute( """ INSERT OR IGNORE INTO Annotations (summary, embedding) @@ -45,6 +58,12 @@ def insert(self, summary: str, embedding: list[float] = None) -> Optional[int]: return None def get_or_insert(self, summary: str, embedding: list[float] = None) -> int: + """ + Inserts a new annotation into the database if it does not already exist, otherwise returns the existing annotation's ID. + :param summary: The summary of the annotation + :param embedding: The embedding of the annotation (optional) + :return: The ID of the inserted or existing annotation. + """ inserted_id = self.insert(summary, embedding) if inserted_id is not None: return inserted_id @@ -61,6 +80,11 @@ def get_or_insert(self, summary: str, embedding: list[float] = None) -> int: return result[0] def set_embedding(self, anno_id: int, embedding: list[float]) -> None: + """ + Sets the embedding for a given annotation ID. + :param anno_id: The ID of the annotation + :param embedding: The new embedding for the annotation + """ if len(embedding) != self.embedding_size: raise ValueError( f"Embedding size must be {self.embedding_size}, got {len(embedding)}" @@ -74,3 +98,12 @@ def set_embedding(self, anno_id: int, embedding: list[float]) -> None: """, (serialized_embedding, anno_id), ) + + @property + def count(self) -> int: + """ + Returns the number of entries in the Annotations table. + :return: int - the number of entries in the table. + """ + cursor = self.connection.execute("SELECT COUNT(*) FROM Annotations") + return cursor.fetchone()[0] diff --git a/test2text/services/db/tables/cases_to_annos.py b/test2text/services/db/tables/cases_to_annos.py index f0b7e9c..6eb0b7c 100644 --- a/test2text/services/db/tables/cases_to_annos.py +++ b/test2text/services/db/tables/cases_to_annos.py @@ -4,7 +4,11 @@ class TestCasesToAnnotationsTable(AbstractTable): - def init_table(self): + """ + This class represents the relationship between test cases and annotations in the database. + """ + + def init_table(self) -> None: self.connection.execute(""" CREATE TABLE IF NOT EXISTS CasesToAnnos ( case_id INTEGER NOT NULL, @@ -40,6 +44,11 @@ def insert(self, case_id: int, annotation_id: int) -> bool: pass return False + @property def count(self) -> int: + """ + Returns the number of entries in the CasesToAnnos table. + :return: int - the number of entries in the table. + """ cursor = self.connection.execute("SELECT COUNT(*) FROM CasesToAnnos") return cursor.fetchone()[0] diff --git a/test2text/services/db/tables/requirements.py b/test2text/services/db/tables/requirements.py index 91420f3..fb4cdd4 100644 --- a/test2text/services/db/tables/requirements.py +++ b/test2text/services/db/tables/requirements.py @@ -7,11 +7,18 @@ class RequirementsTable(AbstractTable): + """ + This class represents the requirements for test cases in the database. + """ + def __init__(self, connection: Connection, embedding_size: int): super().__init__(connection) self.embedding_size = embedding_size - def init_table(self): + def init_table(self) -> None: + """ + Creates the Requirements table in the database if it does not already exist. + """ self.connection.execute( Template(""" CREATE TABLE IF NOT EXISTS Requirements ( @@ -32,6 +39,13 @@ def init_table(self): def insert( self, summary: str, embedding: list[float] = None, external_id: str = None ) -> Optional[int]: + """ + Inserts a new requirement into the database. If the requirement already exists, it updates the existing record. + :param summary: The summary of the requirement + :param embedding: The embedding of the requirement (optional) + :param external_id: The external ID of the requirement (optional) + :return: The ID of the inserted or updated requirement, or None if the requirement already exists and was updated. + """ cursor = self.connection.execute( """ INSERT OR IGNORE INTO Requirements (summary, embedding, external_id) @@ -49,3 +63,12 @@ def insert( return result[0] else: return None + + @property + def count(self) -> int: + """ + Returns the number of entries in the Requirements table. + :return: int - the number of entries in the table. + """ + cursor = self.connection.execute("SELECT COUNT(*) FROM Requirements") + return cursor.fetchone()[0] diff --git a/test2text/services/db/tables/test_case.py b/test2text/services/db/tables/test_case.py index 7db4ab4..fd3c99c 100644 --- a/test2text/services/db/tables/test_case.py +++ b/test2text/services/db/tables/test_case.py @@ -12,6 +12,9 @@ def __init__(self, connection: Connection, embedding_size: int): self.embedding_size = embedding_size def init_table(self): + """ + Creates the TestCases table in the database if it does not already exist. + """ self.connection.execute( Template(""" @@ -34,6 +37,13 @@ def init_table(self): def insert( self, test_script: str, test_case: str, embedding: list[float] = None ) -> Optional[int]: + """ + Inserts a new test case into the database. If the test case already exists, it updates the existing record. + :param test_script: The test script of the test case + :param test_case: The test case of the test case + :param embedding: The embedding of the test case (optional) + :return: The ID of the inserted or updated test case, or None if the test case already exists and was updated. + """ cursor = self.connection.execute( """ INSERT OR IGNORE INTO TestCases (test_script, test_case, embedding) @@ -54,6 +64,12 @@ def insert( return None def get_or_insert(self, test_script: str, test_case: str) -> int: + """ + Inserts a new test case into the database if it does not already exist, otherwise returns the existing test case's ID. + :param test_script: The test script of the test case + :param test_case: The test case of the test case + :return: The ID of the inserted or existing test case. + """ inserted_id = self.insert(test_script, test_case) if inserted_id is not None: return inserted_id @@ -68,3 +84,12 @@ def get_or_insert(self, test_script: str, test_case: str) -> int: result = cursor.fetchone() cursor.close() return result[0] + + @property + def count(self) -> int: + """ + Returns the number of entries in the TestCases table. + :return: int - the number of entries in the table. + """ + cursor = self.connection.execute("SELECT COUNT(*) FROM TestCases") + return cursor.fetchone()[0] diff --git a/test2text/services/loaders/index_requirements.py b/test2text/services/loaders/index_requirements.py index c112985..1750703 100644 --- a/test2text/services/loaders/index_requirements.py +++ b/test2text/services/loaders/index_requirements.py @@ -61,4 +61,4 @@ def write_batch(): write_batch() write_batch() # Check requirements - return db.count_all_entries(from_table="Requirements") + return db.requirements.count diff --git a/tests/test_db/test_tables/test_annos_to_reqs.py b/tests/test_db/test_tables/test_annos_to_reqs.py index e9caf3b..eadb05b 100644 --- a/tests/test_db/test_tables/test_annos_to_reqs.py +++ b/tests/test_db/test_tables/test_annos_to_reqs.py @@ -13,40 +13,47 @@ def setUp(self): self.wrong_req = 8888 def test_insert_single(self): - count_before = self.db.annos_to_reqs.count() + count_before = self.db.annos_to_reqs.count inserted = self.db.annos_to_reqs.insert(self.anno1, self.req1, 1) - count_after = self.db.annos_to_reqs.count() + count_after = self.db.annos_to_reqs.count self.assertTrue(inserted) self.assertEqual(count_after, count_before + 1) def test_insert_multiple(self): - count_before = self.db.annos_to_reqs.count() + count_before = self.db.annos_to_reqs.count inserted1 = self.db.annos_to_reqs.insert(self.anno1, self.req1, 1) inserted2 = self.db.annos_to_reqs.insert(self.anno2, self.req2, 1) - count_after = self.db.annos_to_reqs.count() + count_after = self.db.annos_to_reqs.count self.assertTrue(inserted1) self.assertTrue(inserted2) self.assertEqual(count_after, count_before + 2) def test_insert_duplicate(self): - count_before = self.db.annos_to_reqs.count() + count_before = self.db.annos_to_reqs.count inserted1 = self.db.annos_to_reqs.insert(self.anno1, self.req1, 1) inserted2 = self.db.annos_to_reqs.insert(self.anno1, self.req1, 1) - count_after = self.db.annos_to_reqs.count() + count_after = self.db.annos_to_reqs.count self.assertTrue(inserted1) self.assertFalse(inserted2) # Second insertion should fail as it's a duplicate self.assertEqual(count_after, count_before + 1) def test_insert_wrong_annotation(self): - count_before = self.db.annos_to_reqs.count() + count_before = self.db.annos_to_reqs.count inserted = self.db.annos_to_reqs.insert(self.wrong_anno, self.req1, 1) - count_after = self.db.annos_to_reqs.count() + count_after = self.db.annos_to_reqs.count self.assertFalse(inserted) # Should fail due to foreign key constraint self.assertEqual(count_after, count_before) def test_insert_wrong_requirement(self): - count_before = self.db.annos_to_reqs.count() + count_before = self.db.annos_to_reqs.count inserted = self.db.annos_to_reqs.insert(self.anno1, self.wrong_req, 1) - count_after = self.db.annos_to_reqs.count() + count_after = self.db.annos_to_reqs.count self.assertFalse(inserted) # Should fail due to foreign key constraint self.assertEqual(count_before, count_after) + + def test_count(self): + count_before = self.db.annos_to_reqs.count + self.db.annos_to_reqs.insert(self.anno1, self.req1, 1) + self.db.annos_to_reqs.insert(self.anno2, self.req2, 1) + count_after = self.db.annos_to_reqs.count + self.assertEqual(count_after, count_before + 2) diff --git a/tests/test_db/test_tables/test_annotations.py b/tests/test_db/test_tables/test_annotations.py index 7257172..a6b9a2b 100644 --- a/tests/test_db/test_tables/test_annotations.py +++ b/tests/test_db/test_tables/test_annotations.py @@ -81,3 +81,9 @@ def test_set_embedding(self): read_embedding = unpack_float32(result[0]) self.assertEqual(len(read_embedding), self.db.annotations.embedding_size) self.assertEqual(round_vector(read_embedding), round_vector(new_embedding)) + + def test_count(self): + count_before = self.db.annotations.count + self.db.annotations.insert("Test Summary 11") + count_after = self.db.annotations.count + self.assertEqual(count_after, count_before + 1) diff --git a/tests/test_db/test_tables/test_cases_to_annos.py b/tests/test_db/test_tables/test_cases_to_annos.py index 6f6d998..aca1b42 100644 --- a/tests/test_db/test_tables/test_cases_to_annos.py +++ b/tests/test_db/test_tables/test_cases_to_annos.py @@ -13,40 +13,47 @@ def setUp(self): self.wrong_anno = 8888 def test_insert_single(self): - count_before = self.db.cases_to_annos.count() + count_before = self.db.cases_to_annos.count inserted = self.db.cases_to_annos.insert(self.case1, self.anno1) - count_after = self.db.cases_to_annos.count() + count_after = self.db.cases_to_annos.count self.assertTrue(inserted) self.assertEqual(count_after, count_before + 1) def test_insert_multiple(self): - count_before = self.db.cases_to_annos.count() + count_before = self.db.cases_to_annos.count inserted1 = self.db.cases_to_annos.insert(self.case1, self.anno1) inserted2 = self.db.cases_to_annos.insert(self.case2, self.anno2) - count_after = self.db.cases_to_annos.count() + count_after = self.db.cases_to_annos.count self.assertTrue(inserted1) self.assertTrue(inserted2) self.assertEqual(count_after, count_before + 2) def test_insert_duplicate(self): - count_before = self.db.cases_to_annos.count() + count_before = self.db.cases_to_annos.count inserted1 = self.db.cases_to_annos.insert(self.case1, self.anno1) inserted2 = self.db.cases_to_annos.insert(self.case1, self.anno1) - count_after = self.db.cases_to_annos.count() + count_after = self.db.cases_to_annos.count self.assertTrue(inserted1) self.assertFalse(inserted2) # Second insertion should fail as it's a duplicate self.assertEqual(count_after, count_before + 1) def test_insert_wrong_case(self): - count_before = self.db.cases_to_annos.count() + count_before = self.db.cases_to_annos.count inserted = self.db.cases_to_annos.insert(self.wrong_case, self.anno1) - count_after = self.db.cases_to_annos.count() + count_after = self.db.cases_to_annos.count self.assertFalse(inserted) # Should fail due to foreign key constraint self.assertEqual(count_after, count_before) def test_insert_wrong_annotation(self): - count_before = self.db.cases_to_annos.count() + count_before = self.db.cases_to_annos.count inserted = self.db.cases_to_annos.insert(self.case1, self.wrong_anno) - count_after = self.db.cases_to_annos.count() + count_after = self.db.cases_to_annos.count self.assertFalse(inserted) # Should fail due to foreign key constraint self.assertEqual(count_before, count_after) + + def test_count(self): + count_before = self.db.cases_to_annos.count + self.db.cases_to_annos.insert(self.case1, self.anno1) + self.db.cases_to_annos.insert(self.case2, self.anno2) + count_after = self.db.cases_to_annos.count + self.assertEqual(count_after, count_before + 2) diff --git a/tests/test_db/test_tables/test_requirements.py b/tests/test_db/test_tables/test_requirements.py index adb372d..125e9ce 100644 --- a/tests/test_db/test_tables/test_requirements.py +++ b/tests/test_db/test_tables/test_requirements.py @@ -60,3 +60,9 @@ def test_insert_long_embedding(self): long_embedding = [0.1] * (self.db.requirements.embedding_size + 1) id1 = self.db.requirements.insert("Test Requirement 7", long_embedding) self.assertIsNone(id1) + + def test_count(self): + count_before = self.db.requirements.count + self.db.requirements.insert("Test Requirement 8") + count_after = self.db.requirements.count + self.assertEqual(count_after, count_before + 1) diff --git a/tests/test_db/test_tables/test_test_cases.py b/tests/test_db/test_tables/test_test_cases.py index a1bd922..9009dfa 100644 --- a/tests/test_db/test_tables/test_test_cases.py +++ b/tests/test_db/test_tables/test_test_cases.py @@ -61,3 +61,9 @@ def test_insert_long_embedding(self): "Test Script 14", "Test Case 14", long_embedding ) self.assertIsNone(id1) + + def test_count(self): + count_before = self.db.test_cases.count + self.db.test_cases.insert("Test Script 15", "Test Case 15") + count_after = self.db.test_cases.count + self.assertEqual(count_after, count_before + 1)