|
a |
|
b/class_optimization.lua |
|
|
1 |
require 'torch' |
|
|
2 |
require 'nn' |
|
|
3 |
require 'optim' |
|
|
4 |
require 'model' |
|
|
5 |
require 'image' |
|
|
6 |
require 'cutorch' |
|
|
7 |
require 'cunn' |
|
|
8 |
require('gnuplot') |
|
|
9 |
require('lfs') |
|
|
10 |
require('data') |
|
|
11 |
data = {} |
|
|
12 |
dtype = 'torch.CudaTensor' |
|
|
13 |
|
|
|
14 |
-- specify directories |
|
|
15 |
model_root = 'models/' |
|
|
16 |
data_root = 'data/deepbind/' |
|
|
17 |
viz_dir = 'visualization_results/' |
|
|
18 |
|
|
|
19 |
-- ****************************************************************** -- |
|
|
20 |
-- ****************** CHANGE THESE FIELDS *************************** -- |
|
|
21 |
TFs = {'ATF1_K562_ATF1_-06-325-_Harvard'} |
|
|
22 |
cnn_model_name = 'model=CNN,cnn_size=128,cnn_filters=9-5-3,dropout=0.5,learning_rate=0.01,batch_size=256' |
|
|
23 |
rnn_model_name = 'model=RNN,rnn_size=32,rnn_layers=1,dropout=0.5,learning_rate=0.01,batch_size=256' |
|
|
24 |
cnnrnn_model_name = 'model=CNN-RNN,cnn_size=128,cnn_filter=9,rnn_size=32,rnn_layers=1,dropout=0.5,learning_rate=0.01,batch_size=256' |
|
|
25 |
|
|
|
26 |
model_names = {rnn_model_name,cnn_model_name,cnnrnn_model_name} |
|
|
27 |
-- ****************************************************************** -- |
|
|
28 |
-- ****************************************************************** -- |
|
|
29 |
|
|
|
30 |
alphabet = 'ACGT' |
|
|
31 |
OneHot = OneHot(#alphabet):type(dtype) |
|
|
32 |
crit = nn.ClassNLLCriterion():type(dtype) --c |
|
|
33 |
|
|
|
34 |
|
|
|
35 |
start_pos = 1 |
|
|
36 |
end_pos = start_pos + 0 |
|
|
37 |
|
|
|
38 |
|
|
|
39 |
lambda = 0.009 |
|
|
40 |
config = {learningRate=.05,momentum=0.9} |
|
|
41 |
iterations = 1000 |
|
|
42 |
|
|
|
43 |
|
|
|
44 |
for _,TF in pairs(TFs) do |
|
|
45 |
print(TF) |
|
|
46 |
save_path = viz_dir..TF..'/' |
|
|
47 |
os.execute('mkdir '..save_path..' > /dev/null 2>&1') |
|
|
48 |
|
|
|
49 |
-- Load Models |
|
|
50 |
models = {} |
|
|
51 |
for _,model_name in pairs(model_names) do |
|
|
52 |
load_path = model_root..model_name..'/'..TF..'/' |
|
|
53 |
model = torch.load(load_path..'best_model.t7') |
|
|
54 |
model:evaluate() |
|
|
55 |
model.model:type(dtype) |
|
|
56 |
|
|
|
57 |
models[model_name] = model |
|
|
58 |
end |
|
|
59 |
|
|
|
60 |
--#######################################################################-- |
|
|
61 |
--######################### CLASS OPTIMIZATION ##########################-- |
|
|
62 |
--#######################################################################-- |
|
|
63 |
|
|
|
64 |
for model_name, model in pairs(models) do |
|
|
65 |
print('\n ****** Optimizing '..model_name..' *******\n') |
|
|
66 |
print(model.model) |
|
|
67 |
model.model:remove(1) |
|
|
68 |
model:resetStates() |
|
|
69 |
|
|
|
70 |
motif = torch.rand(1,101,4):type(dtype) |
|
|
71 |
target = torch.Tensor({1}):type(dtype) |
|
|
72 |
|
|
|
73 |
-- motif weight update |
|
|
74 |
feval = function(X) |
|
|
75 |
local output = model:forward(X) |
|
|
76 |
local loss = crit:forward(output[1], target) |
|
|
77 |
local df_do = crit:backward(output[1], target) |
|
|
78 |
local inputGrads = model:backward(motif, df_do) |
|
|
79 |
return (loss + lambda*(X:norm())^2), (inputGrads + X*2*lambda) |
|
|
80 |
end |
|
|
81 |
|
|
|
82 |
-- SGD Loop |
|
|
83 |
for i = 1,iterations do |
|
|
84 |
motif,f = optim.rmsprop(feval,motif,config) |
|
|
85 |
print(f[1]) |
|
|
86 |
end |
|
|
87 |
|
|
|
88 |
-- resize |
|
|
89 |
motif = motif[1]:type(dtype) |
|
|
90 |
|
|
|
91 |
-- clamp to values in (0,1) |
|
|
92 |
motif:clamp(0,1) |
|
|
93 |
|
|
|
94 |
max = motif:max() |
|
|
95 |
for i = 1,101 do |
|
|
96 |
sum = motif[i]:sum() |
|
|
97 |
if sum == 0 then |
|
|
98 |
motif[i] = torch.zeros(4) |
|
|
99 |
else |
|
|
100 |
for j = 1,4 do motif[i][j] = motif[i][j]/max end |
|
|
101 |
end |
|
|
102 |
end |
|
|
103 |
|
|
|
104 |
for i = 1,101 do |
|
|
105 |
--add smoothing constant |
|
|
106 |
for j = 1,4 do motif[i][j] = motif[i][j]+0.01 end |
|
|
107 |
--normalize |
|
|
108 |
sum = motif[i]:sum() |
|
|
109 |
for j = 1,4 do motif[i][j] = motif[i][j]/sum end |
|
|
110 |
end |
|
|
111 |
|
|
|
112 |
|
|
|
113 |
s2l_filename = save_path..model_name..'_optimization.txt' |
|
|
114 |
optimization_file = io.open(s2l_filename, 'w') |
|
|
115 |
optimization_file:write('PO ') |
|
|
116 |
alphabet:gsub(".",function(c) optimization_file:write(tostring(c)..' ') end) |
|
|
117 |
optimization_file:write('\n') |
|
|
118 |
for i=1,motif:size(1) do |
|
|
119 |
optimization_file:write(tostring(i)..' ') |
|
|
120 |
for j=1,motif:size(2) do |
|
|
121 |
optimization_file:write(tostring(motif[i][j])..' ') |
|
|
122 |
end |
|
|
123 |
optimization_file:write('\n') |
|
|
124 |
end |
|
|
125 |
optimization_file:close() |
|
|
126 |
cmd = "weblogo -D transfac -F png -o "..save_path..model_name.."_optimization.png --errorbars NO --show-xaxis NO --show-yaxis NO -A dna --composition none -n 101 --color '#00CC00' 'A' 'A' --color '#0000CC' 'C' 'C' --color '#FFB300' 'G' 'G' --color '#CC0000' 'T' 'T' < "..s2l_filename |
|
|
127 |
os.execute(cmd) |
|
|
128 |
|
|
|
129 |
end |
|
|
130 |
|
|
|
131 |
|
|
|
132 |
|
|
|
133 |
print('') |
|
|
134 |
print(lfs.currentdir()..'/'..save_path) |
|
|
135 |
os.execute('rm '..save_path..'/*.csv > /dev/null 2>&1') |
|
|
136 |
os.execute('rm '..save_path..'/*.txt > /dev/null 2>&1') |
|
|
137 |
|
|
|
138 |
end |