Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 6 additions & 7 deletions bigframes/core/array_value.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,8 @@
import datetime
import functools
import typing
from typing import Iterable, List, Mapping, Optional, Sequence, Tuple
from typing import Iterable, List, Mapping, Optional, Sequence, Tuple, Union

import google.cloud.bigquery
import pandas
import pyarrow as pa

Expand Down Expand Up @@ -91,7 +90,7 @@ def from_range(cls, start, end, step):
@classmethod
def from_table(
cls,
table: google.cloud.bigquery.Table,
table: Union[bq_data.BiglakeIcebergTable, bq_data.GbqNativeTable],
session: Session,
*,
columns: Optional[Sequence[str]] = None,
Expand All @@ -103,8 +102,6 @@ def from_table(
):
if offsets_col and primary_key:
raise ValueError("must set at most one of 'offests', 'primary_key'")
# define data source only for needed columns, this makes row-hashing cheaper
table_def = bq_data.GbqTable.from_table(table, columns=columns or ())

# create ordering from info
ordering = None
Expand All @@ -115,7 +112,9 @@ def from_table(
[ids.ColumnId(key_part) for key_part in primary_key]
)

bf_schema = schemata.ArraySchema.from_bq_table(table, columns=columns)
bf_schema = schemata.ArraySchema.from_bq_schema(
table.physical_schema, columns=columns
)
# Scan all columns by default, we define this list as it can be pruned while preserving source_def
scan_list = nodes.ScanList(
tuple(
Expand All @@ -124,7 +123,7 @@ def from_table(
)
)
source_def = bq_data.BigqueryDataSource(
table=table_def,
table=table,
schema=bf_schema,
at_time=at_time,
sql_predicate=predicate,
Expand Down
99 changes: 91 additions & 8 deletions bigframes/core/bq_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import queue
import threading
import typing
from typing import Any, Iterator, Optional, Sequence, Tuple
from typing import Any, Iterator, List, Optional, Sequence, Tuple, Union

from google.cloud import bigquery_storage_v1
import google.cloud.bigquery as bq
Expand All @@ -37,23 +37,48 @@
import bigframes.core.ordering as orderings


# what is the line between metadata and core fields? Mostly metadata fields are optional or unreliable, but its fuzzy
@dataclasses.dataclass(frozen=True)
class GbqTable:
class TableMetadata:
# this size metadata might be stale, don't use where strict correctness is needed
numBytes: Optional[int] = None
numRows: Optional[int] = None
location: Optional[str] = None
type: Optional[str] = None
created_time: Optional[datetime.datetime] = None
modified_time: Optional[datetime.datetime] = None


@dataclasses.dataclass(frozen=True)
class GbqNativeTable:
project_id: str = dataclasses.field()
dataset_id: str = dataclasses.field()
table_id: str = dataclasses.field()
physical_schema: Tuple[bq.SchemaField, ...] = dataclasses.field()
is_physically_stored: bool = dataclasses.field()
cluster_cols: typing.Optional[Tuple[str, ...]]
partition_col: Optional[str] = None
cluster_cols: typing.Optional[Tuple[str, ...]] = None
primary_key: Optional[Tuple[str, ...]] = None
metadata: TableMetadata = TableMetadata()

@staticmethod
def from_table(table: bq.Table, columns: Sequence[str] = ()) -> GbqTable:
def from_table(table: bq.Table, columns: Sequence[str] = ()) -> GbqNativeTable:
# Subsetting fields with columns can reduce cost of row-hash default ordering
if columns:
schema = tuple(item for item in table.schema if item.name in columns)
else:
schema = tuple(table.schema)
return GbqTable(

metadata = TableMetadata(
numBytes=table.num_bytes,
numRows=table.num_rows,
location=table.location, # type: ignore
type=table.table_type, # type: ignore
created_time=table.created,
modified_time=table.modified,
)

return GbqNativeTable(
project_id=table.project,
dataset_id=table.dataset_id,
table_id=table.table_id,
Expand All @@ -62,15 +87,17 @@ def from_table(table: bq.Table, columns: Sequence[str] = ()) -> GbqTable:
cluster_cols=None
if table.clustering_fields is None
else tuple(table.clustering_fields),
primary_key=tuple(_get_primary_keys(table)),
metadata=metadata,
)

@staticmethod
def from_ref_and_schema(
table_ref: bq.TableReference,
schema: Sequence[bq.SchemaField],
cluster_cols: Optional[Sequence[str]] = None,
) -> GbqTable:
return GbqTable(
) -> GbqNativeTable:
return GbqNativeTable(
project_id=table_ref.project,
dataset_id=table_ref.dataset_id,
table_id=table_ref.table_id,
Expand All @@ -84,12 +111,48 @@ def get_table_ref(self) -> bq.TableReference:
bq.DatasetReference(self.project_id, self.dataset_id), self.table_id
)

def get_full_id(self, quoted: bool = False) -> str:
if quoted:
return f"`{self.project_id}`.`{self.dataset_id}`.`{self.table_id}`"
return f"{self.project_id}.{self.dataset_id}.{self.table_id}"

@property
@functools.cache
def schema_by_id(self):
return {col.name: col for col in self.physical_schema}


@dataclasses.dataclass(frozen=True)
class BiglakeIcebergTable:
project_id: str = dataclasses.field()
catalog_id: str = dataclasses.field()
namespace_id: str = dataclasses.field()
table_id: str = dataclasses.field()
physical_schema: Tuple[bq.SchemaField, ...] = dataclasses.field()
cluster_cols: typing.Optional[Tuple[str, ...]]
metadata: TableMetadata

def get_full_id(self, quoted: bool = False) -> str:
if quoted:
return f"`{self.project_id}`.`{self.catalog_id}`.`{self.namespace_id}`.`{self.table_id}`"
return (
f"{self.project_id}.{self.catalog_id}.{self.namespace_id}.{self.table_id}"
)

@property
@functools.cache
def schema_by_id(self):
return {col.name: col for col in self.physical_schema}

@property
def partition_col(self) -> Optional[str]:
return None

@property
def primary_key(self) -> Optional[Tuple[str, ...]]:
return None


@dataclasses.dataclass(frozen=True)
class BigqueryDataSource:
"""
Expand All @@ -104,7 +167,7 @@ def __post_init__(self):
self.schema.names
)

table: GbqTable
table: Union[GbqNativeTable, BiglakeIcebergTable]
schema: bigframes.core.schema.ArraySchema
at_time: typing.Optional[datetime.datetime] = None
# Added for backwards compatibility, not validated
Expand Down Expand Up @@ -188,6 +251,8 @@ def get_arrow_batches(
project_id: str,
sample_rate: Optional[float] = None,
) -> ReadResult:
assert isinstance(data.table, GbqNativeTable)

table_mod_options = {}
read_options_dict: dict[str, Any] = {"selected_fields": list(columns)}

Expand Down Expand Up @@ -245,3 +310,21 @@ def process_batch(pa_batch):
return ReadResult(
batches, session.estimated_row_count, session.estimated_total_bytes_scanned
)


def _get_primary_keys(
table: bq.Table,
) -> List[str]:
"""Get primary keys from table if they are set."""

primary_keys: List[str] = []
if (
(table_constraints := getattr(table, "table_constraints", None)) is not None
and (primary_key := table_constraints.primary_key) is not None
# This will be False for either None or empty list.
# We want primary_keys = None if no primary keys are set.
and (columns := primary_key.columns)
):
primary_keys = columns if columns is not None else []

return primary_keys
4 changes: 1 addition & 3 deletions bigframes/core/compile/ibis_compiler/ibis_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,9 +207,7 @@ def _table_to_ibis(
source: bq_data.BigqueryDataSource,
scan_cols: typing.Sequence[str],
) -> ibis_types.Table:
full_table_name = (
f"{source.table.project_id}.{source.table.dataset_id}.{source.table.table_id}"
)
full_table_name = source.table.get_full_id(quoted=False)
# Physical schema might include unused columns, unsupported datatypes like JSON
physical_schema = ibis_bigquery.BigQuerySchema.to_ibis(
list(source.table.physical_schema)
Expand Down
18 changes: 15 additions & 3 deletions bigframes/core/compile/sqlglot/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

from bigframes.core import (
agg_expressions,
bq_data,
expression,
guid,
identifiers,
Expand Down Expand Up @@ -173,10 +174,21 @@ def compile_readlocal(node: nodes.ReadLocalNode, child: ir.SQLGlotIR) -> ir.SQLG
@_compile_node.register
def compile_readtable(node: nodes.ReadTableNode, child: ir.SQLGlotIR):
table = node.source.table
if isinstance(table, bq_data.GbqNativeTable):
project, dataset, table_id = table.project_id, table.dataset_id, table.table_id
elif isinstance(table, bq_data.BiglakeIcebergTable):
project, dataset, table_id = (
table.project_id,
table.catalog_id,
f"{table.namespace_id}.{table.table_id}",
)

else:
raise ValueError(f"Unrecognized table type: {table}")
return ir.SQLGlotIR.from_table(
table.project_id,
table.dataset_id,
table.table_id,
project,
dataset,
table_id,
col_names=[col.source_id for col in node.scan_list.items],
alias_names=[col.id.sql for col in node.scan_list.items],
uid_gen=child.uid_gen,
Expand Down
4 changes: 1 addition & 3 deletions bigframes/core/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -825,9 +825,7 @@ def variables_introduced(self) -> int:

@property
def row_count(self) -> typing.Optional[int]:
if self.source.sql_predicate is None and self.source.table.is_physically_stored:
return self.source.n_rows
return None
return self.source.n_rows

@property
def node_defined_ids(self) -> Tuple[identifiers.ColumnId, ...]:
Expand Down
27 changes: 6 additions & 21 deletions bigframes/core/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from dataclasses import dataclass
import functools
import typing
from typing import Dict, List, Optional, Sequence
from typing import Dict, Optional, Sequence

import google.cloud.bigquery
import pyarrow
Expand All @@ -40,31 +40,16 @@ class ArraySchema:
def __iter__(self):
yield from self.items

@classmethod
def from_bq_table(
cls,
table: google.cloud.bigquery.Table,
column_type_overrides: Optional[
typing.Dict[str, bigframes.dtypes.Dtype]
] = None,
columns: Optional[Sequence[str]] = None,
):
if not columns:
fields = table.schema
else:
lookup = {field.name: field for field in table.schema}
fields = [lookup[col] for col in columns]

return ArraySchema.from_bq_schema(
fields, column_type_overrides=column_type_overrides
)

@classmethod
def from_bq_schema(
cls,
schema: List[google.cloud.bigquery.SchemaField],
schema: Sequence[google.cloud.bigquery.SchemaField],
column_type_overrides: Optional[Dict[str, bigframes.dtypes.Dtype]] = None,
columns: Optional[Sequence[str]] = None,
):
if columns:
lookup = {field.name: field for field in schema}
schema = [lookup[col] for col in columns]
if column_type_overrides is None:
column_type_overrides = {}
items = tuple(
Expand Down
3 changes: 1 addition & 2 deletions bigframes/core/sql/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
import bigframes.core.compile.googlesql as googlesql

if TYPE_CHECKING:
import google.cloud.bigquery as bigquery

import bigframes.core.ordering

Expand Down Expand Up @@ -131,7 +130,7 @@ def infix_op(opname: str, left_arg: str, right_arg: str):
return f"{left_arg} {opname} {right_arg}"


def is_distinct_sql(columns: Iterable[str], table_ref: bigquery.TableReference) -> str:
def is_distinct_sql(columns: Iterable[str], table_ref) -> str:
is_unique_sql = f"""WITH full_table AS (
{googlesql.Select().from_(table_ref).select(columns).sql()}
),
Expand Down
4 changes: 2 additions & 2 deletions bigframes/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -800,7 +800,7 @@ def convert_to_schema_field(
name, inner_field.field_type, mode="REPEATED", fields=inner_field.fields
)
if pa.types.is_struct(bigframes_dtype.pyarrow_dtype):
inner_fields: list[pa.Field] = []
inner_fields: list[google.cloud.bigquery.SchemaField] = []
struct_type = typing.cast(pa.StructType, bigframes_dtype.pyarrow_dtype)
for i in range(struct_type.num_fields):
field = struct_type.field(i)
Expand All @@ -823,7 +823,7 @@ def convert_to_schema_field(


def bf_type_from_type_kind(
bq_schema: list[google.cloud.bigquery.SchemaField],
bq_schema: Sequence[google.cloud.bigquery.SchemaField],
) -> typing.Dict[str, Dtype]:
"""Converts bigquery sql type to the default bigframes dtype."""
return {name: dtype for name, dtype in map(convert_schema_field, bq_schema)}
Expand Down
4 changes: 2 additions & 2 deletions bigframes/operations/aggregations.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionT
return dtypes.TIMEDELTA_DTYPE

if dtypes.is_numeric(input_types[0]):
if pd.api.types.is_bool_dtype(input_types[0]):
if pd.api.types.is_bool_dtype(input_types[0]): # type: ignore
return dtypes.INT_DTYPE
return input_types[0]

Expand All @@ -224,7 +224,7 @@ def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionT
# These will change if median is changed to exact implementation.
if not dtypes.is_orderable(input_types[0]):
raise TypeError(f"Type {input_types[0]} is not orderable")
if pd.api.types.is_bool_dtype(input_types[0]):
if pd.api.types.is_bool_dtype(input_types[0]): # type: ignore
return dtypes.INT_DTYPE
else:
return input_types[0]
Expand Down
Loading
Loading