fixed plots and evaluations

parent e8a64156
......@@ -156,6 +156,9 @@ def resynthesize_files(dataset, model, transformOptions=None, transform=None, me
else:
files = choices(dataset, k=n_files)
if transformOptions is None:
print('[Warning] resynthesize_files : transformOptions keyword is required to resynthesize audio samples')
return None, []
transform = transform or transformOptions.get('transformType')
outputs = {}
for i, current_file in enumerate(files):
......
import torch, torch.nn as nn
import torch, torch.nn as nn, numpy as np
from .. import distributions as dist
import sklearn.manifold as manifold
import sklearn.decomposition as decomposition
......@@ -11,20 +11,18 @@ class DimRedBaseline(nn.Module):
super(DimRedBaseline, self).__init__()
self.pinput = input_params
latent_params = checklist(latent_params)
self.platent = checklist(latent_params)
self.platent = checklist(latent_params[-1])
self.dimred_module = self.dimred_class(n_components = self.platent[-1]['dim'], **kwargs)
def encode(self, x, **kwargs):
input_device = torch.device('cpu')
if torch.is_tensor(x):
input_device = x.device
if len(x.shape) > 2:
x = x[:, 0]
dimred_out = torch.from_numpy(self.dimred_module.fit_transform(x)).to(input_device).float()
dimred_dist = dist.Normal(dimred_out, torch.zeros_like(dimred_out)+1e-12)
return [{'out':dimred_out, 'out_params':dimred_dist}]
def decode(self, z,squeezed=False, **kwargs):
def decode(self, z, squeezed=False, **kwargs):
input_device = torch.device('cpu')
if torch.is_tensor(z):
input_device = z.device
......@@ -36,16 +34,17 @@ class DimRedBaseline(nn.Module):
return [{'out':dimred_out, 'out_params':dimred_dist}]
def forward(self, x, y=None, **kwargs):
squeezed = False
if len(x.shape) > 2:
x = x[:, 0]
squeezed = True
input_car = len(checklist(self.pinput['dim']))+self.pinput.get('conv')
input_shape = x.shape[-input_car:]
batch_shape = x.shape[:-input_car]
x = x.view(np.cumprod(batch_shape)[-1], np.cumprod(input_shape)[-1])
z = self.encode(x, **kwargs)
reconstruction= self.decode(z[0]['out'], squeezed = squeezed, **kwargs)
return {'x_params': reconstruction[0]['out_params'],
'z_params_enc':[z[0]['out_params']],
'z_enc':[z[0]['out']],
'z_params_dec':[], 'z_dec':[]}
reconstruction= self.decode(z[0]['out'], **kwargs)
x_params = reconstruction[0]['out_params'].reshape(*batch_shape, *input_shape)
z_params_enc = [z[0]['out_params'].reshape(*batch_shape, self.platent[0]['dim'])]
z_enc = [z[0]['out'].reshape(*batch_shape, self.platent[0]['dim'])]
return {'x_params':x_params, 'z_params_enc':z_params_enc, 'z_params_dec':[], 'z_enc':z_enc}
class PCABaseline(DimRedBaseline):
......
......@@ -79,7 +79,7 @@ class EvaluationContainer(Evaluation):
def __repr__(self):
return "Evaluation : (%s)"%self._evaluations
def __init__(self, evaluations=[], **kwargs):
self._evaluations = []
self._evaluations = evaluations
self.output = None
if kwargs.get('out'):
self.output = kwargs.get('out')
......
......@@ -31,10 +31,7 @@ class LatentEvaluation(Evaluation):
p_l = latent_params[l].get('prior') or get_default_distribution(latent_params[l]['dist'], q_l.batch_shape)
else:
p_l = latent_params[l].get('prior') or get_default_distribution(latent_params[l]['dist'], q_l.batch_shape)
if out.get('logdets') is None:
inputs.append({'params1':q_l, 'params2':p_l})
else:
inputs.append({'params1':q_l, 'params2':p_l, 'logdets':out.get('logdets')[l]})
inputs.append({'params1':q_l, 'params2':p_l})
for div in self.divergences:
......
......@@ -115,6 +115,9 @@ def plot_mean_1d(dist, x=None, preprocess=None, preprocessing=None, axes=None, m
# get distributions
dist_mean = dist.mean.cpu().detach().numpy(); dist_mean_inv=None
dist_variance = dist.variance.cpu().detach().numpy()
if x is None:
x = np.zeros_like(dist_mean)
if torch.is_tensor(x):
x = x.cpu().detach().numpy()
......@@ -169,7 +172,7 @@ def plot_mean_1d(dist, x=None, preprocess=None, preprocessing=None, axes=None, m
return fig, axes
def plot_mean_2d(dist, x=None, preprocessing=None, preprocess="False", multihead=None, out=None, *args, **kwargs):
def plot_mean_2d(dist, x=None, preprocessing=None, preprocess=False, multihead=None, out=None, *args, **kwargs):
n_examples = dist.batch_shape[0]
n_rows, n_columns = get_divs(n_examples)
has_std = hasattr(dist, 'stddev')
......@@ -182,10 +185,14 @@ def plot_mean_2d(dist, x=None, preprocessing=None, preprocess="False", multihead
if has_std:
fig_std, axes_std = plt.subplots(n_rows, 2*n_columns)
if axes.ndim == 1:
if n_rows == 1:
axes = axes[np.newaxis, :]
if has_std:
axes_std = axes_std[np.newaxis, :]
axes_std = axes_std[:, np.newaxis]
if n_columns == 1 and x is None:
axes = axes[:, np.newaxis]
if has_std:
axes_std = axes_std[:, np.newaxis]
dist_mean = dist.mean.cpu().detach().numpy()
if has_std:
......@@ -205,8 +212,8 @@ def plot_mean_2d(dist, x=None, preprocessing=None, preprocess="False", multihead
axes[i,2*j+1].set_title('reconstruction')
else:
axes[i,j].set_title('data')
axes[i,j+1].imshow(dist_mean[i*n_columns+j], aspect='auto')
axes[i,j+1].set_title('reconstruction')
axes[i,j].imshow(dist_mean[i*n_columns+j], aspect='auto')
axes[i,j].set_title('reconstruction')
if hasattr(dist, "stddev"):
if x is not None:
axes_std[i,2*j].imshow(x[i*n_columns+j], aspect='auto')
......@@ -215,8 +222,8 @@ def plot_mean_2d(dist, x=None, preprocessing=None, preprocess="False", multihead
axes_std[i,2*j+1].set_title('reconstruction')
else:
axes_std[i,j].set_title('data')
axes_std[i,j+1].imshow(dist_std[i*n_columns+j], vmin=0, vmax=1, aspect='auto')
axes_std[i,j+1].set_title('reconstruction')
axes_std[i,j].imshow(dist_std[i*n_columns+j], vmin=0, vmax=1, aspect='auto')
axes_std[i,j].set_title('reconstruction')
if multihead is not None:
fig = [fig]; axes = [axes];
......@@ -249,14 +256,20 @@ def plot_mean(x, target=None, preprocessing=None, axes=None, *args, is_sequence=
x = type(x)(x.mean.squeeze(), x.stddev.squeeze())
else:
x = type(x)(x.mean.squeeze())
target = target.squeeze()
if target is not None:
target = target.squeeze()
if is_sequence:
fig = [None]*x.batch_shape[0]; axes = [None]*x.batch_shape[0]
for ex in range(x.batch_shape[0]):
if len(x.batch_shape) <= 3:
fig[ex], axes[ex] = plot_mean_1d(x[ex], x=target[ex], preprocessing=preprocessing, *args, **kwargs)
if target is not None:
fig[ex], axes[ex] = plot_mean_1d(x[ex], x=target[ex], preprocessing=preprocessing, *args, **kwargs)
else:
fig[ex], axes[ex] = plot_mean_1d(x[ex], x=None, preprocessing=preprocessing, *args, **kwargs)
elif len(x.batch_shape) == 4:
fig, axes = plot_mean_2d(x, x=target, preprocessing=preprocessing, *args, **kwargs)
if target is not None:
fig, axes = plot_mean_2d(x, x=target, preprocessing=preprocessing, *args, **kwargs)
else:
if len(x.batch_shape) <= 2 + is_sequence:
fig, axes = plot_mean_1d(x, x=target, preprocessing=preprocessing, axes=axes, *args, **kwargs)
......@@ -392,9 +405,9 @@ def plot_dims(current_z, meta=None, var=None, classes=None, class_ids=None, clas
ax = fig.gca(projection='3d')
if meta is None:
meta = np.zeros((current_z.shape[0]))
cmap = get_cmap(0, color_map=cmap)
cmap_hash = {None:None}
meta = np.zeros((current_z.shape[0])).astype(np.int)
cmap = get_cmap(1, color_map=cmap)
cmap_hash = {0:0}
else:
cmap = get_cmap(len(checklist(classes)), color_map=cmap)
cmap_hash = {classes[i]:i for i in range(len(checklist(classes)))}
......@@ -493,9 +506,10 @@ def plot_3d(current_z, meta=None, var=None, classes=None, class_ids=None, class_
ax = fig.gca(projection='3d')
if meta is None:
meta = np.zeros((current_z.shape[0]))
cmap = get_cmap(0, color_map=cmap)
cmap_hash = {0:0}
meta = np.arange((current_z.shape[0]))
cmap = get_cmap(meta.shape[0], color_map=cmap)
cmap_hash = {x:x for x in meta}
legend = False
else:
cmap = get_cmap(0, color_map=cmap) if classes is None else get_cmap(len(classes), color_map=cmap)
cmap_hash = {None:None} if classes is None else {classes[i]:i for i in range(len(classes))}
......@@ -503,7 +517,9 @@ def plot_3d(current_z, meta=None, var=None, classes=None, class_ids=None, class_
current_alpha = 0.06 if (centroids and not meta is None) else 1.0
current_var = var if not var is None else np.ones(current_z.shape[0])
current_var = (current_var - current_var.mean() / np.abs(current_var).max())+1
meta = np.array(meta).astype(np.int)
if meta is not None:
meta = np.array(meta).astype(np.int)
n_examples = current_z.shape[0]
if scramble:
......
......@@ -23,6 +23,7 @@ plot_hash = {'reconstructions': lplt.plot_reconstructions,
'latent_space': lplt.plot_latent3,
'latent_trajs': lplt.plot_latent_trajs,
'latent_dims': lplt.plot_latent_dim,
'sample': lplt.plot_samples,
'latent_consistency': lplt.plot_latent_consistency,
'statistics':lplt.plot_latent_stats,
'images':lplt.image_export,
......
This diff is collapsed.
......@@ -276,7 +276,7 @@ class ShrubVAE(VanillaVAE):
if from_layer < 0:
from_layer = len(self.platent) + from_layer
for i, z_tmp in enumerate(z):
z_all[from_layer - len(z) + 1 + i] = z[i]
z_all[from_layer - len(z) + 1 + i] = z[i].unsqueeze(1) if len(z[i].shape) == 2 else z[i]
current_z = z_all[from_layer]
outs = []; n_batch = z_all[from_layer].shape[0]; n_seq = z_all[from_layer].shape[1]
......@@ -307,12 +307,12 @@ class ShrubVAE(VanillaVAE):
cum_size = int(ceil(cum_size / steps[-1]))
else:
if self.phidden[i].get('decoder'):
steps = [None] * len(self.phidden[i].get('decoder')) - 1
for i in range(1, len(self.phidden[i]['decoder'])):
steps = [None] * (len(self.phidden) - 1)
for i in range(1, len(self.phidden)):
steps[i-1] = self.phidden[i]['decoder'].get('path_length') or current_z.shape[1]
else:
steps = [None] * len(self.phidden[i]) - 1
for i in range(1, len(self.phidden[i])):
steps = [None] * (len(self.phidden) - 1)
for i in range(1, len(self.phidden)):
steps[i-1] = self.phidden[i].get('path_length') or current_z.shape[1]
# steps = [self.phidden[i].get('path_length', 1) for i in range(1, len(self.phidden))]
......@@ -337,6 +337,8 @@ class ShrubVAE(VanillaVAE):
logger('start decoding')
for layer in reversed(range(1, from_layer+1)):
# get number of steps to be decoded by RVAE decoder
if len(current_z.shape)==2:
current_z = current_z.unsqueeze(1)
current_z = current_z.reshape(n_batch*n_seq, *current_z.shape[2:])
n_overlap = self.phidden[layer].get('path_overlap') or n_steps
# forward
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment