criterion update

parent 30bd6f05
......@@ -34,8 +34,8 @@ class Criterion(nn.modules.loss._Loss):
super(Criterion, self).__init__(size_average=size_average, reduce=reduce, reduction=reduction)
self._weight = weight
self._pow = pow
self._loss_history = {}
self._name = name
self.reset()
def loss(self, *args, **kwargs):
return 0.
......@@ -47,6 +47,9 @@ class Criterion(nn.modules.loss._Loss):
def loss_history(self):
return self._loss_history
def reset(self):
self._loss_history = {}
def get_named_losses(self, losses):
name = 'dummy' if self._name is None else '%s_dummy'%self._name
return {name : losses}
......@@ -149,6 +152,10 @@ class CriterionContainer(Criterion):
def __init__(self, criterions=[], options={}, weight=1.0, **kwargs):
super().__init__(**kwargs)
self._criterions = criterions
def __getitem__(self, item):
return self._criterions[item]
def loss(self, *args, **kwargs):
loss = 0.; losses = list()
......@@ -164,6 +171,10 @@ class CriterionContainer(Criterion):
def __repr__(self):
return "CriterionContainer(%s)"%self._criterions
def reset(self):
for c in self._criterions:
c.reset()
@property
def loss_history(self):
histories = [c.loss_history for c in self._criterions]
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment