diff --git a/database.py b/database.py index 4c8de7c..c5b0448 100644 --- a/database.py +++ b/database.py @@ -3,7 +3,8 @@ conditions (i.e. missing columns) without terminating the entire program. Use th handle database interactions, either as a standalone object or in a context manager.""" import os import psycopg2 -from psycopg2 import DatabaseError +from psycopg2 import DatabaseError, OperationalError +from psycopg2.errors import UndefinedColumn DB_ADDRESS = os.getenv('DB_ADDRESS', 'localhost') DB_PORT = os.getenv('DB_PORT', 5432) @@ -26,8 +27,8 @@ class DBConnector: self.conn = psycopg2.connect( f"host={DB_ADDRESS} port={DB_PORT} dbname={DB_NAME} user={DB_USER} password={DB_PASSWORD}") self.cur = self.conn.cursor() - except DatabaseError as e: - + except OperationalError as e: + raise e def _db_stop(self): self.cur.close() @@ -45,7 +46,7 @@ class DBConnector: def __exit__(self): self._db_stop() - def _query(self, sql): + def _query(self, sql) -> list[dict]: try: self.cur.execute(sql) result = self.cur.fetchall() @@ -54,7 +55,7 @@ class DBConnector: result = [] return result - def read(self, **kwargs): + def read(self, **kwargs) -> list[dict]: """Read rows from a database that match the specified filters. :param kwargs: Column constraints; i.e. what value to filter by in what column. @@ -67,8 +68,26 @@ class DBConnector: query += f" WHERE {' AND '.join(args)}" return self._query(query) - def write(self, **kwargs): + def write(self, **kwargs) -> dict: """Write a row to the database. - :param kwargs: Values to write for each database; specify each column separately!""" - pass \ No newline at end of file + :param kwargs: Values to write for each database; specify each column separately! + :returns: The row you just added.""" + values = [] + for val in kwargs.keys(): + values.append(kwargs[val]) + query = f"INSERT INTO {DB_TABLE} ({', '.join(kwargs.keys())}) VALUES ({', '.join(values)})" + self._query(query) + return kwargs + + def write_all(self, items: list[dict]) -> list[dict]: + """Write multiple rows to the database. + + :param items: Rows to write, as a list of dictionaries. + :returns: The rows that were added successfully.""" + successes = [] + for i in items: + res0 = self.write(**i) + if res0: + successes.append(res0) + return successes \ No newline at end of file