parent 278b08ab
from . import VanillaVAE
import torch.nn as nn
from ..distributions.distribution_priors import get_default_distribution
from ..utils import denest_dict, checktuple
# this is a variational library, so we will consider GAN as auto-encoders deprived of encoders.
# Eh, it's quite the truth, right?
class VanillaGAN(VanillaVAE):
def make_encoders(self, input_params, latent_params, hidden_params, *args, **kwargs):
return nn.ModuleList([None]*len(latent_params))
def encode(self, x, y=None, sample=True, from_layer=0, *args, **kwargs):
raise NotImplementedError
def forward(self, x, z=None, y=None, batch_size=None, options={}, *args, **kwargs):
# x in useless here. A specific latent vector can be given using the z keyword.
#TODO verify Empirical distiribution in case of input z
if z is None:
batch_size = batch_size or x.shape[0] or 64
prior = self.platent[-1].get('prior') or get_default_distribution(self.platent[-1]['dist'], (batch_size, *checktuple(self.platent[-1]['dim'])))
#TODO ça va poser un problème ça
z = prior.rsample()
dec_out = self.decode(z)
x_params = dec_out[0]['out_params']
dec_out = denest_dict(dec_out[1:]) if len(dec_out) > 1 else {}
return {'x_params':x_params, 'z_enc':z, 'z_params_enc':prior,
'z_params_dec':prior, 'z_dec':dec_out.get('out')}
# Trick to build the encoder of InfoGAN, while keeping the random sampling of the forward method used during training.
# Also enables the encode method, to use the InfoGAN as an auto-encoder
class InfoGAN(VanillaGAN):
def make_encoders(self, *args, **kwargs):
return VanillaVAE.make_encoders(self, *args, **kwargs)
def encode(self, *args, **kwargs):
return VanillaVAE.encode(self, *args, **kwargs)
\ No newline at end of file
Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment