Files
ai_novel/tests/test_llm_providers.py

113 lines
3.2 KiB
Python

"""Tests for LLM provider factory."""
import pytest
from unittest.mock import Mock, patch
from ai_novel.llm_providers import LLMProviderFactory
def test_create_openai_llm():
"""Test creating OpenAI LLM."""
config = {
"type": "openai",
"model": "gpt-4",
"temperature": 0.7,
"max_tokens": 2000,
"api_key": "test-key"
}
with patch('ai_novel.llm_providers.ChatOpenAI') as mock_openai:
mock_instance = Mock()
mock_openai.return_value = mock_instance
llm = LLMProviderFactory.create_llm(config)
mock_openai.assert_called_once_with(
model="gpt-4",
temperature=0.7,
max_tokens=2000,
openai_api_key="test-key",
openai_api_base=None
)
assert llm == mock_instance
def test_create_openai_compatible_llm():
"""Test creating OpenAI-compatible LLM."""
config = {
"type": "openai_compatible",
"model": "anthropic/claude-3-haiku",
"temperature": 0.8,
"max_tokens": 1500,
"api_key": "test-key",
"base_url": "https://openrouter.ai/api/v1"
}
with patch('ai_novel.llm_providers.ChatOpenAI') as mock_openai:
mock_instance = Mock()
mock_openai.return_value = mock_instance
llm = LLMProviderFactory.create_llm(config)
mock_openai.assert_called_once_with(
model="anthropic/claude-3-haiku",
temperature=0.8,
max_tokens=1500,
openai_api_key="test-key",
openai_api_base="https://openrouter.ai/api/v1"
)
assert llm == mock_instance
def test_create_ollama_llm():
"""Test creating Ollama LLM."""
config = {
"type": "ollama",
"model": "llama3.1",
"temperature": 0.6,
"base_url": "http://localhost:11434"
}
with patch('ai_novel.llm_providers.ChatOllama') as mock_ollama:
mock_instance = Mock()
mock_ollama.return_value = mock_instance
llm = LLMProviderFactory.create_llm(config)
mock_ollama.assert_called_once_with(
model="llama3.1",
temperature=0.6,
base_url="http://localhost:11434"
)
assert llm == mock_instance
def test_unsupported_provider_type():
"""Test error handling for unsupported provider type."""
config = {
"type": "unsupported_provider",
"model": "some-model"
}
with pytest.raises(ValueError, match="Unsupported provider type: unsupported_provider"):
LLMProviderFactory.create_llm(config)
def test_default_values():
"""Test that default values are used when not specified."""
config = {"type": "openai"}
with patch('ai_novel.llm_providers.ChatOpenAI') as mock_openai:
mock_instance = Mock()
mock_openai.return_value = mock_instance
llm = LLMProviderFactory.create_llm(config)
mock_openai.assert_called_once_with(
model="gpt-3.5-turbo", # default
temperature=0.7, # default
max_tokens=2000, # default
openai_api_key=None,
openai_api_base=None
)