a b/src/linearsvm_cancer.py
1
import time
2
from datetime import datetime
3
import csv
4
import numpy as np
5
from sklearn import svm
6
from sklearn.decomposition import PCA
7
from sklearn.preprocessing import StandardScaler
8
from sklearn.pipeline import Pipeline
9
from sklearn.cross_validation import train_test_split
10
from sklearn.cross_validation import StratifiedShuffleSplit
11
from sklearn.grid_search import GridSearchCV
12
from sklearn.cross_validation import StratifiedKFold
13
14
15
print "Script start at ", datetime.now().isoformat()
16
17
X=np.load('F:/NYU/Hackathon/numpy_array.npy')
18
Y=X[:,:3] #patient_id cancer_type tissue_type
19
X=X[:,3:] #rpm
20
21
RS=np.random.RandomState(90)
22
perm=RS.permutation(678)
23
24
Y=Y[perm]
25
X=X[perm]
26
27
X_train, X_test, Y_train, Y_test = train_test_split(X, Y[:,1], test_size=0.25, random_state=30, stratify=Y[:,1])
28
29
pipe=Pipeline([('pca',PCA()), ('scaled',StandardScaler()), ('svm_linear',svm.SVC(kernel='linear',C=1,class_weight='balanced'))])
30
31
cval=[2**-9, 2**-8, 2**-7, 2**-6, 2**-5, 2**-4, 2**-3, 2**-2, 2**-1, 2**0, 2**1, 2**2, 2**3, 2**4, 2**5, 2**6, 2**7, 2**8, 2**9]
32
33
pca_val=[1,2,4,13,1046]
34
35
gs=GridSearchCV(pipe, dict(pca__n_components=pca_val, svm_linear__C=cval), cv=10, verbose=100)
36
gs.fit(X_train, Y_train)
37
38
score=gs.score(X_test, Y_test)
39
40
print score
41
print 'best_score'
42
print gs.best_score_
43
print 'best_estimator'
44
print gs.best_estimator_
45
print 'best_params'
46
print gs.best_params_
47
48
outfile="grid_linearsvm_cancer_search_scores_{0}.out".format(int(time.time()))
49
50
with open(outfile, "w") as scoreFile:
51
    writer = csv.writer(scoreFile, delimiter = ",")
52
    paramKeys = list(gs.grid_scores_[0].parameters.keys())
53
54
    writer.writerow(['mean']+ paramKeys)
55
    
56
    for i in gs.grid_scores_:
57
        output = list()
58
        output.append(i.mean_validation_score)
59
60
        for k in paramKeys:
61
            output.append(i.parameters.get(k))
62
63
        writer.writerow(output)
64
65
66
print "Script end at ", datetime.now().isoformat()