"""
Merge Readiness - Pre-merge validation checks.

Implements checks that must pass before a merge is allowed:
- Tests pass
- No critical risks detected
- No pending approvals
- Status packet complete
- No merge conflicts
- Branch up to date
"""

from dataclasses import dataclass, field
from datetime import datetime
from enum import Enum
from typing import Optional, Dict, List, Any, Callable
import asyncio
import subprocess
import logging

logger = logging.getLogger(__name__)


class CheckStatus(Enum):
    """Status of a readiness check."""
    PASSED = "passed"
    FAILED = "failed"
    SKIPPED = "skipped"
    PENDING = "pending"
    WARNING = "warning"


class CheckSeverity(Enum):
    """Severity of a check failure."""
    BLOCKING = "blocking"  # Must pass for merge
    WARNING = "warning"    # Can merge with warning
    INFO = "info"          # Informational only


@dataclass
class CheckResult:
    """Result of a single readiness check."""

    name: str
    status: CheckStatus
    severity: CheckSeverity = CheckSeverity.BLOCKING

    message: str = ""
    details: Dict[str, Any] = field(default_factory=dict)

    started_at: Optional[datetime] = None
    completed_at: Optional[datetime] = None
    duration_ms: int = 0


@dataclass
class ReadinessReport:
    """Complete readiness assessment report."""

    branch: str
    target_branch: str

    is_ready: bool
    checks: List[CheckResult] = field(default_factory=list)

    # Summary counts
    passed_count: int = 0
    failed_count: int = 0
    warning_count: int = 0
    skipped_count: int = 0

    # Blocking issues
    blocking_issues: List[str] = field(default_factory=list)
    warnings: List[str] = field(default_factory=list)

    # Metadata
    created_at: datetime = field(default_factory=datetime.now)
    total_duration_ms: int = 0


