diff --git a/example.env b/example.env deleted file mode 100644 index bcee7a6..0000000 --- a/example.env +++ /dev/null @@ -1,16 +0,0 @@ -# Airbyte -GITHUB_TOKEN= - -# Database credentials -DB_HOST=db -DB_AB_DESTINATION_HOST=host.docker.internal -DB_PORT=5432 -EXPOSED_DB_PORT=5432 -DB_NAME=databasex -DB_USER=postgres -DB_PASSWORD=postgres -DB_URL=postgresql://${DB_USER}:${DB_PASSWORD}@${DB_HOST}:${DB_PORT}/${DB_NAME} - -# Google cloud gemini -GEMINI_API_KEY= -GEMINI_MODEL_NAME=gemini-2.0-flash \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 271ff40..00101b6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,23 +1,14 @@ -airbyte==0.31.3 -FastAPI==0.115.9 -uvicorn==0.34.3 -psycopg2-binary==2.9.10 -python-dotenv==1.1.0 -vanna==0.7.9 -chromadb -google-generativeai==0.8.5 -google-cloud-aiplatform==1.96.0 -onnxruntime==1.22.0 - -pytest==7.4.0 -pytest-asyncio==0.21.1 -pytest-mock==3.12.0 -pytest-cov==4.1.0 -httpx -requests-mock==1.11.0 - -torch>=2.2.0,<2.8.0 -transformers>=4.35.0,<5.0.0 -tokenizers>=0.15.0,<1.0.0 -scikit-learn>=1.5.2,<2.0.0 -numpy>=1.21.0,<1.27.0 +annotated-types==0.7.0 +anyio==4.11.0 +certifi==2025.10.5 +exceptiongroup==1.3.0 +h11==0.16.0 +httpcore==1.0.9 +httpx==0.28.1 +idna==3.11 +ollama==0.6.0 +pydantic==2.12.3 +pydantic_core==2.41.4 +sniffio==1.3.1 +typing-inspection==0.4.2 +typing_extensions==4.15.0 diff --git a/src/api/controller/AskController.py b/src/api/controller/AskController.py index 2bc33a9..cd320a4 100644 --- a/src/api/controller/AskController.py +++ b/src/api/controller/AskController.py @@ -1,60 +1,204 @@ from src.assets.pattern.singleton import SingletonMeta from src.api.models import Question, Response +from langchain_google_genai import ChatGoogleGenerativeAI +from langchain.schema import SystemMessage, HumanMessage, AIMessage +from langchain.memory import ConversationBufferWindowMemory from google import genai from src.api.database.MyVanna import MyVanna +import json +from typing import Dict, Optional +import hashlib from src.assets.aux.env import env + # Gemini env vars GEMINI_API_KEY = env["GEMINI_API_KEY"] GEMINI_MODEL_NAME = env["GEMINI_MODEL_NAME"] + class AskController(metaclass=SingletonMeta): def __init__(self): self.client = genai.Client(api_key=GEMINI_API_KEY) + # LLM principal para geração de respostas + self.llm = ChatGoogleGenerativeAI( + model=GEMINI_MODEL_NAME, + google_api_key=GEMINI_API_KEY, + temperature=0, + max_tokens=None, + timeout=None, + max_retries=2, + convert_system_message_to_human=True + ) + + # Memória conversacional (mantém últimas 5 interações) + self.memory = ConversationBufferWindowMemory( + k=5, # Número de interações a manter + memory_key="chat_history", + return_messages=True, + input_key="question", + output_key="answer" + ) + + # Instância do vanna self.vn = MyVanna(config={ - 'print_prompt': False, + 'print_prompt': False, 'print_sql': False, 'api_key': GEMINI_API_KEY, 'model_name': GEMINI_MODEL_NAME }) - self.vn.prepare() - def ask(self, question: Question): - client = genai.Client(api_key=GEMINI_API_KEY) - - vn = MyVanna(config={ - 'print_prompt': False, - 'print_sql': False, - 'api_key': GEMINI_API_KEY, - 'model_name': GEMINI_MODEL_NAME - }) + # Cache simples para queries SQL (em memória) + self.sql_cache: Dict[str, str] = {} + self.result_cache: Dict[str, any] = {} + def _preprocess_question(self, question: str) -> str: + """ + Pré-processa a pergunta usando LLM com contexto da memória + Versão simplificada sem LLMChain para evitar problemas de parsing + """ try: - sql_gerado = vn.generate_sql(question.question) + # Recupera histórico da memória + memory_vars = self.memory.load_memory_variables({}) + chat_history = memory_vars.get("chat_history", []) + + # Monta contexto do histórico + history_text = "" + if chat_history: + history_text = "Histórico da conversa:\n" + for msg in chat_history[-5:]: # Últimas 5 mensagens + if isinstance(msg, HumanMessage): + history_text += f"Usuário: {msg.content}\n" + elif isinstance(msg, AIMessage): + history_text += f"Assistente: {msg.content}\n" + history_text += "\n" - if "SELECT" not in sql_gerado.upper(): - return {"output": "Não consegui entender sua pergunta bem o suficiente para gerar uma resposta SQL válida."} + # Monta o prompt + system_prompt = """Você é um assistente especializado em processar perguntas sobre dados do GitHub. - resultado = vn.run_sql(sql_gerado) +Sua função é: +1. Analisar o histórico da conversa para entender o contexto +2. Resolver referências contextuais (ex: "e no mês passado?", "mostre mais detalhes", "e o outro repositório?") +3. Normalizar expressões temporais: + - "3 meses" → "90 dias" + - "1 ano e 2 meses" → "425 dias" + - Meses separados = 30 dias cada +4. Normalizar terminologia: + - "mudança" → "commit" + - "alteração" → "commit" +5. Expandir a pergunta com contexto necessário do histórico - if not resultado: - return {"output": "A consulta foi feita, mas não há dados correspondentes no banco."} +REGRAS CRÍTICAS: +- Se a pergunta fizer referência a algo anterior ("e aquele", "o outro", "também"), inclua o contexto explícito +- Se não houver referência contextual, retorne a pergunta apenas normalizada +- NÃO explique, NÃO confirme, NÃO dê exemplos +- Retorne APENAS a pergunta processada e expandida em uma única linha""" - # Prompt mais informativo - prompt = f""" - Você é um assistente que responde perguntas sobre dados extraídos do GitHub. + # Monta mensagens + messages = [ + SystemMessage(content=system_prompt) + ] + + # Adiciona histórico se existir + if history_text: + messages.append(HumanMessage(content=history_text)) + + # Adiciona pergunta atual + messages.append(HumanMessage(content=f"Pergunta a processar: {question}")) + + # Chama LLM + response = self.llm.invoke(messages) + + # Extrai conteúdo da resposta + if hasattr(response, 'content'): + processed = response.content.strip() + else: + processed = str(response).strip() + + # Remove qualquer explicação extra (pega só a primeira linha) + processed = processed.split('\n')[0].strip() + + return processed + + except Exception as e: + print(f"[Warning] Erro no preprocessing: {e}. Usando pergunta original.") + return question + + def _get_cache_key(self, text: str) -> str: + """Gera chave de cache baseada no hash da pergunta normalizada""" + normalized = text.lower().strip() + return hashlib.md5(normalized.encode()).hexdigest() + + def _validate_sql(self, sql: str) -> tuple[bool, Optional[str]]: + """Valida SQL gerado para segurança""" + sql_upper = sql.upper().strip() + + # Whitelist: apenas SELECT permitido + if not sql_upper.startswith("SELECT"): + return False, "Apenas queries SELECT são permitidas" + + # Blacklist: operações perigosas + dangerous_keywords = [ + "DELETE", "DROP", "TRUNCATE", "INSERT", + "UPDATE", "ALTER", "CREATE", "GRANT", "REVOKE" + ] + + for keyword in dangerous_keywords: + if keyword in sql_upper: + return False, f"Operação '{keyword}' não é permitida" + + # Limite de complexidade (número de JOINs) + join_count = sql_upper.count("JOIN") + if join_count > 10: + return False, "Query muito complexa (máximo 10 JOINs)" + + return True, None + + def _format_response_with_context(self, question: str, sql: str, result: any) -> str: + """Formata resposta final usando LLM com contexto conversacional""" + + # Recupera histórico da memória + memory_vars = self.memory.load_memory_variables({}) + chat_history = memory_vars.get("chat_history", []) + + # Monta contexto do histórico + history_context = "" + if chat_history: + history_context = "\n\nContexto da conversa anterior:\n" + for msg in chat_history[-3:]: # Últimas 3 mensagens + if isinstance(msg, HumanMessage): + history_context += f"Usuário: {msg.content}\n" + elif isinstance(msg, AIMessage): + history_context += f"Assistente: {msg.content}\n" + + prompt = f""" +Você é um assistente especializado em análise de dados do GitHub. - Pergunta do usuário: "{question.question}" +{history_context} - Resultado da consulta SQL: {resultado} +Pergunta atual: "{question}" - Gere uma resposta clara e útil para o usuário, explicando o que o resultado significa. - """ +SQL gerado e executado: +```sql +{sql} +``` - response = client.models.generate_content( +Resultado da consulta: {result} + +Com base no contexto da conversa e nos resultados, gere uma resposta: +1. Clara e direta +2. Em linguagem natural +3. Destacando insights relevantes +4. Relacionando com perguntas anteriores se aplicável +5. Formato estruturado se houver múltiplos dados + +Responda de forma conversacional e útil. +""" + + try: + response = self.client.models.generate_content( model=GEMINI_MODEL_NAME, contents=prompt, config={ @@ -62,9 +206,129 @@ def ask(self, question: Question): "response_schema": list[Response], } ) + return response.parsed[0].texto + except Exception as e: + print(f"[Error] Erro ao formatar resposta: {e}") + # Fallback para resposta simples + return f"Consulta executada com sucesso. Resultado: {result}" + + def ask(self, question: Question, session_id: Optional[str] = None) -> dict: + """ + Processa pergunta com contexto conversacional + + Args: + question: Objeto Question com a pergunta do usuário + session_id: ID da sessão para memória multi-usuário (futuro) + """ + + try: + original_question = question.question + print(f"[Original] {original_question}") + + # Etapa 1: Pré-processar com contexto + processed_question = self._preprocess_question(original_question) + print(f"[Preprocessed] {processed_question}") + + # Etapa 2: Verificar cache de SQL + cache_key = self._get_cache_key(processed_question) + + if cache_key in self.sql_cache: + print(f"[Cache Hit] SQL encontrado no cache") + sql_gerado = self.sql_cache[cache_key] + else: + # Gerar SQL com Vanna + sql_gerado = self.vn.generate_sql(processed_question) - texto = response.parsed[0].texto - return {"output": texto} + # Validar SQL + is_valid, error_msg = self._validate_sql(sql_gerado) + if not is_valid: + return { + "output": f"Query inválida: {error_msg}", + "error": True + } + + # Armazenar no cache + self.sql_cache[cache_key] = sql_gerado + print(f"[Cache Miss] SQL gerado e armazenado") + + print(f"[SQL] {sql_gerado}") + + # Etapa 3: Verificar cache de resultados + result_cache_key = hashlib.md5(sql_gerado.encode()).hexdigest() + + if result_cache_key in self.result_cache: + print(f"[Cache Hit] Resultado encontrado no cache") + resultado = self.result_cache[result_cache_key] + else: + # Executar SQL + resultado = self.vn.run_sql(sql_gerado) + + if not resultado: + # Salvar na memória mesmo sem resultado + self.memory.save_context( + inputs={"question": original_question}, + outputs={"answer": "Não há dados correspondentes no banco."} + ) + return { + "output": "A consulta foi executada, mas não há dados correspondentes.", + "sql": sql_gerado + } + + # Armazenar resultado no cache + self.result_cache[result_cache_key] = resultado + print(f"[Cache Miss] Resultado obtido e armazenado") + + # Etapa 4: Formatar resposta com contexto + resposta_formatada = self._format_response_with_context( + question=original_question, + sql=sql_gerado, + result=resultado + ) + + # Etapa 5: Salvar na memória + self.memory.save_context( + inputs={"question": original_question}, + outputs={"answer": resposta_formatada} + ) + + return { + "output": resposta_formatada, + "sql": sql_gerado, + "cached": result_cache_key in self.result_cache + } except Exception as e: - return {"output": f"Ocorreu um erro ao processar sua pergunta: {str(e)}"} \ No newline at end of file + import traceback + error_msg = f"Erro ao processar pergunta: {str(e)}" + print(f"[Error] {error_msg}") + print(f"[Error] Traceback: {traceback.format_exc()}") + + # Salvar erro na memória + try: + self.memory.save_context( + inputs={"question": question.question}, + outputs={"answer": error_msg} + ) + except: + pass + + return { + "output": error_msg, + "error": True + } + + def clear_memory(self): + """Limpa o histórico da conversa""" + self.memory.clear() + print("[Memory] Histórico limpo") + + def get_conversation_history(self) -> list: + """Retorna o histórico da conversa""" + memory_vars = self.memory.load_memory_variables({}) + return memory_vars.get("chat_history", []) + + def clear_cache(self): + """Limpa os caches de SQL e resultados""" + self.sql_cache.clear() + self.result_cache.clear() + print("[Cache] Caches limpos") \ No newline at end of file diff --git a/src/api/database/MyVanna.py b/src/api/database/MyVanna.py index 84c2e70..a21ab42 100644 --- a/src/api/database/MyVanna.py +++ b/src/api/database/MyVanna.py @@ -15,12 +15,61 @@ DB_PASSWORD = env["DB_PASSWORD"] DB_URL = env["DB_URL"] -class MyVanna(ChromaDB_VectorStore, GoogleGeminiChat): +class ChromaDB_VectorStoreReset(ChromaDB_VectorStore): def __init__(self, config=None): if config is None: config = {} - ChromaDB_VectorStore.__init__(self, config=config) + # Força o reset na inicialização + config["reset_on_init"] = config.get("reset_on_init", True) + + super().__init__(config=config) + + # Limpa as coleções após a inicialização padrão + if config["reset_on_init"]: + self._reset_collections() + + # Recria as coleções vazias + collection_metadata = config.get("collection_metadata", None) + self.documentation_collection = self.chroma_client.get_or_create_collection( + name="documentation", + embedding_function=self.embedding_function, + metadata=collection_metadata, + ) + self.ddl_collection = self.chroma_client.get_or_create_collection( + name="ddl", + embedding_function=self.embedding_function, + metadata=collection_metadata, + ) + self.sql_collection = self.chroma_client.get_or_create_collection( + name="sql", + embedding_function=self.embedding_function, + metadata=collection_metadata, + ) + + def _reset_collections(self): + """Limpa todas as coleções existentes""" + try: + self.chroma_client.delete_collection("documentation") + except Exception: + pass + + try: + self.chroma_client.delete_collection("ddl") + except Exception: + pass + + try: + self.chroma_client.delete_collection("sql") + except Exception: + pass + +class MyVanna(ChromaDB_VectorStoreReset, GoogleGeminiChat): + def __init__(self, config=None): + if config is None: + config = {} + + ChromaDB_VectorStoreReset.__init__(self, config=config) GEMINI_API_KEY = config.get('api_key') GEMINI_MODEL_NAME = config.get('model_name') @@ -123,32 +172,134 @@ def prepare(self): password = DB_PASSWORD ) - schema = self.get_schema() - - self.train(ddl=self.schema) + self.train(ddl=self.get_schema()) self.train(documentation=""" - A tabela repository armazena os repositórios, identificados por um ID único e nome. +Table: user_info + + id: Bigint primary key with default value from sequence + + login: Required username field (character varying) + + html_url: Required profile URL field (text) + +Table: milestone + + id: Bigint primary key with default value from sequence + + repository_id: Associated repository ID (integer, required) + + title: Milestone title (text, required) + + description: Milestone description (text, optional) + + number: Milestone number (integer, required) + + state: Milestone state (character varying, required) + + created_at: Creation timestamp with time zone + + updated_at: Update timestamp with time zone + + creator: Creator user ID (bigint, required) + +Table: repository + + id: Integer primary key with default value from sequence + + name: Repository name (character varying, required) + +Table: branch + + id: Bigint primary key with default value from sequence + + name: Branch name (character varying, required) + + repository_id: Associated repository ID (integer, required) + +Table: issue - A tabela user contém informações dos usuários, como login e URL de perfil. É usada como referência em outras tabelas, como quem criou issues, pull requests e milestones. + id: Bigint primary key with default value from sequence - A tabela milestone representa marcos definidos nos repositórios, contendo título, descrição, número, estado (aberta ou fechada), datas de criação e atualização, o repositório ao qual pertence e o usuário criador. + title: Issue title (text, required) - A tabela issue armazena tarefas ou bugs reportados. Contém título, corpo, número, autor, repositório, milestone relacionada, datas e URL. As atribuições de usuários a uma issue são registradas em issue_assignees. + body: Issue body/description (text, optional) - A tabela pull_requests armazena os pull requests criados nos repositórios. Inclui título, corpo, número, estado, criador, repositório, milestone (opcional), datas e URL. Os responsáveis são registrados em pull_request_assignees. + number: Issue number (integer, required) - A tabela branch armazena os nomes de branches de cada repositório. + html_url: Issue URL (text, optional) - A tabela commits armazena cada commit feito. Cada registro contém o SHA, mensagem, autor (usuário), repositório, branch (opcional), data de criação e URL. Também há referência à tabela de usuários. + created_at: Creation timestamp with time zone - A tabela parents_commits representa a relação entre commits e seus pais (para commits com múltiplos ancestrais, como em merges). Usa o SHA do commit pai e o ID do commit filho. + updated_at: Update timestamp with time zone - A tabela issue_assignees relaciona múltiplos usuários a uma mesma issue, representando atribuições de tarefas. É uma tabela de junção entre issues e usuários. + created_by: Creator user ID (bigint, required) - A tabela pull_request_assignees relaciona múltiplos usuários a um pull request, permitindo registrar quem é responsável por revisar ou aprovar um PR. + repository_id: Associated repository ID (bigint, required) - O modelo garante integridade por meio de chaves estrangeiras, e unicidade de registros por restrições compostas (como número + repositório para issues, pull requests e milestones). + milestone_id: Associated milestone ID (bigint, optional) + +Table: pull_requests + + id: Bigint primary key with default value from sequence + + created_by: Creator user ID (bigint, required) + + repository_id: Associated repository ID (integer, required) + + number: Pull request number (integer, required) + + state: Pull request state (character varying, required) + + title: Pull request title (text, required) + + body: Pull request body/description (text, optional) + + html_url: Pull request URL (text, required) + + created_at: Creation timestamp with time zone + + updated_at: Update timestamp with time zone + + milestone_id: Associated milestone ID (bigint, optional) + +Table: commits + + id: Bigint primary key with default value from sequence + + user_id: Author user ID (bigint, required) + + branch_id: Associated branch ID (integer, optional) + + pull_request_id: Associated pull request ID (bigint, optional) + + created_at: Creation timestamp with time zone + + message: Commit message (text, required) + + sha: Commit SHA hash (character varying, required) + + html_url: Commit URL (text, optional) + +Table: parents_commits + + id: Integer primary key with default value from sequence + + parent_sha: Parent commit SHA hash (character varying, required) + + commit_id: Child commit ID (integer, required) + +Table: issue_assignees + + issue_id: Issue ID (bigint, required, part of primary key) + + user_id: Assigned user ID (bigint, required, part of primary key) + +Table: pull_request_assignees + + pull_request_id: Pull request ID (bigint, required, part of primary key) + + user_id: Assigned user ID (bigint, required, part of primary key) """) self.train(sql=""" @@ -232,3 +383,5 @@ def prepare(self): ORDER BY total_commits DESC; """) + +