diff --git a/backend/alembic/versions/054_add_table_select_answer.py b/backend/alembic/versions/054_add_table_select_answer.py
new file mode 100644
index 00000000..e75e24f0
--- /dev/null
+++ b/backend/alembic/versions/054_add_table_select_answer.py
@@ -0,0 +1,22 @@
+"""add table_select_answer column to chat_record
+
+Revision ID: 054_table_select
+Revises: 5755c0b95839
+Create Date: 2025-12-23
+
+"""
+from alembic import op
+import sqlalchemy as sa
+
+revision = '054_table_select'
+down_revision = '5755c0b95839'
+branch_labels = None
+depends_on = None
+
+
+def upgrade():
+ op.add_column('chat_record', sa.Column('table_select_answer', sa.Text(), nullable=True))
+
+
+def downgrade():
+ op.drop_column('chat_record', 'table_select_answer')
diff --git a/backend/apps/chat/curd/chat.py b/backend/apps/chat/curd/chat.py
index f93ad18e..94e97cbc 100644
--- a/backend/apps/chat/curd/chat.py
+++ b/backend/apps/chat/curd/chat.py
@@ -649,6 +649,26 @@ def save_select_datasource_answer(session: SessionDep, record_id: int, answer: s
return result
+def save_table_select_answer(session: SessionDep, record_id: int, answer: str) -> ChatRecord:
+ """保存 LLM 表选择的结果到 ChatRecord"""
+ if not record_id:
+ raise Exception("Record id cannot be None")
+ record = get_chat_record_by_id(session, record_id)
+
+ record.table_select_answer = answer
+
+ result = ChatRecord(**record.model_dump())
+
+ stmt = update(ChatRecord).where(and_(ChatRecord.id == record.id)).values(
+ table_select_answer=record.table_select_answer,
+ )
+
+ session.execute(stmt)
+ session.commit()
+
+ return result
+
+
def save_recommend_question_answer(session: SessionDep, record_id: int,
answer: dict = None) -> ChatRecord:
if not record_id:
diff --git a/backend/apps/chat/models/chat_model.py b/backend/apps/chat/models/chat_model.py
index 33872005..25d44dbc 100644
--- a/backend/apps/chat/models/chat_model.py
+++ b/backend/apps/chat/models/chat_model.py
@@ -40,6 +40,7 @@ class OperationEnum(Enum):
GENERATE_SQL_WITH_PERMISSIONS = '5'
CHOOSE_DATASOURCE = '6'
GENERATE_DYNAMIC_SQL = '7'
+ SELECT_TABLE = '8' # LLM 表选择
class ChatFinishStep(Enum):
@@ -112,6 +113,7 @@ class ChatRecord(SQLModel, table=True):
recommended_question_answer: str = Field(sa_column=Column(Text, nullable=True))
recommended_question: str = Field(sa_column=Column(Text, nullable=True))
datasource_select_answer: str = Field(sa_column=Column(Text, nullable=True))
+ table_select_answer: str = Field(sa_column=Column(Text, nullable=True))
finish: bool = Field(sa_column=Column(Boolean, nullable=True, default=False))
error: str = Field(sa_column=Column(Text, nullable=True))
analysis_record_id: int = Field(sa_column=Column(BigInteger, nullable=True))
@@ -137,6 +139,7 @@ class ChatRecordResult(BaseModel):
predict_data: Optional[str] = None
recommended_question: Optional[str] = None
datasource_select_answer: Optional[str] = None
+ table_select_answer: Optional[str] = None
finish: Optional[bool] = None
error: Optional[str] = None
analysis_record_id: Optional[int] = None
diff --git a/backend/apps/chat/task/llm.py b/backend/apps/chat/task/llm.py
index b7402ab1..df8e3d7a 100644
--- a/backend/apps/chat/task/llm.py
+++ b/backend/apps/chat/task/llm.py
@@ -116,8 +116,16 @@ def __init__(self, session: Session, current_user: CurrentUser, chat_question: C
if not ds:
raise SingleMessageError("No available datasource configuration found")
chat_question.engine = (ds.type_name if ds.type != 'excel' else 'PostgreSQL') + get_version(ds)
- chat_question.db_schema = get_table_schema(session=session, current_user=current_user, ds=ds,
- question=chat_question.question, embedding=embedding)
+ # 延迟 get_table_schema 调用到 init_record 之后,以便记录 LLM 表选择日志
+ self._pending_schema_params = {
+ 'session': session,
+ 'current_user': current_user,
+ 'ds': ds,
+ 'question': chat_question.question,
+ 'embedding': embedding,
+ 'history_questions': history_questions,
+ 'config': config
+ }
self.generate_sql_logs = list_generate_sql_logs(session=session, chart_id=chat_id)
self.generate_chart_logs = list_generate_chart_logs(session=session, chart_id=chat_id)
@@ -224,6 +232,22 @@ def init_messages(self):
def init_record(self, session: Session) -> ChatRecord:
self.record = save_question(session=session, current_user=self.current_user, question=self.chat_question)
+
+ # 如果有延迟的 schema 获取,现在执行(此时 record 已存在,可以记录 LLM 表选择日志)
+ if hasattr(self, '_pending_schema_params') and self._pending_schema_params:
+ params = self._pending_schema_params
+ self.chat_question.db_schema = get_table_schema(
+ session=params['session'],
+ current_user=params['current_user'],
+ ds=params['ds'],
+ question=params['question'],
+ embedding=params['embedding'],
+ history_questions=params['history_questions'],
+ config=params['config'],
+ record_id=self.record.id
+ )
+ self._pending_schema_params = None
+
return self.record
def get_record(self):
@@ -349,7 +373,9 @@ def generate_recommend_questions_task(self, _session: Session):
session=_session,
current_user=self.current_user, ds=self.ds,
question=self.chat_question.question,
- embedding=False)
+ embedding=False,
+ config=self.config,
+ record_id=self.record.id)
guess_msg: List[Union[BaseMessage, dict[str, Any]]] = []
guess_msg.append(SystemMessage(content=self.chat_question.guess_sys_question(self.articles_number)))
@@ -494,7 +520,9 @@ def select_datasource(self, _session: Session):
self.ds)
self.chat_question.db_schema = get_table_schema(session=_session,
current_user=self.current_user, ds=self.ds,
- question=self.chat_question.question)
+ question=self.chat_question.question,
+ config=self.config,
+ record_id=self.record.id)
_engine_type = self.chat_question.engine
_chat.engine_type = _ds.type_name
# save chat
@@ -997,7 +1025,9 @@ def run_task(self, in_chat: bool = True, stream: bool = True,
session=_session,
current_user=self.current_user,
ds=self.ds,
- question=self.chat_question.question)
+ question=self.chat_question.question,
+ config=self.config,
+ record_id=self.record.id)
else:
self.validate_history_ds(_session)
diff --git a/backend/apps/datasource/crud/datasource.py b/backend/apps/datasource/crud/datasource.py
index 153e5088..8525b69d 100644
--- a/backend/apps/datasource/crud/datasource.py
+++ b/backend/apps/datasource/crud/datasource.py
@@ -7,8 +7,10 @@
from sqlbot_xpack.permissions.models.ds_rules import DsRules
from sqlmodel import select
+from apps.ai_model.model_factory import LLMConfig
from apps.datasource.crud.permission import get_column_permission_fields, get_row_permission_filters, is_normal_user
from apps.datasource.embedding.table_embedding import calc_table_embedding
+from apps.datasource.llm_select.table_selection import calc_table_llm_selection
from apps.datasource.utils.utils import aes_decrypt
from apps.db.constant import DB
from apps.db.db import get_tables, get_fields, exec_sql, check_connection
@@ -416,7 +418,8 @@ def get_table_obj_by_ds(session: SessionDep, current_user: CurrentUser, ds: Core
def get_table_schema(session: SessionDep, current_user: CurrentUser, ds: CoreDatasource, question: str,
- embedding: bool = True) -> str:
+ embedding: bool = True, history_questions: List[str] = None,
+ config: LLMConfig = None, lang: str = "中文", record_id: int = None) -> str:
schema_str = ""
table_objs = get_table_obj_by_ds(session=session, current_user=current_user, ds=ds)
if len(table_objs) == 0:
@@ -425,7 +428,12 @@ def get_table_schema(session: SessionDep, current_user: CurrentUser, ds: CoreDat
schema_str += f"【DB_ID】 {db_name}\n【Schema】\n"
tables = []
all_tables = [] # temp save all tables
+
+ # 构建 table_name -> table_obj 映射,用于 LLM 表选择
+ table_name_to_obj = {}
for obj in table_objs:
+ table_name_to_obj[obj.table.table_name] = obj
+
schema_table = ''
schema_table += f"# Table: {db_name}.{obj.table.table_name}" if ds.type != "mysql" and ds.type != "es" else f"# Table: {obj.table.table_name}"
table_comment = ''
@@ -453,16 +461,36 @@ def get_table_schema(session: SessionDep, current_user: CurrentUser, ds: CoreDat
tables.append(t_obj)
all_tables.append(t_obj)
- # do table embedding
- if embedding and tables and settings.TABLE_EMBEDDING_ENABLED:
- tables = calc_table_embedding(tables, question)
+ # do table selection
+ used_llm_selection = False # 标记是否使用了 LLM 表选择
+ if embedding and tables:
+ if settings.TABLE_LLM_SELECTION_ENABLED and config:
+ # 使用 LLM 表选择
+ selected_table_names = calc_table_llm_selection(
+ config=config,
+ table_objs=table_objs,
+ question=question,
+ ds_table_relation=ds.table_relation,
+ history_questions=history_questions,
+ lang=lang,
+ session=session,
+ record_id=record_id
+ )
+ if selected_table_names:
+ # 根据选中的表名筛选 tables
+ selected_table_ids = [table_name_to_obj[name].table.id for name in selected_table_names if name in table_name_to_obj]
+ tables = [t for t in tables if t.get('id') in selected_table_ids]
+ used_llm_selection = True # LLM 成功选择了表
+ elif settings.TABLE_EMBEDDING_ENABLED:
+ # 使用 RAG 表选择
+ tables = calc_table_embedding(tables, question, history_questions)
# splice schema
if tables:
for s in tables:
schema_str += s.get('schema_table')
- # field relation
- if tables and ds.table_relation:
+ # field relation - LLM 表选择模式下不补全关联表,完全信任 LLM 的选择结果
+ if tables and ds.table_relation and not used_llm_selection:
relations = list(filter(lambda x: x.get('shape') == 'edge', ds.table_relation))
if relations:
# Complete the missing table
diff --git a/backend/apps/datasource/llm_select/__init__.py b/backend/apps/datasource/llm_select/__init__.py
new file mode 100644
index 00000000..972af445
--- /dev/null
+++ b/backend/apps/datasource/llm_select/__init__.py
@@ -0,0 +1,2 @@
+# Author: SQLBot
+# Date: 2025/12/23
diff --git a/backend/apps/datasource/llm_select/table_selection.py b/backend/apps/datasource/llm_select/table_selection.py
new file mode 100644
index 00000000..9d36b784
--- /dev/null
+++ b/backend/apps/datasource/llm_select/table_selection.py
@@ -0,0 +1,243 @@
+# Author: SQLBot
+# Date: 2025/12/23
+import json
+import traceback
+from typing import List, Optional
+
+from langchain_core.messages import SystemMessage, HumanMessage, AIMessage
+from sqlmodel import Session
+
+from apps.ai_model.model_factory import LLMConfig, LLMFactory
+from apps.chat.curd.chat import start_log, end_log, save_table_select_answer
+from apps.chat.models.chat_model import OperationEnum
+from apps.template.select_table.generator import get_table_selection_template
+from common.core.config import settings
+from common.utils.utils import SQLBotLogUtil, extract_nested_json
+
+
+def build_table_list_for_llm(table_objs: list) -> str:
+ """
+ 构建 LLM 输入的表列表 JSON
+
+ Args:
+ table_objs: 表对象列表,每个对象包含 table 属性
+
+ Returns:
+ JSON 格式的表列表字符串
+ """
+ table_list = []
+ for obj in table_objs:
+ table = obj.table
+ table_info = {
+ "name": table.table_name,
+ "comment": table.custom_comment.strip() if table.custom_comment else ""
+ }
+ table_list.append(table_info)
+ return json.dumps(table_list, ensure_ascii=False)
+
+
+def build_table_relations_str(ds_table_relation: list, table_objs: list) -> str:
+ """
+ 构建表关系字符串
+
+ Args:
+ ds_table_relation: 数据源的表关系配置
+ table_objs: 表对象列表
+
+ Returns:
+ 表关系字符串,格式如:表1.字段1 = 表2.字段2
+ """
+ if not ds_table_relation:
+ return ""
+
+ # 构建 table_id -> table_name 映射
+ table_dict = {}
+ field_dict = {}
+ for obj in table_objs:
+ table = obj.table
+ table_dict[table.id] = table.table_name
+ if obj.fields:
+ for field in obj.fields:
+ field_dict[field.id] = field.field_name
+
+ relations = list(filter(lambda x: x.get('shape') == 'edge', ds_table_relation))
+ if not relations:
+ return ""
+
+ relation_lines = []
+ for rel in relations:
+ source_table_id = int(rel.get('source', {}).get('cell', 0))
+ source_field_id = int(rel.get('source', {}).get('port', 0))
+ target_table_id = int(rel.get('target', {}).get('cell', 0))
+ target_field_id = int(rel.get('target', {}).get('port', 0))
+
+ source_table = table_dict.get(source_table_id)
+ source_field = field_dict.get(source_field_id)
+ target_table = table_dict.get(target_table_id)
+ target_field = field_dict.get(target_field_id)
+
+ if source_table and source_field and target_table and target_field:
+ relation_lines.append(f"{source_table}.{source_field} = {target_table}.{target_field}")
+
+ return "\n".join(relation_lines)
+
+
+def build_history_context(history_questions: List[str]) -> str:
+ """
+ 构建历史问题上下文
+
+ Args:
+ history_questions: 历史问题列表
+
+ Returns:
+ 格式化的历史问题字符串
+ """
+ if not history_questions:
+ return "无"
+
+ max_history = settings.MULTI_TURN_HISTORY_COUNT
+ recent_history = history_questions[-max_history:] if history_questions else []
+
+ if not recent_history:
+ return "无"
+
+ return "\n".join([f"- {q}" for q in recent_history])
+
+
+def parse_llm_response(response_text: str, all_table_names: list) -> List[str]:
+ """
+ 解析 LLM 返回的 JSON 响应
+
+ Args:
+ response_text: LLM 返回的文本
+ all_table_names: 所有可用的表名列表(用于验证)
+
+ Returns:
+ 选中的表名列表
+ """
+ try:
+ json_str = extract_nested_json(response_text)
+ if json_str:
+ result = json.loads(json_str)
+ if isinstance(result, dict) and 'tables' in result:
+ selected_tables = result.get('tables', [])
+ # 验证表名是否存在
+ valid_tables = [t for t in selected_tables if t in all_table_names]
+ return valid_tables
+ except Exception as e:
+ SQLBotLogUtil.error(f"Failed to parse LLM table selection response: {e}")
+
+ return []
+
+
+def calc_table_llm_selection(
+ config: LLMConfig,
+ table_objs: list,
+ question: str,
+ ds_table_relation: list = None,
+ history_questions: List[str] = None,
+ lang: str = "中文",
+ session: Session = None,
+ record_id: int = None
+) -> List[str]:
+ """
+ 使用 LLM 选择相关的表
+
+ Args:
+ config: LLM 配置
+ table_objs: 表对象列表,每个对象包含 table 和 fields 属性
+ question: 用户问题
+ ds_table_relation: 数据源的表关系配置
+ history_questions: 历史问题列表
+ lang: 语言
+ session: 数据库会话(用于记录日志)
+ record_id: 记录ID(用于记录日志)
+
+ Returns:
+ 选中的表名列表,失败时返回空列表
+ """
+ if not table_objs:
+ return []
+
+ current_log = None
+
+ try:
+ # 获取所有表名
+ all_table_names = [obj.table.table_name for obj in table_objs]
+
+ # 构建 LLM 输入
+ table_list_str = build_table_list_for_llm(table_objs)
+ table_relations_str = build_table_relations_str(ds_table_relation, table_objs)
+ history_context = build_history_context(history_questions)
+
+ # 获取提示词模板
+ template = get_table_selection_template()
+ system_prompt = template['system'].format(lang=lang)
+ user_prompt = template['user'].format(
+ table_list=table_list_str,
+ table_relations=table_relations_str if table_relations_str else "无",
+ history_questions=history_context,
+ question=question
+ )
+
+ # 构建消息
+ messages = [
+ SystemMessage(content=system_prompt),
+ HumanMessage(content=user_prompt)
+ ]
+
+ # 记录日志 - 开始
+ if session and record_id:
+ current_log = start_log(
+ session=session,
+ ai_modal_id=config.model_id,
+ ai_modal_name=config.model_name,
+ operate=OperationEnum.SELECT_TABLE,
+ record_id=record_id,
+ full_message=[{'type': msg.type, 'content': msg.content} for msg in messages]
+ )
+
+ # 创建 LLM 实例并调用
+ llm_instance = LLMFactory.create_llm(config)
+ llm = llm_instance.llm
+
+ SQLBotLogUtil.info(f"LLM table selection - question: {question}, tables count: {len(table_objs)}")
+
+ # 非流式调用
+ response = llm.invoke(messages)
+ response_text = response.content if hasattr(response, 'content') else str(response)
+
+ SQLBotLogUtil.info(f"LLM table selection response: {response_text}")
+
+ # 记录日志 - 结束
+ if session and record_id and current_log:
+ messages.append(AIMessage(content=response_text))
+ token_usage = {}
+ if hasattr(response, 'usage_metadata') and response.usage_metadata:
+ token_usage = dict(response.usage_metadata)
+ end_log(
+ session=session,
+ log=current_log,
+ full_message=[{'type': msg.type, 'content': msg.content} for msg in messages],
+ reasoning_content=None,
+ token_usage=token_usage
+ )
+
+ # 解析响应
+ selected_tables = parse_llm_response(response_text, all_table_names)
+
+ # 保存表选择结果到 ChatRecord
+ if session and record_id:
+ save_table_select_answer(session, record_id, response_text)
+
+ if selected_tables:
+ SQLBotLogUtil.info(f"LLM selected tables: {selected_tables}")
+ return selected_tables
+ else:
+ SQLBotLogUtil.warning("LLM table selection failed: 暂时无法找到表")
+ return []
+
+ except Exception as e:
+ SQLBotLogUtil.error(f"LLM table selection error: {e}")
+ traceback.print_exc()
+ return []
diff --git a/backend/apps/template/select_table/__init__.py b/backend/apps/template/select_table/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/backend/apps/template/select_table/generator.py b/backend/apps/template/select_table/generator.py
new file mode 100644
index 00000000..83806e25
--- /dev/null
+++ b/backend/apps/template/select_table/generator.py
@@ -0,0 +1,6 @@
+from apps.template.template import get_base_template
+
+
+def get_table_selection_template():
+ template = get_base_template()
+ return template['template']['table_selection']
diff --git a/backend/common/core/config.py b/backend/common/core/config.py
index 4e09c201..09bc0565 100644
--- a/backend/common/core/config.py
+++ b/backend/common/core/config.py
@@ -115,6 +115,13 @@ def SQLALCHEMY_DATABASE_URI(self) -> PostgresDsn | str:
TABLE_EMBEDDING_COUNT: int = 10
DS_EMBEDDING_COUNT: int = 10
+ # LLM table selection settings
+ TABLE_LLM_SELECTION_ENABLED: bool = True # 是否启用 LLM 表选择(优先于 RAG)
+
+ # Multi-turn embedding settings
+ MULTI_TURN_EMBEDDING_ENABLED: bool = True
+ MULTI_TURN_HISTORY_COUNT: int = 3
+
ORACLE_CLIENT_PATH: str = '/opt/sqlbot/db_client/oracle_instant_client'
@field_validator('SQL_DEBUG',
@@ -123,6 +130,8 @@ def SQLALCHEMY_DATABASE_URI(self) -> PostgresDsn | str:
'PARSE_REASONING_BLOCK_ENABLED',
'PG_POOL_PRE_PING',
'TABLE_EMBEDDING_ENABLED',
+ 'TABLE_LLM_SELECTION_ENABLED',
+ 'MULTI_TURN_EMBEDDING_ENABLED',
mode='before')
@classmethod
def lowercase_bool(cls, v: Any) -> Any:
diff --git a/backend/templates/template.yaml b/backend/templates/template.yaml
index a6f50da1..035b5bb4 100644
--- a/backend/templates/template.yaml
+++ b/backend/templates/template.yaml
@@ -622,6 +622,87 @@ template:
user: |
### sql:
{sql}
-
+
### 子查询映射表:
{sub_query}
+ table_selection:
+ system: |
+ ### 请使用语言:{lang} 回答
+
+ ### 说明:
+ 你是一个数据库专家,你需要根据用户的提问,从提供的数据库表列表中选择最相关的表,这些表将被用于后续生成SQL查询。
+
+ ### 表列表格式:
+ 提供给你的表列表格式为JSON数组:
+ [
+ {{"name": "表名", "comment": "表描述"}},
+ ...
+ ]
+
+ ### 表关系信息:
+ 如果存在表关系,会以如下格式提供:
+
+ 表1.字段1 = 表2.字段2
+ ...
+
+
+ ### 要求:
+ - 仔细分析用户问题,理解其数据查询需求
+ - 用户的问题可能比较模糊,你需要尽可能多地猜测可能需要的表
+ - 选择与用户问题最相关的表,包括:
+ 1. 直接包含所需数据的表
+ 2. 可能需要用于JOIN关联的表
+ 3. 包含筛选条件所需字段的表
+ - 选择所有你认为可能相关的表,不限数量,一般在5张以内
+ - 以JSON格式返回选中的表名列表,格式为:{{"tables": ["表名1", "表名2", ...]}}
+ - 如果无法确定相关的表,返回:{{"tables": [], "message": "无法确定相关的表"}}
+ - 不需要思考过程,请直接返回JSON结果
+
+ ### 示例:
+
+
+ 表列表: [{{"name": "orders", "comment": "订单表"}}, {{"name": "customers", "comment": "客户表"}}, {{"name": "products", "comment": "产品表"}}, {{"name": "order_items", "comment": "订单明细表"}}, {{"name": "categories", "comment": "产品类别表"}}]
+ 表关系: orders.customer_id = customers.id | order_items.order_id = orders.id | order_items.product_id = products.id
+ 问题: 查询销售额最高的客户
+
+
+
+
+
+ 表列表: [{{"name": "users", "comment": "用户表"}}, {{"name": "departments", "comment": "部门表"}}, {{"name": "salaries", "comment": "薪资表"}}, {{"name": "attendance", "comment": "考勤表"}}]
+ 表关系: users.dept_id = departments.id | salaries.user_id = users.id
+ 问题: 各部门薪资情况
+
+
+
+
+
+ 表列表: [{{"name": "sales", "comment": "销售记录"}}, {{"name": "regions", "comment": "区域表"}}, {{"name": "products", "comment": "产品表"}}]
+ 表关系: sales.region_id = regions.id | sales.product_id = products.id
+ 问题: 销售数据
+
+
+ 问题比较模糊,尽可能选择所有可能相关的表
+
+
+ ### 响应, 请直接返回JSON结果:
+ ```json
+
+ user: |
+ ### 表列表:
+ {table_list}
+
+ ### 表关系:
+ {table_relations}
+
+ ### 历史问题(用于理解上下文):
+ {history_questions}
+
+ ### 当前问题:
+ {question}