make adversarial great again

parent f5f16522
......@@ -35,5 +35,5 @@ from .criterion_spectral import SpectralLoss
from .criterion_elbo import *
from .criterion_scan import *
from .criterion_misc import *
from .criterion_adversarial import Adversarial
from .criterion_adversarial import Adversarial, AdversarialInfo, ALI
This diff is collapsed.
......@@ -29,15 +29,21 @@ class KLD(Criterion):
def __repr__(self):
return "KLD"
def loss(self, params1=None, params2=None, sample=False, compute_logdets=False, **kwargs):
assert params1, params2
def loss(self, params1=None, params2=None, out1=None, out2=None, sample=False, compute_logdets=False, **kwargs):
sample = (params1 is None) or (params2 is None) or sample
if not sample:
assert (params1 is not None) and (params2 is not None)
try:
loss = clamp(kl.kl_divergence(params1, params2), min=eps)
except NotImplementedError:
loss = self.kl_sampled(params1, params2, **kwargs)
else:
loss = self.kl_sampled(params1, params2, **kwargs)
if out1 is None:
out1 = params1.rsample()
if out2 is None:
out2 = params2.rsample()
assert (out1 is not None) and (out2 is not None)
loss = self.kl_sampled(out1, out2, **kwargs)
loss = reduce(loss, self.reduction);
losses = (float(loss),)
......
......@@ -69,12 +69,12 @@ class ELBO(CriterionContainer):
def parse_layer(latent_params, out, layer_index=0):
if issubclass(type(latent_params), list):
return [parse_layer(latent_params[i], utils.get_latent_out(out,i)) for i in range(len(latent_params))]
#TODO if not z_params_enc, make montecarlo estimation
params1 = out["z_params_enc"];
if params1.requires_preflow:
out1 = out['z_preflow_enc']
else:
out1 = out['z_enc']
# encoder parameter
params1 = out.get("z_params_enc");
out1 = out['z_enc']
if params1 is not None:
if params1.requires_preflow:
out1 = out['z_preflow_enc']
# decoder parameters
prior = latent_params.get('prior') or None
if prior is not None:
......
......@@ -24,6 +24,8 @@ from ..utils.misc import checklist, checktuple, print_module_stats, flatten_seq_
class MLPReinforcementModule(torch.nn.Module):
def __init__(self, in_params, module_params, latent_params=None, recurrent_z=False):
super(MLPReinforcementModule, self).__init__()
self.pinput = in_params
self.platent = latent_params
n_layers = module_params.get('nlayers', 1)
layer = module_params.get('layer_class', bt.MLPGatedLayer)
hidden_dims = checklist(module_params.get('dim', []), n_layers)
......@@ -45,8 +47,13 @@ class MLPReinforcementModule(torch.nn.Module):
#self.nn_lin = module_params.get('nn_lin')
@flatten_seq_method
def forward(self, input, z=None):
old_shape = input.shape
input_dim = len(checktuple(self.pinput['dim']))
input = input.view(*tuple(input.shape[:-input_dim]), cumprod(input.shape[-input_dim:])[-1])
if input.ndim > 2:
input = input.view(cumprod(input.shape[:-1])[-1], input.shape[-1])
if z is not None:
z_input = torch.cat(checklist(z), dim=-1)
# if len(z_input.shape) > 2:
......@@ -60,6 +67,8 @@ class MLPReinforcementModule(torch.nn.Module):
if self.nn_lin:
input = getattr(torch.nn.functional, self.nn_lin)(input)
'''
if old_shape is not None:
input = input.view(*old_shape)
return input
......
......@@ -118,6 +118,15 @@ class Log1pNormalize(Normalize):
def invert(self, x):
return np.exp(super(Log1pNormalize, self).invert(x) - 1)
class Binary(Normalize):
def __init__(self, dataset=None):
super(Binary, self).__init__(dataset=dataset, norm_type="minmax", mode="unipolar")
def __call__(self, *args, **kwargs):
normalized_data = super(Binary, self).__call__(*args, **kwargs)
return (normalized_data >= 0.5).astype(np.int)
class Magnitude(object):
log_threshold = 1e-3
......
......@@ -33,8 +33,6 @@ def run(self, loader, preprocessing=None, epoch=None, optimize=True, schedule=Fa
out = self.models.forward(x, y=y, epoch=epoch)
if self.reinforcers:
out = self.reinforcers.forward(out, target=x, optimize=False)
#self.logger(log_dist("latent", out['z_params_enc'][-1]))
#self.logger(log_dist("data", out['x_params']))
# compute loss
self.logger('data forwarded')
#pdb.set_trace()
......@@ -42,28 +40,28 @@ def run(self, loader, preprocessing=None, epoch=None, optimize=True, schedule=Fa
train_losses['main_losses'].append(losses)
except NaNError:
pdb.set_trace()
#except Exception as e:
# pdb.set_trace()
# trace
if self.trace_mode == "batch":
if period is None:
period = "train" if optimize else "test"
apply_method(self.losses, "write", period, losses)
apply_method(self.monitor, "update")
self.logger("monitor updated")
# learn
self.logger('loss computed')
if optimize:
batch_loss.backward()
self.optimize(self.models, batch_loss)
batch_loss.backward(retain_graph=True)
self.optimize(self.models, batch_loss, epoch=epoch, batch=batch)
if self.reinforcers:
_, reinforcement_losses = self.reinforcers(out, target=x, epoch=epoch, optimize=optimize)
train_losses['reinforcement_losses'].append(reinforcement_losses)
self.logger('optimization done')
# trace
if self.trace_mode == "batch":
if period is None:
period = "train" if optimize else "test"
apply_method(self.losses, "write", period, losses)
apply_method(self.monitor, "update")
self.logger("monitor updated")
# update loop
named_losses = self.losses.get_named_losses(losses)
if self.reinforcers:
......@@ -76,8 +74,6 @@ def run(self, loader, preprocessing=None, epoch=None, optimize=True, schedule=Fa
current_loss += float(batch_loss)
batch += 1
#if batch % 1 == 90:
# torch.cuda.empty_cache()
del out; del x
......
......@@ -62,6 +62,7 @@ class SimpleTrainer(Trainer):
self.tasks = kwargs.get('tasks', None)
self.preprocessing = kwargs.get('preprocessing', None)
self.dataloader_class = kwargs.get('dataloader', self.dataloader_class)
self.optim_balance = kwargs.get('optim_balance')
# additional args
self.trace_mode = kwargs.get('trace_mode', 'epoch')
self.device = kwargs.get('device')
......@@ -89,10 +90,16 @@ class SimpleTrainer(Trainer):
def get_time(self):
return process_time() - self.start_time
def optimize(self, models, loss):
def optimize(self, models, loss, epoch=None, batch=None):
#pdb.set_trace()
apply_method(self.models, 'step', loss)
apply_method(self.losses, 'step', loss)
update_model = True if self.optim_balance is None else batch % self.optim_balance[0] == 0
update_loss = True if self.optim_balance is None else batch % self.optim_balance[1] == 0
if update_model:
print('model!')
apply_method(self.models, 'step', loss)
if update_loss:
apply_method(self.losses, 'step', loss)
print('loss!')
# print_grad_stats(self.models)
def train(self, partition=None, write=False, batch_size=64, tasks=None, batch_cache_size=1, **kwargs):
......
......@@ -46,4 +46,5 @@ from .vae_vanillaDLGM import VanillaDLGM
from .vae_ladderVAE import LadderVAE
from .vae_mxDLGM import AudioMixtureDLGM
from .vae_recurrentVAE import RVAE, VRNN, ShrubVAE
from .vae_predictiveVAE import PredictiveVAE, PredictiveVRNN, PredictiveShrubVAE
\ No newline at end of file
from .vae_predictiveVAE import PredictiveVAE, PredictiveVRNN, PredictiveShrubVAE
from .vae_adversarial import VanillaGAN, InfoGAN
\ No newline at end of file
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment