"""
Tests for configuration validation.

Tests cover:
- Individual config section validation
- Main Config validation
- ConfigValidationError exception
- Validation helpers
"""

import pytest
from unittest.mock import patch

from agent_orchestrator.config import (
    Config,
    ConfigValidationError,
    BudgetsConfig,
    AgentBudgetConfig,
    ControlLoopConfig,
    RiskGateConfig,
    RoutingConfig,
    LoggingConfig,
    get_config,
    reload_config,
    validate_config,
    VALID_RISK_LEVELS,
    VALID_LOG_LEVELS,
)


# =============================================================================
# AgentBudgetConfig Tests
# =============================================================================

class TestAgentBudgetConfigValidation:
    """Tests for AgentBudgetConfig validation."""

    def test_valid_config(self):
        """Test valid agent budget config."""
        config = AgentBudgetConfig(
            daily_token_limit=100000,
            daily_cost_limit=50.0,
            rate_limit_rpm=60,
        )
        errors = config.validate("test_agent")
        assert errors == []

    def test_negative_token_limit(self):
        """Test negative token limit is invalid."""
        config = AgentBudgetConfig(
            daily_token_limit=-100,
            daily_cost_limit=50.0,
        )
        errors = config.validate("test_agent")
        assert len(errors) == 1
        assert "daily_token_limit" in errors[0]
        assert "non-negative" in errors[0]

    def test_negative_cost_limit(self):
        """Test negative cost limit is invalid."""
        config = AgentBudgetConfig(
            daily_token_limit=100000,
            daily_cost_limit=-10.0,
        )
        errors = config.validate("test_agent")
        assert len(errors) == 1
        assert "daily_cost_limit" in errors[0]

    def test_zero_rate_limit(self):
        """Test zero rate limit is invalid."""
        config = AgentBudgetConfig(
            daily_token_limit=100000,
            daily_cost_limit=50.0,
            rate_limit_rpm=0,
        )
        errors = config.validate("test_agent")
        assert len(errors) == 1
        assert "rate_limit_rpm" in errors[0]

    def test_multiple_errors(self):
        """Test multiple validation errors are collected."""
        config = AgentBudgetConfig(
            daily_token_limit=-100,
            daily_cost_limit=-10.0,
            rate_limit_rpm=0,
        )
        errors = config.validate("test_agent")
        assert len(errors) == 3


# =============================================================================
# BudgetsConfig Tests
# =============================================================================

class TestBudgetsConfigValidation:
    """Tests for BudgetsConfig validation."""

    def test_valid_config(self):
        """Test valid budgets config."""
        config = BudgetsConfig(
            daily_budget_usd=50.0,
            task_budget_usd=5.0,
            warn_threshold_percent=80,
        )
        errors = config.validate()
        assert errors == []

    def test_zero_daily_budget(self):
        """Test zero daily budget is invalid."""
        config = BudgetsConfig(
            daily_budget_usd=0.0,
            task_budget_usd=5.0,
        )
        errors = config.validate()
        assert any("daily_budget_usd" in e and "positive" in e for e in errors)

    def test_task_budget_exceeds_daily(self):
        """Test task budget exceeding daily budget is invalid."""
        config = BudgetsConfig(
            daily_budget_usd=10.0,
            task_budget_usd=20.0,
        )
        errors = config.validate()
        assert any("cannot exceed" in e for e in errors)

    def test_warn_threshold_out_of_range(self):
        """Test warn threshold outside 0-100 is invalid."""
        config = BudgetsConfig(
            daily_budget_usd=50.0,
            task_budget_usd=5.0,
            warn_threshold_percent=150,
        )
        errors = config.validate()
        assert any("warn_threshold_percent" in e and "0 and 100" in e for e in errors)


# =============================================================================
# ControlLoopConfig Tests
# =============================================================================

class TestControlLoopConfigValidation:
    """Tests for ControlLoopConfig validation."""

    def test_valid_config(self):
        """Test valid control loop config."""
        config = ControlLoopConfig(
            health_check_interval=60,
            max_auto_prompt_attempts=2,
            stuck_idle_threshold_minutes=10,
            stuck_error_threshold_count=3,
        )
        errors = config.validate()
        assert errors == []

    def test_zero_health_check_interval(self):
        """Test zero health check interval is invalid."""
        config = ControlLoopConfig(
            health_check_interval=0,
        )
        errors = config.validate()
        assert any("health_check_interval" in e for e in errors)

    def test_negative_auto_prompt_attempts(self):
        """Test negative auto prompt attempts is invalid."""
        config = ControlLoopConfig(
            max_auto_prompt_attempts=-1,
        )
        errors = config.validate()
        assert any("max_auto_prompt_attempts" in e for e in errors)

    def test_zero_stuck_thresholds(self):
        """Test zero stuck thresholds are invalid."""
        config = ControlLoopConfig(
            stuck_idle_threshold_minutes=0,
            stuck_error_threshold_count=0,
        )
        errors = config.validate()
        assert len(errors) == 2


# =============================================================================
# RiskGateConfig Tests
# =============================================================================

