Skip to content

Commit

Permalink
los and friendly fire
Browse files Browse the repository at this point in the history
  • Loading branch information
clstatham committed Sep 5, 2023
1 parent ab22fcc commit 1b56380
Show file tree
Hide file tree
Showing 5 changed files with 157 additions and 107 deletions.
2 changes: 1 addition & 1 deletion src/brains/models/deterministic_mlp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ impl DeterministicMlpActor {
last_out = *next_out;
}
let optim = Mutex::new(adam(varmap.all_vars(), lr).unwrap());
let noise = Mutex::new(OuNoise::new(last_out, 0.0, 0.15, 0.3, 0.05, 1000000));
let noise = Mutex::new(OuNoise::new(last_out, 0.0, 0.15, 0.3, 0.05, 100000));
Self {
varmap,
layers,
Expand Down
112 changes: 56 additions & 56 deletions src/envs/maps/tdm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,65 +70,65 @@ impl Map for TdmMap {
.insert(ActiveEvents::all())
.insert(TransformBundle::from(Transform::from_xyz(500.0, 0.0, 0.0)));

// // right-middle wall
// commands
// .spawn(Collider::cuboid(100.0, 10.0))
// .insert(SpriteBundle {
// sprite: Sprite {
// color: Color::BLACK,
// custom_size: Some(Vec2::new(200.0, 20.0)),
// ..default()
// },
// ..default()
// })
// .insert(Wall)
// .insert(ActiveEvents::all())
// .insert(TransformBundle::from(Transform::from_xyz(400.0, 0.0, 0.0)));
// right-middle wall
commands
.spawn(Collider::cuboid(100.0, 10.0))
.insert(SpriteBundle {
sprite: Sprite {
color: Color::BLACK,
custom_size: Some(Vec2::new(200.0, 20.0)),
..default()
},
..default()
})
.insert(Wall)
.insert(ActiveEvents::all())
.insert(TransformBundle::from(Transform::from_xyz(400.0, 0.0, 0.0)));

// // left-middle wall
// commands
// .spawn(Collider::cuboid(100.0, 10.0))
// .insert(SpriteBundle {
// sprite: Sprite {
// color: Color::BLACK,
// custom_size: Some(Vec2::new(200.0, 20.0)),
// ..default()
// },
// ..default()
// })
// .insert(Wall)
// .insert(ActiveEvents::all())
// .insert(TransformBundle::from(Transform::from_xyz(-400.0, 0.0, 0.0)));
// left-middle wall
commands
.spawn(Collider::cuboid(100.0, 10.0))
.insert(SpriteBundle {
sprite: Sprite {
color: Color::BLACK,
custom_size: Some(Vec2::new(200.0, 20.0)),
..default()
},
..default()
})
.insert(Wall)
.insert(ActiveEvents::all())
.insert(TransformBundle::from(Transform::from_xyz(-400.0, 0.0, 0.0)));

// // top-middle wall
// commands
// .spawn(Collider::cuboid(10.0, 100.0))
// .insert(SpriteBundle {
// sprite: Sprite {
// color: Color::BLACK,
// custom_size: Some(Vec2::new(20.0, 200.0)),
// ..default()
// },
// ..default()
// })
// .insert(Wall)
// .insert(ActiveEvents::all())
// .insert(TransformBundle::from(Transform::from_xyz(0.0, 200.0, 0.0)));
// top-middle wall
commands
.spawn(Collider::cuboid(10.0, 100.0))
.insert(SpriteBundle {
sprite: Sprite {
color: Color::BLACK,
custom_size: Some(Vec2::new(20.0, 200.0)),
..default()
},
..default()
})
.insert(Wall)
.insert(ActiveEvents::all())
.insert(TransformBundle::from(Transform::from_xyz(0.0, 200.0, 0.0)));

// // bottom-middle wall
// commands
// .spawn(Collider::cuboid(10.0, 100.0))
// .insert(SpriteBundle {
// sprite: Sprite {
// color: Color::BLACK,
// custom_size: Some(Vec2::new(20.0, 200.0)),
// ..default()
// },
// ..default()
// })
// .insert(Wall)
// .insert(ActiveEvents::all())
// .insert(TransformBundle::from(Transform::from_xyz(0.0, -200.0, 0.0)));
// bottom-middle wall
commands
.spawn(Collider::cuboid(10.0, 100.0))
.insert(SpriteBundle {
sprite: Sprite {
color: Color::BLACK,
custom_size: Some(Vec2::new(20.0, 200.0)),
..default()
},
..default()
})
.insert(Wall)
.insert(ActiveEvents::all())
.insert(TransformBundle::from(Transform::from_xyz(0.0, -200.0, 0.0)));

// // bottom-left corner wall
// commands
Expand Down
2 changes: 1 addition & 1 deletion src/envs/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,7 @@ impl RunningReturn {
pub struct Terminal(pub bool);

#[derive(Component)]
pub struct Kills(pub usize);
pub struct Kills(pub isize);

#[derive(Component)]
pub struct Deaths(pub usize);
Expand Down
140 changes: 97 additions & 43 deletions src/envs/tdm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,7 @@ fn observation(
),
With<Agent>,
>,
mut gizmos: Gizmos,
) {
queries.iter().for_each(
|(agent, agent_id, my_team, _action, velocity, transform, health)| {
Expand All @@ -372,7 +373,7 @@ fn observation(
enemies: vec![],
};
for (
_other,
other,
other_id,
other_team,
other_action,
Expand All @@ -381,34 +382,90 @@ fn observation(
other_health,
) in queries
.iter()
.filter(|q| q.0 != agent)
.filter(|(ent, ..)| *ent != agent)
.sorted_by_key(|(_, id, _, _, _, _t, _)| id.0)
{
if my_team.0 == other_team.0 {
// check line of sight
let f = |ent| queries.get(ent).is_err() || ent == other;
let filter = QueryFilter::new().predicate(&f);
if let Some((hit_ent, _toi)) = cx.cast_ray(
transform.translation.xy(),
(other_transform.translation.xy() - transform.translation.xy()).normalize(),
Real::MAX,
true,
filter,
) {
if hit_ent == other {
// gizmos.line_2d(
// transform.translation.xy(),
// other_transform.translation.xy(),
// Color::WHITE,
// );
if my_team.0 == other_team.0 {
my_state.teammates.push(TeammateObs {
ident: IdentityEmbedding::new(
other_id.0,
params.get_int("num_agents").unwrap() as usize,
),
phys: RelativePhysicalProperties::new(transform, other_transform),
combat: CombatProperties {
health: other_health.0,
},
map_interaction: MapInteractionProperties::new(
other_transform,
&cx,
),
firing: other_action.combat.shoot > 0.0,
});
} else {
my_state.enemies.push(EnemyObs {
ident: IdentityEmbedding::new(
other_id.0,
params.get_int("num_agents").unwrap() as usize,
),
phys: RelativePhysicalProperties::new(transform, other_transform),
combat: CombatProperties {
health: other_health.0,
},
map_interaction: MapInteractionProperties::new(
other_transform,
&cx,
),
firing: other_action.combat.shoot > 0.0,
});
}
} else if my_team.0 == other_team.0 {
my_state.teammates.push(TeammateObs {
ident: IdentityEmbedding::new(
other_id.0,
params.get_int("num_agents").unwrap() as usize,
),
..Default::default()
});
} else {
my_state.enemies.push(EnemyObs {
ident: IdentityEmbedding::new(
other_id.0,
params.get_int("num_agents").unwrap() as usize,
),
..Default::default()
});
}
} else if my_team.0 == other_team.0 {
my_state.teammates.push(TeammateObs {
ident: IdentityEmbedding::new(
other_id.0,
params.get_int("num_agents").unwrap() as usize,
),
phys: RelativePhysicalProperties::new(transform, other_transform),
combat: CombatProperties {
health: other_health.0,
},
map_interaction: MapInteractionProperties::new(other_transform, &cx),
firing: other_action.combat.shoot > 0.0,
..Default::default()
});
} else {
my_state.enemies.push(EnemyObs {
ident: IdentityEmbedding::new(
other_id.0,
params.get_int("num_agents").unwrap() as usize,
),
phys: RelativePhysicalProperties::new(transform, other_transform),
combat: CombatProperties {
health: other_health.0,
},
map_interaction: MapInteractionProperties::new(other_transform, &cx),
firing: other_action.combat.shoot > 0.0,
..Default::default()
});
}
}
Expand Down Expand Up @@ -501,54 +558,51 @@ fn get_reward(
true,
filter,
) {
gizmos.line_2d(
ray_pos,
ray_pos + ray_dir * toi,
TEAM_COLORS[team_id.get(agent_ent).unwrap().0 as usize],
);
gizmos.line_2d(ray_pos, ray_pos + ray_dir * toi, Color::WHITE);

if let Ok(mut health) = health.get_mut(hit_entity) {
if let Ok(other_team) = team_id.get(hit_entity) {
if my_team.0 == other_team.0 {
// friendly fire!
rewards.get_component_mut::<Reward>(agent_ent).unwrap().0 +=
params.get_float("reward_for_friendly_fire").unwrap() as f32;
} else if health.0 > 0.0 {
if health.0 > 0.0 {
health.0 -= 5.0;
rewards.get_component_mut::<Reward>(agent_ent).unwrap().0 +=
params.get_float("reward_for_hit").unwrap() as f32;
rewards.get_component_mut::<Reward>(hit_entity).unwrap().0 +=
params.get_float("reward_for_getting_hit").unwrap() as f32;
if health.0 <= 0.0 {
rewards.get_mut(agent_ent).unwrap().0 .0 +=
params.get_float("reward_for_kill").unwrap() as f32;
if my_team.0 == other_team.0 {
// friendly fire!
rewards.get_mut(agent_ent).unwrap().0 .0 -=
params.get_float("reward_for_kill").unwrap() as f32;
kills.get_mut(agent_ent).unwrap().0 -= 1;
let msg = format!(
"{} killed their TEAMMATE {}! Not nice!",
&names.get(agent_ent).unwrap().0,
&names.get(hit_entity).unwrap().0
);
log.push(msg);
} else {
rewards.get_mut(agent_ent).unwrap().0 .0 +=
params.get_float("reward_for_kill").unwrap() as f32;
kills.get_mut(agent_ent).unwrap().0 += 1;
let msg = format!(
"{} killed {}! Nice!",
&names.get(agent_ent).unwrap().0,
&names.get(hit_entity).unwrap().0
);
log.push(msg);
}
rewards.get_mut(hit_entity).unwrap().0 .0 +=
params.get_float("reward_for_death").unwrap() as f32;
kills.get_mut(agent_ent).unwrap().0 += 1;
let msg = format!(
"{} killed {}! Nice!",
&names.get(agent_ent).unwrap().0,
&names.get(hit_entity).unwrap().0
);
log.push(msg);
}
}
}
} else {
// hit a wall
rewards.get_component_mut::<Reward>(agent_ent).unwrap().0 +=
params.get_float("reward_for_miss").unwrap() as f32;
}
} else {
gizmos.line_2d(
ray_pos,
ray_pos
+ ray_dir * params.get_float("agent_shoot_distance").unwrap() as f32,
TEAM_COLORS[team_id.get(agent_ent).unwrap().0 as usize],
Color::WHITE,
);
// hit nothing
rewards.get_component_mut::<Reward>(agent_ent).unwrap().0 +=
params.get_float("reward_for_miss").unwrap() as f32;
}
}

Expand Down
8 changes: 2 additions & 6 deletions tdm.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
params:
num_agents: 3
agents_per_team: 1
num_agents: 6
agents_per_team: 2
agent_hidden_dim: 64
actor_lr: 0.01
critic_lr: 0.01
Expand All @@ -16,8 +16,4 @@ params:
agent_shoot_distance: 1000.0
reward_for_kill: 100.0
reward_for_death: -100.0
reward_for_hit: 0.0
reward_for_getting_hit: 0.0
reward_for_miss: 0.0
reward_for_friendly_fire: 0.0

0 comments on commit 1b56380

Please sign in to comment.