From 03e46f90e3107f9ca0f7ff6f950f1d93e6dfc24d Mon Sep 17 00:00:00 2001 From: Garrett Wu Date: Thu, 22 Jan 2026 22:02:26 +0000 Subject: [PATCH 1/3] feat: add bigquery.ml.generate_text function --- GEMINI.md | 4 +- bigframes/bigquery/_operations/ml.py | 101 +++++++++++++++++- bigframes/bigquery/ml.py | 2 + bigframes/core/sql/ml.py | 96 +++++++++++++++-- .../bq_dataframes_ml_cross_validation.ipynb | 4 +- tests/unit/bigquery/test_ml.py | 43 ++++++++ .../evaluate_model_with_options.sql | 2 +- .../generate_text_model_basic.sql | 1 + .../generate_text_model_with_options.sql | 1 + .../global_explain_model_with_options.sql | 2 +- .../predict_model_with_options.sql | 2 +- tests/unit/core/sql/test_ml.py | 26 +++++ 12 files changed, 267 insertions(+), 17 deletions(-) create mode 100644 tests/unit/core/sql/snapshots/test_ml/test_generate_text_model_basic/generate_text_model_basic.sql create mode 100644 tests/unit/core/sql/snapshots/test_ml/test_generate_text_model_with_options/generate_text_model_with_options.sql diff --git a/GEMINI.md b/GEMINI.md index 0d447f17a48..2e4c1ea0a58 100644 --- a/GEMINI.md +++ b/GEMINI.md @@ -2,9 +2,9 @@ ## Testing -We use `nox` to instrument our tests. +We use `pytest` to instrument our tests. -- To test your changes, run unit tests with `nox`: +- To test your changes, run unit tests with `pytest`: ```bash nox -r -s unit diff --git a/bigframes/bigquery/_operations/ml.py b/bigframes/bigquery/_operations/ml.py index e5a5c5dfb68..1f78b42e888 100644 --- a/bigframes/bigquery/_operations/ml.py +++ b/bigframes/bigquery/_operations/ml.py @@ -14,7 +14,7 @@ from __future__ import annotations -from typing import cast, Mapping, Optional, Union +from typing import Any, cast, List, Mapping, Optional, Union import bigframes_vendored.constants import google.cloud.bigquery @@ -431,3 +431,102 @@ def transform( return bpd.read_gbq_query(sql) else: return session.read_gbq_query(sql) + + +@log_adapter.method_logger(custom_base_name="bigquery_ml") +def generate_text( + model: Union[bigframes.ml.base.BaseEstimator, str, pd.Series], + input_: Union[pd.DataFrame, dataframe.DataFrame, str], + *, + temperature: Optional[float] = None, + max_output_tokens: Optional[int] = None, + top_k: Optional[int] = None, + top_p: Optional[float] = None, + flatten_json_output: Optional[bool] = None, + safety_settings: Optional[Mapping[str, str]] = None, + stop_sequences: Optional[List[str]] = None, + ground_with_google_search: Optional[bool] = None, + model_params: Optional[Mapping[str, Any]] = None, + request_type: Optional[str] = None, +) -> dataframe.DataFrame: + """ + Generates text using a BigQuery ML model. + + See the `BigQuery ML GENERATE_TEXT function syntax + `_ + for additional reference. + + Args: + model (bigframes.ml.base.BaseEstimator or str): + The model to use for text generation. + input_ (Union[bigframes.pandas.DataFrame, str]): + The DataFrame or query to use for text generation. + temperature (float, optional): + A FLOAT64 value that is used for sampling promiscuity. The value + must be in the range ``[0.0, 1.0]``. A lower temperature works well + for prompts that expect a more deterministic and less open-ended + or creative response, while a higher temperature can lead to more + diverse or creative results. A temperature of ``0`` is + deterministic, meaning that the highest probability response is + always selected. + max_output_tokens (int, optional): + An INT64 value that sets the maximum number of tokens in the + generated text. + top_k (int, optional): + An INT64 value that changes how the model selects tokens for + output. A ``top_k`` of ``1`` means the next selected token is the + most probable among all tokens in the model's vocabulary. A + ``top_k`` of ``3`` means that the next token is selected from + among the three most probable tokens by using temperature. The + default value is ``40``. + top_p (float, optional): + A FLOAT64 value that changes how the model selects tokens for + output. Tokens are selected from most probable to least probable + until the sum of their probabilities equals the ``top_p`` value. + For example, if tokens A, B, and C have a probability of 0.3, 0.2, + and 0.1 and the ``top_p`` value is ``0.5``, then the model will + select either A or B as the next token by using temperature. The + default value is ``0.95``. + flatten_json_output (bool, optional): + A BOOL value that determines the content of the generated JSON column. + safety_settings (Mapping[str, str], optional): + A STRUCT value that contains the safety settings for the model. + The STRUCT must have a ``category`` field of type STRING and a + ``threshold`` field of type STRING. + stop_sequences (List[str], optional): + An ARRAY value that contains the stop sequences for the model. + ground_with_google_search (bool, optional): + A BOOL value that determines whether to ground the model with Google Search. + model_params (Mapping[str, Any], optional): + A JSON value that contains the parameters for the model. + request_type (str, optional): + A STRING value that contains the request type for the model. + + Returns: + bigframes.pandas.DataFrame: + The generated text. + """ + import bigframes.pandas as bpd + + model_name, session = _get_model_name_and_session(model, input_) + table_sql = _to_sql(input_) + + sql = bigframes.core.sql.ml.generate_text( + model_name=model_name, + table=table_sql, + temperature=temperature, + max_output_tokens=max_output_tokens, + top_k=top_k, + top_p=top_p, + flatten_json_output=flatten_json_output, + safety_settings=safety_settings, + stop_sequences=stop_sequences, + ground_with_google_search=ground_with_google_search, + model_params=model_params, + request_type=request_type, + ) + + if session is None: + return bpd.read_gbq_query(sql) + else: + return session.read_gbq_query(sql) diff --git a/bigframes/bigquery/ml.py b/bigframes/bigquery/ml.py index 6ceadb324d5..ef9aa3288b8 100644 --- a/bigframes/bigquery/ml.py +++ b/bigframes/bigquery/ml.py @@ -23,6 +23,7 @@ create_model, evaluate, explain_predict, + generate_text, global_explain, predict, transform, @@ -35,4 +36,5 @@ "explain_predict", "global_explain", "transform", + "generate_text", ] diff --git a/bigframes/core/sql/ml.py b/bigframes/core/sql/ml.py index 17493159250..4febdc6f140 100644 --- a/bigframes/core/sql/ml.py +++ b/bigframes/core/sql/ml.py @@ -14,7 +14,9 @@ from __future__ import annotations -from typing import Dict, Mapping, Optional, Union +import collections.abc +import json +from typing import Any, Dict, List, Mapping, Optional, Union import bigframes.core.compile.googlesql as googlesql import bigframes.core.sql @@ -100,14 +102,41 @@ def create_model_ddl( def _build_struct_sql( - struct_options: Mapping[str, Union[str, int, float, bool]] + struct_options: Mapping[ + str, + Union[str, int, float, bool, Mapping[str, str], List[str], Mapping[str, Any]], + ] ) -> str: if not struct_options: return "" rendered_options = [] for option_name, option_value in struct_options.items(): - rendered_val = bigframes.core.sql.simple_literal(option_value) + if option_name == "model_params": + json_str = json.dumps(option_value) + # Escape single quotes for SQL string literal + sql_json_str = json_str.replace("'", "''") + rendered_val = f"JSON'{sql_json_str}'" + elif isinstance(option_value, collections.abc.Mapping): + struct_body = ", ".join( + [ + f"{bigframes.core.sql.simple_literal(v)} AS {k}" + for k, v in option_value.items() + ] + ) + rendered_val = f"STRUCT({struct_body})" + elif isinstance(option_value, list): + rendered_val = ( + "[" + + ", ".join( + [bigframes.core.sql.simple_literal(v) for v in option_value] + ) + + "]" + ) + elif isinstance(option_value, bool): + rendered_val = str(option_value).lower() + else: + rendered_val = bigframes.core.sql.simple_literal(option_value) rendered_options.append(f"{rendered_val} AS {option_name}") return f", STRUCT({', '.join(rendered_options)})" @@ -151,7 +180,7 @@ def predict( """Encode the ML.PREDICT statement. See https://cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-predict for reference. """ - struct_options = {} + struct_options: Dict[str, Union[str, int, float, bool]] = {} if threshold is not None: struct_options["threshold"] = threshold if keep_original_columns is not None: @@ -160,10 +189,10 @@ def predict( struct_options["trial_id"] = trial_id sql = ( - f"SELECT * FROM ML.PREDICT(MODEL {googlesql.identifier(model_name)}, ({table})" + f"SELECT * FROM ML.PREDICT(MODEL {googlesql.identifier(model_name)}, ({table}))" ) sql += _build_struct_sql(struct_options) - sql += ")\n" + sql += "\n" return sql @@ -205,13 +234,13 @@ def global_explain( """Encode the ML.GLOBAL_EXPLAIN statement. See https://cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-global-explain for reference. """ - struct_options = {} + struct_options: Dict[str, Union[str, int, float, bool]] = {} if class_level_explain is not None: struct_options["class_level_explain"] = class_level_explain - sql = f"SELECT * FROM ML.GLOBAL_EXPLAIN(MODEL {googlesql.identifier(model_name)}" + sql = f"SELECT * FROM ML.GLOBAL_EXPLAIN(MODEL {googlesql.identifier(model_name)})" sql += _build_struct_sql(struct_options) - sql += ")\n" + sql += "\n" return sql @@ -224,3 +253,52 @@ def transform( """ sql = f"SELECT * FROM ML.TRANSFORM(MODEL {googlesql.identifier(model_name)}, ({table}))\n" return sql + + +def generate_text( + model_name: str, + table: str, + *, + temperature: Optional[float] = None, + max_output_tokens: Optional[int] = None, + top_k: Optional[int] = None, + top_p: Optional[float] = None, + flatten_json_output: Optional[bool] = None, + safety_settings: Optional[Mapping[str, str]] = None, + stop_sequences: Optional[List[str]] = None, + ground_with_google_search: Optional[bool] = None, + model_params: Optional[Mapping[str, Any]] = None, + request_type: Optional[str] = None, +) -> str: + """Encode the ML.GENERATE_TEXT statement. + See https://docs.cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-generate-text for reference. + """ + struct_options: Dict[ + str, + Union[str, int, float, bool, Mapping[str, str], List[str], Mapping[str, Any]], + ] = {} + if temperature is not None: + struct_options["temperature"] = temperature + if max_output_tokens is not None: + struct_options["max_output_tokens"] = max_output_tokens + if top_k is not None: + struct_options["top_k"] = top_k + if top_p is not None: + struct_options["top_p"] = top_p + if flatten_json_output is not None: + struct_options["flatten_json_output"] = flatten_json_output + if safety_settings is not None: + struct_options["safety_settings"] = safety_settings + if stop_sequences is not None: + struct_options["stop_sequences"] = stop_sequences + if ground_with_google_search is not None: + struct_options["ground_with_google_search"] = ground_with_google_search + if model_params is not None: + struct_options["model_params"] = model_params + if request_type is not None: + struct_options["request_type"] = request_type + + sql = f"SELECT * FROM ML.GENERATE_TEXT(MODEL {googlesql.identifier(model_name)}, ({table})" + sql += _build_struct_sql(struct_options) + sql += ")\n" + return sql diff --git a/notebooks/ml/bq_dataframes_ml_cross_validation.ipynb b/notebooks/ml/bq_dataframes_ml_cross_validation.ipynb index 501bfc88d31..3dc0eabf5a1 100644 --- a/notebooks/ml/bq_dataframes_ml_cross_validation.ipynb +++ b/notebooks/ml/bq_dataframes_ml_cross_validation.ipynb @@ -991,7 +991,7 @@ ], "metadata": { "kernelspec": { - "display_name": "venv", + "display_name": "venv (3.10.14)", "language": "python", "name": "python3" }, @@ -1005,7 +1005,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.15" + "version": "3.10.14" } }, "nbformat": 4, diff --git a/tests/unit/bigquery/test_ml.py b/tests/unit/bigquery/test_ml.py index 96b97d68fe3..6aaaffa541a 100644 --- a/tests/unit/bigquery/test_ml.py +++ b/tests/unit/bigquery/test_ml.py @@ -163,3 +163,46 @@ def test_transform_with_pandas_dataframe(read_pandas_mock, read_gbq_query_mock): assert "ML.TRANSFORM" in generated_sql assert f"MODEL `{MODEL_NAME}`" in generated_sql assert "(SELECT * FROM `pandas_df`)" in generated_sql + + +@mock.patch("bigframes.pandas.read_gbq_query") +@mock.patch("bigframes.pandas.read_pandas") +def test_generate_text_with_pandas_dataframe(read_pandas_mock, read_gbq_query_mock): + df = pd.DataFrame({"col1": [1, 2, 3]}) + read_pandas_mock.return_value._to_sql_query.return_value = ( + "SELECT * FROM `pandas_df`", + [], + [], + ) + ml_ops.generate_text( + MODEL_SERIES, + input_=df, + temperature=0.5, + max_output_tokens=128, + top_k=20, + top_p=0.9, + flatten_json_output=True, + safety_settings={"hate_speech": "BLOCK_ONLY_HIGH"}, + stop_sequences=["a", "b"], + ground_with_google_search=True, + model_params={"param1": "value1"}, + request_type="TYPE", + ) + read_pandas_mock.assert_called_once() + read_gbq_query_mock.assert_called_once() + generated_sql = read_gbq_query_mock.call_args[0][0] + assert "ML.GENERATE_TEXT" in generated_sql + assert f"MODEL `{MODEL_NAME}`" in generated_sql + assert "(SELECT * FROM `pandas_df`)" in generated_sql + assert "STRUCT(0.5 AS temperature" in generated_sql + assert "128 AS max_output_tokens" in generated_sql + assert "20 AS top_k" in generated_sql + assert "0.9 AS top_p" in generated_sql + assert "true AS flatten_json_output" in generated_sql + assert ( + "STRUCT('BLOCK_ONLY_HIGH' AS hate_speech) AS safety_settings" in generated_sql + ) + assert "['a', 'b'] AS stop_sequences" in generated_sql + assert "true AS ground_with_google_search" in generated_sql + assert """JSON'{"param1": "value1"}' AS model_params""" in generated_sql + assert "'TYPE' AS request_type" in generated_sql diff --git a/tests/unit/core/sql/snapshots/test_ml/test_evaluate_model_with_options/evaluate_model_with_options.sql b/tests/unit/core/sql/snapshots/test_ml/test_evaluate_model_with_options/evaluate_model_with_options.sql index 01eb4d37819..848c36907b9 100644 --- a/tests/unit/core/sql/snapshots/test_ml/test_evaluate_model_with_options/evaluate_model_with_options.sql +++ b/tests/unit/core/sql/snapshots/test_ml/test_evaluate_model_with_options/evaluate_model_with_options.sql @@ -1 +1 @@ -SELECT * FROM ML.EVALUATE(MODEL `my_model`, STRUCT(False AS perform_aggregation, 10 AS horizon, 0.95 AS confidence_level)) +SELECT * FROM ML.EVALUATE(MODEL `my_model`, STRUCT(false AS perform_aggregation, 10 AS horizon, 0.95 AS confidence_level)) diff --git a/tests/unit/core/sql/snapshots/test_ml/test_generate_text_model_basic/generate_text_model_basic.sql b/tests/unit/core/sql/snapshots/test_ml/test_generate_text_model_basic/generate_text_model_basic.sql new file mode 100644 index 00000000000..9d986876448 --- /dev/null +++ b/tests/unit/core/sql/snapshots/test_ml/test_generate_text_model_basic/generate_text_model_basic.sql @@ -0,0 +1 @@ +SELECT * FROM ML.GENERATE_TEXT(MODEL `my_project.my_dataset.my_model`, (SELECT * FROM new_data)) diff --git a/tests/unit/core/sql/snapshots/test_ml/test_generate_text_model_with_options/generate_text_model_with_options.sql b/tests/unit/core/sql/snapshots/test_ml/test_generate_text_model_with_options/generate_text_model_with_options.sql new file mode 100644 index 00000000000..4a9f8ea9718 --- /dev/null +++ b/tests/unit/core/sql/snapshots/test_ml/test_generate_text_model_with_options/generate_text_model_with_options.sql @@ -0,0 +1 @@ +SELECT * FROM ML.GENERATE_TEXT(MODEL `my_project.my_dataset.my_model`, (SELECT * FROM new_data), STRUCT(0.5 AS temperature, 128 AS max_output_tokens, 20 AS top_k, 0.9 AS top_p, true AS flatten_json_output, STRUCT('BLOCK_ONLY_HIGH' AS hate_speech) AS safety_settings, ['a', 'b'] AS stop_sequences, true AS ground_with_google_search, JSON'{"param1": "value1"}' AS model_params, 'TYPE' AS request_type)) diff --git a/tests/unit/core/sql/snapshots/test_ml/test_global_explain_model_with_options/global_explain_model_with_options.sql b/tests/unit/core/sql/snapshots/test_ml/test_global_explain_model_with_options/global_explain_model_with_options.sql index 1a3baa0c13b..2b2e8b12d22 100644 --- a/tests/unit/core/sql/snapshots/test_ml/test_global_explain_model_with_options/global_explain_model_with_options.sql +++ b/tests/unit/core/sql/snapshots/test_ml/test_global_explain_model_with_options/global_explain_model_with_options.sql @@ -1 +1 @@ -SELECT * FROM ML.GLOBAL_EXPLAIN(MODEL `my_model`, STRUCT(True AS class_level_explain)) +SELECT * FROM ML.GLOBAL_EXPLAIN(MODEL `my_model`), STRUCT(true AS class_level_explain) diff --git a/tests/unit/core/sql/snapshots/test_ml/test_predict_model_with_options/predict_model_with_options.sql b/tests/unit/core/sql/snapshots/test_ml/test_predict_model_with_options/predict_model_with_options.sql index 96c8074e4c1..4e696c486fd 100644 --- a/tests/unit/core/sql/snapshots/test_ml/test_predict_model_with_options/predict_model_with_options.sql +++ b/tests/unit/core/sql/snapshots/test_ml/test_predict_model_with_options/predict_model_with_options.sql @@ -1 +1 @@ -SELECT * FROM ML.PREDICT(MODEL `my_model`, (SELECT * FROM new_data), STRUCT(True AS keep_original_columns)) +SELECT * FROM ML.PREDICT(MODEL `my_model`, (SELECT * FROM new_data)), STRUCT(true AS keep_original_columns) diff --git a/tests/unit/core/sql/test_ml.py b/tests/unit/core/sql/test_ml.py index 9721f42fee1..421169e8f67 100644 --- a/tests/unit/core/sql/test_ml.py +++ b/tests/unit/core/sql/test_ml.py @@ -177,3 +177,29 @@ def test_transform_model_basic(snapshot): table="SELECT * FROM new_data", ) snapshot.assert_match(sql, "transform_model_basic.sql") + + +def test_generate_text_model_basic(snapshot): + sql = bigframes.core.sql.ml.generate_text( + model_name="my_project.my_dataset.my_model", + table="SELECT * FROM new_data", + ) + snapshot.assert_match(sql, "generate_text_model_basic.sql") + + +def test_generate_text_model_with_options(snapshot): + sql = bigframes.core.sql.ml.generate_text( + model_name="my_project.my_dataset.my_model", + table="SELECT * FROM new_data", + temperature=0.5, + max_output_tokens=128, + top_k=20, + top_p=0.9, + flatten_json_output=True, + safety_settings={"hate_speech": "BLOCK_ONLY_HIGH"}, + stop_sequences=["a", "b"], + ground_with_google_search=True, + model_params={"param1": "value1"}, + request_type="TYPE", + ) + snapshot.assert_match(sql, "generate_text_model_with_options.sql") From 58f647b927459831afb19d4e36f1f05187a5ee10 Mon Sep 17 00:00:00 2001 From: Garrett Wu Date: Fri, 23 Jan 2026 19:24:09 +0000 Subject: [PATCH 2/3] feat: Pass options as struct to ML.PREDICT and ML.GLOBAL_EXPLAIN This change corrects the SQL generation for and to pass the options as a struct, which is the correct syntax. The snapshot tests have been updated to reflect these changes. --- GEMINI.md | 4 ++-- bigframes/core/sql/ml.py | 8 ++++---- .../global_explain_model_with_options.sql | 2 +- .../predict_model_with_options.sql | 2 +- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/GEMINI.md b/GEMINI.md index 2e4c1ea0a58..0d447f17a48 100644 --- a/GEMINI.md +++ b/GEMINI.md @@ -2,9 +2,9 @@ ## Testing -We use `pytest` to instrument our tests. +We use `nox` to instrument our tests. -- To test your changes, run unit tests with `pytest`: +- To test your changes, run unit tests with `nox`: ```bash nox -r -s unit diff --git a/bigframes/core/sql/ml.py b/bigframes/core/sql/ml.py index 4febdc6f140..07647b1e4ec 100644 --- a/bigframes/core/sql/ml.py +++ b/bigframes/core/sql/ml.py @@ -189,10 +189,10 @@ def predict( struct_options["trial_id"] = trial_id sql = ( - f"SELECT * FROM ML.PREDICT(MODEL {googlesql.identifier(model_name)}, ({table}))" + f"SELECT * FROM ML.PREDICT(MODEL {googlesql.identifier(model_name)}, ({table})" ) sql += _build_struct_sql(struct_options) - sql += "\n" + sql += ")\n" return sql @@ -238,9 +238,9 @@ def global_explain( if class_level_explain is not None: struct_options["class_level_explain"] = class_level_explain - sql = f"SELECT * FROM ML.GLOBAL_EXPLAIN(MODEL {googlesql.identifier(model_name)})" + sql = f"SELECT * FROM ML.GLOBAL_EXPLAIN(MODEL {googlesql.identifier(model_name)}" sql += _build_struct_sql(struct_options) - sql += "\n" + sql += ")\n" return sql diff --git a/tests/unit/core/sql/snapshots/test_ml/test_global_explain_model_with_options/global_explain_model_with_options.sql b/tests/unit/core/sql/snapshots/test_ml/test_global_explain_model_with_options/global_explain_model_with_options.sql index 2b2e8b12d22..b8d158acfc7 100644 --- a/tests/unit/core/sql/snapshots/test_ml/test_global_explain_model_with_options/global_explain_model_with_options.sql +++ b/tests/unit/core/sql/snapshots/test_ml/test_global_explain_model_with_options/global_explain_model_with_options.sql @@ -1 +1 @@ -SELECT * FROM ML.GLOBAL_EXPLAIN(MODEL `my_model`), STRUCT(true AS class_level_explain) +SELECT * FROM ML.GLOBAL_EXPLAIN(MODEL `my_model`, STRUCT(true AS class_level_explain)) diff --git a/tests/unit/core/sql/snapshots/test_ml/test_predict_model_with_options/predict_model_with_options.sql b/tests/unit/core/sql/snapshots/test_ml/test_predict_model_with_options/predict_model_with_options.sql index 4e696c486fd..f320d47fcf4 100644 --- a/tests/unit/core/sql/snapshots/test_ml/test_predict_model_with_options/predict_model_with_options.sql +++ b/tests/unit/core/sql/snapshots/test_ml/test_predict_model_with_options/predict_model_with_options.sql @@ -1 +1 @@ -SELECT * FROM ML.PREDICT(MODEL `my_model`, (SELECT * FROM new_data)), STRUCT(true AS keep_original_columns) +SELECT * FROM ML.PREDICT(MODEL `my_model`, (SELECT * FROM new_data), STRUCT(true AS keep_original_columns)) From 7e58f51fd8c32ed6c7011718bc88c5e1e21509a9 Mon Sep 17 00:00:00 2001 From: Garrett Wu Date: Fri, 23 Jan 2026 21:46:24 +0000 Subject: [PATCH 3/3] remove params --- bigframes/bigquery/_operations/ml.py | 12 +----------- bigframes/core/sql/ml.py | 6 ------ tests/unit/bigquery/test_ml.py | 6 ------ .../generate_text_model_with_options.sql | 2 +- tests/unit/core/sql/test_ml.py | 2 -- 5 files changed, 2 insertions(+), 26 deletions(-) diff --git a/bigframes/bigquery/_operations/ml.py b/bigframes/bigquery/_operations/ml.py index 1f78b42e888..29ab19550b0 100644 --- a/bigframes/bigquery/_operations/ml.py +++ b/bigframes/bigquery/_operations/ml.py @@ -14,7 +14,7 @@ from __future__ import annotations -from typing import Any, cast, List, Mapping, Optional, Union +from typing import cast, List, Mapping, Optional, Union import bigframes_vendored.constants import google.cloud.bigquery @@ -443,10 +443,8 @@ def generate_text( top_k: Optional[int] = None, top_p: Optional[float] = None, flatten_json_output: Optional[bool] = None, - safety_settings: Optional[Mapping[str, str]] = None, stop_sequences: Optional[List[str]] = None, ground_with_google_search: Optional[bool] = None, - model_params: Optional[Mapping[str, Any]] = None, request_type: Optional[str] = None, ) -> dataframe.DataFrame: """ @@ -489,16 +487,10 @@ def generate_text( default value is ``0.95``. flatten_json_output (bool, optional): A BOOL value that determines the content of the generated JSON column. - safety_settings (Mapping[str, str], optional): - A STRUCT value that contains the safety settings for the model. - The STRUCT must have a ``category`` field of type STRING and a - ``threshold`` field of type STRING. stop_sequences (List[str], optional): An ARRAY value that contains the stop sequences for the model. ground_with_google_search (bool, optional): A BOOL value that determines whether to ground the model with Google Search. - model_params (Mapping[str, Any], optional): - A JSON value that contains the parameters for the model. request_type (str, optional): A STRING value that contains the request type for the model. @@ -519,10 +511,8 @@ def generate_text( top_k=top_k, top_p=top_p, flatten_json_output=flatten_json_output, - safety_settings=safety_settings, stop_sequences=stop_sequences, ground_with_google_search=ground_with_google_search, - model_params=model_params, request_type=request_type, ) diff --git a/bigframes/core/sql/ml.py b/bigframes/core/sql/ml.py index 07647b1e4ec..0b9427b9384 100644 --- a/bigframes/core/sql/ml.py +++ b/bigframes/core/sql/ml.py @@ -264,10 +264,8 @@ def generate_text( top_k: Optional[int] = None, top_p: Optional[float] = None, flatten_json_output: Optional[bool] = None, - safety_settings: Optional[Mapping[str, str]] = None, stop_sequences: Optional[List[str]] = None, ground_with_google_search: Optional[bool] = None, - model_params: Optional[Mapping[str, Any]] = None, request_type: Optional[str] = None, ) -> str: """Encode the ML.GENERATE_TEXT statement. @@ -287,14 +285,10 @@ def generate_text( struct_options["top_p"] = top_p if flatten_json_output is not None: struct_options["flatten_json_output"] = flatten_json_output - if safety_settings is not None: - struct_options["safety_settings"] = safety_settings if stop_sequences is not None: struct_options["stop_sequences"] = stop_sequences if ground_with_google_search is not None: struct_options["ground_with_google_search"] = ground_with_google_search - if model_params is not None: - struct_options["model_params"] = model_params if request_type is not None: struct_options["request_type"] = request_type diff --git a/tests/unit/bigquery/test_ml.py b/tests/unit/bigquery/test_ml.py index 6aaaffa541a..e52820f88a9 100644 --- a/tests/unit/bigquery/test_ml.py +++ b/tests/unit/bigquery/test_ml.py @@ -182,10 +182,8 @@ def test_generate_text_with_pandas_dataframe(read_pandas_mock, read_gbq_query_mo top_k=20, top_p=0.9, flatten_json_output=True, - safety_settings={"hate_speech": "BLOCK_ONLY_HIGH"}, stop_sequences=["a", "b"], ground_with_google_search=True, - model_params={"param1": "value1"}, request_type="TYPE", ) read_pandas_mock.assert_called_once() @@ -199,10 +197,6 @@ def test_generate_text_with_pandas_dataframe(read_pandas_mock, read_gbq_query_mo assert "20 AS top_k" in generated_sql assert "0.9 AS top_p" in generated_sql assert "true AS flatten_json_output" in generated_sql - assert ( - "STRUCT('BLOCK_ONLY_HIGH' AS hate_speech) AS safety_settings" in generated_sql - ) assert "['a', 'b'] AS stop_sequences" in generated_sql assert "true AS ground_with_google_search" in generated_sql - assert """JSON'{"param1": "value1"}' AS model_params""" in generated_sql assert "'TYPE' AS request_type" in generated_sql diff --git a/tests/unit/core/sql/snapshots/test_ml/test_generate_text_model_with_options/generate_text_model_with_options.sql b/tests/unit/core/sql/snapshots/test_ml/test_generate_text_model_with_options/generate_text_model_with_options.sql index 4a9f8ea9718..7839ff3fbdd 100644 --- a/tests/unit/core/sql/snapshots/test_ml/test_generate_text_model_with_options/generate_text_model_with_options.sql +++ b/tests/unit/core/sql/snapshots/test_ml/test_generate_text_model_with_options/generate_text_model_with_options.sql @@ -1 +1 @@ -SELECT * FROM ML.GENERATE_TEXT(MODEL `my_project.my_dataset.my_model`, (SELECT * FROM new_data), STRUCT(0.5 AS temperature, 128 AS max_output_tokens, 20 AS top_k, 0.9 AS top_p, true AS flatten_json_output, STRUCT('BLOCK_ONLY_HIGH' AS hate_speech) AS safety_settings, ['a', 'b'] AS stop_sequences, true AS ground_with_google_search, JSON'{"param1": "value1"}' AS model_params, 'TYPE' AS request_type)) +SELECT * FROM ML.GENERATE_TEXT(MODEL `my_project.my_dataset.my_model`, (SELECT * FROM new_data), STRUCT(0.5 AS temperature, 128 AS max_output_tokens, 20 AS top_k, 0.9 AS top_p, true AS flatten_json_output, ['a', 'b'] AS stop_sequences, true AS ground_with_google_search, 'TYPE' AS request_type)) diff --git a/tests/unit/core/sql/test_ml.py b/tests/unit/core/sql/test_ml.py index 421169e8f67..15e9ef0aa10 100644 --- a/tests/unit/core/sql/test_ml.py +++ b/tests/unit/core/sql/test_ml.py @@ -196,10 +196,8 @@ def test_generate_text_model_with_options(snapshot): top_k=20, top_p=0.9, flatten_json_output=True, - safety_settings={"hate_speech": "BLOCK_ONLY_HIGH"}, stop_sequences=["a", "b"], ground_with_google_search=True, - model_params={"param1": "value1"}, request_type="TYPE", ) snapshot.assert_match(sql, "generate_text_model_with_options.sql")