#!/usr/bin/env python3

import sacn
import time
import sys
import yaml
import math
import random
from util import fprint
import platform    # For getting the operating system name
import subprocess  # For executing a shell command
from util import win32
import cv2
import numpy as np
from uptime import uptime

sender = None
debug = True
config = None
leds = None
leds_size = None
leds_normalized = None
controllers = None
data = None
exactdata = None
rings = None
ringstatus = None
mode = "Startup"
firstrun = True
changecount = 0
animation_time = 0
start = uptime()

def ping(host):
    #Returns True if host (str) responds to a ping request.

    # Option for the number of packets as a function of
    if win32:
        param1 = '-n'
        param2 = '-w'
        param3 = '250'
    else:
        param1 = '-c'
        param2 = '-W'
        param3 = '0.25'

    # Building the command. Ex: "ping -c 1 google.com"
    command = ['ping', param1, '1', param2, param3, host]

    return subprocess.call(command, stdout=subprocess.DEVNULL, stderr=subprocess.STDOUT) == 0

def map():
    global config
    global leds
    global leds_size
    global leds_normalized
    global controllers
    global rings
    global ringstatus
    global animation_time
    

    with open('config.yml', 'r') as fileread:
        #global config
        config = yaml.safe_load(fileread)

    animation_time = config["animation_time"]
    leds = list()
    leds_size = list()
    controllers = list()
    rings = list(range(len(config["position_map"])))
    ringstatus = list(range(len(config["position_map"])))
    #print(rings)
    #fprint(config["led"]["map"])
    generate_map = False
    map = list()
    for shape in config["led"]["map"]:
        if shape["type"] == "circle":
            
            if generate_map:
                map.append((shape["pos"][1],shape["pos"][0]))
            #fprint(shape["pos"])
            anglediv = 360.0 / shape["size"]
            angle = 0
            radius = shape["diameter"] / 2
            lednum = shape["start"]
            for item in config['position_map']:
            # Check if the current item's position matches the target position
                #print(item['pos'],(shape["pos"][1],shape["pos"][0]))
                if tuple(item['pos']) == (shape["pos"][1],shape["pos"][0]):
                    rings[item["index"]] = (shape["pos"][1],shape["pos"][0],lednum,lednum+shape["size"]) # rings[index] = x, y, startpos, endpos
                    ringstatus[item["index"]] = [None, None]
                    break
            if len(leds) < lednum + shape["size"]:
                for x in range(lednum + shape["size"] - len(leds)):
                    leds.append(None)
                    leds_size.append(None)
            while angle < 359.999:
                tmpangle = angle + shape["angle"]
                x = math.cos(tmpangle * (math.pi / 180.0)) * radius + shape["pos"][1] # flip by 90 degress when we changed layout
                y = math.sin(tmpangle * (math.pi / 180.0)) * radius + shape["pos"][0]
                leds[lednum] = (x,y)
                lednum = lednum + 1
                angle = angle + anglediv
        
        elif shape["type"] == "strip":
            angle = shape["angle"]
            lednum = shape["start"]
            length = shape["length"]
            distdiv = length / shape["size"]
            dist = distdiv / 2
            xmov = math.cos(angle * (math.pi / 180.0)) * distdiv
            ymov = math.sin(angle * (math.pi / 180.0)) * distdiv
            pos = shape["pos"]
            if len(leds) < lednum + shape["size"]:
                for x in range(lednum + shape["size"] - len(leds)):
                    leds.append(None)
                    leds_size.append(None)
            
            while dist < length:
                leds[lednum] = (pos[0], pos[1])
                pos[0] += xmov
                pos[1] += ymov
                dist += distdiv
                lednum = lednum + 1

    if generate_map:
        map = sorted(map, key=lambda x: (-x[1], x[0]))
        print(map)
        import matplotlib.pyplot as plt
        plt.axis('equal')
        x, y = zip(*map)
        plt.scatter(x, y, s=12)
        #plt.plot(x, y, marker='o')
        #plt.scatter(*zip(*leds), s=3)
        for i, (x_pos, y_pos) in enumerate(map):
            plt.text(x_pos, y_pos, str(i), color="red", fontsize=12)
        plt.savefig("map2.png", dpi=600, bbox_inches="tight")
        data = {"map": [{"index": i, "pos": str(list(pos))} for i, pos in enumerate(map)]}
        yaml_str = yaml.dump(data, default_flow_style=False)
        print(yaml_str)

    print(rings)
    flag = 0
    for x in leds:
        if x is None:
            flag = flag + 1
    if flag > 0:
        fprint("Warning: Imperfect LED map ordering. Hiding undefined lights.")
        for x in range(len(leds)):
            if leds[x] is None:
                leds[x] = (0, 0)


    #leds = tmpleds.reverse()
    #fprint(leds)

    # controller mapping
    for ctrl in config["led"]["controllers"]:
        if len(controllers) < ctrl["universe"]+1:
            for x in range(ctrl["universe"]+1 - len(controllers)):
                controllers.append(None)

        controllers[ctrl["universe"]] = (ctrl["ledstart"],ctrl["ledend"]+1,ctrl["ip"])
        for x in range(ctrl["ledstart"],ctrl["ledend"]+1):
            leds_size[x] = len(ctrl["mode"])
    #fprint(controllers)
    
    if(debug):
        import matplotlib.pyplot as plt
        plt.axis('equal')
        for ctrl in controllers:
            plt.scatter(*zip(*leds[ctrl[0]:ctrl[1]]), s=2)
        #plt.scatter(*zip(*leds), s=3)
        plt.savefig("map.png", dpi=600, bbox_inches="tight")

    leds_adj = [(x-min([led[0] for led in leds]), # push to zero start
                    y-min([led[1] for led in leds]) )
                   for x, y in leds]
    
    leds_normalized = [(x / max([led[0] for led in leds_adj]), 
                    y / max([led[1] for led in leds_adj]))
                   for x, y in leds_adj]
    #return leds, controllers

def init():
    map()
    global sender
    global config
    global leds
    global leds_size
    global controllers
    global data
    global exactdata
    sender = sacn.sACNsender(fps=config["led"]["fps"], universeDiscovery=False)
    sender.start()  # start the sending thread
    """for x in range(len(controllers)):
        print("Waiting for the controller at", controllers[x][2], "to be online...", end="")
        count = 0
        while not ping(controllers[x][2]):
            count = count + 1
            if count >= config["led"]["timeout"]:
                fprint(" ERROR: controller still offline after " + str(count) + " seconds, continuing...")
                break
        if count < config["led"]["timeout"]:
            fprint(" done")"""
    for x in range(len(controllers)):
        print("Activating controller", x, "at", controllers[x][2], "with", controllers[x][1]-controllers[x][0], "LEDs.")
        sender.activate_output(x+1)  # start sending out data
        sender[x+1].destination = controllers[x][2]
    sender.manual_flush = True

    # initialize global pixel data list
    data = list()
    exactdata = list()
    for x in range(len(leds)):
        if leds_size[x] == 3:
            exactdata.append(None)
            data.append((20,20,127))
        elif leds_size[x] == 4:
            exactdata.append(None)
            data.append((50,50,255,0))
        else:
            exactdata.append(None)
            data.append((0,0,0))
    sendall(data)
    #time.sleep(50000)    
    fprint("Running start-up test sequence...")
    for y in range(1):
        for x in range(len(leds)):
            setpixel(0,60,144,x)
        sendall(data)
        #time.sleep(2)
    alloffsmooth()

def sendall(datain):
    # send all LED data to all controllers
    # data must have all LED data in it as [(R,G,B,)] tuples in an array, 1 tuple per pixel
    global controllers
    global sender
    sender.manual_flush = True
    for x in range(len(controllers)):
        sender[x+1].dmx_data = list(sum(datain[controllers[x][0]:controllers[x][1]] , ())) # flatten the subsection of the data array
    
    sender.flush()
    time.sleep(0.002)
    #sender.flush() # 100% reliable with 2 flushes, often fails with 1
    #time.sleep(0.002)
    #sender.flush()

