"""
Session Replay - Replay and analyze recorded traces.

Provides:
- Timeline generation from traces
- Decision point extraction
- Execution graph visualization data
- Replay API for debugging
"""

import logging
from datetime import datetime
from typing import Any, Optional

from .models import Trace, Span, SpanKind, SpanStatus, TimelineEvent, DecisionPoint
from .storage import TraceStorage, get_trace_storage

logger = logging.getLogger(__name__)


class SessionReplay:
    """
    Replay recorded sessions step-by-step.

    Enables analysis and debugging of agent executions by providing
    timeline views, decision point extraction, and execution graphs.

    Inspired by AgentOps session replay and LangSmith trace visualization.

    Example:
        replay = SessionReplay()

        # Get timeline of events
        timeline = replay.get_execution_timeline("trace-123")
        for event in timeline:
            print(f"{event.timestamp}: {event.event_type} - {event.span_name}")

        # Get decision points
        decisions = replay.get_decision_points("trace-123")
        for decision in decisions:
            print(f"Decision: {decision.decision}")
            print(f"Rationale: {decision.rationale}")
    """

    def __init__(self, storage: Optional[TraceStorage] = None):
        """
        Initialize session replay.

        Args:
            storage: Trace storage backend
        """
        self._storage = storage or get_trace_storage()

    def get_trace(self, trace_id: str) -> Optional[Trace]:
        """Get a trace by ID."""
        return self._storage.get_trace(trace_id)

    def get_execution_timeline(self, trace_id: str) -> list[TimelineEvent]:
        """
        Get chronological timeline of events.

        Args:
            trace_id: Trace identifier

        Returns:
            List of timeline events in chronological order
        """
        trace = self.get_trace(trace_id)
        if not trace or not trace.root_span:
            return []

        events: list[TimelineEvent] = []

        def flatten_spans(span: Span, depth: int = 0) -> None:
            # Start event
            events.append(TimelineEvent(
                timestamp=span.start_time,
                event_type="span_start",
                span_id=span.span_id,
                span_name=span.name,
                span_kind=span.kind,
                depth=depth,
                data={
                    "input_data": span.input_data,
                    "metadata": span.metadata,
                },
            ))

            # Process children
            for child in sorted(span.children, key=lambda s: s.start_time):
                flatten_spans(child, depth + 1)

            # End event
            if span.end_time:
                events.append(TimelineEvent(
                    timestamp=span.end_time,
                    event_type="span_end",
                    span_id=span.span_id,
                    span_name=span.name,
                    span_kind=span.kind,
                    depth=depth,
                    data={
                        "status": span.status.value,
                        "output_data": span.output_data,
                        "error": span.error,
                        "latency_ms": span.latency_ms,
                        "tokens": span.total_tokens,
                        "cost_usd": span.cost_usd,
                    },
                ))

        flatten_spans(trace.root_span)

        # Sort by timestamp
        return sorted(events, key=lambda e: e.timestamp)

    def get_decision_points(self, trace_id: str) -> list[DecisionPoint]:
        """
        Get all decision points where agent made choices.

        Decision points include:
        - Tool calls (agent chose which tool to use)
        - Approval requests (agent decided to ask for approval)
        - LLM calls with multiple options

        Args:
            trace_id: Trace identifier

        Returns:
            List of decision points
        """
        trace = self.get_trace(trace_id)
        if not trace or not trace.root_span:
            return []

        decisions: list[DecisionPoint] = []

        def find_decisions(span: Span) -> None:
            # Tool calls are decision points
            if span.kind == SpanKind.TOOL_CALL:
                decisions.append(DecisionPoint(
                    span_id=span.span_id,
                    span_kind=span.kind,
                    timestamp=span.start_time,
                    input_context=span.input_data,
                    decision=span.name,
                    rationale=span.metadata.get("rationale"),
                    alternatives=span.metadata.get("alternatives", []),
                    confidence=span.metadata.get("confidence", 1.0),
                ))

            # Approval requests are decision points
            elif span.kind == SpanKind.APPROVAL:
                decisions.append(DecisionPoint(
                    span_id=span.span_id,
                    span_kind=span.kind,
                    timestamp=span.start_time,
                    input_context=span.input_data,
                    decision=span.output_data.get("response", "unknown"),
                    rationale=span.metadata.get("rationale"),
                ))

            # Agent handoffs are decision points
            elif span.kind == SpanKind.HANDOFF:
                decisions.append(DecisionPoint(
                    span_id=span.span_id,
                    span_kind=span.kind,
                    timestamp=span.start_time,
                    input_context=span.input_data,
                    decision=f"handoff to {span.metadata.get('to_agent', 'unknown')}",
                    rationale=span.metadata.get("rationale"),
                ))

            # Recurse into children
            for child in span.children:
                find_decisions(child)

        find_decisions(trace.root_span)

        # Sort by timestamp
        return sorted(decisions, key=lambda d: d.timestamp)

    def get_execution_graph(self, trace_id: str) -> dict[str, Any]:
        """
        Get execution graph data for visualization.

        Returns a structure suitable for rendering as a DAG.

        Args:
            trace_id: Trace identifier

        Returns:
            Graph structure with nodes and edges
        """
        trace = self.get_trace(trace_id)
        if not trace or not trace.root_span:
            return {"nodes": [], "edges": []}

        nodes: list[dict[str, Any]] = []
        edges: list[dict[str, str]] = []

        def process_span(span: Span, depth: int = 0) -> None:
            # Add node
            nodes.append({
                "id": span.span_id,
                "name": span.name,
                "kind": span.kind.value,
                "status": span.status.value,
                "depth": depth,
                "latency_ms": span.latency_ms,
                "tokens": span.total_tokens,
                "cost_usd": span.cost_usd,
                "has_error": span.error is not None,
            })

            # Add edges to children
            for child in span.children:
                edges.append({
                    "source": span.span_id,
                    "target": child.span_id,
                })
                process_span(child, depth + 1)

        process_span(trace.root_span)

        return {
            "nodes": nodes,
            "edges": edges,
            "trace_id": trace_id,
            "total_nodes": len(nodes),
            "max_depth": max((n["depth"] for n in nodes), default=0),
        }

    def get_span_details(self, trace_id: str, span_id: str) -> Optional[dict[str, Any]]:
        """
        Get detailed information about a specific span.

        Args:
            trace_id: Trace identifier
            span_id: Span identifier

        Returns:
            Span details or None if not found
        """
        trace = self.get_trace(trace_id)
        if not trace or not trace.root_span:
            return None

        def find_span(span: Span) -> Optional[Span]:
            if span.span_id == span_id:
                return span
            for child in span.children:
                result = find_span(child)
                if result:
                    return result
            return None

        span = find_span(trace.root_span)
        if not span:
            return None

        return {
            "span_id": span.span_id,
            "trace_id": span.trace_id,
            "parent_span_id": span.parent_span_id,
            "name": span.name,
            "kind": span.kind.value,
            "status": span.status.value,
            "start_time": span.start_time.isoformat(),
            "end_time": span.end_time.isoformat() if span.end_time else None,
            "latency_ms": span.latency_ms,
            "input_tokens": span.input_tokens,
            "output_tokens": span.output_tokens,
            "cost_usd": span.cost_usd,
            "input_data": span.input_data,
            "output_data": span.output_data,
            "metadata": span.metadata,
            "error": span.error,
            "child_count": len(span.children),
        }

    def get_error_spans(self, trace_id: str) -> list[dict[str, Any]]:
        """
        Get all spans that had errors.

        Args:
            trace_id: Trace identifier

        Returns:
            List of error span details
        """
        trace = self.get_trace(trace_id)
        if not trace or not trace.root_span:
            return []

        errors: list[dict[str, Any]] = []

        def find_errors(span: Span) -> None:
            if span.status == SpanStatus.ERROR or span.error:
                errors.append({
                    "span_id": span.span_id,
                    "name": span.name,
                    "kind": span.kind.value,
                    "error": span.error,
                    "timestamp": span.start_time.isoformat(),
                })
            for child in span.children:
                find_errors(child)

        find_errors(trace.root_span)
        return errors

    def get_cost_breakdown(self, trace_id: str) -> dict[str, Any]:
        """
        Get cost breakdown by span kind.

        Args:
            trace_id: Trace identifier

        Returns:
            Cost breakdown data
        """
        trace = self.get_trace(trace_id)
        if not trace or not trace.root_span:
            return {}

        by_kind: dict[str, float] = {}
        by_name: dict[str, float] = {}

        def aggregate_costs(span: Span) -> None:
            kind = span.kind.value
            by_kind[kind] = by_kind.get(kind, 0.0) + span.cost_usd
            by_name[span.name] = by_name.get(span.name, 0.0) + span.cost_usd
            for child in span.children:
                aggregate_costs(child)

        aggregate_costs(trace.root_span)

        return {
            "total_cost_usd": trace.total_cost_usd,
            "by_kind": by_kind,
            "by_name": by_name,
        }

    def compare_traces(
        self,
        trace_id_1: str,
        trace_id_2: str,
    ) -> dict[str, Any]:
        """
        Compare two traces for differences.

        Args:
            trace_id_1: First trace ID
            trace_id_2: Second trace ID

        Returns:
            Comparison data
        """
        trace1 = self.get_trace(trace_id_1)
        trace2 = self.get_trace(trace_id_2)

        if not trace1 or not trace2:
            return {"error": "One or both traces not found"}

        return {
            "trace_1": {
                "trace_id": trace_id_1,
                "total_tokens": trace1.total_tokens,
                "total_cost_usd": trace1.total_cost_usd,
                "total_latency_ms": trace1.total_latency_ms,
                "span_count": trace1.span_count,
                "status": trace1.status.value,
            },
            "trace_2": {
                "trace_id": trace_id_2,
                "total_tokens": trace2.total_tokens,
                "total_cost_usd": trace2.total_cost_usd,
                "total_latency_ms": trace2.total_latency_ms,
                "span_count": trace2.span_count,
                "status": trace2.status.value,
            },
            "differences": {
                "tokens_diff": trace2.total_tokens - trace1.total_tokens,
                "cost_diff_usd": trace2.total_cost_usd - trace1.total_cost_usd,
                "latency_diff_ms": trace2.total_latency_ms - trace1.total_latency_ms,
                "span_count_diff": trace2.span_count - trace1.span_count,
            },
        }


# Module-level replay instance
_replay: Optional[SessionReplay] = None


def get_session_replay() -> SessionReplay:
    """Get or create the global session replay instance."""
    global _replay
    if _replay is None:
        _replay = SessionReplay()
    return _replay


def set_session_replay(replay: Optional[SessionReplay]) -> None:
    """Set the global session replay instance."""
    global _replay
    _replay = replay
