"""
Tests for Subscription Tier Tracking functionality.
"""

import pytest
from datetime import datetime
from unittest.mock import MagicMock, patch

from agent_orchestrator.subscriptions import (
    Provider,
    SubscriptionTier,
    TierLimits,
    SubscriptionConfig,
    TIER_CONFIGURATIONS,
    get_tier_limits,
    get_provider_for_tier,
    get_default_tier_for_provider,
    list_tiers_for_provider,
    AgentSubscription,
    SubscriptionManager,
    get_subscription_manager,
    set_subscription_manager,
    TierDetector,
    detect_tier_from_error,
    get_tier_detector,
)
from agent_orchestrator.tracking import RateLimitState, WindowType


class TestProvider:
    """Tests for Provider enum."""

    def test_provider_values(self):
        """Test provider enum values."""
        assert Provider.ANTHROPIC.value == "anthropic"
        assert Provider.OPENAI.value == "openai"
        assert Provider.GOOGLE.value == "google"


class TestSubscriptionTier:
    """Tests for SubscriptionTier enum."""

    def test_claude_tiers(self):
        """Test Claude tier values."""
        assert SubscriptionTier.CLAUDE_FREE.value == "claude_free"
        assert SubscriptionTier.CLAUDE_PRO.value == "claude_pro"
        assert SubscriptionTier.CLAUDE_MAX.value == "claude_max"

    def test_chatgpt_tiers(self):
        """Test ChatGPT tier values."""
        assert SubscriptionTier.CHATGPT_FREE.value == "chatgpt_free"
        assert SubscriptionTier.CHATGPT_PLUS.value == "chatgpt_plus"
        assert SubscriptionTier.CHATGPT_PRO.value == "chatgpt_pro"

    def test_gemini_tiers(self):
        """Test Gemini tier values."""
        assert SubscriptionTier.GEMINI_FREE.value == "gemini_free"
        assert SubscriptionTier.GEMINI_PRO.value == "gemini_pro"
        assert SubscriptionTier.GEMINI_ULTRA.value == "gemini_ultra"


class TestTierLimits:
    """Tests for TierLimits dataclass."""

    def test_create_tier_limits(self):
        """Test creating tier limits."""
        limits = TierLimits(
            messages_per_window=100,
            window_hours=5,
            context_window=128000,
            max_output_tokens=4096,
        )
        assert limits.messages_per_window == 100
        assert limits.window_hours == 5
        assert limits.context_window == 128000

    def test_messages_per_hour(self):
        """Test messages per hour calculation."""
        limits = TierLimits(
            messages_per_window=100,
            window_hours=5,
        )
        assert limits.messages_per_hour == 20.0

    def test_messages_per_hour_zero_window(self):
        """Test messages per hour with zero window."""
        limits = TierLimits(
            messages_per_window=100,
            window_hours=0,
        )
        assert limits.messages_per_hour == float("inf")

    def test_to_dict(self):
        """Test to_dict serialization."""
        limits = TierLimits(
            messages_per_window=100,
            window_hours=5,
            supports_computer_use=True,
            available_models=["model-1", "model-2"],
        )
        data = limits.to_dict()
        assert data["messages_per_window"] == 100
        assert data["supports_computer_use"] is True
        assert len(data["available_models"]) == 2


class TestSubscriptionConfig:
    """Tests for SubscriptionConfig dataclass."""

    def test_create_config(self):
        """Test creating subscription config."""
        limits = TierLimits(messages_per_window=100, window_hours=5)
        config = SubscriptionConfig(
            tier=SubscriptionTier.CLAUDE_PRO,
            provider=Provider.ANTHROPIC,
            limits=limits,
            account_email="test@example.com",
        )
        assert config.tier == SubscriptionTier.CLAUDE_PRO
        assert config.provider == Provider.ANTHROPIC
        assert config.account_email == "test@example.com"

    def test_to_dict(self):
        """Test to_dict serialization."""
        limits = TierLimits(messages_per_window=100, window_hours=5)
        config = SubscriptionConfig(
            tier=SubscriptionTier.CLAUDE_PRO,
            provider=Provider.ANTHROPIC,
            limits=limits,
        )
        data = config.to_dict()
        assert data["tier"] == "claude_pro"
        assert data["provider"] == "anthropic"


class TestTierConfigurations:
    """Tests for TIER_CONFIGURATIONS."""

    def test_all_tiers_configured(self):
        """Test all tiers have configurations."""
        for tier in SubscriptionTier:
            assert tier in TIER_CONFIGURATIONS or tier == SubscriptionTier.UNKNOWN

    def test_claude_free_limits(self):
        """Test Claude Free tier limits."""
        limits = TIER_CONFIGURATIONS[SubscriptionTier.CLAUDE_FREE]
        assert limits.messages_per_window == 5
        assert limits.window_hours == 5

    def test_claude_max_limits(self):
        """Test Claude Max tier limits."""
        limits = TIER_CONFIGURATIONS[SubscriptionTier.CLAUDE_MAX]
        assert limits.messages_per_window == 225
        assert limits.supports_computer_use is True

    def test_chatgpt_pro_limits(self):
        """Test ChatGPT Pro tier limits."""
        limits = TIER_CONFIGURATIONS[SubscriptionTier.CHATGPT_PRO]
        assert limits.messages_per_window == 200
        assert "o1-pro" in limits.available_models

    def test_gemini_ultra_limits(self):
        """Test Gemini Ultra tier limits."""
        limits = TIER_CONFIGURATIONS[SubscriptionTier.GEMINI_ULTRA]
        assert limits.context_window == 2000000
        assert "gemini-ultra" in limits.available_models


class TestTierFunctions:
    """Tests for tier helper functions."""

    def test_get_tier_limits(self):
        """Test get_tier_limits function."""
        limits = get_tier_limits(SubscriptionTier.CLAUDE_PRO)
        assert limits.messages_per_window == 45

    def test_get_tier_limits_unknown(self):
        """Test get_tier_limits for unknown tier."""
        limits = get_tier_limits(SubscriptionTier.UNKNOWN)
        assert limits.messages_per_window == 10

    def test_get_provider_for_tier(self):
        """Test get_provider_for_tier function."""
        assert get_provider_for_tier(SubscriptionTier.CLAUDE_PRO) == Provider.ANTHROPIC
        assert get_provider_for_tier(SubscriptionTier.CHATGPT_PLUS) == Provider.OPENAI
        assert get_provider_for_tier(SubscriptionTier.GEMINI_PRO) == Provider.GOOGLE

    def test_get_default_tier_for_provider(self):
        """Test get_default_tier_for_provider function."""
        assert get_default_tier_for_provider(Provider.ANTHROPIC) == SubscriptionTier.CLAUDE_PRO
        assert get_default_tier_for_provider(Provider.OPENAI) == SubscriptionTier.CHATGPT_PLUS
        assert get_default_tier_for_provider(Provider.GOOGLE) == SubscriptionTier.GEMINI_PRO

    def test_list_tiers_for_provider(self):
        """Test list_tiers_for_provider function."""
        claude_tiers = list_tiers_for_provider(Provider.ANTHROPIC)
        assert SubscriptionTier.CLAUDE_FREE in claude_tiers
        assert SubscriptionTier.CLAUDE_PRO in claude_tiers
        assert SubscriptionTier.CLAUDE_MAX in claude_tiers
        assert SubscriptionTier.CHATGPT_PLUS not in claude_tiers


class TestAgentSubscription:
    """Tests for AgentSubscription dataclass."""

    def test_create_subscription(self):
        """Test creating agent subscription."""
        limits = TierLimits(messages_per_window=100, window_hours=5)
        config = SubscriptionConfig(
            tier=SubscriptionTier.CLAUDE_PRO,
            provider=Provider.ANTHROPIC,
            limits=limits,
        )
        sub = AgentSubscription(
            agent_id="claude-1",
            config=config,
        )
        assert sub.agent_id == "claude-1"
        assert sub.is_verified is False

    def test_to_dict(self):
        """Test to_dict serialization."""
        limits = TierLimits(messages_per_window=100, window_hours=5)
        config = SubscriptionConfig(
            tier=SubscriptionTier.CLAUDE_PRO,
            provider=Provider.ANTHROPIC,
            limits=limits,
        )
        sub = AgentSubscription(
            agent_id="claude-1",
            config=config,
        )
        data = sub.to_dict()
        assert data["agent_id"] == "claude-1"
        assert data["is_verified"] is False


class TestSubscriptionManager:
    """Tests for SubscriptionManager."""

    def setup_method(self):
        """Reset manager before each test."""
        set_subscription_manager(None)

    def test_register_subscription(self):
        """Test registering a subscription."""
        manager = SubscriptionManager()
        sub = manager.register(
            agent_id="claude-1",
            tier=SubscriptionTier.CLAUDE_PRO,
            account_email="test@example.com",
        )
        assert sub.agent_id == "claude-1"
        assert sub.config.tier == SubscriptionTier.CLAUDE_PRO

    def test_get_subscription(self):
        """Test getting a subscription."""
        manager = SubscriptionManager()
        manager.register("claude-1", SubscriptionTier.CLAUDE_PRO)

        sub = manager.get_subscription("claude-1")
        assert sub is not None
        assert sub.agent_id == "claude-1"

    def test_get_subscription_not_found(self):
        """Test getting non-existent subscription."""
        manager = SubscriptionManager()
        sub = manager.get_subscription("unknown")
        assert sub is None

    def test_update_tier(self):
        """Test updating subscription tier."""
        manager = SubscriptionManager()
        manager.register("claude-1", SubscriptionTier.CLAUDE_PRO)

        sub = manager.update_tier("claude-1", SubscriptionTier.CLAUDE_MAX)
        assert sub is not None
        assert sub.config.tier == SubscriptionTier.CLAUDE_MAX

    def test_update_tier_not_found(self):
        """Test updating tier for non-existent subscription."""
        manager = SubscriptionManager()
        sub = manager.update_tier("unknown", SubscriptionTier.CLAUDE_MAX)
        assert sub is None

    def test_unregister(self):
        """Test unregistering a subscription."""
        manager = SubscriptionManager()
        manager.register("claude-1", SubscriptionTier.CLAUDE_PRO)

        result = manager.unregister("claude-1")
        assert result is True
        assert manager.get_subscription("claude-1") is None

    def test_unregister_not_found(self):
        """Test unregistering non-existent subscription."""
        manager = SubscriptionManager()
        result = manager.unregister("unknown")
        assert result is False

    def test_can_make_request(self):
        """Test can_make_request method."""
        manager = SubscriptionManager()
        manager.register("claude-1", SubscriptionTier.CLAUDE_PRO)

        # Can make request with usage below limit
        assert manager.can_make_request("claude-1", current_usage=10) is True

        # Cannot make request at limit
        assert manager.can_make_request("claude-1", current_usage=45) is False

    def test_can_make_request_no_limits(self):
        """Test can_make_request with no subscription."""
        manager = SubscriptionManager()
        # No limits configured, should allow
        assert manager.can_make_request("unknown", current_usage=1000) is True

    def test_get_remaining_requests(self):
        """Test get_remaining_requests method."""
        manager = SubscriptionManager()
        manager.register("claude-1", SubscriptionTier.CLAUDE_PRO)

        remaining = manager.get_remaining_requests("claude-1", current_usage=10)
        assert remaining == 35  # 45 - 10

    def test_list_agents(self):
        """Test listing all agents."""
        manager = SubscriptionManager()
        manager.register("claude-1", SubscriptionTier.CLAUDE_PRO)
        manager.register("claude-2", SubscriptionTier.CLAUDE_MAX)

        agents = manager.list_agents()
        assert "claude-1" in agents
        assert "claude-2" in agents

    def test_list_agents_by_tier(self):
        """Test listing agents by tier."""
        manager = SubscriptionManager()
        manager.register("claude-1", SubscriptionTier.CLAUDE_PRO)
        manager.register("claude-2", SubscriptionTier.CLAUDE_MAX)
        manager.register("claude-3", SubscriptionTier.CLAUDE_PRO)

        pro_agents = manager.list_agents_by_tier(SubscriptionTier.CLAUDE_PRO)
        assert "claude-1" in pro_agents
        assert "claude-3" in pro_agents
        assert "claude-2" not in pro_agents

    def test_list_agents_by_provider(self):
        """Test listing agents by provider."""
        manager = SubscriptionManager()
        manager.register("claude-1", SubscriptionTier.CLAUDE_PRO)
        manager.register("codex-1", SubscriptionTier.CHATGPT_PLUS)

        anthropic_agents = manager.list_agents_by_provider(Provider.ANTHROPIC)
        assert "claude-1" in anthropic_agents
        assert "codex-1" not in anthropic_agents

    def test_mark_verified(self):
        """Test marking subscription as verified."""
        manager = SubscriptionManager()
        manager.register("claude-1", SubscriptionTier.CLAUDE_PRO)

        result = manager.mark_verified("claude-1")
        assert result is True

        sub = manager.get_subscription("claude-1")
        assert sub.is_verified is True
        assert sub.last_verified_at is not None

    def test_on_change_callback(self):
        """Test subscription change callback."""
        callback_calls = []

        def callback(agent_id, subscription):
            callback_calls.append((agent_id, subscription))

        manager = SubscriptionManager()
        manager.on_change(callback)
        manager.register("claude-1", SubscriptionTier.CLAUDE_PRO)

        assert len(callback_calls) == 1
        assert callback_calls[0][0] == "claude-1"

    def test_get_summary(self):
        """Test getting subscription summary."""
        manager = SubscriptionManager()
        manager.register("claude-1", SubscriptionTier.CLAUDE_PRO)
        manager.register("claude-2", SubscriptionTier.CLAUDE_MAX)
        manager.mark_verified("claude-1")

        summary = manager.get_summary()
        assert summary["total_agents"] == 2
        assert summary["verified_count"] == 1
        assert "claude_pro" in summary["by_tier"]

    def test_global_manager_instance(self):
        """Test global manager instance management."""
        manager1 = get_subscription_manager()
        manager2 = get_subscription_manager()
        assert manager1 is manager2


class TestTierDetector:
    """Tests for TierDetector."""

    def test_detect_claude_tier_from_rate_limit(self):
        """Test detecting Claude tier from rate limit."""
        from datetime import datetime, timedelta
        detector = TierDetector()
        rate_limit = RateLimitState(
            requests_used=0,
            requests_limit=225,
            reset_at=datetime.now() + timedelta(hours=5),
            window_type=WindowType.FIVE_HOUR,
        )

        tier, confidence = detector.detect_tier("claude", rate_limit=rate_limit)
        assert tier == SubscriptionTier.CLAUDE_MAX
        assert confidence >= 0.9

    def test_detect_claude_free_from_rate_limit(self):
        """Test detecting Claude Free from rate limit."""
        from datetime import datetime, timedelta
        detector = TierDetector()
        rate_limit = RateLimitState(
            requests_used=0,
            requests_limit=5,
            reset_at=datetime.now() + timedelta(hours=5),
            window_type=WindowType.FIVE_HOUR,
        )

        tier, confidence = detector.detect_tier("claude", rate_limit=rate_limit)
        assert tier == SubscriptionTier.CLAUDE_FREE
        assert confidence >= 0.8

    def test_detect_claude_from_models(self):
        """Test detecting Claude tier from available models."""
        detector = TierDetector()
        tier, confidence = detector.detect_tier(
            "claude",
            available_models=["claude-opus-4", "claude-sonnet-4"],
        )
        assert tier in [SubscriptionTier.CLAUDE_PRO, SubscriptionTier.CLAUDE_MAX]
        assert confidence >= 0.7

    def test_detect_chatgpt_tier(self):
        """Test detecting ChatGPT tier."""
        from datetime import datetime, timedelta
        detector = TierDetector()
        rate_limit = RateLimitState(
            requests_used=0,
            requests_limit=200,
            reset_at=datetime.now() + timedelta(hours=3),
            window_type=WindowType.THREE_HOUR,
        )

        tier, confidence = detector.detect_tier("chatgpt", rate_limit=rate_limit)
        assert tier == SubscriptionTier.CHATGPT_PRO
        assert confidence >= 0.9

    def test_detect_gemini_tier(self):
        """Test detecting Gemini tier."""
        from datetime import datetime, timedelta
        detector = TierDetector()
        rate_limit = RateLimitState(
            requests_used=0,
            requests_limit=2000,
            reset_at=datetime.now() + timedelta(days=1),
            window_type=WindowType.DAILY,
        )

        tier, confidence = detector.detect_tier("gemini", rate_limit=rate_limit)
        assert tier == SubscriptionTier.GEMINI_ULTRA
        assert confidence >= 0.9

    def test_detect_unknown_agent(self):
        """Test detecting tier for unknown agent."""
        detector = TierDetector()
        tier, confidence = detector.detect_tier("unknown_agent")
        assert tier == SubscriptionTier.UNKNOWN
        assert confidence == 0.0

    def test_cache_detection(self):
        """Test caching detection results."""
        detector = TierDetector()
        detector.cache_detection("claude-1", SubscriptionTier.CLAUDE_PRO, 0.9)

        cached = detector.get_cached_tier("claude-1")
        assert cached == SubscriptionTier.CLAUDE_PRO

    def test_clear_cache(self):
        """Test clearing detection cache."""
        detector = TierDetector()
        detector.cache_detection("claude-1", SubscriptionTier.CLAUDE_PRO, 0.9)
        detector.clear_cache("claude-1")

        cached = detector.get_cached_tier("claude-1")
        assert cached is None


class TestDetectTierFromError:
    """Tests for detect_tier_from_error function."""

    def test_detect_free_from_upgrade_prompt(self):
        """Test detecting free tier from upgrade prompt."""
        # The function checks for "claude" in the message to determine Claude vs ChatGPT
        tier = detect_tier_from_error("Claude: Please upgrade to Pro for more features")
        assert tier == SubscriptionTier.CLAUDE_FREE

    def test_detect_from_model_access_error(self):
        """Test detecting tier from model access error."""
        tier = detect_tier_from_error("opus not available in your plan")
        assert tier == SubscriptionTier.CLAUDE_FREE

    def test_detect_chatgpt_plus_from_model_error(self):
        """Test detecting ChatGPT Plus from model error."""
        tier = detect_tier_from_error("o1-pro not available for Plus users")
        assert tier == SubscriptionTier.CHATGPT_PLUS

    def test_unknown_error_message(self):
        """Test unknown error message."""
        tier = detect_tier_from_error("Some random error")
        assert tier is None


class TestGlobalTierDetector:
    """Tests for global tier detector instance."""

    def test_get_tier_detector(self):
        """Test getting global tier detector."""
        detector1 = get_tier_detector()
        detector2 = get_tier_detector()
        assert detector1 is detector2
