"""
Unit tests for the Swarm Intelligence module.

Tests:
- Task decomposition strategies
- Result aggregation strategies
- Swarm coordination
"""

import asyncio
import pytest
from unittest.mock import MagicMock, AsyncMock

from agent_orchestrator.swarm import (
    SwarmCoordinator,
    SwarmConfig,
    SwarmState,
    CoordinationStrategy,
    TaskDecomposer,
    SubTask,
    DecompositionStrategy,
    ResultAggregator,
    AggregatedResult,
    AggregationStrategy,
)
from agent_orchestrator.swarm.decomposer import auto_decompose


class TestSubTask:
    """Tests for SubTask dataclass."""

    def test_subtask_creation(self):
        """Test creating a subtask."""
        task = SubTask(
            id="task-1",
            name="Test Task",
            description="A test task",
            priority=5,
        )
        assert task.id == "task-1"
        assert task.name == "Test Task"
        assert task.priority == 5
        assert task.status == "pending"

    def test_subtask_is_ready(self):
        """Test is_ready property."""
        task = SubTask(
            id="task-1",
            name="Test",
            description="Test",
            dependencies=[],
        )
        assert task.is_ready is True

        task.dependencies = ["dep-1"]
        assert task.is_ready is False

        task.status = "running"
        task.dependencies = []
        assert task.is_ready is False

    def test_subtask_to_dict(self):
        """Test converting subtask to dictionary."""
        task = SubTask(
            id="task-1",
            name="Test",
            description="Test description",
            required_capabilities=["code_edit"],
        )
        data = task.to_dict()
        assert data["id"] == "task-1"
        assert data["required_capabilities"] == ["code_edit"]


