From d0a167a78df970727640aeeb2a89002ac9a50a7b Mon Sep 17 00:00:00 2001
From: Ppaddington <3475635952@qq.com>
Date: Thu, 19 Mar 2026 11:11:55 +0800
Subject: [PATCH] feat: integrate SoLA for LLM SVD decomposition
---
.../compress/LLM_Decomposition/README.md | 30 +
.../LLM_Decomposition/requirements.txt | 355 ++++++++
.../LLM_Decomposition/sola/hook_sola.py | 380 ++++++++
.../compress/LLM_Decomposition/sola/optim.py | 183 ++++
.../LLM_Decomposition/sola/playground_sola.py | 324 +++++++
.../LLM_Decomposition/sola/playground_sola.sh | 3 +
.../LLM_Decomposition/sola/training_sola.py | 160 ++++
.../compress/LLM_Decomposition/sola/utils.py | 811 ++++++++++++++++++
.../LLM_Decomposition/sola_neuron_idx/hook.py | 305 +++++++
.../sola_neuron_idx/playground.py | 155 ++++
.../sola_neuron_idx/playground.sh | 3 +
.../sola_neuron_idx/training.py | 29 +
.../sola_neuron_idx/utils.py | 390 +++++++++
13 files changed, 3128 insertions(+)
create mode 100644 flagscale/compress/LLM_Decomposition/README.md
create mode 100644 flagscale/compress/LLM_Decomposition/requirements.txt
create mode 100644 flagscale/compress/LLM_Decomposition/sola/hook_sola.py
create mode 100644 flagscale/compress/LLM_Decomposition/sola/optim.py
create mode 100644 flagscale/compress/LLM_Decomposition/sola/playground_sola.py
create mode 100644 flagscale/compress/LLM_Decomposition/sola/playground_sola.sh
create mode 100644 flagscale/compress/LLM_Decomposition/sola/training_sola.py
create mode 100644 flagscale/compress/LLM_Decomposition/sola/utils.py
create mode 100644 flagscale/compress/LLM_Decomposition/sola_neuron_idx/hook.py
create mode 100644 flagscale/compress/LLM_Decomposition/sola_neuron_idx/playground.py
create mode 100644 flagscale/compress/LLM_Decomposition/sola_neuron_idx/playground.sh
create mode 100644 flagscale/compress/LLM_Decomposition/sola_neuron_idx/training.py
create mode 100644 flagscale/compress/LLM_Decomposition/sola_neuron_idx/utils.py
diff --git a/flagscale/compress/LLM_Decomposition/README.md b/flagscale/compress/LLM_Decomposition/README.md
new file mode 100644
index 0000000000..51fec25453
--- /dev/null
+++ b/flagscale/compress/LLM_Decomposition/README.md
@@ -0,0 +1,30 @@
+
+
SoLA
+
+[AAAI 2025] SoLA: Leveraging Soft Activation Sparsity and Low-Rank Decomposition for Large Language Model Compression
+
+
+
+
+
+
+
+## Usage
+
+### Construct Calibration Data
+ sola_neuron_idx/utils.py get_wikitext2() --> data/wiki_256_4096.json
+
+### Compute Neurons Norm
+ bash sola_neuron_idx/playground.sh --> data/hitter_dict_15_256_4096_wiki_13b.json
+
+### Low-Rank Decomposition and Evaluation
+ bash sola/playground_sola.sh
+
+ Llama-2-7B/13B low-rank decomposition and evaluation
+
+ Remember modify file path to your own
+
+### Others
+ sola_neuron_idx/hook.py llama_self_attn() past_key_value.update_get() in transformers-4.37.1/src/transformers/cache_utils.py
+
+ You can install modified transformers package: cd transformers-4.37.1; pip install -e .
diff --git a/flagscale/compress/LLM_Decomposition/requirements.txt b/flagscale/compress/LLM_Decomposition/requirements.txt
new file mode 100644
index 0000000000..83f9e70e41
--- /dev/null
+++ b/flagscale/compress/LLM_Decomposition/requirements.txt
@@ -0,0 +1,355 @@
+about-time==4.2.1
+absl-py==2.1.0
+accelerate==1.6.0
+addict==2.4.0
+aiofiles==23.2.1
+aiohttp==3.9.5
+aiosignal==1.3.1
+alive-progress==3.2.0
+altair==5.5.0
+annotated-types==0.7.0
+antlr4-python3-runtime==4.9.3
+anyio==4.4.0
+appdirs==1.4.4
+argon2-cffi==23.1.0
+argon2-cffi-bindings==21.2.0
+arrow==1.3.0
+asttokens==3.0.0
+async-lru==2.0.4
+async-timeout==4.0.3
+attrs==23.2.0
+auto_gptq==0.7.1
+autograd==1.7.0
+babel==2.17.0
+beautifulsoup4==4.13.3
+bitsandbytes==0.43.2.dev0
+bleach==6.2.0
+boto3==1.35.90
+botocore==1.35.90
+certifi==2024.2.2
+cffi==1.17.1
+chardet==5.2.0
+charset-normalizer==3.3.2
+click==8.1.7
+cloudpickle==2.2.1
+cma==3.2.2
+colorama==0.4.6
+comm==0.2.2
+ConfigSpace==0.6.0
+contextlib2==21.6.0
+contourpy==1.2.1
+cpm-kernels==1.0.11
+cycler==0.12.1
+Cython==3.0.10
+dataclasses-json==0.6.7
+DataProperty==1.0.1
+datasets==3.2.0
+debugpy==1.8.12
+decorator==5.2.1
+deepeval==2.0
+defusedxml==0.7.1
+Deprecated==1.2.15
+dill==0.3.8
+distro==1.9.0
+dnspython==2.6.1
+docker==7.1.0
+docker-pycreds==0.4.0
+docstring_parser==0.16
+docx2txt==0.8
+einops==0.8.0
+email_validator==2.2.0
+emcee==3.1.6
+et_xmlfile==2.0.0
+eval_type_backport==0.2.2
+evaluate==0.4.6
+exceptiongroup==1.2.1
+execnet==2.1.1
+executing==2.2.0
+fastapi==0.111.1
+fastapi-cli==0.0.4
+fastjsonschema==2.21.1
+ffmpy==0.3.2
+filelock==3.14.0
+fire==0.7.1
+flash-attn==2.5.8
+fonttools==4.51.0
+fqdn==1.5.1
+frozenlist==1.4.1
+fschat==0.2.24
+fsspec==2024.3.1
+func-timeout==4.3.5
+fuzzywuzzy==0.18.0
+gekko==1.3.0
+gitdb==4.0.11
+GitPython==3.1.43
+google-pasta==0.2.0
+googleapis-common-protos==1.66.0
+gradio==3.50.2
+gradio_client==0.6.1
+grapheme==0.6.0
+greenlet==3.1.1
+grpcio==1.63.2
+h11==0.14.0
+h5py==3.13.0
+httpcore==1.0.5
+httptools==0.6.1
+httpx==0.27.2
+httpx-sse==0.4.0
+huggingface-hub==0.30.1
+-e git+https://github.com/openai/human-eval@6d43fb980f9fee3c892a914eda09951f772ad10d#egg=human_eval
+idna==3.7
+immutabledict==4.2.1
+importlib-metadata==6.11.0
+importlib_resources==6.4.0
+iniconfig==2.0.0
+inquirerpy==0.3.4
+ipykernel==6.29.5
+ipython==8.18.1
+isoduration==20.11.0
+jedi==0.19.2
+jieba==0.42.1
+Jinja2==3.1.3
+jiter==0.8.0
+jmespath==1.0.1
+joblib==1.4.2
+json5==0.10.0
+jsonlines==4.0.0
+jsonpatch==1.33
+jsonpointer==3.0.0
+jsonschema==4.23.0
+jsonschema-specifications==2024.10.1
+jupyter-events==0.12.0
+jupyter-lsp==2.2.5
+jupyter_client==8.6.3
+jupyter_core==5.7.2
+jupyter_server==2.15.0
+jupyter_server_terminals==0.5.3
+jupyterlab==4.3.5
+jupyterlab_pygments==0.3.0
+jupyterlab_server==2.27.3
+kiwisolver==1.4.5
+langchain==0.3.9
+langchain-community==0.3.8
+langchain-core==0.3.21
+langchain-openai==0.2.10
+langchain-text-splitters==0.3.2
+langsmith==0.1.147
+latex2mathml==3.77.0
+Levenshtein==0.27.1
+lion-pytorch==0.1.4
+llm-blender @ git+https://github.com/yuchenlin/LLM-Blender.git@33204d2712944b6b17996f7c079e74cd963ccc7c
+-e git+https://github.com/EleutherAI/lm-evaluation-harness@9d4a04a0aac37fff88ca1a04d667012e1502c21d#egg=lm_eval
+lxml==5.2.2
+markdown-it-py==3.0.0
+markdown2==2.5.3
+MarkupSafe==2.1.5
+marshmallow==3.23.1
+matplotlib==3.8.4
+matplotlib-inline==0.1.7
+mbstrdecoder==1.1.3
+mdurl==0.1.2
+mistune==3.1.2
+ml-collections==0.1.1
+mmengine-lite==0.10.7
+mock==4.0.3
+more-itertools==10.3.0
+mpmath==1.3.0
+multidict==6.0.5
+multiprocess==0.70.16
+mypy-extensions==1.0.0
+narwhals==1.34.1
+nbclient==0.10.2
+nbconvert==7.16.6
+nbformat==5.10.4
+nest-asyncio==1.6.0
+networkx==3.2.1
+nh3==0.2.21
+ninja==1.11.1.1
+nltk==3.8.1
+notebook==7.3.2
+notebook_shim==0.2.4
+numexpr==2.10.1
+numpy==1.26.4
+nvidia-cublas-cu12==12.1.3.1
+nvidia-cuda-cupti-cu12==12.1.105
+nvidia-cuda-nvrtc-cu12==12.1.105
+nvidia-cuda-runtime-cu12==12.1.105
+nvidia-cudnn-cu12==8.9.2.26
+nvidia-cufft-cu12==11.0.2.54
+nvidia-curand-cu12==10.3.2.106
+nvidia-cusolver-cu12==11.4.5.107
+nvidia-cusparse-cu12==12.1.0.106
+nvidia-nccl-cu12==2.20.5
+nvidia-nvjitlink-cu12==12.4.127
+nvidia-nvtx-cu12==12.1.105
+omegaconf==2.2.3
+openai==1.55.3
+openbox==0.8.3
+OpenCC==1.1.9
+-e git+ssh://git@github.com/xinhaoH/KV_C.git@c899b84206943dbded60678a15425ce34b4bcb38#egg=opencompass&subdirectory=opencompass
+opencv-python-headless==4.11.0.86
+openpyxl==3.1.5
+opentelemetry-api==1.24.0
+opentelemetry-exporter-otlp-proto-common==1.24.0
+opentelemetry-exporter-otlp-proto-grpc==1.24.0
+opentelemetry-proto==1.24.0
+opentelemetry-sdk==1.24.0
+opentelemetry-semantic-conventions==0.45b0
+optimum==2.0.0
+orjson==3.10.6
+overrides==7.7.0
+packaging==24.0
+pandas==1.5.3
+pandocfilters==1.5.1
+parso==0.8.4
+pathos==0.3.3
+pathvalidate==3.2.0
+patsy==0.5.6
+# Editable install with no version control (peft==0.10.0)
+-e /home/xinhao/peft-0.10.0
+pexpect==4.9.0
+pfzy==0.3.4
+pillow==10.3.0
+platformdirs==4.2.2
+Platypus-Opt==1.0.4
+pluggy==1.5.0
+portalocker==2.10.1
+pox==0.3.5
+ppft==1.7.6.9
+prettytable==3.10.2
+prometheus_client==0.21.1
+prompt_toolkit==3.0.50
+protobuf==4.25.5
+psutil==5.9.8
+ptflops==0.7.3
+ptyprocess==0.7.0
+pure_eval==0.2.3
+pyaml==24.7.0
+pyarrow==16.0.0
+pyarrow-hotfix==0.6
+pybind11==2.13.1
+pycountry==24.6.1
+pycparser==2.22
+pydantic==1.10.21
+pydantic-settings==2.6.1
+pydantic_core==2.20.1
+pydub==0.25.1
+pyext==0.7
+Pygments==2.18.0
+pymoo==0.6.0
+pyparsing==3.1.2
+pysbd==0.3.4
+pytablewriter==1.2.0
+pytest==8.2.0
+pytest-repeat==0.9.3
+pytest-xdist==3.6.1
+python-dateutil==2.9.0.post0
+python-dotenv==1.0.1
+python-json-logger==3.2.1
+python-Levenshtein==0.27.1
+python-multipart==0.0.9
+pytz==2024.1
+PyYAML==6.0.1
+pyzmq==26.2.1
+ragas==0.2.6
+rank-bm25==0.2.2
+RapidFuzz==3.12.2
+referencing==0.35.1
+regex==2024.4.28
+requests==2.32.3
+requests-toolbelt==1.0.0
+retrying==1.3.4
+rfc3339-validator==0.1.4
+rfc3986-validator==0.1.1
+rich==13.7.1
+rouge==1.0.1
+rouge-chinese==1.0.3
+rouge-score==0.1.2
+rpds-py==0.22.3
+ruff==0.5.5
+s3fs==0.4.2
+s3transfer==0.10.4
+sacrebleu==2.4.2
+safetensors==0.4.3
+sagemaker==2.237.1
+sagemaker-core==1.0.17
+schema==0.7.7
+scikit-learn==1.5.0
+scikit-optimize==0.10.2
+scipy==1.10.1
+seaborn==0.13.2
+semantic-version==2.10.0
+Send2Trash==1.8.3
+sentence-transformers==4.0.1
+sentencepiece==0.2.0
+sentry-sdk==2.11.0
+setproctitle==1.3.3
+shellingham==1.5.4
+shortuuid==1.0.13
+shtab==1.7.1
+six==1.16.0
+smdebug-rulesconfig==1.0.1
+smmap==5.0.1
+sniffio==1.3.1
+sortedcontainers==2.4.0
+soupsieve==2.6
+SQLAlchemy==2.0.35
+sqlitedict==2.1.0
+stack-data==0.6.3
+starlette==0.37.2
+statsmodels==0.14.2
+svgwrite==1.4.3
+sympy==1.12
+syne_tune==0.13.0
+tabledata==1.3.3
+tabulate==0.9.0
+tblib==3.0.0
+tcolorpy==0.1.6
+tenacity==8.4.2
+termcolor==3.0.1
+terminado==0.18.1
+threadpoolctl==3.5.0
+tiktoken==0.8.0
+timeout-decorator==0.5.0
+tinycss2==1.4.0
+tokenizers==0.15.2
+tomli==2.0.1
+tomlkit==0.12.0
+torch==2.3.0
+tornado==6.4.2
+tqdm==4.66.4
+tqdm-multiprocess==0.0.11
+traitlets==5.14.3
+-e git+ssh://git@github.com/xinhaoH/SOLA_xinhao.git@0a46232d27a29d73794599c3530fb07157c60b0c#egg=transformers&subdirectory=transformers-4.37.1
+tree-sitter==0.21.3
+tree-sitter-languages==1.10.2
+triton==2.3.0
+trl @ git+https://github.com/huggingface/trl.git@eab175d434b9bb9badee20335c7945991a26dfac
+typeguard==4.4.2
+typepy==1.3.2
+typer==0.12.3
+types-python-dateutil==2.9.0.20241206
+typing-inspect==0.9.0
+typing_extensions==4.13.2
+tyro==0.9.18
+tzdata==2024.1
+ujson==5.10.0
+uri-template==1.3.0
+urllib3==1.26.20
+uvicorn==0.30.3
+uvloop==0.19.0
+wandb==0.17.5
+watchfiles==0.22.0
+wavedrom==2.0.3.post3
+wcwidth==0.2.13
+webcolors==24.11.1
+webencodings==0.5.1
+websocket-client==1.8.0
+websockets==11.0.3
+word2number==1.1
+wrapt==1.17.0
+xxhash==3.4.1
+yapf==0.43.0
+yarl==1.9.4
+zipp==3.18.1
+zstandard==0.23.0
diff --git a/flagscale/compress/LLM_Decomposition/sola/hook_sola.py b/flagscale/compress/LLM_Decomposition/sola/hook_sola.py
new file mode 100644
index 0000000000..692736e6fb
--- /dev/null
+++ b/flagscale/compress/LLM_Decomposition/sola/hook_sola.py
@@ -0,0 +1,380 @@
+import warnings
+import functools
+import math
+import logging
+import torch
+from torch import nn
+import torch.nn.functional as F
+from typing import Dict, List, Tuple, Optional
+
+from transformers import LlamaPreTrainedModel
+from transformers.models.llama.modeling_llama import LlamaMLP, LlamaAttention
+
+from utils import (
+ HelperState,
+ HelperCollectState,
+ set_helper_state,
+ HELPER_SUPPORT_MODEL_LIST,
+ HELPER_SUPPORT_MODEL_TYPES
+)
+
+logger = logging.getLogger(__name__)
+
+_HELPER_HOOK_KEY = "HelperHook"
+
+def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
+ """
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
+ """
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
+ if n_rep == 1:
+ return hidden_states
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
+
+def rotate_half(x):
+ """Rotates half the hidden dims of the input."""
+ x1 = x[..., : x.shape[-1] // 2]
+ x2 = x[..., x.shape[-1] // 2 :]
+ return torch.cat((-x2, x1), dim=-1)
+
+def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
+ """Applies Rotary Position Embedding to the query and key tensors.
+
+ Args:
+ q (`torch.Tensor`): The query tensor.
+ k (`torch.Tensor`): The key tensor.
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
+ position_ids (`torch.Tensor`):
+ The position indices of the tokens corresponding to the query and key tensors. For example, this can be
+ used to pass offsetted position ids when working with a KV-cache.
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
+ Returns:
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
+ """
+ cos = cos[position_ids].unsqueeze(unsqueeze_dim)
+ sin = sin[position_ids].unsqueeze(unsqueeze_dim)
+ q_embed = (q * cos) + (rotate_half(q) * sin)
+ k_embed = (k * cos) + (rotate_half(k) * sin)
+ return q_embed, k_embed
+
+def add_training_hook_to_llama(model: LlamaPreTrainedModel,
+ dest: Dict[str, Dict[str, List[torch.Tensor]]],
+ intermediate_size: int,
+ hidden_size: int) -> int:
+ set_helper_state(model, HelperState.Collecting)
+ hooks = []
+ last_layer = 0
+
+ def forward_hook_get_XXT(layer_idx, name, module, inp, out):
+ inp = inp[0].detach().float()
+ if inp.shape[1] > 1:
+ adds = torch.matmul(inp.transpose(1,2), inp)
+ adds_sum = torch.sum(adds, dim=0).cpu()
+
+ raw_scaling_diag_matrix = getattr(module, f'raw_scaling_diag_matrix_{layer_idx}')
+ raw_scaling_diag_matrix += adds_sum
+
+ inp = adds = adds_sum = out = None
+ del inp, adds, adds_sum, out
+ torch.cuda.empty_cache()
+
+ for name, module in model.named_modules():
+ suffix = name.split(".")[-1]
+ if suffix not in ["gate_proj", "up_proj", "down_proj", "q_proj", "k_proj", "o_proj", "v_proj"]:
+ continue
+ layer_idx = int(name.split(".")[-3])
+ if suffix == "down_proj":
+ setattr(module, f"raw_scaling_diag_matrix_{layer_idx}", torch.zeros(intermediate_size, intermediate_size))
+ else:
+ setattr(module, f"raw_scaling_diag_matrix_{layer_idx}", torch.zeros(hidden_size, hidden_size))
+ handle_pre_forward_collect_hook = module.register_forward_hook(
+ functools.partial(
+ forward_hook_get_XXT,
+ layer_idx,
+ name
+ )
+ )
+ hooks.append(handle_pre_forward_collect_hook)
+
+ setattr(model, _HELPER_HOOK_KEY, hooks)
+ return last_layer
+
+
+def add_training_hook(model: HELPER_SUPPORT_MODEL_TYPES,
+ dest: Dict[str, Dict[str, List[torch.Tensor]]],
+ intermediate_size: int,
+ hidden_size: int) -> int:
+ if isinstance(model, LlamaPreTrainedModel):
+ print("+++++++++ add Llama training hook +++++++++")
+ return add_training_hook_to_llama(model, dest, intermediate_size, hidden_size)
+ else:
+ raise NotImplementedError(f"Only support {HELPER_SUPPORT_MODEL_LIST}.")
+
+
+def remove_training_hook(model: HELPER_SUPPORT_MODEL_TYPES):
+ hooks = getattr(model, _HELPER_HOOK_KEY)
+ for handle in hooks:
+ handle.remove()
+
+ setattr(model, _HELPER_HOOK_KEY, None)
+
+
+def llama_mlp_forward_sola(module, inp, **kwargs):
+ if module.config.pretraining_tp > 1:
+ raise NotImplementedError
+
+ if module.config.pretraining_tp > 1:
+ slice = module.intermediate_size // module.config.pretraining_tp
+ gate_proj_slices = module.gate_proj.weight.split(slice, dim=0)
+ up_proj_slices = module.up_proj.weight.split(slice, dim=0)
+ down_proj_slices = module.down_proj.weight.split(slice, dim=1)
+
+ gate_proj = torch.cat(
+ [F.linear(inp, gate_proj_slices[i]) for i in range(module.config.pretraining_tp)], dim=-1
+ )
+ up_proj = torch.cat([F.linear(inp, up_proj_slices[i]) for i in range(module.config.pretraining_tp)], dim=-1)
+
+ intermediate_states = (module.act_fn(gate_proj) * up_proj)
+ down_proj = [
+ F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(module.config.pretraining_tp)
+ ]
+ down_proj = sum(down_proj)
+ else:
+ if kwargs['layer_idx'] not in kwargs['pruned_layer_idx_list']:
+ down_proj = module.down_proj(module.act_fn(module.gate_proj(inp)) * module.up_proj(inp))
+ else:
+ # prime
+ h_gate_heavy = module.act_fn(torch.nn.functional.linear(inp, module.gate_weight_heavy))
+ h_up_heavy = torch.nn.functional.linear(inp, module.up_weight_heavy)
+ intermediate_states_heavy = h_gate_heavy * h_up_heavy
+ down_proj_heavy = torch.nn.functional.linear(intermediate_states_heavy, module.down_weight_heavy)
+
+ # marginal
+ if module.gate_proj_use > 0:
+ h_gate = module.act_fn(module.gate_proj(inp))
+ else:
+ if inp.device != module.gate_weight_U_top.device:
+ module.gate_weight_U_top = module.gate_weight_U_top.to(inp.device)
+ tmp = torch.nn.functional.linear(inp, module.gate_weight_U_top)
+ if tmp.device != module.gate_weight_SVh_top.device:
+ module.gate_weight_SVh_top = module.gate_weight_SVh_top.to(tmp.device)
+ h_gate = module.act_fn(torch.nn.functional.linear(tmp, module.gate_weight_SVh_top))
+
+ if module.up_proj_use > 0:
+ h_up = module.up_proj(inp)
+ else:
+ if inp.device != module.up_weight_U_top.device:
+ module.up_weight_U_top = module.up_weight_U_top.to(inp.device)
+ tmp = torch.nn.functional.linear(inp, module.up_weight_U_top)
+ if tmp.device != module.up_weight_SVh_top.device:
+ module.up_weight_SVh_top = module.up_weight_SVh_top.to(tmp.device)
+ h_up = torch.nn.functional.linear(tmp, module.up_weight_SVh_top)
+
+ if h_gate.device != h_up.device:
+ h_gate = h_gate.to(h_up.device)
+ intermediate_states = h_gate * h_up
+
+ if module.down_proj_use > 0:
+ down_proj = module.down_proj(intermediate_states)
+ else:
+ if intermediate_states.device != module.down_weight_U_top.device:
+ module.down_weight_U_top = module.down_weight_U_top.to(intermediate_states.device)
+ tmp = torch.nn.functional.linear(intermediate_states, module.down_weight_U_top)
+ if tmp.device != module.down_weight_SVh_top.device:
+ module.down_weight_SVh_top = module.down_weight_SVh_top.to(tmp.device)
+ down_proj_light = torch.nn.functional.linear(tmp, module.down_weight_SVh_top)
+
+ down_proj = down_proj_heavy + down_proj_light
+
+ return down_proj
+
+def llama_attn_forward_sola(module: torch.nn.Module,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ output_attentions: Optional[bool] = False,
+ use_cache: Optional[bool] = False,
+ **kwargs):
+ if "padding_mask" in kwargs:
+ warnings.warn(
+ "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
+ )
+
+ bsz, q_len, _ = hidden_states.size()
+
+ if module.config.pretraining_tp > 1:
+ key_value_slicing = (module.num_key_value_heads * module.head_dim) // module.config.pretraining_tp
+ query_slices = module.q_proj.weight.split(
+ (module.num_heads * module.head_dim) // module.config.pretraining_tp, dim=0
+ )
+ key_slices = module.k_proj.weight.split(key_value_slicing, dim=0)
+ value_slices = module.v_proj.weight.split(key_value_slicing, dim=0)
+
+ query_states = [F.linear(hidden_states, query_slices[i]) for i in range(module.config.pretraining_tp)]
+ query_states = torch.cat(query_states, dim=-1)
+
+ key_states = [F.linear(hidden_states, key_slices[i]) for i in range(module.config.pretraining_tp)]
+ key_states = torch.cat(key_states, dim=-1)
+
+ value_states = [F.linear(hidden_states, value_slices[i]) for i in range(module.config.pretraining_tp)]
+ value_states = torch.cat(value_states, dim=-1)
+
+ else:
+ ##### Attn q/k decomposition #####
+ value_states = module.v_proj(hidden_states)
+ if kwargs['layer_idx'] not in kwargs['pruned_layer_idx_list']:
+ query_states = module.q_proj(hidden_states)
+ key_states = module.k_proj(hidden_states)
+ else:
+ if module.q_proj_use > 0:
+ query_states = module.q_proj(hidden_states)
+ else:
+ if hidden_states.device != module.q_weight_U_top.device:
+ module.q_weight_U_top = module.q_weight_U_top.to(hidden_states.device)
+ tmp = torch.nn.functional.linear(hidden_states, module.q_weight_U_top)
+ if tmp.device != module.q_weight_SVh_top.device:
+ module.q_weight_SVh_top = module.q_weight_SVh_top.to(tmp.device)
+ query_states = torch.nn.functional.linear(tmp, module.q_weight_SVh_top)
+
+ if module.k_proj_use > 0:
+ key_states = module.k_proj(hidden_states)
+ else:
+ if hidden_states.device != module.k_weight_U_top.device:
+ module.k_weight_U_top = module.k_weight_U_top.to(hidden_states.device)
+ tmp = torch.nn.functional.linear(hidden_states, module.k_weight_U_top)
+ if tmp.device != module.k_weight_SVh_top.device:
+ module.k_weight_SVh_top = module.k_weight_SVh_top.to(tmp.device)
+ key_states = torch.nn.functional.linear(tmp, module.k_weight_SVh_top)
+
+ query_states = query_states.view(bsz, q_len, module.num_heads, module.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, module.num_key_value_heads, module.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, module.num_key_value_heads, module.head_dim).transpose(1, 2)
+
+ kv_seq_len = key_states.shape[-2]
+ if past_key_value is not None:
+ if module.layer_idx is None:
+ raise ValueError(
+ f"The cache structure has changed since version v4.36. If you are using {module.__class__.__name__} "
+ "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
+ "with a layer index."
+ )
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, module.layer_idx)
+
+ cos, sin = module.rotary_emb(value_states, seq_len=kv_seq_len)
+ if cos.device != position_ids.device:
+ position_ids = position_ids.to(cos.device)
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
+
+ if past_key_value is not None:
+ cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
+ key_states, value_states = past_key_value.update(key_states, value_states, module.layer_idx, cache_kwargs)
+
+ key_states = repeat_kv(key_states, module.num_key_value_groups)
+ value_states = repeat_kv(value_states, module.num_key_value_groups)
+
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(module.head_dim)
+
+ if attn_weights.size() != (bsz, module.num_heads, q_len, kv_seq_len):
+ raise ValueError(
+ f"Attention weights should be of size {(bsz, module.num_heads, q_len, kv_seq_len)}, but is"
+ f" {attn_weights.size()}"
+ )
+
+ if attention_mask is not None:
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
+ raise ValueError(
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
+ )
+ if attn_weights.device != attention_mask.device:
+ attn_weights = attn_weights.to(attention_mask.device)
+ attn_weights = attn_weights + attention_mask
+
+ # upcast attention to fp32
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.bfloat16).to(query_states.dtype)
+ attn_weights = nn.functional.dropout(attn_weights, p=module.attention_dropout, training=module.training)
+ if attn_weights.device != value_states.device:
+ attn_weights = attn_weights.to(value_states.device)
+ attn_output = torch.matmul(attn_weights, value_states)
+
+ if attn_output.size() != (bsz, module.num_heads, q_len, module.head_dim):
+ raise ValueError(
+ f"`attn_output` should be of size {(bsz, module.num_heads, q_len, module.head_dim)}, but is"
+ f" {attn_output.size()}"
+ )
+
+ attn_output = attn_output.transpose(1, 2).contiguous()
+
+ attn_output = attn_output.reshape(bsz, q_len, module.hidden_size)
+
+ if module.config.pretraining_tp > 1:
+ attn_output = attn_output.split(module.hidden_size // module.config.pretraining_tp, dim=2)
+ o_proj_slices = module.o_proj.weight.split(module.hidden_size // module.config.pretraining_tp, dim=1)
+ attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(module.config.pretraining_tp)])
+ else:
+ ##### Attn o decomposition #####
+ if kwargs['layer_idx'] not in kwargs['pruned_layer_idx_list']:
+ attn_output = module.o_proj(attn_output)
+ else:
+ if module.o_proj_use > 0:
+ attn_output = module.o_proj(attn_output)
+ if attn_output.device != module.o_weight_U_top.device:
+ module.o_weight_U_top = module.o_weight_U_top.to(attn_output.device)
+ tmp = torch.nn.functional.linear(attn_output, module.o_weight_U_top)
+ if tmp.device != module.o_weight_SVh_top.device:
+ module.o_weight_SVh_top = module.o_weight_SVh_top.to(tmp.device)
+ attn_output = torch.nn.functional.linear(tmp, module.o_weight_SVh_top)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights, past_key_value
+
+
+def add_inference_hook_to_llama_sola(pruned_layer_idx_list,
+ model: HELPER_SUPPORT_MODEL_TYPES):
+ set_helper_state(model, HelperState.Inference)
+ hooks = []
+
+ for name, module in model.named_modules():
+ if not isinstance(module, (LlamaMLP, LlamaAttention)):
+ continue
+ layer_idx = int(name.split(".")[-2])
+
+ if isinstance(module, LlamaMLP):
+ module.forward = functools.partial(
+ llama_mlp_forward_sola,
+ module,
+ layer_idx=layer_idx,
+ pruned_layer_idx_list=pruned_layer_idx_list,
+ module_name=name
+ )
+ elif isinstance(module, LlamaAttention):
+ hitter_dict_cur_layer_q = None
+ hitter_dict_cur_layer_k = None
+
+ module.forward = functools.partial(
+ llama_attn_forward_sola,
+ module,
+ layer_idx=layer_idx,
+ pruned_layer_idx_list=pruned_layer_idx_list,
+ hitter_dict_q=hitter_dict_cur_layer_q,
+ hitter_dict_k=hitter_dict_cur_layer_k
+ )
+
+def add_inference_hook(pruned_layer_idx_list,
+ model: HELPER_SUPPORT_MODEL_TYPES):
+ if isinstance(model, LlamaPreTrainedModel):
+ add_inference_hook_to_llama_sola(pruned_layer_idx_list, model)
+ else:
+ raise NotImplementedError(f"Only support {HELPER_SUPPORT_MODEL_LIST}.")
+
diff --git a/flagscale/compress/LLM_Decomposition/sola/optim.py b/flagscale/compress/LLM_Decomposition/sola/optim.py
new file mode 100644
index 0000000000..8bd4e24fb4
--- /dev/null
+++ b/flagscale/compress/LLM_Decomposition/sola/optim.py
@@ -0,0 +1,183 @@
+import copy
+import heapq
+import math
+from typing import Dict, Tuple, List, Callable
+
+
+class Optimizer:
+ def __init__(self,
+ model_params: int,
+ target_rate: float,
+ granularity: int = 32,
+ start: float = 99,
+ end: float = 90):
+ self.granularity = granularity
+ self.start = start
+ self.end = end
+
+ self.concurrent_budget = 3
+
+ self.core = {}
+
+ # model information
+ self.model_params = model_params
+ self.target_rate = target_rate
+
+ # global state
+ self.perf_degrade_clamp = 0.1 # %, in percentage
+ self.perf_degrade_relaxation = 0.02 # %, in percentage
+
+ self.eff_score = 0
+ self.acc_score = 0
+
+ def add_item(self,
+ name: str,
+ cursor: int,
+ shape: Tuple[int, int],
+ perf_var: List[float],
+ grad: List[float]):
+ m, n = shape
+ self.core.setdefault(name, {})
+ self.core[name]["cur"] = cursor
+ self.core[name]["start_idx"] = cursor
+ self.core[name]["perf_var"] = perf_var
+ self.core[name]["shape"] = shape
+ self.core[name]["est_params"] = (m + n) * (cursor + 1) * self.granularity
+ self.core[name]["tot_params"] = m * n
+ self.core[name]["grad"] = grad
+ self.core[name]["accum_drop"] = 0
+ self.core[name]["is_excluded"] = False
+
+ def update_item(self, state: Dict, name: str, new_cursor: int):
+ m, n = state[name]["shape"]
+ start_idx = state[name]["start_idx"]
+
+ state[name]["cur"] = new_cursor
+ state[name]["est_params"] = (m + n) * (new_cursor + 1) * self.granularity
+ state[name]["accum_drop"] = state[name]["perf_var"][start_idx] - state[name]["perf_var"][new_cursor]
+
+ def _eval_state(self, state: Dict):
+ est_params = 0
+ tot_params = 0
+ tot_acc = 0
+ n_layers = len(state.keys())
+ for key, state_dict in state.items():
+ tot_params += state_dict["tot_params"]
+ est_params += state_dict["est_params"]
+ tot_acc += state_dict["perf_var"][state_dict["cur"]]
+
+ return tot_params, est_params, tot_acc / n_layers
+
+ def constringe(self, do_excluding: bool = True):
+ print(f"Target compression rate: {self.target_rate:.2f}%")
+ previous_state = copy.deepcopy(self.core)
+ epoch = 0
+ eff_scores = 0
+ while True:
+ loop = 0
+ while True:
+ current_state, has_modified = self._inner_loop(previous_state)
+ previous_state = current_state
+
+ tot_params, est_params, acc_scores = self._eval_state(current_state)
+ eff_scores = (1 - (tot_params - est_params) / self.model_params) * 100 # in percentage
+ print(f"{epoch=}, {loop=}, {tot_params=}, {est_params=}, {eff_scores=:.2f}%, {acc_scores=:.2f}")
+ if eff_scores <= self.target_rate or not has_modified:
+ break
+
+ # gradually relax the clamp
+ self.perf_degrade_clamp += self.perf_degrade_relaxation # in percentage
+
+ loop += 1
+
+ if not do_excluding:
+ break
+
+ # check if there's any layer needing to be excluded
+ any_excludes = False
+ for layer, layer_state_dict in current_state.items():
+ mat_shape = layer_state_dict["shape"]
+ desired_rank = (layer_state_dict["cur"] + 1) * self.granularity
+ max_rank = int(math.floor(math.prod(mat_shape) / sum(mat_shape) / self.granularity) - 5) * self.granularity
+ if desired_rank >= max_rank:
+ print(f"excluding {layer}")
+ any_excludes = True
+ self.core[layer]["is_excluded"] = True
+
+ if not any_excludes:
+ break
+
+ # redo after excluding
+ previous_state = copy.deepcopy(self.core)
+ for layer, layer_state_dict in self.core.items():
+ if layer_state_dict["is_excluded"]:
+ previous_state.pop(layer)
+
+ epoch += 1
+
+ print(f"Optimization finished.")
+ if eff_scores > self.target_rate:
+ print(f"Failed to reach the target goal {self.target_rate} vs. {eff_scores}")
+ return current_state
+
+ def _inner_loop(self, previous_state: Dict):
+ current_state = copy.deepcopy(previous_state)
+ all_layers = previous_state.keys()
+
+ loop = 0
+ has_modified = False
+ while True:
+ # Iterative update
+ candidates = []
+ for layer in all_layers:
+ cur = current_state[layer]["cur"]
+
+ perf_drop = current_state[layer]["grad"][cur]
+ moved_cur = cur - 1
+
+ # skip if it exceeds the supremum
+ if perf_drop > self.perf_degrade_clamp:
+ continue
+
+ if current_state[layer]["perf_var"][cur] < self.end:
+ continue
+
+ if cur == 0:
+ continue
+
+ mat_shape = current_state[layer]["shape"]
+ assert len(mat_shape) == 2
+ gain = perf_drop / math.prod(mat_shape)
+
+ # collect the current descending performance
+ heapq.heappush(candidates, (gain, layer, cur, moved_cur))
+
+ if len(candidates) == 0:
+ return current_state, has_modified
+
+ for _ in range(self.concurrent_budget):
+ if len(candidates) == 0:
+ break
+ perf_drop, key, cur, moved_cur = heapq.heappop(candidates)
+ self.update_item(current_state, key, moved_cur)
+ has_modified = True
+
+ if has_modified:
+ tot_params, est_params, acc_scores = self._eval_state(current_state)
+ eff_scores = (1 - (tot_params - est_params) / self.model_params) * 100 # in percentage
+ if eff_scores <= self.target_rate:
+ return current_state, has_modified
+
+ loop += 1
+
+ def export(self):
+ deft_config = {}
+ for layer, layer_data in self.core.items():
+ cur = layer_data["cur"]
+ deft_config[layer] = {
+ "shape": tuple(layer_data["shape"]),
+ "desired_rank": (cur + 1) * self.granularity,
+ "perf_score": layer_data["perf_var"][cur],
+ "is_updated": False,
+ }
+ return deft_config
diff --git a/flagscale/compress/LLM_Decomposition/sola/playground_sola.py b/flagscale/compress/LLM_Decomposition/sola/playground_sola.py
new file mode 100644
index 0000000000..1e0c48ab79
--- /dev/null
+++ b/flagscale/compress/LLM_Decomposition/sola/playground_sola.py
@@ -0,0 +1,324 @@
+import os
+import random
+import click
+from transformers import (
+ AutoModelForCausalLM,
+ AutoTokenizer,
+ LlamaTokenizer,
+ pipeline
+)
+import torch
+import numpy as np
+import pandas as pd
+
+import utils
+from sola.training_sola import Helper
+
+import json
+import time
+import logging
+logger = logging.getLogger(__name__)
+
+from transformers.models.llama.modeling_llama import LlamaMLP, LlamaAttention
+
+import lm_eval
+from lm_eval.models.huggingface import HFLM
+
+def setup_seed(seed):
+ torch.manual_seed(seed)
+ torch.cuda.manual_seed_all(seed)
+ np.random.seed(seed)
+ random.seed(seed)
+ torch.backends.cudnn.deterministic = True
+
+
+@click.command()
+@click.option("-m", "--model", type=click.Path(file_okay=True), help="path to model file", default=None)
+def cli(**kwargs):
+ args = utils.EasyDict(**kwargs)
+ print(args)
+
+ model_path = args.model
+ max_memory = "80000MB"
+ max_memory = {i: max_memory for i in range(1)}
+
+ model = AutoModelForCausalLM.from_pretrained(
+ model_path,
+ cache_dir=None,
+ device_map="auto",
+ quantization_config = None,
+ torch_dtype=torch.bfloat16,
+ attn_implementation="eager",
+ )
+ print("Model created")
+
+ # Tokenizer
+ tokenizer = AutoTokenizer.from_pretrained(
+ model_path,
+ cache_dir=None,
+ padding_side="right",
+ use_fast=False, # Fast tokenizer giving issues.
+ tokenizer_type='llama' if 'ama' in model_path else None, # Needed for HF name change
+ )
+ if tokenizer._pad_token is None:
+ utils.smart_tokenizer_and_embedding_resize(
+ special_tokens_dict=dict(pad_token=utils.DEFAULT_PAD_TOKEN),
+ tokenizer=tokenizer,
+ model=model,
+ )
+ if 'ama' in model_path or isinstance(tokenizer, LlamaTokenizer):
+ # LLaMA tokenizer may not have correct special tokens set.
+ # Check and add them if missing to prevent them from being parsed into different tokens.
+ # Note that these are present in the vocabulary.
+ # Note also that `model.config.pad_token_id` is 0 which corresponds to `` token.
+ print('Adding special tokens.')
+ tokenizer.add_special_tokens({
+ "eos_token": tokenizer.convert_ids_to_tokens(model.config.eos_token_id),
+ "bos_token": tokenizer.convert_ids_to_tokens(model.config.bos_token_id),
+ "unk_token": tokenizer.convert_ids_to_tokens(
+ model.config.pad_token_id if model.config.pad_token_id and model.config.pad_token_id != -1 else tokenizer.pad_token_id
+ ),
+ })
+ print("Tokenizer created")
+
+ generation_pipeline = pipeline(task="text-generation", model=model, tokenizer=tokenizer)
+ print("Pipeline created")
+
+ helper_params = {'intermediate_size': model.config.intermediate_size,
+ 'hidden_size': model.config.hidden_size}
+ helper = Helper(model, torch.bfloat16, **helper_params)
+ print("Helper created")
+
+ # Calibration data
+ prompts_path = '/data/wiki_256_4096.json'
+ with open(prompts_path, 'r') as file:
+ prompts = json.load(file)
+
+ # Compute XX^T
+ t_start_time = time.time()
+ with helper:
+ for text in prompts:
+ prompt_token_count = len(generation_pipeline.tokenizer.encode(text, return_tensors="pt")[0])
+ generation_pipeline(text, max_length=int(prompt_token_count), pad_token_id=tokenizer.eos_token_id, truncation=True)
+ t_end_time = time.time()
+ t_duration = t_end_time - t_start_time
+ print(f"Collect training data costs avg: {t_duration/len(prompts): .5f} s, all: {t_duration/60: .2f} min, {t_duration: .5f} s. ")
+ print('Collect training data Done')
+
+ # Record XX^T
+ """
+ for name, module in model.named_modules():
+ suffix = name.split(".")[-1]
+ if suffix not in ["gate_proj", "up_proj", "down_proj", "q_proj", "k_proj", "v_proj", "o_proj"]:
+ continue
+ layer_idx = int(name.split(".")[-3])
+ raw_scaling_diag_matrix = getattr(module, f'raw_scaling_diag_matrix_{layer_idx}')
+ # torch.save(raw_scaling_diag_matrix, os.path.join('/data/model_params/wiki/13b/raw_scaling_diag_matrix/', f"{name}.raw_scaling_diag_matrix"))
+ print(name, 'raw scaling diag matrix saved')
+ """
+
+ # Low-Rank Decomposition
+ dump_dest = f'/data/model_params/wiki/13b/light'
+ for name, module in model.named_modules():
+ suffix = name.split(".")[-1]
+ if suffix not in ["gate_proj", "up_proj", "down_proj", "q_proj", "k_proj", "v_proj", "o_proj"]:
+ continue
+ layer_idx = int(name.split(".")[-3])
+
+ raw_scaling_diag_matrix = getattr(module, f'raw_scaling_diag_matrix_{layer_idx}').double().to(model.device)
+ # raw_scaling_diag_matrix = torch.load(f'/data/model_params/wiki/13b/raw_scaling_diag_matrix/{name}.raw_scaling_diag_matrix').double().to(model.device)
+
+ with open('/data/hitter_dict_15_256_4096_wiki_13b.json', 'r') as json_file:
+ hitter_dict = json.load(json_file)
+ light_75p_hitter = torch.tensor(hitter_dict[f'{layer_idx}']['light_85p_neuron_idx']).to(model.device)
+ if suffix == "down_proj":
+ raw_scaling_diag_matrix = raw_scaling_diag_matrix[light_75p_hitter, :][:, light_75p_hitter]
+
+ try:
+ scaling_diag_matrix = torch.linalg.cholesky(raw_scaling_diag_matrix).float().to(model.device)
+ except Exception as e:
+ print(name, "Warning: eigen scaling_diag_matrix is not positive!")
+ if torch.isnan(raw_scaling_diag_matrix).any():
+ print("Warning: scaling_diag_matrix contains NaN!")
+ elif torch.isinf(raw_scaling_diag_matrix).any():
+ print("Warning: scaling_diag_matrix contains Inf!")
+ if not torch.equal(raw_scaling_diag_matrix, raw_scaling_diag_matrix.T):
+ print("Warning: scaling_diag_matrix is not a symmetric matrix!")
+ eigenvalues = torch.linalg.eigvalsh(raw_scaling_diag_matrix)
+ raw_scaling_diag_matrix += (- eigenvalues[0] + 1e-3) * torch.eye(raw_scaling_diag_matrix.shape[0]).to(model.device)
+ scaling_diag_matrix = torch.linalg.cholesky(raw_scaling_diag_matrix).float().to(model.device)
+
+ try:
+ scaling_matrix_inv = torch.linalg.inv(scaling_diag_matrix)
+ except Exception as e:
+ print(name, "Warning: scaling_diag_matrix is not full rank!")
+ scaling_diag_matrix += 1e-3 * torch.eye(scaling_diag_matrix.shape[0]).to(model.device)
+ scaling_matrix_inv = torch.linalg.inv(scaling_diag_matrix).to(model.device)
+
+ W = module.weight.float()
+
+ # MLP beta weight matrix decomposition
+ if suffix in ["gate_proj", "up_proj", "down_proj"]:
+ if light_75p_hitter.device != W.device:
+ light_75p_hitter = light_75p_hitter.to(W.device)
+ if suffix in ["gate_proj", "up_proj"]:
+ W = W[light_75p_hitter, :].to(model.device)
+ else:
+ W = W[:, light_75p_hitter].to(model.device)
+
+ if W.device != scaling_diag_matrix.device:
+ scaling_diag_matrix = scaling_diag_matrix.to(W.device)
+ W_scale = torch.matmul(W, scaling_diag_matrix)
+ if layer_idx == 0:
+ print('W scale shape: ', suffix, W_scale.shape)
+
+ u, s, v = torch.linalg.svd(W_scale, full_matrices=False) # The singular values are returned in descending order.
+ if layer_idx == 0:
+ print('decomposition: ', name, u.shape, s.shape, v.shape, W.shape, scaling_matrix_inv.shape)
+
+ torch.save(u, os.path.join(dump_dest, f"{name}.u"))
+ # torch.save(v, os.path.join(dump_dest, f"{name}.v"))
+ torch.save(s, os.path.join(dump_dest, f"{name}.s"))
+ print(name, 'u s v saved.', dump_dest)
+
+ if v.device != scaling_matrix_inv.device:
+ v = v.to(scaling_matrix_inv.device)
+ v_inv = v @ scaling_matrix_inv
+ torch.save(v_inv, os.path.join(dump_dest, f"{name}.v_inv"))
+ print(name, 'v_inv saved.', dump_dest)
+
+ # Compute the rank of each component
+ target_rate = 0.7
+ if target_rate >= 0.8:
+ target_modules = ["q_proj", "k_proj", "gate_proj", "up_proj", "down_proj"]
+ else:
+ target_modules = ["q_proj", "k_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
+
+ # config = utils.get_rank_config(model, target_modules, target_rate, args.model, start=99.9, end=90)
+ # '''
+ if '13b' in model_path: # Llama-2-13b
+ if target_rate == 0.8: # 20% compression rate
+ config = {"q_proj": 1440, "k_proj": 1440,
+ "gate_proj": 2048, "up_proj": 2816, "down_proj": 2976}
+ elif target_rate == 0.7: # 30% compression rate
+ config = {"q_proj": 960, "k_proj": 640, "o_proj": 1920,
+ "gate_proj": 2080, "up_proj": 2400, "down_proj": 2560}
+ elif target_rate == 0.6: # 40% compression rate
+ config = {"q_proj": 480, "k_proj": 480, "o_proj": 1440,
+ "gate_proj": 1120, "up_proj": 2080, "down_proj": 2240}
+ elif target_rate == 0.5: # 50% compression rate
+ config = {"q_proj": 320, "k_proj": 320, "o_proj": 768,
+ "gate_proj": 1280, "up_proj": 1280, "down_proj": 1440}
+ else: # Llama-2-7b
+ if target_rate == 0.8: # 20% compression rate
+ config = {"q_proj": 1120, "k_proj": 800,
+ "gate_proj": 1760, "up_proj": 2400, "down_proj": 2240}
+ elif target_rate == 0.7: # 30% compression rate
+ config = {"q_proj": 640, "k_proj": 640, "o_proj": 1440,
+ "gate_proj": 1760, "up_proj": 1920, "down_proj": 1760}
+ # '''
+ print(config)
+
+ desired_ranks = {}
+ layer_num = 40 if '13b' in model_path else 32
+ for layer_idx in range(layer_num):
+ for suffix in ["gate_proj", "up_proj", "down_proj", "q_proj", "k_proj", "o_proj"]:
+ if f'{layer_idx}' not in desired_ranks.keys():
+ desired_ranks[f'{layer_idx}'] = {suffix: (config[suffix], None)}
+ else:
+ desired_ranks[f'{layer_idx}'][suffix] = (config[suffix], None)
+
+ if '13b' in model_path:
+ pruned_layer_idx_list = [2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14,
+ 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25,
+ 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37]
+ elif '7b' in model_path:
+ pruned_layer_idx_list = [2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14,
+ 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25,
+ 26, 27, 28, 29]
+ print('pruned layer number: ', len(pruned_layer_idx_list))
+
+ # Reduce memory through low rank decomposition
+ hot_ratio = 15
+ if '13b' in model_path:
+ active_params = model_params = 13343959040
+ # edit file path
+ dump_dest = '/data/model_params/wiki/13b/light'
+ save_dest = f'/data/model_params/wiki/13b/uv_fils_{hot_ratio}_30'
+ elif '7b' in model_path:
+ active_params = model_params = 7000842240
+ # dump_dest save_dest edit file path
+ for filename in os.listdir(save_dest):
+ if filename.split('.')[-1] not in ['wu', 'wv']:
+ continue
+ file_path = os.path.join(save_dest, filename)
+ if os.path.exists(file_path):
+ os.remove(file_path)
+ print(f'{file_path} deletion.')
+ for name, module in model.named_modules():
+ if not isinstance(module, (LlamaMLP, LlamaAttention)):
+ continue
+ layer_idx = int(name.split(".")[-2])
+ if layer_idx not in pruned_layer_idx_list:
+ continue
+
+ suffix_list = ["q_proj", "k_proj" , "o_proj", "gate_proj", "up_proj", "down_proj"]
+ for suffix in suffix_list:
+ u = torch.load(os.path.join(dump_dest, f"{name}.{suffix}.u"), map_location=torch.device('cuda'))
+ s = torch.load(os.path.join(dump_dest, f"{name}.{suffix}.s"), map_location=torch.device('cuda'))
+ v = torch.load(os.path.join(dump_dest, f"{name}.{suffix}.v_inv"), map_location=torch.device('cuda'))
+ k = desired_ranks[f'{layer_idx}'][suffix][0]
+ u, v = utils.get_uv(u, s, v, k)
+
+ if suffix in ["gate_proj", "up_proj", "down_proj"]:
+ module_weight_numel = model.config.intermediate_size * model.config.hidden_size
+ module_weight_numel = int(0.85 * module_weight_numel)
+ elif suffix in ["q_proj", "k_proj"]:
+ module_weight_numel = model.config.hidden_size * \
+ (model.config.hidden_size /
+ (model.config.num_attention_heads / model.config.num_key_value_heads))
+ elif suffix in ["o_proj"]:
+ module_weight_numel = model.config.hidden_size * model.config.hidden_size
+ active_params -= module_weight_numel - v.numel() - u.numel()
+
+ torch.save(u, os.path.join(save_dest, f"{name}.{suffix}.wu"))
+ torch.save(v, os.path.join(save_dest, f"{name}.{suffix}.wv"))
+ print(f"{name}.{suffix} {k} wu wv saved.")
+
+ u = s = v = None
+ del u, s, v
+ utils.clear_torch_cache()
+ print(f"Estimated compression rate: {1 - active_params/model_params:.4f}")
+
+ helper.apply_sola_to_model(pruned_layer_idx_list, desired_ranks, hot_ratio, save_dest, model)
+ utils.clear_torch_cache()
+ print(f'Appling Done')
+
+ # torch.save(model.state_dict(), f'/data/model_params/model_{target_rate}.pth')
+
+ model.eval()
+ setup_seed(42)
+
+ # Evaluate perplexity
+ ppl = utils.eval_ppl(model, tokenizer)
+ print('ppl: ', ppl)
+
+ # Evaluate lm eval accuracy
+ hflm = HFLM(pretrained=model, tokenizer=tokenizer, batch_size=8)
+ task_names = ['piqa', 'hellaswag', 'boolq', 'winogrande', 'arc_easy', 'arc_challenge', 'openbookqa']
+ results = lm_eval.simple_evaluate(hflm, tasks=task_names, num_fewshot=0, batch_size=8)[
+ 'results'
+ ]
+ print(results)
+ metric_vals = {task: round(result.get(utils.TASK_METRIC_MAP[task]), 4) for task, result in results.items()}
+ acc_avg = utils.calculate_avg_accuracy(task_names, results)
+ metric_vals['average'] = round(acc_avg, 4)
+ print(metric_vals)
+
+ # Evaluate mmlu accuracy
+ utils.eval_mmlu(model, tokenizer, 5, "data/mmlu-data")
+ print('Eval MMLU Done \n')
+
+
+if __name__ == "__main__":
+ cli()
diff --git a/flagscale/compress/LLM_Decomposition/sola/playground_sola.sh b/flagscale/compress/LLM_Decomposition/sola/playground_sola.sh
new file mode 100644
index 0000000000..6dcc57af0c
--- /dev/null
+++ b/flagscale/compress/LLM_Decomposition/sola/playground_sola.sh
@@ -0,0 +1,3 @@
+#!/bin/bash
+
+CUDA_VISIBLE_DEVICES="0" python sola/playground_sola.py --model="/data/llama-2/Llama-2-13b-hf"
\ No newline at end of file
diff --git a/flagscale/compress/LLM_Decomposition/sola/training_sola.py b/flagscale/compress/LLM_Decomposition/sola/training_sola.py
new file mode 100644
index 0000000000..e08868bf72
--- /dev/null
+++ b/flagscale/compress/LLM_Decomposition/sola/training_sola.py
@@ -0,0 +1,160 @@
+import os
+import torch
+import contextlib
+from typing import Dict, List
+
+from utils import HELPER_SUPPORT_MODEL_LIST, HELPER_SUPPORT_MODEL_TYPES
+
+from sola.hook_sola import (
+ add_training_hook,
+ remove_training_hook,
+ add_inference_hook,
+)
+import utils
+
+
+class Helper(contextlib.ContextDecorator):
+ def __init__(self, model: HELPER_SUPPORT_MODEL_TYPES, compute_type, **kwargs):
+ self.model = model
+ self.device = model.device
+ self.compute_type = compute_type
+ self.hidden_size = kwargs["hidden_size"]
+ self.intermediate_size = kwargs["intermediate_size"]
+ self.training_data: Dict[str, Dict[str, List[torch.Tensor]]] = {}
+
+ if not isinstance(model, HELPER_SUPPORT_MODEL_LIST):
+ raise NotImplementedError("Unsupported model")
+
+ def __enter__(self):
+ self.model_last_layer = add_training_hook(self.model, self.training_data, self.intermediate_size, self.hidden_size)
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ remove_training_hook(self.model)
+
+ def apply_sola_to_model(self, pruned_layer_idx_list, desired_rank_pref, hot_ratio, usv_dump_dest, model: HELPER_SUPPORT_MODEL_TYPES):
+ import json
+ hitter_dict_path = f'/data/hitter_dict_{hot_ratio}_256_4096_wiki_13b.json'
+ with open(hitter_dict_path, 'r') as json_file:
+ hitter_dict = json.load(json_file)
+
+ def infer_device() -> torch.device:
+ if not torch.cuda.is_available():
+ return torch.device("cpu")
+ max_free_memory = -1
+ best_device_index = -1
+ for i in range(torch.cuda.device_count()):
+ current_device = torch.device(f"cuda:{i}")
+ torch.cuda.set_device(current_device)
+ free_memory = torch.cuda.get_device_properties(i).total_memory - torch.cuda.memory_allocated()
+ if free_memory > max_free_memory:
+ max_free_memory = free_memory
+ best_device_index = i
+ if best_device_index == -1:
+ return torch.device("cpu")
+ else:
+ return torch.device(f"cuda:{best_device_index}")
+
+ def regis_mlp_attn_func(pruned_layer_idx_list, model, hitter_dict):
+ from transformers.models.llama.modeling_llama import LlamaMLP, LlamaAttention
+ for name, module in model.named_modules():
+ if not isinstance(module, (LlamaMLP, LlamaAttention)):
+ continue
+ layer_idx = int(name.split(".")[-2])
+ if layer_idx not in pruned_layer_idx_list:
+ continue
+
+ dump_dest = usv_dump_dest
+
+ if isinstance(module, (LlamaMLP)):
+ # prime
+ hitter_dict_cur_layer = hitter_dict[f'{layer_idx}']
+ heavy_25p_hitter = torch.tensor(hitter_dict_cur_layer['heavy_15p_neuron_idx'])
+ gate_weight_heavy = module.gate_proj.weight[heavy_25p_hitter, :]
+ up_weight_heavy = module.up_proj.weight[heavy_25p_hitter, :]
+ down_weight_heavy = module.down_proj.weight[:, heavy_25p_hitter]
+ module.register_buffer('gate_weight_heavy', gate_weight_heavy.to(torch.bfloat16))
+ module.register_buffer('up_weight_heavy', up_weight_heavy.to(torch.bfloat16))
+ module.register_buffer('down_weight_heavy', down_weight_heavy.to(torch.bfloat16))
+ module.gate_proj = module.up_proj = module.down_proj = None
+ del module.gate_proj
+ del module.up_proj
+ del module.down_proj
+ utils.clear_torch_cache()
+
+ # marginal
+ suffix_list = ["gate_proj", "up_proj", "down_proj"]
+ for suffix in suffix_list:
+ if suffix not in desired_rank_pref[f'{layer_idx}'].keys():
+ module.register_buffer(f'{suffix}_use', torch.Tensor([True]))
+ print(f"{suffix} not in desired rank {desired_rank_pref[f'{layer_idx}'].keys()}.")
+ light_idx = torch.tensor(hitter_dict_cur_layer['light_85p_neuron_idx'])
+ if suffix == "gate_proj":
+ gate_weight_light = module.gate_proj.weight[light_idx, :]
+ module.register_buffer('gate_weight_light', gate_weight_light.to(torch.bfloat16))
+ elif suffix == "up_proj":
+ up_weight_light = module.up_proj.weight[light_idx, :]
+ module.register_buffer('up_weight_light', up_weight_light.to(torch.bfloat16))
+ else:
+ down_weight_light = module.down_proj.weight[:, light_idx]
+ module.register_buffer('down_weight_light', down_weight_light.to(torch.bfloat16))
+ else:
+ module.register_buffer(f'{suffix}_use', torch.Tensor([False]))
+ u = torch.load(os.path.join(dump_dest, f"{name}.{suffix}.wu"), map_location=torch.device(infer_device()))
+ v = torch.load(os.path.join(dump_dest, f"{name}.{suffix}.wv"), map_location=torch.device(infer_device()))
+ print('get u v: ', name, suffix, u.shape, v.shape, u.device, v.device)
+ if suffix == "gate_proj":
+ module.register_buffer('gate_weight_U_top', v.t().to(torch.bfloat16))
+ module.register_buffer('gate_weight_SVh_top', u.t().to(torch.bfloat16))
+ elif suffix == "up_proj":
+ module.register_buffer('up_weight_U_top', v.t().to(torch.bfloat16))
+ module.register_buffer('up_weight_SVh_top', u.t().to(torch.bfloat16))
+ else:
+ module.register_buffer('down_weight_U_top', v.t().to(torch.bfloat16))
+ module.register_buffer('down_weight_SVh_top', u.t().to(torch.bfloat16))
+ u = s = v = None
+ del u, s, v
+ utils.clear_torch_cache()
+ else:
+ suffix_list = ["q_proj", "k_proj", "o_proj"]
+ for suffix in suffix_list:
+ if suffix not in desired_rank_pref[f'{layer_idx}'].keys():
+ print(f"{suffix} not in {desired_rank_pref[f'{layer_idx}'].keys()}.")
+ module.register_buffer(f'{suffix}_use', torch.Tensor([True]))
+ else:
+ module.register_buffer(f'{suffix}_use', torch.Tensor([False]))
+ u = torch.load(os.path.join(dump_dest, f"{name}.{suffix}.wu"), map_location=torch.device(infer_device()))
+ v = torch.load(os.path.join(dump_dest, f"{name}.{suffix}.wv"), map_location=torch.device(infer_device()))
+ print('attn get u v: ', name, suffix, u.shape, v.shape, u.device, v.device)
+
+ if suffix == "q_proj":
+ module.register_buffer('q_weight_U_top', v.t().to(torch.bfloat16))
+ module.register_buffer('q_weight_SVh_top', u.t().to(torch.bfloat16))
+ module.q_proj = None
+ del module.q_proj
+ utils.clear_torch_cache()
+ elif suffix == "k_proj":
+ module.register_buffer('k_weight_U_top', v.t().to(torch.bfloat16))
+ module.register_buffer('k_weight_SVh_top', u.t().to(torch.bfloat16))
+ module.k_proj = None
+ del module.k_proj
+ utils.clear_torch_cache()
+ elif suffix == "o_proj":
+ module.register_buffer('o_weight_U_top', v.t().to(torch.bfloat16))
+ module.register_buffer('o_weight_SVh_top', u.t().to(torch.bfloat16))
+ module.o_proj = None
+ del module.o_proj
+ utils.clear_torch_cache()
+ # elif suffix == "v_proj":
+ # module.register_buffer('v_weight_U_top', v.t().to(torch.bfloat16))
+ # module.register_buffer('v_weight_SVh_top', u.t().to(torch.bfloat16))
+ # module.v_proj = None
+ # del module.v_proj
+ # utils.clear_torch_cache()
+ u = s = v = None
+ del u, s, v
+ utils.clear_torch_cache()
+
+ regis_mlp_attn_func(pruned_layer_idx_list, model, hitter_dict)
+
+ add_inference_hook(pruned_layer_idx_list, model)
+
\ No newline at end of file
diff --git a/flagscale/compress/LLM_Decomposition/sola/utils.py b/flagscale/compress/LLM_Decomposition/sola/utils.py
new file mode 100644
index 0000000000..c77ece4dff
--- /dev/null
+++ b/flagscale/compress/LLM_Decomposition/sola/utils.py
@@ -0,0 +1,811 @@
+import os
+import re
+import gc
+import time
+import json
+import pickle
+import torch
+import transformers
+import lm_eval
+import numpy as np
+import pandas as pd
+from tqdm import tqdm
+from enum import Enum
+from datasets import load_dataset
+from typing import Any, Union, Dict, TypeVar, Generic, Iterable, List, Iterator
+from transformers import OPTPreTrainedModel, LlamaPreTrainedModel
+from optim import Optimizer
+import copy
+import itertools
+import bisect
+import warnings
+
+
+T_co = TypeVar('T_co', covariant=True)
+class Dataset(Generic[T_co]):
+ def __getitem__(self, index) -> T_co:
+ raise NotImplementedError
+
+ def __add__(self, other: 'Dataset[T_co]') -> 'ConcatDataset[T_co]':
+ return ConcatDataset([self, other])
+
+
+class ConcatDataset(Dataset[T_co]):
+ r"""Dataset as a concatenation of multiple datasets.
+
+ This class is useful to assemble different existing datasets.
+
+ Args:
+ datasets (sequence): List of datasets to be concatenated
+ """
+ datasets: List[Dataset[T_co]]
+ cumulative_sizes: List[int]
+
+ @staticmethod
+ def cumsum(sequence):
+ r, s = [], 0
+ for e in sequence:
+ l = len(e)
+ r.append(l + s)
+ s += l
+ return r
+
+ def __init__(self, datasets: Iterable[Dataset]) -> None:
+ super(ConcatDataset, self).__init__()
+ # Cannot verify that datasets is Sized
+ assert len(datasets) > 0, 'datasets should not be an empty iterable' # type: ignore
+ self.datasets = list(datasets)
+ for d in self.datasets:
+ assert not isinstance(d, IterableDataset), "ConcatDataset does not support IterableDataset"
+ self.cumulative_sizes = self.cumsum(self.datasets)
+
+ def __len__(self):
+ return self.cumulative_sizes[-1]
+
+ def __getitem__(self, idx):
+ if idx < 0:
+ if -idx > len(self):
+ raise ValueError("absolute value of index should not exceed dataset length")
+ idx = len(self) + idx
+ dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
+ if dataset_idx == 0:
+ sample_idx = idx
+ else:
+ sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
+ return self.datasets[dataset_idx][sample_idx]
+
+ @property
+ def cummulative_sizes(self):
+ warnings.warn("cummulative_sizes attribute is renamed to "
+ "cumulative_sizes", DeprecationWarning, stacklevel=2)
+ return self.cumulative_sizes
+
+class IterableDataset(Dataset[T_co]):
+ def __iter__(self) -> Iterator[T_co]:
+ raise NotImplementedError
+
+ def __add__(self, other: Dataset[T_co]):
+ return ChainDataset([self, other])
+
+class ChainDataset(IterableDataset):
+ r"""Dataset for chainning multiple :class:`IterableDataset` s.
+
+ This class is useful to assemble different existing dataset streams. The
+ chainning operation is done on-the-fly, so concatenating large-scale
+ datasets with this class will be efficient.
+
+ Args:
+ datasets (iterable of IterableDataset): datasets to be chained together
+ """
+ def __init__(self, datasets: Iterable[Dataset]) -> None:
+ super(ChainDataset, self).__init__()
+ self.datasets = datasets
+
+ def __iter__(self):
+ for d in self.datasets:
+ assert isinstance(d, IterableDataset), "ChainDataset only supports IterableDataset"
+ for x in d:
+ yield x
+
+ def __len__(self):
+ total = 0
+ for d in self.datasets:
+ assert isinstance(d, IterableDataset), "ChainDataset only supports IterableDataset"
+ # Cannot verify that all self.datasets are Sized
+ total += len(d) # type: ignore
+ return total
+
+def get_test_data(name, tokenizer, seq_len=2048, batch_size=4):
+ class IndexDataset(Dataset):
+ def __init__(self, tensors):
+ self.tensors = tensors
+
+ def __getitem__(self, index):
+ return self.tensors[index]
+
+ def __len__(self):
+ return len(self.tensors)
+ ####
+ def process_data(samples, tokenizer, seq_len, field_name):
+ test_ids = tokenizer("\n\n".join(samples[field_name]), return_tensors='pt').input_ids[0]
+ test_ids_batch = []
+ nsamples = test_ids.numel() // seq_len
+
+ for i in range(nsamples):
+ batch = test_ids[(i * seq_len):((i + 1) * seq_len)]
+ test_ids_batch.append(batch)
+ test_ids_batch = torch.stack(test_ids_batch)
+ return IndexDataset(tensors=test_ids_batch)
+ ####
+ if 'wikitext2' in name:
+ test_data = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test')
+ test_dataset = process_data(test_data, tokenizer, seq_len, 'text')
+ if 'ptb' in name:
+ test_data = load_dataset('ptb_text_only', 'penn_treebank', split='test')
+ test_dataset = process_data(test_data, tokenizer, seq_len, 'sentence')
+ elif 'c4' in name:
+ test_data = load_dataset("json", data_files="utils/c4-validation.json")['train']
+ test_dataset = process_data(test_data[0:2000], tokenizer, seq_len, 'text')
+ test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
+ return test_loader
+
+@torch.no_grad()
+def eff_eval(model, tokenizer, dataset='wikitext2', original_len=4, generated_len=128, batch_size=1, device="cuda"):
+ model.eval()
+ token_num = 0
+ factor = 1 if batch_size > 32 else 3
+ num_batches_to_fetch = 13 * factor
+ test_loader = get_test_data(dataset, tokenizer, seq_len=original_len, batch_size=batch_size)
+ start_time = time.time()
+ progress_bar = tqdm(enumerate(itertools.islice(test_loader, num_batches_to_fetch)))
+ for batch_idx, batch_data in progress_bar:
+ batch = batch_data.to(device)
+ if batch_idx == 3 * factor:
+ start_time = time.time()
+ if batch_idx >= 3 * factor:
+ token_num += batch.shape[0] * generated_len
+ progress_bar.set_postfix_str(f"{token_num=}")
+ generation_output = model.generate(
+ input_ids=batch,
+ pad_token_id=tokenizer.eos_token_id,
+ do_sample=True,
+ use_cache=True,
+ max_length=original_len+generated_len,
+ )
+ torch.cuda.synchronize()
+ end_time = time.time()
+ total_time = end_time - start_time
+ throughput = token_num / total_time
+ return throughput
+
+
+def eff_eval_v2(model, tokenizer, dataset='wikitext2', original_len=4, generated_len=128, batch_size=1, device="cuda"):
+ model.eval()
+ throughput = 0
+ token_num = 0
+ num_batches_to_fetch = 130
+ test_loader = get_test_data(dataset, tokenizer, seq_len=original_len, batch_size = batch_size)
+
+ for batch_idx, batch_data in enumerate(itertools.islice(test_loader, num_batches_to_fetch)):
+ batch = batch_data.to(device)
+ if batch_idx >= 30:
+ break
+ torch.cuda.empty_cache()
+ torch.cuda.reset_peak_memory_stats(0)
+ torch.cuda.synchronize()
+ generation_output = model.generate(
+ input_ids=batch,
+ pad_token_id=tokenizer.eos_token_id,
+ do_sample=True,
+ use_cache=True,
+ top_k=50,
+ max_length=original_len+generated_len,
+ top_p=0.95,
+ temperature=1,
+ )
+ torch.cuda.synchronize()
+
+ torch.cuda.empty_cache()
+ torch.cuda.reset_peak_memory_stats(0)
+ torch.cuda.synchronize()
+
+ start_time = time.time()
+ for batch_idx, batch_data in enumerate(itertools.islice(test_loader, num_batches_to_fetch)):
+ if batch_idx < 30:
+ continue
+ batch = batch_data.to(device)
+ token_num += batch.shape[0] * generated_len
+ generation_output = model.generate(
+ input_ids=batch,
+ pad_token_id=tokenizer.eos_token_id,
+ do_sample=True,
+ use_cache=True,
+ top_k=50,
+ max_length=original_len+generated_len,
+ top_p=0.95,
+ temperature=1,
+ )
+ torch.cuda.synchronize()
+ end_time = time.time()
+ throughput = end_time - start_time
+ print("time: {}".format(end_time - start_time))
+ print("Throughput: {} tokens/sec".format(token_num / throughput))
+ return token_num / throughput
+
+
+def clear_torch_cache() -> None:
+ gc.collect()
+ torch.cuda.empty_cache()
+
+def get_uv(u, s, v, k):
+ svd_u = u[:, :k]
+ svd_s = s[:k]
+ svd_v = v[:k, :]
+ sqrt_s = torch.diag(torch.sqrt(svd_s))
+ if svd_u.device != sqrt_s.device:
+ print('svd u s device: ', svd_u.device, sqrt_s.device)
+ svd_u = svd_u.to(sqrt_s.device)
+ if sqrt_s.device != svd_v.device:
+ print('svd s v device: ', sqrt_s.device, svd_v.device)
+ svd_v = svd_v.to(sqrt_s.device)
+ clear_torch_cache()
+ u=(svd_u @ sqrt_s).T
+ v=(sqrt_s @ svd_v).T
+ return u, v
+
+def get_rank_config(model, target_modules, target_rate, model_path, start=99.9, end=90):
+ optimizer = constitute_mapping(model, target_modules, target_rate, model_path, start=99.9, end=90)
+ optimized_state = optimizer.constringe()
+ rank = {}
+ for name, module in model.named_modules():
+ suffix = name.split(".")[-1]
+ if target_rate >= 0.8:
+ target_modules = ["q_proj", "k_proj", "gate_proj", "up_proj", "down_proj"]
+ else:
+ target_modules = ["q_proj", "k_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
+ if suffix not in target_modules or name not in optimized_state.keys():
+ print(f"{name} skipped.")
+ continue
+ cur = optimized_state[name]["cur"]
+ desired_rank = (cur + 1) * 32 # rank
+ accum = optimized_state[name]["perf_var"][cur] # perf score
+ layer_idx = int(name.split(".")[-3])
+ if layer_idx not in rank.keys():
+ rank[layer_idx] = {suffix: (desired_rank, accum)}
+ else:
+ rank[layer_idx][suffix] = (desired_rank, accum)
+ print(f"{name=}, {desired_rank=}, {accum:.2f}%.")
+
+ q_proj_values_non_zero = []
+ k_proj_values_non_zero = []
+ o_proj_values_non_zero = []
+ gate_proj_values_non_zero = []
+ up_proj_values_non_zero = []
+ down_proj_values_non_zero = []
+ for key in rank.keys():
+ if 'q_proj' in rank[key].keys():
+ q_proj_values_non_zero.append(rank[key]['q_proj'][0])
+ if 'k_proj' in rank[key].keys():
+ k_proj_values_non_zero.append(rank[key]['k_proj'][0])
+ if 'o_proj' in rank[key].keys():
+ o_proj_values_non_zero.append(rank[key]['o_proj'][0])
+ if 'gate_proj' in rank[key].keys():
+ gate_proj_values_non_zero.append(rank[key]['gate_proj'][0])
+ if 'up_proj' in rank[key].keys():
+ up_proj_values_non_zero.append(rank[key]['up_proj'][0])
+ if 'down_proj' in rank[key].keys():
+ down_proj_values_non_zero.append(rank[key]['down_proj'][0])
+
+ q_r = sorted(q_proj_values_non_zero)[-int(len(sorted(q_proj_values_non_zero)) * 0.3)]
+ k_r = sorted(k_proj_values_non_zero)[-int(len(sorted(k_proj_values_non_zero)) * 0.3)]
+ q_r = ((q_r + 159) // 160) * 160
+ k_r = ((k_r + 159) // 160) * 160
+ print('q:', q_r, int(len(sorted(q_proj_values_non_zero)) * 0.3))
+ print('k:', k_r, int(len(sorted(k_proj_values_non_zero)) * 0.3))
+ if target_rate < 0.8:
+ o_r = sorted(o_proj_values_non_zero)[-int(len(sorted(o_proj_values_non_zero)) * 0.55)]
+ o_r = ((o_r + 159) // 160) * 160
+ print('o:', o_r, int(len(sorted(o_proj_values_non_zero)) * 0.55))
+ if target_rate == 0.7:
+ gate_r = sorted(gate_proj_values_non_zero)[-int(len(sorted(gate_proj_values_non_zero)) * 0.3)]
+ else:
+ gate_r = sorted(gate_proj_values_non_zero)[-int(len(sorted(gate_proj_values_non_zero)) * 0.7)]
+ gate_r = ((gate_r + 159) // 160) * 160
+ up_r = sorted(up_proj_values_non_zero)[-int(len(sorted(up_proj_values_non_zero)) * 0.6)]
+ up_r = ((up_r + 159) // 160) * 160
+ print('gate:', gate_r, int(len(sorted(gate_proj_values_non_zero)) * 0.5))
+ print('up:', up_r, int(len(sorted(up_proj_values_non_zero)) * 0.5))
+ if '13b' in model_path:
+ down_r = int((5120 * int(13824 * 0.85)) / (5120 + int(13824 * 0.85)) * target_rate)
+ if '7b' in model_path:
+ down_r = int((4096 * int(11008 * 0.85)) / (4096 + int(11008 * 0.85)) * target_rate)
+ down_r = ((down_r + 159) // 160) * 160
+ print('down:', down_r)
+
+ if target_rate < 0.8:
+ config = {"q_proj": q_r, "k_proj": k_r, "o_proj": o_r,
+ "gate_proj": gate_r, "up_proj": up_r, "down_proj": down_r}
+ else:
+ config = {"q_proj": q_r, "k_proj": k_r,
+ "gate_proj": gate_r, "up_proj": up_r, "down_proj": down_r}
+
+ return config
+
+def constitute_mapping(model, target_modules, target_rate, model_name,
+ granularity=32, start=99, end=90,
+ dump_dest="/home/xinhao/sep-peft/model_params/wiki/svd_llm/light",
+ dump_dest_attn="/home/xinhao/sep-peft/model_params/wiki/svd_llm/attn"):
+ if "Llama-2-13b-hf" in model_name:
+ config = {
+ "meta-llama/Llama-2-13b-hf": {
+ "architectures": ["LlamaForCausalLM"],
+ "hidden_size": 5120,
+ "intermediate_size": 13824,
+ "num_attention_heads": 40,
+ "num_hidden_layers": 40,
+ "vocab_size": 32000,
+ "params": 13343959040,
+ "lora": 125173760,
+ "optim": 250347520,
+ "grad": 125173760,
+ }
+ }
+ skip_layers = [0, 1, 38, 39]
+ elif "Llama-2-7b-hf" in model_name:
+ config = {
+ "meta-llama/Llama-2-7b-hf": {
+ "architectures": ["LlamaForCausalLM"],
+ "hidden_size": 4096,
+ "intermediate_size": 11008,
+ "num_attention_heads": 32,
+ "num_hidden_layers": 32,
+ "params": 7000842240,
+ "lora": 79953920,
+ "optim": 159907840,
+ "grad": 79953920,
+ }
+ }
+ skip_layers = [0, 1, 29, 30]
+ model_params = config[model_name]["params"]
+ target_rate *= 100
+ optimizer = Optimizer(model_params, target_rate, granularity, start, end)
+
+ # Called during adding hook for updating
+ for name, module in model.named_modules():
+ suffix = name.split(".")[-1]
+
+ if suffix not in target_modules:
+ print(f"[constitute-mapping] skipping {name} due to configuration")
+ continue
+
+ layer_idx = int(name.split(".")[-3])
+ if layer_idx in skip_layers:
+ print(f"[constitute-mapping] skipping {layer_idx} due to configuration")
+ continue
+
+ if target_rate >= 80:
+ suffix_attn_list = ["q_proj", "k_proj"]
+ else:
+ suffix_attn_list = ["q_proj", "k_proj", "o_proj"]
+ if suffix in suffix_attn_list:
+ if not os.path.exists(os.path.join(dump_dest_attn, f"{name}.s")):
+ print(f"[constitute-mapping] skipping {name} since it cannot find sigma file")
+ continue
+ s = torch.load(os.path.join(dump_dest_attn, f"{name}.s"), map_location="cuda:0")
+ else:
+ if not os.path.exists(os.path.join(dump_dest, f"{name}.s")):
+ print(f"[constitute-mapping] skipping {name} since it cannot find sigma file")
+ continue
+ s = torch.load(os.path.join(dump_dest, f"{name}.s"), map_location="cuda:0")
+
+ m, n = module.weight.shape
+ max_trunc = int((m * n) / (m + n))
+ sigma_square = s ** 2
+ total = sigma_square.sum()
+
+ perf_var = []
+ grad = []
+ start_idx = 10000000 # random large number
+ for trunc in range(granularity, max_trunc, granularity):
+ perf = (sigma_square[:trunc].sum() / total * 100).item() # in percentage
+ if perf >= start:
+ start_idx = min(len(perf_var), start_idx)
+ idx = len(perf_var)
+ perf_var.append(perf)
+ grad.append(0 if idx == 0 else perf - perf_var[idx-1])
+
+ start_idx = min(start_idx, len(perf_var) - 1)
+
+ # skip if it is not profitable
+ deft_uniform = None
+ if deft_uniform is not None and (start_idx + 1) * granularity >= max_trunc - 5 * granularity:
+ print(f"[constitute-mapping] skipping {name} for trivial profit")
+ continue
+
+ # collect
+ optimizer.add_item(name, start_idx, tuple([m, n]), perf_var, grad)
+
+ return optimizer
+
+
+class HelperState(Enum):
+ KEY = 10000
+
+ Collecting = 0
+ Inference = 1
+
+ Invalid = 9999
+
+
+HelperState.KEY.label = "HelperState"
+HelperState.Collecting.label = "Helper-Data-Collection" # hook forward() to collect data
+HelperState.Inference.label = "Helper-Ready-Inference" # with updated forward()
+
+class HelperState(Enum):
+ KEY = 10000
+
+ Collecting = 0
+ Inference = 1
+
+ Invalid = 9999
+
+
+HelperState.KEY.label = "HelperState"
+HelperState.Collecting.label = "Helper-Data-Collection" # hook forward() to collect data
+HelperState.Inference.label = "Helper-Ready-Inference" # with updated forward()
+
+
+class HelperCollectState(Enum):
+ KEY = 10001
+
+ Pre = 0
+ Post = 1
+ End = 2
+
+ Invalid = 9999
+
+
+HelperCollectState.KEY.label = "HelperCollectState"
+HelperCollectState.Pre.label = "HelperCollectState-Pre"
+HelperCollectState.Post.label = "HelperCollectState-Post"
+HelperCollectState.End.label = "HelperCollectState-End"
+
+def set_helper_state(model, state: HelperState) -> None:
+ setattr(model, HelperState.KEY.label, state)
+
+
+HELPER_SUPPORT_MODEL_LIST = (LlamaPreTrainedModel)
+HELPER_SUPPORT_MODEL_TYPES = Union[LlamaPreTrainedModel]
+
+
+# https://pypi.org/project/lm-eval/0.0.1/
+TASK_METRIC_MAP = {
+ "piqa": "acc_norm,none",
+ "arc_challenge": "acc_norm,none",
+ "arc_easy": "acc_norm,none",
+ "hellaswag": "acc_norm,none",
+ "winogrande": "acc,none",
+ "boolq": "acc,none",
+ 'wsc': 'acc,none',
+ "openbookqa": "acc_norm,none"
+}
+
+def calculate_avg_accuracy(task_names: str, results: dict) -> float:
+ n_tasks = len(task_names)
+ acc_cumul = sum(result.get(TASK_METRIC_MAP[task]) for task, result in results.items() if 'mmlu' not in task)
+
+ questions_per_mmlu_task = {
+ task_name: lm_eval.tasks.get_task_dict([task_name])[task_name].dataset["test"].num_rows
+ for task_name in task_names
+ if 'mmlu' in task_name
+ }
+
+ if not questions_per_mmlu_task:
+ return acc_cumul / n_tasks
+
+ # Calculate average accuracy for mmlu tasks, weighted by number of questions in each task
+ acc_mmlu = sum(
+ result.get(TASK_METRIC_MAP[task]) * questions_per_mmlu_task[task]
+ for task, result in results.items()
+ if 'mmlu' in task
+ )
+ acc_mmlu_avg = acc_mmlu / sum(questions_per_mmlu_task.values())
+
+ return (acc_cumul + acc_mmlu_avg) / (n_tasks - len(questions_per_mmlu_task) + 1)
+
+
+def easy_dump(obj, dest, label):
+ with open(os.path.join(dest, f"{label}.pkl"), "wb") as f:
+ pickle.dump(obj, f)
+
+ # also dump as json if it is a dict
+ if isinstance(obj, dict):
+ with open(os.path.join(dest, f"{label}.json"), "w") as f:
+ f.write(json.dumps(obj, indent=4))
+
+def make_run_dir(outdir: Union[str, os.PathLike], desc: str) -> str:
+ """Reject modernity, return to automatically create the run dir."""
+ # Pick output directory.
+ prev_run_dirs = []
+ if os.path.isdir(outdir): # sanity check, but click.Path() should clear this one
+ prev_run_dirs = [x for x in os.listdir(outdir) if os.path.isdir(os.path.join(outdir, x))]
+ prev_run_ids = [re.match(r'^\d+', x) for x in prev_run_dirs]
+ prev_run_ids = [int(x.group()) for x in prev_run_ids if x is not None]
+ cur_run_id = max(prev_run_ids, default=-1) + 1 # start with 00000
+ run_dir = os.path.join(outdir, f'{cur_run_id:05d}-{desc}')
+ os.makedirs(run_dir, exist_ok=False) # make sure it doesn't already exist
+ return run_dir
+
+class EasyDict(dict):
+ """Convenience class that behaves like a dict but allows access with the attribute syntax."""
+
+ def __getattr__(self, name: str) -> Any:
+ try:
+ return self[name]
+ except KeyError:
+ raise AttributeError(name)
+
+ def __setattr__(self, name: str, value: Any) -> None:
+ self[name] = value
+
+ def __delattr__(self, name: str) -> None:
+ del self[name]
+
+DEFAULT_PAD_TOKEN = "[PAD]"
+def smart_tokenizer_and_embedding_resize(
+ special_tokens_dict: Dict,
+ tokenizer: transformers.PreTrainedTokenizer,
+ model: transformers.PreTrainedModel,
+):
+ """Resize tokenizer and embedding.
+
+ Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
+ """
+ num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
+ model.resize_token_embeddings(len(tokenizer))
+
+ if num_new_tokens > 0:
+ input_embeddings_data = model.get_input_embeddings().weight.data
+ output_embeddings_data = model.get_output_embeddings().weight.data
+
+ input_embeddings_avg = input_embeddings_data[:-num_new_tokens].mean(dim=0, keepdim=True)
+ output_embeddings_avg = output_embeddings_data[:-num_new_tokens].mean(dim=0, keepdim=True)
+
+ input_embeddings_data[-num_new_tokens:] = input_embeddings_avg
+ output_embeddings_data[-num_new_tokens:] = output_embeddings_avg
+
+subcategories = {
+ "abstract_algebra": ["math"],
+ "anatomy": ["health"],
+ "astronomy": ["physics"],
+ "business_ethics": ["business"],
+ "clinical_knowledge": ["health"],
+ "college_biology": ["biology"],
+ "college_chemistry": ["chemistry"],
+ "college_computer_science": ["computer science"],
+ "college_mathematics": ["math"],
+ "college_medicine": ["health"],
+ "college_physics": ["physics"],
+ "computer_security": ["computer science"],
+ "conceptual_physics": ["physics"],
+ "econometrics": ["economics"],
+ "electrical_engineering": ["engineering"],
+ "elementary_mathematics": ["math"],
+ "formal_logic": ["philosophy"],
+ "global_facts": ["other"],
+ "high_school_biology": ["biology"],
+ "high_school_chemistry": ["chemistry"],
+ "high_school_computer_science": ["computer science"],
+ "high_school_european_history": ["history"],
+ "high_school_geography": ["geography"],
+ "high_school_government_and_politics": ["politics"],
+ "high_school_macroeconomics": ["economics"],
+ "high_school_mathematics": ["math"],
+ "high_school_microeconomics": ["economics"],
+ "high_school_physics": ["physics"],
+ "high_school_psychology": ["psychology"],
+ "high_school_statistics": ["math"],
+ "high_school_us_history": ["history"],
+ "high_school_world_history": ["history"],
+ "human_aging": ["health"],
+ "human_sexuality": ["culture"],
+ "international_law": ["law"],
+ "jurisprudence": ["law"],
+ "logical_fallacies": ["philosophy"],
+ "machine_learning": ["computer science"],
+ "management": ["business"],
+ "marketing": ["business"],
+ "medical_genetics": ["health"],
+ "miscellaneous": ["other"],
+ "moral_disputes": ["philosophy"],
+ "moral_scenarios": ["philosophy"],
+ "nutrition": ["health"],
+ "philosophy": ["philosophy"],
+ "prehistory": ["history"],
+ "professional_accounting": ["other"],
+ "professional_law": ["law"],
+ "professional_medicine": ["health"],
+ "professional_psychology": ["psychology"],
+ "public_relations": ["politics"],
+ "security_studies": ["politics"],
+ "sociology": ["culture"],
+ "us_foreign_policy": ["politics"],
+ "virology": ["health"],
+ "world_religions": ["philosophy"],
+}
+
+categories = {
+ "STEM": ["physics", "chemistry", "biology", "computer science", "math", "engineering"],
+ "humanities": ["history", "philosophy", "law"],
+ "social sciences": ["politics", "culture", "economics", "geography", "psychology"],
+ "other (business, health, misc.)": ["other", "business", "health"],
+}
+
+choices = ["A", "B", "C", "D"]
+
+
+def format_subject(subject):
+ l = subject.split("_")
+ s = ""
+ for entry in l:
+ s += " " + entry
+ return s
+
+def format_example(df, idx, include_answer=True):
+ prompt = df.iloc[idx, 0]
+ k = df.shape[1] - 2
+ for j in range(k):
+ prompt += "\n{}. {}".format(choices[j], df.iloc[idx, j + 1])
+ prompt += "\nAnswer:"
+ if include_answer:
+ prompt += " {}\n\n".format(df.iloc[idx, k + 1])
+ return prompt
+
+
+def gen_prompt(train_df, subject, k=-1):
+ prompt = "The following are multiple choice questions (with answers) about {}.\n\n".format(
+ format_subject(subject)
+ )
+ if k == -1:
+ k = train_df.shape[0]
+ for i in range(k):
+ prompt += format_example(train_df, i)
+ return prompt
+
+
+@torch.no_grad()
+def evaluate_subject(subject, model, tokenizer, ntrain, dev_df, test_df):
+ cors = []
+ all_probs = []
+ answers = choices[: test_df.shape[1] - 2]
+
+ for i in tqdm(range(test_df.shape[0]), desc=subject):
+ # get prompt and make sure it fits
+ k = ntrain
+ prompt_end = format_example(test_df, i, include_answer=False)
+ train_prompt = gen_prompt(dev_df, subject, k)
+ prompt = train_prompt + prompt_end
+
+ input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(model.device)
+
+ while input_ids.shape[-1] > 2048:
+ k -= 1
+ train_prompt = gen_prompt(dev_df, subject, k)
+ prompt = train_prompt + prompt_end
+ input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(
+ model.device
+ )
+
+ label = test_df.iloc[i, test_df.shape[1] - 1]
+
+ logits = model(input_ids=input_ids).logits[0, -1]
+
+ probs = (
+ torch.nn.functional.softmax(
+ torch.tensor(
+ [
+ logits[tokenizer("A").input_ids[-1]],
+ logits[tokenizer("B").input_ids[-1]],
+ logits[tokenizer("C").input_ids[-1]],
+ logits[tokenizer("D").input_ids[-1]],
+ ]
+ ).float(),
+ dim=0,
+ )
+ .detach()
+ .cpu()
+ .numpy()
+ )
+ pred = {0: "A", 1: "B", 2: "C", 3: "D"}[np.argmax(probs)]
+
+ cor = pred == label
+ cors.append(cor)
+ all_probs.append(probs)
+
+ acc = np.mean(cors)
+ cors = np.array(cors)
+
+ all_probs = np.array(all_probs)
+ print("Average accuracy {:.3f} - {}".format(acc, subject))
+
+ return cors, acc, all_probs
+
+def eval_mmlu(model, tokenizer, ntrain, data_dir):
+ subjects = sorted(
+ [
+ f.split("_test.csv")[0]
+ for f in os.listdir(os.path.join(data_dir, "test"))
+ if "_test.csv" in f
+ ]
+ )
+
+ all_cors = []
+ subcat_cors = {
+ subcat: [] for subcat_lists in subcategories.values() for subcat in subcat_lists
+ }
+ cat_cors = {cat: [] for cat in categories}
+
+ start_time = time.time()
+ for subject in subjects:
+ dev_df = pd.read_csv(
+ os.path.join(data_dir, "dev", subject + "_dev.csv"), header=None
+ )[: ntrain]
+ test_df = pd.read_csv(
+ os.path.join(data_dir, "test", subject + "_test.csv"), header=None
+ )
+
+ cors, acc, probs = evaluate_subject(subject, model, tokenizer, ntrain, dev_df, test_df)
+ subcats = subcategories[subject]
+ for subcat in subcats:
+ subcat_cors[subcat].append(cors)
+ for key in categories.keys():
+ if subcat in categories[key]:
+ cat_cors[key].append(cors)
+ all_cors.append(cors)
+
+ results = {"subcategories": {}, "categories": {}}
+ for subcat in subcat_cors:
+ subcat_acc = np.mean(np.concatenate(subcat_cors[subcat]))
+ results["subcategories"][subcat] = subcat_acc
+ print("Average accuracy {:.3f} - {}".format(subcat_acc, subcat))
+
+ for cat in cat_cors:
+ cat_acc = np.mean(np.concatenate(cat_cors[cat]))
+ results["categories"][cat] = cat_acc
+ print("Average accuracy {:.3f} - {}".format(cat_acc, cat))
+ weighted_acc = np.mean(np.concatenate(all_cors))
+ results["weighted_accuracy"] = weighted_acc
+ print("Average accuracy: {:.3f}".format(weighted_acc))
+
+ end_time = time.time()
+ results["cost_time"] = end_time - start_time
+
+ return results
+
+def eval_ppl(model, tokenizer):
+ model.eval()
+ max_length = model.config.max_position_embeddings
+ stride = max_length
+ test = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")["text"]
+ nlls = []
+ encodings = tokenizer("\n\n".join(test), return_tensors="pt")
+ seq_len = encodings.input_ids.size(1)
+ prev_end_loc = 0
+ for begin_loc in tqdm(range(0, seq_len, stride)):
+ end_loc = min(begin_loc + max_length, seq_len)
+ trg_len = end_loc - prev_end_loc # may be different from stride on last loop
+ input_ids = encodings.input_ids[:, begin_loc:end_loc].to(model.device)
+ target_ids = input_ids.clone()
+ target_ids[:, :-trg_len] = -100
+ with torch.no_grad():
+ outputs = model(input_ids, labels=target_ids)
+ # loss is calculated using CrossEntropyLoss which averages over valid labels
+ # N.B. the model only calculates loss over trg_len - 1 labels, because it internally shifts the labels
+ # to the left by 1.
+ neg_log_likelihood = outputs.loss
+ nlls.append(neg_log_likelihood)
+ prev_end_loc = end_loc
+ if end_loc == seq_len:
+ break
+ ppl = torch.exp(torch.stack(nlls).mean())
+ return ppl
diff --git a/flagscale/compress/LLM_Decomposition/sola_neuron_idx/hook.py b/flagscale/compress/LLM_Decomposition/sola_neuron_idx/hook.py
new file mode 100644
index 0000000000..7dc4c34d83
--- /dev/null
+++ b/flagscale/compress/LLM_Decomposition/sola_neuron_idx/hook.py
@@ -0,0 +1,305 @@
+import torch
+from torch import nn
+import torch.nn.functional as F
+from transformers import LlamaPreTrainedModel
+from transformers.models.llama.modeling_llama import LlamaMLP, LlamaDecoderLayer, LlamaAttention
+from transformers.cache_utils import Cache, DynamicCache
+
+import warnings
+import functools
+import math
+import logging
+logger = logging.getLogger(__name__)
+
+from typing import List, Dict, Tuple, Optional
+
+from utils import (
+ HelperState,
+ set_helper_state,
+ HELPER_SUPPORT_MODEL_LIST,
+ HELPER_SUPPORT_MODEL_TYPES
+)
+
+
+_HELPER_HOOK_KEY = "HelperHook"
+
+def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
+ """
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
+ """
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
+ if n_rep == 1:
+ return hidden_states
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
+
+def rotate_half(x):
+ """Rotates half the hidden dims of the input."""
+ x1 = x[..., : x.shape[-1] // 2]
+ x2 = x[..., x.shape[-1] // 2 :]
+ return torch.cat((-x2, x1), dim=-1)
+
+def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
+ """Applies Rotary Position Embedding to the query and key tensors.
+
+ Args:
+ q (`torch.Tensor`): The query tensor.
+ k (`torch.Tensor`): The key tensor.
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
+ position_ids (`torch.Tensor`):
+ The position indices of the tokens corresponding to the query and key tensors. For example, this can be
+ used to pass offsetted position ids when working with a KV-cache.
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
+ Returns:
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
+ """
+ cos = cos[position_ids].unsqueeze(unsqueeze_dim)
+ sin = sin[position_ids].unsqueeze(unsqueeze_dim)
+ q_embed = (q * cos) + (rotate_half(q) * sin)
+ k_embed = (k * cos) + (rotate_half(k) * sin)
+ return q_embed, k_embed
+
+def llama_mlp(module: torch.nn.Module, x: torch.Tensor):
+ if module.config.pretraining_tp > 1:
+ slice = module.intermediate_size // module.config.pretraining_tp
+ gate_proj_slices = module.gate_proj.weight.split(slice, dim=0)
+ up_proj_slices = module.up_proj.weight.split(slice, dim=0)
+ down_proj_slices = module.down_proj.weight.split(slice, dim=1)
+
+ gate_proj = torch.cat(
+ [F.linear(x, gate_proj_slices[i]) for i in range(module.config.pretraining_tp)], dim=-1
+ )
+ up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(module.config.pretraining_tp)], dim=-1)
+
+ intermediate_states = (module.act_fn(gate_proj) * up_proj).split(slice, dim=2)
+ down_proj = [
+ F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(module.config.pretraining_tp)
+ ]
+ down_proj = sum(down_proj)
+ else:
+ h_gate = module.act_fn(module.gate_proj(x))
+ h_up = module.up_proj(x)
+ intermediate_states = h_gate * h_up
+ down_proj = module.down_proj(intermediate_states)
+
+ return down_proj, intermediate_states, h_gate, h_up
+
+
+def llama_self_attn(module: torch.nn.Module,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ output_attentions: Optional[bool] = False,
+ use_cache: Optional[bool] = False,
+ **kwargs
+ ):
+ if "padding_mask" in kwargs:
+ warnings.warn(
+ "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
+ )
+ # print("collect hidden states shape: ", {hidden_states.shape})
+ bsz, q_len, _ = hidden_states.size()
+
+ if module.config.pretraining_tp > 1:
+ key_value_slicing = (module.num_key_value_heads * module.head_dim) // module.config.pretraining_tp
+ query_slices = module.q_proj.weight.split(
+ (module.num_heads * module.head_dim) // module.config.pretraining_tp, dim=0
+ )
+ key_slices = module.k_proj.weight.split(key_value_slicing, dim=0)
+ value_slices = module.v_proj.weight.split(key_value_slicing, dim=0)
+
+ query_states = [F.linear(hidden_states, query_slices[i]) for i in range(module.config.pretraining_tp)]
+ query_states = torch.cat(query_states, dim=-1)
+
+ key_states = [F.linear(hidden_states, key_slices[i]) for i in range(module.config.pretraining_tp)]
+ key_states = torch.cat(key_states, dim=-1)
+
+ value_states = [F.linear(hidden_states, value_slices[i]) for i in range(module.config.pretraining_tp)]
+ value_states = torch.cat(value_states, dim=-1)
+
+ else:
+ query_states = module.q_proj(hidden_states)
+ key_states = module.k_proj(hidden_states)
+ value_states = module.v_proj(hidden_states)
+
+ q_norm = torch.norm(query_states, p=2, dim=0)
+ k_norm = torch.norm(key_states, p=2, dim=0)
+
+ query_states = query_states.view(bsz, q_len, module.num_heads, module.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, module.num_key_value_heads, module.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, module.num_key_value_heads, module.head_dim).transpose(1, 2)
+
+ kv_seq_len = key_states.shape[-2]
+ if past_key_value is not None:
+ if module.layer_idx is None:
+ raise ValueError(
+ f"The cache structure has changed since version v4.36. If you are using {module.__class__.__name__} "
+ "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
+ "with a layer index."
+ )
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, module.layer_idx)
+ cos, sin = module.rotary_emb(value_states, seq_len=kv_seq_len)
+ if cos.device != position_ids.device:
+ position_ids = position_ids.to(cos.device)
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
+
+ if past_key_value is not None:
+ cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
+ # key_states, value_states = past_key_value.update(key_states, value_states, module.layer_idx, cache_kwargs)
+ key_states, value_states = past_key_value.update_get(key_states, value_states, module.layer_idx, cache_kwargs)
+
+ key_states = repeat_kv(key_states, module.num_key_value_groups)
+ value_states = repeat_kv(value_states, module.num_key_value_groups)
+
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(module.head_dim)
+
+ if attn_weights.size() != (bsz, module.num_heads, q_len, kv_seq_len):
+ raise ValueError(
+ f"Attention weights should be of size {(bsz, module.num_heads, q_len, kv_seq_len)}, but is"
+ f" {attn_weights.size()}"
+ )
+
+ if attention_mask is not None:
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
+ raise ValueError(
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
+ )
+ if attn_weights.device != attention_mask.device:
+ attn_weights = attn_weights.to(attention_mask.device)
+ attn_weights = attn_weights + attention_mask
+
+ # upcast attention to fp32
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.bfloat16).to(query_states.dtype)
+ attn_weights = nn.functional.dropout(attn_weights, p=module.attention_dropout, training=module.training)
+ if attn_weights.device != value_states.device:
+ attn_weights = attn_weights.to(value_states.device)
+ attn_output = torch.matmul(attn_weights, value_states)
+
+ if attn_output.size() != (bsz, module.num_heads, q_len, module.head_dim):
+ raise ValueError(
+ f"`attn_output` should be of size {(bsz, module.num_heads, q_len, module.head_dim)}, but is"
+ f" {attn_output.size()}"
+ )
+
+ attn_output = attn_output.transpose(1, 2).contiguous()
+
+ head_norms = torch.norm(attn_output, dim = -1)
+
+ attn_output = attn_output.reshape(bsz, q_len, module.hidden_size)
+
+ if module.config.pretraining_tp > 1:
+ attn_output = attn_output.split(module.hidden_size // module.config.pretraining_tp, dim=2)
+ o_proj_slices = module.o_proj.weight.split(module.hidden_size // module.config.pretraining_tp, dim=1)
+ attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(module.config.pretraining_tp)])
+ else:
+ attn_output = module.o_proj(attn_output)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return head_norms, attn_output, attn_weights, past_key_value, q_norm, k_norm
+
+
+def pre_decoder_forward_hook_dejavu_collect(layer_idx: int,
+ model_type: str,
+ dest: Dict[str, Dict[str, List[torch.Tensor]]],
+ module: torch.nn.Module,
+ inp: Tuple,
+ *args,
+ **kwargs):
+ x = inp[0]
+
+ residual = x
+ hidden_states = module.input_layernorm(x)
+
+ use_legacy_cache = not isinstance(args[0]["past_key_value"], Cache) and args[0]["past_key_value"] is not None
+ if use_legacy_cache:
+ args[0]["past_key_value"] = DynamicCache.from_legacy_cache(args[0]["past_key_value"])
+
+ _, hidden_states, _, _, _, _ = llama_self_attn(module.self_attn,
+ hidden_states,
+ args[0]["attention_mask"],
+ args[0]["position_ids"],
+ args[0]["past_key_value"],
+ args[0]["output_attentions"],
+ args[0]["use_cache"])
+
+ if hidden_states.device != residual.device:
+ hidden_states = hidden_states.to(residual.device)
+ hidden_states = residual + hidden_states
+
+ # Fully Connected
+ residual = hidden_states
+ hidden_states = module.post_attention_layernorm(hidden_states)
+
+ hidden_states, intermediate_states, _, _ = llama_mlp(module.mlp, hidden_states)
+
+ ##### the output norm of different neurons #####
+ if hidden_states.shape[1] > 1:
+ neurons_dict_cur_layer = getattr(module, f'neurons_dict_{layer_idx}')
+ mlp_inter_norm = torch.norm(intermediate_states, p=2, dim=0)
+ if mlp_inter_norm.device != neurons_dict_cur_layer['norm'].device:
+ mlp_inter_norm = mlp_inter_norm.to(neurons_dict_cur_layer['norm'].device)
+ neurons_dict_cur_layer['norm'] += mlp_inter_norm.sum(dim=0)
+ neurons_dict_cur_layer['token_num'] += mlp_inter_norm.shape[0]
+
+
+def add_training_hook_to_llama(model: LlamaPreTrainedModel,
+ dest: Dict[str, Dict[str, List[torch.Tensor]]],
+ intermediate_size: int) -> int:
+ set_helper_state(model, HelperState.Collecting)
+ hooks = []
+ last_layer = 0
+
+ for name, module in model.named_modules():
+ if not isinstance(module, (LlamaDecoderLayer, LlamaAttention, LlamaMLP)):
+ continue
+
+ if isinstance(module, LlamaDecoderLayer):
+ layer_idx = int(name.split(".")[-1])
+ else:
+ layer_idx = int(name.split(".")[-2])
+
+ last_layer = max(layer_idx, last_layer)
+
+ if isinstance(module, LlamaDecoderLayer):
+ setattr(module, f"neurons_dict_{layer_idx}", {'norm': torch.zeros(intermediate_size).to(model.device), 'token_num': 0})
+ handle_dejavu_collect = module.register_forward_pre_hook(
+ functools.partial(
+ pre_decoder_forward_hook_dejavu_collect,
+ layer_idx,
+ "llama",
+ dest,
+ ),
+ with_kwargs=True
+ )
+ hooks.append(handle_dejavu_collect)
+
+ setattr(model, _HELPER_HOOK_KEY, hooks)
+ return last_layer
+
+
+def add_training_hook(model: HELPER_SUPPORT_MODEL_TYPES,
+ dest: Dict[str, Dict[str, List[torch.Tensor]]],
+ intermediate_size: int) -> int:
+ if isinstance(model, LlamaPreTrainedModel):
+ return add_training_hook_to_llama(model, dest, intermediate_size)
+ else:
+ raise NotImplementedError(f"Only support {HELPER_SUPPORT_MODEL_LIST}.")
+
+
+def remove_training_hook(model: HELPER_SUPPORT_MODEL_TYPES):
+ hooks = getattr(model, _HELPER_HOOK_KEY)
+ for handle in hooks:
+ handle.remove()
+
+ setattr(model, _HELPER_HOOK_KEY, None)
diff --git a/flagscale/compress/LLM_Decomposition/sola_neuron_idx/playground.py b/flagscale/compress/LLM_Decomposition/sola_neuron_idx/playground.py
new file mode 100644
index 0000000000..0ce587d6c3
--- /dev/null
+++ b/flagscale/compress/LLM_Decomposition/sola_neuron_idx/playground.py
@@ -0,0 +1,155 @@
+import os
+import random
+import click
+from transformers import (
+ AutoModelForCausalLM,
+ AutoTokenizer,
+ LlamaTokenizer,
+ pipeline
+)
+import torch
+import numpy as np
+import pandas as pd
+
+import utils
+from training import Helper
+
+import json
+import time
+import logging
+logger = logging.getLogger(__name__)
+
+
+def setup_seed(seed):
+ torch.manual_seed(seed)
+ torch.cuda.manual_seed_all(seed)
+ np.random.seed(seed)
+ random.seed(seed)
+ torch.backends.cudnn.deterministic = True
+
+
+@click.command()
+@click.option("-m", "--model", type=click.Path(file_okay=True), help="path to model file", default=None)
+def cli(**kwargs):
+ args = utils.EasyDict(**kwargs)
+ print(args)
+
+ model_path = args.model
+
+ model = AutoModelForCausalLM.from_pretrained(
+ model_path,
+ cache_dir=None,
+ device_map="auto",
+ quantization_config = None,
+ torch_dtype=torch.bfloat16,
+ attn_implementation="eager",
+ )
+ print("Model created")
+
+ # Tokenizer
+ tokenizer = AutoTokenizer.from_pretrained(
+ model_path,
+ cache_dir=None,
+ padding_side="right",
+ use_fast=False, # Fast tokenizer giving issues.
+ tokenizer_type='llama' if 'ama' in model_path else None, # Needed for HF name change
+ )
+ if tokenizer._pad_token is None:
+ utils.smart_tokenizer_and_embedding_resize(
+ special_tokens_dict=dict(pad_token=utils.DEFAULT_PAD_TOKEN),
+ tokenizer=tokenizer,
+ model=model,
+ )
+ if 'ama' in model_path or isinstance(tokenizer, LlamaTokenizer):
+ # LLaMA tokenizer may not have correct special tokens set.
+ # Check and add them if missing to prevent them from being parsed into different tokens.
+ # Note that these are present in the vocabulary.
+ # Note also that `model.config.pad_token_id` is 0 which corresponds to `` token.
+ print('Adding special tokens.')
+ tokenizer.add_special_tokens({
+ "eos_token": tokenizer.convert_ids_to_tokens(model.config.eos_token_id),
+ "bos_token": tokenizer.convert_ids_to_tokens(model.config.bos_token_id),
+ "unk_token": tokenizer.convert_ids_to_tokens(
+ model.config.pad_token_id if model.config.pad_token_id and model.config.pad_token_id != -1 else tokenizer.pad_token_id
+ ),
+ })
+ print("Tokenizer created")
+
+ generation_pipeline = pipeline(task="text-generation", model=model, tokenizer=tokenizer)
+ print("Pipeline created")
+
+ setup_seed(42)
+
+ if '13b' in model_path:
+ intermediate_size = 13824
+ elif '7b' in model_path:
+ intermediate_size = 11008
+ helper_params = {'intermediate_size': intermediate_size}
+ helper = Helper(model, torch.bfloat16, **helper_params)
+ print("Helper created")
+
+ # Construct calibration data
+ # wiki_dataset = utils.get_wikitext2(256, 3, 4096, tokenizer, 'wiki')
+ # with open('/data/wiki_256_4096.json', 'w') as json_file:
+ # json.dump(wiki_dataset, json_file, ensure_ascii=False, indent=4)
+ with open('/data/wiki_256_4096.json', 'r') as file:
+ prompts_all = json.load(file)
+
+ # calibration prompt number
+ prompt_num_list = [256]
+ for prompt_num in prompt_num_list:
+ prompts = prompts_all[:prompt_num]
+ print('prompt number: ', len(prompts))
+
+ # Compute neurons norm
+ t_start_time = time.time()
+ with helper:
+ for text in prompts:
+ prompt_token_count = len(generation_pipeline.tokenizer.encode(text, return_tensors="pt")[0])
+ generation_pipeline(text, max_length=prompt_token_count, pad_token_id=tokenizer.eos_token_id, truncation=True)
+ t_end_time = time.time()
+ t_duration = t_end_time - t_start_time
+ print(f"Collect training data costs avg: {t_duration/len(prompts): .5f} s, all: {t_duration/60: .2f} min, {t_duration: .5f} s. ")
+ print('Collect training data Done')
+
+ # Record neurons norm
+ import csv
+ from transformers.models.llama.modeling_llama import LlamaDecoderLayer
+
+ for name, module in model.named_modules():
+ if not isinstance(module, LlamaDecoderLayer):
+ continue
+ layer_idx = int(name.split(".")[-1])
+
+ neurons_dict_cur_layer = getattr(module, f'neurons_dict_{layer_idx}')
+ average_values = neurons_dict_cur_layer['norm'] / neurons_dict_cur_layer['token_num']
+ csv_data = [(k, average_values[k].item()) for k in range(average_values.shape[0])]
+ csv_file = f'/data/mlp_neurons_norm/sample_{prompt_num}/13b_{layer_idx}.csv'
+
+ with open(csv_file, mode='w', newline='') as file:
+ writer = csv.writer(file)
+ writer.writerows(csv_data)
+ print(layer_idx, csv_file, 'write done')
+ print(prompt_num, 'Neurons Norm Write Done')
+
+ ###### Record prime/marginal neurons index ######
+ dir_name = f'/data/mlp_neurons_norm/sample_{prompt_num}'
+ checkpoint_path_list = os.listdir(dir_name)
+ hitter_dict = {}
+ for checkpoint_path in checkpoint_path_list:
+ layer_idx = int(checkpoint_path.split('_')[-1].split('.')[0])
+ data_df = pd.read_csv(os.path.join(dir_name, checkpoint_path), index_col=False, header=None)
+ data_df = data_df.sort_values(by=1, ascending=False)
+ num_top = int(len(data_df) * 0.15) # prime neurons 15%
+ heavy_neron_idx = data_df[0][:num_top].tolist()
+ light_neron_idx = data_df[0][num_top:].tolist()
+ hitter_dict[layer_idx] = {'heavy_15p_neuron_idx': heavy_neron_idx,
+ 'light_85p_neuron_idx': light_neron_idx}
+
+ with open(f'{dir_name}/hitter_dict_15_{prompt_num}_4096_wiki_13b.json', 'w') as json_file:
+ json.dump(hitter_dict, json_file, ensure_ascii=False, indent=4)
+ print(prompt_num, 'Prime Dict Done')
+
+
+if __name__ == "__main__":
+ cli()
diff --git a/flagscale/compress/LLM_Decomposition/sola_neuron_idx/playground.sh b/flagscale/compress/LLM_Decomposition/sola_neuron_idx/playground.sh
new file mode 100644
index 0000000000..28ed8df3bb
--- /dev/null
+++ b/flagscale/compress/LLM_Decomposition/sola_neuron_idx/playground.sh
@@ -0,0 +1,3 @@
+#!/bin/bash
+
+CUDA_VISIBLE_DEVICES="0" python sola_neuron_idx/playground.py --model="/data/llama-2/Llama-2-13b-hf"
diff --git a/flagscale/compress/LLM_Decomposition/sola_neuron_idx/training.py b/flagscale/compress/LLM_Decomposition/sola_neuron_idx/training.py
new file mode 100644
index 0000000000..32945a552f
--- /dev/null
+++ b/flagscale/compress/LLM_Decomposition/sola_neuron_idx/training.py
@@ -0,0 +1,29 @@
+import torch
+import contextlib
+from typing import Dict, List
+
+from utils import HELPER_SUPPORT_MODEL_LIST, HELPER_SUPPORT_MODEL_TYPES
+
+from sola_neuron_idx.hook import (
+ add_training_hook,
+ remove_training_hook
+)
+
+
+class Helper(contextlib.ContextDecorator):
+ def __init__(self, model: HELPER_SUPPORT_MODEL_TYPES, compute_type, **kwargs):
+ self.model = model
+ self.device = model.device
+ self.compute_type = compute_type
+ self.intermediate_size = kwargs["intermediate_size"]
+ self.training_data: Dict[str, Dict[str, List[torch.Tensor]]] = {}
+
+ if not isinstance(model, HELPER_SUPPORT_MODEL_LIST):
+ raise NotImplementedError("Unsupported model")
+
+ def __enter__(self):
+ self.model_last_layer = add_training_hook(self.model, self.training_data, self.intermediate_size)
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ remove_training_hook(self.model)
+
\ No newline at end of file
diff --git a/flagscale/compress/LLM_Decomposition/sola_neuron_idx/utils.py b/flagscale/compress/LLM_Decomposition/sola_neuron_idx/utils.py
new file mode 100644
index 0000000000..22f0217891
--- /dev/null
+++ b/flagscale/compress/LLM_Decomposition/sola_neuron_idx/utils.py
@@ -0,0 +1,390 @@
+import os
+import re
+import time
+import json
+import pickle
+import torch
+import transformers
+import numpy as np
+import pandas as pd
+from tqdm import tqdm
+from enum import Enum
+from datasets import load_dataset
+from typing import Any, Union, Dict
+from transformers import OPTPreTrainedModel, LlamaPreTrainedModel
+
+
+def get_wikitext2(nsamples, seed, seqlen, tokenizer, data_name, dataset_cache_dir=None):
+ import random
+ random.seed(seed)
+
+ if data_name == 'wiki':
+ traindata = load_dataset("wikitext", "wikitext-2-raw-v1", split="train")
+ else:
+ data_files = {"train": "en/c4-validation.*.json.gz"}
+ traindata = load_dataset("allenai/c4", data_files=data_files, split="train")
+ tot_text = "\n\n".join(traindata["text"])
+
+ traindataset = []
+ for s in range(nsamples + 1):
+ i = random.randint(0, len(tot_text) - seqlen - 1)
+ j = i + seqlen * 10
+ trainenc = tokenizer(tot_text[i:j], return_tensors="pt")
+ if trainenc.input_ids.shape[1] < seqlen:
+ s = s - 1
+ continue
+ if s != 0:
+ traindataset.append(original_text)
+ inp = trainenc.input_ids[:, :seqlen]
+ original_text = tokenizer.decode(inp[0].tolist(), skip_special_tokens=True)
+
+ return traindataset
+
+
+class HelperState(Enum):
+ KEY = 10000
+
+ Collecting = 0
+ Inference = 1
+
+ Invalid = 9999
+
+
+HelperState.KEY.label = "HelperState"
+HelperState.Collecting.label = "Helper-Data-Collection" # hook forward() to collect data
+HelperState.Inference.label = "Helper-Ready-Inference" # with updated forward()
+
+
+class HelperCollectState(Enum):
+ KEY = 10001
+
+ Pre = 0
+ Post = 1
+ End = 2
+
+ Invalid = 9999
+
+
+HelperCollectState.KEY.label = "HelperCollectState"
+HelperCollectState.Pre.label = "HelperCollectState-Pre"
+HelperCollectState.Post.label = "HelperCollectState-Post"
+HelperCollectState.End.label = "HelperCollectState-End"
+
+def set_helper_state(model, state: HelperState) -> None:
+ setattr(model, HelperState.KEY.label, state)
+
+
+HELPER_SUPPORT_MODEL_LIST = (LlamaPreTrainedModel)
+HELPER_SUPPORT_MODEL_TYPES = Union[LlamaPreTrainedModel]
+
+
+
+def easy_dump(obj, dest, label):
+ with open(os.path.join(dest, f"{label}.pkl"), "wb") as f:
+ pickle.dump(obj, f)
+
+ # also dump as json if it is a dict
+ if isinstance(obj, dict):
+ with open(os.path.join(dest, f"{label}.json"), "w") as f:
+ f.write(json.dumps(obj, indent=4))
+
+def make_run_dir(outdir: Union[str, os.PathLike], desc: str) -> str:
+ """Reject modernity, return to automatically create the run dir."""
+ # Pick output directory.
+ prev_run_dirs = []
+ if os.path.isdir(outdir): # sanity check, but click.Path() should clear this one
+ prev_run_dirs = [x for x in os.listdir(outdir) if os.path.isdir(os.path.join(outdir, x))]
+ prev_run_ids = [re.match(r'^\d+', x) for x in prev_run_dirs]
+ prev_run_ids = [int(x.group()) for x in prev_run_ids if x is not None]
+ cur_run_id = max(prev_run_ids, default=-1) + 1 # start with 00000
+ run_dir = os.path.join(outdir, f'{cur_run_id:05d}-{desc}')
+ os.makedirs(run_dir, exist_ok=False) # make sure it doesn't already exist
+ return run_dir
+
+class EasyDict(dict):
+ """Convenience class that behaves like a dict but allows access with the attribute syntax."""
+
+ def __getattr__(self, name: str) -> Any:
+ try:
+ return self[name]
+ except KeyError:
+ raise AttributeError(name)
+
+ def __setattr__(self, name: str, value: Any) -> None:
+ self[name] = value
+
+ def __delattr__(self, name: str) -> None:
+ del self[name]
+
+DEFAULT_PAD_TOKEN = "[PAD]"
+def smart_tokenizer_and_embedding_resize(
+ special_tokens_dict: Dict,
+ tokenizer: transformers.PreTrainedTokenizer,
+ model: transformers.PreTrainedModel,
+):
+ """Resize tokenizer and embedding.
+
+ Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
+ """
+ num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
+ model.resize_token_embeddings(len(tokenizer))
+
+ if num_new_tokens > 0:
+ input_embeddings_data = model.get_input_embeddings().weight.data
+ output_embeddings_data = model.get_output_embeddings().weight.data
+
+ input_embeddings_avg = input_embeddings_data[:-num_new_tokens].mean(dim=0, keepdim=True)
+ output_embeddings_avg = output_embeddings_data[:-num_new_tokens].mean(dim=0, keepdim=True)
+
+ input_embeddings_data[-num_new_tokens:] = input_embeddings_avg
+ output_embeddings_data[-num_new_tokens:] = output_embeddings_avg
+
+subcategories = {
+ "abstract_algebra": ["math"],
+ "anatomy": ["health"],
+ "astronomy": ["physics"],
+ "business_ethics": ["business"],
+ "clinical_knowledge": ["health"],
+ "college_biology": ["biology"],
+ "college_chemistry": ["chemistry"],
+ "college_computer_science": ["computer science"],
+ "college_mathematics": ["math"],
+ "college_medicine": ["health"],
+ "college_physics": ["physics"],
+ "computer_security": ["computer science"],
+ "conceptual_physics": ["physics"],
+ "econometrics": ["economics"],
+ "electrical_engineering": ["engineering"],
+ "elementary_mathematics": ["math"],
+ "formal_logic": ["philosophy"],
+ "global_facts": ["other"],
+ "high_school_biology": ["biology"],
+ "high_school_chemistry": ["chemistry"],
+ "high_school_computer_science": ["computer science"],
+ "high_school_european_history": ["history"],
+ "high_school_geography": ["geography"],
+ "high_school_government_and_politics": ["politics"],
+ "high_school_macroeconomics": ["economics"],
+ "high_school_mathematics": ["math"],
+ "high_school_microeconomics": ["economics"],
+ "high_school_physics": ["physics"],
+ "high_school_psychology": ["psychology"],
+ "high_school_statistics": ["math"],
+ "high_school_us_history": ["history"],
+ "high_school_world_history": ["history"],
+ "human_aging": ["health"],
+ "human_sexuality": ["culture"],
+ "international_law": ["law"],
+ "jurisprudence": ["law"],
+ "logical_fallacies": ["philosophy"],
+ "machine_learning": ["computer science"],
+ "management": ["business"],
+ "marketing": ["business"],
+ "medical_genetics": ["health"],
+ "miscellaneous": ["other"],
+ "moral_disputes": ["philosophy"],
+ "moral_scenarios": ["philosophy"],
+ "nutrition": ["health"],
+ "philosophy": ["philosophy"],
+ "prehistory": ["history"],
+ "professional_accounting": ["other"],
+ "professional_law": ["law"],
+ "professional_medicine": ["health"],
+ "professional_psychology": ["psychology"],
+ "public_relations": ["politics"],
+ "security_studies": ["politics"],
+ "sociology": ["culture"],
+ "us_foreign_policy": ["politics"],
+ "virology": ["health"],
+ "world_religions": ["philosophy"],
+}
+
+categories = {
+ "STEM": ["physics", "chemistry", "biology", "computer science", "math", "engineering"],
+ "humanities": ["history", "philosophy", "law"],
+ "social sciences": ["politics", "culture", "economics", "geography", "psychology"],
+ "other (business, health, misc.)": ["other", "business", "health"],
+}
+
+choices = ["A", "B", "C", "D"]
+
+
+def format_subject(subject):
+ l = subject.split("_")
+ s = ""
+ for entry in l:
+ s += " " + entry
+ return s
+
+def format_example(df, idx, include_answer=True):
+ prompt = df.iloc[idx, 0]
+ k = df.shape[1] - 2
+ for j in range(k):
+ prompt += "\n{}. {}".format(choices[j], df.iloc[idx, j + 1])
+ prompt += "\nAnswer:"
+ if include_answer:
+ prompt += " {}\n\n".format(df.iloc[idx, k + 1])
+ return prompt
+
+
+def gen_prompt(train_df, subject, k=-1):
+ prompt = "The following are multiple choice questions (with answers) about {}.\n\n".format(
+ format_subject(subject)
+ )
+ if k == -1:
+ k = train_df.shape[0]
+ for i in range(k):
+ prompt += format_example(train_df, i)
+ return prompt
+
+
+@torch.no_grad()
+def evaluate_subject(subject, model, tokenizer, ntrain, dev_df, test_df):
+ cors = []
+ all_probs = []
+ answers = choices[: test_df.shape[1] - 2]
+
+ for i in tqdm(range(test_df.shape[0]), desc=subject):
+ # get prompt and make sure it fits
+ k = ntrain
+ prompt_end = format_example(test_df, i, include_answer=False)
+ train_prompt = gen_prompt(dev_df, subject, k)
+ prompt = train_prompt + prompt_end
+
+ input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(model.device)
+
+ while input_ids.shape[-1] > 2048:
+ k -= 1
+ train_prompt = gen_prompt(dev_df, subject, k)
+ prompt = train_prompt + prompt_end
+ input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(
+ model.device
+ )
+
+ label = test_df.iloc[i, test_df.shape[1] - 1]
+
+ logits = model(input_ids=input_ids).logits[0, -1]
+
+ probs = (
+ torch.nn.functional.softmax(
+ torch.tensor(
+ [
+ logits[tokenizer("A").input_ids[-1]],
+ logits[tokenizer("B").input_ids[-1]],
+ logits[tokenizer("C").input_ids[-1]],
+ logits[tokenizer("D").input_ids[-1]],
+ ]
+ ).float(),
+ dim=0,
+ )
+ .detach()
+ .cpu()
+ .numpy()
+ )
+ pred = {0: "A", 1: "B", 2: "C", 3: "D"}[np.argmax(probs)]
+
+ cor = pred == label
+ cors.append(cor)
+ all_probs.append(probs)
+
+ acc = np.mean(cors)
+ cors = np.array(cors)
+
+ all_probs = np.array(all_probs)
+ print("Average accuracy {:.3f} - {}".format(acc, subject))
+
+ return cors, acc, all_probs
+
+def eval_mmlu(label, model, tokenizer, ntrain, data_dir, save_dir):
+ mmlu_out_dest = make_run_dir(save_dir, "mmlu_eval")
+
+ subjects = sorted(
+ [
+ f.split("_test.csv")[0]
+ for f in os.listdir(os.path.join(data_dir, "test"))
+ if "_test.csv" in f
+ ]
+ )
+
+ all_cors = []
+ subcat_cors = {
+ subcat: [] for subcat_lists in subcategories.values() for subcat in subcat_lists
+ }
+ cat_cors = {cat: [] for cat in categories}
+
+ start_time = time.time()
+ for subject in subjects:
+ dev_df = pd.read_csv(
+ os.path.join(data_dir, "dev", subject + "_dev.csv"), header=None
+ )[: ntrain]
+ test_df = pd.read_csv(
+ os.path.join(data_dir, "test", subject + "_test.csv"), header=None
+ )
+
+ cors, acc, probs = evaluate_subject(subject, model, tokenizer, ntrain, dev_df, test_df)
+ subcats = subcategories[subject]
+ for subcat in subcats:
+ subcat_cors[subcat].append(cors)
+ for key in categories.keys():
+ if subcat in categories[key]:
+ cat_cors[key].append(cors)
+ all_cors.append(cors)
+
+ test_df["{}_correct".format(label)] = cors
+ for j in range(probs.shape[1]):
+ choice = choices[j]
+ test_df["{}_choice{}_probs".format(label, choice)] = probs[:, j]
+ os.makedirs(os.path.join(mmlu_out_dest, "results_{}".format(label.split("/")[-1])), exist_ok=True)
+ test_df.to_csv(
+ os.path.join(
+ mmlu_out_dest, "results_{}".format(label.split("/")[-1]), "{}.csv".format(subject)
+ ),
+ index=None,
+ )
+
+ results = {"subcategories": {}, "categories": {}}
+ for subcat in subcat_cors:
+ subcat_acc = np.mean(np.concatenate(subcat_cors[subcat]))
+ results["subcategories"][subcat] = subcat_acc
+ print("Average accuracy {:.3f} - {}".format(subcat_acc, subcat))
+
+ for cat in cat_cors:
+ cat_acc = np.mean(np.concatenate(cat_cors[cat]))
+ results["categories"][cat] = cat_acc
+ print("Average accuracy {:.3f} - {}".format(cat_acc, cat))
+ weighted_acc = np.mean(np.concatenate(all_cors))
+ results["weighted_accuracy"] = weighted_acc
+ print("Average accuracy: {:.3f}".format(weighted_acc))
+
+ end_time = time.time()
+ results["cost_time"] = end_time - start_time
+ easy_dump(results, mmlu_out_dest, "accuracies_{}".format(label.replace("/", "_")))
+ return results
+
+def eval_ppl(model, tokenizer):
+ model.eval()
+ max_length = model.config.max_position_embeddings
+ stride = max_length
+ test = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")["text"]
+ nlls = []
+ encodings = tokenizer("\n\n".join(test), return_tensors="pt")
+ seq_len = encodings.input_ids.size(1)
+ prev_end_loc = 0
+ for begin_loc in tqdm(range(0, seq_len, stride)):
+ end_loc = min(begin_loc + max_length, seq_len)
+ trg_len = end_loc - prev_end_loc # may be different from stride on last loop
+ input_ids = encodings.input_ids[:, begin_loc:end_loc].to(model.device)
+ target_ids = input_ids.clone()
+ target_ids[:, :-trg_len] = -100
+ with torch.no_grad():
+ outputs = model(input_ids, labels=target_ids)
+ # loss is calculated using CrossEntropyLoss which averages over valid labels
+ # N.B. the model only calculates loss over trg_len - 1 labels, because it internally shifts the labels
+ # to the left by 1.
+ neg_log_likelihood = outputs.loss
+ nlls.append(neg_log_likelihood)
+ prev_end_loc = end_loc
+ if end_loc == seq_len:
+ break
+ ppl = torch.exp(torch.stack(nlls).mean())
+ return ppl