class MergeReadiness:
    """
    Evaluates whether a branch is ready to merge.

    Runs a series of checks and produces a ReadinessReport.
    """

    def __init__(
        self,
        db: Any = None,
        risk_policy: Any = None,
        approval_queue: Any = None,
        config: Optional[Dict[str, Any]] = None,
    ):
        """
        Initialize merge readiness checker.

        Args:
            db: Database for querying state
            risk_policy: RiskPolicy for risk classification
            approval_queue: ApprovalQueue for pending approvals
            config: Configuration options
        """
        self.db = db
        # Default to global risk policy if none provided
        if risk_policy is None:
            from ..risk.policy import get_risk_policy
            self.risk_policy = get_risk_policy()
        else:
            self.risk_policy = risk_policy
        self.approval_queue = approval_queue
        self.config = config or {}

        # Default configuration
        self._test_command = self.config.get("test_command", "pytest")
        self._test_timeout = self.config.get("test_timeout_seconds", 300)
        self._require_tests = self.config.get("require_tests", True)
        self._require_status_packet = self.config.get("require_status_packet", True)
        self._require_up_to_date = self.config.get("require_up_to_date", True)

        # Custom check registry
        self._custom_checks: Dict[str, Callable] = {}

    def register_check(
        self,
        name: str,
        check_fn: Callable,
        severity: CheckSeverity = CheckSeverity.BLOCKING,
    ) -> None:
        """
        Register a custom readiness check.

        Args:
            name: Check name
            check_fn: Async function that returns CheckResult
            severity: Severity if check fails
        """
        self._custom_checks[name] = (check_fn, severity)
        logger.info(f"Registered custom readiness check: {name}")

    async def check_readiness(
        self,
        branch: str,
        target_branch: str = "main",
        run_tests: bool = True,
        working_dir: Optional[str] = None,
    ) -> ReadinessReport:
        """
        Run all readiness checks for a branch.

        Args:
            branch: Source branch to check
            target_branch: Target branch for merge
            run_tests: Whether to run tests
            working_dir: Working directory for git/test commands

        Returns:
            ReadinessReport with all check results
        """
        start_time = datetime.now()
        checks: List[CheckResult] = []

        # Run all checks
        checks.append(await self._check_branch_exists(branch, working_dir))
        checks.append(await self._check_no_conflicts(branch, target_branch, working_dir))
        checks.append(await self._check_branch_up_to_date(branch, target_branch, working_dir))

        if run_tests and self._require_tests:
            checks.append(await self._check_tests_pass(working_dir))

        checks.append(await self._check_no_critical_risks(branch))
        checks.append(await self._check_no_pending_approvals(branch))

        if self._require_status_packet:
            checks.append(await self._check_status_packet_complete(branch))

        # Run custom checks
        for name, (check_fn, severity) in self._custom_checks.items():
            try:
                result = await check_fn(branch, target_branch)
                result.severity = severity
                checks.append(result)
            except Exception as e:
                checks.append(CheckResult(
                    name=name,
                    status=CheckStatus.FAILED,
                    severity=severity,
                    message=f"Check error: {str(e)}",
                ))

        # Build report
        report = self._build_report(branch, target_branch, checks, start_time)

        return report

    def _build_report(
        self,
        branch: str,
        target_branch: str,
        checks: List[CheckResult],
        start_time: datetime,
    ) -> ReadinessReport:
        """Build readiness report from check results."""
        report = ReadinessReport(
            branch=branch,
            target_branch=target_branch,
            checks=checks,
        )

        # Count results
        for check in checks:
            if check.status == CheckStatus.PASSED:
                report.passed_count += 1
            elif check.status == CheckStatus.FAILED:
                report.failed_count += 1
                if check.severity == CheckSeverity.BLOCKING:
                    report.blocking_issues.append(f"{check.name}: {check.message}")
                else:
                    report.warnings.append(f"{check.name}: {check.message}")
            elif check.status == CheckStatus.WARNING:
                report.warning_count += 1
                report.warnings.append(f"{check.name}: {check.message}")
            elif check.status == CheckStatus.SKIPPED:
                report.skipped_count += 1

        # Determine overall readiness
        report.is_ready = len(report.blocking_issues) == 0

        # Calculate total duration
        report.total_duration_ms = int((datetime.now() - start_time).total_seconds() * 1000)

        return report

    async def _check_branch_exists(
        self,
        branch: str,
        working_dir: Optional[str],
    ) -> CheckResult:
        """Check that the branch exists."""
        start = datetime.now()

        try:
            result = await self._run_git_command(
                ["git", "rev-parse", "--verify", branch],
                working_dir,
            )

            if result.returncode == 0:
                return CheckResult(
                    name="branch_exists",
                    status=CheckStatus.PASSED,
                    message=f"Branch '{branch}' exists",
                    started_at=start,
                    completed_at=datetime.now(),
                )
            else:
                return CheckResult(
                    name="branch_exists",
                    status=CheckStatus.FAILED,
                    message=f"Branch '{branch}' does not exist",
                    started_at=start,
                    completed_at=datetime.now(),
                )

        except Exception as e:
            return CheckResult(
                name="branch_exists",
                status=CheckStatus.FAILED,
                message=f"Error checking branch: {str(e)}",
                started_at=start,
                completed_at=datetime.now(),
            )

    async def _check_no_conflicts(
        self,
        branch: str,
        target_branch: str,
        working_dir: Optional[str],
    ) -> CheckResult:
        """Check that merge won't have conflicts."""
        start = datetime.now()

        try:
            # Try a dry-run merge
            result = await self._run_git_command(
                ["git", "merge-tree", f"$(git merge-base {target_branch} {branch})", target_branch, branch],
                working_dir,
                shell=True,
            )

            # Check for conflict markers in output
            if result.stdout and "<<<<<<" in result.stdout:
                return CheckResult(
                    name="no_conflicts",
                    status=CheckStatus.FAILED,
                    message="Merge conflicts detected",
                    details={"conflict_preview": result.stdout[:500]},
                    started_at=start,
                    completed_at=datetime.now(),
                )

            # Alternative: check merge-base
            base_result = await self._run_git_command(
                ["git", "merge-base", target_branch, branch],
                working_dir,
            )

            if base_result.returncode != 0:
                return CheckResult(
                    name="no_conflicts",
                    status=CheckStatus.WARNING,
                    severity=CheckSeverity.WARNING,
                    message="Could not determine merge base",
                    started_at=start,
                    completed_at=datetime.now(),
                )

            return CheckResult(
                name="no_conflicts",
                status=CheckStatus.PASSED,
                message="No merge conflicts detected",
                started_at=start,
                completed_at=datetime.now(),
            )

        except Exception as e:
            return CheckResult(
                name="no_conflicts",
                status=CheckStatus.WARNING,
                severity=CheckSeverity.WARNING,
                message=f"Could not check for conflicts: {str(e)}",
                started_at=start,
                completed_at=datetime.now(),
            )

    async def _check_branch_up_to_date(
        self,
        branch: str,
        target_branch: str,
        working_dir: Optional[str],
    ) -> CheckResult:
        """Check that branch is up to date with target."""
        start = datetime.now()

        if not self._require_up_to_date:
            return CheckResult(
                name="branch_up_to_date",
                status=CheckStatus.SKIPPED,
                message="Up-to-date check skipped",
                started_at=start,
                completed_at=datetime.now(),
            )

        try:
            # Get commits in target that aren't in branch
            result = await self._run_git_command(
                ["git", "log", "--oneline", f"{branch}..{target_branch}"],
                working_dir,
            )

            if result.returncode == 0 and result.stdout.strip():
                commit_count = len(result.stdout.strip().split('\n'))
                return CheckResult(
                    name="branch_up_to_date",
                    status=CheckStatus.WARNING,
                    severity=CheckSeverity.WARNING,
                    message=f"Branch is {commit_count} commits behind {target_branch}",
                    details={"behind_by": commit_count},
                    started_at=start,
                    completed_at=datetime.now(),
                )

            return CheckResult(
                name="branch_up_to_date",
                status=CheckStatus.PASSED,
                message=f"Branch is up to date with {target_branch}",
                started_at=start,
                completed_at=datetime.now(),
            )

        except Exception as e:
            return CheckResult(
                name="branch_up_to_date",
                status=CheckStatus.WARNING,
                severity=CheckSeverity.WARNING,
                message=f"Could not check if up to date: {str(e)}",
                started_at=start,
                completed_at=datetime.now(),
            )

    async def _check_tests_pass(
        self,
        working_dir: Optional[str],
    ) -> CheckResult:
        """Run tests and check they pass."""
        start = datetime.now()

        try:
            # Run test command
            process = await asyncio.create_subprocess_shell(
                self._test_command,
                cwd=working_dir,
                stdout=asyncio.subprocess.PIPE,
                stderr=asyncio.subprocess.PIPE,
            )

            try:
                stdout, stderr = await asyncio.wait_for(
                    process.communicate(),
                    timeout=self._test_timeout,
                )
            except asyncio.TimeoutError:
                process.kill()
                return CheckResult(
                    name="tests_pass",
                    status=CheckStatus.FAILED,
                    message=f"Tests timed out after {self._test_timeout}s",
                    started_at=start,
                    completed_at=datetime.now(),
                )

            if process.returncode == 0:
                return CheckResult(
                    name="tests_pass",
                    status=CheckStatus.PASSED,
                    message="All tests passed",
                    started_at=start,
                    completed_at=datetime.now(),
                )
            else:
                return CheckResult(
                    name="tests_pass",
                    status=CheckStatus.FAILED,
                    message="Tests failed",
                    details={
                        "stdout": stdout.decode()[-1000:] if stdout else "",
                        "stderr": stderr.decode()[-1000:] if stderr else "",
                    },
                    started_at=start,
                    completed_at=datetime.now(),
                )

        except FileNotFoundError:
            return CheckResult(
                name="tests_pass",
                status=CheckStatus.SKIPPED,
                severity=CheckSeverity.WARNING,
                message=f"Test command not found: {self._test_command}",
                started_at=start,
                completed_at=datetime.now(),
            )
        except Exception as e:
            return CheckResult(
                name="tests_pass",
                status=CheckStatus.FAILED,
                message=f"Error running tests: {str(e)}",
                started_at=start,
                completed_at=datetime.now(),
            )

    async def _check_no_critical_risks(self, branch: str) -> CheckResult:
        """Check that no critical risks are present."""
        start = datetime.now()

        if self.risk_policy is None:
            return CheckResult(
                name="no_critical_risks",
                status=CheckStatus.SKIPPED,
                message="Risk policy not configured",
                started_at=start,
                completed_at=datetime.now(),
            )

        try:
            # Query for any critical risk items on this branch
            if self.db and hasattr(self.db, 'get_risks_for_branch'):
                risks = self.db.get_risks_for_branch(branch)
                critical_risks = [r for r in risks if r.get('level') == 'CRITICAL']

                if critical_risks:
                    return CheckResult(
                        name="no_critical_risks",
                        status=CheckStatus.FAILED,
                        message=f"Found {len(critical_risks)} critical risks",
                        details={"risks": critical_risks[:5]},
                        started_at=start,
                        completed_at=datetime.now(),
                    )

            return CheckResult(
                name="no_critical_risks",
                status=CheckStatus.PASSED,
                message="No critical risks detected",
                started_at=start,
                completed_at=datetime.now(),
            )

        except Exception as e:
            return CheckResult(
                name="no_critical_risks",
                status=CheckStatus.WARNING,
                severity=CheckSeverity.WARNING,
                message=f"Could not check risks: {str(e)}",
                started_at=start,
                completed_at=datetime.now(),
            )

    async def _check_no_pending_approvals(self, branch: str) -> CheckResult:
        """Check that no approvals are pending."""
        start = datetime.now()

        if self.approval_queue is None:
            return CheckResult(
                name="no_pending_approvals",
                status=CheckStatus.SKIPPED,
                message="Approval queue not configured",
                started_at=start,
                completed_at=datetime.now(),
            )

        try:
            # Check for pending approvals
            if hasattr(self.approval_queue, 'get_pending_approvals'):
                pending = self.approval_queue.get_pending_approvals()
                branch_pending = [a for a in pending if a.get('branch') == branch]

                if branch_pending:
                    return CheckResult(
                        name="no_pending_approvals",
                        status=CheckStatus.FAILED,
                        message=f"Found {len(branch_pending)} pending approvals",
                        details={"pending": branch_pending[:5]},
                        started_at=start,
                        completed_at=datetime.now(),
                    )

            return CheckResult(
                name="no_pending_approvals",
                status=CheckStatus.PASSED,
                message="No pending approvals",
                started_at=start,
                completed_at=datetime.now(),
            )

        except Exception as e:
            return CheckResult(
                name="no_pending_approvals",
                status=CheckStatus.WARNING,
                severity=CheckSeverity.WARNING,
                message=f"Could not check approvals: {str(e)}",
                started_at=start,
                completed_at=datetime.now(),
            )

    async def _check_status_packet_complete(self, branch: str) -> CheckResult:
        """Check that status packet is complete."""
        start = datetime.now()

        if self.db is None:
            return CheckResult(
                name="status_packet_complete",
                status=CheckStatus.SKIPPED,
                message="Database not configured",
                started_at=start,
                completed_at=datetime.now(),
            )

        try:
            # Check for complete status packet
            if hasattr(self.db, 'get_latest_status_packet'):
                packet = self.db.get_latest_status_packet(branch)

                if packet is None:
                    return CheckResult(
                        name="status_packet_complete",
                        status=CheckStatus.FAILED,
                        message="No status packet found for branch",
                        started_at=start,
                        completed_at=datetime.now(),
                    )

                # Check required fields
                required_fields = ['status', 'summary', 'files_changed']
                missing = [f for f in required_fields if not packet.get(f)]

                if missing:
                    return CheckResult(
                        name="status_packet_complete",
                        status=CheckStatus.WARNING,
                        severity=CheckSeverity.WARNING,
                        message=f"Status packet missing fields: {missing}",
                        started_at=start,
                        completed_at=datetime.now(),
                    )

            return CheckResult(
                name="status_packet_complete",
                status=CheckStatus.PASSED,
                message="Status packet is complete",
                started_at=start,
                completed_at=datetime.now(),
            )

        except Exception as e:
            return CheckResult(
                name="status_packet_complete",
                status=CheckStatus.WARNING,
                severity=CheckSeverity.WARNING,
                message=f"Could not check status packet: {str(e)}",
                started_at=start,
                completed_at=datetime.now(),
            )

    async def _run_git_command(
        self,
        args: List[str],
        working_dir: Optional[str],
        shell: bool = False,
    ) -> subprocess.CompletedProcess:
        """Run a git command asynchronously."""
        if shell:
            cmd = ' '.join(args)
            process = await asyncio.create_subprocess_shell(
                cmd,
                cwd=working_dir,
                stdout=asyncio.subprocess.PIPE,
                stderr=asyncio.subprocess.PIPE,
            )
        else:
            process = await asyncio.create_subprocess_exec(
                *args,
                cwd=working_dir,
                stdout=asyncio.subprocess.PIPE,
                stderr=asyncio.subprocess.PIPE,
            )

        stdout, stderr = await process.communicate()

        return subprocess.CompletedProcess(
            args=args,
            returncode=process.returncode,
            stdout=stdout.decode() if stdout else "",
            stderr=stderr.decode() if stderr else "",
        )

    def generate_report_text(self, report: ReadinessReport) -> str:
        """Generate a text report from ReadinessReport."""
        lines = [
            "=" * 60,
            "MERGE READINESS REPORT",
            "=" * 60,
            f"Branch: {report.branch} → {report.target_branch}",
            f"Ready to merge: {'YES' if report.is_ready else 'NO'}",
            f"Generated: {report.created_at.strftime('%Y-%m-%d %H:%M:%S')}",
            "",
            "CHECK RESULTS:",
        ]

        for check in report.checks:
            status_symbol = {
                CheckStatus.PASSED: "✓",
                CheckStatus.FAILED: "✗",
                CheckStatus.WARNING: "⚠",
                CheckStatus.SKIPPED: "○",
                CheckStatus.PENDING: "...",
            }.get(check.status, "?")

            lines.append(f"  {status_symbol} {check.name}: {check.message}")

        lines.append("")
        lines.append(f"Summary: {report.passed_count} passed, {report.failed_count} failed, "
                     f"{report.warning_count} warnings, {report.skipped_count} skipped")

        if report.blocking_issues:
            lines.append("")
            lines.append("BLOCKING ISSUES:")
            for issue in report.blocking_issues:
                lines.append(f"  - {issue}")

        if report.warnings:
            lines.append("")
            lines.append("WARNINGS:")
            for warning in report.warnings:
                lines.append(f"  - {warning}")

        lines.append("")
        lines.append(f"Total duration: {report.total_duration_ms}ms")
        lines.append("=" * 60)

        return "\n".join(lines)
