Switch to unified view

a b/src/ml_models/bilstm/model.py
1
# Base Dependencies
2
# -----------------
3
from typing import Dict
4
5
# Package Dependencies
6
# --------------------
7
from .embeddings import Embeddings, RDEmbeddings
8
from .encoders import RecurrentEncoder
9
from .config import EmbeddingConfig, RDEmbeddingConfig, LSTMConfig
10
11
# Local Dependencies
12
# ------------------
13
from vocabulary import Vocabulary
14
15
# PyTorch Dependencies
16
# ---------------------
17
from torch import nn, Tensor, concat, mean
18
19
20
# Model
21
# -----
22
class HasanModel(nn.Module):
23
    """
24
    Implementation of the BiLSTM model described in `Hasan et al. (2020) - Integrating
25
    Text Embedding with Traditional NLP Features for Clinical Relation Extraction`
26
    """
27
28
    def __init__(
29
        self,
30
        vocab: Vocabulary,
31
        lstm_config: LSTMConfig,
32
        bioword2vec_config: EmbeddingConfig,
33
        rd_config: RDEmbeddingConfig,
34
        pos_config: EmbeddingConfig,
35
        dep_config: EmbeddingConfig,
36
        iob_config: EmbeddingConfig,
37
        num_classes: int = 2,
38
        clf_dropout: float = 0.25,
39
    ):
40
        """Initializes the model
41
42
        Args:
43
            vocab (Vocabulary): vocabulary object
44
            lstm_config (LSTMConfig): configuration for the LSTM encoder
45
            bioword2vec_config (EmbeddingConfig): configuration for the BioWord2Vec embedding
46
            rd_config (RDEmbeddingConfig): configuration for the Relative Distance embedding
47
            pos_config (EmbeddingConfig): configuration for the POS embedding
48
            dep_config (EmbeddingConfig): configuration for the DEP embedding
49
            iob_config (EmbeddingConfig): configuration for the IOB embedding
50
            num_classes (int, optional): number of output classes. Defaults to 2.
51
            clf_dropout (float, optional): dropout rate. Defaults to 0.1.
52
        """
53
        super(HasanModel, self).__init__()
54
55
        # attributes
56
        self.vocab = vocab
57
        self.lstm_config = lstm_config
58
        self.bioword2vec_config = bioword2vec_config
59
        self.rd_config = rd_config
60
        self.pos_config = pos_config
61
        self.dep_config = dep_config
62
        self.iob_config = iob_config
63
        self.num_classes = num_classes
64
        self.num_directions = 2 if self.lstm_config.bidirectional else 1
65
        self.clf_hidden_dim = 64
66
        self.clf_dropout = clf_dropout
67
68
        # embedding layers
69
        self.wv_embedding = Embeddings(**self.bioword2vec_config.__dict__)
70
        self.rd_embedding = RDEmbeddings(**self.rd_config.__dict__)
71
        self.pos_embedding = Embeddings(**self.pos_config.__dict__)
72
        self.dep_embedding = Embeddings(**self.dep_config.__dict__)
73
        self.iob_embedding = Embeddings(**self.iob_config.__dict__)
74
75
        # BiLSTM encoder
76
        self.lstm = RecurrentEncoder(rnn_type="lstm", **self.lstm_config.__dict__)
77
78
        # classifier
79
        self.fc = nn.Sequential(
80
            nn.Dropout(p=self.clf_dropout),
81
            nn.Linear(self.clf_input_dim, self.clf_hidden_dim),
82
            nn.ReLU(),
83
            nn.Dropout(p=self.clf_dropout),
84
            nn.Linear(self.clf_hidden_dim, self.num_classes),
85
            nn.ReLU(),
86
            nn.Sigmoid(),
87
        )
88
89
        # load pretrained embeddings
90
        self.wv_embedding.load_from_file(self.bioword2vec_config.emb_path, self.vocab)
91
92
    @property
93
    def clf_input_dim(self) -> int:
94
        """Input dimensions of the classifier"""
95
        return (self.num_directions * self.lstm_config.hidden_size) + (
96
            2 * self.wv_embedding.embedding_dim
97
        )
98
99
    def forward(self, inputs: Dict[str, Tensor]) -> Tensor:
100
        """Forward pass of the model
101
102
        Args:
103
            inputs (Dict[str, Tensor]): input tensors
104
105
        Returns:
106
            Tensor: output tensor
107
        """
108
        e1: Tensor = inputs["e1"]  # [batch_size, max_len_e1]
109
        e2: Tensor = inputs["e2"]  # [batch_size, max_len_e2]
110
        sent: Tensor = inputs["sent"]  # [batch_size, max_len_seq]
111
        rd1: Tensor = inputs["rd1"]  # [batch_size, max_len_seq]
112
        rd2: Tensor = inputs["rd2"]  # [batch_size, max_len_seq]
113
        pos: Tensor = inputs["pos"]  # [batch_size, max_len_seq]
114
        dep: Tensor = inputs["dep"]  # [batch_size, max_len_seq]
115
        iob: Tensor = inputs["iob"]  # [batch_size, max_len_seq]
116
        seq_length: Tensor = inputs["seq_length"]  # [batch_size]
117
118
        assert len(e1.shape) == 2
119
        assert len(e2.shape) == 2
120
        assert len(sent.shape) == 2
121
        assert len(rd1.shape) == 2
122
        assert len(rd2.shape) == 2
123
        assert len(pos.shape) == 2
124
        assert len(dep.shape) == 2 
125
        assert len(iob.shape) == 2
126
        assert len(seq_length.shape) == 1
127
128
        # embedded inputs
129
        e1_emb = mean(self.wv_embedding(e1), axis=1)  # [batch_size, wv_emb_dim]
130
        e2_emb = mean(self.wv_embedding(e2), axis=1)  # [batch_size, wv_emb_dim]
131
        sent_emb = self.wv_embedding(sent)  # [batch_size, seq_length, wv_emb_dim]
132
        rd1_emb = self.rd_embedding(rd1)  # [batch_size, seq_length, rd_emb_dim]
133
        rd2_emb = self.rd_embedding(rd2)  # [batch_size, seq_length, rd_emb_dim]
134
        pos_emb = self.pos_embedding(pos)  # [batch_size, seq_length, pos_emb_dim]
135
        dep_emb = self.dep_embedding(dep)  # [batch_size, seq_length, pos_emb_dim]
136
        iob_emb = self.iob_embedding(iob)  # [batch_size, seq_length, iob_emb_dim]
137
138
        # encode
139
        inputs_emb = concat((sent_emb, rd1_emb, rd2_emb, pos_emb, dep_emb, iob_emb), axis=2)
140
        outputs_emb, hidden_concat = self.lstm(inputs_emb, seq_length)
141
        outputs = concat((e1_emb, e2_emb, hidden_concat), axis=1)
142
143
        # classify
144
        outputs = self.fc(outputs)
145
146
        return outputs