From d0a167a78df970727640aeeb2a89002ac9a50a7b Mon Sep 17 00:00:00 2001 From: Ppaddington <3475635952@qq.com> Date: Thu, 19 Mar 2026 11:11:55 +0800 Subject: [PATCH] feat: integrate SoLA for LLM SVD decomposition --- .../compress/LLM_Decomposition/README.md | 30 + .../LLM_Decomposition/requirements.txt | 355 ++++++++ .../LLM_Decomposition/sola/hook_sola.py | 380 ++++++++ .../compress/LLM_Decomposition/sola/optim.py | 183 ++++ .../LLM_Decomposition/sola/playground_sola.py | 324 +++++++ .../LLM_Decomposition/sola/playground_sola.sh | 3 + .../LLM_Decomposition/sola/training_sola.py | 160 ++++ .../compress/LLM_Decomposition/sola/utils.py | 811 ++++++++++++++++++ .../LLM_Decomposition/sola_neuron_idx/hook.py | 305 +++++++ .../sola_neuron_idx/playground.py | 155 ++++ .../sola_neuron_idx/playground.sh | 3 + .../sola_neuron_idx/training.py | 29 + .../sola_neuron_idx/utils.py | 390 +++++++++ 13 files changed, 3128 insertions(+) create mode 100644 flagscale/compress/LLM_Decomposition/README.md create mode 100644 flagscale/compress/LLM_Decomposition/requirements.txt create mode 100644 flagscale/compress/LLM_Decomposition/sola/hook_sola.py create mode 100644 flagscale/compress/LLM_Decomposition/sola/optim.py create mode 100644 flagscale/compress/LLM_Decomposition/sola/playground_sola.py create mode 100644 flagscale/compress/LLM_Decomposition/sola/playground_sola.sh create mode 100644 flagscale/compress/LLM_Decomposition/sola/training_sola.py create mode 100644 flagscale/compress/LLM_Decomposition/sola/utils.py create mode 100644 flagscale/compress/LLM_Decomposition/sola_neuron_idx/hook.py create mode 100644 flagscale/compress/LLM_Decomposition/sola_neuron_idx/playground.py create mode 100644 flagscale/compress/LLM_Decomposition/sola_neuron_idx/playground.sh create mode 100644 flagscale/compress/LLM_Decomposition/sola_neuron_idx/training.py create mode 100644 flagscale/compress/LLM_Decomposition/sola_neuron_idx/utils.py diff --git a/flagscale/compress/LLM_Decomposition/README.md b/flagscale/compress/LLM_Decomposition/README.md new file mode 100644 index 0000000000..51fec25453 --- /dev/null +++ b/flagscale/compress/LLM_Decomposition/README.md @@ -0,0 +1,30 @@ +
+

SoLA

+

+[AAAI 2025] SoLA: Leveraging Soft Activation Sparsity and Low-Rank Decomposition for Large Language Model Compression +

+

+ +

+image +

+ +## Usage + +### Construct Calibration Data + sola_neuron_idx/utils.py get_wikitext2() --> data/wiki_256_4096.json + +### Compute Neurons Norm + bash sola_neuron_idx/playground.sh --> data/hitter_dict_15_256_4096_wiki_13b.json + +### Low-Rank Decomposition and Evaluation + bash sola/playground_sola.sh + + Llama-2-7B/13B low-rank decomposition and evaluation + + Remember modify file path to your own + +### Others + sola_neuron_idx/hook.py llama_self_attn() past_key_value.update_get() in transformers-4.37.1/src/transformers/cache_utils.py + + You can install modified transformers package: cd transformers-4.37.1; pip install -e . diff --git a/flagscale/compress/LLM_Decomposition/requirements.txt b/flagscale/compress/LLM_Decomposition/requirements.txt new file mode 100644 index 0000000000..83f9e70e41 --- /dev/null +++ b/flagscale/compress/LLM_Decomposition/requirements.txt @@ -0,0 +1,355 @@ +about-time==4.2.1 +absl-py==2.1.0 +accelerate==1.6.0 +addict==2.4.0 +aiofiles==23.2.1 +aiohttp==3.9.5 +aiosignal==1.3.1 +alive-progress==3.2.0 +altair==5.5.0 +annotated-types==0.7.0 +antlr4-python3-runtime==4.9.3 +anyio==4.4.0 +appdirs==1.4.4 +argon2-cffi==23.1.0 +argon2-cffi-bindings==21.2.0 +arrow==1.3.0 +asttokens==3.0.0 +async-lru==2.0.4 +async-timeout==4.0.3 +attrs==23.2.0 +auto_gptq==0.7.1 +autograd==1.7.0 +babel==2.17.0 +beautifulsoup4==4.13.3 +bitsandbytes==0.43.2.dev0 +bleach==6.2.0 +boto3==1.35.90 +botocore==1.35.90 +certifi==2024.2.2 +cffi==1.17.1 +chardet==5.2.0 +charset-normalizer==3.3.2 +click==8.1.7 +cloudpickle==2.2.1 +cma==3.2.2 +colorama==0.4.6 +comm==0.2.2 +ConfigSpace==0.6.0 +contextlib2==21.6.0 +contourpy==1.2.1 +cpm-kernels==1.0.11 +cycler==0.12.1 +Cython==3.0.10 +dataclasses-json==0.6.7 +DataProperty==1.0.1 +datasets==3.2.0 +debugpy==1.8.12 +decorator==5.2.1 +deepeval==2.0 +defusedxml==0.7.1 +Deprecated==1.2.15 +dill==0.3.8 +distro==1.9.0 +dnspython==2.6.1 +docker==7.1.0 +docker-pycreds==0.4.0 +docstring_parser==0.16 +docx2txt==0.8 +einops==0.8.0 +email_validator==2.2.0 +emcee==3.1.6 +et_xmlfile==2.0.0 +eval_type_backport==0.2.2 +evaluate==0.4.6 +exceptiongroup==1.2.1 +execnet==2.1.1 +executing==2.2.0 +fastapi==0.111.1 +fastapi-cli==0.0.4 +fastjsonschema==2.21.1 +ffmpy==0.3.2 +filelock==3.14.0 +fire==0.7.1 +flash-attn==2.5.8 +fonttools==4.51.0 +fqdn==1.5.1 +frozenlist==1.4.1 +fschat==0.2.24 +fsspec==2024.3.1 +func-timeout==4.3.5 +fuzzywuzzy==0.18.0 +gekko==1.3.0 +gitdb==4.0.11 +GitPython==3.1.43 +google-pasta==0.2.0 +googleapis-common-protos==1.66.0 +gradio==3.50.2 +gradio_client==0.6.1 +grapheme==0.6.0 +greenlet==3.1.1 +grpcio==1.63.2 +h11==0.14.0 +h5py==3.13.0 +httpcore==1.0.5 +httptools==0.6.1 +httpx==0.27.2 +httpx-sse==0.4.0 +huggingface-hub==0.30.1 +-e git+https://github.com/openai/human-eval@6d43fb980f9fee3c892a914eda09951f772ad10d#egg=human_eval +idna==3.7 +immutabledict==4.2.1 +importlib-metadata==6.11.0 +importlib_resources==6.4.0 +iniconfig==2.0.0 +inquirerpy==0.3.4 +ipykernel==6.29.5 +ipython==8.18.1 +isoduration==20.11.0 +jedi==0.19.2 +jieba==0.42.1 +Jinja2==3.1.3 +jiter==0.8.0 +jmespath==1.0.1 +joblib==1.4.2 +json5==0.10.0 +jsonlines==4.0.0 +jsonpatch==1.33 +jsonpointer==3.0.0 +jsonschema==4.23.0 +jsonschema-specifications==2024.10.1 +jupyter-events==0.12.0 +jupyter-lsp==2.2.5 +jupyter_client==8.6.3 +jupyter_core==5.7.2 +jupyter_server==2.15.0 +jupyter_server_terminals==0.5.3 +jupyterlab==4.3.5 +jupyterlab_pygments==0.3.0 +jupyterlab_server==2.27.3 +kiwisolver==1.4.5 +langchain==0.3.9 +langchain-community==0.3.8 +langchain-core==0.3.21 +langchain-openai==0.2.10 +langchain-text-splitters==0.3.2 +langsmith==0.1.147 +latex2mathml==3.77.0 +Levenshtein==0.27.1 +lion-pytorch==0.1.4 +llm-blender @ git+https://github.com/yuchenlin/LLM-Blender.git@33204d2712944b6b17996f7c079e74cd963ccc7c +-e git+https://github.com/EleutherAI/lm-evaluation-harness@9d4a04a0aac37fff88ca1a04d667012e1502c21d#egg=lm_eval +lxml==5.2.2 +markdown-it-py==3.0.0 +markdown2==2.5.3 +MarkupSafe==2.1.5 +marshmallow==3.23.1 +matplotlib==3.8.4 +matplotlib-inline==0.1.7 +mbstrdecoder==1.1.3 +mdurl==0.1.2 +mistune==3.1.2 +ml-collections==0.1.1 +mmengine-lite==0.10.7 +mock==4.0.3 +more-itertools==10.3.0 +mpmath==1.3.0 +multidict==6.0.5 +multiprocess==0.70.16 +mypy-extensions==1.0.0 +narwhals==1.34.1 +nbclient==0.10.2 +nbconvert==7.16.6 +nbformat==5.10.4 +nest-asyncio==1.6.0 +networkx==3.2.1 +nh3==0.2.21 +ninja==1.11.1.1 +nltk==3.8.1 +notebook==7.3.2 +notebook_shim==0.2.4 +numexpr==2.10.1 +numpy==1.26.4 +nvidia-cublas-cu12==12.1.3.1 +nvidia-cuda-cupti-cu12==12.1.105 +nvidia-cuda-nvrtc-cu12==12.1.105 +nvidia-cuda-runtime-cu12==12.1.105 +nvidia-cudnn-cu12==8.9.2.26 +nvidia-cufft-cu12==11.0.2.54 +nvidia-curand-cu12==10.3.2.106 +nvidia-cusolver-cu12==11.4.5.107 +nvidia-cusparse-cu12==12.1.0.106 +nvidia-nccl-cu12==2.20.5 +nvidia-nvjitlink-cu12==12.4.127 +nvidia-nvtx-cu12==12.1.105 +omegaconf==2.2.3 +openai==1.55.3 +openbox==0.8.3 +OpenCC==1.1.9 +-e git+ssh://git@github.com/xinhaoH/KV_C.git@c899b84206943dbded60678a15425ce34b4bcb38#egg=opencompass&subdirectory=opencompass +opencv-python-headless==4.11.0.86 +openpyxl==3.1.5 +opentelemetry-api==1.24.0 +opentelemetry-exporter-otlp-proto-common==1.24.0 +opentelemetry-exporter-otlp-proto-grpc==1.24.0 +opentelemetry-proto==1.24.0 +opentelemetry-sdk==1.24.0 +opentelemetry-semantic-conventions==0.45b0 +optimum==2.0.0 +orjson==3.10.6 +overrides==7.7.0 +packaging==24.0 +pandas==1.5.3 +pandocfilters==1.5.1 +parso==0.8.4 +pathos==0.3.3 +pathvalidate==3.2.0 +patsy==0.5.6 +# Editable install with no version control (peft==0.10.0) +-e /home/xinhao/peft-0.10.0 +pexpect==4.9.0 +pfzy==0.3.4 +pillow==10.3.0 +platformdirs==4.2.2 +Platypus-Opt==1.0.4 +pluggy==1.5.0 +portalocker==2.10.1 +pox==0.3.5 +ppft==1.7.6.9 +prettytable==3.10.2 +prometheus_client==0.21.1 +prompt_toolkit==3.0.50 +protobuf==4.25.5 +psutil==5.9.8 +ptflops==0.7.3 +ptyprocess==0.7.0 +pure_eval==0.2.3 +pyaml==24.7.0 +pyarrow==16.0.0 +pyarrow-hotfix==0.6 +pybind11==2.13.1 +pycountry==24.6.1 +pycparser==2.22 +pydantic==1.10.21 +pydantic-settings==2.6.1 +pydantic_core==2.20.1 +pydub==0.25.1 +pyext==0.7 +Pygments==2.18.0 +pymoo==0.6.0 +pyparsing==3.1.2 +pysbd==0.3.4 +pytablewriter==1.2.0 +pytest==8.2.0 +pytest-repeat==0.9.3 +pytest-xdist==3.6.1 +python-dateutil==2.9.0.post0 +python-dotenv==1.0.1 +python-json-logger==3.2.1 +python-Levenshtein==0.27.1 +python-multipart==0.0.9 +pytz==2024.1 +PyYAML==6.0.1 +pyzmq==26.2.1 +ragas==0.2.6 +rank-bm25==0.2.2 +RapidFuzz==3.12.2 +referencing==0.35.1 +regex==2024.4.28 +requests==2.32.3 +requests-toolbelt==1.0.0 +retrying==1.3.4 +rfc3339-validator==0.1.4 +rfc3986-validator==0.1.1 +rich==13.7.1 +rouge==1.0.1 +rouge-chinese==1.0.3 +rouge-score==0.1.2 +rpds-py==0.22.3 +ruff==0.5.5 +s3fs==0.4.2 +s3transfer==0.10.4 +sacrebleu==2.4.2 +safetensors==0.4.3 +sagemaker==2.237.1 +sagemaker-core==1.0.17 +schema==0.7.7 +scikit-learn==1.5.0 +scikit-optimize==0.10.2 +scipy==1.10.1 +seaborn==0.13.2 +semantic-version==2.10.0 +Send2Trash==1.8.3 +sentence-transformers==4.0.1 +sentencepiece==0.2.0 +sentry-sdk==2.11.0 +setproctitle==1.3.3 +shellingham==1.5.4 +shortuuid==1.0.13 +shtab==1.7.1 +six==1.16.0 +smdebug-rulesconfig==1.0.1 +smmap==5.0.1 +sniffio==1.3.1 +sortedcontainers==2.4.0 +soupsieve==2.6 +SQLAlchemy==2.0.35 +sqlitedict==2.1.0 +stack-data==0.6.3 +starlette==0.37.2 +statsmodels==0.14.2 +svgwrite==1.4.3 +sympy==1.12 +syne_tune==0.13.0 +tabledata==1.3.3 +tabulate==0.9.0 +tblib==3.0.0 +tcolorpy==0.1.6 +tenacity==8.4.2 +termcolor==3.0.1 +terminado==0.18.1 +threadpoolctl==3.5.0 +tiktoken==0.8.0 +timeout-decorator==0.5.0 +tinycss2==1.4.0 +tokenizers==0.15.2 +tomli==2.0.1 +tomlkit==0.12.0 +torch==2.3.0 +tornado==6.4.2 +tqdm==4.66.4 +tqdm-multiprocess==0.0.11 +traitlets==5.14.3 +-e git+ssh://git@github.com/xinhaoH/SOLA_xinhao.git@0a46232d27a29d73794599c3530fb07157c60b0c#egg=transformers&subdirectory=transformers-4.37.1 +tree-sitter==0.21.3 +tree-sitter-languages==1.10.2 +triton==2.3.0 +trl @ git+https://github.com/huggingface/trl.git@eab175d434b9bb9badee20335c7945991a26dfac +typeguard==4.4.2 +typepy==1.3.2 +typer==0.12.3 +types-python-dateutil==2.9.0.20241206 +typing-inspect==0.9.0 +typing_extensions==4.13.2 +tyro==0.9.18 +tzdata==2024.1 +ujson==5.10.0 +uri-template==1.3.0 +urllib3==1.26.20 +uvicorn==0.30.3 +uvloop==0.19.0 +wandb==0.17.5 +watchfiles==0.22.0 +wavedrom==2.0.3.post3 +wcwidth==0.2.13 +webcolors==24.11.1 +webencodings==0.5.1 +websocket-client==1.8.0 +websockets==11.0.3 +word2number==1.1 +wrapt==1.17.0 +xxhash==3.4.1 +yapf==0.43.0 +yarl==1.9.4 +zipp==3.18.1 +zstandard==0.23.0 diff --git a/flagscale/compress/LLM_Decomposition/sola/hook_sola.py b/flagscale/compress/LLM_Decomposition/sola/hook_sola.py new file mode 100644 index 0000000000..692736e6fb --- /dev/null +++ b/flagscale/compress/LLM_Decomposition/sola/hook_sola.py @@ -0,0 +1,380 @@ +import warnings +import functools +import math +import logging +import torch +from torch import nn +import torch.nn.functional as F +from typing import Dict, List, Tuple, Optional + +from transformers import LlamaPreTrainedModel +from transformers.models.llama.modeling_llama import LlamaMLP, LlamaAttention + +from utils import ( + HelperState, + HelperCollectState, + set_helper_state, + HELPER_SUPPORT_MODEL_LIST, + HELPER_SUPPORT_MODEL_TYPES +) + +logger = logging.getLogger(__name__) + +_HELPER_HOOK_KEY = "HelperHook" + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`): + The position indices of the tokens corresponding to the query and key tensors. For example, this can be + used to pass offsetted position ids when working with a KV-cache. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos[position_ids].unsqueeze(unsqueeze_dim) + sin = sin[position_ids].unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + +def add_training_hook_to_llama(model: LlamaPreTrainedModel, + dest: Dict[str, Dict[str, List[torch.Tensor]]], + intermediate_size: int, + hidden_size: int) -> int: + set_helper_state(model, HelperState.Collecting) + hooks = [] + last_layer = 0 + + def forward_hook_get_XXT(layer_idx, name, module, inp, out): + inp = inp[0].detach().float() + if inp.shape[1] > 1: + adds = torch.matmul(inp.transpose(1,2), inp) + adds_sum = torch.sum(adds, dim=0).cpu() + + raw_scaling_diag_matrix = getattr(module, f'raw_scaling_diag_matrix_{layer_idx}') + raw_scaling_diag_matrix += adds_sum + + inp = adds = adds_sum = out = None + del inp, adds, adds_sum, out + torch.cuda.empty_cache() + + for name, module in model.named_modules(): + suffix = name.split(".")[-1] + if suffix not in ["gate_proj", "up_proj", "down_proj", "q_proj", "k_proj", "o_proj", "v_proj"]: + continue + layer_idx = int(name.split(".")[-3]) + if suffix == "down_proj": + setattr(module, f"raw_scaling_diag_matrix_{layer_idx}", torch.zeros(intermediate_size, intermediate_size)) + else: + setattr(module, f"raw_scaling_diag_matrix_{layer_idx}", torch.zeros(hidden_size, hidden_size)) + handle_pre_forward_collect_hook = module.register_forward_hook( + functools.partial( + forward_hook_get_XXT, + layer_idx, + name + ) + ) + hooks.append(handle_pre_forward_collect_hook) + + setattr(model, _HELPER_HOOK_KEY, hooks) + return last_layer + + +def add_training_hook(model: HELPER_SUPPORT_MODEL_TYPES, + dest: Dict[str, Dict[str, List[torch.Tensor]]], + intermediate_size: int, + hidden_size: int) -> int: + if isinstance(model, LlamaPreTrainedModel): + print("+++++++++ add Llama training hook +++++++++") + return add_training_hook_to_llama(model, dest, intermediate_size, hidden_size) + else: + raise NotImplementedError(f"Only support {HELPER_SUPPORT_MODEL_LIST}.") + + +def remove_training_hook(model: HELPER_SUPPORT_MODEL_TYPES): + hooks = getattr(model, _HELPER_HOOK_KEY) + for handle in hooks: + handle.remove() + + setattr(model, _HELPER_HOOK_KEY, None) + + +def llama_mlp_forward_sola(module, inp, **kwargs): + if module.config.pretraining_tp > 1: + raise NotImplementedError + + if module.config.pretraining_tp > 1: + slice = module.intermediate_size // module.config.pretraining_tp + gate_proj_slices = module.gate_proj.weight.split(slice, dim=0) + up_proj_slices = module.up_proj.weight.split(slice, dim=0) + down_proj_slices = module.down_proj.weight.split(slice, dim=1) + + gate_proj = torch.cat( + [F.linear(inp, gate_proj_slices[i]) for i in range(module.config.pretraining_tp)], dim=-1 + ) + up_proj = torch.cat([F.linear(inp, up_proj_slices[i]) for i in range(module.config.pretraining_tp)], dim=-1) + + intermediate_states = (module.act_fn(gate_proj) * up_proj) + down_proj = [ + F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(module.config.pretraining_tp) + ] + down_proj = sum(down_proj) + else: + if kwargs['layer_idx'] not in kwargs['pruned_layer_idx_list']: + down_proj = module.down_proj(module.act_fn(module.gate_proj(inp)) * module.up_proj(inp)) + else: + # prime + h_gate_heavy = module.act_fn(torch.nn.functional.linear(inp, module.gate_weight_heavy)) + h_up_heavy = torch.nn.functional.linear(inp, module.up_weight_heavy) + intermediate_states_heavy = h_gate_heavy * h_up_heavy + down_proj_heavy = torch.nn.functional.linear(intermediate_states_heavy, module.down_weight_heavy) + + # marginal + if module.gate_proj_use > 0: + h_gate = module.act_fn(module.gate_proj(inp)) + else: + if inp.device != module.gate_weight_U_top.device: + module.gate_weight_U_top = module.gate_weight_U_top.to(inp.device) + tmp = torch.nn.functional.linear(inp, module.gate_weight_U_top) + if tmp.device != module.gate_weight_SVh_top.device: + module.gate_weight_SVh_top = module.gate_weight_SVh_top.to(tmp.device) + h_gate = module.act_fn(torch.nn.functional.linear(tmp, module.gate_weight_SVh_top)) + + if module.up_proj_use > 0: + h_up = module.up_proj(inp) + else: + if inp.device != module.up_weight_U_top.device: + module.up_weight_U_top = module.up_weight_U_top.to(inp.device) + tmp = torch.nn.functional.linear(inp, module.up_weight_U_top) + if tmp.device != module.up_weight_SVh_top.device: + module.up_weight_SVh_top = module.up_weight_SVh_top.to(tmp.device) + h_up = torch.nn.functional.linear(tmp, module.up_weight_SVh_top) + + if h_gate.device != h_up.device: + h_gate = h_gate.to(h_up.device) + intermediate_states = h_gate * h_up + + if module.down_proj_use > 0: + down_proj = module.down_proj(intermediate_states) + else: + if intermediate_states.device != module.down_weight_U_top.device: + module.down_weight_U_top = module.down_weight_U_top.to(intermediate_states.device) + tmp = torch.nn.functional.linear(intermediate_states, module.down_weight_U_top) + if tmp.device != module.down_weight_SVh_top.device: + module.down_weight_SVh_top = module.down_weight_SVh_top.to(tmp.device) + down_proj_light = torch.nn.functional.linear(tmp, module.down_weight_SVh_top) + + down_proj = down_proj_heavy + down_proj_light + + return down_proj + +def llama_attn_forward_sola(module: torch.nn.Module, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + **kwargs): + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + + bsz, q_len, _ = hidden_states.size() + + if module.config.pretraining_tp > 1: + key_value_slicing = (module.num_key_value_heads * module.head_dim) // module.config.pretraining_tp + query_slices = module.q_proj.weight.split( + (module.num_heads * module.head_dim) // module.config.pretraining_tp, dim=0 + ) + key_slices = module.k_proj.weight.split(key_value_slicing, dim=0) + value_slices = module.v_proj.weight.split(key_value_slicing, dim=0) + + query_states = [F.linear(hidden_states, query_slices[i]) for i in range(module.config.pretraining_tp)] + query_states = torch.cat(query_states, dim=-1) + + key_states = [F.linear(hidden_states, key_slices[i]) for i in range(module.config.pretraining_tp)] + key_states = torch.cat(key_states, dim=-1) + + value_states = [F.linear(hidden_states, value_slices[i]) for i in range(module.config.pretraining_tp)] + value_states = torch.cat(value_states, dim=-1) + + else: + ##### Attn q/k decomposition ##### + value_states = module.v_proj(hidden_states) + if kwargs['layer_idx'] not in kwargs['pruned_layer_idx_list']: + query_states = module.q_proj(hidden_states) + key_states = module.k_proj(hidden_states) + else: + if module.q_proj_use > 0: + query_states = module.q_proj(hidden_states) + else: + if hidden_states.device != module.q_weight_U_top.device: + module.q_weight_U_top = module.q_weight_U_top.to(hidden_states.device) + tmp = torch.nn.functional.linear(hidden_states, module.q_weight_U_top) + if tmp.device != module.q_weight_SVh_top.device: + module.q_weight_SVh_top = module.q_weight_SVh_top.to(tmp.device) + query_states = torch.nn.functional.linear(tmp, module.q_weight_SVh_top) + + if module.k_proj_use > 0: + key_states = module.k_proj(hidden_states) + else: + if hidden_states.device != module.k_weight_U_top.device: + module.k_weight_U_top = module.k_weight_U_top.to(hidden_states.device) + tmp = torch.nn.functional.linear(hidden_states, module.k_weight_U_top) + if tmp.device != module.k_weight_SVh_top.device: + module.k_weight_SVh_top = module.k_weight_SVh_top.to(tmp.device) + key_states = torch.nn.functional.linear(tmp, module.k_weight_SVh_top) + + query_states = query_states.view(bsz, q_len, module.num_heads, module.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, module.num_key_value_heads, module.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, module.num_key_value_heads, module.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + if module.layer_idx is None: + raise ValueError( + f"The cache structure has changed since version v4.36. If you are using {module.__class__.__name__} " + "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " + "with a layer index." + ) + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, module.layer_idx) + + cos, sin = module.rotary_emb(value_states, seq_len=kv_seq_len) + if cos.device != position_ids.device: + position_ids = position_ids.to(cos.device) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, module.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, module.num_key_value_groups) + value_states = repeat_kv(value_states, module.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(module.head_dim) + + if attn_weights.size() != (bsz, module.num_heads, q_len, kv_seq_len): + raise ValueError( + f"Attention weights should be of size {(bsz, module.num_heads, q_len, kv_seq_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) + if attn_weights.device != attention_mask.device: + attn_weights = attn_weights.to(attention_mask.device) + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.bfloat16).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=module.attention_dropout, training=module.training) + if attn_weights.device != value_states.device: + attn_weights = attn_weights.to(value_states.device) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, module.num_heads, q_len, module.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, module.num_heads, q_len, module.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + + attn_output = attn_output.reshape(bsz, q_len, module.hidden_size) + + if module.config.pretraining_tp > 1: + attn_output = attn_output.split(module.hidden_size // module.config.pretraining_tp, dim=2) + o_proj_slices = module.o_proj.weight.split(module.hidden_size // module.config.pretraining_tp, dim=1) + attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(module.config.pretraining_tp)]) + else: + ##### Attn o decomposition ##### + if kwargs['layer_idx'] not in kwargs['pruned_layer_idx_list']: + attn_output = module.o_proj(attn_output) + else: + if module.o_proj_use > 0: + attn_output = module.o_proj(attn_output) + if attn_output.device != module.o_weight_U_top.device: + module.o_weight_U_top = module.o_weight_U_top.to(attn_output.device) + tmp = torch.nn.functional.linear(attn_output, module.o_weight_U_top) + if tmp.device != module.o_weight_SVh_top.device: + module.o_weight_SVh_top = module.o_weight_SVh_top.to(tmp.device) + attn_output = torch.nn.functional.linear(tmp, module.o_weight_SVh_top) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +def add_inference_hook_to_llama_sola(pruned_layer_idx_list, + model: HELPER_SUPPORT_MODEL_TYPES): + set_helper_state(model, HelperState.Inference) + hooks = [] + + for name, module in model.named_modules(): + if not isinstance(module, (LlamaMLP, LlamaAttention)): + continue + layer_idx = int(name.split(".")[-2]) + + if isinstance(module, LlamaMLP): + module.forward = functools.partial( + llama_mlp_forward_sola, + module, + layer_idx=layer_idx, + pruned_layer_idx_list=pruned_layer_idx_list, + module_name=name + ) + elif isinstance(module, LlamaAttention): + hitter_dict_cur_layer_q = None + hitter_dict_cur_layer_k = None + + module.forward = functools.partial( + llama_attn_forward_sola, + module, + layer_idx=layer_idx, + pruned_layer_idx_list=pruned_layer_idx_list, + hitter_dict_q=hitter_dict_cur_layer_q, + hitter_dict_k=hitter_dict_cur_layer_k + ) + +def add_inference_hook(pruned_layer_idx_list, + model: HELPER_SUPPORT_MODEL_TYPES): + if isinstance(model, LlamaPreTrainedModel): + add_inference_hook_to_llama_sola(pruned_layer_idx_list, model) + else: + raise NotImplementedError(f"Only support {HELPER_SUPPORT_MODEL_LIST}.") + diff --git a/flagscale/compress/LLM_Decomposition/sola/optim.py b/flagscale/compress/LLM_Decomposition/sola/optim.py new file mode 100644 index 0000000000..8bd4e24fb4 --- /dev/null +++ b/flagscale/compress/LLM_Decomposition/sola/optim.py @@ -0,0 +1,183 @@ +import copy +import heapq +import math +from typing import Dict, Tuple, List, Callable + + +class Optimizer: + def __init__(self, + model_params: int, + target_rate: float, + granularity: int = 32, + start: float = 99, + end: float = 90): + self.granularity = granularity + self.start = start + self.end = end + + self.concurrent_budget = 3 + + self.core = {} + + # model information + self.model_params = model_params + self.target_rate = target_rate + + # global state + self.perf_degrade_clamp = 0.1 # %, in percentage + self.perf_degrade_relaxation = 0.02 # %, in percentage + + self.eff_score = 0 + self.acc_score = 0 + + def add_item(self, + name: str, + cursor: int, + shape: Tuple[int, int], + perf_var: List[float], + grad: List[float]): + m, n = shape + self.core.setdefault(name, {}) + self.core[name]["cur"] = cursor + self.core[name]["start_idx"] = cursor + self.core[name]["perf_var"] = perf_var + self.core[name]["shape"] = shape + self.core[name]["est_params"] = (m + n) * (cursor + 1) * self.granularity + self.core[name]["tot_params"] = m * n + self.core[name]["grad"] = grad + self.core[name]["accum_drop"] = 0 + self.core[name]["is_excluded"] = False + + def update_item(self, state: Dict, name: str, new_cursor: int): + m, n = state[name]["shape"] + start_idx = state[name]["start_idx"] + + state[name]["cur"] = new_cursor + state[name]["est_params"] = (m + n) * (new_cursor + 1) * self.granularity + state[name]["accum_drop"] = state[name]["perf_var"][start_idx] - state[name]["perf_var"][new_cursor] + + def _eval_state(self, state: Dict): + est_params = 0 + tot_params = 0 + tot_acc = 0 + n_layers = len(state.keys()) + for key, state_dict in state.items(): + tot_params += state_dict["tot_params"] + est_params += state_dict["est_params"] + tot_acc += state_dict["perf_var"][state_dict["cur"]] + + return tot_params, est_params, tot_acc / n_layers + + def constringe(self, do_excluding: bool = True): + print(f"Target compression rate: {self.target_rate:.2f}%") + previous_state = copy.deepcopy(self.core) + epoch = 0 + eff_scores = 0 + while True: + loop = 0 + while True: + current_state, has_modified = self._inner_loop(previous_state) + previous_state = current_state + + tot_params, est_params, acc_scores = self._eval_state(current_state) + eff_scores = (1 - (tot_params - est_params) / self.model_params) * 100 # in percentage + print(f"{epoch=}, {loop=}, {tot_params=}, {est_params=}, {eff_scores=:.2f}%, {acc_scores=:.2f}") + if eff_scores <= self.target_rate or not has_modified: + break + + # gradually relax the clamp + self.perf_degrade_clamp += self.perf_degrade_relaxation # in percentage + + loop += 1 + + if not do_excluding: + break + + # check if there's any layer needing to be excluded + any_excludes = False + for layer, layer_state_dict in current_state.items(): + mat_shape = layer_state_dict["shape"] + desired_rank = (layer_state_dict["cur"] + 1) * self.granularity + max_rank = int(math.floor(math.prod(mat_shape) / sum(mat_shape) / self.granularity) - 5) * self.granularity + if desired_rank >= max_rank: + print(f"excluding {layer}") + any_excludes = True + self.core[layer]["is_excluded"] = True + + if not any_excludes: + break + + # redo after excluding + previous_state = copy.deepcopy(self.core) + for layer, layer_state_dict in self.core.items(): + if layer_state_dict["is_excluded"]: + previous_state.pop(layer) + + epoch += 1 + + print(f"Optimization finished.") + if eff_scores > self.target_rate: + print(f"Failed to reach the target goal {self.target_rate} vs. {eff_scores}") + return current_state + + def _inner_loop(self, previous_state: Dict): + current_state = copy.deepcopy(previous_state) + all_layers = previous_state.keys() + + loop = 0 + has_modified = False + while True: + # Iterative update + candidates = [] + for layer in all_layers: + cur = current_state[layer]["cur"] + + perf_drop = current_state[layer]["grad"][cur] + moved_cur = cur - 1 + + # skip if it exceeds the supremum + if perf_drop > self.perf_degrade_clamp: + continue + + if current_state[layer]["perf_var"][cur] < self.end: + continue + + if cur == 0: + continue + + mat_shape = current_state[layer]["shape"] + assert len(mat_shape) == 2 + gain = perf_drop / math.prod(mat_shape) + + # collect the current descending performance + heapq.heappush(candidates, (gain, layer, cur, moved_cur)) + + if len(candidates) == 0: + return current_state, has_modified + + for _ in range(self.concurrent_budget): + if len(candidates) == 0: + break + perf_drop, key, cur, moved_cur = heapq.heappop(candidates) + self.update_item(current_state, key, moved_cur) + has_modified = True + + if has_modified: + tot_params, est_params, acc_scores = self._eval_state(current_state) + eff_scores = (1 - (tot_params - est_params) / self.model_params) * 100 # in percentage + if eff_scores <= self.target_rate: + return current_state, has_modified + + loop += 1 + + def export(self): + deft_config = {} + for layer, layer_data in self.core.items(): + cur = layer_data["cur"] + deft_config[layer] = { + "shape": tuple(layer_data["shape"]), + "desired_rank": (cur + 1) * self.granularity, + "perf_score": layer_data["perf_var"][cur], + "is_updated": False, + } + return deft_config diff --git a/flagscale/compress/LLM_Decomposition/sola/playground_sola.py b/flagscale/compress/LLM_Decomposition/sola/playground_sola.py new file mode 100644 index 0000000000..1e0c48ab79 --- /dev/null +++ b/flagscale/compress/LLM_Decomposition/sola/playground_sola.py @@ -0,0 +1,324 @@ +import os +import random +import click +from transformers import ( + AutoModelForCausalLM, + AutoTokenizer, + LlamaTokenizer, + pipeline +) +import torch +import numpy as np +import pandas as pd + +import utils +from sola.training_sola import Helper + +import json +import time +import logging +logger = logging.getLogger(__name__) + +from transformers.models.llama.modeling_llama import LlamaMLP, LlamaAttention + +import lm_eval +from lm_eval.models.huggingface import HFLM + +def setup_seed(seed): + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + np.random.seed(seed) + random.seed(seed) + torch.backends.cudnn.deterministic = True + + +@click.command() +@click.option("-m", "--model", type=click.Path(file_okay=True), help="path to model file", default=None) +def cli(**kwargs): + args = utils.EasyDict(**kwargs) + print(args) + + model_path = args.model + max_memory = "80000MB" + max_memory = {i: max_memory for i in range(1)} + + model = AutoModelForCausalLM.from_pretrained( + model_path, + cache_dir=None, + device_map="auto", + quantization_config = None, + torch_dtype=torch.bfloat16, + attn_implementation="eager", + ) + print("Model created") + + # Tokenizer + tokenizer = AutoTokenizer.from_pretrained( + model_path, + cache_dir=None, + padding_side="right", + use_fast=False, # Fast tokenizer giving issues. + tokenizer_type='llama' if 'ama' in model_path else None, # Needed for HF name change + ) + if tokenizer._pad_token is None: + utils.smart_tokenizer_and_embedding_resize( + special_tokens_dict=dict(pad_token=utils.DEFAULT_PAD_TOKEN), + tokenizer=tokenizer, + model=model, + ) + if 'ama' in model_path or isinstance(tokenizer, LlamaTokenizer): + # LLaMA tokenizer may not have correct special tokens set. + # Check and add them if missing to prevent them from being parsed into different tokens. + # Note that these are present in the vocabulary. + # Note also that `model.config.pad_token_id` is 0 which corresponds to `` token. + print('Adding special tokens.') + tokenizer.add_special_tokens({ + "eos_token": tokenizer.convert_ids_to_tokens(model.config.eos_token_id), + "bos_token": tokenizer.convert_ids_to_tokens(model.config.bos_token_id), + "unk_token": tokenizer.convert_ids_to_tokens( + model.config.pad_token_id if model.config.pad_token_id and model.config.pad_token_id != -1 else tokenizer.pad_token_id + ), + }) + print("Tokenizer created") + + generation_pipeline = pipeline(task="text-generation", model=model, tokenizer=tokenizer) + print("Pipeline created") + + helper_params = {'intermediate_size': model.config.intermediate_size, + 'hidden_size': model.config.hidden_size} + helper = Helper(model, torch.bfloat16, **helper_params) + print("Helper created") + + # Calibration data + prompts_path = '/data/wiki_256_4096.json' + with open(prompts_path, 'r') as file: + prompts = json.load(file) + + # Compute XX^T + t_start_time = time.time() + with helper: + for text in prompts: + prompt_token_count = len(generation_pipeline.tokenizer.encode(text, return_tensors="pt")[0]) + generation_pipeline(text, max_length=int(prompt_token_count), pad_token_id=tokenizer.eos_token_id, truncation=True) + t_end_time = time.time() + t_duration = t_end_time - t_start_time + print(f"Collect training data costs avg: {t_duration/len(prompts): .5f} s, all: {t_duration/60: .2f} min, {t_duration: .5f} s. ") + print('Collect training data Done') + + # Record XX^T + """ + for name, module in model.named_modules(): + suffix = name.split(".")[-1] + if suffix not in ["gate_proj", "up_proj", "down_proj", "q_proj", "k_proj", "v_proj", "o_proj"]: + continue + layer_idx = int(name.split(".")[-3]) + raw_scaling_diag_matrix = getattr(module, f'raw_scaling_diag_matrix_{layer_idx}') + # torch.save(raw_scaling_diag_matrix, os.path.join('/data/model_params/wiki/13b/raw_scaling_diag_matrix/', f"{name}.raw_scaling_diag_matrix")) + print(name, 'raw scaling diag matrix saved') + """ + + # Low-Rank Decomposition + dump_dest = f'/data/model_params/wiki/13b/light' + for name, module in model.named_modules(): + suffix = name.split(".")[-1] + if suffix not in ["gate_proj", "up_proj", "down_proj", "q_proj", "k_proj", "v_proj", "o_proj"]: + continue + layer_idx = int(name.split(".")[-3]) + + raw_scaling_diag_matrix = getattr(module, f'raw_scaling_diag_matrix_{layer_idx}').double().to(model.device) + # raw_scaling_diag_matrix = torch.load(f'/data/model_params/wiki/13b/raw_scaling_diag_matrix/{name}.raw_scaling_diag_matrix').double().to(model.device) + + with open('/data/hitter_dict_15_256_4096_wiki_13b.json', 'r') as json_file: + hitter_dict = json.load(json_file) + light_75p_hitter = torch.tensor(hitter_dict[f'{layer_idx}']['light_85p_neuron_idx']).to(model.device) + if suffix == "down_proj": + raw_scaling_diag_matrix = raw_scaling_diag_matrix[light_75p_hitter, :][:, light_75p_hitter] + + try: + scaling_diag_matrix = torch.linalg.cholesky(raw_scaling_diag_matrix).float().to(model.device) + except Exception as e: + print(name, "Warning: eigen scaling_diag_matrix is not positive!") + if torch.isnan(raw_scaling_diag_matrix).any(): + print("Warning: scaling_diag_matrix contains NaN!") + elif torch.isinf(raw_scaling_diag_matrix).any(): + print("Warning: scaling_diag_matrix contains Inf!") + if not torch.equal(raw_scaling_diag_matrix, raw_scaling_diag_matrix.T): + print("Warning: scaling_diag_matrix is not a symmetric matrix!") + eigenvalues = torch.linalg.eigvalsh(raw_scaling_diag_matrix) + raw_scaling_diag_matrix += (- eigenvalues[0] + 1e-3) * torch.eye(raw_scaling_diag_matrix.shape[0]).to(model.device) + scaling_diag_matrix = torch.linalg.cholesky(raw_scaling_diag_matrix).float().to(model.device) + + try: + scaling_matrix_inv = torch.linalg.inv(scaling_diag_matrix) + except Exception as e: + print(name, "Warning: scaling_diag_matrix is not full rank!") + scaling_diag_matrix += 1e-3 * torch.eye(scaling_diag_matrix.shape[0]).to(model.device) + scaling_matrix_inv = torch.linalg.inv(scaling_diag_matrix).to(model.device) + + W = module.weight.float() + + # MLP beta weight matrix decomposition + if suffix in ["gate_proj", "up_proj", "down_proj"]: + if light_75p_hitter.device != W.device: + light_75p_hitter = light_75p_hitter.to(W.device) + if suffix in ["gate_proj", "up_proj"]: + W = W[light_75p_hitter, :].to(model.device) + else: + W = W[:, light_75p_hitter].to(model.device) + + if W.device != scaling_diag_matrix.device: + scaling_diag_matrix = scaling_diag_matrix.to(W.device) + W_scale = torch.matmul(W, scaling_diag_matrix) + if layer_idx == 0: + print('W scale shape: ', suffix, W_scale.shape) + + u, s, v = torch.linalg.svd(W_scale, full_matrices=False) # The singular values are returned in descending order. + if layer_idx == 0: + print('decomposition: ', name, u.shape, s.shape, v.shape, W.shape, scaling_matrix_inv.shape) + + torch.save(u, os.path.join(dump_dest, f"{name}.u")) + # torch.save(v, os.path.join(dump_dest, f"{name}.v")) + torch.save(s, os.path.join(dump_dest, f"{name}.s")) + print(name, 'u s v saved.', dump_dest) + + if v.device != scaling_matrix_inv.device: + v = v.to(scaling_matrix_inv.device) + v_inv = v @ scaling_matrix_inv + torch.save(v_inv, os.path.join(dump_dest, f"{name}.v_inv")) + print(name, 'v_inv saved.', dump_dest) + + # Compute the rank of each component + target_rate = 0.7 + if target_rate >= 0.8: + target_modules = ["q_proj", "k_proj", "gate_proj", "up_proj", "down_proj"] + else: + target_modules = ["q_proj", "k_proj", "o_proj", "gate_proj", "up_proj", "down_proj"] + + # config = utils.get_rank_config(model, target_modules, target_rate, args.model, start=99.9, end=90) + # ''' + if '13b' in model_path: # Llama-2-13b + if target_rate == 0.8: # 20% compression rate + config = {"q_proj": 1440, "k_proj": 1440, + "gate_proj": 2048, "up_proj": 2816, "down_proj": 2976} + elif target_rate == 0.7: # 30% compression rate + config = {"q_proj": 960, "k_proj": 640, "o_proj": 1920, + "gate_proj": 2080, "up_proj": 2400, "down_proj": 2560} + elif target_rate == 0.6: # 40% compression rate + config = {"q_proj": 480, "k_proj": 480, "o_proj": 1440, + "gate_proj": 1120, "up_proj": 2080, "down_proj": 2240} + elif target_rate == 0.5: # 50% compression rate + config = {"q_proj": 320, "k_proj": 320, "o_proj": 768, + "gate_proj": 1280, "up_proj": 1280, "down_proj": 1440} + else: # Llama-2-7b + if target_rate == 0.8: # 20% compression rate + config = {"q_proj": 1120, "k_proj": 800, + "gate_proj": 1760, "up_proj": 2400, "down_proj": 2240} + elif target_rate == 0.7: # 30% compression rate + config = {"q_proj": 640, "k_proj": 640, "o_proj": 1440, + "gate_proj": 1760, "up_proj": 1920, "down_proj": 1760} + # ''' + print(config) + + desired_ranks = {} + layer_num = 40 if '13b' in model_path else 32 + for layer_idx in range(layer_num): + for suffix in ["gate_proj", "up_proj", "down_proj", "q_proj", "k_proj", "o_proj"]: + if f'{layer_idx}' not in desired_ranks.keys(): + desired_ranks[f'{layer_idx}'] = {suffix: (config[suffix], None)} + else: + desired_ranks[f'{layer_idx}'][suffix] = (config[suffix], None) + + if '13b' in model_path: + pruned_layer_idx_list = [2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, + 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, + 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37] + elif '7b' in model_path: + pruned_layer_idx_list = [2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, + 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, + 26, 27, 28, 29] + print('pruned layer number: ', len(pruned_layer_idx_list)) + + # Reduce memory through low rank decomposition + hot_ratio = 15 + if '13b' in model_path: + active_params = model_params = 13343959040 + # edit file path + dump_dest = '/data/model_params/wiki/13b/light' + save_dest = f'/data/model_params/wiki/13b/uv_fils_{hot_ratio}_30' + elif '7b' in model_path: + active_params = model_params = 7000842240 + # dump_dest save_dest edit file path + for filename in os.listdir(save_dest): + if filename.split('.')[-1] not in ['wu', 'wv']: + continue + file_path = os.path.join(save_dest, filename) + if os.path.exists(file_path): + os.remove(file_path) + print(f'{file_path} deletion.') + for name, module in model.named_modules(): + if not isinstance(module, (LlamaMLP, LlamaAttention)): + continue + layer_idx = int(name.split(".")[-2]) + if layer_idx not in pruned_layer_idx_list: + continue + + suffix_list = ["q_proj", "k_proj" , "o_proj", "gate_proj", "up_proj", "down_proj"] + for suffix in suffix_list: + u = torch.load(os.path.join(dump_dest, f"{name}.{suffix}.u"), map_location=torch.device('cuda')) + s = torch.load(os.path.join(dump_dest, f"{name}.{suffix}.s"), map_location=torch.device('cuda')) + v = torch.load(os.path.join(dump_dest, f"{name}.{suffix}.v_inv"), map_location=torch.device('cuda')) + k = desired_ranks[f'{layer_idx}'][suffix][0] + u, v = utils.get_uv(u, s, v, k) + + if suffix in ["gate_proj", "up_proj", "down_proj"]: + module_weight_numel = model.config.intermediate_size * model.config.hidden_size + module_weight_numel = int(0.85 * module_weight_numel) + elif suffix in ["q_proj", "k_proj"]: + module_weight_numel = model.config.hidden_size * \ + (model.config.hidden_size / + (model.config.num_attention_heads / model.config.num_key_value_heads)) + elif suffix in ["o_proj"]: + module_weight_numel = model.config.hidden_size * model.config.hidden_size + active_params -= module_weight_numel - v.numel() - u.numel() + + torch.save(u, os.path.join(save_dest, f"{name}.{suffix}.wu")) + torch.save(v, os.path.join(save_dest, f"{name}.{suffix}.wv")) + print(f"{name}.{suffix} {k} wu wv saved.") + + u = s = v = None + del u, s, v + utils.clear_torch_cache() + print(f"Estimated compression rate: {1 - active_params/model_params:.4f}") + + helper.apply_sola_to_model(pruned_layer_idx_list, desired_ranks, hot_ratio, save_dest, model) + utils.clear_torch_cache() + print(f'Appling Done') + + # torch.save(model.state_dict(), f'/data/model_params/model_{target_rate}.pth') + + model.eval() + setup_seed(42) + + # Evaluate perplexity + ppl = utils.eval_ppl(model, tokenizer) + print('ppl: ', ppl) + + # Evaluate lm eval accuracy + hflm = HFLM(pretrained=model, tokenizer=tokenizer, batch_size=8) + task_names = ['piqa', 'hellaswag', 'boolq', 'winogrande', 'arc_easy', 'arc_challenge', 'openbookqa'] + results = lm_eval.simple_evaluate(hflm, tasks=task_names, num_fewshot=0, batch_size=8)[ + 'results' + ] + print(results) + metric_vals = {task: round(result.get(utils.TASK_METRIC_MAP[task]), 4) for task, result in results.items()} + acc_avg = utils.calculate_avg_accuracy(task_names, results) + metric_vals['average'] = round(acc_avg, 4) + print(metric_vals) + + # Evaluate mmlu accuracy + utils.eval_mmlu(model, tokenizer, 5, "data/mmlu-data") + print('Eval MMLU Done \n') + + +if __name__ == "__main__": + cli() diff --git a/flagscale/compress/LLM_Decomposition/sola/playground_sola.sh b/flagscale/compress/LLM_Decomposition/sola/playground_sola.sh new file mode 100644 index 0000000000..6dcc57af0c --- /dev/null +++ b/flagscale/compress/LLM_Decomposition/sola/playground_sola.sh @@ -0,0 +1,3 @@ +#!/bin/bash + +CUDA_VISIBLE_DEVICES="0" python sola/playground_sola.py --model="/data/llama-2/Llama-2-13b-hf" \ No newline at end of file diff --git a/flagscale/compress/LLM_Decomposition/sola/training_sola.py b/flagscale/compress/LLM_Decomposition/sola/training_sola.py new file mode 100644 index 0000000000..e08868bf72 --- /dev/null +++ b/flagscale/compress/LLM_Decomposition/sola/training_sola.py @@ -0,0 +1,160 @@ +import os +import torch +import contextlib +from typing import Dict, List + +from utils import HELPER_SUPPORT_MODEL_LIST, HELPER_SUPPORT_MODEL_TYPES + +from sola.hook_sola import ( + add_training_hook, + remove_training_hook, + add_inference_hook, +) +import utils + + +class Helper(contextlib.ContextDecorator): + def __init__(self, model: HELPER_SUPPORT_MODEL_TYPES, compute_type, **kwargs): + self.model = model + self.device = model.device + self.compute_type = compute_type + self.hidden_size = kwargs["hidden_size"] + self.intermediate_size = kwargs["intermediate_size"] + self.training_data: Dict[str, Dict[str, List[torch.Tensor]]] = {} + + if not isinstance(model, HELPER_SUPPORT_MODEL_LIST): + raise NotImplementedError("Unsupported model") + + def __enter__(self): + self.model_last_layer = add_training_hook(self.model, self.training_data, self.intermediate_size, self.hidden_size) + + def __exit__(self, exc_type, exc_val, exc_tb): + remove_training_hook(self.model) + + def apply_sola_to_model(self, pruned_layer_idx_list, desired_rank_pref, hot_ratio, usv_dump_dest, model: HELPER_SUPPORT_MODEL_TYPES): + import json + hitter_dict_path = f'/data/hitter_dict_{hot_ratio}_256_4096_wiki_13b.json' + with open(hitter_dict_path, 'r') as json_file: + hitter_dict = json.load(json_file) + + def infer_device() -> torch.device: + if not torch.cuda.is_available(): + return torch.device("cpu") + max_free_memory = -1 + best_device_index = -1 + for i in range(torch.cuda.device_count()): + current_device = torch.device(f"cuda:{i}") + torch.cuda.set_device(current_device) + free_memory = torch.cuda.get_device_properties(i).total_memory - torch.cuda.memory_allocated() + if free_memory > max_free_memory: + max_free_memory = free_memory + best_device_index = i + if best_device_index == -1: + return torch.device("cpu") + else: + return torch.device(f"cuda:{best_device_index}") + + def regis_mlp_attn_func(pruned_layer_idx_list, model, hitter_dict): + from transformers.models.llama.modeling_llama import LlamaMLP, LlamaAttention + for name, module in model.named_modules(): + if not isinstance(module, (LlamaMLP, LlamaAttention)): + continue + layer_idx = int(name.split(".")[-2]) + if layer_idx not in pruned_layer_idx_list: + continue + + dump_dest = usv_dump_dest + + if isinstance(module, (LlamaMLP)): + # prime + hitter_dict_cur_layer = hitter_dict[f'{layer_idx}'] + heavy_25p_hitter = torch.tensor(hitter_dict_cur_layer['heavy_15p_neuron_idx']) + gate_weight_heavy = module.gate_proj.weight[heavy_25p_hitter, :] + up_weight_heavy = module.up_proj.weight[heavy_25p_hitter, :] + down_weight_heavy = module.down_proj.weight[:, heavy_25p_hitter] + module.register_buffer('gate_weight_heavy', gate_weight_heavy.to(torch.bfloat16)) + module.register_buffer('up_weight_heavy', up_weight_heavy.to(torch.bfloat16)) + module.register_buffer('down_weight_heavy', down_weight_heavy.to(torch.bfloat16)) + module.gate_proj = module.up_proj = module.down_proj = None + del module.gate_proj + del module.up_proj + del module.down_proj + utils.clear_torch_cache() + + # marginal + suffix_list = ["gate_proj", "up_proj", "down_proj"] + for suffix in suffix_list: + if suffix not in desired_rank_pref[f'{layer_idx}'].keys(): + module.register_buffer(f'{suffix}_use', torch.Tensor([True])) + print(f"{suffix} not in desired rank {desired_rank_pref[f'{layer_idx}'].keys()}.") + light_idx = torch.tensor(hitter_dict_cur_layer['light_85p_neuron_idx']) + if suffix == "gate_proj": + gate_weight_light = module.gate_proj.weight[light_idx, :] + module.register_buffer('gate_weight_light', gate_weight_light.to(torch.bfloat16)) + elif suffix == "up_proj": + up_weight_light = module.up_proj.weight[light_idx, :] + module.register_buffer('up_weight_light', up_weight_light.to(torch.bfloat16)) + else: + down_weight_light = module.down_proj.weight[:, light_idx] + module.register_buffer('down_weight_light', down_weight_light.to(torch.bfloat16)) + else: + module.register_buffer(f'{suffix}_use', torch.Tensor([False])) + u = torch.load(os.path.join(dump_dest, f"{name}.{suffix}.wu"), map_location=torch.device(infer_device())) + v = torch.load(os.path.join(dump_dest, f"{name}.{suffix}.wv"), map_location=torch.device(infer_device())) + print('get u v: ', name, suffix, u.shape, v.shape, u.device, v.device) + if suffix == "gate_proj": + module.register_buffer('gate_weight_U_top', v.t().to(torch.bfloat16)) + module.register_buffer('gate_weight_SVh_top', u.t().to(torch.bfloat16)) + elif suffix == "up_proj": + module.register_buffer('up_weight_U_top', v.t().to(torch.bfloat16)) + module.register_buffer('up_weight_SVh_top', u.t().to(torch.bfloat16)) + else: + module.register_buffer('down_weight_U_top', v.t().to(torch.bfloat16)) + module.register_buffer('down_weight_SVh_top', u.t().to(torch.bfloat16)) + u = s = v = None + del u, s, v + utils.clear_torch_cache() + else: + suffix_list = ["q_proj", "k_proj", "o_proj"] + for suffix in suffix_list: + if suffix not in desired_rank_pref[f'{layer_idx}'].keys(): + print(f"{suffix} not in {desired_rank_pref[f'{layer_idx}'].keys()}.") + module.register_buffer(f'{suffix}_use', torch.Tensor([True])) + else: + module.register_buffer(f'{suffix}_use', torch.Tensor([False])) + u = torch.load(os.path.join(dump_dest, f"{name}.{suffix}.wu"), map_location=torch.device(infer_device())) + v = torch.load(os.path.join(dump_dest, f"{name}.{suffix}.wv"), map_location=torch.device(infer_device())) + print('attn get u v: ', name, suffix, u.shape, v.shape, u.device, v.device) + + if suffix == "q_proj": + module.register_buffer('q_weight_U_top', v.t().to(torch.bfloat16)) + module.register_buffer('q_weight_SVh_top', u.t().to(torch.bfloat16)) + module.q_proj = None + del module.q_proj + utils.clear_torch_cache() + elif suffix == "k_proj": + module.register_buffer('k_weight_U_top', v.t().to(torch.bfloat16)) + module.register_buffer('k_weight_SVh_top', u.t().to(torch.bfloat16)) + module.k_proj = None + del module.k_proj + utils.clear_torch_cache() + elif suffix == "o_proj": + module.register_buffer('o_weight_U_top', v.t().to(torch.bfloat16)) + module.register_buffer('o_weight_SVh_top', u.t().to(torch.bfloat16)) + module.o_proj = None + del module.o_proj + utils.clear_torch_cache() + # elif suffix == "v_proj": + # module.register_buffer('v_weight_U_top', v.t().to(torch.bfloat16)) + # module.register_buffer('v_weight_SVh_top', u.t().to(torch.bfloat16)) + # module.v_proj = None + # del module.v_proj + # utils.clear_torch_cache() + u = s = v = None + del u, s, v + utils.clear_torch_cache() + + regis_mlp_attn_func(pruned_layer_idx_list, model, hitter_dict) + + add_inference_hook(pruned_layer_idx_list, model) + \ No newline at end of file diff --git a/flagscale/compress/LLM_Decomposition/sola/utils.py b/flagscale/compress/LLM_Decomposition/sola/utils.py new file mode 100644 index 0000000000..c77ece4dff --- /dev/null +++ b/flagscale/compress/LLM_Decomposition/sola/utils.py @@ -0,0 +1,811 @@ +import os +import re +import gc +import time +import json +import pickle +import torch +import transformers +import lm_eval +import numpy as np +import pandas as pd +from tqdm import tqdm +from enum import Enum +from datasets import load_dataset +from typing import Any, Union, Dict, TypeVar, Generic, Iterable, List, Iterator +from transformers import OPTPreTrainedModel, LlamaPreTrainedModel +from optim import Optimizer +import copy +import itertools +import bisect +import warnings + + +T_co = TypeVar('T_co', covariant=True) +class Dataset(Generic[T_co]): + def __getitem__(self, index) -> T_co: + raise NotImplementedError + + def __add__(self, other: 'Dataset[T_co]') -> 'ConcatDataset[T_co]': + return ConcatDataset([self, other]) + + +class ConcatDataset(Dataset[T_co]): + r"""Dataset as a concatenation of multiple datasets. + + This class is useful to assemble different existing datasets. + + Args: + datasets (sequence): List of datasets to be concatenated + """ + datasets: List[Dataset[T_co]] + cumulative_sizes: List[int] + + @staticmethod + def cumsum(sequence): + r, s = [], 0 + for e in sequence: + l = len(e) + r.append(l + s) + s += l + return r + + def __init__(self, datasets: Iterable[Dataset]) -> None: + super(ConcatDataset, self).__init__() + # Cannot verify that datasets is Sized + assert len(datasets) > 0, 'datasets should not be an empty iterable' # type: ignore + self.datasets = list(datasets) + for d in self.datasets: + assert not isinstance(d, IterableDataset), "ConcatDataset does not support IterableDataset" + self.cumulative_sizes = self.cumsum(self.datasets) + + def __len__(self): + return self.cumulative_sizes[-1] + + def __getitem__(self, idx): + if idx < 0: + if -idx > len(self): + raise ValueError("absolute value of index should not exceed dataset length") + idx = len(self) + idx + dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) + if dataset_idx == 0: + sample_idx = idx + else: + sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] + return self.datasets[dataset_idx][sample_idx] + + @property + def cummulative_sizes(self): + warnings.warn("cummulative_sizes attribute is renamed to " + "cumulative_sizes", DeprecationWarning, stacklevel=2) + return self.cumulative_sizes + +class IterableDataset(Dataset[T_co]): + def __iter__(self) -> Iterator[T_co]: + raise NotImplementedError + + def __add__(self, other: Dataset[T_co]): + return ChainDataset([self, other]) + +class ChainDataset(IterableDataset): + r"""Dataset for chainning multiple :class:`IterableDataset` s. + + This class is useful to assemble different existing dataset streams. The + chainning operation is done on-the-fly, so concatenating large-scale + datasets with this class will be efficient. + + Args: + datasets (iterable of IterableDataset): datasets to be chained together + """ + def __init__(self, datasets: Iterable[Dataset]) -> None: + super(ChainDataset, self).__init__() + self.datasets = datasets + + def __iter__(self): + for d in self.datasets: + assert isinstance(d, IterableDataset), "ChainDataset only supports IterableDataset" + for x in d: + yield x + + def __len__(self): + total = 0 + for d in self.datasets: + assert isinstance(d, IterableDataset), "ChainDataset only supports IterableDataset" + # Cannot verify that all self.datasets are Sized + total += len(d) # type: ignore + return total + +def get_test_data(name, tokenizer, seq_len=2048, batch_size=4): + class IndexDataset(Dataset): + def __init__(self, tensors): + self.tensors = tensors + + def __getitem__(self, index): + return self.tensors[index] + + def __len__(self): + return len(self.tensors) + #### + def process_data(samples, tokenizer, seq_len, field_name): + test_ids = tokenizer("\n\n".join(samples[field_name]), return_tensors='pt').input_ids[0] + test_ids_batch = [] + nsamples = test_ids.numel() // seq_len + + for i in range(nsamples): + batch = test_ids[(i * seq_len):((i + 1) * seq_len)] + test_ids_batch.append(batch) + test_ids_batch = torch.stack(test_ids_batch) + return IndexDataset(tensors=test_ids_batch) + #### + if 'wikitext2' in name: + test_data = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test') + test_dataset = process_data(test_data, tokenizer, seq_len, 'text') + if 'ptb' in name: + test_data = load_dataset('ptb_text_only', 'penn_treebank', split='test') + test_dataset = process_data(test_data, tokenizer, seq_len, 'sentence') + elif 'c4' in name: + test_data = load_dataset("json", data_files="utils/c4-validation.json")['train'] + test_dataset = process_data(test_data[0:2000], tokenizer, seq_len, 'text') + test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False) + return test_loader + +@torch.no_grad() +def eff_eval(model, tokenizer, dataset='wikitext2', original_len=4, generated_len=128, batch_size=1, device="cuda"): + model.eval() + token_num = 0 + factor = 1 if batch_size > 32 else 3 + num_batches_to_fetch = 13 * factor + test_loader = get_test_data(dataset, tokenizer, seq_len=original_len, batch_size=batch_size) + start_time = time.time() + progress_bar = tqdm(enumerate(itertools.islice(test_loader, num_batches_to_fetch))) + for batch_idx, batch_data in progress_bar: + batch = batch_data.to(device) + if batch_idx == 3 * factor: + start_time = time.time() + if batch_idx >= 3 * factor: + token_num += batch.shape[0] * generated_len + progress_bar.set_postfix_str(f"{token_num=}") + generation_output = model.generate( + input_ids=batch, + pad_token_id=tokenizer.eos_token_id, + do_sample=True, + use_cache=True, + max_length=original_len+generated_len, + ) + torch.cuda.synchronize() + end_time = time.time() + total_time = end_time - start_time + throughput = token_num / total_time + return throughput + + +def eff_eval_v2(model, tokenizer, dataset='wikitext2', original_len=4, generated_len=128, batch_size=1, device="cuda"): + model.eval() + throughput = 0 + token_num = 0 + num_batches_to_fetch = 130 + test_loader = get_test_data(dataset, tokenizer, seq_len=original_len, batch_size = batch_size) + + for batch_idx, batch_data in enumerate(itertools.islice(test_loader, num_batches_to_fetch)): + batch = batch_data.to(device) + if batch_idx >= 30: + break + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats(0) + torch.cuda.synchronize() + generation_output = model.generate( + input_ids=batch, + pad_token_id=tokenizer.eos_token_id, + do_sample=True, + use_cache=True, + top_k=50, + max_length=original_len+generated_len, + top_p=0.95, + temperature=1, + ) + torch.cuda.synchronize() + + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats(0) + torch.cuda.synchronize() + + start_time = time.time() + for batch_idx, batch_data in enumerate(itertools.islice(test_loader, num_batches_to_fetch)): + if batch_idx < 30: + continue + batch = batch_data.to(device) + token_num += batch.shape[0] * generated_len + generation_output = model.generate( + input_ids=batch, + pad_token_id=tokenizer.eos_token_id, + do_sample=True, + use_cache=True, + top_k=50, + max_length=original_len+generated_len, + top_p=0.95, + temperature=1, + ) + torch.cuda.synchronize() + end_time = time.time() + throughput = end_time - start_time + print("time: {}".format(end_time - start_time)) + print("Throughput: {} tokens/sec".format(token_num / throughput)) + return token_num / throughput + + +def clear_torch_cache() -> None: + gc.collect() + torch.cuda.empty_cache() + +def get_uv(u, s, v, k): + svd_u = u[:, :k] + svd_s = s[:k] + svd_v = v[:k, :] + sqrt_s = torch.diag(torch.sqrt(svd_s)) + if svd_u.device != sqrt_s.device: + print('svd u s device: ', svd_u.device, sqrt_s.device) + svd_u = svd_u.to(sqrt_s.device) + if sqrt_s.device != svd_v.device: + print('svd s v device: ', sqrt_s.device, svd_v.device) + svd_v = svd_v.to(sqrt_s.device) + clear_torch_cache() + u=(svd_u @ sqrt_s).T + v=(sqrt_s @ svd_v).T + return u, v + +def get_rank_config(model, target_modules, target_rate, model_path, start=99.9, end=90): + optimizer = constitute_mapping(model, target_modules, target_rate, model_path, start=99.9, end=90) + optimized_state = optimizer.constringe() + rank = {} + for name, module in model.named_modules(): + suffix = name.split(".")[-1] + if target_rate >= 0.8: + target_modules = ["q_proj", "k_proj", "gate_proj", "up_proj", "down_proj"] + else: + target_modules = ["q_proj", "k_proj", "o_proj", "gate_proj", "up_proj", "down_proj"] + if suffix not in target_modules or name not in optimized_state.keys(): + print(f"{name} skipped.") + continue + cur = optimized_state[name]["cur"] + desired_rank = (cur + 1) * 32 # rank + accum = optimized_state[name]["perf_var"][cur] # perf score + layer_idx = int(name.split(".")[-3]) + if layer_idx not in rank.keys(): + rank[layer_idx] = {suffix: (desired_rank, accum)} + else: + rank[layer_idx][suffix] = (desired_rank, accum) + print(f"{name=}, {desired_rank=}, {accum:.2f}%.") + + q_proj_values_non_zero = [] + k_proj_values_non_zero = [] + o_proj_values_non_zero = [] + gate_proj_values_non_zero = [] + up_proj_values_non_zero = [] + down_proj_values_non_zero = [] + for key in rank.keys(): + if 'q_proj' in rank[key].keys(): + q_proj_values_non_zero.append(rank[key]['q_proj'][0]) + if 'k_proj' in rank[key].keys(): + k_proj_values_non_zero.append(rank[key]['k_proj'][0]) + if 'o_proj' in rank[key].keys(): + o_proj_values_non_zero.append(rank[key]['o_proj'][0]) + if 'gate_proj' in rank[key].keys(): + gate_proj_values_non_zero.append(rank[key]['gate_proj'][0]) + if 'up_proj' in rank[key].keys(): + up_proj_values_non_zero.append(rank[key]['up_proj'][0]) + if 'down_proj' in rank[key].keys(): + down_proj_values_non_zero.append(rank[key]['down_proj'][0]) + + q_r = sorted(q_proj_values_non_zero)[-int(len(sorted(q_proj_values_non_zero)) * 0.3)] + k_r = sorted(k_proj_values_non_zero)[-int(len(sorted(k_proj_values_non_zero)) * 0.3)] + q_r = ((q_r + 159) // 160) * 160 + k_r = ((k_r + 159) // 160) * 160 + print('q:', q_r, int(len(sorted(q_proj_values_non_zero)) * 0.3)) + print('k:', k_r, int(len(sorted(k_proj_values_non_zero)) * 0.3)) + if target_rate < 0.8: + o_r = sorted(o_proj_values_non_zero)[-int(len(sorted(o_proj_values_non_zero)) * 0.55)] + o_r = ((o_r + 159) // 160) * 160 + print('o:', o_r, int(len(sorted(o_proj_values_non_zero)) * 0.55)) + if target_rate == 0.7: + gate_r = sorted(gate_proj_values_non_zero)[-int(len(sorted(gate_proj_values_non_zero)) * 0.3)] + else: + gate_r = sorted(gate_proj_values_non_zero)[-int(len(sorted(gate_proj_values_non_zero)) * 0.7)] + gate_r = ((gate_r + 159) // 160) * 160 + up_r = sorted(up_proj_values_non_zero)[-int(len(sorted(up_proj_values_non_zero)) * 0.6)] + up_r = ((up_r + 159) // 160) * 160 + print('gate:', gate_r, int(len(sorted(gate_proj_values_non_zero)) * 0.5)) + print('up:', up_r, int(len(sorted(up_proj_values_non_zero)) * 0.5)) + if '13b' in model_path: + down_r = int((5120 * int(13824 * 0.85)) / (5120 + int(13824 * 0.85)) * target_rate) + if '7b' in model_path: + down_r = int((4096 * int(11008 * 0.85)) / (4096 + int(11008 * 0.85)) * target_rate) + down_r = ((down_r + 159) // 160) * 160 + print('down:', down_r) + + if target_rate < 0.8: + config = {"q_proj": q_r, "k_proj": k_r, "o_proj": o_r, + "gate_proj": gate_r, "up_proj": up_r, "down_proj": down_r} + else: + config = {"q_proj": q_r, "k_proj": k_r, + "gate_proj": gate_r, "up_proj": up_r, "down_proj": down_r} + + return config + +def constitute_mapping(model, target_modules, target_rate, model_name, + granularity=32, start=99, end=90, + dump_dest="/home/xinhao/sep-peft/model_params/wiki/svd_llm/light", + dump_dest_attn="/home/xinhao/sep-peft/model_params/wiki/svd_llm/attn"): + if "Llama-2-13b-hf" in model_name: + config = { + "meta-llama/Llama-2-13b-hf": { + "architectures": ["LlamaForCausalLM"], + "hidden_size": 5120, + "intermediate_size": 13824, + "num_attention_heads": 40, + "num_hidden_layers": 40, + "vocab_size": 32000, + "params": 13343959040, + "lora": 125173760, + "optim": 250347520, + "grad": 125173760, + } + } + skip_layers = [0, 1, 38, 39] + elif "Llama-2-7b-hf" in model_name: + config = { + "meta-llama/Llama-2-7b-hf": { + "architectures": ["LlamaForCausalLM"], + "hidden_size": 4096, + "intermediate_size": 11008, + "num_attention_heads": 32, + "num_hidden_layers": 32, + "params": 7000842240, + "lora": 79953920, + "optim": 159907840, + "grad": 79953920, + } + } + skip_layers = [0, 1, 29, 30] + model_params = config[model_name]["params"] + target_rate *= 100 + optimizer = Optimizer(model_params, target_rate, granularity, start, end) + + # Called during adding hook for updating + for name, module in model.named_modules(): + suffix = name.split(".")[-1] + + if suffix not in target_modules: + print(f"[constitute-mapping] skipping {name} due to configuration") + continue + + layer_idx = int(name.split(".")[-3]) + if layer_idx in skip_layers: + print(f"[constitute-mapping] skipping {layer_idx} due to configuration") + continue + + if target_rate >= 80: + suffix_attn_list = ["q_proj", "k_proj"] + else: + suffix_attn_list = ["q_proj", "k_proj", "o_proj"] + if suffix in suffix_attn_list: + if not os.path.exists(os.path.join(dump_dest_attn, f"{name}.s")): + print(f"[constitute-mapping] skipping {name} since it cannot find sigma file") + continue + s = torch.load(os.path.join(dump_dest_attn, f"{name}.s"), map_location="cuda:0") + else: + if not os.path.exists(os.path.join(dump_dest, f"{name}.s")): + print(f"[constitute-mapping] skipping {name} since it cannot find sigma file") + continue + s = torch.load(os.path.join(dump_dest, f"{name}.s"), map_location="cuda:0") + + m, n = module.weight.shape + max_trunc = int((m * n) / (m + n)) + sigma_square = s ** 2 + total = sigma_square.sum() + + perf_var = [] + grad = [] + start_idx = 10000000 # random large number + for trunc in range(granularity, max_trunc, granularity): + perf = (sigma_square[:trunc].sum() / total * 100).item() # in percentage + if perf >= start: + start_idx = min(len(perf_var), start_idx) + idx = len(perf_var) + perf_var.append(perf) + grad.append(0 if idx == 0 else perf - perf_var[idx-1]) + + start_idx = min(start_idx, len(perf_var) - 1) + + # skip if it is not profitable + deft_uniform = None + if deft_uniform is not None and (start_idx + 1) * granularity >= max_trunc - 5 * granularity: + print(f"[constitute-mapping] skipping {name} for trivial profit") + continue + + # collect + optimizer.add_item(name, start_idx, tuple([m, n]), perf_var, grad) + + return optimizer + + +class HelperState(Enum): + KEY = 10000 + + Collecting = 0 + Inference = 1 + + Invalid = 9999 + + +HelperState.KEY.label = "HelperState" +HelperState.Collecting.label = "Helper-Data-Collection" # hook forward() to collect data +HelperState.Inference.label = "Helper-Ready-Inference" # with updated forward() + +class HelperState(Enum): + KEY = 10000 + + Collecting = 0 + Inference = 1 + + Invalid = 9999 + + +HelperState.KEY.label = "HelperState" +HelperState.Collecting.label = "Helper-Data-Collection" # hook forward() to collect data +HelperState.Inference.label = "Helper-Ready-Inference" # with updated forward() + + +class HelperCollectState(Enum): + KEY = 10001 + + Pre = 0 + Post = 1 + End = 2 + + Invalid = 9999 + + +HelperCollectState.KEY.label = "HelperCollectState" +HelperCollectState.Pre.label = "HelperCollectState-Pre" +HelperCollectState.Post.label = "HelperCollectState-Post" +HelperCollectState.End.label = "HelperCollectState-End" + +def set_helper_state(model, state: HelperState) -> None: + setattr(model, HelperState.KEY.label, state) + + +HELPER_SUPPORT_MODEL_LIST = (LlamaPreTrainedModel) +HELPER_SUPPORT_MODEL_TYPES = Union[LlamaPreTrainedModel] + + +# https://pypi.org/project/lm-eval/0.0.1/ +TASK_METRIC_MAP = { + "piqa": "acc_norm,none", + "arc_challenge": "acc_norm,none", + "arc_easy": "acc_norm,none", + "hellaswag": "acc_norm,none", + "winogrande": "acc,none", + "boolq": "acc,none", + 'wsc': 'acc,none', + "openbookqa": "acc_norm,none" +} + +def calculate_avg_accuracy(task_names: str, results: dict) -> float: + n_tasks = len(task_names) + acc_cumul = sum(result.get(TASK_METRIC_MAP[task]) for task, result in results.items() if 'mmlu' not in task) + + questions_per_mmlu_task = { + task_name: lm_eval.tasks.get_task_dict([task_name])[task_name].dataset["test"].num_rows + for task_name in task_names + if 'mmlu' in task_name + } + + if not questions_per_mmlu_task: + return acc_cumul / n_tasks + + # Calculate average accuracy for mmlu tasks, weighted by number of questions in each task + acc_mmlu = sum( + result.get(TASK_METRIC_MAP[task]) * questions_per_mmlu_task[task] + for task, result in results.items() + if 'mmlu' in task + ) + acc_mmlu_avg = acc_mmlu / sum(questions_per_mmlu_task.values()) + + return (acc_cumul + acc_mmlu_avg) / (n_tasks - len(questions_per_mmlu_task) + 1) + + +def easy_dump(obj, dest, label): + with open(os.path.join(dest, f"{label}.pkl"), "wb") as f: + pickle.dump(obj, f) + + # also dump as json if it is a dict + if isinstance(obj, dict): + with open(os.path.join(dest, f"{label}.json"), "w") as f: + f.write(json.dumps(obj, indent=4)) + +def make_run_dir(outdir: Union[str, os.PathLike], desc: str) -> str: + """Reject modernity, return to automatically create the run dir.""" + # Pick output directory. + prev_run_dirs = [] + if os.path.isdir(outdir): # sanity check, but click.Path() should clear this one + prev_run_dirs = [x for x in os.listdir(outdir) if os.path.isdir(os.path.join(outdir, x))] + prev_run_ids = [re.match(r'^\d+', x) for x in prev_run_dirs] + prev_run_ids = [int(x.group()) for x in prev_run_ids if x is not None] + cur_run_id = max(prev_run_ids, default=-1) + 1 # start with 00000 + run_dir = os.path.join(outdir, f'{cur_run_id:05d}-{desc}') + os.makedirs(run_dir, exist_ok=False) # make sure it doesn't already exist + return run_dir + +class EasyDict(dict): + """Convenience class that behaves like a dict but allows access with the attribute syntax.""" + + def __getattr__(self, name: str) -> Any: + try: + return self[name] + except KeyError: + raise AttributeError(name) + + def __setattr__(self, name: str, value: Any) -> None: + self[name] = value + + def __delattr__(self, name: str) -> None: + del self[name] + +DEFAULT_PAD_TOKEN = "[PAD]" +def smart_tokenizer_and_embedding_resize( + special_tokens_dict: Dict, + tokenizer: transformers.PreTrainedTokenizer, + model: transformers.PreTrainedModel, +): + """Resize tokenizer and embedding. + + Note: This is the unoptimized version that may make your embedding size not be divisible by 64. + """ + num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict) + model.resize_token_embeddings(len(tokenizer)) + + if num_new_tokens > 0: + input_embeddings_data = model.get_input_embeddings().weight.data + output_embeddings_data = model.get_output_embeddings().weight.data + + input_embeddings_avg = input_embeddings_data[:-num_new_tokens].mean(dim=0, keepdim=True) + output_embeddings_avg = output_embeddings_data[:-num_new_tokens].mean(dim=0, keepdim=True) + + input_embeddings_data[-num_new_tokens:] = input_embeddings_avg + output_embeddings_data[-num_new_tokens:] = output_embeddings_avg + +subcategories = { + "abstract_algebra": ["math"], + "anatomy": ["health"], + "astronomy": ["physics"], + "business_ethics": ["business"], + "clinical_knowledge": ["health"], + "college_biology": ["biology"], + "college_chemistry": ["chemistry"], + "college_computer_science": ["computer science"], + "college_mathematics": ["math"], + "college_medicine": ["health"], + "college_physics": ["physics"], + "computer_security": ["computer science"], + "conceptual_physics": ["physics"], + "econometrics": ["economics"], + "electrical_engineering": ["engineering"], + "elementary_mathematics": ["math"], + "formal_logic": ["philosophy"], + "global_facts": ["other"], + "high_school_biology": ["biology"], + "high_school_chemistry": ["chemistry"], + "high_school_computer_science": ["computer science"], + "high_school_european_history": ["history"], + "high_school_geography": ["geography"], + "high_school_government_and_politics": ["politics"], + "high_school_macroeconomics": ["economics"], + "high_school_mathematics": ["math"], + "high_school_microeconomics": ["economics"], + "high_school_physics": ["physics"], + "high_school_psychology": ["psychology"], + "high_school_statistics": ["math"], + "high_school_us_history": ["history"], + "high_school_world_history": ["history"], + "human_aging": ["health"], + "human_sexuality": ["culture"], + "international_law": ["law"], + "jurisprudence": ["law"], + "logical_fallacies": ["philosophy"], + "machine_learning": ["computer science"], + "management": ["business"], + "marketing": ["business"], + "medical_genetics": ["health"], + "miscellaneous": ["other"], + "moral_disputes": ["philosophy"], + "moral_scenarios": ["philosophy"], + "nutrition": ["health"], + "philosophy": ["philosophy"], + "prehistory": ["history"], + "professional_accounting": ["other"], + "professional_law": ["law"], + "professional_medicine": ["health"], + "professional_psychology": ["psychology"], + "public_relations": ["politics"], + "security_studies": ["politics"], + "sociology": ["culture"], + "us_foreign_policy": ["politics"], + "virology": ["health"], + "world_religions": ["philosophy"], +} + +categories = { + "STEM": ["physics", "chemistry", "biology", "computer science", "math", "engineering"], + "humanities": ["history", "philosophy", "law"], + "social sciences": ["politics", "culture", "economics", "geography", "psychology"], + "other (business, health, misc.)": ["other", "business", "health"], +} + +choices = ["A", "B", "C", "D"] + + +def format_subject(subject): + l = subject.split("_") + s = "" + for entry in l: + s += " " + entry + return s + +def format_example(df, idx, include_answer=True): + prompt = df.iloc[idx, 0] + k = df.shape[1] - 2 + for j in range(k): + prompt += "\n{}. {}".format(choices[j], df.iloc[idx, j + 1]) + prompt += "\nAnswer:" + if include_answer: + prompt += " {}\n\n".format(df.iloc[idx, k + 1]) + return prompt + + +def gen_prompt(train_df, subject, k=-1): + prompt = "The following are multiple choice questions (with answers) about {}.\n\n".format( + format_subject(subject) + ) + if k == -1: + k = train_df.shape[0] + for i in range(k): + prompt += format_example(train_df, i) + return prompt + + +@torch.no_grad() +def evaluate_subject(subject, model, tokenizer, ntrain, dev_df, test_df): + cors = [] + all_probs = [] + answers = choices[: test_df.shape[1] - 2] + + for i in tqdm(range(test_df.shape[0]), desc=subject): + # get prompt and make sure it fits + k = ntrain + prompt_end = format_example(test_df, i, include_answer=False) + train_prompt = gen_prompt(dev_df, subject, k) + prompt = train_prompt + prompt_end + + input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(model.device) + + while input_ids.shape[-1] > 2048: + k -= 1 + train_prompt = gen_prompt(dev_df, subject, k) + prompt = train_prompt + prompt_end + input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to( + model.device + ) + + label = test_df.iloc[i, test_df.shape[1] - 1] + + logits = model(input_ids=input_ids).logits[0, -1] + + probs = ( + torch.nn.functional.softmax( + torch.tensor( + [ + logits[tokenizer("A").input_ids[-1]], + logits[tokenizer("B").input_ids[-1]], + logits[tokenizer("C").input_ids[-1]], + logits[tokenizer("D").input_ids[-1]], + ] + ).float(), + dim=0, + ) + .detach() + .cpu() + .numpy() + ) + pred = {0: "A", 1: "B", 2: "C", 3: "D"}[np.argmax(probs)] + + cor = pred == label + cors.append(cor) + all_probs.append(probs) + + acc = np.mean(cors) + cors = np.array(cors) + + all_probs = np.array(all_probs) + print("Average accuracy {:.3f} - {}".format(acc, subject)) + + return cors, acc, all_probs + +def eval_mmlu(model, tokenizer, ntrain, data_dir): + subjects = sorted( + [ + f.split("_test.csv")[0] + for f in os.listdir(os.path.join(data_dir, "test")) + if "_test.csv" in f + ] + ) + + all_cors = [] + subcat_cors = { + subcat: [] for subcat_lists in subcategories.values() for subcat in subcat_lists + } + cat_cors = {cat: [] for cat in categories} + + start_time = time.time() + for subject in subjects: + dev_df = pd.read_csv( + os.path.join(data_dir, "dev", subject + "_dev.csv"), header=None + )[: ntrain] + test_df = pd.read_csv( + os.path.join(data_dir, "test", subject + "_test.csv"), header=None + ) + + cors, acc, probs = evaluate_subject(subject, model, tokenizer, ntrain, dev_df, test_df) + subcats = subcategories[subject] + for subcat in subcats: + subcat_cors[subcat].append(cors) + for key in categories.keys(): + if subcat in categories[key]: + cat_cors[key].append(cors) + all_cors.append(cors) + + results = {"subcategories": {}, "categories": {}} + for subcat in subcat_cors: + subcat_acc = np.mean(np.concatenate(subcat_cors[subcat])) + results["subcategories"][subcat] = subcat_acc + print("Average accuracy {:.3f} - {}".format(subcat_acc, subcat)) + + for cat in cat_cors: + cat_acc = np.mean(np.concatenate(cat_cors[cat])) + results["categories"][cat] = cat_acc + print("Average accuracy {:.3f} - {}".format(cat_acc, cat)) + weighted_acc = np.mean(np.concatenate(all_cors)) + results["weighted_accuracy"] = weighted_acc + print("Average accuracy: {:.3f}".format(weighted_acc)) + + end_time = time.time() + results["cost_time"] = end_time - start_time + + return results + +def eval_ppl(model, tokenizer): + model.eval() + max_length = model.config.max_position_embeddings + stride = max_length + test = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")["text"] + nlls = [] + encodings = tokenizer("\n\n".join(test), return_tensors="pt") + seq_len = encodings.input_ids.size(1) + prev_end_loc = 0 + for begin_loc in tqdm(range(0, seq_len, stride)): + end_loc = min(begin_loc + max_length, seq_len) + trg_len = end_loc - prev_end_loc # may be different from stride on last loop + input_ids = encodings.input_ids[:, begin_loc:end_loc].to(model.device) + target_ids = input_ids.clone() + target_ids[:, :-trg_len] = -100 + with torch.no_grad(): + outputs = model(input_ids, labels=target_ids) + # loss is calculated using CrossEntropyLoss which averages over valid labels + # N.B. the model only calculates loss over trg_len - 1 labels, because it internally shifts the labels + # to the left by 1. + neg_log_likelihood = outputs.loss + nlls.append(neg_log_likelihood) + prev_end_loc = end_loc + if end_loc == seq_len: + break + ppl = torch.exp(torch.stack(nlls).mean()) + return ppl diff --git a/flagscale/compress/LLM_Decomposition/sola_neuron_idx/hook.py b/flagscale/compress/LLM_Decomposition/sola_neuron_idx/hook.py new file mode 100644 index 0000000000..7dc4c34d83 --- /dev/null +++ b/flagscale/compress/LLM_Decomposition/sola_neuron_idx/hook.py @@ -0,0 +1,305 @@ +import torch +from torch import nn +import torch.nn.functional as F +from transformers import LlamaPreTrainedModel +from transformers.models.llama.modeling_llama import LlamaMLP, LlamaDecoderLayer, LlamaAttention +from transformers.cache_utils import Cache, DynamicCache + +import warnings +import functools +import math +import logging +logger = logging.getLogger(__name__) + +from typing import List, Dict, Tuple, Optional + +from utils import ( + HelperState, + set_helper_state, + HELPER_SUPPORT_MODEL_LIST, + HELPER_SUPPORT_MODEL_TYPES +) + + +_HELPER_HOOK_KEY = "HelperHook" + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`): + The position indices of the tokens corresponding to the query and key tensors. For example, this can be + used to pass offsetted position ids when working with a KV-cache. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos[position_ids].unsqueeze(unsqueeze_dim) + sin = sin[position_ids].unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + +def llama_mlp(module: torch.nn.Module, x: torch.Tensor): + if module.config.pretraining_tp > 1: + slice = module.intermediate_size // module.config.pretraining_tp + gate_proj_slices = module.gate_proj.weight.split(slice, dim=0) + up_proj_slices = module.up_proj.weight.split(slice, dim=0) + down_proj_slices = module.down_proj.weight.split(slice, dim=1) + + gate_proj = torch.cat( + [F.linear(x, gate_proj_slices[i]) for i in range(module.config.pretraining_tp)], dim=-1 + ) + up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(module.config.pretraining_tp)], dim=-1) + + intermediate_states = (module.act_fn(gate_proj) * up_proj).split(slice, dim=2) + down_proj = [ + F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(module.config.pretraining_tp) + ] + down_proj = sum(down_proj) + else: + h_gate = module.act_fn(module.gate_proj(x)) + h_up = module.up_proj(x) + intermediate_states = h_gate * h_up + down_proj = module.down_proj(intermediate_states) + + return down_proj, intermediate_states, h_gate, h_up + + +def llama_self_attn(module: torch.nn.Module, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + **kwargs + ): + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + # print("collect hidden states shape: ", {hidden_states.shape}) + bsz, q_len, _ = hidden_states.size() + + if module.config.pretraining_tp > 1: + key_value_slicing = (module.num_key_value_heads * module.head_dim) // module.config.pretraining_tp + query_slices = module.q_proj.weight.split( + (module.num_heads * module.head_dim) // module.config.pretraining_tp, dim=0 + ) + key_slices = module.k_proj.weight.split(key_value_slicing, dim=0) + value_slices = module.v_proj.weight.split(key_value_slicing, dim=0) + + query_states = [F.linear(hidden_states, query_slices[i]) for i in range(module.config.pretraining_tp)] + query_states = torch.cat(query_states, dim=-1) + + key_states = [F.linear(hidden_states, key_slices[i]) for i in range(module.config.pretraining_tp)] + key_states = torch.cat(key_states, dim=-1) + + value_states = [F.linear(hidden_states, value_slices[i]) for i in range(module.config.pretraining_tp)] + value_states = torch.cat(value_states, dim=-1) + + else: + query_states = module.q_proj(hidden_states) + key_states = module.k_proj(hidden_states) + value_states = module.v_proj(hidden_states) + + q_norm = torch.norm(query_states, p=2, dim=0) + k_norm = torch.norm(key_states, p=2, dim=0) + + query_states = query_states.view(bsz, q_len, module.num_heads, module.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, module.num_key_value_heads, module.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, module.num_key_value_heads, module.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + if module.layer_idx is None: + raise ValueError( + f"The cache structure has changed since version v4.36. If you are using {module.__class__.__name__} " + "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " + "with a layer index." + ) + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, module.layer_idx) + cos, sin = module.rotary_emb(value_states, seq_len=kv_seq_len) + if cos.device != position_ids.device: + position_ids = position_ids.to(cos.device) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + # key_states, value_states = past_key_value.update(key_states, value_states, module.layer_idx, cache_kwargs) + key_states, value_states = past_key_value.update_get(key_states, value_states, module.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, module.num_key_value_groups) + value_states = repeat_kv(value_states, module.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(module.head_dim) + + if attn_weights.size() != (bsz, module.num_heads, q_len, kv_seq_len): + raise ValueError( + f"Attention weights should be of size {(bsz, module.num_heads, q_len, kv_seq_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) + if attn_weights.device != attention_mask.device: + attn_weights = attn_weights.to(attention_mask.device) + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.bfloat16).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=module.attention_dropout, training=module.training) + if attn_weights.device != value_states.device: + attn_weights = attn_weights.to(value_states.device) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, module.num_heads, q_len, module.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, module.num_heads, q_len, module.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + + head_norms = torch.norm(attn_output, dim = -1) + + attn_output = attn_output.reshape(bsz, q_len, module.hidden_size) + + if module.config.pretraining_tp > 1: + attn_output = attn_output.split(module.hidden_size // module.config.pretraining_tp, dim=2) + o_proj_slices = module.o_proj.weight.split(module.hidden_size // module.config.pretraining_tp, dim=1) + attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(module.config.pretraining_tp)]) + else: + attn_output = module.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return head_norms, attn_output, attn_weights, past_key_value, q_norm, k_norm + + +def pre_decoder_forward_hook_dejavu_collect(layer_idx: int, + model_type: str, + dest: Dict[str, Dict[str, List[torch.Tensor]]], + module: torch.nn.Module, + inp: Tuple, + *args, + **kwargs): + x = inp[0] + + residual = x + hidden_states = module.input_layernorm(x) + + use_legacy_cache = not isinstance(args[0]["past_key_value"], Cache) and args[0]["past_key_value"] is not None + if use_legacy_cache: + args[0]["past_key_value"] = DynamicCache.from_legacy_cache(args[0]["past_key_value"]) + + _, hidden_states, _, _, _, _ = llama_self_attn(module.self_attn, + hidden_states, + args[0]["attention_mask"], + args[0]["position_ids"], + args[0]["past_key_value"], + args[0]["output_attentions"], + args[0]["use_cache"]) + + if hidden_states.device != residual.device: + hidden_states = hidden_states.to(residual.device) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = module.post_attention_layernorm(hidden_states) + + hidden_states, intermediate_states, _, _ = llama_mlp(module.mlp, hidden_states) + + ##### the output norm of different neurons ##### + if hidden_states.shape[1] > 1: + neurons_dict_cur_layer = getattr(module, f'neurons_dict_{layer_idx}') + mlp_inter_norm = torch.norm(intermediate_states, p=2, dim=0) + if mlp_inter_norm.device != neurons_dict_cur_layer['norm'].device: + mlp_inter_norm = mlp_inter_norm.to(neurons_dict_cur_layer['norm'].device) + neurons_dict_cur_layer['norm'] += mlp_inter_norm.sum(dim=0) + neurons_dict_cur_layer['token_num'] += mlp_inter_norm.shape[0] + + +def add_training_hook_to_llama(model: LlamaPreTrainedModel, + dest: Dict[str, Dict[str, List[torch.Tensor]]], + intermediate_size: int) -> int: + set_helper_state(model, HelperState.Collecting) + hooks = [] + last_layer = 0 + + for name, module in model.named_modules(): + if not isinstance(module, (LlamaDecoderLayer, LlamaAttention, LlamaMLP)): + continue + + if isinstance(module, LlamaDecoderLayer): + layer_idx = int(name.split(".")[-1]) + else: + layer_idx = int(name.split(".")[-2]) + + last_layer = max(layer_idx, last_layer) + + if isinstance(module, LlamaDecoderLayer): + setattr(module, f"neurons_dict_{layer_idx}", {'norm': torch.zeros(intermediate_size).to(model.device), 'token_num': 0}) + handle_dejavu_collect = module.register_forward_pre_hook( + functools.partial( + pre_decoder_forward_hook_dejavu_collect, + layer_idx, + "llama", + dest, + ), + with_kwargs=True + ) + hooks.append(handle_dejavu_collect) + + setattr(model, _HELPER_HOOK_KEY, hooks) + return last_layer + + +def add_training_hook(model: HELPER_SUPPORT_MODEL_TYPES, + dest: Dict[str, Dict[str, List[torch.Tensor]]], + intermediate_size: int) -> int: + if isinstance(model, LlamaPreTrainedModel): + return add_training_hook_to_llama(model, dest, intermediate_size) + else: + raise NotImplementedError(f"Only support {HELPER_SUPPORT_MODEL_LIST}.") + + +def remove_training_hook(model: HELPER_SUPPORT_MODEL_TYPES): + hooks = getattr(model, _HELPER_HOOK_KEY) + for handle in hooks: + handle.remove() + + setattr(model, _HELPER_HOOK_KEY, None) diff --git a/flagscale/compress/LLM_Decomposition/sola_neuron_idx/playground.py b/flagscale/compress/LLM_Decomposition/sola_neuron_idx/playground.py new file mode 100644 index 0000000000..0ce587d6c3 --- /dev/null +++ b/flagscale/compress/LLM_Decomposition/sola_neuron_idx/playground.py @@ -0,0 +1,155 @@ +import os +import random +import click +from transformers import ( + AutoModelForCausalLM, + AutoTokenizer, + LlamaTokenizer, + pipeline +) +import torch +import numpy as np +import pandas as pd + +import utils +from training import Helper + +import json +import time +import logging +logger = logging.getLogger(__name__) + + +def setup_seed(seed): + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + np.random.seed(seed) + random.seed(seed) + torch.backends.cudnn.deterministic = True + + +@click.command() +@click.option("-m", "--model", type=click.Path(file_okay=True), help="path to model file", default=None) +def cli(**kwargs): + args = utils.EasyDict(**kwargs) + print(args) + + model_path = args.model + + model = AutoModelForCausalLM.from_pretrained( + model_path, + cache_dir=None, + device_map="auto", + quantization_config = None, + torch_dtype=torch.bfloat16, + attn_implementation="eager", + ) + print("Model created") + + # Tokenizer + tokenizer = AutoTokenizer.from_pretrained( + model_path, + cache_dir=None, + padding_side="right", + use_fast=False, # Fast tokenizer giving issues. + tokenizer_type='llama' if 'ama' in model_path else None, # Needed for HF name change + ) + if tokenizer._pad_token is None: + utils.smart_tokenizer_and_embedding_resize( + special_tokens_dict=dict(pad_token=utils.DEFAULT_PAD_TOKEN), + tokenizer=tokenizer, + model=model, + ) + if 'ama' in model_path or isinstance(tokenizer, LlamaTokenizer): + # LLaMA tokenizer may not have correct special tokens set. + # Check and add them if missing to prevent them from being parsed into different tokens. + # Note that these are present in the vocabulary. + # Note also that `model.config.pad_token_id` is 0 which corresponds to `` token. + print('Adding special tokens.') + tokenizer.add_special_tokens({ + "eos_token": tokenizer.convert_ids_to_tokens(model.config.eos_token_id), + "bos_token": tokenizer.convert_ids_to_tokens(model.config.bos_token_id), + "unk_token": tokenizer.convert_ids_to_tokens( + model.config.pad_token_id if model.config.pad_token_id and model.config.pad_token_id != -1 else tokenizer.pad_token_id + ), + }) + print("Tokenizer created") + + generation_pipeline = pipeline(task="text-generation", model=model, tokenizer=tokenizer) + print("Pipeline created") + + setup_seed(42) + + if '13b' in model_path: + intermediate_size = 13824 + elif '7b' in model_path: + intermediate_size = 11008 + helper_params = {'intermediate_size': intermediate_size} + helper = Helper(model, torch.bfloat16, **helper_params) + print("Helper created") + + # Construct calibration data + # wiki_dataset = utils.get_wikitext2(256, 3, 4096, tokenizer, 'wiki') + # with open('/data/wiki_256_4096.json', 'w') as json_file: + # json.dump(wiki_dataset, json_file, ensure_ascii=False, indent=4) + with open('/data/wiki_256_4096.json', 'r') as file: + prompts_all = json.load(file) + + # calibration prompt number + prompt_num_list = [256] + for prompt_num in prompt_num_list: + prompts = prompts_all[:prompt_num] + print('prompt number: ', len(prompts)) + + # Compute neurons norm + t_start_time = time.time() + with helper: + for text in prompts: + prompt_token_count = len(generation_pipeline.tokenizer.encode(text, return_tensors="pt")[0]) + generation_pipeline(text, max_length=prompt_token_count, pad_token_id=tokenizer.eos_token_id, truncation=True) + t_end_time = time.time() + t_duration = t_end_time - t_start_time + print(f"Collect training data costs avg: {t_duration/len(prompts): .5f} s, all: {t_duration/60: .2f} min, {t_duration: .5f} s. ") + print('Collect training data Done') + + # Record neurons norm + import csv + from transformers.models.llama.modeling_llama import LlamaDecoderLayer + + for name, module in model.named_modules(): + if not isinstance(module, LlamaDecoderLayer): + continue + layer_idx = int(name.split(".")[-1]) + + neurons_dict_cur_layer = getattr(module, f'neurons_dict_{layer_idx}') + average_values = neurons_dict_cur_layer['norm'] / neurons_dict_cur_layer['token_num'] + csv_data = [(k, average_values[k].item()) for k in range(average_values.shape[0])] + csv_file = f'/data/mlp_neurons_norm/sample_{prompt_num}/13b_{layer_idx}.csv' + + with open(csv_file, mode='w', newline='') as file: + writer = csv.writer(file) + writer.writerows(csv_data) + print(layer_idx, csv_file, 'write done') + print(prompt_num, 'Neurons Norm Write Done') + + ###### Record prime/marginal neurons index ###### + dir_name = f'/data/mlp_neurons_norm/sample_{prompt_num}' + checkpoint_path_list = os.listdir(dir_name) + hitter_dict = {} + for checkpoint_path in checkpoint_path_list: + layer_idx = int(checkpoint_path.split('_')[-1].split('.')[0]) + data_df = pd.read_csv(os.path.join(dir_name, checkpoint_path), index_col=False, header=None) + data_df = data_df.sort_values(by=1, ascending=False) + num_top = int(len(data_df) * 0.15) # prime neurons 15% + heavy_neron_idx = data_df[0][:num_top].tolist() + light_neron_idx = data_df[0][num_top:].tolist() + hitter_dict[layer_idx] = {'heavy_15p_neuron_idx': heavy_neron_idx, + 'light_85p_neuron_idx': light_neron_idx} + + with open(f'{dir_name}/hitter_dict_15_{prompt_num}_4096_wiki_13b.json', 'w') as json_file: + json.dump(hitter_dict, json_file, ensure_ascii=False, indent=4) + print(prompt_num, 'Prime Dict Done') + + +if __name__ == "__main__": + cli() diff --git a/flagscale/compress/LLM_Decomposition/sola_neuron_idx/playground.sh b/flagscale/compress/LLM_Decomposition/sola_neuron_idx/playground.sh new file mode 100644 index 0000000000..28ed8df3bb --- /dev/null +++ b/flagscale/compress/LLM_Decomposition/sola_neuron_idx/playground.sh @@ -0,0 +1,3 @@ +#!/bin/bash + +CUDA_VISIBLE_DEVICES="0" python sola_neuron_idx/playground.py --model="/data/llama-2/Llama-2-13b-hf" diff --git a/flagscale/compress/LLM_Decomposition/sola_neuron_idx/training.py b/flagscale/compress/LLM_Decomposition/sola_neuron_idx/training.py new file mode 100644 index 0000000000..32945a552f --- /dev/null +++ b/flagscale/compress/LLM_Decomposition/sola_neuron_idx/training.py @@ -0,0 +1,29 @@ +import torch +import contextlib +from typing import Dict, List + +from utils import HELPER_SUPPORT_MODEL_LIST, HELPER_SUPPORT_MODEL_TYPES + +from sola_neuron_idx.hook import ( + add_training_hook, + remove_training_hook +) + + +class Helper(contextlib.ContextDecorator): + def __init__(self, model: HELPER_SUPPORT_MODEL_TYPES, compute_type, **kwargs): + self.model = model + self.device = model.device + self.compute_type = compute_type + self.intermediate_size = kwargs["intermediate_size"] + self.training_data: Dict[str, Dict[str, List[torch.Tensor]]] = {} + + if not isinstance(model, HELPER_SUPPORT_MODEL_LIST): + raise NotImplementedError("Unsupported model") + + def __enter__(self): + self.model_last_layer = add_training_hook(self.model, self.training_data, self.intermediate_size) + + def __exit__(self, exc_type, exc_val, exc_tb): + remove_training_hook(self.model) + \ No newline at end of file diff --git a/flagscale/compress/LLM_Decomposition/sola_neuron_idx/utils.py b/flagscale/compress/LLM_Decomposition/sola_neuron_idx/utils.py new file mode 100644 index 0000000000..22f0217891 --- /dev/null +++ b/flagscale/compress/LLM_Decomposition/sola_neuron_idx/utils.py @@ -0,0 +1,390 @@ +import os +import re +import time +import json +import pickle +import torch +import transformers +import numpy as np +import pandas as pd +from tqdm import tqdm +from enum import Enum +from datasets import load_dataset +from typing import Any, Union, Dict +from transformers import OPTPreTrainedModel, LlamaPreTrainedModel + + +def get_wikitext2(nsamples, seed, seqlen, tokenizer, data_name, dataset_cache_dir=None): + import random + random.seed(seed) + + if data_name == 'wiki': + traindata = load_dataset("wikitext", "wikitext-2-raw-v1", split="train") + else: + data_files = {"train": "en/c4-validation.*.json.gz"} + traindata = load_dataset("allenai/c4", data_files=data_files, split="train") + tot_text = "\n\n".join(traindata["text"]) + + traindataset = [] + for s in range(nsamples + 1): + i = random.randint(0, len(tot_text) - seqlen - 1) + j = i + seqlen * 10 + trainenc = tokenizer(tot_text[i:j], return_tensors="pt") + if trainenc.input_ids.shape[1] < seqlen: + s = s - 1 + continue + if s != 0: + traindataset.append(original_text) + inp = trainenc.input_ids[:, :seqlen] + original_text = tokenizer.decode(inp[0].tolist(), skip_special_tokens=True) + + return traindataset + + +class HelperState(Enum): + KEY = 10000 + + Collecting = 0 + Inference = 1 + + Invalid = 9999 + + +HelperState.KEY.label = "HelperState" +HelperState.Collecting.label = "Helper-Data-Collection" # hook forward() to collect data +HelperState.Inference.label = "Helper-Ready-Inference" # with updated forward() + + +class HelperCollectState(Enum): + KEY = 10001 + + Pre = 0 + Post = 1 + End = 2 + + Invalid = 9999 + + +HelperCollectState.KEY.label = "HelperCollectState" +HelperCollectState.Pre.label = "HelperCollectState-Pre" +HelperCollectState.Post.label = "HelperCollectState-Post" +HelperCollectState.End.label = "HelperCollectState-End" + +def set_helper_state(model, state: HelperState) -> None: + setattr(model, HelperState.KEY.label, state) + + +HELPER_SUPPORT_MODEL_LIST = (LlamaPreTrainedModel) +HELPER_SUPPORT_MODEL_TYPES = Union[LlamaPreTrainedModel] + + + +def easy_dump(obj, dest, label): + with open(os.path.join(dest, f"{label}.pkl"), "wb") as f: + pickle.dump(obj, f) + + # also dump as json if it is a dict + if isinstance(obj, dict): + with open(os.path.join(dest, f"{label}.json"), "w") as f: + f.write(json.dumps(obj, indent=4)) + +def make_run_dir(outdir: Union[str, os.PathLike], desc: str) -> str: + """Reject modernity, return to automatically create the run dir.""" + # Pick output directory. + prev_run_dirs = [] + if os.path.isdir(outdir): # sanity check, but click.Path() should clear this one + prev_run_dirs = [x for x in os.listdir(outdir) if os.path.isdir(os.path.join(outdir, x))] + prev_run_ids = [re.match(r'^\d+', x) for x in prev_run_dirs] + prev_run_ids = [int(x.group()) for x in prev_run_ids if x is not None] + cur_run_id = max(prev_run_ids, default=-1) + 1 # start with 00000 + run_dir = os.path.join(outdir, f'{cur_run_id:05d}-{desc}') + os.makedirs(run_dir, exist_ok=False) # make sure it doesn't already exist + return run_dir + +class EasyDict(dict): + """Convenience class that behaves like a dict but allows access with the attribute syntax.""" + + def __getattr__(self, name: str) -> Any: + try: + return self[name] + except KeyError: + raise AttributeError(name) + + def __setattr__(self, name: str, value: Any) -> None: + self[name] = value + + def __delattr__(self, name: str) -> None: + del self[name] + +DEFAULT_PAD_TOKEN = "[PAD]" +def smart_tokenizer_and_embedding_resize( + special_tokens_dict: Dict, + tokenizer: transformers.PreTrainedTokenizer, + model: transformers.PreTrainedModel, +): + """Resize tokenizer and embedding. + + Note: This is the unoptimized version that may make your embedding size not be divisible by 64. + """ + num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict) + model.resize_token_embeddings(len(tokenizer)) + + if num_new_tokens > 0: + input_embeddings_data = model.get_input_embeddings().weight.data + output_embeddings_data = model.get_output_embeddings().weight.data + + input_embeddings_avg = input_embeddings_data[:-num_new_tokens].mean(dim=0, keepdim=True) + output_embeddings_avg = output_embeddings_data[:-num_new_tokens].mean(dim=0, keepdim=True) + + input_embeddings_data[-num_new_tokens:] = input_embeddings_avg + output_embeddings_data[-num_new_tokens:] = output_embeddings_avg + +subcategories = { + "abstract_algebra": ["math"], + "anatomy": ["health"], + "astronomy": ["physics"], + "business_ethics": ["business"], + "clinical_knowledge": ["health"], + "college_biology": ["biology"], + "college_chemistry": ["chemistry"], + "college_computer_science": ["computer science"], + "college_mathematics": ["math"], + "college_medicine": ["health"], + "college_physics": ["physics"], + "computer_security": ["computer science"], + "conceptual_physics": ["physics"], + "econometrics": ["economics"], + "electrical_engineering": ["engineering"], + "elementary_mathematics": ["math"], + "formal_logic": ["philosophy"], + "global_facts": ["other"], + "high_school_biology": ["biology"], + "high_school_chemistry": ["chemistry"], + "high_school_computer_science": ["computer science"], + "high_school_european_history": ["history"], + "high_school_geography": ["geography"], + "high_school_government_and_politics": ["politics"], + "high_school_macroeconomics": ["economics"], + "high_school_mathematics": ["math"], + "high_school_microeconomics": ["economics"], + "high_school_physics": ["physics"], + "high_school_psychology": ["psychology"], + "high_school_statistics": ["math"], + "high_school_us_history": ["history"], + "high_school_world_history": ["history"], + "human_aging": ["health"], + "human_sexuality": ["culture"], + "international_law": ["law"], + "jurisprudence": ["law"], + "logical_fallacies": ["philosophy"], + "machine_learning": ["computer science"], + "management": ["business"], + "marketing": ["business"], + "medical_genetics": ["health"], + "miscellaneous": ["other"], + "moral_disputes": ["philosophy"], + "moral_scenarios": ["philosophy"], + "nutrition": ["health"], + "philosophy": ["philosophy"], + "prehistory": ["history"], + "professional_accounting": ["other"], + "professional_law": ["law"], + "professional_medicine": ["health"], + "professional_psychology": ["psychology"], + "public_relations": ["politics"], + "security_studies": ["politics"], + "sociology": ["culture"], + "us_foreign_policy": ["politics"], + "virology": ["health"], + "world_religions": ["philosophy"], +} + +categories = { + "STEM": ["physics", "chemistry", "biology", "computer science", "math", "engineering"], + "humanities": ["history", "philosophy", "law"], + "social sciences": ["politics", "culture", "economics", "geography", "psychology"], + "other (business, health, misc.)": ["other", "business", "health"], +} + +choices = ["A", "B", "C", "D"] + + +def format_subject(subject): + l = subject.split("_") + s = "" + for entry in l: + s += " " + entry + return s + +def format_example(df, idx, include_answer=True): + prompt = df.iloc[idx, 0] + k = df.shape[1] - 2 + for j in range(k): + prompt += "\n{}. {}".format(choices[j], df.iloc[idx, j + 1]) + prompt += "\nAnswer:" + if include_answer: + prompt += " {}\n\n".format(df.iloc[idx, k + 1]) + return prompt + + +def gen_prompt(train_df, subject, k=-1): + prompt = "The following are multiple choice questions (with answers) about {}.\n\n".format( + format_subject(subject) + ) + if k == -1: + k = train_df.shape[0] + for i in range(k): + prompt += format_example(train_df, i) + return prompt + + +@torch.no_grad() +def evaluate_subject(subject, model, tokenizer, ntrain, dev_df, test_df): + cors = [] + all_probs = [] + answers = choices[: test_df.shape[1] - 2] + + for i in tqdm(range(test_df.shape[0]), desc=subject): + # get prompt and make sure it fits + k = ntrain + prompt_end = format_example(test_df, i, include_answer=False) + train_prompt = gen_prompt(dev_df, subject, k) + prompt = train_prompt + prompt_end + + input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(model.device) + + while input_ids.shape[-1] > 2048: + k -= 1 + train_prompt = gen_prompt(dev_df, subject, k) + prompt = train_prompt + prompt_end + input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to( + model.device + ) + + label = test_df.iloc[i, test_df.shape[1] - 1] + + logits = model(input_ids=input_ids).logits[0, -1] + + probs = ( + torch.nn.functional.softmax( + torch.tensor( + [ + logits[tokenizer("A").input_ids[-1]], + logits[tokenizer("B").input_ids[-1]], + logits[tokenizer("C").input_ids[-1]], + logits[tokenizer("D").input_ids[-1]], + ] + ).float(), + dim=0, + ) + .detach() + .cpu() + .numpy() + ) + pred = {0: "A", 1: "B", 2: "C", 3: "D"}[np.argmax(probs)] + + cor = pred == label + cors.append(cor) + all_probs.append(probs) + + acc = np.mean(cors) + cors = np.array(cors) + + all_probs = np.array(all_probs) + print("Average accuracy {:.3f} - {}".format(acc, subject)) + + return cors, acc, all_probs + +def eval_mmlu(label, model, tokenizer, ntrain, data_dir, save_dir): + mmlu_out_dest = make_run_dir(save_dir, "mmlu_eval") + + subjects = sorted( + [ + f.split("_test.csv")[0] + for f in os.listdir(os.path.join(data_dir, "test")) + if "_test.csv" in f + ] + ) + + all_cors = [] + subcat_cors = { + subcat: [] for subcat_lists in subcategories.values() for subcat in subcat_lists + } + cat_cors = {cat: [] for cat in categories} + + start_time = time.time() + for subject in subjects: + dev_df = pd.read_csv( + os.path.join(data_dir, "dev", subject + "_dev.csv"), header=None + )[: ntrain] + test_df = pd.read_csv( + os.path.join(data_dir, "test", subject + "_test.csv"), header=None + ) + + cors, acc, probs = evaluate_subject(subject, model, tokenizer, ntrain, dev_df, test_df) + subcats = subcategories[subject] + for subcat in subcats: + subcat_cors[subcat].append(cors) + for key in categories.keys(): + if subcat in categories[key]: + cat_cors[key].append(cors) + all_cors.append(cors) + + test_df["{}_correct".format(label)] = cors + for j in range(probs.shape[1]): + choice = choices[j] + test_df["{}_choice{}_probs".format(label, choice)] = probs[:, j] + os.makedirs(os.path.join(mmlu_out_dest, "results_{}".format(label.split("/")[-1])), exist_ok=True) + test_df.to_csv( + os.path.join( + mmlu_out_dest, "results_{}".format(label.split("/")[-1]), "{}.csv".format(subject) + ), + index=None, + ) + + results = {"subcategories": {}, "categories": {}} + for subcat in subcat_cors: + subcat_acc = np.mean(np.concatenate(subcat_cors[subcat])) + results["subcategories"][subcat] = subcat_acc + print("Average accuracy {:.3f} - {}".format(subcat_acc, subcat)) + + for cat in cat_cors: + cat_acc = np.mean(np.concatenate(cat_cors[cat])) + results["categories"][cat] = cat_acc + print("Average accuracy {:.3f} - {}".format(cat_acc, cat)) + weighted_acc = np.mean(np.concatenate(all_cors)) + results["weighted_accuracy"] = weighted_acc + print("Average accuracy: {:.3f}".format(weighted_acc)) + + end_time = time.time() + results["cost_time"] = end_time - start_time + easy_dump(results, mmlu_out_dest, "accuracies_{}".format(label.replace("/", "_"))) + return results + +def eval_ppl(model, tokenizer): + model.eval() + max_length = model.config.max_position_embeddings + stride = max_length + test = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")["text"] + nlls = [] + encodings = tokenizer("\n\n".join(test), return_tensors="pt") + seq_len = encodings.input_ids.size(1) + prev_end_loc = 0 + for begin_loc in tqdm(range(0, seq_len, stride)): + end_loc = min(begin_loc + max_length, seq_len) + trg_len = end_loc - prev_end_loc # may be different from stride on last loop + input_ids = encodings.input_ids[:, begin_loc:end_loc].to(model.device) + target_ids = input_ids.clone() + target_ids[:, :-trg_len] = -100 + with torch.no_grad(): + outputs = model(input_ids, labels=target_ids) + # loss is calculated using CrossEntropyLoss which averages over valid labels + # N.B. the model only calculates loss over trg_len - 1 labels, because it internally shifts the labels + # to the left by 1. + neg_log_likelihood = outputs.loss + nlls.append(neg_log_likelihood) + prev_end_loc = end_loc + if end_loc == seq_len: + break + ppl = torch.exp(torch.stack(nlls).mean()) + return ppl