Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 54 additions & 0 deletions docs/examples/budget_forcing/budget_forcing_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
"""
Example usage of budget forcing in long-chain-of-thought reasoning tasks.

To run this script from the root of the Mellea source tree, use the command:
```
uv run python docs/examples/budget_forcing/budget_forcing_example.py
```
"""

from mellea import MelleaSession, start_session
from mellea.backends import ModelOption
from mellea.backends.model_ids import IBM_GRANITE_4_MICRO_3B
from mellea.stdlib.sampling.budget_forcing import BudgetForcingSamplingStrategy


def solve_on_budget(
m_session: MelleaSession, prompt: str, thinking_budget: int = 512
) -> str:
"""Solves the problem in `prompt`, force-stopping thinking at `thinking_budget` tokens
(if reached), and returns the solution"""
# Sampling strategy for budget forcing: pass the thinking budget here
strategy = BudgetForcingSamplingStrategy(
think_max_tokens=thinking_budget,
start_think_token="<think>",
end_think_token="</think>",
answer_suffix="The final answer is:",
requirements=None,
)

# Perform greedy decoding, not exceeding the thinking token budget
result = m_session.instruct(
prompt, strategy=strategy, model_options={ModelOption.TEMPERATURE: 0}
)
output_str = str(
result
) # solution containing (a) a thinking section within <think> and </think> (possibly incomplete due to budget forcing), and (b) a final answer

return output_str


# Create a Mellea session for granite-4.0-micro with an Ollama backend
m_session = start_session(backend_name="ollama", model_id=IBM_GRANITE_4_MICRO_3B)

# Demonstrate granite solving the same problem on various thinking budgets
prompt = "To double your investment in 5 years, what must your annual return be? Put your final answer within \\boxed{}."
different_thinking_budgets = [256, 64, 16] # max number of thinking tokens allowed
for thinking_budget in different_thinking_budgets:
solution = solve_on_budget(m_session, prompt, thinking_budget=thinking_budget)
header = f"MAX THINKING BUDGET: {thinking_budget} tokens"
print(f"{'-' * len(header)}\n{header}\n{'-' * len(header)}")
print(f"PROMPT: {prompt}")
print(f"\nSOLUTION: {solution}")
print(f"\n\nSOLUTION LENGTH: {len(solution)} characters")
print(f"{'-' * len(header)}\n\n")
83 changes: 83 additions & 0 deletions docs/examples/majority_voting/mbrd_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
"""
Example usage of minimum Bayes risk decoding (MBRD).

To run this script from the root of the Mellea source tree, use the command:
```
uv run python docs/examples/majority_voting/mbrd_example.py
```
"""

from mellea import MelleaSession
from mellea.backends.vllm import LocalVLLMBackend
from mellea.backends.types import ModelOption
from mellea.backends.model_ids import QWEN3_0_6B
from mellea.stdlib.sampling.majority_voting import MBRDRougeLStrategy

import os

os.environ["VLLM_USE_V1"] = "0"


def solve_using_mbrd(
m_session: MelleaSession, prompt: str, num_samples: int = 8
) -> str:
"""Solves the problem in `prompt` by generating `num_samples` solutions and
selecting the one with the highest average RougeL with the rest"""
# Generate and select the MBR solution
result = m_session.instruct(
prompt,
strategy=MBRDRougeLStrategy(number_of_samples=num_samples, loop_budget=1),
model_options={
ModelOption.MAX_NEW_TOKENS: 1024,
ModelOption.SYSTEM_PROMPT: "Answer in English.",
},
return_sampling_results=True,
)
raw_output = str(result.result).strip()

# Do any required post-processing (can be model-specific) and extract the final response
def postprocess(raw_output: str) -> str:
# If the raw output contains a thinking section in the beginning, remove it so that
# the user only sees the actual response that follows the closing `</think>` token
if "</think>" in raw_output:
return raw_output.split("</think>")[1].strip()
return raw_output

output = postprocess(raw_output)
return output


# Create a Mellea session for the target use case
max_samples = 8 # indicates that we might want to do MBRD with as many as 8 samples
backend = LocalVLLMBackend(
model_id=QWEN3_0_6B,
model_options={
"gpu_memory_utilization": 0.8,
"trust_remote_code": True,
"max_model_len": 2048,
"max_num_seqs": max_samples,
},
)
m_session = MelleaSession(backend)

# A few example prompts to test
a_science_prompt = "Why does metal feel colder to the touch than wood?"
a_psycholing_prompt = (
"Three reasons why children are better at learning languages than adults."
)
a_history_prompt = "Why was the great wall built?"
an_email_prompt = "We have an applicant for an intern position named Olivia Smith. I want to schedule a phone interview with her. Please draft a short email asking her about her availability."

# Let's use the email prompt in this demo
prompt = an_email_prompt

# Demonstrate how to use the MBRD feature
output = solve_using_mbrd(m_session, prompt, num_samples=8)
print(f"\n\nOutput:\n{output}")

# # Cleanup to avoid torch warning unrelated to MBRD (if needed)
# def torch_destroy_process_group():
# import torch.distributed as dist
# if dist.is_initialized():
# dist.destroy_process_group()
# torch_destroy_process_group()