diff --git a/smac/env/starcraft2/starcraft2.py b/smac/env/starcraft2/starcraft2.py index 4dcae78d..66b2ffea 100644 --- a/smac/env/starcraft2/starcraft2.py +++ b/smac/env/starcraft2/starcraft2.py @@ -64,7 +64,7 @@ class StarCraft2Env(MultiAgentEnv): def __init__( self, map_name="8m", - step_mul=None, + step_mul=8, move_amount=2, difficulty="7", game_version=None, @@ -91,6 +91,7 @@ def __init__( replay_prefix="", window_size_x=1920, window_size_y=1200, + heuristic_ai=False, debug=False, ): """ @@ -102,7 +103,7 @@ def __init__( The name of the SC2 map to play (default is "8m"). The full list can be found by running bin/map_list. step_mul : int, optional - How many game steps per agent step (default is None). None + How many game steps per agent step (default is 8). None indicates to use the default map step_mul. move_amount : float, optional How far away units are ordered to move per step (default is 2). @@ -179,6 +180,8 @@ def __init__( The length of StarCraft II window size (default is 1920). window_size_y: int, optional The height of StarCraft II window size (default is 1200). + heuristic_ai: bool, optional + Whether or not to use a non-learning heuristic AI (default False). debug: bool, optional Log messages about observations, state, actions and rewards for debugging purposes (default is False). @@ -222,6 +225,7 @@ def __init__( self.game_version = game_version self.continuing_episode = continuing_episode self._seed = seed + self.heuristic_ai = heuristic_ai self.debug = debug self.window_size = (window_size_x, window_size_y) self.replay_dir = replay_dir @@ -348,6 +352,9 @@ def reset(self): self.last_action = np.zeros((self.n_agents, self.n_actions)) + if self.heuristic_ai: + self.heuristic_targets = [None] * self.n_agents + try: self._obs = self._controller.observe() self.init_units() @@ -389,7 +396,10 @@ def step(self, actions): logging.debug("Actions".center(60, "-")) for a_id, action in enumerate(actions): - agent_action = self.get_agent_action(a_id, action) + if not self.heuristic_ai: + agent_action = self.get_agent_action(a_id, action) + else: + agent_action = self.get_agent_action_heuristic(a_id, action) if agent_action: sc_actions.append(agent_action) @@ -548,6 +558,59 @@ def get_agent_action(self, a_id, action): sc_action = sc_pb.Action(action_raw=r_pb.ActionRaw(unit_command=cmd)) return sc_action + def get_agent_action_heuristic(self, a_id, action): + unit = self.get_unit_by_id(a_id) + tag = unit.tag + + target = self.heuristic_targets[a_id] + if unit.unit_type == self.medivac_id: + if (target is None or self.agents[target].health == 0 or + self.agents[target].health == self.agents[target].health_max): + min_dist = math.hypot(self.max_distance_x, self.max_distance_y) + min_id = -1 + for al_id, al_unit in self.agents.items(): + if al_unit.unit_type == self.medivac_id: + continue + if (al_unit.health != 0 and + al_unit.health != al_unit.health_max): + dist = self.distance(unit.pos.x, unit.pos.y, + al_unit.pos.x, al_unit.pos.y) + if dist < min_dist: + min_dist = dist + min_id = al_id + self.heuristic_targets[a_id] = min_id + if min_id == -1: + self.heuristic_targets[a_id] = None + return None + action_id = actions['heal'] + target_tag = self.agents[self.heuristic_targets[a_id]].tag + else: + if target is None or self.enemies[target].health == 0: + min_dist = math.hypot(self.max_distance_x, self.max_distance_y) + min_id = -1 + for e_id, e_unit in self.enemies.items(): + if (unit.unit_type == self.marauder_id and + e_unit.unit_type == self.medivac_id): + continue + if e_unit.health > 0: + dist = self.distance(unit.pos.x, unit.pos.y, + e_unit.pos.x, e_unit.pos.y) + if dist < min_dist: + min_dist = dist + min_id = e_id + self.heuristic_targets[a_id] = min_id + action_id = actions['attack'] + target_tag = self.enemies[self.heuristic_targets[a_id]].tag + + cmd = r_pb.ActionRawUnitCommand( + ability_id = action_id, + target_unit_tag = target_tag, + unit_tags = [tag], + queue_command = False) + + sc_action = sc_pb.Action(action_raw=r_pb.ActionRaw(unit_command=cmd)) + return sc_action + def reward_battle(self): """Reward function when self.reward_spare==False. Returns accumulative hit/shield point damage dealt to the enemy