Diff of /data.lua [000000] .. [6d0c6b]

Switch to unified view

a b/data.lua
1
require 'nn'
2
3
4
-- Returns a list of tensors for each line in the file
5
local function loadData(name, maxLoad)
6
    print(name)
7
    if maxLoad == 0 then maxLoad = 1000000000 end
8
    local data = seq.lines(name):take(maxLoad):copy()
9
    function data:size() return #data end
10
    return data
11
end
12
13
function table_invert(t)
14
   local s={}
15
   for k,v in pairs(t) do s[v]=k end
16
   return s
17
end
18
19
function createDatasetOneHot(typ, opt)
20
    if (not opt.size) then opt.size = 0 end
21
    if (not opt.batch_size) then opt.batch_size = 1 end
22
23
    local seqs = loadData( path.join(opt.data_dir, typ..'.fa'), opt.size)
24
    local size
25
    if opt.size == 0 then size = seqs:size() end
26
    local inputs = {}
27
    local outputs = {}
28
    local alphabet = opt.alphabet
29
    local rev_lookup = {}
30
    for i = 1,#alphabet do
31
      rev_lookup[alphabet:sub(i,i)] = i
32
    end
33
34
    local class_labels = table_invert(opt.class_labels_table)
35
36
    setmetatable(inputs, {__index = function(self, ind)
37
      local upper_limit = math.min((ind+opt.batch_size),((size/2)+1))
38
      local batch_len = upper_limit-ind
39
      local seq_len = opt.seq_length
40
      -- len = #seqs[ind*2]
41
      matrix = torch.zeros(batch_len, seq_len):fill(#alphabet+1)
42
      local batch_dim = 1
43
      for b = ind,upper_limit-1 do
44
        local str=seqs[b*2] -- have to multiply by 2 because of the way the data is set up with >1 and then sequence
45
        for i = 1,math.min(#str,seq_len),1 do
46
          if rev_lookup[str:sub(i,i)] then
47
            matrix[batch_dim][i] = rev_lookup[str:sub(i,i)]
48
          end
49
        end
50
        batch_dim = batch_dim+1
51
      end
52
      return matrix
53
    end})
54
55
56
    setmetatable(outputs, {__index = function(self, ind)
57
      local upper_limit = math.min((ind+opt.batch_size),((size/2)+1))
58
      local batch_len = upper_limit-ind
59
      local labels = torch.ones(batch_len)
60
      local batch_dim = 1
61
      for i = ind,upper_limit-1 do
62
        local line = seqs[(i*2)-1]:gsub(' ','')--get label from i*2 -1 (because of FASTA format)
63
        local label = line:split('>')[1]
64
        label = class_labels[label]
65
        labels[batch_dim] = torch.Tensor({label})
66
        batch_dim = batch_dim+1
67
      end
68
      return labels
69
    end})
70
71
    function inputs:size() return size/2 end
72
    function outputs:size() return size/2 end
73
    return {inputs=inputs, labels=outputs}
74
end