Skip to content

Commit

Permalink
hzb copy depth has multisampled shader
Browse files Browse the repository at this point in the history
  • Loading branch information
schell committed Oct 15, 2024
1 parent a5c2650 commit 26fa023
Show file tree
Hide file tree
Showing 18 changed files with 503 additions and 140 deletions.
8 changes: 5 additions & 3 deletions crates/example-culling/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
//! An example app showing (and verifying) how frustum culling works in `renderling`.
//! An example app showing (and verifying) how frustum culling works in
//! `renderling`.
use std::{any::Any, sync::Arc};

use example::{camera::CameraController, utils::*};
Expand Down Expand Up @@ -123,6 +124,7 @@ impl CullingExample {
app_camera.0.id(),
if BoundingSphere::from(aabb)
.is_inside_camera_view(&frustum_camera.0, transform.get())
.0
{
material_overlapping.id()
} else {
Expand Down Expand Up @@ -230,8 +232,8 @@ impl TestAppHandler for CullingExample {
let target = Vec3::ZERO;
let up = Vec3::Y;
let view = Mat4::look_at_rh(eye, target, up);
// let projection = Mat4::orthographic_rh(-10.0, 10.0, -10.0, 10.0, -10.0, 10.0);
// let view = Mat4::IDENTITY;
// let projection = Mat4::orthographic_rh(-10.0, 10.0, -10.0, 10.0, -10.0,
// 10.0); let view = Mat4::IDENTITY;
Camera::new(projection, view)
});

Expand Down
11 changes: 9 additions & 2 deletions crates/renderling/src/bvol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -321,13 +321,20 @@ impl BoundingSphere {
}
}

pub fn is_inside_camera_view(&self, camera: &Camera, transform: Transform) -> bool {
/// Determine whether this sphere is inside the camera's view frustum after
/// being transformed by `transform`.
pub fn is_inside_camera_view(
&self,
camera: &Camera,
transform: Transform,
) -> (bool, BoundingSphere) {
let center = Mat4::from(transform).transform_point3(self.center);
let scale = Vec3::splat(transform.scale.max_element());
let radius = Mat4::from_scale(scale)
.transform_point3(Vec3::new(self.radius, 0.0, 0.0))
.distance(Vec3::ZERO);
BoundingSphere::new(center, radius).is_inside_frustum(camera.frustum())
let sphere = BoundingSphere::new(center, radius);
(sphere.is_inside_frustum(camera.frustum()), sphere)
}
}

Expand Down
11 changes: 1 addition & 10 deletions crates/renderling/src/camera.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
//! Camera projection, view and utilities.
use crabslab::SlabItem;
use glam::{Mat4, Vec3, Vec4Swizzles};
use glam::{Mat4, Vec3};

use crate::bvol::{dist_bpp, Frustum};

Expand Down Expand Up @@ -88,15 +88,6 @@ impl Camera {
pub fn z_far(&self) -> f32 {
dist_bpp(&self.frustum.planes[5], self.position)
}

/// Linearize and normalize a depth value.
pub fn linearize_depth_value(&self, depth: f32) -> f32 {
let z_near = self.z_near();
let z_far = self.z_far();
let z_linear = (2.0 * z_near) / (z_far + z_near - depth * (z_far - z_near));
// Normalize the linearized depth to [0, 1]
(z_linear - z_near) / (z_far - z_near)
}
}

/// Returns the projection and view matrices for a camera with default
Expand Down
144 changes: 101 additions & 43 deletions crates/renderling/src/cull.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@
//! Frustum culling as explained in
//! [the vulkan guide](https://vkguide.dev/docs/gpudriven/compute_culling/).
use crabslab::{Array, Id, Slab, SlabItem};
use glam::{UVec2, UVec3, Vec3Swizzles};
use glam::{UVec2, UVec3, Vec3, Vec3Swizzles};
#[allow(unused_imports)]
use spirv_std::num_traits::Float;
use spirv_std::{
arch::IndexUnchecked,
image::{sample_with, Image, ImageWithMethods},
spirv,
spirv, Sampler,
};

use crate::draw::DrawIndirectArgs;
Expand All @@ -18,9 +20,10 @@ mod cpu;
pub use cpu::*;

