"""A 股工具注册中心 - 注册可供 LLM 调用的 A 股交易工具。

这个模块提供 LLM 可用的 A 股工具函数：
1. 市场数据查询工具（Tushare）
2. 投资组合查询工具（模拟券商）
3. 风险评估工具（A 股规则）
4. 交易时间判断工具
"""

from datetime import date
from typing import Any, Dict, List, Optional

import structlog

from .function_calling import auto_register_tool, global_registry

logger = structlog.get_logger("ashare_tools_registry")

# 全局工具实例（由 Orchestrator 在初始化时设置）
_portfolio_tool = None
_market_data_service = None
_trade_calendar = None


def set_ashare_tool_instances(
    portfolio_tool=None, market_data_service=None, trade_calendar=None
):
    """设置 A 股工具实例（由 Orchestrator 调用）。"""
    global _portfolio_tool, _market_data_service, _trade_calendar
    _portfolio_tool = portfolio_tool
    _market_data_service = market_data_service
    _trade_calendar = trade_calendar
    logger.info(
        "ashare_tools_instances.set",
        has_portfolio=portfolio_tool is not None,
        has_market_data=market_data_service is not None,
        has_calendar=trade_calendar is not None,
    )


@auto_register_tool(
    name="get_stock_info",
    description="获取 A 股股票的基本信息（名称、行业、上市日期）",
    category="market_data",
    risk_level="low",
)
def get_stock_info(ts_code: str) -> Dict[str, Any]:
    """获取股票基本信息。

    Args:
        ts_code: 股票代码，如 000001.SZ

    Returns:
        包含股票基本信息的字典
    """
    logger.info("tool.get_stock_info", ts_code=ts_code)

    if not _market_data_service:
        raise RuntimeError("MarketDataService 未初始化")

    try:
        stock_info = _market_data_service.get_stock_info(ts_code)
        if not stock_info:
            return {"error": f"未找到股票 {ts_code}"}

        return {
            "ts_code": stock_info.ts_code,
            "symbol": stock_info.symbol,
            "name": stock_info.name,
            "industry": stock_info.industry,
            "market": stock_info.market,
            "list_date": stock_info.list_date.isoformat()
            if stock_info.list_date
            else None,
        }
    except Exception as e:
        logger.error("tool.get_stock_info.error", ts_code=ts_code, error=str(e))
        raise RuntimeError(f"获取股票信息失败: {str(e)}") from e


@auto_register_tool(
    name="get_current_price",
    description="获取 A 股的最新价格（实时行情）",
    category="market_data",
    risk_level="low",
)
def get_current_price(ts_code: str, trade_date: Optional[str] = None) -> Dict[str, Any]:
    """获取当前价格。

    Args:
        ts_code: 股票代码，如 000001.SZ
        trade_date: 交易日期（YYYYMMDD），默认最近交易日

    Returns:
        包含价格信息的字典
    """
    logger.info("tool.get_current_price", ts_code=ts_code, trade_date=trade_date)

    if not _market_data_service:
        raise RuntimeError("MarketDataService 未初始化")

    try:
        price = _market_data_service.get_latest_price(ts_code)
        if price is None:
            return {"error": f"未获取到 {ts_code} 的价格数据"}

        return {
            "ts_code": ts_code,
            "price": float(price),
            "trade_date": trade_date or "最近交易日",
        }
    except Exception as e:
        logger.error("tool.get_current_price.error", ts_code=ts_code, error=str(e))
        raise RuntimeError(f"获取价格失败: {str(e)}") from e


@auto_register_tool(
    name="get_limit_prices",
    description="获取 A 股的涨跌停价格",
    category="market_data",
    risk_level="low",
)
def get_limit_prices(ts_code: str, trade_date: Optional[str] = None) -> Dict[str, Any]:
    """获取涨跌停价格。

    Args:
        ts_code: 股票代码
        trade_date: 交易日期（YYYYMMDD）

    Returns:
        包含涨跌停价格的字典
    """
    logger.info("tool.get_limit_prices", ts_code=ts_code, trade_date=trade_date)

    if not _market_data_service:
        raise RuntimeError("MarketDataService 未初始化")

    try:
        limit_price = _market_data_service.get_limit_prices(ts_code, trade_date)
        if not limit_price:
            return {"error": f"未获取到 {ts_code} 的涨跌停价格"}

        result = {
            "ts_code": ts_code,
            "trade_date": limit_price.trade_date,
            "up_limit": float(limit_price.up_limit),
            "down_limit": float(limit_price.down_limit),
        }
        if limit_price.pct_chg is not None:
            result["pct_chg"] = float(limit_price.pct_chg)

        return result
    except Exception as e:
        logger.error("tool.get_limit_prices.error", ts_code=ts_code, error=str(e))
        raise RuntimeError(f"获取涨跌停价格失败: {str(e)}") from e


@auto_register_tool(
    name="is_suspended",
    description="检查 A 股是否停牌",
    category="market_data",
    risk_level="low",
)
def is_suspended(ts_code: str, check_date: Optional[str] = None) -> Dict[str, Any]:
    """检查股票是否停牌。

    Args:
        ts_code: 股票代码
        check_date: 检查日期（YYYYMMDD）

    Returns:
        停牌状态信息
    """
    logger.info("tool.is_suspended", ts_code=ts_code, check_date=check_date)

    if not _market_data_service:
        raise RuntimeError("MarketDataService 未初始化")

    try:
        suspended = _market_data_service.is_suspended(ts_code, check_date)
        return {
            "ts_code": ts_code,
            "is_suspended": suspended,
            "check_date": check_date or "今日",
        }
    except Exception as e:
        logger.error("tool.is_suspended.error", ts_code=ts_code, error=str(e))
        raise RuntimeError(f"检查停牌状态失败: {str(e)}") from e


@auto_register_tool(
    name="get_portfolio_summary",
    description="获取当前 A 股投资组合摘要（总资产、现金、持仓）",
    category="portfolio",
    risk_level="low",
)
def get_portfolio_summary() -> Dict[str, Any]:
    """获取投资组合摘要。

    Returns:
        包含投资组合信息的字典
    """
    logger.info("tool.get_portfolio_summary")

    if not _portfolio_tool:
        raise RuntimeError("ASharePortfolioTool 未初始化")

    try:
        snapshot = _portfolio_tool.fetch_snapshot()

        positions_list = []
        for pos in snapshot.positions:
            positions_list.append(
                {
                    "ts_code": pos.ts_code,
                    "total_quantity": pos.total_quantity,
                    "available_quantity": pos.available_quantity,
                    "cost_price": float(pos.cost_price),
                    "market_value": float(pos.market_value),
                    "unrealized_pnl": float(pos.unrealized_pnl),
                    "unrealized_pnl_pct": float(pos.unrealized_pnl_pct),
                }
            )

        return {
            "total_assets": float(snapshot.total_assets),
            "available_cash": float(snapshot.available_cash),
            "frozen_cash": float(snapshot.frozen_cash),
            "market_value": float(snapshot.market_value),
            "cost_value": float(snapshot.cost_value),
            "unrealized_pnl": float(snapshot.unrealized_pnl),
            "unrealized_pnl_pct": float(snapshot.unrealized_pnl_pct),
            "positions": positions_list,
            "positions_count": len(positions_list),
        }
    except Exception as e:
        logger.error("tool.get_portfolio_summary.error", error=str(e))
        raise RuntimeError(f"获取投资组合数据失败: {str(e)}") from e


@auto_register_tool(
    name="calculate_position_size",
    description="智能计算 A 股建议仓位大小（基于风险等级和置信度）",
    category="portfolio",
    risk_level="low",
)
def calculate_position_size(
    ts_code: str,
    price: float,
    risk_level: str,
    confidence: float = 0.7,
) -> Dict[str, Any]:
    """计算建议仓位大小。

    Args:
        ts_code: 股票代码
        price: 当前价格
        risk_level: 风险等级（LOW/MEDIUM/HIGH）
        confidence: 置信度（0-1）

    Returns:
        仓位建议
    """
    logger.info(
        "tool.calculate_position_size",
        ts_code=ts_code,
        price=price,
        risk_level=risk_level,
        confidence=confidence,
    )

    if not _portfolio_tool:
        raise RuntimeError("ASharePortfolioTool 未初始化")

    try:
        result = _portfolio_tool.calculate_position_size(
            ts_code=ts_code, price=price, risk_level=risk_level, confidence=confidence
        )
        return {
            "ts_code": ts_code,
            "suggested_quantity": result["suggested_quantity"],
            "suggested_amount": float(result["suggested_amount"]),
            "position_ratio": float(result["position_ratio"]),
            "max_positions_reached": result["max_positions_reached"],
            "reason": result["reason"],
        }
    except Exception as e:
        logger.error("tool.calculate_position_size.error", ts_code=ts_code, error=str(e))
        raise RuntimeError(f"计算仓位失败: {str(e)}") from e


@auto_register_tool(
    name="evaluate_risk_level",
    description="评估当前 A 股投资组合的风险级别",
    category="risk",
    risk_level="low",
)
def evaluate_risk_level() -> Dict[str, Any]:
    """评估风险级别。

    Returns:
        风险评估结果
    """
    logger.info("tool.evaluate_risk_level")

    if not _portfolio_tool:
        raise RuntimeError("ASharePortfolioTool 未初始化")

    try:
        snapshot = _portfolio_tool.fetch_snapshot()

        # 计算风险因子
        cash_ratio = (
            snapshot.available_cash / snapshot.total_assets
            if snapshot.total_assets > 0
            else 1.0
        )
        position_count = len(snapshot.positions)

        # 集中度风险
        concentration = _portfolio_tool.get_concentration_risk()

        risk_factors = []
        overall_risk = "LOW"

        # 现金不足
        if cash_ratio < 0.1:
            risk_factors.append("现金余额不足 10%")
            overall_risk = "MEDIUM"

        # 持仓集中度
        if concentration["is_high_risk"]:
            risk_factors.append(
                f"持仓集中度过高：前 3 大占比 {concentration['concentration_ratio']*100:.1f}%"
            )
            overall_risk = "HIGH"

        # 持仓数量
        if position_count >= 3:
            risk_factors.append("持仓数量已达上限（3 个）")
            overall_risk = "MEDIUM" if overall_risk == "LOW" else overall_risk

        return {
            "risk_level": overall_risk,
            "cash_ratio": round(cash_ratio, 3),
            "position_count": position_count,
            "concentration_ratio": round(concentration["concentration_ratio"], 3),
            "factors": risk_factors if risk_factors else ["无明显风险"],
            "recommendation": "建议降低仓位"
            if overall_risk == "HIGH"
            else "可适度建仓" if overall_risk == "LOW" else "谨慎操作",
        }
    except Exception as e:
        logger.error("tool.evaluate_risk_level.error", error=str(e))
        raise RuntimeError(f"风险评估失败: {str(e)}") from e


@auto_register_tool(
    name="check_trading_time",
    description="检查当前是否为 A 股交易时间",
    category="market_data",
    risk_level="low",
)
def check_trading_time(check_date: Optional[str] = None) -> Dict[str, Any]:
    """检查交易时间。

    Args:
        check_date: 检查日期（可选）

    Returns:
        交易时间状态
    """
    logger.info("tool.check_trading_time", check_date=check_date)

    if not _trade_calendar:
        raise RuntimeError("TradeCalendar 未初始化")

    try:
        target_date = date.today()
        if check_date:
            from datetime import datetime

            target_date = datetime.strptime(check_date, "%Y%m%d").date()

        is_trading = _trade_calendar.is_trading_day(target_date)
        session = _trade_calendar.get_trading_session(target_date)

        prev_day = _trade_calendar.get_previous_trading_day(target_date)
        next_day = _trade_calendar.get_next_trading_day(target_date)

        return {
            "date": target_date.isoformat(),
            "is_trading_day": is_trading,
            "trading_session": session.value if session else "CLOSED",
            "previous_trading_day": prev_day.isoformat() if prev_day else None,
            "next_trading_day": next_day.isoformat() if next_day else None,
        }
    except Exception as e:
        logger.error("tool.check_trading_time.error", check_date=check_date, error=str(e))
        raise RuntimeError(f"检查交易时间失败: {str(e)}") from e


