"""
Configurable Risk Policy - Load risk patterns from configuration.

Provides:
- Configuration-driven risk patterns
- Organization-specific rules
- Runtime pattern updates
- Pattern import/export
"""

import json
import logging
import re
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Optional

from .policy import (
    RiskLevel,
    RiskClassification,
    RiskPolicy,
    CRITICAL_COMMAND_PATTERNS,
    CRITICAL_FILE_PATTERNS,
    HIGH_RISK_COMMAND_PATTERNS,
    HIGH_RISK_FILE_PATTERNS,
    MEDIUM_RISK_COMMAND_PATTERNS,
    MEDIUM_RISK_FILE_PATTERNS,
    LOW_RISK_COMMAND_PATTERNS,
    LOW_RISK_FILE_PATTERNS,
)

logger = logging.getLogger(__name__)


@dataclass
class RiskPattern:
    """A single risk pattern."""

    pattern: str  # Regex pattern
    description: str
    level: RiskLevel
    enabled: bool = True
    source: str = "default"  # "default", "custom", "organization"
    category: str = "general"  # For organization

    def to_dict(self) -> dict[str, Any]:
        """Convert to dictionary."""
        return {
            "pattern": self.pattern,
            "description": self.description,
            "level": self.level.value,
            "enabled": self.enabled,
            "source": self.source,
            "category": self.category,
        }

    @classmethod
    def from_dict(cls, data: dict[str, Any]) -> "RiskPattern":
        """Create from dictionary."""
        return cls(
            pattern=data["pattern"],
            description=data["description"],
            level=RiskLevel(data["level"]),
            enabled=data.get("enabled", True),
            source=data.get("source", "custom"),
            category=data.get("category", "general"),
        )


@dataclass
class RiskPolicyConfig:
    """Configuration for the risk policy."""

    # Include default patterns
    include_defaults: bool = True

    # Pattern overrides
    command_patterns: list[RiskPattern] = field(default_factory=list)
    file_patterns: list[RiskPattern] = field(default_factory=list)

    # Level adjustments
    level_overrides: dict[str, RiskLevel] = field(default_factory=dict)  # pattern -> level

    # Global settings
    default_file_level: RiskLevel = RiskLevel.MEDIUM
    default_command_level: RiskLevel = RiskLevel.MEDIUM

    # Organization-specific settings
    organization: str = ""
    allowed_paths: list[str] = field(default_factory=list)  # Always allow these paths
    blocked_paths: list[str] = field(default_factory=list)  # Always block these paths

    def to_dict(self) -> dict[str, Any]:
        """Convert to dictionary for serialization."""
        return {
            "include_defaults": self.include_defaults,
            "command_patterns": [p.to_dict() for p in self.command_patterns],
            "file_patterns": [p.to_dict() for p in self.file_patterns],
            "level_overrides": {k: v.value for k, v in self.level_overrides.items()},
            "default_file_level": self.default_file_level.value,
            "default_command_level": self.default_command_level.value,
            "organization": self.organization,
            "allowed_paths": self.allowed_paths,
            "blocked_paths": self.blocked_paths,
        }

    @classmethod
    def from_dict(cls, data: dict[str, Any]) -> "RiskPolicyConfig":
        """Create from dictionary."""
        return cls(
            include_defaults=data.get("include_defaults", True),
            command_patterns=[
                RiskPattern.from_dict(p) for p in data.get("command_patterns", [])
            ],
            file_patterns=[
                RiskPattern.from_dict(p) for p in data.get("file_patterns", [])
            ],
            level_overrides={
                k: RiskLevel(v) for k, v in data.get("level_overrides", {}).items()
            },
            default_file_level=RiskLevel(data.get("default_file_level", "medium")),
            default_command_level=RiskLevel(data.get("default_command_level", "medium")),
            organization=data.get("organization", ""),
            allowed_paths=data.get("allowed_paths", []),
            blocked_paths=data.get("blocked_paths", []),
        )


class ConfigurableRiskPolicy:
    """
    Risk policy with configurable patterns.

    Extends the base RiskPolicy with:
    - Configuration file loading
    - Custom pattern addition
    - Runtime pattern updates
    - Organization-specific rules

    Example:
        # Load from config file
        policy = ConfigurableRiskPolicy.from_file("risk_policy.json")

        # Add custom patterns
        policy.add_command_pattern(
            pattern=r"deploy-to-production",
            description="Production deployment",
            level=RiskLevel.CRITICAL,
        )

        # Override a pattern's level
        policy.override_level(r"git push", RiskLevel.HIGH)
    """

    def __init__(self, config: Optional[RiskPolicyConfig] = None) -> None:
        """
        Initialize the configurable risk policy.

        Args:
            config: Optional configuration (uses defaults if not provided)
        """
        self.config = config or RiskPolicyConfig()
        self._compiled_patterns: dict[str, list[tuple[re.Pattern, str, RiskLevel]]] = {
            "critical_commands": [],
            "critical_files": [],
            "high_commands": [],
            "high_files": [],
            "medium_commands": [],
            "medium_files": [],
            "low_commands": [],
            "low_files": [],
        }

        # Compile allowed/blocked path patterns
        self._allowed_paths = [re.compile(p) for p in self.config.allowed_paths]
        self._blocked_paths = [re.compile(p) for p in self.config.blocked_paths]

        self._compile_patterns()

    def _compile_patterns(self) -> None:
        """Compile all patterns into regex objects."""
        # Clear existing
        for key in self._compiled_patterns:
            self._compiled_patterns[key].clear()

        # Add default patterns if enabled
        if self.config.include_defaults:
            self._add_default_patterns()

        # Add custom patterns from config
        for pattern in self.config.command_patterns:
            if pattern.enabled:
                self._add_compiled_pattern("command", pattern)

        for pattern in self.config.file_patterns:
            if pattern.enabled:
                self._add_compiled_pattern("file", pattern)

    def _add_default_patterns(self) -> None:
        """Add default patterns from the base policy module."""
        # Critical commands
        for p, desc in CRITICAL_COMMAND_PATTERNS:
            level = self._get_override_level(p, RiskLevel.CRITICAL)
            self._compiled_patterns["critical_commands"].append(
                (re.compile(p, re.IGNORECASE), desc, level)
            )

        # Critical files
        for p, desc in CRITICAL_FILE_PATTERNS:
            level = self._get_override_level(p, RiskLevel.CRITICAL)
            self._compiled_patterns["critical_files"].append(
                (re.compile(p, re.IGNORECASE), desc, level)
            )

        # High risk commands
        for p, desc in HIGH_RISK_COMMAND_PATTERNS:
            level = self._get_override_level(p, RiskLevel.HIGH)
            self._compiled_patterns["high_commands"].append(
                (re.compile(p, re.IGNORECASE), desc, level)
            )

        # High risk files
        for p, desc in HIGH_RISK_FILE_PATTERNS:
            level = self._get_override_level(p, RiskLevel.HIGH)
            self._compiled_patterns["high_files"].append(
                (re.compile(p, re.IGNORECASE), desc, level)
            )

        # Medium risk commands
        for p, desc in MEDIUM_RISK_COMMAND_PATTERNS:
            level = self._get_override_level(p, RiskLevel.MEDIUM)
            self._compiled_patterns["medium_commands"].append(
                (re.compile(p, re.IGNORECASE), desc, level)
            )

        # Medium risk files
        for p, desc in MEDIUM_RISK_FILE_PATTERNS:
            level = self._get_override_level(p, RiskLevel.MEDIUM)
            self._compiled_patterns["medium_files"].append(
                (re.compile(p, re.IGNORECASE), desc, level)
            )

        # Low risk commands
        for p, desc in LOW_RISK_COMMAND_PATTERNS:
            level = self._get_override_level(p, RiskLevel.LOW)
            self._compiled_patterns["low_commands"].append(
                (re.compile(p, re.IGNORECASE), desc, level)
            )

        # Low risk files
        for p, desc in LOW_RISK_FILE_PATTERNS:
            level = self._get_override_level(p, RiskLevel.LOW)
            self._compiled_patterns["low_files"].append(
                (re.compile(p, re.IGNORECASE), desc, level)
            )

    def _get_override_level(self, pattern: str, default: RiskLevel) -> RiskLevel:
        """Get level for pattern, checking for overrides."""
        return self.config.level_overrides.get(pattern, default)

    def _add_compiled_pattern(self, pattern_type: str, pattern: RiskPattern) -> None:
        """Add a compiled pattern to the appropriate list."""
        key = f"{pattern.level.value}_{pattern_type}s"
        if key in self._compiled_patterns:
            compiled = re.compile(pattern.pattern, re.IGNORECASE)
            self._compiled_patterns[key].append(
                (compiled, pattern.description, pattern.level)
            )

    def classify_file(self, file_path: str) -> RiskClassification:
        """
        Classify risk level of a file path.

        Args:
            file_path: Path to file

        Returns:
            RiskClassification with level and reason
        """
        # Check blocked paths first (always CRITICAL)
        for blocked in self._blocked_paths:
            if blocked.search(file_path):
                return RiskClassification(
                    level=RiskLevel.CRITICAL,
                    reason="Path in blocked list",
                    pattern_matched=blocked.pattern,
                )

        # Check allowed paths (always LOW)
        for allowed in self._allowed_paths:
            if allowed.search(file_path):
                return RiskClassification(
                    level=RiskLevel.LOW,
                    reason="Path in allowed list",
                    pattern_matched=allowed.pattern,
                )

        # Check patterns in order of severity
        for level_name in ["critical", "high", "low", "medium"]:
            key = f"{level_name}_files"
            for pattern, desc, level in self._compiled_patterns[key]:
                if pattern.search(file_path):
                    return RiskClassification(
                        level=level,
                        reason=desc,
                        pattern_matched=pattern.pattern,
                    )

        # Default
        return RiskClassification(
            level=self.config.default_file_level,
            reason="Unknown file type, using default level",
        )

    def classify_command(self, command: str) -> RiskClassification:
        """
        Classify risk level of a command.

        Args:
            command: Command string

        Returns:
            RiskClassification with level and reason
        """
        # Check patterns in order of severity
        for level_name in ["critical", "high", "low", "medium"]:
            key = f"{level_name}_commands"
            for pattern, desc, level in self._compiled_patterns[key]:
                if pattern.search(command):
                    return RiskClassification(
                        level=level,
                        reason=desc,
                        pattern_matched=pattern.pattern,
                    )

        # Default
        return RiskClassification(
            level=self.config.default_command_level,
            reason="Unknown command, using default level",
        )

    def add_command_pattern(
        self,
        pattern: str,
        description: str,
        level: RiskLevel,
        category: str = "custom",
    ) -> None:
        """
        Add a custom command pattern at runtime.

        Args:
            pattern: Regex pattern
            description: Human-readable description
            level: Risk level for matching commands
            category: Category for organization
        """
        risk_pattern = RiskPattern(
            pattern=pattern,
            description=description,
            level=level,
            source="runtime",
            category=category,
        )
        self.config.command_patterns.append(risk_pattern)
        self._add_compiled_pattern("command", risk_pattern)
        logger.info(f"Added command pattern: {pattern} ({level.value})")

    def add_file_pattern(
        self,
        pattern: str,
        description: str,
        level: RiskLevel,
        category: str = "custom",
    ) -> None:
        """
        Add a custom file pattern at runtime.

        Args:
            pattern: Regex pattern
            description: Human-readable description
            level: Risk level for matching files
            category: Category for organization
        """
        risk_pattern = RiskPattern(
            pattern=pattern,
            description=description,
            level=level,
            source="runtime",
            category=category,
        )
        self.config.file_patterns.append(risk_pattern)
        self._add_compiled_pattern("file", risk_pattern)
        logger.info(f"Added file pattern: {pattern} ({level.value})")

    def override_level(self, pattern: str, new_level: RiskLevel) -> None:
        """
        Override the risk level for a pattern.

        Args:
            pattern: The pattern to override
            new_level: New risk level
        """
        self.config.level_overrides[pattern] = new_level
        self._compile_patterns()  # Recompile to apply override
        logger.info(f"Overrode pattern level: {pattern} -> {new_level.value}")

    def add_allowed_path(self, pattern: str) -> None:
        """Add a path pattern to the allowed list."""
        self.config.allowed_paths.append(pattern)
        self._allowed_paths.append(re.compile(pattern))
        logger.info(f"Added allowed path: {pattern}")

    def add_blocked_path(self, pattern: str) -> None:
        """Add a path pattern to the blocked list."""
        self.config.blocked_paths.append(pattern)
        self._blocked_paths.append(re.compile(pattern))
        logger.info(f"Added blocked path: {pattern}")

    def get_patterns_by_level(
        self, level: RiskLevel, pattern_type: str = "all"
    ) -> list[tuple[str, str]]:
        """
        Get all patterns of a specific level.

        Args:
            level: Risk level to filter by
            pattern_type: "command", "file", or "all"

        Returns:
            List of (pattern, description) tuples
        """
        results = []
        level_str = level.value

        if pattern_type in ("command", "all"):
            key = f"{level_str}_commands"
            for pattern, desc, _ in self._compiled_patterns.get(key, []):
                results.append((pattern.pattern, desc))

        if pattern_type in ("file", "all"):
            key = f"{level_str}_files"
            for pattern, desc, _ in self._compiled_patterns.get(key, []):
                results.append((pattern.pattern, desc))

        return results

    def get_pattern_count(self) -> dict[str, int]:
        """Get count of patterns by category."""
        counts = {}
        for key, patterns in self._compiled_patterns.items():
            counts[key] = len(patterns)
        counts["allowed_paths"] = len(self._allowed_paths)
        counts["blocked_paths"] = len(self._blocked_paths)
        return counts

    @classmethod
    def from_file(cls, path: Path | str) -> "ConfigurableRiskPolicy":
        """
        Load a risk policy from a JSON file.

        Args:
            path: Path to the config file

        Returns:
            ConfigurableRiskPolicy instance
        """
        path = Path(path)
        if not path.exists():
            logger.warning(f"Config file not found: {path}, using defaults")
            return cls()

        with open(path) as f:
            data = json.load(f)

        config = RiskPolicyConfig.from_dict(data)
        return cls(config)

    def save_to_file(self, path: Path | str) -> None:
        """
        Save the current configuration to a JSON file.

        Args:
            path: Path to save the config file
        """
        path = Path(path)
        with open(path, "w") as f:
            json.dump(self.config.to_dict(), f, indent=2)
        logger.info(f"Saved risk policy config to {path}")

    @classmethod
    def create_template_config(cls) -> dict[str, Any]:
        """
        Create a template configuration for customization.

        Returns:
            Dictionary with example configuration
        """
        return {
            "include_defaults": True,
            "organization": "example-org",
            "default_file_level": "medium",
            "default_command_level": "medium",
            "allowed_paths": [
                "/tmp/safe-workspace/.*",
            ],
            "blocked_paths": [
                "/production/.*",
                ".*\\.prod\\..*",
            ],
            "command_patterns": [
                {
                    "pattern": r"deploy-to-staging",
                    "description": "Staging deployment",
                    "level": "high",
                    "enabled": True,
                    "category": "deployment",
                },
            ],
            "file_patterns": [
                {
                    "pattern": r"internal/secrets/.*",
                    "description": "Internal secrets directory",
                    "level": "critical",
                    "enabled": True,
                    "category": "security",
                },
            ],
            "level_overrides": {
                r"git push": "high",
            },
        }


# Module-level configurable policy instance
_configurable_policy: Optional[ConfigurableRiskPolicy] = None


def get_configurable_policy() -> ConfigurableRiskPolicy:
    """Get or create the global ConfigurableRiskPolicy instance."""
    global _configurable_policy
    if _configurable_policy is None:
        _configurable_policy = ConfigurableRiskPolicy()
    return _configurable_policy


def set_configurable_policy(policy: ConfigurableRiskPolicy) -> None:
    """Set the global ConfigurableRiskPolicy instance."""
    global _configurable_policy
    _configurable_policy = policy


def load_policy_from_file(path: Path | str) -> ConfigurableRiskPolicy:
    """Load and set the global policy from a file."""
    policy = ConfigurableRiskPolicy.from_file(path)
    set_configurable_policy(policy)
    return policy
