#!/usr/bin/env python3
"""
Web Search Integration - Tavily API Interface

Provides web search capabilities to fill research gaps by fetching external sources.
Integrates with the Research Agent for automatic gap-filling and supports manual
searches from the TUI Dashboard.

Features:
- Advanced search with full content extraction
- Batch search for multiple gap queries
- Results formatting for display and project storage
- Credit tracking with configurable limits
- Cost estimation before searches
- Graceful degradation when API unavailable
"""

import os
import json
import logging
from datetime import datetime
from pathlib import Path
from typing import Dict, List, Optional, Any, Tuple
from dataclasses import dataclass, field

# Configure logging
logger = logging.getLogger(__name__)

# Import configuration
try:
    from config import TAVILY_API_KEY, TAVILY_ENABLED, TAVILY_CONFIG
except ImportError:
    TAVILY_API_KEY = ""
    TAVILY_ENABLED = False
    TAVILY_CONFIG = {
        'search_depth': 'advanced',
        'max_results': 10,
        'include_answer': True,
        'include_raw_content': True,
        'timeout': 30,
    }

# =============================================================================
# CREDIT TRACKING (PostgreSQL-backed for multi-user/multi-machine support)
# =============================================================================

# Credit costs per search type (approximate Tavily pricing)
CREDIT_COSTS = {
    'basic': 1,      # Basic search: ~$0.001 = 1 credit
    'advanced': 2,   # Advanced search: ~$0.002 = 2 credits
}

# Default credit limit (can be overridden in config or database)
DEFAULT_CREDIT_LIMIT = 1000

# Legacy file path (for migration purposes only)
CREDIT_TRACKING_FILE = Path(__file__).parent.parent / 'data' / 'tavily_credits.json'


@dataclass
class CreditUsage:
    """Tracks Tavily credit usage."""
    total_used: int = 0
    limit: int = DEFAULT_CREDIT_LIMIT
    searches_count: int = 0
    last_search_at: Optional[str] = None
    history: List[Dict[str, Any]] = field(default_factory=list)

    def to_dict(self) -> Dict[str, Any]:
        return {
            'total_used': self.total_used,
            'limit': self.limit,
            'searches_count': self.searches_count,
            'last_search_at': self.last_search_at,
            'history': self.history[-100:],  # Keep last 100 entries
        }

    @classmethod
    def from_dict(cls, data: Dict[str, Any]) -> 'CreditUsage':
        return cls(
            total_used=data.get('total_used', 0),
            limit=data.get('limit', DEFAULT_CREDIT_LIMIT),
            searches_count=data.get('searches_count', 0),
            last_search_at=data.get('last_search_at'),
            history=data.get('history', []),
        )

    @property
    def remaining(self) -> int:
        """Credits remaining before limit."""
        return max(0, self.limit - self.total_used)

    @property
    def at_limit(self) -> bool:
        """Check if at or over limit."""
        return self.total_used >= self.limit

    def can_afford(self, cost: int) -> bool:
        """Check if we can afford a search with given cost."""
        return self.remaining >= cost

    def record_search(self, query: str, cost: int, results_count: int) -> None:
        """Record a search and its cost."""
        self.total_used += cost
        self.searches_count += 1
        self.last_search_at = datetime.now().isoformat()
        self.history.append({
            'query': query[:100],
            'cost': cost,
            'results': results_count,
            'timestamp': self.last_search_at,
        })


class CreditTracker:
    """Manages Tavily credit tracking with PostgreSQL persistence.

    Falls back to local JSON file if database is unavailable.
    """

    def __init__(self, limit: int = None):
        """Initialize credit tracker.

        Args:
            limit: Credit limit (default: 1000 or from config/database)
        """
        self._use_database = self._check_database()
        self.usage = self._load_usage()
        if limit is not None:
            self.usage.limit = limit
        elif 'credit_limit' in TAVILY_CONFIG:
            self.usage.limit = TAVILY_CONFIG['credit_limit']

    def _check_database(self) -> bool:
        """Check if database credit tracking is available."""
        try:
            from db_utils import get_db_connection
            conn = get_db_connection()
            cur = conn.cursor()
            # Check if system_settings table exists
            cur.execute("""
                SELECT EXISTS (
                    SELECT FROM information_schema.tables
                    WHERE table_name = 'system_settings'
                )
            """)
            exists = cur.fetchone()[0]
            conn.close()
            if exists:
                logger.debug("Using PostgreSQL for credit tracking")
                return True
        except Exception as e:
            logger.debug(f"Database credit tracking unavailable: {e}")
        return False

    def _load_usage(self) -> CreditUsage:
        """Load usage from database or disk."""
        if self._use_database:
            return self._load_from_database()
        return self._load_from_file()

    def _load_from_database(self) -> CreditUsage:
        """Load usage from PostgreSQL system_settings table."""
        try:
            from db_utils import get_db_connection
            conn = get_db_connection()
            cur = conn.cursor()

            # Get credits from system_settings
            cur.execute("""
                SELECT setting_value
                FROM system_settings
                WHERE setting_key = 'tavily_credits'
            """)
            row = cur.fetchone()

            if row:
                data = row[0]
                usage = CreditUsage(
                    total_used=data.get('used', 0),
                    limit=data.get('limit', DEFAULT_CREDIT_LIMIT),
                )

                # Get search count from usage_logs
                cur.execute("""
                    SELECT COUNT(*), MAX(created_at)
                    FROM usage_logs
                    WHERE service = 'tavily' AND success = TRUE
                """)
                count_row = cur.fetchone()
                if count_row:
                    usage.searches_count = count_row[0] or 0
                    if count_row[1]:
                        usage.last_search_at = count_row[1].isoformat()

                conn.close()
                return usage

            conn.close()
        except Exception as e:
            logger.warning(f"Could not load credits from database: {e}")

        return CreditUsage()

    def _load_from_file(self) -> CreditUsage:
        """Load usage from local JSON file (fallback)."""
        try:
            CREDIT_TRACKING_FILE.parent.mkdir(parents=True, exist_ok=True)
            if CREDIT_TRACKING_FILE.exists():
                with open(CREDIT_TRACKING_FILE) as f:
                    data = json.load(f)
                    return CreditUsage.from_dict(data)
        except Exception as e:
            logger.warning(f"Could not load credit usage from file: {e}")
        return CreditUsage()

    def _save_usage(self) -> None:
        """Save usage to database or disk."""
        if self._use_database:
            self._save_to_database()
        else:
            self._save_to_file()

    def _save_to_database(self) -> None:
        """Save usage to PostgreSQL."""
        try:
            from db_utils import get_db_connection
            conn = get_db_connection()
            cur = conn.cursor()

            # Update system_settings
            cur.execute("""
                INSERT INTO system_settings (setting_key, setting_value, description)
                VALUES ('tavily_credits', %s, 'Tavily API credit tracking')
                ON CONFLICT (setting_key) DO UPDATE
                SET setting_value = EXCLUDED.setting_value,
                    updated_at = CURRENT_TIMESTAMP
            """, (json.dumps({
                'used': self.usage.total_used,
                'limit': self.usage.limit,
                'reset_date': None,
            }),))

            conn.commit()
            conn.close()
        except Exception as e:
            logger.warning(f"Could not save credits to database: {e}")
            # Fallback to file
            self._save_to_file()

    def _save_to_file(self) -> None:
        """Save usage to local JSON file."""
        try:
            CREDIT_TRACKING_FILE.parent.mkdir(parents=True, exist_ok=True)
            with open(CREDIT_TRACKING_FILE, 'w') as f:
                json.dump(self.usage.to_dict(), f, indent=2)
        except Exception as e:
            logger.warning(f"Could not save credit usage to file: {e}")

    def _log_to_database(self, query: str, cost: int, results_count: int,
                         session_id: str = None, project_id: int = None) -> None:
        """Log search to usage_logs table."""
        if not self._use_database:
            return

        try:
            from db_utils import get_db_connection
            conn = get_db_connection()
            cur = conn.cursor()

            cur.execute("""
                INSERT INTO usage_logs
                (service, operation, query_text, result_count, credits_used, session_id, project_id)
                VALUES ('tavily', 'search', %s, %s, %s, %s, %s)
            """, (query[:500], results_count, cost, session_id, project_id))

            conn.commit()
            conn.close()
        except Exception as e:
            logger.warning(f"Could not log usage to database: {e}")

    def estimate_cost(self, search_depth: str = 'advanced', num_searches: int = 1) -> int:
        """Estimate cost for planned searches.

        Args:
            search_depth: 'basic' or 'advanced'
            num_searches: Number of searches planned

        Returns:
            Estimated credit cost
        """
        cost_per_search = CREDIT_COSTS.get(search_depth, 2)
        return cost_per_search * num_searches

    def can_search(self, search_depth: str = 'advanced', num_searches: int = 1) -> Tuple[bool, str]:
        """Check if searches can be performed within limit.

        Args:
            search_depth: 'basic' or 'advanced'
            num_searches: Number of searches planned

        Returns:
            Tuple of (can_proceed, reason_message)
        """
        # Reload from database to get latest state (multi-user support)
        if self._use_database:
            self.usage = self._load_from_database()

        cost = self.estimate_cost(search_depth, num_searches)

        if self.usage.at_limit:
            return False, f"Credit limit reached ({self.usage.total_used}/{self.usage.limit})"

        if not self.usage.can_afford(cost):
            return False, (
                f"Insufficient credits: need {cost}, have {self.usage.remaining} "
                f"({self.usage.total_used}/{self.usage.limit} used)"
            )

        return True, f"OK: {cost} credits (remaining: {self.usage.remaining - cost})"

    def record_search(self, query: str, search_depth: str, results_count: int,
                      session_id: str = None, project_id: int = None) -> None:
        """Record a completed search.

        Args:
            query: The search query
            search_depth: 'basic' or 'advanced'
            results_count: Number of results returned
            session_id: Optional session identifier
            project_id: Optional project ID
        """
        cost = CREDIT_COSTS.get(search_depth, 2)
        self.usage.record_search(query, cost, results_count)
        self._save_usage()
        self._log_to_database(query, cost, results_count, session_id, project_id)
        logger.info(f"Tavily credit used: {cost} (total: {self.usage.total_used}/{self.usage.limit})")

    def set_limit(self, new_limit: int) -> None:
        """Update the credit limit.

        Args:
            new_limit: New credit limit
        """
        self.usage.limit = new_limit
        self._save_usage()
        logger.info(f"Tavily credit limit set to: {new_limit}")

    def reset_usage(self) -> None:
        """Reset usage counter (e.g., at start of new billing period)."""
        self.usage.total_used = 0
        self.usage.searches_count = 0
        self.usage.history = []
        self._save_usage()

        # Also reset in database if available
        if self._use_database:
            try:
                from db_utils import get_db_connection
                conn = get_db_connection()
                cur = conn.cursor()
                cur.execute("SELECT reset_tavily_credits()")
                conn.commit()
                conn.close()
            except Exception as e:
                logger.warning(f"Could not reset credits in database: {e}")

        logger.info("Tavily credit usage reset")

    def get_summary(self) -> Dict[str, Any]:
        """Get usage summary.

        Returns:
            Dictionary with usage statistics
        """
        # Refresh from database for latest state
        if self._use_database:
            self.usage = self._load_from_database()

        return {
            'used': self.usage.total_used,
            'limit': self.usage.limit,
            'remaining': self.usage.remaining,
            'searches': self.usage.searches_count,
            'at_limit': self.usage.at_limit,
            'last_search': self.usage.last_search_at,
            'storage': 'database' if self._use_database else 'file',
        }


# Global credit tracker instance
_credit_tracker: Optional[CreditTracker] = None


def get_credit_tracker() -> CreditTracker:
    """Get or create the singleton credit tracker."""
    global _credit_tracker
    if _credit_tracker is None:
        _credit_tracker = CreditTracker()
    return _credit_tracker


@dataclass
class WebSearchResult:
    """Represents a single web search result."""
    url: str
    title: str
    content: str
    raw_content: Optional[str] = None
    score: float = 0.0
    query: str = ""
    retrieved_at: str = field(default_factory=lambda: datetime.now().isoformat())

    def to_dict(self) -> Dict[str, Any]:
        """Convert to dictionary for storage."""
        return {
            'url': self.url,
            'title': self.title,
            'content': self.content,
            'raw_content': self.raw_content,
            'score': self.score,
            'query': self.query,
            'retrieved_at': self.retrieved_at,
        }

    @classmethod
    def from_tavily_result(cls, result: Dict, query: str = "") -> 'WebSearchResult':
        """Create from Tavily API response item."""
        return cls(
            url=result.get('url', ''),
            title=result.get('title', 'Untitled'),
            content=result.get('content', ''),
            raw_content=result.get('raw_content'),
            score=result.get('score', 0.0),
            query=query,
        )


@dataclass
class WebSearchResponse:
    """Represents a complete web search response."""
    query: str
    answer: Optional[str] = None
    results: List[WebSearchResult] = field(default_factory=list)
    error: Optional[str] = None
    search_time_ms: int = 0

    @property
    def success(self) -> bool:
        """Check if search was successful."""
        return self.error is None and len(self.results) > 0

    @property
    def result_count(self) -> int:
        """Number of results returned."""
        return len(self.results)


class TavilySearchInterface:
    """Interface for Tavily web search API.

    Provides methods for:
    - Single query search
    - Batch search for multiple gaps
    - Result formatting for display and storage
    """

    def __init__(self):
        """Initialize the Tavily client if API key is available."""
        self.enabled = TAVILY_ENABLED
        self.config = TAVILY_CONFIG
        self.client = None

        if self.enabled:
            try:
                from tavily import TavilyClient
                self.client = TavilyClient(api_key=TAVILY_API_KEY)
                logger.info("Tavily web search client initialized")
            except ImportError:
                logger.warning(
                    "tavily-python not installed. "
                    "Install with: pip install tavily-python"
                )
                self.enabled = False
            except Exception as e:
                logger.error(f"Failed to initialize Tavily client: {e}")
                self.enabled = False
        else:
            logger.info("Tavily web search disabled (no API key configured)")

    def search(
        self,
        query: str,
        max_results: Optional[int] = None,
        search_depth: Optional[str] = None,
        include_answer: Optional[bool] = None,
        check_credits: bool = True,
    ) -> WebSearchResponse:
        """Execute a web search query.

        Args:
            query: The search query string
            max_results: Maximum number of results (default from config)
            search_depth: 'basic' or 'advanced' (default from config)
            include_answer: Include AI-generated answer (default from config)
            check_credits: Check credit limits before searching (default: True)

        Returns:
            WebSearchResponse with results or error
        """
        if not self.enabled or self.client is None:
            return WebSearchResponse(
                query=query,
                error="Tavily web search is not available. Configure TAVILY_API_KEY in .env"
            )

        # Use defaults from config if not specified
        max_results = max_results or self.config['max_results']
        search_depth = search_depth or self.config['search_depth']
        include_answer = include_answer if include_answer is not None else self.config['include_answer']

        # Check credit limits before searching
        if check_credits:
            tracker = get_credit_tracker()
            can_proceed, reason = tracker.can_search(search_depth, 1)
            if not can_proceed:
                logger.warning(f"Tavily search blocked: {reason}")
                return WebSearchResponse(
                    query=query,
                    error=f"Credit limit: {reason}"
                )

        try:
            import time
            start_time = time.time()

            response = self.client.search(
                query=query,
                search_depth=search_depth,
                max_results=max_results,
                include_answer=include_answer,
                include_raw_content=self.config.get('include_raw_content', True),
            )

            elapsed_ms = int((time.time() - start_time) * 1000)

            # Parse results
            results = [
                WebSearchResult.from_tavily_result(r, query)
                for r in response.get('results', [])
            ]

            # Record credit usage
            tracker = get_credit_tracker()
            tracker.record_search(query, search_depth, len(results))

            return WebSearchResponse(
                query=query,
                answer=response.get('answer'),
                results=results,
                search_time_ms=elapsed_ms,
            )

        except Exception as e:
            logger.error(f"Tavily search failed for '{query}': {e}")
            return WebSearchResponse(
                query=query,
                error=str(e)
            )

    def search_for_gaps(
        self,
        gaps: List[str],
        max_results_per_gap: int = 5,
        check_credits: bool = True,
    ) -> Tuple[Dict[str, WebSearchResponse], Dict[str, Any]]:
        """Search for multiple gap queries.

        Args:
            gaps: List of gap descriptions/queries
            max_results_per_gap: Maximum results per gap query
            check_credits: Check credit limits before searching

        Returns:
            Tuple of:
            - Dictionary mapping gap query to WebSearchResponse
            - Credit usage summary (before and after)
        """
        search_depth = self.config.get('search_depth', 'advanced')
        tracker = get_credit_tracker()

        # Pre-search credit check
        credit_summary_before = tracker.get_summary()

        if check_credits:
            can_proceed, reason = tracker.can_search(search_depth, len(gaps))
            if not can_proceed:
                logger.warning(f"Gap search blocked: {reason}")
                # Return empty results with credit info
                return {}, {
                    'error': reason,
                    'before': credit_summary_before,
                    'searched': 0,
                    'total_gaps': len(gaps),
                }

        results = {}
        searched_count = 0

        for gap in gaps:
            # Re-check credits for each search (in case we run out mid-batch)
            if check_credits:
                can_proceed, _ = tracker.can_search(search_depth, 1)
                if not can_proceed:
                    logger.warning(f"Credit limit reached during gap search, stopped at {searched_count}/{len(gaps)}")
                    break

            # Clean up gap text for search
            search_query = self._gap_to_query(gap)
            results[gap] = self.search(
                query=search_query,
                max_results=max_results_per_gap,
                check_credits=False,  # Already checked above
            )
            searched_count += 1

        credit_summary_after = tracker.get_summary()

        return results, {
            'before': credit_summary_before,
            'after': credit_summary_after,
            'searched': searched_count,
            'total_gaps': len(gaps),
            'credits_used': credit_summary_after['used'] - credit_summary_before['used'],
        }

    def _gap_to_query(self, gap: str) -> str:
        """Convert a gap description to a search query.

        Args:
            gap: Gap description from analysis

        Returns:
            Cleaned search query string
        """
        # Remove common prefixes from gap analysis output
        prefixes_to_remove = [
            "need more information about",
            "missing information on",
            "research gap:",
            "gap:",
            "need to explore",
            "requires further research on",
        ]

        query = gap.lower().strip()
        for prefix in prefixes_to_remove:
            if query.startswith(prefix):
                query = query[len(prefix):].strip()
                break

        # Capitalize first letter
        if query:
            query = query[0].upper() + query[1:]

        return query

    def format_result_as_note(self, result: WebSearchResult) -> str:
        """Format a single result as a project note.

        Args:
            result: WebSearchResult to format

        Returns:
            Formatted note string with metadata
        """
        note_lines = [
            f"## {result.title}",
            "",
            f"**URL:** {result.url}",
            f"**Retrieved:** {result.retrieved_at}",
            f"**Search Query:** {result.query}",
            f"**Relevance Score:** {result.score:.2f}" if result.score else "",
            "",
            "### Content",
            "",
            result.content,
        ]

        # Add raw content if available and different from snippet
        if result.raw_content and result.raw_content != result.content:
            # Truncate raw content if too long
            raw = result.raw_content
            if len(raw) > 5000:
                raw = raw[:5000] + "\n\n... [truncated]"

            note_lines.extend([
                "",
                "### Full Content",
                "",
                raw,
            ])

        return "\n".join(line for line in note_lines if line is not None)

    def format_results_markdown(
        self,
        response: WebSearchResponse,
        include_answer: bool = True
    ) -> str:
        """Format search results as markdown for display.

        Args:
            response: WebSearchResponse to format
            include_answer: Include the AI-generated answer if available

        Returns:
            Markdown formatted string
        """
        lines = [
            f"# Web Search Results",
            "",
            f"**Query:** {response.query}",
            f"**Results:** {response.result_count} found",
            f"**Search Time:** {response.search_time_ms}ms",
            "",
        ]

        if response.error:
            lines.extend([
                "## Error",
                "",
                f"*{response.error}*",
                "",
            ])
            return "\n".join(lines)

        # Add AI answer if available
        if include_answer and response.answer:
            lines.extend([
                "## AI Summary",
                "",
                response.answer,
                "",
                "---",
                "",
            ])

        # Add individual results
        for i, result in enumerate(response.results, 1):
            lines.extend([
                f"### {i}. {result.title}",
                "",
                f"**URL:** `{result.url}`",
                "",
                result.content,
                "",
                f"*Relevance: {result.score:.2f}*" if result.score else "",
                "",
                "---",
                "",
            ])

        return "\n".join(line for line in lines if line is not None)

    def format_gap_results_markdown(
        self,
        gap_results: Dict[str, WebSearchResponse]
    ) -> str:
        """Format results from gap search as markdown.

        Args:
            gap_results: Dictionary from search_for_gaps()

        Returns:
            Markdown formatted string organized by gap
        """
        lines = [
            "# Web Research for Identified Gaps",
            "",
        ]

        for gap, response in gap_results.items():
            lines.extend([
                f"## Gap: {gap}",
                "",
            ])

            if response.error:
                lines.extend([
                    f"*Search failed: {response.error}*",
                    "",
                ])
                continue

            if response.answer:
                lines.extend([
                    "**Summary:**",
                    response.answer,
                    "",
                ])

            if response.results:
                lines.append("**Sources:**")
                lines.append("")
                for result in response.results[:5]:  # Limit to top 5 per gap
                    lines.append(f"- [{result.title}]({result.url})")
                lines.append("")

            lines.extend(["---", ""])

        return "\n".join(lines)


