# client.py
import argparse
import asyncio
import base64
import hmac
import struct
import sys

import httpx
import numpy as np
import sounddevice as sd
import websockets
from cryptography.hazmat.primitives.asymmetric.x25519 import (
    X25519PrivateKey, X25519PublicKey
)
from cryptography.hazmat.primitives.kdf.hkdf import HKDF
from cryptography.hazmat.primitives import hashes, serialization
from cryptography.hazmat.primitives.ciphers.aead import ChaCha20Poly1305

API = "http://127.0.0.1:8000"

def b64e(b: bytes) -> str: return base64.b64encode(b).decode()
def b64d(s: str) -> bytes: return base64.b64decode(s)

def derive_keys(my_pub: bytes, peer_pub: bytes, shared: bytes):
    lo, hi = sorted([my_pub, peer_pub])
    info = b"KISS-PGPfone-v1|" + lo + b"|" + hi
    hk = HKDF(algorithm=hashes.SHA256(), length=80, salt=None, info=info)
    material = hk.derive(shared)
    k0 = material[0:32]; k1 = material[32:64]
    p0 = material[64:68]; p1 = material[68:72]
    if my_pub == lo: return k0, k1, p0, p1
    else:            return k1, k0, p1, p0

def make_sas(shared: bytes, my_pub: bytes, peer_pub: bytes) -> str:
    lo, hi = sorted([my_pub, peer_pub])
    mac = hmac.new(shared, b"KISS-PGPfone-SAS|" + lo + b"|" + hi, digestmod="sha256").digest()
    return base64.b32encode(mac)[:10].decode().rstrip("=")

SAMPLE_RATE = 16000
SAMPLES_PER_FRAME = 320
BYTES_PER_FRAME = 640

def record_frames(q: asyncio.Queue):
    def cb(indata, frames, time, status):
        q.put_nowait(bytes(indata))
    stream = sd.InputStream(channels=1, samplerate=SAMPLE_RATE, dtype="int16", blocksize=SAMPLES_PER_FRAME, callback=cb)
    stream.start(); return stream

def playback_frames(q: asyncio.Queue):
    def cb(outdata, frames, time, status):
        try: chunk = q.get_nowait()
        except asyncio.QueueEmpty: chunk = b"\x00"*BYTES_PER_FRAME
        outdata[:] = np.frombuffer(chunk, dtype=np.int16).reshape(-1,1)
    stream = sd.OutputStream(channels=1, samplerate=SAMPLE_RATE, dtype="int16", blocksize=SAMPLES_PER_FRAME, callback=cb)
    stream.start(); return stream

class E2EE:
    def __init__(self, send_key, recv_key, send_npfx, recv_npfx):
        self.tx = ChaCha20Poly1305(send_key)
        self.rx = ChaCha20Poly1305(recv_key)
        self.sfx = send_npfx; self.rfx = recv_npfx
        self.ctr = 0
    def seal(self, pt: bytes) -> bytes:
        nonce = self.sfx + struct.pack(">Q", self.ctr); self.ctr += 1
        return struct.pack(">Q", self.ctr-1) + self.tx.encrypt(nonce, pt, None)
    def open(self, frame: bytes) -> bytes:
        ctr = struct.unpack(">Q", frame[:8])[0]
        nonce = self.rfx + struct.pack(">Q", ctr)
        return self.rx.decrypt(nonce, frame[8:], None)

async def run(role: str, api: str, session_id: str | None):
    async with httpx.AsyncClient(timeout=10) as http:
        sk = X25519PrivateKey.generate()
        pk = sk.public_key().public_bytes(encoding=serialization.Encoding.Raw, format=serialization.PublicFormat.Raw)
        if role == "caller":
            if not session_id:
                r = await http.post(f"{api}/session"); r.raise_for_status()
                session_id = r.json()["session_id"]; print("Session:", session_id)
            await http.post(f"{api}/session/{session_id}/offer", json={"pubkey": b64e(pk)})
            peer_pk_b64 = None; print("Waiting for callee...")
            while not peer_pk_b64:
                r = await http.get(f"{api}/session/{session_id}/peerkey", params={"role":"caller"}); r.raise_for_status()
                peer_pk_b64 = r.json().get("peer_pubkey"); await asyncio.sleep(0.5)
        else:
            if not session_id: print("--session-id required", file=sys.stderr); sys.exit(2)
            await http.post(f"{api}/session/{session_id}/answer", json={"pubkey": b64e(pk)})
            peer_pk_b64 = None
            while not peer_pk_b64:
                r = await http.get(f"{api}/session/{session_id}/peerkey", params={"role":"callee"}); r.raise_for_status()
                peer_pk_b64 = r.json().get("peer_pubkey"); await asyncio.sleep(0.5)
        peer_pk = b64d(peer_pk_b64)
        shared = sk.exchange(X25519PublicKey.from_public_bytes(peer_pk))
        send_key, recv_key, send_npfx, recv_npfx = derive_keys(pk, peer_pk, shared)
        sas = make_sas(shared, pk, peer_pk); print("SAS:", sas); input("If matched, press Enter... ")
        e2ee = E2EE(send_key, recv_key, send_npfx, recv_npfx)
        uri = f"{api.replace('http','ws')}/ws/{session_id}?role={role}"
        async with websockets.connect(uri, max_size=None) as ws:
            rec_q: asyncio.Queue[bytes] = asyncio.Queue(maxsize=50)
            play_q: asyncio.Queue[bytes] = asyncio.Queue(maxsize=50)
            record_frames(rec_q); playback_frames(play_q)
            async def sender(): 
                while True: await ws.send(e2ee.seal(await rec_q.get()))
            async def receiver():
                while True:
                    msg = await ws.recv()
                    if isinstance(msg, (bytes, bytearray)):
                        try: await play_q.put(e2ee.open(msg))
                        except Exception: pass
            await asyncio.gather(sender(), receiver())

if __name__ == "__main__":
    p = argparse.ArgumentParser()
    p.add_argument("role", choices=["caller","callee"])
    p.add_argument("--api", default=API)
    p.add_argument("--session-id")
    a = p.parse_args()
    try: asyncio.run(run(a.role, a.api, a.session_id))
    except KeyboardInterrupt: pass
