Skip to content

Commit 1a68f25

Browse files
author
Eric Waller
committed
gpu_kernels.rs WIP (syntax check pass, RunPod build)
1 parent df1bbdf commit 1a68f25

3 files changed

Lines changed: 102 additions & 48 deletions

File tree

src/bin/gpu_benchmark.rs

Lines changed: 51 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,50 +1,56 @@
1-
// GPU benchmark for L4 FP16 sin*cos evaluation
2-
// Measures throughput in ops/sec for energy comparison
3-
4-
use erock::gpu_kernels::GpuKernels;
5-
use std::time::Instant;
1+
// src/bin/gpu_benchmark.rs (entire file)
2+
use anyhow::Result;
3+
use half::f16;
64
use std::env;
5+
use std::time::Instant;
6+
7+
// Use the new GPU module type
8+
use erock::gpu_kernels::Fp16SincosModule;
9+
10+
fn parse_arg(name: &str, default: usize) -> usize {
11+
let mut args = env::args().collect::<Vec<_>>();
12+
let mut i = 0;
13+
while i + 1 < args.len() {
14+
if args[i] == format!("--{name}") {
15+
if let Ok(v) = args[i + 1].parse::<usize>() {
16+
return v;
17+
}
18+
}
19+
i += 1;
20+
}
21+
default
22+
}
23+
24+
fn main() -> Result<()> {
25+
// Defaults (adjust via flags: --elements N --iters K)
26+
let elements = parse_arg("elements", 1_000_000);
27+
let iters = parse_arg("iters", 10);
728

8-
fn main() -> Result<(), Box<dyn std::error::Error>> {
9-
env_logger::init();
10-
11-
// Parse batch size from args (default 10M for L4 saturation)
12-
let batch_size: usize = env::args()
13-
.nth(1)
14-
.and_then(|s| s.parse().ok())
15-
.unwrap_or(10_000_000);
16-
17-
println!("Initializing GPU kernels...");
18-
let kernels = GpuKernels::new()?;
19-
20-
// Generate test data
21-
println!("Generating {} test values...", batch_size);
22-
let input: Vec<f32> = (0..batch_size)
23-
.map(|i| (i as f32) * 0.001)
24-
.collect();
25-
26-
// Warmup run
27-
println!("Warmup run...");
28-
let _ = kernels.eval_sincos_fp16(&input)?;
29-
30-
// Benchmark run
31-
println!("Running GPU benchmark...");
29+
println!("gpu_benchmark: elements={}, iters={}", elements, iters);
30+
31+
// Prepare input
32+
let input = vec![f16::from_f32(1.0); elements];
33+
34+
// Initialize GPU module
35+
let module = Fp16SincosModule::new()?;
36+
37+
// Warmup
38+
let _ = module.launch(&input, elements)?;
39+
40+
// Timed runs
3241
let start = Instant::now();
33-
let results = kernels.eval_sincos_fp16(&input)?;
34-
let elapsed = start.elapsed();
35-
36-
// Calculate metrics
37-
let ops = input.len() as f64;
38-
let ops_per_sec = ops / elapsed.as_secs_f64();
39-
let gops_per_sec = ops_per_sec / 1e9;
40-
41-
println!("\n=== GPU FP16 Benchmark Results ===");
42-
println!("Operations: {:>15}", input.len());
43-
println!("Time: {:>12.3} s", elapsed.as_secs_f64());
44-
println!("Throughput: {:>12.2} B ops/sec", gops_per_sec);
45-
println!("First result: {:>12.6}", results[0]);
46-
println!("Last result: {:>12.6}", results[results.len()-1]);
47-
println!("==================================\n");
48-
42+
for _ in 0..iters {
43+
let _out = module.launch(&input, elements)?;
44+
}
45+
let elapsed = start.elapsed().as_secs_f64();
46+
47+
// Very simple ops estimate: 2 results (sin, cos) per element per iter
48+
let total_outputs = (elements as f64) * (iters as f64) * 2.0;
49+
let ops_per_sec = total_outputs / elapsed.max(1e-9);
50+
51+
println!("elapsed_sec={:.6}", elapsed);
52+
println!("outputs={} (sin+cos per element per iter)", total_outputs as u64);
53+
println!("throughput_ops_per_sec={:.3}", ops_per_sec);
54+
4955
Ok(())
5056
}

src/gpu_kernels.rs

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,17 @@
1-
use anyhow::Result;
2-
use cudarc::driver::{CudaDevice, CudaSlice, CudaModule, CudaFunction, LaunchConfig};
1+
use anyhow: :Result;
2+
use cudarc::driver:{CudaDevice, CudaSlice, CudaModule, CudaFunction, LaunchConfig};
33
use half::f16;
44
use std::sync::Arc;
55

6-
#[cfg
6+
#cfg(feature = "gpu")
7+
pub struct Fp16SincosModule {
8+
device: Arc<CudaDevice>,
9+
module: CudaModule,
10+
func: CudaFunction,
11+
}
712

13+
#cfg(feature = "gpu")
14+
imp Fp16SincosModule {
15+
pub fn new() -> Result<Self> {
16+
+ let device = Arc::new(CudaDevice::new(0)?);
17+
let ptx_src = include_str("

src/ptx/fp16_sincos_kernel.ptx

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
.version 7.7
2+
.target sm_89
3+
.address_size 64
4+
5+
.visible .entry fp16_sincos_kernel(
6+
.param .u64 fp16_sincos_kernel_param_0,
7+
.param .u64 fp16_sincos_kernel_param_1,
8+
.param .u32 fp16_sincos_kernel_param_2
9+
)
10+
{
11+
.reg .f32 %f<4>;
12+
.reg .b32 %r<5>;
13+
.reg .f16 %rs<2>;
14+
.reg .pred %p<2>;
15+
16+
ld.param.u64 %rd1, [fp16_sincos_kernel_param_0];
17+
ld.param.u64 %rd2, [fp16_sincos_kernel_param_1];
18+
ld.param.u32 %r1, [fp16_sincos_kernel_param_2];
19+
cvta.to.global.u64 %rd3, %rd1;
20+
cvta.to.global.u64 %rd4, %rd2;
21+
mov.u32 %r2, %ctaid.x;
22+
mov.u32 %r3, %ntid.x;
23+
mad.lo.s32 %r4, %r2, %r3, %tid.x;
24+
setp.ge.s32 %p1, %r4, %r1;
25+
@%p1 bra $BB__Z19fp16_sincos_kernelPPhS0_i_1;
26+
27+
ld.global.u16 %rs1, [%rd3+2*%r4];
28+
cvt.rn.f32.f16 %f1, %rs1;
29+
sin.approx.f32 %f2, %f1;
30+
cos.approx.f32 %f3, %f1;
31+
cvt.rn.f16.f32 %rs1, %f2;
32+
st.global.u16 [%rd4+2*%r4], %rs1;
33+
cvt.rn.f16.f32 %rs1, %f3;
34+
st.global.u16 [%rd4+2*%r4+2], %rs1;
35+
36+
$BB__Z19fp16_sincos_kernelPPhS0_i_1:
37+
ret;
38+
}

0 commit comments

Comments
 (0)