"""
Result Aggregator - Combine results from distributed agent tasks.

Provides:
- Multiple aggregation strategies
- Conflict resolution
- Quality scoring
- Result validation
"""

import logging
import statistics
from dataclasses import dataclass, field
from datetime import datetime
from enum import Enum
from typing import Any, Callable, Optional

logger = logging.getLogger(__name__)


class AggregationStrategy(Enum):
    """Strategies for aggregating results."""

    MERGE = "merge"  # Merge all results into one
    VOTE = "vote"  # Take most common result
    FIRST = "first"  # Take first result
    BEST = "best"  # Take highest quality result
    CONSENSUS = "consensus"  # Require agreement threshold
    WEIGHTED = "weighted"  # Weight by agent confidence
    REDUCE = "reduce"  # Custom reduce function


@dataclass
class AgentResult:
    """Result from a single agent."""

    agent_id: str
    task_id: str
    result: Any
    confidence: float = 1.0  # 0.0 to 1.0
    quality_score: Optional[float] = None  # Set by quality function
    metadata: dict[str, Any] = field(default_factory=dict)
    timestamp: datetime = field(default_factory=datetime.now)
    duration_seconds: Optional[float] = None

    def to_dict(self) -> dict[str, Any]:
        """Convert to dictionary."""
        return {
            "agent_id": self.agent_id,
            "task_id": self.task_id,
            "result": self.result,
            "confidence": self.confidence,
            "quality_score": self.quality_score,
            "metadata": self.metadata,
            "timestamp": self.timestamp.isoformat(),
            "duration_seconds": self.duration_seconds,
        }


@dataclass
class AggregatedResult:
    """Result of aggregating multiple agent results."""

    final_result: Any
    strategy_used: AggregationStrategy
    source_results: list[AgentResult]
    confidence: float  # Overall confidence
    agreement_ratio: float  # How much agents agreed
    metadata: dict[str, Any] = field(default_factory=dict)
    aggregated_at: datetime = field(default_factory=datetime.now)

    def to_dict(self) -> dict[str, Any]:
        """Convert to dictionary."""
        return {
            "final_result": self.final_result,
            "strategy_used": self.strategy_used.value,
            "source_count": len(self.source_results),
            "confidence": self.confidence,
            "agreement_ratio": self.agreement_ratio,
            "metadata": self.metadata,
            "aggregated_at": self.aggregated_at.isoformat(),
        }


class ResultAggregator:
    """
    Aggregates results from multiple agents.

    Supports multiple aggregation strategies:
    - MERGE: Combine all results (for complementary outputs)
    - VOTE: Take most common result (for discrete outputs)
    - FIRST: Take first available result
    - BEST: Take highest quality result
    - CONSENSUS: Require minimum agreement
    - WEIGHTED: Weight by confidence
    - REDUCE: Custom aggregation function

    Example:
        aggregator = ResultAggregator()

        # Add results from agents
        aggregator.add_result("agent-1", "task-1", {"code": "impl1"}, confidence=0.9)
        aggregator.add_result("agent-2", "task-1", {"code": "impl2"}, confidence=0.8)

        # Aggregate with quality scoring
        result = aggregator.aggregate(
            strategy=AggregationStrategy.BEST,
            quality_fn=lambda r: len(r.get("code", ""))  # Prefer longer code
        )
    """

    def __init__(self) -> None:
        """Initialize the aggregator."""
        self._results: list[AgentResult] = []
        self._by_task: dict[str, list[AgentResult]] = {}
        self._by_agent: dict[str, list[AgentResult]] = {}

    def add_result(
        self,
        agent_id: str,
        task_id: str,
        result: Any,
        confidence: float = 1.0,
        metadata: Optional[dict[str, Any]] = None,
        duration_seconds: Optional[float] = None,
    ) -> AgentResult:
        """
        Add a result from an agent.

        Args:
            agent_id: ID of the agent
            task_id: ID of the task
            result: The result data
            confidence: Agent's confidence in result (0.0-1.0)
            metadata: Additional metadata
            duration_seconds: How long the task took

        Returns:
            The created AgentResult
        """
        agent_result = AgentResult(
            agent_id=agent_id,
            task_id=task_id,
            result=result,
            confidence=min(1.0, max(0.0, confidence)),
            metadata=metadata or {},
            duration_seconds=duration_seconds,
        )

        self._results.append(agent_result)

        # Index by task
        if task_id not in self._by_task:
            self._by_task[task_id] = []
        self._by_task[task_id].append(agent_result)

        # Index by agent
        if agent_id not in self._by_agent:
            self._by_agent[agent_id] = []
        self._by_agent[agent_id].append(agent_result)

        logger.debug(f"Added result from {agent_id} for task {task_id}")
        return agent_result

    def aggregate(
        self,
        strategy: AggregationStrategy = AggregationStrategy.MERGE,
        task_id: Optional[str] = None,
        quality_fn: Optional[Callable[[Any], float]] = None,
        reduce_fn: Optional[Callable[[list[Any]], Any]] = None,
        consensus_threshold: float = 0.67,
        result_key: Optional[str] = None,
    ) -> Optional[AggregatedResult]:
        """
        Aggregate results using the specified strategy.

        Args:
            strategy: Aggregation strategy to use
            task_id: Only aggregate results for this task (None = all)
            quality_fn: Function to score result quality (for BEST strategy)
            reduce_fn: Custom reduce function (for REDUCE strategy)
            consensus_threshold: Agreement ratio required (for CONSENSUS)
            result_key: Key to extract from results if they're dicts

        Returns:
            AggregatedResult or None if no results
        """
        results = self._by_task.get(task_id, self._results) if task_id else self._results

        if not results:
            logger.warning("No results to aggregate")
            return None

        # Apply quality scoring if provided
        if quality_fn:
            for r in results:
                try:
                    val = r.result[result_key] if result_key and isinstance(r.result, dict) else r.result
                    r.quality_score = quality_fn(val)
                except Exception as e:
                    logger.warning(f"Quality function failed for {r.agent_id}: {e}")
                    r.quality_score = 0.0

        # Delegate to strategy handlers
        if strategy == AggregationStrategy.MERGE:
            return self._aggregate_merge(results)
        elif strategy == AggregationStrategy.VOTE:
            return self._aggregate_vote(results, result_key)
        elif strategy == AggregationStrategy.FIRST:
            return self._aggregate_first(results)
        elif strategy == AggregationStrategy.BEST:
            return self._aggregate_best(results)
        elif strategy == AggregationStrategy.CONSENSUS:
            return self._aggregate_consensus(results, consensus_threshold, result_key)
        elif strategy == AggregationStrategy.WEIGHTED:
            return self._aggregate_weighted(results, result_key)
        elif strategy == AggregationStrategy.REDUCE:
            return self._aggregate_reduce(results, reduce_fn, result_key)
        else:
            raise ValueError(f"Unknown aggregation strategy: {strategy}")

    def _aggregate_merge(self, results: list[AgentResult]) -> AggregatedResult:
        """Merge all results into a combined structure."""
        merged = {}

        for r in results:
            if isinstance(r.result, dict):
                merged.update(r.result)
            elif isinstance(r.result, list):
                if "list_results" not in merged:
                    merged["list_results"] = []
                merged["list_results"].extend(r.result)
            else:
                merged[r.agent_id] = r.result

        avg_confidence = statistics.mean(r.confidence for r in results)

        return AggregatedResult(
            final_result=merged,
            strategy_used=AggregationStrategy.MERGE,
            source_results=results,
            confidence=avg_confidence,
            agreement_ratio=1.0,  # Merge doesn't check agreement
            metadata={"merge_count": len(results)},
        )

    def _aggregate_vote(
        self,
        results: list[AgentResult],
        result_key: Optional[str],
    ) -> AggregatedResult:
        """Take the most common result (voting)."""
        # Extract comparable values
        values = []
        for r in results:
            if result_key and isinstance(r.result, dict):
                values.append((str(r.result.get(result_key, "")), r))
            else:
                values.append((str(r.result), r))

        # Count votes
        vote_counts: dict[str, int] = {}
        vote_results: dict[str, AgentResult] = {}
        for val_str, agent_result in values:
            vote_counts[val_str] = vote_counts.get(val_str, 0) + 1
            vote_results[val_str] = agent_result

        # Find winner
        winner_str = max(vote_counts, key=lambda k: vote_counts[k])
        winner_count = vote_counts[winner_str]
        winner_result = vote_results[winner_str]

        agreement = winner_count / len(results)

        return AggregatedResult(
            final_result=winner_result.result,
            strategy_used=AggregationStrategy.VOTE,
            source_results=results,
            confidence=agreement * winner_result.confidence,
            agreement_ratio=agreement,
            metadata={"votes": vote_counts, "winner_votes": winner_count},
        )

    def _aggregate_first(self, results: list[AgentResult]) -> AggregatedResult:
        """Take the first available result."""
        first = min(results, key=lambda r: r.timestamp)

        return AggregatedResult(
            final_result=first.result,
            strategy_used=AggregationStrategy.FIRST,
            source_results=results,
            confidence=first.confidence,
            agreement_ratio=1 / len(results),
            metadata={"selected_agent": first.agent_id},
        )

    def _aggregate_best(self, results: list[AgentResult]) -> AggregatedResult:
        """Take the highest quality result."""
        # Use quality_score if set, otherwise confidence
        def get_score(r: AgentResult) -> float:
            if r.quality_score is not None:
                return r.quality_score * r.confidence
            return r.confidence

        best = max(results, key=get_score)
        max_score = get_score(best)
        avg_score = statistics.mean(get_score(r) for r in results)

        return AggregatedResult(
            final_result=best.result,
            strategy_used=AggregationStrategy.BEST,
            source_results=results,
            confidence=best.confidence,
            agreement_ratio=avg_score / max_score if max_score > 0 else 0,
            metadata={
                "selected_agent": best.agent_id,
                "quality_score": best.quality_score,
                "best_score": max_score,
            },
        )

    def _aggregate_consensus(
        self,
        results: list[AgentResult],
        threshold: float,
        result_key: Optional[str],
    ) -> Optional[AggregatedResult]:
        """Require consensus above threshold to return result."""
        # Group similar results
        groups: dict[str, list[AgentResult]] = {}
        for r in results:
            if result_key and isinstance(r.result, dict):
                key = str(r.result.get(result_key, ""))
            else:
                key = str(r.result)

            if key not in groups:
                groups[key] = []
            groups[key].append(r)

        # Find largest group
        largest_key = max(groups, key=lambda k: len(groups[k]))
        largest_group = groups[largest_key]
        agreement = len(largest_group) / len(results)

        if agreement < threshold:
            logger.warning(f"Consensus not reached: {agreement:.2%} < {threshold:.2%}")
            return AggregatedResult(
                final_result=None,
                strategy_used=AggregationStrategy.CONSENSUS,
                source_results=results,
                confidence=0.0,
                agreement_ratio=agreement,
                metadata={
                    "consensus_failed": True,
                    "threshold": threshold,
                    "groups": {k: len(v) for k, v in groups.items()},
                },
            )

        # Return consensus result
        best_in_group = max(largest_group, key=lambda r: r.confidence)

        return AggregatedResult(
            final_result=best_in_group.result,
            strategy_used=AggregationStrategy.CONSENSUS,
            source_results=results,
            confidence=agreement * best_in_group.confidence,
            agreement_ratio=agreement,
            metadata={
                "consensus_reached": True,
                "threshold": threshold,
                "group_size": len(largest_group),
            },
        )

    def _aggregate_weighted(
        self,
        results: list[AgentResult],
        result_key: Optional[str],
    ) -> AggregatedResult:
        """Weight results by confidence."""
        if not results:
            raise ValueError("No results to aggregate")

        # For numeric results, compute weighted average
        if all(isinstance(r.result, (int, float)) for r in results):
            total_weight = sum(r.confidence for r in results)
            weighted_sum = sum(r.result * r.confidence for r in results)
            final = weighted_sum / total_weight if total_weight > 0 else 0

            return AggregatedResult(
                final_result=final,
                strategy_used=AggregationStrategy.WEIGHTED,
                source_results=results,
                confidence=total_weight / len(results),
                agreement_ratio=1.0,
                metadata={"weighted_average": final},
            )

        # For non-numeric, use weighted voting
        weighted_votes: dict[str, float] = {}
        vote_results: dict[str, AgentResult] = {}

        for r in results:
            if result_key and isinstance(r.result, dict):
                key = str(r.result.get(result_key, ""))
            else:
                key = str(r.result)

            weighted_votes[key] = weighted_votes.get(key, 0) + r.confidence
            if key not in vote_results or r.confidence > vote_results[key].confidence:
                vote_results[key] = r

        winner_key = max(weighted_votes, key=lambda k: weighted_votes[k])
        total_weight = sum(weighted_votes.values())

        return AggregatedResult(
            final_result=vote_results[winner_key].result,
            strategy_used=AggregationStrategy.WEIGHTED,
            source_results=results,
            confidence=weighted_votes[winner_key] / total_weight if total_weight > 0 else 0,
            agreement_ratio=weighted_votes[winner_key] / total_weight if total_weight > 0 else 0,
            metadata={"weighted_votes": weighted_votes},
        )

    def _aggregate_reduce(
        self,
        results: list[AgentResult],
        reduce_fn: Optional[Callable[[list[Any]], Any]],
        result_key: Optional[str],
    ) -> AggregatedResult:
        """Apply custom reduce function."""
        if not reduce_fn:
            raise ValueError("reduce_fn required for REDUCE strategy")

        # Extract values
        values = []
        for r in results:
            if result_key and isinstance(r.result, dict):
                values.append(r.result.get(result_key))
            else:
                values.append(r.result)

        # Apply reduce function
        reduced = reduce_fn(values)
        avg_confidence = statistics.mean(r.confidence for r in results)

        return AggregatedResult(
            final_result=reduced,
            strategy_used=AggregationStrategy.REDUCE,
            source_results=results,
            confidence=avg_confidence,
            agreement_ratio=1.0,  # Custom reduce doesn't measure agreement
            metadata={"input_count": len(values)},
        )

    def get_results_for_task(self, task_id: str) -> list[AgentResult]:
        """Get all results for a specific task."""
        return self._by_task.get(task_id, [])

    def get_results_for_agent(self, agent_id: str) -> list[AgentResult]:
        """Get all results from a specific agent."""
        return self._by_agent.get(agent_id, [])

    def get_summary(self) -> dict[str, Any]:
        """Get aggregator summary."""
        return {
            "total_results": len(self._results),
            "tasks_with_results": len(self._by_task),
            "agents_with_results": len(self._by_agent),
            "avg_confidence": (
                statistics.mean(r.confidence for r in self._results)
                if self._results else 0
            ),
        }

    def clear(self) -> None:
        """Clear all results."""
        self._results.clear()
        self._by_task.clear()
        self._by_agent.clear()
