#!/usr/bin/env python3
"""
Document Pinning System - Prioritize documents in research.

Pinned documents receive higher priority in:
- Search result ranking
- Context selection for synthesis
- Research agent source selection

Usage:
    python document_pinning.py --pin DOC_001 DOC_002
    python document_pinning.py --unpin DOC_001
    python document_pinning.py --list
    python document_pinning.py --set-priority DOC_001 --priority 10
"""

import argparse
import sys
import logging
from pathlib import Path
from datetime import datetime
from typing import List, Dict, Any, Optional

# Add parent directory to path
sys.path.insert(0, str(Path(__file__).parent))

from db_utils import execute_query

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)


# =============================================================================
# CONFIGURATION
# =============================================================================

MAX_PINNED_DOCS = 50  # Maximum number of pinned documents


# =============================================================================
# PINNING FUNCTIONS
# =============================================================================

def pin_document(
    document_id: str,
    priority: int = 0,
    notes: str = None
) -> bool:
    """
    Pin a document for priority in research.

    Args:
        document_id: Document to pin
        priority: Priority level (higher = more important)
        notes: Optional notes about why pinned

    Returns:
        True if successful
    """
    # Check if document exists
    doc = execute_query(
        "SELECT document_id, title, is_pinned FROM documents WHERE document_id = %s",
        (document_id,),
        fetch='one'
    )

    if not doc:
        logger.error(f"Document not found: {document_id}")
        return False

    if doc['is_pinned']:
        logger.info(f"Document already pinned: {document_id}")
        # Update priority/notes if provided
        if priority or notes:
            execute_query(
                """
                UPDATE documents
                SET pin_priority = COALESCE(%s, pin_priority),
                    pin_notes = COALESCE(%s, pin_notes)
                WHERE document_id = %s
                """,
                (priority if priority else None, notes, document_id)
            )
        return True

    # Check max pinned limit
    count = execute_query(
        "SELECT COUNT(*) as count FROM documents WHERE is_pinned = TRUE",
        fetch='one'
    )
    if count and count['count'] >= MAX_PINNED_DOCS:
        logger.warning(f"Maximum pinned documents reached ({MAX_PINNED_DOCS})")
        return False

    # Pin the document
    execute_query(
        """
        UPDATE documents
        SET is_pinned = TRUE,
            pin_priority = %s,
            pinned_at = CURRENT_TIMESTAMP,
            pin_notes = %s
        WHERE document_id = %s
        """,
        (priority, notes, document_id)
    )

    logger.info(f"Pinned document: {document_id} (priority: {priority})")
    return True


def pin_documents(
    document_ids: List[str],
    priority: int = 0,
    notes: str = None
) -> Dict[str, bool]:
    """
    Pin multiple documents.

    Args:
        document_ids: List of documents to pin
        priority: Priority level for all
        notes: Optional notes

    Returns:
        Dict mapping document_id to success status
    """
    results = {}
    for doc_id in document_ids:
        results[doc_id] = pin_document(doc_id, priority, notes)
    return results


def unpin_document(document_id: str) -> bool:
    """
    Unpin a document.

    Args:
        document_id: Document to unpin

    Returns:
        True if successful
    """
    result = execute_query(
        """
        UPDATE documents
        SET is_pinned = FALSE,
            pin_priority = 0,
            pinned_at = NULL,
            pin_notes = NULL
        WHERE document_id = %s
        RETURNING document_id
        """,
        (document_id,),
        fetch='one'
    )

    if result:
        logger.info(f"Unpinned document: {document_id}")
        return True

    logger.warning(f"Document not found: {document_id}")
    return False


def unpin_all() -> int:
    """
    Unpin all documents.

    Returns:
        Number of documents unpinned
    """
    result = execute_query(
        """
        UPDATE documents
        SET is_pinned = FALSE,
            pin_priority = 0,
            pinned_at = NULL,
            pin_notes = NULL
        WHERE is_pinned = TRUE
        """,
        fetch='none'
    )

    # Get count of affected rows
    count_result = execute_query(
        "SELECT COUNT(*) as count FROM documents WHERE is_pinned = FALSE AND pinned_at IS NULL",
        fetch='one'
    )

    logger.info("Unpinned all documents")
    return 0  # Can't easily get affected count with current db_utils


