plot updates

parent 350c73ad
......@@ -6,4 +6,4 @@
url = https://github.com/tychovdo/MovingMNIST.git
[submodule "monitor/trajectories"]
path = monitor/trajectories
url = git@github.com:domkirke/trajectories.git
url = https://github.com/domkirke/trajectories.git
......@@ -20,3 +20,4 @@ class MovingMNIST(Dataset):
self.data = np.concatenate([train_data, test_data], axis=0)
self.partitions = {'train':np.arange(train_data.shape[0]), 'test':train_data.shape[0] + np.arange(test_data.shape[0])}
self.files = [None]*len(self.data)
......@@ -41,6 +41,8 @@ class Evaluation(nn.Module):
xs.append(decudify(x)); ys.append(decudify(y))
if issubclass(type(outs[0]), list):
outs = [merge_dicts([outs[i][j] for i in range(len(outs))]) for j in range(len(outs[0]))]
else:
outs = merge_dicts(outs)
ys= merge_dicts(ys)
xs = torch.cat(xs, dim=0)
print(' evaluating...')
......
......@@ -55,15 +55,16 @@ class LatentEvaluation(Evaluation):
latent_std_std = latent_out.stddev.std(dim_shape[:-1])
global_stats['enc'].append([latent_mean_mean, latent_mean_std, latent_std_mean, latent_std_std])
for l, latent_out in enumerate(outputs['z_params_dec']):
if latent_out is None:
continue
dim_shape = tuple(range(latent_out.mean.ndimension()))
latent_mean_mean = latent_out.mean.mean(dim_shape[:-1])
latent_mean_std = latent_out.mean.std(dim_shape[:-1])
latent_std_mean = latent_out.stddev.mean(dim_shape[:-1])
latent_std_std = latent_out.stddev.std(dim_shape[:-1])
global_stats['dec'].append([latent_mean_mean, latent_mean_std, latent_std_mean, latent_std_std])
if outputs.get('z_params_dec'):
for l, latent_out in enumerate(outputs['z_params_dec']):
if latent_out is None:
continue
dim_shape = tuple(range(latent_out.mean.ndimension()))
latent_mean_mean = latent_out.mean.mean(dim_shape[:-1])
latent_mean_std = latent_out.mean.std(dim_shape[:-1])
latent_std_mean = latent_out.stddev.mean(dim_shape[:-1])
latent_std_std = latent_out.stddev.std(dim_shape[:-1])
global_stats['dec'].append([latent_mean_mean, latent_mean_std, latent_std_mean, latent_std_std])
out = {**out, 'stats':global_stats}
return out
......
......@@ -260,22 +260,26 @@ def plot_mean(x, target=None, preprocessing=None, axes=None, *args, is_sequence=
if target is not None:
target = target.squeeze()
if is_sequence:
fig = [None]*x.batch_shape[0]; axes = [None]*x.batch_shape[0]
figs = [None]*x.batch_shape[0]; axes = [None]*x.batch_shape[0]
for ex in range(x.batch_shape[0]):
if len(x.batch_shape) <= 3:
if target is not None:
fig[ex], axes[ex] = plot_mean_1d(x[ex], x=target[ex], preprocessing=preprocessing, *args, **kwargs)
figs[ex], axes[ex] = plot_mean_1d(x[ex], x=target[ex], preprocessing=preprocessing, *args, **kwargs)
else:
fig[ex], axes[ex] = plot_mean_1d(x[ex], x=None, preprocessing=preprocessing, *args, **kwargs)
figs[ex], axes[ex] = plot_mean_1d(x[ex], x=None, preprocessing=preprocessing, *args, **kwargs)
elif len(x.batch_shape) == 4:
figs = []; axes = []
if target is not None:
fig, axes = plot_mean_2d(x, x=target, preprocessing=preprocessing, *args, **kwargs)
for i in range(target.shape[0]):
kwargs['out'] = None if kwargs.get('out') is None else kwargs.get('out')+'_%d'%i
fig, ax = plot_mean_2d(x[i], x=target[i], preprocessing=preprocessing, *args, **kwargs)
figs.append(fig); axes.append(ax)
else:
if len(x.batch_shape) <= 2 + is_sequence:
fig, axes = plot_mean_1d(x, x=target, preprocessing=preprocessing, axes=axes, *args, **kwargs)
figs, axes = plot_mean_1d(x, x=target, preprocessing=preprocessing, axes=axes, *args, **kwargs)
elif len(x.batch_shape) == 3 + is_sequence:
fig, axes = plot_mean_2d(x, x=target, preprocessing=preprocessing, axes=axes, *args, **kwargs)
return fig, axes
figs, axes = plot_mean_2d(x, x=target, preprocessing=preprocessing, axes=axes, *args, **kwargs)
return figs, axes
def plot_empirical(x, *args, **kwargs):
return plot_mean(dist.Normal(x.mean, x.stddev), *args, **kwargs)
......@@ -507,11 +511,11 @@ def plot_3d(current_z, meta=None, var=None, classes=None, class_ids=None, class_
if meta is None:
meta = np.arange((current_z.shape[0]))
cmap = get_cmap(meta.shape[0], color_map=cmap)
cmap = get_cmap(meta.shape[0])
cmap_hash = {x:x for x in meta}
legend = False
else:
cmap = get_cmap(0, color_map=cmap) if classes is None else get_cmap(len(classes), color_map=cmap)
cmap = get_cmap(0) if classes is None else get_cmap(len(classes))
cmap_hash = {None:None} if classes is None else {classes[i]:i for i in range(len(classes))}
current_alpha = 0.06 if (centroids and not meta is None) else 1.0
......@@ -538,8 +542,7 @@ def plot_3d(current_z, meta=None, var=None, classes=None, class_ids=None, class_
ax.scatter(current_z[i,0,0], current_z[i,0,1],current_z[i,0,2], c=color, alpha = current_alpha, marker='o')
ax.scatter(current_z[i,-1,0], current_z[i,-1,1],current_z[i,-1,2], c=color, alpha = current_alpha, marker='+')
else:
cs = np.array([cmap_hash[m] for m in meta])
# pdb.set_trace()
cs = cmap(np.stack([cmap_hash[m] for m in meta]))
if current_z.shape[1]==2:
ax.scatter(current_z[index_ids, 0], current_z[:,1], np.zeros_like(current_z[index_ids,0]), c=cs[index_ids], alpha = current_alpha, s=current_var)
else:
......@@ -548,8 +551,9 @@ def plot_3d(current_z, meta=None, var=None, classes=None, class_ids=None, class_
if centroids and not meta is None:
for i, cid in class_ids.items():
centroid = np.mean(current_z[cid], axis=0)
ax.scatter(centroid[0], centroid[1], centroid[2], s = 30, c=cmap(classes[i]))
ax.text(centroid[0], centroid[1], centroid[2], class_names[i], color=cmap(classes[i]), fontsize=10)
color=np.array(cmap(cmap_hash[i]))[np.newaxis]
ax.scatter(centroid[0], centroid[1], centroid[2], s = 30, c= color)
ax.text(centroid[0], centroid[1], centroid[2], class_names[i], color=cmap(cmap_hash[i]), fontsize=10)
# make legends
if legend and not meta is None and not classes is None:
handles = []
......
......@@ -88,11 +88,17 @@ def plot_reconstructions(dataset, model, label=None, n_points=10, out=None, prep
if vae_out.get('x_params') is not None:
fig_path = name if len(vae_out) == 1 else '%s_%d'%(name, i)
fig, ax = core.plot_distribution(vae_out['x_params'][i], target=data_pp[i], preprocessing=preprocessing, preprocess=preprocess, multihead=multihead_outputs, out=fig_path, **kwargs)
fig, ax = core.plot_distribution(vae_out['x_params'][i], target=data_pp[i], preprocessing=preprocessing, preprocess=preprocess, multihead=multihead_outputs, **kwargs)
figs[os.path.basename(fig_path)] = fig; axes[os.path.basename(fig_path)] = ax
if not out is None:
fig_path = f"{out}/{fig_path}{suffix}.pdf"
fig.savefig(fig_path, format="pdf")
if issubclass(type(fig), list):
for f in range(len(fig)):
fig_path_tmp = f"{out}/{fig_path}{suffix}_{f}.pdf"
fig[f].savefig(fig_path_tmp, format="pdf")
else:
fig_path = f"{out}/{fig_path}{suffix}.pdf"
fig.savefig(fig_path, format="pdf")
fig_reinforced = None
if vae_out.get('x_reinforced') is not None:
......@@ -830,7 +836,7 @@ def plot_latent3(dataset, model, transformation=None, n_points=None, preprocessi
else:
class_names = {} if task is None else {v:k for k, v in dataset.classes[task].items()}
legend = False if task is None else legend
fig, ax = core.plot(full_z[full_ids.get_ids(task)], meta=meta, var=full_var[full_ids.get_ids(task)], classes=nclasses[task], class_ids=class_ids, class_names=class_names, centroids=centroids, legend=legend, sequence=sequence)
fig, ax = core.plot(full_z[full_ids.get_ids(task)], meta=meta, var=full_var[full_ids.get_ids(task)], classes=nclasses[task], class_ids=class_ids[task], class_names=class_names, centroids=centroids, legend=legend, sequence=sequence)
# register and export
fig_name = 'layer %d / task %s'%(layer, task) if task else 'layer%d / no task'%layer
fig.suptitle(fig_name)
......@@ -1339,8 +1345,6 @@ def plot_latent_trajs(dataset, model, n_points=None, preprocessing=None, label=N
full_z = np.concatenate([x.mean.cpu().detach().numpy() for x in vae_out[layer]['out_params']], axis=0)
full_var = np.concatenate([x.variance.cpu().detach().numpy() for x in vae_out[layer]['out_params']], axis=0)
# iteration over tasks
for task in tasks:
print('-- plotting task %s'%task)
......
......@@ -148,16 +148,18 @@ def concat_distrib(distrib_list, unsqueeze=True, dim=1):
means = torch.cat(means, dim)
stds = torch.cat(stds, dim)
return type(distrib_list[0])(means, stds)
def concat_categorical(distrib_list):
probs = [d.probs for d in distrib_list]
probs = [d.logits for d in distrib_list]
probs = [p.unsqueeze(dim) for p in probs]
probs = torch.cat(probs, dim)
return dist.Categorical(probs=probs)
assert(len(set([type(d) for d in distrib_list])) == 1)
if type(distrib_list[0]) in (dist.Normal, dist.RandomWalk):
return concat_normal(distrib_list)
elif type(distrib_list[0]) == dist.Categorical:
elif type(distrib_list[0]) in (dist.Bernoulli, dist.Categorical):
return concat_categorical(distrib_list)
else:
raise Exception('cannot concatenate distribution of type %s'%type(distrib_list[0]))
......
......@@ -298,7 +298,7 @@ class ShrubVAE(VanillaVAE):
steps = [];
if target_seq is None and z_all[0] is not None:
target_seq = z_all[0].shape[1]
target_seq = z_all[0].shape[1]
if target_seq:
cum_size = target_seq
for i in range(1, len(z_all)):
......@@ -427,7 +427,7 @@ class ShrubVAE(VanillaVAE):
enc_out, true_lengths = self.encode(x, y=y, **kwargs)
logger('data encoded')
# decode
layers = range(len(self.platent)) if multi_decode else [-1]
layers = range(len(self.platent)) if multi_decode else [len(self.platent)-1]
outs = []
for l in layers:
current_zs = [z['out'] for z in enc_out[:(l+1)]]
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment