diff --git a/requirements.txt b/requirements.txt index bc06f43..eca9e2c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,4 +7,4 @@ tiktoken faiss-cpu langchain==0.1.0 python-dotenv -chromadb==0.3.29 \ No newline at end of file +chromadb==0.4.22 \ No newline at end of file diff --git a/src/app.py b/src/app.py index c5ca9e0..9964d20 100644 --- a/src/app.py +++ b/src/app.py @@ -76,6 +76,7 @@ def search_page_sunwoo(): uploaded_files = st.file_uploader("Upload a CSV file", type=["csv"], accept_multiple_files=True) search_sunwoo = Search_sunwoo() + retriever_tool = None pipeline_compression_retriever = None st.subheader("웹툰을 검색해보세요!") @@ -94,7 +95,7 @@ def search_page_sunwoo(): datas.append(df) if len(datas) > 0: # print("TEST!!!!!!!" + str(len(datas))) - pipeline_compression_retriever = search_sunwoo.make_retriever(datas) + retriever_tool, pipeline_compression_retriever = search_sunwoo.make_retriever(datas) chat_input_key = "search_chat_input_sunwoo" # 사용자 인풋 받기 @@ -110,7 +111,7 @@ def search_page_sunwoo(): message_placeholder = st.empty() full_response = "" # assistant_response = search.receive_chat(prompt) - assistant_response = search_sunwoo.run(prompt, pipeline_compression_retriever) + assistant_response = search_sunwoo.run(prompt, retriever_tool, pipeline_compression_retriever) print(assistant_response) message_placeholder.markdown(assistant_response) diff --git a/src/model/search_sunwoo/search.py b/src/model/search_sunwoo/search.py index 13d33b3..9565111 100644 --- a/src/model/search_sunwoo/search.py +++ b/src/model/search_sunwoo/search.py @@ -50,7 +50,7 @@ class Response(BaseModel): class Search(): def __init__(self): - self.Response = Response + pass def get_data_from_csv(self, file_path): """ Get data from csv file """ @@ -108,7 +108,8 @@ def get_embeddings(self, documents, cached_embedder, collection_name="webtoon"): vectorstore = Chroma.from_documents( documents, cached_embedder, - collection_name=collection_name) + collection_name=collection_name + ) return vectorstore def get_retriever(self, vectorstore): @@ -118,23 +119,6 @@ def get_retriever(self, vectorstore): ) return retriever - def get_bm25_retriever(self, documents): - retriver = BM25Retriever.from_documents(documents) - return retriver - - def get_elastic_vector(self, cached_embedder, documents): - elasticsearch_url = "https://50da3596960c471fb7fa70548b0a71d1.us-central1.gcp.cloud.es.io:443" - elastic_vectorstore = ElasticsearchStore.from_documents( - documents, - cached_embedder, - es_url=elasticsearch_url, - es_api_key=ELASTIC_API_KEY, - index_name="webtoon", - strategy=ElasticsearchStore.ExactRetrievalStrategy() - ) - elastic_vectorstore.client.indices.refresh(index="webtoon") - return elastic_vectorstore - def get_pipeline_compression_retriever(self, retriever, embeddings): """ Create a pipeline of document transformers and a retriever """ ## filters @@ -177,7 +161,7 @@ def parse(self, output): else: return AgentActionMessageLog( tool=name, tool_input=inputs, log="", message_log=[output] - ) + ) def get_agent(self, retriever_tool): system_message = """ @@ -209,9 +193,7 @@ def get_agent(self, retriever_tool): openai_api_key=OPENAI_API_KEY, max_tokens=2000 ) - - llm_with_tools = llm.bind_functions([retriever_tool, self.Response]) - + llm_with_tools = llm.bind_functions([retriever_tool, Response]) agent = ( { "title": lambda x: x["title"], @@ -252,16 +234,16 @@ def make_retriever(self, datas): retriever = self.get_retriever(vectorstore) pipeline_compression_retriever = self.get_pipeline_compression_retriever(retriever, cached_embedder) - return pipeline_compression_retriever - - def run(self, query, pipeline_compression_retriever): retriever_tool = self.get_retriever_tool(pipeline_compression_retriever) + return retriever_tool, pipeline_compression_retriever + + def run(self, query, retriever_tool, pipeline_compression_retriever): result = self.get_all_relevant_documents(query, pipeline_compression_retriever) if len(result) == 0: return "검색 결과가 없습니다." - agent_executor = self.get_agent(retriever_tool) + agent_executor = self.get_agent(retriever_tool) response = agent_executor( { "title": result[0].metadata["title"], @@ -281,7 +263,7 @@ def run(self, query, pipeline_compression_retriever): file_name = "info.csv" for title_id in title_id_list: - file_path = os.path.join("..", "streamlit/src/model/data/webtoon", title_id, file_name) + file_path = os.path.join("..", "streamlit/src/model/data/webtoon/info", title_id+".csv") absolute_file_path = os.path.abspath(file_path) df = pd.read_csv(absolute_file_path) @@ -292,9 +274,9 @@ def run(self, query, pipeline_compression_retriever): ## 화산귀환, 신의탑, 전지적 독자 시점, 가비지타임, 선천적 얼간이들, 내가 키운 S급들, 대학원 탈출일지 now = time.time() # pipeline_compression_retriever, vectorstore = search.make_retriever_from_url(namu_list) - pipeline_compression_retriever = search.make_retriever(datas) + retriever_tool, pipeline_compression_retriever = search.make_retriever(datas) print("Make Retriever: " + str(time.time()- now)) now2 = time.time() - result = search.run(query, pipeline_compression_retriever) + result = search.run(query, retriever_tool, pipeline_compression_retriever) print("Find result: " + str(time.time()- now2)) print(result)