class dependent priors + updated perceptual regularization

parent f89f154e
......@@ -8,6 +8,7 @@ from ..utils import checklist, apply, print_stats
def scale_prior(prior, size):
#TODO has to be implemented to prevent unscaled priors for KLD
return prior
def regularize_logdets(logdets):
if logdets is None:
return 0
......@@ -54,7 +55,7 @@ class ELBO(CriterionContainer):
scaled_factors = [min(((epoch+1)/(warmup[i]-1))**self.warmup_exp, 1.0)*beta[i] if warmup[i] != 0 and epoch is not None else beta[i] for i in range(len(latent_params))]
return scaled_factors
def get_reconstruction_params(self, model, out, target, epoch=None, callback=None):
def get_reconstruction_params(self, model, out, target, epoch=None, callback=None, **kwargs):
callback = callback or self.reconstruction_loss
input_params = checklist(model.pinput); target = checklist(target)
rec_params = []
......@@ -63,7 +64,7 @@ class ELBO(CriterionContainer):
rec_params.append((callback, {'params1': x_params[i], 'params2': model.format_input_data(target[i]), 'input_params': ip, 'epoch':epoch}, 1.0))
return rec_params
def get_regularization_params(self, model, out, epoch=None, beta=None, warmup=None, callback=None):
def get_regularization_params(self, model, out, epoch=None, beta=None, warmup=None, callback=None, **kwargs):
def parse_layer(latent_params, out, layer_index=0):
if issubclass(type(latent_params), list):
......@@ -77,7 +78,8 @@ class ELBO(CriterionContainer):
# decoder parameters
prior = latent_params.get('prior') or None
if prior is not None:
params2 = scale_prior(prior, out['z_enc'])
params2 = scale_prior(prior, out['z_enc'])(**kwargs)
out2 = params2.rsample()
elif out.get('z_params_dec') is not None:
params2 = out['z_params_dec']
out2 = out["z_dec"]
......@@ -108,11 +110,10 @@ class ELBO(CriterionContainer):
def loss(self, model = None, out = None, target = None, epoch = None, beta=None, warmup=None, *args, **kwargs):
assert not model is None and not out is None and not target is None, "ELBO loss needs a model, an output and an input"
# parse loss arguments
reconstruction_params = self.get_reconstruction_params(model, out, target, epoch=epoch)
reconstruction_params = self.get_reconstruction_params(model, out, target, epoch=epoch, **kwargs)
beta = beta or self.beta
regularization_params = self.get_regularization_params(model, out, epoch=epoch, beta=beta, warmup=warmup)
regularization_params = self.get_regularization_params(model, out, epoch=epoch, beta=beta, warmup=warmup, **kwargs)
logdets = tuple()
# get warmup coefficient
full_loss = 0; rec_errors=tuple(); reg_errors=tuple()
# get reconstruction error
......
......@@ -8,24 +8,113 @@ Created on Fri Jul 6 20:09:45 2018
import pdb
import torch, numpy as np
from .criterion_criterion import Criterion
from utils.onehot import fromOneHot
from ..utils import fromOneHot
import numpy as np, os, pdb
from scipy.stats import norm
from ..monitor.visualize_dimred import MDS
equivalenceInstruments = ['Clarinet-Bb', 'Alto-Sax', 'Trumpet-C', 'Violoncello',
'French-Horn', 'Oboe', 'Flute', 'English-Horn',
'Bassoon', 'Tenor-Trombone', 'Piano', 'Violin']
def get_perceptual_centroids(dataset, mds_dims, timbre_path='timbre.npy', covariance=True, timbreNormalize=True,
timbreProcessing=True):
if (timbreProcessing == True or (not os.path.isfile('timbre_' + str(mds_dims) + '.npy'))):
fullTimbreData = np.load(f"{os.path.dirname(__file__)}/{timbre_path}")[None][0]
# Names of the pre-extracted set of instruments (all with pairwise rates)
selectedInstruments = fullTimbreData['instruments']
# Full sets of ratings (i, j) = all ratings for instru. i vs. instru. j
detailedMatrix = fullTimbreData['ratings']
# Final matrices
nbIns = len(selectedInstruments)
meanRatings = np.zeros((nbIns, nbIns))
gaussMuRatings = np.zeros((nbIns, nbIns))
gaussStdRatings = np.zeros((nbIns, nbIns))
nbRatings = np.zeros((nbIns, nbIns))
# Fit Gaussians for each of the sets of pairwise instruments ratings
for i in range(nbIns):
for j in range(i + 1, nbIns):
nbRatings[i, j] = detailedMatrix[i, j].size
meanRatings[i, j] = np.mean(detailedMatrix[i, j])
# Model the gaussian distribution of ratings
mu, std = norm.fit(detailedMatrix[i, j])
# Fill parameters of the Gaussian
gaussMuRatings[i, j] = mu
gaussStdRatings[i, j] = std
print("%s vs. %s : mu = %.2f, std = %.2f" % (selectedInstruments[i], selectedInstruments[j], mu, std))
# Create square matrices
meanRatings += meanRatings.T
gaussMuRatings += gaussMuRatings.T
gaussStdRatings += gaussStdRatings.T
meanRatings = (meanRatings - np.min(meanRatings)) / np.max(meanRatings)
# Rescale means
gaussMuRatings = (gaussMuRatings - np.min(gaussMuRatings)) / np.max(gaussMuRatings)
# Rescale variances
gaussStdRatings = (gaussStdRatings - np.min(gaussStdRatings)) / np.max(gaussStdRatings)
variance = np.mean(gaussStdRatings, axis=1)
if (timbreNormalize):
variance = ((variance - (np.min(variance)) + 0.01) / np.max(variance)) * 2
# Compute MDS on Gaussian mean
seed = np.random.RandomState(seed=3)
mds = MDS(n_components=mds_dims, max_iter=3000, eps=1e-9, random_state=seed, dissimilarity="precomputed",
n_jobs=1)
position = mds.fit(gaussMuRatings).embedding_
# Store all computations here
fullTimbreData = {'instruments': selectedInstruments,
'ratings': detailedMatrix,
'gmean': gaussMuRatings,
'gstd': gaussStdRatings,
'pos': position,
'var': variance}
np.save('timbre_' + str(mds_dims) + '.npy', fullTimbreData)
else:
# Retrieve final data structure
fullTimbreData = np.load('timbre.npy').item()
# Names of the pre-extracted set of instruments (all with pairwise rates)
selectedInstruments = fullTimbreData['instruments']
# Gaussian modelization of the ratings
gaussMuRatings = fullTimbreData['gmean']
gaussStdRatings = fullTimbreData['gstd']
# MDS modelization of the ratings
position = fullTimbreData['pos']
variance = fullTimbreData['var']
audioTimbreIDs = np.zeros(len(equivalenceInstruments)).astype('int')
# Parse through the list of instruments
for k, v in dataset.classes['instrument'].items():
if (k != '_length'):
audioTimbreIDs[v] = equivalenceInstruments.index(k)
# Class-dependent means and covariances
prior_mean = position[audioTimbreIDs]
prior_std = np.ones((len(equivalenceInstruments), mds_dims))
if (covariance == 1):
prior_std = prior_std * variance[audioTimbreIDs, np.newaxis]
prior_params = (prior_mean, prior_std)
# Same for full Gaussian
prior_gauss_params = (gaussMuRatings, gaussStdRatings)
return prior_params, prior_gauss_params
class PerceptiveL2Loss(Criterion):
def __init__(self, centroids, targetDims, options={'normalize':False}):
def __init__(self, latent_params, dataset=None, normalize=False):
super(PerceptiveL2Loss, self).__init__()
targetDims = latent_params['dim']
_, centroids = get_perceptual_centroids(dataset, latent_params['dim'])[0]
if issubclass(type(centroids), np.ndarray):
self.centroids = torch.from_numpy(centroids).type('torch.FloatTensor')
else:
self.centroids = centroids.type('torch.FloatTensor')
self.targetDims = targetDims
self.normalize = options.get('normalize', False)
self.targetDims = np.arange(targetDims)
self.normalize = normalize
def loss(self, model, out, y=None, layer=0, *args, **kwargs):
assert not y is None
z = out['z_enc'][layer]
y = y['instrument']
if y.dim() == 2:
y = fromOneHot(y.cpu());
# Create the target distance matrix
......@@ -56,16 +145,18 @@ class PerceptiveL2Loss(Criterion):
class PerceptiveGaussianLoss(Criterion):
def __init__(self, gaussianParams, targetDims, options={'normalize':False}):
def __init__(self, latent_params, dataset=None, normalize=False):
super(PerceptiveGaussianLoss, self).__init__()
latent_means, latent_stds = gaussianParams
targetDims = latent_params['dim']
_, latent_means, latent_stds = get_perceptual_centroids(dataset, targetDims)
self.latent_means = torch.from_numpy(latent_means).type('torch.FloatTensor') if issubclass(type(latent_means), np.ndarray) else latent_means.type('torch.FloatTensor')
self.latent_stds = torch.from_numpy(latent_stds).type('torch.FloatTensor') if issubclass(type(latent_stds), np.ndarray) else latent_stds.type('torch.FloatTensor')
self.targetDims = targetDims
self.normalize = options.get('normalize', False)
self.targetDims = np.arange(targetDims)
self.normalize = normalize
def loss(self, model, out, y=None, layer=0, *args, **kwargs):
z = out['z_enc'][layer]
y = y['instrument']
if y.dim() == 2:
y = fromOneHot(y);
# Create the target distance matrix
......@@ -98,21 +189,24 @@ class PerceptiveGaussianLoss(Criterion):
class PerceptiveStudent(Criterion):
def __init__(self, centroids, targetDims, options={'normalize':False}):
def __init__(self, latent_params, dataset=None, normalize=False):
super(PerceptiveStudent, self).__init__()
targetDims = latent_params['dim']
centroids = get_perceptual_centroids(dataset, targetDims)[0]
if issubclass(type(centroids), np.ndarray):
self.centroids = torch.from_numpy(centroids).type('torch.FloatTensor')
else:
self.centroids = centroids.type('torch.FloatTensor')
self.targetDims = targetDims
self.normalize = options.get('normalize', False)
self.targetDims = np.arange(targetDims)
self.normalize = normalize
def loss(self, model, out, y=None, layer=0, *args, **kwargs):
z = out['z_enc'][layer]
y = y['instrument']
if y.dim() == 2:
y = fromOneHot(y);
# Create the target distance matrix
targetDistMat = self.centroids[y, :][:, y]
_, targetDistMat = self.centroids[y, :][:, y]
targetDistMat.requires_grad_(False)
targetDistMat = torch.pow((1 + targetDistMat), -1)
targetDistMat = (targetDistMat / torch.sum(targetDistMat))
......
import pdb
import pdb, abc
import torch
from torch import zeros, ones, eye
from . import Bernoulli, Normal, MultivariateNormal, RandomWalk
from . import Bernoulli, Normal, MultivariateNormal, Categorical, RandomWalk
from . import Distribution
from .distribution_flow import Flow
......@@ -17,6 +19,47 @@ def IsotropicMultivariateGaussian(batch_size, device="cpu", requires_grad=False)
covariance_matrix=eye(*batch_size, device=device, requires_grad=requires_grad))
class Prior(object):
@abc.abstractmethod
def __init__(self):
super(Prior, self).__init__()
@abc.abstractmethod
def __call__(self, *args, **kwargs):
raise NotImplementedError
class ClassPrior(Prior):
def __init__(self, task, dist_type, init_args={}, device=None):
super(ClassPrior, self).__init__()
self.dist_type = dist_type
self.init_args = init_args
self.task = task
self.device = device
def get_prior(self, y):
params = {}
for k in self.init_args.keys():
params[k] = []
for i in range(y.shape[0]):
for k,v in self.init_args.items():
params[k].append(torch.from_numpy(v[y[i]]))
for k, v in params.items():
params[k] = torch.stack(v)
if self.device is not None:
params[k] = params[k].to(self.device)
if issubclass(self.dist_type, Normal):
return Normal(params['mean'], params['stddev'])
return self.dist_type(**params)
def __call__(self, *args, y=None, **kwargs):
assert y.get(self.task) is not None
return self.get_prior(y[self.task])
def get_default_distribution(distrib_type, batch_shape, device="cpu", requires_grad=False):
if issubclass(type(distrib_type), Flow):
distrib_type = distrib_type.dist
......
......@@ -15,3 +15,8 @@ class ScaledSoftsign(nn.Module):
return (self.params['a'] * x)/(1 + torch.abs( self.params['b'] * x) )
class Swish(nn.Module):
def forward(self, x: torch.Tensor):
return x * torch.reciprocal(1 + torch.exp(-x))
......@@ -38,7 +38,7 @@ def run(self, loader, preprocessing=None, epoch=None, optimize=True, schedule=Fa
# compute loss
self.logger('data forwarded')
#pdb.set_trace()
batch_loss, losses = self.losses.loss(model=self.models, out=out, target=x, epoch=epoch, plot=plot and not batch, period=period)
batch_loss, losses = self.losses.loss(model=self.models, out=out, target=x, y=y, epoch=epoch, plot=plot and not batch, period=period)
train_losses['main_losses'].append(losses)
except NaNError:
pdb.set_trace()
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment