a b/R/n_gram.R
1
#' Get distribution of n-grams
2
#' 
3
#' Get distribution of next character given previous n nucleotides.
4
#'
5
#' @inheritParams generator_fasta_lm
6
#' @param path_input Path to folder containing fasta files or single fasta file.
7
#' @param n Size of n gram.
8
#' @param vocabulary Vector of allowed characters, samples outside vocabulary get discarded.
9
#' @param file_sample If integer, size of random sample of files in \code{path_input}.
10
#' @param nuc_dist Nucleotide distribution.
11
#' @return Returns a matrix with distributions of nucleotides given the previous n nucleotides.
12
#' @examples
13
#' temp_dir <- tempfile()
14
#' dir.create(temp_dir)
15
#' create_dummy_data(file_path = temp_dir,
16
#'                   num_files = 3,
17
#'                   seq_length = 80,
18
#'                   vocabulary = c("A", "C", "G", "T"),
19
#'                   num_seq = 2)
20
#' 
21
#' m <- n_gram_dist(path_input = temp_dir,
22
#'                  n = 3,
23
#'                  step = 1,
24
#'                  nuc_dist = FALSE)
25
#' head(round(m, 2))
26
#' @returns A data frame of n-gram predictions.
27
#' @export
28
n_gram_dist <- function(path_input,
29
                        n = 2,
30
                        vocabulary = c("A", "C", "G", "T"),
31
                        format = "fasta",
32
                        file_sample = NULL,
33
                        step = 1,
34
                        nuc_dist = FALSE) {
35
  
36
  if (endsWith(path_input, paste0(".", format))) {
37
    num_files <- 1
38
    fasta_files <- path_input
39
  } else {
40
    fasta_files <- list.files(
41
      path = path_input,
42
      pattern = paste0("\\.", format, "$"),
43
      full.names = TRUE)
44
    num_files <- length(fasta_files)
45
  }
46
  
47
  # take random subset of files
48
  if (!is.null(file_sample)){
49
    fasta_files <- sample(fasta_files)[1:min(file_sample, length(fasta_files))]
50
    num_files <- length(fasta_files)
51
  }
52
  
53
  l <- vector("list")
54
  for (i in 1:n){
55
    l[[i]] <-  vocabulary
56
  }
57
  label_df <- apply(expand.grid(l), 2, as.character)
58
  labels <- vector("character")
59
  for (i in 1:nrow(label_df)){
60
    labels[i] <- paste(label_df[i, ], collapse = "")
61
  }
62
  #labels
63
  
64
  targets <- vector("character")
65
  for (i in 1:length(vocabulary)){
66
    targets <- c(targets, rep(vocabulary[i], length(labels)))
67
  }
68
  gram <- rep(labels, length(vocabulary))
69
  freq <- rep(0, length(labels) * length(vocabulary))
70
  freq_df <- data.frame(gram, targets, freq)
71
  nuc_table <- vector("list")
72
  
73
  for (i in 1:num_files) {
74
    
75
    if (format == "fasta") {
76
      fasta_file <-  microseq::readFasta(fasta_files[i])
77
      
78
    } 
79
    if (format == "fastq") {
80
      fasta_file <-  microseq::readFastq(fasta_files[i])
81
    } 
82
    
83
    seq_vector <- fasta_file$Sequence
84
    start_ind <- get_start_ind(seq_vector = seq_vector,
85
                               length_vector = nchar(seq_vector),
86
                               maxlen = n, step = step, train_mode = "lm")
87
    nuc_seq <- paste(seq_vector, collapse = "")
88
    split_seq <- strsplit(nuc_seq, "")[[1]]
89
    nuc_seq_length <- nchar(nuc_seq)
90
    gram <- split_seq[1 : (nuc_seq_length - n)]
91
    if (n > 1){
92
      for (j in 2:n){
93
        gram <- paste0(gram, split_seq[j : (nuc_seq_length - n + j - 1)])
94
      }
95
    }
96
    targets <- split_seq[(n + 1) : nuc_seq_length]
97
    
98
    # remove sequences with overlapping fasta entries
99
    gram <- gram[start_ind]
100
    targets <- targets[start_ind]
101
    
102
    # remove sequences with ambiguous nucleotides
103
    amb_pos_gram <- c(1:(length(gram)))[stringr::str_detect(gram, paste0("[^", paste0(vocabulary, collapse = ""), "]"))]
104
    amb_pos_targets <- c(1:(length(gram)))[stringr::str_detect(targets, paste0("[^", paste0(vocabulary, collapse = ""), "]"))]
105
    amb_pos <- union(amb_pos_gram, amb_pos_targets)
106
    if (length(amb_pos) > 0){
107
      gram <- gram[-amb_pos]
108
      targets <- targets[-amb_pos]
109
    }
110
    
111
    gram_df <- data.frame(gram = factor(gram, levels = labels),
112
                          targets = factor(targets, levels = vocabulary))
113
    table_df <- as.data.frame(table(gram_df))
114
    
115
    stopifnot(all(freq_df$gram == table_df$gram) & all(freq_df$targets == table_df$targets))
116
    
117
    freq_df$freq <- freq_df$freq + table_df$Freq
118
  }
119
  
120
  dist_matrix <- df_to_distribution_matrix(freq_df, vocabulary = vocabulary)
121
  dist_matrix
122
}
123
124
df_to_distribution_matrix <- function(freq_df, vocabulary = c("A", "C", "G", "T")) {
125
  
126
  stopifnot(names(freq_df) == c("gram", "targets", "freq"))
127
  gram_levels <- levels(factor(freq_df$gram))
128
  num_levels <- length(gram_levels)
129
  dist_matrix <- matrix(0, nrow = num_levels, ncol = length(vocabulary))
130
  dist_matrix <- as.data.frame(dist_matrix)
131
  rownames(dist_matrix) <- as.character(freq_df$gram[1:nrow(dist_matrix)])
132
  colnames(dist_matrix) <- vocabulary
133
 
134
  for (nuc in vocabulary){
135
    nuc_column <- freq_df %>% dplyr::filter(targets == nuc) %>% dplyr::select(gram, freq)
136
    stopifnot(nuc_column$gram == rownames(dist_matrix))
137
    dist_matrix[ , nuc] <- nuc_column$freq
138
  }
139
  dist_matrix$sum <- apply(dist_matrix, 1, sum)
140
  non_zero <- dist_matrix$sum != 0
141
  for (nuc in vocabulary) {
142
    dist_matrix[non_zero, nuc] <- dist_matrix[non_zero, nuc]/dist_matrix$sum[non_zero]
143
  }
144
  dist_matrix[ , vocabulary]
145
}
146
147
#' Predict the next nucleotide using n-gram
148
#'
149
#' Predict the next nucleotide using n-gram. 
150
#'
151
#' @inheritParams generator_fasta_lm
152
#' @param path_input Path to folder containing fasta files or single fasta file.
153
#' @param distribution_matrix A data frame containing frequency of next nucleotide given the previous n nucleotides (output of \code{\link{n_gram_dist}} function).
154
#' @param default_pred Either character from vocabulary or `"random"`. Will be used as prediction if certain n-gram did not appear before.
155
#' If `"random"` assign random prediction.
156
#' @param vocabulary Vector of allowed characters, samples outside vocabulary get discarded.
157
#' @param file_sample If integer, size of random sample of files in \code{path_input}.
158
#' @param return_data_frames Boolean, whether to return data frame with input, predictions, target position and true target.
159
#'
160
#' @examples
161
#' # create dummy fasta files
162
#' temp_dir <- tempfile()
163
#' dir.create(temp_dir)
164
#' create_dummy_data(file_path = temp_dir,
165
#'                   num_files = 3,
166
#'                   seq_length = 8,
167
#'                   vocabulary = c("A", "C", "G", "T"),
168
#'                   num_seq = 2)
169
#' 
170
#' m <- n_gram_dist(path_input = temp_dir,
171
#'                  n = 3,
172
#'                  step = 1,
173
#'                  nuc_dist = FALSE)
174
#' 
175
#' # use distribution matrix to make predictions for one file
176
#' predictions <- predict_with_n_gram(path_input = list.files(temp_dir, full.names = TRUE)[1], 
177
#'                                    distribution_matrix = m)
178
#' 
179
#' # show accuracy
180
#' predictions[[1]]
181
#' 
182
#' @returns List of prediction evaluations.
183
#' @export
184
predict_with_n_gram <- function(path_input, distribution_matrix, default_pred = "random", vocabulary = c("A", "C", "G", "T"),
185
                                file_sample = NULL, format = "fasta", return_data_frames = FALSE, step = 1) {
186
  
187
  n <- nchar(rownames(distribution_matrix)[1])
188
  pred_int <- apply(distribution_matrix, 1, which.max)
189
  # predict most common nucleotide if gram did not appear before
190
  sum_columns <- apply(distribution_matrix, 2, sum)
191
  zero_rows <- which(sum_columns == 0)
192
  if (default_pred == "random") {
193
    random_pred <- sample(1:length(vocabulary), length(zero_rows), replace = TRUE)
194
    pred_int[zero_rows] <- random_pred
195
  } else {
196
    pred_int[zero_rows] <- which(vocabulary == default_pred)
197
  }
198
  # integer to nucleotide
199
  pred <- vector("character")
200
  for (i in 1:length(pred_int)){
201
    pred[i] <- vocabulary[pred_int[i]]
202
  }
203
  
204
  model <- data.frame(gram = rownames(distribution_matrix), pred = pred)
205
  
206
  if (endsWith(path_input, paste0(".", format))) {
207
    num_files <- 1
208
    fasta_files <- path_input
209
  } else {
210
    fasta_files <- list.files(
211
      path = path_input,
212
      pattern = paste0("\\.", format, "$"),
213
      full.names = TRUE)
214
    num_files <- length(fasta_files)
215
  }
216
  
217
  # take random subset of files
218
  if (!is.null(file_sample)){
219
    fasta_files <- sample(fasta_files)[1 : min(file_sample, length(fasta_files))]
220
    num_files <- length(fasta_files)
221
  }
222
  
223
  labels <- rownames(distribution_matrix)
224
  
225
  pred_df_list <- vector("list")
226
  
227
  for (i in 1:num_files) {
228
    
229
    if (format == "fasta") {
230
      fasta_file <-  microseq::readFasta(fasta_files[i])
231
      
232
    } 
233
    if (format == "fastq") {
234
      fasta_file <-  microseq::readFastq(fasta_files[i])
235
    } 
236
    
237
    seq_vector <- fasta_file$Sequence
238
    start_ind <- get_start_ind(seq_vector = seq_vector,
239
                               length_vector = nchar(seq_vector),
240
                               maxlen = n, step = step, train_mode = "lm")
241
    nuc_seq <- paste(seq_vector, collapse = "")
242
    split_seq <- strsplit(nuc_seq, "")[[1]]
243
    
244
    nuc_seq_length <- nchar(nuc_seq)
245
    gram <- split_seq[1 : (nuc_seq_length - n)]
246
    if (n > 1){
247
      for (j in 2:n){
248
        gram <- paste0(gram, split_seq[j : (nuc_seq_length - n + j - 1)])
249
      }
250
    }
251
    targets <- split_seq[(n + 1) : nuc_seq_length]
252
    
253
    # remove sequences with overlapping fasta entries
254
    gram <- gram[start_ind]
255
    targets <- targets[start_ind]
256
    gram_df <- data.frame(gram = factor(gram, levels = labels),
257
                          targets = factor(targets, levels = vocabulary),
258
                          target_pos = start_ind + n)
259
    
260
    # remove sequences with ambiguous nucleotides
261
    gram_df <- gram_df[stats::complete.cases(gram_df), ]
262
    
263
    pred_df <- dplyr::left_join(gram_df, model, by = "gram")
264
    names(pred_df)[2] <- "true"
265
    if (return_data_frames) {
266
      pred_df_list[[i]] <- list(pred_df, accuracy = sum(pred_df$true == pred_df$pred)/nrow(pred_df))
267
    } else {
268
      pred_df_list[[i]] <- list(accuracy = sum(pred_df$true == pred_df$pred)/nrow(pred_df))
269
    }
270
  }
271
  
272
  return(pred_df_list)
273
}