from __future__ import annotations

import asyncio
from collections import defaultdict

from fastapi import APIRouter, WebSocket, WebSocketDisconnect

router = APIRouter(tags=["ws"])


class RoomWsManager:
    def __init__(self):
        self._room_sockets: dict[str, set[WebSocket]] = defaultdict(set)
        self._lock = asyncio.Lock()

    async def add(self, room_id: str, ws: WebSocket) -> None:
        async with self._lock:
            self._room_sockets[room_id].add(ws)

    async def remove(self, room_id: str, ws: WebSocket) -> None:
        async with self._lock:
            sockets = self._room_sockets.get(room_id)
            if not sockets:
                return
            sockets.discard(ws)
            if not sockets:
                self._room_sockets.pop(room_id, None)

    async def broadcast_snapshot(self, room_id: str) -> None:
        from app.routers.stream import build_room_snapshot

        snap = await build_room_snapshot(room_id)
        message = {"type": "room_snapshot", "payload": snap}
        async with self._lock:
            sockets = list(self._room_sockets.get(room_id, set()))
        for ws in sockets:
            try:
                await ws.send_json(message)
            except Exception:
                await self.remove(room_id, ws)


ws_manager = RoomWsManager()


async def broadcast_room_snapshot(room_id: str) -> None:
    await ws_manager.broadcast_snapshot(room_id)


@router.websocket("/ws/room/{room_id}")
async def ws_room(websocket: WebSocket, room_id: str):
    await websocket.accept()
    await ws_manager.add(room_id, websocket)
    await ws_manager.broadcast_snapshot(room_id)
    try:
        while True:
            # 客戶端可選擇傳 ping；目前只需維持連線即可。
            await websocket.receive_text()
    except WebSocketDisconnect:
        pass
    finally:
        await ws_manager.remove(room_id, websocket)
