"""交易记忆系统 - 三层记忆架构。

实现短期、中期、长期记忆，为AI决策提供上下文连续性。
"""

from __future__ import annotations

import json
from collections import deque
from dataclasses import asdict, dataclass, field
from datetime import datetime, timedelta, timezone
from pathlib import Path
from typing import Any, Dict, List, Optional

import structlog

from lacopro.config.settings import Settings, get_settings
from lacopro.memory.store import Database

logger = structlog.get_logger("memory_system")


@dataclass
class DecisionContext:
    """单次决策上下文（短期记忆单元）。"""

    timestamp: datetime
    decision_type: str  # BUY, SELL, HOLD
    symbol: str
    confidence: float
    risk_level: str

    # 决策理由和分析
    reasoning: str
    market_view: str  # 对市场的整体判断
    opportunity_coins: List[str]  # 发现的机会币种
    rejected_reason: Optional[str] = None  # 如果是HOLD，为什么拒绝交易

    # 工具调用信息
    tools_used: List[str] = field(default_factory=list)
    reasoning_rounds: int = 0

    # 执行结果（异步填充）
    execution_result: Optional[str] = None  # SUCCESS, FAILED, PENDING
    actual_pnl: Optional[float] = None  # 如果已平仓，实际盈亏

    def to_dict(self) -> Dict[str, Any]:
        """转换为字典（用于序列化）"""
        data = asdict(self)
        data["timestamp"] = self.timestamp.isoformat()
        return data

    @classmethod
    def from_dict(cls, data: Dict[str, Any]) -> DecisionContext:
        """从字典恢复"""
        data["timestamp"] = datetime.fromisoformat(data["timestamp"])
        return cls(**data)

    def to_summary(self) -> str:
        """生成简短摘要"""
        time_str = self.timestamp.strftime("%H:%M")
        action = f"{self.decision_type} {self.symbol}" if self.symbol != "NO_TRADE" else "HOLD"

        if self.rejected_reason:
            return f"{time_str}: {action} (拒绝: {self.rejected_reason[:50]}...)"

        if self.market_view:
            return f"{time_str}: {action} | 市场: {self.market_view[:40]}..."

        return f"{time_str}: {action} | {self.reasoning[:50]}..."


@dataclass
class DailySummary:
    """每日交易总结（中期记忆）。"""

    date: str  # YYYY-MM-DD
    market_trend: str  # 市场整体趋势判断
    key_events: List[str]  # 关键事件

    # 交易统计
    total_decisions: int
    buy_count: int
    sell_count: int
    hold_count: int
    executed_trades: int

    # 绩效指标
    win_rate: float
    total_pnl: float
    best_trade: Optional[str] = None
    worst_trade: Optional[str] = None

    # 经验教训（LLM生成）
    lessons_learned: str = ""
    strategy_notes: str = ""

    def to_dict(self) -> Dict[str, Any]:
        return asdict(self)

    @classmethod
    def from_dict(cls, data: Dict[str, Any]) -> DailySummary:
        return cls(**data)


@dataclass
class StrategyStatistics:
    """策略统计信息（长期记忆）。"""

    # 按条件分类的胜率
    win_rate_by_condition: Dict[str, float] = field(default_factory=dict)

    # 最佳/最差表现模式
    best_performing_patterns: List[str] = field(default_factory=list)
    avoid_patterns: List[str] = field(default_factory=list)

    # 工具使用统计
    tool_effectiveness: Dict[str, float] = field(default_factory=dict)

    # 时间段表现
    performance_by_hour: Dict[int, float] = field(default_factory=dict)

    # 币种历史表现
    symbol_win_rate: Dict[str, float] = field(default_factory=dict)

    last_updated: datetime = field(default_factory=lambda: datetime.now(timezone.utc))

    def to_dict(self) -> Dict[str, Any]:
        data = asdict(self)
        data["last_updated"] = self.last_updated.isoformat()
        return data

    @classmethod
    def from_dict(cls, data: Dict[str, Any]) -> StrategyStatistics:
        data["last_updated"] = datetime.fromisoformat(data["last_updated"])
        return cls(**data)


class TradingMemory:
    """三层交易记忆系统。

    - 短期记忆：最近5个决策周期（150分钟）
    - 中期记忆：最近7天的每日总结
    - 长期记忆：历史统计和策略演化
    """

    def __init__(
        self,
        database: Optional[Database] = None,
        settings: Optional[Settings] = None,
        max_short_term: int = 5,
        max_mid_term_days: int = 7,
    ) -> None:
        self.database = database
        self.settings = settings or get_settings()
        self.max_short_term = max_short_term
        self.max_mid_term_days = max_mid_term_days

        # 短期记忆（内存deque，快速访问）
        self.short_term: deque[DecisionContext] = deque(maxlen=max_short_term)

        # 中期记忆（文件缓存 + 数据库）
        self.mid_term: List[DailySummary] = []

        # 长期记忆（统计数据）
        self.long_term: StrategyStatistics = StrategyStatistics()

        # 持久化路径
        self.memory_dir = Path.home() / ".lacopro" / "memory"
        self.memory_dir.mkdir(parents=True, exist_ok=True)

        self._load_from_disk()

        logger.info(
            "memory_system.initialized",
            short_term_size=len(self.short_term),
            mid_term_size=len(self.mid_term),
            long_term_updated=self.long_term.last_updated.isoformat(),
        )

    def add_decision(self, context: DecisionContext) -> None:
        """添加决策到短期记忆。"""
        self.short_term.append(context)
        logger.info(
            "memory_system.decision_added",
            decision=context.decision_type,
            symbol=context.symbol,
            short_term_size=len(self.short_term),
        )

        # 异步持久化
        self._save_to_disk()

    def get_short_term_summary(self) -> str:
        """获取短期记忆摘要（用于prompt）。"""
        if not self.short_term:
            return "【短期记忆】\n无历史决策记录\n"

        summaries = []
        for ctx in self.short_term:
            summaries.append(ctx.to_summary())

        return f"""【短期记忆 - 最近{len(self.short_term)}个决策周期】
{chr(10).join(summaries)}
"""

    def get_mid_term_summary(self) -> str:
        """获取中期记忆摘要（用于prompt）。"""
        if not self.mid_term:
            return "【中期记忆】\n暂无每日总结\n"

        # 只取最近3天的总结
        recent_summaries = self.mid_term[-3:]
        lines = []

        for summary in recent_summaries:
            lines.append(f"• {summary.date}: {summary.market_trend}")
            if summary.lessons_learned:
                lines.append(f"  教训: {summary.lessons_learned[:80]}...")

        return f"""【中期记忆 - 最近{len(recent_summaries)}天总结】
{chr(10).join(lines)}
"""

    def get_long_term_summary(self) -> str:
        """获取长期记忆摘要（用于prompt）。"""
        stats = self.long_term

        lines = ["【长期统计 - 历史经验】"]

        # 最佳模式
        if stats.best_performing_patterns:
            lines.append("✅ 高胜率模式:")
            for pattern in stats.best_performing_patterns[:3]:
                lines.append(f"  • {pattern}")

        # 规避模式
        if stats.avoid_patterns:
            lines.append("❌ 规避模式:")
            for pattern in stats.avoid_patterns[:3]:
                lines.append(f"  • {pattern}")

        # 币种表现
        if stats.symbol_win_rate:
            sorted_symbols = sorted(
                stats.symbol_win_rate.items(),
                key=lambda x: x[1],
                reverse=True,
            )[:3]
            lines.append("📊 币种历史表现:")
            for symbol, win_rate in sorted_symbols:
                lines.append(f"  • {symbol}: {win_rate:.1%}胜率")

        if len(lines) == 1:
            lines.append("暂无足够历史数据")

        return "\n".join(lines) + "\n"

    def get_full_context(self) -> str:
        """获取完整记忆上下文（用于AI决策）。"""
        return f"""{self.get_short_term_summary()}
{self.get_mid_term_summary()}
{self.get_long_term_summary()}"""

    async def generate_daily_summary(
        self,
        date: Optional[str] = None,
    ) -> DailySummary:
        """生成每日总结（通过数据库查询今日决策）。"""
        if date is None:
            date = datetime.now(timezone.utc).strftime("%Y-%m-%d")

        # TODO: 从数据库查询今日所有决策和交易结果
        # 这里先生成一个占位总结
        summary = DailySummary(
            date=date,
            market_trend="震荡",
            key_events=[],
            total_decisions=len(self.short_term),
            buy_count=sum(1 for d in self.short_term if d.decision_type == "BUY"),
            sell_count=sum(1 for d in self.short_term if d.decision_type == "SELL"),
            hold_count=sum(1 for d in self.short_term if d.decision_type == "HOLD"),
            executed_trades=0,
            win_rate=0.0,
            total_pnl=0.0,
            lessons_learned="待生成",
        )

        self.mid_term.append(summary)

        # 保持最近N天
        if len(self.mid_term) > self.max_mid_term_days:
            self.mid_term.pop(0)

        self._save_to_disk()

        logger.info("memory_system.daily_summary_generated", date=date)
        return summary

    def update_long_term_statistics(self) -> None:
        """更新长期统计（基于历史数据）。"""
        # TODO: 从数据库查询历史交易，计算统计指标
        # 这里先占位

        self.long_term.last_updated = datetime.now(timezone.utc)
        self._save_to_disk()

        logger.info("memory_system.long_term_statistics_updated")

    def _save_to_disk(self) -> None:
        """持久化记忆到磁盘。"""
        try:
            # 短期记忆
            short_term_file = self.memory_dir / "short_term.json"
            with open(short_term_file, "w", encoding="utf-8") as f:
                data = [ctx.to_dict() for ctx in self.short_term]
                json.dump(data, f, ensure_ascii=False, indent=2)

            # 中期记忆
            mid_term_file = self.memory_dir / "mid_term.json"
            with open(mid_term_file, "w", encoding="utf-8") as f:
                data = [summary.to_dict() for summary in self.mid_term]
                json.dump(data, f, ensure_ascii=False, indent=2)

            # 长期记忆
            long_term_file = self.memory_dir / "long_term.json"
            with open(long_term_file, "w", encoding="utf-8") as f:
                json.dump(self.long_term.to_dict(), f, ensure_ascii=False, indent=2)

            logger.debug("memory_system.saved_to_disk")

        except Exception as exc:
            logger.error("memory_system.save_failed", error=str(exc))

    def _load_from_disk(self) -> None:
        """从磁盘加载记忆。"""
        try:
            # 短期记忆
            short_term_file = self.memory_dir / "short_term.json"
            if short_term_file.exists():
                with open(short_term_file, encoding="utf-8") as f:
                    data = json.load(f)
                    for item in data:
                        self.short_term.append(DecisionContext.from_dict(item))

            # 中期记忆
            mid_term_file = self.memory_dir / "mid_term.json"
            if mid_term_file.exists():
                with open(mid_term_file, encoding="utf-8") as f:
                    data = json.load(f)
                    self.mid_term = [DailySummary.from_dict(item) for item in data]

            # 长期记忆
            long_term_file = self.memory_dir / "long_term.json"
            if long_term_file.exists():
                with open(long_term_file, encoding="utf-8") as f:
                    data = json.load(f)
                    self.long_term = StrategyStatistics.from_dict(data)

            logger.debug(
                "memory_system.loaded_from_disk",
                short_term=len(self.short_term),
                mid_term=len(self.mid_term),
            )

        except Exception as exc:
            logger.warning("memory_system.load_failed", error=str(exc))


__all__ = [
    "TradingMemory",
    "DecisionContext",
    "DailySummary",
    "StrategyStatistics",
]
