...
 
Commits (2)
......@@ -334,6 +334,17 @@ class Dataset(torch.utils.data.Dataset):
ids = set(filter(lambda x: current_metadata[x] == m or m in checklist(current_metadata[x]), valid_ids)).union(ids)
return np.array(list(ids))
def set_tasks(self, tasks):
self.tasks = checklist(tasks)
self.drop_tasks = list(filter(lambda x: x in self.tasks, self.drop_tasks))
self.metadataFiles = [None] * len(self.tasks)
self.taskCallback = {}; self.metadataFiles = {}
for t in range(len(self.tasks)):
self.taskCallback[t] = self.retrieve_callback_from_path(self.metadataDirectory, self.tasks[t]) or metadataCallbacks["default"] or []
self.metadataFiles[t] = self.metadataDirectory + '/' + self.tasks[t] + '/metadata.txt' or self.metadataDirectory + '/metadata.txt'
def retrieve(self, idx):
"""
returns a sub-dataset from the actual one. If the main argument is a string, then returns the sub-dataset of the
......@@ -608,6 +619,7 @@ class Dataset(torch.utils.data.Dataset):
curFile, curHash = len(self.files), -1
testFileID = None
classList = {"_length":0}
print(task)
for line in fileCheck:
line = line[:-1]
if line[0] != "#" and len(line) > 1:
......@@ -655,21 +667,21 @@ class Dataset(torch.utils.data.Dataset):
metaList[i] = [-1]
classList['None'] = -1
#pdb.set_trace()
self.metadata[task] = np.array(metaList);
label_files = '/'.join(fileName.split('/')[:-1])+'/classes.txt'
classList = classList or {str(k):k for k in set(self.metadata[task])}
if os.path.isfile(label_files):
type_metadata = type(self.metadata[task][0])
classes_raw = open(label_files, 'r').read().split('\n')
classes_raw = [tuple(c.split('\t')) for c in classes_raw]
classes_raw = list(filter(lambda x: len(x)==2, classes_raw))
self.classes[task] = {v:type_metadata(k) for k,v in classes_raw}
imported_class_dict = dict(classes_raw)
if classList.get('_length'):
del classList['_length']
self.classes[task]= {imported_class_dict[label]:classList[label] for label in classList.keys()}
self.classes[task]['_length'] = len(classes_raw)
else:
if classList is not None:
self.classes[task] = classList;
else:
self.classes[task] = {k:k for k in set(self.metadata[task])}
self.classes[task] = classList;
def import_metadata_tasks(self, sort=True):
"""
......@@ -685,6 +697,7 @@ class Dataset(torch.utils.data.Dataset):
for t in range(len(self.tasks)):
self.metadataFiles[t] = self.metadataDirectory + '/' + self.tasks[t] + '/metadata.txt' or self.metadataDirectory + '/metadata.txt'
for t in range(len(self.tasks)):
#pdb.set_trace()
self.import_metadata(self.metadataFiles[t], self.tasks[t], self.taskCallback[t])
if sort:
for t in self.tasks:
......
......@@ -21,3 +21,4 @@ class MovingMNIST(Dataset):
self.data = np.concatenate([train_data, test_data], axis=0)
self.partitions = {'train':np.arange(train_data.shape[0]), 'test':train_data.shape[0] + np.arange(test_data.shape[0])}
self.files = [None]*len(self.data)
self.hash = {None:list(range(self.data.shape[0]))}
......@@ -286,13 +286,17 @@ def trajectory2audio(model, traj_types, transformOptions, n_trajectories=1, n_st
def interpolate_files(dataset, vae, n_files=1, n_interp=10, out=None, preprocessing=None, preprocess=False,
def interpolate_files(dataset, vae, n_files=None, files=None, n_interp=10, out=None, preprocessing=None, preprocess=False,
projections=None, transformType=None, window=None, transformOptions=None, predict=False, **kwargs):
n_files = n_files or len(files)
for f in range(n_files):
#sequence_length = loaded_data['script_args'].sequence
#files_to_morph = random.choices(range(len(dataset.data)), k=2)
files_to_morph = choices(dataset.files, k=2)
if files is None:
files_to_morph = choices(dataset.files, k=2)
else:
files_to_morph = files[f]
data_outs = []
projections = checklist(projections)
......@@ -352,10 +356,11 @@ def interpolate_files(dataset, vae, n_files=1, n_interp=10, out=None, preprocess
else:
raise NotImplementedError
check_dir(out)
check_dir(out+'/interpolations')
for i in range(data_out.shape[0]):
signal_out = inverseTransform(data_out[i].squeeze().cpu().detach().numpy(), 'stft', {'transformParameters':transformOptions}, iterations=30, method='griffin-lim')
write_wav('%s/interpolations/morph_%d_%d_%d.wav'%(out,l,f,i), signal_out, transformOptions.get('resampleTo', 22050), norm=True)
fig.savefig('%s/interpolations/morph_%d_%d.pdf'%(out,l,f), format="pdf")
plt.close('all')
......
......@@ -745,7 +745,7 @@ def plot_latent3(dataset, model, transformation=None, n_points=None, preprocessi
if tasks == [None]:
full_ids.add(None, ids if ids is not None else range(len(dataset)))
nclasses = {None:[None]}; class_ids = {None:full_ids.get_full_ids()}
else:
if tasks != [None] and tasks != None:
# fill full_ids objects with corresponding classes, get classes index hash
class_ids = {}; nclasses = {}
for t in tasks:
......