"""
WorkflowEngine - Executes JSON-defined workflows.

Provides:
- Sequential and parallel step execution
- Conditional branching evaluation
- Variable interpolation
- Integration with task system and voting
"""

import asyncio
import logging
import re
from dataclasses import dataclass, field
from datetime import datetime
from typing import Any, Callable, Optional

from .models import (
    Workflow,
    WorkflowStep,
    StepType,
    StepStatus,
    WorkflowStatus,
)

logger = logging.getLogger(__name__)


@dataclass
class WorkflowContext:
    """Runtime context for workflow execution."""

    workflow_id: str
    variables: dict[str, Any] = field(default_factory=dict)
    step_results: dict[str, dict[str, Any]] = field(default_factory=dict)
    errors: list[str] = field(default_factory=list)
    started_at: Optional[datetime] = None
    completed_at: Optional[datetime] = None

    def get(self, path: str, default: Any = None) -> Any:
        """Get a value from context using dot notation."""
        parts = path.split(".")

        # Check if it's a step result reference
        if parts[0] == "steps" and len(parts) >= 2:
            step_id = parts[1]
            if step_id in self.step_results:
                result = self.step_results[step_id]
                for part in parts[2:]:
                    if isinstance(result, dict):
                        result = result.get(part)
                    else:
                        return default
                return result if result is not None else default

        # Check variables
        if parts[0] == "vars" and len(parts) >= 2:
            return self.variables.get(parts[1], default)

        # Direct variable lookup
        return self.variables.get(path, default)

    def set(self, key: str, value: Any) -> None:
        """Set a variable in the context."""
        self.variables[key] = value

    def set_step_result(
        self,
        step_id: str,
        status: str,
        result: Optional[dict[str, Any]] = None,
        error: Optional[str] = None,
    ) -> None:
        """Record a step's result."""
        self.step_results[step_id] = {
            "status": status,
            "result": result or {},
            "error": error,
            "completed_at": datetime.now().isoformat(),
        }

    def interpolate(self, text: str) -> str:
        """Interpolate variables in text using ${var} syntax."""
        if not isinstance(text, str):
            return text

        pattern = r"\$\{([^}]+)\}"

        def replace(match: re.Match) -> str:
            path = match.group(1)
            value = self.get(path, match.group(0))
            return str(value) if value is not None else match.group(0)

        return re.sub(pattern, replace, text)


class WorkflowEngine:
    """
    Executes workflows defined in JSON/YAML.

    Handles step execution, parallelism, conditions, and error handling.
    """

    def __init__(
        self,
        task_executor: Optional[Callable] = None,
        approval_handler: Optional[Callable] = None,
        vote_handler: Optional[Callable] = None,
    ) -> None:
        """
        Initialize the workflow engine.

        Args:
            task_executor: Async function to execute tasks
            approval_handler: Async function to handle approvals
            vote_handler: Async function to handle voting
        """
        self.task_executor = task_executor
        self.approval_handler = approval_handler
        self.vote_handler = vote_handler
        self._running_workflows: dict[str, WorkflowContext] = {}

    async def execute(
        self,
        workflow: Workflow,
        initial_variables: Optional[dict[str, Any]] = None,
    ) -> WorkflowContext:
        """
        Execute a workflow.

        Args:
            workflow: The workflow to execute
            initial_variables: Additional variables to merge

        Returns:
            WorkflowContext with execution results
        """
        # Create context
        context = WorkflowContext(
            workflow_id=workflow.id,
            variables={**workflow.variables, **(initial_variables or {})},
            started_at=datetime.now(),
        )

        self._running_workflows[workflow.id] = context

        # Update workflow status
        workflow.status = WorkflowStatus.RUNNING.value
        workflow.started_at = datetime.now()

        logger.info(f"Starting workflow {workflow.id}: {workflow.name}")

        try:
            # Execute steps
            await self._execute_steps(workflow, context)

            # Determine final status
            if workflow.has_failures():
                workflow.status = WorkflowStatus.FAILED.value
                workflow.error = "One or more steps failed"
            else:
                workflow.status = WorkflowStatus.COMPLETED.value

        except asyncio.CancelledError:
            workflow.status = WorkflowStatus.CANCELLED.value
            logger.info(f"Workflow {workflow.id} cancelled")
        except Exception as e:
            workflow.status = WorkflowStatus.FAILED.value
            workflow.error = str(e)
            context.errors.append(str(e))
            logger.error(f"Workflow {workflow.id} failed: {e}")
        finally:
            workflow.completed_at = datetime.now()
            context.completed_at = datetime.now()
            del self._running_workflows[workflow.id]

        logger.info(
            f"Workflow {workflow.id} finished with status: {workflow.status}"
        )

        return context

    async def _execute_steps(
        self,
        workflow: Workflow,
        context: WorkflowContext,
    ) -> None:
        """Execute all steps in the workflow."""
        max_iterations = len(workflow.steps) * 3  # Safety limit
        iterations = 0

        while not workflow.is_complete() and iterations < max_iterations:
            iterations += 1

            # Get ready steps
            ready_steps = workflow.get_ready_steps()
            if not ready_steps:
                # Check if we're blocked
                pending = [s for s in workflow.steps if s.status == "pending"]
                if pending:
                    logger.warning(
                        f"Workflow {workflow.id} blocked: "
                        f"{len(pending)} steps pending with unmet dependencies"
                    )
                    break
                continue

            # Execute steps (respecting concurrency limit)
            batch = ready_steps[: workflow.max_concurrent_steps]
            await self._execute_step_batch(batch, workflow, context)

            # Check for failure handling
            if workflow.has_failures() and workflow.on_failure == "stop":
                logger.info(f"Workflow {workflow.id} stopping due to failure")
                break

    async def _execute_step_batch(
        self,
        steps: list[WorkflowStep],
        workflow: Workflow,
        context: WorkflowContext,
    ) -> None:
        """Execute a batch of steps, potentially in parallel."""
        if len(steps) == 1:
            await self._execute_step(steps[0], workflow, context)
        else:
            # Execute in parallel
            tasks = [
                self._execute_step(step, workflow, context)
                for step in steps
            ]
            await asyncio.gather(*tasks, return_exceptions=True)

    async def _execute_step(
        self,
        step: WorkflowStep,
        workflow: Workflow,
        context: WorkflowContext,
    ) -> None:
        """Execute a single step."""
        logger.info(f"Executing step {step.id}: {step.name}")

        step.status = StepStatus.RUNNING.value
        step.started_at = datetime.now()
        workflow.current_step_id = step.id

        try:
            if step.step_type == StepType.TASK.value:
                await self._execute_task_step(step, context)
            elif step.step_type == StepType.PARALLEL.value:
                await self._execute_parallel_step(step, workflow, context)
            elif step.step_type == StepType.CONDITIONAL.value:
                await self._execute_conditional_step(step, workflow, context)
            elif step.step_type == StepType.WAIT.value:
                await self._execute_wait_step(step, context)
            elif step.step_type == StepType.APPROVAL.value:
                await self._execute_approval_step(step, context)
            elif step.step_type == StepType.VOTE.value:
                await self._execute_vote_step(step, context)
            else:
                raise ValueError(f"Unknown step type: {step.step_type}")

            step.status = StepStatus.COMPLETED.value
            logger.info(f"Step {step.id} completed successfully")

        except Exception as e:
            step.status = StepStatus.FAILED.value
            step.error = str(e)
            context.errors.append(f"Step {step.id}: {e}")
            logger.error(f"Step {step.id} failed: {e}")

            # Handle retry
            if step.retry_count < step.max_retries:
                step.retry_count += 1
                step.status = StepStatus.PENDING.value
                logger.info(
                    f"Retrying step {step.id} "
                    f"(attempt {step.retry_count}/{step.max_retries})"
                )
                await asyncio.sleep(step.retry_delay_seconds)
        finally:
            step.completed_at = datetime.now()

            # Record result in context
            context.set_step_result(
                step.id,
                step.status,
                step.result,
                step.error,
            )

    async def _execute_task_step(
        self,
        step: WorkflowStep,
        context: WorkflowContext,
    ) -> None:
        """Execute a task step."""
        if not self.task_executor:
            # Simulate task execution
            logger.info(f"Simulating task: {step.task_description}")
            step.result = {"simulated": True}
            return

        # Interpolate task description
        description = context.interpolate(step.task_description or "")

        # Execute via callback
        result = await self.task_executor(
            task_type=step.task_type,
            description=description,
            agent_id=step.agent_id,
            timeout_minutes=step.timeout_minutes,
        )

        step.result = result

        # Extract outputs
        if step.outputs and result:
            for name, path in step.outputs.items():
                value = _extract_value(result, path)
                context.set(name, value)

    async def _execute_parallel_step(
        self,
        step: WorkflowStep,
        workflow: Workflow,
        context: WorkflowContext,
    ) -> None:
        """Execute parallel sub-steps."""
        if not step.parallel_steps:
            return

        tasks = [
            self._execute_step(sub_step, workflow, context)
            for sub_step in step.parallel_steps
        ]

        results = await asyncio.gather(*tasks, return_exceptions=True)

        # Check for failures
        failures = [
            r for r in results
            if isinstance(r, Exception)
        ]
        if failures:
            raise Exception(
                f"{len(failures)} parallel steps failed: {failures[0]}"
            )

        step.result = {
            "parallel_results": [
                {"id": s.id, "status": s.status}
                for s in step.parallel_steps
            ]
        }

    async def _execute_conditional_step(
        self,
        step: WorkflowStep,
        workflow: Workflow,
        context: WorkflowContext,
    ) -> None:
        """Execute a conditional step."""
        if not step.condition:
            raise ValueError("Conditional step requires a condition")

        # Build evaluation context
        eval_context = {
            "vars": context.variables,
            "steps": context.step_results,
        }

        # Evaluate condition
        result = step.condition.evaluate(eval_context)
        logger.info(
            f"Condition '{step.condition.variable} {step.condition.operator} "
            f"{step.condition.value}' evaluated to {result}"
        )

        # Execute appropriate branch
        if result and step.on_true:
            await self._execute_step(step.on_true, workflow, context)
        elif not result and step.on_false:
            await self._execute_step(step.on_false, workflow, context)

        step.result = {"condition_result": result}

    async def _execute_wait_step(
        self,
        step: WorkflowStep,
        context: WorkflowContext,
    ) -> None:
        """Execute a wait step."""
        if not step.wait_for:
            raise ValueError("Wait step requires 'wait_for' variable")

        timeout = step.wait_timeout_minutes * 60
        start = datetime.now()

        while True:
            value = context.get(step.wait_for)
            if value is not None:
                step.result = {"waited_for": step.wait_for, "value": value}
                return

            elapsed = (datetime.now() - start).total_seconds()
            if elapsed >= timeout:
                raise TimeoutError(
                    f"Wait for '{step.wait_for}' timed out after "
                    f"{step.wait_timeout_minutes} minutes"
                )

            await asyncio.sleep(1)

    async def _execute_approval_step(
        self,
        step: WorkflowStep,
        context: WorkflowContext,
    ) -> None:
        """Execute an approval step."""
        if not self.approval_handler:
            logger.info(f"Auto-approving: {step.approval_message}")
            step.result = {"approved": True, "auto": True}
            return

        message = context.interpolate(step.approval_message or "Approval needed")

        result = await self.approval_handler(
            message=message,
            approvers=step.approvers,
        )

        step.result = result
        if not result.get("approved"):
            raise Exception("Approval rejected")

    async def _execute_vote_step(
        self,
        step: WorkflowStep,
        context: WorkflowContext,
    ) -> None:
        """Execute a voting step."""
        if not self.vote_handler:
            # Simulate voting
            winner = step.vote_options[0] if step.vote_options else None
            step.result = {"winner": winner, "simulated": True}
            return

        topic = context.interpolate(step.vote_topic or "Vote required")

        result = await self.vote_handler(
            topic=topic,
            options=step.vote_options,
            quorum=step.vote_quorum,
        )

        step.result = result

        if not result.get("winner"):
            raise Exception("Voting did not reach consensus")

    def get_running_workflows(self) -> list[str]:
        """Get IDs of currently running workflows."""
        return list(self._running_workflows.keys())

    def get_context(self, workflow_id: str) -> Optional[WorkflowContext]:
        """Get the context for a running workflow."""
        return self._running_workflows.get(workflow_id)

    async def cancel(self, workflow_id: str) -> bool:
        """Cancel a running workflow."""
        if workflow_id not in self._running_workflows:
            return False

        # The workflow will handle cancellation in its exception handler
        logger.info(f"Cancellation requested for workflow {workflow_id}")
        return True

    def validate_workflow(self, workflow: Workflow) -> list[str]:
        """
        Validate a workflow definition.

        Returns list of validation errors (empty if valid).
        """
        errors = []

        if not workflow.id:
            errors.append("Workflow must have an ID")
        if not workflow.name:
            errors.append("Workflow must have a name")
        if not workflow.steps:
            errors.append("Workflow must have at least one step")

        # Check for duplicate step IDs
        step_ids = set()
        for step in workflow.steps:
            if step.id in step_ids:
                errors.append(f"Duplicate step ID: {step.id}")
            step_ids.add(step.id)

            # Validate step dependencies
            for dep in step.depends_on:
                if dep not in step_ids and dep != step.id:
                    # Check if it's defined later
                    later_ids = {s.id for s in workflow.steps}
                    if dep not in later_ids:
                        errors.append(
                            f"Step {step.id} depends on unknown step: {dep}"
                        )

        # Check for circular dependencies
        circular = self._check_circular_dependencies(workflow.steps)
        if circular:
            errors.append(f"Circular dependency detected: {circular}")

        return errors

    def _check_circular_dependencies(
        self, steps: list[WorkflowStep]
    ) -> Optional[str]:
        """Check for circular dependencies in steps."""
        # Build dependency graph
        graph = {s.id: set(s.depends_on) for s in steps}

        # Topological sort to detect cycles
        visited = set()
        rec_stack = set()
        path = []

        def visit(node: str) -> Optional[str]:
            if node in rec_stack:
                # Found cycle
                cycle_start = path.index(node)
                cycle = " -> ".join(path[cycle_start:] + [node])
                return cycle
            if node in visited:
                return None

            visited.add(node)
            rec_stack.add(node)
            path.append(node)

            for dep in graph.get(node, []):
                result = visit(dep)
                if result:
                    return result

            path.pop()
            rec_stack.remove(node)
            return None

        for step_id in graph:
            result = visit(step_id)
            if result:
                return result

        return None


def _extract_value(data: dict[str, Any], path: str) -> Any:
    """Extract a value from nested dict using dot notation or JSONPath-like syntax."""
    if not path or not data:
        return None

    parts = path.replace("[", ".").replace("]", "").split(".")
    current = data

    for part in parts:
        if not part:
            continue
        if isinstance(current, dict):
            current = current.get(part)
        elif isinstance(current, list):
            try:
                idx = int(part)
                current = current[idx] if 0 <= idx < len(current) else None
            except ValueError:
                return None
        else:
            return None

        if current is None:
            return None

    return current
