"""
CLI Agent Usage Tracker

Tracks session-based usage for interactive CLI agents (Claude Code, Gemini CLI, Codex).
Unlike API agents that have token-based billing, CLI agents have session/weekly limits.
"""

from dataclasses import dataclass, field
from datetime import datetime, timedelta
from enum import Enum
from typing import Optional, Dict, List, Any
import json
import logging
from pathlib import Path

logger = logging.getLogger(__name__)


class AgentAvailability(Enum):
    """Availability status for CLI agents."""
    AVAILABLE = "available"        # Ready to accept tasks
    BUSY = "busy"                  # Currently processing a task
    LIMITED = "limited"            # Approaching session limit (>80%)
    RATE_LIMITED = "rate_limited"  # Temporarily rate limited
    EXHAUSTED = "exhausted"        # Session limit reached
    UNAVAILABLE = "unavailable"    # Not responding or errored
    PAUSED = "paused"              # Manually paused


@dataclass
class SessionUsage:
    """Usage tracking for a single session."""

    session_id: str
    agent_id: str
    started_at: datetime = field(default_factory=datetime.now)

    # Request tracking
    total_requests: int = 0
    successful_requests: int = 0
    failed_requests: int = 0
    rate_limited_requests: int = 0

    # Token estimates (from responses when available)
    estimated_input_tokens: int = 0
    estimated_output_tokens: int = 0

    # Timing
    total_processing_time_seconds: float = 0.0
    last_request_at: Optional[datetime] = None

    # Limit detection
    limit_warnings: int = 0
    hit_session_limit: bool = False
    limit_hit_at: Optional[datetime] = None

    def record_request(
        self,
        success: bool,
        input_tokens: int = 0,
        output_tokens: int = 0,
        processing_time: float = 0.0,
        rate_limited: bool = False,
        limit_warning: bool = False,
    ) -> None:
        """Record a request to this agent."""
        self.total_requests += 1
        self.last_request_at = datetime.now()

        if rate_limited:
            self.rate_limited_requests += 1
        elif success:
            self.successful_requests += 1
        else:
            self.failed_requests += 1

        self.estimated_input_tokens += input_tokens
        self.estimated_output_tokens += output_tokens
        self.total_processing_time_seconds += processing_time

        if limit_warning:
            self.limit_warnings += 1

    def mark_limit_hit(self) -> None:
        """Mark that this session has hit its limit."""
        self.hit_session_limit = True
        self.limit_hit_at = datetime.now()

    @property
    def total_tokens(self) -> int:
        """Total estimated tokens used."""
        return self.estimated_input_tokens + self.estimated_output_tokens

    @property
    def success_rate(self) -> float:
        """Success rate as percentage."""
        if self.total_requests == 0:
            return 100.0
        return (self.successful_requests / self.total_requests) * 100

    @property
    def avg_processing_time(self) -> float:
        """Average processing time per request."""
        if self.total_requests == 0:
            return 0.0
        return self.total_processing_time_seconds / self.total_requests


@dataclass
class WeeklyUsage:
    """Weekly usage tracking for an agent."""

    agent_id: str
    week_start: datetime

    # Session tracking
    sessions: List[str] = field(default_factory=list)
    total_sessions: int = 0

    # Aggregate stats
    total_requests: int = 0
    total_tokens: int = 0
    total_processing_time: float = 0.0

    # Limits
    sessions_with_limit_hit: int = 0
    rate_limit_incidents: int = 0

    def add_session(self, session: SessionUsage) -> None:
        """Add a session's usage to the weekly total."""
        self.sessions.append(session.session_id)
        self.total_sessions += 1
        self.total_requests += session.total_requests
        self.total_tokens += session.total_tokens
        self.total_processing_time += session.total_processing_time_seconds

        if session.hit_session_limit:
            self.sessions_with_limit_hit += 1
        self.rate_limit_incidents += session.rate_limited_requests


@dataclass
class CLIAgentLimits:
    """Known limits for CLI agents."""

    # Estimated session limits (requests before needing new session)
    session_request_limit: int = 50

    # Estimated weekly limits
    weekly_request_limit: int = 500

    # Rate limiting
    requests_per_minute: int = 10
    cooldown_minutes: int = 5

    # Warning thresholds
    session_warning_threshold: float = 0.8  # Warn at 80%
    weekly_warning_threshold: float = 0.8


