Diff of /class_optimization.lua [000000] .. [6d0c6b]

Switch to unified view

a b/class_optimization.lua
1
require 'torch'
2
require 'nn'
3
require 'optim'
4
require 'model'
5
require 'image'
6
require 'cutorch'
7
require 'cunn'
8
require('gnuplot')
9
require('lfs')
10
require('data')
11
data = {}
12
dtype = 'torch.CudaTensor'
13
14
-- specify directories
15
model_root = 'models/'
16
data_root = 'data/deepbind/'
17
viz_dir = 'visualization_results/'
18
19
-- ****************************************************************** --
20
-- ****************** CHANGE THESE FIELDS *************************** --
21
TFs = {'ATF1_K562_ATF1_-06-325-_Harvard'}
22
cnn_model_name = 'model=CNN,cnn_size=128,cnn_filters=9-5-3,dropout=0.5,learning_rate=0.01,batch_size=256'
23
rnn_model_name = 'model=RNN,rnn_size=32,rnn_layers=1,dropout=0.5,learning_rate=0.01,batch_size=256'
24
cnnrnn_model_name = 'model=CNN-RNN,cnn_size=128,cnn_filter=9,rnn_size=32,rnn_layers=1,dropout=0.5,learning_rate=0.01,batch_size=256'
25
26
model_names = {rnn_model_name,cnn_model_name,cnnrnn_model_name}
27
-- ****************************************************************** --
28
-- ****************************************************************** --
29
30
alphabet = 'ACGT'
31
OneHot = OneHot(#alphabet):type(dtype)
32
crit = nn.ClassNLLCriterion():type(dtype) --c
33
34
35
start_pos = 1
36
end_pos = start_pos + 0
37
38
39
lambda = 0.009
40
config = {learningRate=.05,momentum=0.9}
41
iterations = 1000
42
43
44
for _,TF in pairs(TFs) do
45
  print(TF)
46
  save_path = viz_dir..TF..'/'
47
  os.execute('mkdir '..save_path..' > /dev/null 2>&1')
48
49
  -- Load Models
50
  models = {}
51
  for _,model_name in pairs(model_names) do
52
    load_path = model_root..model_name..'/'..TF..'/'
53
    model = torch.load(load_path..'best_model.t7')
54
    model:evaluate()
55
    model.model:type(dtype)
56
57
    models[model_name] = model
58
  end
59
60
  --#######################################################################--
61
  --######################### CLASS OPTIMIZATION ##########################--
62
  --#######################################################################--
63
64
  for model_name, model in pairs(models) do
65
    print('\n ****** Optimizing '..model_name..' *******\n')
66
    print(model.model)
67
    model.model:remove(1)
68
    model:resetStates()
69
70
    motif = torch.rand(1,101,4):type(dtype)
71
    target = torch.Tensor({1}):type(dtype)
72
73
    -- motif weight update
74
    feval = function(X)
75
      local output = model:forward(X)
76
      local loss = crit:forward(output[1], target)
77
      local df_do = crit:backward(output[1], target)
78
      local inputGrads = model:backward(motif, df_do)
79
      return (loss + lambda*(X:norm())^2), (inputGrads + X*2*lambda)
80
    end
81
82
    -- SGD Loop
83
    for i =  1,iterations do
84
      motif,f = optim.rmsprop(feval,motif,config)
85
      print(f[1])
86
    end
87
88
    -- resize
89
    motif = motif[1]:type(dtype)
90
91
    -- clamp to values in (0,1)
92
    motif:clamp(0,1)
93
94
    max = motif:max()
95
    for i = 1,101 do
96
      sum = motif[i]:sum()
97
      if sum == 0 then
98
        motif[i] = torch.zeros(4)
99
      else
100
        for j = 1,4 do motif[i][j] = motif[i][j]/max end
101
      end
102
    end
103
104
    for i = 1,101 do
105
      --add smoothing constant
106
      for j = 1,4 do motif[i][j] = motif[i][j]+0.01 end
107
      --normalize
108
      sum = motif[i]:sum()
109
      for j = 1,4 do motif[i][j] = motif[i][j]/sum end
110
    end
111
112
113
    s2l_filename = save_path..model_name..'_optimization.txt'
114
    optimization_file = io.open(s2l_filename, 'w')
115
    optimization_file:write('PO ')
116
    alphabet:gsub(".",function(c) optimization_file:write(tostring(c)..' ') end)
117
    optimization_file:write('\n')
118
    for i=1,motif:size(1) do
119
      optimization_file:write(tostring(i)..' ')
120
      for j=1,motif:size(2) do
121
        optimization_file:write(tostring(motif[i][j])..' ')
122
      end
123
      optimization_file:write('\n')
124
    end
125
    optimization_file:close()
126
    cmd = "weblogo -D transfac -F png -o "..save_path..model_name.."_optimization.png --errorbars NO --show-xaxis NO --show-yaxis NO -A dna --composition none -n 101 --color '#00CC00' 'A' 'A' --color '#0000CC' 'C' 'C' --color '#FFB300' 'G' 'G' --color '#CC0000' 'T' 'T' < "..s2l_filename
127
    os.execute(cmd)
128
129
  end
130
131
132
133
  print('')
134
  print(lfs.currentdir()..'/'..save_path)
135
  os.execute('rm '..save_path..'/*.csv > /dev/null 2>&1')
136
  os.execute('rm '..save_path..'/*.txt > /dev/null 2>&1')
137
138
end