"""A 股风控规则集合。

提供 A 股市场特有的风控规则，包括涨跌停、集中度、资金使用等。
"""

from __future__ import annotations

from dataclasses import dataclass
from datetime import date, datetime, timedelta
from typing import List, Optional

import structlog

from laca.adapters.broker import OrderSide, OrderStatus, SimulatedBrokerClient
from laca.adapters.data.market_data_service import MarketDataService
from laca.scheduler.trade_calendar import TradeCalendar
from laca.tools.ashare_portfolio_tool import ASharePortfolioTool, PortfolioSnapshot

from .models import RiskAlert, RiskSeverity

logger = structlog.get_logger(__name__)


@dataclass
class LimitPriceRule:
    """涨跌停价格验证规则。

    检查买卖价格是否在涨跌停范围内。

    Attributes:
        rule_id: 规则 ID
        market_data_service: 行情服务
        price_buffer: 价格缓冲（默认 0.01，防止价格接近涨跌停）
    """

    rule_id: str = "limit_price"
    market_data_service: Optional[MarketDataService] = None
    price_buffer: float = 0.01  # 1% 缓冲

    def __post_init__(self):
        """初始化行情服务。"""
        if self.market_data_service is None:
            self.market_data_service = MarketDataService()

    def validate_price(
        self,
        ts_code: str,
        side: OrderSide,
        price: float,
        trade_date: Optional[date] = None,
    ) -> List[RiskAlert]:
        """验证价格是否在涨跌停范围内。

        Args:
            ts_code: 股票代码
            side: 买卖方向
            price: 委托价格
            trade_date: 交易日期（默认今天）

        Returns:
            风险告警列表
        """
        alerts: List[RiskAlert] = []

        if trade_date is None:
            trade_date = date.today()

        # 获取涨跌停价格
        limit_price = self.market_data_service.get_limit_prices(ts_code, trade_date)
        if not limit_price:
            alerts.append(
                RiskAlert(
                    rule_id=self.rule_id,
                    severity=RiskSeverity.WARNING,
                    message=f"{ts_code} 无法获取涨跌停价格",
                    details={"ts_code": ts_code, "price": price, "side": side.value},
                )
            )
            return alerts

        # 买入：检查是否接近涨停价
        if side == OrderSide.BUY:
            # 计算涨停缓冲价格
            buffer_price = limit_price.up_limit * (1 - self.price_buffer)

            if price > limit_price.up_limit:
                alerts.append(
                    RiskAlert(
                        rule_id=self.rule_id,
                        severity=RiskSeverity.CRITICAL,
                        message=f"{ts_code} 买入价 {price:.2f} 超过涨停价 {limit_price.up_limit:.2f}",
                        details={
                            "ts_code": ts_code,
                            "price": price,
                            "up_limit": limit_price.up_limit,
                            "down_limit": limit_price.down_limit,
                        },
                    )
                )
            elif price >= buffer_price:
                alerts.append(
                    RiskAlert(
                        rule_id=self.rule_id,
                        severity=RiskSeverity.WARNING,
                        message=f"{ts_code} 买入价 {price:.2f} 接近涨停价 {limit_price.up_limit:.2f}",
                        details={
                            "ts_code": ts_code,
                            "price": price,
                            "up_limit": limit_price.up_limit,
                            "buffer_price": buffer_price,
                        },
                    )
                )

        # 卖出：检查是否接近跌停价
        elif side == OrderSide.SELL:
            # 计算跌停缓冲价格
            buffer_price = limit_price.down_limit * (1 + self.price_buffer)

            if price < limit_price.down_limit:
                alerts.append(
                    RiskAlert(
                        rule_id=self.rule_id,
                        severity=RiskSeverity.CRITICAL,
                        message=f"{ts_code} 卖出价 {price:.2f} 低于跌停价 {limit_price.down_limit:.2f}",
                        details={
                            "ts_code": ts_code,
                            "price": price,
                            "up_limit": limit_price.up_limit,
                            "down_limit": limit_price.down_limit,
                        },
                    )
                )
            elif price <= buffer_price:
                alerts.append(
                    RiskAlert(
                        rule_id=self.rule_id,
                        severity=RiskSeverity.WARNING,
                        message=f"{ts_code} 卖出价 {price:.2f} 接近跌停价 {limit_price.down_limit:.2f}",
                        details={
                            "ts_code": ts_code,
                            "price": price,
                            "down_limit": limit_price.down_limit,
                            "buffer_price": buffer_price,
                        },
                    )
                )

        return alerts


