diff --git a/README.md b/README.md index f4da8492..d3857e7c 100644 --- a/README.md +++ b/README.md @@ -265,6 +265,14 @@ python train.py --config-name apps/colmap_3dgrt.yaml path=data/mipnerf360/bonsai python train.py --config-name apps/colmap_3dgut.yaml path=data/mipnerf360/bonsai out_dir=runs experiment_name=bonsai_3dgut dataset.downsample_factor=2 optimizer.type=selective_adam ``` +### Post-processing (linear-to-sRGB and PPISP) + +Hydra key: ``post_processing.method``. Values: + +- **null** (default): no change to rendered RGB before the loss. +- **linear-to-srgb**: **IEC 61966-2-1** piecewise linear-to-sRGB encoding on ``pred_rgb`` (same rule as ``thirdparty/tiny-cuda-nn/scripts/common.py`` ``linear_to_srgb``). See ``threedgrut/utils/post_processing_linear_to_srgb.py``. Example: ``post_processing.method=linear-to-srgb``. +- **ppisp**: per-frame camera corrections; requires the ``ppisp`` package (see ``requirements.txt``) and uses the other ``post_processing.*`` fields in ``configs/base_gs.yaml``. + If you use MCMC and Selective Adam in your research, please cite [3dgs-mcmc](https://github.com/ubc-vision/3dgs-mcmc), [taming-3dgs](https://github.com/humansensinglab/taming-3dgs), and the [gSplat](https://github.com/nerfstudio-project/gsplat/tree/main) library from which the code was adopted (links to the code are provided in the source files). diff --git a/configs/base_gs.yaml b/configs/base_gs.yaml index 51248502..18321f45 100644 --- a/configs/base_gs.yaml +++ b/configs/base_gs.yaml @@ -42,11 +42,36 @@ export_usd: enabled: false path: "" apply_normalizing_transform: true - format: standard # "nurec" for Omniverse USDVol internal format, "standard" for USDVol ParticleField3DGaussianSplat + format: standard # "nurec" for internal USDVol format, "standard" for USDVol ParticleField3DGaussianSplat half_precision: false export_cameras: true export_background: true + # zDepth | cameraDistance | rayHitDistance sorting_mode_hint: cameraDistance + # If true, Gaussian prim ColorSpaceAPI uses lin_rec709_scene; else srgb_rec709_display + linear_srgb: false + # Enable post-processing export when the checkpoint contains a supported module. + # Defaults to true; post-processing-export-mode controls how the effect is exported. + export_post_processing: true + # baked-sh fits a fixed post-processing transform into Gaussian SH coefficients. + # omni-native uses the module-specific Omniverse-native path; currently PPISP SPG. + # baked-sh | omni-native + post-processing-export-mode: baked-sh + # Optional fixed PPISP camera/frame for omni-native export. When frame is set, + # exposure/color are exported as static shader inputs instead of animation. + post-processing-export-camera-id: null + post-processing-export-frame-id: null + # Number of sequential passes over the train/reference set used for fitting. + post-processing-bake-epochs: 3 + post-processing-bake-learning-rate: 0.001 + post-processing-bake-camera-id: 0 + post-processing-bake-frame-id: 0 + # none: disable PPISP vignetting during bake. achromatic-fit: chromatic PPISP reference + # with an achromatic fit-only vignette; the achromatic vignette is not exported. + # none | achromatic-fit + ppisp-bake-vignetting-mode: achromatic-fit + # USD timeCodesPerSecond; time codes are bare frame indices so this sets playback speed + frames_per_second: 1.0 model: density_activation: sigmoid @@ -124,7 +149,10 @@ loss: # Post-processing configuration post_processing: - method: null # Possible values: null, "ppisp" + # null | "ppisp" | "linear-to-srgb" + # - linear-to-srgb: IEC piecewise linear-to-sRGB on pred_rgb (same as tiny-cuda-nn common.linear_to_srgb). + # No extra deps; no trainable weights. See threedgrut/utils/post_processing_linear_to_srgb.py. + method: null # Enable the controller for predicting per-frame corrections for novel views. # When false, zero corrections are used for novel views. use_controller: true diff --git a/docs/ppisp-controller-export-plan.md b/docs/ppisp-controller-export-plan.md new file mode 100644 index 00000000..67ba67cc --- /dev/null +++ b/docs/ppisp-controller-export-plan.md @@ -0,0 +1,162 @@ +# PPISP Controller SPG Export — Feasibility & Plan + +Scope: extend the existing PPISP SPG export to include the **controller** — +the per-camera CNN+MLP that predicts per-frame `exposure` and 8-d +`color_latents` from the rendered image. + +Status: feasible. This document records the design before implementation. + +--- + +## 1. Controller summary (from `ppisp.PPISP._PPISPController`) + +Fixed architecture, one instance per camera. Inputs are the raw rendered +HDR image and an optional `prior_exposure` scalar. + +``` +Conv2d(3→16, 1×1, +bias) +MaxPool2d(3, stride=3) +ReLU +Conv2d(16→32, 1×1, +bias) +ReLU +Conv2d(32→64, 1×1, +bias) +AdaptiveAvgPool2d((5, 5)) +Flatten # → 1600 features +concat(prior_exposure) # → 1601 +Linear(1601 → 128) + ReLU +Linear(128 → 128) + ReLU +Linear(128 → 128) + ReLU +exposure_head: Linear(128 → 1) +color_head: Linear(128 → 8) +``` + +Total weights ≈ **240 K floats per camera**. + +The output is **two scalar values for the whole image** (1 exposure, 8 colour +latents). Those nine numbers replace the current static USD time-samples on +the PPISP shader. + +--- + +## 2. SPG capabilities used + +The 3DGRUT SPG pipeline already uses **Slang** in the SPG runtime +(`*.slang`, `*.slang.lua`, `*.slang.usda`) — the public docs that describe +only CUDA kernels are out-of-date; the existing PPISP shader proves Slang +launch is supported. + +Confirmed primitives: + +- `slang.dispatch{ stage="compute", numthreads=…, grid=…, bind={…} }` per + shader prim. +- `slang.ParameterBlock(...)` for grouped scalar/vector inputs that map to + USD attributes. +- `slang.Texture2D / slang.RWTexture2D / slang.empty(shape, dtype)` for + bound textures and lua-allocated outputs. +- Shader-to-shader chaining via USD `omni:rtx:aov` connections on + `RenderVar` prims (the existing `LdrColor` → `PPISP` wiring uses this). + +What we **do not** rely on: +- Multi-dispatch within one Lua launcher (one dispatch per shader prim). +- CooperativeVector / coopvec — not assumed available in the target Kit. +- Non-2D output buffers — only 2D images via `slang.empty`. + +--- + +## 3. Two challenges and how we solve them + +### 3.1 Adaptive avg pooling on a runtime-sized input + +PyTorch's `AdaptiveAvgPool2d((5,5))` partitions the input into exactly 25 +near-equal cells. The cell bounds are: + +``` +i = 0..4 (output row) +start_h = floor(i * H_in / 5) +end_h = ceil((i + 1) * H_in / 5) +``` + +Each Slang thread group computes one output cell `(i, j)` by reading every +input pixel in `[start_h, end_h) × [start_w, end_w)`, applying the +per-pixel CNN forward (3×3 max-pool fused with the surrounding 1×1 +convolutions), and reducing the sum / divide-by-count in shared memory. + +This works for arbitrary input resolution because the cell bounds are +computed inside the shader from `H_in, W_in`. + +### 3.2 Baking MLP and CNN weights into Slang + +Each camera's controller has unique weights. We generate **one Slang file +per camera** at export time, with all weights emitted as +`static const float[]` arrays. Slang's compiler can fold these into +constant memory, and there is no runtime upload step. + +The generated file `ppisp_controller_.slang` includes a fixed +shared template (CNN forward, pool, MLP) and only differs in the weight +constants. The matching `*.slang.lua` and `*.slang.usda` are emitted per +camera as well so each `RenderProduct` references its own controller. + +If weights ever exceed Slang's static-data limits we can fall back to +USD `float[]` inputs bound as a `StructuredBuffer`, but for the +default architecture (~240 K floats) static arrays are fine. + +--- + +## 4. SPG graph + +``` +HdrColor (RenderVar) + │ + ▼ (omni:rtx:aov connection) +PPISPController_ Slang compute, single thread group + │ outputs ControllerParams (1×9 float image) + ▼ +PPISP Slang compute, grid sized to image + │ reads HdrColor + ControllerParams + static vignetting/CRF + ▼ outputs PPISPColor +LdrColor (RenderVar) +``` + +The existing `ppisp_writer.py` builds the second half. The new +controller writer creates the first stage and connects its output as +an additional input to the PPISP shader. + +The PPISP slang shader is **generalised** to read the exposure and the 8 +colour latents from a 1×9 single-channel float texture when one is bound, +falling back to its `ParameterBlock` defaults otherwise. This keeps the +legacy "static parameters per frame" path unchanged — important for users +who train without a controller. + +--- + +## 5. Testing + +Two-pronged approach: + +1. **Unit-level Python check** that the generated Slang reproduces the + PyTorch controller's outputs to within a tight tolerance, using + `slangpy` to dispatch the controller shader against a reference image. + +2. **Tool: `tools/render_ppisp_spg/`** — a slangpy-based runner + that opens an exported USD/USDZ, walks `/Render/` prims, finds + their SPG shader chain, and replays the chain on a supplied HDR input + for every authored time sample. Useful for visual regression and for + cross-checking that the PPISP USD asset produces the same image + sequence as the in-process `apply_post_processing` path used during + training. + +The render tool intentionally does not try to reproduce Kit's full +RenderProduct pipeline; it executes only the SPG `compute` stages so it +remains independent of Kit and useful in headless CI. + +--- + +## 6. Out of scope for this iteration + +- Multi-dispatch optimisation of the controller (currently one slow but + correct compute pass). +- CoopVec acceleration of the MLP matmul. +- Quantising weights to fp16/bf16 to reduce shader source size. +- Runtime weight upload (large `float[]` USD inputs). + +These can be added later if the basic export proves correct. diff --git a/docs/ppisp-to-rtx-pp-plan.md b/docs/ppisp-to-rtx-pp-plan.md new file mode 100644 index 00000000..9822ddc9 --- /dev/null +++ b/docs/ppisp-to-rtx-pp-plan.md @@ -0,0 +1,630 @@ +# PPISP Omniverse USD Post-Processing Fallback Plan + +Scope: investigate whether the PPISP effect currently planned for SPG export can +also be approximated by existing Omniverse USD post-processing settings. + +Goal: provide an Omniverse USD post-processing fallback for Kit versions where SPG is unavailable, +not supported in the target deployment, or affected by SPG bugs. This is not a +replacement for the exact SPG export path. + +User-facing control: expose the fallback as an export parameter named +`ov-post-processing`. The implementation should live in a dedicated USD writer +file and stay separate from the SPG PPISP writer. + +This document is intentionally a plan only. No implementation is approved here. + +Status: `BLOCKED_ON_SPG_IMPLEMENTATION` + +Implementation gate: the SPG PPISP export path is being implemented first. The +`ov-post-processing` fallback should be implemented on top of that work, after +the shared camera grouping and `/Render`/`RenderProduct` authoring are available. + +--- + +## 1. Context + +The current PPISP USD export plan in `docs/ppisp-controller-export-plan.md` uses a custom +SPG shader on each `RenderProduct` because PPISP is a post-blend image-space +operator: + +1. Exposure: per-frame scalar `rgb *= 2**e`. +2. Vignetting: per-camera, per-channel, per-pixel multiplicative falloff. +3. Color correction: per-frame 3x3 homography in RGI space with intensity + renormalisation. +4. CRF: per-camera, per-channel 4-parameter toe/shoulder/gamma curve. + +The question here is whether a secondary PPISP export path can map enough of the +effect onto existing Omniverse USD post-processing controls to be useful when +SPG is not viable in a given Kit runtime. + +Investigation target: + +- Kit rendering post-processing implementation. +- Generated USD render-settings schema. +- Existing USD stages that author post-processing attributes on `RenderProduct` + prims. + +--- + +## 2. USD post-processing surface found + +Kit exposes post-processing in two layers. + +The C++ renderer layer: + +- `Postprocessing::addPostprocessing` +- `Postprocessing::addTonemapping` +- `Postprocessing::addTvNoise` +- `Postprocessing::addRegisteredCompositing` + +The USD/render-settings layer: + +- `RenderProduct` prims can apply post-processing settings API schemas. +- Example stages author post-processing attributes directly on + `/Render/`. +- The generated schema exposes: + - camera exposure settings. + - tonemapping settings. + - color grading settings. + - vignette settings. + +Relevant setting families: + +- Camera exposure: + - `exposure:time` + - `exposure:fStop` + - `exposure:iso` + - `exposure:responsivity` +- Tonemapping: + - tonemap operator. + - tonemap dither. + - advanced carb-backed controls such as `exposureKey`, `whiteScale`, + `maxWhiteLuminance`, `whitepoint`, and `enableSrgbToGamma` +- Color grading: + - grade enabled. + - `blackPoint`, `whitePoint`, `contrast`, `lift`, `gain`, `multiply`, + `offset`, `gamma`, `saturation` +- TV-noise vignette: + - effect enabled. + - vignetting enabled. + - vignetting size. + - vignetting strength. + +Important observed shader behavior: + +- Tonemapping applies exposure through `computeExposureScale`, then one of the + built-in operators: raw/clamp, linear, Reinhard, modified Reinhard, + Hejl-Hable, Hable UC2, ACES approximation, or Iray Reinhard. +- Color correction/grading can run before tonemapping in ACES mode or after + tonemapping in Standard mode. +- TV-noise vignetting uses one scalar radial-ish function: + `pow(uv.x * (1 - uv.x) * uv.y * (1 - uv.y) * (size + 14), strength)`. + +--- + +## 3. Mapping assessment + +### 3.1 Exposure + +Assessment: exact scalar mapping is likely possible. + +Reasoning: + +- PPISP exposure is a scalar multiply by `2**exposure_params[frame_idx]`. +- USD exposure scale is proportional to `exposure:time`, `filmIso`, and + `responsivity`, and inversely proportional to `fStop**2`. +- If all other exposure parameters are held fixed, time-sampling + `exposure:time` as `baseExposureTime * 2**e` should reproduce the scalar + exposure factor. + +Recommended USD mapping: + +- Apply the camera exposure API schema to each camera prim, or author the + equivalent camera exposure attributes if already supported by the target USD + version. +- Time-sample `exposure:time` per frame. +- Keep `exposure:fStop`, `exposure:iso`, and `exposure:responsivity` fixed. +- Disable auto exposure and histogram adaptation for validation. + +Risks: + +- Exposure is embedded in tonemapping. Exactness only holds if the rest + of the tone pipeline is configured so the exposure scale is not folded into a + different nonlinear look. +- Kit has a Gaussian-specific skip-tonemapping path. If active, it may bypass + the tone pass for Gaussian primary hits and therefore bypass the intended + exposure mapping. + +Confidence: 0.75. + +### 3.2 Vignetting + +Assessment: approximate only. + +Reasoning: + +- PPISP vignetting is per-camera, per-channel, and parameterized by five values + per channel. +- USD vignette control is a scalar function shared across RGB, controlled by + only `size` and `strength`. +- The vignette function is centered in normalized screen space and does not expose + per-channel coefficients or arbitrary polynomial/radial terms. + +Recommended USD mapping: + +- Use the USD vignette API schema only for an approximation path. +- Enable TV noise and vignetting, but disable film grain, scanlines, ghosting, + scrolling, random splotches, wave distortion, vertical lines, and flicker. +- Fit one scalar vignette curve per physical camera to the luminance average + of the PPISP RGB vignette map. +- Record the per-channel residual, because color-dependent vignetting cannot be + represented by this scalar control. + +Risks: + +- The TV-noise vignette pass is semantically part of an analog TV effect, not a + calibrated camera response model. +- It may run after tonemapping, while PPISP vignetting is before color + correction and CRF. This changes the result when later nonlinear operations + are enabled. + +Confidence: 0.45. + +### 3.3 Color Correction + +Assessment: approximate only, and likely weak for scenes with cross-channel +mixing. + +Reasoning: + +- PPISP color correction is a per-frame 3x3 homography in RGI space with + intensity renormalisation. +- USD color correction and color grading expose channel-wise saturation, + contrast, gain, gamma, offset, lift, multiply, black point, and white point. +- These controls do not expose a general 3x3 matrix or a homography with + intensity renormalisation. + +Recommended USD mapping: + +- Prefer the USD color-grading API schema over legacy color-correction carb + settings because it is present in the generated USD schema and examples. +- Use Standard mode for validation if the desired fit is after a linear + tonemap, and ACES mode only if validation shows the color space conversion is + closer to PPISP's RGI-space transform. +- Fit per-frame `gain`, `offset`, `gamma`, `contrast`, and `saturation` to + sampled RGB pairs generated by the trained PPISP transform. +- Treat any off-diagonal color coupling in the PPISP homography as residual + error, not as exportable data. + +Risks: + +- The generated USD schema exposes color grading attributes, but not every + advanced carb setting is necessarily intended for portable USD authoring. +- Time-sampled `RenderProduct` attributes should be verified in Kit, because the + schema examples are mostly static. + +Confidence: 0.35. + +### 3.4 CRF + +Assessment: no exact mapping in existing USD post-processing. + +Reasoning: + +- PPISP CRF is per-camera, per-channel, and has four learned parameters per + channel. +- USD tonemapping provides a small set of global operators. Iray Reinhard adds + `crushBlacks`, `burnHighlights`, and saturation, but not per-channel + toe/shoulder/gamma parameters. +- USD color grading `gamma` is per-channel, but it is not a learned + toe/shoulder CRF. + +Recommended USD mapping: + +- Use the built-in tonemapper only as a coarse approximation. +- Evaluate two candidate fits: + - `operator = "none"` or `"raw"` plus color grading gamma/gain/offset. + - `operator = "iray"` plus Iray Reinhard crush/burn controls and color + grading compensation. +- Fit per camera, not per frame, because PPISP CRF is per camera. + +Risks: + +- A fitted USD tonemapper may interact with exposure and color grading in ways + that make individual PPISP components hard to validate independently. +- Per-channel CRF differences are not representable by global tonemap + operators. + +Confidence: 0.25. + +--- + +## 4. Candidate architectures + +### Option R0 — SPG-only export + +Keep the SPG plan from `docs/ppisp-controller-export-plan.md` as the only PPISP-preserving +export path. + +Use USD post-processing only for user-authored artistic settings unrelated to +PPISP. + +Recommendation: best default path when the target Kit version has reliable SPG +support. + +### Option R1 — Exposure-only USD fallback + +Export only PPISP exposure through time-sampled camera exposure attributes. +Leave vignetting, color correction, and CRF unexported or keep them in SPG. + +Recommendation: useful as the lowest-risk fallback for older Kit versions where +SPG is unavailable but some PPISP brightness matching is better than no PPISP +signal. + +### Option R2 — USD post-processing fallback + +Fit the full PPISP effect into existing USD settings: + +- Exposure via `exposure:time`. +- Vignetting via TV-noise vignette. +- Color correction via color grading. +- CRF via tonemap plus color grading. + +Recommendation: primary USD fallback candidate for Kit versions with no SPG +support or known SPG bugs. It should be advertised as approximate and version +gated. + +### Option R3 — Hybrid export for validation and migration + +Export both: + +- Exact PPISP SPG `RenderVar` path for validation and high fidelity. +- Approximate USD post-processing attributes for viewers that do not support the + custom SPG shader. + +Recommendation: best investigation mode when validating the USD fallback against +SPG in newer Kit versions, or when the same asset must run across mixed Kit +deployments. + +--- + +## 5. Proposed reviewable tasks + +### T-R0 — Build a PPISP reference response sampler + +Purpose: create a test-only numeric reference for comparing PPISP against USD +approximations. + +Inputs: + +- A trained or synthetic PPISP instance. +- A small set of RGB sample grids. +- Camera index and frame index. + +Output: + +- Per-stage reference outputs: + - after exposure + - after vignetting + - after color correction + - after CRF + - final + +Test write-up: + +- Use identity PPISP parameters and assert output equals input. +- Enable only exposure and assert output equals `rgb * 2**e`. +- Enable one non-identity operation at a time and store deterministic numeric + fixtures. + +### T-R1 — Validate USD exposure equivalence + +Purpose: prove whether `exposure:time = baseExposureTime * 2**e` matches PPISP +exposure under controlled USD settings. + +Test write-up: + +- Create a USD stage with one camera and one `RenderProduct`. +- Disable auto exposure, dither, color grading, TV noise, and nonlinear + tonemapping. +- Render a known flat-color target at several exposure values. +- Compare captured output ratios against `2**e`. +- Repeat with Gaussian skip-tonemapping enabled and disabled. + +Pass criterion: + +- Relative error below a chosen tolerance, proposed initial threshold: `1e-3` + for linear floating-point captures. + +### T-R2 — Fit and validate USD vignette + +Purpose: quantify how close the built-in USD vignette can get to PPISP +vignetting. + +Test write-up: + +- Generate PPISP vignetting maps for each camera. +- Fit `vignetting:size` and `vignetting:strength` to the luminance-average + PPISP map. +- Render a flat-color image through the vignette pass with all + other TV effects disabled. +- Compare spatial error and per-channel residual. + +Pass criterion: + +- Report RMSE and max error. Do not enforce pass/fail until real datasets are + sampled. + +### T-R3 — Fit USD color grading to PPISP color correction + +Purpose: determine whether the color grading controls can approximate the PPISP +3x3 RGI homography acceptably. + +Test write-up: + +- Sample RGB values across the training color range. +- Apply PPISP color correction for selected frames. +- Fit USD grade controls to minimize color error. +- Validate on held-out RGB samples and on rendered frames. + +Pass criterion: + +- Report `meanDeltaRgb`, `p95DeltaRgb`, and max channel error. +- Flag frames where off-diagonal homography terms dominate the residual. + +### T-R4 — Fit USD tonemap/grade to PPISP CRF + +Purpose: quantify CRF approximation quality with built-in USD tone operators. + +Test write-up: + +- For each camera, sample the PPISP per-channel CRF curves. +- Fit candidate USD settings: + - raw or none tonemap plus grade gamma/gain/offset + - Iray Reinhard plus grade compensation +- Validate per-channel curve error and final image error. + +Pass criterion: + +- Report per-camera curve RMSE and max error. +- Reject USD-only export for cameras whose CRF fit exceeds the selected + threshold. + +### T-R5 — Author a minimal USD post-processing prototype + +Purpose: verify the USD authoring model without touching the production exporter. + +Expected prototype shape: + +- `/Render/` `RenderProduct` +- Applied schemas: + - tonemapping API schema. + - color-grading API schema. + - vignette API schema. +- Camera exposure attributes on the referenced camera prim. +- `orderedVars` containing `LdrColor`. + +Test write-up: + +- Open the generated USD in Kit with the required render-settings schema enabled. +- Verify authored attributes appear in the active render settings context. +- Capture output and compare against the PPISP reference sampler. + +### T-R6 — Define USD fallback policy + +Purpose: choose when the USD post-processing fallback should be offered after +the numeric validation tasks. + +Decision points: + +- Identify the minimum Kit version where SPG is reliable enough to prefer + `spgExact`. +- Identify older Kit versions or known SPG bug IDs where the USD fallback should + be available. +- If only exposure is accurate, add an exposure-only USD fallback mode. +- If fitted error is acceptable for target datasets, add an approximate USD + post-processing fallback mode. +- If errors are high, keep USD post-processing as an explicit degraded fallback + and document SPG as required for fidelity. + +Test write-up: + +- Produce a short validation report with per-stage errors and example captures. +- Require explicit approval before implementing any exporter changes. + +--- + +## 6. Feasibility Report + +Assessment: feasible with moderate implementation risk. + +The standard USD exporter already has the right integration points: + +- `configs/base_gs.yaml` contains an `export_usd` block for export parameters. +- `USDExporter.from_config` centralizes conversion from config to exporter + constructor arguments. +- `USDExporter.export` already has access to `model`, `dataset`, `conf`, and + `background`. +- `trainer.py` already chooses `USDExporter` when `export_usd.format` is + `standard`. + +The main missing dependency is not the export parameter itself, but access to the +trained PPISP module during USD export. Today `trainer.py` calls: + +```text +exporter.export(..., dataset=self.train_dataset, conf=conf, background=...) +``` + +For any PPISP-derived fallback, this call must also pass +`post_processing=self.post_processing` when `post_processing.method == "ppisp"`. +That is already required by the SPG export plan, so the fallback should reuse +the same exporter-facing data path. + +Recommended export parameters: + +```yaml +export_usd: + export_ppisp: false + ov-post-processing: none +``` + +`export_ppisp` is the gate for PPISP export. If it is `false`, no PPISP effect +is exported and `ov-post-processing` must be `none`. + +Allowed `ov-post-processing` values when `export_ppisp` is `true`: + +- `none`: use the full SPG PPISP path and do not author fallback settings. +- `ppisp-exposure-fallback`: export only PPISP exposure through USD camera exposure. +- `ppisp-fitted-post-processing-fallback`: export the fitted USD post-processing approximation. +- `ppisp-spg-plus-fitted-post-processing-fallback`: author the fitted fallback attributes alongside the SPG + path for validation or mixed Kit deployments. + +Implementation note: because `ov-post-processing` is hyphenated, Python code +should read it with `export_conf.get("ov-post-processing", "none")`, not +`conf.export_usd.ov-post-processing` or normal dot access. + +Dedicated file recommendation: + +```text +threedgrut/export/usd/writers/ov_post_processing.py +``` + +Suggested public API: + +```python +def add_ov_post_processing( + stage, + render_product_entries, + post_processing, + dataset, + mode: str, +) -> None: + ... +``` + +Responsibilities of the dedicated file: + +- Validate that `mode` is one of the supported `ov-post-processing` values. +- Validate that `post_processing` is a PPISP instance for PPISP-derived modes. +- Author USD post-processing API schemas and attributes on `RenderProduct` + prims. +- Author camera exposure attributes for the exposure fallback. +- Fit or consume fitted parameters for vignette, color grading, and CRF + approximation. +- Log an explicit warning when falling back to degraded behavior. + +Responsibilities that should stay outside the dedicated file: + +- Camera prim creation. +- `/Render` scope and `RenderProduct` creation. +- SPG shader authoring and sidecar packaging. +- Exporter config parsing beyond passing the selected mode. + +Feasibility by mode: + +- `none`: high feasibility. Uses the existing full SPG PPISP path when + `export_ppisp` is true. +- `ppisp-exposure-fallback`: high feasibility. Requires camera prims and time-sampled + exposure authoring only. +- `ppisp-fitted-post-processing-fallback`: medium feasibility. USD authoring is straightforward, but + fitting PPISP vignetting/color/CRF into USD controls needs validation and may + have visible residuals. +- `ppisp-spg-plus-fitted-post-processing-fallback`: high feasibility after SPG and fallback paths both + exist. It is mostly orchestration and validation. + +Primary risks: + +- Current `USDExporter` exports one camera per frame; the SPG plan already notes + this must become one prim per physical camera with time-sampled transforms + before per-camera `RenderProduct` post-processing can be cleanly authored. +- The current exporter does not create `/Render` or `RenderProduct` prims. + The fallback depends on the same `render_product.py` foundation as the SPG + plan. +- Time-sampled `RenderProduct` USD attributes need validation in the target Kit + versions. +- The fallback is approximate by design. Documentation and logs must make this + visible so users do not mistake it for SPG fidelity. + +Overall recommendation: implement the feature as a small orchestrated extension +after the shared camera and `RenderProduct` groundwork from the SPG plan. Keep +`export_ppisp` disabled by default. When `export_ppisp` is enabled, use +`ov-post-processing` to choose between full SPG and explicit fallback modes. + +Execution dependency: wait for the SPG implementation to land, then add the OV +post-processing writer as a follow-up layer that reuses the SPG path's camera, +time-code, and `RenderProduct` infrastructure. + +Confidence: 0.8 for exporter/config feasibility, 0.45 for final visual fidelity +of the full PPISP approximation. + +--- + +## 7. Recommended architecture if approved later + +Use `export_ppisp` as the PPISP export gate and `ov-post-processing` as the +implementation selector: + +```text +export_usd: + export_ppisp: false + ov-post-processing: none +``` + +Behavior: + +- `export_ppisp: false`: no PPISP export. `ov-post-processing` must be `none`. +- `export_ppisp: true`, `ov-post-processing: none`: export PPISP through the + full SPG path. +- `export_ppisp: true`, `ov-post-processing: ppisp-exposure-fallback`: export + PPISP through camera exposure fallback only. +- `export_ppisp: true`, `ov-post-processing: ppisp-fitted-post-processing-fallback`: + export PPISP through fitted Omniverse USD post-processing fallback only. +- `export_ppisp: true`, `ov-post-processing: ppisp-spg-plus-fitted-post-processing-fallback`: + write both the SPG exact path and the fitted USD fallback attributes for + validation or mixed-version deployment. + +Keep the approximation code isolated from the exact SPG writer: + +```text +threedgrut/export/usd/ + writers/ + ov_post_processing.py + ppisp_writer.py +``` + +Rationale: + +- The USD mapping is a fitted approximation, not a semantic equivalent of + PPISP. +- The USD fallback exists for deployment compatibility with older or buggy Kit + SPG support, not to displace the high-fidelity SPG path. +- Keeping a separate backend makes review easier and prevents silent quality + regressions in the exact export path. + +--- + +## 8. Open questions + +- Which Kit versions need the USD fallback because SPG is unavailable? +- Which known SPG bugs should trigger or recommend the USD fallback path? +- What error threshold is acceptable for a degraded fallback export? +- Should the approximation target linear floating-point `LdrColor`, gamma + output, or Kit viewport screenshots? +- Are time-sampled `RenderProduct` post-processing attributes supported and + stable in the target Kit version? +- Should Gaussian skip-tonemapping be disabled for PPISP USD approximation, or + is the exported Gaussian material already authored for that path? +--- + +## 9. Current recommendation + +Use USD post-processing only as an explicit fallback alternative for older +Kit versions or known SPG failure modes. + +Exact PPISP export should remain SPG-based for Kit versions where SPG is +available and reliable. The USD fallback path should be version-gated, labeled +approximate, and validated against SPG/reference PPISP before use on target +datasets. + +Recommended next step: review this document and edit the task list or thresholds +before any implementation begins. diff --git a/pyproject.toml b/pyproject.toml index 797435ba..0cde926b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -53,7 +53,7 @@ dependencies = [ # # slangtorch on amd64 only # "slangtorch==1.3.18; sys_platform == 'linux' and platform_machine == 'x86_64'", # usd-core only available for amd64 - "usd-core>=26.3; sys_platform == 'linux' and platform_machine == 'x86_64'", + "usd-core>=26.5; sys_platform == 'linux' and platform_machine == 'x86_64'", ] [project.optional-dependencies] @@ -63,6 +63,7 @@ dev = [ "clang-format==18.1.8", ] gui = [ + "flip-evaluator", "libigl", "polyscope>=2.6.0", "viser", diff --git a/requirements.txt b/requirements.txt index 928e9c01..d0c73377 100644 --- a/requirements.txt +++ b/requirements.txt @@ -28,7 +28,7 @@ libigl pygltflib # --find-links https://nvidia-kaolin.s3.us-east-2.amazonaws.com/torch-2.1.2_cu118.html # kaolin==0.17.0 -usd-core>=26.3 +usd-core>=26.5 ppisp @ git+https://github.com/nv-tlabs/ppisp@v1.0.1 # NCore dataset support (https://github.com/NVIDIA/ncore) nvidia-ncore>=19.0.0 @@ -39,3 +39,4 @@ isort==5.13.0 # graphics user interfaces polyscope>=2.3.0 viser +flip-evaluator diff --git a/threedgrut/datasets/__init__.py b/threedgrut/datasets/__init__.py index 64317f59..bd450072 100644 --- a/threedgrut/datasets/__init__.py +++ b/threedgrut/datasets/__init__.py @@ -184,6 +184,73 @@ def make(name: str, config, ray_jitter): return train_dataset, val_dataset +def make_train(name: str, config, ray_jitter=None): + match name: + case "nerf": + dataset = NeRFDataset( + config.path, + split="train", + bg_color=config.model.background.color, + ray_jitter=ray_jitter, + ) + case "colmap": + # Load EXIF exposure data if enabled + if config.dataset.get("load_exif", True): + exif_exposures = _load_colmap_exif_exposures( + config.path, + config.dataset.downsample_factor, + ) + else: + exif_exposures = None + + dataset = ColmapDataset( + config.path, + split="train", + downsample_factor=config.dataset.downsample_factor, + test_split_interval=config.dataset.test_split_interval, + ray_jitter=ray_jitter, + exif_exposures=exif_exposures, + ) + case "scannetpp": + dataset = ScannetppDataset( + config.path, + split="train", + ray_jitter=ray_jitter, + downsample_factor=config.dataset.downsample_factor, + test_split_interval=config.dataset.test_split_interval, + ) + case "ncore": + dataset = NCoreDataset( + datapath=config.path, + device="cuda", + split="train", + camera_ids=config.dataset.get("camera_ids", None), + lidar_ids=config.dataset.get("lidar_ids", None), + downsample=config.dataset.get("downsample", 1.0), + sample_full_image=config.dataset.train.get("sample_full_image", True), + window_size=config.dataset.train.get("window_size", 256), + n_samples_per_epoch=config.dataset.train.get("n_samples_per_epoch", 1000), + n_train_sample_timepoints=config.dataset.train.get("n_train_sample_timepoints", 1), + n_train_sample_camera_rays=config.dataset.train.get("n_train_sample_camera_rays", 4096), + n_val_image_subsample=config.dataset.get("n_val_image_subsample", 1), + val_frame_interval=config.dataset.get("val_frame_interval", 8), + seek_offset_sec=config.dataset.train.get("seek_offset_sec", 0.0), + duration_sec=config.dataset.train.get("duration_sec", None), + poses_component_group=config.dataset.get("poses_component_group", "default"), + intrinsics_component_group=config.dataset.get("intrinsics_component_group", "default"), + masks_component_group=config.dataset.get("masks_component_group", "default"), + jpeg_backend_cpu=config.dataset.get("jpeg_backend_cpu", "simplejpeg"), + simplejpeg_fastdct=config.dataset.get("simplejpeg_fastdct", False), + simplejpeg_fastupsample=config.dataset.get("simplejpeg_fastupsample", False), + lidar_color_generic_data_name=config.dataset.get("lidar_color_generic_data_name", "rgb"), + ) + case _: + raise ValueError( + f'Unsupported dataset type: {config.dataset.type}. Choose between: ["colmap", "nerf", "scannetpp", "ncore"].' + ) + return dataset + + def make_test(name: str, config): match name: case "nerf": diff --git a/threedgrut/export/importers/nurec_usd.py b/threedgrut/export/importers/nurec_usd.py index a9254e22..e9cc49de 100644 --- a/threedgrut/export/importers/nurec_usd.py +++ b/threedgrut/export/importers/nurec_usd.py @@ -41,6 +41,17 @@ _STATE_N_ACTIVE = ".gaussians_nodes.gaussians.n_active_features" _STATE_EXTRA_SIGNAL = ".gaussians_nodes.gaussians.extra_signal" +_GAUSSIANS_NODES_PREFIX = ".gaussians_nodes." +# Per-node tensor suffixes (same layout as fill_3dgut_template / static gaussians). +_REQUIRED_GAUSSIAN_NODE_KEYS = ( + "positions", + "rotations", + "scales", + "densities", + "features_albedo", + "features_specular", +) + def _find_nurec_volume_prim(stage: Usd.Stage) -> Optional[Usd.Prim]: """Find the NuRec Volume prim (UsdVol::Volume with omni:nurec:isNuRecVolume).""" @@ -108,6 +119,84 @@ def _tensor_from_state(state: dict, key: str, dtype=np.float16, shape_key: Optio return arr.astype(np.float32) +def _discover_gaussians_nodes_prefixes(state: dict) -> list[str]: + """Find state_dict prefixes like '.gaussians_nodes.background' that hold full Gaussian tensors.""" + found: set[str] = set() + for k in state: + if not isinstance(k, str) or not k.endswith(".positions"): + continue + prefix = k[: -len(".positions")] + if not prefix.startswith(_GAUSSIANS_NODES_PREFIX): + continue + if all(state.get(f"{prefix}.{suffix}") is not None for suffix in _REQUIRED_GAUSSIAN_NODE_KEYS): + found.add(prefix) + return sorted(found) + + +def _load_merged_gaussian_tensors_from_state( + state: dict, +) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, Optional[list[int]]]: + """Load positions…specular from state_dict, merging multiple .gaussians_nodes. blocks if present.""" + prefixes = _discover_gaussians_nodes_prefixes(state) + if not prefixes: + positions = _tensor_from_state(state, _STATE_POSITIONS) + rotations = _tensor_from_state(state, _STATE_ROTATIONS) + scales = _tensor_from_state(state, _STATE_SCALES) + densities = _tensor_from_state(state, _STATE_DENSITIES) + features_albedo = _tensor_from_state(state, _STATE_FEATURES_ALBEDO) + features_specular = _tensor_from_state(state, _STATE_FEATURES_SPECULAR) + n_active = state.get(_STATE_N_ACTIVE) + n_active_vals = None + if n_active is not None: + n_active_vals = [int(np.frombuffer(n_active, dtype=np.int64)[0])] + return ( + positions, + rotations, + scales, + densities, + features_albedo, + features_specular, + n_active_vals, + ) + + chunks: dict[str, list[np.ndarray]] = {k: [] for k in _REQUIRED_GAUSSIAN_NODE_KEYS} + n_active_per_node: list[int] = [] + counts: list[tuple[str, int]] = [] + for pref in prefixes: + for suffix in _REQUIRED_GAUSSIAN_NODE_KEYS: + chunks[suffix].append(_tensor_from_state(state, f"{pref}.{suffix}")) + na_key = f"{pref}.n_active_features" + na_raw = state.get(na_key) + if na_raw is not None: + n_active_per_node.append(int(np.frombuffer(na_raw, dtype=np.int64)[0])) + counts.append((pref, int(chunks["positions"][-1].shape[0]))) + + for suffix in _REQUIRED_GAUSSIAN_NODE_KEYS: + ref_tail = chunks[suffix][0].shape[1:] + for i, arr in enumerate(chunks[suffix][1:], start=1): + if arr.shape[1:] != ref_tail: + raise ValueError( + f"NuRec state_dict: incompatible '{suffix}' trailing dims across nodes " + f"({prefixes[0]} {ref_tail} vs {prefixes[i]} {arr.shape[1:]})" + ) + + merged = {suffix: np.concatenate(chunks[suffix], axis=0) for suffix in _REQUIRED_GAUSSIAN_NODE_KEYS} + logger.info( + "NuRec: merged %d Gaussian node(s) %s", + len(prefixes), + ", ".join(f"{p}={n}" for p, n in counts), + ) + return ( + merged["positions"], + merged["rotations"], + merged["scales"], + merged["densities"], + merged["features_albedo"], + merged["features_specular"], + n_active_per_node if n_active_per_node else None, + ) + + def _rotation_matrix_to_quat_wxyz(R: np.ndarray) -> np.ndarray: """Convert 3x3 rotation matrix to wxyz quaternion (one quat).""" trace = R[0, 0] + R[1, 1] + R[2, 2] @@ -248,16 +337,15 @@ def _load_stage(self, stage_path: Path, resolution_root: Path) -> Tuple[Gaussian raw = _load_nurec_bytes(resolution_root, nurec_path) state = _decode_state_dict(raw) - positions = _tensor_from_state(state, _STATE_POSITIONS) - rotations = _tensor_from_state(state, _STATE_ROTATIONS) - scales = _tensor_from_state(state, _STATE_SCALES) - densities = _tensor_from_state(state, _STATE_DENSITIES) - features_albedo = _tensor_from_state(state, _STATE_FEATURES_ALBEDO) - features_specular = _tensor_from_state(state, _STATE_FEATURES_SPECULAR) + positions, rotations, scales, densities, features_albedo, features_specular, n_active_list = ( + _load_merged_gaussian_tensors_from_state(state) + ) - n_active = state.get(_STATE_N_ACTIVE) - if n_active is not None: - sh_degree = int(np.frombuffer(n_active, dtype=np.int64)[0]) + if n_active_list is not None: + unique_deg = set(n_active_list) + if len(unique_deg) > 1: + logger.warning("NuRec nodes disagree on n_active_features %s; using max", n_active_list) + sh_degree = max(n_active_list) else: # Infer from features_specular shape: (N, (degree+1)^2 - 1) * 3 n_spec = features_specular.shape[1] diff --git a/threedgrut/export/scripts/export_usd.py b/threedgrut/export/scripts/export_usd.py index a696aabb..ce030396 100644 --- a/threedgrut/export/scripts/export_usd.py +++ b/threedgrut/export/scripts/export_usd.py @@ -21,10 +21,12 @@ python -m threedgrut.export.scripts.export_usd --checkpoint path/to/checkpoint.pt --output output.usdz # Export with NuRec format (Omniverse compatibility) - python -m threedgrut.export.scripts.export_usd --checkpoint path/to/checkpoint.pt --output output.usdz --format nurec + python -m threedgrut.export.scripts.export_usd --checkpoint path/to/checkpoint.pt \ + --output output.usdz --format nurec # Export without cameras/background - python -m threedgrut.export.scripts.export_usd --checkpoint path/to/checkpoint.pt --output output.usdz --no-cameras --no-background + python -m threedgrut.export.scripts.export_usd --checkpoint path/to/checkpoint.pt \ + --output output.usdz --no-cameras --no-background """ import argparse @@ -35,6 +37,10 @@ import torch from threedgrut.export import NuRecExporter, USDExporter +from threedgrut.export.usd.particle_field_hints import ( + DEFAULT_PARTICLE_FIELD_SORTING_MODE_HINT, + PARTICLE_FIELD_SORTING_MODE_HINTS, +) from threedgrut.utils.logger import logger @@ -122,6 +128,144 @@ def parse_args(): action="store_true", help="Set prim color space to lin_rec709_scene (linear). Default is srgb_rec709_display.", ) + parser.add_argument( + "--sorting-mode-hint", + type=str, + choices=PARTICLE_FIELD_SORTING_MODE_HINTS, + default=None, + help=( + "ParticleField sortingModeHint for standard USD export. " + "Use rayHitDistance for ray-tracing renderers that support ray-hit sorting." + ), + ) + post_processing_group = parser.add_mutually_exclusive_group() + post_processing_group.add_argument( + "--export-post-processing", + dest="export_post_processing", + action="store_true", + default=None, + help="Export post-processing effects when the checkpoint contains a supported post-processing module.", + ) + post_processing_group.add_argument( + "--no-export-post-processing", + dest="export_post_processing", + action="store_false", + help="Skip post-processing export even when the checkpoint contains a supported post-processing module.", + ) + parser.add_argument( + "--post-processing-export-mode", + type=str, + choices=["baked-sh", "omni-native"], + default=None, + help="Post-processing export mode. 'omni-native' uses PPISP SPG and Omniverse material authoring.", + ) + parser.add_argument( + "--post-processing-export-camera-id", + type=int, + default=None, + help="Optional PPISP camera id to use for every RenderProduct in omni-native export.", + ) + parser.add_argument( + "--post-processing-export-frame-id", + type=int, + default=None, + help="Optional PPISP frame id to write as static omni-native shader inputs instead of animation.", + ) + parser.add_argument( + "--ignore-ppisp-controller", + action="store_true", + help=( + "If the checkpoint contains trained PPISP controllers, ignore them and " + "export the optimized per-frame exposure/color parameters as time-sampled " + "USD attributes instead. Has no effect when the checkpoint has no controllers." + ), + ) + parser.add_argument( + "--post-processing-bake-epochs", + type=int, + default=None, + help="Number of sequential passes over the train/reference set for post-processing baked-SH export.", + ) + parser.add_argument( + "--post-processing-bake-learning-rate", + type=float, + default=None, + help="Adam learning rate for features_albedo (default 2.5e-3, matches 3DGS).", + ) + parser.add_argument( + "--post-processing-bake-learning-rate-specular", + type=float, + default=None, + help="Adam learning rate for features_specular (default = albedo lr / 20, matches 3DGS).", + ) + parser.add_argument( + "--post-processing-bake-learning-rate-density", + type=float, + default=None, + help="Adam learning rate for density (default 5e-2, matches 3DGS).", + ) + parser.add_argument( + "--post-processing-bake-camera-id", + type=int, + default=None, + help="Camera id used by the fixed post-processing baked-SH export.", + ) + parser.add_argument( + "--post-processing-bake-frame-id", + type=int, + default=None, + help="Frame id used by the fixed post-processing baked-SH export.", + ) + parser.add_argument( + "--ppisp-bake-vignetting-mode", + type=str, + choices=["none", "achromatic-fit"], + default=None, + help=( + "Vignetting handling for PPISP baked-SH fitting. 'none' disables PPISP vignetting; " + "'achromatic-fit' uses chromatic PPISP reference and an achromatic fit-only vignette." + ), + ) + parser.add_argument( + "--post-processing-bake-view-mode", + type=str, + choices=["training", "trajectory"], + default=None, + help=( + "Which views the bake fit sees per step. 'training' (default) iterates the train " + "dataloader. 'trajectory' orders views along an NN+2-opt camera path and samples " + "random t in [0,1] -- useful when training views are sparse." + ), + ) + parser.add_argument( + "--post-processing-bake-view-seed", + type=int, + default=None, + help="Optional RNG seed for the interpolation samplers (None = non-deterministic).", + ) + parser.add_argument( + "--post-processing-bake-trajectory-weight-position", + type=float, + default=None, + help="Trajectory mode only: weight on the (mean-normalised) position term in pose distance.", + ) + parser.add_argument( + "--post-processing-bake-trajectory-weight-rotation", + type=float, + default=None, + help="Trajectory mode only: weight on the (1 - cos(angle)) rotation term in pose distance.", + ) + parser.add_argument( + "--output-scale", + type=float, + default=None, + help=( + "Multiplicative scale applied to the SH-evaluated RGB output of the " + "exported asset. Default 1.0 (no-op). The DC offset is compensated so " + "rendered output equals output-scale x original eval. Useful for " + "matching downstream tonemap exposure." + ), + ) # Dataset path (optional, overrides checkpoint's dataset path) parser.add_argument( @@ -138,10 +282,72 @@ def parse_args(): action="store_true", help="Enable verbose logging", ) + parser.add_argument( + "--no-usd-validate", + action="store_true", + help="Skip OpenUSD stage validation after standard (ParticleField) export", + ) return parser.parse_args() +def _load_ppisp_from_checkpoint(checkpoint, conf): + """Load trained PPISP state for USD export when available.""" + post_conf = getattr(conf, "post_processing", None) + if "post_processing" not in checkpoint or post_conf is None or getattr(post_conf, "method", None) != "ppisp": + return None + + try: + from ppisp import PPISP, PPISPConfig + except ImportError: + logger.warning("Checkpoint contains PPISP state, but ppisp is not available; skipping PPISP USD export") + return None + + use_controller = post_conf.get("use_controller", True) + n_distillation_steps = post_conf.get("n_distillation_steps", 5000) + if use_controller and n_distillation_steps > 0: + main_training_steps = conf.n_iterations - n_distillation_steps + controller_activation_ratio = main_training_steps / conf.n_iterations + controller_distillation = True + elif use_controller: + controller_activation_ratio = 0.8 + controller_distillation = False + else: + controller_activation_ratio = 0.0 + controller_distillation = False + + ppisp_config = PPISPConfig( + use_controller=use_controller, + controller_distillation=controller_distillation, + controller_activation_ratio=controller_activation_ratio, + ) + post_processing = PPISP.from_state_dict(checkpoint["post_processing"]["module"], config=ppisp_config) + post_processing = post_processing.to("cpu") + logger.info("Loaded PPISP post-processing state for USD export") + return post_processing + + +def _get_export_conf_value(export_conf, dashed_name: str, attr_name: str, default): + if hasattr(export_conf, "get"): + return export_conf.get(dashed_name, getattr(export_conf, attr_name, default)) + return getattr(export_conf, attr_name, default) + + +def _get_export_post_processing_default(export_conf): + if hasattr(export_conf, "get"): + return export_conf.get( + "export-post-processing", + getattr(export_conf, "export_post_processing", True), + ) + return getattr(export_conf, "export_post_processing", True) + + +def _arg_or_conf(cli_value, export_conf, dashed_name: str, attr_name: str, default): + if cli_value is not None: + return cli_value + return _get_export_conf_value(export_conf, dashed_name, attr_name, default) + + def load_model_from_checkpoint(checkpoint_path: str): """Load a 3DGRUT model from checkpoint.""" from threedgrut.model.model import MixtureOfGaussians @@ -164,7 +370,8 @@ def load_model_from_checkpoint(checkpoint_path: str): model.init_from_checkpoint(checkpoint, setup_optimizer=False) model.eval() - return model, conf, model.background + post_processing = _load_ppisp_from_checkpoint(checkpoint, conf) + return model, conf, model.background, post_processing def main(): @@ -184,7 +391,7 @@ def main(): # Load model from checkpoint try: - model, conf, background = load_model_from_checkpoint(str(checkpoint_path)) + model, conf, background, post_processing = load_model_from_checkpoint(str(checkpoint_path)) logger.info(f"Loaded model with {model.get_positions().shape[0]} Gaussians") except ImportError: logger.error("Failed to import model class. Is 3DGRUT properly installed?") @@ -197,9 +404,24 @@ def main(): traceback.print_exc() sys.exit(1) - # Load dataset for camera export + export_conf = getattr(conf, "export_usd", None) or conf + if args.export_post_processing is not None: + export_post_processing = args.export_post_processing + elif post_processing is not None: + export_post_processing = True + else: + export_post_processing = bool(_get_export_post_processing_default(export_conf)) + post_processing_export_mode = _arg_or_conf( + args.post_processing_export_mode, + export_conf, + "post-processing-export-mode", + "post_processing_export_mode", + "baked-sh", + ) + # Load dataset for camera export and for train-split post-processing SH baking. dataset = None - if not args.no_cameras: + needs_dataset = not args.no_cameras or (post_processing is not None and export_post_processing) + if needs_dataset: try: import threedgrut.datasets as datasets @@ -214,15 +436,16 @@ def main(): elif not hasattr(conf, "dataset") or not hasattr(conf.dataset, "type"): logger.warning("No dataset type in checkpoint config. Cannot load dataset for camera export.") else: - dataset = datasets.make_test(name=conf.dataset.type, config=conf) + dataset = datasets.make_train(name=conf.dataset.type, config=conf, ray_jitter=None) split = getattr(dataset, "split", "unknown") logger.info(f"Loaded dataset with {len(dataset)} frames for camera export (split={split})") except Exception as e: - logger.warning(f"Failed to load dataset for camera export: {e}") + logger.error(f"Failed to load dataset for camera export: {e}") if args.verbose: import traceback traceback.print_exc() + sys.exit(1) # Create exporter based on format if args.format == "nurec": @@ -237,18 +460,132 @@ def main(): export_cameras=not args.no_cameras, export_background=not args.no_background, apply_normalizing_transform=not args.no_transform, - linear_srgb=args.linear_srgb, + sorting_mode_hint=_arg_or_conf( + args.sorting_mode_hint, + export_conf, + "sorting-mode-hint", + "sorting_mode_hint", + DEFAULT_PARTICLE_FIELD_SORTING_MODE_HINT, + ), + linear_srgb=args.linear_srgb or getattr(export_conf, "linear_srgb", False), + export_post_processing=export_post_processing, + post_processing_export_mode=post_processing_export_mode, + post_processing_export_camera_id=_arg_or_conf( + args.post_processing_export_camera_id, + export_conf, + "post-processing-export-camera-id", + "post_processing_export_camera_id", + None, + ), + post_processing_export_frame_id=_arg_or_conf( + args.post_processing_export_frame_id, + export_conf, + "post-processing-export-frame-id", + "post_processing_export_frame_id", + None, + ), + ignore_ppisp_controller=args.ignore_ppisp_controller, + post_processing_bake_epochs=_arg_or_conf( + args.post_processing_bake_epochs, + export_conf, + "post-processing-bake-epochs", + "post_processing_bake_epochs", + 7, + ), + post_processing_bake_learning_rate=_arg_or_conf( + args.post_processing_bake_learning_rate, + export_conf, + "post-processing-bake-learning-rate", + "post_processing_bake_learning_rate", + 2.5e-3, + ), + post_processing_bake_learning_rate_specular=_arg_or_conf( + args.post_processing_bake_learning_rate_specular, + export_conf, + "post-processing-bake-learning-rate-specular", + "post_processing_bake_learning_rate_specular", + None, + ), + post_processing_bake_learning_rate_density=_arg_or_conf( + args.post_processing_bake_learning_rate_density, + export_conf, + "post-processing-bake-learning-rate-density", + "post_processing_bake_learning_rate_density", + 5.0e-2, + ), + post_processing_bake_camera_id=_arg_or_conf( + args.post_processing_bake_camera_id, + export_conf, + "post-processing-bake-camera-id", + "post_processing_bake_camera_id", + 0, + ), + post_processing_bake_frame_id=_arg_or_conf( + args.post_processing_bake_frame_id, + export_conf, + "post-processing-bake-frame-id", + "post_processing_bake_frame_id", + 0, + ), + ppisp_bake_vignetting_mode=_arg_or_conf( + args.ppisp_bake_vignetting_mode, + export_conf, + "ppisp-bake-vignetting-mode", + "ppisp_bake_vignetting_mode", + "achromatic-fit", + ), + post_processing_bake_view_mode=_arg_or_conf( + args.post_processing_bake_view_mode, + export_conf, + "post-processing-bake-view-mode", + "post_processing_bake_view_mode", + "trajectory", + ), + post_processing_bake_view_seed=_arg_or_conf( + args.post_processing_bake_view_seed, + export_conf, + "post-processing-bake-view-seed", + "post_processing_bake_view_seed", + None, + ), + post_processing_bake_trajectory_weight_position=_arg_or_conf( + args.post_processing_bake_trajectory_weight_position, + export_conf, + "post-processing-bake-trajectory-weight-position", + "post_processing_bake_trajectory_weight_position", + 1.0, + ), + post_processing_bake_trajectory_weight_rotation=_arg_or_conf( + args.post_processing_bake_trajectory_weight_rotation, + export_conf, + "post-processing-bake-trajectory-weight-rotation", + "post_processing_bake_trajectory_weight_rotation", + 0.5, + ), + output_scale=_arg_or_conf( + args.output_scale, + export_conf, + "output-scale", + "output_scale", + 1.0, + ), + frames_per_second=getattr(export_conf, "frames_per_second", 1.0), ) logger.info("Using ParticleField3DGaussianSplat schema (standard)") # Export try: + export_kw = {} + if args.format == "standard": + export_kw["validate_usd"] = not args.no_usd_validate exporter.export( model=model, output_path=output_path, dataset=dataset, conf=conf, background=background, + post_processing=post_processing, + **export_kw, ) logger.info(f"Export successful: {output_path}") except Exception as e: diff --git a/threedgrut/export/scripts/post_processing_sh_bake_validation.py b/threedgrut/export/scripts/post_processing_sh_bake_validation.py new file mode 100644 index 00000000..69cedff5 --- /dev/null +++ b/threedgrut/export/scripts/post_processing_sh_bake_validation.py @@ -0,0 +1,493 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Validate baking one fixed PPISP transform into Gaussian SH coefficients. + +The reference is the checkpoint render followed by PPISP from one camera/frame, +including that camera's chromatic vignetting. The fitted method optimizes only a +cloned model's SH coefficients, with a temporary achromatic vignette applied in +the fitting loss to isolate chromatic vignette effects. +""" + +from __future__ import annotations + +import argparse +import json +import sys +from pathlib import Path +from typing import Dict, Iterable + +sys.path.insert(0, str(Path(__file__).resolve().parents[3])) + +import numpy as np +import torch +import torch.nn as nn +import torchvision +from torchmetrics import PeakSignalNoiseRatio +from torchmetrics.image import StructuralSimilarityIndexMeasure +from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity + +import threedgrut.datasets as datasets +from threedgrut.render import Renderer +from threedgrut.datasets.utils import configure_dataloader_for_platform +from threedgrut.export.usd.post_processing_sh_bake import ( + MODE_PPISP_BAKE_VIGNETTING_NONE, + FixedPPISP, + apply_achromatic_vignetting, + normalize_ppisp_bake_vignetting_mode, +) +from threedgrut.export.usd.post_processing_sh_simple_bake import simple_bake +from threedgrut.utils.logger import logger +from threedgrut.utils.post_processing_linear_to_srgb import linear_to_srgb +from threedgrut.utils.render import apply_post_processing + +BAKE_FLAVOR_FIT = "fit" +BAKE_FLAVOR_SIMPLE = "simple" +BAKE_FLAVOR_SIMPLE_HIGHER_ORDER = "simple-higher-order" +BAKE_FLAVOR_ALL = "all" + + +def _setShFitParameters(model) -> Iterable[torch.nn.Parameter]: + for parameter in model.parameters(): + parameter.requires_grad_(False) + + fitParameters = [] + for fieldName in ("features_albedo", "features_specular"): + parameter = getattr(model, fieldName) + parameter.requires_grad_(True) + fitParameters.append(parameter) + return fitParameters + + +def _renderReference(referenceModel, fixedPpisp, gpuBatch) -> torch.Tensor: + with torch.no_grad(): + outputs = referenceModel(gpuBatch) + outputs = apply_post_processing(fixedPpisp, outputs, gpuBatch, training=True) + return outputs["pred_rgb"].detach() + + +def _applyAchromaticVignetting(rgb: torch.Tensor, fixedPpisp, gpuBatch, vignettingMode: str) -> torch.Tensor: + if vignettingMode == MODE_PPISP_BAKE_VIGNETTING_NONE: + return rgb + _, height, width, _ = rgb.shape + return apply_achromatic_vignetting( + rgb=rgb, + ppisp=fixedPpisp.ppisp, + camera_id=fixedPpisp.camera_id, + pixel_coords=gpuBatch.pixel_coords, + resolution=(width, height), + ) + + +def _createTrainDataloader(conf): + trainDataset = datasets.make_train(name=conf.dataset.type, config=conf, ray_jitter=None) + dataloaderKwargs = configure_dataloader_for_platform( + { + "num_workers": conf.num_workers, + "batch_size": 1, + "shuffle": True, + "pin_memory": True, + "persistent_workers": True if conf.num_workers > 0 else False, + } + ) + trainDataloader = torch.utils.data.DataLoader(trainDataset, **dataloaderKwargs) + return trainDataset, trainDataloader + + +def _fitBakedSh( + referenceModel, + bakedModel, + fixedPpisp, + dataset, + dataloader, + fitEpochs: int, + learningRate: float, + vignettingMode: str, +) -> None: + if fitEpochs < 1: + raise ValueError(f"fitEpochs must be >= 1, got {fitEpochs}.") + + fitParameters = list(_setShFitParameters(bakedModel)) + optimizer = torch.optim.Adam(fitParameters, lr=learningRate) + + totalSteps = fitEpochs * len(dataloader) + logger.start_progress(task_name="Fitting baked SH", total_steps=totalSteps, color="cyan") + globalStep = 0 + for fitEpoch in range(fitEpochs): + for batch in dataloader: + globalStep += 1 + gpuBatch = dataset.get_gpu_batch_with_intrinsics(batch) + referenceRgb = _renderReference(referenceModel, fixedPpisp, gpuBatch) + + optimizer.zero_grad(set_to_none=True) + bakedOutputs = bakedModel(gpuBatch) + fittedRgb = torch.clamp( + linear_to_srgb( + _applyAchromaticVignetting(bakedOutputs["pred_rgb"], fixedPpisp, gpuBatch, vignettingMode) + ), + 0, + 1, + ) + loss = torch.nn.functional.mse_loss(fittedRgb, referenceRgb) + + loss.backward() + optimizer.step() + + logger.log_progress( + task_name="Fitting baked SH", + advance=1, + iteration=f"{fitEpoch + 1}/{fitEpochs}:{globalStep}", + loss=float(loss.detach().item()), + ) + logger.end_progress(task_name="Fitting baked SH") + + +@torch.no_grad() +def _evaluateBakedSh( + referenceModel, + bakedModel, + simpleBakedModels: Dict[str, nn.Module], + fixedPpisp, + fullFixedPpisp, + dataset, + dataloader, + outputRoot: Path, + computeExtraMetrics: bool, + vignettingMode: str, +) -> dict: + criterions = {"psnr": PeakSignalNoiseRatio(data_range=1).to("cuda")} + if computeExtraMetrics: + criterions |= { + "ssim": StructuralSimilarityIndexMeasure(data_range=1.0).to("cuda"), + "lpips": LearnedPerceptualImagePatchSimilarity(net_type="vgg", normalize=True).to("cuda"), + } + + fullReferencePath = outputRoot / "full_ppisp_reference" + referencePath = outputRoot / "reference" + unfittedPath = outputRoot / "unfitted" + fullReferencePath.mkdir(parents=True, exist_ok=True) + referencePath.mkdir(parents=True, exist_ok=True) + unfittedPath.mkdir(parents=True, exist_ok=True) + bakedPath = outputRoot / "baked" if bakedModel is not None else None + assistedPath = outputRoot / "baked_assisted" if bakedModel is not None else None + if bakedPath is not None: + bakedPath.mkdir(parents=True, exist_ok=True) + if assistedPath is not None: + assistedPath.mkdir(parents=True, exist_ok=True) + simplePaths = {name: outputRoot / f"{name}_baked" for name in simpleBakedModels} + for simplePath in simplePaths.values(): + simplePath.mkdir(parents=True, exist_ok=True) + + unfittedPsnrValues = [] + psnrValues = [] + ssimValues = [] + lpipsValues = [] + assistedPsnrValues = [] + assistedSsimValues = [] + assistedLpipsValues = [] + inferenceTimeValues = [] + simpleMetricValues = { + name: { + "psnr": [], + "ssim": [], + "lpips": [], + } + for name in simpleBakedModels + } + + logger.start_progress(task_name="Evaluating baked SH", total_steps=len(dataloader), color="orange1") + for iteration, batch in enumerate(dataloader): + gpuBatch = dataset.get_gpu_batch_with_intrinsics(batch) + + fullReferenceRgb = _renderReference(referenceModel, fullFixedPpisp, gpuBatch) + referenceRgb = _renderReference(referenceModel, fixedPpisp, gpuBatch) + unfittedOutputs = referenceModel(gpuBatch) + unfittedRgb = unfittedOutputs["pred_rgb"] + + torchvision.utils.save_image( + fullReferenceRgb.squeeze(0).permute(2, 0, 1).clip(0, 1), + fullReferencePath / f"{iteration:05d}.png", + ) + torchvision.utils.save_image( + referenceRgb.squeeze(0).permute(2, 0, 1).clip(0, 1), + referencePath / f"{iteration:05d}.png", + ) + torchvision.utils.save_image( + unfittedRgb.squeeze(0).permute(2, 0, 1).clip(0, 1), + unfittedPath / f"{iteration:05d}.png", + ) + + unfittedPsnrValues.append(criterions["psnr"](unfittedRgb, referenceRgb).item()) + + if bakedModel is not None: + bakedOutputs = bakedModel(gpuBatch) + bakedRgb = torch.clamp(linear_to_srgb(bakedOutputs["pred_rgb"]), 0, 1) + assistedRgb = torch.clamp( + linear_to_srgb( + _applyAchromaticVignetting(bakedOutputs["pred_rgb"], fixedPpisp, gpuBatch, vignettingMode) + ), + 0, + 1, + ) + torchvision.utils.save_image( + bakedRgb.squeeze(0).permute(2, 0, 1).clip(0, 1), + bakedPath / f"{iteration:05d}.png", + ) + torchvision.utils.save_image( + assistedRgb.squeeze(0).permute(2, 0, 1).clip(0, 1), + assistedPath / f"{iteration:05d}.png", + ) + + psnrValues.append(criterions["psnr"](bakedRgb, referenceRgb).item()) + assistedPsnrValues.append(criterions["psnr"](assistedRgb, referenceRgb).item()) + if computeExtraMetrics: + ssimValues.append( + criterions["ssim"](bakedRgb.permute(0, 3, 1, 2), referenceRgb.permute(0, 3, 1, 2)).item() + ) + lpipsValues.append( + criterions["lpips"]( + bakedRgb.clip(0, 1).permute(0, 3, 1, 2), referenceRgb.clip(0, 1).permute(0, 3, 1, 2) + ).item() + ) + assistedSsimValues.append( + criterions["ssim"](assistedRgb.permute(0, 3, 1, 2), referenceRgb.permute(0, 3, 1, 2)).item() + ) + assistedLpipsValues.append( + criterions["lpips"]( + assistedRgb.clip(0, 1).permute(0, 3, 1, 2), referenceRgb.clip(0, 1).permute(0, 3, 1, 2) + ).item() + ) + + if "frame_time_ms" in bakedOutputs: + inferenceTimeValues.append(bakedOutputs["frame_time_ms"]) + + for simpleName, simpleModel in simpleBakedModels.items(): + simpleOutputs = simpleModel(gpuBatch) + simpleRgb = torch.clamp(simpleOutputs["pred_rgb"], 0, 1) + torchvision.utils.save_image( + simpleRgb.squeeze(0).permute(2, 0, 1).clip(0, 1), + simplePaths[simpleName] / f"{iteration:05d}.png", + ) + simpleValues = simpleMetricValues[simpleName] + simpleValues["psnr"].append(criterions["psnr"](simpleRgb, referenceRgb).item()) + if computeExtraMetrics: + simpleValues["ssim"].append( + criterions["ssim"](simpleRgb.permute(0, 3, 1, 2), referenceRgb.permute(0, 3, 1, 2)).item() + ) + simpleValues["lpips"].append( + criterions["lpips"]( + simpleRgb.clip(0, 1).permute(0, 3, 1, 2), + referenceRgb.clip(0, 1).permute(0, 3, 1, 2), + ).item() + ) + + progressPsnr = psnrValues[-1] if psnrValues else unfittedPsnrValues[-1] + logger.log_progress(task_name="Evaluating baked SH", advance=1, iteration=str(iteration), psnr=progressPsnr) + logger.end_progress(task_name="Evaluating baked SH") + + metrics = { + "vignetting_mode": vignettingMode, + "unfitted_mean_psnr": float(np.mean(unfittedPsnrValues)), + "unfitted_std_psnr": float(np.std(unfittedPsnrValues)), + } + if psnrValues: + metrics |= { + "mean_psnr": float(np.mean(psnrValues)), + "std_psnr": float(np.std(psnrValues)), + "assisted_mean_psnr": float(np.mean(assistedPsnrValues)), + "assisted_std_psnr": float(np.std(assistedPsnrValues)), + } + if computeExtraMetrics: + if ssimValues: + metrics |= { + "mean_ssim": float(np.mean(ssimValues)), + "mean_lpips": float(np.mean(lpipsValues)), + "assisted_mean_ssim": float(np.mean(assistedSsimValues)), + "assisted_mean_lpips": float(np.mean(assistedLpipsValues)), + } + for simpleName, simpleValues in simpleMetricValues.items(): + metrics[f"{simpleName}_mean_psnr"] = float(np.mean(simpleValues["psnr"])) + metrics[f"{simpleName}_std_psnr"] = float(np.std(simpleValues["psnr"])) + if computeExtraMetrics: + metrics |= { + f"{simpleName}_mean_ssim": float(np.mean(simpleValues["ssim"])), + f"{simpleName}_mean_lpips": float(np.mean(simpleValues["lpips"])), + } + if inferenceTimeValues: + metrics["mean_inference_time"] = f"{np.mean(inferenceTimeValues):.2f} ms/frame" + + with open(outputRoot / "metrics.json", "w") as file: + json.dump(metrics, file, indent=2) + + psnrMetrics = {key: value for key, value in metrics.items() if "psnr" in key} + logger.log_table("Post-Processing SH Bake Validation PSNR", record=psnrMetrics) + return metrics + + +def _validateArguments(args, ppisp: nn.Module) -> None: + if not hasattr(ppisp, "vignetting_params"): + raise ValueError("Checkpoint post-processing is not PPISP-like: missing vignetting_params.") + if not hasattr(ppisp, "exposure_params") or not hasattr(ppisp, "crf_params"): + raise ValueError("Checkpoint post-processing is not PPISP-like: missing exposure_params or crf_params.") + + numFrames = int(ppisp.exposure_params.shape[0]) + numCameras = int(ppisp.crf_params.shape[0]) + if args.frameId < 0 or args.frameId >= numFrames: + raise ValueError(f"frameId must be in [0, {numFrames - 1}], got {args.frameId}.") + if args.cameraId < 0 or args.cameraId >= numCameras: + raise ValueError(f"cameraId must be in [0, {numCameras - 1}], got {args.cameraId}.") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--checkpoint", required=True, type=str, help="Path to the pretrained checkpoint.") + parser.add_argument("--path", type=str, default="", help="Path to test data, if not provided taken from ckpt.") + parser.add_argument("--out-dir", dest="outDir", required=True, type=str, help="Output path.") + parser.add_argument("--camera-id", dest="cameraId", default=0, type=int, help="PPISP camera id to bake.") + parser.add_argument("--frame-id", dest="frameId", default=0, type=int, help="PPISP frame id to bake.") + parser.add_argument( + "--fit-epochs", + dest="fitEpochs", + default=1, + type=int, + help="Number of sequential passes over the train/reference set.", + ) + parser.add_argument("--learning-rate", dest="learningRate", default=1.0e-3, type=float, help="SH fitting LR.") + parser.add_argument( + "--bake-flavor", + dest="bakeFlavor", + choices=[ + BAKE_FLAVOR_FIT, + BAKE_FLAVOR_SIMPLE, + BAKE_FLAVOR_SIMPLE_HIGHER_ORDER, + BAKE_FLAVOR_ALL, + ], + default=BAKE_FLAVOR_FIT, + help=( + "Bake flavor to evaluate. 'fit' optimizes SH; 'simple' one-shot bakes DC SH; " + "'simple-higher-order' also linearizes higher-order SH; 'all' compares every flavor." + ), + ) + parser.add_argument( + "--vignetting-mode", + dest="vignettingMode", + choices=["none", "achromatic-fit"], + default="achromatic-fit", + help=( + "Vignetting handling for the bake. 'none' disables PPISP vignetting; " + "'achromatic-fit' uses chromatic PPISP reference and an achromatic fit-only vignette." + ), + ) + parser.add_argument( + "--compute-extra-metrics", + dest="computeExtraMetrics", + action="store_false", + help="If set, extra image metrics will not be computed [True by default].", + ) + args = parser.parse_args() + + renderer = Renderer.from_checkpoint( + checkpoint_path=args.checkpoint, + path=args.path, + out_dir=args.outDir, + save_gt=False, + computes_extra_metrics=args.computeExtraMetrics, + ) + if renderer.post_processing is None: + raise ValueError("Checkpoint does not contain PPISP post-processing.") + + _validateArguments(args, renderer.post_processing) + vignettingMode = normalize_ppisp_bake_vignetting_mode(args.vignettingMode) + fixedPpisp = FixedPPISP( + renderer.post_processing, + args.cameraId, + args.frameId, + "cuda", + include_vignetting=vignettingMode != MODE_PPISP_BAKE_VIGNETTING_NONE, + ).eval() + fullFixedPpisp = FixedPPISP( + renderer.post_processing, + args.cameraId, + args.frameId, + "cuda", + include_vignetting=True, + ).eval() + + referenceModel = renderer.model.eval() + + outputRoot = Path(renderer.out_dir) / f"post_processing_sh_bake_ci{args.cameraId}_fi{args.frameId}" + outputRoot.mkdir(parents=True, exist_ok=True) + + trainDataset, trainDataloader = _createTrainDataloader(renderer.conf) + + runFit = args.bakeFlavor in (BAKE_FLAVOR_FIT, BAKE_FLAVOR_ALL) + simpleFlavorHigherOrderFlags = [] + if args.bakeFlavor in (BAKE_FLAVOR_SIMPLE, BAKE_FLAVOR_ALL): + simpleFlavorHigherOrderFlags.append(("simple", False)) + if args.bakeFlavor in (BAKE_FLAVOR_SIMPLE_HIGHER_ORDER, BAKE_FLAVOR_ALL): + simpleFlavorHigherOrderFlags.append(("simple_higher_order", True)) + + bakedModel = None + if runFit: + bakedModel = renderer.model.clone().eval() + bakedModel.build_acc() + logger.info(f"Fitting SH coefficients to fixed PPISP camera={args.cameraId} frame={args.frameId}") + _fitBakedSh( + referenceModel=referenceModel, + bakedModel=bakedModel, + fixedPpisp=fixedPpisp, + dataset=trainDataset, + dataloader=trainDataloader, + fitEpochs=args.fitEpochs, + learningRate=args.learningRate, + vignettingMode=vignettingMode, + ) + + simpleBakedModels = {} + for simpleName, higherOrder in simpleFlavorHigherOrderFlags: + simpleModel = renderer.model.clone().eval() + logger.info( + f"Simple-baking SH for camera_id={args.cameraId} " + f"frame_id={args.frameId} (fixed exposure/color; higher_order={higherOrder})" + ) + exposure, color = simple_bake( + model=simpleModel, + ppisp=renderer.post_processing, + camera_id=args.cameraId, + frame_id=args.frameId, + higher_order=higherOrder, + ) + simpleModel.build_acc() + simpleBakedModels[simpleName] = simpleModel + logger.info( + f"{simpleName} bake done. exposure={exposure:.6f}; " f"color={[float(value) for value in color.tolist()]}" + ) + + _evaluateBakedSh( + referenceModel=referenceModel, + bakedModel=bakedModel, + simpleBakedModels=simpleBakedModels, + fixedPpisp=fixedPpisp, + fullFixedPpisp=fullFixedPpisp, + dataset=renderer.dataset, + dataloader=renderer.dataloader, + outputRoot=outputRoot, + computeExtraMetrics=args.computeExtraMetrics, + vignettingMode=vignettingMode, + ) + + +if __name__ == "__main__": + main() diff --git a/threedgrut/export/scripts/transcode.py b/threedgrut/export/scripts/transcode.py index 9c19261b..c0b34a2d 100644 --- a/threedgrut/export/scripts/transcode.py +++ b/threedgrut/export/scripts/transcode.py @@ -25,6 +25,12 @@ python -m threedgrut.export.scripts.transcode input.ply -o output.usdz --format lightfield python -m threedgrut.export.scripts.transcode input.usdz -o output.ply python -m threedgrut.export.scripts.transcode nurec.usd -o lightfield.usdz --format lightfield + +USD/USDZ → LightField: source /World prims (e.g. rig_trajectories) and /Render +merge into default.usda at the same paths; referenced layers are bundled unchanged +(preserves camera animation curves and authored render products). +/World/Gaussians is skipped by default; use --copy-source-include-gaussians to merge it too. +Use --no-copy-source-prims to disable. """ import argparse @@ -32,6 +38,7 @@ import sys import tempfile import zipfile +from contextlib import nullcontext from pathlib import Path from typing import Optional, Tuple @@ -44,8 +51,13 @@ PLYImporter, USDImporter, ) +from threedgrut.export.usd.camera_copy import usd_stage_path_context_for_camera_copy from threedgrut.export.usd.exporter import USDExporter from threedgrut.export.usd.nurec.exporter import NuRecExporter +from threedgrut.export.usd.particle_field_hints import ( + DEFAULT_PARTICLE_FIELD_SORTING_MODE_HINT, + PARTICLE_FIELD_SORTING_MODE_HINTS, +) logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s") logger = logging.getLogger(__name__) @@ -166,7 +178,11 @@ def get_exporter( export_cameras=False, export_background=False, apply_normalizing_transform=False, - sorting_mode_hint=render_order_hint if render_order_hint is not None else "cameraDistance", + sorting_mode_hint=( + render_order_hint + if render_order_hint is not None + else DEFAULT_PARTICLE_FIELD_SORTING_MODE_HINT + ), linear_srgb=linear_srgb, ), False, @@ -205,6 +221,9 @@ def transcode( apply_coordinate_transform: bool = False, render_order_hint: Optional[str] = None, linear_srgb: bool = False, + copy_cameras_source: Optional[Tuple[Path, Path]] = None, + copy_source_skip_subtrees: Optional[Tuple] = None, + validate_usd: bool = True, ) -> None: """Transcode between Gaussian splatting formats. @@ -219,6 +238,9 @@ def transcode( apply_coordinate_transform: Apply 3DGRUT-to-USDZ transform (for both lightfield and nurec) render_order_hint: If set, force sortingModeHint for lightfield only; ignored for other formats (warning logged). linear_srgb: If True, set prim color space to lin_rec709_scene (lightfield only). + copy_cameras_source: If set, (root_usd_path, asset_resolution_dir) to copy source /World prims from. + copy_source_skip_subtrees: Optional tuple of Sdf.Path roots to skip under /World (None = default skip Gaussians). + validate_usd: If True and output is lightfield, run OpenUSD stage validation after export. """ if render_order_hint is not None and output_format != "lightfield": logger.warning( @@ -264,7 +286,14 @@ def transcode( # Export logger.info(f"Exporting to {output_path}...") - exporter.export(adapter, output_path, apply_coordinate_transform=apply_coordinate_transform) + exporter.export( + adapter, + output_path, + apply_coordinate_transform=apply_coordinate_transform, + copy_cameras_source=copy_cameras_source, + copy_source_skip_subtrees=copy_source_skip_subtrees, + validate_usd=validate_usd if output_format == "lightfield" else False, + ) logger.info(f"Transcode complete: {input_path} -> {output_path}") @@ -343,21 +372,47 @@ def parse_args(): parser.add_argument( "--render-order-hint", type=str, + choices=PARTICLE_FIELD_SORTING_MODE_HINTS, default=None, metavar="MODE", - help="Force sortingModeHint for lightfield export (e.g. cameraDistance, zDepth). Ignored with --format ply/nurec (warning only).", + help=( + "Force sortingModeHint for lightfield export " + "(zDepth, cameraDistance, rayHitDistance). Ignored with --format ply/nurec (warning only)." + ), ) parser.add_argument( "--linear-srgb", action="store_true", help="Set prim color space to lin_rec709_scene (lightfield only). Default is srgb_rec709_display.", ) + parser.add_argument( + "--no-copy-source-prims", + action="store_true", + dest="no_copy_source_prims", + help="When input is USD/USDZ and output is LightField, do not merge source /World prims into default.usda.", + ) + parser.add_argument( + "--no-copy-source-cameras", + action="store_true", + dest="no_copy_source_prims", + help="Deprecated alias for --no-copy-source-prims.", + ) + parser.add_argument( + "--copy-source-include-gaussians", + action="store_true", + help="Also copy /World/Gaussians from the source (duplicates old LightField data; can be very large).", + ) parser.add_argument( "-v", "--verbose", action="store_true", help="Enable verbose logging", ) + parser.add_argument( + "--no-usd-validate", + action="store_true", + help="Skip OpenUSD stage validation after lightfield (.usd/.usdz) export", + ) return parser.parse_args() @@ -392,19 +447,32 @@ def main(): # Create output directory if needed output_path.parent.mkdir(parents=True, exist_ok=True) + suffix_in = input_path.suffix.lower() + use_camera_copy_ctx = ( + output_format == "lightfield" + and suffix_in in (".usd", ".usda", ".usdc", ".usdz") + and not args.no_copy_source_prims + ) + camera_ctx = usd_stage_path_context_for_camera_copy(input_path) if use_camera_copy_ctx else nullcontext(None) + try: - transcode( - input_path=input_path, - output_path=output_path, - output_format=output_format, - max_sh_degree=args.max_sh_degree, - half_precision=args.half, - half_geometry=args.half_geometry, - half_features=args.half_features, - apply_coordinate_transform=args.apply_coordinate_transform, - render_order_hint=args.render_order_hint, - linear_srgb=args.linear_srgb, - ) + with camera_ctx as copy_cameras_source: + skip_subtrees = () if args.copy_source_include_gaussians else None + transcode( + input_path=input_path, + output_path=output_path, + output_format=output_format, + max_sh_degree=args.max_sh_degree, + half_precision=args.half, + half_geometry=args.half_geometry, + half_features=args.half_features, + apply_coordinate_transform=args.apply_coordinate_transform, + render_order_hint=args.render_order_hint, + linear_srgb=args.linear_srgb, + copy_cameras_source=copy_cameras_source, + copy_source_skip_subtrees=skip_subtrees, + validate_usd=not args.no_usd_validate, + ) except Exception as e: logger.error(f"Transcode failed: {e}") if args.verbose: diff --git a/threedgrut/export/tests/test_export_import.py b/threedgrut/export/tests/test_export_import.py index eae8769f..421866f4 100644 --- a/threedgrut/export/tests/test_export_import.py +++ b/threedgrut/export/tests/test_export_import.py @@ -26,7 +26,7 @@ import numpy as np import pytest import torch -from pxr import Usd, UsdValidation +from pxr import Usd from threedgrut.export.base import ExportableModel from threedgrut.export.formats import PLYExporter @@ -34,18 +34,6 @@ from threedgrut.export.usd.exporter import USDExporter -def _validate_stage(stage: Usd.Stage) -> list: - """Run usd-core stage validators (StageMetadataChecker, CompositionErrorTest). Returns list of ValidationError.""" - validators = UsdValidation.ValidationRegistry().GetOrLoadValidatorsByName( - ["usdValidation:StageMetadataChecker", "usdValidation:CompositionErrorTest"] - ) - if not validators: - return [] - ctx = UsdValidation.ValidationContext(validators) - result = ctx.Validate(stage) - return list(result) if result else [] - - class MockGaussianModel(ExportableModel): """Mock ExportableModel with known test data for verification.""" @@ -123,6 +111,24 @@ def get_features_specular(self) -> torch.Tensor: return self._specular +class MockCameraDataset: + """Minimal dataset exposing camera poses for USD camera export tests.""" + + def __len__(self) -> int: + return 2 + + def get_poses(self) -> np.ndarray: + poses = np.repeat(np.eye(4, dtype=np.float64)[None, :, :], len(self), axis=0) + poses[1, 0, 3] = 1.0 + return poses + + def get_camera_names(self): + return ["camera_0000"] + + def get_camera_idx(self, frame_idx: int) -> int: + return 0 + + class TestPLYExportImport: """Test PLY export from ExportableModel and import back.""" @@ -404,7 +410,7 @@ def test_ply_usd_positions_match(self): ) def test_usd_export_passes_usd_validation(self): - """Exported USD stage passes usd-core schema/stage validators.""" + """Exported USD stage passes OpenUSD stage validators (run inside USDExporter.export).""" model = MockGaussianModel(num_gaussians=5, sh_degree=0) with tempfile.TemporaryDirectory() as tmpdir: usd_path = Path(tmpdir) / "test.usdz" @@ -414,10 +420,6 @@ def test_usd_export_passes_usd_validation(self): export_background=False, apply_normalizing_transform=False, ).export(model, usd_path) - stage = Usd.Stage.Open(str(usd_path)) - assert stage, "Failed to open exported stage" - errors = _validate_stage(stage) - assert not errors, "USD validation failed:\n" + "\n".join(e.GetMessage() for e in errors) def _find_prim_with_color_space_api(stage: Usd.Stage): @@ -488,6 +490,52 @@ def test_usd_export_color_space_from_config(self): api = Usd.ColorSpaceAPI(prim) assert api.GetColorSpaceNameAttr().Get() == "lin_rec709_scene" + def test_usdz_export_camera_is_composed_from_root_stage(self): + """USDZ camera prims are authored where the package root composes them.""" + model = MockGaussianModel(num_gaussians=5, sh_degree=3) + dataset = MockCameraDataset() + with tempfile.TemporaryDirectory() as tmpdir: + usd_path = Path(tmpdir) / "test.usdz" + USDExporter( + half_precision=False, + export_cameras=True, + export_background=False, + apply_normalizing_transform=False, + ).export(model, usd_path, dataset=dataset) + stage = Usd.Stage.Open(str(usd_path)) + assert stage + assert stage.GetPrimAtPath("/World/Cameras/camera_0000").IsValid() + assert not stage.GetPrimAtPath("/World/gaussians/Cameras/camera_0000").IsValid() + assert stage.GetStartTimeCode() == 0.0 + assert stage.GetEndTimeCode() == 1.0 + + +class TestUSDExportSortingModeHint: + """Test ParticleField sortingModeHint authoring.""" + + def test_usd_export_sorting_mode_hint_ray_hit_distance(self): + """Export can author the usd-core 26.5 rayHitDistance sorting hint.""" + model = MockGaussianModel(num_gaussians=5, sh_degree=3) + with tempfile.TemporaryDirectory() as tmpdir: + usd_path = Path(tmpdir) / "test.usdz" + USDExporter( + half_precision=False, + export_cameras=False, + export_background=False, + apply_normalizing_transform=False, + sorting_mode_hint="rayHitDistance", + ).export(model, usd_path) + stage = Usd.Stage.Open(str(usd_path)) + assert stage + prim = _find_prim_with_color_space_api(stage) + assert prim is not None, "No Gaussian particle prim found" + assert prim.GetAttribute("sortingModeHint").Get() == "rayHitDistance" + + def test_usd_export_sorting_mode_hint_rejects_unknown_token(self): + """Unsupported sorting hints fail before authoring invalid USD.""" + with pytest.raises(ValueError, match="Unsupported ParticleField sortingModeHint"): + USDExporter(sorting_mode_hint="frontToBack") + if __name__ == "__main__": pytest.main([__file__, "-v"]) diff --git a/threedgrut/export/usd/camera_copy.py b/threedgrut/export/usd/camera_copy.py new file mode 100644 index 00000000..737b826f --- /dev/null +++ b/threedgrut/export/usd/camera_copy.py @@ -0,0 +1,345 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Copy prims from a source USD stage into an export stage (transcode USD → LightField).""" + +import logging +import tempfile +import zipfile +from contextlib import contextmanager +from pathlib import Path +from typing import Collection, Iterator, List, Optional, Set, Tuple + +from pxr import Sdf, UsdGeom + +from threedgrut.export.usd.stage_utils import NamedSerialized + +logger = logging.getLogger(__name__) + +UsdStagePathPair = Tuple[Path, Path] + +# Default: do not duplicate LightField Gaussian root (large); new splats live at /World/Gaussians. +_DEFAULT_SKIP_SUBTREES = (Sdf.Path("/World/Gaussians"),) + + +def _path_is_under_skipped(src_path: Sdf.Path, skip_roots: Collection[Sdf.Path]) -> bool: + for root in skip_roots: + if src_path == root: + return True + # Children of root (e.g. /World/Gaussians/gaussians) + prefix = str(root) + "/" + if str(src_path).startswith(prefix): + return True + return False + + +def _copy_prim_spec_recursive( + src_layer: Sdf.Layer, + dst_layer: Sdf.Layer, + src_path: Sdf.Path, + dst_path: Sdf.Path, +) -> int: + """Copy one prim spec and all descendants. Returns number of prims copied.""" + src_spec = src_layer.GetPrimAtPath(src_path) + if not src_spec or not src_spec.active: + return 0 + Sdf.CopySpec(src_layer, src_path, dst_layer, dst_path) + count = 1 + for child_spec in src_spec.nameChildren: + name = child_spec.name + count += _copy_prim_spec_recursive( + src_layer, + dst_layer, + src_path.AppendChild(name), + dst_path.AppendChild(name), + ) + return count + + +def merge_source_world_at_same_paths( + dest_stage, + source_stage, + skip_source_subtrees: Optional[Collection[Sdf.Path]] = None, +) -> int: + """ + Merge each top-level child of ``/World`` from the source **root** layer onto ``dest_stage``'s + root layer at the **same path** as the source (e.g. ``/World/rig_trajectories``), using + ``Sdf.CopySpec``. References and payloads are copied as-authored so sibling layers (e.g. + ``rig_trajectories.usda``) keep all time samples when those files are bundled unchanged. + + Skips subtrees in ``skip_source_subtrees`` (default: ``/World/Gaussians`` for LightField). + Skips any path where the destination root layer already has a prim spec (e.g. export's + ``/World/gaussians`` reference prim). + """ + skips = tuple(skip_source_subtrees) if skip_source_subtrees is not None else _DEFAULT_SKIP_SUBTREES + src_layer = source_stage.GetRootLayer() + dst_layer = dest_stage.GetRootLayer() + + world_spec = src_layer.GetPrimAtPath("/World") + if not world_spec: + logger.info("Source USD has no /World prim; nothing to merge") + return 0 + + total = 0 + for child_spec in world_spec.nameChildren: + name = child_spec.name + path = Sdf.Path("/World").AppendChild(name) + if _path_is_under_skipped(path, skips): + logger.info("Skipping source subtree %s (transcode merge skip list)", path) + continue + if dst_layer.GetPrimAtPath(path): + logger.info("Keeping destination prim %s; not overwriting with source", path) + continue + total += _copy_prim_spec_recursive(src_layer, dst_layer, path, path) + + if total == 0: + logger.info("No source /World prims merged (empty or all skipped / already present)") + else: + logger.info("Merged %d source prim subtree(s) at original /World paths", total) + + return total + + +def merge_source_prim_at_same_path(dest_stage, source_stage, prim_path: str) -> int: + """ + Copy one source root-layer prim subtree to the destination at the same path. + + This preserves non-geometry export data, such as `/Render`, during USD to + USD transcode without regenerating renderer state from Python objects. + """ + src_layer = source_stage.GetRootLayer() + dst_layer = dest_stage.GetRootLayer() + path = Sdf.Path(prim_path) + + if not src_layer.GetPrimAtPath(path): + logger.info("Source USD has no %s prim; nothing to merge", prim_path) + return 0 + if dst_layer.GetPrimAtPath(path): + logger.info("Keeping destination prim %s; not overwriting with source", prim_path) + return 0 + + count = _copy_prim_spec_recursive(src_layer, dst_layer, path, path) + logger.info("Merged source %s subtree with %d prim(s)", prim_path, count) + return count + + +def copy_authored_time_settings_from_source(source_stage, dest_stage) -> None: + """Copy authored time code range and FPS from source to destination stage when set.""" + try: + if getattr(source_stage, "HasAuthoredTimeCodeRange", None) and source_stage.HasAuthoredTimeCodeRange(): + dest_stage.SetStartTimeCode(source_stage.GetStartTimeCode()) + dest_stage.SetEndTimeCode(source_stage.GetEndTimeCode()) + tps = source_stage.GetTimeCodesPerSecond() + if tps is not None and float(tps) > 0.0: + dest_stage.SetTimeCodesPerSecond(tps) + except Exception as ex: + logger.debug("Could not copy time settings from source stage: %s", ex) + + +# Filenames we always author in LightField USDZ export (never pull from source package). +_OUTPUT_AUTHORED_NAMES = frozenset({"gaussians.usdc", "default.usda"}) + + +def _basename_packaged_ref(asset_path: str) -> Optional[str]: + """USDZ-flat basename for a relative layer/asset reference, or None if not packagable.""" + if not asset_path: + return None + s = asset_path.strip().strip("@") + if not s or "://" in s or s.startswith("/"): + return None + return Path(s.replace("\\", "/")).name + + +def _gather_ref_payload_basenames_from_prim_spec(spec: Sdf.PrimSpec) -> Set[str]: + out: Set[str] = set() + if not spec: + return out + ref_list = spec.referenceList + for item in list(ref_list.prependedItems) + list(ref_list.appendedItems): + bn = _basename_packaged_ref(getattr(item, "assetPath", "") or "") + if bn: + out.add(bn) + pay_list = getattr(spec, "payloadList", None) + if pay_list is not None: + for item in list(pay_list.prependedItems) + list(pay_list.appendedItems): + bn = _basename_packaged_ref(getattr(item, "assetPath", "") or "") + if bn: + out.add(bn) + for prop in spec.properties: + default_value = getattr(prop, "default", None) + asset_path = getattr(default_value, "path", None) or getattr( + default_value, + "assetPath", + None, + ) + if asset_path: + bn = _basename_packaged_ref(asset_path) + if bn: + out.add(bn) + return out + + +def _companion_sidecar_basenames(basename: str) -> Set[str]: + """Additional package files implied by a referenced asset.""" + if basename.endswith(".slang"): + return {f"{basename}.lua"} + return set() + + +def _walk_prim_subtree(layer: Sdf.Layer, root_path: Sdf.Path): + """Depth-first active prims under root_path (inclusive).""" + spec = layer.GetPrimAtPath(root_path) + if not spec or not spec.active: + return + yield root_path + for child_spec in spec.nameChildren: + yield from _walk_prim_subtree(layer, root_path.AppendChild(child_spec.name)) + + +def _gather_refs_from_layer_subtree(layer: Sdf.Layer, path_prefix: str) -> Set[str]: + """Collect referenced basenames from all prims under path_prefix on this layer.""" + needed: Set[str] = set() + root = Sdf.Path(path_prefix) + if not layer.GetPrimAtPath(root): + return needed + for path in _walk_prim_subtree(layer, root): + spec = layer.GetPrimAtPath(path) + needed |= _gather_ref_payload_basenames_from_prim_spec(spec) + return needed + + +def _walk_entire_layer(layer: Sdf.Layer): + """All active prim paths (excluding absolute root pseudo-prim).""" + root = Sdf.Path("/") + spec = layer.GetPrimAtPath(root) + if not spec: + return + for child_spec in spec.nameChildren: + yield from _walk_prim_subtree(layer, root.AppendChild(child_spec.name)) + + +def collect_transitive_sidecars_for_subtree( + dest_layer: Sdf.Layer, + res_root: Path, + path_prefix: str, + extra_skip_names: Optional[Collection[str]] = None, +) -> List[NamedSerialized]: + """ + Resolve layer/asset references under ``path_prefix`` and bundle files from + ``res_root`` into the output USDZ (flat layout). + + Follows references/payloads transitively through USD layers. Skips names in + ``_OUTPUT_AUTHORED_NAMES`` and ``extra_skip_names`` (e.g. source root default + file). + """ + skip: Set[str] = set(_OUTPUT_AUTHORED_NAMES) + if extra_skip_names: + skip.update(extra_skip_names) + + seed = _gather_refs_from_layer_subtree(dest_layer, path_prefix) + queue: Set[str] = {n for n in seed if n not in skip} + done: Set[str] = set(skip) + result: List[NamedSerialized] = [] + + while queue: + name = queue.pop() + if name in done: + continue + done.add(name) + path = res_root / name + if not path.is_file(): + logger.warning("Referenced package file missing under %s: %s", res_root, name) + continue + try: + data = path.read_bytes() + except OSError as e: + logger.warning("Could not read sidecar %s: %s", path, e) + continue + result.append(NamedSerialized(filename=name, serialized=data)) + for companion in _companion_sidecar_basenames(name): + if companion not in done: + queue.add(companion) + + suf = path.suffix.lower() + if suf not in (".usd", ".usda", ".usdc"): + continue + sub = Sdf.Layer.FindOrOpen(str(path)) + if not sub: + logger.warning("Could not open referenced layer for sidecar walk: %s", path) + continue + for p in _walk_entire_layer(sub): + spec = sub.GetPrimAtPath(p) + for bn in _gather_ref_payload_basenames_from_prim_spec(spec): + if bn and bn not in done: + queue.add(bn) + + if result: + logger.info( + "Bundled %d sidecar file(s) from %s for %s references", + len(result), + res_root, + path_prefix, + ) + return result + + +def collect_transitive_sidecars_for_world_subtree( + dest_layer: Sdf.Layer, + res_root: Path, + world_prefix: str = "/World", + extra_skip_names: Optional[Collection[str]] = None, +) -> List[NamedSerialized]: + return collect_transitive_sidecars_for_subtree( + dest_layer, + res_root, + path_prefix=world_prefix, + extra_skip_names=extra_skip_names, + ) + + +@contextmanager +def usd_stage_path_context_for_camera_copy(usd_path: Path) -> Iterator[Optional[UsdStagePathPair]]: + """ + Yield (root_stage_path, asset_resolution_dir) for opening a USD/USDZ with correct asset paths. + + For USDZ, extracts to a temporary directory (deleted on exit). + """ + path = usd_path.resolve() + suffix = path.suffix.lower() + if suffix not in (".usd", ".usda", ".usdc", ".usdz"): + yield None + return + + if suffix == ".usdz": + with tempfile.TemporaryDirectory() as tmp: + tmp_path = Path(tmp) + with zipfile.ZipFile(path, "r") as zf: + zf.extractall(tmp_path) + usd_files = list(tmp_path.glob("*.usd*")) + root_file = None + for f in usd_files: + if f.stem == "default": + root_file = f + break + if root_file is None and usd_files: + root_file = usd_files[0] + if root_file is None: + logger.warning("USDZ has no USD root for source prim copy: %s", path) + yield None + return + yield (root_file.resolve(), tmp_path.resolve()) + return + + yield (path, path.parent.resolve()) diff --git a/threedgrut/export/usd/exporter.py b/threedgrut/export/usd/exporter.py index 71e3276e..437d6082 100644 --- a/threedgrut/export/usd/exporter.py +++ b/threedgrut/export/usd/exporter.py @@ -26,6 +26,7 @@ import numpy as np import torch +from pxr import Usd from ncore.data import ( OpenCVFisheyeCameraModelParameters, OpenCVPinholeCameraModelParameters, @@ -47,11 +48,64 @@ ) from threedgrut.export.usd.writers.background import export_background_to_usd from threedgrut.export.usd.writers.base import create_gaussian_writer +from threedgrut.export.usd.camera_copy import ( + collect_transitive_sidecars_for_subtree, + copy_authored_time_settings_from_source, + merge_source_prim_at_same_path, + merge_source_world_at_same_paths, +) +from threedgrut.export.usd.particle_field_hints import ( + DEFAULT_PARTICLE_FIELD_SORTING_MODE_HINT, + normalize_particle_field_sorting_mode_hint, +) +from threedgrut.export.usd.post_processing_sh_bake import ( + MODE_PPISP_BAKE_VIGNETTING_NONE, + scale_sh_output, +) from threedgrut.export.usd.writers.camera import export_cameras_to_usd logger = logging.getLogger(__name__) +_GAUSSIAN_SKIP_TONEMAPPING_RENDER_SETTING = "rtx:rtpt:gaussian:skipTonemapping:enabled" +MODE_POST_PROCESSING_EXPORT_BAKED_SH = "baked-sh" +MODE_POST_PROCESSING_EXPORT_OMNI_NATIVE = "omni-native" +POST_PROCESSING_EXPORT_MODES = { + MODE_POST_PROCESSING_EXPORT_BAKED_SH, + MODE_POST_PROCESSING_EXPORT_OMNI_NATIVE, +} + + +def _set_render_setting(stage: Usd.Stage, key: str, value: Any) -> None: + render_settings = dict(stage.GetRootLayer().customLayerData.get("renderSettings", {}) or {}) + render_settings[key] = value + stage.SetMetadataByDictKey("customLayerData", "renderSettings", render_settings) + + +def _is_ppisp_post_processing(post_processing: Any) -> bool: + post_processing_type = type(post_processing) + return ( + post_processing_type.__name__ == "PPISP" + and post_processing_type.__module__.split(".", maxsplit=1)[0] == "ppisp" + ) + + +def normalize_post_processing_export_mode(mode: str | None) -> str: + normalized = MODE_POST_PROCESSING_EXPORT_BAKED_SH if mode is None else str(mode).strip().lower() + if normalized not in POST_PROCESSING_EXPORT_MODES: + raise ValueError( + f"Unsupported post-processing export mode '{mode}'. " + f"Expected one of: {sorted(POST_PROCESSING_EXPORT_MODES)}" + ) + return normalized + + +def _get_export_config_value(export_conf, hyphen_name: str, attr_name: str, default: Any) -> Any: + if hasattr(export_conf, "get"): + return export_conf.get(hyphen_name, getattr(export_conf, attr_name, default)) + return getattr(export_conf, attr_name, default) + + def _extract_camera_params_from_dataset(dataset) -> Optional[List]: """ Extract per-frame camera parameters from a dataset. @@ -80,7 +134,7 @@ def _extract_camera_params_from_dataset(dataset) -> Optional[List]: camera_params.append(None) continue - params_dict, _, _, camera_name = params_tuple + params_dict, _, _, camera_name, *_ = params_tuple # Reconstruct CameraModelParameters from dict if camera_name == "OpenCVPinholeCameraModelParameters": @@ -155,6 +209,52 @@ def _extract_camera_params_from_dataset(dataset) -> Optional[List]: return None +def _extract_camera_grouping(dataset): + """Extract camera grouping info from a dataset. + + Returns: + (camera_names, frame_to_camera) where camera_names is a list of logical + camera names and frame_to_camera maps frame_idx → camera_idx. + """ + camera_names = None + frame_to_camera = None + + if hasattr(dataset, "get_camera_names"): + camera_names = dataset.get_camera_names() + if hasattr(dataset, "get_camera_idx"): + frame_to_camera = [dataset.get_camera_idx(i) for i in range(len(dataset))] + + if camera_names is None: + camera_names = ["camera_0000"] + if frame_to_camera is None: + frame_to_camera = [0] * len(dataset) + + return camera_names, frame_to_camera + + +def _extract_camera_resolutions(camera_params: List, camera_names: List[str], frame_to_camera: List[int]): + """Extract per-camera resolution from the first valid frame of each camera.""" + result = {} + num_cameras = len(camera_names) + first_frame: Dict[int, int] = {} + for frame_idx, cam_idx in enumerate(frame_to_camera): + if cam_idx not in first_frame and 0 <= cam_idx < num_cameras: + first_frame[cam_idx] = frame_idx + + for cam_idx, cam_name in enumerate(camera_names): + frame_idx = first_frame.get(cam_idx) + if frame_idx is None or camera_params is None: + continue + params = camera_params[frame_idx] if frame_idx < len(camera_params) else None + if params is None: + continue + if hasattr(params, "resolution"): + w, h = int(params.resolution[0]), int(params.resolution[1]) + result[cam_name] = (w, h) + + return result + + class USDExporter(ModelExporter): """ Exporter for OpenUSD format using ParticleField3DGaussianSplat schema. @@ -164,11 +264,12 @@ class USDExporter(ModelExporter): Features: - ParticleField3DGaussianSplat schema (standard OpenUSD) - - Optional camera export with full intrinsics + - One Camera prim per physical camera with time-sampled transforms - Background/environment export as DomeLight + - Optional baked-SH post-processing export or PPISP Omniverse native export - USDZ packaging (default output) - For Omniverse/NuRec compatibility, use NuRecExporter instead. + For NuRec compatibility, use NuRecExporter instead. """ def __init__( @@ -179,8 +280,26 @@ def __init__( export_cameras: bool = True, export_background: bool = True, apply_normalizing_transform: bool = True, - sorting_mode_hint: str = "cameraDistance", + sorting_mode_hint: str = DEFAULT_PARTICLE_FIELD_SORTING_MODE_HINT, linear_srgb: bool = False, + export_post_processing: bool = True, + post_processing_export_mode: str = MODE_POST_PROCESSING_EXPORT_BAKED_SH, + post_processing_export_camera_id: int | None = None, + post_processing_export_frame_id: int | None = None, + ignore_ppisp_controller: bool = False, + post_processing_bake_epochs: int = 7, + post_processing_bake_learning_rate: float = 2.5e-3, + post_processing_bake_learning_rate_specular: float | None = None, + post_processing_bake_learning_rate_density: float = 5.0e-2, + post_processing_bake_camera_id: int = 0, + post_processing_bake_frame_id: int = 0, + ppisp_bake_vignetting_mode: str = MODE_PPISP_BAKE_VIGNETTING_NONE, + post_processing_bake_view_mode: str = "trajectory", + post_processing_bake_view_seed: int | None = None, + post_processing_bake_trajectory_weight_position: float = 1.0, + post_processing_bake_trajectory_weight_rotation: float = 0.5, + output_scale: float = 1.0, + frames_per_second: float = 1.0, ): """ Initialize the USD exporter. @@ -189,11 +308,53 @@ def __init__( half_precision: If True, use half for both geometry and features (backward compat). half_geometry: Use half precision for positions, orientations, scales (LightField). half_features: Use half precision for opacities and SH coefficients (LightField). - export_cameras: Include camera poses in export - export_background: Include background/environment in export - apply_normalizing_transform: Apply transform to normalize scene orientation - sorting_mode_hint: Sorting hint for rendering ("cameraDistance", "zDepth" per UsdVol schema) - linear_srgb: If True, set prim color space to lin_rec709_scene; else srgb_rec709_display + export_cameras: Include camera poses in export. + export_background: Include background/environment in export. + apply_normalizing_transform: Apply transform to normalize scene orientation. + sorting_mode_hint: Sorting hint for rendering ("zDepth", "cameraDistance", "rayHitDistance"). + linear_srgb: If True, set prim color space to lin_rec709_scene. + export_post_processing: If True, export the checkpoint post-processing + module with the selected export mode. + post_processing_export_mode: "baked-sh" bakes one fixed transform + into Gaussian SH coefficients. "omni-native" uses the module's + Omniverse-native path; currently PPISP SPG. + post_processing_export_camera_id: Optional PPISP camera index to use + for every RenderProduct in omni-native mode. + ignore_ppisp_controller: If True, skip the PPISP controller export + even when the checkpoint has trained controllers, and fall back + to time-sampled exposure / colour USD attributes derived from + ``ppisp.exposure_params`` and ``ppisp.color_params``. No effect + on checkpoints that were trained without a controller. + post_processing_export_frame_id: Optional PPISP frame index to write + as static exposure/color inputs in omni-native mode. + post_processing_bake_epochs: Number of sequential passes over the train/reference set. + post_processing_bake_learning_rate: Adam learning rate for features_albedo + (default 2.5e-3, matches 3DGS). + post_processing_bake_learning_rate_specular: Adam learning rate for + features_specular. Defaults to ``learning_rate / 20`` (the 3DGS ratio). + post_processing_bake_learning_rate_density: Adam learning rate for density + (default 5e-2, matches 3DGS). Optimising density alongside SH absorbs + spatial frequencies the SH alone aliases as colour rainbow fringes. + post_processing_bake_camera_id: Camera index for the fixed baked transform. + post_processing_bake_frame_id: Frame index for the fixed baked transform. + ppisp_bake_vignetting_mode: "none" (default) -- bake produces gamma-space + SH coefficients with no vignetting; the asset format aligns with + no-PPISP exports. "achromatic-fit" is retained for backwards + compatibility but no longer the recommended mode. + post_processing_bake_view_mode: which views the bake fit sees per step. + "training" iterates the train dataloader (default). "trajectory" + orders the training views along an NN+2-opt camera path, parameterises + arc-length on [0, 1], and samples a random t per step (helpful when + training views are sparse). + post_processing_bake_view_seed: optional RNG seed for the interpolation + samplers. None (default) leaves it non-deterministic. + post_processing_bake_trajectory_weight_position: trajectory mode only. + Weight on the (mean-normalised) position term in the pose distance. + post_processing_bake_trajectory_weight_rotation: trajectory mode only. + Weight on the (1 - cos(angle)) rotation term in the pose distance. + frames_per_second: Sets stage.timeCodesPerSecond. Time codes are always + bare frame indices (float(frame_idx)), so this controls playback speed. + Default 1.0 means 1 frame per second of real time. """ if half_precision: half_geometry = True @@ -203,27 +364,53 @@ def __init__( self.export_cameras = export_cameras self.export_background = export_background self.apply_normalizing_transform = apply_normalizing_transform - self.sorting_mode_hint = sorting_mode_hint + self.sorting_mode_hint = normalize_particle_field_sorting_mode_hint(sorting_mode_hint) self.linear_srgb = linear_srgb + self.export_post_processing = export_post_processing + self.post_processing_export_mode = normalize_post_processing_export_mode(post_processing_export_mode) + self.post_processing_export_camera_id = ( + None if post_processing_export_camera_id is None else int(post_processing_export_camera_id) + ) + self.post_processing_export_frame_id = ( + None if post_processing_export_frame_id is None else int(post_processing_export_frame_id) + ) + self.ignore_ppisp_controller = bool(ignore_ppisp_controller) + self.post_processing_bake_epochs = int(post_processing_bake_epochs) + self.post_processing_bake_learning_rate = float(post_processing_bake_learning_rate) + self.post_processing_bake_learning_rate_specular = ( + None if post_processing_bake_learning_rate_specular is None + else float(post_processing_bake_learning_rate_specular) + ) + self.post_processing_bake_learning_rate_density = float( + post_processing_bake_learning_rate_density + ) + self.post_processing_bake_camera_id = int(post_processing_bake_camera_id) + self.post_processing_bake_frame_id = int(post_processing_bake_frame_id) + self.ppisp_bake_vignetting_mode = str(ppisp_bake_vignetting_mode) + self.post_processing_bake_view_mode = str(post_processing_bake_view_mode) + self.post_processing_bake_view_seed = ( + None if post_processing_bake_view_seed is None else int(post_processing_bake_view_seed) + ) + self.post_processing_bake_trajectory_weight_position = float( + post_processing_bake_trajectory_weight_position + ) + self.post_processing_bake_trajectory_weight_rotation = float( + post_processing_bake_trajectory_weight_rotation + ) + self.output_scale = float(output_scale) + self.frames_per_second = frames_per_second def _create_default_stage(self, referenced_stages: List[NamedUSDStage]) -> NamedUSDStage: """ Create a default.usda that references the data stages. - - Args: - referenced_stages: List of stages to reference (e.g., gaussians.usdc) - - Returns: - NamedUSDStage for default.usda """ stage = initialize_usd_stage(up_axis="Y") + stage.SetTimeCodesPerSecond(self.frames_per_second) for ref_stage in referenced_stages: - # Create a reference prim for each stage filename_stem = Path(ref_stage.filename).stem prim_path = f"/World/{filename_stem}" prim = stage.OverridePrim(prim_path) - # Reference the file (bare filename for in-package resolution; same as NuRec) prim.GetReferences().AddReference(ref_stage.filename) return NamedUSDStage(filename="default.usda", stage=stage) @@ -242,24 +429,81 @@ def export( Export the model to a USDZ file. Args: - model: The model to export (must implement ExportableModel) - output_path: Path where the USDZ file will be saved - dataset: Optional dataset for camera poses - conf: Configuration parameters - background: Optional background model for environment export - **kwargs: Additional parameters + model: The model to export (must implement ExportableModel). + output_path: Path where the USDZ file will be saved. + dataset: Optional dataset for camera poses. + conf: Configuration parameters. + background: Optional background model for environment export. + **kwargs: + post_processing: checkpoint post-processing module to bake or export natively. + validate_usd (default True): run OpenUSD stage validators. + apply_coordinate_transform (bool): apply 3DGRUT→USDZ coordinate flip. + copy_source_usd: (stage_path, res_root) for prim merge. + copy_source_skip_subtrees: subtrees to skip during prim merge. """ output_path = Path(output_path) logger.info(f"Exporting USD file to {output_path}...") + post_processing = kwargs.get("post_processing") + has_ppisp_module = _is_ppisp_post_processing(post_processing) + uses_baked_post_processing_export = ( + post_processing is not None + and self.export_post_processing + and self.post_processing_export_mode == MODE_POST_PROCESSING_EXPORT_BAKED_SH + ) + uses_omni_native_post_processing_export = ( + post_processing is not None + and self.export_post_processing + and self.post_processing_export_mode == MODE_POST_PROCESSING_EXPORT_OMNI_NATIVE + ) + + if uses_baked_post_processing_export: + from threedgrut.export.usd.post_processing_sh_bake import ( + PPISPPostProcessingBakeAdapter, + bake_post_processing_into_sh, + ) + + if not has_ppisp_module: + raise ValueError("Baked-SH post-processing export currently supports PPISP post-processing only.") + adapter = PPISPPostProcessingBakeAdapter( + camera_id=self.post_processing_bake_camera_id, + frame_id=self.post_processing_bake_frame_id, + vignetting_mode=self.ppisp_bake_vignetting_mode, + ) + logger.info( + "Baking post-processing into Gaussian SH coefficients before export " + f"(camera={self.post_processing_bake_camera_id}, frame={self.post_processing_bake_frame_id})" + ) + model = bake_post_processing_into_sh( + model=model, + post_processing=post_processing, + train_dataset=dataset, + conf=conf, + adapter=adapter, + epochs=self.post_processing_bake_epochs, + learning_rate=self.post_processing_bake_learning_rate, + learning_rate_specular=self.post_processing_bake_learning_rate_specular, + learning_rate_density=self.post_processing_bake_learning_rate_density, + view_sampling_mode=self.post_processing_bake_view_mode, + interpolated_views_seed=self.post_processing_bake_view_seed, + trajectory_weight_position=self.post_processing_bake_trajectory_weight_position, + trajectory_weight_rotation=self.post_processing_bake_trajectory_weight_rotation, + ) + if uses_omni_native_post_processing_export and not has_ppisp_module: + raise ValueError("Omniverse-native post-processing export currently supports PPISP post-processing only.") + + # User-requested constant brightness scale, applied uniformly to the + # SH output regardless of bake / colour-space mode. The DC offset + # baked into RGB2SH is compensated so a forward eval reproduces + # output_scale * (original SH-evaluated RGB). + if self.output_scale != 1.0: + scale_sh_output(model, self.output_scale) # Get model data via accessor - # LightField expects post-activation values (opacity in [0,1], actual scales) accessor = GaussianExportAccessor(model, conf) attrs = accessor.get_attributes(preactivation=False) caps = accessor.get_capabilities() logger.info(f"Schema: LightField (post-activation)") - logger.info(f"Exporting {attrs.num_gaussians} Gaussians, SH degree {caps.sh_degree}") # Compute normalizing transform if enabled @@ -272,13 +516,14 @@ def export( except (AttributeError, ValueError) as e: logger.warning(f"Failed to compute normalizing transform: {e}") - # Create main USD stage + # Create main USD stage with the configured time code rate stage = initialize_usd_stage(up_axis="Y") + stage.SetTimeCodesPerSecond(self.frames_per_second) apply_coordinate_transform = kwargs.get("apply_coordinate_transform", False) coordinate_transform = get_3dgrut_to_usdz_coordinate_transform() if apply_coordinate_transform else None - # Create Gaussian content root with optional normalizing and coordinate transform + # Create Gaussian content root gaussians_root = create_gaussian_model_root( stage, flip_x_axis=False, @@ -289,7 +534,7 @@ def export( coordinate_transform=coordinate_transform, ) - # Create Gaussian writer (LightField schema) + # Write Gaussians writer = create_gaussian_writer( stage=stage, capabilities=caps, @@ -298,43 +543,86 @@ def export( half_features=self.half_features, sorting_mode_hint=self.sorting_mode_hint, linear_srgb=self.linear_srgb, + omni_usd=uses_omni_native_post_processing_export, + has_post_processing=uses_omni_native_post_processing_export, ) - - # Write Gaussians writer.create_prim(attrs.num_gaussians) writer.write_attributes(attrs) writer.finalize(attrs.positions) - # Collect stages and files for USDZ - stages: List[NamedUSDStage] = [] + suffix = output_path.suffix.lower() + package_as_usdz = suffix == ".usdz" or suffix not in (".usd", ".usda", ".usdc") + + gaussians_stage = NamedUSDStage(filename="gaussians.usdc", stage=stage) + default_stage_wrapped: Optional[NamedUSDStage] = None + if package_as_usdz: + default_stage_wrapped = self._create_default_stage([gaussians_stage]) + scene_stage = default_stage_wrapped.stage if default_stage_wrapped is not None else stage + files: List[NamedSerialized] = [] - # Export cameras if requested and dataset available + copy_source_usd = kwargs.get("copy_source_usd") + if copy_source_usd is None: + copy_source_usd = kwargs.get("copy_cameras_source") + if copy_source_usd is not None: + stage_path, res_root = copy_source_usd + try: + src_stage = Usd.Stage.Open(str(stage_path)) + if not src_stage: + logger.warning("Could not open source USD for prim merge: %s", stage_path) + else: + skip = kwargs.get("copy_source_skip_subtrees") + merge_target = scene_stage + merge_source_world_at_same_paths(merge_target, src_stage, skip_source_subtrees=skip) + merge_source_prim_at_same_path(merge_target, src_stage, "/Render") + copy_authored_time_settings_from_source(src_stage, merge_target) + if package_as_usdz and res_root is not None and res_root.is_dir(): + for path_prefix in ("/World", "/Render"): + sidecars = collect_transitive_sidecars_for_subtree( + merge_target.GetRootLayer(), + res_root, + path_prefix=path_prefix, + extra_skip_names={Path(stage_path).name}, + ) + for entry in sidecars: + if not any(f.filename == entry.filename for f in files): + files.append(entry) + except Exception as e: + logger.warning("Failed to merge source USD prims: %s", e) + + # Extract camera grouping from dataset (used by both camera export and PPISP) + camera_names = None + frame_to_camera = None + camera_prim_paths: Dict[str, str] = {} + camera_params = None + + if dataset is not None: + camera_names, frame_to_camera = _extract_camera_grouping(dataset) + + # Export cameras — one prim per physical camera with time-sampled transforms if self.export_cameras and dataset is not None: try: poses = dataset.get_poses() - # When we apply normalizing transform to the Gaussian root, cameras must be in the - # same coordinate system: apply normalizing transform to each c2w (world → normalized). if self.apply_normalizing_transform: poses = np.einsum("ij,njk->nik", normalizing_transform, poses) - # Extract per-frame camera parameters from dataset camera_params = _extract_camera_params_from_dataset(dataset) - if camera_params is not None: logger.info(f"Extracted camera params for {len(camera_params)} frames") else: logger.warning("Could not extract camera intrinsics from dataset, using default") - export_cameras_to_usd( - stage=stage, + camera_prim_paths = export_cameras_to_usd( + stage=scene_stage, poses=poses, + camera_names=camera_names, + frame_to_camera=frame_to_camera, camera_params=camera_params, root_path="/World/Cameras", visible=False, ) - logger.info(f"Exported {len(poses)} cameras") + logger.info(f"Exported {len(camera_prim_paths)} camera(s) from {len(poses)} frames") except (AttributeError, KeyError, ValueError) as e: logger.warning(f"Failed to export cameras: {e}") @@ -343,7 +631,7 @@ def export( if self.export_background and background is not None: try: _, envmap_bytes = export_background_to_usd( - stage=stage, + stage=scene_stage, background=background, conf=conf, root_path="/World/Environment", @@ -355,42 +643,193 @@ def export( except (AttributeError, ValueError, ImportError) as e: logger.warning(f"Failed to export background: {e}") - # Determine output format - suffix = output_path.suffix.lower() + if not self.export_post_processing and _is_ppisp_post_processing(post_processing): + logger.warning( + "PPISP post-processing module is present but export_usd.export_post_processing=false; " + "PPISP effects will not be exported. Set export_usd.export_post_processing=true to export them." + ) + if self.export_post_processing and post_processing is None: + logger.info("Post-processing export requested but no post_processing module is available; skipping bake") + + if uses_omni_native_post_processing_export: + render_product_entries = self._create_ppisp_render_products( + stage=scene_stage, + dataset=dataset, + camera_names=camera_names, + frame_to_camera=frame_to_camera, + camera_prim_paths=camera_prim_paths, + camera_params=camera_params, + ) + if render_product_entries is not None: + _set_render_setting(scene_stage, _GAUSSIAN_SKIP_TONEMAPPING_RENDER_SETTING, False) + logger.info("Disabled Gaussian skip-tonemapping render setting for PPISP Omniverse-native export") + self._export_ppisp( + stage=scene_stage, + dataset=dataset, + camera_names=camera_names, + post_processing=post_processing, + files=files, + fixed_camera_id=self.post_processing_export_camera_id, + fixed_frame_id=self.post_processing_export_frame_id, + ) + + # Package if suffix == ".usdz": - # Package as USDZ with composition: - # - default.usda (text) references gaussians.usdc (binary) - gaussians_stage = NamedUSDStage(filename="gaussians.usdc", stage=stage) - default_stage = self._create_default_stage([gaussians_stage]) - # default.usda must be first in USDZ - write_to_usdz(output_path, [default_stage, gaussians_stage], files if files else None) + if default_stage_wrapped is None: + default_stage_wrapped = self._create_default_stage([gaussians_stage]) + write_to_usdz(output_path, [default_stage_wrapped, gaussians_stage], files if files else None) + written_path = output_path elif suffix in [".usda", ".usd", ".usdc"]: - # Export as plain USD (format determined by extension) stage.Export(str(output_path)) - # Also export envmap if present if envmap_bytes is not None: envmap_path = output_path.parent / "envmap.png" with open(envmap_path, "wb") as f: f.write(envmap_bytes) + written_path = output_path else: - # Default to USDZ usdz_path = output_path.with_suffix(".usdz") - gaussians_stage = NamedUSDStage(filename="gaussians.usdc", stage=stage) - default_stage = self._create_default_stage([gaussians_stage]) - write_to_usdz(usdz_path, [default_stage, gaussians_stage], files if files else None) + if default_stage_wrapped is None: + default_stage_wrapped = self._create_default_stage([gaussians_stage]) + write_to_usdz(usdz_path, [default_stage_wrapped, gaussians_stage], files if files else None) + written_path = usdz_path + + if kwargs.get("validate_usd", True): + from threedgrut.export.usd.validation import validate_exported_usd_stage + + validate_exported_usd_stage(written_path) logger.info(f"USD export complete: {output_path}") + def _create_ppisp_render_products( + self, + stage, + dataset, + camera_names, + frame_to_camera, + camera_prim_paths: Dict[str, str], + camera_params, + ): + """Create /Render RenderProducts for PPISP Omniverse-native export.""" + if dataset is None or not camera_prim_paths: + logger.warning("No camera prims available for PPISP RenderProduct wiring, skipping") + return None + + from threedgrut.export.usd.writers.render_product import create_render_products + + resolutions = _extract_camera_resolutions(camera_params, camera_names, frame_to_camera) + camera_entries = {} + for cam_name, cam_path in camera_prim_paths.items(): + w, h = resolutions.get(cam_name, (0, 0)) + camera_entries[cam_name] = (cam_path, w, h) + + try: + create_render_products(stage=stage, camera_entries=camera_entries) + except Exception as e: + logger.warning(f"Failed to create RenderProducts: {e}") + return None + + return camera_entries + + def _export_ppisp( + self, + stage, + dataset, + camera_names, + post_processing, + files: List[NamedSerialized], + fixed_camera_id: int | None = None, + fixed_frame_id: int | None = None, + ) -> None: + """Attach PPISP SPG shaders to existing RenderProducts.""" + try: + from ppisp import PPISP # type: ignore[import-not-found] + except ImportError: + logger.warning("ppisp package not available, skipping PPISP export") + return + + if not isinstance(post_processing, PPISP): + logger.warning( + f"export_post_processing=True but post_processing is {type(post_processing).__name__}, " + "expected ppisp.PPISP - skipping" + ) + return + + ppisp_config = getattr(post_processing, "config", None) + controllers = getattr(post_processing, "controllers", None) + has_controller = ( + bool(getattr(ppisp_config, "use_controller", False)) and controllers is not None and len(controllers) > 0 + ) + # The static-frame override modes (fixed_frame_id) intentionally bypass + # the controller because the goal is to bake one specific frame's + # corrections, not to predict them at runtime. + # ignore_ppisp_controller forces the same fall-back even with animation, + # so consumers that don't want runtime controller dispatch can ship the + # optimized per-frame exposure / colour USD attributes instead. + use_controller = ( + has_controller + and fixed_frame_id is None + and not self.ignore_ppisp_controller + ) + if has_controller and fixed_frame_id is not None: + logger.info( + "PPISP controller present but fixed_frame_id is set; using static " + "exposure/color from frame %d instead of the controller.", + fixed_frame_id, + ) + elif has_controller and self.ignore_ppisp_controller: + logger.info( + "PPISP controller present but ignore_ppisp_controller is set; " + "exporting time-sampled exposure/color from optimized PPISP parameters " + "instead of the runtime controller." + ) + + from threedgrut.export.usd.ppisp_spg import get_ppisp_spg_files, get_ppisp_spg_dyn_files + from threedgrut.export.usd.writers.ppisp_writer import ( + add_ppisp_to_all_render_products, + build_camera_frame_mapping, + ) + + _, camera_frame_mapping = build_camera_frame_mapping(dataset) + + try: + add_ppisp_to_all_render_products( + stage=stage, + ppisp=post_processing, + camera_names=camera_names, + camera_frame_mapping=camera_frame_mapping, + fixed_camera_index=fixed_camera_id, + fixed_frame_index=fixed_frame_id, + use_controller=use_controller, + ) + except Exception as e: + logger.warning(f"Failed to add PPISP shaders: {e}") + return + + if use_controller: + spg_files = list(get_ppisp_spg_dyn_files()) + from threedgrut.export.usd.writers.ppisp_controller_writer import ( + get_controller_sidecars, + ) + for s in get_controller_sidecars(): + if not any(f.filename == s.filename for f in spg_files): + spg_files.append(s) + else: + spg_files = get_ppisp_spg_files() + + for spg_file in spg_files: + if not any(f.filename == spg_file.filename for f in files): + files.append(spg_file) + + logger.info( + "PPISP Omniverse-native export complete: %d sidecar(s) added (controller=%s)", + len(files), + use_controller, + ) + @classmethod def from_config(cls, conf) -> "USDExporter": """ Create USDExporter from configuration. - - Args: - conf: Configuration object with export_usd section - - Returns: - Configured USDExporter instance """ export_conf = getattr(conf, "export_usd", None) or conf half_precision = getattr(export_conf, "half_precision", False) @@ -405,6 +844,109 @@ def from_config(cls, conf) -> "USDExporter": export_cameras=getattr(export_conf, "export_cameras", True), export_background=getattr(export_conf, "export_background", True), apply_normalizing_transform=getattr(export_conf, "apply_normalizing_transform", True), - sorting_mode_hint=getattr(export_conf, "sorting_mode_hint", "cameraDistance"), + sorting_mode_hint=getattr(export_conf, "sorting_mode_hint", DEFAULT_PARTICLE_FIELD_SORTING_MODE_HINT), linear_srgb=getattr(export_conf, "linear_srgb", False), + export_post_processing=_get_export_config_value( + export_conf, + "export-post-processing", + "export_post_processing", + True, + ), + post_processing_export_mode=_get_export_config_value( + export_conf, + "post-processing-export-mode", + "post_processing_export_mode", + MODE_POST_PROCESSING_EXPORT_BAKED_SH, + ), + post_processing_export_camera_id=_get_export_config_value( + export_conf, + "post-processing-export-camera-id", + "post_processing_export_camera_id", + None, + ), + post_processing_export_frame_id=_get_export_config_value( + export_conf, + "post-processing-export-frame-id", + "post_processing_export_frame_id", + None, + ), + ignore_ppisp_controller=_get_export_config_value( + export_conf, + "ignore-ppisp-controller", + "ignore_ppisp_controller", + False, + ), + post_processing_bake_epochs=_get_export_config_value( + export_conf, + "post-processing-bake-epochs", + "post_processing_bake_epochs", + 7, + ), + post_processing_bake_learning_rate=_get_export_config_value( + export_conf, + "post-processing-bake-learning-rate", + "post_processing_bake_learning_rate", + 2.5e-3, + ), + post_processing_bake_learning_rate_specular=_get_export_config_value( + export_conf, + "post-processing-bake-learning-rate-specular", + "post_processing_bake_learning_rate_specular", + None, + ), + post_processing_bake_learning_rate_density=_get_export_config_value( + export_conf, + "post-processing-bake-learning-rate-density", + "post_processing_bake_learning_rate_density", + 5.0e-2, + ), + post_processing_bake_camera_id=_get_export_config_value( + export_conf, + "post-processing-bake-camera-id", + "post_processing_bake_camera_id", + 0, + ), + post_processing_bake_frame_id=_get_export_config_value( + export_conf, + "post-processing-bake-frame-id", + "post_processing_bake_frame_id", + 0, + ), + ppisp_bake_vignetting_mode=_get_export_config_value( + export_conf, + "ppisp-bake-vignetting-mode", + "ppisp_bake_vignetting_mode", + MODE_PPISP_BAKE_VIGNETTING_NONE, + ), + post_processing_bake_view_mode=_get_export_config_value( + export_conf, + "post-processing-bake-view-mode", + "post_processing_bake_view_mode", + "training", + ), + post_processing_bake_view_seed=_get_export_config_value( + export_conf, + "post-processing-bake-view-seed", + "post_processing_bake_view_seed", + None, + ), + post_processing_bake_trajectory_weight_position=_get_export_config_value( + export_conf, + "post-processing-bake-trajectory-weight-position", + "post_processing_bake_trajectory_weight_position", + 1.0, + ), + post_processing_bake_trajectory_weight_rotation=_get_export_config_value( + export_conf, + "post-processing-bake-trajectory-weight-rotation", + "post_processing_bake_trajectory_weight_rotation", + 0.5, + ), + output_scale=_get_export_config_value( + export_conf, + "output-scale", + "output_scale", + 1.0, + ), + frames_per_second=getattr(export_conf, "frames_per_second", 1.0), ) diff --git a/threedgrut/export/usd/particle_field_hints.py b/threedgrut/export/usd/particle_field_hints.py new file mode 100644 index 00000000..fc815162 --- /dev/null +++ b/threedgrut/export/usd/particle_field_hints.py @@ -0,0 +1,35 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""ParticleField schema hint tokens supported by usd-core 26.5+.""" + +DEFAULT_PARTICLE_FIELD_SORTING_MODE_HINT = "cameraDistance" + +PARTICLE_FIELD_SORTING_MODE_HINTS = ( + "zDepth", + "cameraDistance", + "rayHitDistance", +) + + +def normalize_particle_field_sorting_mode_hint(value: str) -> str: + """Normalize and validate a ParticleField sortingModeHint token.""" + normalized = str(value).strip() + if normalized not in PARTICLE_FIELD_SORTING_MODE_HINTS: + raise ValueError( + f"Unsupported ParticleField sortingModeHint '{value}'. " + f"Expected one of: {list(PARTICLE_FIELD_SORTING_MODE_HINTS)}" + ) + return normalized diff --git a/threedgrut/export/usd/post_processing_sh_bake.py b/threedgrut/export/usd/post_processing_sh_bake.py new file mode 100644 index 00000000..2bab8d6b --- /dev/null +++ b/threedgrut/export/usd/post_processing_sh_bake.py @@ -0,0 +1,447 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Fit fixed post-processing transforms into Gaussian SH coefficients for export.""" + +from __future__ import annotations + +import copy +import logging +from typing import Iterable + +import torch +import torch.nn as nn + +from threedgrut.datasets.utils import configure_dataloader_for_platform +from threedgrut.utils.render import C0, apply_post_processing + +logger = logging.getLogger(__name__) + + +def scale_sh_output(model, scale: float) -> None: + """In-place scale the SH-evaluated RGB output by ``scale``. + + SH eval is ``rgb = features_albedo * C0 + 0.5 + sum_k Y_k * features_specular_k``. + To get ``s * rgb`` from a forward eval, every term must be scaled: + * features_specular -> s * features_specular (linear, view-dep bands) + * features_albedo -> s * features_albedo + (s - 1) * 0.5 / C0 + compensates for the constant ``0.5`` offset in the DC band. + """ + if scale == 1.0: + return + s = float(scale) + with torch.no_grad(): + model.features_specular.mul_(s) + model.features_albedo.mul_(s).add_((s - 1.0) * 0.5 / C0) + logger.info("Scaled SH output by %.4f (DC offset compensated)", s) + + +class PostProcessingBakeAdapter: + """Adapter interface for baking one fixed post-processing transform.""" + + name = "post-processing" + + def validate(self, post_processing: nn.Module) -> None: + del post_processing + + def create_fixed_post_processing(self, post_processing: nn.Module, device: str) -> nn.Module: + return copy.deepcopy(post_processing).to(device).eval() + + def apply_fit_transform(self, rgb: torch.Tensor, fixed_post_processing: nn.Module, gpu_batch) -> torch.Tensor: + del fixed_post_processing, gpu_batch + return rgb + + def initialize_fit(self, baked_model, post_processing: nn.Module) -> None: + """Optionally warm-start the SH fit with a closed-form initialization. + + Default is a no-op: the cloned ``baked_model`` keeps its checkpoint + SH coefficients as the starting point. Subclasses (e.g. PPISP) can + override to apply a one-shot bake before Adam takes over. + """ + del baked_model, post_processing + + def log_context(self) -> str: + return "" + + +def _set_sh_fit_parameters(model) -> Iterable[torch.nn.Parameter]: + for parameter in model.parameters(): + parameter.requires_grad_(False) + + fit_parameters = [] + for field_name in ("features_albedo", "features_specular"): + parameter = getattr(model, field_name) + parameter.requires_grad_(True) + fit_parameters.append(parameter) + return fit_parameters + + +def _create_train_dataloader(conf, train_dataset): + num_workers = int(getattr(conf, "num_workers", 8)) + dataloader_kwargs = configure_dataloader_for_platform( + { + "num_workers": num_workers, + "batch_size": 1, + "shuffle": True, + "pin_memory": True, + "persistent_workers": True if num_workers > 0 else False, + } + ) + return torch.utils.data.DataLoader(train_dataset, **dataloader_kwargs) + + +def _render_reference(reference_model, fixed_post_processing, gpu_batch) -> torch.Tensor: + with torch.no_grad(): + outputs = reference_model(gpu_batch) + outputs = apply_post_processing(fixed_post_processing, outputs, gpu_batch, training=True) + return outputs["pred_rgb"].detach() + + +def bake_post_processing_into_sh( + model, + post_processing: nn.Module, + train_dataset, + conf, + *, + adapter: PostProcessingBakeAdapter, + epochs: int = 1, + learning_rate: float = 2.5e-3, + learning_rate_specular: float | None = None, + learning_rate_density: float = 5.0e-2, + device: str = "cuda", + view_sampling_mode: str = "training", + interpolated_views_seed: int | None = None, + trajectory_weight_position: float = 1.0, + trajectory_weight_rotation: float = 0.5, +): + """Return a cloned model whose SH coefficients approximate fixed post-processing output. + + Three parameter groups are co-optimised, mirroring 3DGS training defaults: + + * ``features_albedo`` at ``learning_rate`` (default 2.5e-3) + * ``features_specular`` at ``learning_rate_specular`` (default = lr/20) + * ``density`` at ``learning_rate_density`` (default 5e-2) + + Letting density breathe absorbs spatial frequencies the SH alone can't + capture without aliasing -- on harder scenes (caterpillar) this is + worth +5 dB worst-case PSNR over fitting only colour coefficients. + + ``view_sampling_mode`` controls what the optimizer sees each step: + + * ``"training"`` (default) -- iterate the training dataloader as usual. + * ``"trajectory"`` -- order the training views along an approximate + Hamiltonian path (NN + 2-opt on a position+direction metric), + arc-length-parameterise the path on ``[0, 1]``, sample random + ``t ∈ [0, 1]``, slerp inside the bracketing segment. Helpful on + datasets with sparse view coverage. + + The trajectory mode synthesises a ``Batch`` per step from the + template of the first training batch, replacing ``T_to_world`` with + the interpolated pose. ``steps_per_epoch`` matches + ``len(train_dataloader)`` so total step count is unchanged. + """ + from threedgrut.export.usd.post_processing_view_interpolation import ( + InterpolatedViewSampler, + VIEW_SAMPLING_TRAINING, + normalize_view_sampling_mode, + ) + + if not hasattr(model, "clone"): + raise TypeError("Post-processing SH bake export requires a cloneable MixtureOfGaussians model.") + if train_dataset is None: + raise ValueError("Post-processing SH bake export requires a train dataset. Pass --dataset if it is missing.") + if post_processing is None: + raise ValueError("Post-processing SH bake export requires a post_processing module.") + if epochs < 1: + raise ValueError(f"epochs must be >= 1, got {epochs}.") + view_sampling_mode = normalize_view_sampling_mode(view_sampling_mode) + + adapter.validate(post_processing) + reference_model = model.to(device).eval() + reference_model.build_acc() + baked_model = model.clone().to(device).eval() + baked_model.build_acc() + fixed_post_processing = adapter.create_fixed_post_processing(post_processing, device) + + # Warm-start the cloned SH state with the adapter's closed-form bake. + # PPISPPostProcessingBakeAdapter writes display-referred (gamma-space) + # DC; Adam takes over from there. Reduces the iterations needed and + # avoids fitting from a checkpoint state far from the optimum. + adapter.initialize_fit(baked_model, post_processing) + + if learning_rate_specular is None: + learning_rate_specular = learning_rate / 20.0 # 3DGS default ratio + + _set_sh_fit_parameters(baked_model) + baked_model.density.requires_grad_(True) + optimizer = torch.optim.Adam([ + {"params": [baked_model.features_albedo], "lr": learning_rate}, + {"params": [baked_model.features_specular], "lr": learning_rate_specular}, + {"params": [baked_model.density], "lr": learning_rate_density}, + ]) + train_dataloader = _create_train_dataloader(conf, train_dataset) + steps_per_epoch = len(train_dataloader) + + sampler: InterpolatedViewSampler | None = None + if view_sampling_mode != VIEW_SAMPLING_TRAINING: + # Cache one real training batch to seed the synthetic sampler with + # valid intrinsics / rays / pixel coords; only T_to_world rotates + # per step. + first_batch = next(iter(train_dataloader)) + template = train_dataset.get_gpu_batch_with_intrinsics(first_batch) + sampler = InterpolatedViewSampler( + train_dataset, + template_gpu_batch=template, + mode=view_sampling_mode, + steps_per_epoch=steps_per_epoch, + seed=interpolated_views_seed, + weight_position=trajectory_weight_position, + weight_rotation=trajectory_weight_rotation, + ) + + logger.info( + "Fitting %s SH bake: mode=%s epochs=%s steps_per_epoch=%s%s", + adapter.name, + view_sampling_mode, + epochs, + steps_per_epoch, + adapter.log_context(), + ) + + def _gpu_batches(): + if sampler is None: + for batch in train_dataloader: + yield train_dataset.get_gpu_batch_with_intrinsics(batch) + else: + for gpu_batch in sampler: + yield gpu_batch + + with torch.enable_grad(): + global_step = 0 + total_steps = epochs * steps_per_epoch + for epoch in range(epochs): + for gpu_batch in _gpu_batches(): + global_step += 1 + reference_rgb = _render_reference(reference_model, fixed_post_processing, gpu_batch) + + optimizer.zero_grad(set_to_none=True) + baked_outputs = baked_model(gpu_batch) + fitted_rgb = adapter.apply_fit_transform( + baked_outputs["pred_rgb"], + fixed_post_processing, + gpu_batch, + ) + loss = torch.nn.functional.mse_loss(fitted_rgb, reference_rgb) + + loss.backward() + optimizer.step() + + if global_step == 1 or global_step % 50 == 0 or global_step == total_steps: + logger.info( + "%s SH bake epoch %s/%s step %s/%s loss=%.6g", + adapter.name, + epoch + 1, + epochs, + global_step, + total_steps, + float(loss.detach()), + ) + + for parameter in baked_model.parameters(): + parameter.requires_grad_(False) + baked_model.eval() + logger.info("%s SH bake complete", adapter.name) + return baked_model + + +MODE_PPISP_BAKE_VIGNETTING_NONE = "none" +MODE_PPISP_BAKE_VIGNETTING_ACHROMATIC_FIT = "achromatic-fit" +PPISP_BAKE_VIGNETTING_MODES = { + MODE_PPISP_BAKE_VIGNETTING_NONE, + MODE_PPISP_BAKE_VIGNETTING_ACHROMATIC_FIT, +} + + +class FixedPPISP(nn.Module): + """Wrap PPISP as one fixed camera/frame color transform.""" + + def __init__( + self, + ppisp: nn.Module, + camera_id: int, + frame_id: int, + device: str, + include_vignetting: bool = True, + ) -> None: + super().__init__() + self.camera_id = int(camera_id) + self.frame_id = int(frame_id) + self.ppisp = copy.deepcopy(ppisp).to(device).eval() + + if hasattr(self.ppisp, "config") and hasattr(self.ppisp.config, "use_controller"): + self.ppisp.config.use_controller = False + if not include_vignetting and hasattr(self.ppisp, "vignetting_params"): + with torch.no_grad(): + self.ppisp.vignetting_params.zero_() + + def forward( + self, + rgb: torch.Tensor, + pixel_coords: torch.Tensor, + resolution: tuple[int, int], + camera_idx=None, + frame_idx=None, + exposure_prior=None, + ) -> torch.Tensor: + del camera_idx, frame_idx, exposure_prior + return self.ppisp( + rgb, + pixel_coords, + resolution=resolution, + camera_idx=self.camera_id, + frame_idx=self.frame_id, + exposure_prior=None, + ) + + +def normalize_ppisp_bake_vignetting_mode(mode: str | None) -> str: + normalized = MODE_PPISP_BAKE_VIGNETTING_NONE if mode is None else str(mode).strip().lower() + if normalized not in PPISP_BAKE_VIGNETTING_MODES: + raise ValueError( + f"Unsupported PPISP bake vignetting mode '{mode}'. " + f"Expected one of: {sorted(PPISP_BAKE_VIGNETTING_MODES)}" + ) + return normalized + + +def estimate_achromatic_vignetting( + ppisp: nn.Module, + camera_id: int, + pixel_coords: torch.Tensor, + resolution: tuple[int, int], +) -> torch.Tensor: + """Estimate luminance falloff from PPISP's chromatic camera vignette.""" + if not hasattr(ppisp, "vignetting_params"): + raise ValueError("PPISP-like module is missing vignetting_params.") + + width, height = resolution + del height + vig_params = ppisp.vignetting_params[int(camera_id)].to(device=pixel_coords.device, dtype=pixel_coords.dtype) + + u = (pixel_coords[..., 0] - float(width) * 0.5) / float(width) + v = (pixel_coords[..., 1] - float(resolution[1]) * 0.5) / float(width) + uv = torch.stack([u, v], dim=-1) + + channel_falloff = [] + for channel in range(3): + center = vig_params[channel, 0:2] + delta = uv - center + r2 = torch.sum(delta * delta, dim=-1) + falloff = ( + 1.0 + vig_params[channel, 2] * r2 + vig_params[channel, 3] * r2 * r2 + vig_params[channel, 4] * r2 * r2 * r2 + ) + channel_falloff.append(torch.clamp(falloff, 0.0, 1.0)) + + rgb_falloff = torch.stack(channel_falloff, dim=-1) + luminance_weights = torch.tensor([0.2126, 0.7152, 0.0722], device=pixel_coords.device, dtype=pixel_coords.dtype) + return torch.sum(rgb_falloff * luminance_weights, dim=-1, keepdim=True) + + +def apply_achromatic_vignetting( + rgb: torch.Tensor, + ppisp: nn.Module, + camera_id: int, + pixel_coords: torch.Tensor, + resolution: tuple[int, int], +) -> torch.Tensor: + return rgb * estimate_achromatic_vignetting(ppisp, camera_id, pixel_coords, resolution) + + +class PPISPPostProcessingBakeAdapter(PostProcessingBakeAdapter): + name = "PPISP post-processing" + + def __init__( + self, + camera_id: int = 0, + frame_id: int = 0, + vignetting_mode: str = MODE_PPISP_BAKE_VIGNETTING_NONE, + ) -> None: + self.camera_id = int(camera_id) + self.frame_id = int(frame_id) + self.vignetting_mode = normalize_ppisp_bake_vignetting_mode(vignetting_mode) + + def validate(self, post_processing: nn.Module) -> None: + if not hasattr(post_processing, "exposure_params") or not hasattr(post_processing, "crf_params"): + raise ValueError("PPISP SH bake export requires a PPISP-like post_processing module.") + + num_frames = int(post_processing.exposure_params.shape[0]) + num_cameras = int(post_processing.crf_params.shape[0]) + if self.frame_id < 0 or self.frame_id >= num_frames: + raise ValueError(f"frame_id must be in [0, {num_frames - 1}], got {self.frame_id}.") + if self.camera_id < 0 or self.camera_id >= num_cameras: + raise ValueError(f"camera_id must be in [0, {num_cameras - 1}], got {self.camera_id}.") + + def create_fixed_post_processing(self, post_processing: nn.Module, device: str) -> nn.Module: + return FixedPPISP( + post_processing, + self.camera_id, + self.frame_id, + device, + include_vignetting=self.vignetting_mode == MODE_PPISP_BAKE_VIGNETTING_ACHROMATIC_FIT, + ).eval() + + def apply_fit_transform(self, rgb: torch.Tensor, fixed_post_processing: nn.Module, gpu_batch) -> torch.Tensor: + del fixed_post_processing, gpu_batch + # SH eval lives in display (gamma) space -- initialize_fit warm-starts + # with apply_srgb_to_linear=False, the loss target is the full PPISP + # output (also display-referred), and the loss gradient flows through + # identity. Matches the conditioning of training a 3DGS model + # directly in gamma space, where the same SH degree shows no rainbow + # aliasing. + return torch.clamp(rgb, 0.0, 1.0) + + def initialize_fit(self, baked_model, post_processing: nn.Module) -> None: + """Warm-start with a DC-only simple-bake on the chosen (camera, + frame), in display (gamma) space. + + Matches the colour space the trainer used when ``post_processing.method`` + is null/linear-to-srgb -- features_albedo lives directly in display- + referred RGB and ``apply_fit_transform`` is the identity. Aligns the + baked-SH USD asset format with no-PPISP exports. + + The trained ``features_specular`` is left untouched: a higher-order + Jacobian rotation gives a slightly better starting PSNR but Adam + takes much longer to recover from the rotated specular (~7 dB at 9 + epochs on bonsai, see tools/ppisp_export benchmark). + """ + # Late import: avoid pulling ppisp into modules that don't need it. + from threedgrut.export.usd.post_processing_sh_simple_bake import simple_bake + + logger.info( + "PPISP SH bake init: applying simple_bake (camera=%d, frame=%d, " + "higher_order=False, apply_srgb_to_linear=False) before fitting.", + self.camera_id, self.frame_id, + ) + simple_bake( + baked_model, + post_processing, + camera_id=self.camera_id, + frame_id=self.frame_id, + higher_order=False, + apply_srgb_to_linear=False, + ) + + def log_context(self) -> str: + return f" camera={self.camera_id} frame={self.frame_id} vignetting={self.vignetting_mode}" diff --git a/threedgrut/export/usd/post_processing_sh_simple_bake.py b/threedgrut/export/usd/post_processing_sh_simple_bake.py new file mode 100644 index 00000000..a9267d80 --- /dev/null +++ b/threedgrut/export/usd/post_processing_sh_simple_bake.py @@ -0,0 +1,206 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""One-shot bake of a fixed PPISP transform into Gaussian SH coefficients.""" + +from __future__ import annotations + +import logging +from typing import Tuple + +import torch +from ppisp import PPISP, ppisp_apply + +from threedgrut.utils.post_processing_linear_to_srgb import srgb_to_linear +from threedgrut.utils.render import RGB2SH, SH2RGB + +logger = logging.getLogger(__name__) + +# A handful of Gaussians sit at PPISP-saturation extremes where the chain +# rule through pow() blows up (cond(J) > 1e8, |J|_F > 1e4). Their rotated +# specular norms grow by 4+ orders of magnitude and dominate Adam's +# variance estimate, stalling the fit. Past p99 of |J|_F (~3.4 on bonsai), +# rotations are unreliable; ``5.0`` keeps a small safety margin. +JACOBIAN_FRO_NORM_CLIP = 5.0 + + +def get_fixed_frame_params( + ppisp: PPISP, + frame_id: int, +) -> Tuple[float, torch.Tensor]: + """Return exposure offset and color params for one fixed PPISP frame.""" + num_frames = int(ppisp.exposure_params.shape[0]) + if frame_id < 0 or frame_id >= num_frames: + raise ValueError(f"frame_id must be in [0, {num_frames - 1}], got {frame_id}.") + exposure = float(ppisp.exposure_params[frame_id].item()) + color = ppisp.color_params[frame_id].detach() + return exposure, color + + +def _bake_dc_through_ppisp( + dc_rgb_linear: torch.Tensor, + ppisp: PPISP, + camera_id: int, + exposure: float, + color: torch.Tensor, +) -> torch.Tensor: + """Apply PPISP with no vignetting to each Gaussian DC RGB color.""" + device = dc_rgb_linear.device + dtype = dc_rgb_linear.dtype + num_gaussians = dc_rgb_linear.shape[0] + + exposure_params = torch.tensor([exposure], device=device, dtype=dtype) + color_params = color.to(device=device, dtype=dtype).unsqueeze(0) + vignetting_params = torch.zeros_like(ppisp.vignetting_params, device=device, dtype=dtype) + pixel_coords = torch.zeros(num_gaussians, 2, device=device, dtype=dtype) + + return ppisp_apply( + exposure_params=exposure_params, + vignetting_params=vignetting_params, + color_params=color_params, + crf_params=ppisp.crf_params, + rgb_in=dc_rgb_linear.contiguous(), + pixel_coords=pixel_coords, + resolution_w=1, + resolution_h=1, + camera_idx=camera_id, + frame_idx=0, + ) + + +def _bake_dc_with_jacobian_through_ppisp( + dc_rgb_linear: torch.Tensor, + ppisp: PPISP, + camera_id: int, + exposure: float, + color: torch.Tensor, + apply_srgb_to_linear: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Run PPISP forward and extract per-Gaussian RGB Jacobians. + + When ``apply_srgb_to_linear`` is True, both the returned RGB and the + Jacobian correspond to ``srgb_to_linear(PPISP(X))`` so the higher-order + SH rotation stays in the same color space as the DC bake. + """ + rgb_in = dc_rgb_linear.detach().clone().requires_grad_(True) + rgb_ppisp = _bake_dc_through_ppisp( + dc_rgb_linear=rgb_in, + ppisp=ppisp, + camera_id=camera_id, + exposure=exposure, + color=color, + ) + rgb_out = srgb_to_linear(rgb_ppisp) if apply_srgb_to_linear else rgb_ppisp + + num_gaussians = rgb_in.shape[0] + jacobian = torch.empty(num_gaussians, 3, 3, device=rgb_in.device, dtype=rgb_in.dtype) + for channel in range(3): + grad_out = torch.zeros_like(rgb_out) + grad_out[:, channel] = 1.0 + (grads,) = torch.autograd.grad( + outputs=rgb_out, + inputs=rgb_in, + grad_outputs=grad_out, + retain_graph=(channel < 2), + ) + jacobian[:, channel, :] = grads + + return rgb_out.detach(), jacobian.detach() + + +def _apply_jacobian_to_specular(features_specular: torch.nn.Parameter, jacobian: torch.Tensor) -> None: + """In-place linearization of higher-order SH coefficients by ``jacobian``. + + Gaussians whose Jacobian is non-finite or has Frobenius norm above + :data:`JACOBIAN_FRO_NORM_CLIP` keep their trained specular (i.e. J is + replaced by the identity for them) -- avoids polluting Adam's variance + estimate with rare PPISP-saturation outliers. + """ + num_gaussians, total = features_specular.shape + if total % 3 != 0: + raise ValueError(f"features_specular last-dim ({total}) must be divisible by 3.") + num_sh_coeffs = total // 3 + specular_rgb = features_specular.view(num_gaussians, num_sh_coeffs, 3) + + j_fro = torch.linalg.norm(jacobian, ord="fro", dim=(1, 2)) + safe = torch.isfinite(j_fro) & (j_fro <= JACOBIAN_FRO_NORM_CLIP) + eye = torch.eye(3, device=jacobian.device, dtype=jacobian.dtype).expand_as(jacobian) + jacobian_safe = torch.where(safe[:, None, None], jacobian, eye) + n_clipped = int((~safe).sum().item()) + if n_clipped > 0: + logger.info( + "Jacobian rotation clipped on %d/%d gaussians (|J|_F > %.1f or non-finite); " + "their trained features_specular preserved.", + n_clipped, num_gaussians, JACOBIAN_FRO_NORM_CLIP, + ) + + transformed = torch.einsum("nij,nkj->nki", jacobian_safe, specular_rgb) + specular_rgb.copy_(transformed) + + +def simple_bake( + model, + ppisp: PPISP, + camera_id: int, + frame_id: int, + higher_order: bool = False, + apply_srgb_to_linear: bool = False, +) -> Tuple[float, torch.Tensor]: + """Mutate SH coefficients with one fixed PPISP camera/frame transform. + + PPISP outputs display-referred values (its CRF folds in gamma-like + encoding). Storing those directly in linear SH coefficients leaves the + asset double-encoded: downstream consumers that themselves apply a + linear→sRGB step (``threedgrut/utils/post_processing_linear_to_srgb``, + Kit's tonemap, etc.) will gamma-correct on top of an already-encoded + image. ``apply_srgb_to_linear=True`` runs an inverse sRGB on the PPISP + output before ``RGB2SH`` so the SH coefficients land in linear scene- + referred space and a downstream ``linear_to_srgb`` undoes the + transformation cleanly. + """ + exposure, color = get_fixed_frame_params(ppisp, frame_id) + + def _maybe_srgb_to_linear(rgb: torch.Tensor) -> torch.Tensor: + return srgb_to_linear(rgb) if apply_srgb_to_linear else rgb + + if higher_order: + with torch.enable_grad(): + dc_rgb_linear = SH2RGB(model.features_albedo).detach() + dc_rgb_baked, jacobian = _bake_dc_with_jacobian_through_ppisp( + dc_rgb_linear=dc_rgb_linear, + ppisp=ppisp, + camera_id=camera_id, + exposure=exposure, + color=color, + apply_srgb_to_linear=apply_srgb_to_linear, + ) + with torch.no_grad(): + # dc_rgb_baked already includes srgb_to_linear when requested, + # so RGB2SH gets the right color space directly. + model.features_albedo.copy_(RGB2SH(dc_rgb_baked)) + _apply_jacobian_to_specular(model.features_specular, jacobian) + else: + with torch.no_grad(): + dc_rgb_linear = SH2RGB(model.features_albedo.detach()) + dc_rgb_baked = _bake_dc_through_ppisp( + dc_rgb_linear=dc_rgb_linear, + ppisp=ppisp, + camera_id=camera_id, + exposure=exposure, + color=color, + ) + model.features_albedo.copy_(RGB2SH(_maybe_srgb_to_linear(dc_rgb_baked))) + + return exposure, color diff --git a/threedgrut/export/usd/post_processing_view_interpolation.py b/threedgrut/export/usd/post_processing_view_interpolation.py new file mode 100644 index 00000000..223720fe --- /dev/null +++ b/threedgrut/export/usd/post_processing_view_interpolation.py @@ -0,0 +1,348 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 + +"""View samplers for SH-bake fitting. + +The default fit loop iterates the training dataloader, so the optimizer +only sees the discrete set of training poses. The ``trajectory`` sampler +orders the training views along a smooth path (nearest-neighbour + 2-opt +on position+direction), arc-length-parameterises the path on ``[0, 1]``, +then samples random ``t in [0, 1]`` and slerps inside the bracketing +segment. Useful when training views are sparse and a residual fit needs +to generalise to nearby novel views. + +The sampler reuses the dataset's per-intrinsic camera-space rays and +pixel-coordinate grid -- only ``T_to_world`` changes per sample. PPISP's +``FixedPPISP`` ignores the per-frame indices on the synthetic batch, so +camera/frame indices on the template are kept as-is. +""" + +from __future__ import annotations + +import logging +import math +from dataclasses import replace +from typing import Iterator, List, Optional, Tuple + +import numpy as np +import torch + +from threedgrut.datasets.protocols import Batch + +logger = logging.getLogger(__name__) + + +VIEW_SAMPLING_TRAINING = "training" +VIEW_SAMPLING_TRAJECTORY = "trajectory" +VIEW_SAMPLING_MODES = { + VIEW_SAMPLING_TRAINING, + VIEW_SAMPLING_TRAJECTORY, +} + + +def normalize_view_sampling_mode(mode: Optional[str]) -> str: + normalized = VIEW_SAMPLING_TRAINING if mode is None else str(mode).strip().lower() + if normalized not in VIEW_SAMPLING_MODES: + raise ValueError( + f"Unsupported view sampling mode '{mode}'. " + f"Expected one of: {sorted(VIEW_SAMPLING_MODES)}" + ) + return normalized + + +# --------------------------------------------------------------------------- +# Pose interpolation primitives (numpy, double precision for stability) +# --------------------------------------------------------------------------- + + +def _R_to_quat(R: np.ndarray) -> np.ndarray: + """3x3 rotation -> unit quaternion [w, x, y, z] (Shepperd's method).""" + R = np.asarray(R, dtype=np.float64) + trace = R[0, 0] + R[1, 1] + R[2, 2] + if trace > 0.0: + s = math.sqrt(trace + 1.0) * 2.0 + qw = 0.25 * s + qx = (R[2, 1] - R[1, 2]) / s + qy = (R[0, 2] - R[2, 0]) / s + qz = (R[1, 0] - R[0, 1]) / s + elif R[0, 0] > R[1, 1] and R[0, 0] > R[2, 2]: + s = math.sqrt(1.0 + R[0, 0] - R[1, 1] - R[2, 2]) * 2.0 + qw = (R[2, 1] - R[1, 2]) / s + qx = 0.25 * s + qy = (R[0, 1] + R[1, 0]) / s + qz = (R[0, 2] + R[2, 0]) / s + elif R[1, 1] > R[2, 2]: + s = math.sqrt(1.0 + R[1, 1] - R[0, 0] - R[2, 2]) * 2.0 + qw = (R[0, 2] - R[2, 0]) / s + qx = (R[0, 1] + R[1, 0]) / s + qy = 0.25 * s + qz = (R[1, 2] + R[2, 1]) / s + else: + s = math.sqrt(1.0 + R[2, 2] - R[0, 0] - R[1, 1]) * 2.0 + qw = (R[1, 0] - R[0, 1]) / s + qx = (R[0, 2] + R[2, 0]) / s + qy = (R[1, 2] + R[2, 1]) / s + qz = 0.25 * s + q = np.array([qw, qx, qy, qz], dtype=np.float64) + return q / np.linalg.norm(q) + + +def _quat_to_R(q: np.ndarray) -> np.ndarray: + qw, qx, qy, qz = (q / np.linalg.norm(q)).tolist() + return np.array([ + [1 - 2*(qy*qy + qz*qz), 2*(qx*qy - qz*qw), 2*(qx*qz + qy*qw)], + [2*(qx*qy + qz*qw), 1 - 2*(qx*qx + qz*qz), 2*(qy*qz - qx*qw)], + [2*(qx*qz - qy*qw), 2*(qy*qz + qx*qw), 1 - 2*(qx*qx + qy*qy)], + ], dtype=np.float64) + + +def _slerp_quat(q0: np.ndarray, q1: np.ndarray, s: float) -> np.ndarray: + """Standard quaternion slerp; falls back to lerp+normalise when nearly parallel.""" + q0 = q0 / np.linalg.norm(q0) + q1 = q1 / np.linalg.norm(q1) + d = float(np.dot(q0, q1)) + if d < 0.0: # take the short arc + q1 = -q1 + d = -d + if d > 0.9995: + out = q0 + s * (q1 - q0) + return out / np.linalg.norm(out) + theta = math.acos(max(min(d, 1.0), -1.0)) + sin_theta = math.sin(theta) + a = math.sin((1.0 - s) * theta) / sin_theta + b = math.sin(s * theta) / sin_theta + return a * q0 + b * q1 + + +def slerp_pose(pose_a: np.ndarray, pose_b: np.ndarray, s: float) -> np.ndarray: + """Interpolate a 4x4 c2w pose between ``pose_a`` and ``pose_b`` at ``s in [0, 1]``. + + Rotation: quaternion slerp. Translation: linear lerp. Lower row is left as + ``[0, 0, 0, 1]``. + """ + s = float(np.clip(s, 0.0, 1.0)) + pose_a = np.asarray(pose_a, dtype=np.float64) + pose_b = np.asarray(pose_b, dtype=np.float64) + q_a = _R_to_quat(pose_a[:3, :3]) + q_b = _R_to_quat(pose_b[:3, :3]) + q = _slerp_quat(q_a, q_b, s) + R = _quat_to_R(q) + t = (1.0 - s) * pose_a[:3, 3] + s * pose_b[:3, 3] + out = np.eye(4, dtype=np.float64) + out[:3, :3] = R + out[:3, 3] = t + return out + + +# --------------------------------------------------------------------------- +# Trajectory ordering: nearest-neighbour + 2-opt on a position+direction metric +# --------------------------------------------------------------------------- + + +def _pose_distance_matrix( + poses: np.ndarray, + weight_position: float, + weight_rotation: float, +) -> np.ndarray: + """``D[i, j]`` = weighted (position L2 + 1 - cos(forward angle)).""" + n = poses.shape[0] + pos = poses[:, :3, 3] # (N, 3) + fwd = poses[:, :3, 2] # (N, 3) RDF: +Z = forward + fwd = fwd / np.maximum(np.linalg.norm(fwd, axis=1, keepdims=True), 1e-12) + + # vectorised pairwise position distance + diff = pos[:, None, :] - pos[None, :, :] + d_pos = np.linalg.norm(diff, axis=2) + # normalise by mean pairwise so the rotation term lives on a comparable scale + mean_pos = max(float(d_pos[d_pos > 0].mean()) if (d_pos > 0).any() else 1.0, 1e-9) + + cos_ang = np.clip(fwd @ fwd.T, -1.0, 1.0) + d_rot = 1.0 - cos_ang # in [0, 2] + + return weight_position * (d_pos / mean_pos) + weight_rotation * d_rot + + +def _nearest_neighbour_order(D: np.ndarray, start: int = 0) -> List[int]: + n = D.shape[0] + visited = [False] * n + order = [start] + visited[start] = True + while len(order) < n: + last = order[-1] + # mask visited with +inf + candidates = np.where(visited, np.inf, D[last]) + nxt = int(np.argmin(candidates)) + order.append(nxt) + visited[nxt] = True + return order + + +def _two_opt(order: List[int], D: np.ndarray, max_passes: int = 50) -> List[int]: + """In-place 2-opt swap loop. Stops when a full pass yields no improvement + or when ``max_passes`` is reached.""" + n = len(order) + if n < 4: + return order + for _ in range(max_passes): + improved = False + for i in range(1, n - 2): + for j in range(i + 1, n - 1): + a, b = order[i - 1], order[i] + c, d = order[j], order[j + 1] + # original edges (a,b) + (c,d) + # candidate after reverse: (a,c) + (b,d) + if D[a, c] + D[b, d] + 1e-12 < D[a, b] + D[c, d]: + order[i:j + 1] = order[i:j + 1][::-1] + improved = True + if not improved: + break + return order + + +def order_views_along_trajectory( + poses: np.ndarray, + *, + weight_position: float = 1.0, + weight_rotation: float = 0.5, + start_index: int = 0, + two_opt_passes: int = 50, +) -> Tuple[List[int], np.ndarray]: + """Order ``poses`` along an approximate Hamiltonian path. + + Returns ``(ordered_indices, cum_t)`` where ``cum_t[k] in [0, 1]`` is the + arc-length parameter at the k-th ordered pose. ``cum_t[0] = 0`` and + ``cum_t[-1] = 1``. + """ + poses = np.asarray(poses, dtype=np.float64) + if poses.ndim != 3 or poses.shape[-2:] != (4, 4): + raise ValueError(f"poses must be (N, 4, 4), got {poses.shape}") + n = poses.shape[0] + if n < 2: + return list(range(n)), np.zeros(max(n, 1), dtype=np.float64) + + D = _pose_distance_matrix(poses, weight_position, weight_rotation) + order = _nearest_neighbour_order(D, start=start_index) + order = _two_opt(order, D, max_passes=two_opt_passes) + + cum = np.zeros(n, dtype=np.float64) + for k in range(1, n): + cum[k] = cum[k - 1] + D[order[k - 1], order[k]] + if cum[-1] > 0: + cum = cum / cum[-1] + return order, cum + + +# --------------------------------------------------------------------------- +# Sampler driving the fit loop +# --------------------------------------------------------------------------- + + +class InterpolatedViewSampler: + """Yields ``Batch`` objects with synthetic interpolated poses. + + The sampler grabs one template batch from the training dataset to + cache the per-intrinsic camera-space rays, pixel coords and any + intrinsic dictionaries; only ``T_to_world`` (and ``T_to_world_end``, + which we set to the same pose -- no rolling shutter on synthetic + poses) changes per sample. + + Args: + train_dataset: must implement + :meth:`~threedgrut.datasets.protocols.BoundedMultiViewDataset.get_poses` + and :meth:`get_gpu_batch_with_intrinsics`. + mode: only ``"trajectory"`` is supported. + steps_per_epoch: how many synthetic batches to emit per pass. + seed: optional RNG seed for reproducibility. + weight_position / weight_rotation: trajectory distance weights. + start_index: trajectory NN seed index. + """ + + def __init__( + self, + train_dataset, + template_gpu_batch: Batch, + mode: str, + steps_per_epoch: int, + *, + seed: Optional[int] = None, + weight_position: float = 1.0, + weight_rotation: float = 0.5, + start_index: int = 0, + ) -> None: + mode = normalize_view_sampling_mode(mode) + if mode == VIEW_SAMPLING_TRAINING: + raise ValueError("InterpolatedViewSampler is only for non-training modes.") + if not hasattr(train_dataset, "get_poses"): + raise TypeError( + "InterpolatedViewSampler requires a dataset exposing get_poses(); " + f"got {type(train_dataset).__name__}." + ) + if not isinstance(template_gpu_batch, Batch): + raise TypeError("template_gpu_batch must be a threedgrut Batch instance.") + self.dataset = train_dataset + self.mode = mode + self.steps_per_epoch = int(steps_per_epoch) + self._rng = np.random.default_rng(seed) + self._template = template_gpu_batch + + poses = np.asarray(train_dataset.get_poses(), dtype=np.float64) + if poses.ndim != 3 or poses.shape[-2:] != (4, 4): + raise ValueError(f"dataset.get_poses() must be (N, 4, 4), got {poses.shape}") + if poses.shape[0] < 2: + raise ValueError("Need at least 2 training views to interpolate.") + self._poses = poses + + self._ordered_indices, self._cum_t = order_views_along_trajectory( + poses, + weight_position=weight_position, + weight_rotation=weight_rotation, + start_index=start_index, + ) + logger.info( + "Built %d-view trajectory (NN + 2-opt) for SH-bake interpolation.", + len(self._ordered_indices), + ) + + # ------------------------------------------------------------------ + # Pose sampling + # ------------------------------------------------------------------ + + def _sample_pose(self) -> np.ndarray: + t = float(self._rng.random()) + cum = self._cum_t + # Find segment k s.t. cum[k-1] <= t <= cum[k] (with cum[0]=0). + k = int(np.searchsorted(cum, t, side="left")) + k = max(1, min(k, len(cum) - 1)) + denom = max(cum[k] - cum[k - 1], 1e-12) + local_s = float((t - cum[k - 1]) / denom) + a = self._ordered_indices[k - 1] + b = self._ordered_indices[k] + return slerp_pose(self._poses[a], self._poses[b], local_s) + + # ------------------------------------------------------------------ + # Batch construction + # ------------------------------------------------------------------ + + def _make_batch(self, pose_np: np.ndarray) -> Batch: + device = self._template.T_to_world.device + dtype = self._template.T_to_world.dtype + T = torch.from_numpy(pose_np).to(device=device, dtype=dtype).unsqueeze(0) + # Same pose for start and end -- no rolling shutter on synthetic views. + return replace(self._template, T_to_world=T, T_to_world_end=T) + + # ------------------------------------------------------------------ + # Iterator protocol + # ------------------------------------------------------------------ + + def __iter__(self) -> Iterator[Batch]: + for _ in range(self.steps_per_epoch): + yield self._make_batch(self._sample_pose()) + + def __len__(self) -> int: + return self.steps_per_epoch diff --git a/threedgrut/export/usd/ppisp_spg/__init__.py b/threedgrut/export/usd/ppisp_spg/__init__.py new file mode 100644 index 00000000..2ac26bc4 --- /dev/null +++ b/threedgrut/export/usd/ppisp_spg/__init__.py @@ -0,0 +1,67 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +PPISP SPG shader assets for USD RenderProduct post-processing. + +Provides loader for the three SPG sidecar files (Slang shader, Lua launcher, +USDA definition) that must be packaged alongside the exported USDZ. +""" + +import logging +from pathlib import Path +from typing import List + +from threedgrut.export.usd.stage_utils import NamedSerialized + +log = logging.getLogger(__name__) + +_SPG_DIR = Path(__file__).parent +_SPG_STATIC_FILES = [ + "ppisp_usd_spg.slang", + "ppisp_usd_spg.slang.lua", + "ppisp_usd_spg.slang.usda", +] +_SPG_DYN_FILES = [ + "ppisp_usd_spg_dyn.slang", + "ppisp_usd_spg_dyn.slang.lua", + "ppisp_usd_spg_dyn.slang.usda", +] + + +def _load_files(filenames) -> List[NamedSerialized]: + result: List[NamedSerialized] = [] + for filename in filenames: + path = _SPG_DIR / filename + if path.exists(): + result.append(NamedSerialized(filename=filename, serialized=path.read_bytes())) + log.debug(f"Loaded PPISP SPG sidecar: {filename}") + else: + log.warning(f"PPISP SPG sidecar not found: {path}") + return result + + +def get_ppisp_spg_files() -> List[NamedSerialized]: + """Load static-parameter PPISP SPG sidecar files (controller-free path).""" + return _load_files(_SPG_STATIC_FILES) + + +def get_ppisp_spg_dyn_files() -> List[NamedSerialized]: + """Load controller-aware PPISP SPG sidecar files. + + These accompany the per-camera ``ppisp_controller_.slang`` and read + ``exposureOffset`` and the colour latents from the controller output. + """ + return _load_files(_SPG_DYN_FILES) diff --git a/threedgrut/export/usd/ppisp_spg/ppisp_controller.slang b/threedgrut/export/usd/ppisp_spg/ppisp_controller.slang new file mode 100644 index 00000000..c41588e0 --- /dev/null +++ b/threedgrut/export/usd/ppisp_spg/ppisp_controller.slang @@ -0,0 +1,294 @@ +// SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 + +// PPISP Controller SPG Shader. +// +// Generic compute shader: weights are bound at dispatch time as a flat +// ``StructuredBuffer``. Per-camera variation lives entirely in +// USD attributes; the Slang/Lua/USDA assets are shared. +// +// Architecture mirrors ppisp._PPISPController (default config): +// +// Conv1x1(3->16, +bias) +// MaxPool 3x3 stride 3 +// ReLU +// Conv1x1(16->32, +bias) +// ReLU +// Conv1x1(32->64, +bias) +// AdaptiveAvgPool2d((5,5)) +// Flatten -> 1600 +// concat prior_exposure -> 1601 +// MLP: 1601 -> 128 -> 128 -> 128, ReLU after each hidden layer +// exposure_head: 128 -> 1 +// color_head: 128 -> 8 +// +// Output texture (1x9 float, RWTexture2D): +// pixel (0,0): exposureOffset +// pixel (1,0)..(8,0): color latents +// [colorBlue.x, colorBlue.y, +// colorRed.x, colorRed.y, +// colorGreen.x, colorGreen.y, +// colorNeutral.x, colorNeutral.y] + +// --------------------------------------------------------------------------- +// Architecture sizes (must match ``_PPISPController`` defaults). +// --------------------------------------------------------------------------- +static const int CNN_FEATURE_DIM = 64; +static const int POOL_GRID_H = 5; +static const int POOL_GRID_W = 5; +static const int POOL_CELL_COUNT = POOL_GRID_H * POOL_GRID_W; // 25 +static const int POOL_FEATURE_LEN = POOL_CELL_COUNT * CNN_FEATURE_DIM; // 1600 +static const int MLP_INPUT_DIM = POOL_FEATURE_LEN + 1; // 1601 +static const int MLP_HIDDEN_DIM = 128; +static const int COLOR_PARAMS_PER_FRAME = 8; +static const int INPUT_DOWNSAMPLING = 3; +static const int THREAD_GROUP_SIZE = 32; + +// --------------------------------------------------------------------------- +// Weight buffer offsets (the Python writer flattens weights in this order +// into a single float buffer that gets bound as weights). +// --------------------------------------------------------------------------- +static const int OFF_CONV1_W = 0; // 16 * 3 = 48 +static const int OFF_CONV1_B = OFF_CONV1_W + 16 * 3; // + 16 = 64 +static const int OFF_CONV2_W = OFF_CONV1_B + 16; // + 32 * 16 = 576 +static const int OFF_CONV2_B = OFF_CONV2_W + 32 * 16; // + 32 = 608 +static const int OFF_CONV3_W = OFF_CONV2_B + 32; // + 64 * 32 = 2656 +static const int OFF_CONV3_B = OFF_CONV3_W + 64 * 32; // + 64 = 2720 +static const int OFF_TRUNK0_W = OFF_CONV3_B + 64; // + 128 * 1601 = 207648 +static const int OFF_TRUNK0_B = OFF_TRUNK0_W + 128 * MLP_INPUT_DIM; +static const int OFF_TRUNK1_W = OFF_TRUNK0_B + 128; +static const int OFF_TRUNK1_B = OFF_TRUNK1_W + 128 * 128; +static const int OFF_TRUNK2_W = OFF_TRUNK1_B + 128; +static const int OFF_TRUNK2_B = OFF_TRUNK2_W + 128 * 128; +static const int OFF_EXP_W = OFF_TRUNK2_B + 128; +static const int OFF_EXP_B = OFF_EXP_W + 128; +static const int OFF_COL_W = OFF_EXP_B + 1; +static const int OFF_COL_B = OFF_COL_W + 8 * 128; +static const int TOTAL_WEIGHTS = OFF_COL_B + 8; + +// --------------------------------------------------------------------------- +// Bindings +// --------------------------------------------------------------------------- + +// SPG resolves USD ``inputs:foo`` attributes against fields of the slang +// ParameterBlock -- its reflection lookup is ``params:foo``. Putting +// ``weights`` directly inside the ParameterBlock struct lets SPG's +// auto-binding find it by name, and silences the per-dispatch warning +// "Failed to find parameter 'params:weights' in shader reflection". +struct PPISPControllerParams +{ + float priorExposure; + StructuredBuffer weights; +}; + +[[vk::binding(0, 1)]] ParameterBlock g_Params; +[[vk::binding(1, 1)]] Texture2D g_InTex; +[[vk::binding(2, 1)]] RWTexture2D g_OutTex; + +// --------------------------------------------------------------------------- +// Per-pixel CNN building blocks +// --------------------------------------------------------------------------- + +void conv1Forward(float3 rgb, out float feat[16]) +{ + [unroll] for (int o = 0; o < 16; ++o) + { + float v = g_Params.weights[OFF_CONV1_B + o]; + v += rgb.r * g_Params.weights[OFF_CONV1_W + o * 3 + 0]; + v += rgb.g * g_Params.weights[OFF_CONV1_W + o * 3 + 1]; + v += rgb.b * g_Params.weights[OFF_CONV1_W + o * 3 + 2]; + feat[o] = v; + } +} + +void conv2Forward(float fin[16], out float fout[32]) +{ + [unroll] for (int o = 0; o < 32; ++o) + { + float v = g_Params.weights[OFF_CONV2_B + o]; + [unroll] for (int i = 0; i < 16; ++i) + v += fin[i] * g_Params.weights[OFF_CONV2_W + o * 16 + i]; + fout[o] = v; + } +} + +void conv3Forward(float fin[32], out float fout[64]) +{ + [unroll] for (int o = 0; o < CNN_FEATURE_DIM; ++o) + { + float v = g_Params.weights[OFF_CONV3_B + o]; + [unroll] for (int i = 0; i < 32; ++i) + v += fin[i] * g_Params.weights[OFF_CONV3_W + o * 32 + i]; + fout[o] = v; + } +} + +void cnnForwardAtDownsampledPixel( + int inW, + int inH, + int dx, + int dy, + out float feat64[64]) +{ + int x0 = dx * INPUT_DOWNSAMPLING; + int y0 = dy * INPUT_DOWNSAMPLING; + int x1 = min(x0 + INPUT_DOWNSAMPLING, inW); + int y1 = min(y0 + INPUT_DOWNSAMPLING, inH); + + float pooled[16]; + [unroll] for (int c = 0; c < 16; ++c) + pooled[c] = -3.402823e+38; + + for (int yy = y0; yy < y1; ++yy) + { + for (int xx = x0; xx < x1; ++xx) + { + float4 sample = g_InTex.Load(int3(xx, yy, 0)); + float conv1Out[16]; + conv1Forward(sample.rgb, conv1Out); + [unroll] for (int c = 0; c < 16; ++c) + pooled[c] = max(pooled[c], conv1Out[c]); + } + } + + [unroll] for (int c = 0; c < 16; ++c) + pooled[c] = max(0.0, pooled[c]); + + float feat32[32]; + conv2Forward(pooled, feat32); + [unroll] for (int c = 0; c < 32; ++c) + feat32[c] = max(0.0, feat32[c]); + + conv3Forward(feat32, feat64); +} + +void adaptiveCellAverage( + int inW, + int inH, + int dsW, + int dsH, + int gx, + int gy, + out float cellFeat[64]) +{ + int hStart = (gy * dsH) / POOL_GRID_H; + int hEnd = ((gy + 1) * dsH + POOL_GRID_H - 1) / POOL_GRID_H; + int wStart = (gx * dsW) / POOL_GRID_W; + int wEnd = ((gx + 1) * dsW + POOL_GRID_W - 1) / POOL_GRID_W; + hEnd = min(hEnd, dsH); + wEnd = min(wEnd, dsW); + + [unroll] for (int c = 0; c < CNN_FEATURE_DIM; ++c) + cellFeat[c] = 0.0; + + int count = 0; + for (int dy = hStart; dy < hEnd; ++dy) + { + for (int dx = wStart; dx < wEnd; ++dx) + { + float feat64[CNN_FEATURE_DIM]; + cnnForwardAtDownsampledPixel(inW, inH, dx, dy, feat64); + [unroll] for (int c = 0; c < CNN_FEATURE_DIM; ++c) + cellFeat[c] += feat64[c]; + count += 1; + } + } + + float invCount = (count > 0) ? (1.0 / float(count)) : 0.0; + [unroll] for (int c = 0; c < CNN_FEATURE_DIM; ++c) + cellFeat[c] *= invCount; +} + +groupshared float gsPooled[POOL_FEATURE_LEN]; // 1600 floats +groupshared float gsHiddenA[MLP_HIDDEN_DIM]; // 128 floats +groupshared float gsHiddenB[MLP_HIDDEN_DIM]; // 128 floats + +[shader("compute")] +[numthreads(THREAD_GROUP_SIZE, 1, 1)] +void controllerProcess(uint3 gtid : SV_GroupThreadID) +{ + uint inW = 0, inH = 0; + g_InTex.GetDimensions(inW, inH); + + int dsW = max(1u, inW / INPUT_DOWNSAMPLING); + int dsH = max(1u, inH / INPUT_DOWNSAMPLING); + + // Phase 1: pool cells. With THREAD_GROUP_SIZE=32 threads and 25 cells, + // only the first 25 are active in this phase. + // + // Layout note: PyTorch's nn.Flatten on the [N, C, H, W] CNN output + // produces a *channel-major* flat layout — feat[c * H*W + h*W + w]. + // The trunk0 weight matrix was trained against that layout, so + // gsPooled MUST be stored channel-major as well, i.e. + // gsPooled[c * POOL_CELL_COUNT + cell]. + // (cell-major would silently permute every controller output.) + int cell = int(gtid.x); + if (cell < POOL_CELL_COUNT) + { + int gy = cell / POOL_GRID_W; + int gx = cell % POOL_GRID_W; + + float cellFeat[CNN_FEATURE_DIM]; + adaptiveCellAverage(int(inW), int(inH), dsW, dsH, gx, gy, cellFeat); + + [unroll] for (int c = 0; c < CNN_FEATURE_DIM; ++c) + gsPooled[c * POOL_CELL_COUNT + cell] = cellFeat[c]; + } + GroupMemoryBarrierWithGroupSync(); + + // Phase 2: trunk0 (1601 -> 128). 128 output rows are distributed + // across the THREAD_GROUP_SIZE threads. + for (int o = int(gtid.x); o < MLP_HIDDEN_DIM; o += THREAD_GROUP_SIZE) + { + float v = g_Params.weights[OFF_TRUNK0_B + o]; + for (int i = 0; i < POOL_FEATURE_LEN; ++i) + v += gsPooled[i] * g_Params.weights[OFF_TRUNK0_W + o * MLP_INPUT_DIM + i]; + v += g_Params.priorExposure + * g_Params.weights[OFF_TRUNK0_W + o * MLP_INPUT_DIM + POOL_FEATURE_LEN]; + gsHiddenA[o] = max(0.0, v); + } + GroupMemoryBarrierWithGroupSync(); + + // Phase 3: trunk1 (128 -> 128). gsHiddenA -> gsHiddenB. + for (int o = int(gtid.x); o < MLP_HIDDEN_DIM; o += THREAD_GROUP_SIZE) + { + float v = g_Params.weights[OFF_TRUNK1_B + o]; + for (int i = 0; i < MLP_HIDDEN_DIM; ++i) + v += gsHiddenA[i] * g_Params.weights[OFF_TRUNK1_W + o * MLP_HIDDEN_DIM + i]; + gsHiddenB[o] = max(0.0, v); + } + GroupMemoryBarrierWithGroupSync(); + + // Phase 4: trunk2 (128 -> 128). gsHiddenB -> gsHiddenA. + for (int o = int(gtid.x); o < MLP_HIDDEN_DIM; o += THREAD_GROUP_SIZE) + { + float v = g_Params.weights[OFF_TRUNK2_B + o]; + for (int i = 0; i < MLP_HIDDEN_DIM; ++i) + v += gsHiddenB[i] * g_Params.weights[OFF_TRUNK2_W + o * MLP_HIDDEN_DIM + i]; + gsHiddenA[o] = max(0.0, v); + } + GroupMemoryBarrierWithGroupSync(); + + // Phase 5: heads. + if (gtid.x == 0) + { + float v = g_Params.weights[OFF_EXP_B]; + for (int i = 0; i < MLP_HIDDEN_DIM; ++i) + v += gsHiddenA[i] * g_Params.weights[OFF_EXP_W + i]; + g_OutTex[int2(0, 0)] = v; + } + if (gtid.x < uint(COLOR_PARAMS_PER_FRAME)) + { + int o = int(gtid.x); + float v = g_Params.weights[OFF_COL_B + o]; + for (int i = 0; i < MLP_HIDDEN_DIM; ++i) + v += gsHiddenA[i] * g_Params.weights[OFF_COL_W + o * MLP_HIDDEN_DIM + i]; + g_OutTex[int2(1 + o, 0)] = v; + } +} diff --git a/threedgrut/export/usd/ppisp_spg/ppisp_controller.slang.lua b/threedgrut/export/usd/ppisp_spg/ppisp_controller.slang.lua new file mode 100644 index 00000000..67daccdf --- /dev/null +++ b/threedgrut/export/usd/ppisp_spg/ppisp_controller.slang.lua @@ -0,0 +1,73 @@ +-- SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +-- SPDX-License-Identifier: Apache-2.0 + +-- PPISP Controller SPG Launcher. +-- +-- Single shared launcher for every camera. Per-camera differences are +-- carried by the ``weights`` USD attribute, so this file does not need +-- to be regenerated. + +-- Bind the controller weight buffer using whichever buffer-helper SPG's +-- slang lua API exposes. The trained weights live as a USD float[] +-- attribute (params["weights"]) and the slang shader reads them as a +-- read-only StructuredBuffer. +local function bind_weights(w) + -- Probe a long list of plausible names. The first non-nil wins. + local candidates = { + "StructuredBuffer", "RWStructuredBuffer", + "Buffer", "RWBuffer", + "ByteAddressBuffer", "RWByteAddressBuffer", + "ConstantBuffer", + "buffer", "Array", "array", + "float_array", "FloatArray", "floatArray", + "FloatBuffer", "floatBuffer", + "image", "Image", + "uniform", "Uniform", + "list", "List", + } + local hits = {} + for _, name in ipairs(candidates) do + if slang[name] ~= nil then + table.insert(hits, name) + end + end + if #hits > 0 then + return slang[hits[1]](w) + end + -- No buffer helper. List EVERY direct slang.* key plus every + -- candidate we tried (so the metatable surface is also probed via + -- __index above). The error message goes to Kit's log. + local direct = {} + for k, _ in pairs(slang) do table.insert(direct, tostring(k)) end + table.sort(direct) + error("ppisp_controller: no slang buffer-binding helper found. " .. + "Tried: " .. table.concat(candidates, ",") .. + " | direct keys = " .. table.concat(direct, ",")) +end + +function controllerProcess(inputs, outputs, params) + local in_rgba = inputs["HdrColor"] + assert(in_rgba and in_rgba.rank == 2, "HdrColor input must be a 2D texture") + + local weights = params["weights"] + assert(weights, "controllerProcess needs the inputs:weights attribute") + + -- 1x9 single-channel float image holding [exposure, color latents]. + outputs["ControllerParams"] = slang.empty({ 1, 9 }, slang.float) + + return slang.dispatch({ + stage = "compute", + numthreads = { 32, 1, 1 }, + grid = { 1, 1, 1 }, + bind = { + -- weights live inside the ParameterBlock struct so SPG's + -- reflection finds them under "params:weights". + slang.ParameterBlock( + slang.float(params["priorExposure"] or 0.0), + bind_weights(weights) + ), + slang.Texture2D(in_rgba), + slang.RWTexture2D(outputs["ControllerParams"]), + }, + }) +end diff --git a/threedgrut/export/usd/ppisp_spg/ppisp_controller.slang.usda b/threedgrut/export/usd/ppisp_spg/ppisp_controller.slang.usda new file mode 100644 index 00000000..681802b4 --- /dev/null +++ b/threedgrut/export/usd/ppisp_spg/ppisp_controller.slang.usda @@ -0,0 +1,31 @@ +#usda 1.0 +( + defaultPrim = "SlangPPISPController" +) + +def Shader "SlangPPISPController" +{ + uniform token info:implementationSource = "sourceAsset" + uniform asset info:spg:sourceAsset = @ppisp_controller.slang@ + uniform token info:spg:sourceAsset:subIdentifier = "controllerProcess" + + # Optional EXIF-derived prior exposure. Defaults to zero so the controller + # behaves identically to training-time inference when no prior is wired. + float inputs:priorExposure = 0.0 + + # Flat float buffer holding all controller weights in the layout + # encoded by ppisp_controller.slang's OFF_* offsets: + # conv1_weight (16x3) | conv1_bias (16) | + # conv2_weight (32x16) | conv2_bias (32) | + # conv3_weight (64x32) | conv3_bias (64) | + # trunk0_weight (128x1601) | trunk0_bias (128) | + # trunk1_weight (128x128) | trunk1_bias (128) | + # trunk2_weight (128x128) | trunk2_bias (128) | + # exposure_head_weight (128) | exposure_head_bias (1) | + # color_head_weight (8x128) | color_head_bias (8) + # = 241,961 floats per camera. + float[] inputs:weights = [] + + opaque inputs:HdrColor + opaque outputs:ControllerParams +} diff --git a/threedgrut/export/usd/ppisp_spg/ppisp_usd_spg.slang b/threedgrut/export/usd/ppisp_spg/ppisp_usd_spg.slang new file mode 100644 index 00000000..2e20fc55 --- /dev/null +++ b/threedgrut/export/usd/ppisp_spg/ppisp_usd_spg.slang @@ -0,0 +1,253 @@ +// SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// PPISP (Physically Plausible Image Signal Processing) SPG Shader +// +// Implements the ISP pipeline for USD RenderProducts: +// 1. Exposure compensation +// 2. Vignetting correction (per-channel) +// 3. Color correction via ZCA-based homography +// 4. Camera Response Function (per-channel, 4-param toe-shoulder curve) +// +// NOTE: All parameters use flat naming to match USD inputs: attributes (UsdShade-compatible). +// SPG requires the Slang struct field names to match USD input names. + +struct PPISPParams +{ + // User-overridable per-camera responsivity, premultiplied to the + // input HdrColor before the rest of the pipeline runs. Defaults to + // (1, 1, 1) so the override is a no-op unless explicitly authored. + float responsivityR; + float responsivityG; + float responsivityB; + + // Exposure + float exposureOffset; + + // Vignetting R channel + float2 vignettingCenterR; + float vignettingAlpha1R; + float vignettingAlpha2R; + float vignettingAlpha3R; + + // Vignetting G channel + float2 vignettingCenterG; + float vignettingAlpha1G; + float vignettingAlpha2G; + float vignettingAlpha3G; + + // Vignetting B channel + float2 vignettingCenterB; + float vignettingAlpha1B; + float vignettingAlpha2B; + float vignettingAlpha3B; + + // Color correction: 4 control-point latent offsets (Blue, Red, Green, Neutral) + float2 colorLatentBlue; + float2 colorLatentRed; + float2 colorLatentGreen; + float2 colorLatentNeutral; + + // CRF R channel (raw params: activations applied at runtime) + float crfToeR; + float crfShoulderR; + float crfGammaR; + float crfCenterR; + + // CRF G channel + float crfToeG; + float crfShoulderG; + float crfGammaG; + float crfCenterG; + + // CRF B channel + float crfToeB; + float crfShoulderB; + float crfGammaB; + float crfCenterB; +}; + +[[vk::binding(0, 1)]] ParameterBlock g_Params; +[[vk::binding(1, 1)]] Texture2D g_InTex; +[[vk::binding(2, 1)]] RWTexture2D g_OutTex; + +// ZCA pinv 2x2 blocks (constant, matching ppisp_math.cuh COLOR_PINV_BLOCKS) +static const float2x2 ZCA_BLUE = float2x2( 0.0480542, -0.0043631, -0.0043631, 0.0481283); +static const float2x2 ZCA_RED = float2x2( 0.0580570, -0.0179872, -0.0179872, 0.0431061); +static const float2x2 ZCA_GREEN = float2x2( 0.0433336, -0.0180537, -0.0180537, 0.0580500); +static const float2x2 ZCA_NEUTRAL = float2x2( 0.0128369, -0.0034654, -0.0034654, 0.0128158); + +// Compute 3x3 homography from ZCA latent offsets (port of compute_homography from ppisp_math.cuh) +float3x3 computeHomography(float2 bLat, float2 rLat, float2 gLat, float2 nLat) +{ + float2 bd = mul(ZCA_BLUE, bLat); + float2 rd = mul(ZCA_RED, rLat); + float2 gd = mul(ZCA_GREEN, gLat); + float2 nd = mul(ZCA_NEUTRAL, nLat); + + // Target chromaticities: source + offset. Source = (r,g,I) for pure B,R,G,gray + float3 tB = float3(0.0 + bd.x, 0.0 + bd.y, 1.0); + float3 tR = float3(1.0 + rd.x, 0.0 + rd.y, 1.0); + float3 tG = float3(0.0 + gd.x, 1.0 + gd.y, 1.0); + float3 tGray = float3(1.0 / 3.0 + nd.x, 1.0 / 3.0 + nd.y, 1.0); + + // T = [tB | tR | tG] as columns (row-major: row i = [tB[i], tR[i], tG[i]]) + float3x3 T = float3x3(tB.x, tR.x, tG.x, + tB.y, tR.y, tG.y, + tB.z, tR.z, tG.z); + + // Skew-symmetric matrix [tGray]_x + float3x3 skew = float3x3(0.0, -tGray.z, tGray.y, + tGray.z, 0.0, -tGray.x, + -tGray.y, tGray.x, 0.0); + + float3x3 M = mul(skew, T); + + // Null-space vector via cross product of first two rows + float3 r0 = M[0]; + float3 r1 = M[1]; + float3 r2 = M[2]; + + float3 lam = cross(r0, r1); + if (dot(lam, lam) < 1.0e-20) + { + lam = cross(r0, r2); + if (dot(lam, lam) < 1.0e-20) + lam = cross(r1, r2); + } + + // S_inv = [[-1,-1,1],[1,0,0],[0,1,0]] + float3x3 Sinv = float3x3(-1.0, -1.0, 1.0, + 1.0, 0.0, 0.0, + 0.0, 1.0, 0.0); + + // D = diag(lam) + float3x3 D = float3x3(lam.x, 0.0, 0.0, + 0.0, lam.y, 0.0, + 0.0, 0.0, lam.z); + + // H = T * D * S_inv + float3x3 H = mul(mul(T, D), Sinv); + + // Normalize so H[2][2] = 1 + float s = H[2][2]; + if (abs(s) > 1.0e-20) + H = H * (1.0 / s); + + return H; +} + +float applyVignetting(float value, float2 uv, float2 opticalCenter, float alpha1, float alpha2, float alpha3) +{ + float2 delta = uv - opticalCenter; + float r2 = dot(delta, delta); + + float falloff = 1.0; + float r2Pow = r2; + falloff += alpha1 * r2Pow; + r2Pow *= r2; + falloff += alpha2 * r2Pow; + r2Pow *= r2; + falloff += alpha3 * r2Pow; + + falloff = clamp(falloff, 0.0, 1.0); + return value * falloff; +} + +float boundedSoftplus(float raw, float minValue) +{ + return minValue + log(1.0 + exp(raw)); +} + +float sigmoid(float raw) +{ + return 1.0 / (1.0 + exp(-raw)); +} + +// 4-param toe-shoulder CRF (port of apply_crf_ppisp from ppisp_math.cuh) +float applyCRF(float x, float toeRaw, float shoulderRaw, float gammaRaw, float centerRaw) +{ + x = clamp(x, 0.0, 1.0); + + float toe = boundedSoftplus(toeRaw, 0.3); + float shoulder = boundedSoftplus(shoulderRaw, 0.3); + float gamma = boundedSoftplus(gammaRaw, 0.1); + float center = sigmoid(centerRaw); + + // toe >= 0.3, shoulder >= 0.3, center in (0,1) — divisions are safe + float lerpVal = (shoulder - toe) * center + toe; + float a = (shoulder * center) / lerpVal; + float b = 1.0 - a; + + float y; + if (x <= center) + y = a * pow(x / center, toe); + else + y = 1.0 - b * pow((1.0 - x) / (1.0 - center), shoulder); + + return pow(max(0.0, y), gamma); +} + +float3 applyColorCorrection(float3 rgb, float3x3 H) +{ + float intensity = rgb.x + rgb.y + rgb.z; + float3 rgi = float3(rgb.x, rgb.y, intensity); + + rgi = mul(H, rgi); + + rgi = rgi * (intensity / (rgi.z + 1.0e-5)); + return float3(rgi.x, rgi.y, rgi.z - rgi.x - rgi.y); +} + +[shader("compute")] +[numthreads(16, 16, 1)] +void ppispProcess(uint3 tid : SV_DispatchThreadID) +{ + uint w = 0, h = 0; + g_InTex.GetDimensions(w, h); + if (tid.x >= w || tid.y >= h) + return; + + float4 pixel = g_InTex.Load(int3(tid.xy, 0)); + float3 rgb = pixel.rgb; + rgb *= float3(g_Params.responsivityR, g_Params.responsivityG, g_Params.responsivityB); + + // Normalize to [-0.5, 0.5] range based on max dimension (matching CUDA kernel) + float maxRes = max(float(w), float(h)); + float2 uv = float2(tid.x + 0.5 - float(w) * 0.5, tid.y + 0.5 - float(h) * 0.5) / maxRes; + + // 1. Exposure + rgb = rgb * exp2(g_Params.exposureOffset); + + // 2. Vignetting (per-channel) + rgb.r = applyVignetting(rgb.r, uv, g_Params.vignettingCenterR, + g_Params.vignettingAlpha1R, g_Params.vignettingAlpha2R, g_Params.vignettingAlpha3R); + rgb.g = applyVignetting(rgb.g, uv, g_Params.vignettingCenterG, + g_Params.vignettingAlpha1G, g_Params.vignettingAlpha2G, g_Params.vignettingAlpha3G); + rgb.b = applyVignetting(rgb.b, uv, g_Params.vignettingCenterB, + g_Params.vignettingAlpha1B, g_Params.vignettingAlpha2B, g_Params.vignettingAlpha3B); + + // 3. Color correction (ZCA-based homography) + float3x3 H = computeHomography(g_Params.colorLatentBlue, g_Params.colorLatentRed, + g_Params.colorLatentGreen, g_Params.colorLatentNeutral); + rgb = applyColorCorrection(rgb, H); + + // 4. CRF (per-channel, 4-param toe-shoulder) + rgb.r = applyCRF(rgb.r, g_Params.crfToeR, g_Params.crfShoulderR, g_Params.crfGammaR, g_Params.crfCenterR); + rgb.g = applyCRF(rgb.g, g_Params.crfToeG, g_Params.crfShoulderG, g_Params.crfGammaG, g_Params.crfCenterG); + rgb.b = applyCRF(rgb.b, g_Params.crfToeB, g_Params.crfShoulderB, g_Params.crfGammaB, g_Params.crfCenterB); + + g_OutTex[tid.xy] = float4(saturate(rgb), 1.0); +} diff --git a/threedgrut/export/usd/ppisp_spg/ppisp_usd_spg.slang.lua b/threedgrut/export/usd/ppisp_spg/ppisp_usd_spg.slang.lua new file mode 100644 index 00000000..716014f0 --- /dev/null +++ b/threedgrut/export/usd/ppisp_spg/ppisp_usd_spg.slang.lua @@ -0,0 +1,86 @@ +-- SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +-- SPDX-License-Identifier: Apache-2.0 + +-- PPISP (Physically Plausible Image Signal Processing) SPG Launcher +-- +-- Binds PPISP parameters and dispatches the compute shader for +-- USD RenderProduct post-processing. +-- +-- NOTE: Uses flat parameter names matching USD inputs: attributes (UsdShade-compatible). + +function ppispProcess(inputs, outputs, params) + local in_rgba = inputs["HdrColor"] + assert(in_rgba and in_rgba.rank == 2, "HdrColor input must be a 2D texture") + + -- LdrColor expects an RGBA8 output, even when the input is HdrColor. + local height = in_rgba.shape[1] + local width = in_rgba.shape[2] + outputs["PPISPColor"] = slang.empty({height, width}, slang.uchar4) + + -- Pass params directly to preserve __fullName for shader reflection matching. + local function getFloat2(name) + local p = params[name] + return p and slang.float2(p) or slang.float2(0.0, 0.0) + end + + return slang.dispatch({ + stage = "compute", + numthreads = { 16, 16, 1 }, + grid = { math.ceil(width / 16), math.ceil(height / 16), 1 }, + bind = { + slang.ParameterBlock( + -- Per-camera responsivity (premultiplied to input HDR) + slang.float(params["responsivityR"] or 1.0), + slang.float(params["responsivityG"] or 1.0), + slang.float(params["responsivityB"] or 1.0), + + -- Exposure + slang.float(params["exposureOffset"] or 0.0), + + -- Vignetting R + getFloat2("vignettingCenterR"), + slang.float(params["vignettingAlpha1R"] or 0.0), + slang.float(params["vignettingAlpha2R"] or 0.0), + slang.float(params["vignettingAlpha3R"] or 0.0), + + -- Vignetting G + getFloat2("vignettingCenterG"), + slang.float(params["vignettingAlpha1G"] or 0.0), + slang.float(params["vignettingAlpha2G"] or 0.0), + slang.float(params["vignettingAlpha3G"] or 0.0), + + -- Vignetting B + getFloat2("vignettingCenterB"), + slang.float(params["vignettingAlpha1B"] or 0.0), + slang.float(params["vignettingAlpha2B"] or 0.0), + slang.float(params["vignettingAlpha3B"] or 0.0), + + -- Color latent offsets (4 control points) + getFloat2("colorLatentBlue"), + getFloat2("colorLatentRed"), + getFloat2("colorLatentGreen"), + getFloat2("colorLatentNeutral"), + + -- CRF R (defaults = identity: boundedSoftplus(0.013659,0.3)=1, sigmoid(0)=0.5) + slang.float(params["crfToeR"] or 0.013659), + slang.float(params["crfShoulderR"] or 0.013659), + slang.float(params["crfGammaR"] or 0.378165), + slang.float(params["crfCenterR"] or 0.0), + + -- CRF G + slang.float(params["crfToeG"] or 0.013659), + slang.float(params["crfShoulderG"] or 0.013659), + slang.float(params["crfGammaG"] or 0.378165), + slang.float(params["crfCenterG"] or 0.0), + + -- CRF B + slang.float(params["crfToeB"] or 0.013659), + slang.float(params["crfShoulderB"] or 0.013659), + slang.float(params["crfGammaB"] or 0.378165), + slang.float(params["crfCenterB"] or 0.0) + ), + slang.Texture2D(in_rgba), + slang.RWTexture2D(outputs["PPISPColor"]), + }, + }) +end diff --git a/threedgrut/export/usd/ppisp_spg/ppisp_usd_spg.slang.usda b/threedgrut/export/usd/ppisp_spg/ppisp_usd_spg.slang.usda new file mode 100644 index 00000000..50d72281 --- /dev/null +++ b/threedgrut/export/usd/ppisp_spg/ppisp_usd_spg.slang.usda @@ -0,0 +1,63 @@ +#usda 1.0 +( + defaultPrim = "SlangPPISP" +) + +def Shader "SlangPPISP" +{ + uniform token info:implementationSource = "sourceAsset" + uniform asset info:spg:sourceAsset = @ppisp_usd_spg.slang@ + uniform token info:spg:sourceAsset:subIdentifier = "ppispProcess" + + # User-overridable per-camera responsivity (premultiplied to input HDR). + float inputs:responsivityR = 1.0 + float inputs:responsivityG = 1.0 + float inputs:responsivityB = 1.0 + + # Exposure parameter + float inputs:exposureOffset = 0.0 + + # Vignetting parameters (per channel: R, G, B) + float2 inputs:vignettingCenterR = (0.0, 0.0) + float inputs:vignettingAlpha1R = 0.0 + float inputs:vignettingAlpha2R = 0.0 + float inputs:vignettingAlpha3R = 0.0 + + float2 inputs:vignettingCenterG = (0.0, 0.0) + float inputs:vignettingAlpha1G = 0.0 + float inputs:vignettingAlpha2G = 0.0 + float inputs:vignettingAlpha3G = 0.0 + + float2 inputs:vignettingCenterB = (0.0, 0.0) + float inputs:vignettingAlpha1B = 0.0 + float inputs:vignettingAlpha2B = 0.0 + float inputs:vignettingAlpha3B = 0.0 + + # Color correction latent offsets (ZCA-based, 4 control points x 2D) + float2 inputs:colorLatentBlue = (0.0, 0.0) + float2 inputs:colorLatentRed = (0.0, 0.0) + float2 inputs:colorLatentGreen = (0.0, 0.0) + float2 inputs:colorLatentNeutral = (0.0, 0.0) + + # CRF raw parameters (per channel: R, G, B) + # Activations: boundedSoftplus(raw, min) for toe/shoulder/gamma, sigmoid(raw) for center + # Defaults produce identity CRF: toe=1, shoulder=1, gamma=1, center=0.5 + float inputs:crfToeR = 0.013659 + float inputs:crfShoulderR = 0.013659 + float inputs:crfGammaR = 0.378165 + float inputs:crfCenterR = 0.0 + + float inputs:crfToeG = 0.013659 + float inputs:crfShoulderG = 0.013659 + float inputs:crfGammaG = 0.378165 + float inputs:crfCenterG = 0.0 + + float inputs:crfToeB = 0.013659 + float inputs:crfShoulderB = 0.013659 + float inputs:crfGammaB = 0.378165 + float inputs:crfCenterB = 0.0 + + # Image inputs/outputs + opaque inputs:HdrColor + opaque outputs:PPISPColor +} diff --git a/threedgrut/export/usd/ppisp_spg/ppisp_usd_spg_dyn.slang b/threedgrut/export/usd/ppisp_spg/ppisp_usd_spg_dyn.slang new file mode 100644 index 00000000..0e582d74 --- /dev/null +++ b/threedgrut/export/usd/ppisp_spg/ppisp_usd_spg_dyn.slang @@ -0,0 +1,220 @@ +// SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 + +// PPISP (Physically Plausible ISP) SPG Shader — controller-aware variant. +// +// Identical to ppisp_usd_spg.slang in maths. The only difference is that +// `exposureOffset` and the eight colour latents come from a 1x9 single- +// channel float texture produced by the per-camera PPISP controller +// shader (ppisp_controller_.slang), instead of from +// time-sampled USD attributes. +// +// Texel layout of the controller output (matches ppisp_controller.slang): +// (0, 0): exposureOffset +// (1..2, 0): colorLatentBlue.xy +// (3..4, 0): colorLatentRed.xy +// (5..6, 0): colorLatentGreen.xy +// (7..8, 0): colorLatentNeutral.xy + +struct PPISPDynParams +{ + // User-overridable per-camera responsivity, premultiplied to the + // input HdrColor before the rest of the pipeline runs. Defaults to + // (1, 1, 1) so the override is a no-op unless explicitly authored. + float responsivityR; + float responsivityG; + float responsivityB; + + float2 vignettingCenterR; + float vignettingAlpha1R; + float vignettingAlpha2R; + float vignettingAlpha3R; + + float2 vignettingCenterG; + float vignettingAlpha1G; + float vignettingAlpha2G; + float vignettingAlpha3G; + + float2 vignettingCenterB; + float vignettingAlpha1B; + float vignettingAlpha2B; + float vignettingAlpha3B; + + float crfToeR; + float crfShoulderR; + float crfGammaR; + float crfCenterR; + + float crfToeG; + float crfShoulderG; + float crfGammaG; + float crfCenterG; + + float crfToeB; + float crfShoulderB; + float crfGammaB; + float crfCenterB; +}; + +[[vk::binding(0, 1)]] ParameterBlock g_Params; +[[vk::binding(1, 1)]] Texture2D g_InTex; +[[vk::binding(2, 1)]] Texture2D g_ControllerOut; +[[vk::binding(3, 1)]] RWTexture2D g_OutTex; + +static const float2x2 ZCA_BLUE = float2x2( 0.0480542, -0.0043631, -0.0043631, 0.0481283); +static const float2x2 ZCA_RED = float2x2( 0.0580570, -0.0179872, -0.0179872, 0.0431061); +static const float2x2 ZCA_GREEN = float2x2( 0.0433336, -0.0180537, -0.0180537, 0.0580500); +static const float2x2 ZCA_NEUTRAL = float2x2( 0.0128369, -0.0034654, -0.0034654, 0.0128158); + +float3x3 computeHomography(float2 bLat, float2 rLat, float2 gLat, float2 nLat) +{ + float2 bd = mul(ZCA_BLUE, bLat); + float2 rd = mul(ZCA_RED, rLat); + float2 gd = mul(ZCA_GREEN, gLat); + float2 nd = mul(ZCA_NEUTRAL, nLat); + + float3 tB = float3(0.0 + bd.x, 0.0 + bd.y, 1.0); + float3 tR = float3(1.0 + rd.x, 0.0 + rd.y, 1.0); + float3 tG = float3(0.0 + gd.x, 1.0 + gd.y, 1.0); + float3 tGray = float3(1.0 / 3.0 + nd.x, 1.0 / 3.0 + nd.y, 1.0); + + float3x3 T = float3x3(tB.x, tR.x, tG.x, + tB.y, tR.y, tG.y, + tB.z, tR.z, tG.z); + + float3x3 skew = float3x3(0.0, -tGray.z, tGray.y, + tGray.z, 0.0, -tGray.x, + -tGray.y, tGray.x, 0.0); + + float3x3 M = mul(skew, T); + + float3 r0 = M[0]; + float3 r1 = M[1]; + float3 r2 = M[2]; + + float3 lam = cross(r0, r1); + if (dot(lam, lam) < 1.0e-20) + { + lam = cross(r0, r2); + if (dot(lam, lam) < 1.0e-20) + lam = cross(r1, r2); + } + + float3x3 Sinv = float3x3(-1.0, -1.0, 1.0, + 1.0, 0.0, 0.0, + 0.0, 1.0, 0.0); + + float3x3 D = float3x3(lam.x, 0.0, 0.0, + 0.0, lam.y, 0.0, + 0.0, 0.0, lam.z); + + float3x3 H = mul(mul(T, D), Sinv); + + float s = H[2][2]; + if (abs(s) > 1.0e-20) + H = H * (1.0 / s); + + return H; +} + +float applyVignetting(float value, float2 uv, float2 opticalCenter, float a1, float a2, float a3) +{ + float2 d = uv - opticalCenter; + float r2 = dot(d, d); + + float falloff = 1.0; + float r2Pow = r2; + falloff += a1 * r2Pow; + r2Pow *= r2; + falloff += a2 * r2Pow; + r2Pow *= r2; + falloff += a3 * r2Pow; + + return value * clamp(falloff, 0.0, 1.0); +} + +float boundedSoftplus(float raw, float minValue) { return minValue + log(1.0 + exp(raw)); } +float sigmoidF(float raw) { return 1.0 / (1.0 + exp(-raw)); } + +float applyCRF(float x, float toeRaw, float shoulderRaw, float gammaRaw, float centerRaw) +{ + x = clamp(x, 0.0, 1.0); + float toe = boundedSoftplus(toeRaw, 0.3); + float shoulder = boundedSoftplus(shoulderRaw, 0.3); + float gamma = boundedSoftplus(gammaRaw, 0.1); + float center = sigmoidF(centerRaw); + + float lerpVal = (shoulder - toe) * center + toe; + float a = (shoulder * center) / lerpVal; + float b = 1.0 - a; + + float y; + if (x <= center) + y = a * pow(x / center, toe); + else + y = 1.0 - b * pow((1.0 - x) / (1.0 - center), shoulder); + return pow(max(0.0, y), gamma); +} + +float3 applyColorCorrection(float3 rgb, float3x3 H) +{ + float intensity = rgb.x + rgb.y + rgb.z; + float3 rgi = float3(rgb.x, rgb.y, intensity); + rgi = mul(H, rgi); + rgi = rgi * (intensity / (rgi.z + 1.0e-5)); + return float3(rgi.x, rgi.y, rgi.z - rgi.x - rgi.y); +} + +[shader("compute")] +[numthreads(16, 16, 1)] +void ppispProcessDyn(uint3 tid : SV_DispatchThreadID) +{ + uint w = 0, h = 0; + g_InTex.GetDimensions(w, h); + if (tid.x >= w || tid.y >= h) + return; + + float4 pixel = g_InTex.Load(int3(tid.xy, 0)); + float3 rgb = pixel.rgb; + rgb *= float3(g_Params.responsivityR, g_Params.responsivityG, g_Params.responsivityB); + + float maxRes = max(float(w), float(h)); + float2 uv = float2(tid.x + 0.5 - float(w) * 0.5, + tid.y + 0.5 - float(h) * 0.5) / maxRes; + + // Read controller output (1x9 single-channel float texture). + float exposureOffset = g_ControllerOut.Load(int3(0, 0, 0)); + float2 colorLatentBlue = float2(g_ControllerOut.Load(int3(1, 0, 0)), + g_ControllerOut.Load(int3(2, 0, 0))); + float2 colorLatentRed = float2(g_ControllerOut.Load(int3(3, 0, 0)), + g_ControllerOut.Load(int3(4, 0, 0))); + float2 colorLatentGreen = float2(g_ControllerOut.Load(int3(5, 0, 0)), + g_ControllerOut.Load(int3(6, 0, 0))); + float2 colorLatentNeutral = float2(g_ControllerOut.Load(int3(7, 0, 0)), + g_ControllerOut.Load(int3(8, 0, 0))); + + rgb = rgb * exp2(exposureOffset); + + rgb.r = applyVignetting(rgb.r, uv, g_Params.vignettingCenterR, + g_Params.vignettingAlpha1R, g_Params.vignettingAlpha2R, g_Params.vignettingAlpha3R); + rgb.g = applyVignetting(rgb.g, uv, g_Params.vignettingCenterG, + g_Params.vignettingAlpha1G, g_Params.vignettingAlpha2G, g_Params.vignettingAlpha3G); + rgb.b = applyVignetting(rgb.b, uv, g_Params.vignettingCenterB, + g_Params.vignettingAlpha1B, g_Params.vignettingAlpha2B, g_Params.vignettingAlpha3B); + + float3x3 H = computeHomography(colorLatentBlue, colorLatentRed, + colorLatentGreen, colorLatentNeutral); + rgb = applyColorCorrection(rgb, H); + + rgb.r = applyCRF(rgb.r, g_Params.crfToeR, g_Params.crfShoulderR, g_Params.crfGammaR, g_Params.crfCenterR); + rgb.g = applyCRF(rgb.g, g_Params.crfToeG, g_Params.crfShoulderG, g_Params.crfGammaG, g_Params.crfCenterG); + rgb.b = applyCRF(rgb.b, g_Params.crfToeB, g_Params.crfShoulderB, g_Params.crfGammaB, g_Params.crfCenterB); + + g_OutTex[tid.xy] = float4(saturate(rgb), 1.0); +} diff --git a/threedgrut/export/usd/ppisp_spg/ppisp_usd_spg_dyn.slang.lua b/threedgrut/export/usd/ppisp_spg/ppisp_usd_spg_dyn.slang.lua new file mode 100644 index 00000000..735b8cba --- /dev/null +++ b/threedgrut/export/usd/ppisp_spg/ppisp_usd_spg_dyn.slang.lua @@ -0,0 +1,72 @@ +-- SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +-- SPDX-License-Identifier: Apache-2.0 + +-- PPISP SPG Launcher (controller-aware variant). +-- +-- Reads ``exposureOffset`` and the eight colour latents from the +-- controller's output texture; the static USD inputs only carry the +-- per-camera vignetting and CRF parameters. The HdrColor input still +-- comes from the RenderProduct's primary AOV. + +function ppispProcessDyn(inputs, outputs, params) + local in_rgba = inputs["HdrColor"] + assert(in_rgba and in_rgba.rank == 2, "HdrColor input must be a 2D texture") + + local controller = inputs["ControllerParams"] + assert(controller, "ppispProcessDyn needs a ControllerParams input texture") + + local height = in_rgba.shape[1] + local width = in_rgba.shape[2] + outputs["PPISPColor"] = slang.empty({ height, width }, slang.uchar4) + + local function getFloat2(name) + local p = params[name] + return p and slang.float2(p) or slang.float2(0.0, 0.0) + end + + return slang.dispatch({ + stage = "compute", + numthreads = { 16, 16, 1 }, + grid = { math.ceil(width / 16), math.ceil(height / 16), 1 }, + bind = { + slang.ParameterBlock( + slang.float(params["responsivityR"] or 1.0), + slang.float(params["responsivityG"] or 1.0), + slang.float(params["responsivityB"] or 1.0), + + getFloat2("vignettingCenterR"), + slang.float(params["vignettingAlpha1R"] or 0.0), + slang.float(params["vignettingAlpha2R"] or 0.0), + slang.float(params["vignettingAlpha3R"] or 0.0), + + getFloat2("vignettingCenterG"), + slang.float(params["vignettingAlpha1G"] or 0.0), + slang.float(params["vignettingAlpha2G"] or 0.0), + slang.float(params["vignettingAlpha3G"] or 0.0), + + getFloat2("vignettingCenterB"), + slang.float(params["vignettingAlpha1B"] or 0.0), + slang.float(params["vignettingAlpha2B"] or 0.0), + slang.float(params["vignettingAlpha3B"] or 0.0), + + slang.float(params["crfToeR"] or 0.013659), + slang.float(params["crfShoulderR"] or 0.013659), + slang.float(params["crfGammaR"] or 0.378165), + slang.float(params["crfCenterR"] or 0.0), + + slang.float(params["crfToeG"] or 0.013659), + slang.float(params["crfShoulderG"] or 0.013659), + slang.float(params["crfGammaG"] or 0.378165), + slang.float(params["crfCenterG"] or 0.0), + + slang.float(params["crfToeB"] or 0.013659), + slang.float(params["crfShoulderB"] or 0.013659), + slang.float(params["crfGammaB"] or 0.378165), + slang.float(params["crfCenterB"] or 0.0) + ), + slang.Texture2D(in_rgba), + slang.Texture2D(controller), + slang.RWTexture2D(outputs["PPISPColor"]), + }, + }) +end diff --git a/threedgrut/export/usd/ppisp_spg/ppisp_usd_spg_dyn.slang.usda b/threedgrut/export/usd/ppisp_spg/ppisp_usd_spg_dyn.slang.usda new file mode 100644 index 00000000..934bcd38 --- /dev/null +++ b/threedgrut/export/usd/ppisp_spg/ppisp_usd_spg_dyn.slang.usda @@ -0,0 +1,53 @@ +#usda 1.0 +( + defaultPrim = "SlangPPISPDyn" +) + +def Shader "SlangPPISPDyn" +{ + uniform token info:implementationSource = "sourceAsset" + uniform asset info:spg:sourceAsset = @ppisp_usd_spg_dyn.slang@ + uniform token info:spg:sourceAsset:subIdentifier = "ppispProcessDyn" + + # User-overridable per-camera responsivity (premultiplied to input HDR). + float inputs:responsivityR = 1.0 + float inputs:responsivityG = 1.0 + float inputs:responsivityB = 1.0 + + # Vignetting (per channel: R, G, B) + float2 inputs:vignettingCenterR = (0.0, 0.0) + float inputs:vignettingAlpha1R = 0.0 + float inputs:vignettingAlpha2R = 0.0 + float inputs:vignettingAlpha3R = 0.0 + + float2 inputs:vignettingCenterG = (0.0, 0.0) + float inputs:vignettingAlpha1G = 0.0 + float inputs:vignettingAlpha2G = 0.0 + float inputs:vignettingAlpha3G = 0.0 + + float2 inputs:vignettingCenterB = (0.0, 0.0) + float inputs:vignettingAlpha1B = 0.0 + float inputs:vignettingAlpha2B = 0.0 + float inputs:vignettingAlpha3B = 0.0 + + # CRF raw parameters (per channel: R, G, B) + float inputs:crfToeR = 0.013659 + float inputs:crfShoulderR = 0.013659 + float inputs:crfGammaR = 0.378165 + float inputs:crfCenterR = 0.0 + + float inputs:crfToeG = 0.013659 + float inputs:crfShoulderG = 0.013659 + float inputs:crfGammaG = 0.378165 + float inputs:crfCenterG = 0.0 + + float inputs:crfToeB = 0.013659 + float inputs:crfShoulderB = 0.013659 + float inputs:crfGammaB = 0.378165 + float inputs:crfCenterB = 0.0 + + # Image inputs/outputs + opaque inputs:HdrColor + opaque inputs:ControllerParams + opaque outputs:PPISPColor +} diff --git a/threedgrut/export/usd/stage_utils.py b/threedgrut/export/usd/stage_utils.py index 9b70bcfe..e555e9a2 100644 --- a/threedgrut/export/usd/stage_utils.py +++ b/threedgrut/export/usd/stage_utils.py @@ -22,6 +22,7 @@ import logging import os +import struct import tempfile import zipfile from dataclasses import dataclass @@ -38,6 +39,31 @@ # Constants DEFAULT_FRAME_RATE = 24.0 USD_WORLD_PATH = "/World" +_USDZ_ALIGNMENT = 64 +_USDZ_PADDING_EXTRA_ID = 0x1986 + + +def _write_usdz_entry(zip_file: zipfile.ZipFile, filename: str, data: Union[str, bytes]) -> None: + if isinstance(data, str): + data = data.encode("utf-8") + + header_offset = zip_file.fp.tell() + filename_size = len(filename.encode("utf-8")) + unpadded_data_offset = header_offset + 30 + filename_size + padding_size = (-unpadded_data_offset) % _USDZ_ALIGNMENT + + # ZIP extra fields need a 4-byte header. If the needed padding is smaller, + # add one full alignment period and keep the same modulo. + if 0 < padding_size < 4: + padding_size += _USDZ_ALIGNMENT + + zip_info = zipfile.ZipInfo(filename) + zip_info.compress_type = zipfile.ZIP_STORED + if padding_size: + zip_info.extra = struct.pack(" Usd.Stage: diff --git a/threedgrut/export/usd/validation.py b/threedgrut/export/usd/validation.py new file mode 100644 index 00000000..e198b0d3 --- /dev/null +++ b/threedgrut/export/usd/validation.py @@ -0,0 +1,77 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""OpenUSD validation helpers for exported ParticleField / LightField stages.""" + +from __future__ import annotations + +import logging +from pathlib import Path + +logger = logging.getLogger(__name__) + +# Stage-wide checks used by export tests; ParticleField-specific validators may be added as USD exposes them. +_LIGHTFIELD_VALIDATOR_NAMES = ( + "usdValidation:StageMetadataChecker", + "usdValidation:CompositionErrorTest", +) + + +def validate_exported_usd_stage(path: Path) -> None: + """ + Run OpenUSD validation on a written .usd / .usda / .usdc / .usdz. + + Intended for outputs from :class:`~threedgrut.export.usd.exporter.USDExporter` + (ParticleField3DGaussianSplat / LightField). NuRec exports are not validated here. + + If ``UsdValidation`` is missing, validators fail to load, or the registry API is + unavailable, this function logs at DEBUG and returns without error. + + Args: + path: Path to the package root file on disk. + + Raises: + ValueError: Stage cannot be opened, or validators reported errors. + """ + path = Path(path) + try: + from pxr import Usd, UsdValidation + except ImportError: + logger.debug("pxr not available; skipping USD validation for %s", path) + return + + try: + registry = UsdValidation.ValidationRegistry() + validators = registry.GetOrLoadValidatorsByName(list(_LIGHTFIELD_VALIDATOR_NAMES)) + except Exception as exc: + logger.debug("UsdValidation unavailable (%s); skipping USD validation for %s", exc, path) + return + + if not validators: + logger.debug("No USD validators loaded; skipping validation for %s", path) + return + + stage = Usd.Stage.Open(str(path)) + if not stage: + raise ValueError(f"USD validation could not open stage: {path}") + + logger.info("Running OpenUSD stage validation on %s", path) + ctx = UsdValidation.ValidationContext(validators) + result = ctx.Validate(stage) + errors = list(result) if result else [] + if errors: + msg = "\n".join(e.GetMessage() for e in errors) + raise ValueError(f"USD validation failed for {path}:\n{msg}") + logger.info("OpenUSD stage validation passed for %s", path) diff --git a/threedgrut/export/usd/writers/__init__.py b/threedgrut/export/usd/writers/__init__.py index a23669d1..92531563 100644 --- a/threedgrut/export/usd/writers/__init__.py +++ b/threedgrut/export/usd/writers/__init__.py @@ -18,12 +18,17 @@ Provides schema-agnostic interface for writing Gaussian data to USD: - GaussianLightFieldWriter: ParticleField3DGaussianSplat schema +- export_cameras_to_usd: one Camera prim per physical camera, animated xforms +- create_render_products: /Render scope with per-camera RenderProducts +- add_ppisp_to_all_render_products: PPISP SPG shader on RenderProducts """ from threedgrut.export.usd.writers.background import export_background_to_usd from threedgrut.export.usd.writers.base import GaussianUSDWriter, create_gaussian_writer from threedgrut.export.usd.writers.camera import export_cameras_to_usd from threedgrut.export.usd.writers.lightfield import GaussianLightFieldWriter +from threedgrut.export.usd.writers.ppisp_writer import add_ppisp_to_all_render_products +from threedgrut.export.usd.writers.render_product import create_render_products __all__ = [ "GaussianUSDWriter", @@ -31,4 +36,6 @@ "create_gaussian_writer", "export_cameras_to_usd", "export_background_to_usd", + "create_render_products", + "add_ppisp_to_all_render_products", ] diff --git a/threedgrut/export/usd/writers/background.py b/threedgrut/export/usd/writers/background.py index 7b91fc9a..a681fd82 100644 --- a/threedgrut/export/usd/writers/background.py +++ b/threedgrut/export/usd/writers/background.py @@ -41,9 +41,7 @@ def _tensor_to_tuple(color: torch.Tensor) -> Tuple[float, float, float]: """Convert a torch tensor color to a tuple of floats.""" - if color.is_cuda: - color = color.cpu() - arr = color.numpy() + arr = color.detach().cpu().numpy() return tuple(float(c) for c in arr[:3]) diff --git a/threedgrut/export/usd/writers/base.py b/threedgrut/export/usd/writers/base.py index ae06159c..4b794e03 100644 --- a/threedgrut/export/usd/writers/base.py +++ b/threedgrut/export/usd/writers/base.py @@ -27,6 +27,7 @@ from pxr import Gf, Usd, Vt from threedgrut.export.accessor import GaussianAttributes, ModelCapabilities +from threedgrut.export.usd.particle_field_hints import DEFAULT_PARTICLE_FIELD_SORTING_MODE_HINT logger = logging.getLogger(__name__) @@ -50,11 +51,15 @@ def __init__( capabilities: ModelCapabilities, content_root_path: str = "/World/Gaussians", linear_srgb: bool = False, + omni_usd: bool = False, + has_post_processing: bool = False, ): self.stage = stage self.capabilities = capabilities self.content_root_path = content_root_path self.linear_srgb = linear_srgb + self.omni_usd = omni_usd + self.has_post_processing = has_post_processing self.prim: Optional[Usd.Prim] = None def apply_color_space_to_prim(self, prim: Usd.Prim) -> None: @@ -128,8 +133,10 @@ def create_gaussian_writer( content_root_path: str = "/World/Gaussians", half_geometry: bool = False, half_features: bool = False, - sorting_mode_hint: str = "cameraDistance", + sorting_mode_hint: str = DEFAULT_PARTICLE_FIELD_SORTING_MODE_HINT, linear_srgb: bool = False, + omni_usd: bool = False, + has_post_processing: bool = False, ) -> GaussianUSDWriter: """Factory function to create USD Gaussian writer. @@ -141,6 +148,8 @@ def create_gaussian_writer( half_features: Use half precision for opacities and SH coefficients (LightField) sorting_mode_hint: Sorting mode hint for LightField schema linear_srgb: If True, set prim color space to lin_rec709_scene; else srgb_rec709_display + omni_usd: If True, author Omniverse-specific USD features. + has_post_processing: If True, configure Omniverse material for external post-processing. Returns: Configured GaussianUSDWriter instance (LightField schema) @@ -155,4 +164,6 @@ def create_gaussian_writer( half_features=half_features, sorting_mode_hint=sorting_mode_hint, linear_srgb=linear_srgb, + omni_usd=omni_usd, + has_post_processing=has_post_processing, ) diff --git a/threedgrut/export/usd/writers/camera.py b/threedgrut/export/usd/writers/camera.py index f6dc6c6b..4234017d 100644 --- a/threedgrut/export/usd/writers/camera.py +++ b/threedgrut/export/usd/writers/camera.py @@ -16,53 +16,57 @@ """ Camera USD writer for exporting camera poses and intrinsics. -Exports camera poses with full intrinsics support for OpenCVPinhole and OpenCVFisheye -camera models, following the pattern established in NRE's rig_trajectories.py. +Exports one Camera prim per physical camera with time-sampled transforms +and static intrinsics, following the pattern established in NRE's +rig_trajectories.py. """ import logging -from typing import List, Optional +from typing import Dict, List, Optional import numpy as np from ncore.data import ( OpenCVFisheyeCameraModelParameters, OpenCVPinholeCameraModelParameters, ) -from pxr import Gf, Sdf, Usd, UsdGeom, Vt +from pxr import Gf, Sdf, Tf, Usd, UsdGeom from threedgrut.export.transforms import column_vector_4x4_to_usd_matrix logger = logging.getLogger(__name__) -# Default clipping range for cameras DEFAULT_NEAR_CLIP = 0.001 DEFAULT_FAR_CLIP = 10000000.0 +# Coordinate transform from 3DGRUT (right-down-front) to USD camera (right-up-back) +_CAMERA_COORD_FLIP = np.array( + [[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]], dtype=np.float64 +) + + +def _make_usd_prim_name(name: str) -> str: + """Convert an arbitrary string to a valid USD prim identifier.""" + return Tf.MakeValidIdentifier(name) + def _add_opencv_pinhole_camera_intrinsics( camera_prim: Usd.Prim, params: OpenCVPinholeCameraModelParameters, ) -> None: - """Add OpenCV pinhole camera intrinsics to USD camera prim.""" - # Camera projection type - camera_prim.CreateAttribute("cameraProjectionType", Sdf.ValueTypeNames.Token).Set(Vt.Token("pinholeOpenCV")) + camera_prim.CreateAttribute("cameraProjectionType", Sdf.ValueTypeNames.Token).Set("pinholeOpenCV") - # Resolution resolution_list = params.resolution.tolist() camera_prim.CreateAttribute("fthetaWidth", Sdf.ValueTypeNames.Float).Set(float(resolution_list[0])) camera_prim.CreateAttribute("fthetaHeight", Sdf.ValueTypeNames.Float).Set(float(resolution_list[1])) - # Principal point principal_point_list = params.principal_point.tolist() camera_prim.CreateAttribute("fthetaCx", Sdf.ValueTypeNames.Float).Set(float(principal_point_list[0])) camera_prim.CreateAttribute("fthetaCy", Sdf.ValueTypeNames.Float).Set(float(principal_point_list[1])) - # Focal length focal_length_list = params.focal_length.tolist() camera_prim.CreateAttribute("openCVFx", Sdf.ValueTypeNames.Float).Set(float(focal_length_list[0])) camera_prim.CreateAttribute("openCVFy", Sdf.ValueTypeNames.Float).Set(float(focal_length_list[1])) - # Radial distortion coefficients [k1,k2,k3,k4,k5,k6] radial_coeffs_list = params.radial_coeffs.tolist() camera_prim.CreateAttribute("fthetaPolyA", Sdf.ValueTypeNames.Float).Set(float(radial_coeffs_list[0])) camera_prim.CreateAttribute("fthetaPolyB", Sdf.ValueTypeNames.Float).Set(float(radial_coeffs_list[1])) @@ -71,12 +75,10 @@ def _add_opencv_pinhole_camera_intrinsics( camera_prim.CreateAttribute("fthetaPolyE", Sdf.ValueTypeNames.Float).Set(float(radial_coeffs_list[4])) camera_prim.CreateAttribute("fthetaPolyF", Sdf.ValueTypeNames.Float).Set(float(radial_coeffs_list[5])) - # Tangential distortion coefficients [p1,p2] tangential_coeffs_list = params.tangential_coeffs.tolist() camera_prim.CreateAttribute("p0", Sdf.ValueTypeNames.Float).Set(float(tangential_coeffs_list[0])) camera_prim.CreateAttribute("p1", Sdf.ValueTypeNames.Float).Set(float(tangential_coeffs_list[1])) - # Thin prism distortion coefficients [s1,s2,s3,s4] thin_prism_coeffs_list = params.thin_prism_coeffs.tolist() camera_prim.CreateAttribute("s0", Sdf.ValueTypeNames.Float).Set(float(thin_prism_coeffs_list[0])) camera_prim.CreateAttribute("s1", Sdf.ValueTypeNames.Float).Set(float(thin_prism_coeffs_list[1])) @@ -88,152 +90,124 @@ def _add_opencv_fisheye_camera_intrinsics( camera_prim: Usd.Prim, params: OpenCVFisheyeCameraModelParameters, ) -> None: - """Add OpenCV fisheye camera intrinsics to USD camera prim.""" - # Camera projection type - camera_prim.CreateAttribute("cameraProjectionType", Sdf.ValueTypeNames.Token).Set(Vt.Token("fisheyeOpenCV")) + camera_prim.CreateAttribute("cameraProjectionType", Sdf.ValueTypeNames.Token).Set("fisheyeOpenCV") - # Resolution resolution_list = params.resolution.tolist() camera_prim.CreateAttribute("fthetaWidth", Sdf.ValueTypeNames.Float).Set(float(resolution_list[0])) camera_prim.CreateAttribute("fthetaHeight", Sdf.ValueTypeNames.Float).Set(float(resolution_list[1])) - # Principal point principal_point_list = params.principal_point.tolist() camera_prim.CreateAttribute("fthetaCx", Sdf.ValueTypeNames.Float).Set(float(principal_point_list[0])) camera_prim.CreateAttribute("fthetaCy", Sdf.ValueTypeNames.Float).Set(float(principal_point_list[1])) - # Focal length focal_length_list = params.focal_length.tolist() camera_prim.CreateAttribute("openCVFx", Sdf.ValueTypeNames.Float).Set(float(focal_length_list[0])) camera_prim.CreateAttribute("openCVFy", Sdf.ValueTypeNames.Float).Set(float(focal_length_list[1])) - # Radial distortion coefficients [k1,k2,k3,k4] radial_coeffs_list = params.radial_coeffs.tolist() camera_prim.CreateAttribute("fthetaPolyA", Sdf.ValueTypeNames.Float).Set(float(radial_coeffs_list[0])) camera_prim.CreateAttribute("fthetaPolyB", Sdf.ValueTypeNames.Float).Set(float(radial_coeffs_list[1])) camera_prim.CreateAttribute("fthetaPolyC", Sdf.ValueTypeNames.Float).Set(float(radial_coeffs_list[2])) camera_prim.CreateAttribute("fthetaPolyD", Sdf.ValueTypeNames.Float).Set(float(radial_coeffs_list[3])) - # Max FoV (convert from radians to degrees, x2 for full FoV) - camera_prim.CreateAttribute("fthetaMaxFov", Sdf.ValueTypeNames.Float).Set(float(2.0 * np.rad2deg(params.max_angle))) - - -def _add_simple_pinhole_intrinsics( - camera_prim: Usd.Prim, - intrinsics: List[float], - resolution: List[int], -) -> None: - """Add simple pinhole intrinsics [fx, fy, cx, cy] without distortion.""" - fx, fy, cx, cy = intrinsics - - # Use standard USD pinhole camera attributes - # Compute horizontal aperture from resolution and focal length - # USD uses mm for aperture, assuming sensor is 36mm (full-frame) - sensor_width_mm = 36.0 - focal_length_mm = (fx / resolution[0]) * sensor_width_mm - - camera_prim.GetFocalLengthAttr().Set(focal_length_mm) - camera_prim.GetHorizontalApertureAttr().Set(sensor_width_mm) - camera_prim.GetVerticalApertureAttr().Set(sensor_width_mm * resolution[1] / resolution[0]) - - # Principal point offset from center - horizontal_offset = ((cx / resolution[0]) - 0.5) * sensor_width_mm - vertical_offset = ((cy / resolution[1]) - 0.5) * (sensor_width_mm * resolution[1] / resolution[0]) - camera_prim.GetHorizontalApertureOffsetAttr().Set(horizontal_offset) - camera_prim.GetVerticalApertureOffsetAttr().Set(vertical_offset) + camera_prim.CreateAttribute("fthetaMaxFov", Sdf.ValueTypeNames.Float).Set( + float(2.0 * np.rad2deg(params.max_angle)) + ) def export_cameras_to_usd( stage: Usd.Stage, poses: np.ndarray, - intrinsics: Optional[List] = None, + camera_names: List[str], + frame_to_camera: List[int], camera_params: Optional[List] = None, - resolutions: Optional[List[np.ndarray]] = None, root_path: str = "/World/Cameras", - camera_prefix: str = "camera", visible: bool = False, -) -> str: +) -> Dict[str, str]: """ - Export camera poses with intrinsics to USD stage. + Export camera poses with intrinsics to a USD stage. - Supports multiple camera model types: - - OpenCVPinholeCameraModelParameters: Full pinhole with distortion - - OpenCVFisheyeCameraModelParameters: Fisheye with distortion - - Simple intrinsics: [fx, fy, cx, cy] list for basic pinhole + Creates one Camera prim per physical camera with time-sampled transforms + and static intrinsics. The time code for frame i is float(i), so + stage.GetTimeCodesPerSecond() controls real-time playback speed. Args: - stage: USD stage to export to - poses: Camera poses [N, 4, 4] in 3DGRUT convention (right-down-front) - intrinsics: Optional list of [fx, fy, cx, cy] for simple pinhole - camera_params: Optional list of camera model parameters (OpenCVPinhole/Fisheye) - resolutions: Optional list of resolutions [[w, h], ...] for simple intrinsics - root_path: USD path for camera root xform - camera_prefix: Prefix for camera names - visible: Whether cameras should be visible in viewport + stage: USD stage to export to. + poses: Camera-to-world transforms [N_frames, 4, 4] in 3DGRUT convention + (right-down-front). + camera_names: Logical name for each physical camera, indexed by camera_idx. + frame_to_camera: Per-frame camera index mapping, length N_frames. + camera_params: Per-frame CameraModelParameters (OpenCVPinhole / Fisheye). + Intrinsics are taken from the first frame of each camera. + root_path: USD path for the camera root Xform. + visible: Whether camera prims should be visible in the viewport. Returns: - Root path of the cameras + Mapping {camera_name: usd_prim_path} for every exported camera. """ - num_cameras = poses.shape[0] + num_cameras = len(camera_names) + + # Group frame indices by camera + camera_frames: Dict[int, List[int]] = {i: [] for i in range(num_cameras)} + for frame_idx, cam_idx in enumerate(frame_to_camera): + if 0 <= cam_idx < num_cameras: + camera_frames[cam_idx].append(frame_idx) - # Create root xform for cameras UsdGeom.Xform.Define(stage, root_path) - # Coordinate transform from 3DGRUT (right-down-front) to USD camera (right-up-back) - # 3DGRUT: X=right, Y=down, Z=front - # USD: X=right, Y=up, Z=back - camera_coord_flip = np.array([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]], dtype=np.float64) + result: Dict[str, str] = {} + usd_start_time_code = float("inf") + usd_end_time_code = float("-inf") + + for cam_idx, cam_name in enumerate(camera_names): + frame_indices = camera_frames[cam_idx] + if not frame_indices: + logger.warning(f"Camera '{cam_name}' (idx {cam_idx}) has no frames, skipping") + continue - for i in range(num_cameras): - camera_name = f"{camera_prefix}_{i:04d}" - camera_path = f"{root_path}/{camera_name}" + prim_name = _make_usd_prim_name(cam_name) + camera_path = f"{root_path}/{prim_name}" - # Define camera prim camera_prim = stage.DefinePrim(camera_path, "Camera") camera = UsdGeom.Camera(camera_prim) - - # Set clipping range camera.GetClippingRangeAttr().Set(Gf.Vec2f(DEFAULT_NEAR_CLIP, DEFAULT_FAR_CLIP)) - # Add intrinsics based on available data - if camera_params is not None and i < len(camera_params) and camera_params[i] is not None: - params = camera_params[i] + # Static intrinsics from first frame of this camera + first_frame = frame_indices[0] + if camera_params is not None and first_frame < len(camera_params) and camera_params[first_frame] is not None: + params = camera_params[first_frame] if isinstance(params, OpenCVPinholeCameraModelParameters): _add_opencv_pinhole_camera_intrinsics(camera_prim, params) elif isinstance(params, OpenCVFisheyeCameraModelParameters): _add_opencv_fisheye_camera_intrinsics(camera_prim, params) else: - # Fallback to default focal length camera.GetFocalLengthAttr().Set(24.0) - logger.warning(f"Unsupported camera model for camera {i}, using default intrinsics") - elif intrinsics is not None and resolutions is not None: - # Simple pinhole from intrinsics list - if i < len(resolutions): - resolution = resolutions[i].tolist() if isinstance(resolutions[i], np.ndarray) else resolutions[i] - else: - resolution = resolutions[0].tolist() if isinstance(resolutions[0], np.ndarray) else resolutions[0] - _add_simple_pinhole_intrinsics(camera_prim, intrinsics, resolution) + logger.warning(f"Unsupported camera model for '{cam_name}', using default focal length") else: - # Fallback to default focal length camera.GetFocalLengthAttr().Set(24.0) - # Set camera transform (pose) - # Apply coordinate system transform: 3DGRUT -> USD camera, then build USD matrix via Gf API - pose = poses[i] - usd_pose = pose @ camera_coord_flip - usd_matrix = column_vector_4x4_to_usd_matrix(usd_pose) - + # Time-sampled transforms — one sample per frame belonging to this camera xformable = UsdGeom.Xformable(camera_prim) transform_op = xformable.AddTransformOp() - transform_op.Set(usd_matrix) + for frame_idx in frame_indices: + usd_pose = poses[frame_idx] @ _CAMERA_COORD_FLIP + transform_op.Set(column_vector_4x4_to_usd_matrix(usd_pose), float(frame_idx)) + usd_start_time_code = min(usd_start_time_code, float(frame_idx)) + usd_end_time_code = max(usd_end_time_code, float(frame_idx)) - # Set visibility imageable = UsdGeom.Imageable(camera_prim) - visibility = "inherited" if visible else "invisible" - imageable.CreateVisibilityAttr().Set(visibility) + imageable.CreateVisibilityAttr().Set("inherited" if visible else "invisible") - logger.info(f"Exported {num_cameras} cameras to {root_path}") - return root_path + result[cam_name] = camera_path + + if usd_start_time_code <= usd_end_time_code: + stage.SetStartTimeCode(usd_start_time_code) + stage.SetEndTimeCode(usd_end_time_code) + + logger.info( + f"Exported {len(result)} camera(s) ({len(poses)} total frames) to {root_path}" + ) + return result def export_camera_rig_with_timestamps( @@ -267,42 +241,30 @@ def export_camera_rig_with_timestamps( """ num_frames = poses.shape[0] - # Create rig xform rig_prim = stage.DefinePrim(root_path, "Xform") rig_xform = UsdGeom.Xformable(rig_prim) - # Coordinate transform - camera_coord_flip = np.array([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]], dtype=np.float64) - - # USD time code setup usd_time_code_per_second = stage.GetTimeCodesPerSecond() - usd_timestamp_scale = usd_time_code_per_second * 1e-06 # microseconds to time codes + usd_timestamp_scale = usd_time_code_per_second * 1e-06 - # Create transform op for rig rig_transform_op = rig_xform.AddTransformOp() usd_start_time_code = float("inf") usd_end_time_code = 0.0 - # Add time-sampled transforms for i in range(num_frames): - pose = poses[i] - usd_pose = pose @ camera_coord_flip + usd_pose = poses[i] @ _CAMERA_COORD_FLIP usd_matrix = column_vector_4x4_to_usd_matrix(usd_pose) if timestamps_us is not None: - timestamp = timestamps_us[i] - usd_time_code = usd_timestamp_scale * (timestamp - timestamp_offset_us) - usd_start_time_code = min(usd_start_time_code, usd_time_code) - usd_end_time_code = max(usd_end_time_code, usd_time_code) + usd_time_code = usd_timestamp_scale * (timestamps_us[i] - timestamp_offset_us) else: usd_time_code = float(i) - usd_start_time_code = min(usd_start_time_code, usd_time_code) - usd_end_time_code = max(usd_end_time_code, usd_time_code) + usd_start_time_code = min(usd_start_time_code, usd_time_code) + usd_end_time_code = max(usd_end_time_code, usd_time_code) rig_transform_op.Set(usd_matrix, usd_time_code) - # Set time metadata if usd_start_time_code <= usd_end_time_code: stage.SetMetadata("startTimeCode", usd_start_time_code) stage.SetMetadata("endTimeCode", usd_end_time_code) @@ -310,15 +272,11 @@ def export_camera_rig_with_timestamps( if timestamps_us is not None: stage.SetMetadataByDictKey("customLayerData", "absoluteTimeOffsetMicroSec", timestamp_offset_us) - # Create camera prim under rig (static relative to rig) camera_path = f"{root_path}/{camera_name}" camera_prim = stage.DefinePrim(camera_path, "Camera") camera = UsdGeom.Camera(camera_prim) - - # Set default clipping range camera.GetClippingRangeAttr().Set(Gf.Vec2f(DEFAULT_NEAR_CLIP, DEFAULT_FAR_CLIP)) - # Add intrinsics if provided if camera_params is not None and len(camera_params) > 0: params = camera_params[0] if isinstance(params, OpenCVPinholeCameraModelParameters): @@ -330,15 +288,12 @@ def export_camera_rig_with_timestamps( else: camera.GetFocalLengthAttr().Set(24.0) - # Camera is at identity transform relative to rig (transform is on rig itself) xformable = UsdGeom.Xformable(camera_prim) transform_op = xformable.AddTransformOp() transform_op.Set(Gf.Matrix4d(1.0)) - # Set visibility imageable = UsdGeom.Imageable(camera_prim) - visibility = "inherited" if visible else "invisible" - imageable.CreateVisibilityAttr().Set(visibility) + imageable.CreateVisibilityAttr().Set("inherited" if visible else "invisible") logger.info(f"Exported camera rig with {num_frames} frames to {root_path}") return root_path diff --git a/threedgrut/export/usd/writers/lightfield.py b/threedgrut/export/usd/writers/lightfield.py index 4f207ac5..33016ba7 100644 --- a/threedgrut/export/usd/writers/lightfield.py +++ b/threedgrut/export/usd/writers/lightfield.py @@ -27,6 +27,10 @@ from pxr import Gf, Sdf, Usd, UsdGeom, UsdVol, Vt from threedgrut.export.accessor import GaussianAttributes, ModelCapabilities +from threedgrut.export.usd.particle_field_hints import ( + DEFAULT_PARTICLE_FIELD_SORTING_MODE_HINT, + normalize_particle_field_sorting_mode_hint, +) from threedgrut.export.usd.writers.base import GaussianUSDWriter logger = logging.getLogger(__name__) @@ -49,14 +53,23 @@ def __init__( half_geometry: bool = False, half_features: bool = False, projection_mode_hint: str = "perspective", - sorting_mode_hint: str = "cameraDistance", + sorting_mode_hint: str = DEFAULT_PARTICLE_FIELD_SORTING_MODE_HINT, linear_srgb: bool = False, + omni_usd: bool = False, + has_post_processing: bool = False, ) -> None: - super().__init__(stage, capabilities, content_root_path, linear_srgb=linear_srgb) + super().__init__( + stage, + capabilities, + content_root_path, + linear_srgb=linear_srgb, + omni_usd=omni_usd, + has_post_processing=has_post_processing, + ) self.half_geometry = half_geometry self.half_features = half_features self.projection_mode_hint = projection_mode_hint - self.sorting_mode_hint = sorting_mode_hint + self.sorting_mode_hint = normalize_particle_field_sorting_mode_hint(sorting_mode_hint) # Use surflet kernel for surfel models, ellipsoid for 3DGS self.use_surflet_kernel = capabilities.is_surfel @@ -91,6 +104,14 @@ def create_prim(self, num_gaussians: int) -> Usd.Prim: self._set_rendering_hints() self.apply_color_space_to_prim(self.prim) + if self.omni_usd: + from threedgrut.export.usd.writers.omni_material import bind_particlefield_emissive_material + + bind_particlefield_emissive_material( + stage=self.stage, + prim=self.prim, + has_post_processing=self.has_post_processing, + ) return self.prim def _apply_surflet_kernel_schemas(self) -> None: diff --git a/threedgrut/export/usd/writers/omni_material.py b/threedgrut/export/usd/writers/omni_material.py new file mode 100644 index 00000000..0c24ad81 --- /dev/null +++ b/threedgrut/export/usd/writers/omni_material.py @@ -0,0 +1,72 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Omniverse-specific USD material authoring for Gaussian ParticleFields.""" + +from pxr import Sdf, Usd, UsdShade + +USD_LOOKS_PATH = "/World/Looks" +USD_PARTICLEFIELD_MATERIAL_PATH = f"{USD_LOOKS_PATH}/ParticleFieldEmissive" +USD_PARTICLEFIELD_SHADER_PATH = f"{USD_PARTICLEFIELD_MATERIAL_PATH}/Shader" +PARTICLEFIELD_MATERIAL_MDL_FILE = "ParticleFieldEmissive.mdl" +PARTICLEFIELD_MATERIAL_NAME = "ParticleFieldEmissive" + + +def bind_particlefield_emissive_material( + stage: Usd.Stage, + prim: Usd.Prim, + has_post_processing: bool = False, +) -> None: + """Bind Kit's ParticleFieldEmissive MDL material to a Gaussian ParticleField.""" + looks_prim = stage.GetPrimAtPath(USD_LOOKS_PATH) + if not looks_prim.IsValid(): + stage.DefinePrim(USD_LOOKS_PATH, "Scope") + + material_prim = stage.DefinePrim(USD_PARTICLEFIELD_MATERIAL_PATH, "Material") + shader_prim = stage.DefinePrim(USD_PARTICLEFIELD_SHADER_PATH, "Shader") + shader_prim.CreateAttribute( + "info:implementationSource", + Sdf.ValueTypeNames.Token, + custom=False, + variability=Sdf.VariabilityUniform, + ).Set("sourceAsset") + shader_prim.CreateAttribute( + "info:mdl:sourceAsset", + Sdf.ValueTypeNames.Asset, + custom=False, + variability=Sdf.VariabilityUniform, + ).Set(Sdf.AssetPath(PARTICLEFIELD_MATERIAL_MDL_FILE)) + shader_prim.CreateAttribute( + "info:mdl:sourceAsset:subIdentifier", + Sdf.ValueTypeNames.Token, + custom=False, + variability=Sdf.VariabilityUniform, + ).Set(PARTICLEFIELD_MATERIAL_NAME) + + if has_post_processing: + shader_prim.CreateAttribute("inputs:apply_srgb_linear", Sdf.ValueTypeNames.Bool).Set(False) + shader_prim.CreateAttribute("inputs:apply_inverse_tonemap", Sdf.ValueTypeNames.Bool).Set(False) + + output_attr = shader_prim.CreateAttribute("outputs:out", Sdf.ValueTypeNames.Token) + output_attr.SetMetadata("renderType", "material") + + material = UsdShade.Material(material_prim) + shader = UsdShade.Shader(shader_prim) + for output_name in ("mdl:displacement", "mdl:surface", "mdl:volume"): + output = material.CreateOutput(output_name, Sdf.ValueTypeNames.Token) + output.ConnectToSource(shader.GetOutput("out")) + + binding_api = UsdShade.MaterialBindingAPI(prim) + binding_api.Bind(material, bindingStrength=UsdShade.Tokens.weakerThanDescendants) diff --git a/threedgrut/export/usd/writers/ppisp_controller_writer.py b/threedgrut/export/usd/writers/ppisp_controller_writer.py new file mode 100644 index 00000000..e59305a6 --- /dev/null +++ b/threedgrut/export/usd/writers/ppisp_controller_writer.py @@ -0,0 +1,304 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 + +""" +PPISP Controller USD writer. + +Writes the per-camera PPISP controller as a UsdShade Shader prim that +references the shared ``ppisp_controller.slang`` SPG asset. The trained +controller weights are flattened into a single ``float[] inputs:weights`` +attribute on the Shader prim — the Slang shader picks them up as a +``StructuredBuffer`` at dispatch time. + +The flatten layout must match ``ppisp_controller.slang``'s ``OFF_*`` +constants: + + conv1_weight (16 x 3) | conv1_bias (16) + conv2_weight (32 x 16) | conv2_bias (32) + conv3_weight (64 x 32) | conv3_bias (64) + trunk0_weight (128 x 1601) | trunk0_bias (128) + trunk1_weight (128 x 128) | trunk1_bias (128) + trunk2_weight (128 x 128) | trunk2_bias (128) + exposure_head_weight (128) | exposure_head_bias (1) + color_head_weight (8 x 128)| color_head_bias (8) +""" + +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING, List, Sequence + +import numpy as np + +from pxr import Sdf, Usd, UsdShade, Vt + +from threedgrut.export.usd.stage_utils import NamedSerialized + +if TYPE_CHECKING: + import torch.nn as nn # noqa: F401 + +log = logging.getLogger(__name__) + + +# Names must match ppisp_controller.slang's bindings and ppisp_controller.slang.usda. +CONTROLLER_INPUT_RENDER_VAR = "HdrColor" +CONTROLLER_OUTPUT_NAME = "ControllerParams" +PRIOR_EXPOSURE_INPUT = "priorExposure" +WEIGHTS_INPUT = "weights" + +CONTROLLER_USDA_FILE = "ppisp_controller.slang.usda" +CONTROLLER_SLANG_FILE = "ppisp_controller.slang" + +# Architecture sizes (mirror ppisp._PPISPController defaults / shader constants). +EXPECTED_SIZES = { + "cnn_feature_dim": 64, + "pool_grid_h": 5, + "pool_grid_w": 5, + "mlp_hidden_dim": 128, + "color_params_per_frame": 8, + "input_downsampling": 3, +} + +# Total weight count. This *must* match ppisp_controller.slang::TOTAL_WEIGHTS. +EXPECTED_WEIGHTS_LEN = ( + 16 * 3 + 16 + + 32 * 16 + 32 + + 64 * 32 + 64 + + 128 * 1601 + 128 + + 128 * 128 + 128 + + 128 * 128 + 128 + + 128 + 1 + + 8 * 128 + 8 +) + + +# --------------------------------------------------------------------------- +# Weight extraction and validation +# --------------------------------------------------------------------------- + + +def _validate_controller_shape(controller) -> None: + """Sanity-check a ``_PPISPController`` matches the shader's hard-coded sizes.""" + cnn_encoder = controller.cnn_encoder + conv1 = cnn_encoder[0] + conv2 = cnn_encoder[3] + conv3 = cnn_encoder[5] + maxpool = cnn_encoder[1] + avgpool = cnn_encoder[6] + + if conv1.in_channels != 3 or conv1.out_channels != 16: + raise ValueError(f"controller conv1 must be 3->16, got {conv1.in_channels}->{conv1.out_channels}") + if conv1.kernel_size != (1, 1): + raise ValueError(f"controller conv1 kernel must be 1x1, got {conv1.kernel_size}") + if conv2.in_channels != 16 or conv2.out_channels != 32: + raise ValueError(f"controller conv2 must be 16->32, got {conv2.in_channels}->{conv2.out_channels}") + if conv3.in_channels != 32 or conv3.out_channels != EXPECTED_SIZES["cnn_feature_dim"]: + raise ValueError( + f"controller conv3 out_channels must be {EXPECTED_SIZES['cnn_feature_dim']}, got {conv3.out_channels}" + ) + if maxpool.kernel_size != EXPECTED_SIZES["input_downsampling"]: + raise ValueError( + f"controller maxpool kernel must be {EXPECTED_SIZES['input_downsampling']}, got {maxpool.kernel_size}" + ) + if maxpool.stride != EXPECTED_SIZES["input_downsampling"]: + raise ValueError( + f"controller maxpool stride must be {EXPECTED_SIZES['input_downsampling']}, got {maxpool.stride}" + ) + + expected_grid = (EXPECTED_SIZES["pool_grid_h"], EXPECTED_SIZES["pool_grid_w"]) + if tuple(avgpool.output_size) != expected_grid: + raise ValueError(f"controller AdaptiveAvgPool2d must be {expected_grid}, got {tuple(avgpool.output_size)}") + + trunk = controller.mlp_trunk + linear_layers = [m for m in trunk if hasattr(m, "weight") and m.weight.dim() == 2] + if len(linear_layers) != 3: + raise ValueError(f"controller MLP trunk must have 3 Linear layers, got {len(linear_layers)}") + + expected_input_dim = ( + EXPECTED_SIZES["pool_grid_h"] + * EXPECTED_SIZES["pool_grid_w"] + * EXPECTED_SIZES["cnn_feature_dim"] + + 1 + ) + if linear_layers[0].in_features != expected_input_dim: + raise ValueError( + f"controller trunk[0].in_features must be {expected_input_dim}, got {linear_layers[0].in_features}" + ) + for idx, layer in enumerate(linear_layers): + if layer.out_features != EXPECTED_SIZES["mlp_hidden_dim"]: + raise ValueError( + f"controller trunk[{idx}].out_features must be {EXPECTED_SIZES['mlp_hidden_dim']}, " + f"got {layer.out_features}" + ) + + if controller.exposure_head.out_features != 1: + raise ValueError("controller exposure_head must produce one output") + if controller.color_head.out_features != EXPECTED_SIZES["color_params_per_frame"]: + raise ValueError( + f"controller color_head must produce {EXPECTED_SIZES['color_params_per_frame']} outputs" + ) + + +def _to_np(t) -> np.ndarray: + import torch + return t.detach().cpu().to(dtype=torch.float32).numpy() + + +def flatten_controller_weights(controller) -> np.ndarray: + """Concatenate all controller weights into one float32 buffer. + + The order must match ``ppisp_controller.slang``'s ``OFF_*`` offsets. + Returns a 1-D ``np.float32`` array of length :data:`EXPECTED_WEIGHTS_LEN`. + """ + _validate_controller_shape(controller) + + cnn_encoder = controller.cnn_encoder + conv1 = cnn_encoder[0] + conv2 = cnn_encoder[3] + conv3 = cnn_encoder[5] + + trunk = controller.mlp_trunk + linear_layers = [m for m in trunk if hasattr(m, "weight") and m.weight.dim() == 2] + + def conv_w(layer) -> np.ndarray: + # PyTorch Conv2d weight: [out, in, kH, kW]. With 1x1 kernels we + # emit row-major [out * in]. + return _to_np(layer.weight).reshape(layer.out_channels, layer.in_channels).reshape(-1) + + parts: List[np.ndarray] = [ + conv_w(conv1), _to_np(conv1.bias).reshape(-1), + conv_w(conv2), _to_np(conv2.bias).reshape(-1), + conv_w(conv3), _to_np(conv3.bias).reshape(-1), + _to_np(linear_layers[0].weight).reshape(-1), _to_np(linear_layers[0].bias).reshape(-1), + _to_np(linear_layers[1].weight).reshape(-1), _to_np(linear_layers[1].bias).reshape(-1), + _to_np(linear_layers[2].weight).reshape(-1), _to_np(linear_layers[2].bias).reshape(-1), + _to_np(controller.exposure_head.weight).reshape(-1), _to_np(controller.exposure_head.bias).reshape(-1), + _to_np(controller.color_head.weight).reshape(-1), _to_np(controller.color_head.bias).reshape(-1), + ] + + flat = np.concatenate(parts).astype(np.float32, copy=False) + if flat.size != EXPECTED_WEIGHTS_LEN: + raise RuntimeError( + f"flatten_controller_weights produced {flat.size} floats; expected {EXPECTED_WEIGHTS_LEN}. " + "Did the controller architecture change?" + ) + if not np.all(np.isfinite(flat)): + raise RuntimeError( + "controller weights contain NaN/Inf; refusing to export. " + "Investigate the trained checkpoint before retrying." + ) + return flat + + +# --------------------------------------------------------------------------- +# USD authoring +# --------------------------------------------------------------------------- + + +def add_controller_shader_to_render_product( + stage: Usd.Stage, + render_product_path: str, + camera_index: int, + controller, + *, + prior_exposure: float | None = None, +) -> UsdShade.Shader: + """Author the controller Shader prim and connect ``HdrColor`` → ``ControllerParams``. + + Returns the created Shader so the caller can wire its output into the + PPISP shader. The PPISP shader is responsible for *consuming* the + output via its dynamic-controller binding. + """ + render_product = stage.GetPrimAtPath(render_product_path) + if not render_product.IsValid(): + raise ValueError(f"RenderProduct not found at path: {render_product_path}") + + # Mark HdrColor RenderVar input as an opaque AOV (no connection needed here). + input_var_path = f"{render_product_path}/{CONTROLLER_INPUT_RENDER_VAR}" + input_var_prim = stage.GetPrimAtPath(input_var_path) + if input_var_prim.IsValid(): + input_var_prim.CreateAttribute("omni:rtx:aov", Sdf.ValueTypeNames.Opaque, custom=False) + + shader_prim_name = f"PPISPController_{camera_index}" + shader_path = f"{render_product_path}/{shader_prim_name}" + shader = UsdShade.Shader.Define(stage, shader_path) + shader.GetPrim().GetReferences().AddReference(CONTROLLER_USDA_FILE) + shader.GetPrim().CreateAttribute( + "info:implementationSource", Sdf.ValueTypeNames.Token, custom=False + ).Set("sourceAsset") + shader.GetPrim().CreateAttribute( + "info:spg:sourceAsset", Sdf.ValueTypeNames.Asset, custom=False + ).Set(Sdf.AssetPath(CONTROLLER_SLANG_FILE)) + shader.GetPrim().CreateAttribute( + "info:spg:sourceAsset:subIdentifier", Sdf.ValueTypeNames.Token, custom=False + ).Set("controllerProcess") + + hdr_input = shader.CreateInput(CONTROLLER_INPUT_RENDER_VAR, Sdf.ValueTypeNames.Opaque) + hdr_input.GetAttr().SetConnections([Sdf.Path(f"../{CONTROLLER_INPUT_RENDER_VAR}.omni:rtx:aov")]) + + shader.CreateOutput(CONTROLLER_OUTPUT_NAME, Sdf.ValueTypeNames.Opaque) + + prior_input = shader.CreateInput(PRIOR_EXPOSURE_INPUT, Sdf.ValueTypeNames.Float) + prior_input.Set(float(prior_exposure or 0.0)) + + weights = flatten_controller_weights(controller) + weights_input = shader.CreateInput(WEIGHTS_INPUT, Sdf.ValueTypeNames.FloatArray) + weights_input.Set(Vt.FloatArray.FromNumpy(weights)) + + # Route the controller output through a RenderVar with omni:rtx:aov, so + # SPG resolves it the same way it resolves HdrColor / LdrColor. Direct + # Shader -> Shader connections work in slangpy but Kit's runtime walks + # AOV connections, not arbitrary UsdShade outputs. + var_path = f"{render_product_path}/{CONTROLLER_OUTPUT_NAME}" + render_var = stage.DefinePrim(var_path, "RenderVar") + render_var.CreateAttribute("sourceName", Sdf.ValueTypeNames.String).Set(CONTROLLER_OUTPUT_NAME) + aov_attr = render_var.CreateAttribute( + "omni:rtx:aov", Sdf.ValueTypeNames.Opaque, custom=False + ) + aov_attr.SetConnections([ + shader.GetPath().AppendProperty(f"outputs:{CONTROLLER_OUTPUT_NAME}") + ]) + + # Add the intermediate var to RenderProduct.orderedVars so SPG discovers it. + ordered_vars_rel = render_product.GetRelationship("orderedVars") + if ordered_vars_rel: + targets = list(ordered_vars_rel.GetTargets()) + path = Sdf.Path(CONTROLLER_OUTPUT_NAME) + if path not in targets: + targets.append(path) + ordered_vars_rel.SetTargets(targets) + + log.debug( + "Authored PPISP controller shader at %s (camera %d, %d weights), " + "AOV RenderVar at %s", + shader_path, camera_index, weights.size, var_path, + ) + return shader + + +# --------------------------------------------------------------------------- +# Sidecar packaging +# --------------------------------------------------------------------------- + + +def get_controller_sidecars() -> List[NamedSerialized]: + """Load the shared controller SPG sidecar files. + + Unlike the dynamic PPISP path, the controller does not need per-camera + sidecar generation: the weights live in USD attributes, so the slang / + lua / usda assets are identical for every camera. + """ + from threedgrut.export.usd.ppisp_spg import _SPG_DIR + filenames = [CONTROLLER_SLANG_FILE, CONTROLLER_SLANG_FILE + ".lua", CONTROLLER_USDA_FILE] + out: List[NamedSerialized] = [] + for name in filenames: + path = _SPG_DIR / name + if path.exists(): + out.append(NamedSerialized(filename=name, serialized=path.read_bytes())) + return out diff --git a/threedgrut/export/usd/writers/ppisp_writer.py b/threedgrut/export/usd/writers/ppisp_writer.py new file mode 100644 index 00000000..83771153 --- /dev/null +++ b/threedgrut/export/usd/writers/ppisp_writer.py @@ -0,0 +1,540 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +PPISP USD Writer. + +Export PPISP (Physically Plausible Image Signal Processing) as a UsdShade +Shader prim on each camera's RenderProduct. Adapted from +nre-fermat/nre/utils/io/export/ppisp_usd_writer.py, replacing the +rig/timestamp frame-mapping with 3DGRUT integer frame indices. + +PPISP pipeline stages: +1. Exposure compensation (per-frame, time-sampled) +2. Vignetting correction (per-camera, static) +3. Color correction via ZCA-based homography (per-frame, time-sampled) +4. Camera Response Function (per-camera, static) +""" + +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING, Dict, List, Tuple + +import numpy as np + +from pxr import Gf, Sdf, Usd, UsdGeom, UsdShade + +if TYPE_CHECKING: + from ppisp import PPISP # type: ignore[import-not-found] + +log = logging.getLogger(__name__) + +NUM_CHANNELS = 3 +COLOR_PARAMS_PER_FRAME = 8 +CHANNEL_SUFFIXES = ["R", "G", "B"] + +PPISP_SPG_USDA_FILE = "ppisp_usd_spg.slang.usda" +PPISP_SPG_SLANG_FILE = "ppisp_usd_spg.slang" +PPISP_SPG_DYN_USDA_FILE = "ppisp_usd_spg_dyn.slang.usda" +PPISP_SPG_DYN_SLANG_FILE = "ppisp_usd_spg_dyn.slang" +PPISP_INPUT_RENDER_VAR = "HdrColor" +PPISP_CONTROLLER_INPUT = "ControllerParams" +PPISP_OUTPUT_RENDER_VAR = "PPISPColor" +LDR_COLOR_RENDER_VAR = "LdrColor" +PPISP_CAMERA_EXPOSURE = 0.0 +PPISP_CAMERA_EXPOSURE_FSTOP = 1.0 +PPISP_CAMERA_EXPOSURE_ISO = 100.0 +PPISP_CAMERA_EXPOSURE_RESPONSIVITY = 1.0 +PPISP_CAMERA_EXPOSURE_TIME = 1.0 + + +# --------------------------------------------------------------------------- +# Dataset frame-mapping helpers +# --------------------------------------------------------------------------- + + +def build_camera_frame_mapping(dataset) -> Tuple[List[str], Dict[str, List[int]]]: + """Build per-camera frame lists from a 3DGRUT dataset. + + Returns: + (camera_names, {camera_name: [frame_idx, ...]}) where frame_idx values + are the global training indices used as USD time codes. + """ + num_frames = len(dataset) + + camera_names: List[str] + if hasattr(dataset, "get_camera_names"): + camera_names = dataset.get_camera_names() + else: + camera_names = ["camera_0"] + + camera_frames: Dict[str, List[int]] = {name: [] for name in camera_names} + + for frame_idx in range(num_frames): + if hasattr(dataset, "get_camera_idx"): + cam_idx = dataset.get_camera_idx(frame_idx) + else: + cam_idx = 0 + if 0 <= cam_idx < len(camera_names): + camera_frames[camera_names[cam_idx]].append(frame_idx) + + return camera_names, camera_frames + + +# --------------------------------------------------------------------------- +# Shader prim creation +# --------------------------------------------------------------------------- + + +def _add_ldr_color_render_var( + stage: Usd.Stage, + render_product_path: str, + ppisp_output_path: Sdf.Path, +) -> str: + """Create a LdrColor RenderVar wired to the PPISP output.""" + render_var_path = f"{render_product_path}/{LDR_COLOR_RENDER_VAR}" + render_var = stage.DefinePrim(render_var_path, "RenderVar") + render_var.CreateAttribute("sourceName", Sdf.ValueTypeNames.String).Set(LDR_COLOR_RENDER_VAR) + aov_attr = render_var.CreateAttribute("omni:rtx:aov", Sdf.ValueTypeNames.Opaque, custom=False) + aov_attr.SetConnections([ppisp_output_path]) + return render_var_path + + +def _create_shader_prim( + stage: Usd.Stage, + render_product_path: str, + *, + controller_shader: UsdShade.Shader | None = None, +) -> UsdShade.Shader: + """Create the PPISP Shader prim on a RenderProduct. + + When ``controller_shader`` is None, the static SPG variant is used and + ``exposureOffset`` / colour latents must be authored as USD attributes + on the returned Shader. When ``controller_shader`` is provided, the + dynamic variant is used: the controller's ``ControllerParams`` output is + wired into a new opaque input on the PPISP shader, and the per-frame + exposure / colour params are sourced from the controller at runtime. + + Wires HdrColor → PPISP → LdrColor (and ControllerParams → PPISP when a + controller is present) and appends LdrColor to orderedVars. + + Returns the UsdShade.Shader for parameter setting. + """ + render_product = stage.GetPrimAtPath(render_product_path) + if not render_product.IsValid(): + raise ValueError(f"RenderProduct not found at path: {render_product_path}") + + use_dynamic = controller_shader is not None + usda_file = PPISP_SPG_DYN_USDA_FILE if use_dynamic else PPISP_SPG_USDA_FILE + slang_file = PPISP_SPG_DYN_SLANG_FILE if use_dynamic else PPISP_SPG_SLANG_FILE + sub_identifier = "ppispProcessDyn" if use_dynamic else "ppispProcess" + + # Mark HdrColor RenderVar input as an opaque AOV (no connection needed here) + input_var_path = f"{render_product_path}/{PPISP_INPUT_RENDER_VAR}" + input_var_prim = stage.GetPrimAtPath(input_var_path) + if input_var_prim.IsValid(): + input_var_prim.CreateAttribute("omni:rtx:aov", Sdf.ValueTypeNames.Opaque, custom=False) + + # PPISP Shader prim referencing the SPG asset definition + ppisp_shader_path = f"{render_product_path}/PPISP" + shader = UsdShade.Shader.Define(stage, ppisp_shader_path) + shader.GetPrim().GetReferences().AddReference(usda_file) + # Duplicate the source metadata on the instance. Some Kit SPG/Fabric paths + # do not resolve referenced shader metadata when opening packaged USDZ files. + shader.GetPrim().CreateAttribute("info:implementationSource", Sdf.ValueTypeNames.Token, custom=False).Set( + "sourceAsset" + ) + shader.GetPrim().CreateAttribute("info:spg:sourceAsset", Sdf.ValueTypeNames.Asset, custom=False).Set( + Sdf.AssetPath(slang_file) + ) + shader.GetPrim().CreateAttribute("info:spg:sourceAsset:subIdentifier", Sdf.ValueTypeNames.Token, custom=False).Set( + sub_identifier + ) + + # HdrColor opaque input wired to the input RenderVar's AOV + hdr_input = shader.CreateInput(PPISP_INPUT_RENDER_VAR, Sdf.ValueTypeNames.Opaque) + hdr_input.GetAttr().SetConnections([Sdf.Path(f"../{PPISP_INPUT_RENDER_VAR}.omni:rtx:aov")]) + + if use_dynamic: + controller_input = shader.CreateInput(PPISP_CONTROLLER_INPUT, Sdf.ValueTypeNames.Opaque) + # Route through the controller's sibling RenderVar's omni:rtx:aov, + # mirroring how PPISP reads HdrColor. SPG only resolves AOV + # connections, not direct Shader -> Shader output references. + controller_input.GetAttr().SetConnections( + [Sdf.Path(f"../{PPISP_CONTROLLER_INPUT}.omni:rtx:aov")] + ) + + # PPISPColor opaque output + shader.CreateOutput(PPISP_OUTPUT_RENDER_VAR, Sdf.ValueTypeNames.Opaque) + + # LdrColor RenderVar connected to the PPISP output. This intentionally + # replaces the display AOV with PPISP's LDR output. + ppisp_output_path = shader.GetPath().AppendProperty(f"outputs:{PPISP_OUTPUT_RENDER_VAR}") + ldr_var_path = _add_ldr_color_render_var(stage, render_product_path, ppisp_output_path) + + # Append LdrColor to orderedVars + ordered_vars_rel = render_product.GetRelationship("orderedVars") + if ordered_vars_rel: + targets = list(ordered_vars_rel.GetTargets()) + targets.append(Sdf.Path(LDR_COLOR_RENDER_VAR)) + ordered_vars_rel.SetTargets(targets) + + return shader + + +# --------------------------------------------------------------------------- +# Static parameter setters (per-camera) +# --------------------------------------------------------------------------- + + +def _set_responsivity_params(shader: UsdShade.Shader) -> None: + """Author the user-overridable per-channel responsivity inputs (default + 1.0). The shader premultiplies these with the input HDR before the rest + of the PPISP pipeline runs; consumers can override the values per-camera + in the USD asset without re-running the export.""" + for channel in ("R", "G", "B"): + shader.CreateInput(f"responsivity{channel}", Sdf.ValueTypeNames.Float).Set(1.0) + + +def _set_vignetting_params(shader: UsdShade.Shader, ppisp: PPISP, camera_index: int) -> None: + """Set per-camera vignetting parameters (static). + + ppisp.vignetting_params[camera_index] has shape [3, 5]: + [cx, cy, alpha1, alpha2, alpha3] per channel. + """ + vig = ppisp.vignetting_params[camera_index].detach().cpu().numpy() # [3, 5] + for ch in range(NUM_CHANNELS): + s = CHANNEL_SUFFIXES[ch] + shader.CreateInput(f"vignettingCenter{s}", Sdf.ValueTypeNames.Float2).Set( + Gf.Vec2f(float(vig[ch, 0]), float(vig[ch, 1])) + ) + shader.CreateInput(f"vignettingAlpha1{s}", Sdf.ValueTypeNames.Float).Set(float(vig[ch, 2])) + shader.CreateInput(f"vignettingAlpha2{s}", Sdf.ValueTypeNames.Float).Set(float(vig[ch, 3])) + shader.CreateInput(f"vignettingAlpha3{s}", Sdf.ValueTypeNames.Float).Set(float(vig[ch, 4])) + + +def _set_crf_params(shader: UsdShade.Shader, ppisp: PPISP, camera_index: int) -> None: + """Set per-camera CRF raw parameters (static). + + ppisp.crf_params[camera_index] has shape [3, 4]: + [toe, shoulder, gamma, center] per channel (raw, activations applied in shader). + """ + crf = ppisp.crf_params[camera_index].detach().cpu().numpy() # [3, 4] + for ch in range(NUM_CHANNELS): + s = CHANNEL_SUFFIXES[ch] + shader.CreateInput(f"crfToe{s}", Sdf.ValueTypeNames.Float).Set(float(crf[ch, 0])) + shader.CreateInput(f"crfShoulder{s}", Sdf.ValueTypeNames.Float).Set(float(crf[ch, 1])) + shader.CreateInput(f"crfGamma{s}", Sdf.ValueTypeNames.Float).Set(float(crf[ch, 2])) + shader.CreateInput(f"crfCenter{s}", Sdf.ValueTypeNames.Float).Set(float(crf[ch, 3])) + + +# --------------------------------------------------------------------------- +# Animated parameter setters (per-frame, time-sampled) +# --------------------------------------------------------------------------- + + +def _set_animated_exposure_params( + shader: UsdShade.Shader, + ppisp: PPISP, + frame_indices: List[int], +) -> None: + """Write time-sampled exposure offset; default = mean across this camera's frames. + + ppisp.exposure_params has shape [num_frames]. + Time code = float(frame_idx). + """ + exposure = ppisp.exposure_params.detach().cpu().numpy() # [num_frames] + + valid = [i for i in frame_indices if i < len(exposure)] + mean_val = float(np.mean(exposure[valid])) if valid else 0.0 + + exposure_input = shader.CreateInput("exposureOffset", Sdf.ValueTypeNames.Float) + attr = exposure_input.GetAttr() + attr.Set(mean_val) + + for frame_idx in valid: + attr.Set(float(exposure[frame_idx]), float(frame_idx)) + + +def _set_static_exposure_params( + shader: UsdShade.Shader, + ppisp: PPISP, + frame_index: int, +) -> None: + """Write one fixed exposure offset without USD time samples.""" + exposure = ppisp.exposure_params.detach().cpu().numpy() + if frame_index < 0 or frame_index >= len(exposure): + raise ValueError(f"frame_index must be in [0, {len(exposure) - 1}], got {frame_index}.") + shader.CreateInput("exposureOffset", Sdf.ValueTypeNames.Float).Set(float(exposure[frame_index])) + + +def _set_animated_color_params( + shader: UsdShade.Shader, + ppisp: PPISP, + frame_indices: List[int], +) -> None: + """Write time-sampled color latent offsets; default = mean across this camera's frames. + + ppisp.color_params has shape [num_frames, 8]: + [db_r, db_g, dr_r, dr_g, dg_r, dg_g, dgray_r, dgray_g]. + Written as 4 float2 attributes. + Time code = float(frame_idx). + """ + color = ppisp.color_params.detach().cpu().numpy() # [num_frames, 8] + + valid = [i for i in frame_indices if i < len(color)] + mean_color = np.mean(color[valid], axis=0) if valid else np.zeros(8) + + control_point_names = ["colorLatentBlue", "colorLatentRed", "colorLatentGreen", "colorLatentNeutral"] + attrs = [] + for i, name in enumerate(control_point_names): + inp = shader.CreateInput(name, Sdf.ValueTypeNames.Float2) + attr = inp.GetAttr() + attr.Set(Gf.Vec2f(float(mean_color[i * 2]), float(mean_color[i * 2 + 1]))) + attrs.append(attr) + + for frame_idx in valid: + frame_color = color[frame_idx] + for i, attr in enumerate(attrs): + attr.Set( + Gf.Vec2f(float(frame_color[i * 2]), float(frame_color[i * 2 + 1])), + float(frame_idx), + ) + + +def _set_static_color_params( + shader: UsdShade.Shader, + ppisp: PPISP, + frame_index: int, +) -> None: + """Write one fixed color latent state without USD time samples.""" + color = ppisp.color_params.detach().cpu().numpy() + if frame_index < 0 or frame_index >= len(color): + raise ValueError(f"frame_index must be in [0, {len(color) - 1}], got {frame_index}.") + + frame_color = color[frame_index] + control_point_names = ["colorLatentBlue", "colorLatentRed", "colorLatentGreen", "colorLatentNeutral"] + for i, name in enumerate(control_point_names): + shader.CreateInput(name, Sdf.ValueTypeNames.Float2).Set( + Gf.Vec2f(float(frame_color[i * 2]), float(frame_color[i * 2 + 1])) + ) + + +# --------------------------------------------------------------------------- +# Per-camera entry point +# --------------------------------------------------------------------------- + + +def add_ppisp_shader_to_render_product( + stage: Usd.Stage, + render_product_path: str, + camera_index: int, + ppisp: PPISP, + frame_indices: List[int], + fixed_frame_index: int | None = None, + controller_shader: UsdShade.Shader | None = None, +) -> Usd.Prim: + """Add a PPISP Shader to a RenderProduct for one physical camera. + + Per-camera parameters (vignetting, CRF) are written as static USD + attributes. Per-frame parameters (exposure, color latents) are either: + - written with mean-based defaults plus per-frame time samples (when + ``controller_shader`` is None and ``fixed_frame_index`` is None), or + - read at runtime from the upstream controller shader when it is + provided (the dynamic SPG variant is selected automatically). + + Args: + stage: USD stage containing the RenderProduct. + render_product_path: Path to the RenderProduct prim. + camera_index: Index of this camera in the PPISP model. + ppisp: Trained PPISP module. + frame_indices: Global frame indices belonging to this camera. + fixed_frame_index: If set, write this one PPISP frame state as static + shader inputs instead of authoring animated time samples. + controller_shader: Optional upstream controller Shader whose + ``ControllerParams`` output supplies exposure / colour latents. + + Returns: + The created PPISP Shader prim. + """ + assert camera_index < ppisp.num_cameras, f"camera_index {camera_index} >= ppisp.num_cameras {ppisp.num_cameras}" + if not frame_indices and fixed_frame_index is None and controller_shader is None: + log.warning(f"No frames for camera {camera_index} at {render_product_path}, skipping") + return stage.GetPseudoRoot() + + shader = _create_shader_prim(stage, render_product_path, controller_shader=controller_shader) + _set_responsivity_params(shader) + _set_vignetting_params(shader, ppisp, camera_index) + _set_crf_params(shader, ppisp, camera_index) + if controller_shader is not None: + # Exposure / colour latents are computed by the controller shader + # at runtime, so we don't author static or time-sampled values here. + pass + elif fixed_frame_index is None: + _set_animated_exposure_params(shader, ppisp, frame_indices) + _set_animated_color_params(shader, ppisp, frame_indices) + else: + _set_static_exposure_params(shader, ppisp, fixed_frame_index) + _set_static_color_params(shader, ppisp, fixed_frame_index) + + controller_suffix = ", controller" if controller_shader is not None else "" + log.info( + f"Added PPISP shader to {render_product_path} " + f"(camera {camera_index}, {len(frame_indices)} frame(s){controller_suffix})" + ) + return shader.GetPrim() + + +def _create_ppisp_camera(stage: Usd.Stage, render_product: Usd.Prim) -> None: + camera_rel = render_product.GetRelationship("camera") + camera_targets = camera_rel.GetTargets() if camera_rel else [] + if not camera_targets: + log.warning( + "RenderProduct %s has no camera target; skipping PPISP camera override", + render_product.GetPath(), + ) + return + + source_camera_path = camera_targets[0] + source_camera_prim = stage.GetPrimAtPath(source_camera_path) + if not source_camera_prim.IsValid(): + log.warning( + "RenderProduct %s targets missing camera %s; skipping PPISP camera override", + render_product.GetPath(), + source_camera_path, + ) + return + + ppisp_camera_path = render_product.GetPath().AppendChild(f"{source_camera_path.name}_no_isp") + ppisp_camera_prim = stage.DefinePrim(ppisp_camera_path, "Camera") + ppisp_camera_prim.SetHidden(True) + UsdGeom.Imageable(ppisp_camera_prim).CreateVisibilityAttr().Set("invisible") + ppisp_camera_prim.GetInherits().AddInherit(source_camera_path) + ppisp_camera_prim.CreateAttribute("exposure", Sdf.ValueTypeNames.Float).Set(PPISP_CAMERA_EXPOSURE) + ppisp_camera_prim.CreateAttribute("exposure:fStop", Sdf.ValueTypeNames.Float).Set(PPISP_CAMERA_EXPOSURE_FSTOP) + ppisp_camera_prim.CreateAttribute("exposure:iso", Sdf.ValueTypeNames.Float).Set(PPISP_CAMERA_EXPOSURE_ISO) + ppisp_camera_prim.CreateAttribute("exposure:responsivity", Sdf.ValueTypeNames.Float).Set( + PPISP_CAMERA_EXPOSURE_RESPONSIVITY + ) + ppisp_camera_prim.CreateAttribute("exposure:time", Sdf.ValueTypeNames.Float).Set(PPISP_CAMERA_EXPOSURE_TIME) + camera_rel.SetTargets([ppisp_camera_path]) + + +# --------------------------------------------------------------------------- +# Batch export over all RenderProducts +# --------------------------------------------------------------------------- + + +def add_ppisp_to_all_render_products( + stage: Usd.Stage, + ppisp: PPISP, + camera_names: List[str], + camera_frame_mapping: Dict[str, List[int]], + render_scope_path: str = "/Render", + fixed_camera_index: int | None = None, + fixed_frame_index: int | None = None, + use_controller: bool = False, +) -> List[Usd.Prim]: + """Add PPISP shaders to every RenderProduct in the Render scope. + + Args: + stage: USD stage with a populated /Render scope. + ppisp: Trained PPISP module. + camera_names: Ordered list of camera names (index = camera_idx in ppisp). + camera_frame_mapping: ``{camera_name: [frame_idx, ...]}`` from + :func:`build_camera_frame_mapping`. + render_scope_path: Path to the /Render Scope (default ``/Render``). + fixed_camera_index: If set, use this PPISP camera state for every + RenderProduct instead of matching the RenderProduct camera. + fixed_frame_index: If set, use this PPISP frame state as static shader + inputs instead of authoring animated exposure/color samples. + use_controller: If True, author a per-camera PPISP controller shader + and wire its output into the PPISP shader, replacing the static / + time-sampled exposure & colour inputs. Requires the controller + sidecars to be packaged alongside the USD output. + + Returns: + List of created PPISP Shader prims. + """ + from threedgrut.export.usd.writers.camera import _make_usd_prim_name + if use_controller: + from threedgrut.export.usd.writers.ppisp_controller_writer import ( + add_controller_shader_to_render_product, + ) + + render_scope = stage.GetPrimAtPath(render_scope_path) + if not render_scope.IsValid(): + log.warning(f"Render scope not found at {render_scope_path}, skipping PPISP export") + return [] + + camera_name_to_index = {name: idx for idx, name in enumerate(camera_names)} + created: List[Usd.Prim] = [] + + for child in render_scope.GetChildren(): + if child.GetTypeName() != "RenderProduct": + continue + + # RenderProduct prim name matches _make_usd_prim_name(camera_name) + prim_name = child.GetName() + # Reverse-lookup original camera_name by prim name + camera_name = next( + (n for n in camera_names if _make_usd_prim_name(n) == prim_name), + None, + ) + if camera_name is None: + log.warning(f"RenderProduct '{prim_name}' has no matching camera name, skipping") + continue + + camera_index = fixed_camera_index if fixed_camera_index is not None else camera_name_to_index.get(camera_name) + if camera_index is None: + log.warning(f"Camera '{camera_name}' not in camera_names list, skipping") + continue + if camera_index < 0 or camera_index >= ppisp.num_cameras: + raise ValueError(f"fixed_camera_index must be in [0, {ppisp.num_cameras - 1}], got {camera_index}.") + + frame_indices = camera_frame_mapping.get(camera_name, []) + _create_ppisp_camera(stage, child) + + controller_shader = None + if use_controller: + controllers = getattr(ppisp, "controllers", None) + if controllers is None or int(camera_index) >= len(controllers): + log.warning( + "PPISP controllers missing for camera %s (idx=%d); falling back to " + "static parameters for this RenderProduct.", + camera_name, int(camera_index), + ) + else: + controller_shader = add_controller_shader_to_render_product( + stage=stage, + render_product_path=str(child.GetPath()), + camera_index=int(camera_index), + controller=controllers[int(camera_index)], + ) + + shader_prim = add_ppisp_shader_to_render_product( + stage=stage, + render_product_path=str(child.GetPath()), + camera_index=camera_index, + ppisp=ppisp, + frame_indices=frame_indices, + fixed_frame_index=fixed_frame_index, + controller_shader=controller_shader, + ) + created.append(shader_prim) + + log.info(f"Added PPISP shaders to {len(created)} RenderProduct(s)") + return created diff --git a/threedgrut/export/usd/writers/render_product.py b/threedgrut/export/usd/writers/render_product.py new file mode 100644 index 00000000..89d114e4 --- /dev/null +++ b/threedgrut/export/usd/writers/render_product.py @@ -0,0 +1,86 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +USD RenderProduct writer. + +Creates a /Render Scope with one RenderProduct per camera, each holding an +HdrColor RenderVar and the camera relationship required by downstream +post-processing shaders (e.g. PPISP). +""" + +import logging +from typing import Dict, Tuple + +from pxr import Gf, Sdf, Usd, UsdGeom + +log = logging.getLogger(__name__) + +_HDR_COLOR_VAR = "HdrColor" +_RENDER_SCOPE_PATH = "/Render" + + +def create_render_products( + stage: Usd.Stage, + camera_entries: Dict[str, Tuple[str, int, int]], + render_scope_path: str = _RENDER_SCOPE_PATH, +) -> None: + """Create a /Render Scope with one RenderProduct per camera. + + Each RenderProduct is named after its camera and contains: + - ``camera`` relationship pointing to the USD camera prim. + - ``resolution`` attribute. + - ``orderedVars`` relationship → [.../HdrColor]. + - Child ``RenderVar`` ``HdrColor`` with ``sourceName = "HdrColor"``. + + Args: + stage: USD stage that already contains the camera prims. + camera_entries: Mapping ``{camera_name: (usd_camera_path, width, height)}``. + The camera_name is used as the RenderProduct prim name (after USD + identifier sanitization to match what export_cameras_to_usd produced). + render_scope_path: Root path for the Render scope (default ``/Render``). + """ + from threedgrut.export.usd.writers.camera import _make_usd_prim_name + + stage.DefinePrim(render_scope_path, "Scope") + + for camera_name, (camera_path, width, height) in camera_entries.items(): + prim_name = _make_usd_prim_name(camera_name) + product_path = f"{render_scope_path}/{prim_name}" + + product_prim = stage.DefinePrim(product_path, "RenderProduct") + + # Resolution + product_prim.CreateAttribute("resolution", Sdf.ValueTypeNames.Int2).Set( + Gf.Vec2i(int(width), int(height)) + ) + + # Camera relationship + camera_rel = product_prim.CreateRelationship("camera") + camera_rel.SetTargets([Sdf.Path(camera_path)]) + + # HdrColor RenderVar + hdr_var_path = f"{product_path}/{_HDR_COLOR_VAR}" + hdr_var = stage.DefinePrim(hdr_var_path, "RenderVar") + hdr_var.CreateAttribute("sourceName", Sdf.ValueTypeNames.String).Set(_HDR_COLOR_VAR) + hdr_var.CreateAttribute("omni:rtx:aov", Sdf.ValueTypeNames.Opaque, custom=False) + + # orderedVars relationship + ordered_vars_rel = product_prim.CreateRelationship("orderedVars") + ordered_vars_rel.SetTargets([Sdf.Path(_HDR_COLOR_VAR)]) + + log.debug(f"Created RenderProduct at {product_path} → camera {camera_path} ({width}×{height})") + + log.info(f"Created {len(camera_entries)} RenderProduct(s) under {render_scope_path}") diff --git a/threedgrut/render.py b/threedgrut/render.py index 313cff66..877a5922 100644 --- a/threedgrut/render.py +++ b/threedgrut/render.py @@ -119,7 +119,14 @@ def from_checkpoint( # Load post-processing if present in checkpoint post_processing = None method = conf.post_processing.method - if "post_processing" in checkpoint and method == "ppisp": + if "post_processing" in checkpoint and method == "linear-to-srgb": + from threedgrut.utils.post_processing_linear_to_srgb import LinearToSrgbPostProcessing + + post_processing = LinearToSrgbPostProcessing() + post_processing.load_state_dict(checkpoint["post_processing"]["module"]) + post_processing = post_processing.to("cuda") + logger.info("Linear-to-sRGB post-processing loaded from checkpoint") + elif "post_processing" in checkpoint and method == "ppisp": from ppisp import PPISP, PPISPConfig # Derive config from training settings to match trainer.py diff --git a/threedgrut/trainer.py b/threedgrut/trainer.py index 223fa539..accb3430 100644 --- a/threedgrut/trainer.py +++ b/threedgrut/trainer.py @@ -421,6 +421,13 @@ def init_post_processing(self, conf: DictConfig): ) logger.info(f"📷 {method.upper()} initialized: {num_cameras} cameras, {num_frames} frames") + elif method == "linear-to-srgb": + from threedgrut.utils.post_processing_linear_to_srgb import LinearToSrgbPostProcessing + + self.post_processing = LinearToSrgbPostProcessing().to(self.device) + self.post_processing_optimizers = [] + self.post_processing_schedulers = [] + logger.info("Post-processing: linear-to-sRGB (no trainable parameters)") else: raise ValueError(f"Unknown post-processing method: {method}") @@ -801,6 +808,7 @@ def on_training_end(self): dataset=self.train_dataset, conf=conf, background=getattr(self, "background", None), + post_processing=getattr(self, "post_processing", None), ) # Export post-processing report (PPISP-based) diff --git a/threedgrut/utils/post_processing_linear_to_srgb.py b/threedgrut/utils/post_processing_linear_to_srgb.py new file mode 100644 index 00000000..657e8245 --- /dev/null +++ b/threedgrut/utils/post_processing_linear_to_srgb.py @@ -0,0 +1,139 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +r"""Linear-to-sRGB post-processing for training and inference. + +This module implements ``post_processing.method: "linear-to-srgb"`` (see ``configs/base_gs.yaml``). +The trainer applies it to ``pred_rgb`` after the forward render and **before** photometric loss, +so use it when **ground-truth images are sRGB / display-referred** and the **renderer output is +linear scene-referred RGB** (typical for splatting). + +Integration: + +- **Training:** ``Trainer3DGRUT.init_post_processing`` builds :class:`LinearToSrgbPostProcessing` + when ``conf.post_processing.method == "linear-to-srgb"``. No optimizers; regularization term is + always zero (:meth:`get_regularization_loss`). +- **Inference:** ``Renderer.from_checkpoint`` restores the module from the checkpoint when the + saved config uses the same method. + +The forward signature matches ``threedgrut.utils.render.apply_post_processing``; unused arguments are ignored. + +The piecewise rule matches ``thirdparty/tiny-cuda-nn/scripts/common.py`` ``linear_to_srgb`` +(NumPy); this file uses the same math in PyTorch (no NumPy dependency on that script at runtime). +""" + +from __future__ import annotations + +import torch +import torch.nn as nn + + +def linear_to_srgb(x: torch.Tensor) -> torch.Tensor: + """Linear RGB to sRGB nonlinear light (IEC 61966-2-1 style piecewise). + + Same branch structure as ``linear_to_srgb`` in ``thirdparty/tiny-cuda-nn/scripts/common.py``: + + .. code-block:: python + + np.where(img > limit, 1.055 * (img ** (1.0 / 2.4)) - 0.055, 12.92 * img) + + with ``limit = 0.0031308``. Linear values above ``1`` can yield encoded values above ``1`` (HDR). + + Args: + x: Linear RGB tensor (any shape). + + Returns: + Encoded values, same shape / dtype / device as ``x``. + """ + limit = 0.0031308 + positive_x = torch.clamp(x, min=1e-08) + return torch.where( + x > limit, + 1.055 * torch.pow(positive_x, 1.0 / 2.4) - 0.055, + 12.92 * x, + ) + + +def srgb_to_linear(x: torch.Tensor) -> torch.Tensor: + """Inverse of :func:`linear_to_srgb`: sRGB encoded values back to linear. + + Piecewise IEC 61966-2-1 with break point at ``0.04045``: + + .. code-block:: python + + np.where(x < 0.04045, x / 12.92, ((x + 0.055) / 1.055) ** 2.4) + + Round-trips :func:`linear_to_srgb` to fp32 epsilon for ``x`` in [0, 1]; + HDR values (``x > 1``) are passed through the upper branch identically + to the encode side. + + Args: + x: sRGB-encoded tensor (any shape). + + Returns: + Linear values, same shape / dtype / device as ``x``. + """ + limit = 0.04045 + positive_x = torch.clamp(x + 0.055, min=1e-08) + return torch.where( + x < limit, + x / 12.92, + torch.pow(positive_x / 1.055, 2.4), + ) + + +class LinearToSrgbPostProcessing(nn.Module): + """``nn.Module`` wrapper so linear-to-sRGB can plug into the shared post-processing path. + + ``forward`` receives flattened RGB ``[N, 3]`` from ``apply_post_processing`` plus PPISP-style + metadata (pixel coordinates, resolution, camera / frame indices, exposure). Only + ``pred_rgb_flat`` is used; other arguments exist for API compatibility with PPISP. + + There are **no learnable parameters**. Checkpoints still store an (empty) ``state_dict`` for + this module when training with this method. + """ + + def __init__(self) -> None: + super().__init__() + self.register_buffer("_reg_loss_zero", torch.tensor(0.0)) + + def forward( + self, + pred_rgb_flat: torch.Tensor, + pixel_coords_flat: torch.Tensor, + resolution=None, + camera_idx=None, + frame_idx=None, + exposure_prior=None, + ) -> torch.Tensor: + """Encode ``pred_rgb_flat`` with :func:`linear_to_srgb`. + + Args: + pred_rgb_flat: ``[H*W, 3]`` linear RGB (contiguous, batch size 1 upstream). + pixel_coords_flat: Unused (PPISP contract). + resolution: Unused. + camera_idx: Unused. + frame_idx: Unused. + exposure_prior: Unused. + + Returns: + Same shape as ``pred_rgb_flat`` (piecewise IEC-style encode; see :func:`linear_to_srgb`). + """ + del pixel_coords_flat, resolution, camera_idx, frame_idx, exposure_prior + return linear_to_srgb(pred_rgb_flat) + + def get_regularization_loss(self) -> torch.Tensor: + """Scalar zero on the module device; required by the trainer alongside PPISP.""" + return self._reg_loss_zero diff --git a/threedgrut/utils/render.py b/threedgrut/utils/render.py index 57c0f427..f253dfd7 100644 --- a/threedgrut/utils/render.py +++ b/threedgrut/utils/render.py @@ -56,6 +56,9 @@ def apply_post_processing( ) -> dict: """Apply post-processing to rendered output. + ``post_processing`` is typically PPISP or :class:`~threedgrut.utils.post_processing_linear_to_srgb.LinearToSrgbPostProcessing`; + both follow the same ``__call__`` contract (flat RGB plus metadata). + Args: post_processing: Post-processing module outputs: Model outputs including pred_rgb diff --git a/tools/image_comparison/README.md b/tools/image_comparison/README.md new file mode 100644 index 00000000..f9e84293 --- /dev/null +++ b/tools/image_comparison/README.md @@ -0,0 +1,60 @@ +# Image Comparison + +Viser based image comparison viewer for either two specific images or two folders of matching image names. + +## Usage + +Compare two images: + +```bash +python tools/image_comparison/image_comparison.py --images path/to/a.png path/to/b.png +``` + +Compare two folders: + +```bash +python tools/image_comparison/image_comparison.py --folders path/to/folder_a path/to/folder_b +``` + +Optional arguments: + +```bash +python tools/image_comparison/image_comparison.py --folders path/to/folder_a path/to/folder_b --port 8080 --target_fps 20 +``` + +Serve on all network interfaces for another host: + +```bash +python tools/image_comparison/image_comparison.py --folders path/to/folder_a path/to/folder_b --host 0.0.0.0 --port 8080 +``` + +Then open `http://:8080` from the other host. If direct access is blocked, use SSH forwarding: + +```bash +ssh -L 8080:localhost:8080 user@server-host +``` + +## Viewer Modes + +- `Display Mode = fit_largest_dimension`: scales the image so one dimension fills the viewport while preserving the image aspect ratio. The other dimension is smaller than or equal to the viewport. This is the default. +- `Display Mode = fit`: stretches the image to fill both viewport dimensions. +- `slider`: displays both images in the same frame, split by a vertical or horizontal slider. The split can be changed from the `Slider Position` GUI control. +- `checkerboard`: alternates images with a checkerboard mask. +- `diff`: displays a selectable difference map with a `JET` colormap and a scale slider. + +When folder mode is used, images are matched by file name. Duplicate file names inside one folder are rejected so the comparison target is unambiguous. +Use `Previous Image` and `Next Image` to cycle through matched image pairs. + +## Metrics + +The `Metrics` panel displays readable text blocks for the current image pair and the global folder mean. `PSNR`, `SSIM`, and `FLIP` are computed automatically. + +The `Diff Metric` dropdown supports: + +- `l1`: per-pixel mean absolute RGB difference. +- `l2`: per-pixel RGB root mean squared difference. +- `psnr`: per-pixel PSNR-derived error, where lower PSNR is brighter. +- `ssim`: local `1 - SSIM` dissimilarity. +- `flip`: FLIP error map when `flip-evaluator` is installed. + +`FLIP` depends on the `flip-evaluator` package, which provides the `flip_evaluator` Python module. Do not install the unrelated `flip` package. diff --git a/tools/image_comparison/image_comparison.py b/tools/image_comparison/image_comparison.py new file mode 100644 index 00000000..b5158254 --- /dev/null +++ b/tools/image_comparison/image_comparison.py @@ -0,0 +1,787 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import math +import socket +import sys +import time +from dataclasses import dataclass +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import numpy as np +from PIL import Image + +SUPPORTED_IMAGE_EXTENSIONS = { + ".bmp", + ".jpeg", + ".jpg", + ".png", + ".tif", + ".tiff", + ".webp", +} + +COMPARISON_MODES = ["slider", "checkerboard", "diff"] +SLIDER_DIRECTIONS = ["vertical", "horizontal"] +DIFF_METRICS = ["l1", "l2", "psnr", "ssim", "flip"] +DISPLAY_MODES = ["fit_largest_dimension", "fit"] + + +@dataclass +class MetricResults: + psnr: Optional[float] + ssim: Optional[float] + flip: Optional[float] + flip_error: Optional[str] = None + + +def import_viser(): + try: + import viser + except ImportError: + print('viser not installed, please install the gui extra or run "pip install viser"') + sys.exit(1) + + return viser + + +@dataclass(frozen=True) +class ImagePair: + name: str + image_a_path: Path + image_b_path: Path + + +def is_image_path(path: Path) -> bool: + return path.is_file() and path.suffix.lower() in SUPPORTED_IMAGE_EXTENSIONS + + +def collect_images_by_name(folder: Path) -> Dict[str, Path]: + images: Dict[str, Path] = {} + duplicate_names: List[str] = [] + + for path in sorted(folder.rglob("*")): + if not is_image_path(path): + continue + + image_name = path.name + if image_name in images: + duplicate_names.append(image_name) + continue + + images[image_name] = path + + if duplicate_names: + duplicate_list = ", ".join(sorted(set(duplicate_names))) + raise ValueError(f"Duplicate image names found in {folder}: {duplicate_list}") + + return images + + +def build_specific_image_pair(image_a_path: Path, image_b_path: Path) -> List[ImagePair]: + if not is_image_path(image_a_path): + raise ValueError(f"Invalid image path: {image_a_path}") + if not is_image_path(image_b_path): + raise ValueError(f"Invalid image path: {image_b_path}") + + return [ + ImagePair( + name=f"{image_a_path.name} <-> {image_b_path.name}", + image_a_path=image_a_path, + image_b_path=image_b_path, + ) + ] + + +def build_folder_image_pairs(folder_a_path: Path, folder_b_path: Path) -> List[ImagePair]: + if not folder_a_path.is_dir(): + raise ValueError(f"Invalid folder path: {folder_a_path}") + if not folder_b_path.is_dir(): + raise ValueError(f"Invalid folder path: {folder_b_path}") + + folder_a_images = collect_images_by_name(folder_a_path) + folder_b_images = collect_images_by_name(folder_b_path) + matched_names = sorted(set(folder_a_images).intersection(folder_b_images)) + + if not matched_names: + raise ValueError(f"No matching image names found between {folder_a_path} and {folder_b_path}") + + return [ + ImagePair( + name=image_name, + image_a_path=folder_a_images[image_name], + image_b_path=folder_b_images[image_name], + ) + for image_name in matched_names + ] + + +def load_image_rgb(path: Path) -> np.ndarray: + with Image.open(path) as image: + return np.asarray(image.convert("RGB"), dtype=np.float32) / 255.0 + + +def resize_image(image: np.ndarray, size: Tuple[int, int]) -> np.ndarray: + width, height = size + resample_filter = getattr(Image, "Resampling", Image).LANCZOS + resized = Image.fromarray(float_image_to_uint8(image)).resize((width, height), resample_filter) + return np.asarray(resized, dtype=np.float32) / 255.0 + + +def load_aligned_pair(image_pair: ImagePair) -> Tuple[np.ndarray, np.ndarray, str]: + image_a = load_image_rgb(image_pair.image_a_path) + image_b = load_image_rgb(image_pair.image_b_path) + + if image_a.shape == image_b.shape: + status = f"{image_pair.name}: {image_a.shape[1]}x{image_a.shape[0]}" + return image_a, image_b, status + + image_b = resize_image(image_b, (image_a.shape[1], image_a.shape[0])) + status = ( + f"{image_pair.name}: A {image_a.shape[1]}x{image_a.shape[0]}, " + f"B resized to match from {image_pair.image_b_path.name}" + ) + return image_a, image_b, status + + +def render_slider_comparison( + image_a: np.ndarray, + image_b: np.ndarray, + slider_position: float, + slider_direction: str, +) -> np.ndarray: + output = image_b.copy() + height, width = image_a.shape[:2] + slider_position = float(np.clip(slider_position, 0.0, 1.0)) + + if slider_direction == "horizontal": + split_row = int(round(height * slider_position)) + output[:split_row, :] = image_a[:split_row, :] + if 0 < split_row < height: + output[max(0, split_row - 1) : min(height, split_row + 1), :] = 1.0 + else: + split_col = int(round(width * slider_position)) + output[:, :split_col] = image_a[:, :split_col] + if 0 < split_col < width: + output[:, max(0, split_col - 1) : min(width, split_col + 1)] = 1.0 + + return output + + +def render_checkerboard_comparison(image_a: np.ndarray, image_b: np.ndarray, checker_size: int) -> np.ndarray: + checker_size = max(1, checker_size) + height, width = image_a.shape[:2] + y_indices, x_indices = np.indices((height, width)) + checker_mask = ((x_indices // checker_size) + (y_indices // checker_size)) % 2 == 0 + return np.where(checker_mask[..., None], image_a, image_b) + + +def render_diff_comparison(image_a: np.ndarray, image_b: np.ndarray, diff_scale: float) -> np.ndarray: + return render_diff_metric(image_a=image_a, image_b=image_b, diff_metric="l1", diff_scale=diff_scale) + + +def render_diff_metric(image_a: np.ndarray, image_b: np.ndarray, diff_metric: str, diff_scale: float) -> np.ndarray: + error_map = compute_error_map(image_a=image_a, image_b=image_b, diff_metric=diff_metric) + scaled_error = np.clip(error_map * diff_scale, 0.0, 1.0) + return apply_jet_colormap(scaled_error) + + +def compute_error_map(image_a: np.ndarray, image_b: np.ndarray, diff_metric: str) -> np.ndarray: + if diff_metric == "l2": + return np.sqrt(np.mean(np.square(image_a - image_b), axis=-1)) + if diff_metric == "psnr": + return compute_psnr_error_map(image_a=image_a, image_b=image_b) + if diff_metric == "ssim": + return 1.0 - compute_ssim_map(image_a=image_a, image_b=image_b) + if diff_metric == "flip": + flip_map, flip_value, _ = compute_flip_metric(image_a=image_a, image_b=image_b) + if flip_map is not None: + return normalize_error_map(flip_map) + return scalar_metric_to_map(flip_value, image_a.shape[:2], higher_is_worse=True) + + return np.mean(np.abs(image_a - image_b), axis=-1) + + +def compute_psnr(image_a: np.ndarray, image_b: np.ndarray) -> float: + mse = float(np.mean(np.square(image_a - image_b))) + if mse <= 1.0e-12: + return math.inf + return 10.0 * math.log10(1.0 / mse) + + +def compute_psnr_error_map(image_a: np.ndarray, image_b: np.ndarray) -> np.ndarray: + mse = np.mean(np.square(image_a - image_b), axis=-1) + psnr_map = -10.0 * np.log10(np.maximum(mse, 1.0e-12)) + return 1.0 - np.clip(psnr_map / 60.0, 0.0, 1.0) + + +def compute_ssim(image_a: np.ndarray, image_b: np.ndarray) -> float: + return float(np.mean(compute_ssim_map(image_a=image_a, image_b=image_b))) + + +def compute_ssim_map(image_a: np.ndarray, image_b: np.ndarray, kernel_size: int = 11) -> np.ndarray: + image_a = np.clip(image_a, 0.0, 1.0) + image_b = np.clip(image_b, 0.0, 1.0) + + mu_a = box_filter(image_a, kernel_size=kernel_size) + mu_b = box_filter(image_b, kernel_size=kernel_size) + mu_a_squared = np.square(mu_a) + mu_b_squared = np.square(mu_b) + mu_ab = mu_a * mu_b + + sigma_a_squared = box_filter(np.square(image_a), kernel_size=kernel_size) - mu_a_squared + sigma_b_squared = box_filter(np.square(image_b), kernel_size=kernel_size) - mu_b_squared + sigma_ab = box_filter(image_a * image_b, kernel_size=kernel_size) - mu_ab + + c1 = 0.01**2 + c2 = 0.03**2 + numerator = (2.0 * mu_ab + c1) * (2.0 * sigma_ab + c2) + denominator = (mu_a_squared + mu_b_squared + c1) * (sigma_a_squared + sigma_b_squared + c2) + ssim = numerator / np.maximum(denominator, 1.0e-12) + return np.clip(np.mean(ssim, axis=-1), 0.0, 1.0) + + +def box_filter(image: np.ndarray, kernel_size: int) -> np.ndarray: + kernel_size = max(1, int(kernel_size)) + if kernel_size % 2 == 0: + kernel_size += 1 + + radius = kernel_size // 2 + padded = np.pad(image, ((radius, radius), (radius, radius), (0, 0)), mode="edge") + integral = np.pad(padded, ((1, 0), (1, 0), (0, 0)), mode="constant") + integral = np.cumsum(np.cumsum(integral, axis=0), axis=1) + summed = ( + integral[kernel_size:, kernel_size:] + - integral[:-kernel_size, kernel_size:] + - integral[kernel_size:, :-kernel_size] + + integral[:-kernel_size, :-kernel_size] + ) + return summed / float(kernel_size * kernel_size) + + +def scalar_metric_to_map(value: Optional[float], shape: Tuple[int, int], higher_is_worse: bool) -> np.ndarray: + if value is None or not np.isfinite(value): + return np.zeros(shape, dtype=np.float32) + + if higher_is_worse: + normalized_value = float(np.clip(value, 0.0, 1.0)) + else: + normalized_value = 1.0 - float(np.clip(value, 0.0, 1.0)) + + return np.full(shape, normalized_value, dtype=np.float32) + + +def normalize_error_map(error_map: np.ndarray) -> np.ndarray: + if error_map.ndim == 3: + error_map = np.mean(error_map[..., :3], axis=-1) + return np.clip(error_map.astype(np.float32), 0.0, 1.0) + + +def compute_flip_metric( + image_a: np.ndarray, image_b: np.ndarray +) -> Tuple[Optional[np.ndarray], Optional[float], Optional[str]]: + try: + import flip_evaluator + except ImportError as exc: + try: + import nbflip as flip_evaluator + except ImportError: + return ( + None, + None, + f"FLIP unavailable: install NVIDIA FLIP with `pip install flip-evaluator`, not `flip`: {exc}", + ) + + try: + flip_map, mean_flip, _ = flip_evaluator.evaluate( + np.ascontiguousarray(image_a.astype(np.float32)), + np.ascontiguousarray(image_b.astype(np.float32)), + "ldr", + True, + False, + True, + {}, + ) + return normalize_error_map(flip_map), float(mean_flip), None + except Exception as exc: + return None, None, f"FLIP failed: {exc}" + + +def compute_metric_results(image_a: np.ndarray, image_b: np.ndarray) -> MetricResults: + _, flip_value, flip_error = compute_flip_metric(image_a=image_a, image_b=image_b) + return MetricResults( + psnr=compute_psnr(image_a=image_a, image_b=image_b), + ssim=compute_ssim(image_a=image_a, image_b=image_b), + flip=flip_value, + flip_error=flip_error, + ) + + +def mean_metric_value(values: List[Optional[float]]) -> Optional[float]: + finite_values = [value for value in values if value is not None and np.isfinite(value)] + if finite_values: + return float(np.mean(finite_values)) + + if any(value is not None and math.isinf(value) for value in values): + return math.inf + + return None + + +def aggregate_metric_results(metric_results: List[MetricResults]) -> MetricResults: + return MetricResults( + psnr=mean_metric_value([metrics.psnr for metrics in metric_results]), + ssim=mean_metric_value([metrics.ssim for metrics in metric_results]), + flip=mean_metric_value([metrics.flip for metrics in metric_results]), + flip_error=None if any(metrics.flip is not None for metrics in metric_results) else "FLIP not computed", + ) + + +def format_metric_value(value: Optional[float], precision: int = 5) -> str: + if value is None: + return "unavailable" + if math.isinf(value): + return "inf" + if math.isnan(value): + return "nan" + return f"{value:.{precision}f}" + + +def format_metric_markdown(title: str, metric_results: MetricResults, count: Optional[int] = None) -> str: + count_text = "" if count is None else f" ({count} images)" + return "\n".join( + [ + f"### {title}{count_text}", + f"- PSNR: **{format_metric_value(metric_results.psnr, precision=4)}**", + f"- SSIM: **{format_metric_value(metric_results.ssim, precision=5)}**", + f"- FLIP: **{format_metric_value(metric_results.flip, precision=5)}**", + ] + ) + + +def apply_jet_colormap(value: np.ndarray) -> np.ndarray: + value = np.clip(value, 0.0, 1.0) + red = np.clip(1.5 - np.abs(4.0 * value - 3.0), 0.0, 1.0) + green = np.clip(1.5 - np.abs(4.0 * value - 2.0), 0.0, 1.0) + blue = np.clip(1.5 - np.abs(4.0 * value - 1.0), 0.0, 1.0) + return np.stack((red, green, blue), axis=-1) + + +def float_image_to_uint8(image: np.ndarray) -> np.ndarray: + return (np.clip(image, 0.0, 1.0) * 255.0).astype(np.uint8) + + +def resize_uint8_image(image: np.ndarray, size: Tuple[int, int]) -> np.ndarray: + width, height = size + resample_filter = getattr(Image, "Resampling", Image).LANCZOS + resized = Image.fromarray(image).resize((max(1, width), max(1, height)), resample_filter) + return np.asarray(resized, dtype=np.uint8) + + +def get_image_canvas_rect( + image_size: Tuple[int, int], + canvas_size: Tuple[int, int], + display_mode: str, +) -> Tuple[float, float, float, float]: + image_width, image_height = image_size + canvas_width, canvas_height = canvas_size + canvas_width = max(1, canvas_width) + canvas_height = max(1, canvas_height) + + if display_mode == "fit": + return 0.0, 0.0, float(canvas_width), float(canvas_height) + + scale = min(canvas_width / image_width, canvas_height / image_height) + display_width = image_width * scale + display_height = image_height * scale + image_x0 = 0.5 * (canvas_width - display_width) + image_y0 = 0.5 * (canvas_height - display_height) + return image_x0, image_y0, display_width, display_height + + +def fit_image_to_canvas( + image: np.ndarray, + canvas_size: Tuple[int, int], + display_mode: str, +) -> np.ndarray: + canvas_width, canvas_height = canvas_size + canvas_width = max(1, canvas_width) + canvas_height = max(1, canvas_height) + image_height, image_width = image.shape[:2] + canvas = np.zeros((canvas_height, canvas_width, 3), dtype=np.uint8) + + if display_mode == "fit": + return resize_uint8_image(image, (canvas_width, canvas_height)) + + _, _, display_width, display_height = get_image_canvas_rect( + image_size=(image_width, image_height), + canvas_size=(canvas_width, canvas_height), + display_mode=display_mode, + ) + resized_width = max(1, int(round(display_width))) + resized_height = max(1, int(round(display_height))) + resized = resize_uint8_image(image, (resized_width, resized_height)) + + source_x0 = max(0, (resized_width - canvas_width) // 2) + source_y0 = max(0, (resized_height - canvas_height) // 2) + target_x0 = max(0, (canvas_width - resized_width) // 2) + target_y0 = max(0, (canvas_height - resized_height) // 2) + copy_width = min(resized_width - source_x0, canvas_width - target_x0) + copy_height = min(resized_height - source_y0, canvas_height - target_y0) + + canvas[target_y0 : target_y0 + copy_height, target_x0 : target_x0 + copy_width] = resized[ + source_y0 : source_y0 + copy_height, + source_x0 : source_x0 + copy_width, + ] + return canvas + + +class ImageComparisonViewer: + def __init__(self, image_pairs: List[ImagePair], host: str, port: int, target_fps: float) -> None: + self.image_pairs = image_pairs + self.host = host + self.port = port + self.target_fps = target_fps + self.viser = import_viser() + self.server = self.viser.ViserServer(host=self.host, port=self.port) + self.need_update = True + self.image_cache: Dict[str, Tuple[np.ndarray, np.ndarray, str]] = {} + self.metric_cache: Dict[str, MetricResults] = {} + self.global_metric_cache: Dict[bool, MetricResults] = {} + self.error_map_cache: Dict[Tuple[str, str], np.ndarray] = {} + + self.image_pair_dropdown = None + self.display_mode_dropdown = None + self.mode_dropdown = None + self.slider_direction_dropdown = None + self.slider_position_slider = None + self.checker_size_slider = None + self.diff_metric_dropdown = None + self.diff_scale_slider = None + self.current_metrics_markdown = None + self.global_metrics_markdown = None + self.status_text = None + + self.init_ui() + + @self.server.on_client_connect + def _(client) -> None: + self.need_update = True + + def init_ui(self) -> None: + with self.server.gui.add_folder("Image Comparison"): + image_pair_names = [image_pair.name for image_pair in self.image_pairs] + self.image_pair_dropdown = self.server.gui.add_dropdown( + "Image Pair", + options=image_pair_names, + initial_value=image_pair_names[0], + ) + previous_image_button = self.server.gui.add_button("Previous Image") + next_image_button = self.server.gui.add_button("Next Image") + self.display_mode_dropdown = self.server.gui.add_dropdown( + "Display Mode", + options=DISPLAY_MODES, + initial_value=DISPLAY_MODES[0], + ) + self.mode_dropdown = self.server.gui.add_dropdown( + "Mode", + options=COMPARISON_MODES, + initial_value=COMPARISON_MODES[0], + ) + self.slider_direction_dropdown = self.server.gui.add_dropdown( + "Slider Direction", + options=SLIDER_DIRECTIONS, + initial_value=SLIDER_DIRECTIONS[0], + ) + self.slider_position_slider = self.server.gui.add_slider( + "Slider Position", + min=0.0, + max=1.0, + step=0.01, + initial_value=0.5, + ) + self.checker_size_slider = self.server.gui.add_slider( + "Checker Size", + min=4, + max=256, + step=1, + initial_value=32, + ) + self.diff_metric_dropdown = self.server.gui.add_dropdown( + "Diff Metric", + options=DIFF_METRICS, + initial_value=DIFF_METRICS[0], + ) + self.diff_scale_slider = self.server.gui.add_slider( + "Diff Scale", + min=0.1, + max=20.0, + step=0.1, + initial_value=4.0, + ) + reload_button = self.server.gui.add_button("Reload Images") + self.status_text = self.server.gui.add_text("Status", initial_value="Loading", disabled=True) + + with self.server.gui.add_folder("Metrics"): + empty_metrics = MetricResults(psnr=None, ssim=None, flip=None) + self.current_metrics_markdown = self.server.gui.add_markdown( + format_metric_markdown("Current Image", empty_metrics) + ) + self.global_metrics_markdown = self.server.gui.add_markdown( + format_metric_markdown("Folder Mean", empty_metrics, count=len(self.image_pairs)) + ) + + controls = [ + self.image_pair_dropdown, + self.display_mode_dropdown, + self.mode_dropdown, + self.slider_direction_dropdown, + self.slider_position_slider, + self.checker_size_slider, + self.diff_metric_dropdown, + self.diff_scale_slider, + ] + + for control in controls: + + @control.on_update + def _(_) -> None: + self.need_update = True + + @reload_button.on_click + def _(_) -> None: + self.image_cache.clear() + self.metric_cache.clear() + self.global_metric_cache.clear() + self.error_map_cache.clear() + self.need_update = True + + @previous_image_button.on_click + def _(_) -> None: + self.select_relative_image_pair(offset=-1) + + @next_image_button.on_click + def _(_) -> None: + self.select_relative_image_pair(offset=1) + + def select_relative_image_pair(self, offset: int) -> None: + selected_name = self.image_pair_dropdown.value + image_pair_names = [image_pair.name for image_pair in self.image_pairs] + try: + selected_index = image_pair_names.index(selected_name) + except ValueError: + selected_index = 0 + + next_index = (selected_index + offset) % len(self.image_pairs) + self.image_pair_dropdown.value = image_pair_names[next_index] + self.need_update = True + + def get_selected_pair(self) -> ImagePair: + selected_name = self.image_pair_dropdown.value + for image_pair in self.image_pairs: + if image_pair.name == selected_name: + return image_pair + + return self.image_pairs[0] + + def get_aligned_pair(self, image_pair: ImagePair) -> Tuple[np.ndarray, np.ndarray, str]: + if image_pair.name not in self.image_cache: + self.image_cache[image_pair.name] = load_aligned_pair(image_pair) + return self.image_cache[image_pair.name] + + def get_metric_results(self, image_pair: ImagePair, image_a: np.ndarray, image_b: np.ndarray) -> MetricResults: + if image_pair.name not in self.metric_cache: + self.metric_cache[image_pair.name] = compute_metric_results(image_a=image_a, image_b=image_b) + return self.metric_cache[image_pair.name] + + def get_global_metric_results(self) -> MetricResults: + if True in self.global_metric_cache: + return self.global_metric_cache[True] + + global_metric_results = [] + for image_pair in self.image_pairs: + image_a, image_b, _ = self.get_aligned_pair(image_pair) + metric_results = self.get_metric_results(image_pair=image_pair, image_a=image_a, image_b=image_b) + global_metric_results.append(metric_results) + + self.global_metric_cache[True] = aggregate_metric_results(global_metric_results) + return self.global_metric_cache[True] + + def update_metric_widgets( + self, + current_metric_results: MetricResults, + global_metric_results: MetricResults, + ) -> None: + self.current_metrics_markdown.content = format_metric_markdown("Current Image", current_metric_results) + self.global_metrics_markdown.content = format_metric_markdown( + "Folder Mean", + global_metric_results, + count=len(self.image_pairs), + ) + + def render_current_diff( + self, + image_pair: ImagePair, + image_a: np.ndarray, + image_b: np.ndarray, + ) -> np.ndarray: + diff_metric = self.diff_metric_dropdown.value + cache_key = (image_pair.name, diff_metric) + if cache_key not in self.error_map_cache: + self.error_map_cache[cache_key] = compute_error_map( + image_a=image_a, + image_b=image_b, + diff_metric=diff_metric, + ) + + scaled_error = np.clip(self.error_map_cache[cache_key] * float(self.diff_scale_slider.value), 0.0, 1.0) + return apply_jet_colormap(scaled_error) + + def render_current_comparison(self) -> np.ndarray: + image_pair = self.get_selected_pair() + image_a, image_b, status = self.get_aligned_pair(image_pair) + metric_results = self.get_metric_results(image_pair=image_pair, image_a=image_a, image_b=image_b) + global_metric_results = self.get_global_metric_results() + self.update_metric_widgets( + current_metric_results=metric_results, + global_metric_results=global_metric_results, + ) + mode = self.mode_dropdown.value + + if mode == "checkerboard": + output = render_checkerboard_comparison( + image_a=image_a, + image_b=image_b, + checker_size=int(self.checker_size_slider.value), + ) + elif mode == "diff": + output = self.render_current_diff(image_pair=image_pair, image_a=image_a, image_b=image_b) + else: + output = render_slider_comparison( + image_a=image_a, + image_b=image_b, + slider_position=float(self.slider_position_slider.value), + slider_direction=self.slider_direction_dropdown.value, + ) + + if mode == "slider": + self.status_text.value = f"{status} | mode: {mode}" + elif mode == "diff": + diff_metric = self.diff_metric_dropdown.value + warning = self.get_metric_warning(metric_results=metric_results, diff_metric=diff_metric) + self.status_text.value = f"{status} | mode: {mode} | diff: {diff_metric}{warning}" + else: + self.status_text.value = f"{status} | mode: {mode}" + return float_image_to_uint8(output) + + def get_metric_warning(self, metric_results: MetricResults, diff_metric: str) -> str: + if diff_metric == "flip" and metric_results.flip_error is not None: + return f" | {metric_results.flip_error}" + return "" + + def display_output(self, output: np.ndarray) -> None: + display_mode = self.display_mode_dropdown.value + for client in self.server.get_clients().values(): + canvas_width = int(client.camera.image_width or output.shape[1]) + canvas_height = int(client.camera.image_height or output.shape[0]) + display_image = fit_image_to_canvas( + image=output, + canvas_size=(canvas_width, canvas_height), + display_mode=display_mode, + ) + client.scene.set_background_image(display_image, format="jpeg") + + def update(self) -> None: + if not self.need_update: + return + + output = self.render_current_comparison() + self.display_output(output) + + self.need_update = False + + def run(self) -> None: + print_server_urls(host=self.host, port=self.port) + while True: + self.update() + time.sleep(max(0.001, 1.0 / self.target_fps)) + + +def get_candidate_host_addresses() -> List[str]: + addresses = ["127.0.0.1"] + try: + hostname = socket.gethostname() + for address_info in socket.getaddrinfo(hostname, None, family=socket.AF_INET): + address = address_info[4][0] + if address not in addresses and not address.startswith("127."): + addresses.append(address) + except OSError: + pass + return addresses + + +def print_server_urls(host: str, port: int) -> None: + if host in ("0.0.0.0", "::"): + print("Viser is listening on all interfaces. Try these URLs:") + for address in get_candidate_host_addresses(): + print(f" http://{address}:{port}") + else: + print(f"Viser URL: http://{host}:{port}") + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Viser based image comparison viewer.") + + input_group = parser.add_mutually_exclusive_group(required=True) + input_group.add_argument( + "--images", + nargs=2, + metavar=("IMAGE_A", "IMAGE_B"), + help="Compare two specific images.", + ) + input_group.add_argument( + "--folders", + nargs=2, + metavar=("FOLDER_A", "FOLDER_B"), + help="Compare matching image names from two folders.", + ) + + parser.add_argument("--host", type=str, default="0.0.0.0", help="Viser server host/interface.") + parser.add_argument("--port", type=int, default=8080, help="Viser server port.") + parser.add_argument("--target_fps", type=float, default=20.0, help="Maximum UI refresh rate.") + return parser.parse_args() + + +def main() -> None: + args = parse_args() + + if args.images is not None: + image_pairs = build_specific_image_pair(Path(args.images[0]), Path(args.images[1])) + else: + image_pairs = build_folder_image_pairs(Path(args.folders[0]), Path(args.folders[1])) + + viewer = ImageComparisonViewer( + image_pairs=image_pairs, + host=args.host, + port=args.port, + target_fps=args.target_fps, + ) + viewer.run() + + +if __name__ == "__main__": + main() diff --git a/tools/ppisp_export/bake_modes_benchmark/__init__.py b/tools/ppisp_export/bake_modes_benchmark/__init__.py new file mode 100644 index 00000000..52a7a9da --- /dev/null +++ b/tools/ppisp_export/bake_modes_benchmark/__init__.py @@ -0,0 +1,2 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 diff --git a/tools/ppisp_export/bake_modes_benchmark/benchmark.py b/tools/ppisp_export/bake_modes_benchmark/benchmark.py new file mode 100644 index 00000000..0d1d902f --- /dev/null +++ b/tools/ppisp_export/bake_modes_benchmark/benchmark.py @@ -0,0 +1,372 @@ +#!/usr/bin/env python3 +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Sweep PPISP SH-bake modes on a trained checkpoint and report aggregated metrics. + +The bake fits gamma-space (display-referred) SH coefficients against the +PPISP forward output of the trained model, matching the colour space of +the no-PPISP export. Modes vary along two axes: + +* ``simple`` flavours skip optimisation and write only the DC band. +* ``fit`` flavours run :func:`bake_post_processing_into_sh` -- Adam over + features_albedo, features_specular, and (optionally) density. View + sampling is either ``training`` (iterate the dataloader) or + ``trajectory`` (NN+2-opt arc-length-parameterised slerp through training + poses; useful when training views are sparse). + +Per-frame validation: + reference = full PPISP applied to reference-model render at val pose + baked = baked-model render (already display-referred) clipped to [0, 1] + +Metrics: per-frame PSNR (+ optional SSIM / LPIPS), aggregated mean / +median / min / max across the val split. Raw per-frame numbers are +persisted to ``/metrics.json``. + +Usage: + + python tools/ppisp_export/bake_modes_benchmark/benchmark.py \\ + --checkpoint runs//ckpt_last.pt \\ + --out-dir /tmp/bake_modes \\ + --camera-id 0 --frame-id 0 +""" + +from __future__ import annotations + +import argparse +import json +import logging +import sys +import time +from dataclasses import dataclass, field +from pathlib import Path +from typing import Callable, Dict, List, Optional + +import numpy as np +import torch +import torch.nn as nn + +sys.path.insert(0, str(Path(__file__).resolve().parents[2])) + +from threedgrut.export.usd.post_processing_sh_bake import ( # noqa: E402 + MODE_PPISP_BAKE_VIGNETTING_NONE, + PPISPPostProcessingBakeAdapter, + bake_post_processing_into_sh, + FixedPPISP, +) +from threedgrut.export.usd.post_processing_sh_simple_bake import simple_bake # noqa: E402 +from threedgrut.render import Renderer # noqa: E402 +from threedgrut.utils.render import apply_post_processing # noqa: E402 + +logger = logging.getLogger("bake_modes_benchmark") + + +# --------------------------------------------------------------------------- +# Mode catalogue +# --------------------------------------------------------------------------- + + +@dataclass +class BakeMode: + """One row in the sweep -- a bake configuration with a short name.""" + name: str + description: str + builder: Callable[..., nn.Module] + + +def _build_simple(*, model, ppisp, camera_id, frame_id, higher_order, + dataset=None, conf=None): + del dataset, conf # unused by the simple flavours + baked = model.clone().eval() + simple_bake( + baked, ppisp, + camera_id=camera_id, frame_id=frame_id, + higher_order=higher_order, apply_srgb_to_linear=False, + ) + baked.build_acc() + return baked + + +def _build_fit(*, model, ppisp, dataset, conf, camera_id, frame_id, + view_mode, view_seed, epochs, learning_rate, optimize_density: bool): + """Run the full fit-by-bake flow with the production adapter (gamma SH, + no vignetting). ``optimize_density=False`` ablates the density param + group by setting its lr to zero.""" + adapter = PPISPPostProcessingBakeAdapter( + camera_id=camera_id, frame_id=frame_id, + vignetting_mode=MODE_PPISP_BAKE_VIGNETTING_NONE, + ) + return bake_post_processing_into_sh( + model=model, post_processing=ppisp, train_dataset=dataset, conf=conf, + adapter=adapter, epochs=epochs, learning_rate=learning_rate, + learning_rate_density=(5.0e-2 if optimize_density else 0.0), + view_sampling_mode=view_mode, interpolated_views_seed=view_seed, + ) + + +def all_modes(*, fit_epochs: int, fit_lr: float, view_seed: int) -> List[BakeMode]: + return [ + BakeMode( + "simple", + "one-shot DC-only bake (no fit, gamma SH)", + lambda **k: _build_simple(**k, higher_order=False), + ), + BakeMode( + "simple-higher-order", + "one-shot DC + Jacobian-rotated specular (no fit)", + lambda **k: _build_simple(**k, higher_order=True), + ), + BakeMode( + "fit-color-only", + "Adam fit on features_albedo + features_specular only, training views", + lambda **k: _build_fit( + **k, view_mode="training", view_seed=view_seed, + epochs=fit_epochs, learning_rate=fit_lr, optimize_density=False, + ), + ), + BakeMode( + "fit", + "Adam fit on albedo + specular + density, training views (production default)", + lambda **k: _build_fit( + **k, view_mode="training", view_seed=view_seed, + epochs=fit_epochs, learning_rate=fit_lr, optimize_density=True, + ), + ), + BakeMode( + "fit-trajectory", + "Adam fit on albedo + specular + density, trajectory views (NN+2-opt slerp)", + lambda **k: _build_fit( + **k, view_mode="trajectory", view_seed=view_seed, + epochs=fit_epochs, learning_rate=fit_lr, optimize_density=True, + ), + ), + ] + + +# --------------------------------------------------------------------------- +# Per-frame evaluation +# --------------------------------------------------------------------------- + + +@dataclass +class FrameMetrics: + psnr: List[float] = field(default_factory=list) + ssim: List[float] = field(default_factory=list) + lpips: List[float] = field(default_factory=list) + + +def _stats(values: List[float]) -> Dict[str, float]: + if not values: + return {"mean": float("nan"), "median": float("nan"), + "min": float("nan"), "max": float("nan")} + arr = np.asarray(values, dtype=np.float64) + return { + "mean": float(np.mean(arr)), + "median": float(np.median(arr)), + "min": float(np.min(arr)), + "max": float(np.max(arr)), + "n": len(values), + } + + +def _evaluate_mode( + baked_model, + reference_model, + fixed_pp, + dataset, + dataloader, + criteria, + max_frames: Optional[int], +) -> FrameMetrics: + fm = FrameMetrics() + with torch.no_grad(): + for i, batch in enumerate(dataloader): + if max_frames is not None and i >= max_frames: + break + gpu_batch = dataset.get_gpu_batch_with_intrinsics(batch) + + # reference: full per-frame PPISP applied to reference render + ref_outputs = reference_model(gpu_batch) + ref_outputs = apply_post_processing(fixed_pp, ref_outputs, gpu_batch, training=False) + ref_rgb = ref_outputs["pred_rgb"].clip(0, 1) + + # baked: SH eval is already display-referred (gamma); just clip. + baked_outputs = baked_model(gpu_batch) + baked_rgb = torch.clamp(baked_outputs["pred_rgb"], 0, 1) + + fm.psnr.append(criteria["psnr"](baked_rgb, ref_rgb).item()) + if "ssim" in criteria: + fm.ssim.append(criteria["ssim"]( + baked_rgb.permute(0, 3, 1, 2), ref_rgb.permute(0, 3, 1, 2), + ).item()) + if "lpips" in criteria: + fm.lpips.append(criteria["lpips"]( + baked_rgb.clip(0, 1).permute(0, 3, 1, 2), + ref_rgb.clip(0, 1).permute(0, 3, 1, 2), + ).item()) + return fm + + +# --------------------------------------------------------------------------- +# Reporting +# --------------------------------------------------------------------------- + + +def _print_table(rows: Dict[str, Dict[str, Dict[str, float]]]) -> None: + """Print one table per metric (PSNR / SSIM / LPIPS), sorted by mean.""" + for metric in ("psnr", "ssim", "lpips"): + any_data = any(metric in r for r in rows.values()) + if not any_data: + continue + print(f"\n=== {metric.upper()} (val split, {next(iter(rows.values())).get(metric, {}).get('n', '?')} frames) ===") + if metric == "psnr": + print(f"{'mode':<28} {'mean':>9} {'median':>9} {'min':>9} {'max':>9}") + else: + print(f"{'mode':<28} {'mean':>9} {'median':>9} {'min':>9} {'max':>9}") + sorted_modes = sorted( + rows.items(), + key=lambda kv: -kv[1].get(metric, {}).get("mean", float("-inf")) + if metric == "psnr" or metric == "ssim" + else kv[1].get(metric, {}).get("mean", float("inf")), + ) + for mode_name, metrics in sorted_modes: + s = metrics.get(metric) + if s is None: + continue + fmt = "%.3f" if metric != "psnr" else "%6.3f" + print( + f"{mode_name:<28} " + f"{s['mean']:>9.4f} {s['median']:>9.4f} " + f"{s['min']:>9.4f} {s['max']:>9.4f}" + ) + + +# --------------------------------------------------------------------------- +# CLI +# --------------------------------------------------------------------------- + + +def main(argv=None) -> int: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("--checkpoint", type=Path, required=True) + parser.add_argument("--data-path", type=str, default="") + parser.add_argument("--out-dir", type=Path, required=True) + parser.add_argument("--camera-id", type=int, default=0) + parser.add_argument("--frame-id", type=int, default=0) + parser.add_argument("--fit-epochs", type=int, default=1) + parser.add_argument("--fit-lr", type=float, default=1.0e-3) + parser.add_argument("--view-seed", type=int, default=0) + parser.add_argument("--max-frames", type=int, default=None, + help="Limit val frames for quick smoke checks.") + parser.add_argument("--modes", nargs="*", default=None, + help="Subset of mode names to run (default: all).") + parser.add_argument("--no-extra-metrics", action="store_true", + help="Skip SSIM/LPIPS (PSNR only).") + parser.add_argument("--verbose", "-v", action="count", default=0) + args = parser.parse_args(argv) + logging.basicConfig(level=logging.INFO - 10 * args.verbose, + format="%(asctime)s %(levelname)s %(name)s: %(message)s") + + if not torch.cuda.is_available(): + raise SystemExit("CUDA required.") + + args.out_dir.mkdir(parents=True, exist_ok=True) + + renderer = Renderer.from_checkpoint( + checkpoint_path=str(args.checkpoint), + path=args.data_path, + out_dir=str(args.out_dir / "_renderer"), + save_gt=False, computes_extra_metrics=not args.no_extra_metrics, + ) + if renderer.post_processing is None: + raise SystemExit("Checkpoint does not contain PPISP.") + ppisp = renderer.post_processing + if not hasattr(ppisp, "vignetting_params"): + raise SystemExit("Checkpoint post-processing is not PPISP-like.") + + # The bake target is PPISP-without-vignetting (matches the production + # MODE_PPISP_BAKE_VIGNETTING_NONE adapter); both reference and baked + # sides therefore live in the same display-referred space. + fixed_pp = FixedPPISP( + ppisp, args.camera_id, args.frame_id, "cuda", include_vignetting=False, + ).eval() + + # Train dataset for the fit modes (interpolated samplers need it for poses). + from threedgrut.export.usd.post_processing_sh_bake import _create_train_dataloader + train_dataset = renderer.dataset.__class__ # type: ignore + # Re-create train dataset from the loader's dataset reference: easier to + # use renderer.conf-based factory. + import threedgrut.datasets as datasets + train_ds = datasets.make_train(name=renderer.conf.dataset.type, config=renderer.conf, ray_jitter=None) + + from torchmetrics import PeakSignalNoiseRatio + criteria: Dict[str, nn.Module] = {"psnr": PeakSignalNoiseRatio(data_range=1).to("cuda")} + if not args.no_extra_metrics: + from torchmetrics.image import StructuralSimilarityIndexMeasure + from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity + criteria["ssim"] = StructuralSimilarityIndexMeasure(data_range=1.0).to("cuda") + criteria["lpips"] = LearnedPerceptualImagePatchSimilarity(net_type="vgg", normalize=True).to("cuda") + + catalogue = all_modes( + fit_epochs=args.fit_epochs, fit_lr=args.fit_lr, view_seed=args.view_seed, + ) + if args.modes is not None: + wanted = set(args.modes) + catalogue = [m for m in catalogue if m.name in wanted] + if not catalogue: + raise SystemExit(f"No modes match {sorted(wanted)}") + + rows: Dict[str, Dict[str, Dict[str, float]]] = {} + timings: Dict[str, float] = {} + + for mode in catalogue: + logger.info("=" * 60) + logger.info("MODE %s -- %s", mode.name, mode.description) + t0 = time.time() + baked = mode.builder( + model=renderer.model, ppisp=ppisp, dataset=train_ds, conf=renderer.conf, + camera_id=args.camera_id, frame_id=args.frame_id, + ) + build_time = time.time() - t0 + logger.info(" built in %.2fs", build_time) + + fm = _evaluate_mode( + baked, renderer.model, fixed_pp, + renderer.dataset, renderer.dataloader, criteria, + max_frames=args.max_frames, + ) + row = {"psnr": _stats(fm.psnr)} + if not args.no_extra_metrics: + row["ssim"] = _stats(fm.ssim) + row["lpips"] = _stats(fm.lpips) + rows[mode.name] = row + timings[mode.name] = build_time + logger.info( + " %s: PSNR mean=%.3f median=%.3f (n=%d)", + mode.name, row["psnr"]["mean"], row["psnr"]["median"], row["psnr"]["n"], + ) + + _print_table(rows) + print("\n=== Build time (seconds) ===") + for name, t in sorted(timings.items(), key=lambda kv: kv[1]): + print(f" {name:<28} {t:>7.2f} s") + + # Persist raw per-frame numbers for offline analysis. + serial = { + name: { + "build_time_s": timings[name], + **{ + metric: rows[name][metric] + for metric in ("psnr", "ssim", "lpips") if metric in rows[name] + }, + } + for name in rows + } + with open(args.out_dir / "metrics.json", "w") as f: + json.dump(serial, f, indent=2) + logger.info("metrics.json saved to %s", args.out_dir / "metrics.json") + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/tools/render_ppisp_spg/README.md b/tools/render_ppisp_spg/README.md new file mode 100644 index 00000000..16716f2b --- /dev/null +++ b/tools/render_ppisp_spg/README.md @@ -0,0 +1,50 @@ +# render_ppisp_spg + +Headless **slangpy** harness for the PPISP SPG sidecars. Lets you +validate the exported `*.slang` / `*.slang.lua` chain end-to-end without +booting Omniverse Kit. + +## What it does + +- Loads `ppisp_controller.slang` (and the dynamic / static + `ppisp_usd_spg*.slang`) directly from the on-disk SPG sidecar set. +- Strips the `[[vk::binding(*, *)]]` annotations that Kit's SPG layer + consumes (slangpy uses its own auto-binding) and dispatches the same + compute kernel. +- Reads time-sampled USD attributes off a PPISP-bearing + `RenderProduct` and walks frame-by-frame against an HDR input dir. + +## Three entry points + +| Function | Use | +| --- | --- | +| `run_controller(slang, hdr, weights, prior=0)` | Returns the 9-float controller output: `[exposureOffset, blue.xy, red.xy, green.xy, neutral.xy]`. | +| `run_ppisp_dyn(slang, hdr, ctrl_out, vignette, crf)` | Reads colour / exposure from a controller output texture and returns an LDR uint8 image. | +| `run_ppisp_static(slang, hdr, exposure, color_latents, vignette, crf)` | The legacy controller-free path; reads exposure / colour from explicit args. | + +## CLI + +``` +python tools/render_ppisp_spg/render_renderproduct.py \ + out.usdz hdr_inputs/ ldr_outputs/ +``` + +The HDR input layout is one folder per camera-name, with files named +`.{npy,exr,png}`. + +## Validation + +``` +python tools/render_ppisp_spg/validate_controller.py --tol 1e-4 +``` + +Generates a synthetic torch `_PPISPController`, bakes its weights via +`flatten_controller_weights`, dispatches the SPG controller shader, and +compares the 9-element result against the torch reference. Typical max +abs diff is around 4e-6. + +## Dependencies + +`slangpy`, `numpy`, `Pillow`, `usd-core`, and (only for +`validate_controller.py`) `torch`. `OpenEXR`/`Imath` are optional and +only loaded when an `.exr` HDR input is encountered. diff --git a/tools/render_ppisp_spg/__init__.py b/tools/render_ppisp_spg/__init__.py new file mode 100644 index 00000000..52a7a9da --- /dev/null +++ b/tools/render_ppisp_spg/__init__.py @@ -0,0 +1,2 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 diff --git a/tools/render_ppisp_spg/diagnose_controller.py b/tools/render_ppisp_spg/diagnose_controller.py new file mode 100644 index 00000000..daf73d41 --- /dev/null +++ b/tools/render_ppisp_spg/diagnose_controller.py @@ -0,0 +1,255 @@ +#!/usr/bin/env python3 +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Triage helper for "controller-driven Omniverse render disagrees with the +optimized-params render". + +Runs three independent checks in order, each isolating one of the three +hypotheses you flagged. + +H1 -- *Did the controller actually learn the optimized params?* + For a sample of frames, compare in-process + exposure, color = controller(gaussian_render(frame), prior=0) + against the trained per-frame parameters + ppisp.exposure_params[frame_idx], ppisp.color_params[frame_idx]. + Pure PyTorch, no SPG/slang. If these disagree, the controller has not + converged to the per-frame state and any SPG export will inherit that. + +H2 -- *Does the slang controller match the PyTorch controller?* + Run controller(rgb, prior) twice on the same input: + a. PyTorch (ppisp.controllers[c]). + b. slangpy on ppisp_controller.slang. + These should agree to ~1e-6 (we measured 3e-7 on bonsai). A larger + delta means the slang shader, the weight flatten, or the buffer + upload disagrees with the trained controller. + +H3 -- *Is the SPG controller -> PPISP plumbing sound?* + Two ways to drive the dynamic PPISP shader on the same HDR: + a. dynamic path: controller slang writes ControllerParams texture, + ppisp_usd_spg_dyn.slang reads it. + b. static path: feed the *same* 9 floats (taken from the PyTorch + controller in step H1/H2) as USD attributes into the legacy + ppisp_usd_spg.slang shader. + These should produce byte-for-byte the same LDR image. If they + disagree, the dynamic shader's texture binding or layout is wrong. + +Usage: + python tools/render_ppisp_spg/diagnose_controller.py \ + --checkpoint runs//ckpt_last.pt \ + --max-frames 4 +""" + +from __future__ import annotations + +import argparse +import logging +import math +import sys +from pathlib import Path +from typing import List + +import numpy as np +import torch + +sys.path.insert(0, str(Path(__file__).resolve().parents[2])) + +from threedgrut.export.usd.writers.ppisp_controller_writer import ( # noqa: E402 + flatten_controller_weights, +) +from threedgrut.render import Renderer # noqa: E402 +from tools.render_ppisp_spg.spg_runtime import ( # noqa: E402 + CrfParams, VignetteParams, run_controller, run_ppisp_dyn, run_ppisp_static, +) + +logger = logging.getLogger("diagnose_controller") + + +PPISP_SPG_DIR = Path(__file__).resolve().parents[2] / "threedgrut/export/usd/ppisp_spg" +CONTROLLER_SLANG = PPISP_SPG_DIR / "ppisp_controller.slang" +PPISP_DYN_SLANG = PPISP_SPG_DIR / "ppisp_usd_spg_dyn.slang" +PPISP_STATIC_SLANG = PPISP_SPG_DIR / "ppisp_usd_spg.slang" + + +def _vignette_for_camera(ppisp, camera_idx: int) -> VignetteParams: + v = ppisp.vignetting_params[camera_idx].detach().cpu().numpy() + p = VignetteParams() + for ch_idx, ch in enumerate(("r", "g", "b")): + setattr(p, f"center_{ch}", (float(v[ch_idx, 0]), float(v[ch_idx, 1]))) + setattr(p, f"alpha1_{ch}", float(v[ch_idx, 2])) + setattr(p, f"alpha2_{ch}", float(v[ch_idx, 3])) + setattr(p, f"alpha3_{ch}", float(v[ch_idx, 4])) + return p + + +def _crf_for_camera(ppisp, camera_idx: int) -> CrfParams: + crf = ppisp.crf_params[camera_idx].detach().cpu().numpy() + p = CrfParams() + for ch_idx, ch in enumerate(("r", "g", "b")): + setattr(p, f"toe_{ch}", float(crf[ch_idx, 0])) + setattr(p, f"shoulder_{ch}", float(crf[ch_idx, 1])) + setattr(p, f"gamma_{ch}", float(crf[ch_idx, 2])) + setattr(p, f"center_{ch}", float(crf[ch_idx, 3])) + return p + + +def _torch_controller(controller, rgb_np: np.ndarray, prior: float = 0.0) -> np.ndarray: + rgb = torch.from_numpy(rgb_np).float().to("cuda") + pe = torch.tensor([prior], dtype=torch.float32, device="cuda") + with torch.no_grad(): + e, c = controller(rgb, pe) + return np.concatenate([ + np.array([float(e)], dtype=np.float32), + c.detach().cpu().numpy().astype(np.float32), + ]) + + +def _gather_frames(renderer: Renderer, max_frames: int): + """For each batch yield (frame_idx, camera_idx, hdr_np).""" + out = [] + for i, batch in enumerate(renderer.dataloader): + if i >= max_frames: + break + gpu_batch = renderer.dataset.get_gpu_batch_with_intrinsics(batch) + with torch.no_grad(): + outputs = renderer.model(gpu_batch) + hdr = outputs["pred_rgb"][0].detach().cpu().numpy().astype(np.float32) + cam = (renderer.dataset.get_camera_idx(i) if hasattr(renderer.dataset, "get_camera_idx") else 0) + out.append((i, int(cam), hdr)) + return out + + +def main(argv=None) -> int: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("--checkpoint", type=Path, required=True) + parser.add_argument("--data-path", type=Path, default=None) + parser.add_argument("--max-frames", type=int, default=4) + parser.add_argument("--prior", type=float, default=0.0, + help="priorExposure value to use at inference. Match what " + "you pass at export time (default 0.0).") + parser.add_argument("--verbose", "-v", action="count", default=0) + args = parser.parse_args(argv) + logging.basicConfig(level=logging.INFO - 10 * args.verbose, + format="%(asctime)s %(levelname)s %(name)s: %(message)s") + + if not torch.cuda.is_available(): + raise SystemExit("CUDA required.") + + renderer = Renderer.from_checkpoint( + checkpoint_path=str(args.checkpoint), + path=str(args.data_path) if args.data_path else "", + out_dir="/tmp/_diag_unused", save_gt=False, computes_extra_metrics=False, + ) + pp = renderer.post_processing + if pp is None or type(pp).__name__ != "PPISP": + raise SystemExit("Checkpoint has no PPISP module.") + if not pp.controllers or len(pp.controllers) == 0: + raise SystemExit("Checkpoint PPISP has no controllers.") + + frames = _gather_frames(renderer, args.max_frames) + if not frames: + raise SystemExit("No frames produced by the dataloader.") + + # ------------------------------------------------------------------ + # H1: trained per-frame params vs in-process controller prediction + # ------------------------------------------------------------------ + print("\n=== H1: PyTorch controller(rgb) vs trained per-frame params ===") + print(f"{'frame':>5} {'cam':>3} {'exp_train':>10} {'exp_pred':>10} " + f"{'Δexp(stops)':>12} {'col_max|Δ|':>11}") + h1_max_exp_diff = 0.0 + h1_max_col_diff = 0.0 + for fidx, cam, hdr in frames: + ctrl = pp.controllers[cam] + pred = _torch_controller(ctrl, hdr, args.prior) + exp_pred = float(pred[0]) + col_pred = pred[1:] + if fidx >= int(pp.exposure_params.shape[0]): + print(f" frame {fidx}: out of range for exposure_params (size {pp.exposure_params.shape[0]})") + continue + exp_train = float(pp.exposure_params[fidx].detach().cpu()) + col_train = pp.color_params[fidx].detach().cpu().numpy().astype(np.float32) + d_exp = exp_pred - exp_train + d_col = float(np.max(np.abs(col_pred - col_train))) + h1_max_exp_diff = max(h1_max_exp_diff, abs(d_exp)) + h1_max_col_diff = max(h1_max_col_diff, d_col) + print(f"{fidx:>5} {cam:>3} {exp_train:>+10.4f} {exp_pred:>+10.4f} " + f"{d_exp:>+12.3f} {d_col:>11.4f}") + print(f" H1 worst: Δexposure = {h1_max_exp_diff:.3f} stops max|Δcolor| = {h1_max_col_diff:.4f}") + print(f" Interpretation: if Δexposure > ~0.3 stops or Δcolor > ~0.05, the controller has") + print(f" not converged to the optimized per-frame state. The static-export path uses the") + print(f" trained values directly, so it will look 'less exposed' than the controller path.") + + # ------------------------------------------------------------------ + # H2: PyTorch controller vs slang controller + # ------------------------------------------------------------------ + print("\n=== H2: PyTorch controller vs slang controller (same HDR) ===") + print(f"{'frame':>5} {'cam':>3} {'max|Δ|':>11}") + h2_max = 0.0 + for fidx, cam, hdr in frames: + ctrl = pp.controllers[cam] + torch_out = _torch_controller(ctrl, hdr, args.prior) + weights = flatten_controller_weights(ctrl) + slang_out = run_controller(CONTROLLER_SLANG, hdr, weights, prior_exposure=args.prior) + d = float(np.max(np.abs(torch_out - slang_out))) + h2_max = max(h2_max, d) + print(f"{fidx:>5} {cam:>3} {d:>11.3e}") + print(f" H2 worst: max|Δ| = {h2_max:.3e}") + print(f" Interpretation: should be ~3e-7. Anything > 1e-3 means the slang shader,") + print(f" weight flatten, or buffer upload disagrees with the trained controller.") + + # ------------------------------------------------------------------ + # H3: dynamic shader (reads texture) vs static shader (USD attrs), + # both fed the same 9 floats from PyTorch. + # ------------------------------------------------------------------ + print("\n=== H3: slang dyn (reads ControllerParams texture) vs slang static (USD attrs) ===") + print(f"{'frame':>5} {'cam':>3} {'max|Δ|_u8':>11} {'mean|Δ|_u8':>12}") + h3_max_diff = 0 + for fidx, cam, hdr in frames: + ctrl = pp.controllers[cam] + ctrl_out = _torch_controller(ctrl, hdr, args.prior) # 9-float ground truth + vig = _vignette_for_camera(pp, cam) + crf = _crf_for_camera(pp, cam) + + # Dynamic path: controller-output 9-float buffer fed via texture. + ldr_dyn = run_ppisp_dyn(PPISP_DYN_SLANG, hdr, ctrl_out, vig, crf) + + # Static path: same 9 floats as USD attributes. Splits ctrl_out into + # exposure (1) + 4x float2 colour latents in declared order. + exposure = float(ctrl_out[0]) + color = list(ctrl_out[1:].astype(float)) + ldr_stat = run_ppisp_static(PPISP_STATIC_SLANG, hdr, exposure, color, vig, crf) + + diff = np.abs(ldr_dyn[..., :3].astype(int) - ldr_stat[..., :3].astype(int)) + max_d = int(diff.max()); mean_d = float(diff.mean()) + h3_max_diff = max(h3_max_diff, max_d) + print(f"{fidx:>5} {cam:>3} {max_d:>11d} {mean_d:>12.4f}") + print(f" H3 worst: max|Δ|_u8 = {h3_max_diff}") + print(f" Interpretation: should be 0 (or 1 from dispatch ordering). > a few means the") + print(f" dynamic shader's texture binding / texel layout disagrees with what the") + print(f" controller writes — i.e. the SPG plumbing has a bug.") + + # ------------------------------------------------------------------ + # Summary + # ------------------------------------------------------------------ + print("\n=== Summary ===") + h1_bad = h1_max_exp_diff > 0.3 or h1_max_col_diff > 0.05 + h2_bad = h2_max > 1e-3 + h3_bad = h3_max_diff > 3 + verdict = [] + if h1_bad: verdict.append("H1 fails -- controller did not learn the per-frame params (training).") + if h2_bad: verdict.append("H2 fails -- slang controller != PyTorch controller (shader / flatten).") + if h3_bad: verdict.append("H3 fails -- dyn shader != static shader on the same 9 floats (plumbing).") + if not verdict: + print(" All three checks pass within thresholds. If Omniverse still disagrees") + print(" with Python, suspect: (i) Kit applies camera exposure to HdrColor before") + print(" the SPG dispatch, or (ii) Kit's HdrColor scale != gaussian-renderer scale,") + print(" or (iii) priorExposure mismatch between training and the USD attribute.") + else: + for v in verdict: + print(f" - {v}") + + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/tools/render_ppisp_spg/render_renderproduct.py b/tools/render_ppisp_spg/render_renderproduct.py new file mode 100644 index 00000000..aabdc7ba --- /dev/null +++ b/tools/render_ppisp_spg/render_renderproduct.py @@ -0,0 +1,388 @@ +#!/usr/bin/env python3 +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 + +"""Render a PPISP-bearing USD RenderProduct chain via slangpy. + +Given a USD/USDZ that was exported by ``threedgrut.export.usd.exporter`` +with PPISP Omniverse-native mode (and optionally the controller), and a +folder of HDR input images, this tool walks each ``/Render/`` +RenderProduct, finds its ``PPISP[+ Controller]`` Shader prims, resolves +their parameter values for every authored time sample, and dispatches +the matching ``.slang`` files via :mod:`tools.render_ppisp_spg.spg_runtime`. + +For a controllerless export the per-frame exposure / colour latents are +read off the time-sampled USD attributes. With a controller, the +``priorExposure`` value is read once and the controller shader is +dispatched per frame against the supplied HDR input. + +Required layout for the input HDR images: + + //.exr|.png|.npy + +Outputs are written to ``//.png``. +""" + +from __future__ import annotations + +import argparse +import logging +import sys +from dataclasses import asdict +from pathlib import Path +from typing import Dict, Iterable, List, Optional, Tuple + +import numpy as np +from PIL import Image + +from pxr import Sdf, Usd, UsdShade + +# Allow running as a script without installing the tool package. +sys.path.insert(0, str(Path(__file__).resolve().parents[2])) + +from tools.render_ppisp_spg.spg_runtime import ( # noqa: E402 + CrfParams, + VignetteParams, + run_controller, + run_ppisp_dyn, + run_ppisp_static, +) + +logger = logging.getLogger("render_ppisp_spg") + + +CHANNELS = ("R", "G", "B") + + +# --------------------------------------------------------------------------- +# USD parsing +# --------------------------------------------------------------------------- + + +def _resolve_attr_at_time(prim: Usd.Prim, attr_name: str, t: Usd.TimeCode): + attr = prim.GetAttribute(attr_name) + if attr is None or not attr.IsValid(): + return None + return attr.Get(t) + + +def _read_vignetting(ppisp_prim: Usd.Prim, t: Usd.TimeCode) -> VignetteParams: + p = VignetteParams() + + def _f(name: str, default: float) -> float: + v = _resolve_attr_at_time(ppisp_prim, f"inputs:{name}", t) + return float(v) if v is not None else default + + def _f2(name: str, default: Tuple[float, float]) -> Tuple[float, float]: + v = _resolve_attr_at_time(ppisp_prim, f"inputs:{name}", t) + if v is None: + return default + return (float(v[0]), float(v[1])) + + for ch in CHANNELS: + setattr(p, f"center_{ch.lower()}", _f2(f"vignettingCenter{ch}", (0.0, 0.0))) + setattr(p, f"alpha1_{ch.lower()}", _f(f"vignettingAlpha1{ch}", 0.0)) + setattr(p, f"alpha2_{ch.lower()}", _f(f"vignettingAlpha2{ch}", 0.0)) + setattr(p, f"alpha3_{ch.lower()}", _f(f"vignettingAlpha3{ch}", 0.0)) + return p + + +def _read_crf(ppisp_prim: Usd.Prim, t: Usd.TimeCode) -> CrfParams: + c = CrfParams() + + def _f(name: str, default: float) -> float: + v = _resolve_attr_at_time(ppisp_prim, f"inputs:{name}", t) + return float(v) if v is not None else default + + for ch in CHANNELS: + chl = ch.lower() + setattr(c, f"toe_{chl}", _f(f"crfToe{ch}", getattr(c, f"toe_{chl}"))) + setattr(c, f"shoulder_{chl}", _f(f"crfShoulder{ch}", getattr(c, f"shoulder_{chl}"))) + setattr(c, f"gamma_{chl}", _f(f"crfGamma{ch}", getattr(c, f"gamma_{chl}"))) + setattr(c, f"center_{chl}", _f(f"crfCenter{ch}", getattr(c, f"center_{chl}"))) + return c + + +def _read_color_latents(ppisp_prim: Usd.Prim, t: Usd.TimeCode) -> List[float]: + out: List[float] = [] + for name in ("colorLatentBlue", "colorLatentRed", "colorLatentGreen", "colorLatentNeutral"): + v = _resolve_attr_at_time(ppisp_prim, f"inputs:{name}", t) + if v is None: + out.extend([0.0, 0.0]) + else: + out.extend([float(v[0]), float(v[1])]) + return out + + +def _read_exposure(ppisp_prim: Usd.Prim, t: Usd.TimeCode) -> float: + v = _resolve_attr_at_time(ppisp_prim, "inputs:exposureOffset", t) + return float(v) if v is not None else 0.0 + + +def _slang_asset_path(prim: Usd.Prim) -> Optional[str]: + attr = prim.GetAttribute("info:spg:sourceAsset") + if not attr or not attr.IsValid(): + return None + val = attr.Get() + if val is None: + return None + return val.path if hasattr(val, "path") else str(val) + + +def _find_render_products(stage: Usd.Stage) -> List[Usd.Prim]: + render_scope = stage.GetPrimAtPath("/Render") + if not render_scope.IsValid(): + return [] + return [c for c in render_scope.GetChildren() if c.GetTypeName() == "RenderProduct"] + + +def _find_ppisp_and_controller(rp: Usd.Prim) -> Tuple[Optional[Usd.Prim], Optional[Usd.Prim]]: + ppisp = None + controller = None + for child in rp.GetChildren(): + if child.GetName() == "PPISP": + ppisp = child + elif child.GetName().startswith("PPISPController"): + controller = child + return ppisp, controller + + +def _frame_indices_for_prim(prim: Usd.Prim) -> List[float]: + """Union of authored time samples over the animated PPISP attributes.""" + samples: set = set() + for attr_name in ( + "inputs:exposureOffset", + "inputs:colorLatentBlue", + "inputs:colorLatentRed", + "inputs:colorLatentGreen", + "inputs:colorLatentNeutral", + ): + attr = prim.GetAttribute(attr_name) + if attr and attr.IsValid(): + samples.update(attr.GetTimeSamples() or []) + return sorted(samples) + + +# --------------------------------------------------------------------------- +# HDR image I/O +# --------------------------------------------------------------------------- + + +def _load_hdr(path: Path) -> np.ndarray: + if path.suffix.lower() == ".npy": + arr = np.load(path) + return arr.astype(np.float32) + if path.suffix.lower() in (".png", ".jpg", ".jpeg"): + img = Image.open(path).convert("RGB") + return (np.asarray(img).astype(np.float32) / 255.0) + if path.suffix.lower() == ".exr": + try: + import OpenEXR # type: ignore[import-not-found] + import Imath # type: ignore[import-not-found] + except ImportError as e: + raise RuntimeError(f"OpenEXR/Imath required to read {path}: {e}") + f = OpenEXR.InputFile(str(path)) + dw = f.header()["dataWindow"] + w = dw.max.x - dw.min.x + 1 + h = dw.max.y - dw.min.y + 1 + pt = Imath.PixelType(Imath.PixelType.FLOAT) + r, g, b = (np.frombuffer(f.channel(c, pt), dtype=np.float32).reshape(h, w) + for c in ("R", "G", "B")) + return np.stack([r, g, b], axis=-1) + raise RuntimeError(f"unsupported HDR format: {path.suffix}") + + +def _save_png(out_path: Path, image_rgba: np.ndarray) -> None: + out_path.parent.mkdir(parents=True, exist_ok=True) + Image.fromarray(image_rgba, mode="RGBA").save(out_path) + + +# --------------------------------------------------------------------------- +# Per-camera execution +# --------------------------------------------------------------------------- + + +def _process_render_product( + rp: Usd.Prim, + usd_dir: Path, + hdr_dir: Path, + out_dir: Path, + *, + frames: Optional[Iterable[int]] = None, +) -> None: + cam_name = rp.GetName() + ppisp_prim, controller_prim = _find_ppisp_and_controller(rp) + if ppisp_prim is None: + logger.warning("RenderProduct %s has no PPISP shader prim, skipping", cam_name) + return + + ppisp_slang = _slang_asset_path(ppisp_prim) + ctrl_slang = _slang_asset_path(controller_prim) if controller_prim is not None else None + if ppisp_slang is None: + logger.warning("RenderProduct %s PPISP shader has no info:spg:sourceAsset", cam_name) + return + + ppisp_slang_path = (usd_dir / ppisp_slang).resolve() + if not ppisp_slang_path.exists(): + logger.error("PPISP slang sidecar not found at %s", ppisp_slang_path) + return + ctrl_slang_path = None + if ctrl_slang is not None: + ctrl_slang_path = (usd_dir / ctrl_slang).resolve() + if not ctrl_slang_path.exists(): + logger.error("Controller slang sidecar not found at %s", ctrl_slang_path) + return + + hdr_cam_dir = hdr_dir / cam_name + if not hdr_cam_dir.exists(): + logger.warning("No HDR inputs for camera %s under %s, skipping", cam_name, hdr_dir) + return + + sample_times = _frame_indices_for_prim(ppisp_prim) + if not sample_times and controller_prim is not None: + # Controller-only path: time samples are encoded in the HDR folder names. + sample_times = sorted( + int(p.stem) for p in hdr_cam_dir.iterdir() if p.stem.isdigit() + ) + if frames is not None: + sample_times = [t for t in sample_times if int(t) in set(int(f) for f in frames)] + if not sample_times: + logger.warning("Camera %s has no frames to render", cam_name) + return + + logger.info("Rendering %s (%d frames%s)", + cam_name, len(sample_times), + " + controller" if ctrl_slang_path else "") + + for t in sample_times: + frame_index = int(t) + candidates = [ + hdr_cam_dir / f"{frame_index}.npy", + hdr_cam_dir / f"{frame_index}.exr", + hdr_cam_dir / f"{frame_index}.png", + ] + hdr_path = next((c for c in candidates if c.exists()), None) + if hdr_path is None: + logger.warning("No HDR input for %s frame %d", cam_name, frame_index) + continue + + hdr_image = _load_hdr(hdr_path) + timecode = Usd.TimeCode(float(t)) + vignette = _read_vignetting(ppisp_prim, timecode) + crf = _read_crf(ppisp_prim, timecode) + + if ctrl_slang_path is not None: + prior = _resolve_attr_at_time(controller_prim, "inputs:priorExposure", timecode) or 0.0 + weights_attr = controller_prim.GetAttribute("inputs:weights") + weights_val = weights_attr.Get(timecode) if weights_attr and weights_attr.IsValid() else None + if weights_val is None: + logger.error("Controller for %s has no inputs:weights value, skipping frame", cam_name) + continue + # USD's VtArray-backed ndarray comes back read-only / OWNDATA=False; + # slangpy.create_buffer rejects those, so force a writable copy. + weights = np.array(weights_val, dtype=np.float32, copy=True) + controller_out = run_controller(ctrl_slang_path, hdr_image, weights, + prior_exposure=float(prior)) + ldr = run_ppisp_dyn(ppisp_slang_path, hdr_image, controller_out, + vignette=vignette, crf=crf) + else: + exposure = _read_exposure(ppisp_prim, timecode) + color_latents = _read_color_latents(ppisp_prim, timecode) + ldr = run_ppisp_static(ppisp_slang_path, hdr_image, + exposure_offset=exposure, + color_latents=color_latents, + vignette=vignette, crf=crf) + + _save_png(out_dir / cam_name / f"{frame_index}.png", ldr) + logger.debug(" wrote %s/%s/%d.png", out_dir, cam_name, frame_index) + + +# --------------------------------------------------------------------------- +# CLI +# --------------------------------------------------------------------------- + + +def _resolve_usd_dir(usd_path: Path) -> Path: + """Slang/usda asset paths are relative to the USD file. For ``.usdz`` we + extract a temporary copy because the SPG sidecars are stored inside the + archive and slangpy needs them on disk.""" + if usd_path.suffix.lower() != ".usdz": + return usd_path.parent + + import tempfile + import zipfile + + target = Path(tempfile.mkdtemp(prefix="ppisp_usdz_")) + with zipfile.ZipFile(usd_path) as zf: + zf.extractall(target) + logger.info("Extracted %s → %s", usd_path, target) + return target + + +def main(argv: Optional[List[str]] = None) -> int: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("usd", type=Path, help="USD or USDZ file from the PPISP exporter") + parser.add_argument("hdr_dir", type=Path, + help="Directory of HDR inputs, organised as /.{npy,exr,png}") + parser.add_argument("out_dir", type=Path, + help="Where to write LDR PNG outputs") + parser.add_argument("--cameras", nargs="*", default=None, + help="Optional list of camera (RenderProduct) names to render") + parser.add_argument("--frames", nargs="*", type=int, default=None, + help="Optional list of frame indices to render") + parser.add_argument("--verbose", "-v", action="count", default=0, + help="Increase logging verbosity") + args = parser.parse_args(argv) + + logging.basicConfig(level=logging.WARNING - 10 * args.verbose, + format="%(asctime)s %(levelname)s %(name)s: %(message)s") + + if not args.usd.exists(): + logger.error("USD not found: %s", args.usd) + return 2 + + usd_dir = _resolve_usd_dir(args.usd) + if args.usd.suffix.lower() == ".usdz": + # Find the actual default scene file inside the extracted dir. + default_scene = next((p for p in usd_dir.glob("*.usd*") if p.suffix in (".usd", ".usda", ".usdc")), + None) + if default_scene is None: + logger.error("No top-level usd/usda/usdc inside %s", args.usd) + return 2 + usd_path = default_scene + else: + usd_path = args.usd + + stage = Usd.Stage.Open(str(usd_path)) + if stage is None: + logger.error("Failed to open USD stage at %s", usd_path) + return 2 + + products = _find_render_products(stage) + if not products: + logger.error("No RenderProducts found under /Render") + return 1 + + target_names = set(args.cameras) if args.cameras else None + for rp in products: + if target_names is not None and rp.GetName() not in target_names: + continue + _process_render_product( + rp, + usd_dir=usd_dir, + hdr_dir=args.hdr_dir, + out_dir=args.out_dir, + frames=args.frames, + ) + + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/tools/render_ppisp_spg/spg_runtime.py b/tools/render_ppisp_spg/spg_runtime.py new file mode 100644 index 00000000..d0aec5fb --- /dev/null +++ b/tools/render_ppisp_spg/spg_runtime.py @@ -0,0 +1,365 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 + +""" +Headless slangpy runtime for the PPISP SPG shader chain. + +This is *not* a full Kit SPG simulator. It executes the compute stages of +the PPISP SPG sidecars directly against a supplied HDR image so the +exported asset can be validated end-to-end without booting Omniverse. + +The harness uses slangpy's low-level pipeline API: load a Slang module, +create a compute pipeline from a chosen entry point, and dispatch with +resources bound through a ``ShaderCursor`` over the root +``ShaderObject``. This matches how SPG itself binds the same shaders. + +Three entry points are available: + +- :func:`run_controller` — ``ppisp_controller_.slang`` → + 9-element ``[exposureOffset, blue.xy, red.xy, green.xy, neutral.xy]``. +- :func:`run_ppisp_dyn` — ``ppisp_usd_spg_dyn.slang``, takes the + controller output texture; returns an HxWx4 uint8 LDR image. +- :func:`run_ppisp_static` — ``ppisp_usd_spg.slang`` (no controller). +""" + +from __future__ import annotations + +import dataclasses +import logging +import math +from pathlib import Path +from typing import Sequence, Tuple + +import numpy as np +import slangpy as spy + +logger = logging.getLogger(__name__) + + +@dataclasses.dataclass +class VignetteParams: + """Per-camera vignetting parameters in shader storage order.""" + center_r: Tuple[float, float] = (0.0, 0.0) + alpha1_r: float = 0.0 + alpha2_r: float = 0.0 + alpha3_r: float = 0.0 + center_g: Tuple[float, float] = (0.0, 0.0) + alpha1_g: float = 0.0 + alpha2_g: float = 0.0 + alpha3_g: float = 0.0 + center_b: Tuple[float, float] = (0.0, 0.0) + alpha1_b: float = 0.0 + alpha2_b: float = 0.0 + alpha3_b: float = 0.0 + + +@dataclasses.dataclass +class CrfParams: + """Per-camera per-channel toe/shoulder/gamma/center raw parameters.""" + toe_r: float = 0.013659 + shoulder_r: float = 0.013659 + gamma_r: float = 0.378165 + center_r: float = 0.0 + toe_g: float = 0.013659 + shoulder_g: float = 0.013659 + gamma_g: float = 0.378165 + center_g: float = 0.0 + toe_b: float = 0.013659 + shoulder_b: float = 0.013659 + gamma_b: float = 0.378165 + center_b: float = 0.0 + + +# --------------------------------------------------------------------------- +# Device + pipeline helpers +# --------------------------------------------------------------------------- + + +def _make_device(slang_dir: Path) -> spy.Device: + return spy.create_device(include_paths=[str(slang_dir)]) + + +_VK_BINDING_RE = __import__("re").compile(r"\[\[vk::binding\([^\]]+\)\]\]\s*") + + +def _build_pipeline(device: spy.Device, slang_path: Path, entry_point_name: str): + """Compile a Slang file and return its compute pipeline. + + The PPISP SPG shaders carry ``[[vk::binding(slot, set)]]`` annotations + that match Kit's SPG descriptor layout. Slangpy uses its own automatic + binding scheme, and the explicit annotations make resource binding + silently miss (the dispatch runs but reads zeroed buffers). We strip + the annotations *for slangpy dispatch only*; the on-disk slang file + used by SPG keeps them. + """ + session = device.slang_session + src = _VK_BINDING_RE.sub("", slang_path.read_text()) + module = session.load_module_from_source( + slang_path.stem, + src, + path=str(slang_path), + ) + entry_point = module.entry_point(entry_point_name) + program = session.link_program([module], [entry_point]) + pipeline = device.create_compute_pipeline(program) + return pipeline, program + + +def _create_hdr_input_texture(device: spy.Device, hdr_image: np.ndarray) -> spy.Texture: + if hdr_image.ndim != 3 or hdr_image.shape[2] not in (3, 4): + raise ValueError(f"hdr_image must be HxWx3 or HxWx4, got shape {hdr_image.shape}") + if hdr_image.dtype != np.float32: + hdr_image = hdr_image.astype(np.float32, copy=False) + h, w, c = hdr_image.shape + if c == 3: + rgba = np.empty((h, w, 4), dtype=np.float32) + rgba[..., :3] = hdr_image + rgba[..., 3] = 1.0 + hdr_image = rgba + return device.create_texture( + width=w, + height=h, + format=spy.Format.rgba32_float, + usage=spy.TextureUsage.shader_resource | spy.TextureUsage.unordered_access, + data=np.ascontiguousarray(hdr_image), + ) + + +def _create_controller_texture(device: spy.Device, values: np.ndarray) -> spy.Texture: + flat = np.asarray(values, dtype=np.float32).reshape(-1) + if flat.size != 9: + raise ValueError(f"controller values must be 9 floats, got {flat.size}") + # 9x1 single-channel float texture, indexed at (0..8, 0). + return device.create_texture( + width=9, + height=1, + format=spy.Format.r32_float, + usage=spy.TextureUsage.shader_resource, + data=np.ascontiguousarray(flat.reshape(1, 9)), + ) + + +def _read_r32f_row(tex: spy.Texture) -> np.ndarray: + """Read back a 1-row r32_float texture as a flat float32 numpy array.""" + arr = tex.to_numpy() + return np.asarray(arr, dtype=np.float32).reshape(-1) + + +def _read_rgba8(tex: spy.Texture, h: int, w: int) -> np.ndarray: + arr = tex.to_numpy() + return np.asarray(arr, dtype=np.uint8).reshape(h, w, 4) + + +# --------------------------------------------------------------------------- +# Cursor binding helpers +# --------------------------------------------------------------------------- + + +def _set_param_block(cursor: spy.ShaderCursor, block_name: str, fields: dict) -> None: + """Populate a slang ParameterBlock by field name. The cursor we get + from the root object is itself name-addressable, so ``cursor[name]`` + walks into the parameter block automatically.""" + block = cursor[block_name] + for k, v in fields.items(): + block[k] = v + + +def _ceildiv(a: int, b: int) -> int: + return (a + b - 1) // b + + +# --------------------------------------------------------------------------- +# Controller dispatch +# --------------------------------------------------------------------------- + + +def run_controller( + slang_path: str | Path, + hdr_image: np.ndarray, + weights: np.ndarray, + prior_exposure: float = 0.0, + *, + device: spy.Device | None = None, +) -> np.ndarray: + """Dispatch the PPISP controller shader and return its 9 outputs. + + ``weights`` must be a flat float32 buffer matching the layout encoded + in ``ppisp_controller.slang`` (see + :data:`threedgrut.export.usd.writers.ppisp_controller_writer.EXPECTED_WEIGHTS_LEN`). + """ + slang_path = Path(slang_path) + if device is None: + device = _make_device(slang_path.parent) + + pipeline, _ = _build_pipeline(device, slang_path, "controllerProcess") + in_tex = _create_hdr_input_texture(device, hdr_image) + out_tex = device.create_texture( + width=9, + height=1, + format=spy.Format.r32_float, + usage=spy.TextureUsage.shader_resource | spy.TextureUsage.unordered_access, + ) + flat_weights = np.ascontiguousarray(weights.astype(np.float32, copy=False).reshape(-1)) + weights_buf = device.create_buffer( + element_count=int(flat_weights.size), + struct_size=4, + usage=spy.BufferUsage.shader_resource, + data=flat_weights, + ) + + encoder = device.create_command_encoder() + with encoder.begin_compute_pass() as cp: + shader_obj = cp.bind_pipeline(pipeline) + cur = spy.ShaderCursor(shader_obj) + # weights live inside the g_Params ParameterBlock now, so SPG's + # reflection finds them under "params:weights" -- silences the + # "Failed to find parameter 'params:weights'" warning in Kit. + cur["g_Params"]["priorExposure"] = float(prior_exposure) + cur["g_Params"]["weights"] = weights_buf + cur["g_InTex"] = in_tex + cur["g_OutTex"] = out_tex + cp.dispatch(spy.math.uint3(32, 1, 1)) + device.submit_command_buffer(encoder.finish()) + device.wait() + + return _read_r32f_row(out_tex)[:9] + + +# --------------------------------------------------------------------------- +# PPISP dispatches +# --------------------------------------------------------------------------- + + +def _vignette_dict(v: VignetteParams) -> dict: + return { + "vignettingCenterR": list(v.center_r), + "vignettingAlpha1R": v.alpha1_r, + "vignettingAlpha2R": v.alpha2_r, + "vignettingAlpha3R": v.alpha3_r, + "vignettingCenterG": list(v.center_g), + "vignettingAlpha1G": v.alpha1_g, + "vignettingAlpha2G": v.alpha2_g, + "vignettingAlpha3G": v.alpha3_g, + "vignettingCenterB": list(v.center_b), + "vignettingAlpha1B": v.alpha1_b, + "vignettingAlpha2B": v.alpha2_b, + "vignettingAlpha3B": v.alpha3_b, + } + + +def _crf_dict(c: CrfParams) -> dict: + return { + "crfToeR": c.toe_r, + "crfShoulderR": c.shoulder_r, + "crfGammaR": c.gamma_r, + "crfCenterR": c.center_r, + "crfToeG": c.toe_g, + "crfShoulderG": c.shoulder_g, + "crfGammaG": c.gamma_g, + "crfCenterG": c.center_g, + "crfToeB": c.toe_b, + "crfShoulderB": c.shoulder_b, + "crfGammaB": c.gamma_b, + "crfCenterB": c.center_b, + } + + +def run_ppisp_dyn( + slang_path: str | Path, + hdr_image: np.ndarray, + controller_output: np.ndarray, + vignette: VignetteParams, + crf: CrfParams, + *, + device: spy.Device | None = None, +) -> np.ndarray: + """Run ``ppisp_usd_spg_dyn.slang`` and return an HxWx4 uint8 LDR image.""" + slang_path = Path(slang_path) + if device is None: + device = _make_device(slang_path.parent) + + pipeline, _ = _build_pipeline(device, slang_path, "ppispProcessDyn") + h, w = hdr_image.shape[:2] + + in_tex = _create_hdr_input_texture(device, hdr_image) + ctrl_tex = _create_controller_texture(device, controller_output) + out_tex = device.create_texture( + width=w, + height=h, + format=spy.Format.rgba8_unorm, + usage=spy.TextureUsage.shader_resource | spy.TextureUsage.unordered_access, + ) + + encoder = device.create_command_encoder() + with encoder.begin_compute_pass() as cp: + shader_obj = cp.bind_pipeline(pipeline) + cur = spy.ShaderCursor(shader_obj) + _set_param_block(cur, "g_Params", + {**_vignette_dict(vignette), **_crf_dict(crf)}) + cur["g_InTex"] = in_tex + cur["g_ControllerOut"] = ctrl_tex + cur["g_OutTex"] = out_tex + cp.dispatch(spy.math.uint3(_ceildiv(w, 16) * 16, + _ceildiv(h, 16) * 16, 1)) + device.submit_command_buffer(encoder.finish()) + device.wait() + + return _read_rgba8(out_tex, h, w) + + +def run_ppisp_static( + slang_path: str | Path, + hdr_image: np.ndarray, + exposure_offset: float, + color_latents: Sequence[float], + vignette: VignetteParams, + crf: CrfParams, + *, + device: spy.Device | None = None, +) -> np.ndarray: + """Run ``ppisp_usd_spg.slang`` (no controller) and return an LDR uint8 image.""" + slang_path = Path(slang_path) + if len(color_latents) != 8: + raise ValueError(f"color_latents must have 8 entries, got {len(color_latents)}") + if device is None: + device = _make_device(slang_path.parent) + + pipeline, _ = _build_pipeline(device, slang_path, "ppispProcess") + + h, w = hdr_image.shape[:2] + in_tex = _create_hdr_input_texture(device, hdr_image) + out_tex = device.create_texture( + width=w, + height=h, + format=spy.Format.rgba8_unorm, + usage=spy.TextureUsage.shader_resource | spy.TextureUsage.unordered_access, + ) + + fields = { + "exposureOffset": float(exposure_offset), + "colorLatentBlue": [float(color_latents[0]), float(color_latents[1])], + "colorLatentRed": [float(color_latents[2]), float(color_latents[3])], + "colorLatentGreen": [float(color_latents[4]), float(color_latents[5])], + "colorLatentNeutral": [float(color_latents[6]), float(color_latents[7])], + **_vignette_dict(vignette), + **_crf_dict(crf), + } + encoder = device.create_command_encoder() + with encoder.begin_compute_pass() as cp: + shader_obj = cp.bind_pipeline(pipeline) + cur = spy.ShaderCursor(shader_obj) + _set_param_block(cur, "g_Params", fields) + cur["g_InTex"] = in_tex + cur["g_OutTex"] = out_tex + cp.dispatch(spy.math.uint3(_ceildiv(w, 16) * 16, + _ceildiv(h, 16) * 16, 1)) + device.submit_command_buffer(encoder.finish()) + device.wait() + + return _read_rgba8(out_tex, h, w) diff --git a/tools/render_ppisp_spg/validate_controller.py b/tools/render_ppisp_spg/validate_controller.py new file mode 100644 index 00000000..7d3554ff --- /dev/null +++ b/tools/render_ppisp_spg/validate_controller.py @@ -0,0 +1,179 @@ +#!/usr/bin/env python3 +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Numerical sanity check: generate a controller slang for a torch +``_PPISPController`` with known weights, dispatch it via slangpy, and +compare its 9-element output to the PyTorch forward pass. + +This script does not require the full 3DGRUT environment — only +``torch``, ``numpy``, ``slangpy`` and the in-repo writer module. It +fabricates a controller (without needing a ``ppisp.PPISP`` checkpoint) +by reproducing ``ppisp._PPISPController`` from the public +architecture description. +""" + +from __future__ import annotations + +import argparse +import logging +import sys +from pathlib import Path + +import numpy as np + +sys.path.insert(0, str(Path(__file__).resolve().parents[2])) + +# Import the writer directly (bypass threedgrut.__init__ which depends on +# heavy CUDA-only packages we don't need here). +import importlib.util as _ilu # noqa: E402 +import types as _types # noqa: E402 + +# Stub the threedgrut packages so we don't trigger their __init__. +for _pkg in ( + "threedgrut", + "threedgrut.export", + "threedgrut.export.usd", + "threedgrut.export.usd.writers", +): + if _pkg not in sys.modules: + sys.modules[_pkg] = _types.ModuleType(_pkg) + +# Stub stage_utils so the writer can import NamedSerialized. +_stage_utils_stub = _types.ModuleType("threedgrut.export.usd.stage_utils") +import dataclasses as _dc + + +@_dc.dataclass +class _NamedSerialized: + filename: str + serialized: bytes + + +_stage_utils_stub.NamedSerialized = _NamedSerialized +sys.modules["threedgrut.export.usd.stage_utils"] = _stage_utils_stub + +# Stub ppisp_spg package so get_controller_sidecars() can resolve _SPG_DIR. +_ppisp_spg_stub = _types.ModuleType("threedgrut.export.usd.ppisp_spg") +_ppisp_spg_stub._SPG_DIR = ( + Path(__file__).resolve().parents[2] / "threedgrut/export/usd/ppisp_spg" +) +sys.modules["threedgrut.export.usd.ppisp_spg"] = _ppisp_spg_stub + +_writer_path = ( + Path(__file__).resolve().parents[2] + / "threedgrut/export/usd/writers/ppisp_controller_writer.py" +) +_spec = _ilu.spec_from_file_location( + "threedgrut.export.usd.writers.ppisp_controller_writer", str(_writer_path) +) +_writer_mod = _ilu.module_from_spec(_spec) +sys.modules["threedgrut.export.usd.writers.ppisp_controller_writer"] = _writer_mod +_spec.loader.exec_module(_writer_mod) +EXPECTED_SIZES = _writer_mod.EXPECTED_SIZES +flatten_controller_weights = _writer_mod.flatten_controller_weights + +from tools.render_ppisp_spg.spg_runtime import run_controller # noqa: E402 + + +logger = logging.getLogger("validate_controller") + + +def _make_test_controller(seed: int = 0): + """Build a torch module with the same architecture as + ``ppisp._PPISPController``. Importing the real one is preferred but + we duplicate it here so the validator runs without the ppisp package.""" + import torch + from torch import nn + + class _Controller(nn.Module): + def __init__(self): + super().__init__() + cfd = EXPECTED_SIZES["cnn_feature_dim"] + grid = (EXPECTED_SIZES["pool_grid_h"], EXPECTED_SIZES["pool_grid_w"]) + self.cnn_encoder = nn.Sequential( + nn.Conv2d(3, 16, kernel_size=1), + nn.MaxPool2d(EXPECTED_SIZES["input_downsampling"], + stride=EXPECTED_SIZES["input_downsampling"]), + nn.ReLU(inplace=True), + nn.Conv2d(16, 32, kernel_size=1), + nn.ReLU(inplace=True), + nn.Conv2d(32, cfd, kernel_size=1), + nn.AdaptiveAvgPool2d(grid), + nn.Flatten(), + ) + in_dim = cfd * grid[0] * grid[1] + 1 + hd = EXPECTED_SIZES["mlp_hidden_dim"] + self.mlp_trunk = nn.Sequential( + nn.Linear(in_dim, hd), nn.ReLU(inplace=True), + nn.Linear(hd, hd), nn.ReLU(inplace=True), + nn.Linear(hd, hd), nn.ReLU(inplace=True), + ) + self.exposure_head = nn.Linear(hd, 1) + self.color_head = nn.Linear(hd, EXPECTED_SIZES["color_params_per_frame"]) + + def forward(self, rgb: torch.Tensor, prior_exposure: torch.Tensor): + features = self.cnn_encoder(rgb.permute(2, 0, 1).unsqueeze(0).detach()) + features = torch.cat([features.squeeze(0), prior_exposure], dim=0) + hidden = self.mlp_trunk(features) + return self.exposure_head(hidden).squeeze(-1), self.color_head(hidden) + + torch.manual_seed(seed) + ctrl = _Controller().eval() + # Mostly-zero weights with a tiny perturbation so outputs are non-trivial. + with torch.no_grad(): + for p in ctrl.parameters(): + p.normal_(0.0, 0.01) + return ctrl + + +def _torch_reference(ctrl, hdr_image: np.ndarray, prior_exposure: float) -> np.ndarray: + import torch + rgb = torch.from_numpy(hdr_image).float() + pe = torch.tensor([prior_exposure], dtype=torch.float32) + with torch.no_grad(): + exposure, color = ctrl(rgb, pe) + return np.concatenate([ + np.array([float(exposure)], dtype=np.float32), + color.cpu().numpy().astype(np.float32), + ]) + + +def main(argv=None) -> int: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("--width", type=int, default=64) + parser.add_argument("--height", type=int, default=48) + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--prior", type=float, default=0.25) + parser.add_argument("--tol", type=float, default=1.0e-3, + help="abs tol per output element") + parser.add_argument("--keep", type=Path, default=None, + help="Where to write the generated slang file (defaults to a tmp dir)") + parser.add_argument("--verbose", "-v", action="count", default=0) + args = parser.parse_args(argv) + logging.basicConfig(level=logging.WARNING - 10 * args.verbose, + format="%(asctime)s %(levelname)s %(name)s: %(message)s") + + ctrl = _make_test_controller(args.seed) + rng = np.random.default_rng(args.seed) + hdr = (rng.random((args.height, args.width, 3), dtype=np.float32) * 0.8 + 0.1) + + expected = _torch_reference(ctrl, hdr, args.prior) + + weights = flatten_controller_weights(ctrl) + slang_path = Path(__file__).resolve().parents[2] / ( + "threedgrut/export/usd/ppisp_spg/ppisp_controller.slang" + ) + actual = run_controller(slang_path, hdr, weights, prior_exposure=args.prior) + diff = np.abs(actual - expected) + + print(f"reference: {expected}") + print(f"slangpy: {actual}") + print(f"abs diff: {diff}") + print(f"max abs diff: {diff.max():.6g} (tol={args.tol})") + + return 0 if diff.max() <= args.tol else 1 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/tools/render_ppisp_spg/validate_e2e.py b/tools/render_ppisp_spg/validate_e2e.py new file mode 100644 index 00000000..4f6e4b33 --- /dev/null +++ b/tools/render_ppisp_spg/validate_e2e.py @@ -0,0 +1,335 @@ +#!/usr/bin/env python3 +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""End-to-end PPISP SPG export + render validation. + +The pipeline this script exercises: + +1. Build a ``ppisp.PPISP`` module with non-trivial random weights and a + handful of synthetic HDR frames. +2. Author a USD stage with one ``RenderProduct`` per camera, attach the + PPISP shader chain (controller + dynamic PPISP) using the in-repo + writer, and save the USD plus the SPG sidecars to disk. +3. Run the slangpy CLI (`render_renderproduct.py`) against the saved USD + and the synthetic HDR frames to produce LDR PNGs through the slang + shaders. +4. Apply the same PPISP module *in PyTorch* to the same HDR frames, save + them as the reference LDR PNGs. +5. Compare slangpy vs PyTorch images per-frame; report PSNR / max abs + diff. Pass / fail on a configurable PSNR threshold. + +The "training" step is replaced with a perturbed PPISP module because +the validation question is "does the SPG asset reproduce the in-process +PPISP forward pass for these (camera, frame) pairs", not "is the trained +model good". A real trained checkpoint would give the same answer +because the path through both runtimes is identical. +""" + +from __future__ import annotations + +import argparse +import logging +import math +import shutil +import subprocess +import sys +import tempfile +from pathlib import Path +from typing import Dict, List, Tuple + +import numpy as np +from PIL import Image + +import torch +from pxr import Gf, Sdf, Usd, UsdGeom + +sys.path.insert(0, str(Path(__file__).resolve().parents[2])) + +from threedgrut.export.usd.writers.ppisp_writer import ( # noqa: E402 + add_ppisp_to_all_render_products, +) +from threedgrut.export.usd.writers.ppisp_controller_writer import ( # noqa: E402 + get_controller_sidecars, +) +from threedgrut.export.usd.ppisp_spg import ( # noqa: E402 + get_ppisp_spg_dyn_files, +) +from ppisp import PPISP, DEFAULT_PPISP_CONFIG # noqa: E402 + +logger = logging.getLogger("validate_e2e") + + +def _make_perturbed_ppisp(num_cameras: int, num_frames: int, seed: int) -> PPISP: + """Build a PPISP module with non-trivial parameters for every stage.""" + torch.manual_seed(seed) + cfg = DEFAULT_PPISP_CONFIG + ppisp = PPISP(num_cameras=num_cameras, num_frames=num_frames, config=cfg).eval() + with torch.no_grad(): + ppisp.exposure_params.normal_(mean=0.0, std=0.5) + ppisp.color_params.normal_(mean=0.0, std=0.05) + ppisp.vignetting_params.normal_(mean=0.0, std=0.02) + # Keep CRF near identity so the comparison isn't dominated by huge + # nonlinearities; the math is identical between paths regardless. + ppisp.crf_params.add_(torch.randn_like(ppisp.crf_params) * 0.05) + # Perturb every controller's weights so the per-frame override has + # work to do during the dynamic-PPISP path. + for controller in ppisp.controllers: + for p in controller.parameters(): + p.normal_(mean=0.0, std=0.01) + return ppisp + + +def _build_render_product(stage: Usd.Stage, cam_name: str, width: int, height: int) -> Usd.Prim: + rp_path = f"/Render/{cam_name}" + rp = stage.DefinePrim(rp_path, "RenderProduct") + rp.CreateAttribute("resolution", Sdf.ValueTypeNames.Int2).Set(Gf.Vec2i(width, height)) + cam_prim = stage.DefinePrim(f"/World/Cameras/{cam_name}", "Camera") + rp.CreateRelationship("camera").SetTargets([cam_prim.GetPath()]) + hdr = stage.DefinePrim(f"{rp_path}/HdrColor", "RenderVar") + hdr.CreateAttribute("sourceName", Sdf.ValueTypeNames.String).Set("HdrColor") + hdr.CreateAttribute("omni:rtx:aov", Sdf.ValueTypeNames.Opaque, custom=False) + rp.CreateRelationship("orderedVars").SetTargets([Sdf.Path("HdrColor")]) + return rp + + +class _SyntheticDataset: + """Minimal stub matching what build_camera_frame_mapping reads.""" + + def __init__(self, frame_to_camera: List[int], camera_names: List[str]): + self._f2c = list(frame_to_camera) + self._names = list(camera_names) + + def __len__(self) -> int: + return len(self._f2c) + + def get_camera_names(self) -> List[str]: + return list(self._names) + + def get_camera_idx(self, frame_idx: int) -> int: + return int(self._f2c[frame_idx]) + + +def _torch_reference_ldr( + ppisp: PPISP, hdr_image: np.ndarray, camera_idx: int, frame_idx: int +) -> np.ndarray: + """Apply PPISP in PyTorch with the *same* (camera, frame) state the + slang controller path will see at runtime: the controller predicts + exposure / color from the HDR image, while vignetting and CRF use + the per-camera parameters.""" + h, w = hdr_image.shape[:2] + rgb = torch.from_numpy(hdr_image).float() + # Pixel coords like the in-process renderer: integer (x, y). + yy, xx = torch.meshgrid( + torch.arange(h, dtype=torch.float32), + torch.arange(w, dtype=torch.float32), + indexing="ij", + ) + pixel_coords = torch.stack([xx, yy], dim=-1) # [H, W, 2] + + # We want the same path as the slang shader: controller predicts the + # frame state, PPISP applies it. PPISP.forward picks the controller + # path when frame_idx == -1 (novel-view). Pass -1 here so the torch + # reference exercises the controller, matching the slang path. + ppisp_eval = ppisp.eval().to("cuda") + rgb_cuda = rgb.to("cuda") + pixel_coords_cuda = pixel_coords.to("cuda") + with torch.no_grad(): + out = ppisp_eval( + rgb_cuda, + pixel_coords_cuda, + resolution=(w, h), + camera_idx=camera_idx, + frame_idx=-1, + ) + out = out.detach().cpu().numpy() + # The PPISP CUDA kernel saturates internally; convert to uint8 like the + # slang shader does (`saturate(rgb)` -> rgba8_unorm). + ldr = (np.clip(out, 0.0, 1.0) * 255.0 + 0.5).astype(np.uint8) + rgba = np.empty((h, w, 4), dtype=np.uint8) + rgba[..., :3] = ldr + rgba[..., 3] = 255 + return rgba + + +def _save_png(path: Path, image_rgba: np.ndarray) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + Image.fromarray(image_rgba, mode="RGBA").save(path) + + +def _save_npy_hdr(path: Path, hdr: np.ndarray) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + np.save(path, hdr.astype(np.float32)) + + +def _psnr(a: np.ndarray, b: np.ndarray) -> float: + diff = a.astype(np.float32) - b.astype(np.float32) + mse = float((diff * diff).mean()) + if mse <= 0: + return float("inf") + return 20.0 * math.log10(255.0 / math.sqrt(mse)) + + +def _author_stage( + out_dir: Path, + ppisp: PPISP, + cam_names: List[str], + frame_to_camera: List[int], + resolutions: Dict[str, Tuple[int, int]], +) -> Path: + """Build and save the USD stage + ship the SPG sidecars to ``out_dir``.""" + out_dir.mkdir(parents=True, exist_ok=True) + stage = Usd.Stage.CreateNew(str(out_dir / "scene.usda")) + stage.SetMetadata("upAxis", UsdGeom.Tokens.y) + stage.DefinePrim("/World", "Xform") + stage.DefinePrim("/Render", "Scope") + for cam_name, (w, h) in resolutions.items(): + _build_render_product(stage, cam_name, w, h) + + dataset = _SyntheticDataset(frame_to_camera, cam_names) + from threedgrut.export.usd.writers.ppisp_writer import build_camera_frame_mapping + cam_names_built, mapping = build_camera_frame_mapping(dataset) + + add_ppisp_to_all_render_products( + stage=stage, + ppisp=ppisp, + camera_names=cam_names_built, + camera_frame_mapping=mapping, + use_controller=True, + ) + stage.GetRootLayer().Save() + + # Sidecars: shared dyn PPISP + shared controller files. + for s in get_ppisp_spg_dyn_files(): + (out_dir / s.filename).write_bytes(s.serialized) + for s in get_controller_sidecars(): + (out_dir / s.filename).write_bytes(s.serialized) + logger.info("Authored stage at %s with %d sidecars", + out_dir, len(list(out_dir.glob("*.slang*")))) + return out_dir / "scene.usda" + + +def main(argv=None) -> int: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("--num-cameras", type=int, default=2) + parser.add_argument("--frames-per-camera", type=int, default=2) + parser.add_argument("--width", type=int, default=128) + parser.add_argument("--height", type=int, default=96) + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--psnr-threshold", type=float, default=40.0, + help="Per-frame minimum PSNR (dB) for pass.") + parser.add_argument("--keep", type=Path, default=None, + help="Keep working dir at this path instead of a tmpdir.") + parser.add_argument("--verbose", "-v", action="count", default=0) + args = parser.parse_args(argv) + logging.basicConfig(level=logging.INFO - 10 * args.verbose, + format="%(asctime)s %(levelname)s %(name)s: %(message)s") + + if not torch.cuda.is_available(): + raise SystemExit("PPISP forward requires CUDA.") + + work_dir = args.keep + cleanup = work_dir is None + if work_dir is None: + work_dir = Path(tempfile.mkdtemp(prefix="ppisp_e2e_")) + work_dir.mkdir(parents=True, exist_ok=True) + usd_dir = work_dir / "usd" + hdr_dir = work_dir / "hdr" + ref_dir = work_dir / "reference" + slang_dir = work_dir / "slangpy" + + try: + # ---------------------------------------------------------------- + # 1. Build a non-trivial PPISP and a synthetic frame plan. + # ---------------------------------------------------------------- + cam_names = [f"cam_{i}" for i in range(args.num_cameras)] + frame_to_camera: List[int] = [] + for cam_idx in range(args.num_cameras): + frame_to_camera.extend([cam_idx] * args.frames_per_camera) + num_frames = len(frame_to_camera) + resolutions = {n: (args.width, args.height) for n in cam_names} + ppisp = _make_perturbed_ppisp(args.num_cameras, num_frames, seed=args.seed) + + # ---------------------------------------------------------------- + # 2. Synthesise HDR inputs and the PyTorch reference LDR images. + # ---------------------------------------------------------------- + rng = np.random.default_rng(args.seed) + for frame_idx, cam_idx in enumerate(frame_to_camera): + cam_name = cam_names[cam_idx] + # Smooth HDR with a few high-frequency components so the + # controller and the vignetting see real spatial variation. + yy, xx = np.mgrid[0:args.height, 0:args.width].astype(np.float32) + base = 0.4 + 0.4 * rng.random((3,), dtype=np.float32) + hdr = ( + base[None, None, :] + + 0.15 * np.cos((xx / args.width * 4 + frame_idx) * 2 * np.pi)[..., None] + + 0.15 * np.sin((yy / args.height * 4 + cam_idx) * 2 * np.pi)[..., None] + ).astype(np.float32) + hdr += rng.normal(scale=0.02, size=hdr.shape).astype(np.float32) + hdr = np.clip(hdr, 0.0, 1.5) + _save_npy_hdr(hdr_dir / cam_name / f"{frame_idx}.npy", hdr) + + ref = _torch_reference_ldr(ppisp, hdr, cam_idx, frame_idx) + _save_png(ref_dir / cam_name / f"{frame_idx}.png", ref) + + # ---------------------------------------------------------------- + # 3. Author the USD stage + sidecars on disk. + # ---------------------------------------------------------------- + usd_path = _author_stage( + usd_dir, ppisp, cam_names, frame_to_camera, resolutions + ) + + # ---------------------------------------------------------------- + # 4. Run the slangpy CLI against the saved USD. + # ---------------------------------------------------------------- + cli = Path(__file__).resolve().parent / "render_renderproduct.py" + cmd = [ + sys.executable, str(cli), + str(usd_path), str(hdr_dir), str(slang_dir), + "-vv", + ] + logger.info("Running slangpy CLI: %s", " ".join(cmd)) + proc = subprocess.run(cmd, capture_output=True, text=True) + if proc.returncode != 0: + print(proc.stdout) + print(proc.stderr, file=sys.stderr) + raise SystemExit(f"render_renderproduct.py failed (exit {proc.returncode})") + + # ---------------------------------------------------------------- + # 5. Compare per-frame reference vs slangpy outputs. + # ---------------------------------------------------------------- + worst_psnr = float("inf") + worst_pair = None + all_pass = True + for frame_idx, cam_idx in enumerate(frame_to_camera): + cam_name = cam_names[cam_idx] + ref_img = np.asarray(Image.open(ref_dir / cam_name / f"{frame_idx}.png").convert("RGBA")) + sl_path = slang_dir / cam_name / f"{frame_idx}.png" + if not sl_path.exists(): + logger.error("slangpy output missing: %s", sl_path) + all_pass = False + continue + sl_img = np.asarray(Image.open(sl_path).convert("RGBA")) + psnr = _psnr(ref_img[..., :3], sl_img[..., :3]) + max_abs = int(np.max(np.abs(ref_img[..., :3].astype(int) - sl_img[..., :3].astype(int)))) + ok = psnr >= args.psnr_threshold + print(f" cam={cam_name} frame={frame_idx} " + f"PSNR={psnr:7.3f} dB max|Δ|={max_abs} {'OK' if ok else 'FAIL'}") + if psnr < worst_psnr: + worst_psnr = psnr + worst_pair = (cam_name, frame_idx) + if not ok: + all_pass = False + + print() + print(f"worst frame: {worst_pair} at {worst_psnr:.3f} dB " + f"(threshold {args.psnr_threshold} dB)") + return 0 if all_pass else 1 + + finally: + if cleanup: + shutil.rmtree(work_dir, ignore_errors=True) + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/tools/render_ppisp_spg/validate_real_ppisp.py b/tools/render_ppisp_spg/validate_real_ppisp.py new file mode 100644 index 00000000..9fa9d966 --- /dev/null +++ b/tools/render_ppisp_spg/validate_real_ppisp.py @@ -0,0 +1,95 @@ +#!/usr/bin/env python3 +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""End-to-end validator using the *real* ``ppisp`` package. + +Builds an actual ``ppisp.PPISP`` module (one camera, one frame, default +``PPISPConfig``), runs its controller through both PyTorch and the +slangpy SPG harness, and reports the per-output abs diff. Run this after +``install_env_uv.sh`` so the full env including ``ppisp`` is available. +""" + +from __future__ import annotations + +import argparse +import logging +import sys +from pathlib import Path + +import numpy as np + +# Make the in-repo writer importable. The module path goes through +# threedgrut.export.usd.writers.ppisp_controller_writer; that import +# chain pulls heavy CUDA pieces from threedgrut/__init__.py which exist +# in the real env, so we just rely on regular imports here. +sys.path.insert(0, str(Path(__file__).resolve().parents[2])) + +from threedgrut.export.usd.writers.ppisp_controller_writer import ( # noqa: E402 + flatten_controller_weights, +) +from tools.render_ppisp_spg.spg_runtime import run_controller # noqa: E402 + +logger = logging.getLogger("validate_real_ppisp") + + +def main(argv=None) -> int: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("--width", type=int, default=64) + parser.add_argument("--height", type=int, default=48) + parser.add_argument("--prior", type=float, default=0.0) + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--tol", type=float, default=1.0e-4) + parser.add_argument("--num-cameras", type=int, default=1) + parser.add_argument("--num-frames", type=int, default=1) + args = parser.parse_args(argv) + logging.basicConfig(level=logging.INFO, + format="%(asctime)s %(levelname)s %(name)s: %(message)s") + + import torch + from ppisp import PPISP, DEFAULT_PPISP_CONFIG + + torch.manual_seed(args.seed) + rng = np.random.default_rng(args.seed) + + ppisp = PPISP(num_cameras=args.num_cameras, num_frames=args.num_frames, + config=DEFAULT_PPISP_CONFIG).eval() + if not ppisp.controllers or len(ppisp.controllers) == 0: + raise SystemExit("PPISP has no controllers — config.use_controller must be True.") + controller = ppisp.controllers[0] + + # Perturb the controller so the output is non-trivial (PPISP initialises + # everything to zero; without weights, the slang/torch outputs would both + # be zero and the validation would be vacuous). + with torch.no_grad(): + for p in controller.parameters(): + p.normal_(0.0, 0.01) + + hdr = (rng.random((args.height, args.width, 3), dtype=np.float32) * 0.6 + 0.2) + + rgb_t = torch.from_numpy(hdr).float().to(controller.exposure_head.weight.device) + pe_t = torch.tensor([args.prior], dtype=torch.float32, device=rgb_t.device) + with torch.no_grad(): + exposure, color = controller(rgb_t, pe_t) + expected = np.concatenate([ + np.array([float(exposure)], dtype=np.float32), + color.detach().cpu().numpy().astype(np.float32), + ]) + + weights = flatten_controller_weights(controller) + slang_path = Path(__file__).resolve().parents[2] / ( + "threedgrut/export/usd/ppisp_spg/ppisp_controller.slang" + ) + actual = run_controller(slang_path, hdr, weights, prior_exposure=args.prior) + + diff = np.abs(actual - expected) + print(f"reference: {expected}") + print(f"slangpy: {actual}") + print(f"abs diff: {diff}") + print(f"max abs diff: {diff.max():.6g} (tol={args.tol})") + + return 0 if diff.max() <= args.tol else 1 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/tools/render_ppisp_spg/validate_trained.py b/tools/render_ppisp_spg/validate_trained.py new file mode 100644 index 00000000..da39ea1e --- /dev/null +++ b/tools/render_ppisp_spg/validate_trained.py @@ -0,0 +1,366 @@ +#!/usr/bin/env python3 +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""End-to-end validation against a *trained* checkpoint. + +Workflow: + +1. Load the trained checkpoint (which contains the model + the trained + PPISP module including its controllers). +2. For every val frame, run the gaussian renderer to get the pre-PPISP + HDR image. Save it as ``hdr//.npy``. +3. Apply PPISP in PyTorch (the same novel-view path the SPG shader will + use, i.e. ``frame_idx=-1`` so the controller predicts the per-frame + correction). Save it as ``reference//.png``. +4. Author the controller-aware USD via the production exporter and ship + the SPG sidecars to ``usd/``. +5. Run the slangpy CLI (`render_renderproduct.py`) on the USD with the + HDR inputs from step (2) and write its outputs to ``slangpy/``. +6. Compare reference vs slangpy LDR per frame; report PSNR / max abs + diff. Pass / fail on a configurable PSNR threshold. + +This is the workflow a downstream consumer of the asset would actually +exercise: real trained PPISP, real exporter call, real slang dispatch +through the CLI. +""" + +from __future__ import annotations + +import argparse +import logging +import math +import shutil +import subprocess +import sys +import tempfile +from pathlib import Path +from typing import Dict, List, Tuple + +import numpy as np +from PIL import Image + +import torch +from pxr import Gf, Sdf, Usd, UsdGeom + +sys.path.insert(0, str(Path(__file__).resolve().parents[2])) + +from threedgrut.export.usd.writers.ppisp_writer import ( # noqa: E402 + add_ppisp_to_all_render_products, build_camera_frame_mapping, +) +from threedgrut.export.usd.writers.ppisp_controller_writer import ( # noqa: E402 + get_controller_sidecars, +) +from threedgrut.export.usd.ppisp_spg import get_ppisp_spg_dyn_files # noqa: E402 +from threedgrut.render import Renderer # noqa: E402 + +logger = logging.getLogger("validate_trained") + + +def _save_npy(path: Path, arr: np.ndarray) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + np.save(path, arr.astype(np.float32)) + + +def _save_png(path: Path, image_rgba: np.ndarray) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + Image.fromarray(image_rgba, mode="RGBA").save(path) + + +def _psnr(a: np.ndarray, b: np.ndarray) -> float: + diff = a.astype(np.float32) - b.astype(np.float32) + mse = float((diff * diff).mean()) + if mse <= 0: + return float("inf") + return 20.0 * math.log10(255.0 / math.sqrt(mse)) + + +def _to_rgba8(rgb: np.ndarray) -> np.ndarray: + rgb = (np.clip(rgb, 0.0, 1.0) * 255.0 + 0.5).astype(np.uint8) + h, w, _ = rgb.shape + rgba = np.empty((h, w, 4), dtype=np.uint8) + rgba[..., :3] = rgb + rgba[..., 3] = 255 + return rgba + + +def _build_render_product(stage: Usd.Stage, cam_name: str, width: int, height: int) -> Usd.Prim: + rp_path = f"/Render/{cam_name}" + rp = stage.DefinePrim(rp_path, "RenderProduct") + rp.CreateAttribute("resolution", Sdf.ValueTypeNames.Int2).Set(Gf.Vec2i(int(width), int(height))) + cam_prim = stage.DefinePrim(f"/World/Cameras/{cam_name}", "Camera") + rp.CreateRelationship("camera").SetTargets([cam_prim.GetPath()]) + hdr = stage.DefinePrim(f"{rp_path}/HdrColor", "RenderVar") + hdr.CreateAttribute("sourceName", Sdf.ValueTypeNames.String).Set("HdrColor") + hdr.CreateAttribute("omni:rtx:aov", Sdf.ValueTypeNames.Opaque, custom=False) + rp.CreateRelationship("orderedVars").SetTargets([Sdf.Path("HdrColor")]) + return rp + + +class _StubDataset: + def __init__(self, frame_to_camera, names): + self.f2c = list(frame_to_camera) + self.names = list(names) + + def __len__(self): return len(self.f2c) + + def get_camera_names(self): return list(self.names) + + def get_camera_idx(self, i): return int(self.f2c[i]) + + +def main(argv=None) -> int: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("--checkpoint", type=Path, required=True) + parser.add_argument("--data-path", type=Path, default=None, + help="Override the dataset path stored in the checkpoint.") + parser.add_argument("--out-dir", type=Path, default=None, + help="Working directory (default: tmp). Outputs are kept here.") + parser.add_argument("--psnr-threshold", type=float, default=35.0) + parser.add_argument("--max-frames", type=int, default=None, + help="Limit number of val frames processed.") + parser.add_argument("--use-train", action="store_true", + help="Use train frames instead of val (model has overfit, " + "so the gaussian renderer produces non-trivial HDR even after short runs).") + parser.add_argument("--save-hdr-png", action="store_true", + help="Save the HDR render as a normalised PNG for inspection.") + parser.add_argument("--verbose", "-v", action="count", default=0) + args = parser.parse_args(argv) + logging.basicConfig(level=logging.INFO - 10 * args.verbose, + format="%(asctime)s %(levelname)s %(name)s: %(message)s") + + if not torch.cuda.is_available(): + raise SystemExit("CUDA is required.") + + work = args.out_dir or Path(tempfile.mkdtemp(prefix="ppisp_trained_")) + work.mkdir(parents=True, exist_ok=True) + hdr_dir = work / "hdr" + ref_dir = work / "reference" + usd_dir = work / "usd" + slang_dir = work / "slangpy" + + # ------------------------------------------------------------------ + # 1. Load checkpoint via Renderer.from_checkpoint (uses val dataset). + # ------------------------------------------------------------------ + renderer = Renderer.from_checkpoint( + checkpoint_path=str(args.checkpoint), + path=str(args.data_path) if args.data_path else "", + out_dir=str(work / "_renderer_unused"), + save_gt=False, + computes_extra_metrics=False, + ) + model = renderer.model + post_processing = renderer.post_processing + if post_processing is None or type(post_processing).__name__ != "PPISP": + raise SystemExit("Checkpoint has no PPISP post-processing module.") + if not getattr(post_processing.config, "use_controller", False): + raise SystemExit("PPISP was trained without a controller; nothing to validate.") + if args.use_train: + # Pull the train dataloader by re-creating it (Renderer doesn't keep one). + from threedgrut.datasets.utils import configure_dataloader_for_platform + from threedgrut import datasets as ds + conf_for_train = renderer.conf + train_dataset, _ = ds.make(name=conf_for_train.dataset.type, + config=conf_for_train, ray_jitter=None) + train_dataloader = torch.utils.data.DataLoader( + train_dataset, + **configure_dataloader_for_platform( + {"num_workers": 0, "batch_size": 1, "shuffle": False, "pin_memory": True} + ), + ) + val_dataset = train_dataset + val_dataloader = train_dataloader + logger.info("Using TRAIN frames (model overfit) for validation.") + else: + val_dataset = renderer.dataset + val_dataloader = renderer.dataloader + + cam_names: List[str] + if hasattr(val_dataset, "get_camera_names"): + cam_names = list(val_dataset.get_camera_names()) + else: + cam_names = ["cam_0"] + + frame_to_camera: List[int] = [] + resolutions: Dict[str, Tuple[int, int]] = {} + + # ------------------------------------------------------------------ + # 2-3. Render HDR + PyTorch reference LDR for each val frame. + # ------------------------------------------------------------------ + from threedgrut.utils.render import apply_post_processing # noqa: E402 + + seen = 0 + for frame_idx, batch in enumerate(val_dataloader): + if args.max_frames is not None and seen >= args.max_frames: + break + + gpu_batch = val_dataset.get_gpu_batch_with_intrinsics(batch) + with torch.no_grad(): + outputs = model(gpu_batch) + hdr_tensor = outputs["pred_rgb"][0] # [H, W, 3] + + # If the gaussian render is degenerate (e.g. short training, mismatched + # camera frames), the "HDR" is all zeros and the slang/PyTorch + # comparison becomes vacuous. Fall back to the dataset GT image so the + # comparison still exercises the full PPISP pipeline on real-world + # spatial / colour variation. + if hdr_tensor.abs().max().item() < 1e-6: + logger.warning("Gaussian render is degenerate (all zero); " + "substituting GT image as HDR input.") + gt = gpu_batch.rgb_gt + hdr_tensor = gt[0] if gt.dim() == 4 else gt + + hdr_np = hdr_tensor.detach().cpu().numpy().astype(np.float32) + h, w = hdr_np.shape[:2] + + cam_idx = (val_dataset.get_camera_idx(frame_idx) + if hasattr(val_dataset, "get_camera_idx") else 0) + cam_name = cam_names[cam_idx] if cam_idx < len(cam_names) else "cam_0" + resolutions[cam_name] = (w, h) + frame_to_camera.append(cam_idx) + + _save_npy(hdr_dir / cam_name / f"{frame_idx}.npy", hdr_np) + + # Reference path: same PPISP that the slang shader will execute. + # PPISP.forward picks the controller branch when frame_idx=-1, so + # we mirror that here for an apples-to-apples comparison. + with torch.no_grad(): + outputs_ref = dict(outputs) + outputs_ref["pred_rgb"] = hdr_tensor.unsqueeze(0) + # apply_post_processing expects a batch dim. + ref = apply_post_processing( + post_processing, outputs_ref, gpu_batch, training=False + )["pred_rgb"][0] + ref_np = ref.detach().cpu().numpy().astype(np.float32) + _save_png(ref_dir / cam_name / f"{frame_idx}.png", _to_rgba8(ref_np)) + # Save the pre-quantization float reference so we can quantify the + # numerical drift independent of the rgba8_unorm round-trip. + np.save(ref_dir / cam_name / f"{frame_idx}.npy", ref_np) + seen += 1 + + if not frame_to_camera: + raise SystemExit("No validation frames found in dataset.") + logger.info("Rendered %d val frame(s)", len(frame_to_camera)) + + # ------------------------------------------------------------------ + # 4. Author the controller-aware USD + sidecars. + # ------------------------------------------------------------------ + usd_dir.mkdir(parents=True, exist_ok=True) + stage = Usd.Stage.CreateNew(str(usd_dir / "scene.usda")) + stage.SetMetadata("upAxis", UsdGeom.Tokens.y) + stage.DefinePrim("/World", "Xform") + stage.DefinePrim("/Render", "Scope") + for cam_name, (w, h) in resolutions.items(): + _build_render_product(stage, cam_name, w, h) + + dataset_stub = _StubDataset(frame_to_camera, cam_names) + cam_names_built, mapping = build_camera_frame_mapping(dataset_stub) + add_ppisp_to_all_render_products( + stage=stage, + ppisp=post_processing, + camera_names=cam_names_built, + camera_frame_mapping=mapping, + use_controller=True, + ) + stage.GetRootLayer().Save() + + for s in get_ppisp_spg_dyn_files(): + (usd_dir / s.filename).write_bytes(s.serialized) + for s in get_controller_sidecars(): + (usd_dir / s.filename).write_bytes(s.serialized) + + # ------------------------------------------------------------------ + # 5. Run the slangpy CLI. + # ------------------------------------------------------------------ + cli = Path(__file__).resolve().parent / "render_renderproduct.py" + cmd = [ + sys.executable, str(cli), + str(usd_dir / "scene.usda"), str(hdr_dir), str(slang_dir), "-vv", + ] + logger.info("Running slangpy CLI: %s", " ".join(cmd)) + proc = subprocess.run(cmd, capture_output=True, text=True) + if proc.returncode != 0: + print(proc.stdout); print(proc.stderr, file=sys.stderr) + raise SystemExit(f"render_renderproduct.py failed (exit {proc.returncode})") + + # ------------------------------------------------------------------ + # 5b. Probe: run the controller alone via slangpy on each saved HDR + # and compare its 9-element output to PyTorch. This isolates whether + # any drift in the final image originates in the controller (CNN+MLP) + # or further downstream in the PPISP shader. + # ------------------------------------------------------------------ + from threedgrut.export.usd.writers.ppisp_controller_writer import ( + flatten_controller_weights, + ) + from tools.render_ppisp_spg.spg_runtime import run_controller as _run_ctrl + print("\nController (9-float) drift, slang vs torch:") + for frame_idx, cam_idx in enumerate(frame_to_camera): + cam_name = cam_names[cam_idx] + hdr_np = np.load(hdr_dir / cam_name / f"{frame_idx}.npy") + ctrl = post_processing.controllers[cam_idx] + # Torch reference + rgb_t = torch.from_numpy(hdr_np).float().to("cuda") + pe_t = torch.zeros(1, dtype=torch.float32, device="cuda") + with torch.no_grad(): + exposure, color = ctrl(rgb_t, pe_t) + torch_out = np.concatenate([ + np.array([float(exposure)], dtype=np.float32), + color.detach().cpu().numpy().astype(np.float32), + ]) + # Slang + weights = flatten_controller_weights(ctrl) + slang_out = _run_ctrl(usd_dir / "ppisp_controller.slang", hdr_np, weights, prior_exposure=0.0) + diff = np.abs(slang_out - torch_out) + print(f" frame={frame_idx} torch={torch_out} slang={slang_out} max|Δ|={diff.max():.4g}") + + # ------------------------------------------------------------------ + # 6. Compare images. + # ------------------------------------------------------------------ + all_pass = True + worst = (None, float("inf")) + for frame_idx, cam_idx in enumerate(frame_to_camera): + cam_name = cam_names[cam_idx] + ref_path = ref_dir / cam_name / f"{frame_idx}.png" + sl_path = slang_dir / cam_name / f"{frame_idx}.png" + if not sl_path.exists(): + print(f" cam={cam_name} frame={frame_idx} MISSING slangpy output") + all_pass = False + continue + ref = np.asarray(Image.open(ref_path).convert("RGBA")) + sl = np.asarray(Image.open(sl_path).convert("RGBA")) + psnr = _psnr(ref[..., :3], sl[..., :3]) + max_abs = int(np.max(np.abs(ref[..., :3].astype(int) - sl[..., :3].astype(int)))) + ok = psnr >= args.psnr_threshold + + # Also report a float-domain diff: the slang shader writes through + # rgba8_unorm, so its output is already quantized; we re-quantize the + # PyTorch reference with the same rule and compare the float values + # of the reference to that re-quantized form. This shows whether the + # shader is matching the *post-quantization* spec exactly. + ref_float_path = ref_dir / cam_name / f"{frame_idx}.npy" + if ref_float_path.exists(): + ref_float = np.clip(np.load(ref_float_path), 0.0, 1.0) + sl_float = sl[..., :3].astype(np.float32) / 255.0 + float_diff = ref_float - sl_float + mean_abs = float(np.mean(np.abs(float_diff))) + max_abs_f = float(np.max(np.abs(float_diff))) + print(f" cam={cam_name} frame={frame_idx} PSNR={psnr:7.3f} dB " + f"max|Δ|_u8={max_abs} max|Δ|_float={max_abs_f:.4f} " + f"mean|Δ|_float={mean_abs:.5f} " + f"{'OK' if ok else 'FAIL'}") + else: + print(f" cam={cam_name} frame={frame_idx} PSNR={psnr:7.3f} dB " + f"max|Δ|={max_abs} {'OK' if ok else 'FAIL'}") + if psnr < worst[1]: + worst = ((cam_name, frame_idx), psnr) + if not ok: + all_pass = False + + print() + print(f"worst frame: {worst[0]} @ {worst[1]:.3f} dB (threshold {args.psnr_threshold} dB)") + print(f"work dir: {work}") + return 0 if all_pass else 1 + + +if __name__ == "__main__": + raise SystemExit(main())