def set_priority(document_id: str, priority: int) -> bool:
    """
    Set priority level for a pinned document.

    Args:
        document_id: Document ID
        priority: New priority level

    Returns:
        True if successful
    """
    result = execute_query(
        """
        UPDATE documents
        SET pin_priority = %s
        WHERE document_id = %s AND is_pinned = TRUE
        RETURNING document_id
        """,
        (priority, document_id),
        fetch='one'
    )

    if result:
        logger.info(f"Set priority {priority} for {document_id}")
        return True

    logger.warning(f"Document not pinned or not found: {document_id}")
    return False


# =============================================================================
# QUERY FUNCTIONS
# =============================================================================

def get_pinned_documents() -> List[Dict[str, Any]]:
    """
    Get all pinned documents ordered by priority.

    Returns:
        List of pinned document dicts
    """
    results = execute_query(
        """
        SELECT d.document_id, d.title, a.name as author, d.publication_year,
               d.pin_priority, d.pinned_at, d.pin_notes
        FROM documents d
        LEFT JOIN authors a ON d.author_id = a.author_id
        WHERE d.is_pinned = TRUE
        ORDER BY d.pin_priority DESC, d.pinned_at ASC
        """,
        fetch='all'
    )

    return [dict(r) for r in results] if results else []


def get_pinned_document_ids() -> List[str]:
    """
    Get IDs of all pinned documents.

    Returns:
        List of document IDs
    """
    results = execute_query(
        """
        SELECT document_id
        FROM documents
        WHERE is_pinned = TRUE
        ORDER BY pin_priority DESC
        """,
        fetch='all'
    )

    return [r['document_id'] for r in results] if results else []


def is_pinned(document_id: str) -> bool:
    """
    Check if a document is pinned.

    Args:
        document_id: Document ID to check

    Returns:
        True if pinned
    """
    result = execute_query(
        "SELECT is_pinned FROM documents WHERE document_id = %s",
        (document_id,),
        fetch='one'
    )

    return bool(result and result['is_pinned'])


def get_pin_stats() -> Dict[str, Any]:
    """
    Get pinning statistics.

    Returns:
        Dict with stats
    """
    total = execute_query(
        "SELECT COUNT(*) as count FROM documents",
        fetch='one'
    )

    pinned = execute_query(
        "SELECT COUNT(*) as count FROM documents WHERE is_pinned = TRUE",
        fetch='one'
    )

    by_priority = execute_query(
        """
        SELECT pin_priority, COUNT(*) as count
        FROM documents
        WHERE is_pinned = TRUE
        GROUP BY pin_priority
        ORDER BY pin_priority DESC
        """,
        fetch='all'
    )

    return {
        'total_documents': total['count'] if total else 0,
        'pinned_documents': pinned['count'] if pinned else 0,
        'max_allowed': MAX_PINNED_DOCS,
        'by_priority': {r['pin_priority']: r['count'] for r in by_priority} if by_priority else {}
    }


# =============================================================================
# SEARCH INTEGRATION
# =============================================================================

def boost_pinned_results(
    results: List[Dict[str, Any]],
    boost_factor: float = 1.5
) -> List[Dict[str, Any]]:
    """
    Boost scores for results from pinned documents.

    Args:
        results: Search results with 'document_id' and 'score' keys
        boost_factor: Multiplier for pinned document scores

    Returns:
        Results with boosted scores, re-sorted
    """
    pinned_ids = set(get_pinned_document_ids())

    for result in results:
        if result.get('document_id') in pinned_ids:
            original_score = result.get('score', 0)
            result['score'] = original_score * boost_factor
            result['pinned'] = True
        else:
            result['pinned'] = False

    # Re-sort by score
    results.sort(key=lambda x: x.get('score', 0), reverse=True)

    return results


