# server.py
import asyncio
import uuid
from dataclasses import dataclass, field
from typing import Optional, Dict

from fastapi import FastAPI, WebSocket, WebSocketDisconnect, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, FileResponse, PlainTextResponse
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel

app = FastAPI(title="KISS-PGPfone rendezvous/relay + web")

# Permissive CORS during development (OK to tighten in production)
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Serve the web client
app.mount("/web", StaticFiles(directory="web", html=True), name="web")

@app.get("/", include_in_schema=False)
async def index_root():
    return FileResponse("web/index.html")

# ---- Rendezvous/Relay ----
@dataclass
class Party:
    pubkey_b64: Optional[str] = None
    ws: Optional[WebSocket] = None

@dataclass
class Session:
    id: str
    caller: Party = field(default_factory=Party)
    callee: Party = field(default_factory=Party)

SESSIONS: Dict[str, Session] = {}

class PubKey(BaseModel):
    pubkey: str

@app.get("/healthz", include_in_schema=False)
async def healthz():
    return PlainTextResponse("ok")

@app.post("/session")
def create_session():
    sid = uuid.uuid4().hex
    SESSIONS[sid] = Session(id=sid)
    return {"session_id": sid}

@app.post("/session/{sid}/offer")
def post_offer(sid: str, pk: PubKey):
    s = SESSIONS.get(sid)
    if not s:
        raise HTTPException(404, "unknown session")
    s.caller.pubkey_b64 = pk.pubkey
    return {"status": "ok"}

@app.post("/session/{sid}/answer")
def post_answer(sid: str, pk: PubKey):
    s = SESSIONS.get(sid)
    if not s:
        raise HTTPException(404, "unknown session")
    s.callee.pubkey_b64 = pk.pubkey
    return {"status": "ok"}

@app.get("/session/{sid}/peerkey")
def get_peerkey(sid: str, role: str):
    s = SESSIONS.get(sid)
    if not s:
        raise HTTPException(404, "unknown session")
    if role == "caller":
        return JSONResponse({"peer_pubkey": s.callee.pubkey_b64})
    if role == "callee":
        return JSONResponse({"peer_pubkey": s.caller.pubkey_b64})
    raise HTTPException(400, "role must be 'caller' or 'callee'")

@app.websocket("/ws/{sid}")
async def media_ws(ws: WebSocket, sid: str, role: str):
    s = SESSIONS.get(sid)
    if not s:
        await ws.close(code=4404)
        return

    await ws.accept()
    party = s.caller if role == "caller" else s.callee if role == "callee" else None
    other = s.callee if role == "caller" else s.caller if role == "callee" else None
    if party is None:
        await ws.close(code=4400)
        return

    party.ws = ws

    while other.ws is None:
        await asyncio.sleep(0.05)

    try:
        recv_task = asyncio.create_task(pipe(ws, other.ws))
        send_task = asyncio.create_task(pipe(other.ws, ws))
        done, pending = await asyncio.wait({recv_task, send_task}, return_when=asyncio.FIRST_EXCEPTION)
        for t in pending:
            t.cancel()
    except WebSocketDisconnect:
        pass
    finally:
        try:
            await ws.close()
        except Exception:
            pass
        party.ws = None
        if (s.caller.ws is None) and (s.callee.ws is None):
            SESSIONS.pop(sid, None)

async def pipe(src: WebSocket, dst: WebSocket):
    while True:
        data = await src.receive_bytes()
        await dst.send_bytes(data)

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=8000)
