from __future__ import annotations

import hashlib
import json
import secrets
from datetime import timedelta
from typing import Any, Optional

from sqlalchemy import delete
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession

from app.config import settings
from app.models.entities import AudienceSession, JoinToken, Participant, Room
from app.services.time_utils import utc_now


def hash_line_user_id(line_user_id: str) -> str:
    raw = (line_user_id or "").strip()
    if not raw:
        return ""
    return hashlib.sha256(raw.encode("utf-8")).hexdigest()


async def ensure_room_exists(session: AsyncSession, room_id: str) -> Room:
    room = await session.get(Room, room_id)
    if room:
        return room
    room = Room(id=room_id, name=room_id)
    session.add(room)
    await session.flush()
    return room


async def get_or_create_participant(
    session: AsyncSession,
    *,
    room_id: str,
    line_user_id: str = "",
    source: str = "",
    display_name: str = "",
) -> Participant:
    line_user_id_hash = hash_line_user_id(line_user_id)
    await ensure_room_exists(session, room_id)
    participant: Optional[Participant] = None
    if line_user_id_hash:
        result = await session.exec(
            select(Participant).where(
                Participant.room_id == room_id,
                Participant.line_user_id_hash == line_user_id_hash,
            )
        )
        participant = result.first()
    if participant:
        participant.last_seen_at = utc_now()
        if display_name:
            participant.display_name = display_name
        if source:
            participant.source = source
        session.add(participant)
        await session.flush()
        return participant

    participant = Participant(
        id=f"p-{secrets.token_hex(8)}",
        room_id=room_id,
        line_user_id_hash=line_user_id_hash or f"anonymous:{secrets.token_hex(12)}",
        source=source,
        display_name=display_name,
        joined_at=utc_now(),
        last_seen_at=utc_now(),
    )
    session.add(participant)
    await session.flush()
    return participant


async def get_session_record(
    session: AsyncSession,
    session_id: Optional[str],
) -> Optional[AudienceSession]:
    if not session_id:
        return None
    rec = await session.get(AudienceSession, session_id)
    if not rec:
        return None
    now = utc_now()
    if rec.expires_at <= now:
        await session.delete(rec)
        await session.flush()
        return None
    rec.last_seen_at = now
    rec.expires_at = now + timedelta(seconds=settings.audience_session_ttl_seconds)
    session.add(rec)
    await session.flush()
    return rec


async def create_or_reuse_session(
    session: AsyncSession,
    *,
    participant: Participant,
    existing_session_id: Optional[str] = None,
    source: str = "",
) -> AudienceSession:
    existing = await get_session_record(session, existing_session_id)
    now = utc_now()
    if existing and existing.room_id == participant.room_id and existing.participant_id == participant.id:
        existing.last_seen_at = now
        existing.expires_at = now + timedelta(seconds=settings.audience_session_ttl_seconds)
        if source:
            existing.source = source
        session.add(existing)
        await session.flush()
        return existing

    rec = AudienceSession(
        session_id=secrets.token_urlsafe(24),
        participant_id=participant.id,
        room_id=participant.room_id,
        line_user_id_hash=participant.line_user_id_hash,
        source=source,
        created_at=now,
        last_seen_at=now,
        expires_at=now + timedelta(seconds=settings.audience_session_ttl_seconds),
    )
    session.add(rec)
    await session.flush()
    return rec


async def issue_join_token(
    session: AsyncSession,
    *,
    participant: Participant,
    event_id: str = "",
    source: str = "",
    metadata: Optional[dict[str, Any]] = None,
) -> JoinToken:
    now = utc_now()
    token = JoinToken(
        token=secrets.token_urlsafe(32),
        participant_id=participant.id,
        room_id=participant.room_id,
        line_user_id_hash=participant.line_user_id_hash,
        source=source,
        event_id=event_id,
        metadata_json=json.dumps(metadata or {}, ensure_ascii=False),
        created_at=now,
        expires_at=now + timedelta(seconds=settings.join_token_ttl_seconds),
    )
    session.add(token)
    await session.flush()
    return token


async def consume_join_token(session: AsyncSession, token_value: str) -> tuple[JoinToken, Participant]:
    token = await session.get(JoinToken, token_value)
    now = utc_now()
    if not token or token.expires_at <= now:
        raise ValueError("invalid_join_token")
    participant = await session.get(Participant, token.participant_id)
    if not participant:
        raise ValueError("participant_not_found")
    token.last_used_at = now
    token.use_count += 1
    participant.last_seen_at = now
    session.add(token)
    session.add(participant)
    await session.flush()
    return token, participant


def _active_session_filters(room_id: str):
    now = utc_now()
    active_cutoff = now - timedelta(seconds=settings.audience_session_active_window_seconds)
    return now, [
        AudienceSession.room_id == room_id,
        AudienceSession.expires_at > now,
        AudienceSession.last_seen_at >= active_cutoff,
    ]


async def purge_expired_sessions(session: AsyncSession, room_id: Optional[str] = None) -> None:
    statement = delete(AudienceSession).where(AudienceSession.expires_at <= utc_now())
    if room_id:
        statement = statement.where(AudienceSession.room_id == room_id)
    await session.exec(statement)
    await session.flush()


async def delete_session_record(session: AsyncSession, session_id: Optional[str]) -> Optional[AudienceSession]:
    if not session_id:
        return None
    rec = await session.get(AudienceSession, session_id)
    if not rec:
        return None
    await session.delete(rec)
    await session.flush()
    return rec


async def count_room_sessions(session: AsyncSession, room_id: str) -> int:
    await purge_expired_sessions(session, room_id)
    _, filters = _active_session_filters(room_id)
    result = await session.exec(select(AudienceSession).where(*filters))
    return len(result.all())


async def count_room_active_participants(session: AsyncSession, room_id: str) -> int:
    await purge_expired_sessions(session, room_id)
    _, filters = _active_session_filters(room_id)
    result = await session.exec(select(AudienceSession.participant_id).where(*filters))
    return len(set(result.all()))