def fastsendall(datain):
    # send all LED data to all controllers
    # data must have all LED data in it as [(R,G,B,)] tuples in an array, 1 tuple per pixel
    global controllers
    global sender
    sender.manual_flush = False
    print(datain[controllers[0][0]:controllers[0][1]])
    for x in range(len(controllers)):
        sender[x+1].dmx_data = list(sum(datain[controllers[x][0]:controllers[x][1]] , ())) # flatten the subsection of the data array
    
    sender.flush()

def senduniverse(datain, lednum):
    # send all LED data for 1 controller/universe
    # data must have all LED data in it as [(R,G,B,)] tuples in an array, 1 tuple per pixel
    global controllers
    global sender
    for x in range(len(controllers)):
        if lednum >= controllers[x][0] and lednum < controllers[x][1]:
            sender[x+1].dmx_data = list(sum(datain[controllers[x][0]:controllers[x][1]] , ())) # flatten the subsection of the data array
    
    sender.flush()
    time.sleep(0.004)
    #sender.flush() # 100% reliable with 2 flushes, often fails with 1
    #time.sleep(0.002)
    #sender.flush()

def alloff():
    tmpdata = list()
    for x in range(len(leds)):
        if leds_size[x] == 3:
            tmpdata.append((0,0,0))
        elif leds_size[x] == 4:
            tmpdata.append((0,0,0,0))
        else:
            tmpdata.append((0,0,0))
    sendall(tmpdata)
    #sendall(tmpdata)
    #sendall(tmpdata) #definitely make sure it's off

def allon():
    global data
    sendall(data)

def alloffsmooth():
    tmpdata = data
    for x in range(256):
        for x in range(len(data)):
            setpixel(tmpdata[x][0]-1,tmpdata[x][1]-1,tmpdata[x][2]-1, x)
        sendall(tmpdata)

    alloff()

def setpixelnow(r, g, b, num):
    # slight optimization: send only changed universe
    # unfortunately no way to manual flush data packets to only 1 controller with this sACN library
    global data
    setpixel(r,g,b,num)
    senduniverse(data, num)

def setmode(stmode, r=0,g=0,b=0):
    global mode
    global firstrun
    if stmode is not None:
        if mode != stmode:
            firstrun = True
    
        mode = stmode
    

def setring(r,g,b,idx):
    
    ring = rings[idx]
    for pixel in range(ring[2],ring[3]):
        setpixel(r,g,b,pixel)
    #global data
    #senduniverse(data, ring[2])

def runmodes(ring = -1, speed = 1):
    global mode
    global firstrun
    global changecount
    fprint("Mode: " + str(mode))
    if mode == "Startup":
        # loading animation. cable check
        if firstrun:
            changecount = animation_time * 3
            firstrun = False
            for x in range(len(ringstatus)):
                ringstatus[x] = [True, animation_time]
            
        if changecount > 0:
            fprint(changecount)
            changecount = fadeorder(0,len(leds), changecount, 0,50,100)
        else:
            setmode("Startup2")
            
                
    elif mode == "Startup2":
        if firstrun:
            firstrun = False
            
        else:
            for x in range(len(ringstatus)):
                if ringstatus[x][0]:
                    setring(0, 50, 100, x)
                else:
                    ringstatus[x][1] = fadeall(rings[x][2],rings[x][3], ringstatus[x][1], 100,0,0) # not ready

    elif mode == "StartupCheck":
        if firstrun:
            firstrun = False
            for x in range(len(ringstatus)):
                ringstatus[x] = [False, animation_time]
        else:
            for x in range(len(ringstatus)):
                if ringstatus[x][0]:
                    ringstatus[x][1] = fadeall(rings[x][2],rings[x][3], ringstatus[x][1], 0,50,100) # ready  
                else:
                    setring(100, 0, 0, x)

    elif mode == "GrabA":
        if firstrun:
            firstrun = False
            changecount = animation_time # 100hz
        if changecount > 0:
            changecount = fadeall(rings[ring][2],rings[ring][3], changecount, 100,0,0)
        else:
            setring(100,0,0,ring)
            setmode("GrabB")
    elif mode == "GrabB":
        if firstrun:
            firstrun = False
            changecount = animation_time # 100hz
        if changecount > 0:
            changecount = fadeorder(rings[ring][2],rings[ring][3], changecount, 0,100,0)
        else:
            setring(0,100,0,ring)
            setmode("idle")
    elif mode == "GrabC":
        if firstrun:
            firstrun = False
            changecount = animation_time # 100hz
        if changecount > 0:
            changecount = fadeall(rings[ring][2],rings[ring][3], changecount, 0,50,100)
        else:
            setring(0,50,100,ring)
            setmode("idle")
    elif mode == "idle":
        time.sleep(0)

    sendall(data)
            
