a b/R/create_model_lstm_cnn.R
1
#' @title Create LSTM/CNN network
2
#'
3
#' @description Creates a network consisting of an arbitrary number of CNN, LSTM and dense layers.
4
#' Last layer is a dense layer.
5
#' 
6
#' @param maxlen Length of predictor sequence.
7
#' @param dropout_lstm Fraction of the units to drop for inputs.
8
#' @param recurrent_dropout_lstm Fraction of the units to drop for recurrent state.
9
#' @param layer_lstm Number of cells per network layer. Can be a scalar or vector.
10
#' @param layer_dense Vector specifying number of neurons per dense layer after last LSTM or CNN layer (if no LSTM used).
11
#' @param dropout_dense Dropout rates between dense layers. No dropout if `NULL`.
12
#' @param solver Optimization method, options are `"adam", "adagrad", "rmsprop"` or `"sgd"`.
13
#' @param learning_rate Learning rate for optimizer.
14
#' @param bidirectional Use bidirectional wrapper for lstm layers.
15
#' @param vocabulary_size Number of unique character in vocabulary.
16
#' @param stateful Boolean. Whether to use stateful LSTM layer.
17
#' @param batch_size Number of samples that are used for one network update. Only used if \code{stateful = TRUE}.
18
#' @param compile Whether to compile the model.
19
#' @param kernel_size Size of 1d convolutional layers. For multiple layers, assign a vector. (e.g, `rep(3,2)` for two layers and kernel size 3)
20
#' @param filters Number of filters. For multiple layers, assign a vector.
21
#' @param strides Stride values. For multiple layers, assign a vector.
22
#' @param pool_size Integer, size of the max pooling windows. For multiple layers, assign a vector.
23
#' @param padding Padding of CNN layers, e.g. `"same", "valid"` or `"causal"`.
24
#' @param dilation_rate Integer, the dilation rate to use for dilated convolution.
25
#' @param gap Whether to apply global average pooling after last CNN layer.
26
#' @param use_bias Boolean. Usage of bias for CNN layers.
27
#' @param residual_block Boolean. If true, the residual connections are used in CNN. It is not used in the first convolutional layer.
28
#' @param residual_block_length Integer. Determines how many convolutional layers (or triplets when `size_reduction_1D_conv` is `TRUE`) exist
29
#  between the legs of a residual connection. e.g. if the `length kernel_size/filters` is 7 and `residual_block_length` is 2, there are 1+(7-1)*2 convolutional
30
#  layers in the model when `size_reduction_1Dconv` is FALSE and 1+(7-1)*2*3 convolutional layers when `size_reduction_1Dconv` is TRUE.
31
#' @param size_reduction_1Dconv Boolean. When `TRUE`, the number of filters in the convolutional layers is reduced to 1/4 of the number of filters of
32
#  the original layer by a convolution layer with kernel size 1, and number of filters are increased back to the original value by a convolution layer
33
#  with kernel size 1 after the convolution with original kernel size with reduced number of filters.
34
#' @param label_input Integer or `NULL`. If not `NULL`, adds additional input layer of \code{label_input} size.
35
#' @param zero_mask Boolean, whether to apply zero masking before LSTM layer. Only used if model does not use any CNN layers.
36
#' @param label_smoothing Float in \[0, 1\]. If 0, no smoothing is applied. If > 0, loss between the predicted
37
#' labels and a smoothed version of the true labels, where the smoothing squeezes the labels towards 0.5.
38
#' The closer the argument is to 1 the more the labels get smoothed.
39
#' @param label_noise_matrix Matrix of label noises. Every row stands for one class and columns for percentage of labels in that class.
40
#' If first label contains 5 percent wrong labels and second label no noise, then
41
#' 
42
#' \code{label_noise_matrix <- matrix(c(0.95, 0.05, 0, 1), nrow = 2, byrow = TRUE )}
43
#' @param last_layer_activation Activation function of output layer(s). For example `"sigmoid"` or `"softmax"`.
44
#' @param loss_fn Either `"categorical_crossentropy"` or `"binary_crossentropy"`. If `label_noise_matrix` given, will use custom `"noisy_loss"`.
45
#' @param num_output_layers Number of output layers.
46
#' @param auc_metric Whether to add AUC metric.
47
#' @param f1_metric Whether to add F1 metric.
48
#' @param bal_acc Whether to add balanced accuracy.
49
#' @param verbose Boolean.
50
#' @param batch_norm_momentum Momentum for the moving mean and the moving variance.
51
#' @param model_seed Set seed for model parameters in tensorflow if not `NULL`.
52
#' @param mixed_precision Whether to use mixed precision (https://www.tensorflow.org/guide/mixed_precision).
53
#' @param mirrored_strategy Whether to use distributed mirrored strategy. If NULL, will use distributed mirrored strategy only if >1 GPU available.   
54
#' @examplesIf reticulate::py_module_available("tensorflow")
55
#' create_model_lstm_cnn(
56
#'   maxlen = 500,
57
#'   vocabulary_size = 4,
58
#'   kernel_size = c(8, 8, 8),
59
#'   filters = c(16, 32, 64),
60
#'   pool_size = c(3, 3, 3),
61
#'   layer_lstm = c(32, 64),
62
#'   layer_dense = c(128, 4),
63
#'   learning_rate = 0.001)
64
#' 
65
#' @returns A keras model, stacks CNN, LSTM and dense layers.   
66
#' @export
67
create_model_lstm_cnn <- function(
68
    maxlen = 50,
69
    dropout_lstm = 0,
70
    recurrent_dropout_lstm = 0,
71
    layer_lstm = NULL,
72
    layer_dense = c(4),
73
    dropout_dense = NULL,
74
    kernel_size = NULL,
75
    filters = NULL,
76
    strides = NULL,
77
    pool_size = NULL,
78
    solver = "adam",
79
    learning_rate = 0.001,
80
    vocabulary_size = 4,
81
    bidirectional = FALSE,
82
    stateful = FALSE,
83
    batch_size = NULL,
84
    compile = TRUE,
85
    padding = "same",
86
    dilation_rate = NULL,
87
    gap = FALSE,
88
    use_bias = TRUE,
89
    residual_block = FALSE,
90
    residual_block_length = 1,
91
    size_reduction_1Dconv = FALSE,
92
    label_input = NULL,
93
    zero_mask = FALSE,
94
    label_smoothing = 0,
95
    label_noise_matrix = NULL,
96
    last_layer_activation = "softmax",
97
    loss_fn = "categorical_crossentropy",
98
    num_output_layers = 1,
99
    auc_metric = FALSE,
100
    f1_metric = FALSE,
101
    bal_acc = FALSE,
102
    verbose = TRUE,
103
    batch_norm_momentum = 0.99,
104
    model_seed = NULL,
105
    mixed_precision = FALSE,
106
    mirrored_strategy = NULL) {
107
  
108
  if (mixed_precision) tensorflow::tf$keras$mixed_precision$set_global_policy("mixed_float16")
109
  
110
  if (is.null(mirrored_strategy)) mirrored_strategy <- ifelse(count_gpu() > 1, TRUE, FALSE)
111
  if (mirrored_strategy) {
112
    mirrored_strategy <- tensorflow::tf$distribute$MirroredStrategy()
113
    with(mirrored_strategy$scope(), { 
114
      argg <- as.list(environment())
115
      argg$mirrored_strategy <- FALSE
116
      model <- do.call(create_model_lstm_cnn, argg)
117
    })
118
    return(model)
119
  }
120
  
121
  layer_dense <- as.integer(layer_dense)
122
  #browser()
123
  if (!is.null(model_seed)) tensorflow::tf$random$set_seed(model_seed)
124
  num_targets <- layer_dense[length(layer_dense)]
125
  layers.lstm <- length(layer_lstm)
126
  use.cnn <- ifelse(!is.null(kernel_size), TRUE, FALSE)
127
  
128
  if (!is.null(layer_lstm)) {
129
    stopifnot(length(layer_lstm) == 1 | (length(layer_lstm) ==  layers.lstm))
130
  }
131
  
132
  if (layers.lstm == 0 & !use.cnn) {
133
    stop("Model does not use LSTM or CNN layers.")
134
  }
135
  
136
  if (is.null(strides)) strides <- rep(1L, length(filters))
137
  if (is.null(dilation_rate) & use.cnn) dilation_rate <- rep(1L, length(filters))
138
  
139
  if (use.cnn) {
140
    same_length <- (length(kernel_size) == length(filters)) &
141
      (length(filters) == length(strides)) &
142
      (length(strides) == length(dilation_rate))
143
    if (!same_length) {
144
      stop("kernel_size, filters, dilation_rate and strides must have the same length")
145
    }
146
    if (residual_block & (padding != "same")) {
147
      stop("Padding option must be same when residual block is used.")
148
    }
149
  }
150
  
151
  stopifnot(maxlen > 0)
152
  stopifnot(dropout_lstm <= 1 & dropout_lstm >= 0)
153
  stopifnot(recurrent_dropout_lstm <= 1 & recurrent_dropout_lstm >= 0)
154
  
155
  if (length(layer_lstm) == 1) {
156
    layer_lstm <- rep(layer_lstm, layers.lstm)
157
  }
158
  
159
  if (stateful) {
160
    input_tensor <- keras::layer_input(batch_shape = c(batch_size, maxlen, vocabulary_size))
161
  } else {
162
    input_tensor <- keras::layer_input(shape = c(maxlen, vocabulary_size))
163
  }
164
  
165
  if (use.cnn) {
166
    for (i in 1:length(filters)) {
167
      if (i == 1) {
168
        output_tensor <- input_tensor %>%
169
          keras::layer_conv_1d(
170
            kernel_size = kernel_size[i],
171
            padding = padding,
172
            activation = "relu",
173
            filters = filters[i],
174
            strides = strides[i],
175
            dilation_rate = dilation_rate[i],
176
            input_shape = switch(stateful + 1, c(maxlen, vocabulary_size), NULL),
177
            use_bias = use_bias
178
          )
179
        if (!is.null(pool_size) && pool_size[i] > 1) {
180
          output_tensor <- output_tensor %>% keras::layer_max_pooling_1d(pool_size = pool_size[i])
181
        }
182
        output_tensor <- output_tensor %>% keras::layer_batch_normalization(momentum = batch_norm_momentum)
183
      } else {
184
        if (residual_block){
185
          if ((strides[i] > 1) | (pool_size[i] > 1)) {
186
            residual_layer <- output_tensor %>% keras::layer_average_pooling_1d(pool_size=strides[i]*pool_size[i])
187
          } else {
188
            residual_layer <- output_tensor
189
          }
190
          if (filters[i-1] != filters[i]){
191
            residual_layer <- residual_layer %>%
192
              keras::layer_conv_1d(
193
                kernel_size = 1,
194
                padding = padding,
195
                activation = "relu",
196
                filters = filters[i],
197
                strides = 1,
198
                dilation_rate = dilation_rate[i],
199
                input_shape = switch(stateful + 1, c(maxlen, vocabulary_size), NULL),
200
                use_bias = use_bias
201
              )
202
            residual_layer <- residual_layer %>% keras::layer_batch_normalization(momentum = batch_norm_momentum)
203
          }
204
          if (residual_block_length > 1){
205
            for (j in 1:(residual_block_length-1)){
206
              if (size_reduction_1Dconv){
207
                output_tensor <- output_tensor %>%
208
                  keras::layer_conv_1d(
209
                    kernel_size = 1,
210
                    padding = padding,
211
                    activation = "relu",
212
                    filters = filters[i]/4,
213
                    strides = 1,
214
                    dilation_rate = dilation_rate[i],
215
                    input_shape = switch(stateful + 1, c(maxlen, vocabulary_size), NULL),
216
                    use_bias = use_bias
217
                  )
218
                output_tensor <- output_tensor %>% keras::layer_batch_normalization(momentum = batch_norm_momentum)
219
                
220
                output_tensor <- output_tensor %>%
221
                  keras::layer_conv_1d(
222
                    kernel_size = kernel_size[i],
223
                    padding = padding,
224
                    activation = "relu",
225
                    filters = filters[i]/4,
226
                    strides = 1,
227
                    dilation_rate = dilation_rate[i],
228
                    input_shape = switch(stateful + 1, c(maxlen, vocabulary_size), NULL),
229
                    use_bias = use_bias
230
                  )
231
                output_tensor <- output_tensor %>% keras::layer_batch_normalization(momentum = batch_norm_momentum)
232
                
233
                output_tensor <- output_tensor %>%
234
                  keras::layer_conv_1d(
235
                    kernel_size = 1,
236
                    padding = padding,
237
                    activation = "relu",
238
                    filters = filters[i],
239
                    strides = 1,
240
                    dilation_rate = dilation_rate[i],
241
                    input_shape = switch(stateful + 1, c(maxlen, vocabulary_size), NULL),
242
                    use_bias = use_bias
243
                  )
244
                output_tensor <- output_tensor %>% keras::layer_batch_normalization(momentum = batch_norm_momentum)
245
                
246
              } else {
247
                output_tensor <- output_tensor %>%
248
                  keras::layer_conv_1d(
249
                    kernel_size = kernel_size[i],
250
                    padding = padding,
251
                    activation = "relu",
252
                    filters = filters[i],
253
                    strides = 1,
254
                    dilation_rate = dilation_rate[i],
255
                    input_shape = switch(stateful + 1, c(maxlen, vocabulary_size), NULL),
256
                    use_bias = use_bias
257
                  )
258
                output_tensor <- output_tensor %>% keras::layer_batch_normalization(momentum = batch_norm_momentum)
259
              }
260
            }
261
          }
262
        }
263
        if (size_reduction_1Dconv){
264
          output_tensor <- output_tensor %>%
265
            keras::layer_conv_1d(
266
              kernel_size = 1,
267
              padding = padding,
268
              activation = "relu",
269
              filters = filters[i]/4,
270
              strides = 1,
271
              dilation_rate = dilation_rate[i],
272
              input_shape = switch(stateful + 1, c(maxlen, vocabulary_size), NULL),
273
              use_bias = use_bias
274
            )
275
          output_tensor <- output_tensor %>% keras::layer_batch_normalization(momentum = batch_norm_momentum)
276
          
277
          output_tensor <- output_tensor %>%
278
            keras::layer_conv_1d(
279
              kernel_size = kernel_size[i],
280
              padding = padding,
281
              activation = "relu",
282
              filters = filters[i]/4,
283
              strides = strides[i],
284
              dilation_rate = dilation_rate[i],
285
              input_shape = switch(stateful + 1, c(maxlen, vocabulary_size), NULL),
286
              use_bias = use_bias
287
            )
288
          output_tensor <- output_tensor %>% keras::layer_batch_normalization(momentum = batch_norm_momentum)
289
          
290
          output_tensor <- output_tensor %>%
291
            keras::layer_conv_1d(
292
              kernel_size = 1,
293
              padding = padding,
294
              activation = "relu",
295
              filters = filters[i],
296
              strides = 1,
297
              dilation_rate = dilation_rate[i],
298
              input_shape = switch(stateful + 1, c(maxlen, vocabulary_size), NULL),
299
              use_bias = use_bias
300
            )
301
          output_tensor <- output_tensor %>% keras::layer_batch_normalization(momentum = batch_norm_momentum)
302
          
303
        } else {
304
          output_tensor <- output_tensor %>%
305
            keras::layer_conv_1d(
306
              kernel_size = kernel_size[i],
307
              padding = padding,
308
              activation = "relu",
309
              filters = filters[i],
310
              strides = strides[i],
311
              dilation_rate = dilation_rate[i],
312
              input_shape = switch(stateful + 1, c(maxlen, vocabulary_size), NULL),
313
              use_bias = use_bias
314
            )
315
          output_tensor <- output_tensor %>% keras::layer_batch_normalization(momentum = batch_norm_momentum)
316
        }
317
        if (!is.null(pool_size) && pool_size[i] > 1) {
318
          output_tensor <- output_tensor %>% keras::layer_max_pooling_1d(pool_size = pool_size[i])
319
        }
320
        #output_tensor <- output_tensor %>% keras::layer_batch_normalization(momentum = batch_norm_momentum)
321
        if (residual_block){
322
          output_tensor <- keras::layer_add(list(output_tensor, residual_layer))
323
        }
324
      }
325
    }
326
  } else {
327
    if (zero_mask) {
328
      output_tensor <- input_tensor %>% keras::layer_masking()
329
    } else {
330
      output_tensor <- input_tensor
331
    }
332
  }
333
  # lstm layers
334
  if (layers.lstm > 0) {
335
    if (layers.lstm > 1) {
336
      if (bidirectional) {
337
        for (i in 1:(layers.lstm - 1)) {
338
          output_tensor <- output_tensor %>%
339
            keras::bidirectional(
340
              input_shape = switch(stateful + 1, c(maxlen, vocabulary_size), NULL),
341
              keras::layer_lstm(
342
                units = layer_lstm[i],
343
                return_sequences = TRUE,
344
                dropout = dropout_lstm,
345
                recurrent_dropout = recurrent_dropout_lstm,
346
                stateful = stateful,
347
                recurrent_activation = "sigmoid"
348
              )
349
            )
350
        }
351
      } else {
352
        for (i in 1:(layers.lstm - 1)) {
353
          output_tensor <- output_tensor %>%
354
            keras::layer_lstm(
355
              layer_lstm[i],
356
              input_shape = switch(stateful + 1, c(maxlen, vocabulary_size), NULL),
357
              return_sequences = TRUE,
358
              dropout = dropout_lstm,
359
              recurrent_dropout = recurrent_dropout_lstm,
360
              stateful = stateful,
361
              recurrent_activation = "sigmoid"
362
            )
363
        }
364
      }
365
    }
366
    # last LSTM layer
367
    if (bidirectional) {
368
      output_tensor <- output_tensor %>%
369
        keras::bidirectional(
370
          input_shape = switch(stateful + 1, c(maxlen, vocabulary_size), NULL),
371
          keras::layer_lstm(units = layer_lstm[length(layer_lstm)], dropout = dropout_lstm, recurrent_dropout = recurrent_dropout_lstm,
372
                            stateful = stateful, recurrent_activation = "sigmoid")
373
        )
374
    } else {
375
      output_tensor <- output_tensor %>%
376
        keras::layer_lstm(units = layer_lstm[length(layer_lstm)],
377
                          input_shape = switch(stateful + 1, c(maxlen, vocabulary_size), NULL),
378
                          dropout = dropout_lstm, recurrent_dropout = recurrent_dropout_lstm, stateful = stateful,
379
                          recurrent_activation = "sigmoid")
380
    }
381
  }
382
  
383
  if (gap) {
384
    if (layers.lstm != 0) {
385
      stop("Global average pooling not compatible with using LSTM layer")
386
    }
387
    output_tensor <- output_tensor %>% keras::layer_global_average_pooling_1d()
388
  } else {
389
    if (layers.lstm == 0) {
390
      output_tensor <- output_tensor %>% keras::layer_flatten()
391
    }
392
  }
393
  
394
  if (!is.null(label_input)) {
395
    input_label_list <- list()
396
    for (i in 1:length(label_input)) {
397
      if (!stateful) {
398
        eval(parse(text = paste0("label_input_layer_", as.character(i), "<- keras::layer_input(c(label_input[i]))")))
399
      } else {
400
        eval(parse(text = paste0("label_input_layer_", as.character(i), "<- keras::layer_input(batch_shape = c(batch_size, label_input[i]))")))
401
      }
402
      input_label_list[[i]] <- eval(parse(text = paste0("label_input_layer_", as.character(i))))
403
    }
404
    output_tensor <- keras::layer_concatenate(c(
405
      input_label_list, output_tensor
406
    )
407
    )
408
  }
409
  
410
  if (length(layer_dense) > 1) {
411
    for (i in 1:(length(layer_dense) - 1)) {
412
      if (!is.null(dropout_dense)) output_tensor <- output_tensor %>% keras::layer_dropout(dropout_dense[i])
413
      output_tensor <- output_tensor %>% keras::layer_dense(units = layer_dense[i], activation = "relu")
414
    }
415
  }
416
  
417
  if (num_output_layers == 1) {
418
    if (!is.null(dropout_dense)) output_tensor <- output_tensor %>% keras::layer_dropout(dropout_dense[length(dropout_dense)])
419
    output_tensor <- output_tensor %>%
420
      keras::layer_dense(units = num_targets, activation = last_layer_activation, dtype = "float32")
421
  } else {
422
    output_list <- list()
423
    for (i in 1:num_output_layers) {
424
      layer_name <- paste0("output_", i, "_", num_output_layers)
425
      if (!is.null(dropout_dense)) {
426
        output_list[[i]] <- output_tensor %>% keras::layer_dropout(dropout_dense[length(dropout_dense)])
427
        output_list[[i]] <- output_list[[i]] %>%
428
          keras::layer_dense(units = num_targets, activation = last_layer_activation, name = layer_name, dtype = "float32")
429
      } else {
430
        output_list[[i]] <- output_tensor %>%
431
          keras::layer_dense(units = num_targets, activation = last_layer_activation, name = layer_name, dtype = "float32")
432
      }
433
    }
434
  }
435
  
436
  if (!is.null(label_input)) {
437
    label_inputs <- list()
438
    for (i in 1:length(label_input)) {
439
      eval(parse(text = paste0("label_inputs$label_input_layer_", as.character(i), "<- label_input_layer_", as.character(i))))
440
    }
441
    if (num_output_layers == 1) {
442
      model <- keras::keras_model(inputs = list(label_inputs, input_tensor), outputs = output_tensor)
443
    } else {
444
      model <- keras::keras_model(inputs = list(label_inputs, input_tensor), outputs = output_list)
445
    }
446
  } else {
447
    if (num_output_layers == 1) {
448
      model <- keras::keras_model(inputs = input_tensor, outputs = output_tensor)
449
    } else {
450
      model <- keras::keras_model(inputs = input_tensor, outputs = output_list)
451
    }
452
  }
453
  
454
  if (compile) {
455
    model <- compile_model(model = model, label_smoothing = label_smoothing, layer_dense = layer_dense,
456
                           solver = solver, learning_rate = learning_rate, loss_fn = loss_fn, 
457
                           num_output_layers = num_output_layers, label_noise_matrix = label_noise_matrix,
458
                           bal_acc = bal_acc, f1_metric = f1_metric, auc_metric = auc_metric)
459
  }
460
  
461
  argg <- c(as.list(environment()))
462
  model <- add_hparam_list(model, argg)
463
  
464
  if (verbose) model$summary()
465
  return(model)
466
}
467
468
469
#' @title Create LSTM/CNN network to predict middle part of a sequence
470
#'
471
#' @description
472
#' Creates a network consisting of an arbitrary number of CNN, LSTM and dense layers.
473
#' Function creates two sub networks consisting each of (optional) CNN layers followed by an arbitrary number of LSTM layers. Afterwards the last LSTM layers
474
#' get concatenated and followed by one or more dense layers. Last layer is a dense layer.
475
#' Network tries to predict target in the middle of a sequence. If input is AACCTAAGG, input tensors should correspond to x1 = AACC, x2 = GGAA and y = T.
476
#' 
477
#' @inheritParams create_model_lstm_cnn
478
#' @examplesIf reticulate::py_module_available("tensorflow")
479
#' create_model_lstm_cnn_target_middle(
480
#'   maxlen = 500,
481
#'   vocabulary_size = 4,
482
#'   kernel_size = c(8, 8, 8),
483
#'   filters = c(16, 32, 64),
484
#'   pool_size = c(3, 3, 3),
485
#'   layer_lstm = c(32, 64),
486
#'   layer_dense = c(128, 4),
487
#'   learning_rate = 0.001)
488
#'  
489
#' @returns A keras model with two input layers. Consists of LSTN, CNN and dense layers.
490
#' @export
491
create_model_lstm_cnn_target_middle <- function(
492
    maxlen = 50,
493
    dropout_lstm = 0,
494
    recurrent_dropout_lstm = 0,
495
    layer_lstm = 128,
496
    solver = "adam",
497
    learning_rate = 0.001,
498
    vocabulary_size = 4,
499
    bidirectional = FALSE,
500
    stateful = FALSE,
501
    batch_size = NULL,
502
    padding = "same",
503
    compile = TRUE,
504
    layer_dense = NULL,
505
    kernel_size = NULL,
506
    filters = NULL,
507
    pool_size = NULL,
508
    strides = NULL,
509
    label_input = NULL,
510
    zero_mask = FALSE,
511
    label_smoothing = 0,
512
    label_noise_matrix = NULL,
513
    last_layer_activation = "softmax",
514
    loss_fn = "categorical_crossentropy",
515
    num_output_layers = 1,
516
    f1_metric = FALSE,
517
    auc_metric = FALSE,
518
    bal_acc = FALSE,
519
    verbose = TRUE,
520
    batch_norm_momentum = 0.99,
521
    model_seed = NULL,
522
    mixed_precision = FALSE,
523
    mirrored_strategy = NULL) {
524
  
525
  if (mixed_precision) tensorflow::tf$keras$mixed_precision$set_global_policy("mixed_float16")
526
  
527
  if (is.null(mirrored_strategy)) mirrored_strategy <- ifelse(count_gpu() > 1, TRUE, FALSE)
528
  if (mirrored_strategy) {
529
    mirrored_strategy <- tensorflow::tf$distribute$MirroredStrategy()
530
    with(mirrored_strategy$scope(), { 
531
      argg <- as.list(environment())
532
      argg$mirrored_strategy <- FALSE
533
      model <- do.call(create_model_lstm_cnn_target_middle, argg)
534
    })
535
    return(model)
536
  }
537
  
538
  layer_dense <- as.integer(layer_dense)
539
  if (!is.null(model_seed)) tensorflow::tf$random$set_seed(model_seed)
540
  use.cnn <- ifelse(!is.null(kernel_size), TRUE, FALSE)
541
  num_targets <- layer_dense[length(layer_dense)]
542
  layers.lstm <- length(layer_lstm)
543
  
544
  stopifnot(length(layer_lstm) == 1 | (length(layer_lstm) ==  layers.lstm))
545
  stopifnot(maxlen > 0)
546
  stopifnot(dropout_lstm <= 1 & dropout_lstm >= 0)
547
  stopifnot(recurrent_dropout_lstm <= 1 & recurrent_dropout_lstm >= 0)
548
  
549
  if (!is.null(layer_lstm)) {
550
    stopifnot(length(layer_lstm) == 1 | (length(layer_lstm) ==  layers.lstm))
551
  }
552
  
553
  if (is.null(strides)) {
554
    strides <- rep(1L, length(filters))
555
  }
556
  
557
  if (use.cnn) {
558
    same_length <- (length(kernel_size) == length(filters)) & (length(filters) == length(strides))
559
    if (!same_length) {
560
      stop("kernel_size, filters and strides must have the same length")
561
    }
562
  }
563
  
564
  # length of split sequences
565
  maxlen_1 <- ceiling(maxlen/2)
566
  maxlen_2 <- floor(maxlen/2)
567
  if (stateful) {
568
    input_tensor_1 <- keras::layer_input(batch_shape = c(batch_size, maxlen_1, vocabulary_size))
569
  } else {
570
    input_tensor_1 <- keras::layer_input(shape = c(maxlen_1, vocabulary_size))
571
  }
572
  
573
  if (use.cnn) {
574
    for (i in 1:length(filters)) {
575
      if (i == 1) {
576
        output_tensor_1 <- input_tensor_1 %>%
577
          keras::layer_conv_1d(
578
            kernel_size = kernel_size[i],
579
            padding = padding,
580
            activation = "relu",
581
            filters = filters[i],
582
            strides = strides[i],
583
            input_shape = switch(stateful + 1, c(maxlen, vocabulary_size), NULL)
584
          )
585
        if (!is.null(pool_size) && pool_size[i] > 1) {
586
          output_tensor_1 <- output_tensor_1 %>% keras::layer_max_pooling_1d(pool_size = pool_size[i])
587
        }
588
        output_tensor_1 <- output_tensor_1 %>% keras::layer_batch_normalization(momentum = batch_norm_momentum)
589
      } else {
590
        output_tensor_1 <- output_tensor_1 %>%
591
          keras::layer_conv_1d(
592
            kernel_size = kernel_size[i],
593
            padding = padding,
594
            activation = "relu",
595
            strides = strides[i],
596
            filters = filters[i]
597
          )
598
        if (!is.null(pool_size) && pool_size[i] > 1) {
599
          output_tensor_1 <- output_tensor_1 %>% keras::layer_max_pooling_1d(pool_size = pool_size[i])
600
        }
601
        output_tensor_1 <- output_tensor_1 %>% keras::layer_batch_normalization(momentum = batch_norm_momentum)
602
      }
603
    }
604
  } else {
605
    if (zero_mask) {
606
      output_tensor_1 <- input_tensor_1 %>% keras::layer_masking()
607
    } else {
608
      output_tensor_1 <- input_tensor_1
609
    }
610
  }
611
  
612
  # lstm layers
613
  if (!is.null(layers.lstm) && layers.lstm > 0) {
614
    if (layers.lstm > 1) {
615
      if (bidirectional) {
616
        for (i in 1:(layers.lstm - 1)) {
617
          output_tensor_1 <- output_tensor_1 %>%
618
            keras::bidirectional(
619
              input_shape = switch(stateful + 1, c(maxlen, vocabulary_size), NULL),
620
              keras::layer_lstm(
621
                units = layer_lstm[i],
622
                return_sequences = TRUE,
623
                dropout = dropout_lstm,
624
                recurrent_dropout = recurrent_dropout_lstm,
625
                stateful = stateful,
626
                recurrent_activation = "sigmoid"
627
              )
628
            )
629
        }
630
      } else {
631
        for (i in 1:(layers.lstm - 1)) {
632
          output_tensor_1 <- output_tensor_1 %>%
633
            keras::layer_lstm(
634
              units = layer_lstm[i],
635
              input_shape = switch(stateful + 1, c(maxlen, vocabulary_size), NULL),
636
              return_sequences = TRUE,
637
              dropout = dropout_lstm,
638
              recurrent_dropout = recurrent_dropout_lstm,
639
              stateful = stateful,
640
              recurrent_activation = "sigmoid"
641
            )
642
        }
643
      }
644
    }
645
    
646
    # last LSTM layer
647
    if (bidirectional) {
648
      output_tensor_1 <- output_tensor_1 %>%
649
        keras::bidirectional(
650
          input_shape = switch(stateful + 1, c(maxlen, vocabulary_size), NULL),
651
          keras::layer_lstm(units = layer_lstm[length(layer_lstm)], dropout = dropout_lstm, recurrent_dropout = recurrent_dropout_lstm,
652
                            stateful = stateful, recurrent_activation = "sigmoid")
653
        )
654
    } else {
655
      output_tensor_1 <- output_tensor_1 %>%
656
        keras::layer_lstm(units = layer_lstm[length(layer_lstm)],
657
                          input_shape = switch(stateful + 1, c(maxlen, vocabulary_size), NULL),
658
                          dropout = dropout_lstm, recurrent_dropout = recurrent_dropout_lstm, stateful = stateful,
659
                          recurrent_activation = "sigmoid")
660
    }
661
  }
662
  
663
  if (stateful) {
664
    input_tensor_2 <- keras::layer_input(batch_shape = c(batch_size, maxlen_2, vocabulary_size))
665
  } else {
666
    input_tensor_2 <- keras::layer_input(shape = c(maxlen_2, vocabulary_size))
667
  }
668
  
669
  if (use.cnn) {
670
    for (i in 1:length(filters)) {
671
      if (i == 1) {
672
        output_tensor_2 <- input_tensor_2 %>%
673
          keras::layer_conv_1d(
674
            kernel_size = kernel_size[i],
675
            padding = padding,
676
            activation = "relu",
677
            filters = filters[i],
678
            strides = strides[i],
679
            input_shape = switch(stateful + 1, c(maxlen, vocabulary_size), NULL)
680
          )
681
        if (!is.null(pool_size) && pool_size[i] > 1) {
682
          output_tensor_2 <- output_tensor_2 %>% keras::layer_max_pooling_1d(pool_size = pool_size[i])
683
        }
684
        output_tensor_2 <- output_tensor_2 %>% keras::layer_batch_normalization(momentum = batch_norm_momentum)
685
      } else {
686
        output_tensor_2 <- output_tensor_2 %>%
687
          keras::layer_conv_1d(
688
            kernel_size = kernel_size[i],
689
            padding = padding,
690
            activation = "relu",
691
            strides = strides[i],
692
            filters = filters[i]
693
          )
694
        if (!is.null(pool_size) && pool_size[i] > 1) {
695
          output_tensor_2 <- output_tensor_2 %>% keras::layer_max_pooling_1d(pool_size = pool_size[i])
696
        }
697
        output_tensor_2 <- output_tensor_2 %>% keras::layer_batch_normalization(momentum = batch_norm_momentum)
698
      }
699
    }
700
  } else {
701
    if (zero_mask) {
702
      output_tensor_2 <- input_tensor_2 %>% keras::layer_masking()
703
    } else {
704
      output_tensor_2 <- input_tensor_2
705
    }
706
  }
707
  
708
  
709
  # lstm layers
710
  if (!is.null(layers.lstm) && layers.lstm > 0) {
711
    if (layers.lstm > 1) {
712
      if (bidirectional) {
713
        for (i in 1:(layers.lstm - 1)) {
714
          output_tensor_2 <- output_tensor_2 %>%
715
            keras::bidirectional(
716
              input_shape = switch(stateful + 1, c(maxlen, vocabulary_size), NULL),
717
              keras::layer_lstm(
718
                units = layer_lstm[i],
719
                return_sequences = TRUE,
720
                dropout = dropout_lstm,
721
                recurrent_dropout = recurrent_dropout_lstm,
722
                stateful = stateful,
723
                recurrent_activation = "sigmoid"
724
              )
725
            )
726
        }
727
      } else {
728
        for (i in 1:(layers.lstm - 1)) {
729
          output_tensor_2 <- output_tensor_2 %>%
730
            keras::layer_lstm(
731
              units = layer_lstm[i],
732
              input_shape = switch(stateful + 1, c(maxlen, vocabulary_size), NULL),
733
              return_sequences = TRUE,
734
              dropout = dropout_lstm,
735
              recurrent_dropout = recurrent_dropout_lstm,
736
              stateful = stateful,
737
              recurrent_activation = "sigmoid"
738
            )
739
        }
740
      }
741
    }
742
    
743
    # last LSTM layer
744
    if (bidirectional) {
745
      output_tensor_2 <- output_tensor_2 %>%
746
        keras::bidirectional(
747
          input_shape = switch(stateful + 1, c(maxlen, vocabulary_size), NULL),
748
          keras::layer_lstm(units = layer_lstm[length(layer_lstm)], dropout = dropout_lstm, recurrent_dropout = recurrent_dropout_lstm,
749
                            stateful = stateful, recurrent_activation = "sigmoid")
750
        )
751
    } else {
752
      output_tensor_2 <- output_tensor_2 %>%
753
        keras::layer_lstm(units = layer_lstm[length(layer_lstm)],
754
                          input_shape = switch(stateful + 1, c(maxlen, vocabulary_size), NULL),
755
                          dropout = dropout_lstm, recurrent_dropout = recurrent_dropout_lstm, stateful = stateful,
756
                          recurrent_activation = "sigmoid")
757
    }
758
  }
759
  
760
  output_tensor <- keras::layer_concatenate(list(output_tensor_1, output_tensor_2))
761
  
762
  if (layers.lstm == 0) {
763
    output_tensor <- output_tensor %>% keras::layer_flatten()
764
  }
765
  
766
  if (!is.null(label_input)) {
767
    input_label_list <- list()
768
    for (i in 1:length(label_input)) {
769
      if (!stateful) {
770
        eval(parse(text = paste0("label_input_layer_", as.character(i), "<- keras::layer_input(c(label_input[i]))")))
771
      } else {
772
        eval(parse(text = paste0("label_input_layer_", as.character(i), "<- keras::layer_input(batch_shape = c(batch_size, label_input[i]))")))
773
      }
774
      input_label_list[[i]] <- eval(parse(text = paste0("label_input_layer_", as.character(i))))
775
    }
776
    output_tensor <- keras::layer_concatenate(c(
777
      input_label_list, output_tensor
778
    )
779
    )
780
  }
781
  
782
  if (length(layer_dense) > 1) {
783
    for (i in 1:(length(layer_dense) - 1)) {
784
      output_tensor <- output_tensor %>% keras::layer_dense(units = layer_dense[i], activation = "relu")
785
    }
786
  }
787
  
788
  if (num_output_layers == 1) {
789
    output_tensor <- output_tensor %>%
790
      keras::layer_dense(units = num_targets, activation = last_layer_activation, dtype = "float32")
791
  }  else {
792
    output_list <- list()
793
    for (i in 1:num_output_layers) {
794
      layer_name <- paste0("output_", i, "_", num_output_layers)
795
      output_list[[i]] <- output_tensor %>%
796
        keras::layer_dense(units = num_targets, activation = last_layer_activation, name = layer_name, dtype = "float32")
797
    }
798
  }
799
  
800
  # print model layout to screen
801
  if (!is.null(label_input)) {
802
    label_inputs <- list()
803
    for (i in 1:length(label_input)) {
804
      eval(parse(text = paste0("label_inputs$label_input_layer_", as.character(i), "<- label_input_layer_", as.character(i))))
805
    }
806
    model <- keras::keras_model(inputs = c(label_inputs, input_tensor_1, input_tensor_2), outputs = output_tensor)
807
  } else {
808
    model <- keras::keras_model(inputs = list(input_tensor_1, input_tensor_2), outputs = output_tensor)
809
  }
810
  
811
  if (compile) {
812
    model <- compile_model(model = model, label_smoothing = label_smoothing, layer_dense = layer_dense,
813
                           solver = solver, learning_rate = learning_rate, loss_fn = loss_fn, 
814
                           num_output_layers = num_output_layers, label_noise_matrix = label_noise_matrix,
815
                           bal_acc = bal_acc, f1_metric = f1_metric, auc_metric = auc_metric)
816
  }
817
  
818
  argg <- c(as.list(environment()))
819
  model <- add_hparam_list(model, argg)
820
  reticulate::py_set_attr(x = model, name = "hparam", value = model$hparam)
821
  
822
  if (verbose) model$summary()
823
  return(model)
824
}