Switch to unified view

a b/test/diseasedb/linear_model.py
1
import torch
2
from torch import nn
3
import os
4
5
6
class LinearModel(nn.Module):
7
    def __init__(self, label_count, embedding_type, embedding, freeze_embedding=True):
8
        super(LinearModel, self).__init__()
9
        self.embedding_type = embedding_type
10
        if self.embedding_type in ["word", "cui"]:
11
            self.embedding = nn.Embedding.from_pretrained(embedding)
12
            self.input_dim = self.embedding.weight.shape[1]
13
            if freeze_embedding:
14
                self.embedding.weight.required_grad = False
15
        if self.embedding_type == "bert":
16
            self.embedding = embedding
17
            self.input_dim = 768
18
            if freeze_embedding:
19
                for name, param in self.embedding.named_parameters():
20
                    param.requires_grad = False
21
        self.linear = nn.Linear(self.input_dim * 2, label_count)
22
        self.loss_fn = nn.CrossEntropyLoss()
23
24
    def forward(self, x0, x1, length_0=None, length_1=None, label=None):
25
        count = x0.shape[0]
26
        x = torch.cat((x0, x1), dim=0)
27
        emb = self.embedding(x)
28
29
        #print(x.shape, emb.shape, length_0.shape)
30
31
        if self.embedding_type == "word":
32
            emb = torch.sum(emb, dim=1)
33
            length = torch.cat((length_0, length_1)).reshape(-1, 1).expand_as(emb)
34
            emb = emb / length
35
        if self.embedding_type == "cui":
36
            pass
37
        if self.embedding_type == "bert":
38
            emb = emb[1]
39
40
        emb_0 = emb[0:count]
41
        emb_1 = emb[count:]
42
        feature = torch.cat((emb_0, emb_1), dim=1)
43
        pred = self.linear(feature)
44
45
        if label is not None:
46
            loss = self.loss_fn(pred, label)
47
            return pred, loss
48
        return pred, 0.