@dataclass
class ConcentrationRule:
    """单股集中度限制规则。

    检查单只股票持仓占总资产的比例是否超过限制。

    Attributes:
        rule_id: 规则 ID
        max_single_position_ratio: 单只股票最大仓位比例（默认 20%）
        warning_ratio: 警告阈值（默认 15%）
    """

    rule_id: str = "concentration"
    max_single_position_ratio: float = 0.20
    warning_ratio: float = 0.15

    def evaluate(self, snapshot: PortfolioSnapshot) -> List[RiskAlert]:
        """评估持仓集中度。

        Args:
            snapshot: 账户快照

        Returns:
            风险告警列表
        """
        alerts: List[RiskAlert] = []

        if snapshot.total_assets <= 0:
            return alerts

        for position in snapshot.positions:
            position_ratio = position.market_value / snapshot.total_assets

            if position_ratio > self.max_single_position_ratio:
                alerts.append(
                    RiskAlert(
                        rule_id=self.rule_id,
                        severity=RiskSeverity.CRITICAL,
                        message=f"{position.ts_code} 持仓占比 {position_ratio:.2%} 超过上限 {self.max_single_position_ratio:.2%}",
                        details={
                            "ts_code": position.ts_code,
                            "position_ratio": position_ratio,
                            "market_value": position.market_value,
                            "total_assets": snapshot.total_assets,
                        },
                    )
                )
            elif position_ratio > self.warning_ratio:
                alerts.append(
                    RiskAlert(
                        rule_id=self.rule_id,
                        severity=RiskSeverity.WARNING,
                        message=f"{position.ts_code} 持仓占比 {position_ratio:.2%} 接近上限 {self.max_single_position_ratio:.2%}",
                        details={
                            "ts_code": position.ts_code,
                            "position_ratio": position_ratio,
                            "market_value": position.market_value,
                            "total_assets": snapshot.total_assets,
                        },
                    )
                )

        return alerts


@dataclass
class CashUsageRule:
    """资金使用率规则。

    检查可用资金是否充足，避免满仓操作。

    Attributes:
        rule_id: 规则 ID
        min_cash_ratio: 最低现金比例（默认 10%）
        warning_cash_ratio: 警告现金比例（默认 20%）
    """

    rule_id: str = "cash_usage"
    min_cash_ratio: float = 0.10
    warning_cash_ratio: float = 0.20

    def evaluate(self, snapshot: PortfolioSnapshot) -> List[RiskAlert]:
        """评估资金使用情况。

        Args:
            snapshot: 账户快照

        Returns:
            风险告警列表
        """
        alerts: List[RiskAlert] = []

        if snapshot.total_assets <= 0:
            return alerts

        # 计算可用现金比例
        cash_ratio = snapshot.available_cash / snapshot.total_assets

        if cash_ratio < self.min_cash_ratio:
            alerts.append(
                RiskAlert(
                    rule_id=self.rule_id,
                    severity=RiskSeverity.CRITICAL,
                    message=f"可用现金比例 {cash_ratio:.2%} 低于最低要求 {self.min_cash_ratio:.2%}",
                    details={
                        "available_cash": snapshot.available_cash,
                        "frozen_cash": snapshot.frozen_cash,
                        "total_assets": snapshot.total_assets,
                        "cash_ratio": cash_ratio,
                    },
                )
            )
        elif cash_ratio < self.warning_cash_ratio:
            alerts.append(
                RiskAlert(
                    rule_id=self.rule_id,
                    severity=RiskSeverity.WARNING,
                    message=f"可用现金比例 {cash_ratio:.2%} 偏低，建议保持 {self.warning_cash_ratio:.2%} 以上",
                    details={
                        "available_cash": snapshot.available_cash,
                        "frozen_cash": snapshot.frozen_cash,
                        "total_assets": snapshot.total_assets,
                        "cash_ratio": cash_ratio,
                    },
                )
            )

        return alerts


