[dd9da8]: / code / 3_train.py

Download this file

30 lines (28 with data), 1.4 kB

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import config
from utils_model import train_resnet
# Training the ResNet.
print("\n\n+++++ Running 3_train.py +++++")
train_resnet(batch_size=config.args.batch_size,
checkpoints_folder=config.args.checkpoints_folder,
classes=config.classes,
color_jitter_brightness=config.args.color_jitter_brightness,
color_jitter_contrast=config.args.color_jitter_contrast,
color_jitter_hue=config.args.color_jitter_hue,
color_jitter_saturation=config.args.color_jitter_saturation,
device=config.device,
learning_rate=config.args.learning_rate,
learning_rate_decay=config.args.learning_rate_decay,
log_csv=config.log_csv,
num_classes=config.num_classes,
num_layers=config.args.num_layers,
num_workers=config.args.num_workers,
path_mean=config.path_mean,
path_std=config.path_std,
pretrain=config.args.pretrain,
resume_checkpoint=config.args.resume_checkpoint,
resume_checkpoint_path=config.resume_checkpoint_path,
save_interval=config.args.save_interval,
num_epochs=config.args.num_epochs,
train_folder=config.args.train_folder,
weight_decay=config.args.weight_decay)
print("+++++ Finished running 3_train.py +++++\n\n")