diff --git a/configs/phase2_llama.py b/configs/phase2_llama.py new file mode 100644 index 000000000..447efd4b0 --- /dev/null +++ b/configs/phase2_llama.py @@ -0,0 +1,54 @@ +from mmengine.config import read_base +from opencompass.models import HuggingFaceCausalLM +from opencompass.openicl.icl_inferencer import GenInferencer +from opencompass.openicl.icl_retriever import ZeroRetriever + +# Import BOTH custom classes +from custom_dataset import LocalGSM8K, SimpleGSM8KEvaluator + +models = [ + dict( + type=HuggingFaceCausalLM, + abbr='llama-3.2-1b-instruct', + path='meta-llama/Llama-3.2-1B-Instruct', + tokenizer_path='meta-llama/Llama-3.2-1B-Instruct', + model_kwargs=dict(device_map='auto'), + tokenizer_kwargs=dict( + padding_side='left', + truncation_side='left', + pad_token='<|end_of_text|>' + ), + max_out_len=256, + batch_size=16, + run_cfg=dict(num_gpus=1), + ) +] + +datasets = [ + dict( + abbr='gsm8k_sample', + type=LocalGSM8K, + path='json', + reader_cfg=dict( + input_columns=['question'], + output_column='answer', + train_split='train' + ), + infer_cfg=dict( + prompt_template=dict( + type='PromptTemplate', + template="Question: {question}\nLet's think step by step.\nAnswer:" + ), + retriever=dict(type=ZeroRetriever), + inferencer=dict(type=GenInferencer) + ), + # FIX: Use our local evaluator class + eval_cfg=dict( + evaluator=dict(type=SimpleGSM8KEvaluator), + # We don't need a post-processor dict here because + # our custom class handles the parsing internally. + ) + ) +] + +work_dir = './outputs/phase2' diff --git a/configs/phase3_llada.py b/configs/phase3_llada.py new file mode 100644 index 000000000..92c808815 --- /dev/null +++ b/configs/phase3_llada.py @@ -0,0 +1,83 @@ +from opencompass.openicl.icl_prompt_template import PromptTemplate +from opencompass.openicl.icl_retriever import ZeroRetriever +from opencompass.openicl.icl_inferencer import GenInferencer +from opencompass.datasets import GSM8KDataset, gsm8k_postprocess, gsm8k_dataset_postprocess, Gsm8kEvaluator +from opencompass.models.llada import LLaDA # Your custom model + +# ========================================================= +# 1. INLINED GSM8K CONFIGURATION (No external imports needed) +# ========================================================= + +gsm8k_reader_cfg = dict( + input_columns=['question'], + output_column='answer', + test_range='[0:5]' # <--- LIMIT APPLIED HERE (First 50 questions) +) + +gsm8k_infer_cfg = dict( + prompt_template=dict( + type=PromptTemplate, + template=dict( + round=[ + dict(role='HUMAN', prompt="Question: Angelo and Melanie want to plan how many hours over the next week they should study together for their test next week. They have 2 chapters of their textbook to study and 4 worksheets to memorize. They figure out that they should dedicate 3 hours to each chapter of their textbook and 1.5 hours for each worksheet. If they plan to study no more than 4 hours each day, how many days should they plan to study total over the next week if they take a 10-minute break every hour, include 3 10-minute snack breaks each day, and 30 minutes for lunch each day?\nLet's think step by step\nAnswer:"), + dict(role='BOT', prompt='Angelo and Melanie think they should dedicate 3 hours to each of the 2 chapters, 3 hours x 2 chapters = 6 hours total.\nFor the worksheets they plan to dedicate 1.5 hours for each worksheet, 1.5 hours x 4 worksheets = 6 hours total.\nAngelo and Melanie need to start with planning 12 hours to study, at 4 hours a day, 12 / 4 = 3 days.\nHowever, they need to include time for breaks and lunch. Every hour they want to include a 10-minute break, so 12 total hours x 10 minutes = 120 extra minutes for breaks.\nThey also want to include 3 10-minute snack breaks, 3 x 10 minutes = 30 minutes.\nAnd they want to include 30 minutes for lunch each day, so 120 minutes for breaks + 30 minutes for snack breaks + 30 minutes for lunch = 180 minutes, or 180 / 60 minutes per hour = 3 extra hours.\nSo Angelo and Melanie want to plan 12 hours to study + 3 hours of breaks = 15 hours total.\nThey want to study no more than 4 hours each day, 15 hours / 4 hours each day = 3.75\nThey will need to plan to study 4 days to allow for all the time they need.\nThe answer is 4\n'), + dict(role='HUMAN', prompt="Question: Mark's basketball team scores 25 2 pointers, 8 3 pointers and 10 free throws. Their opponents score double the 2 pointers but half the 3 pointers and free throws. What's the total number of points scored by both teams added together?\nLet's think step by step\nAnswer:"), + dict(role='BOT', prompt="Mark's team scores 25 2 pointers, meaning they scored 25*2= 50 points in 2 pointers.\nHis team also scores 6 3 pointers, meaning they scored 8*3= 24 points in 3 pointers\nThey scored 10 free throws, and free throws count as one point so they scored 10*1=10 points in free throws.\nAll together his team scored 50+24+10= 84 points\nMark's opponents scored double his team's number of 2 pointers, meaning they scored 50*2=100 points in 2 pointers.\nHis opponents scored half his team's number of 3 pointers, meaning they scored 24/2= 12 points in 3 pointers.\nThey also scored half Mark's team's points in free throws, meaning they scored 10/2=5 points in free throws.\nAll together Mark's opponents scored 100+12+5=117 points\nThe total score for the game is both team's scores added together, so it is 84+117=201 points\nThe answer is 201\n"), + dict(role='HUMAN', prompt="Question: Bella has two times as many marbles as frisbees. She also has 20 more frisbees than deck cards. If she buys 2/5 times more of each item, what would be the total number of the items she will have if she currently has 60 marbles?\nLet's think step by step\nAnswer:"), + dict(role='BOT', prompt="When Bella buys 2/5 times more marbles, she'll have increased the number of marbles by 2/5*60 = 24\nThe total number of marbles she'll have is 60+24 = 84\nIf Bella currently has 60 marbles, and she has two times as many marbles as frisbees, she has 60/2 = 30 frisbees.\nIf Bella buys 2/5 times more frisbees, she'll have 2/5*30 = 12 more frisbees.\nThe total number of frisbees she'll have will increase to 30+12 = 42\nBella also has 20 more frisbees than deck cards, meaning she has 30-20 = 10 deck cards\nIf she buys 2/5 times more deck cards, she'll have 2/5*10 = 4 more deck cards.\nThe total number of deck cards she'll have is 10+4 = 14\nTogether, Bella will have a total of 14+42+84 = 140 items\nThe answer is 140\n"), + dict(role='HUMAN', prompt="Question: A group of 4 fruit baskets contains 9 apples, 15 oranges, and 14 bananas in the first three baskets and 2 less of each fruit in the fourth basket. How many fruits are there?\nLet's think step by step\nAnswer:"), + dict(role='BOT', prompt='For the first three baskets, the number of apples and oranges in one basket is 9+15=24\nIn total, together with bananas, the number of fruits in one basket is 24+14=38 for the first three baskets.\nSince there are three baskets each having 38 fruits, there are 3*38=114 fruits in the first three baskets.\nThe number of apples in the fourth basket is 9-2=7\nThere are also 15-2=13 oranges in the fourth basket\nThe combined number of oranges and apples in the fourth basket is 13+7=20\nThe fourth basket also contains 14-2=12 bananas.\nIn total, the fourth basket has 20+12=32 fruits.\nThe four baskets together have 32+114=146 fruits.\nThe answer is 146\n'), + dict(role='HUMAN', prompt="Question: {question}\nLet's think step by step\nAnswer:"), + ], + )), + retriever=dict(type=ZeroRetriever), + inferencer=dict(type=GenInferencer, max_out_len=512) +) + +gsm8k_eval_cfg = dict( + evaluator=dict(type=Gsm8kEvaluator), + pred_postprocessor=dict(type=gsm8k_postprocess), + dataset_postprocessor=dict(type=gsm8k_dataset_postprocess) +) + +gsm8k_datasets = [ + dict( + abbr='gsm8k_5', # Renamed for clarity + type=GSM8KDataset, + path='opencompass/gsm8k', + reader_cfg=gsm8k_reader_cfg, + infer_cfg=gsm8k_infer_cfg, + eval_cfg=gsm8k_eval_cfg + ) +] + +# Set the datasets variable required by OpenCompass +datasets = [*gsm8k_datasets] + +# ========================================================= +# 2. MODEL CONFIGURATION +# ========================================================= +models = [ + dict( + type=LLaDA, + abbr='llada-8b-instruct', + path='GSAI-ML/LLaDA-8B-Instruct', + tokenizer_path='GSAI-ML/LLaDA-8B-Instruct', + + # LLaDA Specifics + steps=32, + gen_length=512, + block_length=128, + + # OpenCompass/HF Configs + max_out_len=512, + max_seq_len=2048, + batch_size=1, + run_cfg=dict(num_gpus=1, num_procs=1), + model_kwargs=dict( + device_map='auto', + torch_dtype='torch.bfloat16' + ) + ) +] + diff --git a/custom_dataset.py b/custom_dataset.py new file mode 100644 index 000000000..6472ae225 --- /dev/null +++ b/custom_dataset.py @@ -0,0 +1,68 @@ +import re +from opencompass.datasets import BaseDataset +from datasets import load_dataset +from opencompass.openicl.icl_evaluator import BaseEvaluator + +class LocalGSM8K(BaseDataset): + """ + A custom wrapper to strictly load the local GSM8K sample file. + """ + # FIX: Change signature to accept anything (*args, **kwargs) + # This prevents the "missing positional argument" error. + def load(self, *args, **kwargs): + return load_dataset( + 'json', + data_files='/workspace/llada_test_run/opencompass/data/gsm8k_sample.jsonl', + split='train' + ) + + +class SimpleGSM8KEvaluator(BaseEvaluator): + def score(self, predictions, references): + if len(predictions) != len(references): + return {'error': 'pred_ref_length_mismatch'} + + correct = 0 + total = len(predictions) + + print(f"\n--- DEBUGGING EVALUATOR ({total} samples) ---") + + for i, (pred, ref) in enumerate(zip(predictions, references)): + if isinstance(ref, list): ref = ref[0] + + # 1. Clean Reference + clean_ref = str(ref).split("####")[-1].strip() + clean_ref = clean_ref.replace(',', '') + + # 2. Clean Prediction + pred_str = str(pred) + + # FIX: Improved Regex + # r'-?\d+(?:\.\d+)?' + # -? : Optional negative sign + # \d+ : One or more digits + # (?:\.\d+)? : Optional group: A dot FOLLOWED BY digits. + # This ignores "72." but captures "72.5" + numbers = re.findall(r'-?\d+(?:\.\d+)?', pred_str) + + clean_pred = numbers[-1] if numbers else "NO_NUMBER_FOUND" + + # 3. Compare + # Use float comparison for robustness (72.0 == 72) + try: + is_match = float(clean_pred) == float(clean_ref) + except ValueError: + is_match = (clean_pred == clean_ref) + + if is_match: + correct += 1 + print(f"[Sample {i} PASSED] {clean_pred} == {clean_ref}") + else: + print(f"[Sample {i} FAILED] Expected: '{clean_ref}' | Got: '{clean_pred}'") + + print("-------------------------------------------\n") + return {'accuracy': (correct / total) * 100} + + + + diff --git a/opencompass/models/llada.py b/opencompass/models/llada.py new file mode 100644 index 000000000..0f0569bb2 --- /dev/null +++ b/opencompass/models/llada.py @@ -0,0 +1,123 @@ +import os +import sys +import torch +from opencompass.models import HuggingFace + +# 1. Import the OFFICIAL LLaDA Generation Loop +LLADA_REPO_PATH = os.path.abspath("/workspace/llada_test_run/LLaDA") +if LLADA_REPO_PATH not in sys.path: + sys.path.append(LLADA_REPO_PATH) + +try: + from generate import generate as llada_generate +except ImportError: + print(f"CRITICAL: Could not find 'generate.py' in {LLADA_REPO_PATH}") + +class LLaDA(HuggingFace): + """ + OpenCompass Wrapper for LLaDA 1.5 (Diffusion LLM). + """ + def __init__(self, + steps=64, + gen_length=128, + block_length=128, + tokenizer_path=None, + tokenizer_kwargs=None, + *args, + **kwargs): + + # Save attributes BEFORE calling super().__init__ + self.steps = steps + self.gen_length = gen_length + self.block_length = block_length + self.tokenizer_path = tokenizer_path + self.tokenizer_kwargs = tokenizer_kwargs or {} + + # Re-inject them into kwargs for the super class + if tokenizer_path: + kwargs['tokenizer_path'] = tokenizer_path + if tokenizer_kwargs: + kwargs['tokenizer_kwargs'] = tokenizer_kwargs + + super().__init__(*args, **kwargs) + + def _load_model(self, path, **kwargs): + from transformers import AutoModel, AutoTokenizer + + # -------------------------------------------------------- + # 1. LOAD TOKENIZER + # -------------------------------------------------------- + if 'trust_remote_code' in self.tokenizer_kwargs: + self.tokenizer_kwargs.pop('trust_remote_code') + + self.tokenizer = AutoTokenizer.from_pretrained( + self.tokenizer_path, + trust_remote_code=True, + **self.tokenizer_kwargs + ) + + # -------------------------------------------------------- + # 2. LOAD MODEL + # -------------------------------------------------------- + + # [CRITICAL FIX] Unpack 'model_kwargs' dictionary. + # OpenCompass passes configuration inside this key, but the + # model constructor expects flat arguments. + nested_model_kwargs = kwargs.pop('model_kwargs', {}) or {} + kwargs.update(nested_model_kwargs) + + # Clean up other OpenCompass keys that AutoModel doesn't recognize + kwargs.pop('peft_path', None) + kwargs.pop('peft_kwargs', None) + + # Prevent "multiple values" error for trust_remote_code + if 'trust_remote_code' in kwargs: + kwargs.pop('trust_remote_code') + + # Convert string torch_dtype (from config) to actual torch object + if 'torch_dtype' in kwargs and isinstance(kwargs['torch_dtype'], str): + dtype_str = kwargs['torch_dtype'] + if dtype_str == 'torch.float16': + kwargs['torch_dtype'] = torch.float16 + elif dtype_str == 'torch.bfloat16': + kwargs['torch_dtype'] = torch.bfloat16 + elif dtype_str == 'torch.float32': + kwargs['torch_dtype'] = torch.float32 + + self.model = AutoModel.from_pretrained( + path, + trust_remote_code=True, + **kwargs + ) + + self.model.eval() + + def generate(self, inputs, max_out_len, **kwargs): + # 1. Handle Input + prompt_text = inputs[0] if isinstance(inputs, list) else inputs + + # 2. Tokenize + input_ids = self.tokenizer( + prompt_text, + return_tensors="pt" + ).input_ids.to(self.model.device) + + # 3. Dynamic Canvas Sizing + current_gen_len = max_out_len if max_out_len else self.gen_length + + # 4. Run Diffusion + out = llada_generate( + model=self.model, + prompt=input_ids, + steps=self.steps, + gen_length=current_gen_len, + block_length=self.block_length, + temperature=0.0, + cfg_scale=0.0, + remasking='low_confidence' + ) + + # 5. Decode + return self.tokenizer.batch_decode(out, skip_special_tokens=True) + +