From 70a72f8dd007453fbe8832744bd94fd8d58b3147 Mon Sep 17 00:00:00 2001 From: seyeong Date: Sun, 11 May 2025 22:47:48 +0900 Subject: [PATCH 01/11] =?UTF-8?q?docs=20:=20.gitignore=20table=5Finfo=5Fdb?= =?UTF-8?q?=20=EC=B6=94=EA=B0=80?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .gitignore | 3 ++- llm_utils/graph.py | 1 + 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index 34d9aae..e36e2b6 100644 --- a/.gitignore +++ b/.gitignore @@ -8,4 +8,5 @@ dist/ .venv/ test_lhm/ .cursorignore -.vscode \ No newline at end of file +.vscode +table_info_db \ No newline at end of file diff --git a/llm_utils/graph.py b/llm_utils/graph.py index 69a10b9..cbfe633 100644 --- a/llm_utils/graph.py +++ b/llm_utils/graph.py @@ -31,6 +31,7 @@ class QueryMakerState(TypedDict): searched_tables: dict[str, dict[str, str]] best_practice_query: str refined_input: str + question_profile: dict generated_query: str retriever_name: str top_n: int From f347482d73778959def5ef6435a57864055a27fc Mon Sep 17 00:00:00 2001 From: seyeong Date: Sun, 11 May 2025 22:50:26 +0900 Subject: [PATCH 02/11] =?UTF-8?q?feat(graph):=20add=20profile=E2=80=91awar?= =?UTF-8?q?e=20query=E2=80=91refiner=20chain?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * 신규 함수 추가 - 사용자 입력·테이블 목록에 질문 프로파일()을 포함하는 프롬프트 조합 - 구체화 질문 품질 향상 및 LLM 컨텍스트 강화 --- llm_utils/chains.py | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/llm_utils/chains.py b/llm_utils/chains.py index a0a5f27..54fc98d 100644 --- a/llm_utils/chains.py +++ b/llm_utils/chains.py @@ -72,8 +72,38 @@ def create_query_maker_chain(llm): return query_maker_prompt | llm +def create_query_refiner_with_profile_chain(llm): + prompt = get_prompt_template("query_refiner_prompt") + + tool_choice_prompt = ChatPromptTemplate.from_messages( + [ + SystemMessagePromptTemplate.from_template(prompt), + MessagesPlaceholder(variable_name="user_input"), + SystemMessagePromptTemplate.from_template( + "다음은 사용자의 실제 사용 가능한 테이블 및 컬럼 정보입니다:" + ), + MessagesPlaceholder(variable_name="searched_tables"), + # 프로파일 정보 입력 + SystemMessagePromptTemplate.from_template( + "다음은 사용자의 질문을 분석한 프로파일 정보입니다." + ), + MessagesPlaceholder("profile_prompt"), + SystemMessagePromptTemplate.from_template( + """ + 위 사용자의 입력과 위 조건을 바탕으로 + 분석 관점에서 **충분히 답변 가능한 형태**로 + "구체화된 질문"을 작성하세요. + """, + ), + ] + ) + + return tool_choice_prompt | llm + + query_refiner_chain = create_query_refiner_chain(llm) query_maker_chain = create_query_maker_chain(llm) +query_refiner_with_profile_chain = create_query_refiner_with_profile_chain(llm) if __name__ == "__main__": query_refiner_chain.invoke() From b423bfcb8110d7a3c1a902271d8f8d41249a6605 Mon Sep 17 00:00:00 2001 From: seyeong Date: Sun, 11 May 2025 23:07:03 +0900 Subject: [PATCH 03/11] =?UTF-8?q?feat(chains):=20add=20QuestionProfile=20m?= =?UTF-8?q?odel=20and=20profile=E2=80=91extraction=20chain?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Pydantic 모델 `QuestionProfile` 정의 - 시계열·집계·필터·그룹화·랭킹·기간비교 플래그와 `intent_type` 포함 * `profile_prompt` 템플릿 추가 * `create_profile_extraction_chain()` 함수로 LLM 기반 프로파일 추출 지원 --- llm_utils/chains.py | 39 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) diff --git a/llm_utils/chains.py b/llm_utils/chains.py index 54fc98d..d5d9c57 100644 --- a/llm_utils/chains.py +++ b/llm_utils/chains.py @@ -4,12 +4,14 @@ MessagesPlaceholder, SystemMessagePromptTemplate, ) +from pydantic import BaseModel, Field from .llm_factory import get_llm from dotenv import load_dotenv from prompt.template_loader import get_prompt_template + env_path = os.path.join(os.getcwd(), ".env") if os.path.exists(env_path): @@ -20,6 +22,16 @@ llm = get_llm() +class QuestionProfile(BaseModel): + is_timeseries: bool = Field(description="시계열 분석 필요 여부") + is_aggregation: bool = Field(description="집계 함수 필요 여부") + has_filter: bool = Field(description="조건 필터 필요 여부") + is_grouped: bool = Field(description="그룹화 필요 여부") + has_ranking: bool = Field(description="정렬/순위 필요 여부") + has_temporal_comparison: bool = Field(description="기간 비교 포함 여부") + intent_type: str = Field(description="질문의 주요 의도 유형") + + def create_query_refiner_chain(llm): prompt = get_prompt_template("query_refiner_prompt") tool_choice_prompt = ChatPromptTemplate.from_messages( @@ -101,6 +113,33 @@ def create_query_refiner_with_profile_chain(llm): return tool_choice_prompt | llm +from langchain.prompts import PromptTemplate + +profile_prompt = PromptTemplate( + input_variables=["question"], + template=""" +You are an assistant that analyzes a user question and extracts the following profiles as JSON: +- is_timeseries (boolean) +- is_aggregation (boolean) +- has_filter (boolean) +- is_grouped (boolean) +- has_ranking (boolean) +- has_temporal_comparison (boolean) +- intent_type (one of: trend, lookup, comparison, distribution) + +Return only valid JSON matching the QuestionProfile schema. + +Question: +{question} +""".strip(), +) + + +def create_profile_extraction_chain(llm): + chain = profile_prompt | llm.with_structured_output(QuestionProfile) + return chain + + query_refiner_chain = create_query_refiner_chain(llm) query_maker_chain = create_query_maker_chain(llm) query_refiner_with_profile_chain = create_query_refiner_with_profile_chain(llm) From 885e3f558729d339b74c4d8be70595725d9b2085 Mon Sep 17 00:00:00 2001 From: seyeong Date: Sun, 11 May 2025 23:14:56 +0900 Subject: [PATCH 04/11] =?UTF-8?q?feat(graph)=20:=20PROFILE=5FEXTRACTION=20?= =?UTF-8?q?=EB=85=B8=EB=93=9C=20=EC=B6=94=EA=B0=80=20=EC=82=AC=EC=9A=A9?= =?UTF-8?q?=EC=9E=90=EC=9D=98=20=EC=A7=88=EB=AC=B8=20=ED=8A=B9=EC=A7=95?= =?UTF-8?q?=EC=9D=84=20=EC=A0=95=EC=9D=98=EB=90=9C=20=EC=9C=A0=ED=98=95?= =?UTF-8?q?=EC=97=90=20=EB=94=B0=EB=9D=BC=20=EC=B6=94=EC=B6=9C?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- llm_utils/graph.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/llm_utils/graph.py b/llm_utils/graph.py index cbfe633..2513179 100644 --- a/llm_utils/graph.py +++ b/llm_utils/graph.py @@ -11,6 +11,7 @@ from llm_utils.chains import ( query_refiner_chain, query_maker_chain, + query_refiner_with_profile_chain, ) from llm_utils.tools import get_info_from_db @@ -38,6 +39,17 @@ class QueryMakerState(TypedDict): device: str +# 노드 함수: PROFILE_EXTRACTION 노드 +def profile_extraction_node(state: QueryMakerState): + + result = query_refiner_with_profile_chain.invoke( + {"question": state["messages"][0].content} + ) + + state["question_profile"] = result + return state + + # 노드 함수: QUERY_REFINER 노드 def query_refiner_node(state: QueryMakerState): res = query_refiner_chain.invoke( From 5e5accd4bc1817c07e5bae208ecd820cf583d6ad Mon Sep 17 00:00:00 2001 From: seyeong Date: Sun, 11 May 2025 23:17:24 +0900 Subject: [PATCH 05/11] =?UTF-8?q?feat(graph)=20:=20CONTEXT=5FENRICHMENT=20?= =?UTF-8?q?=EB=85=B8=EB=93=9C=20=EC=B6=94=EA=B0=80=20refined=EB=90=9C=20?= =?UTF-8?q?=EC=A7=88=EB=AC=B8=EC=97=90=EC=84=9C=20=ED=85=8C=EC=9D=B4?= =?UTF-8?q?=EB=B8=94=20=EC=A0=95=EB=B3=B4=EB=A5=BC=20=EC=A0=95=EB=A0=AC?= =?UTF-8?q?=ED=95=98=EB=8A=94=20=EB=85=B8=EB=93=9C?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- llm_utils/graph.py | 50 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 50 insertions(+) diff --git a/llm_utils/graph.py b/llm_utils/graph.py index 2513179..3afa212 100644 --- a/llm_utils/graph.py +++ b/llm_utils/graph.py @@ -65,6 +65,56 @@ def query_refiner_node(state: QueryMakerState): return state +# 노드 함수: CONTEXT_ENRICHMENT 노드 +def context_enrichment_node(state: QueryMakerState): + + import json + + searched_tables = state["searched_tables"] + searched_tables_json = json.dumps(searched_tables, ensure_ascii=False, indent=2) + + question_profile = state["question_profile"].model_dump() + question_profile_json = json.dumps(question_profile, ensure_ascii=False, indent=2) + + from langchain.prompts import PromptTemplate + + enrichment_prompt = PromptTemplate( + input_variables=["refined_question", "profiles", "related_tables"], + template=""" + You are a smart assistant that takes a user question and enriches it using: + 1. Question profiles: {profiles} + 2. Table metadata (names, columns, descriptions): + {related_tables} + + Tasks: + - Correct any wrong terms by matching them to actual column names. + - If the question is time-series or aggregation, add explicit hints (e.g., "over the last 30 days"). + - If needed, map natural language terms to actual column values (e.g., ‘미국’ → ‘USA’ for country_code). + - Output the enriched question only. + + Refined question: + {refined_question} + + Using the refined version for enrichment, but keep original intent in mind. + """.strip(), + ) + + llm = get_llm() + prompt = enrichment_prompt.format_prompt( + refined_question=state["refined_input"], + profiles=question_profile_json, + related_tables=searched_tables_json, + ) + enriched_text = llm.invoke(prompt.to_messages()) + + state["refined_input"] = enriched_text + from langchain_core.messages import HumanMessage + + state["messages"].append(HumanMessage(content=enriched_text.content)) + + return state + + def get_table_info_node(state: QueryMakerState): # retriever_name과 top_n을 이용하여 검색 수행 documents_dict = search_tables( From ffd9dc8a4a40c4f2464f638f26c10391c06dcce6 Mon Sep 17 00:00:00 2001 From: seyeong Date: Sun, 11 May 2025 23:27:12 +0900 Subject: [PATCH 06/11] =?UTF-8?q?feat(graph):=20query=5Frefiner=5Fwith=5Fp?= =?UTF-8?q?rofile=5Fnode=20=EC=B6=94=EA=B0=80=20\=20=EC=A7=88=EB=AC=B8=20?= =?UTF-8?q?=EC=9E=AC=EC=A0=95=EC=9D=98=EC=8B=9C=20=ED=8A=B9=EC=A7=95=20?= =?UTF-8?q?=EC=A0=95=EB=B3=B4=20=ED=99=9C=EC=9A=A9=ED=95=B4=20=EC=9E=AC?= =?UTF-8?q?=EC=A0=95=EC=9D=98=ED=95=98=EB=8A=94=20=EB=85=B8=EB=93=9C=20uti?= =?UTF-8?q?ls.py=20=EC=95=88=20profile=5Fto=5Ftext=20(=EC=A7=88=EB=AC=B8?= =?UTF-8?q?=20=ED=8A=B9=EC=A7=95=20=EC=A0=95=EB=B3=B4=20=ED=94=84=EB=A1=AC?= =?UTF-8?q?=ED=94=84=ED=8A=B8=20=EC=B6=94=EA=B0=80=20=ED=95=A8=EC=88=98)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- llm_utils/graph.py | 21 +++++++++++++++++++++ llm_utils/utils.py | 17 +++++++++++++++++ 2 files changed, 38 insertions(+) create mode 100644 llm_utils/utils.py diff --git a/llm_utils/graph.py b/llm_utils/graph.py index 3afa212..d0ca83e 100644 --- a/llm_utils/graph.py +++ b/llm_utils/graph.py @@ -16,6 +16,7 @@ from llm_utils.tools import get_info_from_db from llm_utils.retrieval import search_tables +from llm_utils.utils import profile_to_text # 노드 식별자 정의 QUERY_REFINER = "query_refiner" @@ -65,6 +66,26 @@ def query_refiner_node(state: QueryMakerState): return state +# 노드 함수: QUERY_REFINER 노드 +def query_refiner_with_profile_node(state: QueryMakerState): + + profile_bullets = profile_to_text(state["question_profile"]) + res = query_refiner_with_profile_chain.invoke( + input={ + "user_input": [state["messages"][0].content], + "user_database_env": [state["user_database_env"]], + "best_practice_query": [state["best_practice_query"]], + "searched_tables": [json.dumps(state["searched_tables"])], + "profile_prompt": [profile_bullets], + } + ) + state["messages"].append(res) + state["refined_input"] = res + + print("refined_input before context enrichment : ", res.content) + return state + + # 노드 함수: CONTEXT_ENRICHMENT 노드 def context_enrichment_node(state: QueryMakerState): diff --git a/llm_utils/utils.py b/llm_utils/utils.py new file mode 100644 index 0000000..2057b5c --- /dev/null +++ b/llm_utils/utils.py @@ -0,0 +1,17 @@ +def profile_to_text(profile_obj) -> str: + mapping = { + "is_timeseries": "• 시계열 분석 필요", + "is_aggregation": "• 집계 함수 필요", + "has_filter": "• WHERE 조건 필요", + "is_grouped": "• GROUP BY 필요", + "has_ranking": "• 정렬/순위 필요", + "has_temporal_comparison": "• 기간 비교 필요", + } + bullets = [ + text for field, text in mapping.items() if getattr(profile_obj, field, False) + ] + intent = getattr(profile_obj, "intent_type", None) + if intent: + bullets.append(f"• 의도 유형 → {intent}") + + return "\n".join(bullets) From 78379969c5e32ac68ae6bcbc1b1627ccac12bcc5 Mon Sep 17 00:00:00 2001 From: seyeong Date: Mon, 12 May 2025 13:42:13 +0900 Subject: [PATCH 07/11] =?UTF-8?q?feat(enriched=5Fgraph)=20:=20=ED=94=84?= =?UTF-8?q?=EB=A1=9C=ED=8C=8C=EC=9D=BC=20=EC=B6=94=EC=B6=9C,=20=EC=BB=A8?= =?UTF-8?q?=ED=85=8D=EC=8A=A4=ED=8A=B8=20=EB=B3=B4=EA=B0=95=20=EB=85=B8?= =?UTF-8?q?=EB=93=9C=EB=A5=BC=20=EC=B6=94=EA=B0=80=ED=95=9C=20=EC=83=88?= =?UTF-8?q?=EB=A1=9C=EC=9A=B4=20=EA=B7=B8=EB=9E=98=ED=94=84=20=ED=8C=8C?= =?UTF-8?q?=EC=9D=BC=EC=9D=84=20=EC=B6=94=EA=B0=80?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- llm_utils/chains.py | 1 + llm_utils/enriched_graph.py | 41 +++++++++++++++++++++++++++++++ llm_utils/graph.py | 9 ++++--- pyproject.toml => pyproject.toml_ | 0 4 files changed, 48 insertions(+), 3 deletions(-) create mode 100644 llm_utils/enriched_graph.py rename pyproject.toml => pyproject.toml_ (100%) diff --git a/llm_utils/chains.py b/llm_utils/chains.py index d5d9c57..e13d01f 100644 --- a/llm_utils/chains.py +++ b/llm_utils/chains.py @@ -142,6 +142,7 @@ def create_profile_extraction_chain(llm): query_refiner_chain = create_query_refiner_chain(llm) query_maker_chain = create_query_maker_chain(llm) +profile_extraction_chain = create_profile_extraction_chain(llm) query_refiner_with_profile_chain = create_query_refiner_with_profile_chain(llm) if __name__ == "__main__": diff --git a/llm_utils/enriched_graph.py b/llm_utils/enriched_graph.py new file mode 100644 index 0000000..1018ec6 --- /dev/null +++ b/llm_utils/enriched_graph.py @@ -0,0 +1,41 @@ +import json + +from langgraph.graph import StateGraph, END +from llm_utils.graph import ( + QueryMakerState, + GET_TABLE_INFO, + PROFILE_EXTRACTION, + QUERY_REFINER, + CONTEXT_ENRICHMENT, + QUERY_MAKER, + get_table_info_node, + profile_extraction_node, + query_refiner_with_profile_node, + context_enrichment_node, + query_maker_node, +) + +""" +기본 워크플로우에 '프로파일 추출(PROFILE_EXTRACTION)'과 '컨텍스트 보강(CONTEXT_ENRICHMENT)'를 +추가한 확장된 그래프입니다. +""" + +# StateGraph 생성 및 구성 +builder = StateGraph(QueryMakerState) +builder.set_entry_point(GET_TABLE_INFO) + +# 노드 추가 +builder.add_node(GET_TABLE_INFO, get_table_info_node) +builder.add_node(QUERY_REFINER, query_refiner_with_profile_node) +builder.add_node(PROFILE_EXTRACTION, profile_extraction_node) +builder.add_node(CONTEXT_ENRICHMENT, context_enrichment_node) +builder.add_node(QUERY_MAKER, query_maker_node) + +# 기본 엣지 설정 +builder.add_edge(GET_TABLE_INFO, PROFILE_EXTRACTION) +builder.add_edge(PROFILE_EXTRACTION, QUERY_REFINER) +builder.add_edge(QUERY_REFINER, CONTEXT_ENRICHMENT) +builder.add_edge(CONTEXT_ENRICHMENT, QUERY_MAKER) + +# QUERY_MAKER 노드 후 종료 +builder.add_edge(QUERY_MAKER, END) diff --git a/llm_utils/graph.py b/llm_utils/graph.py index d0ca83e..36bb774 100644 --- a/llm_utils/graph.py +++ b/llm_utils/graph.py @@ -12,6 +12,7 @@ query_refiner_chain, query_maker_chain, query_refiner_with_profile_chain, + profile_extraction_chain, ) from llm_utils.tools import get_info_from_db @@ -24,6 +25,8 @@ TOOL = "tool" TABLE_FILTER = "table_filter" QUERY_MAKER = "query_maker" +PROFILE_EXTRACTION = "profile_extraction" +CONTEXT_ENRICHMENT = "context_enrichment" # 상태 타입 정의 (추가 상태 정보와 메시지들을 포함) @@ -43,11 +46,10 @@ class QueryMakerState(TypedDict): # 노드 함수: PROFILE_EXTRACTION 노드 def profile_extraction_node(state: QueryMakerState): - result = query_refiner_with_profile_chain.invoke( - {"question": state["messages"][0].content} - ) + result = profile_extraction_chain.invoke({"question": state["messages"][0].content}) state["question_profile"] = result + print("profile_extraction_node : ", result) return state @@ -132,6 +134,7 @@ def context_enrichment_node(state: QueryMakerState): from langchain_core.messages import HumanMessage state["messages"].append(HumanMessage(content=enriched_text.content)) + print("After context enrichment : ", enriched_text.content) return state diff --git a/pyproject.toml b/pyproject.toml_ similarity index 100% rename from pyproject.toml rename to pyproject.toml_ From 9e0318d40adcc704266960056daca0a91e43e8a3 Mon Sep 17 00:00:00 2001 From: seyeong Date: Mon, 12 May 2025 14:39:48 +0900 Subject: [PATCH 08/11] =?UTF-8?q?feat(streamlit)=20:=20=EC=82=AC=EC=9D=B4?= =?UTF-8?q?=EB=93=9C=20=EB=B0=94=EC=97=90=20=ED=94=84=EB=A1=9C=ED=8C=8C?= =?UTF-8?q?=EC=9D=BC=20=EC=B6=94=EC=B6=9C=20&=20=EC=BB=A8=ED=85=8D?= =?UTF-8?q?=EC=8A=A4=ED=8A=B8=20=EB=B3=B4=EA=B0=95=20=EC=9B=8C=ED=81=AC?= =?UTF-8?q?=ED=94=8C=EB=A1=9C=EC=9A=B0=20=EC=82=AC=EC=9A=A9=20=EC=B2=B4?= =?UTF-8?q?=ED=81=AC=EB=B0=95=EC=8A=A4=20=EC=B6=94=EA=B0=80,=20=EA=B7=B8?= =?UTF-8?q?=EB=9E=98=ED=94=84=20=EC=97=B0=EA=B2=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- interface/lang2sql.py | 27 +++++++++++++++++++++++---- 1 file changed, 23 insertions(+), 4 deletions(-) diff --git a/interface/lang2sql.py b/interface/lang2sql.py index 973831a..3e06094 100644 --- a/interface/lang2sql.py +++ b/interface/lang2sql.py @@ -11,6 +11,7 @@ from llm_utils.connect_db import ConnectDB from llm_utils.graph import builder +from llm_utils.enriched_graph import builder as enriched_builder DEFAULT_QUERY = "고객 데이터를 기반으로 유니크한 유저 수를 카운트하는 쿼리" SIDEBAR_OPTIONS = { @@ -65,7 +66,10 @@ def execute_query( # 세션 상태에서 그래프 가져오기 graph = st.session_state.get("graph") if graph is None: - graph = builder.compile() + graph_builder = ( + enriched_builder if st.session_state.get("use_enriched") else builder + ) + graph = graph_builder.compile() st.session_state["graph"] = graph res = graph.invoke( @@ -124,14 +128,29 @@ def display_result( st.title("Lang2SQL") +# 워크플로우 선택(UI) +use_enriched = st.sidebar.checkbox( + "프로파일 추출 & 컨텍스트 보강 워크플로우 사용", value=False +) + # 세션 상태 초기화 -if "graph" not in st.session_state: - st.session_state["graph"] = builder.compile() +if ( + "graph" not in st.session_state + or st.session_state.get("use_enriched") != use_enriched +): + graph_builder = enriched_builder if use_enriched else builder + st.session_state["graph"] = graph_builder.compile() + + # 프로파일 추출 & 컨텍스트 보강 그래프 + st.session_state["use_enriched"] = use_enriched st.info("Lang2SQL이 성공적으로 시작되었습니다.") # 새로고침 버튼 추가 if st.sidebar.button("Lang2SQL 새로고침"): - st.session_state["graph"] = builder.compile() + graph_builder = ( + enriched_builder if st.session_state.get("use_enriched") else builder + ) + st.session_state["graph"] = graph_builder.compile() st.sidebar.success("Lang2SQL이 성공적으로 새로고침되었습니다.") user_query = st.text_area( From 4b41c968cb74162dbca47b43f8558a717449a3e8 Mon Sep 17 00:00:00 2001 From: seyeong Date: Tue, 13 May 2025 11:59:09 +0900 Subject: [PATCH 09/11] =?UTF-8?q?refactor=20:=20context=5Fenrichment=5Fnod?= =?UTF-8?q?e,=20query=5Frefiner=5Fwith=5Fprofile=5Fnode,=20profile=5Fextra?= =?UTF-8?q?ction=5Fnode=20=EC=A3=BC=EC=84=9D=20=EC=B6=94=EA=B0=80?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- llm_utils/graph.py | 51 +++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 46 insertions(+), 5 deletions(-) diff --git a/llm_utils/graph.py b/llm_utils/graph.py index 36bb774..25f3d17 100644 --- a/llm_utils/graph.py +++ b/llm_utils/graph.py @@ -45,7 +45,22 @@ class QueryMakerState(TypedDict): # 노드 함수: PROFILE_EXTRACTION 노드 def profile_extraction_node(state: QueryMakerState): - + """ + 자연어 쿼리로부터 질문 유형(PROFILE)을 추출하는 노드입니다. + + 이 노드는 주어진 자연어 쿼리에서 질문의 특성을 분석하여, 해당 질문이 시계열 분석, 집계 함수 사용, 조건 필터 필요 여부, + 그룹화, 정렬/순위, 기간 비교 등 다양한 특성을 갖는지 여부를 추출합니다. + + 추출된 정보는 `QuestionProfile` 모델에 맞춰 저장됩니다. `QuestionProfile` 모델의 필드는 다음과 같습니다: + - `is_timeseries`: 시계열 분석 필요 여부 + - `is_aggregation`: 집계 함수 필요 여부 + - `has_filter`: 조건 필터 필요 여부 + - `is_grouped`: 그룹화 필요 여부 + - `has_ranking`: 정렬/순위 필요 여부 + - `has_temporal_comparison`: 기간 비교 포함 여부 + - `intent_type`: 질문의 주요 의도 유형 + + """ result = profile_extraction_chain.invoke({"question": state["messages"][0].content}) state["question_profile"] = result @@ -70,6 +85,10 @@ def query_refiner_node(state: QueryMakerState): # 노드 함수: QUERY_REFINER 노드 def query_refiner_with_profile_node(state: QueryMakerState): + """ + 자연어 쿼리로부터 질문 유형(PROFILE)을 사용해 자연어 질의를 확장하는 노드입니다. + + """ profile_bullets = profile_to_text(state["question_profile"]) res = query_refiner_with_profile_chain.invoke( @@ -90,8 +109,32 @@ def query_refiner_with_profile_node(state: QueryMakerState): # 노드 함수: CONTEXT_ENRICHMENT 노드 def context_enrichment_node(state: QueryMakerState): + """ + 주어진 질문과 관련된 메타데이터를 기반으로 질문을 풍부하게 만드는 노드입니다. + + 이 함수는 `refined_question`, `profiles`, `related_tables` 정보를 이용하여 자연어 질문을 보강합니다. + 보강 과정에서는 질문의 의도를 유지하면서, 추가적인 세부 정보를 제공하거나 잘못된 용어를 수정합니다. + + 주요 작업: + - 주어진 질문의 메타데이터 (`question_profile` 및 `searched_tables`)를 활용하여, 질문을 수정하거나 추가 정보를 삽입합니다. + - 질문이 시계열 분석 또는 집계 함수 관련인 경우, 이를 명시적으로 강조합니다 (예: "지난 30일 동안"). + - 자연어에서 실제 열 이름 또는 값으로 잘못 매칭된 용어를 수정합니다 (예: ‘미국’ → ‘USA’). + - 보강된 질문을 출력합니다. - import json + Args: + state (QueryMakerState): 쿼리와 관련된 상태 정보를 담고 있는 객체. + 상태 객체는 `refined_input`, `question_profile`, `searched_tables` 등의 정보를 포함합니다. + + Returns: + QueryMakerState: 보강된 질문이 포함된 상태 객체. + + Example: + Given the refined question "What are the total sales in the last month?", + the function would enrich it with additional information such as: + - Ensuring the time period is specified correctly. + - Correcting any column names if necessary. + - Returning the enriched version of the question. + """ searched_tables = state["searched_tables"] searched_tables_json = json.dumps(searched_tables, ensure_ascii=False, indent=2) @@ -131,9 +174,7 @@ def context_enrichment_node(state: QueryMakerState): enriched_text = llm.invoke(prompt.to_messages()) state["refined_input"] = enriched_text - from langchain_core.messages import HumanMessage - - state["messages"].append(HumanMessage(content=enriched_text.content)) + state["messages"].append(enriched_text) print("After context enrichment : ", enriched_text.content) return state From 0a510e62246bf2a2bc509690a7df535c4ca3abc6 Mon Sep 17 00:00:00 2001 From: seyeong Date: Tue, 13 May 2025 12:44:13 +0900 Subject: [PATCH 10/11] =?UTF-8?q?refactor=20:=20query=5Fenrichment=5Fchain?= =?UTF-8?q?=20=EB=A7=8C=EB=93=A4=EA=B3=A0=20context=5Fenrichment=5Fnode?= =?UTF-8?q?=EC=97=90=20=EC=97=B0=EA=B2=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- llm_utils/chains.py | 27 +++++++++++++++++++++++++++ llm_utils/graph.py | 36 +++++++----------------------------- 2 files changed, 34 insertions(+), 29 deletions(-) diff --git a/llm_utils/chains.py b/llm_utils/chains.py index e13d01f..b8f5556 100644 --- a/llm_utils/chains.py +++ b/llm_utils/chains.py @@ -135,6 +135,32 @@ def create_query_refiner_with_profile_chain(llm): ) +def create_query_enrichment_chain(llm): + + enrichment_prompt = PromptTemplate( + input_variables=["refined_question", "profiles", "related_tables"], + template=""" + You are a smart assistant that takes a user question and enriches it using: + 1. Question profiles: {profiles} + 2. Table metadata (names, columns, descriptions): + {related_tables} + + Tasks: + - Correct any wrong terms by matching them to actual column names. + - If the question is time-series or aggregation, add explicit hints (e.g., "over the last 30 days"). + - If needed, map natural language terms to actual column values (e.g., ‘미국’ → ‘USA’ for country_code). + - Output the enriched question only. + + Refined question: + {refined_question} + + Using the refined version for enrichment, but keep original intent in mind. + """.strip(), + ) + + return enrichment_prompt | llm + + def create_profile_extraction_chain(llm): chain = profile_prompt | llm.with_structured_output(QuestionProfile) return chain @@ -144,6 +170,7 @@ def create_profile_extraction_chain(llm): query_maker_chain = create_query_maker_chain(llm) profile_extraction_chain = create_profile_extraction_chain(llm) query_refiner_with_profile_chain = create_query_refiner_with_profile_chain(llm) +query_enrichment_chain = create_query_enrichment_chain(llm) if __name__ == "__main__": query_refiner_chain.invoke() diff --git a/llm_utils/graph.py b/llm_utils/graph.py index 25f3d17..598671a 100644 --- a/llm_utils/graph.py +++ b/llm_utils/graph.py @@ -13,6 +13,7 @@ query_maker_chain, query_refiner_with_profile_chain, profile_extraction_chain, + query_enrichment_chain, ) from llm_utils.tools import get_info_from_db @@ -142,36 +143,13 @@ def context_enrichment_node(state: QueryMakerState): question_profile = state["question_profile"].model_dump() question_profile_json = json.dumps(question_profile, ensure_ascii=False, indent=2) - from langchain.prompts import PromptTemplate - - enrichment_prompt = PromptTemplate( - input_variables=["refined_question", "profiles", "related_tables"], - template=""" - You are a smart assistant that takes a user question and enriches it using: - 1. Question profiles: {profiles} - 2. Table metadata (names, columns, descriptions): - {related_tables} - - Tasks: - - Correct any wrong terms by matching them to actual column names. - - If the question is time-series or aggregation, add explicit hints (e.g., "over the last 30 days"). - - If needed, map natural language terms to actual column values (e.g., ‘미국’ → ‘USA’ for country_code). - - Output the enriched question only. - - Refined question: - {refined_question} - - Using the refined version for enrichment, but keep original intent in mind. - """.strip(), - ) - - llm = get_llm() - prompt = enrichment_prompt.format_prompt( - refined_question=state["refined_input"], - profiles=question_profile_json, - related_tables=searched_tables_json, + enriched_text = query_enrichment_chain.invoke( + input={ + "refined_question": state["refined_input"], + "profiles": question_profile_json, + "related_tables": searched_tables_json, + } ) - enriched_text = llm.invoke(prompt.to_messages()) state["refined_input"] = enriched_text state["messages"].append(enriched_text) From 1707d463c0abc1ebb008478420745e8216f8a10e Mon Sep 17 00:00:00 2001 From: seyeong Date: Tue, 13 May 2025 13:24:59 +0900 Subject: [PATCH 11/11] =?UTF-8?q?refactor=20:=20profile=5Fextraction=5Fpro?= =?UTF-8?q?mpt,=20query=5Fenrichment=5Fprompt=20=EB=A7=88=ED=81=AC?= =?UTF-8?q?=EB=8B=A4=EC=9A=B4=EC=9C=BC=EB=A1=9C=20=EB=B6=84=EB=A6=AC?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- llm_utils/chains.py | 57 ++++++++--------------------- prompt/profile_extraction_prompt.md | 19 ++++++++++ prompt/query_enrichment_prompt.md | 22 +++++++++++ 3 files changed, 56 insertions(+), 42 deletions(-) create mode 100644 prompt/profile_extraction_prompt.md create mode 100644 prompt/query_enrichment_prompt.md diff --git a/llm_utils/chains.py b/llm_utils/chains.py index b8f5556..587538c 100644 --- a/llm_utils/chains.py +++ b/llm_utils/chains.py @@ -113,55 +113,28 @@ def create_query_refiner_with_profile_chain(llm): return tool_choice_prompt | llm -from langchain.prompts import PromptTemplate - -profile_prompt = PromptTemplate( - input_variables=["question"], - template=""" -You are an assistant that analyzes a user question and extracts the following profiles as JSON: -- is_timeseries (boolean) -- is_aggregation (boolean) -- has_filter (boolean) -- is_grouped (boolean) -- has_ranking (boolean) -- has_temporal_comparison (boolean) -- intent_type (one of: trend, lookup, comparison, distribution) - -Return only valid JSON matching the QuestionProfile schema. - -Question: -{question} -""".strip(), -) - - def create_query_enrichment_chain(llm): + prompt = get_prompt_template("query_enrichment_prompt") - enrichment_prompt = PromptTemplate( - input_variables=["refined_question", "profiles", "related_tables"], - template=""" - You are a smart assistant that takes a user question and enriches it using: - 1. Question profiles: {profiles} - 2. Table metadata (names, columns, descriptions): - {related_tables} - - Tasks: - - Correct any wrong terms by matching them to actual column names. - - If the question is time-series or aggregation, add explicit hints (e.g., "over the last 30 days"). - - If needed, map natural language terms to actual column values (e.g., ‘미국’ → ‘USA’ for country_code). - - Output the enriched question only. - - Refined question: - {refined_question} - - Using the refined version for enrichment, but keep original intent in mind. - """.strip(), + enrichment_prompt = ChatPromptTemplate.from_messages( + [ + SystemMessagePromptTemplate.from_template(prompt), + ] ) - return enrichment_prompt | llm + chain = enrichment_prompt | llm + return chain def create_profile_extraction_chain(llm): + prompt = get_prompt_template("profile_extraction_prompt") + + profile_prompt = ChatPromptTemplate.from_messages( + [ + SystemMessagePromptTemplate.from_template(prompt), + ] + ) + chain = profile_prompt | llm.with_structured_output(QuestionProfile) return chain diff --git a/prompt/profile_extraction_prompt.md b/prompt/profile_extraction_prompt.md new file mode 100644 index 0000000..606e037 --- /dev/null +++ b/prompt/profile_extraction_prompt.md @@ -0,0 +1,19 @@ +# Role + +You are an assistant that analyzes a user question and extracts the following profiles as JSON: +- is_timeseries (boolean) +- is_aggregation (boolean) +- has_filter (boolean) +- is_grouped (boolean) +- has_ranking (boolean) +- has_temporal_comparison (boolean) +- intent_type (one of: trend, lookup, comparison, distribution) + +# Input + +Question: +{question} + +# Output Example + +The output must be a valid JSON matching the QuestionProfile schema. diff --git a/prompt/query_enrichment_prompt.md b/prompt/query_enrichment_prompt.md new file mode 100644 index 0000000..98fbb6f --- /dev/null +++ b/prompt/query_enrichment_prompt.md @@ -0,0 +1,22 @@ +# Role + +You are a smart assistant that takes a user question and enriches it using: +1. Question profiles: {profiles} +2. Table metadata (names, columns, descriptions): + {related_tables} + +# Tasks + +- Correct any wrong terms by matching them to actual column names. +- If the question is time-series or aggregation, add explicit hints (e.g., "over the last 30 days"). +- If needed, map natural language terms to actual column values (e.g., ‘미국’ → ‘USA’ for country_code). +- Output the enriched question only. + +# Input + +Refined question: +{refined_question} + +# Notes + +Using the refined version for enrichment, but keep the original intent in mind.