diff --git a/database.py b/database.py new file mode 100644 index 0000000..2befec2 --- /dev/null +++ b/database.py @@ -0,0 +1,140 @@ +"""This module contains functionality for interacting with a PostgreSQL database. It will automatically handle error +conditions (i.e. missing columns) without terminating the entire program. Use the :py:class:`DBConnector` class to +handle database interactions, either as a standalone object or in a context manager.""" +from __future__ import annotations + +import os +import psycopg2 +from psycopg2 import DatabaseError, OperationalError +from psycopg2.errors import UndefinedColumn + +DB_ADDRESS = os.getenv('DB_ADDRESS', 'localhost') +DB_PORT = os.getenv('DB_PORT', 5432) +DB_USER = os.getenv('DB_USER', 'postgres') +DB_PASSWORD = os.getenv('DB_PASSWORD', '') +DB_NAME = os.getenv('DB_NAME', 'postgres') +DB_TABLE = os.getenv('DB_TABLE', 'cables') + + +class DBConnector: + """Context managed database class. Use with statements to automatically open and close the database connection, like + so: + + .. code-block:: python + with DBConnector() as db: + db.read() + """ + + def _db_start(self): + """Setup the database connection and cursor.""" + try: + self.conn = psycopg2.connect( + f"host={DB_ADDRESS} port={DB_PORT} dbname={DB_NAME} user={DB_USER} password={DB_PASSWORD}") + self.cur = self.conn.cursor() + except OperationalError as e: + raise e + + def _db_stop(self): + """Close the cursor and connection.""" + self.cur.close() + self.conn.close() + + def __init__(self): + self._db_start() + + def __del__(self): + self._db_stop() + + def __enter__(self): + self._db_start() + + def __exit__(self): + self._db_stop() + + def _get_cols(self) -> set[str]: + """Get the list of columns in the database. + + :return: A list of column names.""" + query = f"select COLUMN_NAME from information_schema.columns where table_name={DB_TABLE}" + rows = {x["COLUMN_NAME"] for x in self._query(query)} + return rows + + def _column_parity(self, columns: list[str] | set[str]) -> set[str]: + """If the listed columns are not in the database, add them. + + :param columns: The columns we expect are in the database. + :return: The list of columns in the database after querying.""" + cols = set(columns) + existing = self._get_cols() + needs = cols.difference(existing.intersection(cols)) + if len(needs) > 0: + query = f"ALTER TABLE {DB_TABLE} {', '.join([f'ADD COLUMN {c}' for c in needs])}" + self._query(query) + existing = self._get_cols() + return existing + + def _query(self, sql) -> list[dict]: + """Basic function for running queries. + + :param sql: SQL query as plaintext. + :return: Results of the query, or an empty list if none.""" + result = [] + try: + self.cur.execute(sql) + result = self._read_dict() + except DatabaseError as e: + print(f"ERROR {e.pgcode}: {e.pgerror}\n" + f"Caused by query: {sql}") + finally: + return result + + def _read_dict(self) -> list[dict]: + """Read the cursor as a list of dictionaries. psycopg2 defaults to using a list of tuples, so we want to convert + each row into a dictionary before we return it.""" + cols = [i.name for i in self.cur.description] + results = [] + for row in self.cur: + row_dict = {} + for i in range(0, len(row)): + if row[i]: + row_dict = {**row_dict, cols[i]: row[i]} + results.append(row_dict) + return results + + def read(self, **kwargs) -> list[dict]: + """Read rows from a database that match the specified filters. + + :param kwargs: Column constraints; i.e. what value to filter by in what column. + :returns: A list of dictionaries of all matching rows, or an empty list if no match.""" + args = [] + for kw in kwargs.keys(): + args.append(f"{kw} ILIKE {kwargs['kw']}") + query = f"SELECT * FROM {DB_TABLE}" + if len(args) > 0: + query += f" WHERE {' AND '.join(args)}" + return self._query(query) + + def write(self, **kwargs) -> dict: + """Write a row to the database. + + :param kwargs: Values to write for each database; specify each column separately! + :returns: The row you just added.""" + self._column_parity(set(kwargs.keys())) + values = [] + for val in kwargs.keys(): + values.append(kwargs[val]) + query = f"INSERT INTO {DB_TABLE} ({', '.join(kwargs.keys())}) VALUES ({', '.join(values)})" + self._query(query) + return kwargs + + def write_all(self, items: list[dict]) -> list[dict]: + """Write multiple rows to the database. + + :param items: Rows to write, as a list of dictionaries. + :returns: The rows that were added successfully.""" + successes = [] + for i in items: + res0 = self.write(**i) + if res0: + successes.append(res0) + return successes diff --git a/requirements.txt b/requirements.txt index 3de8bac..d8af0de 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,6 +5,10 @@ pypdf2==2.12.1 alive-progress requests git+https://github.com/Byeongdulee/python-urx.git +<<<<<<< HEAD +======= +psycopg2 +>>>>>>> dthomas-db pyyaml Flask selenium