Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions backend/alembic/versions/054_add_table_select_answer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
"""add table_select_answer column to chat_record

Revision ID: 054_table_select
Revises: 5755c0b95839
Create Date: 2025-12-23

"""
from alembic import op
import sqlalchemy as sa

revision = '054_table_select'
down_revision = '5755c0b95839'
branch_labels = None
depends_on = None


def upgrade():
op.add_column('chat_record', sa.Column('table_select_answer', sa.Text(), nullable=True))


def downgrade():
op.drop_column('chat_record', 'table_select_answer')
20 changes: 20 additions & 0 deletions backend/apps/chat/curd/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -649,6 +649,26 @@ def save_select_datasource_answer(session: SessionDep, record_id: int, answer: s
return result


def save_table_select_answer(session: SessionDep, record_id: int, answer: str) -> ChatRecord:
"""保存 LLM 表选择的结果到 ChatRecord"""
if not record_id:
raise Exception("Record id cannot be None")
record = get_chat_record_by_id(session, record_id)

record.table_select_answer = answer

result = ChatRecord(**record.model_dump())

stmt = update(ChatRecord).where(and_(ChatRecord.id == record.id)).values(
table_select_answer=record.table_select_answer,
)

session.execute(stmt)
session.commit()

return result


def save_recommend_question_answer(session: SessionDep, record_id: int,
answer: dict = None) -> ChatRecord:
if not record_id:
Expand Down
3 changes: 3 additions & 0 deletions backend/apps/chat/models/chat_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ class OperationEnum(Enum):
GENERATE_SQL_WITH_PERMISSIONS = '5'
CHOOSE_DATASOURCE = '6'
GENERATE_DYNAMIC_SQL = '7'
SELECT_TABLE = '8' # LLM 表选择


class ChatFinishStep(Enum):
Expand Down Expand Up @@ -112,6 +113,7 @@ class ChatRecord(SQLModel, table=True):
recommended_question_answer: str = Field(sa_column=Column(Text, nullable=True))
recommended_question: str = Field(sa_column=Column(Text, nullable=True))
datasource_select_answer: str = Field(sa_column=Column(Text, nullable=True))
table_select_answer: str = Field(sa_column=Column(Text, nullable=True))
finish: bool = Field(sa_column=Column(Boolean, nullable=True, default=False))
error: str = Field(sa_column=Column(Text, nullable=True))
analysis_record_id: int = Field(sa_column=Column(BigInteger, nullable=True))
Expand All @@ -137,6 +139,7 @@ class ChatRecordResult(BaseModel):
predict_data: Optional[str] = None
recommended_question: Optional[str] = None
datasource_select_answer: Optional[str] = None
table_select_answer: Optional[str] = None
finish: Optional[bool] = None
error: Optional[str] = None
analysis_record_id: Optional[int] = None
Expand Down
40 changes: 35 additions & 5 deletions backend/apps/chat/task/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,16 @@ def __init__(self, session: Session, current_user: CurrentUser, chat_question: C
if not ds:
raise SingleMessageError("No available datasource configuration found")
chat_question.engine = (ds.type_name if ds.type != 'excel' else 'PostgreSQL') + get_version(ds)
chat_question.db_schema = get_table_schema(session=session, current_user=current_user, ds=ds,
question=chat_question.question, embedding=embedding)
# 延迟 get_table_schema 调用到 init_record 之后,以便记录 LLM 表选择日志
self._pending_schema_params = {
'session': session,
'current_user': current_user,
'ds': ds,
'question': chat_question.question,
'embedding': embedding,
'history_questions': history_questions,
'config': config
}

self.generate_sql_logs = list_generate_sql_logs(session=session, chart_id=chat_id)
self.generate_chart_logs = list_generate_chart_logs(session=session, chart_id=chat_id)
Expand Down Expand Up @@ -224,6 +232,22 @@ def init_messages(self):

def init_record(self, session: Session) -> ChatRecord:
self.record = save_question(session=session, current_user=self.current_user, question=self.chat_question)

# 如果有延迟的 schema 获取,现在执行(此时 record 已存在,可以记录 LLM 表选择日志)
if hasattr(self, '_pending_schema_params') and self._pending_schema_params:
params = self._pending_schema_params
self.chat_question.db_schema = get_table_schema(
session=params['session'],
current_user=params['current_user'],
ds=params['ds'],
question=params['question'],
embedding=params['embedding'],
history_questions=params['history_questions'],
config=params['config'],
record_id=self.record.id
)
self._pending_schema_params = None

return self.record

def get_record(self):
Expand Down Expand Up @@ -349,7 +373,9 @@ def generate_recommend_questions_task(self, _session: Session):
session=_session,
current_user=self.current_user, ds=self.ds,
question=self.chat_question.question,
embedding=False)
embedding=False,
config=self.config,
record_id=self.record.id)

