"""A 股上下文构建工具。

负责在处理事件时收集 A 股市场的观测数据（账户快照、行情数据），
为 Reasoner 提供 A 股决策所需输入。
"""

from __future__ import annotations

from dataclasses import dataclass, field
from datetime import date, datetime, timedelta, timezone
from typing import Any, Dict, List, Optional

import structlog

from laca.adapters.data.market_data_service import MarketDataService
from laca.agent.events import AgentEvent, EventType
from laca.agent.memory_system import TradingMemory
from laca.tools.ashare_portfolio_tool import (
    PortfolioSnapshot as AccountSnapshot,
    ASharePortfolioTool,
)
from laca.services.intelligence_service import IntelligenceService, MarketIntelligence, SymbolIntelligence

logger = structlog.get_logger("ashare_context_builder")


@dataclass
class AShareMarketContext:
    """A 股行情观测数据。"""

    quotes: Dict[str, Dict[str, Any]] = field(default_factory=dict)
    limit_prices: Dict[str, Dict[str, Any]] = field(default_factory=dict)
    suspended_stocks: List[str] = field(default_factory=list)
    errors: Dict[str, str] = field(default_factory=dict)


@dataclass
class AShareObservation:
    """Reasoner 所需的 A 股观测信息。"""

    event: AgentEvent
    snapshot: AccountSnapshot
    market: AShareMarketContext = field(default_factory=AShareMarketContext)
    metrics: Dict[str, Any] = field(default_factory=dict)
    intelligence: Dict[str, SymbolIntelligence] = field(default_factory=dict)
    market_intelligence: Optional[MarketIntelligence] = None
    collected_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
    metadata: Dict[str, Any] = field(default_factory=dict)
    memory_context: str = ""  # 记忆系统提供的上下文


