#!/usr/bin/env python3
"""
Draft Validation Tool

Validates factual claims in draft text against the research library.
Designed as a tool for Claude Code to use in a single call.

Usage:
    # Validate a markdown file
    python validate_draft.py chapter_03.md

    # Validate specific text
    python validate_draft.py --text "Cagliostro visited London in 1776..."

    # With options
    python validate_draft.py chapter.md --top-k 5 --format claude

Output:
    Structured report showing:
    - Supported claims (with source citations)
    - Unverified claims (no matching sources)
    - Contradicted claims (sources disagree)
"""

import argparse
import json
import logging
import re
import sys
from pathlib import Path
from datetime import datetime
from typing import List, Dict, Any, Optional, Tuple
from dataclasses import dataclass, field, asdict
from concurrent.futures import ThreadPoolExecutor, as_completed

# Setup path for imports
sys.path.insert(0, str(Path(__file__).parent))

from config import LOGGING_CONFIG, OPENAI_ENABLED
from db_utils import hybrid_search_with_rerank, execute_query

# Setup logging
LOG_LEVEL = LOGGING_CONFIG.get('level', 'INFO')
logging.basicConfig(
    level=getattr(logging, LOG_LEVEL, logging.INFO),
    format='%(asctime)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)


# =============================================================================
# DATA STRUCTURES
# =============================================================================

@dataclass
class ExtractedClaim:
    """A factual claim extracted from draft text."""
    text: str                          # The claim text
    line_number: int                   # Line in source document
    context: str                       # Surrounding text for context
    claim_type: str = "factual"        # factual, definitional, temporal, etc.

    def to_dict(self) -> Dict[str, Any]:
        return asdict(self)


@dataclass
class SourceMatch:
    """A source that matches a claim."""
    document_id: str
    title: str
    author: str
    year: Optional[int]
    chunk_text: str
    page: Optional[str]
    relevance_score: float

    def to_dict(self) -> Dict[str, Any]:
        return asdict(self)


@dataclass
class ClaimValidation:
    """Validation result for a single claim."""
    claim: ExtractedClaim
    status: str                        # supported, unverified, contradicted, misattributed
    confidence: float                  # 0.0 to 1.0
    sources: List[SourceMatch] = field(default_factory=list)
    explanation: str = ""
    suggestion: str = ""
    correct_attribution: str = ""      # For misattributed: the correct source

    def to_dict(self) -> Dict[str, Any]:
        result = {
            'claim': self.claim.to_dict(),
            'status': self.status,
            'confidence': self.confidence,
            'sources': [s.to_dict() for s in self.sources],
            'explanation': self.explanation,
            'suggestion': self.suggestion,
        }
        if self.correct_attribution:
            result['correct_attribution'] = self.correct_attribution
        return result


@dataclass
class ValidationReport:
    """Complete validation report for a document."""
    source_file: str
    validated_at: str
    total_claims: int
    supported: int
    unverified: int
    contradicted: int
    misattributed: int = 0
    validations: List[ClaimValidation] = field(default_factory=list)

    def to_dict(self) -> Dict[str, Any]:
        return {
            'source_file': self.source_file,
            'validated_at': self.validated_at,
            'summary': {
                'total_claims': self.total_claims,
                'supported': self.supported,
                'unverified': self.unverified,
                'contradicted': self.contradicted,
                'misattributed': self.misattributed,
            },
            'validations': [v.to_dict() for v in self.validations],
        }


# =============================================================================
# LLM INTERFACE (Lightweight)
# =============================================================================

class ValidationLLM:
    """Lightweight LLM interface for claim extraction and comparison."""

    def __init__(self):
        self.client = None
        if OPENAI_ENABLED:
            try:
                from openai import OpenAI
                from config import OPENAI_API_KEY
                self.client = OpenAI(api_key=OPENAI_API_KEY)
            except Exception as e:
                logger.warning(f"OpenAI not available: {e}")

    def chat(self, messages: List[Dict], temperature: float = 0.2) -> str:
        """Send chat request to LLM."""
        if not self.client:
            raise RuntimeError("LLM not available - OpenAI API key required")

        response = self.client.chat.completions.create(
            model="gpt-4o-mini",  # Use mini for speed/cost
            messages=messages,
            temperature=temperature,
        )
        return response.choices[0].message.content


# =============================================================================
# CLAIM EXTRACTION
# =============================================================================

CLAIM_EXTRACTION_PROMPT = """You are analyzing text to extract factual claims that can be verified against a research library.

Extract claims that are:
- Specific factual assertions (dates, names, events, definitions)
- Historical claims (X happened in year Y, X was in location Z)
- Definitional claims (X is defined as Y, X means Y)
- Attribution claims (X said/wrote Y, according to X)

Do NOT extract:
- Opinions or interpretations
- General statements without specifics
- Rhetorical questions
- Hedged statements ("might have", "possibly")

For each claim, output a JSON array with objects containing:
- "text": the exact claim text (quote from the source)
- "type": one of "factual", "temporal", "definitional", "attribution"
- "search_query": a concise search query to verify this claim (3-6 words)

TEXT TO ANALYZE:
{text}

OUTPUT (JSON array only, no markdown):"""


def extract_claims_with_llm(text: str, llm: ValidationLLM) -> List[Dict]:
    """Use LLM to extract factual claims from text."""
    messages = [
        {"role": "system", "content": "You extract factual claims from text. Output valid JSON only."},
        {"role": "user", "content": CLAIM_EXTRACTION_PROMPT.format(text=text)}
    ]

    try:
        response = llm.chat(messages, temperature=0.1)
        # Clean response - remove markdown code blocks if present
        response = response.strip()
        if response.startswith("```"):
            response = re.sub(r'^```\w*\n?', '', response)
            response = re.sub(r'\n?```$', '', response)

        claims = json.loads(response)
        return claims if isinstance(claims, list) else []
    except Exception as e:
        logger.error(f"Claim extraction failed: {e}")
        return []


def extract_claims_simple(text: str) -> List[Dict]:
    """Simple heuristic claim extraction (fallback when LLM unavailable)."""
    claims = []

    # Split into sentences
    sentences = re.split(r'(?<=[.!?])\s+', text)

    # Patterns that indicate factual claims
    claim_patterns = [
        r'\b(in \d{4})\b',                    # Year references
        r'\b(was born|died|established|founded|created|wrote|said)\b',
        r'\b(is defined as|means|refers to)\b',
        r'\b(according to|stated that|argued that)\b',
        r'\b(first|originally|initially)\b',
    ]

    for i, sentence in enumerate(sentences):
        sentence = sentence.strip()
        if len(sentence) < 20:
            continue

        # Check if sentence matches claim patterns
        for pattern in claim_patterns:
            if re.search(pattern, sentence, re.IGNORECASE):
                # Extract key terms for search query
                words = re.findall(r'\b[A-Z][a-z]+\b', sentence)  # Proper nouns
                search_query = ' '.join(words[:4]) if words else sentence[:50]

                claims.append({
                    'text': sentence,
                    'type': 'factual',
                    'search_query': search_query,
                })
                break

    return claims


def extract_claims(text: str, line_offset: int = 0, use_llm: bool = True) -> List[ExtractedClaim]:
    """Extract claims from text, with line numbers."""
    llm = ValidationLLM() if use_llm else None

    # Try LLM extraction first
    raw_claims = []
    if llm and llm.client:
        try:
            raw_claims = extract_claims_with_llm(text, llm)
        except Exception as e:
            logger.warning(f"LLM extraction failed, using fallback: {e}")

    # Fallback to simple extraction
    if not raw_claims:
        raw_claims = extract_claims_simple(text)

    # Convert to ExtractedClaim objects with line numbers
    claims = []
    lines = text.split('\n')

    for raw in raw_claims:
        claim_text = raw.get('text', '')

        # Find line number
        line_num = line_offset
        for i, line in enumerate(lines):
            if claim_text[:50] in line:
                line_num = line_offset + i + 1
                break

        # Get context (surrounding lines)
        context_start = max(0, line_num - line_offset - 2)
        context_end = min(len(lines), line_num - line_offset + 2)
        context = ' '.join(lines[context_start:context_end])

        claims.append(ExtractedClaim(
            text=claim_text,
            line_number=line_num,
            context=context[:500],
            claim_type=raw.get('type', 'factual'),
        ))

    return claims


# =============================================================================
# SOURCE SEARCH
# =============================================================================

def search_for_claim(claim: ExtractedClaim, top_k: int = 5) -> List[SourceMatch]:
    """Search the library for sources related to a claim."""
    # Use the claim text as the search query
    query = claim.text[:200]  # Truncate very long claims

    try:
        results = hybrid_search_with_rerank(
            query_text=query,
            limit=top_k,
            use_rerank=True,
        )
    except Exception as e:
        logger.error(f"Search failed for claim: {e}")
        return []

    sources = []
    for r in results:
        sources.append(SourceMatch(
            document_id=r.get('document_id', ''),
            title=r.get('title', 'Unknown'),
            author=r.get('author', 'Unknown'),
            year=r.get('publication_year'),
            chunk_text=r.get('chunk_text', '')[:500],
            page=str(r.get('page_start', '')) if r.get('page_start') else None,
            relevance_score=r.get('rerank_score', r.get('rrf_score', 0.0)),
        ))

    return sources


# =============================================================================
# CLAIM VALIDATION
# =============================================================================

VALIDATION_PROMPT = """You are validating a factual claim against source material from a research library.

CLAIM TO VALIDATE:
"{claim}"

SOURCE MATERIAL:
{sources}

Determine if the source material:
1. SUPPORTS the claim (source confirms the claim)
2. CONTRADICTS the claim (source disagrees with the claim)
3. MISATTRIBUTED the claim (quote or idea is real but attributed to wrong author/work/date)
4. UNVERIFIED (sources don't address this specific claim)

Output a JSON object with:
- "status": one of "supported", "contradicted", "misattributed", "unverified"
- "confidence": 0.0 to 1.0 (how confident are you in this assessment)
- "explanation": brief explanation of your assessment (1-2 sentences)
- "key_quote": if supported/contradicted/misattributed, quote the relevant part from sources
- "correct_attribution": if misattributed, provide correct source (author, work, date if known)

OUTPUT (JSON only, no markdown):"""


def validate_claim_with_llm(
    claim: ExtractedClaim,
    sources: List[SourceMatch],
    llm: ValidationLLM
) -> Tuple[str, float, str]:
    """Use LLM to compare claim against sources."""
    if not sources:
        return "unverified", 0.5, "No relevant sources found in library."

    # Format sources for prompt
    sources_text = ""
    for i, src in enumerate(sources[:3], 1):  # Use top 3 sources
        sources_text += f"\n[Source {i}] {src.author}, {src.title}"
        if src.year:
            sources_text += f" ({src.year})"
        if src.page:
            sources_text += f", p.{src.page}"
        sources_text += f"\n{src.chunk_text}\n"

    messages = [
        {"role": "system", "content": "You validate factual claims against sources. Output valid JSON only."},
        {"role": "user", "content": VALIDATION_PROMPT.format(
            claim=claim.text,
            sources=sources_text
        )}
    ]

    try:
        response = llm.chat(messages, temperature=0.1)
        # Clean response
        response = response.strip()
        if response.startswith("```"):
            response = re.sub(r'^```\w*\n?', '', response)
            response = re.sub(r'\n?```$', '', response)

        result = json.loads(response)
        return (
            result.get('status', 'unverified'),
            float(result.get('confidence', 0.5)),
            result.get('explanation', '')
        )
    except Exception as e:
        logger.error(f"Validation failed: {e}")
        # Fallback: check if sources are relevant based on score
        if sources and sources[0].relevance_score > 0.5:
            return "unverified", 0.3, "Could not determine - manual review recommended."
        return "unverified", 0.2, "No clearly relevant sources found."


def validate_claim_simple(claim: ExtractedClaim, sources: List[SourceMatch]) -> Tuple[str, float, str]:
    """Simple validation without LLM (checks keyword overlap)."""
    if not sources:
        return "unverified", 0.3, "No sources found."

    # Check for keyword overlap
    claim_words = set(re.findall(r'\b\w{4,}\b', claim.text.lower()))

    best_overlap = 0
    for src in sources:
        src_words = set(re.findall(r'\b\w{4,}\b', src.chunk_text.lower()))
        overlap = len(claim_words & src_words) / max(len(claim_words), 1)
        best_overlap = max(best_overlap, overlap)

    if best_overlap > 0.5 and sources[0].relevance_score > 0.6:
        return "supported", best_overlap, f"High keyword overlap with {sources[0].title}"
    elif best_overlap > 0.3:
        return "unverified", 0.4, "Some overlap but manual verification needed."
    else:
        return "unverified", 0.2, "Low relevance to sources found."


def validate_claim(
    claim: ExtractedClaim,
    sources: List[SourceMatch],
    use_llm: bool = True
) -> ClaimValidation:
    """Validate a claim against its sources."""
    llm = ValidationLLM() if use_llm else None

    if llm and llm.client and sources:
        status, confidence, explanation = validate_claim_with_llm(claim, sources, llm)
    else:
        status, confidence, explanation = validate_claim_simple(claim, sources)

    # Generate suggestion based on status
    if status == "contradicted":
        suggestion = f"Verify against {sources[0].title if sources else 'original source'}"
    elif status == "unverified":
        suggestion = "Consider adding a source or softening the claim"
    else:
        suggestion = ""

    return ClaimValidation(
        claim=claim,
        status=status,
        confidence=confidence,
        sources=sources[:3],  # Keep top 3 sources
        explanation=explanation,
        suggestion=suggestion,
    )


# =============================================================================
# BATCH VALIDATION
# =============================================================================

def validate_document(
    text: str,
    source_file: str = "draft",
    top_k: int = 5,
    use_llm: bool = True,
    max_workers: int = 4,
) -> ValidationReport:
    """Validate all claims in a document."""
    logger.info(f"Extracting claims from {source_file}...")

    # Extract claims
    claims = extract_claims(text, use_llm=use_llm)
    logger.info(f"Found {len(claims)} claims to validate")

    if not claims:
        return ValidationReport(
            source_file=source_file,
            validated_at=datetime.now().isoformat(),
            total_claims=0,
            supported=0,
            unverified=0,
            contradicted=0,
        )

    # Search and validate each claim
    validations = []

    # Use parallel search for speed
    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        # Submit search tasks
        search_futures = {
            executor.submit(search_for_claim, claim, top_k): claim
            for claim in claims
        }

        # Collect search results
        claim_sources = {}
        for future in as_completed(search_futures):
            claim = search_futures[future]
            try:
                sources = future.result()
                claim_sources[claim.text] = (claim, sources)
            except Exception as e:
                logger.error(f"Search failed: {e}")
                claim_sources[claim.text] = (claim, [])

    # Validate each claim (sequential to manage LLM rate limits)
    logger.info("Validating claims against sources...")
    for claim_text, (claim, sources) in claim_sources.items():
        validation = validate_claim(claim, sources, use_llm=use_llm)
        validations.append(validation)

    # Count results
    supported = sum(1 for v in validations if v.status == "supported")
    contradicted = sum(1 for v in validations if v.status == "contradicted")
    misattributed = sum(1 for v in validations if v.status == "misattributed")
    unverified = sum(1 for v in validations if v.status == "unverified")

    return ValidationReport(
        source_file=source_file,
        validated_at=datetime.now().isoformat(),
        total_claims=len(claims),
        supported=supported,
        unverified=unverified,
        contradicted=contradicted,
        misattributed=misattributed,
        validations=validations,
    )


# =============================================================================
# OUTPUT FORMATTING
# =============================================================================

def format_claude_output(report: ValidationReport) -> str:
    """Format report for Claude Code consumption (structured markdown)."""
    lines = [
        f"## Validation Report: {report.source_file}",
        "",
        "### Summary",
        f"- **Claims analyzed:** {report.total_claims}",
        f"- **Supported:** {report.supported}",
        f"- **Unverified:** {report.unverified}",
        f"- **Contradicted:** {report.contradicted}",
        f"- **Misattributed:** {report.misattributed}",
        "",
    ]

    # Group by status
    contradicted = [v for v in report.validations if v.status == "contradicted"]
    misattributed = [v for v in report.validations if v.status == "misattributed"]
    unverified = [v for v in report.validations if v.status == "unverified"]
    supported = [v for v in report.validations if v.status == "supported"]

    # Show contradictions first (most important)
    if contradicted:
        lines.extend([
            "### Contradictions Found",
            "",
        ])
        for v in contradicted:
            lines.append(f"#### Line {v.claim.line_number}")
            lines.append(f"**Your text:** \"{v.claim.text[:100]}{'...' if len(v.claim.text) > 100 else ''}\"")
            if v.sources:
                src = v.sources[0]
                lines.append(f"**Library says:** {src.chunk_text[:200]}...")
                lines.append(f"  → Source: {src.author}, *{src.title}*" +
                           (f" ({src.year})" if src.year else "") +
                           (f", p.{src.page}" if src.page else ""))
            lines.append(f"**Suggestion:** {v.suggestion}")
            lines.append("")

    # Show misattributed (important - wrong source/author)
    if misattributed:
        lines.extend([
            "### Misattributed Claims",
            "",
        ])
        for v in misattributed:
            lines.append(f"#### Line {v.claim.line_number}")
            lines.append(f"**Your text:** \"{v.claim.text[:100]}{'...' if len(v.claim.text) > 100 else ''}\"")
            lines.append(f"**Issue:** {v.explanation}")
            if v.correct_attribution:
                lines.append(f"**Correct attribution:** {v.correct_attribution}")
            if v.sources:
                src = v.sources[0]
                lines.append(f"  → Found in: {src.author}, *{src.title}*" +
                           (f" ({src.year})" if src.year else ""))
            lines.append(f"**Suggestion:** Verify and correct the attribution")
            lines.append("")

    # Show unverified
    if unverified:
        lines.extend([
            "### Unverified Claims",
            "",
        ])
        for v in unverified[:10]:  # Limit display
            lines.append(f"- **Line {v.claim.line_number}:** \"{v.claim.text[:80]}...\"")
            lines.append(f"  → {v.explanation}")
        if len(unverified) > 10:
            lines.append(f"  *... and {len(unverified) - 10} more unverified claims*")
        lines.append("")

    # Show supported (brief)
    if supported:
        lines.extend([
            "### Supported Claims",
            "",
        ])
        for v in supported[:5]:  # Show just a few
            src_info = ""
            if v.sources:
                src = v.sources[0]
                src_info = f" → {src.author}, *{src.title}*"
            lines.append(f"- \"{v.claim.text[:60]}...\"{src_info}")
        if len(supported) > 5:
            lines.append(f"  *... and {len(supported) - 5} more supported claims*")
        lines.append("")

    # Next steps
    lines.extend([
        "### Recommended Actions",
        "",
    ])
    action_num = 1
    if contradicted:
        lines.append(f"{action_num}. **Review {len(contradicted)} contradiction(s)** - verify facts against original sources")
        action_num += 1
    if misattributed:
        lines.append(f"{action_num}. **Fix {len(misattributed)} misattribution(s)** - correct the author/source references")
        action_num += 1
    if unverified:
        lines.append(f"{action_num}. **Add sources for {len(unverified)} unverified claim(s)** - or soften the language")
        action_num += 1
    if not contradicted and not unverified and not misattributed:
        lines.append("All claims are well-supported by your library.")

    return "\n".join(lines)


def format_json_output(report: ValidationReport) -> str:
    """Format report as JSON."""
    return json.dumps(report.to_dict(), indent=2)


# =============================================================================
# CLI
# =============================================================================

def create_parser() -> argparse.ArgumentParser:
    """Create argument parser."""
    parser = argparse.ArgumentParser(
        prog='validate_draft',
        description='Validate factual claims in draft text against the research library',
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog="""
Examples:
  # Validate a markdown file
  python validate_draft.py chapter_03.md

  # Validate specific text
  python validate_draft.py --text "Cagliostro visited London in 1776..."

  # JSON output for programmatic use
  python validate_draft.py chapter.md --format json

For Claude Code usage, the default 'claude' format provides structured markdown.
        """
    )

    parser.add_argument(
        'file',
        nargs='?',
        type=str,
        help='Path to markdown/text file to validate'
    )

    parser.add_argument(
        '--text', '-t',
        type=str,
        help='Validate specific text instead of a file'
    )

    parser.add_argument(
        '--format', '-f',
        choices=['claude', 'json'],
        default='claude',
        help='Output format (default: claude)'
    )

    parser.add_argument(
        '--top-k', '-k',
        type=int,
        default=5,
        help='Number of sources to retrieve per claim (default: 5)'
    )

    parser.add_argument(
        '--no-llm',
        action='store_true',
        help='Disable LLM for extraction/validation (use heuristics)'
    )

    parser.add_argument(
        '--verbose', '-v',
        action='store_true',
        help='Enable verbose output'
    )

    return parser


def main() -> int:
    """Main entry point."""
    parser = create_parser()
    args = parser.parse_args()

    if args.verbose:
        logging.getLogger().setLevel(logging.DEBUG)

    # Get text to validate
    if args.text:
        text = args.text
        source_file = "text_input"
    elif args.file:
        path = Path(args.file)
        if not path.exists():
            print(f"Error: File not found: {path}")
            return 1
        text = path.read_text(encoding='utf-8')
        source_file = path.name
    else:
        print("Error: Provide a file or --text argument")
        return 1

    # Validate
    try:
        report = validate_document(
            text=text,
            source_file=source_file,
            top_k=args.top_k,
            use_llm=not args.no_llm,
        )
    except Exception as e:
        logger.exception("Validation failed")
        print(f"Error: {e}")
        return 1

    # Output
    if args.format == 'json':
        print(format_json_output(report))
    else:
        print(format_claude_output(report))

    # Exit code based on findings
    if report.contradicted > 0:
        return 2  # Has contradictions
    elif report.unverified > 0:
        return 1  # Has unverified
    return 0


if __name__ == '__main__':
    sys.exit(main())
