From 4e84565e768f35182a93632249dd77987edc70c7 Mon Sep 17 00:00:00 2001 From: ehddnr301 Date: Mon, 8 Sep 2025 19:50:39 +0900 Subject: [PATCH 1/3] =?UTF-8?q?=EA=B7=B8=EB=9E=98=ED=94=84=20=EB=B9=8C?= =?UTF-8?q?=EB=8D=94=20=EB=B0=8F=20Lang2SQL=20UI=20=EC=97=85=EB=8D=B0?= =?UTF-8?q?=EC=9D=B4=ED=8A=B8=20#134?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- interface/graph_builder.py | 54 ++++++++++++++++++++++++++++++++++++++ interface/lang2sql.py | 24 +++++++++++++---- 2 files changed, 73 insertions(+), 5 deletions(-) diff --git a/interface/graph_builder.py b/interface/graph_builder.py index c7505fc..6fbcc31 100644 --- a/interface/graph_builder.py +++ b/interface/graph_builder.py @@ -128,6 +128,50 @@ def render_sequence(sequence: List[str]) -> str: # 프리셋에서는 QUERY_MAKER 자동 포함 use_query_maker = True +# GET_TABLE_INFO 설정 +st.subheader("GET_TABLE_INFO 설정") +_prev_cfg = st.session_state.get("graph_config", {}) + +_retriever_options = { + "기본": "벡터 검색 (기본)", + "Reranker": "Reranker 검색 (정확도 향상)", +} +_retriever_keys = list(_retriever_options.keys()) +_retriever_default = _prev_cfg.get("retriever_name", "기본") +_retriever_index = ( + _retriever_keys.index(_retriever_default) + if _retriever_default in _retriever_keys + else 0 +) + +retriever_name = st.selectbox( + "테이블 검색기", + options=_retriever_keys, + format_func=lambda x: _retriever_options[x], + index=_retriever_index, +) + +top_n = st.slider( + "검색할 테이블 정보 개수", + min_value=1, + max_value=20, + value=int(_prev_cfg.get("top_n", 5)), + step=1, +) + +_device_options = ["cpu", "cuda"] +_device_default = _prev_cfg.get("device", "cpu") +_device_index = ( + _device_options.index(_device_default) + if _device_default in _device_options + else 0 +) +device = st.selectbox( + "모델 실행 장치", + options=_device_options, + index=_device_index, +) + def build_sequence_with_qm( preset: str, use_profile: bool, use_context: bool, use_qm: bool @@ -166,6 +210,9 @@ def build_sequence_with_qm( "use_profile": use_profile, "use_context": use_context, "use_query_maker": use_query_maker, + "retriever_name": retriever_name, + "top_n": top_n, + "device": device, } # 선택이 바뀌면 자동으로 세션 그래프 갱신 @@ -174,6 +221,10 @@ def build_sequence_with_qm( _builder = build_state_graph(sequence) st.session_state["graph"] = _builder.compile() st.session_state["graph_config"] = config + # Lang2SQL 메인 UI에서 기본값으로 사용할 옵션 전달 + st.session_state["default_retriever_name"] = retriever_name + st.session_state["default_top_n"] = top_n + st.session_state["default_device"] = device st.info("그래프가 세션에 적용되었습니다.") # 수동 새로고침 버튼 @@ -181,6 +232,9 @@ def build_sequence_with_qm( _builder = build_state_graph(sequence) st.session_state["graph"] = _builder.compile() st.session_state["graph_config"] = config + st.session_state["default_retriever_name"] = retriever_name + st.session_state["default_top_n"] = top_n + st.session_state["default_device"] = device st.success("세션 그래프가 새로고침되었습니다.") with st.expander("현재 세션 그래프 설정"): diff --git a/interface/lang2sql.py b/interface/lang2sql.py index 7b7265a..058ca73 100644 --- a/interface/lang2sql.py +++ b/interface/lang2sql.py @@ -302,10 +302,17 @@ def should_show(_key: str) -> bool: index=0, ) +_device_options = ["cpu", "cuda"] +_default_device = st.session_state.get("default_device", "cpu") +_device_index = ( + _device_options.index(_default_device) + if _default_device in _device_options + else 0 +) device = st.selectbox( "모델 실행 장치를 선택하세요:", - options=["cpu", "cuda"], - index=0, + options=_device_options, + index=_device_index, ) retriever_options = { @@ -313,18 +320,25 @@ def should_show(_key: str) -> bool: "Reranker": "Reranker 검색 (정확도 향상)", } +_retriever_keys = list(retriever_options.keys()) +_default_retriever = st.session_state.get("default_retriever_name", "기본") +_retriever_index = ( + _retriever_keys.index(_default_retriever) + if _default_retriever in _retriever_keys + else 0 +) user_retriever = st.selectbox( "검색기 유형을 선택하세요:", - options=list(retriever_options.keys()), + options=_retriever_keys, format_func=lambda x: retriever_options[x], - index=0, + index=_retriever_index, ) user_top_n = st.slider( "검색할 테이블 정보 개수:", min_value=1, max_value=20, - value=5, + value=int(st.session_state.get("default_top_n", 5)), step=1, help="검색할 테이블 정보의 개수를 설정합니다. 값이 클수록 더 많은 테이블 정보를 검색하지만 처리 시간이 길어질 수 있습니다.", ) From a42ba1f46830785299544643538997aadab23eb6 Mon Sep 17 00:00:00 2001 From: ehddnr301 Date: Mon, 8 Sep 2025 19:54:07 +0900 Subject: [PATCH 2/3] =?UTF-8?q?chore:=20black=20pre-commit=20=EC=A0=81?= =?UTF-8?q?=EC=9A=A9=20#134?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- interface/graph_builder.py | 4 +--- interface/lang2sql.py | 4 +--- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/interface/graph_builder.py b/interface/graph_builder.py index 6fbcc31..6eb347e 100644 --- a/interface/graph_builder.py +++ b/interface/graph_builder.py @@ -162,9 +162,7 @@ def render_sequence(sequence: List[str]) -> str: _device_options = ["cpu", "cuda"] _device_default = _prev_cfg.get("device", "cpu") _device_index = ( - _device_options.index(_device_default) - if _device_default in _device_options - else 0 + _device_options.index(_device_default) if _device_default in _device_options else 0 ) device = st.selectbox( "모델 실행 장치", diff --git a/interface/lang2sql.py b/interface/lang2sql.py index 058ca73..3a3cbe1 100644 --- a/interface/lang2sql.py +++ b/interface/lang2sql.py @@ -305,9 +305,7 @@ def should_show(_key: str) -> bool: _device_options = ["cpu", "cuda"] _default_device = st.session_state.get("default_device", "cpu") _device_index = ( - _device_options.index(_default_device) - if _default_device in _device_options - else 0 + _device_options.index(_default_device) if _default_device in _device_options else 0 ) device = st.selectbox( "모델 실행 장치를 선택하세요:", From 5881db676c8adad2d5e74db45ab8acbf2614ccea Mon Sep 17 00:00:00 2001 From: ehddnr301 Date: Mon, 8 Sep 2025 19:55:00 +0900 Subject: [PATCH 3/3] chore: bump version to 0.2.2 --- version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/version.py b/version.py index 4403cf1..7c5b6f3 100644 --- a/version.py +++ b/version.py @@ -18,4 +18,4 @@ - PATCH는 1로 증가합니다. """ -__version__ = "0.2.1" +__version__ = "0.2.2"