adversarial fixes

parent e3954ae3
......@@ -6,7 +6,7 @@ from .. import distributions as dist
from ..utils import checktuple, checklist
from ..modules.modules_bottleneck import MLP
from ..modules.modules_distribution import get_module_from_density
from numpy import cumprod
from numpy import cumprod, array
class Adversarial(Criterion):
......@@ -49,7 +49,6 @@ class Adversarial(Criterion):
assert params2 is not None
loss_gen = 0
#pdb.set_trace()
if issubclass(type(params1), dist.Distribution):
z_fake = params1.rsample().float()
else:
......@@ -72,7 +71,7 @@ class Adversarial(Criterion):
assert d_gen.min() >= 0 and d_gen.max() <= 1
except AssertionError:
pdb.set_trace()
loss_gen = self.reduce(torch.nn.functional.binary_cross_entropy(d_gen, torch.ones(d_gen.shape, device=device), reduction="none"))
loss_gen = self.reduce(torch.nn.functional.binary_cross_entropy(d_gen, torch.ones(d_gen.shape, device=device, requires_grad=False), reduction="none"))
# get discriminative loss
d_real = torch.sigmoid(self.discriminator(self.hidden_module(z_real)))
......@@ -89,43 +88,50 @@ class Adversarial(Criterion):
except AssertionError:
pdb.set_trace()
loss_real = torch.nn.functional.binary_cross_entropy(d_real, torch.ones(d_real.shape, device=device), reduction="none")
loss_fake = torch.nn.functional.binary_cross_entropy(d_fake, torch.zeros(d_fake.shape, device=device), reduction="none")
loss_real = torch.nn.functional.binary_cross_entropy(d_real, torch.ones(d_real.shape, device=device, requires_grad=False), reduction="none")
loss_fake = torch.nn.functional.binary_cross_entropy(d_fake, torch.zeros(d_fake.shape, device=device, requires_grad=False), reduction="none")
self.adv_loss = self.reduce((loss_real+loss_fake)/2)
if self.gradient_penalty:
with torch.no_grad():
interp_factor = torch.FloatTensor(z_real.shape[0], *tuple([1]*len(z_real.shape[1:]))).repeat(1, *z_real.shape[1:])
interp_factor.uniform_(0, 1)
interp_factor = interp_factor.to(z_real.device)
self.in_interp = interp_factor * z_real + ((1 - interp_factor)*z_fake)
self.out_interp = torch.sigmoid(self.discriminator(self.hidden_module(self.in_interp)))
losses = (loss_gen.cpu().detach().numpy(), self.adv_loss.cpu().detach().numpy())
if z_fake.requires_grad:
if self.gradient_penalty:
with torch.no_grad():
interp_factor = torch.FloatTensor(z_real.shape[0], *tuple([1]*len(z_real.shape[1:]))).repeat(1, *z_real.shape[1:])
interp_factor.uniform_(0, 1)
interp_factor = interp_factor.to(z_real.device)
self.in_interp = interp_factor * z_real + ((1 - interp_factor)*z_fake)
self.out_interp = torch.sigmoid(self.discriminator(self.hidden_module(self.in_interp)))
grad_outputs = torch.ones(self.out_interp.size()).to(self.out_interp.device)
grad = torch.autograd.grad(outputs = self.out_interp, inputs = self.in_interp,
grad_outputs = grad_outputs, create_graph=True, retain_graph=True)[0]
norm_dims = tuple(range(len(self.in_interp.shape)))[1:]
grad_penalty = self.gradient_penalty*((grad.norm(2, dim=norm_dims) - 1 ) ** 2) * self.grad_penalty_weight
grad_penalty = grad_penalty.mean()
self.adv_loss = self.adv_loss + grad_penalty
losses = (*losses, grad_penalty.cpu().detach().numpy())
else:
losses = (*losses, array([0.]))
self.adv_loss.backward(retain_graph=True)
return loss_gen, (loss_gen.cpu().detach().numpy(), self.adv_loss.cpu().detach().numpy())
return loss_gen, losses
def get_named_losses(self, losses):
if issubclass(type(losses[0]), (tuple, list)):
outs = {}
for i,l in enumerate(losses):
outs = {**outs, 'gen_loss_%d'%i:l[0], 'adv_loss_%d'%i:l[1]}
if self.gradient_penalty:
outs = {**outs, 'gen_loss_%d'%i:l[0], 'adv_loss_%d'%i:l[1], 'grad_penalty_%d'%i:l[2]}
else:
outs = {**outs, 'gen_loss_%d'%i:l[0], 'adv_loss_%d'%i:l[1]}
return outs
else:
return {'gen_loss':losses[0], 'adv_loss':losses[1]}
if self.gradient_penalty:
return {'gen_loss':losses[0], 'adv_loss':losses[1], 'grad_penalty':losses[2]}
else:
return {'gen_loss':losses[0], 'adv_loss':losses[1]}
def step(self, *args, retain=False, **kwargs):
# in case, compute gradient penalty
grad_penalty = 0
self.optimizer.zero_grad()
if self.gradient_penalty:
grad = torch.autograd.grad(outputs = self.out_interp, inputs = self.in_interp,
grad_outputs = torch.ones(self.out_interp.size()),
create_graph=True, retain_graph=True)[0]
norm_dims = tuple(range(self.in_interp.ndim))[1:]
grad_penalty = self.gradient_penalty*((grad.norm(2, dim=norm_dims) - 1 ) ** 2) * self.grad_penalty_weight
grad_penalty = grad_penalty.mean()
self.adv_loss = self.adv_loss + grad_penalty
self.adv_loss.backward(retain_graph=True)
self.optimizer.step()
self.optimizer.zero_grad()
# Wassertein Adversarial Loss
......@@ -171,12 +177,12 @@ class AdversarialInfo(Adversarial):
assert not model is None and not out is None and not target is None, "ELBO loss needs a model, an output and an input"
#pdb.set_trace()
x_real = target
z_real = out['z_enc'][-1]
z_real = out['z_enc'][-1]
x_fake = out['x_params'].rsample().float()
device = z_real.device
x_real = target.float().to(device)
device = x_real.device
# get generated loss
out_hidden = self.hidden_module(x_fake)
d_gen = torch.sigmoid(self.discriminator(out_hidden))
......@@ -244,13 +250,14 @@ class ALI(Adversarial):
loss_gen = 0
#pdb.set_trace()
z_real = out['z_params_enc'][-1].rsample().float()
x_real = target
device = z_real.device
x_real = target.float().to(device)
z_fake = self.latent_params.get('prior', dist.priors.get_default_distribution(self.latent_params['dist'],z_real.shape)).rsample().float()
z_fake = self.latent_params.get('prior', dist.priors.get_default_distribution(self.latent_params['dist'],z_real.shape)).rsample().float().to(device)
x_fake = out['x_params'].rsample().float()
device = z_fake.device
# get generated loss
d_gen = torch.sigmoid(self.discriminator(self.hidden_module([x_fake, z_fake])))
......@@ -258,6 +265,7 @@ class ALI(Adversarial):
assert d_gen.min() >= 0 and d_gen.max() <= 1
except AssertionError:
pdb.set_trace()
loss_gen = self.reduce(torch.nn.functional.binary_cross_entropy(d_gen, torch.ones(d_gen.shape, device=device), reduction="none"))
# get discriminative loss
......
......@@ -206,7 +206,7 @@ class PerceptiveStudent(Criterion):
if y.dim() == 2:
y = fromOneHot(y);
# Create the target distance matrix
_, targetDistMat = self.centroids[y, :][:, y]
targetDistMat = self.centroids[y, :][:, y]
targetDistMat.requires_grad_(False)
targetDistMat = torch.pow((1 + targetDistMat), -1)
targetDistMat = (targetDistMat / torch.sum(targetDistMat))
......
......@@ -125,6 +125,17 @@ class Empirical(Distribution):
def __init__(self, tensor):
self.tensor = tensor
def __repr__(self):
return f"Empirical({self.tensor.shape})"
@property
def mean(self):
return self.tensor
@property
def stddev(self):
return torch.zeros_like(self.tensor)
def log_prob(self, value):
if value != self.tensor:
return 0.
......
......@@ -311,7 +311,13 @@ class GaussianLayer2D(nn.Module):
#### Bernoulli layers
def BernoulliLayer(pinput, poutput, **kwargs):
if issubclass(type(poutput['dim']), tuple) or poutput.get('conv'):
take_conv = 0
# if issubclass(type(poutput['dim']), (tuple, list)):
# take_conv = len(poutput['dim']) != 1
if poutput.get('conv'):
take_conv = poutput['conv']
#take_conv = take_conv or issubclass(type(poutput['dim']), tuple)
if take_conv:
return BernoulliLayer2D(pinput, poutput, **kwargs)
else:
return BernoulliLayer1D(pinput, poutput, **kwargs)
......@@ -322,9 +328,10 @@ class BernoulliLayer1D(nn.Module):
def __init__(self, pinput, poutput, **kwargs):
super(BernoulliLayer1D , self).__init__()
self.pinput = pinput; self.poutput = poutput
input_dim = checklist(self.pinput['dim'])[-1]
output_dim = cumprod(checklist(poutput['dim']))[-1]
self.modules_list = nn.Sequential(nn.Linear(input_dim, output_dim), nn.Sigmoid())
self.input_dim = checklist(self.pinput['dim'])[-1]
self.output_dim = checktuple(poutput['dim'])
cum_output_dim = cumprod(self.output_dim)[-1]
self.modules_list = nn.Sequential(nn.Linear(self.input_dim, cum_output_dim), nn.Sigmoid())
init_module(self.modules_list, 'Sigmoid')
def forward(self, ins, *args, **kwargs):
......
......@@ -264,6 +264,8 @@ def plot_mean(x, target=None, preprocessing=None, axes=None, *args, is_sequence=
fig, axes = plot_mean_2d(x, x=target, preprocessing=preprocessing, axes=axes, *args, **kwargs)
return fig, axes
def plot_empirical(x, *args, **kwargs):
return plot_mean(dist.Normal(x.mean, x.stddev), *args, **kwargs)
def plot_probs(x, target=None, preprocessing=None, *args, **kwargs):
n_examples = x.batch_shape[0]
......@@ -295,7 +297,8 @@ def plot_probs(x, target=None, preprocessing=None, *args, **kwargs):
plotting_hash = {torch.Tensor: plot_dirac,
dist.Normal: plot_mean,
dist.Bernoulli: plot_mean,
dist.Categorical: plot_probs}
dist.Categorical: plot_probs,
dist.Empirical: plot_empirical}
def plot_distribution(dists, *args, **kwargs):
if issubclass(type(dists), list):
......
......@@ -30,6 +30,7 @@ def run(self, loader, preprocessing=None, epoch=None, optimize=True, schedule=Fa
x = sample_normalize(x)
self.logger('data preprocessed')
try:
print(torch.cuda.memory_allocated(), torch.cuda.memory_cached())
out = self.models.forward(x, y=y, epoch=epoch)
if self.reinforcers:
out = self.reinforcers.forward(out, target=x, optimize=False)
......@@ -45,7 +46,7 @@ def run(self, loader, preprocessing=None, epoch=None, optimize=True, schedule=Fa
# learn
self.logger('loss computed')
if optimize:
batch_loss.backward(retain_graph=True)
batch_loss.backward(retain_graph=False)
self.optimize(self.models, batch_loss, epoch=epoch, batch=batch)
if self.reinforcers:
......
......@@ -95,11 +95,9 @@ class SimpleTrainer(Trainer):
update_model = True if self.optim_balance is None else batch % self.optim_balance[0] == 0
update_loss = True if self.optim_balance is None else batch % self.optim_balance[1] == 0
if update_model:
print('model!')
apply_method(self.models, 'step', loss)
if update_loss:
apply_method(self.losses, 'step', loss)
print('loss!')
# print_grad_stats(self.models)
def train(self, partition=None, write=False, batch_size=64, tasks=None, batch_cache_size=1, **kwargs):
......
from . import VanillaVAE
import torch.nn as nn
from .. import distributions as dist
from ..distributions.distribution_priors import get_default_distribution
from ..utils import denest_dict, checktuple
# this is a variational library, so we will consider GAN as auto-encoders deprived of encoders.
......@@ -21,10 +23,17 @@ class VanillaGAN(VanillaVAE):
batch_size = batch_size or x.shape[0] or 64
prior = self.platent[-1].get('prior') or get_default_distribution(self.platent[-1]['dist'], (batch_size, *checktuple(self.platent[-1]['dim'])))
#TODO ça va poser un problème ça
z = prior.rsample()
z = prior.rsample().to(next(self.parameters()).device)
z.requires_grad = True
dec_out = self.decode(z)
x_params = dec_out[0]['out_params']
#TODO make with EmpiricalLayer
if issubclass(type(x_params), dist.Bernoulli):
x_params = dist.Empirical(x_params.probs)
elif issubclass(type(x_params), dist.Normal):
x_params = dist.Empirical(x_params.mean)
dec_out = denest_dict(dec_out[1:]) if len(dec_out) > 1 else {}
return {'x_params':x_params, 'z_enc':z, 'z_params_enc':prior,
......@@ -39,4 +48,4 @@ class InfoGAN(VanillaGAN):
return VanillaVAE.make_encoders(self, *args, **kwargs)
def encode(self, *args, **kwargs):
return VanillaVAE.encode(self, *args, **kwargs)
\ No newline at end of file
return VanillaVAE.encode(self, *args, **kwargs)
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment