#!/usr/bin/env python3

from alive_progress import alive_bar
import get_specs
import traceback
#import logging
import yaml
from multiprocessing import Process, Manager, Pool, TimeoutError, active_children, log_to_stderr, Pipe, Queue
from multiprocessing.pool import Pool
import multiprocessing
from time import sleep
from util import fprint
from util import run_cmd
import sys
import ur5_control
from ur5_control import Rob
import os
import signal
import socket
from flask import Flask, render_template, request
import requests
from led_control import LEDSystem
import server
import asyncio
import json
import process_video
import search
from search import JukeboxSearch
#multiprocessing.set_start_method('spawn', True)


config = None
keeprunning = True
arm_ready = False
led_ready = False
camera_ready = False
sensor_ready = False
vm_ready = False
cable_search_ready = False
killme = None
#pool = None
serverproc = None
camera = None
ledsys = None
arm = None
to_server_queue = Queue()
from_server_queue = Queue()
mode = "Startup"
counter = 0
jbs = None
scan_value = None
arm_state = None
cable_list = list()
parse_res = None

def arm_start_callback(res):
    global arm_ready
    arm_ready = True

def led_start_callback(res):
    global led_ready
    led_ready = True
    global ledsys
    ledsys = res

def camera_start_callback(res):
    global camera_ready
    camera_ready = True
    global scan_value
    scan_value = res
    
def sensor_start_callback(res):
    global sensor_ready
    sensor_ready = True

def vm_start_callback(res):
    global vm_ready
    vm_ready = True

def cable_search_callback(res):
    global cable_search_ready
    cable_search_ready = True
    global parse_res 
    parse_res = res

def wait_for(val, name):
    #global val
    if val is False:
        fprint("waiting for " + name + " to complete...")
        while val is False:
            sleep(0.1)

def send_data(type, call, data, client_id="*"):
    out = dict()
    out["type"] = type
    out["call"] = call
    out["data"] = data
    to_server_queue.put((client_id, json.dumps(out)))

def start_server_socket():
    global jbs
    """app = Flask(__name__)

    @app.route('/report_ip', methods=['POST'])
    def report_ip():
        client_ip = request.json.get('ip')
        fprint(f"Received IP: {client_ip}")
        # You can store or process the IP address as needed
        return "IP Received", 200
    
    app.run(host='0.0.0.0', port=5000)"""
    global to_server_queue
    global from_server_queue
    fprint("Starting WebSocket server...")
    websocket_process = server.start_websocket_server(to_server_queue, from_server_queue)
    
    # Example
    #to_server_queue.put("Hello, WebSocket clients!")
    

    while True:
        #print("HI")

        # Handeling Server Requests Loop, will run forever

        if not from_server_queue.empty():
            client_id, message = from_server_queue.get()
            fprint(f"Message from client {client_id}: {message}")

            # Message handler
            try:
                decoded = json.loads(message)
            except:
                fprint("Non-JSON message recieved")
                continue

            if "type" not in decoded:
                fprint("Missing \"type\" field.")
                continue
            if "call" not in decoded:
                fprint("Missing \"call\" field.")
                continue
            if "data" not in decoded:
                fprint("Missing \"data\" field.")
                continue
            # if we get here, we have a "valid" data packet
            data = decoded["data"]
            call = decoded["call"]
            match decoded["type"]:
                case "log":
                    fprint("log message")
                    if call == "send":
                        fprint("webapp: " + str(data), sendqueue=to_server_queue)
                    elif call == "request":
                        fprint("")
                case "cable_map":
                    fprint("cable_map message")
                    if call == "send":
                        fprint("")
                    elif call == "request":
                        fprint("")
                case "ping":
                    fprint("Pong!!!")
                # Lucas' notes
                # Add a ping pong :) response/handler
                # Add a get cable response/handler
                #       this will tell the robot arm to move
                # Call for turning off everything
                # TODO Helper for converting Python Dictionaries to JSON
                # make function: pythonData --> { { "type": "...", "call": "...", "data": pythonData } }
                        
                # to send: to_server_queue.put(("*", "JSON STRING HERE")) # replace * with UUID of client to send to one specific location
                
                case "cable_details":
                    fprint("cable_details message")
                    if call == "send":
                        fprint("")
                    elif call == "request":
                        fprint("")
                        dataout = dict()
                        dataout["cables"] = list()
                        print(data)
                        if "part_number" in data:
                            for part in data["part_number"]:
                                #print(part)
                                #print(jbs.get_partnum(part))
                                dataout["cables"].append(jbs.get_partnum(part)["fullspecs"])
                        if "position" in data:
                            for pos in data["position"]:
                                #print(pos)
                                #print(jbs.get_position(str(pos)))
                                dataout["cables"].append(jbs.get_position(str(pos))["fullspecs"])
                        send_data(decoded["type"], "send", dataout, client_id)
                            
                case "cable_search":
                    fprint("cable_search message")
                    if call == "send":
                        fprint("")
                    elif call == "request":
                        fprint("")
                case "keyboard":
                    fprint("keyboard message")
                    if call == "send":
                        fprint("")
                    elif call == "request":
                        fprint("")
                        if data["enabled"] == True:
                            # todo : send this to client
                            p = Process(target=run_cmd, args=("./keyboard-up.ps1",))
                            p.start()
                        elif data["enabled"] == False:
                            p = Process(target=run_cmd, args=("./keyboard-down.ps1",))
                            p.start()
                case "machine_settings":
                    fprint("machine_settings message")
                    if call == "send":
                        fprint("")
                    elif call == "request":
                        fprint("")
                case _:
                    fprint("Unknown/unimplemented data type: " + decoded["type"])


            


        sleep(0.001)  # Sleep to prevent tight loop
    


def start_client_socket():
    app = Flask(__name__)

    @app.route('/control_client', methods=['POST'])
    def message_from_server():
        # Handle message from server
        data = request.json
        fprint(f"Message from server: {data.get('message')}")
        return "Message received", 200
    
    app.run(host='0.0.0.0', port=6000)


def check_server_online(serverip, clientip):
    def send_ip_to_server(server_url, client_ip):
        try:
            response = requests.post(server_url, json={'ip': client_ip}, timeout=1)
            fprint(f"Server response: {response.text}")
            return True
        except requests.exceptions.RequestException as e:
            fprint(f"Error sending IP to server: {e}")
            return False
        
    server_url = 'http://' + serverip + ':5000/report_ip'
    while not send_ip_to_server(server_url, clientip):
        sleep(1)

    fprint("Successfully connected to server.")
    return True
    
        
def setup_server(pool):
    # linux server setup
    global config
    global counter
    global sensor_ready
    global camera_ready
    global led_ready
    global arm_ready
    global serverproc
    global camera
    global arm
    global jbs
    arm = Rob(config)
    pool.apply_async(arm.init_arm, callback=arm_start_callback)
    global ledsys
    ledsys = LEDSystem()
    pool.apply_async(ledsys.init, callback=led_start_callback)
    #pool.apply_async(sensor_control.init, callback=sensor_start_callback)
    jbs = JukeboxSearch()
    serverproc = Process(target=start_server_socket)
    serverproc.start()
    
    
    if led_ready is False:
        fprint("waiting for " + "LED controller initialization" + " to complete...", sendqueue=to_server_queue)
        while led_ready is False:
            sleep(0.1)
    fprint("LED controllers initialized.", sendqueue=to_server_queue)
    #to_server_queue.put("[log] LED controllers initialized.")
    sensor_ready = True
    if sensor_ready is False:
        fprint("waiting for " + "Sensor Initialization" + " to complete...", sendqueue=to_server_queue)
        while sensor_ready is False:
            sleep(0.1)
    fprint("Sensors initialized.", sendqueue=to_server_queue)

    if camera_ready is False:
        fprint("waiting for " + "Camera initilization" + " to complete...", sendqueue=to_server_queue)
        camera = process_video.qr_reader(config["cameras"]["banner"]["ip"], config["cameras"]["banner"]["port"])

    fprint("Camera initialized.", sendqueue=to_server_queue)

    arm_ready = True
    if arm_ready is False:
        fprint("waiting for " + "UR5 initilization" + " to complete...", sendqueue=to_server_queue)
        while arm_ready is False:
            sleep(0.1)
    fprint("Arm initialized.", sendqueue=to_server_queue)


    
    return True



def mainloop_server(pool):
    # NON-blocking loop
    global config
    global counter
    global killme
    global mode
    global jbs
    global arm
    global ledsys
    global camera
    global arm_ready
    global arm_state
    global camera_ready
    global cable_search_ready
    global cable_list

    if killme.value > 0:
        killall()

    if mode == "Startup":
        counter = 54
        if counter < 54:
            # scanning cables
            
            if arm_state is None:
                #pool.apply_async(arm.get cable to camera, callback=arm_start_callback)
                #ur5_control.goto_holder_index(arm)
                #ur5 get item
                # ur5 bring to camera
                fprint("Getting cable index " + str(counter) + " and scanning...")
                arm_state = "GET"

            elif arm_ready and arm_state == "GET":
                fprint("Looking for QR code...")
                pool.apply_async(camera.read_qr, (30,), callback=camera_start_callback)
                arm_ready = False

            elif camera_ready:
                fprint("Adding cable to list...")
                global scan_value
                if scan_value is False:
                    cable_list.append(scan_value)
                elif scan_value.find("bldn.app/") > -1:
                    scan_value = scan_value[scan_value.find("bldn.app/")+9:]
                else:
                    cable_list.append(scan_value)
                fprint(scan_value)
                #pool.apply_async(arm.return cable, callback=arm_start_callback)
                arm_state = "RETURN"
                camera_ready = False

            elif arm_ready and arm_state == "RETURN":
                counter += 1
                arm_state = None
            else:
                # just wait til arm/camera is ready
                pass
        else:
            # scanned everything
            tmp = [
                    # Actual cables in Jukebox
                    "BLTF-1LF-006-RS5",
                    "BLTF-SD9-006-RI5",
                    "BLTT-SLG-024-HTN",
                    "BLFISX012W0",
                    "BLFI4X012W0",
                    "BLSPE101",
                    "BLSPE102",
                    "BL7922A",
                    "BL7958A",
                    "BLIOP6U",
                    "BL10GXW13",
                    "BL10GXW53",
                    "BL29501F",
                    "BL29512",
                    "BL3106A",
                    "BL9841",
                    "BL3105A",
                    "BL3092A",
                    "BL8760",
                    "BL6300UE",
                    "BL6300FE",
                    "BLRA500P",
                    "AW86104CY",
                    "AW3050",
                    "AW6714",
                    "AW1172C",
                    "AWFIT-221-1_4"
                ]
            while len(tmp) < 54:
                tmp.append(False) # must have 54 entries
            cable_list = tmp # remove for real demo
            pool.apply_async(get_specs.get_multi, (tmp, 0.3), callback=cable_search_callback)
            mode = "Parsing"
            fprint("All cables scanned. Finding & parsing datasheets...")
    if mode == "Parsing":
            # waiting for search & parse to complete
            #cable_search_ready = True
            if cable_search_ready is False:
                pass
            else:
                # done
                global parse_res
                success, partnums = parse_res
                for idx in range(len(partnums)):
                    if partnums[idx] is not False:
                        cable_list[idx] = partnums[idx][0].replace("/", "_")
                    else:
                        cable_list[idx] = False
                
                print(partnums)
                if success:
                    # easy mode
                    fprint("All cables inventoried and parsed.")
                    fprint("Adding to database...")
                    for idx in range(len(cable_list)):
                        partnum = cable_list[idx]
                        if partnum is not False:
                            with open("cables/" + partnum + "/search.json", "rb") as f:
                                searchdata = json.load(f)
                                searchdata["position"] = idx
                            with open("cables/" + partnum + "/specs.json", "rb") as f:
                                specs = json.load(f)
                            searchdata["fullspecs"] = specs
                            searchdata["fullspecs"]["position"] = idx
                            jbs.add_document(searchdata)
                    #sleep(0.5)
                    #print(jbs.get_position("1"))
                    
                    fprint("All cables added to database.")
                    mode = "Idle"
                else:
                    # TODO: manual input
                    pass

            
    if mode == "Idle":
        # do nothing
        if arm_ready is False:
            pool.apply_async(ur5_control.move_to_home, (arm,), callback=arm_start_callback)
            arm_ready = True

        else:
            # LED idle anim
            pass

            


def run_loading_app():
    
    app = Flask(__name__)
    @app.route('/')
    def index():
        return render_template('index.html')
    
    app.run(debug=True, use_reloader=False, port=7000)

def setup_client(pool):
    # Windows client setup
    fprint("Opening browser...")
    firefox = webdriver.Firefox()
    firefox.fullscreen_window()
    global config
    global vm_ready
    global serverproc
    # Open loading wepage
    p = Process(target=run_loading_app)
    p.start()

    
    firefox.get('http://localhost:7000')
    
    # start Linux server VM
    if config["core"]["server"] == "Hyper-V":
        run_cmd("Start-VM -Name Jukebox*") # any and all VMs starting with "Jukebox"

    # Wait for VM to start and be reachable over the network
    serverproc = Process(target=start_client_socket)
    serverproc.start()
    
    pool.apply_async(check_server_online, (config["core"]["serverip"],config["core"]["clientip"]), callback=vm_start_callback)

    #wait_for(vm_ready, "VM Startup")

    #global vm_ready
    if vm_ready is False:
        fprint("waiting for " + "VM Startup" + " to complete...")
        while vm_ready is False:
            sleep(0.1)
    p.terminate()
    firefox.get("http://" + config["core"]["serverip"] + ":8000")
    return True
    

def mainloop_client(pool):
    sleep(0.1)

    # listen for & act on commands from VM, if needed
    # mainly just shut down, possibly connect to wifi or something

"""class Logger(object):
    def __init__(self, filename="output.log"):
        self.log = open(filename, "a")
        self.terminal = sys.stdout

    def write(self, message):
        self.log.write(message)
        #close(filename)
        #self.log = open(filename, "a")
        try:
            self.terminal.write(message)
        except:
            sleep(0)
        
    def flush(self):
        print("",end="")"""

def killall():
    procs = active_children()
    for proc in procs:
        proc.kill()
    fprint("All child processes killed")
    os.kill(os.getpid(), 9) # dirty kill of self

    
def killall_signal(a, b):
    global config
    if config["core"]["server"] == "Hyper-V":
        run_cmd("Stop-VM -Name Jukebox*") # any and all VMs starting with "Jukebox"
    killall()

def error(msg, *args):
    return multiprocessing.get_logger().error(msg, *args)

class LogExceptions(object):
    def __init__(self, callable):
        self.__callable = callable

    def __call__(self, *args, **kwargs):
        try:
            result = self.__callable(*args, **kwargs)

        except Exception as e:
            # Here we add some debugging help. If multiprocessing's
            # debugging is on, it will arrange to log the traceback
            error(traceback.format_exc())
            # Re-raise the original exception so the Pool worker can
            # clean up
            raise

        # It was fine, give a normal answer
        return result

class LoggingPool(Pool):
    def apply_async(self, func, args=(), kwds={}, callback=None):
        return Pool.apply_async(self, LogExceptions(func), args, kwds, callback)


if __name__ == "__main__":
    #sys.stdout = Logger(filename="output.log")
    #sys.stderr = Logger(filename="output.log")
    #log_to_stderr(logging.DEBUG)
    fprint("Starting Jukebox control system...")
    with open('config.yml', 'r') as fileread:
        #global config
        config = yaml.safe_load(fileread)
    fprint("Config loaded.")
    with Manager() as manager:
        fprint("Spawning threads...")
        pool = LoggingPool(processes=10)
        counter = 0
        killme = manager.Value('d', 0)
        signal.signal(signal.SIGINT, killall_signal)
        if config["core"]["mode"] == "winclient":
            fprint("Starting in client mode.")
            from selenium import webdriver
            if setup_client(pool):
                fprint("Entering main loop...")
                while(keeprunning):
                    mainloop_client(pool)

        elif config["core"]["mode"] == "linuxserver":
            fprint("Starting in server mode.")
            if setup_server(pool):
                fprint("Entering main loop...")
                while(keeprunning):
                    mainloop_server(pool)
        else:
            fprint("Mode unspecified - assuming server")
            fprint("Starting in server mode.")
            if setup_server(pool):
                fprint("Entering main loop...")
                while(keeprunning):
                    mainloop_server(pool)