"""
Tests for the observability module.

Tests cover:
- Token Audit integration
- AI Observer client
- Alert Manager
"""

import pytest
from datetime import datetime, date
from unittest.mock import MagicMock, AsyncMock

from agent_orchestrator.observability.token_audit import (
    TokenAuditClient,
    UsageEntry,
    UsageSummary,
    CostAlert,
)

from agent_orchestrator.observability.ai_observer import (
    AIObserverClient,
    MetricPoint,
    AlertRule as ObserverAlertRule,
)

from agent_orchestrator.observability.alerts import (
    AlertManager,
    Alert,
    AlertRule,
    AlertSeverity,
    AlertType,
    AlertState,
    NotificationChannel,
)


# =============================================================================
# Token Audit Tests
# =============================================================================

class TestUsageEntry:
    """Tests for UsageEntry dataclass."""

    def test_create_entry(self):
        """Test creating a usage entry."""
        entry = UsageEntry(
            timestamp=datetime.now(),
            provider="anthropic",
            model="claude-3-sonnet",
            input_tokens=1000,
            output_tokens=200,
            cost_usd=0.01,
        )

        assert entry.provider == "anthropic"
        assert entry.input_tokens == 1000
        assert entry.cost_usd == 0.01

    def test_entry_with_metadata(self):
        """Test entry with optional metadata."""
        entry = UsageEntry(
            timestamp=datetime.now(),
            provider="anthropic",
            model="claude-3-sonnet",
            input_tokens=1000,
            output_tokens=200,
            cost_usd=0.01,
            agent_id="claude-code",
            task_id="task-001",
            tags={"project": "test"},
        )

        assert entry.agent_id == "claude-code"
        assert entry.tags["project"] == "test"


class TestUsageSummary:
    """Tests for UsageSummary dataclass."""

    def test_create_summary(self):
        """Test creating a usage summary."""
        summary = UsageSummary(
            period_start=datetime.now(),
            period_end=datetime.now(),
            total_input_tokens=10000,
            total_output_tokens=2000,
            total_cost_usd=0.10,
            total_requests=5,
        )

        assert summary.total_input_tokens == 10000
        assert summary.total_cost_usd == 0.10

    def test_summary_with_breakdowns(self):
        """Test summary with per-provider breakdown."""
        summary = UsageSummary(
            period_start=datetime.now(),
            period_end=datetime.now(),
            by_provider={
                "anthropic": {"input_tokens": 5000, "cost_usd": 0.05},
                "openai": {"input_tokens": 5000, "cost_usd": 0.05},
            },
        )

        assert "anthropic" in summary.by_provider
        assert summary.by_provider["anthropic"]["cost_usd"] == 0.05


class TestTokenAuditClient:
    """Tests for TokenAuditClient class."""

    @pytest.fixture
    def mock_db(self):
        """Create a mock database."""
        db = MagicMock()
        db.query_usage.return_value = []
        return db

    @pytest.fixture
    def client(self, mock_db):
        """Create token audit client."""
        return TokenAuditClient(db=mock_db)

    @pytest.mark.asyncio
    async def test_query_usage_empty(self, client):
        """Test querying usage with no data."""
        entries = await client.query_usage()
        assert entries == []

    @pytest.mark.asyncio
    async def test_record_usage(self, client, mock_db):
        """Test recording usage entry."""
        mock_db.record_usage = MagicMock()

        success = await client.record_usage(
            provider="anthropic",
            model="claude-3",
            input_tokens=1000,
            output_tokens=200,
            cost_usd=0.01,
        )

        assert success is True
        mock_db.record_usage.assert_called_once()

    @pytest.mark.asyncio
    async def test_get_daily_summary(self, client, mock_db):
        """Test getting daily summary."""
        mock_db.query_usage.return_value = [
            {
                "timestamp": datetime.now().isoformat(),
                "provider": "anthropic",
                "model": "claude-3",
                "input_tokens": 1000,
                "output_tokens": 200,
                "cost_usd": 0.01,
            }
        ]

        summary = await client.get_daily_summary()

        assert summary.total_requests == 1  # One record returned from mock
        assert summary.total_input_tokens == 1000
        assert summary.total_output_tokens == 200

    def test_check_request_threshold(self, client):
        """Test checking request cost threshold."""
        client.set_threshold("per_request", 5.0)

        alert = client.check_request_threshold(3.0)
        assert alert is None

        alert = client.check_request_threshold(6.0)
        assert alert is not None
        assert alert.exceeded is True

    def test_set_thresholds(self, client):
        """Test setting thresholds."""
        client.set_threshold("daily", 100.0)
        client.set_threshold("hourly", 20.0)
        client.set_threshold("per_request", 5.0)

        assert client._daily_threshold == 100.0
        assert client._hourly_threshold == 20.0
        assert client._request_threshold == 5.0

    def test_alert_callback(self, client):
        """Test alert callback registration."""
        alerts_received = []

        def on_alert(alert):
            alerts_received.append(alert)

        client.on_alert(on_alert)

        # Trigger an alert
        client.check_request_threshold(10.0)

        assert len(alerts_received) == 1

    def test_generate_report(self, client):
        """Test report generation."""
        summary = UsageSummary(
            period_start=datetime.now(),
            period_end=datetime.now(),
            total_input_tokens=10000,
            total_output_tokens=2000,
            total_cost_usd=0.10,
            total_requests=5,
        )

        report = client.generate_report(summary)

        assert "TOKEN USAGE REPORT" in report
        assert "10,000" in report  # Formatted token count
        assert "$0.1000" in report  # Cost


# =============================================================================
# AI Observer Tests
# =============================================================================

class TestAIObserverClient:
    """Tests for AIObserverClient class."""

    @pytest.fixture
    def client(self):
        """Create AI observer client."""
        return AIObserverClient()

    def test_record_metric(self, client):
        """Test recording a metric."""
        client.record_metric("test_metric", 42.0)

        metrics = client.get_metrics(name="test_metric")

        assert len(metrics) == 1
        assert metrics[0].value == 42.0

    def test_record_metric_with_tags(self, client):
        """Test recording metric with tags."""
        client.record_metric("test_metric", 42.0, tags={"agent": "test"})

        metrics = client.get_metrics(name="test_metric")

        assert metrics[0].tags["agent"] == "test"

    def test_add_alert_rule(self, client):
        """Test adding an alert rule."""
        rule = ObserverAlertRule(
            name="test_rule",
            metric="test_metric",
            condition="gt",
            threshold=50.0,
        )

        client.add_alert_rule(rule)

        assert "test_rule" in client._alert_rules

    def test_alert_triggers(self, client):
        """Test that alert triggers when threshold exceeded."""
        rule = ObserverAlertRule(
            name="test_rule",
            metric="test_metric",
            condition="gt",
            threshold=50.0,
        )
        client.add_alert_rule(rule)

        alerts_received = []
        client.on_alert(lambda a: alerts_received.append(a))

        # Record value above threshold
        client.record_metric("test_metric", 60.0)

        assert len(alerts_received) == 1
        assert alerts_received[0].rule_name == "test_rule"

    def test_alert_not_triggered(self, client):
        """Test that alert doesn't trigger below threshold."""
        rule = ObserverAlertRule(
            name="test_rule",
            metric="test_metric",
            condition="gt",
            threshold=50.0,
        )
        client.add_alert_rule(rule)

        alerts_received = []
        client.on_alert(lambda a: alerts_received.append(a))

        # Record value below threshold
        client.record_metric("test_metric", 40.0)

        assert len(alerts_received) == 0

    def test_record_standard_metrics(self, client):
        """Test recording standard metrics."""
        client.record_request_cost(0.05, "claude-code")
        client.record_token_usage(1000, 200, "claude-code")
        client.record_request_latency(150.0, "claude-code")
        client.record_agent_health("claude-code", 0.95)

        metrics = client.get_metrics()

        assert len(metrics) == 5  # cost, input, output, latency, health


# =============================================================================
# Alert Manager Tests
# =============================================================================

class TestAlertManager:
    """Tests for AlertManager class."""

    @pytest.fixture
    def manager(self):
        """Create alert manager."""
        return AlertManager()

    def test_default_rules_exist(self, manager):
        """Test that default rules are created."""
        assert len(manager._rules) > 0
        assert "stuck_agent_warning" in manager._rules
        assert "cost_threshold_warning" in manager._rules

    def test_add_rule(self, manager):
        """Test adding a custom rule."""
        rule = AlertRule(
            name="custom_rule",
            alert_type=AlertType.ERROR_SPIKE,
            severity=AlertSeverity.WARNING,
            threshold=5.0,
        )

        manager.add_rule(rule)

        assert "custom_rule" in manager._rules

    def test_remove_rule(self, manager):
        """Test removing a rule."""
        success = manager.remove_rule("stuck_agent_warning")
        assert success is True
        assert "stuck_agent_warning" not in manager._rules

    def test_add_notification_channel(self, manager):
        """Test adding a notification channel."""
        channel = NotificationChannel(
            name="test_slack",
            channel_type="slack",
            slack_channel="#alerts",
        )

        manager.add_channel(channel)

        assert "test_slack" in manager._channels

    @pytest.mark.asyncio
    async def test_check_stuck_agent_fires(self, manager):
        """Test stuck agent alert fires."""
        alerts_received = []
        manager.on_alert(lambda a: alerts_received.append(a))

        await manager.check_stuck_agent(
            agent_id="test-agent",
            is_stuck=True,
            stuck_duration_seconds=600,  # 10 minutes
            reason="Repeated error loop",
        )

        assert len(alerts_received) == 1
        assert alerts_received[0].alert_type == AlertType.STUCK_AGENT

    @pytest.mark.asyncio
    async def test_check_stuck_agent_below_threshold(self, manager):
        """Test stuck agent alert doesn't fire below threshold."""
        alerts_received = []
        manager.on_alert(lambda a: alerts_received.append(a))

        await manager.check_stuck_agent(
            agent_id="test-agent",
            is_stuck=True,
            stuck_duration_seconds=60,  # Only 1 minute
            reason="Test",
        )

        assert len(alerts_received) == 0

    @pytest.mark.asyncio
    async def test_check_cost_threshold_fires(self, manager):
        """Test cost threshold alert fires."""
        alerts_received = []
        manager.on_alert(lambda a: alerts_received.append(a))

        await manager.check_cost_threshold(current_cost=75.0)

        assert len(alerts_received) == 1
        assert alerts_received[0].alert_type == AlertType.COST_THRESHOLD

    @pytest.mark.asyncio
    async def test_check_error_spike_fires(self, manager):
        """Test error spike alert fires."""
        alerts_received = []
        manager.on_alert(lambda a: alerts_received.append(a))

        await manager.check_error_spike(
            error_count=20,
            time_window_seconds=60,
        )

        assert len(alerts_received) == 1
        assert alerts_received[0].alert_type == AlertType.ERROR_SPIKE

    @pytest.mark.asyncio
    async def test_check_budget_exhausted_fires(self, manager):
        """Test budget exhausted alert fires."""
        alerts_received = []
        manager.on_alert(lambda a: alerts_received.append(a))

        await manager.check_budget_exhausted(
            agent_id="test-agent",
            budget_percentage=1.0,  # 100%
        )

        assert len(alerts_received) == 1
        assert alerts_received[0].alert_type == AlertType.BUDGET_EXHAUSTED

    @pytest.mark.asyncio
    async def test_check_health_degraded_fires(self, manager):
        """Test health degraded alert fires."""
        alerts_received = []
        manager.on_alert(lambda a: alerts_received.append(a))

        await manager.check_health_degraded(
            agent_id="test-agent",
            health_score=0.3,  # 30%
        )

        assert len(alerts_received) == 1
        assert alerts_received[0].alert_type == AlertType.HEALTH_DEGRADED

    def test_acknowledge_alert(self, manager):
        """Test acknowledging an alert."""
        # Add an active alert
        alert = Alert(
            alert_id="alert-000001",
            rule_name="test",
            alert_type=AlertType.STUCK_AGENT,
            severity=AlertSeverity.WARNING,
            title="Test Alert",
            message="Test message",
        )
        manager._active_alerts["alert-000001"] = alert

        success = manager.acknowledge_alert(
            "alert-000001",
            acknowledged_by="user@test.com",
            notes="Looking into it",
        )

        assert success is True
        assert alert.state == AlertState.ACKNOWLEDGED
        assert alert.acknowledged_by == "user@test.com"

    def test_silence_alert(self, manager):
        """Test silencing an alert."""
        alert = Alert(
            alert_id="alert-000001",
            rule_name="test",
            alert_type=AlertType.STUCK_AGENT,
            severity=AlertSeverity.WARNING,
            title="Test",
            message="Test",
        )
        manager._active_alerts["alert-000001"] = alert

        success = manager.silence_alert("alert-000001")

        assert success is True
        assert alert.state == AlertState.SILENCED

    def test_get_active_alerts(self, manager):
        """Test getting active alerts with filtering."""
        # Add some alerts
        manager._active_alerts["alert-1"] = Alert(
            alert_id="alert-1",
            rule_name="test",
            alert_type=AlertType.STUCK_AGENT,
            severity=AlertSeverity.WARNING,
            title="Test 1",
            message="Test",
        )
        manager._active_alerts["alert-2"] = Alert(
            alert_id="alert-2",
            rule_name="test",
            alert_type=AlertType.COST_THRESHOLD,
            severity=AlertSeverity.CRITICAL,
            title="Test 2",
            message="Test",
        )

        # Get all
        alerts = manager.get_active_alerts()
        assert len(alerts) == 2

        # Filter by severity
        alerts = manager.get_active_alerts(severity=AlertSeverity.CRITICAL)
        assert len(alerts) == 1
        assert alerts[0].alert_id == "alert-2"

        # Filter by type
        alerts = manager.get_active_alerts(alert_type=AlertType.STUCK_AGENT)
        assert len(alerts) == 1
        assert alerts[0].alert_id == "alert-1"

    def test_get_alert_summary(self, manager):
        """Test getting alert summary."""
        manager._active_alerts["alert-1"] = Alert(
            alert_id="alert-1",
            rule_name="test",
            alert_type=AlertType.STUCK_AGENT,
            severity=AlertSeverity.WARNING,
            title="Test",
            message="Test",
        )

        summary = manager.get_alert_summary()

        assert summary["active_count"] == 1
        assert "warning" in summary["active_by_severity"]

    @pytest.mark.asyncio
    async def test_cooldown_prevents_duplicate_alerts(self, manager):
        """Test that cooldown prevents duplicate alerts."""
        alerts_received = []
        manager.on_alert(lambda a: alerts_received.append(a))

        # First alert should fire
        await manager.check_cost_threshold(current_cost=75.0)
        assert len(alerts_received) == 1

        # Second alert within cooldown should not fire
        await manager.check_cost_threshold(current_cost=80.0)
        assert len(alerts_received) == 1  # Still just 1


class TestAlertSeverityOrdering:
    """Tests for alert severity ordering."""

    def test_severity_order(self):
        """Test severity values are correctly ordered."""
        severities = [
            AlertSeverity.INFO,
            AlertSeverity.WARNING,
            AlertSeverity.ERROR,
            AlertSeverity.CRITICAL,
        ]

        # Verify they have distinct values
        values = [s.value for s in severities]
        assert len(values) == len(set(values))


# =============================================================================
# Notification Channel Tests
# =============================================================================

from agent_orchestrator.observability.alerts import (
    create_slack_webhook_channel,
    create_slack_api_channel,
    create_webhook_channel,
    create_email_channel,
    setup_default_alerting,
)


class TestNotificationChannelCreation:
    """Tests for notification channel creation helpers."""

    def test_create_slack_webhook_channel(self):
        """Test creating Slack webhook channel."""
        channel = create_slack_webhook_channel(
            name="test-slack",
            webhook_url="https://hooks.slack.com/services/T123/B456/xxx",
        )

        assert channel.name == "test-slack"
        assert channel.channel_type == "slack"
        assert channel.webhook_url == "https://hooks.slack.com/services/T123/B456/xxx"
        assert channel.min_severity == AlertSeverity.WARNING

    def test_create_slack_webhook_channel_custom_severity(self):
        """Test creating Slack webhook with custom severity."""
        channel = create_slack_webhook_channel(
            name="slack-critical",
            webhook_url="https://hooks.slack.com/...",
            min_severity=AlertSeverity.CRITICAL,
        )

        assert channel.min_severity == AlertSeverity.CRITICAL

    def test_create_slack_api_channel(self):
        """Test creating Slack API channel."""
        channel = create_slack_api_channel(
            name="slack-api",
            token="xoxb-12345-abcdef",
            channel="#alerts",
        )

        assert channel.name == "slack-api"
        assert channel.channel_type == "slack"
        assert channel.slack_token == "xoxb-12345-abcdef"
        assert channel.slack_channel == "#alerts"

    def test_create_webhook_channel(self):
        """Test creating generic webhook channel."""
        channel = create_webhook_channel(
            name="monitoring",
            url="https://monitoring.example.com/webhook",
        )

        assert channel.name == "monitoring"
        assert channel.channel_type == "webhook"
        assert channel.webhook_url == "https://monitoring.example.com/webhook"
        assert channel.webhook_headers == {}

    def test_create_webhook_channel_with_headers(self):
        """Test creating webhook channel with custom headers."""
        channel = create_webhook_channel(
            name="authenticated",
            url="https://api.example.com/alerts",
            headers={"Authorization": "Bearer token123"},
            timeout_seconds=30,
        )

        assert channel.webhook_headers == {"Authorization": "Bearer token123"}
        assert channel.webhook_timeout_seconds == 30

    def test_create_email_channel(self):
        """Test creating email channel."""
        channel = create_email_channel(
            name="email-alerts",
            recipients=["oncall@example.com", "team@example.com"],
        )

        assert channel.name == "email-alerts"
        assert channel.channel_type == "email"
        assert len(channel.email_addresses) == 2
        assert "oncall@example.com" in channel.email_addresses

    def test_setup_default_alerting(self):
        """Test default alerting setup."""
        manager = setup_default_alerting(
            slack_webhook_url="https://hooks.slack.com/test",
            webhook_url="https://webhook.example.com",
        )

        assert "slack" in manager._channels
        assert "webhook" in manager._channels
        assert len(manager._rules) > 0  # Default rules exist


