Initial commit: AI Novel Generation Tool with prologue support and progress tracking
This commit is contained in:
1
tests/__init__.py
Normal file
1
tests/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
"""Tests for AI Novel Writer."""
|
91
tests/test_config.py
Normal file
91
tests/test_config.py
Normal file
@ -0,0 +1,91 @@
|
||||
"""Tests for configuration management."""
|
||||
|
||||
import pytest
|
||||
import tempfile
|
||||
import yaml
|
||||
from pathlib import Path
|
||||
|
||||
from ai_novel.config import Config
|
||||
|
||||
|
||||
def test_default_config():
|
||||
"""Test default configuration loading."""
|
||||
config = Config()
|
||||
|
||||
assert config.get("project_dir") == "."
|
||||
assert config.get("novelist_llm.type") == "openai"
|
||||
assert config.get("novelist_llm.model") == "gpt-3.5-turbo"
|
||||
assert config.get("summarizer_llm.type") == "openai"
|
||||
|
||||
|
||||
def test_config_from_file():
|
||||
"""Test loading configuration from file."""
|
||||
test_config = {
|
||||
"project_dir": "test_novel",
|
||||
"novelist_llm": {
|
||||
"type": "ollama",
|
||||
"model": "llama3.1",
|
||||
"temperature": 0.8
|
||||
}
|
||||
}
|
||||
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.yaml', delete=False) as f:
|
||||
yaml.dump(test_config, f)
|
||||
config_path = f.name
|
||||
|
||||
try:
|
||||
config = Config(config_path)
|
||||
assert config.get("project_dir") == "test_novel"
|
||||
assert config.get("novelist_llm.type") == "ollama"
|
||||
assert config.get("novelist_llm.model") == "llama3.1"
|
||||
assert config.get("novelist_llm.temperature") == 0.8
|
||||
# Should still have default values for unspecified keys
|
||||
assert config.get("summarizer_llm.type") == "openai"
|
||||
finally:
|
||||
Path(config_path).unlink()
|
||||
|
||||
|
||||
def test_config_get_with_default():
|
||||
"""Test getting configuration values with defaults."""
|
||||
config = Config()
|
||||
|
||||
assert config.get("nonexistent.key", "default") == "default"
|
||||
assert config.get("novelist_llm.nonexistent", "default") == "default"
|
||||
|
||||
|
||||
def test_create_example_config():
|
||||
"""Test creating example configuration file."""
|
||||
with tempfile.NamedTemporaryFile(suffix='.yaml', delete=False) as f:
|
||||
config_path = f.name
|
||||
|
||||
try:
|
||||
Config.create_example_config(config_path)
|
||||
assert Path(config_path).exists()
|
||||
|
||||
with open(config_path, 'r') as f:
|
||||
example_config = yaml.safe_load(f)
|
||||
|
||||
assert "project_dir" in example_config
|
||||
assert "novelist_llm" in example_config
|
||||
assert "summarizer_llm" in example_config
|
||||
finally:
|
||||
Path(config_path).unlink()
|
||||
|
||||
|
||||
def test_config_save_to_file():
|
||||
"""Test saving configuration to file."""
|
||||
config = Config()
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix='.yaml', delete=False) as f:
|
||||
config_path = f.name
|
||||
|
||||
try:
|
||||
config.save_to_file(config_path)
|
||||
assert Path(config_path).exists()
|
||||
|
||||
# Load and verify
|
||||
new_config = Config(config_path)
|
||||
assert new_config.get("novelist_llm.type") == config.get("novelist_llm.type")
|
||||
assert new_config.get("novelist_llm.model") == config.get("novelist_llm.model")
|
||||
finally:
|
||||
Path(config_path).unlink()
|
112
tests/test_llm_providers.py
Normal file
112
tests/test_llm_providers.py
Normal file
@ -0,0 +1,112 @@
|
||||
"""Tests for LLM provider factory."""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
from ai_novel.llm_providers import LLMProviderFactory
|
||||
|
||||
|
||||
def test_create_openai_llm():
|
||||
"""Test creating OpenAI LLM."""
|
||||
config = {
|
||||
"type": "openai",
|
||||
"model": "gpt-4",
|
||||
"temperature": 0.7,
|
||||
"max_tokens": 2000,
|
||||
"api_key": "test-key"
|
||||
}
|
||||
|
||||
with patch('ai_novel.llm_providers.ChatOpenAI') as mock_openai:
|
||||
mock_instance = Mock()
|
||||
mock_openai.return_value = mock_instance
|
||||
|
||||
llm = LLMProviderFactory.create_llm(config)
|
||||
|
||||
mock_openai.assert_called_once_with(
|
||||
model="gpt-4",
|
||||
temperature=0.7,
|
||||
max_tokens=2000,
|
||||
openai_api_key="test-key",
|
||||
openai_api_base=None
|
||||
)
|
||||
assert llm == mock_instance
|
||||
|
||||
|
||||
def test_create_openai_compatible_llm():
|
||||
"""Test creating OpenAI-compatible LLM."""
|
||||
config = {
|
||||
"type": "openai_compatible",
|
||||
"model": "anthropic/claude-3-haiku",
|
||||
"temperature": 0.8,
|
||||
"max_tokens": 1500,
|
||||
"api_key": "test-key",
|
||||
"base_url": "https://openrouter.ai/api/v1"
|
||||
}
|
||||
|
||||
with patch('ai_novel.llm_providers.ChatOpenAI') as mock_openai:
|
||||
mock_instance = Mock()
|
||||
mock_openai.return_value = mock_instance
|
||||
|
||||
llm = LLMProviderFactory.create_llm(config)
|
||||
|
||||
mock_openai.assert_called_once_with(
|
||||
model="anthropic/claude-3-haiku",
|
||||
temperature=0.8,
|
||||
max_tokens=1500,
|
||||
openai_api_key="test-key",
|
||||
openai_api_base="https://openrouter.ai/api/v1"
|
||||
)
|
||||
assert llm == mock_instance
|
||||
|
||||
|
||||
def test_create_ollama_llm():
|
||||
"""Test creating Ollama LLM."""
|
||||
config = {
|
||||
"type": "ollama",
|
||||
"model": "llama3.1",
|
||||
"temperature": 0.6,
|
||||
"base_url": "http://localhost:11434"
|
||||
}
|
||||
|
||||
with patch('ai_novel.llm_providers.ChatOllama') as mock_ollama:
|
||||
mock_instance = Mock()
|
||||
mock_ollama.return_value = mock_instance
|
||||
|
||||
llm = LLMProviderFactory.create_llm(config)
|
||||
|
||||
mock_ollama.assert_called_once_with(
|
||||
model="llama3.1",
|
||||
temperature=0.6,
|
||||
base_url="http://localhost:11434"
|
||||
)
|
||||
assert llm == mock_instance
|
||||
|
||||
|
||||
def test_unsupported_provider_type():
|
||||
"""Test error handling for unsupported provider type."""
|
||||
config = {
|
||||
"type": "unsupported_provider",
|
||||
"model": "some-model"
|
||||
}
|
||||
|
||||
with pytest.raises(ValueError, match="Unsupported provider type: unsupported_provider"):
|
||||
LLMProviderFactory.create_llm(config)
|
||||
|
||||
|
||||
def test_default_values():
|
||||
"""Test that default values are used when not specified."""
|
||||
config = {"type": "openai"}
|
||||
|
||||
with patch('ai_novel.llm_providers.ChatOpenAI') as mock_openai:
|
||||
mock_instance = Mock()
|
||||
mock_openai.return_value = mock_instance
|
||||
|
||||
llm = LLMProviderFactory.create_llm(config)
|
||||
|
||||
mock_openai.assert_called_once_with(
|
||||
model="gpt-3.5-turbo", # default
|
||||
temperature=0.7, # default
|
||||
max_tokens=2000, # default
|
||||
openai_api_key=None,
|
||||
openai_api_base=None
|
||||
)
|
Reference in New Issue
Block a user