diff --git a/interface/graph_builder.py b/interface/graph_builder.py index c7505fc..6eb347e 100644 --- a/interface/graph_builder.py +++ b/interface/graph_builder.py @@ -128,6 +128,48 @@ def render_sequence(sequence: List[str]) -> str: # 프리셋에서는 QUERY_MAKER 자동 포함 use_query_maker = True +# GET_TABLE_INFO 설정 +st.subheader("GET_TABLE_INFO 설정") +_prev_cfg = st.session_state.get("graph_config", {}) + +_retriever_options = { + "기본": "벡터 검색 (기본)", + "Reranker": "Reranker 검색 (정확도 향상)", +} +_retriever_keys = list(_retriever_options.keys()) +_retriever_default = _prev_cfg.get("retriever_name", "기본") +_retriever_index = ( + _retriever_keys.index(_retriever_default) + if _retriever_default in _retriever_keys + else 0 +) + +retriever_name = st.selectbox( + "테이블 검색기", + options=_retriever_keys, + format_func=lambda x: _retriever_options[x], + index=_retriever_index, +) + +top_n = st.slider( + "검색할 테이블 정보 개수", + min_value=1, + max_value=20, + value=int(_prev_cfg.get("top_n", 5)), + step=1, +) + +_device_options = ["cpu", "cuda"] +_device_default = _prev_cfg.get("device", "cpu") +_device_index = ( + _device_options.index(_device_default) if _device_default in _device_options else 0 +) +device = st.selectbox( + "모델 실행 장치", + options=_device_options, + index=_device_index, +) + def build_sequence_with_qm( preset: str, use_profile: bool, use_context: bool, use_qm: bool @@ -166,6 +208,9 @@ def build_sequence_with_qm( "use_profile": use_profile, "use_context": use_context, "use_query_maker": use_query_maker, + "retriever_name": retriever_name, + "top_n": top_n, + "device": device, } # 선택이 바뀌면 자동으로 세션 그래프 갱신 @@ -174,6 +219,10 @@ def build_sequence_with_qm( _builder = build_state_graph(sequence) st.session_state["graph"] = _builder.compile() st.session_state["graph_config"] = config + # Lang2SQL 메인 UI에서 기본값으로 사용할 옵션 전달 + st.session_state["default_retriever_name"] = retriever_name + st.session_state["default_top_n"] = top_n + st.session_state["default_device"] = device st.info("그래프가 세션에 적용되었습니다.") # 수동 새로고침 버튼 @@ -181,6 +230,9 @@ def build_sequence_with_qm( _builder = build_state_graph(sequence) st.session_state["graph"] = _builder.compile() st.session_state["graph_config"] = config + st.session_state["default_retriever_name"] = retriever_name + st.session_state["default_top_n"] = top_n + st.session_state["default_device"] = device st.success("세션 그래프가 새로고침되었습니다.") with st.expander("현재 세션 그래프 설정"): diff --git a/interface/lang2sql.py b/interface/lang2sql.py index 7b7265a..3a3cbe1 100644 --- a/interface/lang2sql.py +++ b/interface/lang2sql.py @@ -302,10 +302,15 @@ def should_show(_key: str) -> bool: index=0, ) +_device_options = ["cpu", "cuda"] +_default_device = st.session_state.get("default_device", "cpu") +_device_index = ( + _device_options.index(_default_device) if _default_device in _device_options else 0 +) device = st.selectbox( "모델 실행 장치를 선택하세요:", - options=["cpu", "cuda"], - index=0, + options=_device_options, + index=_device_index, ) retriever_options = { @@ -313,18 +318,25 @@ def should_show(_key: str) -> bool: "Reranker": "Reranker 검색 (정확도 향상)", } +_retriever_keys = list(retriever_options.keys()) +_default_retriever = st.session_state.get("default_retriever_name", "기본") +_retriever_index = ( + _retriever_keys.index(_default_retriever) + if _default_retriever in _retriever_keys + else 0 +) user_retriever = st.selectbox( "검색기 유형을 선택하세요:", - options=list(retriever_options.keys()), + options=_retriever_keys, format_func=lambda x: retriever_options[x], - index=0, + index=_retriever_index, ) user_top_n = st.slider( "검색할 테이블 정보 개수:", min_value=1, max_value=20, - value=5, + value=int(st.session_state.get("default_top_n", 5)), step=1, help="검색할 테이블 정보의 개수를 설정합니다. 값이 클수록 더 많은 테이블 정보를 검색하지만 처리 시간이 길어질 수 있습니다.", ) diff --git a/version.py b/version.py index 4403cf1..7c5b6f3 100644 --- a/version.py +++ b/version.py @@ -18,4 +18,4 @@ - PATCH는 1로 증가합니다. """ -__version__ = "0.2.1" +__version__ = "0.2.2"