Diff of /src/model/utils.py [000000] .. [780764]

Switch to side-by-side view

--- a
+++ b/src/model/utils.py
@@ -0,0 +1,16 @@
+import torch
+
+
+def find_all_linear_names(model):
+    cls = torch.nn.Linear
+    lora_module_names = set()
+    multimodal_keywords = ["vision_tower", "multi_modal_projector"]
+    for name, module in model.named_modules():
+        if any(mm_keyword in name for mm_keyword in multimodal_keywords):
+            continue
+        if isinstance(module, cls):
+            names = name.split('.')
+            lora_module_names.add(names[0] if len(names) == 1 else names[-1])
+    if 'lm_head' in lora_module_names:
+        lora_module_names.remove("lm_head")
+    return list(lora_module_names)