"""This module contains functionality for interacting with a PostgreSQL database. It will automatically handle error conditions (i.e. missing columns) without terminating the entire program. Use the :py:class:`DBConnector` class to handle database interactions, either as a standalone object or in a context manager.""" from __future__ import annotations import os import psycopg2 from psycopg2 import DatabaseError, OperationalError from psycopg2.errors import UndefinedColumn DB_ADDRESS = os.getenv('DB_ADDRESS', 'localhost') DB_PORT = os.getenv('DB_PORT', 5432) DB_USER = os.getenv('DB_USER', 'postgres') DB_PASSWORD = os.getenv('DB_PASSWORD', '') DB_NAME = os.getenv('DB_NAME', 'postgres') DB_TABLE = os.getenv('DB_TABLE', 'cables') class DBConnector: """Context managed database class. Use with statements to automatically open and close the database connection, like so: .. code-block:: python with DBConnector() as db: db.read() """ def _db_start(self): """Setup the database connection and cursor.""" try: self.conn = psycopg2.connect( f"host={DB_ADDRESS} port={DB_PORT} dbname={DB_NAME} user={DB_USER} password={DB_PASSWORD}") self.cur = self.conn.cursor() except OperationalError as e: raise e def _db_stop(self): """Close the cursor and connection.""" self.cur.close() self.conn.close() def __init__(self): self._db_start() def __del__(self): self._db_stop() def __enter__(self): self._db_start() def __exit__(self): self._db_stop() def _get_cols(self) -> set[str]: """Get the list of columns in the database. :return: A list of column names.""" query = f"select COLUMN_NAME from information_schema.columns where table_name={DB_TABLE}" rows = {x["COLUMN_NAME"] for x in self._query(query)} return rows def _column_parity(self, columns: list[str] | set[str]) -> set[str]: """If the listed columns are not in the database, add them. :param columns: The columns we expect are in the database. :return: The list of columns in the database after querying.""" cols = set(columns) existing = self._get_cols() needs = cols.difference(existing.intersection(cols)) if len(needs) > 0: query = f"ALTER TABLE {DB_TABLE} {', '.join([f'ADD COLUMN {c}' for c in needs])}" self._query(query) existing = self._get_cols() return existing def _query(self, sql) -> list[dict]: """Basic function for running queries. :param sql: SQL query as plaintext. :return: Results of the query, or an empty list if none.""" result = [] try: self.cur.execute(sql) result = self._read_dict() except DatabaseError as e: print(f"ERROR {e.pgcode}: {e.pgerror}\n" f"Caused by query: {sql}") finally: return result def _read_dict(self) -> list[dict]: """Read the cursor as a list of dictionaries. psycopg2 defaults to using a list of tuples, so we want to convert each row into a dictionary before we return it.""" cols = [i.name for i in self.cur.description] results = [] for row in self.cur: row_dict = {} for i in range(0, len(row)): if row[i]: row_dict = {**row_dict, cols[i]: row[i]} results.append(row_dict) return results def read(self, **kwargs) -> list[dict]: """Read rows from a database that match the specified filters. :param kwargs: Column constraints; i.e. what value to filter by in what column. :returns: A list of dictionaries of all matching rows, or an empty list if no match.""" args = [] for kw in kwargs.keys(): args.append(f"{kw} ILIKE {kwargs['kw']}") query = f"SELECT * FROM {DB_TABLE}" if len(args) > 0: query += f" WHERE {' AND '.join(args)}" return self._query(query) def write(self, **kwargs) -> dict: """Write a row to the database. :param kwargs: Values to write for each database; specify each column separately! :returns: The row you just added.""" self._column_parity(set(kwargs.keys())) values = [] for val in kwargs.keys(): values.append(kwargs[val]) query = f"INSERT INTO {DB_TABLE} ({', '.join(kwargs.keys())}) VALUES ({', '.join(values)})" self._query(query) return kwargs def write_all(self, items: list[dict]) -> list[dict]: """Write multiple rows to the database. :param items: Rows to write, as a list of dictionaries. :returns: The rows that were added successfully.""" successes = [] for i in items: res0 = self.write(**i) if res0: successes.append(res0) return successes