|
a |
|
b/code/3_train.py |
|
|
1 |
import config |
|
|
2 |
from utils_model import train_resnet |
|
|
3 |
|
|
|
4 |
# Training the ResNet. |
|
|
5 |
print("\n\n+++++ Running 3_train.py +++++") |
|
|
6 |
train_resnet(batch_size=config.args.batch_size, |
|
|
7 |
checkpoints_folder=config.args.checkpoints_folder, |
|
|
8 |
classes=config.classes, |
|
|
9 |
color_jitter_brightness=config.args.color_jitter_brightness, |
|
|
10 |
color_jitter_contrast=config.args.color_jitter_contrast, |
|
|
11 |
color_jitter_hue=config.args.color_jitter_hue, |
|
|
12 |
color_jitter_saturation=config.args.color_jitter_saturation, |
|
|
13 |
device=config.device, |
|
|
14 |
learning_rate=config.args.learning_rate, |
|
|
15 |
learning_rate_decay=config.args.learning_rate_decay, |
|
|
16 |
log_csv=config.log_csv, |
|
|
17 |
num_classes=config.num_classes, |
|
|
18 |
num_layers=config.args.num_layers, |
|
|
19 |
num_workers=config.args.num_workers, |
|
|
20 |
path_mean=config.path_mean, |
|
|
21 |
path_std=config.path_std, |
|
|
22 |
pretrain=config.args.pretrain, |
|
|
23 |
resume_checkpoint=config.args.resume_checkpoint, |
|
|
24 |
resume_checkpoint_path=config.resume_checkpoint_path, |
|
|
25 |
save_interval=config.args.save_interval, |
|
|
26 |
num_epochs=config.args.num_epochs, |
|
|
27 |
train_folder=config.args.train_folder, |
|
|
28 |
weight_decay=config.args.weight_decay) |
|
|
29 |
print("+++++ Finished running 3_train.py +++++\n\n") |