"""This module contains functionality for interacting with a PostgreSQL database. It will automatically handle error
conditions (i.e. missing columns) without terminating the entire program. Use the :py:class:`DBConnector` class to
handle database interactions, either as a standalone object or in a context manager."""
from __future__ import annotations

import os
import psycopg2
from psycopg2 import DatabaseError, OperationalError
from psycopg2.errors import UndefinedColumn

DB_ADDRESS = os.getenv('DB_ADDRESS', 'localhost')
DB_PORT = os.getenv('DB_PORT', 5432)
DB_USER = os.getenv('DB_USER', 'postgres')
DB_PASSWORD = os.getenv('DB_PASSWORD', '')
DB_NAME = os.getenv('DB_NAME', 'postgres')
DB_TABLE = os.getenv('DB_TABLE', 'cables')


class DBConnector:
    """Context managed database class. Use with statements to automatically open and close the database connection, like
    so:

    .. code-block:: python
       with DBConnector() as db:
           db.read()
    """

    def _db_start(self):
        """Setup the database connection and cursor."""
        try:
            self.conn = psycopg2.connect(
                f"host={DB_ADDRESS} port={DB_PORT} dbname={DB_NAME} user={DB_USER} password={DB_PASSWORD}")
            self.cur = self.conn.cursor()
        except OperationalError as e:
            raise e

    def _db_stop(self):
        """Close the cursor and connection."""
        self.cur.close()
        self.conn.close()

    def __init__(self):
        self._db_start()

    def __del__(self):
        self._db_stop()

    def __enter__(self):
        self._db_start()

    def __exit__(self):
        self._db_stop()

    def _get_cols(self) -> set[str]:
        """Get the list of columns in the database.

        :return: A list of column names."""
        query = f"select COLUMN_NAME from information_schema.columns where table_name={DB_TABLE}"
        rows = {x["COLUMN_NAME"] for x in self._query(query)}
        return rows

    def _column_parity(self, columns: list[str] | set[str]) -> set[str]:
        """If the listed columns are not in the database, add them.

        :param columns: The columns we expect are in the database.
        :return: The list of columns in the database after querying."""
        cols = set(columns)
        existing = self._get_cols()
        needs = cols.difference(existing.intersection(cols))
        if len(needs) > 0:
            query = f"ALTER TABLE {DB_TABLE} {', '.join([f'ADD COLUMN {c}' for c in needs])}"
            self._query(query)
            existing = self._get_cols()
        return existing

    def _query(self, sql) -> list[dict]:
        """Basic function for running queries.

        :param sql: SQL query as plaintext.
        :return: Results of the query, or an empty list if none."""
        result = []
        try:
            self.cur.execute(sql)
            result = self._read_dict()
        except DatabaseError as e:
            print(f"ERROR {e.pgcode}: {e.pgerror}\n"
                  f"Caused by query: {sql}")
        finally:
            return result

    def _read_dict(self) -> list[dict]:
        """Read the cursor as a list of dictionaries. psycopg2 defaults to using a list of tuples, so we want to convert
        each row into a dictionary before we return it."""
        cols = [i.name for i in self.cur.description]
        results = []
        for row in self.cur:
            row_dict = {}
            for i in range(0, len(row)):
                if row[i]:
                    row_dict = {**row_dict, cols[i]: row[i]}
            results.append(row_dict)
        return results

    def read(self, **kwargs) -> list[dict]:
        """Read rows from a database that match the specified filters.

        :param kwargs: Column constraints; i.e. what value to filter by in what column.
        :returns: A list of dictionaries of all matching rows, or an empty list if no match."""
        args = []
        for kw in kwargs.keys():
            args.append(f"{kw} ILIKE {kwargs['kw']}")
        query = f"SELECT * FROM {DB_TABLE}"
        if len(args) > 0:
            query += f" WHERE {' AND '.join(args)}"
        return self._query(query)

    def write(self, **kwargs) -> dict:
        """Write a row to the database.

        :param kwargs: Values to write for each database; specify each column separately!
        :returns: The row you just added."""
        self._column_parity(set(kwargs.keys()))
        values = []
        for val in kwargs.keys():
            values.append(kwargs[val])
        query = f"INSERT INTO {DB_TABLE} ({', '.join(kwargs.keys())}) VALUES ({', '.join(values)})"
        self._query(query)
        return kwargs

    def write_all(self, items: list[dict]) -> list[dict]:
        """Write multiple rows to the database.

        :param items: Rows to write, as a list of dictionaries.
        :returns: The rows that were added successfully."""
        successes = []
        for i in items:
            res0 = self.write(**i)
            if res0:
                successes.append(res0)
        return successes