"""
Cost Optimizer - Analyze usage and recommend cost optimizations.

Provides:
- Usage pattern analysis
- Cost inefficiency detection
- Optimization recommendations
- Cost projections
"""

import logging
from dataclasses import dataclass, field
from datetime import datetime, timedelta
from enum import Enum
from typing import Any, Optional

logger = logging.getLogger(__name__)


class InefficiencyType(Enum):
    """Types of cost inefficiencies."""

    HIGH_RETRY_RATE = "high_retry_rate"
    OVERPROVISIONED_MODEL = "overprovisioned_model"
    EXCESSIVE_TOKENS = "excessive_tokens"
    PEAK_HOUR_USAGE = "peak_hour_usage"
    REDUNDANT_OPERATIONS = "redundant_operations"
    UNUSED_CAPACITY = "unused_capacity"
    RATE_LIMIT_WASTE = "rate_limit_waste"


@dataclass
class UsageRecord:
    """A single usage record."""

    record_id: str
    agent_id: str
    task_id: str
    model: str
    input_tokens: int
    output_tokens: int
    cost_usd: float
    timestamp: datetime
    latency_ms: float = 0.0
    is_retry: bool = False
    task_complexity: str = "medium"  # low, medium, high
    success: bool = True


@dataclass
class Inefficiency:
    """A detected cost inefficiency."""

    inefficiency_type: InefficiencyType
    severity: str  # low, medium, high
    metric_value: float
    recommendation: str
    potential_savings_usd: float = 0.0
    agent_id: Optional[str] = None
    task_id: Optional[str] = None
    details: dict[str, Any] = field(default_factory=dict)

    def to_dict(self) -> dict[str, Any]:
        """Convert to dictionary."""
        return {
            "type": self.inefficiency_type.value,
            "severity": self.severity,
            "metric_value": self.metric_value,
            "recommendation": self.recommendation,
            "potential_savings_usd": self.potential_savings_usd,
            "agent_id": self.agent_id,
            "task_id": self.task_id,
            "details": self.details,
        }


@dataclass
class CostAnalysis:
    """Cost analysis results."""

    total_cost_usd: float
    period_days: int
    by_agent: dict[str, float]
    by_model: dict[str, float]
    by_task_type: dict[str, float]
    by_hour: dict[int, float]
    peak_hours: list[int]
    inefficiencies: list[Inefficiency]
    average_cost_per_task: float
    total_tokens: int
    total_tasks: int

    def to_dict(self) -> dict[str, Any]:
        """Convert to dictionary."""
        return {
            "total_cost_usd": self.total_cost_usd,
            "period_days": self.period_days,
            "by_agent": self.by_agent,
            "by_model": self.by_model,
            "by_task_type": self.by_task_type,
            "by_hour": self.by_hour,
            "peak_hours": self.peak_hours,
            "inefficiencies": [i.to_dict() for i in self.inefficiencies],
            "average_cost_per_task": self.average_cost_per_task,
            "total_tokens": self.total_tokens,
            "total_tasks": self.total_tasks,
        }


@dataclass
class CostProjection:
    """Cost projection for future spending."""

    projected_daily_usd: float
    projected_weekly_usd: float
    projected_monthly_usd: float
    trend: str  # increasing, stable, decreasing
    confidence: float  # 0-1
    factors: list[str]

    def to_dict(self) -> dict[str, Any]:
        """Convert to dictionary."""
        return {
            "projected_daily_usd": self.projected_daily_usd,
            "projected_weekly_usd": self.projected_weekly_usd,
            "projected_monthly_usd": self.projected_monthly_usd,
            "trend": self.trend,
            "confidence": self.confidence,
            "factors": self.factors,
        }


class CostOptimizer:
    """
    Analyze usage and recommend cost optimizations.

    Examines usage patterns to identify inefficiencies and provide
    actionable recommendations for reducing costs.

    Example:
        optimizer = CostOptimizer()

        # Add usage data
        optimizer.add_record(UsageRecord(...))

        # Analyze
        analysis = optimizer.analyze_usage(days=7)
        for inefficiency in analysis.inefficiencies:
            print(f"{inefficiency.recommendation}")
            print(f"Potential savings: ${inefficiency.potential_savings_usd:.2f}")
    """

    # Model pricing per 1K tokens (approximate)
    MODEL_PRICING = {
        "opus": {"input": 0.015, "output": 0.075},
        "sonnet": {"input": 0.003, "output": 0.015},
        "haiku": {"input": 0.00025, "output": 0.00125},
        "gpt-4o": {"input": 0.005, "output": 0.015},
        "gpt-4o-mini": {"input": 0.00015, "output": 0.0006},
        "o1": {"input": 0.015, "output": 0.060},
        "gemini-pro": {"input": 0.00125, "output": 0.005},
        "gemini-flash": {"input": 0.000075, "output": 0.0003},
    }

    # Thresholds for inefficiency detection
    RETRY_RATE_THRESHOLD = 0.15  # 15% retry rate is concerning
    EXCESSIVE_TOKEN_MULTIPLIER = 2.0  # 2x average is excessive
    PEAK_HOUR_PREMIUM = 1.3  # 30% more usage in peak hours

    def __init__(self):
        """Initialize the cost optimizer."""
        self._records: list[UsageRecord] = []

    def add_record(self, record: UsageRecord) -> None:
        """Add a usage record."""
        self._records.append(record)

    def add_records(self, records: list[UsageRecord]) -> None:
        """Add multiple usage records."""
        self._records.extend(records)

    def clear_records(self) -> None:
        """Clear all records."""
        self._records.clear()

    def analyze_usage(self, days: int = 7) -> CostAnalysis:
        """
        Analyze recent usage patterns.

        Args:
            days: Number of days to analyze

        Returns:
            CostAnalysis with findings
        """
        since = datetime.now() - timedelta(days=days)
        records = [r for r in self._records if r.timestamp >= since]

        if not records:
            return CostAnalysis(
                total_cost_usd=0,
                period_days=days,
                by_agent={},
                by_model={},
                by_task_type={},
                by_hour={},
                peak_hours=[],
                inefficiencies=[],
                average_cost_per_task=0,
                total_tokens=0,
                total_tasks=0,
            )

        # Aggregate by dimensions
        by_agent = self._group_cost_by(records, lambda r: r.agent_id)
        by_model = self._group_cost_by(records, lambda r: r.model)
        by_task_type = self._group_cost_by(records, lambda r: r.task_complexity)
        by_hour = self._group_cost_by(records, lambda r: r.timestamp.hour)

        # Find peak hours
        avg_hourly = sum(by_hour.values()) / 24 if by_hour else 0
        peak_hours = [h for h, c in by_hour.items() if c > avg_hourly * self.PEAK_HOUR_PREMIUM]

        # Detect inefficiencies
        inefficiencies = self._find_inefficiencies(records, by_agent, by_model)

        total_cost = sum(r.cost_usd for r in records)
        total_tokens = sum(r.input_tokens + r.output_tokens for r in records)
        unique_tasks = len(set(r.task_id for r in records))

        return CostAnalysis(
            total_cost_usd=total_cost,
            period_days=days,
            by_agent=by_agent,
            by_model=by_model,
            by_task_type=by_task_type,
            by_hour={int(k): v for k, v in by_hour.items()},
            peak_hours=peak_hours,
            inefficiencies=inefficiencies,
            average_cost_per_task=total_cost / unique_tasks if unique_tasks > 0 else 0,
            total_tokens=total_tokens,
            total_tasks=unique_tasks,
        )

    def _group_cost_by(self, records: list[UsageRecord], key_fn) -> dict[Any, float]:
        """Group costs by a key function."""
        result: dict[Any, float] = {}
        for record in records:
            key = key_fn(record)
            result[key] = result.get(key, 0) + record.cost_usd
        return result

    def _find_inefficiencies(
        self,
        records: list[UsageRecord],
        by_agent: dict[str, float],
        by_model: dict[str, float],
    ) -> list[Inefficiency]:
        """Find cost inefficiencies."""
        inefficiencies: list[Inefficiency] = []

        # Check retry rates by agent
        for agent_id in by_agent.keys():
            agent_records = [r for r in records if r.agent_id == agent_id]
            if len(agent_records) < 5:  # Need minimum sample
                continue

            retry_count = sum(1 for r in agent_records if r.is_retry)
            retry_rate = retry_count / len(agent_records)

            if retry_rate > self.RETRY_RATE_THRESHOLD:
                retry_cost = sum(r.cost_usd for r in agent_records if r.is_retry)
                inefficiencies.append(Inefficiency(
                    inefficiency_type=InefficiencyType.HIGH_RETRY_RATE,
                    severity="high" if retry_rate > 0.25 else "medium",
                    metric_value=retry_rate,
                    recommendation=f"Agent '{agent_id}' has {retry_rate:.0%} retry rate. "
                                   f"Investigate failures to reduce retries.",
                    potential_savings_usd=retry_cost * 0.5,  # Assume 50% reducible
                    agent_id=agent_id,
                ))

        # Check for overprovisioned models
        for record in records:
            if self._is_overprovisioned(record):
                cheaper_model = self._suggest_cheaper_model(record)
                savings = self._calculate_model_savings(record, cheaper_model)

                inefficiencies.append(Inefficiency(
                    inefficiency_type=InefficiencyType.OVERPROVISIONED_MODEL,
                    severity="medium",
                    metric_value=record.cost_usd,
                    recommendation=f"Task '{record.task_id}' used {record.model} for "
                                   f"{record.task_complexity} complexity. Consider {cheaper_model}.",
                    potential_savings_usd=savings,
                    task_id=record.task_id,
                    details={"current_model": record.model, "suggested_model": cheaper_model},
                ))

        # Check for excessive token usage
        avg_tokens = sum(r.input_tokens + r.output_tokens for r in records) / len(records)
        for record in records:
            total_tokens = record.input_tokens + record.output_tokens
            if total_tokens > avg_tokens * self.EXCESSIVE_TOKEN_MULTIPLIER:
                inefficiencies.append(Inefficiency(
                    inefficiency_type=InefficiencyType.EXCESSIVE_TOKENS,
                    severity="low",
                    metric_value=total_tokens,
                    recommendation=f"Task '{record.task_id}' used {total_tokens:,} tokens, "
                                   f"significantly above average ({avg_tokens:,.0f}). "
                                   f"Consider prompt optimization.",
                    potential_savings_usd=record.cost_usd * 0.3,  # Assume 30% reducible
                    task_id=record.task_id,
                ))

        return inefficiencies

    def _is_overprovisioned(self, record: UsageRecord) -> bool:
        """Check if a model is overprovisioned for the task."""
        expensive_models = ["opus", "o1", "gpt-4o"]
        return (
            record.model.lower() in expensive_models
            and record.task_complexity == "low"
        )

    def _suggest_cheaper_model(self, record: UsageRecord) -> str:
        """Suggest a cheaper model for the task."""
        if "opus" in record.model.lower():
            return "sonnet" if record.task_complexity == "medium" else "haiku"
        elif "o1" in record.model.lower():
            return "gpt-4o-mini"
        elif "gpt-4o" in record.model.lower() and "mini" not in record.model.lower():
            return "gpt-4o-mini"
        return record.model

    def _calculate_model_savings(self, record: UsageRecord, cheaper_model: str) -> float:
        """Calculate potential savings from using a cheaper model."""
        current_pricing = self.MODEL_PRICING.get(record.model.lower(), {"input": 0, "output": 0})
        cheaper_pricing = self.MODEL_PRICING.get(cheaper_model.lower(), {"input": 0, "output": 0})

        current_cost = (
            (record.input_tokens / 1000) * current_pricing["input"]
            + (record.output_tokens / 1000) * current_pricing["output"]
        )
        cheaper_cost = (
            (record.input_tokens / 1000) * cheaper_pricing["input"]
            + (record.output_tokens / 1000) * cheaper_pricing["output"]
        )

        return max(0, current_cost - cheaper_cost)

    def project_costs(self, days_ahead: int = 30) -> CostProjection:
        """
        Project future costs based on historical patterns.

        Args:
            days_ahead: Number of days to project

        Returns:
            CostProjection with estimates
        """
        # Analyze recent trends
        week_analysis = self.analyze_usage(days=7)
        month_analysis = self.analyze_usage(days=30)

        if week_analysis.total_tasks == 0:
            return CostProjection(
                projected_daily_usd=0,
                projected_weekly_usd=0,
                projected_monthly_usd=0,
                trend="stable",
                confidence=0.0,
                factors=["No recent usage data"],
            )

        # Calculate daily averages
        daily_avg_week = week_analysis.total_cost_usd / 7
        daily_avg_month = month_analysis.total_cost_usd / 30 if month_analysis.total_cost_usd > 0 else daily_avg_week

        # Determine trend
        if daily_avg_week > daily_avg_month * 1.2:
            trend = "increasing"
            multiplier = 1.1
        elif daily_avg_week < daily_avg_month * 0.8:
            trend = "decreasing"
            multiplier = 0.9
        else:
            trend = "stable"
            multiplier = 1.0

        # Project
        projected_daily = daily_avg_week * multiplier
        projected_weekly = projected_daily * 7
        projected_monthly = projected_daily * 30

        # Determine confidence based on data availability
        confidence = min(1.0, len(self._records) / 100)

        factors = []
        if trend == "increasing":
            factors.append("Usage trending upward")
        elif trend == "decreasing":
            factors.append("Usage trending downward")

        if week_analysis.inefficiencies:
            factors.append(f"{len(week_analysis.inefficiencies)} inefficiencies detected")

        return CostProjection(
            projected_daily_usd=projected_daily,
            projected_weekly_usd=projected_weekly,
            projected_monthly_usd=projected_monthly,
            trend=trend,
            confidence=confidence,
            factors=factors,
        )

    def get_recommendations(self, limit: int = 5) -> list[dict[str, Any]]:
        """
        Get top cost optimization recommendations.

        Args:
            limit: Maximum recommendations to return

        Returns:
            List of recommendations sorted by potential savings
        """
        analysis = self.analyze_usage(days=30)

        recommendations = []
        for ineff in analysis.inefficiencies:
            recommendations.append({
                "type": ineff.inefficiency_type.value,
                "recommendation": ineff.recommendation,
                "potential_savings_usd": ineff.potential_savings_usd,
                "severity": ineff.severity,
                "agent_id": ineff.agent_id,
                "task_id": ineff.task_id,
            })

        # Sort by potential savings
        recommendations.sort(key=lambda r: r["potential_savings_usd"], reverse=True)
        return recommendations[:limit]


# Module-level instance
_optimizer: Optional[CostOptimizer] = None


def get_cost_optimizer() -> CostOptimizer:
    """Get or create the global cost optimizer."""
    global _optimizer
    if _optimizer is None:
        _optimizer = CostOptimizer()
    return _optimizer


def set_cost_optimizer(optimizer: Optional[CostOptimizer]) -> None:
    """Set the global cost optimizer."""
    global _optimizer
    _optimizer = optimizer