@auto_register_tool(
    name="get_t1_sellable_quantity",
    description="检查 A 股持仓的 T+1 可卖数量（当日买入次日可卖）",
    category="portfolio",
    risk_level="low",
)
def get_t1_sellable_quantity(
    ts_code: str, check_date: Optional[str] = None
) -> Dict[str, Any]:
    """检查 T+1 可卖数量。

    Args:
        ts_code: 股票代码
        check_date: 检查日期（YYYYMMDD）

    Returns:
        T+1 可卖信息
    """
    logger.info("tool.get_t1_sellable_quantity", ts_code=ts_code, check_date=check_date)

    if not _portfolio_tool:
        raise RuntimeError("ASharePortfolioTool 未初始化")

    try:
        target_date = date.today()
        if check_date:
            from datetime import datetime

            target_date = datetime.strptime(check_date, "%Y%m%d").date()

        sellable = _portfolio_tool.check_sellable_quantity(ts_code, target_date)

        return {
            "ts_code": ts_code,
            "check_date": target_date.isoformat(),
            "sellable_quantity": sellable,
            "t1_restriction": sellable == 0,
            "note": "当日买入的股票次日才能卖出（T+1 规则）",
        }
    except Exception as e:
        logger.error(
            "tool.get_t1_sellable_quantity.error", ts_code=ts_code, error=str(e)
        )
        raise RuntimeError(f"检查 T+1 可卖数量失败: {str(e)}") from e


@auto_register_tool(
    name="get_daily_quote",
    description="获取 A 股的日线行情数据（开高低收、成交量、涨跌幅）",
    category="market_data",
    risk_level="low",
)
def get_daily_quote(
    ts_code: str, start_date: Optional[str] = None, end_date: Optional[str] = None
) -> Dict[str, Any]:
    """获取日线行情。

    Args:
        ts_code: 股票代码
        start_date: 开始日期（YYYYMMDD）
        end_date: 结束日期（YYYYMMDD）

    Returns:
        日线行情数据
    """
    logger.info(
        "tool.get_daily_quote",
        ts_code=ts_code,
        start_date=start_date,
        end_date=end_date,
    )

    if not _market_data_service:
        raise RuntimeError("MarketDataService 未初始化")

    try:
        quotes = _market_data_service.get_daily_quotes(
            ts_code, start_date, end_date, limit=5
        )
        if not quotes:
            return {"error": f"未获取到 {ts_code} 的行情数据"}

        quotes_data = []
        for quote in quotes:
            quotes_data.append(
                {
                    "trade_date": quote.trade_date,
                    "open": float(quote.open),
                    "high": float(quote.high),
                    "low": float(quote.low),
                    "close": float(quote.close),
                    "volume": float(quote.volume),
                    "amount": float(quote.amount),
                    "pct_chg": float(quote.pct_chg),
                }
            )

        return {
            "ts_code": ts_code,
            "quotes": quotes_data,
            "count": len(quotes_data),
        }
    except Exception as e:
        logger.error("tool.get_daily_quote.error", ts_code=ts_code, error=str(e))
        raise RuntimeError(f"获取日线行情失败: {str(e)}") from e


def register_ashare_tools():
    """注册所有 A 股工具到全局注册中心。"""
    # 工具已通过装饰器自动注册
    tools_count = len(global_registry.list_tools())
    logger.info("ashare_tools_registry.initialized", tools_count=tools_count)

    # 打印已注册的工具
    for tool in global_registry.list_tools():
        logger.debug(
            "ashare_tool.registered",
            name=tool.name,
            category=tool.category,
            risk_level=tool.risk_level,
        )


# 模块导入时自动注册
register_ashare_tools()


__all__ = [
    "register_ashare_tools",
    "set_ashare_tool_instances",
    "get_stock_info",
    "get_current_price",
    "get_limit_prices",
    "is_suspended",
    "get_portfolio_summary",
    "calculate_position_size",
    "evaluate_risk_level",
    "check_trading_time",
    "get_t1_sellable_quantity",
    "get_daily_quote",
]
