diff --git a/.gitignore b/.gitignore index feedbb5e..c6bafc66 100644 --- a/.gitignore +++ b/.gitignore @@ -12,6 +12,7 @@ outputs/ extra_info/ eval/ extra_info/ +plan/ debug_** @@ -24,6 +25,7 @@ thirdparty/kaolin/ threedgrt_tracer/.ninja_log threedgrt_tracer/include/3dgrt/kernels/slang/*.cuh* +threedgut_tracer/include/threedgutSlang.cuh *.egg-info .idea diff --git a/TODO.md b/TODO.md new file mode 100644 index 00000000..50839c85 --- /dev/null +++ b/TODO.md @@ -0,0 +1,23 @@ +# TODO + +## 3DGRT: half-precision particle features + +`conf.render.particle_feature_half` is compiled into the 3DGRT kernel via `-DPARTICLE_FEATURE_HALF` +but the Python-side cast is missing. In `threedgrt_tracer/tracer.py`, `gaussians.get_features()` +must be cast to `.half()` before being passed to `_Autograd.apply` when the flag is set, +matching what 3DGUT already does. + +See the `TODO` comment in `threedgrt_tracer/tracer.py`. + +## 3DGRT: NHT support in CUDA path (`gaussianParticles.cuh`) + +The NHT feature transform (`FEATURE_TRANSFORM_TYPE=1`) is implemented for the Slang path +(`gaussianParticles.slang`) but not yet in the CUDA path (`gaussianParticles.cuh`). +Full NHT support in 3DGRT requires extending `gaussianParticles.cuh` with the NHT +interpolation and activation logic currently only present in the Slang kernel. + +## 3DGUT: refactor `evalBackwardNoKBuffer` to share path with k-buffer backward + +`evalBackwardNoKBuffer` (`gutKBufferRenderer.cuh`) duplicates logic from the k-buffer backward +path. The two should be unified into a shared implementation to reduce code duplication and +ensure future fixes apply to both. diff --git a/TODO_half_3dgrt.md b/TODO_half_3dgrt.md new file mode 100644 index 00000000..db386c60 --- /dev/null +++ b/TODO_half_3dgrt.md @@ -0,0 +1,144 @@ +# 3DGRT half-precision feature support + +Goal: make `conf.render.particle_feature_half=true` and `conf.render.feature_output_half=true` +work end-to-end in `threedgrt_tracer`, matching the behavior already implemented in +`threedgut_tracer`. Gradient buffers remain fp32 on both paths. + +Semantics (mirroring 3dgut): +- `particle_feature_half=true`: storage for `particleRadiance` (per-particle feature buffer) + is fp16. Slang entry points already expect `feat_elem_t*` (`__half*` when the macro is set). + Gradient `particleRadianceGrad` stays fp32. +- `feature_output_half=true`: storage for the per-ray integrated feature buffer (`rayRadiance`) + is fp16. Gradient `rayRadianceGrad` stays fp32. Tracer `.forward()` casts fp16 back to fp32 + before returning, mirroring 3dgut. + +## Scope + +Files to touch (by layer): + +- C++ pipeline type layer + - `threedgrt_tracer/include/3dgrt/pipelineParameters.h` + Introduce `TFeatureDensityElem` (output/ray feature) and `TParticleFeatureElem` (particle + storage) typedefs, guarded on the two macros. Change `particleRadiance` from `const float*` + to `const TParticleFeatureElem*`, and `rayRadiance` from + `PackedTensorAccessor32` to `PackedTensorAccessor32`. + `particleRadianceGrad` and `rayRadianceGrad` stay fp32. +- OptiX raygen kernels + - `threedgrt_tracer/src/kernels/cuda/referenceSlangOptix.cu` + - `threedgrt_tracer/src/kernels/cuda/referenceSlangBwdOptix.cu` + 1. Replace `const_cast(params.particleRadiance)` with + `const_cast(params.particleRadiance)`. + 2. FWD write to `rayRadiance`: wrap with `__float2half` when `FEATURE_OUTPUT_HALF`. + 3. BWD read from `rayRadiance`: wrap with `__half2float` when `FEATURE_OUTPUT_HALF`. + `rayRadianceGrad` reads stay fp32. +- Host launcher + - `threedgrt_tracer/src/optixTracer.cpp` + 1. `trace()`: allocate `rayRad` with `torch::kHalf` when `FEATURE_OUTPUT_HALF=1`, + build `packed_accessor32(rayRad)`, and + `getPtr(particleRadiance)`. + 2. `traceBwd()`: same dtype for the forward `rayRad` input; the grad tensors remain fp32. +- Python tracer + - `threedgrt_tracer/tracer.py` + 1. Cast `gaussians.get_features()` to `.half()` when `conf.render.particle_feature_half`. + 2. Keep `ray_features.float()` return to caller; `ray_features` saved in ctx may be fp16 + when `feature_output_half=true` (already saves the raw output, consistent with 3dgut). + +No changes required in Slang `.slang` or generated `.cuh`: the generalization already landed +and compiles correctly once `SLANG_CUDA_ENABLE_HALF=1` is set (done). + +## Task breakdown + +Each task is independently reviewable and testable (run validate.py for the relevant flag +combinations after each). + +### T1 — Introduce typedefs in `pipelineParameters.h` +- Add `TFeatureDensityElem` and `TParticleFeatureElem` (guarded by `FEATURE_OUTPUT_HALF` and + `PARTICLE_FEATURE_HALF`), include `cuda_fp16.h` when either is set. +- Change `particleRadiance` to `const TParticleFeatureElem*` and `rayRadiance` accessor to + `PackedTensorAccessor32`. +- No functional change when both macros are 0 (typedefs resolve to `float`). +- Test: build with both flags false (current default) → no-op rebuild; CI NeRF-Synthetic 3dgrt + smoke test still passes. + +### T2 — Update OptiX kernels for fp16 reads/writes +- Apply the `__float2half` / `__half2float` wrappers in `referenceSlangOptix.cu` and + `referenceSlangBwdOptix.cu` under `FEATURE_OUTPUT_HALF`. +- Update `const_cast` sites to `TParticleFeatureElem*`. +- Test: build with both flags false → identical numerical output to baseline (no wrappers + compiled in). + +### T3 — Host buffer allocation and accessor typing +- `optixTracer.cpp`: select dtype `kHalf` vs `kFloat32` for `rayRad`; use + `packed_accessor32(rayRad)`. +- `getPtr(particleRadiance)` for the particle buffer. +- Test: with flags false → unchanged; build-time assert that tensor dtype matches the + typedef via `TORCH_CHECK(rayRad.scalar_type() == ...)` in DEBUG. + +### T4 — Python cast for `particle_feature_half` +- `tracer.py`: mirror 3dgut's conditional `.half()` cast on `gaussians.get_features()`. +- Test: flags false → unchanged. + +### T5 — End-to-end validation with flags enabled +- Run `validate.py` with `render.particle_feature_half=true render.feature_output_half=true` + using an existing NHT config (e.g. `nerf_synthetic_3dgrt_mcmc_nht.yaml`). +- Compare PSNR after N iterations against the fp32 baseline — expected within 0.1 dB. +- Gradients: single backward pass on a fixed seed; check that + `particleRadianceGrad` and `rayRadianceGrad` are finite and within tolerance of the + fp32 reference. + +### T6 — Rename `*Radiance*` → `*Features*` in 3dgrt +Naming cleanup to align with the post-SH NHT feature abstraction. The legacy `Radiance` +suffix comes from the SH-only era; the buffers now carry arbitrary per-particle / per-ray +features. Purely mechanical rename, no behavioral change. Runs AFTER T1–T5 land so we are +not also chasing name drift during the fp16 functional work. + +Rename mapping (all scopes): +- `PipelineParameters::particleRadiance` → `particleFeatures` +- `PipelineParameters::rayRadiance` → `rayFeatures` +- `PipelineBackwardParameters::particleRadianceGrad` → `particleFeaturesGrad` +- `PipelineBackwardParameters::rayRadianceGrad` → `rayFeaturesGrad` +- `OptixTracer::trace(..., torch::Tensor particleRadiance, ...)` arg → `particleFeatures` +- `OptixTracer::traceBwd(..., torch::Tensor particleRadiance, rayRad, rayRadGrd, ...)` args + → `particleFeatures`, `rayFeat`, `rayFeatGrd` (local tensors + Python side kwargs). +- `particleRadianceGrad` local in `optixTracer.cpp::traceBwd` → `particleFeaturesGrad`. +- Python: `tracer.py` local variables `ray_features` / `ray_features_grd` are already + feature-named; cross-check that the pybind11 binding signature in `bindings.cpp` uses + the new C++ arg names. + +Out of scope for T6 (per resolved decisions above): +- `particleRadianceSphDegree` C++ field and `conf.render.particle_radiance_sph_degree` YAML. +- `shRadiativeParticles.slang` filename and internal `shRadiance*` identifiers (SH path). +- Any `*Radiance*` identifiers that only exist on the SH-specific code path. + +Test: +- Build + full `validate.py` run with fp32 flags (both false) → identical numerical + output to pre-T6 baseline (bit-identical expected since only identifier renames). +- Build + `validate.py` with fp16 flags (both true) → identical output to T5 result. + +## Tests to write up-front + +- `tests/test_3dgrt_half_flags.py` (new, small) + - Parametrize over `(particle_feature_half, feature_output_half) ∈ {(F,F),(T,F),(F,T),(T,T)}`. + - Forward only, single frame, fixed scene; compare `pred_features.float()` to the (F,F) + baseline with `atol=5e-3, rtol=1e-2`. + - Forward + backward; compare `mog_sph.grad` to the (F,F) baseline at the same tolerance. + +## Decisions (resolved with user) + +1. T5 validation ownership: user runs validation; the plan only needs to keep the hooks in + place (no tolerance tuning required from the implementer). +2. Gradient buffers stay fp32 end-to-end (no half-grad path). +3. T6 rename scope is restricted to identifiers naming buffers that can carry NHT features + (i.e. the per-particle feature storage and per-ray integrated feature output, plus their + fp32 gradients). Scalars and SH-specific paths are NOT renamed: + - keep `particleRadianceSphDegree` (C++ field) and `conf.render.particle_radiance_sph_degree` + (YAML) — scalar, shared with the SH path + - keep `shRadiativeParticles.slang` filename and its internal `shRadiance*` identifiers — + SH-only code path. +4. T6 runs AFTER T1–T5. + +## Non-goals + +- No changes to CUDA fallback path (`gaussianParticles.cuh`) — per the existing TODO that is + a separate workstream. +- No changes to `threedgrt_playground`. diff --git a/TODO_nht_cuda.md b/TODO_nht_cuda.md new file mode 100644 index 00000000..031799a1 --- /dev/null +++ b/TODO_nht_cuda.md @@ -0,0 +1,73 @@ +# Handwritten CUDA port of `featuresIntegrateBwdToLocalGrad` (NHT path) + +## Status +- [x] **T1** — Tetrahedron constants (`tetraV0`, `tetraN0..N3`) placed in + `nht_detail` namespace at the top of `shRadiativeGaussianParticles.cuh`. + Values derived from Slang's vertex ordering; verified via script + (w_k == 1 at v_k, 0 at other vertices). +- [x] **T2** — Method body replaced, gated by `#if NHT_FEATURES_BWD_LOCAL_GRAD_CUDA`. + Default = `1` (native CUDA). Flip to `0` in + `threedgut_tracer/include/3dgut/kernels/cuda/models/shRadiativeGaussianParticles.cuh` + to restore the Slang-autodiff path (kept unchanged in the `#else` branch). +- [ ] **T3** — Rebuild, run `validate.py` (or one training step) with the macro + at 1 vs 0. Compare: + - feature gradient buffer L2 (primary parity check) + - density / position gradients (sanity; should be identical since we don't + touch those paths) + - `renderBackward` ms in nsys. +- [ ] **T4** — If parity holds, keep default = 1. Otherwise flip to 0 and iterate. + +## What the handwritten CUDA does (semantics to match Slang exactly) + +Replicates the sequence inside Slang's `particleFeaturesIntegrateBwdToBuffer` +called with `exclusiveGradient=true` and the shifted `featureLocalGrad` buffer: + +1. Early-out when `alpha <= 0`. +2. Recover pre-hit accumulator: + `acc_prev[i] = (integratedFeatures[i] - features[i]*alpha) / (1-alpha)`. +3. VJP of back-to-front `y_i = (1-alpha)*acc_prev_i + alpha*f_i` against + incoming `dy = integratedFeaturesGrad`: + - `dFeatures[i] = alpha * dy_i` + - `alphaGrad += sum_i (features[i] - acc_prev[i]) * dy_i` + - `integratedFeaturesGrad[i] = (1-alpha) * dy_i` (new accumulator grad) +4. Barycentric weights `w[0..3]` from `canonicalIntersection` (Cramer form + matching Slang, precomputed `N_k` face normals). +5. Load 4 vertex feature blocks × `InterpPointFeatureDim` once + (`__half2float` when `PARTICLE_FEATURE_HALF=1`). +6. Activation backward → `dBase[InterpPointFeatureDim]`: + | Activation | Forward | Backward | + |---|---|---| + | None (0) | `out = base` | `dBase = dFeatures` | + | Siren (1) | `sin(base * 2^f)` | `dBase += cos(base*freq) * freq * dOut` | + | Sincos (2) | `sin + cos` | `dBase += (cos - sin) * freq * dOut` | + | Relu (3) | `max(0, base)` | `dBase = (features[i] > 0) ? dFeatures[i] : 0` | +7. Barycentric backward: + - `featureLocalGrad[k*IPFD + i] += w[k] * dBase[i]` (matches Slang's `+=` with exclusiveGradient=true) + - `canonicalIntersectionGrad += sum_k (sum_i vert[k][i] * dBase[i]) * N_k` + +## Guardrails +- `static_assert(FeatureTransformType == 1)` — NHT-only. +- `static_assert(FEATURE_INTERPOLATION_TYPE == 0)` — barycentric only. +- `static_assert(FEATURE_INTERPOLATION_SUPPORT == 1)` — tetrahedra only. +- `static_assert` on `RAY_FEATURE_DIM` / `INTERP_POINT_FEATURE_DIM` / activation consistency. +- `static_assert(4 * IPFD == ParticleFeatureDim)` — buffer layout. + +Any unsupported config fails at compile time — fallback is to flip the macro to 0. + +## Confidence + +- **Forward parity** (interpolation + integration, current config `activation=relu`): + high (see comparison with `neural-harmonic-textures/Interpolation.cuh` — same + tetrahedron geometry, different indexing; same integration math). +- **Backward numerical parity**: medium-high. The Relu path is trivial. The + (1-α)/α lerp VJP + barycentric VJP is standard. Main risk is a sign or + vertex-index swap — covered by T3 gradient diff. +- **Perf win**: medium-high. Expected 3–5× on this single kernel. + +## Open reference points + +- Slang source: `threedgut_tracer/include/3dgut/kernels/slang/models/neuralHarmonicFeaturesParticle.slang` +- External CUDA ref: `/nv/dev/neural-harmonic-textures/gsplat/gsplat/cuda/csrc/RasterizeToPixelsFromWorldNHT3DGSBwd.cu` + (sincos activation; do NOT copy the activation bwd verbatim — see + "Caveats" in the forward-parity discussion: Slang's sincos sums into one + channel, ref's keeps them separate). diff --git a/configs/apps/colmap_3dgrt_mcmc_nht.yaml b/configs/apps/colmap_3dgrt_mcmc_nht.yaml new file mode 100644 index 00000000..74b9fc55 --- /dev/null +++ b/configs/apps/colmap_3dgrt_mcmc_nht.yaml @@ -0,0 +1,23 @@ +# @package _global_ +# NHT (Neural Harmonic Textures) variant for colmap datasets with 3DGRT and MCMC + +defaults: + - /base_mcmc + - /dataset: colmap + - /initialization: colmap + - /render: 3dgrt + - _self_ + +model: + feature_type: "nht" + +render: + pipeline_type: referenceSlang + backward_pipeline_type: referenceSlangBwd + particle_kernel_max_alpha: 0.999 + +loss: + use_opacity: true + lambda_opacity: 0.02 + use_scale: true + lambda_scale: 0.005 diff --git a/configs/apps/colmap_3dgut_mcmc_nht.yaml b/configs/apps/colmap_3dgut_mcmc_nht.yaml new file mode 100644 index 00000000..551c8e0e --- /dev/null +++ b/configs/apps/colmap_3dgut_mcmc_nht.yaml @@ -0,0 +1,21 @@ +# @package _global_ +# NHT (Neural Harmonic Textures) variant for colmap datasets with 3DGUT and MCMC + +defaults: + - /base_mcmc + - /dataset: colmap + - /initialization: colmap + - /render: 3dgut + - _self_ + +model: + feature_type: "nht" + +render: + particle_kernel_max_alpha: 0.999 + +loss: + use_opacity: true + lambda_opacity: 0.02 + use_scale: true + lambda_scale: 0.005 diff --git a/configs/apps/nerf_synthetic_3dgrt_mcmc_nht.yaml b/configs/apps/nerf_synthetic_3dgrt_mcmc_nht.yaml new file mode 100644 index 00000000..9230ba43 --- /dev/null +++ b/configs/apps/nerf_synthetic_3dgrt_mcmc_nht.yaml @@ -0,0 +1,23 @@ +# @package _global_ +# NHT (Neural Harmonic Textures) variant for nerf_synthetic with 3DGRT and MCMC + +defaults: + - /base_mcmc + - /dataset: nerf + - /initialization: random + - /render: 3dgrt + - _self_ + +model: + feature_type: "nht" + +render: + pipeline_type: referenceSlang + backward_pipeline_type: referenceSlangBwd + particle_kernel_max_alpha: 0.999 + +loss: + use_opacity: true + lambda_opacity: 0.02 + use_scale: true + lambda_scale: 0.005 diff --git a/configs/apps/nerf_synthetic_3dgut_mcmc_nht.yaml b/configs/apps/nerf_synthetic_3dgut_mcmc_nht.yaml new file mode 100644 index 00000000..3b06d71d --- /dev/null +++ b/configs/apps/nerf_synthetic_3dgut_mcmc_nht.yaml @@ -0,0 +1,21 @@ +# @package _global_ +# NHT (Neural Harmonic Textures) variant for nerf_synthetic with 3DGUT and MCMC + +defaults: + - /base_mcmc + - /dataset: nerf + - /initialization: random + - /render: 3dgut + - _self_ + +model: + feature_type: "nht" + +render: + particle_kernel_max_alpha: 0.999 + +loss: + use_opacity: true + lambda_opacity: 0.02 + use_scale: true + lambda_scale: 0.005 diff --git a/configs/base_gs.yaml b/configs/base_gs.yaml index 51248502..73703d87 100644 --- a/configs/base_gs.yaml +++ b/configs/base_gs.yaml @@ -56,17 +56,46 @@ model: optimize_density: true optimize_features_albedo: true optimize_features_specular: true + optimize_features: true optimize_position: true optimize_rotation: true optimize_scale: true bvh_update_frequency: 1 + feature_type: "sh" # one of [sh, nht] # sh: spherical harmonics radiance, nht: neural harmonics texture progressive_training: - feature_type: "sh" # one of [sh, mlp], currently only sh supported init_n_features: 0 # sh: initial sh deg | mlp: num of dims initially unmasked max_n_features: 3 # sh: maximum sh deg | mlp: total num of dims finally unmasked increase_frequency: 1000 # unmask more feature dimensions every N global steps increase_step: 1 # sh: how many degrees unmasked per step | mlp: how many dims unmasked per step + nht_features: + dim: 48 + init_min: -1.5707963267948966 + init_max: 1.5707963267948966 + activation: + type: "sincos" + num_frequencies: 1 + interpolation_type: "barycentric" + + nht_decoder: + enabled: true + hidden_dim: 128 + num_layers: 3 + dir_encoding: "SphericalHarmonics" + dir_encoding_degree: 3 + sh_scale: 3.0 + output_activation: "Sigmoid" + unpremultiply_alpha: false + learning_rate: 0.00068 + reg_weight: 0.0 + scheduler: + type: "cosine" + decay_final: 0.1 + max_steps: 30000 + ema_decay: 0.95 + ema_start_step: 0 + color_refine_steps: 3000 + background: name: background-color # one of [skip-background, background-color] color: black # one of [black, white, random] needs to be defined if name == background-color @@ -89,6 +118,8 @@ optimizer: lr: 0.0025 # 3DGS value: 0.0025 features_specular: lr: ${div:${optimizer.params.features_albedo.lr},20} # 3DGS value 20x smaller than lr of features_albedo + features: + lr: 0.015 rotation: lr: 0.001 # 3DGS value: 0.001 scale: @@ -106,6 +137,11 @@ scheduler: density: type: skip + features: + type: cosine + decay_final: 0.1 + max_steps: 30000 + loss: use_l1: true lambda_l1: 0.8 diff --git a/configs/render/3dgrt.yaml b/configs/render/3dgrt.yaml index 655cd3a8..6f664727 100644 --- a/configs/render/3dgrt.yaml +++ b/configs/render/3dgrt.yaml @@ -14,3 +14,5 @@ max_consecutive_bvh_update: 15 enable_normals: false enable_hitcounts: true enable_kernel_timings: false +particle_feature_half: false +feature_output_half: false diff --git a/threedgrt_tracer/include/3dgrt/kernels/cuda/3dgrtTracer.cuh b/threedgrt_tracer/include/3dgrt/kernels/cuda/3dgrtTracer.cuh index f799b647..8bcbf56b 100644 --- a/threedgrt_tracer/include/3dgrt/kernels/cuda/3dgrtTracer.cuh +++ b/threedgrt_tracer/include/3dgrt/kernels/cuda/3dgrtTracer.cuh @@ -38,7 +38,8 @@ struct RayHit { using RayPayload = RayHit[PipelineParameters::MaxNumHitPerTrace]; struct RayData { - float3 radiance; + float features[RAY_FEATURE_DIM]; + float density; float3 normal; float hitDistance; @@ -46,12 +47,16 @@ struct RayData { float hitCount; // TODO (operel): convert to uint32 __device__ void initialize() { - radiance = make_float3(0.0f); - density = 0.0; - normal = make_float3(0.f); - hitDistance = 0.f; + // Zero-initialize all features + #pragma unroll + for (int i = 0; i < RAY_FEATURE_DIM; ++i) { + radiance[i] = 0.0f; + } + density = 0.0; + normal = make_float3(0.f); + hitDistance = 0.f; rayLastHitDistance = 0.f; - hitCount = 0.f; + hitCount = 0.f; } }; @@ -126,7 +131,7 @@ static __device__ __inline__ void trace( rayPayload[15].distance = __uint_as_float(r31); } -/* Traces Gaussians along ray and accumulates radiance into rayData. +/* Traces Gaussians along ray and accumulates features into rayData. * ray is bounded by both (tmin, tmax) and launch params.aabb */ static __device__ __inline__ void traceVolumetricGS( @@ -142,10 +147,10 @@ static __device__ __inline__ void traceVolumetricGS( } float rayTransmittance = 1.0f - rayData.density; - float2 minMaxT = intersectAABB(params.aabb, rayOrigin, rayDirection); - minMaxT.x = fmaxf(minMaxT.x, tmin); - minMaxT.y = fminf(minMaxT.y, tmax); - constexpr float epsT = 1e-9; + float2 minMaxT = intersectAABB(params.aabb, rayOrigin, rayDirection); + minMaxT.x = fmaxf(minMaxT.x, tmin); + minMaxT.y = fminf(minMaxT.y, tmax); + constexpr float epsT = 1e-9; float rayLastHitDistance = fmaxf(0.0f, minMaxT.x - epsT); RayPayload rayPayload; @@ -165,7 +170,7 @@ static __device__ __inline__ void traceVolumetricGS( rayOrigin, rayDirection, rayHit.particleId, - {{(gaussianParticle_RawParameters_0*)params.particleDensity, nullptr}}, + {{(gaussianParticle_RawParameters_0*)params.particleDensity, nullptr, true}}, &rayTransmittance, &rayData.hitDistance, #ifdef ENABLE_NORMALS @@ -175,11 +180,15 @@ static __device__ __inline__ void traceVolumetricGS( #endif ); - particleFeaturesIntegrateFwdFromBuffer(rayDirection, - hitWeight, - rayHit.particleId, - {{(float3*)params.particleRadiance, nullptr}, params.sphDegree}, - &rayData.radiance); + // Call generic Slang wrapper (no conditionals) + // The wrapper handles CommonParameters construction internally + particleFeaturesIntegrateFwdGeneric( + rayDirection, + hitWeight, + rayHit.particleId, + params.particleFeatures, // void* - generic buffer pointer + params.sphDegree, // auxiliary parameter (sphDegree for SH, unused for learned) + rayData.features); // float* - generic output array rayLastHitDistance = fmaxf(rayLastHitDistance, rayHit.distance); @@ -190,7 +199,7 @@ static __device__ __inline__ void traceVolumetricGS( } } - rayData.density = 1 - rayTransmittance; + rayData.density = 1 - rayTransmittance; rayData.rayLastHitDistance = rayLastHitDistance; } @@ -205,7 +214,7 @@ static __device__ __inline__ void intersectVolumetricGS() { : particleDensityHitCustom(optixGetWorldRayOrigin(), optixGetWorldRayDirection(), optixGetPrimitiveIndex(), - {{(gaussianParticle_RawParameters_0*)params.particleDensity, nullptr}}, + {{(gaussianParticle_RawParameters_0*)params.particleDensity, nullptr, true}}, optixGetRayTmin(), optixGetRayTmax(), params.hitMaxParticleSquaredDistance, diff --git a/threedgrt_tracer/include/3dgrt/kernels/cuda/gaussianParticles.cuh b/threedgrt_tracer/include/3dgrt/kernels/cuda/gaussianParticles.cuh index 089e4c4c..73afa688 100644 --- a/threedgrt_tracer/include/3dgrt/kernels/cuda/gaussianParticles.cuh +++ b/threedgrt_tracer/include/3dgrt/kernels/cuda/gaussianParticles.cuh @@ -13,6 +13,10 @@ // See the License for the specific language governing permissions and // limitations under the License. +// This file contains SH-specific CUDA backward pass helpers. +// For learned features mode (FEATURE_TRANSFORM_TYPE != 0), this file compiles but provides no functions. +// Learned features use Slang autodiff instead of these CUDA helpers. + #include #include <3dgrt/mathUtils.h> @@ -22,6 +26,9 @@ typedef int int32_t; typedef unsigned int uint32_t; +// Only define SH-specific functions for SH mode +#if !defined(FEATURE_TRANSFORM_TYPE) || FEATURE_TRANSFORM_TYPE == 0 + void quaternionWXYZToMatrix(const float4& q, float33& ret) { const float r = q.x; const float x = q.y; @@ -365,7 +372,7 @@ __device__ inline bool processHit( const float grayDist = dot(gcrod, gcrod); const float gres = particleResponse(grayDist); - const float galpha = fminf(0.99f, gres * particleDensity); + const float galpha = fminf(GAUSSIAN_PARTICLE_MAX_ALPHA, gres * particleDensity); const bool acceptHit = (gres > minParticleKernelDensity) && (galpha > minParticleAlpha); if (acceptHit) { @@ -506,7 +513,7 @@ __device__ inline void processHitBwd( const float grayDist = dot(gcrod, gcrod); const float gres = particleResponse(grayDist); - const float galpha = fminf(0.99f, gres * particleDensity); + const float galpha = fminf(GAUSSIAN_PARTICLE_MAX_ALPHA, gres * particleDensity); if ((gres > minParticleKernelDensity) && (galpha > minParticleAlpha)) { ParticleDensity& particleDensityGrad = particleDensityGradPtr[particleIdx]; @@ -548,7 +555,7 @@ __device__ inline void processHitBwd( // => groRayHitGrd_j = -grd_j * dot(grdsRayHitGrd * gscl, grd) const float grdScaledDot = dot(grdsRayHitGrd * gscl, grd); float3 grdRayHitGrd, groRayHitGrd; - if constexpr (SurfelPrimitive) { + if (SurfelPrimitive) { const float h = -gro.z / grd.z; grdRayHitGrd = gscl * grdsRayHitGrd * h - make_float3(0.f, 0.f, (h / grd.z) * grdScaledDot); groRayHitGrd = make_float3(0.f, 0.f, -grdScaledDot / grd.z); @@ -723,3 +730,5 @@ __device__ inline void processHitBwd( transmittance = nextTransmit; } } + +#endif // FEATURE_TRANSFORM_TYPE == 0 diff --git a/threedgrt_tracer/include/3dgrt/kernels/slang/models/gaussianParticles.slang b/threedgrt_tracer/include/3dgrt/kernels/slang/models/gaussianParticles.slang index 04e9d413..d8568b26 100644 --- a/threedgrt_tracer/include/3dgrt/kernels/slang/models/gaussianParticles.slang +++ b/threedgrt_tracer/include/3dgrt/kernels/slang/models/gaussianParticles.slang @@ -185,6 +185,17 @@ struct Parameters : IDifferentiable { return sqrt(dot(grds, grds)); } +[BackwardDifferentiable] [ForceInline] float canonicalRayIntersection( + float3 canonicalRayOrigin, + float3 canonicalRayDirection, + float3 scale, + out float3 canonicalIntersection) { + const float3 canonicalGrds = canonicalRayDirection * dot(canonicalRayDirection, -1 * canonicalRayOrigin); + canonicalIntersection = canonicalRayOrigin + canonicalGrds; + const float3 grds = scale * canonicalGrds; + return sqrt(dot(grds, grds)); +} + [BackwardDifferentiable][ForceInline] float3 canonicalRayNormal( float3 canonicalRayOrigin, float3 canonicalRayDirection, @@ -207,6 +218,7 @@ bool hit( Parameters parameters, out float alpha, inout float depth, + out float3 canonicalIntersection, no_diff bool enableNormal, inout float3 normal) { @@ -227,7 +239,7 @@ bool hit( const bool acceptHit = ((maxResponse > MinParticleKernelDensity) && (alpha > MinParticleAlpha)); if (acceptHit) { - depth = canonicalRayDistance(canonicalRayOrigin, canonicalRayDirection, parameters.scale); + depth = canonicalRayIntersection(canonicalRayOrigin, canonicalRayDirection, parameters.scale, canonicalIntersection); if (enableNormal) { normal = canonicalRayNormal(canonicalRayOrigin, canonicalRayDirection, parameters.scale, parameters.rotationT); @@ -276,6 +288,7 @@ float processHitFromBuffer( no_diff RawParametersBuffer parametersBuffer, inout float transmittance, inout float integratedDepth, + out float3 canonicalIntersection, no_diff bool enableNormal, inout float3 integratedNormal) { @@ -287,6 +300,7 @@ float processHitFromBuffer( fetchParameters(particleIdx, parametersBuffer), alpha, depth, + canonicalIntersection, enableNormal, normal)) { @@ -320,7 +334,7 @@ float3 incidentDirectionFromParameters( } [BackwardDifferentiable][ForceInline] -float3 incidentDirectionFromBuffer( +no_diff float3 incidentDirectionFromBuffer( no_diff uint32_t particleIdx, no_diff RawParametersBuffer parametersBuffer, no_diff float3 sourcePosition @@ -353,6 +367,7 @@ inline bool particleDensityHit( gaussianParticle.Parameters parameters, out float alpha, out float depth, + out float3 canonicalIntersection, bool enableNormal, out float3 normal) { @@ -361,6 +376,7 @@ inline bool particleDensityHit( parameters, alpha, depth, + canonicalIntersection, enableNormal, normal); } @@ -392,6 +408,7 @@ inline float particleDensityProcessHitFwdFromBuffer( gaussianParticle.CommonParameters commonParameters, inout float transmittance, inout float integratedDepth, + out float3 canonicalIntersection, in bool enableNormal, inout float3 integratedNormal) { @@ -402,6 +419,7 @@ inline float particleDensityProcessHitFwdFromBuffer( commonParameters.parametersBuffer, transmittance, integratedDepth, + canonicalIntersection, enableNormal, integratedNormal); } @@ -419,6 +437,7 @@ void particleDensityProcessHitBwdToBuffer( in float depth, inout float integratedDepth, inout float integratedDepthGrad, + in float3 canonicalIntersectionGrad, bool enableNormal, in float3 normal, inout float3 integratedNormal, @@ -452,6 +471,7 @@ void particleDensityProcessHitBwdToBuffer( commonParameters.parametersBuffer, transmittanceDiff, integratedDepthDiff, + canonicalIntersectionGrad, enableNormal, integratedNormalDiff, alphaGrad); @@ -536,10 +556,7 @@ bool particleDensityHitInstance( in float3 incidentDirectionGrad ) { - bwd_diff(gaussianParticle.incidentDirectionFromBuffer)( - particleIdx, - commonParameters.parametersBuffer, - sourcePosition, - incidentDirectionGrad - ); + // 3DGRT: incidentDirectionFromBuffer returns no_diff; position gradient + // from incident direction is not propagated back in the 3DGRT path. } + diff --git a/threedgrt_tracer/include/3dgrt/kernels/slang/models/neuralHarmonicFeaturesParticle.slang b/threedgrt_tracer/include/3dgrt/kernels/slang/models/neuralHarmonicFeaturesParticle.slang new file mode 100644 index 00000000..efce40a9 --- /dev/null +++ b/threedgrt_tracer/include/3dgrt/kernels/slang/models/neuralHarmonicFeaturesParticle.slang @@ -0,0 +1,320 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use it except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Neural Harmonic Features: per-particle K-dim features, output N (decoder input); decoder maps N -> RGB. +// Same CudaDeviceExport API as shRadiativeParticles.slang for drop-in replacement. +// Compile-time: FEATURE_INTERPOLATION_TYPE, FEATURE_INTERPOLATION_SUPPORT, FEATURE_ACTIVATION_TYPE, FEATURE_ACTIVATION_NUM_FREQUENCIES, +// PARTICLE_FEATURE_DIM (total K per particle = buffer stride), INTERP_POINT_FEATURE_DIM (per-interpolation-point = K/num_points). +// Support: center -> K=interpPointDim; tetrahedra -> K=4*interpPointDim. +// With sincos activation: N = interpPointDim * num_frequencies * 2 (separate sin/cos channels). + +namespace neuralHarmonicFeaturesParticle +{ +static const int ParticleFeatureDim = PARTICLE_FEATURE_DIM; // Total per-particle (K); not per-interpolation-point +static const int RayFeatureDim = RAY_FEATURE_DIM; // Decoder input N = INTERP_POINT_FEATURE_DIM * FEATURE_ACTIVATION_NUM_FREQUENCIES +static const int InterpPointFeatureDim = INTERP_POINT_FEATURE_DIM; // Per-interpolation-point dimension +static const int FeatureActivationType_None = 0; +static const int FeatureActivationType_Siren = 1; +static const int FeatureActivationType_Sincos = 2; +static const int FeatureActivationType_Relu = 3; +static const int FeatureActivationType = FEATURE_ACTIVATION_TYPE; +static const int FeatureActivationNumFrequencies = FEATURE_ACTIVATION_NUM_FREQUENCIES; + +// Interpolation type (compile-time from FEATURE_INTERPOLATION_TYPE) +static const int InterpolationType_Barycentric = 0; +static const int InterpolationType_Bezier = 1; // Not supported yet +static const int InterpolationType = FEATURE_INTERPOLATION_TYPE; + +// Interpolation support (compile-time from FEATURE_INTERPOLATION_SUPPORT) +static const int InterpolationSupport_Center = 0; +static const int InterpolationSupport_Tetrahedra = 1; +static const int InterpolationSupport_CoTriangles = 2; // 2 coplanar triangles / Not supported yet +static const int InterpolationSupport = FEATURE_INTERPOLATION_SUPPORT; + +// Canonical regular tetrahedron matching the GSplat NHT reference layout: +// p0=(sqrt(6),-sqrt(2),-1), p1=(-sqrt(6),-sqrt(2),-1), p2=(0,2*sqrt(2),-1), p3=(0,0,3). +// Incenter at origin; base z=-1, apex z=3; edge s=sqrt(24), height h=4, inradius r=1. +static const float tetraHedraEdge = 4.898979485566356f; // sqrt(24) +static const float tetraHedraFaceHeight = 4.242640687119285f; // s*sqrt(3)/2 +static const float tetraHedraHeight = 4.0f; // s*sqrt(2/3) +static const float tetraHedraFaceInRadius = 1.4142135623730951f; // s*sqrt(3)/6 = sqrt(2) +static const float tetraHedraInRadius = 1.0f; // unit sphere +static const float3 canonicalTetraVerts[4] = { + float3(0.5f * tetraHedraEdge, -tetraHedraFaceInRadius, -1.0f), + float3(-0.5f * tetraHedraEdge, -tetraHedraFaceInRadius, -1.0f), + float3(0.0f, tetraHedraFaceHeight - tetraHedraFaceInRadius, -1.0f), + float3(0.0f, 0.0f, tetraHedraHeight - tetraHedraInRadius) +}; +// Cramer terms for canonical tetrahedron (independent of P); used by barycentricTetrahedronCanonical. +static const float3 canonicalTetraE1 = canonicalTetraVerts[1] - canonicalTetraVerts[0]; +static const float3 canonicalTetraE2 = canonicalTetraVerts[2] - canonicalTetraVerts[0]; +static const float3 canonicalTetraE3 = canonicalTetraVerts[3] - canonicalTetraVerts[0]; +static const float3 canonicalTetraCrossE2E3 = cross(canonicalTetraE2, canonicalTetraE3); +static const float canonicalTetraDet = dot(canonicalTetraE1, canonicalTetraCrossE2E3); +static const float canonicalTetraInvDet = 1.0f / canonicalTetraDet; + +}; + +#if PARTICLE_FEATURE_HALF +typedef half feat_elem_t; +#else +typedef float feat_elem_t; +#endif + +namespace neuralHarmonicFeaturesParticle +{ + +struct ParametersBuffer +{ + Ptr _dataPtr; // [N_particles, K] flat; fp16 when PARTICLE_FEATURE_HALF + float *_gradPtr; + bool exclusiveGradient; +}; + +struct Parameters : IDifferentiable +{ + Array features; +}; + +[BackwardDifferentiable][ForceInline] +Parameters fetchParametersFromBuffer(no_diff uint32_t particleIdx, + no_diff int interpPointIdx, + no_diff ParametersBuffer parametersBuffer) +{ + Parameters parameters; + const uint32_t particleOffset = particleIdx * ParticleFeatureDim; + const uint32_t interpPointOffset = particleOffset + interpPointIdx * InterpPointFeatureDim; + [ForceUnroll] for (int i = 0; i < InterpPointFeatureDim; ++i) { + parameters.features[i] = parametersBuffer._dataPtr[interpPointOffset + i]; + } + return parameters; +} + +[BackwardDerivativeOf(fetchParametersFromBuffer)][ForceInline] +void fetchParametersFromBufferBwd(no_diff uint32_t particleIdx, + no_diff int interpPointIdx, + no_diff ParametersBuffer parametersBuffer, + Parameters parametersGrad) +{ + const uint32_t particleOffset = particleIdx * ParticleFeatureDim; + const uint32_t interpPointOffset = particleOffset + interpPointIdx * InterpPointFeatureDim; + [ForceUnroll] for (int i = 0; i < InterpPointFeatureDim; ++i) { + const float grad = parametersGrad.features[i]; + if (parametersBuffer.exclusiveGradient) { + parametersBuffer._gradPtr[interpPointOffset + i] += grad; + } else { + InterlockedAdd(parametersBuffer._gradPtr[interpPointOffset + i], grad); + } + } +} + +// Barycentric coordinates for the canonical tetrahedron only. Uses precomputed static const e1,e2,e3,invDet (independent of P). +[BackwardDifferentiable][ForceInline] +float4 barycentricTetrahedronCanonical(float3 P) +{ + float3 d = P - canonicalTetraVerts[0]; + float4 weights; + weights.y = dot(d, canonicalTetraCrossE2E3) * canonicalTetraInvDet; + weights.z = dot(canonicalTetraE1, cross(d, canonicalTetraE3)) * canonicalTetraInvDet; + weights.w = dot(canonicalTetraE1, cross(canonicalTetraE2, d)) * canonicalTetraInvDet; + weights.x = 1.0f - weights.y - weights.z - weights.w; + return weights; +} + +// Encode and activate : none -> identity; siren -> sin(b*2^f); relu -> max(0,b). +[BackwardDifferentiable][ForceInline] +float encodeAndActivate(float baseVal, no_diff int f) +{ + if (FeatureActivationType == FeatureActivationType_None) + return baseVal; + if (FeatureActivationType == FeatureActivationType_Relu) + return max(0.0f, baseVal); + float freq = ldexp(1.0f, f); + float angle = baseVal * freq; + return sin(angle); +} + +// Compute blended features into baseFeatures[INTERP_POINT_FEATURE_DIM], then optionally expand by activation to features[RayFeatureDim]. +// canonicalPosition is differential (hit position in particle canonical space; gradient accumulated in API canonicalPositionGrad). +[BackwardDifferentiable][ForceInline] +void featuresFromParametersBuffer(ParametersBuffer parametersBuffer, + no_diff uint32_t particleIdx, + float3 canonicalPosition, + out Array features +) +{ + Array baseFeatures = fetchParametersFromBuffer(particleIdx, 0, parametersBuffer).features; + if (InterpolationSupport == InterpolationSupport_Tetrahedra && InterpolationType == InterpolationType_Barycentric) { + float4 barycentricWeights = barycentricTetrahedronCanonical(canonicalPosition); + [ForceUnroll] for (int n = 0; n < InterpPointFeatureDim; ++n) { + baseFeatures[n] *= barycentricWeights[0]; + } + [ForceUnroll] for (int k = 1; k < 4; ++k) { + Parameters parameters = fetchParametersFromBuffer(particleIdx, k, parametersBuffer); + [ForceUnroll] for (int n = 0; n < InterpPointFeatureDim; ++n) { + baseFeatures[n] += barycentricWeights[k] * parameters.features[n]; + } + } + } + + if (FeatureActivationType == FeatureActivationType_None) { + [ForceUnroll] for (int i = 0; i < InterpPointFeatureDim; ++i) { + features[i] = baseFeatures[i]; + } + } else if (FeatureActivationType == FeatureActivationType_Relu) { + [ForceUnroll] for (int i = 0; i < InterpPointFeatureDim; ++i) { + features[i] = max(0.0f, baseFeatures[i]); + } + } else if (FeatureActivationType == FeatureActivationType_Sincos) { + [ForceUnroll] for (int k = 0; k < InterpPointFeatureDim; ++k) { + [ForceUnroll] for (int f = 0; f < FeatureActivationNumFrequencies; ++f) { + float freq = float(f + 1); + float angle = baseFeatures[k] * freq; + int outIdx = k * FeatureActivationNumFrequencies * 2 + f * 2; + features[outIdx + 0] = sin(angle); + features[outIdx + 1] = cos(angle); + } + } + } else { + [ForceUnroll] for (int k = 0; k < InterpPointFeatureDim; ++k) { + [ForceUnroll] for (int f = 0; f < FeatureActivationNumFrequencies; ++f) { + features[k * FeatureActivationNumFrequencies + f] = encodeAndActivate(baseFeatures[k], f); + } + } + } +} + +[BackwardDifferentiable][ForceInline] +void integrateFeatures(float weight, + in Array features, + inout float integratedFeatures[RayFeatureDim]) +{ + if (weight > 0.0f) { + [ForceUnroll] for (int i = 0; i < RayFeatureDim; ++i) { + if (backToFront) + integratedFeatures[i] = lerp(integratedFeatures[i], features[i], weight); + else + integratedFeatures[i] += features[i] * weight; + } + } +} + +[BackwardDifferentiable][ForceInline] +void integrateFeaturesFromBuffer(float weight, + no_diff uint32_t particleIdx, + ParametersBuffer parametersBuffer, + float3 canonicalPosition, + inout float integratedFeatures[RayFeatureDim]) +{ + if (weight > 0.0f) { + Array features; + featuresFromParametersBuffer(parametersBuffer, particleIdx, canonicalPosition, features); + integrateFeatures(weight, features, integratedFeatures); + } +} + +} // namespace neuralHarmonicFeaturesParticle + +// ------------------------------------------------------------------------------------------------------------------ +// Entry points - same CudaDeviceExport API as shRadiativeParticles.slang + +[CudaDeviceExport] +inline void particleFeaturesFromBuffer( + in uint32_t particleIdx, + in feat_elem_t *featuresBufferPtr, + in int auxParam, + in float3 incidentDirection, + in float3 canonicalPosition, + out float features[neuralHarmonicFeaturesParticle.RayFeatureDim]) +{ + neuralHarmonicFeaturesParticle.ParametersBuffer parametersBuffer; + parametersBuffer._dataPtr = (Ptr)featuresBufferPtr; + parametersBuffer._gradPtr = nullptr; + parametersBuffer.exclusiveGradient = false; + + neuralHarmonicFeaturesParticle.featuresFromParametersBuffer( + parametersBuffer, + particleIdx, + canonicalPosition, + features + ); +} + +[CudaDeviceExport] +inline void particleFeaturesIntegrateFwdFromBuffer( + in float3 incidentDirection, + in float3 canonicalPosition, + in float weight, + in uint32_t particleIdx, + in feat_elem_t *featuresBufferPtr, + in int auxParam, + inout float integratedFeatures[neuralHarmonicFeaturesParticle.RayFeatureDim]) +{ + neuralHarmonicFeaturesParticle.ParametersBuffer parametersBuffer; + parametersBuffer._dataPtr = (Ptr)featuresBufferPtr; + parametersBuffer._gradPtr = nullptr; + parametersBuffer.exclusiveGradient = false; + + neuralHarmonicFeaturesParticle.integrateFeaturesFromBuffer( + weight, particleIdx, parametersBuffer, canonicalPosition, integratedFeatures); +} + +// canonicalPosition is differential; canonicalPositionGrad is inout and accumulated by this backward. +[CudaDeviceExport] void particleFeaturesIntegrateBwdToBuffer( + in float3 incidentDirection, + in float3 canonicalPosition, + inout float3 canonicalPositionGrad, + in float alpha, + inout float alphaGrad, + in uint32_t particleIdx, + in feat_elem_t *featuresBufferPtr, + in float *featuresBufferGradPtr, + in int auxParam, + in bool exclusiveGradient, + in float features[neuralHarmonicFeaturesParticle.RayFeatureDim], + inout float integratedFeatures[neuralHarmonicFeaturesParticle.RayFeatureDim], + inout float integratedFeaturesGrad[neuralHarmonicFeaturesParticle.RayFeatureDim]) +{ + if (alpha > 0.0f) + { + neuralHarmonicFeaturesParticle.ParametersBuffer parametersBuffer; + parametersBuffer._dataPtr = (Ptr)featuresBufferPtr; + parametersBuffer._gradPtr = (Ptr)featuresBufferGradPtr; + parametersBuffer.exclusiveGradient = exclusiveGradient; + + DifferentialPair alphaDiff = DifferentialPair(alpha, alphaGrad); + + const float weight = 1.0f / (1.0f - alpha); + [ForceUnroll] for (int i = 0; i < neuralHarmonicFeaturesParticle.RayFeatureDim; ++i) { + integratedFeatures[i] = (integratedFeatures[i] - features[i] * alpha) * weight; + } + + DifferentialPair integratedFeaturesDiff = + DifferentialPair(integratedFeatures, integratedFeaturesGrad); + + DifferentialPair canonicalPositionDiff = DifferentialPair(canonicalPosition, canonicalPositionGrad); + + bwd_diff(neuralHarmonicFeaturesParticle.integrateFeaturesFromBuffer)( + alphaDiff, + particleIdx, + parametersBuffer, + canonicalPositionDiff, + integratedFeaturesDiff); + + integratedFeaturesGrad = integratedFeaturesDiff.getDifferential(); + canonicalPositionGrad += canonicalPositionDiff.getDifferential(); + alphaGrad += alphaDiff.getDifferential(); + } +} diff --git a/threedgrt_tracer/include/3dgrt/kernels/slang/models/radiativeParticles.slang b/threedgrt_tracer/include/3dgrt/kernels/slang/models/radiativeParticles.slang new file mode 100644 index 00000000..970005eb --- /dev/null +++ b/threedgrt_tracer/include/3dgrt/kernels/slang/models/radiativeParticles.slang @@ -0,0 +1,27 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Radiative Particles Wrapper +// Conditionally includes the appropriate feature implementation based on FEATURE_TRANSFORM_TYPE + +#if FEATURE_TRANSFORM_TYPE == 0 + // Spherical Harmonics mode + #include <3dgrt/kernels/slang/models/shRadiativeParticles.slang> +#elif FEATURE_TRANSFORM_TYPE == 1 + // Post-MLP radiance mode + #include <3dgrt/kernels/slang/models/neuralHarmonicFeaturesParticle.slang> +#else + #error "Unknown FEATURE_TRANSFORM_TYPE. Must be 0 (SH) or 1 (neural_harmonic_features)" +#endif diff --git a/threedgrt_tracer/include/3dgrt/kernels/slang/models/shRadiativeParticles.slang b/threedgrt_tracer/include/3dgrt/kernels/slang/models/shRadiativeParticles.slang index f2755a60..1e110202 100644 --- a/threedgrt_tracer/include/3dgrt/kernels/slang/models/shRadiativeParticles.slang +++ b/threedgrt_tracer/include/3dgrt/kernels/slang/models/shRadiativeParticles.slang @@ -31,12 +31,6 @@ struct ParametersBuffer bool exclusiveGradient; //< true if the gradient maybe updated without atomics }; -struct CommonParameters -{ - ParametersBuffer parametersBuffer; - int sphDegree; -}; - struct Parameters : IDifferentiable { vector sphCoefficients[RadianceMaxNumSphCoefficients]; @@ -48,7 +42,7 @@ Parameters fetchParametersFromBuffer(no_diff uint32_t particleIdx, { Parameters parameters; const uint32_t particleOffset = particleIdx * RadianceMaxNumSphCoefficients; - [unroll] for (int i = 0; i < RadianceMaxNumSphCoefficients; ++i) { + [ForceUnroll] for (int i = 0; i < RadianceMaxNumSphCoefficients; ++i) { parameters.sphCoefficients[i] = parametersBuffer._dataPtr[particleOffset + i]; } return parameters; @@ -60,14 +54,14 @@ void fetchParametersFromBufferBwd(no_diff uint32_t particleIdx, Parameters parametersGrad) { const uint32_t particleOffset = particleIdx * RadianceMaxNumSphCoefficients; - [unroll] for (int i = 0; i < RadianceMaxNumSphCoefficients; ++i) { + [ForceUnroll] for (int i = 0; i < RadianceMaxNumSphCoefficients; ++i) { const vector coeffs = parametersGrad.sphCoefficients[i]; if (parametersBuffer.exclusiveGradient) { - [unroll] for (int j = 0; j < Dim; ++j) { + [ForceUnroll] for (int j = 0; j < Dim; ++j) { parametersBuffer._gradPtr[particleOffset + i][j] += coeffs[j]; } } else { - [unroll] for (int j = 0; j < Dim; ++j) { + [ForceUnroll] for (int j = 0; j < Dim; ++j) { InterlockedAdd(parametersBuffer._gradPtr[particleOffset + i][j], coeffs[j]); } } @@ -145,41 +139,58 @@ void integrateRadianceFromBuffer(no_diff float3 incident // Entry points [CudaDeviceExport] -inline vector particleFeaturesFromBuffer(in uint32_t particleIdx, - shRadiativeParticle.CommonParameters commonParameters, - in float3 incidentDirection) +inline void particleFeaturesFromBuffer( + in uint32_t particleIdx, + in float *radianceBufferPtr, + in int auxParam, + in float3 incidentDirection, + in float3 canonicalPosition, + out float features[shRadiativeParticle.Dim]) { - return sphericalHarmonics.decode( - commonParameters.sphDegree, - shRadiativeParticle.fetchParametersFromBuffer(particleIdx, commonParameters.parametersBuffer).sphCoefficients, + shRadiativeParticle.ParametersBuffer parametersBuffer; + parametersBuffer._dataPtr = (Ptr, Access.Read>)radianceBufferPtr; + parametersBuffer._gradPtr = nullptr; + parametersBuffer.exclusiveGradient = false; + + vector featuresVec = sphericalHarmonics.decode( + auxParam, + shRadiativeParticle.fetchParametersFromBuffer(particleIdx, parametersBuffer).sphCoefficients, incidentDirection); + + [ForceUnroll] for (int i = 0; i < shRadiativeParticle.Dim; ++i) { + features[i] = featuresVec[i]; + } } [CudaDeviceExport] -inline void particleFeaturesIntegrateFwd(in float weight, - in vector features, - inout vector integratedFeatures) +inline void particleFeaturesIntegrateFwdFromBuffer( + in float3 incidentDirection, + in float3 canonicalPosition, + in float weight, + in uint32_t particleIdx, + in float *radianceBufferPtr, + in int auxParam, + inout float integratedFeatures[shRadiativeParticle.Dim]) { - shRadiativeParticle.integrateRadiance( - weight, - features, - integratedFeatures - ); -} + shRadiativeParticle.ParametersBuffer parametersBuffer; + parametersBuffer._dataPtr = (Ptr, Access.Read>)radianceBufferPtr; + parametersBuffer._gradPtr = nullptr; + parametersBuffer.exclusiveGradient = false; -[CudaDeviceExport] inline void particleFeaturesIntegrateFwdFromBuffer(in float3 incidentDirection, - in float weight, - in uint32_t particleIdx, - shRadiativeParticle.CommonParameters commonParameters, - inout vector integratedFeatures) -{ + vector integratedVec; + [ForceUnroll] for (int i = 0; i < shRadiativeParticle.Dim; ++i) { + integratedVec[i] = integratedFeatures[i]; + } shRadiativeParticle.integrateRadianceFromBuffer( incidentDirection, - commonParameters.sphDegree, + auxParam, weight, particleIdx, - commonParameters.parametersBuffer, - integratedFeatures); + parametersBuffer, + integratedVec); + [ForceUnroll] for (int i = 0; i < shRadiativeParticle.Dim; ++i) { + integratedFeatures[i] = integratedVec[i]; + } } [CudaDeviceExport] void particleFeaturesIntegrateBwd( @@ -214,47 +225,81 @@ inline void particleFeaturesIntegrateFwd(in float weight, [CudaDeviceExport] void particleFeaturesIntegrateBwdToBuffer( in float3 incidentDirection, + in float3 canonicalPosition, + inout float3 canonicalPositionGrad, in float alpha, inout float alphaGrad, in uint32_t particleIdx, - shRadiativeParticle.CommonParameters commonParameters, - in vector features, - inout vector integratedFeatures, - inout vector integratedFeaturesGrad) + in float *featuresBufferPtr, + in float *featuresBufferGradPtr, + in int auxParam, + in bool exclusiveGradient, + in float features[shRadiativeParticle.Dim], + inout float integratedFeatures[shRadiativeParticle.Dim], + inout float integratedFeaturesGrad[shRadiativeParticle.Dim]) { if (alpha > 0.0f) { + shRadiativeParticle.ParametersBuffer parametersBuffer; + parametersBuffer._dataPtr = (Ptr, Access.Read>)featuresBufferPtr; + parametersBuffer._gradPtr = (Ptr, Access.ReadWrite>)featuresBufferGradPtr; + parametersBuffer.exclusiveGradient = exclusiveGradient; + DifferentialPair alphaDiff = DifferentialPair(alpha, alphaGrad); const float weight = 1.0f / (1.0f - alpha); - integratedFeatures = (integratedFeatures - features * alpha) * weight; + [ForceUnroll] for (int i = 0; i < shRadiativeParticle.Dim; ++i) { + integratedFeatures[i] = (integratedFeatures[i] - features[i] * alpha) * weight; + } + + vector integratedFeaturesVec; + vector integratedFeaturesGradVec; + [ForceUnroll] for (int i = 0; i < shRadiativeParticle.Dim; ++i) { + integratedFeaturesVec[i] = integratedFeatures[i]; + integratedFeaturesGradVec[i] = integratedFeaturesGrad[i]; + } DifferentialPair> integratedFeaturesDiff = - DifferentialPair>(integratedFeatures, integratedFeaturesGrad); + DifferentialPair>(integratedFeaturesVec, integratedFeaturesGradVec); bwd_diff(shRadiativeParticle.integrateRadianceFromBuffer)( incidentDirection, - commonParameters.sphDegree, + auxParam, alphaDiff, particleIdx, - commonParameters.parametersBuffer, + parametersBuffer, integratedFeaturesDiff); - integratedFeaturesGrad = integratedFeaturesDiff.getDifferential(); - alphaGrad = alphaDiff.getDifferential(); + [ForceUnroll] for (int i = 0; i < shRadiativeParticle.Dim; ++i) { + integratedFeaturesGrad[i] = integratedFeaturesDiff.getDifferential()[i]; + } + alphaGrad += alphaDiff.getDifferential(); } } [CudaDeviceExport] void particleFeaturesBwdToBuffer( in uint32_t particleIdx, - shRadiativeParticle.CommonParameters commonParameters, - in vector featuresGrad, - in float3 incidentDirection + in float *featuresBufferPtr, + in float *featuresBufferGradPtr, + in int auxParam, + in bool exclusiveGradient, + in float featuresGrad[shRadiativeParticle.Dim], + in float3 incidentDirection, + inout float3 incidentDirectionGrad ) { + shRadiativeParticle.ParametersBuffer parametersBuffer; + parametersBuffer._dataPtr = (Ptr, Access.Read>)featuresBufferPtr; + parametersBuffer._gradPtr = (Ptr, Access.ReadWrite>)featuresBufferGradPtr; + parametersBuffer.exclusiveGradient = exclusiveGradient; + + vector featuresGradVec; + [ForceUnroll] for (int i = 0; i < shRadiativeParticle.Dim; ++i) featuresGradVec[i] = featuresGrad[i]; + + // incidentDirection is no_diff in radianceFromBuffer for 3DGRT; pass as plain value. bwd_diff(shRadiativeParticle.radianceFromBuffer)( particleIdx, incidentDirection, - commonParameters.sphDegree, - commonParameters.parametersBuffer, - featuresGrad); + auxParam, + parametersBuffer, + featuresGradVec); } diff --git a/threedgrt_tracer/include/3dgrt/optixTracer.h b/threedgrt_tracer/include/3dgrt/optixTracer.h index f84463b3..13469eb3 100644 --- a/threedgrt_tracer/include/3dgrt/optixTracer.h +++ b/threedgrt_tracer/include/3dgrt/optixTracer.h @@ -144,7 +144,7 @@ class OptixTracer { torch::Tensor rayOri, torch::Tensor rayDir, torch::Tensor particleDensity, - torch::Tensor particleRadiance, + torch::Tensor particleFeatures, uint32_t renderOpts, int sphDegree, float minTransmittance); @@ -153,13 +153,13 @@ class OptixTracer { torch::Tensor rayToWorld, torch::Tensor rayOri, torch::Tensor rayDir, - torch::Tensor rayRad, + torch::Tensor rayFeat, torch::Tensor rayDns, torch::Tensor rayHit, torch::Tensor rayNrm, torch::Tensor particleDensity, - torch::Tensor particleRadiance, - torch::Tensor rayRadGrd, + torch::Tensor particleFeatures, + torch::Tensor rayFeatGrd, torch::Tensor rayDnsGrd, torch::Tensor rayHitGrd, torch::Tensor rayNrmGrd, diff --git a/threedgrt_tracer/include/3dgrt/pipelineParameters.h b/threedgrt_tracer/include/3dgrt/pipelineParameters.h index a7af073d..ce8e6c1a 100644 --- a/threedgrt_tracer/include/3dgrt/pipelineParameters.h +++ b/threedgrt_tracer/include/3dgrt/pipelineParameters.h @@ -21,17 +21,41 @@ #include <3dgrt/pipelineDefinitions.h> #include <3dgrt/tensorAccessor.h> +// Per-particle feature storage element type. fp16 when PARTICLE_FEATURE_HALF=1, else fp32. +// Gradient buffer for per-particle features is always fp32. +#ifndef PARTICLE_FEATURE_HALF +#define PARTICLE_FEATURE_HALF 0 +#endif +// Per-ray integrated feature output element type. fp16 when FEATURE_OUTPUT_HALF=1, else fp32. +// Gradient buffer for the integrated feature output is always fp32. +#ifndef FEATURE_OUTPUT_HALF +#define FEATURE_OUTPUT_HALF 0 +#endif +#if PARTICLE_FEATURE_HALF || FEATURE_OUTPUT_HALF +#include +#endif +#if PARTICLE_FEATURE_HALF +using TParticleFeatureElem = __half; +#else +using TParticleFeatureElem = float; +#endif +#if FEATURE_OUTPUT_HALF +using TRayFeatureElem = __half; +#else +using TRayFeatureElem = float; +#endif + struct PipelineParameters { float4 rayToWorld[3]; ///< float3x4 ray to world transformation (row-major) PackedTensorAccessor32 rayOrigin; ///< ray origin PackedTensorAccessor32 rayDirection; ///< ray direction - const ParticleDensity* particleDensity; ///< position, scale, quaternions, density - const float* particleRadiance; ///< spherical harmonics coefficients - const void* particleExtendedData; ///< pipeline specific particle data - int32_t* particleVisibility; ///< pipeline specific particle data + const ParticleDensity* particleDensity; ///< position, scale, quaternions, density + const TParticleFeatureElem* particleFeatures; ///< per-particle features (fp16 when PARTICLE_FEATURE_HALF) + const void* particleExtendedData; ///< pipeline specific particle data + int32_t* particleVisibility; ///< pipeline specific particle data - PackedTensorAccessor32 rayRadiance; ///< output integrated ray radiance + PackedTensorAccessor32 rayFeatures; ///< integrated ray features (fp16 when FEATURE_OUTPUT_HALF) PackedTensorAccessor32 rayDensity; ///< output integrated ray density PackedTensorAccessor32 rayHitDistance; ///< output integrated ray hit distance PackedTensorAccessor32 rayNormal; ///< output integrated ray normal @@ -46,6 +70,11 @@ struct PipelineParameters { float alphaMinThreshold; unsigned int sphDegree; + // Feature-based radiance dimensions + static constexpr unsigned int ParticleFeatureDim = PARTICLE_FEATURE_DIM; ///< Total feature dim per particle (buffer stride K) + static constexpr unsigned int RayFeatureDim = RAY_FEATURE_DIM; ///< Per-ray (decoder input N) + static constexpr unsigned int FeatureTransformType = FEATURE_TRANSFORM_TYPE; ///< 0=SH, 1=learned + uint2 frameBounds; unsigned int frameNumber; int gPrimNumTri; @@ -87,11 +116,11 @@ struct PipelineParameters { }; struct PipelineBackwardParameters : PipelineParameters { - PackedTensorAccessor32 rayRadianceGrad; ///< integrated ray radiance gradient + PackedTensorAccessor32 rayFeaturesGrad; ///< integrated ray features gradient (fp32) PackedTensorAccessor32 rayDensityGrad; ///< integrated ray density gradient PackedTensorAccessor32 rayHitDistanceGrad; ///< integrated ray hit distance gradient PackedTensorAccessor32 rayNormalGrad; ///< integrated ray hit distance gradient ParticleDensity* particleDensityGrad; ///< output position, scale, quaternions, density gradient - float* particleRadianceGrad; ///< output spherical harmonics coefficients gradient + float* particleFeaturesGrad; ///< per-particle features gradient (fp32) }; diff --git a/threedgrt_tracer/include/3dgrt/tensorBuffering.h b/threedgrt_tracer/include/3dgrt/tensorBuffering.h index 23454342..458d6ffe 100644 --- a/threedgrt_tracer/include/3dgrt/tensorBuffering.h +++ b/threedgrt_tracer/include/3dgrt/tensorBuffering.h @@ -33,13 +33,18 @@ inline scalar_t* getPtr(torch::Tensor tensor) { return reinterpret_cast(tensor.contiguous().data_ptr()); } else if (tensor.dtype() == torch::kFloat32) { return reinterpret_cast(tensor.contiguous().data_ptr()); + } else if (tensor.dtype() == torch::kHalf) { + return reinterpret_cast(tensor.contiguous().data_ptr()); } else { throw std::runtime_error("getPtr(tensor) received a tensor of unsupported type"); } } +// Note: `reinterpret_cast` is used on the raw `tensor.data_ptr()` so that this +// template instantiates for element types without a PyTorch scalar_type entry +// (e.g. `__half`). Dtype agreement is the caller's responsibility. template class PtrTraits = DefaultPtrTraits> PackedTensorAccessor32 packed_accessor32(torch::Tensor tensor) { - return PackedTensorAccessor32(static_cast::PtrType>(tensor.data_ptr()), + return PackedTensorAccessor32(reinterpret_cast::PtrType>(tensor.data_ptr()), tensor.sizes().data(), tensor.strides().data()); } \ No newline at end of file diff --git a/threedgrt_tracer/include/playground/pipelineParameters.h b/threedgrt_tracer/include/playground/pipelineParameters.h index d25a6ec8..3967a252 100644 --- a/threedgrt_tracer/include/playground/pipelineParameters.h +++ b/threedgrt_tracer/include/playground/pipelineParameters.h @@ -52,8 +52,8 @@ struct PlaygroundTracingParams : PipelineParameters { // rayOri -> rayOrigin // rayDir -> rayDirection // mogPos, mogRot, mogScl, mogDns -> particleDensity - // mogSph-> particleRadiance - // rayRad -> rayRadiance + // mogSph-> particleFeatures + // rayRad -> rayFeatures // rayDns -> rayDensity // rayHit -> rayHitDistance // rayHitsCount -> rayHitsCount diff --git a/threedgrt_tracer/setup_3dgrt.py b/threedgrt_tracer/setup_3dgrt.py index fd282e36..97eb8531 100644 --- a/threedgrt_tracer/setup_3dgrt.py +++ b/threedgrt_tracer/setup_3dgrt.py @@ -16,6 +16,7 @@ import os from threedgrut.utils import jit +from threedgrut.model.features import Features # ---------------------------------------------------------------------------- @@ -24,13 +25,50 @@ def setup_3dgrt(conf): def to_cpp_bool(value): return "true" if value else "false" + feat = Features(conf) + transform_defines = [ + f"-DPARTICLE_FEATURE_DIM={feat.particle_feature_dim}", + f"-DRAY_FEATURE_DIM={feat.ray_feature_dim}", + f"-DFEATURE_TRANSFORM_TYPE={feat.transform_type}", + ] + nht_defines = [ + f"-DFEATURE_INTERPOLATION_TYPE={feat.interpolation_type}", + f"-DFEATURE_INTERPOLATION_SUPPORT={feat.interpolation_support}", + f"-DFEATURE_ACTIVATION_TYPE={feat.activation_type}", + f"-DFEATURE_ACTIVATION_NUM_FREQUENCIES={feat.activation_num_frequencies}", + f"-DINTERP_POINT_FEATURE_DIM={feat.interp_point_feature_dim}", + ] + half_defines = [ + f"-DPARTICLE_FEATURE_HALF={1 if conf.render.particle_feature_half else 0}", + f"-DFEATURE_OUTPUT_HALF={1 if conf.render.feature_output_half else 0}", + ] + include_paths = [] include_paths.append(os.path.join(os.path.dirname(__file__), "include")) include_paths.append(os.path.join(os.path.dirname(__file__), "dependencies", "optix-dev", "include")) - # Compiler options. - cflags = [] - cuda_flags = [] + # Compiler options. Same -D for feature dims so pipelineParameters.h and JIT OptiX pipeline (generateDefines) stay in sync. + cflags = [ + *transform_defines, + *half_defines, + ] + cuda_flags = [ + # Feature-based radiance dimensions (must match Slang compilation) + *transform_defines, + *nht_defines, + *half_defines, + # Other particle parameters + f"-DPARTICLE_RADIANCE_NUM_COEFFS={(conf.render.particle_radiance_sph_degree + 1) ** 2}", + f"-DGAUSSIAN_PARTICLE_KERNEL_DEGREE={conf.render.particle_kernel_degree}", + f"-DGAUSSIAN_PARTICLE_MIN_KERNEL_DENSITY={conf.render.particle_kernel_min_response}", + f"-DGAUSSIAN_PARTICLE_MIN_ALPHA={conf.render.particle_kernel_min_alpha}", + f"-DGAUSSIAN_MIN_TRANSMITTANCE_THRESHOLD={conf.render.min_transmittance}", + ] + # When PARTICLE_FEATURE_HALF=1 the Slang-generated header uses __half types; + # the Slang prelude only pulls in and defines __half when + # SLANG_CUDA_ENABLE_HALF is set. + if conf.render.particle_feature_half or conf.render.feature_output_half: + cuda_flags.append("-DSLANG_CUDA_ENABLE_HALF=1") # List of sources. source_files = [ @@ -44,7 +82,7 @@ def to_cpp_bool(value): jit.compile_slang_kernel( kernel_files=[ f"{os.path.join(slang_build_dir,'models/gaussianParticles.slang')}", - f"{os.path.join(slang_build_dir,'models/shRadiativeParticles.slang')}", + f"{os.path.join(slang_build_dir,'models/radiativeParticles.slang')}", ], output_file=f"{os.path.join(slang_build_dir, 'gaussianParticles.cuh')}", defines=[ @@ -55,6 +93,10 @@ def to_cpp_bool(value): f"-DGAUSSIAN_PARTICLE_MAX_ALPHA={conf.render.particle_kernel_max_alpha}", f"-DGAUSSIAN_PARTICLE_ENABLE_NORMAL={to_cpp_bool(conf.render.enable_normals)}", f"-DGAUSSIAN_PARTICLE_SURFEL={to_cpp_bool(conf.render.primitive_type=='trisurfel')}", + # Feature-based radiance dimensions + *transform_defines, + *nht_defines, + *half_defines, ], include_paths=[ os.path.join(os.path.dirname(__file__), "include"), diff --git a/threedgrt_tracer/src/kernels/cuda/barycentricSurfelsOptix.cu b/threedgrt_tracer/src/kernels/cuda/barycentricSurfelsOptix.cu index 872b79ce..bf8291d1 100644 --- a/threedgrt_tracer/src/kernels/cuda/barycentricSurfelsOptix.cu +++ b/threedgrt_tracer/src/kernels/cuda/barycentricSurfelsOptix.cu @@ -134,7 +134,7 @@ extern "C" __global__ void __raygen__rg() { float3 sphCoefficients[SPH_MAX_NUM_COEFFS]; fetchParticleSphCoefficients( rayHit.particleId, - params.particleRadiance, + params.particleFeatures, &sphCoefficients[0]); const float3 rayParticleRadiance = radianceFromSpH(params.sphDegree, &sphCoefficients[0], rayDirection); @@ -161,9 +161,9 @@ extern "C" __global__ void __raygen__rg() { } } - params.rayRadiance[idx.z][idx.y][idx.x][0] = rayRadiance.x; - params.rayRadiance[idx.z][idx.y][idx.x][1] = rayRadiance.y; - params.rayRadiance[idx.z][idx.y][idx.x][2] = rayRadiance.z; + params.rayFeatures[idx.z][idx.y][idx.x][0] = rayRadiance.x; + params.rayFeatures[idx.z][idx.y][idx.x][1] = rayRadiance.y; + params.rayFeatures[idx.z][idx.y][idx.x][2] = rayRadiance.z; params.rayDensity[idx.z][idx.y][idx.x][0] = 1 - rayTransmittance; params.rayHitDistance[idx.z][idx.y][idx.x][0] = rayHitDistance; params.rayHitDistance[idx.z][idx.y][idx.x][1] = rayLastHitDistance; diff --git a/threedgrt_tracer/src/kernels/cuda/referenceB2FSlangBwdOptix.cu b/threedgrt_tracer/src/kernels/cuda/referenceB2FSlangBwdOptix.cu index 8dda03ca..f31cd6aa 100644 --- a/threedgrt_tracer/src/kernels/cuda/referenceB2FSlangBwdOptix.cu +++ b/threedgrt_tracer/src/kernels/cuda/referenceB2FSlangBwdOptix.cu @@ -101,7 +101,7 @@ extern "C" __global__ void __raygen__rg() { const float3 rayOrigin = params.rayWorldOrigin(idx); const float3 rayDirection = params.rayWorldDirection(idx); - float3 rayRadiance = make_float3(params.rayRadiance[idx.z][idx.y][idx.x][0], params.rayRadiance[idx.z][idx.y][idx.x][1], params.rayRadiance[idx.z][idx.y][idx.x][2]); + float3 rayRadiance = make_float3(params.rayFeatures[idx.z][idx.y][idx.x][0], params.rayFeatures[idx.z][idx.y][idx.x][1], params.rayFeatures[idx.z][idx.y][idx.x][2]); float rayTransmittance = 1.0f - params.rayDensity[idx.z][idx.y][idx.x][0]; float rayHitDistance = params.rayHitDistance[idx.z][idx.y][idx.x][0]; #ifdef ENABLE_NORMALS @@ -110,7 +110,7 @@ extern "C" __global__ void __raygen__rg() { float rayMaxHitDistance = params.rayHitDistance[idx.z][idx.y][idx.x][1]; - float3 rayRadianceGrad = make_float3(params.rayRadianceGrad[idx.z][idx.y][idx.x][0], params.rayRadianceGrad[idx.z][idx.y][idx.x][1], params.rayRadianceGrad[idx.z][idx.y][idx.x][2]); + float3 rayRadianceGrad = make_float3(params.rayFeaturesGrad[idx.z][idx.y][idx.x][0], params.rayFeaturesGrad[idx.z][idx.y][idx.x][1], params.rayFeaturesGrad[idx.z][idx.y][idx.x][2]); float rayTransmittanceGrad = -1.0f * params.rayDensityGrad[idx.z][idx.y][idx.x][0]; float rayHitDistanceGrad = params.rayHitDistanceGrad[idx.z][idx.y][idx.x][0]; #ifdef ENABLE_NORMALS @@ -148,8 +148,8 @@ extern "C" __global__ void __raygen__rg() { // rayHit.particleId, // (ParticleDensity_0*)params.particleDensity, // (ParticleDensity_0*)params.particleDensityGrad, - // (float*)params.particleRadiance, - // (float*)params.particleRadianceGrad, + // (float*)params.particleFeatures, + // (float*)params.particleFeaturesGrad, // params.hitMinGaussianResponse, // params.alphaMinThreshold, // PipelineParameters::ParticleKernelDegree, @@ -185,7 +185,7 @@ extern "C" __global__ void __intersection__is() { : particleDensityHitCustom(optixGetWorldRayOrigin(), optixGetWorldRayDirection(), optixGetPrimitiveIndex(), - {{(gaussianParticle_RawParameters_0*)params.particleDensity, nullptr}}, + {{(gaussianParticle_RawParameters_0*)params.particleDensity, nullptr, true}}, optixGetRayTmin(), optixGetRayTmax(), params.hitMaxParticleSquaredDistance, diff --git a/threedgrt_tracer/src/kernels/cuda/referenceBwdOptix.cu b/threedgrt_tracer/src/kernels/cuda/referenceBwdOptix.cu index 57b2a50b..2fcd80fa 100644 --- a/threedgrt_tracer/src/kernels/cuda/referenceBwdOptix.cu +++ b/threedgrt_tracer/src/kernels/cuda/referenceBwdOptix.cu @@ -109,12 +109,12 @@ extern "C" __global__ void __raygen__rg() { const float3 rayOrigin = params.rayWorldOrigin(idx); const float3 rayDirection = params.rayWorldDirection(idx); - float3 rayIntegratedRadiance = make_float3(params.rayRadiance[idx.z][idx.y][idx.x][0], params.rayRadiance[idx.z][idx.y][idx.x][1], params.rayRadiance[idx.z][idx.y][idx.x][2]); + float3 rayIntegratedRadiance = make_float3(params.rayFeatures[idx.z][idx.y][idx.x][0], params.rayFeatures[idx.z][idx.y][idx.x][1], params.rayFeatures[idx.z][idx.y][idx.x][2]); float rayIntegratedTransmittance = 1.0f - params.rayDensity[idx.z][idx.y][idx.x][0]; float rayIntegratedHitDistance = params.rayHitDistance[idx.z][idx.y][idx.x][0]; float rayMaxHitDistance = params.rayHitDistance[idx.z][idx.y][idx.x][1]; - float3 rayRadianceGrad = make_float3(params.rayRadianceGrad[idx.z][idx.y][idx.x][0], params.rayRadianceGrad[idx.z][idx.y][idx.x][1], params.rayRadianceGrad[idx.z][idx.y][idx.x][2]); + float3 rayRadianceGrad = make_float3(params.rayFeaturesGrad[idx.z][idx.y][idx.x][0], params.rayFeaturesGrad[idx.z][idx.y][idx.x][1], params.rayFeaturesGrad[idx.z][idx.y][idx.x][2]); float rayTransmittanceGrad = -1.0f * params.rayDensityGrad[idx.z][idx.y][idx.x][0]; float rayHitDistanceGrad = params.rayHitDistanceGrad[idx.z][idx.y][idx.x][0]; @@ -147,8 +147,8 @@ extern "C" __global__ void __raygen__rg() { rayHit.particleId, params.particleDensity, params.particleDensityGrad, - params.particleRadiance, - params.particleRadianceGrad, + params.particleFeatures, + params.particleFeaturesGrad, params.hitMinGaussianResponse, params.alphaMinThreshold, params.minTransmittance, diff --git a/threedgrt_tracer/src/kernels/cuda/referenceOptix.cu b/threedgrt_tracer/src/kernels/cuda/referenceOptix.cu index bfba5e91..caecaebf 100644 --- a/threedgrt_tracer/src/kernels/cuda/referenceOptix.cu +++ b/threedgrt_tracer/src/kernels/cuda/referenceOptix.cu @@ -141,7 +141,7 @@ extern "C" __global__ void __raygen__rg() { rayDirection, rayHit.particleId, params.particleDensity, - params.particleRadiance, + params.particleFeatures, params.hitMinGaussianResponse, params.alphaMinThreshold, params.sphDegree, @@ -169,9 +169,9 @@ extern "C" __global__ void __raygen__rg() { } } - params.rayRadiance[idx.z][idx.y][idx.x][0] = rayRadiance.x; - params.rayRadiance[idx.z][idx.y][idx.x][1] = rayRadiance.y; - params.rayRadiance[idx.z][idx.y][idx.x][2] = rayRadiance.z; + params.rayFeatures[idx.z][idx.y][idx.x][0] = rayRadiance.x; + params.rayFeatures[idx.z][idx.y][idx.x][1] = rayRadiance.y; + params.rayFeatures[idx.z][idx.y][idx.x][2] = rayRadiance.z; params.rayDensity[idx.z][idx.y][idx.x][0] = 1 - rayTransmittance; params.rayHitDistance[idx.z][idx.y][idx.x][0] = rayHitDistance; params.rayHitDistance[idx.z][idx.y][idx.x][1] = rayLastHitDistance; diff --git a/threedgrt_tracer/src/kernels/cuda/referenceSlangBwdOptix.cu b/threedgrt_tracer/src/kernels/cuda/referenceSlangBwdOptix.cu index 06dbe09b..1eb31a31 100644 --- a/threedgrt_tracer/src/kernels/cuda/referenceSlangBwdOptix.cu +++ b/threedgrt_tracer/src/kernels/cuda/referenceSlangBwdOptix.cu @@ -110,7 +110,15 @@ extern "C" __global__ void __raygen__rg() { const float3 rayOrigin = params.rayWorldOrigin(idx); const float3 rayDirection = params.rayWorldDirection(idx); - float3 rayRadiance = make_float3(params.rayRadiance[idx.z][idx.y][idx.x][0], params.rayRadiance[idx.z][idx.y][idx.x][1], params.rayRadiance[idx.z][idx.y][idx.x][2]); + FixedArray rayFeatures; +#pragma unroll + for (int i = 0; i < RAY_FEATURE_DIM; i++) { +#if FEATURE_OUTPUT_HALF + rayFeatures[i] = __half2float(params.rayFeatures[idx.z][idx.y][idx.x][i]); +#else + rayFeatures[i] = params.rayFeatures[idx.z][idx.y][idx.x][i]; +#endif + } float rayTransmittance = 1.0f - params.rayDensity[idx.z][idx.y][idx.x][0]; float rayHitDistance = params.rayHitDistance[idx.z][idx.y][idx.x][0]; #ifdef ENABLE_NORMALS @@ -119,7 +127,11 @@ extern "C" __global__ void __raygen__rg() { float rayMaxHitDistance = params.rayHitDistance[idx.z][idx.y][idx.x][1]; - float3 rayRadianceGrad = make_float3(params.rayRadianceGrad[idx.z][idx.y][idx.x][0], params.rayRadianceGrad[idx.z][idx.y][idx.x][1], params.rayRadianceGrad[idx.z][idx.y][idx.x][2]); + FixedArray rayFeaturesGrad; +#pragma unroll + for (int i = 0; i < RAY_FEATURE_DIM; i++) { + rayFeaturesGrad[i] = params.rayFeaturesGrad[idx.z][idx.y][idx.x][i]; + } float rayTransmittanceGrad = -1.0f * params.rayDensityGrad[idx.z][idx.y][idx.x][0]; float rayHitDistanceGrad = params.rayHitDistanceGrad[idx.z][idx.y][idx.x][0]; #ifdef ENABLE_NORMALS @@ -149,12 +161,13 @@ extern "C" __global__ void __raygen__rg() { // NB : processing front-to-back backToFrontBwd is equivalent processing back-to-front frontToBackBwd !! float hitAlpha, hitDistance; - float3 hitNormal; + float3 canonicalIntersection, hitNormal; if (particleDensityHit( rayOrigin, rayDirection, particleDensityParameters(rayHit.particleId, - {{(gaussianParticle_RawParameters_0*)params.particleDensity, nullptr}}), + {{(gaussianParticle_RawParameters_0*)params.particleDensity, nullptr, true}}), &hitAlpha, &hitDistance, + &canonicalIntersection, #ifdef ENABLE_NORMALS true, &hitNormal #else @@ -163,26 +176,37 @@ extern "C" __global__ void __raygen__rg() { ) ) { - const float3 hitRadiance = particleFeaturesFromBuffer( + FixedArray hitFeatures; + particleFeaturesFromBuffer( rayHit.particleId, - {{(float3*)params.particleRadiance, (float3*)params.particleRadianceGrad}, (int)params.sphDegree}, - rayDirection); + const_cast(params.particleFeatures), + (int)params.sphDegree, + rayDirection, + canonicalIntersection, + &hitFeatures); float hitAlphaGrad = 0.f; + float3 canonicalIntersectionGrad = make_float3(0.f); particleFeaturesIntegrateBwdToBuffer(rayDirection, + canonicalIntersection, + &canonicalIntersectionGrad, hitAlpha, &hitAlphaGrad, rayHit.particleId, - {{(float3*)params.particleRadiance, (float3*)params.particleRadianceGrad}, params.sphDegree}, - hitRadiance, - &rayRadiance, - &rayRadianceGrad); + const_cast(params.particleFeatures), + const_cast(params.particleFeaturesGrad), + (int)params.sphDegree, + false, // exclusiveGradient: multiple rays can hit same particle + hitFeatures, + &rayFeatures, + &rayFeaturesGrad); particleDensityProcessHitBwdToBuffer(rayOrigin, rayDirection, rayHit.particleId, {{(gaussianParticle_RawParameters_0*)params.particleDensity, - (gaussianParticle_RawParameters_0*)params.particleDensityGrad}}, + (gaussianParticle_RawParameters_0*)params.particleDensityGrad, + false}}, hitAlpha, hitAlphaGrad, &rayTransmittance, @@ -190,10 +214,11 @@ extern "C" __global__ void __raygen__rg() { hitDistance, &rayHitDistance, &rayHitDistanceGrad, + canonicalIntersectionGrad, #ifdef ENABLE_NORMALS true, hitNormal, &rayNormal, &rayNormalGrad #else - false, hitNormal, nullptr, nullptr + false, make_float3(0.f, 0.f, 0.f), nullptr, nullptr #endif ); } @@ -215,7 +240,7 @@ extern "C" __global__ void __intersection__is() { : particleDensityHitCustom(optixGetWorldRayOrigin(), optixGetWorldRayDirection(), optixGetPrimitiveIndex(), - {{(gaussianParticle_RawParameters_0*)params.particleDensity, nullptr}}, + {{(gaussianParticle_RawParameters_0*)params.particleDensity, nullptr, true}}, optixGetRayTmin(), optixGetRayTmax(), params.hitMaxParticleSquaredDistance, diff --git a/threedgrt_tracer/src/kernels/cuda/referenceSlangOptix.cu b/threedgrt_tracer/src/kernels/cuda/referenceSlangOptix.cu index 86a82893..a0fb6cb1 100644 --- a/threedgrt_tracer/src/kernels/cuda/referenceSlangOptix.cu +++ b/threedgrt_tracer/src/kernels/cuda/referenceSlangOptix.cu @@ -109,9 +109,14 @@ extern "C" __global__ void __raygen__rg() { float3 rayOrigin = params.rayWorldOrigin(idx); float3 rayDirection = params.rayWorldDirection(idx); - float3 rayRadiance = make_float3(0.0f); - float rayTransmittance = 1.0f; - float rayHitDistance = 0.f; + FixedArray rayFeatures; +#pragma unroll + for (int i = 0; i < RAY_FEATURE_DIM; i++) { + rayFeatures[i] = 0.0f; + } + float rayTransmittance = 1.0f; + float rayHitDistance = 0.f; + float3 canonicalIntersection = make_float3(0.f); #ifdef ENABLE_NORMALS float3 rayNormal = make_float3(0.f); #endif @@ -140,9 +145,10 @@ extern "C" __global__ void __raygen__rg() { rayOrigin, rayDirection, rayHit.particleId, - {{(gaussianParticle_RawParameters_0*)params.particleDensity, nullptr}}, + {{(gaussianParticle_RawParameters_0*)params.particleDensity, nullptr, true}}, &rayTransmittance, &rayHitDistance, + &canonicalIntersection, #ifdef ENABLE_NORMALS true, &rayNormal #else @@ -151,10 +157,12 @@ extern "C" __global__ void __raygen__rg() { ); particleFeaturesIntegrateFwdFromBuffer(rayDirection, + canonicalIntersection, hitWeight, rayHit.particleId, - {{(float3*)params.particleRadiance, nullptr}, params.sphDegree}, - &rayRadiance); + const_cast(params.particleFeatures), + params.sphDegree, + &rayFeatures); // NOTE(qi): Race condition here, but as we are writing the same value, it seems it is safe. if (hitWeight > 0.f) { @@ -170,9 +178,14 @@ extern "C" __global__ void __raygen__rg() { } } - params.rayRadiance[idx.z][idx.y][idx.x][0] = rayRadiance.x; - params.rayRadiance[idx.z][idx.y][idx.x][1] = rayRadiance.y; - params.rayRadiance[idx.z][idx.y][idx.x][2] = rayRadiance.z; +#pragma unroll + for (int i = 0; i < RAY_FEATURE_DIM; i++) { +#if FEATURE_OUTPUT_HALF + params.rayFeatures[idx.z][idx.y][idx.x][i] = __float2half(rayFeatures[i]); +#else + params.rayFeatures[idx.z][idx.y][idx.x][i] = rayFeatures[i]; +#endif + } params.rayDensity[idx.z][idx.y][idx.x][0] = 1 - rayTransmittance; params.rayHitDistance[idx.z][idx.y][idx.x][0] = rayHitDistance; params.rayHitDistance[idx.z][idx.y][idx.x][1] = rayLastHitDistance; @@ -197,7 +210,7 @@ extern "C" __global__ void __intersection__is() { : particleDensityHitCustom(optixGetWorldRayOrigin(), optixGetWorldRayDirection(), optixGetPrimitiveIndex(), - {{(gaussianParticle_RawParameters_0*)params.particleDensity, nullptr}}, + {{(gaussianParticle_RawParameters_0*)params.particleDensity, nullptr, true}}, optixGetRayTmin(), optixGetRayTmax(), params.hitMaxParticleSquaredDistance, diff --git a/threedgrt_tracer/src/optixTracer.cpp b/threedgrt_tracer/src/optixTracer.cpp index e1e4a817..1c43e4b9 100644 --- a/threedgrt_tracer/src/optixTracer.cpp +++ b/threedgrt_tracer/src/optixTracer.cpp @@ -215,6 +215,19 @@ std::vector OptixTracer::generateDefines( defines.emplace_back("-DSPH_MAX_NUM_COEFFS=" + std::to_string((_state->particleRadianceSphDegree + 1) * (_state->particleRadianceSphDegree + 1))); defines.emplace_back("-DPARTICLE_PRIMITIVE_TYPE=" + std::to_string(_state->gPrimType)); defines.emplace_back("-DPARTICLE_PRIMITIVE_CLAMPED=" + std::to_string(particleKernelDensityClamping ? 1 : 0)); + // Feature dims: use C++ defines so JIT OptiX pipeline sees same values as extension build + defines.emplace_back("-DPARTICLE_FEATURE_DIM=" + std::to_string(PipelineParameters::ParticleFeatureDim)); + defines.emplace_back("-DRAY_FEATURE_DIM=" + std::to_string(PipelineParameters::RayFeatureDim)); + defines.emplace_back("-DFEATURE_TRANSFORM_TYPE=" + std::to_string(PipelineParameters::FeatureTransformType)); + // Half-precision flags: must match what the extension build sees so the + // PipelineParameters struct layout and the Slang-generated header agree + // with the NVRTC-compiled OptiX kernels. + defines.emplace_back("-DPARTICLE_FEATURE_HALF=" + std::to_string(PARTICLE_FEATURE_HALF)); + defines.emplace_back("-DFEATURE_OUTPUT_HALF=" + std::to_string(FEATURE_OUTPUT_HALF)); +#if PARTICLE_FEATURE_HALF || FEATURE_OUTPUT_HALF + // Enable `__half` in the Slang-generated CUDA prelude. + defines.emplace_back("-DSLANG_CUDA_ENABLE_HALF=1"); +#endif } return defines; } @@ -356,9 +369,21 @@ void OptixTracer::createPipeline(const OptixDeviceContext context, std::string optix_include_dir = dependencies_path + "/dependencies/optix-dev/include"; std::string cuda_include_dir = cuda_path + "/include"; + // Some CUDA distributions (e.g. conda's cuda-toolkit) place per-target + // SDK headers such as under $CUDA_HOME/targets//include + // rather than directly under $CUDA_HOME/include. Thread that extra + // search path to NVRTC when present so Slang-generated CUDA that uses + // __half compiles. + std::vector extra_includes_with_cuda_targets = extra_includes; +#if defined(__x86_64__) || defined(_M_X64) + extra_includes_with_cuda_targets.push_back(cuda_path + "/targets/x86_64-linux/include"); +#elif defined(__aarch64__) + extra_includes_with_cuda_targets.push_back(cuda_path + "/targets/sbsa-linux/include"); +#endif + const char* input = getInputData(shaderFile.c_str(), includeDir.c_str(), optix_include_dir.c_str(), cuda_include_dir.c_str(), kernel_name.c_str(), inputSize, defines, - (const char**)&log, extra_includes); + (const char**)&log, extra_includes_with_cuda_targets); size_t sizeof_log = sizeof(log); OPTIX_CHECK_LOG(optixModuleCreateFromPTX( @@ -857,13 +882,18 @@ OptixTracer::trace(uint32_t frameNumber, torch::Tensor rayOri, torch::Tensor rayDir, torch::Tensor particleDensity, - torch::Tensor particleRadiance, + torch::Tensor particleFeatures, uint32_t renderOpts, int sphDegree, float minTransmittance) { const torch::TensorOptions opts = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); - torch::Tensor rayRad = torch::empty({rayOri.size(0), rayOri.size(1), rayOri.size(2), 3}, opts); +#if FEATURE_OUTPUT_HALF + const torch::TensorOptions rayFeatOpts = torch::TensorOptions().dtype(torch::kHalf).device(torch::kCUDA); +#else + const torch::TensorOptions rayFeatOpts = opts; +#endif + torch::Tensor rayFeat = torch::empty({rayOri.size(0), rayOri.size(1), rayOri.size(2), static_cast(PipelineParameters::RayFeatureDim)}, rayFeatOpts); torch::Tensor rayDns = torch::empty({rayOri.size(0), rayOri.size(1), rayOri.size(2), 1}, opts); torch::Tensor rayHit = torch::empty({rayOri.size(0), rayOri.size(1), rayOri.size(2), 2}, opts); torch::Tensor rayNrm = torch::empty({rayOri.size(0), rayOri.size(1), rayOri.size(2), 3}, opts); @@ -889,11 +919,11 @@ OptixTracer::trace(uint32_t frameNumber, paramsHost.rayDirection = packed_accessor32(rayDir); paramsHost.particleDensity = getPtr(particleDensity); - paramsHost.particleRadiance = getPtr(particleRadiance); + paramsHost.particleFeatures = getPtr(particleFeatures); paramsHost.particleExtendedData = reinterpret_cast(_state->gPipelineParticleData); paramsHost.particleVisibility = getPtr(particleVisibility); - paramsHost.rayRadiance = packed_accessor32(rayRad); + paramsHost.rayFeatures = packed_accessor32(rayFeat); paramsHost.rayDensity = packed_accessor32(rayDns); paramsHost.rayHitDistance = packed_accessor32(rayHit); paramsHost.rayNormal = packed_accessor32(rayNrm); @@ -906,12 +936,12 @@ OptixTracer::trace(uint32_t frameNumber, reinterpret_cast(_state->paramsDevice), ¶msHost, sizeof(paramsHost), cudaMemcpyHostToDevice, cudaStream)); OPTIX_CHECK(optixLaunch(_state->pipelineTracingFwd, cudaStream, _state->paramsDevice, - sizeof(PipelineParameters), &_state->sbtTracingFwd, rayRad.size(2), - rayRad.size(1), rayRad.size(0))); + sizeof(PipelineParameters), &_state->sbtTracingFwd, rayFeat.size(2), + rayFeat.size(1), rayFeat.size(0))); CUDA_CHECK_LAST(); - return std::tuple(rayRad, rayDns, rayHit, rayNrm, rayHitsCount, particleVisibility); + return std::tuple(rayFeat, rayDns, rayHit, rayNrm, rayHitsCount, particleVisibility); } std::tuple @@ -919,13 +949,13 @@ OptixTracer::traceBwd(uint32_t frameNumber, torch::Tensor rayToWorld, torch::Tensor rayOri, torch::Tensor rayDir, - torch::Tensor rayRad, + torch::Tensor rayFeat, torch::Tensor rayDns, torch::Tensor rayHit, torch::Tensor rayNrm, torch::Tensor particleDensity, - torch::Tensor particleRadiance, - torch::Tensor rayRadGrd, + torch::Tensor particleFeatures, + torch::Tensor rayFeatGrd, torch::Tensor rayDnsGrd, torch::Tensor rayHitGrd, torch::Tensor rayNrmGrd, @@ -935,7 +965,7 @@ OptixTracer::traceBwd(uint32_t frameNumber, const torch::TensorOptions opts = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); torch::Tensor particleDensityGrad = torch::zeros({particleDensity.size(0), particleDensity.size(1)}, opts); - torch::Tensor particleRadianceGrad = torch::zeros({particleRadiance.size(0), particleRadiance.size(1)}, opts); + torch::Tensor particleFeaturesGrad = torch::zeros({particleFeatures.size(0), particleFeatures.size(1)}, opts); PipelineBackwardParameters paramsHost; paramsHost.handle = _state->gasHandle; @@ -956,18 +986,18 @@ OptixTracer::traceBwd(uint32_t frameNumber, paramsHost.rayDirection = packed_accessor32(rayDir); paramsHost.particleDensity = getPtr(particleDensity); - paramsHost.particleRadiance = getPtr(particleRadiance); + paramsHost.particleFeatures = getPtr(particleFeatures); paramsHost.particleExtendedData = reinterpret_cast(_state->gPipelineParticleData); - paramsHost.rayRadiance = packed_accessor32(rayRad); + paramsHost.rayFeatures = packed_accessor32(rayFeat); paramsHost.rayDensity = packed_accessor32(rayDns); paramsHost.rayHitDistance = packed_accessor32(rayHit); paramsHost.rayNormal = packed_accessor32(rayNrm); paramsHost.particleDensityGrad = getPtr(particleDensityGrad); - paramsHost.particleRadianceGrad = getPtr(particleRadianceGrad); + paramsHost.particleFeaturesGrad = getPtr(particleFeaturesGrad); - paramsHost.rayRadianceGrad = packed_accessor32(rayRadGrd); + paramsHost.rayFeaturesGrad = packed_accessor32(rayFeatGrd); paramsHost.rayDensityGrad = packed_accessor32(rayDnsGrd); paramsHost.rayHitDistanceGrad = packed_accessor32(rayHitGrd); paramsHost.rayNormalGrad = packed_accessor32(rayNrmGrd); @@ -980,7 +1010,7 @@ OptixTracer::traceBwd(uint32_t frameNumber, OPTIX_CHECK(optixLaunch(_state->pipelineTracingBwd, cudaStream, _state->paramsDevice, sizeof(PipelineBackwardParameters), &_state->sbtTracingBwd, - rayRad.size(2), rayRad.size(1), rayRad.size(0))); + rayFeat.size(2), rayFeat.size(1), rayFeat.size(0))); - return std::tuple(particleDensityGrad, particleRadianceGrad); + return std::tuple(particleDensityGrad, particleFeaturesGrad); } diff --git a/threedgrt_tracer/tracer.py b/threedgrt_tracer/tracer.py index d62492c4..349f84c2 100644 --- a/threedgrt_tracer/tracer.py +++ b/threedgrt_tracer/tracer.py @@ -21,6 +21,7 @@ import torch.utils.cpp_extension from threedgrut.datasets.protocols import Batch +from threedgrut.model.features import Features from threedgrut.utils.timer import CudaTimer logger = logging.getLogger(__name__) @@ -66,7 +67,7 @@ def forward( min_transmittance, ): particle_density = torch.concat([mog_pos, mog_dns, mog_rot, mog_scl, torch.zeros_like(mog_dns)], dim=1) - ray_radiance, ray_density, ray_hit_distance, ray_normals, hits_count, mog_visibility = tracer_wrapper.trace( + ray_features, ray_density, ray_hit_distance, ray_normals, hits_count, mog_visibility = tracer_wrapper.trace( frame_id, ray_to_world, ray_ori, @@ -81,7 +82,7 @@ def forward( ray_to_world, ray_ori, ray_dir, - ray_radiance, + ray_features, ray_density, ray_hit_distance, ray_normals, @@ -94,7 +95,7 @@ def forward( ctx.min_transmittance = min_transmittance ctx.tracer_wrapper = tracer_wrapper return ( - ray_radiance, + ray_features.float(), # always fp32 to caller; fp16 saved in ctx for trace_bwd ray_density, ray_hit_distance[:, :, :, 0:1], # return only the hit distance ray_normals, @@ -105,7 +106,7 @@ def forward( @staticmethod def backward( ctx, - ray_radiance_grd, + ray_features_grd, ray_density_grd, ray_hit_distance_grd, ray_normals_grd, @@ -116,7 +117,7 @@ def backward( ray_to_world, ray_ori, ray_dir, - ray_radiance, + ray_features, ray_density, ray_hit_distance, ray_normals, @@ -129,13 +130,13 @@ def backward( ray_to_world, ray_ori, ray_dir, - ray_radiance, + ray_features, ray_density, ray_hit_distance, ray_normals, particle_density, mog_sph, - ray_radiance_grd, + ray_features_grd, ray_density_grd, ray_hit_distance_grd, ray_normals_grd, @@ -167,10 +168,10 @@ class RenderOpts(IntEnum): DEFAULT = NONE def __init__(self, conf): - self.device = "cuda" self.conf = conf self.num_update_bvh = 0 + self.feature_transform_type = Features(conf).transform_type logger.info(f'🔆 Creating Optix tracing pipeline.. Using CUDA path: "{torch.utils.cpp_extension.CUDA_HOME}"') torch.zeros(1, device=self.device) # Create a dummy tensor to force cuda context init @@ -220,7 +221,7 @@ def render(self, gaussians, gpu_batch: Batch, train=False, frame_id=0): if self.frame_timer is not None: self.frame_timer.start() - pred_rgb, pred_opacity, pred_dist, pred_normals, hits_count, mog_visibility = Tracer._Autograd.apply( + (pred_features, pred_opacity, pred_dist, pred_normals, hits_count, mog_visibility) = Tracer._Autograd.apply( self.tracer_wrapper, frame_id, gpu_batch.T_to_world.contiguous(), @@ -230,7 +231,11 @@ def render(self, gaussians, gpu_batch: Batch, train=False, frame_id=0): gaussians.get_rotation().contiguous(), gaussians.get_scale().contiguous(), gaussians.get_density().contiguous(), - gaussians.get_features().contiguous(), + ( + gaussians.get_features().half().contiguous() + if self.conf.render.particle_feature_half + else gaussians.get_features().contiguous() + ), Tracer.RenderOpts.DEFAULT, gaussians.n_active_features, self.conf.render.min_transmittance, @@ -239,15 +244,11 @@ def render(self, gaussians, gpu_batch: Batch, train=False, frame_id=0): if self.frame_timer is not None: self.frame_timer.end() - pred_rgb, pred_opacity = gaussians.background( - gpu_batch.T_to_world.contiguous(), gpu_batch.rays_dir.contiguous(), pred_rgb, pred_opacity, train - ) - if self.frame_timer is not None: self.timings["forward_render"] = self.frame_timer.timing() return { - "pred_rgb": pred_rgb, + "pred_features": pred_features, "pred_opacity": pred_opacity, "pred_dist": pred_dist, "pred_normals": torch.nn.functional.normalize(pred_normals, dim=3), diff --git a/threedgrut/model/feature_decoder.py b/threedgrut/model/feature_decoder.py new file mode 100644 index 00000000..351275e0 --- /dev/null +++ b/threedgrut/model/feature_decoder.py @@ -0,0 +1,215 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn +import tinycudann as tcnn + + +class FeatureDecoder(nn.Module): + """Transforms N-dimensional feature maps to RGB radiance using tiny-cuda-nn. + + Takes rendered feature maps and ray directions, encodes directions, concatenates + with features, and decodes to RGB via a tiny-cuda-nn MLP. + """ + + def __init__( + self, + ray_feature_dim: int, + hidden_dim: int = 128, + num_layers: int = 4, + dir_encoding: str = "SphericalHarmonics", + dir_encoding_degree: int = 3, + sh_scale: float = 1.0, + output_activation: str = "Sigmoid", + ema_decay: float = 0.0, + ema_start_step: int = 0, + unpremultiply_alpha: bool = False, + ): + """Initialize the feature decoder. + + Args: + ray_feature_dim: Per-ray feature dimension (rendered features input to the decoder MLP) + hidden_dim: Hidden layer dimension for MLP decoder (default 128) + num_layers: Number of hidden layers in the MLP (default 4) + dir_encoding: Direction encoding type ("SphericalHarmonics" or "Frequency") + dir_encoding_degree: Degree for direction encoding (SH degree or frequency bands; default 3) + sh_scale: Scale applied to ray directions before encoding: (v*sh_scale+1)/2 maps to [0,1]. + sh_scale=1 is standard unit sphere coverage; sh_scale=3 extends coverage for + directions beyond the unit sphere. + output_activation: Output layer activation ("Sigmoid" for [0,1] RGB, or "ReLU") + ema_decay: If > 0, keep EMA shadow of parameters (decay factor). 0 = no EMA. + ema_start_step: Global step at which to start updating EMA. + unpremultiply_alpha: If True, decode features / alpha then multiply RGB by alpha. + """ + super().__init__() + self.ray_feature_dim = ray_feature_dim + self.hidden_dim = hidden_dim + self.num_layers = num_layers + self.sh_scale = sh_scale + self.output_activation = output_activation + self.unpremultiply_alpha = unpremultiply_alpha + self._ema_decay = ema_decay + self._ema_start_step = ema_start_step + self._ema_shadow: dict[str, torch.Tensor] = {} + self._ema_backup: dict[str, torch.Tensor] = {} + + if dir_encoding == "SphericalHarmonics": + dir_enc = {"otype": "SphericalHarmonics", "degree": dir_encoding_degree, "n_dims_to_encode": 3} + elif dir_encoding == "Frequency": + dir_enc = {"otype": "Frequency", "n_frequencies": dir_encoding_degree, "n_dims_to_encode": 3} + else: + raise ValueError(f"Unknown dir_encoding: {dir_encoding}") + + composite_encoding_config = { + "otype": "Composite", + "nested": [ + {"otype": "Identity", "n_dims_to_encode": ray_feature_dim}, + dir_enc, + ], + } + + network_config = { + "otype": "FullyFusedMLP", + "activation": "ReLU", + "output_activation": output_activation, + "n_neurons": hidden_dim, + "n_hidden_layers": num_layers, + } + + self.network = tcnn.NetworkWithInputEncoding( + n_input_dims=ray_feature_dim + 3, + n_output_dims=3, + encoding_config=composite_encoding_config, + network_config=network_config, + ) + + if self._ema_decay > 0: + for name, param in self.named_parameters(): + if param.requires_grad: + self._ema_shadow[name] = param.data.clone() + + def ema_update(self, global_step: int) -> None: + """Update EMA shadow when global_step >= ema_start_step. No-op if ema_decay <= 0.""" + if self._ema_decay <= 0 or global_step < self._ema_start_step: + return + with torch.no_grad(): + for name, param in self.named_parameters(): + if param.requires_grad and name in self._ema_shadow: + if self._ema_shadow[name].device != param.device: + self._ema_shadow[name] = self._ema_shadow[name].to(param.device) + self._ema_shadow[name].lerp_(param.data, 1.0 - self._ema_decay) + + def apply_ema_shadow(self) -> None: + """Use EMA weights for inference (e.g. validation). No-op if no EMA.""" + if not self._ema_shadow: + return + with torch.no_grad(): + for name, param in self.named_parameters(): + if param.requires_grad and name in self._ema_shadow: + self._ema_backup[name] = param.data.clone() + param.data.copy_(self._ema_shadow[name]) + + def restore_ema(self) -> None: + """Restore training weights after inference. No-op if no EMA.""" + with torch.no_grad(): + for name, param in self.named_parameters(): + if param.requires_grad and name in self._ema_backup: + param.data.copy_(self._ema_backup[name]) + self._ema_backup.clear() + + def ema_state_dict(self) -> dict: + """State dict of EMA shadow for checkpoint. Empty if no EMA.""" + return {k: v.clone() for k, v in self._ema_shadow.items()} + + def load_ema_state_dict(self, state_dict: dict) -> None: + """Load EMA shadow from checkpoint.""" + self._ema_shadow = {k: v.clone() for k, v in state_dict.items()} + + def forward( + self, + features: torch.Tensor, + ray_directions: torch.Tensor, + alpha: torch.Tensor | None = None, + ) -> torch.Tensor: + """Transform features and ray directions to RGB. + + Args: + features: Input features of shape [H*W, N] or [B, H, W, N] (alpha-blended). + ray_directions: Ray directions of shape [H*W, 3] or [B, H, W, 3]. + alpha: Optional opacity used only when unpremultiply_alpha is enabled. + + Returns: + RGB tensor of shape [H*W, 3] or [B, H, W, 3] + """ + features_shape = features.shape + ray_dirs_shape = ray_directions.shape + + if len(features_shape) == 4: # [B, H, W, N] + B, H, W, N = features_shape + assert ray_dirs_shape == (B, H, W, 3), f"Ray directions shape mismatch: expected {(B, H, W, 3)}, got {ray_dirs_shape}" + assert N == self.ray_feature_dim, f"Expected {self.ray_feature_dim} features, got {N}" + + features_flat = features.reshape(B * H * W, N) + ray_dirs_flat = ray_directions.reshape(B * H * W, 3) + alpha_flat = alpha.reshape(B * H * W, 1) if alpha is not None else None + + rgb_flat = self._process(features_flat, ray_dirs_flat, alpha_flat) + return rgb_flat.reshape(B, H, W, 3) + + elif len(features_shape) == 2: # [H*W, N] + HW, N = features_shape + assert ray_dirs_shape == (HW, 3), f"Ray directions shape mismatch: expected {(HW, 3)}, got {ray_dirs_shape}" + assert N == self.ray_feature_dim, f"Expected {self.ray_feature_dim} features, got {N}" + alpha_flat = alpha.reshape(HW, 1) if alpha is not None else None + return self._process(features, ray_directions, alpha_flat) + else: + raise ValueError(f"Expected input shape [B, H, W, N] or [H*W, N], got {features_shape}") + + def _process( + self, + features: torch.Tensor, + ray_directions: torch.Tensor, + alpha: torch.Tensor | None = None, + ) -> torch.Tensor: + if self.unpremultiply_alpha and alpha is not None: + alpha_safe = alpha.clamp(min=1e-8) + features = features / alpha_safe + + # tcnn SH/Frequency encoding expects inputs in [0,1]: (v*sh_scale+1)/2 + dirs_unit_cube = (ray_directions * self.sh_scale + 1.0) * 0.5 + full_input = torch.cat([features, dirs_unit_cube], dim=-1) + rgb = self.network(full_input) + + if self.unpremultiply_alpha and alpha is not None: + rgb = rgb * alpha_safe + + return rgb.float() + + def regularization_loss(self) -> torch.Tensor: + """Compute L2 regularization loss on decoder weights.""" + loss = torch.tensor(0.0, device=self.network.params.device) + loss = loss + torch.sum(self.network.params**2) + return loss + + def extra_repr(self) -> str: + return ( + f"ray_feature_dim={self.ray_feature_dim}, " + f"hidden_dim={self.hidden_dim}, " + f"num_layers={self.num_layers}, " + f"sh_scale={self.sh_scale}, " + f"output_activation={self.output_activation}, " + f"unpremultiply_alpha={self.unpremultiply_alpha}" + ) diff --git a/threedgrut/model/features.py b/threedgrut/model/features.py new file mode 100644 index 00000000..e8de2ca6 --- /dev/null +++ b/threedgrut/model/features.py @@ -0,0 +1,169 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from enum import IntEnum + + +class Features: + """Enums and conf-driven getters for feature type, activation, and interpolation.""" + + class Type(IntEnum): + """Feature representation mode: integer value used directly in C preprocessor defines.""" + SH = 0 # Spherical harmonics + NHT = 1 # Neural harmonic texture + + @classmethod + def from_string(cls, value: str) -> "Features.Type": + value_lower = value.lower() + for member in cls: + if member.name.lower() == value_lower: + return member + raise ValueError(f"Invalid feature_type: '{value}'. Must be one of {[m.name.lower() for m in cls]}") + + class ActivationType(IntEnum): + NONE = 0 + SIREN = 1 + SINCOS = 2 + RELU = 3 + + class InterpolationType(IntEnum): + CENTER = 0 + BEZIER = 1 + + class InterpolationSupport(IntEnum): + CENTER = 0 + TETRAHEDRA = 1 + TRIANGLE = 2 + + def __init__(self, conf): + self._conf = conf + + @property + def transform_type(self): + """SH or NHT — integer value used directly in C preprocessor defines.""" + return Features.Type.from_string(self._conf.model.feature_type) + + @property + def activation_type(self): + """Feature activation type from nht_features.activation.type.""" + feature_type = self._conf.model.feature_type.lower() + if feature_type != "nht": + return Features.ActivationType.NONE + v = getattr(self._conf.model.nht_features, "activation", None) + if v is None: + return Features.ActivationType.NONE + t = getattr(v, "type", "none") + if isinstance(t, str): + t = t.lower() + if t == "none": + return Features.ActivationType.NONE + if t == "siren": + return Features.ActivationType.SIREN + if t == "sincos": + return Features.ActivationType.SINCOS + if t == "relu": + return Features.ActivationType.RELU + raise ValueError(f"Unknown nht_features.activation.type: {t}") + + @property + def activation_num_frequencies(self): + """Number of frequency bands. 1 when activation is none or relu.""" + if self.activation_type in (Features.ActivationType.NONE, Features.ActivationType.RELU): + return 1 + v = getattr(self._conf.model.nht_features, "activation", None) + return int(getattr(v, "num_frequencies", 1)) if v else 1 + + @property + def interpolation_type(self): + """CENTER (none/barycentric) or BEZIER.""" + feature_type = self._conf.model.feature_type.lower() + if feature_type != "nht": + return Features.InterpolationType.CENTER + v = getattr(self._conf.model.nht_features, "interpolation_type", "none").lower() + if v == "none": + return Features.InterpolationType.CENTER + if v == "barycentric": + return Features.InterpolationType.CENTER + if v == "bezier": + return Features.InterpolationType.BEZIER + raise ValueError(f"Unknown nht_features.interpolation_type: {v}") + + @property + def interpolation_support(self): + """CENTER, TETRAHEDRA (gaussian), or TRIANGLE (trisurfel).""" + feature_type = self._conf.model.feature_type.lower() + if feature_type != "nht": + return Features.InterpolationSupport.CENTER + v = getattr(self._conf.model.nht_features, "interpolation_type", "none").lower() + if v == "none": + return Features.InterpolationSupport.CENTER + if v == "barycentric": + primitive = getattr(self._conf.render, "primitive_type", "instances") + return Features.InterpolationSupport.TRIANGLE if primitive == "trisurfel" else Features.InterpolationSupport.TETRAHEDRA + raise ValueError(f"Unknown nht_features.interpolation_type: {v}") + + @property + def num_interpolation_points(self): + """1 for center support, 4 for barycentric (tetrahedra or trisurfel).""" + feature_type = self._conf.model.feature_type.lower() + if feature_type != "nht": + return 1 + if self.interpolation_support == Features.InterpolationSupport.CENTER: + return 1 + return 4 # barycentric: tetrahedra (4 verts) or trisurfel (2 coplanar triangles, 4 verts) + + @property + def particle_feature_dim(self): + """Total feature dim per particle (buffer stride). For NHT = nht_features.dim.""" + feature_type = self._conf.model.feature_type.lower() + if feature_type == "sh": + sh_degree = self._conf.model.progressive_training.max_n_features + return 3 * ((sh_degree + 1) ** 2) + elif feature_type == "nht": + return self._conf.model.nht_features.dim + raise ValueError(f"Unknown feature_type: {feature_type}") + + @property + def interp_point_feature_dim(self): + """Per-interpolation-point feature dim before activation.""" + feature_type = self._conf.model.feature_type.lower() + if feature_type != "nht": + return 3 + return self._conf.model.nht_features.dim // self.num_interpolation_points + + @property + def ray_feature_dim(self): + """Per-ray feature dim after optional harmonic expansion.""" + feature_type = self._conf.model.feature_type.lower() + if feature_type == "sh": + return 3 # RGB output + elif feature_type == "nht": + expansion = 2 if self.activation_type == Features.ActivationType.SINCOS else 1 + return self.interp_point_feature_dim * self.activation_num_frequencies * expansion + raise ValueError(f"Unknown feature_type: {feature_type}") + + @property + def feature_defines(self): + """C preprocessor defines for all feature-related kernel parameters.""" + return [ + f"-DPARTICLE_FEATURE_DIM={self.particle_feature_dim}", + f"-DRAY_FEATURE_DIM={self.ray_feature_dim}", + f"-DFEATURE_TRANSFORM_TYPE={self.transform_type}", + f"-DFEATURE_INTERPOLATION_TYPE={self.interpolation_type}", + f"-DFEATURE_INTERPOLATION_SUPPORT={self.interpolation_support}", + f"-DFEATURE_ACTIVATION_TYPE={self.activation_type}", + f"-DFEATURE_ACTIVATION_NUM_FREQUENCIES={self.activation_num_frequencies}", + f"-DINTERP_POINT_FEATURE_DIM={self.interp_point_feature_dim}", + ] diff --git a/threedgrut/model/model.py b/threedgrut/model/model.py index f5400c2a..b08e9cc9 100644 --- a/threedgrut/model/model.py +++ b/threedgrut/model/model.py @@ -42,9 +42,9 @@ to_np, to_torch, ) +from threedgrut.model.features import Features from threedgrut.utils.render import RGB2SH - class MixtureOfGaussians(torch.nn.Module, ExportableModel): """ """ @@ -54,10 +54,15 @@ def num_gaussians(self): def feature_fields(self) -> list[str]: """Returns a list of feature field names - subclasses can override""" - return [ - "features_albedo", - "features_specular", - ] + if self.feature_type == Features.Type.SH: + return [ + "features_albedo", + "features_specular", + ] + elif self.feature_type == Features.Type.NHT: + return ["features"] + else: + raise ValueError(f"Unknown feature_type: {self.feature_type}") def get_positions(self) -> torch.Tensor: return self.positions @@ -69,13 +74,24 @@ def get_n_active_features(self) -> int: return self.n_active_features def get_features_albedo(self) -> torch.Tensor: - return self.features_albedo + if self.feature_type == Features.Type.SH: + return self.features_albedo + else: + raise AttributeError(f"features_albedo not available in feature_type='{self.feature_type.name.lower()}' mode") def get_features_specular(self) -> torch.Tensor: - return self.features_specular + if self.feature_type == Features.Type.SH: + return self.features_specular + else: + raise AttributeError(f"features_specular not available in feature_type='{self.feature_type.name.lower()}' mode") def get_features(self): - return torch.cat((self.features_albedo, self.features_specular), dim=1) + if self.feature_type == Features.Type.SH: + return torch.cat((self.features_albedo, self.features_specular), dim=1) + elif self.feature_type == Features.Type.NHT: + return self.features # [N, K] + else: + raise ValueError(f"Unknown feature_type: {self.feature_type}") def get_scale(self, preactivation=False): if preactivation: @@ -124,21 +140,31 @@ def get_model_parameters(self) -> dict: # Add optimizer state dict "optimizer": self.optimizer.state_dict(), "config": self.conf, + # Feature type and dimensions + "feature_type": self.feature_type.name.lower(), # Store as string for serialization + "particle_feature_dim": self.particle_feature_dim, + "ray_feature_dim": self.ray_feature_dim, } if self.progressive_training: model_params["feature_dim_increase_interval"] = self.feature_dim_increase_interval model_params["feature_dim_increase_step"] = self.feature_dim_increase_step - if self.feature_type == "sh": + if self.feature_type == Features.Type.SH: model_params["features_albedo"] = self.features_albedo model_params["features_specular"] = self.features_specular + elif self.feature_type == Features.Type.NHT: + model_params["features"] = self.features return model_params def __init__(self, conf, scene_extent=None): super().__init__() + # Store config early - needed for feature type detection + self.conf = conf + self.scene_extent = scene_extent + sh_degree = conf.model.progressive_training.max_n_features render_sph_degree = conf.render.particle_radiance_sph_degree if sh_degree > render_sph_degree: @@ -157,16 +183,47 @@ def __init__(self, conf, scene_extent=None): ) # Rotation of each Gaussian represented as a unit quaternion [n_gaussians, 4] self.scale = torch.nn.Parameter(torch.empty([0, 3])) # Anisotropic scale of each Gaussian [n_gaussians, 3] self.density = torch.nn.Parameter(torch.empty([0, 1])) # Density of each Gaussian [n_gaussians, 1] - self.features_albedo = torch.nn.Parameter( - torch.empty([0, 3]) - ) # Feature vector of the 0th order SH coefficients [n_gaussians, 3] (We split it into two due to different learning rates) - self.features_specular = torch.nn.Parameter( - torch.empty([0, specular_dim]) - ) # Features of the higher order SH coefficients [n_gaussians, specular_dim] + + # Feature type configuration - determine feature storage mode + self.feature_type = Features.Type.from_string(self.conf.model.feature_type) + + primitive_type = (getattr(conf.render, "primitive_type", None) or "").lower() + if self.feature_type == Features.Type.NHT and primitive_type == "trisurfel": + raise ValueError( + "Trisurfels are not supported in NHT mode. Use primitive_type 'instances' or 'icosahedron'." + ) + + if self.feature_type == Features.Type.SH: + # Spherical harmonics mode: separate albedo and specular features + self.features_albedo = torch.nn.Parameter( + torch.empty([0, 3]) + ) # Feature vector of the 0th order SH coefficients [n_gaussians, 3] + self.features_specular = torch.nn.Parameter( + torch.empty([0, specular_dim]) + ) # Features of the higher order SH coefficients [n_gaussians, specular_dim] + self.particle_feature_dim = 3 + specular_dim # SH coeffs (input to tracer) + self.ray_feature_dim = 3 # RGB output from tracer + elif self.feature_type == Features.Type.NHT: + # NHT: per-particle feature vector, decoder maps rendered features -> RGB + feat = Features(conf) + num_points = feat.num_interpolation_points + nht_dim = int(conf.model.nht_features.dim) + if nht_dim % num_points != 0: + raise ValueError( + f"nht_features.dim={nht_dim} must be divisible by num_interpolation_points={num_points} " + f"(interpolation_type + primitive)" + ) + self.nht_num_interpolation_points = num_points + self.particle_feature_dim = feat.particle_feature_dim + self.ray_feature_dim = feat.ray_feature_dim + self.features = torch.nn.Parameter( + torch.empty([0, self.particle_feature_dim]) + ) # NHT features [n_gaussians, particle_feature_dim] + else: + raise ValueError(f"Unknown feature_type: {self.feature_type}. Must be 'sh' or 'nht'.") + self.max_sh_degree = sh_degree - self.conf = conf - self.scene_extent = scene_extent self.positions_gradient_norm = None self.device = "cuda" @@ -180,7 +237,6 @@ def __init__(self, conf, scene_extent=None): self.background = background.make(self.conf.model.background.name, self.conf.model.background) # Check if we would like to do progressive training - self.feature_type = self.conf.model.progressive_training.feature_type self.n_active_features = min(self.conf.model.progressive_training.init_n_features, sh_degree) self.max_n_features = ( sh_degree # For SH, this is the SH degree (clamped if > render.particle_radiance_sph_degree) @@ -191,7 +247,6 @@ def __init__(self, conf, scene_extent=None): self.feature_dim_increase_step = self.conf.model.progressive_training.increase_step self.progressive_training = True - # Rendering method if conf.render.method == "3dgrt": self.renderer = threedgrt_tracer.Tracer(conf) elif conf.render.method == "3dgut": @@ -219,8 +274,12 @@ def freeze_gaussians(self) -> None: self.rotation.requires_grad = False self.scale.requires_grad = False self.density.requires_grad = False - self.features_albedo.requires_grad = False - self.features_specular.requires_grad = False + + if self.feature_type == Features.Type.SH: + self.features_albedo.requires_grad = False + self.features_specular.requires_grad = False + elif self.feature_type == Features.Type.NHT: + self.features.requires_grad = False self._gaussians_frozen = True logger.info("❄️ [Distillation] Gaussian parameters frozen") @@ -232,12 +291,14 @@ def validate_fields(self): assert self.rotation.shape == (num_gaussians, 4) assert self.scale.shape == (num_gaussians, 3) - if self.feature_type == "sh": + if self.feature_type == Features.Type.SH: assert self.features_albedo.shape == (num_gaussians, 3) specular_sh_dims = sh_degree_to_specular_dim(self.max_n_features) assert self.features_specular.shape == (num_gaussians, specular_sh_dims) + elif self.feature_type == Features.Type.NHT: + assert self.features.shape == (num_gaussians, self.particle_feature_dim) else: - raise ValueError("Neural features not yet supported.") + raise ValueError(f"Unknown feature_type: {self.feature_type}") def init_from_colmap(self, root_path: str, observer_pts): # Special case for scannetpp dataset @@ -327,6 +388,8 @@ def init_from_fused_point_cloud(self, pc_path: str, observer_pts): ) def init_from_pretrained_point_cloud(self, pc_path: str, set_optimizable_parameters: bool = True): + if self.feature_type != Features.Type.SH: + raise NotImplementedError(f"init_from_pretrained_point_cloud only supports feature_type='sh', got '{self.feature_type.name.lower()}'") data = PlyData.read(pc_path) num_gaussians = len(data["vertex"]) self.positions = torch.nn.Parameter( @@ -477,14 +540,21 @@ def init_from_random_point_cloud( # sh albedo in [0, 0.0039] fused_color = torch.rand((num_gaussians, 3), dtype=dtype, device=self.device) / 255.0 - features_albedo = features_specular = None - if self.feature_type == "sh": + # Initialize features based on feature_type + if self.feature_type == Features.Type.SH: features_albedo = fused_color.contiguous() max_sh_degree = self.max_n_features num_specular_features = sh_degree_to_specular_dim(max_sh_degree) features_specular = torch.zeros( (num_gaussians, num_specular_features), dtype=dtype, device=self.device ).contiguous() + elif self.feature_type == Features.Type.NHT: + init_min = float(getattr(self.conf.model.nht_features, "init_min", -5.0)) + init_max = float(getattr(self.conf.model.nht_features, "init_max", 5.0)) + features = ( + torch.rand((num_gaussians, self.particle_feature_dim), dtype=dtype, device=self.device) + * (init_max - init_min) + init_min + ) dist = torch.clamp_min(nearest_neighbor_dist_cpuKD(fused_point_cloud), 1e-3) scales = torch.log(dist * self.conf.model.default_scale_factor)[..., None].repeat(1, 3) @@ -500,24 +570,82 @@ def init_from_random_point_cloud( self.rotation = torch.nn.Parameter(rots.to(dtype=dtype, device=self.device)) self.scale = torch.nn.Parameter(scales.to(dtype=dtype, device=self.device)) self.density = torch.nn.Parameter(opacities.to(dtype=dtype, device=self.device)) - self.features_albedo = torch.nn.Parameter(features_albedo.to(dtype=dtype, device=self.device)) - self.features_specular = torch.nn.Parameter(features_specular.to(dtype=dtype, device=self.device)) + + if self.feature_type == Features.Type.SH: + self.features_albedo = torch.nn.Parameter(features_albedo.to(dtype=dtype, device=self.device)) + self.features_specular = torch.nn.Parameter(features_specular.to(dtype=dtype, device=self.device)) + elif self.feature_type == Features.Type.NHT: + self.features = torch.nn.Parameter(features.to(dtype=dtype, device=self.device)) if set_optimizable_parameters: self.set_optimizable_parameters() self.validate_fields() def init_from_checkpoint(self, checkpoint: dict, setup_optimizer=True): + # Backward compatibility: detect legacy checkpoints without feature_type + if "feature_type" not in checkpoint and "features_albedo" in checkpoint: + logger.info("Loading legacy checkpoint - auto-detecting feature_type='sh'") + checkpoint["feature_type"] = "sh" + checkpoint["particle_feature_dim"] = checkpoint["features_albedo"].shape[1] + checkpoint["features_specular"].shape[1] + checkpoint["ray_feature_dim"] = 3 + + # Load features based on feature_type (convert string to enum) + checkpoint_feature_type_str = checkpoint.get("feature_type", "sh") + checkpoint_feature_type = Features.Type.from_string(checkpoint_feature_type_str) + + # NHT: 3DGUT is compiled with PARTICLE_FEATURE_DIM / RAY_FEATURE_DIM from current config. + # Checkpoints must match those compile-time constants or CUDA will read past feature buffers. + if checkpoint_feature_type == Features.Type.NHT: + if "features" not in checkpoint: + raise ValueError("NHT checkpoint missing 'features' tensor") + feat = checkpoint["features"] + ck_pf = int(feat.shape[1]) + ck_rf = checkpoint.get("ray_feature_dim") + if ck_pf != self.particle_feature_dim: + raise ValueError( + f"NHT checkpoint features width is {ck_pf} but this build expects " + f"particle_feature_dim={self.particle_feature_dim} from config " + f"(model.nht_features.dim / interpolation). The 3DGUT CUDA extension was compiled for " + f"the config value; use the same nht_features (and render.primitive_type) as the run " + f"that produced the checkpoint, or train from scratch." + ) + if ck_rf is not None and int(ck_rf) != self.ray_feature_dim: + raise ValueError( + f"NHT checkpoint ray_feature_dim={ck_rf} does not match config ray_feature_dim=" + f"{self.ray_feature_dim}. Align model.nht_features.activation with the checkpoint run." + ) + if "particle_feature_dim" in checkpoint and int(checkpoint["particle_feature_dim"]) != ck_pf: + logger.warning( + f"Checkpoint particle_feature_dim={checkpoint['particle_feature_dim']} disagrees with " + f"features.shape[1]={ck_pf}; using tensor shape." + ) + + # Load basic parameters self.positions = checkpoint["positions"] self.rotation = checkpoint["rotation"] self.scale = checkpoint["scale"] self.density = checkpoint["density"] - self.features_albedo = checkpoint["features_albedo"] - self.features_specular = checkpoint["features_specular"] self.n_active_features = checkpoint["n_active_features"] self.max_n_features = checkpoint["max_n_features"] self.scene_extent = checkpoint["scene_extent"] + # Load feature dimensions. For NHT, keep config-derived dims (validated above vs checkpoint tensors); + # stale metadata keys must not override after a successful shape check. + if checkpoint_feature_type != Features.Type.NHT: + if "particle_feature_dim" in checkpoint: + self.particle_feature_dim = checkpoint["particle_feature_dim"] + if "ray_feature_dim" in checkpoint: + self.ray_feature_dim = checkpoint["ray_feature_dim"] + + if checkpoint_feature_type == Features.Type.SH: + self.features_albedo = checkpoint["features_albedo"] + self.features_specular = checkpoint["features_specular"] + elif checkpoint_feature_type == Features.Type.NHT: + self.features = checkpoint["features"] + self.nht_num_interpolation_points = Features(self.conf).num_interpolation_points + else: + raise ValueError(f"Unknown feature_type in checkpoint: {checkpoint_feature_type}") + if self.progressive_training: self.feature_dim_increase_interval = checkpoint["feature_dim_increase_interval"] self.feature_dim_increase_step = checkpoint["feature_dim_increase_step"] @@ -588,17 +716,29 @@ def default_initialize_from_points(self, pts, observer_pts, colors=None, use_obs if colors is None: colors = torch.randint(0, 256, (N, 3), dtype=torch.uint8, device=self.device, generator=rng) - features_albedo = to_torch(RGB2SH(to_np(colors.float() / 255.0)), device=self.device) - - num_specular_dims = sh_degree_to_specular_dim(self.max_n_features) - features_specular = torch.zeros((N, num_specular_dims)) + # Initialize features based on feature_type + if self.feature_type == Features.Type.SH: + features_albedo = to_torch(RGB2SH(to_np(colors.float() / 255.0)), device=self.device) + num_specular_dims = sh_degree_to_specular_dim(self.max_n_features) + features_specular = torch.zeros((N, num_specular_dims)) + elif self.feature_type == Features.Type.NHT: + init_min = float(getattr(self.conf.model.nht_features, "init_min", -5.0)) + init_max = float(getattr(self.conf.model.nht_features, "init_max", 5.0)) + features = ( + torch.rand((N, self.particle_feature_dim), dtype=dtype, device=self.device, generator=rng) + * (init_max - init_min) + init_min + ) self.positions = torch.nn.Parameter(positions.to(dtype=dtype, device=self.device)) self.rotation = torch.nn.Parameter(rots.to(dtype=dtype, device=self.device)) self.scale = torch.nn.Parameter(scales.to(dtype=dtype, device=self.device)) self.density = torch.nn.Parameter(opacities.to(dtype=dtype, device=self.device)) - self.features_albedo = torch.nn.Parameter(features_albedo.to(dtype=dtype, device=self.device)) - self.features_specular = torch.nn.Parameter(features_specular.to(dtype=dtype, device=self.device)) + + if self.feature_type == Features.Type.SH: + self.features_albedo = torch.nn.Parameter(features_albedo.to(dtype=dtype, device=self.device)) + self.features_specular = torch.nn.Parameter(features_specular.to(dtype=dtype, device=self.device)) + elif self.feature_type == Features.Type.NHT: + self.features = torch.nn.Parameter(features.to(dtype=dtype, device=self.device)) self.set_optimizable_parameters() self.setup_optimizer() @@ -607,6 +747,11 @@ def default_initialize_from_points(self, pts, observer_pts, colors=None, use_obs def setup_optimizer(self, state_dict=None): params = [] for name, args in self.conf.optimizer.params.items(): + # Skip parameters that don't exist (e.g., 'features' in SH mode or 'features_albedo' in learned mode) + if not hasattr(self, name): + logger.info(f"Skipping optimizer parameter '{name}' - not present in {self.feature_type.name.lower()} mode") + continue + module = getattr(self, name) # If the module is a torch.nn.Module, we can add all of its trainable parameters to the optimizer @@ -644,15 +789,28 @@ def setup_optimizer(self, state_dict=None): def setup_scheduler(self): self.schedulers = {} for name, args in self.conf.scheduler.items(): - if args.type is not None and getattr(self, name).requires_grad: - if name == "positions": - self.schedulers[name] = get_scheduler(args.type)( - lr_init=args.lr_init * self.scene_extent, - lr_final=args.lr_final * self.scene_extent, - max_steps=args.max_steps, - ) - else: - self.schedulers[name] = get_scheduler(args.type)(**args) + if not hasattr(self, name): + continue + attr = getattr(self, name) + if not (hasattr(attr, "requires_grad") and attr.requires_grad): + continue + if args.type is None: + continue + if name == "positions": + self.schedulers[name] = get_scheduler(args.type)( + lr_init=args.lr_init * self.scene_extent, + lr_final=args.lr_final * self.scene_extent, + max_steps=args.max_steps, + ) + elif name == "features": + lr_init = getattr(self.conf.optimizer.params.features, "lr", 0.07) + decay_final = getattr(args, "decay_final", 0.001) + lr_final = lr_init * decay_final + self.schedulers[name] = get_scheduler(args.type)( + lr_init=lr_init, lr_final=lr_final, max_steps=args.max_steps + ) + else: + self.schedulers[name] = get_scheduler(args.type)(**args) def scheduler_step(self, step): for param_group in self.optimizer.param_groups: @@ -664,10 +822,6 @@ def scheduler_step(self, step): def set_optimizable_parameters(self): if not self.conf.model.optimize_density: self.density.requires_grad = False - if not self.conf.model.optimize_features_albedo: - self.features_albedo.requires_grad = False - if not self.conf.model.optimize_features_specular: - self.features_specular.requires_grad = False if not self.conf.model.optimize_rotation: self.rotation.requires_grad = False if not self.conf.model.optimize_scale: @@ -675,6 +829,17 @@ def set_optimizable_parameters(self): if not self.conf.model.optimize_position: self.positions.requires_grad = False + # Handle feature optimization based on feature_type + if self.feature_type == Features.Type.SH: + if not self.conf.model.optimize_features_albedo: + self.features_albedo.requires_grad = False + if not self.conf.model.optimize_features_specular: + self.features_specular.requires_grad = False + elif self.feature_type == Features.Type.NHT: + # For learned features, check if optimize_features config exists + if not self.conf.model.optimize_features: + self.features.requires_grad = False + def update_optimizable_parameters(self, optimizable_tensors: dict[str, torch.Tensor]): for name, value in optimizable_tensors.items(): setattr(self, name, value) @@ -683,7 +848,7 @@ def increase_num_active_features(self) -> None: self.n_active_features = min(self.max_n_features, self.n_active_features + self.feature_dim_increase_step) def get_active_feature_mask(self) -> torch.Tensor: - if self.feature_type == "sh": + if self.feature_type == Features.Type.SH: current_sh_degree = self.n_active_features max_sh_degree = self.max_n_features active_features = sh_degree_to_num_features(current_sh_degree) @@ -700,7 +865,9 @@ def clamp_density(self): optimizable_tensors = self.replace_tensor_to_optimizer(updated_densities, "density") self.density = optimizable_tensors["density"] - def forward(self, batch: Batch, train=False, frame_id=0) -> dict[str, torch.Tensor]: + def forward( + self, batch: Batch, train=False, frame_id=0 + ) -> dict[str, torch.Tensor]: """ Args: batch: a Batch structure containing the input data @@ -731,6 +898,8 @@ def export_ply(self, mogt_path: str): @torch.no_grad() def init_from_ply(self, mogt_path: str, init_model=True): + if self.feature_type != Features.Type.SH: + raise NotImplementedError(f"init_from_ply only supports feature_type='sh', got '{self.feature_type.name.lower()}'") plydata = PlyData.read(mogt_path) mogt_pos = np.stack( @@ -812,15 +981,21 @@ def copy_fields(self, other, deepcopy=False): self.rotation = torch.nn.Parameter(other.rotation.clone()) self.scale = torch.nn.Parameter(other.scale.clone()) self.density = torch.nn.Parameter(other.density.clone()) - self.features_albedo = torch.nn.Parameter(other.features_albedo.clone()) - self.features_specular = torch.nn.Parameter(other.features_specular.clone()) - else: # shared tensors + if other.feature_type == Features.Type.SH: + self.features_albedo = torch.nn.Parameter(other.features_albedo.clone()) + self.features_specular = torch.nn.Parameter(other.features_specular.clone()) + elif other.feature_type == Features.Type.NHT: + self.features = torch.nn.Parameter(other.features.clone()) + else: self.positions = torch.nn.Parameter(other.positions) self.rotation = torch.nn.Parameter(other.rotation) self.scale = torch.nn.Parameter(other.scale) self.density = torch.nn.Parameter(other.density) - self.features_albedo = torch.nn.Parameter(other.features_albedo) - self.features_specular = torch.nn.Parameter(other.features_specular) + if other.feature_type == Features.Type.SH: + self.features_albedo = torch.nn.Parameter(other.features_albedo) + self.features_specular = torch.nn.Parameter(other.features_specular) + elif other.feature_type == Features.Type.NHT: + self.features = torch.nn.Parameter(other.features) self.max_sh_degree = other.max_sh_degree self.n_active_features = other.n_active_features self.scene_extent = other.scene_extent @@ -828,6 +1003,11 @@ def copy_fields(self, other, deepcopy=False): self.feature_dim_increase_interval = other.feature_dim_increase_interval self.feature_dim_increase_step = other.feature_dim_increase_step self.background = other.background + self.feature_type = other.feature_type + self.particle_feature_dim = other.particle_feature_dim + self.ray_feature_dim = other.ray_feature_dim + if hasattr(other, "nht_num_interpolation_points"): + self.nht_num_interpolation_points = other.nht_num_interpolation_points self.validate_fields() def clone(self): @@ -842,8 +1022,11 @@ def __getitem__(self, idx): sliced.rotation = torch.nn.Parameter(sliced.rotation[idx]) sliced.scale = torch.nn.Parameter(sliced.scale[idx]) sliced.density = torch.nn.Parameter(sliced.density[idx]) - sliced.features_albedo = torch.nn.Parameter(sliced.features_albedo[idx]) - sliced.features_specular = torch.nn.Parameter(sliced.features_specular[idx]) + if self.feature_type == Features.Type.SH: + sliced.features_albedo = torch.nn.Parameter(sliced.features_albedo[idx]) + sliced.features_specular = torch.nn.Parameter(sliced.features_specular[idx]) + elif self.feature_type == Features.Type.NHT: + sliced.features = torch.nn.Parameter(sliced.features[idx]) return sliced def __len__(self): diff --git a/threedgrut/render.py b/threedgrut/render.py index 313cff66..ee2f7fa0 100644 --- a/threedgrut/render.py +++ b/threedgrut/render.py @@ -29,7 +29,7 @@ from threedgrut.utils.color_correct import color_correct_affine from threedgrut.utils.logger import logger from threedgrut.utils.misc import create_summary_writer -from threedgrut.utils.render import apply_post_processing +from threedgrut.utils.render import apply_background, apply_feature_decoder, apply_post_processing class Renderer: @@ -44,6 +44,7 @@ def __init__( writer=None, compute_extra_metrics=True, post_processing=None, + feature_decoder=None, ) -> None: if path: # Replace the path to the test data @@ -59,6 +60,7 @@ def __init__( self.writer = writer self.compute_extra_metrics = compute_extra_metrics self.post_processing = post_processing + self.feature_decoder = feature_decoder if conf.model.background.color == "black": self.bg_color = torch.zeros((3,), dtype=torch.float32, device="cuda") @@ -148,6 +150,44 @@ def from_checkpoint( num_frames = post_processing.exposure_params.shape[0] logger.info(f"📷 {method.upper()} loaded from checkpoint: {num_cameras} cameras, {num_frames} frames") + # Load feature decoder for nht models + feature_decoder = None + if "feature_decoder" in checkpoint: + from threedgrut.model.features import Features + from threedgrut.model.feature_decoder import FeatureDecoder + + if model.feature_type == Features.Type.NHT: + conf_model = conf.model + dec = conf_model.nht_decoder + hidden_dim = dec.hidden_dim + num_layers = getattr(dec, "num_layers", 4) + dir_encoding = getattr(dec, "dir_encoding", "SphericalHarmonics") + dir_encoding_degree = getattr(dec, "dir_encoding_degree", 3) + sh_scale = getattr(dec, "sh_scale", 1.0) + output_activation = getattr(dec, "output_activation", "Sigmoid") + unpremultiply_alpha = getattr(dec, "unpremultiply_alpha", False) + ema_decay = getattr(dec, "ema_decay", 0.0) + ema_start_step = getattr(dec, "ema_start_step", 0) + feature_decoder = FeatureDecoder( + ray_feature_dim=model.ray_feature_dim, + hidden_dim=hidden_dim, + num_layers=num_layers, + dir_encoding=dir_encoding, + dir_encoding_degree=dir_encoding_degree, + sh_scale=sh_scale, + output_activation=output_activation, + ema_decay=ema_decay, + ema_start_step=ema_start_step, + unpremultiply_alpha=unpremultiply_alpha, + ).to("cuda") + feature_decoder.load_state_dict(checkpoint["feature_decoder"]["module"]) + ema_state = checkpoint["feature_decoder"].get("ema") + if ema_state is not None: + feature_decoder.load_ema_state_dict(ema_state) + feature_decoder.apply_ema_shadow() + feature_decoder.eval() + logger.info("🎨 Feature decoder loaded from checkpoint") + return Renderer( model=model, conf=conf, @@ -158,6 +198,7 @@ def from_checkpoint( writer=writer, compute_extra_metrics=computes_extra_metrics, post_processing=post_processing, + feature_decoder=feature_decoder, ) @classmethod @@ -171,6 +212,7 @@ def from_preloaded_model( global_step=None, compute_extra_metrics=False, post_processing=None, + feature_decoder=None, ): """Loads checkpoint for test path.""" @@ -188,6 +230,7 @@ def from_preloaded_model( writer=writer, compute_extra_metrics=compute_extra_metrics, post_processing=post_processing, + feature_decoder=feature_decoder, ) @torch.no_grad() @@ -236,20 +279,23 @@ def render_all(self): # Compute the outputs of a single batch outputs = self.model(gpu_batch) + if self.feature_decoder is not None: + outputs = apply_feature_decoder(self.feature_decoder, outputs, gpu_batch, training=False) + outputs = apply_background(self.model.background, outputs, gpu_batch, training=False) # Apply post-processing if self.post_processing is not None: outputs = apply_post_processing(self.post_processing, outputs, gpu_batch, training=False) - pred_rgb_full = outputs["pred_rgb"] + pred_features_full = outputs["pred_features"] rgb_gt_full = gpu_batch.rgb_gt # The values are already alpha composited with the background torchvision.utils.save_image( - pred_rgb_full.squeeze(0).permute(2, 0, 1), + pred_features_full.squeeze(0).permute(2, 0, 1), os.path.join(output_path_renders, "{0:05d}".format(iteration) + ".png"), ) - pred_img_to_write = pred_rgb_full[-1].clip(0, 1.0) + pred_img_to_write = pred_features_full[-1].clip(0, 1.0) gt_img_to_write = rgb_gt_full[-1].clip(0, 1.0) if self.save_gt: @@ -259,7 +305,7 @@ def render_all(self): ) # Compute the loss - psnr_single_img = criterions["psnr"](outputs["pred_rgb"], gpu_batch.rgb_gt).item() + psnr_single_img = criterions["psnr"](outputs["pred_features"], gpu_batch.rgb_gt).item() psnr.append(psnr_single_img) # evaluation on valid rays only logger.info(f"Frame {iteration}, PSNR: {psnr[-1]}") @@ -276,29 +322,29 @@ def render_all(self): # evaluate on full image ssim.append( criterions["ssim"]( - pred_rgb_full.permute(0, 3, 1, 2), + pred_features_full.permute(0, 3, 1, 2), rgb_gt_full.permute(0, 3, 1, 2), ).item() ) lpips.append( criterions["lpips"]( - pred_rgb_full.clip(0, 1).permute(0, 3, 1, 2), + pred_features_full.clip(0, 1).permute(0, 3, 1, 2), rgb_gt_full.permute(0, 3, 1, 2), ).item() ) # Color-corrected metrics - pred_rgb_cc = color_correct_affine(pred_rgb_full, rgb_gt_full) - cc_psnr.append(criterions["psnr"](pred_rgb_cc, rgb_gt_full).item()) + pred_features_cc = color_correct_affine(pred_features_full, rgb_gt_full) + cc_psnr.append(criterions["psnr"](pred_features_cc, rgb_gt_full).item()) cc_ssim.append( criterions["ssim"]( - pred_rgb_cc.permute(0, 3, 1, 2), + pred_features_cc.permute(0, 3, 1, 2), rgb_gt_full.permute(0, 3, 1, 2), ).item() ) cc_lpips.append( criterions["lpips"]( - pred_rgb_cc.clip(0, 1).permute(0, 3, 1, 2), + pred_features_cc.clip(0, 1).permute(0, 3, 1, 2), rgb_gt_full.permute(0, 3, 1, 2), ).item() ) @@ -340,6 +386,7 @@ def render_all(self): mean_cc_psnr=float(mean_cc_psnr), mean_cc_ssim=float(mean_cc_ssim), mean_cc_lpips=float(mean_cc_lpips), + mean_inference_time_ms=float(mean_inference_time), ) metrics_path = os.path.join(self.out_dir, "metrics.json") with open(metrics_path, "w") as f: diff --git a/threedgrut/trainer.py b/threedgrut/trainer.py index 223fa539..24e8d741 100644 --- a/threedgrut/trainer.py +++ b/threedgrut/trainer.py @@ -39,7 +39,7 @@ from threedgrut.strategy.base import BaseStrategy from threedgrut.utils.logger import logger from threedgrut.utils.misc import check_step_condition, create_summary_writer, jet_map -from threedgrut.utils.render import apply_post_processing +from threedgrut.utils.render import apply_background, apply_feature_decoder, apply_post_processing from threedgrut.utils.timer import CudaTimer @@ -85,6 +85,9 @@ class Trainer3DGRUT: _distillation_start_step: int = -1 """ Step at which distillation starts (-1 means disabled) """ + _color_refine_frozen_param_names = frozenset(("positions", "scale", "rotation", "density")) + """ Gaussian optimizer parameter groups frozen during NHT color refinement """ + @staticmethod def create_from_checkpoint(resume: str, conf: DictConfig): """Create a new trainer from a checkpoint file""" @@ -119,6 +122,10 @@ def __init__(self, conf: DictConfig, device=None): """ Total number of train epochs / passes, e.g. single pass over the dataset.""" self.val_frequency = conf.val_frequency """ Validation frequency, in terms on global steps """ + self._color_refine_start_step = self._get_color_refine_start_step(conf) + """ Step at which NHT color refinement starts """ + self._in_color_refine = False + """ Whether NHT color refinement is active """ # Setup the trainer and components logger.log_rule("Load Datasets") @@ -129,11 +136,57 @@ def __init__(self, conf: DictConfig, device=None): self.init_densification_and_pruning_strategy(conf) logger.log_rule("Setup Model Weights & Training") self.init_metrics() + # Feature decoder and post-processing must exist before setup_training so resume can load their state. + self.init_feature_decoder(conf) + self.init_post_processing(conf) self.setup_training(conf, self.model, self.train_dataset) self.init_experiments_tracking(conf) - self.init_post_processing(conf) self.init_gui(conf, self.model, self.train_dataset, self.val_dataset, self.scene_bbox) + def _get_color_refine_start_step(self, conf: DictConfig) -> int: + """Return the first step of the NHT color-only refinement phase.""" + feature_type = str(OmegaConf.select(conf, "model.feature_type", default="sh")).lower() + if feature_type != "nht": + return conf.n_iterations + + color_refine_steps = int(OmegaConf.select(conf, "model.nht_decoder.color_refine_steps", default=0) or 0) + if color_refine_steps <= 0: + return conf.n_iterations + + return max(0, conf.n_iterations - color_refine_steps) + + def _is_color_refine_active(self, global_step: int) -> bool: + return global_step >= self._color_refine_start_step and self._color_refine_start_step < self.conf.n_iterations + + def _apply_color_refine_freeze(self, global_step: int) -> None: + """Freeze Gaussian geometry/opacity optimizer groups while colors keep training.""" + if not self._is_color_refine_active(global_step): + return + + if not self._in_color_refine: + self._in_color_refine = True + self.strategy.suspend() + logger.info( + f"🎨 [step {global_step}] Entering NHT color refinement: " + "freezing geometry + opacity and disabling scale/opacity regularization." + ) + + if self.model.optimizer is None: + return + + for param_group in self.model.optimizer.param_groups: + if param_group.get("name") in self._color_refine_frozen_param_names: + param_group["lr"] = 0.0 + + def _zero_color_refine_frozen_grads(self) -> None: + if not self._in_color_refine or self.model.optimizer is None: + return + + for param_group in self.model.optimizer.param_groups: + if param_group.get("name") in self._color_refine_frozen_param_names: + for param in param_group["params"]: + param.grad = None + def init_dataloaders(self, conf: DictConfig): from threedgrut.datasets.utils import configure_dataloader_for_platform @@ -226,6 +279,17 @@ def setup_training( self.strategy.init_densification_buffer(checkpoint) global_step = checkpoint["global_step"] + # Restore feature decoder state (skip if architecture drifted vs checkpoint) + if "feature_decoder" in checkpoint and self.feature_decoder is not None: + fd_ckpt = checkpoint["feature_decoder"] + self.feature_decoder.load_state_dict(fd_ckpt["module"]) + self.feature_decoder_optimizer.load_state_dict(fd_ckpt["optimizer"]) + self.feature_decoder_scheduler.load_state_dict(fd_ckpt["scheduler"]) + ema_state = fd_ckpt.get("ema") + if ema_state is not None: + self.feature_decoder.load_ema_state_dict(ema_state) + logger.info("🎨 Feature decoder state restored from checkpoint") + # Restore post-processing state if "post_processing" in checkpoint and self.post_processing is not None: self.post_processing.load_state_dict(checkpoint["post_processing"]["module"]) @@ -240,6 +304,7 @@ def setup_training( ): sched.load_state_dict(sched_state) logger.info("📷 Post-processing state restored from checkpoint") + model.build_acc() elif conf.import_ply.enabled: ply_path = ( conf.import_ply.path @@ -329,15 +394,16 @@ def init_gui( ): gui = None + feature_decoder = getattr(self, "feature_decoder", None) if conf.with_gui: from threedgrut.utils.gui import GUI - gui = GUI(conf, model, train_dataset, val_dataset, scene_bbox) + gui = GUI(conf, model, train_dataset, val_dataset, scene_bbox, feature_decoder=feature_decoder) elif conf.with_viser_gui: from threedgrut.utils.viser_gui_util import ViserGUI - gui = ViserGUI(conf, model, train_dataset, val_dataset, scene_bbox) + gui = ViserGUI(conf, model, train_dataset, val_dataset, scene_bbox, feature_decoder=feature_decoder) self.gui = gui @@ -424,6 +490,84 @@ def init_post_processing(self, conf: DictConfig): else: raise ValueError(f"Unknown post-processing method: {method}") + def init_feature_decoder(self, conf: DictConfig): + """Initialize feature decoder for learned features mode.""" + from threedgrut.model.features import Features + + if self.model.feature_type != Features.Type.NHT: + self.feature_decoder = None + self.feature_decoder_optimizer = None + self.feature_decoder_scheduler = None + return + + dec_conf = conf.model.nht_decoder + if not getattr(dec_conf, "enabled", True): + self.feature_decoder = None + self.feature_decoder_optimizer = None + self.feature_decoder_scheduler = None + return + + from threedgrut.model.feature_decoder import FeatureDecoder + + ray_feature_dim = self.model.ray_feature_dim + dec = conf.model.nht_decoder + hidden_dim = dec.hidden_dim + num_layers = getattr(dec, "num_layers", 4) + dir_encoding = getattr(dec, "dir_encoding", "SphericalHarmonics") + dir_encoding_degree = getattr(dec, "dir_encoding_degree", 3) + sh_scale = getattr(dec, "sh_scale", 1.0) + output_activation = getattr(dec, "output_activation", "Sigmoid") + unpremultiply_alpha = getattr(dec, "unpremultiply_alpha", False) + ema_decay = getattr(dec_conf, "ema_decay", 0.0) + ema_start_step = getattr(dec_conf, "ema_start_step", 0) + logger.info(f"Initializing FeatureDecoder: {ray_feature_dim} -> 3 RGB") + self.feature_decoder = FeatureDecoder( + ray_feature_dim=ray_feature_dim, + hidden_dim=hidden_dim, + num_layers=num_layers, + dir_encoding=dir_encoding, + dir_encoding_degree=dir_encoding_degree, + sh_scale=sh_scale, + output_activation=output_activation, + ema_decay=ema_decay, + ema_start_step=ema_start_step, + unpremultiply_alpha=unpremultiply_alpha, + ).to(self.device) + + lr = dec.learning_rate + weight_decay = getattr(dec, "reg_weight", 0.0) + self.feature_decoder_optimizer = torch.optim.Adam( + self.feature_decoder.parameters(), + lr=lr, + weight_decay=weight_decay, + ) + + scheduler_conf = dec.scheduler + max_steps = getattr(conf, "n_iterations", 30000) + decay_final = float(getattr(scheduler_conf, "decay_final", 0.001)) + if scheduler_conf.type == "exponential": + gamma = decay_final ** (1.0 / max_steps) + self.feature_decoder_scheduler = torch.optim.lr_scheduler.ExponentialLR( + self.feature_decoder_optimizer, + gamma=gamma, + ) + elif scheduler_conf.type == "cosine": + eta_min = lr * decay_final + self.feature_decoder_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( + self.feature_decoder_optimizer, + T_max=max_steps, + eta_min=eta_min, + ) + else: + raise ValueError(f"Unknown scheduler type: {scheduler_conf.type}") + + if ema_decay > 0: + logger.info(f"🎨 FeatureDecoder EMA: decay={ema_decay}, start_step={ema_start_step}") + logger.info( + f"🎨 FeatureDecoder optimizer: lr={lr}, " + f"weight_decay={weight_decay}, scheduler={scheduler_conf.type}" + ) + @torch.cuda.nvtx.range("get_metrics") def get_metrics( self, @@ -448,7 +592,7 @@ def get_metrics( step = self.global_step rgb_gt = gpu_batch.rgb_gt - rgb_pred = outputs["pred_rgb"] + rgb_pred = outputs["pred_features"] psnr = self.criterions["psnr"] ssim = self.criterions["ssim"] @@ -471,18 +615,18 @@ def get_metrics( metrics["psnr"] = psnr(rgb_pred, rgb_gt).item() rgb_gt_full = rgb_gt.permute(0, 3, 1, 2) - pred_rgb_full = rgb_pred.permute(0, 3, 1, 2) - pred_rgb_full_clipped = rgb_pred.clip(0, 1).permute(0, 3, 1, 2) + pred_features_full = rgb_pred.permute(0, 3, 1, 2) + pred_features_full_clipped = rgb_pred.clip(0, 1).permute(0, 3, 1, 2) with torch.cuda.nvtx.range(f"criterions_ssim"): - metrics["ssim"] = ssim(pred_rgb_full, rgb_gt_full).item() + metrics["ssim"] = ssim(pred_features_full, rgb_gt_full).item() with torch.cuda.nvtx.range(f"criterions_lpips"): - metrics["lpips"] = lpips(pred_rgb_full_clipped, rgb_gt_full).item() + metrics["lpips"] = lpips(pred_features_full_clipped, rgb_gt_full).item() if iteration in self.conf.writer.log_image_views: metrics["img_hit_counts"] = jet_map(outputs["hits_count"][-1], self.conf.writer.max_num_hits) metrics["img_gt"] = gpu_batch.rgb_gt[-1].clip(0, 1.0) - metrics["img_pred"] = outputs["pred_rgb"][-1].clip(0, 1.0) + metrics["img_pred"] = outputs["pred_features"][-1].clip(0, 1.0) metrics["img_pred_dist"] = jet_map(outputs["pred_dist"][-1], 100) metrics["img_pred_opacity"] = jet_map(outputs["pred_opacity"][-1], 1) @@ -508,7 +652,7 @@ def get_losses( losses: dictionary of loss terms computed for current batch. """ rgb_gt = gpu_batch.rgb_gt - rgb_pred = outputs["pred_rgb"] + rgb_pred = outputs["pred_features"] mask = gpu_batch.mask # Mask out the invalid pixels if the mask is provided @@ -529,7 +673,7 @@ def get_losses( lambda_l2 = 0.0 if self.conf.loss.use_l2: with torch.cuda.nvtx.range(f"loss-l2"): - loss_l2 = torch.nn.functional.mse_loss(outputs["pred_rgb"], rgb_gt) + loss_l2 = torch.nn.functional.mse_loss(outputs["pred_features"], rgb_gt) lambda_l2 = self.conf.loss.lambda_l2 # DSSIM loss @@ -538,14 +682,14 @@ def get_losses( if self.conf.loss.use_ssim: with torch.cuda.nvtx.range(f"loss-ssim"): rgb_gt_full = torch.permute(rgb_gt, (0, 3, 1, 2)) - pred_rgb_full = torch.permute(rgb_pred, (0, 3, 1, 2)) - loss_ssim = 1.0 - ssim(pred_rgb_full, rgb_gt_full) + pred_features_full = torch.permute(rgb_pred, (0, 3, 1, 2)) + loss_ssim = 1.0 - ssim(pred_features_full, rgb_gt_full) lambda_ssim = self.conf.loss.lambda_ssim # Opacity regularization loss_opacity = torch.zeros(1, device=self.device) lambda_opacity = 0.0 - if self.conf.loss.use_opacity: + if self.conf.loss.use_opacity and not self._in_color_refine: with torch.cuda.nvtx.range(f"loss-opacity"): loss_opacity = torch.abs(self.model.get_density()).mean() lambda_opacity = self.conf.loss.lambda_opacity @@ -553,7 +697,7 @@ def get_losses( # Scale regularization loss_scale = torch.zeros(1, device=self.device) lambda_scale = 0.0 - if self.conf.loss.use_scale: + if self.conf.loss.use_scale and not self._in_color_refine: with torch.cuda.nvtx.range(f"loss-scale"): loss_scale = torch.abs(self.model.get_scale()).mean() lambda_scale = self.conf.loss.lambda_scale @@ -713,6 +857,8 @@ def log_training_iter( post_processing_reg_loss, global_step, ) + if self._color_refine_start_step < self.conf.n_iterations: + writer.add_scalar("train/color_refine", float(self._in_color_refine), global_step) if "psnr" in batch_metrics: writer.add_scalar("psnr/train", batch_metrics["psnr"], self.global_step) if "ssim" in batch_metrics: @@ -833,17 +979,24 @@ def on_training_end(self): logger.log_rule("Evaluation on Test Set") # Renderer test split - renderer = Renderer.from_preloaded_model( - model=self.model, - out_dir=out_dir, - path=conf.path, - save_gt=False, - writer=self.tracking.writer, - global_step=self.global_step, - compute_extra_metrics=conf.compute_extra_metrics, - post_processing=self.post_processing, - ) - renderer.render_all() + if self.feature_decoder is not None: + self.feature_decoder.apply_ema_shadow() + try: + renderer = Renderer.from_preloaded_model( + model=self.model, + out_dir=out_dir, + path=conf.path, + save_gt=False, + writer=self.tracking.writer, + global_step=self.global_step, + compute_extra_metrics=conf.compute_extra_metrics, + post_processing=self.post_processing, + feature_decoder=self.feature_decoder, + ) + renderer.render_all() + finally: + if self.feature_decoder is not None: + self.feature_decoder.restore_ema() @torch.cuda.nvtx.range(f"save_checkpoint") def save_checkpoint(self, last_checkpoint: bool = False): @@ -860,6 +1013,26 @@ def save_checkpoint(self, last_checkpoint: bool = False): strategy_parameters = self.strategy.get_strategy_parameters() parameters = {**parameters, **strategy_parameters} + # Add feature decoder state to checkpoint (module + optimizer + scheduler + EMA) + if self.feature_decoder is not None: + dec = self.feature_decoder + parameters["feature_decoder"] = { + "module": dec.state_dict(), + "optimizer": self.feature_decoder_optimizer.state_dict(), + "scheduler": self.feature_decoder_scheduler.state_dict(), + "arch": { + "ray_feature_dim": dec.ray_feature_dim, + "hidden_dim": dec.hidden_dim, + "num_layers": dec.num_layers, + "sh_scale": dec.sh_scale, + "output_activation": dec.output_activation, + "unpremultiply_alpha": dec.unpremultiply_alpha, + }, + } + ema_state = self.feature_decoder.ema_state_dict() + if ema_state: + parameters["feature_decoder"]["ema"] = ema_state + # Add post-processing state to checkpoint (module + optimizers + schedulers) if self.post_processing is not None: parameters["post_processing"] = { @@ -917,6 +1090,8 @@ def run_train_iter( metrics: list, conf: DictConfig, ): + self._apply_color_refine_freeze(global_step) + # Freeze Gaussians and suspend strategy when distillation starts if self._distillation_start_step >= 0 and global_step >= self._distillation_start_step: self.model.freeze_gaussians() @@ -926,6 +1101,8 @@ def run_train_iter( with torch.cuda.nvtx.range(f"train_iter{global_step}_get_gpu_batch"): gpu_batch = self.train_dataset.get_gpu_batch_with_intrinsics(batch) + profilers["step_total"].start() + # Perform validation if required is_time_to_validate = (global_step > 0 or conf.validate_first) and (global_step % self.val_frequency == 0) if is_time_to_validate: @@ -937,6 +1114,14 @@ def run_train_iter( outputs = self.model(gpu_batch, train=True, frame_id=global_step) profilers["inference"].end() + # Apply feature decoder to convert N-dimensional features to RGB + if self.feature_decoder is not None: + with torch.cuda.nvtx.range(f"train_{global_step}_feature_decoder"): + profilers["feature_decoder"].start() + outputs = apply_feature_decoder(self.feature_decoder, outputs, gpu_batch, training=True) + profilers["feature_decoder"].end() + outputs = apply_background(self.model.background, outputs, gpu_batch, training=True) + # Apply post-processing to rendered output if self.post_processing is not None: with torch.cuda.nvtx.range(f"train_{global_step}_post_processing"): @@ -945,6 +1130,14 @@ def run_train_iter( # Compute the losses of a single batch with torch.cuda.nvtx.range(f"train_{global_step}_loss"): batch_losses = self.get_losses(gpu_batch, outputs) + + # Add feature decoder regularization loss + if self.feature_decoder is not None and "decoder_reg_loss" in outputs: + decoder_reg_weight = conf.model.nht_decoder.reg_weight + decoder_reg_loss = decoder_reg_weight * outputs["decoder_reg_loss"] + batch_losses["total_loss"] = batch_losses["total_loss"] + decoder_reg_loss + batch_losses["decoder_reg_loss"] = decoder_reg_loss + # Add post-processing regularization loss if self.post_processing is not None: post_processing_reg_loss = self.post_processing.get_regularization_loss() @@ -979,6 +1172,7 @@ def run_train_iter( # Optimizer step with torch.cuda.nvtx.range(f"train_{global_step}_backprop"): + self._zero_color_refine_frozen_grads() if isinstance(self.model.optimizer, SelectiveAdam): assert ( outputs["mog_visibility"].shape == self.model.density.shape @@ -991,6 +1185,15 @@ def run_train_iter( # Scheduler step with torch.cuda.nvtx.range(f"train_{global_step}_scheduler"): self.model.scheduler_step(global_step) + self._apply_color_refine_freeze(global_step) + + # Feature decoder optimizer/scheduler step + if self.feature_decoder_optimizer is not None: + with torch.cuda.nvtx.range(f"train_{global_step}_feature_decoder_opt"): + self.feature_decoder_optimizer.step() + self.feature_decoder_optimizer.zero_grad() + self.feature_decoder_scheduler.step() + self.feature_decoder.ema_update(global_step) # Post-processing optimizer/scheduler step if self.post_processing_optimizers is not None: @@ -1026,6 +1229,8 @@ def run_train_iter( self.model.build_acc(rebuild=True) profilers["build_as"].end() + profilers["step_total"].end() + # Increment the global step global_step += 1 self.global_step = global_step @@ -1067,7 +1272,10 @@ def run_train_pass(self, conf: DictConfig): "inference": CudaTimer(enabled=self.conf.enable_frame_timings), "backward": CudaTimer(enabled=self.conf.enable_frame_timings), "build_as": CudaTimer(enabled=self.conf.enable_frame_timings), + "step_total": CudaTimer(enabled=self.conf.enable_frame_timings), } + if self.feature_decoder is not None: + profilers["feature_decoder"] = CudaTimer(enabled=self.conf.enable_frame_timings) for iter, batch in enumerate(self.train_dataloader): # Check if we have reached the maximum number of iterations @@ -1087,6 +1295,8 @@ def run_validation_pass(self, conf: DictConfig) -> dict[str, Any]: dictionary of metrics computed and aggregated over validation set. """ + if self.feature_decoder is not None: + self.feature_decoder.apply_ema_shadow() profilers = { "inference": CudaTimer(), } @@ -1106,6 +1316,10 @@ def run_validation_pass(self, conf: DictConfig) -> dict[str, Any]: with torch.cuda.nvtx.range(f"train.validation_step_{self.global_step}"): profilers["inference"].start() outputs = self.model(gpu_batch, train=False) + # Apply feature decoder to convert N-dimensional features to RGB + if self.feature_decoder is not None: + outputs = apply_feature_decoder(self.feature_decoder, outputs, gpu_batch, training=False) + outputs = apply_background(self.model.background, outputs, gpu_batch, training=False) # Apply post-processing for validation (novel view mode) if self.post_processing is not None: outputs = apply_post_processing(self.post_processing, outputs, gpu_batch, training=False) @@ -1125,6 +1339,8 @@ def run_validation_pass(self, conf: DictConfig) -> dict[str, Any]: metrics.append(batch_metrics) logger.end_progress(task_name="Validation") + if self.feature_decoder is not None: + self.feature_decoder.restore_ema() metrics = self._flatten_list_of_dicts(metrics) self.log_validation_pass(metrics) diff --git a/threedgrut/utils/gui.py b/threedgrut/utils/gui.py index 8801f3b9..58b6c498 100644 --- a/threedgrut/utils/gui.py +++ b/threedgrut/utils/gui.py @@ -187,7 +187,7 @@ def render_from_current_ps_view(self): self.render_height = window_h return ( - outputs["pred_rgb"], + outputs["pred_features"], outputs["pred_opacity"], outputs["pred_dist"], outputs["pred_normals"], diff --git a/threedgrut/utils/misc.py b/threedgrut/utils/misc.py index 3757a6f6..b2251404 100644 --- a/threedgrut/utils/misc.py +++ b/threedgrut/utils/misc.py @@ -97,14 +97,28 @@ def helper(step): return helper -def skip_scheduler(type=""): +def cosine_scheduler(lr_init, lr_final, max_steps=1000000, type=""): + """Cosine annealing: lr = lr_final + 0.5 * (lr_init - lr_final) * (1 + cos(pi * step / max_steps)).""" + + def helper(step): + t = np.clip(step / max_steps, 0, 1) + return float(lr_final + 0.5 * (lr_init - lr_final) * (1 + np.cos(np.pi * t))) + + return helper + + +def skip_scheduler(type="", **kwargs): def helper(step): return None return helper -SCHEDULER_DICT: dict[str, Callable] = {"exp": exponential_scheduler, "skip": skip_scheduler} +SCHEDULER_DICT: dict[str, Callable] = { + "exp": exponential_scheduler, + "cosine": cosine_scheduler, + "skip": skip_scheduler, +} def get_scheduler(scheduler: str) -> Callable: diff --git a/threedgrut/utils/render.py b/threedgrut/utils/render.py index 57c0f427..8d6a65ef 100644 --- a/threedgrut/utils/render.py +++ b/threedgrut/utils/render.py @@ -14,6 +14,9 @@ # limitations under the License. +import torch +import torch.nn as nn + ## NOTE: SPH code from gaussian-splatting, from plenoctree, from ??? C0 = 0.28209479177387814 C1 = 0.4886025119029199 @@ -48,6 +51,55 @@ def SH2RGB(sh): return sh * C0 + 0.5 +def apply_feature_decoder( + feature_decoder, + outputs: dict, + gpu_batch, + training: bool = False, +) -> dict: + """Apply feature decoder to N-dimensional feature map.""" + if feature_decoder is None: + return outputs + + feature_map = outputs["pred_features"] # [B, H, W, N] alpha-blended features + alpha = outputs["pred_opacity"] # [B, H, W] or [B, H, W, 1] + B, H, W, N = feature_map.shape + + R = gpu_batch.T_to_world[:, :3, :3] # [B, 3, 3] c2w rotation + rays_dir_cam = gpu_batch.rays_dir # [B, H, W, 3] + rays_dir_world = torch.einsum("bij,bhwj->bhwi", R, rays_dir_cam) + rays_dir_world = torch.nn.functional.normalize(rays_dir_world, dim=-1) + + features_flat = feature_map.contiguous().view(-1, N) + ray_dir_flat = rays_dir_world.contiguous().view(-1, 3) + if alpha.dim() == 3: + alpha = alpha.unsqueeze(-1) # [B, H, W, 1] + alpha_flat = alpha.contiguous().view(-1, 1) + + rgb_flat = feature_decoder(features_flat, ray_dir_flat, alpha=alpha_flat) + outputs["pred_features"] = rgb_flat.view(B, H, W, 3) + + if training and hasattr(feature_decoder, "regularization_loss"): + outputs["decoder_reg_loss"] = feature_decoder.regularization_loss() + + return outputs + + +def apply_background(background, outputs: dict, gpu_batch, training: bool = False) -> dict: + """Apply background to decoded RGB (3-channel). Call after apply_feature_decoder when using nht.""" + if background is None or outputs["pred_features"].shape[-1] != 3: + return outputs + pred_features, pred_opacity = background( + gpu_batch.T_to_world.contiguous(), + gpu_batch.rays_dir.contiguous(), + outputs["pred_features"], + outputs["pred_opacity"], + training, + ) + outputs["pred_features"] = pred_features + return outputs + + def apply_post_processing( post_processing, outputs: dict, @@ -58,28 +110,28 @@ def apply_post_processing( Args: post_processing: Post-processing module - outputs: Model outputs including pred_rgb + outputs: Model outputs including pred_features gpu_batch: Batch containing camera_idx, frame_idx, pixel_coords, exposure training: If True, use actual frame_idx; if False, use -1 for novel view mode Returns: - Updated outputs dict with post-processed pred_rgb + Updated outputs dict with post-processed pred_features """ - assert outputs["pred_rgb"].shape[0] == 1, "Post-processing requires batch_size=1" + assert outputs["pred_features"].shape[0] == 1, "Post-processing requires batch_size=1" - pred_rgb = outputs["pred_rgb"] + pred_features = outputs["pred_features"] camera_idx = gpu_batch.camera_idx frame_idx = gpu_batch.frame_idx if training else -1 - H, W = pred_rgb.shape[1], pred_rgb.shape[2] + H, W = pred_features.shape[1], pred_features.shape[2] # Flatten: [1, H, W, 3] -> [H*W, 3] # Ensure contiguous memory for CUDA kernels - pred_rgb_flat = pred_rgb.contiguous().view(-1, 3) + pred_features_flat = pred_features.contiguous().view(-1, 3) pixel_coords_flat = gpu_batch.pixel_coords.contiguous().view(-1, 2) # Apply post-processing - pred_rgb_pp = post_processing( - pred_rgb_flat, + pred_features_pp = post_processing( + pred_features_flat, pixel_coords_flat, resolution=(W, H), camera_idx=camera_idx, @@ -88,5 +140,5 @@ def apply_post_processing( ) # Reshape back: [H*W, 3] -> [1, H, W, 3] - outputs["pred_rgb"] = pred_rgb_pp.view(pred_rgb.shape) + outputs["pred_features"] = pred_features_pp.view(pred_features.shape) return outputs diff --git a/threedgrut/utils/viser_gui_util.py b/threedgrut/utils/viser_gui_util.py index c986aa0e..799f497b 100644 --- a/threedgrut/utils/viser_gui_util.py +++ b/threedgrut/utils/viser_gui_util.py @@ -252,7 +252,7 @@ def render_from_current_view( # points_plane = self.model.positions[u, v] # Return the same outputs as polyscope version return ( - outputs["pred_rgb"], + outputs["pred_features"], outputs["pred_opacity"], outputs["pred_dist"], outputs["pred_normals"], diff --git a/threedgrut_playground/engine.py b/threedgrut_playground/engine.py index 65b8da08..97be5d27 100644 --- a/threedgrut_playground/engine.py +++ b/threedgrut_playground/engine.py @@ -882,7 +882,7 @@ def _render_depth_of_field_buffer(self, rb, camera, rays): if not self.primitives.enabled or not self.primitives.has_visible_objects(): if self.disable_gaussian_tracing: dof_rb = dict( - pred_rgb=torch.zeros_like(rays.rays_ori), + pred_features=torch.zeros_like(rays.rays_ori), pred_opacity=torch.zeros_like(rays.rays_ori[:, :, 0:1]), ) else: @@ -890,7 +890,7 @@ def _render_depth_of_field_buffer(self, rb, camera, rays): else: dof_rb = self._render_playground_hybrid(dof_rays_ori, dof_rays_dir) - rb["rgb"] = self._accumulate_to_buffer(rb["rgb"], dof_rb["pred_rgb"], i, self.gamma_correction) + rb["rgb"] = self._accumulate_to_buffer(rb["rgb"], dof_rb["pred_features"], i, self.gamma_correction) rb["opacity"] = (rb["opacity"] * i + dof_rb["pred_opacity"]) / (i + 1) def _render_spp_buffer(self, rb, rays): @@ -902,14 +902,14 @@ def _render_spp_buffer(self, rb, rays): if not self.primitives.enabled or not self.primitives.has_visible_objects(): if self.disable_gaussian_tracing: spp_rb = dict( - pred_rgb=torch.zeros_like(rays.rays_ori), + pred_features=torch.zeros_like(rays.rays_ori), pred_opacity=torch.zeros_like(rays.rays_ori[:, :, 0:1]), ) else: spp_rb = self.scene_mog.trace(rays_o=rays.rays_ori, rays_d=rays.rays_dir) else: spp_rb = self._render_playground_hybrid(rays.rays_ori, rays.rays_dir) - batch_rgb = spp_rb["pred_rgb"].sum(dim=0).unsqueeze(0) + batch_rgb = spp_rb["pred_features"].sum(dim=0).unsqueeze(0) rb["rgb"] = self._accumulate_to_buffer( rb["rgb"], batch_rgb, i, self.gamma_correction, batch_size=self.spp.batch_size ) @@ -928,7 +928,7 @@ def _render_playground_hybrid(self, rays_o: torch.Tensor, rays_d: torch.Tensor) Returns: Dict[str, torch.Tensor]: Rendering results containing: - - 'pred_rgb': RGB colors of shape (B, H, W, 3), range [0, 1] + - 'pred_features': RGB colors of shape (B, H, W, 3), range [0, 1] - 'pred_opacity': Opacity values of shape (B, H, W, 1), range [0, 1] - 'last_ray_d': Final ray directions for background computation - Additional buffers from tracer.render_playground() @@ -979,14 +979,14 @@ def _render_playground_hybrid(self, rays_o: torch.Tensor, rays_d: torch.Tensor) max_pbr_bounces=self.max_pbr_bounces, ) - pred_rgb = rendered_results["pred_rgb"] + pred_features = rendered_results["pred_features"] pred_opacity = rendered_results["pred_opacity"] # If no envmap is used for background, saturate the color channels by blending the mog background if envmap is None or not self.environment.is_ignore_envmap(): poses = torch.tensor([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0]], dtype=torch.float32) - pred_rgb, pred_opacity = mog.background( - poses.contiguous(), rendered_results["last_ray_d"].contiguous(), pred_rgb, pred_opacity, False + pred_features, pred_opacity = mog.background( + poses.contiguous(), rendered_results["last_ray_d"].contiguous(), pred_features, pred_opacity, False ) # Mark materials as uploaded @@ -995,9 +995,9 @@ def _render_playground_hybrid(self, rays_o: torch.Tensor, rays_d: torch.Tensor) # Advance frame id (for i.e., random number generator) and avoid int32 overflow self.frame_id = self.frame_id + self.spp.batch_size if self.frame_id <= (2**31 - 1) else 0 - pred_rgb = torch.clamp(pred_rgb, 0.0, 1.0) # Make sure image pixels are in valid range + pred_features = torch.clamp(pred_features, 0.0, 1.0) # Make sure image pixels are in valid range - rendered_results["pred_rgb"] = pred_rgb + rendered_results["pred_features"] = pred_features return rendered_results @torch.cuda.nvtx.range("render_pass") @@ -1065,7 +1065,7 @@ def render_pass(self, camera: Camera, is_first_pass: bool) -> Dict[str, torch.Te if not self.primitives.enabled or not self.primitives.has_visible_objects(): if self.disable_gaussian_tracing: rb = dict( - pred_rgb=torch.zeros_like(rays.rays_ori), + pred_features=torch.zeros_like(rays.rays_ori), pred_opacity=torch.zeros_like(rays.rays_ori[:, :, 0:1]), ) else: @@ -1073,7 +1073,7 @@ def render_pass(self, camera: Camera, is_first_pass: bool) -> Dict[str, torch.Te else: rb = self._render_playground_hybrid(rays.rays_ori, rays.rays_dir) - rb = dict(rgb=rb["pred_rgb"], opacity=rb["pred_opacity"]) + rb = dict(rgb=rb["pred_features"], opacity=rb["pred_opacity"]) rb["rgb"] = self.environment.tonemap(rb["rgb"]) rb["rgb"] = torch.pow(rb["rgb"], 1.0 / self.gamma_correction) rb["rgb"] = rb["rgb"].mean(dim=0).unsqueeze(0) diff --git a/threedgrut_playground/include/playground/hybridTracer.h b/threedgrut_playground/include/playground/hybridTracer.h index 07444f2c..87d2d769 100644 --- a/threedgrut_playground/include/playground/hybridTracer.h +++ b/threedgrut_playground/include/playground/hybridTracer.h @@ -129,7 +129,7 @@ class HybridOptixTracer : public OptixTracer { torch::Tensor> traceHybrid(uint32_t frameNumber, torch::Tensor rayToWorld, torch::Tensor rayOri, torch::Tensor rayDir, - torch::Tensor particleDensity, torch::Tensor particleRadiance, + torch::Tensor particleDensity, torch::Tensor particleFeatures, int sphDegree, float minTransmittance, torch::Tensor rayMaxT, uint32_t playgroundOpts, torch::Tensor triangles, torch::Tensor vNormals, torch::Tensor vTangents, diff --git a/threedgrut_playground/include/playground/kernels/cuda/trace.cuh b/threedgrut_playground/include/playground/kernels/cuda/trace.cuh index 717b9a8f..85130499 100644 --- a/threedgrut_playground/include/playground/kernels/cuda/trace.cuh +++ b/threedgrut_playground/include/playground/kernels/cuda/trace.cuh @@ -119,9 +119,9 @@ static __device__ __forceinline__ void clearOutputBuffers() { const int ry = fminf(idx.y, params.frameBounds.y); // Ray coordinates in pixels - params.rayRadiance[idx.z][ry][rx][0] = 0.0f; - params.rayRadiance[idx.z][ry][rx][1] = 0.0f; - params.rayRadiance[idx.z][ry][rx][2] = 0.0f; + params.rayFeatures[idx.z][ry][rx][0] = 0.0f; + params.rayFeatures[idx.z][ry][rx][1] = 0.0f; + params.rayFeatures[idx.z][ry][rx][2] = 0.0f; params.rayDensity[idx.z][ry][rx][0] = 0.0f; } @@ -133,9 +133,9 @@ writeRadianceDensityToOutputBuffer(float4 radiance) { const int ry = fminf(idx.y, params.frameBounds.y); // Ray coordinates in pixels - params.rayRadiance[idx.z][ry][rx][0] = radiance.x; - params.rayRadiance[idx.z][ry][rx][1] = radiance.y; - params.rayRadiance[idx.z][ry][rx][2] = radiance.z; + params.rayFeatures[idx.z][ry][rx][0] = radiance.x; + params.rayFeatures[idx.z][ry][rx][1] = radiance.y; + params.rayFeatures[idx.z][ry][rx][2] = radiance.z; params.rayDensity[idx.z][ry][rx][0] = radiance.w; } @@ -147,9 +147,9 @@ accumulateRadianceToOutputBuffer(float3 radiance) { const int ry = fminf(idx.y, params.frameBounds.y); // Ray coordinates in pixels - params.rayRadiance[idx.z][ry][rx][0] += radiance.x; - params.rayRadiance[idx.z][ry][rx][1] += radiance.y; - params.rayRadiance[idx.z][ry][rx][2] += radiance.z; + params.rayFeatures[idx.z][ry][rx][0] += radiance.x; + params.rayFeatures[idx.z][ry][rx][1] += radiance.y; + params.rayFeatures[idx.z][ry][rx][2] += radiance.z; } static __device__ __forceinline__ void @@ -160,9 +160,9 @@ accumulateRadianceDensityToOutputBuffer(float4 radiance) { const int ry = fminf(idx.y, params.frameBounds.y); // Ray coordinates in pixels - params.rayRadiance[idx.z][ry][rx][0] += radiance.x; - params.rayRadiance[idx.z][ry][rx][1] += radiance.y; - params.rayRadiance[idx.z][ry][rx][2] += radiance.z; + params.rayFeatures[idx.z][ry][rx][0] += radiance.x; + params.rayFeatures[idx.z][ry][rx][1] += radiance.y; + params.rayFeatures[idx.z][ry][rx][2] += radiance.z; params.rayDensity[idx.z][ry][rx][0] += radiance.w; } diff --git a/threedgrut_playground/src/hybridTracer.cpp b/threedgrut_playground/src/hybridTracer.cpp index 6cc95306..05683265 100644 --- a/threedgrut_playground/src/hybridTracer.cpp +++ b/threedgrut_playground/src/hybridTracer.cpp @@ -314,7 +314,7 @@ std::tuple(rayDir); paramsHost.particleDensity = getPtr(particleDensity); - paramsHost.particleRadiance = getPtr(particleRadiance); + paramsHost.particleFeatures = getPtr(particleFeatures); paramsHost.particleExtendedData = reinterpret_cast(_state->gPipelineParticleData); - paramsHost.rayRadiance = packed_accessor32(rayRad); + paramsHost.rayFeatures = packed_accessor32(rayRad); paramsHost.rayDensity = packed_accessor32(rayDns); paramsHost.rayHitDistance = packed_accessor32(rayHit); paramsHost.rayNormal = packed_accessor32(rayNrm); diff --git a/threedgrut_playground/src/kernels/cuda/3dgrtKernel.cu b/threedgrut_playground/src/kernels/cuda/3dgrtKernel.cu index 9fdd2b4e..a48976cf 100644 --- a/threedgrut_playground/src/kernels/cuda/3dgrtKernel.cu +++ b/threedgrut_playground/src/kernels/cuda/3dgrtKernel.cu @@ -28,9 +28,9 @@ extern "C" __global__ void __raygen__rg() { traceVolumetricGS(rayData, rayOrigin, rayDirection, 1e-9, 1e9); - params.rayRadiance[idx.z][idx.y][idx.x][0] = rayData.radiance.x; - params.rayRadiance[idx.z][idx.y][idx.x][1] = rayData.radiance.y; - params.rayRadiance[idx.z][idx.y][idx.x][2] = rayData.radiance.z; + params.rayFeatures[idx.z][idx.y][idx.x][0] = rayData.radiance.x; + params.rayFeatures[idx.z][idx.y][idx.x][1] = rayData.radiance.y; + params.rayFeatures[idx.z][idx.y][idx.x][2] = rayData.radiance.z; params.rayDensity[idx.z][idx.y][idx.x][0] = rayData.density; params.rayHitDistance[idx.z][idx.y][idx.x][0] = rayData.hitDistance; params.rayHitDistance[idx.z][idx.y][idx.x][1] = rayData.rayLastHitDistance; diff --git a/threedgrut_playground/tracer.py b/threedgrut_playground/tracer.py index 77e03799..566ca93d 100644 --- a/threedgrut_playground/tracer.py +++ b/threedgrut_playground/tracer.py @@ -115,7 +115,7 @@ def render(self, gaussians, gpu_batch, train=False, frame_id=0): mog_scl = gaussians.get_scale().contiguous() particle_density = torch.concat([mog_pos, mog_dns, mog_rot, mog_scl, torch.zeros_like(mog_dns)], dim=1) - pred_rgb, pred_opacity, pred_dist, pred_normals, hits_count = self.tracer_wrapper.trace( + pred_features, pred_opacity, pred_dist, pred_normals, hits_count = self.tracer_wrapper.trace( frame_id, gpu_batch["poses"].contiguous(), gpu_batch["rays_o_cam"].contiguous(), @@ -127,12 +127,12 @@ def render(self, gaussians, gpu_batch, train=False, frame_id=0): ) # NOTE: disable background - pred_rgb, pred_opacity = gaussians.background( - gpu_batch["poses"].contiguous(), gpu_batch["rays_d_cam"].contiguous(), pred_rgb, pred_opacity, train + pred_features, pred_opacity = gaussians.background( + gpu_batch["poses"].contiguous(), gpu_batch["rays_d_cam"].contiguous(), pred_features, pred_opacity, train ) return { - "pred_rgb": pred_rgb, + "pred_features": pred_features, "pred_opacity": pred_opacity, "pred_dist": pred_dist, "pred_normals": torch.nn.functional.normalize(pred_normals, dim=3), @@ -224,7 +224,7 @@ def render_playground( min_transmittance = self.conf.render.min_transmittance envmap_offset = envmap_offset.contiguous() - pred_rgb, pred_opacity, pred_dist, pred_normals, hits_count = self.tracer_wrapper.trace_hybrid( + pred_features, pred_opacity, pred_dist, pred_normals, hits_count = self.tracer_wrapper.trace_hybrid( frame_id, poses, ray_o, @@ -253,7 +253,7 @@ def render_playground( pred_dist = pred_dist[:, :, :, 0:1] # return only the hit distance return { - "pred_rgb": pred_rgb, + "pred_features": pred_features, "pred_opacity": pred_opacity, "pred_dist": pred_dist, "pred_normals": torch.nn.functional.normalize(pred_normals, dim=3), diff --git a/threedgut_tracer/include/3dgut/kernels/cuda/common/rayPayload.cuh b/threedgut_tracer/include/3dgut/kernels/cuda/common/rayPayload.cuh index 8fca0f5e..bdc91a74 100644 --- a/threedgut_tracer/include/3dgut/kernels/cuda/common/rayPayload.cuh +++ b/threedgut_tracer/include/3dgut/kernels/cuda/common/rayPayload.cuh @@ -107,6 +107,20 @@ __device__ __inline__ RayPayloadT initializeRay(const threedgut::RenderParameter return ray; } +// Output element type for the feature+density buffer. +// FEATURE_OUTPUT_HALF=1: write __half (fp16) to halve memory bandwidth. +// FEATURE_OUTPUT_HALF=0: write float (fp32, default). +#ifndef FEATURE_OUTPUT_HALF +#define FEATURE_OUTPUT_HALF 0 +#endif + +#if FEATURE_OUTPUT_HALF +#include +using TFeatureDensityElem = __half; +#else +using TFeatureDensityElem = float; +#endif + // Initialize ray based on given pixel coordinates (load-balanced mode) template __device__ __inline__ RayPayloadT initializeRayPerPixel(const threedgut::RenderParameters& params, @@ -149,13 +163,27 @@ __device__ __inline__ void finalizeRay(const TRayPayload& ray, const tcnn::vec3* __restrict__ sensorRayOriginPtr, float* __restrict__ worldCountPtr, float* __restrict__ worldHitDistancePtr, - tcnn::vec4* __restrict__ radianceDensityPtr, + TFeatureDensityElem* __restrict__ featureDensityPtr, const tcnn::mat4x3& sensorToWorldTransform) { if (!ray.isValid()) { return; } - radianceDensityPtr[ray.idx] = {ray.features[0], ray.features[1], ray.features[2], (1.0f - ray.transmittance)}; + static_assert(RAY_FEATURE_DIM == TRayPayload::FeatDim, "RAY_FEATURE_DIM must equal TRayPayload::FeatDim"); + const uint32_t base = ray.idx * (RAY_FEATURE_DIM + 1); +#if FEATURE_OUTPUT_HALF + #pragma unroll + for (int i = 0; i < TRayPayload::FeatDim; ++i) { + featureDensityPtr[base + i] = __float2half(ray.features[i]); + } + featureDensityPtr[base + RAY_FEATURE_DIM] = __float2half(1.0f - ray.transmittance); +#else + #pragma unroll + for (int i = 0; i < TRayPayload::FeatDim; ++i) { + featureDensityPtr[base + i] = ray.features[i]; + } + featureDensityPtr[base + RAY_FEATURE_DIM] = (1.0f - ray.transmittance); +#endif worldHitDistancePtr[ray.idx] = ray.hitT; diff --git a/threedgut_tracer/include/3dgut/kernels/cuda/common/rayPayloadBackward.cuh b/threedgut_tracer/include/3dgut/kernels/cuda/common/rayPayloadBackward.cuh index 67c44db4..ff37f159 100644 --- a/threedgut_tracer/include/3dgut/kernels/cuda/common/rayPayloadBackward.cuh +++ b/threedgut_tracer/include/3dgut/kernels/cuda/common/rayPayloadBackward.cuh @@ -33,8 +33,8 @@ __device__ __inline__ RayPayloadT initializeBackwardRay(const threedgut::RenderP const tcnn::vec3* __restrict__ sensorRayDirectionPtr, const float* __restrict__ worldHitDistancePtr, const float* __restrict__ worldHitDistanceGradientPtr, - const tcnn::vec* __restrict__ featuresDensityPtr, - const tcnn::vec* __restrict__ featuresDensityGradientPtr, + const TFeatureDensityElem* __restrict__ featuresDensityPtr, + const float* __restrict__ featuresDensityGradientPtr, const tcnn::mat4x3& sensorToWorldTransform) { // NB : no backpropagation through the forward ray initialization / finalization @@ -44,14 +44,29 @@ __device__ __inline__ RayPayloadT initializeBackwardRay(const threedgut::RenderP sensorToWorldTransform); if (ray.isAlive()) { - const tcnn::vec featuresDensity = featuresDensityPtr[ray.idx]; - const tcnn::vec featuresDensityGradient = featuresDensityGradientPtr[ray.idx]; - ray.transmittanceBackward = 1.f - featuresDensity[RayPayloadT::FeatDim]; - ray.transmittanceGradient = -1.f * featuresDensityGradient[RayPayloadT::FeatDim]; - ray.hitTBackward = worldHitDistancePtr[ray.idx]; - ray.hitTGradient = worldHitDistanceGradientPtr[ray.idx]; - ray.featuresBackward = threedgut::sliceVec<0, RayPayloadT::FeatDim>(featuresDensity); - ray.featuresGradient = threedgut::sliceVec<0, RayPayloadT::FeatDim>(featuresDensityGradient); + constexpr uint32_t stride = RayPayloadT::FeatDim + 1; + const uint32_t base = ray.idx * stride; + // Forward features: fp16 when FEATURE_OUTPUT_HALF=1, fp32 otherwise. + // Gradient buffer: always fp32 — keeps backward numerically stable regardless of forward dtype. +#if FEATURE_OUTPUT_HALF + #pragma unroll + for (int i = 0; i < RayPayloadT::FeatDim; ++i) { + ray.featuresBackward[i] = __half2float(featuresDensityPtr[base + i]); + ray.featuresGradient[i] = featuresDensityGradientPtr[base + i]; + } + ray.transmittanceBackward = 1.f - __half2float(featuresDensityPtr[base + RayPayloadT::FeatDim]); + ray.transmittanceGradient = -1.f * featuresDensityGradientPtr[base + RayPayloadT::FeatDim]; +#else + #pragma unroll + for (int i = 0; i < RayPayloadT::FeatDim; ++i) { + ray.featuresBackward[i] = featuresDensityPtr[base + i]; + ray.featuresGradient[i] = featuresDensityGradientPtr[base + i]; + } + ray.transmittanceBackward = 1.f - featuresDensityPtr[base + RayPayloadT::FeatDim]; + ray.transmittanceGradient = -1.f * featuresDensityGradientPtr[base + RayPayloadT::FeatDim]; +#endif + ray.hitTBackward = worldHitDistancePtr[ray.idx]; + ray.hitTGradient = worldHitDistanceGradientPtr[ray.idx]; } return ray; diff --git a/threedgut_tracer/include/3dgut/kernels/cuda/models/gaussianParticles.cuh b/threedgut_tracer/include/3dgut/kernels/cuda/models/gaussianParticles.cuh index 0d311552..7200ae48 100644 --- a/threedgut_tracer/include/3dgut/kernels/cuda/models/gaussianParticles.cuh +++ b/threedgut_tracer/include/3dgut/kernels/cuda/models/gaussianParticles.cuh @@ -15,6 +15,10 @@ #include <3dgut/kernels/cuda/common/mathUtils.cuh> +#if PARTICLE_FEATURE_HALF +#include +#endif + namespace threedgut { struct ParticleDensity { @@ -382,7 +386,7 @@ __device__ inline bool processHitFwd( const float grayDist = dot(gcrod, gcrod); const float gres = particleResponse(grayDist); - const float galpha = fminf(0.99f, gres * particleDensity); + const float galpha = fminf(GAUSSIAN_PARTICLE_MAX_ALPHA, gres * particleDensity); const bool acceptHit = (gres > minParticleKernelDensity) && (galpha > minParticleAlpha); if (acceptHit) { @@ -525,7 +529,7 @@ __device__ inline void processHitBwd( const float grayDist = dot(gcrod, gcrod); const float gres = particleResponse(grayDist); - const float galpha = fminf(0.99f, gres * particleDensity); + const float galpha = fminf(GAUSSIAN_PARTICLE_MAX_ALPHA, gres * particleDensity); if ((gres > minParticleKernelDensity) && (galpha > minParticleAlpha)) { @@ -746,4 +750,237 @@ __device__ inline void processHitBwd( } } +// ======================================================================================= +// Neural Harmonic Textures — handwritten CUDA backward. +// +// Mirrors `particleFeaturesIntegrateBwdToBuffer` from neuralHarmonicFeaturesParticle.slang +// called with `exclusiveGradient=true` and a shifted local-grad buffer (see +// `featuresIntegrateBwdToLocalGrad` in shRadiativeGaussianParticles.cuh for context). +// +// Scope: tetrahedral 4-vertex barycentric interpolation + {None, Siren, Sincos, Relu} +// activation, fp16/fp32 particle feature storage, fp32 gradients. +// ======================================================================================= +namespace nht { + +// Canonical tetrahedron matching the GSplat NHT reference vertex ordering. +// v0=(sqrt(6),-sqrt(2),-1), v1=(-sqrt(6),-sqrt(2),-1), v2=(0,2*sqrt(2),-1), v3=(0,0,3). +__forceinline__ __device__ float3 tetraV0() { + return make_float3(2.449489742783178f, -1.4142135623730951f, -1.0f); +} +// Scaled inward face normals N_k = d(w_k)/dP from the GSplat reference tetrahedron. +// Verified: w_k = 1 exactly at v_k, 0 at the other vertices, sum_k w_k = 1. +__forceinline__ __device__ float3 tetraN0() { + return make_float3( 0.20412414523193154f, -0.11785113019775792f, -0.08333333333333333f); +} +__forceinline__ __device__ float3 tetraN1() { + return make_float3(-0.20412414523193154f, -0.11785113019775792f, -0.08333333333333333f); +} +__forceinline__ __device__ float3 tetraN2() { + return make_float3( 0.00000000000000000f, 0.23570226039551587f, -0.08333333333333333f); +} +__forceinline__ __device__ float3 tetraN3() { + return make_float3( 0.00000000000000000f, 0.00000000000000000f, 0.25000000000000000f); +} + +// Mirrors neuralHarmonicFeaturesParticle.slang's FeatureActivationType_*. +enum ActivationType : int { + ActivationNone = 0, + ActivationSiren = 1, + ActivationSincos = 2, + ActivationRelu = 3, +}; + +// Single-element load promoting fp16 to fp32 when needed; identity on float. +// Uses __half's device `operator float()` (cuda_fp16.h) when TFeatElem is __half. +template +__forceinline__ __device__ float loadFeatureElem(const TFeatElem* p) { + return static_cast(*p); +} + +/** + * Handwritten reverse of Slang's `particleFeaturesIntegrateBwdToBuffer` for the NHT path. + * Writes per-vertex gradient contributions directly into a thread-private `featureLocalGrad` + * scratch buffer of size `4*InterpPointFeatureDim` (caller warp-reduces + atomic-adds + * downstream — see `featureLocalGradWarpReduceAndWrite`). + * + * Buffer layout assumption: `vertex_k[i] = particleFeatureBufPtr[particleIdx*4*IPFD + k*IPFD + i]`. + * + * Walk order: renderer traverses hits front-to-back (near → far) for both fwd and bwd. + * Compositing algebra: Slang decomposes fwd as the equivalent back-to-front lerp + * C_i = alpha_i * f_i + (1 - alpha_i) * C_{i+1}, C_0 = forward-final color, C_N = 0. + * Each call to this function unwinds one lerp step of the current particle, advancing + * the stored accumulator from C_i to C_{i+1} — consistent with the front-to-back walk + * starting from `integratedFeatures = C_0` (set by initializeBackwardRay). + * + * Mutations (match Slang `particleFeaturesIntegrateBwdToBuffer` exactly): + * - integratedFeatures[i] = (integratedFeatures[i] - features[i] * alpha) / (1 - alpha) + * (recovers C_{i+1} from C_i; seen by the next hit) + * - integratedFeaturesGrad[i] = (1 - alpha) * d(acc_post)_i + * (d(C_i)/d(C_{i+1}) from the lerp VJP) + * - alphaGrad += Σ_i (features[i] - acc_prev_i) * d(acc_post)_i + * - canonicalIntersectionGrad += barycentric VJP + * - featureLocalGrad[k*IPFD+i]+= w_k * dBase[i] (matches `exclusiveGradient=true`) + * + * No-op when `alpha <= 0`. + */ +template +__forceinline__ __device__ void featuresIntegrateBwdToLocalGrad( + const float3& canonicalIntersection, + float3& canonicalIntersectionGrad, + float alpha, + float& alphaGrad, + uint32_t particleIdx, + const float* features, // [RayFeatureDim] read-only + float* integratedFeatures, // [RayFeatureDim] inout (→ acc_prev) + float* integratedFeaturesGrad, // [RayFeatureDim] inout (→ d(acc_prev)) + const TFeatElem* particleFeatureBufPtr, // [N * 4*IPFD] + float* featureLocalGrad) { // [4*IPFD] inout (+= accumulator) + + static_assert(ActivationType >= 0 && ActivationType <= 3, + "unsupported FEATURE_ACTIVATION_TYPE for NHT CUDA backward"); + constexpr int kIPFD = InterpPointFeatureDim; + constexpr int kRFD = (ActivationType == ActivationNone || ActivationType == ActivationRelu) + ? kIPFD + : (ActivationType == ActivationSincos ? kIPFD * NumFrequencies * 2 : kIPFD * NumFrequencies); + constexpr int kPFD = 4 * kIPFD; + + if (alpha <= 0.0f) { + return; + } + + const float oneMinusAlpha = 1.0f - alpha; + const float invOneMinusAlpha = 1.0f / oneMinusAlpha; + + // Step 1+2: advance stored accumulator C_i -> C_{i+1} (one lerp unwind) AND apply the lerp VJP. + // fwd (lerp form): y_i = (1-alpha)*acc_prev_i + alpha*f_i + // unwind: acc_prev_i = (y_i - alpha*f_i) / (1-alpha) + // VJP: d(acc_prev)_i = (1-alpha) * dy_i + // d(f)_i = alpha * dy_i + // d(alpha) += sum_i (f_i - acc_prev_i) * dy_i + float dFeatures[kRFD]; + float dAlphaAcc = 0.0f; +#pragma unroll + for (int i = 0; i < kRFD; ++i) { + const float dy_i = integratedFeaturesGrad[i]; + const float accPrev_i = (integratedFeatures[i] - features[i] * alpha) * invOneMinusAlpha; + dFeatures[i] = alpha * dy_i; + dAlphaAcc += (features[i] - accPrev_i) * dy_i; + integratedFeatures[i] = accPrev_i; + integratedFeaturesGrad[i] = oneMinusAlpha * dy_i; + } + alphaGrad += dAlphaAcc; + + // Barycentric weights (Slang Cramer form). + float w[4]; + { + const float3 v0 = tetraV0(); + const float3 N1 = tetraN1(); + const float3 N2 = tetraN2(); + const float3 N3 = tetraN3(); + const float3 d = make_float3(canonicalIntersection.x - v0.x, + canonicalIntersection.y - v0.y, + canonicalIntersection.z - v0.z); + w[1] = d.x * N1.x + d.y * N1.y + d.z * N1.z; + w[2] = d.x * N2.x + d.y * N2.y + d.z * N2.z; + w[3] = d.x * N3.x + d.y * N3.y + d.z * N3.z; + w[0] = 1.0f - w[1] - w[2] - w[3]; + } + + // Load all 4 vertex feature blocks once (fp16 → fp32 if applicable). + const uint32_t particleOffset = particleIdx * kPFD; + float vert[4][kIPFD]; +#pragma unroll + for (int k = 0; k < 4; ++k) { + const uint32_t vkOff = particleOffset + k * kIPFD; +#pragma unroll + for (int i = 0; i < kIPFD; ++i) { + vert[k][i] = loadFeatureElem(&particleFeatureBufPtr[vkOff + i]); + } + } + + // Step 3: activation backward → dBase[kIPFD]. + float dBase[kIPFD]; + if constexpr (ActivationType == ActivationNone) { +#pragma unroll + for (int i = 0; i < kIPFD; ++i) { + dBase[i] = dFeatures[i]; + } + } else if constexpr (ActivationType == ActivationRelu) { + // features[i] = max(0, base[i]); mask from features[i] > 0 (bit-identical to base[i] > 0). +#pragma unroll + for (int i = 0; i < kIPFD; ++i) { + dBase[i] = (features[i] > 0.0f) ? dFeatures[i] : 0.0f; + } + } else { + // Siren / Sincos: need baseFeatures for the trig derivative. + float baseFeatures[kIPFD] = {}; +#pragma unroll + for (int k = 0; k < 4; ++k) { +#pragma unroll + for (int i = 0; i < kIPFD; ++i) { + baseFeatures[i] += w[k] * vert[k][i]; + } + } +#pragma unroll + for (int i = 0; i < kIPFD; ++i) { + dBase[i] = 0.0f; + } + if constexpr (ActivationType == ActivationSiren) { + // features[k*NFreq+f] = sin(base[k] * 2^f) → dBase[k] += cos(angle)*freq*dOut +#pragma unroll + for (int k = 0; k < kIPFD; ++k) { +#pragma unroll + for (int f = 0; f < NumFrequencies; ++f) { + const float freq = ldexpf(1.0f, f); + const float angle = baseFeatures[k] * freq; + dBase[k] += cosf(angle) * freq * dFeatures[k * NumFrequencies + f]; + } + } + } else { + // ActivationSincos follows GSplat: separate sin/cos channels per frequency. +#pragma unroll + for (int k = 0; k < kIPFD; ++k) { +#pragma unroll + for (int f = 0; f < NumFrequencies; ++f) { + const float freq = static_cast(f + 1); + const float angle = baseFeatures[k] * freq; + float s, c; + __sincosf(angle, &s, &c); + const int outIdx = k * NumFrequencies * 2 + f * 2; + dBase[k] += (c * dFeatures[outIdx] - s * dFeatures[outIdx + 1]) * freq; + } + } + } + } + + // Step 4: barycentric backward. + // d(vertex_k[i]) = w_k * dBase[i] → featureLocalGrad (+= matches Slang exclusiveGradient=true) + // d(canonicalPos) += sum_k (sum_i vert_k[i]*dBase[i]) * N_k + float3 dP = make_float3(0.0f, 0.0f, 0.0f); +#pragma unroll + for (int k = 0; k < 4; ++k) { + float dw_k = 0.0f; +#pragma unroll + for (int i = 0; i < kIPFD; ++i) { + featureLocalGrad[k * kIPFD + i] += w[k] * dBase[i]; + dw_k += vert[k][i] * dBase[i]; + } + const float3 Nk = (k == 0) ? tetraN0() + : (k == 1) ? tetraN1() + : (k == 2) ? tetraN2() + : tetraN3(); + dP.x += dw_k * Nk.x; + dP.y += dw_k * Nk.y; + dP.z += dw_k * Nk.z; + } + canonicalIntersectionGrad.x += dP.x; + canonicalIntersectionGrad.y += dP.y; + canonicalIntersectionGrad.z += dP.z; +} + +} // namespace nht + } // namespace threedgut \ No newline at end of file diff --git a/threedgut_tracer/include/3dgut/kernels/cuda/models/shRadiativeGaussianParticles.cuh b/threedgut_tracer/include/3dgut/kernels/cuda/models/shRadiativeGaussianParticles.cuh index a4aaed5f..497f4f61 100644 --- a/threedgut_tracer/include/3dgut/kernels/cuda/models/shRadiativeGaussianParticles.cuh +++ b/threedgut_tracer/include/3dgut/kernels/cuda/models/shRadiativeGaussianParticles.cuh @@ -17,6 +17,18 @@ #include <3dgut/kernels/cuda/models/gaussianParticles.cuh> #include <3dgut/renderer/renderParameters.h> +#if PARTICLE_FEATURE_HALF +#include +#endif + +// Select backend for the NHT-path feature integration local-grad backward. +// 1 → handwritten CUDA (threedgut::nht::featuresIntegrateBwdToLocalGrad), +// 0 → Slang autodiff (threedgutSlang.cuh particleFeaturesIntegrateBwdToBuffer). +// Only effective when FEATURE_TRANSFORM_TYPE == 1 (NHT). A/B switch for perf work. +#ifndef NHT_FEATURES_BWD_LOCAL_GRAD_CUDA +#define NHT_FEATURES_BWD_LOCAL_GRAD_CUDA 1 +#endif + template struct ShRadiativeGaussianParticlesBuffer { TBuffer* ptr = nullptr; @@ -65,7 +77,7 @@ struct ShRadiativeGaussianVolumetricFeaturesParticles : Params, public ExtParams __forceinline__ __device__ DensityParameters fetchDensityParameters(uint32_t particleIdx) const { const auto parameters = particleDensityParameters( particleIdx, - {reinterpret_cast(m_densityRawParameters.ptr), nullptr}); + {reinterpret_cast(m_densityRawParameters.ptr), nullptr, false}); return *reinterpret_cast(¶meters); } @@ -95,13 +107,14 @@ struct ShRadiativeGaussianVolumetricFeaturesParticles : Params, public ExtParams const DensityParameters& parameters, float& alpha, float& depth, + float3& canonicalIntersection, tcnn::vec3* normal = nullptr) const { - return particleDensityHit(*reinterpret_cast(&rayOrigin), *reinterpret_cast(&rayDirection), reinterpret_cast(parameters), &alpha, &depth, + &canonicalIntersection, normal != nullptr, reinterpret_cast(normal)); } @@ -127,12 +140,14 @@ struct ShRadiativeGaussianVolumetricFeaturesParticles : Params, public ExtParams float& transmittance, float& integratedDepth, tcnn::vec3* integratedNormal = nullptr) const { + float3 unusedCanonicalIntersection = make_float3(0.f, 0.f, 0.f); return particleDensityProcessHitFwdFromBuffer(*reinterpret_cast(&rayOrigin), *reinterpret_cast(&rayDirection), particleIdx, {{reinterpret_cast(m_densityRawParameters.ptr), nullptr, true}}, &transmittance, &integratedDepth, + &unusedCanonicalIntersection, integratedNormal != nullptr, reinterpret_cast(integratedNormal)); } @@ -148,6 +163,7 @@ struct ShRadiativeGaussianVolumetricFeaturesParticles : Params, public ExtParams float depth, float& integratedDepth, float& integratedDepthGrad, + const float3& canonicalIntersectionGrad, const tcnn::vec3* normal = nullptr, tcnn::vec3* integratedNormal = nullptr, tcnn::vec3* integratedNormalGrad = nullptr @@ -167,6 +183,7 @@ struct ShRadiativeGaussianVolumetricFeaturesParticles : Params, public ExtParams depth, &integratedDepth, &integratedDepthGrad, + canonicalIntersectionGrad, normal != nullptr, normal == nullptr ? make_float3(0, 0, 0) : *reinterpret_cast(normal), reinterpret_cast(integratedNormal), @@ -218,59 +235,64 @@ struct ShRadiativeGaussianVolumetricFeaturesParticles : Params, public ExtParams template __forceinline__ __device__ void densityIncidentDirectionBwdToBuffer(uint32_t particlesIdx, - const tcnn::vec3& sourcePosition) - + const tcnn::vec3& sourcePosition, + const tcnn::vec3& incidentDirectionGrad) { - particleDensityIncidentDirectionBwdToBuffer(particlesIdx, - {{reinterpret_cast(m_densityRawParameters.ptr), - reinterpret_cast(m_densityRawParameters.gradPtr), - exclusiveGradient}}, - *reinterpret_cast(&sourcePosition)); + if constexpr (TDifferentiable) { + particleDensityIncidentDirectionBwdToBuffer( + particlesIdx, + {{reinterpret_cast(m_densityRawParameters.ptr), + reinterpret_cast(m_densityRawParameters.gradPtr), + exclusiveGradient}}, + *reinterpret_cast(&sourcePosition), + *reinterpret_cast(&incidentDirectionGrad)); + } } - using FeaturesParameters = shRadiativeParticle_Parameters_0; - using TFeaturesVec = typename tcnn::vec; + using TFeaturesVec = typename tcnn::vec; +#if PARTICLE_FEATURE_HALF + using TFeatureRawParamPtr = __half; +#else + using TFeatureRawParamPtr = float; +#endif inline __device__ void initializeFeatures(threedgut::MemoryHandles parameters) { - static_assert(ExtParams::FeaturesDim == 3, "Hardcoded 3-dimensional radiance because of Slang-Cuda interop"); - m_featureRawParameters.ptr = parameters.bufferPtr(Params::FeaturesRawParametersBufferIndex); + m_featureRawParameters.ptr = parameters.bufferPtr(Params::FeaturesRawParametersBufferIndex); m_featureActiveShDegree = *reinterpret_cast(parameters.bufferPtr(Params::GlobalParametersValueBufferIndex) + Params::FeatureShDegreeValueOffset); }; inline __device__ void initializeFeaturesGradient(threedgut::MemoryHandles parametersGradient) { if constexpr (TDifferentiable) { - m_featureRawParameters.gradPtr = parametersGradient.bufferPtr(Params::FeaturesRawParametersGradientBufferIndex); + m_featureRawParameters.gradPtr = parametersGradient.bufferPtr(Params::FeaturesRawParametersGradientBufferIndex); } }; __forceinline__ __device__ TFeaturesVec featuresFromBuffer(uint32_t particleIdx, - const tcnn::vec3& incidentDirection) const { - - const auto features = particleFeaturesFromBuffer(particleIdx, - {{m_featureRawParameters.ptr, nullptr, true}, m_featureActiveShDegree}, - *reinterpret_cast(&incidentDirection)); - return *reinterpret_cast(&features); + const tcnn::vec3& incidentDirection, + const float3& canonicalPosition) const { + TFeaturesVec result; + particleFeaturesFromBuffer(particleIdx, + reinterpret_cast(m_featureRawParameters.ptr), + m_featureActiveShDegree, + *reinterpret_cast(&incidentDirection), + canonicalPosition, + reinterpret_cast*>(&result)); + return result; } template __forceinline__ __device__ TFeaturesVec featuresCustomFromBuffer(uint32_t particleIdx, const tcnn::vec3& incidentDirection) const { - const float3 gradu = threedgut::radianceFromSpH(m_featureActiveShDegree, - reinterpret_cast(&m_featureRawParameters.ptr[particleIdx * ExtParams::RadianceMaxNumSphCoefficients]), - *reinterpret_cast(&incidentDirection), - Clamped); - return *reinterpret_cast(&gradu); - } - - template - __forceinline__ __device__ void featuresBwdToBuffer(uint32_t particleIdx, - const TFeaturesVec& featuresGrad, - const tcnn::vec3& incidentDirection) const { - particleFeaturesBwdToBuffer(particleIdx, - {{m_featureRawParameters.ptr, m_featureRawParameters.gradPtr, exclusiveGradient}, m_featureActiveShDegree}, - *reinterpret_cast(&featuresGrad), - *reinterpret_cast(&incidentDirection)); + if constexpr (ExtParams::PerRayParticleFeatures) { + return TFeaturesVec::zero(); + } else { + const float3 gradu = threedgut::radianceFromSpH(m_featureActiveShDegree, + reinterpret_cast(&m_featureRawParameters.ptr[particleIdx * ExtParams::RadianceMaxNumSphCoefficients * 3]), + *reinterpret_cast(&incidentDirection), + Clamped); + return *reinterpret_cast(&gradu); + } } template @@ -278,31 +300,50 @@ struct ShRadiativeGaussianVolumetricFeaturesParticles : Params, public ExtParams const TFeaturesVec& features, const TFeaturesVec& featuresGrad, const tcnn::vec3& incidentDirection) const { - threedgut::radianceFromSpHBwd(m_featureActiveShDegree, - *reinterpret_cast(&incidentDirection), - *reinterpret_cast(&featuresGrad), - reinterpret_cast(&m_featureRawParameters.gradPtr[particleIdx * ExtParams::RadianceMaxNumSphCoefficients]), - *reinterpret_cast(&features)); + if constexpr (!ExtParams::PerRayParticleFeatures) { + threedgut::radianceFromSpHBwd(m_featureActiveShDegree, + *reinterpret_cast(&incidentDirection), + *reinterpret_cast(&featuresGrad), + reinterpret_cast(&m_featureRawParameters.gradPtr[particleIdx * ExtParams::RadianceMaxNumSphCoefficients * 3]), + *reinterpret_cast(&features)); + } + } + + template + __forceinline__ __device__ void featuresBwdToBuffer(uint32_t particleIdx, + const TFeaturesVec& featuresGrad, + const tcnn::vec3& incidentDirection, + tcnn::vec3& incidentDirectionGrad) const { + if constexpr (TDifferentiable) { + particleFeaturesBwdToBuffer(particleIdx, + reinterpret_cast(reinterpret_cast(m_featureRawParameters.ptr)), + m_featureRawParameters.gradPtr, + m_featureActiveShDegree, + exclusiveGradient, + *reinterpret_cast*>(&featuresGrad), + *reinterpret_cast(&incidentDirection), + reinterpret_cast(&incidentDirectionGrad)); + } } __forceinline__ __device__ void featureIntegrateFwd(float weight, const TFeaturesVec& features, TFeaturesVec& integratedFeatures) const { - particleFeaturesIntegrateFwd(weight, - *reinterpret_cast(&features), - reinterpret_cast(&integratedFeatures)); + *reinterpret_cast*>(&features), + reinterpret_cast*>(&integratedFeatures)); } __forceinline__ __device__ void featuresIntegrateFwdFromBuffer(const tcnn::vec3& incidentDirection, float weight, - uint32_t particleIdx, TFeaturesVec integratedFeatures) const { - + uint32_t particleIdx, TFeaturesVec& integratedFeatures) const { particleFeaturesIntegrateFwdFromBuffer(*reinterpret_cast(&incidentDirection), - weight, - particleIdx, - {{m_featureRawParameters.ptr, nullptr, true}, m_featureActiveShDegree}, - reinterpret_cast(&integratedFeatures)); + make_float3(0.f, 0.f, 0.f), + weight, + particleIdx, + reinterpret_cast(m_featureRawParameters.ptr), + m_featureActiveShDegree, + reinterpret_cast*>(&integratedFeatures)); } __forceinline__ __device__ void featuresIntegrateBwd(float alpha, @@ -314,15 +355,17 @@ struct ShRadiativeGaussianVolumetricFeaturesParticles : Params, public ExtParams if (TDifferentiable) { particleFeaturesIntegrateBwd(alpha, &alphaGrad, - *reinterpret_cast(&features), - reinterpret_cast(&featuresGrad), - reinterpret_cast(&integratedFeatures), - reinterpret_cast(&integratedFeaturesGrad)); + *reinterpret_cast*>(&features), + reinterpret_cast*>(&featuresGrad), + reinterpret_cast*>(&integratedFeatures), + reinterpret_cast*>(&integratedFeaturesGrad)); } } template __forceinline__ __device__ void featuresIntegrateBwdToBuffer(const tcnn::vec3& incidentDirection, + const float3& canonicalIntersection, + float3& canonicalIntersectionGrad, float alpha, float& alphaGrad, uint32_t particleIdx, @@ -332,13 +375,114 @@ struct ShRadiativeGaussianVolumetricFeaturesParticles : Params, public ExtParams if (TDifferentiable) { particleFeaturesIntegrateBwdToBuffer(*reinterpret_cast(&incidentDirection), + canonicalIntersection, + &canonicalIntersectionGrad, alpha, &alphaGrad, particleIdx, - {{m_featureRawParameters.ptr, m_featureRawParameters.gradPtr, exclusiveGradient}, m_featureActiveShDegree}, - *reinterpret_cast(&features), - reinterpret_cast(&integratedFeatures), - reinterpret_cast(&integratedFeaturesGrad)); + reinterpret_cast(m_featureRawParameters.ptr), + m_featureRawParameters.gradPtr, + m_featureActiveShDegree, + exclusiveGradient, + *reinterpret_cast*>(&features), + reinterpret_cast*>(&integratedFeatures), + reinterpret_cast*>(&integratedFeaturesGrad)); + } + } + + // NHT warp reduction step 1: compute feature grad into a thread-private local buffer (no atomics). + // featureLocalGrad must be zero-initialized (size ExtParams::ParticleFeatureDim) before calling. + // Follow with featureLocalGradWarpReduceAndWrite (called by ALL warp threads) to write to global buffer. + __forceinline__ __device__ void featuresIntegrateBwdToLocalGrad(const tcnn::vec3& incidentDirection, + const float3& canonicalIntersection, + float3& canonicalIntersectionGrad, + float alpha, + float& alphaGrad, + uint32_t particleIdx, + const TFeaturesVec& features, + TFeaturesVec& integratedFeatures, + TFeaturesVec& integratedFeaturesGrad, + float* featureLocalGrad) const { +#if NHT_FEATURES_BWD_LOCAL_GRAD_CUDA + if constexpr (TDifferentiable) { + static_assert(ExtParams::FeatureTransformType == 1, + "NHT CUDA backward is only valid on the NHT path (FEATURE_TRANSFORM_TYPE==1)"); + static_assert(FEATURE_INTERPOLATION_TYPE == 0, + "only barycentric interpolation supported"); + static_assert(FEATURE_INTERPOLATION_SUPPORT == 1, + "only tetrahedral interpolation support supported"); + static_assert(4 * INTERP_POINT_FEATURE_DIM == ExtParams::ParticleFeatureDim, + "NHT buffer layout must be 4 vertices * InterpPointFeatureDim"); + static_assert(((FEATURE_ACTIVATION_TYPE == 0 || FEATURE_ACTIVATION_TYPE == 3) && + RAY_FEATURE_DIM == INTERP_POINT_FEATURE_DIM) || + (FEATURE_ACTIVATION_TYPE == 1 && + RAY_FEATURE_DIM == INTERP_POINT_FEATURE_DIM * FEATURE_ACTIVATION_NUM_FREQUENCIES) || + (FEATURE_ACTIVATION_TYPE == 2 && + RAY_FEATURE_DIM == INTERP_POINT_FEATURE_DIM * FEATURE_ACTIVATION_NUM_FREQUENCIES * 2), + "RAY_FEATURE_DIM / INTERP_POINT_FEATURE_DIM / activation mismatch"); + + threedgut::nht::featuresIntegrateBwdToLocalGrad< + INTERP_POINT_FEATURE_DIM, + FEATURE_ACTIVATION_NUM_FREQUENCIES, + FEATURE_ACTIVATION_TYPE>( + canonicalIntersection, + canonicalIntersectionGrad, + alpha, + alphaGrad, + particleIdx, + reinterpret_cast(&features), + reinterpret_cast(&integratedFeatures), + reinterpret_cast(&integratedFeaturesGrad), + reinterpret_cast(m_featureRawParameters.ptr), + featureLocalGrad); + } +#else + if constexpr (TDifferentiable) { + // Pointer offset trick: shift featureLocalGrad back by particleOffset so Slang's + // _gradPtr[interpPointOffset + i] writes land in featureLocalGrad[0..ParticleFeatureDim-1]. + // interpPointOffset = particleOffset + interpPointIdx*InterpPointFeatureDim, so: + // (featureLocalGrad - particleOffset)[interpPointOffset + i] + // = featureLocalGrad[interpPointIdx*InterpPointFeatureDim + i] + const uint32_t particleOffset = particleIdx * ExtParams::ParticleFeatureDim; + particleFeaturesIntegrateBwdToBuffer( + *reinterpret_cast(&incidentDirection), + canonicalIntersection, + &canonicalIntersectionGrad, + alpha, + &alphaGrad, + particleIdx, + reinterpret_cast(m_featureRawParameters.ptr), + featureLocalGrad - particleOffset, // shifted: writes to featureLocalGrad[0..ParticleFeatureDim-1] + m_featureActiveShDegree, + true, // exclusiveGradient=true → += without atomics + *reinterpret_cast*>(&features), + reinterpret_cast*>(&integratedFeatures), + reinterpret_cast*>(&integratedFeaturesGrad)); + } +#endif // NHT_FEATURES_BWD_LOCAL_GRAD_CUDA + } + + // NHT warp reduction step 2: warp-reduce featureLocalGrad and atomicAdd to global gradient buffer. + // MUST be called by ALL threads in the warp (including non-hitting threads with featureLocalGrad=0) + // to satisfy __shfl_xor_sync requirements. + __forceinline__ __device__ void featureLocalGradWarpReduceAndWrite(uint32_t particleIdx, + float* featureLocalGrad, + uint32_t tileThreadIdx) const { + if constexpr (TDifferentiable) { +#pragma unroll + for (int mask = 1; mask < warpSize; mask *= 2) { +#pragma unroll + for (int i = 0; i < ExtParams::ParticleFeatureDim; i++) { + featureLocalGrad[i] += __shfl_xor_sync(0xffffffff, featureLocalGrad[i], mask); + } + } + if ((tileThreadIdx & (warpSize - 1)) == 0) { + const uint32_t particleOffset = particleIdx * ExtParams::ParticleFeatureDim; +#pragma unroll + for (int i = 0; i < ExtParams::ParticleFeatureDim; i++) { + atomicAdd(&m_featureRawParameters.gradPtr[particleOffset + i], featureLocalGrad[i]); + } + } } } @@ -355,7 +499,8 @@ struct ShRadiativeGaussianVolumetricFeaturesParticles : Params, public ExtParams reinterpret_cast(rayDirection), particleIdx, m_densityRawParameters.ptr, - PerRayRadiance ? reinterpret_cast(m_featureRawParameters.ptr) : reinterpret_cast(particleFeaturesPtr), + PerRayRadiance ? reinterpret_cast(reinterpret_cast(m_featureRawParameters.ptr)) + : reinterpret_cast(particleFeaturesPtr), ExtParams::MinParticleKernelDensity, ExtParams::AlphaThreshold, m_featureActiveShDegree, @@ -389,8 +534,9 @@ struct ShRadiativeGaussianVolumetricFeaturesParticles : Params, public ExtParams particleIdx, reinterpret_cast(densityRawParameters), reinterpret_cast(densityRawParametersGrad), - PerRayRadiance ? reinterpret_cast(m_featureRawParameters.ptr) : reinterpret_cast(particleFeatures.data()), - PerRayRadiance ? reinterpret_cast(m_featureRawParameters.gradPtr) : reinterpret_cast(particleFeaturesGradPtr), + PerRayRadiance ? reinterpret_cast(reinterpret_cast(m_featureRawParameters.ptr)) + : reinterpret_cast(particleFeatures.data()), + PerRayRadiance ? m_featureRawParameters.gradPtr : reinterpret_cast(particleFeaturesGradPtr), ExtParams::MinParticleKernelDensity, ExtParams::AlphaThreshold, ExtParams::MinTransmittanceThreshold, @@ -413,7 +559,7 @@ struct ShRadiativeGaussianVolumetricFeaturesParticles : Params, public ExtParams #pragma unroll for (int mask = 1; mask < warpSize; mask *= 2) { #pragma unroll - for (int i = 0; i < ExtParams::FeaturesDim; ++i) { + for (int i = 0; i < ExtParams::RayFeatureDim; ++i) { featuresGrad[i] += __shfl_xor_sync(0xffffffff, featuresGrad[i], mask); } } @@ -421,13 +567,13 @@ struct ShRadiativeGaussianVolumetricFeaturesParticles : Params, public ExtParams // First thread in the warp performs the atomic add if ((tileThreadIdx & (warpSize - 1)) == 0) { #pragma unroll - for (int i = 0; i < ExtParams::FeaturesDim; i++) { + for (int i = 0; i < ExtParams::RayFeatureDim; i++) { atomicAdd(&featuresGradSum[particleIdx][i], featuresGrad[i]); } } } else { #pragma unroll - for (int i = 0; i < ExtParams::FeaturesDim; ++i) { + for (int i = 0; i < ExtParams::RayFeatureDim; ++i) { atomicAdd(&featuresGradSum[particleIdx][i], featuresGrad[i]); } } @@ -486,5 +632,5 @@ private: m_densityRawParameters; int m_featureActiveShDegree = 0; - ShRadiativeGaussianParticlesBuffer m_featureRawParameters; + ShRadiativeGaussianParticlesBuffer m_featureRawParameters; // gradPtr always fp32; ptr cast to TFeatureRawParamPtr for reads }; diff --git a/threedgut_tracer/include/3dgut/kernels/cuda/renderers/gutKBufferRenderer.cuh b/threedgut_tracer/include/3dgut/kernels/cuda/renderers/gutKBufferRenderer.cuh index 1390f8d6..c3d02345 100644 --- a/threedgut_tracer/include/3dgut/kernels/cuda/renderers/gutKBufferRenderer.cuh +++ b/threedgut_tracer/include/3dgut/kernels/cuda/renderers/gutKBufferRenderer.cuh @@ -18,16 +18,52 @@ #include <3dgut/kernels/cuda/common/rayPayloadBackward.cuh> #include <3dgut/renderer/gutRendererParameters.h> -struct HitParticle { +// HitParticle stores per-ray hit state inside the k-buffer. The +// `canonicalIntersection` is only consumed by the NHT feature path; for the +// SH path (`PerRayParticleFeatures=false`) we drop it from the struct so the +// k-buffer pays no storage cost in registers / shared memory per hit. +template +struct HitParticleT; + +// Base / SH case: common per-hit state only. +template <> +struct HitParticleT { static constexpr float InvalidHitT = -1.0f; int idx = -1; float hitT = InvalidHitT; float alpha = 0.0f; }; -template -struct HitParticleKBuffer { - __device__ HitParticleKBuffer() { +// NHT case extends with the per-ray canonical intersection point used by the +// feature evaluation. Plain (non-virtual) inheritance: layout is base fields +// then the extra `float3`, matching the pre-refactor combined struct exactly. +template <> +struct HitParticleT : HitParticleT { + float3 canonicalIntersection = make_float3(0.f, 0.f, 0.f); +}; + +// Helper returning a writable `float3&` slot usable as the `densityHit` +// canonical-intersection out-parameter. Routes to the struct field when +// present, otherwise to a caller-provided stack scratch (which the compiler +// elides when unused downstream). +template +__forceinline__ __device__ float3& canonicalIntersectionSlot(HitParticleT& hit, float3& scratch); + +template <> +__forceinline__ __device__ float3& canonicalIntersectionSlot(HitParticleT& hit, float3& /*scratch*/) { + return hit.canonicalIntersection; +} + +template <> +__forceinline__ __device__ float3& canonicalIntersectionSlot(HitParticleT& /*hit*/, float3& scratch) { + return scratch; +} + +template +struct HitParticleKBufferT { + using HitParticle = HitParticleT; + + __device__ HitParticleKBufferT() { m_numHits = 0; #pragma unroll for (int i = 0; i < K; ++i) { @@ -75,8 +111,9 @@ private: uint32_t m_numHits; }; -template <> -struct HitParticleKBuffer<0> { +template +struct HitParticleKBufferT<0, HasCanonical> { + using HitParticle = HitParticleT; constexpr inline __device__ void insert(HitParticle& hitParticle) const {} constexpr inline __device__ HitParticle operator[](int) const { return HitParticle(); } constexpr inline __device__ uint32_t numHits() const { return 0; } @@ -91,8 +128,13 @@ struct GUTKBufferRenderer : Params { using DensityRawParameters = typename Particles::DensityRawParameters; using TFeaturesVec = typename Particles::TFeaturesVec; - using TRayPayload = RayPayload; - using TRayPayloadBackward = RayPayloadBackward; + using TRayPayload = RayPayload; + using TRayPayloadBackward = RayPayloadBackward; + + // Storage-optimized hit types: the per-ray canonical intersection is kept + // only when the feature model needs it (NHT / PerRayParticleFeatures). + using HitParticle = HitParticleT; + using HitParticleKBuffer = HitParticleKBufferT; struct PrefetchedParticleData { uint32_t idx; @@ -115,24 +157,27 @@ struct GUTKBufferRenderer : Params { if constexpr (Backward) { float hitAlphaGrad = 0.f; + float3 canonicalIntersectionGrad = make_float3(0.f, 0.f, 0.f); if constexpr (Params::PerRayParticleFeatures) { particles.featuresIntegrateBwdToBuffer(ray.direction, + hitParticle.canonicalIntersection, + canonicalIntersectionGrad, hitParticle.alpha, hitAlphaGrad, hitParticle.idx, - particles.featuresFromBuffer(hitParticle.idx, ray.direction), + particles.featuresFromBuffer(hitParticle.idx, ray.direction, hitParticle.canonicalIntersection), ray.featuresBackward, ray.featuresGradient); } else { TFeaturesVec particleFeaturesGradientVec = TFeaturesVec::zero(); particles.featuresIntegrateBwd(hitParticle.alpha, hitAlphaGrad, - particleFeatures[hitParticle.idx], + tcnn::max(particleFeatures[hitParticle.idx], 0.f), particleFeaturesGradientVec, ray.featuresBackward, ray.featuresGradient); #pragma unroll - for (int i = 0; i < Particles::FeaturesDim; ++i) { + for (int i = 0; i < Particles::RayFeatureDim; ++i) { atomicAdd(&(particleFeaturesGradient[hitParticle.idx][i]), particleFeaturesGradientVec[i]); } } @@ -146,7 +191,8 @@ struct GUTKBufferRenderer : Params { ray.transmittanceGradient, hitParticle.hitT, ray.hitTBackward, - ray.hitTGradient); + ray.hitTGradient, + canonicalIntersectionGrad); ray.transmittance *= (1.0 - hitParticle.alpha); @@ -157,12 +203,20 @@ struct GUTKBufferRenderer : Params { hitParticle.hitT, ray.hitT); - particles.featureIntegrateFwd(hitWeight, - Params::PerRayParticleFeatures ? particles.featuresFromBuffer(hitParticle.idx, ray.direction) : tcnn::max(particleFeatures[hitParticle.idx], 0.f), - ray.features); + // `if constexpr` branches so the SH specialization of Hit (which + // has no `canonicalIntersection` member) is not instantiated with + // a missing field reference. + if constexpr (Params::PerRayParticleFeatures) { + particles.featureIntegrateFwd(hitWeight, + particles.featuresFromBuffer(hitParticle.idx, ray.direction, hitParticle.canonicalIntersection), + ray.features); + } else { + particles.featureIntegrateFwd(hitWeight, + tcnn::max(particleFeatures[hitParticle.idx], 0.f), + ray.features); + } - if (hitWeight > 0.0f) - ray.countHit(); + if (hitWeight > 0.0f) ray.countHit(); } if (ray.transmittance < Particles::MinTransmittanceThreshold) { @@ -228,7 +282,7 @@ struct GUTKBufferRenderer : Params { using namespace threedgut; __shared__ PrefetchedParticleData prefetchedParticlesData[GUTParameters::Tiling::BlockSize]; - HitParticleKBuffer hitParticleKBuffer; + HitParticleKBuffer hitParticleKBuffer; for (uint32_t i = 0; i < tileNumBlocksToProcess; i++, tileNumParticlesToProcess -= GUTParameters::Tiling::BlockSize) { @@ -261,11 +315,15 @@ struct GUTKBufferRenderer : Params { HitParticle hitParticle; hitParticle.idx = particleData.idx; + // `canonicalScratch` is only written when the SH specialization + // of Hit has no canonicalIntersection field; compiler elides it. + float3 canonicalScratch = make_float3(0.f, 0.f, 0.f); if (particles.densityHit(ray.origin, ray.direction, particleData.densityParameters, hitParticle.alpha, - hitParticle.hitT) && + hitParticle.hitT, + canonicalIntersectionSlot(hitParticle, canonicalScratch)) && (hitParticle.hitT > ray.tMinMax.x) && (hitParticle.hitT < ray.tMinMax.y)) { @@ -349,24 +407,26 @@ struct GUTKBufferRenderer : Params { if (!ray.isAlive()) break; - float hitAlpha = 0.0f; - float hitT = 0.0f; + float hitAlpha = 0.0f; + float3 hitCanonicalIntersection = make_float3(0.f, 0.f, 0.f); + float hitT = 0.0f; TFeaturesVec hitFeatures = TFeaturesVec::zero(); - bool validHit = false; + bool validHit = false; // Step 1: Each thread tests one Gaussian intersection if (j < tileNumParticlesToProcess) { const uint32_t toProcessSortedIndex = tileParticleRangeIndices.x + j; - const uint32_t particleIdx = sortedTileParticleIdxPtr[toProcessSortedIndex]; + const uint32_t particleIdx = sortedTileParticleIdxPtr[toProcessSortedIndex]; if (particleIdx != GUTParameters::InvalidParticleIdx) { auto densityParams = particles.fetchDensityParameters(particleIdx); if (particles.densityHit(ray.origin, - ray.direction, - densityParams, - hitAlpha, - hitT) && + ray.direction, + densityParams, + hitAlpha, + hitT, + hitCanonicalIntersection) && (hitT > ray.tMinMax.x) && (hitT < ray.tMinMax.y)) { @@ -374,7 +434,7 @@ struct GUTKBufferRenderer : Params { // Get Gaussian features if constexpr (Params::PerRayParticleFeatures) { - hitFeatures = particles.featuresFromBuffer(particleIdx, ray.direction); + hitFeatures = particles.featuresFromBuffer(particleIdx, ray.direction, hitCanonicalIntersection); } else { hitFeatures = tcnn::max(particleFeaturesBuffer[particleIdx], 0.f); } @@ -430,15 +490,15 @@ struct GUTKBufferRenderer : Params { float hitWeight = hitAlpha * particleTransmittance; // Compute weighted contributions - for (int featIdx = 0; featIdx < Particles::FeaturesDim; ++featIdx) { + for (int featIdx = 0; featIdx < Particles::RayFeatureDim; ++featIdx) { accumulatedFeatures[featIdx] = hitFeatures[featIdx] * hitWeight; } - accumulatedHitT = hitT * hitWeight; + accumulatedHitT = hitT * hitWeight; accumulatedHitCount = (hitWeight > 0.0f) ? 1 : 0; } // Step 6: Warp-level reduction (tree-based sum) - for (int featIdx = 0; featIdx < Particles::FeaturesDim; ++featIdx) { + for (int featIdx = 0; featIdx < Particles::RayFeatureDim; ++featIdx) { for (uint32_t offset = WarpSize / 2; offset > 0; offset >>= 1) { accumulatedFeatures[featIdx] += __shfl_down_sync(WarpMask, accumulatedFeatures[featIdx], offset); } @@ -451,7 +511,7 @@ struct GUTKBufferRenderer : Params { // Step 7: Only lane 0 updates ray state (avoid race conditions) if (laneId == 0) { - for (int featIdx = 0; featIdx < Particles::FeaturesDim; ++featIdx) { + for (int featIdx = 0; featIdx < Particles::RayFeatureDim; ++featIdx) { ray.features[featIdx] += accumulatedFeatures[featIdx]; } ray.hitT += accumulatedHitT; @@ -481,83 +541,177 @@ struct GUTKBufferRenderer : Params { static_assert(Backward && (Params::KHitBufferSize == 0), "Optimized path for backward pass with no KBuffer"); using namespace threedgut; - __shared__ PrefetchedRawParticleData prefetchedRawParticlesData[GUTParameters::Tiling::BlockSize]; - for (uint32_t i = 0; i < tileNumBlocksToProcess; i++, tileNumParticlesToProcess -= GUTParameters::Tiling::BlockSize) { + if constexpr (Params::PerRayParticleFeatures) { + // NHT path: features are re-evaluated per-ray from buffer (no pre-computed feature cache). + // Gradients accumulate into a thread-private local buffer, then warp-reduced before a single + // atomicAdd per feature dim — reduces global memory traffic by up to 32× vs. per-hit atomics. + __shared__ PrefetchedParticleData prefetchedParticlesData[GUTParameters::Tiling::BlockSize]; - if (__syncthreads_and(!ray.isAlive())) { - break; - } + for (uint32_t i = 0; i < tileNumBlocksToProcess; i++, tileNumParticlesToProcess -= GUTParameters::Tiling::BlockSize) { - // Collectively fetch particle data - const uint32_t toProcessSortedIndex = tileParticleRangeIndices.x + i * GUTParameters::Tiling::BlockSize + tileThreadIdx; - if (toProcessSortedIndex < tileParticleRangeIndices.y) { - const uint32_t particleIdx = sortedTileParticleIdxPtr[toProcessSortedIndex]; - if (particleIdx != GUTParameters::InvalidParticleIdx) { - prefetchedRawParticlesData[tileThreadIdx].densityParameters = particles.fetchDensityRawParameters(particleIdx); - if constexpr (Params::PerRayParticleFeatures) { - prefetchedRawParticlesData[tileThreadIdx].features = TFeaturesVec::zero(); + if (__syncthreads_and(!ray.isAlive())) { + break; + } + + // Collectively fetch density parameters only (features fetched per-ray below) + const uint32_t toProcessSortedIndex = tileParticleRangeIndices.x + i * GUTParameters::Tiling::BlockSize + tileThreadIdx; + if (toProcessSortedIndex < tileParticleRangeIndices.y) { + const uint32_t particleIdx = sortedTileParticleIdxPtr[toProcessSortedIndex]; + if (particleIdx != GUTParameters::InvalidParticleIdx) { + prefetchedParticlesData[tileThreadIdx] = {particleIdx, particles.fetchDensityParameters(particleIdx)}; } else { - prefetchedRawParticlesData[tileThreadIdx].features = tcnn::max(particleFeaturesBuffer[particleIdx], 0.f); + prefetchedParticlesData[tileThreadIdx].idx = GUTParameters::InvalidParticleIdx; } - prefetchedRawParticlesData[tileThreadIdx].idx = particleIdx; } else { - prefetchedRawParticlesData[tileThreadIdx].idx = GUTParameters::InvalidParticleIdx; + prefetchedParticlesData[tileThreadIdx].idx = GUTParameters::InvalidParticleIdx; + } + __syncthreads(); + + // Process fetched particles + for (int j = 0; j < min(GUTParameters::Tiling::BlockSize, tileNumParticlesToProcess); j++) { + + if (__all_sync(GUTParameters::Tiling::WarpMask, !ray.isAlive())) { + break; + } + + const PrefetchedParticleData particleData = prefetchedParticlesData[j]; + if (particleData.idx == GUTParameters::InvalidParticleIdx) { + ray.kill(); + break; + } + + // Thread-private feature gradient buffer for warp reduction. + // Zero-initialized: non-hitting threads contribute 0 to the warp sum. + float featureLocalGrad[Particles::ParticleFeatureDim] = {}; + + if (ray.isAlive()) { + float hitAlpha = 0.f; + float hitT = 0.f; + float3 canonicalIntersection = make_float3(0.f, 0.f, 0.f); + + if (particles.densityHit(ray.origin, ray.direction, particleData.densityParameters, + hitAlpha, hitT, canonicalIntersection) && + (hitT > ray.tMinMax.x) && + (hitT < ray.tMinMax.y)) { + // Re-evaluate NHT features at canonical intersection point (cheap: barycentric interp) + const TFeaturesVec hitFeatures = particles.featuresFromBuffer(particleData.idx, ray.direction, canonicalIntersection); + + float hitAlphaGrad = 0.f; + float3 canonicalIntersectionGrad = make_float3(0.f, 0.f, 0.f); + + // Write feature grad to thread-private local buffer (no atomics); warp reduction follows below. + particles.featuresIntegrateBwdToLocalGrad(ray.direction, + canonicalIntersection, + canonicalIntersectionGrad, + hitAlpha, + hitAlphaGrad, + particleData.idx, + hitFeatures, + ray.featuresBackward, + ray.featuresGradient, + featureLocalGrad); + + particles.template densityProcessHitBwdToBuffer(ray.origin, + ray.direction, + particleData.idx, + hitAlpha, + hitAlphaGrad, + ray.transmittanceBackward, + ray.transmittanceGradient, + hitT, + ray.hitTBackward, + ray.hitTGradient, + canonicalIntersectionGrad); + + ray.transmittance *= (1.0f - hitAlpha); + } + + if (ray.transmittance < Particles::MinTransmittanceThreshold) { + ray.kill(); + } + } + + // Warp reduction: all 32 threads (alive or not) participate in __shfl_xor_sync. + // Non-hitting threads contribute featureLocalGrad=0 → no spurious gradient. + // Reduces 32×ParticleFeatureDim atomics per particle to ParticleFeatureDim atomics. + particles.featureLocalGradWarpReduceAndWrite(particleData.idx, featureLocalGrad, tileThreadIdx); } - } else { - prefetchedRawParticlesData[tileThreadIdx].idx = GUTParameters::InvalidParticleIdx; } - __syncthreads(); + } else { + // SH path: pre-computed features fetched from shared memory, gradients via precomputed buffer. + __shared__ PrefetchedRawParticleData prefetchedRawParticlesData[GUTParameters::Tiling::BlockSize]; - // Process fetched particles - for (int j = 0; j < min(GUTParameters::Tiling::BlockSize, tileNumParticlesToProcess); j++) { + for (uint32_t i = 0; i < tileNumBlocksToProcess; i++, tileNumParticlesToProcess -= GUTParameters::Tiling::BlockSize) { - if (__all_sync(GUTParameters::Tiling::WarpMask, !ray.isAlive())) { + if (__syncthreads_and(!ray.isAlive())) { break; } - const PrefetchedRawParticleData particleData = prefetchedRawParticlesData[j]; - if (particleData.idx == GUTParameters::InvalidParticleIdx) { - ray.kill(); - break; + // Collectively fetch particle data + const uint32_t toProcessSortedIndex = tileParticleRangeIndices.x + i * GUTParameters::Tiling::BlockSize + tileThreadIdx; + if (toProcessSortedIndex < tileParticleRangeIndices.y) { + const uint32_t particleIdx = sortedTileParticleIdxPtr[toProcessSortedIndex]; + if (particleIdx != GUTParameters::InvalidParticleIdx) { + prefetchedRawParticlesData[tileThreadIdx].densityParameters = particles.fetchDensityRawParameters(particleIdx); + prefetchedRawParticlesData[tileThreadIdx].features = tcnn::max(particleFeaturesBuffer[particleIdx], 0.f); + prefetchedRawParticlesData[tileThreadIdx].idx = particleIdx; + } else { + prefetchedRawParticlesData[tileThreadIdx].idx = GUTParameters::InvalidParticleIdx; + } + } else { + prefetchedRawParticlesData[tileThreadIdx].idx = GUTParameters::InvalidParticleIdx; } + __syncthreads(); - DensityRawParameters densityRawParametersGrad; - densityRawParametersGrad.density = 0.0f; - densityRawParametersGrad.position = make_float3(0.0f); - densityRawParametersGrad.quaternion = make_float4(0.0f); - densityRawParametersGrad.scale = make_float3(0.0f); - - TFeaturesVec featuresGrad = TFeaturesVec::zero(); - - if (ray.isAlive()) { - particles.processHitBwd( - ray.origin, - ray.direction, - particleData.idx, - particleData.densityParameters, - &densityRawParametersGrad, - particleData.features, - &featuresGrad, - ray.transmittance, - ray.transmittanceBackward, - ray.transmittanceGradient, - ray.features, - ray.featuresBackward, - ray.featuresGradient, - ray.hitT, - ray.hitTBackward, - ray.hitTGradient); - if (ray.transmittance < Particles::MinTransmittanceThreshold) { + // Process fetched particles + for (int j = 0; j < min(GUTParameters::Tiling::BlockSize, tileNumParticlesToProcess); j++) { + + if (__all_sync(GUTParameters::Tiling::WarpMask, !ray.isAlive())) { + break; + } + + const PrefetchedRawParticleData particleData = prefetchedRawParticlesData[j]; + if (particleData.idx == GUTParameters::InvalidParticleIdx) { ray.kill(); + break; + } + + DensityRawParameters densityRawParametersGrad; + densityRawParametersGrad.density = 0.0f; + densityRawParametersGrad.position = make_float3(0.0f); + densityRawParametersGrad.quaternion = make_float4(0.0f); + densityRawParametersGrad.scale = make_float3(0.0f); + + TFeaturesVec featuresGrad = TFeaturesVec::zero(); + + if (ray.isAlive()) { + particles.processHitBwd( + ray.origin, + ray.direction, + particleData.idx, + particleData.densityParameters, + &densityRawParametersGrad, + particleData.features, + &featuresGrad, + ray.transmittance, + ray.transmittanceBackward, + ray.transmittanceGradient, + ray.features, + ray.featuresBackward, + ray.featuresGradient, + ray.hitT, + ray.hitTBackward, + ray.hitTGradient); + if (ray.transmittance < Particles::MinTransmittanceThreshold) { + ray.kill(); + } } - } - if constexpr (!Params::PerRayParticleFeatures) { particles.processHitBwdUpdateFeaturesGradient(particleData.idx, featuresGrad, particleFeaturesGradientBuffer, tileThreadIdx); + particles.processHitBwdUpdateDensityGradient(particleData.idx, densityRawParametersGrad, tileThreadIdx); } - particles.processHitBwdUpdateDensityGradient(particleData.idx, densityRawParametersGrad, tileThreadIdx); } } } diff --git a/threedgut_tracer/include/3dgut/kernels/cuda/renderers/gutProjector.cuh b/threedgut_tracer/include/3dgut/kernels/cuda/renderers/gutProjector.cuh index 6319b96f..f335ef8c 100644 --- a/threedgut_tracer/include/3dgut/kernels/cuda/renderers/gutProjector.cuh +++ b/threedgut_tracer/include/3dgut/kernels/cuda/renderers/gutProjector.cuh @@ -402,30 +402,30 @@ struct GUTProjector : Params, UTParams { const float* __restrict__ particlesPrecomputedFeaturesPtr, const float* __restrict__ particlesPrecomputedFeaturesGradPtr, threedgut::MemoryHandles parametersGradient) { - if constexpr (Params::PerRayParticleFeatures) { - return; - } + if constexpr (!Params::PerRayParticleFeatures) { - const uint32_t particleIdx = blockIdx.x * blockDim.x + threadIdx.x; - if (particleIdx >= numParticles) { - return; - } - if (particlesTilesCountPtr[particleIdx] == 0) { - return; - } + const uint32_t particleIdx = blockIdx.x * blockDim.x + threadIdx.x; + if (particleIdx >= numParticles) { + return; + } + if (particlesTilesCountPtr[particleIdx] == 0) { + return; + } - Particles particles; - particles.initializeDensity(parameters); - const tcnn::vec3 incidentDirection = tcnn::normalize(particles.fetchPosition(particleIdx) - sensorWorldPosition); - - particles.initializeFeatures(parameters); - particles.initializeFeaturesGradient(parametersGradient); - particles.featuresBwdCustomToBuffer( - particleIdx, - reinterpret_cast(particlesPrecomputedFeaturesPtr)[particleIdx], - reinterpret_cast(particlesPrecomputedFeaturesGradPtr)[particleIdx], - incidentDirection); - particles.initializeDensityGradient(parametersGradient); - particles.template densityIncidentDirectionBwdToBuffer(particleIdx, sensorWorldPosition); + Particles particles; + particles.initializeDensity(parameters); + particles.initializeDensityGradient(parametersGradient); + const tcnn::vec3 incidentDirection = tcnn::normalize(particles.fetchPosition(particleIdx) - sensorWorldPosition); + tcnn::vec3 incidentDirectionGrad = tcnn::vec3(0.0f); + + particles.initializeFeatures(parameters); + particles.initializeFeaturesGradient(parametersGradient); + particles.featuresBwdToBuffer(particleIdx, + reinterpret_cast(particlesPrecomputedFeaturesGradPtr)[particleIdx], + incidentDirection, + incidentDirectionGrad); + + particles.template densityIncidentDirectionBwdToBuffer(particleIdx, sensorWorldPosition, incidentDirectionGrad); + } } }; diff --git a/threedgut_tracer/include/3dgut/kernels/cuda/renderers/gutRenderer.cuh b/threedgut_tracer/include/3dgut/kernels/cuda/renderers/gutRenderer.cuh index 8a451c4e..364a94f6 100644 --- a/threedgut_tracer/include/3dgut/kernels/cuda/renderers/gutRenderer.cuh +++ b/threedgut_tracer/include/3dgut/kernels/cuda/renderers/gutRenderer.cuh @@ -88,7 +88,7 @@ __global__ void render(threedgut::RenderParameters params, tcnn::mat4x3 sensorToWorldTransform, float* __restrict__ worldHitCountPtr, float* __restrict__ worldHitDistancePtr, - tcnn::vec4* __restrict__ radianceDensityPtr, + TFeatureDensityElem* __restrict__ featureDensityPtr, const tcnn::vec2* __restrict__ particlesProjectedPositionPtr, const tcnn::vec4* __restrict__ particlesProjectedConicOpacityPtr, const float* __restrict__ particlesGlobalDepthPtr, @@ -111,7 +111,7 @@ __global__ void render(threedgut::RenderParameters params, // TGUTModel::eval(params, ray, {parameterMemoryHandles}); // NB : finalize ray is not differentiable (has to be no-op when used in a differentiable renderer) - finalizeRay(ray, params, sensorRayOriginPtr, worldHitCountPtr, worldHitDistancePtr, radianceDensityPtr, sensorToWorldTransform); + finalizeRay(ray, params, sensorRayOriginPtr, worldHitCountPtr, worldHitDistancePtr, featureDensityPtr, sensorToWorldTransform); } #if FINE_GRAINED_LOAD_BALANCING @@ -124,7 +124,7 @@ __global__ void renderBalanced(threedgut::RenderParameters params, tcnn::mat4x3 sensorToWorldTransform, float* __restrict__ worldHitCountPtr, float* __restrict__ worldHitDistancePtr, - tcnn::vec4* __restrict__ radianceDensityPtr, + TFeatureDensityElem* __restrict__ featureDensityPtr, const tcnn::vec2* __restrict__ particlesProjectedPositionPtr, const tcnn::vec4* __restrict__ particlesProjectedConicOpacityPtr, const float* __restrict__ particlesGlobalDepthPtr, @@ -210,13 +210,24 @@ __global__ void renderBalanced(threedgut::RenderParameters params, // Only lane 0 should write, as only it has accumulated the correct values if (laneId == 0) { finalizeRay(ray, params, sensorRayOriginPtr, worldHitCountPtr, - worldHitDistancePtr, radianceDensityPtr, sensorToWorldTransform); + worldHitDistancePtr, featureDensityPtr, sensorToWorldTransform); } } } #endif // FINE_GRAINED_LOAD_BALANCING -__global__ void renderBackward(threedgut::RenderParameters params, +// Optional register-cap on the backward kernel (plan T3). +// Set at build time with both args, e.g. +// -DNHT_BWD_LB_THREADS=256 -DNHT_BWD_LB_MIN_BLOCKS=2 +// to force <=128 regs/thread at BlockSize=256. Default leaves the compiler +// free, matching baseline behavior (regs/thread reported by ptxas). +#if defined(NHT_BWD_LB_THREADS) && defined(NHT_BWD_LB_MIN_BLOCKS) +#define NHT_BWD_LB __launch_bounds__(NHT_BWD_LB_THREADS, NHT_BWD_LB_MIN_BLOCKS) +#else +#define NHT_BWD_LB +#endif + +__global__ NHT_BWD_LB void renderBackward(threedgut::RenderParameters params, const tcnn::uvec2* __restrict__ sortedTileRangeIndicesPtr, const uint32_t* __restrict__ sortedTileDataPtr, const tcnn::vec3* __restrict__ sensorRayOriginPtr, @@ -224,8 +235,8 @@ __global__ void renderBackward(threedgut::RenderParameters params, tcnn::mat4x3 sensorToWorldTransform, const float* __restrict__ worldHitDistancePtr, const float* __restrict__ worldHitDistanceGradientPtr, - const tcnn::vec4* __restrict__ radianceDensityPtr, - const tcnn::vec4* __restrict__ radianceDensityGradientPtr, + const TFeatureDensityElem* __restrict__ featureDensityPtr, + const float* __restrict__ featureDensityGradientPtr, tcnn::vec3* __restrict__ /*worldRayOriginGradientPtr*/, tcnn::vec3* __restrict__ /*worldRayDirectionGradientPtr*/, const tcnn::vec2* __restrict__ particlesProjectedPositionPtr, @@ -244,8 +255,8 @@ __global__ void renderBackward(threedgut::RenderParameters params, sensorRayDirectionPtr, worldHitDistancePtr, worldHitDistanceGradientPtr, - radianceDensityPtr, - radianceDensityGradientPtr, + featureDensityPtr, + featureDensityGradientPtr, sensorToWorldTransform); // TGUTModel::evalBackward(params, ray, {parameterMemoryHandles}, {parameterGradientMemoryHandles}); diff --git a/threedgut_tracer/include/3dgut/kernels/slang/models/gaussianParticles.slang b/threedgut_tracer/include/3dgut/kernels/slang/models/gaussianParticles.slang index af44b9cf..687acf8b 100644 --- a/threedgut_tracer/include/3dgut/kernels/slang/models/gaussianParticles.slang +++ b/threedgut_tracer/include/3dgut/kernels/slang/models/gaussianParticles.slang @@ -178,6 +178,17 @@ struct Parameters : IDifferentiable { return sqrt(dot(grds, grds)); } +[BackwardDifferentiable][ForceInline] float canonicalRayIntersection( + float3 canonicalRayOrigin, + float3 canonicalRayDirection, + float3 scale, + out float3 canonicalIntersection) { + const float3 canonicalGrds = canonicalRayDirection * dot(canonicalRayDirection, -1 * canonicalRayOrigin); + canonicalIntersection = canonicalRayOrigin + canonicalGrds; + const float3 grds = scale * canonicalGrds; + return sqrt(dot(grds, grds)); +} + [BackwardDifferentiable][ForceInline] float3 canonicalRayNormal( float3 canonicalRayOrigin, float3 canonicalRayDirection, @@ -200,6 +211,7 @@ bool hit( Parameters parameters, out float alpha, inout float depth, + out float3 canonicalIntersection, no_diff bool enableNormal, inout float3 normal) { @@ -220,7 +232,7 @@ bool hit( const bool acceptHit = ((maxResponse > MinParticleKernelDensity) && (alpha > MinParticleAlpha)); if (acceptHit) { - depth = canonicalRayDistance(canonicalRayOrigin, canonicalRayDirection, parameters.scale); + depth = canonicalRayIntersection(canonicalRayOrigin, canonicalRayDirection, parameters.scale, canonicalIntersection); if (enableNormal) { normal = canonicalRayNormal(canonicalRayOrigin, canonicalRayDirection, parameters.scale, parameters.rotationT); @@ -269,6 +281,7 @@ float processHitFromBuffer( no_diff RawParametersBuffer parametersBuffer, inout float transmittance, inout float integratedDepth, + out float3 canonicalIntersection, no_diff bool enableNormal, inout float3 integratedNormal) { @@ -280,6 +293,7 @@ float processHitFromBuffer( fetchParameters(particleIdx, parametersBuffer), alpha, depth, + canonicalIntersection, enableNormal, normal)) { @@ -313,7 +327,7 @@ float3 incidentDirectionFromParameters( } [BackwardDifferentiable][ForceInline] -no_diff float3 incidentDirectionFromBuffer( +float3 incidentDirectionFromBuffer( no_diff uint32_t particleIdx, no_diff RawParametersBuffer parametersBuffer, no_diff float3 sourcePosition @@ -346,6 +360,7 @@ inline bool particleDensityHit( gaussianParticle.Parameters parameters, out float alpha, out float depth, + out float3 canonicalIntersection, bool enableNormal, out float3 normal) { @@ -354,6 +369,7 @@ inline bool particleDensityHit( parameters, alpha, depth, + canonicalIntersection, enableNormal, normal); } @@ -385,6 +401,7 @@ inline float particleDensityProcessHitFwdFromBuffer( gaussianParticle.CommonParameters commonParameters, inout float transmittance, inout float integratedDepth, + out float3 canonicalIntersection, in bool enableNormal, inout float3 integratedNormal) { @@ -395,6 +412,7 @@ inline float particleDensityProcessHitFwdFromBuffer( commonParameters.parametersBuffer, transmittance, integratedDepth, + canonicalIntersection, enableNormal, integratedNormal); } @@ -412,6 +430,7 @@ void particleDensityProcessHitBwdToBuffer( in float depth, inout float integratedDepth, inout float integratedDepthGrad, + in float3 canonicalIntersectionGrad, bool enableNormal, in float3 normal, inout float3 integratedNormal, @@ -445,6 +464,7 @@ void particleDensityProcessHitBwdToBuffer( commonParameters.parametersBuffer, transmittanceDiff, integratedDepthDiff, + canonicalIntersectionGrad, enableNormal, integratedNormalDiff, alphaGrad); @@ -525,12 +545,14 @@ bool particleDensityHitInstance( [CudaDeviceExport] void particleDensityIncidentDirectionBwdToBuffer( in uint32_t particleIdx, gaussianParticle.CommonParameters commonParameters, - in float3 sourcePosition + in float3 sourcePosition, + in float3 incidentDirectionGrad ) { bwd_diff(gaussianParticle.incidentDirectionFromBuffer)( particleIdx, commonParameters.parametersBuffer, - sourcePosition + sourcePosition, + incidentDirectionGrad ); } diff --git a/threedgut_tracer/include/3dgut/kernels/slang/models/neuralHarmonicFeaturesParticle.slang b/threedgut_tracer/include/3dgut/kernels/slang/models/neuralHarmonicFeaturesParticle.slang new file mode 100644 index 00000000..af653a28 --- /dev/null +++ b/threedgut_tracer/include/3dgut/kernels/slang/models/neuralHarmonicFeaturesParticle.slang @@ -0,0 +1,364 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use it except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Neural Harmonic Features: per-particle K-dim features, output N (decoder input); decoder maps N -> RGB. +// Same CudaDeviceExport API as shRadiativeParticles.slang for drop-in replacement. +// Compile-time: FEATURE_INTERPOLATION_TYPE, FEATURE_INTERPOLATION_SUPPORT, FEATURE_ACTIVATION_TYPE, FEATURE_ACTIVATION_NUM_FREQUENCIES, +// PARTICLE_FEATURE_DIM (total K per particle = buffer stride), INTERP_POINT_FEATURE_DIM (per-interpolation-point = K/num_points). +// Support: center -> K=interpPointDim; tetrahedra -> K=4*interpPointDim. +// With sincos activation: N = interpPointDim * num_frequencies * 2 (separate sin/cos channels). + +namespace neuralHarmonicFeaturesParticle +{ +static const int ParticleFeatureDim = PARTICLE_FEATURE_DIM; // Total per-particle (K); not per-interpolation-point +static const int RayFeatureDim = RAY_FEATURE_DIM; // Decoder input N = INTERP_POINT_FEATURE_DIM * FEATURE_ACTIVATION_NUM_FREQUENCIES +static const int InterpPointFeatureDim = INTERP_POINT_FEATURE_DIM; // Per-interpolation-point dimension +static const int FeatureActivationType_None = 0; +static const int FeatureActivationType_Siren = 1; +static const int FeatureActivationType_Sincos = 2; +static const int FeatureActivationType_Relu = 3; +static const int FeatureActivationType = FEATURE_ACTIVATION_TYPE; +static const int FeatureActivationNumFrequencies = FEATURE_ACTIVATION_NUM_FREQUENCIES; + +// Interpolation type (compile-time from FEATURE_INTERPOLATION_TYPE) +static const int InterpolationType_Barycentric = 0; +static const int InterpolationType_Bezier = 1; // Not supported yet +static const int InterpolationType = FEATURE_INTERPOLATION_TYPE; + +// Interpolation support (compile-time from FEATURE_INTERPOLATION_SUPPORT) +static const int InterpolationSupport_Center = 0; +static const int InterpolationSupport_Tetrahedra = 1; +static const int InterpolationSupport_CoTriangles = 2; // 2 coplanar triangles / Not supported yet +static const int InterpolationSupport = FEATURE_INTERPOLATION_SUPPORT; + +// Canonical regular tetrahedron matching the GSplat NHT reference layout: +// p0=(sqrt(6),-sqrt(2),-1), p1=(-sqrt(6),-sqrt(2),-1), p2=(0,2*sqrt(2),-1), p3=(0,0,3). +// Incenter at origin; base z=-1, apex z=3; edge s=sqrt(24), height h=4, inradius r=1. +static const float tetraHedraEdge = 4.898979485566356f; // sqrt(24) +static const float tetraHedraFaceHeight = 4.242640687119285f; // s*sqrt(3)/2 +static const float tetraHedraHeight = 4.0f; // s*sqrt(2/3) +static const float tetraHedraFaceInRadius = 1.4142135623730951f; // s*sqrt(3)/6 = sqrt(2) +static const float tetraHedraInRadius = 1.0f; // unit sphere +static const float3 canonicalTetraVerts[4] = { + float3(0.5f * tetraHedraEdge, -tetraHedraFaceInRadius, -1.0f), + float3(-0.5f * tetraHedraEdge, -tetraHedraFaceInRadius, -1.0f), + float3(0.0f, tetraHedraFaceHeight - tetraHedraFaceInRadius, -1.0f), + float3(0.0f, 0.0f, tetraHedraHeight - tetraHedraInRadius) +}; +// Cramer terms for canonical tetrahedron (independent of P); used by barycentricTetrahedronCanonical. +static const float3 canonicalTetraE1 = canonicalTetraVerts[1] - canonicalTetraVerts[0]; +static const float3 canonicalTetraE2 = canonicalTetraVerts[2] - canonicalTetraVerts[0]; +static const float3 canonicalTetraE3 = canonicalTetraVerts[3] - canonicalTetraVerts[0]; +static const float3 canonicalTetraCrossE2E3 = cross(canonicalTetraE2, canonicalTetraE3); +static const float canonicalTetraDet = dot(canonicalTetraE1, canonicalTetraCrossE2E3); +static const float canonicalTetraInvDet = 1.0f / canonicalTetraDet; + +}; + +#if PARTICLE_FEATURE_HALF +typedef half feat_elem_t; +#else +typedef float feat_elem_t; +#endif + +namespace neuralHarmonicFeaturesParticle +{ + +struct ParametersBuffer +{ + Ptr _dataPtr; // [N_particles, K] flat; fp16 when PARTICLE_FEATURE_HALF + float *_gradPtr; + bool exclusiveGradient; +}; + +struct Parameters : IDifferentiable +{ + Array features; +}; + +[BackwardDifferentiable][ForceInline] +Parameters fetchParametersFromBuffer(no_diff uint32_t particleIdx, + no_diff int interpPointIdx, + no_diff ParametersBuffer parametersBuffer) +{ + Parameters parameters; + const uint32_t particleOffset = particleIdx * ParticleFeatureDim; + const uint32_t interpPointOffset = particleOffset + interpPointIdx * InterpPointFeatureDim; + [ForceUnroll] for (int i = 0; i < InterpPointFeatureDim; ++i) { + parameters.features[i] = parametersBuffer._dataPtr[interpPointOffset + i]; + } + return parameters; +} + +[BackwardDerivativeOf(fetchParametersFromBuffer)][ForceInline] +void fetchParametersFromBufferBwd(no_diff uint32_t particleIdx, + no_diff int interpPointIdx, + no_diff ParametersBuffer parametersBuffer, + Parameters parametersGrad) +{ + const uint32_t particleOffset = particleIdx * ParticleFeatureDim; + const uint32_t interpPointOffset = particleOffset + interpPointIdx * InterpPointFeatureDim; + [ForceUnroll] for (int i = 0; i < InterpPointFeatureDim; ++i) { + const float grad = parametersGrad.features[i]; + if (parametersBuffer.exclusiveGradient) { + parametersBuffer._gradPtr[interpPointOffset + i] += grad; + } else { + InterlockedAdd(parametersBuffer._gradPtr[interpPointOffset + i], grad); + } + } +} + +// Barycentric coordinates for the canonical tetrahedron only. Uses precomputed static const e1,e2,e3,invDet (independent of P). +[BackwardDifferentiable][ForceInline] +float4 barycentricTetrahedronCanonical(float3 P) +{ + float3 d = P - canonicalTetraVerts[0]; + float4 weights; + weights.y = dot(d, canonicalTetraCrossE2E3) * canonicalTetraInvDet; + weights.z = dot(canonicalTetraE1, cross(d, canonicalTetraE3)) * canonicalTetraInvDet; + weights.w = dot(canonicalTetraE1, cross(canonicalTetraE2, d)) * canonicalTetraInvDet; + weights.x = 1.0f - weights.y - weights.z - weights.w; + return weights; +} + +// Encode and activate : none -> identity; siren -> sin(b*2^f); relu -> max(0,b). +[BackwardDifferentiable][ForceInline] +float encodeAndActivate(float baseVal, no_diff int f) +{ + if (FeatureActivationType == FeatureActivationType_None) + return baseVal; + if (FeatureActivationType == FeatureActivationType_Relu) + return max(0.0f, baseVal); + float freq = ldexp(1.0f, f); + float angle = baseVal * freq; + return sin(angle); +} + +// Compute blended features into baseFeatures[INTERP_POINT_FEATURE_DIM], then optionally expand by activation to features[RayFeatureDim]. +// canonicalPosition is differential (hit position in particle canonical space; gradient accumulated in API canonicalPositionGrad). +[BackwardDifferentiable][ForceInline] +void featuresFromParametersBuffer(ParametersBuffer parametersBuffer, + no_diff uint32_t particleIdx, + float3 canonicalPosition, + out Array features +) +{ + Array baseFeatures = fetchParametersFromBuffer(particleIdx, 0, parametersBuffer).features; + if (InterpolationSupport == InterpolationSupport_Tetrahedra && InterpolationType == InterpolationType_Barycentric) { + float4 barycentricWeights = barycentricTetrahedronCanonical(canonicalPosition); + [ForceUnroll] for (int n = 0; n < InterpPointFeatureDim; ++n) { + baseFeatures[n] *= barycentricWeights[0]; + } + [ForceUnroll] for (int k = 1; k < 4; ++k) { + Parameters parameters = fetchParametersFromBuffer(particleIdx, k, parametersBuffer); + [ForceUnroll] for (int n = 0; n < InterpPointFeatureDim; ++n) { + baseFeatures[n] += barycentricWeights[k] * parameters.features[n]; + } + } + } + + if (FeatureActivationType == FeatureActivationType_None) { + [ForceUnroll] for (int i = 0; i < InterpPointFeatureDim; ++i) { + features[i] = baseFeatures[i]; + } + } else if (FeatureActivationType == FeatureActivationType_Relu) { + [ForceUnroll] for (int i = 0; i < InterpPointFeatureDim; ++i) { + features[i] = max(0.0f, baseFeatures[i]); + } + } else if (FeatureActivationType == FeatureActivationType_Sincos) { + [ForceUnroll] for (int k = 0; k < InterpPointFeatureDim; ++k) { + [ForceUnroll] for (int f = 0; f < FeatureActivationNumFrequencies; ++f) { + float freq = float(f + 1); + float angle = baseFeatures[k] * freq; + int outIdx = k * FeatureActivationNumFrequencies * 2 + f * 2; + features[outIdx + 0] = sin(angle); + features[outIdx + 1] = cos(angle); + } + } + } else { + [ForceUnroll] for (int k = 0; k < InterpPointFeatureDim; ++k) { + [ForceUnroll] for (int f = 0; f < FeatureActivationNumFrequencies; ++f) { + features[k * FeatureActivationNumFrequencies + f] = encodeAndActivate(baseFeatures[k], f); + } + } + } +} + +[BackwardDifferentiable][ForceInline] +void integrateFeatures(float weight, + in Array features, + inout float integratedFeatures[RayFeatureDim]) +{ + if (weight > 0.0f) { + [ForceUnroll] for (int i = 0; i < RayFeatureDim; ++i) { + if (backToFront) + integratedFeatures[i] = lerp(integratedFeatures[i], features[i], weight); + else + integratedFeatures[i] += features[i] * weight; + } + } +} + +[BackwardDifferentiable][ForceInline] +void integrateFeaturesFromBuffer(float weight, + no_diff uint32_t particleIdx, + ParametersBuffer parametersBuffer, + float3 canonicalPosition, + inout float integratedFeatures[RayFeatureDim]) +{ + if (weight > 0.0f) { + Array features; + featuresFromParametersBuffer(parametersBuffer, particleIdx, canonicalPosition, features); + integrateFeatures(weight, features, integratedFeatures); + } +} + +} // namespace neuralHarmonicFeaturesParticle + +// ------------------------------------------------------------------------------------------------------------------ +// Entry points - same CudaDeviceExport API as shRadiativeParticles.slang + +[CudaDeviceExport] +inline void particleFeaturesFromBuffer( + in uint32_t particleIdx, + in feat_elem_t *featuresBufferPtr, + in int auxParam, + in float3 incidentDirection, + in float3 canonicalPosition, + out float features[neuralHarmonicFeaturesParticle.RayFeatureDim]) +{ + neuralHarmonicFeaturesParticle.ParametersBuffer parametersBuffer; + parametersBuffer._dataPtr = (Ptr)featuresBufferPtr; + parametersBuffer._gradPtr = nullptr; + parametersBuffer.exclusiveGradient = false; + + neuralHarmonicFeaturesParticle.featuresFromParametersBuffer( + parametersBuffer, + particleIdx, + canonicalPosition, + features + ); +} + +[CudaDeviceExport] +inline void particleFeaturesIntegrateFwd(in float weight, + in float features[neuralHarmonicFeaturesParticle.RayFeatureDim], + inout float integratedFeatures[neuralHarmonicFeaturesParticle.RayFeatureDim]) +{ + neuralHarmonicFeaturesParticle.integrateFeatures( + weight, + (Array)features, + (Array)integratedFeatures + ); +} + +[CudaDeviceExport] +inline void particleFeaturesIntegrateFwdFromBuffer( + in float3 incidentDirection, + in float3 canonicalPosition, + in float weight, + in uint32_t particleIdx, + in feat_elem_t *featuresBufferPtr, + in int auxParam, + inout float integratedFeatures[neuralHarmonicFeaturesParticle.RayFeatureDim]) +{ + neuralHarmonicFeaturesParticle.ParametersBuffer parametersBuffer; + parametersBuffer._dataPtr = (Ptr)featuresBufferPtr; + parametersBuffer._gradPtr = nullptr; + parametersBuffer.exclusiveGradient = false; + + neuralHarmonicFeaturesParticle.integrateFeaturesFromBuffer( + weight, particleIdx, parametersBuffer, canonicalPosition, integratedFeatures); +} + +// canonicalPosition is differential; canonicalPositionGrad is inout and accumulated by this backward. +[CudaDeviceExport] void particleFeaturesIntegrateBwdToBuffer( + in float3 incidentDirection, + in float3 canonicalPosition, + inout float3 canonicalPositionGrad, + in float alpha, + inout float alphaGrad, + in uint32_t particleIdx, + in feat_elem_t *featuresBufferPtr, + in float *featuresBufferGradPtr, + in int auxParam, + in bool exclusiveGradient, + in float features[neuralHarmonicFeaturesParticle.RayFeatureDim], + inout float integratedFeatures[neuralHarmonicFeaturesParticle.RayFeatureDim], + inout float integratedFeaturesGrad[neuralHarmonicFeaturesParticle.RayFeatureDim]) +{ + if (alpha > 0.0f) + { + neuralHarmonicFeaturesParticle.ParametersBuffer parametersBuffer; + parametersBuffer._dataPtr = (Ptr)featuresBufferPtr; + parametersBuffer._gradPtr = (Ptr)featuresBufferGradPtr; + parametersBuffer.exclusiveGradient = exclusiveGradient; + + DifferentialPair alphaDiff = DifferentialPair(alpha, alphaGrad); + + const float weight = 1.0f / (1.0f - alpha); + [ForceUnroll] for (int i = 0; i < neuralHarmonicFeaturesParticle.RayFeatureDim; ++i) { + integratedFeatures[i] = (integratedFeatures[i] - features[i] * alpha) * weight; + } + + DifferentialPair integratedFeaturesDiff = + DifferentialPair(integratedFeatures, integratedFeaturesGrad); + + DifferentialPair canonicalPositionDiff = DifferentialPair(canonicalPosition, canonicalPositionGrad); + + bwd_diff(neuralHarmonicFeaturesParticle.integrateFeaturesFromBuffer)( + alphaDiff, + particleIdx, + parametersBuffer, + canonicalPositionDiff, + integratedFeaturesDiff); + + integratedFeaturesGrad = integratedFeaturesDiff.getDifferential(); + canonicalPositionGrad += canonicalPositionDiff.getDifferential(); + alphaGrad += alphaDiff.getDifferential(); + } +} + +[CudaDeviceExport] void particleFeaturesIntegrateBwd( + in float alpha, + inout float alphaGrad, + in float features[neuralHarmonicFeaturesParticle.RayFeatureDim], + inout float featuresGrad[neuralHarmonicFeaturesParticle.RayFeatureDim], + inout float integratedFeatures[neuralHarmonicFeaturesParticle.RayFeatureDim], + inout float integratedFeaturesGrad[neuralHarmonicFeaturesParticle.RayFeatureDim]) +{ + if (alpha > 0.0f) + { + DifferentialPair alphaDiff = DifferentialPair(alpha, alphaGrad); + DifferentialPair featuresDiff = + DifferentialPair(features, featuresGrad); + + const float weight = 1.0f / (1.0f - alpha); + [ForceUnroll] for (int i = 0; i < neuralHarmonicFeaturesParticle.RayFeatureDim; ++i) { + integratedFeatures[i] = (integratedFeatures[i] - features[i] * alpha) * weight; + } + DifferentialPair integratedFeaturesDiff = + DifferentialPair(integratedFeatures, integratedFeaturesGrad); + + bwd_diff(neuralHarmonicFeaturesParticle.integrateFeatures)( + alphaDiff, + featuresDiff, + integratedFeaturesDiff); + + alphaGrad = alphaDiff.getDifferential(); + featuresGrad = featuresDiff.getDifferential(); + integratedFeaturesGrad = integratedFeaturesDiff.getDifferential(); + } +} diff --git a/threedgut_tracer/include/3dgut/kernels/slang/models/radiativeParticles.slang b/threedgut_tracer/include/3dgut/kernels/slang/models/radiativeParticles.slang new file mode 100644 index 00000000..32d27dad --- /dev/null +++ b/threedgut_tracer/include/3dgut/kernels/slang/models/radiativeParticles.slang @@ -0,0 +1,27 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Radiative Particles Wrapper +// Conditionally includes the appropriate feature implementation based on FEATURE_TRANSFORM_TYPE + +#if FEATURE_TRANSFORM_TYPE == 0 + // Spherical Harmonics mode + #include <3dgut/kernels/slang/models/shRadiativeParticles.slang> +#elif FEATURE_TRANSFORM_TYPE == 1 + // Post-MLP radiance mode + #include <3dgut/kernels/slang/models/neuralHarmonicFeaturesParticle.slang> +#else + #error "Unknown FEATURE_TRANSFORM_TYPE. Must be 0 (SH) or 1 (neural_harmonic_features)" +#endif diff --git a/threedgut_tracer/include/3dgut/kernels/slang/models/shRadiativeParticles.slang b/threedgut_tracer/include/3dgut/kernels/slang/models/shRadiativeParticles.slang index 66334e49..1ef1a7e7 100644 --- a/threedgut_tracer/include/3dgut/kernels/slang/models/shRadiativeParticles.slang +++ b/threedgut_tracer/include/3dgut/kernels/slang/models/shRadiativeParticles.slang @@ -2,7 +2,7 @@ // SPDX-License-Identifier: Apache-2.0 // // Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. +// you may not use it except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 @@ -15,22 +15,22 @@ #include <3dgut/kernels/slang/common/sphericalHarmonics.slang> +namespace shRadiativeParticle +{ +static const int RadianceMaxNumSphCoefficients = PARTICLE_RADIANCE_NUM_COEFFS; +static const int Dim = 3; +}; + namespace shRadiativeParticle { struct ParametersBuffer { Ptr, Access.Read> _dataPtr; - Ptr> _gradPtr; + vector *_gradPtr; bool exclusiveGradient; //< true if the gradient maybe updated without atomics }; -struct CommonParameters -{ - ParametersBuffer parametersBuffer; - int sphDegree; -}; - struct Parameters : IDifferentiable { vector sphCoefficients[RadianceMaxNumSphCoefficients]; @@ -42,7 +42,7 @@ Parameters fetchParametersFromBuffer(no_diff uint32_t particleIdx, { Parameters parameters; const uint32_t particleOffset = particleIdx * RadianceMaxNumSphCoefficients; - [unroll] for (int i = 0; i < RadianceMaxNumSphCoefficients; ++i) { + [ForceUnroll] for (int i = 0; i < RadianceMaxNumSphCoefficients; ++i) { parameters.sphCoefficients[i] = parametersBuffer._dataPtr[particleOffset + i]; } return parameters; @@ -54,14 +54,14 @@ void fetchParametersFromBufferBwd(no_diff uint32_t particleIdx, Parameters parametersGrad) { const uint32_t particleOffset = particleIdx * RadianceMaxNumSphCoefficients; - [unroll] for (int i = 0; i < RadianceMaxNumSphCoefficients; ++i) { + [ForceUnroll] for (int i = 0; i < RadianceMaxNumSphCoefficients; ++i) { const vector coeffs = parametersGrad.sphCoefficients[i]; if (parametersBuffer.exclusiveGradient) { - [unroll] for (int j = 0; j < Dim; ++j) { + [ForceUnroll] for (int j = 0; j < Dim; ++j) { parametersBuffer._gradPtr[particleOffset + i][j] += coeffs[j]; } } else { - [unroll] for (int j = 0; j < Dim; ++j) { + [ForceUnroll] for (int j = 0; j < Dim; ++j) { InterlockedAdd(parametersBuffer._gradPtr[particleOffset + i][j], coeffs[j]); } } @@ -70,7 +70,7 @@ void fetchParametersFromBufferBwd(no_diff uint32_t particleIdx, [BackwardDifferentiable] [ForceInline] vector radianceFromBuffer(no_diff uint32_t particleIdx, - no_diff float3 incidentDirection, + in float3 incidentDirection, no_diff uint32_t sphDegree, no_diff ParametersBuffer parametersBuffer) { @@ -139,61 +139,103 @@ void integrateRadianceFromBuffer(no_diff float3 incident // Entry points [CudaDeviceExport] -inline vector particleFeaturesFromBuffer(in uint32_t particleIdx, - shRadiativeParticle.CommonParameters commonParameters, - in float3 incidentDirection) +inline void particleFeaturesFromBuffer( + in uint32_t particleIdx, + in float *radianceBufferPtr, + in int auxParam, + in float3 incidentDirection, + in float3 canonicalPosition, + out float features[shRadiativeParticle.Dim]) { - return sphericalHarmonics.decode( - commonParameters.sphDegree, - shRadiativeParticle.fetchParametersFromBuffer(particleIdx, commonParameters.parametersBuffer).sphCoefficients, + shRadiativeParticle.ParametersBuffer parametersBuffer; + parametersBuffer._dataPtr = (Ptr, Access.Read>)radianceBufferPtr; + parametersBuffer._gradPtr = nullptr; + parametersBuffer.exclusiveGradient = false; + + vector featuresVec = sphericalHarmonics.decode( + auxParam, + shRadiativeParticle.fetchParametersFromBuffer(particleIdx, parametersBuffer).sphCoefficients, incidentDirection); + + [ForceUnroll] for (int i = 0; i < shRadiativeParticle.Dim; ++i) { + features[i] = featuresVec[i]; + } } [CudaDeviceExport] inline void particleFeaturesIntegrateFwd(in float weight, - in vector features, - inout vector integratedFeatures) -{ - shRadiativeParticle.integrateRadiance( - weight, - features, - integratedFeatures - ); + in float features[shRadiativeParticle.Dim], + inout float integratedFeatures[shRadiativeParticle.Dim]) +{ + vector featuresVec; + [ForceUnroll] for (int i = 0; i < shRadiativeParticle.Dim; ++i) { + featuresVec[i] = features[i]; + } + vector integratedVec; + [ForceUnroll] for (int i = 0; i < shRadiativeParticle.Dim; ++i) { + integratedVec[i] = integratedFeatures[i]; + } + shRadiativeParticle.integrateRadiance(weight, featuresVec, integratedVec); + [ForceUnroll] for (int i = 0; i < shRadiativeParticle.Dim; ++i) { + integratedFeatures[i] = integratedVec[i]; + } } -[CudaDeviceExport] inline void particleFeaturesIntegrateFwdFromBuffer(in float3 incidentDirection, - in float weight, - in uint32_t particleIdx, - shRadiativeParticle.CommonParameters commonParameters, - inout vector integratedFeatures) +[CudaDeviceExport] +inline void particleFeaturesIntegrateFwdFromBuffer( + in float3 incidentDirection, + in float3 canonicalPosition, + in float weight, + in uint32_t particleIdx, + in float *radianceBufferPtr, + in int auxParam, + inout float integratedFeatures[shRadiativeParticle.Dim]) { + shRadiativeParticle.ParametersBuffer parametersBuffer; + parametersBuffer._dataPtr = (Ptr, Access.Read>)radianceBufferPtr; + parametersBuffer._gradPtr = nullptr; + parametersBuffer.exclusiveGradient = false; + + vector integratedVec; + [ForceUnroll] for (int i = 0; i < shRadiativeParticle.Dim; ++i) { + integratedVec[i] = integratedFeatures[i]; + } shRadiativeParticle.integrateRadianceFromBuffer( - incidentDirection, - commonParameters.sphDegree, - weight, - particleIdx, - commonParameters.parametersBuffer, - integratedFeatures); + incidentDirection, auxParam, weight, particleIdx, parametersBuffer, integratedVec); + [ForceUnroll] for (int i = 0; i < shRadiativeParticle.Dim; ++i) { + integratedFeatures[i] = integratedVec[i]; + } } [CudaDeviceExport] void particleFeaturesIntegrateBwd( in float alpha, inout float alphaGrad, - in vector features, - inout vector featuresGrad, - inout vector integratedFeatures, - inout vector integratedFeaturesGrad) + in float features[shRadiativeParticle.Dim], + inout float featuresGrad[shRadiativeParticle.Dim], + inout float integratedFeatures[shRadiativeParticle.Dim], + inout float integratedFeaturesGrad[shRadiativeParticle.Dim]) { if(alpha > 0.0f) { + vector featuresVec, featuresGradVec; + [ForceUnroll] for (int i = 0; i < shRadiativeParticle.Dim; ++i) { + featuresVec[i] = features[i]; + featuresGradVec[i] = featuresGrad[i]; + } + vector integratedVec, integratedGradVec; + [ForceUnroll] for (int i = 0; i < shRadiativeParticle.Dim; ++i) { + integratedVec[i] = integratedFeatures[i]; + integratedGradVec[i] = integratedFeaturesGrad[i]; + } + DifferentialPair alphaDiff = DifferentialPair(alpha, alphaGrad); DifferentialPair> featuresDiff = - DifferentialPair>(features, featuresGrad); + DifferentialPair>(featuresVec, featuresGradVec); const float weight = 1.0f / (1.0f - alpha); - integratedFeatures = (integratedFeatures - features * alpha) * weight; + integratedVec = (integratedVec - featuresVec * alpha) * weight; DifferentialPair> integratedFeaturesDiff = - DifferentialPair>(integratedFeatures, integratedFeaturesGrad); + DifferentialPair>(integratedVec, integratedGradVec); bwd_diff(shRadiativeParticle.integrateRadiance)( alphaDiff, @@ -201,54 +243,98 @@ inline void particleFeaturesIntegrateFwd(in float weight, integratedFeaturesDiff); alphaGrad = alphaDiff.getDifferential(); - featuresGrad = featuresDiff.getDifferential(); - integratedFeaturesGrad = integratedFeaturesDiff.getDifferential(); + [ForceUnroll] for (int i = 0; i < shRadiativeParticle.Dim; ++i) { + featuresGrad[i] = featuresDiff.getDifferential()[i]; + } + [ForceUnroll] for (int i = 0; i < shRadiativeParticle.Dim; ++i) { + integratedFeaturesGrad[i] = integratedFeaturesDiff.getDifferential()[i]; + } + [ForceUnroll] for (int i = 0; i < shRadiativeParticle.Dim; ++i) { + integratedFeatures[i] = integratedVec[i]; + } } } [CudaDeviceExport] void particleFeaturesIntegrateBwdToBuffer( in float3 incidentDirection, + in float3 canonicalPosition, + inout float3 canonicalPositionGrad, in float alpha, inout float alphaGrad, in uint32_t particleIdx, - shRadiativeParticle.CommonParameters commonParameters, - in vector features, - inout vector integratedFeatures, - inout vector integratedFeaturesGrad) + in float *featuresBufferPtr, + in float *featuresBufferGradPtr, + in int auxParam, + in bool exclusiveGradient, + in float features[shRadiativeParticle.Dim], + inout float integratedFeatures[shRadiativeParticle.Dim], + inout float integratedFeaturesGrad[shRadiativeParticle.Dim]) { if (alpha > 0.0f) { + shRadiativeParticle.ParametersBuffer parametersBuffer; + parametersBuffer._dataPtr = (Ptr, Access.Read>)featuresBufferPtr; + parametersBuffer._gradPtr = (vector*)featuresBufferGradPtr; + parametersBuffer.exclusiveGradient = exclusiveGradient; + DifferentialPair alphaDiff = DifferentialPair(alpha, alphaGrad); const float weight = 1.0f / (1.0f - alpha); - integratedFeatures = (integratedFeatures - features * alpha) * weight; + [ForceUnroll] for (int i = 0; i < shRadiativeParticle.Dim; ++i) { + integratedFeatures[i] = (integratedFeatures[i] - features[i] * alpha) * weight; + } + + vector integratedFeaturesVec; + vector integratedFeaturesGradVec; + [ForceUnroll] for (int i = 0; i < shRadiativeParticle.Dim; ++i) { + integratedFeaturesVec[i] = integratedFeatures[i]; + integratedFeaturesGradVec[i] = integratedFeaturesGrad[i]; + } DifferentialPair> integratedFeaturesDiff = - DifferentialPair>(integratedFeatures, integratedFeaturesGrad); + DifferentialPair>(integratedFeaturesVec, integratedFeaturesGradVec); bwd_diff(shRadiativeParticle.integrateRadianceFromBuffer)( incidentDirection, - commonParameters.sphDegree, + auxParam, alphaDiff, particleIdx, - commonParameters.parametersBuffer, + parametersBuffer, integratedFeaturesDiff); - integratedFeaturesGrad = integratedFeaturesDiff.getDifferential(); - alphaGrad = alphaDiff.getDifferential(); + [ForceUnroll] for (int i = 0; i < shRadiativeParticle.Dim; ++i) { + integratedFeaturesGrad[i] = integratedFeaturesDiff.getDifferential()[i]; + } + alphaGrad += alphaDiff.getDifferential(); } } [CudaDeviceExport] void particleFeaturesBwdToBuffer( in uint32_t particleIdx, - shRadiativeParticle.CommonParameters commonParameters, - in vector featuresGrad, - in float3 incidentDirection + in float *featuresBufferPtr, + in float *featuresBufferGradPtr, + in int auxParam, + in bool exclusiveGradient, + in float featuresGrad[shRadiativeParticle.Dim], + in float3 incidentDirection, + inout float3 incidentDirectionGrad ) { + shRadiativeParticle.ParametersBuffer parametersBuffer; + parametersBuffer._dataPtr = (Ptr, Access.Read>)featuresBufferPtr; + parametersBuffer._gradPtr = (vector*)featuresBufferGradPtr; + parametersBuffer.exclusiveGradient = exclusiveGradient; + + vector featuresGradVec; + [ForceUnroll] for (int i = 0; i < shRadiativeParticle.Dim; ++i) featuresGradVec[i] = featuresGrad[i]; + + DifferentialPair incidentDirectionDiff = DifferentialPair(incidentDirection, incidentDirectionGrad); + bwd_diff(shRadiativeParticle.radianceFromBuffer)( particleIdx, - incidentDirection, - commonParameters.sphDegree, - commonParameters.parametersBuffer, - featuresGrad); + incidentDirectionDiff, + auxParam, + parametersBuffer, + featuresGradVec); + + incidentDirectionGrad += incidentDirectionDiff.getDifferential(); } diff --git a/threedgut_tracer/include/3dgut/renderer/gutRenderer.h b/threedgut_tracer/include/3dgut/renderer/gutRenderer.h index 7abd81fa..8ad1919b 100644 --- a/threedgut_tracer/include/3dgut/renderer/gutRenderer.h +++ b/threedgut_tracer/include/3dgut/renderer/gutRenderer.h @@ -22,6 +22,16 @@ #include +#ifndef FEATURE_OUTPUT_HALF +#define FEATURE_OUTPUT_HALF 0 +#endif +#if FEATURE_OUTPUT_HALF +#include +using TFeatureDensityElem = __half; +#else +using TFeatureDensityElem = float; +#endif + namespace threedgut { class GUTRenderer { @@ -67,7 +77,7 @@ class GUTRenderer { const tcnn::vec3* sensorRayDirectionCudaPtr, float* worldHitCountCudaPtr, float* worldHitDistanceCudaPtr, - tcnn::vec4* radianceDensityCudaPtr, + TFeatureDensityElem* featureDensityCudaPtr, int* particlesVisibilityCudaPtr, Parameters& parameters, int cudaDeviceIndex, @@ -78,8 +88,8 @@ class GUTRenderer { const tcnn::vec3* sensorRayDirectionCudaPtr, const float* worldHitDistanceCudaPtr, const float* worldHitDistanceGradientCudaPtr, - const tcnn::vec4* radianceDensityCudaPtr, - const tcnn::vec4* radianceDensityGradientCudaPtr, + const TFeatureDensityElem* featureDensityCudaPtr, + const float* featureDensityGradientCudaPtr, tcnn::vec3* worldRayOriginGradientCudaPtr, tcnn::vec3* worldRayDirectionGradientCudaPtr, Parameters& parameters, diff --git a/threedgut_tracer/include/3dgut/threedgut.cuh b/threedgut_tracer/include/3dgut/threedgut.cuh index 3ab2b092..e01ca757 100644 --- a/threedgut_tracer/include/3dgut/threedgut.cuh +++ b/threedgut_tracer/include/3dgut/threedgut.cuh @@ -29,7 +29,13 @@ struct model_InternalParams { }; struct model_ExternalParams { - static constexpr int FeaturesDim = 3; + // Feature-based radiance dimensions (from compile-time defines) + static constexpr int ParticleFeatureDim = PARTICLE_FEATURE_DIM; + static constexpr int RayFeatureDim = RAY_FEATURE_DIM; + static constexpr int FeatureTransformType = FEATURE_TRANSFORM_TYPE; // 0=SH, 1=nht + // NHT requires per-ray evaluation (interpolation); SH can use precomputed features + static constexpr bool PerRayParticleFeatures = (FEATURE_TRANSFORM_TYPE != 0); + static constexpr float AlphaThreshold = GAUSSIAN_PARTICLE_MIN_ALPHA; // = 1.0/255.0 static constexpr float MinTransmittanceThreshold = GAUSSIAN_MIN_TRANSMITTANCE_THRESHOLD; // = 0.0001 static constexpr int KernelDegree = GAUSSIAN_PARTICLE_KERNEL_DEGREE; @@ -52,7 +58,7 @@ struct TGUTProjectorParams { static constexpr bool TightOpacityBounding = GAUSSIAN_TIGHT_OPACITY_BOUNDING; static constexpr bool RectBounding = GAUSSIAN_RECT_BOUNDING; static constexpr bool TileCulling = GAUSSIAN_TILE_BASED_CULLING; - static constexpr bool PerRayParticleFeatures = false; + static constexpr bool PerRayParticleFeatures = model_ExternalParams::PerRayParticleFeatures; static constexpr float MaxDepthValue = 3.4028235e+38; static constexpr bool GlobalZOrder = GAUSSIAN_GLOBAL_Z_ORDER; static constexpr bool BackwardProjection = false; // m_settings.renderMode == Settings::Splat @@ -77,7 +83,7 @@ static_assert(TGUTProjectionParams::RequireAllSigmaPoints == false, "RequireAllS using TGUTProjector = GUTProjector; struct TGUTRendererParams { - static constexpr bool PerRayParticleFeatures = TGUTProjectorParams::PerRayParticleFeatures; + static constexpr bool PerRayParticleFeatures = model_ExternalParams::PerRayParticleFeatures; static constexpr int KHitBufferSize = GAUSSIAN_K_BUFFER_SIZE; static constexpr bool CustomBackward = false; }; diff --git a/threedgut_tracer/include/3dgut/threedgut.slang b/threedgut_tracer/include/3dgut/threedgut.slang index 554e7fca..cb7f0ecc 100644 --- a/threedgut_tracer/include/3dgut/threedgut.slang +++ b/threedgut_tracer/include/3dgut/threedgut.slang @@ -24,11 +24,4 @@ namespace gaussianParticle }; #include <3dgut/kernels/slang/models/gaussianParticles.slang> - -namespace shRadiativeParticle -{ - static const int RadianceMaxNumSphCoefficients = PARTICLE_RADIANCE_NUM_COEFFS; - static const int Dim = 3; -}; - -#include <3dgut/kernels/slang/models/shRadiativeParticles.slang> +#include <3dgut/kernels/slang/models/radiativeParticles.slang> diff --git a/threedgut_tracer/setup_3dgut.py b/threedgut_tracer/setup_3dgut.py index f88f443a..5ba539af 100644 --- a/threedgut_tracer/setup_3dgut.py +++ b/threedgut_tracer/setup_3dgut.py @@ -19,6 +19,7 @@ import torch from threedgrut.utils import jit +from threedgrut.model.features import Features # ---------------------------------------------------------------------------- @@ -32,9 +33,6 @@ def setup_3dgut(conf): include_paths.append(os.path.join(prefix, "..", "thirdparty", "tiny-cuda-nn", "include")) include_paths.append(os.path.join(prefix, "..", "thirdparty", "tiny-cuda-nn", "dependencies")) include_paths.append(os.path.join(prefix, "..", "thirdparty", "tiny-cuda-nn", "dependencies", "fmt", "include")) - include_paths.append(build_dir) - - # Compiler options. def to_cpp_bool(value): return "true" if value else "false" @@ -45,6 +43,24 @@ def to_cpp_bool(value): ut_kappa = conf.render.splat.ut_kappa ut_delta = math.sqrt(ut_alpha * ut_alpha * (ut_d + ut_kappa)) + feat = Features(conf) + transform_defines = [ + f"-DPARTICLE_FEATURE_DIM={feat.particle_feature_dim}", + f"-DRAY_FEATURE_DIM={feat.ray_feature_dim}", + f"-DFEATURE_TRANSFORM_TYPE={feat.transform_type}", + ] + nht_defines = [ + f"-DFEATURE_INTERPOLATION_TYPE={feat.interpolation_type}", + f"-DFEATURE_INTERPOLATION_SUPPORT={feat.interpolation_support}", + f"-DFEATURE_ACTIVATION_TYPE={feat.activation_type}", + f"-DFEATURE_ACTIVATION_NUM_FREQUENCIES={feat.activation_num_frequencies}", + f"-DINTERP_POINT_FEATURE_DIM={feat.interp_point_feature_dim}", + ] + half_defines = [ + f"-DPARTICLE_FEATURE_HALF={1 if conf.render.particle_feature_half else 0}", + f"-DFEATURE_OUTPUT_HALF={1 if conf.render.feature_output_half else 0}", + ] + defines = [ f"-DPARTICLE_RADIANCE_NUM_COEFFS={(conf.render.particle_radiance_sph_degree + 1) ** 2}", f"-DGAUSSIAN_PARTICLE_KERNEL_DEGREE={conf.render.particle_kernel_degree}", @@ -55,6 +71,11 @@ def to_cpp_bool(value): f"-DGAUSSIAN_PARTICLE_SURFEL={to_cpp_bool(conf.render.primitive_type=='trisurfel')}", f"-DGAUSSIAN_MIN_TRANSMITTANCE_THRESHOLD={conf.render.min_transmittance}", f"-DGAUSSIAN_ENABLE_HIT_COUNT={to_cpp_bool(conf.render.enable_hitcounts)}", + # Feature-based radiance dimensions + *transform_defines, + *nht_defines, + # Feature buffer memory layout + *half_defines, # Specific to the 3DGUT renderer f"-DGAUSSIAN_N_ROLLING_SHUTTER_ITERATIONS={conf.render.splat.n_rolling_shutter_iterations}", f"-DGAUSSIAN_K_BUFFER_SIZE={conf.render.splat.k_buffer_size}", @@ -88,6 +109,18 @@ def to_cpp_bool(value): "-O3", *defines, ] + # Diagnostic: dump ptxas register / smem / spill stats per kernel. + # Enable with `export NHT_PTXAS_VERBOSE=1` before launching training. + if os.environ.get("NHT_PTXAS_VERBOSE", "0") == "1": + cuda_cflags += [ + "-Xptxas=-v", + "--resource-usage", + ] + # When PARTICLE_FEATURE_HALF=1 the Slang-generated header uses __half types; + # the Slang prelude only pulls in and defines __half when + # SLANG_CUDA_ENABLE_HALF is set. + if conf.render.particle_feature_half or conf.render.feature_output_half: + cuda_cflags.append("-DSLANG_CUDA_ENABLE_HALF=1") # List of sources. source_files = [ @@ -99,9 +132,11 @@ def to_cpp_bool(value): # Compile slang kernels slang_build_inc_dir = os.path.join(os.path.dirname(__file__), "include", "3dgut") + slang_output_file = os.path.join(os.path.dirname(__file__), "include", "threedgutSlang.cuh") + jit.compile_slang_kernel( kernel_files=[f"{os.path.join(slang_build_inc_dir, 'threedgut.slang')}"], - output_file=f"{os.path.join(build_dir, 'threedgutSlang.cuh')}", + output_file=slang_output_file, include_paths=[ os.path.join(os.path.dirname(__file__), "include"), os.path.join(os.path.dirname(__file__), "..", "threedgrt_tracer", "include"), diff --git a/threedgut_tracer/src/gutRenderer.cu b/threedgut_tracer/src/gutRenderer.cu index 8c2472c3..5c361b63 100644 --- a/threedgut_tracer/src/gutRenderer.cu +++ b/threedgut_tracer/src/gutRenderer.cu @@ -39,7 +39,7 @@ namespace { using namespace threedgut; constexpr int featuresDim() { - return model_ExternalParams::FeaturesDim; + return model_ExternalParams::RayFeatureDim; } // identify tiles start/end indices in the sorted tile/depth keys buffer @@ -243,7 +243,7 @@ threedgut::Status threedgut::GUTRenderer::renderForward(const RenderParameters& const vec3* sensorRayDirectionCudaPtr, float* worldHitCountCudaPtr, float* worldHitDistanceCudaPtr, - vec4* radianceDensityCudaPtr, + TFeatureDensityElem* featureDensityCudaPtr, int* particlesVisibilityCudaPtr, Parameters& parameters, int cudaDeviceIndex, @@ -390,7 +390,7 @@ threedgut::Status threedgut::GUTRenderer::renderForward(const RenderParameters& sensorPoseToMat(sensorPoseInv), worldHitCountCudaPtr, worldHitDistanceCudaPtr, - radianceDensityCudaPtr, + featureDensityCudaPtr, (const tcnn::vec2*)m_forwardContext->particlesProjectedPosition.data(), (const tcnn::vec4*)m_forwardContext->particlesProjectedConicOpacity.data(), (const float*)m_forwardContext->particlesGlobalDepth.data(), @@ -407,7 +407,7 @@ threedgut::Status threedgut::GUTRenderer::renderForward(const RenderParameters& sensorPoseToMat(sensorPoseInv), worldHitCountCudaPtr, worldHitDistanceCudaPtr, - radianceDensityCudaPtr, + featureDensityCudaPtr, (const tcnn::vec2*)m_forwardContext->particlesProjectedPosition.data(), (const tcnn::vec4*)m_forwardContext->particlesProjectedConicOpacity.data(), (const float*)m_forwardContext->particlesGlobalDepth.data(), @@ -423,10 +423,10 @@ threedgut::Status threedgut::GUTRenderer::renderForward(const RenderParameters& threedgut::Status threedgut::GUTRenderer::renderBackward(const RenderParameters& params, const vec3* sensorRayOriginCudaPtr, const vec3* sensorRayDirectionCudaPtr, - const float* worldHitDistanceCudaPtr, // - const float* worldHitDistanceGradientCudaPtr, // TODO: not implemented yet - const vec4* radianceDensityCudaPtr, // - const vec4* radianceDensityGradientCudaPtr, // TODO: not implemented yet + const float* worldHitDistanceCudaPtr, + const float* worldHitDistanceGradientCudaPtr, + const TFeatureDensityElem* featureDensityCudaPtr, + const float* featureDensityGradientCudaPtr, vec3* worldRayOriginGradientCudaPtr, // TODO: not implemented yet vec3* worldRayDirectionGradientCudaPtr, // TODO: not implemented yet Parameters& parameters, @@ -477,8 +477,8 @@ threedgut::Status threedgut::GUTRenderer::renderBackward(const RenderParameters& sensorPoseToMat(sensorPoseInv), (const float*)worldHitDistanceCudaPtr, // (const float*)worldHitDistanceGradientCudaPtr, // TODO: not implemented yet - (const tcnn::vec4*)radianceDensityCudaPtr, // - (const tcnn::vec4*)radianceDensityGradientCudaPtr, // TODO: not implemented yet + featureDensityCudaPtr, + featureDensityGradientCudaPtr, (tcnn::vec3*)worldRayOriginGradientCudaPtr, // TODO: not implemented yet (tcnn::vec3*)worldRayDirectionGradientCudaPtr, // TODO: not implemented yet (const tcnn::vec2*)m_forwardContext->particlesProjectedPosition.data(), diff --git a/threedgut_tracer/src/splatRaster.cpp b/threedgut_tracer/src/splatRaster.cpp index f277acdc..8a400d84 100644 --- a/threedgut_tracer/src/splatRaster.cpp +++ b/threedgut_tracer/src/splatRaster.cpp @@ -192,8 +192,14 @@ SplatRaster::trace(uint32_t frameNumber, int numActiveFeatures, const uint32_t numParticles = particleDensity.size(0); const torch::TensorOptions opts = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); + // Feature output dtype: fp16 halves memory bandwidth for NHT feature buffers +#if FEATURE_OUTPUT_HALF + const torch::TensorOptions featureOpts = torch::TensorOptions().dtype(torch::kHalf).device(torch::kCUDA); +#else + const torch::TensorOptions featureOpts = opts; +#endif - torch::Tensor rayRadianceDensity = torch::zeros({height, width, 4}, opts); + torch::Tensor rayRadianceDensity = torch::zeros({height, width, static_cast(RAY_FEATURE_DIM + 1)}, featureOpts); torch::Tensor rayHitDistance = torch::ones({height, width, 1}, opts).multiply(1e06f); torch::Tensor rayHitCount = torch::zeros({height, width, 1}, opts); torch::Tensor particleVisibility = torch::zeros({numParticles, 1}, opts); @@ -228,7 +234,7 @@ SplatRaster::trace(uint32_t frameNumber, int numActiveFeatures, reinterpret_cast(voidDataPtr(rayDirection)), reinterpret_cast(voidDataPtr(rayHitCount)), reinterpret_cast(voidDataPtr(rayHitDistance)), - reinterpret_cast(voidDataPtr(rayRadianceDensity)), + reinterpret_cast(voidDataPtr(rayRadianceDensity)), reinterpret_cast(voidDataPtr(particleVisibility)), m_parameters, cudaDeviceIndex, @@ -316,8 +322,8 @@ SplatRaster::traceBwd(uint32_t frameNumber, int numActiveFeatures, reinterpret_cast(voidDataPtr(rayDirection)), reinterpret_cast(voidDataPtr(rayHitDistance)), reinterpret_cast(voidDataPtr(rayHitDistanceGradient)), - reinterpret_cast(voidDataPtr(rayRadianceDensity)), - reinterpret_cast(voidDataPtr(rayRadianceDensityGradient)), + reinterpret_cast(voidDataPtr(rayRadianceDensity)), + reinterpret_cast(voidDataPtr(rayRadianceDensityGradient)), rayBackpropagation ? reinterpret_cast(voidDataPtr(rayOriginGradient)) : nullptr, rayBackpropagation ? reinterpret_cast(voidDataPtr(rayDirectionGradient)) : nullptr, m_parameters, cudaDeviceIndex, cudaStream); diff --git a/threedgut_tracer/tracer.py b/threedgut_tracer/tracer.py index e2a80e83..b59a664b 100644 --- a/threedgut_tracer/tracer.py +++ b/threedgut_tracer/tracer.py @@ -176,7 +176,7 @@ def forward( particle_density = torch.concat( [mog_pos, mog_dns, mog_rot, mog_scl, torch.zeros_like(mog_dns)], dim=1 ).contiguous() - particle_radiance = mog_sph.contiguous() + particle_features = mog_sph.contiguous() # dtype set by caller (fp16 when particle_feature_half=true) ray_time = ( torch.ones( @@ -185,11 +185,11 @@ def forward( * sensor_poses.timestamps_us[0] ) - ray_radiance_density, ray_hit_distance, ray_hit_count, mog_visibility = tracer_wrapper.trace( + ray_features_density, ray_hit_distance, ray_hit_count, mog_visibility = tracer_wrapper.trace( frame_id, n_active_features, particle_density, - particle_radiance, + particle_features, ray_ori.contiguous(), ray_dir.contiguous(), ray_time.contiguous(), @@ -204,10 +204,10 @@ def forward( ray_ori, ray_dir, ray_time, - ray_radiance_density, + ray_features_density, ray_hit_distance, particle_density, - particle_radiance, + particle_features, ) ctx.frame_id = frame_id @@ -217,7 +217,7 @@ def forward( ctx.tracer_wrapper = tracer_wrapper return ( - ray_radiance_density, + ray_features_density.float(), # always fp32 to caller; fp16 saved in ctx for trace_bwd ray_hit_distance, ray_hit_count, mog_visibility, @@ -226,7 +226,7 @@ def forward( @staticmethod def backward( ctx, - ray_radiance_density_grd, + ray_features_density_grd, # always fp32 (gradient buffer is never fp16) ray_hit_distance_grd, ray_hit_count_grd_UNUSED, mog_visibility_grd_UNUSED, @@ -235,10 +235,10 @@ def backward( ray_ori, ray_dir, ray_time, - ray_radiance_density, + ray_features_density, ray_hit_distance, particle_density, - particle_radiance, + particle_features, ) = ctx.saved_variables frame_id = ctx.frame_id @@ -246,11 +246,11 @@ def backward( sensor_params = ctx.sensor_params sensor_poses = ctx.sensor_poses - particle_density_grd, particle_radiance_grd = ctx.tracer_wrapper.trace_bwd( + particle_density_grd, particle_features_grd = ctx.tracer_wrapper.trace_bwd( frame_id, n_active_features, particle_density, - particle_radiance, + particle_features, ray_ori, ray_dir, ray_time, @@ -259,8 +259,8 @@ def backward( sensor_poses.timestamps_us[1], sensor_poses.T_world_sensors[0], sensor_poses.T_world_sensors[1], - ray_radiance_density, - ray_radiance_density_grd, + ray_features_density, + ray_features_density_grd, ray_hit_distance, ray_hit_distance_grd, ) @@ -268,7 +268,7 @@ def backward( mog_pos_grd, mog_dns_grd, mog_rot_grd, mog_scl_grd, _ = torch.split( particle_density_grd, [3, 1, 4, 3, 1], dim=1 ) - mog_sph_grd = particle_radiance_grd + mog_sph_grd = particle_features_grd return ( None, # tracer_wrapper @@ -310,7 +310,7 @@ def render(self, gaussians, gpu_batch: Batch, train=False, frame_id=0): num_gaussians = gaussians.num_gaussians with torch.cuda.nvtx.range(f"model.forward({num_gaussians} gaussians)"): ( - pred_rgba, + pred_features_alpha, pred_dist, hits_count, mog_visibility, @@ -324,27 +324,27 @@ def render(self, gaussians, gpu_batch: Batch, train=False, frame_id=0): gaussians.get_rotation().contiguous(), gaussians.get_scale().contiguous(), gaussians.get_density().contiguous(), - gaussians.get_features().contiguous(), + gaussians.get_features().contiguous().half() + if self.conf.render.particle_feature_half + else gaussians.get_features().contiguous(), sensor, poses, ) - pred_rgb = pred_rgba[..., :3].unsqueeze(0).contiguous() - pred_opacity = pred_rgba[..., 3:].unsqueeze(0).contiguous() + # pred_features_alpha is [..., RAY_FEATURE_DIM + 1]: features (fp32) + density + ray_feature_dim = gaussians.ray_feature_dim + pred_features = pred_features_alpha[..., :ray_feature_dim].unsqueeze(0).contiguous() + pred_opacity = pred_features_alpha[..., ray_feature_dim:].unsqueeze(0).contiguous() pred_dist = pred_dist.unsqueeze(0).contiguous() hits_count = hits_count.unsqueeze(0).contiguous() - pred_rgb, pred_opacity = gaussians.background( - gpu_batch.T_to_world.contiguous(), rays_d, pred_rgb, pred_opacity, train - ) - timings = self.tracer_wrapper.collect_times() return { - "pred_rgb": pred_rgb, + "pred_features": pred_features, "pred_opacity": pred_opacity, "pred_dist": pred_dist, - "pred_normals": torch.nn.functional.normalize(torch.ones_like(pred_rgb), dim=3), + "pred_normals": torch.nn.functional.normalize(torch.ones_like(pred_features), dim=3), "hits_count": hits_count, "frame_time_ms": timings["forward_render"] if "forward_render" in timings else 0.0, "mog_visibility": mog_visibility, diff --git a/validate.py b/validate.py new file mode 100644 index 00000000..88b162ed --- /dev/null +++ b/validate.py @@ -0,0 +1,244 @@ +#!/usr/bin/env python3 +""" +Validation script: run 3DGUT and 3DGRT on the drums NeRF-Synthetic scene +(MCMC, 100k particles, 15k iterations) for SH and optionally NHT features, +then produce a report with PSNR, SSIM, training time, and render time. + +Usage +----- + python validate.py --data /path/to/nerf_synthetic/drums [OPTIONS] + +Options +------- + --data PATH Path to the drums scene directory (required) + --out-dir PATH Root output directory (default: runs/validate) + --nht Also run NHT experiments (requires NHT support in the + codebase; silently skipped if unavailable) + --iterations N Training iterations per experiment (default: 15000) + --particles N Initial particle count (default: 100000) + --skip-existing Skip training if the checkpoint already exists +""" + +import argparse +import json +import os +import subprocess +import sys +import time +from pathlib import Path + + +# --------------------------------------------------------------------------- +# NHT availability check +# --------------------------------------------------------------------------- + +def _nht_available() -> bool: + """Return True if this codebase has NHT support compiled in.""" + try: + from threedgrut.model.features import Features + _ = Features.Type.NHT + return True + except (ImportError, AttributeError): + return False + + +# --------------------------------------------------------------------------- +# Experiment definitions +# --------------------------------------------------------------------------- + +def _experiments(args, nht_ok: bool): + """Yield (name, renderer, feature_type, app_config) tuples.""" + base = [ + ("3dgut_sh", "3dgut", "sh", "apps/nerf_synthetic_3dgut_mcmc_nht"), + ("3dgrt_sh", "3dgrt", "sh", "apps/nerf_synthetic_3dgrt_mcmc_nht"), + ] + nht = [ + ("3dgut_nht", "3dgut", "nht", "apps/nerf_synthetic_3dgut_mcmc_nht"), + ("3dgrt_nht", "3dgrt", "nht", "apps/nerf_synthetic_3dgrt_mcmc_nht"), + ] + for entry in base: + yield entry + if args.nht and nht_ok: + for entry in nht: + yield entry + + +# --------------------------------------------------------------------------- +# Train +# --------------------------------------------------------------------------- + +def _find_latest(directory: Path, pattern: str): + """Return the most recently modified file matching pattern under directory, or None.""" + matches = sorted(directory.glob(pattern), key=lambda p: p.stat().st_mtime) + return matches[-1] if matches else None + + +def _train(name: str, app_config: str, feature_type: str, args) -> tuple[float, str]: + """Run training and return (wall_time_seconds, checkpoint_path).""" + exp_dir = Path(args.out_dir) / name + + # Trainer saves under exp_dir/-/ckpt_last.pt + if args.skip_existing: + ckpt = _find_latest(exp_dir, "*/ckpt_last.pt") + if ckpt is not None: + print(f" [skip] checkpoint already exists: {ckpt}") + return 0.0, str(ckpt) + + exp_dir.mkdir(parents=True, exist_ok=True) + + cmd = [ + sys.executable, "train.py", + f"--config-name={app_config}", + f"path={args.data}", + f"out_dir={args.out_dir}", + f"experiment_name={name}", + f"n_iterations={args.iterations}", + f"initialization.num_gaussians={args.particles}", + f"model.feature_type={feature_type}", + # disable feature_output_half for SH (it's set to true in NHT app configs) + f"render.feature_output_half={'true' if feature_type == 'nht' else 'false'}", + # save checkpoint only at the end; intermediate checkpoints not needed here + f"checkpoint.iterations=[{args.iterations}]", + # disable GUI + "with_gui=false", + "with_viser_gui=false", + ] + + print(f" $ {' '.join(cmd)}") + t0 = time.time() + subprocess.run(cmd, check=True) + elapsed = time.time() - t0 + + ckpt = _find_latest(exp_dir, "*/ckpt_last.pt") + if ckpt is None: + raise FileNotFoundError(f"Expected checkpoint not found under: {exp_dir}/*/ckpt_last.pt") + + return elapsed, str(ckpt) + + +# --------------------------------------------------------------------------- +# Render / evaluate +# --------------------------------------------------------------------------- + +def _render(name: str, ckpt: str, args) -> dict: + """Run render.py and return the metrics dict.""" + eval_dir = Path(args.out_dir) / name / "eval" + eval_dir.mkdir(parents=True, exist_ok=True) + + cmd = [ + sys.executable, "render.py", + "--checkpoint", ckpt, + "--out-dir", str(eval_dir), + ] + + print(f" $ {' '.join(cmd)}") + subprocess.run(cmd, check=True) + + # Renderer saves under eval_dir//-/metrics.json + metrics_path = _find_latest(eval_dir, "**/metrics.json") + if metrics_path is None: + raise FileNotFoundError(f"metrics.json not found after render under: {eval_dir}") + + with open(metrics_path) as f: + return json.load(f) + + +# --------------------------------------------------------------------------- +# Report +# --------------------------------------------------------------------------- + +_HEADER = ( + "| Experiment | PSNR (dB) | SSIM | LPIPS | Train (min) | Render (ms/f) |\n" + "|------------------|-----------|--------|--------|-------------|---------------|\n" +) + + +def _row(name: str, m: dict, train_sec: float) -> str: + psnr = f"{m.get('mean_psnr', float('nan')):.2f}" + ssim = f"{m.get('mean_ssim', float('nan')):.4f}" + lpips = f"{m.get('mean_lpips', float('nan')):.4f}" + t_min = f"{train_sec / 60:.1f}" if train_sec > 0 else "—" + r_ms = f"{m.get('mean_inference_time_ms', float('nan')):.2f}" + return f"| {name:<16} | {psnr:>9} | {ssim:>6} | {lpips:>6} | {t_min:>11} | {r_ms:>13} |\n" + + +def _write_report(rows: list[tuple], args, nht_ok: bool) -> str: + scene = Path(args.data).name + lines = [ + f"# Validation Report: {scene}\n\n", + f"Scene: `{args.data}` \n", + f"Iterations: {args.iterations} \n", + f"Particles: {args.particles:,} \n", + f"Strategy: MCMC \n", + f"NHT requested: {args.nht} ", + f"{'(supported)' if nht_ok else '(not available in this build — skipped)'} \n\n", + "## Results\n\n", + _HEADER, + ] + for name, metrics, train_sec in rows: + lines.append(_row(name, metrics, train_sec)) + + report = "".join(lines) + + report_path = Path(args.out_dir) / "report.md" + report_path.parent.mkdir(parents=True, exist_ok=True) + report_path.write_text(report) + return report + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + +def main(): + parser = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter) + parser.add_argument("--data", required=True, help="Path to the drums scene directory") + parser.add_argument("--out-dir", default="runs/validate", help="Root output directory") + parser.add_argument("--nht", action="store_true", help="Also run NHT experiments") + parser.add_argument("--iterations", type=int, default=15000, help="Training iterations per experiment") + parser.add_argument("--particles", type=int, default=100000, help="Initial particle count") + parser.add_argument("--skip-existing", action="store_true", help="Skip training if checkpoint exists") + args = parser.parse_args() + + nht_ok = _nht_available() + if args.nht and not nht_ok: + print("WARNING: --nht requested but NHT is not available in this build; NHT experiments will be skipped.") + + rows = [] + for name, renderer, feature_type, app_config in _experiments(args, nht_ok): + print(f"\n{'='*60}") + print(f"Experiment: {name} (renderer={renderer}, features={feature_type})") + print(f"{'='*60}") + + print("\n[1/2] Training ...") + try: + train_sec, ckpt = _train(name, app_config, feature_type, args) + except subprocess.CalledProcessError as e: + print(f"ERROR: training failed for {name}: {e}") + rows.append((name, {}, 0.0)) + continue + + print(f"\n[2/2] Evaluating ...") + try: + metrics = _render(name, ckpt, args) + except subprocess.CalledProcessError as e: + print(f"ERROR: render failed for {name}: {e}") + rows.append((name, {}, train_sec)) + continue + + rows.append((name, metrics, train_sec)) + print(f"\n PSNR={metrics.get('mean_psnr', '?'):.2f} dB " + f"SSIM={metrics.get('mean_ssim', '?'):.4f} " + f"LPIPS={metrics.get('mean_lpips', '?'):.4f} " + f"render={metrics.get('mean_inference_time_ms', '?'):.2f} ms/frame") + + print(f"\n{'='*60}") + print("REPORT") + print(f"{'='*60}") + report = _write_report(rows, args, nht_ok) + print(report) + print(f"Report saved to: {Path(args.out_dir) / 'report.md'}") + + +if __name__ == "__main__": + main()