"""
Control Actions - Actions the control loop can take on agents.

This module defines the actions available to the control loop:
1. AUTO_PROMPT: Send an unstick prompt to the agent
2. ESCALATE: Request human intervention
3. TERMINATE: Kill the agent process
4. REASSIGN: Move task to a different agent
5. CONTINUE: No action needed, let agent continue
"""

import logging
from dataclasses import dataclass, field
from datetime import datetime, timedelta
from enum import Enum
from typing import Any, Optional

from .health import StuckReason, HealthCheckResult


logger = logging.getLogger(__name__)


class ControlActionType(Enum):
    """Types of control actions."""

    CONTINUE = "continue"  # No action, agent is healthy
    AUTO_PROMPT = "auto_prompt"  # Send unstick prompt
    ESCALATE = "escalate"  # Request human intervention
    TERMINATE = "terminate"  # Kill the agent
    REASSIGN = "reassign"  # Move task to different agent
    PAUSE = "pause"  # Temporarily pause the agent
    RESUME = "resume"  # Resume a paused agent


class EscalationLevel(Enum):
    """Escalation severity levels."""

    INFO = "info"  # FYI, no action required
    WARN = "warn"  # Attention needed soon
    URGENT = "urgent"  # Needs immediate attention
    CRITICAL = "critical"  # System at risk


@dataclass
class ControlAction:
    """
    A control action to take on an agent.

    Created by the control loop based on health check results.
    """

    action_type: ControlActionType
    agent_id: str
    task_id: Optional[str] = None
    reason: str = ""
    prompt: Optional[str] = None  # For AUTO_PROMPT
    escalation_level: Optional[EscalationLevel] = None  # For ESCALATE
    target_agent_id: Optional[str] = None  # For REASSIGN
    metadata: dict[str, Any] = field(default_factory=dict)
    created_at: datetime = field(default_factory=datetime.now)

    def __str__(self) -> str:
        """Human-readable representation."""
        return f"ControlAction({self.action_type.value}, agent={self.agent_id}, reason={self.reason})"


@dataclass
class ActionResult:
    """Result of executing a control action."""

    action: ControlAction
    success: bool
    message: str
    executed_at: datetime = field(default_factory=datetime.now)
    error: Optional[str] = None


class ActionPolicy:
    """
    Policy for deciding control actions based on health state.

    Configurable thresholds and rules for when to take action.
    Includes exponential backoff for auto-prompts to prevent flooding.
    """

    def __init__(
        self,
        auto_prompt_max_attempts: int = 3,
        escalate_after_auto_prompts: int = 2,
        terminate_after_escalations: int = 2,
        auto_terminate_on_timeout: bool = False,
        auto_prompt_base_delay_seconds: float = 30.0,
        auto_prompt_max_delay_seconds: float = 300.0,
        auto_prompt_backoff_multiplier: float = 2.0,
    ):
        """
        Initialize action policy.

        Args:
            auto_prompt_max_attempts: Max auto-prompts before escalating
            escalate_after_auto_prompts: Auto-prompts before escalation
            terminate_after_escalations: Escalations before termination
            auto_terminate_on_timeout: Auto-terminate on timeout (vs escalate)
            auto_prompt_base_delay_seconds: Initial delay between auto-prompts
            auto_prompt_max_delay_seconds: Maximum delay between auto-prompts
            auto_prompt_backoff_multiplier: Multiplier for exponential backoff
        """
        self.auto_prompt_max_attempts = auto_prompt_max_attempts
        self.escalate_after_auto_prompts = escalate_after_auto_prompts
        self.terminate_after_escalations = terminate_after_escalations
        self.auto_terminate_on_timeout = auto_terminate_on_timeout

        # Exponential backoff settings
        self.auto_prompt_base_delay = auto_prompt_base_delay_seconds
        self.auto_prompt_max_delay = auto_prompt_max_delay_seconds
        self.auto_prompt_backoff_multiplier = auto_prompt_backoff_multiplier

        # Track attempts per agent
        self._auto_prompt_counts: dict[str, int] = {}
        self._escalation_counts: dict[str, int] = {}

        # Track last auto-prompt time for backoff
        self._last_auto_prompt_times: dict[str, datetime] = {}

        # Track last escalation time and level per agent
        self._last_escalation_times: dict[str, datetime] = {}
        self._last_escalation_levels: dict[str, EscalationLevel] = {}

        # Minimum time between escalations (seconds)
        self._escalation_cooldown_seconds = 60.0

    def decide_action(
        self,
        health_result: HealthCheckResult,
        task_id: Optional[str] = None,
    ) -> ControlAction:
        """
        Decide what action to take based on health check.

        Args:
            health_result: Result from health check
            task_id: Current task ID

        Returns:
            ControlAction to execute
        """
        agent_id = health_result.agent_id

        # Healthy agent - continue
        if health_result.is_healthy and not health_result.is_stuck:
            return ControlAction(
                action_type=ControlActionType.CONTINUE,
                agent_id=agent_id,
                task_id=task_id,
                reason="Agent is healthy",
            )

        # Get current attempt counts
        auto_prompts = self._auto_prompt_counts.get(agent_id, 0)
        escalations = self._escalation_counts.get(agent_id, 0)

        # Decide based on stuck reason
        stuck_reason = health_result.stuck_reason

        # Timeout - special handling
        if stuck_reason == StuckReason.TIMEOUT_EXCEEDED:
            if self.auto_terminate_on_timeout:
                return self._create_terminate_action(
                    agent_id, task_id,
                    "Task exceeded timeout limit",
                    health_result
                )
            else:
                return self._create_escalate_action(
                    agent_id, task_id,
                    "Task exceeded timeout - needs decision",
                    EscalationLevel.URGENT,
                    health_result
                )

        # Awaiting approval - always escalate
        if stuck_reason == StuckReason.AWAITING_APPROVAL:
            return self._create_escalate_action(
                agent_id, task_id,
                health_result.stuck_details,
                EscalationLevel.INFO,
                health_result
            )

        # Resource exhausted - escalate urgently
        if stuck_reason == StuckReason.RESOURCE_EXHAUSTED:
            return self._create_escalate_action(
                agent_id, task_id,
                "Agent resource exhausted",
                EscalationLevel.URGENT,
                health_result
            )

        # For other stuck reasons, follow escalation ladder
        if escalations >= self.terminate_after_escalations:
            return self._create_terminate_action(
                agent_id, task_id,
                f"Terminated after {escalations} escalations",
                health_result
            )

        if auto_prompts >= self.escalate_after_auto_prompts:
            return self._create_escalate_action(
                agent_id, task_id,
                f"Auto-prompting ineffective after {auto_prompts} attempts",
                EscalationLevel.WARN,
                health_result
            )

        if auto_prompts < self.auto_prompt_max_attempts:
            # Check exponential backoff
            if not self.can_send_auto_prompt(agent_id):
                remaining = self.get_time_until_auto_prompt(agent_id)
                logger.debug(
                    f"Auto-prompt backoff for {agent_id}: {remaining:.1f}s remaining"
                )
                return ControlAction(
                    action_type=ControlActionType.CONTINUE,
                    agent_id=agent_id,
                    task_id=task_id,
                    reason=f"Waiting for backoff ({remaining:.1f}s remaining)",
                    metadata={"backoff_remaining_seconds": remaining},
                )

            return self._create_auto_prompt_action(
                agent_id, task_id,
                health_result
            )

        # Fallback to escalation
        return self._create_escalate_action(
            agent_id, task_id,
            "Agent stuck - automatic remediation exhausted",
            EscalationLevel.WARN,
            health_result
        )

    def _create_auto_prompt_action(
        self,
        agent_id: str,
        task_id: Optional[str],
        health_result: HealthCheckResult,
    ) -> ControlAction:
        """Create an auto-prompt action."""
        from .health import generate_unstick_prompt

        prompt = generate_unstick_prompt(
            health_result.stuck_reason,
            health_result.stuck_details
        )

        return ControlAction(
            action_type=ControlActionType.AUTO_PROMPT,
            agent_id=agent_id,
            task_id=task_id,
            reason=f"Stuck: {health_result.stuck_reason.value}",
            prompt=prompt,
            metadata={
                "stuck_reason": health_result.stuck_reason.value,
                "stuck_details": health_result.stuck_details,
            },
        )

    def _create_escalate_action(
        self,
        agent_id: str,
        task_id: Optional[str],
        reason: str,
        level: EscalationLevel,
        health_result: HealthCheckResult,
    ) -> ControlAction:
        """Create an escalation action."""
        return ControlAction(
            action_type=ControlActionType.ESCALATE,
            agent_id=agent_id,
            task_id=task_id,
            reason=reason,
            escalation_level=level,
            metadata={
                "stuck_reason": health_result.stuck_reason.value,
                "stuck_details": health_result.stuck_details,
                "recommendations": health_result.recommendations,
            },
        )

    def _create_terminate_action(
        self,
        agent_id: str,
        task_id: Optional[str],
        reason: str,
        health_result: HealthCheckResult,
    ) -> ControlAction:
        """Create a termination action."""
        return ControlAction(
            action_type=ControlActionType.TERMINATE,
            agent_id=agent_id,
            task_id=task_id,
            reason=reason,
            metadata={
                "stuck_reason": health_result.stuck_reason.value,
                "final_state": health_result.state.value,
            },
        )

    def record_auto_prompt(self, agent_id: str) -> None:
        """Record that an auto-prompt was sent."""
        self._auto_prompt_counts[agent_id] = self._auto_prompt_counts.get(agent_id, 0) + 1
        self._last_auto_prompt_times[agent_id] = datetime.now()

    def record_escalation(
        self,
        agent_id: str,
        level: Optional[EscalationLevel] = None,
    ) -> None:
        """
        Record that an escalation was made.

        Args:
            agent_id: Agent identifier
            level: Escalation level (for tracking)
        """
        self._escalation_counts[agent_id] = self._escalation_counts.get(agent_id, 0) + 1
        self._last_escalation_times[agent_id] = datetime.now()
        if level:
            self._last_escalation_levels[agent_id] = level

    def can_escalate(self, agent_id: str) -> bool:
        """
        Check if enough time has passed for another escalation.

        Prevents flooding escalations in rapid succession.

        Args:
            agent_id: Agent identifier

        Returns:
            True if escalation can be sent now
        """
        last_time = self._last_escalation_times.get(agent_id)
        if not last_time:
            return True

        elapsed = (datetime.now() - last_time).total_seconds()
        return elapsed >= self._escalation_cooldown_seconds

    def get_escalation_state(self, agent_id: str) -> dict[str, Any]:
        """
        Get the current escalation state for an agent.

        Args:
            agent_id: Agent identifier

        Returns:
            Dictionary with escalation state info
        """
        last_time = self._last_escalation_times.get(agent_id)
        last_level = self._last_escalation_levels.get(agent_id)
        count = self._escalation_counts.get(agent_id, 0)

        return {
            "count": count,
            "can_escalate": self.can_escalate(agent_id),
            "last_escalation_time": last_time.isoformat() if last_time else None,
            "last_escalation_level": last_level.value if last_level else None,
            "at_termination_threshold": count >= self.terminate_after_escalations,
        }

    def reset_agent(self, agent_id: str) -> None:
        """Reset all counts and state for an agent (e.g., after task completion)."""
        self._auto_prompt_counts.pop(agent_id, None)
        self._escalation_counts.pop(agent_id, None)
        self._last_auto_prompt_times.pop(agent_id, None)
        self._last_escalation_times.pop(agent_id, None)
        self._last_escalation_levels.pop(agent_id, None)

    def get_auto_prompt_delay(self, agent_id: str) -> float:
        """
        Get the delay before the next auto-prompt (exponential backoff).

        Args:
            agent_id: Agent identifier

        Returns:
            Delay in seconds before next auto-prompt should be sent
        """
        attempts = self._auto_prompt_counts.get(agent_id, 0)
        if attempts == 0:
            return 0.0

        # Calculate delay: base * (multiplier ^ (attempts - 1))
        delay = self.auto_prompt_base_delay * (
            self.auto_prompt_backoff_multiplier ** (attempts - 1)
        )
        return min(delay, self.auto_prompt_max_delay)

    def can_send_auto_prompt(self, agent_id: str) -> bool:
        """
        Check if enough time has passed for another auto-prompt.

        Uses exponential backoff based on the number of previous attempts.

        Args:
            agent_id: Agent identifier

        Returns:
            True if auto-prompt can be sent now
        """
        last_time = self._last_auto_prompt_times.get(agent_id)
        if not last_time:
            return True

        required_delay = self.get_auto_prompt_delay(agent_id)
        elapsed = (datetime.now() - last_time).total_seconds()
        return elapsed >= required_delay

    def get_time_until_auto_prompt(self, agent_id: str) -> float:
        """
        Get seconds remaining until next auto-prompt can be sent.

        Args:
            agent_id: Agent identifier

        Returns:
            Seconds until auto-prompt allowed, or 0 if allowed now
        """
        last_time = self._last_auto_prompt_times.get(agent_id)
        if not last_time:
            return 0.0

        required_delay = self.get_auto_prompt_delay(agent_id)
        elapsed = (datetime.now() - last_time).total_seconds()
        remaining = required_delay - elapsed
        return max(0.0, remaining)

    def get_agent_stats(self, agent_id: str) -> dict[str, Any]:
        """Get action stats for an agent including backoff and escalation info."""
        return {
            "auto_prompts": self._auto_prompt_counts.get(agent_id, 0),
            "escalations": self._escalation_counts.get(agent_id, 0),
            "can_auto_prompt": self.can_send_auto_prompt(agent_id),
            "backoff_remaining": self.get_time_until_auto_prompt(agent_id),
            "next_backoff_delay": self.get_auto_prompt_delay(agent_id),
            "escalation_state": self.get_escalation_state(agent_id),
        }


def create_reassign_action(
    agent_id: str,
    target_agent_id: str,
    task_id: str,
    reason: str,
) -> ControlAction:
    """Create an action to reassign a task to a different agent."""
    return ControlAction(
        action_type=ControlActionType.REASSIGN,
        agent_id=agent_id,
        task_id=task_id,
        reason=reason,
        target_agent_id=target_agent_id,
    )


def create_pause_action(agent_id: str, reason: str) -> ControlAction:
    """Create an action to pause an agent."""
    return ControlAction(
        action_type=ControlActionType.PAUSE,
        agent_id=agent_id,
        reason=reason,
    )


def create_resume_action(agent_id: str) -> ControlAction:
    """Create an action to resume a paused agent."""
    return ControlAction(
        action_type=ControlActionType.RESUME,
        agent_id=agent_id,
        reason="Resuming paused agent",
    )
