"""
Model cascades: parallel batches, key rotation, cross-provider race, JSON validation.
"""

import asyncio
import logging
from collections.abc import Callable
from typing import Any

import httpx

from app.providers import providers
from app.utils.constants import (
    ENABLE_CROSS_PROVIDER_RACE,
    EXTRACTION_PROVIDER_ORDER,
    EXTRACTION_TIMEOUT,
    FAST_MODE_MAX_BATCHES,
    FAST_MODE_MAX_KEYS,
    INTERPRETATION_PROVIDER_ORDER,
    INTERPRETATION_TIMEOUT,
    PARALLEL_MODEL_COUNT,
    PROVIDER_CASCADES,
    RACE_MAX_KEYS,
    RATE_LIMIT_RETRY_DELAY,
)
from app.utils.helpers import (
    is_gemma_model,
    try_parse_json,
    validate_extraction_json,
    validate_interpretation_json,
)

logger = logging.getLogger("lab_analyzer")

PROVIDER_REGISTRY = {
    "gemini": (providers.call_gemini, providers.get_gemini_keys),
    "groq": (providers.call_groq, providers.get_groq_keys),
    "openrouter": (providers.call_openrouter, providers.get_openrouter_keys),
}


def _log_success(provider: str, model: str, key_num: int) -> None:
    logger.info("[SUCCESS] %s | %s | key #%s", provider, model, key_num)


def _log_failed(provider: str, reason: str, model: str = "") -> None:
    if model:
        logger.warning("[FAILED] %s | %s | %s", provider, model, reason)
    else:
        logger.warning("[FAILED] %s | %s", provider, reason)


def _log_invalid_json(provider: str, model: str) -> None:
    logger.warning("[INVALID_JSON] %s | %s", provider, model)


def _log_retry(message: str) -> None:
    logger.info("[RETRY] %s", message)


def _log_batch(provider: str, batch_num: int, models: list[str], key_num: int) -> None:
    logger.info(
        "[BATCH] %s | round %s | key #%s | models: %s",
        provider,
        batch_num,
        key_num,
        ", ".join(models),
    )


def _classify_error(exc: Exception) -> str:
    if isinstance(exc, asyncio.TimeoutError):
        return "timeout"
    if isinstance(exc, httpx.TimeoutException):
        return "timeout"
    if isinstance(exc, httpx.ConnectError):
        return "connection error"
    if isinstance(exc, httpx.HTTPStatusError):
        code = exc.response.status_code
        if code == 429:
            return "rate limit"
        if code in (401, 403):
            return "auth/blocked"
        if code >= 500:
            return "provider unavailable"
        return f"http {code}"
    msg = str(exc).lower()
    if "rate" in msg and "limit" in msg:
        return "rate limit"
    if "invalid json" in msg or "json" in msg:
        return "invalid json"
    if "empty" in msg:
        return "empty response"
    if "blocked" in msg or "safety" in msg:
        return "blocked prompt"
    return str(exc)[:120]


def _usable_keys(keys_fn: Callable[[], list[str]]) -> list[str]:
    return [k for k in keys_fn() if k and not k.startswith("YOUR_")]


def _provider_ready(provider: str) -> bool:
    if provider not in PROVIDER_REGISTRY:
        return False
    _, keys_fn = PROVIDER_REGISTRY[provider]
    return bool(_usable_keys(keys_fn)) and bool(PROVIDER_CASCADES.get(provider))


def _chunk_cascade(cascade: list[str], size: int = PARALLEL_MODEL_COUNT) -> list[list[str]]:
    return [cascade[i : i + size] for i in range(0, len(cascade), size)]


async def _call_single_model(
    provider: str,
    model: str,
    api_key: str,
    system_prompt: str,
    user_prompt: str,
    timeout: float,
    validate_fn: Callable[[dict[str, Any] | None], bool],
) -> dict[str, Any]:
    if provider not in PROVIDER_REGISTRY:
        raise ValueError(f"Unknown provider: {provider}")

    call_fn, _ = PROVIDER_REGISTRY[provider]
    strict_json = not is_gemma_model(model)

    raw = await asyncio.wait_for(
        call_fn(
            api_key,
            model,
            system_prompt,
            user_prompt,
            timeout,
            strict_json=strict_json,
        ),
        timeout=timeout,
    )

    if not raw or not raw.strip():
        raise ValueError("empty response")

    parsed = try_parse_json(raw)
    if parsed is None:
        raise ValueError("invalid json: could not parse response")
    if not validate_fn(parsed):
        raise ValueError("invalid json: failed structure validation")

    return parsed  # type: ignore[return-value]


async def _race_parallel_batch(
    provider: str,
    models: list[str],
    api_key: str,
    key_num: int,
    system_prompt: str,
    user_prompt: str,
    timeout: float,
    validate_fn: Callable[[dict[str, Any] | None], bool],
) -> tuple[str, dict[str, Any]]:
    tasks = [
        asyncio.create_task(
            _call_single_model(
                provider, model, api_key,
                system_prompt, user_prompt, timeout, validate_fn,
            ),
            name=f"{provider}|{model}",
        )
        for model in models
    ]

    errors: list[str] = []
    try:
        pending = set(tasks)
        while pending:
            done, pending = await asyncio.wait(
                pending,
                return_when=asyncio.FIRST_COMPLETED,
            )
            for task in done:
                if task.cancelled():
                    continue
                model_name = task.get_name().split("|", 1)[-1]
                try:
                    data = task.result()
                    _log_success(provider, model_name, key_num)
                    for other in pending:
                        other.cancel()
                    await asyncio.gather(*pending, return_exceptions=True)
                    return model_name, data
                except asyncio.CancelledError:
                    continue
                except Exception as exc:
                    reason = _classify_error(exc)
                    if "invalid json" in reason or "invalid json" in str(exc).lower():
                        _log_invalid_json(provider, model_name)
                    else:
                        _log_failed(provider, reason, model_name)
                    errors.append(f"{model_name}: {reason}")

        raise RuntimeError("; ".join(errors) if errors else "batch failed")
    finally:
        for task in tasks:
            if not task.done():
                task.cancel()
        await asyncio.gather(*tasks, return_exceptions=True)


async def _run_provider_cascade(
    provider: str,
    system_prompt: str,
    user_prompt: str,
    timeout: float,
    validate_fn: Callable[[dict[str, Any] | None], bool],
    max_batches: int | None = None,
    max_keys: int | None = None,
) -> tuple[str, dict[str, Any]]:
    if provider not in PROVIDER_REGISTRY:
        raise ValueError(f"Unknown provider: {provider}")

    _, keys_fn = PROVIDER_REGISTRY[provider]
    cascade = PROVIDER_CASCADES.get(provider, [])
    keys = _usable_keys(keys_fn)

    if not keys:
        raise RuntimeError(f"{provider}: no valid API keys configured")
    if not cascade:
        raise RuntimeError(f"{provider}: empty model cascade")

    batches = _chunk_cascade(cascade, PARALLEL_MODEL_COUNT)
    if max_batches is not None:
        batches = batches[:max_batches]

    keys_to_use = keys[:max_keys] if max_keys is not None else keys
    all_errors: list[str] = []

    for batch_idx, model_batch in enumerate(batches, start=1):
        for key_idx, api_key in enumerate(keys_to_use, start=1):
            _log_batch(provider, batch_idx, model_batch, key_idx)

            try:
                model, data = await _race_parallel_batch(
                    provider=provider,
                    models=model_batch,
                    api_key=api_key,
                    key_num=key_idx,
                    system_prompt=system_prompt,
                    user_prompt=user_prompt,
                    timeout=timeout,
                    validate_fn=validate_fn,
                )
                return f"{provider}|{model}|key{key_idx}", data

            except Exception as exc:
                reason = _classify_error(exc)
                _log_failed(provider, reason)
                all_errors.append(f"batch{batch_idx}/key{key_idx}: {reason}")

                if reason == "rate limit" and key_idx < len(keys_to_use):
                    _log_retry(f"Switching API key after rate limit ({provider})...")
                    await asyncio.sleep(RATE_LIMIT_RETRY_DELAY)
                elif key_idx < len(keys_to_use):
                    _log_retry(f"Switching API key ({provider})...")

        if batch_idx < len(batches):
            _log_retry(
                f"Moving to next model batch ({provider}): "
                f"{', '.join(batches[batch_idx][:3])}"
            )

    raise RuntimeError(
        f"{provider} cascade exhausted. " + "; ".join(all_errors[-6:])
    )


async def _race_providers_first_batch(
    provider_order: list[str],
    system_prompt: str,
    user_prompt: str,
    timeout: float,
    validate_fn: Callable[[dict[str, Any] | None], bool],
) -> tuple[str, dict[str, Any]]:
    eligible = [p for p in provider_order if _provider_ready(p)]
    if not eligible:
        raise RuntimeError("No providers configured with valid keys")

    async def _attempt(provider: str) -> tuple[str, dict[str, Any]]:
        return await _run_provider_cascade(
            provider=provider,
            system_prompt=system_prompt,
            user_prompt=user_prompt,
            timeout=timeout,
            validate_fn=validate_fn,
            max_batches=1,
            max_keys=RACE_MAX_KEYS,
        )

    tasks = [
        asyncio.create_task(_attempt(provider), name=f"provider|{provider}")
        for provider in eligible
    ]

    errors: list[str] = []
    try:
        pending = set(tasks)
        while pending:
            done, pending = await asyncio.wait(
                pending,
                return_when=asyncio.FIRST_COMPLETED,
            )
            for task in done:
                if task.cancelled():
                    continue
                provider_name = task.get_name().split("|", 1)[-1]
                try:
                    label, data = task.result()
                    logger.info("[RACE_WIN] %s", label)
                    for other in pending:
                        other.cancel()
                    await asyncio.gather(*pending, return_exceptions=True)
                    return label, data
                except asyncio.CancelledError:
                    continue
                except Exception as exc:
                    reason = _classify_error(exc)
                    _log_failed(provider_name, f"race: {reason}")
                    errors.append(f"{provider_name}: {reason}")

        raise RuntimeError("; ".join(errors) if errors else "provider race failed")
    finally:
        for task in tasks:
            if not task.done():
                task.cancel()
        await asyncio.gather(*tasks, return_exceptions=True)


async def _run_limited_cascade(
    provider_order: list[str],
    system_prompt: str,
    user_prompt: str,
    timeout: float,
    validate_fn: Callable[[dict[str, Any] | None], bool],
    task_label: str,
) -> tuple[str, dict[str, Any]]:
    last_error: Exception | None = None

    for provider in provider_order:
        if not _provider_ready(provider):
            continue
        try:
            logger.info("[CASCADE] %s cascade: %s", task_label, provider)
            return await _run_provider_cascade(
                provider=provider,
                system_prompt=system_prompt,
                user_prompt=user_prompt,
                timeout=timeout,
                validate_fn=validate_fn,
                max_batches=FAST_MODE_MAX_BATCHES,
                max_keys=FAST_MODE_MAX_KEYS,
            )
        except Exception as exc:
            last_error = exc
            _log_failed(provider, f"entire cascade failed: {_classify_error(exc)}")

    raise RuntimeError(f"All {task_label} providers failed: {last_error}")


async def _run_smart_cascade(
    provider_order: list[str],
    system_prompt: str,
    user_prompt: str,
    timeout: float,
    validate_fn: Callable[[dict[str, Any] | None], bool],
    task_label: str,
) -> tuple[str, dict[str, Any]]:
    if ENABLE_CROSS_PROVIDER_RACE:
        try:
            logger.info("[CASCADE] %s — cross-provider race", task_label)
            return await _race_providers_first_batch(
                provider_order=provider_order,
                system_prompt=system_prompt,
                user_prompt=user_prompt,
                timeout=timeout,
                validate_fn=validate_fn,
            )
        except Exception as exc:
            logger.warning(
                "[CASCADE] %s race failed, using sequential fallback: %s",
                task_label,
                exc,
            )

    return await _run_limited_cascade(
        provider_order=provider_order,
        system_prompt=system_prompt,
        user_prompt=user_prompt,
        timeout=timeout,
        validate_fn=validate_fn,
        task_label=task_label,
    )


async def run_cascade_extraction(
    system_prompt: str,
    user_prompt: str,
) -> tuple[str, dict[str, Any]]:
    return await _run_smart_cascade(
        provider_order=EXTRACTION_PROVIDER_ORDER,
        system_prompt=system_prompt,
        user_prompt=user_prompt,
        timeout=EXTRACTION_TIMEOUT,
        validate_fn=validate_extraction_json,
        task_label="extraction",
    )


async def run_cascade_interpretation(
    system_prompt: str,
    user_prompt: str,
) -> tuple[str, dict[str, Any]]:
    return await _run_smart_cascade(
        provider_order=INTERPRETATION_PROVIDER_ORDER,
        system_prompt=system_prompt,
        user_prompt=user_prompt,
        timeout=INTERPRETATION_TIMEOUT,
        validate_fn=validate_interpretation_json,
        task_label="interpretation",
    )
