Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions test2text/pages/controls/controls_page.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,7 @@ def controls_page():

def refresh_counts():
with get_db_client() as db:
st.session_state["all_annotations_count"] = db.count_all_entries(
"Annotations"
)
st.session_state["all_annotations_count"] = db.annotations.count
st.session_state["embedded_annotations_count"] = (
db.count_notnull_entries("embedding", from_table="Annotations")
)
Expand Down
58 changes: 53 additions & 5 deletions test2text/services/db/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def __exit__(self, exc_type, exc_val, exc_tb):
def __enter__(self):
return self

def get_table_names(self):
def get_table_names(self) -> list[str]:
"""
Returns a list of all user-defined tables in the database.

Expand All @@ -87,7 +87,13 @@ def get_table_names(self):
cursor.close()
return tables

def get_column_values(self, *columns: str, from_table: str):
def get_column_values(self, *columns: str, from_table: str) -> list[tuple]:
"""
Returns the values of the specified columns from the specified table.
:param columns: list of column names
:param from_table: name of the table
:return: list of tuples containing the values of the specified columns
"""
cursor = self.conn.execute(f"SELECT {', '.join(columns)} FROM {from_table}")
return cursor.fetchall()

Expand Down Expand Up @@ -116,6 +122,11 @@ def count_all_entries(self, from_table: str) -> int:
return count

def count_notnull_entries(self, *columns: str, from_table: str) -> int:
"""
Count the number of non-null entries in the specified columns of the specified table.
:param columns: list of column names
:param from_table: name of the table
"""
count = self.conn.execute(
f"SELECT COUNT(*) FROM {from_table} WHERE {' AND '.join([column + ' IS NOT NULL' for column in columns])}"
).fetchone()[0]
Expand All @@ -135,6 +146,9 @@ def has_column(self, column_name: str, table_name: str) -> bool:
return column_name in columns

def get_null_entries(self, from_table: str) -> list:
"""
Returns values (id and summary) witch has null values in its embedding column.
"""
cursor = self.conn.execute(
f"SELECT id, summary FROM {from_table} WHERE embedding IS NULL"
)
Expand Down Expand Up @@ -174,8 +188,8 @@ def join_all_tables_by_requirements(
self, where_clauses="", params=None
) -> list[tuple]:
"""
Join all tables related to requirements based on the provided where clauses and parameters.
return a list of tuples containing :
Extract values from requirements with related annotations and their test cases based on the provided where clauses and parameters.
Return a list of tuples containing :
req_id,
req_external_id,
req_summary,
Expand Down Expand Up @@ -222,6 +236,14 @@ def join_all_tables_by_requirements(
def get_ordered_values_from_requirements(
self, distance_sql="", where_clauses="", distance_order_sql="", params=None
) -> list[tuple]:
"""
Extracted values from Requirements table based on the provided where clauses and specified parameters ordered by distance and id.
Return a list of tuples containing :
req_id,
req_external_id,
req_summary,
distance between annotation and requirement embeddings,
"""
where_sql = f"WHERE {' AND '.join(where_clauses)}" if where_clauses else ""
sql = f"""
SELECT
Expand All @@ -241,6 +263,14 @@ def get_ordered_values_from_requirements(
def get_ordered_values_from_test_cases(
self, distance_sql="", where_clauses="", distance_order_sql="", params=None
) -> list[tuple]:
"""
Extracted values from TestCases table based on the provided where clauses and specified parameters ordered by distance and id.
Return a list of tuples containing :
case_id,
test_script,
test_case,
distance between test case and typed by user text embeddings if it is specified,
"""
where_sql = f"WHERE {' AND '.join(where_clauses)}" if where_clauses else ""
sql = f"""
SELECT
Expand All @@ -260,6 +290,21 @@ def get_ordered_values_from_test_cases(
def join_all_tables_by_test_cases(
self, where_clauses="", params=None
) -> list[tuple]:
"""
Join all tables related to test cases based on the provided where clauses and specified parameters.
Return a list of tuples containing :
case_id,
test_script,
test_case,
anno_id,
anno_summary,
anno_embedding,
distance between annotation and requirement embeddings,
req_id,
req_external_id,
req_summary,
req_embedding
"""
where_sql = ""
if where_clauses:
where_sql = f"WHERE {' AND '.join(where_clauses)}"
Expand Down Expand Up @@ -294,7 +339,10 @@ def join_all_tables_by_test_cases(
data = self.conn.execute(sql, params)
return data.fetchall()

def get_embeddings_by_id(self, id1: int, from_table: str):
def get_embeddings_by_id(self, id1: int, from_table: str) -> float:
"""
Returns the embedding of the specified id from the specified table.
"""
cursor = self.conn.execute(
f"SELECT embedding FROM {from_table} WHERE id = ?", (id1,)
)
Expand Down
4 changes: 4 additions & 0 deletions test2text/services/db/streamlit_conn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@


def get_db_client() -> DbClient:
"""
Returns a DbClient instance connected to the database where requirements, annotations, test cases and their relations are stored.
:return: DbClient instance
"""
from test2text.services.utils import res_folder

return DbClient(res_folder.get_file_path("db.sqlite3"))
26 changes: 24 additions & 2 deletions test2text/services/db/tables/annos_to_reqs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,14 @@


class AnnotationsToRequirementsTable(AbstractTable):
def init_table(self):
"""
This class represents the relationship between annotations and requirements in the database by closest distance between them.
"""

def init_table(self) -> None:
"""
Creates the AnnotationsToRequirements table in the database if it does not already exist.
"""
self.connection.execute("""
CREATE TABLE IF NOT EXISTS AnnotationsToRequirements (
annotation_id INTEGER NOT NULL,
Expand All @@ -15,7 +22,10 @@ def init_table(self):
)
""")

def recreate_table(self):
def recreate_table(self) -> None:
"""
Drops the AnnotationsToRequirements table if it exists and recreates it.
"""
self.connection.execute("""
DROP TABLE IF EXISTS AnnotationsToRequirements
""")
Expand All @@ -24,6 +34,13 @@ def recreate_table(self):
def insert(
self, annotation_id: int, requirement_id: int, cached_distance: float
) -> bool:
"""
Inserts a new entry into the AnnotationsToRequirements table.
:param annotation_id: The ID of the annotation
:param requirement_id: The ID of the requirement
:param cached_distance: The cached distance between the annotation and the requirement
:return: True if the insertion was successful, False otherwise.
"""
try:
cursor = self.connection.execute(
"""
Expand All @@ -42,7 +59,12 @@ def insert(
pass
return False

@property
def count(self) -> int:
"""
Returns the number of entries in the AnnotationsToRequirements table.
:return: int - the number of entries in the table.
"""
cursor = self.connection.execute(
"SELECT COUNT(*) FROM AnnotationsToRequirements"
)
Expand Down
33 changes: 33 additions & 0 deletions test2text/services/db/tables/annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,18 @@


class AnnotationsTable(AbstractTable):
"""
This class represents the annotations of test cases in the database.
"""

def __init__(self, connection: Connection, embedding_size: int):
super().__init__(connection)
self.embedding_size = embedding_size

def init_table(self):
"""
Creates the Annotations table in the database if it does not already exist.
"""
self.connection.execute(
Template("""
CREATE TABLE IF NOT EXISTS Annotations (
Expand All @@ -29,6 +36,12 @@ def init_table(self):
)

def insert(self, summary: str, embedding: list[float] = None) -> Optional[int]:
"""
Inserts a new annotation into the database. If the annotation already exists, it updates the existing record.
:param summary: The summary of the annotation
:param embedding: The embedding of the annotation (optional)
:return: The ID of the inserted or updated annotation, or None if the annotation already exists and was updated.
"""
cursor = self.connection.execute(
"""
INSERT OR IGNORE INTO Annotations (summary, embedding)
Expand All @@ -45,6 +58,12 @@ def insert(self, summary: str, embedding: list[float] = None) -> Optional[int]:
return None

def get_or_insert(self, summary: str, embedding: list[float] = None) -> int:
"""
Inserts a new annotation into the database if it does not already exist, otherwise returns the existing annotation's ID.
:param summary: The summary of the annotation
:param embedding: The embedding of the annotation (optional)
:return: The ID of the inserted or existing annotation.
"""
inserted_id = self.insert(summary, embedding)
if inserted_id is not None:
return inserted_id
Expand All @@ -61,6 +80,11 @@ def get_or_insert(self, summary: str, embedding: list[float] = None) -> int:
return result[0]

def set_embedding(self, anno_id: int, embedding: list[float]) -> None:
"""
Sets the embedding for a given annotation ID.
:param anno_id: The ID of the annotation
:param embedding: The new embedding for the annotation
"""
if len(embedding) != self.embedding_size:
raise ValueError(
f"Embedding size must be {self.embedding_size}, got {len(embedding)}"
Expand All @@ -74,3 +98,12 @@ def set_embedding(self, anno_id: int, embedding: list[float]) -> None:
""",
(serialized_embedding, anno_id),
)

@property
def count(self) -> int:
"""
Returns the number of entries in the Annotations table.
:return: int - the number of entries in the table.
"""
cursor = self.connection.execute("SELECT COUNT(*) FROM Annotations")
return cursor.fetchone()[0]
11 changes: 10 additions & 1 deletion test2text/services/db/tables/cases_to_annos.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,11 @@


class TestCasesToAnnotationsTable(AbstractTable):
def init_table(self):
"""
This class represents the relationship between test cases and annotations in the database.
"""

def init_table(self) -> None:
self.connection.execute("""
CREATE TABLE IF NOT EXISTS CasesToAnnos (
case_id INTEGER NOT NULL,
Expand Down Expand Up @@ -40,6 +44,11 @@ def insert(self, case_id: int, annotation_id: int) -> bool:
pass
return False

@property
def count(self) -> int:
"""
Returns the number of entries in the CasesToAnnos table.
:return: int - the number of entries in the table.
"""
cursor = self.connection.execute("SELECT COUNT(*) FROM CasesToAnnos")
return cursor.fetchone()[0]
25 changes: 24 additions & 1 deletion test2text/services/db/tables/requirements.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,18 @@


class RequirementsTable(AbstractTable):
"""
This class represents the requirements for test cases in the database.
"""

def __init__(self, connection: Connection, embedding_size: int):
super().__init__(connection)
self.embedding_size = embedding_size

def init_table(self):
def init_table(self) -> None:
"""
Creates the Requirements table in the database if it does not already exist.
"""
self.connection.execute(
Template("""
CREATE TABLE IF NOT EXISTS Requirements (
Expand All @@ -32,6 +39,13 @@ def init_table(self):
def insert(
self, summary: str, embedding: list[float] = None, external_id: str = None
) -> Optional[int]:
"""
Inserts a new requirement into the database. If the requirement already exists, it updates the existing record.
:param summary: The summary of the requirement
:param embedding: The embedding of the requirement (optional)
:param external_id: The external ID of the requirement (optional)
:return: The ID of the inserted or updated requirement, or None if the requirement already exists and was updated.
"""
cursor = self.connection.execute(
"""
INSERT OR IGNORE INTO Requirements (summary, embedding, external_id)
Expand All @@ -49,3 +63,12 @@ def insert(
return result[0]
else:
return None

@property
def count(self) -> int:
"""
Returns the number of entries in the Requirements table.
:return: int - the number of entries in the table.
"""
cursor = self.connection.execute("SELECT COUNT(*) FROM Requirements")
return cursor.fetchone()[0]
25 changes: 25 additions & 0 deletions test2text/services/db/tables/test_case.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ def __init__(self, connection: Connection, embedding_size: int):
self.embedding_size = embedding_size

def init_table(self):
"""
Creates the TestCases table in the database if it does not already exist.
"""
self.connection.execute(
Template("""

Expand All @@ -34,6 +37,13 @@ def init_table(self):
def insert(
self, test_script: str, test_case: str, embedding: list[float] = None
) -> Optional[int]:
"""
Inserts a new test case into the database. If the test case already exists, it updates the existing record.
:param test_script: The test script of the test case
:param test_case: The test case of the test case
:param embedding: The embedding of the test case (optional)
:return: The ID of the inserted or updated test case, or None if the test case already exists and was updated.
"""
cursor = self.connection.execute(
"""
INSERT OR IGNORE INTO TestCases (test_script, test_case, embedding)
Expand All @@ -54,6 +64,12 @@ def insert(
return None

def get_or_insert(self, test_script: str, test_case: str) -> int:
"""
Inserts a new test case into the database if it does not already exist, otherwise returns the existing test case's ID.
:param test_script: The test script of the test case
:param test_case: The test case of the test case
:return: The ID of the inserted or existing test case.
"""
inserted_id = self.insert(test_script, test_case)
if inserted_id is not None:
return inserted_id
Expand All @@ -68,3 +84,12 @@ def get_or_insert(self, test_script: str, test_case: str) -> int:
result = cursor.fetchone()
cursor.close()
return result[0]

@property
def count(self) -> int:
"""
Returns the number of entries in the TestCases table.
:return: int - the number of entries in the table.
"""
cursor = self.connection.execute("SELECT COUNT(*) FROM TestCases")
return cursor.fetchone()[0]
2 changes: 1 addition & 1 deletion test2text/services/loaders/index_requirements.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,4 +61,4 @@ def write_batch():
write_batch()
write_batch()
# Check requirements
return db.count_all_entries(from_table="Requirements")
return db.requirements.count
Loading