"""
Retry Logic - Automatic retry for transient failures.

Implements:
- Retry decorator with exponential backoff
- Transient error detection
- Fallback model configuration
- Circuit breaker pattern
"""

from dataclasses import dataclass, field
from datetime import datetime, timedelta
from typing import Optional, Dict, List, Any, Callable, Type, TypeVar, Union
from functools import wraps
import asyncio
import logging

logger = logging.getLogger(__name__)

T = TypeVar('T')


# Known transient exceptions by provider
TRANSIENT_EXCEPTIONS = {
    "anthropic": [
        "RateLimitError",
        "APIConnectionError",
        "APITimeoutError",
        "InternalServerError",
        "ServiceUnavailableError",
    ],
    "openai": [
        "RateLimitError",
        "APIConnectionError",
        "Timeout",
        "InternalServerError",
        "ServiceUnavailableError",
    ],
    "google": [
        "ResourceExhausted",
        "ServiceUnavailable",
        "DeadlineExceeded",
    ],
    "default": [
        "RateLimitError",
        "ConnectionError",
        "TimeoutError",
        "asyncio.TimeoutError",
    ],
}


@dataclass
class RetryConfig:
    """Configuration for retry behavior."""

    max_retries: int = 3
    initial_delay: float = 1.0
    max_delay: float = 60.0
    exponential_base: float = 2.0
    jitter: bool = True

    # Which exceptions to retry
    retry_exceptions: List[str] = field(default_factory=lambda: [
        "RateLimitError",
        "ConnectionError",
        "TimeoutError",
        "ServiceUnavailable",
    ])

    # Always fail on these
    fatal_exceptions: List[str] = field(default_factory=lambda: [
        "AuthenticationError",
        "InvalidRequestError",
        "PermissionDenied",
    ])


@dataclass
class FallbackModel:
    """Configuration for a fallback model."""

    provider: str
    model: str
    priority: int = 0  # Lower is higher priority

    # Constraints
    max_tokens: Optional[int] = None
    cost_multiplier: float = 1.0

    # When to use
    on_rate_limit: bool = True
    on_timeout: bool = True
    on_error: bool = False


@dataclass
class FallbackConfig:
    """Configuration for fallback behavior."""

    # Ordered list of fallback models
    fallbacks: List[FallbackModel] = field(default_factory=list)

    # Whether to use fallbacks
    enabled: bool = True

    # Max fallback attempts
    max_fallback_attempts: int = 2


# Default fallback configurations
DEFAULT_FALLBACKS: Dict[str, FallbackConfig] = {
    "claude-3-opus": FallbackConfig(
        fallbacks=[
            FallbackModel(provider="anthropic", model="claude-3-sonnet", priority=1),
            FallbackModel(provider="anthropic", model="claude-3-haiku", priority=2),
        ]
    ),
    "claude-3-sonnet": FallbackConfig(
        fallbacks=[
            FallbackModel(provider="anthropic", model="claude-3-haiku", priority=1),
            FallbackModel(provider="openai", model="gpt-4o", priority=2, on_rate_limit=True),
        ]
    ),
    "gpt-4o": FallbackConfig(
        fallbacks=[
            FallbackModel(provider="openai", model="gpt-4o-mini", priority=1),
            FallbackModel(provider="anthropic", model="claude-3-sonnet", priority=2),
        ]
    ),
    "gemini-1.5-pro": FallbackConfig(
        fallbacks=[
            FallbackModel(provider="google", model="gemini-1.5-flash", priority=1),
        ]
    ),
}


def is_transient_exception(exception: Exception, provider: str = "default") -> bool:
    """
    Check if an exception is transient and should be retried.

    Args:
        exception: The exception to check
        provider: Provider name for provider-specific checks

    Returns:
        True if the exception is transient
    """
    exception_name = type(exception).__name__

    # Check provider-specific transient exceptions
    provider_exceptions = TRANSIENT_EXCEPTIONS.get(
        provider, TRANSIENT_EXCEPTIONS["default"]
    )

    if exception_name in provider_exceptions:
        return True

    # Check default transient exceptions
    if exception_name in TRANSIENT_EXCEPTIONS["default"]:
        return True

    # Check by error message patterns
    error_msg = str(exception).lower()
    transient_patterns = [
        "rate limit",
        "too many requests",
        "temporarily unavailable",
        "service unavailable",
        "connection reset",
        "connection refused",
        "timed out",
        "timeout",
        "overloaded",
        "capacity",
    ]

    for pattern in transient_patterns:
        if pattern in error_msg:
            return True

    return False


def is_rate_limit_exception(exception: Exception) -> bool:
    """Check if exception is specifically a rate limit error."""
    exception_name = type(exception).__name__
    error_msg = str(exception).lower()

    if "ratelimit" in exception_name.lower():
        return True

    rate_limit_patterns = [
        "rate limit",
        "too many requests",
        "429",
        "quota exceeded",
        "resource exhausted",
    ]

    for pattern in rate_limit_patterns:
        if pattern in error_msg:
            return True

    return False


def retry(
    max_retries: int = 3,
    initial_delay: float = 1.0,
    max_delay: float = 60.0,
    exponential_base: float = 2.0,
    retry_exceptions: Optional[List[Type[Exception]]] = None,
    on_retry: Optional[Callable[[Exception, int], None]] = None,
):
    """
    Decorator for retrying async functions with exponential backoff.

    Args:
        max_retries: Maximum number of retry attempts
        initial_delay: Initial delay between retries in seconds
        max_delay: Maximum delay between retries
        exponential_base: Base for exponential backoff
        retry_exceptions: List of exception types to retry on
        on_retry: Callback called before each retry
    """
    def decorator(func: Callable[..., T]) -> Callable[..., T]:
        @wraps(func)
        async def wrapper(*args, **kwargs) -> T:
            last_exception = None
            delay = initial_delay

            for attempt in range(max_retries + 1):
                try:
                    return await func(*args, **kwargs)

                except Exception as e:
                    last_exception = e

                    # Check if we should retry
                    should_retry = False

                    if retry_exceptions:
                        should_retry = any(
                            isinstance(e, exc_type)
                            for exc_type in retry_exceptions
                        )
                    else:
                        should_retry = is_transient_exception(e)

                    if not should_retry or attempt >= max_retries:
                        raise

                    # Log retry
                    logger.warning(
                        f"Retry {attempt + 1}/{max_retries} for {func.__name__} "
                        f"after {type(e).__name__}: {str(e)[:100]}"
                    )

                    # Call retry callback
                    if on_retry:
                        on_retry(e, attempt + 1)

                    # Wait before retry
                    await asyncio.sleep(delay)

                    # Exponential backoff
                    delay = min(delay * exponential_base, max_delay)

            # Should not reach here, but just in case
            raise last_exception

        return wrapper
    return decorator


class CircuitBreaker:
    """
    Circuit breaker pattern for preventing cascading failures.

    States:
    - CLOSED: Normal operation, requests pass through
    - OPEN: Failures exceeded threshold, requests fail fast
    - HALF_OPEN: Testing if service recovered
    """

    class State:
        CLOSED = "closed"
        OPEN = "open"
        HALF_OPEN = "half_open"

    def __init__(
        self,
        failure_threshold: int = 5,
        recovery_timeout: float = 30.0,
        half_open_requests: int = 3,
    ):
        """
        Initialize circuit breaker.

        Args:
            failure_threshold: Failures before opening circuit
            recovery_timeout: Seconds before trying half-open
            half_open_requests: Requests to try in half-open state
        """
        self.failure_threshold = failure_threshold
        self.recovery_timeout = recovery_timeout
        self.half_open_requests = half_open_requests

        self._state = self.State.CLOSED
        self._failures = 0
        self._successes = 0
        self._last_failure_time: Optional[datetime] = None
        self._half_open_count = 0
        self._lock = asyncio.Lock()

    @property
    def state(self) -> str:
        """Get current circuit state."""
        return self._state

    async def can_execute(self) -> bool:
        """
        Check if request can proceed.

        Returns:
            True if request should proceed
        """
        async with self._lock:
            if self._state == self.State.CLOSED:
                return True

            if self._state == self.State.OPEN:
                # Check if recovery timeout elapsed
                if self._last_failure_time:
                    elapsed = (datetime.now() - self._last_failure_time).total_seconds()
                    if elapsed >= self.recovery_timeout:
                        self._state = self.State.HALF_OPEN
                        self._half_open_count = 0
                        logger.info("Circuit breaker entering half-open state")
                        return True
                return False

            if self._state == self.State.HALF_OPEN:
                if self._half_open_count < self.half_open_requests:
                    self._half_open_count += 1
                    return True
                return False

            return True

    async def record_success(self) -> None:
        """Record successful request."""
        async with self._lock:
            if self._state == self.State.HALF_OPEN:
                self._successes += 1
                if self._successes >= self.half_open_requests:
                    self._state = self.State.CLOSED
                    self._failures = 0
                    self._successes = 0
                    logger.info("Circuit breaker closed after recovery")
            elif self._state == self.State.CLOSED:
                # Reset failure count on success
                self._failures = max(0, self._failures - 1)

    async def record_failure(self) -> None:
        """Record failed request."""
        async with self._lock:
            self._failures += 1
            self._last_failure_time = datetime.now()

            if self._state == self.State.HALF_OPEN:
                # Failure in half-open - back to open
                self._state = self.State.OPEN
                self._successes = 0
                logger.warning("Circuit breaker reopened after half-open failure")

            elif self._state == self.State.CLOSED:
                if self._failures >= self.failure_threshold:
                    self._state = self.State.OPEN
                    logger.warning(
                        f"Circuit breaker opened after {self._failures} failures"
                    )

    def get_stats(self) -> Dict[str, Any]:
        """Get circuit breaker statistics."""
        return {
            "state": self._state,
            "failures": self._failures,
            "successes": self._successes,
            "last_failure": self._last_failure_time.isoformat() if self._last_failure_time else None,
        }


class RetryManager:
    """
    Manages retry logic and fallbacks for multiple providers/models.
    """

    def __init__(
        self,
        retry_config: Optional[RetryConfig] = None,
        fallback_configs: Optional[Dict[str, FallbackConfig]] = None,
    ):
        """
        Initialize retry manager.

        Args:
            retry_config: Default retry configuration
            fallback_configs: Fallback configurations per model
        """
        self.retry_config = retry_config or RetryConfig()
        self.fallback_configs = fallback_configs or DEFAULT_FALLBACKS.copy()

        # Circuit breakers per provider
        self._circuit_breakers: Dict[str, CircuitBreaker] = {}

    def get_circuit_breaker(self, provider: str) -> CircuitBreaker:
        """Get or create circuit breaker for provider."""
        if provider not in self._circuit_breakers:
            self._circuit_breakers[provider] = CircuitBreaker()
        return self._circuit_breakers[provider]

    def get_fallback_model(
        self,
        model: str,
        exception: Exception,
    ) -> Optional[FallbackModel]:
        """
        Get fallback model for a failed request.

        Args:
            model: The model that failed
            exception: The exception that occurred

        Returns:
            FallbackModel to try, or None
        """
        config = self.fallback_configs.get(model)
        if not config or not config.enabled:
            return None

        is_rate_limit = is_rate_limit_exception(exception)
        is_timeout = "timeout" in str(exception).lower()

        for fallback in sorted(config.fallbacks, key=lambda f: f.priority):
            # Check if fallback applies to this error type
            if is_rate_limit and fallback.on_rate_limit:
                return fallback
            if is_timeout and fallback.on_timeout:
                return fallback
            if fallback.on_error:
                return fallback

        return None

    async def execute_with_retry(
        self,
        func: Callable[..., T],
        *args,
        provider: str = "default",
        model: Optional[str] = None,
        **kwargs,
    ) -> T:
        """
        Execute function with retry logic and fallbacks.

        Args:
            func: Async function to execute
            provider: Provider name for rate limiting
            model: Model name for fallback lookup
            *args, **kwargs: Arguments for func

        Returns:
            Result of func
        """
        circuit_breaker = self.get_circuit_breaker(provider)
        last_exception = None
        delay = self.retry_config.initial_delay

        for attempt in range(self.retry_config.max_retries + 1):
            # Check circuit breaker
            if not await circuit_breaker.can_execute():
                logger.warning(f"Circuit breaker open for {provider}, failing fast")
                if last_exception:
                    raise last_exception
                raise RuntimeError(f"Circuit breaker open for {provider}")

            try:
                result = await func(*args, **kwargs)
                await circuit_breaker.record_success()
                return result

            except Exception as e:
                last_exception = e
                await circuit_breaker.record_failure()

                # Check if we should retry
                if not is_transient_exception(e, provider):
                    raise

                if attempt >= self.retry_config.max_retries:
                    # Try fallback
                    if model:
                        fallback = self.get_fallback_model(model, e)
                        if fallback:
                            logger.info(
                                f"Attempting fallback to {fallback.provider}/{fallback.model}"
                            )
                            # Note: caller would need to re-invoke with fallback model
                            raise FallbackRequired(fallback) from e
                    raise

                logger.warning(
                    f"Retry {attempt + 1}/{self.retry_config.max_retries} "
                    f"for {provider}: {type(e).__name__}"
                )

                # Wait before retry
                await asyncio.sleep(delay)
                delay = min(
                    delay * self.retry_config.exponential_base,
                    self.retry_config.max_delay
                )

        raise last_exception


class FallbackRequired(Exception):
    """Exception indicating a fallback model should be used."""

    def __init__(self, fallback: FallbackModel):
        self.fallback = fallback
        super().__init__(f"Fallback required to {fallback.provider}/{fallback.model}")


# Singleton instance
_retry_manager: Optional[RetryManager] = None


def get_retry_manager() -> RetryManager:
    """Get or create the global retry manager."""
    global _retry_manager
    if _retry_manager is None:
        _retry_manager = RetryManager()
    return _retry_manager


def set_retry_manager(manager: RetryManager) -> None:
    """Set the global retry manager."""
    global _retry_manager
    _retry_manager = manager
