|
a |
|
b/src/model/utils.py |
|
|
1 |
import torch |
|
|
2 |
|
|
|
3 |
|
|
|
4 |
def find_all_linear_names(model): |
|
|
5 |
cls = torch.nn.Linear |
|
|
6 |
lora_module_names = set() |
|
|
7 |
multimodal_keywords = ["vision_tower", "multi_modal_projector"] |
|
|
8 |
for name, module in model.named_modules(): |
|
|
9 |
if any(mm_keyword in name for mm_keyword in multimodal_keywords): |
|
|
10 |
continue |
|
|
11 |
if isinstance(module, cls): |
|
|
12 |
names = name.split('.') |
|
|
13 |
lora_module_names.add(names[0] if len(names) == 1 else names[-1]) |
|
|
14 |
if 'lm_head' in lora_module_names: |
|
|
15 |
lora_module_names.remove("lm_head") |
|
|
16 |
return list(lora_module_names) |