class TestTaskDecomposer:
    """Tests for the TaskDecomposer class."""

    def test_decomposer_init(self):
        """Test initializing decomposer."""
        decomposer = TaskDecomposer()
        assert len(decomposer.get_all_tasks()) == 0

    def test_decompose_parallel(self):
        """Test parallel decomposition."""
        decomposer = TaskDecomposer()

        subtasks = decomposer.decompose(
            name="Build project",
            description="Build and test",
            strategy=DecompositionStrategy.PARALLEL,
            hints={
                "subtasks": [
                    {"name": "Build", "capabilities": ["terminal"]},
                    {"name": "Lint", "capabilities": ["terminal"]},
                    {"name": "Test", "capabilities": ["run_tests"]},
                ]
            },
        )

        assert len(subtasks) == 3
        # All parallel tasks should have no dependencies
        for task in subtasks:
            assert task.dependencies == []

    def test_decompose_sequential(self):
        """Test sequential decomposition."""
        decomposer = TaskDecomposer()

        subtasks = decomposer.decompose(
            name="Deploy",
            description="Build then deploy",
            strategy=DecompositionStrategy.SEQUENTIAL,
            hints={
                "subtasks": [
                    {"name": "Build"},
                    {"name": "Test"},
                    {"name": "Deploy"},
                ]
            },
        )

        assert len(subtasks) == 3
        # First task has no dependencies
        assert subtasks[0].dependencies == []
        # Second depends on first
        assert subtasks[1].dependencies == [subtasks[0].id]
        # Third depends on second
        assert subtasks[2].dependencies == [subtasks[1].id]

    def test_decompose_map_reduce(self):
        """Test map-reduce decomposition."""
        decomposer = TaskDecomposer()

        subtasks = decomposer.decompose(
            name="Process data",
            description="Process chunks and combine",
            strategy=DecompositionStrategy.MAP_REDUCE,
            hints={
                "map_tasks": [
                    {"name": "Process chunk 1"},
                    {"name": "Process chunk 2"},
                    {"name": "Process chunk 3"},
                ],
                "reduce_task": {"name": "Combine results"},
            },
        )

        # 3 map + 1 reduce
        assert len(subtasks) == 4

        # Find reduce task
        reduce_task = next(t for t in subtasks if "reduce" in t.metadata.get("phase", ""))
        # Reduce should depend on all map tasks
        assert len(reduce_task.dependencies) == 3

    def test_decompose_pipeline(self):
        """Test pipeline decomposition."""
        decomposer = TaskDecomposer()

        subtasks = decomposer.decompose(
            name="ETL",
            description="Extract, transform, load",
            strategy=DecompositionStrategy.PIPELINE,
            hints={
                "stages": [
                    {"name": "Extract"},
                    {"name": "Transform"},
                    {"name": "Load"},
                ]
            },
        )

        assert len(subtasks) == 3
        # Pipeline should have chained dependencies
        assert subtasks[0].dependencies == []
        assert subtasks[1].dependencies == [subtasks[0].id]
        assert subtasks[2].dependencies == [subtasks[1].id]

    def test_decompose_hierarchical(self):
        """Test hierarchical decomposition."""
        decomposer = TaskDecomposer()

        subtasks = decomposer.decompose(
            name="Project",
            description="Multi-level project",
            strategy=DecompositionStrategy.HIERARCHICAL,
            hints={
                "subtasks": [
                    {
                        "name": "Frontend",
                        "children": [
                            {"name": "UI Components"},
                            {"name": "Styling"},
                        ],
                    },
                    {
                        "name": "Backend",
                        "children": [
                            {"name": "API"},
                            {"name": "Database"},
                        ],
                    },
                ]
            },
        )

        # 2 root + 4 children
        assert len(subtasks) == 6

    def test_get_ready_tasks(self):
        """Test getting tasks ready for execution."""
        decomposer = TaskDecomposer()

        decomposer.decompose(
            name="Test",
            description="Test",
            strategy=DecompositionStrategy.SEQUENTIAL,
            hints={
                "subtasks": [
                    {"name": "First"},
                    {"name": "Second"},
                    {"name": "Third"},
                ]
            },
        )

        ready = decomposer.get_ready_tasks()
        assert len(ready) == 1
        assert ready[0].name == "First"

    def test_task_lifecycle(self):
        """Test task state transitions."""
        decomposer = TaskDecomposer()

        subtasks = decomposer.decompose(
            name="Test",
            description="Test",
            strategy=DecompositionStrategy.PARALLEL,
            hints={"subtasks": [{"name": "Task 1"}]},
        )

        task = subtasks[0]

        # Assign
        assert decomposer.assign_task(task.id, "agent-1")
        assert task.status == "assigned"
        assert task.assigned_to == "agent-1"

        # Start
        assert decomposer.start_task(task.id)
        assert task.status == "running"
        assert task.started_at is not None

        # Complete
        assert decomposer.complete_task(task.id, {"result": "done"})
        assert task.status == "completed"
        assert task.completed_at is not None
        assert task.result == {"result": "done"}

    def test_task_failure(self):
        """Test task failure handling."""
        decomposer = TaskDecomposer()

        subtasks = decomposer.decompose(
            name="Test",
            description="Test",
            strategy=DecompositionStrategy.PARALLEL,
            hints={"subtasks": [{"name": "Task 1"}]},
        )

        task = subtasks[0]
        decomposer.start_task(task.id)

        assert decomposer.fail_task(task.id, "Test error")
        assert task.status == "failed"
        assert task.metadata.get("error") == "Test error"

    def test_get_progress(self):
        """Test getting decomposition progress."""
        decomposer = TaskDecomposer()

        decomposer.decompose(
            name="Test",
            description="Test",
            strategy=DecompositionStrategy.PARALLEL,
            hints={"subtasks": [{"name": "T1"}, {"name": "T2"}, {"name": "T3"}]},
        )

        progress = decomposer.get_progress()
        assert progress["total_tasks"] == 3
        assert progress["status_counts"]["pending"] == 3


class TestAutoDecompose:
    """Tests for auto_decompose function."""

    def test_auto_decompose_numbered_list(self):
        """Test auto-decomposing numbered list."""
        description = "1. Build the project\n2. Run tests\n3. Deploy"
        subtasks = auto_decompose(description)

        assert len(subtasks) == 3
        assert "Build" in subtasks[0].name

    def test_auto_decompose_bullet_list(self):
        """Test auto-decomposing bullet list."""
        description = "- Create component\n- Add styling\n- Write tests"
        subtasks = auto_decompose(description)

        assert len(subtasks) == 3

    def test_auto_decompose_and_separated(self):
        """Test auto-decomposing 'and' separated tasks."""
        description = "Build the frontend and deploy to staging"
        subtasks = auto_decompose(description)

        assert len(subtasks) == 2


