"""
Working Memory - Per-task scratch space.

This module provides:
- Task-scoped temporary storage
- Transient notes and observations
- Tool output caching
- Summarization for promotion to long-term memory

Working memory is discarded at task end unless summarized.

Usage:
    from agent_orchestrator.memory.working import WorkingMemory

    wm = WorkingMemory()
    session = wm.create_session("task-123", "claude-code")

    session.add_note("Tried approach A, failed with error X")
    session.add_tool_output("pytest", "FAILED: 3 errors")

    summary = session.summarize()
    session.discard()
"""

import logging
from dataclasses import dataclass, field
from datetime import datetime
from typing import Any, Optional
import uuid


logger = logging.getLogger(__name__)


@dataclass
class ToolOutput:
    """A captured tool output."""

    tool_name: str
    command: Optional[str]
    output: str
    exit_code: Optional[int]
    timestamp: datetime = field(default_factory=datetime.now)


@dataclass
class WorkingNote:
    """A transient working note."""

    content: str
    category: str = "observation"  # observation, attempt, decision, blocker
    timestamp: datetime = field(default_factory=datetime.now)


@dataclass
class WorkingSession:
    """
    A task-scoped working memory session.

    Each task gets its own session that is discarded at the end.
    Contents can be summarized before disposal.
    """

    session_id: str
    task_id: str
    agent_id: str
    created_at: datetime = field(default_factory=datetime.now)
    notes: list[WorkingNote] = field(default_factory=list)
    tool_outputs: list[ToolOutput] = field(default_factory=list)
    key_values: dict[str, Any] = field(default_factory=dict)
    is_active: bool = True

    def add_note(self, content: str, category: str = "observation") -> None:
        """
        Add a working note.

        Args:
            content: Note content
            category: Note category (observation, attempt, decision, blocker)
        """
        note = WorkingNote(content=content, category=category)
        self.notes.append(note)
        logger.debug(f"[{self.session_id}] Note added: {category}")

    def add_attempt(self, description: str) -> None:
        """Record an attempted approach."""
        self.add_note(f"Attempted: {description}", "attempt")

    def add_blocker(self, description: str) -> None:
        """Record a blocker or issue."""
        self.add_note(f"Blocker: {description}", "blocker")

    def add_decision(self, description: str) -> None:
        """Record a decision made during the task."""
        self.add_note(f"Decided: {description}", "decision")

    def add_tool_output(
        self,
        tool_name: str,
        output: str,
        command: Optional[str] = None,
        exit_code: Optional[int] = None,
    ) -> None:
        """
        Add a tool output.

        Args:
            tool_name: Name of the tool
            output: Output text
            command: Command that was run
            exit_code: Exit code if applicable
        """
        tool_output = ToolOutput(
            tool_name=tool_name,
            command=command,
            output=output,
            exit_code=exit_code,
        )
        self.tool_outputs.append(tool_output)
        logger.debug(f"[{self.session_id}] Tool output added: {tool_name}")

    def set_value(self, key: str, value: Any) -> None:
        """Set a key-value pair in working memory."""
        self.key_values[key] = value

    def get_value(self, key: str, default: Any = None) -> Any:
        """Get a value from working memory."""
        return self.key_values.get(key, default)

    def get_context(self, max_notes: int = 10, max_outputs: int = 5) -> dict[str, Any]:
        """
        Get the current working context.

        Args:
            max_notes: Maximum recent notes to include
            max_outputs: Maximum tool outputs to include

        Returns:
            Context dictionary
        """
        recent_notes = self.notes[-max_notes:]
        recent_outputs = self.tool_outputs[-max_outputs:]

        return {
            "session_id": self.session_id,
            "task_id": self.task_id,
            "agent_id": self.agent_id,
            "notes": [
                {"content": n.content, "category": n.category}
                for n in recent_notes
            ],
            "tool_outputs": [
                {
                    "tool": o.tool_name,
                    "command": o.command,
                    "output": o.output[:500],  # Truncate
                    "exit_code": o.exit_code,
                }
                for o in recent_outputs
            ],
            "key_values": self.key_values,
        }

    def get_notes_by_category(self, category: str) -> list[WorkingNote]:
        """Get notes filtered by category."""
        return [n for n in self.notes if n.category == category]

    def get_blockers(self) -> list[str]:
        """Get all recorded blockers."""
        blockers = self.get_notes_by_category("blocker")
        return [b.content for b in blockers]

    def get_attempts(self) -> list[str]:
        """Get all recorded attempts."""
        attempts = self.get_notes_by_category("attempt")
        return [a.content for a in attempts]

    def summarize(self) -> dict[str, Any]:
        """
        Summarize the working session for potential promotion.

        Returns:
            Summary dictionary suitable for creating a task summary
        """
        # Count by category
        note_counts = {}
        for note in self.notes:
            note_counts[note.category] = note_counts.get(note.category, 0) + 1

        # Get unique tools used
        tools_used = list(set(o.tool_name for o in self.tool_outputs))

        # Get final outcomes (last few notes)
        final_notes = self.notes[-3:] if self.notes else []

        # Check for errors in outputs
        errors = [
            o for o in self.tool_outputs
            if o.exit_code and o.exit_code != 0
        ]

        summary = {
            "session_id": self.session_id,
            "task_id": self.task_id,
            "agent_id": self.agent_id,
            "duration_seconds": (datetime.now() - self.created_at).total_seconds(),
            "note_counts": note_counts,
            "total_notes": len(self.notes),
            "total_tool_outputs": len(self.tool_outputs),
            "tools_used": tools_used,
            "blockers_encountered": self.get_blockers(),
            "final_notes": [n.content for n in final_notes],
            "had_errors": len(errors) > 0,
            "error_count": len(errors),
        }

        return summary

    def to_prompt_string(self, max_notes: int = 5) -> str:
        """
        Convert working memory to a string for prompt injection.

        Args:
            max_notes: Maximum notes to include

        Returns:
            Formatted string
        """
        parts = ["## Working Memory (this session)"]

        # Recent notes
        recent_notes = self.notes[-max_notes:]
        if recent_notes:
            parts.append("\nRecent observations:")
            for note in recent_notes:
                parts.append(f"- [{note.category}] {note.content}")

        # Blockers
        blockers = self.get_blockers()
        if blockers:
            parts.append("\nCurrent blockers:")
            for blocker in blockers[-3:]:
                parts.append(f"- {blocker}")

        return "\n".join(parts)


class WorkingMemory:
    """
    Manager for working memory sessions.

    Creates and manages per-task scratch space that is
    discarded at task completion.
    """

    def __init__(self):
        """Initialize the Working Memory manager."""
        self._sessions: dict[str, WorkingSession] = {}
        self._task_to_session: dict[str, str] = {}

    def create_session(self, task_id: str, agent_id: str) -> WorkingSession:
        """
        Create a new working session for a task.

        Args:
            task_id: The task ID
            agent_id: The agent ID

        Returns:
            New WorkingSession
        """
        session_id = f"ws-{uuid.uuid4().hex[:12]}"

        session = WorkingSession(
            session_id=session_id,
            task_id=task_id,
            agent_id=agent_id,
        )

        self._sessions[session_id] = session
        self._task_to_session[task_id] = session_id

        logger.info(f"Created working session {session_id} for task {task_id}")

        return session

    def get_session(self, session_id: str) -> Optional[WorkingSession]:
        """Get a session by ID."""
        return self._sessions.get(session_id)

    def get_session_for_task(self, task_id: str) -> Optional[WorkingSession]:
        """Get the session for a task."""
        session_id = self._task_to_session.get(task_id)
        if session_id:
            return self._sessions.get(session_id)
        return None

    def discard(self, session_id: str) -> Optional[dict[str, Any]]:
        """
        Discard a working session.

        Args:
            session_id: Session to discard

        Returns:
            Summary of the session before disposal
        """
        session = self._sessions.get(session_id)
        if not session:
            return None

        # Generate summary before discarding
        summary = session.summarize()

        # Clean up references
        session.is_active = False
        del self._sessions[session_id]

        if session.task_id in self._task_to_session:
            del self._task_to_session[session.task_id]

        logger.info(f"Discarded working session {session_id}")

        return summary

    def discard_for_task(self, task_id: str) -> Optional[dict[str, Any]]:
        """Discard the session for a task."""
        session_id = self._task_to_session.get(task_id)
        if session_id:
            return self.discard(session_id)
        return None

    def get_active_sessions(self) -> list[WorkingSession]:
        """Get all active sessions."""
        return [s for s in self._sessions.values() if s.is_active]

    def cleanup_old_sessions(self, max_age_hours: int = 24) -> int:
        """
        Clean up sessions older than max_age_hours.

        Args:
            max_age_hours: Maximum age in hours

        Returns:
            Number of sessions cleaned up
        """
        now = datetime.now()
        to_discard = []

        for session_id, session in self._sessions.items():
            age_hours = (now - session.created_at).total_seconds() / 3600
            if age_hours > max_age_hours:
                to_discard.append(session_id)

        for session_id in to_discard:
            self.discard(session_id)

        if to_discard:
            logger.info(f"Cleaned up {len(to_discard)} old working sessions")

        return len(to_discard)
