import sys
import asyncio
import websockets
import json
import base64
import os
from config import SERVER_HOST, SERVER_PORT

from PyQt6.QtWidgets import QApplication, QMainWindow, QVBoxLayout, QHBoxLayout, QTextEdit, QLineEdit, QPushButton, QWidget, QLabel, QInputDialog, QMessageBox, QStatusBar
from PyQt6.QtCore import QThread, pyqtSignal, Qt

from cryptography.hazmat.primitives.asymmetric import dh
from cryptography.hazmat.primitives import serialization, hashes
from cryptography.hazmat.primitives.kdf.hkdf import HKDF
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from cryptography.hazmat.backends import default_backend

# Diffie-Hellman Parameters (Pre-defined for simplicity)
p = dh.DHParameterNumbers(0xFFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE64928665CEE39D760FE13568648CD24196192E217E1D368736B0A5210D4B64580D,
                              2).parameters(default_backend())

private_key = p.generate_private_key()
public_key = private_key.public_key()

# Global state for WebSocket connection and room info (managed by WebSocketClient and ChatWindow)
room_public_keys = {}
shared_aes_key = None
client_id_global = None
room_id_global = None

# Helper function to identify the room creator (simplified for this example)
# def get_room_creator_id_simple(): # This function is no longer needed
#     return client_id_global

class WebSocketClient(QThread):
    message_received = pyqtSignal(dict)
    connected = pyqtSignal()
    disconnected = pyqtSignal()
    error_occurred = pyqtSignal(str)

    def __init__(self, parent=None):
        super().__init__(parent)
        self.loop = asyncio.new_event_loop()
        self.websocket = None
        self.uri = f"ws://{SERVER_HOST}:{SERVER_PORT}"
        self.running = True

    def run(self):
        asyncio.set_event_loop(self.loop)
        self.loop.run_until_complete(self.connect_and_listen())

    async def connect_and_listen(self):
        try:
            self.websocket = await websockets.connect(self.uri)
            self.connected.emit()
            while self.running:
                try:
                    message_raw = await self.websocket.recv()
                    data = json.loads(message_raw)
                    self.message_received.emit(data)
                except websockets.exceptions.ConnectionClosedOK:
                    self.disconnected.emit()
                    break
                except json.JSONDecodeError:
                    self.error_occurred.emit(f"Invalid JSON received: {message_raw}")
                except Exception as e:
                    self.error_occurred.emit(f"WebSocket receive error: {e}")
                    break
        except Exception as e:
            self.error_occurred.emit(f"Failed to connect to WebSocket server: {e}")
            self.disconnected.emit()

    async def send_message_async(self, message: dict):
        if self.websocket and self.websocket.open:
            try:
                await self.websocket.send(json.dumps(message))
            except Exception as e:
                self.error_occurred.emit(f"Failed to send message: {e}")

    def send_message(self, message: dict):
        if self.loop and self.loop.is_running():
            asyncio.run_coroutine_threadsafe(self.send_message_async(message), self.loop)
        else:
            self.error_occurred.emit("WebSocket loop is not running to send message.")

    def stop(self):
        self.running = False
        if self.websocket and self.websocket.open:
            self.loop.call_soon_threadsafe(self.websocket.close)
        self.loop.call_soon_threadsafe(self.loop.stop) # This might not stop run_until_complete immediately
        self.wait()
        self.loop.close()

class ChatWindow(QMainWindow):
    def __init__(self):
        super().__init__()
        self.setWindowTitle("Encrypted Chat")
        self.setGeometry(100, 100, 800, 600)

        self.central_widget = QWidget()
        self.setCentralWidget(self.central_widget)
        self.main_layout = QVBoxLayout(self.central_widget)

        self.chat_display = QTextEdit()
        self.chat_display.setReadOnly(True)
        self.main_layout.addWidget(self.chat_display)

        self.input_layout = QHBoxLayout()
        self.message_input = QLineEdit()
        self.message_input.setPlaceholderText("Enter message or command (e.g., /create room1 pass1)")
        self.send_button = QPushButton("Send")
        self.send_button.clicked.connect(self.send_button_clicked)
        self.message_input.returnPressed.connect(self.send_button_clicked) # Send on Enter key

        self.input_layout.addWidget(self.message_input)
        self.input_layout.addWidget(self.send_button)
        self.main_layout.addLayout(self.input_layout)

        self.status_bar = QStatusBar()
        self.setStatusBar(self.status_bar)
        self.status_bar.showMessage("Disconnected")

        # Load QSS for terminal theme
        try:
            with open("terminal_theme.qss", "r") as f:
                self.setStyleSheet(f.read())
        except FileNotFoundError:
            self.log_message("terminal_theme.qss not found. Using default theme.", "red")

        # Diffie-Hellman keys (instance specific)
        self.private_key = p.generate_private_key()
        self.public_key = self.private_key.public_key()

        # Global state for WebSocket connection and room info (instance specific)
        self.room_public_keys = {}
        self.shared_aes_key = None
        self.client_id = None
        self.room_id = None
        self.is_room_creator = False # New flag to track if this client created the room

        self.websocket_thread = WebSocketClient()
        self.websocket_thread.message_received.connect(self.handle_websocket_message)
        self.websocket_thread.connected.connect(self.on_websocket_connected)
        self.websocket_thread.disconnected.connect(self.on_websocket_disconnected)
        self.websocket_thread.error_occurred.connect(self.on_websocket_error)
        self.websocket_thread.start()

    def on_websocket_connected(self):
        self.status_bar.showMessage("Connected to Server")
        self.log_message("Connected to WebSocket server.", "green")

    def on_websocket_disconnected(self):
        self.status_bar.showMessage("Disconnected from Server")
        self.log_message("Disconnected from WebSocket server.", "red")

    def on_websocket_error(self, error_message):
        self.log_message(f"WebSocket Error: {error_message}", "red")

    def log_message(self, message, color="white"):
        self.chat_display.append(f"<span style=\"color:{color};\">{message}</span>")

    def send_button_clicked(self):
        text = self.message_input.text().strip()
        if not text:
            return
        self.message_input.clear()
        self.process_command(text)

    def process_command(self, command):
        self.log_message(f"> {command}", "gray")

        parts = command.split(" ", 2)
        action = parts[0]

        if action == "/create" and len(parts) == 3:
            room_id = parts[1]
            password = parts[2]
            self.room_id = room_id
            self.is_room_creator = True # Set flag when creating a room
            message = {"action": "create_room", "room_id": room_id, "password": password, "public_key": base64.b64encode(self.public_key.public_bytes(serialization.Encoding.PEM, serialization.PublicFormat.SubjectPublicKeyInfo)).decode('utf-8')}
            self.websocket_thread.send_message(message)
        elif action == "/join" and len(parts) == 3:
            room_id = parts[1]
            password = parts[2]
            self.room_id = room_id
            message = {"action": "join_room", "room_id": room_id, "password": password, "public_key": base64.b64encode(self.public_key.public_bytes(serialization.Encoding.PEM, serialization.PublicFormat.SubjectPublicKeyInfo)).decode('utf-8')}
            self.websocket_thread.send_message(message)
        elif action == "/send" and len(parts) >= 2:
            if self.room_id and self.shared_aes_key:
                # Correctly capture all words after /send
                message_content = command.split(" ", 1)[1] 
                self.send_encrypted_message(message_content)
            else:
                self.log_message("You must create or join a room and have a shared AES key first to send messages.", "red")
        elif action == "/distribute_key":
            if self.room_id and self.client_id:
                self.distribute_session_key()
            else:
                self.log_message("You must create a room and have a client ID to distribute the key.", "red")
        elif action == "/exit":
            self.close()
        else:
            self.log_message("Unknown command. Use /create, /join, /send, /distribute_key, or /exit.", "yellow")

    def handle_websocket_message(self, data):
        # global client_id_global, room_id_global, shared_aes_key, room_public_keys # These are now instance variables
        if data.get("type") == "message":
            sender_id = data.get("sender_id", "Unknown")
            encrypted_content_b64 = data.get("content", "")
            iv_b64 = data.get("iv", "")

            if self.shared_aes_key and encrypted_content_b64 and iv_b64:
                try:
                    encrypted_content = base64.b64decode(encrypted_content_b64)
                    iv = base64.b64decode(iv_b64)
                    cipher = Cipher(algorithms.AES(self.shared_aes_key), modes.CBC(iv), backend=default_backend())
                    decryptor = cipher.decryptor()
                    decrypted_padded_message = decryptor.update(encrypted_content) + decryptor.finalize()
                    pad_len = decrypted_padded_message[-1]
                    decrypted_message = decrypted_padded_message[:-pad_len].decode('utf-8')
                    self.log_message(f"[{sender_id}] {decrypted_message}", "cyan")
                except Exception as e:
                    self.log_message(f"[ERROR] Failed to decrypt message from {sender_id}: {e}", "red")
            else:
                self.log_message(f"[ERROR] Cannot decrypt message from {sender_id}: Shared AES key not available or missing content/IV.", "red")
        elif data.get("type") == "public_keys":
            received_keys = data.get("public_keys", {})
            for client, key_b64 in received_keys.items():
                if client != self.client_id:
                    self.room_public_keys[client] = serialization.load_pem_public_key(base64.b64decode(key_b64), backend=default_backend())
            self.log_message("[SERVER] Received all public keys.", "green")
        elif data.get("type") == "new_public_key":
            client = data.get("client_id")
            key_b64 = data.get("public_key")
            if client != self.client_id:
                self.room_public_keys[client] = serialization.load_pem_public_key(base64.b64decode(key_b64), backend=default_backend())
                self.log_message(f"[SERVER] Received new public key for client: {client}", "green")
            if self.shared_aes_key and self.room_id and self.is_room_creator: # Check if this client is the creator
                self.distribute_session_key_to_one(client)
        elif data.get("type") == "session_key_bundle":
            sender_id = data.get("sender_id")
            bundle = data.get("bundle")
            if self.client_id in bundle:
                encrypted_key_b64 = bundle[self.client_id]
                try:
                    encrypted_bundle = base64.b64decode(encrypted_key_b64)
                    iv = encrypted_bundle[:16] # Extract IV from the beginning of the bundle
                    encrypted_key = encrypted_bundle[16:] # The rest is the actual encrypted key

                    shared_secret = self.private_key.exchange(self.room_public_keys[sender_id])
                    derived_key = HKDF(algorithm=hashes.SHA256(), length=32, salt=None, info=b'aes key transport', backend=default_backend()).derive(shared_secret)
                    cipher = Cipher(algorithms.AES(derived_key), modes.CBC(iv), backend=default_backend()) # Use CBC mode with IV
                    decryptor = cipher.decryptor()
                    decrypted_session_key_padded = decryptor.update(encrypted_key) + decryptor.finalize()
                    pad_len = decrypted_session_key_padded[-1]
                    self.shared_aes_key = decrypted_session_key_padded[:-pad_len]
                    self.log_message("[SERVER] Successfully received and decrypted AES session key.", "bold green")
                except Exception as e:
                    self.log_message(f"[ERROR] Failed to decrypt session key: {e}", "red")
            else:
                self.log_message("[SERVER] No session key for this client in the bundle.", "yellow")
        else:
            message_text = data.get("message", str(data))
            status = data.get("status", "info")
            if status == "success":
                self.log_message(f"[SERVER] {message_text}", "green")
                if data.get("client_id"): self.client_id = data["client_id"]
                # If this client just created the room, generate and distribute AES key
                if self.is_room_creator and not self.shared_aes_key and self.room_id:
                    self.distribute_session_key()
            elif status == "error":
                self.log_message(f"[SERVER] Error: {message_text}", "red")
            else:
                self.log_message(f"[SERVER] {message_text}", "cyan")

    def send_encrypted_message(self, content: str):
        # global private_key, public_key # Keep these global or pass them around - Removed, now instance variables
        if not self.shared_aes_key:
            self.log_message("Cannot send message: Shared AES key not available.", "red")
            return
        
        pad_len = 16 - (len(content.encode('utf-8')) % 16)
        padded_content = content.encode('utf-8') + bytes([pad_len]) * pad_len

        iv = os.urandom(16)
        cipher = Cipher(algorithms.AES(self.shared_aes_key), modes.CBC(iv), backend=default_backend())
        encryptor = cipher.encryptor()
        encrypted_content = encryptor.update(padded_content) + encryptor.finalize()

        encrypted_content_b64 = base64.b64encode(encrypted_content).decode('utf-8')
        iv_b64 = base64.b64encode(iv).decode('utf-8')

        message = {"action": "message", "room_id": self.room_id, "content": encrypted_content_b64, "iv": iv_b64, "sender_id": self.client_id}
        self.websocket_thread.send_message(message)
        self.log_message(f"[You] {content}", "white") # Display own message locally

    def distribute_session_key(self):
        # global private_key, public_key # Keep these global or pass them around - Removed, now instance variables
        if not self.shared_aes_key:
            self.shared_aes_key = os.urandom(32)
            self.log_message("Generated new AES session key.", "green")
        
        encrypted_session_key_bundle = {}
        for client, pub_key_obj in self.room_public_keys.items():
            if client != self.client_id:
                try:
                    shared_secret = self.private_key.exchange(pub_key_obj)
                    derived_key = HKDF(algorithm=hashes.SHA256(), length=32, salt=None, info=b'aes key transport', backend=default_backend()).derive(shared_secret)
                    
                    pad_len = 16 - (len(self.shared_aes_key) % 16)
                    padded_session_key = self.shared_aes_key + bytes([pad_len]) * pad_len

                    iv = os.urandom(16) # Generate IV for CBC mode
                    cipher = Cipher(algorithms.AES(derived_key), modes.CBC(iv), backend=default_backend())
                    encryptor = cipher.encryptor()
                    encrypted_key = encryptor.update(padded_session_key) + encryptor.finalize()
                    encrypted_session_key_bundle[client] = base64.b64encode(iv + encrypted_key).decode('utf-8') # Prepend IV to encrypted key
                except Exception as e:
                    self.log_message(f"Error encrypting session key for {client}: {e}", "red")
        
        if encrypted_session_key_bundle:
            message = {"action": "session_key_distribution", "room_id": self.room_id, "encrypted_session_key_bundle": encrypted_session_key_bundle}
            self.websocket_thread.send_message(message)
            self.log_message("Distributed AES session key to other room members.", "green")
        else:
            self.log_message("No other clients in the room to distribute session key.", "yellow")

    def distribute_session_key_to_one(self, target_client_id: str):
        # global private_key, public_key # Keep these global or pass them around - Removed, now instance variables

        if not self.shared_aes_key:
            self.log_message("Cannot distribute session key: Shared AES key not available.", "red")
            return

        if target_client_id not in self.room_public_keys:
            self.log_message(f"Public key for client {target_client_id} not available for session key distribution.", "red")
            return

        try:
            pub_key_obj = self.room_public_keys[target_client_id]
            shared_secret = self.private_key.exchange(pub_key_obj)
            derived_key = HKDF(algorithm=hashes.SHA256(), length=32, salt=None, info=b'aes key transport', backend=default_backend()).derive(shared_secret)
            
            pad_len = 16 - (len(self.shared_aes_key) % 16)
            padded_session_key = self.shared_aes_key + bytes([pad_len]) * pad_len

            iv = os.urandom(16) # Generate IV for CBC mode
            cipher = Cipher(algorithms.AES(derived_key), modes.CBC(iv), backend=default_backend())
            encryptor = cipher.encryptor()
            encrypted_key = encryptor.update(padded_session_key) + encryptor.finalize()

            encrypted_session_key_bundle = {target_client_id: base64.b64encode(iv + encrypted_key).decode('utf-8')} # Prepend IV to encrypted key
            
            message = {"action": "session_key_distribution", "room_id": self.room_id, "encrypted_session_key_bundle": encrypted_session_key_bundle}
            self.websocket_thread.send_message(message)
            self.log_message(f"Distributed AES session key to new client: {target_client_id}.", "green")
        except Exception as e:
            self.log_message(f"Error distributing session key to {target_client_id}: {e}", "red")

    def closeEvent(self, event):
        self.websocket_thread.stop()
        event.accept()

if __name__ == "__main__":
    # Diffie-Hellman Parameters (Pre-defined for simplicity) - Moved to global scope
    # p = dh.DHParameterNumbers(0xFFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE64928665CEE39D760FE13568648CD24196192E217E1D368736B0A5210D4B64580D,
    #                               2).parameters(default_backend())
    
    # private_key = p.generate_private_key()
    # public_key = private_key.public_key()
    
    app = QApplication(sys.argv)
    window = ChatWindow()
    window.show()
    sys.exit(app.exec()) 