"""交易日调度：盘前同步、盘中工作流、收盘结算。"""

from __future__ import annotations

import asyncio
from datetime import date, datetime, time
from typing import Awaitable, Callable, Optional

from apscheduler.schedulers.asyncio import AsyncIOScheduler
from apscheduler.triggers.date import DateTrigger

from backend.infra.logging import logger
from backend.infra.metrics import increment_counter
from backend.agent.workflow import AgentWorkflow, WorkflowContext
from backend.agent.providers import AccountProvider, SnapshotProvider
from backend.infra.notifications import FeishuWebhookNotifier
from backend.market.toolkit.toolkit import TushareToolkit
from .factory import create_async_scheduler

AsyncVoid = Callable[[], Awaitable[None]]
AsyncSettlement = Callable[[date], Awaitable[None]]
AsyncBool = Callable[[], Awaitable[bool]]
GetNextTradingDay = Callable[[date], Optional[date]]
NowFunc = Callable[[], datetime]


class TradingDayScheduler:
    """统一调度交易日关键任务。"""

    def __init__(
        self,
        *,
        agent_workflow: AgentWorkflow,
        snapshot_provider: SnapshotProvider,
        account_provider: AccountProvider,
        market_sync_job: AsyncVoid,
        settlement_job: AsyncSettlement,
        is_trading_day: Callable[[], bool],
        notifier: Optional[FeishuWebhookNotifier] = None,
        scheduler: Optional[AsyncIOScheduler] = None,
        get_next_trading_day: Optional[GetNextTradingDay] = None,
        now_fn: Optional[NowFunc] = None,
    ) -> None:
        self._workflow = agent_workflow
        self._snapshot_provider = snapshot_provider
        self._account_provider = account_provider
        self._market_sync_job = market_sync_job
        self._settlement_job = settlement_job
        self._is_trading_day = is_trading_day
        self._notifier = notifier
        self._scheduler = scheduler or create_async_scheduler()
        self._get_next_trading_day = get_next_trading_day
        self._calendar_job_ids: list[str] = []
        self._now: NowFunc = now_fn or datetime.now

    def start(self) -> None:
        """注册定时任务并启动事件循环。"""
        self._scheduler.add_job(self._schedule_market_sync, "cron", day_of_week="0-4", hour=7, minute=30)
        self._scheduler.add_job(self._schedule_workflow_run, "cron", day_of_week="0-4", hour="9-14", minute="*/20")
        self._scheduler.add_job(self._schedule_settlement, "cron", day_of_week="0-4", hour=15, minute=30)
        self._scheduler.start()
        try:
            asyncio.get_event_loop().run_forever()
        except KeyboardInterrupt:
            self._scheduler.shutdown()

    def start_with_calendar(self) -> None:
        """基于交易日历动态调整调度计划。"""
        if not self._get_next_trading_day:
            raise ValueError("get_next_trading_day 未提供，无法启用交易日历模式")
        self._schedule_next_trading_day()
        self._scheduler.add_job(self._schedule_next_trading_day, "cron", hour=6, minute=0)
        self._scheduler.start()
        try:
            asyncio.get_event_loop().run_forever()
        except KeyboardInterrupt:
            self._scheduler.shutdown()

    def plan_for_trading_day(self, trading_day: date) -> None:
        """为指定交易日注册所有任务。"""
        self._clear_calendar_jobs()
        now = self._now()
        schedule_points = [
            (self._schedule_market_sync, time(7, 30)),
            (self._schedule_settlement, time(15, 30)),
        ]

        workflow_times = []
        for hour in range(9, 15):
            for minute in (0, 20, 40):
                if hour == 14 and minute > 40:
                    continue
                workflow_times.append(time(hour, minute))

        for minute_time in workflow_times:
            schedule_points.append((self._schedule_workflow_run, minute_time))

        for callback, clock_time in schedule_points:
            run_dt = datetime.combine(trading_day, clock_time)
            if run_dt <= now:
                continue
            job = self._scheduler.add_job(callback, trigger=DateTrigger(run_date=run_dt))
            self._calendar_job_ids.append(job.id)
        logger.info(
            "scheduler:calendar_planned",
            trading_day=str(trading_day),
            jobs=len(self._calendar_job_ids),
        )

    def _schedule_next_trading_day(self) -> None:
        if not self._get_next_trading_day:
            return
        today = date.today()
        now = self._now()
        if self._is_trading_day() and now.date() == today and now.time() < time(15, 0):
            target_day = today
        else:
            target_day = self._get_next_trading_day(today)
        if not target_day:
            logger.warning("scheduler:calendar_missing_next", reference=str(today))
            return
        self.plan_for_trading_day(target_day)

    def _clear_calendar_jobs(self) -> None:
        for job_id in self._calendar_job_ids:
            try:
                self._scheduler.remove_job(job_id)
            except Exception:  # noqa: BLE001
                continue
        self._calendar_job_ids.clear()

    async def run_workflow_once(self) -> Optional[WorkflowContext]:
        if not self._is_trading_day():
            logger.info("scheduler:workflow_skip", reason="non_trading_day")
            increment_counter("scheduler_workflow_skipped_total")
            return None
        snapshot = await self._snapshot_provider()
        account = await self._account_provider()
        try:
            result = await self._workflow.run(snapshot=snapshot, account=account)
        except Exception as exc:  # noqa: BLE001
            logger.error("scheduler:workflow_failed", error=str(exc))
            increment_counter("scheduler_workflow_failed_total")
            if self._notifier:
                await self._notifier.send_alert(
                    "Trading workflow failed",
                    str(exc),
                    level="ERROR",
                )
            raise
        logger.info(
            "scheduler:workflow_done",
            broadcast=bool(result.broadcast),
            risk_status=result.risk["status"] if result.risk else None,
        )
        increment_counter("scheduler_workflow_run_total")
        if result.risk and result.risk.get("status") == "FLAGGED":
            increment_counter("scheduler_workflow_flagged_total")
            if self._notifier:
                warnings = ", ".join(result.risk.get("warnings", [])) if result.risk else ""
                await self._notifier.send_alert(
                    "Workflow risk flagged",
                    warnings or "Risk node returned FLAGGED",
                    level="WARNING",
                )
        return result

    async def run_market_sync_once(self) -> bool:
        if not self._is_trading_day():
            logger.info("scheduler:market_sync_skip", reason="non_trading_day")
            increment_counter("scheduler_market_sync_skipped_total")
            return False
        try:
            await self._market_sync_job()
        except Exception as exc:  # noqa: BLE001
            logger.exception("scheduler:market_sync_failed", error=str(exc))
            increment_counter("scheduler_market_sync_failed_total")
            if self._notifier:
                await self._notifier.send_alert(
                    "Market sync failed",
                    str(exc),
                    level="ERROR",
                )
            return False
        logger.info("scheduler:market_sync_done")
        increment_counter("scheduler_market_sync_total")
        return True

    async def run_settlement_once(self) -> bool:
        today = date.today()
        if not self._is_trading_day():
            logger.info("scheduler:settlement_skip", reason="non_trading_day")
            increment_counter("scheduler_settlement_skipped_total")
            return False
        try:
            await self._settlement_job(today)
        except Exception as exc:  # noqa: BLE001
            logger.exception("scheduler:settlement_failed", error=str(exc))
            increment_counter("scheduler_settlement_failed_total")
            if self._notifier:
                await self._notifier.send_alert(
                    "Settlement failed",
                    str(exc),
                    level="ERROR",
                )
            return False
        logger.info("scheduler:settlement_done", date=str(today))
        increment_counter("scheduler_settlement_total")
        return True

    @classmethod
    def from_toolkit(
        cls,
        *,
        toolkit: TushareToolkit,
        agent_workflow: AgentWorkflow,
        snapshot_provider: SnapshotProvider,
        account_provider: AccountProvider,
        market_sync_job: AsyncVoid,
        settlement_job: AsyncSettlement,
        notifier: Optional[FeishuWebhookNotifier] = None,
        scheduler: Optional[AsyncIOScheduler] = None,
        get_next_trading_day: Optional[GetNextTradingDay] = None,
        now_fn: Optional[NowFunc] = None,
    ) -> "TradingDayScheduler":
        def is_trading_day() -> bool:
            return toolkit.is_trading_day()

        def next_trading_day(reference: date) -> Optional[date]:
            return toolkit.next_trading_day(reference, include_today=False)

        return cls(
            agent_workflow=agent_workflow,
            snapshot_provider=snapshot_provider,
            account_provider=account_provider,
            market_sync_job=market_sync_job,
            settlement_job=settlement_job,
            is_trading_day=is_trading_day,
            notifier=notifier,
            scheduler=scheduler,
            get_next_trading_day=get_next_trading_day or next_trading_day,
            now_fn=now_fn,
        )

    def _schedule_market_sync(self) -> None:
        asyncio.create_task(self.run_market_sync_once())

    def _schedule_workflow_run(self) -> None:
        asyncio.create_task(self.run_workflow_once())

    def _schedule_settlement(self) -> None:
        asyncio.create_task(self.run_settlement_once())


__all__ = [
    "TradingDayScheduler",
]
