|
a |
|
b/train_and_evaluate_model.py |
|
|
1 |
from tensorflow.keras.callbacks import EarlyStopping |
|
|
2 |
|
|
|
3 |
def train_and_evaluate_model(model, train_sequences, train_labels, test_sequences, test_labels): |
|
|
4 |
model.compile(loss='sparse_categorical_crossentropy', optimizer='adam', metrics=['accuracy']) |
|
|
5 |
|
|
|
6 |
early_stopping = EarlyStopping(patience=3, monitor='val_loss', restore_best_weights=True) |
|
|
7 |
|
|
|
8 |
model.fit(train_sequences, train_labels, validation_data=(test_sequences, test_labels), |
|
|
9 |
epochs=10, batch_size=32, callbacks=[early_stopping]) |
|
|
10 |
|
|
|
11 |
_, accuracy = model.evaluate(test_sequences, test_labels) |
|
|
12 |
|
|
|
13 |
return accuracy |