prediction & various updates

parent 45d99868
......@@ -195,8 +195,8 @@ class CriterionContainer(Criterion):
def get_named_losses(self, losses):
named_losses=dict()
for i, l in enumerate(losses):
current_loss = self._criterions[i].get_named_losses(l)
for i, l in enumerate(self._criterions):
current_loss = l.get_named_losses(losses[i])
named_losses = {**named_losses, **current_loss}
return named_losses
......
......@@ -32,18 +32,13 @@ class KLD(Criterion):
def loss(self, params1=None, params2=None, out1=None, out2=None, sample=False, compute_logdets=False, **kwargs):
sample = (params1 is None) or (params2 is None) or sample
if not sample:
assert (params1 is not None) and (params2 is not None)
try:
loss = clamp(kl.kl_divergence(params1, params2), min=eps)
except NotImplementedError:
loss = self.kl_sampled(params1, params2, **kwargs)
loss = self.kl_sampled(params1, params2, out1, out2, **kwargs)
pass
else:
if out1 is None:
out1 = params1.rsample()
if out2 is None:
out2 = params2.rsample()
assert (out1 is not None) and (out2 is not None)
loss = self.kl_sampled(out1, out2, **kwargs)
loss = self.kl_sampled(params1, params2, out1, out2, **kwargs)
loss = reduce(loss, self.reduction);
losses = (float(loss),)
......@@ -51,7 +46,9 @@ class KLD(Criterion):
return loss, losses
def kl_sampled(self, params1, params2, out1, out2, **kwargs):
return params2.log_prob(out1) - params1.log_prob(out1)
if out1 is None:
out1 = params1.rsample()
return params1.log_prob(out1) - params2.log_prob(out1)
def get_named_losses(self, losses):
named_losses = {}
......
......@@ -57,11 +57,16 @@ class ELBO(CriterionContainer):
def get_reconstruction_params(self, model, out, target, epoch=None, callback=None, **kwargs):
callback = callback or self.reconstruction_loss
input_params = checklist(model.pinput); target = checklist(target)
rec_params = []
x_params = checklist(out['x_params'])
for i, ip in enumerate(input_params):
rec_params.append((callback, {'params1': x_params[i], 'params2': model.format_input_data(target[i]), 'input_params': ip, 'epoch':epoch}, 1.0))
if issubclass(type(self.reconstruction_loss), list):
input_params = checklist(model.pinput); target = checklist(target)
rec_params = []
x_params = checklist(out['x_params'])
target = model.format_input_data(target)
for i, ip in enumerate(input_params):
rec_params.append((callback, {'params1': x_params[i], 'params2': target[i], 'input_params': ip, 'epoch':epoch}, 1.0))
else:
x_params = out['x_params']
rec_params = [(callback, {'params1': x_params, 'params2': target, 'input_params': model.pinput, 'epoch':epoch}, 1.0)]
return rec_params
def get_regularization_params(self, model, out, epoch=None, beta=None, warmup=None, callback=None, **kwargs):
......@@ -72,20 +77,20 @@ class ELBO(CriterionContainer):
# encoder parameter
params1 = out.get("z_params_enc");
out1 = out['z_enc']
if params1 is not None:
if params1.requires_preflow:
out1 = out['z_preflow_enc']
# decoder parameters
prior = latent_params.get('prior') or None
if prior is not None:
params2 = scale_prior(prior, out['z_enc'])(batch_size = out1.shape, **kwargs)
out2 = params2.rsample()
elif out.get('z_params_dec') is not None:
if out.get('z_params_dec') is not None:
params2 = out['z_params_dec']
out2 = out["z_dec"]
elif prior is not None:
params2 = scale_prior(prior, out['z_enc'])(batch_size = out1.shape, **kwargs)
out2 = params2.rsample()
else:
params2 = get_default_distribution(latent_params['dist'], out['z_params_enc'].batch_shape,
device=out['z_enc'].device)
if out.get('z_params_enc'):
prior_shape = out['z_params_enc'].batch_shape
else:
prior_shape = out['z_enc'].shape
params2 = get_default_distribution(latent_params['dist'], prior_shape, device=out['z_enc'].device)
out2 = params2.rsample()
#pdb.set_trace()
......
......@@ -18,9 +18,9 @@ class LogDensity(Criterion):
if issubclass(type(input_params),list):
if not issubclass(type(params2), list):
x = [params2]
losses = tuple([self.loss(params1[i], params2[i], input_params[i]) for i in range(len(input_params))])
loss = sum(losses)
losses = tuple([l.detach().cpu().numpy() for l in losses])
losses_full = tuple([self.loss(params1[i], params2[i], input_params[i]) for i in range(len(input_params))])
loss = sum([l[0] for l in losses_full])
losses = tuple(sum([list(l[1]) for l in losses_full], []))
else:
#if len(target.shape)==2:
#target = target.squeeze(1)
......
......@@ -786,7 +786,7 @@ class Dataset(torch.utils.data.Dataset):
###################################
"""
def construct_partition(self, partitionNames, partitionPercent, tasks=None, balancedClass=True, equalClass=False):
def construct_partition(self, partitionNames, partitionPercent, tasks=[], balancedClass=True, equalClass=False):
"""
Construct a random/balanced partition set for each dataset
Only takes indices with valid metadatas for every task
......
......@@ -225,6 +225,70 @@ class Magnitude(object):
return new_data
class Phase(Preprocessing):
def __init__(self, normalize="bipolar", unwrap=True):
super(Phase, self).__init__()
self.normalize = normalize
if unwrap is not None:
assert type(unwrap) == int, "if not None, unwrap keyword must be a (int) dimension"
self.unwrap = unwrap
else:
self.unwrap = None
def __call__(self, data, **kwargs):
if issubclass(type(data), list):
return [self(x) for x in data]
if issubclass(type(data), np.ndarray):
if not np.iscomplexobj(data):
raise TypeError('Phase preprocessing needs complex np.ndarray')
transform = np.angle(data)
else:
raise TypeError('Phase preprocessing needs complex np.ndarray')
if self.unwrap is not None:
transform = np.unwrap(transform, axis=self.unwrap)
if self.normalize == "unipolar":
transform = np.clip((transform + np.pi)/(2*np.pi), 0, 1)
elif self.normalize == "bipolar":
transform = np.clip(transform / np.pi, -1, 1)
return transform
def invert(self, x):
if self.normalize == "unipolar":
x = (x * 2 * np.pi) - np.pi
elif self.normalize == "bipolar":
x = x * np.pi
return x
class InstantaneousPhase(Preprocessing):
def __init__(self):
super(InstantaneousPhase, self).__init__()
def __call__(self, data):
if issubclass(type(data), list):
return [self(x) for x in data]
class Polar(Preprocessing):
def __init__(self, mag_args, phase_args):
super(Polar, self).__init__()
self.magnitude_pp = Magnitude(**mag_args)
self.phase_pp = Phase(**phase_args)
def scale(self, data):
self.magnitude_pp.scale(data)
self.phase_pp.scale(data)
def __call__(self, data):
return [self.magnitude_pp(data), self.phase_pp(data)]
def invert(self, x, complex=False):
if complex:
return self.magnitude_pp.invert(x[0])*np.exp(1j*self.magnitude_pp.invert(x[1]))
else:
return [self.magnitude_pp.invert(x[0]),self.magnitude_pp.invert(x[1])]
class MuLaw(object):
"""
......
......@@ -6,9 +6,12 @@ from torch.distributions.utils import _sum_rightmost
class FlowDistribution(TransformedDistribution):
requires_preflow = True
def __init__(self, base_distribution, flow, validate_args=None):
super(FlowDistribution, self).__init__(base_distribution, flow.transforms, validate_args=validate_args)
in_selector = lambda _, x: x
def __init__(self, base_distribution, flow, unwrap_blocks=True, in_select=lambda x: x, validate_args=None):
super(FlowDistribution, self).__init__(in_select(base_distribution), flow.transforms, validate_args=validate_args)
self.base_distribution=base_distribution
self.flow = flow
self.unwrap_blocks=True
def __repr__(self):
return "FlowDistribution(%s, %s)"%(self.base_dist, self.flow)
......@@ -19,34 +22,34 @@ class FlowDistribution(TransformedDistribution):
# else:
# return self.base_dist.__getattr__(item)
def sample(self, sample_shape=torch.Size(), aux_in = None, retain=False):
def sample(self, sample_shape=torch.Size(), aux_in = None):
with torch.no_grad():
x = self.base_dist.sample(sample_shape)
x_0 = x
if retain:
x_0 = self.base_distribution.sample(sample_shape)
x = self.in_selector(x_0)
if self.unwrap_blocks:
full_x = []
self.flow.amortization(x_0, aux=aux_in)
for i, flow in enumerate(self.flow.blocks):
x = flow(x)
if retain:
if self.unwrap_blocks:
full_x.append(x.unsqueeze(1))
if retain:
if self.unwrap_blocks:
return torch.cat(full_x, dim=1), x_0
else:
return x, x_0
def rsample(self, sample_shape=torch.Size(), aux_in=None, retain=False):
x = self.base_dist.rsample(sample_shape)
x_0 = x
if retain:
def rsample(self, sample_shape=torch.Size(), aux_in=None):
x_0 = self.base_distribution.sample(sample_shape)
x = self.in_selector(x_0)
if self.unwrap_blocks:
full_x = []
self.flow.amortization(x_0, aux=aux_in)
for i, flow in enumerate(self.flow.blocks):
#pdb.set_trace()
x = flow(x)
if retain:
if self.unwrap_blocks:
full_x.append(x.unsqueeze(1))
if retain:
if self.unwrap_blocks:
return torch.cat(full_x, dim=1), x_0
else:
return x, x_0
......@@ -80,5 +83,39 @@ class Flow(object):
return FlowDistribution(self._dist(*args, **kwargs), self._flow)
class SequenceFlowDistribution(FlowDistribution):
in_selector = lambda _, x: x[:, -1]
def __repr__(self):
return "SequenceFlowDistribution(%s, %s)"%(self.base_dist, self.flow)
@property
def batch_shape(self):
final_seq_length = self.base_distribution.batch_shape[1] + len(self.flow.bijectors)
return (self.base_distribution.batch_shape[0], final_seq_length, *self.base_distribution.batch_shape[2:])
def sample(self, *args, **kwargs):
x, x_0 = super(SequenceFlowDistribution, self).sample(*args, **kwargs)
x = torch.cat([x_0, x], axis=1)
return x, x_0
def rsample(self, *args, **kwargs):
x, x_0 = super(SequenceFlowDistribution, self).rsample(*args, **kwargs)
x = torch.cat([x_0, x], axis=1)
return x, x_0
def log_prob(self, value):
"""
Scores the sample by inverting the transform(s) and computing the score
using the score of the base distribution and the log abs det jacobian.
"""
n_seq = self.base_distribution.batch_shape[1]
log_prob_dist = self.base_distribution.log_prob(value[:, :-n_seq])
log_prob_flow = self.flow.bijectors.log_abs_det_jacobian(value[:, self.base_distribution.batch_shape[1]], matrix=True)
if len(log_prob_flow.shape) < 3 or log_prob_flow.shape[-1] == 1:
log_prob_flow = log_prob_flow.unsqueeze(-1) if len(log_prob_flow.shape) == 2 else log_prob_flow
# trick to fit shape of log jacobians
log_prob_flow = log_prob_flow.repeat(1, 1, value.shape[-1]) / value.shape[-1]
# cumulative sum and adding to last step from base distribution
log_prob_flow = - torch.cumsum(log_prob_flow, dim=1) + log_prob_dist[:, -1].unsqueeze(1)
return torch.cat([log_prob_dist, log_prob_flow], dim=1)
......@@ -60,12 +60,17 @@ class FlowList(nn.ModuleList, Flow):#, transform.ComposeTransform):
nn.ModuleList.append(self, flow)
#self.parts = list(self.parts) + [flow]
def log_abs_det_jacobian(self, z):
def log_abs_det_jacobian(self, z, matrix=False):
if not self:
return torch.zeros_like(z)
result = 0
result = [] if matrix else 0
for flow in self:
result = result + flow.log_abs_det_jacobian(z)
if matrix:
result.append(flow.log_abs_det_jacobian(z))
else:
result = result + flow.log_abs_det_jacobian(z)
if matrix:
result = torch.stack(result, dim=1)
return result
def n_parameters(self):
......
......@@ -2,7 +2,7 @@ import torch, torch.nn as nn, pdb
from . import flow
from .modules_bottleneck import MLP
from .modules_recurrent import RecurrentModule, RNNLayer, GRULayer, LSTMLayer, VRNNEncoder
from ..distributions import Normal, Empirical, FlowDistribution
from ..distributions import Normal, Empirical, FlowDistribution, SequenceFlowDistribution
from ..utils import checklist, print_stats, flatten_seq_method, print_module_stats, oneHot
from . import Sequential
......@@ -135,6 +135,9 @@ class CPCPredictiveLayer(nn.Module):
predictors.append(predictor)
self.predictors = nn.ModuleList(predictors)
self.parametrization = prediction_params.get('parametrization', 'normal')
self.parametrization_params = prediction_params.get('parametrization_params', {})
'''
def get_density_ratio(self, z, context, make_negatives=True):
print_stats(context, 'context')
......@@ -186,7 +189,11 @@ class CPCPredictiveLayer(nn.Module):
#print_stats(contexts, 'contexts')
#pdb.set_trace()
#print(predictions.mean((0, 1)))
return {"out": predictions, "cpc_states":contexts}
out = {"out": torch.cat([z_in, predictions], dim=1), "cpc_states":contexts}
if self.parametrization == "normal":
scale = self.parametrization_params.get('normal_std', 1e-3)
out['out_params'] = Normal(out['out'], scale*torch.ones_like(out['out']))
return out
# Flow-based prediction modules
......@@ -245,11 +252,8 @@ class FlowPrediction(nn.Module):
#aux_in = self.recurrent_smoother(out['recurrent_out'][:, -1])
elif self.amortization in ['hidden']:
aux_in = out['hidden'][-1]
final_jacobs = []
z_outs = []
flow_dist = FlowDistribution(out['z_params_enc'][-1][:, -1], self.flow)
flow_dist = SequenceFlowDistribution(out['z_params_enc'][-1][:, :-self.n_predictions], self.flow, unwrap_blocks=True)
# for n in range(len(self.flow)):
# z_out, log_jacobians = self.flow[n](data_in[:, context_in], aux=aux_in)
# final_jacobs.append(torch.sum(torch.cat(log_jacobians, dim=-1).unsqueeze(1), dim=-1))
......@@ -258,7 +262,7 @@ class FlowPrediction(nn.Module):
# if len(previous_z.shape) > 2:
# previous_z = previous_z.squeeze(-2)
# preds = torch.cat(z_outs, dim=1)
out, out_preflow = flow_dist.rsample(retain=True, aux_in=aux_in)
out, out_preflow = flow_dist.rsample(aux_in=aux_in)
return {'out_params':flow_dist, 'out':out, 'out_preflow':out_preflow}
......@@ -287,8 +291,8 @@ class GPPrediction(nn.Module):
def __init__(self, input_params, prediction_params, **kwargs):
super(GPPrediction, self).__init__()
#self.register_parameter('sigma', nn.Parameter(torch.tensor(prediction_params.get('init_variance', 1e-5))))
self.sigma = 1e-2
self.timescale = 0.3
self.sigma = prediction_params.get('variance')
self.timescale = prediction_params.get('timescale')
#self.register_parameter('timescale', nn.Parameter(torch.tensor(prediction_params.get('timescale', 3e-1))))
self.n_predictions = prediction_params.get('n_predictions')
......@@ -302,10 +306,14 @@ class GPPrediction(nn.Module):
out.mean(1); cov = process_kernel(t_in, t_in)
k_pred = process_kernel(t_in, t_pred)
# predicted mean
k_mult = torch.mm(k_pred.t() , torch.inverse(cov + torch.eye(t_in.shape[0], device=out.device)*self.sigma))
z_pred = torch.bmm(k_mult.unsqueeze(0).repeat(out.shape[0],1,1), out[:, :context_in])
pdb.set_trace()
#print_stats(out, 'out')
#print_stats(z_pred, 'pred')
# predicted variance
z_sig_pred = process_kernel(t_pred, t_pred) - torch.mm(k_mult, k_pred)
return {'out':z_pred}
......
......@@ -87,6 +87,12 @@ def get_divs(n):
def get_tensors_from_dist(distrib):
if issubclass(type(distrib), dist.Normal):
return {'mean':distrib.mean, 'std':distrib.stddev}
else:
raise NotImplementedError
# Plotting functions
......@@ -172,7 +178,7 @@ def plot_mean_1d(dist, x=None, preprocess=None, preprocessing=None, axes=None, m
return fig, axes
def plot_mean_2d(dist, x=None, preprocessing=None, preprocess=False, multihead=None, out=None, *args, **kwargs):
def plot_mean_2d(dist, x=None, preprocessing=None, preprocess=False, multihead=None, out=None, fig=None, *args, **kwargs):
n_examples = dist.batch_shape[0]
n_rows, n_columns = get_divs(n_examples)
has_std = hasattr(dist, 'stddev')
......@@ -226,7 +232,7 @@ def plot_mean_2d(dist, x=None, preprocessing=None, preprocess=False, multihead=N
axes_std[i,j].set_title('reconstruction')
if multihead is not None:
fig = [fig]; axes = [axes];
fig = [fig]; axes = [fig.axis];
fig_m, axes_m = plt.subplots(n_rows, n_columns, figsize=(10,10))
if len(axes_m.shape) == 1:
axes_m = axes_m[:, np.newaxis]
......@@ -239,8 +245,8 @@ def plot_mean_2d(dist, x=None, preprocessing=None, preprocess=False, multihead=N
if out is not None:
fig.savefig(out+".pdf", format="pdf")
if has_std:
fig_std.savefig(out+"_std.pdf", format="pdf")
# if has_std:
# fig_std.savefig(out+"_std.pdf", format="pdf")
return fig, axes
......
......@@ -33,6 +33,7 @@ plot_hash = {'reconstructions': lplt.plot_reconstructions,
'grid_latent':lplt.grid_latent,
'descriptors_2d':plot2Ddescriptors,
'descriptors_3d':plot3Ddescriptors,
'prediction':lplt.plot_prediction,
'audio_reconstructions': resynthesize_files,
'audio_interpolate': interpolate_files}
......
import torch, os, pdb, gc
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import librosa
from functools import reduce
......@@ -7,6 +8,7 @@ from torchvision.utils import make_grid
from mpl_toolkits.mplot3d import Axes3D
from . import visualize_dimred as dr
from ..utils.onehot import fromOneHot, oneHot
from ..distributions.distribution_priors import get_default_distribution
from ..utils import decudify, merge_dicts, CollapsedIds, check_dir, recgetitem, decudify
......@@ -167,6 +169,117 @@ def plot_samples(dataset, model, label=None, n_points=10, layer=None, priors=Non
return figs, axes
def import_axes(fig, ax_to_remove, ax_to_import, position=None):
ax_to_import.remove();
position = position or ax_to_remove.get_position();
ax_to_import.set_position(position);
subplotspec = ax_to_remove.get_subplotspec()
ax_to_import.set_subplotspec(subplotspec)
ax_to_import.figure = fig;
fig.axes.append(ax_to_import)
fig.add_axes(ax_to_import)
ax_to_remove.remove()
ax_to_import.set_title('')
def plot_prediction(dataset, model, label=None, n_points=10, out=None, preprocess=True, preprocessing=None, partition=None,
epoch=None, name=None, loader=None, ids=None, plot_multihead=False, reinforcers=None, **kwargs):
if partition is not None:
dataset = dataset.retrieve(partition)
n_rows, n_columns = core.get_divs(n_points)
if ids is None:
full_id_list = np.arange(len(dataset))
ids = full_id_list[np.random.permutation(len(full_id_list))[:n_points]] if ids is None else ids
# get data
Loader = loader if loader else DataLoader
loader = Loader(dataset, None, ids=ids, is_sequence=model.take_sequences, tasks=label)
data, metadata = next(loader.__iter__())
if preprocess:
if issubclass(type(model.pinput), list):
preprocessing = preprocessing if preprocessing else [None]*len(dataset.data)
data_pp = [None]*len(dataset.data)
if not issubclass(type(preprocessing), list):
preprocessing = [preprocessing]*len(dataset.data)
for i, pp in enumerate(preprocessing):
if not pp is None:
data_pp[i] = preprocessing(data[i])
else:
data_pp = preprocessing(data) if preprocessing is not None else data
else:
data_pp = data
# forward
add_args = {}
if hasattr(model, 'prediction_params'):
add_args['n_preds'] = model.prediction_params['n_predictions']
add_args['epoch'] = kwargs.get('epoch')
with torch.no_grad():
encoder_out, prediction_out = model.encode(data_pp, y=metadata, predict=True)
encoder_out = {'z_params_enc':[encoder_out[i]['out_params'] for i in range(len(encoder_out))],
'z_enc':[encoder_out[i]['out'] for i in range(len(encoder_out))]}
decoder_out = model.decode(encoder_out['z_enc'], predict=False)
decoder_out_pred = model.decode(prediction_out['out'], predict=False)
fig_dd, axes = core.plot_distribution(decoder_out[0]['out_params'], target=data_pp, preprocessing=preprocessing, preprocess=preprocess)
if axes.shape[1] > 2:
axes_cc = np.concatenate([axes[:, 2*i] for i in range(axes.shape[1]//2)])
axes_rec = np.concatenate([axes[:, 2*i+1] for i in range(axes.shape[1]//2)])
axes = np.array([axes_cc, axes_rec]).T
_, axes_pred = core.plot_distribution(decoder_out_pred[0]['out_params'], preprocessing=preprocessing, preprocess=preprocess)
if axes_pred.shape[1] > 1:
axes_pred = axes_pred.reshape(np.cumprod(axes_pred.shape)[-1], 1)
fig, ax = plt.subplots(dpi=fig_dd.dpi, figsize=(9,7))
plot_cpc = (prediction_out.get('cpc_states') is not None)
n_figs = 3.0 + plot_cpc
for n in range(n_points):
x_idx = 0.
position = matplotlib.transforms.Bbox([[0.03 + x_idx / n_figs, n/n_points+0.03], [(x_idx + 1) / n_figs, (n+1)/n_points]])
print(position)
axes[n, 0].remove(); axes[n, 0].figure = fig; fig.add_axes(axes[n, 0]); axes[n, 0].set_position(position)
axes[n, 0].set_title(''); axes[n, 0].set_xticks([]); axes[n, 0].set_yticks([]);
x_idx = 1.
position = matplotlib.transforms.Bbox([[0.03 + x_idx / n_figs, n/n_points+0.03], [(x_idx + 1) / n_figs, (n+1)/n_points]])
print(position)
axes[n, 1].remove(); axes[n, 1].figure = fig; fig.add_axes(axes[n, 1]); axes[n, 1].set_position(position)
axes[n, 1].set_title(''); axes[n, 1].set_xticks([]); axes[n, 1].set_yticks([]);
x_idx = 2.
position = matplotlib.transforms.Bbox([[0.03 + x_idx / n_figs, n/n_points+0.03], [(x_idx + 1) / n_figs, (n+1)/n_points]])
print(position)
axes_pred[n, 0].remove(); axes_pred[n, 0].figure = fig; fig.add_axes(axes_pred[n, 0]); axes_pred[n, 0].set_position(position)
axes_pred[n, 0].set_title(''); axes_pred[n, 0].set_xticks([]); axes_pred[n, 0].set_yticks([]);
if plot_cpc:
x_idx = 3.
position = matplotlib.transforms.Bbox([[x_idx / n_figs+0.02, n/n_points+0.02], [(x_idx + 1) / n_figs-0.03, (n+1)/n_points-0.02]])
cpc_ax = fig.add_axes(position)
cpc = prediction_out.get('cpc_states').detach().cpu().numpy()
for c in range(cpc.shape[-1]):
cpc_ax.plot(cpc[n, :, c])
axes_pred[n, 0].set_xticks([]);
'''
ax[n, 0].set_position(position)
ax[n, 0].plot(np.arange(10), np.sin(np.arange(10)))
'''
#import_axes(fig, ax[n, 1], axes[n, 1])
#import_axes(fig, ax[n, 2], axes_pred[n, 0])
if out is not None:
current_name = name or f"pred_{partition or str()}"
if partition:
current_name += "_"+partition
out += "/prediction"
check_dir(out)
fig.savefig(out+'/%s_%d.pdf'%(current_name,epoch), format="pdf")
return [fig], [ax]
def get_plot_subdataset(dataset, n_points=None, partition=None, ids=None):
......
......@@ -15,51 +15,62 @@ logger = GPULogger(verbose=False)
# are defined separately from the object and dynamically added to a given class
def predictive_vae_init(self, *args, **kwargs):
# fist init VAE, then prediction
self.vae_class.__init__(self, *args, **kwargs)
self.prediction_class.__init__(self, *args, **kwargs)
def decode_prediction(self, z_predicted, *args, **kwargs):
return self.vae_class.decode(self, z_predicted)
def concat_prediction(self, vae_out, prediction_out, n_preds=None, teacher_prob=None, epoch=None, **kwargs):
z_predicted = prediction_out.get('out')
z_params_predicted = prediction_out.get('out_params')
# get teacher probabitliy
teacher_prob = teacher_prob or self.teacher_prob or 0
if self.teacher_warmup != 0 and epoch is not None:
teacher_prob = (1 - min(1.0, epoch / self.teacher_warmup)*(1-teacher_prob))
# 1 means taking true encoded position
#print('teacher_prob', teacher_prob)
# 0 for prediction, 1 for original
mask = torch.distributions.Bernoulli(teacher_prob).sample(sample_shape=[z_predicted.shape[0]]).to(device=z_predicted.device)
#1print(mask)
# just bypass if mask == 0
if torch.sum(mask) < z_predicted.shape[0]:
x_decoded = self.decode_prediction(z_predicted, target_seq=n_preds)
seq_length = z_predicted.shape[1]
n_preds = n_preds or self.prediction_module.n_predictions
x_decoded = self.decode_prediction(z_predicted[:, -n_preds:], target_seq=n_preds)
# merge encoding parameters (only first layer)
mask_int = torch.bernoulli(torch.tensor(float(teacher_prob)))
try:
vae_out['z_params_enc'][-1] = concat_distrib([vae_out['z_params_enc'][-1][:, :-seq_length],
dist_crossed_select(mask, vae_out['z_params_enc'][-1][:, -seq_length:], z_params_predicted)], dim=1, unsqueeze=False)
except NotImplementedError:
pass
vae_out['z_enc'][-1] = torch.cat([vae_out['z_enc'][-1][:, :-seq_length],
crossed_select(mask, vae_out['z_enc'][-1][:, -seq_length:], z_predicted)], dim=1)
# try to random batches
if z_params_predicted is not None:
vae_out['z_params_enc'][-1] = dist_crossed_select(mask, vae_out['z_params_enc'][-1], z_params_predicted)
if vae_out.get('z_enc') is not None:
vae_out['z_enc'][-1] = crossed_select(mask, z_predicted, vae_out['z_enc'][-1])
except NotImplementedError as e:
if mask_int == 0:
vae_out['z_params_enc'][-1] = z_params_predicted
if vae_out.get('z_enc') is not None:
vae_out['z_enc'][-1] = z_predicted
# merge decoding parameters (every layers)
for layer in range(len(x_decoded)-1):
seq_length = x_decoded[layer+1]['out'].shape[1]
vae_out['z_params_dec'][layer] = concat_distrib([vae_out['z_params_dec'][layer][:, :-seq_length],
dist_crossed_select(mask, vae_out['z_params_dec'][layer][:, -seq_length:], x_decoded[layer+1]['out_params'])], dim=1, unsqueeze=False)
vae_out['z_dec'][layer] = torch.cat([vae_out['z_dec'][layer][:, :-seq_length],
crossed_select(mask, vae_out['z_dec'][layer][:, -seq_length:], x_decoded[layer+1]['out'])], dim=1)
#TODO flows?
vae_out['x_params'] = concat_distrib([vae_out['x_params'][:, :-n_preds],
dist_crossed_select(mask, x_decoded[0]['out_params'], vae_out['x_params'][:, -n_preds:])], dim=1, unsqueeze=False)
if prediction_out.get('logdets'):
logdets = vae_out.get('logdets', [None]*len(self.platent)) or [None]*len(self.platent)
if logdets[-1] is None:
logdets[-1] = prediction_out['logdets']
else:
logdets[-1].append(prediction_out['logdets'])
vae_out['logdets'] = logdets
try:
vae_out['z_params_dec'][layer] = concat_distrib([vae_out['z_params_dec'][layer][:, :-seq_length],
dist_crossed_select(mask, vae_out['z_params_dec'][layer][:, -seq_length:], x_decoded[layer+1]['out_params'])], dim=1, unsqueeze=False)
if vae_out.get('z_dec') is not None:
vae_out['z_dec'][layer] = torch.cat([vae_out['z_dec'][layer][:, :-seq_length],
crossed_select(mask, vae_out['z_dec'][layer][:, -seq_length:], x_decoded[layer+1]['out'])], dim=1)
except NotImplementedError as e:
if mask_int == 0:
vae_out['z_params_dec'][layer] = x_decoded[layer+1]['out_params']
if vae_out.get('z_dec') is not None: