"""
Unit tests for the persistence layer.
"""

import pytest
from pathlib import Path
from datetime import datetime, date

from agent_orchestrator.persistence.database import OrchestratorDB
from agent_orchestrator.persistence.models import (
    Agent,
    Task,
    Run,
    HealthSample,
    Approval,
    Decision,
)
from agent_orchestrator.journal.status_packet import (
    StatusPacket,
    TaskArtifacts,
)


class TestOrchestratorDB:
    """Tests for OrchestratorDB class."""

    @pytest.fixture
    def db(self, temp_db_path: Path) -> OrchestratorDB:
        """Create a test database."""
        return OrchestratorDB(temp_db_path)

    # =========================================================================
    # Agent Tests
    # =========================================================================

    def test_create_and_get_agent(self, db: OrchestratorDB):
        """Test creating and retrieving an agent."""
        agent = Agent(
            id="test-agent-1",
            tool="claude_code",
            worktree_path="/path/to/worktree",
            branch="feature/test",
            status="idle",
        )
        db.create_agent(agent)

        retrieved = db.get_agent("test-agent-1")
        assert retrieved is not None
        assert retrieved.id == "test-agent-1"
        assert retrieved.tool == "claude_code"
        assert retrieved.worktree_path == "/path/to/worktree"
        assert retrieved.branch == "feature/test"
        assert retrieved.status == "idle"

    def test_get_nonexistent_agent(self, db: OrchestratorDB):
        """Test getting an agent that doesn't exist."""
        retrieved = db.get_agent("nonexistent")
        assert retrieved is None

    def test_update_agent_status(self, db: OrchestratorDB):
        """Test updating an agent's status."""
        agent = Agent(id="test-agent-2", tool="gemini_cli", status="idle")
        db.create_agent(agent)

        db.update_agent_status("test-agent-2", "running")

        retrieved = db.get_agent("test-agent-2")
        assert retrieved.status == "running"

    def test_get_agents_by_status(self, db: OrchestratorDB):
        """Test filtering agents by status."""
        db.create_agent(Agent(id="agent-1", tool="claude_code", status="idle"))
        db.create_agent(Agent(id="agent-2", tool="gemini_cli", status="running"))
        db.create_agent(Agent(id="agent-3", tool="codex", status="idle"))

        idle_agents = db.get_agents_by_status("idle")
        assert len(idle_agents) == 2

        running_agents = db.get_agents_by_status("running")
        assert len(running_agents) == 1
        assert running_agents[0].id == "agent-2"

    def test_delete_agent(self, db: OrchestratorDB):
        """Test deleting an agent."""
        db.create_agent(Agent(id="to-delete", tool="claude_code"))
        assert db.get_agent("to-delete") is not None

        db.delete_agent("to-delete")
        assert db.get_agent("to-delete") is None

    # =========================================================================
    # Task Tests
    # =========================================================================

    def test_create_and_get_task(self, db: OrchestratorDB):
        """Test creating and retrieving a task."""
        task = Task(
            id="task-1",
            description="Implement user authentication",
            task_type="multi_file_refactor",
            priority=1,
            status="pending",
        )
        db.create_task(task)

        retrieved = db.get_task("task-1")
        assert retrieved is not None
        assert retrieved.description == "Implement user authentication"
        assert retrieved.task_type == "multi_file_refactor"
        assert retrieved.priority == 1

    def test_get_pending_tasks(self, db: OrchestratorDB):
        """Test getting pending tasks ordered by priority."""
        db.create_task(Task(id="t1", description="Low priority", task_type="test", priority=0))
        db.create_task(Task(id="t2", description="High priority", task_type="test", priority=10))
        db.create_task(Task(id="t3", description="Medium priority", task_type="test", priority=5))

        pending = db.get_pending_tasks()
        assert len(pending) == 3
        assert pending[0].id == "t2"  # Highest priority first
        assert pending[1].id == "t3"
        assert pending[2].id == "t1"

    def test_assign_task(self, db: OrchestratorDB):
        """Test assigning a task to an agent."""
        db.create_agent(Agent(id="agent-1", tool="claude_code"))
        db.create_task(Task(id="task-1", description="Test", task_type="test"))

        db.assign_task("task-1", "agent-1")

        task = db.get_task("task-1")
        assert task.assigned_agent_id == "agent-1"
        assert task.status == "assigned"

    # =========================================================================
    # Run-Until-Done Tests
    # =========================================================================

    def test_task_run_until_done_fields(self, db: OrchestratorDB):
        """Test task with run_until_done fields."""
        task = Task(
            id="task-rud",
            description="Retry task",
            task_type="test",
            run_until_done=True,
            max_retries=5,
        )
        db.create_task(task)

        retrieved = db.get_task("task-rud")
        assert retrieved.run_until_done is True
        assert retrieved.max_retries == 5
        assert retrieved.attempt_count == 0

    def test_record_task_attempt(self, db: OrchestratorDB):
        """Test recording task attempts."""
        db.create_task(Task(
            id="task-attempt",
            description="Test",
            task_type="test",
            run_until_done=True,
        ))

        # First attempt
        attempt1 = db.record_task_attempt("task-attempt")
        assert attempt1 == 1

        # Second attempt
        attempt2 = db.record_task_attempt("task-attempt")
        assert attempt2 == 2

        # Verify in database
        task = db.get_task("task-attempt")
        assert task.attempt_count == 2
        assert task.last_attempt_at is not None

    def test_get_retryable_tasks(self, db: OrchestratorDB):
        """Test getting tasks eligible for retry."""
        # Create a retryable task
        db.create_task(Task(
            id="task-retry",
            description="Retry me",
            task_type="test",
            status="failed",
            run_until_done=True,
            max_retries=3,
            attempt_count=1,
        ))

        # Create a non-retryable task (not run_until_done)
        db.create_task(Task(
            id="task-no-retry",
            description="No retry",
            task_type="test",
            status="failed",
            run_until_done=False,
        ))

        # Create an exhausted task
        db.create_task(Task(
            id="task-exhausted",
            description="Exhausted",
            task_type="test",
            status="failed",
            run_until_done=True,
            max_retries=3,
            attempt_count=3,
        ))

        retryable = db.get_retryable_tasks()
        assert len(retryable) == 1
        assert retryable[0].id == "task-retry"

    def test_reset_task_for_retry(self, db: OrchestratorDB):
        """Test resetting a task for retry."""
        db.create_agent(Agent(id="agent-1", tool="claude_code"))
        db.create_task(Task(
            id="task-reset",
            description="Reset me",
            task_type="test",
            status="failed",
            assigned_agent_id="agent-1",
            error_message="Previous error",
            run_until_done=True,
        ))

        db.reset_task_for_retry("task-reset")

        task = db.get_task("task-reset")
        assert task.status == "pending"
        assert task.assigned_agent_id is None
        assert task.error_message is None

    def test_mark_task_exhausted(self, db: OrchestratorDB):
        """Test marking a task as exhausted."""
        db.create_task(Task(
            id="task-exhaust",
            description="Exhaust me",
            task_type="test",
            status="running",
            run_until_done=True,
        ))

        db.mark_task_exhausted("task-exhaust", "Final failure reason")

        task = db.get_task("task-exhaust")
        assert task.status == "failed"
        assert "[Exhausted after max retries]" in task.error_message
        assert "Final failure reason" in task.error_message

    def test_task_can_retry_property(self):
        """Test Task.can_retry property."""
        # Can retry: run_until_done, failed, attempts remaining
        task = Task(
            id="t1",
            description="Test",
            task_type="test",
            status="failed",
            run_until_done=True,
            max_retries=3,
            attempt_count=1,
        )
        assert task.can_retry is True
        assert task.retries_remaining == 2

        # Cannot retry: not run_until_done
        task2 = Task(
            id="t2",
            description="Test",
            task_type="test",
            status="failed",
            run_until_done=False,
        )
        assert task2.can_retry is False

        # Cannot retry: exhausted
        task3 = Task(
            id="t3",
            description="Test",
            task_type="test",
            status="failed",
            run_until_done=True,
            max_retries=3,
            attempt_count=3,
        )
        assert task3.can_retry is False
        assert task3.retries_remaining == 0

        # Cannot retry: not failed
        task4 = Task(
            id="t4",
            description="Test",
            task_type="test",
            status="completed",
            run_until_done=True,
            max_retries=3,
            attempt_count=1,
        )
        assert task4.can_retry is False

    # =========================================================================
    # Run Tests
    # =========================================================================

    def test_start_and_complete_run(self, db: OrchestratorDB):
        """Test starting and completing a run."""
        db.create_agent(Agent(id="agent-1", tool="claude_code"))

        run_id = db.start_run("run-1", "agent-1", "task-1")
        assert run_id == "run-1"

        packet = StatusPacket(
            agent_id="agent-1",
            task_id="task-1",
            status="completed",
            progress_summary="Implemented feature",
            artifacts=TaskArtifacts(
                tokens_input=1000,
                tokens_output=500,
                cost_usd=0.05,
            ),
        )
        db.complete_run("run-1", "success", packet)

        run = db.get_run("run-1")
        assert run is not None
        assert run.outcome == "success"
        assert run.tokens_input == 1000
        assert run.tokens_output == 500
        assert run.cost_usd == 0.05
        assert run.ended_at is not None

    def test_get_agent_runs(self, db: OrchestratorDB):
        """Test getting runs for an agent."""
        db.create_agent(Agent(id="agent-1", tool="claude_code"))
        db.start_run("run-1", "agent-1")
        db.start_run("run-2", "agent-1")
        db.start_run("run-3", "agent-1")

        runs = db.get_agent_runs("agent-1", limit=2)
        assert len(runs) == 2

    # =========================================================================
    # Health Sample Tests
    # =========================================================================

    def test_record_health_sample(self, db: OrchestratorDB):
        """Test recording a health sample."""
        db.create_agent(Agent(id="agent-1", tool="claude_code"))

        sample = HealthSample(
            agent_id="agent-1",
            token_burn_rate=100.0,
            error_count=2,
            consecutive_same_error=1,
            is_stuck=False,
        )
        db.record_health_sample(sample)

        latest = db.get_latest_health_sample("agent-1")
        assert latest is not None
        assert latest.token_burn_rate == 100.0
        assert latest.error_count == 2

    def test_get_stuck_agents(self, db: OrchestratorDB):
        """Test getting stuck agents."""
        db.create_agent(Agent(id="agent-1", tool="claude_code"))
        db.create_agent(Agent(id="agent-2", tool="gemini_cli"))

        # Agent 1 is stuck
        db.record_health_sample(HealthSample(
            agent_id="agent-1",
            is_stuck=True,
            stuck_reason="repeated_error_loop",
        ))

        # Agent 2 is healthy
        db.record_health_sample(HealthSample(
            agent_id="agent-2",
            is_stuck=False,
        ))

        stuck = db.get_stuck_agents()
        assert len(stuck) == 1
        assert stuck[0].agent_id == "agent-1"

    # =========================================================================
    # Approval Tests
    # =========================================================================

    def test_create_and_get_approval(self, db: OrchestratorDB):
        """Test creating and retrieving approvals."""
        db.create_agent(Agent(id="agent-1", tool="claude_code"))

        approval = Approval(
            id="approval-1",
            agent_id="agent-1",
            action_type="command",
            target="npm install lodash",
            risk_level="medium",
            status="pending",
        )
        db.create_approval(approval)

        pending = db.get_pending_approvals()
        assert len(pending) == 1
        assert pending[0].target == "npm install lodash"

    def test_update_approval(self, db: OrchestratorDB):
        """Test updating an approval decision."""
        db.create_agent(Agent(id="agent-1", tool="claude_code"))
        db.create_approval(Approval(
            id="approval-1",
            agent_id="agent-1",
            action_type="command",
            target="npm install",
            risk_level="medium",
        ))

        db.update_approval("approval-1", "approved", "user", "Looks safe")

        pending = db.get_pending_approvals()
        assert len(pending) == 0  # No longer pending

    # =========================================================================
    # Usage Tests
    # =========================================================================

    def test_update_daily_usage(self, db: OrchestratorDB):
        """Test updating daily usage."""
        db.create_agent(Agent(id="agent-1", tool="claude_code"))

        # First usage
        db.update_daily_usage("agent-1", 1000, 500, 0.05)

        usage = db.get_daily_usage("agent-1")
        assert usage["tokens_input"] == 1000
        assert usage["tokens_output"] == 500
        assert usage["cost_usd"] == 0.05
        assert usage["requests_count"] == 1

        # Second usage (should accumulate)
        db.update_daily_usage("agent-1", 2000, 1000, 0.10)

        usage = db.get_daily_usage("agent-1")
        assert usage["tokens_input"] == 3000
        assert usage["tokens_output"] == 1500
        assert usage["cost_usd"] == pytest.approx(0.15)
        assert usage["requests_count"] == 2

    def test_get_daily_usage_no_data(self, db: OrchestratorDB):
        """Test getting usage when no data exists."""
        usage = db.get_daily_usage("nonexistent-agent")
        assert usage["tokens_input"] == 0
        assert usage["tokens_output"] == 0
        assert usage["cost_usd"] == 0.0

    # =========================================================================
    # MCP Tool Usage Tests
    # =========================================================================

    def test_record_mcp_tool_call(self, db: OrchestratorDB):
        """Test recording MCP tool calls."""
        db.create_agent(Agent(id="agent-1", tool="claude_code"))

        db.record_mcp_tool_call(
            agent_id="agent-1",
            run_id="run-1",
            tool_name="read_file",
            mcp_server="filesystem",
            tokens_used=100,
            duration_ms=50,
            success=True,
        )

        usage = db.get_mcp_usage_today(mcp_server="filesystem")
        assert usage == 100

    # =========================================================================
    # Decision Tests
    # =========================================================================

    def test_record_decision(self, db: OrchestratorDB):
        """Test recording decisions."""
        db.create_agent(Agent(id="agent-1", tool="claude_code"))

        decision = Decision(
            id="dec-001",
            agent_id="agent-1",
            decision="Use JWT for authentication",
            rationale="Matches existing patterns",
            reversible=True,
        )
        db.record_decision(decision)

        decisions = db.get_recent_decisions()
        assert len(decisions) == 1
        assert decisions[0].decision == "Use JWT for authentication"

    # =========================================================================
    # Utility Tests
    # =========================================================================

    def test_generate_run_id(self, db: OrchestratorDB):
        """Test run ID generation."""
        run_id = db.generate_run_id("agent-1")
        assert run_id.startswith("run-")
        assert "agent-1" in run_id

    def test_generate_approval_id(self, db: OrchestratorDB):
        """Test approval ID generation."""
        approval_id = db.generate_approval_id()
        assert approval_id.startswith("approval-")
