From c822264329c3e124de75b4083dd91eb36365762b Mon Sep 17 00:00:00 2001 From: Matthew Gordon Date: Fri, 3 Jan 2025 15:46:37 -0400 Subject: [PATCH] Raycasting works! --- src/app/dem_renderer/dem_renderer.wgsl | 8 +- src/app/dem_renderer/mod.rs | 9 +- src/app/dem_renderer/ray_intersection.wgsl | 64 ++++++--- src/app/dem_renderer/tests.rs | 153 +++++++++++++++++++++ src/app/dem_renderer/tests.wgsl | 33 +++++ src/app/mod.rs | 3 +- 6 files changed, 242 insertions(+), 28 deletions(-) create mode 100644 src/app/dem_renderer/tests.rs create mode 100644 src/app/dem_renderer/tests.wgsl diff --git a/src/app/dem_renderer/dem_renderer.wgsl b/src/app/dem_renderer/dem_renderer.wgsl index 6faf98b..fe1f759 100644 --- a/src/app/dem_renderer/dem_renderer.wgsl +++ b/src/app/dem_renderer/dem_renderer.wgsl @@ -42,12 +42,16 @@ fn fs_solid(vertex: VertexOutput) -> @location(0) vec4 { root_node.index = vec2(0); root_node.level = textureNumLevels(dembvh_texture) - 1; + var hit_index = vec2(0); if intersect_ray_with_node(dembvh_texture, uniforms.dem_min_corner, uniforms.dem_cell_size, + uniforms.dem_z_range, ray, - root_node) { - return vec4(1.0); + root_node, + &hit_index) { + let v_i = textureLoad(dem_texture, hit_index, 0).r; + return vec4(vec3(v_i) / 65535.0, 1.0); } else { discard; } diff --git a/src/app/dem_renderer/mod.rs b/src/app/dem_renderer/mod.rs index 0eee1f5..12e1351 100644 --- a/src/app/dem_renderer/mod.rs +++ b/src/app/dem_renderer/mod.rs @@ -175,9 +175,9 @@ impl DemRenderer { let mut uniforms = UniformBufferManager::new(device); uniforms.set_dem_min_corner(glam::Vec2::new(source.x_min, source.y_min)); uniforms.set_dem_cell_size( - glam::Vec2::new(source.x_max, source.y_max) - - glam::Vec2::new(source.x_min, source.y_min) - / glam::Vec2::new(source.num_cells_x as f32, source.num_cells_y as f32), + (glam::Vec2::new(source.x_max, source.y_max) + - glam::Vec2::new(source.x_min, source.y_min)) + / glam::Vec2::new(source.num_cells_x as f32, source.num_cells_y as f32), ); uniforms.set_dem_z_range(glam::Vec2::new(source.z_min, source.z_max)); @@ -489,3 +489,6 @@ fn get_animated_camera_position(animation_time: std::time::Duration, dem_size: f dem_size * 0.7, ) } + +#[cfg(test)] +mod tests; diff --git a/src/app/dem_renderer/ray_intersection.wgsl b/src/app/dem_renderer/ray_intersection.wgsl index 397d68f..86d845a 100644 --- a/src/app/dem_renderer/ray_intersection.wgsl +++ b/src/app/dem_renderer/ray_intersection.wgsl @@ -12,25 +12,37 @@ struct AABB { struct Ray { origin: vec3, direction: vec3 }; -fn intersect_ray_with_aabb(ray: Ray, aabb: AABB) -> bool { - let t1 = (aabb.min_corner - ray.origin) / ray.direction; - let t2 = (aabb.max_corner - ray.origin) / ray.direction; +fn invert_ray_direction(v: vec3 ) -> vec3 { + return select(vec3(1.0e+30), + vec3(1.0) / v, + vec3(v)); +} + +fn intersect_ray_with_aabb(ray: Ray, aabb: AABB) -> f32 { + return intersect_ray_with_aabb_optimized(ray.origin, + invert_ray_direction(ray.direction), + aabb); +} + +fn intersect_ray_with_aabb_optimized(ray_origin: vec3, inv_ray_direction: vec3, aabb: AABB) -> f32 { + let t1 = (aabb.min_corner - ray_origin) * inv_ray_direction; + let t2 = (aabb.max_corner - ray_origin) * inv_ray_direction; let t_mins = min(t1, t2); - let t_min = min(t_mins.x, min(t_mins.y, t_mins.z)); + let t_min = max(t_mins.x, max(t_mins.y, t_mins.z)); let t_maxs = max(t1, t2); - let t_max = max(t_maxs.x, max(t_maxs.y, t_maxs.z)); - return t_min <= t_max; + let t_max = min(t_maxs.x, min(t_maxs.y, t_maxs.z)); + return select(-1.0, max(t_min, 0.0f), t_min <= t_max); } fn get_xy_min_for_node(dem_min_corner: vec2, - dem_cell_size: vec2, - node: BoundingNode) -> vec2 { + dem_cell_size: vec2, + node: BoundingNode) -> vec2 { return dem_min_corner.xy + dem_cell_size * vec2(node.index) * exp2(f32(node.level)); } fn get_xy_max_for_node(dem_min_corner: vec2, - dem_cell_size: vec2, - node: BoundingNode) -> vec2 { + dem_cell_size: vec2, + node: BoundingNode) -> vec2 { return dem_min_corner.xy + dem_cell_size * vec2(node.index + 1) * exp2(f32(node.level)); } @@ -52,31 +64,39 @@ fn pop_node_stack(stack: ptr) -> BoundingNode { fn intersect_ray_with_node(tree_texture: texture_2d, dem_min_corner: vec2, dem_cell_size: vec2, + dem_z_range: vec2, ray: Ray, - root_node: BoundingNode) -> bool { + root_node: BoundingNode, + hit_cell: ptr>) -> bool { + let inv_ray_direction = invert_ray_direction(ray.direction); var node_stack: NodeStack; node_stack.count = 0u; push_node_stack(&node_stack, root_node); + var closest_hit_distance = 1.0e+30f; while node_stack.count > 0 { let node = pop_node_stack(&node_stack); let minmax_z = textureLoad(tree_texture, node.index, i32(node.level)).rg; if minmax_z.r == 0 { return false; } - let min_z = scale_u16(minmax_z.r, uniforms.dem_z_range); - let max_z = scale_u16(minmax_z.g, uniforms.dem_z_range); + let min_z = scale_u16(minmax_z.r, dem_z_range); + let max_z = scale_u16(minmax_z.g, dem_z_range); var aabb: AABB ; aabb.min_corner = vec3(get_xy_min_for_node(dem_min_corner, - dem_cell_size, - node), - min_z); + dem_cell_size, + node), + min_z); aabb.max_corner = vec3(get_xy_max_for_node(dem_min_corner, - dem_cell_size, - node), - max_z); - if intersect_ray_with_aabb(ray, aabb) { + dem_cell_size, + node), + max_z); + let hit_distance = intersect_ray_with_aabb_optimized(ray.origin, inv_ray_direction, aabb); + if hit_distance >= 0.0 { if node.level == 0 { - return true; + if hit_distance < closest_hit_distance { + closest_hit_distance = hit_distance; + *hit_cell = node.index; + } } else { let next_index = node.index * 2; var next_node: BoundingNode; @@ -92,5 +112,5 @@ fn intersect_ray_with_node(tree_texture: texture_2d, } } } - return false; + return closest_hit_distance < 1.0e+20; } diff --git a/src/app/dem_renderer/tests.rs b/src/app/dem_renderer/tests.rs new file mode 100644 index 0000000..e2876b2 --- /dev/null +++ b/src/app/dem_renderer/tests.rs @@ -0,0 +1,153 @@ +use futures::executor::block_on; +use std::sync::mpsc::channel; + +use wgsl_shader_assembler::wgsl_module; + +use super::*; + +async fn run_compute_shader_test( + shader_module: wgpu::ShaderModuleDescriptor<'_>, + test_function: impl AsRef, + input: &[u8], + output_size: usize, +) -> Vec { + let wgpu_instance = wgpu::Instance::default(); + let adapter = wgpu_instance + .request_adapter(&wgpu::RequestAdapterOptions::default()) + .await + .unwrap(); + let (device, queue) = adapter + .request_device( + &wgpu::DeviceDescriptor { + label: None, + required_features: wgpu::Features::empty(), + required_limits: wgpu::Limits::downlevel_defaults(), + memory_hints: wgpu::MemoryHints::MemoryUsage, + }, + None, + ) + .await + .unwrap(); + + let shader_module = device.create_shader_module(shader_module); + let staging_buffer = device.create_buffer(&wgpu::BufferDescriptor { + label: None, + size: output_size as u64, + usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST, + mapped_at_creation: false, + }); + let input_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor { + label: None, + contents: input, + usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST, + }); + let output_buffer = device.create_buffer(&wgpu::BufferDescriptor { + label: None, + size: output_size as u64, + usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC, + mapped_at_creation: false, + }); + let compute_pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor { + label: None, + layout: None, + module: &shader_module, + entry_point: Some(test_function.as_ref()), + compilation_options: Default::default(), + cache: None, + }); + let bind_group_layout = compute_pipeline.get_bind_group_layout(0); + let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor { + label: None, + layout: &bind_group_layout, + entries: &[ + wgpu::BindGroupEntry { + binding: 0, + resource: input_buffer.as_entire_binding(), + }, + wgpu::BindGroupEntry { + binding: 1, + resource: output_buffer.as_entire_binding(), + }, + ], + }); + let mut encoder = + device.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None }); + { + let mut cpass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor { + label: None, + timestamp_writes: None, + }); + cpass.set_pipeline(&compute_pipeline); + cpass.set_bind_group(0, &bind_group, &[]); + cpass.dispatch_workgroups(1, 1, 1); + } + encoder.copy_buffer_to_buffer(&output_buffer, 0, &staging_buffer, 0, output_size as u64); + queue.submit(Some(encoder.finish())); + + let (tx, rx) = channel(); + let buffer_slice = staging_buffer.slice(..); + buffer_slice.map_async(wgpu::MapMode::Read, move |v| tx.send(v).unwrap()); + device.poll(wgpu::Maintain::wait()).panic_on_timeout(); + if let Ok(Ok(())) = rx.recv() { + buffer_slice.get_mapped_range().to_vec() + } else { + panic!("failed to run test on GPU") + } +} + +#[repr(C, align(16))] +#[derive(Clone, Copy, Pod, Zeroable)] +struct Vec3 { + elements: [f32; 3], + _padding: f32 +} + +#[repr(C)] +#[derive(Clone, Copy, Pod, Zeroable)] +struct TestInput { + ray_origin: Vec3, + ray_direction: Vec3, + aabb_min_corner: Vec3, + aabb_max_corner: Vec3, +} + +fn vec3(x: f32, y: f32, z: f32) -> Vec3 { + Vec3 { + elements: [x, y, z], + _padding: 0.0 + } +} + +#[test] +fn test_shaders() { + let input_buffer = vec![ + TestInput { + ray_origin: vec3(-5.0, 0.0, 0.0), + ray_direction: vec3(1.0, 0.0, 0.0), + aabb_min_corner: vec3(-1.0, -1.0, -1.0), + aabb_max_corner: vec3(1.0, 1.0, 1.0), + }, + TestInput { + ray_origin: vec3(-5.0, 0.0, 0.0), + ray_direction: vec3(1.0, 0.0, 0.0), + aabb_min_corner: vec3(-1.0, 0.5, -1.0), + aabb_max_corner: vec3(1.0, 1.0, 1.0), + }, + TestInput { + ray_origin: vec3(-5.0, 0.0, 0.0), + ray_direction: vec3(1.0, 0.15, 0.0), + aabb_min_corner: vec3(-1.5, 0.5, -1.0), + aabb_max_corner: vec3(1.0, 1.0, 1.0), + }, + ]; + let output_buffer: Vec = bytemuck::cast_slice(&block_on(run_compute_shader_test( + wgsl_module!("tests.wgsl"), + "test_intersect_ray_with_aabb", + bytemuck::cast_slice(&input_buffer), + input_buffer.len() * 4, + ))) + .to_vec(); + assert_eq!(output_buffer[0], 1); + assert_eq!(output_buffer[1], 0); + assert_eq!(output_buffer[2], 1); +} diff --git a/src/app/dem_renderer/tests.wgsl b/src/app/dem_renderer/tests.wgsl new file mode 100644 index 0000000..de0cf5d --- /dev/null +++ b/src/app/dem_renderer/tests.wgsl @@ -0,0 +1,33 @@ +@group(0) +@binding(0) +var input_data: array; +@group(0) +@binding(1) +var output_data: array; + +//#include ray_intersection.wgsl + +struct Input { + ray_origin: vec3, + ray_direction: vec3, + aabb_min_corner: vec3, + aabb_max_corner: vec3 +} + +@compute +@workgroup_size(1) +fn test_intersect_ray_with_aabb() { + for(var i=0u; i < arrayLength(&input_data); i++) { + var ray: Ray; + ray.origin = input_data[i].ray_origin; + ray.direction = input_data[i].ray_direction; + var aabb: AABB; + aabb.min_corner = input_data[i].aabb_min_corner; + aabb.max_corner = input_data[i].aabb_max_corner; + if intersect_ray_with_aabb(ray, aabb) >= 0.0 { + output_data[i] = 1u; + }else { + output_data[i] = 0u; + } + } +} diff --git a/src/app/mod.rs b/src/app/mod.rs index 5a3588f..824039d 100644 --- a/src/app/mod.rs +++ b/src/app/mod.rs @@ -46,12 +46,13 @@ impl MvuApp for App { async fn init(&mut self, instance: &Instance, surface: Surface<'static>, size: Size2i) { let adapter = instance .request_adapter(&wgpu::RequestAdapterOptions { - power_preference: wgpu::PowerPreference::default(), + power_preference: wgpu::PowerPreference::HighPerformance, force_fallback_adapter: false, compatible_surface: Some(&surface), }) .await .expect("Failed to find an appropriate adapter"); + eprintln!("Using {}", adapter.get_info().name); let (device, queue) = adapter .request_device(