Commit 77dd66d7 authored by Keunwoo Choi's avatar Keunwoo Choi
Browse files

first commit

parents
This diff is collapsed.
This diff is collapsed.
""" .py version of 'knn and svm - for many tasks.ipynb'
After using mid_layer features, it's diverging from the ipython notebook - the notebook is outdated now.
"""
import matplotlib
import matplotlib.pyplot as plt
import multiprocessing
plt.style.use('ggplot')
font = {'family': 'consolas',
'weight': 'light',
'size': 12}
matplotlib.rc('font', **font)
ggplot_colors = [plt.rcParams['axes.color_cycle'][i] for i in [0, 1, 2, 3, 4, 5, 6]]
import os
import sys
import numpy as np
import librosa
import time
import sklearn
import pdb
from sklearn.neighbors import KNeighborsClassifier
from sklearn.svm import SVC, SVR
from sklearn.linear_model import LogisticRegression
from sklearn.neighbors import KNeighborsRegressor
from sklearn.model_selection import PredefinedSplit
from sklearn.model_selection import GridSearchCV
from sklearn.preprocessing import StandardScaler
from utils_featext import OptionalStandardScaler
from sklearn.pipeline import Pipeline
import pandas as pd
import logging
import cPickle as cP
PATH_CLS = 'data_classifiers/' # save classifiers
FOLDER_CSV = 'data_csv/'
FOLDER_FEATS = 'data_feats/'
FOLDER_RESULTS = 'result_transfer/'
try:
os.mkdir(FOLDER_RESULTS)
except:
pass
def load_xy_many(taskname, featname='mine', npy_suffix='', logger=None, mid_layer=4):
""" wrapper for load_xy() for loading and concatenating multiple of them. """
if featname == 'mfcc':
x, y = load_xy(taskname, featname, npy_suffix, logger, mid_layer=mid_layer)
elif featname == 'mine':
for l_idx, mid_layer_num in enumerate(mid_layer):
if l_idx == 0:
x, y = load_xy(taskname, featname, npy_suffix, logger, mid_layer=mid_layer_num)
else:
x_new, _ = load_xy(taskname, featname, npy_suffix, logger, mid_layer=mid_layer_num)
x = np.concatenate((x, x_new), axis=1)
elif featname == 'mfcc+12345':
x, _ = load_xy_many(taskname, 'mfcc', npy_suffix, logger, mid_layer)
x_12345, y = load_xy_many(taskname, 'mine', npy_suffix, logger, [0, 1, 2, 3, 4])
x = np.concatenate((x, x_12345), axis=1)
return x, y
def load_xy(task_name, feat_name='mine', npy_suffix='', logger=None, mid_layer=4):
"""
:param task_name:
:param feat_name:
:param npy_suffix:
:param logger:
:param mid_layer: ignired if 'mfcc' is used
:return:
"""
assert task_name in ('ballroom_extended', 'gtzan_genre', 'gtzan_speechmusic',
'emoMusic_a', 'emoMusic_v', 'jamendo_vd', 'urbansound')
# logger.info('load_xy({}, {}, mid_layer: {}, npy_suffix: {})...'.format(task_name, feat_name, mid_layer, npy_suffix))
# X
csv_filename = '{}.csv'.format(task_name)
if feat_name == 'mine':
if task_name.startswith('emoMusic'):
if mid_layer == 4: # For the last layer, use Max-Pooled one
npy_filename = '{}{}.npy'.format('emoMusic', npy_suffix)
else: # For the others, use Average-Pooled ones
npy_filename = '{}_layer_{}{}.npy'.format('emoMusic', mid_layer, npy_suffix)
else:
if mid_layer == 4:
npy_filename = '{}{}.npy'.format(task_name, npy_suffix)
else:
npy_filename = '{}_layer_{}{}.npy'.format(task_name, mid_layer, npy_suffix)
elif feat_name == 'mfcc':
if task_name.startswith('emoMusic'):
npy_filename = '{}_mfcc.npy'.format('emoMusic')
else:
npy_filename = '{}_mfcc.npy'.format(task_name)
x = np.load(os.path.join(FOLDER_FEATS, npy_filename))
# Y
if task_name == 'emoMusic_v':
csv_filename = '{}.csv'.format('emoMusic')
df = pd.DataFrame.from_csv(os.path.join(FOLDER_CSV, csv_filename))
y = df['label_valence']
elif task_name == 'emoMusic_a':
csv_filename = '{}.csv'.format('emoMusic')
df = pd.DataFrame.from_csv(os.path.join(FOLDER_CSV, csv_filename))
y = df['label_arousal']
else:
y = pd.DataFrame.from_csv(os.path.join(FOLDER_CSV, csv_filename))['label']
return x, y
def save_result(featname, taskname, classifiername, score):
"""featname: string, taskname:string, score:float"""
filename = 'T_{}_F_{}_CL_{}.npy'.format(taskname, featname, classifiername)
np.save(os.path.join(FOLDER_RESULTS, filename), score)
def cross_validate(featnames, tasknames, cvs, classifiers, gps, logger, n_jobs, npy_suffix='', mid_layer=4):
'''featnames: list of string, ['mine', 'mfcc']
- tasknames = list of stringm ['ballroom_extended', 'gtzan_genre', 'gtzan_speechmusic',
'emoMusic', 'jamendo_vc', 'urbansound']
- cvs: list of cv, 10 for rest, split arrays for urbansound and jamendo_vd
- classifier: list of classifier class, e.g [KNeighborsClassifier, SVC]
- gps: list of gp, e.g. [{"n_neighbors":[1, 2, 8, 12, 16]}, {"C":[0.1, 8.0], "kernel":['linear', 'rbf']}]
- mid_layer: scalar, or list of scalar .
'''
np.random.seed(1209)
if not isinstance(mid_layer, list):
mid_layer = [mid_layer]
logger.info('')
logger.info('--- Cross-validation started for {} ---'.format(''.join([str(i) for i in mid_layer])))
for featname in featnames:
logger.info(' * feat_name: {} ---'.format(featname))
for classifier, gp in zip(classifiers, gps):
clname = classifier.__name__
logger.info(' - classifier: {} ---'.format(clname))
for taskname, cv in zip(tasknames, cvs):
logger.info(' . task: {} ---'.format(taskname))
model_filename = 'clf_{}_{}_{}.cP'.format(featname, taskname, clname)
x, y = load_xy_many(taskname, featname, npy_suffix, logger, mid_layer=mid_layer)
estimators = [('stdd', OptionalStandardScaler()), ('clf', classifier())]
pipe = Pipeline(estimators)
if isinstance(gp, dict): # k-nn or svm with single kernel
params = {'stdd__on': [True, False]}
params.update({'clf__' + key: value for (key, value) in gp.iteritems()})
elif isinstance(gp, list): # svm: grid param can be a list of dictionaries
params = []
for dct in gp: # should be dict of list for e.g. svm
sub_params = {'stdd__on': [True, False]}
sub_params.update({'clf__' + key: value for (key, value) in dct.iteritems()})
params.append(sub_params)
clf = GridSearchCV(pipe, params, cv=cv, n_jobs=n_jobs, pre_dispatch='8*n_jobs').fit(x, y)
logger.info(' . best score {}'.format(clf.best_score_))
logger.info(clf.best_params_)
print('best score of {}, {}, {}: {}'.format(featname,
taskname,
clname,
clf.best_score_))
print(clf.best_params_)
cP.dump(clf, open(os.path.join(PATH_CLS, model_filename), 'w'))
featname_midlayer = '{}_{}'.format(featname, ''.join([str(i) for i in mid_layer]))
save_result(featname_midlayer, taskname, clname, clf.best_score_)
def get_cv_jamendo():
task_name = 'jamendo_vd'
csv_filename = '{}.csv'.format(task_name)
df = pd.DataFrame.from_csv(os.path.join(FOLDER_CSV, csv_filename))
splits = 0 * np.array([df['category'] == 'train']).astype(int) \
+ 0 * np.array([df['category'] == 'valid']).astype(int) \
+ 1 * np.array([df['category'] == 'test']).astype(int)
# PredefinedSplit(df['category'])
train_idxs = np.where(np.any(splits == 0, axis=0))[0]
test_idxs = np.where(np.any(splits == 1, axis=0))[0]
cv_iter = [(train_idxs, test_idxs)]
return cv_iter
def get_cv_urbansound():
task_name = 'urbansound'
csv_filename = '{}.csv'.format(task_name)
df = pd.DataFrame.from_csv(os.path.join(FOLDER_CSV, csv_filename))
ps = PredefinedSplit(df['fold'])
return ps
def get_logger(task_idx, system_name='', memo=''):
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
# create a file handler
handler = logging.FileHandler('feature-transfer_task_{}_{}.log'.format(task_idx, system_name))
handler.setLevel(logging.INFO)
# create a logging format
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
handler.setFormatter(formatter)
# add the handlers to the logger
logger.addHandler(handler)
logger.info('-' * 50)
logger.info('.' * 50)
logger.info('memo: {}'.format(memo))
logger.info('.' * 50)
logger.info('-' * 50)
return logger
def do_task_svm(task_idx, logger, n_jobs, which_emo='all'):
"""task_idx is 1 based, NOT zero-based.
"""
assert 1 <= task_idx <= 6
tasknames_cl = ['ballroom_extended', 'gtzan_genre', 'gtzan_speechmusic',
'jamendo_vd', 'urbansound']
cvs_cl = [10, 10, 10, get_cv_jamendo(), get_cv_urbansound()]
is_classification = True
if task_idx <= 3:
tasknames = tasknames_cl[task_idx - 1: task_idx]
cvs = cvs_cl[task_idx - 1: task_idx]
elif task_idx == 4: # Regression task.
is_classification = False
if which_emo == 'all':
tasknames = ['emoMusic_a', 'emoMusic_v']
cvs = [10]
elif which_emo == 'a':
tasknames = ['emoMusic_a']
cvs = [10]
elif which_emo == 'v':
tasknames = ['emoMusic_v']
cvs = [10, 10]
elif task_idx == 5:
tasknames = tasknames_cl[3:4]
cvs = cvs_cl[3:4]
elif task_idx == 6:
tasknames = tasknames_cl[4:5]
cvs = cvs_cl[4:5]
gps = [[{"C": [0.1, 2.0, 8.0, 32.0], "kernel": ['rbf'],
"gamma": [0.5 ** i for i in [3, 5, 7, 9, 11, 13]] + ['auto']},
{"C": [0.1, 2.0, 8.0, 32.0], "kernel": ['linear']}
]]
if is_classification:
classifiers = [SVC]
else:
classifiers = [SVR]
# FOR MFCC+12345 test,....
one_layers = [[i] for i in range(5)]
two_layers = [[i, j] for i in range(5) for j in range(i + 1, 5)]
three_layers = [[i, j, k] for i in range(5) for j in range(i + 1, 5) for k in range(j + 1, 5)]
four_layers = [range(4), range(1, 5)]
five_layers = [range(5)]
# all_layers = five_layers + four_layers + three_layers + two_layers + one_layers # 1, 2, 10, 10, 5
for mid_layer in all_layers:
cross_validate(['mine'], tasknames, cvs, classifiers, gps, logger, n_jobs, mid_layer=mid_layer)
# MFCC
cross_validate(['mfcc'], tasknames, cvs, classifiers, gps, logger, n_jobs, mid_layer=None)
cross_validate(['mfcc+12345'], tasknames, cvs, classifiers, gps, logger, n_jobs, mid_layer=None)
def main_all():
n_cpu = multiprocessing.cpu_count()
n_jobs = int(n_cpu * 0.8)
task_idxs = range(1, 7)
for task_idx in task_idxs:
logger = get_logger(task_idx)
do_task_svm(task_idx, logger, n_jobs)
if __name__ == '__main__':
main_all()
This diff is collapsed.
# transfer_learning_music
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
,id,filepath,label
0,bagpipe,gtzan_music_speech/music_speech/music_wav/bagpipe.wav,0
1,ballad,gtzan_music_speech/music_speech/music_wav/ballad.wav,0
2,bartok,gtzan_music_speech/music_speech/music_wav/bartok.wav,0
3,beat,gtzan_music_speech/music_speech/music_wav/beat.wav,0
4,beatles,gtzan_music_speech/music_speech/music_wav/beatles.wav,0
5,bigband,gtzan_music_speech/music_speech/music_wav/bigband.wav,0
6,birdland,gtzan_music_speech/music_speech/music_wav/birdland.wav,0
7,blues,gtzan_music_speech/music_speech/music_wav/blues.wav,0
8,bmarsalis,gtzan_music_speech/music_speech/music_wav/bmarsalis.wav,0
9,brahms,gtzan_music_speech/music_speech/music_wav/brahms.wav,0
10,canonaki,gtzan_music_speech/music_speech/music_wav/canonaki.wav,0
11,caravan,gtzan_music_speech/music_speech/music_wav/caravan.wav,0
12,chaka,gtzan_music_speech/music_speech/music_wav/chaka.wav,0
13,classical,gtzan_music_speech/music_speech/music_wav/classical.wav,0
14,classical1,gtzan_music_speech/music_speech/music_wav/classical1.wav,0
15,classical2,gtzan_music_speech/music_speech/music_wav/classical2.wav,0
16,copland,gtzan_music_speech/music_speech/music_wav/copland.wav,0
17,copland2,gtzan_music_speech/music_speech/music_wav/copland2.wav,0
18,corea,gtzan_music_speech/music_speech/music_wav/corea.wav,0
19,corea1,gtzan_music_speech/music_speech/music_wav/corea1.wav,0
20,cure,gtzan_music_speech/music_speech/music_wav/cure.wav,0
21,debussy,gtzan_music_speech/music_speech/music_wav/debussy.wav,0
22,deedee,gtzan_music_speech/music_speech/music_wav/deedee.wav,0
23,deedee1,gtzan_music_speech/music_speech/music_wav/deedee1.wav,0
24,duke,gtzan_music_speech/music_speech/music_wav/duke.wav,0
25,echoes,gtzan_music_speech/music_speech/music_wav/echoes.wav,0
26,eguitar,gtzan_music_speech/music_speech/music_wav/eguitar.wav,0
27,georose,gtzan_music_speech/music_speech/music_wav/georose.wav,0
28,gismonti,gtzan_music_speech/music_speech/music_wav/gismonti.wav,0
29,glass,gtzan_music_speech/music_speech/music_wav/glass.wav,0
30,glass1,gtzan_music_speech/music_speech/music_wav/glass1.wav,0
31,gravity,gtzan_music_speech/music_speech/music_wav/gravity.wav,0
32,gravity2,gtzan_music_speech/music_speech/music_wav/gravity2.wav,0
33,guitar,gtzan_music_speech/music_speech/music_wav/guitar.wav,0
34,hendrix,gtzan_music_speech/music_speech/music_wav/hendrix.wav,0
35,ipanema,gtzan_music_speech/music_speech/music_wav/ipanema.wav,0
36,jazz,gtzan_music_speech/music_speech/music_wav/jazz.wav,0
37,jazz1,gtzan_music_speech/music_speech/music_wav/jazz1.wav,0
38,led,gtzan_music_speech/music_speech/music_wav/led.wav,0
39,loreena,gtzan_music_speech/music_speech/music_wav/loreena.wav,0
40,madradeus,gtzan_music_speech/music_speech/music_wav/madradeus.wav,0
41,magkas,gtzan_music_speech/music_speech/music_wav/magkas.wav,0
42,march,gtzan_music_speech/music_speech/music_wav/march.wav,0
43,marlene,gtzan_music_speech/music_speech/music_wav/marlene.wav,0
44,mingus,gtzan_music_speech/music_speech/music_wav/mingus.wav,0
45,mingus1,gtzan_music_speech/music_speech/music_wav/mingus1.wav,0
46,misirlou,gtzan_music_speech/music_speech/music_wav/misirlou.wav,0
47,moanin,gtzan_music_speech/music_speech/music_wav/moanin.wav,0
48,narch,gtzan_music_speech/music_speech/music_wav/narch.wav,0
49,ncherry,gtzan_music_speech/music_speech/music_wav/ncherry.wav,0
50,nearhou,gtzan_music_speech/music_speech/music_wav/nearhou.wav,0
51,opera,gtzan_music_speech/music_speech/music_wav/opera.wav,0
52,opera1,gtzan_music_speech/music_speech/music_wav/opera1.wav,0
53,pop,gtzan_music_speech/music_speech/music_wav/pop.wav,0
54,prodigy,gtzan_music_speech/music_speech/music_wav/prodigy.wav,0
55,redhot,gtzan_music_speech/music_speech/music_wav/redhot.wav,0
56,rock,gtzan_music_speech/music_speech/music_wav/rock.wav,0
57,rock2,gtzan_music_speech/music_speech/music_wav/rock2.wav,0
58,russo,gtzan_music_speech/music_speech/music_wav/russo.wav,0
59,tony,gtzan_music_speech/music_speech/music_wav/tony.wav,0
60,u2,gtzan_music_speech/music_speech/music_wav/u2.wav,0
61,unpoco,gtzan_music_speech/music_speech/music_wav/unpoco.wav,0
62,vlobos,gtzan_music_speech/music_speech/music_wav/vlobos.wav,0
63,winds,gtzan_music_speech/music_speech/music_wav/winds.wav,0
64,acomic,gtzan_music_speech/music_speech/speech_wav/acomic.wav,1
65,acomic2,gtzan_music_speech/music_speech/speech_wav/acomic2.wav,1
66,allison,gtzan_music_speech/music_speech/speech_wav/allison.wav,1
67,amal,gtzan_music_speech/music_speech/speech_wav/amal.wav,1
68,austria,gtzan_music_speech/music_speech/speech_wav/austria.wav,1
69,bathroom1,gtzan_music_speech/music_speech/speech_wav/bathroom1.wav,1
70,chant,gtzan_music_speech/music_speech/speech_wav/chant.wav,1
71,charles,gtzan_music_speech/music_speech/speech_wav/charles.wav,1
72,china,gtzan_music_speech/music_speech/speech_wav/china.wav,1
73,comedy,gtzan_music_speech/music_speech/speech_wav/comedy.wav,1
74,comedy1,gtzan_music_speech/music_speech/speech_wav/comedy1.wav,1
75,conversion,gtzan_music_speech/music_speech/speech_wav/conversion.wav,1
76,danie,gtzan_music_speech/music_speech/speech_wav/danie.wav,1
77,danie1,gtzan_music_speech/music_speech/speech_wav/danie1.wav,1
78,dialogue,gtzan_music_speech/music_speech/speech_wav/dialogue.wav,1
79,dialogue1,gtzan_music_speech/music_speech/speech_wav/dialogue1.wav,1
80,dialogue2,gtzan_music_speech/music_speech/speech_wav/dialogue2.wav,1
81,diamond,gtzan_music_speech/music_speech/speech_wav/diamond.wav,1
82,ellhnika,gtzan_music_speech/music_speech/speech_wav/ellhnika.wav,1
83,emil,gtzan_music_speech/music_speech/speech_wav/emil.wav,1
84,fem_rock,gtzan_music_speech/music_speech/speech_wav/fem_rock.wav,1
85,female,gtzan_music_speech/music_speech/speech_wav/female.wav,1
86,fire,gtzan_music_speech/music_speech/speech_wav/fire.wav,1
87,geography,gtzan_music_speech/music_speech/speech_wav/geography.wav,1
88,geography1,gtzan_music_speech/music_speech/speech_wav/geography1.wav,1
89,georg,gtzan_music_speech/music_speech/speech_wav/georg.wav,1
90,god,gtzan_music_speech/music_speech/speech_wav/god.wav,1
91,greek,gtzan_music_speech/music_speech/speech_wav/greek.wav,1
92,greek1,gtzan_music_speech/music_speech/speech_wav/greek1.wav,1
93,india,gtzan_music_speech/music_speech/speech_wav/india.wav,1
94,jony,gtzan_music_speech/music_speech/speech_wav/jony.wav,1
95,jvoice,gtzan_music_speech/music_speech/speech_wav/jvoice.wav,1
96,kedar,gtzan_music_speech/music_speech/speech_wav/kedar.wav,1
97,kid,gtzan_music_speech/music_speech/speech_wav/kid.wav,1
98,lena,gtzan_music_speech/music_speech/speech_wav/lena.wav,1
99,male,gtzan_music_speech/music_speech/speech_wav/male.wav,1
100,my_voice,gtzan_music_speech/music_speech/speech_wav/my_voice.wav,1
101,nether,gtzan_music_speech/music_speech/speech_wav/nether.wav,1
102,news1,gtzan_music_speech/music_speech/speech_wav/news1.wav,1
103,news2,gtzan_music_speech/music_speech/speech_wav/news2.wav,1
104,nj105,gtzan_music_speech/music_speech/speech_wav/nj105.wav,1
105,nj105a,gtzan_music_speech/music_speech/speech_wav/nj105a.wav,1
106,oneday,gtzan_music_speech/music_speech/speech_wav/oneday.wav,1
107,psychic,gtzan_music_speech/music_speech/speech_wav/psychic.wav,1
108,pulp,gtzan_music_speech/music_speech/speech_wav/pulp.wav,1
109,pulp1,gtzan_music_speech/music_speech/speech_wav/pulp1.wav,1
110,pulp2,gtzan_music_speech/music_speech/speech_wav/pulp2.wav,1
111,relation,gtzan_music_speech/music_speech/speech_wav/relation.wav,1
112,serbian,gtzan_music_speech/music_speech/speech_wav/serbian.wav,1
113,shannon,gtzan_music_speech/music_speech/speech_wav/shannon.wav,1
114,sleep,gtzan_music_speech/music_speech/speech_wav/sleep.wav,1
115,smoke1,gtzan_music_speech/music_speech/speech_wav/smoke1.wav,1
116,smoking,gtzan_music_speech/music_speech/speech_wav/smoking.wav,1
117,stupid,gtzan_music_speech/music_speech/speech_wav/stupid.wav,1
118,teachers,gtzan_music_speech/music_speech/speech_wav/teachers.wav,1
119,teachers1,gtzan_music_speech/music_speech/speech_wav/teachers1.wav,1
120,teachers2,gtzan_music_speech/music_speech/speech_wav/teachers2.wav,1
121,thlui,gtzan_music_speech/music_speech/speech_wav/thlui.wav,1
122,undergrad,gtzan_music_speech/music_speech/speech_wav/undergrad.wav,1
123,vegetables,gtzan_music_speech/music_speech/speech_wav/vegetables.wav,1
124,vegetables1,gtzan_music_speech/music_speech/speech_wav/vegetables1.wav,1
125,vegetables2,gtzan_music_speech/music_speech/speech_wav/vegetables2.wav,1
126,voice,gtzan_music_speech/music_speech/speech_wav/voice.wav,1
127,voices,gtzan_music_speech/music_speech/speech_wav/voices.wav,1
This diff is collapsed.
This diff is collapsed.
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment