# app/api/v1/endpoints/ws_chat.py
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Depends, BackgroundTasks
from app.api import deps
from app.services import message_service
from app.services.notification_service import create_chat_notification
from app.services.media_service import MediaService, Settings
from app.models.message import Message
from app.models.user import User
from typing import Dict
from sqlalchemy.orm import Session
from app.core.logging_config import setup_logging
from app.core.time import now
import json
import traceback
import base64
import os
import uuid
import io
from fastapi import UploadFile

router = APIRouter()
# Setup logging
logger = setup_logging()

# Define the upload path for chat images
CHAT_IMAGES_PATH = "chat_images"

# Initialize media service
settings = Settings()
media_service = MediaService(settings.STORAGE_BASE_PATH, settings.BASE_URL)

# Ensure the upload directory exists
os.makedirs(os.path.join(settings.STORAGE_BASE_PATH, CHAT_IMAGES_PATH), exist_ok=True)

# Dictionary to store active connections
active_connections: Dict[int, WebSocket] = {}

class ConnectionManager:
    def __init__(self):
        self.active_connections = {}  # user_id: WebSocket
    
    async def connect(self, websocket: WebSocket, user_id: int):
        await websocket.accept()
        self.active_connections[user_id] = websocket
        logger.info(f"User {user_id} connected to WebSocket")
    
    def disconnect(self, user_id: int):
        if user_id in self.active_connections:
            del self.active_connections[user_id]
            logger.info(f"User {user_id} disconnected from WebSocket")
    
    async def send_message(self, message: dict, user_id: int):
        if user_id in self.active_connections:
            await self.active_connections[user_id].send_text(json.dumps(message))
            logger.debug(f"Message sent to user {user_id}")
            return True
        logger.debug(f"User {user_id} not connected, message not delivered")
        return False

manager = ConnectionManager()

# 디버깅용 웹소켓 엔드포인트 추가
@router.websocket("/ws/debug")
async def websocket_debug(websocket: WebSocket):
    """
    Simple WebSocket echo endpoint for debugging without authentication.
    """
    logger.info("Debug WebSocket connection attempt")
    await websocket.accept()
    logger.info("Debug WebSocket connection accepted")
    
    try:
        while True:
            data = await websocket.receive_text()
            logger.info(f"Debug WebSocket received: {data}")
            await websocket.send_text(f"Echo: {data}")
    except WebSocketDisconnect:
        logger.info("Debug WebSocket client disconnected")
    except Exception as e:
        logger.error(f"Debug WebSocket error: {str(e)}")
        logger.error(traceback.format_exc())

@router.websocket("/ws/chat/{receiver_id}")
async def websocket_chat(
    websocket: WebSocket,
    receiver_id: int,
    background_tasks: BackgroundTasks,
    db: Session = Depends(deps.get_db),
    current_user = Depends(deps.get_current_ws_user)
):
    """
    Supports real-time chat between the currently logged-in user and the user with receiver_id.
    Sends push notifications when the recipient is offline or not in the chat.
    """
    logger.info(f"WebSocket connection attempt from user {current_user.id} to friend {receiver_id}")
    await manager.connect(websocket, current_user.id)
    
    try:
        while True:
            data = await websocket.receive_text()
            try:
                # Try to parse as JSON, fall back to plain text if needed
                try:
                    message_data = json.loads(data)
                    content = message_data.get("content", "")
                    message_type = message_data.get("type", "text")
                except json.JSONDecodeError:
                    content = data
                    message_type = "text"
                
                # Handle image uploads via WebSocket
                if message_type == "image" and "image_data" in message_data:
                    try:
                        # Get base64 image data
                        image_data = message_data.get("image_data", "")
                        image_format = message_data.get("image_format", "jpeg")
                        filename = f"{uuid.uuid4()}.{image_format}"
                        file_path = f"{CHAT_IMAGES_PATH}/{filename}"
                        
                        # Decode base64 data
                        image_binary = base64.b64decode(image_data)
                        
                        # Create a file-like object for UploadFile
                        file_obj = io.BytesIO(image_binary)
                        upload_file = UploadFile(
                            filename=filename,
                            file=file_obj,
                            content_type=f"image/{image_format}"
                        )
                        
                        # Upload using MediaService
                        image_url = media_service.upload_file(upload_file, file_path)
                        logger.info(f"Uploaded chat image via websocket to: {image_url}")
                        
                        # Format content as image type
                        content = f"[IMAGE]:{image_url}"
                    except Exception as e:
                        logger.error(f"Error processing image: {str(e)}")
                        await websocket.send_text(json.dumps({
                            "error": f"Failed to process image: {str(e)}"
                        }))
                        continue
                
                # Create and save message
                message = Message(
                    sender_id=current_user.id,
                    receiver_id=int(receiver_id),
                    content=content,
                    sent_at=now(),
                    is_read=False
                )
                db.add(message)
                db.commit()
                db.refresh(message)
                
                # Format message for sending
                message_dict = {
                    "id": message.id,
                    "sender_id": message.sender_id,
                    "receiver_id": message.receiver_id,
                    "content": message.content,
                    "sent_at": message.sent_at.isoformat(),
                    "is_read": message.is_read,
                    "is_image": content.startswith("[IMAGE]:"),
                    "image_url": content.replace("[IMAGE]:", "", 1) if content.startswith("[IMAGE]:") else None
                }
                
                # Try to send to receiver if connected
                was_sent = await manager.send_message(message_dict, int(receiver_id))
                
                # Send push notification if recipient is not connected
                if not was_sent:
                    # Get sender's name
                    sender_name = f"{current_user.first_name} {current_user.last_name}"
                    
                    # Create a notification message (different for image)
                    notification_content = "Sent an image" if content.startswith("[IMAGE]:") else content
                    
                    # Send notification in background
                    background_tasks.add_task(
                        create_chat_notification,
                        db=db,
                        user_id=int(receiver_id),
                        sender_name=sender_name,
                        message_content=notification_content,
                        message_id=message.id
                    )
                    logger.info(f"Added notification task for message {message.id} to user {receiver_id}")
                
                # Always send back to sender as confirmation
                await websocket.send_text(json.dumps(message_dict))
                
            except Exception as e:
                error_msg = f"Error processing message: {str(e)}"
                logger.error(error_msg)
                logger.error(traceback.format_exc())
                await websocket.send_text(json.dumps({"error": error_msg}))
    
    except WebSocketDisconnect:
        # Remove from active connections on disconnect
        manager.disconnect(current_user.id)
        logger.info(f"Client #{current_user.id} disconnected")
    except Exception as e:
        logger.error(f"WebSocket error: {str(e)}")
        logger.error(traceback.format_exc())
        manager.disconnect(current_user.id)