"""
Task Decomposer - Break complex tasks into distributable subtasks.

Provides:
- Multiple decomposition strategies
- Dependency tracking between subtasks
- Priority-based task ordering
- Capability-based task assignment hints
"""

import logging
import re
from dataclasses import dataclass, field
from datetime import datetime
from enum import Enum
from typing import Any, Callable, Optional
from uuid import uuid4

logger = logging.getLogger(__name__)


class DecompositionStrategy(Enum):
    """Strategies for decomposing tasks."""

    SEQUENTIAL = "sequential"  # Tasks must run in order
    PARALLEL = "parallel"  # Tasks can run simultaneously
    HIERARCHICAL = "hierarchical"  # Tree structure with parent-child
    MAP_REDUCE = "map_reduce"  # Split -> process -> combine
    PIPELINE = "pipeline"  # Each task transforms output of previous


@dataclass
class SubTask:
    """A decomposed subtask."""

    id: str
    name: str
    description: str
    parent_id: Optional[str] = None  # For hierarchical decomposition
    dependencies: list[str] = field(default_factory=list)  # IDs of tasks this depends on
    priority: int = 0  # Higher = more important
    estimated_complexity: float = 1.0  # Relative complexity 0.0-10.0
    required_capabilities: list[str] = field(default_factory=list)
    metadata: dict[str, Any] = field(default_factory=dict)
    status: str = "pending"  # pending, assigned, running, completed, failed
    assigned_to: Optional[str] = None  # Agent ID
    result: Optional[Any] = None
    created_at: datetime = field(default_factory=datetime.now)
    started_at: Optional[datetime] = None
    completed_at: Optional[datetime] = None

    def to_dict(self) -> dict[str, Any]:
        """Convert to dictionary."""
        return {
            "id": self.id,
            "name": self.name,
            "description": self.description,
            "parent_id": self.parent_id,
            "dependencies": self.dependencies,
            "priority": self.priority,
            "estimated_complexity": self.estimated_complexity,
            "required_capabilities": self.required_capabilities,
            "metadata": self.metadata,
            "status": self.status,
            "assigned_to": self.assigned_to,
            "created_at": self.created_at.isoformat(),
            "started_at": self.started_at.isoformat() if self.started_at else None,
            "completed_at": self.completed_at.isoformat() if self.completed_at else None,
        }

    @property
    def is_ready(self) -> bool:
        """Check if task is ready to run (no pending dependencies)."""
        return self.status == "pending" and len(self.dependencies) == 0

    @property
    def duration(self) -> Optional[float]:
        """Get task duration in seconds if completed."""
        if self.started_at and self.completed_at:
            return (self.completed_at - self.started_at).total_seconds()
        return None


