Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
SuperBruceJia committed Dec 16, 2024
1 parent b791841 commit db7dee4
Show file tree
Hide file tree
Showing 18 changed files with 140 additions and 39 deletions.
8 changes: 8 additions & 0 deletions .idea/.gitignore

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

12 changes: 12 additions & 0 deletions .idea/PodGPT.iml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

63 changes: 63 additions & 0 deletions .idea/inspectionProfiles/Project_Default.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 6 additions & 0 deletions .idea/inspectionProfiles/profiles_settings.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 7 additions & 0 deletions .idea/misc.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 8 additions & 0 deletions .idea/modules.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 6 additions & 0 deletions .idea/vcs.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion lib/model_loader_large.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def model_loader(config):
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
bias="none",
# Please note that the current vLLM is not supporting
# Please note that the current vLLM is not supporting
# the modules "w1", "w2", "w3", and "gate" at this point (June 20, 2024)
target_modules=[
"q_proj", "k_proj", "v_proj", "o_proj"
Expand Down
7 changes: 4 additions & 3 deletions lib/model_loader_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import os

from transformers import AutoTokenizer, TrainingArguments
from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig, get_gptq_peft_model
from auto_gptq import AutoGPTQForCausalLM, get_gptq_peft_model
from auto_gptq.utils.peft_utils import GPTQLoraConfig
from peft import TaskType
from trl import SFTTrainer
Expand All @@ -34,7 +34,7 @@ def model_initializer(config):
model = AutoGPTQForCausalLM.from_quantized(
model_name,
# Since we are using the auto-gptq==0.6.0,
# We cannot use shard safetensors and here we just use the single 39.8GB single-safetensor checkpoint.
# We cannot use shard safetensors and here we just use the single 39.8GB single-safetensor checkpoint.
# https://huggingface.co/shuyuej/Llama-3.3-70B-Instruct-GPTQ/tree/f77c1b3864179c38146f12656804b5b3dfd1e2a2
revision="f77c1b3",
use_safetensors=True,
Expand All @@ -51,7 +51,8 @@ def model_initializer(config):
model.warmup_triton()

# https://gist.github.com/eusip/de8fadb761741b56d5d9a6232bf979ed#file-oasst-pythia-12b-05-03-2023-py-L68-L87
# NOTE: https://github.com/lvwerra/trl/blob/a2749d9e0c96198486b788875eda3b325f76a5c8/examples/sentiment/scripts/gpt-neox-20b_peft/gpt-neo-20b_sentiment_peft.py#L181
# https://github.com/lvwerra/trl/blob/a2749d9e0c96198486b788875eda3b325f76a5c8/examples/sentiment/scripts/
# gpt-neox-20b_peft/gpt-neo-20b_sentiment_peft.py#L181
for param in model.parameters():
# freeze base model's layers
param.requires_grad = False
Expand Down
1 change: 0 additions & 1 deletion main_large.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,4 +90,3 @@ def main(config):
print(yaml.dump(config, default_flow_style=False), '\n\n')
main(config=config)
sys.stdout = sys.__stdout__

1 change: 0 additions & 1 deletion main_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,4 +90,3 @@ def main(config):
print(yaml.dump(config, default_flow_style=False), '\n\n')
main(config=config)
sys.stdout = sys.__stdout__

1 change: 0 additions & 1 deletion main_small.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,4 +96,3 @@ def main(config):
print(yaml.dump(config, default_flow_style=False), '\n\n')
main(config=config)
sys.stdout = sys.__stdout__

11 changes: 2 additions & 9 deletions quantization/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@

####################################################################################

import time
import os
import logging
import argparse
Expand Down Expand Up @@ -109,7 +108,6 @@ def quantization(model_dir, output_dir, quantdataset, bits, group_size, desc_act
raise ValueError(f"Unsupported dtype: {dtype}")

# Load the model with specified quantization settings
logger.info(f"Loading model from {model_dir} with trust_remote_code={trust_remote_code} and dtype={torch_dtype}")
model = AutoGPTQForCausalLM.from_pretrained(
model_dir,
quantize_config=quantize_config,
Expand All @@ -119,15 +117,10 @@ def quantization(model_dir, output_dir, quantdataset, bits, group_size, desc_act
)

# Perform the quantization process
logger.info(f"Starting quantization to {output_dir} with use_triton={use_triton}")
start_time = time.time()
model.quantize(quantdataset, use_triton=use_triton, batch_size=batch_size)
logger.info(f"Time to quantize model at {output_dir} with use_triton={use_triton}: {time.time() - start_time:.2f}")

# Save the quantized model
logger.info(f"Saving quantized model to {output_dir}")
model.save_quantized(output_dir, use_safetensors=True)
logger.info("Done.")


def mian(args):
Expand Down Expand Up @@ -198,12 +191,12 @@ def mian(args):
logger.error(f"Aborted. Will delete {output_dir}")
os.rmdir(output_dir)
abort = True
except:
except Exception:
raise
finally:
count += 1
else:
logger.error(f"Aborting - told to stop!")
logger.error("Aborting - told to stop!")
break


Expand Down
6 changes: 3 additions & 3 deletions quantization/quantization_GPTQModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,15 +138,15 @@ def mian(args):
)
except KeyboardInterrupt:
# Handle user interrupt
logger.error(f"Aborted. Will delete {output_dir}")
logger.error("Aborted. Will delete {output_dir}")
os.rmdir(output_dir)
abort = True
except:
except Exception:
raise
finally:
count += 1
else:
logger.error(f"Aborting - told to stop!")
logger.error("Aborting - told to stop!")
break


Expand Down
7 changes: 4 additions & 3 deletions quantization/quantization_HF.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,13 @@
# PodGPT: An Audio-augmented Large Language Model for Research and Education
# Copyright (C) 2024 Kolachalama Laboratory at Boston University

import os
import argparse
import json

import torch
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, GPTQConfig
from huggingface_hub import login

from utils.utils import load_config

Expand Down Expand Up @@ -104,7 +105,7 @@ def main(repo, bits, group_size, act_order, hf_read_token):
"weight_map": {key: "model.safetensors" for key in state_dict.keys()}, # Map all weights to a single file
}

index_file_path = os.path.join(model_save_path, "model.safetensors.index.json")
index_file_path = os.path.join(f"{repo}_{bits}bit", "model.safetensors.index.json")
with open(index_file_path, "w") as f:
json.dump(index, f, indent=2)
print("Saved index file to", index_file_path)
Expand All @@ -123,7 +124,7 @@ def main(repo, bits, group_size, act_order, hf_read_token):
# Load the configuration
config = load_config(file_name="config_quantization.yml")
hf_read_token = config.get("hf_read_token")

# Conduct the GPTQ quantization
main(
config=config,
Expand Down
26 changes: 13 additions & 13 deletions utils/answer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,7 @@ def extract_answer(completion, option_range="a-eA-E"):
re.compile(rf'would be[^{potential_letters}]*\{{([{option_range}])\}}'),
re.compile(rf'would be[^{potential_letters}]*([{option_range}])\)'),
re.compile(rf'would be[^{potential_letters}]*([{option_range}])$'),

# Matches "is (A)" and similar formats
re.compile(
rf'is[^{potential_letters}]*:+[^{potential_letters}]*\n+[^{potential_letters}]*\(([{option_range}])\)'
Expand All @@ -392,15 +392,15 @@ def extract_answer(completion, option_range="a-eA-E"):
rf'is[^{potential_letters}]*:+[^{potential_letters}]*\n+[^{potential_letters}]*([{option_range}])\)'
),
re.compile(rf'is[^{potential_letters}]*\n+[^{potential_letters}]*([{option_range}])\)'),

# Matches "be (A)" and similar formats
re.compile(rf'is[^{letter_and_num}]+([{option_range}])\)'),
re.compile(rf'be[^{letter_and_num}]+([{option_range}])\)'),
re.compile(rf'[^{letter_and_num}]+([{option_range}])\)[^{potential_letters}]*is'),
re.compile(rf'[^{letter_and_num}]+([{option_range}])\)[^{potential_letters}]*would'),
re.compile(rf'[^{letter_and_num}]+([{option_range}])\)[^{potential_letters}]*could'),
re.compile(rf'[^{letter_and_num}]+([{option_range}])\)[^{potential_letters}]*will'),

# Matches "(A)" followed by any other characters
re.compile(rf':+[^{letter_and_num}]*([{option_range}])\)[^{potential_letters}]'),
re.compile(rf':+[^{letter_and_num}]*([{option_range}])\)$'),
Expand Down Expand Up @@ -460,7 +460,7 @@ def extract_answer(completion, option_range="a-eA-E"):
additional_patterns = [
# Matches "A"
re.compile(rf"^[^{letter_and_num}]*([{option_range}])[^{letter_and_num}]*$"),

# Matches "(A) is", "[A] is", "{A} is", and similar formats
re.compile(rf'\(([{option_range}])\)[^{potential_letters}]*is'),
re.compile(rf'\[([{option_range}])\][^{potential_letters}]*is'),
Expand All @@ -472,7 +472,7 @@ def extract_answer(completion, option_range="a-eA-E"):
),
re.compile(rf'^([{option_range}])\)[^{potential_letters}]*is'),
re.compile(rf'^([{option_range}])[^{letter_and_num}][^{potential_letters}]*is'),

# Matches "(A) would", "[A] would", "{A} would", and similar formats
re.compile(rf'\(([{option_range}])\)[^{potential_letters}]*would'),
re.compile(rf'\[([{option_range}])\][^{potential_letters}]*would'),
Expand All @@ -484,7 +484,7 @@ def extract_answer(completion, option_range="a-eA-E"):
),
re.compile(rf'^([{option_range}])\)[^{potential_letters}]*would'),
re.compile(rf'^([{option_range}])[^{letter_and_num}][^{potential_letters}]*would'),

# Matches "(A) could", "[A] could", "{A} could", and similar formats
re.compile(rf'\(([{option_range}])\)[^{potential_letters}]*could'),
re.compile(rf'\[([{option_range}])\][^{potential_letters}]*could'),
Expand All @@ -496,7 +496,7 @@ def extract_answer(completion, option_range="a-eA-E"):
),
re.compile(rf'^([{option_range}])\)[^{potential_letters}]*could'),
re.compile(rf'^([{option_range}])[^{letter_and_num}][^{potential_letters}]*could'),

# Matches "(A) will", "[A] will", "{A} will", and similar formats
re.compile(rf'\(([{option_range}])\)[^{potential_letters}]*will'),
re.compile(rf'\[([{option_range}])\][^{potential_letters}]*will'),
Expand All @@ -508,7 +508,7 @@ def extract_answer(completion, option_range="a-eA-E"):
),
re.compile(rf'^([{option_range}])\)[^{potential_letters}]*will'),
re.compile(rf'^([{option_range}])[^{letter_and_num}][^{potential_letters}]*will'),

# Matches "option: (A)" and similar formats
re.compile(rf'[oO]ption:+[^{potential_letters}]*\(([{option_range}])\)'),
re.compile(rf'[oO]ption:+[^{potential_letters}]*\[([{option_range}])\]'),
Expand All @@ -531,7 +531,7 @@ def extract_answer(completion, option_range="a-eA-E"):
rf'{letter_and_num}]'
),
re.compile(rf'[oO]ption:+[^{potential_letters}]*[^{letter_and_num}]([{option_range}])$'),

# Matches "choice: (A)" and similar formats
re.compile(rf'[cC]hoice:+[^{potential_letters}]*\(([{option_range}])\)'),
re.compile(rf'[cC]hoice:+[^{potential_letters}]*\[([{option_range}])\]'),
Expand All @@ -554,7 +554,7 @@ def extract_answer(completion, option_range="a-eA-E"):
rf'{letter_and_num}]'
),
re.compile(rf'[cC]hoice:+[^{potential_letters}]*[^{letter_and_num}]([{option_range}])$'),

# Matches "answer: (A)" and similar formats
re.compile(rf' is[^{potential_letters}]+\(([{option_range}])\)[^{potential_letters}]'),
re.compile(rf' is[^{potential_letters}]+\[([{option_range}])\][^{potential_letters}]'),
Expand All @@ -580,7 +580,7 @@ def extract_answer(completion, option_range="a-eA-E"):
re.compile(rf' is[^{potential_letters}]+\{{([{option_range}])\}}'),
re.compile(rf' is[^{potential_letters}]*[^{letter_and_num}]([{option_range}])\)'),
re.compile(rf' is[^{letter_and_num}]*([{option_range}])\)'),

