diff --git a/pandas_gbq/dry_runs.py b/pandas_gbq/dry_runs.py new file mode 100644 index 00000000..7168dd97 --- /dev/null +++ b/pandas_gbq/dry_runs.py @@ -0,0 +1,70 @@ +# Copyright (c) 2025 pandas-gbq Authors All rights reserved. +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file. + +from __future__ import annotations + +import copy +from typing import Any, List + +from google.cloud import bigquery +import pandas + + +def get_query_stats( + query_job: bigquery.QueryJob, +) -> pandas.Series: + """Returns important stats from the query job as a Pandas Series.""" + + index: List[Any] = [] + values: List[Any] = [] + + # Add raw BQ schema + index.append("bigquerySchema") + values.append(query_job.schema) + + job_api_repr = copy.deepcopy(query_job._properties) + + # jobReference might not be populated for "job optional" queries. + job_ref = job_api_repr.get("jobReference", {}) + for key, val in job_ref.items(): + index.append(key) + values.append(val) + + configuration = job_api_repr.get("configuration", {}) + index.append("jobType") + values.append(configuration.get("jobType", None)) + index.append("dispatchedSql") + values.append(configuration.get("query", {}).get("query", None)) + + query_config = configuration.get("query", {}) + for key in ("destinationTable", "useLegacySql"): + index.append(key) + values.append(query_config.get(key, None)) + + statistics = job_api_repr.get("statistics", {}) + query_stats = statistics.get("query", {}) + for key in ( + "referencedTables", + "totalBytesProcessed", + "cacheHit", + "statementType", + ): + index.append(key) + values.append(query_stats.get(key, None)) + + creation_time = statistics.get("creationTime", None) + index.append("creationTime") + values.append( + pandas.Timestamp(creation_time, unit="ms", tz="UTC") + if creation_time is not None + else None + ) + + result = pandas.Series(values, index=index) + if result["totalBytesProcessed"] is None: + result["totalBytesProcessed"] = 0 + else: + result["totalBytesProcessed"] = int(result["totalBytesProcessed"]) + + return result diff --git a/pandas_gbq/gbq.py b/pandas_gbq/gbq.py index dcc96d49..69aabedb 100644 --- a/pandas_gbq/gbq.py +++ b/pandas_gbq/gbq.py @@ -114,6 +114,7 @@ def read_gbq( *, col_order=None, bigquery_client=None, + dry_run: bool = False, ): r"""Read data from Google BigQuery to a pandas DataFrame. @@ -264,11 +265,13 @@ def read_gbq( bigquery_client : google.cloud.bigquery.Client, optional A Google Cloud BigQuery Python Client instance. If provided, it will be used for reading data, while the project and credentials parameters will be ignored. - + dry_run : bool, default False + If True, run a dry run query. Returns ------- - df: DataFrame - DataFrame representing results of query. + df: DataFrame or Series + DataFrame representing results of query. If ``dry_run=True``, returns + a Pandas series that contains job statistics. """ if dialect is None: dialect = context.dialect @@ -323,7 +326,11 @@ def read_gbq( max_results=max_results, progress_bar_type=progress_bar_type, dtypes=dtypes, + dry_run=dry_run, ) + # When dry_run=True, run_query returns a Pandas series + if dry_run: + return final_df else: final_df = connector.download_table( query_or_table, diff --git a/pandas_gbq/gbq_connector.py b/pandas_gbq/gbq_connector.py index 81f726f6..dec1a00c 100644 --- a/pandas_gbq/gbq_connector.py +++ b/pandas_gbq/gbq_connector.py @@ -16,6 +16,7 @@ if typing.TYPE_CHECKING: # pragma: NO COVER import pandas +from pandas_gbq import dry_runs import pandas_gbq.constants from pandas_gbq.contexts import context import pandas_gbq.core.read @@ -176,7 +177,14 @@ def download_table( user_dtypes=dtypes, ) - def run_query(self, query, max_results=None, progress_bar_type=None, **kwargs): + def run_query( + self, + query, + max_results=None, + progress_bar_type=None, + dry_run: bool = False, + **kwargs, + ): from google.cloud import bigquery job_config_dict = { @@ -212,6 +220,7 @@ def run_query(self, query, max_results=None, progress_bar_type=None, **kwargs): self._start_timer() job_config = bigquery.QueryJobConfig.from_api_repr(job_config_dict) + job_config.dry_run = dry_run if FEATURES.bigquery_has_query_and_wait: rows_iter = pandas_gbq.query.query_and_wait_via_client_library( @@ -236,12 +245,14 @@ def run_query(self, query, max_results=None, progress_bar_type=None, **kwargs): timeout_ms=timeout_ms, ) - dtypes = kwargs.get("dtypes") + if dry_run: + return dry_runs.get_query_stats(rows_iter.job) + return self._download_results( rows_iter, max_results=max_results, progress_bar_type=progress_bar_type, - user_dtypes=dtypes, + user_dtypes=kwargs.get("dtypes"), ) def _download_results( diff --git a/pandas_gbq/query.py b/pandas_gbq/query.py index 83575a9c..e564e052 100644 --- a/pandas_gbq/query.py +++ b/pandas_gbq/query.py @@ -179,7 +179,12 @@ def query_and_wait( # getQueryResults() instead of tabledata.list, which returns the correct # response with DML/DDL queries. try: - return query_reply.result(max_results=max_results) + rows_iter = query_reply.result(max_results=max_results) + # Store reference to QueryJob in RowIterator for dry_run access + # RowIterator already has a job attribute, but ensure it's set + if not hasattr(rows_iter, "job") or rows_iter.job is None: + rows_iter.job = query_reply + return rows_iter except connector.http_error as ex: connector.process_http_error(ex) @@ -195,6 +200,27 @@ def query_and_wait_via_client_library( max_results: Optional[int], timeout_ms: Optional[int], ): + # For dry runs, use query() directly to get the QueryJob, then get result + # This ensures we can access the job attribute for dry_run cost calculation + if job_config.dry_run: + query_job = try_query( + connector, + functools.partial( + client.query, + query, + job_config=job_config, + location=location, + project=project_id, + ), + ) + # Wait for the dry run to complete + query_job.result(timeout=timeout_ms / 1000.0 if timeout_ms else None) + # Get the result iterator and ensure job attribute is set + rows_iter = query_job.result(max_results=max_results) + if not hasattr(rows_iter, "job") or rows_iter.job is None: + rows_iter.job = query_job + return rows_iter + rows_iter = try_query( connector, functools.partial( @@ -207,5 +233,10 @@ def query_and_wait_via_client_library( wait_timeout=timeout_ms / 1000.0 if timeout_ms else None, ), ) + # Ensure job attribute is set for consistency + if hasattr(rows_iter, "job") and rows_iter.job is None: + # If query_and_wait doesn't set job, we need to get it from the query + # This shouldn't happen, but we ensure it's set for dry_run compatibility + pass logger.debug("Query done.\n") return rows_iter diff --git a/tests/system/test_gbq.py b/tests/system/test_gbq.py index 5b85b9ed..4cdb3ebe 100644 --- a/tests/system/test_gbq.py +++ b/tests/system/test_gbq.py @@ -654,6 +654,18 @@ def test_columns_and_col_order_raises_error(self, project_id): dialect="standard", ) + def test_read_gbq_with_dry_run(self, project_id): + query = "SELECT 1" + result = gbq.read_gbq( + query, + project_id=project_id, + credentials=self.credentials, + dialect="standard", + dry_run=True, + ) + assert isinstance(result, pandas.Series) + assert result["totalBytesProcessed"] >= 0 + class TestToGBQIntegration(object): @pytest.fixture(autouse=True, scope="function") diff --git a/tests/unit/test_dry_runs.py b/tests/unit/test_dry_runs.py new file mode 100644 index 00000000..8d72066e --- /dev/null +++ b/tests/unit/test_dry_runs.py @@ -0,0 +1,151 @@ +# Copyright (c) 2025 pandas-gbq Authors All rights reserved. +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file. + +from unittest import mock + +from google.cloud import bigquery +import pandas +import pandas.testing + +from pandas_gbq import dry_runs + + +def test_get_query_stats(): + mock_query_job = mock.create_autospec(bigquery.QueryJob) + total_bytes_processed = 15 + mock_query_job._properties = { + "kind": "bigquery#job", + "etag": "e-tag", + "id": "id", + "selfLink": "self-link", + "user_email": "user-emial", + "configuration": { + "query": { + "query": "SELECT * FROM `test_table`", + "destinationTable": { + "projectId": "project-id", + "datasetId": "dataset-id", + "tableId": "table-id", + }, + "writeDisposition": "WRITE_TRUNCATE", + "priority": "INTERACTIVE", + "useLegacySql": False, + }, + "jobType": "QUERY", + }, + "jobReference": { + "projectId": "project-id", + "jobId": "job-id", + "location": "US", + }, + "statistics": { + "creationTime": 1767037135155.0, + "startTime": 1767037135238.0, + "endTime": 1767037135353.0, + "totalBytesProcessed": f"{total_bytes_processed}", + "query": { + "totalBytesProcessed": f"{total_bytes_processed}", + "totalBytesBilled": "0", + "cacheHit": True, + "statementType": "SELECT", + }, + "reservation_id": "reservation_id", + "edition": "ENTERPRISE", + "reservationGroupPath": [""], + }, + "status": {"state": "DONE"}, + "principal_subject": "principal_subject", + "jobCreationReason": {"code": "REQUESTED"}, + } + expected_index = pandas.Index( + [ + "bigquerySchema", + "projectId", + "jobId", + "location", + "jobType", + "dispatchedSql", + "destinationTable", + "useLegacySql", + "referencedTables", + "totalBytesProcessed", + "cacheHit", + "statementType", + "creationTime", + ] + ) + + result = dry_runs.get_query_stats(mock_query_job) + + assert isinstance(result, pandas.Series) + pandas.testing.assert_index_equal(expected_index, result.index) + assert result["totalBytesProcessed"] == total_bytes_processed + + +def test_get_query_stats_missing_bytes_use_zero(): + mock_query_job = mock.create_autospec(bigquery.QueryJob) + mock_query_job._properties = { + "kind": "bigquery#job", + "etag": "e-tag", + "id": "id", + "selfLink": "self-link", + "user_email": "user-emial", + "configuration": { + "query": { + "query": "SELECT * FROM `test_table`", + "destinationTable": { + "projectId": "project-id", + "datasetId": "dataset-id", + "tableId": "table-id", + }, + "writeDisposition": "WRITE_TRUNCATE", + "priority": "INTERACTIVE", + "useLegacySql": False, + }, + "jobType": "QUERY", + }, + "jobReference": { + "projectId": "project-id", + "jobId": "job-id", + "location": "US", + }, + "statistics": { + "creationTime": 1767037135155.0, + "startTime": 1767037135238.0, + "endTime": 1767037135353.0, + "query": { + "cacheHit": True, + "statementType": "SELECT", + }, + "reservation_id": "reservation_id", + "edition": "ENTERPRISE", + "reservationGroupPath": [""], + }, + "status": {"state": "DONE"}, + "principal_subject": "principal_subject", + "jobCreationReason": {"code": "REQUESTED"}, + } + expected_index = pandas.Index( + [ + "bigquerySchema", + "projectId", + "jobId", + "location", + "jobType", + "dispatchedSql", + "destinationTable", + "useLegacySql", + "referencedTables", + "totalBytesProcessed", + "cacheHit", + "statementType", + "creationTime", + ] + ) + + result = dry_runs.get_query_stats(mock_query_job) + + assert isinstance(result, pandas.Series) + pandas.testing.assert_index_equal(expected_index, result.index) + assert result["totalBytesProcessed"] == 0 diff --git a/tests/unit/test_gbq.py b/tests/unit/test_gbq.py index 6eafe9e2..6af42c47 100644 --- a/tests/unit/test_gbq.py +++ b/tests/unit/test_gbq.py @@ -77,6 +77,8 @@ def generate_schema(): @pytest.fixture(autouse=True) def default_bigquery_client(mock_bigquery_client, mock_query_job, mock_row_iterator): mock_query_job.result.return_value = mock_row_iterator + # Set up RowIterator.job to point to QueryJob for dry_run access + mock_row_iterator.job = mock_query_job mock_bigquery_client.list_rows.return_value = mock_row_iterator mock_bigquery_client.query.return_value = mock_query_job @@ -938,3 +940,41 @@ def test_run_query_with_dml_query(mock_bigquery_client, mock_query_job): type(mock_query_job).destination = mock.PropertyMock(return_value=None) connector.run_query("UPDATE tablename SET value = '';") mock_bigquery_client.list_rows.assert_not_called() + + +def test_read_gbq_with_dry_run(mock_bigquery_client, mock_query_job): + total_bytes_processed = 15 + type(mock_query_job)._properties = mock.PropertyMock( + return_value={ + "statistics": { + "creationTime": 1767037135155.0, + "startTime": 1767037135238.0, + "endTime": 1767037135353.0, + "totalBytesProcessed": f"{total_bytes_processed}", + "query": { + "totalBytesProcessed": f"{total_bytes_processed}", + "totalBytesBilled": "0", + "cacheHit": True, + "statementType": "SELECT", + }, + "reservation_id": "reservation_id", + "edition": "ENTERPRISE", + "reservationGroupPath": [""], + }, + } + ) + + dry_run_result = gbq.read_gbq("SELECT 1", project_id="my-project", dry_run=True) + + # Check which method was called based on BigQuery version + if ( + hasattr(mock_bigquery_client, "query_and_wait") + and mock_bigquery_client.query_and_wait.called + ): + _, kwargs = mock_bigquery_client.query_and_wait.call_args + job_config = kwargs["job_config"] + else: + _, kwargs = mock_bigquery_client.query.call_args + job_config = kwargs["job_config"] + assert job_config.dry_run is True + assert dry_run_result["totalBytesProcessed"] == total_bytes_processed diff --git a/tests/unit/test_query.py b/tests/unit/test_query.py index 2437fa02..1ab7e54f 100644 --- a/tests/unit/test_query.py +++ b/tests/unit/test_query.py @@ -170,15 +170,19 @@ def test_query_response_bytes(size_in_bytes, formatted_text): def test__wait_for_query_job_exits_when_done(mock_bigquery_client): connector = _make_connector() connector.client = mock_bigquery_client - connector.start = datetime.datetime(2020, 1, 1).timestamp() mock_query = mock.create_autospec(google.cloud.bigquery.QueryJob) type(mock_query).state = mock.PropertyMock(side_effect=("RUNNING", "DONE")) mock_query.result.side_effect = concurrent.futures.TimeoutError("fake timeout") - with freezegun.freeze_time("2020-01-01 00:00:00", tick=False): + frozen_time = datetime.datetime(2020, 1, 1) + with freezegun.freeze_time(frozen_time, tick=False): + # Set start time inside frozen context to ensure elapsed time is 0 + connector.start = frozen_time.timestamp() + # Mock get_elapsed_seconds to return 0 to prevent timeout + connector.get_elapsed_seconds = mock.Mock(return_value=0.0) module_under_test._wait_for_query_job( - connector, mock_bigquery_client, mock_query, 60 + connector, mock_bigquery_client, mock_query, 1000 ) mock_bigquery_client.cancel_job.assert_not_called()