perceptual updates

parent 9b7b8a55
......@@ -103,7 +103,7 @@ class PerceptiveL2Loss(Criterion):
def __init__(self, latent_params, dataset=None, normalize=False):
super(PerceptiveL2Loss, self).__init__()
targetDims = latent_params['dim']
_, centroids = get_perceptual_centroids(dataset, latent_params['dim'])[0]
_, (centroids, _) = get_perceptual_centroids(dataset, latent_params['dim'])
if issubclass(type(centroids), np.ndarray):
self.centroids = torch.from_numpy(centroids).type('torch.FloatTensor')
else:
......@@ -192,7 +192,7 @@ class PerceptiveStudent(Criterion):
def __init__(self, latent_params, dataset=None, normalize=False):
super(PerceptiveStudent, self).__init__()
targetDims = latent_params['dim']
centroids = get_perceptual_centroids(dataset, targetDims)[0]
_, (centroids, _) = get_perceptual_centroids(dataset, targetDims)
if issubclass(type(centroids), np.ndarray):
self.centroids = torch.from_numpy(centroids).type('torch.FloatTensor')
else:
......
......@@ -163,7 +163,8 @@ def resynthesize_files(dataset, model, transformOptions=None, transform=None, me
if transform is not None:
if issubclass(type(transform), (list, tuple)):
transform = transform[0]
current_transform = computeTransform([current_file], transform, transformOptions)[0]
pdb.set_trace()
current_transform = computeTransform([current_file], transform, transformOptions)
current_transform = np.array(current_transform)
# ct = np.copy(current_transform)
else:
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment