Compare commits

..

No commits in common. "bb19158aa0d1f2b887b80f006f18232150ddded1" and "4a4e969e0f4a2b0c6d1d7acc1829bce464f21e06" have entirely different histories.

@ -1,8 +1,6 @@
"""This module contains functionality for interacting with a PostgreSQL database. It will automatically handle error
conditions (i.e. missing columns) without terminating the entire program. Use the :py:class:`DBConnector` class to
handle database interactions, either as a standalone object or in a context manager."""
from __future__ import annotations
import os
import psycopg2
from psycopg2 import DatabaseError, OperationalError
@ -24,9 +22,7 @@ class DBConnector:
with DBConnector() as db:
db.read()
"""
def _db_start(self):
"""Setup the database connection and cursor."""
try:
self.conn = psycopg2.connect(
f"host={DB_ADDRESS} port={DB_PORT} dbname={DB_NAME} user={DB_USER} password={DB_PASSWORD}")
@ -35,7 +31,6 @@ class DBConnector:
raise e
def _db_stop(self):
"""Close the cursor and connection."""
self.cur.close()
self.conn.close()
@ -51,50 +46,14 @@ class DBConnector:
def __exit__(self):
self._db_stop()
def _get_cols(self) -> set[str]:
"""Get the list of columns in the database.
:return: A list of column names."""
query = f"select COLUMN_NAME from information_schema.columns where table_name={DB_TABLE}"
rows = {x["COLUMN_NAME"] for x in self._query(query)}
return rows
def _column_parity(self, columns: list[str] | set[str]) -> set[str]:
"""If the listed columns are not in the database, add them."""
cols = set(columns)
existing = self._get_cols()
needs = cols.difference(existing.intersection(cols))
query = f"ALTER TABLE {DB_TABLE} {', '.join([f'ADD COLUMN {c}' for c in needs])}"
self._query(query)
return self._get_cols()
def _query(self, sql) -> list[dict]:
"""Basic function for running queries.
:param sql: SQL query as plaintext.
:return: Results of the query, or an empty list if none."""
result = []
try:
self.cur.execute(sql)
result = self._read_dict()
result = self.cur.fetchall()
except DatabaseError as e:
print(f"ERROR {e.pgcode}: {e.pgerror}\n"
f"Caused by query: {sql}")
finally:
return result
def _read_dict(self) -> list[dict]:
"""Read the cursor as a list of dictionaries. psycopg2 defaults to using a list of tuples, so we want to convert
each row into a dictionary before we return it."""
cols = [i.name for i in self.cur.description]
results = []
for row in self.cur:
row_dict = {}
for i in range(0, len(row)):
if row[i]:
row_dict = {**row_dict, cols[i]: row[i]}
results.append(row_dict)
return results
print(f"DB ERROR [{e.pgcode}]: {e.pgerror}")
result = []
return result
def read(self, **kwargs) -> list[dict]:
"""Read rows from a database that match the specified filters.
@ -131,4 +90,4 @@ class DBConnector:
res0 = self.write(**i)
if res0:
successes.append(res0)
return successes
return successes