diff --git a/.gitignore b/.gitignore index 4906374..f76bfeb 100644 --- a/.gitignore +++ b/.gitignore @@ -161,10 +161,13 @@ cython_debug/ # option (not recommended) you can uncomment the following to ignore the entire idea folder. #.idea/ +# Visual Studio code local settings +.vscode/ + # SSH Keys *_host_key *.key *.pub # config files -*.ini +*.ini \ No newline at end of file diff --git a/SSH/config.ini.TEMPLATE b/SSH/config.ini.TEMPLATE index 16ac131..70a5a82 100644 --- a/SSH/config.ini.TEMPLATE +++ b/SSH/config.ini.TEMPLATE @@ -29,6 +29,8 @@ server_version_string = OpenSSH_8.2p1 Ubuntu-4ubuntu0.3 ##### OpenAI llm_provider = openai model_name = gpt-4o +# if you want to specify another OpenAI-compatible endpoint +# llm_endpoint = https://example.com ##### ollama llama3 #llm_provider = ollama diff --git a/SSH/ssh_server.py b/SSH/ssh_server.py index 2d72982..fcd2d7d 100755 --- a/SSH/ssh_server.py +++ b/SSH/ssh_server.py @@ -16,10 +16,10 @@ from base64 import b64encode from operator import itemgetter from langchain_openai import ChatOpenAI -from langchain_aws import ChatBedrock, ChatBedrockConverse +from langchain_aws import ChatBedrockConverse from langchain_google_genai import ChatGoogleGenerativeAI from langchain_ollama import ChatOllama -from langchain_core.messages import HumanMessage, SystemMessage, trim_messages +from langchain_core.messages import HumanMessage, trim_messages from langchain_core.chat_history import BaseChatMessageHistory, InMemoryChatMessageHistory from langchain_core.runnables.history import RunnableWithMessageHistory from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder @@ -295,7 +295,7 @@ def llm_get_session_history(session_id: str) -> BaseChatMessageHistory: return llm_sessions[session_id] def get_user_accounts() -> dict: - if (not 'user_accounts' in config) or (len(config.items('user_accounts')) == 0): + if ("user_accounts" not in config) or (len(config.items("user_accounts")) == 0): raise ValueError("No user accounts found in configuration file.") accounts = dict() @@ -305,15 +305,21 @@ def get_user_accounts() -> dict: return accounts -def choose_llm(llm_provider: Optional[str] = None, model_name: Optional[str] = None): - llm_provider_name = llm_provider or config['llm'].get("llm_provider", "openai") +def choose_llm( + llm_provider: Optional[str] = None, + model_name: Optional[str] = None, + llm_endpoint: Optional[str] = None, +): + llm_provider_name: str + if llm_provider is not None: + llm_provider_name = llm_provider + else: + llm_provider_name = config["llm"].get("llm_provider", "openai") llm_provider_name = llm_provider_name.lower() model_name = model_name or config['llm'].get("model_name", "gpt-3.5-turbo") if llm_provider_name == 'openai': - llm_model = ChatOpenAI( - model=model_name - ) + llm_model = ChatOpenAI(model=model_name, endpoint=llm_endpoint) elif llm_provider_name == 'ollama': llm_model = ChatOllama( model=model_name @@ -452,7 +458,11 @@ def get_prompts(prompt: Optional[str], prompt_file: Optional[str]) -> dict: llm_system_prompt = prompts["system_prompt"] llm_user_prompt = prompts["user_prompt"] - llm = choose_llm(config['llm'].get("llm_provider"), config['llm'].get("model_name")) + llm = choose_llm( + config["llm"].get("llm_provider"), + config["llm"].get("model_name"), + config["llm"].get("llm_endpoint"), + ) llm_sessions = dict()