"""
Swarm Coordinator - Orchestrate distributed agent collaboration.

Provides:
- Swarm lifecycle management
- Task distribution and tracking
- Result collection and aggregation
- Multi-strategy coordination
"""

import asyncio
import logging
from dataclasses import dataclass, field
from datetime import datetime
from enum import Enum
from typing import Any, Callable, Optional
from uuid import uuid4

from .decomposer import TaskDecomposer, SubTask, DecompositionStrategy
from .aggregator import ResultAggregator, AggregatedResult, AggregationStrategy

logger = logging.getLogger(__name__)


class CoordinationStrategy(Enum):
    """Strategies for coordinating swarm agents."""

    ROUND_ROBIN = "round_robin"  # Distribute evenly
    CAPABILITY_MATCH = "capability_match"  # Match by capabilities
    LOAD_BALANCED = "load_balanced"  # Balance by current load
    BROADCAST = "broadcast"  # Send to all agents
    AUCTION = "auction"  # Agents bid for tasks
    HIERARCHICAL = "hierarchical"  # Tiered agent structure


class SwarmState(Enum):
    """State of a swarm operation."""

    IDLE = "idle"
    DECOMPOSING = "decomposing"
    DISTRIBUTING = "distributing"
    EXECUTING = "executing"
    AGGREGATING = "aggregating"
    COMPLETED = "completed"
    FAILED = "failed"


@dataclass
class SwarmConfig:
    """Configuration for swarm coordination."""

    # Decomposition settings
    decomposition_strategy: DecompositionStrategy = DecompositionStrategy.PARALLEL
    max_subtasks: int = 20
    min_subtask_complexity: float = 0.1

    # Distribution settings
    coordination_strategy: CoordinationStrategy = CoordinationStrategy.CAPABILITY_MATCH
    max_tasks_per_agent: int = 3
    task_timeout_seconds: int = 300

    # Aggregation settings
    aggregation_strategy: AggregationStrategy = AggregationStrategy.BEST
    consensus_threshold: float = 0.67
    require_min_results: int = 1

    # Retry settings
    max_retries: int = 2
    retry_delay_seconds: int = 5

    # Metadata
    swarm_name: str = ""

    def to_dict(self) -> dict[str, Any]:
        """Convert to dictionary."""
        return {
            "decomposition_strategy": self.decomposition_strategy.value,
            "max_subtasks": self.max_subtasks,
            "coordination_strategy": self.coordination_strategy.value,
            "max_tasks_per_agent": self.max_tasks_per_agent,
            "task_timeout_seconds": self.task_timeout_seconds,
            "aggregation_strategy": self.aggregation_strategy.value,
            "consensus_threshold": self.consensus_threshold,
            "require_min_results": self.require_min_results,
            "max_retries": self.max_retries,
            "swarm_name": self.swarm_name,
        }


@dataclass
class AgentInfo:
    """Information about an agent in the swarm."""

    agent_id: str
    capabilities: list[str] = field(default_factory=list)
    current_load: int = 0
    max_load: int = 3
    available: bool = True
    performance_score: float = 1.0  # 0.0 to 1.0
    metadata: dict[str, Any] = field(default_factory=dict)

    @property
    def can_accept_task(self) -> bool:
        """Check if agent can accept another task."""
        return self.available and self.current_load < self.max_load


