"""
Approval Queue - Centralized approval request management.

This module provides:
- Centralized queue for all approval requests
- Automatic timeout handling
- Priority-based processing
- Statistics and monitoring

Usage:
    from agent_orchestrator.interrupt import ApprovalQueue

    queue = ApprovalQueue(db, cli_handler, async_handler)
    await queue.start()

    # Submit approval request
    response = await queue.submit(
        agent_id="claude-code",
        action_type="command",
        target="git push",
        risk_level="high",
    )
"""

import asyncio
import logging
from dataclasses import dataclass, field
from datetime import datetime, timedelta
from enum import Enum
from typing import Any, Callable, Optional, Protocol

from ..persistence.database import OrchestratorDB
from ..persistence.models import Approval
from .cli_handler import ApprovalResponse, ApprovalDecision, CLIInterruptHandler
from .async_handler import AsyncInterruptHandler


logger = logging.getLogger(__name__)


class InterruptHandler(Protocol):
    """Protocol for interrupt handlers."""

    async def request_approval(
        self,
        agent_id: str,
        action_type: str,
        target: str,
        risk_level: str,
        context: Optional[dict[str, Any]] = None,
        diff: Optional[str] = None,
    ) -> ApprovalResponse:
        """Request approval."""
        ...


class HandlerMode(Enum):
    """Handler selection mode."""

    CLI = "cli"  # Use CLI handler (blocking)
    ASYNC = "async"  # Use async handler (webhook)
    AUTO = "auto"  # Auto-select based on environment


class ApprovalPriority(Enum):
    """Priority levels for approval requests."""

    CRITICAL = 0  # Process immediately
    HIGH = 1
    MEDIUM = 2
    LOW = 3


@dataclass
class QueuedApproval:
    """An approval request in the queue."""

    approval_id: str
    agent_id: str
    action_type: str
    target: str
    risk_level: str
    priority: ApprovalPriority
    created_at: datetime
    timeout_at: datetime
    context: dict[str, Any] = field(default_factory=dict)
    diff: Optional[str] = None
    future: Optional[asyncio.Future] = None


@dataclass
class QueueConfig:
    """Configuration for the approval queue."""

    # Default timeout
    default_timeout_seconds: int = 300  # 5 minutes

    # Per-risk-level timeouts
    timeout_by_risk: dict[str, int] = field(default_factory=lambda: {
        "low": 60,  # 1 minute
        "medium": 300,  # 5 minutes
        "high": 600,  # 10 minutes
        "critical": 60,  # 1 minute (fast reject)
    })

    # Maximum queue size
    max_queue_size: int = 100

    # Handler mode
    handler_mode: HandlerMode = HandlerMode.AUTO

    # Cleanup interval
    cleanup_interval_seconds: int = 60

    # Notification on queue overflow
    notify_on_overflow: bool = True


@dataclass
class QueueStats:
    """Statistics for the approval queue."""

    total_submitted: int = 0
    total_approved: int = 0
    total_rejected: int = 0
    total_timeout: int = 0
    total_skipped: int = 0
    current_pending: int = 0
    avg_wait_time_seconds: float = 0.0
    last_reset: datetime = field(default_factory=datetime.now)


