"""多 Agent 协作调度与评分体系。"""

from __future__ import annotations

import json
import re
from dataclasses import dataclass, field
from typing import Dict, List, Optional

from backend.infra.logging import logger

from .base import AgentResult, BaseAgent
from .workflow import WorkflowContext


@dataclass
class ScoreEntry:
    points: float = 0.0
    rounds: int = 0
    wins: int = 0
    losses: int = 0
    notes: List[str] = field(default_factory=list)


class AgentScoreBoard:
    """维护多 Agent 评分。"""

    def __init__(self) -> None:
        self._scores: Dict[str, ScoreEntry] = {}

    def update(self, agent: str, delta: float, *, note: str = "") -> None:
        entry = self._scores.setdefault(agent, ScoreEntry())
        entry.rounds += 1
        entry.points += delta
        if delta > 0:
            entry.wins += 1
        elif delta < 0:
            entry.losses += 1
        if note:
            entry.notes.append(note)
        logger.info(
            "scoreboard:update",
            agent=agent,
            delta=delta,
            points=entry.points,
            rounds=entry.rounds,
        )

    def snapshot(self) -> List[Dict[str, object]]:
        data = [
            {
                "agent": agent,
                "points": round(entry.points, 4),
                "rounds": entry.rounds,
                "wins": entry.wins,
                "losses": entry.losses,
                "notes": list(entry.notes[-5:]),
            }
            for agent, entry in self._scores.items()
        ]
        data.sort(key=lambda item: item["points"], reverse=True)
        return data


class ReflectionEngine:
    """根据风险与评分连续情况触发反思提示。"""

    def __init__(self, *, threshold: int = 2) -> None:
        self._threshold = threshold
        self._consecutive_negatives = 0

    def process(self, risk_flagged: bool, review_score: float) -> List[str]:
        negative = risk_flagged or review_score < 0
        if negative:
            self._consecutive_negatives += 1
        else:
            self._consecutive_negatives = 0
        if negative and self._consecutive_negatives >= self._threshold:
            note = (
                "Reflection: 连续两轮出现风险或负面评分，请复盘策略假设、仓位管理与工具调用。"
            )
            self._consecutive_negatives = 0
            return [note]
        return []


@dataclass
class RoundOutcome:
    broadcast: AgentResult
    strategy: AgentResult
    risk: AgentResult
    review: AgentResult
    risk_status: str
    risk_warnings: List[str]
    review_score: float
    reflections: List[str]
    scoreboard: List[Dict[str, object]]


