"""多 Agent 竞赛执行器。"""

from __future__ import annotations

import asyncio
from dataclasses import dataclass, field
from datetime import date, datetime, timezone
from pathlib import Path
from typing import Any, Dict, List, Optional

from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker

from backend.agent import (
    AgentMemory,
    AgentRunRecorder,
    AgentScoreBoard,
    BaseAgent,
    ToolDefinition,
    create_market_agent,
    create_review_agent,
    create_risk_agent,
    create_strategy_agent,
)
from backend.agent.orchestrator import MultiAgentOrchestrator, RoundOutcome
from backend.agent.base import AgentResult
from backend.agent.providers import WorkflowDataProvider
from backend.competition.event_bus import CompetitionEventBus
from backend.competition.models import CompetitionRound, RoundStatus
from backend.competition.services import (
    AgentTurnPayload,
    CompetitionService,
    ParticipantProfile,
    RoundResultPayload,
)
from backend.infra.logging import logger
from backend.market.tushare_service import TushareService
from backend.trading.models import OrderSide, OrderType
from backend.trading.services import TradingAPIService
from backend.trading.services.risk_service import InstrumentSnapshot


@dataclass
class ParticipantState:
    """运行期缓存的参赛选手状态。"""

    profile: ParticipantProfile
    focus_symbol: str
    log_dir: Path
    scoreboard: AgentScoreBoard
    market_agent: BaseAgent
    strategy_agent: BaseAgent
    risk_agent: BaseAgent
    review_agent: BaseAgent
    orchestrator: MultiAgentOrchestrator
    last_strategy_points: float = 0.0
    latest_snapshot: Dict[str, Any] = field(default_factory=dict)


class CompetitionRunner:
    """负责调度多参赛选手完成一轮对决并持久化结果。"""

    def __init__(
        self,
        *,
        competition_service: CompetitionService,
        session_factory: async_sessionmaker[AsyncSession],
        trading_api: TradingAPIService,
        tushare_service: TushareService,
        llm_service: Any,
        log_root: Path,
        default_symbol: str = "000001.SZ",
        event_bus: Optional[CompetitionEventBus] = None,
    ) -> None:
        self._service = competition_service
        self._session_factory = session_factory
        self._trading_api = trading_api
        self._tushare_service = tushare_service
        self._llm_service = llm_service
        self._log_root = Path(log_root).resolve()
        self._default_symbol = default_symbol
        self._participant_states: Dict[int, ParticipantState] = {}
        self._event_bus = event_bus

    async def run_round(self, trading_day: Optional[date] = None) -> CompetitionRound:
        """执行一轮竞赛并返回回合记录。"""

        async with self._session_factory() as session:
            config = await self._service.load_config(session)
            if not config.participants:
                raise RuntimeError("竞赛未配置参赛选手，无法执行回合")

            next_index = await self._next_round_index(session)
            round_record = await self._service.create_round(
                session,
                round_index=next_index + 1,
                trading_day=trading_day,
                snapshot_source={"focus_symbols": [self._focus_symbol(p) for p in config.participants]},
                extra={"participant_count": len(config.participants)},
            )
            round_record.status = RoundStatus.RUNNING.value
            round_record.started_at = datetime.now(timezone.utc)
            await session.commit()

        await self._publish_event(
            "round_started",
            {
                "round_id": round_record.id,
                "round_index": round_record.round_index,
                "trading_day": round_record.trading_day.isoformat() if round_record.trading_day else None,
                "participant_count": len(config.participants),
            },
        )

        snapshot_cache = await self._prefetch_snapshots(config.participants, round_record.trading_day)

        try:
            for profile in config.participants:
                focus_symbol = self._focus_symbol(profile)
                await self._run_participant_round(
                    round_record,
                    profile,
                    snapshot_override=snapshot_cache.get(focus_symbol),
                )
        except Exception:
            logger.exception("competition:round_failed", round_id=round_record.id)
            await self._update_round_status(
                round_record.id,
                status=RoundStatus.FAILED.value,
                finished_at=datetime.now(timezone.utc),
            )
            await self._publish_event(
                "round_failed",
                {
                    "round_id": round_record.id,
                    "round_index": round_record.round_index,
                },
            )
            raise

        await self._update_round_status(
            round_record.id,
            status=RoundStatus.COMPLETED.value,
            finished_at=datetime.now(timezone.utc),
        )
        await self._publish_event(
            "round_completed",
            {
                "round_id": round_record.id,
                "round_index": round_record.round_index,
                "trading_day": round_record.trading_day.isoformat() if round_record.trading_day else None,
            },
        )
        return await self._fetch_round(round_record.id)

    async def _run_participant_round(
        self,
        round_record: CompetitionRound,
        profile: ParticipantProfile,
        *,
        snapshot_override: Optional[Dict[str, Any]] = None,
    ) -> None:
        state = await self._ensure_participant_state(profile)
        provider = WorkflowDataProvider(
            tushare_service=self._tushare_service,
            trading_api=self._trading_api,
            session_factory=self._session_factory,
            strategy_account_id=profile.strategy_account_id,
            focus_symbol=state.focus_symbol,
            as_of_date=round_record.trading_day,
            preloaded_snapshot=snapshot_override,
        )

        snapshot = await provider.snapshot()
        account_context = await provider.account()
        state.latest_snapshot = snapshot

        before = await self._load_account_overview(profile.strategy_account_id)

        outcome = await state.orchestrator.run_round(snapshot, account_context)

        after = await self._load_account_overview(profile.strategy_account_id)

        await self._persist_result(round_record, profile, state, outcome, before, after)

    async def _persist_result(
        self,
        round_record: CompetitionRound,
        profile: ParticipantProfile,
        state: ParticipantState,
        outcome: RoundOutcome,
        before: Dict[str, Any],
        after: Dict[str, Any],
    ) -> None:
        before_total = before["snapshot"].total_equity if before and before.get("snapshot") else None
        after_total = after["snapshot"].total_equity if after and after.get("snapshot") else None
        after_cash = after["snapshot"].cash_available if after and after.get("snapshot") else None
        pnl = None
        pnl_pct = None
        if before_total is not None and after_total is not None:
            pnl = after_total - before_total
            if abs(before_total) > 1e-6:
                pnl_pct = pnl / before_total

        strategy_points = self._extract_strategy_points(outcome)
        score_delta = strategy_points - state.last_strategy_points
        state.last_strategy_points = strategy_points

        payload = RoundResultPayload(
            round_id=round_record.id,
            participant_id=profile.id,
            strategy_account_id=profile.strategy_account_id,
            nav=after_total,
            total_equity=after_total,
            cash=after_cash,
            pnl=pnl,
            pnl_pct=pnl_pct,
            risk_status=outcome.risk_status,
            risk_warnings=outcome.risk_warnings,
            review_score=outcome.review_score,
            score_delta=score_delta,
            reflections=outcome.reflections,
            summary=outcome.review.final,
            actions=outcome.strategy.actions,
            log_path=str(state.log_dir.relative_to(self._log_root)) if state.log_dir else None,
        )

        turn_payloads = self._build_turn_payloads(profile, outcome)

        async with self._session_factory() as session:
            result_record = await self._service.record_round_result(session, payload)
            await self._service.append_agent_turns(
                session,
                round_result_id=result_record.id,
                participant_id=profile.id,
                turns=turn_payloads,
            )
            await session.commit()

        await self._publish_event(
            "participant_result",
            {
                "round_id": round_record.id,
                "round_index": round_record.round_index,
                "participant_id": profile.id,
                "participant": profile.display_name,
                "strategy_account_id": profile.strategy_account_id,
                "score_delta": score_delta,
                "review_score": outcome.review_score,
                "focus_symbol": state.focus_symbol,
            },
        )

    async def _ensure_participant_state(self, profile: ParticipantProfile) -> ParticipantState:
        if profile.id in self._participant_states:
            return self._participant_states[profile.id]

        focus_symbol = self._focus_symbol(profile)
        participant_dir = self._log_root / f"participant_{profile.slug}"
        participant_dir.mkdir(parents=True, exist_ok=True)

        def _recorder(name: str) -> AgentRunRecorder:
            path = participant_dir / name
            path.mkdir(parents=True, exist_ok=True)
            return AgentRunRecorder(path)

        def _memory() -> AgentMemory:
            return AgentMemory()

        model_overrides = profile.agent_overrides.get("models", {}) if profile.agent_overrides else {}

        market_agent = create_market_agent(
            self._llm_service,
            model=model_overrides.get("market") or profile.primary_model,
            memory=_memory(),
            recorder=_recorder("market"),
        )
        strategy_agent = create_strategy_agent(
            self._llm_service,
            model=model_overrides.get("strategy") or profile.primary_model,
            memory=_memory(),
            recorder=_recorder("strategy"),
        )
        risk_agent = create_risk_agent(
            self._llm_service,
            model=model_overrides.get("risk") or profile.primary_model,
            memory=_memory(),
            recorder=_recorder("risk"),
        )
        review_agent = create_review_agent(
            self._llm_service,
            model=model_overrides.get("review") or profile.primary_model,
            memory=_memory(),
            recorder=_recorder("review"),
        )

        trade_tool = ToolDefinition(
            name="trade_order",
            description="提交模拟交易订单（ts_code/side/quantity/price）",
            handler=lambda payload, pf=profile: self._handle_trade_tool(pf, payload),
        )
        strategy_agent.register_tool(trade_tool)

        market_intel_tool = ToolDefinition(
            name="market_intel",
            description=(
                "查询最新新闻公告，可选参数：ts_code/news_limit/announcement_limit/detail。"
                "detail 默认为 brief，可设置为 headline 获取仅标题，或 full 获取完整摘要。"
            ),
            handler=lambda payload, pf=profile: self._handle_market_intel(pf, payload),
        )
        market_agent.register_tool(market_intel_tool)
        strategy_agent.register_tool(market_intel_tool)

        scoreboard = AgentScoreBoard()
        orchestrator = MultiAgentOrchestrator(
            market_agent=market_agent,
            strategy_agent=strategy_agent,
            risk_agent=risk_agent,
            review_agent=review_agent,
            scoreboard=scoreboard,
        )

        state = ParticipantState(
            profile=profile,
            focus_symbol=focus_symbol,
            log_dir=participant_dir,
            scoreboard=scoreboard,
            market_agent=market_agent,
            strategy_agent=strategy_agent,
            risk_agent=risk_agent,
            review_agent=review_agent,
            orchestrator=orchestrator,
        )
        self._participant_states[profile.id] = state
        return state

    async def _handle_trade_tool(self, profile: ParticipantProfile, payload: Dict[str, Any]) -> Dict[str, Any]:
        ts_code = str(payload.get("ts_code") or self._focus_symbol(profile))
        side_value = str(payload.get("side", "BUY")).upper()
        order_type_value = str(payload.get("order_type", "MARKET")).upper()
        quantity = float(payload.get("quantity", 0))
        if quantity <= 0:
            logger.warning(
                "competition:trade_invalid_quantity",
                participant=profile.display_name,
                quantity=quantity,
            )
            return {
                "error": "quantity must be positive",
                "requested_quantity": quantity,
            }

        try:
            side = OrderSide[side_value]
        except KeyError as exc:  # pragma: no cover - 输入校验
            logger.warning(
                "competition:trade_invalid_side",
                participant=profile.display_name,
                side=side_value,
            )
            return {
                "error": "unsupported_side",
                "side": side_value,
            }
        try:
            order_type = OrderType[order_type_value]
        except KeyError as exc:  # pragma: no cover - 输入校验
            logger.warning(
                "competition:trade_invalid_order_type",
                participant=profile.display_name,
                order_type=order_type_value,
            )
            return {
                "error": "unsupported_order_type",
                "order_type": order_type_value,
            }

        price = payload.get("price")
        execution_price = payload.get("execution_price")
        state = self._participant_states.get(profile.id)
        latest_snapshot = state.latest_snapshot if state else {}
        snapshot_price = (
            latest_snapshot.get("latest_price", {}).get("close")
            if isinstance(latest_snapshot, dict)
            else None
        )
        if execution_price is None:
            execution_price = price or snapshot_price
        instrument_data = payload.get("instrument") or {}
        instrument = InstrumentSnapshot(
            ts_code=ts_code,
            price=instrument_data.get("price", snapshot_price),
            limit_up=instrument_data.get("limit_up"),
            limit_down=instrument_data.get("limit_down"),
            is_suspended=bool(instrument_data.get("is_suspended", False)),
        )

        async with self._session_factory() as session:
            fill = await self._trading_api.place_order(
                session,
                strategy_account_id=profile.strategy_account_id,
                ts_code=ts_code,
                side=side,
                order_type=order_type,
                quantity=quantity,
                price=price,
                execution_price=execution_price,
                instrument=instrument,
            )
            await session.commit()

        return {
            "order_id": fill.order.id,
            "status": fill.order.status.value,
            "filled_quantity": fill.order.filled_quantity,
            "avg_price": fill.order.avg_filled_price,
            "requested_quantity": quantity,
        }

    async def _load_account_overview(self, strategy_account_id: int) -> Dict[str, Any]:
        async with self._session_factory() as session:
            overview = await self._trading_api.get_account_overview(session, strategy_account_id)
        return {
            "snapshot": overview.snapshot,
            "positions": overview.positions,
            "open_orders": overview.open_orders,
        }

    async def _update_round_status(
        self,
        round_id: int,
        *,
        status: str,
        finished_at: Optional[datetime] = None,
    ) -> None:
        async with self._session_factory() as session:
            record = await session.get(CompetitionRound, round_id)
            if record is None:
                return
            record.status = status
            if finished_at is not None:
                record.finished_at = finished_at
            await session.commit()

    async def _fetch_round(self, round_id: int) -> CompetitionRound:
        async with self._session_factory() as session:
            record = await session.get(CompetitionRound, round_id)
            if record is None:
                raise ValueError(f"round {round_id} not found")
            await session.refresh(record)
            return record

    async def _next_round_index(self, session: AsyncSession) -> int:
        stmt = select(func.coalesce(func.max(CompetitionRound.round_index), 0)).where(
            CompetitionRound.competition_id == self._service.competition_id
        )
        result = await session.execute(stmt)
        return int(result.scalar_one())

    def _focus_symbol(self, profile: ParticipantProfile) -> str:
        overrides = profile.agent_overrides or {}
        if "focus_symbol" in overrides:
            return str(overrides["focus_symbol"])
        symbols = overrides.get("symbols")
        if isinstance(symbols, list) and symbols:
            return str(symbols[0])
        return overrides.get("symbol") or self._default_symbol

    def _extract_strategy_points(self, outcome: RoundOutcome) -> float:
        for entry in outcome.scoreboard:
            if entry.get("agent") == "strategy":
                return float(entry.get("points", 0.0))
        return 0.0

    def _build_turn_payloads(
        self,
        profile: ParticipantProfile,
        outcome: RoundOutcome,
    ) -> List[AgentTurnPayload]:
        turns: List[AgentTurnPayload] = []
        mapping = {
            "market-broadcast": outcome.broadcast,
            "strategy": outcome.strategy,
            "risk-audit": outcome.risk,
            "review": outcome.review,
        }
        for role, result in mapping.items():
            turns.append(self._to_turn_payload(role, result))
        return turns

    def _to_turn_payload(self, role: str, result: AgentResult) -> AgentTurnPayload:
        return AgentTurnPayload(
            agent_name=role,
            role=role,
            thoughts=result.thoughts,
            final_output=result.final,
            actions=result.actions,
            tool_results=[
                {
                    "name": tool.name,
                    "params": tool.params,
                    "output": tool.output,
                }
                for tool in result.tool_results
            ],
            raw_response=result.raw_response,
        )

    async def _prefetch_snapshots(
        self,
        participants: List[ParticipantProfile],
        trading_day: Optional[date],
    ) -> Dict[str, Dict[str, Any]]:
        symbols = {self._focus_symbol(profile) for profile in participants}
        snapshots: Dict[str, Dict[str, Any]] = {}
        for symbol in symbols:
            try:
                snapshots[symbol] = await asyncio.to_thread(
                    self._tushare_service.build_snapshot,
                    symbol,
                    as_of=trading_day,
                )
            except Exception:  # noqa: BLE001
                logger.warning("competition:snapshot_prefetch_failed", symbol=symbol, exc_info=True)
        return snapshots

    async def _handle_market_intel(self, profile: ParticipantProfile, payload: Dict[str, Any]) -> Dict[str, Any]:
        ts_code = str(payload.get("ts_code") or self._focus_symbol(profile))
        news_limit = max(1, min(int(payload.get("news_limit", payload.get("limit", 5)) or 5), 20))
        announcement_limit = max(1, min(int(payload.get("announcement_limit", 10) or 10), 50))
        detail = str(payload.get("detail") or payload.get("mode") or "brief")
        try:
            return await asyncio.to_thread(
                self._tushare_service.build_intel,
                ts_code,
                news_limit=news_limit,
                announcement_limit=announcement_limit,
                detail=detail,
            )
        except Exception as exc:  # noqa: BLE001 - 市场资讯获取失败兜底
            logger.warning(
                "competition:market_intel_failed",
                participant=profile.display_name,
                ts_code=ts_code,
                error=str(exc),
            )
            return {
                "ts_code": ts_code,
                "news": [],
                "announcements": [],
                "error": str(exc),
            }

    async def _publish_event(self, event: str, payload: Dict[str, Any]) -> None:
        if self._event_bus is None:
            return
        try:
            await self._event_bus.publish({"event": event, "payload": payload})
        except Exception:  # noqa: BLE001
            logger.warning("competition:event_publish_failed", event=event, exc_info=True)


__all__ = ["CompetitionRunner", "ParticipantState"]
