a b/train_test.py
1
"""
2
Training and testing for OmiEmbed
3
"""
4
import time
5
import warnings
6
from util import util
7
from params.train_test_params import TrainTestParams
8
from datasets import create_separate_dataloader
9
from models import create_model
10
from util.visualizer import Visualizer
11
12
13
if __name__ == "__main__":
14
    warnings.filterwarnings('ignore')
15
    full_start_time = time.time()
16
    # Get parameters
17
    param = TrainTestParams().parse()
18
    if param.deterministic:
19
        util.setup_seed(param.seed)
20
21
    # Dataset related
22
    full_dataloader, train_dataloader, val_dataloader, test_dataloader = create_separate_dataloader(param)
23
    print('The size of training set is {}'.format(len(train_dataloader)))
24
    # Get sample list for the dataset
25
    param.sample_list = full_dataloader.get_sample_list()
26
    # Get the dimension of input omics data
27
    param.omics_dims = full_dataloader.get_omics_dims()
28
    if param.downstream_task in ['classification', 'multitask', 'alltask']:
29
        # Get the number of classes for the classification task
30
        if param.class_num == 0:
31
            param.class_num = full_dataloader.get_class_num()
32
        if param.downstream_task != 'alltask':
33
            print('The number of classes: {}'.format(param.class_num))
34
    if param.downstream_task in ['regression', 'multitask', 'alltask']:
35
        # Get the range of the target values
36
        values_min = full_dataloader.get_values_min()
37
        values_max = full_dataloader.get_values_max()
38
        if param.regression_scale == 1:
39
            param.regression_scale = values_max
40
        print('The range of the target values is [{}, {}]'.format(values_min, values_max))
41
    if param.downstream_task in ['survival', 'multitask', 'alltask']:
42
        # Get the range of T
43
        survival_T_min = full_dataloader.get_survival_T_min()
44
        survival_T_max = full_dataloader.get_survival_T_max()
45
        if param.survival_T_max == -1:
46
            param.survival_T_max = survival_T_max
47
        print('The range of survival T is [{}, {}]'.format(survival_T_min, survival_T_max))
48
49
    # Model related
50
    model = create_model(param)     # Create a model given param.model and other parameters
51
    model.setup(param)              # Regular setup for the model: load and print networks, create schedulers
52
    visualizer = Visualizer(param)  # Create a visualizer to print results
53
54
    # Start the epoch loop
55
    visualizer.print_phase(model.phase)
56
    for epoch in range(param.epoch_count, param.epoch_num + 1):     # outer loop for different epochs
57
        epoch_start_time = time.time()                              # Start time of this epoch
58
        model.epoch = epoch
59
        # TRAINING
60
        model.set_train()                                           # Set train mode for training
61
        iter_load_start_time = time.time()                          # Start time of data loading for this iteration
62
        output_dict, losses_dict, metrics_dict = model.init_log_dict()          # Initialize the log dictionaries
63
        if epoch == param.epoch_num_p1 + 1:
64
            model.phase = 'p2'                                      # Change to supervised phase
65
            visualizer.print_phase(model.phase)
66
        if epoch == param.epoch_num_p1 + param.epoch_num_p2 + 1:
67
            model.phase = 'p3'                                      # Change to supervised phase
68
            visualizer.print_phase(model.phase)
69
70
        # Start training loop
71
        for i, data in enumerate(train_dataloader):                 # Inner loop for different iteration within one epoch
72
            model.iter = i
73
            dataset_size = len(train_dataloader)
74
            actual_batch_size = len(data['index'])
75
            iter_start_time = time.time()                           # Timer for computation per iteration
76
            if i % param.print_freq == 0:
77
                load_time = iter_start_time - iter_load_start_time  # Data loading time for this iteration
78
            model.set_input(data)                                   # Unpack input data from the output dictionary of the dataloader
79
            model.update()                                          # Calculate losses, gradients and update network parameters
80
            model.update_log_dict(output_dict, losses_dict, metrics_dict, actual_batch_size)       # Update the log dictionaries
81
            if i % param.print_freq == 0:                           # Print training losses and save logging information to the disk
82
                comp_time = time.time() - iter_start_time           # Computational time for this iteration
83
                visualizer.print_train_log(epoch, i, losses_dict, metrics_dict, load_time, comp_time, param.batch_size, dataset_size)
84
            iter_load_start_time = time.time()
85
86
        # Model saving
87
        if param.save_model:
88
            if param.save_epoch_freq == -1:  # Only save networks during last epoch
89
                if epoch == param.epoch_num:
90
                    print('Saving the model at the end of epoch {:d}'.format(epoch))
91
                    model.save_networks(str(epoch))
92
            elif epoch % param.save_epoch_freq == 0:                # Save both the generator and the discriminator every <save_epoch_freq> epochs
93
                print('Saving the model at the end of epoch {:d}'.format(epoch))
94
                # model.save_networks('latest')
95
                model.save_networks(str(epoch))
96
97
        train_time = time.time() - epoch_start_time
98
        current_lr = model.update_learning_rate()  # update learning rates at the end of each epoch
99
        visualizer.print_train_summary(epoch, losses_dict, output_dict, train_time, current_lr)
100
101
        # TESTING
102
        model.set_eval()                                            # Set eval mode for testing
103
        test_start_time = time.time()                               # Start time of testing
104
        output_dict, losses_dict, metrics_dict = model.init_log_dict()  # Initialize the log dictionaries
105
106
        # Start testing loop
107
        for i, data in enumerate(test_dataloader):
108
            dataset_size = len(test_dataloader)
109
            actual_batch_size = len(data['index'])
110
            model.set_input(data)                                   # Unpack input data from the output dictionary of the dataloader
111
            model.test()                                            # Run forward to get the output tensors
112
            model.update_log_dict(output_dict, losses_dict, metrics_dict, actual_batch_size)  # Update the log dictionaries
113
            if i % param.print_freq == 0:                           # Print testing log
114
                visualizer.print_test_log(epoch, i, losses_dict, metrics_dict, param.batch_size, dataset_size)
115
116
        test_time = time.time() - test_start_time
117
        visualizer.print_test_summary(epoch, losses_dict, output_dict, test_time)
118
        if epoch == param.epoch_num:
119
            visualizer.save_output_dict(output_dict)
120
121
    full_time = time.time() - full_start_time
122
    print('Full running time: {:.3f}s'.format(full_time))