class TestResultAggregator:
    """Tests for the ResultAggregator class."""

    def test_aggregator_init(self):
        """Test initializing aggregator."""
        aggregator = ResultAggregator()
        summary = aggregator.get_summary()
        assert summary["total_results"] == 0

    def test_add_result(self):
        """Test adding results."""
        aggregator = ResultAggregator()

        result = aggregator.add_result(
            agent_id="agent-1",
            task_id="task-1",
            result={"code": "impl1"},
            confidence=0.9,
        )

        assert result.agent_id == "agent-1"
        assert result.confidence == 0.9
        assert len(aggregator.get_results_for_task("task-1")) == 1

    def test_aggregate_merge(self):
        """Test merge aggregation strategy."""
        aggregator = ResultAggregator()

        aggregator.add_result("agent-1", "task-1", {"a": 1})
        aggregator.add_result("agent-2", "task-1", {"b": 2})

        result = aggregator.aggregate(strategy=AggregationStrategy.MERGE)

        assert result is not None
        assert result.final_result["a"] == 1
        assert result.final_result["b"] == 2
        assert result.strategy_used == AggregationStrategy.MERGE

    def test_aggregate_vote(self):
        """Test vote aggregation strategy."""
        aggregator = ResultAggregator()

        # Add 3 identical results and 1 different
        aggregator.add_result("agent-1", "task-1", "answer-A")
        aggregator.add_result("agent-2", "task-1", "answer-A")
        aggregator.add_result("agent-3", "task-1", "answer-A")
        aggregator.add_result("agent-4", "task-1", "answer-B")

        result = aggregator.aggregate(strategy=AggregationStrategy.VOTE)

        assert result is not None
        assert result.final_result == "answer-A"
        assert result.agreement_ratio == 0.75

    def test_aggregate_first(self):
        """Test first aggregation strategy."""
        aggregator = ResultAggregator()

        aggregator.add_result("agent-1", "task-1", "first-result")
        aggregator.add_result("agent-2", "task-1", "second-result")

        result = aggregator.aggregate(strategy=AggregationStrategy.FIRST)

        assert result is not None
        assert result.final_result == "first-result"

    def test_aggregate_best_with_quality(self):
        """Test best aggregation with quality function."""
        aggregator = ResultAggregator()

        aggregator.add_result("agent-1", "task-1", "short", confidence=0.9)
        aggregator.add_result("agent-2", "task-1", "longer result", confidence=0.7)

        # Quality based on length
        result = aggregator.aggregate(
            strategy=AggregationStrategy.BEST,
            quality_fn=lambda x: len(x) if isinstance(x, str) else 0,
        )

        assert result is not None
        assert result.final_result == "longer result"

    def test_aggregate_consensus_reached(self):
        """Test consensus aggregation when reached."""
        aggregator = ResultAggregator()

        # 3 out of 4 agree = 75% which is > 67%
        aggregator.add_result("agent-1", "task-1", "agreed")
        aggregator.add_result("agent-2", "task-1", "agreed")
        aggregator.add_result("agent-3", "task-1", "agreed")
        aggregator.add_result("agent-4", "task-1", "different")

        result = aggregator.aggregate(
            strategy=AggregationStrategy.CONSENSUS,
            consensus_threshold=0.67,
        )

        assert result is not None
        assert result.final_result == "agreed"
        assert result.metadata.get("consensus_reached") is True

    def test_aggregate_consensus_not_reached(self):
        """Test consensus aggregation when not reached."""
        aggregator = ResultAggregator()

        aggregator.add_result("agent-1", "task-1", "answer-A")
        aggregator.add_result("agent-2", "task-1", "answer-B")
        aggregator.add_result("agent-3", "task-1", "answer-C")

        result = aggregator.aggregate(
            strategy=AggregationStrategy.CONSENSUS,
            consensus_threshold=0.67,
        )

        assert result is not None
        assert result.final_result is None
        assert result.metadata.get("consensus_failed") is True

    def test_aggregate_weighted_numeric(self):
        """Test weighted aggregation with numeric results."""
        aggregator = ResultAggregator()

        aggregator.add_result("agent-1", "task-1", 100, confidence=0.9)
        aggregator.add_result("agent-2", "task-1", 50, confidence=0.1)

        result = aggregator.aggregate(strategy=AggregationStrategy.WEIGHTED)

        assert result is not None
        # Weighted average should be closer to 100
        assert result.final_result > 90

    def test_aggregate_reduce_custom(self):
        """Test reduce aggregation with custom function."""
        aggregator = ResultAggregator()

        aggregator.add_result("agent-1", "task-1", [1, 2])
        aggregator.add_result("agent-2", "task-1", [3, 4])

        def flatten_reduce(values):
            flat = []
            for v in values:
                if isinstance(v, list):
                    flat.extend(v)
            return flat

        result = aggregator.aggregate(
            strategy=AggregationStrategy.REDUCE,
            reduce_fn=flatten_reduce,
        )

        assert result is not None
        assert result.final_result == [1, 2, 3, 4]

    def test_get_results_by_task(self):
        """Test getting results filtered by task."""
        aggregator = ResultAggregator()

        aggregator.add_result("agent-1", "task-1", "result-1")
        aggregator.add_result("agent-2", "task-1", "result-2")
        aggregator.add_result("agent-1", "task-2", "result-3")

        task1_results = aggregator.get_results_for_task("task-1")
        assert len(task1_results) == 2

        task2_results = aggregator.get_results_for_task("task-2")
        assert len(task2_results) == 1


class TestSwarmConfig:
    """Tests for SwarmConfig."""

    def test_config_defaults(self):
        """Test default configuration values."""
        config = SwarmConfig()
        assert config.decomposition_strategy == DecompositionStrategy.PARALLEL
        assert config.coordination_strategy == CoordinationStrategy.CAPABILITY_MATCH
        assert config.aggregation_strategy == AggregationStrategy.BEST
        assert config.max_retries == 2

    def test_config_to_dict(self):
        """Test converting config to dictionary."""
        config = SwarmConfig(
            swarm_name="test-swarm",
            max_subtasks=10,
        )
        data = config.to_dict()
        assert data["swarm_name"] == "test-swarm"
        assert data["max_subtasks"] == 10


class TestSwarmCoordinator:
    """Tests for the SwarmCoordinator class."""

    def test_coordinator_init(self):
        """Test initializing coordinator."""
        coordinator = SwarmCoordinator()
        assert coordinator._state == SwarmState.IDLE
        assert len(coordinator.list_agents()) == 0

    def test_register_agent(self):
        """Test registering an agent."""
        coordinator = SwarmCoordinator()

        agent = coordinator.register_agent(
            "agent-1",
            capabilities=["code_edit", "git"],
            max_load=5,
        )

        assert agent.agent_id == "agent-1"
        assert "code_edit" in agent.capabilities
        assert agent.max_load == 5

    def test_unregister_agent(self):
        """Test unregistering an agent."""
        coordinator = SwarmCoordinator()

        coordinator.register_agent("agent-1")
        assert coordinator.unregister_agent("agent-1")
        assert "agent-1" not in coordinator.list_agents()

    def test_get_available_agents(self):
        """Test getting available agents."""
        coordinator = SwarmCoordinator()

        coordinator.register_agent("agent-1")
        coordinator.register_agent("agent-2")

        # Make agent-2 unavailable
        coordinator._agents["agent-2"].available = False

        available = coordinator.get_available_agents()
        assert len(available) == 1
        assert available[0].agent_id == "agent-1"

    def test_agent_can_accept_task(self):
        """Test agent task acceptance."""
        coordinator = SwarmCoordinator()

        agent = coordinator.register_agent("agent-1", max_load=2)
        assert agent.can_accept_task is True

        # Fill up the agent
        agent.current_load = 2
        assert agent.can_accept_task is False

    def test_get_progress(self):
        """Test getting swarm progress."""
        coordinator = SwarmCoordinator()
        coordinator.register_agent("agent-1")

        progress = coordinator.get_progress()
        assert progress["state"] == "idle"
        assert progress["agents"]["total"] == 1

    def test_get_status_summary(self):
        """Test getting detailed status summary."""
        config = SwarmConfig(swarm_name="test-swarm")
        coordinator = SwarmCoordinator(config)
        coordinator.register_agent("agent-1", capabilities=["code_edit"])

        summary = coordinator.get_status_summary()
        assert summary["swarm"]["name"] == "test-swarm"
        assert "agent-1" in summary["agents"]

    def test_state_change_callback(self):
        """Test state change callback."""
        coordinator = SwarmCoordinator()
        states = []

        coordinator.on_state_change(lambda s: states.append(s))
        coordinator._set_state(SwarmState.EXECUTING)

        assert SwarmState.EXECUTING in states


