#!/usr/bin/env python3

import sacn
import time
import sys
import yaml
import math
import random
from util import fprint
import platform    # For getting the operating system name
import subprocess  # For executing a shell command
from util import win32
import cv2
import numpy as np
from uptime import uptime



class LEDSystem():
    sender = None
    debug = True
    config = None
    leds = None
    leds_size = None
    leds_normalized = None
    controllers = None
    data = None
    exactdata = None
    rings = None
    ringstatus = None
    mode = "Startup"
    firstrun = True
    changecount = 0
    animation_time = 0
    start = uptime()
    
    def __init__(self):
        self.start = uptime()
        #self.init()
        #return self

    def ping(self, host):
        #Returns True if host (str) responds to a ping request.

        # Option for the number of packets as a function of
        if win32:
            param1 = '-n'
            param2 = '-w'
            param3 = '250'
        else:
            param1 = '-c'
            param2 = '-W'
            param3 = '0.25'

        # Building the command. Ex: "ping -c 1 google.com"
        command = ['ping', param1, '1', param2, param3, host]

        return subprocess.call(command, stdout=subprocess.DEVNULL, stderr=subprocess.STDOUT) == 0

    def map(self):

        with open('config.yml', 'r') as fileread:
            #global config
            self.config = yaml.safe_load(fileread)

        self.animation_time = self.config["animation_time"]
        self.leds = list()
        self.leds_size = list()
        self.controllers = list()
        self.rings = list(range(len(self.config["position_map"])))
        print("Setting ring status")
        self.ringstatus = list(range(len(self.config["position_map"])))
        #print(rings)
        #fprint(config["led"]["map"])
        generate_map = False
        map = list()
        for shape in self.config["led"]["map"]:
            if shape["type"] == "circle":

                if generate_map:
                    map.append((shape["pos"][1],shape["pos"][0]))
                #fprint(shape["pos"])
                anglediv = 360.0 / shape["size"]
                angle = 0
                radius = shape["diameter"] / 2
                lednum = shape["start"]
                for item in self.config['position_map']:
                # Check if the current item's position matches the target position
                    #print(item['pos'],(shape["pos"][1],shape["pos"][0]))
                    if tuple(item['pos']) == (shape["pos"][1],shape["pos"][0]):
                        self.rings[item["index"]] = (shape["pos"][1],shape["pos"][0],lednum,lednum+shape["size"]) # rings[index] = x, y, startpos, endpos
                        self.ringstatus[item["index"]] = [None, None]
                        break
                if len(self.leds) < lednum + shape["size"]:
                    for x in range(lednum + shape["size"] - len(self.leds)):
                        self.leds.append(None)
                        self.leds_size.append(None)
                while angle < 359.999:
                    tmpangle = angle + shape["angle"]
                    x = math.cos(tmpangle * (math.pi / 180.0)) * radius + shape["pos"][1] # flip by 90 degress when we changed layout
                    y = math.sin(tmpangle * (math.pi / 180.0)) * radius + shape["pos"][0]
                    self.leds[lednum] = (x,y)
                    lednum = lednum + 1
                    angle = angle + anglediv

            elif shape["type"] == "strip":
                angle = shape["angle"]
                lednum = shape["start"]
                length = shape["length"]
                distdiv = length / shape["size"]
                dist = distdiv / 2
                xmov = math.cos(angle * (math.pi / 180.0)) * distdiv
                ymov = math.sin(angle * (math.pi / 180.0)) * distdiv
                pos = shape["pos"]
                if len(self.leds) < lednum + shape["size"]:
                    for x in range(lednum + shape["size"] - len(self.leds)):
                        self.leds.append(None)
                        self.leds_size.append(None)

                while dist < length:
                    self.leds[lednum] = (pos[0], pos[1])
                    pos[0] += xmov
                    pos[1] += ymov
                    dist += distdiv
                    lednum = lednum + 1

        if generate_map:
            map = sorted(map, key=lambda x: (-x[1], x[0]))
            print(map)
            import matplotlib.pyplot as plt
            plt.axis('equal')
            x, y = zip(*map)
            plt.scatter(x, y, s=12)
            #plt.plot(x, y, marker='o')
            #plt.scatter(*zip(*leds), s=3)
            for i, (x_pos, y_pos) in enumerate(map):
                plt.text(x_pos, y_pos, str(i), color="red", fontsize=12)
            plt.savefig("map2.png", dpi=600, bbox_inches="tight")
            data = {"map": [{"index": i, "pos": str(list(pos))} for i, pos in enumerate(map)]}
            yaml_str = yaml.dump(data, default_flow_style=False)
            print(yaml_str)

        print(self.rings)
        flag = 0
        for x in self.leds:
            if x is None:
                flag = flag + 1
        if flag > 0:
            fprint("Warning: Imperfect LED map ordering. Hiding undefined lights.")
            for x in range(len(self.leds)):
                if self.leds[x] is None:
                    self.leds[x] = (0, 0)


        #leds = tmpleds.reverse()
        #fprint(leds)

        # controller mapping
        for ctrl in self.config["led"]["controllers"]:
            if len(self.controllers) < ctrl["universe"]:
                print(ctrl["universe"])
                for x in range(ctrl["universe"] - len(self.controllers)):
                    self.controllers.append(None)

            self.controllers[ctrl["universe"]-1] = (ctrl["ledstart"],ctrl["ledend"]+1,ctrl["ip"])
            for x in range(ctrl["ledstart"],ctrl["ledend"]+1):
                self.leds_size[x] = len(ctrl["mode"])
        #fprint(controllers)

        if(self.debug):
            import matplotlib.pyplot as plt
            plt.axis('equal')
            for ctrl in self.controllers:
                plt.scatter(*zip(*self.leds[ctrl[0]:ctrl[1]]), s=2)
            #plt.scatter(*zip(*leds), s=3)
            plt.savefig("map.png", dpi=600, bbox_inches="tight")

        leds_adj = [(x-min([led[0] for led in self.leds]), # push to zero start
                        y-min([led[1] for led in self.leds]) )
                       for x, y in self.leds]

        self.leds_normalized = [(x / max([led[0] for led in leds_adj]), 
                        y / max([led[1] for led in leds_adj]))
                       for x, y in leds_adj]
        #return leds, controllers

    def init(self):
        self.map()
        self.sender = sacn.sACNsender(fps=self.config["led"]["fps"], universeDiscovery=False)
        self.sender.start()  # start the sending thread
        for x in range(len(self.controllers)):
            print("Waiting for the controller at", self.controllers[x][2], "to be online...", end="", flush=True)
            count = 0
            while not self.ping(self.controllers[x][2]):
                count = count + 1
                if count >= self.config["led"]["timeout"]:
                    print(" ERROR: controller still offline after " + str(count) + " seconds, continuing...")
                    break
            else:
                print(" done")
            #if count < self.config["led"]["timeout"]:
                
        time.sleep(1)
        for x in range(len(self.controllers)):
            print("Activating controller", x+1, "at", self.controllers[x][2], "with", self.controllers[x][1]-self.controllers[x][0], "LEDs.")
            self.sender.activate_output(x+1)  # start sending out data
            self.sender[x+1].destination = self.controllers[x][2]
        self.sender.manual_flush = True

        # initialize global pixel data list
        self.data = list()
        self.exactdata = list()
        for x in range(len(self.leds)):
            if self.leds_size[x] == 3:
                self.exactdata.append(None)
                self.data.append((20,20,127))
            elif self.leds_size[x] == 4:
                self.exactdata.append(None)
                self.data.append((50,50,255,0))
            else:
                self.exactdata.append(None)
                self.data.append((0,0,0))
        self.sendall(self.data)
        #time.sleep(50000)    
        # fprint("Running start-up test sequence...")
        # for y in range(1):
        #     for x in range(len(self.leds)):
        #         self.setpixel(0,60,144,x)
        #     self.sendall(self.data)
        #     #time.sleep(2)
        # self.alloffsmooth()
        self.startup_animation(show=False)

    def sendall(self, datain):
        # send all LED data to all controllers
        # data must have all LED data in it as [(R,G,B,)] tuples in an array, 1 tuple per pixel
        self.sender.manual_flush = True
        for x in range(len(self.controllers)):
            self.sender[x+1].dmx_data = list(sum(datain[self.controllers[x][0]:self.controllers[x][1]] , ())) # flatten the subsection of the data array

        self.sender.flush()
        time.sleep(0.002)
        #sender.flush() # 100% reliable with 2 flushes, often fails with 1
        #time.sleep(0.002)
        #sender.flush()

    def fastsendall(self, datain):
        # send all LED data to all controllers
        # data must have all LED data in it as [(R,G,B,)] tuples in an array, 1 tuple per pixel
        self.sender.manual_flush = True
        #print(datain[self.controllers[0][0]:self.controllers[0][1]])
        for x in range(len(self.controllers)):
            self.sender[x+1].dmx_data = list(sum(datain[self.controllers[x][0]:self.controllers[x][1]] , ())) # flatten the subsection of the data array

        self.sender.flush()

    def senduniverse(self, datain, lednum):
        # send all LED data for 1 controller/universe
        # data must have all LED data in it as [(R,G,B,)] tuples in an array, 1 tuple per pixel
        for x in range(len(self.controllers)):
            if lednum >= self.controllers[x][0] and lednum < self.controllers[x][1]:
                self.sender[x+1].dmx_data = list(sum(datain[self.controllers[x][0]:self.controllers[x][1]] , ())) # flatten the subsection of the data array

        self.sender.flush()
        time.sleep(0.004)
        #sender.flush() # 100% reliable with 2 flushes, often fails with 1
        #time.sleep(0.002)
        #sender.flush()

    def alloff(self):
        tmpdata = list()
        for x in range(len(self.leds)):
            if self.leds_size[x] == 3:
                tmpdata.append((0,0,0))
            elif self.leds_size[x] == 4:
                tmpdata.append((0,0,0,0))
            else:
                tmpdata.append((0,0,0))
        self.sendall(tmpdata)
        #sendall(tmpdata)
        #sendall(tmpdata) #definitely make sure it's off
        return self

    def allon(self):
        self.sendall(self.data)
        return self

    def alloffsmooth(self):
        tmpdata = self.data
        for x in range(256):
            for x in range(len(self.data)):
                self.setpixel(tmpdata[x][0]-1,tmpdata[x][1]-1,tmpdata[x][2]-1, x)
            self.sendall(tmpdata)

        self.alloff()
        return self

    def setpixelnow(self, r, g, b, num):
        # slight optimization: send only changed universe
        # unfortunately no way to manual flush data packets to only 1 controller with this sACN library
        self.setpixel(r,g,b,num)
        self.senduniverse(self.data, num)
        return self

    def setmode(self, stmode, r=0,g=0,b=0):
        if stmode is not None:
            if self.mode != stmode:
                self.firstrun = True

            self.mode = stmode
        return self
    
    def setrange(self, start, end, r,g,b):
        val = (r,g,b)
        for x in range(start,end):
            self.data[x] = val

    def setallrings(self, r,g,b, exclude):
        startidx1 = self.rings[0][2]
        endidx2 = self.rings[-1][3]
        endidx1 = self.rings[exclude][2]
        startidx2 = self.rings[exclude][3]+1
        self.setrange(startidx, endidx, r,g,b)
        

    def setring(self, r,g,b,idx):
        ring = self.rings[idx]
        for pixel in range(ring[2],ring[3]):
            self.setpixel(r,g,b,pixel)
        #global data
        #senduniverse(data, ring[2])
        return self

    def runmodes(self, ring = -1, arm_position = None):
        #fprint("Mode: " + str(self.mode))
        if self.mode == "Startup":
            # loading animation. cable check
            if self.firstrun:
                self.changecount = self.animation_time * 3
                self.firstrun = False
                for x in range(len(self.ringstatus)):
                    self.ringstatus[x] = [True, self.animation_time]

            if self.changecount > 0:
                #fprint(self.changecount)
                self.changecount = self.fadeorder(0,len(self.leds), self.changecount, 0,50,100)
            else:
                self.setmode("Startup2")


        elif self.mode == "Startup2":
            if self.firstrun:
                self.firstrun = False

            else:
                for x in range(len(self.ringstatus)):
                    if self.ringstatus[x][0]:
                        self.setring(0, 50, 100, x)
                    else:
                        self.ringstatus[x][1] = self.fadeall(self.rings[x][2],self.rings[x][3], self.ringstatus[x][1], 100,0,0) # not ready

        elif self.mode == "StartupCheck":
            if self.firstrun:
                self.firstrun = False
                for x in range(len(self.ringstatus)):
                    self.ringstatus[x] = [False, self.animation_time]
            else:
                for x in range(len(self.ringstatus)):
                    if self.ringstatus[x][0]:
                        self.ringstatus[x][1] = self.fadeall(self.rings[x][2],self.rings[x][3], self.ringstatus[x][1], 0,50,100) # ready  
                    else:
                        self.setring(100, 0, 0, x)

        elif self.mode == "GrabA":
            if self.firstrun:
                self.firstrun = False
                self.changecount = self.animation_time # 100hz
            if self.changecount > 0:
                self.changecount = self.fadeall(self.rings[ring][2],self.rings[ring][3], self.changecount, 100,0,0)
            else:
                self.setring(100,0,0,ring)
                self.setmode("GrabB")
        elif self.mode == "GrabB":
            if self.firstrun:
                self.firstrun = False
                self.changecount = self.animation_time # 100hz
            if self.changecount > 0:
                #self.changecount = self.fadeorder(self.rings[ring][2],self.rings[ring][3], self.changecount, 0,100,0)
                self.changecount = self.fadeorder(self.rings[ring][2],self.rings[ring][2]+24, self.changecount, 0,100,0)
            else:
                self.setring(0,100,0,ring)
                self.setmode("idle")
        elif self.mode == "GrabC":
            if self.firstrun:
                self.firstrun = False
                self.changecount = self.animation_time # 100hz
            if self.changecount > 0:
                self.changecount = self.fadeall(self.rings[ring][2],self.rings[ring][3], self.changecount, 0,50,100)
            else:
                self.setring(0,50,100,ring)
                self.setmode("idle")

        elif self.mode == "Moving":
            if self.firstrun:
                self.firstrun = False
            posxy = arm_position[0:1]
            posxy[0] = int(posxy[0] * 1000)
            posxy[0] = int(posxy[1] * 1000)
            radius = int(arm_position[2] * 1000)
            base = (0,50,100)
            target = (100,100,100)
            deltar = target[0] - base[0]
            deltag = target[0] - base[0]
            #deltab = target[0] - base[0]
            # reset!
            self.setallringsexcept(0,50,100, ring)
            # fade outwards
            for idx,led in enumerate(self.leds):
                dist = int(math.srqt(math.pow(int(posxy[0] - led[0]), 2) + math.pow(int(posxy[1] - led[1]), 2)))
                if dist < radius:
                    ratio = dist/radius
                    self.data[idx] = (int(base[0] + ratio * deltar), int(base[1] + ratio * deltag), 100) #base[2] + ratio * deltab)

        elif self.mode == "idle":
            time.sleep(0)

        self.sendall(self.data)
        return self

    def fadeall(self, idxa,idxb,sizerem,r,g,b):
        if sizerem < 1:
            return 0
        sum = 0
        for x in range(idxa,idxb):
            if self.exactdata[x] is None:
                self.exactdata[x] = self.data[x]
            old = self.exactdata[x]
            dr = (r - old[0])/sizerem
            sum += abs(dr)
            dr += old[0]
            dg = (g - old[1])/sizerem
            sum += abs(dg)
            dg += old[1]
            db = (b - old[2])/sizerem 
            db += old[2]
            sum += abs(db)
            self.exactdata[x] = (dr, dg, db)
            #print(new)
            self.setpixel(dr, dg, db, x)
            if sizerem == 1:
                self.exactdata[x] = None
        if sum == 0 and sizerem > 2:
            sizerem = 2
        return sizerem - 1

    def fadeorder(self, idxa,idxb,sizerem,r,g,b):
        if sizerem < 1:
            return 0
        drs = 0
        dgs = 0
        dbs = 0
        sum = 0
        for x in range(idxa,idxb):
            if self.exactdata[x] is None:
                self.exactdata[x] = self.data[x]
            old = self.exactdata[x]
            dr = (r - old[0])
            dg = (g - old[1])
            db = (b - old[2])
            drs += dr
            dgs += dg
            dbs += db

        drs /= sizerem
        dgs /= sizerem
        dbs /= sizerem
        sum += abs(drs) + abs(dgs) + abs(dbs)
        #print(drs,dgs,dbs)
        for x in range(idxa,idxb):
            old = self.exactdata[x]
            new = list(old)
            if drs > 0:
                if old[0] + drs > r:
                    new[0] = r
                    drs -= r - old[0]
                else:
                    new[0] = old[0] + drs
                    drs = 0
            if dgs > 0:
                if old[1] + dgs > g:
                    new[1] = g
                    dgs -= g - old[1]
                else:
                    new[1] = old[1] + dgs
                    dgs = 0
            if dbs > 0:
                if old[2] + dbs > b:
                    new[2] = b
                    dbs -= b - old[2]
                else:
                    new[2] = old[2] + dbs
                    dbs = 0

            if drs < 0:
                if old[0] + drs < r:
                    new[0] = r
                    drs -= r - old[0]
                else:
                    new[0] = old[0] + drs
                    drs = 0
            if dgs < 0:
                if old[1] + dgs < g:
                    new[1] = g
                    dgs -= g - old[1]
                else:
                    new[1] = old[1] + dgs
                    dgs = 0
            if dbs < 0:
                if old[2] + dbs < b:
                    new[2] = b
                    dbs -= b - old[2]
                else:
                    new[2] = old[2] + dbs
                    dbs = 0

            if drs != 0 or dgs != 0 or dbs != 0:
                self.exactdata[x] = new
                self.setpixel(new[0],new[1],new[2],x)

            if sizerem == 1:
                self.exactdata[x] = None

        if sum == 0 and sizerem > 2:
            sizerem = 2
        return sizerem - 1


    def setpixel(self, r, g, b, num):
        # constrain values
        if r < 0:
            r = 0
        elif r > 255:
            r = 255
        if g < 0:
            g = 0
        elif g > 255:
            g = 255
        if b < 0:
            b = 0
        elif b > 255:
            b = 255

        if self.leds_size[num] == 3:
            self.data[num] = (int(r), int(g), int(b))
        elif self.leds_size[num] == 4: # cut out matching white and turn on white pixel instead
            self.data[num] = (( int(r) - int(min(r,g,b)), int(g) - int(min(r,g,b)), int(b) - int(min(r,g,b)), int(min(r,g,b))) )
        else:
            self.data[num] = (int(r), int(g), int(b))
        return self
    
    def setpixelfast3(self, rgb, num):
        self.data[num] = rgb

    def close(self):
        time.sleep(0.5)
        self.sender.stop()
        return self

    def mapimage(self, image, fps=30):
        
        #fprint(1 / (uptime() - self.start))
        self.start = uptime()
        minsize = min(image.shape[0:2])
        if image.shape[1] > image.shape[0]:
            offset = (image.shape[1] - image.shape[0])/2
            leds_normalized2 = [(x * minsize + offset, 
                                y * minsize)
                                for x, y in self.leds_normalized]
        else:
            offset = (image.shape[0] - image.shape[1])/2
            leds_normalized2 = [(x * minsize, 
                                y * minsize + offset)
                                for x, y in self.leds_normalized]

        cv2.imshow("video", image)
        cv2.waitKey(1)


        #im_rgb = image #cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # OpenCV uses BGR format by default
        avgx = 0
        avgy = 0
        for xx in range(len(leds_normalized2)):
            led = leds_normalized2[xx]
            x, y = int(round(led[0])), int(round(led[1]))

            if x < image.shape[1] and y < image.shape[0]:
                #avgx += x
                #avgy += y
                color = tuple(image[y, x])
                self.setpixel(color[2]/2,color[1]/2,color[0]/2,xx) # swap b & r
                #print(color)
            else:
                #avgx += x
                #avgy += y
                self.setpixel(0,0,0,xx)
        #avgx /= len(leds)
        #avgy /= len(leds)
        #print((avgx,avgy, max([led[0] for led in leds_adj]), max([led[1] for led in leds_adj]) , min(image.shape[0:2]) ))
        self.fastsendall(self.data)
        while self.start + 1.0/fps > uptime():
            time.sleep(0.00001)
        return self

    def mainloop(self, stmode, ring = -1, fps = 100, preview = False, arm_position = None):
        while uptime() - self.start < 1/fps:
            time.sleep(0.00001)
        #fprint("Running LED loop with ring " + str(ring) + " and set mode " + str(stmode))
        #fprint(1 / (uptime() - self.start))
        self.start = uptime()
        if self.mode is not None:
            self.setmode(stmode)

        #if self.
        self.runmodes(ring, arm_position)
        if preview:
            self.drawdata()
        return self

    def drawdata(self):
        #tmp = list()
        #for x in len(leds):
        #    led = leds[x]
        #    tmp.append((led[0], led[1], data[x]))

        x = [led[0] for led in self.leds]
        y = [led[1] for led in self.leds]
        colors = self.data
        colors_normalized = [(x[0]/255, x[1]/255, x[2]/255) for x in colors]
        # Plot the points
        plt.scatter(x, y, c=colors_normalized)

        # Optional: add grid, title, and labels
        plt.grid(True)
        plt.title('Colored Points')
        plt.xlabel('X')
        plt.ylabel('Y')
        plt.show()
        plt.savefig("map3.png", dpi=50, bbox_inches="tight")
        plt.clf()
        return self

    def startup_animation(self, show):


        stmode = "Startup"
        self.mainloop(stmode, preview=show)
        while self.mode == "Startup":
            self.mainloop(None, preview=show)
        for x in range(54):
            self.ringstatus[x][0] = False
            self.mainloop(None, preview=show)

        for x in range(self.animation_time):
            self.mainloop(None, preview=show)
        self.clear_animations()
        stmode = "idle"
        self.mainloop(stmode, preview=show)
        self.clear_animations()
        return self

    def clear_animations(self):
        for x in range(len(self.leds)):
            self.exactdata[x] = None
        return self

    def do_animation(self, stmode, ring=-1):
        self.mainloop(stmode, ring, preview=show)
        self.wait_for_animation(ring)
        return self

    def start_animation(self, stmode, ring=-1):
        self.mainloop(stmode, ring, preview=show)
        return self

    def wait_for_animation(self, ring=-1):
        while self.mode != "idle":
            self.mainloop(None, ring, preview=show)
        return self

if __name__ == "__main__":
    
    import matplotlib.pyplot as plt
    ledsys = LEDSystem()
    ledsys.init()
    cap = cv2.VideoCapture('badapple.mp4')
    while cap.isOpened():
        ret, frame = cap.read()
        if not ret:
            break
        ledsys.mapimage(frame, fps=120)

    show = False
    ring = 1
    
    ledsys.startup_animation(show)
    for x in range(54):
        ledsys.ringstatus[x][0] = True
        ledsys.mainloop(None, preview=show)
    for x in range(ledsys.animation_time):
        ledsys.mainloop(None, preview=show)
    
    ledsys.do_animation("GrabA", 1)

    ledsys.do_animation("GrabA", 5)
    ledsys.start_animation("GrabC", 1)
    
    ledsys.wait_for_animation(1)
    ledsys.do_animation("GrabC", 5)
    
    ledsys.close()
    #sys.exit(0)


    # blue : default
    # green : target
    # yellow : crosshair
    # red : missing 
    # uninitialized : red/purple?