class TaskDecomposer:
    """
    Decomposes complex tasks into subtasks for swarm processing.

    Supports multiple decomposition strategies:
    - SEQUENTIAL: Tasks run one after another
    - PARALLEL: Tasks run simultaneously
    - HIERARCHICAL: Tree structure with recursive subtasks
    - MAP_REDUCE: Split data, process chunks, combine results
    - PIPELINE: Chain of transformations

    Example:
        decomposer = TaskDecomposer()

        # Decompose a task into parallel subtasks
        subtasks = decomposer.decompose(
            name="Build and test project",
            description="Build the project and run all tests",
            strategy=DecompositionStrategy.PARALLEL,
            hints={
                "subtasks": [
                    {"name": "Build", "capabilities": ["terminal"]},
                    {"name": "Run unit tests", "capabilities": ["run_tests"]},
                    {"name": "Run integration tests", "capabilities": ["run_tests"]},
                ]
            }
        )

        # Get ready tasks
        ready = decomposer.get_ready_tasks()
    """

    def __init__(self) -> None:
        """Initialize the task decomposer."""
        self._tasks: dict[str, SubTask] = {}
        self._decomposition_id = str(uuid4())[:8]

    def decompose(
        self,
        name: str,
        description: str,
        strategy: DecompositionStrategy = DecompositionStrategy.PARALLEL,
        hints: Optional[dict[str, Any]] = None,
        parent_id: Optional[str] = None,
    ) -> list[SubTask]:
        """
        Decompose a task into subtasks.

        Args:
            name: Task name
            description: Task description
            strategy: Decomposition strategy
            hints: Hints for decomposition (e.g., explicit subtask definitions)
            parent_id: Parent task ID for hierarchical decomposition

        Returns:
            List of created subtasks
        """
        hints = hints or {}
        subtasks = []

        if strategy == DecompositionStrategy.SEQUENTIAL:
            subtasks = self._decompose_sequential(name, description, hints, parent_id)
        elif strategy == DecompositionStrategy.PARALLEL:
            subtasks = self._decompose_parallel(name, description, hints, parent_id)
        elif strategy == DecompositionStrategy.HIERARCHICAL:
            subtasks = self._decompose_hierarchical(name, description, hints, parent_id)
        elif strategy == DecompositionStrategy.MAP_REDUCE:
            subtasks = self._decompose_map_reduce(name, description, hints, parent_id)
        elif strategy == DecompositionStrategy.PIPELINE:
            subtasks = self._decompose_pipeline(name, description, hints, parent_id)

        # Register all subtasks
        for task in subtasks:
            self._tasks[task.id] = task

        logger.info(f"Decomposed '{name}' into {len(subtasks)} subtasks using {strategy.value}")
        return subtasks

    def _decompose_sequential(
        self,
        name: str,
        description: str,
        hints: dict[str, Any],
        parent_id: Optional[str],
    ) -> list[SubTask]:
        """Decompose into sequential subtasks with chained dependencies."""
        subtask_hints = hints.get("subtasks", [])
        subtasks = []
        prev_id = None

        for i, hint in enumerate(subtask_hints):
            task_id = f"{self._decomposition_id}-seq-{i}"
            task = SubTask(
                id=task_id,
                name=hint.get("name", f"Step {i+1}"),
                description=hint.get("description", f"Sequential step {i+1} of {name}"),
                parent_id=parent_id,
                dependencies=[prev_id] if prev_id else [],
                priority=hint.get("priority", len(subtask_hints) - i),
                required_capabilities=hint.get("capabilities", []),
                metadata=hint.get("metadata", {}),
            )
            subtasks.append(task)
            prev_id = task_id

        return subtasks

    def _decompose_parallel(
        self,
        name: str,
        description: str,
        hints: dict[str, Any],
        parent_id: Optional[str],
    ) -> list[SubTask]:
        """Decompose into parallel subtasks with no dependencies."""
        subtask_hints = hints.get("subtasks", [])
        subtasks = []

        for i, hint in enumerate(subtask_hints):
            task_id = f"{self._decomposition_id}-par-{i}"
            task = SubTask(
                id=task_id,
                name=hint.get("name", f"Task {i+1}"),
                description=hint.get("description", f"Parallel task {i+1} of {name}"),
                parent_id=parent_id,
                dependencies=[],  # No dependencies for parallel
                priority=hint.get("priority", 0),
                required_capabilities=hint.get("capabilities", []),
                metadata=hint.get("metadata", {}),
            )
            subtasks.append(task)

        return subtasks

    def _decompose_hierarchical(
        self,
        name: str,
        description: str,
        hints: dict[str, Any],
        parent_id: Optional[str],
    ) -> list[SubTask]:
        """Decompose into hierarchical tree structure."""
        subtasks = []

        def process_node(node_hints: dict[str, Any], node_parent: Optional[str], depth: int) -> None:
            task_id = f"{self._decomposition_id}-hier-{len(subtasks)}"
            task = SubTask(
                id=task_id,
                name=node_hints.get("name", f"Node {len(subtasks)+1}"),
                description=node_hints.get("description", ""),
                parent_id=node_parent,
                dependencies=[],
                priority=node_hints.get("priority", 0),
                required_capabilities=node_hints.get("capabilities", []),
                metadata={"depth": depth, **node_hints.get("metadata", {})},
            )
            subtasks.append(task)

            # Process children
            children = node_hints.get("children", [])
            for child in children:
                process_node(child, task_id, depth + 1)

        # Process root nodes
        root_nodes = hints.get("subtasks", [])
        for node in root_nodes:
            process_node(node, parent_id, 0)

        return subtasks

    def _decompose_map_reduce(
        self,
        name: str,
        description: str,
        hints: dict[str, Any],
        parent_id: Optional[str],
    ) -> list[SubTask]:
        """Decompose into map-reduce pattern: map tasks + reduce task."""
        subtasks = []

        # Map phase
        map_hints = hints.get("map_tasks", hints.get("subtasks", []))
        map_ids = []

        for i, hint in enumerate(map_hints):
            task_id = f"{self._decomposition_id}-map-{i}"
            task = SubTask(
                id=task_id,
                name=hint.get("name", f"Map {i+1}"),
                description=hint.get("description", f"Process chunk {i+1}"),
                parent_id=parent_id,
                dependencies=[],
                priority=hint.get("priority", 1),
                required_capabilities=hint.get("capabilities", []),
                metadata={"phase": "map", "index": i, **hint.get("metadata", {})},
            )
            subtasks.append(task)
            map_ids.append(task_id)

        # Reduce phase
        reduce_hint = hints.get("reduce_task", {})
        reduce_task = SubTask(
            id=f"{self._decomposition_id}-reduce",
            name=reduce_hint.get("name", "Reduce"),
            description=reduce_hint.get("description", "Combine results from map phase"),
            parent_id=parent_id,
            dependencies=map_ids,  # Depends on all map tasks
            priority=reduce_hint.get("priority", 0),
            required_capabilities=reduce_hint.get("capabilities", []),
            metadata={"phase": "reduce", **reduce_hint.get("metadata", {})},
        )
        subtasks.append(reduce_task)

        return subtasks

    def _decompose_pipeline(
        self,
        name: str,
        description: str,
        hints: dict[str, Any],
        parent_id: Optional[str],
    ) -> list[SubTask]:
        """Decompose into pipeline stages."""
        stage_hints = hints.get("stages", hints.get("subtasks", []))
        subtasks = []
        prev_id = None

        for i, hint in enumerate(stage_hints):
            task_id = f"{self._decomposition_id}-pipe-{i}"
            task = SubTask(
                id=task_id,
                name=hint.get("name", f"Stage {i+1}"),
                description=hint.get("description", f"Pipeline stage {i+1}"),
                parent_id=parent_id,
                dependencies=[prev_id] if prev_id else [],
                priority=hint.get("priority", len(stage_hints) - i),
                required_capabilities=hint.get("capabilities", []),
                metadata={"stage": i, "input_transform": hint.get("input_transform"), **hint.get("metadata", {})},
            )
            subtasks.append(task)
            prev_id = task_id

        return subtasks

    def get_task(self, task_id: str) -> Optional[SubTask]:
        """Get a task by ID."""
        return self._tasks.get(task_id)

    def get_all_tasks(self) -> list[SubTask]:
        """Get all tasks."""
        return list(self._tasks.values())

    def get_ready_tasks(self) -> list[SubTask]:
        """Get tasks that are ready to run (no pending dependencies)."""
        ready = []
        for task in self._tasks.values():
            if task.status != "pending":
                continue

            # Check if all dependencies are completed
            deps_satisfied = all(
                self._tasks.get(dep_id, SubTask(id="", name="", description="")).status == "completed"
                for dep_id in task.dependencies
            )

            if deps_satisfied:
                ready.append(task)

        return sorted(ready, key=lambda t: -t.priority)

    def get_tasks_by_status(self, status: str) -> list[SubTask]:
        """Get tasks by status."""
        return [t for t in self._tasks.values() if t.status == status]

    def assign_task(self, task_id: str, agent_id: str) -> bool:
        """
        Assign a task to an agent.

        Returns True if assignment successful.
        """
        task = self._tasks.get(task_id)
        if not task or task.status != "pending":
            return False

        task.status = "assigned"
        task.assigned_to = agent_id
        logger.info(f"Assigned task {task_id} to agent {agent_id}")
        return True

    def start_task(self, task_id: str) -> bool:
        """
        Mark a task as started.

        Returns True if status updated.
        """
        task = self._tasks.get(task_id)
        if not task or task.status not in ("pending", "assigned"):
            return False

        task.status = "running"
        task.started_at = datetime.now()
        logger.info(f"Started task {task_id}")
        return True

    def complete_task(self, task_id: str, result: Any = None) -> bool:
        """
        Mark a task as completed.

        Returns True if status updated.
        """
        task = self._tasks.get(task_id)
        if not task or task.status != "running":
            return False

        task.status = "completed"
        task.completed_at = datetime.now()
        task.result = result

        # Remove this task from others' dependencies
        for other_task in self._tasks.values():
            if task_id in other_task.dependencies:
                other_task.dependencies.remove(task_id)

        logger.info(f"Completed task {task_id}")
        return True

    def fail_task(self, task_id: str, error: str = "") -> bool:
        """
        Mark a task as failed.

        Returns True if status updated.
        """
        task = self._tasks.get(task_id)
        if not task:
            return False

        task.status = "failed"
        task.completed_at = datetime.now()
        task.metadata["error"] = error
        logger.warning(f"Failed task {task_id}: {error}")
        return True

    def get_progress(self) -> dict[str, Any]:
        """Get decomposition progress summary."""
        status_counts = {"pending": 0, "assigned": 0, "running": 0, "completed": 0, "failed": 0}
        total_complexity = 0.0
        completed_complexity = 0.0

        for task in self._tasks.values():
            status_counts[task.status] = status_counts.get(task.status, 0) + 1
            total_complexity += task.estimated_complexity
            if task.status == "completed":
                completed_complexity += task.estimated_complexity

        return {
            "total_tasks": len(self._tasks),
            "status_counts": status_counts,
            "progress_percent": (completed_complexity / total_complexity * 100) if total_complexity > 0 else 0,
            "ready_tasks": len(self.get_ready_tasks()),
        }

    def reset(self) -> None:
        """Reset the decomposer, clearing all tasks."""
        self._tasks.clear()
        self._decomposition_id = str(uuid4())[:8]


