Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@ df = ds.generate(n=10)
```python
# OpenAI
gen = chatan.generator("openai", "YOUR_OPENAI_API_KEY")
# OpenAI-compatible service that does not require a key
gen_alt = chatan.generator(
"openai", base_url="https://api.example.com/v1"
)

# Anthropic
gen = chatan.generator("anthropic", "YOUR_ANTHROPIC_API_KEY")
Expand Down
2 changes: 2 additions & 0 deletions docs/source/datasets_and_generators.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ Supported generator providers
Chatan includes built-in clients for a few common model sources:

* ``openai`` - access GPT models via the OpenAI API
(use ``base_url`` for OpenAI-compatible endpoints; API key required
only when ``base_url`` is omitted)
* ``anthropic`` - use Claude models from Anthropic
* ``transformers``/``huggingface`` - run local HuggingFace models with ``transformers``

Expand Down
8 changes: 6 additions & 2 deletions docs/source/quickstart.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,12 @@ Basic Usage
.. code-block:: python

import chatan

gen = chatan.generator("openai", "YOUR_OPENAI_API_KEY")

gen = chatan.generator("openai", "YOUR_OPENAI_API_KEY")
# use an OpenAI-compatible service that doesn't need a key
# gen = chatan.generator(
# "openai", base_url="https://api.example.com/v1"
# )
# or for Anthropic
# gen = chatan.generator("anthropic", "YOUR_ANTHROPIC_API_KEY")

Expand Down
37 changes: 33 additions & 4 deletions src/chatan/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,34 @@ def generate(self, prompt: str, **kwargs) -> str:
class OpenAIGenerator(BaseGenerator):
"""OpenAI GPT generator."""

def __init__(self, api_key: str, model: str = "gpt-3.5-turbo", **kwargs):
self.client = openai.OpenAI(api_key=api_key)
def __init__(
self,
api_key: Optional[str] = None,
model: str = "gpt-3.5-turbo",
base_url: Optional[str] = None,
**kwargs,
) -> None:
"""Initialize the generator.

Parameters
----------
api_key:
API key for authenticating with the OpenAI service. Required if
``base_url`` is not provided.
model:
Model name to use for generation.
base_url:
Optional custom API base URL for OpenAI compatible services.
**kwargs:
Additional default parameters passed to ``chat.completions.create``.
"""

client_kwargs = {}
if api_key is not None:
client_kwargs["api_key"] = api_key
if base_url is not None:
client_kwargs["base_url"] = base_url
self.client = openai.OpenAI(**client_kwargs)
self.model = model
self.default_kwargs = kwargs

Expand Down Expand Up @@ -209,7 +235,7 @@ def __init__(self, provider: str, api_key: Optional[str] = None, **kwargs):
provider_lower = provider.lower()
try:
if provider_lower == "openai":
if api_key is None:
if api_key is None and kwargs.get("base_url") is None:
raise ValueError("API key is required for OpenAI")
self._generator = OpenAIGenerator(api_key, **kwargs)
elif provider_lower == "anthropic":
Expand Down Expand Up @@ -256,6 +282,9 @@ def generator(
provider: str = "openai", api_key: Optional[str] = None, **kwargs
) -> GeneratorClient:
"""Create a generator client."""
if provider.lower() in {"openai", "anthropic"} and api_key is None:
if provider.lower() == "openai":
if api_key is None and kwargs.get("base_url") is None:
raise ValueError("API key is required")
elif provider.lower() == "anthropic" and api_key is None:
raise ValueError("API key is required")
return GeneratorClient(provider, api_key, **kwargs)
Loading
Loading