#!/usr/bin/env python3

# Parse Belden catalog techdata datasheets 

from PyPDF2 import PdfReader
import camelot
import numpy as np
from PIL import Image
import io
import json
from util import fprint
import uuid
from util import run_cmd

def parse(filename, output_dir, partnum, dstype):

    # Extract table data

    tables = camelot.read_pdf(filename, pages="1-end", flavor='lattice', backend="poppler", split_text=False, line_scale=100, process_background=True, resolution=600, interations=1, layout_kwargs={'detect_vertical': False, 'char_margin': 0.5}, shift_text=['r', 't'])
    #fprint("Total tables extracted:", tables.n)
    n = 0
    pagenum = 0
    reader = PdfReader(filename)
    page = reader.pages[0]
    table_list = {}
    for table in tables:
        table.df.infer_objects(copy=False)
        table.df.replace('', np.nan, inplace=True)
        table.df.dropna(inplace=True, how="all")
        table.df.dropna(inplace=True, axis="columns", how="all")
        table.df.replace(np.nan, '', inplace=True)
        
        if not table.df.empty:
            #fprint("\nTable " + str(n))
            # Extract table names
            table_start = table.cells[0][0].lt[1] # Read top-left cell's top-left coordinate
            #fprint(table_start)
            ymin = table_start
            ymax = table_start + 10
            if pagenum != table.page - 1:
                pagenum = table.page - 1
                page = reader.pages[table.page - 1]
            parts = []
            def visitor_body(text, cm, tm, fontDict, fontSize):
                y = tm[5]
                if y > ymin and y < ymax:
                    parts.append(text)

            page.extract_text(visitor_text=visitor_body)
            text_body = "".join(parts).strip('\n')
            if len(text_body) == 0:
                text_body = str(n)
            #fprint(text_body)
            

            table_list[text_body] = table.df
            #table.to_html("table" + str(n) + ".html")
            
            #fprint(table.df)
            #camelot.plot(table, kind='grid').savefig("test" + str(n) + ".png")
            n=n+1
    #camelot.plot(tables[0], kind='grid').savefig("test.png")
    
    #tables.export(output_dir + '/techdata.json', f='json')

    # fprint(table_list)
    # Extract Basic details - part name & description, image, etc

    reader = PdfReader(filename)
    page = reader.pages[0]
    count = 0
    skip = False
    for image_file_object in page.images:
        if image_file_object.name == "img0.png" and skip == False:
            #fprint(Image.open(io.BytesIO(image_file_object.data)).mode)
            if Image.open(io.BytesIO(image_file_object.data)).mode == "P":
                skip = True
                continue
            with open(output_dir + "/brand.png", "wb") as fp:
                fp.write(image_file_object.data)
        if Image.open(io.BytesIO(image_file_object.data)).size == (430, 430):
            with open(output_dir + "/part.png", "wb") as fp:
                fp.write(image_file_object.data)
    if skip:
        for image_file_object in page.images:
            if image_file_object.name == "img1.png":
                with open(output_dir + "/brand.png", "wb") as fp:
                    fp.write(image_file_object.data)
                    count += 1

    # Table parsing and reordring
    tables = dict()
    previous_table = ""
    for table_name in table_list.keys():
        # determine shape: horizontal or vertical
        table = table_list[table_name]
        rows = table.shape[0]
        cols = table.shape[1]
        vertical = None
        if rows > 2 and cols == 2:
            vertical = True
        elif cols == 1:
            vertical = False
        elif rows == 1:
            vertical = True
        elif cols == 2: # and rows <= 2
            # inconsistent
            if table.iloc[0, 0].find(":") == len(table.iloc[0, 0]) - 1: # check if last character is ":" indicating a vertical table
                vertical = True
            else:
                vertical = False

        elif cols > 2: # and rows <= 2
            vertical = False
        elif rows > 2 and cols > 2: # big table
            vertical = False
        else: # 1 column, <= 2 rows
            vertical = False

        # missing name check
        for table_name_2 in table_list.keys(): 
            if table_name_2.find(table.iloc[-1, 0]) >= 0:
                # Name taken from table directly above - this table does not have a name
                table_list["Specs " + str(len(tables))] = table_list.pop(table_name_2, None) # rename table to arbitrary altername name
                break

        if vertical:
            out = dict()
            for row in table.itertuples(index=False, name=None):
                out[row[0].replace("\n", " ").replace(":", "")] = row[1]

        else: # horizontal
            out = dict()
            for col in table.columns:
                col_data = tuple(table[col])
                out[col_data[0].replace("\n", " ")] = col_data[1:]
        
        tables[table_name] = out



        # multi-page table check
        if dstype == "Belden":
            if table_name.isdigit() and len(tables) > 1:
                fprint(table_name)
                fprint(previous_table)
                
                
                
                
                main_key = previous_table
                cont_key = table_name
                fprint(tables)
                if vertical == False:
                    main_keys = list(tables[main_key].keys())
                    for i, (cont_key, cont_values) in enumerate(tables[cont_key].items()):
                        if i < len(main_keys):
                            fprint(tables[main_key][main_keys[i]])
                            tables[main_key][main_keys[i]] = (tables[main_key][main_keys[i]] + (cont_key,) + cont_values)
    
                    del tables[table_name]
    
                else:
                    for key in tables[cont_key].keys():
                        tables[main_key][key] = tables[cont_key][key]
                    del tables[table_name]

        previous_table = table_name
    
    # remove multi-line values that occasionally squeak through
    def replace_newlines_in_dict(d):
        for key, value in d.items():
            if isinstance(value, str):
                # Replace \n with " " if the value is a string
                d[key] = value.replace('\n', ' ')
            elif isinstance(value, dict):
                # Recursively call the function if the value is another dictionary
                replace_newlines_in_dict(value)
        return d
    
    tables = replace_newlines_in_dict(tables)

    # summary

    output_table = dict()
    output_table["partnum"] = partnum
    id = str(uuid.uuid4())
    output_table["id"] = id
    #output_table["position"] = id
    #output_table["brand"] = brand
    output_table["fullspecs"] = tables
    output_table["searchspecs"] = {"partnum": partnum, **flatten(tables)}
    
    output_table["searchspecs"]["id"] = id
    


    print(output_table)

    run_cmd("rm " + output_dir + "/*.json") # not reliable!
    with open(output_dir + "/" + output_table["searchspecs"]["id"] + ".json", 'w') as json_file:
        json.dump(output_table["searchspecs"], json_file)

    return output_table


def flatten(tables):
    def convert_to_number(s):
        try:
            # First, try converting to an integer.
            return int(s)
        except ValueError:
            # If that fails, try converting to a float.
            try:
                return float(s)
            except ValueError:
                # If it fails again, return the original string.
                return s
    out = dict()
    print("{")
    for table in tables.keys():
        for key in tables[table].keys():
            if len(key) < 64:
                keyname = key
            else:
                keyname = key[0:64]

            fullkeyname = (table + ": " + keyname).replace(".","")
            if type(tables[table][key]) is not tuple:
                out[fullkeyname] = convert_to_number(tables[table][key])
                print("\"" + keyname + "\":", "\"" + str(out[fullkeyname]) + "\",")
            elif len(tables[table][key]) == 1:
                out[fullkeyname] = convert_to_number(tables[table][key][0])
                
                print("\"" + keyname + "\":", "\"" + str(out[fullkeyname]) + "\",")

    print("}")
    return out

    

if __name__ == "__main__":
    parse("test2.pdf", "cables/10GXS13", "10GXS13")