def auto_decompose(
    description: str,
    strategy: DecompositionStrategy = DecompositionStrategy.PARALLEL,
) -> list[SubTask]:
    """
    Attempt to auto-decompose a task description into subtasks.

    This is a simple heuristic-based decomposition. For complex tasks,
    use explicit hints with the TaskDecomposer.

    Args:
        description: Natural language task description
        strategy: Decomposition strategy

    Returns:
        List of subtasks
    """
    decomposer = TaskDecomposer()

    # Simple heuristic: look for numbered lists, bullet points, or "and"
    subtask_hints = []

    # Check for numbered items
    numbered = re.findall(r"(\d+)\.\s*([^\n]+)", description)
    if numbered:
        for num, text in numbered:
            subtask_hints.append({"name": text.strip(), "description": text.strip()})

    # Check for bullet points
    if not subtask_hints:
        bullets = re.findall(r"[-*]\s*([^\n]+)", description)
        for text in bullets:
            subtask_hints.append({"name": text.strip(), "description": text.strip()})

    # Fall back to splitting on "and"
    if not subtask_hints and " and " in description.lower():
        parts = re.split(r"\s+and\s+", description, flags=re.IGNORECASE)
        for part in parts:
            part = part.strip()
            if len(part) > 3:
                subtask_hints.append({"name": part[:50], "description": part})

    # If no structure found, create a single task
    if not subtask_hints:
        subtask_hints = [{"name": "Execute task", "description": description}]

    return decomposer.decompose(
        name="Auto-decomposed task",
        description=description,
        strategy=strategy,
        hints={"subtasks": subtask_hints},
    )
