From 4f1fb083261f4ae74014f229c80d1c223a6b8c2a Mon Sep 17 00:00:00 2001 From: Nicolas Moenne-Loccoz Date: Tue, 24 Mar 2026 17:17:51 -0400 Subject: [PATCH 1/9] Fix incident direction gradient and exclusive gradient --- .../3dgrt/kernels/cuda/3dgrtTracer.cuh | 6 +-- .../slang/models/shRadiativeParticles.slang | 17 ++++++--- .../kernels/cuda/referenceB2FSlangBwdOptix.cu | 2 +- .../kernels/cuda/referenceSlangBwdOptix.cu | 10 ++--- .../src/kernels/cuda/referenceSlangOptix.cu | 6 +-- .../models/shRadiativeGaussianParticles.cuh | 37 +++++++++++-------- .../kernels/cuda/renderers/gutProjector.cuh | 15 ++++---- .../slang/models/gaussianParticles.slang | 8 ++-- .../slang/models/shRadiativeParticles.slang | 17 ++++++--- 9 files changed, 69 insertions(+), 49 deletions(-) diff --git a/threedgrt_tracer/include/3dgrt/kernels/cuda/3dgrtTracer.cuh b/threedgrt_tracer/include/3dgrt/kernels/cuda/3dgrtTracer.cuh index f799b647..b443977e 100644 --- a/threedgrt_tracer/include/3dgrt/kernels/cuda/3dgrtTracer.cuh +++ b/threedgrt_tracer/include/3dgrt/kernels/cuda/3dgrtTracer.cuh @@ -165,7 +165,7 @@ static __device__ __inline__ void traceVolumetricGS( rayOrigin, rayDirection, rayHit.particleId, - {{(gaussianParticle_RawParameters_0*)params.particleDensity, nullptr}}, + {{(gaussianParticle_RawParameters_0*)params.particleDensity, nullptr, true}}, &rayTransmittance, &rayData.hitDistance, #ifdef ENABLE_NORMALS @@ -178,7 +178,7 @@ static __device__ __inline__ void traceVolumetricGS( particleFeaturesIntegrateFwdFromBuffer(rayDirection, hitWeight, rayHit.particleId, - {{(float3*)params.particleRadiance, nullptr}, params.sphDegree}, + {{(float3*)params.particleRadiance, nullptr, true}, params.sphDegree}, &rayData.radiance); rayLastHitDistance = fmaxf(rayLastHitDistance, rayHit.distance); @@ -205,7 +205,7 @@ static __device__ __inline__ void intersectVolumetricGS() { : particleDensityHitCustom(optixGetWorldRayOrigin(), optixGetWorldRayDirection(), optixGetPrimitiveIndex(), - {{(gaussianParticle_RawParameters_0*)params.particleDensity, nullptr}}, + {{(gaussianParticle_RawParameters_0*)params.particleDensity, nullptr, true}}, optixGetRayTmin(), optixGetRayTmax(), params.hitMaxParticleSquaredDistance, diff --git a/threedgrt_tracer/include/3dgrt/kernels/slang/models/shRadiativeParticles.slang b/threedgrt_tracer/include/3dgrt/kernels/slang/models/shRadiativeParticles.slang index f2755a60..d6ef7d73 100644 --- a/threedgrt_tracer/include/3dgrt/kernels/slang/models/shRadiativeParticles.slang +++ b/threedgrt_tracer/include/3dgrt/kernels/slang/models/shRadiativeParticles.slang @@ -76,14 +76,14 @@ void fetchParametersFromBufferBwd(no_diff uint32_t particleIdx, [BackwardDifferentiable] [ForceInline] vector radianceFromBuffer(no_diff uint32_t particleIdx, - no_diff float3 incidentDirection, + in float3 incidentDirection, no_diff uint32_t sphDegree, no_diff ParametersBuffer parametersBuffer) { return sphericalHarmonics.decode( - sphDegree, - fetchParametersFromBuffer(particleIdx, parametersBuffer).sphCoefficients, - incidentDirection); + sphDegree, + fetchParametersFromBuffer(particleIdx, parametersBuffer).sphCoefficients, + incidentDirection); } [BackwardDifferentiable][ForceInline] @@ -248,13 +248,18 @@ inline void particleFeaturesIntegrateFwd(in float weight, in uint32_t particleIdx, shRadiativeParticle.CommonParameters commonParameters, in vector featuresGrad, - in float3 incidentDirection + in float3 incidentDirection, + inout float3 incidentDirectionGrad ) { + DifferentialPair incidentDirectionDiff = DifferentialPair(incidentDirection, incidentDirectionGrad); + bwd_diff(shRadiativeParticle.radianceFromBuffer)( particleIdx, - incidentDirection, + incidentDirectionDiff, commonParameters.sphDegree, commonParameters.parametersBuffer, featuresGrad); + + incidentDirectionGrad += incidentDirectionDiff.getDifferential(); } diff --git a/threedgrt_tracer/src/kernels/cuda/referenceB2FSlangBwdOptix.cu b/threedgrt_tracer/src/kernels/cuda/referenceB2FSlangBwdOptix.cu index 8dda03ca..7459ff42 100644 --- a/threedgrt_tracer/src/kernels/cuda/referenceB2FSlangBwdOptix.cu +++ b/threedgrt_tracer/src/kernels/cuda/referenceB2FSlangBwdOptix.cu @@ -185,7 +185,7 @@ extern "C" __global__ void __intersection__is() { : particleDensityHitCustom(optixGetWorldRayOrigin(), optixGetWorldRayDirection(), optixGetPrimitiveIndex(), - {{(gaussianParticle_RawParameters_0*)params.particleDensity, nullptr}}, + {{(gaussianParticle_RawParameters_0*)params.particleDensity, nullptr, true}}, optixGetRayTmin(), optixGetRayTmax(), params.hitMaxParticleSquaredDistance, diff --git a/threedgrt_tracer/src/kernels/cuda/referenceSlangBwdOptix.cu b/threedgrt_tracer/src/kernels/cuda/referenceSlangBwdOptix.cu index 06dbe09b..2e983fc6 100644 --- a/threedgrt_tracer/src/kernels/cuda/referenceSlangBwdOptix.cu +++ b/threedgrt_tracer/src/kernels/cuda/referenceSlangBwdOptix.cu @@ -153,7 +153,7 @@ extern "C" __global__ void __raygen__rg() { if (particleDensityHit( rayOrigin, rayDirection, particleDensityParameters(rayHit.particleId, - {{(gaussianParticle_RawParameters_0*)params.particleDensity, nullptr}}), + {{(gaussianParticle_RawParameters_0*)params.particleDensity, nullptr, true}}), &hitAlpha, &hitDistance, #ifdef ENABLE_NORMALS true, &hitNormal @@ -165,7 +165,7 @@ extern "C" __global__ void __raygen__rg() { ) { const float3 hitRadiance = particleFeaturesFromBuffer( rayHit.particleId, - {{(float3*)params.particleRadiance, (float3*)params.particleRadianceGrad}, (int)params.sphDegree}, + {{(float3*)params.particleRadiance, (float3*)params.particleRadianceGrad, false}, (int)params.sphDegree}, rayDirection); float hitAlphaGrad = 0.f; @@ -173,7 +173,7 @@ extern "C" __global__ void __raygen__rg() { hitAlpha, &hitAlphaGrad, rayHit.particleId, - {{(float3*)params.particleRadiance, (float3*)params.particleRadianceGrad}, params.sphDegree}, + {{(float3*)params.particleRadiance, (float3*)params.particleRadianceGrad, false}, params.sphDegree}, hitRadiance, &rayRadiance, &rayRadianceGrad); @@ -182,7 +182,7 @@ extern "C" __global__ void __raygen__rg() { rayDirection, rayHit.particleId, {{(gaussianParticle_RawParameters_0*)params.particleDensity, - (gaussianParticle_RawParameters_0*)params.particleDensityGrad}}, + (gaussianParticle_RawParameters_0*)params.particleDensityGrad, false}}, hitAlpha, hitAlphaGrad, &rayTransmittance, @@ -215,7 +215,7 @@ extern "C" __global__ void __intersection__is() { : particleDensityHitCustom(optixGetWorldRayOrigin(), optixGetWorldRayDirection(), optixGetPrimitiveIndex(), - {{(gaussianParticle_RawParameters_0*)params.particleDensity, nullptr}}, + {{(gaussianParticle_RawParameters_0*)params.particleDensity, nullptr, true}}, optixGetRayTmin(), optixGetRayTmax(), params.hitMaxParticleSquaredDistance, diff --git a/threedgrt_tracer/src/kernels/cuda/referenceSlangOptix.cu b/threedgrt_tracer/src/kernels/cuda/referenceSlangOptix.cu index 86a82893..5593d11f 100644 --- a/threedgrt_tracer/src/kernels/cuda/referenceSlangOptix.cu +++ b/threedgrt_tracer/src/kernels/cuda/referenceSlangOptix.cu @@ -140,7 +140,7 @@ extern "C" __global__ void __raygen__rg() { rayOrigin, rayDirection, rayHit.particleId, - {{(gaussianParticle_RawParameters_0*)params.particleDensity, nullptr}}, + {{(gaussianParticle_RawParameters_0*)params.particleDensity, nullptr, true}}, &rayTransmittance, &rayHitDistance, #ifdef ENABLE_NORMALS @@ -153,7 +153,7 @@ extern "C" __global__ void __raygen__rg() { particleFeaturesIntegrateFwdFromBuffer(rayDirection, hitWeight, rayHit.particleId, - {{(float3*)params.particleRadiance, nullptr}, params.sphDegree}, + {{(float3*)params.particleRadiance, nullptr, true}, params.sphDegree}, &rayRadiance); // NOTE(qi): Race condition here, but as we are writing the same value, it seems it is safe. @@ -197,7 +197,7 @@ extern "C" __global__ void __intersection__is() { : particleDensityHitCustom(optixGetWorldRayOrigin(), optixGetWorldRayDirection(), optixGetPrimitiveIndex(), - {{(gaussianParticle_RawParameters_0*)params.particleDensity, nullptr}}, + {{(gaussianParticle_RawParameters_0*)params.particleDensity, nullptr, true}}, optixGetRayTmin(), optixGetRayTmax(), params.hitMaxParticleSquaredDistance, diff --git a/threedgut_tracer/include/3dgut/kernels/cuda/models/shRadiativeGaussianParticles.cuh b/threedgut_tracer/include/3dgut/kernels/cuda/models/shRadiativeGaussianParticles.cuh index a4aaed5f..d25c461d 100644 --- a/threedgut_tracer/include/3dgut/kernels/cuda/models/shRadiativeGaussianParticles.cuh +++ b/threedgut_tracer/include/3dgut/kernels/cuda/models/shRadiativeGaussianParticles.cuh @@ -65,7 +65,7 @@ struct ShRadiativeGaussianVolumetricFeaturesParticles : Params, public ExtParams __forceinline__ __device__ DensityParameters fetchDensityParameters(uint32_t particleIdx) const { const auto parameters = particleDensityParameters( particleIdx, - {reinterpret_cast(m_densityRawParameters.ptr), nullptr}); + {{reinterpret_cast(m_densityRawParameters.ptr), nullptr, true}}); return *reinterpret_cast(¶meters); } @@ -218,14 +218,18 @@ struct ShRadiativeGaussianVolumetricFeaturesParticles : Params, public ExtParams template __forceinline__ __device__ void densityIncidentDirectionBwdToBuffer(uint32_t particlesIdx, - const tcnn::vec3& sourcePosition) - + const tcnn::vec3& sourcePosition, + const tcnn::vec3& incidentDirectionGrad) { - particleDensityIncidentDirectionBwdToBuffer(particlesIdx, - {{reinterpret_cast(m_densityRawParameters.ptr), - reinterpret_cast(m_densityRawParameters.gradPtr), - exclusiveGradient}}, - *reinterpret_cast(&sourcePosition)); + if constexpr (TDifferentiable) { + particleDensityIncidentDirectionBwdToBuffer( + particlesIdx, + {{reinterpret_cast(m_densityRawParameters.ptr), + reinterpret_cast(m_densityRawParameters.gradPtr), + exclusiveGradient}}, + *reinterpret_cast(&sourcePosition), + *reinterpret_cast(&incidentDirectionGrad)); + } } using FeaturesParameters = shRadiativeParticle_Parameters_0; @@ -264,13 +268,16 @@ struct ShRadiativeGaussianVolumetricFeaturesParticles : Params, public ExtParams template __forceinline__ __device__ void featuresBwdToBuffer(uint32_t particleIdx, - const TFeaturesVec& featuresGrad, - const tcnn::vec3& incidentDirection) const { - - particleFeaturesBwdToBuffer(particleIdx, - {{m_featureRawParameters.ptr, m_featureRawParameters.gradPtr, exclusiveGradient}, m_featureActiveShDegree}, - *reinterpret_cast(&featuresGrad), - *reinterpret_cast(&incidentDirection)); + const TFeaturesVec& featuresGrad, + const tcnn::vec3& incidentDirection, + tcnn::vec3& incidentDirectionGrad) const { + if constexpr (TDifferentiable) { + particleFeaturesBwdToBuffer(particleIdx, + {{m_featureRawParameters.ptr, m_featureRawParameters.gradPtr, exclusiveGradient}, m_featureActiveShDegree}, + *reinterpret_cast(&featuresGrad), + *reinterpret_cast(&incidentDirection), + reinterpret_cast(&incidentDirectionGrad)); + } } template diff --git a/threedgut_tracer/include/3dgut/kernels/cuda/renderers/gutProjector.cuh b/threedgut_tracer/include/3dgut/kernels/cuda/renderers/gutProjector.cuh index 6319b96f..6ba66605 100644 --- a/threedgut_tracer/include/3dgut/kernels/cuda/renderers/gutProjector.cuh +++ b/threedgut_tracer/include/3dgut/kernels/cuda/renderers/gutProjector.cuh @@ -416,16 +416,17 @@ struct GUTProjector : Params, UTParams { Particles particles; particles.initializeDensity(parameters); + particles.initializeDensityGradient(parametersGradient); const tcnn::vec3 incidentDirection = tcnn::normalize(particles.fetchPosition(particleIdx) - sensorWorldPosition); + tcnn::vec3 incidentDirectionGrad = tcnn::vec3(0.0f); particles.initializeFeatures(parameters); particles.initializeFeaturesGradient(parametersGradient); - particles.featuresBwdCustomToBuffer( - particleIdx, - reinterpret_cast(particlesPrecomputedFeaturesPtr)[particleIdx], - reinterpret_cast(particlesPrecomputedFeaturesGradPtr)[particleIdx], - incidentDirection); - particles.initializeDensityGradient(parametersGradient); - particles.template densityIncidentDirectionBwdToBuffer(particleIdx, sensorWorldPosition); + particles.featuresBwdToBuffer(particleIdx, + reinterpret_cast(particlesPrecomputedFeaturesGradPtr)[particleIdx], + incidentDirection, + incidentDirectionGrad); + + particles.template densityIncidentDirectionBwdToBuffer(particleIdx, sensorWorldPosition, incidentDirectionGrad); } }; diff --git a/threedgut_tracer/include/3dgut/kernels/slang/models/gaussianParticles.slang b/threedgut_tracer/include/3dgut/kernels/slang/models/gaussianParticles.slang index af44b9cf..2e938e26 100644 --- a/threedgut_tracer/include/3dgut/kernels/slang/models/gaussianParticles.slang +++ b/threedgut_tracer/include/3dgut/kernels/slang/models/gaussianParticles.slang @@ -313,7 +313,7 @@ float3 incidentDirectionFromParameters( } [BackwardDifferentiable][ForceInline] -no_diff float3 incidentDirectionFromBuffer( +float3 incidentDirectionFromBuffer( no_diff uint32_t particleIdx, no_diff RawParametersBuffer parametersBuffer, no_diff float3 sourcePosition @@ -525,12 +525,14 @@ bool particleDensityHitInstance( [CudaDeviceExport] void particleDensityIncidentDirectionBwdToBuffer( in uint32_t particleIdx, gaussianParticle.CommonParameters commonParameters, - in float3 sourcePosition + in float3 sourcePosition, + in float3 incidentDirectionGrad ) { bwd_diff(gaussianParticle.incidentDirectionFromBuffer)( particleIdx, commonParameters.parametersBuffer, - sourcePosition + sourcePosition, + incidentDirectionGrad ); } diff --git a/threedgut_tracer/include/3dgut/kernels/slang/models/shRadiativeParticles.slang b/threedgut_tracer/include/3dgut/kernels/slang/models/shRadiativeParticles.slang index 66334e49..aefe138c 100644 --- a/threedgut_tracer/include/3dgut/kernels/slang/models/shRadiativeParticles.slang +++ b/threedgut_tracer/include/3dgut/kernels/slang/models/shRadiativeParticles.slang @@ -70,14 +70,14 @@ void fetchParametersFromBufferBwd(no_diff uint32_t particleIdx, [BackwardDifferentiable] [ForceInline] vector radianceFromBuffer(no_diff uint32_t particleIdx, - no_diff float3 incidentDirection, + in float3 incidentDirection, no_diff uint32_t sphDegree, no_diff ParametersBuffer parametersBuffer) { return sphericalHarmonics.decode( - sphDegree, - fetchParametersFromBuffer(particleIdx, parametersBuffer).sphCoefficients, - incidentDirection); + sphDegree, + fetchParametersFromBuffer(particleIdx, parametersBuffer).sphCoefficients, + incidentDirection); } [BackwardDifferentiable][ForceInline] @@ -242,13 +242,18 @@ inline void particleFeaturesIntegrateFwd(in float weight, in uint32_t particleIdx, shRadiativeParticle.CommonParameters commonParameters, in vector featuresGrad, - in float3 incidentDirection + in float3 incidentDirection, + inout float3 incidentDirectionGrad ) { + DifferentialPair incidentDirectionDiff = DifferentialPair(incidentDirection, incidentDirectionGrad); + bwd_diff(shRadiativeParticle.radianceFromBuffer)( particleIdx, - incidentDirection, + incidentDirectionDiff, commonParameters.sphDegree, commonParameters.parametersBuffer, featuresGrad); + + incidentDirectionGrad += incidentDirectionDiff.getDifferential(); } From 9b4e6e4b7e9c1b2ba4e7870cc1c2e1187a02a560 Mon Sep 17 00:00:00 2001 From: Nicolas Moenne-Loccoz Date: Tue, 24 Mar 2026 16:47:46 -0400 Subject: [PATCH 2/9] Fixes clamp and constexpr --- .../include/3dgrt/kernels/cuda/gaussianParticles.cuh | 2 +- .../include/3dgut/kernels/cuda/renderers/gutKBufferRenderer.cuh | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/threedgrt_tracer/include/3dgrt/kernels/cuda/gaussianParticles.cuh b/threedgrt_tracer/include/3dgrt/kernels/cuda/gaussianParticles.cuh index 089e4c4c..ebf7c100 100644 --- a/threedgrt_tracer/include/3dgrt/kernels/cuda/gaussianParticles.cuh +++ b/threedgrt_tracer/include/3dgrt/kernels/cuda/gaussianParticles.cuh @@ -548,7 +548,7 @@ __device__ inline void processHitBwd( // => groRayHitGrd_j = -grd_j * dot(grdsRayHitGrd * gscl, grd) const float grdScaledDot = dot(grdsRayHitGrd * gscl, grd); float3 grdRayHitGrd, groRayHitGrd; - if constexpr (SurfelPrimitive) { + if (SurfelPrimitive) { const float h = -gro.z / grd.z; grdRayHitGrd = gscl * grdsRayHitGrd * h - make_float3(0.f, 0.f, (h / grd.z) * grdScaledDot); groRayHitGrd = make_float3(0.f, 0.f, -grdScaledDot / grd.z); diff --git a/threedgut_tracer/include/3dgut/kernels/cuda/renderers/gutKBufferRenderer.cuh b/threedgut_tracer/include/3dgut/kernels/cuda/renderers/gutKBufferRenderer.cuh index 1390f8d6..bf73ec9b 100644 --- a/threedgut_tracer/include/3dgut/kernels/cuda/renderers/gutKBufferRenderer.cuh +++ b/threedgut_tracer/include/3dgut/kernels/cuda/renderers/gutKBufferRenderer.cuh @@ -127,7 +127,7 @@ struct GUTKBufferRenderer : Params { TFeaturesVec particleFeaturesGradientVec = TFeaturesVec::zero(); particles.featuresIntegrateBwd(hitParticle.alpha, hitAlphaGrad, - particleFeatures[hitParticle.idx], + tcnn::max(particleFeatures[hitParticle.idx], 0.f), particleFeaturesGradientVec, ray.featuresBackward, ray.featuresGradient); From 753ae0fc5be33d697f834b579f1d5dc7cc501fed Mon Sep 17 00:00:00 2001 From: Nicolas Moenne-Loccoz Date: Tue, 21 Apr 2026 11:51:49 -0400 Subject: [PATCH 3/9] Neural Harmonic Textures integration for 3DGRT and 3DGUT --- .gitignore | 1 + configs/apps/colmap_3dgrt_mcmc_nht.yaml | 22 ++ configs/apps/colmap_3dgut_mcmc_nht.yaml | 18 + .../apps/nerf_synthetic_3dgrt_mcmc_nht.yaml | 22 ++ .../apps/nerf_synthetic_3dgut_mcmc_nht.yaml | 18 + configs/base_gs.yaml | 34 +- configs/render/3dgrt.yaml | 2 + .../3dgrt/kernels/cuda/3dgrtTracer.cuh | 43 ++- .../3dgrt/kernels/cuda/gaussianParticles.cuh | 9 + .../slang/models/gaussianParticles.slang | 33 +- .../neuralHarmonicFeaturesParticle.slang | 307 +++++++++++++++ .../slang/models/radiativeParticles.slang | 27 ++ .../slang/models/shRadiativeParticles.slang | 154 +++++--- .../include/3dgrt/pipelineParameters.h | 5 + threedgrt_tracer/setup_3dgrt.py | 49 ++- .../kernels/cuda/referenceSlangBwdOptix.cu | 39 +- .../src/kernels/cuda/referenceSlangOptix.cu | 23 +- threedgrt_tracer/src/optixTracer.cpp | 6 +- threedgrt_tracer/tracer.py | 27 +- threedgrut/model/feature_decoder.py | 212 +++++++++++ threedgrut/model/features.py | 168 +++++++++ threedgrut/model/model.py | 312 +++++++++++++--- threedgrut/render.py | 65 +++- threedgrut/trainer.py | 178 ++++++++- threedgrut/utils/gui.py | 2 +- threedgrut/utils/misc.py | 18 +- threedgrut/utils/render.py | 70 +++- threedgrut/utils/viser_gui_util.py | 2 +- threedgrut_playground/engine.py | 24 +- threedgrut_playground/tracer.py | 12 +- .../3dgut/kernels/cuda/common/rayPayload.cuh | 32 +- .../cuda/common/rayPayloadBackward.cuh | 35 +- .../models/shRadiativeGaussianParticles.cuh | 213 ++++++++--- .../cuda/renderers/gutKBufferRenderer.cuh | 255 +++++++++---- .../kernels/cuda/renderers/gutProjector.cuh | 47 ++- .../kernels/cuda/renderers/gutRenderer.cuh | 16 +- .../slang/models/gaussianParticles.slang | 22 +- .../neuralHarmonicFeaturesParticle.slang | 351 ++++++++++++++++++ .../slang/models/radiativeParticles.slang | 27 ++ .../slang/models/shRadiativeParticles.slang | 207 +++++++---- .../include/3dgut/renderer/gutRenderer.h | 16 +- threedgut_tracer/include/3dgut/threedgut.cuh | 12 +- .../include/3dgut/threedgut.slang | 9 +- threedgut_tracer/setup_3dgut.py | 36 +- threedgut_tracer/src/gutRenderer.cu | 20 +- threedgut_tracer/src/splatRaster.cpp | 14 +- threedgut_tracer/tracer.py | 48 +-- 47 files changed, 2729 insertions(+), 533 deletions(-) create mode 100644 configs/apps/colmap_3dgrt_mcmc_nht.yaml create mode 100644 configs/apps/colmap_3dgut_mcmc_nht.yaml create mode 100644 configs/apps/nerf_synthetic_3dgrt_mcmc_nht.yaml create mode 100644 configs/apps/nerf_synthetic_3dgut_mcmc_nht.yaml create mode 100644 threedgrt_tracer/include/3dgrt/kernels/slang/models/neuralHarmonicFeaturesParticle.slang create mode 100644 threedgrt_tracer/include/3dgrt/kernels/slang/models/radiativeParticles.slang create mode 100644 threedgrut/model/feature_decoder.py create mode 100644 threedgrut/model/features.py create mode 100644 threedgut_tracer/include/3dgut/kernels/slang/models/neuralHarmonicFeaturesParticle.slang create mode 100644 threedgut_tracer/include/3dgut/kernels/slang/models/radiativeParticles.slang diff --git a/.gitignore b/.gitignore index feedbb5e..b389391e 100644 --- a/.gitignore +++ b/.gitignore @@ -24,6 +24,7 @@ thirdparty/kaolin/ threedgrt_tracer/.ninja_log threedgrt_tracer/include/3dgrt/kernels/slang/*.cuh* +threedgut_tracer/include/threedgutSlang.cuh *.egg-info .idea diff --git a/configs/apps/colmap_3dgrt_mcmc_nht.yaml b/configs/apps/colmap_3dgrt_mcmc_nht.yaml new file mode 100644 index 00000000..fb53696d --- /dev/null +++ b/configs/apps/colmap_3dgrt_mcmc_nht.yaml @@ -0,0 +1,22 @@ +# @package _global_ +# NHT (Neural Harmonic Textures) variant for colmap datasets with 3DGRT and MCMC + +defaults: + - /base_mcmc + - /dataset: colmap + - /initialization: colmap + - /render: 3dgrt + - _self_ + +model: + feature_type: "nht" + +render: + pipeline_type: referenceSlang + backward_pipeline_type: referenceSlangBwd + +loss: + use_opacity: true + lambda_opacity: 0.02 + use_scale: true + lambda_scale: 0.005 diff --git a/configs/apps/colmap_3dgut_mcmc_nht.yaml b/configs/apps/colmap_3dgut_mcmc_nht.yaml new file mode 100644 index 00000000..e06d8f0f --- /dev/null +++ b/configs/apps/colmap_3dgut_mcmc_nht.yaml @@ -0,0 +1,18 @@ +# @package _global_ +# NHT (Neural Harmonic Textures) variant for colmap datasets with 3DGUT and MCMC + +defaults: + - /base_mcmc + - /dataset: colmap + - /initialization: colmap + - /render: 3dgut + - _self_ + +model: + feature_type: "nht" + +loss: + use_opacity: true + lambda_opacity: 0.02 + use_scale: true + lambda_scale: 0.005 diff --git a/configs/apps/nerf_synthetic_3dgrt_mcmc_nht.yaml b/configs/apps/nerf_synthetic_3dgrt_mcmc_nht.yaml new file mode 100644 index 00000000..f398179f --- /dev/null +++ b/configs/apps/nerf_synthetic_3dgrt_mcmc_nht.yaml @@ -0,0 +1,22 @@ +# @package _global_ +# NHT (Neural Harmonic Textures) variant for nerf_synthetic with 3DGRT and MCMC + +defaults: + - /base_mcmc + - /dataset: nerf + - /initialization: random + - /render: 3dgrt + - _self_ + +model: + feature_type: "nht" + +render: + pipeline_type: referenceSlang + backward_pipeline_type: referenceSlangBwd + +loss: + use_opacity: true + lambda_opacity: 0.02 + use_scale: true + lambda_scale: 0.005 diff --git a/configs/apps/nerf_synthetic_3dgut_mcmc_nht.yaml b/configs/apps/nerf_synthetic_3dgut_mcmc_nht.yaml new file mode 100644 index 00000000..7b8307c8 --- /dev/null +++ b/configs/apps/nerf_synthetic_3dgut_mcmc_nht.yaml @@ -0,0 +1,18 @@ +# @package _global_ +# NHT (Neural Harmonic Textures) variant for nerf_synthetic with 3DGUT and MCMC + +defaults: + - /base_mcmc + - /dataset: nerf + - /initialization: random + - /render: 3dgut + - _self_ + +model: + feature_type: "nht" + +loss: + use_opacity: true + lambda_opacity: 0.02 + use_scale: true + lambda_scale: 0.005 diff --git a/configs/base_gs.yaml b/configs/base_gs.yaml index 51248502..5bf31edf 100644 --- a/configs/base_gs.yaml +++ b/configs/base_gs.yaml @@ -56,17 +56,42 @@ model: optimize_density: true optimize_features_albedo: true optimize_features_specular: true + optimize_features: true optimize_position: true optimize_rotation: true optimize_scale: true bvh_update_frequency: 1 + feature_type: "sh" # one of [sh, nht] # sh: spherical harmonics radiance, nht: neural harmonics texture progressive_training: - feature_type: "sh" # one of [sh, mlp], currently only sh supported init_n_features: 0 # sh: initial sh deg | mlp: num of dims initially unmasked max_n_features: 3 # sh: maximum sh deg | mlp: total num of dims finally unmasked increase_frequency: 1000 # unmask more feature dimensions every N global steps increase_step: 1 # sh: how many degrees unmasked per step | mlp: how many dims unmasked per step + nht_features: + dim: 48 + activation: + type: "siren" + num_frequencies: 1 + interpolation_type: "barycentric" + + nht_decoder: + enabled: true + hidden_dim: 128 + num_layers: 3 + dir_encoding: "SphericalHarmonics" + dir_encoding_degree: 3 + sh_scale: 3.0 + output_activation: "Sigmoid" + learning_rate: 0.00068 + reg_weight: 0.0 + scheduler: + type: "cosine" + decay_final: 0.1 + max_steps: 30000 + ema_decay: 0.95 + ema_start_step: 0 + background: name: background-color # one of [skip-background, background-color] color: black # one of [black, white, random] needs to be defined if name == background-color @@ -89,6 +114,8 @@ optimizer: lr: 0.0025 # 3DGS value: 0.0025 features_specular: lr: ${div:${optimizer.params.features_albedo.lr},20} # 3DGS value 20x smaller than lr of features_albedo + features: + lr: 0.015 rotation: lr: 0.001 # 3DGS value: 0.001 scale: @@ -106,6 +133,11 @@ scheduler: density: type: skip + features: + type: cosine + decay_final: 0.1 + max_steps: 30000 + loss: use_l1: true lambda_l1: 0.8 diff --git a/configs/render/3dgrt.yaml b/configs/render/3dgrt.yaml index 655cd3a8..6f664727 100644 --- a/configs/render/3dgrt.yaml +++ b/configs/render/3dgrt.yaml @@ -14,3 +14,5 @@ max_consecutive_bvh_update: 15 enable_normals: false enable_hitcounts: true enable_kernel_timings: false +particle_feature_half: false +feature_output_half: false diff --git a/threedgrt_tracer/include/3dgrt/kernels/cuda/3dgrtTracer.cuh b/threedgrt_tracer/include/3dgrt/kernels/cuda/3dgrtTracer.cuh index b443977e..8ce3ff60 100644 --- a/threedgrt_tracer/include/3dgrt/kernels/cuda/3dgrtTracer.cuh +++ b/threedgrt_tracer/include/3dgrt/kernels/cuda/3dgrtTracer.cuh @@ -38,7 +38,8 @@ struct RayHit { using RayPayload = RayHit[PipelineParameters::MaxNumHitPerTrace]; struct RayData { - float3 radiance; + float features[RAY_FEATURE_DIM]; + float density; float3 normal; float hitDistance; @@ -46,12 +47,16 @@ struct RayData { float hitCount; // TODO (operel): convert to uint32 __device__ void initialize() { - radiance = make_float3(0.0f); - density = 0.0; - normal = make_float3(0.f); - hitDistance = 0.f; + // Zero-initialize all features + #pragma unroll + for (int i = 0; i < RAY_FEATURE_DIM; ++i) { + radiance[i] = 0.0f; + } + density = 0.0; + normal = make_float3(0.f); + hitDistance = 0.f; rayLastHitDistance = 0.f; - hitCount = 0.f; + hitCount = 0.f; } }; @@ -126,7 +131,7 @@ static __device__ __inline__ void trace( rayPayload[15].distance = __uint_as_float(r31); } -/* Traces Gaussians along ray and accumulates radiance into rayData. +/* Traces Gaussians along ray and accumulates features into rayData. * ray is bounded by both (tmin, tmax) and launch params.aabb */ static __device__ __inline__ void traceVolumetricGS( @@ -142,10 +147,10 @@ static __device__ __inline__ void traceVolumetricGS( } float rayTransmittance = 1.0f - rayData.density; - float2 minMaxT = intersectAABB(params.aabb, rayOrigin, rayDirection); - minMaxT.x = fmaxf(minMaxT.x, tmin); - minMaxT.y = fminf(minMaxT.y, tmax); - constexpr float epsT = 1e-9; + float2 minMaxT = intersectAABB(params.aabb, rayOrigin, rayDirection); + minMaxT.x = fmaxf(minMaxT.x, tmin); + minMaxT.y = fminf(minMaxT.y, tmax); + constexpr float epsT = 1e-9; float rayLastHitDistance = fmaxf(0.0f, minMaxT.x - epsT); RayPayload rayPayload; @@ -175,11 +180,15 @@ static __device__ __inline__ void traceVolumetricGS( #endif ); - particleFeaturesIntegrateFwdFromBuffer(rayDirection, - hitWeight, - rayHit.particleId, - {{(float3*)params.particleRadiance, nullptr, true}, params.sphDegree}, - &rayData.radiance); + // Call generic Slang wrapper (no conditionals) + // The wrapper handles CommonParameters construction internally + particleFeaturesIntegrateFwdGeneric( + rayDirection, + hitWeight, + rayHit.particleId, + params.particleRadiance, // void* - generic buffer pointer + params.sphDegree, // auxiliary parameter (sphDegree for SH, unused for learned) + rayData.features); // float* - generic output array rayLastHitDistance = fmaxf(rayLastHitDistance, rayHit.distance); @@ -190,7 +199,7 @@ static __device__ __inline__ void traceVolumetricGS( } } - rayData.density = 1 - rayTransmittance; + rayData.density = 1 - rayTransmittance; rayData.rayLastHitDistance = rayLastHitDistance; } diff --git a/threedgrt_tracer/include/3dgrt/kernels/cuda/gaussianParticles.cuh b/threedgrt_tracer/include/3dgrt/kernels/cuda/gaussianParticles.cuh index ebf7c100..54a1fcbf 100644 --- a/threedgrt_tracer/include/3dgrt/kernels/cuda/gaussianParticles.cuh +++ b/threedgrt_tracer/include/3dgrt/kernels/cuda/gaussianParticles.cuh @@ -13,6 +13,10 @@ // See the License for the specific language governing permissions and // limitations under the License. +// This file contains SH-specific CUDA backward pass helpers. +// For learned features mode (FEATURE_TRANSFORM_TYPE != 0), this file compiles but provides no functions. +// Learned features use Slang autodiff instead of these CUDA helpers. + #include #include <3dgrt/mathUtils.h> @@ -22,6 +26,9 @@ typedef int int32_t; typedef unsigned int uint32_t; +// Only define SH-specific functions for SH mode +#if !defined(FEATURE_TRANSFORM_TYPE) || FEATURE_TRANSFORM_TYPE == 0 + void quaternionWXYZToMatrix(const float4& q, float33& ret) { const float r = q.x; const float x = q.y; @@ -723,3 +730,5 @@ __device__ inline void processHitBwd( transmittance = nextTransmit; } } + +#endif // FEATURE_TRANSFORM_TYPE == 0 diff --git a/threedgrt_tracer/include/3dgrt/kernels/slang/models/gaussianParticles.slang b/threedgrt_tracer/include/3dgrt/kernels/slang/models/gaussianParticles.slang index 04e9d413..d8568b26 100644 --- a/threedgrt_tracer/include/3dgrt/kernels/slang/models/gaussianParticles.slang +++ b/threedgrt_tracer/include/3dgrt/kernels/slang/models/gaussianParticles.slang @@ -185,6 +185,17 @@ struct Parameters : IDifferentiable { return sqrt(dot(grds, grds)); } +[BackwardDifferentiable] [ForceInline] float canonicalRayIntersection( + float3 canonicalRayOrigin, + float3 canonicalRayDirection, + float3 scale, + out float3 canonicalIntersection) { + const float3 canonicalGrds = canonicalRayDirection * dot(canonicalRayDirection, -1 * canonicalRayOrigin); + canonicalIntersection = canonicalRayOrigin + canonicalGrds; + const float3 grds = scale * canonicalGrds; + return sqrt(dot(grds, grds)); +} + [BackwardDifferentiable][ForceInline] float3 canonicalRayNormal( float3 canonicalRayOrigin, float3 canonicalRayDirection, @@ -207,6 +218,7 @@ bool hit( Parameters parameters, out float alpha, inout float depth, + out float3 canonicalIntersection, no_diff bool enableNormal, inout float3 normal) { @@ -227,7 +239,7 @@ bool hit( const bool acceptHit = ((maxResponse > MinParticleKernelDensity) && (alpha > MinParticleAlpha)); if (acceptHit) { - depth = canonicalRayDistance(canonicalRayOrigin, canonicalRayDirection, parameters.scale); + depth = canonicalRayIntersection(canonicalRayOrigin, canonicalRayDirection, parameters.scale, canonicalIntersection); if (enableNormal) { normal = canonicalRayNormal(canonicalRayOrigin, canonicalRayDirection, parameters.scale, parameters.rotationT); @@ -276,6 +288,7 @@ float processHitFromBuffer( no_diff RawParametersBuffer parametersBuffer, inout float transmittance, inout float integratedDepth, + out float3 canonicalIntersection, no_diff bool enableNormal, inout float3 integratedNormal) { @@ -287,6 +300,7 @@ float processHitFromBuffer( fetchParameters(particleIdx, parametersBuffer), alpha, depth, + canonicalIntersection, enableNormal, normal)) { @@ -320,7 +334,7 @@ float3 incidentDirectionFromParameters( } [BackwardDifferentiable][ForceInline] -float3 incidentDirectionFromBuffer( +no_diff float3 incidentDirectionFromBuffer( no_diff uint32_t particleIdx, no_diff RawParametersBuffer parametersBuffer, no_diff float3 sourcePosition @@ -353,6 +367,7 @@ inline bool particleDensityHit( gaussianParticle.Parameters parameters, out float alpha, out float depth, + out float3 canonicalIntersection, bool enableNormal, out float3 normal) { @@ -361,6 +376,7 @@ inline bool particleDensityHit( parameters, alpha, depth, + canonicalIntersection, enableNormal, normal); } @@ -392,6 +408,7 @@ inline float particleDensityProcessHitFwdFromBuffer( gaussianParticle.CommonParameters commonParameters, inout float transmittance, inout float integratedDepth, + out float3 canonicalIntersection, in bool enableNormal, inout float3 integratedNormal) { @@ -402,6 +419,7 @@ inline float particleDensityProcessHitFwdFromBuffer( commonParameters.parametersBuffer, transmittance, integratedDepth, + canonicalIntersection, enableNormal, integratedNormal); } @@ -419,6 +437,7 @@ void particleDensityProcessHitBwdToBuffer( in float depth, inout float integratedDepth, inout float integratedDepthGrad, + in float3 canonicalIntersectionGrad, bool enableNormal, in float3 normal, inout float3 integratedNormal, @@ -452,6 +471,7 @@ void particleDensityProcessHitBwdToBuffer( commonParameters.parametersBuffer, transmittanceDiff, integratedDepthDiff, + canonicalIntersectionGrad, enableNormal, integratedNormalDiff, alphaGrad); @@ -536,10 +556,7 @@ bool particleDensityHitInstance( in float3 incidentDirectionGrad ) { - bwd_diff(gaussianParticle.incidentDirectionFromBuffer)( - particleIdx, - commonParameters.parametersBuffer, - sourcePosition, - incidentDirectionGrad - ); + // 3DGRT: incidentDirectionFromBuffer returns no_diff; position gradient + // from incident direction is not propagated back in the 3DGRT path. } + diff --git a/threedgrt_tracer/include/3dgrt/kernels/slang/models/neuralHarmonicFeaturesParticle.slang b/threedgrt_tracer/include/3dgrt/kernels/slang/models/neuralHarmonicFeaturesParticle.slang new file mode 100644 index 00000000..0b426f42 --- /dev/null +++ b/threedgrt_tracer/include/3dgrt/kernels/slang/models/neuralHarmonicFeaturesParticle.slang @@ -0,0 +1,307 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use it except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Neural Harmonic Features: per-particle K-dim features, output N (decoder input); decoder maps N -> RGB. +// Same CudaDeviceExport API as shRadiativeParticles.slang for drop-in replacement. +// Compile-time: FEATURE_INTERPOLATION_TYPE, FEATURE_INTERPOLATION_SUPPORT, FEATURE_ACTIVATION_TYPE, FEATURE_ACTIVATION_NUM_FREQUENCIES, +// PARTICLE_FEATURE_DIM (total K per particle = buffer stride), INTERP_POINT_FEATURE_DIM (per-interpolation-point = K/num_points). +// Support: center -> K=interpPointDim; tetrahedra -> K=4*interpPointDim. With activation: N = interpPointDim * num_frequencies (decoder input). + +namespace neuralHarmonicFeaturesParticle +{ +static const int ParticleFeatureDim = PARTICLE_FEATURE_DIM; // Total per-particle (K); not per-interpolation-point +static const int RayFeatureDim = RAY_FEATURE_DIM; // Decoder input N = INTERP_POINT_FEATURE_DIM * FEATURE_ACTIVATION_NUM_FREQUENCIES +static const int InterpPointFeatureDim = INTERP_POINT_FEATURE_DIM; // Per-interpolation-point dimension +static const int FeatureActivationType_None = 0; +static const int FeatureActivationType_Siren = 1; +static const int FeatureActivationType_Sincos = 2; +static const int FeatureActivationType_Relu = 3; +static const int FeatureActivationType = FEATURE_ACTIVATION_TYPE; +static const int FeatureActivationNumFrequencies = FEATURE_ACTIVATION_NUM_FREQUENCIES; + +// Interpolation type (compile-time from FEATURE_INTERPOLATION_TYPE) +static const int InterpolationType_Barycentric = 0; +static const int InterpolationType_Bezier = 1; // Not supported yet +static const int InterpolationType = FEATURE_INTERPOLATION_TYPE; + +// Interpolation support (compile-time from FEATURE_INTERPOLATION_SUPPORT) +static const int InterpolationSupport_Center = 0; +static const int InterpolationSupport_Tetrahedra = 1; +static const int InterpolationSupport_CoTriangles = 2; // 2 coplanar triangles / Not supported yet +static const int InterpolationSupport = FEATURE_INTERPOLATION_SUPPORT; + +// Canonical regular tetrahedron that contains the unit sphere; each face is tangent to the unit sphere +// (insphere radius = 1). Same geometry as particlePrimitives.cu (enclosing tetrahedra). +// Incenter at origin; base z=-1, apex z=3; edge s=sqrt(24), height h=4, inradius r=1. +static const float tetraHedraEdge = 4.898979485566356f; // sqrt(24) +static const float tetraHedraFaceHeight = 4.242640687119285f; // s*sqrt(3)/2 +static const float tetraHedraHeight = 4.0f; // s*sqrt(2/3) +static const float tetraHedraFaceInRadius = 1.4142135623730951f; // s*sqrt(3)/6 = sqrt(2) +static const float tetraHedraInRadius = 1.0f; // unit sphere +static const float3 canonicalTetraVerts[4] = { + float3(-0.5f * tetraHedraEdge, -tetraHedraFaceInRadius, -1.0f), + float3(0.0f, tetraHedraFaceHeight - tetraHedraFaceInRadius, -1.0f), + float3(0.0f, 0.0f, tetraHedraHeight - tetraHedraInRadius), + float3(0.5f * tetraHedraEdge, -tetraHedraFaceInRadius, -1.0f) +}; +// Cramer terms for canonical tetrahedron (independent of P); used by barycentricTetrahedronCanonical. +static const float3 canonicalTetraE1 = canonicalTetraVerts[1] - canonicalTetraVerts[0]; +static const float3 canonicalTetraE2 = canonicalTetraVerts[2] - canonicalTetraVerts[0]; +static const float3 canonicalTetraE3 = canonicalTetraVerts[3] - canonicalTetraVerts[0]; +static const float3 canonicalTetraCrossE2E3 = cross(canonicalTetraE2, canonicalTetraE3); +static const float canonicalTetraDet = dot(canonicalTetraE1, canonicalTetraCrossE2E3); +static const float canonicalTetraInvDet = 1.0f / canonicalTetraDet; + +}; + +#if PARTICLE_FEATURE_HALF +typedef half feat_elem_t; +#else +typedef float feat_elem_t; +#endif + +namespace neuralHarmonicFeaturesParticle +{ + +struct ParametersBuffer +{ + Ptr _dataPtr; // [N_particles, K] flat; fp16 when PARTICLE_FEATURE_HALF + float *_gradPtr; + bool exclusiveGradient; +}; + +struct Parameters : IDifferentiable +{ + Array features; +}; + +[BackwardDifferentiable][ForceInline] +Parameters fetchParametersFromBuffer(no_diff uint32_t particleIdx, + no_diff int interpPointIdx, + no_diff ParametersBuffer parametersBuffer) +{ + Parameters parameters; + const uint32_t particleOffset = particleIdx * ParticleFeatureDim; + const uint32_t interpPointOffset = particleOffset + interpPointIdx * InterpPointFeatureDim; + [ForceUnroll] for (int i = 0; i < InterpPointFeatureDim; ++i) { + parameters.features[i] = parametersBuffer._dataPtr[interpPointOffset + i]; + } + return parameters; +} + +[BackwardDerivativeOf(fetchParametersFromBuffer)][ForceInline] +void fetchParametersFromBufferBwd(no_diff uint32_t particleIdx, + no_diff int interpPointIdx, + no_diff ParametersBuffer parametersBuffer, + Parameters parametersGrad) +{ + const uint32_t particleOffset = particleIdx * ParticleFeatureDim; + const uint32_t interpPointOffset = particleOffset + interpPointIdx * InterpPointFeatureDim; + [ForceUnroll] for (int i = 0; i < InterpPointFeatureDim; ++i) { + const float grad = parametersGrad.features[i]; + if (parametersBuffer.exclusiveGradient) { + parametersBuffer._gradPtr[interpPointOffset + i] += grad; + } else { + InterlockedAdd(parametersBuffer._gradPtr[interpPointOffset + i], grad); + } + } +} + +// Barycentric coordinates for the canonical tetrahedron only. Uses precomputed static const e1,e2,e3,invDet (independent of P). +[BackwardDifferentiable][ForceInline] +float4 barycentricTetrahedronCanonical(float3 P) +{ + float3 d = P - canonicalTetraVerts[0]; + float4 weights; + weights.y = dot(d, canonicalTetraCrossE2E3) * canonicalTetraInvDet; + weights.z = dot(canonicalTetraE1, cross(d, canonicalTetraE3)) * canonicalTetraInvDet; + weights.w = dot(canonicalTetraE1, cross(canonicalTetraE2, d)) * canonicalTetraInvDet; + weights.x = 1.0f - weights.y - weights.z - weights.w; + return weights; +} + +// Encode and activate : none -> identity; siren -> sin(b*2^f); sincos -> sin(b*2^f)+cos(b*2^f); relu -> max(0,b). +[BackwardDifferentiable][ForceInline] +float encodeAndActivate(float baseVal, no_diff int f) +{ + if (FeatureActivationType == FeatureActivationType_None) + return baseVal; + if (FeatureActivationType == FeatureActivationType_Relu) + return max(0.0f, baseVal); + float freq = ldexp(1.0f, f); + float angle = baseVal * freq; + if (FeatureActivationType == FeatureActivationType_Siren) + return sin(angle); + return sin(angle) + cos(angle); +} + +// Compute blended features into baseFeatures[INTERP_POINT_FEATURE_DIM], then optionally expand by activation to features[RayFeatureDim]. +// canonicalPosition is differential (hit position in particle canonical space; gradient accumulated in API canonicalPositionGrad). +[BackwardDifferentiable][ForceInline] +void featuresFromParametersBuffer(ParametersBuffer parametersBuffer, + no_diff uint32_t particleIdx, + float3 canonicalPosition, + out Array features +) +{ + Array baseFeatures = fetchParametersFromBuffer(particleIdx, 0, parametersBuffer).features; + if (InterpolationSupport == InterpolationSupport_Tetrahedra && InterpolationType == InterpolationType_Barycentric) { + float4 barycentricWeights = barycentricTetrahedronCanonical(canonicalPosition); + [ForceUnroll] for (int n = 0; n < InterpPointFeatureDim; ++n) { + baseFeatures[n] *= barycentricWeights[0]; + } + [ForceUnroll] for (int k = 1; k < 4; ++k) { + Parameters parameters = fetchParametersFromBuffer(particleIdx, k, parametersBuffer); + [ForceUnroll] for (int n = 0; n < InterpPointFeatureDim; ++n) { + baseFeatures[n] += barycentricWeights[k] * parameters.features[n]; + } + } + } + + if (FeatureActivationType == FeatureActivationType_None) { + [ForceUnroll] for (int i = 0; i < InterpPointFeatureDim; ++i) { + features[i] = baseFeatures[i]; + } + } else { + [ForceUnroll] for (int k = 0; k < InterpPointFeatureDim; ++k) { + [ForceUnroll] for (int f = 0; f < FeatureActivationNumFrequencies; ++f) { + features[k * FeatureActivationNumFrequencies + f] = encodeAndActivate(baseFeatures[k], f); + } + } + } +} + +[BackwardDifferentiable][ForceInline] +void integrateFeatures(float weight, + in Array features, + inout float integratedFeatures[RayFeatureDim]) +{ + if (weight > 0.0f) { + [ForceUnroll] for (int i = 0; i < RayFeatureDim; ++i) { + if (backToFront) + integratedFeatures[i] = lerp(integratedFeatures[i], features[i], weight); + else + integratedFeatures[i] += features[i] * weight; + } + } +} + +[BackwardDifferentiable][ForceInline] +void integrateFeaturesFromBuffer(float weight, + no_diff uint32_t particleIdx, + ParametersBuffer parametersBuffer, + float3 canonicalPosition, + inout float integratedFeatures[RayFeatureDim]) +{ + if (weight > 0.0f) { + Array features; + featuresFromParametersBuffer(parametersBuffer, particleIdx, canonicalPosition, features); + integrateFeatures(weight, features, integratedFeatures); + } +} + +} // namespace neuralHarmonicFeaturesParticle + +// ------------------------------------------------------------------------------------------------------------------ +// Entry points - same CudaDeviceExport API as shRadiativeParticles.slang + +[CudaDeviceExport] +inline void particleFeaturesFromBuffer( + in uint32_t particleIdx, + in feat_elem_t *featuresBufferPtr, + in int auxParam, + in float3 incidentDirection, + in float3 canonicalPosition, + out float features[neuralHarmonicFeaturesParticle.RayFeatureDim]) +{ + neuralHarmonicFeaturesParticle.ParametersBuffer parametersBuffer; + parametersBuffer._dataPtr = (Ptr)featuresBufferPtr; + parametersBuffer._gradPtr = nullptr; + parametersBuffer.exclusiveGradient = false; + + neuralHarmonicFeaturesParticle.featuresFromParametersBuffer( + parametersBuffer, + particleIdx, + canonicalPosition, + features + ); +} + +[CudaDeviceExport] +inline void particleFeaturesIntegrateFwdFromBuffer( + in float3 incidentDirection, + in float3 canonicalPosition, + in float weight, + in uint32_t particleIdx, + in feat_elem_t *featuresBufferPtr, + in int auxParam, + inout float integratedFeatures[neuralHarmonicFeaturesParticle.RayFeatureDim]) +{ + neuralHarmonicFeaturesParticle.ParametersBuffer parametersBuffer; + parametersBuffer._dataPtr = (Ptr)featuresBufferPtr; + parametersBuffer._gradPtr = nullptr; + parametersBuffer.exclusiveGradient = false; + + neuralHarmonicFeaturesParticle.integrateFeaturesFromBuffer( + weight, particleIdx, parametersBuffer, canonicalPosition, integratedFeatures); +} + +// canonicalPosition is differential; canonicalPositionGrad is inout and accumulated by this backward. +[CudaDeviceExport] void particleFeaturesIntegrateBwdToBuffer( + in float3 incidentDirection, + in float3 canonicalPosition, + inout float3 canonicalPositionGrad, + in float alpha, + inout float alphaGrad, + in uint32_t particleIdx, + in feat_elem_t *featuresBufferPtr, + in float *featuresBufferGradPtr, + in int auxParam, + in bool exclusiveGradient, + in float features[neuralHarmonicFeaturesParticle.RayFeatureDim], + inout float integratedFeatures[neuralHarmonicFeaturesParticle.RayFeatureDim], + inout float integratedFeaturesGrad[neuralHarmonicFeaturesParticle.RayFeatureDim]) +{ + if (alpha > 0.0f) + { + neuralHarmonicFeaturesParticle.ParametersBuffer parametersBuffer; + parametersBuffer._dataPtr = (Ptr)featuresBufferPtr; + parametersBuffer._gradPtr = (Ptr)featuresBufferGradPtr; + parametersBuffer.exclusiveGradient = exclusiveGradient; + + DifferentialPair alphaDiff = DifferentialPair(alpha, alphaGrad); + + const float weight = 1.0f / (1.0f - alpha); + [ForceUnroll] for (int i = 0; i < neuralHarmonicFeaturesParticle.RayFeatureDim; ++i) { + integratedFeatures[i] = (integratedFeatures[i] - features[i] * alpha) * weight; + } + + DifferentialPair integratedFeaturesDiff = + DifferentialPair(integratedFeatures, integratedFeaturesGrad); + + DifferentialPair canonicalPositionDiff = DifferentialPair(canonicalPosition, canonicalPositionGrad); + + bwd_diff(neuralHarmonicFeaturesParticle.integrateFeaturesFromBuffer)( + alphaDiff, + particleIdx, + parametersBuffer, + canonicalPositionDiff, + integratedFeaturesDiff); + + integratedFeaturesGrad = integratedFeaturesDiff.getDifferential(); + canonicalPositionGrad += canonicalPositionDiff.getDifferential(); + alphaGrad += alphaDiff.getDifferential(); + } +} diff --git a/threedgrt_tracer/include/3dgrt/kernels/slang/models/radiativeParticles.slang b/threedgrt_tracer/include/3dgrt/kernels/slang/models/radiativeParticles.slang new file mode 100644 index 00000000..970005eb --- /dev/null +++ b/threedgrt_tracer/include/3dgrt/kernels/slang/models/radiativeParticles.slang @@ -0,0 +1,27 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Radiative Particles Wrapper +// Conditionally includes the appropriate feature implementation based on FEATURE_TRANSFORM_TYPE + +#if FEATURE_TRANSFORM_TYPE == 0 + // Spherical Harmonics mode + #include <3dgrt/kernels/slang/models/shRadiativeParticles.slang> +#elif FEATURE_TRANSFORM_TYPE == 1 + // Post-MLP radiance mode + #include <3dgrt/kernels/slang/models/neuralHarmonicFeaturesParticle.slang> +#else + #error "Unknown FEATURE_TRANSFORM_TYPE. Must be 0 (SH) or 1 (neural_harmonic_features)" +#endif diff --git a/threedgrt_tracer/include/3dgrt/kernels/slang/models/shRadiativeParticles.slang b/threedgrt_tracer/include/3dgrt/kernels/slang/models/shRadiativeParticles.slang index d6ef7d73..1e110202 100644 --- a/threedgrt_tracer/include/3dgrt/kernels/slang/models/shRadiativeParticles.slang +++ b/threedgrt_tracer/include/3dgrt/kernels/slang/models/shRadiativeParticles.slang @@ -31,12 +31,6 @@ struct ParametersBuffer bool exclusiveGradient; //< true if the gradient maybe updated without atomics }; -struct CommonParameters -{ - ParametersBuffer parametersBuffer; - int sphDegree; -}; - struct Parameters : IDifferentiable { vector sphCoefficients[RadianceMaxNumSphCoefficients]; @@ -48,7 +42,7 @@ Parameters fetchParametersFromBuffer(no_diff uint32_t particleIdx, { Parameters parameters; const uint32_t particleOffset = particleIdx * RadianceMaxNumSphCoefficients; - [unroll] for (int i = 0; i < RadianceMaxNumSphCoefficients; ++i) { + [ForceUnroll] for (int i = 0; i < RadianceMaxNumSphCoefficients; ++i) { parameters.sphCoefficients[i] = parametersBuffer._dataPtr[particleOffset + i]; } return parameters; @@ -60,14 +54,14 @@ void fetchParametersFromBufferBwd(no_diff uint32_t particleIdx, Parameters parametersGrad) { const uint32_t particleOffset = particleIdx * RadianceMaxNumSphCoefficients; - [unroll] for (int i = 0; i < RadianceMaxNumSphCoefficients; ++i) { + [ForceUnroll] for (int i = 0; i < RadianceMaxNumSphCoefficients; ++i) { const vector coeffs = parametersGrad.sphCoefficients[i]; if (parametersBuffer.exclusiveGradient) { - [unroll] for (int j = 0; j < Dim; ++j) { + [ForceUnroll] for (int j = 0; j < Dim; ++j) { parametersBuffer._gradPtr[particleOffset + i][j] += coeffs[j]; } } else { - [unroll] for (int j = 0; j < Dim; ++j) { + [ForceUnroll] for (int j = 0; j < Dim; ++j) { InterlockedAdd(parametersBuffer._gradPtr[particleOffset + i][j], coeffs[j]); } } @@ -76,14 +70,14 @@ void fetchParametersFromBufferBwd(no_diff uint32_t particleIdx, [BackwardDifferentiable] [ForceInline] vector radianceFromBuffer(no_diff uint32_t particleIdx, - in float3 incidentDirection, + no_diff float3 incidentDirection, no_diff uint32_t sphDegree, no_diff ParametersBuffer parametersBuffer) { return sphericalHarmonics.decode( - sphDegree, - fetchParametersFromBuffer(particleIdx, parametersBuffer).sphCoefficients, - incidentDirection); + sphDegree, + fetchParametersFromBuffer(particleIdx, parametersBuffer).sphCoefficients, + incidentDirection); } [BackwardDifferentiable][ForceInline] @@ -145,41 +139,58 @@ void integrateRadianceFromBuffer(no_diff float3 incident // Entry points [CudaDeviceExport] -inline vector particleFeaturesFromBuffer(in uint32_t particleIdx, - shRadiativeParticle.CommonParameters commonParameters, - in float3 incidentDirection) +inline void particleFeaturesFromBuffer( + in uint32_t particleIdx, + in float *radianceBufferPtr, + in int auxParam, + in float3 incidentDirection, + in float3 canonicalPosition, + out float features[shRadiativeParticle.Dim]) { - return sphericalHarmonics.decode( - commonParameters.sphDegree, - shRadiativeParticle.fetchParametersFromBuffer(particleIdx, commonParameters.parametersBuffer).sphCoefficients, + shRadiativeParticle.ParametersBuffer parametersBuffer; + parametersBuffer._dataPtr = (Ptr, Access.Read>)radianceBufferPtr; + parametersBuffer._gradPtr = nullptr; + parametersBuffer.exclusiveGradient = false; + + vector featuresVec = sphericalHarmonics.decode( + auxParam, + shRadiativeParticle.fetchParametersFromBuffer(particleIdx, parametersBuffer).sphCoefficients, incidentDirection); + + [ForceUnroll] for (int i = 0; i < shRadiativeParticle.Dim; ++i) { + features[i] = featuresVec[i]; + } } [CudaDeviceExport] -inline void particleFeaturesIntegrateFwd(in float weight, - in vector features, - inout vector integratedFeatures) +inline void particleFeaturesIntegrateFwdFromBuffer( + in float3 incidentDirection, + in float3 canonicalPosition, + in float weight, + in uint32_t particleIdx, + in float *radianceBufferPtr, + in int auxParam, + inout float integratedFeatures[shRadiativeParticle.Dim]) { - shRadiativeParticle.integrateRadiance( - weight, - features, - integratedFeatures - ); -} + shRadiativeParticle.ParametersBuffer parametersBuffer; + parametersBuffer._dataPtr = (Ptr, Access.Read>)radianceBufferPtr; + parametersBuffer._gradPtr = nullptr; + parametersBuffer.exclusiveGradient = false; -[CudaDeviceExport] inline void particleFeaturesIntegrateFwdFromBuffer(in float3 incidentDirection, - in float weight, - in uint32_t particleIdx, - shRadiativeParticle.CommonParameters commonParameters, - inout vector integratedFeatures) -{ + vector integratedVec; + [ForceUnroll] for (int i = 0; i < shRadiativeParticle.Dim; ++i) { + integratedVec[i] = integratedFeatures[i]; + } shRadiativeParticle.integrateRadianceFromBuffer( incidentDirection, - commonParameters.sphDegree, + auxParam, weight, particleIdx, - commonParameters.parametersBuffer, - integratedFeatures); + parametersBuffer, + integratedVec); + [ForceUnroll] for (int i = 0; i < shRadiativeParticle.Dim; ++i) { + integratedFeatures[i] = integratedVec[i]; + } } [CudaDeviceExport] void particleFeaturesIntegrateBwd( @@ -214,52 +225,81 @@ inline void particleFeaturesIntegrateFwd(in float weight, [CudaDeviceExport] void particleFeaturesIntegrateBwdToBuffer( in float3 incidentDirection, + in float3 canonicalPosition, + inout float3 canonicalPositionGrad, in float alpha, inout float alphaGrad, in uint32_t particleIdx, - shRadiativeParticle.CommonParameters commonParameters, - in vector features, - inout vector integratedFeatures, - inout vector integratedFeaturesGrad) + in float *featuresBufferPtr, + in float *featuresBufferGradPtr, + in int auxParam, + in bool exclusiveGradient, + in float features[shRadiativeParticle.Dim], + inout float integratedFeatures[shRadiativeParticle.Dim], + inout float integratedFeaturesGrad[shRadiativeParticle.Dim]) { if (alpha > 0.0f) { + shRadiativeParticle.ParametersBuffer parametersBuffer; + parametersBuffer._dataPtr = (Ptr, Access.Read>)featuresBufferPtr; + parametersBuffer._gradPtr = (Ptr, Access.ReadWrite>)featuresBufferGradPtr; + parametersBuffer.exclusiveGradient = exclusiveGradient; + DifferentialPair alphaDiff = DifferentialPair(alpha, alphaGrad); const float weight = 1.0f / (1.0f - alpha); - integratedFeatures = (integratedFeatures - features * alpha) * weight; + [ForceUnroll] for (int i = 0; i < shRadiativeParticle.Dim; ++i) { + integratedFeatures[i] = (integratedFeatures[i] - features[i] * alpha) * weight; + } + + vector integratedFeaturesVec; + vector integratedFeaturesGradVec; + [ForceUnroll] for (int i = 0; i < shRadiativeParticle.Dim; ++i) { + integratedFeaturesVec[i] = integratedFeatures[i]; + integratedFeaturesGradVec[i] = integratedFeaturesGrad[i]; + } DifferentialPair> integratedFeaturesDiff = - DifferentialPair>(integratedFeatures, integratedFeaturesGrad); + DifferentialPair>(integratedFeaturesVec, integratedFeaturesGradVec); bwd_diff(shRadiativeParticle.integrateRadianceFromBuffer)( incidentDirection, - commonParameters.sphDegree, + auxParam, alphaDiff, particleIdx, - commonParameters.parametersBuffer, + parametersBuffer, integratedFeaturesDiff); - integratedFeaturesGrad = integratedFeaturesDiff.getDifferential(); - alphaGrad = alphaDiff.getDifferential(); + [ForceUnroll] for (int i = 0; i < shRadiativeParticle.Dim; ++i) { + integratedFeaturesGrad[i] = integratedFeaturesDiff.getDifferential()[i]; + } + alphaGrad += alphaDiff.getDifferential(); } } [CudaDeviceExport] void particleFeaturesBwdToBuffer( in uint32_t particleIdx, - shRadiativeParticle.CommonParameters commonParameters, - in vector featuresGrad, + in float *featuresBufferPtr, + in float *featuresBufferGradPtr, + in int auxParam, + in bool exclusiveGradient, + in float featuresGrad[shRadiativeParticle.Dim], in float3 incidentDirection, inout float3 incidentDirectionGrad ) { - DifferentialPair incidentDirectionDiff = DifferentialPair(incidentDirection, incidentDirectionGrad); + shRadiativeParticle.ParametersBuffer parametersBuffer; + parametersBuffer._dataPtr = (Ptr, Access.Read>)featuresBufferPtr; + parametersBuffer._gradPtr = (Ptr, Access.ReadWrite>)featuresBufferGradPtr; + parametersBuffer.exclusiveGradient = exclusiveGradient; + vector featuresGradVec; + [ForceUnroll] for (int i = 0; i < shRadiativeParticle.Dim; ++i) featuresGradVec[i] = featuresGrad[i]; + + // incidentDirection is no_diff in radianceFromBuffer for 3DGRT; pass as plain value. bwd_diff(shRadiativeParticle.radianceFromBuffer)( particleIdx, - incidentDirectionDiff, - commonParameters.sphDegree, - commonParameters.parametersBuffer, - featuresGrad); - - incidentDirectionGrad += incidentDirectionDiff.getDifferential(); + incidentDirection, + auxParam, + parametersBuffer, + featuresGradVec); } diff --git a/threedgrt_tracer/include/3dgrt/pipelineParameters.h b/threedgrt_tracer/include/3dgrt/pipelineParameters.h index a7af073d..d9a267bf 100644 --- a/threedgrt_tracer/include/3dgrt/pipelineParameters.h +++ b/threedgrt_tracer/include/3dgrt/pipelineParameters.h @@ -46,6 +46,11 @@ struct PipelineParameters { float alphaMinThreshold; unsigned int sphDegree; + // Feature-based radiance dimensions + static constexpr unsigned int ParticleFeatureDim = PARTICLE_FEATURE_DIM; ///< Total feature dim per particle (buffer stride K) + static constexpr unsigned int RayFeatureDim = RAY_FEATURE_DIM; ///< Per-ray (decoder input N) + static constexpr unsigned int FeatureTransformType = FEATURE_TRANSFORM_TYPE; ///< 0=SH, 1=learned + uint2 frameBounds; unsigned int frameNumber; int gPrimNumTri; diff --git a/threedgrt_tracer/setup_3dgrt.py b/threedgrt_tracer/setup_3dgrt.py index fd282e36..e878089d 100644 --- a/threedgrt_tracer/setup_3dgrt.py +++ b/threedgrt_tracer/setup_3dgrt.py @@ -16,6 +16,7 @@ import os from threedgrut.utils import jit +from threedgrut.model.features import Features # ---------------------------------------------------------------------------- @@ -24,13 +25,49 @@ def setup_3dgrt(conf): def to_cpp_bool(value): return "true" if value else "false" + feat = Features(conf) + transform_defines = [ + f"-DPARTICLE_FEATURE_DIM={feat.particle_feature_dim}", + f"-DRAY_FEATURE_DIM={feat.ray_feature_dim}", + f"-DFEATURE_TRANSFORM_TYPE={feat.transform_type}", + ] + nht_defines = [ + f"-DFEATURE_INTERPOLATION_TYPE={feat.interpolation_type}", + f"-DFEATURE_INTERPOLATION_SUPPORT={feat.interpolation_support}", + f"-DFEATURE_ACTIVATION_TYPE={feat.activation_type}", + f"-DFEATURE_ACTIVATION_NUM_FREQUENCIES={feat.activation_num_frequencies}", + f"-DINTERP_POINT_FEATURE_DIM={feat.interp_point_feature_dim}", + ] + half_defines = [ + f"-DPARTICLE_FEATURE_HALF={1 if conf.render.particle_feature_half else 0}", + f"-DFEATURE_OUTPUT_HALF={1 if conf.render.feature_output_half else 0}", + ] + include_paths = [] include_paths.append(os.path.join(os.path.dirname(__file__), "include")) include_paths.append(os.path.join(os.path.dirname(__file__), "dependencies", "optix-dev", "include")) - # Compiler options. - cflags = [] - cuda_flags = [] + # Compiler options. Same -D for feature dims so pipelineParameters.h and JIT OptiX pipeline (generateDefines) stay in sync. + cflags = [ + *transform_defines, + ] + cuda_flags = [ + # Feature-based radiance dimensions (must match Slang compilation) + *transform_defines, + *nht_defines, + *half_defines, + # Other particle parameters + f"-DPARTICLE_RADIANCE_NUM_COEFFS={(conf.render.particle_radiance_sph_degree + 1) ** 2}", + f"-DGAUSSIAN_PARTICLE_KERNEL_DEGREE={conf.render.particle_kernel_degree}", + f"-DGAUSSIAN_PARTICLE_MIN_KERNEL_DENSITY={conf.render.particle_kernel_min_response}", + f"-DGAUSSIAN_PARTICLE_MIN_ALPHA={conf.render.particle_kernel_min_alpha}", + f"-DGAUSSIAN_MIN_TRANSMITTANCE_THRESHOLD={conf.render.min_transmittance}", + ] + # When PARTICLE_FEATURE_HALF=1 the Slang-generated header uses __half types; + # the Slang prelude only pulls in and defines __half when + # SLANG_CUDA_ENABLE_HALF is set. + if conf.render.particle_feature_half or conf.render.feature_output_half: + cuda_flags.append("-DSLANG_CUDA_ENABLE_HALF=1") # List of sources. source_files = [ @@ -44,7 +81,7 @@ def to_cpp_bool(value): jit.compile_slang_kernel( kernel_files=[ f"{os.path.join(slang_build_dir,'models/gaussianParticles.slang')}", - f"{os.path.join(slang_build_dir,'models/shRadiativeParticles.slang')}", + f"{os.path.join(slang_build_dir,'models/radiativeParticles.slang')}", ], output_file=f"{os.path.join(slang_build_dir, 'gaussianParticles.cuh')}", defines=[ @@ -55,6 +92,10 @@ def to_cpp_bool(value): f"-DGAUSSIAN_PARTICLE_MAX_ALPHA={conf.render.particle_kernel_max_alpha}", f"-DGAUSSIAN_PARTICLE_ENABLE_NORMAL={to_cpp_bool(conf.render.enable_normals)}", f"-DGAUSSIAN_PARTICLE_SURFEL={to_cpp_bool(conf.render.primitive_type=='trisurfel')}", + # Feature-based radiance dimensions + *transform_defines, + *nht_defines, + *half_defines, ], include_paths=[ os.path.join(os.path.dirname(__file__), "include"), diff --git a/threedgrt_tracer/src/kernels/cuda/referenceSlangBwdOptix.cu b/threedgrt_tracer/src/kernels/cuda/referenceSlangBwdOptix.cu index 2e983fc6..c9cc4f0b 100644 --- a/threedgrt_tracer/src/kernels/cuda/referenceSlangBwdOptix.cu +++ b/threedgrt_tracer/src/kernels/cuda/referenceSlangBwdOptix.cu @@ -110,7 +110,11 @@ extern "C" __global__ void __raygen__rg() { const float3 rayOrigin = params.rayWorldOrigin(idx); const float3 rayDirection = params.rayWorldDirection(idx); - float3 rayRadiance = make_float3(params.rayRadiance[idx.z][idx.y][idx.x][0], params.rayRadiance[idx.z][idx.y][idx.x][1], params.rayRadiance[idx.z][idx.y][idx.x][2]); + FixedArray rayRadiance; +#pragma unroll + for (int i = 0; i < RAY_FEATURE_DIM; i++) { + rayRadiance[i] = params.rayRadiance[idx.z][idx.y][idx.x][i]; + } float rayTransmittance = 1.0f - params.rayDensity[idx.z][idx.y][idx.x][0]; float rayHitDistance = params.rayHitDistance[idx.z][idx.y][idx.x][0]; #ifdef ENABLE_NORMALS @@ -119,7 +123,11 @@ extern "C" __global__ void __raygen__rg() { float rayMaxHitDistance = params.rayHitDistance[idx.z][idx.y][idx.x][1]; - float3 rayRadianceGrad = make_float3(params.rayRadianceGrad[idx.z][idx.y][idx.x][0], params.rayRadianceGrad[idx.z][idx.y][idx.x][1], params.rayRadianceGrad[idx.z][idx.y][idx.x][2]); + FixedArray rayRadianceGrad; +#pragma unroll + for (int i = 0; i < RAY_FEATURE_DIM; i++) { + rayRadianceGrad[i] = params.rayRadianceGrad[idx.z][idx.y][idx.x][i]; + } float rayTransmittanceGrad = -1.0f * params.rayDensityGrad[idx.z][idx.y][idx.x][0]; float rayHitDistanceGrad = params.rayHitDistanceGrad[idx.z][idx.y][idx.x][0]; #ifdef ENABLE_NORMALS @@ -149,12 +157,13 @@ extern "C" __global__ void __raygen__rg() { // NB : processing front-to-back backToFrontBwd is equivalent processing back-to-front frontToBackBwd !! float hitAlpha, hitDistance; - float3 hitNormal; + float3 canonicalIntersection, hitNormal; if (particleDensityHit( rayOrigin, rayDirection, particleDensityParameters(rayHit.particleId, {{(gaussianParticle_RawParameters_0*)params.particleDensity, nullptr, true}}), &hitAlpha, &hitDistance, + &canonicalIntersection, #ifdef ENABLE_NORMALS true, &hitNormal #else @@ -163,17 +172,27 @@ extern "C" __global__ void __raygen__rg() { ) ) { - const float3 hitRadiance = particleFeaturesFromBuffer( + FixedArray hitRadiance; + particleFeaturesFromBuffer( rayHit.particleId, - {{(float3*)params.particleRadiance, (float3*)params.particleRadianceGrad, false}, (int)params.sphDegree}, - rayDirection); + const_cast(params.particleRadiance), + (int)params.sphDegree, + rayDirection, + canonicalIntersection, + &hitRadiance); float hitAlphaGrad = 0.f; + float3 canonicalIntersectionGrad = make_float3(0.f); particleFeaturesIntegrateBwdToBuffer(rayDirection, + canonicalIntersection, + &canonicalIntersectionGrad, hitAlpha, &hitAlphaGrad, rayHit.particleId, - {{(float3*)params.particleRadiance, (float3*)params.particleRadianceGrad, false}, params.sphDegree}, + const_cast(params.particleRadiance), + const_cast(params.particleRadianceGrad), + (int)params.sphDegree, + false, // exclusiveGradient: multiple rays can hit same particle hitRadiance, &rayRadiance, &rayRadianceGrad); @@ -182,7 +201,8 @@ extern "C" __global__ void __raygen__rg() { rayDirection, rayHit.particleId, {{(gaussianParticle_RawParameters_0*)params.particleDensity, - (gaussianParticle_RawParameters_0*)params.particleDensityGrad, false}}, + (gaussianParticle_RawParameters_0*)params.particleDensityGrad, + false}}, hitAlpha, hitAlphaGrad, &rayTransmittance, @@ -190,10 +210,11 @@ extern "C" __global__ void __raygen__rg() { hitDistance, &rayHitDistance, &rayHitDistanceGrad, + canonicalIntersectionGrad, #ifdef ENABLE_NORMALS true, hitNormal, &rayNormal, &rayNormalGrad #else - false, hitNormal, nullptr, nullptr + false, make_float3(0.f, 0.f, 0.f), nullptr, nullptr #endif ); } diff --git a/threedgrt_tracer/src/kernels/cuda/referenceSlangOptix.cu b/threedgrt_tracer/src/kernels/cuda/referenceSlangOptix.cu index 5593d11f..ed2b34ef 100644 --- a/threedgrt_tracer/src/kernels/cuda/referenceSlangOptix.cu +++ b/threedgrt_tracer/src/kernels/cuda/referenceSlangOptix.cu @@ -109,9 +109,14 @@ extern "C" __global__ void __raygen__rg() { float3 rayOrigin = params.rayWorldOrigin(idx); float3 rayDirection = params.rayWorldDirection(idx); - float3 rayRadiance = make_float3(0.0f); - float rayTransmittance = 1.0f; - float rayHitDistance = 0.f; + FixedArray rayRadiance; +#pragma unroll + for (int i = 0; i < RAY_FEATURE_DIM; i++) { + rayRadiance[i] = 0.0f; + } + float rayTransmittance = 1.0f; + float rayHitDistance = 0.f; + float3 canonicalIntersection = make_float3(0.f); #ifdef ENABLE_NORMALS float3 rayNormal = make_float3(0.f); #endif @@ -143,6 +148,7 @@ extern "C" __global__ void __raygen__rg() { {{(gaussianParticle_RawParameters_0*)params.particleDensity, nullptr, true}}, &rayTransmittance, &rayHitDistance, + &canonicalIntersection, #ifdef ENABLE_NORMALS true, &rayNormal #else @@ -151,9 +157,11 @@ extern "C" __global__ void __raygen__rg() { ); particleFeaturesIntegrateFwdFromBuffer(rayDirection, + canonicalIntersection, hitWeight, rayHit.particleId, - {{(float3*)params.particleRadiance, nullptr, true}, params.sphDegree}, + const_cast(params.particleRadiance), + params.sphDegree, &rayRadiance); // NOTE(qi): Race condition here, but as we are writing the same value, it seems it is safe. @@ -170,9 +178,10 @@ extern "C" __global__ void __raygen__rg() { } } - params.rayRadiance[idx.z][idx.y][idx.x][0] = rayRadiance.x; - params.rayRadiance[idx.z][idx.y][idx.x][1] = rayRadiance.y; - params.rayRadiance[idx.z][idx.y][idx.x][2] = rayRadiance.z; +#pragma unroll + for (int i = 0; i < RAY_FEATURE_DIM; i++) { + params.rayRadiance[idx.z][idx.y][idx.x][i] = rayRadiance[i]; + } params.rayDensity[idx.z][idx.y][idx.x][0] = 1 - rayTransmittance; params.rayHitDistance[idx.z][idx.y][idx.x][0] = rayHitDistance; params.rayHitDistance[idx.z][idx.y][idx.x][1] = rayLastHitDistance; diff --git a/threedgrt_tracer/src/optixTracer.cpp b/threedgrt_tracer/src/optixTracer.cpp index e1e4a817..746b513f 100644 --- a/threedgrt_tracer/src/optixTracer.cpp +++ b/threedgrt_tracer/src/optixTracer.cpp @@ -215,6 +215,10 @@ std::vector OptixTracer::generateDefines( defines.emplace_back("-DSPH_MAX_NUM_COEFFS=" + std::to_string((_state->particleRadianceSphDegree + 1) * (_state->particleRadianceSphDegree + 1))); defines.emplace_back("-DPARTICLE_PRIMITIVE_TYPE=" + std::to_string(_state->gPrimType)); defines.emplace_back("-DPARTICLE_PRIMITIVE_CLAMPED=" + std::to_string(particleKernelDensityClamping ? 1 : 0)); + // Feature dims: use C++ defines so JIT OptiX pipeline sees same values as extension build + defines.emplace_back("-DPARTICLE_FEATURE_DIM=" + std::to_string(PipelineParameters::ParticleFeatureDim)); + defines.emplace_back("-DRAY_FEATURE_DIM=" + std::to_string(PipelineParameters::RayFeatureDim)); + defines.emplace_back("-DFEATURE_TRANSFORM_TYPE=" + std::to_string(PipelineParameters::FeatureTransformType)); } return defines; } @@ -863,7 +867,7 @@ OptixTracer::trace(uint32_t frameNumber, float minTransmittance) { const torch::TensorOptions opts = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); - torch::Tensor rayRad = torch::empty({rayOri.size(0), rayOri.size(1), rayOri.size(2), 3}, opts); + torch::Tensor rayRad = torch::empty({rayOri.size(0), rayOri.size(1), rayOri.size(2), static_cast(PipelineParameters::RayFeatureDim)}, opts); torch::Tensor rayDns = torch::empty({rayOri.size(0), rayOri.size(1), rayOri.size(2), 1}, opts); torch::Tensor rayHit = torch::empty({rayOri.size(0), rayOri.size(1), rayOri.size(2), 2}, opts); torch::Tensor rayNrm = torch::empty({rayOri.size(0), rayOri.size(1), rayOri.size(2), 3}, opts); diff --git a/threedgrt_tracer/tracer.py b/threedgrt_tracer/tracer.py index d62492c4..e143489c 100644 --- a/threedgrt_tracer/tracer.py +++ b/threedgrt_tracer/tracer.py @@ -21,6 +21,7 @@ import torch.utils.cpp_extension from threedgrut.datasets.protocols import Batch +from threedgrut.model.features import Features from threedgrut.utils.timer import CudaTimer logger = logging.getLogger(__name__) @@ -66,7 +67,7 @@ def forward( min_transmittance, ): particle_density = torch.concat([mog_pos, mog_dns, mog_rot, mog_scl, torch.zeros_like(mog_dns)], dim=1) - ray_radiance, ray_density, ray_hit_distance, ray_normals, hits_count, mog_visibility = tracer_wrapper.trace( + ray_features, ray_density, ray_hit_distance, ray_normals, hits_count, mog_visibility = tracer_wrapper.trace( frame_id, ray_to_world, ray_ori, @@ -81,7 +82,7 @@ def forward( ray_to_world, ray_ori, ray_dir, - ray_radiance, + ray_features, ray_density, ray_hit_distance, ray_normals, @@ -94,7 +95,7 @@ def forward( ctx.min_transmittance = min_transmittance ctx.tracer_wrapper = tracer_wrapper return ( - ray_radiance, + ray_features.float(), # always fp32 to caller; fp16 saved in ctx for trace_bwd ray_density, ray_hit_distance[:, :, :, 0:1], # return only the hit distance ray_normals, @@ -105,7 +106,7 @@ def forward( @staticmethod def backward( ctx, - ray_radiance_grd, + ray_features_grd, ray_density_grd, ray_hit_distance_grd, ray_normals_grd, @@ -116,7 +117,7 @@ def backward( ray_to_world, ray_ori, ray_dir, - ray_radiance, + ray_features, ray_density, ray_hit_distance, ray_normals, @@ -129,13 +130,13 @@ def backward( ray_to_world, ray_ori, ray_dir, - ray_radiance, + ray_features, ray_density, ray_hit_distance, ray_normals, particle_density, mog_sph, - ray_radiance_grd, + ray_features_grd, ray_density_grd, ray_hit_distance_grd, ray_normals_grd, @@ -167,10 +168,10 @@ class RenderOpts(IntEnum): DEFAULT = NONE def __init__(self, conf): - self.device = "cuda" self.conf = conf self.num_update_bvh = 0 + self.feature_transform_type = Features(conf).transform_type logger.info(f'🔆 Creating Optix tracing pipeline.. Using CUDA path: "{torch.utils.cpp_extension.CUDA_HOME}"') torch.zeros(1, device=self.device) # Create a dummy tensor to force cuda context init @@ -220,7 +221,7 @@ def render(self, gaussians, gpu_batch: Batch, train=False, frame_id=0): if self.frame_timer is not None: self.frame_timer.start() - pred_rgb, pred_opacity, pred_dist, pred_normals, hits_count, mog_visibility = Tracer._Autograd.apply( + (pred_features, pred_opacity, pred_dist, pred_normals, hits_count, mog_visibility) = Tracer._Autograd.apply( self.tracer_wrapper, frame_id, gpu_batch.T_to_world.contiguous(), @@ -230,7 +231,7 @@ def render(self, gaussians, gpu_batch: Batch, train=False, frame_id=0): gaussians.get_rotation().contiguous(), gaussians.get_scale().contiguous(), gaussians.get_density().contiguous(), - gaussians.get_features().contiguous(), + gaussians.get_features().contiguous(), # TODO: cast to .half() when conf.render.particle_feature_half Tracer.RenderOpts.DEFAULT, gaussians.n_active_features, self.conf.render.min_transmittance, @@ -239,15 +240,11 @@ def render(self, gaussians, gpu_batch: Batch, train=False, frame_id=0): if self.frame_timer is not None: self.frame_timer.end() - pred_rgb, pred_opacity = gaussians.background( - gpu_batch.T_to_world.contiguous(), gpu_batch.rays_dir.contiguous(), pred_rgb, pred_opacity, train - ) - if self.frame_timer is not None: self.timings["forward_render"] = self.frame_timer.timing() return { - "pred_rgb": pred_rgb, + "pred_features": pred_features, "pred_opacity": pred_opacity, "pred_dist": pred_dist, "pred_normals": torch.nn.functional.normalize(pred_normals, dim=3), diff --git a/threedgrut/model/feature_decoder.py b/threedgrut/model/feature_decoder.py new file mode 100644 index 00000000..2680060e --- /dev/null +++ b/threedgrut/model/feature_decoder.py @@ -0,0 +1,212 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn +import tinycudann as tcnn + + +class FeatureDecoder(nn.Module): + """Transforms N-dimensional feature maps to RGB radiance using tiny-cuda-nn. + + Takes rendered feature maps and ray directions, encodes directions, concatenates + with features, and decodes to RGB via a tiny-cuda-nn MLP. + """ + + def __init__( + self, + ray_feature_dim: int, + hidden_dim: int = 128, + num_layers: int = 4, + dir_encoding: str = "SphericalHarmonics", + dir_encoding_degree: int = 3, + sh_scale: float = 1.0, + output_activation: str = "Sigmoid", + ema_decay: float = 0.0, + ema_start_step: int = 0, + ): + """Initialize the feature decoder. + + Args: + ray_feature_dim: Per-ray feature dimension (rendered features input to the decoder MLP) + hidden_dim: Hidden layer dimension for MLP decoder (default 128) + num_layers: Number of hidden layers in the MLP (default 4) + dir_encoding: Direction encoding type ("SphericalHarmonics" or "Frequency") + dir_encoding_degree: Degree for direction encoding (SH degree or frequency bands; default 3) + sh_scale: Scale applied to ray directions before encoding: (v*sh_scale+1)/2 maps to [0,1]. + sh_scale=1 is standard unit sphere coverage; sh_scale=3 extends coverage for + directions beyond the unit sphere. + output_activation: Output layer activation ("Sigmoid" for [0,1] RGB, or "ReLU") + ema_decay: If > 0, keep EMA shadow of parameters (decay factor). 0 = no EMA. + ema_start_step: Global step at which to start updating EMA. + """ + super().__init__() + self.ray_feature_dim = ray_feature_dim + self.hidden_dim = hidden_dim + self.num_layers = num_layers + self.sh_scale = sh_scale + self.output_activation = output_activation + self._ema_decay = ema_decay + self._ema_start_step = ema_start_step + self._ema_shadow: dict[str, torch.Tensor] = {} + self._ema_backup: dict[str, torch.Tensor] = {} + + if dir_encoding == "SphericalHarmonics": + dir_enc = {"otype": "SphericalHarmonics", "degree": dir_encoding_degree, "n_dims_to_encode": 3} + elif dir_encoding == "Frequency": + dir_enc = {"otype": "Frequency", "n_frequencies": dir_encoding_degree, "n_dims_to_encode": 3} + else: + raise ValueError(f"Unknown dir_encoding: {dir_encoding}") + + composite_encoding_config = { + "otype": "Composite", + "nested": [ + {"otype": "Identity", "n_dims_to_encode": ray_feature_dim}, + dir_enc, + ], + } + + network_config = { + "otype": "FullyFusedMLP", + "activation": "ReLU", + "output_activation": output_activation, + "n_neurons": hidden_dim, + "n_hidden_layers": num_layers, + } + + self.network = tcnn.NetworkWithInputEncoding( + n_input_dims=ray_feature_dim + 3, + n_output_dims=3, + encoding_config=composite_encoding_config, + network_config=network_config, + ) + + if self._ema_decay > 0: + for name, param in self.named_parameters(): + if param.requires_grad: + self._ema_shadow[name] = param.data.clone() + + def ema_update(self, global_step: int) -> None: + """Update EMA shadow when global_step >= ema_start_step. No-op if ema_decay <= 0.""" + if self._ema_decay <= 0 or global_step < self._ema_start_step: + return + with torch.no_grad(): + for name, param in self.named_parameters(): + if param.requires_grad and name in self._ema_shadow: + if self._ema_shadow[name].device != param.device: + self._ema_shadow[name] = self._ema_shadow[name].to(param.device) + self._ema_shadow[name].lerp_(param.data, 1.0 - self._ema_decay) + + def apply_ema_shadow(self) -> None: + """Use EMA weights for inference (e.g. validation). No-op if no EMA.""" + if not self._ema_shadow: + return + with torch.no_grad(): + for name, param in self.named_parameters(): + if param.requires_grad and name in self._ema_shadow: + self._ema_backup[name] = param.data.clone() + param.data.copy_(self._ema_shadow[name]) + + def restore_ema(self) -> None: + """Restore training weights after inference. No-op if no EMA.""" + with torch.no_grad(): + for name, param in self.named_parameters(): + if param.requires_grad and name in self._ema_backup: + param.data.copy_(self._ema_backup[name]) + self._ema_backup.clear() + + def ema_state_dict(self) -> dict: + """State dict of EMA shadow for checkpoint. Empty if no EMA.""" + return {k: v.clone() for k, v in self._ema_shadow.items()} + + def load_ema_state_dict(self, state_dict: dict) -> None: + """Load EMA shadow from checkpoint.""" + self._ema_shadow = {k: v.clone() for k, v in state_dict.items()} + + def forward( + self, + features: torch.Tensor, + ray_directions: torch.Tensor, + alpha: torch.Tensor | None = None, + ) -> torch.Tensor: + """Transform features and ray directions to RGB. + + Args: + features: Input features of shape [H*W, N] or [B, H, W, N] (alpha-blended). + ray_directions: Ray directions of shape [H*W, 3] or [B, H, W, 3]. + alpha: Optional [H*W] or [B, H, W] opacity. If given, un-multiply (features/alpha), + decode, then re-multiply (rgb*alpha) so the decoder sees unblended features. + + Returns: + RGB tensor of shape [H*W, 3] or [B, H, W, 3] + """ + features_shape = features.shape + ray_dirs_shape = ray_directions.shape + + if len(features_shape) == 4: # [B, H, W, N] + B, H, W, N = features_shape + assert ray_dirs_shape == (B, H, W, 3), f"Ray directions shape mismatch: expected {(B, H, W, 3)}, got {ray_dirs_shape}" + assert N == self.ray_feature_dim, f"Expected {self.ray_feature_dim} features, got {N}" + + features_flat = features.reshape(B * H * W, N) + ray_dirs_flat = ray_directions.reshape(B * H * W, 3) + alpha_flat = alpha.reshape(B * H * W, 1) if alpha is not None else None + + rgb_flat = self._process(features_flat, ray_dirs_flat, alpha_flat) + return rgb_flat.reshape(B, H, W, 3) + + elif len(features_shape) == 2: # [H*W, N] + HW, N = features_shape + assert ray_dirs_shape == (HW, 3), f"Ray directions shape mismatch: expected {(HW, 3)}, got {ray_dirs_shape}" + assert N == self.ray_feature_dim, f"Expected {self.ray_feature_dim} features, got {N}" + alpha_flat = alpha.reshape(HW, 1) if alpha is not None else None + return self._process(features, ray_directions, alpha_flat) + else: + raise ValueError(f"Expected input shape [B, H, W, N] or [H*W, N], got {features_shape}") + + def _process( + self, + features: torch.Tensor, + ray_directions: torch.Tensor, + alpha: torch.Tensor | None = None, + ) -> torch.Tensor: + if alpha is not None: + alpha_safe = alpha.clamp(min=1e-8) + features = features / alpha_safe + + # tcnn SH/Frequency encoding expects inputs in [0,1]: (v*sh_scale+1)/2 + dirs_unit_cube = (ray_directions * self.sh_scale + 1.0) * 0.5 + full_input = torch.cat([features, dirs_unit_cube], dim=-1) + rgb = self.network(full_input) + + if alpha is not None: + rgb = rgb * alpha_safe + + return rgb.float() + + def regularization_loss(self) -> torch.Tensor: + """Compute L2 regularization loss on decoder weights.""" + loss = torch.tensor(0.0, device=self.network.params.device) + loss = loss + torch.sum(self.network.params**2) + return loss + + def extra_repr(self) -> str: + return ( + f"ray_feature_dim={self.ray_feature_dim}, " + f"hidden_dim={self.hidden_dim}, " + f"num_layers={self.num_layers}, " + f"sh_scale={self.sh_scale}, " + f"output_activation={self.output_activation}" + ) diff --git a/threedgrut/model/features.py b/threedgrut/model/features.py new file mode 100644 index 00000000..aaf727a0 --- /dev/null +++ b/threedgrut/model/features.py @@ -0,0 +1,168 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from enum import IntEnum + + +class Features: + """Enums and conf-driven getters for feature type, activation, and interpolation.""" + + class Type(IntEnum): + """Feature representation mode: integer value used directly in C preprocessor defines.""" + SH = 0 # Spherical harmonics + NHT = 1 # Neural harmonic texture + + @classmethod + def from_string(cls, value: str) -> "Features.Type": + value_lower = value.lower() + for member in cls: + if member.name.lower() == value_lower: + return member + raise ValueError(f"Invalid feature_type: '{value}'. Must be one of {[m.name.lower() for m in cls]}") + + class ActivationType(IntEnum): + NONE = 0 + SIREN = 1 + SINCOS = 2 + RELU = 3 + + class InterpolationType(IntEnum): + CENTER = 0 + BEZIER = 1 + + class InterpolationSupport(IntEnum): + CENTER = 0 + TETRAHEDRA = 1 + TRIANGLE = 2 + + def __init__(self, conf): + self._conf = conf + + @property + def transform_type(self): + """SH or NHT — integer value used directly in C preprocessor defines.""" + return Features.Type.from_string(self._conf.model.feature_type) + + @property + def activation_type(self): + """Feature activation type from nht_features.activation.type.""" + feature_type = self._conf.model.feature_type.lower() + if feature_type != "nht": + return Features.ActivationType.NONE + v = getattr(self._conf.model.nht_features, "activation", None) + if v is None: + return Features.ActivationType.NONE + t = getattr(v, "type", "none") + if isinstance(t, str): + t = t.lower() + if t == "none": + return Features.ActivationType.NONE + if t == "siren": + return Features.ActivationType.SIREN + if t == "sincos": + return Features.ActivationType.SINCOS + if t == "relu": + return Features.ActivationType.RELU + raise ValueError(f"Unknown nht_features.activation.type: {t}") + + @property + def activation_num_frequencies(self): + """Number of frequency bands. 1 when activation is none or relu.""" + if self.activation_type in (Features.ActivationType.NONE, Features.ActivationType.RELU): + return 1 + v = getattr(self._conf.model.nht_features, "activation", None) + return int(getattr(v, "num_frequencies", 1)) if v else 1 + + @property + def interpolation_type(self): + """CENTER (none/barycentric) or BEZIER.""" + feature_type = self._conf.model.feature_type.lower() + if feature_type != "nht": + return Features.InterpolationType.CENTER + v = getattr(self._conf.model.nht_features, "interpolation_type", "none").lower() + if v == "none": + return Features.InterpolationType.CENTER + if v == "barycentric": + return Features.InterpolationType.CENTER + if v == "bezier": + return Features.InterpolationType.BEZIER + raise ValueError(f"Unknown nht_features.interpolation_type: {v}") + + @property + def interpolation_support(self): + """CENTER, TETRAHEDRA (gaussian), or TRIANGLE (trisurfel).""" + feature_type = self._conf.model.feature_type.lower() + if feature_type != "nht": + return Features.InterpolationSupport.CENTER + v = getattr(self._conf.model.nht_features, "interpolation_type", "none").lower() + if v == "none": + return Features.InterpolationSupport.CENTER + if v == "barycentric": + primitive = getattr(self._conf.render, "primitive_type", "instances") + return Features.InterpolationSupport.TRIANGLE if primitive == "trisurfel" else Features.InterpolationSupport.TETRAHEDRA + raise ValueError(f"Unknown nht_features.interpolation_type: {v}") + + @property + def num_interpolation_points(self): + """1 for center support, 4 for barycentric (tetrahedra or trisurfel).""" + feature_type = self._conf.model.feature_type.lower() + if feature_type != "nht": + return 1 + if self.interpolation_support == Features.InterpolationSupport.CENTER: + return 1 + return 4 # barycentric: tetrahedra (4 verts) or trisurfel (2 coplanar triangles, 4 verts) + + @property + def particle_feature_dim(self): + """Total feature dim per particle (buffer stride). For NHT = nht_features.dim.""" + feature_type = self._conf.model.feature_type.lower() + if feature_type == "sh": + sh_degree = self._conf.model.progressive_training.max_n_features + return 3 * ((sh_degree + 1) ** 2) + elif feature_type == "nht": + return self._conf.model.nht_features.dim + raise ValueError(f"Unknown feature_type: {feature_type}") + + @property + def interp_point_feature_dim(self): + """Per-interpolation-point feature dim before activation.""" + feature_type = self._conf.model.feature_type.lower() + if feature_type != "nht": + return 3 + return self._conf.model.nht_features.dim // self.num_interpolation_points + + @property + def ray_feature_dim(self): + """Per-ray feature dim (decoder input): interp_point_dim * num_frequencies.""" + feature_type = self._conf.model.feature_type.lower() + if feature_type == "sh": + return 3 # RGB output + elif feature_type == "nht": + return self.interp_point_feature_dim * self.activation_num_frequencies + raise ValueError(f"Unknown feature_type: {feature_type}") + + @property + def feature_defines(self): + """C preprocessor defines for all feature-related kernel parameters.""" + return [ + f"-DPARTICLE_FEATURE_DIM={self.particle_feature_dim}", + f"-DRAY_FEATURE_DIM={self.ray_feature_dim}", + f"-DFEATURE_TRANSFORM_TYPE={self.transform_type}", + f"-DFEATURE_INTERPOLATION_TYPE={self.interpolation_type}", + f"-DFEATURE_INTERPOLATION_SUPPORT={self.interpolation_support}", + f"-DFEATURE_ACTIVATION_TYPE={self.activation_type}", + f"-DFEATURE_ACTIVATION_NUM_FREQUENCIES={self.activation_num_frequencies}", + f"-DINTERP_POINT_FEATURE_DIM={self.interp_point_feature_dim}", + ] diff --git a/threedgrut/model/model.py b/threedgrut/model/model.py index f5400c2a..460e97aa 100644 --- a/threedgrut/model/model.py +++ b/threedgrut/model/model.py @@ -42,9 +42,9 @@ to_np, to_torch, ) +from threedgrut.model.features import Features from threedgrut.utils.render import RGB2SH - class MixtureOfGaussians(torch.nn.Module, ExportableModel): """ """ @@ -54,10 +54,15 @@ def num_gaussians(self): def feature_fields(self) -> list[str]: """Returns a list of feature field names - subclasses can override""" - return [ - "features_albedo", - "features_specular", - ] + if self.feature_type == Features.Type.SH: + return [ + "features_albedo", + "features_specular", + ] + elif self.feature_type == Features.Type.NHT: + return ["features"] + else: + raise ValueError(f"Unknown feature_type: {self.feature_type}") def get_positions(self) -> torch.Tensor: return self.positions @@ -69,13 +74,24 @@ def get_n_active_features(self) -> int: return self.n_active_features def get_features_albedo(self) -> torch.Tensor: - return self.features_albedo + if self.feature_type == Features.Type.SH: + return self.features_albedo + else: + raise AttributeError(f"features_albedo not available in feature_type='{self.feature_type.name.lower()}' mode") def get_features_specular(self) -> torch.Tensor: - return self.features_specular + if self.feature_type == Features.Type.SH: + return self.features_specular + else: + raise AttributeError(f"features_specular not available in feature_type='{self.feature_type.name.lower()}' mode") def get_features(self): - return torch.cat((self.features_albedo, self.features_specular), dim=1) + if self.feature_type == Features.Type.SH: + return torch.cat((self.features_albedo, self.features_specular), dim=1) + elif self.feature_type == Features.Type.NHT: + return self.features # [N, K] + else: + raise ValueError(f"Unknown feature_type: {self.feature_type}") def get_scale(self, preactivation=False): if preactivation: @@ -124,21 +140,31 @@ def get_model_parameters(self) -> dict: # Add optimizer state dict "optimizer": self.optimizer.state_dict(), "config": self.conf, + # Feature type and dimensions + "feature_type": self.feature_type.name.lower(), # Store as string for serialization + "particle_feature_dim": self.particle_feature_dim, + "ray_feature_dim": self.ray_feature_dim, } if self.progressive_training: model_params["feature_dim_increase_interval"] = self.feature_dim_increase_interval model_params["feature_dim_increase_step"] = self.feature_dim_increase_step - if self.feature_type == "sh": + if self.feature_type == Features.Type.SH: model_params["features_albedo"] = self.features_albedo model_params["features_specular"] = self.features_specular + elif self.feature_type == Features.Type.NHT: + model_params["features"] = self.features return model_params def __init__(self, conf, scene_extent=None): super().__init__() + # Store config early - needed for feature type detection + self.conf = conf + self.scene_extent = scene_extent + sh_degree = conf.model.progressive_training.max_n_features render_sph_degree = conf.render.particle_radiance_sph_degree if sh_degree > render_sph_degree: @@ -157,16 +183,47 @@ def __init__(self, conf, scene_extent=None): ) # Rotation of each Gaussian represented as a unit quaternion [n_gaussians, 4] self.scale = torch.nn.Parameter(torch.empty([0, 3])) # Anisotropic scale of each Gaussian [n_gaussians, 3] self.density = torch.nn.Parameter(torch.empty([0, 1])) # Density of each Gaussian [n_gaussians, 1] - self.features_albedo = torch.nn.Parameter( - torch.empty([0, 3]) - ) # Feature vector of the 0th order SH coefficients [n_gaussians, 3] (We split it into two due to different learning rates) - self.features_specular = torch.nn.Parameter( - torch.empty([0, specular_dim]) - ) # Features of the higher order SH coefficients [n_gaussians, specular_dim] + + # Feature type configuration - determine feature storage mode + self.feature_type = Features.Type.from_string(self.conf.model.feature_type) + + primitive_type = (getattr(conf.render, "primitive_type", None) or "").lower() + if self.feature_type == Features.Type.NHT and primitive_type == "trisurfel": + raise ValueError( + "Trisurfels are not supported in NHT mode. Use primitive_type 'instances' or 'icosahedron'." + ) + + if self.feature_type == Features.Type.SH: + # Spherical harmonics mode: separate albedo and specular features + self.features_albedo = torch.nn.Parameter( + torch.empty([0, 3]) + ) # Feature vector of the 0th order SH coefficients [n_gaussians, 3] + self.features_specular = torch.nn.Parameter( + torch.empty([0, specular_dim]) + ) # Features of the higher order SH coefficients [n_gaussians, specular_dim] + self.particle_feature_dim = 3 + specular_dim # SH coeffs (input to tracer) + self.ray_feature_dim = 3 # RGB output from tracer + elif self.feature_type == Features.Type.NHT: + # NHT: per-particle feature vector, decoder maps rendered features -> RGB + feat = Features(conf) + num_points = feat.num_interpolation_points + nht_dim = int(conf.model.nht_features.dim) + if nht_dim % num_points != 0: + raise ValueError( + f"nht_features.dim={nht_dim} must be divisible by num_interpolation_points={num_points} " + f"(interpolation_type + primitive)" + ) + self.nht_num_interpolation_points = num_points + self.particle_feature_dim = feat.particle_feature_dim + self.ray_feature_dim = feat.ray_feature_dim + self.features = torch.nn.Parameter( + torch.empty([0, self.particle_feature_dim]) + ) # NHT features [n_gaussians, particle_feature_dim] + else: + raise ValueError(f"Unknown feature_type: {self.feature_type}. Must be 'sh' or 'nht'.") + self.max_sh_degree = sh_degree - self.conf = conf - self.scene_extent = scene_extent self.positions_gradient_norm = None self.device = "cuda" @@ -180,7 +237,6 @@ def __init__(self, conf, scene_extent=None): self.background = background.make(self.conf.model.background.name, self.conf.model.background) # Check if we would like to do progressive training - self.feature_type = self.conf.model.progressive_training.feature_type self.n_active_features = min(self.conf.model.progressive_training.init_n_features, sh_degree) self.max_n_features = ( sh_degree # For SH, this is the SH degree (clamped if > render.particle_radiance_sph_degree) @@ -191,7 +247,6 @@ def __init__(self, conf, scene_extent=None): self.feature_dim_increase_step = self.conf.model.progressive_training.increase_step self.progressive_training = True - # Rendering method if conf.render.method == "3dgrt": self.renderer = threedgrt_tracer.Tracer(conf) elif conf.render.method == "3dgut": @@ -219,8 +274,12 @@ def freeze_gaussians(self) -> None: self.rotation.requires_grad = False self.scale.requires_grad = False self.density.requires_grad = False - self.features_albedo.requires_grad = False - self.features_specular.requires_grad = False + + if self.feature_type == Features.Type.SH: + self.features_albedo.requires_grad = False + self.features_specular.requires_grad = False + elif self.feature_type == Features.Type.NHT: + self.features.requires_grad = False self._gaussians_frozen = True logger.info("❄️ [Distillation] Gaussian parameters frozen") @@ -232,12 +291,14 @@ def validate_fields(self): assert self.rotation.shape == (num_gaussians, 4) assert self.scale.shape == (num_gaussians, 3) - if self.feature_type == "sh": + if self.feature_type == Features.Type.SH: assert self.features_albedo.shape == (num_gaussians, 3) specular_sh_dims = sh_degree_to_specular_dim(self.max_n_features) assert self.features_specular.shape == (num_gaussians, specular_sh_dims) + elif self.feature_type == Features.Type.NHT: + assert self.features.shape == (num_gaussians, self.particle_feature_dim) else: - raise ValueError("Neural features not yet supported.") + raise ValueError(f"Unknown feature_type: {self.feature_type}") def init_from_colmap(self, root_path: str, observer_pts): # Special case for scannetpp dataset @@ -327,6 +388,8 @@ def init_from_fused_point_cloud(self, pc_path: str, observer_pts): ) def init_from_pretrained_point_cloud(self, pc_path: str, set_optimizable_parameters: bool = True): + if self.feature_type != Features.Type.SH: + raise NotImplementedError(f"init_from_pretrained_point_cloud only supports feature_type='sh', got '{self.feature_type.name.lower()}'") data = PlyData.read(pc_path) num_gaussians = len(data["vertex"]) self.positions = torch.nn.Parameter( @@ -477,14 +540,24 @@ def init_from_random_point_cloud( # sh albedo in [0, 0.0039] fused_color = torch.rand((num_gaussians, 3), dtype=dtype, device=self.device) / 255.0 - features_albedo = features_specular = None - if self.feature_type == "sh": + # Initialize features based on feature_type + if self.feature_type == Features.Type.SH: features_albedo = fused_color.contiguous() max_sh_degree = self.max_n_features num_specular_features = sh_degree_to_specular_dim(max_sh_degree) features_specular = torch.zeros( (num_gaussians, num_specular_features), dtype=dtype, device=self.device ).contiguous() + elif self.feature_type == Features.Type.NHT: + # Initialize learned features with uniform [-pi/2, pi/2] for SIREN + act_type = Features(self.conf).activation_type + if act_type in (Features.ActivationType.SIREN, Features.ActivationType.SINCOS): + features = ( + torch.rand((num_gaussians, self.particle_feature_dim), dtype=dtype, device=self.device) + * 3.141592653589793 - 1.5707963267948966 + ) + else: + features = torch.randn((num_gaussians, self.particle_feature_dim), dtype=dtype, device=self.device) * 0.1 dist = torch.clamp_min(nearest_neighbor_dist_cpuKD(fused_point_cloud), 1e-3) scales = torch.log(dist * self.conf.model.default_scale_factor)[..., None].repeat(1, 3) @@ -500,24 +573,82 @@ def init_from_random_point_cloud( self.rotation = torch.nn.Parameter(rots.to(dtype=dtype, device=self.device)) self.scale = torch.nn.Parameter(scales.to(dtype=dtype, device=self.device)) self.density = torch.nn.Parameter(opacities.to(dtype=dtype, device=self.device)) - self.features_albedo = torch.nn.Parameter(features_albedo.to(dtype=dtype, device=self.device)) - self.features_specular = torch.nn.Parameter(features_specular.to(dtype=dtype, device=self.device)) + + if self.feature_type == Features.Type.SH: + self.features_albedo = torch.nn.Parameter(features_albedo.to(dtype=dtype, device=self.device)) + self.features_specular = torch.nn.Parameter(features_specular.to(dtype=dtype, device=self.device)) + elif self.feature_type == Features.Type.NHT: + self.features = torch.nn.Parameter(features.to(dtype=dtype, device=self.device)) if set_optimizable_parameters: self.set_optimizable_parameters() self.validate_fields() def init_from_checkpoint(self, checkpoint: dict, setup_optimizer=True): + # Backward compatibility: detect legacy checkpoints without feature_type + if "feature_type" not in checkpoint and "features_albedo" in checkpoint: + logger.info("Loading legacy checkpoint - auto-detecting feature_type='sh'") + checkpoint["feature_type"] = "sh" + checkpoint["particle_feature_dim"] = checkpoint["features_albedo"].shape[1] + checkpoint["features_specular"].shape[1] + checkpoint["ray_feature_dim"] = 3 + + # Load features based on feature_type (convert string to enum) + checkpoint_feature_type_str = checkpoint.get("feature_type", "sh") + checkpoint_feature_type = Features.Type.from_string(checkpoint_feature_type_str) + + # NHT: 3DGUT is compiled with PARTICLE_FEATURE_DIM / RAY_FEATURE_DIM from current config. + # Checkpoints must match those compile-time constants or CUDA will read past feature buffers. + if checkpoint_feature_type == Features.Type.NHT: + if "features" not in checkpoint: + raise ValueError("NHT checkpoint missing 'features' tensor") + feat = checkpoint["features"] + ck_pf = int(feat.shape[1]) + ck_rf = checkpoint.get("ray_feature_dim") + if ck_pf != self.particle_feature_dim: + raise ValueError( + f"NHT checkpoint features width is {ck_pf} but this build expects " + f"particle_feature_dim={self.particle_feature_dim} from config " + f"(model.nht_features.dim / interpolation). The 3DGUT CUDA extension was compiled for " + f"the config value; use the same nht_features (and render.primitive_type) as the run " + f"that produced the checkpoint, or train from scratch." + ) + if ck_rf is not None and int(ck_rf) != self.ray_feature_dim: + raise ValueError( + f"NHT checkpoint ray_feature_dim={ck_rf} does not match config ray_feature_dim=" + f"{self.ray_feature_dim}. Align model.nht_features.activation with the checkpoint run." + ) + if "particle_feature_dim" in checkpoint and int(checkpoint["particle_feature_dim"]) != ck_pf: + logger.warning( + f"Checkpoint particle_feature_dim={checkpoint['particle_feature_dim']} disagrees with " + f"features.shape[1]={ck_pf}; using tensor shape." + ) + + # Load basic parameters self.positions = checkpoint["positions"] self.rotation = checkpoint["rotation"] self.scale = checkpoint["scale"] self.density = checkpoint["density"] - self.features_albedo = checkpoint["features_albedo"] - self.features_specular = checkpoint["features_specular"] self.n_active_features = checkpoint["n_active_features"] self.max_n_features = checkpoint["max_n_features"] self.scene_extent = checkpoint["scene_extent"] + # Load feature dimensions. For NHT, keep config-derived dims (validated above vs checkpoint tensors); + # stale metadata keys must not override after a successful shape check. + if checkpoint_feature_type != Features.Type.NHT: + if "particle_feature_dim" in checkpoint: + self.particle_feature_dim = checkpoint["particle_feature_dim"] + if "ray_feature_dim" in checkpoint: + self.ray_feature_dim = checkpoint["ray_feature_dim"] + + if checkpoint_feature_type == Features.Type.SH: + self.features_albedo = checkpoint["features_albedo"] + self.features_specular = checkpoint["features_specular"] + elif checkpoint_feature_type == Features.Type.NHT: + self.features = checkpoint["features"] + self.nht_num_interpolation_points = Features(self.conf).num_interpolation_points + else: + raise ValueError(f"Unknown feature_type in checkpoint: {checkpoint_feature_type}") + if self.progressive_training: self.feature_dim_increase_interval = checkpoint["feature_dim_increase_interval"] self.feature_dim_increase_step = checkpoint["feature_dim_increase_step"] @@ -588,17 +719,41 @@ def default_initialize_from_points(self, pts, observer_pts, colors=None, use_obs if colors is None: colors = torch.randint(0, 256, (N, 3), dtype=torch.uint8, device=self.device, generator=rng) - features_albedo = to_torch(RGB2SH(to_np(colors.float() / 255.0)), device=self.device) - - num_specular_dims = sh_degree_to_specular_dim(self.max_n_features) - features_specular = torch.zeros((N, num_specular_dims)) + # Initialize features based on feature_type + if self.feature_type == Features.Type.SH: + features_albedo = to_torch(RGB2SH(to_np(colors.float() / 255.0)), device=self.device) + num_specular_dims = sh_degree_to_specular_dim(self.max_n_features) + features_specular = torch.zeros((N, num_specular_dims)) + elif self.feature_type == Features.Type.NHT: + act_type = Features(self.conf).activation_type + if act_type in (Features.ActivationType.SIREN, Features.ActivationType.SINCOS): + # Uniform [-pi/2, pi/2]: symmetric around 0 so sin activations start + # with zero mean and balanced positive/negative gradients. + features = ( + torch.rand((N, self.particle_feature_dim), dtype=dtype, device=self.device) + * 3.141592653589793 # pi + - 1.5707963267948966 # - pi/2 + ) + elif act_type == Features.ActivationType.RELU: + # Same as SH degree 0: init from RGB so relu(features)=radiance in [0,1]. + rgb = (colors.float() / 255.0).to(dtype=dtype, device=self.device) + if self.particle_feature_dim == 3: + features = rgb + else: + features = rgb.repeat(1, (self.particle_feature_dim + 2) // 3)[:, : self.particle_feature_dim] + else: + features = torch.randn((N, self.particle_feature_dim), dtype=dtype, device=self.device) * 0.01 self.positions = torch.nn.Parameter(positions.to(dtype=dtype, device=self.device)) self.rotation = torch.nn.Parameter(rots.to(dtype=dtype, device=self.device)) self.scale = torch.nn.Parameter(scales.to(dtype=dtype, device=self.device)) self.density = torch.nn.Parameter(opacities.to(dtype=dtype, device=self.device)) - self.features_albedo = torch.nn.Parameter(features_albedo.to(dtype=dtype, device=self.device)) - self.features_specular = torch.nn.Parameter(features_specular.to(dtype=dtype, device=self.device)) + + if self.feature_type == Features.Type.SH: + self.features_albedo = torch.nn.Parameter(features_albedo.to(dtype=dtype, device=self.device)) + self.features_specular = torch.nn.Parameter(features_specular.to(dtype=dtype, device=self.device)) + elif self.feature_type == Features.Type.NHT: + self.features = torch.nn.Parameter(features.to(dtype=dtype, device=self.device)) self.set_optimizable_parameters() self.setup_optimizer() @@ -607,6 +762,11 @@ def default_initialize_from_points(self, pts, observer_pts, colors=None, use_obs def setup_optimizer(self, state_dict=None): params = [] for name, args in self.conf.optimizer.params.items(): + # Skip parameters that don't exist (e.g., 'features' in SH mode or 'features_albedo' in learned mode) + if not hasattr(self, name): + logger.info(f"Skipping optimizer parameter '{name}' - not present in {self.feature_type.name.lower()} mode") + continue + module = getattr(self, name) # If the module is a torch.nn.Module, we can add all of its trainable parameters to the optimizer @@ -644,15 +804,28 @@ def setup_optimizer(self, state_dict=None): def setup_scheduler(self): self.schedulers = {} for name, args in self.conf.scheduler.items(): - if args.type is not None and getattr(self, name).requires_grad: - if name == "positions": - self.schedulers[name] = get_scheduler(args.type)( - lr_init=args.lr_init * self.scene_extent, - lr_final=args.lr_final * self.scene_extent, - max_steps=args.max_steps, - ) - else: - self.schedulers[name] = get_scheduler(args.type)(**args) + if not hasattr(self, name): + continue + attr = getattr(self, name) + if not (hasattr(attr, "requires_grad") and attr.requires_grad): + continue + if args.type is None: + continue + if name == "positions": + self.schedulers[name] = get_scheduler(args.type)( + lr_init=args.lr_init * self.scene_extent, + lr_final=args.lr_final * self.scene_extent, + max_steps=args.max_steps, + ) + elif name == "features": + lr_init = getattr(self.conf.optimizer.params.features, "lr", 0.07) + decay_final = getattr(args, "decay_final", 0.001) + lr_final = lr_init * decay_final + self.schedulers[name] = get_scheduler(args.type)( + lr_init=lr_init, lr_final=lr_final, max_steps=args.max_steps + ) + else: + self.schedulers[name] = get_scheduler(args.type)(**args) def scheduler_step(self, step): for param_group in self.optimizer.param_groups: @@ -664,10 +837,6 @@ def scheduler_step(self, step): def set_optimizable_parameters(self): if not self.conf.model.optimize_density: self.density.requires_grad = False - if not self.conf.model.optimize_features_albedo: - self.features_albedo.requires_grad = False - if not self.conf.model.optimize_features_specular: - self.features_specular.requires_grad = False if not self.conf.model.optimize_rotation: self.rotation.requires_grad = False if not self.conf.model.optimize_scale: @@ -675,6 +844,17 @@ def set_optimizable_parameters(self): if not self.conf.model.optimize_position: self.positions.requires_grad = False + # Handle feature optimization based on feature_type + if self.feature_type == Features.Type.SH: + if not self.conf.model.optimize_features_albedo: + self.features_albedo.requires_grad = False + if not self.conf.model.optimize_features_specular: + self.features_specular.requires_grad = False + elif self.feature_type == Features.Type.NHT: + # For learned features, check if optimize_features config exists + if not self.conf.model.optimize_features: + self.features.requires_grad = False + def update_optimizable_parameters(self, optimizable_tensors: dict[str, torch.Tensor]): for name, value in optimizable_tensors.items(): setattr(self, name, value) @@ -683,7 +863,7 @@ def increase_num_active_features(self) -> None: self.n_active_features = min(self.max_n_features, self.n_active_features + self.feature_dim_increase_step) def get_active_feature_mask(self) -> torch.Tensor: - if self.feature_type == "sh": + if self.feature_type == Features.Type.SH: current_sh_degree = self.n_active_features max_sh_degree = self.max_n_features active_features = sh_degree_to_num_features(current_sh_degree) @@ -700,7 +880,9 @@ def clamp_density(self): optimizable_tensors = self.replace_tensor_to_optimizer(updated_densities, "density") self.density = optimizable_tensors["density"] - def forward(self, batch: Batch, train=False, frame_id=0) -> dict[str, torch.Tensor]: + def forward( + self, batch: Batch, train=False, frame_id=0 + ) -> dict[str, torch.Tensor]: """ Args: batch: a Batch structure containing the input data @@ -731,6 +913,8 @@ def export_ply(self, mogt_path: str): @torch.no_grad() def init_from_ply(self, mogt_path: str, init_model=True): + if self.feature_type != Features.Type.SH: + raise NotImplementedError(f"init_from_ply only supports feature_type='sh', got '{self.feature_type.name.lower()}'") plydata = PlyData.read(mogt_path) mogt_pos = np.stack( @@ -812,15 +996,21 @@ def copy_fields(self, other, deepcopy=False): self.rotation = torch.nn.Parameter(other.rotation.clone()) self.scale = torch.nn.Parameter(other.scale.clone()) self.density = torch.nn.Parameter(other.density.clone()) - self.features_albedo = torch.nn.Parameter(other.features_albedo.clone()) - self.features_specular = torch.nn.Parameter(other.features_specular.clone()) - else: # shared tensors + if other.feature_type == Features.Type.SH: + self.features_albedo = torch.nn.Parameter(other.features_albedo.clone()) + self.features_specular = torch.nn.Parameter(other.features_specular.clone()) + elif other.feature_type == Features.Type.NHT: + self.features = torch.nn.Parameter(other.features.clone()) + else: self.positions = torch.nn.Parameter(other.positions) self.rotation = torch.nn.Parameter(other.rotation) self.scale = torch.nn.Parameter(other.scale) self.density = torch.nn.Parameter(other.density) - self.features_albedo = torch.nn.Parameter(other.features_albedo) - self.features_specular = torch.nn.Parameter(other.features_specular) + if other.feature_type == Features.Type.SH: + self.features_albedo = torch.nn.Parameter(other.features_albedo) + self.features_specular = torch.nn.Parameter(other.features_specular) + elif other.feature_type == Features.Type.NHT: + self.features = torch.nn.Parameter(other.features) self.max_sh_degree = other.max_sh_degree self.n_active_features = other.n_active_features self.scene_extent = other.scene_extent @@ -828,6 +1018,11 @@ def copy_fields(self, other, deepcopy=False): self.feature_dim_increase_interval = other.feature_dim_increase_interval self.feature_dim_increase_step = other.feature_dim_increase_step self.background = other.background + self.feature_type = other.feature_type + self.particle_feature_dim = other.particle_feature_dim + self.ray_feature_dim = other.ray_feature_dim + if hasattr(other, "nht_num_interpolation_points"): + self.nht_num_interpolation_points = other.nht_num_interpolation_points self.validate_fields() def clone(self): @@ -842,8 +1037,11 @@ def __getitem__(self, idx): sliced.rotation = torch.nn.Parameter(sliced.rotation[idx]) sliced.scale = torch.nn.Parameter(sliced.scale[idx]) sliced.density = torch.nn.Parameter(sliced.density[idx]) - sliced.features_albedo = torch.nn.Parameter(sliced.features_albedo[idx]) - sliced.features_specular = torch.nn.Parameter(sliced.features_specular[idx]) + if self.feature_type == Features.Type.SH: + sliced.features_albedo = torch.nn.Parameter(sliced.features_albedo[idx]) + sliced.features_specular = torch.nn.Parameter(sliced.features_specular[idx]) + elif self.feature_type == Features.Type.NHT: + sliced.features = torch.nn.Parameter(sliced.features[idx]) return sliced def __len__(self): diff --git a/threedgrut/render.py b/threedgrut/render.py index 313cff66..18d69bb1 100644 --- a/threedgrut/render.py +++ b/threedgrut/render.py @@ -29,7 +29,7 @@ from threedgrut.utils.color_correct import color_correct_affine from threedgrut.utils.logger import logger from threedgrut.utils.misc import create_summary_writer -from threedgrut.utils.render import apply_post_processing +from threedgrut.utils.render import apply_background, apply_feature_decoder, apply_post_processing class Renderer: @@ -44,6 +44,7 @@ def __init__( writer=None, compute_extra_metrics=True, post_processing=None, + feature_decoder=None, ) -> None: if path: # Replace the path to the test data @@ -59,6 +60,7 @@ def __init__( self.writer = writer self.compute_extra_metrics = compute_extra_metrics self.post_processing = post_processing + self.feature_decoder = feature_decoder if conf.model.background.color == "black": self.bg_color = torch.zeros((3,), dtype=torch.float32, device="cuda") @@ -148,6 +150,40 @@ def from_checkpoint( num_frames = post_processing.exposure_params.shape[0] logger.info(f"📷 {method.upper()} loaded from checkpoint: {num_cameras} cameras, {num_frames} frames") + # Load feature decoder for nht models + feature_decoder = None + if "feature_decoder" in checkpoint: + from threedgrut.model.features import Features + from threedgrut.model.feature_decoder import FeatureDecoder + + if model.feature_type == Features.Type.NHT: + conf_model = conf.model + dec = conf_model.nht_decoder + hidden_dim = dec.hidden_dim + num_layers = getattr(dec, "num_layers", 4) + dir_encoding = getattr(dec, "dir_encoding", "SphericalHarmonics") + dir_encoding_degree = getattr(dec, "dir_encoding_degree", 3) + output_activation = getattr(dec, "output_activation", "Sigmoid") + ema_decay = getattr(dec, "ema_decay", 0.0) + ema_start_step = getattr(dec, "ema_start_step", 0) + feature_decoder = FeatureDecoder( + ray_feature_dim=model.ray_feature_dim, + hidden_dim=hidden_dim, + num_layers=num_layers, + dir_encoding=dir_encoding, + dir_encoding_degree=dir_encoding_degree, + output_activation=output_activation, + ema_decay=ema_decay, + ema_start_step=ema_start_step, + ).to("cuda") + feature_decoder.load_state_dict(checkpoint["feature_decoder"]["module"]) + ema_state = checkpoint["feature_decoder"].get("ema") + if ema_state is not None: + feature_decoder.load_ema_state_dict(ema_state) + feature_decoder.apply_ema_shadow() + feature_decoder.eval() + logger.info("🎨 Feature decoder loaded from checkpoint") + return Renderer( model=model, conf=conf, @@ -158,6 +194,7 @@ def from_checkpoint( writer=writer, compute_extra_metrics=computes_extra_metrics, post_processing=post_processing, + feature_decoder=feature_decoder, ) @classmethod @@ -171,6 +208,7 @@ def from_preloaded_model( global_step=None, compute_extra_metrics=False, post_processing=None, + feature_decoder=None, ): """Loads checkpoint for test path.""" @@ -188,6 +226,7 @@ def from_preloaded_model( writer=writer, compute_extra_metrics=compute_extra_metrics, post_processing=post_processing, + feature_decoder=feature_decoder, ) @torch.no_grad() @@ -236,20 +275,23 @@ def render_all(self): # Compute the outputs of a single batch outputs = self.model(gpu_batch) + if self.feature_decoder is not None: + outputs = apply_feature_decoder(self.feature_decoder, outputs, gpu_batch, training=False) + outputs = apply_background(self.model.background, outputs, gpu_batch, training=False) # Apply post-processing if self.post_processing is not None: outputs = apply_post_processing(self.post_processing, outputs, gpu_batch, training=False) - pred_rgb_full = outputs["pred_rgb"] + pred_features_full = outputs["pred_features"] rgb_gt_full = gpu_batch.rgb_gt # The values are already alpha composited with the background torchvision.utils.save_image( - pred_rgb_full.squeeze(0).permute(2, 0, 1), + pred_features_full.squeeze(0).permute(2, 0, 1), os.path.join(output_path_renders, "{0:05d}".format(iteration) + ".png"), ) - pred_img_to_write = pred_rgb_full[-1].clip(0, 1.0) + pred_img_to_write = pred_features_full[-1].clip(0, 1.0) gt_img_to_write = rgb_gt_full[-1].clip(0, 1.0) if self.save_gt: @@ -259,7 +301,7 @@ def render_all(self): ) # Compute the loss - psnr_single_img = criterions["psnr"](outputs["pred_rgb"], gpu_batch.rgb_gt).item() + psnr_single_img = criterions["psnr"](outputs["pred_features"], gpu_batch.rgb_gt).item() psnr.append(psnr_single_img) # evaluation on valid rays only logger.info(f"Frame {iteration}, PSNR: {psnr[-1]}") @@ -276,29 +318,29 @@ def render_all(self): # evaluate on full image ssim.append( criterions["ssim"]( - pred_rgb_full.permute(0, 3, 1, 2), + pred_features_full.permute(0, 3, 1, 2), rgb_gt_full.permute(0, 3, 1, 2), ).item() ) lpips.append( criterions["lpips"]( - pred_rgb_full.clip(0, 1).permute(0, 3, 1, 2), + pred_features_full.clip(0, 1).permute(0, 3, 1, 2), rgb_gt_full.permute(0, 3, 1, 2), ).item() ) # Color-corrected metrics - pred_rgb_cc = color_correct_affine(pred_rgb_full, rgb_gt_full) - cc_psnr.append(criterions["psnr"](pred_rgb_cc, rgb_gt_full).item()) + pred_features_cc = color_correct_affine(pred_features_full, rgb_gt_full) + cc_psnr.append(criterions["psnr"](pred_features_cc, rgb_gt_full).item()) cc_ssim.append( criterions["ssim"]( - pred_rgb_cc.permute(0, 3, 1, 2), + pred_features_cc.permute(0, 3, 1, 2), rgb_gt_full.permute(0, 3, 1, 2), ).item() ) cc_lpips.append( criterions["lpips"]( - pred_rgb_cc.clip(0, 1).permute(0, 3, 1, 2), + pred_features_cc.clip(0, 1).permute(0, 3, 1, 2), rgb_gt_full.permute(0, 3, 1, 2), ).item() ) @@ -340,6 +382,7 @@ def render_all(self): mean_cc_psnr=float(mean_cc_psnr), mean_cc_ssim=float(mean_cc_ssim), mean_cc_lpips=float(mean_cc_lpips), + mean_inference_time_ms=float(mean_inference_time), ) metrics_path = os.path.join(self.out_dir, "metrics.json") with open(metrics_path, "w") as f: diff --git a/threedgrut/trainer.py b/threedgrut/trainer.py index 223fa539..88307454 100644 --- a/threedgrut/trainer.py +++ b/threedgrut/trainer.py @@ -39,7 +39,7 @@ from threedgrut.strategy.base import BaseStrategy from threedgrut.utils.logger import logger from threedgrut.utils.misc import check_step_condition, create_summary_writer, jet_map -from threedgrut.utils.render import apply_post_processing +from threedgrut.utils.render import apply_background, apply_feature_decoder, apply_post_processing from threedgrut.utils.timer import CudaTimer @@ -129,9 +129,11 @@ def __init__(self, conf: DictConfig, device=None): self.init_densification_and_pruning_strategy(conf) logger.log_rule("Setup Model Weights & Training") self.init_metrics() + # Feature decoder and post-processing must exist before setup_training so resume can load their state. + self.init_feature_decoder(conf) + self.init_post_processing(conf) self.setup_training(conf, self.model, self.train_dataset) self.init_experiments_tracking(conf) - self.init_post_processing(conf) self.init_gui(conf, self.model, self.train_dataset, self.val_dataset, self.scene_bbox) def init_dataloaders(self, conf: DictConfig): @@ -226,6 +228,18 @@ def setup_training( self.strategy.init_densification_buffer(checkpoint) global_step = checkpoint["global_step"] + # Restore feature decoder state (skip if architecture drifted vs checkpoint) + if "feature_decoder" in checkpoint and self.feature_decoder is not None: + fd_ckpt = checkpoint["feature_decoder"] + self.feature_decoder.load_state_dict(fd_ckpt["module"]) + self.feature_decoder_optimizer.load_state_dict(fd_ckpt["optimizer"]) + self.feature_decoder_scheduler.load_state_dict(fd_ckpt["scheduler"]) + ema_state = fd_ckpt.get("ema") + if ema_state is not None: + self.feature_decoder.load_ema_state_dict(ema_state) + self.feature_decoder.apply_ema_shadow() + logger.info("🎨 Feature decoder state restored from checkpoint") + # Restore post-processing state if "post_processing" in checkpoint and self.post_processing is not None: self.post_processing.load_state_dict(checkpoint["post_processing"]["module"]) @@ -240,6 +254,7 @@ def setup_training( ): sched.load_state_dict(sched_state) logger.info("📷 Post-processing state restored from checkpoint") + model.build_acc() elif conf.import_ply.enabled: ply_path = ( conf.import_ply.path @@ -329,15 +344,16 @@ def init_gui( ): gui = None + feature_decoder = getattr(self, "feature_decoder", None) if conf.with_gui: from threedgrut.utils.gui import GUI - gui = GUI(conf, model, train_dataset, val_dataset, scene_bbox) + gui = GUI(conf, model, train_dataset, val_dataset, scene_bbox, feature_decoder=feature_decoder) elif conf.with_viser_gui: from threedgrut.utils.viser_gui_util import ViserGUI - gui = ViserGUI(conf, model, train_dataset, val_dataset, scene_bbox) + gui = ViserGUI(conf, model, train_dataset, val_dataset, scene_bbox, feature_decoder=feature_decoder) self.gui = gui @@ -424,6 +440,82 @@ def init_post_processing(self, conf: DictConfig): else: raise ValueError(f"Unknown post-processing method: {method}") + def init_feature_decoder(self, conf: DictConfig): + """Initialize feature decoder for learned features mode.""" + from threedgrut.model.features import Features + + if self.model.feature_type != Features.Type.NHT: + self.feature_decoder = None + self.feature_decoder_optimizer = None + self.feature_decoder_scheduler = None + return + + dec_conf = conf.model.nht_decoder + if not getattr(dec_conf, "enabled", True): + self.feature_decoder = None + self.feature_decoder_optimizer = None + self.feature_decoder_scheduler = None + return + + from threedgrut.model.feature_decoder import FeatureDecoder + + ray_feature_dim = self.model.ray_feature_dim + dec = conf.model.nht_decoder + hidden_dim = dec.hidden_dim + num_layers = getattr(dec, "num_layers", 4) + dir_encoding = getattr(dec, "dir_encoding", "SphericalHarmonics") + dir_encoding_degree = getattr(dec, "dir_encoding_degree", 3) + sh_scale = getattr(dec, "sh_scale", 1.0) + output_activation = getattr(dec, "output_activation", "Sigmoid") + ema_decay = getattr(dec_conf, "ema_decay", 0.0) + ema_start_step = getattr(dec_conf, "ema_start_step", 0) + logger.info(f"Initializing FeatureDecoder: {ray_feature_dim} -> 3 RGB") + self.feature_decoder = FeatureDecoder( + ray_feature_dim=ray_feature_dim, + hidden_dim=hidden_dim, + num_layers=num_layers, + dir_encoding=dir_encoding, + dir_encoding_degree=dir_encoding_degree, + sh_scale=sh_scale, + output_activation=output_activation, + ema_decay=ema_decay, + ema_start_step=ema_start_step, + ).to(self.device) + + lr = dec.learning_rate + weight_decay = getattr(dec, "reg_weight", 0.0) + self.feature_decoder_optimizer = torch.optim.Adam( + self.feature_decoder.parameters(), + lr=lr, + weight_decay=weight_decay, + ) + + scheduler_conf = dec.scheduler + max_steps = getattr(conf, "n_iterations", 30000) + decay_final = float(getattr(scheduler_conf, "decay_final", 0.001)) + if scheduler_conf.type == "exponential": + gamma = decay_final ** (1.0 / max_steps) + self.feature_decoder_scheduler = torch.optim.lr_scheduler.ExponentialLR( + self.feature_decoder_optimizer, + gamma=gamma, + ) + elif scheduler_conf.type == "cosine": + eta_min = lr * decay_final + self.feature_decoder_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( + self.feature_decoder_optimizer, + T_max=max_steps, + eta_min=eta_min, + ) + else: + raise ValueError(f"Unknown scheduler type: {scheduler_conf.type}") + + if ema_decay > 0: + logger.info(f"🎨 FeatureDecoder EMA: decay={ema_decay}, start_step={ema_start_step}") + logger.info( + f"🎨 FeatureDecoder optimizer: lr={lr}, " + f"weight_decay={weight_decay}, scheduler={scheduler_conf.type}" + ) + @torch.cuda.nvtx.range("get_metrics") def get_metrics( self, @@ -448,7 +540,7 @@ def get_metrics( step = self.global_step rgb_gt = gpu_batch.rgb_gt - rgb_pred = outputs["pred_rgb"] + rgb_pred = outputs["pred_features"] psnr = self.criterions["psnr"] ssim = self.criterions["ssim"] @@ -471,18 +563,18 @@ def get_metrics( metrics["psnr"] = psnr(rgb_pred, rgb_gt).item() rgb_gt_full = rgb_gt.permute(0, 3, 1, 2) - pred_rgb_full = rgb_pred.permute(0, 3, 1, 2) - pred_rgb_full_clipped = rgb_pred.clip(0, 1).permute(0, 3, 1, 2) + pred_features_full = rgb_pred.permute(0, 3, 1, 2) + pred_features_full_clipped = rgb_pred.clip(0, 1).permute(0, 3, 1, 2) with torch.cuda.nvtx.range(f"criterions_ssim"): - metrics["ssim"] = ssim(pred_rgb_full, rgb_gt_full).item() + metrics["ssim"] = ssim(pred_features_full, rgb_gt_full).item() with torch.cuda.nvtx.range(f"criterions_lpips"): - metrics["lpips"] = lpips(pred_rgb_full_clipped, rgb_gt_full).item() + metrics["lpips"] = lpips(pred_features_full_clipped, rgb_gt_full).item() if iteration in self.conf.writer.log_image_views: metrics["img_hit_counts"] = jet_map(outputs["hits_count"][-1], self.conf.writer.max_num_hits) metrics["img_gt"] = gpu_batch.rgb_gt[-1].clip(0, 1.0) - metrics["img_pred"] = outputs["pred_rgb"][-1].clip(0, 1.0) + metrics["img_pred"] = outputs["pred_features"][-1].clip(0, 1.0) metrics["img_pred_dist"] = jet_map(outputs["pred_dist"][-1], 100) metrics["img_pred_opacity"] = jet_map(outputs["pred_opacity"][-1], 1) @@ -508,7 +600,7 @@ def get_losses( losses: dictionary of loss terms computed for current batch. """ rgb_gt = gpu_batch.rgb_gt - rgb_pred = outputs["pred_rgb"] + rgb_pred = outputs["pred_features"] mask = gpu_batch.mask # Mask out the invalid pixels if the mask is provided @@ -529,7 +621,7 @@ def get_losses( lambda_l2 = 0.0 if self.conf.loss.use_l2: with torch.cuda.nvtx.range(f"loss-l2"): - loss_l2 = torch.nn.functional.mse_loss(outputs["pred_rgb"], rgb_gt) + loss_l2 = torch.nn.functional.mse_loss(outputs["pred_features"], rgb_gt) lambda_l2 = self.conf.loss.lambda_l2 # DSSIM loss @@ -538,8 +630,8 @@ def get_losses( if self.conf.loss.use_ssim: with torch.cuda.nvtx.range(f"loss-ssim"): rgb_gt_full = torch.permute(rgb_gt, (0, 3, 1, 2)) - pred_rgb_full = torch.permute(rgb_pred, (0, 3, 1, 2)) - loss_ssim = 1.0 - ssim(pred_rgb_full, rgb_gt_full) + pred_features_full = torch.permute(rgb_pred, (0, 3, 1, 2)) + loss_ssim = 1.0 - ssim(pred_features_full, rgb_gt_full) lambda_ssim = self.conf.loss.lambda_ssim # Opacity regularization @@ -842,6 +934,7 @@ def on_training_end(self): global_step=self.global_step, compute_extra_metrics=conf.compute_extra_metrics, post_processing=self.post_processing, + feature_decoder=self.feature_decoder, ) renderer.render_all() @@ -860,6 +953,24 @@ def save_checkpoint(self, last_checkpoint: bool = False): strategy_parameters = self.strategy.get_strategy_parameters() parameters = {**parameters, **strategy_parameters} + # Add feature decoder state to checkpoint (module + optimizer + scheduler + EMA) + if self.feature_decoder is not None: + dec = self.feature_decoder + parameters["feature_decoder"] = { + "module": dec.state_dict(), + "optimizer": self.feature_decoder_optimizer.state_dict(), + "scheduler": self.feature_decoder_scheduler.state_dict(), + "arch": { + "ray_feature_dim": dec.ray_feature_dim, + "hidden_dim": dec.hidden_dim, + "num_layers": dec.num_layers, + "output_activation": dec.output_activation, + }, + } + ema_state = self.feature_decoder.ema_state_dict() + if ema_state: + parameters["feature_decoder"]["ema"] = ema_state + # Add post-processing state to checkpoint (module + optimizers + schedulers) if self.post_processing is not None: parameters["post_processing"] = { @@ -926,6 +1037,8 @@ def run_train_iter( with torch.cuda.nvtx.range(f"train_iter{global_step}_get_gpu_batch"): gpu_batch = self.train_dataset.get_gpu_batch_with_intrinsics(batch) + profilers["step_total"].start() + # Perform validation if required is_time_to_validate = (global_step > 0 or conf.validate_first) and (global_step % self.val_frequency == 0) if is_time_to_validate: @@ -937,6 +1050,14 @@ def run_train_iter( outputs = self.model(gpu_batch, train=True, frame_id=global_step) profilers["inference"].end() + # Apply feature decoder to convert N-dimensional features to RGB + if self.feature_decoder is not None: + with torch.cuda.nvtx.range(f"train_{global_step}_feature_decoder"): + profilers["feature_decoder"].start() + outputs = apply_feature_decoder(self.feature_decoder, outputs, gpu_batch, training=True) + profilers["feature_decoder"].end() + outputs = apply_background(self.model.background, outputs, gpu_batch, training=True) + # Apply post-processing to rendered output if self.post_processing is not None: with torch.cuda.nvtx.range(f"train_{global_step}_post_processing"): @@ -945,6 +1066,14 @@ def run_train_iter( # Compute the losses of a single batch with torch.cuda.nvtx.range(f"train_{global_step}_loss"): batch_losses = self.get_losses(gpu_batch, outputs) + + # Add feature decoder regularization loss + if self.feature_decoder is not None and "decoder_reg_loss" in outputs: + decoder_reg_weight = conf.model.nht_decoder.reg_weight + decoder_reg_loss = decoder_reg_weight * outputs["decoder_reg_loss"] + batch_losses["total_loss"] = batch_losses["total_loss"] + decoder_reg_loss + batch_losses["decoder_reg_loss"] = decoder_reg_loss + # Add post-processing regularization loss if self.post_processing is not None: post_processing_reg_loss = self.post_processing.get_regularization_loss() @@ -992,6 +1121,14 @@ def run_train_iter( with torch.cuda.nvtx.range(f"train_{global_step}_scheduler"): self.model.scheduler_step(global_step) + # Feature decoder optimizer/scheduler step + if self.feature_decoder_optimizer is not None: + with torch.cuda.nvtx.range(f"train_{global_step}_feature_decoder_opt"): + self.feature_decoder_optimizer.step() + self.feature_decoder_optimizer.zero_grad() + self.feature_decoder_scheduler.step() + self.feature_decoder.ema_update(global_step) + # Post-processing optimizer/scheduler step if self.post_processing_optimizers is not None: with torch.cuda.nvtx.range(f"train_{global_step}_post_processing_opt"): @@ -1026,6 +1163,8 @@ def run_train_iter( self.model.build_acc(rebuild=True) profilers["build_as"].end() + profilers["step_total"].end() + # Increment the global step global_step += 1 self.global_step = global_step @@ -1067,7 +1206,10 @@ def run_train_pass(self, conf: DictConfig): "inference": CudaTimer(enabled=self.conf.enable_frame_timings), "backward": CudaTimer(enabled=self.conf.enable_frame_timings), "build_as": CudaTimer(enabled=self.conf.enable_frame_timings), + "step_total": CudaTimer(enabled=self.conf.enable_frame_timings), } + if self.feature_decoder is not None: + profilers["feature_decoder"] = CudaTimer(enabled=self.conf.enable_frame_timings) for iter, batch in enumerate(self.train_dataloader): # Check if we have reached the maximum number of iterations @@ -1087,6 +1229,8 @@ def run_validation_pass(self, conf: DictConfig) -> dict[str, Any]: dictionary of metrics computed and aggregated over validation set. """ + if self.feature_decoder is not None: + self.feature_decoder.apply_ema_shadow() profilers = { "inference": CudaTimer(), } @@ -1106,6 +1250,10 @@ def run_validation_pass(self, conf: DictConfig) -> dict[str, Any]: with torch.cuda.nvtx.range(f"train.validation_step_{self.global_step}"): profilers["inference"].start() outputs = self.model(gpu_batch, train=False) + # Apply feature decoder to convert N-dimensional features to RGB + if self.feature_decoder is not None: + outputs = apply_feature_decoder(self.feature_decoder, outputs, gpu_batch, training=False) + outputs = apply_background(self.model.background, outputs, gpu_batch, training=False) # Apply post-processing for validation (novel view mode) if self.post_processing is not None: outputs = apply_post_processing(self.post_processing, outputs, gpu_batch, training=False) @@ -1125,6 +1273,8 @@ def run_validation_pass(self, conf: DictConfig) -> dict[str, Any]: metrics.append(batch_metrics) logger.end_progress(task_name="Validation") + if self.feature_decoder is not None: + self.feature_decoder.restore_ema() metrics = self._flatten_list_of_dicts(metrics) self.log_validation_pass(metrics) diff --git a/threedgrut/utils/gui.py b/threedgrut/utils/gui.py index 8801f3b9..58b6c498 100644 --- a/threedgrut/utils/gui.py +++ b/threedgrut/utils/gui.py @@ -187,7 +187,7 @@ def render_from_current_ps_view(self): self.render_height = window_h return ( - outputs["pred_rgb"], + outputs["pred_features"], outputs["pred_opacity"], outputs["pred_dist"], outputs["pred_normals"], diff --git a/threedgrut/utils/misc.py b/threedgrut/utils/misc.py index 3757a6f6..b2251404 100644 --- a/threedgrut/utils/misc.py +++ b/threedgrut/utils/misc.py @@ -97,14 +97,28 @@ def helper(step): return helper -def skip_scheduler(type=""): +def cosine_scheduler(lr_init, lr_final, max_steps=1000000, type=""): + """Cosine annealing: lr = lr_final + 0.5 * (lr_init - lr_final) * (1 + cos(pi * step / max_steps)).""" + + def helper(step): + t = np.clip(step / max_steps, 0, 1) + return float(lr_final + 0.5 * (lr_init - lr_final) * (1 + np.cos(np.pi * t))) + + return helper + + +def skip_scheduler(type="", **kwargs): def helper(step): return None return helper -SCHEDULER_DICT: dict[str, Callable] = {"exp": exponential_scheduler, "skip": skip_scheduler} +SCHEDULER_DICT: dict[str, Callable] = { + "exp": exponential_scheduler, + "cosine": cosine_scheduler, + "skip": skip_scheduler, +} def get_scheduler(scheduler: str) -> Callable: diff --git a/threedgrut/utils/render.py b/threedgrut/utils/render.py index 57c0f427..8d6a65ef 100644 --- a/threedgrut/utils/render.py +++ b/threedgrut/utils/render.py @@ -14,6 +14,9 @@ # limitations under the License. +import torch +import torch.nn as nn + ## NOTE: SPH code from gaussian-splatting, from plenoctree, from ??? C0 = 0.28209479177387814 C1 = 0.4886025119029199 @@ -48,6 +51,55 @@ def SH2RGB(sh): return sh * C0 + 0.5 +def apply_feature_decoder( + feature_decoder, + outputs: dict, + gpu_batch, + training: bool = False, +) -> dict: + """Apply feature decoder to N-dimensional feature map.""" + if feature_decoder is None: + return outputs + + feature_map = outputs["pred_features"] # [B, H, W, N] alpha-blended features + alpha = outputs["pred_opacity"] # [B, H, W] or [B, H, W, 1] + B, H, W, N = feature_map.shape + + R = gpu_batch.T_to_world[:, :3, :3] # [B, 3, 3] c2w rotation + rays_dir_cam = gpu_batch.rays_dir # [B, H, W, 3] + rays_dir_world = torch.einsum("bij,bhwj->bhwi", R, rays_dir_cam) + rays_dir_world = torch.nn.functional.normalize(rays_dir_world, dim=-1) + + features_flat = feature_map.contiguous().view(-1, N) + ray_dir_flat = rays_dir_world.contiguous().view(-1, 3) + if alpha.dim() == 3: + alpha = alpha.unsqueeze(-1) # [B, H, W, 1] + alpha_flat = alpha.contiguous().view(-1, 1) + + rgb_flat = feature_decoder(features_flat, ray_dir_flat, alpha=alpha_flat) + outputs["pred_features"] = rgb_flat.view(B, H, W, 3) + + if training and hasattr(feature_decoder, "regularization_loss"): + outputs["decoder_reg_loss"] = feature_decoder.regularization_loss() + + return outputs + + +def apply_background(background, outputs: dict, gpu_batch, training: bool = False) -> dict: + """Apply background to decoded RGB (3-channel). Call after apply_feature_decoder when using nht.""" + if background is None or outputs["pred_features"].shape[-1] != 3: + return outputs + pred_features, pred_opacity = background( + gpu_batch.T_to_world.contiguous(), + gpu_batch.rays_dir.contiguous(), + outputs["pred_features"], + outputs["pred_opacity"], + training, + ) + outputs["pred_features"] = pred_features + return outputs + + def apply_post_processing( post_processing, outputs: dict, @@ -58,28 +110,28 @@ def apply_post_processing( Args: post_processing: Post-processing module - outputs: Model outputs including pred_rgb + outputs: Model outputs including pred_features gpu_batch: Batch containing camera_idx, frame_idx, pixel_coords, exposure training: If True, use actual frame_idx; if False, use -1 for novel view mode Returns: - Updated outputs dict with post-processed pred_rgb + Updated outputs dict with post-processed pred_features """ - assert outputs["pred_rgb"].shape[0] == 1, "Post-processing requires batch_size=1" + assert outputs["pred_features"].shape[0] == 1, "Post-processing requires batch_size=1" - pred_rgb = outputs["pred_rgb"] + pred_features = outputs["pred_features"] camera_idx = gpu_batch.camera_idx frame_idx = gpu_batch.frame_idx if training else -1 - H, W = pred_rgb.shape[1], pred_rgb.shape[2] + H, W = pred_features.shape[1], pred_features.shape[2] # Flatten: [1, H, W, 3] -> [H*W, 3] # Ensure contiguous memory for CUDA kernels - pred_rgb_flat = pred_rgb.contiguous().view(-1, 3) + pred_features_flat = pred_features.contiguous().view(-1, 3) pixel_coords_flat = gpu_batch.pixel_coords.contiguous().view(-1, 2) # Apply post-processing - pred_rgb_pp = post_processing( - pred_rgb_flat, + pred_features_pp = post_processing( + pred_features_flat, pixel_coords_flat, resolution=(W, H), camera_idx=camera_idx, @@ -88,5 +140,5 @@ def apply_post_processing( ) # Reshape back: [H*W, 3] -> [1, H, W, 3] - outputs["pred_rgb"] = pred_rgb_pp.view(pred_rgb.shape) + outputs["pred_features"] = pred_features_pp.view(pred_features.shape) return outputs diff --git a/threedgrut/utils/viser_gui_util.py b/threedgrut/utils/viser_gui_util.py index c986aa0e..799f497b 100644 --- a/threedgrut/utils/viser_gui_util.py +++ b/threedgrut/utils/viser_gui_util.py @@ -252,7 +252,7 @@ def render_from_current_view( # points_plane = self.model.positions[u, v] # Return the same outputs as polyscope version return ( - outputs["pred_rgb"], + outputs["pred_features"], outputs["pred_opacity"], outputs["pred_dist"], outputs["pred_normals"], diff --git a/threedgrut_playground/engine.py b/threedgrut_playground/engine.py index 65b8da08..97be5d27 100644 --- a/threedgrut_playground/engine.py +++ b/threedgrut_playground/engine.py @@ -882,7 +882,7 @@ def _render_depth_of_field_buffer(self, rb, camera, rays): if not self.primitives.enabled or not self.primitives.has_visible_objects(): if self.disable_gaussian_tracing: dof_rb = dict( - pred_rgb=torch.zeros_like(rays.rays_ori), + pred_features=torch.zeros_like(rays.rays_ori), pred_opacity=torch.zeros_like(rays.rays_ori[:, :, 0:1]), ) else: @@ -890,7 +890,7 @@ def _render_depth_of_field_buffer(self, rb, camera, rays): else: dof_rb = self._render_playground_hybrid(dof_rays_ori, dof_rays_dir) - rb["rgb"] = self._accumulate_to_buffer(rb["rgb"], dof_rb["pred_rgb"], i, self.gamma_correction) + rb["rgb"] = self._accumulate_to_buffer(rb["rgb"], dof_rb["pred_features"], i, self.gamma_correction) rb["opacity"] = (rb["opacity"] * i + dof_rb["pred_opacity"]) / (i + 1) def _render_spp_buffer(self, rb, rays): @@ -902,14 +902,14 @@ def _render_spp_buffer(self, rb, rays): if not self.primitives.enabled or not self.primitives.has_visible_objects(): if self.disable_gaussian_tracing: spp_rb = dict( - pred_rgb=torch.zeros_like(rays.rays_ori), + pred_features=torch.zeros_like(rays.rays_ori), pred_opacity=torch.zeros_like(rays.rays_ori[:, :, 0:1]), ) else: spp_rb = self.scene_mog.trace(rays_o=rays.rays_ori, rays_d=rays.rays_dir) else: spp_rb = self._render_playground_hybrid(rays.rays_ori, rays.rays_dir) - batch_rgb = spp_rb["pred_rgb"].sum(dim=0).unsqueeze(0) + batch_rgb = spp_rb["pred_features"].sum(dim=0).unsqueeze(0) rb["rgb"] = self._accumulate_to_buffer( rb["rgb"], batch_rgb, i, self.gamma_correction, batch_size=self.spp.batch_size ) @@ -928,7 +928,7 @@ def _render_playground_hybrid(self, rays_o: torch.Tensor, rays_d: torch.Tensor) Returns: Dict[str, torch.Tensor]: Rendering results containing: - - 'pred_rgb': RGB colors of shape (B, H, W, 3), range [0, 1] + - 'pred_features': RGB colors of shape (B, H, W, 3), range [0, 1] - 'pred_opacity': Opacity values of shape (B, H, W, 1), range [0, 1] - 'last_ray_d': Final ray directions for background computation - Additional buffers from tracer.render_playground() @@ -979,14 +979,14 @@ def _render_playground_hybrid(self, rays_o: torch.Tensor, rays_d: torch.Tensor) max_pbr_bounces=self.max_pbr_bounces, ) - pred_rgb = rendered_results["pred_rgb"] + pred_features = rendered_results["pred_features"] pred_opacity = rendered_results["pred_opacity"] # If no envmap is used for background, saturate the color channels by blending the mog background if envmap is None or not self.environment.is_ignore_envmap(): poses = torch.tensor([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0]], dtype=torch.float32) - pred_rgb, pred_opacity = mog.background( - poses.contiguous(), rendered_results["last_ray_d"].contiguous(), pred_rgb, pred_opacity, False + pred_features, pred_opacity = mog.background( + poses.contiguous(), rendered_results["last_ray_d"].contiguous(), pred_features, pred_opacity, False ) # Mark materials as uploaded @@ -995,9 +995,9 @@ def _render_playground_hybrid(self, rays_o: torch.Tensor, rays_d: torch.Tensor) # Advance frame id (for i.e., random number generator) and avoid int32 overflow self.frame_id = self.frame_id + self.spp.batch_size if self.frame_id <= (2**31 - 1) else 0 - pred_rgb = torch.clamp(pred_rgb, 0.0, 1.0) # Make sure image pixels are in valid range + pred_features = torch.clamp(pred_features, 0.0, 1.0) # Make sure image pixels are in valid range - rendered_results["pred_rgb"] = pred_rgb + rendered_results["pred_features"] = pred_features return rendered_results @torch.cuda.nvtx.range("render_pass") @@ -1065,7 +1065,7 @@ def render_pass(self, camera: Camera, is_first_pass: bool) -> Dict[str, torch.Te if not self.primitives.enabled or not self.primitives.has_visible_objects(): if self.disable_gaussian_tracing: rb = dict( - pred_rgb=torch.zeros_like(rays.rays_ori), + pred_features=torch.zeros_like(rays.rays_ori), pred_opacity=torch.zeros_like(rays.rays_ori[:, :, 0:1]), ) else: @@ -1073,7 +1073,7 @@ def render_pass(self, camera: Camera, is_first_pass: bool) -> Dict[str, torch.Te else: rb = self._render_playground_hybrid(rays.rays_ori, rays.rays_dir) - rb = dict(rgb=rb["pred_rgb"], opacity=rb["pred_opacity"]) + rb = dict(rgb=rb["pred_features"], opacity=rb["pred_opacity"]) rb["rgb"] = self.environment.tonemap(rb["rgb"]) rb["rgb"] = torch.pow(rb["rgb"], 1.0 / self.gamma_correction) rb["rgb"] = rb["rgb"].mean(dim=0).unsqueeze(0) diff --git a/threedgrut_playground/tracer.py b/threedgrut_playground/tracer.py index 77e03799..566ca93d 100644 --- a/threedgrut_playground/tracer.py +++ b/threedgrut_playground/tracer.py @@ -115,7 +115,7 @@ def render(self, gaussians, gpu_batch, train=False, frame_id=0): mog_scl = gaussians.get_scale().contiguous() particle_density = torch.concat([mog_pos, mog_dns, mog_rot, mog_scl, torch.zeros_like(mog_dns)], dim=1) - pred_rgb, pred_opacity, pred_dist, pred_normals, hits_count = self.tracer_wrapper.trace( + pred_features, pred_opacity, pred_dist, pred_normals, hits_count = self.tracer_wrapper.trace( frame_id, gpu_batch["poses"].contiguous(), gpu_batch["rays_o_cam"].contiguous(), @@ -127,12 +127,12 @@ def render(self, gaussians, gpu_batch, train=False, frame_id=0): ) # NOTE: disable background - pred_rgb, pred_opacity = gaussians.background( - gpu_batch["poses"].contiguous(), gpu_batch["rays_d_cam"].contiguous(), pred_rgb, pred_opacity, train + pred_features, pred_opacity = gaussians.background( + gpu_batch["poses"].contiguous(), gpu_batch["rays_d_cam"].contiguous(), pred_features, pred_opacity, train ) return { - "pred_rgb": pred_rgb, + "pred_features": pred_features, "pred_opacity": pred_opacity, "pred_dist": pred_dist, "pred_normals": torch.nn.functional.normalize(pred_normals, dim=3), @@ -224,7 +224,7 @@ def render_playground( min_transmittance = self.conf.render.min_transmittance envmap_offset = envmap_offset.contiguous() - pred_rgb, pred_opacity, pred_dist, pred_normals, hits_count = self.tracer_wrapper.trace_hybrid( + pred_features, pred_opacity, pred_dist, pred_normals, hits_count = self.tracer_wrapper.trace_hybrid( frame_id, poses, ray_o, @@ -253,7 +253,7 @@ def render_playground( pred_dist = pred_dist[:, :, :, 0:1] # return only the hit distance return { - "pred_rgb": pred_rgb, + "pred_features": pred_features, "pred_opacity": pred_opacity, "pred_dist": pred_dist, "pred_normals": torch.nn.functional.normalize(pred_normals, dim=3), diff --git a/threedgut_tracer/include/3dgut/kernels/cuda/common/rayPayload.cuh b/threedgut_tracer/include/3dgut/kernels/cuda/common/rayPayload.cuh index 8fca0f5e..bdc91a74 100644 --- a/threedgut_tracer/include/3dgut/kernels/cuda/common/rayPayload.cuh +++ b/threedgut_tracer/include/3dgut/kernels/cuda/common/rayPayload.cuh @@ -107,6 +107,20 @@ __device__ __inline__ RayPayloadT initializeRay(const threedgut::RenderParameter return ray; } +// Output element type for the feature+density buffer. +// FEATURE_OUTPUT_HALF=1: write __half (fp16) to halve memory bandwidth. +// FEATURE_OUTPUT_HALF=0: write float (fp32, default). +#ifndef FEATURE_OUTPUT_HALF +#define FEATURE_OUTPUT_HALF 0 +#endif + +#if FEATURE_OUTPUT_HALF +#include +using TFeatureDensityElem = __half; +#else +using TFeatureDensityElem = float; +#endif + // Initialize ray based on given pixel coordinates (load-balanced mode) template __device__ __inline__ RayPayloadT initializeRayPerPixel(const threedgut::RenderParameters& params, @@ -149,13 +163,27 @@ __device__ __inline__ void finalizeRay(const TRayPayload& ray, const tcnn::vec3* __restrict__ sensorRayOriginPtr, float* __restrict__ worldCountPtr, float* __restrict__ worldHitDistancePtr, - tcnn::vec4* __restrict__ radianceDensityPtr, + TFeatureDensityElem* __restrict__ featureDensityPtr, const tcnn::mat4x3& sensorToWorldTransform) { if (!ray.isValid()) { return; } - radianceDensityPtr[ray.idx] = {ray.features[0], ray.features[1], ray.features[2], (1.0f - ray.transmittance)}; + static_assert(RAY_FEATURE_DIM == TRayPayload::FeatDim, "RAY_FEATURE_DIM must equal TRayPayload::FeatDim"); + const uint32_t base = ray.idx * (RAY_FEATURE_DIM + 1); +#if FEATURE_OUTPUT_HALF + #pragma unroll + for (int i = 0; i < TRayPayload::FeatDim; ++i) { + featureDensityPtr[base + i] = __float2half(ray.features[i]); + } + featureDensityPtr[base + RAY_FEATURE_DIM] = __float2half(1.0f - ray.transmittance); +#else + #pragma unroll + for (int i = 0; i < TRayPayload::FeatDim; ++i) { + featureDensityPtr[base + i] = ray.features[i]; + } + featureDensityPtr[base + RAY_FEATURE_DIM] = (1.0f - ray.transmittance); +#endif worldHitDistancePtr[ray.idx] = ray.hitT; diff --git a/threedgut_tracer/include/3dgut/kernels/cuda/common/rayPayloadBackward.cuh b/threedgut_tracer/include/3dgut/kernels/cuda/common/rayPayloadBackward.cuh index 67c44db4..ff37f159 100644 --- a/threedgut_tracer/include/3dgut/kernels/cuda/common/rayPayloadBackward.cuh +++ b/threedgut_tracer/include/3dgut/kernels/cuda/common/rayPayloadBackward.cuh @@ -33,8 +33,8 @@ __device__ __inline__ RayPayloadT initializeBackwardRay(const threedgut::RenderP const tcnn::vec3* __restrict__ sensorRayDirectionPtr, const float* __restrict__ worldHitDistancePtr, const float* __restrict__ worldHitDistanceGradientPtr, - const tcnn::vec* __restrict__ featuresDensityPtr, - const tcnn::vec* __restrict__ featuresDensityGradientPtr, + const TFeatureDensityElem* __restrict__ featuresDensityPtr, + const float* __restrict__ featuresDensityGradientPtr, const tcnn::mat4x3& sensorToWorldTransform) { // NB : no backpropagation through the forward ray initialization / finalization @@ -44,14 +44,29 @@ __device__ __inline__ RayPayloadT initializeBackwardRay(const threedgut::RenderP sensorToWorldTransform); if (ray.isAlive()) { - const tcnn::vec featuresDensity = featuresDensityPtr[ray.idx]; - const tcnn::vec featuresDensityGradient = featuresDensityGradientPtr[ray.idx]; - ray.transmittanceBackward = 1.f - featuresDensity[RayPayloadT::FeatDim]; - ray.transmittanceGradient = -1.f * featuresDensityGradient[RayPayloadT::FeatDim]; - ray.hitTBackward = worldHitDistancePtr[ray.idx]; - ray.hitTGradient = worldHitDistanceGradientPtr[ray.idx]; - ray.featuresBackward = threedgut::sliceVec<0, RayPayloadT::FeatDim>(featuresDensity); - ray.featuresGradient = threedgut::sliceVec<0, RayPayloadT::FeatDim>(featuresDensityGradient); + constexpr uint32_t stride = RayPayloadT::FeatDim + 1; + const uint32_t base = ray.idx * stride; + // Forward features: fp16 when FEATURE_OUTPUT_HALF=1, fp32 otherwise. + // Gradient buffer: always fp32 — keeps backward numerically stable regardless of forward dtype. +#if FEATURE_OUTPUT_HALF + #pragma unroll + for (int i = 0; i < RayPayloadT::FeatDim; ++i) { + ray.featuresBackward[i] = __half2float(featuresDensityPtr[base + i]); + ray.featuresGradient[i] = featuresDensityGradientPtr[base + i]; + } + ray.transmittanceBackward = 1.f - __half2float(featuresDensityPtr[base + RayPayloadT::FeatDim]); + ray.transmittanceGradient = -1.f * featuresDensityGradientPtr[base + RayPayloadT::FeatDim]; +#else + #pragma unroll + for (int i = 0; i < RayPayloadT::FeatDim; ++i) { + ray.featuresBackward[i] = featuresDensityPtr[base + i]; + ray.featuresGradient[i] = featuresDensityGradientPtr[base + i]; + } + ray.transmittanceBackward = 1.f - featuresDensityPtr[base + RayPayloadT::FeatDim]; + ray.transmittanceGradient = -1.f * featuresDensityGradientPtr[base + RayPayloadT::FeatDim]; +#endif + ray.hitTBackward = worldHitDistancePtr[ray.idx]; + ray.hitTGradient = worldHitDistanceGradientPtr[ray.idx]; } return ray; diff --git a/threedgut_tracer/include/3dgut/kernels/cuda/models/shRadiativeGaussianParticles.cuh b/threedgut_tracer/include/3dgut/kernels/cuda/models/shRadiativeGaussianParticles.cuh index d25c461d..fda4e956 100644 --- a/threedgut_tracer/include/3dgut/kernels/cuda/models/shRadiativeGaussianParticles.cuh +++ b/threedgut_tracer/include/3dgut/kernels/cuda/models/shRadiativeGaussianParticles.cuh @@ -17,6 +17,9 @@ #include <3dgut/kernels/cuda/models/gaussianParticles.cuh> #include <3dgut/renderer/renderParameters.h> +#if PARTICLE_FEATURE_HALF +#include +#endif template struct ShRadiativeGaussianParticlesBuffer { TBuffer* ptr = nullptr; @@ -65,7 +68,7 @@ struct ShRadiativeGaussianVolumetricFeaturesParticles : Params, public ExtParams __forceinline__ __device__ DensityParameters fetchDensityParameters(uint32_t particleIdx) const { const auto parameters = particleDensityParameters( particleIdx, - {{reinterpret_cast(m_densityRawParameters.ptr), nullptr, true}}); + {reinterpret_cast(m_densityRawParameters.ptr), nullptr, false}); return *reinterpret_cast(¶meters); } @@ -95,13 +98,14 @@ struct ShRadiativeGaussianVolumetricFeaturesParticles : Params, public ExtParams const DensityParameters& parameters, float& alpha, float& depth, + float3& canonicalIntersection, tcnn::vec3* normal = nullptr) const { - return particleDensityHit(*reinterpret_cast(&rayOrigin), *reinterpret_cast(&rayDirection), reinterpret_cast(parameters), &alpha, &depth, + &canonicalIntersection, normal != nullptr, reinterpret_cast(normal)); } @@ -127,12 +131,14 @@ struct ShRadiativeGaussianVolumetricFeaturesParticles : Params, public ExtParams float& transmittance, float& integratedDepth, tcnn::vec3* integratedNormal = nullptr) const { + float3 unusedCanonicalIntersection = make_float3(0.f, 0.f, 0.f); return particleDensityProcessHitFwdFromBuffer(*reinterpret_cast(&rayOrigin), *reinterpret_cast(&rayDirection), particleIdx, {{reinterpret_cast(m_densityRawParameters.ptr), nullptr, true}}, &transmittance, &integratedDepth, + &unusedCanonicalIntersection, integratedNormal != nullptr, reinterpret_cast(integratedNormal)); } @@ -148,6 +154,7 @@ struct ShRadiativeGaussianVolumetricFeaturesParticles : Params, public ExtParams float depth, float& integratedDepth, float& integratedDepthGrad, + const float3& canonicalIntersectionGrad, const tcnn::vec3* normal = nullptr, tcnn::vec3* integratedNormal = nullptr, tcnn::vec3* integratedNormalGrad = nullptr @@ -167,6 +174,7 @@ struct ShRadiativeGaussianVolumetricFeaturesParticles : Params, public ExtParams depth, &integratedDepth, &integratedDepthGrad, + canonicalIntersectionGrad, normal != nullptr, normal == nullptr ? make_float3(0, 0, 0) : *reinterpret_cast(normal), reinterpret_cast(integratedNormal), @@ -232,51 +240,49 @@ struct ShRadiativeGaussianVolumetricFeaturesParticles : Params, public ExtParams } } - using FeaturesParameters = shRadiativeParticle_Parameters_0; - using TFeaturesVec = typename tcnn::vec; + using TFeaturesVec = typename tcnn::vec; +#if PARTICLE_FEATURE_HALF + using TFeatureRawParamPtr = __half; +#else + using TFeatureRawParamPtr = float; +#endif inline __device__ void initializeFeatures(threedgut::MemoryHandles parameters) { - static_assert(ExtParams::FeaturesDim == 3, "Hardcoded 3-dimensional radiance because of Slang-Cuda interop"); - m_featureRawParameters.ptr = parameters.bufferPtr(Params::FeaturesRawParametersBufferIndex); + m_featureRawParameters.ptr = parameters.bufferPtr(Params::FeaturesRawParametersBufferIndex); m_featureActiveShDegree = *reinterpret_cast(parameters.bufferPtr(Params::GlobalParametersValueBufferIndex) + Params::FeatureShDegreeValueOffset); }; inline __device__ void initializeFeaturesGradient(threedgut::MemoryHandles parametersGradient) { if constexpr (TDifferentiable) { - m_featureRawParameters.gradPtr = parametersGradient.bufferPtr(Params::FeaturesRawParametersGradientBufferIndex); + m_featureRawParameters.gradPtr = parametersGradient.bufferPtr(Params::FeaturesRawParametersGradientBufferIndex); } }; __forceinline__ __device__ TFeaturesVec featuresFromBuffer(uint32_t particleIdx, - const tcnn::vec3& incidentDirection) const { - - const auto features = particleFeaturesFromBuffer(particleIdx, - {{m_featureRawParameters.ptr, nullptr, true}, m_featureActiveShDegree}, - *reinterpret_cast(&incidentDirection)); - return *reinterpret_cast(&features); + const tcnn::vec3& incidentDirection, + const float3& canonicalPosition) const { + TFeaturesVec result; + particleFeaturesFromBuffer(particleIdx, + reinterpret_cast(m_featureRawParameters.ptr), + m_featureActiveShDegree, + *reinterpret_cast(&incidentDirection), + canonicalPosition, + reinterpret_cast*>(&result)); + return result; } template __forceinline__ __device__ TFeaturesVec featuresCustomFromBuffer(uint32_t particleIdx, const tcnn::vec3& incidentDirection) const { - const float3 gradu = threedgut::radianceFromSpH(m_featureActiveShDegree, - reinterpret_cast(&m_featureRawParameters.ptr[particleIdx * ExtParams::RadianceMaxNumSphCoefficients]), - *reinterpret_cast(&incidentDirection), - Clamped); - return *reinterpret_cast(&gradu); - } - template - __forceinline__ __device__ void featuresBwdToBuffer(uint32_t particleIdx, - const TFeaturesVec& featuresGrad, - const tcnn::vec3& incidentDirection, - tcnn::vec3& incidentDirectionGrad) const { - if constexpr (TDifferentiable) { - particleFeaturesBwdToBuffer(particleIdx, - {{m_featureRawParameters.ptr, m_featureRawParameters.gradPtr, exclusiveGradient}, m_featureActiveShDegree}, - *reinterpret_cast(&featuresGrad), - *reinterpret_cast(&incidentDirection), - reinterpret_cast(&incidentDirectionGrad)); + if constexpr (ExtParams::PerRayParticleFeatures) { + return TFeaturesVec::zero(); + } else { + const float3 gradu = threedgut::radianceFromSpH(m_featureActiveShDegree, + reinterpret_cast(&m_featureRawParameters.ptr[particleIdx * ExtParams::RadianceMaxNumSphCoefficients * 3]), + *reinterpret_cast(&incidentDirection), + Clamped); + return *reinterpret_cast(&gradu); } } @@ -285,31 +291,50 @@ struct ShRadiativeGaussianVolumetricFeaturesParticles : Params, public ExtParams const TFeaturesVec& features, const TFeaturesVec& featuresGrad, const tcnn::vec3& incidentDirection) const { - threedgut::radianceFromSpHBwd(m_featureActiveShDegree, - *reinterpret_cast(&incidentDirection), - *reinterpret_cast(&featuresGrad), - reinterpret_cast(&m_featureRawParameters.gradPtr[particleIdx * ExtParams::RadianceMaxNumSphCoefficients]), - *reinterpret_cast(&features)); + if constexpr (!ExtParams::PerRayParticleFeatures) { + threedgut::radianceFromSpHBwd(m_featureActiveShDegree, + *reinterpret_cast(&incidentDirection), + *reinterpret_cast(&featuresGrad), + reinterpret_cast(&m_featureRawParameters.gradPtr[particleIdx * ExtParams::RadianceMaxNumSphCoefficients * 3]), + *reinterpret_cast(&features)); + } + } + + template + __forceinline__ __device__ void featuresBwdToBuffer(uint32_t particleIdx, + const TFeaturesVec& featuresGrad, + const tcnn::vec3& incidentDirection, + tcnn::vec3& incidentDirectionGrad) const { + if constexpr (TDifferentiable) { + particleFeaturesBwdToBuffer(particleIdx, + reinterpret_cast(reinterpret_cast(m_featureRawParameters.ptr)), + m_featureRawParameters.gradPtr, + m_featureActiveShDegree, + exclusiveGradient, + *reinterpret_cast*>(&featuresGrad), + *reinterpret_cast(&incidentDirection), + reinterpret_cast(&incidentDirectionGrad)); + } } __forceinline__ __device__ void featureIntegrateFwd(float weight, const TFeaturesVec& features, TFeaturesVec& integratedFeatures) const { - particleFeaturesIntegrateFwd(weight, - *reinterpret_cast(&features), - reinterpret_cast(&integratedFeatures)); + *reinterpret_cast*>(&features), + reinterpret_cast*>(&integratedFeatures)); } __forceinline__ __device__ void featuresIntegrateFwdFromBuffer(const tcnn::vec3& incidentDirection, float weight, - uint32_t particleIdx, TFeaturesVec integratedFeatures) const { - + uint32_t particleIdx, TFeaturesVec& integratedFeatures) const { particleFeaturesIntegrateFwdFromBuffer(*reinterpret_cast(&incidentDirection), - weight, - particleIdx, - {{m_featureRawParameters.ptr, nullptr, true}, m_featureActiveShDegree}, - reinterpret_cast(&integratedFeatures)); + make_float3(0.f, 0.f, 0.f), + weight, + particleIdx, + reinterpret_cast(m_featureRawParameters.ptr), + m_featureActiveShDegree, + reinterpret_cast*>(&integratedFeatures)); } __forceinline__ __device__ void featuresIntegrateBwd(float alpha, @@ -321,15 +346,17 @@ struct ShRadiativeGaussianVolumetricFeaturesParticles : Params, public ExtParams if (TDifferentiable) { particleFeaturesIntegrateBwd(alpha, &alphaGrad, - *reinterpret_cast(&features), - reinterpret_cast(&featuresGrad), - reinterpret_cast(&integratedFeatures), - reinterpret_cast(&integratedFeaturesGrad)); + *reinterpret_cast*>(&features), + reinterpret_cast*>(&featuresGrad), + reinterpret_cast*>(&integratedFeatures), + reinterpret_cast*>(&integratedFeaturesGrad)); } } template __forceinline__ __device__ void featuresIntegrateBwdToBuffer(const tcnn::vec3& incidentDirection, + const float3& canonicalIntersection, + float3& canonicalIntersectionGrad, float alpha, float& alphaGrad, uint32_t particleIdx, @@ -339,13 +366,79 @@ struct ShRadiativeGaussianVolumetricFeaturesParticles : Params, public ExtParams if (TDifferentiable) { particleFeaturesIntegrateBwdToBuffer(*reinterpret_cast(&incidentDirection), + canonicalIntersection, + &canonicalIntersectionGrad, alpha, &alphaGrad, particleIdx, - {{m_featureRawParameters.ptr, m_featureRawParameters.gradPtr, exclusiveGradient}, m_featureActiveShDegree}, - *reinterpret_cast(&features), - reinterpret_cast(&integratedFeatures), - reinterpret_cast(&integratedFeaturesGrad)); + reinterpret_cast(m_featureRawParameters.ptr), + m_featureRawParameters.gradPtr, + m_featureActiveShDegree, + exclusiveGradient, + *reinterpret_cast*>(&features), + reinterpret_cast*>(&integratedFeatures), + reinterpret_cast*>(&integratedFeaturesGrad)); + } + } + + // NHT warp reduction step 1: compute feature grad into a thread-private local buffer (no atomics). + // featureLocalGrad must be zero-initialized (size ExtParams::ParticleFeatureDim) before calling. + // Follow with featureLocalGradWarpReduceAndWrite (called by ALL warp threads) to write to global buffer. + __forceinline__ __device__ void featuresIntegrateBwdToLocalGrad(const tcnn::vec3& incidentDirection, + const float3& canonicalIntersection, + float3& canonicalIntersectionGrad, + float alpha, + float& alphaGrad, + uint32_t particleIdx, + const TFeaturesVec& features, + TFeaturesVec& integratedFeatures, + TFeaturesVec& integratedFeaturesGrad, + float* featureLocalGrad) const { + if constexpr (TDifferentiable) { + // Pointer offset trick: shift featureLocalGrad back by particleOffset so Slang's + // _gradPtr[interpPointOffset + i] writes land in featureLocalGrad[0..ParticleFeatureDim-1]. + // interpPointOffset = particleOffset + interpPointIdx*InterpPointFeatureDim, so: + // (featureLocalGrad - particleOffset)[interpPointOffset + i] + // = featureLocalGrad[interpPointIdx*InterpPointFeatureDim + i] + const uint32_t particleOffset = particleIdx * ExtParams::ParticleFeatureDim; + particleFeaturesIntegrateBwdToBuffer( + *reinterpret_cast(&incidentDirection), + canonicalIntersection, + &canonicalIntersectionGrad, + alpha, + &alphaGrad, + particleIdx, + reinterpret_cast(m_featureRawParameters.ptr), + featureLocalGrad - particleOffset, // shifted: writes to featureLocalGrad[0..ParticleFeatureDim-1] + m_featureActiveShDegree, + true, // exclusiveGradient=true → += without atomics + *reinterpret_cast*>(&features), + reinterpret_cast*>(&integratedFeatures), + reinterpret_cast*>(&integratedFeaturesGrad)); + } + } + + // NHT warp reduction step 2: warp-reduce featureLocalGrad and atomicAdd to global gradient buffer. + // MUST be called by ALL threads in the warp (including non-hitting threads with featureLocalGrad=0) + // to satisfy __shfl_xor_sync requirements. + __forceinline__ __device__ void featureLocalGradWarpReduceAndWrite(uint32_t particleIdx, + float* featureLocalGrad, + uint32_t tileThreadIdx) const { + if constexpr (TDifferentiable) { +#pragma unroll + for (int mask = 1; mask < warpSize; mask *= 2) { +#pragma unroll + for (int i = 0; i < ExtParams::ParticleFeatureDim; i++) { + featureLocalGrad[i] += __shfl_xor_sync(0xffffffff, featureLocalGrad[i], mask); + } + } + if ((tileThreadIdx & (warpSize - 1)) == 0) { + const uint32_t particleOffset = particleIdx * ExtParams::ParticleFeatureDim; +#pragma unroll + for (int i = 0; i < ExtParams::ParticleFeatureDim; i++) { + atomicAdd(&m_featureRawParameters.gradPtr[particleOffset + i], featureLocalGrad[i]); + } + } } } @@ -362,7 +455,8 @@ struct ShRadiativeGaussianVolumetricFeaturesParticles : Params, public ExtParams reinterpret_cast(rayDirection), particleIdx, m_densityRawParameters.ptr, - PerRayRadiance ? reinterpret_cast(m_featureRawParameters.ptr) : reinterpret_cast(particleFeaturesPtr), + PerRayRadiance ? reinterpret_cast(reinterpret_cast(m_featureRawParameters.ptr)) + : reinterpret_cast(particleFeaturesPtr), ExtParams::MinParticleKernelDensity, ExtParams::AlphaThreshold, m_featureActiveShDegree, @@ -396,8 +490,9 @@ struct ShRadiativeGaussianVolumetricFeaturesParticles : Params, public ExtParams particleIdx, reinterpret_cast(densityRawParameters), reinterpret_cast(densityRawParametersGrad), - PerRayRadiance ? reinterpret_cast(m_featureRawParameters.ptr) : reinterpret_cast(particleFeatures.data()), - PerRayRadiance ? reinterpret_cast(m_featureRawParameters.gradPtr) : reinterpret_cast(particleFeaturesGradPtr), + PerRayRadiance ? reinterpret_cast(reinterpret_cast(m_featureRawParameters.ptr)) + : reinterpret_cast(particleFeatures.data()), + PerRayRadiance ? m_featureRawParameters.gradPtr : reinterpret_cast(particleFeaturesGradPtr), ExtParams::MinParticleKernelDensity, ExtParams::AlphaThreshold, ExtParams::MinTransmittanceThreshold, @@ -420,7 +515,7 @@ struct ShRadiativeGaussianVolumetricFeaturesParticles : Params, public ExtParams #pragma unroll for (int mask = 1; mask < warpSize; mask *= 2) { #pragma unroll - for (int i = 0; i < ExtParams::FeaturesDim; ++i) { + for (int i = 0; i < ExtParams::RayFeatureDim; ++i) { featuresGrad[i] += __shfl_xor_sync(0xffffffff, featuresGrad[i], mask); } } @@ -428,13 +523,13 @@ struct ShRadiativeGaussianVolumetricFeaturesParticles : Params, public ExtParams // First thread in the warp performs the atomic add if ((tileThreadIdx & (warpSize - 1)) == 0) { #pragma unroll - for (int i = 0; i < ExtParams::FeaturesDim; i++) { + for (int i = 0; i < ExtParams::RayFeatureDim; i++) { atomicAdd(&featuresGradSum[particleIdx][i], featuresGrad[i]); } } } else { #pragma unroll - for (int i = 0; i < ExtParams::FeaturesDim; ++i) { + for (int i = 0; i < ExtParams::RayFeatureDim; ++i) { atomicAdd(&featuresGradSum[particleIdx][i], featuresGrad[i]); } } @@ -493,5 +588,5 @@ private: m_densityRawParameters; int m_featureActiveShDegree = 0; - ShRadiativeGaussianParticlesBuffer m_featureRawParameters; + ShRadiativeGaussianParticlesBuffer m_featureRawParameters; // gradPtr always fp32; ptr cast to TFeatureRawParamPtr for reads }; diff --git a/threedgut_tracer/include/3dgut/kernels/cuda/renderers/gutKBufferRenderer.cuh b/threedgut_tracer/include/3dgut/kernels/cuda/renderers/gutKBufferRenderer.cuh index bf73ec9b..5a432d9a 100644 --- a/threedgut_tracer/include/3dgut/kernels/cuda/renderers/gutKBufferRenderer.cuh +++ b/threedgut_tracer/include/3dgut/kernels/cuda/renderers/gutKBufferRenderer.cuh @@ -23,6 +23,7 @@ struct HitParticle { int idx = -1; float hitT = InvalidHitT; float alpha = 0.0f; + float3 canonicalIntersection = make_float3(0.f, 0.f, 0.f); }; template @@ -91,8 +92,8 @@ struct GUTKBufferRenderer : Params { using DensityRawParameters = typename Particles::DensityRawParameters; using TFeaturesVec = typename Particles::TFeaturesVec; - using TRayPayload = RayPayload; - using TRayPayloadBackward = RayPayloadBackward; + using TRayPayload = RayPayload; + using TRayPayloadBackward = RayPayloadBackward; struct PrefetchedParticleData { uint32_t idx; @@ -115,12 +116,15 @@ struct GUTKBufferRenderer : Params { if constexpr (Backward) { float hitAlphaGrad = 0.f; + float3 canonicalIntersectionGrad = make_float3(0.f, 0.f, 0.f); if constexpr (Params::PerRayParticleFeatures) { particles.featuresIntegrateBwdToBuffer(ray.direction, + hitParticle.canonicalIntersection, + canonicalIntersectionGrad, hitParticle.alpha, hitAlphaGrad, hitParticle.idx, - particles.featuresFromBuffer(hitParticle.idx, ray.direction), + particles.featuresFromBuffer(hitParticle.idx, ray.direction, hitParticle.canonicalIntersection), ray.featuresBackward, ray.featuresGradient); } else { @@ -132,7 +136,7 @@ struct GUTKBufferRenderer : Params { ray.featuresBackward, ray.featuresGradient); #pragma unroll - for (int i = 0; i < Particles::FeaturesDim; ++i) { + for (int i = 0; i < Particles::RayFeatureDim; ++i) { atomicAdd(&(particleFeaturesGradient[hitParticle.idx][i]), particleFeaturesGradientVec[i]); } } @@ -146,7 +150,8 @@ struct GUTKBufferRenderer : Params { ray.transmittanceGradient, hitParticle.hitT, ray.hitTBackward, - ray.hitTGradient); + ray.hitTGradient, + canonicalIntersectionGrad); ray.transmittance *= (1.0 - hitParticle.alpha); @@ -158,11 +163,10 @@ struct GUTKBufferRenderer : Params { ray.hitT); particles.featureIntegrateFwd(hitWeight, - Params::PerRayParticleFeatures ? particles.featuresFromBuffer(hitParticle.idx, ray.direction) : tcnn::max(particleFeatures[hitParticle.idx], 0.f), + Params::PerRayParticleFeatures ? particles.featuresFromBuffer(hitParticle.idx, ray.direction, hitParticle.canonicalIntersection) : tcnn::max(particleFeatures[hitParticle.idx], 0.f), ray.features); - if (hitWeight > 0.0f) - ray.countHit(); + if (hitWeight > 0.0f) ray.countHit(); } if (ray.transmittance < Particles::MinTransmittanceThreshold) { @@ -265,7 +269,8 @@ struct GUTKBufferRenderer : Params { ray.direction, particleData.densityParameters, hitParticle.alpha, - hitParticle.hitT) && + hitParticle.hitT, + hitParticle.canonicalIntersection) && (hitParticle.hitT > ray.tMinMax.x) && (hitParticle.hitT < ray.tMinMax.y)) { @@ -349,24 +354,26 @@ struct GUTKBufferRenderer : Params { if (!ray.isAlive()) break; - float hitAlpha = 0.0f; - float hitT = 0.0f; + float hitAlpha = 0.0f; + float3 hitCanonicalIntersection = make_float3(0.f, 0.f, 0.f); + float hitT = 0.0f; TFeaturesVec hitFeatures = TFeaturesVec::zero(); - bool validHit = false; + bool validHit = false; // Step 1: Each thread tests one Gaussian intersection if (j < tileNumParticlesToProcess) { const uint32_t toProcessSortedIndex = tileParticleRangeIndices.x + j; - const uint32_t particleIdx = sortedTileParticleIdxPtr[toProcessSortedIndex]; + const uint32_t particleIdx = sortedTileParticleIdxPtr[toProcessSortedIndex]; if (particleIdx != GUTParameters::InvalidParticleIdx) { auto densityParams = particles.fetchDensityParameters(particleIdx); if (particles.densityHit(ray.origin, - ray.direction, - densityParams, - hitAlpha, - hitT) && + ray.direction, + densityParams, + hitAlpha, + hitT, + hitCanonicalIntersection) && (hitT > ray.tMinMax.x) && (hitT < ray.tMinMax.y)) { @@ -374,7 +381,7 @@ struct GUTKBufferRenderer : Params { // Get Gaussian features if constexpr (Params::PerRayParticleFeatures) { - hitFeatures = particles.featuresFromBuffer(particleIdx, ray.direction); + hitFeatures = particles.featuresFromBuffer(particleIdx, ray.direction, hitCanonicalIntersection); } else { hitFeatures = tcnn::max(particleFeaturesBuffer[particleIdx], 0.f); } @@ -430,15 +437,15 @@ struct GUTKBufferRenderer : Params { float hitWeight = hitAlpha * particleTransmittance; // Compute weighted contributions - for (int featIdx = 0; featIdx < Particles::FeaturesDim; ++featIdx) { + for (int featIdx = 0; featIdx < Particles::RayFeatureDim; ++featIdx) { accumulatedFeatures[featIdx] = hitFeatures[featIdx] * hitWeight; } - accumulatedHitT = hitT * hitWeight; + accumulatedHitT = hitT * hitWeight; accumulatedHitCount = (hitWeight > 0.0f) ? 1 : 0; } // Step 6: Warp-level reduction (tree-based sum) - for (int featIdx = 0; featIdx < Particles::FeaturesDim; ++featIdx) { + for (int featIdx = 0; featIdx < Particles::RayFeatureDim; ++featIdx) { for (uint32_t offset = WarpSize / 2; offset > 0; offset >>= 1) { accumulatedFeatures[featIdx] += __shfl_down_sync(WarpMask, accumulatedFeatures[featIdx], offset); } @@ -451,7 +458,7 @@ struct GUTKBufferRenderer : Params { // Step 7: Only lane 0 updates ray state (avoid race conditions) if (laneId == 0) { - for (int featIdx = 0; featIdx < Particles::FeaturesDim; ++featIdx) { + for (int featIdx = 0; featIdx < Particles::RayFeatureDim; ++featIdx) { ray.features[featIdx] += accumulatedFeatures[featIdx]; } ray.hitT += accumulatedHitT; @@ -481,83 +488,175 @@ struct GUTKBufferRenderer : Params { static_assert(Backward && (Params::KHitBufferSize == 0), "Optimized path for backward pass with no KBuffer"); using namespace threedgut; - __shared__ PrefetchedRawParticleData prefetchedRawParticlesData[GUTParameters::Tiling::BlockSize]; - for (uint32_t i = 0; i < tileNumBlocksToProcess; i++, tileNumParticlesToProcess -= GUTParameters::Tiling::BlockSize) { + if constexpr (Params::PerRayParticleFeatures) { + // NHT path: features are re-evaluated per-ray from buffer (no pre-computed feature cache). + // Gradients accumulate into a thread-private local buffer, then warp-reduced before a single + // atomicAdd per feature dim — reduces global memory traffic by up to 32× vs. per-hit atomics. + __shared__ PrefetchedParticleData prefetchedParticlesData[GUTParameters::Tiling::BlockSize]; - if (__syncthreads_and(!ray.isAlive())) { - break; - } + for (uint32_t i = 0; i < tileNumBlocksToProcess; i++, tileNumParticlesToProcess -= GUTParameters::Tiling::BlockSize) { - // Collectively fetch particle data - const uint32_t toProcessSortedIndex = tileParticleRangeIndices.x + i * GUTParameters::Tiling::BlockSize + tileThreadIdx; - if (toProcessSortedIndex < tileParticleRangeIndices.y) { - const uint32_t particleIdx = sortedTileParticleIdxPtr[toProcessSortedIndex]; - if (particleIdx != GUTParameters::InvalidParticleIdx) { - prefetchedRawParticlesData[tileThreadIdx].densityParameters = particles.fetchDensityRawParameters(particleIdx); - if constexpr (Params::PerRayParticleFeatures) { - prefetchedRawParticlesData[tileThreadIdx].features = TFeaturesVec::zero(); + if (__syncthreads_and(!ray.isAlive())) { + break; + } + + // Collectively fetch density parameters only (features fetched per-ray below) + const uint32_t toProcessSortedIndex = tileParticleRangeIndices.x + i * GUTParameters::Tiling::BlockSize + tileThreadIdx; + if (toProcessSortedIndex < tileParticleRangeIndices.y) { + const uint32_t particleIdx = sortedTileParticleIdxPtr[toProcessSortedIndex]; + if (particleIdx != GUTParameters::InvalidParticleIdx) { + prefetchedParticlesData[tileThreadIdx] = {particleIdx, particles.fetchDensityParameters(particleIdx)}; } else { - prefetchedRawParticlesData[tileThreadIdx].features = tcnn::max(particleFeaturesBuffer[particleIdx], 0.f); + prefetchedParticlesData[tileThreadIdx].idx = GUTParameters::InvalidParticleIdx; } - prefetchedRawParticlesData[tileThreadIdx].idx = particleIdx; } else { - prefetchedRawParticlesData[tileThreadIdx].idx = GUTParameters::InvalidParticleIdx; + prefetchedParticlesData[tileThreadIdx].idx = GUTParameters::InvalidParticleIdx; + } + __syncthreads(); + + // Process fetched particles + for (int j = 0; j < min(GUTParameters::Tiling::BlockSize, tileNumParticlesToProcess); j++) { + + if (__all_sync(GUTParameters::Tiling::WarpMask, !ray.isAlive())) { + break; + } + + const PrefetchedParticleData particleData = prefetchedParticlesData[j]; + if (particleData.idx == GUTParameters::InvalidParticleIdx) { + ray.kill(); + break; + } + + // Thread-private feature gradient buffer for warp reduction. + // Zero-initialized: non-hitting threads contribute 0 to the warp sum. + float featureLocalGrad[Particles::ParticleFeatureDim] = {}; + + if (ray.isAlive()) { + float hitAlpha = 0.f; + float hitT = 0.f; + float3 canonicalIntersection = make_float3(0.f, 0.f, 0.f); + + if (particles.densityHit(ray.origin, ray.direction, particleData.densityParameters, + hitAlpha, hitT, canonicalIntersection)) { + // Re-evaluate NHT features at canonical intersection point (cheap: barycentric interp) + const TFeaturesVec hitFeatures = particles.featuresFromBuffer(particleData.idx, ray.direction, canonicalIntersection); + + float hitAlphaGrad = 0.f; + float3 canonicalIntersectionGrad = make_float3(0.f, 0.f, 0.f); + + // Write feature grad to thread-private local buffer (no atomics); warp reduction follows below. + particles.featuresIntegrateBwdToLocalGrad(ray.direction, + canonicalIntersection, + canonicalIntersectionGrad, + hitAlpha, + hitAlphaGrad, + particleData.idx, + hitFeatures, + ray.featuresBackward, + ray.featuresGradient, + featureLocalGrad); + + particles.template densityProcessHitBwdToBuffer(ray.origin, + ray.direction, + particleData.idx, + hitAlpha, + hitAlphaGrad, + ray.transmittanceBackward, + ray.transmittanceGradient, + hitT, + ray.hitTBackward, + ray.hitTGradient, + canonicalIntersectionGrad); + + ray.transmittance *= (1.0f - hitAlpha); + } + + if (ray.transmittance < Particles::MinTransmittanceThreshold) { + ray.kill(); + } + } + + // Warp reduction: all 32 threads (alive or not) participate in __shfl_xor_sync. + // Non-hitting threads contribute featureLocalGrad=0 → no spurious gradient. + // Reduces 32×ParticleFeatureDim atomics per particle to ParticleFeatureDim atomics. + particles.featureLocalGradWarpReduceAndWrite(particleData.idx, featureLocalGrad, tileThreadIdx); } - } else { - prefetchedRawParticlesData[tileThreadIdx].idx = GUTParameters::InvalidParticleIdx; } - __syncthreads(); + } else { + // SH path: pre-computed features fetched from shared memory, gradients via precomputed buffer. + __shared__ PrefetchedRawParticleData prefetchedRawParticlesData[GUTParameters::Tiling::BlockSize]; - // Process fetched particles - for (int j = 0; j < min(GUTParameters::Tiling::BlockSize, tileNumParticlesToProcess); j++) { + for (uint32_t i = 0; i < tileNumBlocksToProcess; i++, tileNumParticlesToProcess -= GUTParameters::Tiling::BlockSize) { - if (__all_sync(GUTParameters::Tiling::WarpMask, !ray.isAlive())) { + if (__syncthreads_and(!ray.isAlive())) { break; } - const PrefetchedRawParticleData particleData = prefetchedRawParticlesData[j]; - if (particleData.idx == GUTParameters::InvalidParticleIdx) { - ray.kill(); - break; + // Collectively fetch particle data + const uint32_t toProcessSortedIndex = tileParticleRangeIndices.x + i * GUTParameters::Tiling::BlockSize + tileThreadIdx; + if (toProcessSortedIndex < tileParticleRangeIndices.y) { + const uint32_t particleIdx = sortedTileParticleIdxPtr[toProcessSortedIndex]; + if (particleIdx != GUTParameters::InvalidParticleIdx) { + prefetchedRawParticlesData[tileThreadIdx].densityParameters = particles.fetchDensityRawParameters(particleIdx); + prefetchedRawParticlesData[tileThreadIdx].features = tcnn::max(particleFeaturesBuffer[particleIdx], 0.f); + prefetchedRawParticlesData[tileThreadIdx].idx = particleIdx; + } else { + prefetchedRawParticlesData[tileThreadIdx].idx = GUTParameters::InvalidParticleIdx; + } + } else { + prefetchedRawParticlesData[tileThreadIdx].idx = GUTParameters::InvalidParticleIdx; } + __syncthreads(); + + // Process fetched particles + for (int j = 0; j < min(GUTParameters::Tiling::BlockSize, tileNumParticlesToProcess); j++) { - DensityRawParameters densityRawParametersGrad; - densityRawParametersGrad.density = 0.0f; - densityRawParametersGrad.position = make_float3(0.0f); - densityRawParametersGrad.quaternion = make_float4(0.0f); - densityRawParametersGrad.scale = make_float3(0.0f); - - TFeaturesVec featuresGrad = TFeaturesVec::zero(); - - if (ray.isAlive()) { - particles.processHitBwd( - ray.origin, - ray.direction, - particleData.idx, - particleData.densityParameters, - &densityRawParametersGrad, - particleData.features, - &featuresGrad, - ray.transmittance, - ray.transmittanceBackward, - ray.transmittanceGradient, - ray.features, - ray.featuresBackward, - ray.featuresGradient, - ray.hitT, - ray.hitTBackward, - ray.hitTGradient); - if (ray.transmittance < Particles::MinTransmittanceThreshold) { + if (__all_sync(GUTParameters::Tiling::WarpMask, !ray.isAlive())) { + break; + } + + const PrefetchedRawParticleData particleData = prefetchedRawParticlesData[j]; + if (particleData.idx == GUTParameters::InvalidParticleIdx) { ray.kill(); + break; + } + + DensityRawParameters densityRawParametersGrad; + densityRawParametersGrad.density = 0.0f; + densityRawParametersGrad.position = make_float3(0.0f); + densityRawParametersGrad.quaternion = make_float4(0.0f); + densityRawParametersGrad.scale = make_float3(0.0f); + + TFeaturesVec featuresGrad = TFeaturesVec::zero(); + + if (ray.isAlive()) { + particles.processHitBwd( + ray.origin, + ray.direction, + particleData.idx, + particleData.densityParameters, + &densityRawParametersGrad, + particleData.features, + &featuresGrad, + ray.transmittance, + ray.transmittanceBackward, + ray.transmittanceGradient, + ray.features, + ray.featuresBackward, + ray.featuresGradient, + ray.hitT, + ray.hitTBackward, + ray.hitTGradient); + if (ray.transmittance < Particles::MinTransmittanceThreshold) { + ray.kill(); + } } - } - if constexpr (!Params::PerRayParticleFeatures) { particles.processHitBwdUpdateFeaturesGradient(particleData.idx, featuresGrad, particleFeaturesGradientBuffer, tileThreadIdx); + particles.processHitBwdUpdateDensityGradient(particleData.idx, densityRawParametersGrad, tileThreadIdx); } - particles.processHitBwdUpdateDensityGradient(particleData.idx, densityRawParametersGrad, tileThreadIdx); } } } diff --git a/threedgut_tracer/include/3dgut/kernels/cuda/renderers/gutProjector.cuh b/threedgut_tracer/include/3dgut/kernels/cuda/renderers/gutProjector.cuh index 6ba66605..f335ef8c 100644 --- a/threedgut_tracer/include/3dgut/kernels/cuda/renderers/gutProjector.cuh +++ b/threedgut_tracer/include/3dgut/kernels/cuda/renderers/gutProjector.cuh @@ -402,31 +402,30 @@ struct GUTProjector : Params, UTParams { const float* __restrict__ particlesPrecomputedFeaturesPtr, const float* __restrict__ particlesPrecomputedFeaturesGradPtr, threedgut::MemoryHandles parametersGradient) { - if constexpr (Params::PerRayParticleFeatures) { - return; - } + if constexpr (!Params::PerRayParticleFeatures) { - const uint32_t particleIdx = blockIdx.x * blockDim.x + threadIdx.x; - if (particleIdx >= numParticles) { - return; - } - if (particlesTilesCountPtr[particleIdx] == 0) { - return; - } + const uint32_t particleIdx = blockIdx.x * blockDim.x + threadIdx.x; + if (particleIdx >= numParticles) { + return; + } + if (particlesTilesCountPtr[particleIdx] == 0) { + return; + } - Particles particles; - particles.initializeDensity(parameters); - particles.initializeDensityGradient(parametersGradient); - const tcnn::vec3 incidentDirection = tcnn::normalize(particles.fetchPosition(particleIdx) - sensorWorldPosition); - tcnn::vec3 incidentDirectionGrad = tcnn::vec3(0.0f); - - particles.initializeFeatures(parameters); - particles.initializeFeaturesGradient(parametersGradient); - particles.featuresBwdToBuffer(particleIdx, - reinterpret_cast(particlesPrecomputedFeaturesGradPtr)[particleIdx], - incidentDirection, - incidentDirectionGrad); - - particles.template densityIncidentDirectionBwdToBuffer(particleIdx, sensorWorldPosition, incidentDirectionGrad); + Particles particles; + particles.initializeDensity(parameters); + particles.initializeDensityGradient(parametersGradient); + const tcnn::vec3 incidentDirection = tcnn::normalize(particles.fetchPosition(particleIdx) - sensorWorldPosition); + tcnn::vec3 incidentDirectionGrad = tcnn::vec3(0.0f); + + particles.initializeFeatures(parameters); + particles.initializeFeaturesGradient(parametersGradient); + particles.featuresBwdToBuffer(particleIdx, + reinterpret_cast(particlesPrecomputedFeaturesGradPtr)[particleIdx], + incidentDirection, + incidentDirectionGrad); + + particles.template densityIncidentDirectionBwdToBuffer(particleIdx, sensorWorldPosition, incidentDirectionGrad); + } } }; diff --git a/threedgut_tracer/include/3dgut/kernels/cuda/renderers/gutRenderer.cuh b/threedgut_tracer/include/3dgut/kernels/cuda/renderers/gutRenderer.cuh index 8a451c4e..06aca8ee 100644 --- a/threedgut_tracer/include/3dgut/kernels/cuda/renderers/gutRenderer.cuh +++ b/threedgut_tracer/include/3dgut/kernels/cuda/renderers/gutRenderer.cuh @@ -88,7 +88,7 @@ __global__ void render(threedgut::RenderParameters params, tcnn::mat4x3 sensorToWorldTransform, float* __restrict__ worldHitCountPtr, float* __restrict__ worldHitDistancePtr, - tcnn::vec4* __restrict__ radianceDensityPtr, + TFeatureDensityElem* __restrict__ featureDensityPtr, const tcnn::vec2* __restrict__ particlesProjectedPositionPtr, const tcnn::vec4* __restrict__ particlesProjectedConicOpacityPtr, const float* __restrict__ particlesGlobalDepthPtr, @@ -111,7 +111,7 @@ __global__ void render(threedgut::RenderParameters params, // TGUTModel::eval(params, ray, {parameterMemoryHandles}); // NB : finalize ray is not differentiable (has to be no-op when used in a differentiable renderer) - finalizeRay(ray, params, sensorRayOriginPtr, worldHitCountPtr, worldHitDistancePtr, radianceDensityPtr, sensorToWorldTransform); + finalizeRay(ray, params, sensorRayOriginPtr, worldHitCountPtr, worldHitDistancePtr, featureDensityPtr, sensorToWorldTransform); } #if FINE_GRAINED_LOAD_BALANCING @@ -124,7 +124,7 @@ __global__ void renderBalanced(threedgut::RenderParameters params, tcnn::mat4x3 sensorToWorldTransform, float* __restrict__ worldHitCountPtr, float* __restrict__ worldHitDistancePtr, - tcnn::vec4* __restrict__ radianceDensityPtr, + TFeatureDensityElem* __restrict__ featureDensityPtr, const tcnn::vec2* __restrict__ particlesProjectedPositionPtr, const tcnn::vec4* __restrict__ particlesProjectedConicOpacityPtr, const float* __restrict__ particlesGlobalDepthPtr, @@ -210,7 +210,7 @@ __global__ void renderBalanced(threedgut::RenderParameters params, // Only lane 0 should write, as only it has accumulated the correct values if (laneId == 0) { finalizeRay(ray, params, sensorRayOriginPtr, worldHitCountPtr, - worldHitDistancePtr, radianceDensityPtr, sensorToWorldTransform); + worldHitDistancePtr, featureDensityPtr, sensorToWorldTransform); } } } @@ -224,8 +224,8 @@ __global__ void renderBackward(threedgut::RenderParameters params, tcnn::mat4x3 sensorToWorldTransform, const float* __restrict__ worldHitDistancePtr, const float* __restrict__ worldHitDistanceGradientPtr, - const tcnn::vec4* __restrict__ radianceDensityPtr, - const tcnn::vec4* __restrict__ radianceDensityGradientPtr, + const TFeatureDensityElem* __restrict__ featureDensityPtr, + const float* __restrict__ featureDensityGradientPtr, tcnn::vec3* __restrict__ /*worldRayOriginGradientPtr*/, tcnn::vec3* __restrict__ /*worldRayDirectionGradientPtr*/, const tcnn::vec2* __restrict__ particlesProjectedPositionPtr, @@ -244,8 +244,8 @@ __global__ void renderBackward(threedgut::RenderParameters params, sensorRayDirectionPtr, worldHitDistancePtr, worldHitDistanceGradientPtr, - radianceDensityPtr, - radianceDensityGradientPtr, + featureDensityPtr, + featureDensityGradientPtr, sensorToWorldTransform); // TGUTModel::evalBackward(params, ray, {parameterMemoryHandles}, {parameterGradientMemoryHandles}); diff --git a/threedgut_tracer/include/3dgut/kernels/slang/models/gaussianParticles.slang b/threedgut_tracer/include/3dgut/kernels/slang/models/gaussianParticles.slang index 2e938e26..687acf8b 100644 --- a/threedgut_tracer/include/3dgut/kernels/slang/models/gaussianParticles.slang +++ b/threedgut_tracer/include/3dgut/kernels/slang/models/gaussianParticles.slang @@ -178,6 +178,17 @@ struct Parameters : IDifferentiable { return sqrt(dot(grds, grds)); } +[BackwardDifferentiable][ForceInline] float canonicalRayIntersection( + float3 canonicalRayOrigin, + float3 canonicalRayDirection, + float3 scale, + out float3 canonicalIntersection) { + const float3 canonicalGrds = canonicalRayDirection * dot(canonicalRayDirection, -1 * canonicalRayOrigin); + canonicalIntersection = canonicalRayOrigin + canonicalGrds; + const float3 grds = scale * canonicalGrds; + return sqrt(dot(grds, grds)); +} + [BackwardDifferentiable][ForceInline] float3 canonicalRayNormal( float3 canonicalRayOrigin, float3 canonicalRayDirection, @@ -200,6 +211,7 @@ bool hit( Parameters parameters, out float alpha, inout float depth, + out float3 canonicalIntersection, no_diff bool enableNormal, inout float3 normal) { @@ -220,7 +232,7 @@ bool hit( const bool acceptHit = ((maxResponse > MinParticleKernelDensity) && (alpha > MinParticleAlpha)); if (acceptHit) { - depth = canonicalRayDistance(canonicalRayOrigin, canonicalRayDirection, parameters.scale); + depth = canonicalRayIntersection(canonicalRayOrigin, canonicalRayDirection, parameters.scale, canonicalIntersection); if (enableNormal) { normal = canonicalRayNormal(canonicalRayOrigin, canonicalRayDirection, parameters.scale, parameters.rotationT); @@ -269,6 +281,7 @@ float processHitFromBuffer( no_diff RawParametersBuffer parametersBuffer, inout float transmittance, inout float integratedDepth, + out float3 canonicalIntersection, no_diff bool enableNormal, inout float3 integratedNormal) { @@ -280,6 +293,7 @@ float processHitFromBuffer( fetchParameters(particleIdx, parametersBuffer), alpha, depth, + canonicalIntersection, enableNormal, normal)) { @@ -346,6 +360,7 @@ inline bool particleDensityHit( gaussianParticle.Parameters parameters, out float alpha, out float depth, + out float3 canonicalIntersection, bool enableNormal, out float3 normal) { @@ -354,6 +369,7 @@ inline bool particleDensityHit( parameters, alpha, depth, + canonicalIntersection, enableNormal, normal); } @@ -385,6 +401,7 @@ inline float particleDensityProcessHitFwdFromBuffer( gaussianParticle.CommonParameters commonParameters, inout float transmittance, inout float integratedDepth, + out float3 canonicalIntersection, in bool enableNormal, inout float3 integratedNormal) { @@ -395,6 +412,7 @@ inline float particleDensityProcessHitFwdFromBuffer( commonParameters.parametersBuffer, transmittance, integratedDepth, + canonicalIntersection, enableNormal, integratedNormal); } @@ -412,6 +430,7 @@ void particleDensityProcessHitBwdToBuffer( in float depth, inout float integratedDepth, inout float integratedDepthGrad, + in float3 canonicalIntersectionGrad, bool enableNormal, in float3 normal, inout float3 integratedNormal, @@ -445,6 +464,7 @@ void particleDensityProcessHitBwdToBuffer( commonParameters.parametersBuffer, transmittanceDiff, integratedDepthDiff, + canonicalIntersectionGrad, enableNormal, integratedNormalDiff, alphaGrad); diff --git a/threedgut_tracer/include/3dgut/kernels/slang/models/neuralHarmonicFeaturesParticle.slang b/threedgut_tracer/include/3dgut/kernels/slang/models/neuralHarmonicFeaturesParticle.slang new file mode 100644 index 00000000..43ece455 --- /dev/null +++ b/threedgut_tracer/include/3dgut/kernels/slang/models/neuralHarmonicFeaturesParticle.slang @@ -0,0 +1,351 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use it except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Neural Harmonic Features: per-particle K-dim features, output N (decoder input); decoder maps N -> RGB. +// Same CudaDeviceExport API as shRadiativeParticles.slang for drop-in replacement. +// Compile-time: FEATURE_INTERPOLATION_TYPE, FEATURE_INTERPOLATION_SUPPORT, FEATURE_ACTIVATION_TYPE, FEATURE_ACTIVATION_NUM_FREQUENCIES, +// PARTICLE_FEATURE_DIM (total K per particle = buffer stride), INTERP_POINT_FEATURE_DIM (per-interpolation-point = K/num_points). +// Support: center -> K=interpPointDim; tetrahedra -> K=4*interpPointDim. With activation: N = interpPointDim * num_frequencies (decoder input). + +namespace neuralHarmonicFeaturesParticle +{ +static const int ParticleFeatureDim = PARTICLE_FEATURE_DIM; // Total per-particle (K); not per-interpolation-point +static const int RayFeatureDim = RAY_FEATURE_DIM; // Decoder input N = INTERP_POINT_FEATURE_DIM * FEATURE_ACTIVATION_NUM_FREQUENCIES +static const int InterpPointFeatureDim = INTERP_POINT_FEATURE_DIM; // Per-interpolation-point dimension +static const int FeatureActivationType_None = 0; +static const int FeatureActivationType_Siren = 1; +static const int FeatureActivationType_Sincos = 2; +static const int FeatureActivationType_Relu = 3; +static const int FeatureActivationType = FEATURE_ACTIVATION_TYPE; +static const int FeatureActivationNumFrequencies = FEATURE_ACTIVATION_NUM_FREQUENCIES; + +// Interpolation type (compile-time from FEATURE_INTERPOLATION_TYPE) +static const int InterpolationType_Barycentric = 0; +static const int InterpolationType_Bezier = 1; // Not supported yet +static const int InterpolationType = FEATURE_INTERPOLATION_TYPE; + +// Interpolation support (compile-time from FEATURE_INTERPOLATION_SUPPORT) +static const int InterpolationSupport_Center = 0; +static const int InterpolationSupport_Tetrahedra = 1; +static const int InterpolationSupport_CoTriangles = 2; // 2 coplanar triangles / Not supported yet +static const int InterpolationSupport = FEATURE_INTERPOLATION_SUPPORT; + +// Canonical regular tetrahedron that contains the unit sphere; each face is tangent to the unit sphere +// (insphere radius = 1). Same geometry as particlePrimitives.cu (enclosing tetrahedra). +// Incenter at origin; base z=-1, apex z=3; edge s=sqrt(24), height h=4, inradius r=1. +static const float tetraHedraEdge = 4.898979485566356f; // sqrt(24) +static const float tetraHedraFaceHeight = 4.242640687119285f; // s*sqrt(3)/2 +static const float tetraHedraHeight = 4.0f; // s*sqrt(2/3) +static const float tetraHedraFaceInRadius = 1.4142135623730951f; // s*sqrt(3)/6 = sqrt(2) +static const float tetraHedraInRadius = 1.0f; // unit sphere +static const float3 canonicalTetraVerts[4] = { + float3(-0.5f * tetraHedraEdge, -tetraHedraFaceInRadius, -1.0f), + float3(0.0f, tetraHedraFaceHeight - tetraHedraFaceInRadius, -1.0f), + float3(0.0f, 0.0f, tetraHedraHeight - tetraHedraInRadius), + float3(0.5f * tetraHedraEdge, -tetraHedraFaceInRadius, -1.0f) +}; +// Cramer terms for canonical tetrahedron (independent of P); used by barycentricTetrahedronCanonical. +static const float3 canonicalTetraE1 = canonicalTetraVerts[1] - canonicalTetraVerts[0]; +static const float3 canonicalTetraE2 = canonicalTetraVerts[2] - canonicalTetraVerts[0]; +static const float3 canonicalTetraE3 = canonicalTetraVerts[3] - canonicalTetraVerts[0]; +static const float3 canonicalTetraCrossE2E3 = cross(canonicalTetraE2, canonicalTetraE3); +static const float canonicalTetraDet = dot(canonicalTetraE1, canonicalTetraCrossE2E3); +static const float canonicalTetraInvDet = 1.0f / canonicalTetraDet; + +}; + +#if PARTICLE_FEATURE_HALF +typedef half feat_elem_t; +#else +typedef float feat_elem_t; +#endif + +namespace neuralHarmonicFeaturesParticle +{ + +struct ParametersBuffer +{ + Ptr _dataPtr; // [N_particles, K] flat; fp16 when PARTICLE_FEATURE_HALF + float *_gradPtr; + bool exclusiveGradient; +}; + +struct Parameters : IDifferentiable +{ + Array features; +}; + +[BackwardDifferentiable][ForceInline] +Parameters fetchParametersFromBuffer(no_diff uint32_t particleIdx, + no_diff int interpPointIdx, + no_diff ParametersBuffer parametersBuffer) +{ + Parameters parameters; + const uint32_t particleOffset = particleIdx * ParticleFeatureDim; + const uint32_t interpPointOffset = particleOffset + interpPointIdx * InterpPointFeatureDim; + [ForceUnroll] for (int i = 0; i < InterpPointFeatureDim; ++i) { + parameters.features[i] = parametersBuffer._dataPtr[interpPointOffset + i]; + } + return parameters; +} + +[BackwardDerivativeOf(fetchParametersFromBuffer)][ForceInline] +void fetchParametersFromBufferBwd(no_diff uint32_t particleIdx, + no_diff int interpPointIdx, + no_diff ParametersBuffer parametersBuffer, + Parameters parametersGrad) +{ + const uint32_t particleOffset = particleIdx * ParticleFeatureDim; + const uint32_t interpPointOffset = particleOffset + interpPointIdx * InterpPointFeatureDim; + [ForceUnroll] for (int i = 0; i < InterpPointFeatureDim; ++i) { + const float grad = parametersGrad.features[i]; + if (parametersBuffer.exclusiveGradient) { + parametersBuffer._gradPtr[interpPointOffset + i] += grad; + } else { + InterlockedAdd(parametersBuffer._gradPtr[interpPointOffset + i], grad); + } + } +} + +// Barycentric coordinates for the canonical tetrahedron only. Uses precomputed static const e1,e2,e3,invDet (independent of P). +[BackwardDifferentiable][ForceInline] +float4 barycentricTetrahedronCanonical(float3 P) +{ + float3 d = P - canonicalTetraVerts[0]; + float4 weights; + weights.y = dot(d, canonicalTetraCrossE2E3) * canonicalTetraInvDet; + weights.z = dot(canonicalTetraE1, cross(d, canonicalTetraE3)) * canonicalTetraInvDet; + weights.w = dot(canonicalTetraE1, cross(canonicalTetraE2, d)) * canonicalTetraInvDet; + weights.x = 1.0f - weights.y - weights.z - weights.w; + return weights; +} + +// Encode and activate : none -> identity; siren -> sin(b*2^f); sincos -> sin(b*2^f)+cos(b*2^f); relu -> max(0,b). +[BackwardDifferentiable][ForceInline] +float encodeAndActivate(float baseVal, no_diff int f) +{ + if (FeatureActivationType == FeatureActivationType_None) + return baseVal; + if (FeatureActivationType == FeatureActivationType_Relu) + return max(0.0f, baseVal); + float freq = ldexp(1.0f, f); + float angle = baseVal * freq; + if (FeatureActivationType == FeatureActivationType_Siren) + return sin(angle); + return sin(angle) + cos(angle); +} + +// Compute blended features into baseFeatures[INTERP_POINT_FEATURE_DIM], then optionally expand by activation to features[RayFeatureDim]. +// canonicalPosition is differential (hit position in particle canonical space; gradient accumulated in API canonicalPositionGrad). +[BackwardDifferentiable][ForceInline] +void featuresFromParametersBuffer(ParametersBuffer parametersBuffer, + no_diff uint32_t particleIdx, + float3 canonicalPosition, + out Array features +) +{ + Array baseFeatures = fetchParametersFromBuffer(particleIdx, 0, parametersBuffer).features; + if (InterpolationSupport == InterpolationSupport_Tetrahedra && InterpolationType == InterpolationType_Barycentric) { + float4 barycentricWeights = barycentricTetrahedronCanonical(canonicalPosition); + [ForceUnroll] for (int n = 0; n < InterpPointFeatureDim; ++n) { + baseFeatures[n] *= barycentricWeights[0]; + } + [ForceUnroll] for (int k = 1; k < 4; ++k) { + Parameters parameters = fetchParametersFromBuffer(particleIdx, k, parametersBuffer); + [ForceUnroll] for (int n = 0; n < InterpPointFeatureDim; ++n) { + baseFeatures[n] += barycentricWeights[k] * parameters.features[n]; + } + } + } + + if (FeatureActivationType == FeatureActivationType_None) { + [ForceUnroll] for (int i = 0; i < InterpPointFeatureDim; ++i) { + features[i] = baseFeatures[i]; + } + } else { + [ForceUnroll] for (int k = 0; k < InterpPointFeatureDim; ++k) { + [ForceUnroll] for (int f = 0; f < FeatureActivationNumFrequencies; ++f) { + features[k * FeatureActivationNumFrequencies + f] = encodeAndActivate(baseFeatures[k], f); + } + } + } +} + +[BackwardDifferentiable][ForceInline] +void integrateFeatures(float weight, + in Array features, + inout float integratedFeatures[RayFeatureDim]) +{ + if (weight > 0.0f) { + [ForceUnroll] for (int i = 0; i < RayFeatureDim; ++i) { + if (backToFront) + integratedFeatures[i] = lerp(integratedFeatures[i], features[i], weight); + else + integratedFeatures[i] += features[i] * weight; + } + } +} + +[BackwardDifferentiable][ForceInline] +void integrateFeaturesFromBuffer(float weight, + no_diff uint32_t particleIdx, + ParametersBuffer parametersBuffer, + float3 canonicalPosition, + inout float integratedFeatures[RayFeatureDim]) +{ + if (weight > 0.0f) { + Array features; + featuresFromParametersBuffer(parametersBuffer, particleIdx, canonicalPosition, features); + integrateFeatures(weight, features, integratedFeatures); + } +} + +} // namespace neuralHarmonicFeaturesParticle + +// ------------------------------------------------------------------------------------------------------------------ +// Entry points - same CudaDeviceExport API as shRadiativeParticles.slang + +[CudaDeviceExport] +inline void particleFeaturesFromBuffer( + in uint32_t particleIdx, + in feat_elem_t *featuresBufferPtr, + in int auxParam, + in float3 incidentDirection, + in float3 canonicalPosition, + out float features[neuralHarmonicFeaturesParticle.RayFeatureDim]) +{ + neuralHarmonicFeaturesParticle.ParametersBuffer parametersBuffer; + parametersBuffer._dataPtr = (Ptr)featuresBufferPtr; + parametersBuffer._gradPtr = nullptr; + parametersBuffer.exclusiveGradient = false; + + neuralHarmonicFeaturesParticle.featuresFromParametersBuffer( + parametersBuffer, + particleIdx, + canonicalPosition, + features + ); +} + +[CudaDeviceExport] +inline void particleFeaturesIntegrateFwd(in float weight, + in float features[neuralHarmonicFeaturesParticle.RayFeatureDim], + inout float integratedFeatures[neuralHarmonicFeaturesParticle.RayFeatureDim]) +{ + neuralHarmonicFeaturesParticle.integrateFeatures( + weight, + (Array)features, + (Array)integratedFeatures + ); +} + +[CudaDeviceExport] +inline void particleFeaturesIntegrateFwdFromBuffer( + in float3 incidentDirection, + in float3 canonicalPosition, + in float weight, + in uint32_t particleIdx, + in feat_elem_t *featuresBufferPtr, + in int auxParam, + inout float integratedFeatures[neuralHarmonicFeaturesParticle.RayFeatureDim]) +{ + neuralHarmonicFeaturesParticle.ParametersBuffer parametersBuffer; + parametersBuffer._dataPtr = (Ptr)featuresBufferPtr; + parametersBuffer._gradPtr = nullptr; + parametersBuffer.exclusiveGradient = false; + + neuralHarmonicFeaturesParticle.integrateFeaturesFromBuffer( + weight, particleIdx, parametersBuffer, canonicalPosition, integratedFeatures); +} + +// canonicalPosition is differential; canonicalPositionGrad is inout and accumulated by this backward. +[CudaDeviceExport] void particleFeaturesIntegrateBwdToBuffer( + in float3 incidentDirection, + in float3 canonicalPosition, + inout float3 canonicalPositionGrad, + in float alpha, + inout float alphaGrad, + in uint32_t particleIdx, + in feat_elem_t *featuresBufferPtr, + in float *featuresBufferGradPtr, + in int auxParam, + in bool exclusiveGradient, + in float features[neuralHarmonicFeaturesParticle.RayFeatureDim], + inout float integratedFeatures[neuralHarmonicFeaturesParticle.RayFeatureDim], + inout float integratedFeaturesGrad[neuralHarmonicFeaturesParticle.RayFeatureDim]) +{ + if (alpha > 0.0f) + { + neuralHarmonicFeaturesParticle.ParametersBuffer parametersBuffer; + parametersBuffer._dataPtr = (Ptr)featuresBufferPtr; + parametersBuffer._gradPtr = (Ptr)featuresBufferGradPtr; + parametersBuffer.exclusiveGradient = exclusiveGradient; + + DifferentialPair alphaDiff = DifferentialPair(alpha, alphaGrad); + + const float weight = 1.0f / (1.0f - alpha); + [ForceUnroll] for (int i = 0; i < neuralHarmonicFeaturesParticle.RayFeatureDim; ++i) { + integratedFeatures[i] = (integratedFeatures[i] - features[i] * alpha) * weight; + } + + DifferentialPair integratedFeaturesDiff = + DifferentialPair(integratedFeatures, integratedFeaturesGrad); + + DifferentialPair canonicalPositionDiff = DifferentialPair(canonicalPosition, canonicalPositionGrad); + + bwd_diff(neuralHarmonicFeaturesParticle.integrateFeaturesFromBuffer)( + alphaDiff, + particleIdx, + parametersBuffer, + canonicalPositionDiff, + integratedFeaturesDiff); + + integratedFeaturesGrad = integratedFeaturesDiff.getDifferential(); + canonicalPositionGrad += canonicalPositionDiff.getDifferential(); + alphaGrad += alphaDiff.getDifferential(); + } +} + +[CudaDeviceExport] void particleFeaturesIntegrateBwd( + in float alpha, + inout float alphaGrad, + in float features[neuralHarmonicFeaturesParticle.RayFeatureDim], + inout float featuresGrad[neuralHarmonicFeaturesParticle.RayFeatureDim], + inout float integratedFeatures[neuralHarmonicFeaturesParticle.RayFeatureDim], + inout float integratedFeaturesGrad[neuralHarmonicFeaturesParticle.RayFeatureDim]) +{ + if (alpha > 0.0f) + { + DifferentialPair alphaDiff = DifferentialPair(alpha, alphaGrad); + DifferentialPair featuresDiff = + DifferentialPair(features, featuresGrad); + + const float weight = 1.0f / (1.0f - alpha); + [ForceUnroll] for (int i = 0; i < neuralHarmonicFeaturesParticle.RayFeatureDim; ++i) { + integratedFeatures[i] = (integratedFeatures[i] - features[i] * alpha) * weight; + } + DifferentialPair integratedFeaturesDiff = + DifferentialPair(integratedFeatures, integratedFeaturesGrad); + + bwd_diff(neuralHarmonicFeaturesParticle.integrateFeatures)( + alphaDiff, + featuresDiff, + integratedFeaturesDiff); + + alphaGrad = alphaDiff.getDifferential(); + featuresGrad = featuresDiff.getDifferential(); + integratedFeaturesGrad = integratedFeaturesDiff.getDifferential(); + } +} diff --git a/threedgut_tracer/include/3dgut/kernels/slang/models/radiativeParticles.slang b/threedgut_tracer/include/3dgut/kernels/slang/models/radiativeParticles.slang new file mode 100644 index 00000000..32d27dad --- /dev/null +++ b/threedgut_tracer/include/3dgut/kernels/slang/models/radiativeParticles.slang @@ -0,0 +1,27 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Radiative Particles Wrapper +// Conditionally includes the appropriate feature implementation based on FEATURE_TRANSFORM_TYPE + +#if FEATURE_TRANSFORM_TYPE == 0 + // Spherical Harmonics mode + #include <3dgut/kernels/slang/models/shRadiativeParticles.slang> +#elif FEATURE_TRANSFORM_TYPE == 1 + // Post-MLP radiance mode + #include <3dgut/kernels/slang/models/neuralHarmonicFeaturesParticle.slang> +#else + #error "Unknown FEATURE_TRANSFORM_TYPE. Must be 0 (SH) or 1 (neural_harmonic_features)" +#endif diff --git a/threedgut_tracer/include/3dgut/kernels/slang/models/shRadiativeParticles.slang b/threedgut_tracer/include/3dgut/kernels/slang/models/shRadiativeParticles.slang index aefe138c..1ef1a7e7 100644 --- a/threedgut_tracer/include/3dgut/kernels/slang/models/shRadiativeParticles.slang +++ b/threedgut_tracer/include/3dgut/kernels/slang/models/shRadiativeParticles.slang @@ -2,7 +2,7 @@ // SPDX-License-Identifier: Apache-2.0 // // Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. +// you may not use it except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 @@ -15,22 +15,22 @@ #include <3dgut/kernels/slang/common/sphericalHarmonics.slang> +namespace shRadiativeParticle +{ +static const int RadianceMaxNumSphCoefficients = PARTICLE_RADIANCE_NUM_COEFFS; +static const int Dim = 3; +}; + namespace shRadiativeParticle { struct ParametersBuffer { Ptr, Access.Read> _dataPtr; - Ptr> _gradPtr; + vector *_gradPtr; bool exclusiveGradient; //< true if the gradient maybe updated without atomics }; -struct CommonParameters -{ - ParametersBuffer parametersBuffer; - int sphDegree; -}; - struct Parameters : IDifferentiable { vector sphCoefficients[RadianceMaxNumSphCoefficients]; @@ -42,7 +42,7 @@ Parameters fetchParametersFromBuffer(no_diff uint32_t particleIdx, { Parameters parameters; const uint32_t particleOffset = particleIdx * RadianceMaxNumSphCoefficients; - [unroll] for (int i = 0; i < RadianceMaxNumSphCoefficients; ++i) { + [ForceUnroll] for (int i = 0; i < RadianceMaxNumSphCoefficients; ++i) { parameters.sphCoefficients[i] = parametersBuffer._dataPtr[particleOffset + i]; } return parameters; @@ -54,14 +54,14 @@ void fetchParametersFromBufferBwd(no_diff uint32_t particleIdx, Parameters parametersGrad) { const uint32_t particleOffset = particleIdx * RadianceMaxNumSphCoefficients; - [unroll] for (int i = 0; i < RadianceMaxNumSphCoefficients; ++i) { + [ForceUnroll] for (int i = 0; i < RadianceMaxNumSphCoefficients; ++i) { const vector coeffs = parametersGrad.sphCoefficients[i]; if (parametersBuffer.exclusiveGradient) { - [unroll] for (int j = 0; j < Dim; ++j) { + [ForceUnroll] for (int j = 0; j < Dim; ++j) { parametersBuffer._gradPtr[particleOffset + i][j] += coeffs[j]; } } else { - [unroll] for (int j = 0; j < Dim; ++j) { + [ForceUnroll] for (int j = 0; j < Dim; ++j) { InterlockedAdd(parametersBuffer._gradPtr[particleOffset + i][j], coeffs[j]); } } @@ -75,9 +75,9 @@ vector radianceFromBuffer(no_diff uint32_t particleIdx, no_diff ParametersBuffer parametersBuffer) { return sphericalHarmonics.decode( - sphDegree, - fetchParametersFromBuffer(particleIdx, parametersBuffer).sphCoefficients, - incidentDirection); + sphDegree, + fetchParametersFromBuffer(particleIdx, parametersBuffer).sphCoefficients, + incidentDirection); } [BackwardDifferentiable][ForceInline] @@ -139,61 +139,103 @@ void integrateRadianceFromBuffer(no_diff float3 incident // Entry points [CudaDeviceExport] -inline vector particleFeaturesFromBuffer(in uint32_t particleIdx, - shRadiativeParticle.CommonParameters commonParameters, - in float3 incidentDirection) +inline void particleFeaturesFromBuffer( + in uint32_t particleIdx, + in float *radianceBufferPtr, + in int auxParam, + in float3 incidentDirection, + in float3 canonicalPosition, + out float features[shRadiativeParticle.Dim]) { - return sphericalHarmonics.decode( - commonParameters.sphDegree, - shRadiativeParticle.fetchParametersFromBuffer(particleIdx, commonParameters.parametersBuffer).sphCoefficients, + shRadiativeParticle.ParametersBuffer parametersBuffer; + parametersBuffer._dataPtr = (Ptr, Access.Read>)radianceBufferPtr; + parametersBuffer._gradPtr = nullptr; + parametersBuffer.exclusiveGradient = false; + + vector featuresVec = sphericalHarmonics.decode( + auxParam, + shRadiativeParticle.fetchParametersFromBuffer(particleIdx, parametersBuffer).sphCoefficients, incidentDirection); + + [ForceUnroll] for (int i = 0; i < shRadiativeParticle.Dim; ++i) { + features[i] = featuresVec[i]; + } } [CudaDeviceExport] inline void particleFeaturesIntegrateFwd(in float weight, - in vector features, - inout vector integratedFeatures) + in float features[shRadiativeParticle.Dim], + inout float integratedFeatures[shRadiativeParticle.Dim]) { - shRadiativeParticle.integrateRadiance( - weight, - features, - integratedFeatures - ); + vector featuresVec; + [ForceUnroll] for (int i = 0; i < shRadiativeParticle.Dim; ++i) { + featuresVec[i] = features[i]; + } + vector integratedVec; + [ForceUnroll] for (int i = 0; i < shRadiativeParticle.Dim; ++i) { + integratedVec[i] = integratedFeatures[i]; + } + shRadiativeParticle.integrateRadiance(weight, featuresVec, integratedVec); + [ForceUnroll] for (int i = 0; i < shRadiativeParticle.Dim; ++i) { + integratedFeatures[i] = integratedVec[i]; + } } -[CudaDeviceExport] inline void particleFeaturesIntegrateFwdFromBuffer(in float3 incidentDirection, - in float weight, - in uint32_t particleIdx, - shRadiativeParticle.CommonParameters commonParameters, - inout vector integratedFeatures) +[CudaDeviceExport] +inline void particleFeaturesIntegrateFwdFromBuffer( + in float3 incidentDirection, + in float3 canonicalPosition, + in float weight, + in uint32_t particleIdx, + in float *radianceBufferPtr, + in int auxParam, + inout float integratedFeatures[shRadiativeParticle.Dim]) { + shRadiativeParticle.ParametersBuffer parametersBuffer; + parametersBuffer._dataPtr = (Ptr, Access.Read>)radianceBufferPtr; + parametersBuffer._gradPtr = nullptr; + parametersBuffer.exclusiveGradient = false; + + vector integratedVec; + [ForceUnroll] for (int i = 0; i < shRadiativeParticle.Dim; ++i) { + integratedVec[i] = integratedFeatures[i]; + } shRadiativeParticle.integrateRadianceFromBuffer( - incidentDirection, - commonParameters.sphDegree, - weight, - particleIdx, - commonParameters.parametersBuffer, - integratedFeatures); + incidentDirection, auxParam, weight, particleIdx, parametersBuffer, integratedVec); + [ForceUnroll] for (int i = 0; i < shRadiativeParticle.Dim; ++i) { + integratedFeatures[i] = integratedVec[i]; + } } [CudaDeviceExport] void particleFeaturesIntegrateBwd( in float alpha, inout float alphaGrad, - in vector features, - inout vector featuresGrad, - inout vector integratedFeatures, - inout vector integratedFeaturesGrad) + in float features[shRadiativeParticle.Dim], + inout float featuresGrad[shRadiativeParticle.Dim], + inout float integratedFeatures[shRadiativeParticle.Dim], + inout float integratedFeaturesGrad[shRadiativeParticle.Dim]) { if(alpha > 0.0f) { + vector featuresVec, featuresGradVec; + [ForceUnroll] for (int i = 0; i < shRadiativeParticle.Dim; ++i) { + featuresVec[i] = features[i]; + featuresGradVec[i] = featuresGrad[i]; + } + vector integratedVec, integratedGradVec; + [ForceUnroll] for (int i = 0; i < shRadiativeParticle.Dim; ++i) { + integratedVec[i] = integratedFeatures[i]; + integratedGradVec[i] = integratedFeaturesGrad[i]; + } + DifferentialPair alphaDiff = DifferentialPair(alpha, alphaGrad); DifferentialPair> featuresDiff = - DifferentialPair>(features, featuresGrad); + DifferentialPair>(featuresVec, featuresGradVec); const float weight = 1.0f / (1.0f - alpha); - integratedFeatures = (integratedFeatures - features * alpha) * weight; + integratedVec = (integratedVec - featuresVec * alpha) * weight; DifferentialPair> integratedFeaturesDiff = - DifferentialPair>(integratedFeatures, integratedFeaturesGrad); + DifferentialPair>(integratedVec, integratedGradVec); bwd_diff(shRadiativeParticle.integrateRadiance)( alphaDiff, @@ -201,59 +243,98 @@ inline void particleFeaturesIntegrateFwd(in float weight, integratedFeaturesDiff); alphaGrad = alphaDiff.getDifferential(); - featuresGrad = featuresDiff.getDifferential(); - integratedFeaturesGrad = integratedFeaturesDiff.getDifferential(); + [ForceUnroll] for (int i = 0; i < shRadiativeParticle.Dim; ++i) { + featuresGrad[i] = featuresDiff.getDifferential()[i]; + } + [ForceUnroll] for (int i = 0; i < shRadiativeParticle.Dim; ++i) { + integratedFeaturesGrad[i] = integratedFeaturesDiff.getDifferential()[i]; + } + [ForceUnroll] for (int i = 0; i < shRadiativeParticle.Dim; ++i) { + integratedFeatures[i] = integratedVec[i]; + } } } [CudaDeviceExport] void particleFeaturesIntegrateBwdToBuffer( in float3 incidentDirection, + in float3 canonicalPosition, + inout float3 canonicalPositionGrad, in float alpha, inout float alphaGrad, in uint32_t particleIdx, - shRadiativeParticle.CommonParameters commonParameters, - in vector features, - inout vector integratedFeatures, - inout vector integratedFeaturesGrad) + in float *featuresBufferPtr, + in float *featuresBufferGradPtr, + in int auxParam, + in bool exclusiveGradient, + in float features[shRadiativeParticle.Dim], + inout float integratedFeatures[shRadiativeParticle.Dim], + inout float integratedFeaturesGrad[shRadiativeParticle.Dim]) { if (alpha > 0.0f) { + shRadiativeParticle.ParametersBuffer parametersBuffer; + parametersBuffer._dataPtr = (Ptr, Access.Read>)featuresBufferPtr; + parametersBuffer._gradPtr = (vector*)featuresBufferGradPtr; + parametersBuffer.exclusiveGradient = exclusiveGradient; + DifferentialPair alphaDiff = DifferentialPair(alpha, alphaGrad); const float weight = 1.0f / (1.0f - alpha); - integratedFeatures = (integratedFeatures - features * alpha) * weight; + [ForceUnroll] for (int i = 0; i < shRadiativeParticle.Dim; ++i) { + integratedFeatures[i] = (integratedFeatures[i] - features[i] * alpha) * weight; + } + + vector integratedFeaturesVec; + vector integratedFeaturesGradVec; + [ForceUnroll] for (int i = 0; i < shRadiativeParticle.Dim; ++i) { + integratedFeaturesVec[i] = integratedFeatures[i]; + integratedFeaturesGradVec[i] = integratedFeaturesGrad[i]; + } DifferentialPair> integratedFeaturesDiff = - DifferentialPair>(integratedFeatures, integratedFeaturesGrad); + DifferentialPair>(integratedFeaturesVec, integratedFeaturesGradVec); bwd_diff(shRadiativeParticle.integrateRadianceFromBuffer)( incidentDirection, - commonParameters.sphDegree, + auxParam, alphaDiff, particleIdx, - commonParameters.parametersBuffer, + parametersBuffer, integratedFeaturesDiff); - integratedFeaturesGrad = integratedFeaturesDiff.getDifferential(); - alphaGrad = alphaDiff.getDifferential(); + [ForceUnroll] for (int i = 0; i < shRadiativeParticle.Dim; ++i) { + integratedFeaturesGrad[i] = integratedFeaturesDiff.getDifferential()[i]; + } + alphaGrad += alphaDiff.getDifferential(); } } [CudaDeviceExport] void particleFeaturesBwdToBuffer( in uint32_t particleIdx, - shRadiativeParticle.CommonParameters commonParameters, - in vector featuresGrad, + in float *featuresBufferPtr, + in float *featuresBufferGradPtr, + in int auxParam, + in bool exclusiveGradient, + in float featuresGrad[shRadiativeParticle.Dim], in float3 incidentDirection, inout float3 incidentDirectionGrad ) { + shRadiativeParticle.ParametersBuffer parametersBuffer; + parametersBuffer._dataPtr = (Ptr, Access.Read>)featuresBufferPtr; + parametersBuffer._gradPtr = (vector*)featuresBufferGradPtr; + parametersBuffer.exclusiveGradient = exclusiveGradient; + + vector featuresGradVec; + [ForceUnroll] for (int i = 0; i < shRadiativeParticle.Dim; ++i) featuresGradVec[i] = featuresGrad[i]; + DifferentialPair incidentDirectionDiff = DifferentialPair(incidentDirection, incidentDirectionGrad); bwd_diff(shRadiativeParticle.radianceFromBuffer)( particleIdx, incidentDirectionDiff, - commonParameters.sphDegree, - commonParameters.parametersBuffer, - featuresGrad); + auxParam, + parametersBuffer, + featuresGradVec); incidentDirectionGrad += incidentDirectionDiff.getDifferential(); } diff --git a/threedgut_tracer/include/3dgut/renderer/gutRenderer.h b/threedgut_tracer/include/3dgut/renderer/gutRenderer.h index 7abd81fa..8ad1919b 100644 --- a/threedgut_tracer/include/3dgut/renderer/gutRenderer.h +++ b/threedgut_tracer/include/3dgut/renderer/gutRenderer.h @@ -22,6 +22,16 @@ #include +#ifndef FEATURE_OUTPUT_HALF +#define FEATURE_OUTPUT_HALF 0 +#endif +#if FEATURE_OUTPUT_HALF +#include +using TFeatureDensityElem = __half; +#else +using TFeatureDensityElem = float; +#endif + namespace threedgut { class GUTRenderer { @@ -67,7 +77,7 @@ class GUTRenderer { const tcnn::vec3* sensorRayDirectionCudaPtr, float* worldHitCountCudaPtr, float* worldHitDistanceCudaPtr, - tcnn::vec4* radianceDensityCudaPtr, + TFeatureDensityElem* featureDensityCudaPtr, int* particlesVisibilityCudaPtr, Parameters& parameters, int cudaDeviceIndex, @@ -78,8 +88,8 @@ class GUTRenderer { const tcnn::vec3* sensorRayDirectionCudaPtr, const float* worldHitDistanceCudaPtr, const float* worldHitDistanceGradientCudaPtr, - const tcnn::vec4* radianceDensityCudaPtr, - const tcnn::vec4* radianceDensityGradientCudaPtr, + const TFeatureDensityElem* featureDensityCudaPtr, + const float* featureDensityGradientCudaPtr, tcnn::vec3* worldRayOriginGradientCudaPtr, tcnn::vec3* worldRayDirectionGradientCudaPtr, Parameters& parameters, diff --git a/threedgut_tracer/include/3dgut/threedgut.cuh b/threedgut_tracer/include/3dgut/threedgut.cuh index 3ab2b092..e01ca757 100644 --- a/threedgut_tracer/include/3dgut/threedgut.cuh +++ b/threedgut_tracer/include/3dgut/threedgut.cuh @@ -29,7 +29,13 @@ struct model_InternalParams { }; struct model_ExternalParams { - static constexpr int FeaturesDim = 3; + // Feature-based radiance dimensions (from compile-time defines) + static constexpr int ParticleFeatureDim = PARTICLE_FEATURE_DIM; + static constexpr int RayFeatureDim = RAY_FEATURE_DIM; + static constexpr int FeatureTransformType = FEATURE_TRANSFORM_TYPE; // 0=SH, 1=nht + // NHT requires per-ray evaluation (interpolation); SH can use precomputed features + static constexpr bool PerRayParticleFeatures = (FEATURE_TRANSFORM_TYPE != 0); + static constexpr float AlphaThreshold = GAUSSIAN_PARTICLE_MIN_ALPHA; // = 1.0/255.0 static constexpr float MinTransmittanceThreshold = GAUSSIAN_MIN_TRANSMITTANCE_THRESHOLD; // = 0.0001 static constexpr int KernelDegree = GAUSSIAN_PARTICLE_KERNEL_DEGREE; @@ -52,7 +58,7 @@ struct TGUTProjectorParams { static constexpr bool TightOpacityBounding = GAUSSIAN_TIGHT_OPACITY_BOUNDING; static constexpr bool RectBounding = GAUSSIAN_RECT_BOUNDING; static constexpr bool TileCulling = GAUSSIAN_TILE_BASED_CULLING; - static constexpr bool PerRayParticleFeatures = false; + static constexpr bool PerRayParticleFeatures = model_ExternalParams::PerRayParticleFeatures; static constexpr float MaxDepthValue = 3.4028235e+38; static constexpr bool GlobalZOrder = GAUSSIAN_GLOBAL_Z_ORDER; static constexpr bool BackwardProjection = false; // m_settings.renderMode == Settings::Splat @@ -77,7 +83,7 @@ static_assert(TGUTProjectionParams::RequireAllSigmaPoints == false, "RequireAllS using TGUTProjector = GUTProjector; struct TGUTRendererParams { - static constexpr bool PerRayParticleFeatures = TGUTProjectorParams::PerRayParticleFeatures; + static constexpr bool PerRayParticleFeatures = model_ExternalParams::PerRayParticleFeatures; static constexpr int KHitBufferSize = GAUSSIAN_K_BUFFER_SIZE; static constexpr bool CustomBackward = false; }; diff --git a/threedgut_tracer/include/3dgut/threedgut.slang b/threedgut_tracer/include/3dgut/threedgut.slang index 554e7fca..cb7f0ecc 100644 --- a/threedgut_tracer/include/3dgut/threedgut.slang +++ b/threedgut_tracer/include/3dgut/threedgut.slang @@ -24,11 +24,4 @@ namespace gaussianParticle }; #include <3dgut/kernels/slang/models/gaussianParticles.slang> - -namespace shRadiativeParticle -{ - static const int RadianceMaxNumSphCoefficients = PARTICLE_RADIANCE_NUM_COEFFS; - static const int Dim = 3; -}; - -#include <3dgut/kernels/slang/models/shRadiativeParticles.slang> +#include <3dgut/kernels/slang/models/radiativeParticles.slang> diff --git a/threedgut_tracer/setup_3dgut.py b/threedgut_tracer/setup_3dgut.py index f88f443a..593936f6 100644 --- a/threedgut_tracer/setup_3dgut.py +++ b/threedgut_tracer/setup_3dgut.py @@ -19,6 +19,7 @@ import torch from threedgrut.utils import jit +from threedgrut.model.features import Features # ---------------------------------------------------------------------------- @@ -32,9 +33,6 @@ def setup_3dgut(conf): include_paths.append(os.path.join(prefix, "..", "thirdparty", "tiny-cuda-nn", "include")) include_paths.append(os.path.join(prefix, "..", "thirdparty", "tiny-cuda-nn", "dependencies")) include_paths.append(os.path.join(prefix, "..", "thirdparty", "tiny-cuda-nn", "dependencies", "fmt", "include")) - include_paths.append(build_dir) - - # Compiler options. def to_cpp_bool(value): return "true" if value else "false" @@ -45,6 +43,24 @@ def to_cpp_bool(value): ut_kappa = conf.render.splat.ut_kappa ut_delta = math.sqrt(ut_alpha * ut_alpha * (ut_d + ut_kappa)) + feat = Features(conf) + transform_defines = [ + f"-DPARTICLE_FEATURE_DIM={feat.particle_feature_dim}", + f"-DRAY_FEATURE_DIM={feat.ray_feature_dim}", + f"-DFEATURE_TRANSFORM_TYPE={feat.transform_type}", + ] + nht_defines = [ + f"-DFEATURE_INTERPOLATION_TYPE={feat.interpolation_type}", + f"-DFEATURE_INTERPOLATION_SUPPORT={feat.interpolation_support}", + f"-DFEATURE_ACTIVATION_TYPE={feat.activation_type}", + f"-DFEATURE_ACTIVATION_NUM_FREQUENCIES={feat.activation_num_frequencies}", + f"-DINTERP_POINT_FEATURE_DIM={feat.interp_point_feature_dim}", + ] + half_defines = [ + f"-DPARTICLE_FEATURE_HALF={1 if conf.render.particle_feature_half else 0}", + f"-DFEATURE_OUTPUT_HALF={1 if conf.render.feature_output_half else 0}", + ] + defines = [ f"-DPARTICLE_RADIANCE_NUM_COEFFS={(conf.render.particle_radiance_sph_degree + 1) ** 2}", f"-DGAUSSIAN_PARTICLE_KERNEL_DEGREE={conf.render.particle_kernel_degree}", @@ -55,6 +71,11 @@ def to_cpp_bool(value): f"-DGAUSSIAN_PARTICLE_SURFEL={to_cpp_bool(conf.render.primitive_type=='trisurfel')}", f"-DGAUSSIAN_MIN_TRANSMITTANCE_THRESHOLD={conf.render.min_transmittance}", f"-DGAUSSIAN_ENABLE_HIT_COUNT={to_cpp_bool(conf.render.enable_hitcounts)}", + # Feature-based radiance dimensions + *transform_defines, + *nht_defines, + # Feature buffer memory layout + *half_defines, # Specific to the 3DGUT renderer f"-DGAUSSIAN_N_ROLLING_SHUTTER_ITERATIONS={conf.render.splat.n_rolling_shutter_iterations}", f"-DGAUSSIAN_K_BUFFER_SIZE={conf.render.splat.k_buffer_size}", @@ -88,6 +109,11 @@ def to_cpp_bool(value): "-O3", *defines, ] + # When PARTICLE_FEATURE_HALF=1 the Slang-generated header uses __half types; + # the Slang prelude only pulls in and defines __half when + # SLANG_CUDA_ENABLE_HALF is set. + if conf.render.particle_feature_half or conf.render.feature_output_half: + cuda_cflags.append("-DSLANG_CUDA_ENABLE_HALF=1") # List of sources. source_files = [ @@ -99,9 +125,11 @@ def to_cpp_bool(value): # Compile slang kernels slang_build_inc_dir = os.path.join(os.path.dirname(__file__), "include", "3dgut") + slang_output_file = os.path.join(os.path.dirname(__file__), "include", "threedgutSlang.cuh") + jit.compile_slang_kernel( kernel_files=[f"{os.path.join(slang_build_inc_dir, 'threedgut.slang')}"], - output_file=f"{os.path.join(build_dir, 'threedgutSlang.cuh')}", + output_file=slang_output_file, include_paths=[ os.path.join(os.path.dirname(__file__), "include"), os.path.join(os.path.dirname(__file__), "..", "threedgrt_tracer", "include"), diff --git a/threedgut_tracer/src/gutRenderer.cu b/threedgut_tracer/src/gutRenderer.cu index 8c2472c3..5c361b63 100644 --- a/threedgut_tracer/src/gutRenderer.cu +++ b/threedgut_tracer/src/gutRenderer.cu @@ -39,7 +39,7 @@ namespace { using namespace threedgut; constexpr int featuresDim() { - return model_ExternalParams::FeaturesDim; + return model_ExternalParams::RayFeatureDim; } // identify tiles start/end indices in the sorted tile/depth keys buffer @@ -243,7 +243,7 @@ threedgut::Status threedgut::GUTRenderer::renderForward(const RenderParameters& const vec3* sensorRayDirectionCudaPtr, float* worldHitCountCudaPtr, float* worldHitDistanceCudaPtr, - vec4* radianceDensityCudaPtr, + TFeatureDensityElem* featureDensityCudaPtr, int* particlesVisibilityCudaPtr, Parameters& parameters, int cudaDeviceIndex, @@ -390,7 +390,7 @@ threedgut::Status threedgut::GUTRenderer::renderForward(const RenderParameters& sensorPoseToMat(sensorPoseInv), worldHitCountCudaPtr, worldHitDistanceCudaPtr, - radianceDensityCudaPtr, + featureDensityCudaPtr, (const tcnn::vec2*)m_forwardContext->particlesProjectedPosition.data(), (const tcnn::vec4*)m_forwardContext->particlesProjectedConicOpacity.data(), (const float*)m_forwardContext->particlesGlobalDepth.data(), @@ -407,7 +407,7 @@ threedgut::Status threedgut::GUTRenderer::renderForward(const RenderParameters& sensorPoseToMat(sensorPoseInv), worldHitCountCudaPtr, worldHitDistanceCudaPtr, - radianceDensityCudaPtr, + featureDensityCudaPtr, (const tcnn::vec2*)m_forwardContext->particlesProjectedPosition.data(), (const tcnn::vec4*)m_forwardContext->particlesProjectedConicOpacity.data(), (const float*)m_forwardContext->particlesGlobalDepth.data(), @@ -423,10 +423,10 @@ threedgut::Status threedgut::GUTRenderer::renderForward(const RenderParameters& threedgut::Status threedgut::GUTRenderer::renderBackward(const RenderParameters& params, const vec3* sensorRayOriginCudaPtr, const vec3* sensorRayDirectionCudaPtr, - const float* worldHitDistanceCudaPtr, // - const float* worldHitDistanceGradientCudaPtr, // TODO: not implemented yet - const vec4* radianceDensityCudaPtr, // - const vec4* radianceDensityGradientCudaPtr, // TODO: not implemented yet + const float* worldHitDistanceCudaPtr, + const float* worldHitDistanceGradientCudaPtr, + const TFeatureDensityElem* featureDensityCudaPtr, + const float* featureDensityGradientCudaPtr, vec3* worldRayOriginGradientCudaPtr, // TODO: not implemented yet vec3* worldRayDirectionGradientCudaPtr, // TODO: not implemented yet Parameters& parameters, @@ -477,8 +477,8 @@ threedgut::Status threedgut::GUTRenderer::renderBackward(const RenderParameters& sensorPoseToMat(sensorPoseInv), (const float*)worldHitDistanceCudaPtr, // (const float*)worldHitDistanceGradientCudaPtr, // TODO: not implemented yet - (const tcnn::vec4*)radianceDensityCudaPtr, // - (const tcnn::vec4*)radianceDensityGradientCudaPtr, // TODO: not implemented yet + featureDensityCudaPtr, + featureDensityGradientCudaPtr, (tcnn::vec3*)worldRayOriginGradientCudaPtr, // TODO: not implemented yet (tcnn::vec3*)worldRayDirectionGradientCudaPtr, // TODO: not implemented yet (const tcnn::vec2*)m_forwardContext->particlesProjectedPosition.data(), diff --git a/threedgut_tracer/src/splatRaster.cpp b/threedgut_tracer/src/splatRaster.cpp index f277acdc..8a400d84 100644 --- a/threedgut_tracer/src/splatRaster.cpp +++ b/threedgut_tracer/src/splatRaster.cpp @@ -192,8 +192,14 @@ SplatRaster::trace(uint32_t frameNumber, int numActiveFeatures, const uint32_t numParticles = particleDensity.size(0); const torch::TensorOptions opts = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); + // Feature output dtype: fp16 halves memory bandwidth for NHT feature buffers +#if FEATURE_OUTPUT_HALF + const torch::TensorOptions featureOpts = torch::TensorOptions().dtype(torch::kHalf).device(torch::kCUDA); +#else + const torch::TensorOptions featureOpts = opts; +#endif - torch::Tensor rayRadianceDensity = torch::zeros({height, width, 4}, opts); + torch::Tensor rayRadianceDensity = torch::zeros({height, width, static_cast(RAY_FEATURE_DIM + 1)}, featureOpts); torch::Tensor rayHitDistance = torch::ones({height, width, 1}, opts).multiply(1e06f); torch::Tensor rayHitCount = torch::zeros({height, width, 1}, opts); torch::Tensor particleVisibility = torch::zeros({numParticles, 1}, opts); @@ -228,7 +234,7 @@ SplatRaster::trace(uint32_t frameNumber, int numActiveFeatures, reinterpret_cast(voidDataPtr(rayDirection)), reinterpret_cast(voidDataPtr(rayHitCount)), reinterpret_cast(voidDataPtr(rayHitDistance)), - reinterpret_cast(voidDataPtr(rayRadianceDensity)), + reinterpret_cast(voidDataPtr(rayRadianceDensity)), reinterpret_cast(voidDataPtr(particleVisibility)), m_parameters, cudaDeviceIndex, @@ -316,8 +322,8 @@ SplatRaster::traceBwd(uint32_t frameNumber, int numActiveFeatures, reinterpret_cast(voidDataPtr(rayDirection)), reinterpret_cast(voidDataPtr(rayHitDistance)), reinterpret_cast(voidDataPtr(rayHitDistanceGradient)), - reinterpret_cast(voidDataPtr(rayRadianceDensity)), - reinterpret_cast(voidDataPtr(rayRadianceDensityGradient)), + reinterpret_cast(voidDataPtr(rayRadianceDensity)), + reinterpret_cast(voidDataPtr(rayRadianceDensityGradient)), rayBackpropagation ? reinterpret_cast(voidDataPtr(rayOriginGradient)) : nullptr, rayBackpropagation ? reinterpret_cast(voidDataPtr(rayDirectionGradient)) : nullptr, m_parameters, cudaDeviceIndex, cudaStream); diff --git a/threedgut_tracer/tracer.py b/threedgut_tracer/tracer.py index e2a80e83..b59a664b 100644 --- a/threedgut_tracer/tracer.py +++ b/threedgut_tracer/tracer.py @@ -176,7 +176,7 @@ def forward( particle_density = torch.concat( [mog_pos, mog_dns, mog_rot, mog_scl, torch.zeros_like(mog_dns)], dim=1 ).contiguous() - particle_radiance = mog_sph.contiguous() + particle_features = mog_sph.contiguous() # dtype set by caller (fp16 when particle_feature_half=true) ray_time = ( torch.ones( @@ -185,11 +185,11 @@ def forward( * sensor_poses.timestamps_us[0] ) - ray_radiance_density, ray_hit_distance, ray_hit_count, mog_visibility = tracer_wrapper.trace( + ray_features_density, ray_hit_distance, ray_hit_count, mog_visibility = tracer_wrapper.trace( frame_id, n_active_features, particle_density, - particle_radiance, + particle_features, ray_ori.contiguous(), ray_dir.contiguous(), ray_time.contiguous(), @@ -204,10 +204,10 @@ def forward( ray_ori, ray_dir, ray_time, - ray_radiance_density, + ray_features_density, ray_hit_distance, particle_density, - particle_radiance, + particle_features, ) ctx.frame_id = frame_id @@ -217,7 +217,7 @@ def forward( ctx.tracer_wrapper = tracer_wrapper return ( - ray_radiance_density, + ray_features_density.float(), # always fp32 to caller; fp16 saved in ctx for trace_bwd ray_hit_distance, ray_hit_count, mog_visibility, @@ -226,7 +226,7 @@ def forward( @staticmethod def backward( ctx, - ray_radiance_density_grd, + ray_features_density_grd, # always fp32 (gradient buffer is never fp16) ray_hit_distance_grd, ray_hit_count_grd_UNUSED, mog_visibility_grd_UNUSED, @@ -235,10 +235,10 @@ def backward( ray_ori, ray_dir, ray_time, - ray_radiance_density, + ray_features_density, ray_hit_distance, particle_density, - particle_radiance, + particle_features, ) = ctx.saved_variables frame_id = ctx.frame_id @@ -246,11 +246,11 @@ def backward( sensor_params = ctx.sensor_params sensor_poses = ctx.sensor_poses - particle_density_grd, particle_radiance_grd = ctx.tracer_wrapper.trace_bwd( + particle_density_grd, particle_features_grd = ctx.tracer_wrapper.trace_bwd( frame_id, n_active_features, particle_density, - particle_radiance, + particle_features, ray_ori, ray_dir, ray_time, @@ -259,8 +259,8 @@ def backward( sensor_poses.timestamps_us[1], sensor_poses.T_world_sensors[0], sensor_poses.T_world_sensors[1], - ray_radiance_density, - ray_radiance_density_grd, + ray_features_density, + ray_features_density_grd, ray_hit_distance, ray_hit_distance_grd, ) @@ -268,7 +268,7 @@ def backward( mog_pos_grd, mog_dns_grd, mog_rot_grd, mog_scl_grd, _ = torch.split( particle_density_grd, [3, 1, 4, 3, 1], dim=1 ) - mog_sph_grd = particle_radiance_grd + mog_sph_grd = particle_features_grd return ( None, # tracer_wrapper @@ -310,7 +310,7 @@ def render(self, gaussians, gpu_batch: Batch, train=False, frame_id=0): num_gaussians = gaussians.num_gaussians with torch.cuda.nvtx.range(f"model.forward({num_gaussians} gaussians)"): ( - pred_rgba, + pred_features_alpha, pred_dist, hits_count, mog_visibility, @@ -324,27 +324,27 @@ def render(self, gaussians, gpu_batch: Batch, train=False, frame_id=0): gaussians.get_rotation().contiguous(), gaussians.get_scale().contiguous(), gaussians.get_density().contiguous(), - gaussians.get_features().contiguous(), + gaussians.get_features().contiguous().half() + if self.conf.render.particle_feature_half + else gaussians.get_features().contiguous(), sensor, poses, ) - pred_rgb = pred_rgba[..., :3].unsqueeze(0).contiguous() - pred_opacity = pred_rgba[..., 3:].unsqueeze(0).contiguous() + # pred_features_alpha is [..., RAY_FEATURE_DIM + 1]: features (fp32) + density + ray_feature_dim = gaussians.ray_feature_dim + pred_features = pred_features_alpha[..., :ray_feature_dim].unsqueeze(0).contiguous() + pred_opacity = pred_features_alpha[..., ray_feature_dim:].unsqueeze(0).contiguous() pred_dist = pred_dist.unsqueeze(0).contiguous() hits_count = hits_count.unsqueeze(0).contiguous() - pred_rgb, pred_opacity = gaussians.background( - gpu_batch.T_to_world.contiguous(), rays_d, pred_rgb, pred_opacity, train - ) - timings = self.tracer_wrapper.collect_times() return { - "pred_rgb": pred_rgb, + "pred_features": pred_features, "pred_opacity": pred_opacity, "pred_dist": pred_dist, - "pred_normals": torch.nn.functional.normalize(torch.ones_like(pred_rgb), dim=3), + "pred_normals": torch.nn.functional.normalize(torch.ones_like(pred_features), dim=3), "hits_count": hits_count, "frame_time_ms": timings["forward_render"] if "forward_render" in timings else 0.0, "mog_visibility": mog_visibility, From 9ab4e6e29074a46769e0ee5814832f5991544164 Mon Sep 17 00:00:00 2001 From: Nicolas Moenne-Loccoz Date: Tue, 21 Apr 2026 14:07:27 -0400 Subject: [PATCH 4/9] Add Ray Features and Particle Features FP16 support for 3DGRT Fwd --- .../3dgrt/kernels/cuda/3dgrtTracer.cuh | 2 +- threedgrt_tracer/include/3dgrt/optixTracer.h | 8 +-- .../include/3dgrt/pipelineParameters.h | 38 +++++++++--- .../include/3dgrt/tensorBuffering.h | 7 ++- .../include/playground/pipelineParameters.h | 4 +- threedgrt_tracer/setup_3dgrt.py | 1 + .../kernels/cuda/barycentricSurfelsOptix.cu | 8 +-- .../kernels/cuda/referenceB2FSlangBwdOptix.cu | 8 +-- .../src/kernels/cuda/referenceBwdOptix.cu | 8 +-- .../src/kernels/cuda/referenceOptix.cu | 8 +-- .../kernels/cuda/referenceSlangBwdOptix.cu | 28 +++++---- .../src/kernels/cuda/referenceSlangOptix.cu | 14 +++-- threedgrt_tracer/src/optixTracer.cpp | 62 +++++++++++++------ threedgrt_tracer/tracer.py | 6 +- .../include/playground/hybridTracer.h | 2 +- .../include/playground/kernels/cuda/trace.cuh | 24 +++---- threedgrut_playground/src/hybridTracer.cpp | 6 +- .../src/kernels/cuda/3dgrtKernel.cu | 6 +- 18 files changed, 154 insertions(+), 86 deletions(-) diff --git a/threedgrt_tracer/include/3dgrt/kernels/cuda/3dgrtTracer.cuh b/threedgrt_tracer/include/3dgrt/kernels/cuda/3dgrtTracer.cuh index 8ce3ff60..8bcbf56b 100644 --- a/threedgrt_tracer/include/3dgrt/kernels/cuda/3dgrtTracer.cuh +++ b/threedgrt_tracer/include/3dgrt/kernels/cuda/3dgrtTracer.cuh @@ -186,7 +186,7 @@ static __device__ __inline__ void traceVolumetricGS( rayDirection, hitWeight, rayHit.particleId, - params.particleRadiance, // void* - generic buffer pointer + params.particleFeatures, // void* - generic buffer pointer params.sphDegree, // auxiliary parameter (sphDegree for SH, unused for learned) rayData.features); // float* - generic output array diff --git a/threedgrt_tracer/include/3dgrt/optixTracer.h b/threedgrt_tracer/include/3dgrt/optixTracer.h index f84463b3..13469eb3 100644 --- a/threedgrt_tracer/include/3dgrt/optixTracer.h +++ b/threedgrt_tracer/include/3dgrt/optixTracer.h @@ -144,7 +144,7 @@ class OptixTracer { torch::Tensor rayOri, torch::Tensor rayDir, torch::Tensor particleDensity, - torch::Tensor particleRadiance, + torch::Tensor particleFeatures, uint32_t renderOpts, int sphDegree, float minTransmittance); @@ -153,13 +153,13 @@ class OptixTracer { torch::Tensor rayToWorld, torch::Tensor rayOri, torch::Tensor rayDir, - torch::Tensor rayRad, + torch::Tensor rayFeat, torch::Tensor rayDns, torch::Tensor rayHit, torch::Tensor rayNrm, torch::Tensor particleDensity, - torch::Tensor particleRadiance, - torch::Tensor rayRadGrd, + torch::Tensor particleFeatures, + torch::Tensor rayFeatGrd, torch::Tensor rayDnsGrd, torch::Tensor rayHitGrd, torch::Tensor rayNrmGrd, diff --git a/threedgrt_tracer/include/3dgrt/pipelineParameters.h b/threedgrt_tracer/include/3dgrt/pipelineParameters.h index d9a267bf..ce8e6c1a 100644 --- a/threedgrt_tracer/include/3dgrt/pipelineParameters.h +++ b/threedgrt_tracer/include/3dgrt/pipelineParameters.h @@ -21,17 +21,41 @@ #include <3dgrt/pipelineDefinitions.h> #include <3dgrt/tensorAccessor.h> +// Per-particle feature storage element type. fp16 when PARTICLE_FEATURE_HALF=1, else fp32. +// Gradient buffer for per-particle features is always fp32. +#ifndef PARTICLE_FEATURE_HALF +#define PARTICLE_FEATURE_HALF 0 +#endif +// Per-ray integrated feature output element type. fp16 when FEATURE_OUTPUT_HALF=1, else fp32. +// Gradient buffer for the integrated feature output is always fp32. +#ifndef FEATURE_OUTPUT_HALF +#define FEATURE_OUTPUT_HALF 0 +#endif +#if PARTICLE_FEATURE_HALF || FEATURE_OUTPUT_HALF +#include +#endif +#if PARTICLE_FEATURE_HALF +using TParticleFeatureElem = __half; +#else +using TParticleFeatureElem = float; +#endif +#if FEATURE_OUTPUT_HALF +using TRayFeatureElem = __half; +#else +using TRayFeatureElem = float; +#endif + struct PipelineParameters { float4 rayToWorld[3]; ///< float3x4 ray to world transformation (row-major) PackedTensorAccessor32 rayOrigin; ///< ray origin PackedTensorAccessor32 rayDirection; ///< ray direction - const ParticleDensity* particleDensity; ///< position, scale, quaternions, density - const float* particleRadiance; ///< spherical harmonics coefficients - const void* particleExtendedData; ///< pipeline specific particle data - int32_t* particleVisibility; ///< pipeline specific particle data + const ParticleDensity* particleDensity; ///< position, scale, quaternions, density + const TParticleFeatureElem* particleFeatures; ///< per-particle features (fp16 when PARTICLE_FEATURE_HALF) + const void* particleExtendedData; ///< pipeline specific particle data + int32_t* particleVisibility; ///< pipeline specific particle data - PackedTensorAccessor32 rayRadiance; ///< output integrated ray radiance + PackedTensorAccessor32 rayFeatures; ///< integrated ray features (fp16 when FEATURE_OUTPUT_HALF) PackedTensorAccessor32 rayDensity; ///< output integrated ray density PackedTensorAccessor32 rayHitDistance; ///< output integrated ray hit distance PackedTensorAccessor32 rayNormal; ///< output integrated ray normal @@ -92,11 +116,11 @@ struct PipelineParameters { }; struct PipelineBackwardParameters : PipelineParameters { - PackedTensorAccessor32 rayRadianceGrad; ///< integrated ray radiance gradient + PackedTensorAccessor32 rayFeaturesGrad; ///< integrated ray features gradient (fp32) PackedTensorAccessor32 rayDensityGrad; ///< integrated ray density gradient PackedTensorAccessor32 rayHitDistanceGrad; ///< integrated ray hit distance gradient PackedTensorAccessor32 rayNormalGrad; ///< integrated ray hit distance gradient ParticleDensity* particleDensityGrad; ///< output position, scale, quaternions, density gradient - float* particleRadianceGrad; ///< output spherical harmonics coefficients gradient + float* particleFeaturesGrad; ///< per-particle features gradient (fp32) }; diff --git a/threedgrt_tracer/include/3dgrt/tensorBuffering.h b/threedgrt_tracer/include/3dgrt/tensorBuffering.h index 23454342..458d6ffe 100644 --- a/threedgrt_tracer/include/3dgrt/tensorBuffering.h +++ b/threedgrt_tracer/include/3dgrt/tensorBuffering.h @@ -33,13 +33,18 @@ inline scalar_t* getPtr(torch::Tensor tensor) { return reinterpret_cast(tensor.contiguous().data_ptr()); } else if (tensor.dtype() == torch::kFloat32) { return reinterpret_cast(tensor.contiguous().data_ptr()); + } else if (tensor.dtype() == torch::kHalf) { + return reinterpret_cast(tensor.contiguous().data_ptr()); } else { throw std::runtime_error("getPtr(tensor) received a tensor of unsupported type"); } } +// Note: `reinterpret_cast` is used on the raw `tensor.data_ptr()` so that this +// template instantiates for element types without a PyTorch scalar_type entry +// (e.g. `__half`). Dtype agreement is the caller's responsibility. template class PtrTraits = DefaultPtrTraits> PackedTensorAccessor32 packed_accessor32(torch::Tensor tensor) { - return PackedTensorAccessor32(static_cast::PtrType>(tensor.data_ptr()), + return PackedTensorAccessor32(reinterpret_cast::PtrType>(tensor.data_ptr()), tensor.sizes().data(), tensor.strides().data()); } \ No newline at end of file diff --git a/threedgrt_tracer/include/playground/pipelineParameters.h b/threedgrt_tracer/include/playground/pipelineParameters.h index d25a6ec8..3967a252 100644 --- a/threedgrt_tracer/include/playground/pipelineParameters.h +++ b/threedgrt_tracer/include/playground/pipelineParameters.h @@ -52,8 +52,8 @@ struct PlaygroundTracingParams : PipelineParameters { // rayOri -> rayOrigin // rayDir -> rayDirection // mogPos, mogRot, mogScl, mogDns -> particleDensity - // mogSph-> particleRadiance - // rayRad -> rayRadiance + // mogSph-> particleFeatures + // rayRad -> rayFeatures // rayDns -> rayDensity // rayHit -> rayHitDistance // rayHitsCount -> rayHitsCount diff --git a/threedgrt_tracer/setup_3dgrt.py b/threedgrt_tracer/setup_3dgrt.py index e878089d..97eb8531 100644 --- a/threedgrt_tracer/setup_3dgrt.py +++ b/threedgrt_tracer/setup_3dgrt.py @@ -50,6 +50,7 @@ def to_cpp_bool(value): # Compiler options. Same -D for feature dims so pipelineParameters.h and JIT OptiX pipeline (generateDefines) stay in sync. cflags = [ *transform_defines, + *half_defines, ] cuda_flags = [ # Feature-based radiance dimensions (must match Slang compilation) diff --git a/threedgrt_tracer/src/kernels/cuda/barycentricSurfelsOptix.cu b/threedgrt_tracer/src/kernels/cuda/barycentricSurfelsOptix.cu index 872b79ce..bf8291d1 100644 --- a/threedgrt_tracer/src/kernels/cuda/barycentricSurfelsOptix.cu +++ b/threedgrt_tracer/src/kernels/cuda/barycentricSurfelsOptix.cu @@ -134,7 +134,7 @@ extern "C" __global__ void __raygen__rg() { float3 sphCoefficients[SPH_MAX_NUM_COEFFS]; fetchParticleSphCoefficients( rayHit.particleId, - params.particleRadiance, + params.particleFeatures, &sphCoefficients[0]); const float3 rayParticleRadiance = radianceFromSpH(params.sphDegree, &sphCoefficients[0], rayDirection); @@ -161,9 +161,9 @@ extern "C" __global__ void __raygen__rg() { } } - params.rayRadiance[idx.z][idx.y][idx.x][0] = rayRadiance.x; - params.rayRadiance[idx.z][idx.y][idx.x][1] = rayRadiance.y; - params.rayRadiance[idx.z][idx.y][idx.x][2] = rayRadiance.z; + params.rayFeatures[idx.z][idx.y][idx.x][0] = rayRadiance.x; + params.rayFeatures[idx.z][idx.y][idx.x][1] = rayRadiance.y; + params.rayFeatures[idx.z][idx.y][idx.x][2] = rayRadiance.z; params.rayDensity[idx.z][idx.y][idx.x][0] = 1 - rayTransmittance; params.rayHitDistance[idx.z][idx.y][idx.x][0] = rayHitDistance; params.rayHitDistance[idx.z][idx.y][idx.x][1] = rayLastHitDistance; diff --git a/threedgrt_tracer/src/kernels/cuda/referenceB2FSlangBwdOptix.cu b/threedgrt_tracer/src/kernels/cuda/referenceB2FSlangBwdOptix.cu index 7459ff42..f31cd6aa 100644 --- a/threedgrt_tracer/src/kernels/cuda/referenceB2FSlangBwdOptix.cu +++ b/threedgrt_tracer/src/kernels/cuda/referenceB2FSlangBwdOptix.cu @@ -101,7 +101,7 @@ extern "C" __global__ void __raygen__rg() { const float3 rayOrigin = params.rayWorldOrigin(idx); const float3 rayDirection = params.rayWorldDirection(idx); - float3 rayRadiance = make_float3(params.rayRadiance[idx.z][idx.y][idx.x][0], params.rayRadiance[idx.z][idx.y][idx.x][1], params.rayRadiance[idx.z][idx.y][idx.x][2]); + float3 rayRadiance = make_float3(params.rayFeatures[idx.z][idx.y][idx.x][0], params.rayFeatures[idx.z][idx.y][idx.x][1], params.rayFeatures[idx.z][idx.y][idx.x][2]); float rayTransmittance = 1.0f - params.rayDensity[idx.z][idx.y][idx.x][0]; float rayHitDistance = params.rayHitDistance[idx.z][idx.y][idx.x][0]; #ifdef ENABLE_NORMALS @@ -110,7 +110,7 @@ extern "C" __global__ void __raygen__rg() { float rayMaxHitDistance = params.rayHitDistance[idx.z][idx.y][idx.x][1]; - float3 rayRadianceGrad = make_float3(params.rayRadianceGrad[idx.z][idx.y][idx.x][0], params.rayRadianceGrad[idx.z][idx.y][idx.x][1], params.rayRadianceGrad[idx.z][idx.y][idx.x][2]); + float3 rayRadianceGrad = make_float3(params.rayFeaturesGrad[idx.z][idx.y][idx.x][0], params.rayFeaturesGrad[idx.z][idx.y][idx.x][1], params.rayFeaturesGrad[idx.z][idx.y][idx.x][2]); float rayTransmittanceGrad = -1.0f * params.rayDensityGrad[idx.z][idx.y][idx.x][0]; float rayHitDistanceGrad = params.rayHitDistanceGrad[idx.z][idx.y][idx.x][0]; #ifdef ENABLE_NORMALS @@ -148,8 +148,8 @@ extern "C" __global__ void __raygen__rg() { // rayHit.particleId, // (ParticleDensity_0*)params.particleDensity, // (ParticleDensity_0*)params.particleDensityGrad, - // (float*)params.particleRadiance, - // (float*)params.particleRadianceGrad, + // (float*)params.particleFeatures, + // (float*)params.particleFeaturesGrad, // params.hitMinGaussianResponse, // params.alphaMinThreshold, // PipelineParameters::ParticleKernelDegree, diff --git a/threedgrt_tracer/src/kernels/cuda/referenceBwdOptix.cu b/threedgrt_tracer/src/kernels/cuda/referenceBwdOptix.cu index 57b2a50b..2fcd80fa 100644 --- a/threedgrt_tracer/src/kernels/cuda/referenceBwdOptix.cu +++ b/threedgrt_tracer/src/kernels/cuda/referenceBwdOptix.cu @@ -109,12 +109,12 @@ extern "C" __global__ void __raygen__rg() { const float3 rayOrigin = params.rayWorldOrigin(idx); const float3 rayDirection = params.rayWorldDirection(idx); - float3 rayIntegratedRadiance = make_float3(params.rayRadiance[idx.z][idx.y][idx.x][0], params.rayRadiance[idx.z][idx.y][idx.x][1], params.rayRadiance[idx.z][idx.y][idx.x][2]); + float3 rayIntegratedRadiance = make_float3(params.rayFeatures[idx.z][idx.y][idx.x][0], params.rayFeatures[idx.z][idx.y][idx.x][1], params.rayFeatures[idx.z][idx.y][idx.x][2]); float rayIntegratedTransmittance = 1.0f - params.rayDensity[idx.z][idx.y][idx.x][0]; float rayIntegratedHitDistance = params.rayHitDistance[idx.z][idx.y][idx.x][0]; float rayMaxHitDistance = params.rayHitDistance[idx.z][idx.y][idx.x][1]; - float3 rayRadianceGrad = make_float3(params.rayRadianceGrad[idx.z][idx.y][idx.x][0], params.rayRadianceGrad[idx.z][idx.y][idx.x][1], params.rayRadianceGrad[idx.z][idx.y][idx.x][2]); + float3 rayRadianceGrad = make_float3(params.rayFeaturesGrad[idx.z][idx.y][idx.x][0], params.rayFeaturesGrad[idx.z][idx.y][idx.x][1], params.rayFeaturesGrad[idx.z][idx.y][idx.x][2]); float rayTransmittanceGrad = -1.0f * params.rayDensityGrad[idx.z][idx.y][idx.x][0]; float rayHitDistanceGrad = params.rayHitDistanceGrad[idx.z][idx.y][idx.x][0]; @@ -147,8 +147,8 @@ extern "C" __global__ void __raygen__rg() { rayHit.particleId, params.particleDensity, params.particleDensityGrad, - params.particleRadiance, - params.particleRadianceGrad, + params.particleFeatures, + params.particleFeaturesGrad, params.hitMinGaussianResponse, params.alphaMinThreshold, params.minTransmittance, diff --git a/threedgrt_tracer/src/kernels/cuda/referenceOptix.cu b/threedgrt_tracer/src/kernels/cuda/referenceOptix.cu index bfba5e91..caecaebf 100644 --- a/threedgrt_tracer/src/kernels/cuda/referenceOptix.cu +++ b/threedgrt_tracer/src/kernels/cuda/referenceOptix.cu @@ -141,7 +141,7 @@ extern "C" __global__ void __raygen__rg() { rayDirection, rayHit.particleId, params.particleDensity, - params.particleRadiance, + params.particleFeatures, params.hitMinGaussianResponse, params.alphaMinThreshold, params.sphDegree, @@ -169,9 +169,9 @@ extern "C" __global__ void __raygen__rg() { } } - params.rayRadiance[idx.z][idx.y][idx.x][0] = rayRadiance.x; - params.rayRadiance[idx.z][idx.y][idx.x][1] = rayRadiance.y; - params.rayRadiance[idx.z][idx.y][idx.x][2] = rayRadiance.z; + params.rayFeatures[idx.z][idx.y][idx.x][0] = rayRadiance.x; + params.rayFeatures[idx.z][idx.y][idx.x][1] = rayRadiance.y; + params.rayFeatures[idx.z][idx.y][idx.x][2] = rayRadiance.z; params.rayDensity[idx.z][idx.y][idx.x][0] = 1 - rayTransmittance; params.rayHitDistance[idx.z][idx.y][idx.x][0] = rayHitDistance; params.rayHitDistance[idx.z][idx.y][idx.x][1] = rayLastHitDistance; diff --git a/threedgrt_tracer/src/kernels/cuda/referenceSlangBwdOptix.cu b/threedgrt_tracer/src/kernels/cuda/referenceSlangBwdOptix.cu index c9cc4f0b..1eb31a31 100644 --- a/threedgrt_tracer/src/kernels/cuda/referenceSlangBwdOptix.cu +++ b/threedgrt_tracer/src/kernels/cuda/referenceSlangBwdOptix.cu @@ -110,10 +110,14 @@ extern "C" __global__ void __raygen__rg() { const float3 rayOrigin = params.rayWorldOrigin(idx); const float3 rayDirection = params.rayWorldDirection(idx); - FixedArray rayRadiance; + FixedArray rayFeatures; #pragma unroll for (int i = 0; i < RAY_FEATURE_DIM; i++) { - rayRadiance[i] = params.rayRadiance[idx.z][idx.y][idx.x][i]; +#if FEATURE_OUTPUT_HALF + rayFeatures[i] = __half2float(params.rayFeatures[idx.z][idx.y][idx.x][i]); +#else + rayFeatures[i] = params.rayFeatures[idx.z][idx.y][idx.x][i]; +#endif } float rayTransmittance = 1.0f - params.rayDensity[idx.z][idx.y][idx.x][0]; float rayHitDistance = params.rayHitDistance[idx.z][idx.y][idx.x][0]; @@ -123,10 +127,10 @@ extern "C" __global__ void __raygen__rg() { float rayMaxHitDistance = params.rayHitDistance[idx.z][idx.y][idx.x][1]; - FixedArray rayRadianceGrad; + FixedArray rayFeaturesGrad; #pragma unroll for (int i = 0; i < RAY_FEATURE_DIM; i++) { - rayRadianceGrad[i] = params.rayRadianceGrad[idx.z][idx.y][idx.x][i]; + rayFeaturesGrad[i] = params.rayFeaturesGrad[idx.z][idx.y][idx.x][i]; } float rayTransmittanceGrad = -1.0f * params.rayDensityGrad[idx.z][idx.y][idx.x][0]; float rayHitDistanceGrad = params.rayHitDistanceGrad[idx.z][idx.y][idx.x][0]; @@ -172,14 +176,14 @@ extern "C" __global__ void __raygen__rg() { ) ) { - FixedArray hitRadiance; + FixedArray hitFeatures; particleFeaturesFromBuffer( rayHit.particleId, - const_cast(params.particleRadiance), + const_cast(params.particleFeatures), (int)params.sphDegree, rayDirection, canonicalIntersection, - &hitRadiance); + &hitFeatures); float hitAlphaGrad = 0.f; float3 canonicalIntersectionGrad = make_float3(0.f); @@ -189,13 +193,13 @@ extern "C" __global__ void __raygen__rg() { hitAlpha, &hitAlphaGrad, rayHit.particleId, - const_cast(params.particleRadiance), - const_cast(params.particleRadianceGrad), + const_cast(params.particleFeatures), + const_cast(params.particleFeaturesGrad), (int)params.sphDegree, false, // exclusiveGradient: multiple rays can hit same particle - hitRadiance, - &rayRadiance, - &rayRadianceGrad); + hitFeatures, + &rayFeatures, + &rayFeaturesGrad); particleDensityProcessHitBwdToBuffer(rayOrigin, rayDirection, diff --git a/threedgrt_tracer/src/kernels/cuda/referenceSlangOptix.cu b/threedgrt_tracer/src/kernels/cuda/referenceSlangOptix.cu index ed2b34ef..a0fb6cb1 100644 --- a/threedgrt_tracer/src/kernels/cuda/referenceSlangOptix.cu +++ b/threedgrt_tracer/src/kernels/cuda/referenceSlangOptix.cu @@ -109,10 +109,10 @@ extern "C" __global__ void __raygen__rg() { float3 rayOrigin = params.rayWorldOrigin(idx); float3 rayDirection = params.rayWorldDirection(idx); - FixedArray rayRadiance; + FixedArray rayFeatures; #pragma unroll for (int i = 0; i < RAY_FEATURE_DIM; i++) { - rayRadiance[i] = 0.0f; + rayFeatures[i] = 0.0f; } float rayTransmittance = 1.0f; float rayHitDistance = 0.f; @@ -160,9 +160,9 @@ extern "C" __global__ void __raygen__rg() { canonicalIntersection, hitWeight, rayHit.particleId, - const_cast(params.particleRadiance), + const_cast(params.particleFeatures), params.sphDegree, - &rayRadiance); + &rayFeatures); // NOTE(qi): Race condition here, but as we are writing the same value, it seems it is safe. if (hitWeight > 0.f) { @@ -180,7 +180,11 @@ extern "C" __global__ void __raygen__rg() { #pragma unroll for (int i = 0; i < RAY_FEATURE_DIM; i++) { - params.rayRadiance[idx.z][idx.y][idx.x][i] = rayRadiance[i]; +#if FEATURE_OUTPUT_HALF + params.rayFeatures[idx.z][idx.y][idx.x][i] = __float2half(rayFeatures[i]); +#else + params.rayFeatures[idx.z][idx.y][idx.x][i] = rayFeatures[i]; +#endif } params.rayDensity[idx.z][idx.y][idx.x][0] = 1 - rayTransmittance; params.rayHitDistance[idx.z][idx.y][idx.x][0] = rayHitDistance; diff --git a/threedgrt_tracer/src/optixTracer.cpp b/threedgrt_tracer/src/optixTracer.cpp index 746b513f..1c43e4b9 100644 --- a/threedgrt_tracer/src/optixTracer.cpp +++ b/threedgrt_tracer/src/optixTracer.cpp @@ -219,6 +219,15 @@ std::vector OptixTracer::generateDefines( defines.emplace_back("-DPARTICLE_FEATURE_DIM=" + std::to_string(PipelineParameters::ParticleFeatureDim)); defines.emplace_back("-DRAY_FEATURE_DIM=" + std::to_string(PipelineParameters::RayFeatureDim)); defines.emplace_back("-DFEATURE_TRANSFORM_TYPE=" + std::to_string(PipelineParameters::FeatureTransformType)); + // Half-precision flags: must match what the extension build sees so the + // PipelineParameters struct layout and the Slang-generated header agree + // with the NVRTC-compiled OptiX kernels. + defines.emplace_back("-DPARTICLE_FEATURE_HALF=" + std::to_string(PARTICLE_FEATURE_HALF)); + defines.emplace_back("-DFEATURE_OUTPUT_HALF=" + std::to_string(FEATURE_OUTPUT_HALF)); +#if PARTICLE_FEATURE_HALF || FEATURE_OUTPUT_HALF + // Enable `__half` in the Slang-generated CUDA prelude. + defines.emplace_back("-DSLANG_CUDA_ENABLE_HALF=1"); +#endif } return defines; } @@ -360,9 +369,21 @@ void OptixTracer::createPipeline(const OptixDeviceContext context, std::string optix_include_dir = dependencies_path + "/dependencies/optix-dev/include"; std::string cuda_include_dir = cuda_path + "/include"; + // Some CUDA distributions (e.g. conda's cuda-toolkit) place per-target + // SDK headers such as under $CUDA_HOME/targets//include + // rather than directly under $CUDA_HOME/include. Thread that extra + // search path to NVRTC when present so Slang-generated CUDA that uses + // __half compiles. + std::vector extra_includes_with_cuda_targets = extra_includes; +#if defined(__x86_64__) || defined(_M_X64) + extra_includes_with_cuda_targets.push_back(cuda_path + "/targets/x86_64-linux/include"); +#elif defined(__aarch64__) + extra_includes_with_cuda_targets.push_back(cuda_path + "/targets/sbsa-linux/include"); +#endif + const char* input = getInputData(shaderFile.c_str(), includeDir.c_str(), optix_include_dir.c_str(), cuda_include_dir.c_str(), kernel_name.c_str(), inputSize, defines, - (const char**)&log, extra_includes); + (const char**)&log, extra_includes_with_cuda_targets); size_t sizeof_log = sizeof(log); OPTIX_CHECK_LOG(optixModuleCreateFromPTX( @@ -861,13 +882,18 @@ OptixTracer::trace(uint32_t frameNumber, torch::Tensor rayOri, torch::Tensor rayDir, torch::Tensor particleDensity, - torch::Tensor particleRadiance, + torch::Tensor particleFeatures, uint32_t renderOpts, int sphDegree, float minTransmittance) { const torch::TensorOptions opts = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); - torch::Tensor rayRad = torch::empty({rayOri.size(0), rayOri.size(1), rayOri.size(2), static_cast(PipelineParameters::RayFeatureDim)}, opts); +#if FEATURE_OUTPUT_HALF + const torch::TensorOptions rayFeatOpts = torch::TensorOptions().dtype(torch::kHalf).device(torch::kCUDA); +#else + const torch::TensorOptions rayFeatOpts = opts; +#endif + torch::Tensor rayFeat = torch::empty({rayOri.size(0), rayOri.size(1), rayOri.size(2), static_cast(PipelineParameters::RayFeatureDim)}, rayFeatOpts); torch::Tensor rayDns = torch::empty({rayOri.size(0), rayOri.size(1), rayOri.size(2), 1}, opts); torch::Tensor rayHit = torch::empty({rayOri.size(0), rayOri.size(1), rayOri.size(2), 2}, opts); torch::Tensor rayNrm = torch::empty({rayOri.size(0), rayOri.size(1), rayOri.size(2), 3}, opts); @@ -893,11 +919,11 @@ OptixTracer::trace(uint32_t frameNumber, paramsHost.rayDirection = packed_accessor32(rayDir); paramsHost.particleDensity = getPtr(particleDensity); - paramsHost.particleRadiance = getPtr(particleRadiance); + paramsHost.particleFeatures = getPtr(particleFeatures); paramsHost.particleExtendedData = reinterpret_cast(_state->gPipelineParticleData); paramsHost.particleVisibility = getPtr(particleVisibility); - paramsHost.rayRadiance = packed_accessor32(rayRad); + paramsHost.rayFeatures = packed_accessor32(rayFeat); paramsHost.rayDensity = packed_accessor32(rayDns); paramsHost.rayHitDistance = packed_accessor32(rayHit); paramsHost.rayNormal = packed_accessor32(rayNrm); @@ -910,12 +936,12 @@ OptixTracer::trace(uint32_t frameNumber, reinterpret_cast(_state->paramsDevice), ¶msHost, sizeof(paramsHost), cudaMemcpyHostToDevice, cudaStream)); OPTIX_CHECK(optixLaunch(_state->pipelineTracingFwd, cudaStream, _state->paramsDevice, - sizeof(PipelineParameters), &_state->sbtTracingFwd, rayRad.size(2), - rayRad.size(1), rayRad.size(0))); + sizeof(PipelineParameters), &_state->sbtTracingFwd, rayFeat.size(2), + rayFeat.size(1), rayFeat.size(0))); CUDA_CHECK_LAST(); - return std::tuple(rayRad, rayDns, rayHit, rayNrm, rayHitsCount, particleVisibility); + return std::tuple(rayFeat, rayDns, rayHit, rayNrm, rayHitsCount, particleVisibility); } std::tuple @@ -923,13 +949,13 @@ OptixTracer::traceBwd(uint32_t frameNumber, torch::Tensor rayToWorld, torch::Tensor rayOri, torch::Tensor rayDir, - torch::Tensor rayRad, + torch::Tensor rayFeat, torch::Tensor rayDns, torch::Tensor rayHit, torch::Tensor rayNrm, torch::Tensor particleDensity, - torch::Tensor particleRadiance, - torch::Tensor rayRadGrd, + torch::Tensor particleFeatures, + torch::Tensor rayFeatGrd, torch::Tensor rayDnsGrd, torch::Tensor rayHitGrd, torch::Tensor rayNrmGrd, @@ -939,7 +965,7 @@ OptixTracer::traceBwd(uint32_t frameNumber, const torch::TensorOptions opts = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); torch::Tensor particleDensityGrad = torch::zeros({particleDensity.size(0), particleDensity.size(1)}, opts); - torch::Tensor particleRadianceGrad = torch::zeros({particleRadiance.size(0), particleRadiance.size(1)}, opts); + torch::Tensor particleFeaturesGrad = torch::zeros({particleFeatures.size(0), particleFeatures.size(1)}, opts); PipelineBackwardParameters paramsHost; paramsHost.handle = _state->gasHandle; @@ -960,18 +986,18 @@ OptixTracer::traceBwd(uint32_t frameNumber, paramsHost.rayDirection = packed_accessor32(rayDir); paramsHost.particleDensity = getPtr(particleDensity); - paramsHost.particleRadiance = getPtr(particleRadiance); + paramsHost.particleFeatures = getPtr(particleFeatures); paramsHost.particleExtendedData = reinterpret_cast(_state->gPipelineParticleData); - paramsHost.rayRadiance = packed_accessor32(rayRad); + paramsHost.rayFeatures = packed_accessor32(rayFeat); paramsHost.rayDensity = packed_accessor32(rayDns); paramsHost.rayHitDistance = packed_accessor32(rayHit); paramsHost.rayNormal = packed_accessor32(rayNrm); paramsHost.particleDensityGrad = getPtr(particleDensityGrad); - paramsHost.particleRadianceGrad = getPtr(particleRadianceGrad); + paramsHost.particleFeaturesGrad = getPtr(particleFeaturesGrad); - paramsHost.rayRadianceGrad = packed_accessor32(rayRadGrd); + paramsHost.rayFeaturesGrad = packed_accessor32(rayFeatGrd); paramsHost.rayDensityGrad = packed_accessor32(rayDnsGrd); paramsHost.rayHitDistanceGrad = packed_accessor32(rayHitGrd); paramsHost.rayNormalGrad = packed_accessor32(rayNrmGrd); @@ -984,7 +1010,7 @@ OptixTracer::traceBwd(uint32_t frameNumber, OPTIX_CHECK(optixLaunch(_state->pipelineTracingBwd, cudaStream, _state->paramsDevice, sizeof(PipelineBackwardParameters), &_state->sbtTracingBwd, - rayRad.size(2), rayRad.size(1), rayRad.size(0))); + rayFeat.size(2), rayFeat.size(1), rayFeat.size(0))); - return std::tuple(particleDensityGrad, particleRadianceGrad); + return std::tuple(particleDensityGrad, particleFeaturesGrad); } diff --git a/threedgrt_tracer/tracer.py b/threedgrt_tracer/tracer.py index e143489c..349f84c2 100644 --- a/threedgrt_tracer/tracer.py +++ b/threedgrt_tracer/tracer.py @@ -231,7 +231,11 @@ def render(self, gaussians, gpu_batch: Batch, train=False, frame_id=0): gaussians.get_rotation().contiguous(), gaussians.get_scale().contiguous(), gaussians.get_density().contiguous(), - gaussians.get_features().contiguous(), # TODO: cast to .half() when conf.render.particle_feature_half + ( + gaussians.get_features().half().contiguous() + if self.conf.render.particle_feature_half + else gaussians.get_features().contiguous() + ), Tracer.RenderOpts.DEFAULT, gaussians.n_active_features, self.conf.render.min_transmittance, diff --git a/threedgrut_playground/include/playground/hybridTracer.h b/threedgrut_playground/include/playground/hybridTracer.h index 07444f2c..87d2d769 100644 --- a/threedgrut_playground/include/playground/hybridTracer.h +++ b/threedgrut_playground/include/playground/hybridTracer.h @@ -129,7 +129,7 @@ class HybridOptixTracer : public OptixTracer { torch::Tensor> traceHybrid(uint32_t frameNumber, torch::Tensor rayToWorld, torch::Tensor rayOri, torch::Tensor rayDir, - torch::Tensor particleDensity, torch::Tensor particleRadiance, + torch::Tensor particleDensity, torch::Tensor particleFeatures, int sphDegree, float minTransmittance, torch::Tensor rayMaxT, uint32_t playgroundOpts, torch::Tensor triangles, torch::Tensor vNormals, torch::Tensor vTangents, diff --git a/threedgrut_playground/include/playground/kernels/cuda/trace.cuh b/threedgrut_playground/include/playground/kernels/cuda/trace.cuh index 717b9a8f..85130499 100644 --- a/threedgrut_playground/include/playground/kernels/cuda/trace.cuh +++ b/threedgrut_playground/include/playground/kernels/cuda/trace.cuh @@ -119,9 +119,9 @@ static __device__ __forceinline__ void clearOutputBuffers() { const int ry = fminf(idx.y, params.frameBounds.y); // Ray coordinates in pixels - params.rayRadiance[idx.z][ry][rx][0] = 0.0f; - params.rayRadiance[idx.z][ry][rx][1] = 0.0f; - params.rayRadiance[idx.z][ry][rx][2] = 0.0f; + params.rayFeatures[idx.z][ry][rx][0] = 0.0f; + params.rayFeatures[idx.z][ry][rx][1] = 0.0f; + params.rayFeatures[idx.z][ry][rx][2] = 0.0f; params.rayDensity[idx.z][ry][rx][0] = 0.0f; } @@ -133,9 +133,9 @@ writeRadianceDensityToOutputBuffer(float4 radiance) { const int ry = fminf(idx.y, params.frameBounds.y); // Ray coordinates in pixels - params.rayRadiance[idx.z][ry][rx][0] = radiance.x; - params.rayRadiance[idx.z][ry][rx][1] = radiance.y; - params.rayRadiance[idx.z][ry][rx][2] = radiance.z; + params.rayFeatures[idx.z][ry][rx][0] = radiance.x; + params.rayFeatures[idx.z][ry][rx][1] = radiance.y; + params.rayFeatures[idx.z][ry][rx][2] = radiance.z; params.rayDensity[idx.z][ry][rx][0] = radiance.w; } @@ -147,9 +147,9 @@ accumulateRadianceToOutputBuffer(float3 radiance) { const int ry = fminf(idx.y, params.frameBounds.y); // Ray coordinates in pixels - params.rayRadiance[idx.z][ry][rx][0] += radiance.x; - params.rayRadiance[idx.z][ry][rx][1] += radiance.y; - params.rayRadiance[idx.z][ry][rx][2] += radiance.z; + params.rayFeatures[idx.z][ry][rx][0] += radiance.x; + params.rayFeatures[idx.z][ry][rx][1] += radiance.y; + params.rayFeatures[idx.z][ry][rx][2] += radiance.z; } static __device__ __forceinline__ void @@ -160,9 +160,9 @@ accumulateRadianceDensityToOutputBuffer(float4 radiance) { const int ry = fminf(idx.y, params.frameBounds.y); // Ray coordinates in pixels - params.rayRadiance[idx.z][ry][rx][0] += radiance.x; - params.rayRadiance[idx.z][ry][rx][1] += radiance.y; - params.rayRadiance[idx.z][ry][rx][2] += radiance.z; + params.rayFeatures[idx.z][ry][rx][0] += radiance.x; + params.rayFeatures[idx.z][ry][rx][1] += radiance.y; + params.rayFeatures[idx.z][ry][rx][2] += radiance.z; params.rayDensity[idx.z][ry][rx][0] += radiance.w; } diff --git a/threedgrut_playground/src/hybridTracer.cpp b/threedgrut_playground/src/hybridTracer.cpp index 6cc95306..05683265 100644 --- a/threedgrut_playground/src/hybridTracer.cpp +++ b/threedgrut_playground/src/hybridTracer.cpp @@ -314,7 +314,7 @@ std::tuple(rayDir); paramsHost.particleDensity = getPtr(particleDensity); - paramsHost.particleRadiance = getPtr(particleRadiance); + paramsHost.particleFeatures = getPtr(particleFeatures); paramsHost.particleExtendedData = reinterpret_cast(_state->gPipelineParticleData); - paramsHost.rayRadiance = packed_accessor32(rayRad); + paramsHost.rayFeatures = packed_accessor32(rayRad); paramsHost.rayDensity = packed_accessor32(rayDns); paramsHost.rayHitDistance = packed_accessor32(rayHit); paramsHost.rayNormal = packed_accessor32(rayNrm); diff --git a/threedgrut_playground/src/kernels/cuda/3dgrtKernel.cu b/threedgrut_playground/src/kernels/cuda/3dgrtKernel.cu index 9fdd2b4e..a48976cf 100644 --- a/threedgrut_playground/src/kernels/cuda/3dgrtKernel.cu +++ b/threedgrut_playground/src/kernels/cuda/3dgrtKernel.cu @@ -28,9 +28,9 @@ extern "C" __global__ void __raygen__rg() { traceVolumetricGS(rayData, rayOrigin, rayDirection, 1e-9, 1e9); - params.rayRadiance[idx.z][idx.y][idx.x][0] = rayData.radiance.x; - params.rayRadiance[idx.z][idx.y][idx.x][1] = rayData.radiance.y; - params.rayRadiance[idx.z][idx.y][idx.x][2] = rayData.radiance.z; + params.rayFeatures[idx.z][idx.y][idx.x][0] = rayData.radiance.x; + params.rayFeatures[idx.z][idx.y][idx.x][1] = rayData.radiance.y; + params.rayFeatures[idx.z][idx.y][idx.x][2] = rayData.radiance.z; params.rayDensity[idx.z][idx.y][idx.x][0] = rayData.density; params.rayHitDistance[idx.z][idx.y][idx.x][0] = rayData.hitDistance; params.rayHitDistance[idx.z][idx.y][idx.x][1] = rayData.rayLastHitDistance; From bb2dc9c491104dc5bed845e0337d44fe008d5df4 Mon Sep 17 00:00:00 2001 From: Nicolas Moenne-Loccoz Date: Tue, 21 Apr 2026 15:45:04 -0400 Subject: [PATCH 5/9] 3DGUT: drop canonicalIntersection from SH hit struct via templated specialization `HitParticleT` splits the k-buffer hit record into an SH base and an NHT derived that adds the `canonicalIntersection` float3. The renderer aliases pick the right one from `Params::PerRayParticleFeatures`, so the SH path pays no storage for a field it never reads. `densityHit` routes its out-arg through `canonicalIntersectionSlot` (struct field or stack scratch). The forward feature ternary becomes `if constexpr` so the SH specialization does not name-lookup a missing member. Made-with: Cursor --- .../cuda/renderers/gutKBufferRenderer.cuh | 77 ++++++++++++++++--- 1 file changed, 65 insertions(+), 12 deletions(-) diff --git a/threedgut_tracer/include/3dgut/kernels/cuda/renderers/gutKBufferRenderer.cuh b/threedgut_tracer/include/3dgut/kernels/cuda/renderers/gutKBufferRenderer.cuh index 5a432d9a..c948910b 100644 --- a/threedgut_tracer/include/3dgut/kernels/cuda/renderers/gutKBufferRenderer.cuh +++ b/threedgut_tracer/include/3dgut/kernels/cuda/renderers/gutKBufferRenderer.cuh @@ -18,17 +18,52 @@ #include <3dgut/kernels/cuda/common/rayPayloadBackward.cuh> #include <3dgut/renderer/gutRendererParameters.h> -struct HitParticle { +// HitParticle stores per-ray hit state inside the k-buffer. The +// `canonicalIntersection` is only consumed by the NHT feature path; for the +// SH path (`PerRayParticleFeatures=false`) we drop it from the struct so the +// k-buffer pays no storage cost in registers / shared memory per hit. +template +struct HitParticleT; + +// Base / SH case: common per-hit state only. +template <> +struct HitParticleT { static constexpr float InvalidHitT = -1.0f; int idx = -1; float hitT = InvalidHitT; float alpha = 0.0f; - float3 canonicalIntersection = make_float3(0.f, 0.f, 0.f); }; -template -struct HitParticleKBuffer { - __device__ HitParticleKBuffer() { +// NHT case extends with the per-ray canonical intersection point used by the +// feature evaluation. Plain (non-virtual) inheritance: layout is base fields +// then the extra `float3`, matching the pre-refactor combined struct exactly. +template <> +struct HitParticleT : HitParticleT { + float3 canonicalIntersection = make_float3(0.f, 0.f, 0.f); +}; + +// Helper returning a writable `float3&` slot usable as the `densityHit` +// canonical-intersection out-parameter. Routes to the struct field when +// present, otherwise to a caller-provided stack scratch (which the compiler +// elides when unused downstream). +template +__forceinline__ __device__ float3& canonicalIntersectionSlot(HitParticleT& hit, float3& scratch); + +template <> +__forceinline__ __device__ float3& canonicalIntersectionSlot(HitParticleT& hit, float3& /*scratch*/) { + return hit.canonicalIntersection; +} + +template <> +__forceinline__ __device__ float3& canonicalIntersectionSlot(HitParticleT& /*hit*/, float3& scratch) { + return scratch; +} + +template +struct HitParticleKBufferT { + using HitParticle = HitParticleT; + + __device__ HitParticleKBufferT() { m_numHits = 0; #pragma unroll for (int i = 0; i < K; ++i) { @@ -76,8 +111,9 @@ private: uint32_t m_numHits; }; -template <> -struct HitParticleKBuffer<0> { +template +struct HitParticleKBufferT<0, HasCanonical> { + using HitParticle = HitParticleT; constexpr inline __device__ void insert(HitParticle& hitParticle) const {} constexpr inline __device__ HitParticle operator[](int) const { return HitParticle(); } constexpr inline __device__ uint32_t numHits() const { return 0; } @@ -95,6 +131,11 @@ struct GUTKBufferRenderer : Params { using TRayPayload = RayPayload; using TRayPayloadBackward = RayPayloadBackward; + // Storage-optimized hit types: the per-ray canonical intersection is kept + // only when the feature model needs it (NHT / PerRayParticleFeatures). + using HitParticle = HitParticleT; + using HitParticleKBuffer = HitParticleKBufferT; + struct PrefetchedParticleData { uint32_t idx; DensityParameters densityParameters; @@ -162,9 +203,18 @@ struct GUTKBufferRenderer : Params { hitParticle.hitT, ray.hitT); - particles.featureIntegrateFwd(hitWeight, - Params::PerRayParticleFeatures ? particles.featuresFromBuffer(hitParticle.idx, ray.direction, hitParticle.canonicalIntersection) : tcnn::max(particleFeatures[hitParticle.idx], 0.f), - ray.features); + // `if constexpr` branches so the SH specialization of Hit (which + // has no `canonicalIntersection` member) is not instantiated with + // a missing field reference. + if constexpr (Params::PerRayParticleFeatures) { + particles.featureIntegrateFwd(hitWeight, + particles.featuresFromBuffer(hitParticle.idx, ray.direction, hitParticle.canonicalIntersection), + ray.features); + } else { + particles.featureIntegrateFwd(hitWeight, + tcnn::max(particleFeatures[hitParticle.idx], 0.f), + ray.features); + } if (hitWeight > 0.0f) ray.countHit(); } @@ -232,7 +282,7 @@ struct GUTKBufferRenderer : Params { using namespace threedgut; __shared__ PrefetchedParticleData prefetchedParticlesData[GUTParameters::Tiling::BlockSize]; - HitParticleKBuffer hitParticleKBuffer; + HitParticleKBuffer hitParticleKBuffer; for (uint32_t i = 0; i < tileNumBlocksToProcess; i++, tileNumParticlesToProcess -= GUTParameters::Tiling::BlockSize) { @@ -265,12 +315,15 @@ struct GUTKBufferRenderer : Params { HitParticle hitParticle; hitParticle.idx = particleData.idx; + // `canonicalScratch` is only written when the SH specialization + // of Hit has no canonicalIntersection field; compiler elides it. + float3 canonicalScratch = make_float3(0.f, 0.f, 0.f); if (particles.densityHit(ray.origin, ray.direction, particleData.densityParameters, hitParticle.alpha, hitParticle.hitT, - hitParticle.canonicalIntersection) && + canonicalIntersectionSlot(hitParticle, canonicalScratch)) && (hitParticle.hitT > ray.tMinMax.x) && (hitParticle.hitT < ray.tMinMax.y)) { From f8509098b2ec28d6de9d553a1cea59f1971f537d Mon Sep 17 00:00:00 2001 From: Nicolas Moenne-Loccoz Date: Mon, 27 Apr 2026 19:30:26 -0400 Subject: [PATCH 6/9] Align NHT reference behavior and refinement flow Updates NHT encoding, tetrahedral interpolation, alpha handling, initialization, and refinement behavior so benchmark runs are closer to the GSplat reference while preserving SH defaults. --- TODO.md | 23 ++ TODO_half_3dgrt.md | 144 +++++++++++ TODO_nht_cuda.md | 73 ++++++ configs/apps/colmap_3dgrt_mcmc_nht.yaml | 1 + configs/apps/colmap_3dgut_mcmc_nht.yaml | 3 + .../apps/nerf_synthetic_3dgrt_mcmc_nht.yaml | 1 + .../apps/nerf_synthetic_3dgut_mcmc_nht.yaml | 3 + configs/base_gs.yaml | 6 +- plan/nht_bwd_reg_reduction.md | 183 +++++++++++++ plan/nht_reference_results.md | 10 + .../3dgrt/kernels/cuda/gaussianParticles.cuh | 4 +- .../neuralHarmonicFeaturesParticle.slang | 31 ++- threedgrut/model/feature_decoder.py | 13 +- threedgrut/model/features.py | 5 +- threedgrut/model/model.py | 39 +-- threedgrut/render.py | 4 + threedgrut/trainer.py | 65 ++++- .../kernels/cuda/models/gaussianParticles.cuh | 241 ++++++++++++++++- .../models/shRadiativeGaussianParticles.cuh | 44 ++++ .../cuda/renderers/gutKBufferRenderer.cuh | 21 ++ .../kernels/cuda/renderers/gutRenderer.cuh | 13 +- .../neuralHarmonicFeaturesParticle.slang | 31 ++- threedgut_tracer/setup_3dgut.py | 7 + validate.py | 244 ++++++++++++++++++ 24 files changed, 1149 insertions(+), 60 deletions(-) create mode 100644 TODO.md create mode 100644 TODO_half_3dgrt.md create mode 100644 TODO_nht_cuda.md create mode 100644 plan/nht_bwd_reg_reduction.md create mode 100644 plan/nht_reference_results.md create mode 100644 validate.py diff --git a/TODO.md b/TODO.md new file mode 100644 index 00000000..50839c85 --- /dev/null +++ b/TODO.md @@ -0,0 +1,23 @@ +# TODO + +## 3DGRT: half-precision particle features + +`conf.render.particle_feature_half` is compiled into the 3DGRT kernel via `-DPARTICLE_FEATURE_HALF` +but the Python-side cast is missing. In `threedgrt_tracer/tracer.py`, `gaussians.get_features()` +must be cast to `.half()` before being passed to `_Autograd.apply` when the flag is set, +matching what 3DGUT already does. + +See the `TODO` comment in `threedgrt_tracer/tracer.py`. + +## 3DGRT: NHT support in CUDA path (`gaussianParticles.cuh`) + +The NHT feature transform (`FEATURE_TRANSFORM_TYPE=1`) is implemented for the Slang path +(`gaussianParticles.slang`) but not yet in the CUDA path (`gaussianParticles.cuh`). +Full NHT support in 3DGRT requires extending `gaussianParticles.cuh` with the NHT +interpolation and activation logic currently only present in the Slang kernel. + +## 3DGUT: refactor `evalBackwardNoKBuffer` to share path with k-buffer backward + +`evalBackwardNoKBuffer` (`gutKBufferRenderer.cuh`) duplicates logic from the k-buffer backward +path. The two should be unified into a shared implementation to reduce code duplication and +ensure future fixes apply to both. diff --git a/TODO_half_3dgrt.md b/TODO_half_3dgrt.md new file mode 100644 index 00000000..db386c60 --- /dev/null +++ b/TODO_half_3dgrt.md @@ -0,0 +1,144 @@ +# 3DGRT half-precision feature support + +Goal: make `conf.render.particle_feature_half=true` and `conf.render.feature_output_half=true` +work end-to-end in `threedgrt_tracer`, matching the behavior already implemented in +`threedgut_tracer`. Gradient buffers remain fp32 on both paths. + +Semantics (mirroring 3dgut): +- `particle_feature_half=true`: storage for `particleRadiance` (per-particle feature buffer) + is fp16. Slang entry points already expect `feat_elem_t*` (`__half*` when the macro is set). + Gradient `particleRadianceGrad` stays fp32. +- `feature_output_half=true`: storage for the per-ray integrated feature buffer (`rayRadiance`) + is fp16. Gradient `rayRadianceGrad` stays fp32. Tracer `.forward()` casts fp16 back to fp32 + before returning, mirroring 3dgut. + +## Scope + +Files to touch (by layer): + +- C++ pipeline type layer + - `threedgrt_tracer/include/3dgrt/pipelineParameters.h` + Introduce `TFeatureDensityElem` (output/ray feature) and `TParticleFeatureElem` (particle + storage) typedefs, guarded on the two macros. Change `particleRadiance` from `const float*` + to `const TParticleFeatureElem*`, and `rayRadiance` from + `PackedTensorAccessor32` to `PackedTensorAccessor32`. + `particleRadianceGrad` and `rayRadianceGrad` stay fp32. +- OptiX raygen kernels + - `threedgrt_tracer/src/kernels/cuda/referenceSlangOptix.cu` + - `threedgrt_tracer/src/kernels/cuda/referenceSlangBwdOptix.cu` + 1. Replace `const_cast(params.particleRadiance)` with + `const_cast(params.particleRadiance)`. + 2. FWD write to `rayRadiance`: wrap with `__float2half` when `FEATURE_OUTPUT_HALF`. + 3. BWD read from `rayRadiance`: wrap with `__half2float` when `FEATURE_OUTPUT_HALF`. + `rayRadianceGrad` reads stay fp32. +- Host launcher + - `threedgrt_tracer/src/optixTracer.cpp` + 1. `trace()`: allocate `rayRad` with `torch::kHalf` when `FEATURE_OUTPUT_HALF=1`, + build `packed_accessor32(rayRad)`, and + `getPtr(particleRadiance)`. + 2. `traceBwd()`: same dtype for the forward `rayRad` input; the grad tensors remain fp32. +- Python tracer + - `threedgrt_tracer/tracer.py` + 1. Cast `gaussians.get_features()` to `.half()` when `conf.render.particle_feature_half`. + 2. Keep `ray_features.float()` return to caller; `ray_features` saved in ctx may be fp16 + when `feature_output_half=true` (already saves the raw output, consistent with 3dgut). + +No changes required in Slang `.slang` or generated `.cuh`: the generalization already landed +and compiles correctly once `SLANG_CUDA_ENABLE_HALF=1` is set (done). + +## Task breakdown + +Each task is independently reviewable and testable (run validate.py for the relevant flag +combinations after each). + +### T1 — Introduce typedefs in `pipelineParameters.h` +- Add `TFeatureDensityElem` and `TParticleFeatureElem` (guarded by `FEATURE_OUTPUT_HALF` and + `PARTICLE_FEATURE_HALF`), include `cuda_fp16.h` when either is set. +- Change `particleRadiance` to `const TParticleFeatureElem*` and `rayRadiance` accessor to + `PackedTensorAccessor32`. +- No functional change when both macros are 0 (typedefs resolve to `float`). +- Test: build with both flags false (current default) → no-op rebuild; CI NeRF-Synthetic 3dgrt + smoke test still passes. + +### T2 — Update OptiX kernels for fp16 reads/writes +- Apply the `__float2half` / `__half2float` wrappers in `referenceSlangOptix.cu` and + `referenceSlangBwdOptix.cu` under `FEATURE_OUTPUT_HALF`. +- Update `const_cast` sites to `TParticleFeatureElem*`. +- Test: build with both flags false → identical numerical output to baseline (no wrappers + compiled in). + +### T3 — Host buffer allocation and accessor typing +- `optixTracer.cpp`: select dtype `kHalf` vs `kFloat32` for `rayRad`; use + `packed_accessor32(rayRad)`. +- `getPtr(particleRadiance)` for the particle buffer. +- Test: with flags false → unchanged; build-time assert that tensor dtype matches the + typedef via `TORCH_CHECK(rayRad.scalar_type() == ...)` in DEBUG. + +### T4 — Python cast for `particle_feature_half` +- `tracer.py`: mirror 3dgut's conditional `.half()` cast on `gaussians.get_features()`. +- Test: flags false → unchanged. + +### T5 — End-to-end validation with flags enabled +- Run `validate.py` with `render.particle_feature_half=true render.feature_output_half=true` + using an existing NHT config (e.g. `nerf_synthetic_3dgrt_mcmc_nht.yaml`). +- Compare PSNR after N iterations against the fp32 baseline — expected within 0.1 dB. +- Gradients: single backward pass on a fixed seed; check that + `particleRadianceGrad` and `rayRadianceGrad` are finite and within tolerance of the + fp32 reference. + +### T6 — Rename `*Radiance*` → `*Features*` in 3dgrt +Naming cleanup to align with the post-SH NHT feature abstraction. The legacy `Radiance` +suffix comes from the SH-only era; the buffers now carry arbitrary per-particle / per-ray +features. Purely mechanical rename, no behavioral change. Runs AFTER T1–T5 land so we are +not also chasing name drift during the fp16 functional work. + +Rename mapping (all scopes): +- `PipelineParameters::particleRadiance` → `particleFeatures` +- `PipelineParameters::rayRadiance` → `rayFeatures` +- `PipelineBackwardParameters::particleRadianceGrad` → `particleFeaturesGrad` +- `PipelineBackwardParameters::rayRadianceGrad` → `rayFeaturesGrad` +- `OptixTracer::trace(..., torch::Tensor particleRadiance, ...)` arg → `particleFeatures` +- `OptixTracer::traceBwd(..., torch::Tensor particleRadiance, rayRad, rayRadGrd, ...)` args + → `particleFeatures`, `rayFeat`, `rayFeatGrd` (local tensors + Python side kwargs). +- `particleRadianceGrad` local in `optixTracer.cpp::traceBwd` → `particleFeaturesGrad`. +- Python: `tracer.py` local variables `ray_features` / `ray_features_grd` are already + feature-named; cross-check that the pybind11 binding signature in `bindings.cpp` uses + the new C++ arg names. + +Out of scope for T6 (per resolved decisions above): +- `particleRadianceSphDegree` C++ field and `conf.render.particle_radiance_sph_degree` YAML. +- `shRadiativeParticles.slang` filename and internal `shRadiance*` identifiers (SH path). +- Any `*Radiance*` identifiers that only exist on the SH-specific code path. + +Test: +- Build + full `validate.py` run with fp32 flags (both false) → identical numerical + output to pre-T6 baseline (bit-identical expected since only identifier renames). +- Build + `validate.py` with fp16 flags (both true) → identical output to T5 result. + +## Tests to write up-front + +- `tests/test_3dgrt_half_flags.py` (new, small) + - Parametrize over `(particle_feature_half, feature_output_half) ∈ {(F,F),(T,F),(F,T),(T,T)}`. + - Forward only, single frame, fixed scene; compare `pred_features.float()` to the (F,F) + baseline with `atol=5e-3, rtol=1e-2`. + - Forward + backward; compare `mog_sph.grad` to the (F,F) baseline at the same tolerance. + +## Decisions (resolved with user) + +1. T5 validation ownership: user runs validation; the plan only needs to keep the hooks in + place (no tolerance tuning required from the implementer). +2. Gradient buffers stay fp32 end-to-end (no half-grad path). +3. T6 rename scope is restricted to identifiers naming buffers that can carry NHT features + (i.e. the per-particle feature storage and per-ray integrated feature output, plus their + fp32 gradients). Scalars and SH-specific paths are NOT renamed: + - keep `particleRadianceSphDegree` (C++ field) and `conf.render.particle_radiance_sph_degree` + (YAML) — scalar, shared with the SH path + - keep `shRadiativeParticles.slang` filename and its internal `shRadiance*` identifiers — + SH-only code path. +4. T6 runs AFTER T1–T5. + +## Non-goals + +- No changes to CUDA fallback path (`gaussianParticles.cuh`) — per the existing TODO that is + a separate workstream. +- No changes to `threedgrt_playground`. diff --git a/TODO_nht_cuda.md b/TODO_nht_cuda.md new file mode 100644 index 00000000..031799a1 --- /dev/null +++ b/TODO_nht_cuda.md @@ -0,0 +1,73 @@ +# Handwritten CUDA port of `featuresIntegrateBwdToLocalGrad` (NHT path) + +## Status +- [x] **T1** — Tetrahedron constants (`tetraV0`, `tetraN0..N3`) placed in + `nht_detail` namespace at the top of `shRadiativeGaussianParticles.cuh`. + Values derived from Slang's vertex ordering; verified via script + (w_k == 1 at v_k, 0 at other vertices). +- [x] **T2** — Method body replaced, gated by `#if NHT_FEATURES_BWD_LOCAL_GRAD_CUDA`. + Default = `1` (native CUDA). Flip to `0` in + `threedgut_tracer/include/3dgut/kernels/cuda/models/shRadiativeGaussianParticles.cuh` + to restore the Slang-autodiff path (kept unchanged in the `#else` branch). +- [ ] **T3** — Rebuild, run `validate.py` (or one training step) with the macro + at 1 vs 0. Compare: + - feature gradient buffer L2 (primary parity check) + - density / position gradients (sanity; should be identical since we don't + touch those paths) + - `renderBackward` ms in nsys. +- [ ] **T4** — If parity holds, keep default = 1. Otherwise flip to 0 and iterate. + +## What the handwritten CUDA does (semantics to match Slang exactly) + +Replicates the sequence inside Slang's `particleFeaturesIntegrateBwdToBuffer` +called with `exclusiveGradient=true` and the shifted `featureLocalGrad` buffer: + +1. Early-out when `alpha <= 0`. +2. Recover pre-hit accumulator: + `acc_prev[i] = (integratedFeatures[i] - features[i]*alpha) / (1-alpha)`. +3. VJP of back-to-front `y_i = (1-alpha)*acc_prev_i + alpha*f_i` against + incoming `dy = integratedFeaturesGrad`: + - `dFeatures[i] = alpha * dy_i` + - `alphaGrad += sum_i (features[i] - acc_prev[i]) * dy_i` + - `integratedFeaturesGrad[i] = (1-alpha) * dy_i` (new accumulator grad) +4. Barycentric weights `w[0..3]` from `canonicalIntersection` (Cramer form + matching Slang, precomputed `N_k` face normals). +5. Load 4 vertex feature blocks × `InterpPointFeatureDim` once + (`__half2float` when `PARTICLE_FEATURE_HALF=1`). +6. Activation backward → `dBase[InterpPointFeatureDim]`: + | Activation | Forward | Backward | + |---|---|---| + | None (0) | `out = base` | `dBase = dFeatures` | + | Siren (1) | `sin(base * 2^f)` | `dBase += cos(base*freq) * freq * dOut` | + | Sincos (2) | `sin + cos` | `dBase += (cos - sin) * freq * dOut` | + | Relu (3) | `max(0, base)` | `dBase = (features[i] > 0) ? dFeatures[i] : 0` | +7. Barycentric backward: + - `featureLocalGrad[k*IPFD + i] += w[k] * dBase[i]` (matches Slang's `+=` with exclusiveGradient=true) + - `canonicalIntersectionGrad += sum_k (sum_i vert[k][i] * dBase[i]) * N_k` + +## Guardrails +- `static_assert(FeatureTransformType == 1)` — NHT-only. +- `static_assert(FEATURE_INTERPOLATION_TYPE == 0)` — barycentric only. +- `static_assert(FEATURE_INTERPOLATION_SUPPORT == 1)` — tetrahedra only. +- `static_assert` on `RAY_FEATURE_DIM` / `INTERP_POINT_FEATURE_DIM` / activation consistency. +- `static_assert(4 * IPFD == ParticleFeatureDim)` — buffer layout. + +Any unsupported config fails at compile time — fallback is to flip the macro to 0. + +## Confidence + +- **Forward parity** (interpolation + integration, current config `activation=relu`): + high (see comparison with `neural-harmonic-textures/Interpolation.cuh` — same + tetrahedron geometry, different indexing; same integration math). +- **Backward numerical parity**: medium-high. The Relu path is trivial. The + (1-α)/α lerp VJP + barycentric VJP is standard. Main risk is a sign or + vertex-index swap — covered by T3 gradient diff. +- **Perf win**: medium-high. Expected 3–5× on this single kernel. + +## Open reference points + +- Slang source: `threedgut_tracer/include/3dgut/kernels/slang/models/neuralHarmonicFeaturesParticle.slang` +- External CUDA ref: `/nv/dev/neural-harmonic-textures/gsplat/gsplat/cuda/csrc/RasterizeToPixelsFromWorldNHT3DGSBwd.cu` + (sincos activation; do NOT copy the activation bwd verbatim — see + "Caveats" in the forward-parity discussion: Slang's sincos sums into one + channel, ref's keeps them separate). diff --git a/configs/apps/colmap_3dgrt_mcmc_nht.yaml b/configs/apps/colmap_3dgrt_mcmc_nht.yaml index fb53696d..74b9fc55 100644 --- a/configs/apps/colmap_3dgrt_mcmc_nht.yaml +++ b/configs/apps/colmap_3dgrt_mcmc_nht.yaml @@ -14,6 +14,7 @@ model: render: pipeline_type: referenceSlang backward_pipeline_type: referenceSlangBwd + particle_kernel_max_alpha: 0.999 loss: use_opacity: true diff --git a/configs/apps/colmap_3dgut_mcmc_nht.yaml b/configs/apps/colmap_3dgut_mcmc_nht.yaml index e06d8f0f..551c8e0e 100644 --- a/configs/apps/colmap_3dgut_mcmc_nht.yaml +++ b/configs/apps/colmap_3dgut_mcmc_nht.yaml @@ -11,6 +11,9 @@ defaults: model: feature_type: "nht" +render: + particle_kernel_max_alpha: 0.999 + loss: use_opacity: true lambda_opacity: 0.02 diff --git a/configs/apps/nerf_synthetic_3dgrt_mcmc_nht.yaml b/configs/apps/nerf_synthetic_3dgrt_mcmc_nht.yaml index f398179f..9230ba43 100644 --- a/configs/apps/nerf_synthetic_3dgrt_mcmc_nht.yaml +++ b/configs/apps/nerf_synthetic_3dgrt_mcmc_nht.yaml @@ -14,6 +14,7 @@ model: render: pipeline_type: referenceSlang backward_pipeline_type: referenceSlangBwd + particle_kernel_max_alpha: 0.999 loss: use_opacity: true diff --git a/configs/apps/nerf_synthetic_3dgut_mcmc_nht.yaml b/configs/apps/nerf_synthetic_3dgut_mcmc_nht.yaml index 7b8307c8..3b06d71d 100644 --- a/configs/apps/nerf_synthetic_3dgut_mcmc_nht.yaml +++ b/configs/apps/nerf_synthetic_3dgut_mcmc_nht.yaml @@ -11,6 +11,9 @@ defaults: model: feature_type: "nht" +render: + particle_kernel_max_alpha: 0.999 + loss: use_opacity: true lambda_opacity: 0.02 diff --git a/configs/base_gs.yaml b/configs/base_gs.yaml index 5bf31edf..73703d87 100644 --- a/configs/base_gs.yaml +++ b/configs/base_gs.yaml @@ -70,8 +70,10 @@ model: nht_features: dim: 48 + init_min: -1.5707963267948966 + init_max: 1.5707963267948966 activation: - type: "siren" + type: "sincos" num_frequencies: 1 interpolation_type: "barycentric" @@ -83,6 +85,7 @@ model: dir_encoding_degree: 3 sh_scale: 3.0 output_activation: "Sigmoid" + unpremultiply_alpha: false learning_rate: 0.00068 reg_weight: 0.0 scheduler: @@ -91,6 +94,7 @@ model: max_steps: 30000 ema_decay: 0.95 ema_start_step: 0 + color_refine_steps: 3000 background: name: background-color # one of [skip-background, background-color] diff --git a/plan/nht_bwd_reg_reduction.md b/plan/nht_bwd_reg_reduction.md new file mode 100644 index 00000000..4e8d3424 --- /dev/null +++ b/plan/nht_bwd_reg_reduction.md @@ -0,0 +1,183 @@ +# Plan — NHT backward register-pressure reduction + +## Context + +`renderBackward` (NHT + PARTICLE_FEATURE_HALF=1, FEATURE_OUTPUT_HALF=1) is 60 ms +while fwd is 4 ms (15×). Two diagnostics established the root cause: + +1. **A/B: commenting the body of `featuresIntegrateBwdToLocalGrad` changes nothing.** + Commenting `featureLocalGradWarpReduceAndWrite` makes the kernel 2× faster. + This is not an atomic-contention signal (reference + `RasterizeToPixelsFromWorldNHT3DGSBwd.cu` uses the same 48-atomic pattern + and is fast) — it is DCE freeing register pressure upstream. + +2. **ptxas stats (sm_90):** + + | kernel | regs/thread | spills | smem | barriers | + |------------------|-------------|--------|-----------|----------| + | `renderBackward` | **214** | 0 | 17 408 B | 1 | + | `render` (fwd) | 64 | 0 | 17 408 B | 1 | + + Actual `BlockSize = 16 × 16 = 256` (see `gutRendererParameters.h`), not 128 + as I first assumed. With 65 536 regs/SM on sm_90, reg cliffs at 256 + threads/block are: + - 255 regs → 1 block/SM (at 214 today → 1 block/SM, occupancy ~12.5 %) + - 128 regs → 2 blocks/SM (25 %) + - 85 regs → 3 blocks/SM (~37 %) + - 64 regs → 4 blocks/SM (50 %) + + At 214 regs the kernel fits ~1 block/SM → effective occupancy ~12.5 %, + memory latency cannot be hidden. + (If `FINE_GRAINED_LOAD_BALANCING=true`, blocks are 128 thr — shifts the + cliffs: 128r→4blk, 96r→5blk, 64r→8blk. Confirm which path is active.) + +## Goal + +Drive `renderBackward` regs/thread below **128** (ideally **≤ 96**) while +keeping gradients bit-equivalent to the current CUDA-integrated path +(compile-time toggle `NHT_FEATURES_BWD_LOCAL_GRAD_CUDA=1`). + +Measurable success criteria: +- Primary: `renderBackward` wall time reduced ≥ 25 % on the current scene. +- Secondary: regs/thread reported by ptxas ≤ 128. +- Parity: feature/density gradient L2 vs pre-change reference < 1e-6 rel. + +Non-goal: architectural rework of the renderer or Slang call surface. + +## Task list (strictly ordered, each task independently reversible) + +### T1 — Attribution of the register budget + +Goal: measure how much each piece of live state costs in regs, so we know +which fix is worth pursuing. Pure diagnostic, no behavioral change. + +- **T1.a**: build with `-Xptxas=-v -res-usage` permanently enabled behind an + env var (`export NHT_PTXAS_VERBOSE=1`) via `setup_3dgut.py`. Archive + the current 214-reg baseline. +- **T1.b**: temporarily gut `featureLocalGradWarpReduceAndWrite`'s body + (keep signature, compile-time `#if 0` the shuffles + atomics). Rebuild. + Record new regs + ms. +- **T1.c**: additionally gut our `featuresIntegrateBwdToLocalGrad` body. + Rebuild. Record. +- **T1.d**: additionally replace the Slang `densityProcessHitBwdToBuffer` + call with a no-op. Rebuild. Record. + +Deliverable: a 5-row table (baseline + T1.a .. T1.d) of (regs, smem, ms). +Tells us exactly where the 214 regs are parked. + +### T2 — Target fix 1: stage `featureLocalGrad[]` into `__shared__` + +Hypothesis: moving the 48-float per-thread scratch out of registers and +into per-thread shmem slots will free up to 48 regs/thread, with negligible +runtime overhead (shmem LD/ST ≈ register throughput for small arrays). + +- **T2.a** (implementation): add a `__shared__ float sFeatureLocalGrad[BlockSize][PARTICLE_FEATURE_DIM]` + in the KBuffer renderer's bwd path. Change the caller site in + `gutKBufferRenderer.cuh` to pass `sFeatureLocalGrad[tileThreadIdx]` + (plus a clear of that row at top of each `j` iteration). +- **T2.b** (integrate): update the callee signature in + `shRadiativeGaussianParticles.cuh::featuresIntegrateBwdToLocalGrad` and + in `threedgut::nht::featuresIntegrateBwdToLocalGrad` (already takes a + `float*`, so just documentation + assume shmem aliasing). +- **T2.c** (reduce): rewrite `featureLocalGradWarpReduceAndWrite` to reduce + from shmem rather than thread-local registers. Verify `__shfl_xor_sync` + still works (it does — operates on any register value; we first load + shmem row into a single thread-private scalar per iteration). +- **T2.d** (smem budget check): new smem = `17 408 + BlockSize × 48 × 4` + bytes. For `BlockSize=256` that is +49 152 B → 66 560 B/block. H100 has + 228 KB dynamic smem/SM → 3 blocks/SM by smem. Fine only if we hit ≥ 2 + blocks/SM by regs too (i.e. reg cliff ≤ 128). Record. +- **T2.e** (bench): rebuild, capture ptxas regs and `renderBackward` ms. + Table row. +- **T2.f** (parity): run `validate.py` or equivalent; dump feature + gradient buffer L2 vs baseline. Must be ≤ 1e-6 relative. + +Expected: regs 214 → ~166 (–48). Probably still above 128 cliff; need T3. + +### T3 — Target fix 2: `__launch_bounds__` + controlled spill + +Hypothesis: if T2 alone does not cross the 128-reg cliff, force the +compiler to cap regs with `__launch_bounds__(BlockSize, minBlocksPerSM)`. +This may introduce local-memory spills, but occupancy gain beats spill +cost when kernel is latency-bound (our case). + +- **T3.a**: apply `__launch_bounds__(256, 2)` to the `renderBackward` + kernel entry in `gutRenderer.cuh`. Regs forced ≤ 128. Rebuild, record. +- **T3.b**: try `__launch_bounds__(256, 3)` (regs ≤ 85). Rebuild, record. +- **T3.c**: try `__launch_bounds__(256, 4)` (regs ≤ 64). Rebuild, record. + (If `FINE_GRAINED_LOAD_BALANCING=true` the block size is 128; adjust + the first arg accordingly.) +- **T3.d** (parity): only behavioral change is scheduling → gradients + must be bit-identical. Confirm L2 = 0. + +Table rows: regs, spill stores, spill loads, ms per configuration. + +### T4 — Pick winner + +Decision table from T2/T3 data: + +| config | regs | spills | ms | parity | pick? | +|----------------|------|--------|-------|--------|-------| +| baseline | 214 | 0 | 60 | ref | — | +| T2 (shmem) | ? | 0 | ? | yes | ? | +| T3a (128,4) | 128 | ? | ? | yes | ? | +| T3b (128,5) | 96 | ? | ? | yes | ? | +| T2 + T3a | ? | ? | ? | yes | ? | +| T2 + T3b | ? | ? | ? | yes | ? | + +Pick the lowest ms with spill stores low enough to not dominate L1 traffic. + +### T5 — (conditional) Prefetch features into `__shared__` à la reference + +Only if T4 falls short of the 25 % target. This is the +`RasterizeToPixelsFromWorldNHT3DGSBwd.cu` pattern: load the 48-float +feature block for all particles in the batch into shmem **once**, then +per-hit reads are shmem broadcasts. Cuts redundant global loads (currently +done twice per hit — inside `featuresFromBuffer` and again inside our +`featuresIntegrateBwdToLocalGrad`) and reduces register churn from callee +boundaries. + +Bigger change: ~60 lines in `gutKBufferRenderer.cuh` inner batch loop, +plus new shmem size `BlockSize × 48 × sizeof(TFeatElem)` (= 12 KB @ fp16, +24 KB @ fp32). Defer until T4 signals it is needed. + +### T6 — Validation & cleanup + +- **T6.a**: final regs + ms comparison table. +- **T6.b**: full gradient parity on ≥ 2 training steps. +- **T6.c**: revert `NHT_PTXAS_VERBOSE` default (env-gated, stays available). +- **T6.d**: document the choice + knob(s) in `TODO_nht_cuda.md`. + +## Test harness (used by every Tx) + +One repeatable command invocation that captures both ptxas output and +kernel ms. Pseudo-code for `bench.sh`: + +``` +rm ~/.cache/torch_extensions/.../gutRenderer.cuda.o +NHT_PTXAS_VERBOSE=1 python validate.py --iters 50 --nsys ... +grep 'registers\|spill\|renderBackward' /tmp/build_and_run.log +``` + +Produces one row of the comparison table per Tx invocation. + +## Confidence + +- T1 (diagnostic): 95 %. Already have one data point, filling the matrix is mechanical. +- T2 alone buys 30–48 regs: 70 %. +- T2 alone reaches 128-reg cliff: 40 %. +- T3 forces the cliff and gains 25 %+: 60 %. +- Combined (T2+T3) reaches 50 % speedup: 50 %. + +## Open questions (for your review before execution) + +1. Is the validation harness OK as above, or do you have a preferred + benchmarking script I should wire into `bench.sh`? +2. `BlockSize` — I assumed 128. Worth checking the actual tiling config + used by the NHT path; the smem budget math depends on it. +3. Parity tolerance: 1e-6 relative is arbitrary. Tighter (bit-exact) is + achievable for T2 and T3 since they do not change the math. Want me + to require bit-exact? +4. If register pressure is mostly coming from the Slang-exported + `densityProcessHitBwdToBuffer` (T1.d will tell), we may have to + port that too. That is outside this plan — flag as follow-up if so. diff --git a/plan/nht_reference_results.md b/plan/nht_reference_results.md new file mode 100644 index 00000000..3f2f7c3b --- /dev/null +++ b/plan/nht_reference_results.md @@ -0,0 +1,10 @@ +# NHT Reference Results + +## Bonsai Comparison + +| Run | Source | PSNR | SSIM | LPIPS | CC PSNR | CC SSIM | CC LPIPS | Time (ms/image) | +| --- | --- | ---: | ---: | ---: | ---: | ---: | ---: | ---: | +| GSplat reference | GSplat NHT reference run | 34.163 | 0.9542 | 0.235 | - | - | - | 11.000 | +| 3DGRUT previous | Before benchmark-parity/color-refinement fixes | 33.427 | 0.949 | 0.248 | 33.559 | 0.947 | 0.247 | 2.489 | +| 3DGRUT updated | After benchmark-parity/color-refinement fixes | 33.734 | 0.951 | 0.246 | 33.908 | 0.949 | 0.246 | 2.455 | + diff --git a/threedgrt_tracer/include/3dgrt/kernels/cuda/gaussianParticles.cuh b/threedgrt_tracer/include/3dgrt/kernels/cuda/gaussianParticles.cuh index 54a1fcbf..73afa688 100644 --- a/threedgrt_tracer/include/3dgrt/kernels/cuda/gaussianParticles.cuh +++ b/threedgrt_tracer/include/3dgrt/kernels/cuda/gaussianParticles.cuh @@ -372,7 +372,7 @@ __device__ inline bool processHit( const float grayDist = dot(gcrod, gcrod); const float gres = particleResponse(grayDist); - const float galpha = fminf(0.99f, gres * particleDensity); + const float galpha = fminf(GAUSSIAN_PARTICLE_MAX_ALPHA, gres * particleDensity); const bool acceptHit = (gres > minParticleKernelDensity) && (galpha > minParticleAlpha); if (acceptHit) { @@ -513,7 +513,7 @@ __device__ inline void processHitBwd( const float grayDist = dot(gcrod, gcrod); const float gres = particleResponse(grayDist); - const float galpha = fminf(0.99f, gres * particleDensity); + const float galpha = fminf(GAUSSIAN_PARTICLE_MAX_ALPHA, gres * particleDensity); if ((gres > minParticleKernelDensity) && (galpha > minParticleAlpha)) { ParticleDensity& particleDensityGrad = particleDensityGradPtr[particleIdx]; diff --git a/threedgrt_tracer/include/3dgrt/kernels/slang/models/neuralHarmonicFeaturesParticle.slang b/threedgrt_tracer/include/3dgrt/kernels/slang/models/neuralHarmonicFeaturesParticle.slang index 0b426f42..efce40a9 100644 --- a/threedgrt_tracer/include/3dgrt/kernels/slang/models/neuralHarmonicFeaturesParticle.slang +++ b/threedgrt_tracer/include/3dgrt/kernels/slang/models/neuralHarmonicFeaturesParticle.slang @@ -17,7 +17,8 @@ // Same CudaDeviceExport API as shRadiativeParticles.slang for drop-in replacement. // Compile-time: FEATURE_INTERPOLATION_TYPE, FEATURE_INTERPOLATION_SUPPORT, FEATURE_ACTIVATION_TYPE, FEATURE_ACTIVATION_NUM_FREQUENCIES, // PARTICLE_FEATURE_DIM (total K per particle = buffer stride), INTERP_POINT_FEATURE_DIM (per-interpolation-point = K/num_points). -// Support: center -> K=interpPointDim; tetrahedra -> K=4*interpPointDim. With activation: N = interpPointDim * num_frequencies (decoder input). +// Support: center -> K=interpPointDim; tetrahedra -> K=4*interpPointDim. +// With sincos activation: N = interpPointDim * num_frequencies * 2 (separate sin/cos channels). namespace neuralHarmonicFeaturesParticle { @@ -42,8 +43,8 @@ static const int InterpolationSupport_Tetrahedra = 1; static const int InterpolationSupport_CoTriangles = 2; // 2 coplanar triangles / Not supported yet static const int InterpolationSupport = FEATURE_INTERPOLATION_SUPPORT; -// Canonical regular tetrahedron that contains the unit sphere; each face is tangent to the unit sphere -// (insphere radius = 1). Same geometry as particlePrimitives.cu (enclosing tetrahedra). +// Canonical regular tetrahedron matching the GSplat NHT reference layout: +// p0=(sqrt(6),-sqrt(2),-1), p1=(-sqrt(6),-sqrt(2),-1), p2=(0,2*sqrt(2),-1), p3=(0,0,3). // Incenter at origin; base z=-1, apex z=3; edge s=sqrt(24), height h=4, inradius r=1. static const float tetraHedraEdge = 4.898979485566356f; // sqrt(24) static const float tetraHedraFaceHeight = 4.242640687119285f; // s*sqrt(3)/2 @@ -51,10 +52,10 @@ static const float tetraHedraHeight = 4.0f; // s*sqrt(2/3) static const float tetraHedraFaceInRadius = 1.4142135623730951f; // s*sqrt(3)/6 = sqrt(2) static const float tetraHedraInRadius = 1.0f; // unit sphere static const float3 canonicalTetraVerts[4] = { + float3(0.5f * tetraHedraEdge, -tetraHedraFaceInRadius, -1.0f), float3(-0.5f * tetraHedraEdge, -tetraHedraFaceInRadius, -1.0f), float3(0.0f, tetraHedraFaceHeight - tetraHedraFaceInRadius, -1.0f), - float3(0.0f, 0.0f, tetraHedraHeight - tetraHedraInRadius), - float3(0.5f * tetraHedraEdge, -tetraHedraFaceInRadius, -1.0f) + float3(0.0f, 0.0f, tetraHedraHeight - tetraHedraInRadius) }; // Cramer terms for canonical tetrahedron (independent of P); used by barycentricTetrahedronCanonical. static const float3 canonicalTetraE1 = canonicalTetraVerts[1] - canonicalTetraVerts[0]; @@ -132,7 +133,7 @@ float4 barycentricTetrahedronCanonical(float3 P) return weights; } -// Encode and activate : none -> identity; siren -> sin(b*2^f); sincos -> sin(b*2^f)+cos(b*2^f); relu -> max(0,b). +// Encode and activate : none -> identity; siren -> sin(b*2^f); relu -> max(0,b). [BackwardDifferentiable][ForceInline] float encodeAndActivate(float baseVal, no_diff int f) { @@ -142,9 +143,7 @@ float encodeAndActivate(float baseVal, no_diff int f) return max(0.0f, baseVal); float freq = ldexp(1.0f, f); float angle = baseVal * freq; - if (FeatureActivationType == FeatureActivationType_Siren) - return sin(angle); - return sin(angle) + cos(angle); + return sin(angle); } // Compute blended features into baseFeatures[INTERP_POINT_FEATURE_DIM], then optionally expand by activation to features[RayFeatureDim]. @@ -174,6 +173,20 @@ void featuresFromParametersBuffer(ParametersBuffer parametersBuffer, [ForceUnroll] for (int i = 0; i < InterpPointFeatureDim; ++i) { features[i] = baseFeatures[i]; } + } else if (FeatureActivationType == FeatureActivationType_Relu) { + [ForceUnroll] for (int i = 0; i < InterpPointFeatureDim; ++i) { + features[i] = max(0.0f, baseFeatures[i]); + } + } else if (FeatureActivationType == FeatureActivationType_Sincos) { + [ForceUnroll] for (int k = 0; k < InterpPointFeatureDim; ++k) { + [ForceUnroll] for (int f = 0; f < FeatureActivationNumFrequencies; ++f) { + float freq = float(f + 1); + float angle = baseFeatures[k] * freq; + int outIdx = k * FeatureActivationNumFrequencies * 2 + f * 2; + features[outIdx + 0] = sin(angle); + features[outIdx + 1] = cos(angle); + } + } } else { [ForceUnroll] for (int k = 0; k < InterpPointFeatureDim; ++k) { [ForceUnroll] for (int f = 0; f < FeatureActivationNumFrequencies; ++f) { diff --git a/threedgrut/model/feature_decoder.py b/threedgrut/model/feature_decoder.py index 2680060e..351275e0 100644 --- a/threedgrut/model/feature_decoder.py +++ b/threedgrut/model/feature_decoder.py @@ -36,6 +36,7 @@ def __init__( output_activation: str = "Sigmoid", ema_decay: float = 0.0, ema_start_step: int = 0, + unpremultiply_alpha: bool = False, ): """Initialize the feature decoder. @@ -51,6 +52,7 @@ def __init__( output_activation: Output layer activation ("Sigmoid" for [0,1] RGB, or "ReLU") ema_decay: If > 0, keep EMA shadow of parameters (decay factor). 0 = no EMA. ema_start_step: Global step at which to start updating EMA. + unpremultiply_alpha: If True, decode features / alpha then multiply RGB by alpha. """ super().__init__() self.ray_feature_dim = ray_feature_dim @@ -58,6 +60,7 @@ def __init__( self.num_layers = num_layers self.sh_scale = sh_scale self.output_activation = output_activation + self.unpremultiply_alpha = unpremultiply_alpha self._ema_decay = ema_decay self._ema_start_step = ema_start_step self._ema_shadow: dict[str, torch.Tensor] = {} @@ -146,8 +149,7 @@ def forward( Args: features: Input features of shape [H*W, N] or [B, H, W, N] (alpha-blended). ray_directions: Ray directions of shape [H*W, 3] or [B, H, W, 3]. - alpha: Optional [H*W] or [B, H, W] opacity. If given, un-multiply (features/alpha), - decode, then re-multiply (rgb*alpha) so the decoder sees unblended features. + alpha: Optional opacity used only when unpremultiply_alpha is enabled. Returns: RGB tensor of shape [H*W, 3] or [B, H, W, 3] @@ -182,7 +184,7 @@ def _process( ray_directions: torch.Tensor, alpha: torch.Tensor | None = None, ) -> torch.Tensor: - if alpha is not None: + if self.unpremultiply_alpha and alpha is not None: alpha_safe = alpha.clamp(min=1e-8) features = features / alpha_safe @@ -191,7 +193,7 @@ def _process( full_input = torch.cat([features, dirs_unit_cube], dim=-1) rgb = self.network(full_input) - if alpha is not None: + if self.unpremultiply_alpha and alpha is not None: rgb = rgb * alpha_safe return rgb.float() @@ -208,5 +210,6 @@ def extra_repr(self) -> str: f"hidden_dim={self.hidden_dim}, " f"num_layers={self.num_layers}, " f"sh_scale={self.sh_scale}, " - f"output_activation={self.output_activation}" + f"output_activation={self.output_activation}, " + f"unpremultiply_alpha={self.unpremultiply_alpha}" ) diff --git a/threedgrut/model/features.py b/threedgrut/model/features.py index aaf727a0..e8de2ca6 100644 --- a/threedgrut/model/features.py +++ b/threedgrut/model/features.py @@ -145,12 +145,13 @@ def interp_point_feature_dim(self): @property def ray_feature_dim(self): - """Per-ray feature dim (decoder input): interp_point_dim * num_frequencies.""" + """Per-ray feature dim after optional harmonic expansion.""" feature_type = self._conf.model.feature_type.lower() if feature_type == "sh": return 3 # RGB output elif feature_type == "nht": - return self.interp_point_feature_dim * self.activation_num_frequencies + expansion = 2 if self.activation_type == Features.ActivationType.SINCOS else 1 + return self.interp_point_feature_dim * self.activation_num_frequencies * expansion raise ValueError(f"Unknown feature_type: {feature_type}") @property diff --git a/threedgrut/model/model.py b/threedgrut/model/model.py index 460e97aa..b08e9cc9 100644 --- a/threedgrut/model/model.py +++ b/threedgrut/model/model.py @@ -549,15 +549,12 @@ def init_from_random_point_cloud( (num_gaussians, num_specular_features), dtype=dtype, device=self.device ).contiguous() elif self.feature_type == Features.Type.NHT: - # Initialize learned features with uniform [-pi/2, pi/2] for SIREN - act_type = Features(self.conf).activation_type - if act_type in (Features.ActivationType.SIREN, Features.ActivationType.SINCOS): - features = ( - torch.rand((num_gaussians, self.particle_feature_dim), dtype=dtype, device=self.device) - * 3.141592653589793 - 1.5707963267948966 - ) - else: - features = torch.randn((num_gaussians, self.particle_feature_dim), dtype=dtype, device=self.device) * 0.1 + init_min = float(getattr(self.conf.model.nht_features, "init_min", -5.0)) + init_max = float(getattr(self.conf.model.nht_features, "init_max", 5.0)) + features = ( + torch.rand((num_gaussians, self.particle_feature_dim), dtype=dtype, device=self.device) + * (init_max - init_min) + init_min + ) dist = torch.clamp_min(nearest_neighbor_dist_cpuKD(fused_point_cloud), 1e-3) scales = torch.log(dist * self.conf.model.default_scale_factor)[..., None].repeat(1, 3) @@ -725,24 +722,12 @@ def default_initialize_from_points(self, pts, observer_pts, colors=None, use_obs num_specular_dims = sh_degree_to_specular_dim(self.max_n_features) features_specular = torch.zeros((N, num_specular_dims)) elif self.feature_type == Features.Type.NHT: - act_type = Features(self.conf).activation_type - if act_type in (Features.ActivationType.SIREN, Features.ActivationType.SINCOS): - # Uniform [-pi/2, pi/2]: symmetric around 0 so sin activations start - # with zero mean and balanced positive/negative gradients. - features = ( - torch.rand((N, self.particle_feature_dim), dtype=dtype, device=self.device) - * 3.141592653589793 # pi - - 1.5707963267948966 # - pi/2 - ) - elif act_type == Features.ActivationType.RELU: - # Same as SH degree 0: init from RGB so relu(features)=radiance in [0,1]. - rgb = (colors.float() / 255.0).to(dtype=dtype, device=self.device) - if self.particle_feature_dim == 3: - features = rgb - else: - features = rgb.repeat(1, (self.particle_feature_dim + 2) // 3)[:, : self.particle_feature_dim] - else: - features = torch.randn((N, self.particle_feature_dim), dtype=dtype, device=self.device) * 0.01 + init_min = float(getattr(self.conf.model.nht_features, "init_min", -5.0)) + init_max = float(getattr(self.conf.model.nht_features, "init_max", 5.0)) + features = ( + torch.rand((N, self.particle_feature_dim), dtype=dtype, device=self.device, generator=rng) + * (init_max - init_min) + init_min + ) self.positions = torch.nn.Parameter(positions.to(dtype=dtype, device=self.device)) self.rotation = torch.nn.Parameter(rots.to(dtype=dtype, device=self.device)) diff --git a/threedgrut/render.py b/threedgrut/render.py index 18d69bb1..ee2f7fa0 100644 --- a/threedgrut/render.py +++ b/threedgrut/render.py @@ -163,7 +163,9 @@ def from_checkpoint( num_layers = getattr(dec, "num_layers", 4) dir_encoding = getattr(dec, "dir_encoding", "SphericalHarmonics") dir_encoding_degree = getattr(dec, "dir_encoding_degree", 3) + sh_scale = getattr(dec, "sh_scale", 1.0) output_activation = getattr(dec, "output_activation", "Sigmoid") + unpremultiply_alpha = getattr(dec, "unpremultiply_alpha", False) ema_decay = getattr(dec, "ema_decay", 0.0) ema_start_step = getattr(dec, "ema_start_step", 0) feature_decoder = FeatureDecoder( @@ -172,9 +174,11 @@ def from_checkpoint( num_layers=num_layers, dir_encoding=dir_encoding, dir_encoding_degree=dir_encoding_degree, + sh_scale=sh_scale, output_activation=output_activation, ema_decay=ema_decay, ema_start_step=ema_start_step, + unpremultiply_alpha=unpremultiply_alpha, ).to("cuda") feature_decoder.load_state_dict(checkpoint["feature_decoder"]["module"]) ema_state = checkpoint["feature_decoder"].get("ema") diff --git a/threedgrut/trainer.py b/threedgrut/trainer.py index 88307454..30f03abf 100644 --- a/threedgrut/trainer.py +++ b/threedgrut/trainer.py @@ -85,6 +85,9 @@ class Trainer3DGRUT: _distillation_start_step: int = -1 """ Step at which distillation starts (-1 means disabled) """ + _color_refine_frozen_param_names = frozenset(("positions", "scale", "rotation", "density")) + """ Gaussian optimizer parameter groups frozen during NHT color refinement """ + @staticmethod def create_from_checkpoint(resume: str, conf: DictConfig): """Create a new trainer from a checkpoint file""" @@ -119,6 +122,10 @@ def __init__(self, conf: DictConfig, device=None): """ Total number of train epochs / passes, e.g. single pass over the dataset.""" self.val_frequency = conf.val_frequency """ Validation frequency, in terms on global steps """ + self._color_refine_start_step = self._get_color_refine_start_step(conf) + """ Step at which NHT color refinement starts """ + self._in_color_refine = False + """ Whether NHT color refinement is active """ # Setup the trainer and components logger.log_rule("Load Datasets") @@ -136,6 +143,50 @@ def __init__(self, conf: DictConfig, device=None): self.init_experiments_tracking(conf) self.init_gui(conf, self.model, self.train_dataset, self.val_dataset, self.scene_bbox) + def _get_color_refine_start_step(self, conf: DictConfig) -> int: + """Return the first step of the NHT color-only refinement phase.""" + feature_type = str(OmegaConf.select(conf, "model.feature_type", default="sh")).lower() + if feature_type != "nht": + return conf.n_iterations + + color_refine_steps = int(OmegaConf.select(conf, "model.nht_decoder.color_refine_steps", default=0) or 0) + if color_refine_steps <= 0: + return conf.n_iterations + + return max(0, conf.n_iterations - color_refine_steps) + + def _is_color_refine_active(self, global_step: int) -> bool: + return global_step >= self._color_refine_start_step and self._color_refine_start_step < self.conf.n_iterations + + def _apply_color_refine_freeze(self, global_step: int) -> None: + """Freeze Gaussian geometry/opacity optimizer groups while colors keep training.""" + if not self._is_color_refine_active(global_step): + return + + if not self._in_color_refine: + self._in_color_refine = True + self.strategy.suspend() + logger.info( + f"🎨 [step {global_step}] Entering NHT color refinement: " + "freezing geometry + opacity and disabling scale/opacity regularization." + ) + + if self.model.optimizer is None: + return + + for param_group in self.model.optimizer.param_groups: + if param_group.get("name") in self._color_refine_frozen_param_names: + param_group["lr"] = 0.0 + + def _zero_color_refine_frozen_grads(self) -> None: + if not self._in_color_refine or self.model.optimizer is None: + return + + for param_group in self.model.optimizer.param_groups: + if param_group.get("name") in self._color_refine_frozen_param_names: + for param in param_group["params"]: + param.grad = None + def init_dataloaders(self, conf: DictConfig): from threedgrut.datasets.utils import configure_dataloader_for_platform @@ -467,6 +518,7 @@ def init_feature_decoder(self, conf: DictConfig): dir_encoding_degree = getattr(dec, "dir_encoding_degree", 3) sh_scale = getattr(dec, "sh_scale", 1.0) output_activation = getattr(dec, "output_activation", "Sigmoid") + unpremultiply_alpha = getattr(dec, "unpremultiply_alpha", False) ema_decay = getattr(dec_conf, "ema_decay", 0.0) ema_start_step = getattr(dec_conf, "ema_start_step", 0) logger.info(f"Initializing FeatureDecoder: {ray_feature_dim} -> 3 RGB") @@ -480,6 +532,7 @@ def init_feature_decoder(self, conf: DictConfig): output_activation=output_activation, ema_decay=ema_decay, ema_start_step=ema_start_step, + unpremultiply_alpha=unpremultiply_alpha, ).to(self.device) lr = dec.learning_rate @@ -637,7 +690,7 @@ def get_losses( # Opacity regularization loss_opacity = torch.zeros(1, device=self.device) lambda_opacity = 0.0 - if self.conf.loss.use_opacity: + if self.conf.loss.use_opacity and not self._in_color_refine: with torch.cuda.nvtx.range(f"loss-opacity"): loss_opacity = torch.abs(self.model.get_density()).mean() lambda_opacity = self.conf.loss.lambda_opacity @@ -645,7 +698,7 @@ def get_losses( # Scale regularization loss_scale = torch.zeros(1, device=self.device) lambda_scale = 0.0 - if self.conf.loss.use_scale: + if self.conf.loss.use_scale and not self._in_color_refine: with torch.cuda.nvtx.range(f"loss-scale"): loss_scale = torch.abs(self.model.get_scale()).mean() lambda_scale = self.conf.loss.lambda_scale @@ -805,6 +858,8 @@ def log_training_iter( post_processing_reg_loss, global_step, ) + if self._color_refine_start_step < self.conf.n_iterations: + writer.add_scalar("train/color_refine", float(self._in_color_refine), global_step) if "psnr" in batch_metrics: writer.add_scalar("psnr/train", batch_metrics["psnr"], self.global_step) if "ssim" in batch_metrics: @@ -964,7 +1019,9 @@ def save_checkpoint(self, last_checkpoint: bool = False): "ray_feature_dim": dec.ray_feature_dim, "hidden_dim": dec.hidden_dim, "num_layers": dec.num_layers, + "sh_scale": dec.sh_scale, "output_activation": dec.output_activation, + "unpremultiply_alpha": dec.unpremultiply_alpha, }, } ema_state = self.feature_decoder.ema_state_dict() @@ -1028,6 +1085,8 @@ def run_train_iter( metrics: list, conf: DictConfig, ): + self._apply_color_refine_freeze(global_step) + # Freeze Gaussians and suspend strategy when distillation starts if self._distillation_start_step >= 0 and global_step >= self._distillation_start_step: self.model.freeze_gaussians() @@ -1108,6 +1167,7 @@ def run_train_iter( # Optimizer step with torch.cuda.nvtx.range(f"train_{global_step}_backprop"): + self._zero_color_refine_frozen_grads() if isinstance(self.model.optimizer, SelectiveAdam): assert ( outputs["mog_visibility"].shape == self.model.density.shape @@ -1120,6 +1180,7 @@ def run_train_iter( # Scheduler step with torch.cuda.nvtx.range(f"train_{global_step}_scheduler"): self.model.scheduler_step(global_step) + self._apply_color_refine_freeze(global_step) # Feature decoder optimizer/scheduler step if self.feature_decoder_optimizer is not None: diff --git a/threedgut_tracer/include/3dgut/kernels/cuda/models/gaussianParticles.cuh b/threedgut_tracer/include/3dgut/kernels/cuda/models/gaussianParticles.cuh index 0d311552..7200ae48 100644 --- a/threedgut_tracer/include/3dgut/kernels/cuda/models/gaussianParticles.cuh +++ b/threedgut_tracer/include/3dgut/kernels/cuda/models/gaussianParticles.cuh @@ -15,6 +15,10 @@ #include <3dgut/kernels/cuda/common/mathUtils.cuh> +#if PARTICLE_FEATURE_HALF +#include +#endif + namespace threedgut { struct ParticleDensity { @@ -382,7 +386,7 @@ __device__ inline bool processHitFwd( const float grayDist = dot(gcrod, gcrod); const float gres = particleResponse(grayDist); - const float galpha = fminf(0.99f, gres * particleDensity); + const float galpha = fminf(GAUSSIAN_PARTICLE_MAX_ALPHA, gres * particleDensity); const bool acceptHit = (gres > minParticleKernelDensity) && (galpha > minParticleAlpha); if (acceptHit) { @@ -525,7 +529,7 @@ __device__ inline void processHitBwd( const float grayDist = dot(gcrod, gcrod); const float gres = particleResponse(grayDist); - const float galpha = fminf(0.99f, gres * particleDensity); + const float galpha = fminf(GAUSSIAN_PARTICLE_MAX_ALPHA, gres * particleDensity); if ((gres > minParticleKernelDensity) && (galpha > minParticleAlpha)) { @@ -746,4 +750,237 @@ __device__ inline void processHitBwd( } } +// ======================================================================================= +// Neural Harmonic Textures — handwritten CUDA backward. +// +// Mirrors `particleFeaturesIntegrateBwdToBuffer` from neuralHarmonicFeaturesParticle.slang +// called with `exclusiveGradient=true` and a shifted local-grad buffer (see +// `featuresIntegrateBwdToLocalGrad` in shRadiativeGaussianParticles.cuh for context). +// +// Scope: tetrahedral 4-vertex barycentric interpolation + {None, Siren, Sincos, Relu} +// activation, fp16/fp32 particle feature storage, fp32 gradients. +// ======================================================================================= +namespace nht { + +// Canonical tetrahedron matching the GSplat NHT reference vertex ordering. +// v0=(sqrt(6),-sqrt(2),-1), v1=(-sqrt(6),-sqrt(2),-1), v2=(0,2*sqrt(2),-1), v3=(0,0,3). +__forceinline__ __device__ float3 tetraV0() { + return make_float3(2.449489742783178f, -1.4142135623730951f, -1.0f); +} +// Scaled inward face normals N_k = d(w_k)/dP from the GSplat reference tetrahedron. +// Verified: w_k = 1 exactly at v_k, 0 at the other vertices, sum_k w_k = 1. +__forceinline__ __device__ float3 tetraN0() { + return make_float3( 0.20412414523193154f, -0.11785113019775792f, -0.08333333333333333f); +} +__forceinline__ __device__ float3 tetraN1() { + return make_float3(-0.20412414523193154f, -0.11785113019775792f, -0.08333333333333333f); +} +__forceinline__ __device__ float3 tetraN2() { + return make_float3( 0.00000000000000000f, 0.23570226039551587f, -0.08333333333333333f); +} +__forceinline__ __device__ float3 tetraN3() { + return make_float3( 0.00000000000000000f, 0.00000000000000000f, 0.25000000000000000f); +} + +// Mirrors neuralHarmonicFeaturesParticle.slang's FeatureActivationType_*. +enum ActivationType : int { + ActivationNone = 0, + ActivationSiren = 1, + ActivationSincos = 2, + ActivationRelu = 3, +}; + +// Single-element load promoting fp16 to fp32 when needed; identity on float. +// Uses __half's device `operator float()` (cuda_fp16.h) when TFeatElem is __half. +template +__forceinline__ __device__ float loadFeatureElem(const TFeatElem* p) { + return static_cast(*p); +} + +/** + * Handwritten reverse of Slang's `particleFeaturesIntegrateBwdToBuffer` for the NHT path. + * Writes per-vertex gradient contributions directly into a thread-private `featureLocalGrad` + * scratch buffer of size `4*InterpPointFeatureDim` (caller warp-reduces + atomic-adds + * downstream — see `featureLocalGradWarpReduceAndWrite`). + * + * Buffer layout assumption: `vertex_k[i] = particleFeatureBufPtr[particleIdx*4*IPFD + k*IPFD + i]`. + * + * Walk order: renderer traverses hits front-to-back (near → far) for both fwd and bwd. + * Compositing algebra: Slang decomposes fwd as the equivalent back-to-front lerp + * C_i = alpha_i * f_i + (1 - alpha_i) * C_{i+1}, C_0 = forward-final color, C_N = 0. + * Each call to this function unwinds one lerp step of the current particle, advancing + * the stored accumulator from C_i to C_{i+1} — consistent with the front-to-back walk + * starting from `integratedFeatures = C_0` (set by initializeBackwardRay). + * + * Mutations (match Slang `particleFeaturesIntegrateBwdToBuffer` exactly): + * - integratedFeatures[i] = (integratedFeatures[i] - features[i] * alpha) / (1 - alpha) + * (recovers C_{i+1} from C_i; seen by the next hit) + * - integratedFeaturesGrad[i] = (1 - alpha) * d(acc_post)_i + * (d(C_i)/d(C_{i+1}) from the lerp VJP) + * - alphaGrad += Σ_i (features[i] - acc_prev_i) * d(acc_post)_i + * - canonicalIntersectionGrad += barycentric VJP + * - featureLocalGrad[k*IPFD+i]+= w_k * dBase[i] (matches `exclusiveGradient=true`) + * + * No-op when `alpha <= 0`. + */ +template +__forceinline__ __device__ void featuresIntegrateBwdToLocalGrad( + const float3& canonicalIntersection, + float3& canonicalIntersectionGrad, + float alpha, + float& alphaGrad, + uint32_t particleIdx, + const float* features, // [RayFeatureDim] read-only + float* integratedFeatures, // [RayFeatureDim] inout (→ acc_prev) + float* integratedFeaturesGrad, // [RayFeatureDim] inout (→ d(acc_prev)) + const TFeatElem* particleFeatureBufPtr, // [N * 4*IPFD] + float* featureLocalGrad) { // [4*IPFD] inout (+= accumulator) + + static_assert(ActivationType >= 0 && ActivationType <= 3, + "unsupported FEATURE_ACTIVATION_TYPE for NHT CUDA backward"); + constexpr int kIPFD = InterpPointFeatureDim; + constexpr int kRFD = (ActivationType == ActivationNone || ActivationType == ActivationRelu) + ? kIPFD + : (ActivationType == ActivationSincos ? kIPFD * NumFrequencies * 2 : kIPFD * NumFrequencies); + constexpr int kPFD = 4 * kIPFD; + + if (alpha <= 0.0f) { + return; + } + + const float oneMinusAlpha = 1.0f - alpha; + const float invOneMinusAlpha = 1.0f / oneMinusAlpha; + + // Step 1+2: advance stored accumulator C_i -> C_{i+1} (one lerp unwind) AND apply the lerp VJP. + // fwd (lerp form): y_i = (1-alpha)*acc_prev_i + alpha*f_i + // unwind: acc_prev_i = (y_i - alpha*f_i) / (1-alpha) + // VJP: d(acc_prev)_i = (1-alpha) * dy_i + // d(f)_i = alpha * dy_i + // d(alpha) += sum_i (f_i - acc_prev_i) * dy_i + float dFeatures[kRFD]; + float dAlphaAcc = 0.0f; +#pragma unroll + for (int i = 0; i < kRFD; ++i) { + const float dy_i = integratedFeaturesGrad[i]; + const float accPrev_i = (integratedFeatures[i] - features[i] * alpha) * invOneMinusAlpha; + dFeatures[i] = alpha * dy_i; + dAlphaAcc += (features[i] - accPrev_i) * dy_i; + integratedFeatures[i] = accPrev_i; + integratedFeaturesGrad[i] = oneMinusAlpha * dy_i; + } + alphaGrad += dAlphaAcc; + + // Barycentric weights (Slang Cramer form). + float w[4]; + { + const float3 v0 = tetraV0(); + const float3 N1 = tetraN1(); + const float3 N2 = tetraN2(); + const float3 N3 = tetraN3(); + const float3 d = make_float3(canonicalIntersection.x - v0.x, + canonicalIntersection.y - v0.y, + canonicalIntersection.z - v0.z); + w[1] = d.x * N1.x + d.y * N1.y + d.z * N1.z; + w[2] = d.x * N2.x + d.y * N2.y + d.z * N2.z; + w[3] = d.x * N3.x + d.y * N3.y + d.z * N3.z; + w[0] = 1.0f - w[1] - w[2] - w[3]; + } + + // Load all 4 vertex feature blocks once (fp16 → fp32 if applicable). + const uint32_t particleOffset = particleIdx * kPFD; + float vert[4][kIPFD]; +#pragma unroll + for (int k = 0; k < 4; ++k) { + const uint32_t vkOff = particleOffset + k * kIPFD; +#pragma unroll + for (int i = 0; i < kIPFD; ++i) { + vert[k][i] = loadFeatureElem(&particleFeatureBufPtr[vkOff + i]); + } + } + + // Step 3: activation backward → dBase[kIPFD]. + float dBase[kIPFD]; + if constexpr (ActivationType == ActivationNone) { +#pragma unroll + for (int i = 0; i < kIPFD; ++i) { + dBase[i] = dFeatures[i]; + } + } else if constexpr (ActivationType == ActivationRelu) { + // features[i] = max(0, base[i]); mask from features[i] > 0 (bit-identical to base[i] > 0). +#pragma unroll + for (int i = 0; i < kIPFD; ++i) { + dBase[i] = (features[i] > 0.0f) ? dFeatures[i] : 0.0f; + } + } else { + // Siren / Sincos: need baseFeatures for the trig derivative. + float baseFeatures[kIPFD] = {}; +#pragma unroll + for (int k = 0; k < 4; ++k) { +#pragma unroll + for (int i = 0; i < kIPFD; ++i) { + baseFeatures[i] += w[k] * vert[k][i]; + } + } +#pragma unroll + for (int i = 0; i < kIPFD; ++i) { + dBase[i] = 0.0f; + } + if constexpr (ActivationType == ActivationSiren) { + // features[k*NFreq+f] = sin(base[k] * 2^f) → dBase[k] += cos(angle)*freq*dOut +#pragma unroll + for (int k = 0; k < kIPFD; ++k) { +#pragma unroll + for (int f = 0; f < NumFrequencies; ++f) { + const float freq = ldexpf(1.0f, f); + const float angle = baseFeatures[k] * freq; + dBase[k] += cosf(angle) * freq * dFeatures[k * NumFrequencies + f]; + } + } + } else { + // ActivationSincos follows GSplat: separate sin/cos channels per frequency. +#pragma unroll + for (int k = 0; k < kIPFD; ++k) { +#pragma unroll + for (int f = 0; f < NumFrequencies; ++f) { + const float freq = static_cast(f + 1); + const float angle = baseFeatures[k] * freq; + float s, c; + __sincosf(angle, &s, &c); + const int outIdx = k * NumFrequencies * 2 + f * 2; + dBase[k] += (c * dFeatures[outIdx] - s * dFeatures[outIdx + 1]) * freq; + } + } + } + } + + // Step 4: barycentric backward. + // d(vertex_k[i]) = w_k * dBase[i] → featureLocalGrad (+= matches Slang exclusiveGradient=true) + // d(canonicalPos) += sum_k (sum_i vert_k[i]*dBase[i]) * N_k + float3 dP = make_float3(0.0f, 0.0f, 0.0f); +#pragma unroll + for (int k = 0; k < 4; ++k) { + float dw_k = 0.0f; +#pragma unroll + for (int i = 0; i < kIPFD; ++i) { + featureLocalGrad[k * kIPFD + i] += w[k] * dBase[i]; + dw_k += vert[k][i] * dBase[i]; + } + const float3 Nk = (k == 0) ? tetraN0() + : (k == 1) ? tetraN1() + : (k == 2) ? tetraN2() + : tetraN3(); + dP.x += dw_k * Nk.x; + dP.y += dw_k * Nk.y; + dP.z += dw_k * Nk.z; + } + canonicalIntersectionGrad.x += dP.x; + canonicalIntersectionGrad.y += dP.y; + canonicalIntersectionGrad.z += dP.z; +} + +} // namespace nht + } // namespace threedgut \ No newline at end of file diff --git a/threedgut_tracer/include/3dgut/kernels/cuda/models/shRadiativeGaussianParticles.cuh b/threedgut_tracer/include/3dgut/kernels/cuda/models/shRadiativeGaussianParticles.cuh index fda4e956..497f4f61 100644 --- a/threedgut_tracer/include/3dgut/kernels/cuda/models/shRadiativeGaussianParticles.cuh +++ b/threedgut_tracer/include/3dgut/kernels/cuda/models/shRadiativeGaussianParticles.cuh @@ -20,6 +20,15 @@ #if PARTICLE_FEATURE_HALF #include #endif + +// Select backend for the NHT-path feature integration local-grad backward. +// 1 → handwritten CUDA (threedgut::nht::featuresIntegrateBwdToLocalGrad), +// 0 → Slang autodiff (threedgutSlang.cuh particleFeaturesIntegrateBwdToBuffer). +// Only effective when FEATURE_TRANSFORM_TYPE == 1 (NHT). A/B switch for perf work. +#ifndef NHT_FEATURES_BWD_LOCAL_GRAD_CUDA +#define NHT_FEATURES_BWD_LOCAL_GRAD_CUDA 1 +#endif + template struct ShRadiativeGaussianParticlesBuffer { TBuffer* ptr = nullptr; @@ -394,6 +403,40 @@ struct ShRadiativeGaussianVolumetricFeaturesParticles : Params, public ExtParams TFeaturesVec& integratedFeatures, TFeaturesVec& integratedFeaturesGrad, float* featureLocalGrad) const { +#if NHT_FEATURES_BWD_LOCAL_GRAD_CUDA + if constexpr (TDifferentiable) { + static_assert(ExtParams::FeatureTransformType == 1, + "NHT CUDA backward is only valid on the NHT path (FEATURE_TRANSFORM_TYPE==1)"); + static_assert(FEATURE_INTERPOLATION_TYPE == 0, + "only barycentric interpolation supported"); + static_assert(FEATURE_INTERPOLATION_SUPPORT == 1, + "only tetrahedral interpolation support supported"); + static_assert(4 * INTERP_POINT_FEATURE_DIM == ExtParams::ParticleFeatureDim, + "NHT buffer layout must be 4 vertices * InterpPointFeatureDim"); + static_assert(((FEATURE_ACTIVATION_TYPE == 0 || FEATURE_ACTIVATION_TYPE == 3) && + RAY_FEATURE_DIM == INTERP_POINT_FEATURE_DIM) || + (FEATURE_ACTIVATION_TYPE == 1 && + RAY_FEATURE_DIM == INTERP_POINT_FEATURE_DIM * FEATURE_ACTIVATION_NUM_FREQUENCIES) || + (FEATURE_ACTIVATION_TYPE == 2 && + RAY_FEATURE_DIM == INTERP_POINT_FEATURE_DIM * FEATURE_ACTIVATION_NUM_FREQUENCIES * 2), + "RAY_FEATURE_DIM / INTERP_POINT_FEATURE_DIM / activation mismatch"); + + threedgut::nht::featuresIntegrateBwdToLocalGrad< + INTERP_POINT_FEATURE_DIM, + FEATURE_ACTIVATION_NUM_FREQUENCIES, + FEATURE_ACTIVATION_TYPE>( + canonicalIntersection, + canonicalIntersectionGrad, + alpha, + alphaGrad, + particleIdx, + reinterpret_cast(&features), + reinterpret_cast(&integratedFeatures), + reinterpret_cast(&integratedFeaturesGrad), + reinterpret_cast(m_featureRawParameters.ptr), + featureLocalGrad); + } +#else if constexpr (TDifferentiable) { // Pointer offset trick: shift featureLocalGrad back by particleOffset so Slang's // _gradPtr[interpPointOffset + i] writes land in featureLocalGrad[0..ParticleFeatureDim-1]. @@ -416,6 +459,7 @@ struct ShRadiativeGaussianVolumetricFeaturesParticles : Params, public ExtParams reinterpret_cast*>(&integratedFeatures), reinterpret_cast*>(&integratedFeaturesGrad)); } +#endif // NHT_FEATURES_BWD_LOCAL_GRAD_CUDA } // NHT warp reduction step 2: warp-reduce featureLocalGrad and atomicAdd to global gradient buffer. diff --git a/threedgut_tracer/include/3dgut/kernels/cuda/renderers/gutKBufferRenderer.cuh b/threedgut_tracer/include/3dgut/kernels/cuda/renderers/gutKBufferRenderer.cuh index c948910b..2665df25 100644 --- a/threedgut_tracer/include/3dgut/kernels/cuda/renderers/gutKBufferRenderer.cuh +++ b/threedgut_tracer/include/3dgut/kernels/cuda/renderers/gutKBufferRenderer.cuh @@ -18,6 +18,19 @@ #include <3dgut/kernels/cuda/common/rayPayloadBackward.cuh> #include <3dgut/renderer/gutRendererParameters.h> +// NHT backward diagnostic toggle — register-pressure attribution (plan T1). +// Mode selects which components of the NHT bwd inner loop are live. +// The modes are strictly nested so each one also disables everything below it. +// 0 = baseline, full pipeline (default) +// 1 = disable featureLocalGradWarpReduceAndWrite (warp reduce + atomics) +// 2 = also disable featuresIntegrateBwdToLocalGrad (local-grad compute) +// 3 = also disable densityProcessHitBwdToBuffer (Slang density bwd) +// Lower-level (higher-mode) toggles help isolate which symbol is parked in +// the 214-reg live set. Behavior is intentionally broken under modes > 0. +#ifndef NHT_BWD_DIAG_MODE +#define NHT_BWD_DIAG_MODE 0 +#endif + // HitParticle stores per-ray hit state inside the k-buffer. The // `canonicalIntersection` is only consumed by the NHT feature path; for the // SH path (`PerRayParticleFeatures=false`) we drop it from the struct so the @@ -599,6 +612,7 @@ struct GUTKBufferRenderer : Params { float3 canonicalIntersectionGrad = make_float3(0.f, 0.f, 0.f); // Write feature grad to thread-private local buffer (no atomics); warp reduction follows below. +#if NHT_BWD_DIAG_MODE < 2 particles.featuresIntegrateBwdToLocalGrad(ray.direction, canonicalIntersection, canonicalIntersectionGrad, @@ -609,7 +623,9 @@ struct GUTKBufferRenderer : Params { ray.featuresBackward, ray.featuresGradient, featureLocalGrad); +#endif +#if NHT_BWD_DIAG_MODE < 3 particles.template densityProcessHitBwdToBuffer(ray.origin, ray.direction, particleData.idx, @@ -621,6 +637,7 @@ struct GUTKBufferRenderer : Params { ray.hitTBackward, ray.hitTGradient, canonicalIntersectionGrad); +#endif ray.transmittance *= (1.0f - hitAlpha); } @@ -633,7 +650,11 @@ struct GUTKBufferRenderer : Params { // Warp reduction: all 32 threads (alive or not) participate in __shfl_xor_sync. // Non-hitting threads contribute featureLocalGrad=0 → no spurious gradient. // Reduces 32×ParticleFeatureDim atomics per particle to ParticleFeatureDim atomics. +#if NHT_BWD_DIAG_MODE < 1 particles.featureLocalGradWarpReduceAndWrite(particleData.idx, featureLocalGrad, tileThreadIdx); +#else + (void)featureLocalGrad; +#endif } } } else { diff --git a/threedgut_tracer/include/3dgut/kernels/cuda/renderers/gutRenderer.cuh b/threedgut_tracer/include/3dgut/kernels/cuda/renderers/gutRenderer.cuh index 06aca8ee..364a94f6 100644 --- a/threedgut_tracer/include/3dgut/kernels/cuda/renderers/gutRenderer.cuh +++ b/threedgut_tracer/include/3dgut/kernels/cuda/renderers/gutRenderer.cuh @@ -216,7 +216,18 @@ __global__ void renderBalanced(threedgut::RenderParameters params, } #endif // FINE_GRAINED_LOAD_BALANCING -__global__ void renderBackward(threedgut::RenderParameters params, +// Optional register-cap on the backward kernel (plan T3). +// Set at build time with both args, e.g. +// -DNHT_BWD_LB_THREADS=256 -DNHT_BWD_LB_MIN_BLOCKS=2 +// to force <=128 regs/thread at BlockSize=256. Default leaves the compiler +// free, matching baseline behavior (regs/thread reported by ptxas). +#if defined(NHT_BWD_LB_THREADS) && defined(NHT_BWD_LB_MIN_BLOCKS) +#define NHT_BWD_LB __launch_bounds__(NHT_BWD_LB_THREADS, NHT_BWD_LB_MIN_BLOCKS) +#else +#define NHT_BWD_LB +#endif + +__global__ NHT_BWD_LB void renderBackward(threedgut::RenderParameters params, const tcnn::uvec2* __restrict__ sortedTileRangeIndicesPtr, const uint32_t* __restrict__ sortedTileDataPtr, const tcnn::vec3* __restrict__ sensorRayOriginPtr, diff --git a/threedgut_tracer/include/3dgut/kernels/slang/models/neuralHarmonicFeaturesParticle.slang b/threedgut_tracer/include/3dgut/kernels/slang/models/neuralHarmonicFeaturesParticle.slang index 43ece455..af653a28 100644 --- a/threedgut_tracer/include/3dgut/kernels/slang/models/neuralHarmonicFeaturesParticle.slang +++ b/threedgut_tracer/include/3dgut/kernels/slang/models/neuralHarmonicFeaturesParticle.slang @@ -17,7 +17,8 @@ // Same CudaDeviceExport API as shRadiativeParticles.slang for drop-in replacement. // Compile-time: FEATURE_INTERPOLATION_TYPE, FEATURE_INTERPOLATION_SUPPORT, FEATURE_ACTIVATION_TYPE, FEATURE_ACTIVATION_NUM_FREQUENCIES, // PARTICLE_FEATURE_DIM (total K per particle = buffer stride), INTERP_POINT_FEATURE_DIM (per-interpolation-point = K/num_points). -// Support: center -> K=interpPointDim; tetrahedra -> K=4*interpPointDim. With activation: N = interpPointDim * num_frequencies (decoder input). +// Support: center -> K=interpPointDim; tetrahedra -> K=4*interpPointDim. +// With sincos activation: N = interpPointDim * num_frequencies * 2 (separate sin/cos channels). namespace neuralHarmonicFeaturesParticle { @@ -42,8 +43,8 @@ static const int InterpolationSupport_Tetrahedra = 1; static const int InterpolationSupport_CoTriangles = 2; // 2 coplanar triangles / Not supported yet static const int InterpolationSupport = FEATURE_INTERPOLATION_SUPPORT; -// Canonical regular tetrahedron that contains the unit sphere; each face is tangent to the unit sphere -// (insphere radius = 1). Same geometry as particlePrimitives.cu (enclosing tetrahedra). +// Canonical regular tetrahedron matching the GSplat NHT reference layout: +// p0=(sqrt(6),-sqrt(2),-1), p1=(-sqrt(6),-sqrt(2),-1), p2=(0,2*sqrt(2),-1), p3=(0,0,3). // Incenter at origin; base z=-1, apex z=3; edge s=sqrt(24), height h=4, inradius r=1. static const float tetraHedraEdge = 4.898979485566356f; // sqrt(24) static const float tetraHedraFaceHeight = 4.242640687119285f; // s*sqrt(3)/2 @@ -51,10 +52,10 @@ static const float tetraHedraHeight = 4.0f; // s*sqrt(2/3) static const float tetraHedraFaceInRadius = 1.4142135623730951f; // s*sqrt(3)/6 = sqrt(2) static const float tetraHedraInRadius = 1.0f; // unit sphere static const float3 canonicalTetraVerts[4] = { + float3(0.5f * tetraHedraEdge, -tetraHedraFaceInRadius, -1.0f), float3(-0.5f * tetraHedraEdge, -tetraHedraFaceInRadius, -1.0f), float3(0.0f, tetraHedraFaceHeight - tetraHedraFaceInRadius, -1.0f), - float3(0.0f, 0.0f, tetraHedraHeight - tetraHedraInRadius), - float3(0.5f * tetraHedraEdge, -tetraHedraFaceInRadius, -1.0f) + float3(0.0f, 0.0f, tetraHedraHeight - tetraHedraInRadius) }; // Cramer terms for canonical tetrahedron (independent of P); used by barycentricTetrahedronCanonical. static const float3 canonicalTetraE1 = canonicalTetraVerts[1] - canonicalTetraVerts[0]; @@ -132,7 +133,7 @@ float4 barycentricTetrahedronCanonical(float3 P) return weights; } -// Encode and activate : none -> identity; siren -> sin(b*2^f); sincos -> sin(b*2^f)+cos(b*2^f); relu -> max(0,b). +// Encode and activate : none -> identity; siren -> sin(b*2^f); relu -> max(0,b). [BackwardDifferentiable][ForceInline] float encodeAndActivate(float baseVal, no_diff int f) { @@ -142,9 +143,7 @@ float encodeAndActivate(float baseVal, no_diff int f) return max(0.0f, baseVal); float freq = ldexp(1.0f, f); float angle = baseVal * freq; - if (FeatureActivationType == FeatureActivationType_Siren) - return sin(angle); - return sin(angle) + cos(angle); + return sin(angle); } // Compute blended features into baseFeatures[INTERP_POINT_FEATURE_DIM], then optionally expand by activation to features[RayFeatureDim]. @@ -174,6 +173,20 @@ void featuresFromParametersBuffer(ParametersBuffer parametersBuffer, [ForceUnroll] for (int i = 0; i < InterpPointFeatureDim; ++i) { features[i] = baseFeatures[i]; } + } else if (FeatureActivationType == FeatureActivationType_Relu) { + [ForceUnroll] for (int i = 0; i < InterpPointFeatureDim; ++i) { + features[i] = max(0.0f, baseFeatures[i]); + } + } else if (FeatureActivationType == FeatureActivationType_Sincos) { + [ForceUnroll] for (int k = 0; k < InterpPointFeatureDim; ++k) { + [ForceUnroll] for (int f = 0; f < FeatureActivationNumFrequencies; ++f) { + float freq = float(f + 1); + float angle = baseFeatures[k] * freq; + int outIdx = k * FeatureActivationNumFrequencies * 2 + f * 2; + features[outIdx + 0] = sin(angle); + features[outIdx + 1] = cos(angle); + } + } } else { [ForceUnroll] for (int k = 0; k < InterpPointFeatureDim; ++k) { [ForceUnroll] for (int f = 0; f < FeatureActivationNumFrequencies; ++f) { diff --git a/threedgut_tracer/setup_3dgut.py b/threedgut_tracer/setup_3dgut.py index 593936f6..5ba539af 100644 --- a/threedgut_tracer/setup_3dgut.py +++ b/threedgut_tracer/setup_3dgut.py @@ -109,6 +109,13 @@ def to_cpp_bool(value): "-O3", *defines, ] + # Diagnostic: dump ptxas register / smem / spill stats per kernel. + # Enable with `export NHT_PTXAS_VERBOSE=1` before launching training. + if os.environ.get("NHT_PTXAS_VERBOSE", "0") == "1": + cuda_cflags += [ + "-Xptxas=-v", + "--resource-usage", + ] # When PARTICLE_FEATURE_HALF=1 the Slang-generated header uses __half types; # the Slang prelude only pulls in and defines __half when # SLANG_CUDA_ENABLE_HALF is set. diff --git a/validate.py b/validate.py new file mode 100644 index 00000000..88b162ed --- /dev/null +++ b/validate.py @@ -0,0 +1,244 @@ +#!/usr/bin/env python3 +""" +Validation script: run 3DGUT and 3DGRT on the drums NeRF-Synthetic scene +(MCMC, 100k particles, 15k iterations) for SH and optionally NHT features, +then produce a report with PSNR, SSIM, training time, and render time. + +Usage +----- + python validate.py --data /path/to/nerf_synthetic/drums [OPTIONS] + +Options +------- + --data PATH Path to the drums scene directory (required) + --out-dir PATH Root output directory (default: runs/validate) + --nht Also run NHT experiments (requires NHT support in the + codebase; silently skipped if unavailable) + --iterations N Training iterations per experiment (default: 15000) + --particles N Initial particle count (default: 100000) + --skip-existing Skip training if the checkpoint already exists +""" + +import argparse +import json +import os +import subprocess +import sys +import time +from pathlib import Path + + +# --------------------------------------------------------------------------- +# NHT availability check +# --------------------------------------------------------------------------- + +def _nht_available() -> bool: + """Return True if this codebase has NHT support compiled in.""" + try: + from threedgrut.model.features import Features + _ = Features.Type.NHT + return True + except (ImportError, AttributeError): + return False + + +# --------------------------------------------------------------------------- +# Experiment definitions +# --------------------------------------------------------------------------- + +def _experiments(args, nht_ok: bool): + """Yield (name, renderer, feature_type, app_config) tuples.""" + base = [ + ("3dgut_sh", "3dgut", "sh", "apps/nerf_synthetic_3dgut_mcmc_nht"), + ("3dgrt_sh", "3dgrt", "sh", "apps/nerf_synthetic_3dgrt_mcmc_nht"), + ] + nht = [ + ("3dgut_nht", "3dgut", "nht", "apps/nerf_synthetic_3dgut_mcmc_nht"), + ("3dgrt_nht", "3dgrt", "nht", "apps/nerf_synthetic_3dgrt_mcmc_nht"), + ] + for entry in base: + yield entry + if args.nht and nht_ok: + for entry in nht: + yield entry + + +# --------------------------------------------------------------------------- +# Train +# --------------------------------------------------------------------------- + +def _find_latest(directory: Path, pattern: str): + """Return the most recently modified file matching pattern under directory, or None.""" + matches = sorted(directory.glob(pattern), key=lambda p: p.stat().st_mtime) + return matches[-1] if matches else None + + +def _train(name: str, app_config: str, feature_type: str, args) -> tuple[float, str]: + """Run training and return (wall_time_seconds, checkpoint_path).""" + exp_dir = Path(args.out_dir) / name + + # Trainer saves under exp_dir/-/ckpt_last.pt + if args.skip_existing: + ckpt = _find_latest(exp_dir, "*/ckpt_last.pt") + if ckpt is not None: + print(f" [skip] checkpoint already exists: {ckpt}") + return 0.0, str(ckpt) + + exp_dir.mkdir(parents=True, exist_ok=True) + + cmd = [ + sys.executable, "train.py", + f"--config-name={app_config}", + f"path={args.data}", + f"out_dir={args.out_dir}", + f"experiment_name={name}", + f"n_iterations={args.iterations}", + f"initialization.num_gaussians={args.particles}", + f"model.feature_type={feature_type}", + # disable feature_output_half for SH (it's set to true in NHT app configs) + f"render.feature_output_half={'true' if feature_type == 'nht' else 'false'}", + # save checkpoint only at the end; intermediate checkpoints not needed here + f"checkpoint.iterations=[{args.iterations}]", + # disable GUI + "with_gui=false", + "with_viser_gui=false", + ] + + print(f" $ {' '.join(cmd)}") + t0 = time.time() + subprocess.run(cmd, check=True) + elapsed = time.time() - t0 + + ckpt = _find_latest(exp_dir, "*/ckpt_last.pt") + if ckpt is None: + raise FileNotFoundError(f"Expected checkpoint not found under: {exp_dir}/*/ckpt_last.pt") + + return elapsed, str(ckpt) + + +# --------------------------------------------------------------------------- +# Render / evaluate +# --------------------------------------------------------------------------- + +def _render(name: str, ckpt: str, args) -> dict: + """Run render.py and return the metrics dict.""" + eval_dir = Path(args.out_dir) / name / "eval" + eval_dir.mkdir(parents=True, exist_ok=True) + + cmd = [ + sys.executable, "render.py", + "--checkpoint", ckpt, + "--out-dir", str(eval_dir), + ] + + print(f" $ {' '.join(cmd)}") + subprocess.run(cmd, check=True) + + # Renderer saves under eval_dir//-/metrics.json + metrics_path = _find_latest(eval_dir, "**/metrics.json") + if metrics_path is None: + raise FileNotFoundError(f"metrics.json not found after render under: {eval_dir}") + + with open(metrics_path) as f: + return json.load(f) + + +# --------------------------------------------------------------------------- +# Report +# --------------------------------------------------------------------------- + +_HEADER = ( + "| Experiment | PSNR (dB) | SSIM | LPIPS | Train (min) | Render (ms/f) |\n" + "|------------------|-----------|--------|--------|-------------|---------------|\n" +) + + +def _row(name: str, m: dict, train_sec: float) -> str: + psnr = f"{m.get('mean_psnr', float('nan')):.2f}" + ssim = f"{m.get('mean_ssim', float('nan')):.4f}" + lpips = f"{m.get('mean_lpips', float('nan')):.4f}" + t_min = f"{train_sec / 60:.1f}" if train_sec > 0 else "—" + r_ms = f"{m.get('mean_inference_time_ms', float('nan')):.2f}" + return f"| {name:<16} | {psnr:>9} | {ssim:>6} | {lpips:>6} | {t_min:>11} | {r_ms:>13} |\n" + + +def _write_report(rows: list[tuple], args, nht_ok: bool) -> str: + scene = Path(args.data).name + lines = [ + f"# Validation Report: {scene}\n\n", + f"Scene: `{args.data}` \n", + f"Iterations: {args.iterations} \n", + f"Particles: {args.particles:,} \n", + f"Strategy: MCMC \n", + f"NHT requested: {args.nht} ", + f"{'(supported)' if nht_ok else '(not available in this build — skipped)'} \n\n", + "## Results\n\n", + _HEADER, + ] + for name, metrics, train_sec in rows: + lines.append(_row(name, metrics, train_sec)) + + report = "".join(lines) + + report_path = Path(args.out_dir) / "report.md" + report_path.parent.mkdir(parents=True, exist_ok=True) + report_path.write_text(report) + return report + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + +def main(): + parser = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter) + parser.add_argument("--data", required=True, help="Path to the drums scene directory") + parser.add_argument("--out-dir", default="runs/validate", help="Root output directory") + parser.add_argument("--nht", action="store_true", help="Also run NHT experiments") + parser.add_argument("--iterations", type=int, default=15000, help="Training iterations per experiment") + parser.add_argument("--particles", type=int, default=100000, help="Initial particle count") + parser.add_argument("--skip-existing", action="store_true", help="Skip training if checkpoint exists") + args = parser.parse_args() + + nht_ok = _nht_available() + if args.nht and not nht_ok: + print("WARNING: --nht requested but NHT is not available in this build; NHT experiments will be skipped.") + + rows = [] + for name, renderer, feature_type, app_config in _experiments(args, nht_ok): + print(f"\n{'='*60}") + print(f"Experiment: {name} (renderer={renderer}, features={feature_type})") + print(f"{'='*60}") + + print("\n[1/2] Training ...") + try: + train_sec, ckpt = _train(name, app_config, feature_type, args) + except subprocess.CalledProcessError as e: + print(f"ERROR: training failed for {name}: {e}") + rows.append((name, {}, 0.0)) + continue + + print(f"\n[2/2] Evaluating ...") + try: + metrics = _render(name, ckpt, args) + except subprocess.CalledProcessError as e: + print(f"ERROR: render failed for {name}: {e}") + rows.append((name, {}, train_sec)) + continue + + rows.append((name, metrics, train_sec)) + print(f"\n PSNR={metrics.get('mean_psnr', '?'):.2f} dB " + f"SSIM={metrics.get('mean_ssim', '?'):.4f} " + f"LPIPS={metrics.get('mean_lpips', '?'):.4f} " + f"render={metrics.get('mean_inference_time_ms', '?'):.2f} ms/frame") + + print(f"\n{'='*60}") + print("REPORT") + print(f"{'='*60}") + report = _write_report(rows, args, nht_ok) + print(report) + print(f"Report saved to: {Path(args.out_dir) / 'report.md'}") + + +if __name__ == "__main__": + main() From 5013f4bb673b07ff971373b8c4af3088a56d143f Mon Sep 17 00:00:00 2001 From: Nicolas Moenne-Loccoz Date: Tue, 28 Apr 2026 07:33:30 -0400 Subject: [PATCH 7/9] Remove NHT backward diagnostic switches Keeps the production NHT backward path unconditional so diagnostic-only broken modes cannot affect training builds. --- .../cuda/renderers/gutKBufferRenderer.cuh | 21 ------------------- 1 file changed, 21 deletions(-) diff --git a/threedgut_tracer/include/3dgut/kernels/cuda/renderers/gutKBufferRenderer.cuh b/threedgut_tracer/include/3dgut/kernels/cuda/renderers/gutKBufferRenderer.cuh index 2665df25..c948910b 100644 --- a/threedgut_tracer/include/3dgut/kernels/cuda/renderers/gutKBufferRenderer.cuh +++ b/threedgut_tracer/include/3dgut/kernels/cuda/renderers/gutKBufferRenderer.cuh @@ -18,19 +18,6 @@ #include <3dgut/kernels/cuda/common/rayPayloadBackward.cuh> #include <3dgut/renderer/gutRendererParameters.h> -// NHT backward diagnostic toggle — register-pressure attribution (plan T1). -// Mode selects which components of the NHT bwd inner loop are live. -// The modes are strictly nested so each one also disables everything below it. -// 0 = baseline, full pipeline (default) -// 1 = disable featureLocalGradWarpReduceAndWrite (warp reduce + atomics) -// 2 = also disable featuresIntegrateBwdToLocalGrad (local-grad compute) -// 3 = also disable densityProcessHitBwdToBuffer (Slang density bwd) -// Lower-level (higher-mode) toggles help isolate which symbol is parked in -// the 214-reg live set. Behavior is intentionally broken under modes > 0. -#ifndef NHT_BWD_DIAG_MODE -#define NHT_BWD_DIAG_MODE 0 -#endif - // HitParticle stores per-ray hit state inside the k-buffer. The // `canonicalIntersection` is only consumed by the NHT feature path; for the // SH path (`PerRayParticleFeatures=false`) we drop it from the struct so the @@ -612,7 +599,6 @@ struct GUTKBufferRenderer : Params { float3 canonicalIntersectionGrad = make_float3(0.f, 0.f, 0.f); // Write feature grad to thread-private local buffer (no atomics); warp reduction follows below. -#if NHT_BWD_DIAG_MODE < 2 particles.featuresIntegrateBwdToLocalGrad(ray.direction, canonicalIntersection, canonicalIntersectionGrad, @@ -623,9 +609,7 @@ struct GUTKBufferRenderer : Params { ray.featuresBackward, ray.featuresGradient, featureLocalGrad); -#endif -#if NHT_BWD_DIAG_MODE < 3 particles.template densityProcessHitBwdToBuffer(ray.origin, ray.direction, particleData.idx, @@ -637,7 +621,6 @@ struct GUTKBufferRenderer : Params { ray.hitTBackward, ray.hitTGradient, canonicalIntersectionGrad); -#endif ray.transmittance *= (1.0f - hitAlpha); } @@ -650,11 +633,7 @@ struct GUTKBufferRenderer : Params { // Warp reduction: all 32 threads (alive or not) participate in __shfl_xor_sync. // Non-hitting threads contribute featureLocalGrad=0 → no spurious gradient. // Reduces 32×ParticleFeatureDim atomics per particle to ParticleFeatureDim atomics. -#if NHT_BWD_DIAG_MODE < 1 particles.featureLocalGradWarpReduceAndWrite(particleData.idx, featureLocalGrad, tileThreadIdx); -#else - (void)featureLocalGrad; -#endif } } } else { From 74be1eed6f0a051811e2c750b14961e22b81e5da Mon Sep 17 00:00:00 2001 From: Nicolas Moenne-Loccoz Date: Tue, 28 Apr 2026 08:47:19 -0400 Subject: [PATCH 8/9] Fix NHT eval EMA and backward hit gating Aligns final evaluation with EMA semantics, keeps resumed training on live decoder weights, and mirrors forward hit bounds in the 3DGUT no-k-buffer backward path. --- plan/nht_correction_plan.md | 115 ++++++++++++++++++ plan/nht_reference_results.md | 10 +- threedgrut/trainer.py | 31 +++-- .../cuda/renderers/gutKBufferRenderer.cuh | 4 +- 4 files changed, 144 insertions(+), 16 deletions(-) create mode 100644 plan/nht_correction_plan.md diff --git a/plan/nht_correction_plan.md b/plan/nht_correction_plan.md new file mode 100644 index 00000000..00aca20c --- /dev/null +++ b/plan/nht_correction_plan.md @@ -0,0 +1,115 @@ +# NHT Correction Plan + +## Goal + +Close the correctness gaps found in the NHT branch review against the GSplat NHT reference, without changing SH baseline defaults. + +## P0 - Correctness Bugs + +### T1 - Match Forward Depth Gating In 3DGUT No-K-Buffer Backward + +- File: `threedgut_tracer/include/3dgut/kernels/cuda/renderers/gutKBufferRenderer.cuh` +- Problem: `evalBackwardNoKBuffer` calls `densityHit` but does not apply the same `hitT > ray.tMinMax.x && hitT < ray.tMinMax.y` gate used by forward. +- Change: add the depth-slab predicate around the backward hit processing path. +- Test: + - Build/run a short NHT validation with default `k_buffer_size: 0`. + - Add or run a targeted kernel smoke case where a Gaussian outside `ray.tMinMax` contributes no feature/density gradient. + +### T2 - Use FeatureDecoder EMA For Final Test Evaluation + +- Files: `threedgrut/trainer.py`, optionally `threedgrut/render.py` +- Problem: validation and checkpoint eval use EMA weights, but train-end `on_training_end()` test uses live decoder weights. +- Change: apply `feature_decoder.apply_ema_shadow()` around `Renderer.from_preloaded_model(...).render_all()` and restore after, mirroring validation and GSplat eval. +- Test: + - Re-run checkpoint eval and train-end eval on the same checkpoint; metrics should agree within numerical noise. + +### T3 - Do Not Apply EMA To Trainable Decoder On Resume + +- File: `threedgrut/trainer.py` +- Problem: resume loads decoder EMA shadow and then copies EMA into live trainable parameters, while optimizer state still belongs to non-EMA weights. +- Change: load EMA shadow only; do not call `apply_ema_shadow()` in the resume path. +- Test: + - Save a checkpoint with EMA, resume, and verify `feature_decoder.state_dict()` equals checkpoint `"module"` immediately after load. + - Verify eval still uses EMA through the eval-only swap path. + +## P1 - Metric And Config Semantics + +### T4 - Fix `Renderer.render_all()` Extra-Metrics Guard + +- File: `threedgrut/render.py` +- Problem: `compute_extra_metrics=False` omits SSIM/LPIPS criterions but `render_all()` always uses them. +- Change: either always construct metrics needed by the table, or guard SSIM/LPIPS/color-corrected metrics and output only PSNR when disabled. +- Test: + - Run `Renderer.from_preloaded_model(..., compute_extra_metrics=False).render_all()` on a small checkpoint without `KeyError`. + +### T5 - Correct Benchmark Result Table Units + +- File: `plan/nht_reference_results.md` +- Problem: current 3DGRUT `2.455` / `2.489` values are `std_psnr`, not render time. +- Change: rename that column for 3DGRUT rows or replace with real `mean_inference_time_ms` from a timing-enabled eval. +- Test: + - Confirm `metrics.json` and terminal table fields map one-to-one to the document columns. + +### T6 - Add L2 Loss To Total Or Disable The Flag + +- File: `threedgrut/trainer.py` +- Problem: `loss.use_l2` computes/logs L2 but does not affect `total_loss`. +- Change: add `lambda_l2 * loss_l2` to `total_loss`, or remove/mark unsupported if unused by intended configs. +- Test: + - Unit/smoke check with `loss.use_l2=true`, `lambda_l1=lambda_ssim=lambda_opacity=lambda_scale=0`, verify nonzero `total_loss`. + +### T7 - Split Decoder Weight Decay From Explicit Decoder Regularization + +- Files: `configs/base_gs.yaml`, `threedgrut/trainer.py`, `threedgrut/model/feature_decoder.py` +- Problem: `nht_decoder.reg_weight` drives both Adam `weight_decay` and explicit `params^2` loss. +- Change: keep GSplat parity by using optimizer `weight_decay` only, or introduce separate config keys if both are desired. +- Test: + - With default `reg_weight: 0.0`, behavior is unchanged. + - With nonzero regularization, verify only the chosen mechanism contributes. + +## P2 - Latent / Cleanup Items + +### T8 - Make `rays_in_world_space` Consistent In Feature Decode + +- File: `threedgrut/utils/render.py` +- Problem: tracer respects world-space rays, decoder always rotates directions by `T_to_world`. +- Change: if `gpu_batch.rays_in_world_space` is true, normalize `gpu_batch.rays_dir` directly. +- Test: + - Create a small `Batch` with world-space rays and non-identity `T_to_world`; verify decoder direction input is not double-rotated. + +### T9 - Remove Or Fully Reject Unsupported Bezier NHT Config + +- File: `threedgrut/model/features.py` +- Problem: `interpolation_type` accepts `"bezier"` but `interpolation_support` and kernels do not support it. +- Change: raise a clear `NotImplementedError` for `"bezier"` at `interpolation_type`, or implement end-to-end later. +- Test: + - Config with `model.nht_features.interpolation_type=bezier` fails early with a clear message. + +### T10 - Clarify NHT Progressive Feature Bookkeeping + +- Files: `threedgrut/model/model.py`, `threedgrut/trainer.py`, docs/comments +- Problem: `n_active_features` progresses for NHT but kernels do not use it to mask NHT feature dimensions. +- Change: disable progression for NHT or document it as unused and avoid misleading logs/exports. +- Test: + - NHT training logs do not imply progressive feature activation unless implemented. + +## P3 - Test Coverage + +### T11 - Add NHT Smoke/Parity Tests + +- Files: new tests or `validate.py` +- Problem: `validate.py` only checks that Python NHT symbols exist, not that CUDA NHT kernels work. +- Change: + - Add a minimal NHT render/backward smoke test for 3DGUT. + - Add shape checks for `particle_feature_dim`, `ray_feature_dim`, and sincos expansion. + - Add a resume/EMA behavior check. +- Test: + - Run the new tests in the `3dgrut-nht` environment. + +## Suggested Order + +1. T1, T2, T3: correctness and benchmark parity. +2. T4, T5, T6: metric/reporting correctness. +3. T7, T8, T9, T10: config semantics and latent bugs. +4. T11: regression coverage. + diff --git a/plan/nht_reference_results.md b/plan/nht_reference_results.md index 3f2f7c3b..962b85aa 100644 --- a/plan/nht_reference_results.md +++ b/plan/nht_reference_results.md @@ -2,9 +2,15 @@ ## Bonsai Comparison -| Run | Source | PSNR | SSIM | LPIPS | CC PSNR | CC SSIM | CC LPIPS | Time (ms/image) | +| Run | Source | PSNR | SSIM | LPIPS | CC PSNR | CC SSIM | CC LPIPS | Std PSNR | | --- | --- | ---: | ---: | ---: | ---: | ---: | ---: | ---: | -| GSplat reference | GSplat NHT reference run | 34.163 | 0.9542 | 0.235 | - | - | - | 11.000 | +| GSplat reference | GSplat NHT reference run | 34.163 | 0.9542 | 0.235 | - | - | - | - | | 3DGRUT previous | Before benchmark-parity/color-refinement fixes | 33.427 | 0.949 | 0.248 | 33.559 | 0.947 | 0.247 | 2.489 | | 3DGRUT updated | After benchmark-parity/color-refinement fixes | 33.734 | 0.951 | 0.246 | 33.908 | 0.949 | 0.246 | 2.455 | +| 3DGRUT T1-T3 | After depth-gate and EMA eval fixes | 33.702 | 0.951 | 0.246 | 33.853 | 0.949 | 0.246 | 2.520 | + +## Timing Notes + +- GSplat reference render time: `11.000 ms/image`. +- 3DGRUT timing needs a timing-enabled eval; the table's last 3DGRUT column is `std_psnr`, not time. diff --git a/threedgrut/trainer.py b/threedgrut/trainer.py index 30f03abf..24e8d741 100644 --- a/threedgrut/trainer.py +++ b/threedgrut/trainer.py @@ -288,7 +288,6 @@ def setup_training( ema_state = fd_ckpt.get("ema") if ema_state is not None: self.feature_decoder.load_ema_state_dict(ema_state) - self.feature_decoder.apply_ema_shadow() logger.info("🎨 Feature decoder state restored from checkpoint") # Restore post-processing state @@ -980,18 +979,24 @@ def on_training_end(self): logger.log_rule("Evaluation on Test Set") # Renderer test split - renderer = Renderer.from_preloaded_model( - model=self.model, - out_dir=out_dir, - path=conf.path, - save_gt=False, - writer=self.tracking.writer, - global_step=self.global_step, - compute_extra_metrics=conf.compute_extra_metrics, - post_processing=self.post_processing, - feature_decoder=self.feature_decoder, - ) - renderer.render_all() + if self.feature_decoder is not None: + self.feature_decoder.apply_ema_shadow() + try: + renderer = Renderer.from_preloaded_model( + model=self.model, + out_dir=out_dir, + path=conf.path, + save_gt=False, + writer=self.tracking.writer, + global_step=self.global_step, + compute_extra_metrics=conf.compute_extra_metrics, + post_processing=self.post_processing, + feature_decoder=self.feature_decoder, + ) + renderer.render_all() + finally: + if self.feature_decoder is not None: + self.feature_decoder.restore_ema() @torch.cuda.nvtx.range(f"save_checkpoint") def save_checkpoint(self, last_checkpoint: bool = False): diff --git a/threedgut_tracer/include/3dgut/kernels/cuda/renderers/gutKBufferRenderer.cuh b/threedgut_tracer/include/3dgut/kernels/cuda/renderers/gutKBufferRenderer.cuh index c948910b..c3d02345 100644 --- a/threedgut_tracer/include/3dgut/kernels/cuda/renderers/gutKBufferRenderer.cuh +++ b/threedgut_tracer/include/3dgut/kernels/cuda/renderers/gutKBufferRenderer.cuh @@ -591,7 +591,9 @@ struct GUTKBufferRenderer : Params { float3 canonicalIntersection = make_float3(0.f, 0.f, 0.f); if (particles.densityHit(ray.origin, ray.direction, particleData.densityParameters, - hitAlpha, hitT, canonicalIntersection)) { + hitAlpha, hitT, canonicalIntersection) && + (hitT > ray.tMinMax.x) && + (hitT < ray.tMinMax.y)) { // Re-evaluate NHT features at canonical intersection point (cheap: barycentric interp) const TFeaturesVec hitFeatures = particles.featuresFromBuffer(particleData.idx, ray.direction, canonicalIntersection); From 5bf69463c469f00a69b9cfa39511b3d026e7df05 Mon Sep 17 00:00:00 2001 From: Nicolas Moenne-Loccoz Date: Tue, 28 Apr 2026 09:22:07 -0400 Subject: [PATCH 9/9] Ignore local planning notes --- .gitignore | 1 + plan/nht_bwd_reg_reduction.md | 183 ---------------------------------- plan/nht_correction_plan.md | 115 --------------------- plan/nht_reference_results.md | 16 --- 4 files changed, 1 insertion(+), 314 deletions(-) delete mode 100644 plan/nht_bwd_reg_reduction.md delete mode 100644 plan/nht_correction_plan.md delete mode 100644 plan/nht_reference_results.md diff --git a/.gitignore b/.gitignore index b389391e..c6bafc66 100644 --- a/.gitignore +++ b/.gitignore @@ -12,6 +12,7 @@ outputs/ extra_info/ eval/ extra_info/ +plan/ debug_** diff --git a/plan/nht_bwd_reg_reduction.md b/plan/nht_bwd_reg_reduction.md deleted file mode 100644 index 4e8d3424..00000000 --- a/plan/nht_bwd_reg_reduction.md +++ /dev/null @@ -1,183 +0,0 @@ -# Plan — NHT backward register-pressure reduction - -## Context - -`renderBackward` (NHT + PARTICLE_FEATURE_HALF=1, FEATURE_OUTPUT_HALF=1) is 60 ms -while fwd is 4 ms (15×). Two diagnostics established the root cause: - -1. **A/B: commenting the body of `featuresIntegrateBwdToLocalGrad` changes nothing.** - Commenting `featureLocalGradWarpReduceAndWrite` makes the kernel 2× faster. - This is not an atomic-contention signal (reference - `RasterizeToPixelsFromWorldNHT3DGSBwd.cu` uses the same 48-atomic pattern - and is fast) — it is DCE freeing register pressure upstream. - -2. **ptxas stats (sm_90):** - - | kernel | regs/thread | spills | smem | barriers | - |------------------|-------------|--------|-----------|----------| - | `renderBackward` | **214** | 0 | 17 408 B | 1 | - | `render` (fwd) | 64 | 0 | 17 408 B | 1 | - - Actual `BlockSize = 16 × 16 = 256` (see `gutRendererParameters.h`), not 128 - as I first assumed. With 65 536 regs/SM on sm_90, reg cliffs at 256 - threads/block are: - - 255 regs → 1 block/SM (at 214 today → 1 block/SM, occupancy ~12.5 %) - - 128 regs → 2 blocks/SM (25 %) - - 85 regs → 3 blocks/SM (~37 %) - - 64 regs → 4 blocks/SM (50 %) - - At 214 regs the kernel fits ~1 block/SM → effective occupancy ~12.5 %, - memory latency cannot be hidden. - (If `FINE_GRAINED_LOAD_BALANCING=true`, blocks are 128 thr — shifts the - cliffs: 128r→4blk, 96r→5blk, 64r→8blk. Confirm which path is active.) - -## Goal - -Drive `renderBackward` regs/thread below **128** (ideally **≤ 96**) while -keeping gradients bit-equivalent to the current CUDA-integrated path -(compile-time toggle `NHT_FEATURES_BWD_LOCAL_GRAD_CUDA=1`). - -Measurable success criteria: -- Primary: `renderBackward` wall time reduced ≥ 25 % on the current scene. -- Secondary: regs/thread reported by ptxas ≤ 128. -- Parity: feature/density gradient L2 vs pre-change reference < 1e-6 rel. - -Non-goal: architectural rework of the renderer or Slang call surface. - -## Task list (strictly ordered, each task independently reversible) - -### T1 — Attribution of the register budget - -Goal: measure how much each piece of live state costs in regs, so we know -which fix is worth pursuing. Pure diagnostic, no behavioral change. - -- **T1.a**: build with `-Xptxas=-v -res-usage` permanently enabled behind an - env var (`export NHT_PTXAS_VERBOSE=1`) via `setup_3dgut.py`. Archive - the current 214-reg baseline. -- **T1.b**: temporarily gut `featureLocalGradWarpReduceAndWrite`'s body - (keep signature, compile-time `#if 0` the shuffles + atomics). Rebuild. - Record new regs + ms. -- **T1.c**: additionally gut our `featuresIntegrateBwdToLocalGrad` body. - Rebuild. Record. -- **T1.d**: additionally replace the Slang `densityProcessHitBwdToBuffer` - call with a no-op. Rebuild. Record. - -Deliverable: a 5-row table (baseline + T1.a .. T1.d) of (regs, smem, ms). -Tells us exactly where the 214 regs are parked. - -### T2 — Target fix 1: stage `featureLocalGrad[]` into `__shared__` - -Hypothesis: moving the 48-float per-thread scratch out of registers and -into per-thread shmem slots will free up to 48 regs/thread, with negligible -runtime overhead (shmem LD/ST ≈ register throughput for small arrays). - -- **T2.a** (implementation): add a `__shared__ float sFeatureLocalGrad[BlockSize][PARTICLE_FEATURE_DIM]` - in the KBuffer renderer's bwd path. Change the caller site in - `gutKBufferRenderer.cuh` to pass `sFeatureLocalGrad[tileThreadIdx]` - (plus a clear of that row at top of each `j` iteration). -- **T2.b** (integrate): update the callee signature in - `shRadiativeGaussianParticles.cuh::featuresIntegrateBwdToLocalGrad` and - in `threedgut::nht::featuresIntegrateBwdToLocalGrad` (already takes a - `float*`, so just documentation + assume shmem aliasing). -- **T2.c** (reduce): rewrite `featureLocalGradWarpReduceAndWrite` to reduce - from shmem rather than thread-local registers. Verify `__shfl_xor_sync` - still works (it does — operates on any register value; we first load - shmem row into a single thread-private scalar per iteration). -- **T2.d** (smem budget check): new smem = `17 408 + BlockSize × 48 × 4` - bytes. For `BlockSize=256` that is +49 152 B → 66 560 B/block. H100 has - 228 KB dynamic smem/SM → 3 blocks/SM by smem. Fine only if we hit ≥ 2 - blocks/SM by regs too (i.e. reg cliff ≤ 128). Record. -- **T2.e** (bench): rebuild, capture ptxas regs and `renderBackward` ms. - Table row. -- **T2.f** (parity): run `validate.py` or equivalent; dump feature - gradient buffer L2 vs baseline. Must be ≤ 1e-6 relative. - -Expected: regs 214 → ~166 (–48). Probably still above 128 cliff; need T3. - -### T3 — Target fix 2: `__launch_bounds__` + controlled spill - -Hypothesis: if T2 alone does not cross the 128-reg cliff, force the -compiler to cap regs with `__launch_bounds__(BlockSize, minBlocksPerSM)`. -This may introduce local-memory spills, but occupancy gain beats spill -cost when kernel is latency-bound (our case). - -- **T3.a**: apply `__launch_bounds__(256, 2)` to the `renderBackward` - kernel entry in `gutRenderer.cuh`. Regs forced ≤ 128. Rebuild, record. -- **T3.b**: try `__launch_bounds__(256, 3)` (regs ≤ 85). Rebuild, record. -- **T3.c**: try `__launch_bounds__(256, 4)` (regs ≤ 64). Rebuild, record. - (If `FINE_GRAINED_LOAD_BALANCING=true` the block size is 128; adjust - the first arg accordingly.) -- **T3.d** (parity): only behavioral change is scheduling → gradients - must be bit-identical. Confirm L2 = 0. - -Table rows: regs, spill stores, spill loads, ms per configuration. - -### T4 — Pick winner - -Decision table from T2/T3 data: - -| config | regs | spills | ms | parity | pick? | -|----------------|------|--------|-------|--------|-------| -| baseline | 214 | 0 | 60 | ref | — | -| T2 (shmem) | ? | 0 | ? | yes | ? | -| T3a (128,4) | 128 | ? | ? | yes | ? | -| T3b (128,5) | 96 | ? | ? | yes | ? | -| T2 + T3a | ? | ? | ? | yes | ? | -| T2 + T3b | ? | ? | ? | yes | ? | - -Pick the lowest ms with spill stores low enough to not dominate L1 traffic. - -### T5 — (conditional) Prefetch features into `__shared__` à la reference - -Only if T4 falls short of the 25 % target. This is the -`RasterizeToPixelsFromWorldNHT3DGSBwd.cu` pattern: load the 48-float -feature block for all particles in the batch into shmem **once**, then -per-hit reads are shmem broadcasts. Cuts redundant global loads (currently -done twice per hit — inside `featuresFromBuffer` and again inside our -`featuresIntegrateBwdToLocalGrad`) and reduces register churn from callee -boundaries. - -Bigger change: ~60 lines in `gutKBufferRenderer.cuh` inner batch loop, -plus new shmem size `BlockSize × 48 × sizeof(TFeatElem)` (= 12 KB @ fp16, -24 KB @ fp32). Defer until T4 signals it is needed. - -### T6 — Validation & cleanup - -- **T6.a**: final regs + ms comparison table. -- **T6.b**: full gradient parity on ≥ 2 training steps. -- **T6.c**: revert `NHT_PTXAS_VERBOSE` default (env-gated, stays available). -- **T6.d**: document the choice + knob(s) in `TODO_nht_cuda.md`. - -## Test harness (used by every Tx) - -One repeatable command invocation that captures both ptxas output and -kernel ms. Pseudo-code for `bench.sh`: - -``` -rm ~/.cache/torch_extensions/.../gutRenderer.cuda.o -NHT_PTXAS_VERBOSE=1 python validate.py --iters 50 --nsys ... -grep 'registers\|spill\|renderBackward' /tmp/build_and_run.log -``` - -Produces one row of the comparison table per Tx invocation. - -## Confidence - -- T1 (diagnostic): 95 %. Already have one data point, filling the matrix is mechanical. -- T2 alone buys 30–48 regs: 70 %. -- T2 alone reaches 128-reg cliff: 40 %. -- T3 forces the cliff and gains 25 %+: 60 %. -- Combined (T2+T3) reaches 50 % speedup: 50 %. - -## Open questions (for your review before execution) - -1. Is the validation harness OK as above, or do you have a preferred - benchmarking script I should wire into `bench.sh`? -2. `BlockSize` — I assumed 128. Worth checking the actual tiling config - used by the NHT path; the smem budget math depends on it. -3. Parity tolerance: 1e-6 relative is arbitrary. Tighter (bit-exact) is - achievable for T2 and T3 since they do not change the math. Want me - to require bit-exact? -4. If register pressure is mostly coming from the Slang-exported - `densityProcessHitBwdToBuffer` (T1.d will tell), we may have to - port that too. That is outside this plan — flag as follow-up if so. diff --git a/plan/nht_correction_plan.md b/plan/nht_correction_plan.md deleted file mode 100644 index 00aca20c..00000000 --- a/plan/nht_correction_plan.md +++ /dev/null @@ -1,115 +0,0 @@ -# NHT Correction Plan - -## Goal - -Close the correctness gaps found in the NHT branch review against the GSplat NHT reference, without changing SH baseline defaults. - -## P0 - Correctness Bugs - -### T1 - Match Forward Depth Gating In 3DGUT No-K-Buffer Backward - -- File: `threedgut_tracer/include/3dgut/kernels/cuda/renderers/gutKBufferRenderer.cuh` -- Problem: `evalBackwardNoKBuffer` calls `densityHit` but does not apply the same `hitT > ray.tMinMax.x && hitT < ray.tMinMax.y` gate used by forward. -- Change: add the depth-slab predicate around the backward hit processing path. -- Test: - - Build/run a short NHT validation with default `k_buffer_size: 0`. - - Add or run a targeted kernel smoke case where a Gaussian outside `ray.tMinMax` contributes no feature/density gradient. - -### T2 - Use FeatureDecoder EMA For Final Test Evaluation - -- Files: `threedgrut/trainer.py`, optionally `threedgrut/render.py` -- Problem: validation and checkpoint eval use EMA weights, but train-end `on_training_end()` test uses live decoder weights. -- Change: apply `feature_decoder.apply_ema_shadow()` around `Renderer.from_preloaded_model(...).render_all()` and restore after, mirroring validation and GSplat eval. -- Test: - - Re-run checkpoint eval and train-end eval on the same checkpoint; metrics should agree within numerical noise. - -### T3 - Do Not Apply EMA To Trainable Decoder On Resume - -- File: `threedgrut/trainer.py` -- Problem: resume loads decoder EMA shadow and then copies EMA into live trainable parameters, while optimizer state still belongs to non-EMA weights. -- Change: load EMA shadow only; do not call `apply_ema_shadow()` in the resume path. -- Test: - - Save a checkpoint with EMA, resume, and verify `feature_decoder.state_dict()` equals checkpoint `"module"` immediately after load. - - Verify eval still uses EMA through the eval-only swap path. - -## P1 - Metric And Config Semantics - -### T4 - Fix `Renderer.render_all()` Extra-Metrics Guard - -- File: `threedgrut/render.py` -- Problem: `compute_extra_metrics=False` omits SSIM/LPIPS criterions but `render_all()` always uses them. -- Change: either always construct metrics needed by the table, or guard SSIM/LPIPS/color-corrected metrics and output only PSNR when disabled. -- Test: - - Run `Renderer.from_preloaded_model(..., compute_extra_metrics=False).render_all()` on a small checkpoint without `KeyError`. - -### T5 - Correct Benchmark Result Table Units - -- File: `plan/nht_reference_results.md` -- Problem: current 3DGRUT `2.455` / `2.489` values are `std_psnr`, not render time. -- Change: rename that column for 3DGRUT rows or replace with real `mean_inference_time_ms` from a timing-enabled eval. -- Test: - - Confirm `metrics.json` and terminal table fields map one-to-one to the document columns. - -### T6 - Add L2 Loss To Total Or Disable The Flag - -- File: `threedgrut/trainer.py` -- Problem: `loss.use_l2` computes/logs L2 but does not affect `total_loss`. -- Change: add `lambda_l2 * loss_l2` to `total_loss`, or remove/mark unsupported if unused by intended configs. -- Test: - - Unit/smoke check with `loss.use_l2=true`, `lambda_l1=lambda_ssim=lambda_opacity=lambda_scale=0`, verify nonzero `total_loss`. - -### T7 - Split Decoder Weight Decay From Explicit Decoder Regularization - -- Files: `configs/base_gs.yaml`, `threedgrut/trainer.py`, `threedgrut/model/feature_decoder.py` -- Problem: `nht_decoder.reg_weight` drives both Adam `weight_decay` and explicit `params^2` loss. -- Change: keep GSplat parity by using optimizer `weight_decay` only, or introduce separate config keys if both are desired. -- Test: - - With default `reg_weight: 0.0`, behavior is unchanged. - - With nonzero regularization, verify only the chosen mechanism contributes. - -## P2 - Latent / Cleanup Items - -### T8 - Make `rays_in_world_space` Consistent In Feature Decode - -- File: `threedgrut/utils/render.py` -- Problem: tracer respects world-space rays, decoder always rotates directions by `T_to_world`. -- Change: if `gpu_batch.rays_in_world_space` is true, normalize `gpu_batch.rays_dir` directly. -- Test: - - Create a small `Batch` with world-space rays and non-identity `T_to_world`; verify decoder direction input is not double-rotated. - -### T9 - Remove Or Fully Reject Unsupported Bezier NHT Config - -- File: `threedgrut/model/features.py` -- Problem: `interpolation_type` accepts `"bezier"` but `interpolation_support` and kernels do not support it. -- Change: raise a clear `NotImplementedError` for `"bezier"` at `interpolation_type`, or implement end-to-end later. -- Test: - - Config with `model.nht_features.interpolation_type=bezier` fails early with a clear message. - -### T10 - Clarify NHT Progressive Feature Bookkeeping - -- Files: `threedgrut/model/model.py`, `threedgrut/trainer.py`, docs/comments -- Problem: `n_active_features` progresses for NHT but kernels do not use it to mask NHT feature dimensions. -- Change: disable progression for NHT or document it as unused and avoid misleading logs/exports. -- Test: - - NHT training logs do not imply progressive feature activation unless implemented. - -## P3 - Test Coverage - -### T11 - Add NHT Smoke/Parity Tests - -- Files: new tests or `validate.py` -- Problem: `validate.py` only checks that Python NHT symbols exist, not that CUDA NHT kernels work. -- Change: - - Add a minimal NHT render/backward smoke test for 3DGUT. - - Add shape checks for `particle_feature_dim`, `ray_feature_dim`, and sincos expansion. - - Add a resume/EMA behavior check. -- Test: - - Run the new tests in the `3dgrut-nht` environment. - -## Suggested Order - -1. T1, T2, T3: correctness and benchmark parity. -2. T4, T5, T6: metric/reporting correctness. -3. T7, T8, T9, T10: config semantics and latent bugs. -4. T11: regression coverage. - diff --git a/plan/nht_reference_results.md b/plan/nht_reference_results.md deleted file mode 100644 index 962b85aa..00000000 --- a/plan/nht_reference_results.md +++ /dev/null @@ -1,16 +0,0 @@ -# NHT Reference Results - -## Bonsai Comparison - -| Run | Source | PSNR | SSIM | LPIPS | CC PSNR | CC SSIM | CC LPIPS | Std PSNR | -| --- | --- | ---: | ---: | ---: | ---: | ---: | ---: | ---: | -| GSplat reference | GSplat NHT reference run | 34.163 | 0.9542 | 0.235 | - | - | - | - | -| 3DGRUT previous | Before benchmark-parity/color-refinement fixes | 33.427 | 0.949 | 0.248 | 33.559 | 0.947 | 0.247 | 2.489 | -| 3DGRUT updated | After benchmark-parity/color-refinement fixes | 33.734 | 0.951 | 0.246 | 33.908 | 0.949 | 0.246 | 2.455 | -| 3DGRUT T1-T3 | After depth-gate and EMA eval fixes | 33.702 | 0.951 | 0.246 | 33.853 | 0.949 | 0.246 | 2.520 | - -## Timing Notes - -- GSplat reference render time: `11.000 ms/image`. -- 3DGRUT timing needs a timing-enabled eval; the table's last 3DGRUT column is `std_psnr`, not time. -