class TestSwarmCoordinatorAsync:
    """Async tests for SwarmCoordinator."""

    @pytest.mark.asyncio
    async def test_execute_simple_task(self):
        """Test executing a simple swarm task."""
        coordinator = SwarmCoordinator()

        # Register agents
        coordinator.register_agent("agent-1", capabilities=["code_edit"])
        coordinator.register_agent("agent-2", capabilities=["code_edit"])

        # Simple executor that returns task name
        async def executor(task: SubTask, agent_id: str):
            return f"Result from {agent_id} for {task.name}"

        result = await coordinator.execute(
            name="Test Task",
            description="A simple test",
            subtask_hints=[
                {"name": "Subtask 1", "capabilities": ["code_edit"]},
                {"name": "Subtask 2", "capabilities": ["code_edit"]},
            ],
            task_executor=executor,
        )

        assert result is not None
        assert coordinator._state == SwarmState.COMPLETED

    @pytest.mark.asyncio
    async def test_execute_with_quality_function(self):
        """Test execution with quality scoring."""
        coordinator = SwarmCoordinator()
        coordinator.register_agent("agent-1")

        async def executor(task: SubTask, agent_id: str):
            return "high quality result with details"

        result = await coordinator.execute(
            name="Quality Test",
            description="Test quality",
            subtask_hints=[{"name": "Task"}],
            task_executor=executor,
            quality_fn=lambda x: len(x) if isinstance(x, str) else 0,
        )

        assert result is not None

    @pytest.mark.asyncio
    async def test_execute_no_agents(self):
        """Test execution with no agents registered."""
        coordinator = SwarmCoordinator()

        async def executor(task: SubTask, agent_id: str):
            return "result"

        # Should eventually complete since there are no agents to assign to
        # and no tasks will be ready after a timeout
        with pytest.raises(Exception):
            # This will timeout or fail since no agents available
            await asyncio.wait_for(
                coordinator.execute(
                    name="No Agent Test",
                    description="Test",
                    subtask_hints=[{"name": "Task"}],
                    task_executor=executor,
                ),
                timeout=2.0,
            )


class TestSwarmIntegration:
    """Integration tests for the swarm module."""

    def test_full_swarm_workflow_sync(self):
        """Test the complete swarm workflow (sync parts)."""
        # Create coordinator with custom config
        config = SwarmConfig(
            decomposition_strategy=DecompositionStrategy.PARALLEL,
            coordination_strategy=CoordinationStrategy.CAPABILITY_MATCH,
            aggregation_strategy=AggregationStrategy.MERGE,
        )
        coordinator = SwarmCoordinator(config)

        # Register diverse agents
        coordinator.register_agent(
            "code-agent",
            capabilities=["code_edit", "git"],
            max_load=3,
        )
        coordinator.register_agent(
            "test-agent",
            capabilities=["run_tests", "debug"],
            max_load=2,
        )
        coordinator.register_agent(
            "docs-agent",
            capabilities=["documentation"],
            max_load=1,
        )

        # Verify setup
        assert len(coordinator.list_agents()) == 3
        assert coordinator._state == SwarmState.IDLE

        # Get progress
        progress = coordinator.get_progress()
        assert progress["agents"]["total"] == 3
        assert progress["agents"]["available"] == 3

    def test_decomposer_aggregator_integration(self):
        """Test decomposer and aggregator working together."""
        decomposer = TaskDecomposer()
        aggregator = ResultAggregator()

        # Decompose a task
        subtasks = decomposer.decompose(
            name="Integration Test",
            description="Test integration",
            strategy=DecompositionStrategy.PARALLEL,
            hints={
                "subtasks": [
                    {"name": "Part 1"},
                    {"name": "Part 2"},
                    {"name": "Part 3"},
                ]
            },
        )

        # Simulate execution and result collection
        for task in subtasks:
            decomposer.start_task(task.id)
            result = f"Result for {task.name}"
            decomposer.complete_task(task.id, result)

            # Add to aggregator
            aggregator.add_result(
                agent_id=f"agent-{task.id}",
                task_id=task.id,
                result=result,
                confidence=0.9,
            )

        # Verify decomposer progress
        progress = decomposer.get_progress()
        assert progress["status_counts"]["completed"] == 3

        # Verify aggregator results
        summary = aggregator.get_summary()
        assert summary["total_results"] == 3

        # Aggregate results
        final = aggregator.aggregate(strategy=AggregationStrategy.MERGE)
        assert final is not None
        assert len(final.source_results) == 3
