| #!/usr/bin/env python3 |
| # |
| # Copyright (c) 2021 Project CHIP Authors |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| # |
| """Wrapper and utility functions around sqlite3""" |
| |
| import sqlite3 |
| from typing import List, Optional |
| |
| import pandas as pd # type: ignore |
| from memdf import Config, ConfigDescription |
| |
| CONFIG: ConfigDescription = { |
| Config.group_def('database'): { |
| 'title': 'database options', |
| }, |
| 'database.file': { |
| 'help': 'Sqlite3 file', |
| 'metavar': 'FILENAME', |
| 'default': None, |
| 'argparse': { |
| 'alias': ['--db'], |
| }, |
| }, |
| } |
| |
| |
| class Database: |
| """Wrapper and utility functions around sqlite3""" |
| on_open: Optional[List[str]] = None |
| on_writable: Optional[List[str]] = None |
| |
| def __init__(self, filename: str, writable: bool = True): |
| self.filename = filename |
| self.writable = writable |
| self.con: Optional[sqlite3.Connection] = None |
| |
| def __enter__(self): |
| return self.open() |
| |
| def __exit__(self, et, ev, traceback): |
| self.close() |
| return False |
| |
| def open(self): |
| """Open and initialize the database connection.""" |
| if not self.con: |
| db = 'file:' + self.filename |
| if not self.writable: |
| db += '?mode=ro' |
| self.con = sqlite3.connect(db, uri=True) |
| if self.on_open: |
| for i in self.on_open: |
| self.con.execute(i) |
| if self.writable and self.on_writable: |
| for i in self.on_writable: |
| self.con.execute(i) |
| return self |
| |
| def close(self): |
| if self.con: |
| self.con.close() |
| self.con = None |
| return self |
| |
| def connection(self) -> sqlite3.Connection: |
| assert self.con |
| return self.con |
| |
| def execute(self, query, parameters=None): |
| if parameters: |
| return self.con.execute(query, parameters) |
| return self.con.execute(query) |
| |
| def commit(self): |
| self.con.commit() |
| return self |
| |
| def store(self, table: str, **kwargs): |
| """Insert the data if it does not already exist.""" |
| q = (f"INSERT INTO {table} ({','.join(kwargs.keys())})" |
| f" VALUES ({','.join('?' * len(kwargs))})" |
| f" ON CONFLICT DO NOTHING") |
| v = list(kwargs.values()) |
| self.connection().execute(q, v) |
| |
| def get_matching(self, table: str, columns: List[str], **kwargs): |
| q = (f"SELECT {','.join(columns)} FROM {table}" |
| f" WHERE {'=? AND '.join(kwargs.keys())}=?") |
| v = list(kwargs.values()) |
| return self.connection().execute(q, v) |
| |
| def get_matching_id(self, table: str, **kwargs): |
| cur = self.get_matching(table, ['id'], **kwargs) |
| row = cur.fetchone() |
| if row: |
| return row[0] |
| return None |
| |
| def store_and_return_id(self, table: str, **kwargs) -> Optional[int]: |
| self.store(table, **kwargs) |
| return self.get_matching_id(table, **kwargs) |
| |
| def data_frame(self, query, parameters=None) -> pd.DataFrame: |
| """Return the results of a query as a DataFrame.""" |
| cur = self.execute(query, parameters) |
| columns = [i[0] for i in cur.description] |
| df = pd.DataFrame(cur.fetchall(), columns=columns) |
| self.commit() |
| df.attrs = {'title': query} |
| return df |