Source code for zoviz.database

"""Class for handling zotero.sqlite database connection, queries, and visualization"""
# pylint: disable=too-few-public-methods,too-many-locals
import os
import sqlite3
from sys import platform
from itertools import combinations, product
import pandas as pd
import networkx as nx


[docs]class DB: """ Interface layer for zotero.sqlite database. Relies on the assumption that the entire database is small enough to load into memory without issue. Tables are loaded from disk the first time they are accessed. """ def __init__(self, db_path=None): self.db_path = db_path or guess_db_path() validate_db_path(self.db_path) self._conn = sqlite3.connect(self.db_path) self._data = {} tables_query = \ "SELECT name FROM sqlite_master WHERE type='table' ORDER BY name" table_names = self.query_df(tables_query).T.values.tolist()[0] self.tables = {n: self.query_table_columns(n) for n in table_names} collection_names_query = \ "SELECT distinct collectionName FROM collections" self.collection_names = self.query_df( collection_names_query).T.values.tolist()[0] def __getattr__(self, key): return self.__getitem__(key) def __getitem__(self, key): return self._data.get(key, self.load_table(key)) def __enter__(self): return self def __exit__(self, exc_type, exc_value, exc_traceback): if self._conn is not None: self.close()
[docs] def close(self): """Close database connection""" self._conn.close()
[docs] def load_table(self, name: str) -> pd.DataFrame: """Load a whole table from disk :param name: Name of table :type name: str :return: Table contents :rtype: pd.DataFrame """ if name in self.tables: self._data[name] = self.query_df("select * from %s" % name) else: raise KeyError("No such table: %s" % name) return self._data[name]
[docs] def query_df(self, query: str) -> pd.DataFrame: """Return the result of a query as a DataFrame :param query: String containing SQL query :type query: str :return: DataFrame of result :rtype: pd.DataFrame """ table = self._conn.cursor().execute(query) # str sanitization via cursor cols = [x[0] for x in table.description] df = pd.DataFrame(zip(*table), index=cols).transpose() return df
[docs] def query_table_columns(self, name: str) -> list: """Get list of column names for a table :param name: Name of table :type name: str :return: List of columns :rtype: list """ query = f"PRAGMA table_info('{name}')" cols = self.query_df(query)["name"].values.tolist() return cols
[docs] def build_creator_graph(self, collection: str) -> nx.MultiGraph: """Build a graph data structure of collaborative work :param collection: Zotero Collection name :type collection: str :return: MultiGraph with Creator objects as nodes and 1 edge per collaboration :rtype: nx.MultiGraph """ # Get data collection_id = self.query_df( "select collectionID from collections where collectionName = '%s'" % collection)["collectionID"].values[0] collection_items = self.query_df( "select distinct itemID from collectionItems where collectionID = %d" % collection_id)["itemID"].values creators_table = self["creators"] initials = [x[0] for x in creators_table["firstName"].values] creator_names = [initial + ". " + lastname for initial, lastname in zip(initials, creators_table["lastName"])] creator_map = {i: Creator(name=n, creator_id=i) for i, n in zip( creators_table["creatorID"], creator_names)} # Build graph g = nx.MultiGraph() for item_id in collection_items: item_creator_ids = self.query_df( "select creatorID from itemCreators where itemID = %d" % item_id)["creatorID"].values item_creators = [creator_map[i] for i in item_creator_ids] for c in item_creators: c.count += 1 if c not in g.nodes: g.add_node(c) g.add_edges_from(combinations(item_creators, 2)) # Contract nodes with duplicate names duplicate_names = {n for n in creator_names if creator_names.count(n) > 1} for name in duplicate_names: nodes = [n for n in g.nodes if str(n) == name] for u, v in product(nodes[0:1], nodes[1:]): u += v # Combine stats g = nx.contracted_nodes(g, u, v, self_loops=False) return g
[docs]class Creator: """Content-creator data for use as a graph node""" def __init__(self, name, creator_id): self.name = name self.id = creator_id self.count = 0 self.contracted_ids = [] def __str__(self): # For labeling in networkx embedding return self.name def __add__(self, other): # For contracting duplicates self.count += other.count self.contracted_ids.append(other.id) return self
[docs]def guess_db_path(): """Guess location of zotero.sqlite based on operating system""" if any([x in platform.lower() for x in ["linux", "darwin"]]): db_path = os.path.join(os.path.expanduser("~"), "Zotero", "zotero.sqlite") elif "windows" in platform.lower(): db_path = os.path.join(os.path.expandvars("% HOMEPATH %"), "Zotero", "zotero.sqlite") else: error_string = "Unable to resolve database location on OS %s ." % platform error_string += "\nYou will need to supply the path as an argument." raise NotImplementedError(error_string) return db_path
[docs]def validate_db_path(db_path: str): """Check if the database file exists :param db_path: Path to zotero.sqlite :type db_path: str :raises FileNotFoundError: Raised if database does not exist """ if not os.path.isfile(db_path): error_string = "Did not locate zotero.sqlite database at %s ." % db_path error_string += "\nYou will need to supply the path as an argument." raise FileNotFoundError(error_string)