"""Tushare 数据客户端。

封装 Tushare Pro API 调用，提供重试、限流、错误处理等机制。
"""

from __future__ import annotations

import time
from datetime import date, datetime
from threading import Lock
from typing import Any, Dict, List, Optional

import pandas as pd
import structlog
import tushare as ts
from tenacity import (
    retry,
    retry_if_exception_type,
    stop_after_attempt,
    wait_exponential,
)

from laca.config.settings import get_settings

logger = structlog.get_logger(__name__)


class TushareAPIError(Exception):
    """Tushare API 调用错误。"""

    pass


class TushareRateLimitError(Exception):
    """Tushare API 频率限制错误。"""

    pass


class TushareClient:
    """Tushare Pro 客户端。

    提供 A 股行情数据、交易日历、基本面信息等接口，内置限流和重试机制。

    Attributes:
        token: Tushare API Token
        pro: Tushare Pro API 实例
        _rate_limiter: 限流器
        _lock: 线程锁
    """

    def __init__(self, token: Optional[str] = None):
        """初始化 Tushare 客户端。

        Args:
            token: Tushare API Token，若不提供则从配置文件读取

        Raises:
            ValueError: 当 token 未配置时
        """
        settings = get_settings()
        self.token = token or settings.tushare_token

        if not self.token:
            raise ValueError(
                "Tushare token 未配置，请在 .env 中设置 TUSHARE_TOKEN"
            )

        # 初始化 Tushare Pro API
        ts.set_token(self.token)
        self.pro = ts.pro_api()

        # 限流器：每分钟最多调用次数（根据账户权限调整）
        self._rate_limit = 200  # 默认200次/分钟，可根据实际权限调整
        self._call_times: List[float] = []
        self._lock = Lock()

        logger.info("tushare_client_initialized", rate_limit=self._rate_limit)

    def _check_rate_limit(self) -> None:
        """检查并控制 API 调用频率。

        实现滑动窗口限流算法，确保不超过每分钟调用次数限制。

        Raises:
            TushareRateLimitError: 当超过频率限制时
        """
        with self._lock:
            current_time = time.time()
            # 移除60秒前的调用记录
            self._call_times = [
                t for t in self._call_times if current_time - t < 60
            ]

            if len(self._call_times) >= self._rate_limit:
                wait_time = 60 - (current_time - self._call_times[0])
                logger.warning(
                    "rate_limit_exceeded",
                    wait_time=wait_time,
                    calls_in_window=len(self._call_times),
                )
                raise TushareRateLimitError(
                    f"超过频率限制，请等待 {wait_time:.1f} 秒"
                )

            self._call_times.append(current_time)

    @retry(
        retry=retry_if_exception_type((TushareAPIError, ConnectionError)),
        stop=stop_after_attempt(3),
        wait=wait_exponential(multiplier=1, min=2, max=10),
        reraise=True,
    )
    def _call_api(
        self,
        api_name: str,
        fields: Optional[str] = None,
        **kwargs: Any,
    ) -> pd.DataFrame:
        """调用 Tushare API 的通用方法。

        Args:
            api_name: API 接口名称
            fields: 返回字段列表（逗号分隔）
            **kwargs: API 参数

        Returns:
            API 返回的 DataFrame 数据

        Raises:
            TushareRateLimitError: 频率限制错误
            TushareAPIError: API 调用错误
        """
        self._check_rate_limit()

        try:
            logger.debug(
                "calling_tushare_api",
                api_name=api_name,
                params=kwargs,
            )

            df = self.pro.query(api_name, fields=fields, **kwargs)

            if df is None or df.empty:
                logger.warning(
                    "tushare_api_empty_result",
                    api_name=api_name,
                    params=kwargs,
                )
                return pd.DataFrame()

            logger.debug(
                "tushare_api_success",
                api_name=api_name,
                rows=len(df),
            )
            return df

        except Exception as e:
            logger.error(
                "tushare_api_error",
                api_name=api_name,
                error=str(e),
                params=kwargs,
            )
            if "每分钟最多访问" in str(e) or "超过限额" in str(e):
                raise TushareRateLimitError(str(e)) from e
            raise TushareAPIError(f"API 调用失败: {e}") from e

    def get_trading_calendar(
        self,
        exchange: str = "SSE",
        start_date: Optional[date] = None,
        end_date: Optional[date] = None,
    ) -> pd.DataFrame:
        """获取交易日历。

        Args:
            exchange: 交易所代码 (SSE上交所, SZSE深交所)
            start_date: 开始日期
            end_date: 结束日期

        Returns:
            包含交易日历的 DataFrame，列：cal_date, is_open, pretrade_date
        """
        params = {"exchange": exchange}
        if start_date:
            params["start_date"] = start_date.strftime("%Y%m%d")
        if end_date:
            params["end_date"] = end_date.strftime("%Y%m%d")

        return self._call_api("trade_cal", **params)

    def get_stock_basic(
        self,
        list_status: str = "L",
        exchange: Optional[str] = None,
    ) -> pd.DataFrame:
        """获取股票基础信息。

        Args:
            list_status: 上市状态 (L上市, D退市, P暂停上市)
            exchange: 交易所 (SSE上交所, SZSE深交所, BSE北交所)

        Returns:
            包含股票基本信息的 DataFrame
        """
        params = {"list_status": list_status}
        if exchange:
            params["exchange"] = exchange

        return self._call_api("stock_basic", **params)

    def get_daily_quote(
        self,
        ts_code: Optional[str] = None,
        trade_date: Optional[date] = None,
        start_date: Optional[date] = None,
        end_date: Optional[date] = None,
    ) -> pd.DataFrame:
        """获取日线行情数据。

        Args:
            ts_code: 股票代码 (如 000001.SZ)
            trade_date: 交易日期
            start_date: 开始日期
            end_date: 结束日期

        Returns:
            包含日线行情的 DataFrame
        """
        params = {}
        if ts_code:
            params["ts_code"] = ts_code
        if trade_date:
            params["trade_date"] = trade_date.strftime("%Y%m%d")
        if start_date:
            params["start_date"] = start_date.strftime("%Y%m%d")
        if end_date:
            params["end_date"] = end_date.strftime("%Y%m%d")

        return self._call_api("daily", **params)

    def get_daily_basic(
        self,
        ts_code: Optional[str] = None,
        trade_date: Optional[date] = None,
        start_date: Optional[date] = None,
        end_date: Optional[date] = None,
    ) -> pd.DataFrame:
        """获取每日基础指标数据。"""

        params: Dict[str, Any] = {}
        if ts_code:
            params["ts_code"] = ts_code
        if trade_date:
            params["trade_date"] = trade_date.strftime("%Y%m%d")
        if start_date:
            params["start_date"] = start_date.strftime("%Y%m%d")
        if end_date:
            params["end_date"] = end_date.strftime("%Y%m%d")

        fields = (
            "ts_code,trade_date,close,turnover_rate,turnover_rate_f,pe,pe_ttm,pb,ps,ps_ttm,"
            "total_share,float_share,free_share,total_mv,circ_mv"
        )
        return self._call_api("daily_basic", fields=fields, **params)

    def get_realtime_quote(self, ts_code: str) -> pd.DataFrame:
        """获取实时行情快照（需要高级权限）。

        Args:
            ts_code: 股票代码

        Returns:
            包含实时行情的 DataFrame

        Note:
            此接口需要 Tushare 高级权限，免费用户可能无法调用
        """
        return self._call_api("stk_factor", ts_code=ts_code)

    def get_minute_data(
        self,
        ts_code: str,
        freq: str = "5min",
        start_date: Optional[datetime] = None,
        end_date: Optional[datetime] = None,
    ) -> pd.DataFrame:
        """获取分钟线数据（需要高级权限）。

        Args:
            ts_code: 股票代码
            freq: 频率 (1min, 5min, 15min, 30min, 60min)
            start_date: 开始时间
            end_date: 结束时间

        Returns:
            包含分钟线数据的 DataFrame

        Note:
            此接口需要 Tushare 高级权限
        """
        params = {
            "ts_code": ts_code,
            "freq": freq,
        }
        if start_date:
            params["start_date"] = start_date.strftime("%Y%m%d %H:%M:%S")
        if end_date:
            params["end_date"] = end_date.strftime("%Y%m%d %H:%M:%S")

        # 注意：分钟数据接口在 Tushare 中可能需要特殊处理
        # 这里使用 pro_bar 接口
        try:
            df = ts.pro_bar(
                ts_code=ts_code,
                freq=freq,
                start_date=start_date.strftime("%Y%m%d") if start_date else None,
                end_date=end_date.strftime("%Y%m%d") if end_date else None,
            )
            return df if df is not None else pd.DataFrame()
        except Exception as e:
            logger.error("get_minute_data_error", error=str(e), ts_code=ts_code)
            raise TushareAPIError(f"获取分钟数据失败: {e}") from e

    def get_limit_price(
        self,
        ts_code: Optional[str] = None,
        trade_date: Optional[date] = None,
    ) -> pd.DataFrame:
        """获取每日涨跌停价格。

        Args:
            ts_code: 股票代码
            trade_date: 交易日期

        Returns:
            包含涨跌停价格的 DataFrame
        """
        params = {}
        if ts_code:
            params["ts_code"] = ts_code
        if trade_date:
            params["trade_date"] = trade_date.strftime("%Y%m%d")

        return self._call_api("stk_limit", **params)

    def get_limit_list(
        self,
        trade_date: Optional[date] = None,
        ts_code: Optional[str] = None,
        limit_type: Optional[str] = None,
    ) -> pd.DataFrame:
        """获取当日涨跌停明细。"""

        params: Dict[str, Any] = {}
        if trade_date:
            params["trade_date"] = trade_date.strftime("%Y%m%d")
        if ts_code:
            params["ts_code"] = ts_code
        if limit_type:
            params["limit_type"] = limit_type

        return self._call_api("limit_list", **params)

    def get_suspend_info(
        self,
        ts_code: Optional[str] = None,
        suspend_date: Optional[date] = None,
        resume_date: Optional[date] = None,
    ) -> pd.DataFrame:
        """获取停复牌信息。

        Args:
            ts_code: 股票代码
            suspend_date: 停牌日期
            resume_date: 复牌日期

        Returns:
            包含停复牌信息的 DataFrame
        """
        params = {}
        if ts_code:
            params["ts_code"] = ts_code
        if suspend_date:
            params["suspend_date"] = suspend_date.strftime("%Y%m%d")
        if resume_date:
            params["resume_date"] = resume_date.strftime("%Y%m%d")

        return self._call_api("suspend_d", **params)

    def get_stock_industry(
        self,
        ts_code: Optional[str] = None,
        src: str = "SW2021",
    ) -> pd.DataFrame:
        """获取股票行业分类。

        Args:
            ts_code: 股票代码
            src: 分类标准 (SW2021申万, CS中证)

        Returns:
            包含行业分类的 DataFrame
        """
        params = {"src": src}
        if ts_code:
            params["ts_code"] = ts_code

        return self._call_api("stock_company", **params)

    def get_moneyflow(
        self,
        ts_code: Optional[str] = None,
        trade_date: Optional[date] = None,
        start_date: Optional[date] = None,
        end_date: Optional[date] = None,
    ) -> pd.DataFrame:
        """获取个股资金流向数据。"""

        params: Dict[str, Any] = {}
        if ts_code:
            params["ts_code"] = ts_code
        if trade_date:
            params["trade_date"] = trade_date.strftime("%Y%m%d")
        if start_date:
            params["start_date"] = start_date.strftime("%Y%m%d")
        if end_date:
            params["end_date"] = end_date.strftime("%Y%m%d")

        return self._call_api("moneyflow", **params)

    def get_moneyflow_hsgt(
        self,
        start_date: Optional[date] = None,
        end_date: Optional[date] = None,
    ) -> pd.DataFrame:
        """获取沪深港通资金流向。"""

        params: Dict[str, Any] = {}
        if start_date:
            params["start_date"] = start_date.strftime("%Y%m%d")
        if end_date:
            params["end_date"] = end_date.strftime("%Y%m%d")

        return self._call_api("moneyflow_hsgt", **params)

    def get_top_list(
        self,
        trade_date: Optional[date] = None,
        start_date: Optional[date] = None,
        end_date: Optional[date] = None,
        ts_code: Optional[str] = None,
    ) -> pd.DataFrame:
        """获取龙虎榜上榜股票列表。"""

        params: Dict[str, Any] = {}
        if trade_date:
            params["trade_date"] = trade_date.strftime("%Y%m%d")
        if start_date:
            params["start_date"] = start_date.strftime("%Y%m%d")
        if end_date:
            params["end_date"] = end_date.strftime("%Y%m%d")
        if ts_code:
            params["ts_code"] = ts_code

        return self._call_api("top_list", **params)

    def get_top_institution(
        self,
        trade_date: Optional[date] = None,
        ts_code: Optional[str] = None,
    ) -> pd.DataFrame:
        """获取龙虎榜机构成交明细。"""

        params: Dict[str, Any] = {}
        if trade_date:
            params["trade_date"] = trade_date.strftime("%Y%m%d")
        if ts_code:
            params["ts_code"] = ts_code

        return self._call_api("top_inst", **params)

    def health_check(self) -> bool:
        """健康检查：验证 API 连接和权限。

        Returns:
            True 表示连接正常
        """
        try:
            # 简单查询测试连接
            df = self.get_trading_calendar(
                start_date=date.today(),
                end_date=date.today(),
            )
            logger.info("tushare_health_check_passed", rows=len(df))
            return True
        except Exception as e:
            logger.error("tushare_health_check_failed", error=str(e))
            return False


__all__ = ["TushareClient", "TushareAPIError", "TushareRateLimitError"]
