|
38 | 38 | C3AConfig, |
39 | 39 | DeloraConfig, |
40 | 40 | FourierFTConfig, |
| 41 | + GraloraConfig, |
41 | 42 | HRAConfig, |
42 | 43 | IA3Config, |
43 | 44 | LNTuningConfig, |
|
665 | 666 | "init_weights": True, |
666 | 667 | }, |
667 | 668 | ), |
| 669 | + ########### |
| 670 | + # GraLoRA # |
| 671 | + ########### |
| 672 | + ("Vanilla MLP 1 GraLoRA", "MLP", GraloraConfig, {"target_modules": "lin0"}), |
| 673 | + ("Vanilla MLP 2 GraLoRA", "MLP", GraloraConfig, {"target_modules": ["lin0"]}), |
| 674 | + ("Vanilla MLP 3 GraLoRA", "MLP", GraloraConfig, {"target_modules": ["lin1"]}), |
| 675 | + ("Vanilla MLP 4 GraLoRA", "MLP", GraloraConfig, {"target_modules": ["lin0", "lin1"]}), |
| 676 | + ( |
| 677 | + "Vanilla MLP 5 GraLoRA", |
| 678 | + "MLP", |
| 679 | + GraloraConfig, |
| 680 | + {"target_modules": ["lin0"], "modules_to_save": ["lin1"]}, |
| 681 | + ), |
| 682 | + ( |
| 683 | + "Embedding + transformers Conv1D 1 GraLoRA", |
| 684 | + "EmbConv1D", |
| 685 | + GraloraConfig, |
| 686 | + {"target_modules": ["conv1d"], "gralora_k": 1}, |
| 687 | + ), |
668 | 688 | ########## |
669 | 689 | # VBLoRA # |
670 | 690 | ########## |
|
973 | 993 | {"n_frequency": 10, "target_modules": ["lin0"]}, |
974 | 994 | {"n_frequency": 10, "target_modules": ["lin1"]}, |
975 | 995 | ), |
| 996 | + ( |
| 997 | + "GraLoRA Same", |
| 998 | + "gralora", |
| 999 | + GraloraConfig, |
| 1000 | + {"target_modules": ["lin0"], "init_weights": False}, |
| 1001 | + {"target_modules": ["lin0"], "init_weights": False}, |
| 1002 | + ), |
| 1003 | + ( |
| 1004 | + "GraLoRA Different", |
| 1005 | + "gralora", |
| 1006 | + GraloraConfig, |
| 1007 | + {"target_modules": ["lin0"], "init_weights": False}, |
| 1008 | + {"target_modules": ["lin1"], "init_weights": False}, |
| 1009 | + ), |
976 | 1010 | ( |
977 | 1011 | "SHiRA Same", |
978 | 1012 | "shira", |
|
1159 | 1193 | VeraConfig: "vera_lambda_", |
1160 | 1194 | RandLoraConfig: "randlora_", |
1161 | 1195 | FourierFTConfig: "fourierft_", |
| 1196 | + GraloraConfig: "gralora_", |
1162 | 1197 | C3AConfig: "c3a_", |
1163 | 1198 | HRAConfig: "hra_", |
1164 | 1199 | ShiraConfig: "shira_", |
@@ -3104,12 +3139,12 @@ def test_add_weighted_adapter_subtraction_with_negative_weights(self): |
3104 | 3139 | cancelled_B = module.lora_B["cancelled"].weight.data |
3105 | 3140 |
|
3106 | 3141 | # The weights should be approximately zero (they cancel out) |
3107 | | - assert torch.allclose(cancelled_A, torch.zeros_like(cancelled_A), atol=1e-5), ( |
3108 | | - f"Cancelled A should be ~0, got max abs value {cancelled_A.abs().max()}" |
3109 | | - ) |
3110 | | - assert torch.allclose(cancelled_B, torch.zeros_like(cancelled_B), atol=1e-5), ( |
3111 | | - f"Cancelled B should be ~0, got max abs value {cancelled_B.abs().max()}" |
3112 | | - ) |
| 3142 | + assert torch.allclose( |
| 3143 | + cancelled_A, torch.zeros_like(cancelled_A), atol=1e-5 |
| 3144 | + ), f"Cancelled A should be ~0, got max abs value {cancelled_A.abs().max()}" |
| 3145 | + assert torch.allclose( |
| 3146 | + cancelled_B, torch.zeros_like(cancelled_B), atol=1e-5 |
| 3147 | + ), f"Cancelled B should be ~0, got max abs value {cancelled_B.abs().max()}" |
3113 | 3148 |
|
3114 | 3149 | def test_add_weighted_adapter_negative_weight_with_different_scaling(self): |
3115 | 3150 | # Test negative weights with different scaling factors (lora_alpha) |
@@ -3515,9 +3550,9 @@ def test_multirank_2(self): |
3515 | 3550 | if isinstance(module, BaseTunerLayer): |
3516 | 3551 | rank_expected = rank_pattern.get(key, r) |
3517 | 3552 | rank_current = module.lora_A[adapter].weight.shape[0] |
3518 | | - assert rank_current == rank_expected, ( |
3519 | | - f"Rank {rank_current} is not equal to expected {rank_expected}" |
3520 | | - ) |
| 3553 | + assert ( |
| 3554 | + rank_current == rank_expected |
| 3555 | + ), f"Rank {rank_current} is not equal to expected {rank_expected}" |
3521 | 3556 |
|
3522 | 3557 |
|
3523 | 3558 | class TestLayerRepr: |
|
0 commit comments