# Module-level instance for convenience
_search_interface: Optional[TavilySearchInterface] = None


def get_search_interface() -> TavilySearchInterface:
    """Get or create the singleton search interface."""
    global _search_interface
    if _search_interface is None:
        _search_interface = TavilySearchInterface()
    return _search_interface


def web_search(query: str, max_results: int = 10) -> WebSearchResponse:
    """Convenience function for quick web search.

    Args:
        query: Search query string
        max_results: Maximum results to return

    Returns:
        WebSearchResponse with results
    """
    return get_search_interface().search(query, max_results=max_results)


def search_gaps(gaps: List[str]) -> Tuple[Dict[str, WebSearchResponse], Dict[str, Any]]:
    """Convenience function for gap search.

    Args:
        gaps: List of gap descriptions

    Returns:
        Tuple of (results dict, credit usage summary)
    """
    return get_search_interface().search_for_gaps(gaps)


def estimate_gap_search_cost(num_gaps: int, search_depth: str = 'advanced') -> Dict[str, Any]:
    """Estimate the cost of searching for gaps.

    Args:
        num_gaps: Number of gap queries to search
        search_depth: 'basic' or 'advanced'

    Returns:
        Cost estimation dictionary
    """
    tracker = get_credit_tracker()
    cost = tracker.estimate_cost(search_depth, num_gaps)
    summary = tracker.get_summary()

    return {
        'estimated_credits': cost,
        'current_used': summary['used'],
        'current_limit': summary['limit'],
        'remaining_after': max(0, summary['remaining'] - cost),
        'can_afford': summary['remaining'] >= cost,
        'search_depth': search_depth,
        'num_searches': num_gaps,
    }


def get_credit_summary() -> Dict[str, Any]:
    """Get current credit usage summary.

    Returns:
        Credit usage statistics
    """
    return get_credit_tracker().get_summary()


def set_credit_limit(new_limit: int) -> None:
    """Set the credit limit.

    Args:
        new_limit: New credit limit
    """
    get_credit_tracker().set_limit(new_limit)


def reset_credits() -> None:
    """Reset credit usage (e.g., for new billing period)."""
    get_credit_tracker().reset_usage()


# Export public API
__all__ = [
    'TavilySearchInterface',
    'WebSearchResult',
    'WebSearchResponse',
    'CreditTracker',
    'CreditUsage',
    'get_search_interface',
    'get_credit_tracker',
    'web_search',
    'search_gaps',
    'estimate_gap_search_cost',
    'get_credit_summary',
    'set_credit_limit',
    'reset_credits',
    'TAVILY_ENABLED',
    'CREDIT_COSTS',
    'DEFAULT_CREDIT_LIMIT',
]
