Diff of /data.lua [000000] .. [6d0c6b]

Switch to side-by-side view

--- a
+++ b/data.lua
@@ -0,0 +1,74 @@
+require 'nn'
+
+
+-- Returns a list of tensors for each line in the file
+local function loadData(name, maxLoad)
+    print(name)
+    if maxLoad == 0 then maxLoad = 1000000000 end
+    local data = seq.lines(name):take(maxLoad):copy()
+    function data:size() return #data end
+    return data
+end
+
+function table_invert(t)
+   local s={}
+   for k,v in pairs(t) do s[v]=k end
+   return s
+end
+
+function createDatasetOneHot(typ, opt)
+    if (not opt.size) then opt.size = 0 end
+    if (not opt.batch_size) then opt.batch_size = 1 end
+
+    local seqs = loadData( path.join(opt.data_dir, typ..'.fa'), opt.size)
+    local size
+    if opt.size == 0 then size = seqs:size() end
+    local inputs = {}
+    local outputs = {}
+    local alphabet = opt.alphabet
+    local rev_lookup = {}
+    for i = 1,#alphabet do
+      rev_lookup[alphabet:sub(i,i)] = i
+    end
+
+    local class_labels = table_invert(opt.class_labels_table)
+
+    setmetatable(inputs, {__index = function(self, ind)
+      local upper_limit = math.min((ind+opt.batch_size),((size/2)+1))
+      local batch_len = upper_limit-ind
+      local seq_len = opt.seq_length
+      -- len = #seqs[ind*2]
+      matrix = torch.zeros(batch_len, seq_len):fill(#alphabet+1)
+      local batch_dim = 1
+      for b = ind,upper_limit-1 do
+        local str=seqs[b*2] -- have to multiply by 2 because of the way the data is set up with >1 and then sequence
+        for i = 1,math.min(#str,seq_len),1 do
+          if rev_lookup[str:sub(i,i)] then
+            matrix[batch_dim][i] = rev_lookup[str:sub(i,i)]
+          end
+        end
+        batch_dim = batch_dim+1
+      end
+      return matrix
+    end})
+
+
+    setmetatable(outputs, {__index = function(self, ind)
+      local upper_limit = math.min((ind+opt.batch_size),((size/2)+1))
+      local batch_len = upper_limit-ind
+      local labels = torch.ones(batch_len)
+      local batch_dim = 1
+      for i = ind,upper_limit-1 do
+        local line = seqs[(i*2)-1]:gsub(' ','')--get label from i*2 -1 (because of FASTA format)
+        local label = line:split('>')[1]
+        label = class_labels[label]
+        labels[batch_dim] = torch.Tensor({label})
+        batch_dim = batch_dim+1
+      end
+      return labels
+    end})
+
+    function inputs:size() return size/2 end
+    function outputs:size() return size/2 end
+    return {inputs=inputs, labels=outputs}
+end