#!/usr/bin/env python3
"""
Embedding Generation Pipeline for Research Development Framework.

This script generates vector embeddings for document chunks:
1. Load chunks that need embeddings
2. Batch process through OpenAI API
3. Store embeddings in database for semantic search

IMPORTANT: This step is OPTIONAL. The system works without embeddings
using keyword-only search. Run this only if you have an OpenAI API key
and want to enable semantic (meaning-based) search.

Usage:
    python generate_embeddings.py                      # Process all chunks without embeddings
    python generate_embeddings.py --document DOC_001   # Process specific document
    python generate_embeddings.py --batch-size 50      # Custom batch size
    python generate_embeddings.py --estimate           # Estimate cost without processing
"""

import os
import sys
import time
import logging
from typing import List, Dict, Any, Tuple
import argparse

try:
    from openai import OpenAI
    HAS_OPENAI = True
except ImportError:
    HAS_OPENAI = False

from config import (
    OPENAI_API_KEY,
    OPENAI_ENABLED,
    EMBEDDING_MODEL,
    EMBEDDING_DIMENSIONS,
    PROCESSING_CONFIG,
    LOGGING_CONFIG
)
from db_utils import (
    get_db_connection,
    execute_query,
    batch_update_embeddings,
    update_document_status,
    get_chunks_without_embeddings
)

# Setup logging
logging.basicConfig(
    level=getattr(logging, LOGGING_CONFIG['level']),
    format=LOGGING_CONFIG['format']
)
logger = logging.getLogger(__name__)


class EmbeddingGenerator:
    """Handles embedding generation via OpenAI API."""

    def __init__(self, batch_size: int = None):
        if not HAS_OPENAI:
            raise ImportError("OpenAI package not installed")

        if not OPENAI_API_KEY:
            raise ValueError("OPENAI_API_KEY not set in environment")

        self.client = OpenAI(api_key=OPENAI_API_KEY)
        self.model = EMBEDDING_MODEL
        self.dimensions = EMBEDDING_DIMENSIONS
        self.batch_size = batch_size or PROCESSING_CONFIG['embedding_batch_size']
        self.retry_attempts = PROCESSING_CONFIG['retry_attempts']
        self.retry_delay = PROCESSING_CONFIG['retry_delay_seconds']

        self.stats = {
            'chunks_processed': 0,
            'tokens_used': 0,
            'api_calls': 0,
            'errors': 0,
            'retries': 0
        }

    def estimate_tokens(self, text: str) -> int:
        """Estimate token count for text (rough approximation)."""
        # OpenAI embedding models: ~4 chars per token
        return len(text) // 4

    def generate_embedding(self, text: str) -> List[float]:
        """Generate embedding for a single text."""
        response = self.client.embeddings.create(
            model=self.model,
            input=text,
            dimensions=self.dimensions
        )
        return response.data[0].embedding

    def generate_batch_embeddings(self, texts: List[str]) -> List[List[float]]:
        """
        Generate embeddings for a batch of texts.

        Returns list of embedding vectors in same order as input.
        """
        for attempt in range(self.retry_attempts):
            try:
                response = self.client.embeddings.create(
                    model=self.model,
                    input=texts,
                    dimensions=self.dimensions
                )

                self.stats['api_calls'] += 1
                self.stats['tokens_used'] += response.usage.total_tokens

                # Sort by index to ensure order matches input
                sorted_data = sorted(response.data, key=lambda x: x.index)
                return [item.embedding for item in sorted_data]

            except Exception as e:
                logger.warning(f"API error (attempt {attempt + 1}): {e}")
                self.stats['retries'] += 1

                if attempt < self.retry_attempts - 1:
                    time.sleep(self.retry_delay * (attempt + 1))
                else:
                    raise

    def get_chunks_to_embed(self, document_id: str = None) -> List[Dict]:
        """Get chunks that need embeddings."""
        if document_id:
            return execute_query(
                """
                SELECT chunk_id, chunk_text, chunk_tokens
                FROM chunks
                WHERE document_id = %s AND embedding IS NULL
                ORDER BY chunk_sequence
                """,
                (document_id,),
                fetch='all'
            )

        return execute_query(
            """
            SELECT chunk_id, chunk_text, chunk_tokens
            FROM chunks
            WHERE embedding IS NULL
            ORDER BY document_id, chunk_sequence
            """,
            fetch='all'
        )

    def estimate_cost(self, chunks: List[Dict]) -> Dict[str, Any]:
        """
        Estimate processing cost for chunks.

        OpenAI text-embedding-3-small: $0.02 per 1M tokens
        """
        total_tokens = sum(c.get('chunk_tokens', 0) or self.estimate_tokens(c['chunk_text'])
                          for c in chunks)

        cost_per_million = 0.02  # USD for text-embedding-3-small
        estimated_cost = (total_tokens / 1_000_000) * cost_per_million

        return {
            'chunk_count': len(chunks),
            'total_tokens': total_tokens,
            'estimated_cost_usd': round(estimated_cost, 4),
            'api_calls_needed': (len(chunks) + self.batch_size - 1) // self.batch_size
        }

    def process_batch(self, chunks: List[Dict]) -> int:
        """
        Process a batch of chunks.

        Returns number successfully processed.
        """
        if not chunks:
            return 0

        texts = [c['chunk_text'] for c in chunks]
        chunk_ids = [c['chunk_id'] for c in chunks]

        try:
            embeddings = self.generate_batch_embeddings(texts)

            # Prepare updates
            updates = [(chunk_ids[i], embeddings[i]) for i in range(len(chunks))]

            # Batch update database
            batch_update_embeddings(updates)

            self.stats['chunks_processed'] += len(chunks)
            return len(chunks)

        except Exception as e:
            logger.error(f"Batch processing error: {e}")
            self.stats['errors'] += 1
            return 0

    def process_all(self, document_id: str = None, estimate_only: bool = False) -> Dict[str, Any]:
        """
        Process all chunks that need embeddings.

        Args:
            document_id: Optional filter by document
            estimate_only: If True, only estimate cost without processing

        Returns:
            Processing statistics
        """
        chunks = self.get_chunks_to_embed(document_id)

        if not chunks:
            logger.info("No chunks need embeddings")
            return self.stats

        logger.info(f"Found {len(chunks)} chunks to embed")

        # Estimate cost
        estimate = self.estimate_cost(chunks)
        logger.info(f"Estimated cost: ${estimate['estimated_cost_usd']:.4f} "
                   f"({estimate['total_tokens']:,} tokens)")

        if estimate_only:
            return estimate

        # Process in batches
        total_batches = (len(chunks) + self.batch_size - 1) // self.batch_size

        for i in range(0, len(chunks), self.batch_size):
            batch = chunks[i:i + self.batch_size]
            batch_num = (i // self.batch_size) + 1

            logger.info(f"Processing batch {batch_num}/{total_batches} "
                       f"({len(batch)} chunks)")

            processed = self.process_batch(batch)

            if processed == 0:
                logger.warning(f"Batch {batch_num} failed")

            # Rate limiting: small delay between batches
            if i + self.batch_size < len(chunks):
                time.sleep(0.5)

        # Update document status if processing specific document
        if document_id and self.stats['errors'] == 0:
            update_document_status(document_id, 'embedded')

        return self.stats


def main():
    parser = argparse.ArgumentParser(
        description='Generate embeddings for the Research Development Framework',
        epilog='''
NOTE: Embeddings are OPTIONAL. Without them, the system uses keyword-only search.
To enable semantic search, add your OpenAI API key to the .env file:

    OPENAI_API_KEY=sk-your-key-here

Get an API key from: https://platform.openai.com/api-keys
        '''
    )
    parser.add_argument(
        '--document',
        type=str,
        help='Process a specific document by ID'
    )
    parser.add_argument(
        '--batch-size',
        type=int,
        default=100,
        help='Number of chunks per API call (default: 100)'
    )
    parser.add_argument(
        '--estimate',
        action='store_true',
        help='Only estimate cost without processing'
    )

    args = parser.parse_args()

    # Check for API key availability
    if not OPENAI_ENABLED:
        print("\n" + "=" * 60)
        print("EMBEDDING GENERATION - API KEY REQUIRED")
        print("=" * 60)
        print()
        print("This step is OPTIONAL but enables semantic search.")
        print()
        print("To generate embeddings, add your OpenAI API key to .env:")
        print()
        print("    1. Open: .env")
        print("    2. Add:  OPENAI_API_KEY=sk-your-key-here")
        print()
        print("Get an API key from: https://platform.openai.com/api-keys")
        print()
        print("Without embeddings, you can still use KEYWORD SEARCH,")
        print("which works well for exact phrases and specific terms.")
        print("=" * 60)
        sys.exit(0)

    if not HAS_OPENAI:
        print("Error: openai package not installed. Run: pip install openai")
        sys.exit(1)

    generator = EmbeddingGenerator(batch_size=args.batch_size)

    if args.estimate:
        chunks = generator.get_chunks_to_embed(args.document)
        if chunks:
            estimate = generator.estimate_cost(chunks)
            print("\n" + "=" * 50)
            print("EMBEDDING COST ESTIMATE")
            print("=" * 50)
            print(f"Chunks to process:   {estimate['chunk_count']:,}")
            print(f"Total tokens:        {estimate['total_tokens']:,}")
            print(f"API calls needed:    {estimate['api_calls_needed']}")
            print(f"Estimated cost:      ${estimate['estimated_cost_usd']:.4f} USD")
            print("=" * 50)
        else:
            print("No chunks need embeddings")
        return

    stats = generator.process_all(document_id=args.document)

    # Print summary
    print("\n" + "=" * 50)
    print("EMBEDDING GENERATION SUMMARY")
    print("=" * 50)
    print(f"Chunks Processed: {stats['chunks_processed']:,}")
    print(f"Tokens Used:      {stats['tokens_used']:,}")
    print(f"API Calls:        {stats['api_calls']}")
    print(f"Retries:          {stats['retries']}")
    print(f"Errors:           {stats['errors']}")

    if stats['tokens_used'] > 0:
        cost = (stats['tokens_used'] / 1_000_000) * 0.02
        print(f"Actual Cost:      ${cost:.4f} USD")

    print("=" * 50)


if __name__ == '__main__':
    main()
