a b/medicalbert/classifiers/util.py
1
##These functions allow us to delete layers from a model,
2
## Or if not possible, replace with an identity function,
3
## So that it has no trainable parameters/effect.
4
from torch import nn
5
import copy
6
7
def deleteEncodingLayers(model, num_layers_to_keep):  # must pass in the full bert model
8
    oldModuleList = model.bert.encoder.layer
9
    newModuleList = nn.ModuleList()
10
11
    # Now iterate over all layers, only keepign only the relevant layers.
12
    for i in range(0, num_layers_to_keep):
13
        newModuleList.append(oldModuleList[i])
14
15
    # create a copy of the model, modify it with the new list, and return
16
    copyOfModel = copy.deepcopy(model)
17
    copyOfModel.bert.encoder.layer = newModuleList
18
19
    return copyOfModel