Skip to content

Commit b07e5b9

Browse files
daveeyrelh
andauthored
Add retry logic for task generation and improve goal observation tokens (#4081)
Co-authored-by: Richard Higgins <[email protected]>
1 parent 33ebeba commit b07e5b9

File tree

16 files changed

+179
-42
lines changed

16 files changed

+179
-42
lines changed

metta/cogworks/curriculum/curriculum_env.py

Lines changed: 34 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -65,10 +65,23 @@ def _add_curriculum_stats_to_info(self, info_dict: dict) -> None:
6565
def reset(self, *args, **kwargs):
6666
"""Reset the environment and get a new task from curriculum."""
6767

68-
# Get a new task from curriculum
69-
self._current_task = self._curriculum.get_task()
70-
self._env.set_mg_config(self._current_task.get_env_cfg())
71-
obs, info = self._env.reset(*args, **kwargs)
68+
# Try to get a valid task and build the map
69+
max_retries = 10
70+
for attempt in range(max_retries):
71+
try:
72+
# Get a new task from curriculum
73+
self._current_task = self._curriculum.get_task()
74+
# Create the env config and build the map in try-catch
75+
self._env.set_mg_config(self._current_task.get_env_cfg())
76+
obs, info = self._env.reset(*args, **kwargs)
77+
break
78+
except Exception:
79+
# If config is invalid or map building fails, request a new task
80+
if attempt == max_retries - 1:
81+
# If we've exhausted retries, raise the exception
82+
raise
83+
# Otherwise, try again with a new task
84+
continue
7285

7386
# Invalidate stats cache on reset
7487
self._stats_cache_valid = False
@@ -96,8 +109,23 @@ def step(self, *args, **kwargs):
96109
self._current_task.complete(mean_reward)
97110
# Update the curriculum algorithm with task performance for learning progress
98111
self._curriculum.update_task_performance(self._current_task._task_id, mean_reward)
99-
self._current_task = self._curriculum.get_task()
100-
self._env.set_mg_config(self._current_task.get_env_cfg())
112+
113+
# Try to get a valid task and build the map
114+
max_retries = 10
115+
for attempt in range(max_retries):
116+
try:
117+
self._current_task = self._curriculum.get_task()
118+
# Create the env config and build the map in try-catch
119+
self._env.set_mg_config(self._current_task.get_env_cfg())
120+
break
121+
except Exception:
122+
# If config is invalid or map building fails, return 0 reward and request a new task
123+
if attempt == max_retries - 1:
124+
# If we've exhausted retries, set rewards to 0 and continue
125+
rewards = rewards * 0
126+
break
127+
# Otherwise, try again with a new task
128+
continue
101129

102130
# Invalidate stats cache when task changes
103131
self._stats_cache_valid = False

metta/cogworks/curriculum/learning_progress_algorithm.py

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,16 @@ def stats(self, prefix: str = "") -> Dict[str, float]:
9898
# Get base stats (required)
9999
stats = self.get_base_stats()
100100

101+
# Always include learning progress stats (not just when detailed logging is enabled)
102+
if self.hypers.use_bidirectional:
103+
lp_stats = self._get_bidirectional_detailed_stats()
104+
else:
105+
lp_stats = self._get_basic_detailed_stats()
106+
107+
# Add lp/ prefix to learning progress stats
108+
for key, value in lp_stats.items():
109+
stats[f"lp/{key}"] = value
110+
101111
if self.enable_detailed_logging:
102112
detailed = self.get_detailed_stats()
103113
stats.update(detailed)
@@ -378,20 +388,13 @@ def on_task_created(self, task: CurriculumTask) -> None:
378388
self.invalidate_cache()
379389

380390
def get_detailed_stats(self) -> Dict[str, float]:
381-
"""Get detailed stats including learning progress and slice distribution analysis."""
382-
stats = super().get_detailed_stats() # Gets slice analyzer stats
391+
"""Get detailed stats including slice distribution analysis.
383392
384-
# Always include learning progress stats (not just when detailed logging is enabled)
385-
if self.hypers.use_bidirectional:
386-
lp_stats = self._get_bidirectional_detailed_stats()
387-
else:
388-
lp_stats = self._get_basic_detailed_stats()
389-
390-
# Add lp/ prefix to learning progress stats
391-
for key, value in lp_stats.items():
392-
stats[f"lp/{key}"] = value
393-
394-
return stats
393+
Note: Learning progress stats are always included in stats() regardless of
394+
enable_detailed_logging, so they are not included here to avoid duplication.
395+
"""
396+
# Only return slice analyzer stats (LP stats are always included in stats())
397+
return super().get_detailed_stats() # Gets slice analyzer stats
395398

396399
def _get_bidirectional_detailed_stats(self) -> Dict[str, float]:
397400
"""Get detailed bidirectional learning progress statistics."""

packages/cogames/src/cogames/cogs_vs_clips/evals/diagnostic_evals.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -433,12 +433,23 @@ def configure_env(self, cfg: MettaGridConfig) -> None:
433433

434434
assembler = cfg.game.objects.get("assembler")
435435
if isinstance(assembler, AssemblerConfig):
436+
# Check if a protocol with ["gear"] vibe already exists to avoid duplicates
437+
has_gear_protocol = any(p.vibes == ["gear"] and p.min_agents == 0 for p in assembler.protocols)
438+
436439
updated_protocols: list[ProtocolConfig] = []
440+
modified_decoder = False
437441
for proto in assembler.protocols:
438442
if proto.output_resources.get("decoder", 0) > 0:
439-
inputs = {res: 1 for res in non_clipped}
440-
updated_proto = proto.model_copy(update={"vibes": ["gear"], "input_resources": inputs})
441-
updated_protocols.append(updated_proto)
443+
if has_gear_protocol:
444+
# Preserve existing decoder recipe when a gear protocol already exists
445+
updated_protocols.append(proto)
446+
elif not modified_decoder:
447+
# Only modify the first decoder protocol to avoid duplicates
448+
inputs = {res: 1 for res in non_clipped}
449+
updated_proto = proto.model_copy(update={"vibes": ["gear"], "input_resources": inputs})
450+
updated_protocols.append(updated_proto)
451+
modified_decoder = True
452+
# Skip subsequent decoder protocols to avoid duplicates
442453
else:
443454
updated_protocols.append(proto)
444455
assembler.protocols = updated_protocols

packages/cogames/src/cogames/cogs_vs_clips/evals/difficulty_variants.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -270,8 +270,9 @@ def _add_gear_protocol() -> None:
270270
if asm is None or not isinstance(asm, AssemblerConfig):
271271
return
272272

273-
# Check if ['gear'] protocol already exists
274-
if any(p.vibes == ["gear"] for p in asm.protocols):
273+
# Check if ['gear'] protocol with same min_agents already exists
274+
# C++ doesn't allow duplicate protocols with same vibes and min_agents
275+
if any(p.vibes == ["gear"] and p.min_agents == 0 for p in asm.protocols):
275276
return # Already added
276277

277278
# Add the ONE generic gear protocol for this variant

packages/cogames/src/cogames/cogs_vs_clips/stations.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -119,9 +119,9 @@ def station_cfg(self) -> AssemblerConfig:
119119
)
120120

121121

122-
# Rare regenerates slowly. More cogs increase the amount extracted.
122+
# Rare, single-use. More cogs increase the amount extracted.
123123
class GermaniumExtractorConfig(ExtractorConfig):
124-
max_uses: int = Field(default=5)
124+
max_uses: int = Field(default=1)
125125
synergy: int = 50
126126

127127
def station_cfg(self) -> AssemblerConfig:

packages/cogames/src/cogames/cogs_vs_clips/variants.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -287,9 +287,19 @@ def modify_env(self, mission, env) -> None:
287287
if not isinstance(assembler_cfg, AssemblerConfig):
288288
raise TypeError("Expected 'assembler' to be AssemblerConfig")
289289
gear_outputs = {"decoder", "modulator", "scrambler", "resonator"}
290-
for protocol in assembler_cfg.protocols:
291-
if any(k in protocol.output_resources for k in gear_outputs):
292-
protocol.vibes = ["gear"]
290+
291+
# Check if a protocol with ["gear"] vibe already exists
292+
# If so, don't modify any protocols to avoid creating duplicates
293+
has_gear_protocol = any(p.vibes == ["gear"] and p.min_agents == 0 for p in assembler_cfg.protocols)
294+
295+
if has_gear_protocol:
296+
return
297+
298+
# Rewrite all gear recipes to use only the ["gear"] vibe, keep order intact.
299+
assembler_cfg.protocols = [
300+
p.model_copy(update={"vibes": ["gear"]}) if any(k in p.output_resources for k in gear_outputs) else p
301+
for p in assembler_cfg.protocols
302+
]
293303

294304

295305
class InventoryHeartTuneVariant(MissionVariant):

packages/mettagrid/cpp/bindings/mettagrid_c.cpp

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include <iostream>
99
#include <numeric>
1010
#include <random>
11+
#include <string>
1112
#include <unordered_set>
1213
#include <vector>
1314

@@ -347,6 +348,35 @@ void MettaGrid::_compute_observation(GridCoord observer_row,
347348
global_tokens.push_back({ObservationFeature::LastReward, reward_int});
348349
}
349350

