--- a +++ b/medicalbert/classifiers/util.py @@ -0,0 +1,19 @@ +##These functions allow us to delete layers from a model, +## Or if not possible, replace with an identity function, +## So that it has no trainable parameters/effect. +from torch import nn +import copy + +def deleteEncodingLayers(model, num_layers_to_keep): # must pass in the full bert model + oldModuleList = model.bert.encoder.layer + newModuleList = nn.ModuleList() + + # Now iterate over all layers, only keepign only the relevant layers. + for i in range(0, num_layers_to_keep): + newModuleList.append(oldModuleList[i]) + + # create a copy of the model, modify it with the new list, and return + copyOfModel = copy.deepcopy(model) + copyOfModel.bert.encoder.layer = newModuleList + + return copyOfModel \ No newline at end of file