Skip to content

Commit

Permalink
reworked params
Browse files Browse the repository at this point in the history
  • Loading branch information
clstatham committed Sep 4, 2023
1 parent 8cdca3a commit 8617a38
Show file tree
Hide file tree
Showing 13 changed files with 325 additions and 630 deletions.
2 changes: 0 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
/target
/training
profile.json
/burn
/dfdx
*.dot
9 changes: 2 additions & 7 deletions src/brains/learners/maddpg/mod.rs
Original file line number Diff line number Diff line change
@@ -1,19 +1,14 @@
use bevy::prelude::*;
use bevy_prng::ChaCha8Rng;
use bevy_rand::prelude::EntropyComponent;
use itertools::Itertools;

use crate::{
brains::models::{
CentralizedCritic, CompoundPolicy, CopyWeights, CriticWithTarget, Policy, PolicyWithTarget,
ValueEstimator,
},
brains::models::{CopyWeights, CriticWithTarget, Policy, PolicyWithTarget, ValueEstimator},
envs::{Action, Env},
};

use candle_core::{IndexOp, Tensor};

use self::replay_buffer::{MaddpgBuffer, MaddpgBufferInner};
use self::replay_buffer::MaddpgBuffer;

use super::{Learner, OffPolicyBuffer, Status, DEVICE};

Expand Down
22 changes: 10 additions & 12 deletions src/brains/learners/maddpg/replay_buffer.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
use bevy::{core::FrameCount, prelude::*};
use bevy_prng::ChaCha8Rng;
use bevy_rand::prelude::EntropyComponent;
use itertools::Itertools;
use rand_distr::Distribution;

Expand Down Expand Up @@ -148,14 +146,14 @@ pub struct MaddpgMetadata {
}

impl StepMetadata for MaddpgMetadata {
fn calculate<E: Env, P: Policy, V: ValueEstimator>(
obs: &FrameStack<Box<[f32]>>,
action: &E::Action,
policy: &P,
value: &V,
fn calculate<A, P: Policy, V: ValueEstimator>(
_obs: &FrameStack<Box<[f32]>>,
_action: &A,
_policy: &P,
_value: &V,
) -> Self
where
E::Action: Action<E, Logits = P::Logits>,
A: Action<Logits = P::Logits>,
{
unimplemented!()
}
Expand Down Expand Up @@ -234,7 +232,7 @@ impl<E: Env> OffPolicyBuffer<E> for MaddpgBuffer<E> {
}

pub fn store_sarts<E: Env, P: Policy, V: ValueEstimator>(
params: Res<E::Params>,
params: Res<Params>,
observations: Query<&FrameStack<Box<[f32]>>, With<Agent>>,
actions: Query<&E::Action, With<Agent>>,
mut rewards: Query<(&mut Reward, &mut RmsNormalize), With<Agent>>,
Expand All @@ -244,10 +242,10 @@ pub fn store_sarts<E: Env, P: Policy, V: ValueEstimator>(
agent_ids: Query<&AgentId, With<Agent>>,
frame_count: Res<FrameCount>,
) where
E::Action: Action<E, Logits = P::Logits>,
E::Action: Action<Logits = P::Logits>,
{
if frame_count.0 as usize % params.agent_frame_stack_len() == 0 {
if frame_count.0 as usize > params.agent_warmup() {
if frame_count.0 as usize % params.get_int("agent_frame_stack_len").unwrap() as usize == 0 {
if frame_count.0 as usize > params.get_int("agent_warmup").unwrap() as usize {
for agent_ent in agents.iter() {
let (action, reward, terminal) = (
actions.get(agent_ent).unwrap().clone(),
Expand Down
4 changes: 1 addition & 3 deletions src/brains/learners/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
use bevy::prelude::{Component, Resource};
use bevy_prng::ChaCha8Rng;
use bevy_rand::prelude::EntropyComponent;
use bevy::prelude::Resource;
use candle_core::Device;

use crate::{
Expand Down
9 changes: 6 additions & 3 deletions src/brains/mod.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
use bevy::prelude::Resource;

use crate::envs::Action;

use self::{
learners::maddpg::Maddpg,
models::{
deterministic_mlp::{DeterministicMlpActor, DeterministicMlpCritic},
linear_resnet::{LinResActor, LinResCritic},
CentralizedCritic, CompoundPolicy, CriticWithTarget, Policy, PolicyWithTarget,
ValueEstimator,
CriticWithTarget, Policy, PolicyWithTarget, ValueEstimator,
},
};

Expand All @@ -20,5 +20,8 @@ pub type AgentLearner<E> = Maddpg<E>;
#[derive(Resource)]
pub struct Policies<P: Policy>(pub Vec<P>);

#[derive(Resource)]
pub struct Actions<A: Action>(pub Vec<A>);

#[derive(Resource)]
pub struct ValueEstimators<V: ValueEstimator>(pub Vec<V>);
43 changes: 27 additions & 16 deletions src/brains/models/deterministic_mlp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ use super::{CopyWeights, Policy, PolicyWithTarget, ValueEstimator};
#[derive(Clone)]
pub struct DeterministicMlpActorStatus {
pub action: Box<[f32]>,
pub std: f32,
}

#[derive(Component)]
Expand Down Expand Up @@ -61,6 +62,7 @@ impl DeterministicMlpActor {
layers,
optim,
status: Mutex::new(DeterministicMlpActorStatus {
std: 1.0,
action: vec![0.0f32; last_out].into_boxed_slice(),
}),
}
Expand All @@ -78,22 +80,26 @@ impl Policy for DeterministicMlpActor {
for layer in self.layers[..n_layers - 1].iter() {
x = layer.forward(&x)?.tanh()?;
}
x = self.layers[n_layers - 1].forward(&x)?;
x = self.layers[n_layers - 1].forward(&x)?.tanh()?;
Ok(x)
}

fn act(&self, obs: &Tensor) -> Result<(Tensor, Self::Logits)> {
let logits = self.action_logits(obs)?;
self.status.lock().unwrap().action = logits.squeeze(0)?.to_vec1()?.into_boxed_slice();
let logits = (&logits + logits.randn_like(0.0, 0.1)?)?;
Ok((logits.clone(), logits))
let mut status = self.status.lock().unwrap();
status.action = logits.squeeze(0)?.to_vec1()?.into_boxed_slice();
let actions = (&logits + logits.randn_like(0.0, status.std as f64)?)?;
if status.std > 0.05 {
status.std *= 0.99998;
}
Ok((actions, logits))
}

fn log_prob(&self, logits: &Self::Logits, action: &Tensor) -> Result<Tensor> {
fn log_prob(&self, _logits: &Self::Logits, _action: &Tensor) -> Result<Tensor> {
unimplemented!()
}

fn entropy(&self, logits: &Self::Logits) -> Result<Tensor> {
fn entropy(&self, _logits: &Self::Logits) -> Result<Tensor> {
unimplemented!()
}

Expand Down Expand Up @@ -238,6 +244,7 @@ pub fn action_space_ui<E: Env>(
action.push_str(&format!(" {:.4}", a));
}
ui.label(action);
ui.label(format!("std: {}", status.std));
// });

// ui.horizontal_top(|ui| {
Expand All @@ -247,15 +254,14 @@ pub fn action_space_ui<E: Env>(
.iter()
.enumerate()
.map(|(i, a)| {
let m = Bar::new(i as f64, *a as f64)
// .fill(Color32::from_rgb(
// (rg.x * 255.0) as u8,
// (rg.y * 255.0) as u8,
// 0,
// ));
.fill(egui::Color32::RED);
m
// .width(1.0 - *std as f64 / 6.0)
let s = Line::new(vec![
[i as f64, *a as f64 - status.std as f64],
[i as f64, *a as f64 + status.std as f64],
])
.stroke(egui::Stroke::new(4.0, egui::Color32::LIGHT_GREEN));
let m =
Bar::new(i as f64, *a as f64).fill(egui::Color32::RED);
(m, s)
})
.collect_vec();

Expand All @@ -274,7 +280,12 @@ pub fn action_space_ui<E: Env>(
.show(
ui,
|plot| {
plot.bar_chart(BarChart::new(ms));
let (m, s): (Vec<_>, Vec<_>) =
ms.into_iter().multiunzip();
plot.bar_chart(BarChart::new(m));
for s in s {
plot.line(s);
}
},
);
});
Expand Down
12 changes: 4 additions & 8 deletions src/brains/models/linear_resnet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ pub struct LinResActor {
common3: ResBlock,
mu_head: Linear,
cov_head: Linear,
varmap: VarMap,
_varmap: VarMap,
optim: Mutex<AdamW>,
status: Mutex<Option<LinearResnetStatus>>,
}
Expand All @@ -52,7 +52,7 @@ impl LinResActor {
let cov_head = linear(hidden_len, action_len, 0.01, vs.pp("cov_head"));
let optim = AdamW::new_lr(varmap.all_vars(), lr)?;
Ok(Self {
varmap,
_varmap: varmap,
common1,
common2,
common3,
Expand Down Expand Up @@ -155,7 +155,7 @@ impl LinResCritic {
}
}
impl ValueEstimator for LinResCritic {
fn estimate_value(&self, x: &Tensor, action: Option<&Tensor>) -> Result<Tensor> {
fn estimate_value(&self, x: &Tensor, _action: Option<&Tensor>) -> Result<Tensor> {
let x = x.flatten_from(1)?;
let x = self.l1.forward(&x)?.tanh()?;
let x = self.l2.forward(&x)?;
Expand Down Expand Up @@ -214,7 +214,7 @@ pub fn action_space_ui<E: Env>(
ui.vertical(|ui| {
ui.heading(&names.get(agent).unwrap().0);
// ui.group(|ui| {
let status = LinResActor::status(&policies.get(agent).unwrap());
let status = policies.get(agent).unwrap().status();
if let Some(status) = status {
let mut mu = "mu:".to_owned();
for m in status.mu.iter() {
Expand All @@ -237,10 +237,6 @@ pub fn action_space_ui<E: Env>(
.zip(status.cov.iter())
.enumerate()
.map(|(i, (mu, cov))| {
// https://www.desmos.com/calculator/rkoehr8rve
let scale = cov.sqrt() * 3.0;
let _rg = Vec2::new(scale.exp(), (1.0 / scale).exp())
.normalize();
let m = Bar::new(i as f64, *mu as f64)
// .fill(Color32::from_rgb(
// (rg.x * 255.0) as u8,
Expand Down
Loading

0 comments on commit 8617a38

Please sign in to comment.