113 lines
3.2 KiB
Python
113 lines
3.2 KiB
Python
"""Tests for LLM provider factory."""
|
|
|
|
import pytest
|
|
from unittest.mock import Mock, patch
|
|
|
|
from ai_novel.llm_providers import LLMProviderFactory
|
|
|
|
|
|
def test_create_openai_llm():
|
|
"""Test creating OpenAI LLM."""
|
|
config = {
|
|
"type": "openai",
|
|
"model": "gpt-4",
|
|
"temperature": 0.7,
|
|
"max_tokens": 2000,
|
|
"api_key": "test-key"
|
|
}
|
|
|
|
with patch('ai_novel.llm_providers.ChatOpenAI') as mock_openai:
|
|
mock_instance = Mock()
|
|
mock_openai.return_value = mock_instance
|
|
|
|
llm = LLMProviderFactory.create_llm(config)
|
|
|
|
mock_openai.assert_called_once_with(
|
|
model="gpt-4",
|
|
temperature=0.7,
|
|
max_tokens=2000,
|
|
openai_api_key="test-key",
|
|
openai_api_base=None
|
|
)
|
|
assert llm == mock_instance
|
|
|
|
|
|
def test_create_openai_compatible_llm():
|
|
"""Test creating OpenAI-compatible LLM."""
|
|
config = {
|
|
"type": "openai_compatible",
|
|
"model": "anthropic/claude-3-haiku",
|
|
"temperature": 0.8,
|
|
"max_tokens": 1500,
|
|
"api_key": "test-key",
|
|
"base_url": "https://openrouter.ai/api/v1"
|
|
}
|
|
|
|
with patch('ai_novel.llm_providers.ChatOpenAI') as mock_openai:
|
|
mock_instance = Mock()
|
|
mock_openai.return_value = mock_instance
|
|
|
|
llm = LLMProviderFactory.create_llm(config)
|
|
|
|
mock_openai.assert_called_once_with(
|
|
model="anthropic/claude-3-haiku",
|
|
temperature=0.8,
|
|
max_tokens=1500,
|
|
openai_api_key="test-key",
|
|
openai_api_base="https://openrouter.ai/api/v1"
|
|
)
|
|
assert llm == mock_instance
|
|
|
|
|
|
def test_create_ollama_llm():
|
|
"""Test creating Ollama LLM."""
|
|
config = {
|
|
"type": "ollama",
|
|
"model": "llama3.1",
|
|
"temperature": 0.6,
|
|
"base_url": "http://localhost:11434"
|
|
}
|
|
|
|
with patch('ai_novel.llm_providers.ChatOllama') as mock_ollama:
|
|
mock_instance = Mock()
|
|
mock_ollama.return_value = mock_instance
|
|
|
|
llm = LLMProviderFactory.create_llm(config)
|
|
|
|
mock_ollama.assert_called_once_with(
|
|
model="llama3.1",
|
|
temperature=0.6,
|
|
base_url="http://localhost:11434"
|
|
)
|
|
assert llm == mock_instance
|
|
|
|
|
|
def test_unsupported_provider_type():
|
|
"""Test error handling for unsupported provider type."""
|
|
config = {
|
|
"type": "unsupported_provider",
|
|
"model": "some-model"
|
|
}
|
|
|
|
with pytest.raises(ValueError, match="Unsupported provider type: unsupported_provider"):
|
|
LLMProviderFactory.create_llm(config)
|
|
|
|
|
|
def test_default_values():
|
|
"""Test that default values are used when not specified."""
|
|
config = {"type": "openai"}
|
|
|
|
with patch('ai_novel.llm_providers.ChatOpenAI') as mock_openai:
|
|
mock_instance = Mock()
|
|
mock_openai.return_value = mock_instance
|
|
|
|
llm = LLMProviderFactory.create_llm(config)
|
|
|
|
mock_openai.assert_called_once_with(
|
|
model="gpt-3.5-turbo", # default
|
|
temperature=0.7, # default
|
|
max_tokens=2000, # default
|
|
openai_api_key=None,
|
|
openai_api_base=None
|
|
)
|