from __future__ import annotations

from typing import Optional
from time import perf_counter

from backend.infra import get_async_session
from backend.market.toolkit.toolkit import TushareToolkit
from backend.market.storage import (
    upsert_hsgt_moneyflow,
    upsert_index_daily,
    upsert_index_weight,
    upsert_stock_basic,
    upsert_stock_daily,
    upsert_stock_daily_basic,
    upsert_stock_moneyflow,
    upsert_financial_indicator,
)
from backend.infra.logging import log_context, logger
from backend.infra.metrics import increment_counter, observe_histogram
from tenacity import retry, stop_after_attempt, wait_exponential


class MarketDataSyncService:
    """协调Tushare工具箱与数据库存储的同步服务。"""

    def __init__(self, toolkit: TushareToolkit, *, session_factory=None) -> None:
        self.toolkit = toolkit
        self.session_factory = session_factory or get_async_session()
        self._duration_buckets = (0.5, 1.0, 2.0, 3.0, 5.0, 10.0, 30.0, 60.0)

    @retry(stop=stop_after_attempt(3), wait=wait_exponential(min=2, max=10))
    async def sync_stock_master(self) -> None:
        with log_context(task="sync_stock_master"):
            start_time = perf_counter()
            logger.info("sync_stock_master:start")
            try:
                df = self.toolkit.fetch_stock_basic()
                async with self.session_factory() as session:
                    await upsert_stock_basic(session, df)
                    await session.commit()
                logger.info("sync_stock_master:done", rows=len(df))
                increment_counter("market_sync_master_success_total")
            except Exception as exc:  # noqa: BLE001
                logger.exception("sync_stock_master:error", error=str(exc))
                increment_counter("market_sync_master_failed_total")
                raise
            finally:
                elapsed = perf_counter() - start_time
                observe_histogram(
                    "market_sync_duration_seconds",
                    elapsed,
                    buckets=self._duration_buckets,
                    task="stock_master",
                )

    @retry(stop=stop_after_attempt(3), wait=wait_exponential(min=2, max=10))
    async def sync_stock_daily(self, ts_code: str, days: int = 120) -> None:
        with log_context(task="sync_stock_daily", ts_code=ts_code):
            logger.info("sync_stock_daily:start", ts_code=ts_code, days=days)
            start_time = perf_counter()
            try:
                daily = self.toolkit.get_recent_daily(ts_code, days)
                basic = self.toolkit.get_daily_basic(ts_code, limit=min(days, 60))
                moneyflow = self.toolkit.get_moneyflow(ts_code, days=days)
                indicators = self.toolkit.get_financial_indicator(ts_code, limit=min(days, 8))
                async with self.session_factory() as session:
                    if not daily.empty:
                        await upsert_stock_daily(session, daily)
                    if not basic.empty:
                        await upsert_stock_daily_basic(session, basic)
                    if not moneyflow.empty:
                        await upsert_stock_moneyflow(session, moneyflow)
                    if not indicators.empty:
                        await upsert_financial_indicator(session, indicators)
                    await session.commit()
                logger.info(
                    "sync_stock_daily:done",
                    ts_code=ts_code,
                    daily_rows=len(daily),
                    basic_rows=len(basic),
                    moneyflow_rows=len(moneyflow),
                )
                increment_counter("market_sync_stock_success_total", ts_code=ts_code)
            except Exception as exc:  # noqa: BLE001
                logger.exception("sync_stock_daily:error", ts_code=ts_code, error=str(exc))
                increment_counter("market_sync_stock_failed_total", ts_code=ts_code)
                raise
            finally:
                elapsed = perf_counter() - start_time
                observe_histogram(
                    "market_sync_duration_seconds",
                    elapsed,
                    buckets=self._duration_buckets,
                    task="stock_daily",
                    ts_code=ts_code,
                )
                observe_histogram(
                    "market_sync_duration_seconds",
                    elapsed,
                    buckets=self._duration_buckets,
                    task="financial_indicator",
                    ts_code=ts_code,
                )

    @retry(stop=stop_after_attempt(3), wait=wait_exponential(min=2, max=10))
    async def sync_index_daily(self, ts_code: str, days: int = 120) -> None:
        with log_context(task="sync_index_daily", ts_code=ts_code):
            logger.info("sync_index_daily:start", ts_code=ts_code, days=days)
            start_time = perf_counter()
            try:
                index_df = self.toolkit.get_index_daily(ts_code, days)
                async with self.session_factory() as session:
                    if not index_df.empty:
                        await upsert_index_daily(session, index_df)
                    await session.commit()
                logger.info("sync_index_daily:done", ts_code=ts_code, rows=len(index_df))
                increment_counter("market_sync_index_success_total", ts_code=ts_code)
            except Exception as exc:  # noqa: BLE001
                logger.exception("sync_index_daily:error", ts_code=ts_code, error=str(exc))
                increment_counter("market_sync_index_failed_total", ts_code=ts_code)
                raise
            finally:
                elapsed = perf_counter() - start_time
                observe_histogram(
                    "market_sync_duration_seconds",
                    elapsed,
                    buckets=self._duration_buckets,
                    task="index_daily",
                    ts_code=ts_code,
                )

    @retry(stop=stop_after_attempt(3), wait=wait_exponential(min=2, max=10))
    async def sync_hsgt_moneyflow(self, start_date: Optional[str] = None, end_date: Optional[str] = None) -> None:
        with log_context(task="sync_hsgt_moneyflow", start=start_date, end=end_date):
            logger.info("sync_hsgt_moneyflow:start", start=start_date, end=end_date)
            start_time = perf_counter()
            try:
                df = self.toolkit.get_hsgt_moneyflow(start_date, end_date)
                async with self.session_factory() as session:
                    if not df.empty:
                        await upsert_hsgt_moneyflow(session, df)
                    await session.commit()
                logger.info("sync_hsgt_moneyflow:done", rows=len(df))
                increment_counter("market_sync_hsgt_success_total")
            except Exception as exc:  # noqa: BLE001
                logger.exception("sync_hsgt_moneyflow:error", error=str(exc))
                increment_counter("market_sync_hsgt_failed_total")
                raise
            finally:
                elapsed = perf_counter() - start_time
                observe_histogram(
                    "market_sync_duration_seconds",
                    elapsed,
                    buckets=self._duration_buckets,
                    task="hsgt",
                )

    async def sync_multi_stocks(
        self,
        ts_codes: list[str],
        days: int = 120,
        *,
        batch_size: int = 10,
    ) -> None:
        batched = [ts_codes[i : i + batch_size] for i in range(0, len(ts_codes), batch_size)]
        for batch in batched:
            with log_context(task="sync_multi_stocks", batch=batch):
                logger.info("sync_multi_stocks:batch_start", batch=batch)
                start_time = perf_counter()
                async with self.session_factory() as session:
                    for ts_code in batch:
                        daily = self.toolkit.get_recent_daily(ts_code, days)
                        basic = self.toolkit.get_daily_basic(ts_code, limit=min(days, 60))
                        moneyflow = self.toolkit.get_moneyflow(ts_code, days=days)
                        if not daily.empty:
                            await upsert_stock_daily(session, daily)
                        if not basic.empty:
                            await upsert_stock_daily_basic(session, basic)
                        if not moneyflow.empty:
                            await upsert_stock_moneyflow(session, moneyflow)
                    await session.commit()
                logger.info("sync_multi_stocks:batch_done", batch=batch)
                elapsed = perf_counter() - start_time
                observe_histogram(
                    "market_sync_duration_seconds",
                    elapsed,
                    buckets=self._duration_buckets,
                    task="multi_stock_batch",
                    batch_size=str(len(batch)),
                )

    async def sync_all(
        self,
        *,
        index_codes: Optional[list[str]] = None,
        stock_codes: Optional[list[str]] = None,
        days: int = 120,
        batch_size: int = 10,
    ) -> None:
        with log_context(task="sync_all"):
            logger.info("sync_all:start", stocks=len(stock_codes or []), indexes=len(index_codes or []))
            await self.sync_stock_master()
            if stock_codes:
                await self.sync_multi_stocks(stock_codes, days, batch_size=batch_size)
            if index_codes:
                for code in index_codes:
                    await self.sync_index_daily(code, days)
                await self.sync_index_weights(index_codes)
            await self.sync_hsgt_moneyflow()
            logger.info("sync_all:done")

    @retry(stop=stop_after_attempt(3), wait=wait_exponential(min=2, max=10))
    async def sync_index_weights(
        self,
        index_codes: list[str],
        *,
        trade_date: Optional[str] = None,
    ) -> None:
        if not index_codes:
            return
        with log_context(task="sync_index_weight", index_codes=index_codes):
            logger.info("sync_index_weight:start", count=len(index_codes))
            start_time = perf_counter()
            reference = trade_date
            if reference is None:
                last_day = self.toolkit.last_trading_day()
                if last_day is None:
                    logger.warning("sync_index_weight:no_trade_day")
                    return
                reference = last_day.strftime("%Y%m%d")
            async with self.session_factory() as session:
                for code in index_codes:
                    weights = self.toolkit.get_index_weight(code, reference)
                    if weights.empty:
                        logger.info("sync_index_weight:empty", index_code=code, trade_date=reference)
                        continue
                    await upsert_index_weight(session, weights)
                await session.commit()
            elapsed = perf_counter() - start_time
            observe_histogram(
                "market_sync_duration_seconds",
                elapsed,
                buckets=self._duration_buckets,
                task="index_weight",
            )
            logger.info("sync_index_weight:done", trade_date=reference)


async def bootstrap_default_service(token: str) -> MarketDataSyncService:
    toolkit = TushareToolkit(token)
    service = MarketDataSyncService(toolkit)
    await service.sync_stock_master()
    return service
