|
a |
|
b/data.lua |
|
|
1 |
require 'nn' |
|
|
2 |
|
|
|
3 |
|
|
|
4 |
-- Returns a list of tensors for each line in the file |
|
|
5 |
local function loadData(name, maxLoad) |
|
|
6 |
print(name) |
|
|
7 |
if maxLoad == 0 then maxLoad = 1000000000 end |
|
|
8 |
local data = seq.lines(name):take(maxLoad):copy() |
|
|
9 |
function data:size() return #data end |
|
|
10 |
return data |
|
|
11 |
end |
|
|
12 |
|
|
|
13 |
function table_invert(t) |
|
|
14 |
local s={} |
|
|
15 |
for k,v in pairs(t) do s[v]=k end |
|
|
16 |
return s |
|
|
17 |
end |
|
|
18 |
|
|
|
19 |
function createDatasetOneHot(typ, opt) |
|
|
20 |
if (not opt.size) then opt.size = 0 end |
|
|
21 |
if (not opt.batch_size) then opt.batch_size = 1 end |
|
|
22 |
|
|
|
23 |
local seqs = loadData( path.join(opt.data_dir, typ..'.fa'), opt.size) |
|
|
24 |
local size |
|
|
25 |
if opt.size == 0 then size = seqs:size() end |
|
|
26 |
local inputs = {} |
|
|
27 |
local outputs = {} |
|
|
28 |
local alphabet = opt.alphabet |
|
|
29 |
local rev_lookup = {} |
|
|
30 |
for i = 1,#alphabet do |
|
|
31 |
rev_lookup[alphabet:sub(i,i)] = i |
|
|
32 |
end |
|
|
33 |
|
|
|
34 |
local class_labels = table_invert(opt.class_labels_table) |
|
|
35 |
|
|
|
36 |
setmetatable(inputs, {__index = function(self, ind) |
|
|
37 |
local upper_limit = math.min((ind+opt.batch_size),((size/2)+1)) |
|
|
38 |
local batch_len = upper_limit-ind |
|
|
39 |
local seq_len = opt.seq_length |
|
|
40 |
-- len = #seqs[ind*2] |
|
|
41 |
matrix = torch.zeros(batch_len, seq_len):fill(#alphabet+1) |
|
|
42 |
local batch_dim = 1 |
|
|
43 |
for b = ind,upper_limit-1 do |
|
|
44 |
local str=seqs[b*2] -- have to multiply by 2 because of the way the data is set up with >1 and then sequence |
|
|
45 |
for i = 1,math.min(#str,seq_len),1 do |
|
|
46 |
if rev_lookup[str:sub(i,i)] then |
|
|
47 |
matrix[batch_dim][i] = rev_lookup[str:sub(i,i)] |
|
|
48 |
end |
|
|
49 |
end |
|
|
50 |
batch_dim = batch_dim+1 |
|
|
51 |
end |
|
|
52 |
return matrix |
|
|
53 |
end}) |
|
|
54 |
|
|
|
55 |
|
|
|
56 |
setmetatable(outputs, {__index = function(self, ind) |
|
|
57 |
local upper_limit = math.min((ind+opt.batch_size),((size/2)+1)) |
|
|
58 |
local batch_len = upper_limit-ind |
|
|
59 |
local labels = torch.ones(batch_len) |
|
|
60 |
local batch_dim = 1 |
|
|
61 |
for i = ind,upper_limit-1 do |
|
|
62 |
local line = seqs[(i*2)-1]:gsub(' ','')--get label from i*2 -1 (because of FASTA format) |
|
|
63 |
local label = line:split('>')[1] |
|
|
64 |
label = class_labels[label] |
|
|
65 |
labels[batch_dim] = torch.Tensor({label}) |
|
|
66 |
batch_dim = batch_dim+1 |
|
|
67 |
end |
|
|
68 |
return labels |
|
|
69 |
end}) |
|
|
70 |
|
|
|
71 |
function inputs:size() return size/2 end |
|
|
72 |
function outputs:size() return size/2 end |
|
|
73 |
return {inputs=inputs, labels=outputs} |
|
|
74 |
end |