From fd9a59af298d833a1b4f0acff45eb9724e1ff270 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=80=E6=96=B9=E7=9A=84=E8=B5=9B=E5=8D=9A=E7=9C=9F?= =?UTF-8?q?=E5=90=9B?= Date: Wed, 24 Dec 2025 17:04:32 +0800 Subject: [PATCH] feat(table-selection): add LLM-based table selection for SQL generation MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add a new LLM-based table selection feature that can replace or work alongside the existing RAG-based table embedding. Key changes: - New table selection module (backend/apps/datasource/llm_select/) - New config option TABLE_LLM_SELECTION_ENABLED (default: true) - Add table_select_answer field to ChatRecord for logging LLM selections - Add SELECT_TABLE operation type for tracking in chat logs - Skip foreign key relation table completion when using LLM selection (LLM already sees table relations and can decide which tables to include) 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- .../versions/054_add_table_select_answer.py | 22 ++ backend/apps/chat/curd/chat.py | 20 ++ backend/apps/chat/models/chat_model.py | 3 + backend/apps/chat/task/llm.py | 40 ++- backend/apps/datasource/crud/datasource.py | 40 ++- .../apps/datasource/llm_select/__init__.py | 2 + .../datasource/llm_select/table_selection.py | 243 ++++++++++++++++++ .../apps/template/select_table/__init__.py | 0 .../apps/template/select_table/generator.py | 6 + backend/common/core/config.py | 9 + backend/templates/template.yaml | 83 +++++- 11 files changed, 456 insertions(+), 12 deletions(-) create mode 100644 backend/alembic/versions/054_add_table_select_answer.py create mode 100644 backend/apps/datasource/llm_select/__init__.py create mode 100644 backend/apps/datasource/llm_select/table_selection.py create mode 100644 backend/apps/template/select_table/__init__.py create mode 100644 backend/apps/template/select_table/generator.py diff --git a/backend/alembic/versions/054_add_table_select_answer.py b/backend/alembic/versions/054_add_table_select_answer.py new file mode 100644 index 00000000..e75e24f0 --- /dev/null +++ b/backend/alembic/versions/054_add_table_select_answer.py @@ -0,0 +1,22 @@ +"""add table_select_answer column to chat_record + +Revision ID: 054_table_select +Revises: 5755c0b95839 +Create Date: 2025-12-23 + +""" +from alembic import op +import sqlalchemy as sa + +revision = '054_table_select' +down_revision = '5755c0b95839' +branch_labels = None +depends_on = None + + +def upgrade(): + op.add_column('chat_record', sa.Column('table_select_answer', sa.Text(), nullable=True)) + + +def downgrade(): + op.drop_column('chat_record', 'table_select_answer') diff --git a/backend/apps/chat/curd/chat.py b/backend/apps/chat/curd/chat.py index f93ad18e..94e97cbc 100644 --- a/backend/apps/chat/curd/chat.py +++ b/backend/apps/chat/curd/chat.py @@ -649,6 +649,26 @@ def save_select_datasource_answer(session: SessionDep, record_id: int, answer: s return result +def save_table_select_answer(session: SessionDep, record_id: int, answer: str) -> ChatRecord: + """保存 LLM 表选择的结果到 ChatRecord""" + if not record_id: + raise Exception("Record id cannot be None") + record = get_chat_record_by_id(session, record_id) + + record.table_select_answer = answer + + result = ChatRecord(**record.model_dump()) + + stmt = update(ChatRecord).where(and_(ChatRecord.id == record.id)).values( + table_select_answer=record.table_select_answer, + ) + + session.execute(stmt) + session.commit() + + return result + + def save_recommend_question_answer(session: SessionDep, record_id: int, answer: dict = None) -> ChatRecord: if not record_id: diff --git a/backend/apps/chat/models/chat_model.py b/backend/apps/chat/models/chat_model.py index 33872005..25d44dbc 100644 --- a/backend/apps/chat/models/chat_model.py +++ b/backend/apps/chat/models/chat_model.py @@ -40,6 +40,7 @@ class OperationEnum(Enum): GENERATE_SQL_WITH_PERMISSIONS = '5' CHOOSE_DATASOURCE = '6' GENERATE_DYNAMIC_SQL = '7' + SELECT_TABLE = '8' # LLM 表选择 class ChatFinishStep(Enum): @@ -112,6 +113,7 @@ class ChatRecord(SQLModel, table=True): recommended_question_answer: str = Field(sa_column=Column(Text, nullable=True)) recommended_question: str = Field(sa_column=Column(Text, nullable=True)) datasource_select_answer: str = Field(sa_column=Column(Text, nullable=True)) + table_select_answer: str = Field(sa_column=Column(Text, nullable=True)) finish: bool = Field(sa_column=Column(Boolean, nullable=True, default=False)) error: str = Field(sa_column=Column(Text, nullable=True)) analysis_record_id: int = Field(sa_column=Column(BigInteger, nullable=True)) @@ -137,6 +139,7 @@ class ChatRecordResult(BaseModel): predict_data: Optional[str] = None recommended_question: Optional[str] = None datasource_select_answer: Optional[str] = None + table_select_answer: Optional[str] = None finish: Optional[bool] = None error: Optional[str] = None analysis_record_id: Optional[int] = None diff --git a/backend/apps/chat/task/llm.py b/backend/apps/chat/task/llm.py index b7402ab1..df8e3d7a 100644 --- a/backend/apps/chat/task/llm.py +++ b/backend/apps/chat/task/llm.py @@ -116,8 +116,16 @@ def __init__(self, session: Session, current_user: CurrentUser, chat_question: C if not ds: raise SingleMessageError("No available datasource configuration found") chat_question.engine = (ds.type_name if ds.type != 'excel' else 'PostgreSQL') + get_version(ds) - chat_question.db_schema = get_table_schema(session=session, current_user=current_user, ds=ds, - question=chat_question.question, embedding=embedding) + # 延迟 get_table_schema 调用到 init_record 之后,以便记录 LLM 表选择日志 + self._pending_schema_params = { + 'session': session, + 'current_user': current_user, + 'ds': ds, + 'question': chat_question.question, + 'embedding': embedding, + 'history_questions': history_questions, + 'config': config + } self.generate_sql_logs = list_generate_sql_logs(session=session, chart_id=chat_id) self.generate_chart_logs = list_generate_chart_logs(session=session, chart_id=chat_id) @@ -224,6 +232,22 @@ def init_messages(self): def init_record(self, session: Session) -> ChatRecord: self.record = save_question(session=session, current_user=self.current_user, question=self.chat_question) + + # 如果有延迟的 schema 获取,现在执行(此时 record 已存在,可以记录 LLM 表选择日志) + if hasattr(self, '_pending_schema_params') and self._pending_schema_params: + params = self._pending_schema_params + self.chat_question.db_schema = get_table_schema( + session=params['session'], + current_user=params['current_user'], + ds=params['ds'], + question=params['question'], + embedding=params['embedding'], + history_questions=params['history_questions'], + config=params['config'], + record_id=self.record.id + ) + self._pending_schema_params = None + return self.record def get_record(self): @@ -349,7 +373,9 @@ def generate_recommend_questions_task(self, _session: Session): session=_session, current_user=self.current_user, ds=self.ds, question=self.chat_question.question, - embedding=False) + embedding=False, + config=self.config, + record_id=self.record.id) guess_msg: List[Union[BaseMessage, dict[str, Any]]] = [] guess_msg.append(SystemMessage(content=self.chat_question.guess_sys_question(self.articles_number))) @@ -494,7 +520,9 @@ def select_datasource(self, _session: Session): self.ds) self.chat_question.db_schema = get_table_schema(session=_session, current_user=self.current_user, ds=self.ds, - question=self.chat_question.question) + question=self.chat_question.question, + config=self.config, + record_id=self.record.id) _engine_type = self.chat_question.engine _chat.engine_type = _ds.type_name # save chat @@ -997,7 +1025,9 @@ def run_task(self, in_chat: bool = True, stream: bool = True, session=_session, current_user=self.current_user, ds=self.ds, - question=self.chat_question.question) + question=self.chat_question.question, + config=self.config, + record_id=self.record.id) else: self.validate_history_ds(_session) diff --git a/backend/apps/datasource/crud/datasource.py b/backend/apps/datasource/crud/datasource.py index 153e5088..8525b69d 100644 --- a/backend/apps/datasource/crud/datasource.py +++ b/backend/apps/datasource/crud/datasource.py @@ -7,8 +7,10 @@ from sqlbot_xpack.permissions.models.ds_rules import DsRules from sqlmodel import select +from apps.ai_model.model_factory import LLMConfig from apps.datasource.crud.permission import get_column_permission_fields, get_row_permission_filters, is_normal_user from apps.datasource.embedding.table_embedding import calc_table_embedding +from apps.datasource.llm_select.table_selection import calc_table_llm_selection from apps.datasource.utils.utils import aes_decrypt from apps.db.constant import DB from apps.db.db import get_tables, get_fields, exec_sql, check_connection @@ -416,7 +418,8 @@ def get_table_obj_by_ds(session: SessionDep, current_user: CurrentUser, ds: Core def get_table_schema(session: SessionDep, current_user: CurrentUser, ds: CoreDatasource, question: str, - embedding: bool = True) -> str: + embedding: bool = True, history_questions: List[str] = None, + config: LLMConfig = None, lang: str = "中文", record_id: int = None) -> str: schema_str = "" table_objs = get_table_obj_by_ds(session=session, current_user=current_user, ds=ds) if len(table_objs) == 0: @@ -425,7 +428,12 @@ def get_table_schema(session: SessionDep, current_user: CurrentUser, ds: CoreDat schema_str += f"【DB_ID】 {db_name}\n【Schema】\n" tables = [] all_tables = [] # temp save all tables + + # 构建 table_name -> table_obj 映射,用于 LLM 表选择 + table_name_to_obj = {} for obj in table_objs: + table_name_to_obj[obj.table.table_name] = obj + schema_table = '' schema_table += f"# Table: {db_name}.{obj.table.table_name}" if ds.type != "mysql" and ds.type != "es" else f"# Table: {obj.table.table_name}" table_comment = '' @@ -453,16 +461,36 @@ def get_table_schema(session: SessionDep, current_user: CurrentUser, ds: CoreDat tables.append(t_obj) all_tables.append(t_obj) - # do table embedding - if embedding and tables and settings.TABLE_EMBEDDING_ENABLED: - tables = calc_table_embedding(tables, question) + # do table selection + used_llm_selection = False # 标记是否使用了 LLM 表选择 + if embedding and tables: + if settings.TABLE_LLM_SELECTION_ENABLED and config: + # 使用 LLM 表选择 + selected_table_names = calc_table_llm_selection( + config=config, + table_objs=table_objs, + question=question, + ds_table_relation=ds.table_relation, + history_questions=history_questions, + lang=lang, + session=session, + record_id=record_id + ) + if selected_table_names: + # 根据选中的表名筛选 tables + selected_table_ids = [table_name_to_obj[name].table.id for name in selected_table_names if name in table_name_to_obj] + tables = [t for t in tables if t.get('id') in selected_table_ids] + used_llm_selection = True # LLM 成功选择了表 + elif settings.TABLE_EMBEDDING_ENABLED: + # 使用 RAG 表选择 + tables = calc_table_embedding(tables, question, history_questions) # splice schema if tables: for s in tables: schema_str += s.get('schema_table') - # field relation - if tables and ds.table_relation: + # field relation - LLM 表选择模式下不补全关联表,完全信任 LLM 的选择结果 + if tables and ds.table_relation and not used_llm_selection: relations = list(filter(lambda x: x.get('shape') == 'edge', ds.table_relation)) if relations: # Complete the missing table diff --git a/backend/apps/datasource/llm_select/__init__.py b/backend/apps/datasource/llm_select/__init__.py new file mode 100644 index 00000000..972af445 --- /dev/null +++ b/backend/apps/datasource/llm_select/__init__.py @@ -0,0 +1,2 @@ +# Author: SQLBot +# Date: 2025/12/23 diff --git a/backend/apps/datasource/llm_select/table_selection.py b/backend/apps/datasource/llm_select/table_selection.py new file mode 100644 index 00000000..9d36b784 --- /dev/null +++ b/backend/apps/datasource/llm_select/table_selection.py @@ -0,0 +1,243 @@ +# Author: SQLBot +# Date: 2025/12/23 +import json +import traceback +from typing import List, Optional + +from langchain_core.messages import SystemMessage, HumanMessage, AIMessage +from sqlmodel import Session + +from apps.ai_model.model_factory import LLMConfig, LLMFactory +from apps.chat.curd.chat import start_log, end_log, save_table_select_answer +from apps.chat.models.chat_model import OperationEnum +from apps.template.select_table.generator import get_table_selection_template +from common.core.config import settings +from common.utils.utils import SQLBotLogUtil, extract_nested_json + + +def build_table_list_for_llm(table_objs: list) -> str: + """ + 构建 LLM 输入的表列表 JSON + + Args: + table_objs: 表对象列表,每个对象包含 table 属性 + + Returns: + JSON 格式的表列表字符串 + """ + table_list = [] + for obj in table_objs: + table = obj.table + table_info = { + "name": table.table_name, + "comment": table.custom_comment.strip() if table.custom_comment else "" + } + table_list.append(table_info) + return json.dumps(table_list, ensure_ascii=False) + + +def build_table_relations_str(ds_table_relation: list, table_objs: list) -> str: + """ + 构建表关系字符串 + + Args: + ds_table_relation: 数据源的表关系配置 + table_objs: 表对象列表 + + Returns: + 表关系字符串,格式如:表1.字段1 = 表2.字段2 + """ + if not ds_table_relation: + return "" + + # 构建 table_id -> table_name 映射 + table_dict = {} + field_dict = {} + for obj in table_objs: + table = obj.table + table_dict[table.id] = table.table_name + if obj.fields: + for field in obj.fields: + field_dict[field.id] = field.field_name + + relations = list(filter(lambda x: x.get('shape') == 'edge', ds_table_relation)) + if not relations: + return "" + + relation_lines = [] + for rel in relations: + source_table_id = int(rel.get('source', {}).get('cell', 0)) + source_field_id = int(rel.get('source', {}).get('port', 0)) + target_table_id = int(rel.get('target', {}).get('cell', 0)) + target_field_id = int(rel.get('target', {}).get('port', 0)) + + source_table = table_dict.get(source_table_id) + source_field = field_dict.get(source_field_id) + target_table = table_dict.get(target_table_id) + target_field = field_dict.get(target_field_id) + + if source_table and source_field and target_table and target_field: + relation_lines.append(f"{source_table}.{source_field} = {target_table}.{target_field}") + + return "\n".join(relation_lines) + + +def build_history_context(history_questions: List[str]) -> str: + """ + 构建历史问题上下文 + + Args: + history_questions: 历史问题列表 + + Returns: + 格式化的历史问题字符串 + """ + if not history_questions: + return "无" + + max_history = settings.MULTI_TURN_HISTORY_COUNT + recent_history = history_questions[-max_history:] if history_questions else [] + + if not recent_history: + return "无" + + return "\n".join([f"- {q}" for q in recent_history]) + + +def parse_llm_response(response_text: str, all_table_names: list) -> List[str]: + """ + 解析 LLM 返回的 JSON 响应 + + Args: + response_text: LLM 返回的文本 + all_table_names: 所有可用的表名列表(用于验证) + + Returns: + 选中的表名列表 + """ + try: + json_str = extract_nested_json(response_text) + if json_str: + result = json.loads(json_str) + if isinstance(result, dict) and 'tables' in result: + selected_tables = result.get('tables', []) + # 验证表名是否存在 + valid_tables = [t for t in selected_tables if t in all_table_names] + return valid_tables + except Exception as e: + SQLBotLogUtil.error(f"Failed to parse LLM table selection response: {e}") + + return [] + + +def calc_table_llm_selection( + config: LLMConfig, + table_objs: list, + question: str, + ds_table_relation: list = None, + history_questions: List[str] = None, + lang: str = "中文", + session: Session = None, + record_id: int = None +) -> List[str]: + """ + 使用 LLM 选择相关的表 + + Args: + config: LLM 配置 + table_objs: 表对象列表,每个对象包含 table 和 fields 属性 + question: 用户问题 + ds_table_relation: 数据源的表关系配置 + history_questions: 历史问题列表 + lang: 语言 + session: 数据库会话(用于记录日志) + record_id: 记录ID(用于记录日志) + + Returns: + 选中的表名列表,失败时返回空列表 + """ + if not table_objs: + return [] + + current_log = None + + try: + # 获取所有表名 + all_table_names = [obj.table.table_name for obj in table_objs] + + # 构建 LLM 输入 + table_list_str = build_table_list_for_llm(table_objs) + table_relations_str = build_table_relations_str(ds_table_relation, table_objs) + history_context = build_history_context(history_questions) + + # 获取提示词模板 + template = get_table_selection_template() + system_prompt = template['system'].format(lang=lang) + user_prompt = template['user'].format( + table_list=table_list_str, + table_relations=table_relations_str if table_relations_str else "无", + history_questions=history_context, + question=question + ) + + # 构建消息 + messages = [ + SystemMessage(content=system_prompt), + HumanMessage(content=user_prompt) + ] + + # 记录日志 - 开始 + if session and record_id: + current_log = start_log( + session=session, + ai_modal_id=config.model_id, + ai_modal_name=config.model_name, + operate=OperationEnum.SELECT_TABLE, + record_id=record_id, + full_message=[{'type': msg.type, 'content': msg.content} for msg in messages] + ) + + # 创建 LLM 实例并调用 + llm_instance = LLMFactory.create_llm(config) + llm = llm_instance.llm + + SQLBotLogUtil.info(f"LLM table selection - question: {question}, tables count: {len(table_objs)}") + + # 非流式调用 + response = llm.invoke(messages) + response_text = response.content if hasattr(response, 'content') else str(response) + + SQLBotLogUtil.info(f"LLM table selection response: {response_text}") + + # 记录日志 - 结束 + if session and record_id and current_log: + messages.append(AIMessage(content=response_text)) + token_usage = {} + if hasattr(response, 'usage_metadata') and response.usage_metadata: + token_usage = dict(response.usage_metadata) + end_log( + session=session, + log=current_log, + full_message=[{'type': msg.type, 'content': msg.content} for msg in messages], + reasoning_content=None, + token_usage=token_usage + ) + + # 解析响应 + selected_tables = parse_llm_response(response_text, all_table_names) + + # 保存表选择结果到 ChatRecord + if session and record_id: + save_table_select_answer(session, record_id, response_text) + + if selected_tables: + SQLBotLogUtil.info(f"LLM selected tables: {selected_tables}") + return selected_tables + else: + SQLBotLogUtil.warning("LLM table selection failed: 暂时无法找到表") + return [] + + except Exception as e: + SQLBotLogUtil.error(f"LLM table selection error: {e}") + traceback.print_exc() + return [] diff --git a/backend/apps/template/select_table/__init__.py b/backend/apps/template/select_table/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/backend/apps/template/select_table/generator.py b/backend/apps/template/select_table/generator.py new file mode 100644 index 00000000..83806e25 --- /dev/null +++ b/backend/apps/template/select_table/generator.py @@ -0,0 +1,6 @@ +from apps.template.template import get_base_template + + +def get_table_selection_template(): + template = get_base_template() + return template['template']['table_selection'] diff --git a/backend/common/core/config.py b/backend/common/core/config.py index 4e09c201..09bc0565 100644 --- a/backend/common/core/config.py +++ b/backend/common/core/config.py @@ -115,6 +115,13 @@ def SQLALCHEMY_DATABASE_URI(self) -> PostgresDsn | str: TABLE_EMBEDDING_COUNT: int = 10 DS_EMBEDDING_COUNT: int = 10 + # LLM table selection settings + TABLE_LLM_SELECTION_ENABLED: bool = True # 是否启用 LLM 表选择(优先于 RAG) + + # Multi-turn embedding settings + MULTI_TURN_EMBEDDING_ENABLED: bool = True + MULTI_TURN_HISTORY_COUNT: int = 3 + ORACLE_CLIENT_PATH: str = '/opt/sqlbot/db_client/oracle_instant_client' @field_validator('SQL_DEBUG', @@ -123,6 +130,8 @@ def SQLALCHEMY_DATABASE_URI(self) -> PostgresDsn | str: 'PARSE_REASONING_BLOCK_ENABLED', 'PG_POOL_PRE_PING', 'TABLE_EMBEDDING_ENABLED', + 'TABLE_LLM_SELECTION_ENABLED', + 'MULTI_TURN_EMBEDDING_ENABLED', mode='before') @classmethod def lowercase_bool(cls, v: Any) -> Any: diff --git a/backend/templates/template.yaml b/backend/templates/template.yaml index a6f50da1..035b5bb4 100644 --- a/backend/templates/template.yaml +++ b/backend/templates/template.yaml @@ -622,6 +622,87 @@ template: user: | ### sql: {sql} - + ### 子查询映射表: {sub_query} + table_selection: + system: | + ### 请使用语言:{lang} 回答 + + ### 说明: + 你是一个数据库专家,你需要根据用户的提问,从提供的数据库表列表中选择最相关的表,这些表将被用于后续生成SQL查询。 + + ### 表列表格式: + 提供给你的表列表格式为JSON数组: + [ + {{"name": "表名", "comment": "表描述"}}, + ... + ] + + ### 表关系信息: + 如果存在表关系,会以如下格式提供: + + 表1.字段1 = 表2.字段2 + ... + + + ### 要求: + - 仔细分析用户问题,理解其数据查询需求 + - 用户的问题可能比较模糊,你需要尽可能多地猜测可能需要的表 + - 选择与用户问题最相关的表,包括: + 1. 直接包含所需数据的表 + 2. 可能需要用于JOIN关联的表 + 3. 包含筛选条件所需字段的表 + - 选择所有你认为可能相关的表,不限数量,一般在5张以内 + - 以JSON格式返回选中的表名列表,格式为:{{"tables": ["表名1", "表名2", ...]}} + - 如果无法确定相关的表,返回:{{"tables": [], "message": "无法确定相关的表"}} + - 不需要思考过程,请直接返回JSON结果 + + ### 示例: + + + 表列表: [{{"name": "orders", "comment": "订单表"}}, {{"name": "customers", "comment": "客户表"}}, {{"name": "products", "comment": "产品表"}}, {{"name": "order_items", "comment": "订单明细表"}}, {{"name": "categories", "comment": "产品类别表"}}] + 表关系: orders.customer_id = customers.id | order_items.order_id = orders.id | order_items.product_id = products.id + 问题: 查询销售额最高的客户 + + + {{"tables": ["orders", "customers"]}} + + + + + 表列表: [{{"name": "users", "comment": "用户表"}}, {{"name": "departments", "comment": "部门表"}}, {{"name": "salaries", "comment": "薪资表"}}, {{"name": "attendance", "comment": "考勤表"}}] + 表关系: users.dept_id = departments.id | salaries.user_id = users.id + 问题: 各部门薪资情况 + + + {{"tables": ["users", "departments", "salaries"]}} + + + + + 表列表: [{{"name": "sales", "comment": "销售记录"}}, {{"name": "regions", "comment": "区域表"}}, {{"name": "products", "comment": "产品表"}}] + 表关系: sales.region_id = regions.id | sales.product_id = products.id + 问题: 销售数据 + + + {{"tables": ["sales", "regions", "products"]}} + + 问题比较模糊,尽可能选择所有可能相关的表 + + + ### 响应, 请直接返回JSON结果: + ```json + + user: | + ### 表列表: + {table_list} + + ### 表关系: + {table_relations} + + ### 历史问题(用于理解上下文): + {history_questions} + + ### 当前问题: + {question}