class TestRiskGateConfigValidation:
    """Tests for RiskGateConfig validation."""

    def test_valid_config(self):
        """Test valid risk gate config."""
        for level in VALID_RISK_LEVELS:
            config = RiskGateConfig(
                default_risk_level=level,
                approval_timeout_seconds=3600,
            )
            errors = config.validate()
            assert errors == []

    def test_invalid_risk_level(self):
        """Test invalid risk level."""
        config = RiskGateConfig(
            default_risk_level="extreme",
        )
        errors = config.validate()
        assert len(errors) == 1
        assert "default_risk_level" in errors[0]
        assert "extreme" in errors[0]

    def test_case_insensitive_risk_level(self):
        """Test risk level is case-insensitive."""
        config = RiskGateConfig(
            default_risk_level="HIGH",
        )
        errors = config.validate()
        assert errors == []

    def test_zero_timeout(self):
        """Test zero approval timeout is invalid."""
        config = RiskGateConfig(
            approval_timeout_seconds=0,
        )
        errors = config.validate()
        assert any("approval_timeout_seconds" in e for e in errors)


# =============================================================================
# RoutingConfig Tests
# =============================================================================

class TestRoutingConfigValidation:
    """Tests for RoutingConfig validation."""

    def test_valid_config(self):
        """Test valid routing config."""
        config = RoutingConfig(
            cli_soft_limit_pct=0.85,
            cli_hard_limit_pct=0.95,
        )
        errors = config.validate()
        assert errors == []

    def test_soft_limit_out_of_range(self):
        """Test soft limit outside 0-1 is invalid."""
        config = RoutingConfig(
            cli_soft_limit_pct=1.5,
            cli_hard_limit_pct=2.0,
        )
        errors = config.validate()
        assert any("cli_soft_limit_pct" in e and "0.0 and 1.0" in e for e in errors)

    def test_soft_exceeds_hard(self):
        """Test soft limit >= hard limit is invalid."""
        config = RoutingConfig(
            cli_soft_limit_pct=0.95,
            cli_hard_limit_pct=0.85,
        )
        errors = config.validate()
        assert any("must be less than" in e for e in errors)

    def test_equal_limits(self):
        """Test equal limits is invalid."""
        config = RoutingConfig(
            cli_soft_limit_pct=0.90,
            cli_hard_limit_pct=0.90,
        )
        errors = config.validate()
        assert any("must be less than" in e for e in errors)


# =============================================================================
# LoggingConfig Tests
# =============================================================================

class TestLoggingConfigValidation:
    """Tests for LoggingConfig validation."""

    def test_valid_config(self):
        """Test valid logging config."""
        for level in VALID_LOG_LEVELS:
            config = LoggingConfig(level=level)
            errors = config.validate()
            assert errors == []

    def test_invalid_log_level(self):
        """Test invalid log level."""
        config = LoggingConfig(level="TRACE")
        errors = config.validate()
        assert len(errors) == 1
        assert "logging.level" in errors[0]

    def test_case_insensitive_log_level(self):
        """Test log level is case-insensitive."""
        config = LoggingConfig(level="info")
        errors = config.validate()
        assert errors == []


# =============================================================================
# Main Config Tests
# =============================================================================

class TestConfigValidation:
    """Tests for main Config validation."""

    def test_default_config_is_valid(self):
        """Test default configuration is valid."""
        config = Config()
        errors = config.validate(raise_on_error=False)
        assert errors == []

    def test_is_valid_method(self):
        """Test is_valid() helper method."""
        config = Config()
        assert config.is_valid() is True

    def test_get_validation_summary(self):
        """Test get_validation_summary() method."""
        config = Config()
        summary = config.get_validation_summary()

        assert summary["valid"] is True
        assert summary["error_count"] == 0
        assert summary["errors"] == []
        assert len(summary["sections_validated"]) > 0

    def test_validation_error_raised(self):
        """Test ConfigValidationError is raised for invalid config."""
        config = Config()
        # Manually break config
        config.budgets.daily_budget_usd = -10.0

        with pytest.raises(ConfigValidationError) as exc_info:
            config.validate()

        assert len(exc_info.value.errors) > 0
        assert "daily_budget_usd" in str(exc_info.value)

    def test_validation_error_not_raised_when_disabled(self):
        """Test errors returned without raising when disabled."""
        config = Config()
        config.budgets.daily_budget_usd = -10.0

        errors = config.validate(raise_on_error=False)
        assert len(errors) > 0

    def test_multiple_section_errors(self):
        """Test errors from multiple sections are collected."""
        config = Config()
        config.budgets.daily_budget_usd = -10.0
        config.control_loop.health_check_interval = 0
        config.risk_gate.default_risk_level = "invalid"

        errors = config.validate(raise_on_error=False)
        assert len(errors) >= 3


# =============================================================================
# Module Function Tests
# =============================================================================

class TestModuleFunctions:
    """Tests for module-level functions."""

    def test_validate_config_function(self):
        """Test validate_config() function."""
        errors = validate_config()
        assert isinstance(errors, list)

    def test_get_config_with_validation(self):
        """Test get_config(validate=True) validates."""
        # Reset global config
        import agent_orchestrator.config as config_module
        config_module._config = None

        # Should not raise with default valid config
        config = get_config(validate=True)
        assert config is not None

    def test_reload_config_with_validation(self):
        """Test reload_config(validate=True) validates."""
        # Should not raise with default valid config
        config = reload_config(validate=True)
        assert config is not None


# =============================================================================
# ConfigValidationError Tests
# =============================================================================

class TestConfigValidationError:
    """Tests for ConfigValidationError exception."""

    def test_error_message_format(self):
        """Test error message contains all errors."""
        errors = [
            "error 1",
            "error 2",
            "error 3",
        ]
        exc = ConfigValidationError(errors)

        assert "3 error(s)" in str(exc)
        assert "error 1" in str(exc)
        assert "error 2" in str(exc)
        assert "error 3" in str(exc)

    def test_errors_attribute(self):
        """Test errors attribute is accessible."""
        errors = ["test error"]
        exc = ConfigValidationError(errors)

        assert exc.errors == errors
