small updates

parent d911cc09
......@@ -452,16 +452,17 @@ class ShrubVAE(VanillaVAE):
def init_optimizer(self, optim_params, init_scheduler=True):
optimization_mode = optim_params.get('optimize', 'full')
if optimization_mode == 'recurrent':
alg = optim_params.get('optimizer', 'Adam')
optim_args = optim_params.get('optim_args', {'lr':1e-3})
if issubclass(type(optim_args['lr']), list):
optim_args['lr'] = optim_args['lr'][0]
parameters = nn.ParameterList(sum([list(d.parameters()) for d in self.encoders[1:]] + [list(d.parameters()) for d in self.decoders[1:]], []))
self.optimizers = {'default':getattr(torch.optim, alg)([{'params':parameters}], **optim_args)}
if init_scheduler:
self.init_scheduler(optim_params)
else:
super(ShrubVAE, self).init_optimizer(optim_params)
# alg = optim_params.get('optimizer', 'Adam')
# optim_args = optim_params.get('optim_args', {'lr':1e-3})
# if issubclass(type(optim_args['lr']), list):
# optim_args['lr'] = optim_args['lr'][0]
# parameters = nn.ParameterList(sum([list(d.parameters()) for d in self.encoders[1:]] + [list(d.parameters()) for d in self.decoders[1:]], []))
# self.optimizers = {'default':getattr(torch.optim, alg)([{'params':parameters}], **optim_args)}
# if init_scheduler:
# self.init_scheduler(optim_params)
optim_args = optim_params['optim_args']
optim_args['lr'] = [0.] + checklist(optim_args['lr'], n=len(self.platent)-1)
super(ShrubVAE, self).init_optimizer(optim_params)
def init_scheduler(self, optim_params):
......
......@@ -150,6 +150,7 @@ class VanillaVAE(AbstractVAE):
optim_args = checklist(optim_args, n=len(self.platent))
param_groups = []
for l in range(len(self.platent)):
print('layer %d : %s'%(l, optim_args[l]))
if optimization_mode in ['full', 'encoder']:
param_groups.append({'params':self.encoders[l].parameters(), **optim_args[l]})
if optimization_mode in ['full', 'decoder']:
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment