diff --git a/database.py b/database.py index cbff64d..a21809d 100644 --- a/database.py +++ b/database.py @@ -51,6 +51,23 @@ class DBConnector: def __exit__(self): self._db_stop() + def _get_cols(self) -> set[str]: + """Get the list of columns in the database. + + :return: A list of column names.""" + query = f"select COLUMN_NAME from information_schema.columns where table_name={DB_TABLE}" + rows = {x["COLUMN_NAME"] for x in self._query(query)} + return rows + + def _column_parity(self, columns: list[str] | set[str]) -> set[str]: + """If the listed columns are not in the database, add them.""" + cols = set(columns) + existing = self._get_cols() + needs = cols.difference(existing.intersection(cols)) + query = f"ALTER TABLE {DB_TABLE} {', '.join([f'ADD COLUMN {c}' for c in needs])}" + self._query(query) + return self._get_cols() + def _query(self, sql) -> list[dict]: """Basic function for running queries.