class ApprovalQueue:
    """
    Centralized approval queue manager.

    Handles routing requests to appropriate handlers,
    timeout management, and statistics tracking.
    """

    def __init__(
        self,
        db: OrchestratorDB,
        cli_handler: Optional[CLIInterruptHandler] = None,
        async_handler: Optional[AsyncInterruptHandler] = None,
        config: Optional[QueueConfig] = None,
    ):
        """
        Initialize the approval queue.

        Args:
            db: Database for persistence
            cli_handler: CLI interrupt handler
            async_handler: Async interrupt handler
            config: Queue configuration
        """
        self.db = db
        self.cli_handler = cli_handler
        self.async_handler = async_handler
        self.config = config or QueueConfig()

        # Queue storage
        self._queue: asyncio.PriorityQueue[tuple[int, float, QueuedApproval]] = asyncio.PriorityQueue()
        self._pending: dict[str, QueuedApproval] = {}

        # Statistics
        self._stats = QueueStats()
        self._wait_times: list[float] = []

        # Background tasks
        self._running = False
        self._processor_task: Optional[asyncio.Task] = None
        self._cleanup_task: Optional[asyncio.Task] = None

        # Callbacks
        self._on_approval: Optional[Callable] = None
        self._on_rejection: Optional[Callable] = None
        self._on_timeout: Optional[Callable] = None

    async def start(self) -> None:
        """Start the queue processor."""
        if self._running:
            return

        self._running = True
        self._cleanup_task = asyncio.create_task(self._cleanup_loop())

        logger.info("Approval queue started")

    async def stop(self) -> None:
        """Stop the queue processor."""
        self._running = False

        if self._cleanup_task:
            self._cleanup_task.cancel()
            try:
                await self._cleanup_task
            except asyncio.CancelledError:
                pass

        # Cancel all pending requests
        for approval_id, queued in list(self._pending.items()):
            if queued.future and not queued.future.done():
                queued.future.set_result(ApprovalResponse.reject(
                    reason="Queue stopped",
                    approval_id=approval_id,
                ))

        logger.info("Approval queue stopped")

    async def submit(
        self,
        agent_id: str,
        action_type: str,
        target: str,
        risk_level: str,
        context: Optional[dict[str, Any]] = None,
        diff: Optional[str] = None,
        priority: Optional[ApprovalPriority] = None,
        timeout_seconds: Optional[int] = None,
    ) -> ApprovalResponse:
        """
        Submit an approval request to the queue.

        Args:
            agent_id: Agent requesting approval
            action_type: Type of action
            target: Target of action
            risk_level: Risk level
            context: Additional context
            diff: Diff for file edits
            priority: Request priority
            timeout_seconds: Custom timeout

        Returns:
            ApprovalResponse with decision
        """
        context = context or {}

        # Check queue capacity
        if len(self._pending) >= self.config.max_queue_size:
            logger.error("Approval queue full, auto-rejecting")
            return ApprovalResponse.reject(
                reason="Approval queue full",
                decided_by="system",
            )

        # Determine priority from risk level if not specified
        if priority is None:
            priority = self._risk_to_priority(risk_level)

        # Determine timeout
        if timeout_seconds is None:
            timeout_seconds = self.config.timeout_by_risk.get(
                risk_level.lower(),
                self.config.default_timeout_seconds
            )

        # Create queued approval
        now = datetime.now()
        approval_id = self.db.generate_approval_id()

        queued = QueuedApproval(
            approval_id=approval_id,
            agent_id=agent_id,
            action_type=action_type,
            target=target,
            risk_level=risk_level,
            priority=priority,
            created_at=now,
            timeout_at=now + timedelta(seconds=timeout_seconds),
            context=context,
            diff=diff,
            future=asyncio.get_event_loop().create_future(),
        )

        # Store in pending
        self._pending[approval_id] = queued
        self._stats.total_submitted += 1
        self._stats.current_pending = len(self._pending)

        # Create database record
        approval = Approval(
            id=approval_id,
            agent_id=agent_id,
            action_type=action_type,
            target=target,
            risk_level=risk_level,
            status="pending",
        )
        self.db.create_approval(approval)

        # Process immediately (don't queue for async processing)
        asyncio.create_task(self._process_approval(queued))

        # Wait for result
        try:
            response = await queued.future
        except asyncio.CancelledError:
            response = ApprovalResponse.reject(
                reason="Request cancelled",
                approval_id=approval_id,
            )

        # Update stats
        self._update_stats(response, queued)
        self._pending.pop(approval_id, None)
        self._stats.current_pending = len(self._pending)

        return response

    async def _process_approval(self, queued: QueuedApproval) -> None:
        """Process a single approval request."""
        try:
            handler = self._select_handler()

            response = await handler.request_approval(
                agent_id=queued.agent_id,
                action_type=queued.action_type,
                target=queued.target,
                risk_level=queued.risk_level,
                context=queued.context,
                diff=queued.diff,
            )

            # Set result on future
            if queued.future and not queued.future.done():
                queued.future.set_result(response)

            # Trigger callbacks
            await self._trigger_callbacks(response)

        except Exception as e:
            logger.error(f"Error processing approval {queued.approval_id}: {e}")
            if queued.future and not queued.future.done():
                queued.future.set_result(ApprovalResponse.reject(
                    reason=f"Processing error: {str(e)}",
                    approval_id=queued.approval_id,
                ))

    def _select_handler(self) -> InterruptHandler:
        """Select appropriate handler based on configuration."""
        if self.config.handler_mode == HandlerMode.CLI:
            if self.cli_handler:
                return self.cli_handler
            raise RuntimeError("CLI handler not configured")

        elif self.config.handler_mode == HandlerMode.ASYNC:
            if self.async_handler:
                return self.async_handler
            raise RuntimeError("Async handler not configured")

        else:  # AUTO
            # Prefer CLI if available and terminal is interactive
            import sys
            if self.cli_handler and sys.stdin.isatty():
                return self.cli_handler
            elif self.async_handler:
                return self.async_handler
            elif self.cli_handler:
                return self.cli_handler
            else:
                raise RuntimeError("No handlers configured")

    def _risk_to_priority(self, risk_level: str) -> ApprovalPriority:
        """Convert risk level to priority."""
        mapping = {
            "critical": ApprovalPriority.CRITICAL,
            "high": ApprovalPriority.HIGH,
            "medium": ApprovalPriority.MEDIUM,
            "low": ApprovalPriority.LOW,
        }
        return mapping.get(risk_level.lower(), ApprovalPriority.MEDIUM)

    def _update_stats(self, response: ApprovalResponse, queued: QueuedApproval) -> None:
        """Update statistics based on response."""
        # Update counters
        if response.decision == ApprovalDecision.APPROVED:
            self._stats.total_approved += 1
        elif response.decision == ApprovalDecision.REJECTED:
            self._stats.total_rejected += 1
        elif response.decision == ApprovalDecision.TIMEOUT:
            self._stats.total_timeout += 1
        elif response.decision == ApprovalDecision.SKIPPED:
            self._stats.total_skipped += 1

        # Update wait time
        wait_time = (datetime.now() - queued.created_at).total_seconds()
        self._wait_times.append(wait_time)

        # Keep only last 100 wait times for average
        if len(self._wait_times) > 100:
            self._wait_times = self._wait_times[-100:]

        self._stats.avg_wait_time_seconds = sum(self._wait_times) / len(self._wait_times)

    async def _trigger_callbacks(self, response: ApprovalResponse) -> None:
        """Trigger registered callbacks."""
        try:
            if response.decision == ApprovalDecision.APPROVED and self._on_approval:
                await self._on_approval(response)
            elif response.decision == ApprovalDecision.REJECTED and self._on_rejection:
                await self._on_rejection(response)
            elif response.decision == ApprovalDecision.TIMEOUT and self._on_timeout:
                await self._on_timeout(response)
        except Exception as e:
            logger.error(f"Callback error: {e}")

    async def _cleanup_loop(self) -> None:
        """Background loop to clean up timed-out requests."""
        while self._running:
            try:
                await asyncio.sleep(self.config.cleanup_interval_seconds)
                await self._check_timeouts()
            except asyncio.CancelledError:
                break
            except Exception as e:
                logger.error(f"Cleanup error: {e}")

    async def _check_timeouts(self) -> None:
        """Check for and handle timed-out requests."""
        now = datetime.now()

        for approval_id, queued in list(self._pending.items()):
            if now >= queued.timeout_at:
                logger.warning(f"Approval {approval_id} timed out")

                response = ApprovalResponse.timeout(approval_id=approval_id)

                if queued.future and not queued.future.done():
                    queued.future.set_result(response)

                # Update database
                self.db.update_approval(
                    approval_id=approval_id,
                    status="timeout",
                    decided_by="system",
                    decision_notes="Timed out waiting for response",
                )

    # =========================================================================
    # Callback registration
    # =========================================================================

    def on_approval(self, callback: Callable) -> None:
        """Register callback for approvals."""
        self._on_approval = callback

    def on_rejection(self, callback: Callable) -> None:
        """Register callback for rejections."""
        self._on_rejection = callback

    def on_timeout(self, callback: Callable) -> None:
        """Register callback for timeouts."""
        self._on_timeout = callback

    # =========================================================================
    # Query methods
    # =========================================================================

    def get_stats(self) -> QueueStats:
        """Get current queue statistics."""
        return self._stats

    def get_pending_count(self) -> int:
        """Get number of pending approvals."""
        return len(self._pending)

    def get_pending_approvals(self) -> list[dict[str, Any]]:
        """Get list of pending approvals."""
        return [
            {
                "approval_id": q.approval_id,
                "agent_id": q.agent_id,
                "action_type": q.action_type,
                "target": q.target,
                "risk_level": q.risk_level,
                "priority": q.priority.value,
                "created_at": q.created_at.isoformat(),
                "timeout_at": q.timeout_at.isoformat(),
            }
            for q in self._pending.values()
        ]

    def cancel(self, approval_id: str) -> bool:
        """
        Cancel a pending approval request.

        Args:
            approval_id: The approval to cancel

        Returns:
            True if cancelled
        """
        queued = self._pending.get(approval_id)
        if not queued:
            return False

        if queued.future and not queued.future.done():
            queued.future.set_result(ApprovalResponse.reject(
                reason="Cancelled",
                decided_by="system",
                approval_id=approval_id,
            ))

        self._pending.pop(approval_id, None)

        self.db.update_approval(
            approval_id=approval_id,
            status="rejected",
            decided_by="system",
            decision_notes="Cancelled by queue",
        )

        return True

    def cancel_all(self, agent_id: Optional[str] = None) -> int:
        """
        Cancel all pending approvals, optionally filtered by agent.

        Returns:
            Number of approvals cancelled
        """
        cancelled = 0

        for approval_id, queued in list(self._pending.items()):
            if agent_id is None or queued.agent_id == agent_id:
                if self.cancel(approval_id):
                    cancelled += 1

        return cancelled

    def reset_stats(self) -> None:
        """Reset queue statistics."""
        self._stats = QueueStats()
        self._wait_times = []
