minor upadtes

parent 5385a14f
......@@ -128,6 +128,7 @@ class ELBO(CriterionContainer):
reg_loss, reg_losses = reg_args[0](**reg_args[1], reduction=self.reduction, is_sequence=model.take_sequences, **kwargs)
#print(reg_loss, reg_args[2])
full_loss = full_loss + reg_args[2]*reg_loss if reg_args[2] != 0. else full_loss
print(f"{out['z_params_enc'][0].mean.mean()}, {out['z_params_enc'][0].mean.std()}, {out['z_params_enc'][0].mean.min()}, {out['z_params_enc'][0].mean.max()}")
reg_errors = reg_errors + (reg_losses,)
if out.get('logdets') is not None:
if out['logdets'][i] is None:
......
......@@ -71,16 +71,16 @@ class DatasetAudio(generic.Dataset):
# Type of audio-related augmentations
self.augmentationCallbacks = [];
self.preprocessing = None
# if self.importType == "asynchronous":
# self.flattenData = self.flattenDataAsynchronous
"""
def __getitem__(self, *args, **kwargs):
data, metadata = super(DatasetAudio, self).__getitem__(*args, **kwargs)
if self.preprocessing:
data = self.preprocessing(data)
return data, metadata
"""
"""
###################################
# Import functions
......@@ -536,7 +536,7 @@ class DatasetAudio(generic.Dataset):
def retrieve(self, idx):
dataset = super(DatasetAudio, self).retrieve(idx)
dataset.transformOptions = self.transformOptions
dataset.preprocessing = self.preprocessing
#dataset.preprocessing = self.preprocessing
dataset.transformName = self.transformName
return dataset
......
......@@ -210,6 +210,7 @@ class Dataset(torch.utils.data.Dataset):
self.metadata = {}
self.labels = []
self.data = []
self.preprocessing = None
self.metadataFiles = [None] * len(self.tasks)
for t in range(len(self.tasks)):
self.taskCallback[t] = (options.get("taskCallback") and options["taskCallback"][t]) or self.retrieve_callback_from_path(self.metadataDirectory, self.tasks[t]) or metadataCallbacks["default"] or []
......@@ -260,6 +261,9 @@ class Dataset(torch.utils.data.Dataset):
data = np.concatenate(data, axis=0)
else:
data = self._get_padded_data(data, self.padded_dims, self.padded_lengths)
if self.preprocessing is not None:
data = self.preprocessing(data)
metadata = {}
# Get corresponding metadata
if self.drop_tasks:
......@@ -397,6 +401,7 @@ class Dataset(torch.utils.data.Dataset):
newDataset.padded_dims = self.padded_dims
newDataset.drop_tasks = self.drop_tasks
newDataset.has_sequences = self.has_sequences
newDataset.preprocessing = self.preprocessing
#newDataset.transformOptions = self.transformOptionclass_ids)s
if len(self.partitions) != 0:
......
......@@ -281,8 +281,12 @@ class DeconvLayer(ConvLayer):
pool_modules = {1:nn.MaxUnpool1d, 2:nn.MaxUnpool2d, 3:nn.MaxUnpool3d}
conv_modules = {1: torch.nn.ConvTranspose1d, 2:torch.nn.ConvTranspose2d, 3:torch.nn.ConvTranspose3d}
@flatten_seq_method
def forward(self, x, *args, indices=None, output_size=None, **kwargs):
batch_shape = checktuple(x.shape[0])
if len(x.shape) > self.conv_dim + 2:
batch_shape = x.shape[:2]
x = x.view(np.cumprod(x.shape[:-(len(conv_dim)+1)])[-1], *x.shape[-(len(conv_dim)+1):])
if not output_size is None:
output_size = [int(o) for o in output_size]
if self.pooling:
......@@ -302,6 +306,10 @@ class DeconvLayer(ConvLayer):
if self.nn_lin:
current_out = self._modules['nnlin_module'](current_out)
if len(batch_shape) > 1:
current_out = current_out.view(*batch_shape, *current_out.shape[-(len(conv_dim)+1):])
return current_out
def init_pooling(self, dim, kernel_size):
......@@ -751,7 +759,7 @@ class ConvolutionalLatent(nn.Module):
# make flattening modules
transfer_modules = [None]*len(self.phidden); transfered_sizes = [None]*len(self.phidden)
output_size = [p.get_output_conv_length()[1][-1] for p in self.conv_modules]
output_size = [p.get_output_conv_length(input_dim=self.pins[i]['dim'])[1][-1] for i, p in enumerate(self.conv_modules)]
for i in range(len(self.phidden)):
transfer_modules[i], transfered_sizes[i] = self.get_flattening_module(output_size[i], self.phidden[i])
self.transfer_modules = nn.ModuleList(transfer_modules)
......@@ -765,7 +773,7 @@ class ConvolutionalLatent(nn.Module):
def get_flattening_module(input_dim, phidden, *args, **kwargs):
transfer_mode = phidden.get('transfer', 'unflatten')
if transfer_mode == "unflatten":
return Flatten(), input_dim*phidden['channels'][-1]
return Flatten(), int(np.cumprod(input_dim)[-1])*phidden['channels'][-1]
elif transfer_mode == "conv1x1":
n_channels = phidden['channels'][-1]
return Sequential(ConvLayer.conv_modules[phidden['conv_dim']](n_channels, 1, kernel_size=1), Flatten()), input_dim
......
......@@ -282,23 +282,24 @@ class RVAEDecoder(HiddenModule):
has_hidden = phidden is not None and phidden != {}
self.precurrent = recurrent_params
self.pouts = pouts
super(RVAEDecoder, self).__init__(pins, phidden, pouts=pouts, precurrent=recurrent_params, *args, **kwargs)
super(RVAEDecoder, self).__init__(pins, phidden, precurrent=recurrent_params, *args, **kwargs)
if pouts:
if has_hidden:
self.out_modules = self.make_output_layers(phidden, pouts, is_seq=True)
else:
self.out_modules = self.make_output_layers(recurrent_params, pouts, is_seq=True)
self.out_modules = self.make_output_layers(recurrent_params, pouts, is_seq=True)
def make_hidden_layers(self, pins, phidden={"dim": 800, "nlayers": 2, 'label': None, 'conditioning': 'concat'},
*args, **kwargs):
precurrent = kwargs.get('precurrent') or self.precurrent
hidden_module = super().make_hidden_layers(precurrent, phidden, *args, **kwargs)
recurrent_module = self.make_recurrent_layer(pins, precurrent)
return Sequential(recurrent_module, hidden_module)
kwargs['precurrent'] = kwargs.get('precurrent') or self.precurrent
#hidden_module = super().make_hidden_layers(precurrent, phidden, *args, **kwargs)
#recurrent_module = self.make_recurrent_layer(pins, precurrent)
hidden_module = super().make_hidden_layers(pins, phidden, *args, **kwargs)
recurrent_module = self.make_recurrent_layer(phidden, *args, **kwargs)
return Sequential(hidden_module, recurrent_module)
@property
def hidden_out_params(self, hidden_modules=None):
return self.precurrent
"""
hidden_modules = hidden_modules or self._hidden_modules
if issubclass(type(hidden_modules), nn.ModuleList):
params = []
......@@ -310,9 +311,9 @@ class RVAEDecoder(HiddenModule):
return params
else:
if hasattr(hidden_modules, 'phidden'):
return checklist(hidden_modules[1].phidden)[-1]
else:
g else:
return checklist(checklist(self.phidden)[0])[-1]
"""
def make_recurrent_layer(self, phidden, precurrent, *args, **kwargs):
if issubclass(type(precurrent), list):
......@@ -331,7 +332,7 @@ class RVAEDecoder(HiddenModule):
return hidden_out
def clear_recurrent(self):
self.hidden_modules[0].clear()
self.hidden_modules[1].clear()
def forward(self, x, n_steps=100, y=None, sample=True, clear=True, return_hidden=False, *args, **kwargs):
if clear:
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment