multilayer lr

parent e9f1e3a1
......@@ -67,7 +67,11 @@ class TensorboardHandler(object):
def add_model_grads(self, model, epoch):
for n, p in model.named_parameters():
self.writer.add_histogram('model/grads/'+n,p.grad.detach().cpu().numpy(),epoch)
print(n)
if p.grad is not None:
self.writer.add_histogram('model/grads/'+n,p.grad.detach().cpu().numpy(),epoch)
else:
print('{Warning} %s has no gradient'%n)
def add_loss(self, losses, epoch):
for loss in losses:
......
......@@ -439,6 +439,8 @@ class ShrubVAE(VanillaVAE):
if optimization_mode == 'recurrent':
alg = optim_params.get('optimizer', 'Adam')
optim_args = optim_params.get('optim_args', {'lr':1e-3})
if issubclass(type(optim_args['lr']), list):
optim_args['lr'] = optim_args['lr'][0]
parameters = nn.ParameterList(sum([list(d.parameters()) for d in self.encoders[1:]] + [list(d.parameters()) for d in self.decoders[1:]], []))
self.optimizers = {'default':getattr(torch.optim, alg)([{'params':parameters}], **optim_args)}
if init_scheduler:
......
......@@ -11,7 +11,7 @@ import torch.nn as nn
import torch.optim
from ..modules.modules_hidden import HiddenModule
from ..utils.misc import GPULogger, denest_dict, apply, apply_method, apply_distribution, print_stats, flatten_seq_method
from ..utils.misc import GPULogger, denest_dict, apply, apply_method, apply_distribution, print_stats, flatten_seq_method, checklist
from . import AbstractVAE
logger = GPULogger(verbose=False)
......@@ -140,10 +140,22 @@ class VanillaVAE(AbstractVAE):
alg = optim_params.get('optimizer', 'Adam')
optim_args = optim_params.get('optim_args', {'lr':1e-3})
optimization_mode = optim_params.get('mode', 'full')
if issubclass(type(optim_args['lr']), list):
if len(optim_args['lr']) != len(self.platent):
optim_args['lr'] = optim_args['lr'][0]
else:
optim_args = [{**optim_args, 'lr':optim_args['lr'][i]} for i in range(len(self.platent))]
optim_args = checklist(optim_args, n=len(self.platent))
param_groups = []
for l in range(len(self.platent)):
if optimization_mode in ['full', 'encoder']:
param_groups.append({'params':self.encoders[l].parameters(), **optim_args[l]})
if optimization_mode in ['full', 'decoder']:
param_groups.append({'params':self.decoders[l].parameters(), **optim_args[l]})
optimizer = getattr(torch.optim, alg)(param_groups)
optimizer = getattr(torch.optim, alg)([{'params':self.encoders.parameters()}], **optim_args)
if optimization_mode == 'full':
optimizer.add_param_group({'params':self.decoders.parameters()})
self.optimizers = {'default':optimizer}
if init_scheduler:
self.init_scheduler(optim_params)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment