"""Параллельный прогон LLM-бенчмарка по casebook.csv.

Usage:
    python bench-runner.py casebook.csv results.jsonl

Читает casebook (case_id, kind, ground_truth, text_a, text_b, note),
прогоняет каждый кейс через все настроенные модели через asyncio.gather,
пишет результаты в JSONL построчно (модель × кейс).

Провайдеры подключаются автоматически по наличию ключа в .env.
Если ключа нет — модель пропускается без ошибки.

Кастомизация: измени массивы MODELS и PRICING под свой набор.
"""
import asyncio
import csv
import json
import os
import re
import sys
import time
from pathlib import Path

import httpx
from dotenv import load_dotenv
from openai import AsyncOpenAI

load_dotenv()

ROOT = Path(__file__).parent
SYSTEM_PROMPT = (ROOT / "system-prompt.md").read_text(encoding="utf-8").strip()
USER_TEMPLATE = (ROOT / "user-template.md").read_text(encoding="utf-8").strip()


# ---- модели ----
# tag                — короткое имя для отчёта
# provider           — ключ из CLIENTS ниже
# model              — model_id для API
# extra_body         — provider-specific параметры (thinking toggles)
# thinking           — флаг для temperature/cost подсчёта
# requires_json_kw   — нужно ли слово "json" в user message (DashScope)

MODELS = [
    # === Американская премиум ===
    {"tag": "sonnet-default",         "provider": "anthropic", "model": "claude-sonnet-4-5",
     "extra_body": None,                                          "thinking": False},
    {"tag": "sonnet-thinking",        "provider": "anthropic", "model": "claude-sonnet-4-5",
     "extra_body": {"thinking": {"type": "enabled", "budget_tokens": 3000}}, "thinking": True},

    # === DeepSeek (топ Codeforces) ===
    {"tag": "deepseek-v4-pro",        "provider": "deepseek",  "model": "deepseek-chat",
     "extra_body": {"thinking": {"type": "enabled"}},             "thinking": True},

    # === GLM 5.1 (Z.AI) — два режима ===
    {"tag": "glm-5.1-thinking",       "provider": "zai",       "model": "glm-4.6",
     "extra_body": {"thinking": {"type": "enabled"}},             "thinking": True},
    {"tag": "glm-5.1-default",        "provider": "zai",       "model": "glm-4.6",
     "extra_body": {"thinking": {"type": "disabled"}},            "thinking": False},

    # === Qwen 3.6 Plus (DashScope International) — два режима ===
    {"tag": "qwen3.6-plus-thinking",  "provider": "dashscope", "model": "qwen-plus",
     "extra_body": {"enable_thinking": True},                     "thinking": True,
     "requires_json_kw": True},
    {"tag": "qwen3.6-plus-default",   "provider": "dashscope", "model": "qwen-plus",
     "extra_body": {"enable_thinking": False},                    "thinking": False,
     "requires_json_kw": True},

    # === Kimi через OpenRouter ===
    {"tag": "kimi-k2-thinking",       "provider": "openrouter", "model": "moonshotai/kimi-k2",
     "extra_body": None,                                          "thinking": True},

    # === Российские ===
    {"tag": "yandex-5.1-default",     "provider": "yandex",    "model": "yandexgpt/latest",
     "extra_body": None,                                          "thinking": False},
    {"tag": "gigachat-2-max",         "provider": "gigachat",  "model": "GigaChat-2-Max",
     "extra_body": None,                                          "thinking": False},
]


# ---- цены USD per 1M tokens (input, output). Reasoning тарифицируется как output. ----
PRICING = {
    "claude-sonnet-4-5":      (3.00, 15.00),
    "deepseek-chat":          (0.435, 0.87),
    "glm-4.6":                (0.60, 2.20),
    "qwen-plus":              (0.40, 1.20),
    "moonshotai/kimi-k2":     (0.50, 2.00),
    "yandexgpt/latest":       (1.20, 1.20),
    "GigaChat-2-Max":         (0.0, 0.0),  # тарификация в рублях, грубая оценка
}


# ---- клиенты по провайдерам, инициализируются по наличию ключа ----
def make_clients():
    clients = {}
    if os.getenv("ANTHROPIC_API_KEY"):
        clients["anthropic"] = AsyncOpenAI(
            base_url="https://api.anthropic.com/v1/",
            api_key=os.getenv("ANTHROPIC_API_KEY"),
            default_headers={"anthropic-version": "2023-06-01"},
        )
    if os.getenv("DEEPSEEK_API_KEY"):
        clients["deepseek"] = AsyncOpenAI(
            base_url="https://api.deepseek.com/v1",
            api_key=os.getenv("DEEPSEEK_API_KEY"),
        )
    if os.getenv("ZAI_API_KEY"):
        clients["zai"] = AsyncOpenAI(
            base_url="https://api.z.ai/api/paas/v4",
            api_key=os.getenv("ZAI_API_KEY"),
        )
    if os.getenv("DASHSCOPE_API_KEY"):
        clients["dashscope"] = AsyncOpenAI(
            base_url="https://dashscope-intl.aliyuncs.com/compatible-mode/v1",
            api_key=os.getenv("DASHSCOPE_API_KEY"),
        )
    if os.getenv("OPENROUTER_API_KEY"):
        clients["openrouter"] = AsyncOpenAI(
            base_url="https://openrouter.ai/api/v1",
            api_key=os.getenv("OPENROUTER_API_KEY"),
        )
    if os.getenv("YANDEX_API_KEY") and os.getenv("YANDEX_FOLDER_ID"):
        clients["yandex"] = AsyncOpenAI(
            base_url="https://llm.api.cloud.yandex.net/v1",
            api_key=os.getenv("YANDEX_API_KEY"),
            default_headers={"x-folder-id": os.getenv("YANDEX_FOLDER_ID")},
        )
    if os.getenv("GIGACHAT_AUTH_KEY"):
        # GigaChat использует OAuth2, а не bearer-key. Здесь — упрощённая ветка
        # для ясности; в проде токен надо обновлять каждые 30 минут.
        clients["gigachat"] = "gigachat_special"  # обрабатывается отдельно
    return clients


CLIENTS = make_clients()
SEMAPHORES = {p: asyncio.Semaphore(6) for p in CLIENTS}


# ---- парсинг JSON-ответа модели ----
def parse_response(text: str):
    text = (text or "").strip()
    m = re.search(r"```(?:json)?\s*(\{.*?\})\s*```", text, re.DOTALL)
    cleaned = m.group(1) if m else text
    if not cleaned.startswith("{"):
        m2 = re.search(r"(\{[^{}]*\"verdict\"[^{}]*\})", text, re.DOTALL)
        cleaned = m2.group(1) if m2 else cleaned
    try:
        data = json.loads(cleaned)
        verdict = str(data.get("verdict", "?")).upper()
        if verdict not in {"YES", "NO", "MAYBE"}:
            verdict = "?"
        return verdict, str(data.get("reason", ""))[:200]
    except Exception:
        return "PARSE_ERR", text[:80]


# ---- запрос к одной модели ----
async def call_model(model_cfg, case):
    provider = model_cfg["provider"]
    if provider not in CLIENTS:
        return None  # ключа нет → пропуск

    client = CLIENTS[provider]
    user = USER_TEMPLATE.format(text_a=case["text_a"], text_b=case["text_b"])

    async with SEMAPHORES[provider]:
        t0 = time.monotonic()
        try:
            kwargs = {
                "model": model_cfg["model"],
                "messages": [
                    {"role": "system", "content": SYSTEM_PROMPT},
                    {"role": "user", "content": user},
                ],
                "max_completion_tokens": 5000,
                "temperature": 1.0 if model_cfg["thinking"] else 0.0,
            }
            if model_cfg.get("extra_body"):
                kwargs["extra_body"] = model_cfg["extra_body"]

            # JSON mode где поддерживается
            if provider in {"deepseek", "zai", "dashscope", "openrouter"}:
                kwargs["response_format"] = {"type": "json_object"}

            resp = await client.chat.completions.create(**kwargs)
            dt = time.monotonic() - t0
            text = resp.choices[0].message.content or ""
            verdict, reason = parse_response(text)

            usage = resp.usage
            in_tok = usage.prompt_tokens or 0
            out_tok = usage.completion_tokens or 0
            details = getattr(usage, "completion_tokens_details", None)
            reasoning_tok = getattr(details, "reasoning_tokens", 0) if details else 0

            in_p, out_p = PRICING.get(model_cfg["model"], (0.0, 0.0))
            cost = (in_tok * in_p + out_tok * out_p) / 1_000_000

            return {
                "model": model_cfg["tag"],
                "case_id": case["case_id"],
                "kind": case["kind"],
                "ground_truth": case["ground_truth"],
                "verdict": verdict,
                "reason": reason,
                "latency_s": round(dt, 2),
                "input_tokens": in_tok,
                "output_tokens": out_tok,
                "reasoning_tokens": reasoning_tok,
                "cost_usd": round(cost, 5),
            }
        except Exception as e:
            return {
                "model": model_cfg["tag"],
                "case_id": case["case_id"],
                "error": str(e)[:200],
                "latency_s": round(time.monotonic() - t0, 2),
            }


async def main(casebook_path: str, output_path: str):
    cases = []
    with open(casebook_path, encoding="utf-8") as f:
        for row in csv.DictReader(f):
            cases.append(row)

    skipped = [m for m in MODELS if m["provider"] not in CLIENTS]
    if skipped:
        print(f"[skip] no key for: {', '.join(m['tag'] for m in skipped)}", file=sys.stderr)

    active = [m for m in MODELS if m["provider"] in CLIENTS]
    print(f"[run] {len(active)} models × {len(cases)} cases = {len(active) * len(cases)} requests", file=sys.stderr)

    tasks = [call_model(m, c) for m in active for c in cases]
    t0 = time.monotonic()
    results = await asyncio.gather(*tasks)
    dt = time.monotonic() - t0
    print(f"[done] {dt:.1f}s wall time", file=sys.stderr)

    with open(output_path, "w", encoding="utf-8") as f:
        for r in results:
            if r:
                f.write(json.dumps(r, ensure_ascii=False) + "\n")

    ok = sum(1 for r in results if r and "verdict" in r)
    err = sum(1 for r in results if r and "error" in r)
    print(f"[result] {ok} ok, {err} errors → {output_path}", file=sys.stderr)


if __name__ == "__main__":
    if len(sys.argv) < 3:
        print("Usage: python bench-runner.py casebook.csv results.jsonl", file=sys.stderr)
        sys.exit(1)
    asyncio.run(main(sys.argv[1], sys.argv[2]))