@dataclass
class TradingTimeRule:
    """交易时间验证规则。

    检查当前时间是否在交易时段内。

    Attributes:
        rule_id: 规则 ID
        trade_calendar: 交易日历
        allow_call_auction: 是否允许集合竞价时段（默认 True）
    """

    rule_id: str = "trading_time"
    trade_calendar: Optional[TradeCalendar] = None
    allow_call_auction: bool = True

    def __post_init__(self):
        """初始化交易日历。"""
        if self.trade_calendar is None:
            self.trade_calendar = TradeCalendar()

    def validate_trading_time(
        self,
        check_time: Optional[datetime] = None,
    ) -> List[RiskAlert]:
        """验证是否在交易时段。

        Args:
            check_time: 检查时间（默认当前时间）

        Returns:
            风险告警列表
        """
        alerts: List[RiskAlert] = []

        if check_time is None:
            check_time = datetime.now()

        # 检查是否为交易日
        if not self.trade_calendar.is_trading_day(check_time.date()):
            alerts.append(
                RiskAlert(
                    rule_id=self.rule_id,
                    severity=RiskSeverity.CRITICAL,
                    message=f"{check_time.date()} 不是交易日",
                    details={"check_time": check_time.isoformat()},
                )
            )
            return alerts

        # 获取交易时段
        session = self.trade_calendar.get_trading_session(check_time)

        # 检查是否在交易时段
        if session in {"morning", "afternoon"}:
            # 正常交易时段
            return alerts
        elif session in {"morning_call_auction", "afternoon_call_auction"}:
            # 集合竞价时段
            if not self.allow_call_auction:
                alerts.append(
                    RiskAlert(
                        rule_id=self.rule_id,
                        severity=RiskSeverity.WARNING,
                        message=f"当前为集合竞价时段（{session}），不建议交易",
                        details={
                            "check_time": check_time.isoformat(),
                            "session": session,
                        },
                    )
                )
        else:
            # 非交易时段
            alerts.append(
                RiskAlert(
                    rule_id=self.rule_id,
                    severity=RiskSeverity.CRITICAL,
                    message=f"当前非交易时段（{session}），无法交易",
                    details={
                        "check_time": check_time.isoformat(),
                        "session": session,
                    },
                )
            )

        return alerts


@dataclass
class SuspensionRule:
    """停牌检查规则。

    检查股票是否停牌。

    Attributes:
        rule_id: 规则 ID
        market_data_service: 行情服务
    """

    rule_id: str = "suspension"
    market_data_service: Optional[MarketDataService] = None

    def __post_init__(self):
        """初始化行情服务。"""
        if self.market_data_service is None:
            self.market_data_service = MarketDataService()

    def validate_suspension(
        self,
        ts_code: str,
        check_date: Optional[date] = None,
    ) -> List[RiskAlert]:
        """验证股票是否停牌。

        Args:
            ts_code: 股票代码
            check_date: 检查日期（默认今天）

        Returns:
            风险告警列表
        """
        alerts: List[RiskAlert] = []

        if check_date is None:
            check_date = date.today()

        # 检查是否停牌
        if self.market_data_service.is_suspended(ts_code, check_date):
            alerts.append(
                RiskAlert(
                    rule_id=self.rule_id,
                    severity=RiskSeverity.CRITICAL,
                    message=f"{ts_code} 停牌中，无法交易",
                    details={"ts_code": ts_code, "check_date": check_date.isoformat()},
                )
            )

        return alerts


@dataclass
class T1SellRule:
    """T+1 卖出限制规则。

    检查是否违反 T+1 规则（当日买入不能当日卖出）。

    Attributes:
        rule_id: 规则 ID
        portfolio_tool: 组合工具
    """

    rule_id: str = "t1_sell"
    portfolio_tool: Optional[ASharePortfolioTool] = None

    def __post_init__(self):
        """初始化组合工具。"""
        if self.portfolio_tool is None:
            self.portfolio_tool = ASharePortfolioTool()

    def validate_sell_quantity(
        self,
        ts_code: str,
        quantity: int,
        check_date: Optional[date] = None,
    ) -> List[RiskAlert]:
        """验证卖出数量是否符合 T+1 规则。

        Args:
            ts_code: 股票代码
            quantity: 卖出数量
            check_date: 检查日期（默认今天）

        Returns:
            风险告警列表
        """
        alerts: List[RiskAlert] = []

        if check_date is None:
            check_date = date.today()

        # 检查可卖数量
        sellable_quantity = self.portfolio_tool.check_sellable_quantity(
            ts_code, check_date
        )

        if quantity > sellable_quantity:
            position = self.portfolio_tool.get_position(ts_code)
            total_quantity = position.total_quantity if position else 0

            alerts.append(
                RiskAlert(
                    rule_id=self.rule_id,
                    severity=RiskSeverity.CRITICAL,
                    message=f"{ts_code} 可卖数量不足：需要 {quantity}，可卖 {sellable_quantity}（T+1 规则）",
                    details={
                        "ts_code": ts_code,
                        "requested_quantity": quantity,
                        "sellable_quantity": sellable_quantity,
                        "total_quantity": total_quantity,
                        "last_buy_date": position.last_buy_date.isoformat()
                        if position and position.last_buy_date
                        else None,
                    },
                )
            )

        return alerts


@dataclass
class PendingOrderTimeoutRule:
    """订单超时告警规则。

    检查订单是否超时未成交。

    Attributes:
        rule_id: 规则 ID
        timeout_minutes: 超时时间（分钟）
        broker_client: 券商客户端
    """

    rule_id: str = "pending_order_timeout"
    timeout_minutes: int = 30
    broker_client: Optional[SimulatedBrokerClient] = None

    def check_pending_orders(self) -> List[RiskAlert]:
        """检查待成交订单是否超时。

        Returns:
            风险告警列表
        """
        alerts: List[RiskAlert] = []

        if not self.broker_client:
            return alerts

        # 获取待成交订单
        pending_orders = self.broker_client.get_orders(status=OrderStatus.PENDING)

        now = datetime.now()
        timeout_threshold = timedelta(minutes=self.timeout_minutes)

        for order in pending_orders:
            elapsed = now - order.created_at

            if elapsed > timeout_threshold:
                alerts.append(
                    RiskAlert(
                        rule_id=self.rule_id,
                        severity=RiskSeverity.WARNING,
                        message=f"订单 {order.order_id} 超时未成交（{elapsed.seconds // 60} 分钟）",
                        details={
                            "order_id": order.order_id,
                            "ts_code": order.ts_code,
                            "side": order.side.value,
                            "price": order.price,
                            "quantity": order.quantity,
                            "created_at": order.created_at.isoformat(),
                            "elapsed_minutes": elapsed.seconds // 60,
                        },
                    )
                )

        return alerts


@dataclass
class DailyTradeCountRule:
    """单日交易次数限制规则。

    检查单日交易次数是否超过限制。

    Attributes:
        rule_id: 规则 ID
        max_daily_trades: 单日最大交易次数
        broker_client: 券商客户端
    """

    rule_id: str = "daily_trade_count"
    max_daily_trades: int = 10
    broker_client: Optional[SimulatedBrokerClient] = None

    def check_daily_trades(self, check_date: Optional[date] = None) -> List[RiskAlert]:
        """检查单日交易次数。

        Args:
            check_date: 检查日期（默认今天）

        Returns:
            风险告警列表
        """
        alerts: List[RiskAlert] = []

        if not self.broker_client:
            return alerts

        if check_date is None:
            check_date = date.today()

        # 统计今日成交订单数
        all_orders = self.broker_client.get_orders()
        today_filled_orders = [
            o
            for o in all_orders
            if o.status == OrderStatus.FILLED
            and o.filled_at
            and o.filled_at.date() == check_date
        ]

        trade_count = len(today_filled_orders)

        if trade_count >= self.max_daily_trades:
            alerts.append(
                RiskAlert(
                    rule_id=self.rule_id,
                    severity=RiskSeverity.WARNING,
                    message=f"单日交易次数 {trade_count} 已达上限 {self.max_daily_trades}",
                    details={
                        "trade_count": trade_count,
                        "max_trades": self.max_daily_trades,
                        "check_date": check_date.isoformat(),
                    },
                )
            )

        return alerts


@dataclass
class AbnormalVolatilityRule:
    """异常波动告警规则。

    检查股票价格是否出现异常波动。

    Attributes:
        rule_id: 规则 ID
        volatility_threshold: 波动率阈值（默认 5%）
        market_data_service: 行情服务
    """

    rule_id: str = "abnormal_volatility"
    volatility_threshold: float = 0.05  # 5%
    market_data_service: Optional[MarketDataService] = None

    def __post_init__(self):
        """初始化行情服务。"""
        if self.market_data_service is None:
            self.market_data_service = MarketDataService()

    def check_volatility(
        self,
        ts_code: str,
        check_date: Optional[date] = None,
    ) -> List[RiskAlert]:
        """检查股票是否出现异常波动。

        Args:
            ts_code: 股票代码
            check_date: 检查日期（默认今天）

        Returns:
            风险告警列表
        """
        alerts: List[RiskAlert] = []

        if check_date is None:
            check_date = date.today()

        try:
            # 获取最近一天的行情
            quotes = self.market_data_service.get_daily_quotes(
                ts_code=ts_code,
                end_date=check_date,
                limit=1,
            )

            if not quotes:
                return alerts

            quote = quotes[0]

            # 计算日内波动率（（最高-最低）/ 开盘）
            if quote.open > 0:
                intraday_volatility = (quote.high - quote.low) / quote.open

                if intraday_volatility > self.volatility_threshold:
                    alerts.append(
                        RiskAlert(
                            rule_id=self.rule_id,
                            severity=RiskSeverity.WARNING,
                            message=f"{ts_code} 出现异常波动，日内波动率 {intraday_volatility:.2%}",
                            details={
                                "ts_code": ts_code,
                                "trade_date": quote.trade_date.isoformat(),
                                "open": quote.open,
                                "high": quote.high,
                                "low": quote.low,
                                "close": quote.close,
                                "intraday_volatility": intraday_volatility,
                                "threshold": self.volatility_threshold,
                            },
                        )
                    )

            # 检查涨跌幅
            if abs(quote.pct_chg) > self.volatility_threshold * 100:
                alerts.append(
                    RiskAlert(
                        rule_id=self.rule_id,
                        severity=RiskSeverity.WARNING,
                        message=f"{ts_code} 涨跌幅异常，当日涨跌 {quote.pct_chg:.2f}%",
                        details={
                            "ts_code": ts_code,
                            "trade_date": quote.trade_date.isoformat(),
                            "pct_chg": quote.pct_chg,
                            "close": quote.close,
                            "pre_close": quote.pre_close,
                        },
                    )
                )

        except Exception as e:
            logger.warning(
                "volatility_check_failed",
                ts_code=ts_code,
                error=str(e),
            )

        return alerts


__all__ = [
    "LimitPriceRule",
    "ConcentrationRule",
    "CashUsageRule",
    "TradingTimeRule",
    "SuspensionRule",
    "T1SellRule",
    "PendingOrderTimeoutRule",
    "DailyTradeCountRule",
    "AbnormalVolatilityRule",
]
