|
a |
|
b/src/ml_models/bilstm/model.py |
|
|
1 |
# Base Dependencies |
|
|
2 |
# ----------------- |
|
|
3 |
from typing import Dict |
|
|
4 |
|
|
|
5 |
# Package Dependencies |
|
|
6 |
# -------------------- |
|
|
7 |
from .embeddings import Embeddings, RDEmbeddings |
|
|
8 |
from .encoders import RecurrentEncoder |
|
|
9 |
from .config import EmbeddingConfig, RDEmbeddingConfig, LSTMConfig |
|
|
10 |
|
|
|
11 |
# Local Dependencies |
|
|
12 |
# ------------------ |
|
|
13 |
from vocabulary import Vocabulary |
|
|
14 |
|
|
|
15 |
# PyTorch Dependencies |
|
|
16 |
# --------------------- |
|
|
17 |
from torch import nn, Tensor, concat, mean |
|
|
18 |
|
|
|
19 |
|
|
|
20 |
# Model |
|
|
21 |
# ----- |
|
|
22 |
class HasanModel(nn.Module): |
|
|
23 |
""" |
|
|
24 |
Implementation of the BiLSTM model described in `Hasan et al. (2020) - Integrating |
|
|
25 |
Text Embedding with Traditional NLP Features for Clinical Relation Extraction` |
|
|
26 |
""" |
|
|
27 |
|
|
|
28 |
def __init__( |
|
|
29 |
self, |
|
|
30 |
vocab: Vocabulary, |
|
|
31 |
lstm_config: LSTMConfig, |
|
|
32 |
bioword2vec_config: EmbeddingConfig, |
|
|
33 |
rd_config: RDEmbeddingConfig, |
|
|
34 |
pos_config: EmbeddingConfig, |
|
|
35 |
dep_config: EmbeddingConfig, |
|
|
36 |
iob_config: EmbeddingConfig, |
|
|
37 |
num_classes: int = 2, |
|
|
38 |
clf_dropout: float = 0.25, |
|
|
39 |
): |
|
|
40 |
"""Initializes the model |
|
|
41 |
|
|
|
42 |
Args: |
|
|
43 |
vocab (Vocabulary): vocabulary object |
|
|
44 |
lstm_config (LSTMConfig): configuration for the LSTM encoder |
|
|
45 |
bioword2vec_config (EmbeddingConfig): configuration for the BioWord2Vec embedding |
|
|
46 |
rd_config (RDEmbeddingConfig): configuration for the Relative Distance embedding |
|
|
47 |
pos_config (EmbeddingConfig): configuration for the POS embedding |
|
|
48 |
dep_config (EmbeddingConfig): configuration for the DEP embedding |
|
|
49 |
iob_config (EmbeddingConfig): configuration for the IOB embedding |
|
|
50 |
num_classes (int, optional): number of output classes. Defaults to 2. |
|
|
51 |
clf_dropout (float, optional): dropout rate. Defaults to 0.1. |
|
|
52 |
""" |
|
|
53 |
super(HasanModel, self).__init__() |
|
|
54 |
|
|
|
55 |
# attributes |
|
|
56 |
self.vocab = vocab |
|
|
57 |
self.lstm_config = lstm_config |
|
|
58 |
self.bioword2vec_config = bioword2vec_config |
|
|
59 |
self.rd_config = rd_config |
|
|
60 |
self.pos_config = pos_config |
|
|
61 |
self.dep_config = dep_config |
|
|
62 |
self.iob_config = iob_config |
|
|
63 |
self.num_classes = num_classes |
|
|
64 |
self.num_directions = 2 if self.lstm_config.bidirectional else 1 |
|
|
65 |
self.clf_hidden_dim = 64 |
|
|
66 |
self.clf_dropout = clf_dropout |
|
|
67 |
|
|
|
68 |
# embedding layers |
|
|
69 |
self.wv_embedding = Embeddings(**self.bioword2vec_config.__dict__) |
|
|
70 |
self.rd_embedding = RDEmbeddings(**self.rd_config.__dict__) |
|
|
71 |
self.pos_embedding = Embeddings(**self.pos_config.__dict__) |
|
|
72 |
self.dep_embedding = Embeddings(**self.dep_config.__dict__) |
|
|
73 |
self.iob_embedding = Embeddings(**self.iob_config.__dict__) |
|
|
74 |
|
|
|
75 |
# BiLSTM encoder |
|
|
76 |
self.lstm = RecurrentEncoder(rnn_type="lstm", **self.lstm_config.__dict__) |
|
|
77 |
|
|
|
78 |
# classifier |
|
|
79 |
self.fc = nn.Sequential( |
|
|
80 |
nn.Dropout(p=self.clf_dropout), |
|
|
81 |
nn.Linear(self.clf_input_dim, self.clf_hidden_dim), |
|
|
82 |
nn.ReLU(), |
|
|
83 |
nn.Dropout(p=self.clf_dropout), |
|
|
84 |
nn.Linear(self.clf_hidden_dim, self.num_classes), |
|
|
85 |
nn.ReLU(), |
|
|
86 |
nn.Sigmoid(), |
|
|
87 |
) |
|
|
88 |
|
|
|
89 |
# load pretrained embeddings |
|
|
90 |
self.wv_embedding.load_from_file(self.bioword2vec_config.emb_path, self.vocab) |
|
|
91 |
|
|
|
92 |
@property |
|
|
93 |
def clf_input_dim(self) -> int: |
|
|
94 |
"""Input dimensions of the classifier""" |
|
|
95 |
return (self.num_directions * self.lstm_config.hidden_size) + ( |
|
|
96 |
2 * self.wv_embedding.embedding_dim |
|
|
97 |
) |
|
|
98 |
|
|
|
99 |
def forward(self, inputs: Dict[str, Tensor]) -> Tensor: |
|
|
100 |
"""Forward pass of the model |
|
|
101 |
|
|
|
102 |
Args: |
|
|
103 |
inputs (Dict[str, Tensor]): input tensors |
|
|
104 |
|
|
|
105 |
Returns: |
|
|
106 |
Tensor: output tensor |
|
|
107 |
""" |
|
|
108 |
e1: Tensor = inputs["e1"] # [batch_size, max_len_e1] |
|
|
109 |
e2: Tensor = inputs["e2"] # [batch_size, max_len_e2] |
|
|
110 |
sent: Tensor = inputs["sent"] # [batch_size, max_len_seq] |
|
|
111 |
rd1: Tensor = inputs["rd1"] # [batch_size, max_len_seq] |
|
|
112 |
rd2: Tensor = inputs["rd2"] # [batch_size, max_len_seq] |
|
|
113 |
pos: Tensor = inputs["pos"] # [batch_size, max_len_seq] |
|
|
114 |
dep: Tensor = inputs["dep"] # [batch_size, max_len_seq] |
|
|
115 |
iob: Tensor = inputs["iob"] # [batch_size, max_len_seq] |
|
|
116 |
seq_length: Tensor = inputs["seq_length"] # [batch_size] |
|
|
117 |
|
|
|
118 |
assert len(e1.shape) == 2 |
|
|
119 |
assert len(e2.shape) == 2 |
|
|
120 |
assert len(sent.shape) == 2 |
|
|
121 |
assert len(rd1.shape) == 2 |
|
|
122 |
assert len(rd2.shape) == 2 |
|
|
123 |
assert len(pos.shape) == 2 |
|
|
124 |
assert len(dep.shape) == 2 |
|
|
125 |
assert len(iob.shape) == 2 |
|
|
126 |
assert len(seq_length.shape) == 1 |
|
|
127 |
|
|
|
128 |
# embedded inputs |
|
|
129 |
e1_emb = mean(self.wv_embedding(e1), axis=1) # [batch_size, wv_emb_dim] |
|
|
130 |
e2_emb = mean(self.wv_embedding(e2), axis=1) # [batch_size, wv_emb_dim] |
|
|
131 |
sent_emb = self.wv_embedding(sent) # [batch_size, seq_length, wv_emb_dim] |
|
|
132 |
rd1_emb = self.rd_embedding(rd1) # [batch_size, seq_length, rd_emb_dim] |
|
|
133 |
rd2_emb = self.rd_embedding(rd2) # [batch_size, seq_length, rd_emb_dim] |
|
|
134 |
pos_emb = self.pos_embedding(pos) # [batch_size, seq_length, pos_emb_dim] |
|
|
135 |
dep_emb = self.dep_embedding(dep) # [batch_size, seq_length, pos_emb_dim] |
|
|
136 |
iob_emb = self.iob_embedding(iob) # [batch_size, seq_length, iob_emb_dim] |
|
|
137 |
|
|
|
138 |
# encode |
|
|
139 |
inputs_emb = concat((sent_emb, rd1_emb, rd2_emb, pos_emb, dep_emb, iob_emb), axis=2) |
|
|
140 |
outputs_emb, hidden_concat = self.lstm(inputs_emb, seq_length) |
|
|
141 |
outputs = concat((e1_emb, e2_emb, hidden_concat), axis=1) |
|
|
142 |
|
|
|
143 |
# classify |
|
|
144 |
outputs = self.fc(outputs) |
|
|
145 |
|
|
|
146 |
return outputs |