Diff of /src/model.py [000000] .. [b798eb]

Switch to unified view

a b/src/model.py
1
import numpy as np 
2
3
4
## ====== Torch imports ======
5
import torch
6
import pytorch_lightning as pl 
7
import torch.nn as nn 
8
import torch.optim as optim
9
from lightning.pytorch.utilities.types import OptimizerLRScheduler
10
import torch.utils.data
11
from torch.nn import functional as F
12
import lightning.pytorch.loggers
13
from torchmetrics import ConfusionMatrix
14
from torchmetrics.classification import F1Score
15
from sklearn.metrics import f1_score, confusion_matrix
16
17
class LModel(pl.LightningModule):
18
    def __init__(self,
19
                dim: int = None,
20
                output_dim: int = 8,
21
                batch_size: int = 20,
22
                lr: float = 1e-2,
23
                weight_decay: float=1e-3,
24
                ):
25
        super().__init__()
26
        self.init_dim = dim,
27
        self.batch_size = batch_size
28
        self.lr = lr
29
        self.weight_decay = weight_decay 
30
        self.output_dim = output_dim,
31
        #self.loss_fun = nn.CrossEntropyLoss()
32
        self.loss_fun = nn.BCELoss()
33
        self.f1 = F1Score(task='binary', average='micro')
34
        self.test_preds = None
35
        self.test_labels = None
36
        
37
        self.embedder = nn.Sequential(
38
            nn.Linear(dim,2**12),
39
            nn.LeakyReLU(),
40
            # nn.Linear(2**12, 2**10),
41
            # nn.LeakyReLU(),
42
            # nn.Linear(2**10, 2**8),
43
            # nn.LeakyReLU(),
44
            nn.Linear(2**12, 2**8),
45
            nn.LeakyReLU(),
46
            nn.Linear(2**8, 2**6),
47
            nn.LeakyReLU(),
48
            nn.Linear(2**6, output_dim),
49
            nn.LeakyReLU()
50
        )
51
        self.classifier = nn.Sequential(
52
            #nn.Linear(2**4,2**3),
53
            #nn.LeakyReLU(),
54
            nn.Linear(output_dim,1),
55
            nn.Softmax(dim=1)
56
        )
57
        
58
    def forward(self, x):
59
        x = self.embedder(x)
60
        x = self.classifier(x)  
61
        return x
62
63
    def on_train_epoch_start(self) -> None:
64
        self.train_loss = 0
65
        
66
    def training_step(self, batch):
67
        X_batch, y_batch = batch
68
        y = y_batch.unsqueeze(1)
69
        #y_float = y.float()
70
        X_embedded = self.embedder(X_batch)
71
        y_pred = self.classifier(X_embedded)
72
        loss = self.loss_fun(y_pred, y)
73
        f1 = self.f1(y_pred, y)
74
        
75
        #loss = self.loss_fun(y_pred, y_batch)
76
        # loss = (self.loss_fun(y_pred,y_batch) + torch.sum(torch.cat([torch.flatten(torch.abs(x)) for x in self.embedder.parameters() ]))*.0016
77
        #         + torch.sum(torch.cat([torch.flatten(torch.square(x)) for x in self.embedder.parameters() ]))*.00255 )
78
        # acc = (y_pred.max(1).indices == y_batch.max(1).indices).sum().item()/y_pred.shape[0]
79
        self.log("train_loss", 
80
                loss, 
81
                prog_bar=False, 
82
                on_step=False, 
83
                on_epoch=True)
84
        self.log("train_acc", 
85
                f1, 
86
                prog_bar=False, 
87
                on_step=False, 
88
                on_epoch=True)
89
        
90
91
        return loss
92
        
93
    def validation_step(self, batch):
94
        print("Entered validation")
95
        X_batch, y_batch = batch
96
        y = y_batch.unsqueeze(1)
97
        #y_float = y.float()
98
        X_embedded = self.embedder(X_batch)
99
        y_pred = self.classifier(X_embedded)
100
        #loss = self.loss_fun(y_pred,y_batch)
101
        loss = self.loss_fun(y_pred, y)
102
        f1 = self.f1(y_pred, y)
103
        print(f1)
104
        self.log("val_loss",
105
                loss, 
106
                prog_bar=False,
107
                on_step=False,
108
                on_epoch=True)
109
        self.log("val_acc", 
110
                f1,
111
                prog_bar=False,
112
                on_step=False,
113
                on_epoch=True)
114
        return loss
115
    
116
    def configure_optimizers(self) ->   OptimizerLRScheduler:
117
        optimizer = torch.optim.AdamW(list(self.embedder.parameters()) +
118
                                        list(self.classifier.parameters()),
119
                                        lr=self.lr,
120
                                        weight_decay=self.weight_decay) 
121
        return optimizer
122
    
123
    def test_step(self, batch):
124
        X_batch,y_batch = batch
125
        y = y_batch.unsqueeze(1)
126
        #y_float = y.float()
127
        y_pred = self(X_batch)
128
        loss = self.loss_fun(y_pred, y)
129
        f1 = self.f1(y_pred, y)
130
        #acc = (y_pred.max(1).indices == y_batch.max(1).indices).sum().item()/y_pred.shape[0]
131
        #X_embedded = self.embedder(X)
132
        
133
        self.log("test_loss", 
134
                loss, 
135
                on_step=False,
136
                on_epoch=True)
137
        self.log("test_acc",
138
                f1,
139
                on_step=False, 
140
                on_epoch=True)
141
        
142
        return {'loss': loss, 
143
                'f1score': f1}