"""Tests for LLM provider factory.""" import pytest from unittest.mock import Mock, patch from ai_novel.llm_providers import LLMProviderFactory def test_create_openai_llm(): """Test creating OpenAI LLM.""" config = { "type": "openai", "model": "gpt-4", "temperature": 0.7, "max_tokens": 2000, "api_key": "test-key" } with patch('ai_novel.llm_providers.ChatOpenAI') as mock_openai: mock_instance = Mock() mock_openai.return_value = mock_instance llm = LLMProviderFactory.create_llm(config) mock_openai.assert_called_once_with( model="gpt-4", temperature=0.7, max_tokens=2000, openai_api_key="test-key", openai_api_base=None ) assert llm == mock_instance def test_create_openai_compatible_llm(): """Test creating OpenAI-compatible LLM.""" config = { "type": "openai_compatible", "model": "anthropic/claude-3-haiku", "temperature": 0.8, "max_tokens": 1500, "api_key": "test-key", "base_url": "https://openrouter.ai/api/v1" } with patch('ai_novel.llm_providers.ChatOpenAI') as mock_openai: mock_instance = Mock() mock_openai.return_value = mock_instance llm = LLMProviderFactory.create_llm(config) mock_openai.assert_called_once_with( model="anthropic/claude-3-haiku", temperature=0.8, max_tokens=1500, openai_api_key="test-key", openai_api_base="https://openrouter.ai/api/v1" ) assert llm == mock_instance def test_create_ollama_llm(): """Test creating Ollama LLM.""" config = { "type": "ollama", "model": "llama3.1", "temperature": 0.6, "base_url": "http://localhost:11434" } with patch('ai_novel.llm_providers.ChatOllama') as mock_ollama: mock_instance = Mock() mock_ollama.return_value = mock_instance llm = LLMProviderFactory.create_llm(config) mock_ollama.assert_called_once_with( model="llama3.1", temperature=0.6, base_url="http://localhost:11434" ) assert llm == mock_instance def test_unsupported_provider_type(): """Test error handling for unsupported provider type.""" config = { "type": "unsupported_provider", "model": "some-model" } with pytest.raises(ValueError, match="Unsupported provider type: unsupported_provider"): LLMProviderFactory.create_llm(config) def test_default_values(): """Test that default values are used when not specified.""" config = {"type": "openai"} with patch('ai_novel.llm_providers.ChatOpenAI') as mock_openai: mock_instance = Mock() mock_openai.return_value = mock_instance llm = LLMProviderFactory.create_llm(config) mock_openai.assert_called_once_with( model="gpt-3.5-turbo", # default temperature=0.7, # default max_tokens=2000, # default openai_api_key=None, openai_api_base=None )