small updates

parent 8f2d36a7
......@@ -48,6 +48,7 @@ class Adversarial(Criterion):
assert params1 is not None
assert params2 is not None
loss_gen = 0
if issubclass(type(params1), dist.Distribution):
z_fake = params1.rsample().float()
......@@ -105,13 +106,13 @@ class Adversarial(Criterion):
grad = torch.autograd.grad(outputs = self.out_interp, inputs = self.in_interp,
grad_outputs = grad_outputs, create_graph=True, retain_graph=True)[0]
norm_dims = tuple(range(len(self.in_interp.shape)))[1:]
grad_penalty = self.gradient_penalty*((grad.norm(2, dim=norm_dims) - 1 ) ** 2) * self.grad_penalty_weight
grad_penalty = self.gradient_penalty*((grad.norm(2, dim=norm_dims) - 1 ) ** 2)
grad_penalty = grad_penalty.mean()
self.adv_loss = self.adv_loss + grad_penalty
self.adv_loss = self.adv_loss + self.grad_penalty_weight * grad_penalty
losses = (*losses, grad_penalty.cpu().detach().numpy())
else:
losses = (*losses, array([0.]))
self.adv_loss.backward(retain_graph=True)
else:
losses = (*losses, array([0.]))
return loss_gen, losses
......@@ -134,7 +135,7 @@ class Adversarial(Criterion):
# in case, compute gradient penalty
self.optimizer.step()
self.optimizer.zero_grad()
# Wassertein Adversarial Loss
if self.clip:
for param in self.parameters():
param.data.clamp_(self.clip[0], self.clip[1])
......@@ -256,6 +257,7 @@ class ALI(Adversarial):
x_real = target.float().to(device)
z_fake = self.latent_params.get('prior', dist.priors.get_default_distribution(self.latent_params['dist'],z_real.shape)).rsample().float().to(device)
z_fake.requires_grad = True
x_fake = out['x_params'].rsample().float()
# get generated loss
......@@ -287,13 +289,16 @@ class ALI(Adversarial):
loss_fake = torch.nn.functional.binary_cross_entropy(d_fake, torch.zeros(d_fake.shape, device=device), reduction="none")
self.adv_loss = self.reduce((loss_real+loss_fake)/2)
if self.gradient_penalty:
with torch.no_grad():
interp_factor = torch.FloatTensor(z_real.shape[0], *tuple([1]*len(z_real.shape[1:]))).repeat(1, *z_real.shape[1:])
interp_factor.uniform_(0, 1)
interp_factor = interp_factor.to(z_real.device)
self.in_interp = interp_factor * z_real + ((1 - interp_factor)*z_fake)
self.out_interp = torch.sigmoid(self.discriminator(self.hidden_module(self.in_interp)))
if z_real.requires_grad:
if self.gradient_penalty:
raise NotImplementedError
with torch.no_grad():
interp_factor = torch.FloatTensor(z_real.shape[0], *tuple([1]*len(z_real.shape[1:]))).repeat(1, *z_real.shape[1:])
interp_factor.uniform_(0, 1)
interp_factor = interp_factor.to(z_real.device)
self.in_interp = interp_factor * z_real + ((1 - interp_factor)*z_fake)
self.out_interp = torch.sigmoid(self.discriminator(self.hidden_module(self.in_interp)))
self.adv_loss.backward(retain_graph=True)
return loss_gen, (loss_gen.cpu().detach().numpy(), self.adv_loss.cpu().detach().numpy())
......@@ -138,6 +138,8 @@ class ELBO(CriterionContainer):
full_loss = full_loss - logdet_error
logdets = logdets + (float(logdet_error),)
pdb.set_trace()
return full_loss, (rec_errors, reg_errors, *logdets)
def get_named_losses(self, losses):
......
......@@ -43,7 +43,7 @@ class MLPLayer(nn.Module):
"""
dump_patches = True
nn_lin = "ELU"
def __init__(self, input_dim, output_dim, nn_lin=None, batch_norm='batch', dropout=None, name_suffix="", bias=True, *args, **kwargs):
def __init__(self, input_dim, output_dim, nn_lin=None, nn_lin_args={}, batch_norm='batch', dropout=None, name_suffix="", bias=True, *args, **kwargs):
"""
:param input_dim: input dimension
:type input_dim: int
......@@ -83,7 +83,7 @@ class MLPLayer(nn.Module):
# Non Linearity
self.nn_lin = nn_lin
if nn_lin:
modules["nnlin"+name_suffix] = getattr(nn, nn_lin)()
modules["nnlin"+name_suffix] = getattr(nn, nn_lin)(**nn_lin_args)
self.module = nn.Sequential(modules)
@flatten_seq_method
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment