diff --git a/reasoning_gym/games/countdown.py b/reasoning_gym/games/countdown.py index c8da71a5..202b2e1c 100644 --- a/reasoning_gym/games/countdown.py +++ b/reasoning_gym/games/countdown.py @@ -127,7 +127,7 @@ def _generate_candidate_expression(self, rng: Random, num_terms: int) -> tuple[s numbers = [rng.randint(self.config.min_value, self.config.max_value) for _ in range(num_terms)] # Create symbols for building expression - syms = symbols(f"x:{num_terms}") + syms = symbols(f"x_{0}:{num_terms}") # Build random expression expr = syms[0] @@ -162,7 +162,23 @@ def _generate_candidate_expression(self, rng: Random, num_terms: int) -> tuple[s # Fallback to addition for zero expr = expr + syms[i] - return expr, numbers, syms + # Safely replace symbols with numbers (to avoid name conflicts) + expr_str = str(expr) + + # Create a list of replacements: [(symbol_name, number_string), ...] + replacements = [] + for i, sym in enumerate(syms): + sym_name = str(sym) + replacements.append((sym_name, str(numbers[i]))) + + # Sort by symbol name length in descending order (replace longer names first) + replacements.sort(key=lambda x: len(x[0]), reverse=True) + + # Perform the safe replacement + for sym_name, num_str in replacements: + expr_str = expr_str.replace(sym_name, num_str) + + return expr, numbers, syms, expr_str def _generate_expression(self, rng: Random) -> tuple[str, list[int], int]: """Generate a valid expression and its result @@ -175,14 +191,13 @@ def _generate_expression(self, rng: Random) -> tuple[str, list[int], int]: max_attempts = 100 for attempt in range(max_attempts): try: - expr, numbers, syms = self._generate_candidate_expression(rng, num_terms) + expr, numbers, syms, expr_str = self._generate_candidate_expression(rng, num_terms) # Substitute actual numbers to get target subs = {sym: num for sym, num in zip(syms, numbers)} target = int(expr.subs(subs)) # Convert to string expression - expr_str = str(expr) for i, sym in enumerate(syms): expr_str = expr_str.replace(str(sym), str(numbers[i])) diff --git a/tests/test_countdown.py b/tests/test_countdown.py index 6e3a1e0a..ab32506c 100644 --- a/tests/test_countdown.py +++ b/tests/test_countdown.py @@ -114,6 +114,16 @@ def test_edge_cases_2(): assert dataset.score_answer(answer=answer, entry=item) != 1.0 +def test_countdown_more_numbers(): + """Test when min_numbers exceed 10""" + dataset = CountdownDataset( + CountdownConfig(min_numbers=11, max_numbers=11, shuffle=False, size=5, seed=42) + ) # Set 11 engaged numbers for testing + + for item in dataset: + assert item["metadata"]["target"] == int(eval(item["metadata"]["expression"])) + + def test_countdown_game_randomization(): """Test number randomization configuration""" config = CountdownConfig(min_numbers=4, max_numbers=4, shuffle=False, size=10, seed=42) # Fixed size for testing