diff --git a/src/llm_benchmark/algorithms/primes.py b/src/llm_benchmark/algorithms/primes.py index 9c56daa..5a5878a 100644 --- a/src/llm_benchmark/algorithms/primes.py +++ b/src/llm_benchmark/algorithms/primes.py @@ -1,6 +1,7 @@ from typing import List + class Primes: @staticmethod def is_prime(n: int) -> bool: @@ -14,7 +15,7 @@ def is_prime(n: int) -> bool: """ if n < 2: return False - for i in range(2, n): + for i in range(2, int(n ** 0.5) + 1): if n % i == 0: return False return True @@ -30,7 +31,7 @@ def sum_primes(n: int) -> int: int: Sum of primes from 0 to n """ sum_ = 0 - for i in range(n): + for i in range(2, n): if Primes.is_prime(i): sum_ += i return sum_ @@ -46,10 +47,13 @@ def prime_factors(n: int) -> List[int]: List[int]: List of prime factors """ ret = [] - while n > 1: - for i in range(2, n + 1): - if n % i == 0: - ret.append(i) - n = n // i - break - return ret + i = 2 + while i * i <= n: + if n % i: + i += 1 + else: + n //= i + ret.append(i) + if n > 1: + ret.append(n) + return ret \ No newline at end of file diff --git a/src/llm_benchmark/algorithms/sort.py b/src/llm_benchmark/algorithms/sort.py index 5f36289..d8e8c78 100644 --- a/src/llm_benchmark/algorithms/sort.py +++ b/src/llm_benchmark/algorithms/sort.py @@ -2,41 +2,49 @@ from typing import List +from sys import maxsize +from typing import List + + class Sort: @staticmethod def sort_list(v: List[int]) -> None: - """Sort a list of integers in place + """ + Sort a list of integers in place Args: v (List[int]): List of integers """ - for i in range(len(v)): - for j in range(i + 1, len(v)): - if v[i] > v[j]: - v[i], v[j] = v[j], v[i] + v.sort() @staticmethod def dutch_flag_partition(v: List[int], pivot_value: int) -> None: - """Dutch flag partitioning + """ + Dutch flag partitioning Args: v (List[int]): List of integers pivot_value (int): Pivot value """ - next_value = 0 - - for i in range(len(v)): - if v[i] < pivot_value: - v[i], v[next_value] = v[next_value], v[i] - next_value += 1 - for i in range(next_value, len(v)): - if v[i] == pivot_value: - v[i], v[next_value] = v[next_value], v[i] - next_value += 1 + low = 0 + mid = 0 + high = len(v) - 1 + + while mid <= high: + if v[mid] < pivot_value: + v[low], v[mid] = v[mid], v[low] + low += 1 + mid += 1 + elif v[mid] == pivot_value: + mid += 1 + else: + v[mid], v[high] = v[high], v[mid] + high -= 1 @staticmethod def max_n(v: List[int], n: int) -> List[int]: - """Find the maximum n numbers in a list + """ + Find the maximum n numbers in a list Args: v (List[int]): List of integers @@ -45,15 +53,5 @@ def max_n(v: List[int], n: int) -> List[int]: Returns: List[int]: List of maximum n values """ - tmp = v.copy() - ret = [-maxsize - 1] * n - for i in range(n): - max_val = tmp[0] - max_idx = 0 - for j in range(1, len(tmp)): - if tmp[j] > max_val: - max_val = tmp[j] - max_idx = j - ret[i] = max_val - tmp.pop(max_idx) - return ret + v.sort() + return v[-n:] \ No newline at end of file diff --git a/src/llm_benchmark/control/double.py b/src/llm_benchmark/control/double.py index 4be41d7..5d0f744 100644 --- a/src/llm_benchmark/control/double.py +++ b/src/llm_benchmark/control/double.py @@ -1,43 +1,35 @@ +import collections from typing import List class DoubleForLoop: @staticmethod def sum_square(n: int) -> int: - """Sum of squares of numbers from 0 to n (exclusive) + '''Sum of squares of numbers from 0 to n (exclusive) Args: n (int): Number to sum up to Returns: int: Sum of squares of numbers from 0 to n - """ - sum_ = 0 - for i in range(n): - for j in range(n): - if i == j: - sum_ += i * j - return sum_ + ''' + return sum(i * i for i in range(n)) @staticmethod def sum_triangle(n: int) -> int: - """Sum of triangle of numbers from 0 to n (exclusive) + '''Sum of triangle of numbers from 0 to n (exclusive) Args: n (int): Number to sum up to Returns: int: Sum of triangle of numbers from 0 to n - """ - sum_ = 0 - for i in range(n): - for j in range(i + 1): - sum_ += j - return sum_ + ''' + return sum(sum(range(i + 1)) for i in range(n)) @staticmethod - def count_pairs(arr: List[int]) -> int: - """Count pairs of numbers in an array + def count_pairs(arr: list[int]) -> int: + '''Count pairs of numbers in an array A pair is defined as exactly two numbers in the array that are equal. @@ -46,21 +38,13 @@ def count_pairs(arr: List[int]) -> int: Returns: int: Number of pairs in the array - """ - count = 0 - for i in range(len(arr)): - ndup = 0 - for j in range(len(arr)): - if arr[i] == arr[j]: - ndup += 1 - if ndup == 2: - count += 1 - - return count // 2 + ''' + from collections import Counter + return sum(count // 2 for count in Counter(arr).values() if count == 2) @staticmethod - def count_duplicates(arr0: List[int], arr1: List[int]) -> int: - """Count duplicates between two arrays + def count_duplicates(arr0: list[int], arr1: list[int]) -> int: + '''Count duplicates between two arrays Args: arr0 (List[int]): Array of integers @@ -68,26 +52,17 @@ def count_duplicates(arr0: List[int], arr1: List[int]) -> int: Returns: int: Number of duplicates between the two arrays - """ - count = 0 - for i in range(len(arr0)): - for j in range(len(arr1)): - if i == j and arr0[i] == arr1[j]: - count += 1 - return count + ''' + return sum(a == b for a, b in zip(arr0, arr1)) @staticmethod - def sum_matrix(m: List[List[int]]) -> int: - """Sum of matrix of integers + def sum_matrix(m: list[list[int]]) -> int: + '''Sum of matrix of integers Args: m (List[List[int]]): Matrix of integers Returns: int: Sum of matrix of integers - """ - sum_ = 0 - for i in range(len(m)): - for j in range(len(m[i])): - sum_ += m[i][j] - return sum_ + ''' + return sum(sum(row) for row in m) \ No newline at end of file diff --git a/src/llm_benchmark/control/single.py b/src/llm_benchmark/control/single.py index 9a314e6..dff7376 100644 --- a/src/llm_benchmark/control/single.py +++ b/src/llm_benchmark/control/single.py @@ -12,10 +12,7 @@ def sum_range(n: int) -> int: Returns: int: Sum of range of numbers from 0 to n """ - arr = [] - for i in range(n): - arr.append(i) - return sum(arr) + return n * (n - 1) // 2 @staticmethod def max_list(v: List[int]) -> int: @@ -27,11 +24,7 @@ def max_list(v: List[int]) -> int: Returns: int: Maximum value in the vector """ - max_val = v[0] - for i in range(1, len(v)): - if v[i] > max_val: - max_val = v[i] - return max_val + return max(v) @staticmethod def sum_modulus(n: int, m: int) -> int: @@ -44,8 +37,4 @@ def sum_modulus(n: int, m: int) -> int: Returns: int: Sum of modulus of numbers from 0 to n """ - arr = [] - for i in range(n): - if i % m == 0: - arr.append(i) - return sum(arr) + return sum(i for i in range(n) if i % m == 0) \ No newline at end of file diff --git a/src/llm_benchmark/datastructures/dslist.py b/src/llm_benchmark/datastructures/dslist.py index d282a9c..eeda557 100644 --- a/src/llm_benchmark/datastructures/dslist.py +++ b/src/llm_benchmark/datastructures/dslist.py @@ -12,10 +12,7 @@ def modify_list(v: List[int]) -> List[int]: Returns: List[int]: Modified list of integers """ - ret = [] - for i in range(len(v)): - ret.append(v[i] + 1) - return ret + return [x + 1 for x in v] @staticmethod def search_list(v: List[int], n: int) -> List[int]: @@ -29,11 +26,7 @@ def search_list(v: List[int], n: int) -> List[int]: Returns: List[int]: List of indices where the value is found """ - ret = [] - for i in range(len(v)): - if v[i] == n: - ret.append(i) - return ret + return [i for i, x in enumerate(v) if x == n] @staticmethod def sort_list(v: List[int]) -> List[int]: @@ -45,13 +38,7 @@ def sort_list(v: List[int]) -> List[int]: Returns: List[int]: Sorted list of integers """ - ret = v.copy() - for i in range(len(ret)): - for j in range(i + 1, len(ret)): - if ret[i] > ret[j]: - ret[i], ret[j] = ret[j], ret[i] - - return ret + return sorted(v) @staticmethod def reverse_list(v: List[int]) -> List[int]: @@ -63,10 +50,7 @@ def reverse_list(v: List[int]) -> List[int]: Returns: List[int]: Reversed list of integers """ - ret = [] - for i in range(len(v)): - ret.append(v[len(v) - 1 - i]) - return ret + return v[::-1] @staticmethod def rotate_list(v: List[int], n: int) -> List[int]: @@ -79,12 +63,8 @@ def rotate_list(v: List[int], n: int) -> List[int]: Returns: List[int]: Rotated list of integers """ - ret = [] - for i in range(n, len(v)): - ret.append(v[i]) - for i in range(n): - ret.append(v[i]) - return ret + n = n % len(v) + return v[n:] + v[:n] @staticmethod def merge_lists(v1: List[int], v2: List[int]) -> List[int]: @@ -97,9 +77,4 @@ def merge_lists(v1: List[int], v2: List[int]) -> List[int]: Returns: List[int]: Merged list of integers """ - ret = [] - for i in range(len(v1)): - ret.append(v1[i]) - for i in range(len(v2)): - ret.append(v2[i]) - return ret + return v1 + v2 \ No newline at end of file diff --git a/src/llm_benchmark/sql/query.py b/src/llm_benchmark/sql/query.py index 53f6885..d63c515 100644 --- a/src/llm_benchmark/sql/query.py +++ b/src/llm_benchmark/sql/query.py @@ -2,7 +2,16 @@ from textwrap import dedent + class SqlQuery: + _conn = None + + @classmethod + def get_connection(cls): + if cls._conn is None: + cls._conn = sqlite3.connect("data/chinook.db") + return cls._conn + @staticmethod def query_album(name: str) -> bool: """Check if an album exists @@ -13,11 +22,11 @@ def query_album(name: str) -> bool: Returns: bool: True if the album exists, False otherwise """ - conn = sqlite3.connect("data/chinook.db") + conn = SqlQuery.get_connection() cur = conn.cursor() - cur.execute(f"SELECT * FROM Album WHERE Title = '{name}'") - return len(cur.fetchall()) > 0 + cur.execute("SELECT 1 FROM Album WHERE Title = ?", (name,)) + return cur.fetchone() is not None @staticmethod def join_albums() -> list: @@ -26,26 +35,20 @@ def join_albums() -> list: Returns: list: """ - conn = sqlite3.connect("data/chinook.db") + conn = SqlQuery.get_connection() cur = conn.cursor() cur.execute( dedent( """\ SELECT - t.Name AS TrackName, ( - SELECT a2.Title - FROM Album a2 - WHERE a2.AlbumId = t.AlbumId - ) AS AlbumName, - ( - SELECT ar.Name - FROM Artist ar - JOIN Album a3 ON a3.ArtistId = ar.ArtistId - WHERE a3.AlbumId = t.AlbumId - ) AS ArtistName + t.Name AS TrackName, + a.Title AS AlbumName, + ar.Name AS ArtistName FROM Track t + JOIN Album a ON a.AlbumId = t.AlbumId + JOIN Artist ar ON ar.ArtistId = a.ArtistId """ ) ) @@ -58,7 +61,7 @@ def top_invoices() -> list: Returns: list: List of tuples """ - conn = sqlite3.connect("data/chinook.db") + conn = SqlQuery.get_connection() cur = conn.cursor() cur.execute( @@ -70,9 +73,10 @@ def top_invoices() -> list: i.Total FROM Invoice i - JOIN Customer c ON c.CustomerId = i.CustomerId + JOIN Customer c ON c.CustomerId = i.CustomerId = i.CustomerId ORDER BY i.Total DESC + LIMIT 10 """ ) ) - return cur.fetchall()[:10] + return cur.fetchall() \ No newline at end of file