Skip to content

Commit

Permalink
more modularity! MOREEEE
Browse files Browse the repository at this point in the history
  • Loading branch information
clstatham committed Aug 30, 2023
1 parent a6467cd commit d17512f
Show file tree
Hide file tree
Showing 11 changed files with 576 additions and 292 deletions.
1 change: 1 addition & 0 deletions .cargo/config.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# linker = "clang"
# rustflags = ["-Clink-arg=-fuse-ld=lld", "-Zshare-generics=y"]
# runner = ".cargo/codelldb.sh"
# rustflags = ["-Zsanitizer=thread"]

[target.x86_64-pc-windows-msvc]
linker = "rust-lld.exe"
Expand Down
35 changes: 35 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 3 additions & 3 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@ crossbeam-channel = "0.5.8"
tokio = { version = "1.32.0", features = ["sync"] }
futures-lite = "1.13.0"
bevy_framepace = "0.13.3"

[profile.dev.package."*"]
opt-level = 3
derive_more = "0.99.17"
# [profile.dev.package."*"]
# opt-level = 3

[profile.release]
debug = true
28 changes: 16 additions & 12 deletions src/brains/replay_buffer.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use std::collections::VecDeque;

use bevy::prelude::Component;
use burn_tch::TchBackend;

use crate::{
Expand Down Expand Up @@ -62,6 +63,7 @@ pub struct PpoMetadata {
pub hiddens: Option<HiddenStates<TchBackend<f32>>>,
}

#[derive(Component)]
pub struct PpoBuffer<E: Env> {
pub obs: VecDeque<FrameStack<E::Observation>>,
pub action: VecDeque<E::Action>,
Expand Down Expand Up @@ -133,18 +135,20 @@ where
terminal,
} = step;

self.obs.push_back(obs);
self.action.push_back(action);
self.reward.push_back(reward);
self.terminal.push_back(terminal);
self.advantage.push_back(None);
self.returns.push_back(None);

self.current_trajectory_start += 1;
if let Some(max_len) = max_len {
if self.current_trajectory_start >= max_len {
self.current_trajectory_start = max_len;
self.finish_trajectory(); // in case one of them is an ABSOLUTE GAMER and doesn't die for like 100_000 frames
if action.metadata().hiddens.is_some() {
self.obs.push_back(obs);
self.action.push_back(action);
self.reward.push_back(reward);
self.terminal.push_back(terminal);
self.advantage.push_back(None);
self.returns.push_back(None);

self.current_trajectory_start += 1;
if let Some(max_len) = max_len {
if self.current_trajectory_start >= max_len {
self.current_trajectory_start = max_len;
self.finish_trajectory(); // in case one of them is an ABSOLUTE GAMER and doesn't die for like 100_000 frames
}
}
}
}
Expand Down
6 changes: 3 additions & 3 deletions src/brains/thinkers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@ impl Status for () {
}

pub trait Thinker<E: Env> {
type Metadata: Clone;
type Status: Status + Clone + Default;
type ActionMetadata: Clone + Default;
type Metadata: Clone + Send + Sync;
type Status: Status + Clone + Default + Send + Sync;
type ActionMetadata: Clone + Default + Send + Sync;
fn act(
&mut self,
obs: &FrameStack<E::Observation>,
Expand Down
Loading

0 comments on commit d17512f

Please sign in to comment.