Skip to content

Commit

Permalink
got cuda working
Browse files Browse the repository at this point in the history
  • Loading branch information
clstatham committed Aug 24, 2023
1 parent 55215d3 commit e1eb023
Show file tree
Hide file tree
Showing 5 changed files with 48 additions and 6 deletions.
23 changes: 23 additions & 0 deletions .cargo/config.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
[target.x86_64-unknown-linux-gnu]
linker = "clang"
rustflags = ["-Clink-arg=-fuse-ld=lld", "-Zshare-generics=y"]

# NOTE: you must install [Mach-O LLD Port](https://lld.llvm.org/MachO/index.html) on mac. you can easily do this by installing llvm which includes lld with the "brew" package manager:
# `brew install llvm`
[target.x86_64-apple-darwin]
rustflags = [
"-C",
"link-arg=-fuse-ld=/usr/local/opt/llvm/bin/ld64.lld",
"-Zshare-generics=y",
]

[target.aarch64-apple-darwin]
rustflags = [
"-C",
"link-arg=-fuse-ld=/opt/homebrew/opt/llvm/bin/ld64.lld",
"-Zshare-generics=y",
]

[target.x86_64-pc-windows-msvc]
linker = "rust-lld.exe"
rustflags = ["-Zshare-generics=n"]
10 changes: 10 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
bevy = "0.11.2"
bevy = { version = "0.11.2", features = ["dynamic_linking"] }
bevy_rapier2d = "0.22.0"
bevy_egui = "0.21"
chrono = "0.4.26"
Expand Down
2 changes: 2 additions & 0 deletions rust-toolchain.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
[toolchain]
channel = "nightly"
17 changes: 12 additions & 5 deletions src/brains/thinkers/ppo.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ use burn::{
module::{ADModule, Module, ModuleMapper},
nn::{Linear, LinearConfig, ReLU},
optim::{
adaptor::OptimizerAdaptor, Adam, AdamConfig, GradientsParams, Optimizer, Sgd, SgdConfig,
SimpleOptimizer,
adaptor::OptimizerAdaptor, momentum::MomentumConfig, Adam, AdamConfig, GradientsParams,
Optimizer, Sgd, SgdConfig, SimpleOptimizer,
},
record::{BinGzFileRecorder, FullPrecisionSettings},
tensor::{backend::Backend, Tensor},
Expand Down Expand Up @@ -291,15 +291,22 @@ impl PpoThinker {
}
.init()
.fork(&TchDevice::Cuda(0));
dbg!(actor.devices());
let critic = PpoCriticConfig {
obs_len: OBS_LEN,
hidden_len: AGENT_HIDDEN_DIM,
}
.init()
.fork(&TchDevice::Cuda(0));
let actor_optim = SgdConfig::new().init();
let critic_optim = SgdConfig::new().init();
let actor_optim = SgdConfig::new()
.with_momentum(Some(
MomentumConfig::new().with_momentum(0.9).with_nesterov(true),
))
.init();
let critic_optim = SgdConfig::new()
.with_momentum(Some(
MomentumConfig::new().with_momentum(0.9).with_nesterov(true),
))
.init();
Self {
actor,
critic,
Expand Down

0 comments on commit e1eb023

Please sign in to comment.