Compare commits
	
		
			2 Commits
		
	
	
		
			aadb6ba24d
			...
			0f2c19e811
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| 0f2c19e811 | |||
| b18355fc14 | 
							
								
								
									
										140
									
								
								database.py
									
									
									
									
									
								
							
							
						
						
									
										140
									
								
								database.py
									
									
									
									
									
								
							| @@ -1,140 +0,0 @@ | |||||||
| """This module contains functionality for interacting with a PostgreSQL database. It will automatically handle error |  | ||||||
| conditions (i.e. missing columns) without terminating the entire program. Use the :py:class:`DBConnector` class to |  | ||||||
| handle database interactions, either as a standalone object or in a context manager.""" |  | ||||||
| from __future__ import annotations |  | ||||||
|  |  | ||||||
| import os |  | ||||||
| import psycopg2 |  | ||||||
| from psycopg2 import DatabaseError, OperationalError |  | ||||||
| from psycopg2.errors import UndefinedColumn |  | ||||||
|  |  | ||||||
| DB_ADDRESS = os.getenv('DB_ADDRESS', 'localhost') |  | ||||||
| DB_PORT = os.getenv('DB_PORT', 5432) |  | ||||||
| DB_USER = os.getenv('DB_USER', 'postgres') |  | ||||||
| DB_PASSWORD = os.getenv('DB_PASSWORD', '') |  | ||||||
| DB_NAME = os.getenv('DB_NAME', 'postgres') |  | ||||||
| DB_TABLE = os.getenv('DB_TABLE', 'cables') |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class DBConnector: |  | ||||||
|     """Context managed database class. Use with statements to automatically open and close the database connection, like |  | ||||||
|     so: |  | ||||||
|  |  | ||||||
|     .. code-block:: python |  | ||||||
|        with DBConnector() as db: |  | ||||||
|            db.read() |  | ||||||
|     """ |  | ||||||
|  |  | ||||||
|     def _db_start(self): |  | ||||||
|         """Setup the database connection and cursor.""" |  | ||||||
|         try: |  | ||||||
|             self.conn = psycopg2.connect( |  | ||||||
|                 f"host={DB_ADDRESS} port={DB_PORT} dbname={DB_NAME} user={DB_USER} password={DB_PASSWORD}") |  | ||||||
|             self.cur = self.conn.cursor() |  | ||||||
|         except OperationalError as e: |  | ||||||
|             raise e |  | ||||||
|  |  | ||||||
|     def _db_stop(self): |  | ||||||
|         """Close the cursor and connection.""" |  | ||||||
|         self.cur.close() |  | ||||||
|         self.conn.close() |  | ||||||
|  |  | ||||||
|     def __init__(self): |  | ||||||
|         self._db_start() |  | ||||||
|  |  | ||||||
|     def __del__(self): |  | ||||||
|         self._db_stop() |  | ||||||
|  |  | ||||||
|     def __enter__(self): |  | ||||||
|         self._db_start() |  | ||||||
|  |  | ||||||
|     def __exit__(self): |  | ||||||
|         self._db_stop() |  | ||||||
|  |  | ||||||
|     def _get_cols(self) -> set[str]: |  | ||||||
|         """Get the list of columns in the database. |  | ||||||
|  |  | ||||||
|         :return: A list of column names.""" |  | ||||||
|         query = f"select COLUMN_NAME from information_schema.columns where table_name={DB_TABLE}" |  | ||||||
|         rows = {x["COLUMN_NAME"] for x in self._query(query)} |  | ||||||
|         return rows |  | ||||||
|  |  | ||||||
|     def _column_parity(self, columns: list[str] | set[str]) -> set[str]: |  | ||||||
|         """If the listed columns are not in the database, add them. |  | ||||||
|  |  | ||||||
|         :param columns: The columns we expect are in the database. |  | ||||||
|         :return: The list of columns in the database after querying.""" |  | ||||||
|         cols = set(columns) |  | ||||||
|         existing = self._get_cols() |  | ||||||
|         needs = cols.difference(existing.intersection(cols)) |  | ||||||
|         if len(needs) > 0: |  | ||||||
|             query = f"ALTER TABLE {DB_TABLE} {', '.join([f'ADD COLUMN {c}' for c in needs])}" |  | ||||||
|             self._query(query) |  | ||||||
|             existing = self._get_cols() |  | ||||||
|         return existing |  | ||||||
|  |  | ||||||
|     def _query(self, sql) -> list[dict]: |  | ||||||
|         """Basic function for running queries. |  | ||||||
|  |  | ||||||
|         :param sql: SQL query as plaintext. |  | ||||||
|         :return: Results of the query, or an empty list if none.""" |  | ||||||
|         result = [] |  | ||||||
|         try: |  | ||||||
|             self.cur.execute(sql) |  | ||||||
|             result = self._read_dict() |  | ||||||
|         except DatabaseError as e: |  | ||||||
|             print(f"ERROR {e.pgcode}: {e.pgerror}\n" |  | ||||||
|                   f"Caused by query: {sql}") |  | ||||||
|         finally: |  | ||||||
|             return result |  | ||||||
|  |  | ||||||
|     def _read_dict(self) -> list[dict]: |  | ||||||
|         """Read the cursor as a list of dictionaries. psycopg2 defaults to using a list of tuples, so we want to convert |  | ||||||
|         each row into a dictionary before we return it.""" |  | ||||||
|         cols = [i.name for i in self.cur.description] |  | ||||||
|         results = [] |  | ||||||
|         for row in self.cur: |  | ||||||
|             row_dict = {} |  | ||||||
|             for i in range(0, len(row)): |  | ||||||
|                 if row[i]: |  | ||||||
|                     row_dict = {**row_dict, cols[i]: row[i]} |  | ||||||
|             results.append(row_dict) |  | ||||||
|         return results |  | ||||||
|  |  | ||||||
|     def read(self, **kwargs) -> list[dict]: |  | ||||||
|         """Read rows from a database that match the specified filters. |  | ||||||
|  |  | ||||||
|         :param kwargs: Column constraints; i.e. what value to filter by in what column. |  | ||||||
|         :returns: A list of dictionaries of all matching rows, or an empty list if no match.""" |  | ||||||
|         args = [] |  | ||||||
|         for kw in kwargs.keys(): |  | ||||||
|             args.append(f"{kw} ILIKE {kwargs['kw']}") |  | ||||||
|         query = f"SELECT * FROM {DB_TABLE}" |  | ||||||
|         if len(args) > 0: |  | ||||||
|             query += f" WHERE {' AND '.join(args)}" |  | ||||||
|         return self._query(query) |  | ||||||
|  |  | ||||||
|     def write(self, **kwargs) -> dict: |  | ||||||
|         """Write a row to the database. |  | ||||||
|  |  | ||||||
|         :param kwargs: Values to write for each database; specify each column separately! |  | ||||||
|         :returns: The row you just added.""" |  | ||||||
|         self._column_parity(set(kwargs.keys())) |  | ||||||
|         values = [] |  | ||||||
|         for val in kwargs.keys(): |  | ||||||
|             values.append(kwargs[val]) |  | ||||||
|         query = f"INSERT INTO {DB_TABLE} ({', '.join(kwargs.keys())}) VALUES ({', '.join(values)})" |  | ||||||
|         self._query(query) |  | ||||||
|         return kwargs |  | ||||||
|  |  | ||||||
|     def write_all(self, items: list[dict]) -> list[dict]: |  | ||||||
|         """Write multiple rows to the database. |  | ||||||
|  |  | ||||||
|         :param items: Rows to write, as a list of dictionaries. |  | ||||||
|         :returns: The rows that were added successfully.""" |  | ||||||
|         successes = [] |  | ||||||
|         for i in items: |  | ||||||
|             res0 = self.write(**i) |  | ||||||
|             if res0: |  | ||||||
|                 successes.append(res0) |  | ||||||
|         return successes |  | ||||||
		Reference in New Issue
	
	Block a user