Diff of /train_and_eval.py [000000] .. [6536f9]

Switch to unified view

a b/train_and_eval.py
1
import time
2
import torch
3
from torchmetrics import Accuracy
4
5
6
def train_model(model: torch.nn.Module,
7
               data_loader: torch.utils.data.DataLoader, 
8
               loss_fn: torch.nn.Module, #criterion
9
               optimizer: torch.optim.Optimizer,
10
               device: torch.device,
11
               num_epochs,
12
               output_shape):
13
    
14
    start_time = time.time()
15
    
16
    accuracy_metric = Accuracy(num_classes= output_shape, task='multiclass').to(device)
17
    for epoch in range(num_epochs):
18
        print(f'Epoch {epoch + 1}/{num_epochs}')
19
20
        model.train() 
21
        train_loss = 0
22
        for signal, class_label in data_loader: 
23
            signal, class_label = signal.to(device), class_label.to(device) #
24
            train_pred = model(signal)
25
            loss = loss_fn(train_pred, class_label)
26
            train_loss += loss.item() 
27
28
            accuracy_metric(train_pred, class_label)
29
30
            optimizer.zero_grad()
31
            loss.backward()
32
            optimizer.step()
33
34
        train_acc = accuracy_metric.compute() * 100  
35
        print(f"Train loss: {train_loss / len(data_loader):.5f} | Train accuracy: {train_acc:.2f}%")
36
        accuracy_metric.reset()
37
38
    total_time = (time.time() - start_time)
39
    print(f"\nTotal training time: {total_time} seconds")
40
    return total_time 
41
42
43
def evaluate_model(model: torch.nn.Module,
44
                   test_loader: torch.utils.data.DataLoader, 
45
                   loss_fn: torch.nn.Module, #criterion
46
                   device: torch.device,
47
                   output_shape):
48
    
49
    start_time = time.time()
50
51
    test_loss = 0
52
    accuracy_metric = Accuracy(num_classes=output_shape, task='multiclass').to(device)
53
54
    model.eval()
55
    with torch.inference_mode(): 
56
        for signal, class_label in test_loader:
57
            signal, class_label = signal.to(device), class_label.to(device)
58
            test_pred = model(signal)
59
            loss = loss_fn(test_pred, class_label)
60
            test_loss +=loss.item()
61
62
            accuracy_metric(test_pred, class_label)
63
64
    test_acc = accuracy_metric.compute() * 100  
65
    print(f"\nTest loss: {test_loss/len(test_loader):.5f} | Test accuracy: {test_acc:.2f}%")
66
    accuracy_metric.reset() 
67
68
    total_time = (time.time() - start_time)
69
    print(f"Total evaluation time: {total_time} seconds\n")
70
71
    return test_acc.item(), total_time