a b/R/create_model_transformer.R
1
#' Create transformer model
2
#'
3
#' Creates transformer network for classification. Model can consist of several stacked attention blocks.
4
#' 
5
#' @inheritParams keras::layer_multi_head_attention
6
#' @inheritParams create_model_lstm_cnn
7
#' @param pos_encoding Either `"sinusoid"` or `"embedding"`. How to add positional information.
8
#' If `"sinusoid"`, will add sine waves of different frequencies to input.
9
#' If `"embedding"`, model learns positional embedding.
10
#' @param embed_dim Dimension for token embedding. No embedding if set to 0. Should be used when input is not one-hot encoded
11
#' (integer sequence).
12
#' @param head_size Dimensions of attention key.
13
#' @param n Frequency of sine waves for positional encoding. Only applied if `pos_encoding = "sinusoid"`.
14
#' @param ff_dim Units of first dense layer after attention blocks.
15
#' @param dropout Vector of dropout rates after attention block(s). 
16
#' @param dropout_dense Dropout for dense layers.
17
#' @param flatten_method How to process output of last attention block. Can be `"max_ch_first"`, `"max_ch_last"`, `"average_ch_first"`,
18
#' `"average_ch_last"`, `"both_ch_first"`, `"both_ch_last"`, `"all"`, `"none"` or `"flatten"`.
19
#' If `"average_ch_last"` /  `"max_ch_last"`  or `"average_ch_first"` / `"max_ch_first"`, will apply global average/max pooling.
20
#' `_ch_first` / `_ch_last` to decide along which axis. `"both_ch_first"` / `"both_ch_last"` to use max and average together. `"all"` to use all 4 
21
#' global pooling options together. If `"flatten"`, will flatten output after last attention block. If `"none"` no flattening applied.
22
#' @examples
23
#' 
24
#' maxlen <- 50
25
#' \donttest{
26
#' library(keras)
27
#' model <- create_model_transformer(maxlen = maxlen,
28
#'                                   head_size=c(10,12),
29
#'                                   num_heads=c(7,8),
30
#'                                   ff_dim=c(5,9),
31
#'                                   dropout=c(0.3, 0.5))
32
#' }
33
#' @returns A keras model implementing transformer architecture.
34
#' @export
35
create_model_transformer <- function(maxlen,
36
                                     vocabulary_size = 4,
37
                                     embed_dim = 64,
38
                                     pos_encoding = "embedding",
39
                                     head_size = 4L,
40
                                     num_heads = 5L,
41
                                     ff_dim = 8,
42
                                     dropout=0,
43
                                     n = 10000, # pos emb frequency
44
                                     layer_dense = 2,
45
                                     dropout_dense = NULL,
46
                                     flatten_method = "flatten",
47
                                     last_layer_activation = "softmax",
48
                                     loss_fn = "categorical_crossentropy",
49
                                     solver = "adam",
50
                                     learning_rate = 0.01,
51
                                     label_noise_matrix = NULL,
52
                                     bal_acc = FALSE,
53
                                     f1_metric = FALSE,
54
                                     auc_metric = FALSE,
55
                                     label_smoothing = 0,
56
                                     verbose = TRUE,
57
                                     model_seed = NULL,
58
                                     mixed_precision = FALSE,
59
                                     mirrored_strategy = NULL) {
60
  
61
  if (mixed_precision) tensorflow::tf$keras$mixed_precision$set_global_policy("mixed_float16")
62
  
63
  if (is.null(mirrored_strategy)) mirrored_strategy <- ifelse(count_gpu() > 1, TRUE, FALSE)
64
  if (mirrored_strategy) {
65
    mirrored_strategy <- tensorflow::tf$distribute$MirroredStrategy()
66
    with(mirrored_strategy$scope(), { 
67
      argg <- as.list(environment())
68
      argg$mirrored_strategy <- FALSE
69
      model <- do.call(create_model_transformer, argg)
70
    })
71
    return(model)
72
  }
73
  
74
  stopifnot(length(head_size) == length(num_heads))
75
  stopifnot(length(head_size) == length(dropout))
76
  stopifnot(length(head_size) == length(ff_dim))
77
  stopifnot(flatten_method %in% c("max_ch_first", "max_ch_last", "average_ch_first",
78
                                  "average_ch_last", "both_ch_first", "both_ch_last", "all", "none", "flatten"))
79
  stopifnot(pos_encoding %in% c("sinusoid", "embedding"))
80
  num_dense_layers <- length(layer_dense)
81
  head_size <- as.integer(head_size)
82
  num_heads <- as.integer(num_heads)
83
  maxlen <- as.integer(maxlen)
84
  num_attention_blocks <- length(num_heads)
85
  vocabulary_size <-  as.integer(vocabulary_size)
86
  if (!is.null(model_seed)) tensorflow::tf$random$set_seed(model_seed)
87
  
88
  if (embed_dim == 0) {
89
    input_tensor <- keras::layer_input(shape = c(maxlen, vocabulary_size))
90
  } else {
91
    input_tensor <- keras::layer_input(shape = c(maxlen))
92
  }
93
  
94
  # positional encoding 
95
  if (pos_encoding == "sinusoid") { 
96
    pos_enc_layer <- layer_pos_sinusoid_wrapper(maxlen = maxlen, vocabulary_size = vocabulary_size,
97
                                                n = n, embed_dim = embed_dim)
98
  } 
99
  if (pos_encoding == "embedding") { 
100
    pos_enc_layer <- layer_pos_embedding_wrapper(maxlen = maxlen, vocabulary_size = vocabulary_size,
101
                                                 embed_dim = embed_dim)
102
  }
103
  output_tensor <- input_tensor %>% pos_enc_layer
104
  
105
  # attention blocks
106
  for (i in 1:num_attention_blocks) {
107
    attn_block <- layer_transformer_block_wrapper(
108
      num_heads = num_heads[i],
109
      head_size = head_size[i],
110
      dropout_rate = dropout[i],
111
      ff_dim = ff_dim[i],
112
      embed_dim = embed_dim,
113
      vocabulary_size = vocabulary_size,
114
      load_r6 = FALSE)
115
    output_tensor <- output_tensor %>% attn_block
116
  }
117
  
118
  if (flatten_method != "none") {
119
    output_tensor <- pooling_flatten(global_pooling = flatten_method, output_tensor = output_tensor)
120
  }
121
  
122
  # dense layers
123
  if (num_dense_layers > 1) {
124
    for (i in 1:(num_dense_layers - 1)) {
125
      output_tensor <- output_tensor %>% keras::layer_dense(units = layer_dense[i], activation = "relu")
126
      if (!is.null(dropout_dense)) {
127
        output_tensor <- output_tensor %>% keras::layer_dropout(rate = dropout_dense[i])
128
      }
129
    }
130
  }  
131
  
132
  output_tensor <- output_tensor %>% keras::layer_dense(units = layer_dense[length(layer_dense)], activation = last_layer_activation, dtype = "float32")
133
  
134
  # create model
135
  model <- keras::keras_model(inputs = input_tensor, outputs = output_tensor)
136
  
137
  model <- compile_model(model = model, label_smoothing = label_smoothing, layer_dense = layer_dense,
138
                         solver = solver, learning_rate = learning_rate, loss_fn = loss_fn, 
139
                         num_output_layers = 1, label_noise_matrix = label_noise_matrix,
140
                         bal_acc = bal_acc, f1_metric = f1_metric, auc_metric = auc_metric)
141
  
142
  if (verbose) print(model)
143
  
144
  argg <- c(as.list(environment()))
145
  model <- add_hparam_list(model, argg)
146
  reticulate::py_set_attr(x = model, name = "hparam", value = model$hparam)
147
  
148
  model
149
  
150
}