"""多 Agent 竞赛数据服务。"""

from __future__ import annotations

from dataclasses import dataclass, field
from datetime import date
from typing import Any, Dict, List, Optional, Sequence

from sqlalchemy import Select, case, func, select
from sqlalchemy.ext.asyncio import AsyncSession

from backend.infra.logging import logger
from backend.trading.services.reporting import (
    fetch_nav_history,
    fetch_positions,
    fetch_recent_trades,
)

from .models import (
    AgentTurnLog,
    Competition,
    CompetitionParticipant,
    CompetitionRound,
    RoundStatus,
    ParticipantRoundResult,
)


@dataclass
class ParticipantProfile:
    """参赛选手基础配置。"""

    id: int
    display_name: str
    slug: str
    strategy_account_id: int
    primary_model: Optional[str]
    fallback_models: List[str] = field(default_factory=list)
    agent_overrides: Dict[str, Any] = field(default_factory=dict)
    temperature: float = 0.65
    tags: List[str] = field(default_factory=list)
    description: Optional[str] = None
    is_active: bool = True


@dataclass
class CompetitionConfig:
    """竞赛及选手配置。"""

    competition: Competition
    participants: List[ParticipantProfile]


@dataclass
class RoundResultPayload:
    """回合结果数据载体。"""

    round_id: int
    participant_id: int
    strategy_account_id: int
    nav: Optional[float] = None
    total_equity: Optional[float] = None
    cash: Optional[float] = None
    pnl: Optional[float] = None
    pnl_pct: Optional[float] = None
    risk_status: Optional[str] = None
    risk_warnings: Optional[List[str]] = None
    review_score: Optional[float] = None
    score_delta: Optional[float] = None
    reflections: Optional[List[str]] = None
    summary: Optional[str] = None
    actions: Optional[List[Dict[str, Any]]] = None
    log_path: Optional[str] = None


@dataclass
class AgentTurnPayload:
    """单个 Agent 回合日志数据。"""

    agent_name: str
    role: str
    thoughts: Optional[str] = None
    final_output: Optional[str] = None
    actions: Optional[List[Dict[str, Any]]] = None
    tool_results: Optional[List[Dict[str, Any]]] = None
    raw_response: Optional[str] = None
    latency_ms: Optional[float] = None