# =============================================================================
# CLI
# =============================================================================

def main():
    parser = argparse.ArgumentParser(
        description='Document Pinning System',
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog="""
Examples:
  # Pin documents
  %(prog)s --pin DOC_001 DOC_002

  # Pin with priority
  %(prog)s --pin DOC_001 --priority 10

  # Pin with notes
  %(prog)s --pin DOC_001 --notes "Key source for chapter 3"

  # Unpin document
  %(prog)s --unpin DOC_001

  # List pinned documents
  %(prog)s --list

  # Show statistics
  %(prog)s --stats

  # Set priority for existing pin
  %(prog)s --set-priority DOC_001 --priority 5

  # Unpin all
  %(prog)s --unpin-all
        """
    )

    # Actions
    parser.add_argument('--pin', nargs='+', metavar='DOC_ID',
                       help='Pin one or more documents')
    parser.add_argument('--unpin', nargs='+', metavar='DOC_ID',
                       help='Unpin one or more documents')
    parser.add_argument('--unpin-all', action='store_true',
                       help='Unpin all documents')
    parser.add_argument('--list', action='store_true',
                       help='List all pinned documents')
    parser.add_argument('--stats', action='store_true',
                       help='Show pinning statistics')
    parser.add_argument('--set-priority', metavar='DOC_ID',
                       help='Set priority for a document')
    parser.add_argument('--check', metavar='DOC_ID',
                       help='Check if document is pinned')

    # Options
    parser.add_argument('--priority', type=int, default=0,
                       help='Priority level (higher = more important)')
    parser.add_argument('--notes', type=str,
                       help='Notes about why document is pinned')

    args = parser.parse_args()

    # Handle actions
    if args.pin:
        results = pin_documents(args.pin, args.priority, args.notes)
        for doc_id, success in results.items():
            status = "Pinned" if success else "FAILED"
            print(f"  {status}: {doc_id}")

    elif args.unpin:
        for doc_id in args.unpin:
            success = unpin_document(doc_id)
            status = "Unpinned" if success else "NOT FOUND"
            print(f"  {status}: {doc_id}")

    elif args.unpin_all:
        confirm = input("Unpin ALL documents? (yes/no): ")
        if confirm.lower() == 'yes':
            unpin_all()
            print("All documents unpinned")
        else:
            print("Cancelled")

    elif args.list:
        docs = get_pinned_documents()
        print(f"\nPinned Documents ({len(docs)}):")
        print("=" * 70)

        if not docs:
            print("  No documents pinned")
        else:
            for doc in docs:
                pinned_date = doc['pinned_at'].strftime('%Y-%m-%d') if doc['pinned_at'] else 'N/A'
                print(f"\n  [{doc['pin_priority']:2d}] {doc['document_id']}")
                print(f"      {doc['title'][:50]}...")
                if doc['author']:
                    print(f"      by {doc['author']}")
                print(f"      Pinned: {pinned_date}")
                if doc['pin_notes']:
                    print(f"      Notes: {doc['pin_notes']}")

    elif args.stats:
        stats = get_pin_stats()
        print(f"\nPinning Statistics:")
        print("=" * 40)
        print(f"  Total documents: {stats['total_documents']}")
        print(f"  Pinned: {stats['pinned_documents']} / {stats['max_allowed']}")
        if stats['by_priority']:
            print(f"\n  By Priority:")
            for priority, count in sorted(stats['by_priority'].items(), reverse=True):
                print(f"    Priority {priority}: {count} documents")

    elif args.set_priority:
        success = set_priority(args.set_priority, args.priority)
        if success:
            print(f"Set priority {args.priority} for {args.set_priority}")
        else:
            print(f"Failed - document not pinned or not found")

    elif args.check:
        pinned = is_pinned(args.check)
        status = "PINNED" if pinned else "not pinned"
        print(f"{args.check}: {status}")

    else:
        parser.print_help()


if __name__ == '__main__':
    main()
