from fastapi import FastAPI, WebSocket, WebSocketDisconnect, Depends
from fastapi.staticfiles import StaticFiles
from fastapi.middleware.cors import CORSMiddleware
import json
import asyncio
from typing import List
from datetime import datetime

import sys
import os
sys.path.append(os.path.dirname(os.path.abspath(__file__)))

from database import init_database, get_db, Message
from agents import OfficeManager  
from data_sources import AStockDataManager
from scheduler import ConversationScheduler
import config

app = FastAPI(title="A股AI分析室", version="1.0.0")

# CORS settings
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

class WebSocketManager:
    """WebSocket connection manager"""
    
    def __init__(self):
        self.active_connections: List[WebSocket] = []
    
    async def connect(self, websocket: WebSocket):
        await websocket.accept()
        self.active_connections.append(websocket)
        print(f"Client connected. Total connections: {len(self.active_connections)}")
    
    def disconnect(self, websocket: WebSocket):
        if websocket in self.active_connections:
            self.active_connections.remove(websocket)
        print(f"Client disconnected. Total connections: {len(self.active_connections)}")
    
    async def broadcast_message(self, message: dict):
        """Broadcast message to all connected clients"""
        if not self.active_connections:
            return
        
        message_json = json.dumps(message, ensure_ascii=False, default=str)
        disconnected = []
        
        for connection in self.active_connections:
            try:
                await connection.send_text(message_json)
            except Exception as e:
                print(f"Error sending message to client: {e}")
                disconnected.append(connection)
        
        # Clean up disconnected connections
        for connection in disconnected:
            self.disconnect(connection)

# Global components
office_manager = OfficeManager()
data_manager = AStockDataManager()
websocket_manager = WebSocketManager()
conversation_scheduler = ConversationScheduler(office_manager, data_manager)

@app.on_event("startup")
async def startup_event():
    """Application startup initialization"""
    print("启动A股AI分析室...")
    
    # Initialize database
    init_database()
    
    # Setup WebSocket manager
    conversation_scheduler.websocket_manager = websocket_manager
    
    # Get initial A股 data
    try:
        await data_manager.refresh_market_data()
        await data_manager.refresh_concept_data()
        print("A股基础数据加载成功")
        
        # Background async loading of additional data sources
        asyncio.create_task(load_additional_data())
    except Exception as e:
        print(f"Warning: Could not load initial data: {e}")

async def load_additional_data():
    """Background loading of additional A股 data sources"""
    try:
        # Load more concept and market sentiment data
        await data_manager.get_market_summary()
        print("A股扩展数据加载完成")
    except Exception as e:
        print(f"Warning: Could not load additional data: {e}")
    
    # Start conversation scheduler
    conversation_scheduler.start()
    
    print("A股AI分析室启动完成！")

@app.on_event("shutdown")
async def shutdown_event():
    """Cleanup when application shuts down"""
    print("Shutting down Multi-Agent AI Office...")
    conversation_scheduler.stop()

@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
    """WebSocket connection endpoint"""
    await websocket_manager.connect(websocket)
    
    try:
        # Send initial data
        await send_initial_data(websocket)
        
        # Keep connection alive and handle client messages
        while True:
            data = await websocket.receive_text()
            message = json.loads(data)
            
            # Handle client message
            await handle_client_message(message, websocket)
            
    except WebSocketDisconnect:
        websocket_manager.disconnect(websocket)
    except Exception as e:
        print(f"WebSocket error: {e}")
        websocket_manager.disconnect(websocket)

async def send_initial_data(websocket: WebSocket):
    """Send initial data to newly connected clients"""
    try:
        # Send agent information
        agents_data = {}
        for role, agent in office_manager.agents.items():
            agents_data[role] = {
                "name": agent.name,
                "role": role,
                "model": agent.model,
                "personality": agent.personality,
                "expertise": agent.expertise,
                "quirks": getattr(agent, 'quirks', []),
                "languages": getattr(agent, 'languages', None),
                "confidence_level": getattr(agent, 'confidence_level', None),
                "assertiveness": getattr(agent, 'current_assertiveness', None),
                "last_spoke": getattr(agent, 'last_spoke_time', None).isoformat() if getattr(agent, 'last_spoke_time', None) else None,
                "status": "active"
            }
        
        agents_info = {
            "type": "agents_info",
            "data": agents_data
        }
        
        print(f"Sending agents info: {len(agents_data)} agents")
        for role, data in agents_data.items():
            print(f"  - {role}: {data['name']} ({data['model']})")
        await websocket.send_text(json.dumps(agents_info, ensure_ascii=False, default=str))
        
        # Send latest market data
        market_data = await data_manager.get_latest_market_data()
        market_message = {
            "type": "market_data",
            "data": market_data
        }
        await websocket.send_text(json.dumps(market_message, ensure_ascii=False, default=str))
        
        # Send recent conversation history in one batch to reduce connection interruptions
        db = next(get_db())
        recent_messages = db.query(Message).order_by(Message.timestamp.desc()).limit(20).all()
        history_payload = {
            "type": "chat_history",
            "data": [
                {
                    "agent_name": m.agent_name,
                    "agent_role": m.agent_role,
                    "content": m.content,
                    "message_type": m.message_type,
                    "timestamp": m.timestamp.isoformat()
                }
                for m in reversed(recent_messages)
            ]
        }
        await websocket.send_text(json.dumps(history_payload, ensure_ascii=False))
        db.close()
        
    except Exception as e:
        print(f"Error sending initial data: {e}")

async def handle_client_message(message: dict, websocket: WebSocket):
    """Handle messages from client"""
    message_type = message.get("type")
    
    if message_type == "manual_trigger":
        # Manual conversation trigger
        context = {
            "trigger_type": "manual",
            "urgency_level": 0.6,
            "topics": message.get("topics", ["general_discussion"])
        }
        await conversation_scheduler._trigger_conversation(context)
    
    elif message_type == "get_status":
        # Get office status
        status = {
            "type": "office_status",
            "data": {
                "office_state": office_manager.office_state,
                "agents_count": len(office_manager.agents),
                "conversation_active": conversation_scheduler.conversation_active,
                "last_message_time": conversation_scheduler.last_message_time.isoformat() if conversation_scheduler.last_message_time else None
            }
        }
        await websocket.send_text(json.dumps(status, ensure_ascii=False, default=str))
    
    elif message_type == "translate_message":
        # Handle translation request
        content = message.get("content", "")
        target_language = message.get("target_language", "")
        message_id = message.get("message_id", "")
        
        if office_manager.translator and content and target_language:
            translation_result = await office_manager.translator.translate_message(content, target_language)
            response = {
                "type": "translation_result",
                "data": {
                    "message_id": message_id,
                    "original_content": content,
                    "translated_content": translation_result["translated_content"],
                    "target_language": target_language,
                    "status": translation_result["status"]
                }
            }
            await websocket.send_text(json.dumps(response, ensure_ascii=False, default=str))
    
    elif message_type == "user_input":
        # Handle user input trigger
        user_content = message.get("content", "")
        input_type = message.get("input_type", "general")  # news, coin_info, general, etc.
        
        if not user_content.strip():
            return
            
        # Save user message to database
        db = next(get_db())
        user_message = Message(
            agent_name="User",
            agent_role="user",
            content=user_content,
            message_type="user_input",
            context_data=json.dumps({"input_type": input_type}),
            interest_score=1.0
        )
        db.add(user_message)
        db.commit()
        db.close()
        
        # Broadcast user message to all clients
        user_msg_data = {
            "type": "chat_message",
            "data": {
                "agent_name": "User",
                "agent_role": "user", 
                "content": user_content,
                "message_type": "user_input",
                "timestamp": datetime.utcnow().isoformat()
            }
        }
        await websocket_manager.broadcast_message(user_msg_data)
        
        # Trigger agent discussion about user input
        context = {
            "trigger_type": "user_input",
            "user_content": user_content,
            "input_type": input_type,
            "urgency_level": 0.8,
            "topics": ["user_topic", "discussion", input_type]
        }
        await conversation_scheduler._trigger_conversation(context)

@app.get("/")
async def root():
    """Health check endpoint"""
    return {
        "message": "Multi-Agent AI Office API",
        "version": "0.5.0",
        "status": "running",
        "agents_count": len(office_manager.agents),
        "timestamp": datetime.utcnow().isoformat()
    }

@app.get("/api/agents")
async def get_agents():
    """Get all agent information"""
    return {
        "agents": {
            role: {
                "name": agent.name,
                "role": role,
                "personality": agent.personality,
                "expertise": agent.expertise,
                "confidence_level": agent.confidence_level,
                "assertiveness": agent.current_assertiveness,
                "last_spoke": agent.last_spoke_time.isoformat() if agent.last_spoke_time else None
            }
            for role, agent in office_manager.agents.items()
        }
    }

@app.get("/api/messages")
async def get_messages(limit: int = 50, db=Depends(get_db)):
    """Get message history"""
    messages = db.query(Message).order_by(Message.timestamp.desc()).limit(limit).all()
    
    return {
        "messages": [
            {
                "id": msg.id,
                "agent_name": msg.agent_name,
                "agent_role": msg.agent_role,
                "content": msg.content,
                "message_type": msg.message_type,
                "interest_score": msg.interest_score,
                "timestamp": msg.timestamp.isoformat()
            }
            for msg in reversed(messages)
        ]
    }

@app.get("/api/market-data")
async def get_market_data():
    """Get latest market data"""
    market_data = await data_manager.get_latest_market_data()
    return {"market_data": market_data}

@app.get("/api/news")
async def get_news():
    """Get latest news"""
    news_data = await data_manager.get_latest_news()
    return {"news": news_data}

@app.get("/api/office-state")
async def get_office_state():
    """Get current office state"""
    return {
        "office_state": office_manager.office_state,
        "conversation_active": conversation_scheduler.conversation_active,
        "last_message_time": conversation_scheduler.last_message_time.isoformat() if conversation_scheduler.last_message_time else None,
        "silence_duration": conversation_scheduler._get_silence_duration()
    }

@app.post("/api/trigger-conversation")
async def trigger_conversation(request: dict):
    """Manually trigger conversation"""
    topics = request.get("topics", ["general_discussion"])
    urgency = request.get("urgency_level", 0.6)
    
    context = {
        "trigger_type": "api_trigger",
        "urgency_level": urgency,
        "topics": topics
    }
    
    await conversation_scheduler._trigger_conversation(context)
    
    return {"message": "Conversation triggered", "context": context}

@app.get("/api/low-cost-mode")
async def get_low_cost_mode():
    """Get current low cost mode status"""
    return {
        "low_cost_mode": config.LOW_COST_MODE,
        "low_cost_model": config.LOW_COST_MODEL,
        "message": "Low cost mode is " + ("enabled" if config.LOW_COST_MODE else "disabled")
    }

@app.post("/api/low-cost-mode")
async def toggle_low_cost_mode(request: dict):
    """Toggle low cost mode"""
    enabled = request.get("enabled", False)
    
    # Update the config module
    config.LOW_COST_MODE = enabled
    
    # Update all agents to use new model configuration
    for agent in office_manager.agents.values():
        agent.actual_model = config.LOW_COST_MODEL if enabled else agent.model
        
        # Reinitialize API client if model type changed
        if agent.actual_model in ["doubao-seed-1-6-thinking-250615", "deepseek-r1-250528", "doubao-seed-1-6-flash", "doubao-seed-1-6-flash-250615"]:
            from openai import OpenAI
            agent.client = OpenAI(
                api_key=config.VOLCES_API_KEY,
                base_url=config.VOLCES_API_BASE
            )
        else:
            from openai import OpenAI
            agent.client = OpenAI(
                api_key=config.YUNWU_API_KEY,
                base_url=config.YUNWU_API_BASE
            )
    
    return {
        "message": f"Low cost mode {'enabled' if enabled else 'disabled'}",
        "low_cost_mode": enabled,
        "low_cost_model": config.LOW_COST_MODEL
    }

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=8000)