#[spirv(compute(threads(32)))]
pub fn compute_frustum_culling(
#[spirv(storage_buffer, descriptor_set = 0, binding = 0)] slab: &[u32],
#[spirv(storage_buffer, descriptor_set = 0, binding = 1)] args: &mut [DrawIndirectArgs],
pub fn compute_culling(
#[spirv(storage_buffer, descriptor_set = 0, binding = 0)] stage_slab: &[u32],
#[spirv(storage_buffer, descriptor_set = 0, binding = 1)] depth_pyramid_slab: &[u32],
#[spirv(storage_buffer, descriptor_set = 0, binding = 2)] args: &mut [DrawIndirectArgs],
#[spirv(global_invocation_id)] global_id: UVec3,
) {
let gid = global_id.x as usize;
Expand All @@ -31,21 +34,68 @@ pub fn compute_frustum_culling(
// Get the draw arg
let arg = unsafe { args.index_unchecked_mut(gid) };
// Get the renderlet using the draw arg's renderlet id
let renderlet = slab.read_unchecked(arg.first_instance);
let renderlet = stage_slab.read_unchecked(arg.first_instance);

arg.vertex_count = renderlet.get_vertex_count();
arg.instance_count = if renderlet.visible { 1 } else { 0 };

if renderlet.bounds.radius == 0.0 {
return;
}
let camera = slab.read(renderlet.camera_id);
let model = slab.read(renderlet.transform_id);
if !renderlet.bounds.is_inside_camera_view(&camera, model) {
let camera = stage_slab.read(renderlet.camera_id);
let model = stage_slab.read(renderlet.transform_id);
// Compute frustum culling, and then occlusion culling, if need be
let (renderlet_is_inside_frustum, sphere_in_world_coords) =
renderlet.bounds.is_inside_camera_view(&camera, model);
if renderlet_is_inside_frustum {
// Compute occlusion culling using the hierachical z-buffer.
let hzb_desc = depth_pyramid_slab.read_unchecked::<DepthPyramidDescriptor>(0u32.into());
let viewprojection = camera.view_projection();

// Find the center and radius of the bounding sphere in screen space, where
// (0, 0) is the top-left of the screen and (1, 1) is is the bottom-left.
//
// z = 0 is near and z = 1 is far.
let center_ndc = viewprojection.project_point3(sphere_in_world_coords.center);
let center = Vec3::new(
(center_ndc.x + 1.0) * 0.5,
(center_ndc.y + 1.0) * -0.5,
center_ndc.z,
);
// Find the radius (in screen space)
let radius = viewprojection
.project_point3(Vec3::new(sphere_in_world_coords.radius, 0.0, 0.0))
.distance(Vec3::ZERO);
let size_max_element = if hzb_desc.size.x > hzb_desc.size.y {
hzb_desc.size.x
} else {
hzb_desc.size.y
} as f32;
let size_in_pixels = 2.0 * radius * size_max_element;
let mip_level = size_in_pixels.log2().floor() as u32;
let x = center.x * (hzb_desc.size.x >> mip_level) as f32;
let y = center.y * (hzb_desc.size.y >> mip_level) as f32;
let depth_id = hzb_desc.id_of_depth(
mip_level,
UVec2::new(x as u32, y as u32),
depth_pyramid_slab,
);

let depth_in_hzb = depth_pyramid_slab.read_unchecked(depth_id);
let depth_of_sphere = center.z - radius;
let renderlet_is_behind_something = depth_of_sphere > depth_in_hzb;

if renderlet_is_behind_something {
arg.instance_count = 0;
}
} else {
arg.instance_count = 0;
}
}

/// A hierarchichal depth buffer.
///
/// AKA HZB
#[derive(Clone, Copy, Default, SlabItem)]
pub struct DepthPyramidDescriptor {
/// Size of the top layer mip.
Expand All @@ -68,6 +118,11 @@ impl DepthPyramidDescriptor {
!(global_invocation.x < current_size.x && global_invocation.y < current_size.y)
}

#[cfg(test)]
fn size_at(&self, mip_level: u32) -> UVec2 {
UVec2::new(self.size.x >> mip_level, self.size.y >> mip_level)
}

/// Return the [`Id`] of the depth at the given `mip_level` and coordinate.
fn id_of_depth(&self, mip_level: u32, coord: UVec2, slab: &[u32]) -> Id<f32> {
let mip_array = slab.read(self.mip.at(mip_level as usize));
Expand All @@ -77,9 +132,8 @@ impl DepthPyramidDescriptor {
}
}

pub type DepthImage2d = Image!(2D, type=f32, sampled=true, depth=true);
pub type DepthPyramidImage = Image!(2D, format = r32f, sampled = true, depth = false);
pub type DepthPyramidImageMut = Image!(2D, format = r32f, depth = false);
pub type DepthImage2d = Image!(2D, type=f32, sampled, depth);
pub type DepthImage2dMultisampled = Image!(2D, type=f32, sampled, depth, multisampled);

/// Copies a depth texture to the top mip of a pyramid of mips.
///
Expand All @@ -103,14 +157,36 @@ pub fn compute_copy_depth_to_pyramid(
slab.write(dest_id, &depth);
}

/// Downsample from `DepthPyramidDescriptor::mip_level` into
/// `DepthPyramidDescriptor::mip_level + 1`.
/// Copies a depth texture to the top mip of a pyramid of mips.
///
/// It is assumed that a [`DepthPyramidDescriptor`] is stored at index `0` in
/// the given slab.
#[spirv(compute(threads(32, 32, 1)))]
pub fn compute_copy_depth_to_pyramid_multisampled(
#[spirv(descriptor_set = 0, binding = 0, storage_buffer)] slab: &mut [u32],
#[spirv(descriptor_set = 0, binding = 1)] depth_texture: &DepthImage2dMultisampled,
#[spirv(global_invocation_id)] global_id: UVec3,
) {
let desc = slab.read_unchecked::<DepthPyramidDescriptor>(0u32.into());
if desc.should_skip_invocation(global_id) {
return;
}

let depth = depth_texture
.fetch_with(global_id.xy(), sample_with::sample_index(0))
.x;
let dest_id = desc.id_of_depth(0, global_id.xy(), slab);
slab.write(dest_id, &depth);
}

/// Downsample from `DepthPyramidDescriptor::mip_level-1` into
/// `DepthPyramidDescriptor::mip_level`.
///
/// It is assumed that a [`DepthPyramidDescriptor`] is stored at index `0` in
/// the given slab.
///
/// The `DepthPyramidDescriptor`'s `mip_level` field will point to that of the
/// level being sampled.
/// mip level being downsampled to (the mip level being written into).
///
/// This shader should be called in a loop from from `1..mip_count`.
#[spirv(compute(threads(32, 32, 1)))]
Expand All @@ -128,32 +204,14 @@ pub fn compute_downsample_depth_pyramid(
//
// a b
// c d
let a = slab.read(desc.id_of_depth(desc.mip_level, global_id.xy(), slab));
let b = slab.read(desc.id_of_depth(desc.mip_level, global_id.xy() + UVec2::new(1, 0), slab));
let c = slab.read(desc.id_of_depth(desc.mip_level, global_id.xy() + UVec2::new(0, 1), slab));
let d = slab.read(desc.id_of_depth(desc.mip_level, global_id.xy() + UVec2::new(1, 1), slab));
let depth_value = a.min(b).min(c).min(d);
// Write the texel in the current mip
let current_mip_id = desc.id_of_depth(desc.mip_level + 1, global_id.xy() / 2, slab);
slab.write(current_mip_id, &depth_value);
}

#[spirv(compute(threads(32)))]
pub fn compute_occlusion_culling(
#[spirv(storage_buffer, descriptor_set = 0, binding = 0)] slab: &[u32],
#[spirv(storage_buffer, descriptor_set = 0, binding = 1)] args: &mut [DrawIndirectArgs],
#[spirv(global_invocation_id)] global_id: UVec3,
) {
let gid = global_id.x as usize;
if gid >= args.len() {
return;
}

// Get the draw arg
let arg = unsafe { args.index_unchecked_mut(gid) };
// Get the renderlet using the draw arg's renderlet id
let renderlet = slab.read_unchecked(arg.first_instance);

arg.vertex_count = renderlet.get_vertex_count();
arg.instance_count = if renderlet.visible { 1 } else { 0 };
let a_coord = global_id.xy() * 2;
let a = slab.read(desc.id_of_depth(desc.mip_level - 1, a_coord, slab));
let b = slab.read(desc.id_of_depth(desc.mip_level - 1, a_coord + UVec2::new(1, 0), slab));
let c = slab.read(desc.id_of_depth(desc.mip_level - 1, a_coord + UVec2::new(0, 1), slab));
let d = slab.read(desc.id_of_depth(desc.mip_level - 1, a_coord + UVec2::new(1, 1), slab));
// Take the maximum depth of the region (max depth means furthest away)
let depth_value = a.max(b).max(c).max(d);
// Write the texel in the next mip
let depth_id = desc.id_of_depth(desc.mip_level, global_id.xy(), slab);
slab.write(depth_id, &depth_value);
}
Loading

0 comments on commit 26fa023

Please sign in to comment.