"""
MCP Budget - Cost management for MCP servers and tools.

Implements:
- Per-MCP-server budget limits
- Per-tool budget limits
- Usage tracking for MCP tool calls
- Budget enforcement before tool execution
"""

from dataclasses import dataclass, field
from datetime import datetime, date
from enum import Enum
from typing import Optional, Dict, List, Any, Set
import logging

logger = logging.getLogger(__name__)


class MCPBudgetStatus(Enum):
    """Status of MCP budget check."""
    WITHIN_BUDGET = "within_budget"
    WARNING = "warning"  # Approaching limit
    EXCEEDED = "exceeded"
    BLOCKED = "blocked"  # Server/tool is blocked


@dataclass
class MCPServerBudget:
    """Budget configuration for an MCP server."""

    server_name: str

    # Call limits
    daily_call_limit: int = 1000  # Max calls per day
    hourly_call_limit: int = 100  # Max calls per hour

    # Cost limits (USD) - for paid APIs
    daily_cost_limit: float = 10.0
    monthly_cost_limit: float = 200.0

    # Rate limiting
    min_call_interval_ms: int = 100  # Min ms between calls

    # Warning thresholds
    warning_threshold: float = 0.8

    # State
    is_blocked: bool = False
    block_reason: Optional[str] = None

    # Metadata
    created_at: datetime = field(default_factory=datetime.now)
    updated_at: datetime = field(default_factory=datetime.now)


@dataclass
class ToolBudget:
    """Budget configuration for a specific tool within an MCP server."""

    server_name: str
    tool_name: str

    # Call limits
    daily_call_limit: int = 500
    hourly_call_limit: int = 50

    # Cost per call (for estimation)
    cost_per_call: float = 0.0  # Free by default

    # Risk settings
    requires_approval: bool = False
    risk_level: str = "low"  # low, medium, high

    # State
    is_blocked: bool = False
    block_reason: Optional[str] = None

    # Metadata
    created_at: datetime = field(default_factory=datetime.now)


@dataclass
class MCPUsageRecord:
    """Usage record for MCP server/tool."""

    server_name: str
    date: date

    # Call counts
    total_calls: int = 0
    successful_calls: int = 0
    failed_calls: int = 0
    rejected_calls: int = 0

    # Cost tracking
    total_cost: float = 0.0

    # Per-tool breakdown
    tool_calls: Dict[str, int] = field(default_factory=dict)
    tool_costs: Dict[str, float] = field(default_factory=dict)

    # Hourly tracking (for rate limiting)
    hourly_calls: Dict[int, int] = field(default_factory=dict)  # hour -> count

    # Timestamps
    first_call_at: Optional[datetime] = None
    last_call_at: Optional[datetime] = None


@dataclass
class MCPBudgetCheckResult:
    """Result of an MCP budget check."""

    status: MCPBudgetStatus
    allowed: bool

    # Server info
    server_name: str
    tool_name: Optional[str] = None

    # Usage info
    daily_calls: int = 0
    hourly_calls: int = 0
    daily_cost: float = 0.0

    # Limits
    daily_call_limit: int = 0
    hourly_call_limit: int = 0
    daily_cost_limit: float = 0.0

    # Percentages
    daily_call_percentage: float = 0.0
    hourly_call_percentage: float = 0.0
    cost_percentage: float = 0.0

    # Reason if not allowed
    reason: Optional[str] = None

    # Rate limit info
    wait_ms: int = 0  # Ms to wait before next call


# Default MCP server budgets
DEFAULT_MCP_BUDGETS: Dict[str, MCPServerBudget] = {
    "filesystem": MCPServerBudget(
        server_name="filesystem",
        daily_call_limit=10000,  # File ops are frequent
        hourly_call_limit=1000,
        daily_cost_limit=0.0,  # Free
    ),
    "github": MCPServerBudget(
        server_name="github",
        daily_call_limit=5000,  # GitHub API limits
        hourly_call_limit=500,
        daily_cost_limit=0.0,
    ),
    "web-search": MCPServerBudget(
        server_name="web-search",
        daily_call_limit=100,  # Search APIs are expensive
        hourly_call_limit=20,
        daily_cost_limit=5.0,
    ),
    "database": MCPServerBudget(
        server_name="database",
        daily_call_limit=1000,
        hourly_call_limit=100,
        daily_cost_limit=0.0,
    ),
    "browser": MCPServerBudget(
        server_name="browser",
        daily_call_limit=500,
        hourly_call_limit=50,
        daily_cost_limit=1.0,
    ),
}

# High-risk tools that require approval
HIGH_RISK_TOOLS: Dict[str, Set[str]] = {
    "filesystem": {"delete_file", "delete_directory", "write_file"},
    "github": {"create_pull_request", "merge_pull_request", "delete_branch"},
    "database": {"execute_query", "drop_table", "truncate_table"},
    "shell": {"execute_command"},
}


class MCPUsageTracker:
    """Tracks and enforces MCP server and tool budgets."""

    def __init__(
        self,
        db: Any,
        server_budgets: Optional[Dict[str, MCPServerBudget]] = None,
        tool_budgets: Optional[Dict[str, ToolBudget]] = None,
    ):
        """
        Initialize MCP usage tracker.

        Args:
            db: Database instance for persistence
            server_budgets: Custom server budget configurations
            tool_budgets: Custom tool budget configurations
        """
        self.db = db
        self._server_budgets = server_budgets or DEFAULT_MCP_BUDGETS.copy()
        self._tool_budgets: Dict[str, ToolBudget] = tool_budgets or {}
        self._usage: Dict[str, MCPUsageRecord] = {}
        self._last_call_times: Dict[str, datetime] = {}

    def get_server_budget(self, server_name: str) -> MCPServerBudget:
        """Get budget for an MCP server, creating default if needed."""
        if server_name not in self._server_budgets:
            self._server_budgets[server_name] = MCPServerBudget(server_name=server_name)
        return self._server_budgets[server_name]

    def set_server_budget(self, server_name: str, budget: MCPServerBudget) -> None:
        """Set or update budget for an MCP server."""
        self._server_budgets[server_name] = budget
        budget.updated_at = datetime.now()

    def get_tool_budget(self, server_name: str, tool_name: str) -> Optional[ToolBudget]:
        """Get budget for a specific tool."""
        key = f"{server_name}:{tool_name}"
        return self._tool_budgets.get(key)

    def set_tool_budget(self, budget: ToolBudget) -> None:
        """Set or update budget for a specific tool."""
        key = f"{budget.server_name}:{budget.tool_name}"
        self._tool_budgets[key] = budget

    def _get_usage(self, server_name: str) -> MCPUsageRecord:
        """Get or create usage record for today."""
        today = date.today()
        key = f"{server_name}:{today.isoformat()}"

        if key not in self._usage:
            # Try to load from database
            usage = self._load_usage_from_db(server_name, today)
            if usage:
                self._usage[key] = usage
            else:
                self._usage[key] = MCPUsageRecord(
                    server_name=server_name,
                    date=today,
                )

        return self._usage[key]

    def _load_usage_from_db(self, server_name: str, day: date) -> Optional[MCPUsageRecord]:
        """Load usage record from database."""
        if not hasattr(self.db, 'get_mcp_usage'):
            return None

        record = self.db.get_mcp_usage(server_name, day.isoformat())
        if not record:
            return None

        return MCPUsageRecord(
            server_name=server_name,
            date=day,
            total_calls=record.get('total_calls', 0),
            successful_calls=record.get('successful_calls', 0),
            failed_calls=record.get('failed_calls', 0),
            rejected_calls=record.get('rejected_calls', 0),
            total_cost=record.get('total_cost', 0.0),
            tool_calls=record.get('tool_calls', {}),
            tool_costs=record.get('tool_costs', {}),
        )

    def _save_usage_to_db(self, usage: MCPUsageRecord) -> None:
        """Save usage record to database."""
        if not hasattr(self.db, 'upsert_mcp_usage'):
            return

        self.db.upsert_mcp_usage(
            server_name=usage.server_name,
            date=usage.date.isoformat(),
            total_calls=usage.total_calls,
            successful_calls=usage.successful_calls,
            failed_calls=usage.failed_calls,
            rejected_calls=usage.rejected_calls,
            total_cost=usage.total_cost,
            tool_calls=usage.tool_calls,
            tool_costs=usage.tool_costs,
        )

    def _get_current_hour_calls(self, usage: MCPUsageRecord) -> int:
        """Get call count for current hour."""
        current_hour = datetime.now().hour
        return usage.hourly_calls.get(current_hour, 0)

    def _check_rate_limit(self, server_name: str) -> int:
        """Check rate limit and return wait time in ms."""
        budget = self.get_server_budget(server_name)
        last_call = self._last_call_times.get(server_name)

        if last_call is None:
            return 0

        elapsed_ms = int((datetime.now() - last_call).total_seconds() * 1000)
        if elapsed_ms < budget.min_call_interval_ms:
            return budget.min_call_interval_ms - elapsed_ms

        return 0

    def check_budget(
        self,
        server_name: str,
        tool_name: Optional[str] = None,
    ) -> MCPBudgetCheckResult:
        """
        Check if MCP call is within budget.

        Args:
            server_name: MCP server name
            tool_name: Optional tool name for tool-specific checks

        Returns:
            MCPBudgetCheckResult with status and details
        """
        server_budget = self.get_server_budget(server_name)
        usage = self._get_usage(server_name)

        # Check if server is blocked
        if server_budget.is_blocked:
            return MCPBudgetCheckResult(
                status=MCPBudgetStatus.BLOCKED,
                allowed=False,
                server_name=server_name,
                tool_name=tool_name,
                reason=server_budget.block_reason or "MCP server is blocked",
            )

        # Check tool-specific blocking
        if tool_name:
            tool_key = f"{server_name}:{tool_name}"
            tool_budget = self._tool_budgets.get(tool_key)
            if tool_budget and tool_budget.is_blocked:
                return MCPBudgetCheckResult(
                    status=MCPBudgetStatus.BLOCKED,
                    allowed=False,
                    server_name=server_name,
                    tool_name=tool_name,
                    reason=tool_budget.block_reason or f"Tool {tool_name} is blocked",
                )

        # Check rate limit
        wait_ms = self._check_rate_limit(server_name)

        # Get current counts
        daily_calls = usage.total_calls
        hourly_calls = self._get_current_hour_calls(usage)
        daily_cost = usage.total_cost

        # Calculate percentages
        daily_call_pct = daily_calls / server_budget.daily_call_limit if server_budget.daily_call_limit > 0 else 0
        hourly_call_pct = hourly_calls / server_budget.hourly_call_limit if server_budget.hourly_call_limit > 0 else 0
        cost_pct = daily_cost / server_budget.daily_cost_limit if server_budget.daily_cost_limit > 0 else 0

        # Check daily call limit
        if daily_calls >= server_budget.daily_call_limit:
            return MCPBudgetCheckResult(
                status=MCPBudgetStatus.EXCEEDED,
                allowed=False,
                server_name=server_name,
                tool_name=tool_name,
                reason=f"Daily call limit exceeded ({daily_calls}/{server_budget.daily_call_limit})",
                daily_calls=daily_calls,
                daily_call_limit=server_budget.daily_call_limit,
                daily_call_percentage=daily_call_pct,
            )

        # Check hourly call limit
        if hourly_calls >= server_budget.hourly_call_limit:
            return MCPBudgetCheckResult(
                status=MCPBudgetStatus.EXCEEDED,
                allowed=False,
                server_name=server_name,
                tool_name=tool_name,
                reason=f"Hourly call limit exceeded ({hourly_calls}/{server_budget.hourly_call_limit})",
                hourly_calls=hourly_calls,
                hourly_call_limit=server_budget.hourly_call_limit,
                hourly_call_percentage=hourly_call_pct,
            )

        # Check cost limit
        if server_budget.daily_cost_limit > 0 and daily_cost >= server_budget.daily_cost_limit:
            return MCPBudgetCheckResult(
                status=MCPBudgetStatus.EXCEEDED,
                allowed=False,
                server_name=server_name,
                tool_name=tool_name,
                reason=f"Daily cost limit exceeded (${daily_cost:.2f}/${server_budget.daily_cost_limit:.2f})",
                daily_cost=daily_cost,
                daily_cost_limit=server_budget.daily_cost_limit,
                cost_percentage=cost_pct,
            )

        # Check tool-specific limits
        if tool_name:
            tool_key = f"{server_name}:{tool_name}"
            tool_budget = self._tool_budgets.get(tool_key)

            if tool_budget:
                tool_calls = usage.tool_calls.get(tool_name, 0)

                if tool_calls >= tool_budget.daily_call_limit:
                    return MCPBudgetCheckResult(
                        status=MCPBudgetStatus.EXCEEDED,
                        allowed=False,
                        server_name=server_name,
                        tool_name=tool_name,
                        reason=f"Tool daily limit exceeded ({tool_calls}/{tool_budget.daily_call_limit})",
                    )

        # Check warning threshold
        max_pct = max(daily_call_pct, hourly_call_pct, cost_pct)
        status = MCPBudgetStatus.WARNING if max_pct >= server_budget.warning_threshold else MCPBudgetStatus.WITHIN_BUDGET

        return MCPBudgetCheckResult(
            status=status,
            allowed=True,
            server_name=server_name,
            tool_name=tool_name,
            daily_calls=daily_calls,
            hourly_calls=hourly_calls,
            daily_cost=daily_cost,
            daily_call_limit=server_budget.daily_call_limit,
            hourly_call_limit=server_budget.hourly_call_limit,
            daily_cost_limit=server_budget.daily_cost_limit,
            daily_call_percentage=daily_call_pct,
            hourly_call_percentage=hourly_call_pct,
            cost_percentage=cost_pct,
            wait_ms=wait_ms,
        )

    def requires_approval(self, server_name: str, tool_name: str) -> bool:
        """Check if a tool call requires approval."""
        # Check high-risk tools list
        if server_name in HIGH_RISK_TOOLS:
            if tool_name in HIGH_RISK_TOOLS[server_name]:
                return True

        # Check tool-specific budget
        tool_key = f"{server_name}:{tool_name}"
        tool_budget = self._tool_budgets.get(tool_key)
        if tool_budget and tool_budget.requires_approval:
            return True

        return False

    def record_call(
        self,
        server_name: str,
        tool_name: str,
        cost: float = 0.0,
        success: bool = True,
    ) -> None:
        """
        Record an MCP tool call.

        Args:
            server_name: MCP server name
            tool_name: Tool that was called
            cost: Cost of the call
            success: Whether the call succeeded
        """
        usage = self._get_usage(server_name)
        now = datetime.now()

        # Update call counts
        usage.total_calls += 1
        if success:
            usage.successful_calls += 1
        else:
            usage.failed_calls += 1

        # Update cost
        if cost > 0:
            usage.total_cost += cost

        # Update tool-specific counts
        if tool_name not in usage.tool_calls:
            usage.tool_calls[tool_name] = 0
            usage.tool_costs[tool_name] = 0.0

        usage.tool_calls[tool_name] += 1
        if cost > 0:
            usage.tool_costs[tool_name] += cost

        # Update hourly tracking
        current_hour = now.hour
        if current_hour not in usage.hourly_calls:
            usage.hourly_calls[current_hour] = 0
        usage.hourly_calls[current_hour] += 1

        # Update timestamps
        if usage.first_call_at is None:
            usage.first_call_at = now
        usage.last_call_at = now

        # Update last call time for rate limiting
        self._last_call_times[server_name] = now

        # Persist to database
        self._save_usage_to_db(usage)

        # Log warnings
        budget = self.get_server_budget(server_name)
        daily_pct = usage.total_calls / budget.daily_call_limit if budget.daily_call_limit > 0 else 0

        if daily_pct >= budget.warning_threshold:
            logger.warning(
                f"MCP server {server_name} approaching daily call limit: "
                f"{usage.total_calls}/{budget.daily_call_limit} ({daily_pct:.1%})"
            )

    def record_rejection(self, server_name: str, tool_name: str) -> None:
        """Record a rejected MCP call."""
        usage = self._get_usage(server_name)
        usage.rejected_calls += 1
        self._save_usage_to_db(usage)

    def block_server(self, server_name: str, reason: str) -> None:
        """Block an MCP server."""
        budget = self.get_server_budget(server_name)
        budget.is_blocked = True
        budget.block_reason = reason
        budget.updated_at = datetime.now()
        logger.info(f"Blocked MCP server {server_name}: {reason}")

    def unblock_server(self, server_name: str) -> None:
        """Unblock an MCP server."""
        budget = self.get_server_budget(server_name)
        budget.is_blocked = False
        budget.block_reason = None
        budget.updated_at = datetime.now()
        logger.info(f"Unblocked MCP server {server_name}")

    def block_tool(self, server_name: str, tool_name: str, reason: str) -> None:
        """Block a specific tool."""
        tool_key = f"{server_name}:{tool_name}"
        if tool_key not in self._tool_budgets:
            self._tool_budgets[tool_key] = ToolBudget(
                server_name=server_name,
                tool_name=tool_name,
            )

        self._tool_budgets[tool_key].is_blocked = True
        self._tool_budgets[tool_key].block_reason = reason
        logger.info(f"Blocked tool {server_name}:{tool_name}: {reason}")

    def unblock_tool(self, server_name: str, tool_name: str) -> None:
        """Unblock a specific tool."""
        tool_key = f"{server_name}:{tool_name}"
        if tool_key in self._tool_budgets:
            self._tool_budgets[tool_key].is_blocked = False
            self._tool_budgets[tool_key].block_reason = None
            logger.info(f"Unblocked tool {server_name}:{tool_name}")

    def get_server_summary(self, server_name: str) -> Dict[str, Any]:
        """Get usage summary for an MCP server."""
        budget = self.get_server_budget(server_name)
        usage = self._get_usage(server_name)

        return {
            "server_name": server_name,
            "date": usage.date.isoformat(),
            "budget": {
                "daily_call_limit": budget.daily_call_limit,
                "hourly_call_limit": budget.hourly_call_limit,
                "daily_cost_limit": budget.daily_cost_limit,
            },
            "usage": {
                "total_calls": usage.total_calls,
                "successful_calls": usage.successful_calls,
                "failed_calls": usage.failed_calls,
                "rejected_calls": usage.rejected_calls,
                "total_cost": usage.total_cost,
            },
            "tool_breakdown": usage.tool_calls,
            "status": "blocked" if budget.is_blocked else "active",
        }

    def get_all_servers_summary(self) -> List[Dict[str, Any]]:
        """Get usage summary for all tracked MCP servers."""
        summaries = []
        for server_name in self._server_budgets:
            summaries.append(self.get_server_summary(server_name))
        return summaries


# Singleton instance
_mcp_tracker: Optional[MCPUsageTracker] = None


def get_mcp_tracker() -> Optional[MCPUsageTracker]:
    """Get the global MCP tracker instance."""
    return _mcp_tracker


def set_mcp_tracker(tracker: MCPUsageTracker) -> None:
    """Set the global MCP tracker instance."""
    global _mcp_tracker
    _mcp_tracker = tracker


def check_mcp_budget(server_name: str, tool_name: Optional[str] = None) -> bool:
    """Quick check if MCP call is within budget."""
    tracker = get_mcp_tracker()
    if tracker is None:
        return True  # No enforcement configured

    result = tracker.check_budget(server_name, tool_name)
    return result.allowed