351+
// Add goal tokens for rewarding resources when enabled
352+
if (_global_obs_config.goal_obs) {
353+
auto& agent = _agents[agent_idx];
354+
// Track which resources we've already added goal tokens for
355+
std::unordered_set<std::string> added_resources;
356+
// Iterate through stat_rewards to find rewarding resources
357+
for (const auto& [stat_name, reward_value] : agent->stat_rewards) {
358+
// Extract resource name from stat name (e.g., "carbon.amount" -> "carbon", "carbon.gained" -> "carbon")
359+
size_t dot_pos = stat_name.find('.');
360+
if (dot_pos != std::string::npos) {
361+
std::string resource_name = stat_name.substr(0, dot_pos);
362+
// Only add one goal token per resource
363+
if (added_resources.find(resource_name) == added_resources.end()) {
364+
// Find the resource index in resource_names
365+
for (size_t i = 0; i < resource_names.size(); i++) {
366+
if (resource_names[i] == resource_name) {
367+
// Get the inventory feature ID for this resource
368+
ObservationType inventory_feature_id = _obs_encoder->get_inventory_feature_id(static_cast<InventoryItem>(i));
369+
// Add a goal token with the resource's inventory feature ID as the value
370+
global_tokens.push_back({ObservationFeature::Goal, inventory_feature_id});
371+
added_resources.insert(resource_name);
372+
break;
373+
}
374+
}
375+
}
376+
}
377+
}
378+
}
379+
350380
// Global tokens are always at the center of the observation.
351381
uint8_t global_location =
352382
PackedCoordinate::pack(static_cast<uint8_t>(obs_height_radius), static_cast<uint8_t>(obs_width_radius));

packages/mettagrid/cpp/include/mettagrid/config/mettagrid_config.hpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ struct GlobalObsConfig {
2424
bool last_action = true;
2525
bool last_reward = true;
2626
bool compass = false;
27+
bool goal_obs = false;
2728
};
2829

2930
struct GameConfig {
@@ -57,15 +58,17 @@ namespace py = pybind11;
5758
inline void bind_global_obs_config(py::module& m) {
5859
py::class_<GlobalObsConfig>(m, "GlobalObsConfig")
5960
.def(py::init<>())
60-
.def(py::init<bool, bool, bool, bool>(),
61+
.def(py::init<bool, bool, bool, bool, bool>(),
6162
py::arg("episode_completion_pct") = true,
6263
py::arg("last_action") = true,
6364
py::arg("last_reward") = true,
64-
py::arg("compass") = false)
65+
py::arg("compass") = false,
66+
py::arg("goal_obs") = false)
6567
.def_readwrite("episode_completion_pct", &GlobalObsConfig::episode_completion_pct)
6668
.def_readwrite("last_action", &GlobalObsConfig::last_action)
6769
.def_readwrite("last_reward", &GlobalObsConfig::last_reward)
68-
.def_readwrite("compass", &GlobalObsConfig::compass);
70+
.def_readwrite("compass", &GlobalObsConfig::compass)
71+
.def_readwrite("goal_obs", &GlobalObsConfig::goal_obs);
6972
}
7073

7174
inline void bind_game_config(py::module& m) {

packages/mettagrid/cpp/include/mettagrid/config/observation_features.hpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ class ObservationFeaturesImpl {
2727
_cooldown_remaining = get("cooldown_remaining");
2828
_clipped = get("clipped");
2929
_remaining_uses = get("remaining_uses");
30+
_goal = get("goal");
3031

3132
// Initialize public members (must be done AFTER private members are set above)
3233
Group = _group;
@@ -40,6 +41,7 @@ class ObservationFeaturesImpl {
4041
CooldownRemaining = _cooldown_remaining;
4142
Clipped = _clipped;
4243
RemainingUses = _remaining_uses;
44+
Goal = _goal;
4345
}
4446

4547
// Get feature ID by name (throws if not found)
@@ -68,6 +70,7 @@ class ObservationFeaturesImpl {
6870
ObservationType CooldownRemaining;
6971
ObservationType Clipped;
7072
ObservationType RemainingUses;
73+
ObservationType Goal;
7174

7275
private:
7376
std::unordered_map<std::string, ObservationType> _name_to_id;
@@ -84,6 +87,7 @@ class ObservationFeaturesImpl {
8487
ObservationType _cooldown_remaining;
8588
ObservationType _clipped;
8689
ObservationType _remaining_uses;
90+
ObservationType _goal;
8791
};
8892

8993
// Global singleton instance
@@ -107,6 +111,7 @@ extern ObservationType Tag;
107111
extern ObservationType CooldownRemaining;
108112
extern ObservationType Clipped;
109113
extern ObservationType RemainingUses;
114+
extern ObservationType Goal;
110115
} // namespace ObservationFeature
111116

112117
#endif // PACKAGES_METTAGRID_CPP_INCLUDE_METTAGRID_CONFIG_OBSERVATION_FEATURES_HPP_

packages/mettagrid/cpp/src/mettagrid/config/observation_features.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ ObservationType Tag;
1616
ObservationType CooldownRemaining;
1717
ObservationType Clipped;
1818
ObservationType RemainingUses;
19+
ObservationType Goal;
1920

2021
void Initialize(const std::unordered_map<std::string, ObservationType>& feature_ids) {
2122
_instance = std::make_shared<ObservationFeaturesImpl>(feature_ids);
@@ -32,5 +33,6 @@ void Initialize(const std::unordered_map<std::string, ObservationType>& feature_
3233
CooldownRemaining = _instance->CooldownRemaining;
3334
Clipped = _instance->Clipped;
3435
RemainingUses = _instance->RemainingUses;
36+
Goal = _instance->Goal;
3537
}
3638
} // namespace ObservationFeature

0 commit comments

Comments
 (0)