|
a |
|
b/medicalbert/classifiers/util.py |
|
|
1 |
##These functions allow us to delete layers from a model, |
|
|
2 |
## Or if not possible, replace with an identity function, |
|
|
3 |
## So that it has no trainable parameters/effect. |
|
|
4 |
from torch import nn |
|
|
5 |
import copy |
|
|
6 |
|
|
|
7 |
def deleteEncodingLayers(model, num_layers_to_keep): # must pass in the full bert model |
|
|
8 |
oldModuleList = model.bert.encoder.layer |
|
|
9 |
newModuleList = nn.ModuleList() |
|
|
10 |
|
|
|
11 |
# Now iterate over all layers, only keepign only the relevant layers. |
|
|
12 |
for i in range(0, num_layers_to_keep): |
|
|
13 |
newModuleList.append(oldModuleList[i]) |
|
|
14 |
|
|
|
15 |
# create a copy of the model, modify it with the new list, and return |
|
|
16 |
copyOfModel = copy.deepcopy(model) |
|
|
17 |
copyOfModel.bert.encoder.layer = newModuleList |
|
|
18 |
|
|
|
19 |
return copyOfModel |