guess_msg: List[Union[BaseMessage, dict[str, Any]]] = []
guess_msg.append(SystemMessage(content=self.chat_question.guess_sys_question(self.articles_number)))
Expand Down Expand Up @@ -494,7 +520,9 @@ def select_datasource(self, _session: Session):
self.ds)
self.chat_question.db_schema = get_table_schema(session=_session,
current_user=self.current_user, ds=self.ds,
question=self.chat_question.question)
question=self.chat_question.question,
config=self.config,
record_id=self.record.id)
_engine_type = self.chat_question.engine
_chat.engine_type = _ds.type_name
# save chat
Expand Down Expand Up @@ -997,7 +1025,9 @@ def run_task(self, in_chat: bool = True, stream: bool = True,
session=_session,
current_user=self.current_user,
ds=self.ds,
question=self.chat_question.question)
question=self.chat_question.question,
config=self.config,
record_id=self.record.id)
else:
self.validate_history_ds(_session)

Expand Down
40 changes: 34 additions & 6 deletions backend/apps/datasource/crud/datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@
from sqlbot_xpack.permissions.models.ds_rules import DsRules
from sqlmodel import select

from apps.ai_model.model_factory import LLMConfig
from apps.datasource.crud.permission import get_column_permission_fields, get_row_permission_filters, is_normal_user
from apps.datasource.embedding.table_embedding import calc_table_embedding
from apps.datasource.llm_select.table_selection import calc_table_llm_selection
from apps.datasource.utils.utils import aes_decrypt
from apps.db.constant import DB
from apps.db.db import get_tables, get_fields, exec_sql, check_connection
Expand Down Expand Up @@ -416,7 +418,8 @@ def get_table_obj_by_ds(session: SessionDep, current_user: CurrentUser, ds: Core


def get_table_schema(session: SessionDep, current_user: CurrentUser, ds: CoreDatasource, question: str,
embedding: bool = True) -> str:
embedding: bool = True, history_questions: List[str] = None,
config: LLMConfig = None, lang: str = "中文", record_id: int = None) -> str:
schema_str = ""
table_objs = get_table_obj_by_ds(session=session, current_user=current_user, ds=ds)
if len(table_objs) == 0:
Expand All @@ -425,7 +428,12 @@ def get_table_schema(session: SessionDep, current_user: CurrentUser, ds: CoreDat
schema_str += f"【DB_ID】 {db_name}\n【Schema】\n"
tables = []
all_tables = [] # temp save all tables

# 构建 table_name -> table_obj 映射,用于 LLM 表选择
table_name_to_obj = {}
for obj in table_objs:
table_name_to_obj[obj.table.table_name] = obj

schema_table = ''
schema_table += f"# Table: {db_name}.{obj.table.table_name}" if ds.type != "mysql" and ds.type != "es" else f"# Table: {obj.table.table_name}"
table_comment = ''
Expand Down Expand Up @@ -453,16 +461,36 @@ def get_table_schema(session: SessionDep, current_user: CurrentUser, ds: CoreDat
tables.append(t_obj)
all_tables.append(t_obj)

# do table embedding
if embedding and tables and settings.TABLE_EMBEDDING_ENABLED:
tables = calc_table_embedding(tables, question)
# do table selection
used_llm_selection = False # 标记是否使用了 LLM 表选择
if embedding and tables:
if settings.TABLE_LLM_SELECTION_ENABLED and config:
# 使用 LLM 表选择
selected_table_names = calc_table_llm_selection(
config=config,
table_objs=table_objs,
question=question,
ds_table_relation=ds.table_relation,
history_questions=history_questions,
lang=lang,
session=session,
record_id=record_id
)
if selected_table_names:
# 根据选中的表名筛选 tables
selected_table_ids = [table_name_to_obj[name].table.id for name in selected_table_names if name in table_name_to_obj]
tables = [t for t in tables if t.get('id') in selected_table_ids]
used_llm_selection = True # LLM 成功选择了表
elif settings.TABLE_EMBEDDING_ENABLED:
# 使用 RAG 表选择
tables = calc_table_embedding(tables, question, history_questions)
# splice schema
if tables:
for s in tables:
schema_str += s.get('schema_table')

# field relation
if tables and ds.table_relation:
# field relation - LLM 表选择模式下不补全关联表,完全信任 LLM 的选择结果
if tables and ds.table_relation and not used_llm_selection:
relations = list(filter(lambda x: x.get('shape') == 'edge', ds.table_relation))
if relations:
# Complete the missing table
Expand Down
2 changes: 2 additions & 0 deletions backend/apps/datasource/llm_select/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# Author: SQLBot
# Date: 2025/12/23
Loading