Skip to content

No instruct dp #20

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 74 additions & 25 deletions cik_benchmark/baselines/direct_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Direct prompt method

"""

import inspect
import logging
import numpy as np
Expand Down Expand Up @@ -96,23 +97,35 @@ def huggingface_instruct_model_client(

def constrained_decoding_regex(required_timestamps):
"""
Generates a regular expression to force the model output
to satisfy the required format and provide values for
all required timestamps

Generates a regex pattern for constrained decoding such that:
- <forecast> occurs at start (on its own line).
- For each required timestamp ts, the model must produce
(ts,NUMBER)
with NO extra whitespace, exactly as shown:
open paren, timestamp literal, comma, numeric value, close paren
- Then </forecast> at the end (on its own line).
"""
timestamp_regex = "".join(
[
r"\(\s*{}\s*,\s*[-+]?\d+(\.\d+)?\)\n".format(re.escape(ts))
for ts in required_timestamps
]
)
return r"<forecast>\n{}<\/forecast>".format(timestamp_regex)

# Build one pattern line per required timestamp:
# (YYYY-MM-DD HH:MM:SS,[-+]?\d+(?:\.\d+)?)
# No spaces allowed anywhere, so everything is literally "fixed"
# except for the numeric portion.
lines = [
rf"\({re.escape(ts)},[-+]?\d{1,20}(?:\.\d{0,20})?\)"
for ts in required_timestamps
]

# Join lines with exactly one "\n".
body = r"\n".join(lines)

# Return the full pattern, ensuring a single newline
# after <forecast> and before </forecast>.
return rf"<forecast>\n{body}\n</forecast>"
Comment on lines +108 to +123
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a need to replace the original function? This might break reproducibility.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The original function allowed for returns such as \n \n \n \n \n \n \n ... so I fixed it to this when working with Hymba. I think it would only break reproducibility if this code is somehow encoded into controlling the RNG for the llm or for the tasks, right?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@marcotet what are your thoughts?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, all results of our paper were with the previous code so I guess it was fine - we didn't get any errors.

If you think it's better maybe good to reproduce a model's results on CiK to confirm it doesn't break reproducibility. What do you think?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does mamba need this change?


# Make generation pipeline
pipe = pipeline(
task="text-generation",
model=llm,
model=llm.cuda().to(torch.bfloat16),
tokenizer=tokenizer,
device_map="auto",
)
Expand All @@ -124,19 +137,55 @@ def constrained_decoding_regex(required_timestamps):
)

# Now extract the assistant's reply
choices = []
for response in pipe(
[messages] * n,
max_length=max_tokens,
temperature=temperature,
prefix_allowed_tokens_fn=prefix_function,
batch_size=n,
):
# Create a message object
message = SimpleNamespace(content=response[0]["generated_text"][-1]["content"])
# Create a choice object
choice = SimpleNamespace(message=message)
choices.append(choice)

# TODO: verify
# if pipe has no chat_template, get the context from the messages and append it to the prompt.
# then complete the prompt with the model
if getattr(pipe.tokenizer, "chat_template", None) is not None:
start_time = time.time()
choices = []
for response in pipe(
[messages] * n,
max_length=max_tokens,
temperature=temperature,
prefix_allowed_tokens_fn=prefix_function,
batch_size=n,
):
# Create a message object
message = SimpleNamespace(
content=response[0]["generated_text"][-1]["content"]
)
# Create a choice object
choice = SimpleNamespace(message=message)
choices.append(choice)
print(f"Time taken for completion: {time.time() - start_time}")

else:
# Get the context from the messages
context = ""
for message in messages:
context += message["content"] + " " # directly concatenate the context

# Generate completions
choices = []
start_time = time.time()
responses = pipe(
[context] * n,
temperature=temperature,
prefix_allowed_tokens_fn=prefix_function,
batch_size=n,
max_new_tokens=max_tokens,
)
print(f"Time taken for completion: {time.time() - start_time}")

for response in responses:
# Create a message object
message = SimpleNamespace(
content=response[0]["generated_text"][len(context) :].strip()
)
# Create a choice object
choice = SimpleNamespace(message=message)
choices.append(choice)

# Create a usage object (we can estimate tokens)
usage = SimpleNamespace(
Expand Down
67 changes: 67 additions & 0 deletions cik_benchmark/baselines/hf_utils/dp_hf_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@
AutoTokenizer,
LlamaTokenizerFast,
MistralForCausalLM,
# MambaConfig,
# MambaForCausalLM,
# Mamba2Config,
# Mamba2ForCausalLM,
)
import torch
import gc
Expand Down Expand Up @@ -35,6 +39,29 @@
"Mistral-7B-Instruct-v0.3": "mistralai/Mistral-7B-Instruct-v0.3",
"falcon-7b-instruct": "tiiuae/falcon-7b-instruct",
"falcon-40b-instruct": "tiiuae/falcon-40b-instruct",
# SSMs
# "Hymba-1.5B-Base": "nvidia/Hymba-1.5B-Base", # needs FlexAttention installation
# "Hymba-1.5B-Instruct": "nvidia/Hymba-1.5B-Instruct", # needs FlexAttention installation
# "mamba-2.8B": "state-spaces/mamba-2.8b-hf",
# "mamba-1.4B": "state-spaces/mamba-1.4b-hf",
# "mamba-790m": "state-spaces/mamba-790m-hf",
# "mamba-370m": "state-spaces/mamba-370m-hf",
# "mamba-130m": "state-spaces/mamba-130m-hf",
# "mamba2-2.7B": "state-spaces/mamba2-2.7b",
# "mamba2-1.3B": "state-spaces/mamba2-1.3b",
# "mamba2-780m": "state-spaces/mamba2-780m",
# "mamba2-370m": "state-spaces/mamba2-370m",
# "mamba2-130m": "state-spaces/mamba2-130m",
# "Zamba-7B-v1": "Zyphra/Zamba-7B-v1",
# "Zamba2-7B": "Zyphra/Zamba2-7B",
# "Zamba2-2.7B": "Zyphra/Zamba2-2.7B",
# "Zamba2-1.2B": "Zyphra/Zamba2-1.2B",
# "Zamba2-7B-Instruct": "Zyphra/Zamba2-7B-Instruct",
# "Zamba2-2.7B-Instruct": "Zyphra/Zamba2-2.7B-Instruct",
# "Zamba2-1.2B-Instruct": "Zyphra/Zamba2-1.2B-Instruct",
# "Falcon3-Mamba-7B-Base": "tiiuae/Falcon3-Mamba-7B-Base",
# "Falcon3-Mamba-7B-Instruct": "tiiuae/Falcon3-Mamba-7B-Instruct",
# "Bamba-9B": "ibm-fms/Bamba-9B",
}


Expand Down Expand Up @@ -72,6 +99,16 @@ def get_tokenizer(llm_path, llm_type):
tokenizer = AutoTokenizer.from_pretrained(llm_path)
elif "falcon" in llm_type:
tokenizer = AutoTokenizer.from_pretrained(llm_path)
# elif "mamba-" in llm_type:
# tokenizer = AutoTokenizer.from_pretrained(llm_path)
# elif "Zamba2-" in llm_type:
# tokenizer = AutoTokenizer.from_pretrained(llm_path)
# elif "Zamba-" in llm_type:
# tokenizer = AutoTokenizer.from_pretrained(llm_path)
# elif "Hymba-" in llm_type:
# tokenizer = AutoTokenizer.from_pretrained(
# llm_path, trust_remote_code=True, parallelism="none"
# )
else:
assert False

Expand Down Expand Up @@ -133,6 +170,36 @@ def get_model_and_tokenizer(llm_path, llm_type):
model = MistralForCausalLM.from_pretrained(
llm_path, torch_dtype=torch.bfloat16, device_map="auto"
)
# elif "mamba-" in llm_type:
# model = MambaForCausalLM.from_pretrained(
# llm_path,
# device_map="auto",
# torch_dtype=torch.float16,
# )
# elif "mamba2-" in llm_type:
# model = Mamba2ForCausalLM.from_pretrained(
# llm_path,
# device_map="auto",
# torch_dtype=torch.float16,
# )
# elif "Zamba2-" in llm_type:
# model = AutoModelForCausalLM.from_pretrained(
# llm_path,
# device_map="auto",
# torch_dtype=torch.bfloat16,
# )
# elif "Zamba-" in llm_type:
# model = AutoModelForCausalLM.from_pretrained(
# llm_path,
# device_map="auto",
# torch_dtype=torch.bfloat16,
# )
# elif "Hymba-" in llm_type:
# model = AutoModelForCausalLM.from_pretrained(
# llm_path,
# trust_remote_code=True,
# )
# model = model.cuda().to(torch.bfloat16)
else:
assert False

Expand Down
14 changes: 11 additions & 3 deletions experiments/direct-prompt-models/qwen_7b_instruct_ctx_g2.json
Original file line number Diff line number Diff line change
@@ -1,4 +1,12 @@
[
{"label": "CC-Qwen-2.5-7B-Instruct (ctx)", "method": "directprompt", "llm": "qwen2.5-7B-Instruct", "use_context": true, "temperature": 1.0,
"batch_size_on_retry":10, "batch_size":10, "n_retries": 10}
]
{
"label": "CC-Qwen-2.5-7B-Instruct (ctx)",
"method": "directprompt",
"llm": "qwen2.5-7B-Instruct",
"use_context": true,
"temperature": 1.0,
"batch_size_on_retry": 10,
"batch_size": 10,
"n_retries": 10
}
]
7 changes: 5 additions & 2 deletions experiments/statistical-models/statsmodels_c40.json
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
[
{"label": "Statsmodels", "method": "statsmodels"}
]
{
"label": "Statsmodels",
"method": "statsmodels"
}
]
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,5 +19,6 @@ termcolor
tenacity
h5py
transformers>4.4.1
tokenizers
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was this necessary for mamba or something?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I can remove it though since I didn't add the mamba models

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, that'd be great cause it might break the already-fragile requirements otherwise :D Unless you tested that it didn't :)

sentencepiece
lm-format-enforcer
6 changes: 5 additions & 1 deletion run_baselines.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,10 @@
import numpy as np
import pandas as pd


from collections import defaultdict
from pathlib import Path

from cik_benchmark.baselines.direct_prompt import DirectPrompt
from cik_benchmark.baselines.lag_llama import lag_llama
from cik_benchmark.baselines.chronos import ChronosForecaster
Expand Down Expand Up @@ -207,7 +209,9 @@ def experiment_directprompt(
dp_forecaster = DirectPrompt(
model=llm,
use_context=use_context,
token_cost=openai_costs[llm] if llm in openai_costs else {"input": 0.0, "output": 0.0}, # Cost only used for OpenAI models
token_cost=(
openai_costs[llm] if llm in openai_costs else {"input": 0.0, "output": 0.0}
), # Cost only used for OpenAI models
batch_size=batch_size,
batch_size_on_retry=batch_size_on_retry,
n_retries=n_retries,
Expand Down