--- a +++ b/src/model/utils.py @@ -0,0 +1,16 @@ +import torch + + +def find_all_linear_names(model): + cls = torch.nn.Linear + lora_module_names = set() + multimodal_keywords = ["vision_tower", "multi_modal_projector"] + for name, module in model.named_modules(): + if any(mm_keyword in name for mm_keyword in multimodal_keywords): + continue + if isinstance(module, cls): + names = name.split('.') + lora_module_names.add(names[0] if len(names) == 1 else names[-1]) + if 'lm_head' in lora_module_names: + lora_module_names.remove("lm_head") + return list(lora_module_names)