"""
CLI Usage Tracker - Enhanced tracking for interactive CLI agents.

Provides:
- Per-minute rate limiting for CLI agents
- Usage recording integration with BudgetEnforcer
- Agent prioritization based on current usage
- Real-time usage monitoring and alerts

Usage:
    from agent_orchestrator.budget.cli_usage_tracker import CLIUsageTracker

    tracker = CLIUsageTracker(budget_enforcer, db)

    # Before execution
    can_run = await tracker.check_rate_limit("claude-code")

    # After execution
    await tracker.record_cli_usage("claude-code", response)
"""

import asyncio
import json
import logging
from collections import deque
from dataclasses import dataclass, field
from datetime import datetime, timedelta
from typing import Any, Optional

from .agent_budget import BudgetEnforcer, BudgetStatus
from ..persistence.database import OrchestratorDB
from ..adapters.base import AgentResponse

logger = logging.getLogger(__name__)


@dataclass
class CLISession:
    """Represents a CLI agent session."""

    session_id: str
    agent_id: str
    started_at: datetime = field(default_factory=datetime.now)
    request_count: int = 0
    total_tokens: int = 0
    total_cost_usd: float = 0.0
    is_active: bool = True


@dataclass
class RateLimitConfig:
    """Rate limit configuration for a CLI agent."""

    agent_id: str

    # Requests per minute
    rpm_limit: int = 60

    # Tokens per minute
    tpm_limit: int = 100_000

    # Burst allowance (short-term spike)
    burst_allowance: int = 10  # Extra requests allowed in burst

    # Cooldown after hitting limit (seconds)
    cooldown_seconds: int = 60


@dataclass
class RateLimitState:
    """Current rate limit state for an agent."""

    agent_id: str

    # Recent requests (timestamps)
    recent_requests: deque = field(default_factory=lambda: deque(maxlen=100))

    # Recent token usage (timestamp, tokens)
    recent_tokens: deque = field(default_factory=lambda: deque(maxlen=100))

    # Cooldown until
    cooldown_until: Optional[datetime] = None

    # Consecutive limit hits
    consecutive_limit_hits: int = 0


@dataclass
class UsagePriority:
    """Agent prioritization based on usage."""

    agent_id: str
    priority_score: float  # 0-100, higher is better
    reason: str
    budget_remaining_pct: float
    rate_limit_headroom: float
    error_rate: float
    last_activity: datetime


class CLIUsageTracker:
    """
    Enhanced usage tracker for CLI agents.

    Features:
    - Per-minute rate limiting (RPM and TPM)
    - Budget integration and enforcement
    - Agent prioritization for routing
    - Usage monitoring and alerts
    - Session persistence for state recovery
    """

    def __init__(
        self,
        budget_enforcer: BudgetEnforcer,
        db: OrchestratorDB,
        rate_configs: Optional[dict[str, RateLimitConfig]] = None,
        auto_restore_state: bool = True,
    ):
        """
        Initialize CLI usage tracker.

        Args:
            budget_enforcer: Budget enforcer instance
            db: Database for persistence
            rate_configs: Per-agent rate limit configs
            auto_restore_state: Whether to restore state from DB on init
        """
        self.budget_enforcer = budget_enforcer
        self.db = db

        # Rate limit configurations
        self._rate_configs: dict[str, RateLimitConfig] = rate_configs or {}
        self._rate_states: dict[str, RateLimitState] = {}

        # Session tracking
        self._sessions: dict[str, CLISession] = {}

        # Default rate configs for known CLI agents
        self._init_default_configs()

        # Alert thresholds
        self.high_usage_threshold = 0.85  # 85% of limit
        self.critical_usage_threshold = 0.95  # 95% of limit

        # Restore state from database if enabled
        if auto_restore_state:
            self._restore_all_state()

    def _init_default_configs(self) -> None:
        """Initialize default rate limit configs for known agents."""
        defaults = {
            "claude-code": RateLimitConfig(
                agent_id="claude-code",
                rpm_limit=50,  # Conservative for CLI
                tpm_limit=100_000,
            ),
            "gemini-cli": RateLimitConfig(
                agent_id="gemini-cli",
                rpm_limit=60,  # Free tier: 60 RPM
                tpm_limit=150_000,
            ),
            "codex-cli": RateLimitConfig(
                agent_id="codex-cli",
                rpm_limit=50,
                tpm_limit=80_000,
            ),
        }

        for agent_id, config in defaults.items():
            if agent_id not in self._rate_configs:
                self._rate_configs[agent_id] = config

    def _get_rate_state(self, agent_id: str) -> RateLimitState:
        """Get or create rate limit state for an agent."""
        if agent_id not in self._rate_states:
            self._rate_states[agent_id] = RateLimitState(agent_id=agent_id)
        return self._rate_states[agent_id]

    def set_rate_config(self, agent_id: str, config: RateLimitConfig) -> None:
        """Set rate limit configuration for an agent."""
        self._rate_configs[agent_id] = config
        logger.info(f"Set rate limit for {agent_id}: {config.rpm_limit} RPM, {config.tpm_limit} TPM")

    async def check_rate_limit(
        self,
        agent_id: str,
        estimated_tokens: int = 0,
    ) -> tuple[bool, Optional[str]]:
        """
        Check if agent is within rate limits.

        Args:
            agent_id: Agent to check
            estimated_tokens: Estimated tokens for request

        Returns:
            Tuple of (allowed, reason_if_denied)
        """
        config = self._rate_configs.get(agent_id)
        if not config:
            # No rate limit configured, allow
            return True, None

        state = self._get_rate_state(agent_id)
        now = datetime.now()

        # Check cooldown
        if state.cooldown_until and now < state.cooldown_until:
            remaining = (state.cooldown_until - now).total_seconds()
            return False, f"Rate limit cooldown: {remaining:.0f}s remaining"

        # Clean up old requests (older than 1 minute)
        cutoff = now - timedelta(minutes=1)

        # Count requests in last minute
        recent_count = sum(1 for ts in state.recent_requests if ts > cutoff)

        # Count tokens in last minute
        recent_token_sum = sum(
            tokens for ts, tokens in state.recent_tokens
            if ts > cutoff
        )

        # Check RPM limit
        effective_rpm = config.rpm_limit + config.burst_allowance
        if recent_count >= effective_rpm:
            state.consecutive_limit_hits += 1

            # Apply cooldown if consistently hitting limits
            if state.consecutive_limit_hits >= 3:
                state.cooldown_until = now + timedelta(seconds=config.cooldown_seconds)
                logger.warning(
                    f"Agent {agent_id} hit rate limit {state.consecutive_limit_hits} times, "
                    f"applying {config.cooldown_seconds}s cooldown"
                )

            return False, f"RPM limit exceeded: {recent_count}/{config.rpm_limit}"

        # Check TPM limit
        if recent_token_sum + estimated_tokens > config.tpm_limit:
            return False, f"TPM limit exceeded: {recent_token_sum + estimated_tokens}/{config.tpm_limit}"

        # Reset consecutive hits if we're within limits
        state.consecutive_limit_hits = 0

        return True, None

    async def record_cli_usage(
        self,
        agent_id: str,
        response: AgentResponse,
    ) -> None:
        """
        Record usage for a CLI agent request.

        Integrates with BudgetEnforcer and tracks rate limits.

        Args:
            agent_id: Agent that made the request
            response: Response from the agent
        """
        now = datetime.now()

        # Record in rate limit state
        state = self._get_rate_state(agent_id)
        state.recent_requests.append(now)

        total_tokens = response.tokens_input + response.tokens_output
        if total_tokens > 0:
            state.recent_tokens.append((now, total_tokens))

        # Update session stats
        session = self._sessions.get(agent_id)
        if session:
            session.request_count += 1
            session.total_tokens += total_tokens
            session.total_cost_usd += response.cost

        # Record in budget enforcer
        self.budget_enforcer.record_usage(
            agent_id=agent_id,
            input_tokens=response.tokens_input,
            output_tokens=response.tokens_output,
            cost=response.cost,
            success=response.success,
        )

        # Check for high usage and log warnings
        await self._check_usage_alerts(agent_id)

        # Periodically save state (every 10 requests)
        if session and session.request_count % 10 == 0:
            self._save_state(agent_id)

        logger.debug(
            f"Recorded usage for {agent_id}: "
            f"{response.tokens_input} in + {response.tokens_output} out = "
            f"${response.cost:.4f}"
        )

    async def _check_usage_alerts(self, agent_id: str) -> None:
        """Check usage levels and log alerts if needed."""
        # Check budget status
        budget_check = self.budget_enforcer.check_budget(agent_id)

        if budget_check.status == BudgetStatus.WARNING:
            logger.warning(
                f"Agent {agent_id} usage WARNING: "
                f"Tokens: {budget_check.input_token_percentage:.0%} in, "
                f"{budget_check.output_token_percentage:.0%} out, "
                f"Cost: {budget_check.cost_percentage:.0%}"
            )
        elif budget_check.status == BudgetStatus.EXCEEDED:
            logger.error(
                f"Agent {agent_id} EXCEEDED budget: {budget_check.reason}"
            )

        # Check rate limit headroom
        config = self._rate_configs.get(agent_id)
        if config:
            state = self._get_rate_state(agent_id)
            now = datetime.now()
            cutoff = now - timedelta(minutes=1)

            recent_count = sum(1 for ts in state.recent_requests if ts > cutoff)
            rpm_pct = recent_count / config.rpm_limit

            if rpm_pct >= self.critical_usage_threshold:
                logger.warning(
                    f"Agent {agent_id} rate limit CRITICAL: "
                    f"{recent_count}/{config.rpm_limit} RPM ({rpm_pct:.0%})"
                )
            elif rpm_pct >= self.high_usage_threshold:
                logger.info(
                    f"Agent {agent_id} rate limit HIGH: "
                    f"{recent_count}/{config.rpm_limit} RPM ({rpm_pct:.0%})"
                )

    def get_agent_priority(self, agent_id: str) -> UsagePriority:
        """
        Calculate priority score for an agent based on current usage.

        Higher score = better candidate for routing.

        Args:
            agent_id: Agent to evaluate

        Returns:
            UsagePriority with score and details
        """
        # Get budget info
        budget_check = self.budget_enforcer.check_budget(agent_id)
        budget = self.budget_enforcer.get_budget(agent_id)
        usage = self.budget_enforcer._get_daily_usage(agent_id)

        # Calculate budget headroom (higher is better)
        budget_remaining_pct = 1.0 - max(
            budget_check.input_token_percentage,
            budget_check.output_token_percentage,
            budget_check.cost_percentage,
        )

        # Calculate rate limit headroom
        rate_limit_headroom = 1.0
        config = self._rate_configs.get(agent_id)
        if config:
            state = self._get_rate_state(agent_id)
            now = datetime.now()
            cutoff = now - timedelta(minutes=1)
            recent_count = sum(1 for ts in state.recent_requests if ts > cutoff)
            rate_limit_headroom = 1.0 - (recent_count / config.rpm_limit)

        # Calculate error rate penalty
        error_rate = 0.0
        if usage.successful_requests + usage.failed_requests > 0:
            error_rate = usage.failed_requests / (usage.successful_requests + usage.failed_requests)

        # Calculate priority score (0-100)
        # Weight: 50% budget, 30% rate limit, 20% error rate
        priority_score = (
            budget_remaining_pct * 50 +
            rate_limit_headroom * 30 +
            (1.0 - error_rate) * 20
        )

        # Penalties
        if budget_check.status == BudgetStatus.EXCEEDED:
            priority_score = 0.0
            reason = "Budget exceeded"
        elif budget_check.status == BudgetStatus.PAUSED:
            priority_score = 0.0
            reason = "Agent paused"
        elif state := self._rate_states.get(agent_id):
            if state.cooldown_until and datetime.now() < state.cooldown_until:
                priority_score = 0.0
                reason = "Rate limit cooldown"
            else:
                reason = f"Score: {priority_score:.1f}/100"
        else:
            reason = f"Score: {priority_score:.1f}/100"

        return UsagePriority(
            agent_id=agent_id,
            priority_score=priority_score,
            reason=reason,
            budget_remaining_pct=budget_remaining_pct,
            rate_limit_headroom=rate_limit_headroom,
            error_rate=error_rate,
            last_activity=usage.last_request_at or datetime.now(),
        )

    def rank_agents(self, agent_ids: list[str]) -> list[UsagePriority]:
        """
        Rank agents by priority for routing decisions.

        Args:
            agent_ids: List of agent IDs to rank

        Returns:
            List of UsagePriority sorted by score (highest first)
        """
        priorities = [self.get_agent_priority(agent_id) for agent_id in agent_ids]
        return sorted(priorities, key=lambda p: p.priority_score, reverse=True)

    def get_usage_summary(self, agent_id: str) -> dict[str, Any]:
        """
        Get comprehensive usage summary for an agent.

        Args:
            agent_id: Agent to summarize

        Returns:
            Dictionary with usage details
        """
        # Budget info
        budget_summary = self.budget_enforcer.get_usage_summary(agent_id)

        # Rate limit info
        rate_info = {"configured": False}
        config = self._rate_configs.get(agent_id)
        if config:
            state = self._get_rate_state(agent_id)
            now = datetime.now()
            cutoff = now - timedelta(minutes=1)

            recent_count = sum(1 for ts in state.recent_requests if ts > cutoff)
            recent_tokens = sum(
                tokens for ts, tokens in state.recent_tokens
                if ts > cutoff
            )

            rate_info = {
                "configured": True,
                "rpm_limit": config.rpm_limit,
                "rpm_current": recent_count,
                "rpm_percentage": recent_count / config.rpm_limit if config.rpm_limit > 0 else 0,
                "tpm_limit": config.tpm_limit,
                "tpm_current": recent_tokens,
                "tpm_percentage": recent_tokens / config.tpm_limit if config.tpm_limit > 0 else 0,
                "cooldown_active": state.cooldown_until and now < state.cooldown_until,
                "consecutive_hits": state.consecutive_limit_hits,
            }

        # Priority info
        priority = self.get_agent_priority(agent_id)

        return {
            "agent_id": agent_id,
            "timestamp": datetime.now().isoformat(),
            "budget": budget_summary,
            "rate_limits": rate_info,
            "priority": {
                "score": priority.priority_score,
                "reason": priority.reason,
                "budget_headroom": priority.budget_remaining_pct,
                "rate_headroom": priority.rate_limit_headroom,
                "error_rate": priority.error_rate,
            },
        }

    def reset_rate_limit(self, agent_id: str) -> None:
        """Manually reset rate limit state for an agent (for testing/admin)."""
        if agent_id in self._rate_states:
            self._rate_states[agent_id] = RateLimitState(agent_id=agent_id)
            logger.info(f"Reset rate limit state for {agent_id}")

    # =========================================================================
    # Session Management
    # =========================================================================

    def start_session(self, agent_id: str) -> CLISession:
        """
        Start a new CLI session for an agent.

        Args:
            agent_id: Agent to start session for

        Returns:
            New CLISession object
        """
        # End any existing active session
        if agent_id in self._sessions:
            self.end_session(agent_id)

        session_id = self.db.generate_session_id(agent_id)
        session = CLISession(
            session_id=session_id,
            agent_id=agent_id,
        )

        self._sessions[agent_id] = session

        # Persist to database
        try:
            self.db.create_cli_session(session_id, agent_id)
            logger.info(f"Started CLI session {session_id} for {agent_id}")
        except Exception as e:
            logger.warning(f"Failed to persist session start: {e}")

        return session

    def end_session(self, agent_id: str) -> Optional[CLISession]:
        """
        End the current CLI session for an agent.

        Args:
            agent_id: Agent whose session to end

        Returns:
            The ended session, or None if no active session
        """
        session = self._sessions.pop(agent_id, None)
        if not session:
            return None

        session.is_active = False

        # Save final state
        self._save_state(agent_id)

        # Update database
        try:
            self.db.end_cli_session(
                session_id=session.session_id,
                request_count=session.request_count,
                total_tokens=session.total_tokens,
                total_cost_usd=session.total_cost_usd,
            )
            logger.info(
                f"Ended CLI session {session.session_id}: "
                f"{session.request_count} requests, "
                f"{session.total_tokens} tokens, "
                f"${session.total_cost_usd:.4f}"
            )
        except Exception as e:
            logger.warning(f"Failed to persist session end: {e}")

        return session

    def get_session(self, agent_id: str) -> Optional[CLISession]:
        """Get the current session for an agent."""
        return self._sessions.get(agent_id)

    def get_or_start_session(self, agent_id: str) -> CLISession:
        """Get existing session or start a new one."""
        if agent_id not in self._sessions:
            return self.start_session(agent_id)
        return self._sessions[agent_id]

    # =========================================================================
    # State Persistence
    # =========================================================================

    def _save_state(self, agent_id: str) -> None:
        """
        Save rate limit state for an agent to the database.

        Args:
            agent_id: Agent whose state to save
        """
        state = self._rate_states.get(agent_id)
        if not state:
            return

        session = self._sessions.get(agent_id)
        session_id = session.session_id if session else None

        # Serialize recent requests as JSON
        now = datetime.now()
        cutoff = now - timedelta(minutes=5)  # Keep last 5 minutes

        recent_requests = [
            ts.isoformat()
            for ts in state.recent_requests
            if ts > cutoff
        ]

        recent_tokens = [
            [ts.isoformat(), tokens]
            for ts, tokens in state.recent_tokens
            if ts > cutoff
        ]

        try:
            self.db.save_rate_limit_state(
                agent_id=agent_id,
                session_id=session_id,
                recent_request_timestamps=json.dumps(recent_requests),
                recent_token_usage=json.dumps(recent_tokens),
                cooldown_until=state.cooldown_until,
                consecutive_limit_hits=state.consecutive_limit_hits,
            )
            logger.debug(f"Saved rate limit state for {agent_id}")
        except Exception as e:
            logger.warning(f"Failed to save rate limit state for {agent_id}: {e}")

    def _restore_state(self, agent_id: str) -> bool:
        """
        Restore rate limit state for an agent from the database.

        Args:
            agent_id: Agent whose state to restore

        Returns:
            True if state was restored, False otherwise
        """
        try:
            saved_state = self.db.load_rate_limit_state(agent_id)
            if not saved_state:
                return False

            state = self._get_rate_state(agent_id)

            # Parse recent requests
            if saved_state.get("recent_request_timestamps"):
                timestamps = json.loads(saved_state["recent_request_timestamps"])
                for ts_str in timestamps:
                    try:
                        ts = datetime.fromisoformat(ts_str)
                        state.recent_requests.append(ts)
                    except (ValueError, TypeError):
                        pass

            # Parse recent tokens
            if saved_state.get("recent_token_usage"):
                token_data = json.loads(saved_state["recent_token_usage"])
                for item in token_data:
                    try:
                        ts = datetime.fromisoformat(item[0])
                        tokens = int(item[1])
                        state.recent_tokens.append((ts, tokens))
                    except (ValueError, TypeError, IndexError):
                        pass

            # Restore cooldown
            if saved_state.get("cooldown_until"):
                cooldown = saved_state["cooldown_until"]
                if isinstance(cooldown, str):
                    state.cooldown_until = datetime.fromisoformat(cooldown)
                elif isinstance(cooldown, datetime):
                    state.cooldown_until = cooldown

            # Restore consecutive hits
            state.consecutive_limit_hits = saved_state.get("consecutive_limit_hits", 0)

            logger.info(f"Restored rate limit state for {agent_id}")
            return True

        except Exception as e:
            logger.warning(f"Failed to restore rate limit state for {agent_id}: {e}")
            return False

    def _restore_all_state(self) -> None:
        """Restore state for all known agents from the database."""
        for agent_id in self._rate_configs.keys():
            self._restore_state(agent_id)

    def save_all_state(self) -> None:
        """Save state for all active agents (call on shutdown)."""
        for agent_id in self._rate_states.keys():
            self._save_state(agent_id)
        logger.info(f"Saved state for {len(self._rate_states)} agents")

    def end_all_sessions(self) -> int:
        """
        End all active sessions (call on shutdown).

        Returns:
            Number of sessions ended
        """
        agent_ids = list(self._sessions.keys())
        for agent_id in agent_ids:
            self.end_session(agent_id)
        return len(agent_ids)

    def get_session_history(
        self,
        agent_id: str,
        limit: int = 10,
    ) -> list[dict[str, Any]]:
        """
        Get recent session history for an agent.

        Args:
            agent_id: Agent to get history for
            limit: Maximum number of sessions to return

        Returns:
            List of session records
        """
        try:
            return self.db.get_cli_session_history(agent_id, limit)
        except Exception as e:
            logger.warning(f"Failed to get session history: {e}")
            return []


# Global instance
_cli_usage_tracker: Optional[CLIUsageTracker] = None


def get_cli_usage_tracker() -> Optional[CLIUsageTracker]:
    """Get the global CLI usage tracker instance."""
    return _cli_usage_tracker


def set_cli_usage_tracker(tracker: CLIUsageTracker) -> None:
    """Set the global CLI usage tracker instance."""
    global _cli_usage_tracker
    _cli_usage_tracker = tracker
