mnist sequential + small update

parent 6e1d5a6a
[submodule "data/toys/--init"]
path = data/toys/--init
url = https://github.com/tychovdo/MovingMNIST.git
[submodule "data/toys/MovingMNIST"]
path = data/toys/MovingMNIST
url = https://github.com/tychovdo/MovingMNIST.git
Subproject commit 6ff8f1d7042168e419c1efc60ab87791b6a263fb
Subproject commit 6ff8f1d7042168e419c1efc60ab87791b6a263fb
from ..data_generic import Dataset
from .MovingMNIST.MovingMNIST import MovingMNIST as MM
import os, numpy as np, torch
class MovingMNIST(Dataset):
def __init__(self, options):
dataPrefix = os.path.dirname(__file__)+'/MovingMNIST'
tasks = []
super(MovingMNIST, self).__init__({'dataPrefix':dataPrefix, 'tasks':tasks})
train_set = MM(root="./MovingMNIST/moving_mnist", train=True, download=True)
test_set = MM(root="./MovingMNIST/moving_mnist", train=False)
train_data = train_set.train_data.numpy()
test_data = test_set.test_data.numpy()
self.data = np.concatenate([train_data, test_data], axis=0)
self.partitions = {'train':np.arange(train_data.shape[0]), 'test':train_data.shape[0] + np.arange(test_data.shape[0])}
......@@ -336,8 +336,7 @@ class BernoulliLayer1D(nn.Module):
def forward(self, ins, *args, **kwargs):
mu = self.modules_list(ins)
if len(mu.shape) != len(checktuple(self.poutput['dim'])) + 1:
mu = mu.view(mu.shape[0], *self.output_dim)
mu = mu.view(*mu.shape[:-1], *checktuple(self.output_dim))
return Bernoulli(probs=mu)
class BernoulliLayer2D(nn.Module):
......
......@@ -179,7 +179,7 @@ class SimpleTrainer(Trainer):
models = self.models
name = str(self.name)
epoch = kwargs.get('epoch')
print('-- saving model at %s'%'results/%s/%s.t7'%(results_folder, name))
print('-- saving model at %s'%'results/%s/%s.pth'%(results_folder, name))
if not issubclass(type(models), list):
models = [models]
datasets = self.datasets
......@@ -195,14 +195,14 @@ class SimpleTrainer(Trainer):
additional_args = {'loss':self.losses, 'partitions':partitions, **kwargs}
if self.reinforcers:
additional_args['reinforcers'] = self.reinforcers
models[i].save('%s/%s.t7'%(results_folder, current_name), **additional_args)
models[i].save('%s/%s.pth'%(results_folder, current_name), **additional_args)
# saving best model
best_model = self.best_model
if not issubclass(type(best_model), list):
best_model = [best_model]
if not self.best_model is None and save_best:
print('-- saving best model at %s'%'results/%s/%s_best.t7'%(results_folder, name))
print('-- saving best model at %s'%'results/%s/%s_best.pth'%(results_folder, name))
for i in range(len(best_model)):
current_name = name+'_best' if len(models) == 1 else '/vae_%d/%s_best'%(i, name)
torch.save({'preprocessing':self.preprocessing, 'loss':self.losses, 'partitions':partitions, **kwargs, **best_model[i]}, '%s/%s.t7'%(results_folder, current_name))
......
......@@ -145,8 +145,11 @@ class ShrubVAE(VanillaVAE):
super(ShrubVAE, self).init_modules(input_params, latent_params, hidden_params,
encoder = encoder, decoder = decoder, *args, **kwargs)
if hidden_params[0].get('load'):
loaded_data = torch.load(hidden_params[0]['load'], map_location="cpu")
vae = loaded_data['class'].load(loaded_data)
if issubclass(type(hidden_params[0]['load']), str):
loaded_data = torch.load(hidden_params[0]['load'], map_location="cpu")
vae = loaded_data['class'].load(loaded_data)
else:
vae = hidden_params[0]['load']
#if hidden_params[0].get('load') in ["encoder", "full"]:
self.encoders[0] = vae.encoders[0]; self.loaded_encoder=True
#elif hidden_params[0].get('load') == ["decoder", "full"]:
......@@ -397,7 +400,10 @@ class ShrubVAE(VanillaVAE):
current_out = self.decoders[0](current_z)
#current_out['out_params'] = current_out['out_params'].view(original_shape[0],original_shape[1],*current_out['out_params'].batch_shape[1:])
if current_out.get('out') is None:
current_out['out'] = apply_method(current_out['out_params'], 'rsample')
if current_out['out_params'].has_rsample:
current_out['out'] = apply_method(current_out['out_params'], 'rsample')
else:
current_out['out'] = apply_method(current_out['out_params'], 'sample')
logger('last layer decoded')
outs.append(current_out)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment