Switch to unified view

a b/temporal_output_values.lua
1
require 'torch'
2
require 'nn'
3
require 'optim'
4
require 'model'
5
include('util/auRoc.lua')
6
require 'lfs'
7
8
local cmd = torch.CmdLine()
9
10
11
-- GPU
12
cmd:option('-gpu', 1) -- set to 0 if no GPU
13
14
-- Dataset options
15
cmd:option('-data_root', 'data') -- data root directory
16
cmd:option('-dataset', 'deepbind') -- dataset
17
cmd:option('-seq_length', 101) --length of DNA sequences
18
cmd:option('-TF', 'ATF1_K562_ATF1_-06-325-_Harvard') -- change for different TF
19
cmd:option('-alphabet', 'ACGT')
20
cmd:option('-size', 0) -- how much of each dataset to load. 0 = full
21
cmd:option('-batch_size', 1)
22
cmd:option('class_labels','1,0') --specify positive label first
23
24
25
local opt = cmd:parse(arg)
26
27
opt.class_labels_table = opt.class_labels:split(',')
28
opt.num_classes = #opt.class_labels_table
29
opt.alphabet_size = #opt.alphabet
30
31
local data_dir = opt.data_root..'/'..opt.dataset..'/'
32
33
34
-- Set up GPU stuff
35
local dtype = 'torch.FloatTensor'
36
if opt.gpu > 0  then
37
  collectgarbage()
38
  require 'cutorch'
39
  require 'cunn'
40
  cutorch.setDevice(opt.gpu )
41
  dtype = 'torch.CudaTensor'
42
  print(string.format('Running with CUDA on GPU %d', opt.gpu))
43
else
44
  print 'Running in CPU mode'
45
end
46
47
48
49
local data_dir = opt.data_root..'/'..opt.dataset..'/'
50
51
opt.TF = TF or opt.TF
52
opt.data_dir = data_dir..opt.TF
53
54
55
-- specify directories
56
model_root = 'models'
57
data_root = 'data/deepbind/'
58
viz_dir = 'visualization_results/'
59
60
-- ****************************************************************** --
61
-- ****************** CHANGE THESE FIELDS *************************** --
62
TFs = {'ATF1_K562_ATF1_-06-325-_Harvard'}
63
rnn_model = 'model=RNN,rnn_size=32,rnn_layers=1,dropout=0.5,learning_rate=0.01,batch_size=256'
64
cnnrnn_model = 'model=CNN-RNN,cnn_size=128,cnn_filter=9,rnn_size=32,rnn_layers=1,dropout=0.5,learning_rate=0.01,batch_size=256'
65
66
model_names = {rnn_model,cnnrnn_model} --add or remove to this
67
68
-- which sequences in the test set to show temporal outputs for
69
start_seq = 1
70
end_seq = start_seq + 0
71
-- ****************************************************************** --
72
-- ****************************************************************** --
73
74
75
alphabet = opt.alphabet
76
rev_dictionary = {}
77
dictionary = {}
78
for i = 1,#alphabet do
79
  rev_dictionary[i] = alphabet:sub(i,i)
80
  dictionary[alphabet:sub(i,i)] = i
81
end
82
83
OneHot = OneHot(#alphabet):type(dtype)
84
crit = nn.ClassNLLCriterion():type(dtype)
85
86
87
for _,TF in pairs(TFs) do
88
  print(TF)
89
  save_path = viz_dir..TF..'/'
90
  os.execute('mkdir '..save_path..' > /dev/null 2>&1')
91
  -- os.execute('rm '..save_path..'/*.csv > /dev/null 2>&1')
92
  -- os.execute('rm '..save_path..'*.png > /dev/null 2>&1')
93
94
95
  data_dir = data_root..TF
96
  opt.data_dir = data_dir
97
98
  require('data')
99
  data = {}
100
  test_seqs = createDatasetOneHot("test", opt)
101
102
103
  -- Load Models into models table
104
  models = {}
105
  for _,model_name in pairs(model_names) do
106
    print()
107
    load_path = model_root..'/'..model_name..'/'..TF..'/'
108
    model = torch.load(load_path..'best_model.t7')
109
    model.model:remove(1)
110
    model:evaluate()
111
    model.model:type(dtype)
112
    models[model_name] = model
113
  end
114
115
116
  for t = start_seq,end_seq do
117
    x = test_seqs.inputs[t]:type(dtype)
118
    X = OneHot:forward(x)
119
    y = test_seqs.labels[t]:type(dtype)
120
121
    --####################### CREATE SEQ LOGO ###############################--
122
    s2l_filename = save_path..'sequence_'..t..'.txt'
123
    f = io.open(s2l_filename, 'w')
124
    print(s2l_filename)
125
    f:write('PO ')
126
    alphabet:gsub(".",function(c) f:write(tostring(c)..' ') end)
127
    f:write('\n')
128
    for i=1,X[1]:size(1) do
129
      f:write(tostring(i)..' ')
130
      for j=1,X[1]:size(2) do
131
        f:write(tostring(X[1][i][j])..' ')
132
      end
133
      f:write('\n')
134
    end
135
    f:close()
136
    cmd = "weblogo -D transfac -F png -o "..save_path.."sequence_"..t..".png --errorbars NO --show-xaxis NO --show-yaxis NO -A dna --composition none -n 101 --color '#00CC00' 'A' 'A' --color '#0000CC' 'C' 'C' --color '#FFB300' 'G' 'G' --color '#CC0000' 'T' 'T' < "..s2l_filename
137
    os.execute(cmd)
138
139
    --####################### TEMPORAL OUTPUT ###############################--
140
    for model_name, model in pairs(models) do
141
      print('***** '..model_name..' *****')
142
      out_file_fwd = io.open(save_path..model_name..'_output_values_fwd_'..t..'.csv', 'w')
143
      out_file_bwd = io.open(save_path..model_name..'_output_values_bwd_'..t..'.csv', 'w')
144
145
      -- need to get CNN output column vectors to be fed into RNN output
146
      if string.match(model.model:__tostring__(),'Convolution') then
147
        CNN = model.model:get(1)
148
        model.model:remove(1)
149
        X_in = CNN:forward(X)
150
      else
151
        X_in = X
152
      end
153
154
      -- FORWARD
155
      for i = 1,X_in:size(2) do
156
        model:resetStates()
157
        model:zeroGradParameters()
158
        output = model:forward(X_in[{{1,1},{1,i}}])
159
        pos_sent_value = torch.exp(output[1])[1]
160
        out_file_fwd:write(rev_dictionary[x[1][i]]..',')
161
        out_file_fwd:write(pos_sent_value..',\n')
162
      end
163
      -- REVERSE
164
      for i = 1,X_in:size(2) do
165
        model:resetStates()
166
        model:zeroGradParameters()
167
        output = model:forward(X_in[{{1,1},{i,X_in:size(2)}}])
168
        pos_sent_value = torch.exp(output[1])[1]
169
        out_file_bwd:write(rev_dictionary[x[1][i]]..',')
170
        out_file_bwd:write(pos_sent_value..',\n')
171
      end
172
173
      out_file_fwd:write('\n')
174
      out_file_bwd:write('\n')
175
      out_file_fwd:close()
176
      out_file_bwd:close()
177
178
179
      cmd = 'Rscript ./heatmap_scripts/heatmap_temporal.R '..save_path..model_name..'_output_values_fwd_'..t..'.csv '..save_path..model_name..'_output_values_fwd_'..t..'.png -20'
180
      os.execute(cmd..' > /dev/null 2>&1')
181
      cmd = 'Rscript ./heatmap_scripts/heatmap_temporal.R '..save_path..model_name..'_output_values_bwd_'..t..'.csv '..save_path..model_name..'_output_values_bwd_'..t..'.png -20'
182
      os.execute(cmd..' > /dev/null 2>&1')
183
    end -- model in models
184
185
  end -- test sequences
186
187
  print('')
188
  print(lfs.currentdir()..'/'..save_path)
189
  os.execute('rm '..save_path..'/*.csv > /dev/null 2>&1')
190
  os.execute('rm '..save_path..'/*.txt > /dev/null 2>&1')
191
192
end -- TFs