"""
Rate Limit Monitor - Real-time monitoring of CLI agent rate limits.

Provides:
- Background monitoring of rate limit states
- Threshold-based alerting
- Callback system for alerts
- Health status aggregation
"""

import asyncio
import logging
from dataclasses import dataclass, field
from datetime import datetime
from typing import Any, Callable, Optional

from .state_models import RateLimitState, CLIStateSnapshot, AgentHealthStatus
from .cli_state_reader import CLIStateReader, get_all_readers

logger = logging.getLogger(__name__)


@dataclass
class RateLimitAlert:
    """Alert when rate limits reach thresholds."""

    agent_id: str
    agent_type: str
    current_percentage: float
    threshold: str  # "warning", "critical", "exhausted"
    reset_at: datetime
    message: str
    created_at: datetime = field(default_factory=datetime.now)

    @property
    def level(self) -> str:
        """Alias for threshold - returns alert level."""
        return self.threshold

    def to_dict(self) -> dict[str, Any]:
        """Convert to dictionary."""
        return {
            "agent_id": self.agent_id,
            "agent_type": self.agent_type,
            "current_percentage": round(self.current_percentage, 2),
            "threshold": self.threshold,
            "level": self.level,
            "reset_at": self.reset_at.isoformat(),
            "message": self.message,
            "created_at": self.created_at.isoformat(),
        }


class RateLimitMonitor:
    """
    Monitor rate limits across multiple CLI agents.

    Provides:
    - Real-time monitoring loop
    - Configurable alert thresholds
    - Callback-based alerting
    - State aggregation

    Example:
        monitor = RateLimitMonitor()

        # Add alert callback
        monitor.on_alert(lambda alert: print(f"Alert: {alert.message}"))

        # Start monitoring (async)
        await monitor.start_monitoring(interval_seconds=60)

        # Or get current state
        states = monitor.get_all_states()
    """

    # Default thresholds (percentage)
    WARNING_THRESHOLD = 70.0
    CRITICAL_THRESHOLD = 90.0

    def __init__(
        self,
        warning_threshold: float = 70.0,
        critical_threshold: float = 90.0,
        readers: Optional[dict[str, CLIStateReader]] = None,
    ) -> None:
        """
        Initialize the rate limit monitor.

        Args:
            warning_threshold: Percentage for warning alerts
            critical_threshold: Percentage for critical alerts
            readers: Custom readers (defaults to all available)
        """
        self.WARNING_THRESHOLD = warning_threshold
        self.CRITICAL_THRESHOLD = critical_threshold

        self._readers = readers or get_all_readers()
        self._callbacks: list[Callable[[RateLimitAlert], None]] = []
        self._running = False
        self._last_states: dict[str, CLIStateSnapshot] = {}
        self._last_alerts: dict[str, RateLimitAlert] = {}  # agent_id -> last alert
        self._alert_history: list[RateLimitAlert] = []  # For get_recent_alerts
        self._alert_cooldown_seconds = 300  # 5 minutes between repeated alerts

    @property
    def warning_threshold(self) -> float:
        """Get warning threshold."""
        return self.WARNING_THRESHOLD

    @property
    def critical_threshold(self) -> float:
        """Get critical threshold."""
        return self.CRITICAL_THRESHOLD

    def register_reader(self, agent_id: str, reader: CLIStateReader) -> None:
        """Register a custom reader for an agent."""
        self._readers[agent_id] = reader

    def on_alert(self, callback: Callable[[RateLimitAlert], None]) -> None:
        """Register a callback for rate limit alerts."""
        self._callbacks.append(callback)

    def clear_callbacks(self) -> None:
        """Clear all alert callbacks."""
        self._callbacks.clear()

    def check_agent(self, agent_id: str) -> Optional[RateLimitAlert]:
        """
        Check a specific agent and return an alert if thresholds are exceeded.

        Args:
            agent_id: The agent to check

        Returns:
            RateLimitAlert if thresholds exceeded, None otherwise
        """
        state = self.get_state(agent_id)
        if not state or not state.rate_limit:
            return None

        pct = state.rate_limit.percentage_used
        threshold = self.get_threshold_status(pct)

        if threshold == "healthy":
            return None

        # Create alert with level property
        alert = RateLimitAlert(
            agent_id=agent_id,
            agent_type=state.agent_type,
            current_percentage=pct,
            threshold=threshold,
            reset_at=state.rate_limit.reset_at,
            message=self._format_alert_message(agent_id, state, threshold),
        )

        # Fire the alert
        self._fire_alert(alert)
        return alert

    def get_state(self, agent_id: str) -> Optional[CLIStateSnapshot]:
        """Get current state for a specific agent."""
        reader = self._readers.get(agent_id)
        if reader:
            state = reader.get_snapshot(agent_id)
            self._last_states[agent_id] = state
            return state
        return self._last_states.get(agent_id)

    def get_all_states(self) -> dict[str, CLIStateSnapshot]:
        """Get current states for all agents."""
        states = {}
        for agent_id, reader in self._readers.items():
            try:
                state = reader.get_snapshot(agent_id)
                states[agent_id] = state
                self._last_states[agent_id] = state
            except Exception as e:
                logger.warning(f"Error getting state for {agent_id}: {e}")
        return states

    def get_health_status(self, agent_id: str) -> Optional[AgentHealthStatus]:
        """Get health status for an agent."""
        state = self.get_state(agent_id)
        if not state:
            return None

        return self._compute_health_status(state)

    def get_all_health_statuses(self) -> dict[str, AgentHealthStatus]:
        """Get health statuses for all agents."""
        states = self.get_all_states()
        return {
            agent_id: self._compute_health_status(state)
            for agent_id, state in states.items()
        }

    def _compute_health_status(self, state: CLIStateSnapshot) -> AgentHealthStatus:
        """Compute health status from state snapshot."""
        rate_pct = state.rate_limit.percentage_used if state.rate_limit else 0.0
        is_rate_limited = state.rate_limit.is_exhausted if state.rate_limit else False
        is_waiting = state.session.waiting_for_input if state.session else False
        last_activity = state.session.idle_seconds if state.session else float("inf")

        # Determine status
        if not state.is_installed or not state.state_dir_exists:
            status = "unavailable"
            message = state.availability_reason
        elif is_rate_limited:
            status = "critical"
            message = f"Rate limit exhausted (resets at {state.rate_limit.reset_at})"
        elif rate_pct >= self.CRITICAL_THRESHOLD:
            status = "critical"
            message = f"Rate limit at {rate_pct:.1f}%"
        elif rate_pct >= self.WARNING_THRESHOLD:
            status = "warning"
            message = f"Rate limit at {rate_pct:.1f}%"
        elif is_waiting:
            status = "warning"
            message = f"Waiting for input: {state.session.waiting_state.value}"
        else:
            status = "healthy"
            message = "Available"

        return AgentHealthStatus(
            agent_id=state.agent_id,
            agent_type=state.agent_type,
            status=status,
            rate_limit_percentage=rate_pct,
            is_rate_limited=is_rate_limited,
            is_waiting_for_input=is_waiting,
            last_activity_seconds_ago=last_activity,
            message=message,
        )

    def get_threshold_status(self, percentage: float) -> str:
        """Get threshold status for a percentage value."""
        if percentage >= 100:
            return "exhausted"
        elif percentage >= self.CRITICAL_THRESHOLD:
            return "critical"
        elif percentage >= self.WARNING_THRESHOLD:
            return "warning"
        return "healthy"

    async def start_monitoring(self, interval_seconds: int = 60) -> None:
        """
        Start background monitoring loop.

        Args:
            interval_seconds: Check interval in seconds
        """
        self._running = True
        logger.info(f"Starting rate limit monitoring (interval: {interval_seconds}s)")

        while self._running:
            try:
                await self._check_all_limits()
            except Exception as e:
                logger.error(f"Error in monitoring loop: {e}")

            await asyncio.sleep(interval_seconds)

    def stop_monitoring(self) -> None:
        """Stop the monitoring loop."""
        self._running = False
        logger.info("Stopping rate limit monitoring")

    async def _check_all_limits(self) -> None:
        """Check all limits and fire alerts if needed."""
        states = self.get_all_states()

        for agent_id, state in states.items():
            if not state.rate_limit:
                continue

            alert = self._check_state_for_alert(agent_id, state)
            if alert:
                self._fire_alert(alert)

    def _check_state_for_alert(
        self,
        agent_id: str,
        state: CLIStateSnapshot,
    ) -> Optional[RateLimitAlert]:
        """Check if state warrants an alert."""
        if not state.rate_limit:
            return None

        pct = state.rate_limit.percentage_used
        threshold = self.get_threshold_status(pct)

        if threshold == "healthy":
            return None

        # Check cooldown
        last_alert = self._last_alerts.get(agent_id)
        if last_alert:
            seconds_since = (datetime.now() - last_alert.created_at).total_seconds()
            if seconds_since < self._alert_cooldown_seconds:
                # Don't repeat alerts within cooldown unless severity increased
                if threshold == last_alert.threshold:
                    return None

        return RateLimitAlert(
            agent_id=agent_id,
            agent_type=state.agent_type,
            current_percentage=pct,
            threshold=threshold,
            reset_at=state.rate_limit.reset_at,
            message=self._format_alert_message(agent_id, state, threshold),
        )

    def _format_alert_message(
        self,
        agent_id: str,
        state: CLIStateSnapshot,
        threshold: str,
    ) -> str:
        """Format alert message."""
        pct = state.rate_limit.percentage_used
        remaining = state.rate_limit.remaining
        reset = state.rate_limit.reset_at.strftime("%H:%M")

        if threshold == "exhausted":
            return f"{agent_id} rate limit exhausted. Resets at {reset}"
        elif threshold == "critical":
            return f"{agent_id} at {pct:.1f}% rate limit ({remaining} remaining). Resets at {reset}"
        else:
            return f"{agent_id} at {pct:.1f}% rate limit ({remaining} remaining)"

    def _fire_alert(self, alert: RateLimitAlert) -> None:
        """Fire an alert to all callbacks."""
        self._last_alerts[alert.agent_id] = alert
        self._alert_history.append(alert)

        logger.warning(f"Rate limit alert: {alert.message}")

        for callback in self._callbacks:
            try:
                callback(alert)
            except Exception as e:
                logger.error(f"Error in alert callback: {e}")

    def get_recent_alerts(self, limit: int = 50) -> list[RateLimitAlert]:
        """
        Get recent alerts.

        Args:
            limit: Maximum number of alerts to return

        Returns:
            List of recent alerts, most recent first
        """
        return list(reversed(self._alert_history[-limit:]))

    def get_health_summary(self) -> dict[str, Any]:
        """Get health summary for all agents."""
        return self.get_summary()

    def get_summary(self) -> dict[str, Any]:
        """Get monitoring summary."""
        states = self.get_all_states()
        health_statuses = {
            agent_id: self._compute_health_status(state)
            for agent_id, state in states.items()
        }

        status_counts = {"healthy": 0, "warning": 0, "critical": 0, "unavailable": 0}
        for status in health_statuses.values():
            status_counts[status.status] = status_counts.get(status.status, 0) + 1

        return {
            "total_agents": len(states),
            "monitoring_active": self._running,
            "status_counts": status_counts,
            "thresholds": {
                "warning": self.WARNING_THRESHOLD,
                "critical": self.CRITICAL_THRESHOLD,
            },
            "agents": {
                agent_id: status.to_dict()
                for agent_id, status in health_statuses.items()
            },
        }


# Module-level monitor instance
_monitor: Optional[RateLimitMonitor] = None


def get_rate_limit_monitor() -> RateLimitMonitor:
    """Get or create the global rate limit monitor."""
    global _monitor
    if _monitor is None:
        _monitor = RateLimitMonitor()
    return _monitor


def set_rate_limit_monitor(monitor: Optional[RateLimitMonitor]) -> None:
    """Set the global rate limit monitor."""
    global _monitor
    _monitor = monitor