class AShareContextBuilder:
    """基于 ASharePortfolioTool 与 MarketDataService 构建 A 股观测上下文。"""

    def __init__(
        self,
        portfolio_tool: ASharePortfolioTool,
        *,
        market_data_service: Optional[MarketDataService] = None,
        intelligence_service: Optional[IntelligenceService] = None,
        memory_system: Optional[TradingMemory] = None,
        max_stocks: int = 10,
        market_cache_ttl_seconds: int = 30,
        observation_cache_ttl_seconds: int = 10,
    ) -> None:
        self._portfolio_tool = portfolio_tool
        self._market_data_service = market_data_service
        self._intelligence_service = intelligence_service or IntelligenceService(market_data_service=market_data_service)
        self._memory_system = memory_system or TradingMemory()
        self._max_stocks = max(max_stocks, 1)
        self._market_cache_ttl = timedelta(seconds=max(market_cache_ttl_seconds, 0))
        self._observation_cache_ttl = timedelta(
            seconds=max(observation_cache_ttl_seconds, 0)
        )

        # Market context 缓存
        self._market_cache: Optional[tuple[datetime, List[str], AShareMarketContext]] = (
            None
        )

        # Observation 缓存
        self._observation_cache: Optional[tuple[datetime, AShareObservation]] = None

    async def build(
        self,
        event: AgentEvent,
        *,
        metadata: Optional[Dict[str, Any]] = None,
    ) -> AShareObservation:
        """构建 A 股观测上下文。"""

        # 检查 Observation 缓存
        now = datetime.now(timezone.utc)
        cached = self._observation_cache
        if (
            cached
            and self._observation_cache_ttl.total_seconds() > 0
            and (now - cached[0]) < self._observation_cache_ttl
        ):
            logger.debug(
                "ashare_context.observation.cache_hit",
                age_seconds=(now - cached[0]).total_seconds(),
            )
            obs = cached[1]
            return AShareObservation(
                event=event,
                snapshot=obs.snapshot,
                market=obs.market,
                metrics=obs.metrics,
                collected_at=obs.collected_at,
                metadata=metadata or {},
                memory_context=obs.memory_context,
            )

        # 重新构建 Observation
        snapshot = self._portfolio_tool.fetch_snapshot()
        market = await self._collect_market(snapshot, event)
        metrics = self._compute_metrics(snapshot)
        intelligence, market_intel = self._collect_intelligence(snapshot, event)

        # 获取记忆上下文
        memory_context = self._memory_system.get_full_context()

        observation = AShareObservation(
            event=event,
            snapshot=snapshot,
            market=market,
            metrics=metrics,
            metadata=metadata or {},
            memory_context=memory_context,
            intelligence=intelligence,
            market_intelligence=market_intel,
        )

        # 更新缓存
        self._observation_cache = (now, observation)
        return observation

    async def _collect_market(
        self, snapshot: AccountSnapshot, event: AgentEvent
    ) -> AShareMarketContext:
        """收集市场数据。"""

        if not self._market_data_service:
            return AShareMarketContext()

        # 选择要查询的股票
        stocks = self._select_stocks(snapshot, event)
        if not stocks:
            return AShareMarketContext()

        # 检查缓存
        cached = self._market_cache
        now = datetime.now(timezone.utc)
        if (
            cached
            and self._market_cache_ttl.total_seconds() > 0
            and (now - cached[0]) < self._market_cache_ttl
            and set(cached[1]) == set(stocks)
        ):
            logger.debug(
                "ashare_context.market.cache_hit",
                age_seconds=(now - cached[0]).total_seconds(),
                stocks_count=len(stocks),
            )
            return cached[2]

        # 查询市场数据
        market_context = AShareMarketContext()

        for ts_code in stocks:
            try:
                # 获取最新价格
                price = self._market_data_service.get_latest_price(ts_code)
                if price:
                    market_context.quotes[ts_code] = {
                        "price": float(price),
                        "ts_code": ts_code,
                    }

                # 获取涨跌停价格
                limit_price = self._market_data_service.get_limit_prices(ts_code)
                if limit_price:
                    limit_data = {
                        "up_limit": float(limit_price.up_limit),
                        "down_limit": float(limit_price.down_limit),
                    }
                    if limit_price.pct_chg is not None:
                        limit_data["pct_chg"] = float(limit_price.pct_chg)

                    market_context.limit_prices[ts_code] = limit_data

                # 检查停牌
                if self._market_data_service.is_suspended(ts_code):
                    market_context.suspended_stocks.append(ts_code)

            except Exception as exc:
                logger.warning(
                    "ashare_context.market.stock_error",
                    ts_code=ts_code,
                    error=str(exc),
                )
                market_context.errors[ts_code] = str(exc)

        # 更新缓存
        self._market_cache = (now, stocks, market_context)
        return market_context

    def _select_stocks(
        self, snapshot: AccountSnapshot, event: AgentEvent
    ) -> List[str]:
        """选择要查询行情的股票。"""

        stocks = []

        # 添加事件中指定的股票
        if event.event_type is EventType.AUTO_SIGNAL and event.payload:
            ts_code = event.payload.get("ts_code")
            if ts_code:
                stocks.append(ts_code.upper())

        # 添加持仓股票
        for position in snapshot.positions:
            if position.ts_code not in stocks:
                stocks.append(position.ts_code)

        # 如果无持仓且为定时触发，添加默认关注股票池
        if not stocks and event.event_type == EventType.SCHEDULED_TICK:
            # 沪深核心资产（蓝筹股、消费龙头、科技龙头）
            default_watchlist = [
                "600519.SH",  # 贵州茅台 - 白酒龙头
                "600036.SH",  # 招商银行 - 银行龙头
                "000858.SZ",  # 五粮液 - 白酒
                "000333.SZ",  # 美的集团 - 家电龙头
                "600276.SH",  # 恒瑞医药 - 医药龙头
            ]
            stocks = default_watchlist[: self._max_stocks]
            logger.info(
                "ashare_context.using_default_watchlist",
                stocks_count=len(stocks),
                stocks=stocks,
                reason="无持仓且定时触发，使用默认关注股票池",
            )

        # 限制数量
        return stocks[: self._max_stocks]

    def _collect_intelligence(
        self, snapshot: AccountSnapshot, event: AgentEvent
    ) -> tuple[Dict[str, SymbolIntelligence], Optional[MarketIntelligence]]:
        if not self._intelligence_service:
            return {}, None

        trade_date = event.created_at.date() if event.created_at else date.today()
        symbols = [pos.ts_code for pos in snapshot.positions]

        intelligence = self._intelligence_service.collect_batch(
            symbols,
            trade_date=trade_date,
        ) if symbols else {}

        market_intel = self._intelligence_service.collect_market_intelligence(
            trade_date=trade_date
        )

        return intelligence, market_intel

    def _compute_metrics(self, snapshot: AccountSnapshot) -> Dict[str, Any]:
        """计算派生指标。"""

        total_assets = snapshot.total_assets
        metrics: Dict[str, Any] = {
            "total_assets": total_assets,
            "available_cash": snapshot.available_cash,
            "frozen_cash": snapshot.frozen_cash,
            "market_value": snapshot.market_value,
            "positions_count": len(snapshot.positions),
            "unrealized_pnl": snapshot.unrealized_pnl,
            "unrealized_pnl_pct": snapshot.unrealized_pnl_pct,
        }

        # 现金比例
        if total_assets > 0:
            metrics["cash_ratio"] = snapshot.available_cash / total_assets
        else:
            metrics["cash_ratio"] = None

        return metrics

    def clear_cache(self) -> None:
        """清空缓存。"""
        self._observation_cache = None
        self._market_cache = None
        logger.info("ashare_context.cache_cleared")

    @property
    def memory_system(self) -> TradingMemory:
        """获取记忆系统实例。"""
        return self._memory_system


__all__ = [
    "AShareContextBuilder",
    "AShareObservation",
    "AShareMarketContext",
]
