diff --git a/quimb/tensor/drawing.py b/quimb/tensor/drawing.py index c1f5b2ba..381f0290 100644 --- a/quimb/tensor/drawing.py +++ b/quimb/tensor/drawing.py @@ -490,9 +490,7 @@ def draw_tn( # - user specified scaling # - number of tensors # - how flat the plot area is (flatter requires smaller nodes) - full_node_scale = ( - 0.2 * node_scale * node_packing_factor * plot_volume**0.5 - ) + full_node_scale = 0.2 * node_scale * node_packing_factor * plot_volume**0.5 default_outline_size = 6 * full_node_scale**0.5 @@ -1552,9 +1550,9 @@ def average_color(colors): r, g, b, a = zip(*colors) # then RMS average each channel - rm = (sum(ri**2 for ri in r) / len(r))**0.5 - gm = (sum(gi**2 for gi in g) / len(g))**0.5 - bm = (sum(bi**2 for bi in b) / len(b))**0.5 + rm = (sum(ri**2 for ri in r) / len(r)) ** 0.5 + gm = (sum(gi**2 for gi in g) / len(g)) ** 0.5 + bm = (sum(bi**2 for bi in b) / len(b)) ** 0.5 am = sum(a) / len(a) return (rm, gm, bm, am) @@ -1703,3 +1701,68 @@ def visualize_tensor(tensor, **kwargs): kwargs.setdefault("compass", True) kwargs.setdefault("compass_labels", tensor.inds) return xyz.visualize_tensor(tensor.data, **kwargs) + + +def choose_squarest_grid(x): + p = x**0.5 + if p.is_integer(): + m = n = int(p) + else: + m = int(round(p)) + p = int(p) + n = p if m * p >= x else p + 1 + return m, n + + +def visualize_tensors( + tn, + r=None, + r_scale=1.0, + figsize=None, + mode="network", + **visualize_opts, +): + from matplotlib import pyplot as plt + + if figsize is None: + figsize = (2 * tn.num_tensors**0.4, 2 * tn.num_tensors**0.4) + if r is None: + r = 1.0 / tn.num_tensors**0.5 + r *= r_scale + + max_mag = None + visualize_opts.setdefault("max_mag", max_mag) + visualize_opts.setdefault("size_scale", r) + + if mode == "network": + fig = plt.figure(figsize=figsize) + pos = tn.draw(get="pos") + for tid, (x, y) in pos.items(): + if tid not in tn.tensor_map: + # hyper indez + continue + x = (x + 1) / 2 - r / 2 + y = (y + 1) / 2 - r / 2 + ax = fig.add_axes((x, y, r / 2, r / 2)) + tn.tensor_map[tid].visualize(ax=ax, **visualize_opts) + else: + if mode == "grid": + px, py = choose_squarest_grid(tn.num_tensors) + elif mode == "row": + px, py = tn.num_tensors, 1 + figsize = (2 * figsize[0], figsize[1] / 2) + elif mode == "col": + px, py = 1, tn.num_tensors + figsize = (figsize[0] / 2, 2 * figsize[1]) + + fig, axs = plt.subplots(py, px, figsize=figsize) + for i, t in enumerate(tn): + t.visualize(ax=axs.flat[i], **visualize_opts) + for ax in axs.flat[i:]: + ax.set_axis_off() + + # transparent background + fig.patch.set_alpha(0.0) + + plt.show() + plt.close()