Diff of /R/generator_rds.R [000000] .. [409433]

Switch to side-by-side view

--- a
+++ b/R/generator_rds.R
@@ -0,0 +1,456 @@
+#' Rds data generator
+#' 
+#' Creates training batches from rds files. Rds files must contain a  
+#' list of length 2 (input/target) or of length 1 (for language model).
+#' If \code{target_len} is not NULL will take the last \code{target_len} entries of 
+#' the first list element as targets and the rest as input.    
+#' 
+#' @inheritParams generator_fasta_label_header_csv
+#' @param rds_folder Path to input files.
+#' @param target_len Number of target nucleotides for language model.
+#' @param delete_used_files Whether to delete file once used. Only applies for rds files. 
+#' @examplesIf reticulate::py_module_available("tensorflow")
+#' # create 3 rds files
+#' rds_folder <- tempfile()
+#' dir.create(rds_folder)
+#' batch_size <- 7
+#' maxlen <- 11
+#' voc_len <- 4
+#' for (i in 1:3) {
+#'   x <- sample(0:(voc_len-1), maxlen*batch_size, replace = TRUE)
+#'   x <- keras::to_categorical(x, num_classes = voc_len)
+#'   x <- array(x, dim = c(batch_size, maxlen, voc_len))
+#'   y <- sample(0:2, batch_size ,replace = TRUE)
+#'   y <- keras::to_categorical(y, num_classes = 3)
+#'   xy_list <- list(x, y)
+#'   file_name <- paste0(rds_folder, "/file_", i, ".rds")
+#'   saveRDS(xy_list, file_name) 
+#' }
+#' 
+#' # create generator
+#' gen <- generator_rds(rds_folder, batch_size = 2)
+#' z <- gen()
+#' x <- z[[1]]
+#' y <- z[[2]]
+#' x[1, , ]
+#' y[1, ]
+#' 
+#' @returns A generator function.  
+#' @export
+generator_rds <- function(rds_folder, batch_size, path_file_log = NULL,
+                          max_samples = NULL,
+                          proportion_per_seq = NULL,
+                          target_len = NULL,
+                          seed = NULL, delete_used_files = FALSE,
+                          reverse_complement = FALSE,
+                          sample_by_file_size = FALSE,
+                          n_gram = NULL, n_gram_stride = 1,
+                          reverse_complement_encoding = FALSE,
+                          add_noise = NULL,
+                          reshape_xy = NULL) {
+  
+  if (!is.null(reshape_xy)) {
+    reshape_xy_bool <- TRUE
+    reshape_x_bool <- ifelse(is.null(reshape_xy$x), FALSE, TRUE)
+    if (reshape_x_bool && !all(c('x', 'y') %in% names(formals(reshape_xy$x)))) {
+      stop("function reshape_xy$x needs to have arguments named x, y and sw")
+    }
+    reshape_y_bool <- ifelse(is.null(reshape_xy$y), FALSE, TRUE)
+    if (reshape_y_bool && !all(c('x', 'y') %in% names(formals(reshape_xy$y)))) {
+      stop("function reshape_xy$y needs to have arguments named x, y and sw")
+    }
+    reshape_sw_bool <- ifelse(is.null(reshape_xy$sw), FALSE, TRUE)
+    if (reshape_sw_bool && !all(c('x', 'y') %in% names(formals(reshape_xy$sw)))) {
+      stop("function reshape_xy$sw needs to have arguments named x, y and sw")
+    }
+  } else {
+    reshape_xy_bool <- FALSE
+  }
+  
+  if (!is.null(seed)) set.seed(seed)
+  is_lm <- !is.null(target_len)
+  
+  if (!is.null(n_gram) & is_lm && (target_len < n_gram)) {
+    stop("target_len needs to be at least as big as n_gram.")
+  }
+  
+  rds_files <- list_fasta_files(rds_folder, format = "rds", file_filter = NULL)
+  num_files <- length(rds_files)
+  
+  read_success <- FALSE
+  while (!read_success) {
+    tryCatch(
+      expr = {
+        if (sample_by_file_size) {
+          file_prob <- file.info(rds_files)$size/sum(file.info(rds_files)$size)
+          file_index <- sample(1:num_files, size = 1, prob = file_prob)
+        } else {
+          file_prob <- NULL
+          rds_files <- sample(rds_files)
+          file_index <- 1
+        }
+        rds_file <- readRDS(rds_files[file_index])
+        read_success <- TRUE
+      },
+      error = function(e){ 
+        
+      }
+    )
+  }
+  
+  if (delete_used_files) file.remove(rds_files[file_index])
+  
+  if (is.list(rds_file)) {
+    x_complete <- rds_file[[1]]
+    include_sw <- ifelse(length(rds_file) == 3, TRUE, FALSE) 
+  } else {
+    x_complete <- rds_file
+    include_sw <- FALSE
+  }
+  
+  if (!is_lm) y_complete <- rds_file[[2]]
+  # TODO: adjust for different input format (input mix of 3D and 1D etc.)
+  multi_input <- ifelse(is.list(x_complete), TRUE, FALSE)
+  multi_output <- ifelse(length(rds_file) > 1 && is.list(rds_file[[2]]), TRUE, FALSE)
+  
+  if (multi_input) {
+    x_dim_list <- list()
+    size_splits_in <- list()
+    for (i in 1:length(x_complete)) {
+      x_dim_list[[i]] <- dim(x_complete[[i]])
+      size_splits_in[[i]] <- x_dim_list[[i]][length(x_dim_list[[i]])] 
+      if (i > 1) {
+        if (length(x_dim_list[[i]]) != length(x_dim_list[[i-1]])) {
+          stop("rds generator only works if separate inputs have same dimension size")
+        }
+      }
+    }
+    x_complete <- tensorflow::tf$concat(x_complete, 
+                                        axis = as.integer(length(x_dim_list[[1]]) - 1)) %>% as.array()
+  } 
+  x_dim_start <- dim(x_complete)
+  
+  if (!is_lm) {
+    if (multi_output) {
+      y_dim_list <- list()
+      size_splits_out <- list()
+      for (i in 1:length(y_complete)) {
+        y_dim_list[[i]] <- dim(y_complete[[i]])
+        size_splits_out[[i]] <- y_dim_list[[i]][length(y_dim_list[[i]])] 
+        if (i > 1) {
+          if (length(y_dim_list[[i]]) != length(y_dim_list[[i-1]])) {
+            stop("rds generator only works if separate outputs have same dimension size")
+          }
+        }
+      }
+      y_complete <- tensorflow::tf$concat(y_complete,
+                                          axis = as.integer(length(y_dim_list[[1]]) - 1)) %>% as.array()
+    } 
+    y_dim_start <- dim(y_complete)
+    if (is.null(y_dim_start)) y_dim_start <- length(y_complete)
+    if (x_dim_start[1] != y_dim_start[1]) {
+      stop("Different number of samples for input and target")
+    }
+  }
+  
+  if (include_sw) {
+    sw_complete <- rds_file[[3]]
+    multi_sw <- is.list(sw_complete)
+    if (multi_sw) {
+      # TODO:  have temporal/non-temporal in same batch 
+      sw_temporal <- ifelse(dim(sw_complete[[1]])[2] == 1, FALSE, TRUE)
+      sw_dim <- lapply(sw_complete,dim)
+      
+      sw_dim_list <- list()
+      size_splits_sw <- list()
+      for (i in 1:length(sw_complete)) {
+        sw_dim_list[[i]] <- dim(sw_complete[[i]])
+        size_splits_sw[[i]] <- sw_dim_list[[i]][length(sw_dim_list[[i]])] 
+      }
+      sw_complete <- tensorflow::tf$concat(sw_complete, 
+                                           axis = as.integer(length(sw_dim_list[[1]]) - 1)) %>% as.array()
+      
+    } else {
+      sw_temporal <- ifelse(dim(sw_complete)[2] == 1, FALSE, TRUE)
+      sw_dim <- dim(sw_complete) 
+    }
+    sw_dim_start <- dim(sw_complete)
+  }
+  
+  sample_index <- 1:x_dim_start[1]
+  
+  if (!is.null(proportion_per_seq)) {
+    sample_index <- sample(sample_index, min(length(sample_index), length(sample_index) * proportion_per_seq))
+  }
+  
+  if (!is.null(max_samples)) {
+    sample_index <- sample(sample_index, min(length(sample_index), max_samples))
+  }
+  
+  if (!is.null(path_file_log)) {
+    utils::write.table(x = rds_files[1], file = path_file_log, row.names = FALSE, col.names = FALSE)
+  }
+  
+  x_dim <- x_dim_start
+  if (!is_lm) y_dim <- y_dim_start
+  if (include_sw) sw_dim <- sw_dim_start
+  
+  function() {
+    
+    # TODO: adjust for multi input/output
+    x_index <- 1
+    x <- array(0, c(batch_size, x_dim[-1]))
+    if (is_lm) {
+      y <- vector("list", target_len)
+    } else {
+      y <- array(0, c(batch_size, y_dim[2]))
+    }
+    
+    if (include_sw) {
+      sw <- array(0, c(batch_size, sw_dim[2]))
+    }
+    
+    while (x_index <= batch_size) {
+      
+      # go to next file if sample index empty
+      if (length(sample_index) == 0) {
+        
+        if (num_files == 1) {
+          
+          sample_index <<- 1:x_dim[1]
+          if (delete_used_files) {
+            file.remove(rds_files[file_index])
+            read_success <- FALSE
+            while (!read_success) {
+              tryCatch(
+                expr = {
+                  if (sample_by_file_size) {
+                    file_index <<- sample(1:num_files, size = 1, prob = file_prob)
+                  } else {
+                    file_index <<- file_index + 1
+                    if (file_index > num_files) {
+                      file_index <<- 1
+                      rds_files <<- sample(rds_files)
+                    }
+                  }
+                  rds_file <<- readRDS(rds_files[file_index])
+                  read_success <- TRUE
+                },
+                error = function(e){ 
+                  
+                }
+              )
+            }
+          } 
+          
+        } else {
+          
+          if (delete_used_files) file.remove(rds_files[file_index])
+          
+          read_success <- FALSE
+          while (!read_success) {
+            tryCatch(
+              expr = {
+                if (sample_by_file_size) {
+                  file_index <<- sample(1:num_files, size = 1, prob = file_prob)
+                } else {
+                  file_index <<- file_index + 1
+                  if (file_index > num_files) {
+                    file_index <<- 1
+                    rds_files <<- sample(rds_files)
+                  }
+                }
+                rds_file <<- readRDS(rds_files[file_index])
+                read_success <- TRUE
+              },
+              error = function(e){ 
+                
+              }
+            )
+          }
+          
+          if (multi_input) {
+            # combine inputs in one tensor
+            x_complete <<- tensorflow::tf$concat(rds_file[[1]], 
+                                                 axis = as.integer(length(x_dim_list[[1]]) - 1)) %>% as.array()
+          } else {
+            x_complete <<- rds_file[[1]]
+          } 
+          x_dim <<- dim(x_complete)
+          
+          if (include_sw) {
+            if (multi_sw) {
+              sw_complete <- tensorflow::tf$concat(rds_file[[3]], 
+                                                   axis = as.integer(length(sw_dim_list[[1]]) - 1)) %>% as.array()
+            } else {
+              sw_complete <<- rds_file[[3]]
+            }
+          }
+          
+          if (!is_lm) {
+            if (multi_output) {
+              y_complete <- tensorflow::tf$concat(rds_file[[2]], 
+                                                  axis = as.integer(length(y_dim_list[[1]]) - 1)) %>% as.array()
+            } else {
+              y_complete <<- rds_file[[2]]
+            }
+            y_dim <<- dim(y_complete)
+            #if (is.null(y_dim)) y_dim <<- length(y_complete)
+          }
+          
+          if (!is_lm && (x_dim[1] != y_dim[1])) {
+            print(x_dim)
+            print(y_dim)
+            stop("Different number of samples for input and target")
+          }
+          
+          sample_index <<- 1:x_dim[1]
+          if (!is.null(proportion_per_seq)) {
+            if (length(sample_index) > 1) {
+              sample_index <<- sample(sample_index, min(length(sample_index), max(1, floor(length(sample_index) * proportion_per_seq))))
+            }
+          }
+          
+          if (!is.null(max_samples)) {
+            if (length(sample_index) > 1) {
+              sample_index <<- sample(sample_index, min(length(sample_index), max_samples))
+            }
+          }
+        }
+        
+        # log file
+        if (!is.null(path_file_log)) {
+          utils::write.table(x = rds_files[file_index], file = path_file_log, append = TRUE, col.names = FALSE, row.names = FALSE)
+        }
+      }
+      
+      if (length(sample_index) == 0) next
+      if (length(sample_index) == 1) {
+        index <- sample_index
+      } else {
+        index <- sample(sample_index, min(batch_size - x_index + 1, length(sample_index)))
+      }
+      
+      #subsetting
+      subset_index <- x_index:(x_index + length(index) - 1)
+      
+      if (length(x_dim) == 4) {
+        x[subset_index, , , ] <- x_complete[index, , , ]
+      }
+      
+      if (length(x_dim) == 3) {
+        x[subset_index, , ] <- x_complete[index, , ]
+      }
+      
+      if (length(x_dim) == 2) {
+        x[subset_index, ] <- x_complete[index, ]
+      }
+      
+      if (!is_lm) {
+        y[subset_index, ] <- y_complete[index, ]
+      }
+      
+      if (include_sw) {
+        sw[subset_index, ] <- sw_complete[index, ]
+      }
+      
+      x_index <- x_index + length(index)
+      sample_index <<- setdiff(sample_index, index)
+      
+    }
+    
+    if (is_lm) {
+      for (m in 1:target_len) {
+        if (batch_size == 1) {
+          y[[m]] <- matrix(x[ , x_dim[2] - target_len + m, ], nrow = 1)
+        } else {
+          y[[m]] <- x[ , x_dim[2] - target_len + m, ]
+        }
+      }
+      
+      x <- x[ , 1:(x_dim[2] - target_len), ]
+      if (batch_size == 1) {
+        dim(x) <- c(1, dim(x))
+      }
+    }
+    
+    if (!is.null(n_gram) & is_lm) {
+      y <- do.call(rbind, y)
+      y_list <- list()
+      for (i in 1:batch_size) {
+        index <- (i-1)  + (1 + (0:(target_len-1)) * batch_size)
+        n_gram_matrix <- n_gram_of_matrix(input_matrix = y[index, ], n = n_gram)
+        y_list[[i]] <- n_gram_matrix
+      }
+      y_tensor <- keras::k_stack(y_list, axis = 1L) %>% keras::k_eval()
+      y <- vector("list", dim(y_tensor)[2])
+      for (i in 1:dim(y_tensor)[2]) {
+        y_subset <- y_tensor[ , i, ]
+        if (batch_size == 1) y_subset <- matrix(y_subset, nrow = 1)
+        y[[i]] <- y_subset
+      }
+      
+      if (is.list(y) & length(y) == 1) {
+        y <- y[[1]]
+      }
+      
+      if (n_gram_stride > 1 & is.list(y)) {
+        stride_index <- 0:(length(y)-1) %% n_gram_stride == 0
+        y <- y[stride_index]
+        
+      }
+    }
+    
+    if (!is.null(add_noise)) {
+      noise_args <- c(add_noise, list(x = x))
+      x <- do.call(add_noise_tensor, noise_args)
+    }
+    
+    if (reverse_complement_encoding){
+      x_1 <- x
+      x_2 <- array(x_1[ , (dim(x)[2]):1, 4:1], dim = dim(x))
+      x <- list(x_1, x_2)
+    }
+    
+    if (multi_input) {
+      x <- tensorflow::tf$split(x, num_or_size_splits = size_splits_in, axis = as.integer(length(x_dim)-1))
+    }
+    
+    if (multi_output) {
+      y <- tensorflow::tf$split(y, num_or_size_splits = size_splits_out, axis = as.integer(length(y_dim)-1))
+    }
+    
+    if (include_sw && multi_sw) {
+      sw <- tensorflow::tf$split(sw, num_or_size_splits = size_splits_sw, axis = as.integer(length(sw_dim)-1))
+    }
+    
+    if (include_sw) {
+      
+      if (reshape_xy_bool) {
+        l <- f_reshape(x = x, y = y,
+                       reshape_xy = reshape_xy,
+                       reshape_x_bool = reshape_x_bool,
+                       reshape_y_bool = reshape_y_bool,
+                       reshape_sw_bool = reshape_sw_bool,
+                       sw = sw)
+        return(l)
+      } 
+      
+      return(list(x, y, sw))
+      
+    } else {
+      
+      if (reshape_xy_bool) {
+        l <- f_reshape(x = x, y = y,
+                       reshape_xy = reshape_xy,
+                       reshape_x_bool = reshape_x_bool,
+                       reshape_y_bool = reshape_y_bool,
+                       reshape_sw_bool = reshape_sw_bool, sw = NULL)
+        return(l)
+      } 
+      
+      return(list(x, y))
+    }
+    
+  }
+}