"""
Shared Memory - Shared state accessible by multiple agents.

Provides:
- Thread-safe key-value storage
- Lock-based write coordination
- Change history tracking
- Wait-for-key capability
"""

import asyncio
import logging
from dataclasses import dataclass, field
from datetime import datetime
from typing import Any, Callable, Optional

logger = logging.getLogger(__name__)


@dataclass
class StateChange:
    """Record of a state change."""

    key: str
    old_value: Any
    new_value: Any
    agent_id: str
    timestamp: datetime = field(default_factory=datetime.now)
    operation: str = "write"  # write, delete, clear

    def to_dict(self) -> dict[str, Any]:
        """Convert to dictionary."""
        return {
            "key": self.key,
            "old_value": self.old_value,
            "new_value": self.new_value,
            "agent_id": self.agent_id,
            "timestamp": self.timestamp.isoformat(),
            "operation": self.operation,
        }


class SharedMemory:
    """
    Shared state accessible by all agents in a workflow.

    Provides thread-safe access to shared state with locking,
    change tracking, and wait-for-key capabilities.

    Inspired by Agent-MCP coordination patterns.

    Example:
        memory = SharedMemory("workflow-123")

        # Write value
        await memory.write("config", {"timeout": 30}, agent_id="claude-code")

        # Read value
        config = await memory.read("config")

        # Wait for another agent to set a value
        result = await memory.wait_for("analysis_complete", timeout=60.0)
    """

    def __init__(self, workflow_id: str):
        """
        Initialize shared memory for a workflow.

        Args:
            workflow_id: Unique workflow identifier
        """
        self.workflow_id = workflow_id
        self._state: dict[str, Any] = {}
        self._locks: dict[str, asyncio.Lock] = {}
        self._global_lock = asyncio.Lock()
        self._history: list[StateChange] = []
        self._watchers: dict[str, list[asyncio.Event]] = {}
        self._callbacks: list[Callable[[StateChange], None]] = []

    async def read(self, key: str, default: Any = None) -> Any:
        """
        Read value from shared memory.

        Args:
            key: Key to read
            default: Default value if key doesn't exist

        Returns:
            Value or default
        """
        return self._state.get(key, default)

    async def read_all(self) -> dict[str, Any]:
        """
        Read all values from shared memory.

        Returns:
            Copy of entire state
        """
        return dict(self._state)

    async def write(
        self,
        key: str,
        value: Any,
        agent_id: str,
    ) -> None:
        """
        Write value to shared memory with lock.

        Args:
            key: Key to write
            value: Value to write
            agent_id: Agent performing the write
        """
        # Ensure lock exists for key
        async with self._global_lock:
            if key not in self._locks:
                self._locks[key] = asyncio.Lock()

        # Acquire lock and write
        async with self._locks[key]:
            old_value = self._state.get(key)
            self._state[key] = value

            # Record change
            change = StateChange(
                key=key,
                old_value=old_value,
                new_value=value,
                agent_id=agent_id,
                operation="write",
            )
            self._history.append(change)
            self._notify_callbacks(change)

            # Notify watchers
            self._notify_watchers(key)

            logger.debug(f"SharedMemory[{self.workflow_id}]: {agent_id} wrote {key}")

    async def delete(self, key: str, agent_id: str) -> bool:
        """
        Delete a key from shared memory.

        Args:
            key: Key to delete
            agent_id: Agent performing the delete

        Returns:
            True if key existed
        """
        if key not in self._state:
            return False

        async with self._global_lock:
            if key not in self._locks:
                self._locks[key] = asyncio.Lock()

        async with self._locks[key]:
            if key not in self._state:
                return False

            old_value = self._state.pop(key)

            change = StateChange(
                key=key,
                old_value=old_value,
                new_value=None,
                agent_id=agent_id,
                operation="delete",
            )
            self._history.append(change)
            self._notify_callbacks(change)

            logger.debug(f"SharedMemory[{self.workflow_id}]: {agent_id} deleted {key}")
            return True

    async def clear(self, agent_id: str) -> None:
        """
        Clear all state.

        Args:
            agent_id: Agent performing the clear
        """
        async with self._global_lock:
            old_state = dict(self._state)
            self._state.clear()

            for key, value in old_state.items():
                change = StateChange(
                    key=key,
                    old_value=value,
                    new_value=None,
                    agent_id=agent_id,
                    operation="clear",
                )
                self._history.append(change)

            logger.debug(f"SharedMemory[{self.workflow_id}]: {agent_id} cleared all state")

    async def wait_for(
        self,
        key: str,
        timeout: float = 60.0,
        check_interval: float = 0.1,
    ) -> Any:
        """
        Wait for a key to be set.

        Args:
            key: Key to wait for
            timeout: Maximum wait time in seconds
            check_interval: Interval between checks

        Returns:
            Value once set

        Raises:
            TimeoutError: If timeout exceeded
        """
        # Check if already exists
        if key in self._state:
            return self._state[key]

        # Create event for this key
        event = asyncio.Event()
        if key not in self._watchers:
            self._watchers[key] = []
        self._watchers[key].append(event)

        try:
            # Wait with timeout
            await asyncio.wait_for(event.wait(), timeout=timeout)
            return self._state.get(key)
        except asyncio.TimeoutError:
            raise TimeoutError(f"Timeout waiting for key: {key}")
        finally:
            # Cleanup watcher
            if key in self._watchers:
                self._watchers[key] = [e for e in self._watchers[key] if e != event]

    async def wait_for_condition(
        self,
        predicate: Callable[[dict[str, Any]], bool],
        timeout: float = 60.0,
        check_interval: float = 0.5,
    ) -> bool:
        """
        Wait for a condition to be true.

        Args:
            predicate: Function that takes state and returns bool
            timeout: Maximum wait time
            check_interval: Interval between checks

        Returns:
            True if condition met, False if timeout
        """
        start = datetime.now()
        while (datetime.now() - start).total_seconds() < timeout:
            if predicate(self._state):
                return True
            await asyncio.sleep(check_interval)
        return False

    def _notify_watchers(self, key: str) -> None:
        """Notify watchers that a key was set."""
        if key in self._watchers:
            for event in self._watchers[key]:
                event.set()

    def _notify_callbacks(self, change: StateChange) -> None:
        """Notify callbacks of state change."""
        for callback in self._callbacks:
            try:
                callback(change)
            except Exception as e:
                logger.error(f"Error in shared memory callback: {e}")

    def on_change(self, callback: Callable[[StateChange], None]) -> None:
        """Register callback for state changes."""
        self._callbacks.append(callback)

    def has_key(self, key: str) -> bool:
        """Check if key exists."""
        return key in self._state

    def keys(self) -> list[str]:
        """Get all keys."""
        return list(self._state.keys())

    def get_history(
        self,
        key: Optional[str] = None,
        agent_id: Optional[str] = None,
        limit: int = 100,
    ) -> list[StateChange]:
        """
        Get change history.

        Args:
            key: Filter by key
            agent_id: Filter by agent
            limit: Maximum results

        Returns:
            List of state changes
        """
        results = self._history

        if key:
            results = [c for c in results if c.key == key]
        if agent_id:
            results = [c for c in results if c.agent_id == agent_id]

        return results[-limit:]

    def get_stats(self) -> dict[str, Any]:
        """Get memory statistics."""
        return {
            "workflow_id": self.workflow_id,
            "key_count": len(self._state),
            "total_changes": len(self._history),
            "keys": list(self._state.keys()),
        }


class SharedMemoryManager:
    """
    Manager for multiple shared memory instances.

    Creates and manages SharedMemory instances for different workflows.

    Example:
        manager = SharedMemoryManager()

        # Get or create memory for a workflow
        memory = manager.get_or_create("workflow-123")

        # Write value
        await memory.write("key", "value", "agent-1")
    """

    def __init__(self):
        """Initialize the manager."""
        self._memories: dict[str, SharedMemory] = {}

    def get_or_create(self, workflow_id: str) -> SharedMemory:
        """
        Get or create shared memory for a workflow.

        Args:
            workflow_id: Workflow identifier

        Returns:
            SharedMemory instance
        """
        if workflow_id not in self._memories:
            self._memories[workflow_id] = SharedMemory(workflow_id)
            logger.info(f"Created shared memory for workflow {workflow_id}")
        return self._memories[workflow_id]

    def get(self, workflow_id: str) -> Optional[SharedMemory]:
        """Get shared memory if exists."""
        return self._memories.get(workflow_id)

    def delete(self, workflow_id: str) -> bool:
        """Delete shared memory for a workflow."""
        if workflow_id in self._memories:
            del self._memories[workflow_id]
            logger.info(f"Deleted shared memory for workflow {workflow_id}")
            return True
        return False

    def list_workflows(self) -> list[str]:
        """List all workflow IDs with shared memory."""
        return list(self._memories.keys())

    def get_stats(self) -> dict[str, Any]:
        """Get manager statistics."""
        return {
            "total_workflows": len(self._memories),
            "workflows": {
                wf_id: mem.get_stats()
                for wf_id, mem in self._memories.items()
            },
        }


# Module-level instance
_manager: Optional[SharedMemoryManager] = None


def get_shared_memory_manager() -> SharedMemoryManager:
    """Get or create the global shared memory manager."""
    global _manager
    if _manager is None:
        _manager = SharedMemoryManager()
    return _manager


def set_shared_memory_manager(manager: Optional[SharedMemoryManager]) -> None:
    """Set the global shared memory manager."""
    global _manager
    _manager = manager