# Matches "choice (A)" and similar formats
re.compile(rf'[cC]hoice[^{potential_letters}]*\(([{option_range}])\)'),
re.compile(rf'[cC]hoice[^{potential_letters}]*\[([{option_range}])\]'),
Expand All @@ -603,7 +603,7 @@ def extract_answer(completion, option_range="a-eA-E"):
rf'{letter_and_num}]'
),
re.compile(rf'[cC]hoice[^{potential_letters}]*[^{letter_and_num}]([{option_range}])$'),

# Matches "answer (A)" and similar formats
re.compile(rf'[aA]nswer[^{potential_letters}]*\(([{option_range}])\)'),
re.compile(rf'[aA]nswer[^{potential_letters}]*\[([{option_range}])\]'),
Expand All @@ -625,7 +625,7 @@ def extract_answer(completion, option_range="a-eA-E"):
rf'{letter_and_num}]'
),
re.compile(rf'[aA]nswer[^{potential_letters}]*[^{letter_and_num}]([{option_range}])$'),

# Matches "option (A)" and similar formats
re.compile(rf'[Oo]ption[^{potential_letters}]*\(([{option_range}])\)'),
re.compile(rf'[Oo]ption[^{potential_letters}]*\[([{option_range}])\]'),
Expand Down
Loading

0 comments on commit db7dee4

Please sign in to comment.