Skip to content

Commit

Permalink
TN: add visualize_tensors
Browse files Browse the repository at this point in the history
  • Loading branch information
jcmgray committed Jan 24, 2024
1 parent 95b4473 commit 429f249
Showing 1 changed file with 69 additions and 6 deletions.
75 changes: 69 additions & 6 deletions quimb/tensor/drawing.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,9 +490,7 @@ def draw_tn(
# - user specified scaling
# - number of tensors
# - how flat the plot area is (flatter requires smaller nodes)
full_node_scale = (
0.2 * node_scale * node_packing_factor * plot_volume**0.5
)
full_node_scale = 0.2 * node_scale * node_packing_factor * plot_volume**0.5

default_outline_size = 6 * full_node_scale**0.5

Expand Down Expand Up @@ -1552,9 +1550,9 @@ def average_color(colors):
r, g, b, a = zip(*colors)

# then RMS average each channel
rm = (sum(ri**2 for ri in r) / len(r))**0.5
gm = (sum(gi**2 for gi in g) / len(g))**0.5
bm = (sum(bi**2 for bi in b) / len(b))**0.5
rm = (sum(ri**2 for ri in r) / len(r)) ** 0.5
gm = (sum(gi**2 for gi in g) / len(g)) ** 0.5
bm = (sum(bi**2 for bi in b) / len(b)) ** 0.5
am = sum(a) / len(a)

return (rm, gm, bm, am)
Expand Down Expand Up @@ -1703,3 +1701,68 @@ def visualize_tensor(tensor, **kwargs):
kwargs.setdefault("compass", True)
kwargs.setdefault("compass_labels", tensor.inds)
return xyz.visualize_tensor(tensor.data, **kwargs)


def choose_squarest_grid(x):
p = x**0.5
if p.is_integer():
m = n = int(p)
else:
m = int(round(p))
p = int(p)
n = p if m * p >= x else p + 1
return m, n


def visualize_tensors(
tn,
r=None,
r_scale=1.0,
figsize=None,
mode="network",
**visualize_opts,
):
from matplotlib import pyplot as plt

if figsize is None:
figsize = (2 * tn.num_tensors**0.4, 2 * tn.num_tensors**0.4)
if r is None:
r = 1.0 / tn.num_tensors**0.5
r *= r_scale

max_mag = None
visualize_opts.setdefault("max_mag", max_mag)
visualize_opts.setdefault("size_scale", r)

if mode == "network":
fig = plt.figure(figsize=figsize)
pos = tn.draw(get="pos")
for tid, (x, y) in pos.items():
if tid not in tn.tensor_map:
# hyper indez
continue
x = (x + 1) / 2 - r / 2
y = (y + 1) / 2 - r / 2
ax = fig.add_axes((x, y, r / 2, r / 2))
tn.tensor_map[tid].visualize(ax=ax, **visualize_opts)
else:
if mode == "grid":
px, py = choose_squarest_grid(tn.num_tensors)
elif mode == "row":
px, py = tn.num_tensors, 1
figsize = (2 * figsize[0], figsize[1] / 2)
elif mode == "col":
px, py = 1, tn.num_tensors
figsize = (figsize[0] / 2, 2 * figsize[1])

fig, axs = plt.subplots(py, px, figsize=figsize)
for i, t in enumerate(tn):
t.visualize(ax=axs.flat[i], **visualize_opts)
for ax in axs.flat[i:]:
ax.set_axis_off()

# transparent background
fig.patch.set_alpha(0.0)

plt.show()
plt.close()

0 comments on commit 429f249

Please sign in to comment.