|
a |
|
b/R/create_model_transformer.R |
|
|
1 |
#' Create transformer model |
|
|
2 |
#' |
|
|
3 |
#' Creates transformer network for classification. Model can consist of several stacked attention blocks. |
|
|
4 |
#' |
|
|
5 |
#' @inheritParams keras::layer_multi_head_attention |
|
|
6 |
#' @inheritParams create_model_lstm_cnn |
|
|
7 |
#' @param pos_encoding Either `"sinusoid"` or `"embedding"`. How to add positional information. |
|
|
8 |
#' If `"sinusoid"`, will add sine waves of different frequencies to input. |
|
|
9 |
#' If `"embedding"`, model learns positional embedding. |
|
|
10 |
#' @param embed_dim Dimension for token embedding. No embedding if set to 0. Should be used when input is not one-hot encoded |
|
|
11 |
#' (integer sequence). |
|
|
12 |
#' @param head_size Dimensions of attention key. |
|
|
13 |
#' @param n Frequency of sine waves for positional encoding. Only applied if `pos_encoding = "sinusoid"`. |
|
|
14 |
#' @param ff_dim Units of first dense layer after attention blocks. |
|
|
15 |
#' @param dropout Vector of dropout rates after attention block(s). |
|
|
16 |
#' @param dropout_dense Dropout for dense layers. |
|
|
17 |
#' @param flatten_method How to process output of last attention block. Can be `"max_ch_first"`, `"max_ch_last"`, `"average_ch_first"`, |
|
|
18 |
#' `"average_ch_last"`, `"both_ch_first"`, `"both_ch_last"`, `"all"`, `"none"` or `"flatten"`. |
|
|
19 |
#' If `"average_ch_last"` / `"max_ch_last"` or `"average_ch_first"` / `"max_ch_first"`, will apply global average/max pooling. |
|
|
20 |
#' `_ch_first` / `_ch_last` to decide along which axis. `"both_ch_first"` / `"both_ch_last"` to use max and average together. `"all"` to use all 4 |
|
|
21 |
#' global pooling options together. If `"flatten"`, will flatten output after last attention block. If `"none"` no flattening applied. |
|
|
22 |
#' @examples |
|
|
23 |
#' |
|
|
24 |
#' maxlen <- 50 |
|
|
25 |
#' \donttest{ |
|
|
26 |
#' library(keras) |
|
|
27 |
#' model <- create_model_transformer(maxlen = maxlen, |
|
|
28 |
#' head_size=c(10,12), |
|
|
29 |
#' num_heads=c(7,8), |
|
|
30 |
#' ff_dim=c(5,9), |
|
|
31 |
#' dropout=c(0.3, 0.5)) |
|
|
32 |
#' } |
|
|
33 |
#' @returns A keras model implementing transformer architecture. |
|
|
34 |
#' @export |
|
|
35 |
create_model_transformer <- function(maxlen, |
|
|
36 |
vocabulary_size = 4, |
|
|
37 |
embed_dim = 64, |
|
|
38 |
pos_encoding = "embedding", |
|
|
39 |
head_size = 4L, |
|
|
40 |
num_heads = 5L, |
|
|
41 |
ff_dim = 8, |
|
|
42 |
dropout=0, |
|
|
43 |
n = 10000, # pos emb frequency |
|
|
44 |
layer_dense = 2, |
|
|
45 |
dropout_dense = NULL, |
|
|
46 |
flatten_method = "flatten", |
|
|
47 |
last_layer_activation = "softmax", |
|
|
48 |
loss_fn = "categorical_crossentropy", |
|
|
49 |
solver = "adam", |
|
|
50 |
learning_rate = 0.01, |
|
|
51 |
label_noise_matrix = NULL, |
|
|
52 |
bal_acc = FALSE, |
|
|
53 |
f1_metric = FALSE, |
|
|
54 |
auc_metric = FALSE, |
|
|
55 |
label_smoothing = 0, |
|
|
56 |
verbose = TRUE, |
|
|
57 |
model_seed = NULL, |
|
|
58 |
mixed_precision = FALSE, |
|
|
59 |
mirrored_strategy = NULL) { |
|
|
60 |
|
|
|
61 |
if (mixed_precision) tensorflow::tf$keras$mixed_precision$set_global_policy("mixed_float16") |
|
|
62 |
|
|
|
63 |
if (is.null(mirrored_strategy)) mirrored_strategy <- ifelse(count_gpu() > 1, TRUE, FALSE) |
|
|
64 |
if (mirrored_strategy) { |
|
|
65 |
mirrored_strategy <- tensorflow::tf$distribute$MirroredStrategy() |
|
|
66 |
with(mirrored_strategy$scope(), { |
|
|
67 |
argg <- as.list(environment()) |
|
|
68 |
argg$mirrored_strategy <- FALSE |
|
|
69 |
model <- do.call(create_model_transformer, argg) |
|
|
70 |
}) |
|
|
71 |
return(model) |
|
|
72 |
} |
|
|
73 |
|
|
|
74 |
stopifnot(length(head_size) == length(num_heads)) |
|
|
75 |
stopifnot(length(head_size) == length(dropout)) |
|
|
76 |
stopifnot(length(head_size) == length(ff_dim)) |
|
|
77 |
stopifnot(flatten_method %in% c("max_ch_first", "max_ch_last", "average_ch_first", |
|
|
78 |
"average_ch_last", "both_ch_first", "both_ch_last", "all", "none", "flatten")) |
|
|
79 |
stopifnot(pos_encoding %in% c("sinusoid", "embedding")) |
|
|
80 |
num_dense_layers <- length(layer_dense) |
|
|
81 |
head_size <- as.integer(head_size) |
|
|
82 |
num_heads <- as.integer(num_heads) |
|
|
83 |
maxlen <- as.integer(maxlen) |
|
|
84 |
num_attention_blocks <- length(num_heads) |
|
|
85 |
vocabulary_size <- as.integer(vocabulary_size) |
|
|
86 |
if (!is.null(model_seed)) tensorflow::tf$random$set_seed(model_seed) |
|
|
87 |
|
|
|
88 |
if (embed_dim == 0) { |
|
|
89 |
input_tensor <- keras::layer_input(shape = c(maxlen, vocabulary_size)) |
|
|
90 |
} else { |
|
|
91 |
input_tensor <- keras::layer_input(shape = c(maxlen)) |
|
|
92 |
} |
|
|
93 |
|
|
|
94 |
# positional encoding |
|
|
95 |
if (pos_encoding == "sinusoid") { |
|
|
96 |
pos_enc_layer <- layer_pos_sinusoid_wrapper(maxlen = maxlen, vocabulary_size = vocabulary_size, |
|
|
97 |
n = n, embed_dim = embed_dim) |
|
|
98 |
} |
|
|
99 |
if (pos_encoding == "embedding") { |
|
|
100 |
pos_enc_layer <- layer_pos_embedding_wrapper(maxlen = maxlen, vocabulary_size = vocabulary_size, |
|
|
101 |
embed_dim = embed_dim) |
|
|
102 |
} |
|
|
103 |
output_tensor <- input_tensor %>% pos_enc_layer |
|
|
104 |
|
|
|
105 |
# attention blocks |
|
|
106 |
for (i in 1:num_attention_blocks) { |
|
|
107 |
attn_block <- layer_transformer_block_wrapper( |
|
|
108 |
num_heads = num_heads[i], |
|
|
109 |
head_size = head_size[i], |
|
|
110 |
dropout_rate = dropout[i], |
|
|
111 |
ff_dim = ff_dim[i], |
|
|
112 |
embed_dim = embed_dim, |
|
|
113 |
vocabulary_size = vocabulary_size, |
|
|
114 |
load_r6 = FALSE) |
|
|
115 |
output_tensor <- output_tensor %>% attn_block |
|
|
116 |
} |
|
|
117 |
|
|
|
118 |
if (flatten_method != "none") { |
|
|
119 |
output_tensor <- pooling_flatten(global_pooling = flatten_method, output_tensor = output_tensor) |
|
|
120 |
} |
|
|
121 |
|
|
|
122 |
# dense layers |
|
|
123 |
if (num_dense_layers > 1) { |
|
|
124 |
for (i in 1:(num_dense_layers - 1)) { |
|
|
125 |
output_tensor <- output_tensor %>% keras::layer_dense(units = layer_dense[i], activation = "relu") |
|
|
126 |
if (!is.null(dropout_dense)) { |
|
|
127 |
output_tensor <- output_tensor %>% keras::layer_dropout(rate = dropout_dense[i]) |
|
|
128 |
} |
|
|
129 |
} |
|
|
130 |
} |
|
|
131 |
|
|
|
132 |
output_tensor <- output_tensor %>% keras::layer_dense(units = layer_dense[length(layer_dense)], activation = last_layer_activation, dtype = "float32") |
|
|
133 |
|
|
|
134 |
# create model |
|
|
135 |
model <- keras::keras_model(inputs = input_tensor, outputs = output_tensor) |
|
|
136 |
|
|
|
137 |
model <- compile_model(model = model, label_smoothing = label_smoothing, layer_dense = layer_dense, |
|
|
138 |
solver = solver, learning_rate = learning_rate, loss_fn = loss_fn, |
|
|
139 |
num_output_layers = 1, label_noise_matrix = label_noise_matrix, |
|
|
140 |
bal_acc = bal_acc, f1_metric = f1_metric, auc_metric = auc_metric) |
|
|
141 |
|
|
|
142 |
if (verbose) print(model) |
|
|
143 |
|
|
|
144 |
argg <- c(as.list(environment())) |
|
|
145 |
model <- add_hparam_list(model, argg) |
|
|
146 |
reticulate::py_set_attr(x = model, name = "hparam", value = model$hparam) |
|
|
147 |
|
|
|
148 |
model |
|
|
149 |
|
|
|
150 |
} |