def fadeall(idxa,idxb,sizerem,r,g,b):
    if sizerem < 1:
        return 0
    global exactdata
    sum = 0
    for x in range(idxa,idxb):
        if exactdata[x] is None:
            exactdata[x] = data[x]
        old = exactdata[x]
        dr = (r - old[0])/sizerem
        sum += abs(dr)
        dr += old[0]
        dg = (g - old[1])/sizerem
        sum += abs(dg)
        dg += old[1]
        db = (b - old[2])/sizerem 
        db += old[2]
        sum += abs(db)
        exactdata[x] = (dr, dg, db)
        #print(new)
        setpixel(dr, dg, db, x)
        if sizerem == 1:
            exactdata[x] = None
    if sum == 0 and sizerem > 2:
        sizerem = 2
    return sizerem - 1
        
def fadeorder(idxa,idxb,sizerem,r,g,b):
    if sizerem < 1:
        return 0
    global exactdata
    drs = 0
    dgs = 0
    dbs = 0
    sum = 0
    for x in range(idxa,idxb):
        if exactdata[x] is None:
            exactdata[x] = data[x]
        old = exactdata[x]
        dr = (r - old[0])
        dg = (g - old[1])
        db = (b - old[2])
        drs += dr
        dgs += dg
        dbs += db
    
    drs /= sizerem
    dgs /= sizerem
    dbs /= sizerem
    sum += abs(drs) + abs(dgs) + abs(dbs)
    print(drs,dgs,dbs)
    for x in range(idxa,idxb):
        old = exactdata[x]
        new = list(old)
        if drs > 0:
            if old[0] + drs > r:
                new[0] = r
                drs -= r - old[0]
            else:
                new[0] = old[0] + drs
                drs = 0
        if dgs > 0:
            if old[1] + dgs > g:
                new[1] = g
                dgs -= g - old[1]
            else:
                new[1] = old[1] + dgs
                dgs = 0
        if dbs > 0:
            if old[2] + dbs > b:
                new[2] = b
                dbs -= b - old[2]
            else:
                new[2] = old[2] + dbs
                dbs = 0

        if drs < 0:
            if old[0] + drs < r:
                new[0] = r
                drs -= r - old[0]
            else:
                new[0] = old[0] + drs
                drs = 0
        if dgs < 0:
            if old[1] + dgs < g:
                new[1] = g
                dgs -= g - old[1]
            else:
                new[1] = old[1] + dgs
                dgs = 0
        if dbs < 0:
            if old[2] + dbs < b:
                new[2] = b
                dbs -= b - old[2]
            else:
                new[2] = old[2] + dbs
                dbs = 0

        if drs != 0 or dgs != 0 or dbs != 0:
            exactdata[x] = new
            setpixel(new[0],new[1],new[2],x)
        
        if sizerem == 1:
            exactdata[x] = None

    if sum == 0 and sizerem > 2:
        sizerem = 2
    return sizerem - 1


def setpixel(r, g, b, num):
    global data
    global leds_size
    # constrain values
    if r < 0:
        r = 0
    elif r > 255:
        r = 255
    if g < 0:
        g = 0
    elif g > 255:
        g = 255
    if b < 0:
        b = 0
    elif b > 255:
        b = 255

    if leds_size[num] == 3:
        data[num] = (int(r), int(g), int(b))
    elif leds_size[num] == 4: # cut out matching white and turn on white pixel instead
        data[num] = (( int(r) - int(min(r,g,b)), int(g) - int(min(r,g,b)), int(b) - int(min(r,g,b)), int(min(r,g,b))) )
    else:
        data[num] = (int(r), int(g), int(b))
    

def close():
    global sender
    time.sleep(0.5)
    sender.stop()

def mapimage(image, fps=90):
    global start
    while uptime() - start < 1/fps:
        time.sleep(0.00001)
    fprint(1 / (uptime() - start))
    start = uptime()
    minsize = min(image.shape[0:2])
    leds_normalized2 = [(x * minsize, 
                        y * minsize)
                        for x, y in leds_normalized]
    
    cv2.imshow("video", image)
    cv2.waitKey(1)

    
    #im_rgb = image #cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # OpenCV uses BGR format by default
    avgx = 0
    avgy = 0
    for xx in range(len(leds_normalized2)):
        led = leds_normalized2[xx]
        x, y = int(round(led[0])), int(round(led[1]))
        
        if x < image.shape[1] and y < image.shape[0]:
            #avgx += x
            #avgy += y
            color = tuple(image[y, x])
            setpixel(color[2]/2,color[1]/2,color[0]/2,xx) # swap b & r
            #print(color)
        else:
            #avgx += x
            #avgy += y
            setpixel(0,0,0,xx)
    #avgx /= len(leds)
    #avgy /= len(leds)
    #print((avgx,avgy, max([led[0] for led in leds_adj]), max([led[1] for led in leds_adj]) , min(image.shape[0:2]) ))
    global data
    fastsendall(data)

def mainloop(stmode, ring = -1, fps = 100, preview = False):
    global start
    while uptime() - start < 1/fps:
        time.sleep(0.00001)
    fprint(1 / (uptime() - start))
    start = uptime()
    if mode is not None:
        setmode(stmode)
    runmodes(ring)
    if preview:
        drawdata()

def drawdata():
    #tmp = list()
    #for x in len(leds):
    #    led = leds[x]
    #    tmp.append((led[0], led[1], data[x]))

    x = [led[0] for led in leds]
    y = [led[1] for led in leds]
    colors = data
    colors_normalized = [(x[0]/255, x[1]/255, x[2]/255) for x in colors]
    # Plot the points
    plt.scatter(x, y, c=colors_normalized)

    # Optional: add grid, title, and labels
    plt.grid(True)
    plt.title('Colored Points')
    plt.xlabel('X')
    plt.ylabel('Y')
    plt.show()
    plt.savefig("map3.png", dpi=50, bbox_inches="tight")
    plt.clf()

def startup_animation(show):

    
    stmode = "Startup"
    mainloop(stmode, preview=show)
    while mode == "Startup":
        mainloop(None, preview=show)
    for x in range(54):
        ringstatus[x][0] = False
        mainloop(None, preview=show)

    for x in range(animation_time):
        mainloop(None, preview=show)
    clear_animations()
    stmode = "StartupCheck"
    mainloop(stmode, preview=show)
    clear_animations()

def clear_animations():
    for x in range(len(leds)):
        exactdata[x] = None

def do_animation(stmode, ring=-1):
    mainloop(stmode, ring, preview=show)
    wait_for_animation(ring)

def start_animation(stmode, ring=-1):
    mainloop(stmode, ring, preview=show)

def wait_for_animation(ring=-1):
    while mode != "idle":
        mainloop(None, ring, preview=show)

if __name__ == "__main__":
    init()
    import matplotlib.pyplot as plt
    """cap = cv2.VideoCapture('badapple.mp4')
    while cap.isOpened():
        ret, frame = cap.read()
        if not ret:
            break
        mapimage(frame, fps=30)"""
    show = True
    ring = 1
    startup_animation(show)
    for x in range(54):
        ringstatus[x][0] = True
        mainloop(None, preview=show)
    for x in range(animation_time):
        mainloop(None, preview=show)
    
    do_animation("GrabA", 1)

    do_animation("GrabA", 5)
    start_animation("GrabC", 1)
    
    wait_for_animation(1)
    do_animation("GrabC", 5)
    
    close()
    #sys.exit(0)


    # blue : default
    # green : target
    # yellow : crosshair
    # red : missing 
    # uninitialized : red/purple?