class TestNotificationChannelConfig:
    """Tests for NotificationChannel configuration."""

    def test_channel_retry_settings(self):
        """Test channel retry settings defaults."""
        channel = NotificationChannel(
            name="test",
            channel_type="webhook",
            webhook_url="https://example.com",
        )

        assert channel.retry_count == 3
        assert channel.retry_delay_seconds == 1.0

    def test_channel_custom_retry_settings(self):
        """Test channel with custom retry settings."""
        channel = NotificationChannel(
            name="test",
            channel_type="webhook",
            webhook_url="https://example.com",
            retry_count=5,
            retry_delay_seconds=2.0,
        )

        assert channel.retry_count == 5
        assert channel.retry_delay_seconds == 2.0


class TestSlackNotification:
    """Tests for Slack notification functionality."""

    @pytest.fixture
    def manager(self):
        """Create alert manager."""
        return AlertManager()

    @pytest.fixture
    def slack_channel(self):
        """Create mock Slack channel."""
        return NotificationChannel(
            name="test-slack",
            channel_type="slack",
            webhook_url="https://hooks.slack.com/test",
        )

    @pytest.mark.asyncio
    async def test_slack_message_building(self, manager, slack_channel):
        """Test Slack message is built correctly."""
        manager.add_channel(slack_channel)

        # Mock aiohttp to capture the payload
        import aiohttp
        from unittest.mock import patch, AsyncMock, MagicMock
        from contextlib import asynccontextmanager

        captured_payload = None

        @asynccontextmanager
        async def mock_post(url, json=None, **kwargs):
            nonlocal captured_payload
            captured_payload = json

            mock_response = MagicMock()
            mock_response.status = 200
            mock_response.json = AsyncMock(return_value={"ok": True})
            yield mock_response

        @asynccontextmanager
        async def mock_session():
            mock = MagicMock()
            mock.post = mock_post
            yield mock

        with patch("aiohttp.ClientSession", mock_session):
            # Add rule that notifies slack
            rule = manager._rules["cost_threshold_warning"]
            rule.notify_channels = ["test-slack"]

            await manager.check_cost_threshold(current_cost=75.0)

        # Verify payload structure
        assert captured_payload is not None
        assert "blocks" in captured_payload
        assert "attachments" in captured_payload

        # Check blocks contain expected content
        blocks = captured_payload["blocks"]
        assert len(blocks) >= 2  # Header and message at minimum


class TestWebhookNotification:
    """Tests for webhook notification functionality."""

    @pytest.fixture
    def manager(self):
        """Create alert manager."""
        return AlertManager()

    @pytest.fixture
    def webhook_channel(self):
        """Create mock webhook channel."""
        return NotificationChannel(
            name="test-webhook",
            channel_type="webhook",
            webhook_url="https://webhook.example.com/alerts",
            webhook_headers={"X-Custom-Header": "test-value"},
        )

    @pytest.mark.asyncio
    async def test_webhook_payload(self, manager, webhook_channel):
        """Test webhook payload structure."""
        manager.add_channel(webhook_channel)

        import aiohttp
        from unittest.mock import patch, AsyncMock, MagicMock
        from contextlib import asynccontextmanager

        captured_payload = None
        captured_headers = None

        @asynccontextmanager
        async def mock_post(url, json=None, headers=None, **kwargs):
            nonlocal captured_payload, captured_headers
            captured_payload = json
            captured_headers = headers

            mock_response = MagicMock()
            mock_response.status = 200
            yield mock_response

        @asynccontextmanager
        async def mock_session():
            mock = MagicMock()
            mock.post = mock_post
            yield mock

        with patch("aiohttp.ClientSession", mock_session):
            rule = manager._rules["error_spike"]
            rule.notify_channels = ["test-webhook"]

            await manager.check_error_spike(error_count=20, time_window_seconds=60)

        # Verify payload
        assert captured_payload is not None
        assert "alert_id" in captured_payload
        assert "alert_type" in captured_payload
        assert captured_payload["alert_type"] == "error_spike"
        assert "severity" in captured_payload
        assert "title" in captured_payload
        assert "message" in captured_payload
        assert "timestamp" in captured_payload

        # Verify headers
        assert captured_headers is not None
        assert "X-Custom-Header" in captured_headers
        assert captured_headers["X-Custom-Header"] == "test-value"
