Source code for zoviz.visualization
""" Visualization functions specialized for Zotero data """
import networkx as nx
import matplotlib.pyplot as plt
import seaborn as sns
[docs]def draw_multigraph(g: nx.MultiGraph, pos: dict, radius_mult=0.1,
directed=False, **kwargs) -> plt.Axes:
"""A way to draw multiple edges per node pair
:param g: A multigraph, probably returned by zoviz.DB.build_creator_graph()
:type g: nx.MultiGraph
:param pos: networkx layout
:type pos: dict
:param directed: Draw arrows?, defaults to False
:type directed: bool, optional
:return: return axis handle
:rtype: plt.Axes
"""
ax = plt.gca()
# Draw nodes & labels
label_kwargs = {k: v for k, v in kwargs.items() if "font" in k}
node_kwargs = {k: v for k, v in kwargs.items() if "font" not in k}
nx.draw_networkx_labels(g, pos, **label_kwargs)
nx.draw_networkx_nodes(g, pos, **node_kwargs)
# Draw edges w/ different arc radius for each duplicate
arrowstyle = "->" if directed else "-"
for e in g.edges:
startpos = pos[e[0]]
endpos = pos[e[1]]
edge_index = e[2] # I am the Nth edge between these nodes
radius = radius_mult * float(edge_index)
connectionstyle = "arc3,rad=%f" % radius
arrowprops = {"arrowstyle": arrowstyle,
"connectionstyle": connectionstyle,
"color": kwargs.get("edge_color", 'k'),
"alpha": kwargs.get("edge_alpha", 0.3),
"linewidth": kwargs.get("edge_linewidth", 1.)}
ax.annotate("", xy=startpos, xytext=endpos,
xycoords="data", arrowprops=arrowprops, zorder=-1)
return ax
[docs]def draw_community_graph(g: nx.MultiGraph, fig=None, **kwargs) -> plt.Figure:
"""A quick, deterministic embedding for the creator graph
:param collection: Zotero Collection name
:type collection: str
:return: figure handle
:rtype: plt.Figure
"""
if len(g.nodes) < 2:
raise ValueError("Community graph has less than 2 members; unable to draw")
fig = fig or plt.figure(figsize=(10, 10), dpi=120.)
node_degrees = [g.degree(n) for n in g.nodes]
cmap = sns.cubehelix_palette(start=0, rot=3., dark=0.6,
as_cmap=True, reverse=True)
colors = cmap([float(x) / max(max(node_degrees), 1) for x in node_degrees])
counts = [float(x.count) for x in g.nodes]
max_count = max(counts)
sizes_normed = [x * 5e3 / max_count for x in counts]
layout = nx.spring_layout(g, pos=nx.circular_layout(g),
k=len(g.nodes) / 16., iterations=50)
kwargs["alpha"] = kwargs.get("alpha", 0.9)
kwargs["font_size"] = kwargs.get("font_size", 6)
# kwargs["with_labels"] = kwargs.get("with_labels", True)
draw_multigraph(g, pos=layout, node_color=colors,
node_size=sizes_normed, **kwargs)
return fig, layout