class CompetitionService:
    """封装竞赛相关的持久化与读取逻辑。"""

    def __init__(self, competition_id: int) -> None:
        self.competition_id = competition_id

    async def get_competition(self, session: AsyncSession) -> Competition:
        stmt: Select[Competition] = select(Competition).where(Competition.id == self.competition_id)
        result = await session.execute(stmt)
        competition = result.scalar_one_or_none()
        if competition is None:
            raise ValueError(f"competition {self.competition_id} not found")
        return competition

    async def list_participants(self, session: AsyncSession, *, include_inactive: bool = False) -> List[ParticipantProfile]:
        stmt: Select[CompetitionParticipant] = (
            select(CompetitionParticipant)
            .where(CompetitionParticipant.competition_id == self.competition_id)
            .order_by(CompetitionParticipant.id)
        )
        if not include_inactive:
            stmt = stmt.where(CompetitionParticipant.is_active.is_(True))
        result = await session.execute(stmt)
        records = result.scalars().all()
        participants: List[ParticipantProfile] = []
        for item in records:
            tags = [token.strip() for token in (item.tags or "").split(",") if token.strip()]
            fallback_models = list(item.fallback_models or [])
            overrides = dict(item.agent_overrides or {})
            participants.append(
                ParticipantProfile(
                    id=item.id,
                    display_name=item.display_name,
                    slug=item.slug,
                    strategy_account_id=item.strategy_account_id,
                    primary_model=item.primary_model,
                    fallback_models=fallback_models,
                    agent_overrides=overrides,
                    temperature=item.temperature or 0.65,
                    tags=tags,
                    description=item.description,
                    is_active=item.is_active,
                )
            )
        return participants

    async def load_config(self, session: AsyncSession) -> CompetitionConfig:
        competition = await self.get_competition(session)
        participants = await self.list_participants(session)
        return CompetitionConfig(competition=competition, participants=participants)

    async def fetch_overview(self, session: AsyncSession) -> Dict[str, Any]:
        competition = await self.get_competition(session)

        participant_count = await session.scalar(
            select(func.count()).where(CompetitionParticipant.competition_id == self.competition_id)
        )
        active_count = await session.scalar(
            select(func.count()).where(
                CompetitionParticipant.competition_id == self.competition_id,
                CompetitionParticipant.is_active.is_(True),
            )
        )

        round_row = await session.execute(
            select(
                func.count(CompetitionRound.id).label("total"),
                func.sum(
                    case(
                        (CompetitionRound.status == RoundStatus.COMPLETED.value, 1),
                        else_=0,
                    )
                ).label("completed"),
                func.sum(
                    case(
                        (CompetitionRound.status == RoundStatus.RUNNING.value, 1),
                        else_=0,
                    )
                ).label("running"),
                func.max(CompetitionRound.finished_at).label("last_finished_at"),
                func.max(CompetitionRound.round_index).label("last_round_index"),
            ).where(CompetitionRound.competition_id == self.competition_id)
        )
        round_stats = round_row.first()

        total_rounds = int(round_stats.total or 0) if round_stats else 0
        completed_rounds = int(round_stats.completed or 0) if round_stats else 0
        running_rounds = int(round_stats.running or 0) if round_stats else 0
        last_round_index = int(round_stats.last_round_index or 0) if round_stats and round_stats.last_round_index is not None else None
        last_finished_at = (
            round_stats.last_finished_at.isoformat()
            if round_stats and round_stats.last_finished_at is not None
            else None
        )

        return {
            "competition_id": competition.id,
            "name": competition.name,
            "season": competition.season,
            "status": competition.status,
            "description": competition.description,
            "participant_count": int(participant_count or 0),
            "active_participants": int(active_count or 0),
            "round_total": total_rounds,
            "round_completed": completed_rounds,
            "round_running": running_rounds,
            "last_round_index": last_round_index,
            "last_finished_at": last_finished_at,
            "started_at": competition.started_at.isoformat() if competition.started_at else None,
            "updated_at": competition.updated_at.isoformat() if competition.updated_at else None,
        }

    async def fetch_leaderboard(
        self,
        session: AsyncSession,
        *,
        limit: Optional[int] = None,
    ) -> List[Dict[str, Any]]:
        participants = await self.list_participants(session, include_inactive=False)
        if not participants:
            return []
        participant_ids = [p.id for p in participants]

        stats_stmt = select(
            ParticipantRoundResult.participant_id,
            func.count(ParticipantRoundResult.id).label("rounds"),
            func.sum(ParticipantRoundResult.pnl).label("total_pnl"),
            func.sum(ParticipantRoundResult.score_delta).label("score"),
            func.sum(
                case(
                    (ParticipantRoundResult.risk_status == "FLAGGED", 1),
                    else_=0,
                )
            ).label("risk_flags"),
            func.max(ParticipantRoundResult.created_at).label("last_activity"),
        ).where(ParticipantRoundResult.participant_id.in_(participant_ids)).group_by(ParticipantRoundResult.participant_id)
        stats_rows = await session.execute(stats_stmt)

        stats_map: Dict[int, Dict[str, Any]] = {}
        for row in stats_rows:
            stats_map[row.participant_id] = {
                "rounds": int(row.rounds or 0),
                "total_pnl": float(row.total_pnl or 0.0),
                "score": float(row.score or 0.0),
                "risk_flags": int(row.risk_flags or 0),
                "last_activity": row.last_activity,
            }

        latest_sub = (
            select(
                ParticipantRoundResult.participant_id.label("participant_id"),
                ParticipantRoundResult.round_id.label("round_id"),
                ParticipantRoundResult.total_equity.label("total_equity"),
                ParticipantRoundResult.pnl.label("pnl"),
                ParticipantRoundResult.pnl_pct.label("pnl_pct"),
                ParticipantRoundResult.review_score.label("review_score"),
                ParticipantRoundResult.risk_status.label("risk_status"),
                ParticipantRoundResult.created_at.label("created_at"),
                func.row_number()
                .over(
                    partition_by=ParticipantRoundResult.participant_id,
                    order_by=ParticipantRoundResult.created_at.desc(),
                )
                .label("rn"),
            )
            .where(ParticipantRoundResult.participant_id.in_(participant_ids))
            .subquery()
        )
        latest_rows = await session.execute(
            select(latest_sub).where(latest_sub.c.rn == 1)
        )
        latest_map: Dict[int, Dict[str, Any]] = {}
        for row in latest_rows:
            latest_map[row.participant_id] = {
                "round_id": row.round_id,
                "total_equity": float(row.total_equity or 0.0) if row.total_equity is not None else None,
                "pnl": float(row.pnl or 0.0) if row.pnl is not None else None,
                "pnl_pct": float(row.pnl_pct or 0.0) if row.pnl_pct is not None else None,
                "review_score": float(row.review_score or 0.0) if row.review_score is not None else None,
                "risk_status": row.risk_status,
                "created_at": row.created_at,
            }

        leaderboard: List[Dict[str, Any]] = []
        for profile in participants:
            stats = stats_map.get(profile.id, {})
            latest = latest_map.get(profile.id, {})
            equity = latest.get("total_equity")
            total_pnl = stats.get("total_pnl", 0.0)
            rounds = stats.get("rounds", 0)
            leaderboard.append(
                {
                    "participant_id": profile.id,
                    "display_name": profile.display_name,
                    "slug": profile.slug,
                    "strategy_account_id": profile.strategy_account_id,
                    "primary_model": profile.primary_model,
                    "tags": profile.tags,
                    "rounds": rounds,
                    "total_pnl": total_pnl,
                    "score": stats.get("score", 0.0),
                    "risk_flags": stats.get("risk_flags", 0),
                    "last_round_id": latest.get("round_id"),
                    "last_round_time": latest.get("created_at").isoformat() if latest.get("created_at") else None,
                    "total_equity": equity,
                    "pnl": latest.get("pnl"),
                    "pnl_pct": latest.get("pnl_pct"),
                    "review_score": latest.get("review_score"),
                    "focus_symbol": self._extract_focus_symbol(profile),
                }
            )

        leaderboard.sort(key=lambda item: (item.get("score", 0.0), item.get("total_equity") or 0.0), reverse=True)
        if limit is not None:
            leaderboard = leaderboard[:limit]
        return leaderboard

    async def fetch_participant_detail(
        self,
        session: AsyncSession,
        participant_id: int,
        *,
        rounds_limit: int = 10,
    ) -> Dict[str, Any]:
        profiles = await self.list_participants(session, include_inactive=True)
        profile = next((item for item in profiles if item.id == participant_id), None)
        if profile is None:
            raise ValueError(f"participant {participant_id} not found in competition {self.competition_id}")

        stats_row = await session.execute(
            select(
                func.count(ParticipantRoundResult.id).label("rounds"),
                func.sum(ParticipantRoundResult.pnl).label("total_pnl"),
                func.sum(ParticipantRoundResult.score_delta).label("score"),
                func.sum(
                    case(
                        (ParticipantRoundResult.risk_status == "FLAGGED", 1),
                        else_=0,
                    )
                ).label("risk_flags"),
                func.max(ParticipantRoundResult.created_at).label("last_activity"),
            ).where(ParticipantRoundResult.participant_id == participant_id)
        )
        stats = stats_row.first()

        latest_row = await session.execute(
            select(
                ParticipantRoundResult,
                CompetitionRound.round_index,
                CompetitionRound.trading_day,
                CompetitionRound.finished_at,
                CompetitionRound.status.label("round_status"),
            )
            .join(CompetitionRound, ParticipantRoundResult.round_id == CompetitionRound.id)
            .where(ParticipantRoundResult.participant_id == participant_id)
            .order_by(ParticipantRoundResult.created_at.desc())
            .limit(1)
        )
        latest = latest_row.first()

        history_rows = await session.execute(
            select(
                ParticipantRoundResult,
                CompetitionRound.round_index,
                CompetitionRound.trading_day,
                CompetitionRound.finished_at,
                CompetitionRound.status.label("round_status"),
            )
            .join(CompetitionRound, ParticipantRoundResult.round_id == CompetitionRound.id)
            .where(ParticipantRoundResult.participant_id == participant_id)
            .order_by(ParticipantRoundResult.created_at.desc())
            .limit(rounds_limit)
        )

        def _result_to_dict(row) -> Dict[str, Any]:
            result: ParticipantRoundResult = row.ParticipantRoundResult  # type: ignore[attr-defined]
            return {
                "round_id": result.round_id,
                "round_index": row.round_index,
                "trading_day": row.trading_day.isoformat() if row.trading_day else None,
                "finished_at": row.finished_at.isoformat() if row.finished_at else None,
                "status": getattr(row, "round_status", None),
                "nav": result.total_equity,
                "pnl": result.pnl,
                "pnl_pct": result.pnl_pct,
                "risk_status": result.risk_status,
                "review_score": result.review_score,
                "score_delta": result.score_delta,
                "summary": result.summary,
                "actions": result.actions,
                "reflections": result.reflections,
                "log_path": result.log_path,
                "created_at": result.created_at.isoformat() if result.created_at else None,
            }

        history = [_result_to_dict(row) for row in history_rows]
        latest_dict = _result_to_dict(latest) if latest else None

        metrics = {
            "rounds": int(stats.rounds or 0) if stats else 0,
            "total_pnl": float(stats.total_pnl or 0.0) if stats else 0.0,
            "score": float(stats.score or 0.0) if stats else 0.0,
            "risk_flags": int(stats.risk_flags or 0) if stats else 0,
            "last_activity": stats.last_activity.isoformat() if stats and stats.last_activity else None,
        }

        nav_rows = await fetch_nav_history(session, profile.strategy_account_id, limit=180)
        nav_points = [
            {
                "date": row.date.isoformat(),
                "nav": row.nav,
                "total_equity": row.total_equity,
                "cash": row.cash,
                "pnl_daily": row.pnl_daily,
                "pnl_total": row.pnl_total,
            }
            for row in nav_rows
        ]
        latest_nav = nav_points[-1] if nav_points else None

        position_rows = await fetch_positions(session, profile.strategy_account_id)
        positions_payload: List[Dict[str, Any]] = []
        for pos in position_rows:
            if (
                abs(pos.quantity) < 1e-6
                and abs(pos.frozen_quantity) < 1e-6
                and abs(pos.market_value) < 1e-2
                and abs(pos.unrealized_pnl) < 1e-2
            ):
                continue
            positions_payload.append(
                {
                    "ts_code": pos.ts_code,
                    "quantity": pos.quantity,
                    "frozen_quantity": pos.frozen_quantity,
                    "avg_cost": pos.avg_cost,
                    "market_value": pos.market_value,
                    "unrealized_pnl": pos.unrealized_pnl,
                    "realized_pnl": pos.realized_pnl,
                    "updated_at": pos.updated_at.isoformat() if pos.updated_at else None,
                }
            )

        trades_rows = await fetch_recent_trades(session, profile.strategy_account_id, limit=40)
        trades_payload: List[Dict[str, Any]] = []
        for trade, order in trades_rows:
            order_type = getattr(order.order_type, "value", order.order_type)
            order_status = getattr(order.status, "value", order.status)
            trades_payload.append(
                {
                    "trade_id": trade.id,
                    "order_id": order.id,
                    "ts_code": trade.ts_code,
                    "side": getattr(trade.side, "value", trade.side),
                    "price": trade.price,
                    "quantity": trade.quantity,
                    "amount": trade.amount,
                    "order_type": order_type,
                    "order_status": order_status,
                    "trade_time": trade.trade_time.isoformat() if trade.trade_time else None,
                }
            )

        has_positions = bool(positions_payload)
        has_trades = bool(trades_payload)
        latest_nav_total = latest_nav.get("total_equity") if latest_nav else None
        latest_round_nav = latest_dict.get("nav") if latest_dict else None
        is_sample_nav = bool(nav_points) and not has_trades and not has_positions
        if is_sample_nav:
            # 如果最新回合仍为 0 盈亏，说明尚未产生真实交易，提示前端显示示例数据标签
            if latest_round_nav and latest_nav_total:
                if abs(float(latest_round_nav or 0.0)) > 1e-6 and abs(float(latest_nav_total or 0.0)) > 1e-6:
                    if abs(float(latest_round_nav) - float(latest_nav_total)) < 1e-6:
                        is_sample_nav = False
        account_payload = {
            "summary": {
                "current_nav": latest_nav.get("nav") if latest_nav else None,
                "total_equity": latest_nav.get("total_equity") if latest_nav else None,
                "cash": latest_nav.get("cash") if latest_nav else None,
                "pnl_daily": latest_nav.get("pnl_daily") if latest_nav else None,
                "pnl_total": latest_nav.get("pnl_total") if latest_nav else None,
                "last_date": latest_nav.get("date") if latest_nav else None,
            },
            "nav_points": nav_points,
            "positions": positions_payload,
            "trades": trades_payload,
            "is_sample_nav": is_sample_nav,
        }

        return {
            "participant": {
                "id": profile.id,
                "display_name": profile.display_name,
                "slug": profile.slug,
                "description": profile.description,
                "tags": profile.tags,
                "primary_model": profile.primary_model,
                "fallback_models": profile.fallback_models,
                "agent_overrides": profile.agent_overrides,
                "temperature": profile.temperature,
                "strategy_account_id": profile.strategy_account_id,
                "focus_symbol": self._extract_focus_symbol(profile),
                "is_active": profile.is_active,
            },
            "metrics": metrics,
            "latest_round": latest_dict,
            "recent_rounds": history,
            "account": account_payload,
        }

    async def fetch_round_detail(
        self,
        session: AsyncSession,
        round_id: int,
    ) -> Dict[str, Any]:
        round_record = await session.get(CompetitionRound, round_id)
        if round_record is None or round_record.competition_id != self.competition_id:
            raise ValueError(f"round {round_id} not found in competition {self.competition_id}")

        results_rows = await session.execute(
            select(
                ParticipantRoundResult,
                CompetitionParticipant.display_name,
                CompetitionParticipant.slug,
                CompetitionParticipant.strategy_account_id,
            )
            .join(CompetitionParticipant, ParticipantRoundResult.participant_id == CompetitionParticipant.id)
            .where(ParticipantRoundResult.round_id == round_id)
            .order_by(CompetitionParticipant.id)
        )
        result_records = results_rows.all()
        result_ids = [row.ParticipantRoundResult.id for row in result_records]

        turns_rows = await session.execute(
            select(AgentTurnLog)
            .where(AgentTurnLog.round_result_id.in_(result_ids))
            .order_by(AgentTurnLog.participant_id, AgentTurnLog.created_at)
        )
        turn_map: Dict[int, List[Dict[str, Any]]] = {}
        for log in turns_rows.scalars():
            turn_map.setdefault(log.round_result_id, []).append(
                {
                    "agent_name": log.agent_name,
                    "role": log.role,
                    "thoughts": log.thoughts,
                    "final_output": log.final_output,
                    "actions": log.actions,
                    "tool_results": log.tool_results,
                    "raw_response": log.raw_response,
                    "latency_ms": log.latency_ms,
                    "created_at": log.created_at.isoformat() if log.created_at else None,
                }
            )

        participants: List[Dict[str, Any]] = []
        for row in result_records:
            result = row.ParticipantRoundResult
            participants.append(
                {
                    "participant_id": result.participant_id,
                    "display_name": row.display_name,
                    "slug": row.slug,
                    "strategy_account_id": row.strategy_account_id,
                    "nav": result.total_equity,
                    "pnl": result.pnl,
                    "pnl_pct": result.pnl_pct,
                    "risk_status": result.risk_status,
                    "review_score": result.review_score,
                    "score_delta": result.score_delta,
                    "summary": result.summary,
                    "actions": result.actions,
                    "reflections": result.reflections,
                    "log_path": result.log_path,
                    "agent_turns": turn_map.get(result.id, []),
                }
            )

        return {
            "round": {
                "id": round_record.id,
                "round_index": round_record.round_index,
                "status": round_record.status,
                "trading_day": round_record.trading_day.isoformat() if round_record.trading_day else None,
                "started_at": round_record.started_at.isoformat() if round_record.started_at else None,
                "finished_at": round_record.finished_at.isoformat() if round_record.finished_at else None,
                "snapshot_source": round_record.snapshot_source,
                "metadata": round_record.extra,
            },
            "participants": participants,
        }

    async def create_round(
        self,
        session: AsyncSession,
        *,
        round_index: int,
        trading_day: Optional[date] = None,
        snapshot_source: Optional[Dict[str, Any]] = None,
        extra: Optional[Dict[str, Any]] = None,
    ) -> CompetitionRound:
        record = CompetitionRound(
            competition_id=self.competition_id,
            round_index=round_index,
            trading_day=trading_day,
            snapshot_source=snapshot_source or {},
            extra=extra or {},
        )
        session.add(record)
        await session.flush()
        return record

    async def record_round_result(
        self,
        session: AsyncSession,
        payload: RoundResultPayload,
    ) -> ParticipantRoundResult:
        stmt: Select[ParticipantRoundResult] = select(ParticipantRoundResult).where(
            ParticipantRoundResult.round_id == payload.round_id,
            ParticipantRoundResult.participant_id == payload.participant_id,
        )
        result = await session.execute(stmt)
        existing = result.scalar_one_or_none()

        fields = {
            "strategy_account_id": payload.strategy_account_id,
            "nav": payload.nav,
            "total_equity": payload.total_equity,
            "cash": payload.cash,
            "pnl": payload.pnl,
            "pnl_pct": payload.pnl_pct,
            "risk_status": payload.risk_status,
            "risk_warnings": payload.risk_warnings or [],
            "review_score": payload.review_score,
            "score_delta": payload.score_delta,
            "reflections": payload.reflections or [],
            "summary": payload.summary,
            "actions": payload.actions or [],
            "log_path": payload.log_path,
        }

        if existing is None:
            record = ParticipantRoundResult(
                round_id=payload.round_id,
                participant_id=payload.participant_id,
                **fields,
            )
            session.add(record)
            await session.flush()
            return record

        for key, value in fields.items():
            setattr(existing, key, value)
        await session.flush()
        return existing

    async def append_agent_turns(
        self,
        session: AsyncSession,
        *,
        round_result_id: int,
        participant_id: int,
        turns: Sequence[AgentTurnPayload],
    ) -> None:
        if not turns:
            return
        for turn in turns:
            record = AgentTurnLog(
                round_result_id=round_result_id,
                participant_id=participant_id,
                agent_name=turn.agent_name,
                role=turn.role,
                thoughts=turn.thoughts,
                final_output=turn.final_output,
                actions=turn.actions or [],
                tool_results=turn.tool_results or [],
                raw_response=turn.raw_response,
                latency_ms=turn.latency_ms,
            )
            session.add(record)
        await session.flush()
        logger.info(
            "competition:append_agent_turns",
            participant_id=participant_id,
            count=len(turns),
        )



    def _extract_focus_symbol(self, profile: ParticipantProfile) -> Optional[str]:
        overrides = profile.agent_overrides or {}
        if "focus_symbol" in overrides:
            return str(overrides["focus_symbol"])
        symbols = overrides.get("symbols")
        if isinstance(symbols, list) and symbols:
            return str(symbols[0])
        if overrides.get("symbol"):
            return str(overrides["symbol"])
        return None


__all__ = [
    "CompetitionService",
    "CompetitionConfig",
    "ParticipantProfile",
    "RoundResultPayload",
    "AgentTurnPayload",
]