class MultiAgentOrchestrator:
    """Market → Strategy → Risk → Review 多 Agent 调度。"""

    def __init__(
        self,
        *,
        market_agent: BaseAgent,
        strategy_agent: BaseAgent,
        risk_agent: BaseAgent,
        review_agent: BaseAgent,
        scoreboard: Optional[AgentScoreBoard] = None,
        reflection_engine: Optional[ReflectionEngine] = None,
    ) -> None:
        self.market_agent = market_agent
        self.strategy_agent = strategy_agent
        self.risk_agent = risk_agent
        self.review_agent = review_agent
        self.scoreboard = scoreboard or AgentScoreBoard()
        self.reflection_engine = reflection_engine or ReflectionEngine()

    async def run_round(
        self,
        snapshot: Dict[str, object],
        account: Dict[str, object],
    ) -> RoundOutcome:
        broadcast = await self.market_agent.run(
            "生成市场播报",
            context={"snapshot": snapshot},
        )

        strategy_context = {
            "snapshot": snapshot,
            "account": account,
            "broadcast": broadcast.final,
        }
        strategy = await self.strategy_agent.run(
            "制定当期交易计划",
            context=strategy_context,
        )

        risk_context = {
            "snapshot": snapshot,
            "account": account,
            "broadcast": broadcast.final,
            "strategy": strategy.final,
            "actions": strategy.actions,
        }
        risk = await self.risk_agent.run(
            "审核策略风险并输出状态",
            context=risk_context,
        )

        review_context = {
            "snapshot": snapshot,
            "account": account,
            "broadcast": broadcast.final,
            "strategy": strategy.final,
            "risk": risk.final,
            "actions": strategy.actions,
        }
        review = await self.review_agent.run(
            "复盘本轮并给出评分",
            context=review_context,
        )

        risk_summary = self._parse_json_from_result(risk)
        risk_status = (risk_summary.get("status") or "").upper()
        risk_warnings = risk_summary.get("warnings") or []
        risk_flagged = risk_status == "FLAGGED"

        review_summary = self._parse_json_from_result(review)
        review_score = float(review_summary.get("score") or 0)

        self._apply_scores(risk_flagged, review_score)
        reflections = self.reflection_engine.process(risk_flagged, review_score)
        for note in reflections:
            self.review_agent.memory.append_text(note)

        return RoundOutcome(
            broadcast=broadcast,
            strategy=strategy,
            risk=risk,
            review=review,
            risk_status=risk_status or "APPROVED",
            risk_warnings=list(risk_warnings),
            review_score=review_score,
            reflections=reflections,
            scoreboard=self.scoreboard.snapshot(),
        )

    def _apply_scores(self, risk_flagged: bool, review_score: float) -> None:
        if risk_flagged:
            self.scoreboard.update("strategy", -2.0, note="risk_flagged")
            self.scoreboard.update("risk", 1.0, note="flagged_detection")
        else:
            self.scoreboard.update("strategy", 1.0, note="risk_approved")
            self.scoreboard.update("risk", 0.5, note="approved")

        if review_score:
            note = f"review_score={review_score}"
            self.scoreboard.update("strategy", review_score, note=note)
            self.scoreboard.update("review", review_score, note=note)
        else:
            self.scoreboard.update("review", 0.1, note="neutral_review")

        # 市场播报加微量分数，鼓励稳定输出
        self.scoreboard.update("market-broadcast", 0.1, note="info_update")

    def _parse_json_from_result(self, result: AgentResult) -> Dict[str, object]:
        candidates = [result.final, result.raw_response]
        for item in candidates:
            if not item:
                continue
            text = item if isinstance(item, str) else str(item)
            data = self._try_load_json(text)
            if data is not None and isinstance(data, dict):
                return data
        for execution in result.tool_results:
            if isinstance(execution.output, dict):
                return execution.output
            if isinstance(execution.output, str):
                data = self._try_load_json(execution.output)
                if data is not None:
                    return data
        return {}

    @staticmethod
    def _try_load_json(text: str) -> Optional[Dict[str, object]]:
        text = text.strip()
        if not text:
            return None
        match = re.search(r"\{.*\}", text, flags=re.DOTALL)
        candidate = match.group(0) if match else text
        try:
            loaded = json.loads(candidate)
        except json.JSONDecodeError:
            return None
        return loaded if isinstance(loaded, dict) else None


class MultiAgentWorkflowAdapter:
    """将多 Agent 调度结果转换为 TradingDayScheduler 可消费的上下文。"""

    def __init__(self, orchestrator: MultiAgentOrchestrator) -> None:
        self._orchestrator = orchestrator

    async def run(self, *, snapshot: Dict[str, object], account: Dict[str, object]) -> WorkflowContext:
        outcome = await self._orchestrator.run_round(snapshot, account)
        risk_payload = {
            "status": outcome.risk_status,
            "warnings": outcome.risk_warnings,
            "score": outcome.review_score,
            "reflections": outcome.reflections,
        }
        return WorkflowContext(
            snapshot=snapshot,
            account=account,
            broadcast=outcome.broadcast.final,
            strategy=outcome.strategy,
            risk=risk_payload,
        )


__all__ = [
    "MultiAgentOrchestrator",
    "MultiAgentWorkflowAdapter",
    "AgentScoreBoard",
    "ScoreEntry",
    "RoundOutcome",
]
