Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,5 +1,17 @@
outputs/

# Prime-RL (separate repo with trained models)
prime-rl/

# CSV data exports
*.csv

# Generated plots
plots/

# Eval results
eval_results/

# Byte-compiled / optimized / DLL files
__pycache__/
*.py[codz]
Expand Down
72 changes: 72 additions & 0 deletions configs/medcalc_bench_8b_nothink_rl.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
inference_gpu_ids = [0, 1, 2, 3, 4, 5]
trainer_gpu_ids = [6, 7]

max_steps = 300
seq_len = 4096

[wandb]
project = "medcalc-bench-verified-8b"
name = "medcalc-bench-verified-8b-nothink"

[ckpt]
# Checkpoint at the end of training

[model]
name = "Qwen/Qwen3-8B-Instruct-2507"

[orchestrator]
batch_size = 128 # Reduced for 8B model
rollouts_per_example = 16

[[orchestrator.env]]
id = "medcalc_bench"

# Environment arguments for MedCalc-Bench
[orchestrator.env.args]
one_shot = false
add_python_tool = false
use_think = false # No thinking tags
answer_format = "xml"
use_verified_dataset = true # Use MedCalc-Bench-Verified

[orchestrator.sampling]
max_tokens = 1024
temperature = 1.0

# Evaluation configuration
[orchestrator.eval]
interval = 10
eval_base_model = true

[[orchestrator.eval.env]]
id = "medcalc_bench"
num_examples = 300
rollouts_per_example = 1

[orchestrator.eval.env.args]
one_shot = false
add_python_tool = false
use_think = false # No thinking tags
answer_format = "xml"
use_verified_dataset = true # Use MedCalc-Bench-Verified

[orchestrator.eval.sampling]
temperature = 0.0
max_tokens = 1024

[trainer]
# Default trainer config (GRPO)

[trainer.model.lora]
rank = 32
alpha = 64.0

[trainer.optim]
lr = 1e-5

[inference]
# Inference server config

[inference.parallel]
tp = 2
dp = 3
73 changes: 73 additions & 0 deletions configs/medcalc_bench_rl.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# MedCalc-Bench RL Training Configuration
# Clinical calculator reasoning with numeric/date outputs

inference_gpu_ids = [0, 1, 2, 3, 4, 5]
trainer_gpu_ids = [6, 7]

max_steps = 300
seq_len = 4096

[wandb]
project = "medcalc-bench-rl-1"
name = "medcalc-bench-rl-1"

[ckpt]
# Checkpoint at the end of training

[model]
name = "Qwen/Qwen3-4B-Instruct-2507"

[orchestrator]
batch_size = 512
rollouts_per_example = 16

[[orchestrator.env]]
id = "medcalc_bench"

# Environment arguments for MedCalc-Bench
[orchestrator.env.args]
one_shot = false # Zero-shot prompting
add_python_tool = false # No code execution
use_think = true # Enable <think> reasoning tags
answer_format = "xml" # Use XML format for answers

[orchestrator.sampling]
max_tokens = 2048
temperature = 1.0

# Evaluation configuration
[orchestrator.eval]
interval = 10 # Evaluate every 10 steps
eval_base_model = true # Also evaluate the base model at step 0

[[orchestrator.eval.env]]
id = "medcalc_bench"
num_examples = 300
rollouts_per_example = 1

[orchestrator.eval.env.args]
one_shot = false
add_python_tool = false
use_think = true
answer_format = "xml"

[orchestrator.eval.sampling]
temperature = 0.0 # Greedy sampling for evaluation
max_tokens = 2048

[trainer]
# Default trainer config (GRPO)

[trainer.model.lora]
rank = 32
alpha = 64.0

[trainer.optim]
lr = 1e-5 # LoRA typically benefits from higher LR

[inference]
# Inference server config

[inference.parallel]
tp = 2 # Split model across 2 GPUs for faster per-query latency
dp = 3 # 3 parallel instances
72 changes: 72 additions & 0 deletions configs/medcalc_bench_short_rl.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
inference_gpu_ids = [0, 1, 2, 3, 4, 5]
trainer_gpu_ids = [6, 7]

max_steps = 300
seq_len = 4096

[wandb]
project = "medcalc-bench-verified"
name = "medcalc-bench-verified-2048tok"

[ckpt]
# Checkpoint at the end of training

[model]
name = "Qwen/Qwen3-4B-Instruct-2507"

[orchestrator]
batch_size = 256
rollouts_per_example = 16

[[orchestrator.env]]
id = "medcalc_bench"

# Environment arguments for MedCalc-Bench
[orchestrator.env.args]
one_shot = false
add_python_tool = false
use_think = true
answer_format = "xml"
use_verified_dataset = true # Use MedCalc-Bench-Verified

[orchestrator.sampling]
max_tokens = 1024
temperature = 1.0

# Evaluation configuration
[orchestrator.eval]
interval = 10
eval_base_model = true

[[orchestrator.eval.env]]
id = "medcalc_bench"
num_examples = 300
rollouts_per_example = 1

[orchestrator.eval.env.args]
one_shot = false
add_python_tool = false
use_think = true
answer_format = "xml"
use_verified_dataset = true # Use MedCalc-Bench-Verified

[orchestrator.eval.sampling]
temperature = 0.0
max_tokens = 1024

[trainer]
# Default trainer config (GRPO)

[trainer.model.lora]
rank = 32
alpha = 64.0

[trainer.optim]
lr = 1e-5

[inference]
# Inference server config

[inference.parallel]
tp = 2
dp = 3
77 changes: 77 additions & 0 deletions configs/medcalc_bench_tools_rl.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# MedCalc-Bench RL Training Configuration (with Python Tool)
# Clinical calculator reasoning with code execution

inference_gpu_ids = [0, 1, 2, 3, 4, 5]
trainer_gpu_ids = [6, 7]

max_steps = 300
seq_len = 8192 # Longer for multi-turn tool use

[wandb]
project = "medcalc-bench-tools"
name = "medcalc-bench-tools-rl"

[ckpt]
# Checkpoint at the end of training

[model]
name = "Qwen/Qwen3-4B-Instruct-2507"

[orchestrator]
batch_size = 256 # Smaller batch for multi-turn (more tokens per example)
rollouts_per_example = 8

[[orchestrator.env]]
id = "medcalc_bench"

# Environment arguments for MedCalc-Bench with tools
[orchestrator.env.args]
one_shot = false
add_python_tool = true # Enable Python tool for calculations
use_think = true
answer_format = "xml"
max_turns = 10 # Allow multiple tool calls
use_verified_dataset = true # Use MedCalc-Bench-Verified

[orchestrator.sampling]
max_tokens = 2048
temperature = 1.0

# Evaluation configuration
[orchestrator.eval]
interval = 10
eval_base_model = true

[[orchestrator.eval.env]]
id = "medcalc_bench"
num_examples = 300
rollouts_per_example = 1

[orchestrator.eval.env.args]
one_shot = false
add_python_tool = true # Tools enabled for eval too
use_think = true
answer_format = "xml"
max_turns = 10
use_verified_dataset = true # Use MedCalc-Bench-Verified

[orchestrator.eval.sampling]
temperature = 0.0
max_tokens = 2048

[trainer]
# Default trainer config (GRPO)

[trainer.model.lora]
rank = 32
alpha = 64.0

[trainer.optim]
lr = 1e-5

[inference]
# Inference server config

[inference.parallel]
tp = 2
dp = 3
72 changes: 72 additions & 0 deletions configs/medcasereasoning_rl.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
# MedCaseReasoning RL Training Configuration
# Medical diagnosis from case presentations using LLM-as-a-Judge

inference_gpu_ids = [0, 1, 2, 3, 4, 5]
trainer_gpu_ids = [6, 7]

max_steps = 300
seq_len = 2048 # Prompts are shorter (max ~1034 tokens)

[wandb]
project = "medcasereasoning-rl"
name = "medcasereasoning-rl"

[ckpt]
# Checkpoint at the end of training

[model]
name = "Qwen/Qwen3-4B-Instruct-2507"

[orchestrator]
batch_size = 512
rollouts_per_example = 16

[[orchestrator.env]]
id = "medcasereasoning"

# Environment arguments - uses LLM-as-a-Judge
[orchestrator.env.args]
judge_model = "gpt-5-nano"
reasoning_effort = "low"
# judge_base_url = "http://localhost:8001/v1" # Uncomment for local judge
# judge_api_key = "your-api-key" # Optional, uses OPENAI_API_KEY by default

[orchestrator.sampling]
max_tokens = 1024
temperature = 1.0

# Evaluation configuration
[orchestrator.eval]
interval = 10
eval_base_model = true

[[orchestrator.eval.env]]
id = "medcasereasoning"
num_examples = 100 # Smaller eval set (LLM judge is slow/expensive)
rollouts_per_example = 1

[orchestrator.eval.env.args]
judge_model = "gpt-5-nano"
reasoning_effort = "low"
eval_split = "test" # Use test set for evaluation

[orchestrator.eval.sampling]
temperature = 0.0
max_tokens = 1024

[trainer]
# Default trainer config (GRPO)

[trainer.model.lora]
rank = 32
alpha = 64.0

[trainer.optim]
lr = 1e-5

[inference]
# Inference server config

[inference.parallel]
tp = 2
dp = 3
Loading