class SwarmCoordinator:
    """
    Coordinates distributed agent execution.

    Provides:
    - Task decomposition and distribution
    - Agent management and load balancing
    - Result collection and aggregation
    - Progress monitoring

    Example:
        coordinator = SwarmCoordinator()

        # Register agents
        coordinator.register_agent("agent-1", ["code_edit", "git"])
        coordinator.register_agent("agent-2", ["code_edit", "search"])

        # Execute a swarm task
        result = await coordinator.execute(
            name="Refactor codebase",
            description="Refactor the auth module",
            subtask_hints=[
                {"name": "Analyze current code", "capabilities": ["code_edit"]},
                {"name": "Search for patterns", "capabilities": ["search"]},
                {"name": "Apply refactoring", "capabilities": ["code_edit"]},
            ],
            task_executor=my_executor_fn,
        )
    """

    def __init__(self, config: Optional[SwarmConfig] = None) -> None:
        """
        Initialize the swarm coordinator.

        Args:
            config: Swarm configuration
        """
        self.config = config or SwarmConfig()
        self.swarm_id = str(uuid4())[:8]

        self._agents: dict[str, AgentInfo] = {}
        self._decomposer = TaskDecomposer()
        self._aggregator = ResultAggregator()

        self._state = SwarmState.IDLE
        self._started_at: Optional[datetime] = None
        self._completed_at: Optional[datetime] = None
        self._current_task: Optional[str] = None

        # Callbacks
        self._on_task_assigned: list[Callable[[SubTask, str], None]] = []
        self._on_task_completed: list[Callable[[SubTask, Any], None]] = []
        self._on_state_change: list[Callable[[SwarmState], None]] = []

    def register_agent(
        self,
        agent_id: str,
        capabilities: Optional[list[str]] = None,
        max_load: int = 3,
        metadata: Optional[dict[str, Any]] = None,
    ) -> AgentInfo:
        """
        Register an agent with the swarm.

        Args:
            agent_id: Unique agent identifier
            capabilities: Agent capabilities
            max_load: Maximum concurrent tasks
            metadata: Additional agent metadata

        Returns:
            AgentInfo for the registered agent
        """
        agent = AgentInfo(
            agent_id=agent_id,
            capabilities=capabilities or [],
            max_load=max_load,
            metadata=metadata or {},
        )
        self._agents[agent_id] = agent
        logger.info(f"Registered agent {agent_id} with capabilities: {capabilities}")
        return agent

    def unregister_agent(self, agent_id: str) -> bool:
        """
        Unregister an agent from the swarm.

        Returns True if agent was registered.
        """
        if agent_id in self._agents:
            del self._agents[agent_id]
            logger.info(f"Unregistered agent {agent_id}")
            return True
        return False

    def get_agent(self, agent_id: str) -> Optional[AgentInfo]:
        """Get agent info by ID."""
        return self._agents.get(agent_id)

    def list_agents(self) -> list[str]:
        """List all registered agent IDs."""
        return list(self._agents.keys())

    def get_available_agents(self) -> list[AgentInfo]:
        """Get agents that can accept tasks."""
        return [a for a in self._agents.values() if a.can_accept_task]

    def _set_state(self, state: SwarmState) -> None:
        """Update swarm state and notify listeners."""
        old_state = self._state
        self._state = state
        logger.info(f"Swarm state: {old_state.value} -> {state.value}")

        for callback in self._on_state_change:
            try:
                callback(state)
            except Exception as e:
                logger.warning(f"State change callback error: {e}")

    async def execute(
        self,
        name: str,
        description: str,
        subtask_hints: Optional[list[dict[str, Any]]] = None,
        task_executor: Optional[Callable[[SubTask, str], Any]] = None,
        quality_fn: Optional[Callable[[Any], float]] = None,
    ) -> AggregatedResult:
        """
        Execute a task using the swarm.

        This is the main entry point for swarm operations. It:
        1. Decomposes the task into subtasks
        2. Distributes subtasks to agents
        3. Executes subtasks and collects results
        4. Aggregates results into final output

        Args:
            name: Task name
            description: Task description
            subtask_hints: Hints for task decomposition
            task_executor: Function to execute subtasks (async)
            quality_fn: Function to score result quality

        Returns:
            Aggregated result from all subtasks
        """
        self._started_at = datetime.now()
        self._current_task = name
        self._aggregator.clear()

        try:
            # Step 1: Decompose
            self._set_state(SwarmState.DECOMPOSING)
            subtasks = self._decomposer.decompose(
                name=name,
                description=description,
                strategy=self.config.decomposition_strategy,
                hints={"subtasks": subtask_hints} if subtask_hints else None,
            )

            if not subtasks:
                raise ValueError("Task decomposition produced no subtasks")

            # Step 2: Distribute and Execute
            self._set_state(SwarmState.DISTRIBUTING)
            await self._distribute_and_execute(subtasks, task_executor)

            # Step 3: Aggregate
            self._set_state(SwarmState.AGGREGATING)
            result = self._aggregator.aggregate(
                strategy=self.config.aggregation_strategy,
                quality_fn=quality_fn,
                consensus_threshold=self.config.consensus_threshold,
            )

            if result is None:
                raise ValueError("Aggregation produced no result")

            self._set_state(SwarmState.COMPLETED)
            self._completed_at = datetime.now()

            return result

        except Exception as e:
            self._set_state(SwarmState.FAILED)
            self._completed_at = datetime.now()
            logger.error(f"Swarm execution failed: {e}")
            raise

    async def _distribute_and_execute(
        self,
        subtasks: list[SubTask],
        task_executor: Optional[Callable[[SubTask, str], Any]],
    ) -> None:
        """Distribute subtasks to agents and execute them."""
        self._set_state(SwarmState.EXECUTING)

        # Create task queue from ready tasks
        pending_retries: dict[str, int] = {}  # task_id -> retry count

        while True:
            ready_tasks = self._decomposer.get_ready_tasks()
            if not ready_tasks:
                # Check if any tasks are still running
                running = self._decomposer.get_tasks_by_status("running")
                assigned = self._decomposer.get_tasks_by_status("assigned")
                if not running and not assigned:
                    break
                await asyncio.sleep(0.1)
                continue

            # Distribute ready tasks to available agents
            available_agents = self.get_available_agents()
            if not available_agents:
                await asyncio.sleep(0.1)
                continue

            # Assign tasks based on coordination strategy
            assignments = self._compute_assignments(ready_tasks, available_agents)

            # Execute assignments in parallel
            async_tasks = []
            for task, agent in assignments:
                self._decomposer.assign_task(task.id, agent.agent_id)
                agent.current_load += 1

                # Notify callbacks
                for callback in self._on_task_assigned:
                    try:
                        callback(task, agent.agent_id)
                    except Exception as e:
                        logger.warning(f"Task assigned callback error: {e}")

                # Create execution task
                if task_executor:
                    async_task = self._execute_subtask(
                        task, agent, task_executor, pending_retries
                    )
                    async_tasks.append(async_task)

            if async_tasks:
                await asyncio.gather(*async_tasks, return_exceptions=True)

    async def _execute_subtask(
        self,
        task: SubTask,
        agent: AgentInfo,
        executor: Callable[[SubTask, str], Any],
        pending_retries: dict[str, int],
    ) -> None:
        """Execute a single subtask."""
        self._decomposer.start_task(task.id)

        try:
            # Execute with timeout
            result = await asyncio.wait_for(
                asyncio.coroutine(lambda: executor(task, agent.agent_id))()
                if not asyncio.iscoroutinefunction(executor)
                else executor(task, agent.agent_id),
                timeout=self.config.task_timeout_seconds,
            )

            # Task completed successfully
            self._decomposer.complete_task(task.id, result)

            # Add to aggregator
            self._aggregator.add_result(
                agent_id=agent.agent_id,
                task_id=task.id,
                result=result,
                confidence=agent.performance_score,
                duration_seconds=task.duration,
            )

            # Notify callbacks
            for callback in self._on_task_completed:
                try:
                    callback(task, result)
                except Exception as e:
                    logger.warning(f"Task completed callback error: {e}")

        except asyncio.TimeoutError:
            logger.warning(f"Task {task.id} timed out on agent {agent.agent_id}")
            await self._handle_task_failure(task, agent, pending_retries, "Timeout")

        except Exception as e:
            logger.error(f"Task {task.id} failed on agent {agent.agent_id}: {e}")
            await self._handle_task_failure(task, agent, pending_retries, str(e))

        finally:
            agent.current_load = max(0, agent.current_load - 1)

    async def _handle_task_failure(
        self,
        task: SubTask,
        agent: AgentInfo,
        pending_retries: dict[str, int],
        error: str,
    ) -> None:
        """Handle a failed task with optional retry."""
        retry_count = pending_retries.get(task.id, 0)

        if retry_count < self.config.max_retries:
            # Retry the task
            pending_retries[task.id] = retry_count + 1
            task.status = "pending"
            task.assigned_to = None
            logger.info(f"Scheduling retry {retry_count + 1} for task {task.id}")
            await asyncio.sleep(self.config.retry_delay_seconds)
        else:
            # Mark as failed
            self._decomposer.fail_task(task.id, error)
            agent.performance_score = max(0.1, agent.performance_score - 0.1)

    def _compute_assignments(
        self,
        tasks: list[SubTask],
        agents: list[AgentInfo],
    ) -> list[tuple[SubTask, AgentInfo]]:
        """Compute task-to-agent assignments based on coordination strategy."""
        assignments = []

        if self.config.coordination_strategy == CoordinationStrategy.ROUND_ROBIN:
            # Simple round-robin distribution
            for i, task in enumerate(tasks):
                agent = agents[i % len(agents)]
                if agent.can_accept_task:
                    assignments.append((task, agent))

        elif self.config.coordination_strategy == CoordinationStrategy.CAPABILITY_MATCH:
            # Match tasks to agents by capability
            for task in tasks:
                best_agent = self._find_best_agent_for_task(task, agents)
                if best_agent:
                    assignments.append((task, best_agent))

        elif self.config.coordination_strategy == CoordinationStrategy.LOAD_BALANCED:
            # Assign to least loaded agents
            sorted_agents = sorted(agents, key=lambda a: a.current_load)
            for task in tasks:
                for agent in sorted_agents:
                    if agent.can_accept_task:
                        assignments.append((task, agent))
                        break

        elif self.config.coordination_strategy == CoordinationStrategy.BROADCAST:
            # Send task to all available agents
            for task in tasks[:1]:  # Only broadcast first task
                for agent in agents:
                    if agent.can_accept_task:
                        # Create task copy for each agent
                        assignments.append((task, agent))

        elif self.config.coordination_strategy == CoordinationStrategy.HIERARCHICAL:
            # Assign to highest performing agents first
            sorted_agents = sorted(agents, key=lambda a: -a.performance_score)
            for task in tasks:
                for agent in sorted_agents:
                    if agent.can_accept_task:
                        assignments.append((task, agent))
                        break

        return assignments

    def _find_best_agent_for_task(
        self,
        task: SubTask,
        agents: list[AgentInfo],
    ) -> Optional[AgentInfo]:
        """Find the best agent for a task based on capabilities."""
        if not task.required_capabilities:
            # No requirements, return least loaded available agent
            available = [a for a in agents if a.can_accept_task]
            if available:
                return min(available, key=lambda a: a.current_load)
            return None

        # Score agents by capability match
        best_agent = None
        best_score = -1

        for agent in agents:
            if not agent.can_accept_task:
                continue

            # Count matching capabilities
            matches = sum(
                1 for cap in task.required_capabilities
                if cap in agent.capabilities
            )

            # Factor in performance and load
            score = (
                matches * 10
                + agent.performance_score * 5
                - agent.current_load
            )

            if score > best_score:
                best_score = score
                best_agent = agent

        return best_agent

    def get_progress(self) -> dict[str, Any]:
        """Get current swarm progress."""
        decomposer_progress = self._decomposer.get_progress()
        aggregator_summary = self._aggregator.get_summary()

        return {
            "swarm_id": self.swarm_id,
            "state": self._state.value,
            "current_task": self._current_task,
            "started_at": self._started_at.isoformat() if self._started_at else None,
            "duration_seconds": (
                (self._completed_at or datetime.now()) - self._started_at
            ).total_seconds() if self._started_at else 0,
            "agents": {
                "total": len(self._agents),
                "available": len(self.get_available_agents()),
            },
            "tasks": decomposer_progress,
            "results": aggregator_summary,
        }

    def get_status_summary(self) -> dict[str, Any]:
        """Get detailed status summary."""
        return {
            "swarm": {
                "id": self.swarm_id,
                "name": self.config.swarm_name,
                "state": self._state.value,
            },
            "config": self.config.to_dict(),
            "progress": self.get_progress(),
            "agents": {
                agent_id: {
                    "capabilities": agent.capabilities,
                    "load": f"{agent.current_load}/{agent.max_load}",
                    "available": agent.can_accept_task,
                    "performance": agent.performance_score,
                }
                for agent_id, agent in self._agents.items()
            },
        }

    # Event hooks
    def on_task_assigned(
        self,
        callback: Callable[[SubTask, str], None],
    ) -> None:
        """Register callback for task assignment events."""
        self._on_task_assigned.append(callback)

    def on_task_completed(
        self,
        callback: Callable[[SubTask, Any], None],
    ) -> None:
        """Register callback for task completion events."""
        self._on_task_completed.append(callback)

    def on_state_change(
        self,
        callback: Callable[[SwarmState], None],
    ) -> None:
        """Register callback for state change events."""
        self._on_state_change.append(callback)

    def reset(self) -> None:
        """Reset the coordinator for a new operation."""
        self._decomposer.reset()
        self._aggregator.clear()
        self._state = SwarmState.IDLE
        self._started_at = None
        self._completed_at = None
        self._current_task = None

        # Reset agent loads
        for agent in self._agents.values():
            agent.current_load = 0