# Default limits by agent type
DEFAULT_CLI_LIMITS: Dict[str, CLIAgentLimits] = {
    "claude-code": CLIAgentLimits(
        session_request_limit=50,
        weekly_request_limit=500,
        requests_per_minute=10,
        cooldown_minutes=5,
    ),
    "gemini-cli": CLIAgentLimits(
        session_request_limit=100,  # More generous
        weekly_request_limit=1000,
        requests_per_minute=20,
        cooldown_minutes=2,
    ),
    "codex-cli": CLIAgentLimits(
        session_request_limit=50,
        weekly_request_limit=500,
        requests_per_minute=10,
        cooldown_minutes=5,
    ),
}


class CLIUsageTracker:
    """
    Tracks usage for CLI-based agents.

    CLI agents have session-based limits rather than token-based billing.
    This tracker monitors usage patterns to predict when limits will be hit
    and enables proactive load balancing.
    """

    def __init__(
        self,
        db: Any = None,
        persistence_path: Optional[Path] = None,
        limits: Optional[Dict[str, CLIAgentLimits]] = None,
    ):
        """
        Initialize the CLI usage tracker.

        Args:
            db: Database for persistence
            persistence_path: Path for JSON persistence fallback
            limits: Custom limits per agent type
        """
        self.db = db
        self.persistence_path = persistence_path or Path("data/cli_usage.json")
        self.limits = limits or DEFAULT_CLI_LIMITS.copy()

        # Current sessions by agent
        self._sessions: Dict[str, SessionUsage] = {}

        # Weekly tracking by agent
        self._weekly: Dict[str, WeeklyUsage] = {}

        # Availability status
        self._availability: Dict[str, AgentAvailability] = {}

        # Rate limit tracking
        self._last_request_time: Dict[str, datetime] = {}
        self._rate_limit_until: Dict[str, datetime] = {}

        # Load persisted state
        self._load_state()

    def _load_state(self) -> None:
        """Load persisted state from disk."""
        if self.persistence_path.exists():
            try:
                with open(self.persistence_path) as f:
                    data = json.load(f)
                    # TODO: Deserialize sessions and weekly data
                    logger.info("Loaded CLI usage state")
            except Exception as e:
                logger.warning(f"Failed to load CLI usage state: {e}")

    def _save_state(self) -> None:
        """Persist state to disk."""
        try:
            self.persistence_path.parent.mkdir(parents=True, exist_ok=True)
            # TODO: Serialize sessions and weekly data
            with open(self.persistence_path, 'w') as f:
                json.dump({
                    "saved_at": datetime.now().isoformat(),
                    "agents": list(self._sessions.keys()),
                }, f)
        except Exception as e:
            logger.warning(f"Failed to save CLI usage state: {e}")

    def start_session(self, agent_id: str) -> SessionUsage:
        """Start a new session for an agent."""
        session_id = f"sess-{agent_id}-{datetime.now().strftime('%Y%m%d%H%M%S')}"
        session = SessionUsage(session_id=session_id, agent_id=agent_id)

        # Archive old session if exists
        if agent_id in self._sessions:
            old_session = self._sessions[agent_id]
            self._archive_session(old_session)

        self._sessions[agent_id] = session
        self._availability[agent_id] = AgentAvailability.AVAILABLE

        logger.info(f"Started new session {session_id} for {agent_id}")
        return session

    def _archive_session(self, session: SessionUsage) -> None:
        """Archive a completed session to weekly tracking."""
        agent_id = session.agent_id

        # Get or create weekly tracker
        week_start = self._get_week_start()
        weekly_key = f"{agent_id}:{week_start.isoformat()}"

        if weekly_key not in self._weekly:
            self._weekly[weekly_key] = WeeklyUsage(
                agent_id=agent_id,
                week_start=week_start,
            )

        self._weekly[weekly_key].add_session(session)

    def _get_week_start(self) -> datetime:
        """Get the start of the current week (Monday)."""
        now = datetime.now()
        return now - timedelta(days=now.weekday())

    def record_request(
        self,
        agent_id: str,
        success: bool,
        input_tokens: int = 0,
        output_tokens: int = 0,
        processing_time: float = 0.0,
        response_text: str = "",
    ) -> AgentAvailability:
        """
        Record a request to a CLI agent.

        Args:
            agent_id: The agent that handled the request
            success: Whether the request succeeded
            input_tokens: Estimated input tokens
            output_tokens: Estimated output tokens
            processing_time: Time taken in seconds
            response_text: Response text (to detect limit messages)

        Returns:
            Current availability status
        """
        # Ensure session exists
        if agent_id not in self._sessions:
            self.start_session(agent_id)

        session = self._sessions[agent_id]

        # Detect rate limiting or limit hit from response
        rate_limited = self._detect_rate_limit(response_text)
        limit_warning = self._detect_limit_warning(response_text)
        limit_hit = self._detect_limit_hit(response_text)

        # Record the request
        session.record_request(
            success=success,
            input_tokens=input_tokens,
            output_tokens=output_tokens,
            processing_time=processing_time,
            rate_limited=rate_limited,
            limit_warning=limit_warning,
        )

        if limit_hit:
            session.mark_limit_hit()
            self._availability[agent_id] = AgentAvailability.EXHAUSTED
        elif rate_limited:
            self._handle_rate_limit(agent_id)
        else:
            # Update availability based on usage
            self._update_availability(agent_id)

        self._last_request_time[agent_id] = datetime.now()
        self._save_state()

        return self._availability.get(agent_id, AgentAvailability.AVAILABLE)

    def _detect_rate_limit(self, response: str) -> bool:
        """Detect if response indicates rate limiting."""
        rate_limit_phrases = [
            "rate limit",
            "too many requests",
            "slow down",
            "try again later",
            "quota exceeded",
        ]
        response_lower = response.lower()
        return any(phrase in response_lower for phrase in rate_limit_phrases)

    def _detect_limit_warning(self, response: str) -> bool:
        """Detect if response contains usage warning."""
        warning_phrases = [
            "approaching limit",
            "usage warning",
            "running low",
            "almost reached",
        ]
        response_lower = response.lower()
        return any(phrase in response_lower for phrase in warning_phrases)

    def _detect_limit_hit(self, response: str) -> bool:
        """Detect if response indicates limit has been hit."""
        limit_phrases = [
            "limit reached",
            "limit exceeded",
            "no more requests",
            "session expired",
            "upgrade required",
            "out of credits",
        ]
        response_lower = response.lower()
        return any(phrase in response_lower for phrase in limit_phrases)

    def _handle_rate_limit(self, agent_id: str) -> None:
        """Handle a rate limit hit."""
        limits = self.limits.get(agent_id, CLIAgentLimits())
        cooldown_until = datetime.now() + timedelta(minutes=limits.cooldown_minutes)

        self._rate_limit_until[agent_id] = cooldown_until
        self._availability[agent_id] = AgentAvailability.RATE_LIMITED

        logger.warning(
            f"Agent {agent_id} rate limited until {cooldown_until.isoformat()}"
        )

    def _update_availability(self, agent_id: str) -> None:
        """Update availability based on current usage."""
        session = self._sessions.get(agent_id)
        if not session:
            self._availability[agent_id] = AgentAvailability.UNAVAILABLE
            return

        limits = self.limits.get(agent_id, CLIAgentLimits())

        # Check if rate limit has expired
        if agent_id in self._rate_limit_until:
            if datetime.now() < self._rate_limit_until[agent_id]:
                self._availability[agent_id] = AgentAvailability.RATE_LIMITED
                return
            else:
                del self._rate_limit_until[agent_id]

        # Calculate session usage percentage
        session_pct = session.total_requests / limits.session_request_limit

        if session.hit_session_limit or session_pct >= 1.0:
            self._availability[agent_id] = AgentAvailability.EXHAUSTED
        elif session_pct >= limits.session_warning_threshold:
            self._availability[agent_id] = AgentAvailability.LIMITED
        else:
            self._availability[agent_id] = AgentAvailability.AVAILABLE

    def get_availability(self, agent_id: str) -> AgentAvailability:
        """Get current availability for an agent."""
        # Check for rate limit expiry
        if agent_id in self._rate_limit_until:
            if datetime.now() >= self._rate_limit_until[agent_id]:
                del self._rate_limit_until[agent_id]
                self._update_availability(agent_id)

        return self._availability.get(agent_id, AgentAvailability.UNAVAILABLE)

    def set_availability(
        self,
        agent_id: str,
        status: AgentAvailability
    ) -> None:
        """Manually set availability (e.g., for pause/resume)."""
        self._availability[agent_id] = status

    def get_session_usage(self, agent_id: str) -> Optional[SessionUsage]:
        """Get current session usage for an agent."""
        return self._sessions.get(agent_id)

    def get_usage_stats(self, agent_id: str) -> Dict[str, Any]:
        """Get comprehensive usage stats for an agent."""
        session = self._sessions.get(agent_id)
        limits = self.limits.get(agent_id, CLIAgentLimits())
        availability = self.get_availability(agent_id)

        if not session:
            return {
                "agent_id": agent_id,
                "availability": availability.value,
                "has_session": False,
            }

        session_pct = (session.total_requests / limits.session_request_limit) * 100

        return {
            "agent_id": agent_id,
            "availability": availability.value,
            "has_session": True,
            "session_id": session.session_id,
            "session_started": session.started_at.isoformat(),
            "total_requests": session.total_requests,
            "successful_requests": session.successful_requests,
            "failed_requests": session.failed_requests,
            "rate_limited_requests": session.rate_limited_requests,
            "estimated_tokens": session.total_tokens,
            "session_percentage": session_pct,
            "session_limit": limits.session_request_limit,
            "success_rate": session.success_rate,
            "avg_processing_time": session.avg_processing_time,
            "limit_warnings": session.limit_warnings,
            "hit_limit": session.hit_session_limit,
        }

    def get_all_stats(self) -> Dict[str, Dict[str, Any]]:
        """Get usage stats for all tracked agents."""
        all_agents = set(self._sessions.keys()) | set(self._availability.keys())
        return {agent_id: self.get_usage_stats(agent_id) for agent_id in all_agents}

    def get_best_available_agent(
        self,
        agent_ids: List[str],
        prefer_low_usage: bool = True,
    ) -> Optional[str]:
        """
        Get the best available agent from a list.

        Args:
            agent_ids: List of agent IDs to consider
            prefer_low_usage: If True, prefer agents with lower usage

        Returns:
            Best agent ID or None if none available
        """
        available = []

        for agent_id in agent_ids:
            status = self.get_availability(agent_id)
            if status in (AgentAvailability.AVAILABLE, AgentAvailability.LIMITED):
                session = self._sessions.get(agent_id)
                usage_pct = 0
                if session:
                    limits = self.limits.get(agent_id, CLIAgentLimits())
                    usage_pct = session.total_requests / limits.session_request_limit
                available.append((agent_id, status, usage_pct))

        if not available:
            return None

        # Sort by preference
        if prefer_low_usage:
            # Prefer AVAILABLE over LIMITED, then by usage percentage
            available.sort(key=lambda x: (
                0 if x[1] == AgentAvailability.AVAILABLE else 1,
                x[2]
            ))
        else:
            # Just prefer available
            available.sort(key=lambda x: 0 if x[1] == AgentAvailability.AVAILABLE else 1)

        return available[0][0]

    def reset_session(self, agent_id: str) -> SessionUsage:
        """Reset an agent's session (e.g., after re-authentication)."""
        logger.info(f"Resetting session for {agent_id}")
        return self.start_session(agent_id)


# Singleton instance
_cli_tracker: Optional[CLIUsageTracker] = None


def get_cli_tracker() -> CLIUsageTracker:
    """Get or create the global CLI usage tracker."""
    global _cli_tracker
    if _cli_tracker is None:
        _cli_tracker = CLIUsageTracker()
    return _cli_tracker


def set_cli_tracker(tracker: CLIUsageTracker) -> None:
    """Set the global CLI usage tracker."""
    global _cli_tracker
    _cli_tracker = tracker
