"""安全防护系统 - 工具调用权限、风险评估和异常处理。

这个模块提供：
1. 权限管理和访问控制
2. 风险评估和限制
3. 异常检测和处理
4. 安全审计和日志
5. 紧急停止机制
"""

from __future__ import annotations

import logging
from typing import Any, Dict, List, Optional, Set, Tuple
from datetime import datetime, timedelta
from dataclasses import dataclass, field
from enum import Enum
import json

from .function_calling import ToolCall, ToolFunction, PermissionLevel


class RiskLevel(Enum):
    """风险级别。"""
    VERY_LOW = "very_low"
    LOW = "low"
    MEDIUM = "medium"
    HIGH = "high"
    CRITICAL = "critical"


class SecurityEvent(Enum):
    """安全事件类型。"""
    PERMISSION_DENIED = "permission_denied"
    RISK_THRESHOLD_EXCEEDED = "risk_threshold_exceeded"
    SUSPICIOUS_ACTIVITY = "suspicious_activity"
    RATE_LIMIT_EXCEEDED = "rate_limit_exceeded"
    INVALID_PARAMETERS = "invalid_parameters"
    EMERGENCY_STOP = "emergency_stop"


@dataclass
class SecurityRule:
    """安全规则。"""
    name: str
    description: str
    risk_level: RiskLevel
    enabled: bool = True
    parameters: Dict[str, Any] = field(default_factory=dict)


@dataclass
class SecurityViolation:
    """安全违规记录。"""
    timestamp: datetime
    event_type: SecurityEvent
    rule_name: str
    severity: RiskLevel
    details: Dict[str, Any]
    tool_call: Optional[ToolCall] = None
    resolved: bool = False


@dataclass
class RiskAssessment:
    """风险评估结果。"""
    overall_risk: RiskLevel
    risk_score: float  # 0.0 - 1.0
    risk_factors: List[str]
    recommendations: List[str]
    allowed: bool
    restrictions: List[str] = field(default_factory=list)


class SecurityGuard:
    """安全防护系统。"""
    
    def __init__(self):
        self.logger = logging.getLogger(__name__)
        self.security_rules: Dict[str, SecurityRule] = {}
        self.violations: List[SecurityViolation] = []
        self.rate_limits: Dict[str, List[datetime]] = {}
        self.emergency_stop = False
        
        # 初始化默认安全规则
        self._initialize_default_rules()
    
    def _initialize_default_rules(self):
        """初始化默认安全规则。"""
        
        # 交易金额限制
        self.add_rule(SecurityRule(
            name="trading_amount_limit",
            description="限制单笔交易金额",
            risk_level=RiskLevel.HIGH,
            parameters={
                "max_amount_usd": 10000.0,  # 最大单笔交易金额
                "max_daily_amount_usd": 50000.0,  # 最大日交易金额
                "max_position_percentage": 0.2  # 最大单个持仓比例
            }
        ))
        
        # 交易频率限制
        self.add_rule(SecurityRule(
            name="trading_frequency_limit",
            description="限制交易频率",
            risk_level=RiskLevel.MEDIUM,
            parameters={
                "max_trades_per_hour": 10,
                "max_trades_per_day": 100,
                "min_interval_seconds": 60
            }
        ))
        
        # 风险敞口限制
        self.add_rule(SecurityRule(
            name="risk_exposure_limit",
            description="限制总体风险敞口",
            risk_level=RiskLevel.HIGH,
            parameters={
                "max_total_exposure": 0.8,  # 最大总敞口比例
                "max_leverage": 3.0,  # 最大杠杆倍数
                "max_correlation_risk": 0.7  # 最大相关性风险
            }
        ))
        
        # 市场条件限制
        self.add_rule(SecurityRule(
            name="market_condition_limit",
            description="在极端市场条件下限制交易",
            risk_level=RiskLevel.CRITICAL,
            parameters={
                "max_volatility": 0.1,  # 最大波动率
                "min_liquidity": 1000000,  # 最小流动性
                "fear_greed_threshold": 20  # 恐慌指数阈值
            }
        ))
        
        # API调用频率限制
        self.add_rule(SecurityRule(
            name="api_rate_limit",
            description="限制API调用频率",
            risk_level=RiskLevel.LOW,
            parameters={
                "max_calls_per_minute": 60,
                "max_calls_per_hour": 1000
            }
        ))
    
    def add_rule(self, rule: SecurityRule) -> None:
        """添加安全规则。"""
        self.security_rules[rule.name] = rule
        self.logger.info(f"Added security rule: {rule.name}")
    
    def remove_rule(self, rule_name: str) -> bool:
        """移除安全规则。"""
        if rule_name in self.security_rules:
            del self.security_rules[rule_name]
            self.logger.info(f"Removed security rule: {rule_name}")
            return True
        return False
    
    def enable_rule(self, rule_name: str) -> bool:
        """启用安全规则。"""
        if rule_name in self.security_rules:
            self.security_rules[rule_name].enabled = True
            return True
        return False
    
    def disable_rule(self, rule_name: str) -> bool:
        """禁用安全规则。"""
        if rule_name in self.security_rules:
            self.security_rules[rule_name].enabled = False
            return True
        return False
    
    def assess_tool_call_risk(
        self,
        tool_call: ToolCall,
        tool_function: ToolFunction,
        context: Dict[str, Any]
    ) -> RiskAssessment:
        """评估工具调用风险。"""
        
        if self.emergency_stop:
            return RiskAssessment(
                overall_risk=RiskLevel.CRITICAL,
                risk_score=1.0,
                risk_factors=["Emergency stop activated"],
                recommendations=["Wait for emergency stop to be cleared"],
                allowed=False,
                restrictions=["Emergency stop in effect"]
            )
        
        risk_factors = []
        risk_score = 0.0
        restrictions = []
        
        # 检查权限级别
        if tool_function.permission_level == PermissionLevel.ADMIN:
            risk_factors.append("Admin level permission required")
            risk_score += 0.3
        elif tool_function.permission_level == PermissionLevel.TRADING:
            risk_factors.append("Trading permission required")
            risk_score += 0.2
        
        # 检查工具风险级别
        tool_risk_map = {
            "LOW": 0.1,
            "MEDIUM": 0.3,
            "HIGH": 0.6,
            "CRITICAL": 0.9
        }
        tool_risk_score = tool_risk_map.get(tool_function.risk_level, 0.5)
        risk_score += tool_risk_score
        
        if tool_risk_score > 0.5:
            risk_factors.append(f"High risk tool: {tool_function.risk_level}")
        
        # 检查具体安全规则
        for rule_name, rule in self.security_rules.items():
            if not rule.enabled:
                continue
            
            violation = self._check_rule_violation(rule, tool_call, context)
            if violation:
                risk_factors.append(f"Rule violation: {rule_name}")
                risk_score += 0.2
                restrictions.append(violation)
        
        # 检查频率限制
        rate_limit_violation = self._check_rate_limits(tool_function.name)
        if rate_limit_violation:
            risk_factors.append("Rate limit exceeded")
            risk_score += 0.3
            restrictions.append(rate_limit_violation)
        
        # 确定整体风险级别
        if risk_score >= 0.9:
            overall_risk = RiskLevel.CRITICAL
        elif risk_score >= 0.7:
            overall_risk = RiskLevel.HIGH
        elif risk_score >= 0.5:
            overall_risk = RiskLevel.MEDIUM
        elif risk_score >= 0.3:
            overall_risk = RiskLevel.LOW
        else:
            overall_risk = RiskLevel.VERY_LOW
        
        # 生成建议
        recommendations = self._generate_recommendations(risk_factors, overall_risk)
        
        # 决定是否允许执行
        allowed = overall_risk in [RiskLevel.VERY_LOW, RiskLevel.LOW, RiskLevel.MEDIUM]
        if overall_risk == RiskLevel.HIGH and not restrictions:
            allowed = True  # 高风险但无具体限制时允许执行
        
        return RiskAssessment(
            overall_risk=overall_risk,
            risk_score=min(risk_score, 1.0),
            risk_factors=risk_factors,
            recommendations=recommendations,
            allowed=allowed,
            restrictions=restrictions
        )
    
    def _check_rule_violation(
        self,
        rule: SecurityRule,
        tool_call: ToolCall,
        context: Dict[str, Any]
    ) -> Optional[str]:
        """检查规则违规。"""
        
        if rule.name == "trading_amount_limit":
            return self._check_trading_amount_limit(rule, tool_call, context)
        elif rule.name == "trading_frequency_limit":
            return self._check_trading_frequency_limit(rule, tool_call, context)
        elif rule.name == "risk_exposure_limit":
            return self._check_risk_exposure_limit(rule, tool_call, context)
        elif rule.name == "market_condition_limit":
            return self._check_market_condition_limit(rule, tool_call, context)
        
        return None
    
    def _check_trading_amount_limit(
        self,
        rule: SecurityRule,
        tool_call: ToolCall,
        context: Dict[str, Any]
    ) -> Optional[str]:
        """检查交易金额限制。"""
        if tool_call.function_name not in ["place_order", "execute_trade"]:
            return None
        
        # 获取交易金额
        quantity = tool_call.parameters.get("quantity", 0)
        price = tool_call.parameters.get("price", 0)
        amount_usd = quantity * price
        
        max_amount = rule.parameters.get("max_amount_usd", 10000)
        if amount_usd > max_amount:
            return f"Trade amount ${amount_usd:.2f} exceeds limit ${max_amount:.2f}"
        
        return None
    
    def _check_trading_frequency_limit(
        self,
        rule: SecurityRule,
        tool_call: ToolCall,
        context: Dict[str, Any]
    ) -> Optional[str]:
        """检查交易频率限制。"""
        if tool_call.function_name not in ["place_order", "execute_trade"]:
            return None
        
        now = datetime.now()
        hour_ago = now - timedelta(hours=1)
        
        # 统计最近一小时的交易次数
        recent_trades = [
            v for v in self.violations
            if v.timestamp > hour_ago and v.event_type == SecurityEvent.RATE_LIMIT_EXCEEDED
        ]
        
        max_trades_per_hour = rule.parameters.get("max_trades_per_hour", 10)
        if len(recent_trades) >= max_trades_per_hour:
            return f"Trading frequency limit exceeded: {len(recent_trades)}/{max_trades_per_hour} per hour"
        
        return None
    
    def _check_risk_exposure_limit(
        self,
        rule: SecurityRule,
        tool_call: ToolCall,
        context: Dict[str, Any]
    ) -> Optional[str]:
        """检查风险敞口限制。"""
        # 这里需要根据实际的投资组合数据来计算
        # 暂时返回None，实际实现需要接入投资组合数据
        return None
    
    def _check_market_condition_limit(
        self,
        rule: SecurityRule,
        tool_call: ToolCall,
        context: Dict[str, Any]
    ) -> Optional[str]:
        """检查市场条件限制。"""
        # 这里需要根据实际的市场数据来判断
        # 暂时返回None，实际实现需要接入市场数据
        return None
    
    def _check_rate_limits(self, tool_name: str) -> Optional[str]:
        """检查API调用频率限制。"""
        now = datetime.now()
        minute_ago = now - timedelta(minutes=1)
        
        # 初始化工具调用记录
        if tool_name not in self.rate_limits:
            self.rate_limits[tool_name] = []
        
        # 清理过期记录
        self.rate_limits[tool_name] = [
            timestamp for timestamp in self.rate_limits[tool_name]
            if timestamp > minute_ago
        ]
        
        # 检查频率限制
        calls_per_minute = len(self.rate_limits[tool_name])
        max_calls = 60  # 默认每分钟最多60次调用
        
        if calls_per_minute >= max_calls:
            return f"Rate limit exceeded for {tool_name}: {calls_per_minute}/{max_calls} per minute"
        
        # 记录本次调用
        self.rate_limits[tool_name].append(now)
        return None
    
    def _generate_recommendations(
        self,
        risk_factors: List[str],
        risk_level: RiskLevel
    ) -> List[str]:
        """生成安全建议。"""
        recommendations = []
        
        if risk_level == RiskLevel.CRITICAL:
            recommendations.append("立即停止操作，等待人工审核")
            recommendations.append("检查系统状态和市场条件")
        elif risk_level == RiskLevel.HIGH:
            recommendations.append("谨慎执行，建议降低交易规模")
            recommendations.append("增加风险监控频率")
        elif risk_level == RiskLevel.MEDIUM:
            recommendations.append("正常执行，保持风险监控")
        
        if "Rate limit exceeded" in risk_factors:
            recommendations.append("等待一段时间后重试")
        
        if "High risk tool" in str(risk_factors):
            recommendations.append("确认操作参数正确")
            recommendations.append("考虑使用更保守的策略")
        
        return recommendations
    
    def record_violation(
        self,
        event_type: SecurityEvent,
        rule_name: str,
        severity: RiskLevel,
        details: Dict[str, Any],
        tool_call: Optional[ToolCall] = None
    ) -> None:
        """记录安全违规。"""
        violation = SecurityViolation(
            timestamp=datetime.now(),
            event_type=event_type,
            rule_name=rule_name,
            severity=severity,
            details=details,
            tool_call=tool_call
        )
        
        self.violations.append(violation)
        
        # 记录日志
        self.logger.warning(
            f"Security violation: {event_type.value} - {rule_name} "
            f"(Severity: {severity.value})"
        )
        
        # 如果是严重违规，考虑触发紧急停止
        if severity == RiskLevel.CRITICAL:
            self.logger.critical("Critical security violation detected")
    
    def activate_emergency_stop(self, reason: str) -> None:
        """激活紧急停止。"""
        self.emergency_stop = True
        self.logger.critical(f"Emergency stop activated: {reason}")
        
        self.record_violation(
            SecurityEvent.EMERGENCY_STOP,
            "emergency_stop",
            RiskLevel.CRITICAL,
            {"reason": reason}
        )
    
    def deactivate_emergency_stop(self) -> None:
        """解除紧急停止。"""
        self.emergency_stop = False
        self.logger.info("Emergency stop deactivated")
    
    def get_security_status(self) -> Dict[str, Any]:
        """获取安全状态。"""
        recent_violations = [
            v for v in self.violations
            if v.timestamp > datetime.now() - timedelta(hours=24)
        ]
        
        return {
            "emergency_stop": self.emergency_stop,
            "total_rules": len(self.security_rules),
            "enabled_rules": len([r for r in self.security_rules.values() if r.enabled]),
            "total_violations": len(self.violations),
            "recent_violations": len(recent_violations),
            "critical_violations": len([
                v for v in recent_violations
                if v.severity == RiskLevel.CRITICAL
            ]),
            "last_violation": (
                self.violations[-1].timestamp.isoformat()
                if self.violations else None
            )
        }
    
    def export_audit_log(self, hours: int = 24) -> List[Dict[str, Any]]:
        """导出审计日志。"""
        cutoff_time = datetime.now() - timedelta(hours=hours)
        recent_violations = [
            v for v in self.violations
            if v.timestamp > cutoff_time
        ]
        
        return [
            {
                "timestamp": v.timestamp.isoformat(),
                "event_type": v.event_type.value,
                "rule_name": v.rule_name,
                "severity": v.severity.value,
                "details": v.details,
                "tool_call": (
                    {
                        "function_name": v.tool_call.function_name,
                        "parameters": v.tool_call.parameters
                    } if v.tool_call else None
                ),
                "resolved": v.resolved
            }
            for v in recent_violations
        ]


# 全局安全防护实例
global_security_guard = SecurityGuard()


__all__ = [
    "RiskLevel",
    "SecurityEvent", 
    "SecurityRule",
    "SecurityViolation",
    "RiskAssessment",
    "SecurityGuard",
    "global_security_guard"
]