From 92bf3819fed4c72e212eb28a5b5ffd297a14ddd0 Mon Sep 17 00:00:00 2001 From: Christian R <117322020+cdreetz@users.noreply.github.com> Date: Fri, 25 Jul 2025 23:55:39 -0500 Subject: [PATCH] Allow OpenAI base_url without api key --- README.md | 4 + docs/source/datasets_and_generators.rst | 2 + docs/source/quickstart.rst | 8 +- src/chatan/generator.py | 37 ++++- tests/test_generator.py | 209 ++++++++++++++++-------- 5 files changed, 187 insertions(+), 73 deletions(-) diff --git a/README.md b/README.md index f7a4748..cff0c22 100644 --- a/README.md +++ b/README.md @@ -46,6 +46,10 @@ df = ds.generate(n=10) ```python # OpenAI gen = chatan.generator("openai", "YOUR_OPENAI_API_KEY") +# OpenAI-compatible service that does not require a key +gen_alt = chatan.generator( + "openai", base_url="https://api.example.com/v1" +) # Anthropic gen = chatan.generator("anthropic", "YOUR_ANTHROPIC_API_KEY") diff --git a/docs/source/datasets_and_generators.rst b/docs/source/datasets_and_generators.rst index 68aacbe..8916317 100644 --- a/docs/source/datasets_and_generators.rst +++ b/docs/source/datasets_and_generators.rst @@ -10,6 +10,8 @@ Supported generator providers Chatan includes built-in clients for a few common model sources: * ``openai`` - access GPT models via the OpenAI API + (use ``base_url`` for OpenAI-compatible endpoints; API key required + only when ``base_url`` is omitted) * ``anthropic`` - use Claude models from Anthropic * ``transformers``/``huggingface`` - run local HuggingFace models with ``transformers`` diff --git a/docs/source/quickstart.rst b/docs/source/quickstart.rst index c6e6607..12a42e4 100644 --- a/docs/source/quickstart.rst +++ b/docs/source/quickstart.rst @@ -18,8 +18,12 @@ Basic Usage .. code-block:: python import chatan - - gen = chatan.generator("openai", "YOUR_OPENAI_API_KEY") + + gen = chatan.generator("openai", "YOUR_OPENAI_API_KEY") + # use an OpenAI-compatible service that doesn't need a key + # gen = chatan.generator( + # "openai", base_url="https://api.example.com/v1" + # ) # or for Anthropic # gen = chatan.generator("anthropic", "YOUR_ANTHROPIC_API_KEY") diff --git a/src/chatan/generator.py b/src/chatan/generator.py index d8bdbee..1d05f4e 100644 --- a/src/chatan/generator.py +++ b/src/chatan/generator.py @@ -29,8 +29,34 @@ def generate(self, prompt: str, **kwargs) -> str: class OpenAIGenerator(BaseGenerator): """OpenAI GPT generator.""" - def __init__(self, api_key: str, model: str = "gpt-3.5-turbo", **kwargs): - self.client = openai.OpenAI(api_key=api_key) + def __init__( + self, + api_key: Optional[str] = None, + model: str = "gpt-3.5-turbo", + base_url: Optional[str] = None, + **kwargs, + ) -> None: + """Initialize the generator. + + Parameters + ---------- + api_key: + API key for authenticating with the OpenAI service. Required if + ``base_url`` is not provided. + model: + Model name to use for generation. + base_url: + Optional custom API base URL for OpenAI compatible services. + **kwargs: + Additional default parameters passed to ``chat.completions.create``. + """ + + client_kwargs = {} + if api_key is not None: + client_kwargs["api_key"] = api_key + if base_url is not None: + client_kwargs["base_url"] = base_url + self.client = openai.OpenAI(**client_kwargs) self.model = model self.default_kwargs = kwargs @@ -209,7 +235,7 @@ def __init__(self, provider: str, api_key: Optional[str] = None, **kwargs): provider_lower = provider.lower() try: if provider_lower == "openai": - if api_key is None: + if api_key is None and kwargs.get("base_url") is None: raise ValueError("API key is required for OpenAI") self._generator = OpenAIGenerator(api_key, **kwargs) elif provider_lower == "anthropic": @@ -256,6 +282,9 @@ def generator( provider: str = "openai", api_key: Optional[str] = None, **kwargs ) -> GeneratorClient: """Create a generator client.""" - if provider.lower() in {"openai", "anthropic"} and api_key is None: + if provider.lower() == "openai": + if api_key is None and kwargs.get("base_url") is None: + raise ValueError("API key is required") + elif provider.lower() == "anthropic" and api_key is None: raise ValueError("API key is required") return GeneratorClient(provider, api_key, **kwargs) diff --git a/tests/test_generator.py b/tests/test_generator.py index 0727338..0fdec61 100644 --- a/tests/test_generator.py +++ b/tests/test_generator.py @@ -1,20 +1,24 @@ """Comprehensive tests for generator module.""" -import pytest import sys -from unittest.mock import Mock, patch, MagicMock +from unittest.mock import MagicMock, Mock, patch + +import pytest + from chatan.generator import ( - OpenAIGenerator, AnthropicGenerator, - GeneratorFunction, GeneratorClient, - generator + GeneratorFunction, + OpenAIGenerator, + generator, ) # Conditional imports for torch-dependent tests try: import torch + from chatan.generator import TransformersGenerator + TORCH_AVAILABLE = True except ImportError: TORCH_AVAILABLE = False @@ -23,21 +27,35 @@ class TestOpenAIGenerator: """Test OpenAI generator implementation.""" - @patch('openai.OpenAI') + @patch("openai.OpenAI") def test_init_default_model(self, mock_openai): """Test OpenAI generator initialization with default model.""" gen = OpenAIGenerator("test-key") assert gen.model == "gpt-3.5-turbo" mock_openai.assert_called_once_with(api_key="test-key") - @patch('openai.OpenAI') + @patch("openai.OpenAI") def test_init_custom_model(self, mock_openai): """Test OpenAI generator initialization with custom model.""" gen = OpenAIGenerator("test-key", model="gpt-4", temperature=0.8) assert gen.model == "gpt-4" assert gen.default_kwargs == {"temperature": 0.8} - @patch('openai.OpenAI') + @patch("openai.OpenAI") + def test_init_with_base_url(self, mock_openai): + """OpenAI generator should pass base_url when provided.""" + OpenAIGenerator(base_url="https://other") + mock_openai.assert_called_once_with(base_url="https://other") + + @patch("openai.OpenAI") + def test_init_with_base_url_and_key(self, mock_openai): + """API key is still passed when provided along with base_url.""" + OpenAIGenerator("test-key", base_url="https://other") + mock_openai.assert_called_once_with( + api_key="test-key", base_url="https://other" + ) + + @patch("openai.OpenAI") def test_generate_basic(self, mock_openai): """Test basic content generation.""" # Setup mock @@ -54,11 +72,10 @@ def test_generate_basic(self, mock_openai): assert result == "Generated content" mock_client.chat.completions.create.assert_called_once_with( - model="gpt-3.5-turbo", - messages=[{"role": "user", "content": "Test prompt"}] + model="gpt-3.5-turbo", messages=[{"role": "user", "content": "Test prompt"}] ) - @patch('openai.OpenAI') + @patch("openai.OpenAI") def test_generate_with_kwargs(self, mock_openai): """Test generation with additional kwargs.""" mock_client = Mock() @@ -76,10 +93,10 @@ def test_generate_with_kwargs(self, mock_openai): model="gpt-3.5-turbo", messages=[{"role": "user", "content": "Test"}], temperature=0.5, - max_tokens=100 + max_tokens=100, ) - @patch('openai.OpenAI') + @patch("openai.OpenAI") def test_kwargs_override(self, mock_openai): """Test that call-time kwargs override defaults.""" mock_client = Mock() @@ -100,14 +117,14 @@ def test_kwargs_override(self, mock_openai): class TestAnthropicGenerator: """Test Anthropic generator implementation.""" - @patch('anthropic.Anthropic') + @patch("anthropic.Anthropic") def test_init_default_model(self, mock_anthropic): """Test Anthropic generator initialization.""" gen = AnthropicGenerator("test-key") assert gen.model == "claude-3-sonnet-20240229" mock_anthropic.assert_called_once_with(api_key="test-key") - @patch('anthropic.Anthropic') + @patch("anthropic.Anthropic") def test_generate_basic(self, mock_anthropic): """Test basic content generation.""" mock_client = Mock() @@ -125,10 +142,10 @@ def test_generate_basic(self, mock_anthropic): mock_client.messages.create.assert_called_once_with( model="claude-3-sonnet-20240229", messages=[{"role": "user", "content": "Test prompt"}], - max_tokens=1000 + max_tokens=1000, ) - @patch('anthropic.Anthropic') + @patch("anthropic.Anthropic") def test_max_tokens_extraction(self, mock_anthropic): """Test that max_tokens is extracted from kwargs.""" mock_client = Mock() @@ -150,9 +167,9 @@ def test_max_tokens_extraction(self, mock_anthropic): @pytest.mark.skipif(not TORCH_AVAILABLE, reason="PyTorch not available") class TestTransformersGenerator: """Test TransformersGenerator functionality (only when torch is available).""" - - @patch('transformers.AutoTokenizer.from_pretrained') - @patch('transformers.AutoModelForCausalLM.from_pretrained') + + @patch("transformers.AutoTokenizer.from_pretrained") + @patch("transformers.AutoModelForCausalLM.from_pretrained") def test_transformers_init(self, mock_model, mock_tokenizer): """Test TransformersGenerator initialization.""" # Mock tokenizer @@ -160,14 +177,14 @@ def test_transformers_init(self, mock_model, mock_tokenizer): mock_tok.pad_token = None mock_tok.eos_token = "[EOS]" mock_tokenizer.return_value = mock_tok - + # Mock model mock_mdl = Mock() mock_model.return_value = mock_mdl - with patch('torch.cuda.is_available', return_value=False): + with patch("torch.cuda.is_available", return_value=False): gen = TransformersGenerator("gpt2") - + assert gen.model_name == "gpt2" assert gen.device == "cpu" mock_tokenizer.assert_called_once_with("gpt2") @@ -180,10 +197,10 @@ def test_template_substitution(self): """Test template variable substitution.""" mock_generator = Mock() mock_generator.generate.return_value = "Generated content" - + func = GeneratorFunction(mock_generator, "Write about {topic} in {style}") result = func({"topic": "AI", "style": "casual"}) - + assert result == "Generated content" mock_generator.generate.assert_called_once_with("Write about AI in casual") @@ -191,7 +208,7 @@ def test_missing_context_variable(self): """Test behavior with missing context variables.""" mock_generator = Mock() func = GeneratorFunction(mock_generator, "Write about {topic}") - + with pytest.raises(KeyError): func({"wrong_key": "value"}) @@ -199,10 +216,10 @@ def test_extra_context_variables(self): """Test behavior with extra context variables.""" mock_generator = Mock() mock_generator.generate.return_value = "Generated" - + func = GeneratorFunction(mock_generator, "Write about {topic}") result = func({"topic": "AI", "extra": "ignored"}) - + assert result == "Generated" mock_generator.generate.assert_called_once_with("Write about AI") @@ -210,20 +227,26 @@ def test_extra_context_variables(self): class TestGeneratorClient: """Test GeneratorClient interface.""" - @patch('chatan.generator.OpenAIGenerator') + @patch("chatan.generator.OpenAIGenerator") def test_openai_client_creation(self, mock_openai_gen): """Test OpenAI client creation.""" - client = GeneratorClient("openai", "test-key", temperature=0.7) - mock_openai_gen.assert_called_once_with("test-key", temperature=0.7) + client = GeneratorClient("openai", base_url="https://other") + mock_openai_gen.assert_called_once_with( + None, base_url="https://other" + ) - @patch('chatan.generator.AnthropicGenerator') + @patch("chatan.generator.AnthropicGenerator") def test_anthropic_client_creation(self, mock_anthropic_gen): """Test Anthropic client creation.""" - client = GeneratorClient("anthropic", "test-key", model="claude-3-opus-20240229") - mock_anthropic_gen.assert_called_once_with("test-key", model="claude-3-opus-20240229") + client = GeneratorClient( + "anthropic", "test-key", model="claude-3-opus-20240229" + ) + mock_anthropic_gen.assert_called_once_with( + "test-key", model="claude-3-opus-20240229" + ) @pytest.mark.skipif(not TORCH_AVAILABLE, reason="PyTorch not available") - @patch('chatan.generator.TransformersGenerator') + @patch("chatan.generator.TransformersGenerator") def test_transformers_client_creation(self, mock_hf_gen): """Test Transformers client creation.""" client = GeneratorClient("transformers", model="gpt2") @@ -232,8 +255,11 @@ def test_transformers_client_creation(self, mock_hf_gen): def test_transformers_client_creation_no_torch(self): """Test Transformers client creation when torch is not available.""" # Temporarily patch TRANSFORMERS_AVAILABLE to False - with patch('chatan.generator.TRANSFORMERS_AVAILALBE', False): - with pytest.raises(ImportError, match="Local model support requires additional dependencies"): + with patch("chatan.generator.TRANSFORMERS_AVAILALBE", False): + with pytest.raises( + ImportError, + match="Local model support requires additional dependencies", + ): GeneratorClient("transformers", model="gpt2") def test_unsupported_provider(self): @@ -241,12 +267,12 @@ def test_unsupported_provider(self): with pytest.raises(ValueError, match="Unsupported provider: invalid"): GeneratorClient("invalid", "test-key") - @patch('chatan.generator.OpenAIGenerator') + @patch("chatan.generator.OpenAIGenerator") def test_callable_returns_generator_function(self, mock_openai_gen): """Test that calling client returns GeneratorFunction.""" client = GeneratorClient("openai", "test-key") func = client("Template {var}") - + assert isinstance(func, GeneratorFunction) assert func.prompt_template == "Template {var}" @@ -259,19 +285,28 @@ def test_missing_api_key(self): with pytest.raises(ValueError, match="API key is required"): generator("openai") - @patch('chatan.generator.GeneratorClient') + def test_base_url_no_key_ok(self): + """Providing base_url should allow omitting the API key.""" + gen = generator("openai", base_url="https://other") + assert isinstance(gen, GeneratorClient) + + @patch("chatan.generator.GeneratorClient") def test_factory_creates_client(self, mock_client): """Test factory function creates GeneratorClient.""" - result = generator("openai", "test-key", temperature=0.5) - mock_client.assert_called_once_with("openai", "test-key", temperature=0.5) + result = generator( + "openai", base_url="https://other", temperature=0.5 + ) + mock_client.assert_called_once_with( + "openai", None, base_url="https://other", temperature=0.5 + ) - @patch('chatan.generator.GeneratorClient') + @patch("chatan.generator.GeneratorClient") def test_default_provider(self, mock_client): """Test default provider is openai.""" generator(api_key="test-key") mock_client.assert_called_once_with("openai", "test-key") - @patch('chatan.generator.GeneratorClient') + @patch("chatan.generator.GeneratorClient") def test_transformers_provider_no_key(self, mock_client): """Transformers provider should not require API key.""" generator("transformers", model="gpt2") @@ -281,7 +316,7 @@ def test_transformers_provider_no_key(self, mock_client): class TestIntegration: """Integration tests for generator components.""" - @patch('openai.OpenAI') + @patch("openai.OpenAI") def test_end_to_end_openai(self, mock_openai): """Test complete OpenAI generation pipeline.""" # Setup mock @@ -301,10 +336,47 @@ def test_end_to_end_openai(self, mock_openai): assert result == "The capital of France is Paris." mock_client.chat.completions.create.assert_called_once_with( model="gpt-3.5-turbo", - messages=[{"role": "user", "content": "What is the capital of France?"}] + messages=[{"role": "user", "content": "What is the capital of France?"}], ) - @patch('anthropic.Anthropic') + @patch("openai.OpenAI") + def test_end_to_end_openai_custom_base_url(self, mock_openai): + """Ensure custom base_url is passed to the OpenAI client.""" + mock_client = Mock() + mock_response = Mock() + mock_choice = Mock() + mock_choice.message.content = "Hi" + mock_response.choices = [mock_choice] + mock_client.chat.completions.create.return_value = mock_response + mock_openai.return_value = mock_client + + gen = generator("openai", base_url="https://other") + func = gen("Say hi") + result = func({}) + + assert result == "Hi" + mock_openai.assert_called_once_with(base_url="https://other") + + @patch("openai.OpenAI") + def test_end_to_end_openai_custom_base_url_with_key(self, mock_openai): + """base_url with key should pass both arguments.""" + mock_client = Mock() + mock_response = Mock() + mock_choice = Mock() + mock_choice.message.content = "Hi" + mock_response.choices = [mock_choice] + mock_client.chat.completions.create.return_value = mock_response + mock_openai.return_value = mock_client + + gen = generator("openai", "test-key", base_url="https://other") + func = gen("Hello") + func({}) + + mock_openai.assert_called_once_with( + api_key="test-key", base_url="https://other" + ) + + @patch("anthropic.Anthropic") def test_end_to_end_anthropic(self, mock_anthropic): """Test complete Anthropic generation pipeline.""" # Setup mock @@ -323,7 +395,7 @@ def test_end_to_end_anthropic(self, mock_anthropic): assert result == "Python is a programming language." - @patch('openai.OpenAI') + @patch("openai.OpenAI") def test_multiple_generations(self, mock_openai): """Test multiple generations with same generator.""" mock_client = Mock() @@ -336,17 +408,17 @@ def test_multiple_generations(self, mock_openai): gen = generator("openai", "test-key") func = gen("Generate {type}") - + result1 = func({"type": "poem"}) result2 = func({"type": "story"}) - + assert result1 == "Response" assert result2 == "Response" assert mock_client.chat.completions.create.call_count == 2 @pytest.mark.skipif(not TORCH_AVAILABLE, reason="PyTorch not available") - @patch('transformers.AutoTokenizer.from_pretrained') - @patch('transformers.AutoModelForCausalLM.from_pretrained') + @patch("transformers.AutoTokenizer.from_pretrained") + @patch("transformers.AutoModelForCausalLM.from_pretrained") def test_end_to_end_transformers(self, mock_model, mock_tokenizer): """Test complete Transformers generation pipeline.""" # Mock tokenizer @@ -354,10 +426,13 @@ def test_end_to_end_transformers(self, mock_model, mock_tokenizer): mock_tok.pad_token = None mock_tok.eos_token = "[EOS]" mock_tok.eos_token_id = 2 - mock_tok.return_value = {'input_ids': torch.tensor([[1, 2, 3]]), 'attention_mask': torch.tensor([[1, 1, 1]])} + mock_tok.return_value = { + "input_ids": torch.tensor([[1, 2, 3]]), + "attention_mask": torch.tensor([[1, 1, 1]]), + } mock_tok.decode.return_value = "Hello" mock_tokenizer.return_value = mock_tok - + # Mock model mock_mdl = Mock() mock_mdl.generate.return_value = [torch.tensor([1, 2, 3, 4, 5])] @@ -365,14 +440,14 @@ def test_end_to_end_transformers(self, mock_model, mock_tokenizer): gen = generator("transformers", model="gpt2") func = gen("Say hi to {name}") - - with patch('torch.no_grad'): + + with patch("torch.no_grad"): result = func({"name": "Bob"}) assert result == "Hello" mock_tokenizer.assert_called_once_with("gpt2") - @patch('openai.OpenAI') + @patch("openai.OpenAI") def test_generator_function_with_variables(self, mock_openai): """GeneratorFunction should accept default variables.""" mock_client = Mock() @@ -390,21 +465,21 @@ def test_generator_function_with_variables(self, mock_openai): assert result == "Question about elephants" mock_client.chat.completions.create.assert_called_once_with( model="gpt-3.5-turbo", - messages=[{"role": "user", "content": "Question about elephants"}] + messages=[{"role": "user", "content": "Question about elephants"}], ) def test_case_insensitive_provider(self): """Test that provider names are case insensitive.""" - with patch('chatan.generator.OpenAIGenerator') as mock_gen: + with patch("chatan.generator.OpenAIGenerator") as mock_gen: generator("OPENAI", "test-key") mock_gen.assert_called_once() - - with patch('chatan.generator.AnthropicGenerator') as mock_gen: + + with patch("chatan.generator.AnthropicGenerator") as mock_gen: generator("ANTHROPIC", "test-key") mock_gen.assert_called_once() if TORCH_AVAILABLE: - with patch('chatan.generator.TransformersGenerator') as mock_gen: + with patch("chatan.generator.TransformersGenerator") as mock_gen: generator("TRANSFORMERS", model="gpt2") mock_gen.assert_called_once() @@ -412,7 +487,7 @@ def test_case_insensitive_provider(self): class TestErrorHandling: """Test error handling scenarios.""" - @patch('openai.OpenAI') + @patch("openai.OpenAI") def test_openai_api_error(self, mock_openai): """Test handling of OpenAI API errors.""" mock_client = Mock() @@ -423,7 +498,7 @@ def test_openai_api_error(self, mock_openai): with pytest.raises(Exception, match="API Error"): gen.generate("Test prompt") - @patch('anthropic.Anthropic') + @patch("anthropic.Anthropic") def test_anthropic_api_error(self, mock_anthropic): """Test handling of Anthropic API errors.""" mock_client = Mock() @@ -438,8 +513,8 @@ def test_empty_response_handling(self): """Test handling of empty responses.""" mock_generator = Mock() mock_generator.generate.return_value = " " # Whitespace only - + func = GeneratorFunction(mock_generator, "Generate {thing}") result = func({"thing": "content"}) - - assert result == "" # Should be stripped to empty string \ No newline at end of file + + assert result == "" # Should be stripped to empty string