"""行情数据工具。"""

from __future__ import annotations

import asyncio
from collections import OrderedDict
from dataclasses import dataclass, field
from datetime import datetime, timedelta, timezone
from typing import Any, Dict, List, Optional, Tuple

import httpx
import structlog
from tenacity import AsyncRetrying, retry_if_exception_type, stop_after_attempt, wait_fixed

from lacopro.config.settings import Settings, get_settings
from lacopro.tools.base import AsyncTool

logger = structlog.get_logger("market_data_tool")

KLINES_ENDPOINT = "/api/v3/klines"
TICKER_ENDPOINT = "/api/v3/ticker/24hr"


class MarketDataError(Exception):
    """行情数据异常。"""


@dataclass
class Kline:
    open_time: datetime
    close_time: datetime
    open_price: float
    high_price: float
    low_price: float
    close_price: float
    volume: float
    raw: Dict[str, Any]


@dataclass
class Ticker:
    symbol: str
    price_change_percent: float
    last_price: float
    high_price: float
    low_price: float
    volume: float
    close_time: datetime
    raw: Dict[str, Any]


@dataclass
class CacheStats:
    """缓存统计信息。"""

    ticker_hits: int = 0
    ticker_misses: int = 0
    kline_hits: int = 0
    kline_misses: int = 0
    stale_fallbacks: int = 0

    @property
    def ticker_hit_rate(self) -> float:
        total = self.ticker_hits + self.ticker_misses
        return self.ticker_hits / total if total > 0 else 0.0

    @property
    def kline_hit_rate(self) -> float:
        total = self.kline_hits + self.kline_misses
        return self.kline_hits / total if total > 0 else 0.0


class MarketDataTool(AsyncTool):
    """币安行情工具。"""

    def __init__(
        self,
        *,
        settings: Optional[Settings] = None,
        client: Optional[httpx.AsyncClient] = None,
        cache_ttl_seconds: int = 30,
        min_fetch_interval_seconds: int = 2,
        max_cache_size: int = 200,
        stale_cache_ttl_seconds: int = 300,
        enable_stale_fallback: bool = True,
    ) -> None:
        self.settings = settings or get_settings()
        self._client_provided = client is not None
        self._client = client or httpx.AsyncClient(timeout=10.0, base_url="https://api.binance.com")

        # 使用 OrderedDict 实现 LRU 缓存
        self._ticker_cache: OrderedDict[str, Tuple[Ticker, datetime]] = OrderedDict()
        self._kline_cache: OrderedDict[str, Tuple[List[Kline], datetime]] = OrderedDict()

        self._cache_ttl = timedelta(seconds=max(cache_ttl_seconds, 1))
        self._stale_cache_ttl = timedelta(seconds=max(stale_cache_ttl_seconds, cache_ttl_seconds))
        self._min_fetch_interval = timedelta(seconds=max(min_fetch_interval_seconds, 0))
        self._max_cache_size = max(max_cache_size, 1)  # 至少保留1个缓存条目
        self._enable_stale_fallback = enable_stale_fallback
        self._last_bulk_fetch_at: Optional[datetime] = None

        # 缓存统计
        self._stats = CacheStats()

    async def fetch_klines(
        self,
        symbol: str,
        interval: str = "1m",
        limit: int = 100,
    ) -> List[Kline]:
        symbol = symbol.upper()
        cache_key = f"{symbol}_{interval}_{limit}"

        # 检查缓存
        cached = self._kline_cache.get(cache_key)
        now = datetime.now(timezone.utc)
        if cached and now - cached[1] < self._cache_ttl:
            self._stats.kline_hits += 1
            # 移到末尾（LRU）
            self._kline_cache.move_to_end(cache_key)
            return cached[0]

        self._stats.kline_misses += 1

        # 请求新数据
        params = {"symbol": symbol, "interval": interval, "limit": limit}
        try:
            data = await self._request_json(KLINES_ENDPOINT, params=params)
        except Exception as exc:
            # 容错：如果请求失败且有过期缓存，返回过期缓存
            if self._enable_stale_fallback and cached and now - cached[1] < self._stale_cache_ttl:
                self._stats.stale_fallbacks += 1
                logger.warning(
                    "market_data.kline.stale_fallback",
                    symbol=symbol,
                    error=str(exc),
                    age_seconds=(now - cached[1]).total_seconds(),
                )
                return cached[0]
            raise

        klines: List[Kline] = []
        for item in data:
            try:
                klines.append(
                    Kline(
                        open_time=datetime.fromtimestamp(item[0] / 1000, tz=timezone.utc),
                        close_time=datetime.fromtimestamp(item[6] / 1000, tz=timezone.utc),
                        open_price=float(item[1]),
                        high_price=float(item[2]),
                        low_price=float(item[3]),
                        close_price=float(item[4]),
                        volume=float(item[5]),
                        raw={"binance": item},
                    )
                )
            except (TypeError, ValueError) as exc:  # 不信任数据格式
                raise MarketDataError(f"Invalid kline payload: {item}") from exc

        # 更新缓存
        self._kline_cache[cache_key] = (klines, now)
        self._kline_cache.move_to_end(cache_key)

        # 限制缓存大小
        while len(self._kline_cache) > self._max_cache_size:
            self._kline_cache.popitem(last=False)

        return klines

    async def fetch_ticker(self, symbol: str) -> Ticker:
        symbol = symbol.upper()
        cached = self._ticker_cache.get(symbol)
        now = datetime.now(timezone.utc)

        # 检查新鲜缓存
        if cached and now - cached[1] < self._cache_ttl:
            self._stats.ticker_hits += 1
            # 移到末尾（LRU）
            self._ticker_cache.move_to_end(symbol)
            return cached[0]

        self._stats.ticker_misses += 1

        # 请求新数据
        params = {"symbol": symbol}
        try:
            data = await self._request_json(TICKER_ENDPOINT, params=params)
        except Exception as exc:
            # 容错：如果请求失败且有过期缓存，返回过期缓存
            if self._enable_stale_fallback and cached and now - cached[1] < self._stale_cache_ttl:
                self._stats.stale_fallbacks += 1
                logger.warning(
                    "market_data.ticker.stale_fallback",
                    symbol=symbol,
                    error=str(exc),
                    age_seconds=(now - cached[1]).total_seconds(),
                )
                return cached[0]
            raise

        ticker = self._parse_ticker(symbol, data)
        self._ticker_cache[symbol] = (ticker, now)
        self._ticker_cache.move_to_end(symbol)
        self._last_bulk_fetch_at = datetime.now(timezone.utc)

        # 限制缓存大小
        while len(self._ticker_cache) > self._max_cache_size:
            self._ticker_cache.popitem(last=False)

        return ticker

    async def _maybe_throttle(self, now: datetime) -> None:
        if self._min_fetch_interval.total_seconds() == 0:
            return
        last_fetch = self._last_bulk_fetch_at
        if last_fetch and (now - last_fetch) < self._min_fetch_interval:
            await asyncio.sleep((self._min_fetch_interval - (now - last_fetch)).total_seconds())
        self._last_bulk_fetch_at = datetime.now(timezone.utc)

    async def fetch_tickers(self, symbols: List[str]) -> Dict[str, Ticker]:
        results: Dict[str, Ticker] = {}
        now = datetime.now(timezone.utc)
        to_fetch: List[str] = []
        for symbol in symbols:
            upper_symbol = symbol.upper()
            cached = self._ticker_cache.get(upper_symbol)
            if cached and now - cached[1] < self._cache_ttl:
                self._stats.ticker_hits += 1
                self._ticker_cache.move_to_end(upper_symbol)
                results[upper_symbol] = cached[0]
            else:
                to_fetch.append(upper_symbol)

        if not to_fetch:
            return results

        await self._maybe_throttle(now)
        for symbol in to_fetch:
            try:
                ticker = await self.fetch_ticker(symbol)
            except Exception as exc:  # noqa: BLE001
                logger.warning("market_data.fetch_tickers.error", symbol=symbol, error=str(exc))
                # 尝试使用过期缓存
                cached = self._ticker_cache.get(symbol)
                if self._enable_stale_fallback and cached and now - cached[1] < self._stale_cache_ttl:
                    self._stats.stale_fallbacks += 1
                    logger.warning(
                        "market_data.fetch_tickers.stale_fallback",
                        symbol=symbol,
                        age_seconds=(now - cached[1]).total_seconds(),
                    )
                    results[symbol] = cached[0]
                continue
            results[symbol] = ticker
        return results

    def get_cache_stats(self) -> CacheStats:
        """获取缓存统计信息。"""

        return CacheStats(
            ticker_hits=self._stats.ticker_hits,
            ticker_misses=self._stats.ticker_misses,
            kline_hits=self._stats.kline_hits,
            kline_misses=self._stats.kline_misses,
            stale_fallbacks=self._stats.stale_fallbacks,
        )

    def reset_cache_stats(self) -> None:
        """重置缓存统计。"""

        self._stats = CacheStats()

    def clear_cache(self) -> None:
        """清空所有缓存。"""

        self._ticker_cache.clear()
        self._kline_cache.clear()
        logger.info("market_data.cache_cleared")

    def discover_opportunity_coins(
        self,
        min_volatility_percent: float = 3.0,
        min_volume_usdt: float = 5_000_000.0,
        top_n: int = 15,
    ) -> Dict[str, Any]:
        """发现市场机会（区分涨跌、包含主流币）。

        Args:
            min_volatility_percent: 最小波动率（上涨币种），默认3%
            min_volume_usdt: 最小24h成交量（USDT），默认500万
            top_n: 每类返回TOP N币种，默认15个

        Returns:
            包含主流币种、上涨机会、下跌警示、市场趋势的字典
        """

        logger.info(
            "market_data.discover_opportunity_coins",
            min_volatility=min_volatility_percent,
            min_volume=min_volume_usdt,
            top_n=top_n
        )

        try:
            # 修复：使用同步客户端避免event loop冲突
            with httpx.Client(timeout=10.0, base_url="https://api.binance.com") as sync_client:
                # 获取所有24h ticker数据
                response = sync_client.get(TICKER_ENDPOINT)
                response.raise_for_status()
                all_tickers = response.json()

            # 筛选USDT交易对
            usdt_pairs = [
                t for t in all_tickers
                if t['symbol'].endswith('USDT')
                and t['symbol'] not in ['USDCUSDT', 'FDUSDUSDT', 'TUSDUSDT']  # 排除稳定币
            ]

            # 主流币种列表（降低波动率要求）
            mainstream_symbols = ['BTCUSDT', 'ETHUSDT', 'BNBUSDT', 'SOLUSDT']

            # 分类收集
            mainstream_coins = []  # 主流币
            rising_opportunities = []  # 上涨机会
            falling_warnings = []  # 下跌警示

            # 市场统计
            total_count = 0
            rising_count = 0
            falling_count = 0
            total_change = 0.0

            for ticker_data in usdt_pairs:
                try:
                    symbol = ticker_data['symbol']
                    price_change_pct = float(ticker_data['priceChangePercent'])
                    volume_usdt = float(ticker_data['quoteVolume'])
                    last_price = float(ticker_data['lastPrice'])

                    total_count += 1
                    total_change += price_change_pct

                    if price_change_pct > 0:
                        rising_count += 1
                    else:
                        falling_count += 1

                    # 1. 主流币种（降低门槛：1%波动即可）
                    if symbol in mainstream_symbols and abs(price_change_pct) >= 1.0:
                        mainstream_coins.append({
                            "symbol": symbol,
                            "24h涨跌": f"{price_change_pct:+.2f}%",
                            "当前价格": last_price,
                            "成交量": self._format_volume(volume_usdt),
                            "趋势": "上涨" if price_change_pct > 0 else "下跌"
                        })

                    # 2. 上涨机会（高成交量 + 上涨）
                    if (price_change_pct >= min_volatility_percent
                        and volume_usdt >= min_volume_usdt):
                        rising_opportunities.append({
                            "symbol": symbol,
                            "24h涨跌": f"+{price_change_pct:.2f}%",
                            "当前价格": last_price,
                            "成交量": self._format_volume(volume_usdt),
                            "波动率": abs(price_change_pct)
                        })

                    # 3. 下跌警示（高成交量 + 大跌）
                    if (price_change_pct <= -min_volatility_percent
                        and volume_usdt >= min_volume_usdt):
                        falling_warnings.append({
                            "symbol": symbol,
                            "24h涨跌": f"{price_change_pct:.2f}%",
                            "当前价格": last_price,
                            "成交量": self._format_volume(volume_usdt),
                            "波动率": abs(price_change_pct)
                        })

                except (KeyError, ValueError) as e:
                    logger.warning(
                        "market_data.parse_opportunity_failed",
                        symbol=ticker_data.get('symbol'),
                        error=str(e)
                    )
                    continue

            # 排序
            rising_opportunities.sort(key=lambda x: x["波动率"], reverse=True)
            falling_warnings.sort(key=lambda x: x["波动率"], reverse=True)
            mainstream_coins.sort(key=lambda x: float(x["24h涨跌"].rstrip('%')), reverse=True)

            # 市场趋势判断
            avg_change = total_change / total_count if total_count > 0 else 0
            if avg_change > 1.0:
                market_trend = "偏多（整体上涨）"
            elif avg_change < -1.0:
                market_trend = "偏空（整体下跌）"
            else:
                market_trend = "震荡（涨跌参半）"

            result = {
                "主流币种": mainstream_coins,
                "上涨机会": rising_opportunities[:top_n],
                "下跌警示": falling_warnings[:top_n],
                "市场趋势": {
                    "上涨币种数": rising_count,
                    "下跌币种数": falling_count,
                    "总币种数": total_count,
                    "平均涨跌": f"{avg_change:+.2f}%",
                    "整体趋势": market_trend,
                    "涨跌比": f"{rising_count}:{falling_count}"
                },
                "分析建议": self._generate_market_advice(
                    len(mainstream_coins),
                    len(rising_opportunities),
                    len(falling_warnings),
                    market_trend
                )
            }

            logger.info(
                "market_data.opportunity_coins_found",
                mainstream=len(mainstream_coins),
                rising=len(rising_opportunities),
                falling=len(falling_warnings),
                market_trend=market_trend
            )

            return result

        except Exception as e:
            logger.exception("market_data.discover_opportunity_failed", error=str(e))
            raise MarketDataError(f"Failed to discover opportunity coins: {e}") from e

    def _format_volume(self, volume_usdt: float) -> str:
        """格式化成交量显示。"""
        if volume_usdt >= 1_000_000_000:
            return f"{volume_usdt / 1_000_000_000:.2f}B"
        elif volume_usdt >= 1_000_000:
            return f"{volume_usdt / 1_000_000:.1f}M"
        else:
            return f"{volume_usdt / 1_000:.1f}K"

    def _generate_market_advice(
        self,
        mainstream_count: int,
        rising_count: int,
        falling_count: int,
        market_trend: str
    ) -> List[str]:
        """生成市场分析建议。"""
        advice = []

        if mainstream_count == 0:
            advice.append("⚠️ 主流币种波动率不足1%，市场可能缺乏方向")
        elif mainstream_count > 0:
            advice.append(f"✅ 发现{mainstream_count}个主流币种活跃，可作为稳健配置参考")

        if rising_count > 10:
            advice.append(f"📈 发现{rising_count}个上涨机会，市场情绪偏多")
        elif rising_count > 0:
            advice.append(f"💡 有{rising_count}个币种上涨，但整体机会有限")
        else:
            advice.append("⚠️ 没有符合条件的上涨机会，建议谨慎")

        if falling_count > 10:
            advice.append(f"⚠️ {falling_count}个币种大幅下跌，市场风险较高")

        if "偏空" in market_trend:
            advice.append("🛡️ 市场整体下跌，建议控制仓位或等待企稳信号")
        elif "偏多" in market_trend:
            advice.append("🚀 市场整体上涨，可适当参与强势品种")
        else:
            advice.append("⚖️ 市场震荡，建议精选个股而非追涨杀跌")

        return advice

    def _parse_ticker(self, symbol: str, data: Dict[str, Any]) -> Ticker:
        try:
            close_time = datetime.fromtimestamp(int(data["closeTime"]) / 1000, tz=timezone.utc)
            return Ticker(
                symbol=symbol,
                price_change_percent=float(data["priceChangePercent"]),
                last_price=float(data["lastPrice"]),
                high_price=float(data["highPrice"]),
                low_price=float(data["lowPrice"]),
                volume=float(data["volume"]),
                close_time=close_time,
                raw=data,
            )
        except (KeyError, ValueError) as exc:
            raise MarketDataError(f"Invalid ticker payload: {data}") from exc

    async def _request_json(self, endpoint: str, *, params: Optional[Dict[str, Any]] = None) -> Any:
        async for attempt in AsyncRetrying(
            reraise=True,
            stop=stop_after_attempt(3),
            wait=wait_fixed(1),
            retry=retry_if_exception_type(httpx.HTTPError),
        ):
            with attempt:
                response = await self._client.get(endpoint, params=params)
                response.raise_for_status()
                return response.json()

    async def get_ma(self, symbol: str, period: int = 20, interval: str = "1m") -> float:
        """计算移动平均线（MA）。

        Args:
            symbol: 交易对符号
            period: 计算周期
            interval: K线时间间隔

        Returns:
            MA值
        """
        klines = await self.fetch_klines(symbol, interval=interval, limit=period)
        if len(klines) < period:
            raise MarketDataError(f"K线数量不足，需要{period}根，实际{len(klines)}根")

        # 使用收盘价计算MA
        close_prices = [k.close_price for k in klines]
        ma = sum(close_prices) / len(close_prices)

        logger.debug(
            "market_data.ma_calculated",
            symbol=symbol,
            period=period,
            ma=ma,
            klines_count=len(klines)
        )

        return ma

    async def get_rsi(self, symbol: str, period: int = 14, interval: str = "1m") -> float:
        """计算相对强弱指标（RSI）。

        Args:
            symbol: 交易对符号
            period: 计算周期
            interval: K线时间间隔

        Returns:
            RSI值（0-100）
        """
        # RSI需要period+1根K线来计算period个价格变化
        klines = await self.fetch_klines(symbol, interval=interval, limit=period + 1)
        if len(klines) < period + 1:
            raise MarketDataError(f"K线数量不足，需要{period+1}根，实际{len(klines)}根")

        # 计算价格变化
        close_prices = [k.close_price for k in klines]
        price_changes = [close_prices[i] - close_prices[i-1] for i in range(1, len(close_prices))]

        # 分离涨跌
        gains = [change if change > 0 else 0 for change in price_changes]
        losses = [abs(change) if change < 0 else 0 for change in price_changes]

        # 计算平均涨跌
        avg_gain = sum(gains) / period
        avg_loss = sum(losses) / period

        # 计算RSI
        if avg_loss == 0:
            rsi = 100.0
        else:
            rs = avg_gain / avg_loss
            rsi = 100 - (100 / (1 + rs))

        logger.debug(
            "market_data.rsi_calculated",
            symbol=symbol,
            period=period,
            rsi=rsi,
            avg_gain=avg_gain,
            avg_loss=avg_loss
        )

        return rsi

    async def close(self) -> None:
        if not self._client_provided:
            await self._client.aclose()


__all__ = ["MarketDataTool", "MarketDataError", "Kline", "Ticker", "CacheStats"]
