diff --git a/README.md b/README.md index 9b236a83..f50450aa 100644 --- a/README.md +++ b/README.md @@ -123,5 +123,12 @@ easily. Then you can `import tower` and you're off to the races! uv run python ``` +To run tests: + +```bash +uv sync --locked --all-extras --dev +uv run pytest tests +``` + If you need to get the latest OpenAPI SDK, you can run `./scripts/generate-python-api-client.sh`. diff --git a/pyproject.toml b/pyproject.toml index 71efe0ab..7c7e3531 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,14 +5,6 @@ build-backend = "maturin" [project] name = "tower" version = "0.3.13" - - - - - - - - description = "Tower CLI and runtime environment for Tower." authors = [{ name = "Tower Computing Inc.", email = "brad@tower.dev" }] readme = "README.md" @@ -68,8 +60,8 @@ tower = { workspace = true } [dependency-groups] dev = [ - "openapi-python-client>=0.12.1", + "openapi-python-client>=0.12.1", "pytest>=8.3.5", "pytest-httpx>=0.35.0", - "pyiceberg[sql-sqlite]>=0.9.0", + "pyiceberg[sql-sqlite]>=0.9.0", ] diff --git a/src/tower/utils/pyarrow.py b/src/tower/utils/pyarrow.py index a0ad414c..9b8c1111 100644 --- a/src/tower/utils/pyarrow.py +++ b/src/tower/utils/pyarrow.py @@ -3,93 +3,202 @@ import pyarrow as pa import pyarrow.compute as pc -import pyiceberg.types as types +from pyiceberg import types as iceberg_types from pyiceberg.schema import Schema as IcebergSchema from pyiceberg.expressions import ( BooleanExpression, - And, Or, Not, - EqualTo, NotEqualTo, - GreaterThan, GreaterThanOrEqual, - LessThan, LessThanOrEqual, - Literal, Reference + And, + Or, + Not, + EqualTo, + NotEqualTo, + GreaterThan, + GreaterThanOrEqual, + LessThan, + LessThanOrEqual, + Reference, ) -def arrow_to_iceberg_type(arrow_type): + +class FieldIdManager: + """ + Manages the assignment of unique field IDs. + Field IDs in Iceberg start from 1. + """ + + def __init__(self, start_id=1): + # Initialize current_id to start_id - 1 so the first call to get_next_id() returns start_id + self.current_id = start_id - 1 + + def get_next_id(self) -> int: + """Returns the next available unique field ID.""" + self.current_id += 1 + return self.current_id + + +def arrow_to_iceberg_type_recursive( + arrow_type: pa.DataType, field_id_manager: FieldIdManager +) -> iceberg_types.IcebergType: """ - Convert a PyArrow type to a PyIceberg type. Special thanks to Claude for - the help on this. + Recursively convert a PyArrow DataType to a PyIceberg type, + managing field IDs for nested structures. """ - - if pa.types.is_boolean(arrow_type): - return types.BooleanType() + # Primitive type mappings (most remain the same) + if pa.types.is_string(arrow_type) or pa.types.is_large_string(arrow_type): + return iceberg_types.StringType() elif pa.types.is_integer(arrow_type): - # Check the bit width to determine the appropriate Iceberg integer type - bit_width = arrow_type.bit_width - if bit_width <= 32: - return types.IntegerType() + if arrow_type.bit_width <= 32: # type: ignore + return iceberg_types.IntegerType() else: - return types.LongType() + return iceberg_types.LongType() elif pa.types.is_floating(arrow_type): - if arrow_type.bit_width == 32: - return types.FloatType() + if arrow_type.bit_width <= 32: # type: ignore + return iceberg_types.FloatType() else: - return types.DoubleType() - elif pa.types.is_string(arrow_type) or pa.types.is_large_string(arrow_type): - return types.StringType() - elif pa.types.is_binary(arrow_type) or pa.types.is_large_binary(arrow_type): - return types.BinaryType() + return iceberg_types.DoubleType() + elif pa.types.is_boolean(arrow_type): + return iceberg_types.BooleanType() elif pa.types.is_date(arrow_type): - return types.DateType() - elif pa.types.is_timestamp(arrow_type): - return types.TimestampType() + return iceberg_types.DateType() elif pa.types.is_time(arrow_type): - return types.TimeType() + return iceberg_types.TimeType() + elif pa.types.is_timestamp(arrow_type): + if arrow_type.tz is not None: # type: ignore + return iceberg_types.TimestamptzType() + else: + return iceberg_types.TimestampType() + elif pa.types.is_binary(arrow_type) or pa.types.is_large_binary(arrow_type): + return iceberg_types.BinaryType() + elif pa.types.is_fixed_size_binary(arrow_type): + return iceberg_types.FixedType(length=arrow_type.byte_width) # type: ignore elif pa.types.is_decimal(arrow_type): - precision = arrow_type.precision - scale = arrow_type.scale - return types.DecimalType(precision, scale) - elif pa.types.is_list(arrow_type): - element_type = arrow_to_iceberg_type(arrow_type.value_type) - return types.ListType(element_type) + return iceberg_types.DecimalType(arrow_type.precision, arrow_type.scale) # type: ignore + + # Nested type mappings + elif ( + pa.types.is_list(arrow_type) + or pa.types.is_large_list(arrow_type) + or pa.types.is_fixed_size_list(arrow_type) + ): + # The element field itself in Iceberg needs an ID. + element_id = field_id_manager.get_next_id() + + # Recursively convert the list's element type. + # arrow_type.value_type is the DataType of the elements. + # arrow_type.value_field is the Field of the elements (contains name, type, nullability). + element_pyarrow_type = arrow_type.value_type # type: ignore + element_iceberg_type = arrow_to_iceberg_type_recursive( + element_pyarrow_type, field_id_manager + ) + + # Determine if the elements themselves are required (not nullable). + element_is_required = not arrow_type.value_field.nullable # type: ignore + + return iceberg_types.ListType( + element_id=element_id, + element_type=element_iceberg_type, + element_required=element_is_required, + ) elif pa.types.is_struct(arrow_type): - fields = [] - for i, field in enumerate(arrow_type): - name = field.name - field_type = arrow_to_iceberg_type(field.type) - fields.append(types.NestedField(i + 1, name, field_type, required=not field.nullable)) - return types.StructType(*fields) + struct_iceberg_fields = [] + # arrow_type is a StructType. Iterate through its fields. + for i in range(arrow_type.num_fields): # type: ignore + pyarrow_child_field = arrow_type.field(i) # This is a pyarrow.Field + + # Each field within the struct needs its own unique ID. + nested_field_id = field_id_manager.get_next_id() + nested_iceberg_type = arrow_to_iceberg_type_recursive( + pyarrow_child_field.type, field_id_manager + ) + + doc = None + if pyarrow_child_field.metadata and b"doc" in pyarrow_child_field.metadata: + doc = pyarrow_child_field.metadata[b"doc"].decode("utf-8") + + struct_iceberg_fields.append( + iceberg_types.NestedField( + field_id=nested_field_id, + name=pyarrow_child_field.name, + field_type=nested_iceberg_type, + required=not pyarrow_child_field.nullable, + doc=doc, + ) + ) + return iceberg_types.StructType(*struct_iceberg_fields) elif pa.types.is_map(arrow_type): - key_type = arrow_to_iceberg_type(arrow_type.key_type) - value_type = arrow_to_iceberg_type(arrow_type.item_type) - return types.MapType(key_type, value_type) + # Iceberg MapType requires IDs for key and value fields. + key_id = field_id_manager.get_next_id() + value_id = field_id_manager.get_next_id() + + key_iceberg_type = arrow_to_iceberg_type_recursive( + arrow_type.key_type, field_id_manager + ) # type: ignore + value_iceberg_type = arrow_to_iceberg_type_recursive( + arrow_type.item_type, field_id_manager + ) # type: ignore + + # PyArrow map keys are always non-nullable by Arrow specification. + # Nullability of map values comes from the item_field. + value_is_required = not arrow_type.item_field.nullable # type: ignore + + return iceberg_types.MapType( + key_id=key_id, + key_type=key_iceberg_type, + value_id=value_id, + value_type=value_iceberg_type, + value_required=value_is_required, + ) else: raise ValueError(f"Unsupported Arrow type: {arrow_type}") -def convert_pyarrow_field(num, field) -> types.NestedField: - name = field.name - field_type = arrow_to_iceberg_type(field.type) - field_id = num + 1 # Iceberg requires field IDs +def convert_pyarrow_schema( + arrow_schema: pa.Schema, schema_id: int = 1, start_field_id: int = 1 +) -> IcebergSchema: + """ + Convert a PyArrow schema to a PyIceberg schema. + + Args: + arrow_schema: The input PyArrow.Schema. + schema_id: The schema ID for the Iceberg schema. + start_field_id: The starting ID for field ID assignment. + Returns: + An IcebergSchema object. + """ + field_id_manager = FieldIdManager(start_id=start_field_id) + iceberg_fields = [] + + for pyarrow_field in arrow_schema: # pyarrow_field is a pa.Field object + # Assign a unique ID for this top-level field. + top_level_field_id = field_id_manager.get_next_id() - return types.NestedField( - field_id, - name, - field_type, - required=not field.nullable - ) + # Recursively convert the field's type. This will handle ID assignment + # for any nested structures using the same field_id_manager. + iceberg_field_type = arrow_to_iceberg_type_recursive( + pyarrow_field.type, field_id_manager + ) + doc = None + if pyarrow_field.metadata and b"doc" in pyarrow_field.metadata: + doc = pyarrow_field.metadata[b"doc"].decode("utf-8") -def convert_pyarrow_schema(arrow_schema: pa.Schema) -> IcebergSchema: - """Convert a PyArrow schema to a PyIceberg schema.""" - fields = [convert_pyarrow_field(i, field) for i, field in enumerate(arrow_schema)] - return IcebergSchema(*fields) + iceberg_fields.append( + iceberg_types.NestedField( + field_id=top_level_field_id, + name=pyarrow_field.name, + field_type=iceberg_field_type, + required=not pyarrow_field.nullable, # Top-level field nullability + doc=doc, + ) + ) + return IcebergSchema(*iceberg_fields, schema_id=schema_id) def extract_field_and_literal(expr: pc.Expression) -> tuple[str, Any]: """Extract field name and literal value from a comparison expression.""" # First, convert the expression to a string and parse it expr_str = str(expr) - + # PyArrow expression strings look like: "(field_name == literal)" or similar # Need to determine the operator and then split accordingly operators = ["==", "!=", ">", ">=", "<", "<="] @@ -98,36 +207,38 @@ def extract_field_and_literal(expr: pc.Expression) -> tuple[str, Any]: if op in expr_str: op_used = op break - + if not op_used: - raise ValueError(f"Could not find comparison operator in expression: {expr_str}") - + raise ValueError( + f"Could not find comparison operator in expression: {expr_str}" + ) + # Remove parentheses and split by operator expr_clean = expr_str.strip("()") parts = expr_clean.split(op_used) if len(parts) != 2: raise ValueError(f"Expected binary comparison in expression: {expr_str}") - + # Determine which part is the field and which is the literal field_name = None literal_value = None - + # Clean up the parts left = parts[0].strip() right = parts[1].strip() - + # Typically field name doesn't have quotes, literals (strings) do if left.startswith('"') or left.startswith("'"): # Right side is the field field_name = right # Extract the literal value - this is a simplification - literal_value = left.strip('"\'') + literal_value = left.strip("\"'") else: # Left side is the field field_name = left # Extract the literal value - this is a simplification - literal_value = right.strip('"\'') - + literal_value = right.strip("\"'") + # Try to convert numeric literals try: if "." in literal_value: @@ -137,17 +248,18 @@ def extract_field_and_literal(expr: pc.Expression) -> tuple[str, Any]: except ValueError: # Keep as string if not numeric pass - + return field_name, literal_value + def convert_pyarrow_expression(expr: pc.Expression) -> Optional[BooleanExpression]: """Convert a PyArrow compute expression to a PyIceberg boolean expression.""" if expr is None: return None - + # Handle the expression based on its string representation expr_str = str(expr) - + # Handle logical operations if "and" in expr_str.lower() and isinstance(expr, pc.Expression): # This is a simplification - in real code, you'd need to parse the expression @@ -156,7 +268,7 @@ def convert_pyarrow_expression(expr: pc.Expression) -> Optional[BooleanExpressio right_expr = None # You'd need to extract this return And( convert_pyarrow_expression(left_expr), - convert_pyarrow_expression(right_expr) + convert_pyarrow_expression(right_expr), ) elif "or" in expr_str.lower() and isinstance(expr, pc.Expression): # Similar simplification @@ -164,13 +276,13 @@ def convert_pyarrow_expression(expr: pc.Expression) -> Optional[BooleanExpressio right_expr = None # You'd need to extract this return Or( convert_pyarrow_expression(left_expr), - convert_pyarrow_expression(right_expr) + convert_pyarrow_expression(right_expr), ) elif "not" in expr_str.lower() and isinstance(expr, pc.Expression): # Similar simplification inner_expr = None # You'd need to extract this return Not(convert_pyarrow_expression(inner_expr)) - + # Handle comparison operations try: if "==" in expr_str: @@ -204,13 +316,13 @@ def convert_pyarrow_expressions(exprs: List[pc.Expression]) -> BooleanExpression """ if not exprs: raise ValueError("No expressions provided") - + if len(exprs) == 1: return convert_pyarrow_expression(exprs[0]) - + # Combine multiple expressions with AND result = convert_pyarrow_expression(exprs[0]) for expr in exprs[1:]: result = And(result, convert_pyarrow_expression(expr)) - + return result diff --git a/tests/tower/test_tables.py b/tests/tower/test_tables.py index 3ef0b027..bf044b25 100644 --- a/tests/tower/test_tables.py +++ b/tests/tower/test_tables.py @@ -14,13 +14,14 @@ # Imports the library under test import tower + def get_temp_dir(): """Create a temporary directory and return its file:// URL.""" # Create a temporary directory that will be automatically cleaned up temp_dir = tempfile.TemporaryDirectory() abs_path = pathlib.Path(temp_dir.name).absolute() - file_url = urljoin('file:', pathname2url(str(abs_path))) - + file_url = urljoin("file:", pathname2url(str(abs_path))) + # Return both the URL and the path to the temporary directory return file_url, abs_path @@ -38,21 +39,41 @@ def in_memory_catalog(): def test_reading_and_writing_to_tables(in_memory_catalog): - schema = pa.schema([ - pa.field("id", pa.int64()), - pa.field("name", pa.string()), - pa.field("age", pa.int32()), - pa.field("created_at", pa.timestamp("ms")), - ]) + schema = pa.schema( + [ + pa.field("id", pa.int64()), + pa.field("name", pa.string()), + pa.field("age", pa.int32()), + pa.field("created_at", pa.timestamp("ms")), + ] + ) ref = tower.tables("users", catalog=in_memory_catalog) table = ref.create_if_not_exists(schema) - data_with_schema = pa.Table.from_pylist([ - {"id": 1, "name": "Alice", "age": 30, "created_at": datetime.datetime(2023, 1, 1, 0, 0, 0)}, - {"id": 2, "name": "Bob", "age": 25, "created_at": datetime.datetime(2023, 1, 2, 0, 0, 0)}, - {"id": 3, "name": "Charlie", "age": 35, "created_at": datetime.datetime(2023, 1, 3, 0, 0, 0)}, - ], schema=schema) + data_with_schema = pa.Table.from_pylist( + [ + { + "id": 1, + "name": "Alice", + "age": 30, + "created_at": datetime.datetime(2023, 1, 1, 0, 0, 0), + }, + { + "id": 2, + "name": "Bob", + "age": 25, + "created_at": datetime.datetime(2023, 1, 2, 0, 0, 0), + }, + { + "id": 3, + "name": "Charlie", + "age": 35, + "created_at": datetime.datetime(2023, 1, 3, 0, 0, 0), + }, + ], + schema=schema, + ) # If we write some data to the table, that should be...OK. table = table.insert(data_with_schema) @@ -66,24 +87,48 @@ def test_reading_and_writing_to_tables(in_memory_catalog): avg_age = df.select(pl.mean("age").alias("mean_age")).collect().item() assert avg_age == 30.0 + def test_upsert_to_tables(in_memory_catalog): - schema = pa.schema([ - pa.field("id", pa.int64()), - pa.field("username", pa.string()), - pa.field("name", pa.string()), - pa.field("age", pa.int32()), - pa.field("created_at", pa.timestamp("ms")), - ]) + schema = pa.schema( + [ + pa.field("id", pa.int64()), + pa.field("username", pa.string()), + pa.field("name", pa.string()), + pa.field("age", pa.int32()), + pa.field("created_at", pa.timestamp("ms")), + ] + ) # First we'll insert some data into the relevant table. ref = tower.tables("users", catalog=in_memory_catalog) table = ref.create_if_not_exists(schema) - data_with_schema = pa.Table.from_pylist([ - {"id": 1, "username": "alicea", "name": "Alice", "age": 30, "created_at": datetime.datetime(2023, 1, 1, 0, 0, 0)}, - {"id": 2, "username": "bobb", "name": "Bob", "age": 25, "created_at": datetime.datetime(2023, 1, 2, 0, 0, 0)}, - {"id": 3, "username": "charliec", "name": "Charlie", "age": 35, "created_at": datetime.datetime(2023, 1, 3, 0, 0, 0)}, - ], schema=schema) + data_with_schema = pa.Table.from_pylist( + [ + { + "id": 1, + "username": "alicea", + "name": "Alice", + "age": 30, + "created_at": datetime.datetime(2023, 1, 1, 0, 0, 0), + }, + { + "id": 2, + "username": "bobb", + "name": "Bob", + "age": 25, + "created_at": datetime.datetime(2023, 1, 2, 0, 0, 0), + }, + { + "id": 3, + "username": "charliec", + "name": "Charlie", + "age": 35, + "created_at": datetime.datetime(2023, 1, 3, 0, 0, 0), + }, + ], + schema=schema, + ) # Make sure that we can actually insert the data into the table. table = table.insert(data_with_schema) @@ -91,13 +136,22 @@ def test_upsert_to_tables(in_memory_catalog): assert table.rows_affected().inserts == 3 # Now we'll update records in the table. - data_with_schema = pa.Table.from_pylist([ - {"id": 2, "username": "bobb", "name": "Bob", "age": 26, "created_at": datetime.datetime(2023, 1, 2, 0, 0, 0)}, - ], schema=schema) + data_with_schema = pa.Table.from_pylist( + [ + { + "id": 2, + "username": "bobb", + "name": "Bob", + "age": 26, + "created_at": datetime.datetime(2023, 1, 2, 0, 0, 0), + }, + ], + schema=schema, + ) # And make sure we can upsert the data. table = table.upsert(data_with_schema, join_cols=["username"]) - assert table.rows_affected().updates == 1 + assert table.rows_affected().updates == 1 # Now let's read from the table and see what we get back out. df = table.to_polars() @@ -107,24 +161,48 @@ def test_upsert_to_tables(in_memory_catalog): # The age should match what we updated the relevant record to assert res["age"].item() == 26 + def test_delete_from_tables(in_memory_catalog): - schema = pa.schema([ - pa.field("id", pa.int64()), - pa.field("username", pa.string()), - pa.field("name", pa.string()), - pa.field("age", pa.int32()), - pa.field("created_at", pa.timestamp("ms")), - ]) + schema = pa.schema( + [ + pa.field("id", pa.int64()), + pa.field("username", pa.string()), + pa.field("name", pa.string()), + pa.field("age", pa.int32()), + pa.field("created_at", pa.timestamp("ms")), + ] + ) # First we'll insert some data into the relevant table. ref = tower.tables("users", catalog=in_memory_catalog) table = ref.create_if_not_exists(schema) - data_with_schema = pa.Table.from_pylist([ - {"id": 1, "username": "alicea", "name": "Alice", "age": 30, "created_at": datetime.datetime(2023, 1, 1, 0, 0, 0)}, - {"id": 2, "username": "bobb", "name": "Bob", "age": 25, "created_at": datetime.datetime(2023, 1, 2, 0, 0, 0)}, - {"id": 3, "username": "charliec", "name": "Charlie", "age": 35, "created_at": datetime.datetime(2023, 1, 3, 0, 0, 0)}, - ], schema=schema) + data_with_schema = pa.Table.from_pylist( + [ + { + "id": 1, + "username": "alicea", + "name": "Alice", + "age": 30, + "created_at": datetime.datetime(2023, 1, 1, 0, 0, 0), + }, + { + "id": 2, + "username": "bobb", + "name": "Bob", + "age": 25, + "created_at": datetime.datetime(2023, 1, 2, 0, 0, 0), + }, + { + "id": 3, + "username": "charliec", + "name": "Charlie", + "age": 35, + "created_at": datetime.datetime(2023, 1, 3, 0, 0, 0), + }, + ], + schema=schema, + ) # Make sure that we can actually insert the data into the table. table = table.insert(data_with_schema) @@ -132,9 +210,7 @@ def test_delete_from_tables(in_memory_catalog): assert table.rows_affected().inserts == 3 # Perform the underlying delete from the table... - table.delete(filters=[ - table.column("username") == "bobb" - ]) + table.delete(filters=[table.column("username") == "bobb"]) # ...and let's make sure that record is actually gone. df = table.to_polars() @@ -143,14 +219,17 @@ def test_delete_from_tables(in_memory_catalog): all_rows = df.collect() assert all_rows.height == 2 + def test_getting_schemas_for_tables(in_memory_catalog): - original_schema = pa.schema([ - pa.field("id", pa.int64()), - pa.field("username", pa.string()), - pa.field("name", pa.string()), - pa.field("age", pa.int32()), - pa.field("created_at", pa.timestamp("ms")), - ]) + original_schema = pa.schema( + [ + pa.field("id", pa.int64()), + pa.field("username", pa.string()), + pa.field("name", pa.string()), + pa.field("age", pa.int32()), + pa.field("created_at", pa.timestamp("ms")), + ] + ) # First we'll insert some data into the relevant table. ref = tower.tables("users", catalog=in_memory_catalog) @@ -163,3 +242,317 @@ def test_getting_schemas_for_tables(in_memory_catalog): assert new_schema.field("id") is not None assert new_schema.field("age") is not None assert new_schema.field("created_at") is not None + + +def test_list_of_structs(in_memory_catalog): + """Tests writing and reading a list of non-nullable structs.""" + table_name = "test_list_of_structs_table" + # Define a pyarrow schema with a list of structs + # The list 'tags' can be null, but its elements (structs) are not nullable. + # Inside the struct, 'key' is non-nullable, 'value' is nullable. + item_struct_type = pa.struct( + [ + pa.field("key", pa.string(), nullable=False), # Non-nullable key + pa.field("value", pa.int64(), nullable=True), # Nullable value + ] + ) + # The 'item' field represents the elements of the list. It's non-nullable. + # This means each element in the list must be a valid struct, not a Python None. + pa_schema = pa.schema( + [ + pa.field("doc_id", pa.int32(), nullable=False), + pa.field( + "tags", + pa.list_(pa.field("item", item_struct_type, nullable=False)), + nullable=True, + ), + ] + ) + + ref = tower.tables(table_name, catalog=in_memory_catalog) + table = ref.create_if_not_exists(pa_schema) + assert table is not None, f"Table '{table_name}' should have been created" + + data_to_write = [ + { + "doc_id": 1, + "tags": [{"key": "user", "value": 100}, {"key": "priority", "value": 1}], + }, + { + "doc_id": 2, + "tags": [ + {"key": "source", "value": 200}, + {"key": "reviewed", "value": None}, + ], + }, # Null value for a struct field + {"doc_id": 3, "tags": []}, # Empty list + {"doc_id": 4, "tags": None}, # Null list + ] + arrow_table_write = pa.Table.from_pylist(data_to_write, schema=pa_schema) + + op_result = table.insert(arrow_table_write) + assert op_result is not None + assert op_result.rows_affected().inserts == 4 + + # Read back and verify + df_read = table.to_polars().collect() # Collect to get a Polars DataFrame + + assert df_read.shape[0] == 4 + assert df_read["doc_id"].to_list() == [1, 2, 3, 4] + + # Verify nested data (Polars handles structs and lists well) + # For doc_id = 1 + tags_doc1 = ( + df_read.filter(pl.col("doc_id") == 1).select("tags").row(0)[0] + ) # Get the list of structs + assert len(tags_doc1) == 2 + assert tags_doc1[0]["key"] == "user" + assert tags_doc1[0]["value"] == 100 + assert tags_doc1[1]["key"] == "priority" + assert tags_doc1[1]["value"] == 1 + + # For doc_id = 2 (with a null inside a struct) + tags_doc2 = df_read.filter(pl.col("doc_id") == 2).select("tags").row(0)[0] + assert len(tags_doc2) == 2 + assert tags_doc2[0]["key"] == "source" + assert tags_doc2[0]["value"] == 200 + assert tags_doc2[1]["key"] == "reviewed" + assert tags_doc2[1]["value"] is None + + # For doc_id = 3 (empty list) + tags_doc3 = df_read.filter(pl.col("doc_id") == 3).select("tags").row(0)[0] + assert len(tags_doc3) == 0 + + # For doc_id = 4 (null list should also be an empty list) + tags_doc4 = df_read.filter(pl.col("doc_id") == 4).select("tags").row(0)[0] + assert tags_doc4 == [] + + +def test_nested_structs(in_memory_catalog): + """Tests writing and reading a table with nested structs.""" + table_name = "test_nested_structs_table" + # Define a pyarrow schema with nested structs + # config: struct> + settings_struct_type = pa.struct( + [ + pa.field("retries", pa.int8(), nullable=False), + pa.field("timeout", pa.int32(), nullable=True), + pa.field("active", pa.bool_(), nullable=False), + ] + ) + pa_schema = pa.schema( + [ + pa.field("record_id", pa.string(), nullable=False), + pa.field( + "config", + pa.struct( + [ + pa.field("name", pa.string(), nullable=True), + pa.field( + "details", settings_struct_type, nullable=True + ), # This inner struct can be null + ] + ), + nullable=True, + ), # The outer 'config' struct can also be null + ] + ) + + ref = tower.tables(table_name, catalog=in_memory_catalog) + table = ref.create_if_not_exists(pa_schema) + assert table is not None, f"Table '{table_name}' should have been created" + + data_to_write = [ + { + "record_id": "rec1", + "config": { + "name": "Default", + "details": {"retries": 3, "timeout": 1000, "active": True}, + }, + }, + { + "record_id": "rec2", + "config": { + "name": "Fast", + "details": {"retries": 1, "timeout": None, "active": True}, + }, + }, # Null timeout + { + "record_id": "rec3", + "config": {"name": "Inactive", "details": None}, + }, # Null inner struct + {"record_id": "rec4", "config": None}, # Null outer struct + ] + arrow_table_write = pa.Table.from_pylist(data_to_write, schema=pa_schema) + + op_result = table.insert(arrow_table_write) + assert op_result is not None + assert op_result.rows_affected().inserts == 4 + + # Read back and verify + df_read = table.to_polars().collect() + + assert df_read.shape[0] == 4 + assert df_read["record_id"].to_list() == ["rec1", "rec2", "rec3", "rec4"] + + # Verify nested data for rec1 + config_rec1 = ( + df_read.filter(pl.col("record_id") == "rec1").select("config").row(0)[0] + ) + assert config_rec1["name"] == "Default" + details_rec1 = config_rec1["details"] + assert details_rec1["retries"] == 3 + assert details_rec1["timeout"] == 1000 + assert details_rec1["active"] is True + + # Verify nested data for rec2 (null timeout) + config_rec2 = ( + df_read.filter(pl.col("record_id") == "rec2").select("config").row(0)[0] + ) + assert config_rec2["name"] == "Fast" + details_rec2 = config_rec2["details"] + assert details_rec2["retries"] == 1 + assert details_rec2["timeout"] is None + assert details_rec2["active"] is True + + # Verify nested data for rec3 (null inner struct 'details') + config_rec3 = ( + df_read.filter(pl.col("record_id") == "rec3").select("config").row(0)[0] + ) + assert config_rec3["name"] == "Inactive" + assert config_rec3["details"] is None # The 'details' struct itself is null + + # Verify nested data for rec4 (null outer struct 'config') + config_rec4 = ( + df_read.filter(pl.col("record_id") == "rec4").select("config").row(0)[0] + ) + assert config_rec4 is None # The 'config' struct is null + + +def test_list_of_primitive_types(in_memory_catalog): + """Tests writing and reading a list of primitive types.""" + table_name = "test_list_of_primitives_table" + pa_schema = pa.schema( + [ + pa.field("event_id", pa.int32(), nullable=False), + pa.field( + "scores", + pa.list_(pa.field("score", pa.float32(), nullable=False)), + nullable=True, + ), # List of non-nullable floats + pa.field( + "keywords", + pa.list_(pa.field("keyword", pa.string(), nullable=True)), + nullable=True, + ), # List of nullable strings + ] + ) + + ref = tower.tables(table_name, catalog=in_memory_catalog) + table = ref.create_if_not_exists(pa_schema) + assert table is not None, f"Table '{table_name}' should have been created" + + data_to_write = [ + {"event_id": 1, "scores": [1.0, 2.5, 3.0], "keywords": ["alpha", "beta", None]}, + {"event_id": 2, "scores": [], "keywords": ["gamma"]}, + {"event_id": 3, "scores": None, "keywords": None}, + {"event_id": 4, "scores": [4.2], "keywords": []}, + ] + arrow_table_write = pa.Table.from_pylist(data_to_write, schema=pa_schema) + + op_result = table.insert(arrow_table_write) + assert op_result is not None + assert op_result.rows_affected().inserts == 4 + + df_read = table.to_polars().collect() + + assert df_read.shape[0] == 4 + + # Event 1 + row1 = df_read.filter(pl.col("event_id") == 1) + assert row1.select("scores").to_series()[0].to_list() == [1.0, 2.5, 3.0] + assert row1.select("keywords").to_series()[0].to_list() == ["alpha", "beta", None] + + # Event 2 + row2 = df_read.filter(pl.col("event_id") == 2) + assert row2.select("scores").to_series()[0].to_list() == [] + assert row2.select("keywords").to_series()[0].to_list() == ["gamma"] + + # Event 3 + row3 = df_read.filter(pl.col("event_id") == 3) + assert row3.select("scores").to_series()[0] is None + assert row3.select("keywords").to_series()[0] is None + + +def test_map_type_simple(in_memory_catalog): + """Tests writing and reading a simple map type.""" + table_name = "test_map_type_simple_table" + # Map from string to string. Keys are non-nullable, values can be nullable. + pa_schema = pa.schema( + [ + pa.field("id", pa.int32(), nullable=False), + pa.field( + "properties", + pa.map_(pa.string(), pa.string(), keys_sorted=False), + nullable=True, + ), + # Note: PyArrow map values are nullable by default if item_field is not specified with nullable=False + ] + ) + + ref = tower.tables(table_name, catalog=in_memory_catalog) + table = ref.create_if_not_exists(pa_schema) + assert table is not None, f"Table '{table_name}' should have been created" + + # PyArrow represents maps as a list of structs with 'key' and 'value' fields + data_to_write = [ + {"id": 1, "properties": [("color", "blue"), ("size", "large")]}, + { + "id": 2, + "properties": [("status", "pending"), ("owner", None)], + }, # Null value in map + {"id": 3, "properties": []}, # Empty map + {"id": 4, "properties": None}, # Null map field + ] + arrow_table_write = pa.Table.from_pylist(data_to_write, schema=pa_schema) + + op_result = table.insert(arrow_table_write) + assert op_result is not None + assert op_result.rows_affected().inserts == 4 + + df_read = table.to_polars().collect() + assert df_read.shape[0] == 4 + + # Verify map data + # Polars represents map as list of structs: struct + # Row 1 + props1_series = df_read.filter(pl.col("id") == 1).select("properties").to_series() + # The series item is already a list of dictionaries + props1_list = props1_series[0] + expected_props1 = [ + {"key": "color", "value": "blue"}, + {"key": "size", "value": "large"}, + ] + # Sort by key for consistent comparison if order is not guaranteed + assert sorted(props1_list, key=lambda x: x["key"]) == sorted( + expected_props1, key=lambda x: x["key"] + ) + + # Row 2 + props2_series = df_read.filter(pl.col("id") == 2).select("properties").to_series() + props2_list = props2_series[0] + expected_props2 = [ + {"key": "status", "value": "pending"}, + {"key": "owner", "value": None}, + ] + assert sorted(props2_list, key=lambda x: x["key"]) == sorted( + expected_props2, key=lambda x: x["key"] + ) + + # Row 3 (empty map) + props3_series = df_read.filter(pl.col("id") == 3).select("properties").to_series() + assert props3_series[0].to_list() == [] + + # Row 4 (null map) + props4_series = df_read.filter(pl.col("id") == 4).select("properties").to_series() + assert props4_series[0] is None