Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 19 additions & 4 deletions reasoning_gym/games/countdown.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def _generate_candidate_expression(self, rng: Random, num_terms: int) -> tuple[s
numbers = [rng.randint(self.config.min_value, self.config.max_value) for _ in range(num_terms)]

# Create symbols for building expression
syms = symbols(f"x:{num_terms}")
syms = symbols(f"x_{0}:{num_terms}")

# Build random expression
expr = syms[0]
Expand Down Expand Up @@ -162,7 +162,23 @@ def _generate_candidate_expression(self, rng: Random, num_terms: int) -> tuple[s
# Fallback to addition for zero
expr = expr + syms[i]

return expr, numbers, syms
# Safely replace symbols with numbers (to avoid name conflicts)
expr_str = str(expr)

# Create a list of replacements: [(symbol_name, number_string), ...]
replacements = []
for i, sym in enumerate(syms):
sym_name = str(sym)
replacements.append((sym_name, str(numbers[i])))

# Sort by symbol name length in descending order (replace longer names first)
replacements.sort(key=lambda x: len(x[0]), reverse=True)

# Perform the safe replacement
for sym_name, num_str in replacements:
expr_str = expr_str.replace(sym_name, num_str)

return expr, numbers, syms, expr_str

def _generate_expression(self, rng: Random) -> tuple[str, list[int], int]:
"""Generate a valid expression and its result
Expand All @@ -175,14 +191,13 @@ def _generate_expression(self, rng: Random) -> tuple[str, list[int], int]:
max_attempts = 100
for attempt in range(max_attempts):
try:
expr, numbers, syms = self._generate_candidate_expression(rng, num_terms)
expr, numbers, syms, expr_str = self._generate_candidate_expression(rng, num_terms)

# Substitute actual numbers to get target
subs = {sym: num for sym, num in zip(syms, numbers)}
target = int(expr.subs(subs))

# Convert to string expression
expr_str = str(expr)
for i, sym in enumerate(syms):
expr_str = expr_str.replace(str(sym), str(numbers[i]))

Expand Down
10 changes: 10 additions & 0 deletions tests/test_countdown.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,16 @@ def test_edge_cases_2():
assert dataset.score_answer(answer=answer, entry=item) != 1.0


def test_countdown_more_numbers():
"""Test when min_numbers exceed 10"""
dataset = CountdownDataset(
CountdownConfig(min_numbers=11, max_numbers=11, shuffle=False, size=5, seed=42)
) # Set 11 engaged numbers for testing

for item in dataset:
assert item["metadata"]["target"] == int(eval(item["metadata"]["expression"]))


def test_countdown_game_randomization():
"""Test number randomization configuration"""
config = CountdownConfig(min_numbers=4, max_numbers=4, shuffle=False, size=10, seed=42) # Fixed size for testing
Expand Down