From 915a15dd1ace2a2c1f2c3513b3d21ce4d06b556b Mon Sep 17 00:00:00 2001 From: Nicolas Moenne-Loccoz Date: Sun, 26 Apr 2026 19:34:43 -0400 Subject: [PATCH 01/42] fix(transcode): import NuRec USD as LightField Decode NuRec USD/USDZ inputs into LightField export data so transcode can convert NuRec packages into standard ParticleField USD. Made-with: Cursor --- threedgrut/export/importers/nurec_usd.py | 106 ++++++++- threedgrut/export/scripts/transcode.py | 72 ++++-- threedgrut/export/usd/camera_copy.py | 280 +++++++++++++++++++++++ threedgrut/export/usd/exporter.py | 64 ++++-- 4 files changed, 484 insertions(+), 38 deletions(-) create mode 100644 threedgrut/export/usd/camera_copy.py diff --git a/threedgrut/export/importers/nurec_usd.py b/threedgrut/export/importers/nurec_usd.py index a9254e22..e9cc49de 100644 --- a/threedgrut/export/importers/nurec_usd.py +++ b/threedgrut/export/importers/nurec_usd.py @@ -41,6 +41,17 @@ _STATE_N_ACTIVE = ".gaussians_nodes.gaussians.n_active_features" _STATE_EXTRA_SIGNAL = ".gaussians_nodes.gaussians.extra_signal" +_GAUSSIANS_NODES_PREFIX = ".gaussians_nodes." +# Per-node tensor suffixes (same layout as fill_3dgut_template / static gaussians). +_REQUIRED_GAUSSIAN_NODE_KEYS = ( + "positions", + "rotations", + "scales", + "densities", + "features_albedo", + "features_specular", +) + def _find_nurec_volume_prim(stage: Usd.Stage) -> Optional[Usd.Prim]: """Find the NuRec Volume prim (UsdVol::Volume with omni:nurec:isNuRecVolume).""" @@ -108,6 +119,84 @@ def _tensor_from_state(state: dict, key: str, dtype=np.float16, shape_key: Optio return arr.astype(np.float32) +def _discover_gaussians_nodes_prefixes(state: dict) -> list[str]: + """Find state_dict prefixes like '.gaussians_nodes.background' that hold full Gaussian tensors.""" + found: set[str] = set() + for k in state: + if not isinstance(k, str) or not k.endswith(".positions"): + continue + prefix = k[: -len(".positions")] + if not prefix.startswith(_GAUSSIANS_NODES_PREFIX): + continue + if all(state.get(f"{prefix}.{suffix}") is not None for suffix in _REQUIRED_GAUSSIAN_NODE_KEYS): + found.add(prefix) + return sorted(found) + + +def _load_merged_gaussian_tensors_from_state( + state: dict, +) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, Optional[list[int]]]: + """Load positions…specular from state_dict, merging multiple .gaussians_nodes. blocks if present.""" + prefixes = _discover_gaussians_nodes_prefixes(state) + if not prefixes: + positions = _tensor_from_state(state, _STATE_POSITIONS) + rotations = _tensor_from_state(state, _STATE_ROTATIONS) + scales = _tensor_from_state(state, _STATE_SCALES) + densities = _tensor_from_state(state, _STATE_DENSITIES) + features_albedo = _tensor_from_state(state, _STATE_FEATURES_ALBEDO) + features_specular = _tensor_from_state(state, _STATE_FEATURES_SPECULAR) + n_active = state.get(_STATE_N_ACTIVE) + n_active_vals = None + if n_active is not None: + n_active_vals = [int(np.frombuffer(n_active, dtype=np.int64)[0])] + return ( + positions, + rotations, + scales, + densities, + features_albedo, + features_specular, + n_active_vals, + ) + + chunks: dict[str, list[np.ndarray]] = {k: [] for k in _REQUIRED_GAUSSIAN_NODE_KEYS} + n_active_per_node: list[int] = [] + counts: list[tuple[str, int]] = [] + for pref in prefixes: + for suffix in _REQUIRED_GAUSSIAN_NODE_KEYS: + chunks[suffix].append(_tensor_from_state(state, f"{pref}.{suffix}")) + na_key = f"{pref}.n_active_features" + na_raw = state.get(na_key) + if na_raw is not None: + n_active_per_node.append(int(np.frombuffer(na_raw, dtype=np.int64)[0])) + counts.append((pref, int(chunks["positions"][-1].shape[0]))) + + for suffix in _REQUIRED_GAUSSIAN_NODE_KEYS: + ref_tail = chunks[suffix][0].shape[1:] + for i, arr in enumerate(chunks[suffix][1:], start=1): + if arr.shape[1:] != ref_tail: + raise ValueError( + f"NuRec state_dict: incompatible '{suffix}' trailing dims across nodes " + f"({prefixes[0]} {ref_tail} vs {prefixes[i]} {arr.shape[1:]})" + ) + + merged = {suffix: np.concatenate(chunks[suffix], axis=0) for suffix in _REQUIRED_GAUSSIAN_NODE_KEYS} + logger.info( + "NuRec: merged %d Gaussian node(s) %s", + len(prefixes), + ", ".join(f"{p}={n}" for p, n in counts), + ) + return ( + merged["positions"], + merged["rotations"], + merged["scales"], + merged["densities"], + merged["features_albedo"], + merged["features_specular"], + n_active_per_node if n_active_per_node else None, + ) + + def _rotation_matrix_to_quat_wxyz(R: np.ndarray) -> np.ndarray: """Convert 3x3 rotation matrix to wxyz quaternion (one quat).""" trace = R[0, 0] + R[1, 1] + R[2, 2] @@ -248,16 +337,15 @@ def _load_stage(self, stage_path: Path, resolution_root: Path) -> Tuple[Gaussian raw = _load_nurec_bytes(resolution_root, nurec_path) state = _decode_state_dict(raw) - positions = _tensor_from_state(state, _STATE_POSITIONS) - rotations = _tensor_from_state(state, _STATE_ROTATIONS) - scales = _tensor_from_state(state, _STATE_SCALES) - densities = _tensor_from_state(state, _STATE_DENSITIES) - features_albedo = _tensor_from_state(state, _STATE_FEATURES_ALBEDO) - features_specular = _tensor_from_state(state, _STATE_FEATURES_SPECULAR) + positions, rotations, scales, densities, features_albedo, features_specular, n_active_list = ( + _load_merged_gaussian_tensors_from_state(state) + ) - n_active = state.get(_STATE_N_ACTIVE) - if n_active is not None: - sh_degree = int(np.frombuffer(n_active, dtype=np.int64)[0]) + if n_active_list is not None: + unique_deg = set(n_active_list) + if len(unique_deg) > 1: + logger.warning("NuRec nodes disagree on n_active_features %s; using max", n_active_list) + sh_degree = max(n_active_list) else: # Infer from features_specular shape: (N, (degree+1)^2 - 1) * 3 n_spec = features_specular.shape[1] diff --git a/threedgrut/export/scripts/transcode.py b/threedgrut/export/scripts/transcode.py index 9c19261b..0f8b7b8a 100644 --- a/threedgrut/export/scripts/transcode.py +++ b/threedgrut/export/scripts/transcode.py @@ -25,6 +25,11 @@ python -m threedgrut.export.scripts.transcode input.ply -o output.usdz --format lightfield python -m threedgrut.export.scripts.transcode input.usdz -o output.ply python -m threedgrut.export.scripts.transcode nurec.usd -o lightfield.usdz --format lightfield + +USD/USDZ → LightField: source /World prims (e.g. rig_trajectories) merge into default.usda at the +same paths; referenced layers are bundled unchanged (preserves camera animation curves). +/World/Gaussians is skipped by default; use --copy-source-include-gaussians to merge it too. +Use --no-copy-source-prims to disable. """ import argparse @@ -32,6 +37,7 @@ import sys import tempfile import zipfile +from contextlib import nullcontext from pathlib import Path from typing import Optional, Tuple @@ -44,6 +50,7 @@ PLYImporter, USDImporter, ) +from threedgrut.export.usd.camera_copy import usd_stage_path_context_for_camera_copy from threedgrut.export.usd.exporter import USDExporter from threedgrut.export.usd.nurec.exporter import NuRecExporter @@ -205,6 +212,8 @@ def transcode( apply_coordinate_transform: bool = False, render_order_hint: Optional[str] = None, linear_srgb: bool = False, + copy_cameras_source: Optional[Tuple[Path, Path]] = None, + copy_source_skip_subtrees: Optional[Tuple] = None, ) -> None: """Transcode between Gaussian splatting formats. @@ -219,6 +228,8 @@ def transcode( apply_coordinate_transform: Apply 3DGRUT-to-USDZ transform (for both lightfield and nurec) render_order_hint: If set, force sortingModeHint for lightfield only; ignored for other formats (warning logged). linear_srgb: If True, set prim color space to lin_rec709_scene (lightfield only). + copy_cameras_source: If set, (root_usd_path, asset_resolution_dir) to copy source /World prims from. + copy_source_skip_subtrees: Optional tuple of Sdf.Path roots to skip under /World (None = default skip Gaussians). """ if render_order_hint is not None and output_format != "lightfield": logger.warning( @@ -264,7 +275,13 @@ def transcode( # Export logger.info(f"Exporting to {output_path}...") - exporter.export(adapter, output_path, apply_coordinate_transform=apply_coordinate_transform) + exporter.export( + adapter, + output_path, + apply_coordinate_transform=apply_coordinate_transform, + copy_cameras_source=copy_cameras_source, + copy_source_skip_subtrees=copy_source_skip_subtrees, + ) logger.info(f"Transcode complete: {input_path} -> {output_path}") @@ -352,6 +369,23 @@ def parse_args(): action="store_true", help="Set prim color space to lin_rec709_scene (lightfield only). Default is srgb_rec709_display.", ) + parser.add_argument( + "--no-copy-source-prims", + action="store_true", + dest="no_copy_source_prims", + help="When input is USD/USDZ and output is LightField, do not merge source /World prims into default.usda.", + ) + parser.add_argument( + "--no-copy-source-cameras", + action="store_true", + dest="no_copy_source_prims", + help="Deprecated alias for --no-copy-source-prims.", + ) + parser.add_argument( + "--copy-source-include-gaussians", + action="store_true", + help="Also copy /World/Gaussians from the source (duplicates old LightField data; can be very large).", + ) parser.add_argument( "-v", "--verbose", @@ -392,19 +426,31 @@ def main(): # Create output directory if needed output_path.parent.mkdir(parents=True, exist_ok=True) + suffix_in = input_path.suffix.lower() + use_camera_copy_ctx = ( + output_format == "lightfield" + and suffix_in in (".usd", ".usda", ".usdc", ".usdz") + and not args.no_copy_source_prims + ) + camera_ctx = usd_stage_path_context_for_camera_copy(input_path) if use_camera_copy_ctx else nullcontext(None) + try: - transcode( - input_path=input_path, - output_path=output_path, - output_format=output_format, - max_sh_degree=args.max_sh_degree, - half_precision=args.half, - half_geometry=args.half_geometry, - half_features=args.half_features, - apply_coordinate_transform=args.apply_coordinate_transform, - render_order_hint=args.render_order_hint, - linear_srgb=args.linear_srgb, - ) + with camera_ctx as copy_cameras_source: + skip_subtrees = () if args.copy_source_include_gaussians else None + transcode( + input_path=input_path, + output_path=output_path, + output_format=output_format, + max_sh_degree=args.max_sh_degree, + half_precision=args.half, + half_geometry=args.half_geometry, + half_features=args.half_features, + apply_coordinate_transform=args.apply_coordinate_transform, + render_order_hint=args.render_order_hint, + linear_srgb=args.linear_srgb, + copy_cameras_source=copy_cameras_source, + copy_source_skip_subtrees=skip_subtrees, + ) except Exception as e: logger.error(f"Transcode failed: {e}") if args.verbose: diff --git a/threedgrut/export/usd/camera_copy.py b/threedgrut/export/usd/camera_copy.py new file mode 100644 index 00000000..01328b00 --- /dev/null +++ b/threedgrut/export/usd/camera_copy.py @@ -0,0 +1,280 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Copy prims from a source USD stage into an export stage (transcode USD → LightField).""" + +import logging +import tempfile +import zipfile +from contextlib import contextmanager +from pathlib import Path +from typing import Collection, Iterator, List, Optional, Set, Tuple + +from pxr import Sdf, UsdGeom + +from threedgrut.export.usd.stage_utils import NamedSerialized + +logger = logging.getLogger(__name__) + +UsdStagePathPair = Tuple[Path, Path] + +# Default: do not duplicate LightField Gaussian root (large); new splats live at /World/Gaussians. +_DEFAULT_SKIP_SUBTREES = (Sdf.Path("/World/Gaussians"),) + + +def _path_is_under_skipped(src_path: Sdf.Path, skip_roots: Collection[Sdf.Path]) -> bool: + for root in skip_roots: + if src_path == root: + return True + # Children of root (e.g. /World/Gaussians/gaussians) + prefix = str(root) + "/" + if str(src_path).startswith(prefix): + return True + return False + + +def _copy_prim_spec_recursive( + src_layer: Sdf.Layer, + dst_layer: Sdf.Layer, + src_path: Sdf.Path, + dst_path: Sdf.Path, +) -> int: + """Copy one prim spec and all descendants. Returns number of prims copied.""" + src_spec = src_layer.GetPrimAtPath(src_path) + if not src_spec or not src_spec.active: + return 0 + Sdf.CopySpec(src_layer, src_path, dst_layer, dst_path) + count = 1 + for child_spec in src_spec.nameChildren: + name = child_spec.name + count += _copy_prim_spec_recursive( + src_layer, + dst_layer, + src_path.AppendChild(name), + dst_path.AppendChild(name), + ) + return count + + +def merge_source_world_at_same_paths( + dest_stage, + source_stage, + skip_source_subtrees: Optional[Collection[Sdf.Path]] = None, +) -> int: + """ + Merge each top-level child of ``/World`` from the source **root** layer onto ``dest_stage``'s + root layer at the **same path** as the source (e.g. ``/World/rig_trajectories``), using + ``Sdf.CopySpec``. References and payloads are copied as-authored so sibling layers (e.g. + ``rig_trajectories.usda``) keep all time samples when those files are bundled unchanged. + + Skips subtrees in ``skip_source_subtrees`` (default: ``/World/Gaussians`` for LightField). + Skips any path where the destination root layer already has a prim spec (e.g. export's + ``/World/gaussians`` reference prim). + """ + skips = tuple(skip_source_subtrees) if skip_source_subtrees is not None else _DEFAULT_SKIP_SUBTREES + src_layer = source_stage.GetRootLayer() + dst_layer = dest_stage.GetRootLayer() + + world_spec = src_layer.GetPrimAtPath("/World") + if not world_spec: + logger.info("Source USD has no /World prim; nothing to merge") + return 0 + + total = 0 + for child_spec in world_spec.nameChildren: + name = child_spec.name + path = Sdf.Path("/World").AppendChild(name) + if _path_is_under_skipped(path, skips): + logger.info("Skipping source subtree %s (transcode merge skip list)", path) + continue + if dst_layer.GetPrimAtPath(path): + logger.info("Keeping destination prim %s; not overwriting with source", path) + continue + total += _copy_prim_spec_recursive(src_layer, dst_layer, path, path) + + if total == 0: + logger.info("No source /World prims merged (empty or all skipped / already present)") + else: + logger.info("Merged %d source prim subtree(s) at original /World paths", total) + + return total + + +def copy_authored_time_settings_from_source(source_stage, dest_stage) -> None: + """Copy authored time code range and FPS from source to destination stage when set.""" + try: + if getattr(source_stage, "HasAuthoredTimeCodeRange", None) and source_stage.HasAuthoredTimeCodeRange(): + dest_stage.SetStartTimeCode(source_stage.GetStartTimeCode()) + dest_stage.SetEndTimeCode(source_stage.GetEndTimeCode()) + tps = source_stage.GetTimeCodesPerSecond() + if tps is not None and float(tps) > 0.0: + dest_stage.SetTimeCodesPerSecond(tps) + except Exception as ex: + logger.debug("Could not copy time settings from source stage: %s", ex) + + +# Filenames we always author in LightField USDZ export (never pull from source package). +_OUTPUT_AUTHORED_NAMES = frozenset({"gaussians.usdc", "default.usda"}) + + +def _basename_packaged_ref(asset_path: str) -> Optional[str]: + """USDZ-flat basename for a relative layer/asset reference, or None if not packagable.""" + if not asset_path: + return None + s = asset_path.strip().strip("@") + if not s or "://" in s or s.startswith("/"): + return None + return Path(s.replace("\\", "/")).name + + +def _gather_ref_payload_basenames_from_prim_spec(spec: Sdf.PrimSpec) -> Set[str]: + out: Set[str] = set() + if not spec: + return out + ref_list = spec.referenceList + for item in list(ref_list.prependedItems) + list(ref_list.appendedItems): + bn = _basename_packaged_ref(getattr(item, "assetPath", "") or "") + if bn: + out.add(bn) + pay_list = getattr(spec, "payloadList", None) + if pay_list is not None: + for item in list(pay_list.prependedItems) + list(pay_list.appendedItems): + bn = _basename_packaged_ref(getattr(item, "assetPath", "") or "") + if bn: + out.add(bn) + return out + + +def _walk_prim_subtree(layer: Sdf.Layer, root_path: Sdf.Path): + """Depth-first active prims under root_path (inclusive).""" + spec = layer.GetPrimAtPath(root_path) + if not spec or not spec.active: + return + yield root_path + for child_spec in spec.nameChildren: + yield from _walk_prim_subtree(layer, root_path.AppendChild(child_spec.name)) + + +def _gather_refs_from_layer_subtree(layer: Sdf.Layer, path_prefix: str) -> Set[str]: + """Collect referenced basenames from all prims under path_prefix on this layer.""" + needed: Set[str] = set() + root = Sdf.Path(path_prefix) + if not layer.GetPrimAtPath(root): + return needed + for path in _walk_prim_subtree(layer, root): + spec = layer.GetPrimAtPath(path) + needed |= _gather_ref_payload_basenames_from_prim_spec(spec) + return needed + + +def _walk_entire_layer(layer: Sdf.Layer): + """All active prim paths (excluding absolute root pseudo-prim).""" + root = Sdf.Path("/") + spec = layer.GetPrimAtPath(root) + if not spec: + return + for child_spec in spec.nameChildren: + yield from _walk_prim_subtree(layer, root.AppendChild(child_spec.name)) + + +def collect_transitive_sidecars_for_world_subtree( + dest_layer: Sdf.Layer, + res_root: Path, + world_prefix: str = "/World", + extra_skip_names: Optional[Collection[str]] = None, +) -> List[NamedSerialized]: + """ + Resolve layer/asset references under ``world_prefix`` and bundle files from ``res_root`` into the + output USDZ (flat layout). Follows references/payloads transitively through USD layers. + + Skips names in ``_OUTPUT_AUTHORED_NAMES`` and ``extra_skip_names`` (e.g. source root default file). + """ + skip: Set[str] = set(_OUTPUT_AUTHORED_NAMES) + if extra_skip_names: + skip.update(extra_skip_names) + + seed = _gather_refs_from_layer_subtree(dest_layer, world_prefix) + queue: Set[str] = {n for n in seed if n not in skip} + done: Set[str] = set(skip) + result: List[NamedSerialized] = [] + + while queue: + name = queue.pop() + if name in done: + continue + done.add(name) + path = res_root / name + if not path.is_file(): + logger.warning("Referenced package file missing under %s: %s", res_root, name) + continue + try: + data = path.read_bytes() + except OSError as e: + logger.warning("Could not read sidecar %s: %s", path, e) + continue + result.append(NamedSerialized(filename=name, serialized=data)) + + suf = path.suffix.lower() + if suf not in (".usd", ".usda", ".usdc"): + continue + sub = Sdf.Layer.FindOrOpen(str(path)) + if not sub: + logger.warning("Could not open referenced layer for sidecar walk: %s", path) + continue + for p in _walk_entire_layer(sub): + spec = sub.GetPrimAtPath(p) + for bn in _gather_ref_payload_basenames_from_prim_spec(spec): + if bn and bn not in done: + queue.add(bn) + + if result: + logger.info("Bundled %d sidecar file(s) from %s for /World references", len(result), res_root) + return result + + +@contextmanager +def usd_stage_path_context_for_camera_copy(usd_path: Path) -> Iterator[Optional[UsdStagePathPair]]: + """ + Yield (root_stage_path, asset_resolution_dir) for opening a USD/USDZ with correct asset paths. + + For USDZ, extracts to a temporary directory (deleted on exit). + """ + path = usd_path.resolve() + suffix = path.suffix.lower() + if suffix not in (".usd", ".usda", ".usdc", ".usdz"): + yield None + return + + if suffix == ".usdz": + with tempfile.TemporaryDirectory() as tmp: + tmp_path = Path(tmp) + with zipfile.ZipFile(path, "r") as zf: + zf.extractall(tmp_path) + usd_files = list(tmp_path.glob("*.usd*")) + root_file = None + for f in usd_files: + if f.stem == "default": + root_file = f + break + if root_file is None and usd_files: + root_file = usd_files[0] + if root_file is None: + logger.warning("USDZ has no USD root for source prim copy: %s", path) + yield None + return + yield (root_file.resolve(), tmp_path.resolve()) + return + + yield (path, path.parent.resolve()) diff --git a/threedgrut/export/usd/exporter.py b/threedgrut/export/usd/exporter.py index 71e3276e..9ced7dc7 100644 --- a/threedgrut/export/usd/exporter.py +++ b/threedgrut/export/usd/exporter.py @@ -26,6 +26,7 @@ import numpy as np import torch +from pxr import Usd from ncore.data import ( OpenCVFisheyeCameraModelParameters, OpenCVPinholeCameraModelParameters, @@ -47,6 +48,11 @@ ) from threedgrut.export.usd.writers.background import export_background_to_usd from threedgrut.export.usd.writers.base import create_gaussian_writer +from threedgrut.export.usd.camera_copy import ( + collect_transitive_sidecars_for_world_subtree, + copy_authored_time_settings_from_source, + merge_source_world_at_same_paths, +) from threedgrut.export.usd.writers.camera import export_cameras_to_usd logger = logging.getLogger(__name__) @@ -305,10 +311,43 @@ def export( writer.write_attributes(attrs) writer.finalize(attrs.positions) - # Collect stages and files for USDZ - stages: List[NamedUSDStage] = [] + suffix = output_path.suffix.lower() + package_as_usdz = suffix == ".usdz" or suffix not in (".usd", ".usda", ".usdc") + + gaussians_stage = NamedUSDStage(filename="gaussians.usdc", stage=stage) + default_stage_wrapped: Optional[NamedUSDStage] = None + if package_as_usdz: + default_stage_wrapped = self._create_default_stage([gaussians_stage]) + files: List[NamedSerialized] = [] + copy_source_usd = kwargs.get("copy_source_usd") + if copy_source_usd is None: + copy_source_usd = kwargs.get("copy_cameras_source") + if copy_source_usd is not None: + stage_path, res_root = copy_source_usd + try: + src_stage = Usd.Stage.Open(str(stage_path)) + if not src_stage: + logger.warning("Could not open source USD for prim merge: %s", stage_path) + else: + skip = kwargs.get("copy_source_skip_subtrees") + merge_target = default_stage_wrapped.stage if default_stage_wrapped is not None else stage + merge_source_world_at_same_paths(merge_target, src_stage, skip_source_subtrees=skip) + copy_authored_time_settings_from_source(src_stage, merge_target) + if package_as_usdz and res_root is not None and res_root.is_dir(): + sidecars = collect_transitive_sidecars_for_world_subtree( + merge_target.GetRootLayer(), + res_root, + world_prefix="/World", + extra_skip_names={Path(stage_path).name}, + ) + for entry in sidecars: + if not any(f.filename == entry.filename for f in files): + files.append(entry) + except Exception as e: + logger.warning("Failed to merge source USD prims: %s", e) + # Export cameras if requested and dataset available if self.export_cameras and dataset is not None: try: @@ -355,29 +394,22 @@ def export( except (AttributeError, ValueError, ImportError) as e: logger.warning(f"Failed to export background: {e}") - # Determine output format - suffix = output_path.suffix.lower() + # Package: gaussians_stage / default_stage_wrapped were built before source merge. if suffix == ".usdz": - # Package as USDZ with composition: - # - default.usda (text) references gaussians.usdc (binary) - gaussians_stage = NamedUSDStage(filename="gaussians.usdc", stage=stage) - default_stage = self._create_default_stage([gaussians_stage]) - # default.usda must be first in USDZ - write_to_usdz(output_path, [default_stage, gaussians_stage], files if files else None) + if default_stage_wrapped is None: + default_stage_wrapped = self._create_default_stage([gaussians_stage]) + write_to_usdz(output_path, [default_stage_wrapped, gaussians_stage], files if files else None) elif suffix in [".usda", ".usd", ".usdc"]: - # Export as plain USD (format determined by extension) stage.Export(str(output_path)) - # Also export envmap if present if envmap_bytes is not None: envmap_path = output_path.parent / "envmap.png" with open(envmap_path, "wb") as f: f.write(envmap_bytes) else: - # Default to USDZ usdz_path = output_path.with_suffix(".usdz") - gaussians_stage = NamedUSDStage(filename="gaussians.usdc", stage=stage) - default_stage = self._create_default_stage([gaussians_stage]) - write_to_usdz(usdz_path, [default_stage, gaussians_stage], files if files else None) + if default_stage_wrapped is None: + default_stage_wrapped = self._create_default_stage([gaussians_stage]) + write_to_usdz(usdz_path, [default_stage_wrapped, gaussians_stage], files if files else None) logger.info(f"USD export complete: {output_path}") From 50243d9c8e648b72415cb477bb472d85146dbfa8 Mon Sep 17 00:00:00 2001 From: Nicolas Moenne-Loccoz Date: Sun, 26 Apr 2026 19:34:43 -0400 Subject: [PATCH 02/42] feat(export): validate written LightField USD stages Run OpenUSD validation after standard USD export so composition and stage metadata errors are caught before returning success. Made-with: Cursor --- threedgrut/export/scripts/export_usd.py | 9 +++ threedgrut/export/scripts/transcode.py | 9 +++ threedgrut/export/tests/test_export_import.py | 20 +---- threedgrut/export/usd/exporter.py | 12 ++- threedgrut/export/usd/validation.py | 77 +++++++++++++++++++ 5 files changed, 108 insertions(+), 19 deletions(-) create mode 100644 threedgrut/export/usd/validation.py diff --git a/threedgrut/export/scripts/export_usd.py b/threedgrut/export/scripts/export_usd.py index a696aabb..9cff811a 100644 --- a/threedgrut/export/scripts/export_usd.py +++ b/threedgrut/export/scripts/export_usd.py @@ -138,6 +138,11 @@ def parse_args(): action="store_true", help="Enable verbose logging", ) + parser.add_argument( + "--no-usd-validate", + action="store_true", + help="Skip OpenUSD stage validation after standard (ParticleField) export", + ) return parser.parse_args() @@ -243,12 +248,16 @@ def main(): # Export try: + export_kw = {} + if args.format == "standard": + export_kw["validate_usd"] = not args.no_usd_validate exporter.export( model=model, output_path=output_path, dataset=dataset, conf=conf, background=background, + **export_kw, ) logger.info(f"Export successful: {output_path}") except Exception as e: diff --git a/threedgrut/export/scripts/transcode.py b/threedgrut/export/scripts/transcode.py index 0f8b7b8a..e3d8bd75 100644 --- a/threedgrut/export/scripts/transcode.py +++ b/threedgrut/export/scripts/transcode.py @@ -214,6 +214,7 @@ def transcode( linear_srgb: bool = False, copy_cameras_source: Optional[Tuple[Path, Path]] = None, copy_source_skip_subtrees: Optional[Tuple] = None, + validate_usd: bool = True, ) -> None: """Transcode between Gaussian splatting formats. @@ -230,6 +231,7 @@ def transcode( linear_srgb: If True, set prim color space to lin_rec709_scene (lightfield only). copy_cameras_source: If set, (root_usd_path, asset_resolution_dir) to copy source /World prims from. copy_source_skip_subtrees: Optional tuple of Sdf.Path roots to skip under /World (None = default skip Gaussians). + validate_usd: If True and output is lightfield, run OpenUSD stage validation after export. """ if render_order_hint is not None and output_format != "lightfield": logger.warning( @@ -281,6 +283,7 @@ def transcode( apply_coordinate_transform=apply_coordinate_transform, copy_cameras_source=copy_cameras_source, copy_source_skip_subtrees=copy_source_skip_subtrees, + validate_usd=validate_usd if output_format == "lightfield" else False, ) logger.info(f"Transcode complete: {input_path} -> {output_path}") @@ -392,6 +395,11 @@ def parse_args(): action="store_true", help="Enable verbose logging", ) + parser.add_argument( + "--no-usd-validate", + action="store_true", + help="Skip OpenUSD stage validation after lightfield (.usd/.usdz) export", + ) return parser.parse_args() @@ -450,6 +458,7 @@ def main(): linear_srgb=args.linear_srgb, copy_cameras_source=copy_cameras_source, copy_source_skip_subtrees=skip_subtrees, + validate_usd=not args.no_usd_validate, ) except Exception as e: logger.error(f"Transcode failed: {e}") diff --git a/threedgrut/export/tests/test_export_import.py b/threedgrut/export/tests/test_export_import.py index eae8769f..0beff9a8 100644 --- a/threedgrut/export/tests/test_export_import.py +++ b/threedgrut/export/tests/test_export_import.py @@ -26,7 +26,7 @@ import numpy as np import pytest import torch -from pxr import Usd, UsdValidation +from pxr import Usd from threedgrut.export.base import ExportableModel from threedgrut.export.formats import PLYExporter @@ -34,18 +34,6 @@ from threedgrut.export.usd.exporter import USDExporter -def _validate_stage(stage: Usd.Stage) -> list: - """Run usd-core stage validators (StageMetadataChecker, CompositionErrorTest). Returns list of ValidationError.""" - validators = UsdValidation.ValidationRegistry().GetOrLoadValidatorsByName( - ["usdValidation:StageMetadataChecker", "usdValidation:CompositionErrorTest"] - ) - if not validators: - return [] - ctx = UsdValidation.ValidationContext(validators) - result = ctx.Validate(stage) - return list(result) if result else [] - - class MockGaussianModel(ExportableModel): """Mock ExportableModel with known test data for verification.""" @@ -404,7 +392,7 @@ def test_ply_usd_positions_match(self): ) def test_usd_export_passes_usd_validation(self): - """Exported USD stage passes usd-core schema/stage validators.""" + """Exported USD stage passes OpenUSD stage validators (run inside USDExporter.export).""" model = MockGaussianModel(num_gaussians=5, sh_degree=0) with tempfile.TemporaryDirectory() as tmpdir: usd_path = Path(tmpdir) / "test.usdz" @@ -414,10 +402,6 @@ def test_usd_export_passes_usd_validation(self): export_background=False, apply_normalizing_transform=False, ).export(model, usd_path) - stage = Usd.Stage.Open(str(usd_path)) - assert stage, "Failed to open exported stage" - errors = _validate_stage(stage) - assert not errors, "USD validation failed:\n" + "\n".join(e.GetMessage() for e in errors) def _find_prim_with_color_space_api(stage: Usd.Stage): diff --git a/threedgrut/export/usd/exporter.py b/threedgrut/export/usd/exporter.py index 9ced7dc7..2a6403dc 100644 --- a/threedgrut/export/usd/exporter.py +++ b/threedgrut/export/usd/exporter.py @@ -253,7 +253,9 @@ def export( dataset: Optional dataset for camera poses conf: Configuration parameters background: Optional background model for environment export - **kwargs: Additional parameters + **kwargs: Additional parameters. ``validate_usd`` (default True): run OpenUSD + stage validators on the written file (ParticleField / LightField only; no-op if + ``UsdValidation`` is unavailable). """ output_path = Path(output_path) logger.info(f"Exporting USD file to {output_path}...") @@ -399,17 +401,25 @@ def export( if default_stage_wrapped is None: default_stage_wrapped = self._create_default_stage([gaussians_stage]) write_to_usdz(output_path, [default_stage_wrapped, gaussians_stage], files if files else None) + written_path = output_path elif suffix in [".usda", ".usd", ".usdc"]: stage.Export(str(output_path)) if envmap_bytes is not None: envmap_path = output_path.parent / "envmap.png" with open(envmap_path, "wb") as f: f.write(envmap_bytes) + written_path = output_path else: usdz_path = output_path.with_suffix(".usdz") if default_stage_wrapped is None: default_stage_wrapped = self._create_default_stage([gaussians_stage]) write_to_usdz(usdz_path, [default_stage_wrapped, gaussians_stage], files if files else None) + written_path = usdz_path + + if kwargs.get("validate_usd", True): + from threedgrut.export.usd.validation import validate_exported_usd_stage + + validate_exported_usd_stage(written_path) logger.info(f"USD export complete: {output_path}") diff --git a/threedgrut/export/usd/validation.py b/threedgrut/export/usd/validation.py new file mode 100644 index 00000000..e198b0d3 --- /dev/null +++ b/threedgrut/export/usd/validation.py @@ -0,0 +1,77 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""OpenUSD validation helpers for exported ParticleField / LightField stages.""" + +from __future__ import annotations + +import logging +from pathlib import Path + +logger = logging.getLogger(__name__) + +# Stage-wide checks used by export tests; ParticleField-specific validators may be added as USD exposes them. +_LIGHTFIELD_VALIDATOR_NAMES = ( + "usdValidation:StageMetadataChecker", + "usdValidation:CompositionErrorTest", +) + + +def validate_exported_usd_stage(path: Path) -> None: + """ + Run OpenUSD validation on a written .usd / .usda / .usdc / .usdz. + + Intended for outputs from :class:`~threedgrut.export.usd.exporter.USDExporter` + (ParticleField3DGaussianSplat / LightField). NuRec exports are not validated here. + + If ``UsdValidation`` is missing, validators fail to load, or the registry API is + unavailable, this function logs at DEBUG and returns without error. + + Args: + path: Path to the package root file on disk. + + Raises: + ValueError: Stage cannot be opened, or validators reported errors. + """ + path = Path(path) + try: + from pxr import Usd, UsdValidation + except ImportError: + logger.debug("pxr not available; skipping USD validation for %s", path) + return + + try: + registry = UsdValidation.ValidationRegistry() + validators = registry.GetOrLoadValidatorsByName(list(_LIGHTFIELD_VALIDATOR_NAMES)) + except Exception as exc: + logger.debug("UsdValidation unavailable (%s); skipping USD validation for %s", exc, path) + return + + if not validators: + logger.debug("No USD validators loaded; skipping validation for %s", path) + return + + stage = Usd.Stage.Open(str(path)) + if not stage: + raise ValueError(f"USD validation could not open stage: {path}") + + logger.info("Running OpenUSD stage validation on %s", path) + ctx = UsdValidation.ValidationContext(validators) + result = ctx.Validate(stage) + errors = list(result) if result else [] + if errors: + msg = "\n".join(e.GetMessage() for e in errors) + raise ValueError(f"USD validation failed for {path}:\n{msg}") + logger.info("OpenUSD stage validation passed for %s", path) From 2bc90d26c944327eff64970f46e059413046ebab Mon Sep 17 00:00:00 2001 From: Nicolas Moenne-Loccoz Date: Sun, 26 Apr 2026 19:34:43 -0400 Subject: [PATCH 03/42] feat(post-processing): add linear-to-sRGB module Add a no-parameter post-processing module for training and rendering when the renderer outputs linear RGB against sRGB targets. Made-with: Cursor --- README.md | 8 ++ configs/base_gs.yaml | 5 +- threedgrut/render.py | 9 +- threedgrut/trainer.py | 7 ++ .../utils/post_processing_linear_to_srgb.py | 111 ++++++++++++++++++ threedgrut/utils/render.py | 3 + 6 files changed, 141 insertions(+), 2 deletions(-) create mode 100644 threedgrut/utils/post_processing_linear_to_srgb.py diff --git a/README.md b/README.md index f4da8492..d3857e7c 100644 --- a/README.md +++ b/README.md @@ -265,6 +265,14 @@ python train.py --config-name apps/colmap_3dgrt.yaml path=data/mipnerf360/bonsai python train.py --config-name apps/colmap_3dgut.yaml path=data/mipnerf360/bonsai out_dir=runs experiment_name=bonsai_3dgut dataset.downsample_factor=2 optimizer.type=selective_adam ``` +### Post-processing (linear-to-sRGB and PPISP) + +Hydra key: ``post_processing.method``. Values: + +- **null** (default): no change to rendered RGB before the loss. +- **linear-to-srgb**: **IEC 61966-2-1** piecewise linear-to-sRGB encoding on ``pred_rgb`` (same rule as ``thirdparty/tiny-cuda-nn/scripts/common.py`` ``linear_to_srgb``). See ``threedgrut/utils/post_processing_linear_to_srgb.py``. Example: ``post_processing.method=linear-to-srgb``. +- **ppisp**: per-frame camera corrections; requires the ``ppisp`` package (see ``requirements.txt``) and uses the other ``post_processing.*`` fields in ``configs/base_gs.yaml``. + If you use MCMC and Selective Adam in your research, please cite [3dgs-mcmc](https://github.com/ubc-vision/3dgs-mcmc), [taming-3dgs](https://github.com/humansensinglab/taming-3dgs), and the [gSplat](https://github.com/nerfstudio-project/gsplat/tree/main) library from which the code was adopted (links to the code are provided in the source files). diff --git a/configs/base_gs.yaml b/configs/base_gs.yaml index 51248502..d8b738ee 100644 --- a/configs/base_gs.yaml +++ b/configs/base_gs.yaml @@ -124,7 +124,10 @@ loss: # Post-processing configuration post_processing: - method: null # Possible values: null, "ppisp" + # null | "ppisp" | "linear-to-srgb" + # - linear-to-srgb: IEC piecewise linear-to-sRGB on pred_rgb (same as tiny-cuda-nn common.linear_to_srgb). + # No extra deps; no trainable weights. See threedgrut/utils/post_processing_linear_to_srgb.py. + method: null # Enable the controller for predicting per-frame corrections for novel views. # When false, zero corrections are used for novel views. use_controller: true diff --git a/threedgrut/render.py b/threedgrut/render.py index 313cff66..877a5922 100644 --- a/threedgrut/render.py +++ b/threedgrut/render.py @@ -119,7 +119,14 @@ def from_checkpoint( # Load post-processing if present in checkpoint post_processing = None method = conf.post_processing.method - if "post_processing" in checkpoint and method == "ppisp": + if "post_processing" in checkpoint and method == "linear-to-srgb": + from threedgrut.utils.post_processing_linear_to_srgb import LinearToSrgbPostProcessing + + post_processing = LinearToSrgbPostProcessing() + post_processing.load_state_dict(checkpoint["post_processing"]["module"]) + post_processing = post_processing.to("cuda") + logger.info("Linear-to-sRGB post-processing loaded from checkpoint") + elif "post_processing" in checkpoint and method == "ppisp": from ppisp import PPISP, PPISPConfig # Derive config from training settings to match trainer.py diff --git a/threedgrut/trainer.py b/threedgrut/trainer.py index 223fa539..bd275773 100644 --- a/threedgrut/trainer.py +++ b/threedgrut/trainer.py @@ -421,6 +421,13 @@ def init_post_processing(self, conf: DictConfig): ) logger.info(f"📷 {method.upper()} initialized: {num_cameras} cameras, {num_frames} frames") + elif method == "linear-to-srgb": + from threedgrut.utils.post_processing_linear_to_srgb import LinearToSrgbPostProcessing + + self.post_processing = LinearToSrgbPostProcessing().to(self.device) + self.post_processing_optimizers = [] + self.post_processing_schedulers = [] + logger.info("Post-processing: linear-to-sRGB (no trainable parameters)") else: raise ValueError(f"Unknown post-processing method: {method}") diff --git a/threedgrut/utils/post_processing_linear_to_srgb.py b/threedgrut/utils/post_processing_linear_to_srgb.py new file mode 100644 index 00000000..462196f5 --- /dev/null +++ b/threedgrut/utils/post_processing_linear_to_srgb.py @@ -0,0 +1,111 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +r"""Linear-to-sRGB post-processing for training and inference. + +This module implements ``post_processing.method: "linear-to-srgb"`` (see ``configs/base_gs.yaml``). +The trainer applies it to ``pred_rgb`` after the forward render and **before** photometric loss, +so use it when **ground-truth images are sRGB / display-referred** and the **renderer output is +linear scene-referred RGB** (typical for splatting). + +Integration: + +- **Training:** ``Trainer3DGRUT.init_post_processing`` builds :class:`LinearToSrgbPostProcessing` + when ``conf.post_processing.method == "linear-to-srgb"``. No optimizers; regularization term is + always zero (:meth:`get_regularization_loss`). +- **Inference:** ``Renderer.from_checkpoint`` restores the module from the checkpoint when the + saved config uses the same method. + +The forward signature matches ``threedgrut.utils.render.apply_post_processing``; unused arguments are ignored. + +The piecewise rule matches ``thirdparty/tiny-cuda-nn/scripts/common.py`` ``linear_to_srgb`` +(NumPy); this file uses the same math in PyTorch (no NumPy dependency on that script at runtime). +""" + +from __future__ import annotations + +import torch +import torch.nn as nn + + +def linear_to_srgb(x: torch.Tensor) -> torch.Tensor: + """Linear RGB to sRGB nonlinear light (IEC 61966-2-1 style piecewise). + + Same branch structure as ``linear_to_srgb`` in ``thirdparty/tiny-cuda-nn/scripts/common.py``: + + .. code-block:: python + + np.where(img > limit, 1.055 * (img ** (1.0 / 2.4)) - 0.055, 12.92 * img) + + with ``limit = 0.0031308``. Linear values above ``1`` can yield encoded values above ``1`` (HDR). + + Args: + x: Linear RGB tensor (any shape). + + Returns: + Encoded values, same shape / dtype / device as ``x``. + """ + limit = 0.0031308 + x = torch.clamp(x, min=1e-8, max=1.0) + return torch.where( + x > limit, + 1.055 * torch.pow(x, 1.0 / 2.4) - 0.055, + 12.92 * x, + ) + + +class LinearToSrgbPostProcessing(nn.Module): + """``nn.Module`` wrapper so linear-to-sRGB can plug into the shared post-processing path. + + ``forward`` receives flattened RGB ``[N, 3]`` from ``apply_post_processing`` plus PPISP-style + metadata (pixel coordinates, resolution, camera / frame indices, exposure). Only + ``pred_rgb_flat`` is used; other arguments exist for API compatibility with PPISP. + + There are **no learnable parameters**. Checkpoints still store an (empty) ``state_dict`` for + this module when training with this method. + """ + + def __init__(self) -> None: + super().__init__() + self.register_buffer("_reg_loss_zero", torch.tensor(0.0)) + + def forward( + self, + pred_rgb_flat: torch.Tensor, + pixel_coords_flat: torch.Tensor, + resolution=None, + camera_idx=None, + frame_idx=None, + exposure_prior=None, + ) -> torch.Tensor: + """Encode ``pred_rgb_flat`` with :func:`linear_to_srgb`. + + Args: + pred_rgb_flat: ``[H*W, 3]`` linear RGB (contiguous, batch size 1 upstream). + pixel_coords_flat: Unused (PPISP contract). + resolution: Unused. + camera_idx: Unused. + frame_idx: Unused. + exposure_prior: Unused. + + Returns: + Same shape as ``pred_rgb_flat`` (piecewise IEC-style encode; see :func:`linear_to_srgb`). + """ + del pixel_coords_flat, resolution, camera_idx, frame_idx, exposure_prior + return linear_to_srgb(pred_rgb_flat) + + def get_regularization_loss(self) -> torch.Tensor: + """Scalar zero on the module device; required by the trainer alongside PPISP.""" + return self._reg_loss_zero diff --git a/threedgrut/utils/render.py b/threedgrut/utils/render.py index 57c0f427..f253dfd7 100644 --- a/threedgrut/utils/render.py +++ b/threedgrut/utils/render.py @@ -56,6 +56,9 @@ def apply_post_processing( ) -> dict: """Apply post-processing to rendered output. + ``post_processing`` is typically PPISP or :class:`~threedgrut.utils.post_processing_linear_to_srgb.LinearToSrgbPostProcessing`; + both follow the same ``__call__`` contract (flat RGB plus metadata). + Args: post_processing: Post-processing module outputs: Model outputs including pred_rgb From 18b5a185e52a48ebdf981c55d01f03dfe732bff7 Mon Sep 17 00:00:00 2001 From: Nicolas Moenne-Loccoz Date: Sun, 26 Apr 2026 19:34:43 -0400 Subject: [PATCH 04/42] feat(export): add linear USD color-space toggle Expose export_usd.linear_srgb so LightField exports can author lin_rec709_scene instead of the default display-referred color space. Made-with: Cursor --- configs/base_gs.yaml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/configs/base_gs.yaml b/configs/base_gs.yaml index d8b738ee..16b2d8a8 100644 --- a/configs/base_gs.yaml +++ b/configs/base_gs.yaml @@ -47,6 +47,8 @@ export_usd: export_cameras: true export_background: true sorting_mode_hint: cameraDistance + # If true, Gaussian prim ColorSpaceAPI uses lin_rec709_scene; else srgb_rec709_display + linear_srgb: false model: density_activation: sigmoid From bb4375d3f7c66e8df8b1dcaa21cc3a2eb98b765a Mon Sep 17 00:00:00 2001 From: Nicolas Moenne-Loccoz Date: Sun, 26 Apr 2026 19:34:43 -0400 Subject: [PATCH 05/42] feat(export): add PPISP SPG USD export Author per-camera RenderProducts and attach the PPISP SPG shader assets so standard USD exports can preserve the trained PPISP effect. Made-with: Cursor --- configs/base_gs.yaml | 4 + threedgrut/export/usd/exporter.py | 226 ++++++++--- threedgrut/export/usd/ppisp_spg/__init__.py | 53 +++ .../export/usd/ppisp_spg/ppisp_usd_spg.slang | 245 ++++++++++++ .../usd/ppisp_spg/ppisp_usd_spg.slang.lua | 90 +++++ .../usd/ppisp_spg/ppisp_usd_spg.slang.usda | 58 +++ threedgrut/export/usd/writers/__init__.py | 7 + threedgrut/export/usd/writers/camera.py | 203 ++++------ threedgrut/export/usd/writers/ppisp_writer.py | 364 ++++++++++++++++++ .../export/usd/writers/render_product.py | 85 ++++ threedgrut/trainer.py | 1 + 11 files changed, 1165 insertions(+), 171 deletions(-) create mode 100644 threedgrut/export/usd/ppisp_spg/__init__.py create mode 100644 threedgrut/export/usd/ppisp_spg/ppisp_usd_spg.slang create mode 100644 threedgrut/export/usd/ppisp_spg/ppisp_usd_spg.slang.lua create mode 100644 threedgrut/export/usd/ppisp_spg/ppisp_usd_spg.slang.usda create mode 100644 threedgrut/export/usd/writers/ppisp_writer.py create mode 100644 threedgrut/export/usd/writers/render_product.py diff --git a/configs/base_gs.yaml b/configs/base_gs.yaml index 16b2d8a8..ce0b5344 100644 --- a/configs/base_gs.yaml +++ b/configs/base_gs.yaml @@ -49,6 +49,10 @@ export_usd: sorting_mode_hint: cameraDistance # If true, Gaussian prim ColorSpaceAPI uses lin_rec709_scene; else srgb_rec709_display linear_srgb: false + # PPISP post-processing export as SPG shader on per-camera RenderProducts + export_ppisp: false + # USD timeCodesPerSecond; time codes are bare frame indices so this sets playback speed + frames_per_second: 1.0 model: density_activation: sigmoid diff --git a/threedgrut/export/usd/exporter.py b/threedgrut/export/usd/exporter.py index 2a6403dc..62b41ab4 100644 --- a/threedgrut/export/usd/exporter.py +++ b/threedgrut/export/usd/exporter.py @@ -161,6 +161,57 @@ def _extract_camera_params_from_dataset(dataset) -> Optional[List]: return None +def _extract_camera_grouping(dataset): + """Extract camera grouping info from a dataset. + + Returns: + (camera_names, frame_to_camera) where camera_names is a list of logical + camera names and frame_to_camera maps frame_idx → camera_idx. + """ + camera_names = None + frame_to_camera = None + + if hasattr(dataset, "get_camera_names"): + camera_names = dataset.get_camera_names() + if hasattr(dataset, "get_camera_idx"): + frame_to_camera = [dataset.get_camera_idx(i) for i in range(len(dataset))] + + if camera_names is None: + camera_names = ["camera_0000"] + if frame_to_camera is None: + frame_to_camera = [0] * len(dataset) + + return camera_names, frame_to_camera + + +def _extract_camera_resolutions(camera_params: List, camera_names: List[str], frame_to_camera: List[int]): + """Extract per-camera resolution from the first valid frame of each camera. + + Returns: + {camera_name: (width, height)} or empty dict on failure. + """ + result = {} + num_cameras = len(camera_names) + # Build first-frame-per-camera map + first_frame: Dict[int, int] = {} + for frame_idx, cam_idx in enumerate(frame_to_camera): + if cam_idx not in first_frame and 0 <= cam_idx < num_cameras: + first_frame[cam_idx] = frame_idx + + for cam_idx, cam_name in enumerate(camera_names): + frame_idx = first_frame.get(cam_idx) + if frame_idx is None or camera_params is None: + continue + params = camera_params[frame_idx] if frame_idx < len(camera_params) else None + if params is None: + continue + if hasattr(params, "resolution"): + w, h = int(params.resolution[0]), int(params.resolution[1]) + result[cam_name] = (w, h) + + return result + + class USDExporter(ModelExporter): """ Exporter for OpenUSD format using ParticleField3DGaussianSplat schema. @@ -170,8 +221,9 @@ class USDExporter(ModelExporter): Features: - ParticleField3DGaussianSplat schema (standard OpenUSD) - - Optional camera export with full intrinsics + - One Camera prim per physical camera with time-sampled transforms - Background/environment export as DomeLight + - Optional PPISP SPG shader on per-camera RenderProducts - USDZ packaging (default output) For Omniverse/NuRec compatibility, use NuRecExporter instead. @@ -187,6 +239,8 @@ def __init__( apply_normalizing_transform: bool = True, sorting_mode_hint: str = "cameraDistance", linear_srgb: bool = False, + export_ppisp: bool = False, + frames_per_second: float = 1.0, ): """ Initialize the USD exporter. @@ -195,11 +249,16 @@ def __init__( half_precision: If True, use half for both geometry and features (backward compat). half_geometry: Use half precision for positions, orientations, scales (LightField). half_features: Use half precision for opacities and SH coefficients (LightField). - export_cameras: Include camera poses in export - export_background: Include background/environment in export - apply_normalizing_transform: Apply transform to normalize scene orientation - sorting_mode_hint: Sorting hint for rendering ("cameraDistance", "zDepth" per UsdVol schema) - linear_srgb: If True, set prim color space to lin_rec709_scene; else srgb_rec709_display + export_cameras: Include camera poses in export. + export_background: Include background/environment in export. + apply_normalizing_transform: Apply transform to normalize scene orientation. + sorting_mode_hint: Sorting hint for rendering ("cameraDistance", "zDepth"). + linear_srgb: If True, set prim color space to lin_rec709_scene. + export_ppisp: If True, add PPISP SPG shaders on per-camera RenderProducts. + Requires post_processing kwarg to be a ppisp.PPISP instance. + frames_per_second: Sets stage.timeCodesPerSecond. Time codes are always + bare frame indices (float(frame_idx)), so this controls playback speed. + Default 1.0 means 1 frame per second of real time. """ if half_precision: half_geometry = True @@ -211,25 +270,19 @@ def __init__( self.apply_normalizing_transform = apply_normalizing_transform self.sorting_mode_hint = sorting_mode_hint self.linear_srgb = linear_srgb + self.export_ppisp = export_ppisp + self.frames_per_second = frames_per_second def _create_default_stage(self, referenced_stages: List[NamedUSDStage]) -> NamedUSDStage: """ Create a default.usda that references the data stages. - - Args: - referenced_stages: List of stages to reference (e.g., gaussians.usdc) - - Returns: - NamedUSDStage for default.usda """ stage = initialize_usd_stage(up_axis="Y") for ref_stage in referenced_stages: - # Create a reference prim for each stage filename_stem = Path(ref_stage.filename).stem prim_path = f"/World/{filename_stem}" prim = stage.OverridePrim(prim_path) - # Reference the file (bare filename for in-package resolution; same as NuRec) prim.GetReferences().AddReference(ref_stage.filename) return NamedUSDStage(filename="default.usda", stage=stage) @@ -248,26 +301,28 @@ def export( Export the model to a USDZ file. Args: - model: The model to export (must implement ExportableModel) - output_path: Path where the USDZ file will be saved - dataset: Optional dataset for camera poses - conf: Configuration parameters - background: Optional background model for environment export - **kwargs: Additional parameters. ``validate_usd`` (default True): run OpenUSD - stage validators on the written file (ParticleField / LightField only; no-op if - ``UsdValidation`` is unavailable). + model: The model to export (must implement ExportableModel). + output_path: Path where the USDZ file will be saved. + dataset: Optional dataset for camera poses. + conf: Configuration parameters. + background: Optional background model for environment export. + **kwargs: + post_processing: ppisp.PPISP instance for SPG export (used when + export_ppisp=True). + validate_usd (default True): run OpenUSD stage validators. + apply_coordinate_transform (bool): apply 3DGRUT→USDZ coordinate flip. + copy_source_usd: (stage_path, res_root) for prim merge. + copy_source_skip_subtrees: subtrees to skip during prim merge. """ output_path = Path(output_path) logger.info(f"Exporting USD file to {output_path}...") # Get model data via accessor - # LightField expects post-activation values (opacity in [0,1], actual scales) accessor = GaussianExportAccessor(model, conf) attrs = accessor.get_attributes(preactivation=False) caps = accessor.get_capabilities() logger.info(f"Schema: LightField (post-activation)") - logger.info(f"Exporting {attrs.num_gaussians} Gaussians, SH degree {caps.sh_degree}") # Compute normalizing transform if enabled @@ -280,13 +335,14 @@ def export( except (AttributeError, ValueError) as e: logger.warning(f"Failed to compute normalizing transform: {e}") - # Create main USD stage + # Create main USD stage with the configured time code rate stage = initialize_usd_stage(up_axis="Y") + stage.SetTimeCodesPerSecond(self.frames_per_second) apply_coordinate_transform = kwargs.get("apply_coordinate_transform", False) coordinate_transform = get_3dgrut_to_usdz_coordinate_transform() if apply_coordinate_transform else None - # Create Gaussian content root with optional normalizing and coordinate transform + # Create Gaussian content root gaussians_root = create_gaussian_model_root( stage, flip_x_axis=False, @@ -297,7 +353,7 @@ def export( coordinate_transform=coordinate_transform, ) - # Create Gaussian writer (LightField schema) + # Write Gaussians writer = create_gaussian_writer( stage=stage, capabilities=caps, @@ -307,8 +363,6 @@ def export( sorting_mode_hint=self.sorting_mode_hint, linear_srgb=self.linear_srgb, ) - - # Write Gaussians writer.create_prim(attrs.num_gaussians) writer.write_attributes(attrs) writer.finalize(attrs.positions) @@ -350,32 +404,38 @@ def export( except Exception as e: logger.warning("Failed to merge source USD prims: %s", e) - # Export cameras if requested and dataset available + # Extract camera grouping from dataset (used by both camera export and PPISP) + camera_names = None + frame_to_camera = None + camera_prim_paths: Dict[str, str] = {} + + if dataset is not None: + camera_names, frame_to_camera = _extract_camera_grouping(dataset) + + # Export cameras — one prim per physical camera with time-sampled transforms if self.export_cameras and dataset is not None: try: poses = dataset.get_poses() - # When we apply normalizing transform to the Gaussian root, cameras must be in the - # same coordinate system: apply normalizing transform to each c2w (world → normalized). if self.apply_normalizing_transform: poses = np.einsum("ij,njk->nik", normalizing_transform, poses) - # Extract per-frame camera parameters from dataset camera_params = _extract_camera_params_from_dataset(dataset) - if camera_params is not None: logger.info(f"Extracted camera params for {len(camera_params)} frames") else: logger.warning("Could not extract camera intrinsics from dataset, using default") - export_cameras_to_usd( + camera_prim_paths = export_cameras_to_usd( stage=stage, poses=poses, + camera_names=camera_names, + frame_to_camera=frame_to_camera, camera_params=camera_params, root_path="/World/Cameras", visible=False, ) - logger.info(f"Exported {len(poses)} cameras") + logger.info(f"Exported {len(camera_prim_paths)} camera(s) from {len(poses)} frames") except (AttributeError, KeyError, ValueError) as e: logger.warning(f"Failed to export cameras: {e}") @@ -396,7 +456,20 @@ def export( except (AttributeError, ValueError, ImportError) as e: logger.warning(f"Failed to export background: {e}") - # Package: gaussians_stage / default_stage_wrapped were built before source merge. + # Export PPISP as SPG shaders on RenderProducts + if self.export_ppisp: + self._export_ppisp( + stage=stage, + dataset=dataset, + camera_names=camera_names, + frame_to_camera=frame_to_camera, + camera_prim_paths=camera_prim_paths, + camera_params=_extract_camera_params_from_dataset(dataset) if dataset is not None else None, + post_processing=kwargs.get("post_processing"), + files=files, + ) + + # Package if suffix == ".usdz": if default_stage_wrapped is None: default_stage_wrapped = self._create_default_stage([gaussians_stage]) @@ -423,16 +496,81 @@ def export( logger.info(f"USD export complete: {output_path}") + def _export_ppisp( + self, + stage, + dataset, + camera_names, + frame_to_camera, + camera_prim_paths: Dict[str, str], + camera_params, + post_processing, + files: List[NamedSerialized], + ) -> None: + """Create /Render RenderProducts and attach PPISP SPG shaders.""" + try: + from ppisp import PPISP # type: ignore[import-not-found] + except ImportError: + logger.warning("ppisp package not available, skipping PPISP export") + return + + if not isinstance(post_processing, PPISP): + logger.warning( + f"export_ppisp=True but post_processing is {type(post_processing).__name__}, " + "expected ppisp.PPISP — skipping" + ) + return + + if dataset is None or not camera_prim_paths: + logger.warning("No camera prims available for PPISP RenderProduct wiring, skipping") + return + + from threedgrut.export.usd.writers.render_product import create_render_products + from threedgrut.export.usd.writers.ppisp_writer import ( + add_ppisp_to_all_render_products, + build_camera_frame_mapping, + ) + from threedgrut.export.usd.ppisp_spg import get_ppisp_spg_files + + # Build camera_entries: camera_name → (usd_camera_path, width, height) + resolutions = _extract_camera_resolutions(camera_params, camera_names, frame_to_camera) + camera_entries = {} + for cam_name, cam_path in camera_prim_paths.items(): + w, h = resolutions.get(cam_name, (0, 0)) + camera_entries[cam_name] = (cam_path, w, h) + + try: + create_render_products(stage=stage, camera_entries=camera_entries) + except Exception as e: + logger.warning(f"Failed to create RenderProducts: {e}") + return + + # Build frame mapping from dataset + _, camera_frame_mapping = build_camera_frame_mapping(dataset) + + try: + add_ppisp_to_all_render_products( + stage=stage, + ppisp=post_processing, + camera_names=camera_names, + camera_frame_mapping=camera_frame_mapping, + ) + except Exception as e: + logger.warning(f"Failed to add PPISP shaders: {e}") + return + + # Add SPG sidecars to the USDZ package + spg_files = get_ppisp_spg_files() + for spg_file in spg_files: + if not any(f.filename == spg_file.filename for f in files): + files.append(spg_file) + + logger.info(f"PPISP SPG export complete: {len(spg_files)} sidecar(s) added") + @classmethod def from_config(cls, conf) -> "USDExporter": """ Create USDExporter from configuration. - - Args: - conf: Configuration object with export_usd section - - Returns: - Configured USDExporter instance """ export_conf = getattr(conf, "export_usd", None) or conf half_precision = getattr(export_conf, "half_precision", False) @@ -449,4 +587,6 @@ def from_config(cls, conf) -> "USDExporter": apply_normalizing_transform=getattr(export_conf, "apply_normalizing_transform", True), sorting_mode_hint=getattr(export_conf, "sorting_mode_hint", "cameraDistance"), linear_srgb=getattr(export_conf, "linear_srgb", False), + export_ppisp=getattr(export_conf, "export_ppisp", False), + frames_per_second=getattr(export_conf, "frames_per_second", 1.0), ) diff --git a/threedgrut/export/usd/ppisp_spg/__init__.py b/threedgrut/export/usd/ppisp_spg/__init__.py new file mode 100644 index 00000000..7aa37f5b --- /dev/null +++ b/threedgrut/export/usd/ppisp_spg/__init__.py @@ -0,0 +1,53 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +PPISP SPG shader assets for USD RenderProduct post-processing. + +Provides loader for the three SPG sidecar files (Slang shader, Lua launcher, +USDA definition) that must be packaged alongside the exported USDZ. +""" + +import logging +from pathlib import Path +from typing import List + +from threedgrut.export.usd.stage_utils import NamedSerialized + +log = logging.getLogger(__name__) + +_SPG_DIR = Path(__file__).parent +_SPG_FILES = [ + "ppisp_usd_spg.slang", + "ppisp_usd_spg.slang.lua", + "ppisp_usd_spg.slang.usda", +] + + +def get_ppisp_spg_files() -> List[NamedSerialized]: + """Load all PPISP SPG sidecar files as serialized data for USDZ packaging. + + Returns: + List of NamedSerialized for each SPG file (slang, lua, usda). + """ + result: List[NamedSerialized] = [] + for filename in _SPG_FILES: + path = _SPG_DIR / filename + if path.exists(): + result.append(NamedSerialized(filename=filename, serialized=path.read_bytes())) + log.debug(f"Loaded PPISP SPG sidecar: {filename}") + else: + log.warning(f"PPISP SPG sidecar not found: {path}") + return result diff --git a/threedgrut/export/usd/ppisp_spg/ppisp_usd_spg.slang b/threedgrut/export/usd/ppisp_spg/ppisp_usd_spg.slang new file mode 100644 index 00000000..02fc2fc3 --- /dev/null +++ b/threedgrut/export/usd/ppisp_spg/ppisp_usd_spg.slang @@ -0,0 +1,245 @@ +// SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// PPISP (Physically Plausible Image Signal Processing) SPG Shader +// +// Implements the ISP pipeline for USD RenderProducts: +// 1. Exposure compensation +// 2. Vignetting correction (per-channel) +// 3. Color correction via ZCA-based homography +// 4. Camera Response Function (per-channel, 4-param toe-shoulder curve) +// +// NOTE: All parameters use flat naming to match USD inputs: attributes (UsdShade-compatible). +// SPG requires the Slang struct field names to match USD input names. + +struct PPISPParams +{ + // Exposure + float exposureOffset; + + // Vignetting R channel + float2 vignettingCenterR; + float vignettingAlpha1R; + float vignettingAlpha2R; + float vignettingAlpha3R; + + // Vignetting G channel + float2 vignettingCenterG; + float vignettingAlpha1G; + float vignettingAlpha2G; + float vignettingAlpha3G; + + // Vignetting B channel + float2 vignettingCenterB; + float vignettingAlpha1B; + float vignettingAlpha2B; + float vignettingAlpha3B; + + // Color correction: 4 control-point latent offsets (Blue, Red, Green, Neutral) + float2 colorLatentBlue; + float2 colorLatentRed; + float2 colorLatentGreen; + float2 colorLatentNeutral; + + // CRF R channel (raw params: activations applied at runtime) + float crfToeR; + float crfShoulderR; + float crfGammaR; + float crfCenterR; + + // CRF G channel + float crfToeG; + float crfShoulderG; + float crfGammaG; + float crfCenterG; + + // CRF B channel + float crfToeB; + float crfShoulderB; + float crfGammaB; + float crfCenterB; +}; + +[[vk::binding(0, 1)]] ParameterBlock g_Params; +[[vk::binding(1, 1)]] Texture2D g_InTex; +[[vk::binding(2, 1)]] RWTexture2D g_OutTex; + +// ZCA pinv 2x2 blocks (constant, matching ppisp_math.cuh COLOR_PINV_BLOCKS) +static const float2x2 ZCA_BLUE = float2x2( 0.0480542, -0.0043631, -0.0043631, 0.0481283); +static const float2x2 ZCA_RED = float2x2( 0.0580570, -0.0179872, -0.0179872, 0.0431061); +static const float2x2 ZCA_GREEN = float2x2( 0.0433336, -0.0180537, -0.0180537, 0.0580500); +static const float2x2 ZCA_NEUTRAL = float2x2( 0.0128369, -0.0034654, -0.0034654, 0.0128158); + +// Compute 3x3 homography from ZCA latent offsets (port of compute_homography from ppisp_math.cuh) +float3x3 computeHomography(float2 bLat, float2 rLat, float2 gLat, float2 nLat) +{ + float2 bd = mul(ZCA_BLUE, bLat); + float2 rd = mul(ZCA_RED, rLat); + float2 gd = mul(ZCA_GREEN, gLat); + float2 nd = mul(ZCA_NEUTRAL, nLat); + + // Target chromaticities: source + offset. Source = (r,g,I) for pure B,R,G,gray + float3 tB = float3(0.0 + bd.x, 0.0 + bd.y, 1.0); + float3 tR = float3(1.0 + rd.x, 0.0 + rd.y, 1.0); + float3 tG = float3(0.0 + gd.x, 1.0 + gd.y, 1.0); + float3 tGray = float3(1.0 / 3.0 + nd.x, 1.0 / 3.0 + nd.y, 1.0); + + // T = [tB | tR | tG] as columns (row-major: row i = [tB[i], tR[i], tG[i]]) + float3x3 T = float3x3(tB.x, tR.x, tG.x, + tB.y, tR.y, tG.y, + tB.z, tR.z, tG.z); + + // Skew-symmetric matrix [tGray]_x + float3x3 skew = float3x3(0.0, -tGray.z, tGray.y, + tGray.z, 0.0, -tGray.x, + -tGray.y, tGray.x, 0.0); + + float3x3 M = mul(skew, T); + + // Null-space vector via cross product of first two rows + float3 r0 = M[0]; + float3 r1 = M[1]; + float3 r2 = M[2]; + + float3 lam = cross(r0, r1); + if (dot(lam, lam) < 1.0e-20) + { + lam = cross(r0, r2); + if (dot(lam, lam) < 1.0e-20) + lam = cross(r1, r2); + } + + // S_inv = [[-1,-1,1],[1,0,0],[0,1,0]] + float3x3 Sinv = float3x3(-1.0, -1.0, 1.0, + 1.0, 0.0, 0.0, + 0.0, 1.0, 0.0); + + // D = diag(lam) + float3x3 D = float3x3(lam.x, 0.0, 0.0, + 0.0, lam.y, 0.0, + 0.0, 0.0, lam.z); + + // H = T * D * S_inv + float3x3 H = mul(mul(T, D), Sinv); + + // Normalize so H[2][2] = 1 + float s = H[2][2]; + if (abs(s) > 1.0e-20) + H = H * (1.0 / s); + + return H; +} + +float applyVignetting(float value, float2 uv, float2 opticalCenter, float alpha1, float alpha2, float alpha3) +{ + float2 delta = uv - opticalCenter; + float r2 = dot(delta, delta); + + float falloff = 1.0; + float r2Pow = r2; + falloff += alpha1 * r2Pow; + r2Pow *= r2; + falloff += alpha2 * r2Pow; + r2Pow *= r2; + falloff += alpha3 * r2Pow; + + falloff = clamp(falloff, 0.0, 1.0); + return value * falloff; +} + +float boundedSoftplus(float raw, float minValue) +{ + return minValue + log(1.0 + exp(raw)); +} + +float sigmoid(float raw) +{ + return 1.0 / (1.0 + exp(-raw)); +} + +// 4-param toe-shoulder CRF (port of apply_crf_ppisp from ppisp_math.cuh) +float applyCRF(float x, float toeRaw, float shoulderRaw, float gammaRaw, float centerRaw) +{ + x = clamp(x, 0.0, 1.0); + + float toe = boundedSoftplus(toeRaw, 0.3); + float shoulder = boundedSoftplus(shoulderRaw, 0.3); + float gamma = boundedSoftplus(gammaRaw, 0.1); + float center = sigmoid(centerRaw); + + // toe >= 0.3, shoulder >= 0.3, center in (0,1) — divisions are safe + float lerpVal = (shoulder - toe) * center + toe; + float a = (shoulder * center) / lerpVal; + float b = 1.0 - a; + + float y; + if (x <= center) + y = a * pow(x / center, toe); + else + y = 1.0 - b * pow((1.0 - x) / (1.0 - center), shoulder); + + return pow(max(0.0, y), gamma); +} + +float3 applyColorCorrection(float3 rgb, float3x3 H) +{ + float intensity = rgb.x + rgb.y + rgb.z; + float3 rgi = float3(rgb.x, rgb.y, intensity); + + rgi = mul(H, rgi); + + rgi = rgi * (intensity / (rgi.z + 1.0e-5)); + return float3(rgi.x, rgi.y, rgi.z - rgi.x - rgi.y); +} + +[shader("compute")] +[numthreads(16, 16, 1)] +void ppispProcess(uint3 tid : SV_DispatchThreadID) +{ + uint w = 0, h = 0; + g_InTex.GetDimensions(w, h); + if (tid.x >= w || tid.y >= h) + return; + + float4 pixel = g_InTex.Load(int3(tid.xy, 0)); + float3 rgb = pixel.rgb; + + // Normalize to [-0.5, 0.5] range based on max dimension (matching CUDA kernel) + float maxRes = max(float(w), float(h)); + float2 uv = float2(tid.x + 0.5 - float(w) * 0.5, tid.y + 0.5 - float(h) * 0.5) / maxRes; + + // 1. Exposure + rgb = rgb * exp2(g_Params.exposureOffset); + + // 2. Vignetting (per-channel) + rgb.r = applyVignetting(rgb.r, uv, g_Params.vignettingCenterR, + g_Params.vignettingAlpha1R, g_Params.vignettingAlpha2R, g_Params.vignettingAlpha3R); + rgb.g = applyVignetting(rgb.g, uv, g_Params.vignettingCenterG, + g_Params.vignettingAlpha1G, g_Params.vignettingAlpha2G, g_Params.vignettingAlpha3G); + rgb.b = applyVignetting(rgb.b, uv, g_Params.vignettingCenterB, + g_Params.vignettingAlpha1B, g_Params.vignettingAlpha2B, g_Params.vignettingAlpha3B); + + // 3. Color correction (ZCA-based homography) + float3x3 H = computeHomography(g_Params.colorLatentBlue, g_Params.colorLatentRed, + g_Params.colorLatentGreen, g_Params.colorLatentNeutral); + rgb = applyColorCorrection(rgb, H); + + // 4. CRF (per-channel, 4-param toe-shoulder) + rgb.r = applyCRF(rgb.r, g_Params.crfToeR, g_Params.crfShoulderR, g_Params.crfGammaR, g_Params.crfCenterR); + rgb.g = applyCRF(rgb.g, g_Params.crfToeG, g_Params.crfShoulderG, g_Params.crfGammaG, g_Params.crfCenterG); + rgb.b = applyCRF(rgb.b, g_Params.crfToeB, g_Params.crfShoulderB, g_Params.crfGammaB, g_Params.crfCenterB); + + g_OutTex[tid.xy] = float4(rgb, pixel.a); +} diff --git a/threedgrut/export/usd/ppisp_spg/ppisp_usd_spg.slang.lua b/threedgrut/export/usd/ppisp_spg/ppisp_usd_spg.slang.lua new file mode 100644 index 00000000..5f3c0f48 --- /dev/null +++ b/threedgrut/export/usd/ppisp_spg/ppisp_usd_spg.slang.lua @@ -0,0 +1,90 @@ +-- SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +-- SPDX-License-Identifier: Apache-2.0 +-- +-- Licensed under the Apache License, Version 2.0 (the "License"); +-- you may not use this file except in compliance with the License. +-- You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. + +-- PPISP (Physically Plausible Image Signal Processing) SPG Launcher +-- +-- Binds PPISP parameters and dispatches the compute shader for +-- USD RenderProduct post-processing. +-- +-- NOTE: Uses flat parameter names matching USD inputs: attributes (UsdShade-compatible). + +function ppispProcess(inputs, outputs, params) + local in_rgba = inputs["HdrColor"] + assert(in_rgba and in_rgba.rank == 2, "HdrColor input must be a 2D texture") + + -- Output texture mirrors input shape and dtype + local height = in_rgba.shape[1] + local width = in_rgba.shape[2] + outputs["PPISPColor"] = slang.empty({height, width}, in_rgba.dtype) + + -- Pass params directly to preserve __fullName for shader reflection matching. + local function getFloat2(name) + local p = params[name] + return p and slang.float2(p) or slang.float2(0.0, 0.0) + end + + return slang.dispatch({ + bind = { + slang.ParameterBlock( + -- Exposure + slang.float(params["exposureOffset"] or 0.0), + + -- Vignetting R + getFloat2("vignettingCenterR"), + slang.float(params["vignettingAlpha1R"] or 0.0), + slang.float(params["vignettingAlpha2R"] or 0.0), + slang.float(params["vignettingAlpha3R"] or 0.0), + + -- Vignetting G + getFloat2("vignettingCenterG"), + slang.float(params["vignettingAlpha1G"] or 0.0), + slang.float(params["vignettingAlpha2G"] or 0.0), + slang.float(params["vignettingAlpha3G"] or 0.0), + + -- Vignetting B + getFloat2("vignettingCenterB"), + slang.float(params["vignettingAlpha1B"] or 0.0), + slang.float(params["vignettingAlpha2B"] or 0.0), + slang.float(params["vignettingAlpha3B"] or 0.0), + + -- Color latent offsets (4 control points) + getFloat2("colorLatentBlue"), + getFloat2("colorLatentRed"), + getFloat2("colorLatentGreen"), + getFloat2("colorLatentNeutral"), + + -- CRF R (defaults = identity: boundedSoftplus(0.013659,0.3)=1, sigmoid(0)=0.5) + slang.float(params["crfToeR"] or 0.013659), + slang.float(params["crfShoulderR"] or 0.013659), + slang.float(params["crfGammaR"] or 0.378165), + slang.float(params["crfCenterR"] or 0.0), + + -- CRF G + slang.float(params["crfToeG"] or 0.013659), + slang.float(params["crfShoulderG"] or 0.013659), + slang.float(params["crfGammaG"] or 0.378165), + slang.float(params["crfCenterG"] or 0.0), + + -- CRF B + slang.float(params["crfToeB"] or 0.013659), + slang.float(params["crfShoulderB"] or 0.013659), + slang.float(params["crfGammaB"] or 0.378165), + slang.float(params["crfCenterB"] or 0.0) + ), + slang.Texture2D(in_rgba), + slang.RWTexture2D(outputs["PPISPColor"]), + }, + }) +end diff --git a/threedgrut/export/usd/ppisp_spg/ppisp_usd_spg.slang.usda b/threedgrut/export/usd/ppisp_spg/ppisp_usd_spg.slang.usda new file mode 100644 index 00000000..b423a28e --- /dev/null +++ b/threedgrut/export/usd/ppisp_spg/ppisp_usd_spg.slang.usda @@ -0,0 +1,58 @@ +#usda 1.0 +( + defaultPrim = "SlangPPISP" +) + +def Shader "SlangPPISP" +{ + uniform token info:implementationSource = "sourceAsset" + uniform asset info:spg:sourceAsset = @ppisp_usd_spg.slang@ + uniform token info:spg:sourceAsset:subIdentifier = "ppispProcess" + + # Exposure parameter + float inputs:exposureOffset = 0.0 + + # Vignetting parameters (per channel: R, G, B) + float2 inputs:vignettingCenterR = (0.0, 0.0) + float inputs:vignettingAlpha1R = 0.0 + float inputs:vignettingAlpha2R = 0.0 + float inputs:vignettingAlpha3R = 0.0 + + float2 inputs:vignettingCenterG = (0.0, 0.0) + float inputs:vignettingAlpha1G = 0.0 + float inputs:vignettingAlpha2G = 0.0 + float inputs:vignettingAlpha3G = 0.0 + + float2 inputs:vignettingCenterB = (0.0, 0.0) + float inputs:vignettingAlpha1B = 0.0 + float inputs:vignettingAlpha2B = 0.0 + float inputs:vignettingAlpha3B = 0.0 + + # Color correction latent offsets (ZCA-based, 4 control points x 2D) + float2 inputs:colorLatentBlue = (0.0, 0.0) + float2 inputs:colorLatentRed = (0.0, 0.0) + float2 inputs:colorLatentGreen = (0.0, 0.0) + float2 inputs:colorLatentNeutral = (0.0, 0.0) + + # CRF raw parameters (per channel: R, G, B) + # Activations: boundedSoftplus(raw, min) for toe/shoulder/gamma, sigmoid(raw) for center + # Defaults produce identity CRF: toe=1, shoulder=1, gamma=1, center=0.5 + float inputs:crfToeR = 0.013659 + float inputs:crfShoulderR = 0.013659 + float inputs:crfGammaR = 0.378165 + float inputs:crfCenterR = 0.0 + + float inputs:crfToeG = 0.013659 + float inputs:crfShoulderG = 0.013659 + float inputs:crfGammaG = 0.378165 + float inputs:crfCenterG = 0.0 + + float inputs:crfToeB = 0.013659 + float inputs:crfShoulderB = 0.013659 + float inputs:crfGammaB = 0.378165 + float inputs:crfCenterB = 0.0 + + # Image inputs/outputs + opaque inputs:HdrColor + opaque outputs:PPISPColor +} diff --git a/threedgrut/export/usd/writers/__init__.py b/threedgrut/export/usd/writers/__init__.py index a23669d1..92531563 100644 --- a/threedgrut/export/usd/writers/__init__.py +++ b/threedgrut/export/usd/writers/__init__.py @@ -18,12 +18,17 @@ Provides schema-agnostic interface for writing Gaussian data to USD: - GaussianLightFieldWriter: ParticleField3DGaussianSplat schema +- export_cameras_to_usd: one Camera prim per physical camera, animated xforms +- create_render_products: /Render scope with per-camera RenderProducts +- add_ppisp_to_all_render_products: PPISP SPG shader on RenderProducts """ from threedgrut.export.usd.writers.background import export_background_to_usd from threedgrut.export.usd.writers.base import GaussianUSDWriter, create_gaussian_writer from threedgrut.export.usd.writers.camera import export_cameras_to_usd from threedgrut.export.usd.writers.lightfield import GaussianLightFieldWriter +from threedgrut.export.usd.writers.ppisp_writer import add_ppisp_to_all_render_products +from threedgrut.export.usd.writers.render_product import create_render_products __all__ = [ "GaussianUSDWriter", @@ -31,4 +36,6 @@ "create_gaussian_writer", "export_cameras_to_usd", "export_background_to_usd", + "create_render_products", + "add_ppisp_to_all_render_products", ] diff --git a/threedgrut/export/usd/writers/camera.py b/threedgrut/export/usd/writers/camera.py index f6dc6c6b..172c8b84 100644 --- a/threedgrut/export/usd/writers/camera.py +++ b/threedgrut/export/usd/writers/camera.py @@ -16,53 +16,57 @@ """ Camera USD writer for exporting camera poses and intrinsics. -Exports camera poses with full intrinsics support for OpenCVPinhole and OpenCVFisheye -camera models, following the pattern established in NRE's rig_trajectories.py. +Exports one Camera prim per physical camera with time-sampled transforms +and static intrinsics, following the pattern established in NRE's +rig_trajectories.py. """ import logging -from typing import List, Optional +from typing import Dict, List, Optional import numpy as np from ncore.data import ( OpenCVFisheyeCameraModelParameters, OpenCVPinholeCameraModelParameters, ) -from pxr import Gf, Sdf, Usd, UsdGeom, Vt +from pxr import Gf, Sdf, Tf, Usd, UsdGeom from threedgrut.export.transforms import column_vector_4x4_to_usd_matrix logger = logging.getLogger(__name__) -# Default clipping range for cameras DEFAULT_NEAR_CLIP = 0.001 DEFAULT_FAR_CLIP = 10000000.0 +# Coordinate transform from 3DGRUT (right-down-front) to USD camera (right-up-back) +_CAMERA_COORD_FLIP = np.array( + [[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]], dtype=np.float64 +) + + +def _make_usd_prim_name(name: str) -> str: + """Convert an arbitrary string to a valid USD prim identifier.""" + return Tf.MakeValidIdentifier(name) + def _add_opencv_pinhole_camera_intrinsics( camera_prim: Usd.Prim, params: OpenCVPinholeCameraModelParameters, ) -> None: - """Add OpenCV pinhole camera intrinsics to USD camera prim.""" - # Camera projection type - camera_prim.CreateAttribute("cameraProjectionType", Sdf.ValueTypeNames.Token).Set(Vt.Token("pinholeOpenCV")) + camera_prim.CreateAttribute("cameraProjectionType", Sdf.ValueTypeNames.Token).Set("pinholeOpenCV") - # Resolution resolution_list = params.resolution.tolist() camera_prim.CreateAttribute("fthetaWidth", Sdf.ValueTypeNames.Float).Set(float(resolution_list[0])) camera_prim.CreateAttribute("fthetaHeight", Sdf.ValueTypeNames.Float).Set(float(resolution_list[1])) - # Principal point principal_point_list = params.principal_point.tolist() camera_prim.CreateAttribute("fthetaCx", Sdf.ValueTypeNames.Float).Set(float(principal_point_list[0])) camera_prim.CreateAttribute("fthetaCy", Sdf.ValueTypeNames.Float).Set(float(principal_point_list[1])) - # Focal length focal_length_list = params.focal_length.tolist() camera_prim.CreateAttribute("openCVFx", Sdf.ValueTypeNames.Float).Set(float(focal_length_list[0])) camera_prim.CreateAttribute("openCVFy", Sdf.ValueTypeNames.Float).Set(float(focal_length_list[1])) - # Radial distortion coefficients [k1,k2,k3,k4,k5,k6] radial_coeffs_list = params.radial_coeffs.tolist() camera_prim.CreateAttribute("fthetaPolyA", Sdf.ValueTypeNames.Float).Set(float(radial_coeffs_list[0])) camera_prim.CreateAttribute("fthetaPolyB", Sdf.ValueTypeNames.Float).Set(float(radial_coeffs_list[1])) @@ -71,12 +75,10 @@ def _add_opencv_pinhole_camera_intrinsics( camera_prim.CreateAttribute("fthetaPolyE", Sdf.ValueTypeNames.Float).Set(float(radial_coeffs_list[4])) camera_prim.CreateAttribute("fthetaPolyF", Sdf.ValueTypeNames.Float).Set(float(radial_coeffs_list[5])) - # Tangential distortion coefficients [p1,p2] tangential_coeffs_list = params.tangential_coeffs.tolist() camera_prim.CreateAttribute("p0", Sdf.ValueTypeNames.Float).Set(float(tangential_coeffs_list[0])) camera_prim.CreateAttribute("p1", Sdf.ValueTypeNames.Float).Set(float(tangential_coeffs_list[1])) - # Thin prism distortion coefficients [s1,s2,s3,s4] thin_prism_coeffs_list = params.thin_prism_coeffs.tolist() camera_prim.CreateAttribute("s0", Sdf.ValueTypeNames.Float).Set(float(thin_prism_coeffs_list[0])) camera_prim.CreateAttribute("s1", Sdf.ValueTypeNames.Float).Set(float(thin_prism_coeffs_list[1])) @@ -88,152 +90,116 @@ def _add_opencv_fisheye_camera_intrinsics( camera_prim: Usd.Prim, params: OpenCVFisheyeCameraModelParameters, ) -> None: - """Add OpenCV fisheye camera intrinsics to USD camera prim.""" - # Camera projection type - camera_prim.CreateAttribute("cameraProjectionType", Sdf.ValueTypeNames.Token).Set(Vt.Token("fisheyeOpenCV")) + camera_prim.CreateAttribute("cameraProjectionType", Sdf.ValueTypeNames.Token).Set("fisheyeOpenCV") - # Resolution resolution_list = params.resolution.tolist() camera_prim.CreateAttribute("fthetaWidth", Sdf.ValueTypeNames.Float).Set(float(resolution_list[0])) camera_prim.CreateAttribute("fthetaHeight", Sdf.ValueTypeNames.Float).Set(float(resolution_list[1])) - # Principal point principal_point_list = params.principal_point.tolist() camera_prim.CreateAttribute("fthetaCx", Sdf.ValueTypeNames.Float).Set(float(principal_point_list[0])) camera_prim.CreateAttribute("fthetaCy", Sdf.ValueTypeNames.Float).Set(float(principal_point_list[1])) - # Focal length focal_length_list = params.focal_length.tolist() camera_prim.CreateAttribute("openCVFx", Sdf.ValueTypeNames.Float).Set(float(focal_length_list[0])) camera_prim.CreateAttribute("openCVFy", Sdf.ValueTypeNames.Float).Set(float(focal_length_list[1])) - # Radial distortion coefficients [k1,k2,k3,k4] radial_coeffs_list = params.radial_coeffs.tolist() camera_prim.CreateAttribute("fthetaPolyA", Sdf.ValueTypeNames.Float).Set(float(radial_coeffs_list[0])) camera_prim.CreateAttribute("fthetaPolyB", Sdf.ValueTypeNames.Float).Set(float(radial_coeffs_list[1])) camera_prim.CreateAttribute("fthetaPolyC", Sdf.ValueTypeNames.Float).Set(float(radial_coeffs_list[2])) camera_prim.CreateAttribute("fthetaPolyD", Sdf.ValueTypeNames.Float).Set(float(radial_coeffs_list[3])) - # Max FoV (convert from radians to degrees, x2 for full FoV) - camera_prim.CreateAttribute("fthetaMaxFov", Sdf.ValueTypeNames.Float).Set(float(2.0 * np.rad2deg(params.max_angle))) - - -def _add_simple_pinhole_intrinsics( - camera_prim: Usd.Prim, - intrinsics: List[float], - resolution: List[int], -) -> None: - """Add simple pinhole intrinsics [fx, fy, cx, cy] without distortion.""" - fx, fy, cx, cy = intrinsics - - # Use standard USD pinhole camera attributes - # Compute horizontal aperture from resolution and focal length - # USD uses mm for aperture, assuming sensor is 36mm (full-frame) - sensor_width_mm = 36.0 - focal_length_mm = (fx / resolution[0]) * sensor_width_mm - - camera_prim.GetFocalLengthAttr().Set(focal_length_mm) - camera_prim.GetHorizontalApertureAttr().Set(sensor_width_mm) - camera_prim.GetVerticalApertureAttr().Set(sensor_width_mm * resolution[1] / resolution[0]) - - # Principal point offset from center - horizontal_offset = ((cx / resolution[0]) - 0.5) * sensor_width_mm - vertical_offset = ((cy / resolution[1]) - 0.5) * (sensor_width_mm * resolution[1] / resolution[0]) - camera_prim.GetHorizontalApertureOffsetAttr().Set(horizontal_offset) - camera_prim.GetVerticalApertureOffsetAttr().Set(vertical_offset) + camera_prim.CreateAttribute("fthetaMaxFov", Sdf.ValueTypeNames.Float).Set( + float(2.0 * np.rad2deg(params.max_angle)) + ) def export_cameras_to_usd( stage: Usd.Stage, poses: np.ndarray, - intrinsics: Optional[List] = None, + camera_names: List[str], + frame_to_camera: List[int], camera_params: Optional[List] = None, - resolutions: Optional[List[np.ndarray]] = None, root_path: str = "/World/Cameras", - camera_prefix: str = "camera", visible: bool = False, -) -> str: +) -> Dict[str, str]: """ - Export camera poses with intrinsics to USD stage. + Export camera poses with intrinsics to a USD stage. - Supports multiple camera model types: - - OpenCVPinholeCameraModelParameters: Full pinhole with distortion - - OpenCVFisheyeCameraModelParameters: Fisheye with distortion - - Simple intrinsics: [fx, fy, cx, cy] list for basic pinhole + Creates one Camera prim per physical camera with time-sampled transforms + and static intrinsics. The time code for frame i is float(i), so + stage.GetTimeCodesPerSecond() controls real-time playback speed. Args: - stage: USD stage to export to - poses: Camera poses [N, 4, 4] in 3DGRUT convention (right-down-front) - intrinsics: Optional list of [fx, fy, cx, cy] for simple pinhole - camera_params: Optional list of camera model parameters (OpenCVPinhole/Fisheye) - resolutions: Optional list of resolutions [[w, h], ...] for simple intrinsics - root_path: USD path for camera root xform - camera_prefix: Prefix for camera names - visible: Whether cameras should be visible in viewport + stage: USD stage to export to. + poses: Camera-to-world transforms [N_frames, 4, 4] in 3DGRUT convention + (right-down-front). + camera_names: Logical name for each physical camera, indexed by camera_idx. + frame_to_camera: Per-frame camera index mapping, length N_frames. + camera_params: Per-frame CameraModelParameters (OpenCVPinhole / Fisheye). + Intrinsics are taken from the first frame of each camera. + root_path: USD path for the camera root Xform. + visible: Whether camera prims should be visible in the viewport. Returns: - Root path of the cameras + Mapping {camera_name: usd_prim_path} for every exported camera. """ - num_cameras = poses.shape[0] + num_cameras = len(camera_names) + + # Group frame indices by camera + camera_frames: Dict[int, List[int]] = {i: [] for i in range(num_cameras)} + for frame_idx, cam_idx in enumerate(frame_to_camera): + if 0 <= cam_idx < num_cameras: + camera_frames[cam_idx].append(frame_idx) - # Create root xform for cameras UsdGeom.Xform.Define(stage, root_path) - # Coordinate transform from 3DGRUT (right-down-front) to USD camera (right-up-back) - # 3DGRUT: X=right, Y=down, Z=front - # USD: X=right, Y=up, Z=back - camera_coord_flip = np.array([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]], dtype=np.float64) + result: Dict[str, str] = {} + + for cam_idx, cam_name in enumerate(camera_names): + frame_indices = camera_frames[cam_idx] + if not frame_indices: + logger.warning(f"Camera '{cam_name}' (idx {cam_idx}) has no frames, skipping") + continue - for i in range(num_cameras): - camera_name = f"{camera_prefix}_{i:04d}" - camera_path = f"{root_path}/{camera_name}" + prim_name = _make_usd_prim_name(cam_name) + camera_path = f"{root_path}/{prim_name}" - # Define camera prim camera_prim = stage.DefinePrim(camera_path, "Camera") camera = UsdGeom.Camera(camera_prim) - - # Set clipping range camera.GetClippingRangeAttr().Set(Gf.Vec2f(DEFAULT_NEAR_CLIP, DEFAULT_FAR_CLIP)) - # Add intrinsics based on available data - if camera_params is not None and i < len(camera_params) and camera_params[i] is not None: - params = camera_params[i] + # Static intrinsics from first frame of this camera + first_frame = frame_indices[0] + if camera_params is not None and first_frame < len(camera_params) and camera_params[first_frame] is not None: + params = camera_params[first_frame] if isinstance(params, OpenCVPinholeCameraModelParameters): _add_opencv_pinhole_camera_intrinsics(camera_prim, params) elif isinstance(params, OpenCVFisheyeCameraModelParameters): _add_opencv_fisheye_camera_intrinsics(camera_prim, params) else: - # Fallback to default focal length camera.GetFocalLengthAttr().Set(24.0) - logger.warning(f"Unsupported camera model for camera {i}, using default intrinsics") - elif intrinsics is not None and resolutions is not None: - # Simple pinhole from intrinsics list - if i < len(resolutions): - resolution = resolutions[i].tolist() if isinstance(resolutions[i], np.ndarray) else resolutions[i] - else: - resolution = resolutions[0].tolist() if isinstance(resolutions[0], np.ndarray) else resolutions[0] - _add_simple_pinhole_intrinsics(camera_prim, intrinsics, resolution) + logger.warning(f"Unsupported camera model for '{cam_name}', using default focal length") else: - # Fallback to default focal length camera.GetFocalLengthAttr().Set(24.0) - # Set camera transform (pose) - # Apply coordinate system transform: 3DGRUT -> USD camera, then build USD matrix via Gf API - pose = poses[i] - usd_pose = pose @ camera_coord_flip - usd_matrix = column_vector_4x4_to_usd_matrix(usd_pose) - + # Time-sampled transforms — one sample per frame belonging to this camera xformable = UsdGeom.Xformable(camera_prim) transform_op = xformable.AddTransformOp() - transform_op.Set(usd_matrix) + for frame_idx in frame_indices: + usd_pose = poses[frame_idx] @ _CAMERA_COORD_FLIP + transform_op.Set(column_vector_4x4_to_usd_matrix(usd_pose), float(frame_idx)) - # Set visibility imageable = UsdGeom.Imageable(camera_prim) - visibility = "inherited" if visible else "invisible" - imageable.CreateVisibilityAttr().Set(visibility) + imageable.CreateVisibilityAttr().Set("inherited" if visible else "invisible") - logger.info(f"Exported {num_cameras} cameras to {root_path}") - return root_path + result[cam_name] = camera_path + + logger.info( + f"Exported {len(result)} camera(s) ({len(poses)} total frames) to {root_path}" + ) + return result def export_camera_rig_with_timestamps( @@ -267,42 +233,30 @@ def export_camera_rig_with_timestamps( """ num_frames = poses.shape[0] - # Create rig xform rig_prim = stage.DefinePrim(root_path, "Xform") rig_xform = UsdGeom.Xformable(rig_prim) - # Coordinate transform - camera_coord_flip = np.array([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]], dtype=np.float64) - - # USD time code setup usd_time_code_per_second = stage.GetTimeCodesPerSecond() - usd_timestamp_scale = usd_time_code_per_second * 1e-06 # microseconds to time codes + usd_timestamp_scale = usd_time_code_per_second * 1e-06 - # Create transform op for rig rig_transform_op = rig_xform.AddTransformOp() usd_start_time_code = float("inf") usd_end_time_code = 0.0 - # Add time-sampled transforms for i in range(num_frames): - pose = poses[i] - usd_pose = pose @ camera_coord_flip + usd_pose = poses[i] @ _CAMERA_COORD_FLIP usd_matrix = column_vector_4x4_to_usd_matrix(usd_pose) if timestamps_us is not None: - timestamp = timestamps_us[i] - usd_time_code = usd_timestamp_scale * (timestamp - timestamp_offset_us) - usd_start_time_code = min(usd_start_time_code, usd_time_code) - usd_end_time_code = max(usd_end_time_code, usd_time_code) + usd_time_code = usd_timestamp_scale * (timestamps_us[i] - timestamp_offset_us) else: usd_time_code = float(i) - usd_start_time_code = min(usd_start_time_code, usd_time_code) - usd_end_time_code = max(usd_end_time_code, usd_time_code) + usd_start_time_code = min(usd_start_time_code, usd_time_code) + usd_end_time_code = max(usd_end_time_code, usd_time_code) rig_transform_op.Set(usd_matrix, usd_time_code) - # Set time metadata if usd_start_time_code <= usd_end_time_code: stage.SetMetadata("startTimeCode", usd_start_time_code) stage.SetMetadata("endTimeCode", usd_end_time_code) @@ -310,15 +264,11 @@ def export_camera_rig_with_timestamps( if timestamps_us is not None: stage.SetMetadataByDictKey("customLayerData", "absoluteTimeOffsetMicroSec", timestamp_offset_us) - # Create camera prim under rig (static relative to rig) camera_path = f"{root_path}/{camera_name}" camera_prim = stage.DefinePrim(camera_path, "Camera") camera = UsdGeom.Camera(camera_prim) - - # Set default clipping range camera.GetClippingRangeAttr().Set(Gf.Vec2f(DEFAULT_NEAR_CLIP, DEFAULT_FAR_CLIP)) - # Add intrinsics if provided if camera_params is not None and len(camera_params) > 0: params = camera_params[0] if isinstance(params, OpenCVPinholeCameraModelParameters): @@ -330,15 +280,12 @@ def export_camera_rig_with_timestamps( else: camera.GetFocalLengthAttr().Set(24.0) - # Camera is at identity transform relative to rig (transform is on rig itself) xformable = UsdGeom.Xformable(camera_prim) transform_op = xformable.AddTransformOp() transform_op.Set(Gf.Matrix4d(1.0)) - # Set visibility imageable = UsdGeom.Imageable(camera_prim) - visibility = "inherited" if visible else "invisible" - imageable.CreateVisibilityAttr().Set(visibility) + imageable.CreateVisibilityAttr().Set("inherited" if visible else "invisible") logger.info(f"Exported camera rig with {num_frames} frames to {root_path}") return root_path diff --git a/threedgrut/export/usd/writers/ppisp_writer.py b/threedgrut/export/usd/writers/ppisp_writer.py new file mode 100644 index 00000000..8947c3ca --- /dev/null +++ b/threedgrut/export/usd/writers/ppisp_writer.py @@ -0,0 +1,364 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +PPISP USD Writer. + +Export PPISP (Physically Plausible Image Signal Processing) as a UsdShade +Shader prim on each camera's RenderProduct. Adapted from +nre-fermat/nre/utils/io/export/ppisp_usd_writer.py, replacing the +rig/timestamp frame-mapping with 3DGRUT integer frame indices. + +PPISP pipeline stages: +1. Exposure compensation (per-frame, time-sampled) +2. Vignetting correction (per-camera, static) +3. Color correction via ZCA-based homography (per-frame, time-sampled) +4. Camera Response Function (per-camera, static) +""" + +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING, Dict, List, Tuple + +import numpy as np + +from pxr import Gf, Sdf, Usd, UsdShade + +if TYPE_CHECKING: + from ppisp import PPISP # type: ignore[import-not-found] + +log = logging.getLogger(__name__) + +NUM_CHANNELS = 3 +COLOR_PARAMS_PER_FRAME = 8 +CHANNEL_SUFFIXES = ["R", "G", "B"] + +PPISP_SPG_USDA_FILE = "ppisp_usd_spg.slang.usda" +PPISP_INPUT_RENDER_VAR = "HdrColor" +PPISP_OUTPUT_RENDER_VAR = "PPISPColor" +LDR_COLOR_RENDER_VAR = "LdrColor" + + +# --------------------------------------------------------------------------- +# Dataset frame-mapping helpers +# --------------------------------------------------------------------------- + + +def build_camera_frame_mapping(dataset) -> Tuple[List[str], Dict[str, List[int]]]: + """Build per-camera frame lists from a 3DGRUT dataset. + + Returns: + (camera_names, {camera_name: [frame_idx, ...]}) where frame_idx values + are the global training indices used as USD time codes. + """ + num_frames = len(dataset) + + camera_names: List[str] + if hasattr(dataset, "get_camera_names"): + camera_names = dataset.get_camera_names() + else: + camera_names = ["camera_0"] + + camera_frames: Dict[str, List[int]] = {name: [] for name in camera_names} + + for frame_idx in range(num_frames): + if hasattr(dataset, "get_camera_idx"): + cam_idx = dataset.get_camera_idx(frame_idx) + else: + cam_idx = 0 + if 0 <= cam_idx < len(camera_names): + camera_frames[camera_names[cam_idx]].append(frame_idx) + + return camera_names, camera_frames + + +# --------------------------------------------------------------------------- +# Shader prim creation +# --------------------------------------------------------------------------- + + +def _add_ldr_color_render_var( + stage: Usd.Stage, + render_product_path: str, + ppisp_output_path: Sdf.Path, +) -> str: + """Create a LdrColor RenderVar wired to the PPISP output.""" + ldr_var_path = f"{render_product_path}/{LDR_COLOR_RENDER_VAR}" + ldr_var = stage.DefinePrim(ldr_var_path, "RenderVar") + ldr_var.CreateAttribute("sourceName", Sdf.ValueTypeNames.String).Set(LDR_COLOR_RENDER_VAR) + aov_attr = ldr_var.CreateAttribute("omni:rtx:aov", Sdf.ValueTypeNames.Opaque) + aov_attr.SetConnections([ppisp_output_path]) + return ldr_var_path + + +def _create_shader_prim(stage: Usd.Stage, render_product_path: str) -> UsdShade.Shader: + """Create the PPISP Shader prim on a RenderProduct. + + Wires HdrColor → PPISP → LdrColor and appends LdrColor to orderedVars. + Returns the UsdShade.Shader for parameter setting. + """ + render_product = stage.GetPrimAtPath(render_product_path) + if not render_product.IsValid(): + raise ValueError(f"RenderProduct not found at path: {render_product_path}") + + # Mark HdrColor RenderVar input as an opaque AOV (no connection needed here) + input_var_path = f"{render_product_path}/{PPISP_INPUT_RENDER_VAR}" + input_var_prim = stage.GetPrimAtPath(input_var_path) + if input_var_prim.IsValid(): + input_var_prim.CreateAttribute("omni:rtx:aov", Sdf.ValueTypeNames.Opaque) + + # PPISP Shader prim referencing the SPG asset definition + ppisp_shader_path = f"{render_product_path}/PPISP" + shader = UsdShade.Shader.Define(stage, ppisp_shader_path) + shader.GetPrim().GetReferences().AddReference(PPISP_SPG_USDA_FILE) + + # HdrColor opaque input wired to the input RenderVar's AOV + hdr_input = shader.CreateInput(PPISP_INPUT_RENDER_VAR, Sdf.ValueTypeNames.Opaque) + hdr_input.GetAttr().SetConnections([Sdf.Path(f"../{PPISP_INPUT_RENDER_VAR}.omni:rtx:aov")]) + + # PPISPColor opaque output + shader.CreateOutput(PPISP_OUTPUT_RENDER_VAR, Sdf.ValueTypeNames.Opaque) + + # LdrColor RenderVar connected to the output + ppisp_output_path = shader.GetPath().AppendProperty(f"outputs:{PPISP_OUTPUT_RENDER_VAR}") + ldr_var_path = _add_ldr_color_render_var(stage, render_product_path, ppisp_output_path) + + # Append LdrColor to orderedVars + ordered_vars_rel = render_product.GetRelationship("orderedVars") + if ordered_vars_rel: + targets = list(ordered_vars_rel.GetTargets()) + targets.append(Sdf.Path(ldr_var_path)) + ordered_vars_rel.SetTargets(targets) + + return shader + + +# --------------------------------------------------------------------------- +# Static parameter setters (per-camera) +# --------------------------------------------------------------------------- + + +def _set_vignetting_params(shader: UsdShade.Shader, ppisp: PPISP, camera_index: int) -> None: + """Set per-camera vignetting parameters (static). + + ppisp.vignetting_params[camera_index] has shape [3, 5]: + [cx, cy, alpha1, alpha2, alpha3] per channel. + """ + vig = ppisp.vignetting_params[camera_index].cpu().numpy() # [3, 5] + for ch in range(NUM_CHANNELS): + s = CHANNEL_SUFFIXES[ch] + shader.CreateInput(f"vignettingCenter{s}", Sdf.ValueTypeNames.Float2).Set( + Gf.Vec2f(float(vig[ch, 0]), float(vig[ch, 1])) + ) + shader.CreateInput(f"vignettingAlpha1{s}", Sdf.ValueTypeNames.Float).Set(float(vig[ch, 2])) + shader.CreateInput(f"vignettingAlpha2{s}", Sdf.ValueTypeNames.Float).Set(float(vig[ch, 3])) + shader.CreateInput(f"vignettingAlpha3{s}", Sdf.ValueTypeNames.Float).Set(float(vig[ch, 4])) + + +def _set_crf_params(shader: UsdShade.Shader, ppisp: PPISP, camera_index: int) -> None: + """Set per-camera CRF raw parameters (static). + + ppisp.crf_params[camera_index] has shape [3, 4]: + [toe, shoulder, gamma, center] per channel (raw, activations applied in shader). + """ + crf = ppisp.crf_params[camera_index].cpu().numpy() # [3, 4] + for ch in range(NUM_CHANNELS): + s = CHANNEL_SUFFIXES[ch] + shader.CreateInput(f"crfToe{s}", Sdf.ValueTypeNames.Float).Set(float(crf[ch, 0])) + shader.CreateInput(f"crfShoulder{s}", Sdf.ValueTypeNames.Float).Set(float(crf[ch, 1])) + shader.CreateInput(f"crfGamma{s}", Sdf.ValueTypeNames.Float).Set(float(crf[ch, 2])) + shader.CreateInput(f"crfCenter{s}", Sdf.ValueTypeNames.Float).Set(float(crf[ch, 3])) + + +# --------------------------------------------------------------------------- +# Animated parameter setters (per-frame, time-sampled) +# --------------------------------------------------------------------------- + + +def _set_animated_exposure_params( + shader: UsdShade.Shader, + ppisp: PPISP, + frame_indices: List[int], +) -> None: + """Write time-sampled exposure offset; default = mean across this camera's frames. + + ppisp.exposure_params has shape [num_frames]. + Time code = float(frame_idx). + """ + exposure = ppisp.exposure_params.cpu().numpy() # [num_frames] + + valid = [i for i in frame_indices if i < len(exposure)] + mean_val = float(np.mean(exposure[valid])) if valid else 0.0 + + exposure_input = shader.CreateInput("exposureOffset", Sdf.ValueTypeNames.Float) + attr = exposure_input.GetAttr() + attr.Set(mean_val) + + for frame_idx in valid: + attr.Set(float(exposure[frame_idx]), float(frame_idx)) + + +def _set_animated_color_params( + shader: UsdShade.Shader, + ppisp: PPISP, + frame_indices: List[int], +) -> None: + """Write time-sampled color latent offsets; default = mean across this camera's frames. + + ppisp.color_params has shape [num_frames, 8]: + [db_r, db_g, dr_r, dr_g, dg_r, dg_g, dgray_r, dgray_g]. + Written as 4 float2 attributes. + Time code = float(frame_idx). + """ + color = ppisp.color_params.cpu().numpy() # [num_frames, 8] + + valid = [i for i in frame_indices if i < len(color)] + mean_color = np.mean(color[valid], axis=0) if valid else np.zeros(8) + + control_point_names = ["colorLatentBlue", "colorLatentRed", "colorLatentGreen", "colorLatentNeutral"] + attrs = [] + for i, name in enumerate(control_point_names): + inp = shader.CreateInput(name, Sdf.ValueTypeNames.Float2) + attr = inp.GetAttr() + attr.Set(Gf.Vec2f(float(mean_color[i * 2]), float(mean_color[i * 2 + 1]))) + attrs.append(attr) + + for frame_idx in valid: + frame_color = color[frame_idx] + for i, attr in enumerate(attrs): + attr.Set( + Gf.Vec2f(float(frame_color[i * 2]), float(frame_color[i * 2 + 1])), + float(frame_idx), + ) + + +# --------------------------------------------------------------------------- +# Per-camera entry point +# --------------------------------------------------------------------------- + + +def add_ppisp_shader_to_render_product( + stage: Usd.Stage, + render_product_path: str, + camera_index: int, + ppisp: PPISP, + frame_indices: List[int], +) -> Usd.Prim: + """Add a PPISP Shader to a RenderProduct for one physical camera. + + Per-camera parameters (vignetting, CRF) are written as static USD + attributes. Per-frame parameters (exposure, color latents) are written + with a mean-based default value and one time sample per training frame + at time_code = float(frame_idx). + + Args: + stage: USD stage containing the RenderProduct. + render_product_path: Path to the RenderProduct prim. + camera_index: Index of this camera in the PPISP model. + ppisp: Trained PPISP module. + frame_indices: Global frame indices belonging to this camera. + + Returns: + The created PPISP Shader prim. + """ + assert camera_index < ppisp.num_cameras, ( + f"camera_index {camera_index} >= ppisp.num_cameras {ppisp.num_cameras}" + ) + if not frame_indices: + log.warning(f"No frames for camera {camera_index} at {render_product_path}, skipping") + return stage.GetPseudoRoot() + + shader = _create_shader_prim(stage, render_product_path) + _set_vignetting_params(shader, ppisp, camera_index) + _set_crf_params(shader, ppisp, camera_index) + _set_animated_exposure_params(shader, ppisp, frame_indices) + _set_animated_color_params(shader, ppisp, frame_indices) + + log.info( + f"Added PPISP shader to {render_product_path} " + f"(camera {camera_index}, {len(frame_indices)} frame(s))" + ) + return shader.GetPrim() + + +# --------------------------------------------------------------------------- +# Batch export over all RenderProducts +# --------------------------------------------------------------------------- + + +def add_ppisp_to_all_render_products( + stage: Usd.Stage, + ppisp: PPISP, + camera_names: List[str], + camera_frame_mapping: Dict[str, List[int]], + render_scope_path: str = "/Render", +) -> List[Usd.Prim]: + """Add PPISP shaders to every RenderProduct in the Render scope. + + Args: + stage: USD stage with a populated /Render scope. + ppisp: Trained PPISP module. + camera_names: Ordered list of camera names (index = camera_idx in ppisp). + camera_frame_mapping: ``{camera_name: [frame_idx, ...]}`` from + :func:`build_camera_frame_mapping`. + render_scope_path: Path to the /Render Scope (default ``/Render``). + + Returns: + List of created PPISP Shader prims. + """ + from threedgrut.export.usd.writers.camera import _make_usd_prim_name + + render_scope = stage.GetPrimAtPath(render_scope_path) + if not render_scope.IsValid(): + log.warning(f"Render scope not found at {render_scope_path}, skipping PPISP export") + return [] + + camera_name_to_index = {name: idx for idx, name in enumerate(camera_names)} + created: List[Usd.Prim] = [] + + for child in render_scope.GetChildren(): + if child.GetTypeName() != "RenderProduct": + continue + + # RenderProduct prim name matches _make_usd_prim_name(camera_name) + prim_name = child.GetName() + # Reverse-lookup original camera_name by prim name + camera_name = next( + (n for n in camera_names if _make_usd_prim_name(n) == prim_name), + None, + ) + if camera_name is None: + log.warning(f"RenderProduct '{prim_name}' has no matching camera name, skipping") + continue + + camera_index = camera_name_to_index.get(camera_name) + if camera_index is None: + log.warning(f"Camera '{camera_name}' not in camera_names list, skipping") + continue + + frame_indices = camera_frame_mapping.get(camera_name, []) + + shader_prim = add_ppisp_shader_to_render_product( + stage=stage, + render_product_path=str(child.GetPath()), + camera_index=camera_index, + ppisp=ppisp, + frame_indices=frame_indices, + ) + created.append(shader_prim) + + log.info(f"Added PPISP shaders to {len(created)} RenderProduct(s)") + return created diff --git a/threedgrut/export/usd/writers/render_product.py b/threedgrut/export/usd/writers/render_product.py new file mode 100644 index 00000000..760ec1ab --- /dev/null +++ b/threedgrut/export/usd/writers/render_product.py @@ -0,0 +1,85 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +USD RenderProduct writer. + +Creates a /Render Scope with one RenderProduct per camera, each holding an +HdrColor RenderVar and the camera relationship required by downstream +post-processing shaders (e.g. PPISP). +""" + +import logging +from typing import Dict, Tuple + +from pxr import Sdf, Usd, UsdGeom + +log = logging.getLogger(__name__) + +_HDR_COLOR_VAR = "HdrColor" +_RENDER_SCOPE_PATH = "/Render" + + +def create_render_products( + stage: Usd.Stage, + camera_entries: Dict[str, Tuple[str, int, int]], + render_scope_path: str = _RENDER_SCOPE_PATH, +) -> None: + """Create a /Render Scope with one RenderProduct per camera. + + Each RenderProduct is named after its camera and contains: + - ``camera`` relationship pointing to the USD camera prim. + - ``resolution`` attribute. + - ``orderedVars`` relationship → [.../HdrColor]. + - Child ``RenderVar`` ``HdrColor`` with ``sourceName = "HdrColor"``. + + Args: + stage: USD stage that already contains the camera prims. + camera_entries: Mapping ``{camera_name: (usd_camera_path, width, height)}``. + The camera_name is used as the RenderProduct prim name (after USD + identifier sanitization to match what export_cameras_to_usd produced). + render_scope_path: Root path for the Render scope (default ``/Render``). + """ + from threedgrut.export.usd.writers.camera import _make_usd_prim_name + + stage.DefinePrim(render_scope_path, "Scope") + + for camera_name, (camera_path, width, height) in camera_entries.items(): + prim_name = _make_usd_prim_name(camera_name) + product_path = f"{render_scope_path}/{prim_name}" + + product_prim = stage.DefinePrim(product_path, "RenderProduct") + + # Resolution + product_prim.CreateAttribute( + "resolution", Sdf.ValueTypeNames.Int2 + ).Set((width, height)) + + # Camera relationship + camera_rel = product_prim.CreateRelationship("camera") + camera_rel.SetTargets([Sdf.Path(camera_path)]) + + # HdrColor RenderVar + hdr_var_path = f"{product_path}/{_HDR_COLOR_VAR}" + hdr_var = stage.DefinePrim(hdr_var_path, "RenderVar") + hdr_var.CreateAttribute("sourceName", Sdf.ValueTypeNames.String).Set(_HDR_COLOR_VAR) + + # orderedVars relationship + ordered_vars_rel = product_prim.CreateRelationship("orderedVars") + ordered_vars_rel.SetTargets([Sdf.Path(hdr_var_path)]) + + log.debug(f"Created RenderProduct at {product_path} → camera {camera_path} ({width}×{height})") + + log.info(f"Created {len(camera_entries)} RenderProduct(s) under {render_scope_path}") diff --git a/threedgrut/trainer.py b/threedgrut/trainer.py index bd275773..accb3430 100644 --- a/threedgrut/trainer.py +++ b/threedgrut/trainer.py @@ -808,6 +808,7 @@ def on_training_end(self): dataset=self.train_dataset, conf=conf, background=getattr(self, "background", None), + post_processing=getattr(self, "post_processing", None), ) # Export post-processing report (PPISP-based) From dc0d2a3e22fc92526b9ff38783eebce4cbcac862 Mon Sep 17 00:00:00 2001 From: Nicolas Moenne-Loccoz Date: Sun, 26 Apr 2026 19:34:43 -0400 Subject: [PATCH 06/42] feat(export): add PPISP Omniverse fallback writer Add a dedicated writer that maps PPISP exposure and fitted post-processing parameters to Omniverse USD attributes for deployments without reliable SPG support. Made-with: Cursor --- configs/base_gs.yaml | 3 + threedgrut/export/usd/exporter.py | 133 ++++-- threedgrut/export/usd/writers/__init__.py | 3 + .../export/usd/writers/ov_post_processing.py | 381 ++++++++++++++++++ 4 files changed, 494 insertions(+), 26 deletions(-) create mode 100644 threedgrut/export/usd/writers/ov_post_processing.py diff --git a/configs/base_gs.yaml b/configs/base_gs.yaml index ce0b5344..7c84c817 100644 --- a/configs/base_gs.yaml +++ b/configs/base_gs.yaml @@ -51,6 +51,9 @@ export_usd: linear_srgb: false # PPISP post-processing export as SPG shader on per-camera RenderProducts export_ppisp: false + # Omniverse RTX post-processing workaround for Kit versions without reliable SPG. + # none | ppisp-exposure-war | ppisp-approx-war | ppisp-hybrid-war + ov-post-processing: none # USD timeCodesPerSecond; time codes are bare frame indices so this sets playback speed frames_per_second: 1.0 diff --git a/threedgrut/export/usd/exporter.py b/threedgrut/export/usd/exporter.py index 62b41ab4..f42522de 100644 --- a/threedgrut/export/usd/exporter.py +++ b/threedgrut/export/usd/exporter.py @@ -53,6 +53,11 @@ copy_authored_time_settings_from_source, merge_source_world_at_same_paths, ) +from threedgrut.export.usd.writers.ov_post_processing import ( + MODE_NONE, + MODE_PPISP_HYBRID_WAR, + normalize_ov_post_processing_mode, +) from threedgrut.export.usd.writers.camera import export_cameras_to_usd logger = logging.getLogger(__name__) @@ -240,6 +245,7 @@ def __init__( sorting_mode_hint: str = "cameraDistance", linear_srgb: bool = False, export_ppisp: bool = False, + ov_post_processing: str = MODE_NONE, frames_per_second: float = 1.0, ): """ @@ -256,6 +262,7 @@ def __init__( linear_srgb: If True, set prim color space to lin_rec709_scene. export_ppisp: If True, add PPISP SPG shaders on per-camera RenderProducts. Requires post_processing kwarg to be a ppisp.PPISP instance. + ov_post_processing: Omniverse RTX post-processing workaround mode. frames_per_second: Sets stage.timeCodesPerSecond. Time codes are always bare frame indices (float(frame_idx)), so this controls playback speed. Default 1.0 means 1 frame per second of real time. @@ -271,6 +278,7 @@ def __init__( self.sorting_mode_hint = sorting_mode_hint self.linear_srgb = linear_srgb self.export_ppisp = export_ppisp + self.ov_post_processing = normalize_ov_post_processing_mode(ov_post_processing) self.frames_per_second = frames_per_second def _create_default_stage(self, referenced_stages: List[NamedUSDStage]) -> NamedUSDStage: @@ -308,7 +316,7 @@ def export( background: Optional background model for environment export. **kwargs: post_processing: ppisp.PPISP instance for SPG export (used when - export_ppisp=True). + export_ppisp=True or ov-post-processing is enabled). validate_usd (default True): run OpenUSD stage validators. apply_coordinate_transform (bool): apply 3DGRUT→USDZ coordinate flip. copy_source_usd: (stage_path, res_root) for prim merge. @@ -408,6 +416,7 @@ def export( camera_names = None frame_to_camera = None camera_prim_paths: Dict[str, str] = {} + camera_params = None if dataset is not None: camera_names, frame_to_camera = _extract_camera_grouping(dataset) @@ -456,19 +465,40 @@ def export( except (AttributeError, ValueError, ImportError) as e: logger.warning(f"Failed to export background: {e}") - # Export PPISP as SPG shaders on RenderProducts - if self.export_ppisp: - self._export_ppisp( + render_product_entries = None + export_spg_ppisp = self.export_ppisp or self.ov_post_processing == MODE_PPISP_HYBRID_WAR + needs_ppisp_render_products = export_spg_ppisp or self.ov_post_processing != MODE_NONE + if needs_ppisp_render_products: + render_product_entries = self._create_ppisp_render_products( stage=stage, dataset=dataset, camera_names=camera_names, frame_to_camera=frame_to_camera, camera_prim_paths=camera_prim_paths, - camera_params=_extract_camera_params_from_dataset(dataset) if dataset is not None else None, + camera_params=camera_params, + ) + + # Export PPISP as SPG shaders on RenderProducts + if export_spg_ppisp and render_product_entries is not None: + self._export_ppisp( + stage=stage, + dataset=dataset, + camera_names=camera_names, post_processing=kwargs.get("post_processing"), files=files, ) + # Export PPISP approximation as Omniverse RTX post-processing settings + if self.ov_post_processing != MODE_NONE and render_product_entries is not None: + self._export_ov_post_processing( + stage=stage, + camera_names=camera_names, + camera_prim_paths=camera_prim_paths, + render_product_entries=render_product_entries, + dataset=dataset, + post_processing=kwargs.get("post_processing"), + ) + # Package if suffix == ".usdz": if default_stage_wrapped is None: @@ -496,7 +526,7 @@ def export( logger.info(f"USD export complete: {output_path}") - def _export_ppisp( + def _create_ppisp_render_products( self, stage, dataset, @@ -504,10 +534,37 @@ def _export_ppisp( frame_to_camera, camera_prim_paths: Dict[str, str], camera_params, + ): + """Create /Render RenderProducts shared by SPG and OV PPISP exports.""" + if dataset is None or not camera_prim_paths: + logger.warning("No camera prims available for PPISP RenderProduct wiring, skipping") + return None + + from threedgrut.export.usd.writers.render_product import create_render_products + + resolutions = _extract_camera_resolutions(camera_params, camera_names, frame_to_camera) + camera_entries = {} + for cam_name, cam_path in camera_prim_paths.items(): + w, h = resolutions.get(cam_name, (0, 0)) + camera_entries[cam_name] = (cam_path, w, h) + + try: + create_render_products(stage=stage, camera_entries=camera_entries) + except Exception as e: + logger.warning(f"Failed to create RenderProducts: {e}") + return None + + return camera_entries + + def _export_ppisp( + self, + stage, + dataset, + camera_names, post_processing, files: List[NamedSerialized], ) -> None: - """Create /Render RenderProducts and attach PPISP SPG shaders.""" + """Attach PPISP SPG shaders to existing RenderProducts.""" try: from ppisp import PPISP # type: ignore[import-not-found] except ImportError: @@ -521,31 +578,12 @@ def _export_ppisp( ) return - if dataset is None or not camera_prim_paths: - logger.warning("No camera prims available for PPISP RenderProduct wiring, skipping") - return - - from threedgrut.export.usd.writers.render_product import create_render_products from threedgrut.export.usd.writers.ppisp_writer import ( add_ppisp_to_all_render_products, build_camera_frame_mapping, ) from threedgrut.export.usd.ppisp_spg import get_ppisp_spg_files - # Build camera_entries: camera_name → (usd_camera_path, width, height) - resolutions = _extract_camera_resolutions(camera_params, camera_names, frame_to_camera) - camera_entries = {} - for cam_name, cam_path in camera_prim_paths.items(): - w, h = resolutions.get(cam_name, (0, 0)) - camera_entries[cam_name] = (cam_path, w, h) - - try: - create_render_products(stage=stage, camera_entries=camera_entries) - except Exception as e: - logger.warning(f"Failed to create RenderProducts: {e}") - return - - # Build frame mapping from dataset _, camera_frame_mapping = build_camera_frame_mapping(dataset) try: @@ -567,6 +605,46 @@ def _export_ppisp( logger.info(f"PPISP SPG export complete: {len(spg_files)} sidecar(s) added") + def _export_ov_post_processing( + self, + stage, + camera_names, + camera_prim_paths, + render_product_entries, + dataset, + post_processing, + ) -> None: + """Attach Omniverse RTX post-processing WAR attributes to RenderProducts.""" + try: + from ppisp import PPISP # type: ignore[import-not-found] + except ImportError: + logger.warning("ppisp package not available, skipping OV post-processing export") + return + + if not isinstance(post_processing, PPISP): + logger.warning( + f"ov-post-processing={self.ov_post_processing} but post_processing is " + f"{type(post_processing).__name__}, expected ppisp.PPISP — skipping" + ) + return + + from threedgrut.export.usd.writers.ov_post_processing import add_ov_post_processing + from threedgrut.export.usd.writers.ppisp_writer import build_camera_frame_mapping + + _, camera_frame_mapping = build_camera_frame_mapping(dataset) + try: + add_ov_post_processing( + stage=stage, + camera_names=camera_names, + camera_prim_paths=camera_prim_paths, + camera_frame_mapping=camera_frame_mapping, + render_product_entries=render_product_entries, + post_processing=post_processing, + mode=self.ov_post_processing, + ) + except Exception as e: + logger.warning(f"Failed to add OV post-processing workaround: {e}") + @classmethod def from_config(cls, conf) -> "USDExporter": """ @@ -588,5 +666,8 @@ def from_config(cls, conf) -> "USDExporter": sorting_mode_hint=getattr(export_conf, "sorting_mode_hint", "cameraDistance"), linear_srgb=getattr(export_conf, "linear_srgb", False), export_ppisp=getattr(export_conf, "export_ppisp", False), + ov_post_processing=export_conf.get("ov-post-processing", MODE_NONE) + if hasattr(export_conf, "get") + else getattr(export_conf, "ov_post_processing", MODE_NONE), frames_per_second=getattr(export_conf, "frames_per_second", 1.0), ) diff --git a/threedgrut/export/usd/writers/__init__.py b/threedgrut/export/usd/writers/__init__.py index 92531563..e3572bcc 100644 --- a/threedgrut/export/usd/writers/__init__.py +++ b/threedgrut/export/usd/writers/__init__.py @@ -21,12 +21,14 @@ - export_cameras_to_usd: one Camera prim per physical camera, animated xforms - create_render_products: /Render scope with per-camera RenderProducts - add_ppisp_to_all_render_products: PPISP SPG shader on RenderProducts +- add_ov_post_processing: Omniverse RTX post-processing PPISP workaround """ from threedgrut.export.usd.writers.background import export_background_to_usd from threedgrut.export.usd.writers.base import GaussianUSDWriter, create_gaussian_writer from threedgrut.export.usd.writers.camera import export_cameras_to_usd from threedgrut.export.usd.writers.lightfield import GaussianLightFieldWriter +from threedgrut.export.usd.writers.ov_post_processing import add_ov_post_processing from threedgrut.export.usd.writers.ppisp_writer import add_ppisp_to_all_render_products from threedgrut.export.usd.writers.render_product import create_render_products @@ -38,4 +40,5 @@ "export_background_to_usd", "create_render_products", "add_ppisp_to_all_render_products", + "add_ov_post_processing", ] diff --git a/threedgrut/export/usd/writers/ov_post_processing.py b/threedgrut/export/usd/writers/ov_post_processing.py new file mode 100644 index 00000000..4c4fd307 --- /dev/null +++ b/threedgrut/export/usd/writers/ov_post_processing.py @@ -0,0 +1,381 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Omniverse RTX post-processing workaround writer for PPISP exports. + +This writer is a degraded fallback for Kit versions where SPG is unavailable or +unreliable. It authors standard Omniverse RTX render settings only; exact PPISP +export remains the SPG path. +""" + +from __future__ import annotations + +import logging +from typing import Dict, Iterable, List, Tuple + +import numpy as np +from pxr import Gf, Sdf, Usd + +from threedgrut.export.usd.writers.camera import _make_usd_prim_name + +log = logging.getLogger(__name__) + +MODE_NONE = "none" +MODE_PPISP_EXPOSURE_WAR = "ppisp-exposure-war" +MODE_PPISP_APPROX_WAR = "ppisp-approx-war" +MODE_PPISP_HYBRID_WAR = "ppisp-hybrid-war" + +OV_POST_PROCESSING_MODES = { + MODE_NONE, + MODE_PPISP_EXPOSURE_WAR, + MODE_PPISP_APPROX_WAR, + MODE_PPISP_HYBRID_WAR, +} + +_BASE_EXPOSURE_TIME_SECONDS = 0.02 +_DEFAULT_EXPOSURE_FSTOP = 5.0 +_DEFAULT_EXPOSURE_ISO = 100.0 +_DEFAULT_EXPOSURE_RESPONSIVITY = 1.10267091 +_DEFAULT_RENDER_RESOLUTION = (1280, 720) + +_CAMERA_EXPOSURE_APIS = ["OmniRtxCameraExposureAPI_1"] +_RENDER_PRODUCT_APIS = [ + "OmniRtxPostTonemapIrayReinhardAPI_1", + "OmniRtxPostColorGradingAPI_1", + "OmniRtxPostTvNoiseAPI_1", +] + +_ZCA_BLUE = np.array([[0.0480542, -0.0043631], [-0.0043631, 0.0481283]], dtype=np.float64) +_ZCA_RED = np.array([[0.0580570, -0.0179872], [-0.0179872, 0.0431061]], dtype=np.float64) +_ZCA_GREEN = np.array([[0.0433336, -0.0180537], [-0.0180537, 0.0580500]], dtype=np.float64) +_ZCA_NEUTRAL = np.array([[0.0128369, -0.0034654], [-0.0034654, 0.0128158]], dtype=np.float64) + + +def normalize_ov_post_processing_mode(mode: str | None) -> str: + """Normalize and validate the ``export_usd.ov-post-processing`` value.""" + normalized = MODE_NONE if mode is None else str(mode).strip().lower() + if normalized not in OV_POST_PROCESSING_MODES: + raise ValueError( + f"Unsupported ov-post-processing mode '{mode}'. " + f"Expected one of: {sorted(OV_POST_PROCESSING_MODES)}" + ) + return normalized + + +def _as_numpy(value) -> np.ndarray: + if hasattr(value, "detach"): + value = value.detach() + if hasattr(value, "cpu"): + value = value.cpu() + if hasattr(value, "numpy"): + return value.numpy() + return np.asarray(value) + + +def _prepend_api_schemas(prim: Usd.Prim, schemas: Iterable[str]) -> None: + """Apply schemas by authoring the same listOp shape used by Kit examples.""" + schemas = [schema for schema in schemas if schema] + if not schemas: + return + prim.SetMetadata("apiSchemas", Sdf.TokenListOp.Create(prependedItems=schemas)) + + +def _create_float_attr(prim: Usd.Prim, name: str, value: float): + attr = prim.CreateAttribute(name, Sdf.ValueTypeNames.Float) + attr.Set(float(value)) + return attr + + +def _create_bool_attr(prim: Usd.Prim, name: str, value: bool): + attr = prim.CreateAttribute(name, Sdf.ValueTypeNames.Bool) + attr.Set(bool(value)) + return attr + + +def _create_color_attr(prim: Usd.Prim, name: str, value) -> Usd.Attribute: + vec = Gf.Vec3f(float(value[0]), float(value[1]), float(value[2])) + attr = prim.CreateAttribute(name, Sdf.ValueTypeNames.Color3f) + attr.Set(vec) + return attr + + +def _compute_homography(color_latent: np.ndarray) -> np.ndarray: + """Compute PPISP's RGI homography from one 8-float color latent vector.""" + b_lat = color_latent[0:2] + r_lat = color_latent[2:4] + g_lat = color_latent[4:6] + n_lat = color_latent[6:8] + + bd = _ZCA_BLUE @ b_lat + rd = _ZCA_RED @ r_lat + gd = _ZCA_GREEN @ g_lat + nd = _ZCA_NEUTRAL @ n_lat + + t_blue = np.array([bd[0], bd[1], 1.0], dtype=np.float64) + t_red = np.array([1.0 + rd[0], rd[1], 1.0], dtype=np.float64) + t_green = np.array([gd[0], 1.0 + gd[1], 1.0], dtype=np.float64) + t_gray = np.array([1.0 / 3.0 + nd[0], 1.0 / 3.0 + nd[1], 1.0], dtype=np.float64) + + target = np.stack([t_blue, t_red, t_green], axis=1) + skew = np.array( + [ + [0.0, -t_gray[2], t_gray[1]], + [t_gray[2], 0.0, -t_gray[0]], + [-t_gray[1], t_gray[0], 0.0], + ], + dtype=np.float64, + ) + matrix = skew @ target + lam = np.cross(matrix[0], matrix[1]) + if np.dot(lam, lam) < 1.0e-20: + lam = np.cross(matrix[0], matrix[2]) + if np.dot(lam, lam) < 1.0e-20: + lam = np.cross(matrix[1], matrix[2]) + + source_inv = np.array([[-1.0, -1.0, 1.0], [1.0, 0.0, 0.0], [0.0, 1.0, 0.0]], dtype=np.float64) + homography = target @ np.diag(lam) @ source_inv + if abs(homography[2, 2]) > 1.0e-20: + homography = homography / homography[2, 2] + return homography + + +def _apply_color_homography(rgb: np.ndarray, homography: np.ndarray) -> np.ndarray: + intensity = np.sum(rgb, axis=1) + rgi = np.stack([rgb[:, 0], rgb[:, 1], intensity], axis=1) + corrected = rgi @ homography.T + corrected = corrected * (intensity / (corrected[:, 2] + 1.0e-5))[:, None] + return np.stack([corrected[:, 0], corrected[:, 1], corrected[:, 2] - corrected[:, 0] - corrected[:, 1]], axis=1) + + +def _fit_grade_gain(color_latent: np.ndarray) -> np.ndarray: + """Fit RTX color grade gain to PPISP's cross-channel homography.""" + homography = _compute_homography(color_latent) + values = np.linspace(0.05, 1.0, 5, dtype=np.float64) + rgb = np.array(np.meshgrid(values, values, values), dtype=np.float64).T.reshape(-1, 3) + target = np.clip(_apply_color_homography(rgb, homography), 0.0, 4.0) + denom = np.maximum(np.sum(rgb * rgb, axis=0), 1.0e-8) + gain = np.sum(target * rgb, axis=0) / denom + return np.clip(gain, 0.25, 4.0) + + +def _bounded_softplus(raw: np.ndarray, min_value: float) -> np.ndarray: + return min_value + np.log1p(np.exp(raw)) + + +def _sigmoid(raw: np.ndarray) -> np.ndarray: + return 1.0 / (1.0 + np.exp(-raw)) + + +def _apply_crf(x: np.ndarray, raw_params: np.ndarray) -> np.ndarray: + toe = _bounded_softplus(raw_params[0], 0.3) + shoulder = _bounded_softplus(raw_params[1], 0.3) + gamma = _bounded_softplus(raw_params[2], 0.1) + center = _sigmoid(raw_params[3]) + lerp_val = (shoulder - toe) * center + toe + a = (shoulder * center) / lerp_val + b = 1.0 - a + y = np.where( + x <= center, + a * np.power(x / center, toe), + 1.0 - b * np.power((1.0 - x) / (1.0 - center), shoulder), + ) + return np.power(np.maximum(0.0, y), gamma) + + +def _fit_grade_gamma(crf_params: np.ndarray) -> np.ndarray: + """Fit RTX grade gamma to PPISP's per-channel CRF.""" + x = np.linspace(0.02, 0.98, 96, dtype=np.float64) + candidates = np.linspace(0.25, 4.0, 128, dtype=np.float64) + result = [] + for channel in range(3): + target = _apply_crf(x, crf_params[channel]) + errors = [np.mean((np.power(x, 1.0 / gamma) - target) ** 2) for gamma in candidates] + result.append(float(candidates[int(np.argmin(errors))])) + return np.asarray(result, dtype=np.float64) + + +def _ppisp_vignette_luminance(vig_params: np.ndarray, width: int, height: int) -> Tuple[np.ndarray, np.ndarray]: + sample_w = 48 + sample_h = max(8, int(round(sample_w * max(height, 1) / max(width, 1)))) + xs = (np.arange(sample_w, dtype=np.float64) + 0.5 - sample_w * 0.5) / sample_w + ys = (np.arange(sample_h, dtype=np.float64) + 0.5 - sample_h * 0.5) / sample_w + grid_x, grid_y = np.meshgrid(xs, ys) + uv = np.stack([grid_x, grid_y], axis=-1) + + rgb_falloff = [] + for channel in range(3): + center = vig_params[channel, 0:2] + delta = uv - center + r2 = np.sum(delta * delta, axis=-1) + falloff = 1.0 + vig_params[channel, 2] * r2 + vig_params[channel, 3] * r2**2 + vig_params[channel, 4] * r2**3 + rgb_falloff.append(np.clip(falloff, 0.0, 1.0)) + rgb_falloff = np.stack(rgb_falloff, axis=-1) + luminance = np.dot(rgb_falloff, np.array([0.2126, 0.7152, 0.0722], dtype=np.float64)) + + org_u = (np.arange(sample_w, dtype=np.float64) + 0.5) / sample_w + org_v = (np.arange(sample_h, dtype=np.float64) + 0.5) / sample_h + org_x, org_y = np.meshgrid(org_u, org_v) + org_uv = np.stack([org_x, org_y], axis=-1) + return luminance, org_uv + + +def _fit_tv_vignette(vig_params: np.ndarray, width: int, height: int) -> Tuple[bool, float, float]: + target, org_uv = _ppisp_vignette_luminance(vig_params, width, height) + if float(np.max(np.abs(target - 1.0))) < 1.0e-3: + return False, 107.0, 0.7 + + uv2 = org_uv * (1.0 - org_uv) + base = uv2[..., 0] * uv2[..., 1] + best_error = float("inf") + best_size = 107.0 + best_strength = 0.7 + + for size in np.linspace(1.0, 180.0, 72): + raw = np.maximum(base * (size + 14.0), 1.0e-8) + for strength in np.linspace(0.2, 2.0, 73): + candidate = np.power(raw, strength) + error = float(np.mean((candidate - target) ** 2)) + if error < best_error: + best_error = error + best_size = float(size) + best_strength = float(strength) + + return True, best_size, best_strength + + +def _author_camera_exposure( + stage: Usd.Stage, + camera_path: str, + frame_indices: List[int], + exposure_params: np.ndarray, +) -> None: + camera_prim = stage.GetPrimAtPath(camera_path) + if not camera_prim.IsValid(): + log.warning("Cannot author OV exposure: missing camera prim %s", camera_path) + return + + _prepend_api_schemas(camera_prim, _CAMERA_EXPOSURE_APIS) + _create_float_attr(camera_prim, "exposure:fStop", _DEFAULT_EXPOSURE_FSTOP) + _create_float_attr(camera_prim, "exposure:iso", _DEFAULT_EXPOSURE_ISO) + _create_float_attr(camera_prim, "exposure:responsivity", _DEFAULT_EXPOSURE_RESPONSIVITY) + + valid = [frame_idx for frame_idx in frame_indices if frame_idx < len(exposure_params)] + exposure_values = np.exp2(exposure_params[valid]) * _BASE_EXPOSURE_TIME_SECONDS if valid else np.asarray([]) + default_value = float(np.mean(exposure_values)) if len(exposure_values) else _BASE_EXPOSURE_TIME_SECONDS + + exposure_time = camera_prim.CreateAttribute("exposure:time", Sdf.ValueTypeNames.Float) + exposure_time.Set(default_value) + for frame_idx, value in zip(valid, exposure_values): + exposure_time.Set(float(value), float(frame_idx)) + + +def _author_tv_vignette(render_product: Usd.Prim, vig_params: np.ndarray, width: int, height: int) -> None: + enabled, size, strength = _fit_tv_vignette(vig_params, width, height) + _create_bool_attr(render_product, "omni:rtx:post:tvNoise:enabled", enabled) + _create_bool_attr(render_product, "omni:rtx:post:tvNoise:vignetting:enabled", enabled) + _create_float_attr(render_product, "omni:rtx:post:tvNoise:vignetting:size", size) + _create_float_attr(render_product, "omni:rtx:post:tvNoise:vignetting:strength", strength) + + for attr_name in ( + "omni:rtx:post:tvNoise:filmGrain:enabled", + "omni:rtx:post:tvNoise:ghostFlickering:enabled", + "omni:rtx:post:tvNoise:randomSplotches:enabled", + "omni:rtx:post:tvNoise:scanlines:enabled", + "omni:rtx:post:tvNoise:scrollBug:enabled", + "omni:rtx:post:tvNoise:verticalLines:enabled", + "omni:rtx:post:tvNoise:vignetting:flickering:enabled", + "omni:rtx:post:tvNoise:waveDistortion:enabled", + ): + _create_bool_attr(render_product, attr_name, False) + + +def _author_color_grade( + render_product: Usd.Prim, + frame_indices: List[int], + color_params: np.ndarray, + crf_params: np.ndarray, +) -> None: + valid = [frame_idx for frame_idx in frame_indices if frame_idx < len(color_params)] + gains = np.stack([_fit_grade_gain(color_params[frame_idx]) for frame_idx in valid], axis=0) if valid else np.ones((0, 3)) + default_gain = np.mean(gains, axis=0) if len(gains) else np.ones(3, dtype=np.float64) + gamma = _fit_grade_gamma(crf_params) + + grade_enabled = bool(np.max(np.abs(default_gain - 1.0)) > 1.0e-3 or np.max(np.abs(gamma - 1.0)) > 1.0e-3) + _create_bool_attr(render_product, "omni:rtx:post:grade:enabled", grade_enabled) + gain_attr = _create_color_attr(render_product, "omni:rtx:post:grade:gain", default_gain) + _create_color_attr(render_product, "omni:rtx:post:grade:gamma", gamma) + _create_color_attr(render_product, "omni:rtx:post:grade:offset", (0.0, 0.0, 0.0)) + _create_color_attr(render_product, "omni:rtx:post:grade:contrast", (1.0, 1.0, 1.0)) + _create_color_attr(render_product, "omni:rtx:post:grade:saturation", (1.0, 1.0, 1.0)) + + for frame_idx, gain in zip(valid, gains): + gain_attr.Set(Gf.Vec3f(float(gain[0]), float(gain[1]), float(gain[2])), float(frame_idx)) + + +def add_ov_post_processing( + stage: Usd.Stage, + camera_names: List[str], + camera_prim_paths: Dict[str, str], + camera_frame_mapping: Dict[str, List[int]], + render_product_entries: Dict[str, Tuple[str, int, int]], + post_processing, + mode: str, + render_scope_path: str = "/Render", +) -> None: + """Author Omniverse RTX post-processing settings for PPISP WAR export.""" + normalized_mode = normalize_ov_post_processing_mode(mode) + if normalized_mode == MODE_NONE: + return + + exposure_params = _as_numpy(post_processing.exposure_params) + color_params = _as_numpy(post_processing.color_params) + vignetting_params = _as_numpy(post_processing.vignetting_params) + crf_params = _as_numpy(post_processing.crf_params) + + camera_name_to_index = {name: idx for idx, name in enumerate(camera_names)} + approximate_full_ppisp = normalized_mode in {MODE_PPISP_APPROX_WAR, MODE_PPISP_HYBRID_WAR} + + for camera_name in camera_names: + frame_indices = camera_frame_mapping.get(camera_name, []) + camera_path = camera_prim_paths.get(camera_name) + if camera_path is None: + log.warning("Skipping OV post-processing for %s: missing camera prim", camera_name) + continue + + _author_camera_exposure(stage, camera_path, frame_indices, exposure_params) + + if not approximate_full_ppisp: + continue + + camera_index = camera_name_to_index[camera_name] + render_product_name = _make_usd_prim_name(camera_name) + render_product_path = f"{render_scope_path}/{render_product_name}" + render_product = stage.GetPrimAtPath(render_product_path) + if not render_product.IsValid(): + log.warning("Skipping OV post-processing for %s: missing RenderProduct", camera_name) + continue + + _prepend_api_schemas(render_product, _RENDER_PRODUCT_APIS) + _, width, height = render_product_entries.get(camera_name, ("", *_DEFAULT_RENDER_RESOLUTION)) + width = width or _DEFAULT_RENDER_RESOLUTION[0] + height = height or _DEFAULT_RENDER_RESOLUTION[1] + + _author_tv_vignette(render_product, vignetting_params[camera_index], width, height) + _author_color_grade(render_product, frame_indices, color_params, crf_params[camera_index]) + + log.warning( + "Authored OV RTX post-processing PPISP workaround mode '%s'. This is approximate and not SPG-fidelity.", + normalized_mode, + ) From 1239be71b9451caf29d930230ee27ee1ab168b55 Mon Sep 17 00:00:00 2001 From: Nicolas Moenne-Loccoz Date: Sun, 26 Apr 2026 19:34:43 -0400 Subject: [PATCH 07/42] feat(export): refine PPISP fallback selection Make export_ppisp the master PPISP export gate and use ov-post-processing only to select SPG, exposure fallback, fitted fallback, or combined SPG-plus-fallback output. Made-with: Cursor --- configs/base_gs.yaml | 9 +- docs/ppsip-to-rtx-pp-plan.md | 633 ++++++++++++++++++ threedgrut/export/scripts/transcode.py | 5 +- threedgrut/export/usd/camera_copy.py | 79 ++- threedgrut/export/usd/exporter.py | 67 +- threedgrut/export/usd/writers/__init__.py | 2 +- .../export/usd/writers/ov_post_processing.py | 48 +- 7 files changed, 782 insertions(+), 61 deletions(-) create mode 100644 docs/ppsip-to-rtx-pp-plan.md diff --git a/configs/base_gs.yaml b/configs/base_gs.yaml index 7c84c817..e0ec8a5a 100644 --- a/configs/base_gs.yaml +++ b/configs/base_gs.yaml @@ -42,17 +42,18 @@ export_usd: enabled: false path: "" apply_normalizing_transform: true - format: standard # "nurec" for Omniverse USDVol internal format, "standard" for USDVol ParticleField3DGaussianSplat + format: standard # "nurec" for internal USDVol format, "standard" for USDVol ParticleField3DGaussianSplat half_precision: false export_cameras: true export_background: true sorting_mode_hint: cameraDistance # If true, Gaussian prim ColorSpaceAPI uses lin_rec709_scene; else srgb_rec709_display linear_srgb: false - # PPISP post-processing export as SPG shader on per-camera RenderProducts + # Enable PPISP post-processing export. ov-post-processing selects the implementation. export_ppisp: false - # Omniverse RTX post-processing workaround for Kit versions without reliable SPG. - # none | ppisp-exposure-war | ppisp-approx-war | ppisp-hybrid-war + # none uses the full SPG path when export_ppisp is true. + # Other values use Omniverse USD post-processing fallback modes. + # none | ppisp-exposure-fallback | ppisp-fitted-post-processing-fallback | ppisp-spg-plus-fitted-post-processing-fallback ov-post-processing: none # USD timeCodesPerSecond; time codes are bare frame indices so this sets playback speed frames_per_second: 1.0 diff --git a/docs/ppsip-to-rtx-pp-plan.md b/docs/ppsip-to-rtx-pp-plan.md new file mode 100644 index 00000000..a2d0c792 --- /dev/null +++ b/docs/ppsip-to-rtx-pp-plan.md @@ -0,0 +1,633 @@ +# PPISP Omniverse USD Post-Processing Fallback Plan + +Scope: investigate whether the PPISP effect currently planned for SPG export can +also be approximated by existing Omniverse USD post-processing settings. + +Goal: provide an Omniverse USD post-processing fallback for Kit versions where SPG is unavailable, +not supported in the target deployment, or affected by SPG bugs. This is not a +replacement for the exact SPG export path. + +User-facing control: expose the fallback as an export parameter named +`ov-post-processing`. The implementation should live in a dedicated USD writer +file and stay separate from the SPG PPISP writer. + +This document is intentionally a plan only. No implementation is approved here. + +Status: `BLOCKED_ON_SPG_IMPLEMENTATION` + +Implementation gate: the SPG PPISP export path is being implemented first. The +`ov-post-processing` fallback should be implemented on top of that work, after +the shared camera grouping and `/Render`/`RenderProduct` authoring are available. + +--- + +## 1. Context + +The current PPISP USD export plan in `docs/ppisp-export-plan.md` uses a custom +SPG shader on each `RenderProduct` because PPISP is a post-blend image-space +operator: + +1. Exposure: per-frame scalar `rgb *= 2**e`. +2. Vignetting: per-camera, per-channel, per-pixel multiplicative falloff. +3. Color correction: per-frame 3x3 homography in RGI space with intensity + renormalisation. +4. CRF: per-camera, per-channel 4-parameter toe/shoulder/gamma curve. + +The question here is whether a secondary PPISP export path can map enough of the +effect onto existing Omniverse USD post-processing controls to be useful when +SPG is not viable in a given Kit runtime. + +Investigation target: + +- Kit rendering post-processing implementation. +- Generated USD render-settings schema. +- Existing USD stages that author post-processing attributes on `RenderProduct` + prims. + +--- + +## 2. USD post-processing surface found + +Kit exposes post-processing in two layers. + +The C++ renderer layer: + +- `Postprocessing::addPostprocessing` +- `Postprocessing::addTonemapping` +- `Postprocessing::addTvNoise` +- `Postprocessing::addRegisteredCompositing` + +The USD/render-settings layer: + +- `RenderProduct` prims can apply post-processing settings API schemas. +- Example stages author post-processing attributes directly on + `/Render/`. +- The generated schema exposes: + - camera exposure settings. + - tonemapping settings. + - color grading settings. + - vignette settings. + +Relevant setting families: + +- Camera exposure: + - `exposure:time` + - `exposure:fStop` + - `exposure:iso` + - `exposure:responsivity` +- Tonemapping: + - tonemap operator. + - tonemap dither. + - advanced carb-backed controls such as `exposureKey`, `whiteScale`, + `maxWhiteLuminance`, `whitepoint`, and `enableSrgbToGamma` +- Color grading: + - grade enabled. + - `blackPoint`, `whitePoint`, `contrast`, `lift`, `gain`, `multiply`, + `offset`, `gamma`, `saturation` +- TV-noise vignette: + - effect enabled. + - vignetting enabled. + - vignetting size. + - vignetting strength. + +Important observed shader behavior: + +- Tonemapping applies exposure through `computeExposureScale`, then one of the + built-in operators: raw/clamp, linear, Reinhard, modified Reinhard, + Hejl-Hable, Hable UC2, ACES approximation, or Iray Reinhard. +- Color correction/grading can run before tonemapping in ACES mode or after + tonemapping in Standard mode. +- TV-noise vignetting uses one scalar radial-ish function: + `pow(uv.x * (1 - uv.x) * uv.y * (1 - uv.y) * (size + 14), strength)`. + +--- + +## 3. Mapping assessment + +### 3.1 Exposure + +Assessment: exact scalar mapping is likely possible. + +Reasoning: + +- PPISP exposure is a scalar multiply by `2**exposure_params[frame_idx]`. +- USD exposure scale is proportional to `exposure:time`, `filmIso`, and + `responsivity`, and inversely proportional to `fStop**2`. +- If all other exposure parameters are held fixed, time-sampling + `exposure:time` as `baseExposureTime * 2**e` should reproduce the scalar + exposure factor. + +Recommended USD mapping: + +- Apply the camera exposure API schema to each camera prim, or author the + equivalent camera exposure attributes if already supported by the target USD + version. +- Time-sample `exposure:time` per frame. +- Keep `exposure:fStop`, `exposure:iso`, and `exposure:responsivity` fixed. +- Disable auto exposure and histogram adaptation for validation. + +Risks: + +- Exposure is embedded in tonemapping. Exactness only holds if the rest + of the tone pipeline is configured so the exposure scale is not folded into a + different nonlinear look. +- Kit has a Gaussian-specific skip-tonemapping path. If active, it may bypass + the tone pass for Gaussian primary hits and therefore bypass the intended + exposure mapping. + +Confidence: 0.75. + +### 3.2 Vignetting + +Assessment: approximate only. + +Reasoning: + +- PPISP vignetting is per-camera, per-channel, and parameterized by five values + per channel. +- USD vignette control is a scalar function shared across RGB, controlled by + only `size` and `strength`. +- The vignette function is centered in normalized screen space and does not expose + per-channel coefficients or arbitrary polynomial/radial terms. + +Recommended USD mapping: + +- Use the USD vignette API schema only for an approximation path. +- Enable TV noise and vignetting, but disable film grain, scanlines, ghosting, + scrolling, random splotches, wave distortion, vertical lines, and flicker. +- Fit one scalar vignette curve per physical camera to the luminance average + of the PPISP RGB vignette map. +- Record the per-channel residual, because color-dependent vignetting cannot be + represented by this scalar control. + +Risks: + +- The TV-noise vignette pass is semantically part of an analog TV effect, not a + calibrated camera response model. +- It may run after tonemapping, while PPISP vignetting is before color + correction and CRF. This changes the result when later nonlinear operations + are enabled. + +Confidence: 0.45. + +### 3.3 Color Correction + +Assessment: approximate only, and likely weak for scenes with cross-channel +mixing. + +Reasoning: + +- PPISP color correction is a per-frame 3x3 homography in RGI space with + intensity renormalisation. +- USD color correction and color grading expose channel-wise saturation, + contrast, gain, gamma, offset, lift, multiply, black point, and white point. +- These controls do not expose a general 3x3 matrix or a homography with + intensity renormalisation. + +Recommended USD mapping: + +- Prefer the USD color-grading API schema over legacy color-correction carb + settings because it is present in the generated USD schema and examples. +- Use Standard mode for validation if the desired fit is after a linear + tonemap, and ACES mode only if validation shows the color space conversion is + closer to PPISP's RGI-space transform. +- Fit per-frame `gain`, `offset`, `gamma`, `contrast`, and `saturation` to + sampled RGB pairs generated by the trained PPISP transform. +- Treat any off-diagonal color coupling in the PPISP homography as residual + error, not as exportable data. + +Risks: + +- The generated USD schema exposes color grading attributes, but not every + advanced carb setting is necessarily intended for portable USD authoring. +- Time-sampled `RenderProduct` attributes should be verified in Kit, because the + schema examples are mostly static. + +Confidence: 0.35. + +### 3.4 CRF + +Assessment: no exact mapping in existing USD post-processing. + +Reasoning: + +- PPISP CRF is per-camera, per-channel, and has four learned parameters per + channel. +- USD tonemapping provides a small set of global operators. Iray Reinhard adds + `crushBlacks`, `burnHighlights`, and saturation, but not per-channel + toe/shoulder/gamma parameters. +- USD color grading `gamma` is per-channel, but it is not a learned + toe/shoulder CRF. + +Recommended USD mapping: + +- Use the built-in tonemapper only as a coarse approximation. +- Evaluate two candidate fits: + - `operator = "none"` or `"raw"` plus color grading gamma/gain/offset. + - `operator = "iray"` plus Iray Reinhard crush/burn controls and color + grading compensation. +- Fit per camera, not per frame, because PPISP CRF is per camera. + +Risks: + +- A fitted USD tonemapper may interact with exposure and color grading in ways + that make individual PPISP components hard to validate independently. +- Per-channel CRF differences are not representable by global tonemap + operators. + +Confidence: 0.25. + +--- + +## 4. Candidate architectures + +### Option R0 — SPG-only export + +Keep the SPG plan from `docs/ppisp-export-plan.md` as the only PPISP-preserving +export path. + +Use USD post-processing only for user-authored artistic settings unrelated to +PPISP. + +Recommendation: best default path when the target Kit version has reliable SPG +support. + +### Option R1 — Exposure-only USD fallback + +Export only PPISP exposure through time-sampled camera exposure attributes. +Leave vignetting, color correction, and CRF unexported or keep them in SPG. + +Recommendation: useful as the lowest-risk fallback for older Kit versions where +SPG is unavailable but some PPISP brightness matching is better than no PPISP +signal. + +### Option R2 — USD post-processing fallback + +Fit the full PPISP effect into existing USD settings: + +- Exposure via `exposure:time`. +- Vignetting via TV-noise vignette. +- Color correction via color grading. +- CRF via tonemap plus color grading. + +Recommendation: primary USD fallback candidate for Kit versions with no SPG +support or known SPG bugs. It should be advertised as approximate and version +gated. + +### Option R3 — Hybrid export for validation and migration + +Export both: + +- Exact PPISP SPG `RenderVar` path for validation and high fidelity. +- Approximate USD post-processing attributes for viewers that do not support the + custom SPG shader. + +Recommendation: best investigation mode when validating the USD fallback against +SPG in newer Kit versions, or when the same asset must run across mixed Kit +deployments. + +--- + +## 5. Proposed reviewable tasks + +### T-R0 — Build a PPISP reference response sampler + +Purpose: create a test-only numeric reference for comparing PPISP against USD +approximations. + +Inputs: + +- A trained or synthetic PPISP instance. +- A small set of RGB sample grids. +- Camera index and frame index. + +Output: + +- Per-stage reference outputs: + - after exposure + - after vignetting + - after color correction + - after CRF + - final + +Test write-up: + +- Use identity PPISP parameters and assert output equals input. +- Enable only exposure and assert output equals `rgb * 2**e`. +- Enable one non-identity operation at a time and store deterministic numeric + fixtures. + +### T-R1 — Validate USD exposure equivalence + +Purpose: prove whether `exposure:time = baseExposureTime * 2**e` matches PPISP +exposure under controlled USD settings. + +Test write-up: + +- Create a USD stage with one camera and one `RenderProduct`. +- Disable auto exposure, dither, color grading, TV noise, and nonlinear + tonemapping. +- Render a known flat-color target at several exposure values. +- Compare captured output ratios against `2**e`. +- Repeat with Gaussian skip-tonemapping enabled and disabled. + +Pass criterion: + +- Relative error below a chosen tolerance, proposed initial threshold: `1e-3` + for linear floating-point captures. + +### T-R2 — Fit and validate USD vignette + +Purpose: quantify how close the built-in USD vignette can get to PPISP +vignetting. + +Test write-up: + +- Generate PPISP vignetting maps for each camera. +- Fit `vignetting:size` and `vignetting:strength` to the luminance-average + PPISP map. +- Render a flat-color image through the vignette pass with all + other TV effects disabled. +- Compare spatial error and per-channel residual. + +Pass criterion: + +- Report RMSE and max error. Do not enforce pass/fail until real datasets are + sampled. + +### T-R3 — Fit USD color grading to PPISP color correction + +Purpose: determine whether the color grading controls can approximate the PPISP +3x3 RGI homography acceptably. + +Test write-up: + +- Sample RGB values across the training color range. +- Apply PPISP color correction for selected frames. +- Fit USD grade controls to minimize color error. +- Validate on held-out RGB samples and on rendered frames. + +Pass criterion: + +- Report `meanDeltaRgb`, `p95DeltaRgb`, and max channel error. +- Flag frames where off-diagonal homography terms dominate the residual. + +### T-R4 — Fit USD tonemap/grade to PPISP CRF + +Purpose: quantify CRF approximation quality with built-in USD tone operators. + +Test write-up: + +- For each camera, sample the PPISP per-channel CRF curves. +- Fit candidate USD settings: + - raw or none tonemap plus grade gamma/gain/offset + - Iray Reinhard plus grade compensation +- Validate per-channel curve error and final image error. + +Pass criterion: + +- Report per-camera curve RMSE and max error. +- Reject USD-only export for cameras whose CRF fit exceeds the selected + threshold. + +### T-R5 — Author a minimal USD post-processing prototype + +Purpose: verify the USD authoring model without touching the production exporter. + +Expected prototype shape: + +- `/Render/` `RenderProduct` +- Applied schemas: + - tonemapping API schema. + - color-grading API schema. + - vignette API schema. +- Camera exposure attributes on the referenced camera prim. +- `orderedVars` containing `LdrColor`. + +Test write-up: + +- Open the generated USD in Kit with the required render-settings schema enabled. +- Verify authored attributes appear in the active render settings context. +- Capture output and compare against the PPISP reference sampler. + +### T-R6 — Define USD fallback policy + +Purpose: choose when the USD post-processing fallback should be offered after +the numeric validation tasks. + +Decision points: + +- Identify the minimum Kit version where SPG is reliable enough to prefer + `spgExact`. +- Identify older Kit versions or known SPG bug IDs where the USD fallback should + be available. +- If only exposure is accurate, add an exposure-only USD fallback mode. +- If fitted error is acceptable for target datasets, add an approximate USD + post-processing fallback mode. +- If errors are high, keep USD post-processing as an explicit degraded fallback + and document SPG as required for fidelity. + +Test write-up: + +- Produce a short validation report with per-stage errors and example captures. +- Require explicit approval before implementing any exporter changes. + +--- + +## 6. Feasibility Report + +Assessment: feasible with moderate implementation risk. + +The standard USD exporter already has the right integration points: + +- `configs/base_gs.yaml` contains an `export_usd` block for export parameters. +- `USDExporter.from_config` centralizes conversion from config to exporter + constructor arguments. +- `USDExporter.export` already has access to `model`, `dataset`, `conf`, and + `background`. +- `trainer.py` already chooses `USDExporter` when `export_usd.format` is + `standard`. + +The main missing dependency is not the export parameter itself, but access to the +trained PPISP module during USD export. Today `trainer.py` calls: + +```text +exporter.export(..., dataset=self.train_dataset, conf=conf, background=...) +``` + +For any PPISP-derived fallback, this call must also pass +`post_processing=self.post_processing` when `post_processing.method == "ppisp"`. +That is already required by the SPG export plan, so the fallback should reuse +the same exporter-facing data path. + +Recommended export parameters: + +```yaml +export_usd: + export_ppisp: false + ov-post-processing: none +``` + +`export_ppisp` is the gate for PPISP export. If it is `false`, no PPISP effect +is exported and `ov-post-processing` must be `none`. + +Allowed `ov-post-processing` values when `export_ppisp` is `true`: + +- `none`: use the full SPG PPISP path and do not author fallback settings. +- `ppisp-exposure-fallback`: export only PPISP exposure through USD camera exposure. +- `ppisp-fitted-post-processing-fallback`: export the fitted USD post-processing approximation. +- `ppisp-spg-plus-fitted-post-processing-fallback`: author the fitted fallback attributes alongside the SPG + path for validation or mixed Kit deployments. + +Implementation note: because `ov-post-processing` is hyphenated, Python code +should read it with `export_conf.get("ov-post-processing", "none")`, not +`conf.export_usd.ov-post-processing` or normal dot access. + +Dedicated file recommendation: + +```text +threedgrut/export/usd/writers/ov_post_processing.py +``` + +Suggested public API: + +```python +def add_ov_post_processing( + stage, + render_product_entries, + post_processing, + dataset, + mode: str, +) -> None: + ... +``` + +Responsibilities of the dedicated file: + +- Validate that `mode` is one of the supported `ov-post-processing` values. +- Validate that `post_processing` is a PPISP instance for PPISP-derived modes. +- Author USD post-processing API schemas and attributes on `RenderProduct` + prims. +- Author camera exposure attributes for the exposure fallback. +- Fit or consume fitted parameters for vignette, color grading, and CRF + approximation. +- Log an explicit warning when falling back to degraded behavior. + +Responsibilities that should stay outside the dedicated file: + +- Camera prim creation. +- `/Render` scope and `RenderProduct` creation. +- SPG shader authoring and sidecar packaging. +- Exporter config parsing beyond passing the selected mode. + +Feasibility by mode: + +- `none`: high feasibility. Uses the existing full SPG PPISP path when + `export_ppisp` is true. +- `ppisp-exposure-fallback`: high feasibility. Requires camera prims and time-sampled + exposure authoring only. +- `ppisp-fitted-post-processing-fallback`: medium feasibility. USD authoring is straightforward, but + fitting PPISP vignetting/color/CRF into USD controls needs validation and may + have visible residuals. +- `ppisp-spg-plus-fitted-post-processing-fallback`: high feasibility after SPG and fallback paths both + exist. It is mostly orchestration and validation. + +Primary risks: + +- Current `USDExporter` exports one camera per frame; the SPG plan already notes + this must become one prim per physical camera with time-sampled transforms + before per-camera `RenderProduct` post-processing can be cleanly authored. +- The current exporter does not create `/Render` or `RenderProduct` prims. + The fallback depends on the same `render_product.py` foundation as the SPG + plan. +- Time-sampled `RenderProduct` USD attributes need validation in the target Kit + versions. +- The fallback is approximate by design. Documentation and logs must make this + visible so users do not mistake it for SPG fidelity. + +Overall recommendation: implement the feature as a small orchestrated extension +after the shared camera and `RenderProduct` groundwork from the SPG plan. Keep +`export_ppisp` disabled by default. When `export_ppisp` is enabled, use +`ov-post-processing` to choose between full SPG and explicit fallback modes. + +Execution dependency: wait for the SPG implementation to land, then add the OV +post-processing writer as a follow-up layer that reuses the SPG path's camera, +time-code, and `RenderProduct` infrastructure. + +Confidence: 0.8 for exporter/config feasibility, 0.45 for final visual fidelity +of the full PPISP approximation. + +--- + +## 7. Recommended architecture if approved later + +Use `export_ppisp` as the PPISP export gate and `ov-post-processing` as the +implementation selector: + +```text +export_usd: + export_ppisp: false + ov-post-processing: none +``` + +Behavior: + +- `export_ppisp: false`: no PPISP export. `ov-post-processing` must be `none`. +- `export_ppisp: true`, `ov-post-processing: none`: export PPISP through the + full SPG path. +- `export_ppisp: true`, `ov-post-processing: ppisp-exposure-fallback`: export + PPISP through camera exposure fallback only. +- `export_ppisp: true`, `ov-post-processing: ppisp-fitted-post-processing-fallback`: + export PPISP through fitted Omniverse USD post-processing fallback only. +- `export_ppisp: true`, `ov-post-processing: ppisp-spg-plus-fitted-post-processing-fallback`: + write both the SPG exact path and the fitted USD fallback attributes for + validation or mixed-version deployment. + +Keep the approximation code isolated from the exact SPG writer: + +```text +threedgrut/export/usd/ + writers/ + ov_post_processing.py + ppisp_writer.py +``` + +Rationale: + +- The USD mapping is a fitted approximation, not a semantic equivalent of + PPISP. +- The USD fallback exists for deployment compatibility with older or buggy Kit + SPG support, not to displace the high-fidelity SPG path. +- Keeping a separate backend makes review easier and prevents silent quality + regressions in the exact export path. + +--- + +## 8. Open questions + +- Which Kit versions need the USD fallback because SPG is unavailable? +- Which known SPG bugs should trigger or recommend the USD fallback path? +- What error threshold is acceptable for a degraded workaround export? +- Should the approximation target linear floating-point `LdrColor`, gamma + output, or Kit viewport screenshots? +- Are time-sampled `RenderProduct` post-processing attributes supported and + stable in the target Kit version? +- Should Gaussian skip-tonemapping be disabled for PPISP USD approximation, or + is the exported Gaussian material already authored for that path? +- Should the requested file name keep the `ppsip` spelling, or should a follow-up + rename to `ppisp-to-usd-post-processing-plan.md` be made? + +--- + +## 9. Current recommendation + +Use USD post-processing only as an explicit fallback alternative for older +Kit versions or known SPG failure modes. + +Exact PPISP export should remain SPG-based for Kit versions where SPG is +available and reliable. The USD fallback path should be version-gated, labeled +approximate, and validated against SPG/reference PPISP before use on target +datasets. + +Recommended next step: review this document and edit the task list or thresholds +before any implementation begins. diff --git a/threedgrut/export/scripts/transcode.py b/threedgrut/export/scripts/transcode.py index e3d8bd75..c7e6bf54 100644 --- a/threedgrut/export/scripts/transcode.py +++ b/threedgrut/export/scripts/transcode.py @@ -26,8 +26,9 @@ python -m threedgrut.export.scripts.transcode input.usdz -o output.ply python -m threedgrut.export.scripts.transcode nurec.usd -o lightfield.usdz --format lightfield -USD/USDZ → LightField: source /World prims (e.g. rig_trajectories) merge into default.usda at the -same paths; referenced layers are bundled unchanged (preserves camera animation curves). +USD/USDZ → LightField: source /World prims (e.g. rig_trajectories) and /Render +merge into default.usda at the same paths; referenced layers are bundled unchanged +(preserves camera animation curves and authored render products). /World/Gaussians is skipped by default; use --copy-source-include-gaussians to merge it too. Use --no-copy-source-prims to disable. """ diff --git a/threedgrut/export/usd/camera_copy.py b/threedgrut/export/usd/camera_copy.py index 01328b00..737b826f 100644 --- a/threedgrut/export/usd/camera_copy.py +++ b/threedgrut/export/usd/camera_copy.py @@ -112,6 +112,29 @@ def merge_source_world_at_same_paths( return total +def merge_source_prim_at_same_path(dest_stage, source_stage, prim_path: str) -> int: + """ + Copy one source root-layer prim subtree to the destination at the same path. + + This preserves non-geometry export data, such as `/Render`, during USD to + USD transcode without regenerating renderer state from Python objects. + """ + src_layer = source_stage.GetRootLayer() + dst_layer = dest_stage.GetRootLayer() + path = Sdf.Path(prim_path) + + if not src_layer.GetPrimAtPath(path): + logger.info("Source USD has no %s prim; nothing to merge", prim_path) + return 0 + if dst_layer.GetPrimAtPath(path): + logger.info("Keeping destination prim %s; not overwriting with source", prim_path) + return 0 + + count = _copy_prim_spec_recursive(src_layer, dst_layer, path, path) + logger.info("Merged source %s subtree with %d prim(s)", prim_path, count) + return count + + def copy_authored_time_settings_from_source(source_stage, dest_stage) -> None: """Copy authored time code range and FPS from source to destination stage when set.""" try: @@ -154,9 +177,27 @@ def _gather_ref_payload_basenames_from_prim_spec(spec: Sdf.PrimSpec) -> Set[str] bn = _basename_packaged_ref(getattr(item, "assetPath", "") or "") if bn: out.add(bn) + for prop in spec.properties: + default_value = getattr(prop, "default", None) + asset_path = getattr(default_value, "path", None) or getattr( + default_value, + "assetPath", + None, + ) + if asset_path: + bn = _basename_packaged_ref(asset_path) + if bn: + out.add(bn) return out +def _companion_sidecar_basenames(basename: str) -> Set[str]: + """Additional package files implied by a referenced asset.""" + if basename.endswith(".slang"): + return {f"{basename}.lua"} + return set() + + def _walk_prim_subtree(layer: Sdf.Layer, root_path: Sdf.Path): """Depth-first active prims under root_path (inclusive).""" spec = layer.GetPrimAtPath(root_path) @@ -189,23 +230,25 @@ def _walk_entire_layer(layer: Sdf.Layer): yield from _walk_prim_subtree(layer, root.AppendChild(child_spec.name)) -def collect_transitive_sidecars_for_world_subtree( +def collect_transitive_sidecars_for_subtree( dest_layer: Sdf.Layer, res_root: Path, - world_prefix: str = "/World", + path_prefix: str, extra_skip_names: Optional[Collection[str]] = None, ) -> List[NamedSerialized]: """ - Resolve layer/asset references under ``world_prefix`` and bundle files from ``res_root`` into the - output USDZ (flat layout). Follows references/payloads transitively through USD layers. + Resolve layer/asset references under ``path_prefix`` and bundle files from + ``res_root`` into the output USDZ (flat layout). - Skips names in ``_OUTPUT_AUTHORED_NAMES`` and ``extra_skip_names`` (e.g. source root default file). + Follows references/payloads transitively through USD layers. Skips names in + ``_OUTPUT_AUTHORED_NAMES`` and ``extra_skip_names`` (e.g. source root default + file). """ skip: Set[str] = set(_OUTPUT_AUTHORED_NAMES) if extra_skip_names: skip.update(extra_skip_names) - seed = _gather_refs_from_layer_subtree(dest_layer, world_prefix) + seed = _gather_refs_from_layer_subtree(dest_layer, path_prefix) queue: Set[str] = {n for n in seed if n not in skip} done: Set[str] = set(skip) result: List[NamedSerialized] = [] @@ -225,6 +268,9 @@ def collect_transitive_sidecars_for_world_subtree( logger.warning("Could not read sidecar %s: %s", path, e) continue result.append(NamedSerialized(filename=name, serialized=data)) + for companion in _companion_sidecar_basenames(name): + if companion not in done: + queue.add(companion) suf = path.suffix.lower() if suf not in (".usd", ".usda", ".usdc"): @@ -240,10 +286,29 @@ def collect_transitive_sidecars_for_world_subtree( queue.add(bn) if result: - logger.info("Bundled %d sidecar file(s) from %s for /World references", len(result), res_root) + logger.info( + "Bundled %d sidecar file(s) from %s for %s references", + len(result), + res_root, + path_prefix, + ) return result +def collect_transitive_sidecars_for_world_subtree( + dest_layer: Sdf.Layer, + res_root: Path, + world_prefix: str = "/World", + extra_skip_names: Optional[Collection[str]] = None, +) -> List[NamedSerialized]: + return collect_transitive_sidecars_for_subtree( + dest_layer, + res_root, + path_prefix=world_prefix, + extra_skip_names=extra_skip_names, + ) + + @contextmanager def usd_stage_path_context_for_camera_copy(usd_path: Path) -> Iterator[Optional[UsdStagePathPair]]: """ diff --git a/threedgrut/export/usd/exporter.py b/threedgrut/export/usd/exporter.py index f42522de..c60b3857 100644 --- a/threedgrut/export/usd/exporter.py +++ b/threedgrut/export/usd/exporter.py @@ -49,13 +49,14 @@ from threedgrut.export.usd.writers.background import export_background_to_usd from threedgrut.export.usd.writers.base import create_gaussian_writer from threedgrut.export.usd.camera_copy import ( - collect_transitive_sidecars_for_world_subtree, + collect_transitive_sidecars_for_subtree, copy_authored_time_settings_from_source, + merge_source_prim_at_same_path, merge_source_world_at_same_paths, ) from threedgrut.export.usd.writers.ov_post_processing import ( - MODE_NONE, - MODE_PPISP_HYBRID_WAR, + MODE_PPISP_OMNI_FALLBACK_NONE, + MODE_PPISP_OMNI_FALLBACK_SPG_PLUS_FITTED_POST_PROCESSING, normalize_ov_post_processing_mode, ) from threedgrut.export.usd.writers.camera import export_cameras_to_usd @@ -231,7 +232,7 @@ class USDExporter(ModelExporter): - Optional PPISP SPG shader on per-camera RenderProducts - USDZ packaging (default output) - For Omniverse/NuRec compatibility, use NuRecExporter instead. + For NuRec compatibility, use NuRecExporter instead. """ def __init__( @@ -245,7 +246,7 @@ def __init__( sorting_mode_hint: str = "cameraDistance", linear_srgb: bool = False, export_ppisp: bool = False, - ov_post_processing: str = MODE_NONE, + ov_post_processing: str = MODE_PPISP_OMNI_FALLBACK_NONE, frames_per_second: float = 1.0, ): """ @@ -260,9 +261,11 @@ def __init__( apply_normalizing_transform: Apply transform to normalize scene orientation. sorting_mode_hint: Sorting hint for rendering ("cameraDistance", "zDepth"). linear_srgb: If True, set prim color space to lin_rec709_scene. - export_ppisp: If True, add PPISP SPG shaders on per-camera RenderProducts. - Requires post_processing kwarg to be a ppisp.PPISP instance. - ov_post_processing: Omniverse RTX post-processing workaround mode. + export_ppisp: If True, export PPISP using SPG or the selected + Omniverse USD fallback mode. Requires post_processing kwarg to + be a ppisp.PPISP instance. + ov_post_processing: PPISP export implementation selector. "none" + uses the full SPG path when export_ppisp is enabled. frames_per_second: Sets stage.timeCodesPerSecond. Time codes are always bare frame indices (float(frame_idx)), so this controls playback speed. Default 1.0 means 1 frame per second of real time. @@ -279,6 +282,12 @@ def __init__( self.linear_srgb = linear_srgb self.export_ppisp = export_ppisp self.ov_post_processing = normalize_ov_post_processing_mode(ov_post_processing) + if not self.export_ppisp and self.ov_post_processing != MODE_PPISP_OMNI_FALLBACK_NONE: + raise ValueError( + "export_usd.ov-post-processing requires export_usd.export_ppisp=true. " + "Set export_ppisp=true to export PPISP through an Omniverse USD fallback, " + "or set ov-post-processing=none." + ) self.frames_per_second = frames_per_second def _create_default_stage(self, referenced_stages: List[NamedUSDStage]) -> NamedUSDStage: @@ -398,17 +407,19 @@ def export( skip = kwargs.get("copy_source_skip_subtrees") merge_target = default_stage_wrapped.stage if default_stage_wrapped is not None else stage merge_source_world_at_same_paths(merge_target, src_stage, skip_source_subtrees=skip) + merge_source_prim_at_same_path(merge_target, src_stage, "/Render") copy_authored_time_settings_from_source(src_stage, merge_target) if package_as_usdz and res_root is not None and res_root.is_dir(): - sidecars = collect_transitive_sidecars_for_world_subtree( - merge_target.GetRootLayer(), - res_root, - world_prefix="/World", - extra_skip_names={Path(stage_path).name}, - ) - for entry in sidecars: - if not any(f.filename == entry.filename for f in files): - files.append(entry) + for path_prefix in ("/World", "/Render"): + sidecars = collect_transitive_sidecars_for_subtree( + merge_target.GetRootLayer(), + res_root, + path_prefix=path_prefix, + extra_skip_names={Path(stage_path).name}, + ) + for entry in sidecars: + if not any(f.filename == entry.filename for f in files): + files.append(entry) except Exception as e: logger.warning("Failed to merge source USD prims: %s", e) @@ -466,8 +477,14 @@ def export( logger.warning(f"Failed to export background: {e}") render_product_entries = None - export_spg_ppisp = self.export_ppisp or self.ov_post_processing == MODE_PPISP_HYBRID_WAR - needs_ppisp_render_products = export_spg_ppisp or self.ov_post_processing != MODE_NONE + export_spg_ppisp = self.export_ppisp and ( + self.ov_post_processing == MODE_PPISP_OMNI_FALLBACK_NONE + or self.ov_post_processing == MODE_PPISP_OMNI_FALLBACK_SPG_PLUS_FITTED_POST_PROCESSING + ) + export_omni_ppisp_fallback = ( + self.export_ppisp and self.ov_post_processing != MODE_PPISP_OMNI_FALLBACK_NONE + ) + needs_ppisp_render_products = self.export_ppisp if needs_ppisp_render_products: render_product_entries = self._create_ppisp_render_products( stage=stage, @@ -488,8 +505,8 @@ def export( files=files, ) - # Export PPISP approximation as Omniverse RTX post-processing settings - if self.ov_post_processing != MODE_NONE and render_product_entries is not None: + # Export fitted PPISP approximation as Omniverse USD post-processing settings + if export_omni_ppisp_fallback and render_product_entries is not None: self._export_ov_post_processing( stage=stage, camera_names=camera_names, @@ -614,7 +631,7 @@ def _export_ov_post_processing( dataset, post_processing, ) -> None: - """Attach Omniverse RTX post-processing WAR attributes to RenderProducts.""" + """Attach Omniverse USD post-processing fallback attributes to RenderProducts.""" try: from ppisp import PPISP # type: ignore[import-not-found] except ImportError: @@ -643,7 +660,7 @@ def _export_ov_post_processing( mode=self.ov_post_processing, ) except Exception as e: - logger.warning(f"Failed to add OV post-processing workaround: {e}") + logger.warning(f"Failed to add OV post-processing fallback: {e}") @classmethod def from_config(cls, conf) -> "USDExporter": @@ -666,8 +683,8 @@ def from_config(cls, conf) -> "USDExporter": sorting_mode_hint=getattr(export_conf, "sorting_mode_hint", "cameraDistance"), linear_srgb=getattr(export_conf, "linear_srgb", False), export_ppisp=getattr(export_conf, "export_ppisp", False), - ov_post_processing=export_conf.get("ov-post-processing", MODE_NONE) + ov_post_processing=export_conf.get("ov-post-processing", MODE_PPISP_OMNI_FALLBACK_NONE) if hasattr(export_conf, "get") - else getattr(export_conf, "ov_post_processing", MODE_NONE), + else getattr(export_conf, "ov_post_processing", MODE_PPISP_OMNI_FALLBACK_NONE), frames_per_second=getattr(export_conf, "frames_per_second", 1.0), ) diff --git a/threedgrut/export/usd/writers/__init__.py b/threedgrut/export/usd/writers/__init__.py index e3572bcc..0111460f 100644 --- a/threedgrut/export/usd/writers/__init__.py +++ b/threedgrut/export/usd/writers/__init__.py @@ -21,7 +21,7 @@ - export_cameras_to_usd: one Camera prim per physical camera, animated xforms - create_render_products: /Render scope with per-camera RenderProducts - add_ppisp_to_all_render_products: PPISP SPG shader on RenderProducts -- add_ov_post_processing: Omniverse RTX post-processing PPISP workaround +- add_ov_post_processing: Omniverse USD post-processing PPISP fallback """ from threedgrut.export.usd.writers.background import export_background_to_usd diff --git a/threedgrut/export/usd/writers/ov_post_processing.py b/threedgrut/export/usd/writers/ov_post_processing.py index 4c4fd307..609cfe81 100644 --- a/threedgrut/export/usd/writers/ov_post_processing.py +++ b/threedgrut/export/usd/writers/ov_post_processing.py @@ -13,11 +13,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Omniverse RTX post-processing workaround writer for PPISP exports. +"""Omniverse USD post-processing fallback writer for PPISP exports. This writer is a degraded fallback for Kit versions where SPG is unavailable or -unreliable. It authors standard Omniverse RTX render settings only; exact PPISP -export remains the SPG path. +unreliable. It authors Omniverse USD render settings only; exact PPISP export +remains the SPG path. """ from __future__ import annotations @@ -32,16 +32,16 @@ log = logging.getLogger(__name__) -MODE_NONE = "none" -MODE_PPISP_EXPOSURE_WAR = "ppisp-exposure-war" -MODE_PPISP_APPROX_WAR = "ppisp-approx-war" -MODE_PPISP_HYBRID_WAR = "ppisp-hybrid-war" +MODE_PPISP_OMNI_FALLBACK_NONE = "none" +MODE_PPISP_OMNI_FALLBACK_EXPOSURE = "ppisp-exposure-fallback" +MODE_PPISP_OMNI_FALLBACK_FITTED_POST_PROCESSING = "ppisp-fitted-post-processing-fallback" +MODE_PPISP_OMNI_FALLBACK_SPG_PLUS_FITTED_POST_PROCESSING = "ppisp-spg-plus-fitted-post-processing-fallback" -OV_POST_PROCESSING_MODES = { - MODE_NONE, - MODE_PPISP_EXPOSURE_WAR, - MODE_PPISP_APPROX_WAR, - MODE_PPISP_HYBRID_WAR, +PPISP_OMNI_POST_PROCESSING_FALLBACK_MODES = { + MODE_PPISP_OMNI_FALLBACK_NONE, + MODE_PPISP_OMNI_FALLBACK_EXPOSURE, + MODE_PPISP_OMNI_FALLBACK_FITTED_POST_PROCESSING, + MODE_PPISP_OMNI_FALLBACK_SPG_PLUS_FITTED_POST_PROCESSING, } _BASE_EXPOSURE_TIME_SECONDS = 0.02 @@ -65,11 +65,11 @@ def normalize_ov_post_processing_mode(mode: str | None) -> str: """Normalize and validate the ``export_usd.ov-post-processing`` value.""" - normalized = MODE_NONE if mode is None else str(mode).strip().lower() - if normalized not in OV_POST_PROCESSING_MODES: + normalized = MODE_PPISP_OMNI_FALLBACK_NONE if mode is None else str(mode).strip().lower() + if normalized not in PPISP_OMNI_POST_PROCESSING_FALLBACK_MODES: raise ValueError( f"Unsupported ov-post-processing mode '{mode}'. " - f"Expected one of: {sorted(OV_POST_PROCESSING_MODES)}" + f"Expected one of: {sorted(PPISP_OMNI_POST_PROCESSING_FALLBACK_MODES)}" ) return normalized @@ -160,7 +160,7 @@ def _apply_color_homography(rgb: np.ndarray, homography: np.ndarray) -> np.ndarr def _fit_grade_gain(color_latent: np.ndarray) -> np.ndarray: - """Fit RTX color grade gain to PPISP's cross-channel homography.""" + """Fit USD color grade gain to PPISP's cross-channel homography.""" homography = _compute_homography(color_latent) values = np.linspace(0.05, 1.0, 5, dtype=np.float64) rgb = np.array(np.meshgrid(values, values, values), dtype=np.float64).T.reshape(-1, 3) @@ -195,7 +195,7 @@ def _apply_crf(x: np.ndarray, raw_params: np.ndarray) -> np.ndarray: def _fit_grade_gamma(crf_params: np.ndarray) -> np.ndarray: - """Fit RTX grade gamma to PPISP's per-channel CRF.""" + """Fit USD grade gamma to PPISP's per-channel CRF.""" x = np.linspace(0.02, 0.98, 96, dtype=np.float64) candidates = np.linspace(0.25, 4.0, 128, dtype=np.float64) result = [] @@ -334,9 +334,9 @@ def add_ov_post_processing( mode: str, render_scope_path: str = "/Render", ) -> None: - """Author Omniverse RTX post-processing settings for PPISP WAR export.""" + """Author Omniverse USD post-processing settings for PPISP fallback export.""" normalized_mode = normalize_ov_post_processing_mode(mode) - if normalized_mode == MODE_NONE: + if normalized_mode == MODE_PPISP_OMNI_FALLBACK_NONE: return exposure_params = _as_numpy(post_processing.exposure_params) @@ -345,7 +345,10 @@ def add_ov_post_processing( crf_params = _as_numpy(post_processing.crf_params) camera_name_to_index = {name: idx for idx, name in enumerate(camera_names)} - approximate_full_ppisp = normalized_mode in {MODE_PPISP_APPROX_WAR, MODE_PPISP_HYBRID_WAR} + writes_fitted_post_processing = normalized_mode in { + MODE_PPISP_OMNI_FALLBACK_FITTED_POST_PROCESSING, + MODE_PPISP_OMNI_FALLBACK_SPG_PLUS_FITTED_POST_PROCESSING, + } for camera_name in camera_names: frame_indices = camera_frame_mapping.get(camera_name, []) @@ -356,7 +359,7 @@ def add_ov_post_processing( _author_camera_exposure(stage, camera_path, frame_indices, exposure_params) - if not approximate_full_ppisp: + if not writes_fitted_post_processing: continue camera_index = camera_name_to_index[camera_name] @@ -376,6 +379,7 @@ def add_ov_post_processing( _author_color_grade(render_product, frame_indices, color_params, crf_params[camera_index]) log.warning( - "Authored OV RTX post-processing PPISP workaround mode '%s'. This is approximate and not SPG-fidelity.", + "Authored Omniverse USD post-processing PPISP fallback mode '%s'. " + "This is approximate and not SPG-fidelity.", normalized_mode, ) From a62f7c711a8f660e98f10903791b166a556c7c46 Mon Sep 17 00:00:00 2001 From: Nicolas Moenne-Loccoz Date: Sun, 26 Apr 2026 19:34:43 -0400 Subject: [PATCH 08/42] fix(export): compose USDZ scene data from root stage Keep cameras, render products, and PPISP post-processing authored where the packaged root stage composes them, and align standalone USD export with PPISP configuration. Made-with: Cursor --- docs/ppsip-to-rtx-pp-plan.md | 2 +- threedgrut/export/scripts/export_usd.py | 61 +++++++++++++++++-- threedgrut/export/tests/test_export_import.py | 35 +++++++++++ threedgrut/export/usd/exporter.py | 22 ++++--- .../export/usd/writers/ov_post_processing.py | 6 +- .../utils/post_processing_linear_to_srgb.py | 4 +- 6 files changed, 109 insertions(+), 21 deletions(-) diff --git a/docs/ppsip-to-rtx-pp-plan.md b/docs/ppsip-to-rtx-pp-plan.md index a2d0c792..55e36543 100644 --- a/docs/ppsip-to-rtx-pp-plan.md +++ b/docs/ppsip-to-rtx-pp-plan.md @@ -607,7 +607,7 @@ Rationale: - Which Kit versions need the USD fallback because SPG is unavailable? - Which known SPG bugs should trigger or recommend the USD fallback path? -- What error threshold is acceptable for a degraded workaround export? +- What error threshold is acceptable for a degraded fallback export? - Should the approximation target linear floating-point `LdrColor`, gamma output, or Kit viewport screenshots? - Are time-sampled `RenderProduct` post-processing attributes supported and diff --git a/threedgrut/export/scripts/export_usd.py b/threedgrut/export/scripts/export_usd.py index 9cff811a..96a46a45 100644 --- a/threedgrut/export/scripts/export_usd.py +++ b/threedgrut/export/scripts/export_usd.py @@ -21,10 +21,12 @@ python -m threedgrut.export.scripts.export_usd --checkpoint path/to/checkpoint.pt --output output.usdz # Export with NuRec format (Omniverse compatibility) - python -m threedgrut.export.scripts.export_usd --checkpoint path/to/checkpoint.pt --output output.usdz --format nurec + python -m threedgrut.export.scripts.export_usd --checkpoint path/to/checkpoint.pt \ + --output output.usdz --format nurec # Export without cameras/background - python -m threedgrut.export.scripts.export_usd --checkpoint path/to/checkpoint.pt --output output.usdz --no-cameras --no-background + python -m threedgrut.export.scripts.export_usd --checkpoint path/to/checkpoint.pt \ + --output output.usdz --no-cameras --no-background """ import argparse @@ -147,6 +149,48 @@ def parse_args(): return parser.parse_args() +def _load_ppisp_from_checkpoint(checkpoint, conf): + """Load trained PPISP state for USD export when available.""" + post_conf = getattr(conf, "post_processing", None) + if "post_processing" not in checkpoint or post_conf is None or getattr(post_conf, "method", None) != "ppisp": + return None + + try: + from ppisp import PPISP, PPISPConfig + except ImportError: + logger.warning("Checkpoint contains PPISP state, but ppisp is not available; skipping PPISP USD export") + return None + + use_controller = post_conf.get("use_controller", True) + n_distillation_steps = post_conf.get("n_distillation_steps", 5000) + if use_controller and n_distillation_steps > 0: + main_training_steps = conf.n_iterations - n_distillation_steps + controller_activation_ratio = main_training_steps / conf.n_iterations + controller_distillation = True + elif use_controller: + controller_activation_ratio = 0.8 + controller_distillation = False + else: + controller_activation_ratio = 0.0 + controller_distillation = False + + ppisp_config = PPISPConfig( + use_controller=use_controller, + controller_distillation=controller_distillation, + controller_activation_ratio=controller_activation_ratio, + ) + post_processing = PPISP.from_state_dict(checkpoint["post_processing"]["module"], config=ppisp_config) + post_processing = post_processing.to("cpu") + logger.info("Loaded PPISP post-processing state for USD export") + return post_processing + + +def _get_export_conf_value(export_conf, dashed_name: str, attr_name: str, default): + if hasattr(export_conf, "get"): + return export_conf.get(dashed_name, getattr(export_conf, attr_name, default)) + return getattr(export_conf, attr_name, default) + + def load_model_from_checkpoint(checkpoint_path: str): """Load a 3DGRUT model from checkpoint.""" from threedgrut.model.model import MixtureOfGaussians @@ -169,7 +213,8 @@ def load_model_from_checkpoint(checkpoint_path: str): model.init_from_checkpoint(checkpoint, setup_optimizer=False) model.eval() - return model, conf, model.background + post_processing = _load_ppisp_from_checkpoint(checkpoint, conf) + return model, conf, model.background, post_processing def main(): @@ -189,7 +234,7 @@ def main(): # Load model from checkpoint try: - model, conf, background = load_model_from_checkpoint(str(checkpoint_path)) + model, conf, background, post_processing = load_model_from_checkpoint(str(checkpoint_path)) logger.info(f"Loaded model with {model.get_positions().shape[0]} Gaussians") except ImportError: logger.error("Failed to import model class. Is 3DGRUT properly installed?") @@ -236,13 +281,18 @@ def main(): else: half_geometry = args.half_geometry or args.half half_features = args.half_features or args.half + export_conf = getattr(conf, "export_usd", None) or conf exporter = USDExporter( half_geometry=half_geometry, half_features=half_features, export_cameras=not args.no_cameras, export_background=not args.no_background, apply_normalizing_transform=not args.no_transform, - linear_srgb=args.linear_srgb, + sorting_mode_hint=getattr(export_conf, "sorting_mode_hint", "cameraDistance"), + linear_srgb=args.linear_srgb or getattr(export_conf, "linear_srgb", False), + export_ppisp=getattr(export_conf, "export_ppisp", False), + ov_post_processing=_get_export_conf_value(export_conf, "ov-post-processing", "ov_post_processing", "none"), + frames_per_second=getattr(export_conf, "frames_per_second", 1.0), ) logger.info("Using ParticleField3DGaussianSplat schema (standard)") @@ -257,6 +307,7 @@ def main(): dataset=dataset, conf=conf, background=background, + post_processing=post_processing, **export_kw, ) logger.info(f"Export successful: {output_path}") diff --git a/threedgrut/export/tests/test_export_import.py b/threedgrut/export/tests/test_export_import.py index 0beff9a8..b4ec3613 100644 --- a/threedgrut/export/tests/test_export_import.py +++ b/threedgrut/export/tests/test_export_import.py @@ -111,6 +111,24 @@ def get_features_specular(self) -> torch.Tensor: return self._specular +class MockCameraDataset: + """Minimal dataset exposing camera poses for USD camera export tests.""" + + def __len__(self) -> int: + return 2 + + def get_poses(self) -> np.ndarray: + poses = np.repeat(np.eye(4, dtype=np.float64)[None, :, :], len(self), axis=0) + poses[1, 0, 3] = 1.0 + return poses + + def get_camera_names(self): + return ["camera_0000"] + + def get_camera_idx(self, frame_idx: int) -> int: + return 0 + + class TestPLYExportImport: """Test PLY export from ExportableModel and import back.""" @@ -472,6 +490,23 @@ def test_usd_export_color_space_from_config(self): api = Usd.ColorSpaceAPI(prim) assert api.GetColorSpaceNameAttr().Get() == "lin_rec709_scene" + def test_usdz_export_camera_is_composed_from_root_stage(self): + """USDZ camera prims are authored where the package root composes them.""" + model = MockGaussianModel(num_gaussians=5, sh_degree=3) + dataset = MockCameraDataset() + with tempfile.TemporaryDirectory() as tmpdir: + usd_path = Path(tmpdir) / "test.usdz" + USDExporter( + half_precision=False, + export_cameras=True, + export_background=False, + apply_normalizing_transform=False, + ).export(model, usd_path, dataset=dataset) + stage = Usd.Stage.Open(str(usd_path)) + assert stage + assert stage.GetPrimAtPath("/World/Cameras/camera_0000").IsValid() + assert not stage.GetPrimAtPath("/World/gaussians/Cameras/camera_0000").IsValid() + if __name__ == "__main__": pytest.main([__file__, "-v"]) diff --git a/threedgrut/export/usd/exporter.py b/threedgrut/export/usd/exporter.py index c60b3857..8ca6c481 100644 --- a/threedgrut/export/usd/exporter.py +++ b/threedgrut/export/usd/exporter.py @@ -295,6 +295,7 @@ def _create_default_stage(self, referenced_stages: List[NamedUSDStage]) -> Named Create a default.usda that references the data stages. """ stage = initialize_usd_stage(up_axis="Y") + stage.SetTimeCodesPerSecond(self.frames_per_second) for ref_stage in referenced_stages: filename_stem = Path(ref_stage.filename).stem @@ -391,6 +392,7 @@ def export( default_stage_wrapped: Optional[NamedUSDStage] = None if package_as_usdz: default_stage_wrapped = self._create_default_stage([gaussians_stage]) + scene_stage = default_stage_wrapped.stage if default_stage_wrapped is not None else stage files: List[NamedSerialized] = [] @@ -405,7 +407,7 @@ def export( logger.warning("Could not open source USD for prim merge: %s", stage_path) else: skip = kwargs.get("copy_source_skip_subtrees") - merge_target = default_stage_wrapped.stage if default_stage_wrapped is not None else stage + merge_target = scene_stage merge_source_world_at_same_paths(merge_target, src_stage, skip_source_subtrees=skip) merge_source_prim_at_same_path(merge_target, src_stage, "/Render") copy_authored_time_settings_from_source(src_stage, merge_target) @@ -447,7 +449,7 @@ def export( logger.warning("Could not extract camera intrinsics from dataset, using default") camera_prim_paths = export_cameras_to_usd( - stage=stage, + stage=scene_stage, poses=poses, camera_names=camera_names, frame_to_camera=frame_to_camera, @@ -464,7 +466,7 @@ def export( if self.export_background and background is not None: try: _, envmap_bytes = export_background_to_usd( - stage=stage, + stage=scene_stage, background=background, conf=conf, root_path="/World/Environment", @@ -487,7 +489,7 @@ def export( needs_ppisp_render_products = self.export_ppisp if needs_ppisp_render_products: render_product_entries = self._create_ppisp_render_products( - stage=stage, + stage=scene_stage, dataset=dataset, camera_names=camera_names, frame_to_camera=frame_to_camera, @@ -498,17 +500,17 @@ def export( # Export PPISP as SPG shaders on RenderProducts if export_spg_ppisp and render_product_entries is not None: self._export_ppisp( - stage=stage, + stage=scene_stage, dataset=dataset, camera_names=camera_names, post_processing=kwargs.get("post_processing"), files=files, ) - # Export fitted PPISP approximation as Omniverse USD post-processing settings + # Export PPISP through fitted Omniverse USD post-processing settings. if export_omni_ppisp_fallback and render_product_entries is not None: self._export_ov_post_processing( - stage=stage, + stage=scene_stage, camera_names=camera_names, camera_prim_paths=camera_prim_paths, render_product_entries=render_product_entries, @@ -552,7 +554,7 @@ def _create_ppisp_render_products( camera_prim_paths: Dict[str, str], camera_params, ): - """Create /Render RenderProducts shared by SPG and OV PPISP exports.""" + """Create /Render RenderProducts shared by SPG and Omniverse fallback PPISP exports.""" if dataset is None or not camera_prim_paths: logger.warning("No camera prims available for PPISP RenderProduct wiring, skipping") return None @@ -635,7 +637,7 @@ def _export_ov_post_processing( try: from ppisp import PPISP # type: ignore[import-not-found] except ImportError: - logger.warning("ppisp package not available, skipping OV post-processing export") + logger.warning("ppisp package not available, skipping Omniverse post-processing fallback export") return if not isinstance(post_processing, PPISP): @@ -660,7 +662,7 @@ def _export_ov_post_processing( mode=self.ov_post_processing, ) except Exception as e: - logger.warning(f"Failed to add OV post-processing fallback: {e}") + logger.warning(f"Failed to add Omniverse post-processing fallback: {e}") @classmethod def from_config(cls, conf) -> "USDExporter": diff --git a/threedgrut/export/usd/writers/ov_post_processing.py b/threedgrut/export/usd/writers/ov_post_processing.py index 609cfe81..977ec640 100644 --- a/threedgrut/export/usd/writers/ov_post_processing.py +++ b/threedgrut/export/usd/writers/ov_post_processing.py @@ -263,7 +263,7 @@ def _author_camera_exposure( ) -> None: camera_prim = stage.GetPrimAtPath(camera_path) if not camera_prim.IsValid(): - log.warning("Cannot author OV exposure: missing camera prim %s", camera_path) + log.warning("Cannot author Omniverse fallback exposure: missing camera prim %s", camera_path) return _prepend_api_schemas(camera_prim, _CAMERA_EXPOSURE_APIS) @@ -354,7 +354,7 @@ def add_ov_post_processing( frame_indices = camera_frame_mapping.get(camera_name, []) camera_path = camera_prim_paths.get(camera_name) if camera_path is None: - log.warning("Skipping OV post-processing for %s: missing camera prim", camera_name) + log.warning("Skipping Omniverse post-processing fallback for %s: missing camera prim", camera_name) continue _author_camera_exposure(stage, camera_path, frame_indices, exposure_params) @@ -367,7 +367,7 @@ def add_ov_post_processing( render_product_path = f"{render_scope_path}/{render_product_name}" render_product = stage.GetPrimAtPath(render_product_path) if not render_product.IsValid(): - log.warning("Skipping OV post-processing for %s: missing RenderProduct", camera_name) + log.warning("Skipping Omniverse post-processing fallback for %s: missing RenderProduct", camera_name) continue _prepend_api_schemas(render_product, _RENDER_PRODUCT_APIS) diff --git a/threedgrut/utils/post_processing_linear_to_srgb.py b/threedgrut/utils/post_processing_linear_to_srgb.py index 462196f5..f93e160e 100644 --- a/threedgrut/utils/post_processing_linear_to_srgb.py +++ b/threedgrut/utils/post_processing_linear_to_srgb.py @@ -58,10 +58,10 @@ def linear_to_srgb(x: torch.Tensor) -> torch.Tensor: Encoded values, same shape / dtype / device as ``x``. """ limit = 0.0031308 - x = torch.clamp(x, min=1e-8, max=1.0) + positive_x = torch.clamp(x, min=0.0) return torch.where( x > limit, - 1.055 * torch.pow(x, 1.0 / 2.4) - 0.055, + 1.055 * torch.pow(positive_x, 1.0 / 2.4) - 0.055, 12.92 * x, ) From e216e9e223d66887ef78563857f30daf8ffba805 Mon Sep 17 00:00:00 2001 From: Nicolas Moenne-Loccoz Date: Mon, 27 Apr 2026 09:38:26 -0400 Subject: [PATCH 09/42] fix(export): make PPISP USD export usable from checkpoints Ensure checkpoint exports carry the PPISP SPG graph, camera timeline, and Gaussian render settings needed for Kit to render the post-processed LdrColor output. Made-with: Cursor --- configs/base_gs.yaml | 4 +- threedgrut/datasets/__init__.py | 67 +++++++++++++++++++ threedgrut/export/scripts/export_usd.py | 8 ++- threedgrut/export/tests/test_export_import.py | 2 + threedgrut/export/usd/exporter.py | 42 ++++++++++-- .../export/usd/ppisp_spg/ppisp_usd_spg.slang | 2 +- .../usd/ppisp_spg/ppisp_usd_spg.slang.lua | 19 ++---- threedgrut/export/usd/writers/camera.py | 8 +++ threedgrut/export/usd/writers/ppisp_writer.py | 29 +++++--- .../export/usd/writers/render_product.py | 11 +-- 10 files changed, 153 insertions(+), 39 deletions(-) diff --git a/configs/base_gs.yaml b/configs/base_gs.yaml index e0ec8a5a..65b915e9 100644 --- a/configs/base_gs.yaml +++ b/configs/base_gs.yaml @@ -49,8 +49,8 @@ export_usd: sorting_mode_hint: cameraDistance # If true, Gaussian prim ColorSpaceAPI uses lin_rec709_scene; else srgb_rec709_display linear_srgb: false - # Enable PPISP post-processing export. ov-post-processing selects the implementation. - export_ppisp: false + # Enable PPISP post-processing export when post_processing.method is "ppisp". + export_ppisp: true # none uses the full SPG path when export_ppisp is true. # Other values use Omniverse USD post-processing fallback modes. # none | ppisp-exposure-fallback | ppisp-fitted-post-processing-fallback | ppisp-spg-plus-fitted-post-processing-fallback diff --git a/threedgrut/datasets/__init__.py b/threedgrut/datasets/__init__.py index 64317f59..bd450072 100644 --- a/threedgrut/datasets/__init__.py +++ b/threedgrut/datasets/__init__.py @@ -184,6 +184,73 @@ def make(name: str, config, ray_jitter): return train_dataset, val_dataset +def make_train(name: str, config, ray_jitter=None): + match name: + case "nerf": + dataset = NeRFDataset( + config.path, + split="train", + bg_color=config.model.background.color, + ray_jitter=ray_jitter, + ) + case "colmap": + # Load EXIF exposure data if enabled + if config.dataset.get("load_exif", True): + exif_exposures = _load_colmap_exif_exposures( + config.path, + config.dataset.downsample_factor, + ) + else: + exif_exposures = None + + dataset = ColmapDataset( + config.path, + split="train", + downsample_factor=config.dataset.downsample_factor, + test_split_interval=config.dataset.test_split_interval, + ray_jitter=ray_jitter, + exif_exposures=exif_exposures, + ) + case "scannetpp": + dataset = ScannetppDataset( + config.path, + split="train", + ray_jitter=ray_jitter, + downsample_factor=config.dataset.downsample_factor, + test_split_interval=config.dataset.test_split_interval, + ) + case "ncore": + dataset = NCoreDataset( + datapath=config.path, + device="cuda", + split="train", + camera_ids=config.dataset.get("camera_ids", None), + lidar_ids=config.dataset.get("lidar_ids", None), + downsample=config.dataset.get("downsample", 1.0), + sample_full_image=config.dataset.train.get("sample_full_image", True), + window_size=config.dataset.train.get("window_size", 256), + n_samples_per_epoch=config.dataset.train.get("n_samples_per_epoch", 1000), + n_train_sample_timepoints=config.dataset.train.get("n_train_sample_timepoints", 1), + n_train_sample_camera_rays=config.dataset.train.get("n_train_sample_camera_rays", 4096), + n_val_image_subsample=config.dataset.get("n_val_image_subsample", 1), + val_frame_interval=config.dataset.get("val_frame_interval", 8), + seek_offset_sec=config.dataset.train.get("seek_offset_sec", 0.0), + duration_sec=config.dataset.train.get("duration_sec", None), + poses_component_group=config.dataset.get("poses_component_group", "default"), + intrinsics_component_group=config.dataset.get("intrinsics_component_group", "default"), + masks_component_group=config.dataset.get("masks_component_group", "default"), + jpeg_backend_cpu=config.dataset.get("jpeg_backend_cpu", "simplejpeg"), + simplejpeg_fastdct=config.dataset.get("simplejpeg_fastdct", False), + simplejpeg_fastupsample=config.dataset.get("simplejpeg_fastupsample", False), + lidar_color_generic_data_name=config.dataset.get("lidar_color_generic_data_name", "rgb"), + ) + case _: + raise ValueError( + f'Unsupported dataset type: {config.dataset.type}. Choose between: ["colmap", "nerf", "scannetpp", "ncore"].' + ) + return dataset + + def make_test(name: str, config): match name: case "nerf": diff --git a/threedgrut/export/scripts/export_usd.py b/threedgrut/export/scripts/export_usd.py index 96a46a45..7f0c398c 100644 --- a/threedgrut/export/scripts/export_usd.py +++ b/threedgrut/export/scripts/export_usd.py @@ -264,15 +264,16 @@ def main(): elif not hasattr(conf, "dataset") or not hasattr(conf.dataset, "type"): logger.warning("No dataset type in checkpoint config. Cannot load dataset for camera export.") else: - dataset = datasets.make_test(name=conf.dataset.type, config=conf) + dataset = datasets.make_train(name=conf.dataset.type, config=conf, ray_jitter=None) split = getattr(dataset, "split", "unknown") logger.info(f"Loaded dataset with {len(dataset)} frames for camera export (split={split})") except Exception as e: - logger.warning(f"Failed to load dataset for camera export: {e}") + logger.error(f"Failed to load dataset for camera export: {e}") if args.verbose: import traceback traceback.print_exc() + sys.exit(1) # Create exporter based on format if args.format == "nurec": @@ -282,6 +283,7 @@ def main(): half_geometry = args.half_geometry or args.half half_features = args.half_features or args.half export_conf = getattr(conf, "export_usd", None) or conf + export_ppisp = bool(getattr(export_conf, "export_ppisp", False) or post_processing is not None) exporter = USDExporter( half_geometry=half_geometry, half_features=half_features, @@ -290,7 +292,7 @@ def main(): apply_normalizing_transform=not args.no_transform, sorting_mode_hint=getattr(export_conf, "sorting_mode_hint", "cameraDistance"), linear_srgb=args.linear_srgb or getattr(export_conf, "linear_srgb", False), - export_ppisp=getattr(export_conf, "export_ppisp", False), + export_ppisp=export_ppisp, ov_post_processing=_get_export_conf_value(export_conf, "ov-post-processing", "ov_post_processing", "none"), frames_per_second=getattr(export_conf, "frames_per_second", 1.0), ) diff --git a/threedgrut/export/tests/test_export_import.py b/threedgrut/export/tests/test_export_import.py index b4ec3613..5c54d349 100644 --- a/threedgrut/export/tests/test_export_import.py +++ b/threedgrut/export/tests/test_export_import.py @@ -506,6 +506,8 @@ def test_usdz_export_camera_is_composed_from_root_stage(self): assert stage assert stage.GetPrimAtPath("/World/Cameras/camera_0000").IsValid() assert not stage.GetPrimAtPath("/World/gaussians/Cameras/camera_0000").IsValid() + assert stage.GetStartTimeCode() == 0.0 + assert stage.GetEndTimeCode() == 1.0 if __name__ == "__main__": diff --git a/threedgrut/export/usd/exporter.py b/threedgrut/export/usd/exporter.py index 8ca6c481..bbba990d 100644 --- a/threedgrut/export/usd/exporter.py +++ b/threedgrut/export/usd/exporter.py @@ -64,6 +64,15 @@ logger = logging.getLogger(__name__) +_GAUSSIAN_SKIP_TONEMAPPING_RENDER_SETTING = "rtx:rtpt:gaussian:skipTonemapping:enabled" + + +def _set_render_setting(stage: Usd.Stage, key: str, value: Any) -> None: + render_settings = dict(stage.GetRootLayer().customLayerData.get("renderSettings", {}) or {}) + render_settings[key] = value + stage.SetMetadataByDictKey("customLayerData", "renderSettings", render_settings) + + def _extract_camera_params_from_dataset(dataset) -> Optional[List]: """ Extract per-frame camera parameters from a dataset. @@ -92,7 +101,7 @@ def _extract_camera_params_from_dataset(dataset) -> Optional[List]: camera_params.append(None) continue - params_dict, _, _, camera_name = params_tuple + params_dict, _, _, camera_name, *_ = params_tuple # Reconstruct CameraModelParameters from dict if camera_name == "OpenCVPinholeCameraModelParameters": @@ -479,14 +488,18 @@ def export( logger.warning(f"Failed to export background: {e}") render_product_entries = None - export_spg_ppisp = self.export_ppisp and ( + post_processing = kwargs.get("post_processing") + has_ppisp_export_source = self.export_ppisp and post_processing is not None + export_spg_ppisp = has_ppisp_export_source and ( self.ov_post_processing == MODE_PPISP_OMNI_FALLBACK_NONE or self.ov_post_processing == MODE_PPISP_OMNI_FALLBACK_SPG_PLUS_FITTED_POST_PROCESSING ) export_omni_ppisp_fallback = ( - self.export_ppisp and self.ov_post_processing != MODE_PPISP_OMNI_FALLBACK_NONE + has_ppisp_export_source and self.ov_post_processing != MODE_PPISP_OMNI_FALLBACK_NONE ) - needs_ppisp_render_products = self.export_ppisp + needs_ppisp_render_products = has_ppisp_export_source + if self.export_ppisp and post_processing is None: + logger.info("PPISP export requested but no post_processing module is available; skipping /Render export") if needs_ppisp_render_products: render_product_entries = self._create_ppisp_render_products( stage=scene_stage, @@ -499,11 +512,15 @@ def export( # Export PPISP as SPG shaders on RenderProducts if export_spg_ppisp and render_product_entries is not None: + _set_render_setting(scene_stage, _GAUSSIAN_SKIP_TONEMAPPING_RENDER_SETTING, False) + logger.info( + "Disabled Gaussian skip-tonemapping render setting for PPISP SPG export" + ) self._export_ppisp( stage=scene_stage, dataset=dataset, camera_names=camera_names, - post_processing=kwargs.get("post_processing"), + post_processing=post_processing, files=files, ) @@ -515,7 +532,7 @@ def export( camera_prim_paths=camera_prim_paths, render_product_entries=render_product_entries, dataset=dataset, - post_processing=kwargs.get("post_processing"), + post_processing=post_processing, ) # Package @@ -597,6 +614,19 @@ def _export_ppisp( ) return + ppisp_config = getattr(post_processing, "config", None) + controllers = getattr(post_processing, "controllers", None) + has_controller = ( + bool(getattr(ppisp_config, "use_controller", False)) + and controllers is not None + and len(controllers) > 0 + ) + if has_controller: + logger.warning( + "PPISP controller export is not implemented yet; SPG export uses only " + "stored exposure/color parameters, vignetting, and CRF." + ) + from threedgrut.export.usd.writers.ppisp_writer import ( add_ppisp_to_all_render_products, build_camera_frame_mapping, diff --git a/threedgrut/export/usd/ppisp_spg/ppisp_usd_spg.slang b/threedgrut/export/usd/ppisp_spg/ppisp_usd_spg.slang index 02fc2fc3..b721cb6d 100644 --- a/threedgrut/export/usd/ppisp_spg/ppisp_usd_spg.slang +++ b/threedgrut/export/usd/ppisp_spg/ppisp_usd_spg.slang @@ -241,5 +241,5 @@ void ppispProcess(uint3 tid : SV_DispatchThreadID) rgb.g = applyCRF(rgb.g, g_Params.crfToeG, g_Params.crfShoulderG, g_Params.crfGammaG, g_Params.crfCenterG); rgb.b = applyCRF(rgb.b, g_Params.crfToeB, g_Params.crfShoulderB, g_Params.crfGammaB, g_Params.crfCenterB); - g_OutTex[tid.xy] = float4(rgb, pixel.a); + g_OutTex[tid.xy] = float4(saturate(rgb), 1.0); } diff --git a/threedgrut/export/usd/ppisp_spg/ppisp_usd_spg.slang.lua b/threedgrut/export/usd/ppisp_spg/ppisp_usd_spg.slang.lua index 5f3c0f48..971be9ea 100644 --- a/threedgrut/export/usd/ppisp_spg/ppisp_usd_spg.slang.lua +++ b/threedgrut/export/usd/ppisp_spg/ppisp_usd_spg.slang.lua @@ -1,17 +1,5 @@ -- SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -- SPDX-License-Identifier: Apache-2.0 --- --- Licensed under the Apache License, Version 2.0 (the "License"); --- you may not use this file except in compliance with the License. --- You may obtain a copy of the License at --- --- http://www.apache.org/licenses/LICENSE-2.0 --- --- Unless required by applicable law or agreed to in writing, software --- distributed under the License is distributed on an "AS IS" BASIS, --- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. --- See the License for the specific language governing permissions and --- limitations under the License. -- PPISP (Physically Plausible Image Signal Processing) SPG Launcher -- @@ -24,10 +12,10 @@ function ppispProcess(inputs, outputs, params) local in_rgba = inputs["HdrColor"] assert(in_rgba and in_rgba.rank == 2, "HdrColor input must be a 2D texture") - -- Output texture mirrors input shape and dtype + -- LdrColor expects an RGBA8 output, even when the input is HdrColor. local height = in_rgba.shape[1] local width = in_rgba.shape[2] - outputs["PPISPColor"] = slang.empty({height, width}, in_rgba.dtype) + outputs["PPISPColor"] = slang.empty({height, width}, slang.uchar4) -- Pass params directly to preserve __fullName for shader reflection matching. local function getFloat2(name) @@ -36,6 +24,9 @@ function ppispProcess(inputs, outputs, params) end return slang.dispatch({ + stage = "compute", + numthreads = { 16, 16, 1 }, + grid = { math.ceil(width / 16), math.ceil(height / 16), 1 }, bind = { slang.ParameterBlock( -- Exposure diff --git a/threedgrut/export/usd/writers/camera.py b/threedgrut/export/usd/writers/camera.py index 172c8b84..4234017d 100644 --- a/threedgrut/export/usd/writers/camera.py +++ b/threedgrut/export/usd/writers/camera.py @@ -156,6 +156,8 @@ def export_cameras_to_usd( UsdGeom.Xform.Define(stage, root_path) result: Dict[str, str] = {} + usd_start_time_code = float("inf") + usd_end_time_code = float("-inf") for cam_idx, cam_name in enumerate(camera_names): frame_indices = camera_frames[cam_idx] @@ -190,12 +192,18 @@ def export_cameras_to_usd( for frame_idx in frame_indices: usd_pose = poses[frame_idx] @ _CAMERA_COORD_FLIP transform_op.Set(column_vector_4x4_to_usd_matrix(usd_pose), float(frame_idx)) + usd_start_time_code = min(usd_start_time_code, float(frame_idx)) + usd_end_time_code = max(usd_end_time_code, float(frame_idx)) imageable = UsdGeom.Imageable(camera_prim) imageable.CreateVisibilityAttr().Set("inherited" if visible else "invisible") result[cam_name] = camera_path + if usd_start_time_code <= usd_end_time_code: + stage.SetStartTimeCode(usd_start_time_code) + stage.SetEndTimeCode(usd_end_time_code) + logger.info( f"Exported {len(result)} camera(s) ({len(poses)} total frames) to {root_path}" ) diff --git a/threedgrut/export/usd/writers/ppisp_writer.py b/threedgrut/export/usd/writers/ppisp_writer.py index 8947c3ca..59dd8084 100644 --- a/threedgrut/export/usd/writers/ppisp_writer.py +++ b/threedgrut/export/usd/writers/ppisp_writer.py @@ -47,6 +47,7 @@ CHANNEL_SUFFIXES = ["R", "G", "B"] PPISP_SPG_USDA_FILE = "ppisp_usd_spg.slang.usda" +PPISP_SPG_SLANG_FILE = "ppisp_usd_spg.slang" PPISP_INPUT_RENDER_VAR = "HdrColor" PPISP_OUTPUT_RENDER_VAR = "PPISPColor" LDR_COLOR_RENDER_VAR = "LdrColor" @@ -96,12 +97,12 @@ def _add_ldr_color_render_var( ppisp_output_path: Sdf.Path, ) -> str: """Create a LdrColor RenderVar wired to the PPISP output.""" - ldr_var_path = f"{render_product_path}/{LDR_COLOR_RENDER_VAR}" - ldr_var = stage.DefinePrim(ldr_var_path, "RenderVar") - ldr_var.CreateAttribute("sourceName", Sdf.ValueTypeNames.String).Set(LDR_COLOR_RENDER_VAR) - aov_attr = ldr_var.CreateAttribute("omni:rtx:aov", Sdf.ValueTypeNames.Opaque) + render_var_path = f"{render_product_path}/{LDR_COLOR_RENDER_VAR}" + render_var = stage.DefinePrim(render_var_path, "RenderVar") + render_var.CreateAttribute("sourceName", Sdf.ValueTypeNames.String).Set(LDR_COLOR_RENDER_VAR) + aov_attr = render_var.CreateAttribute("omni:rtx:aov", Sdf.ValueTypeNames.Opaque, custom=False) aov_attr.SetConnections([ppisp_output_path]) - return ldr_var_path + return render_var_path def _create_shader_prim(stage: Usd.Stage, render_product_path: str) -> UsdShade.Shader: @@ -118,12 +119,23 @@ def _create_shader_prim(stage: Usd.Stage, render_product_path: str) -> UsdShade. input_var_path = f"{render_product_path}/{PPISP_INPUT_RENDER_VAR}" input_var_prim = stage.GetPrimAtPath(input_var_path) if input_var_prim.IsValid(): - input_var_prim.CreateAttribute("omni:rtx:aov", Sdf.ValueTypeNames.Opaque) + input_var_prim.CreateAttribute("omni:rtx:aov", Sdf.ValueTypeNames.Opaque, custom=False) # PPISP Shader prim referencing the SPG asset definition ppisp_shader_path = f"{render_product_path}/PPISP" shader = UsdShade.Shader.Define(stage, ppisp_shader_path) shader.GetPrim().GetReferences().AddReference(PPISP_SPG_USDA_FILE) + # Duplicate the source metadata on the instance. Some Kit SPG/Fabric paths + # do not resolve referenced shader metadata when opening packaged USDZ files. + shader.GetPrim().CreateAttribute("info:implementationSource", Sdf.ValueTypeNames.Token, custom=False).Set( + "sourceAsset" + ) + shader.GetPrim().CreateAttribute("info:spg:sourceAsset", Sdf.ValueTypeNames.Asset, custom=False).Set( + Sdf.AssetPath(PPISP_SPG_SLANG_FILE) + ) + shader.GetPrim().CreateAttribute("info:spg:sourceAsset:subIdentifier", Sdf.ValueTypeNames.Token, custom=False).Set( + "ppispProcess" + ) # HdrColor opaque input wired to the input RenderVar's AOV hdr_input = shader.CreateInput(PPISP_INPUT_RENDER_VAR, Sdf.ValueTypeNames.Opaque) @@ -132,7 +144,8 @@ def _create_shader_prim(stage: Usd.Stage, render_product_path: str) -> UsdShade. # PPISPColor opaque output shader.CreateOutput(PPISP_OUTPUT_RENDER_VAR, Sdf.ValueTypeNames.Opaque) - # LdrColor RenderVar connected to the output + # LdrColor RenderVar connected to the PPISP output. This intentionally + # replaces the display AOV with PPISP's LDR output. ppisp_output_path = shader.GetPath().AppendProperty(f"outputs:{PPISP_OUTPUT_RENDER_VAR}") ldr_var_path = _add_ldr_color_render_var(stage, render_product_path, ppisp_output_path) @@ -140,7 +153,7 @@ def _create_shader_prim(stage: Usd.Stage, render_product_path: str) -> UsdShade. ordered_vars_rel = render_product.GetRelationship("orderedVars") if ordered_vars_rel: targets = list(ordered_vars_rel.GetTargets()) - targets.append(Sdf.Path(ldr_var_path)) + targets.append(Sdf.Path(LDR_COLOR_RENDER_VAR)) ordered_vars_rel.SetTargets(targets) return shader diff --git a/threedgrut/export/usd/writers/render_product.py b/threedgrut/export/usd/writers/render_product.py index 760ec1ab..89d114e4 100644 --- a/threedgrut/export/usd/writers/render_product.py +++ b/threedgrut/export/usd/writers/render_product.py @@ -24,7 +24,7 @@ import logging from typing import Dict, Tuple -from pxr import Sdf, Usd, UsdGeom +from pxr import Gf, Sdf, Usd, UsdGeom log = logging.getLogger(__name__) @@ -63,9 +63,9 @@ def create_render_products( product_prim = stage.DefinePrim(product_path, "RenderProduct") # Resolution - product_prim.CreateAttribute( - "resolution", Sdf.ValueTypeNames.Int2 - ).Set((width, height)) + product_prim.CreateAttribute("resolution", Sdf.ValueTypeNames.Int2).Set( + Gf.Vec2i(int(width), int(height)) + ) # Camera relationship camera_rel = product_prim.CreateRelationship("camera") @@ -75,10 +75,11 @@ def create_render_products( hdr_var_path = f"{product_path}/{_HDR_COLOR_VAR}" hdr_var = stage.DefinePrim(hdr_var_path, "RenderVar") hdr_var.CreateAttribute("sourceName", Sdf.ValueTypeNames.String).Set(_HDR_COLOR_VAR) + hdr_var.CreateAttribute("omni:rtx:aov", Sdf.ValueTypeNames.Opaque, custom=False) # orderedVars relationship ordered_vars_rel = product_prim.CreateRelationship("orderedVars") - ordered_vars_rel.SetTargets([Sdf.Path(hdr_var_path)]) + ordered_vars_rel.SetTargets([Sdf.Path(_HDR_COLOR_VAR)]) log.debug(f"Created RenderProduct at {product_path} → camera {camera_path} ({width}×{height})") From c81bd516b08973cf4495bccf58c4ce457e44c9e4 Mon Sep 17 00:00:00 2001 From: Nicolas Moenne-Loccoz Date: Mon, 27 Apr 2026 15:13:17 -0400 Subject: [PATCH 10/42] feat(export): gate Omniverse USD authoring Add an explicit omni-usd opt-in so default USD exports remain neutral while Kit-specific PPISP SPG and ParticleField MDL material authoring are only emitted when requested. Made-with: Cursor --- configs/base_gs.yaml | 3 + threedgrut/export/scripts/export_usd.py | 9 ++- threedgrut/export/usd/exporter.py | 54 ++++++++++++-- threedgrut/export/usd/stage_utils.py | 30 +++++++- threedgrut/export/usd/writers/base.py | 10 +++ threedgrut/export/usd/writers/lightfield.py | 19 ++++- .../export/usd/writers/omni_material.py | 72 +++++++++++++++++++ threedgrut/export/usd/writers/ppisp_writer.py | 26 ++++++- 8 files changed, 212 insertions(+), 11 deletions(-) create mode 100644 threedgrut/export/usd/writers/omni_material.py diff --git a/configs/base_gs.yaml b/configs/base_gs.yaml index 65b915e9..108ec317 100644 --- a/configs/base_gs.yaml +++ b/configs/base_gs.yaml @@ -49,7 +49,10 @@ export_usd: sorting_mode_hint: cameraDistance # If true, Gaussian prim ColorSpaceAPI uses lin_rec709_scene; else srgb_rec709_display linear_srgb: false + # Enable Omniverse-specific USD authoring, including PPISP SPG and MDL material binding. + omni-usd: false # Enable PPISP post-processing export when post_processing.method is "ppisp". + # Requires omni-usd=true when the checkpoint contains a PPISP module. export_ppisp: true # none uses the full SPG path when export_ppisp is true. # Other values use Omniverse USD post-processing fallback modes. diff --git a/threedgrut/export/scripts/export_usd.py b/threedgrut/export/scripts/export_usd.py index 7f0c398c..e1683597 100644 --- a/threedgrut/export/scripts/export_usd.py +++ b/threedgrut/export/scripts/export_usd.py @@ -124,6 +124,11 @@ def parse_args(): action="store_true", help="Set prim color space to lin_rec709_scene (linear). Default is srgb_rec709_display.", ) + parser.add_argument( + "--omni-usd", + action="store_true", + help="Enable Omniverse-specific USD authoring such as PPISP SPG and MDL material binding.", + ) # Dataset path (optional, overrides checkpoint's dataset path) parser.add_argument( @@ -283,7 +288,8 @@ def main(): half_geometry = args.half_geometry or args.half half_features = args.half_features or args.half export_conf = getattr(conf, "export_usd", None) or conf - export_ppisp = bool(getattr(export_conf, "export_ppisp", False) or post_processing is not None) + export_ppisp = bool(getattr(export_conf, "export_ppisp", True)) + omni_usd = bool(args.omni_usd or _get_export_conf_value(export_conf, "omni-usd", "omni_usd", False)) exporter = USDExporter( half_geometry=half_geometry, half_features=half_features, @@ -292,6 +298,7 @@ def main(): apply_normalizing_transform=not args.no_transform, sorting_mode_hint=getattr(export_conf, "sorting_mode_hint", "cameraDistance"), linear_srgb=args.linear_srgb or getattr(export_conf, "linear_srgb", False), + omni_usd=omni_usd, export_ppisp=export_ppisp, ov_post_processing=_get_export_conf_value(export_conf, "ov-post-processing", "ov_post_processing", "none"), frames_per_second=getattr(export_conf, "frames_per_second", 1.0), diff --git a/threedgrut/export/usd/exporter.py b/threedgrut/export/usd/exporter.py index bbba990d..be13b2e8 100644 --- a/threedgrut/export/usd/exporter.py +++ b/threedgrut/export/usd/exporter.py @@ -73,6 +73,20 @@ def _set_render_setting(stage: Usd.Stage, key: str, value: Any) -> None: stage.SetMetadataByDictKey("customLayerData", "renderSettings", render_settings) +def _is_ppisp_post_processing(post_processing: Any) -> bool: + post_processing_type = type(post_processing) + return ( + post_processing_type.__name__ == "PPISP" + and post_processing_type.__module__.split(".", maxsplit=1)[0] == "ppisp" + ) + + +def _get_export_config_value(export_conf, hyphen_name: str, attr_name: str, default: Any) -> Any: + if hasattr(export_conf, "get"): + return export_conf.get(hyphen_name, getattr(export_conf, attr_name, default)) + return getattr(export_conf, attr_name, default) + + def _extract_camera_params_from_dataset(dataset) -> Optional[List]: """ Extract per-frame camera parameters from a dataset. @@ -254,7 +268,8 @@ def __init__( apply_normalizing_transform: bool = True, sorting_mode_hint: str = "cameraDistance", linear_srgb: bool = False, - export_ppisp: bool = False, + omni_usd: bool = False, + export_ppisp: bool = True, ov_post_processing: str = MODE_PPISP_OMNI_FALLBACK_NONE, frames_per_second: float = 1.0, ): @@ -270,6 +285,8 @@ def __init__( apply_normalizing_transform: Apply transform to normalize scene orientation. sorting_mode_hint: Sorting hint for rendering ("cameraDistance", "zDepth"). linear_srgb: If True, set prim color space to lin_rec709_scene. + omni_usd: If True, author Omniverse-specific USD features such as + ParticleFieldEmissive MDL binding and PPISP SPG graphs. export_ppisp: If True, export PPISP using SPG or the selected Omniverse USD fallback mode. Requires post_processing kwarg to be a ppisp.PPISP instance. @@ -289,6 +306,7 @@ def __init__( self.apply_normalizing_transform = apply_normalizing_transform self.sorting_mode_hint = sorting_mode_hint self.linear_srgb = linear_srgb + self.omni_usd = omni_usd self.export_ppisp = export_ppisp self.ov_post_processing = normalize_ov_post_processing_mode(ov_post_processing) if not self.export_ppisp and self.ov_post_processing != MODE_PPISP_OMNI_FALLBACK_NONE: @@ -297,6 +315,12 @@ def __init__( "Set export_ppisp=true to export PPISP through an Omniverse USD fallback, " "or set ov-post-processing=none." ) + if not self.omni_usd and self.ov_post_processing != MODE_PPISP_OMNI_FALLBACK_NONE: + raise ValueError( + "export_usd.ov-post-processing requires export_usd.omni-usd=true. " + "Set omni-usd=true to author Omniverse USD post-processing fallback features, " + "or set ov-post-processing=none." + ) self.frames_per_second = frames_per_second def _create_default_stage(self, referenced_stages: List[NamedUSDStage]) -> NamedUSDStage: @@ -343,6 +367,14 @@ def export( """ output_path = Path(output_path) logger.info(f"Exporting USD file to {output_path}...") + post_processing = kwargs.get("post_processing") + has_ppisp_module = _is_ppisp_post_processing(post_processing) + if has_ppisp_module and self.export_ppisp and not self.omni_usd: + raise ValueError( + "PPISP USD export requires export_usd.omni-usd=true because the current PPISP " + "implementation uses Omniverse SPG. Re-run with export_usd.omni-usd=true, " + "or set export_usd.export_ppisp=false to export the model without PPISP effects." + ) # Get model data via accessor accessor = GaussianExportAccessor(model, conf) @@ -389,6 +421,8 @@ def export( half_features=self.half_features, sorting_mode_hint=self.sorting_mode_hint, linear_srgb=self.linear_srgb, + omni_usd=self.omni_usd, + has_post_processing=has_ppisp_module and self.export_ppisp, ) writer.create_prim(attrs.num_gaussians) writer.write_attributes(attrs) @@ -488,7 +522,11 @@ def export( logger.warning(f"Failed to export background: {e}") render_product_entries = None - post_processing = kwargs.get("post_processing") + if not self.export_ppisp and _is_ppisp_post_processing(post_processing): + logger.warning( + "PPISP post-processing module is present but export_usd.export_ppisp=false; " + "PPISP effects will not be exported. Set export_usd.export_ppisp=true to export them." + ) has_ppisp_export_source = self.export_ppisp and post_processing is not None export_spg_ppisp = has_ppisp_export_source and ( self.ov_post_processing == MODE_PPISP_OMNI_FALLBACK_NONE @@ -714,9 +752,13 @@ def from_config(cls, conf) -> "USDExporter": apply_normalizing_transform=getattr(export_conf, "apply_normalizing_transform", True), sorting_mode_hint=getattr(export_conf, "sorting_mode_hint", "cameraDistance"), linear_srgb=getattr(export_conf, "linear_srgb", False), - export_ppisp=getattr(export_conf, "export_ppisp", False), - ov_post_processing=export_conf.get("ov-post-processing", MODE_PPISP_OMNI_FALLBACK_NONE) - if hasattr(export_conf, "get") - else getattr(export_conf, "ov_post_processing", MODE_PPISP_OMNI_FALLBACK_NONE), + omni_usd=_get_export_config_value(export_conf, "omni-usd", "omni_usd", False), + export_ppisp=getattr(export_conf, "export_ppisp", True), + ov_post_processing=_get_export_config_value( + export_conf, + "ov-post-processing", + "ov_post_processing", + MODE_PPISP_OMNI_FALLBACK_NONE, + ), frames_per_second=getattr(export_conf, "frames_per_second", 1.0), ) diff --git a/threedgrut/export/usd/stage_utils.py b/threedgrut/export/usd/stage_utils.py index 9b70bcfe..e555e9a2 100644 --- a/threedgrut/export/usd/stage_utils.py +++ b/threedgrut/export/usd/stage_utils.py @@ -22,6 +22,7 @@ import logging import os +import struct import tempfile import zipfile from dataclasses import dataclass @@ -38,6 +39,31 @@ # Constants DEFAULT_FRAME_RATE = 24.0 USD_WORLD_PATH = "/World" +_USDZ_ALIGNMENT = 64 +_USDZ_PADDING_EXTRA_ID = 0x1986 + + +def _write_usdz_entry(zip_file: zipfile.ZipFile, filename: str, data: Union[str, bytes]) -> None: + if isinstance(data, str): + data = data.encode("utf-8") + + header_offset = zip_file.fp.tell() + filename_size = len(filename.encode("utf-8")) + unpadded_data_offset = header_offset + 30 + filename_size + padding_size = (-unpadded_data_offset) % _USDZ_ALIGNMENT + + # ZIP extra fields need a 4-byte header. If the needed padding is smaller, + # add one full alignment period and keep the same modulo. + if 0 < padding_size < 4: + padding_size += _USDZ_ALIGNMENT + + zip_info = zipfile.ZipInfo(filename) + zip_info.compress_type = zipfile.ZIP_STORED + if padding_size: + zip_info.extra = struct.pack(" Usd.Stage: diff --git a/threedgrut/export/usd/writers/base.py b/threedgrut/export/usd/writers/base.py index ae06159c..48641df8 100644 --- a/threedgrut/export/usd/writers/base.py +++ b/threedgrut/export/usd/writers/base.py @@ -50,11 +50,15 @@ def __init__( capabilities: ModelCapabilities, content_root_path: str = "/World/Gaussians", linear_srgb: bool = False, + omni_usd: bool = False, + has_post_processing: bool = False, ): self.stage = stage self.capabilities = capabilities self.content_root_path = content_root_path self.linear_srgb = linear_srgb + self.omni_usd = omni_usd + self.has_post_processing = has_post_processing self.prim: Optional[Usd.Prim] = None def apply_color_space_to_prim(self, prim: Usd.Prim) -> None: @@ -130,6 +134,8 @@ def create_gaussian_writer( half_features: bool = False, sorting_mode_hint: str = "cameraDistance", linear_srgb: bool = False, + omni_usd: bool = False, + has_post_processing: bool = False, ) -> GaussianUSDWriter: """Factory function to create USD Gaussian writer. @@ -141,6 +147,8 @@ def create_gaussian_writer( half_features: Use half precision for opacities and SH coefficients (LightField) sorting_mode_hint: Sorting mode hint for LightField schema linear_srgb: If True, set prim color space to lin_rec709_scene; else srgb_rec709_display + omni_usd: If True, author Omniverse-specific USD features. + has_post_processing: If True, configure Omniverse material for external post-processing. Returns: Configured GaussianUSDWriter instance (LightField schema) @@ -155,4 +163,6 @@ def create_gaussian_writer( half_features=half_features, sorting_mode_hint=sorting_mode_hint, linear_srgb=linear_srgb, + omni_usd=omni_usd, + has_post_processing=has_post_processing, ) diff --git a/threedgrut/export/usd/writers/lightfield.py b/threedgrut/export/usd/writers/lightfield.py index 4f207ac5..130697b9 100644 --- a/threedgrut/export/usd/writers/lightfield.py +++ b/threedgrut/export/usd/writers/lightfield.py @@ -51,8 +51,17 @@ def __init__( projection_mode_hint: str = "perspective", sorting_mode_hint: str = "cameraDistance", linear_srgb: bool = False, + omni_usd: bool = False, + has_post_processing: bool = False, ) -> None: - super().__init__(stage, capabilities, content_root_path, linear_srgb=linear_srgb) + super().__init__( + stage, + capabilities, + content_root_path, + linear_srgb=linear_srgb, + omni_usd=omni_usd, + has_post_processing=has_post_processing, + ) self.half_geometry = half_geometry self.half_features = half_features self.projection_mode_hint = projection_mode_hint @@ -91,6 +100,14 @@ def create_prim(self, num_gaussians: int) -> Usd.Prim: self._set_rendering_hints() self.apply_color_space_to_prim(self.prim) + if self.omni_usd: + from threedgrut.export.usd.writers.omni_material import bind_particlefield_emissive_material + + bind_particlefield_emissive_material( + stage=self.stage, + prim=self.prim, + has_post_processing=self.has_post_processing, + ) return self.prim def _apply_surflet_kernel_schemas(self) -> None: diff --git a/threedgrut/export/usd/writers/omni_material.py b/threedgrut/export/usd/writers/omni_material.py new file mode 100644 index 00000000..0c24ad81 --- /dev/null +++ b/threedgrut/export/usd/writers/omni_material.py @@ -0,0 +1,72 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Omniverse-specific USD material authoring for Gaussian ParticleFields.""" + +from pxr import Sdf, Usd, UsdShade + +USD_LOOKS_PATH = "/World/Looks" +USD_PARTICLEFIELD_MATERIAL_PATH = f"{USD_LOOKS_PATH}/ParticleFieldEmissive" +USD_PARTICLEFIELD_SHADER_PATH = f"{USD_PARTICLEFIELD_MATERIAL_PATH}/Shader" +PARTICLEFIELD_MATERIAL_MDL_FILE = "ParticleFieldEmissive.mdl" +PARTICLEFIELD_MATERIAL_NAME = "ParticleFieldEmissive" + + +def bind_particlefield_emissive_material( + stage: Usd.Stage, + prim: Usd.Prim, + has_post_processing: bool = False, +) -> None: + """Bind Kit's ParticleFieldEmissive MDL material to a Gaussian ParticleField.""" + looks_prim = stage.GetPrimAtPath(USD_LOOKS_PATH) + if not looks_prim.IsValid(): + stage.DefinePrim(USD_LOOKS_PATH, "Scope") + + material_prim = stage.DefinePrim(USD_PARTICLEFIELD_MATERIAL_PATH, "Material") + shader_prim = stage.DefinePrim(USD_PARTICLEFIELD_SHADER_PATH, "Shader") + shader_prim.CreateAttribute( + "info:implementationSource", + Sdf.ValueTypeNames.Token, + custom=False, + variability=Sdf.VariabilityUniform, + ).Set("sourceAsset") + shader_prim.CreateAttribute( + "info:mdl:sourceAsset", + Sdf.ValueTypeNames.Asset, + custom=False, + variability=Sdf.VariabilityUniform, + ).Set(Sdf.AssetPath(PARTICLEFIELD_MATERIAL_MDL_FILE)) + shader_prim.CreateAttribute( + "info:mdl:sourceAsset:subIdentifier", + Sdf.ValueTypeNames.Token, + custom=False, + variability=Sdf.VariabilityUniform, + ).Set(PARTICLEFIELD_MATERIAL_NAME) + + if has_post_processing: + shader_prim.CreateAttribute("inputs:apply_srgb_linear", Sdf.ValueTypeNames.Bool).Set(False) + shader_prim.CreateAttribute("inputs:apply_inverse_tonemap", Sdf.ValueTypeNames.Bool).Set(False) + + output_attr = shader_prim.CreateAttribute("outputs:out", Sdf.ValueTypeNames.Token) + output_attr.SetMetadata("renderType", "material") + + material = UsdShade.Material(material_prim) + shader = UsdShade.Shader(shader_prim) + for output_name in ("mdl:displacement", "mdl:surface", "mdl:volume"): + output = material.CreateOutput(output_name, Sdf.ValueTypeNames.Token) + output.ConnectToSource(shader.GetOutput("out")) + + binding_api = UsdShade.MaterialBindingAPI(prim) + binding_api.Bind(material, bindingStrength=UsdShade.Tokens.weakerThanDescendants) diff --git a/threedgrut/export/usd/writers/ppisp_writer.py b/threedgrut/export/usd/writers/ppisp_writer.py index 59dd8084..f8326ec8 100644 --- a/threedgrut/export/usd/writers/ppisp_writer.py +++ b/threedgrut/export/usd/writers/ppisp_writer.py @@ -35,7 +35,7 @@ import numpy as np -from pxr import Gf, Sdf, Usd, UsdShade +from pxr import Gf, Sdf, Usd, UsdGeom, UsdShade if TYPE_CHECKING: from ppisp import PPISP # type: ignore[import-not-found] @@ -51,6 +51,7 @@ PPISP_INPUT_RENDER_VAR = "HdrColor" PPISP_OUTPUT_RENDER_VAR = "PPISPColor" LDR_COLOR_RENDER_VAR = "LdrColor" +PPISP_CAMERA_EXPOSURE = 1.0 # --------------------------------------------------------------------------- @@ -307,6 +308,28 @@ def add_ppisp_shader_to_render_product( return shader.GetPrim() +def _force_ppisp_camera_exposure(stage: Usd.Stage, render_product: Usd.Prim) -> None: + camera_rel = render_product.GetRelationship("camera") + camera_targets = camera_rel.GetTargets() if camera_rel else [] + if not camera_targets: + log.warning( + "RenderProduct %s has no camera target; skipping PPISP camera exposure", + render_product.GetPath(), + ) + return + + camera_prim = stage.GetPrimAtPath(camera_targets[0]) + if not camera_prim.IsValid(): + log.warning( + "RenderProduct %s targets missing camera %s; skipping PPISP camera exposure", + render_product.GetPath(), + camera_targets[0], + ) + return + + UsdGeom.Camera(camera_prim).CreateExposureAttr().Set(PPISP_CAMERA_EXPOSURE) + + # --------------------------------------------------------------------------- # Batch export over all RenderProducts # --------------------------------------------------------------------------- @@ -363,6 +386,7 @@ def add_ppisp_to_all_render_products( continue frame_indices = camera_frame_mapping.get(camera_name, []) + _force_ppisp_camera_exposure(stage, child) shader_prim = add_ppisp_shader_to_render_product( stage=stage, From 5c40e3f1dd88f0ddefba38b6c6b37540dc1552b5 Mon Sep 17 00:00:00 2001 From: Nicolas Moenne-Loccoz Date: Mon, 27 Apr 2026 16:28:00 -0400 Subject: [PATCH 11/42] fix(export): refine PPISP Omniverse export behavior Default checkpoint exports to include loaded PPISP modules unless explicitly disabled, and scope neutral camera exposure to a hidden RenderProduct-local camera for PPISP SPG. Made-with: Cursor --- threedgrut/export/scripts/export_usd.py | 21 ++++++++++- threedgrut/export/usd/exporter.py | 3 +- threedgrut/export/usd/writers/ppisp_writer.py | 35 ++++++++++++++----- 3 files changed, 48 insertions(+), 11 deletions(-) diff --git a/threedgrut/export/scripts/export_usd.py b/threedgrut/export/scripts/export_usd.py index e1683597..01995caf 100644 --- a/threedgrut/export/scripts/export_usd.py +++ b/threedgrut/export/scripts/export_usd.py @@ -129,6 +129,20 @@ def parse_args(): action="store_true", help="Enable Omniverse-specific USD authoring such as PPISP SPG and MDL material binding.", ) + ppisp_group = parser.add_mutually_exclusive_group() + ppisp_group.add_argument( + "--export-ppisp", + dest="export_ppisp", + action="store_true", + default=None, + help="Export PPISP effects when the checkpoint contains a PPISP module.", + ) + ppisp_group.add_argument( + "--no-export-ppisp", + dest="export_ppisp", + action="store_false", + help="Skip PPISP export even when the checkpoint contains a PPISP module.", + ) # Dataset path (optional, overrides checkpoint's dataset path) parser.add_argument( @@ -288,7 +302,12 @@ def main(): half_geometry = args.half_geometry or args.half half_features = args.half_features or args.half export_conf = getattr(conf, "export_usd", None) or conf - export_ppisp = bool(getattr(export_conf, "export_ppisp", True)) + if args.export_ppisp is not None: + export_ppisp = args.export_ppisp + elif post_processing is not None: + export_ppisp = True + else: + export_ppisp = bool(getattr(export_conf, "export_ppisp", True)) omni_usd = bool(args.omni_usd or _get_export_conf_value(export_conf, "omni-usd", "omni_usd", False)) exporter = USDExporter( half_geometry=half_geometry, diff --git a/threedgrut/export/usd/exporter.py b/threedgrut/export/usd/exporter.py index be13b2e8..4ca96cc0 100644 --- a/threedgrut/export/usd/exporter.py +++ b/threedgrut/export/usd/exporter.py @@ -373,7 +373,8 @@ def export( raise ValueError( "PPISP USD export requires export_usd.omni-usd=true because the current PPISP " "implementation uses Omniverse SPG. Re-run with export_usd.omni-usd=true, " - "or set export_usd.export_ppisp=false to export the model without PPISP effects." + "or set export_usd.export_ppisp=false / pass --no-export-ppisp to export the " + "model without PPISP effects." ) # Get model data via accessor diff --git a/threedgrut/export/usd/writers/ppisp_writer.py b/threedgrut/export/usd/writers/ppisp_writer.py index f8326ec8..6c05d555 100644 --- a/threedgrut/export/usd/writers/ppisp_writer.py +++ b/threedgrut/export/usd/writers/ppisp_writer.py @@ -51,7 +51,11 @@ PPISP_INPUT_RENDER_VAR = "HdrColor" PPISP_OUTPUT_RENDER_VAR = "PPISPColor" LDR_COLOR_RENDER_VAR = "LdrColor" -PPISP_CAMERA_EXPOSURE = 1.0 +PPISP_CAMERA_EXPOSURE = 0.0 +PPISP_CAMERA_EXPOSURE_FSTOP = 1.0 +PPISP_CAMERA_EXPOSURE_ISO = 100.0 +PPISP_CAMERA_EXPOSURE_RESPONSIVITY = 1.0 +PPISP_CAMERA_EXPOSURE_TIME = 1.0 # --------------------------------------------------------------------------- @@ -308,26 +312,39 @@ def add_ppisp_shader_to_render_product( return shader.GetPrim() -def _force_ppisp_camera_exposure(stage: Usd.Stage, render_product: Usd.Prim) -> None: +def _create_ppisp_camera(stage: Usd.Stage, render_product: Usd.Prim) -> None: camera_rel = render_product.GetRelationship("camera") camera_targets = camera_rel.GetTargets() if camera_rel else [] if not camera_targets: log.warning( - "RenderProduct %s has no camera target; skipping PPISP camera exposure", + "RenderProduct %s has no camera target; skipping PPISP camera override", render_product.GetPath(), ) return - camera_prim = stage.GetPrimAtPath(camera_targets[0]) - if not camera_prim.IsValid(): + source_camera_path = camera_targets[0] + source_camera_prim = stage.GetPrimAtPath(source_camera_path) + if not source_camera_prim.IsValid(): log.warning( - "RenderProduct %s targets missing camera %s; skipping PPISP camera exposure", + "RenderProduct %s targets missing camera %s; skipping PPISP camera override", render_product.GetPath(), - camera_targets[0], + source_camera_path, ) return - UsdGeom.Camera(camera_prim).CreateExposureAttr().Set(PPISP_CAMERA_EXPOSURE) + ppisp_camera_path = render_product.GetPath().AppendChild(f"{source_camera_path.name}_no_isp") + ppisp_camera_prim = stage.DefinePrim(ppisp_camera_path, "Camera") + ppisp_camera_prim.SetHidden(True) + UsdGeom.Imageable(ppisp_camera_prim).CreateVisibilityAttr().Set("invisible") + ppisp_camera_prim.GetInherits().AddInherit(source_camera_path) + ppisp_camera_prim.CreateAttribute("exposure", Sdf.ValueTypeNames.Float).Set(PPISP_CAMERA_EXPOSURE) + ppisp_camera_prim.CreateAttribute("exposure:fStop", Sdf.ValueTypeNames.Float).Set(PPISP_CAMERA_EXPOSURE_FSTOP) + ppisp_camera_prim.CreateAttribute("exposure:iso", Sdf.ValueTypeNames.Float).Set(PPISP_CAMERA_EXPOSURE_ISO) + ppisp_camera_prim.CreateAttribute("exposure:responsivity", Sdf.ValueTypeNames.Float).Set( + PPISP_CAMERA_EXPOSURE_RESPONSIVITY + ) + ppisp_camera_prim.CreateAttribute("exposure:time", Sdf.ValueTypeNames.Float).Set(PPISP_CAMERA_EXPOSURE_TIME) + camera_rel.SetTargets([ppisp_camera_path]) # --------------------------------------------------------------------------- @@ -386,7 +403,7 @@ def add_ppisp_to_all_render_products( continue frame_indices = camera_frame_mapping.get(camera_name, []) - _force_ppisp_camera_exposure(stage, child) + _create_ppisp_camera(stage, child) shader_prim = add_ppisp_shader_to_render_product( stage=stage, From f1a23273d28fe872750fe119faeb694ec3ee1bfe Mon Sep 17 00:00:00 2001 From: Nicolas Moenne-Loccoz Date: Tue, 28 Apr 2026 14:47:11 -0400 Subject: [PATCH 12/42] feat(export): bake post-processing into SH Add a generic post-processing SH bake path with a PPISP adapter so standard USD exports can bake a fixed post-processing transform into Gaussian coefficients. --- configs/base_gs.yaml | 22 +- threedgrut/export/scripts/export_usd.py | 142 ++++++- .../post_processing_sh_bake_validation.py | 400 ++++++++++++++++++ threedgrut/export/usd/exporter.py | 167 ++++++-- .../export/usd/post_processing_sh_bake.py | 315 ++++++++++++++ 5 files changed, 996 insertions(+), 50 deletions(-) create mode 100644 threedgrut/export/scripts/post_processing_sh_bake_validation.py create mode 100644 threedgrut/export/usd/post_processing_sh_bake.py diff --git a/configs/base_gs.yaml b/configs/base_gs.yaml index 108ec317..568bb42a 100644 --- a/configs/base_gs.yaml +++ b/configs/base_gs.yaml @@ -51,11 +51,23 @@ export_usd: linear_srgb: false # Enable Omniverse-specific USD authoring, including PPISP SPG and MDL material binding. omni-usd: false - # Enable PPISP post-processing export when post_processing.method is "ppisp". - # Requires omni-usd=true when the checkpoint contains a PPISP module. - export_ppisp: true - # none uses the full SPG path when export_ppisp is true. - # Other values use Omniverse USD post-processing fallback modes. + # Enable post-processing export when the checkpoint contains a supported module. + # Defaults to true; post-processing-export-mode controls how the effect is exported. + export_post_processing: true + # baked-sh fits a fixed post-processing transform into Gaussian SH coefficients. + # native uses the module-specific native path; PPISP native export requires omni-usd=true. + # baked-sh | native + post-processing-export-mode: baked-sh + # Number of sequential passes over the train/reference set used for fitting. + post-processing-bake-epochs: 1 + post-processing-bake-learning-rate: 0.001 + post-processing-bake-camera-id: 0 + post-processing-bake-frame-id: 0 + # none: disable PPISP vignetting during bake. achromatic-fit: chromatic PPISP reference + # with an achromatic fit-only vignette; the achromatic vignette is not exported. + # none | achromatic-fit + ppisp-bake-vignetting-mode: achromatic-fit + # Omniverse USD post-processing fallback modes. These require omni-usd=true. # none | ppisp-exposure-fallback | ppisp-fitted-post-processing-fallback | ppisp-spg-plus-fitted-post-processing-fallback ov-post-processing: none # USD timeCodesPerSecond; time codes are bare frame indices so this sets playback speed diff --git a/threedgrut/export/scripts/export_usd.py b/threedgrut/export/scripts/export_usd.py index 01995caf..5f661b7f 100644 --- a/threedgrut/export/scripts/export_usd.py +++ b/threedgrut/export/scripts/export_usd.py @@ -129,19 +129,60 @@ def parse_args(): action="store_true", help="Enable Omniverse-specific USD authoring such as PPISP SPG and MDL material binding.", ) - ppisp_group = parser.add_mutually_exclusive_group() - ppisp_group.add_argument( - "--export-ppisp", - dest="export_ppisp", + post_processing_group = parser.add_mutually_exclusive_group() + post_processing_group.add_argument( + "--export-post-processing", + dest="export_post_processing", action="store_true", default=None, - help="Export PPISP effects when the checkpoint contains a PPISP module.", + help="Export post-processing effects when the checkpoint contains a supported post-processing module.", ) - ppisp_group.add_argument( - "--no-export-ppisp", - dest="export_ppisp", + post_processing_group.add_argument( + "--no-export-post-processing", + dest="export_post_processing", action="store_false", - help="Skip PPISP export even when the checkpoint contains a PPISP module.", + help="Skip post-processing export even when the checkpoint contains a supported post-processing module.", + ) + parser.add_argument( + "--post-processing-export-mode", + type=str, + choices=["baked-sh", "native"], + default=None, + help="Post-processing export mode. Default is baked-sh when post-processing export is enabled.", + ) + parser.add_argument( + "--post-processing-bake-epochs", + type=int, + default=None, + help="Number of sequential passes over the train/reference set for post-processing baked-SH export.", + ) + parser.add_argument( + "--post-processing-bake-learning-rate", + type=float, + default=None, + help="Adam learning rate for post-processing baked-SH export.", + ) + parser.add_argument( + "--post-processing-bake-camera-id", + type=int, + default=None, + help="Camera id used by the fixed post-processing baked-SH export.", + ) + parser.add_argument( + "--post-processing-bake-frame-id", + type=int, + default=None, + help="Frame id used by the fixed post-processing baked-SH export.", + ) + parser.add_argument( + "--ppisp-bake-vignetting-mode", + type=str, + choices=["none", "achromatic-fit"], + default=None, + help=( + "Vignetting handling for PPISP baked-SH fitting. 'none' disables PPISP vignetting; " + "'achromatic-fit' uses chromatic PPISP reference and an achromatic fit-only vignette." + ), ) # Dataset path (optional, overrides checkpoint's dataset path) @@ -210,6 +251,21 @@ def _get_export_conf_value(export_conf, dashed_name: str, attr_name: str, defaul return getattr(export_conf, attr_name, default) +def _get_export_post_processing_default(export_conf): + if hasattr(export_conf, "get"): + return export_conf.get( + "export-post-processing", + getattr(export_conf, "export_post_processing", True), + ) + return getattr(export_conf, "export_post_processing", True) + + +def _arg_or_conf(cli_value, export_conf, dashed_name: str, attr_name: str, default): + if cli_value is not None: + return cli_value + return _get_export_conf_value(export_conf, dashed_name, attr_name, default) + + def load_model_from_checkpoint(checkpoint_path: str): """Load a 3DGRUT model from checkpoint.""" from threedgrut.model.model import MixtureOfGaussians @@ -266,9 +322,28 @@ def main(): traceback.print_exc() sys.exit(1) - # Load dataset for camera export + export_conf = getattr(conf, "export_usd", None) or conf + if args.export_post_processing is not None: + export_post_processing = args.export_post_processing + elif post_processing is not None: + export_post_processing = True + else: + export_post_processing = bool(_get_export_post_processing_default(export_conf)) + post_processing_export_mode = _arg_or_conf( + args.post_processing_export_mode, + export_conf, + "post-processing-export-mode", + "post_processing_export_mode", + "baked-sh", + ) + + # Load dataset for camera export and for train-split post-processing SH baking. dataset = None - if not args.no_cameras: + needs_dataset = ( + not args.no_cameras + or (post_processing is not None and export_post_processing and post_processing_export_mode == "baked-sh") + ) + if needs_dataset: try: import threedgrut.datasets as datasets @@ -301,13 +376,6 @@ def main(): else: half_geometry = args.half_geometry or args.half half_features = args.half_features or args.half - export_conf = getattr(conf, "export_usd", None) or conf - if args.export_ppisp is not None: - export_ppisp = args.export_ppisp - elif post_processing is not None: - export_ppisp = True - else: - export_ppisp = bool(getattr(export_conf, "export_ppisp", True)) omni_usd = bool(args.omni_usd or _get_export_conf_value(export_conf, "omni-usd", "omni_usd", False)) exporter = USDExporter( half_geometry=half_geometry, @@ -318,7 +386,43 @@ def main(): sorting_mode_hint=getattr(export_conf, "sorting_mode_hint", "cameraDistance"), linear_srgb=args.linear_srgb or getattr(export_conf, "linear_srgb", False), omni_usd=omni_usd, - export_ppisp=export_ppisp, + export_post_processing=export_post_processing, + post_processing_export_mode=post_processing_export_mode, + post_processing_bake_epochs=_arg_or_conf( + args.post_processing_bake_epochs, + export_conf, + "post-processing-bake-epochs", + "post_processing_bake_epochs", + 1, + ), + post_processing_bake_learning_rate=_arg_or_conf( + args.post_processing_bake_learning_rate, + export_conf, + "post-processing-bake-learning-rate", + "post_processing_bake_learning_rate", + 1.0e-3, + ), + post_processing_bake_camera_id=_arg_or_conf( + args.post_processing_bake_camera_id, + export_conf, + "post-processing-bake-camera-id", + "post_processing_bake_camera_id", + 0, + ), + post_processing_bake_frame_id=_arg_or_conf( + args.post_processing_bake_frame_id, + export_conf, + "post-processing-bake-frame-id", + "post_processing_bake_frame_id", + 0, + ), + ppisp_bake_vignetting_mode=_arg_or_conf( + args.ppisp_bake_vignetting_mode, + export_conf, + "ppisp-bake-vignetting-mode", + "ppisp_bake_vignetting_mode", + "achromatic-fit", + ), ov_post_processing=_get_export_conf_value(export_conf, "ov-post-processing", "ov_post_processing", "none"), frames_per_second=getattr(export_conf, "frames_per_second", 1.0), ) diff --git a/threedgrut/export/scripts/post_processing_sh_bake_validation.py b/threedgrut/export/scripts/post_processing_sh_bake_validation.py new file mode 100644 index 00000000..f4a0b564 --- /dev/null +++ b/threedgrut/export/scripts/post_processing_sh_bake_validation.py @@ -0,0 +1,400 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Validate baking one fixed PPISP transform into Gaussian SH coefficients. + +The reference is the checkpoint render followed by PPISP from one camera/frame, +including that camera's chromatic vignetting. The fitted method optimizes only a +cloned model's SH coefficients, with a temporary achromatic vignette applied in +the fitting loss to isolate chromatic vignette effects. +""" + +from __future__ import annotations + +import argparse +import json +import sys +from pathlib import Path +from typing import Iterable + +sys.path.insert(0, str(Path(__file__).resolve().parents[3])) + +import numpy as np +import torch +import torch.nn as nn +import torchvision +from torchmetrics import PeakSignalNoiseRatio +from torchmetrics.image import StructuralSimilarityIndexMeasure +from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity + +import threedgrut.datasets as datasets +from threedgrut.render import Renderer +from threedgrut.datasets.utils import configure_dataloader_for_platform +from threedgrut.export.usd.post_processing_sh_bake import ( + MODE_PPISP_BAKE_VIGNETTING_NONE, + FixedPPISP, + apply_achromatic_vignetting, + normalize_ppisp_bake_vignetting_mode, +) +from threedgrut.utils.color_correct import color_correct_affine +from threedgrut.utils.logger import logger +from threedgrut.utils.render import apply_post_processing + + +def _setShFitParameters(model) -> Iterable[torch.nn.Parameter]: + for parameter in model.parameters(): + parameter.requires_grad_(False) + + fitParameters = [] + for fieldName in ("features_albedo", "features_specular"): + parameter = getattr(model, fieldName) + parameter.requires_grad_(True) + fitParameters.append(parameter) + return fitParameters + + +def _renderReference(referenceModel, fixedPpisp, gpuBatch) -> torch.Tensor: + with torch.no_grad(): + outputs = referenceModel(gpuBatch) + outputs = apply_post_processing(fixedPpisp, outputs, gpuBatch, training=True) + return outputs["pred_rgb"].detach() + + +def _applyAchromaticVignetting(rgb: torch.Tensor, fixedPpisp, gpuBatch, vignettingMode: str) -> torch.Tensor: + if vignettingMode == MODE_PPISP_BAKE_VIGNETTING_NONE: + return rgb + _, height, width, _ = rgb.shape + return apply_achromatic_vignetting( + rgb=rgb, + ppisp=fixedPpisp.ppisp, + camera_id=fixedPpisp.camera_id, + pixel_coords=gpuBatch.pixel_coords, + resolution=(width, height), + ) + + +def _createTrainDataloader(conf): + trainDataset = datasets.make_train(name=conf.dataset.type, config=conf, ray_jitter=None) + dataloaderKwargs = configure_dataloader_for_platform( + { + "num_workers": conf.num_workers, + "batch_size": 1, + "shuffle": True, + "pin_memory": True, + "persistent_workers": True if conf.num_workers > 0 else False, + } + ) + trainDataloader = torch.utils.data.DataLoader(trainDataset, **dataloaderKwargs) + return trainDataset, trainDataloader + + +def _fitBakedSh( + referenceModel, + bakedModel, + fixedPpisp, + dataset, + dataloader, + fitEpochs: int, + learningRate: float, + vignettingMode: str, +) -> None: + if fitEpochs < 1: + raise ValueError(f"fitEpochs must be >= 1, got {fitEpochs}.") + + fitParameters = list(_setShFitParameters(bakedModel)) + optimizer = torch.optim.Adam(fitParameters, lr=learningRate) + + totalSteps = fitEpochs * len(dataloader) + logger.start_progress(task_name="Fitting baked SH", total_steps=totalSteps, color="cyan") + globalStep = 0 + for fitEpoch in range(fitEpochs): + for batch in dataloader: + globalStep += 1 + gpuBatch = dataset.get_gpu_batch_with_intrinsics(batch) + referenceRgb = _renderReference(referenceModel, fixedPpisp, gpuBatch) + + optimizer.zero_grad(set_to_none=True) + bakedOutputs = bakedModel(gpuBatch) + fittedRgb = _applyAchromaticVignetting(bakedOutputs["pred_rgb"], fixedPpisp, gpuBatch, vignettingMode) + loss = torch.nn.functional.mse_loss(fittedRgb, referenceRgb) + + loss.backward() + optimizer.step() + + logger.log_progress( + task_name="Fitting baked SH", + advance=1, + iteration=f"{fitEpoch + 1}/{fitEpochs}:{globalStep}", + loss=float(loss.detach().item()), + ) + logger.end_progress(task_name="Fitting baked SH") + + +@torch.no_grad() +def _evaluateBakedSh( + referenceModel, + bakedModel, + fixedPpisp, + dataset, + dataloader, + outputRoot: Path, + computeExtraMetrics: bool, + vignettingMode: str, +) -> dict: + criterions = {"psnr": PeakSignalNoiseRatio(data_range=1).to("cuda")} + if computeExtraMetrics: + criterions |= { + "ssim": StructuralSimilarityIndexMeasure(data_range=1.0).to("cuda"), + "lpips": LearnedPerceptualImagePatchSimilarity(net_type="vgg", normalize=True).to("cuda"), + } + + referencePath = outputRoot / "reference" + bakedPath = outputRoot / "baked" + assistedPath = outputRoot / "baked_assisted" + referencePath.mkdir(parents=True, exist_ok=True) + bakedPath.mkdir(parents=True, exist_ok=True) + assistedPath.mkdir(parents=True, exist_ok=True) + + psnrValues = [] + ssimValues = [] + lpipsValues = [] + ccPsnrValues = [] + ccSsimValues = [] + ccLpipsValues = [] + assistedPsnrValues = [] + assistedSsimValues = [] + assistedLpipsValues = [] + assistedCcPsnrValues = [] + assistedCcSsimValues = [] + assistedCcLpipsValues = [] + inferenceTimeValues = [] + + logger.start_progress(task_name="Evaluating baked SH", total_steps=len(dataloader), color="orange1") + for iteration, batch in enumerate(dataloader): + gpuBatch = dataset.get_gpu_batch_with_intrinsics(batch) + + referenceRgb = _renderReference(referenceModel, fixedPpisp, gpuBatch) + bakedOutputs = bakedModel(gpuBatch) + bakedRgb = bakedOutputs["pred_rgb"] + assistedRgb = _applyAchromaticVignetting(bakedRgb, fixedPpisp, gpuBatch, vignettingMode) + + torchvision.utils.save_image( + referenceRgb.squeeze(0).permute(2, 0, 1).clip(0, 1), + referencePath / f"{iteration:05d}.png", + ) + torchvision.utils.save_image( + bakedRgb.squeeze(0).permute(2, 0, 1).clip(0, 1), + bakedPath / f"{iteration:05d}.png", + ) + torchvision.utils.save_image( + assistedRgb.squeeze(0).permute(2, 0, 1).clip(0, 1), + assistedPath / f"{iteration:05d}.png", + ) + + psnrValues.append(criterions["psnr"](bakedRgb, referenceRgb).item()) + assistedPsnrValues.append(criterions["psnr"](assistedRgb, referenceRgb).item()) + if computeExtraMetrics: + ssimValues.append( + criterions["ssim"]( + bakedRgb.permute(0, 3, 1, 2), + referenceRgb.permute(0, 3, 1, 2), + ).item() + ) + lpipsValues.append( + criterions["lpips"]( + bakedRgb.clip(0, 1).permute(0, 3, 1, 2), + referenceRgb.clip(0, 1).permute(0, 3, 1, 2), + ).item() + ) + assistedSsimValues.append( + criterions["ssim"]( + assistedRgb.permute(0, 3, 1, 2), + referenceRgb.permute(0, 3, 1, 2), + ).item() + ) + assistedLpipsValues.append( + criterions["lpips"]( + assistedRgb.clip(0, 1).permute(0, 3, 1, 2), + referenceRgb.clip(0, 1).permute(0, 3, 1, 2), + ).item() + ) + + bakedRgbCc = color_correct_affine(bakedRgb, referenceRgb) + ccPsnrValues.append(criterions["psnr"](bakedRgbCc, referenceRgb).item()) + ccSsimValues.append( + criterions["ssim"]( + bakedRgbCc.permute(0, 3, 1, 2), + referenceRgb.permute(0, 3, 1, 2), + ).item() + ) + ccLpipsValues.append( + criterions["lpips"]( + bakedRgbCc.clip(0, 1).permute(0, 3, 1, 2), + referenceRgb.clip(0, 1).permute(0, 3, 1, 2), + ).item() + ) + assistedRgbCc = color_correct_affine(assistedRgb, referenceRgb) + assistedCcPsnrValues.append(criterions["psnr"](assistedRgbCc, referenceRgb).item()) + assistedCcSsimValues.append( + criterions["ssim"]( + assistedRgbCc.permute(0, 3, 1, 2), + referenceRgb.permute(0, 3, 1, 2), + ).item() + ) + assistedCcLpipsValues.append( + criterions["lpips"]( + assistedRgbCc.clip(0, 1).permute(0, 3, 1, 2), + referenceRgb.clip(0, 1).permute(0, 3, 1, 2), + ).item() + ) + + if "frame_time_ms" in bakedOutputs: + inferenceTimeValues.append(bakedOutputs["frame_time_ms"]) + + logger.log_progress(task_name="Evaluating baked SH", advance=1, iteration=str(iteration), psnr=psnrValues[-1]) + logger.end_progress(task_name="Evaluating baked SH") + + metrics = { + "vignetting_mode": vignettingMode, + "mean_psnr": float(np.mean(psnrValues)), + "std_psnr": float(np.std(psnrValues)), + "assisted_mean_psnr": float(np.mean(assistedPsnrValues)), + "assisted_std_psnr": float(np.std(assistedPsnrValues)), + } + if computeExtraMetrics: + metrics |= { + "mean_ssim": float(np.mean(ssimValues)), + "mean_lpips": float(np.mean(lpipsValues)), + "mean_cc_psnr": float(np.mean(ccPsnrValues)), + "mean_cc_ssim": float(np.mean(ccSsimValues)), + "mean_cc_lpips": float(np.mean(ccLpipsValues)), + "assisted_mean_ssim": float(np.mean(assistedSsimValues)), + "assisted_mean_lpips": float(np.mean(assistedLpipsValues)), + "assisted_mean_cc_psnr": float(np.mean(assistedCcPsnrValues)), + "assisted_mean_cc_ssim": float(np.mean(assistedCcSsimValues)), + "assisted_mean_cc_lpips": float(np.mean(assistedCcLpipsValues)), + } + if inferenceTimeValues: + metrics["mean_inference_time"] = f"{np.mean(inferenceTimeValues):.2f} ms/frame" + + with open(outputRoot / "metrics.json", "w") as file: + json.dump(metrics, file, indent=2) + + logger.log_table("Post-Processing SH Bake Validation", record=metrics) + return metrics + + +def _validateArguments(args, ppisp: nn.Module) -> None: + if not hasattr(ppisp, "vignetting_params"): + raise ValueError("Checkpoint post-processing is not PPISP-like: missing vignetting_params.") + if not hasattr(ppisp, "exposure_params") or not hasattr(ppisp, "crf_params"): + raise ValueError("Checkpoint post-processing is not PPISP-like: missing exposure_params or crf_params.") + + numFrames = int(ppisp.exposure_params.shape[0]) + numCameras = int(ppisp.crf_params.shape[0]) + if args.frameId < 0 or args.frameId >= numFrames: + raise ValueError(f"frameId must be in [0, {numFrames - 1}], got {args.frameId}.") + if args.cameraId < 0 or args.cameraId >= numCameras: + raise ValueError(f"cameraId must be in [0, {numCameras - 1}], got {args.cameraId}.") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--checkpoint", required=True, type=str, help="Path to the pretrained checkpoint.") + parser.add_argument("--path", type=str, default="", help="Path to test data, if not provided taken from ckpt.") + parser.add_argument("--out-dir", dest="outDir", required=True, type=str, help="Output path.") + parser.add_argument("--camera-id", dest="cameraId", default=0, type=int, help="PPISP camera id to bake.") + parser.add_argument("--frame-id", dest="frameId", default=0, type=int, help="PPISP frame id to bake.") + parser.add_argument( + "--fit-epochs", + dest="fitEpochs", + default=1, + type=int, + help="Number of sequential passes over the train/reference set.", + ) + parser.add_argument("--learning-rate", dest="learningRate", default=1.0e-3, type=float, help="SH fitting LR.") + parser.add_argument( + "--vignetting-mode", + dest="vignettingMode", + choices=["none", "achromatic-fit"], + default="achromatic-fit", + help=( + "Vignetting handling for the bake. 'none' disables PPISP vignetting; " + "'achromatic-fit' uses chromatic PPISP reference and an achromatic fit-only vignette." + ), + ) + parser.add_argument( + "--compute-extra-metrics", + dest="computeExtraMetrics", + action="store_false", + help="If set, extra image metrics will not be computed [True by default].", + ) + args = parser.parse_args() + + renderer = Renderer.from_checkpoint( + checkpoint_path=args.checkpoint, + path=args.path, + out_dir=args.outDir, + save_gt=False, + computes_extra_metrics=args.computeExtraMetrics, + ) + if renderer.post_processing is None: + raise ValueError("Checkpoint does not contain PPISP post-processing.") + + _validateArguments(args, renderer.post_processing) + vignettingMode = normalize_ppisp_bake_vignetting_mode(args.vignettingMode) + fixedPpisp = FixedPPISP( + renderer.post_processing, + args.cameraId, + args.frameId, + "cuda", + include_vignetting=vignettingMode != MODE_PPISP_BAKE_VIGNETTING_NONE, + ).eval() + + referenceModel = renderer.model.eval() + bakedModel = renderer.model.clone().eval() + bakedModel.build_acc() + + outputRoot = Path(renderer.out_dir) / f"post_processing_sh_bake_ci{args.cameraId}_fi{args.frameId}" + outputRoot.mkdir(parents=True, exist_ok=True) + + trainDataset, trainDataloader = _createTrainDataloader(renderer.conf) + + logger.info(f"Fitting SH coefficients to fixed PPISP camera={args.cameraId} frame={args.frameId}") + _fitBakedSh( + referenceModel=referenceModel, + bakedModel=bakedModel, + fixedPpisp=fixedPpisp, + dataset=trainDataset, + dataloader=trainDataloader, + fitEpochs=args.fitEpochs, + learningRate=args.learningRate, + vignettingMode=vignettingMode, + ) + + _evaluateBakedSh( + referenceModel=referenceModel, + bakedModel=bakedModel, + fixedPpisp=fixedPpisp, + dataset=renderer.dataset, + dataloader=renderer.dataloader, + outputRoot=outputRoot, + computeExtraMetrics=args.computeExtraMetrics, + vignettingMode=vignettingMode, + ) + + +if __name__ == "__main__": + main() diff --git a/threedgrut/export/usd/exporter.py b/threedgrut/export/usd/exporter.py index 4ca96cc0..82ae2dab 100644 --- a/threedgrut/export/usd/exporter.py +++ b/threedgrut/export/usd/exporter.py @@ -59,12 +59,19 @@ MODE_PPISP_OMNI_FALLBACK_SPG_PLUS_FITTED_POST_PROCESSING, normalize_ov_post_processing_mode, ) +from threedgrut.export.usd.post_processing_sh_bake import MODE_PPISP_BAKE_VIGNETTING_ACHROMATIC_FIT from threedgrut.export.usd.writers.camera import export_cameras_to_usd logger = logging.getLogger(__name__) _GAUSSIAN_SKIP_TONEMAPPING_RENDER_SETTING = "rtx:rtpt:gaussian:skipTonemapping:enabled" +MODE_POST_PROCESSING_EXPORT_BAKED_SH = "baked-sh" +MODE_POST_PROCESSING_EXPORT_NATIVE = "native" +POST_PROCESSING_EXPORT_MODES = { + MODE_POST_PROCESSING_EXPORT_BAKED_SH, + MODE_POST_PROCESSING_EXPORT_NATIVE, +} def _set_render_setting(stage: Usd.Stage, key: str, value: Any) -> None: @@ -81,6 +88,16 @@ def _is_ppisp_post_processing(post_processing: Any) -> bool: ) +def normalize_post_processing_export_mode(mode: str | None) -> str: + normalized = MODE_POST_PROCESSING_EXPORT_BAKED_SH if mode is None else str(mode).strip().lower() + if normalized not in POST_PROCESSING_EXPORT_MODES: + raise ValueError( + f"Unsupported post-processing export mode '{mode}'. " + f"Expected one of: {sorted(POST_PROCESSING_EXPORT_MODES)}" + ) + return normalized + + def _get_export_config_value(export_conf, hyphen_name: str, attr_name: str, default: Any) -> Any: if hasattr(export_conf, "get"): return export_conf.get(hyphen_name, getattr(export_conf, attr_name, default)) @@ -269,7 +286,13 @@ def __init__( sorting_mode_hint: str = "cameraDistance", linear_srgb: bool = False, omni_usd: bool = False, - export_ppisp: bool = True, + export_post_processing: bool = True, + post_processing_export_mode: str = MODE_POST_PROCESSING_EXPORT_BAKED_SH, + post_processing_bake_epochs: int = 1, + post_processing_bake_learning_rate: float = 1.0e-3, + post_processing_bake_camera_id: int = 0, + post_processing_bake_frame_id: int = 0, + ppisp_bake_vignetting_mode: str = MODE_PPISP_BAKE_VIGNETTING_ACHROMATIC_FIT, ov_post_processing: str = MODE_PPISP_OMNI_FALLBACK_NONE, frames_per_second: float = 1.0, ): @@ -287,11 +310,20 @@ def __init__( linear_srgb: If True, set prim color space to lin_rec709_scene. omni_usd: If True, author Omniverse-specific USD features such as ParticleFieldEmissive MDL binding and PPISP SPG graphs. - export_ppisp: If True, export PPISP using SPG or the selected - Omniverse USD fallback mode. Requires post_processing kwarg to - be a ppisp.PPISP instance. + export_post_processing: If True, export the checkpoint post-processing + module using the selected export mode. + post_processing_export_mode: "baked-sh" bakes one fixed + post-processing transform into Gaussian SH coefficients. + "native" uses the module-specific native export path. + post_processing_bake_epochs: Number of sequential passes over the train/reference set. + post_processing_bake_learning_rate: Adam learning rate for baked SH. + post_processing_bake_camera_id: Camera index for the fixed baked transform. + post_processing_bake_frame_id: Frame index for the fixed baked transform. + ppisp_bake_vignetting_mode: "none" disables vignetting in the PPISP + reference. "achromatic-fit" keeps chromatic PPISP vignetting in + the reference and applies an achromatic estimate only in the fit loss. ov_post_processing: PPISP export implementation selector. "none" - uses the full SPG path when export_ppisp is enabled. + uses the native PPISP path when export_post_processing is enabled. frames_per_second: Sets stage.timeCodesPerSecond. Time codes are always bare frame indices (float(frame_idx)), so this controls playback speed. Default 1.0 means 1 frame per second of real time. @@ -307,12 +339,18 @@ def __init__( self.sorting_mode_hint = sorting_mode_hint self.linear_srgb = linear_srgb self.omni_usd = omni_usd - self.export_ppisp = export_ppisp + self.export_post_processing = export_post_processing + self.post_processing_export_mode = normalize_post_processing_export_mode(post_processing_export_mode) + self.post_processing_bake_epochs = int(post_processing_bake_epochs) + self.post_processing_bake_learning_rate = float(post_processing_bake_learning_rate) + self.post_processing_bake_camera_id = int(post_processing_bake_camera_id) + self.post_processing_bake_frame_id = int(post_processing_bake_frame_id) + self.ppisp_bake_vignetting_mode = str(ppisp_bake_vignetting_mode) self.ov_post_processing = normalize_ov_post_processing_mode(ov_post_processing) - if not self.export_ppisp and self.ov_post_processing != MODE_PPISP_OMNI_FALLBACK_NONE: + if not self.export_post_processing and self.ov_post_processing != MODE_PPISP_OMNI_FALLBACK_NONE: raise ValueError( - "export_usd.ov-post-processing requires export_usd.export_ppisp=true. " - "Set export_ppisp=true to export PPISP through an Omniverse USD fallback, " + "export_usd.ov-post-processing requires export_usd.export_post_processing=true. " + "Set export_post_processing=true to export PPISP through an Omniverse USD fallback, " "or set ov-post-processing=none." ) if not self.omni_usd and self.ov_post_processing != MODE_PPISP_OMNI_FALLBACK_NONE: @@ -358,8 +396,7 @@ def export( conf: Configuration parameters. background: Optional background model for environment export. **kwargs: - post_processing: ppisp.PPISP instance for SPG export (used when - export_ppisp=True or ov-post-processing is enabled). + post_processing: checkpoint post-processing module to bake or export natively. validate_usd (default True): run OpenUSD stage validators. apply_coordinate_transform (bool): apply 3DGRUT→USDZ coordinate flip. copy_source_usd: (stage_path, res_root) for prim merge. @@ -369,12 +406,44 @@ def export( logger.info(f"Exporting USD file to {output_path}...") post_processing = kwargs.get("post_processing") has_ppisp_module = _is_ppisp_post_processing(post_processing) - if has_ppisp_module and self.export_ppisp and not self.omni_usd: + uses_baked_post_processing_export = ( + post_processing is not None + and self.export_post_processing + and self.post_processing_export_mode == MODE_POST_PROCESSING_EXPORT_BAKED_SH + and self.ov_post_processing == MODE_PPISP_OMNI_FALLBACK_NONE + ) + if has_ppisp_module and self.export_post_processing and not uses_baked_post_processing_export and not self.omni_usd: raise ValueError( - "PPISP USD export requires export_usd.omni-usd=true because the current PPISP " - "implementation uses Omniverse SPG. Re-run with export_usd.omni-usd=true, " - "or set export_usd.export_ppisp=false / pass --no-export-ppisp to export the " - "model without PPISP effects." + "PPISP SPG/fallback USD export requires export_usd.omni-usd=true. " + "Use post_processing_export_mode=baked-sh for standard USD baked-SH export, " + "or set export_usd.export_post_processing=false / pass --no-export-post-processing." + ) + + if uses_baked_post_processing_export: + from threedgrut.export.usd.post_processing_sh_bake import ( + PPISPPostProcessingBakeAdapter, + bake_post_processing_into_sh, + ) + + if not has_ppisp_module: + raise ValueError("Baked-SH post-processing export currently supports PPISP post-processing only.") + adapter = PPISPPostProcessingBakeAdapter( + camera_id=self.post_processing_bake_camera_id, + frame_id=self.post_processing_bake_frame_id, + vignetting_mode=self.ppisp_bake_vignetting_mode, + ) + logger.info( + "Baking post-processing into Gaussian SH coefficients before export " + f"(camera={self.post_processing_bake_camera_id}, frame={self.post_processing_bake_frame_id})" + ) + model = bake_post_processing_into_sh( + model=model, + post_processing=post_processing, + train_dataset=dataset, + conf=conf, + adapter=adapter, + epochs=self.post_processing_bake_epochs, + learning_rate=self.post_processing_bake_learning_rate, ) # Get model data via accessor @@ -423,7 +492,7 @@ def export( sorting_mode_hint=self.sorting_mode_hint, linear_srgb=self.linear_srgb, omni_usd=self.omni_usd, - has_post_processing=has_ppisp_module and self.export_ppisp, + has_post_processing=has_ppisp_module and self.export_post_processing and not uses_baked_post_processing_export, ) writer.create_prim(attrs.num_gaussians) writer.write_attributes(attrs) @@ -523,22 +592,27 @@ def export( logger.warning(f"Failed to export background: {e}") render_product_entries = None - if not self.export_ppisp and _is_ppisp_post_processing(post_processing): + if not self.export_post_processing and _is_ppisp_post_processing(post_processing): logger.warning( - "PPISP post-processing module is present but export_usd.export_ppisp=false; " - "PPISP effects will not be exported. Set export_usd.export_ppisp=true to export them." + "PPISP post-processing module is present but export_usd.export_post_processing=false; " + "PPISP effects will not be exported. Set export_usd.export_post_processing=true to export them." ) - has_ppisp_export_source = self.export_ppisp and post_processing is not None + has_ppisp_export_source = ( + self.export_post_processing and post_processing is not None and not uses_baked_post_processing_export + ) export_spg_ppisp = has_ppisp_export_source and ( - self.ov_post_processing == MODE_PPISP_OMNI_FALLBACK_NONE + ( + self.post_processing_export_mode == MODE_POST_PROCESSING_EXPORT_NATIVE + and self.ov_post_processing == MODE_PPISP_OMNI_FALLBACK_NONE + ) or self.ov_post_processing == MODE_PPISP_OMNI_FALLBACK_SPG_PLUS_FITTED_POST_PROCESSING ) export_omni_ppisp_fallback = ( has_ppisp_export_source and self.ov_post_processing != MODE_PPISP_OMNI_FALLBACK_NONE ) needs_ppisp_render_products = has_ppisp_export_source - if self.export_ppisp and post_processing is None: - logger.info("PPISP export requested but no post_processing module is available; skipping /Render export") + if self.export_post_processing and post_processing is None: + logger.info("Post-processing export requested but no post_processing module is available; skipping /Render export") if needs_ppisp_render_products: render_product_entries = self._create_ppisp_render_products( stage=scene_stage, @@ -648,7 +722,7 @@ def _export_ppisp( if not isinstance(post_processing, PPISP): logger.warning( - f"export_ppisp=True but post_processing is {type(post_processing).__name__}, " + f"export_post_processing=True but post_processing is {type(post_processing).__name__}, " "expected ppisp.PPISP — skipping" ) return @@ -754,7 +828,48 @@ def from_config(cls, conf) -> "USDExporter": sorting_mode_hint=getattr(export_conf, "sorting_mode_hint", "cameraDistance"), linear_srgb=getattr(export_conf, "linear_srgb", False), omni_usd=_get_export_config_value(export_conf, "omni-usd", "omni_usd", False), - export_ppisp=getattr(export_conf, "export_ppisp", True), + export_post_processing=_get_export_config_value( + export_conf, + "export-post-processing", + "export_post_processing", + True, + ), + post_processing_export_mode=_get_export_config_value( + export_conf, + "post-processing-export-mode", + "post_processing_export_mode", + MODE_POST_PROCESSING_EXPORT_BAKED_SH, + ), + post_processing_bake_epochs=_get_export_config_value( + export_conf, + "post-processing-bake-epochs", + "post_processing_bake_epochs", + 1, + ), + post_processing_bake_learning_rate=_get_export_config_value( + export_conf, + "post-processing-bake-learning-rate", + "post_processing_bake_learning_rate", + 1.0e-3, + ), + post_processing_bake_camera_id=_get_export_config_value( + export_conf, + "post-processing-bake-camera-id", + "post_processing_bake_camera_id", + 0, + ), + post_processing_bake_frame_id=_get_export_config_value( + export_conf, + "post-processing-bake-frame-id", + "post_processing_bake_frame_id", + 0, + ), + ppisp_bake_vignetting_mode=_get_export_config_value( + export_conf, + "ppisp-bake-vignetting-mode", + "ppisp_bake_vignetting_mode", + MODE_PPISP_BAKE_VIGNETTING_ACHROMATIC_FIT, + ), ov_post_processing=_get_export_config_value( export_conf, "ov-post-processing", diff --git a/threedgrut/export/usd/post_processing_sh_bake.py b/threedgrut/export/usd/post_processing_sh_bake.py new file mode 100644 index 00000000..ac1a3acf --- /dev/null +++ b/threedgrut/export/usd/post_processing_sh_bake.py @@ -0,0 +1,315 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Fit fixed post-processing transforms into Gaussian SH coefficients for export.""" + +from __future__ import annotations + +import copy +import logging +from typing import Iterable + +import torch +import torch.nn as nn + +from threedgrut.datasets.utils import configure_dataloader_for_platform +from threedgrut.utils.render import apply_post_processing + +logger = logging.getLogger(__name__) + + +class PostProcessingBakeAdapter: + """Adapter interface for baking one fixed post-processing transform.""" + + name = "post-processing" + + def validate(self, post_processing: nn.Module) -> None: + del post_processing + + def create_fixed_post_processing(self, post_processing: nn.Module, device: str) -> nn.Module: + return copy.deepcopy(post_processing).to(device).eval() + + def apply_fit_transform(self, rgb: torch.Tensor, fixed_post_processing: nn.Module, gpu_batch) -> torch.Tensor: + del fixed_post_processing, gpu_batch + return rgb + + def log_context(self) -> str: + return "" + + +def _set_sh_fit_parameters(model) -> Iterable[torch.nn.Parameter]: + for parameter in model.parameters(): + parameter.requires_grad_(False) + + fit_parameters = [] + for field_name in ("features_albedo", "features_specular"): + parameter = getattr(model, field_name) + parameter.requires_grad_(True) + fit_parameters.append(parameter) + return fit_parameters + + +def _create_train_dataloader(conf, train_dataset): + num_workers = int(getattr(conf, "num_workers", 8)) + dataloader_kwargs = configure_dataloader_for_platform( + { + "num_workers": num_workers, + "batch_size": 1, + "shuffle": True, + "pin_memory": True, + "persistent_workers": True if num_workers > 0 else False, + } + ) + return torch.utils.data.DataLoader(train_dataset, **dataloader_kwargs) + + +def _render_reference(reference_model, fixed_post_processing, gpu_batch) -> torch.Tensor: + with torch.no_grad(): + outputs = reference_model(gpu_batch) + outputs = apply_post_processing(fixed_post_processing, outputs, gpu_batch, training=True) + return outputs["pred_rgb"].detach() + + +def bake_post_processing_into_sh( + model, + post_processing: nn.Module, + train_dataset, + conf, + *, + adapter: PostProcessingBakeAdapter, + epochs: int = 1, + learning_rate: float = 1.0e-3, + device: str = "cuda", +): + """Return a cloned model whose SH coefficients approximate fixed post-processing output.""" + if not hasattr(model, "clone"): + raise TypeError("Post-processing SH bake export requires a cloneable MixtureOfGaussians model.") + if train_dataset is None: + raise ValueError("Post-processing SH bake export requires a train dataset. Pass --dataset if it is missing.") + if post_processing is None: + raise ValueError("Post-processing SH bake export requires a post_processing module.") + if epochs < 1: + raise ValueError(f"epochs must be >= 1, got {epochs}.") + + adapter.validate(post_processing) + reference_model = model.to(device).eval() + reference_model.build_acc() + baked_model = model.clone().to(device).eval() + baked_model.build_acc() + fixed_post_processing = adapter.create_fixed_post_processing(post_processing, device) + + fit_parameters = list(_set_sh_fit_parameters(baked_model)) + optimizer = torch.optim.Adam(fit_parameters, lr=learning_rate) + train_dataloader = _create_train_dataloader(conf, train_dataset) + + logger.info( + "Fitting %s SH bake on train split: epochs=%s frames_per_epoch=%s%s", + adapter.name, + epochs, + len(train_dataloader), + adapter.log_context(), + ) + with torch.enable_grad(): + global_step = 0 + total_steps = epochs * len(train_dataloader) + for epoch in range(epochs): + for batch in train_dataloader: + global_step += 1 + gpu_batch = train_dataset.get_gpu_batch_with_intrinsics(batch) + reference_rgb = _render_reference(reference_model, fixed_post_processing, gpu_batch) + + optimizer.zero_grad(set_to_none=True) + baked_outputs = baked_model(gpu_batch) + fitted_rgb = adapter.apply_fit_transform( + baked_outputs["pred_rgb"], + fixed_post_processing, + gpu_batch, + ) + loss = torch.nn.functional.mse_loss(fitted_rgb, reference_rgb) + + loss.backward() + optimizer.step() + + if global_step == 1 or global_step % 50 == 0 or global_step == total_steps: + logger.info( + "%s SH bake epoch %s/%s step %s/%s loss=%.6g", + adapter.name, + epoch + 1, + epochs, + global_step, + total_steps, + float(loss.detach()), + ) + + for parameter in baked_model.parameters(): + parameter.requires_grad_(False) + baked_model.eval() + logger.info("%s SH bake complete", adapter.name) + return baked_model + + +MODE_PPISP_BAKE_VIGNETTING_NONE = "none" +MODE_PPISP_BAKE_VIGNETTING_ACHROMATIC_FIT = "achromatic-fit" +PPISP_BAKE_VIGNETTING_MODES = { + MODE_PPISP_BAKE_VIGNETTING_NONE, + MODE_PPISP_BAKE_VIGNETTING_ACHROMATIC_FIT, +} + + +class FixedPPISP(nn.Module): + """Wrap PPISP as one fixed camera/frame color transform.""" + + def __init__( + self, + ppisp: nn.Module, + camera_id: int, + frame_id: int, + device: str, + include_vignetting: bool = True, + ) -> None: + super().__init__() + self.camera_id = int(camera_id) + self.frame_id = int(frame_id) + self.ppisp = copy.deepcopy(ppisp).to(device).eval() + + if hasattr(self.ppisp, "config") and hasattr(self.ppisp.config, "use_controller"): + self.ppisp.config.use_controller = False + if not include_vignetting and hasattr(self.ppisp, "vignetting_params"): + with torch.no_grad(): + self.ppisp.vignetting_params.zero_() + + def forward( + self, + rgb: torch.Tensor, + pixel_coords: torch.Tensor, + resolution: tuple[int, int], + camera_idx=None, + frame_idx=None, + exposure_prior=None, + ) -> torch.Tensor: + del camera_idx, frame_idx, exposure_prior + return self.ppisp( + rgb, + pixel_coords, + resolution=resolution, + camera_idx=self.camera_id, + frame_idx=self.frame_id, + exposure_prior=None, + ) + + +def normalize_ppisp_bake_vignetting_mode(mode: str | None) -> str: + normalized = MODE_PPISP_BAKE_VIGNETTING_ACHROMATIC_FIT if mode is None else str(mode).strip().lower() + if normalized not in PPISP_BAKE_VIGNETTING_MODES: + raise ValueError( + f"Unsupported PPISP bake vignetting mode '{mode}'. " + f"Expected one of: {sorted(PPISP_BAKE_VIGNETTING_MODES)}" + ) + return normalized + + +def estimate_achromatic_vignetting( + ppisp: nn.Module, + camera_id: int, + pixel_coords: torch.Tensor, + resolution: tuple[int, int], +) -> torch.Tensor: + """Estimate luminance falloff from PPISP's chromatic camera vignette.""" + if not hasattr(ppisp, "vignetting_params"): + raise ValueError("PPISP-like module is missing vignetting_params.") + + width, height = resolution + del height + vig_params = ppisp.vignetting_params[int(camera_id)].to(device=pixel_coords.device, dtype=pixel_coords.dtype) + + u = (pixel_coords[..., 0] - float(width) * 0.5) / float(width) + v = (pixel_coords[..., 1] - float(resolution[1]) * 0.5) / float(width) + uv = torch.stack([u, v], dim=-1) + + channel_falloff = [] + for channel in range(3): + center = vig_params[channel, 0:2] + delta = uv - center + r2 = torch.sum(delta * delta, dim=-1) + falloff = ( + 1.0 + + vig_params[channel, 2] * r2 + + vig_params[channel, 3] * r2 * r2 + + vig_params[channel, 4] * r2 * r2 * r2 + ) + channel_falloff.append(torch.clamp(falloff, 0.0, 1.0)) + + rgb_falloff = torch.stack(channel_falloff, dim=-1) + luminance_weights = torch.tensor([0.2126, 0.7152, 0.0722], device=pixel_coords.device, dtype=pixel_coords.dtype) + return torch.sum(rgb_falloff * luminance_weights, dim=-1, keepdim=True) + + +def apply_achromatic_vignetting( + rgb: torch.Tensor, + ppisp: nn.Module, + camera_id: int, + pixel_coords: torch.Tensor, + resolution: tuple[int, int], +) -> torch.Tensor: + return rgb * estimate_achromatic_vignetting(ppisp, camera_id, pixel_coords, resolution) + + +class PPISPPostProcessingBakeAdapter(PostProcessingBakeAdapter): + name = "PPISP post-processing" + + def __init__( + self, + camera_id: int = 0, + frame_id: int = 0, + vignetting_mode: str = MODE_PPISP_BAKE_VIGNETTING_ACHROMATIC_FIT, + ) -> None: + self.camera_id = int(camera_id) + self.frame_id = int(frame_id) + self.vignetting_mode = normalize_ppisp_bake_vignetting_mode(vignetting_mode) + + def validate(self, post_processing: nn.Module) -> None: + if not hasattr(post_processing, "exposure_params") or not hasattr(post_processing, "crf_params"): + raise ValueError("PPISP SH bake export requires a PPISP-like post_processing module.") + + num_frames = int(post_processing.exposure_params.shape[0]) + num_cameras = int(post_processing.crf_params.shape[0]) + if self.frame_id < 0 or self.frame_id >= num_frames: + raise ValueError(f"frame_id must be in [0, {num_frames - 1}], got {self.frame_id}.") + if self.camera_id < 0 or self.camera_id >= num_cameras: + raise ValueError(f"camera_id must be in [0, {num_cameras - 1}], got {self.camera_id}.") + + def create_fixed_post_processing(self, post_processing: nn.Module, device: str) -> nn.Module: + return FixedPPISP( + post_processing, + self.camera_id, + self.frame_id, + device, + include_vignetting=self.vignetting_mode == MODE_PPISP_BAKE_VIGNETTING_ACHROMATIC_FIT, + ).eval() + + def apply_fit_transform(self, rgb: torch.Tensor, fixed_post_processing: nn.Module, gpu_batch) -> torch.Tensor: + if self.vignetting_mode == MODE_PPISP_BAKE_VIGNETTING_NONE: + return rgb + _, height, width, _ = rgb.shape + return apply_achromatic_vignetting( + rgb=rgb, + ppisp=fixed_post_processing.ppisp, + camera_id=fixed_post_processing.camera_id, + pixel_coords=gpu_batch.pixel_coords, + resolution=(width, height), + ) + + def log_context(self) -> str: + return f" camera={self.camera_id} frame={self.frame_id} vignetting={self.vignetting_mode}" From 67d557e0a0c7387414b9256ebfc30e49a3f59db3 Mon Sep 17 00:00:00 2001 From: Nicolas Moenne-Loccoz Date: Tue, 28 Apr 2026 15:09:41 -0400 Subject: [PATCH 13/42] refactor(export): clarify post-processing native mode Remove the separate omni-usd and ov-post-processing controls while keeping PPISP native export behind the explicit omni-native post-processing mode. --- configs/base_gs.yaml | 9 +- threedgrut/export/scripts/export_usd.py | 15 +- threedgrut/export/usd/exporter.py | 175 ++------ threedgrut/export/usd/writers/__init__.py | 3 - .../export/usd/writers/ov_post_processing.py | 385 ------------------ 5 files changed, 39 insertions(+), 548 deletions(-) delete mode 100644 threedgrut/export/usd/writers/ov_post_processing.py diff --git a/configs/base_gs.yaml b/configs/base_gs.yaml index 568bb42a..9b36a34e 100644 --- a/configs/base_gs.yaml +++ b/configs/base_gs.yaml @@ -49,14 +49,12 @@ export_usd: sorting_mode_hint: cameraDistance # If true, Gaussian prim ColorSpaceAPI uses lin_rec709_scene; else srgb_rec709_display linear_srgb: false - # Enable Omniverse-specific USD authoring, including PPISP SPG and MDL material binding. - omni-usd: false # Enable post-processing export when the checkpoint contains a supported module. # Defaults to true; post-processing-export-mode controls how the effect is exported. export_post_processing: true # baked-sh fits a fixed post-processing transform into Gaussian SH coefficients. - # native uses the module-specific native path; PPISP native export requires omni-usd=true. - # baked-sh | native + # omni-native uses the module-specific Omniverse-native path; currently PPISP SPG. + # baked-sh | omni-native post-processing-export-mode: baked-sh # Number of sequential passes over the train/reference set used for fitting. post-processing-bake-epochs: 1 @@ -67,9 +65,6 @@ export_usd: # with an achromatic fit-only vignette; the achromatic vignette is not exported. # none | achromatic-fit ppisp-bake-vignetting-mode: achromatic-fit - # Omniverse USD post-processing fallback modes. These require omni-usd=true. - # none | ppisp-exposure-fallback | ppisp-fitted-post-processing-fallback | ppisp-spg-plus-fitted-post-processing-fallback - ov-post-processing: none # USD timeCodesPerSecond; time codes are bare frame indices so this sets playback speed frames_per_second: 1.0 diff --git a/threedgrut/export/scripts/export_usd.py b/threedgrut/export/scripts/export_usd.py index 5f661b7f..26479f0a 100644 --- a/threedgrut/export/scripts/export_usd.py +++ b/threedgrut/export/scripts/export_usd.py @@ -124,11 +124,6 @@ def parse_args(): action="store_true", help="Set prim color space to lin_rec709_scene (linear). Default is srgb_rec709_display.", ) - parser.add_argument( - "--omni-usd", - action="store_true", - help="Enable Omniverse-specific USD authoring such as PPISP SPG and MDL material binding.", - ) post_processing_group = parser.add_mutually_exclusive_group() post_processing_group.add_argument( "--export-post-processing", @@ -146,9 +141,9 @@ def parse_args(): parser.add_argument( "--post-processing-export-mode", type=str, - choices=["baked-sh", "native"], + choices=["baked-sh", "omni-native"], default=None, - help="Post-processing export mode. Default is baked-sh when post-processing export is enabled.", + help="Post-processing export mode. 'omni-native' uses PPISP SPG and Omniverse material authoring.", ) parser.add_argument( "--post-processing-bake-epochs", @@ -336,12 +331,11 @@ def main(): "post_processing_export_mode", "baked-sh", ) - # Load dataset for camera export and for train-split post-processing SH baking. dataset = None needs_dataset = ( not args.no_cameras - or (post_processing is not None and export_post_processing and post_processing_export_mode == "baked-sh") + or (post_processing is not None and export_post_processing) ) if needs_dataset: try: @@ -376,7 +370,6 @@ def main(): else: half_geometry = args.half_geometry or args.half half_features = args.half_features or args.half - omni_usd = bool(args.omni_usd or _get_export_conf_value(export_conf, "omni-usd", "omni_usd", False)) exporter = USDExporter( half_geometry=half_geometry, half_features=half_features, @@ -385,7 +378,6 @@ def main(): apply_normalizing_transform=not args.no_transform, sorting_mode_hint=getattr(export_conf, "sorting_mode_hint", "cameraDistance"), linear_srgb=args.linear_srgb or getattr(export_conf, "linear_srgb", False), - omni_usd=omni_usd, export_post_processing=export_post_processing, post_processing_export_mode=post_processing_export_mode, post_processing_bake_epochs=_arg_or_conf( @@ -423,7 +415,6 @@ def main(): "ppisp_bake_vignetting_mode", "achromatic-fit", ), - ov_post_processing=_get_export_conf_value(export_conf, "ov-post-processing", "ov_post_processing", "none"), frames_per_second=getattr(export_conf, "frames_per_second", 1.0), ) logger.info("Using ParticleField3DGaussianSplat schema (standard)") diff --git a/threedgrut/export/usd/exporter.py b/threedgrut/export/usd/exporter.py index 82ae2dab..eb9e301a 100644 --- a/threedgrut/export/usd/exporter.py +++ b/threedgrut/export/usd/exporter.py @@ -54,11 +54,6 @@ merge_source_prim_at_same_path, merge_source_world_at_same_paths, ) -from threedgrut.export.usd.writers.ov_post_processing import ( - MODE_PPISP_OMNI_FALLBACK_NONE, - MODE_PPISP_OMNI_FALLBACK_SPG_PLUS_FITTED_POST_PROCESSING, - normalize_ov_post_processing_mode, -) from threedgrut.export.usd.post_processing_sh_bake import MODE_PPISP_BAKE_VIGNETTING_ACHROMATIC_FIT from threedgrut.export.usd.writers.camera import export_cameras_to_usd @@ -67,10 +62,10 @@ _GAUSSIAN_SKIP_TONEMAPPING_RENDER_SETTING = "rtx:rtpt:gaussian:skipTonemapping:enabled" MODE_POST_PROCESSING_EXPORT_BAKED_SH = "baked-sh" -MODE_POST_PROCESSING_EXPORT_NATIVE = "native" +MODE_POST_PROCESSING_EXPORT_OMNI_NATIVE = "omni-native" POST_PROCESSING_EXPORT_MODES = { MODE_POST_PROCESSING_EXPORT_BAKED_SH, - MODE_POST_PROCESSING_EXPORT_NATIVE, + MODE_POST_PROCESSING_EXPORT_OMNI_NATIVE, } @@ -231,14 +226,9 @@ def _extract_camera_grouping(dataset): def _extract_camera_resolutions(camera_params: List, camera_names: List[str], frame_to_camera: List[int]): - """Extract per-camera resolution from the first valid frame of each camera. - - Returns: - {camera_name: (width, height)} or empty dict on failure. - """ + """Extract per-camera resolution from the first valid frame of each camera.""" result = {} num_cameras = len(camera_names) - # Build first-frame-per-camera map first_frame: Dict[int, int] = {} for frame_idx, cam_idx in enumerate(frame_to_camera): if cam_idx not in first_frame and 0 <= cam_idx < num_cameras: @@ -269,7 +259,7 @@ class USDExporter(ModelExporter): - ParticleField3DGaussianSplat schema (standard OpenUSD) - One Camera prim per physical camera with time-sampled transforms - Background/environment export as DomeLight - - Optional PPISP SPG shader on per-camera RenderProducts + - Optional baked-SH post-processing export or PPISP Omniverse native export - USDZ packaging (default output) For NuRec compatibility, use NuRecExporter instead. @@ -285,7 +275,6 @@ def __init__( apply_normalizing_transform: bool = True, sorting_mode_hint: str = "cameraDistance", linear_srgb: bool = False, - omni_usd: bool = False, export_post_processing: bool = True, post_processing_export_mode: str = MODE_POST_PROCESSING_EXPORT_BAKED_SH, post_processing_bake_epochs: int = 1, @@ -293,7 +282,6 @@ def __init__( post_processing_bake_camera_id: int = 0, post_processing_bake_frame_id: int = 0, ppisp_bake_vignetting_mode: str = MODE_PPISP_BAKE_VIGNETTING_ACHROMATIC_FIT, - ov_post_processing: str = MODE_PPISP_OMNI_FALLBACK_NONE, frames_per_second: float = 1.0, ): """ @@ -308,13 +296,11 @@ def __init__( apply_normalizing_transform: Apply transform to normalize scene orientation. sorting_mode_hint: Sorting hint for rendering ("cameraDistance", "zDepth"). linear_srgb: If True, set prim color space to lin_rec709_scene. - omni_usd: If True, author Omniverse-specific USD features such as - ParticleFieldEmissive MDL binding and PPISP SPG graphs. export_post_processing: If True, export the checkpoint post-processing - module using the selected export mode. - post_processing_export_mode: "baked-sh" bakes one fixed - post-processing transform into Gaussian SH coefficients. - "native" uses the module-specific native export path. + module with the selected export mode. + post_processing_export_mode: "baked-sh" bakes one fixed transform + into Gaussian SH coefficients. "omni-native" uses the module's + Omniverse-native path; currently PPISP SPG. post_processing_bake_epochs: Number of sequential passes over the train/reference set. post_processing_bake_learning_rate: Adam learning rate for baked SH. post_processing_bake_camera_id: Camera index for the fixed baked transform. @@ -322,8 +308,6 @@ def __init__( ppisp_bake_vignetting_mode: "none" disables vignetting in the PPISP reference. "achromatic-fit" keeps chromatic PPISP vignetting in the reference and applies an achromatic estimate only in the fit loss. - ov_post_processing: PPISP export implementation selector. "none" - uses the native PPISP path when export_post_processing is enabled. frames_per_second: Sets stage.timeCodesPerSecond. Time codes are always bare frame indices (float(frame_idx)), so this controls playback speed. Default 1.0 means 1 frame per second of real time. @@ -338,7 +322,6 @@ def __init__( self.apply_normalizing_transform = apply_normalizing_transform self.sorting_mode_hint = sorting_mode_hint self.linear_srgb = linear_srgb - self.omni_usd = omni_usd self.export_post_processing = export_post_processing self.post_processing_export_mode = normalize_post_processing_export_mode(post_processing_export_mode) self.post_processing_bake_epochs = int(post_processing_bake_epochs) @@ -346,19 +329,6 @@ def __init__( self.post_processing_bake_camera_id = int(post_processing_bake_camera_id) self.post_processing_bake_frame_id = int(post_processing_bake_frame_id) self.ppisp_bake_vignetting_mode = str(ppisp_bake_vignetting_mode) - self.ov_post_processing = normalize_ov_post_processing_mode(ov_post_processing) - if not self.export_post_processing and self.ov_post_processing != MODE_PPISP_OMNI_FALLBACK_NONE: - raise ValueError( - "export_usd.ov-post-processing requires export_usd.export_post_processing=true. " - "Set export_post_processing=true to export PPISP through an Omniverse USD fallback, " - "or set ov-post-processing=none." - ) - if not self.omni_usd and self.ov_post_processing != MODE_PPISP_OMNI_FALLBACK_NONE: - raise ValueError( - "export_usd.ov-post-processing requires export_usd.omni-usd=true. " - "Set omni-usd=true to author Omniverse USD post-processing fallback features, " - "or set ov-post-processing=none." - ) self.frames_per_second = frames_per_second def _create_default_stage(self, referenced_stages: List[NamedUSDStage]) -> NamedUSDStage: @@ -410,14 +380,12 @@ def export( post_processing is not None and self.export_post_processing and self.post_processing_export_mode == MODE_POST_PROCESSING_EXPORT_BAKED_SH - and self.ov_post_processing == MODE_PPISP_OMNI_FALLBACK_NONE ) - if has_ppisp_module and self.export_post_processing and not uses_baked_post_processing_export and not self.omni_usd: - raise ValueError( - "PPISP SPG/fallback USD export requires export_usd.omni-usd=true. " - "Use post_processing_export_mode=baked-sh for standard USD baked-SH export, " - "or set export_usd.export_post_processing=false / pass --no-export-post-processing." - ) + uses_omni_native_post_processing_export = ( + post_processing is not None + and self.export_post_processing + and self.post_processing_export_mode == MODE_POST_PROCESSING_EXPORT_OMNI_NATIVE + ) if uses_baked_post_processing_export: from threedgrut.export.usd.post_processing_sh_bake import ( @@ -445,6 +413,8 @@ def export( epochs=self.post_processing_bake_epochs, learning_rate=self.post_processing_bake_learning_rate, ) + if uses_omni_native_post_processing_export and not has_ppisp_module: + raise ValueError("Omniverse-native post-processing export currently supports PPISP post-processing only.") # Get model data via accessor accessor = GaussianExportAccessor(model, conf) @@ -491,8 +461,8 @@ def export( half_features=self.half_features, sorting_mode_hint=self.sorting_mode_hint, linear_srgb=self.linear_srgb, - omni_usd=self.omni_usd, - has_post_processing=has_ppisp_module and self.export_post_processing and not uses_baked_post_processing_export, + omni_usd=uses_omni_native_post_processing_export, + has_post_processing=uses_omni_native_post_processing_export, ) writer.create_prim(attrs.num_gaussians) writer.write_attributes(attrs) @@ -591,29 +561,15 @@ def export( except (AttributeError, ValueError, ImportError) as e: logger.warning(f"Failed to export background: {e}") - render_product_entries = None if not self.export_post_processing and _is_ppisp_post_processing(post_processing): logger.warning( "PPISP post-processing module is present but export_usd.export_post_processing=false; " "PPISP effects will not be exported. Set export_usd.export_post_processing=true to export them." ) - has_ppisp_export_source = ( - self.export_post_processing and post_processing is not None and not uses_baked_post_processing_export - ) - export_spg_ppisp = has_ppisp_export_source and ( - ( - self.post_processing_export_mode == MODE_POST_PROCESSING_EXPORT_NATIVE - and self.ov_post_processing == MODE_PPISP_OMNI_FALLBACK_NONE - ) - or self.ov_post_processing == MODE_PPISP_OMNI_FALLBACK_SPG_PLUS_FITTED_POST_PROCESSING - ) - export_omni_ppisp_fallback = ( - has_ppisp_export_source and self.ov_post_processing != MODE_PPISP_OMNI_FALLBACK_NONE - ) - needs_ppisp_render_products = has_ppisp_export_source if self.export_post_processing and post_processing is None: - logger.info("Post-processing export requested but no post_processing module is available; skipping /Render export") - if needs_ppisp_render_products: + logger.info("Post-processing export requested but no post_processing module is available; skipping bake") + + if uses_omni_native_post_processing_export: render_product_entries = self._create_ppisp_render_products( stage=scene_stage, dataset=dataset, @@ -622,31 +578,16 @@ def export( camera_prim_paths=camera_prim_paths, camera_params=camera_params, ) - - # Export PPISP as SPG shaders on RenderProducts - if export_spg_ppisp and render_product_entries is not None: - _set_render_setting(scene_stage, _GAUSSIAN_SKIP_TONEMAPPING_RENDER_SETTING, False) - logger.info( - "Disabled Gaussian skip-tonemapping render setting for PPISP SPG export" - ) - self._export_ppisp( - stage=scene_stage, - dataset=dataset, - camera_names=camera_names, - post_processing=post_processing, - files=files, - ) - - # Export PPISP through fitted Omniverse USD post-processing settings. - if export_omni_ppisp_fallback and render_product_entries is not None: - self._export_ov_post_processing( - stage=scene_stage, - camera_names=camera_names, - camera_prim_paths=camera_prim_paths, - render_product_entries=render_product_entries, - dataset=dataset, - post_processing=post_processing, - ) + if render_product_entries is not None: + _set_render_setting(scene_stage, _GAUSSIAN_SKIP_TONEMAPPING_RENDER_SETTING, False) + logger.info("Disabled Gaussian skip-tonemapping render setting for PPISP Omniverse-native export") + self._export_ppisp( + stage=scene_stage, + dataset=dataset, + camera_names=camera_names, + post_processing=post_processing, + files=files, + ) # Package if suffix == ".usdz": @@ -684,7 +625,7 @@ def _create_ppisp_render_products( camera_prim_paths: Dict[str, str], camera_params, ): - """Create /Render RenderProducts shared by SPG and Omniverse fallback PPISP exports.""" + """Create /Render RenderProducts for PPISP Omniverse-native export.""" if dataset is None or not camera_prim_paths: logger.warning("No camera prims available for PPISP RenderProduct wiring, skipping") return None @@ -723,7 +664,7 @@ def _export_ppisp( if not isinstance(post_processing, PPISP): logger.warning( f"export_post_processing=True but post_processing is {type(post_processing).__name__}, " - "expected ppisp.PPISP — skipping" + "expected ppisp.PPISP - skipping" ) return @@ -740,11 +681,11 @@ def _export_ppisp( "stored exposure/color parameters, vignetting, and CRF." ) + from threedgrut.export.usd.ppisp_spg import get_ppisp_spg_files from threedgrut.export.usd.writers.ppisp_writer import ( add_ppisp_to_all_render_products, build_camera_frame_mapping, ) - from threedgrut.export.usd.ppisp_spg import get_ppisp_spg_files _, camera_frame_mapping = build_camera_frame_mapping(dataset) @@ -759,53 +700,12 @@ def _export_ppisp( logger.warning(f"Failed to add PPISP shaders: {e}") return - # Add SPG sidecars to the USDZ package spg_files = get_ppisp_spg_files() for spg_file in spg_files: if not any(f.filename == spg_file.filename for f in files): files.append(spg_file) - logger.info(f"PPISP SPG export complete: {len(spg_files)} sidecar(s) added") - - def _export_ov_post_processing( - self, - stage, - camera_names, - camera_prim_paths, - render_product_entries, - dataset, - post_processing, - ) -> None: - """Attach Omniverse USD post-processing fallback attributes to RenderProducts.""" - try: - from ppisp import PPISP # type: ignore[import-not-found] - except ImportError: - logger.warning("ppisp package not available, skipping Omniverse post-processing fallback export") - return - - if not isinstance(post_processing, PPISP): - logger.warning( - f"ov-post-processing={self.ov_post_processing} but post_processing is " - f"{type(post_processing).__name__}, expected ppisp.PPISP — skipping" - ) - return - - from threedgrut.export.usd.writers.ov_post_processing import add_ov_post_processing - from threedgrut.export.usd.writers.ppisp_writer import build_camera_frame_mapping - - _, camera_frame_mapping = build_camera_frame_mapping(dataset) - try: - add_ov_post_processing( - stage=stage, - camera_names=camera_names, - camera_prim_paths=camera_prim_paths, - camera_frame_mapping=camera_frame_mapping, - render_product_entries=render_product_entries, - post_processing=post_processing, - mode=self.ov_post_processing, - ) - except Exception as e: - logger.warning(f"Failed to add Omniverse post-processing fallback: {e}") + logger.info(f"PPISP Omniverse-native export complete: {len(spg_files)} sidecar(s) added") @classmethod def from_config(cls, conf) -> "USDExporter": @@ -827,7 +727,6 @@ def from_config(cls, conf) -> "USDExporter": apply_normalizing_transform=getattr(export_conf, "apply_normalizing_transform", True), sorting_mode_hint=getattr(export_conf, "sorting_mode_hint", "cameraDistance"), linear_srgb=getattr(export_conf, "linear_srgb", False), - omni_usd=_get_export_config_value(export_conf, "omni-usd", "omni_usd", False), export_post_processing=_get_export_config_value( export_conf, "export-post-processing", @@ -870,11 +769,5 @@ def from_config(cls, conf) -> "USDExporter": "ppisp_bake_vignetting_mode", MODE_PPISP_BAKE_VIGNETTING_ACHROMATIC_FIT, ), - ov_post_processing=_get_export_config_value( - export_conf, - "ov-post-processing", - "ov_post_processing", - MODE_PPISP_OMNI_FALLBACK_NONE, - ), frames_per_second=getattr(export_conf, "frames_per_second", 1.0), ) diff --git a/threedgrut/export/usd/writers/__init__.py b/threedgrut/export/usd/writers/__init__.py index 0111460f..92531563 100644 --- a/threedgrut/export/usd/writers/__init__.py +++ b/threedgrut/export/usd/writers/__init__.py @@ -21,14 +21,12 @@ - export_cameras_to_usd: one Camera prim per physical camera, animated xforms - create_render_products: /Render scope with per-camera RenderProducts - add_ppisp_to_all_render_products: PPISP SPG shader on RenderProducts -- add_ov_post_processing: Omniverse USD post-processing PPISP fallback """ from threedgrut.export.usd.writers.background import export_background_to_usd from threedgrut.export.usd.writers.base import GaussianUSDWriter, create_gaussian_writer from threedgrut.export.usd.writers.camera import export_cameras_to_usd from threedgrut.export.usd.writers.lightfield import GaussianLightFieldWriter -from threedgrut.export.usd.writers.ov_post_processing import add_ov_post_processing from threedgrut.export.usd.writers.ppisp_writer import add_ppisp_to_all_render_products from threedgrut.export.usd.writers.render_product import create_render_products @@ -40,5 +38,4 @@ "export_background_to_usd", "create_render_products", "add_ppisp_to_all_render_products", - "add_ov_post_processing", ] diff --git a/threedgrut/export/usd/writers/ov_post_processing.py b/threedgrut/export/usd/writers/ov_post_processing.py deleted file mode 100644 index 977ec640..00000000 --- a/threedgrut/export/usd/writers/ov_post_processing.py +++ /dev/null @@ -1,385 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Omniverse USD post-processing fallback writer for PPISP exports. - -This writer is a degraded fallback for Kit versions where SPG is unavailable or -unreliable. It authors Omniverse USD render settings only; exact PPISP export -remains the SPG path. -""" - -from __future__ import annotations - -import logging -from typing import Dict, Iterable, List, Tuple - -import numpy as np -from pxr import Gf, Sdf, Usd - -from threedgrut.export.usd.writers.camera import _make_usd_prim_name - -log = logging.getLogger(__name__) - -MODE_PPISP_OMNI_FALLBACK_NONE = "none" -MODE_PPISP_OMNI_FALLBACK_EXPOSURE = "ppisp-exposure-fallback" -MODE_PPISP_OMNI_FALLBACK_FITTED_POST_PROCESSING = "ppisp-fitted-post-processing-fallback" -MODE_PPISP_OMNI_FALLBACK_SPG_PLUS_FITTED_POST_PROCESSING = "ppisp-spg-plus-fitted-post-processing-fallback" - -PPISP_OMNI_POST_PROCESSING_FALLBACK_MODES = { - MODE_PPISP_OMNI_FALLBACK_NONE, - MODE_PPISP_OMNI_FALLBACK_EXPOSURE, - MODE_PPISP_OMNI_FALLBACK_FITTED_POST_PROCESSING, - MODE_PPISP_OMNI_FALLBACK_SPG_PLUS_FITTED_POST_PROCESSING, -} - -_BASE_EXPOSURE_TIME_SECONDS = 0.02 -_DEFAULT_EXPOSURE_FSTOP = 5.0 -_DEFAULT_EXPOSURE_ISO = 100.0 -_DEFAULT_EXPOSURE_RESPONSIVITY = 1.10267091 -_DEFAULT_RENDER_RESOLUTION = (1280, 720) - -_CAMERA_EXPOSURE_APIS = ["OmniRtxCameraExposureAPI_1"] -_RENDER_PRODUCT_APIS = [ - "OmniRtxPostTonemapIrayReinhardAPI_1", - "OmniRtxPostColorGradingAPI_1", - "OmniRtxPostTvNoiseAPI_1", -] - -_ZCA_BLUE = np.array([[0.0480542, -0.0043631], [-0.0043631, 0.0481283]], dtype=np.float64) -_ZCA_RED = np.array([[0.0580570, -0.0179872], [-0.0179872, 0.0431061]], dtype=np.float64) -_ZCA_GREEN = np.array([[0.0433336, -0.0180537], [-0.0180537, 0.0580500]], dtype=np.float64) -_ZCA_NEUTRAL = np.array([[0.0128369, -0.0034654], [-0.0034654, 0.0128158]], dtype=np.float64) - - -def normalize_ov_post_processing_mode(mode: str | None) -> str: - """Normalize and validate the ``export_usd.ov-post-processing`` value.""" - normalized = MODE_PPISP_OMNI_FALLBACK_NONE if mode is None else str(mode).strip().lower() - if normalized not in PPISP_OMNI_POST_PROCESSING_FALLBACK_MODES: - raise ValueError( - f"Unsupported ov-post-processing mode '{mode}'. " - f"Expected one of: {sorted(PPISP_OMNI_POST_PROCESSING_FALLBACK_MODES)}" - ) - return normalized - - -def _as_numpy(value) -> np.ndarray: - if hasattr(value, "detach"): - value = value.detach() - if hasattr(value, "cpu"): - value = value.cpu() - if hasattr(value, "numpy"): - return value.numpy() - return np.asarray(value) - - -def _prepend_api_schemas(prim: Usd.Prim, schemas: Iterable[str]) -> None: - """Apply schemas by authoring the same listOp shape used by Kit examples.""" - schemas = [schema for schema in schemas if schema] - if not schemas: - return - prim.SetMetadata("apiSchemas", Sdf.TokenListOp.Create(prependedItems=schemas)) - - -def _create_float_attr(prim: Usd.Prim, name: str, value: float): - attr = prim.CreateAttribute(name, Sdf.ValueTypeNames.Float) - attr.Set(float(value)) - return attr - - -def _create_bool_attr(prim: Usd.Prim, name: str, value: bool): - attr = prim.CreateAttribute(name, Sdf.ValueTypeNames.Bool) - attr.Set(bool(value)) - return attr - - -def _create_color_attr(prim: Usd.Prim, name: str, value) -> Usd.Attribute: - vec = Gf.Vec3f(float(value[0]), float(value[1]), float(value[2])) - attr = prim.CreateAttribute(name, Sdf.ValueTypeNames.Color3f) - attr.Set(vec) - return attr - - -def _compute_homography(color_latent: np.ndarray) -> np.ndarray: - """Compute PPISP's RGI homography from one 8-float color latent vector.""" - b_lat = color_latent[0:2] - r_lat = color_latent[2:4] - g_lat = color_latent[4:6] - n_lat = color_latent[6:8] - - bd = _ZCA_BLUE @ b_lat - rd = _ZCA_RED @ r_lat - gd = _ZCA_GREEN @ g_lat - nd = _ZCA_NEUTRAL @ n_lat - - t_blue = np.array([bd[0], bd[1], 1.0], dtype=np.float64) - t_red = np.array([1.0 + rd[0], rd[1], 1.0], dtype=np.float64) - t_green = np.array([gd[0], 1.0 + gd[1], 1.0], dtype=np.float64) - t_gray = np.array([1.0 / 3.0 + nd[0], 1.0 / 3.0 + nd[1], 1.0], dtype=np.float64) - - target = np.stack([t_blue, t_red, t_green], axis=1) - skew = np.array( - [ - [0.0, -t_gray[2], t_gray[1]], - [t_gray[2], 0.0, -t_gray[0]], - [-t_gray[1], t_gray[0], 0.0], - ], - dtype=np.float64, - ) - matrix = skew @ target - lam = np.cross(matrix[0], matrix[1]) - if np.dot(lam, lam) < 1.0e-20: - lam = np.cross(matrix[0], matrix[2]) - if np.dot(lam, lam) < 1.0e-20: - lam = np.cross(matrix[1], matrix[2]) - - source_inv = np.array([[-1.0, -1.0, 1.0], [1.0, 0.0, 0.0], [0.0, 1.0, 0.0]], dtype=np.float64) - homography = target @ np.diag(lam) @ source_inv - if abs(homography[2, 2]) > 1.0e-20: - homography = homography / homography[2, 2] - return homography - - -def _apply_color_homography(rgb: np.ndarray, homography: np.ndarray) -> np.ndarray: - intensity = np.sum(rgb, axis=1) - rgi = np.stack([rgb[:, 0], rgb[:, 1], intensity], axis=1) - corrected = rgi @ homography.T - corrected = corrected * (intensity / (corrected[:, 2] + 1.0e-5))[:, None] - return np.stack([corrected[:, 0], corrected[:, 1], corrected[:, 2] - corrected[:, 0] - corrected[:, 1]], axis=1) - - -def _fit_grade_gain(color_latent: np.ndarray) -> np.ndarray: - """Fit USD color grade gain to PPISP's cross-channel homography.""" - homography = _compute_homography(color_latent) - values = np.linspace(0.05, 1.0, 5, dtype=np.float64) - rgb = np.array(np.meshgrid(values, values, values), dtype=np.float64).T.reshape(-1, 3) - target = np.clip(_apply_color_homography(rgb, homography), 0.0, 4.0) - denom = np.maximum(np.sum(rgb * rgb, axis=0), 1.0e-8) - gain = np.sum(target * rgb, axis=0) / denom - return np.clip(gain, 0.25, 4.0) - - -def _bounded_softplus(raw: np.ndarray, min_value: float) -> np.ndarray: - return min_value + np.log1p(np.exp(raw)) - - -def _sigmoid(raw: np.ndarray) -> np.ndarray: - return 1.0 / (1.0 + np.exp(-raw)) - - -def _apply_crf(x: np.ndarray, raw_params: np.ndarray) -> np.ndarray: - toe = _bounded_softplus(raw_params[0], 0.3) - shoulder = _bounded_softplus(raw_params[1], 0.3) - gamma = _bounded_softplus(raw_params[2], 0.1) - center = _sigmoid(raw_params[3]) - lerp_val = (shoulder - toe) * center + toe - a = (shoulder * center) / lerp_val - b = 1.0 - a - y = np.where( - x <= center, - a * np.power(x / center, toe), - 1.0 - b * np.power((1.0 - x) / (1.0 - center), shoulder), - ) - return np.power(np.maximum(0.0, y), gamma) - - -def _fit_grade_gamma(crf_params: np.ndarray) -> np.ndarray: - """Fit USD grade gamma to PPISP's per-channel CRF.""" - x = np.linspace(0.02, 0.98, 96, dtype=np.float64) - candidates = np.linspace(0.25, 4.0, 128, dtype=np.float64) - result = [] - for channel in range(3): - target = _apply_crf(x, crf_params[channel]) - errors = [np.mean((np.power(x, 1.0 / gamma) - target) ** 2) for gamma in candidates] - result.append(float(candidates[int(np.argmin(errors))])) - return np.asarray(result, dtype=np.float64) - - -def _ppisp_vignette_luminance(vig_params: np.ndarray, width: int, height: int) -> Tuple[np.ndarray, np.ndarray]: - sample_w = 48 - sample_h = max(8, int(round(sample_w * max(height, 1) / max(width, 1)))) - xs = (np.arange(sample_w, dtype=np.float64) + 0.5 - sample_w * 0.5) / sample_w - ys = (np.arange(sample_h, dtype=np.float64) + 0.5 - sample_h * 0.5) / sample_w - grid_x, grid_y = np.meshgrid(xs, ys) - uv = np.stack([grid_x, grid_y], axis=-1) - - rgb_falloff = [] - for channel in range(3): - center = vig_params[channel, 0:2] - delta = uv - center - r2 = np.sum(delta * delta, axis=-1) - falloff = 1.0 + vig_params[channel, 2] * r2 + vig_params[channel, 3] * r2**2 + vig_params[channel, 4] * r2**3 - rgb_falloff.append(np.clip(falloff, 0.0, 1.0)) - rgb_falloff = np.stack(rgb_falloff, axis=-1) - luminance = np.dot(rgb_falloff, np.array([0.2126, 0.7152, 0.0722], dtype=np.float64)) - - org_u = (np.arange(sample_w, dtype=np.float64) + 0.5) / sample_w - org_v = (np.arange(sample_h, dtype=np.float64) + 0.5) / sample_h - org_x, org_y = np.meshgrid(org_u, org_v) - org_uv = np.stack([org_x, org_y], axis=-1) - return luminance, org_uv - - -def _fit_tv_vignette(vig_params: np.ndarray, width: int, height: int) -> Tuple[bool, float, float]: - target, org_uv = _ppisp_vignette_luminance(vig_params, width, height) - if float(np.max(np.abs(target - 1.0))) < 1.0e-3: - return False, 107.0, 0.7 - - uv2 = org_uv * (1.0 - org_uv) - base = uv2[..., 0] * uv2[..., 1] - best_error = float("inf") - best_size = 107.0 - best_strength = 0.7 - - for size in np.linspace(1.0, 180.0, 72): - raw = np.maximum(base * (size + 14.0), 1.0e-8) - for strength in np.linspace(0.2, 2.0, 73): - candidate = np.power(raw, strength) - error = float(np.mean((candidate - target) ** 2)) - if error < best_error: - best_error = error - best_size = float(size) - best_strength = float(strength) - - return True, best_size, best_strength - - -def _author_camera_exposure( - stage: Usd.Stage, - camera_path: str, - frame_indices: List[int], - exposure_params: np.ndarray, -) -> None: - camera_prim = stage.GetPrimAtPath(camera_path) - if not camera_prim.IsValid(): - log.warning("Cannot author Omniverse fallback exposure: missing camera prim %s", camera_path) - return - - _prepend_api_schemas(camera_prim, _CAMERA_EXPOSURE_APIS) - _create_float_attr(camera_prim, "exposure:fStop", _DEFAULT_EXPOSURE_FSTOP) - _create_float_attr(camera_prim, "exposure:iso", _DEFAULT_EXPOSURE_ISO) - _create_float_attr(camera_prim, "exposure:responsivity", _DEFAULT_EXPOSURE_RESPONSIVITY) - - valid = [frame_idx for frame_idx in frame_indices if frame_idx < len(exposure_params)] - exposure_values = np.exp2(exposure_params[valid]) * _BASE_EXPOSURE_TIME_SECONDS if valid else np.asarray([]) - default_value = float(np.mean(exposure_values)) if len(exposure_values) else _BASE_EXPOSURE_TIME_SECONDS - - exposure_time = camera_prim.CreateAttribute("exposure:time", Sdf.ValueTypeNames.Float) - exposure_time.Set(default_value) - for frame_idx, value in zip(valid, exposure_values): - exposure_time.Set(float(value), float(frame_idx)) - - -def _author_tv_vignette(render_product: Usd.Prim, vig_params: np.ndarray, width: int, height: int) -> None: - enabled, size, strength = _fit_tv_vignette(vig_params, width, height) - _create_bool_attr(render_product, "omni:rtx:post:tvNoise:enabled", enabled) - _create_bool_attr(render_product, "omni:rtx:post:tvNoise:vignetting:enabled", enabled) - _create_float_attr(render_product, "omni:rtx:post:tvNoise:vignetting:size", size) - _create_float_attr(render_product, "omni:rtx:post:tvNoise:vignetting:strength", strength) - - for attr_name in ( - "omni:rtx:post:tvNoise:filmGrain:enabled", - "omni:rtx:post:tvNoise:ghostFlickering:enabled", - "omni:rtx:post:tvNoise:randomSplotches:enabled", - "omni:rtx:post:tvNoise:scanlines:enabled", - "omni:rtx:post:tvNoise:scrollBug:enabled", - "omni:rtx:post:tvNoise:verticalLines:enabled", - "omni:rtx:post:tvNoise:vignetting:flickering:enabled", - "omni:rtx:post:tvNoise:waveDistortion:enabled", - ): - _create_bool_attr(render_product, attr_name, False) - - -def _author_color_grade( - render_product: Usd.Prim, - frame_indices: List[int], - color_params: np.ndarray, - crf_params: np.ndarray, -) -> None: - valid = [frame_idx for frame_idx in frame_indices if frame_idx < len(color_params)] - gains = np.stack([_fit_grade_gain(color_params[frame_idx]) for frame_idx in valid], axis=0) if valid else np.ones((0, 3)) - default_gain = np.mean(gains, axis=0) if len(gains) else np.ones(3, dtype=np.float64) - gamma = _fit_grade_gamma(crf_params) - - grade_enabled = bool(np.max(np.abs(default_gain - 1.0)) > 1.0e-3 or np.max(np.abs(gamma - 1.0)) > 1.0e-3) - _create_bool_attr(render_product, "omni:rtx:post:grade:enabled", grade_enabled) - gain_attr = _create_color_attr(render_product, "omni:rtx:post:grade:gain", default_gain) - _create_color_attr(render_product, "omni:rtx:post:grade:gamma", gamma) - _create_color_attr(render_product, "omni:rtx:post:grade:offset", (0.0, 0.0, 0.0)) - _create_color_attr(render_product, "omni:rtx:post:grade:contrast", (1.0, 1.0, 1.0)) - _create_color_attr(render_product, "omni:rtx:post:grade:saturation", (1.0, 1.0, 1.0)) - - for frame_idx, gain in zip(valid, gains): - gain_attr.Set(Gf.Vec3f(float(gain[0]), float(gain[1]), float(gain[2])), float(frame_idx)) - - -def add_ov_post_processing( - stage: Usd.Stage, - camera_names: List[str], - camera_prim_paths: Dict[str, str], - camera_frame_mapping: Dict[str, List[int]], - render_product_entries: Dict[str, Tuple[str, int, int]], - post_processing, - mode: str, - render_scope_path: str = "/Render", -) -> None: - """Author Omniverse USD post-processing settings for PPISP fallback export.""" - normalized_mode = normalize_ov_post_processing_mode(mode) - if normalized_mode == MODE_PPISP_OMNI_FALLBACK_NONE: - return - - exposure_params = _as_numpy(post_processing.exposure_params) - color_params = _as_numpy(post_processing.color_params) - vignetting_params = _as_numpy(post_processing.vignetting_params) - crf_params = _as_numpy(post_processing.crf_params) - - camera_name_to_index = {name: idx for idx, name in enumerate(camera_names)} - writes_fitted_post_processing = normalized_mode in { - MODE_PPISP_OMNI_FALLBACK_FITTED_POST_PROCESSING, - MODE_PPISP_OMNI_FALLBACK_SPG_PLUS_FITTED_POST_PROCESSING, - } - - for camera_name in camera_names: - frame_indices = camera_frame_mapping.get(camera_name, []) - camera_path = camera_prim_paths.get(camera_name) - if camera_path is None: - log.warning("Skipping Omniverse post-processing fallback for %s: missing camera prim", camera_name) - continue - - _author_camera_exposure(stage, camera_path, frame_indices, exposure_params) - - if not writes_fitted_post_processing: - continue - - camera_index = camera_name_to_index[camera_name] - render_product_name = _make_usd_prim_name(camera_name) - render_product_path = f"{render_scope_path}/{render_product_name}" - render_product = stage.GetPrimAtPath(render_product_path) - if not render_product.IsValid(): - log.warning("Skipping Omniverse post-processing fallback for %s: missing RenderProduct", camera_name) - continue - - _prepend_api_schemas(render_product, _RENDER_PRODUCT_APIS) - _, width, height = render_product_entries.get(camera_name, ("", *_DEFAULT_RENDER_RESOLUTION)) - width = width or _DEFAULT_RENDER_RESOLUTION[0] - height = height or _DEFAULT_RENDER_RESOLUTION[1] - - _author_tv_vignette(render_product, vignetting_params[camera_index], width, height) - _author_color_grade(render_product, frame_indices, color_params, crf_params[camera_index]) - - log.warning( - "Authored Omniverse USD post-processing PPISP fallback mode '%s'. " - "This is approximate and not SPG-fidelity.", - normalized_mode, - ) From 5d033154b7cbd11ad37e13236fe62589c33c6af2 Mon Sep 17 00:00:00 2001 From: Nicolas Moenne-Loccoz Date: Wed, 29 Apr 2026 11:10:06 -0400 Subject: [PATCH 14/42] feat(export): support fixed PPISP native export Add fixed camera/frame controls for PPISP omni-native export and extend bake validation with L1 fitting plus reference baseline outputs. --- configs/base_gs.yaml | 6 ++- threedgrut/export/scripts/export_usd.py | 26 +++++++++ .../post_processing_sh_bake_validation.py | 45 ++++++++++++++-- threedgrut/export/usd/exporter.py | 30 +++++++++++ .../export/usd/post_processing_sh_bake.py | 2 +- threedgrut/export/usd/writers/ppisp_writer.py | 54 +++++++++++++++++-- 6 files changed, 153 insertions(+), 10 deletions(-) diff --git a/configs/base_gs.yaml b/configs/base_gs.yaml index 9b36a34e..12fb96b4 100644 --- a/configs/base_gs.yaml +++ b/configs/base_gs.yaml @@ -56,8 +56,12 @@ export_usd: # omni-native uses the module-specific Omniverse-native path; currently PPISP SPG. # baked-sh | omni-native post-processing-export-mode: baked-sh + # Optional fixed PPISP camera/frame for omni-native export. When frame is set, + # exposure/color are exported as static shader inputs instead of animation. + post-processing-export-camera-id: null + post-processing-export-frame-id: null # Number of sequential passes over the train/reference set used for fitting. - post-processing-bake-epochs: 1 + post-processing-bake-epochs: 3 post-processing-bake-learning-rate: 0.001 post-processing-bake-camera-id: 0 post-processing-bake-frame-id: 0 diff --git a/threedgrut/export/scripts/export_usd.py b/threedgrut/export/scripts/export_usd.py index 26479f0a..1ce3696e 100644 --- a/threedgrut/export/scripts/export_usd.py +++ b/threedgrut/export/scripts/export_usd.py @@ -145,6 +145,18 @@ def parse_args(): default=None, help="Post-processing export mode. 'omni-native' uses PPISP SPG and Omniverse material authoring.", ) + parser.add_argument( + "--post-processing-export-camera-id", + type=int, + default=None, + help="Optional PPISP camera id to use for every RenderProduct in omni-native export.", + ) + parser.add_argument( + "--post-processing-export-frame-id", + type=int, + default=None, + help="Optional PPISP frame id to write as static omni-native shader inputs instead of animation.", + ) parser.add_argument( "--post-processing-bake-epochs", type=int, @@ -380,6 +392,20 @@ def main(): linear_srgb=args.linear_srgb or getattr(export_conf, "linear_srgb", False), export_post_processing=export_post_processing, post_processing_export_mode=post_processing_export_mode, + post_processing_export_camera_id=_arg_or_conf( + args.post_processing_export_camera_id, + export_conf, + "post-processing-export-camera-id", + "post_processing_export_camera_id", + None, + ), + post_processing_export_frame_id=_arg_or_conf( + args.post_processing_export_frame_id, + export_conf, + "post-processing-export-frame-id", + "post_processing_export_frame_id", + None, + ), post_processing_bake_epochs=_arg_or_conf( args.post_processing_bake_epochs, export_conf, diff --git a/threedgrut/export/scripts/post_processing_sh_bake_validation.py b/threedgrut/export/scripts/post_processing_sh_bake_validation.py index f4a0b564..39615c19 100644 --- a/threedgrut/export/scripts/post_processing_sh_bake_validation.py +++ b/threedgrut/export/scripts/post_processing_sh_bake_validation.py @@ -50,6 +50,7 @@ ) from threedgrut.utils.color_correct import color_correct_affine from threedgrut.utils.logger import logger +from threedgrut.utils.post_processing_linear_to_srgb import linear_to_srgb from threedgrut.utils.render import apply_post_processing @@ -127,8 +128,8 @@ def _fitBakedSh( optimizer.zero_grad(set_to_none=True) bakedOutputs = bakedModel(gpuBatch) - fittedRgb = _applyAchromaticVignetting(bakedOutputs["pred_rgb"], fixedPpisp, gpuBatch, vignettingMode) - loss = torch.nn.functional.mse_loss(fittedRgb, referenceRgb) + fittedRgb = torch.clamp(linear_to_srgb(_applyAchromaticVignetting(bakedOutputs["pred_rgb"], fixedPpisp, gpuBatch, vignettingMode)), 0, 1) + loss = torch.nn.functional.l1_loss(fittedRgb, referenceRgb) loss.backward() optimizer.step() @@ -147,6 +148,7 @@ def _evaluateBakedSh( referenceModel, bakedModel, fixedPpisp, + fullFixedPpisp, dataset, dataloader, outputRoot: Path, @@ -160,13 +162,18 @@ def _evaluateBakedSh( "lpips": LearnedPerceptualImagePatchSimilarity(net_type="vgg", normalize=True).to("cuda"), } + fullReferencePath = outputRoot / "full_ppisp_reference" referencePath = outputRoot / "reference" + unfittedPath = outputRoot / "unfitted" bakedPath = outputRoot / "baked" assistedPath = outputRoot / "baked_assisted" + fullReferencePath.mkdir(parents=True, exist_ok=True) referencePath.mkdir(parents=True, exist_ok=True) + unfittedPath.mkdir(parents=True, exist_ok=True) bakedPath.mkdir(parents=True, exist_ok=True) assistedPath.mkdir(parents=True, exist_ok=True) + unfittedPsnrValues = [] psnrValues = [] ssimValues = [] lpipsValues = [] @@ -185,15 +192,34 @@ def _evaluateBakedSh( for iteration, batch in enumerate(dataloader): gpuBatch = dataset.get_gpu_batch_with_intrinsics(batch) + fullReferenceRgb = _renderReference(referenceModel, fullFixedPpisp, gpuBatch) referenceRgb = _renderReference(referenceModel, fixedPpisp, gpuBatch) + unfittedOutputs = referenceModel(gpuBatch) + unfittedRgb = unfittedOutputs["pred_rgb"] bakedOutputs = bakedModel(gpuBatch) - bakedRgb = bakedOutputs["pred_rgb"] - assistedRgb = _applyAchromaticVignetting(bakedRgb, fixedPpisp, gpuBatch, vignettingMode) + bakedRgb = torch.clamp( + linear_to_srgb(bakedOutputs["pred_rgb"]), + 0, + 1, + ) + assistedRgb = torch.clamp( + linear_to_srgb(_applyAchromaticVignetting(bakedOutputs["pred_rgb"], fixedPpisp, gpuBatch, vignettingMode)), + 0, + 1, + ) + torchvision.utils.save_image( + fullReferenceRgb.squeeze(0).permute(2, 0, 1).clip(0, 1), + fullReferencePath / f"{iteration:05d}.png", + ) torchvision.utils.save_image( referenceRgb.squeeze(0).permute(2, 0, 1).clip(0, 1), referencePath / f"{iteration:05d}.png", ) + torchvision.utils.save_image( + unfittedRgb.squeeze(0).permute(2, 0, 1).clip(0, 1), + unfittedPath / f"{iteration:05d}.png", + ) torchvision.utils.save_image( bakedRgb.squeeze(0).permute(2, 0, 1).clip(0, 1), bakedPath / f"{iteration:05d}.png", @@ -203,6 +229,7 @@ def _evaluateBakedSh( assistedPath / f"{iteration:05d}.png", ) + unfittedPsnrValues.append(criterions["psnr"](unfittedRgb, referenceRgb).item()) psnrValues.append(criterions["psnr"](bakedRgb, referenceRgb).item()) assistedPsnrValues.append(criterions["psnr"](assistedRgb, referenceRgb).item()) if computeExtraMetrics: @@ -268,6 +295,8 @@ def _evaluateBakedSh( metrics = { "vignetting_mode": vignettingMode, + "unfitted_mean_psnr": float(np.mean(unfittedPsnrValues)), + "unfitted_std_psnr": float(np.std(unfittedPsnrValues)), "mean_psnr": float(np.mean(psnrValues)), "std_psnr": float(np.std(psnrValues)), "assisted_mean_psnr": float(np.mean(assistedPsnrValues)), @@ -362,6 +391,13 @@ def main() -> None: "cuda", include_vignetting=vignettingMode != MODE_PPISP_BAKE_VIGNETTING_NONE, ).eval() + fullFixedPpisp = FixedPPISP( + renderer.post_processing, + args.cameraId, + args.frameId, + "cuda", + include_vignetting=True, + ).eval() referenceModel = renderer.model.eval() bakedModel = renderer.model.clone().eval() @@ -388,6 +424,7 @@ def main() -> None: referenceModel=referenceModel, bakedModel=bakedModel, fixedPpisp=fixedPpisp, + fullFixedPpisp=fullFixedPpisp, dataset=renderer.dataset, dataloader=renderer.dataloader, outputRoot=outputRoot, diff --git a/threedgrut/export/usd/exporter.py b/threedgrut/export/usd/exporter.py index eb9e301a..8fc1b14a 100644 --- a/threedgrut/export/usd/exporter.py +++ b/threedgrut/export/usd/exporter.py @@ -277,6 +277,8 @@ def __init__( linear_srgb: bool = False, export_post_processing: bool = True, post_processing_export_mode: str = MODE_POST_PROCESSING_EXPORT_BAKED_SH, + post_processing_export_camera_id: int | None = None, + post_processing_export_frame_id: int | None = None, post_processing_bake_epochs: int = 1, post_processing_bake_learning_rate: float = 1.0e-3, post_processing_bake_camera_id: int = 0, @@ -301,6 +303,10 @@ def __init__( post_processing_export_mode: "baked-sh" bakes one fixed transform into Gaussian SH coefficients. "omni-native" uses the module's Omniverse-native path; currently PPISP SPG. + post_processing_export_camera_id: Optional PPISP camera index to use + for every RenderProduct in omni-native mode. + post_processing_export_frame_id: Optional PPISP frame index to write + as static exposure/color inputs in omni-native mode. post_processing_bake_epochs: Number of sequential passes over the train/reference set. post_processing_bake_learning_rate: Adam learning rate for baked SH. post_processing_bake_camera_id: Camera index for the fixed baked transform. @@ -324,6 +330,12 @@ def __init__( self.linear_srgb = linear_srgb self.export_post_processing = export_post_processing self.post_processing_export_mode = normalize_post_processing_export_mode(post_processing_export_mode) + self.post_processing_export_camera_id = ( + None if post_processing_export_camera_id is None else int(post_processing_export_camera_id) + ) + self.post_processing_export_frame_id = ( + None if post_processing_export_frame_id is None else int(post_processing_export_frame_id) + ) self.post_processing_bake_epochs = int(post_processing_bake_epochs) self.post_processing_bake_learning_rate = float(post_processing_bake_learning_rate) self.post_processing_bake_camera_id = int(post_processing_bake_camera_id) @@ -587,6 +599,8 @@ def export( camera_names=camera_names, post_processing=post_processing, files=files, + fixed_camera_id=self.post_processing_export_camera_id, + fixed_frame_id=self.post_processing_export_frame_id, ) # Package @@ -653,6 +667,8 @@ def _export_ppisp( camera_names, post_processing, files: List[NamedSerialized], + fixed_camera_id: int | None = None, + fixed_frame_id: int | None = None, ) -> None: """Attach PPISP SPG shaders to existing RenderProducts.""" try: @@ -695,6 +711,8 @@ def _export_ppisp( ppisp=post_processing, camera_names=camera_names, camera_frame_mapping=camera_frame_mapping, + fixed_camera_index=fixed_camera_id, + fixed_frame_index=fixed_frame_id, ) except Exception as e: logger.warning(f"Failed to add PPISP shaders: {e}") @@ -739,6 +757,18 @@ def from_config(cls, conf) -> "USDExporter": "post_processing_export_mode", MODE_POST_PROCESSING_EXPORT_BAKED_SH, ), + post_processing_export_camera_id=_get_export_config_value( + export_conf, + "post-processing-export-camera-id", + "post_processing_export_camera_id", + None, + ), + post_processing_export_frame_id=_get_export_config_value( + export_conf, + "post-processing-export-frame-id", + "post_processing_export_frame_id", + None, + ), post_processing_bake_epochs=_get_export_config_value( export_conf, "post-processing-bake-epochs", diff --git a/threedgrut/export/usd/post_processing_sh_bake.py b/threedgrut/export/usd/post_processing_sh_bake.py index ac1a3acf..8d7c37e6 100644 --- a/threedgrut/export/usd/post_processing_sh_bake.py +++ b/threedgrut/export/usd/post_processing_sh_bake.py @@ -137,7 +137,7 @@ def bake_post_processing_into_sh( fixed_post_processing, gpu_batch, ) - loss = torch.nn.functional.mse_loss(fitted_rgb, reference_rgb) + loss = torch.nn.functional.l1_loss(fitted_rgb, reference_rgb) loss.backward() optimizer.step() diff --git a/threedgrut/export/usd/writers/ppisp_writer.py b/threedgrut/export/usd/writers/ppisp_writer.py index 6c05d555..61ef041c 100644 --- a/threedgrut/export/usd/writers/ppisp_writer.py +++ b/threedgrut/export/usd/writers/ppisp_writer.py @@ -229,6 +229,18 @@ def _set_animated_exposure_params( attr.Set(float(exposure[frame_idx]), float(frame_idx)) +def _set_static_exposure_params( + shader: UsdShade.Shader, + ppisp: PPISP, + frame_index: int, +) -> None: + """Write one fixed exposure offset without USD time samples.""" + exposure = ppisp.exposure_params.cpu().numpy() + if frame_index < 0 or frame_index >= len(exposure): + raise ValueError(f"frame_index must be in [0, {len(exposure) - 1}], got {frame_index}.") + shader.CreateInput("exposureOffset", Sdf.ValueTypeNames.Float).Set(float(exposure[frame_index])) + + def _set_animated_color_params( shader: UsdShade.Shader, ppisp: PPISP, @@ -263,6 +275,24 @@ def _set_animated_color_params( ) +def _set_static_color_params( + shader: UsdShade.Shader, + ppisp: PPISP, + frame_index: int, +) -> None: + """Write one fixed color latent state without USD time samples.""" + color = ppisp.color_params.cpu().numpy() + if frame_index < 0 or frame_index >= len(color): + raise ValueError(f"frame_index must be in [0, {len(color) - 1}], got {frame_index}.") + + frame_color = color[frame_index] + control_point_names = ["colorLatentBlue", "colorLatentRed", "colorLatentGreen", "colorLatentNeutral"] + for i, name in enumerate(control_point_names): + shader.CreateInput(name, Sdf.ValueTypeNames.Float2).Set( + Gf.Vec2f(float(frame_color[i * 2]), float(frame_color[i * 2 + 1])) + ) + + # --------------------------------------------------------------------------- # Per-camera entry point # --------------------------------------------------------------------------- @@ -274,6 +304,7 @@ def add_ppisp_shader_to_render_product( camera_index: int, ppisp: PPISP, frame_indices: List[int], + fixed_frame_index: int | None = None, ) -> Usd.Prim: """Add a PPISP Shader to a RenderProduct for one physical camera. @@ -288,6 +319,8 @@ def add_ppisp_shader_to_render_product( camera_index: Index of this camera in the PPISP model. ppisp: Trained PPISP module. frame_indices: Global frame indices belonging to this camera. + fixed_frame_index: If set, write this one PPISP frame state as static + shader inputs instead of authoring animated time samples. Returns: The created PPISP Shader prim. @@ -295,15 +328,19 @@ def add_ppisp_shader_to_render_product( assert camera_index < ppisp.num_cameras, ( f"camera_index {camera_index} >= ppisp.num_cameras {ppisp.num_cameras}" ) - if not frame_indices: + if not frame_indices and fixed_frame_index is None: log.warning(f"No frames for camera {camera_index} at {render_product_path}, skipping") return stage.GetPseudoRoot() shader = _create_shader_prim(stage, render_product_path) _set_vignetting_params(shader, ppisp, camera_index) _set_crf_params(shader, ppisp, camera_index) - _set_animated_exposure_params(shader, ppisp, frame_indices) - _set_animated_color_params(shader, ppisp, frame_indices) + if fixed_frame_index is None: + _set_animated_exposure_params(shader, ppisp, frame_indices) + _set_animated_color_params(shader, ppisp, frame_indices) + else: + _set_static_exposure_params(shader, ppisp, fixed_frame_index) + _set_static_color_params(shader, ppisp, fixed_frame_index) log.info( f"Added PPISP shader to {render_product_path} " @@ -358,6 +395,8 @@ def add_ppisp_to_all_render_products( camera_names: List[str], camera_frame_mapping: Dict[str, List[int]], render_scope_path: str = "/Render", + fixed_camera_index: int | None = None, + fixed_frame_index: int | None = None, ) -> List[Usd.Prim]: """Add PPISP shaders to every RenderProduct in the Render scope. @@ -368,6 +407,10 @@ def add_ppisp_to_all_render_products( camera_frame_mapping: ``{camera_name: [frame_idx, ...]}`` from :func:`build_camera_frame_mapping`. render_scope_path: Path to the /Render Scope (default ``/Render``). + fixed_camera_index: If set, use this PPISP camera state for every + RenderProduct instead of matching the RenderProduct camera. + fixed_frame_index: If set, use this PPISP frame state as static shader + inputs instead of authoring animated exposure/color samples. Returns: List of created PPISP Shader prims. @@ -397,10 +440,12 @@ def add_ppisp_to_all_render_products( log.warning(f"RenderProduct '{prim_name}' has no matching camera name, skipping") continue - camera_index = camera_name_to_index.get(camera_name) + camera_index = fixed_camera_index if fixed_camera_index is not None else camera_name_to_index.get(camera_name) if camera_index is None: log.warning(f"Camera '{camera_name}' not in camera_names list, skipping") continue + if camera_index < 0 or camera_index >= ppisp.num_cameras: + raise ValueError(f"fixed_camera_index must be in [0, {ppisp.num_cameras - 1}], got {camera_index}.") frame_indices = camera_frame_mapping.get(camera_name, []) _create_ppisp_camera(stage, child) @@ -411,6 +456,7 @@ def add_ppisp_to_all_render_products( camera_index=camera_index, ppisp=ppisp, frame_indices=frame_indices, + fixed_frame_index=fixed_frame_index, ) created.append(shader_prim) From 6f2d4ef0a062323c7b952fa982a952d54c0024a2 Mon Sep 17 00:00:00 2001 From: Nicolas Moenne-Loccoz Date: Wed, 29 Apr 2026 11:42:38 -0400 Subject: [PATCH 15/42] feat(tools): add image comparison viewer Add a Viser-based image comparison tool for paired images and matched folders, with visual diff modes and metrics support. --- .cursor/rules/formatter-defaults.mdc | 10 + pyproject.toml | 1 + requirements.txt | 1 + tools/image_comparison/README.md | 61 ++ tools/image_comparison/image_comparison.py | 790 +++++++++++++++++++++ 5 files changed, 863 insertions(+) create mode 100644 .cursor/rules/formatter-defaults.mdc create mode 100644 tools/image_comparison/README.md create mode 100644 tools/image_comparison/image_comparison.py diff --git a/.cursor/rules/formatter-defaults.mdc b/.cursor/rules/formatter-defaults.mdc new file mode 100644 index 00000000..a6e41aaf --- /dev/null +++ b/.cursor/rules/formatter-defaults.mdc @@ -0,0 +1,10 @@ +--- +description: Formatter defaults for Cursor agents +alwaysApply: true +--- + +# Formatter Defaults + +- Do not run `bazel run //:format` by default. +- Prefer the smallest relevant formatter or validation command for the files changed, such as Python syntax checks, lints, or file-scoped formatters. +- Ask before running repository-wide formatting commands, especially when they may take a long time or touch unrelated files. diff --git a/pyproject.toml b/pyproject.toml index 797435ba..d5a2e87f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -63,6 +63,7 @@ dev = [ "clang-format==18.1.8", ] gui = [ + "flip-evaluator", "libigl", "polyscope>=2.6.0", "viser", diff --git a/requirements.txt b/requirements.txt index 928e9c01..55c62999 100644 --- a/requirements.txt +++ b/requirements.txt @@ -39,3 +39,4 @@ isort==5.13.0 # graphics user interfaces polyscope>=2.3.0 viser +flip-evaluator diff --git a/tools/image_comparison/README.md b/tools/image_comparison/README.md new file mode 100644 index 00000000..65ef590f --- /dev/null +++ b/tools/image_comparison/README.md @@ -0,0 +1,61 @@ +# Image Comparison + +Viser based image comparison viewer for either two specific images or two folders of matching image names. + +## Usage + +Compare two images: + +```bash +python tools/image_comparison/image_comparison.py --images path/to/a.png path/to/b.png +``` + +Compare two folders: + +```bash +python tools/image_comparison/image_comparison.py --folders path/to/folder_a path/to/folder_b +``` + +Optional arguments: + +```bash +python tools/image_comparison/image_comparison.py --folders path/to/folder_a path/to/folder_b --port 8080 --target_fps 20 +``` + +Serve on all network interfaces for another host: + +```bash +python tools/image_comparison/image_comparison.py --folders path/to/folder_a path/to/folder_b --host 0.0.0.0 --port 8080 +``` + +Then open `http://:8080` from the other host. If direct access is blocked, use SSH forwarding: + +```bash +ssh -L 8080:localhost:8080 user@server-host +``` + +## Viewer Modes + +- `Display Mode = fit_largest_dimension`: scales the image so one dimension fills the viewport while preserving the image aspect ratio. The other dimension is smaller than or equal to the viewport. This is the default. +- `Display Mode = fit`: stretches the image to fill both viewport dimensions. +- `slider`: displays both images in the same frame, split by a vertical or horizontal slider. The split can be changed from the `Slider Position` GUI control. +- `checkerboard`: alternates images with a checkerboard mask. +- `diff`: displays a selectable difference map with a `JET` colormap and a scale slider. + +When folder mode is used, images are matched by file name. Duplicate file names inside one folder are rejected so the comparison target is unambiguous. +Use `Previous Image` and `Next Image` to cycle through matched image pairs. + +## Metrics + +The `Metrics` panel displays scalar `PSNR`, `SSIM`, `LPIPS`, and `FLIP` values for the selected image pair. `PSNR` and `SSIM` are computed automatically. Use `Compute LPIPS / FLIP` to compute the heavier perceptual metrics on demand. + +The `Diff Metric` dropdown supports: + +- `l1`: per-pixel mean absolute RGB difference. +- `l2`: per-pixel RGB root mean squared difference. +- `psnr`: per-pixel PSNR-derived error, where lower PSNR is brighter. +- `ssim`: local `1 - SSIM` dissimilarity. +- `lpips`: scalar LPIPS displayed as a uniform heatmap. +- `flip`: FLIP error map when `flip-evaluator` is installed. + +`LPIPS` depends on `torchmetrics` and its model weights. `FLIP` depends on the `flip-evaluator` package, which provides the `flip_evaluator` Python module. Do not install the unrelated `flip` package. diff --git a/tools/image_comparison/image_comparison.py b/tools/image_comparison/image_comparison.py new file mode 100644 index 00000000..dc853cbb --- /dev/null +++ b/tools/image_comparison/image_comparison.py @@ -0,0 +1,790 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import math +import socket +import sys +import time +from dataclasses import dataclass +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import numpy as np +from PIL import Image + +SUPPORTED_IMAGE_EXTENSIONS = { + ".bmp", + ".jpeg", + ".jpg", + ".png", + ".tif", + ".tiff", + ".webp", +} + +COMPARISON_MODES = ["slider", "checkerboard", "diff"] +SLIDER_DIRECTIONS = ["vertical", "horizontal"] +DIFF_METRICS = ["l1", "l2", "psnr", "ssim", "lpips", "flip"] +DISPLAY_MODES = ["fit_largest_dimension", "fit"] + + +@dataclass +class MetricResults: + psnr: Optional[float] + ssim: Optional[float] + lpips: Optional[float] + flip: Optional[float] + lpips_error: Optional[str] = None + flip_error: Optional[str] = None + + +def import_viser(): + try: + import viser + except ImportError: + print('viser not installed, please install the gui extra or run "pip install viser"') + sys.exit(1) + + return viser + + +@dataclass(frozen=True) +class ImagePair: + name: str + image_a_path: Path + image_b_path: Path + + +def is_image_path(path: Path) -> bool: + return path.is_file() and path.suffix.lower() in SUPPORTED_IMAGE_EXTENSIONS + + +def collect_images_by_name(folder: Path) -> Dict[str, Path]: + images: Dict[str, Path] = {} + duplicate_names: List[str] = [] + + for path in sorted(folder.rglob("*")): + if not is_image_path(path): + continue + + image_name = path.name + if image_name in images: + duplicate_names.append(image_name) + continue + + images[image_name] = path + + if duplicate_names: + duplicate_list = ", ".join(sorted(set(duplicate_names))) + raise ValueError(f"Duplicate image names found in {folder}: {duplicate_list}") + + return images + + +def build_specific_image_pair(image_a_path: Path, image_b_path: Path) -> List[ImagePair]: + if not is_image_path(image_a_path): + raise ValueError(f"Invalid image path: {image_a_path}") + if not is_image_path(image_b_path): + raise ValueError(f"Invalid image path: {image_b_path}") + + return [ + ImagePair( + name=f"{image_a_path.name} <-> {image_b_path.name}", + image_a_path=image_a_path, + image_b_path=image_b_path, + ) + ] + + +def build_folder_image_pairs(folder_a_path: Path, folder_b_path: Path) -> List[ImagePair]: + if not folder_a_path.is_dir(): + raise ValueError(f"Invalid folder path: {folder_a_path}") + if not folder_b_path.is_dir(): + raise ValueError(f"Invalid folder path: {folder_b_path}") + + folder_a_images = collect_images_by_name(folder_a_path) + folder_b_images = collect_images_by_name(folder_b_path) + matched_names = sorted(set(folder_a_images).intersection(folder_b_images)) + + if not matched_names: + raise ValueError(f"No matching image names found between {folder_a_path} and {folder_b_path}") + + return [ + ImagePair( + name=image_name, + image_a_path=folder_a_images[image_name], + image_b_path=folder_b_images[image_name], + ) + for image_name in matched_names + ] + + +def load_image_rgb(path: Path) -> np.ndarray: + with Image.open(path) as image: + return np.asarray(image.convert("RGB"), dtype=np.float32) / 255.0 + + +def resize_image(image: np.ndarray, size: Tuple[int, int]) -> np.ndarray: + width, height = size + resample_filter = getattr(Image, "Resampling", Image).LANCZOS + resized = Image.fromarray(float_image_to_uint8(image)).resize((width, height), resample_filter) + return np.asarray(resized, dtype=np.float32) / 255.0 + + +def load_aligned_pair(image_pair: ImagePair) -> Tuple[np.ndarray, np.ndarray, str]: + image_a = load_image_rgb(image_pair.image_a_path) + image_b = load_image_rgb(image_pair.image_b_path) + + if image_a.shape == image_b.shape: + status = f"{image_pair.name}: {image_a.shape[1]}x{image_a.shape[0]}" + return image_a, image_b, status + + image_b = resize_image(image_b, (image_a.shape[1], image_a.shape[0])) + status = ( + f"{image_pair.name}: A {image_a.shape[1]}x{image_a.shape[0]}, " + f"B resized to match from {image_pair.image_b_path.name}" + ) + return image_a, image_b, status + + +def render_slider_comparison( + image_a: np.ndarray, + image_b: np.ndarray, + slider_position: float, + slider_direction: str, +) -> np.ndarray: + output = image_b.copy() + height, width = image_a.shape[:2] + slider_position = float(np.clip(slider_position, 0.0, 1.0)) + + if slider_direction == "horizontal": + split_row = int(round(height * slider_position)) + output[:split_row, :] = image_a[:split_row, :] + if 0 < split_row < height: + output[max(0, split_row - 1) : min(height, split_row + 1), :] = 1.0 + else: + split_col = int(round(width * slider_position)) + output[:, :split_col] = image_a[:, :split_col] + if 0 < split_col < width: + output[:, max(0, split_col - 1) : min(width, split_col + 1)] = 1.0 + + return output + + +def render_checkerboard_comparison(image_a: np.ndarray, image_b: np.ndarray, checker_size: int) -> np.ndarray: + checker_size = max(1, checker_size) + height, width = image_a.shape[:2] + y_indices, x_indices = np.indices((height, width)) + checker_mask = ((x_indices // checker_size) + (y_indices // checker_size)) % 2 == 0 + return np.where(checker_mask[..., None], image_a, image_b) + + +def render_diff_comparison(image_a: np.ndarray, image_b: np.ndarray, diff_scale: float) -> np.ndarray: + return render_diff_metric(image_a=image_a, image_b=image_b, diff_metric="l1", diff_scale=diff_scale) + + +def render_diff_metric(image_a: np.ndarray, image_b: np.ndarray, diff_metric: str, diff_scale: float) -> np.ndarray: + error_map = compute_error_map(image_a=image_a, image_b=image_b, diff_metric=diff_metric) + scaled_error = np.clip(error_map * diff_scale, 0.0, 1.0) + return apply_jet_colormap(scaled_error) + + +def compute_error_map(image_a: np.ndarray, image_b: np.ndarray, diff_metric: str) -> np.ndarray: + if diff_metric == "l2": + return np.sqrt(np.mean(np.square(image_a - image_b), axis=-1)) + if diff_metric == "psnr": + return compute_psnr_error_map(image_a=image_a, image_b=image_b) + if diff_metric == "ssim": + return 1.0 - compute_ssim_map(image_a=image_a, image_b=image_b) + if diff_metric == "lpips": + lpips_value, _ = compute_lpips_metric(image_a=image_a, image_b=image_b) + return scalar_metric_to_map(lpips_value, image_a.shape[:2], higher_is_worse=True) + if diff_metric == "flip": + flip_map, flip_value, _ = compute_flip_metric(image_a=image_a, image_b=image_b) + if flip_map is not None: + return normalize_error_map(flip_map) + return scalar_metric_to_map(flip_value, image_a.shape[:2], higher_is_worse=True) + + return np.mean(np.abs(image_a - image_b), axis=-1) + + +def compute_psnr(image_a: np.ndarray, image_b: np.ndarray) -> float: + mse = float(np.mean(np.square(image_a - image_b))) + if mse <= 1.0e-12: + return math.inf + return 10.0 * math.log10(1.0 / mse) + + +def compute_psnr_error_map(image_a: np.ndarray, image_b: np.ndarray) -> np.ndarray: + mse = np.mean(np.square(image_a - image_b), axis=-1) + psnr_map = -10.0 * np.log10(np.maximum(mse, 1.0e-12)) + return 1.0 - np.clip(psnr_map / 60.0, 0.0, 1.0) + + +def compute_ssim(image_a: np.ndarray, image_b: np.ndarray) -> float: + return float(np.mean(compute_ssim_map(image_a=image_a, image_b=image_b))) + + +def compute_ssim_map(image_a: np.ndarray, image_b: np.ndarray, kernel_size: int = 11) -> np.ndarray: + image_a = np.clip(image_a, 0.0, 1.0) + image_b = np.clip(image_b, 0.0, 1.0) + + mu_a = box_filter(image_a, kernel_size=kernel_size) + mu_b = box_filter(image_b, kernel_size=kernel_size) + mu_a_squared = np.square(mu_a) + mu_b_squared = np.square(mu_b) + mu_ab = mu_a * mu_b + + sigma_a_squared = box_filter(np.square(image_a), kernel_size=kernel_size) - mu_a_squared + sigma_b_squared = box_filter(np.square(image_b), kernel_size=kernel_size) - mu_b_squared + sigma_ab = box_filter(image_a * image_b, kernel_size=kernel_size) - mu_ab + + c1 = 0.01**2 + c2 = 0.03**2 + numerator = (2.0 * mu_ab + c1) * (2.0 * sigma_ab + c2) + denominator = (mu_a_squared + mu_b_squared + c1) * (sigma_a_squared + sigma_b_squared + c2) + ssim = numerator / np.maximum(denominator, 1.0e-12) + return np.clip(np.mean(ssim, axis=-1), 0.0, 1.0) + + +def box_filter(image: np.ndarray, kernel_size: int) -> np.ndarray: + kernel_size = max(1, int(kernel_size)) + if kernel_size % 2 == 0: + kernel_size += 1 + + radius = kernel_size // 2 + padded = np.pad(image, ((radius, radius), (radius, radius), (0, 0)), mode="edge") + integral = np.pad(padded, ((1, 0), (1, 0), (0, 0)), mode="constant") + integral = np.cumsum(np.cumsum(integral, axis=0), axis=1) + summed = ( + integral[kernel_size:, kernel_size:] + - integral[:-kernel_size, kernel_size:] + - integral[kernel_size:, :-kernel_size] + + integral[:-kernel_size, :-kernel_size] + ) + return summed / float(kernel_size * kernel_size) + + +def scalar_metric_to_map(value: Optional[float], shape: Tuple[int, int], higher_is_worse: bool) -> np.ndarray: + if value is None or not np.isfinite(value): + return np.zeros(shape, dtype=np.float32) + + if higher_is_worse: + normalized_value = float(np.clip(value, 0.0, 1.0)) + else: + normalized_value = 1.0 - float(np.clip(value, 0.0, 1.0)) + + return np.full(shape, normalized_value, dtype=np.float32) + + +def normalize_error_map(error_map: np.ndarray) -> np.ndarray: + if error_map.ndim == 3: + error_map = np.mean(error_map[..., :3], axis=-1) + return np.clip(error_map.astype(np.float32), 0.0, 1.0) + + +def image_to_torch_tensor(image: np.ndarray): + import torch + + return torch.from_numpy(np.ascontiguousarray(image)).permute(2, 0, 1).unsqueeze(0).float() + + +def compute_lpips_metric(image_a: np.ndarray, image_b: np.ndarray) -> Tuple[Optional[float], Optional[str]]: + try: + import torch + from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity + except ImportError as exc: + return None, f"LPIPS unavailable: {exc}" + + try: + device = "cuda" if torch.cuda.is_available() else "cpu" + metric = LearnedPerceptualImagePatchSimilarity(net_type="vgg", normalize=True).to(device) + with torch.no_grad(): + value = metric(image_to_torch_tensor(image_a).to(device), image_to_torch_tensor(image_b).to(device)) + return float(value.detach().cpu().item()), None + except Exception as exc: + return None, f"LPIPS failed: {exc}" + + +def compute_flip_metric( + image_a: np.ndarray, image_b: np.ndarray +) -> Tuple[Optional[np.ndarray], Optional[float], Optional[str]]: + try: + import flip_evaluator + except ImportError as exc: + try: + import nbflip as flip_evaluator + except ImportError: + return ( + None, + None, + f"FLIP unavailable: install NVIDIA FLIP with `pip install flip-evaluator`, not `flip`: {exc}", + ) + + try: + flip_map, mean_flip, _ = flip_evaluator.evaluate( + np.ascontiguousarray(image_a.astype(np.float32)), + np.ascontiguousarray(image_b.astype(np.float32)), + "ldr", + True, + False, + True, + {}, + ) + return normalize_error_map(flip_map), float(mean_flip), None + except Exception as exc: + return None, None, f"FLIP failed: {exc}" + + +def compute_metric_results( + image_a: np.ndarray, + image_b: np.ndarray, + include_perceptual_metrics: bool = False, +) -> MetricResults: + if include_perceptual_metrics: + lpips_value, lpips_error = compute_lpips_metric(image_a=image_a, image_b=image_b) + _, flip_value, flip_error = compute_flip_metric(image_a=image_a, image_b=image_b) + else: + lpips_value = None + flip_value = None + lpips_error = "LPIPS not computed" + flip_error = "FLIP not computed" + + return MetricResults( + psnr=compute_psnr(image_a=image_a, image_b=image_b), + ssim=compute_ssim(image_a=image_a, image_b=image_b), + lpips=lpips_value, + flip=flip_value, + lpips_error=lpips_error, + flip_error=flip_error, + ) + + +def format_metric_value(value: Optional[float], precision: int = 5) -> str: + if value is None: + return "unavailable" + if math.isinf(value): + return "inf" + if math.isnan(value): + return "nan" + return f"{value:.{precision}f}" + + +def apply_jet_colormap(value: np.ndarray) -> np.ndarray: + value = np.clip(value, 0.0, 1.0) + red = np.clip(1.5 - np.abs(4.0 * value - 3.0), 0.0, 1.0) + green = np.clip(1.5 - np.abs(4.0 * value - 2.0), 0.0, 1.0) + blue = np.clip(1.5 - np.abs(4.0 * value - 1.0), 0.0, 1.0) + return np.stack((red, green, blue), axis=-1) + + +def float_image_to_uint8(image: np.ndarray) -> np.ndarray: + return (np.clip(image, 0.0, 1.0) * 255.0).astype(np.uint8) + + +def resize_uint8_image(image: np.ndarray, size: Tuple[int, int]) -> np.ndarray: + width, height = size + resample_filter = getattr(Image, "Resampling", Image).LANCZOS + resized = Image.fromarray(image).resize((max(1, width), max(1, height)), resample_filter) + return np.asarray(resized, dtype=np.uint8) + + +def get_image_canvas_rect( + image_size: Tuple[int, int], + canvas_size: Tuple[int, int], + display_mode: str, +) -> Tuple[float, float, float, float]: + image_width, image_height = image_size + canvas_width, canvas_height = canvas_size + canvas_width = max(1, canvas_width) + canvas_height = max(1, canvas_height) + + if display_mode == "fit": + return 0.0, 0.0, float(canvas_width), float(canvas_height) + + scale = min(canvas_width / image_width, canvas_height / image_height) + display_width = image_width * scale + display_height = image_height * scale + image_x0 = 0.5 * (canvas_width - display_width) + image_y0 = 0.5 * (canvas_height - display_height) + return image_x0, image_y0, display_width, display_height + + +def fit_image_to_canvas( + image: np.ndarray, + canvas_size: Tuple[int, int], + display_mode: str, +) -> np.ndarray: + canvas_width, canvas_height = canvas_size + canvas_width = max(1, canvas_width) + canvas_height = max(1, canvas_height) + image_height, image_width = image.shape[:2] + canvas = np.zeros((canvas_height, canvas_width, 3), dtype=np.uint8) + + if display_mode == "fit": + return resize_uint8_image(image, (canvas_width, canvas_height)) + + _, _, display_width, display_height = get_image_canvas_rect( + image_size=(image_width, image_height), + canvas_size=(canvas_width, canvas_height), + display_mode=display_mode, + ) + resized_width = max(1, int(round(display_width))) + resized_height = max(1, int(round(display_height))) + resized = resize_uint8_image(image, (resized_width, resized_height)) + + source_x0 = max(0, (resized_width - canvas_width) // 2) + source_y0 = max(0, (resized_height - canvas_height) // 2) + target_x0 = max(0, (canvas_width - resized_width) // 2) + target_y0 = max(0, (canvas_height - resized_height) // 2) + copy_width = min(resized_width - source_x0, canvas_width - target_x0) + copy_height = min(resized_height - source_y0, canvas_height - target_y0) + + canvas[target_y0 : target_y0 + copy_height, target_x0 : target_x0 + copy_width] = resized[ + source_y0 : source_y0 + copy_height, + source_x0 : source_x0 + copy_width, + ] + return canvas + + +class ImageComparisonViewer: + def __init__(self, image_pairs: List[ImagePair], host: str, port: int, target_fps: float) -> None: + self.image_pairs = image_pairs + self.host = host + self.port = port + self.target_fps = target_fps + self.viser = import_viser() + self.server = self.viser.ViserServer(host=self.host, port=self.port) + self.need_update = True + self.image_cache: Dict[str, Tuple[np.ndarray, np.ndarray, str]] = {} + self.metric_cache: Dict[str, MetricResults] = {} + self.error_map_cache: Dict[Tuple[str, str], np.ndarray] = {} + + self.image_pair_dropdown = None + self.display_mode_dropdown = None + self.mode_dropdown = None + self.slider_direction_dropdown = None + self.slider_position_slider = None + self.checker_size_slider = None + self.diff_metric_dropdown = None + self.diff_scale_slider = None + self.psnr_text = None + self.ssim_text = None + self.lpips_text = None + self.flip_text = None + self.compute_perceptual_metrics_button = None + self.status_text = None + + self.init_ui() + + @self.server.on_client_connect + def _(client) -> None: + self.need_update = True + + def init_ui(self) -> None: + with self.server.gui.add_folder("Image Comparison"): + image_pair_names = [image_pair.name for image_pair in self.image_pairs] + self.image_pair_dropdown = self.server.gui.add_dropdown( + "Image Pair", + options=image_pair_names, + initial_value=image_pair_names[0], + ) + previous_image_button = self.server.gui.add_button("Previous Image") + next_image_button = self.server.gui.add_button("Next Image") + self.display_mode_dropdown = self.server.gui.add_dropdown( + "Display Mode", + options=DISPLAY_MODES, + initial_value=DISPLAY_MODES[0], + ) + self.mode_dropdown = self.server.gui.add_dropdown( + "Mode", + options=COMPARISON_MODES, + initial_value=COMPARISON_MODES[0], + ) + self.slider_direction_dropdown = self.server.gui.add_dropdown( + "Slider Direction", + options=SLIDER_DIRECTIONS, + initial_value=SLIDER_DIRECTIONS[0], + ) + self.slider_position_slider = self.server.gui.add_slider( + "Slider Position", + min=0.0, + max=1.0, + step=0.01, + initial_value=0.5, + ) + self.checker_size_slider = self.server.gui.add_slider( + "Checker Size", + min=4, + max=256, + step=1, + initial_value=32, + ) + self.diff_metric_dropdown = self.server.gui.add_dropdown( + "Diff Metric", + options=DIFF_METRICS, + initial_value=DIFF_METRICS[0], + ) + self.diff_scale_slider = self.server.gui.add_slider( + "Diff Scale", + min=0.1, + max=20.0, + step=0.1, + initial_value=4.0, + ) + reload_button = self.server.gui.add_button("Reload Images") + self.status_text = self.server.gui.add_text("Status", initial_value="Loading", disabled=True) + + with self.server.gui.add_folder("Metrics"): + self.psnr_text = self.server.gui.add_text("PSNR", initial_value="unavailable", disabled=True) + self.ssim_text = self.server.gui.add_text("SSIM", initial_value="unavailable", disabled=True) + self.lpips_text = self.server.gui.add_text("LPIPS", initial_value="unavailable", disabled=True) + self.flip_text = self.server.gui.add_text("FLIP", initial_value="unavailable", disabled=True) + self.compute_perceptual_metrics_button = self.server.gui.add_button("Compute LPIPS / FLIP") + + controls = [ + self.image_pair_dropdown, + self.display_mode_dropdown, + self.mode_dropdown, + self.slider_direction_dropdown, + self.slider_position_slider, + self.checker_size_slider, + self.diff_metric_dropdown, + self.diff_scale_slider, + ] + + for control in controls: + + @control.on_update + def _(_) -> None: + self.need_update = True + + @reload_button.on_click + def _(_) -> None: + self.image_cache.clear() + self.metric_cache.clear() + self.error_map_cache.clear() + self.need_update = True + + @previous_image_button.on_click + def _(_) -> None: + self.select_relative_image_pair(offset=-1) + + @next_image_button.on_click + def _(_) -> None: + self.select_relative_image_pair(offset=1) + + @self.compute_perceptual_metrics_button.on_click + def _(_) -> None: + self.compute_selected_perceptual_metrics() + self.need_update = True + + def select_relative_image_pair(self, offset: int) -> None: + selected_name = self.image_pair_dropdown.value + image_pair_names = [image_pair.name for image_pair in self.image_pairs] + try: + selected_index = image_pair_names.index(selected_name) + except ValueError: + selected_index = 0 + + next_index = (selected_index + offset) % len(self.image_pairs) + self.image_pair_dropdown.value = image_pair_names[next_index] + self.need_update = True + + def get_selected_pair(self) -> ImagePair: + selected_name = self.image_pair_dropdown.value + for image_pair in self.image_pairs: + if image_pair.name == selected_name: + return image_pair + + return self.image_pairs[0] + + def get_aligned_pair(self, image_pair: ImagePair) -> Tuple[np.ndarray, np.ndarray, str]: + if image_pair.name not in self.image_cache: + self.image_cache[image_pair.name] = load_aligned_pair(image_pair) + return self.image_cache[image_pair.name] + + def get_metric_results(self, image_pair: ImagePair, image_a: np.ndarray, image_b: np.ndarray) -> MetricResults: + if image_pair.name not in self.metric_cache: + self.metric_cache[image_pair.name] = compute_metric_results(image_a=image_a, image_b=image_b) + return self.metric_cache[image_pair.name] + + def compute_selected_perceptual_metrics(self) -> None: + image_pair = self.get_selected_pair() + image_a, image_b, _ = self.get_aligned_pair(image_pair) + self.metric_cache[image_pair.name] = compute_metric_results( + image_a=image_a, + image_b=image_b, + include_perceptual_metrics=True, + ) + self.update_metric_widgets(metric_results=self.metric_cache[image_pair.name]) + + def update_metric_widgets(self, metric_results: MetricResults) -> None: + self.psnr_text.value = format_metric_value(metric_results.psnr, precision=4) + self.ssim_text.value = format_metric_value(metric_results.ssim, precision=5) + self.lpips_text.value = format_metric_value(metric_results.lpips, precision=5) + self.flip_text.value = format_metric_value(metric_results.flip, precision=5) + + def render_current_diff( + self, + image_pair: ImagePair, + image_a: np.ndarray, + image_b: np.ndarray, + ) -> np.ndarray: + diff_metric = self.diff_metric_dropdown.value + cache_key = (image_pair.name, diff_metric) + if cache_key not in self.error_map_cache: + self.error_map_cache[cache_key] = compute_error_map( + image_a=image_a, + image_b=image_b, + diff_metric=diff_metric, + ) + + scaled_error = np.clip(self.error_map_cache[cache_key] * float(self.diff_scale_slider.value), 0.0, 1.0) + return apply_jet_colormap(scaled_error) + + def render_current_comparison(self) -> np.ndarray: + image_pair = self.get_selected_pair() + image_a, image_b, status = self.get_aligned_pair(image_pair) + metric_results = self.get_metric_results(image_pair=image_pair, image_a=image_a, image_b=image_b) + self.update_metric_widgets(metric_results=metric_results) + mode = self.mode_dropdown.value + + if mode == "checkerboard": + output = render_checkerboard_comparison( + image_a=image_a, + image_b=image_b, + checker_size=int(self.checker_size_slider.value), + ) + elif mode == "diff": + output = self.render_current_diff(image_pair=image_pair, image_a=image_a, image_b=image_b) + else: + output = render_slider_comparison( + image_a=image_a, + image_b=image_b, + slider_position=float(self.slider_position_slider.value), + slider_direction=self.slider_direction_dropdown.value, + ) + + if mode == "slider": + self.status_text.value = f"{status} | mode: {mode}" + elif mode == "diff": + diff_metric = self.diff_metric_dropdown.value + warning = self.get_metric_warning(metric_results=metric_results, diff_metric=diff_metric) + self.status_text.value = f"{status} | mode: {mode} | diff: {diff_metric}{warning}" + else: + self.status_text.value = f"{status} | mode: {mode}" + return float_image_to_uint8(output) + + def get_metric_warning(self, metric_results: MetricResults, diff_metric: str) -> str: + if diff_metric == "lpips" and metric_results.lpips_error is not None: + return f" | {metric_results.lpips_error}" + if diff_metric == "flip" and metric_results.flip_error is not None: + return f" | {metric_results.flip_error}" + return "" + + def display_output(self, output: np.ndarray) -> None: + display_mode = self.display_mode_dropdown.value + for client in self.server.get_clients().values(): + canvas_width = int(client.camera.image_width or output.shape[1]) + canvas_height = int(client.camera.image_height or output.shape[0]) + display_image = fit_image_to_canvas( + image=output, + canvas_size=(canvas_width, canvas_height), + display_mode=display_mode, + ) + client.scene.set_background_image(display_image, format="jpeg") + + def update(self) -> None: + if not self.need_update: + return + + output = self.render_current_comparison() + self.display_output(output) + + self.need_update = False + + def run(self) -> None: + print_server_urls(host=self.host, port=self.port) + while True: + self.update() + time.sleep(max(0.001, 1.0 / self.target_fps)) + + +def get_candidate_host_addresses() -> List[str]: + addresses = ["127.0.0.1"] + try: + hostname = socket.gethostname() + for address_info in socket.getaddrinfo(hostname, None, family=socket.AF_INET): + address = address_info[4][0] + if address not in addresses and not address.startswith("127."): + addresses.append(address) + except OSError: + pass + return addresses + + +def print_server_urls(host: str, port: int) -> None: + if host in ("0.0.0.0", "::"): + print("Viser is listening on all interfaces. Try these URLs:") + for address in get_candidate_host_addresses(): + print(f" http://{address}:{port}") + else: + print(f"Viser URL: http://{host}:{port}") + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Viser based image comparison viewer.") + + input_group = parser.add_mutually_exclusive_group(required=True) + input_group.add_argument( + "--images", + nargs=2, + metavar=("IMAGE_A", "IMAGE_B"), + help="Compare two specific images.", + ) + input_group.add_argument( + "--folders", + nargs=2, + metavar=("FOLDER_A", "FOLDER_B"), + help="Compare matching image names from two folders.", + ) + + parser.add_argument("--host", type=str, default="0.0.0.0", help="Viser server host/interface.") + parser.add_argument("--port", type=int, default=8080, help="Viser server port.") + parser.add_argument("--target_fps", type=float, default=20.0, help="Maximum UI refresh rate.") + return parser.parse_args() + + +def main() -> None: + args = parse_args() + + if args.images is not None: + image_pairs = build_specific_image_pair(Path(args.images[0]), Path(args.images[1])) + else: + image_pairs = build_folder_image_pairs(Path(args.folders[0]), Path(args.folders[1])) + + viewer = ImageComparisonViewer( + image_pairs=image_pairs, + host=args.host, + port=args.port, + target_fps=args.target_fps, + ) + viewer.run() + + +if __name__ == "__main__": + main() From 2b6414825513287ad26ce9d1304cad617deb65e4 Mon Sep 17 00:00:00 2001 From: Nicolas Moenne-Loccoz Date: Wed, 29 Apr 2026 11:44:52 -0400 Subject: [PATCH 16/42] fix(validation): avoid NaNs in linear to sRGB Clamp the power branch away from zero so validation image conversion remains finite. --- threedgrut/utils/post_processing_linear_to_srgb.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/threedgrut/utils/post_processing_linear_to_srgb.py b/threedgrut/utils/post_processing_linear_to_srgb.py index f93e160e..93c2000c 100644 --- a/threedgrut/utils/post_processing_linear_to_srgb.py +++ b/threedgrut/utils/post_processing_linear_to_srgb.py @@ -58,7 +58,7 @@ def linear_to_srgb(x: torch.Tensor) -> torch.Tensor: Encoded values, same shape / dtype / device as ``x``. """ limit = 0.0031308 - positive_x = torch.clamp(x, min=0.0) + positive_x = torch.clamp(x, min=1e-08) return torch.where( x > limit, 1.055 * torch.pow(positive_x, 1.0 / 2.4) - 0.055, From ab1108ff7e85a644f841659f021a376c51ac1eb3 Mon Sep 17 00:00:00 2001 From: Nicolas Moenne-Loccoz Date: Thu, 30 Apr 2026 08:40:41 -0400 Subject: [PATCH 17/42] fix(tools): improve image comparison metrics Display metrics as readable text, add folder-level aggregates, and compute FLIP by default while removing LPIPS from the viewer. --- tools/image_comparison/README.md | 5 +- tools/image_comparison/image_comparison.py | 151 ++++++++++----------- 2 files changed, 76 insertions(+), 80 deletions(-) diff --git a/tools/image_comparison/README.md b/tools/image_comparison/README.md index 65ef590f..f9e84293 100644 --- a/tools/image_comparison/README.md +++ b/tools/image_comparison/README.md @@ -47,7 +47,7 @@ Use `Previous Image` and `Next Image` to cycle through matched image pairs. ## Metrics -The `Metrics` panel displays scalar `PSNR`, `SSIM`, `LPIPS`, and `FLIP` values for the selected image pair. `PSNR` and `SSIM` are computed automatically. Use `Compute LPIPS / FLIP` to compute the heavier perceptual metrics on demand. +The `Metrics` panel displays readable text blocks for the current image pair and the global folder mean. `PSNR`, `SSIM`, and `FLIP` are computed automatically. The `Diff Metric` dropdown supports: @@ -55,7 +55,6 @@ The `Diff Metric` dropdown supports: - `l2`: per-pixel RGB root mean squared difference. - `psnr`: per-pixel PSNR-derived error, where lower PSNR is brighter. - `ssim`: local `1 - SSIM` dissimilarity. -- `lpips`: scalar LPIPS displayed as a uniform heatmap. - `flip`: FLIP error map when `flip-evaluator` is installed. -`LPIPS` depends on `torchmetrics` and its model weights. `FLIP` depends on the `flip-evaluator` package, which provides the `flip_evaluator` Python module. Do not install the unrelated `flip` package. +`FLIP` depends on the `flip-evaluator` package, which provides the `flip_evaluator` Python module. Do not install the unrelated `flip` package. diff --git a/tools/image_comparison/image_comparison.py b/tools/image_comparison/image_comparison.py index dc853cbb..b5158254 100644 --- a/tools/image_comparison/image_comparison.py +++ b/tools/image_comparison/image_comparison.py @@ -37,7 +37,7 @@ COMPARISON_MODES = ["slider", "checkerboard", "diff"] SLIDER_DIRECTIONS = ["vertical", "horizontal"] -DIFF_METRICS = ["l1", "l2", "psnr", "ssim", "lpips", "flip"] +DIFF_METRICS = ["l1", "l2", "psnr", "ssim", "flip"] DISPLAY_MODES = ["fit_largest_dimension", "fit"] @@ -45,9 +45,7 @@ class MetricResults: psnr: Optional[float] ssim: Optional[float] - lpips: Optional[float] flip: Optional[float] - lpips_error: Optional[str] = None flip_error: Optional[str] = None @@ -209,9 +207,6 @@ def compute_error_map(image_a: np.ndarray, image_b: np.ndarray, diff_metric: str return compute_psnr_error_map(image_a=image_a, image_b=image_b) if diff_metric == "ssim": return 1.0 - compute_ssim_map(image_a=image_a, image_b=image_b) - if diff_metric == "lpips": - lpips_value, _ = compute_lpips_metric(image_a=image_a, image_b=image_b) - return scalar_metric_to_map(lpips_value, image_a.shape[:2], higher_is_worse=True) if diff_metric == "flip": flip_map, flip_value, _ = compute_flip_metric(image_a=image_a, image_b=image_b) if flip_map is not None: @@ -296,29 +291,6 @@ def normalize_error_map(error_map: np.ndarray) -> np.ndarray: return np.clip(error_map.astype(np.float32), 0.0, 1.0) -def image_to_torch_tensor(image: np.ndarray): - import torch - - return torch.from_numpy(np.ascontiguousarray(image)).permute(2, 0, 1).unsqueeze(0).float() - - -def compute_lpips_metric(image_a: np.ndarray, image_b: np.ndarray) -> Tuple[Optional[float], Optional[str]]: - try: - import torch - from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity - except ImportError as exc: - return None, f"LPIPS unavailable: {exc}" - - try: - device = "cuda" if torch.cuda.is_available() else "cpu" - metric = LearnedPerceptualImagePatchSimilarity(net_type="vgg", normalize=True).to(device) - with torch.no_grad(): - value = metric(image_to_torch_tensor(image_a).to(device), image_to_torch_tensor(image_b).to(device)) - return float(value.detach().cpu().item()), None - except Exception as exc: - return None, f"LPIPS failed: {exc}" - - def compute_flip_metric( image_a: np.ndarray, image_b: np.ndarray ) -> Tuple[Optional[np.ndarray], Optional[float], Optional[str]]: @@ -349,30 +321,36 @@ def compute_flip_metric( return None, None, f"FLIP failed: {exc}" -def compute_metric_results( - image_a: np.ndarray, - image_b: np.ndarray, - include_perceptual_metrics: bool = False, -) -> MetricResults: - if include_perceptual_metrics: - lpips_value, lpips_error = compute_lpips_metric(image_a=image_a, image_b=image_b) - _, flip_value, flip_error = compute_flip_metric(image_a=image_a, image_b=image_b) - else: - lpips_value = None - flip_value = None - lpips_error = "LPIPS not computed" - flip_error = "FLIP not computed" - +def compute_metric_results(image_a: np.ndarray, image_b: np.ndarray) -> MetricResults: + _, flip_value, flip_error = compute_flip_metric(image_a=image_a, image_b=image_b) return MetricResults( psnr=compute_psnr(image_a=image_a, image_b=image_b), ssim=compute_ssim(image_a=image_a, image_b=image_b), - lpips=lpips_value, flip=flip_value, - lpips_error=lpips_error, flip_error=flip_error, ) +def mean_metric_value(values: List[Optional[float]]) -> Optional[float]: + finite_values = [value for value in values if value is not None and np.isfinite(value)] + if finite_values: + return float(np.mean(finite_values)) + + if any(value is not None and math.isinf(value) for value in values): + return math.inf + + return None + + +def aggregate_metric_results(metric_results: List[MetricResults]) -> MetricResults: + return MetricResults( + psnr=mean_metric_value([metrics.psnr for metrics in metric_results]), + ssim=mean_metric_value([metrics.ssim for metrics in metric_results]), + flip=mean_metric_value([metrics.flip for metrics in metric_results]), + flip_error=None if any(metrics.flip is not None for metrics in metric_results) else "FLIP not computed", + ) + + def format_metric_value(value: Optional[float], precision: int = 5) -> str: if value is None: return "unavailable" @@ -383,6 +361,18 @@ def format_metric_value(value: Optional[float], precision: int = 5) -> str: return f"{value:.{precision}f}" +def format_metric_markdown(title: str, metric_results: MetricResults, count: Optional[int] = None) -> str: + count_text = "" if count is None else f" ({count} images)" + return "\n".join( + [ + f"### {title}{count_text}", + f"- PSNR: **{format_metric_value(metric_results.psnr, precision=4)}**", + f"- SSIM: **{format_metric_value(metric_results.ssim, precision=5)}**", + f"- FLIP: **{format_metric_value(metric_results.flip, precision=5)}**", + ] + ) + + def apply_jet_colormap(value: np.ndarray) -> np.ndarray: value = np.clip(value, 0.0, 1.0) red = np.clip(1.5 - np.abs(4.0 * value - 3.0), 0.0, 1.0) @@ -471,6 +461,7 @@ def __init__(self, image_pairs: List[ImagePair], host: str, port: int, target_fp self.need_update = True self.image_cache: Dict[str, Tuple[np.ndarray, np.ndarray, str]] = {} self.metric_cache: Dict[str, MetricResults] = {} + self.global_metric_cache: Dict[bool, MetricResults] = {} self.error_map_cache: Dict[Tuple[str, str], np.ndarray] = {} self.image_pair_dropdown = None @@ -481,11 +472,8 @@ def __init__(self, image_pairs: List[ImagePair], host: str, port: int, target_fp self.checker_size_slider = None self.diff_metric_dropdown = None self.diff_scale_slider = None - self.psnr_text = None - self.ssim_text = None - self.lpips_text = None - self.flip_text = None - self.compute_perceptual_metrics_button = None + self.current_metrics_markdown = None + self.global_metrics_markdown = None self.status_text = None self.init_ui() @@ -549,11 +537,13 @@ def init_ui(self) -> None: self.status_text = self.server.gui.add_text("Status", initial_value="Loading", disabled=True) with self.server.gui.add_folder("Metrics"): - self.psnr_text = self.server.gui.add_text("PSNR", initial_value="unavailable", disabled=True) - self.ssim_text = self.server.gui.add_text("SSIM", initial_value="unavailable", disabled=True) - self.lpips_text = self.server.gui.add_text("LPIPS", initial_value="unavailable", disabled=True) - self.flip_text = self.server.gui.add_text("FLIP", initial_value="unavailable", disabled=True) - self.compute_perceptual_metrics_button = self.server.gui.add_button("Compute LPIPS / FLIP") + empty_metrics = MetricResults(psnr=None, ssim=None, flip=None) + self.current_metrics_markdown = self.server.gui.add_markdown( + format_metric_markdown("Current Image", empty_metrics) + ) + self.global_metrics_markdown = self.server.gui.add_markdown( + format_metric_markdown("Folder Mean", empty_metrics, count=len(self.image_pairs)) + ) controls = [ self.image_pair_dropdown, @@ -576,6 +566,7 @@ def _(_) -> None: def _(_) -> None: self.image_cache.clear() self.metric_cache.clear() + self.global_metric_cache.clear() self.error_map_cache.clear() self.need_update = True @@ -587,11 +578,6 @@ def _(_) -> None: def _(_) -> None: self.select_relative_image_pair(offset=1) - @self.compute_perceptual_metrics_button.on_click - def _(_) -> None: - self.compute_selected_perceptual_metrics() - self.need_update = True - def select_relative_image_pair(self, offset: int) -> None: selected_name = self.image_pair_dropdown.value image_pair_names = [image_pair.name for image_pair in self.image_pairs] @@ -622,21 +608,30 @@ def get_metric_results(self, image_pair: ImagePair, image_a: np.ndarray, image_b self.metric_cache[image_pair.name] = compute_metric_results(image_a=image_a, image_b=image_b) return self.metric_cache[image_pair.name] - def compute_selected_perceptual_metrics(self) -> None: - image_pair = self.get_selected_pair() - image_a, image_b, _ = self.get_aligned_pair(image_pair) - self.metric_cache[image_pair.name] = compute_metric_results( - image_a=image_a, - image_b=image_b, - include_perceptual_metrics=True, - ) - self.update_metric_widgets(metric_results=self.metric_cache[image_pair.name]) + def get_global_metric_results(self) -> MetricResults: + if True in self.global_metric_cache: + return self.global_metric_cache[True] + + global_metric_results = [] + for image_pair in self.image_pairs: + image_a, image_b, _ = self.get_aligned_pair(image_pair) + metric_results = self.get_metric_results(image_pair=image_pair, image_a=image_a, image_b=image_b) + global_metric_results.append(metric_results) - def update_metric_widgets(self, metric_results: MetricResults) -> None: - self.psnr_text.value = format_metric_value(metric_results.psnr, precision=4) - self.ssim_text.value = format_metric_value(metric_results.ssim, precision=5) - self.lpips_text.value = format_metric_value(metric_results.lpips, precision=5) - self.flip_text.value = format_metric_value(metric_results.flip, precision=5) + self.global_metric_cache[True] = aggregate_metric_results(global_metric_results) + return self.global_metric_cache[True] + + def update_metric_widgets( + self, + current_metric_results: MetricResults, + global_metric_results: MetricResults, + ) -> None: + self.current_metrics_markdown.content = format_metric_markdown("Current Image", current_metric_results) + self.global_metrics_markdown.content = format_metric_markdown( + "Folder Mean", + global_metric_results, + count=len(self.image_pairs), + ) def render_current_diff( self, @@ -660,7 +655,11 @@ def render_current_comparison(self) -> np.ndarray: image_pair = self.get_selected_pair() image_a, image_b, status = self.get_aligned_pair(image_pair) metric_results = self.get_metric_results(image_pair=image_pair, image_a=image_a, image_b=image_b) - self.update_metric_widgets(metric_results=metric_results) + global_metric_results = self.get_global_metric_results() + self.update_metric_widgets( + current_metric_results=metric_results, + global_metric_results=global_metric_results, + ) mode = self.mode_dropdown.value if mode == "checkerboard": @@ -690,8 +689,6 @@ def render_current_comparison(self) -> np.ndarray: return float_image_to_uint8(output) def get_metric_warning(self, metric_results: MetricResults, diff_metric: str) -> str: - if diff_metric == "lpips" and metric_results.lpips_error is not None: - return f" | {metric_results.lpips_error}" if diff_metric == "flip" and metric_results.flip_error is not None: return f" | {metric_results.flip_error}" return "" From 7e41b455258f2d7789c4340c41576a70d94fa9c6 Mon Sep 17 00:00:00 2001 From: Nicolas Moenne-Loccoz Date: Thu, 30 Apr 2026 08:42:10 -0400 Subject: [PATCH 18/42] feat(validation): add simple SH bake comparison Add simple fixed-frame PPISP baking variants to the validation workflow and keep the fitted bake path aligned with the current L2 objective. --- threedgrut/export/scripts/export_usd.py | 5 +- .../post_processing_sh_bake_validation.py | 292 +++++++++++------- threedgrut/export/usd/exporter.py | 4 +- .../export/usd/post_processing_sh_bake.py | 7 +- .../usd/post_processing_sh_simple_bake.py | 151 +++++++++ 5 files changed, 329 insertions(+), 130 deletions(-) create mode 100644 threedgrut/export/usd/post_processing_sh_simple_bake.py diff --git a/threedgrut/export/scripts/export_usd.py b/threedgrut/export/scripts/export_usd.py index 1ce3696e..95f4a197 100644 --- a/threedgrut/export/scripts/export_usd.py +++ b/threedgrut/export/scripts/export_usd.py @@ -345,10 +345,7 @@ def main(): ) # Load dataset for camera export and for train-split post-processing SH baking. dataset = None - needs_dataset = ( - not args.no_cameras - or (post_processing is not None and export_post_processing) - ) + needs_dataset = not args.no_cameras or (post_processing is not None and export_post_processing) if needs_dataset: try: import threedgrut.datasets as datasets diff --git a/threedgrut/export/scripts/post_processing_sh_bake_validation.py b/threedgrut/export/scripts/post_processing_sh_bake_validation.py index 39615c19..69cedff5 100644 --- a/threedgrut/export/scripts/post_processing_sh_bake_validation.py +++ b/threedgrut/export/scripts/post_processing_sh_bake_validation.py @@ -27,7 +27,7 @@ import json import sys from pathlib import Path -from typing import Iterable +from typing import Dict, Iterable sys.path.insert(0, str(Path(__file__).resolve().parents[3])) @@ -48,11 +48,16 @@ apply_achromatic_vignetting, normalize_ppisp_bake_vignetting_mode, ) -from threedgrut.utils.color_correct import color_correct_affine +from threedgrut.export.usd.post_processing_sh_simple_bake import simple_bake from threedgrut.utils.logger import logger from threedgrut.utils.post_processing_linear_to_srgb import linear_to_srgb from threedgrut.utils.render import apply_post_processing +BAKE_FLAVOR_FIT = "fit" +BAKE_FLAVOR_SIMPLE = "simple" +BAKE_FLAVOR_SIMPLE_HIGHER_ORDER = "simple-higher-order" +BAKE_FLAVOR_ALL = "all" + def _setShFitParameters(model) -> Iterable[torch.nn.Parameter]: for parameter in model.parameters(): @@ -128,8 +133,14 @@ def _fitBakedSh( optimizer.zero_grad(set_to_none=True) bakedOutputs = bakedModel(gpuBatch) - fittedRgb = torch.clamp(linear_to_srgb(_applyAchromaticVignetting(bakedOutputs["pred_rgb"], fixedPpisp, gpuBatch, vignettingMode)), 0, 1) - loss = torch.nn.functional.l1_loss(fittedRgb, referenceRgb) + fittedRgb = torch.clamp( + linear_to_srgb( + _applyAchromaticVignetting(bakedOutputs["pred_rgb"], fixedPpisp, gpuBatch, vignettingMode) + ), + 0, + 1, + ) + loss = torch.nn.functional.mse_loss(fittedRgb, referenceRgb) loss.backward() optimizer.step() @@ -147,6 +158,7 @@ def _fitBakedSh( def _evaluateBakedSh( referenceModel, bakedModel, + simpleBakedModels: Dict[str, nn.Module], fixedPpisp, fullFixedPpisp, dataset, @@ -165,28 +177,35 @@ def _evaluateBakedSh( fullReferencePath = outputRoot / "full_ppisp_reference" referencePath = outputRoot / "reference" unfittedPath = outputRoot / "unfitted" - bakedPath = outputRoot / "baked" - assistedPath = outputRoot / "baked_assisted" fullReferencePath.mkdir(parents=True, exist_ok=True) referencePath.mkdir(parents=True, exist_ok=True) unfittedPath.mkdir(parents=True, exist_ok=True) - bakedPath.mkdir(parents=True, exist_ok=True) - assistedPath.mkdir(parents=True, exist_ok=True) + bakedPath = outputRoot / "baked" if bakedModel is not None else None + assistedPath = outputRoot / "baked_assisted" if bakedModel is not None else None + if bakedPath is not None: + bakedPath.mkdir(parents=True, exist_ok=True) + if assistedPath is not None: + assistedPath.mkdir(parents=True, exist_ok=True) + simplePaths = {name: outputRoot / f"{name}_baked" for name in simpleBakedModels} + for simplePath in simplePaths.values(): + simplePath.mkdir(parents=True, exist_ok=True) unfittedPsnrValues = [] psnrValues = [] ssimValues = [] lpipsValues = [] - ccPsnrValues = [] - ccSsimValues = [] - ccLpipsValues = [] assistedPsnrValues = [] assistedSsimValues = [] assistedLpipsValues = [] - assistedCcPsnrValues = [] - assistedCcSsimValues = [] - assistedCcLpipsValues = [] inferenceTimeValues = [] + simpleMetricValues = { + name: { + "psnr": [], + "ssim": [], + "lpips": [], + } + for name in simpleBakedModels + } logger.start_progress(task_name="Evaluating baked SH", total_steps=len(dataloader), color="orange1") for iteration, batch in enumerate(dataloader): @@ -196,17 +215,6 @@ def _evaluateBakedSh( referenceRgb = _renderReference(referenceModel, fixedPpisp, gpuBatch) unfittedOutputs = referenceModel(gpuBatch) unfittedRgb = unfittedOutputs["pred_rgb"] - bakedOutputs = bakedModel(gpuBatch) - bakedRgb = torch.clamp( - linear_to_srgb(bakedOutputs["pred_rgb"]), - 0, - 1, - ) - assistedRgb = torch.clamp( - linear_to_srgb(_applyAchromaticVignetting(bakedOutputs["pred_rgb"], fixedPpisp, gpuBatch, vignettingMode)), - 0, - 1, - ) torchvision.utils.save_image( fullReferenceRgb.squeeze(0).permute(2, 0, 1).clip(0, 1), @@ -220,108 +228,111 @@ def _evaluateBakedSh( unfittedRgb.squeeze(0).permute(2, 0, 1).clip(0, 1), unfittedPath / f"{iteration:05d}.png", ) - torchvision.utils.save_image( - bakedRgb.squeeze(0).permute(2, 0, 1).clip(0, 1), - bakedPath / f"{iteration:05d}.png", - ) - torchvision.utils.save_image( - assistedRgb.squeeze(0).permute(2, 0, 1).clip(0, 1), - assistedPath / f"{iteration:05d}.png", - ) unfittedPsnrValues.append(criterions["psnr"](unfittedRgb, referenceRgb).item()) - psnrValues.append(criterions["psnr"](bakedRgb, referenceRgb).item()) - assistedPsnrValues.append(criterions["psnr"](assistedRgb, referenceRgb).item()) - if computeExtraMetrics: - ssimValues.append( - criterions["ssim"]( - bakedRgb.permute(0, 3, 1, 2), - referenceRgb.permute(0, 3, 1, 2), - ).item() - ) - lpipsValues.append( - criterions["lpips"]( - bakedRgb.clip(0, 1).permute(0, 3, 1, 2), - referenceRgb.clip(0, 1).permute(0, 3, 1, 2), - ).item() - ) - assistedSsimValues.append( - criterions["ssim"]( - assistedRgb.permute(0, 3, 1, 2), - referenceRgb.permute(0, 3, 1, 2), - ).item() - ) - assistedLpipsValues.append( - criterions["lpips"]( - assistedRgb.clip(0, 1).permute(0, 3, 1, 2), - referenceRgb.clip(0, 1).permute(0, 3, 1, 2), - ).item() - ) - bakedRgbCc = color_correct_affine(bakedRgb, referenceRgb) - ccPsnrValues.append(criterions["psnr"](bakedRgbCc, referenceRgb).item()) - ccSsimValues.append( - criterions["ssim"]( - bakedRgbCc.permute(0, 3, 1, 2), - referenceRgb.permute(0, 3, 1, 2), - ).item() - ) - ccLpipsValues.append( - criterions["lpips"]( - bakedRgbCc.clip(0, 1).permute(0, 3, 1, 2), - referenceRgb.clip(0, 1).permute(0, 3, 1, 2), - ).item() + if bakedModel is not None: + bakedOutputs = bakedModel(gpuBatch) + bakedRgb = torch.clamp(linear_to_srgb(bakedOutputs["pred_rgb"]), 0, 1) + assistedRgb = torch.clamp( + linear_to_srgb( + _applyAchromaticVignetting(bakedOutputs["pred_rgb"], fixedPpisp, gpuBatch, vignettingMode) + ), + 0, + 1, ) - assistedRgbCc = color_correct_affine(assistedRgb, referenceRgb) - assistedCcPsnrValues.append(criterions["psnr"](assistedRgbCc, referenceRgb).item()) - assistedCcSsimValues.append( - criterions["ssim"]( - assistedRgbCc.permute(0, 3, 1, 2), - referenceRgb.permute(0, 3, 1, 2), - ).item() + torchvision.utils.save_image( + bakedRgb.squeeze(0).permute(2, 0, 1).clip(0, 1), + bakedPath / f"{iteration:05d}.png", ) - assistedCcLpipsValues.append( - criterions["lpips"]( - assistedRgbCc.clip(0, 1).permute(0, 3, 1, 2), - referenceRgb.clip(0, 1).permute(0, 3, 1, 2), - ).item() + torchvision.utils.save_image( + assistedRgb.squeeze(0).permute(2, 0, 1).clip(0, 1), + assistedPath / f"{iteration:05d}.png", ) - if "frame_time_ms" in bakedOutputs: - inferenceTimeValues.append(bakedOutputs["frame_time_ms"]) - - logger.log_progress(task_name="Evaluating baked SH", advance=1, iteration=str(iteration), psnr=psnrValues[-1]) + psnrValues.append(criterions["psnr"](bakedRgb, referenceRgb).item()) + assistedPsnrValues.append(criterions["psnr"](assistedRgb, referenceRgb).item()) + if computeExtraMetrics: + ssimValues.append( + criterions["ssim"](bakedRgb.permute(0, 3, 1, 2), referenceRgb.permute(0, 3, 1, 2)).item() + ) + lpipsValues.append( + criterions["lpips"]( + bakedRgb.clip(0, 1).permute(0, 3, 1, 2), referenceRgb.clip(0, 1).permute(0, 3, 1, 2) + ).item() + ) + assistedSsimValues.append( + criterions["ssim"](assistedRgb.permute(0, 3, 1, 2), referenceRgb.permute(0, 3, 1, 2)).item() + ) + assistedLpipsValues.append( + criterions["lpips"]( + assistedRgb.clip(0, 1).permute(0, 3, 1, 2), referenceRgb.clip(0, 1).permute(0, 3, 1, 2) + ).item() + ) + + if "frame_time_ms" in bakedOutputs: + inferenceTimeValues.append(bakedOutputs["frame_time_ms"]) + + for simpleName, simpleModel in simpleBakedModels.items(): + simpleOutputs = simpleModel(gpuBatch) + simpleRgb = torch.clamp(simpleOutputs["pred_rgb"], 0, 1) + torchvision.utils.save_image( + simpleRgb.squeeze(0).permute(2, 0, 1).clip(0, 1), + simplePaths[simpleName] / f"{iteration:05d}.png", + ) + simpleValues = simpleMetricValues[simpleName] + simpleValues["psnr"].append(criterions["psnr"](simpleRgb, referenceRgb).item()) + if computeExtraMetrics: + simpleValues["ssim"].append( + criterions["ssim"](simpleRgb.permute(0, 3, 1, 2), referenceRgb.permute(0, 3, 1, 2)).item() + ) + simpleValues["lpips"].append( + criterions["lpips"]( + simpleRgb.clip(0, 1).permute(0, 3, 1, 2), + referenceRgb.clip(0, 1).permute(0, 3, 1, 2), + ).item() + ) + + progressPsnr = psnrValues[-1] if psnrValues else unfittedPsnrValues[-1] + logger.log_progress(task_name="Evaluating baked SH", advance=1, iteration=str(iteration), psnr=progressPsnr) logger.end_progress(task_name="Evaluating baked SH") metrics = { "vignetting_mode": vignettingMode, "unfitted_mean_psnr": float(np.mean(unfittedPsnrValues)), "unfitted_std_psnr": float(np.std(unfittedPsnrValues)), - "mean_psnr": float(np.mean(psnrValues)), - "std_psnr": float(np.std(psnrValues)), - "assisted_mean_psnr": float(np.mean(assistedPsnrValues)), - "assisted_std_psnr": float(np.std(assistedPsnrValues)), } - if computeExtraMetrics: + if psnrValues: metrics |= { - "mean_ssim": float(np.mean(ssimValues)), - "mean_lpips": float(np.mean(lpipsValues)), - "mean_cc_psnr": float(np.mean(ccPsnrValues)), - "mean_cc_ssim": float(np.mean(ccSsimValues)), - "mean_cc_lpips": float(np.mean(ccLpipsValues)), - "assisted_mean_ssim": float(np.mean(assistedSsimValues)), - "assisted_mean_lpips": float(np.mean(assistedLpipsValues)), - "assisted_mean_cc_psnr": float(np.mean(assistedCcPsnrValues)), - "assisted_mean_cc_ssim": float(np.mean(assistedCcSsimValues)), - "assisted_mean_cc_lpips": float(np.mean(assistedCcLpipsValues)), + "mean_psnr": float(np.mean(psnrValues)), + "std_psnr": float(np.std(psnrValues)), + "assisted_mean_psnr": float(np.mean(assistedPsnrValues)), + "assisted_std_psnr": float(np.std(assistedPsnrValues)), } + if computeExtraMetrics: + if ssimValues: + metrics |= { + "mean_ssim": float(np.mean(ssimValues)), + "mean_lpips": float(np.mean(lpipsValues)), + "assisted_mean_ssim": float(np.mean(assistedSsimValues)), + "assisted_mean_lpips": float(np.mean(assistedLpipsValues)), + } + for simpleName, simpleValues in simpleMetricValues.items(): + metrics[f"{simpleName}_mean_psnr"] = float(np.mean(simpleValues["psnr"])) + metrics[f"{simpleName}_std_psnr"] = float(np.std(simpleValues["psnr"])) + if computeExtraMetrics: + metrics |= { + f"{simpleName}_mean_ssim": float(np.mean(simpleValues["ssim"])), + f"{simpleName}_mean_lpips": float(np.mean(simpleValues["lpips"])), + } if inferenceTimeValues: metrics["mean_inference_time"] = f"{np.mean(inferenceTimeValues):.2f} ms/frame" with open(outputRoot / "metrics.json", "w") as file: json.dump(metrics, file, indent=2) - logger.log_table("Post-Processing SH Bake Validation", record=metrics) + psnrMetrics = {key: value for key, value in metrics.items() if "psnr" in key} + logger.log_table("Post-Processing SH Bake Validation PSNR", record=psnrMetrics) return metrics @@ -354,6 +365,21 @@ def main() -> None: help="Number of sequential passes over the train/reference set.", ) parser.add_argument("--learning-rate", dest="learningRate", default=1.0e-3, type=float, help="SH fitting LR.") + parser.add_argument( + "--bake-flavor", + dest="bakeFlavor", + choices=[ + BAKE_FLAVOR_FIT, + BAKE_FLAVOR_SIMPLE, + BAKE_FLAVOR_SIMPLE_HIGHER_ORDER, + BAKE_FLAVOR_ALL, + ], + default=BAKE_FLAVOR_FIT, + help=( + "Bake flavor to evaluate. 'fit' optimizes SH; 'simple' one-shot bakes DC SH; " + "'simple-higher-order' also linearizes higher-order SH; 'all' compares every flavor." + ), + ) parser.add_argument( "--vignetting-mode", dest="vignettingMode", @@ -400,29 +426,59 @@ def main() -> None: ).eval() referenceModel = renderer.model.eval() - bakedModel = renderer.model.clone().eval() - bakedModel.build_acc() outputRoot = Path(renderer.out_dir) / f"post_processing_sh_bake_ci{args.cameraId}_fi{args.frameId}" outputRoot.mkdir(parents=True, exist_ok=True) trainDataset, trainDataloader = _createTrainDataloader(renderer.conf) - logger.info(f"Fitting SH coefficients to fixed PPISP camera={args.cameraId} frame={args.frameId}") - _fitBakedSh( - referenceModel=referenceModel, - bakedModel=bakedModel, - fixedPpisp=fixedPpisp, - dataset=trainDataset, - dataloader=trainDataloader, - fitEpochs=args.fitEpochs, - learningRate=args.learningRate, - vignettingMode=vignettingMode, - ) + runFit = args.bakeFlavor in (BAKE_FLAVOR_FIT, BAKE_FLAVOR_ALL) + simpleFlavorHigherOrderFlags = [] + if args.bakeFlavor in (BAKE_FLAVOR_SIMPLE, BAKE_FLAVOR_ALL): + simpleFlavorHigherOrderFlags.append(("simple", False)) + if args.bakeFlavor in (BAKE_FLAVOR_SIMPLE_HIGHER_ORDER, BAKE_FLAVOR_ALL): + simpleFlavorHigherOrderFlags.append(("simple_higher_order", True)) + + bakedModel = None + if runFit: + bakedModel = renderer.model.clone().eval() + bakedModel.build_acc() + logger.info(f"Fitting SH coefficients to fixed PPISP camera={args.cameraId} frame={args.frameId}") + _fitBakedSh( + referenceModel=referenceModel, + bakedModel=bakedModel, + fixedPpisp=fixedPpisp, + dataset=trainDataset, + dataloader=trainDataloader, + fitEpochs=args.fitEpochs, + learningRate=args.learningRate, + vignettingMode=vignettingMode, + ) + + simpleBakedModels = {} + for simpleName, higherOrder in simpleFlavorHigherOrderFlags: + simpleModel = renderer.model.clone().eval() + logger.info( + f"Simple-baking SH for camera_id={args.cameraId} " + f"frame_id={args.frameId} (fixed exposure/color; higher_order={higherOrder})" + ) + exposure, color = simple_bake( + model=simpleModel, + ppisp=renderer.post_processing, + camera_id=args.cameraId, + frame_id=args.frameId, + higher_order=higherOrder, + ) + simpleModel.build_acc() + simpleBakedModels[simpleName] = simpleModel + logger.info( + f"{simpleName} bake done. exposure={exposure:.6f}; " f"color={[float(value) for value in color.tolist()]}" + ) _evaluateBakedSh( referenceModel=referenceModel, bakedModel=bakedModel, + simpleBakedModels=simpleBakedModels, fixedPpisp=fixedPpisp, fullFixedPpisp=fullFixedPpisp, dataset=renderer.dataset, diff --git a/threedgrut/export/usd/exporter.py b/threedgrut/export/usd/exporter.py index 8fc1b14a..234603d0 100644 --- a/threedgrut/export/usd/exporter.py +++ b/threedgrut/export/usd/exporter.py @@ -687,9 +687,7 @@ def _export_ppisp( ppisp_config = getattr(post_processing, "config", None) controllers = getattr(post_processing, "controllers", None) has_controller = ( - bool(getattr(ppisp_config, "use_controller", False)) - and controllers is not None - and len(controllers) > 0 + bool(getattr(ppisp_config, "use_controller", False)) and controllers is not None and len(controllers) > 0 ) if has_controller: logger.warning( diff --git a/threedgrut/export/usd/post_processing_sh_bake.py b/threedgrut/export/usd/post_processing_sh_bake.py index 8d7c37e6..08ce2d76 100644 --- a/threedgrut/export/usd/post_processing_sh_bake.py +++ b/threedgrut/export/usd/post_processing_sh_bake.py @@ -137,7 +137,7 @@ def bake_post_processing_into_sh( fixed_post_processing, gpu_batch, ) - loss = torch.nn.functional.l1_loss(fitted_rgb, reference_rgb) + loss = torch.nn.functional.mse_loss(fitted_rgb, reference_rgb) loss.backward() optimizer.step() @@ -244,10 +244,7 @@ def estimate_achromatic_vignetting( delta = uv - center r2 = torch.sum(delta * delta, dim=-1) falloff = ( - 1.0 - + vig_params[channel, 2] * r2 - + vig_params[channel, 3] * r2 * r2 - + vig_params[channel, 4] * r2 * r2 * r2 + 1.0 + vig_params[channel, 2] * r2 + vig_params[channel, 3] * r2 * r2 + vig_params[channel, 4] * r2 * r2 * r2 ) channel_falloff.append(torch.clamp(falloff, 0.0, 1.0)) diff --git a/threedgrut/export/usd/post_processing_sh_simple_bake.py b/threedgrut/export/usd/post_processing_sh_simple_bake.py new file mode 100644 index 00000000..58ba76f2 --- /dev/null +++ b/threedgrut/export/usd/post_processing_sh_simple_bake.py @@ -0,0 +1,151 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""One-shot bake of a fixed PPISP transform into Gaussian SH coefficients.""" + +from __future__ import annotations + +from typing import Tuple + +import torch +from ppisp import PPISP, ppisp_apply + +from threedgrut.utils.render import RGB2SH, SH2RGB + + +def get_fixed_frame_params( + ppisp: PPISP, + frame_id: int, +) -> Tuple[float, torch.Tensor]: + """Return exposure offset and color params for one fixed PPISP frame.""" + num_frames = int(ppisp.exposure_params.shape[0]) + if frame_id < 0 or frame_id >= num_frames: + raise ValueError(f"frame_id must be in [0, {num_frames - 1}], got {frame_id}.") + exposure = float(ppisp.exposure_params[frame_id].item()) + color = ppisp.color_params[frame_id].detach() + return exposure, color + + +def _bake_dc_through_ppisp( + dc_rgb_linear: torch.Tensor, + ppisp: PPISP, + camera_id: int, + exposure: float, + color: torch.Tensor, +) -> torch.Tensor: + """Apply PPISP with no vignetting to each Gaussian DC RGB color.""" + device = dc_rgb_linear.device + dtype = dc_rgb_linear.dtype + num_gaussians = dc_rgb_linear.shape[0] + + exposure_params = torch.tensor([exposure], device=device, dtype=dtype) + color_params = color.to(device=device, dtype=dtype).unsqueeze(0) + vignetting_params = torch.zeros_like(ppisp.vignetting_params, device=device, dtype=dtype) + pixel_coords = torch.zeros(num_gaussians, 2, device=device, dtype=dtype) + + return ppisp_apply( + exposure_params=exposure_params, + vignetting_params=vignetting_params, + color_params=color_params, + crf_params=ppisp.crf_params, + rgb_in=dc_rgb_linear.contiguous(), + pixel_coords=pixel_coords, + resolution_w=1, + resolution_h=1, + camera_idx=camera_id, + frame_idx=0, + ) + + +def _bake_dc_with_jacobian_through_ppisp( + dc_rgb_linear: torch.Tensor, + ppisp: PPISP, + camera_id: int, + exposure: float, + color: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Run PPISP forward and extract per-Gaussian RGB Jacobians.""" + rgb_in = dc_rgb_linear.detach().clone().requires_grad_(True) + rgb_out = _bake_dc_through_ppisp( + dc_rgb_linear=rgb_in, + ppisp=ppisp, + camera_id=camera_id, + exposure=exposure, + color=color, + ) + + num_gaussians = rgb_in.shape[0] + jacobian = torch.empty(num_gaussians, 3, 3, device=rgb_in.device, dtype=rgb_in.dtype) + for channel in range(3): + grad_out = torch.zeros_like(rgb_out) + grad_out[:, channel] = 1.0 + (grads,) = torch.autograd.grad( + outputs=rgb_out, + inputs=rgb_in, + grad_outputs=grad_out, + retain_graph=(channel < 2), + ) + jacobian[:, channel, :] = grads + + return rgb_out.detach(), jacobian.detach() + + +def _apply_jacobian_to_specular(features_specular: torch.nn.Parameter, jacobian: torch.Tensor) -> None: + """In-place linearization of higher-order SH coefficients by ``jacobian``.""" + num_gaussians, total = features_specular.shape + if total % 3 != 0: + raise ValueError(f"features_specular last-dim ({total}) must be divisible by 3.") + num_sh_coeffs = total // 3 + specular_rgb = features_specular.view(num_gaussians, num_sh_coeffs, 3) + transformed = torch.einsum("nij,nkj->nki", jacobian, specular_rgb) + specular_rgb.copy_(transformed) + + +def simple_bake( + model, + ppisp: PPISP, + camera_id: int, + frame_id: int, + higher_order: bool = False, +) -> Tuple[float, torch.Tensor]: + """Mutate SH coefficients with one fixed PPISP camera/frame transform.""" + exposure, color = get_fixed_frame_params(ppisp, frame_id) + + if higher_order: + with torch.enable_grad(): + dc_rgb_linear = SH2RGB(model.features_albedo).detach() + dc_rgb_baked, jacobian = _bake_dc_with_jacobian_through_ppisp( + dc_rgb_linear=dc_rgb_linear, + ppisp=ppisp, + camera_id=camera_id, + exposure=exposure, + color=color, + ) + with torch.no_grad(): + model.features_albedo.copy_(RGB2SH(dc_rgb_baked)) + _apply_jacobian_to_specular(model.features_specular, jacobian) + else: + with torch.no_grad(): + dc_rgb_linear = SH2RGB(model.features_albedo.detach()) + dc_rgb_baked = _bake_dc_through_ppisp( + dc_rgb_linear=dc_rgb_linear, + ppisp=ppisp, + camera_id=camera_id, + exposure=exposure, + color=color, + ) + model.features_albedo.copy_(RGB2SH(dc_rgb_baked)) + + return exposure, color From 6e835afbb410eaca174b0ed3294f066e12d28602 Mon Sep 17 00:00:00 2001 From: Nicolas Moenne-Loccoz Date: Tue, 5 May 2026 15:05:53 -0400 Subject: [PATCH 19/42] feat(export): add ParticleField sorting hint option Expose and validate ParticleField sortingModeHint values so USD exports can target ray-hit sorting while preserving the existing cameraDistance default. --- configs/base_gs.yaml | 1 + pyproject.toml | 2 +- requirements.txt | 2 +- threedgrut/export/scripts/export_usd.py | 22 +++++++++++- threedgrut/export/scripts/transcode.py | 16 +++++++-- threedgrut/export/tests/test_export_import.py | 27 ++++++++++++++ threedgrut/export/usd/exporter.py | 12 ++++--- threedgrut/export/usd/particle_field_hints.py | 35 +++++++++++++++++++ threedgrut/export/usd/writers/background.py | 4 +-- threedgrut/export/usd/writers/base.py | 3 +- threedgrut/export/usd/writers/lightfield.py | 8 +++-- threedgrut/export/usd/writers/ppisp_writer.py | 21 +++++------ 12 files changed, 125 insertions(+), 28 deletions(-) create mode 100644 threedgrut/export/usd/particle_field_hints.py diff --git a/configs/base_gs.yaml b/configs/base_gs.yaml index 12fb96b4..18321f45 100644 --- a/configs/base_gs.yaml +++ b/configs/base_gs.yaml @@ -46,6 +46,7 @@ export_usd: half_precision: false export_cameras: true export_background: true + # zDepth | cameraDistance | rayHitDistance sorting_mode_hint: cameraDistance # If true, Gaussian prim ColorSpaceAPI uses lin_rec709_scene; else srgb_rec709_display linear_srgb: false diff --git a/pyproject.toml b/pyproject.toml index d5a2e87f..0cde926b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -53,7 +53,7 @@ dependencies = [ # # slangtorch on amd64 only # "slangtorch==1.3.18; sys_platform == 'linux' and platform_machine == 'x86_64'", # usd-core only available for amd64 - "usd-core>=26.3; sys_platform == 'linux' and platform_machine == 'x86_64'", + "usd-core>=26.5; sys_platform == 'linux' and platform_machine == 'x86_64'", ] [project.optional-dependencies] diff --git a/requirements.txt b/requirements.txt index 55c62999..d0c73377 100644 --- a/requirements.txt +++ b/requirements.txt @@ -28,7 +28,7 @@ libigl pygltflib # --find-links https://nvidia-kaolin.s3.us-east-2.amazonaws.com/torch-2.1.2_cu118.html # kaolin==0.17.0 -usd-core>=26.3 +usd-core>=26.5 ppisp @ git+https://github.com/nv-tlabs/ppisp@v1.0.1 # NCore dataset support (https://github.com/NVIDIA/ncore) nvidia-ncore>=19.0.0 diff --git a/threedgrut/export/scripts/export_usd.py b/threedgrut/export/scripts/export_usd.py index 95f4a197..e3ccf002 100644 --- a/threedgrut/export/scripts/export_usd.py +++ b/threedgrut/export/scripts/export_usd.py @@ -37,6 +37,10 @@ import torch from threedgrut.export import NuRecExporter, USDExporter +from threedgrut.export.usd.particle_field_hints import ( + DEFAULT_PARTICLE_FIELD_SORTING_MODE_HINT, + PARTICLE_FIELD_SORTING_MODE_HINTS, +) from threedgrut.utils.logger import logger @@ -124,6 +128,16 @@ def parse_args(): action="store_true", help="Set prim color space to lin_rec709_scene (linear). Default is srgb_rec709_display.", ) + parser.add_argument( + "--sorting-mode-hint", + type=str, + choices=PARTICLE_FIELD_SORTING_MODE_HINTS, + default=None, + help=( + "ParticleField sortingModeHint for standard USD export. " + "Use rayHitDistance for ray-tracing renderers that support ray-hit sorting." + ), + ) post_processing_group = parser.add_mutually_exclusive_group() post_processing_group.add_argument( "--export-post-processing", @@ -385,7 +399,13 @@ def main(): export_cameras=not args.no_cameras, export_background=not args.no_background, apply_normalizing_transform=not args.no_transform, - sorting_mode_hint=getattr(export_conf, "sorting_mode_hint", "cameraDistance"), + sorting_mode_hint=_arg_or_conf( + args.sorting_mode_hint, + export_conf, + "sorting-mode-hint", + "sorting_mode_hint", + DEFAULT_PARTICLE_FIELD_SORTING_MODE_HINT, + ), linear_srgb=args.linear_srgb or getattr(export_conf, "linear_srgb", False), export_post_processing=export_post_processing, post_processing_export_mode=post_processing_export_mode, diff --git a/threedgrut/export/scripts/transcode.py b/threedgrut/export/scripts/transcode.py index c7e6bf54..c0b34a2d 100644 --- a/threedgrut/export/scripts/transcode.py +++ b/threedgrut/export/scripts/transcode.py @@ -54,6 +54,10 @@ from threedgrut.export.usd.camera_copy import usd_stage_path_context_for_camera_copy from threedgrut.export.usd.exporter import USDExporter from threedgrut.export.usd.nurec.exporter import NuRecExporter +from threedgrut.export.usd.particle_field_hints import ( + DEFAULT_PARTICLE_FIELD_SORTING_MODE_HINT, + PARTICLE_FIELD_SORTING_MODE_HINTS, +) logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s") logger = logging.getLogger(__name__) @@ -174,7 +178,11 @@ def get_exporter( export_cameras=False, export_background=False, apply_normalizing_transform=False, - sorting_mode_hint=render_order_hint if render_order_hint is not None else "cameraDistance", + sorting_mode_hint=( + render_order_hint + if render_order_hint is not None + else DEFAULT_PARTICLE_FIELD_SORTING_MODE_HINT + ), linear_srgb=linear_srgb, ), False, @@ -364,9 +372,13 @@ def parse_args(): parser.add_argument( "--render-order-hint", type=str, + choices=PARTICLE_FIELD_SORTING_MODE_HINTS, default=None, metavar="MODE", - help="Force sortingModeHint for lightfield export (e.g. cameraDistance, zDepth). Ignored with --format ply/nurec (warning only).", + help=( + "Force sortingModeHint for lightfield export " + "(zDepth, cameraDistance, rayHitDistance). Ignored with --format ply/nurec (warning only)." + ), ) parser.add_argument( "--linear-srgb", diff --git a/threedgrut/export/tests/test_export_import.py b/threedgrut/export/tests/test_export_import.py index 5c54d349..421866f4 100644 --- a/threedgrut/export/tests/test_export_import.py +++ b/threedgrut/export/tests/test_export_import.py @@ -510,5 +510,32 @@ def test_usdz_export_camera_is_composed_from_root_stage(self): assert stage.GetEndTimeCode() == 1.0 +class TestUSDExportSortingModeHint: + """Test ParticleField sortingModeHint authoring.""" + + def test_usd_export_sorting_mode_hint_ray_hit_distance(self): + """Export can author the usd-core 26.5 rayHitDistance sorting hint.""" + model = MockGaussianModel(num_gaussians=5, sh_degree=3) + with tempfile.TemporaryDirectory() as tmpdir: + usd_path = Path(tmpdir) / "test.usdz" + USDExporter( + half_precision=False, + export_cameras=False, + export_background=False, + apply_normalizing_transform=False, + sorting_mode_hint="rayHitDistance", + ).export(model, usd_path) + stage = Usd.Stage.Open(str(usd_path)) + assert stage + prim = _find_prim_with_color_space_api(stage) + assert prim is not None, "No Gaussian particle prim found" + assert prim.GetAttribute("sortingModeHint").Get() == "rayHitDistance" + + def test_usd_export_sorting_mode_hint_rejects_unknown_token(self): + """Unsupported sorting hints fail before authoring invalid USD.""" + with pytest.raises(ValueError, match="Unsupported ParticleField sortingModeHint"): + USDExporter(sorting_mode_hint="frontToBack") + + if __name__ == "__main__": pytest.main([__file__, "-v"]) diff --git a/threedgrut/export/usd/exporter.py b/threedgrut/export/usd/exporter.py index 234603d0..a836e9d1 100644 --- a/threedgrut/export/usd/exporter.py +++ b/threedgrut/export/usd/exporter.py @@ -54,6 +54,10 @@ merge_source_prim_at_same_path, merge_source_world_at_same_paths, ) +from threedgrut.export.usd.particle_field_hints import ( + DEFAULT_PARTICLE_FIELD_SORTING_MODE_HINT, + normalize_particle_field_sorting_mode_hint, +) from threedgrut.export.usd.post_processing_sh_bake import MODE_PPISP_BAKE_VIGNETTING_ACHROMATIC_FIT from threedgrut.export.usd.writers.camera import export_cameras_to_usd @@ -273,7 +277,7 @@ def __init__( export_cameras: bool = True, export_background: bool = True, apply_normalizing_transform: bool = True, - sorting_mode_hint: str = "cameraDistance", + sorting_mode_hint: str = DEFAULT_PARTICLE_FIELD_SORTING_MODE_HINT, linear_srgb: bool = False, export_post_processing: bool = True, post_processing_export_mode: str = MODE_POST_PROCESSING_EXPORT_BAKED_SH, @@ -296,7 +300,7 @@ def __init__( export_cameras: Include camera poses in export. export_background: Include background/environment in export. apply_normalizing_transform: Apply transform to normalize scene orientation. - sorting_mode_hint: Sorting hint for rendering ("cameraDistance", "zDepth"). + sorting_mode_hint: Sorting hint for rendering ("zDepth", "cameraDistance", "rayHitDistance"). linear_srgb: If True, set prim color space to lin_rec709_scene. export_post_processing: If True, export the checkpoint post-processing module with the selected export mode. @@ -326,7 +330,7 @@ def __init__( self.export_cameras = export_cameras self.export_background = export_background self.apply_normalizing_transform = apply_normalizing_transform - self.sorting_mode_hint = sorting_mode_hint + self.sorting_mode_hint = normalize_particle_field_sorting_mode_hint(sorting_mode_hint) self.linear_srgb = linear_srgb self.export_post_processing = export_post_processing self.post_processing_export_mode = normalize_post_processing_export_mode(post_processing_export_mode) @@ -741,7 +745,7 @@ def from_config(cls, conf) -> "USDExporter": export_cameras=getattr(export_conf, "export_cameras", True), export_background=getattr(export_conf, "export_background", True), apply_normalizing_transform=getattr(export_conf, "apply_normalizing_transform", True), - sorting_mode_hint=getattr(export_conf, "sorting_mode_hint", "cameraDistance"), + sorting_mode_hint=getattr(export_conf, "sorting_mode_hint", DEFAULT_PARTICLE_FIELD_SORTING_MODE_HINT), linear_srgb=getattr(export_conf, "linear_srgb", False), export_post_processing=_get_export_config_value( export_conf, diff --git a/threedgrut/export/usd/particle_field_hints.py b/threedgrut/export/usd/particle_field_hints.py new file mode 100644 index 00000000..fc815162 --- /dev/null +++ b/threedgrut/export/usd/particle_field_hints.py @@ -0,0 +1,35 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""ParticleField schema hint tokens supported by usd-core 26.5+.""" + +DEFAULT_PARTICLE_FIELD_SORTING_MODE_HINT = "cameraDistance" + +PARTICLE_FIELD_SORTING_MODE_HINTS = ( + "zDepth", + "cameraDistance", + "rayHitDistance", +) + + +def normalize_particle_field_sorting_mode_hint(value: str) -> str: + """Normalize and validate a ParticleField sortingModeHint token.""" + normalized = str(value).strip() + if normalized not in PARTICLE_FIELD_SORTING_MODE_HINTS: + raise ValueError( + f"Unsupported ParticleField sortingModeHint '{value}'. " + f"Expected one of: {list(PARTICLE_FIELD_SORTING_MODE_HINTS)}" + ) + return normalized diff --git a/threedgrut/export/usd/writers/background.py b/threedgrut/export/usd/writers/background.py index 7b91fc9a..a681fd82 100644 --- a/threedgrut/export/usd/writers/background.py +++ b/threedgrut/export/usd/writers/background.py @@ -41,9 +41,7 @@ def _tensor_to_tuple(color: torch.Tensor) -> Tuple[float, float, float]: """Convert a torch tensor color to a tuple of floats.""" - if color.is_cuda: - color = color.cpu() - arr = color.numpy() + arr = color.detach().cpu().numpy() return tuple(float(c) for c in arr[:3]) diff --git a/threedgrut/export/usd/writers/base.py b/threedgrut/export/usd/writers/base.py index 48641df8..4b794e03 100644 --- a/threedgrut/export/usd/writers/base.py +++ b/threedgrut/export/usd/writers/base.py @@ -27,6 +27,7 @@ from pxr import Gf, Usd, Vt from threedgrut.export.accessor import GaussianAttributes, ModelCapabilities +from threedgrut.export.usd.particle_field_hints import DEFAULT_PARTICLE_FIELD_SORTING_MODE_HINT logger = logging.getLogger(__name__) @@ -132,7 +133,7 @@ def create_gaussian_writer( content_root_path: str = "/World/Gaussians", half_geometry: bool = False, half_features: bool = False, - sorting_mode_hint: str = "cameraDistance", + sorting_mode_hint: str = DEFAULT_PARTICLE_FIELD_SORTING_MODE_HINT, linear_srgb: bool = False, omni_usd: bool = False, has_post_processing: bool = False, diff --git a/threedgrut/export/usd/writers/lightfield.py b/threedgrut/export/usd/writers/lightfield.py index 130697b9..33016ba7 100644 --- a/threedgrut/export/usd/writers/lightfield.py +++ b/threedgrut/export/usd/writers/lightfield.py @@ -27,6 +27,10 @@ from pxr import Gf, Sdf, Usd, UsdGeom, UsdVol, Vt from threedgrut.export.accessor import GaussianAttributes, ModelCapabilities +from threedgrut.export.usd.particle_field_hints import ( + DEFAULT_PARTICLE_FIELD_SORTING_MODE_HINT, + normalize_particle_field_sorting_mode_hint, +) from threedgrut.export.usd.writers.base import GaussianUSDWriter logger = logging.getLogger(__name__) @@ -49,7 +53,7 @@ def __init__( half_geometry: bool = False, half_features: bool = False, projection_mode_hint: str = "perspective", - sorting_mode_hint: str = "cameraDistance", + sorting_mode_hint: str = DEFAULT_PARTICLE_FIELD_SORTING_MODE_HINT, linear_srgb: bool = False, omni_usd: bool = False, has_post_processing: bool = False, @@ -65,7 +69,7 @@ def __init__( self.half_geometry = half_geometry self.half_features = half_features self.projection_mode_hint = projection_mode_hint - self.sorting_mode_hint = sorting_mode_hint + self.sorting_mode_hint = normalize_particle_field_sorting_mode_hint(sorting_mode_hint) # Use surflet kernel for surfel models, ellipsoid for 3DGS self.use_surflet_kernel = capabilities.is_surfel diff --git a/threedgrut/export/usd/writers/ppisp_writer.py b/threedgrut/export/usd/writers/ppisp_writer.py index 61ef041c..fe9a766f 100644 --- a/threedgrut/export/usd/writers/ppisp_writer.py +++ b/threedgrut/export/usd/writers/ppisp_writer.py @@ -175,7 +175,7 @@ def _set_vignetting_params(shader: UsdShade.Shader, ppisp: PPISP, camera_index: ppisp.vignetting_params[camera_index] has shape [3, 5]: [cx, cy, alpha1, alpha2, alpha3] per channel. """ - vig = ppisp.vignetting_params[camera_index].cpu().numpy() # [3, 5] + vig = ppisp.vignetting_params[camera_index].detach().cpu().numpy() # [3, 5] for ch in range(NUM_CHANNELS): s = CHANNEL_SUFFIXES[ch] shader.CreateInput(f"vignettingCenter{s}", Sdf.ValueTypeNames.Float2).Set( @@ -192,7 +192,7 @@ def _set_crf_params(shader: UsdShade.Shader, ppisp: PPISP, camera_index: int) -> ppisp.crf_params[camera_index] has shape [3, 4]: [toe, shoulder, gamma, center] per channel (raw, activations applied in shader). """ - crf = ppisp.crf_params[camera_index].cpu().numpy() # [3, 4] + crf = ppisp.crf_params[camera_index].detach().cpu().numpy() # [3, 4] for ch in range(NUM_CHANNELS): s = CHANNEL_SUFFIXES[ch] shader.CreateInput(f"crfToe{s}", Sdf.ValueTypeNames.Float).Set(float(crf[ch, 0])) @@ -216,7 +216,7 @@ def _set_animated_exposure_params( ppisp.exposure_params has shape [num_frames]. Time code = float(frame_idx). """ - exposure = ppisp.exposure_params.cpu().numpy() # [num_frames] + exposure = ppisp.exposure_params.detach().cpu().numpy() # [num_frames] valid = [i for i in frame_indices if i < len(exposure)] mean_val = float(np.mean(exposure[valid])) if valid else 0.0 @@ -235,7 +235,7 @@ def _set_static_exposure_params( frame_index: int, ) -> None: """Write one fixed exposure offset without USD time samples.""" - exposure = ppisp.exposure_params.cpu().numpy() + exposure = ppisp.exposure_params.detach().cpu().numpy() if frame_index < 0 or frame_index >= len(exposure): raise ValueError(f"frame_index must be in [0, {len(exposure) - 1}], got {frame_index}.") shader.CreateInput("exposureOffset", Sdf.ValueTypeNames.Float).Set(float(exposure[frame_index])) @@ -253,7 +253,7 @@ def _set_animated_color_params( Written as 4 float2 attributes. Time code = float(frame_idx). """ - color = ppisp.color_params.cpu().numpy() # [num_frames, 8] + color = ppisp.color_params.detach().cpu().numpy() # [num_frames, 8] valid = [i for i in frame_indices if i < len(color)] mean_color = np.mean(color[valid], axis=0) if valid else np.zeros(8) @@ -281,7 +281,7 @@ def _set_static_color_params( frame_index: int, ) -> None: """Write one fixed color latent state without USD time samples.""" - color = ppisp.color_params.cpu().numpy() + color = ppisp.color_params.detach().cpu().numpy() if frame_index < 0 or frame_index >= len(color): raise ValueError(f"frame_index must be in [0, {len(color) - 1}], got {frame_index}.") @@ -325,9 +325,7 @@ def add_ppisp_shader_to_render_product( Returns: The created PPISP Shader prim. """ - assert camera_index < ppisp.num_cameras, ( - f"camera_index {camera_index} >= ppisp.num_cameras {ppisp.num_cameras}" - ) + assert camera_index < ppisp.num_cameras, f"camera_index {camera_index} >= ppisp.num_cameras {ppisp.num_cameras}" if not frame_indices and fixed_frame_index is None: log.warning(f"No frames for camera {camera_index} at {render_product_path}, skipping") return stage.GetPseudoRoot() @@ -342,10 +340,7 @@ def add_ppisp_shader_to_render_product( _set_static_exposure_params(shader, ppisp, fixed_frame_index) _set_static_color_params(shader, ppisp, fixed_frame_index) - log.info( - f"Added PPISP shader to {render_product_path} " - f"(camera {camera_index}, {len(frame_indices)} frame(s))" - ) + log.info(f"Added PPISP shader to {render_product_path} " f"(camera {camera_index}, {len(frame_indices)} frame(s))") return shader.GetPrim() From 36a0d2d7a82793b33cc0536da558867c439abf40 Mon Sep 17 00:00:00 2001 From: Horde Date: Thu, 30 Apr 2026 20:00:04 +0000 Subject: [PATCH 20/42] feat(export): PPISP controller SPG export MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds the per-camera PPISP controller (CNN + adaptive avg pool + 3-layer MLP) to the Omniverse-native USD export so the runtime SPG pipeline predicts exposure/colour at frame time instead of relying on time-sampled USD attributes baked at training. Pipeline authored on each RenderProduct: HdrColor -> PPISPController_ -> ControllerParams (1x9 float) -> PPISP (dyn) -> PPISPColor -> LdrColor The controller weights flatten into a single 241,961-element float[] attribute on the Shader prim; the SPG slang reads them via a StructuredBuffer. The dynamic PPISP variant (ppisp_usd_spg_dyn.slang) reads exposure/colour from the controller's 1x9 output texture instead of from USD inputs. Added a slangpy-based headless harness under tools/render_ppisp_spg/ that compiles and dispatches the same shaders without booting Kit, plus several validators (synthetic, real ppisp module, trained checkpoint) that compare the slang result against the in-process PyTorch PPISP forward pass. End-to-end on a trained bonsai checkpoint: PSNR 63 dB, controller 9-float drift ~3e-7 vs torch — within fp32 rounding of the rgba8_unorm output format. Co-Authored-By: Claude Opus 4.7 (1M context) --- docs/ppisp-controller-export-plan.md | 162 ++++++++ threedgrut/export/usd/exporter.py | 34 +- threedgrut/export/usd/ppisp_spg/__init__.py | 30 +- .../usd/ppisp_spg/ppisp_controller.slang | 289 +++++++++++++ .../usd/ppisp_spg/ppisp_controller.slang.lua | 33 ++ .../usd/ppisp_spg/ppisp_controller.slang.usda | 31 ++ .../usd/ppisp_spg/ppisp_usd_spg_dyn.slang | 212 ++++++++++ .../usd/ppisp_spg/ppisp_usd_spg_dyn.slang.lua | 68 +++ .../ppisp_spg/ppisp_usd_spg_dyn.slang.usda | 48 +++ .../usd/writers/ppisp_controller_writer.py | 280 +++++++++++++ threedgrut/export/usd/writers/ppisp_writer.py | 93 ++++- tools/render_ppisp_spg/README.md | 50 +++ tools/render_ppisp_spg/__init__.py | 2 + .../render_ppisp_spg/render_renderproduct.py | 388 ++++++++++++++++++ tools/render_ppisp_spg/spg_runtime.py | 362 ++++++++++++++++ tools/render_ppisp_spg/validate_controller.py | 179 ++++++++ tools/render_ppisp_spg/validate_e2e.py | 335 +++++++++++++++ tools/render_ppisp_spg/validate_real_ppisp.py | 95 +++++ tools/render_ppisp_spg/validate_trained.py | 366 +++++++++++++++++ 19 files changed, 3030 insertions(+), 27 deletions(-) create mode 100644 docs/ppisp-controller-export-plan.md create mode 100644 threedgrut/export/usd/ppisp_spg/ppisp_controller.slang create mode 100644 threedgrut/export/usd/ppisp_spg/ppisp_controller.slang.lua create mode 100644 threedgrut/export/usd/ppisp_spg/ppisp_controller.slang.usda create mode 100644 threedgrut/export/usd/ppisp_spg/ppisp_usd_spg_dyn.slang create mode 100644 threedgrut/export/usd/ppisp_spg/ppisp_usd_spg_dyn.slang.lua create mode 100644 threedgrut/export/usd/ppisp_spg/ppisp_usd_spg_dyn.slang.usda create mode 100644 threedgrut/export/usd/writers/ppisp_controller_writer.py create mode 100644 tools/render_ppisp_spg/README.md create mode 100644 tools/render_ppisp_spg/__init__.py create mode 100644 tools/render_ppisp_spg/render_renderproduct.py create mode 100644 tools/render_ppisp_spg/spg_runtime.py create mode 100644 tools/render_ppisp_spg/validate_controller.py create mode 100644 tools/render_ppisp_spg/validate_e2e.py create mode 100644 tools/render_ppisp_spg/validate_real_ppisp.py create mode 100644 tools/render_ppisp_spg/validate_trained.py diff --git a/docs/ppisp-controller-export-plan.md b/docs/ppisp-controller-export-plan.md new file mode 100644 index 00000000..eee951f9 --- /dev/null +++ b/docs/ppisp-controller-export-plan.md @@ -0,0 +1,162 @@ +# PPISP Controller SPG Export — Feasibility & Plan + +Scope: extend the existing PPISP SPG export to include the **controller** — +the per-camera CNN+MLP that predicts per-frame `exposure` and 8-d +`color_latents` from the rendered image. + +Status: feasible. This document records the design before implementation. + +--- + +## 1. Controller summary (from `ppisp.PPISP._PPISPController`) + +Fixed architecture, one instance per camera. Inputs are the raw rendered +HDR image and an optional `prior_exposure` scalar. + +``` +Conv2d(3→16, 1×1, +bias) +MaxPool2d(3, stride=3) +ReLU +Conv2d(16→32, 1×1, +bias) +ReLU +Conv2d(32→64, 1×1, +bias) +AdaptiveAvgPool2d((5, 5)) +Flatten # → 1600 features +concat(prior_exposure) # → 1601 +Linear(1601 → 128) + ReLU +Linear(128 → 128) + ReLU +Linear(128 → 128) + ReLU +exposure_head: Linear(128 → 1) +color_head: Linear(128 → 8) +``` + +Total weights ≈ **240 K floats per camera**. + +The output is **two scalar values for the whole image** (1 exposure, 8 colour +latents). Those nine numbers replace the current static USD time-samples on +the PPISP shader. + +--- + +## 2. SPG capabilities used + +The 3DGRUT SPG pipeline already uses **Slang** in the SPG runtime +(`*.slang`, `*.slang.lua`, `*.slang.usda`) — the public docs that describe +only CUDA kernels are out-of-date; the existing PPISP shader proves Slang +launch is supported. + +Confirmed primitives: + +- `slang.dispatch{ stage="compute", numthreads=…, grid=…, bind={…} }` per + shader prim. +- `slang.ParameterBlock(...)` for grouped scalar/vector inputs that map to + USD attributes. +- `slang.Texture2D / slang.RWTexture2D / slang.empty(shape, dtype)` for + bound textures and lua-allocated outputs. +- Shader-to-shader chaining via USD `omni:rtx:aov` connections on + `RenderVar` prims (the existing `LdrColor` → `PPISP` wiring uses this). + +What we **do not** rely on: +- Multi-dispatch within one Lua launcher (one dispatch per shader prim). +- CooperativeVector / coopvec — not assumed available in the target Kit. +- Non-2D output buffers — only 2D images via `slang.empty`. + +--- + +## 3. Two challenges and how we solve them + +### 3.1 Adaptive avg pooling on a runtime-sized input + +PyTorch's `AdaptiveAvgPool2d((5,5))` partitions the input into exactly 25 +near-equal cells. The cell bounds are: + +``` +i = 0..4 (output row) +start_h = floor(i * H_in / 5) +end_h = ceil((i + 1) * H_in / 5) +``` + +Each Slang thread group computes one output cell `(i, j)` by reading every +input pixel in `[start_h, end_h) × [start_w, end_w)`, applying the +per-pixel CNN forward (3×3 max-pool fused with the surrounding 1×1 +convolutions), and reducing the sum / divide-by-count in shared memory. + +This works for arbitrary input resolution because the cell bounds are +computed inside the shader from `H_in, W_in`. + +### 3.2 Baking MLP and CNN weights into Slang + +Each camera's controller has unique weights. We generate **one Slang file +per camera** at export time, with all weights emitted as +`static const float[]` arrays. Slang's compiler can fold these into +constant memory, and there is no runtime upload step. + +The generated file `ppisp_controller_.slang` includes a fixed +shared template (CNN forward, pool, MLP) and only differs in the weight +constants. The matching `*.slang.lua` and `*.slang.usda` are emitted per +camera as well so each `RenderProduct` references its own controller. + +If weights ever exceed Slang's static-data limits we can fall back to +USD `float[]` inputs bound as a `StructuredBuffer`, but for the +default architecture (~240 K floats) static arrays are fine. + +--- + +## 4. SPG graph + +``` +HdrColor (RenderVar) + │ + ▼ (omni:rtx:aov connection) +PPISPController_ Slang compute, single thread group + │ outputs ControllerParams (1×9 float image) + ▼ +PPISP Slang compute, grid sized to image + │ reads HdrColor + ControllerParams + static vignetting/CRF + ▼ outputs PPISPColor +LdrColor (RenderVar) +``` + +The existing `ppisp_writer.py` builds the second half. The new +controller writer creates the first stage and connects its output as +an additional input to the PPISP shader. + +The PPISP slang shader is **generalised** to read the exposure and the 8 +colour latents from a 1×9 single-channel float texture when one is bound, +falling back to its `ParameterBlock` defaults otherwise. This keeps the +legacy "static parameters per frame" path unchanged — important for users +who train without a controller. + +--- + +## 5. Testing + +Two-pronged approach: + +1. **Unit-level Python check** that the generated Slang reproduces the + PyTorch controller's outputs to within a tight tolerance, using + `slangpy` to dispatch the controller shader against a reference image. + +2. **Tool: `tools/render_usd_renderproduct/`** — a slangpy-based runner + that opens an exported USD/USDZ, walks `/Render/` prims, finds + their SPG shader chain, and replays the chain on a supplied HDR input + for every authored time sample. Useful for visual regression and for + cross-checking that the PPISP USD asset produces the same image + sequence as the in-process `apply_post_processing` path used during + training. + +The render tool intentionally does not try to reproduce Kit's full +RenderProduct pipeline; it executes only the SPG `compute` stages so it +remains independent of Kit and useful in headless CI. + +--- + +## 6. Out of scope for this iteration + +- Multi-dispatch optimisation of the controller (currently one slow but + correct compute pass). +- CoopVec acceleration of the MLP matmul. +- Quantising weights to fp16/bf16 to reduce shader source size. +- Runtime weight upload (large `float[]` USD inputs). + +These can be added later if the basic export proves correct. diff --git a/threedgrut/export/usd/exporter.py b/threedgrut/export/usd/exporter.py index a836e9d1..15f2a4a9 100644 --- a/threedgrut/export/usd/exporter.py +++ b/threedgrut/export/usd/exporter.py @@ -693,13 +693,18 @@ def _export_ppisp( has_controller = ( bool(getattr(ppisp_config, "use_controller", False)) and controllers is not None and len(controllers) > 0 ) - if has_controller: - logger.warning( - "PPISP controller export is not implemented yet; SPG export uses only " - "stored exposure/color parameters, vignetting, and CRF." + # The static-frame override modes (fixed_frame_id) intentionally bypass + # the controller because the goal is to bake one specific frame's + # corrections, not to predict them at runtime. + use_controller = has_controller and fixed_frame_id is None + if has_controller and fixed_frame_id is not None: + logger.info( + "PPISP controller present but fixed_frame_id is set; using static " + "exposure/color from frame %d instead of the controller.", + fixed_frame_id, ) - from threedgrut.export.usd.ppisp_spg import get_ppisp_spg_files + from threedgrut.export.usd.ppisp_spg import get_ppisp_spg_files, get_ppisp_spg_dyn_files from threedgrut.export.usd.writers.ppisp_writer import ( add_ppisp_to_all_render_products, build_camera_frame_mapping, @@ -715,17 +720,32 @@ def _export_ppisp( camera_frame_mapping=camera_frame_mapping, fixed_camera_index=fixed_camera_id, fixed_frame_index=fixed_frame_id, + use_controller=use_controller, ) except Exception as e: logger.warning(f"Failed to add PPISP shaders: {e}") return - spg_files = get_ppisp_spg_files() + if use_controller: + spg_files = list(get_ppisp_spg_dyn_files()) + from threedgrut.export.usd.writers.ppisp_controller_writer import ( + get_controller_sidecars, + ) + for s in get_controller_sidecars(): + if not any(f.filename == s.filename for f in spg_files): + spg_files.append(s) + else: + spg_files = get_ppisp_spg_files() + for spg_file in spg_files: if not any(f.filename == spg_file.filename for f in files): files.append(spg_file) - logger.info(f"PPISP Omniverse-native export complete: {len(spg_files)} sidecar(s) added") + logger.info( + "PPISP Omniverse-native export complete: %d sidecar(s) added (controller=%s)", + len(files), + use_controller, + ) @classmethod def from_config(cls, conf) -> "USDExporter": diff --git a/threedgrut/export/usd/ppisp_spg/__init__.py b/threedgrut/export/usd/ppisp_spg/__init__.py index 7aa37f5b..2ac26bc4 100644 --- a/threedgrut/export/usd/ppisp_spg/__init__.py +++ b/threedgrut/export/usd/ppisp_spg/__init__.py @@ -29,21 +29,21 @@ log = logging.getLogger(__name__) _SPG_DIR = Path(__file__).parent -_SPG_FILES = [ +_SPG_STATIC_FILES = [ "ppisp_usd_spg.slang", "ppisp_usd_spg.slang.lua", "ppisp_usd_spg.slang.usda", ] +_SPG_DYN_FILES = [ + "ppisp_usd_spg_dyn.slang", + "ppisp_usd_spg_dyn.slang.lua", + "ppisp_usd_spg_dyn.slang.usda", +] -def get_ppisp_spg_files() -> List[NamedSerialized]: - """Load all PPISP SPG sidecar files as serialized data for USDZ packaging. - - Returns: - List of NamedSerialized for each SPG file (slang, lua, usda). - """ +def _load_files(filenames) -> List[NamedSerialized]: result: List[NamedSerialized] = [] - for filename in _SPG_FILES: + for filename in filenames: path = _SPG_DIR / filename if path.exists(): result.append(NamedSerialized(filename=filename, serialized=path.read_bytes())) @@ -51,3 +51,17 @@ def get_ppisp_spg_files() -> List[NamedSerialized]: else: log.warning(f"PPISP SPG sidecar not found: {path}") return result + + +def get_ppisp_spg_files() -> List[NamedSerialized]: + """Load static-parameter PPISP SPG sidecar files (controller-free path).""" + return _load_files(_SPG_STATIC_FILES) + + +def get_ppisp_spg_dyn_files() -> List[NamedSerialized]: + """Load controller-aware PPISP SPG sidecar files. + + These accompany the per-camera ``ppisp_controller_.slang`` and read + ``exposureOffset`` and the colour latents from the controller output. + """ + return _load_files(_SPG_DYN_FILES) diff --git a/threedgrut/export/usd/ppisp_spg/ppisp_controller.slang b/threedgrut/export/usd/ppisp_spg/ppisp_controller.slang new file mode 100644 index 00000000..9ead0090 --- /dev/null +++ b/threedgrut/export/usd/ppisp_spg/ppisp_controller.slang @@ -0,0 +1,289 @@ +// SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 + +// PPISP Controller SPG Shader. +// +// Generic compute shader: weights are bound at dispatch time as a flat +// ``StructuredBuffer``. Per-camera variation lives entirely in +// USD attributes; the Slang/Lua/USDA assets are shared. +// +// Architecture mirrors ppisp._PPISPController (default config): +// +// Conv1x1(3->16, +bias) +// MaxPool 3x3 stride 3 +// ReLU +// Conv1x1(16->32, +bias) +// ReLU +// Conv1x1(32->64, +bias) +// AdaptiveAvgPool2d((5,5)) +// Flatten -> 1600 +// concat prior_exposure -> 1601 +// MLP: 1601 -> 128 -> 128 -> 128, ReLU after each hidden layer +// exposure_head: 128 -> 1 +// color_head: 128 -> 8 +// +// Output texture (1x9 float, RWTexture2D): +// pixel (0,0): exposureOffset +// pixel (1,0)..(8,0): color latents +// [colorBlue.x, colorBlue.y, +// colorRed.x, colorRed.y, +// colorGreen.x, colorGreen.y, +// colorNeutral.x, colorNeutral.y] + +// --------------------------------------------------------------------------- +// Architecture sizes (must match ``_PPISPController`` defaults). +// --------------------------------------------------------------------------- +static const int CNN_FEATURE_DIM = 64; +static const int POOL_GRID_H = 5; +static const int POOL_GRID_W = 5; +static const int POOL_CELL_COUNT = POOL_GRID_H * POOL_GRID_W; // 25 +static const int POOL_FEATURE_LEN = POOL_CELL_COUNT * CNN_FEATURE_DIM; // 1600 +static const int MLP_INPUT_DIM = POOL_FEATURE_LEN + 1; // 1601 +static const int MLP_HIDDEN_DIM = 128; +static const int COLOR_PARAMS_PER_FRAME = 8; +static const int INPUT_DOWNSAMPLING = 3; +static const int THREAD_GROUP_SIZE = 32; + +// --------------------------------------------------------------------------- +// Weight buffer offsets (the Python writer flattens weights in this order +// into a single float buffer that gets bound as g_Weights). +// --------------------------------------------------------------------------- +static const int OFF_CONV1_W = 0; // 16 * 3 = 48 +static const int OFF_CONV1_B = OFF_CONV1_W + 16 * 3; // + 16 = 64 +static const int OFF_CONV2_W = OFF_CONV1_B + 16; // + 32 * 16 = 576 +static const int OFF_CONV2_B = OFF_CONV2_W + 32 * 16; // + 32 = 608 +static const int OFF_CONV3_W = OFF_CONV2_B + 32; // + 64 * 32 = 2656 +static const int OFF_CONV3_B = OFF_CONV3_W + 64 * 32; // + 64 = 2720 +static const int OFF_TRUNK0_W = OFF_CONV3_B + 64; // + 128 * 1601 = 207648 +static const int OFF_TRUNK0_B = OFF_TRUNK0_W + 128 * MLP_INPUT_DIM; +static const int OFF_TRUNK1_W = OFF_TRUNK0_B + 128; +static const int OFF_TRUNK1_B = OFF_TRUNK1_W + 128 * 128; +static const int OFF_TRUNK2_W = OFF_TRUNK1_B + 128; +static const int OFF_TRUNK2_B = OFF_TRUNK2_W + 128 * 128; +static const int OFF_EXP_W = OFF_TRUNK2_B + 128; +static const int OFF_EXP_B = OFF_EXP_W + 128; +static const int OFF_COL_W = OFF_EXP_B + 1; +static const int OFF_COL_B = OFF_COL_W + 8 * 128; +static const int TOTAL_WEIGHTS = OFF_COL_B + 8; + +// --------------------------------------------------------------------------- +// Bindings +// --------------------------------------------------------------------------- + +struct PPISPControllerParams +{ + float priorExposure; +}; + +[[vk::binding(0, 1)]] ParameterBlock g_Params; +[[vk::binding(1, 1)]] Texture2D g_InTex; +[[vk::binding(2, 1)]] StructuredBuffer g_Weights; +[[vk::binding(3, 1)]] RWTexture2D g_OutTex; + +// --------------------------------------------------------------------------- +// Per-pixel CNN building blocks +// --------------------------------------------------------------------------- + +void conv1Forward(float3 rgb, out float feat[16]) +{ + [unroll] for (int o = 0; o < 16; ++o) + { + float v = g_Weights[OFF_CONV1_B + o]; + v += rgb.r * g_Weights[OFF_CONV1_W + o * 3 + 0]; + v += rgb.g * g_Weights[OFF_CONV1_W + o * 3 + 1]; + v += rgb.b * g_Weights[OFF_CONV1_W + o * 3 + 2]; + feat[o] = v; + } +} + +void conv2Forward(float fin[16], out float fout[32]) +{ + [unroll] for (int o = 0; o < 32; ++o) + { + float v = g_Weights[OFF_CONV2_B + o]; + [unroll] for (int i = 0; i < 16; ++i) + v += fin[i] * g_Weights[OFF_CONV2_W + o * 16 + i]; + fout[o] = v; + } +} + +void conv3Forward(float fin[32], out float fout[64]) +{ + [unroll] for (int o = 0; o < CNN_FEATURE_DIM; ++o) + { + float v = g_Weights[OFF_CONV3_B + o]; + [unroll] for (int i = 0; i < 32; ++i) + v += fin[i] * g_Weights[OFF_CONV3_W + o * 32 + i]; + fout[o] = v; + } +} + +void cnnForwardAtDownsampledPixel( + int inW, + int inH, + int dx, + int dy, + out float feat64[64]) +{ + int x0 = dx * INPUT_DOWNSAMPLING; + int y0 = dy * INPUT_DOWNSAMPLING; + int x1 = min(x0 + INPUT_DOWNSAMPLING, inW); + int y1 = min(y0 + INPUT_DOWNSAMPLING, inH); + + float pooled[16]; + [unroll] for (int c = 0; c < 16; ++c) + pooled[c] = -3.402823e+38; + + for (int yy = y0; yy < y1; ++yy) + { + for (int xx = x0; xx < x1; ++xx) + { + float4 sample = g_InTex.Load(int3(xx, yy, 0)); + float conv1Out[16]; + conv1Forward(sample.rgb, conv1Out); + [unroll] for (int c = 0; c < 16; ++c) + pooled[c] = max(pooled[c], conv1Out[c]); + } + } + + [unroll] for (int c = 0; c < 16; ++c) + pooled[c] = max(0.0, pooled[c]); + + float feat32[32]; + conv2Forward(pooled, feat32); + [unroll] for (int c = 0; c < 32; ++c) + feat32[c] = max(0.0, feat32[c]); + + conv3Forward(feat32, feat64); +} + +void adaptiveCellAverage( + int inW, + int inH, + int dsW, + int dsH, + int gx, + int gy, + out float cellFeat[64]) +{ + int hStart = (gy * dsH) / POOL_GRID_H; + int hEnd = ((gy + 1) * dsH + POOL_GRID_H - 1) / POOL_GRID_H; + int wStart = (gx * dsW) / POOL_GRID_W; + int wEnd = ((gx + 1) * dsW + POOL_GRID_W - 1) / POOL_GRID_W; + hEnd = min(hEnd, dsH); + wEnd = min(wEnd, dsW); + + [unroll] for (int c = 0; c < CNN_FEATURE_DIM; ++c) + cellFeat[c] = 0.0; + + int count = 0; + for (int dy = hStart; dy < hEnd; ++dy) + { + for (int dx = wStart; dx < wEnd; ++dx) + { + float feat64[CNN_FEATURE_DIM]; + cnnForwardAtDownsampledPixel(inW, inH, dx, dy, feat64); + [unroll] for (int c = 0; c < CNN_FEATURE_DIM; ++c) + cellFeat[c] += feat64[c]; + count += 1; + } + } + + float invCount = (count > 0) ? (1.0 / float(count)) : 0.0; + [unroll] for (int c = 0; c < CNN_FEATURE_DIM; ++c) + cellFeat[c] *= invCount; +} + +groupshared float gsPooled[POOL_FEATURE_LEN]; // 1600 floats +groupshared float gsHiddenA[MLP_HIDDEN_DIM]; // 128 floats +groupshared float gsHiddenB[MLP_HIDDEN_DIM]; // 128 floats + +[shader("compute")] +[numthreads(THREAD_GROUP_SIZE, 1, 1)] +void controllerProcess(uint3 gtid : SV_GroupThreadID) +{ + uint inW = 0, inH = 0; + g_InTex.GetDimensions(inW, inH); + + int dsW = max(1u, inW / INPUT_DOWNSAMPLING); + int dsH = max(1u, inH / INPUT_DOWNSAMPLING); + + // Phase 1: pool cells. With THREAD_GROUP_SIZE=32 threads and 25 cells, + // only the first 25 are active in this phase. + // + // Layout note: PyTorch's nn.Flatten on the [N, C, H, W] CNN output + // produces a *channel-major* flat layout — feat[c * H*W + h*W + w]. + // The trunk0 weight matrix was trained against that layout, so + // gsPooled MUST be stored channel-major as well, i.e. + // gsPooled[c * POOL_CELL_COUNT + cell]. + // (cell-major would silently permute every controller output.) + int cell = int(gtid.x); + if (cell < POOL_CELL_COUNT) + { + int gy = cell / POOL_GRID_W; + int gx = cell % POOL_GRID_W; + + float cellFeat[CNN_FEATURE_DIM]; + adaptiveCellAverage(int(inW), int(inH), dsW, dsH, gx, gy, cellFeat); + + [unroll] for (int c = 0; c < CNN_FEATURE_DIM; ++c) + gsPooled[c * POOL_CELL_COUNT + cell] = cellFeat[c]; + } + GroupMemoryBarrierWithGroupSync(); + + // Phase 2: trunk0 (1601 -> 128). 128 output rows are distributed + // across the THREAD_GROUP_SIZE threads. + for (int o = int(gtid.x); o < MLP_HIDDEN_DIM; o += THREAD_GROUP_SIZE) + { + float v = g_Weights[OFF_TRUNK0_B + o]; + for (int i = 0; i < POOL_FEATURE_LEN; ++i) + v += gsPooled[i] * g_Weights[OFF_TRUNK0_W + o * MLP_INPUT_DIM + i]; + v += g_Params.priorExposure + * g_Weights[OFF_TRUNK0_W + o * MLP_INPUT_DIM + POOL_FEATURE_LEN]; + gsHiddenA[o] = max(0.0, v); + } + GroupMemoryBarrierWithGroupSync(); + + // Phase 3: trunk1 (128 -> 128). gsHiddenA -> gsHiddenB. + for (int o = int(gtid.x); o < MLP_HIDDEN_DIM; o += THREAD_GROUP_SIZE) + { + float v = g_Weights[OFF_TRUNK1_B + o]; + for (int i = 0; i < MLP_HIDDEN_DIM; ++i) + v += gsHiddenA[i] * g_Weights[OFF_TRUNK1_W + o * MLP_HIDDEN_DIM + i]; + gsHiddenB[o] = max(0.0, v); + } + GroupMemoryBarrierWithGroupSync(); + + // Phase 4: trunk2 (128 -> 128). gsHiddenB -> gsHiddenA. + for (int o = int(gtid.x); o < MLP_HIDDEN_DIM; o += THREAD_GROUP_SIZE) + { + float v = g_Weights[OFF_TRUNK2_B + o]; + for (int i = 0; i < MLP_HIDDEN_DIM; ++i) + v += gsHiddenB[i] * g_Weights[OFF_TRUNK2_W + o * MLP_HIDDEN_DIM + i]; + gsHiddenA[o] = max(0.0, v); + } + GroupMemoryBarrierWithGroupSync(); + + // Phase 5: heads. + if (gtid.x == 0) + { + float v = g_Weights[OFF_EXP_B]; + for (int i = 0; i < MLP_HIDDEN_DIM; ++i) + v += gsHiddenA[i] * g_Weights[OFF_EXP_W + i]; + g_OutTex[int2(0, 0)] = v; + } + if (gtid.x < uint(COLOR_PARAMS_PER_FRAME)) + { + int o = int(gtid.x); + float v = g_Weights[OFF_COL_B + o]; + for (int i = 0; i < MLP_HIDDEN_DIM; ++i) + v += gsHiddenA[i] * g_Weights[OFF_COL_W + o * MLP_HIDDEN_DIM + i]; + g_OutTex[int2(1 + o, 0)] = v; + } +} diff --git a/threedgrut/export/usd/ppisp_spg/ppisp_controller.slang.lua b/threedgrut/export/usd/ppisp_spg/ppisp_controller.slang.lua new file mode 100644 index 00000000..f001912c --- /dev/null +++ b/threedgrut/export/usd/ppisp_spg/ppisp_controller.slang.lua @@ -0,0 +1,33 @@ +-- SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +-- SPDX-License-Identifier: Apache-2.0 + +-- PPISP Controller SPG Launcher. +-- +-- Single shared launcher for every camera. Per-camera differences are +-- carried by the ``weights`` USD attribute, so this file does not need +-- to be regenerated. + +function controllerProcess(inputs, outputs, params) + local in_rgba = inputs["HdrColor"] + assert(in_rgba and in_rgba.rank == 2, "HdrColor input must be a 2D texture") + + local weights = params["weights"] + assert(weights, "controllerProcess requires the inputs:weights attribute") + + -- 1x9 single-channel float image holding [exposure, color latents]. + outputs["ControllerParams"] = slang.empty({ 1, 9 }, slang.float) + + return slang.dispatch({ + stage = "compute", + numthreads = { 32, 1, 1 }, + grid = { 1, 1, 1 }, + bind = { + slang.ParameterBlock( + slang.float(params["priorExposure"] or 0.0) + ), + slang.Texture2D(in_rgba), + slang.StructuredBuffer(weights), + slang.RWTexture2D(outputs["ControllerParams"]), + }, + }) +end diff --git a/threedgrut/export/usd/ppisp_spg/ppisp_controller.slang.usda b/threedgrut/export/usd/ppisp_spg/ppisp_controller.slang.usda new file mode 100644 index 00000000..681802b4 --- /dev/null +++ b/threedgrut/export/usd/ppisp_spg/ppisp_controller.slang.usda @@ -0,0 +1,31 @@ +#usda 1.0 +( + defaultPrim = "SlangPPISPController" +) + +def Shader "SlangPPISPController" +{ + uniform token info:implementationSource = "sourceAsset" + uniform asset info:spg:sourceAsset = @ppisp_controller.slang@ + uniform token info:spg:sourceAsset:subIdentifier = "controllerProcess" + + # Optional EXIF-derived prior exposure. Defaults to zero so the controller + # behaves identically to training-time inference when no prior is wired. + float inputs:priorExposure = 0.0 + + # Flat float buffer holding all controller weights in the layout + # encoded by ppisp_controller.slang's OFF_* offsets: + # conv1_weight (16x3) | conv1_bias (16) | + # conv2_weight (32x16) | conv2_bias (32) | + # conv3_weight (64x32) | conv3_bias (64) | + # trunk0_weight (128x1601) | trunk0_bias (128) | + # trunk1_weight (128x128) | trunk1_bias (128) | + # trunk2_weight (128x128) | trunk2_bias (128) | + # exposure_head_weight (128) | exposure_head_bias (1) | + # color_head_weight (8x128) | color_head_bias (8) + # = 241,961 floats per camera. + float[] inputs:weights = [] + + opaque inputs:HdrColor + opaque outputs:ControllerParams +} diff --git a/threedgrut/export/usd/ppisp_spg/ppisp_usd_spg_dyn.slang b/threedgrut/export/usd/ppisp_spg/ppisp_usd_spg_dyn.slang new file mode 100644 index 00000000..e84fef70 --- /dev/null +++ b/threedgrut/export/usd/ppisp_spg/ppisp_usd_spg_dyn.slang @@ -0,0 +1,212 @@ +// SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 + +// PPISP (Physically Plausible ISP) SPG Shader — controller-aware variant. +// +// Identical to ppisp_usd_spg.slang in maths. The only difference is that +// `exposureOffset` and the eight colour latents come from a 1x9 single- +// channel float texture produced by the per-camera PPISP controller +// shader (ppisp_controller_.slang), instead of from +// time-sampled USD attributes. +// +// Texel layout of the controller output (matches ppisp_controller.slang): +// (0, 0): exposureOffset +// (1..2, 0): colorLatentBlue.xy +// (3..4, 0): colorLatentRed.xy +// (5..6, 0): colorLatentGreen.xy +// (7..8, 0): colorLatentNeutral.xy + +struct PPISPDynParams +{ + float2 vignettingCenterR; + float vignettingAlpha1R; + float vignettingAlpha2R; + float vignettingAlpha3R; + + float2 vignettingCenterG; + float vignettingAlpha1G; + float vignettingAlpha2G; + float vignettingAlpha3G; + + float2 vignettingCenterB; + float vignettingAlpha1B; + float vignettingAlpha2B; + float vignettingAlpha3B; + + float crfToeR; + float crfShoulderR; + float crfGammaR; + float crfCenterR; + + float crfToeG; + float crfShoulderG; + float crfGammaG; + float crfCenterG; + + float crfToeB; + float crfShoulderB; + float crfGammaB; + float crfCenterB; +}; + +[[vk::binding(0, 1)]] ParameterBlock g_Params; +[[vk::binding(1, 1)]] Texture2D g_InTex; +[[vk::binding(2, 1)]] Texture2D g_ControllerOut; +[[vk::binding(3, 1)]] RWTexture2D g_OutTex; + +static const float2x2 ZCA_BLUE = float2x2( 0.0480542, -0.0043631, -0.0043631, 0.0481283); +static const float2x2 ZCA_RED = float2x2( 0.0580570, -0.0179872, -0.0179872, 0.0431061); +static const float2x2 ZCA_GREEN = float2x2( 0.0433336, -0.0180537, -0.0180537, 0.0580500); +static const float2x2 ZCA_NEUTRAL = float2x2( 0.0128369, -0.0034654, -0.0034654, 0.0128158); + +float3x3 computeHomography(float2 bLat, float2 rLat, float2 gLat, float2 nLat) +{ + float2 bd = mul(ZCA_BLUE, bLat); + float2 rd = mul(ZCA_RED, rLat); + float2 gd = mul(ZCA_GREEN, gLat); + float2 nd = mul(ZCA_NEUTRAL, nLat); + + float3 tB = float3(0.0 + bd.x, 0.0 + bd.y, 1.0); + float3 tR = float3(1.0 + rd.x, 0.0 + rd.y, 1.0); + float3 tG = float3(0.0 + gd.x, 1.0 + gd.y, 1.0); + float3 tGray = float3(1.0 / 3.0 + nd.x, 1.0 / 3.0 + nd.y, 1.0); + + float3x3 T = float3x3(tB.x, tR.x, tG.x, + tB.y, tR.y, tG.y, + tB.z, tR.z, tG.z); + + float3x3 skew = float3x3(0.0, -tGray.z, tGray.y, + tGray.z, 0.0, -tGray.x, + -tGray.y, tGray.x, 0.0); + + float3x3 M = mul(skew, T); + + float3 r0 = M[0]; + float3 r1 = M[1]; + float3 r2 = M[2]; + + float3 lam = cross(r0, r1); + if (dot(lam, lam) < 1.0e-20) + { + lam = cross(r0, r2); + if (dot(lam, lam) < 1.0e-20) + lam = cross(r1, r2); + } + + float3x3 Sinv = float3x3(-1.0, -1.0, 1.0, + 1.0, 0.0, 0.0, + 0.0, 1.0, 0.0); + + float3x3 D = float3x3(lam.x, 0.0, 0.0, + 0.0, lam.y, 0.0, + 0.0, 0.0, lam.z); + + float3x3 H = mul(mul(T, D), Sinv); + + float s = H[2][2]; + if (abs(s) > 1.0e-20) + H = H * (1.0 / s); + + return H; +} + +float applyVignetting(float value, float2 uv, float2 opticalCenter, float a1, float a2, float a3) +{ + float2 d = uv - opticalCenter; + float r2 = dot(d, d); + + float falloff = 1.0; + float r2Pow = r2; + falloff += a1 * r2Pow; + r2Pow *= r2; + falloff += a2 * r2Pow; + r2Pow *= r2; + falloff += a3 * r2Pow; + + return value * clamp(falloff, 0.0, 1.0); +} + +float boundedSoftplus(float raw, float minValue) { return minValue + log(1.0 + exp(raw)); } +float sigmoidF(float raw) { return 1.0 / (1.0 + exp(-raw)); } + +float applyCRF(float x, float toeRaw, float shoulderRaw, float gammaRaw, float centerRaw) +{ + x = clamp(x, 0.0, 1.0); + float toe = boundedSoftplus(toeRaw, 0.3); + float shoulder = boundedSoftplus(shoulderRaw, 0.3); + float gamma = boundedSoftplus(gammaRaw, 0.1); + float center = sigmoidF(centerRaw); + + float lerpVal = (shoulder - toe) * center + toe; + float a = (shoulder * center) / lerpVal; + float b = 1.0 - a; + + float y; + if (x <= center) + y = a * pow(x / center, toe); + else + y = 1.0 - b * pow((1.0 - x) / (1.0 - center), shoulder); + return pow(max(0.0, y), gamma); +} + +float3 applyColorCorrection(float3 rgb, float3x3 H) +{ + float intensity = rgb.x + rgb.y + rgb.z; + float3 rgi = float3(rgb.x, rgb.y, intensity); + rgi = mul(H, rgi); + rgi = rgi * (intensity / (rgi.z + 1.0e-5)); + return float3(rgi.x, rgi.y, rgi.z - rgi.x - rgi.y); +} + +[shader("compute")] +[numthreads(16, 16, 1)] +void ppispProcessDyn(uint3 tid : SV_DispatchThreadID) +{ + uint w = 0, h = 0; + g_InTex.GetDimensions(w, h); + if (tid.x >= w || tid.y >= h) + return; + + float4 pixel = g_InTex.Load(int3(tid.xy, 0)); + float3 rgb = pixel.rgb; + + float maxRes = max(float(w), float(h)); + float2 uv = float2(tid.x + 0.5 - float(w) * 0.5, + tid.y + 0.5 - float(h) * 0.5) / maxRes; + + // Read controller output (1x9 single-channel float texture). + float exposureOffset = g_ControllerOut.Load(int3(0, 0, 0)); + float2 colorLatentBlue = float2(g_ControllerOut.Load(int3(1, 0, 0)), + g_ControllerOut.Load(int3(2, 0, 0))); + float2 colorLatentRed = float2(g_ControllerOut.Load(int3(3, 0, 0)), + g_ControllerOut.Load(int3(4, 0, 0))); + float2 colorLatentGreen = float2(g_ControllerOut.Load(int3(5, 0, 0)), + g_ControllerOut.Load(int3(6, 0, 0))); + float2 colorLatentNeutral = float2(g_ControllerOut.Load(int3(7, 0, 0)), + g_ControllerOut.Load(int3(8, 0, 0))); + + rgb = rgb * exp2(exposureOffset); + + rgb.r = applyVignetting(rgb.r, uv, g_Params.vignettingCenterR, + g_Params.vignettingAlpha1R, g_Params.vignettingAlpha2R, g_Params.vignettingAlpha3R); + rgb.g = applyVignetting(rgb.g, uv, g_Params.vignettingCenterG, + g_Params.vignettingAlpha1G, g_Params.vignettingAlpha2G, g_Params.vignettingAlpha3G); + rgb.b = applyVignetting(rgb.b, uv, g_Params.vignettingCenterB, + g_Params.vignettingAlpha1B, g_Params.vignettingAlpha2B, g_Params.vignettingAlpha3B); + + float3x3 H = computeHomography(colorLatentBlue, colorLatentRed, + colorLatentGreen, colorLatentNeutral); + rgb = applyColorCorrection(rgb, H); + + rgb.r = applyCRF(rgb.r, g_Params.crfToeR, g_Params.crfShoulderR, g_Params.crfGammaR, g_Params.crfCenterR); + rgb.g = applyCRF(rgb.g, g_Params.crfToeG, g_Params.crfShoulderG, g_Params.crfGammaG, g_Params.crfCenterG); + rgb.b = applyCRF(rgb.b, g_Params.crfToeB, g_Params.crfShoulderB, g_Params.crfGammaB, g_Params.crfCenterB); + + g_OutTex[tid.xy] = float4(saturate(rgb), 1.0); +} diff --git a/threedgrut/export/usd/ppisp_spg/ppisp_usd_spg_dyn.slang.lua b/threedgrut/export/usd/ppisp_spg/ppisp_usd_spg_dyn.slang.lua new file mode 100644 index 00000000..20cfdfe5 --- /dev/null +++ b/threedgrut/export/usd/ppisp_spg/ppisp_usd_spg_dyn.slang.lua @@ -0,0 +1,68 @@ +-- SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +-- SPDX-License-Identifier: Apache-2.0 + +-- PPISP SPG Launcher (controller-aware variant). +-- +-- Reads ``exposureOffset`` and the eight colour latents from the +-- controller's output texture; the static USD inputs only carry the +-- per-camera vignetting and CRF parameters. The HdrColor input still +-- comes from the RenderProduct's primary AOV. + +function ppispProcessDyn(inputs, outputs, params) + local in_rgba = inputs["HdrColor"] + assert(in_rgba and in_rgba.rank == 2, "HdrColor input must be a 2D texture") + + local controller = inputs["ControllerParams"] + assert(controller, "ppispProcessDyn requires a ControllerParams input texture") + + local height = in_rgba.shape[1] + local width = in_rgba.shape[2] + outputs["PPISPColor"] = slang.empty({ height, width }, slang.uchar4) + + local function getFloat2(name) + local p = params[name] + return p and slang.float2(p) or slang.float2(0.0, 0.0) + end + + return slang.dispatch({ + stage = "compute", + numthreads = { 16, 16, 1 }, + grid = { math.ceil(width / 16), math.ceil(height / 16), 1 }, + bind = { + slang.ParameterBlock( + getFloat2("vignettingCenterR"), + slang.float(params["vignettingAlpha1R"] or 0.0), + slang.float(params["vignettingAlpha2R"] or 0.0), + slang.float(params["vignettingAlpha3R"] or 0.0), + + getFloat2("vignettingCenterG"), + slang.float(params["vignettingAlpha1G"] or 0.0), + slang.float(params["vignettingAlpha2G"] or 0.0), + slang.float(params["vignettingAlpha3G"] or 0.0), + + getFloat2("vignettingCenterB"), + slang.float(params["vignettingAlpha1B"] or 0.0), + slang.float(params["vignettingAlpha2B"] or 0.0), + slang.float(params["vignettingAlpha3B"] or 0.0), + + slang.float(params["crfToeR"] or 0.013659), + slang.float(params["crfShoulderR"] or 0.013659), + slang.float(params["crfGammaR"] or 0.378165), + slang.float(params["crfCenterR"] or 0.0), + + slang.float(params["crfToeG"] or 0.013659), + slang.float(params["crfShoulderG"] or 0.013659), + slang.float(params["crfGammaG"] or 0.378165), + slang.float(params["crfCenterG"] or 0.0), + + slang.float(params["crfToeB"] or 0.013659), + slang.float(params["crfShoulderB"] or 0.013659), + slang.float(params["crfGammaB"] or 0.378165), + slang.float(params["crfCenterB"] or 0.0) + ), + slang.Texture2D(in_rgba), + slang.Texture2D(controller), + slang.RWTexture2D(outputs["PPISPColor"]), + }, + }) +end diff --git a/threedgrut/export/usd/ppisp_spg/ppisp_usd_spg_dyn.slang.usda b/threedgrut/export/usd/ppisp_spg/ppisp_usd_spg_dyn.slang.usda new file mode 100644 index 00000000..97848613 --- /dev/null +++ b/threedgrut/export/usd/ppisp_spg/ppisp_usd_spg_dyn.slang.usda @@ -0,0 +1,48 @@ +#usda 1.0 +( + defaultPrim = "SlangPPISPDyn" +) + +def Shader "SlangPPISPDyn" +{ + uniform token info:implementationSource = "sourceAsset" + uniform asset info:spg:sourceAsset = @ppisp_usd_spg_dyn.slang@ + uniform token info:spg:sourceAsset:subIdentifier = "ppispProcessDyn" + + # Vignetting (per channel: R, G, B) + float2 inputs:vignettingCenterR = (0.0, 0.0) + float inputs:vignettingAlpha1R = 0.0 + float inputs:vignettingAlpha2R = 0.0 + float inputs:vignettingAlpha3R = 0.0 + + float2 inputs:vignettingCenterG = (0.0, 0.0) + float inputs:vignettingAlpha1G = 0.0 + float inputs:vignettingAlpha2G = 0.0 + float inputs:vignettingAlpha3G = 0.0 + + float2 inputs:vignettingCenterB = (0.0, 0.0) + float inputs:vignettingAlpha1B = 0.0 + float inputs:vignettingAlpha2B = 0.0 + float inputs:vignettingAlpha3B = 0.0 + + # CRF raw parameters (per channel: R, G, B) + float inputs:crfToeR = 0.013659 + float inputs:crfShoulderR = 0.013659 + float inputs:crfGammaR = 0.378165 + float inputs:crfCenterR = 0.0 + + float inputs:crfToeG = 0.013659 + float inputs:crfShoulderG = 0.013659 + float inputs:crfGammaG = 0.378165 + float inputs:crfCenterG = 0.0 + + float inputs:crfToeB = 0.013659 + float inputs:crfShoulderB = 0.013659 + float inputs:crfGammaB = 0.378165 + float inputs:crfCenterB = 0.0 + + # Image inputs/outputs + opaque inputs:HdrColor + opaque inputs:ControllerParams + opaque outputs:PPISPColor +} diff --git a/threedgrut/export/usd/writers/ppisp_controller_writer.py b/threedgrut/export/usd/writers/ppisp_controller_writer.py new file mode 100644 index 00000000..6b2ef855 --- /dev/null +++ b/threedgrut/export/usd/writers/ppisp_controller_writer.py @@ -0,0 +1,280 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 + +""" +PPISP Controller USD writer. + +Writes the per-camera PPISP controller as a UsdShade Shader prim that +references the shared ``ppisp_controller.slang`` SPG asset. The trained +controller weights are flattened into a single ``float[] inputs:weights`` +attribute on the Shader prim — the Slang shader picks them up as a +``StructuredBuffer`` at dispatch time. + +The flatten layout must match ``ppisp_controller.slang``'s ``OFF_*`` +constants: + + conv1_weight (16 x 3) | conv1_bias (16) + conv2_weight (32 x 16) | conv2_bias (32) + conv3_weight (64 x 32) | conv3_bias (64) + trunk0_weight (128 x 1601) | trunk0_bias (128) + trunk1_weight (128 x 128) | trunk1_bias (128) + trunk2_weight (128 x 128) | trunk2_bias (128) + exposure_head_weight (128) | exposure_head_bias (1) + color_head_weight (8 x 128)| color_head_bias (8) +""" + +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING, List, Sequence + +import numpy as np + +from pxr import Sdf, Usd, UsdShade, Vt + +from threedgrut.export.usd.stage_utils import NamedSerialized + +if TYPE_CHECKING: + import torch.nn as nn # noqa: F401 + +log = logging.getLogger(__name__) + + +# Names must match ppisp_controller.slang's bindings and ppisp_controller.slang.usda. +CONTROLLER_INPUT_RENDER_VAR = "HdrColor" +CONTROLLER_OUTPUT_NAME = "ControllerParams" +PRIOR_EXPOSURE_INPUT = "priorExposure" +WEIGHTS_INPUT = "weights" + +CONTROLLER_USDA_FILE = "ppisp_controller.slang.usda" +CONTROLLER_SLANG_FILE = "ppisp_controller.slang" + +# Architecture sizes (mirror ppisp._PPISPController defaults / shader constants). +EXPECTED_SIZES = { + "cnn_feature_dim": 64, + "pool_grid_h": 5, + "pool_grid_w": 5, + "mlp_hidden_dim": 128, + "color_params_per_frame": 8, + "input_downsampling": 3, +} + +# Total weight count. This *must* match ppisp_controller.slang::TOTAL_WEIGHTS. +EXPECTED_WEIGHTS_LEN = ( + 16 * 3 + 16 + + 32 * 16 + 32 + + 64 * 32 + 64 + + 128 * 1601 + 128 + + 128 * 128 + 128 + + 128 * 128 + 128 + + 128 + 1 + + 8 * 128 + 8 +) + + +# --------------------------------------------------------------------------- +# Weight extraction and validation +# --------------------------------------------------------------------------- + + +def _validate_controller_shape(controller) -> None: + """Sanity-check a ``_PPISPController`` matches the shader's hard-coded sizes.""" + cnn_encoder = controller.cnn_encoder + conv1 = cnn_encoder[0] + conv2 = cnn_encoder[3] + conv3 = cnn_encoder[5] + maxpool = cnn_encoder[1] + avgpool = cnn_encoder[6] + + if conv1.in_channels != 3 or conv1.out_channels != 16: + raise ValueError(f"controller conv1 must be 3->16, got {conv1.in_channels}->{conv1.out_channels}") + if conv1.kernel_size != (1, 1): + raise ValueError(f"controller conv1 kernel must be 1x1, got {conv1.kernel_size}") + if conv2.in_channels != 16 or conv2.out_channels != 32: + raise ValueError(f"controller conv2 must be 16->32, got {conv2.in_channels}->{conv2.out_channels}") + if conv3.in_channels != 32 or conv3.out_channels != EXPECTED_SIZES["cnn_feature_dim"]: + raise ValueError( + f"controller conv3 out_channels must be {EXPECTED_SIZES['cnn_feature_dim']}, got {conv3.out_channels}" + ) + if maxpool.kernel_size != EXPECTED_SIZES["input_downsampling"]: + raise ValueError( + f"controller maxpool kernel must be {EXPECTED_SIZES['input_downsampling']}, got {maxpool.kernel_size}" + ) + if maxpool.stride != EXPECTED_SIZES["input_downsampling"]: + raise ValueError( + f"controller maxpool stride must be {EXPECTED_SIZES['input_downsampling']}, got {maxpool.stride}" + ) + + expected_grid = (EXPECTED_SIZES["pool_grid_h"], EXPECTED_SIZES["pool_grid_w"]) + if tuple(avgpool.output_size) != expected_grid: + raise ValueError(f"controller AdaptiveAvgPool2d must be {expected_grid}, got {tuple(avgpool.output_size)}") + + trunk = controller.mlp_trunk + linear_layers = [m for m in trunk if hasattr(m, "weight") and m.weight.dim() == 2] + if len(linear_layers) != 3: + raise ValueError(f"controller MLP trunk must have 3 Linear layers, got {len(linear_layers)}") + + expected_input_dim = ( + EXPECTED_SIZES["pool_grid_h"] + * EXPECTED_SIZES["pool_grid_w"] + * EXPECTED_SIZES["cnn_feature_dim"] + + 1 + ) + if linear_layers[0].in_features != expected_input_dim: + raise ValueError( + f"controller trunk[0].in_features must be {expected_input_dim}, got {linear_layers[0].in_features}" + ) + for idx, layer in enumerate(linear_layers): + if layer.out_features != EXPECTED_SIZES["mlp_hidden_dim"]: + raise ValueError( + f"controller trunk[{idx}].out_features must be {EXPECTED_SIZES['mlp_hidden_dim']}, " + f"got {layer.out_features}" + ) + + if controller.exposure_head.out_features != 1: + raise ValueError("controller exposure_head must produce one output") + if controller.color_head.out_features != EXPECTED_SIZES["color_params_per_frame"]: + raise ValueError( + f"controller color_head must produce {EXPECTED_SIZES['color_params_per_frame']} outputs" + ) + + +def _to_np(t) -> np.ndarray: + import torch + return t.detach().cpu().to(dtype=torch.float32).numpy() + + +def flatten_controller_weights(controller) -> np.ndarray: + """Concatenate all controller weights into one float32 buffer. + + The order must match ``ppisp_controller.slang``'s ``OFF_*`` offsets. + Returns a 1-D ``np.float32`` array of length :data:`EXPECTED_WEIGHTS_LEN`. + """ + _validate_controller_shape(controller) + + cnn_encoder = controller.cnn_encoder + conv1 = cnn_encoder[0] + conv2 = cnn_encoder[3] + conv3 = cnn_encoder[5] + + trunk = controller.mlp_trunk + linear_layers = [m for m in trunk if hasattr(m, "weight") and m.weight.dim() == 2] + + def conv_w(layer) -> np.ndarray: + # PyTorch Conv2d weight: [out, in, kH, kW]. With 1x1 kernels we + # emit row-major [out * in]. + return _to_np(layer.weight).reshape(layer.out_channels, layer.in_channels).reshape(-1) + + parts: List[np.ndarray] = [ + conv_w(conv1), _to_np(conv1.bias).reshape(-1), + conv_w(conv2), _to_np(conv2.bias).reshape(-1), + conv_w(conv3), _to_np(conv3.bias).reshape(-1), + _to_np(linear_layers[0].weight).reshape(-1), _to_np(linear_layers[0].bias).reshape(-1), + _to_np(linear_layers[1].weight).reshape(-1), _to_np(linear_layers[1].bias).reshape(-1), + _to_np(linear_layers[2].weight).reshape(-1), _to_np(linear_layers[2].bias).reshape(-1), + _to_np(controller.exposure_head.weight).reshape(-1), _to_np(controller.exposure_head.bias).reshape(-1), + _to_np(controller.color_head.weight).reshape(-1), _to_np(controller.color_head.bias).reshape(-1), + ] + + flat = np.concatenate(parts).astype(np.float32, copy=False) + if flat.size != EXPECTED_WEIGHTS_LEN: + raise RuntimeError( + f"flatten_controller_weights produced {flat.size} floats; expected {EXPECTED_WEIGHTS_LEN}. " + "Did the controller architecture change?" + ) + if not np.all(np.isfinite(flat)): + raise RuntimeError( + "controller weights contain NaN/Inf; refusing to export. " + "Investigate the trained checkpoint before retrying." + ) + return flat + + +# --------------------------------------------------------------------------- +# USD authoring +# --------------------------------------------------------------------------- + + +def add_controller_shader_to_render_product( + stage: Usd.Stage, + render_product_path: str, + camera_index: int, + controller, + *, + prior_exposure: float | None = None, +) -> UsdShade.Shader: + """Author the controller Shader prim and connect ``HdrColor`` → ``ControllerParams``. + + Returns the created Shader so the caller can wire its output into the + PPISP shader. The PPISP shader is responsible for *consuming* the + output via its dynamic-controller binding. + """ + render_product = stage.GetPrimAtPath(render_product_path) + if not render_product.IsValid(): + raise ValueError(f"RenderProduct not found at path: {render_product_path}") + + # Mark HdrColor RenderVar input as an opaque AOV (no connection needed here). + input_var_path = f"{render_product_path}/{CONTROLLER_INPUT_RENDER_VAR}" + input_var_prim = stage.GetPrimAtPath(input_var_path) + if input_var_prim.IsValid(): + input_var_prim.CreateAttribute("omni:rtx:aov", Sdf.ValueTypeNames.Opaque, custom=False) + + shader_prim_name = f"PPISPController_{camera_index}" + shader_path = f"{render_product_path}/{shader_prim_name}" + shader = UsdShade.Shader.Define(stage, shader_path) + shader.GetPrim().GetReferences().AddReference(CONTROLLER_USDA_FILE) + shader.GetPrim().CreateAttribute( + "info:implementationSource", Sdf.ValueTypeNames.Token, custom=False + ).Set("sourceAsset") + shader.GetPrim().CreateAttribute( + "info:spg:sourceAsset", Sdf.ValueTypeNames.Asset, custom=False + ).Set(Sdf.AssetPath(CONTROLLER_SLANG_FILE)) + shader.GetPrim().CreateAttribute( + "info:spg:sourceAsset:subIdentifier", Sdf.ValueTypeNames.Token, custom=False + ).Set("controllerProcess") + + hdr_input = shader.CreateInput(CONTROLLER_INPUT_RENDER_VAR, Sdf.ValueTypeNames.Opaque) + hdr_input.GetAttr().SetConnections([Sdf.Path(f"../{CONTROLLER_INPUT_RENDER_VAR}.omni:rtx:aov")]) + + shader.CreateOutput(CONTROLLER_OUTPUT_NAME, Sdf.ValueTypeNames.Opaque) + + prior_input = shader.CreateInput(PRIOR_EXPOSURE_INPUT, Sdf.ValueTypeNames.Float) + prior_input.Set(float(prior_exposure or 0.0)) + + weights = flatten_controller_weights(controller) + weights_input = shader.CreateInput(WEIGHTS_INPUT, Sdf.ValueTypeNames.FloatArray) + weights_input.Set(Vt.FloatArray.FromNumpy(weights)) + + log.debug( + "Authored PPISP controller shader at %s (camera %d, %d weights)", + shader_path, camera_index, weights.size, + ) + return shader + + +# --------------------------------------------------------------------------- +# Sidecar packaging +# --------------------------------------------------------------------------- + + +def get_controller_sidecars() -> List[NamedSerialized]: + """Load the shared controller SPG sidecar files. + + Unlike the dynamic PPISP path, the controller does not need per-camera + sidecar generation: the weights live in USD attributes, so the slang / + lua / usda assets are identical for every camera. + """ + from threedgrut.export.usd.ppisp_spg import _SPG_DIR + filenames = [CONTROLLER_SLANG_FILE, CONTROLLER_SLANG_FILE + ".lua", CONTROLLER_USDA_FILE] + out: List[NamedSerialized] = [] + for name in filenames: + path = _SPG_DIR / name + if path.exists(): + out.append(NamedSerialized(filename=name, serialized=path.read_bytes())) + return out diff --git a/threedgrut/export/usd/writers/ppisp_writer.py b/threedgrut/export/usd/writers/ppisp_writer.py index fe9a766f..18fcf8af 100644 --- a/threedgrut/export/usd/writers/ppisp_writer.py +++ b/threedgrut/export/usd/writers/ppisp_writer.py @@ -48,7 +48,10 @@ PPISP_SPG_USDA_FILE = "ppisp_usd_spg.slang.usda" PPISP_SPG_SLANG_FILE = "ppisp_usd_spg.slang" +PPISP_SPG_DYN_USDA_FILE = "ppisp_usd_spg_dyn.slang.usda" +PPISP_SPG_DYN_SLANG_FILE = "ppisp_usd_spg_dyn.slang" PPISP_INPUT_RENDER_VAR = "HdrColor" +PPISP_CONTROLLER_INPUT = "ControllerParams" PPISP_OUTPUT_RENDER_VAR = "PPISPColor" LDR_COLOR_RENDER_VAR = "LdrColor" PPISP_CAMERA_EXPOSURE = 0.0 @@ -110,16 +113,35 @@ def _add_ldr_color_render_var( return render_var_path -def _create_shader_prim(stage: Usd.Stage, render_product_path: str) -> UsdShade.Shader: +def _create_shader_prim( + stage: Usd.Stage, + render_product_path: str, + *, + controller_shader: UsdShade.Shader | None = None, +) -> UsdShade.Shader: """Create the PPISP Shader prim on a RenderProduct. - Wires HdrColor → PPISP → LdrColor and appends LdrColor to orderedVars. + When ``controller_shader`` is None, the static SPG variant is used and + ``exposureOffset`` / colour latents must be authored as USD attributes + on the returned Shader. When ``controller_shader`` is provided, the + dynamic variant is used: the controller's ``ControllerParams`` output is + wired into a new opaque input on the PPISP shader, and the per-frame + exposure / colour params are sourced from the controller at runtime. + + Wires HdrColor → PPISP → LdrColor (and ControllerParams → PPISP when a + controller is present) and appends LdrColor to orderedVars. + Returns the UsdShade.Shader for parameter setting. """ render_product = stage.GetPrimAtPath(render_product_path) if not render_product.IsValid(): raise ValueError(f"RenderProduct not found at path: {render_product_path}") + use_dynamic = controller_shader is not None + usda_file = PPISP_SPG_DYN_USDA_FILE if use_dynamic else PPISP_SPG_USDA_FILE + slang_file = PPISP_SPG_DYN_SLANG_FILE if use_dynamic else PPISP_SPG_SLANG_FILE + sub_identifier = "ppispProcessDyn" if use_dynamic else "ppispProcess" + # Mark HdrColor RenderVar input as an opaque AOV (no connection needed here) input_var_path = f"{render_product_path}/{PPISP_INPUT_RENDER_VAR}" input_var_prim = stage.GetPrimAtPath(input_var_path) @@ -129,23 +151,30 @@ def _create_shader_prim(stage: Usd.Stage, render_product_path: str) -> UsdShade. # PPISP Shader prim referencing the SPG asset definition ppisp_shader_path = f"{render_product_path}/PPISP" shader = UsdShade.Shader.Define(stage, ppisp_shader_path) - shader.GetPrim().GetReferences().AddReference(PPISP_SPG_USDA_FILE) + shader.GetPrim().GetReferences().AddReference(usda_file) # Duplicate the source metadata on the instance. Some Kit SPG/Fabric paths # do not resolve referenced shader metadata when opening packaged USDZ files. shader.GetPrim().CreateAttribute("info:implementationSource", Sdf.ValueTypeNames.Token, custom=False).Set( "sourceAsset" ) shader.GetPrim().CreateAttribute("info:spg:sourceAsset", Sdf.ValueTypeNames.Asset, custom=False).Set( - Sdf.AssetPath(PPISP_SPG_SLANG_FILE) + Sdf.AssetPath(slang_file) ) shader.GetPrim().CreateAttribute("info:spg:sourceAsset:subIdentifier", Sdf.ValueTypeNames.Token, custom=False).Set( - "ppispProcess" + sub_identifier ) # HdrColor opaque input wired to the input RenderVar's AOV hdr_input = shader.CreateInput(PPISP_INPUT_RENDER_VAR, Sdf.ValueTypeNames.Opaque) hdr_input.GetAttr().SetConnections([Sdf.Path(f"../{PPISP_INPUT_RENDER_VAR}.omni:rtx:aov")]) + if use_dynamic: + controller_input = shader.CreateInput(PPISP_CONTROLLER_INPUT, Sdf.ValueTypeNames.Opaque) + controller_output_path = controller_shader.GetPath().AppendProperty( + f"outputs:{PPISP_CONTROLLER_INPUT}" + ) + controller_input.GetAttr().SetConnections([controller_output_path]) + # PPISPColor opaque output shader.CreateOutput(PPISP_OUTPUT_RENDER_VAR, Sdf.ValueTypeNames.Opaque) @@ -305,13 +334,16 @@ def add_ppisp_shader_to_render_product( ppisp: PPISP, frame_indices: List[int], fixed_frame_index: int | None = None, + controller_shader: UsdShade.Shader | None = None, ) -> Usd.Prim: """Add a PPISP Shader to a RenderProduct for one physical camera. Per-camera parameters (vignetting, CRF) are written as static USD - attributes. Per-frame parameters (exposure, color latents) are written - with a mean-based default value and one time sample per training frame - at time_code = float(frame_idx). + attributes. Per-frame parameters (exposure, color latents) are either: + - written with mean-based defaults plus per-frame time samples (when + ``controller_shader`` is None and ``fixed_frame_index`` is None), or + - read at runtime from the upstream controller shader when it is + provided (the dynamic SPG variant is selected automatically). Args: stage: USD stage containing the RenderProduct. @@ -321,26 +353,36 @@ def add_ppisp_shader_to_render_product( frame_indices: Global frame indices belonging to this camera. fixed_frame_index: If set, write this one PPISP frame state as static shader inputs instead of authoring animated time samples. + controller_shader: Optional upstream controller Shader whose + ``ControllerParams`` output supplies exposure / colour latents. Returns: The created PPISP Shader prim. """ assert camera_index < ppisp.num_cameras, f"camera_index {camera_index} >= ppisp.num_cameras {ppisp.num_cameras}" - if not frame_indices and fixed_frame_index is None: + if not frame_indices and fixed_frame_index is None and controller_shader is None: log.warning(f"No frames for camera {camera_index} at {render_product_path}, skipping") return stage.GetPseudoRoot() - shader = _create_shader_prim(stage, render_product_path) + shader = _create_shader_prim(stage, render_product_path, controller_shader=controller_shader) _set_vignetting_params(shader, ppisp, camera_index) _set_crf_params(shader, ppisp, camera_index) - if fixed_frame_index is None: + if controller_shader is not None: + # Exposure / colour latents are computed by the controller shader + # at runtime, so we don't author static or time-sampled values here. + pass + elif fixed_frame_index is None: _set_animated_exposure_params(shader, ppisp, frame_indices) _set_animated_color_params(shader, ppisp, frame_indices) else: _set_static_exposure_params(shader, ppisp, fixed_frame_index) _set_static_color_params(shader, ppisp, fixed_frame_index) - log.info(f"Added PPISP shader to {render_product_path} " f"(camera {camera_index}, {len(frame_indices)} frame(s))") + controller_suffix = ", controller" if controller_shader is not None else "" + log.info( + f"Added PPISP shader to {render_product_path} " + f"(camera {camera_index}, {len(frame_indices)} frame(s){controller_suffix})" + ) return shader.GetPrim() @@ -392,6 +434,7 @@ def add_ppisp_to_all_render_products( render_scope_path: str = "/Render", fixed_camera_index: int | None = None, fixed_frame_index: int | None = None, + use_controller: bool = False, ) -> List[Usd.Prim]: """Add PPISP shaders to every RenderProduct in the Render scope. @@ -406,11 +449,19 @@ def add_ppisp_to_all_render_products( RenderProduct instead of matching the RenderProduct camera. fixed_frame_index: If set, use this PPISP frame state as static shader inputs instead of authoring animated exposure/color samples. + use_controller: If True, author a per-camera PPISP controller shader + and wire its output into the PPISP shader, replacing the static / + time-sampled exposure & colour inputs. Requires the controller + sidecars to be packaged alongside the USD output. Returns: List of created PPISP Shader prims. """ from threedgrut.export.usd.writers.camera import _make_usd_prim_name + if use_controller: + from threedgrut.export.usd.writers.ppisp_controller_writer import ( + add_controller_shader_to_render_product, + ) render_scope = stage.GetPrimAtPath(render_scope_path) if not render_scope.IsValid(): @@ -445,6 +496,23 @@ def add_ppisp_to_all_render_products( frame_indices = camera_frame_mapping.get(camera_name, []) _create_ppisp_camera(stage, child) + controller_shader = None + if use_controller: + controllers = getattr(ppisp, "controllers", None) + if controllers is None or int(camera_index) >= len(controllers): + log.warning( + "PPISP controllers missing for camera %s (idx=%d); falling back to " + "static parameters for this RenderProduct.", + camera_name, int(camera_index), + ) + else: + controller_shader = add_controller_shader_to_render_product( + stage=stage, + render_product_path=str(child.GetPath()), + camera_index=int(camera_index), + controller=controllers[int(camera_index)], + ) + shader_prim = add_ppisp_shader_to_render_product( stage=stage, render_product_path=str(child.GetPath()), @@ -452,6 +520,7 @@ def add_ppisp_to_all_render_products( ppisp=ppisp, frame_indices=frame_indices, fixed_frame_index=fixed_frame_index, + controller_shader=controller_shader, ) created.append(shader_prim) diff --git a/tools/render_ppisp_spg/README.md b/tools/render_ppisp_spg/README.md new file mode 100644 index 00000000..16716f2b --- /dev/null +++ b/tools/render_ppisp_spg/README.md @@ -0,0 +1,50 @@ +# render_ppisp_spg + +Headless **slangpy** harness for the PPISP SPG sidecars. Lets you +validate the exported `*.slang` / `*.slang.lua` chain end-to-end without +booting Omniverse Kit. + +## What it does + +- Loads `ppisp_controller.slang` (and the dynamic / static + `ppisp_usd_spg*.slang`) directly from the on-disk SPG sidecar set. +- Strips the `[[vk::binding(*, *)]]` annotations that Kit's SPG layer + consumes (slangpy uses its own auto-binding) and dispatches the same + compute kernel. +- Reads time-sampled USD attributes off a PPISP-bearing + `RenderProduct` and walks frame-by-frame against an HDR input dir. + +## Three entry points + +| Function | Use | +| --- | --- | +| `run_controller(slang, hdr, weights, prior=0)` | Returns the 9-float controller output: `[exposureOffset, blue.xy, red.xy, green.xy, neutral.xy]`. | +| `run_ppisp_dyn(slang, hdr, ctrl_out, vignette, crf)` | Reads colour / exposure from a controller output texture and returns an LDR uint8 image. | +| `run_ppisp_static(slang, hdr, exposure, color_latents, vignette, crf)` | The legacy controller-free path; reads exposure / colour from explicit args. | + +## CLI + +``` +python tools/render_ppisp_spg/render_renderproduct.py \ + out.usdz hdr_inputs/ ldr_outputs/ +``` + +The HDR input layout is one folder per camera-name, with files named +`.{npy,exr,png}`. + +## Validation + +``` +python tools/render_ppisp_spg/validate_controller.py --tol 1e-4 +``` + +Generates a synthetic torch `_PPISPController`, bakes its weights via +`flatten_controller_weights`, dispatches the SPG controller shader, and +compares the 9-element result against the torch reference. Typical max +abs diff is around 4e-6. + +## Dependencies + +`slangpy`, `numpy`, `Pillow`, `usd-core`, and (only for +`validate_controller.py`) `torch`. `OpenEXR`/`Imath` are optional and +only loaded when an `.exr` HDR input is encountered. diff --git a/tools/render_ppisp_spg/__init__.py b/tools/render_ppisp_spg/__init__.py new file mode 100644 index 00000000..52a7a9da --- /dev/null +++ b/tools/render_ppisp_spg/__init__.py @@ -0,0 +1,2 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 diff --git a/tools/render_ppisp_spg/render_renderproduct.py b/tools/render_ppisp_spg/render_renderproduct.py new file mode 100644 index 00000000..aabdc7ba --- /dev/null +++ b/tools/render_ppisp_spg/render_renderproduct.py @@ -0,0 +1,388 @@ +#!/usr/bin/env python3 +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 + +"""Render a PPISP-bearing USD RenderProduct chain via slangpy. + +Given a USD/USDZ that was exported by ``threedgrut.export.usd.exporter`` +with PPISP Omniverse-native mode (and optionally the controller), and a +folder of HDR input images, this tool walks each ``/Render/`` +RenderProduct, finds its ``PPISP[+ Controller]`` Shader prims, resolves +their parameter values for every authored time sample, and dispatches +the matching ``.slang`` files via :mod:`tools.render_ppisp_spg.spg_runtime`. + +For a controllerless export the per-frame exposure / colour latents are +read off the time-sampled USD attributes. With a controller, the +``priorExposure`` value is read once and the controller shader is +dispatched per frame against the supplied HDR input. + +Required layout for the input HDR images: + + //.exr|.png|.npy + +Outputs are written to ``//.png``. +""" + +from __future__ import annotations + +import argparse +import logging +import sys +from dataclasses import asdict +from pathlib import Path +from typing import Dict, Iterable, List, Optional, Tuple + +import numpy as np +from PIL import Image + +from pxr import Sdf, Usd, UsdShade + +# Allow running as a script without installing the tool package. +sys.path.insert(0, str(Path(__file__).resolve().parents[2])) + +from tools.render_ppisp_spg.spg_runtime import ( # noqa: E402 + CrfParams, + VignetteParams, + run_controller, + run_ppisp_dyn, + run_ppisp_static, +) + +logger = logging.getLogger("render_ppisp_spg") + + +CHANNELS = ("R", "G", "B") + + +# --------------------------------------------------------------------------- +# USD parsing +# --------------------------------------------------------------------------- + + +def _resolve_attr_at_time(prim: Usd.Prim, attr_name: str, t: Usd.TimeCode): + attr = prim.GetAttribute(attr_name) + if attr is None or not attr.IsValid(): + return None + return attr.Get(t) + + +def _read_vignetting(ppisp_prim: Usd.Prim, t: Usd.TimeCode) -> VignetteParams: + p = VignetteParams() + + def _f(name: str, default: float) -> float: + v = _resolve_attr_at_time(ppisp_prim, f"inputs:{name}", t) + return float(v) if v is not None else default + + def _f2(name: str, default: Tuple[float, float]) -> Tuple[float, float]: + v = _resolve_attr_at_time(ppisp_prim, f"inputs:{name}", t) + if v is None: + return default + return (float(v[0]), float(v[1])) + + for ch in CHANNELS: + setattr(p, f"center_{ch.lower()}", _f2(f"vignettingCenter{ch}", (0.0, 0.0))) + setattr(p, f"alpha1_{ch.lower()}", _f(f"vignettingAlpha1{ch}", 0.0)) + setattr(p, f"alpha2_{ch.lower()}", _f(f"vignettingAlpha2{ch}", 0.0)) + setattr(p, f"alpha3_{ch.lower()}", _f(f"vignettingAlpha3{ch}", 0.0)) + return p + + +def _read_crf(ppisp_prim: Usd.Prim, t: Usd.TimeCode) -> CrfParams: + c = CrfParams() + + def _f(name: str, default: float) -> float: + v = _resolve_attr_at_time(ppisp_prim, f"inputs:{name}", t) + return float(v) if v is not None else default + + for ch in CHANNELS: + chl = ch.lower() + setattr(c, f"toe_{chl}", _f(f"crfToe{ch}", getattr(c, f"toe_{chl}"))) + setattr(c, f"shoulder_{chl}", _f(f"crfShoulder{ch}", getattr(c, f"shoulder_{chl}"))) + setattr(c, f"gamma_{chl}", _f(f"crfGamma{ch}", getattr(c, f"gamma_{chl}"))) + setattr(c, f"center_{chl}", _f(f"crfCenter{ch}", getattr(c, f"center_{chl}"))) + return c + + +def _read_color_latents(ppisp_prim: Usd.Prim, t: Usd.TimeCode) -> List[float]: + out: List[float] = [] + for name in ("colorLatentBlue", "colorLatentRed", "colorLatentGreen", "colorLatentNeutral"): + v = _resolve_attr_at_time(ppisp_prim, f"inputs:{name}", t) + if v is None: + out.extend([0.0, 0.0]) + else: + out.extend([float(v[0]), float(v[1])]) + return out + + +def _read_exposure(ppisp_prim: Usd.Prim, t: Usd.TimeCode) -> float: + v = _resolve_attr_at_time(ppisp_prim, "inputs:exposureOffset", t) + return float(v) if v is not None else 0.0 + + +def _slang_asset_path(prim: Usd.Prim) -> Optional[str]: + attr = prim.GetAttribute("info:spg:sourceAsset") + if not attr or not attr.IsValid(): + return None + val = attr.Get() + if val is None: + return None + return val.path if hasattr(val, "path") else str(val) + + +def _find_render_products(stage: Usd.Stage) -> List[Usd.Prim]: + render_scope = stage.GetPrimAtPath("/Render") + if not render_scope.IsValid(): + return [] + return [c for c in render_scope.GetChildren() if c.GetTypeName() == "RenderProduct"] + + +def _find_ppisp_and_controller(rp: Usd.Prim) -> Tuple[Optional[Usd.Prim], Optional[Usd.Prim]]: + ppisp = None + controller = None + for child in rp.GetChildren(): + if child.GetName() == "PPISP": + ppisp = child + elif child.GetName().startswith("PPISPController"): + controller = child + return ppisp, controller + + +def _frame_indices_for_prim(prim: Usd.Prim) -> List[float]: + """Union of authored time samples over the animated PPISP attributes.""" + samples: set = set() + for attr_name in ( + "inputs:exposureOffset", + "inputs:colorLatentBlue", + "inputs:colorLatentRed", + "inputs:colorLatentGreen", + "inputs:colorLatentNeutral", + ): + attr = prim.GetAttribute(attr_name) + if attr and attr.IsValid(): + samples.update(attr.GetTimeSamples() or []) + return sorted(samples) + + +# --------------------------------------------------------------------------- +# HDR image I/O +# --------------------------------------------------------------------------- + + +def _load_hdr(path: Path) -> np.ndarray: + if path.suffix.lower() == ".npy": + arr = np.load(path) + return arr.astype(np.float32) + if path.suffix.lower() in (".png", ".jpg", ".jpeg"): + img = Image.open(path).convert("RGB") + return (np.asarray(img).astype(np.float32) / 255.0) + if path.suffix.lower() == ".exr": + try: + import OpenEXR # type: ignore[import-not-found] + import Imath # type: ignore[import-not-found] + except ImportError as e: + raise RuntimeError(f"OpenEXR/Imath required to read {path}: {e}") + f = OpenEXR.InputFile(str(path)) + dw = f.header()["dataWindow"] + w = dw.max.x - dw.min.x + 1 + h = dw.max.y - dw.min.y + 1 + pt = Imath.PixelType(Imath.PixelType.FLOAT) + r, g, b = (np.frombuffer(f.channel(c, pt), dtype=np.float32).reshape(h, w) + for c in ("R", "G", "B")) + return np.stack([r, g, b], axis=-1) + raise RuntimeError(f"unsupported HDR format: {path.suffix}") + + +def _save_png(out_path: Path, image_rgba: np.ndarray) -> None: + out_path.parent.mkdir(parents=True, exist_ok=True) + Image.fromarray(image_rgba, mode="RGBA").save(out_path) + + +# --------------------------------------------------------------------------- +# Per-camera execution +# --------------------------------------------------------------------------- + + +def _process_render_product( + rp: Usd.Prim, + usd_dir: Path, + hdr_dir: Path, + out_dir: Path, + *, + frames: Optional[Iterable[int]] = None, +) -> None: + cam_name = rp.GetName() + ppisp_prim, controller_prim = _find_ppisp_and_controller(rp) + if ppisp_prim is None: + logger.warning("RenderProduct %s has no PPISP shader prim, skipping", cam_name) + return + + ppisp_slang = _slang_asset_path(ppisp_prim) + ctrl_slang = _slang_asset_path(controller_prim) if controller_prim is not None else None + if ppisp_slang is None: + logger.warning("RenderProduct %s PPISP shader has no info:spg:sourceAsset", cam_name) + return + + ppisp_slang_path = (usd_dir / ppisp_slang).resolve() + if not ppisp_slang_path.exists(): + logger.error("PPISP slang sidecar not found at %s", ppisp_slang_path) + return + ctrl_slang_path = None + if ctrl_slang is not None: + ctrl_slang_path = (usd_dir / ctrl_slang).resolve() + if not ctrl_slang_path.exists(): + logger.error("Controller slang sidecar not found at %s", ctrl_slang_path) + return + + hdr_cam_dir = hdr_dir / cam_name + if not hdr_cam_dir.exists(): + logger.warning("No HDR inputs for camera %s under %s, skipping", cam_name, hdr_dir) + return + + sample_times = _frame_indices_for_prim(ppisp_prim) + if not sample_times and controller_prim is not None: + # Controller-only path: time samples are encoded in the HDR folder names. + sample_times = sorted( + int(p.stem) for p in hdr_cam_dir.iterdir() if p.stem.isdigit() + ) + if frames is not None: + sample_times = [t for t in sample_times if int(t) in set(int(f) for f in frames)] + if not sample_times: + logger.warning("Camera %s has no frames to render", cam_name) + return + + logger.info("Rendering %s (%d frames%s)", + cam_name, len(sample_times), + " + controller" if ctrl_slang_path else "") + + for t in sample_times: + frame_index = int(t) + candidates = [ + hdr_cam_dir / f"{frame_index}.npy", + hdr_cam_dir / f"{frame_index}.exr", + hdr_cam_dir / f"{frame_index}.png", + ] + hdr_path = next((c for c in candidates if c.exists()), None) + if hdr_path is None: + logger.warning("No HDR input for %s frame %d", cam_name, frame_index) + continue + + hdr_image = _load_hdr(hdr_path) + timecode = Usd.TimeCode(float(t)) + vignette = _read_vignetting(ppisp_prim, timecode) + crf = _read_crf(ppisp_prim, timecode) + + if ctrl_slang_path is not None: + prior = _resolve_attr_at_time(controller_prim, "inputs:priorExposure", timecode) or 0.0 + weights_attr = controller_prim.GetAttribute("inputs:weights") + weights_val = weights_attr.Get(timecode) if weights_attr and weights_attr.IsValid() else None + if weights_val is None: + logger.error("Controller for %s has no inputs:weights value, skipping frame", cam_name) + continue + # USD's VtArray-backed ndarray comes back read-only / OWNDATA=False; + # slangpy.create_buffer rejects those, so force a writable copy. + weights = np.array(weights_val, dtype=np.float32, copy=True) + controller_out = run_controller(ctrl_slang_path, hdr_image, weights, + prior_exposure=float(prior)) + ldr = run_ppisp_dyn(ppisp_slang_path, hdr_image, controller_out, + vignette=vignette, crf=crf) + else: + exposure = _read_exposure(ppisp_prim, timecode) + color_latents = _read_color_latents(ppisp_prim, timecode) + ldr = run_ppisp_static(ppisp_slang_path, hdr_image, + exposure_offset=exposure, + color_latents=color_latents, + vignette=vignette, crf=crf) + + _save_png(out_dir / cam_name / f"{frame_index}.png", ldr) + logger.debug(" wrote %s/%s/%d.png", out_dir, cam_name, frame_index) + + +# --------------------------------------------------------------------------- +# CLI +# --------------------------------------------------------------------------- + + +def _resolve_usd_dir(usd_path: Path) -> Path: + """Slang/usda asset paths are relative to the USD file. For ``.usdz`` we + extract a temporary copy because the SPG sidecars are stored inside the + archive and slangpy needs them on disk.""" + if usd_path.suffix.lower() != ".usdz": + return usd_path.parent + + import tempfile + import zipfile + + target = Path(tempfile.mkdtemp(prefix="ppisp_usdz_")) + with zipfile.ZipFile(usd_path) as zf: + zf.extractall(target) + logger.info("Extracted %s → %s", usd_path, target) + return target + + +def main(argv: Optional[List[str]] = None) -> int: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("usd", type=Path, help="USD or USDZ file from the PPISP exporter") + parser.add_argument("hdr_dir", type=Path, + help="Directory of HDR inputs, organised as /.{npy,exr,png}") + parser.add_argument("out_dir", type=Path, + help="Where to write LDR PNG outputs") + parser.add_argument("--cameras", nargs="*", default=None, + help="Optional list of camera (RenderProduct) names to render") + parser.add_argument("--frames", nargs="*", type=int, default=None, + help="Optional list of frame indices to render") + parser.add_argument("--verbose", "-v", action="count", default=0, + help="Increase logging verbosity") + args = parser.parse_args(argv) + + logging.basicConfig(level=logging.WARNING - 10 * args.verbose, + format="%(asctime)s %(levelname)s %(name)s: %(message)s") + + if not args.usd.exists(): + logger.error("USD not found: %s", args.usd) + return 2 + + usd_dir = _resolve_usd_dir(args.usd) + if args.usd.suffix.lower() == ".usdz": + # Find the actual default scene file inside the extracted dir. + default_scene = next((p for p in usd_dir.glob("*.usd*") if p.suffix in (".usd", ".usda", ".usdc")), + None) + if default_scene is None: + logger.error("No top-level usd/usda/usdc inside %s", args.usd) + return 2 + usd_path = default_scene + else: + usd_path = args.usd + + stage = Usd.Stage.Open(str(usd_path)) + if stage is None: + logger.error("Failed to open USD stage at %s", usd_path) + return 2 + + products = _find_render_products(stage) + if not products: + logger.error("No RenderProducts found under /Render") + return 1 + + target_names = set(args.cameras) if args.cameras else None + for rp in products: + if target_names is not None and rp.GetName() not in target_names: + continue + _process_render_product( + rp, + usd_dir=usd_dir, + hdr_dir=args.hdr_dir, + out_dir=args.out_dir, + frames=args.frames, + ) + + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/tools/render_ppisp_spg/spg_runtime.py b/tools/render_ppisp_spg/spg_runtime.py new file mode 100644 index 00000000..027ea6ab --- /dev/null +++ b/tools/render_ppisp_spg/spg_runtime.py @@ -0,0 +1,362 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 + +""" +Headless slangpy runtime for the PPISP SPG shader chain. + +This is *not* a full Kit SPG simulator. It executes the compute stages of +the PPISP SPG sidecars directly against a supplied HDR image so the +exported asset can be validated end-to-end without booting Omniverse. + +The harness uses slangpy's low-level pipeline API: load a Slang module, +create a compute pipeline from a chosen entry point, and dispatch with +resources bound through a ``ShaderCursor`` over the root +``ShaderObject``. This matches how SPG itself binds the same shaders. + +Three entry points are available: + +- :func:`run_controller` — ``ppisp_controller_.slang`` → + 9-element ``[exposureOffset, blue.xy, red.xy, green.xy, neutral.xy]``. +- :func:`run_ppisp_dyn` — ``ppisp_usd_spg_dyn.slang``, takes the + controller output texture; returns an HxWx4 uint8 LDR image. +- :func:`run_ppisp_static` — ``ppisp_usd_spg.slang`` (no controller). +""" + +from __future__ import annotations + +import dataclasses +import logging +import math +from pathlib import Path +from typing import Sequence, Tuple + +import numpy as np +import slangpy as spy + +logger = logging.getLogger(__name__) + + +@dataclasses.dataclass +class VignetteParams: + """Per-camera vignetting parameters in shader storage order.""" + center_r: Tuple[float, float] = (0.0, 0.0) + alpha1_r: float = 0.0 + alpha2_r: float = 0.0 + alpha3_r: float = 0.0 + center_g: Tuple[float, float] = (0.0, 0.0) + alpha1_g: float = 0.0 + alpha2_g: float = 0.0 + alpha3_g: float = 0.0 + center_b: Tuple[float, float] = (0.0, 0.0) + alpha1_b: float = 0.0 + alpha2_b: float = 0.0 + alpha3_b: float = 0.0 + + +@dataclasses.dataclass +class CrfParams: + """Per-camera per-channel toe/shoulder/gamma/center raw parameters.""" + toe_r: float = 0.013659 + shoulder_r: float = 0.013659 + gamma_r: float = 0.378165 + center_r: float = 0.0 + toe_g: float = 0.013659 + shoulder_g: float = 0.013659 + gamma_g: float = 0.378165 + center_g: float = 0.0 + toe_b: float = 0.013659 + shoulder_b: float = 0.013659 + gamma_b: float = 0.378165 + center_b: float = 0.0 + + +# --------------------------------------------------------------------------- +# Device + pipeline helpers +# --------------------------------------------------------------------------- + + +def _make_device(slang_dir: Path) -> spy.Device: + return spy.create_device(include_paths=[str(slang_dir)]) + + +_VK_BINDING_RE = __import__("re").compile(r"\[\[vk::binding\([^\]]+\)\]\]\s*") + + +def _build_pipeline(device: spy.Device, slang_path: Path, entry_point_name: str): + """Compile a Slang file and return its compute pipeline. + + The PPISP SPG shaders carry ``[[vk::binding(slot, set)]]`` annotations + that match Kit's SPG descriptor layout. Slangpy uses its own automatic + binding scheme, and the explicit annotations make resource binding + silently miss (the dispatch runs but reads zeroed buffers). We strip + the annotations *for slangpy dispatch only*; the on-disk slang file + used by SPG keeps them. + """ + session = device.slang_session + src = _VK_BINDING_RE.sub("", slang_path.read_text()) + module = session.load_module_from_source( + slang_path.stem, + src, + path=str(slang_path), + ) + entry_point = module.entry_point(entry_point_name) + program = session.link_program([module], [entry_point]) + pipeline = device.create_compute_pipeline(program) + return pipeline, program + + +def _create_hdr_input_texture(device: spy.Device, hdr_image: np.ndarray) -> spy.Texture: + if hdr_image.ndim != 3 or hdr_image.shape[2] not in (3, 4): + raise ValueError(f"hdr_image must be HxWx3 or HxWx4, got shape {hdr_image.shape}") + if hdr_image.dtype != np.float32: + hdr_image = hdr_image.astype(np.float32, copy=False) + h, w, c = hdr_image.shape + if c == 3: + rgba = np.empty((h, w, 4), dtype=np.float32) + rgba[..., :3] = hdr_image + rgba[..., 3] = 1.0 + hdr_image = rgba + return device.create_texture( + width=w, + height=h, + format=spy.Format.rgba32_float, + usage=spy.TextureUsage.shader_resource | spy.TextureUsage.unordered_access, + data=np.ascontiguousarray(hdr_image), + ) + + +def _create_controller_texture(device: spy.Device, values: np.ndarray) -> spy.Texture: + flat = np.asarray(values, dtype=np.float32).reshape(-1) + if flat.size != 9: + raise ValueError(f"controller values must be 9 floats, got {flat.size}") + # 9x1 single-channel float texture, indexed at (0..8, 0). + return device.create_texture( + width=9, + height=1, + format=spy.Format.r32_float, + usage=spy.TextureUsage.shader_resource, + data=np.ascontiguousarray(flat.reshape(1, 9)), + ) + + +def _read_r32f_row(tex: spy.Texture) -> np.ndarray: + """Read back a 1-row r32_float texture as a flat float32 numpy array.""" + arr = tex.to_numpy() + return np.asarray(arr, dtype=np.float32).reshape(-1) + + +def _read_rgba8(tex: spy.Texture, h: int, w: int) -> np.ndarray: + arr = tex.to_numpy() + return np.asarray(arr, dtype=np.uint8).reshape(h, w, 4) + + +# --------------------------------------------------------------------------- +# Cursor binding helpers +# --------------------------------------------------------------------------- + + +def _set_param_block(cursor: spy.ShaderCursor, block_name: str, fields: dict) -> None: + """Populate a slang ParameterBlock by field name. The cursor we get + from the root object is itself name-addressable, so ``cursor[name]`` + walks into the parameter block automatically.""" + block = cursor[block_name] + for k, v in fields.items(): + block[k] = v + + +def _ceildiv(a: int, b: int) -> int: + return (a + b - 1) // b + + +# --------------------------------------------------------------------------- +# Controller dispatch +# --------------------------------------------------------------------------- + + +def run_controller( + slang_path: str | Path, + hdr_image: np.ndarray, + weights: np.ndarray, + prior_exposure: float = 0.0, + *, + device: spy.Device | None = None, +) -> np.ndarray: + """Dispatch the PPISP controller shader and return its 9 outputs. + + ``weights`` must be a flat float32 buffer matching the layout encoded + in ``ppisp_controller.slang`` (see + :data:`threedgrut.export.usd.writers.ppisp_controller_writer.EXPECTED_WEIGHTS_LEN`). + """ + slang_path = Path(slang_path) + if device is None: + device = _make_device(slang_path.parent) + + pipeline, _ = _build_pipeline(device, slang_path, "controllerProcess") + in_tex = _create_hdr_input_texture(device, hdr_image) + out_tex = device.create_texture( + width=9, + height=1, + format=spy.Format.r32_float, + usage=spy.TextureUsage.shader_resource | spy.TextureUsage.unordered_access, + ) + flat_weights = np.ascontiguousarray(weights.astype(np.float32, copy=False).reshape(-1)) + weights_buf = device.create_buffer( + element_count=int(flat_weights.size), + struct_size=4, + usage=spy.BufferUsage.shader_resource, + data=flat_weights, + ) + + encoder = device.create_command_encoder() + with encoder.begin_compute_pass() as cp: + shader_obj = cp.bind_pipeline(pipeline) + cur = spy.ShaderCursor(shader_obj) + _set_param_block(cur, "g_Params", {"priorExposure": float(prior_exposure)}) + cur["g_InTex"] = in_tex + cur["g_Weights"] = weights_buf + cur["g_OutTex"] = out_tex + cp.dispatch(spy.math.uint3(32, 1, 1)) + device.submit_command_buffer(encoder.finish()) + device.wait() + + return _read_r32f_row(out_tex)[:9] + + +# --------------------------------------------------------------------------- +# PPISP dispatches +# --------------------------------------------------------------------------- + + +def _vignette_dict(v: VignetteParams) -> dict: + return { + "vignettingCenterR": list(v.center_r), + "vignettingAlpha1R": v.alpha1_r, + "vignettingAlpha2R": v.alpha2_r, + "vignettingAlpha3R": v.alpha3_r, + "vignettingCenterG": list(v.center_g), + "vignettingAlpha1G": v.alpha1_g, + "vignettingAlpha2G": v.alpha2_g, + "vignettingAlpha3G": v.alpha3_g, + "vignettingCenterB": list(v.center_b), + "vignettingAlpha1B": v.alpha1_b, + "vignettingAlpha2B": v.alpha2_b, + "vignettingAlpha3B": v.alpha3_b, + } + + +def _crf_dict(c: CrfParams) -> dict: + return { + "crfToeR": c.toe_r, + "crfShoulderR": c.shoulder_r, + "crfGammaR": c.gamma_r, + "crfCenterR": c.center_r, + "crfToeG": c.toe_g, + "crfShoulderG": c.shoulder_g, + "crfGammaG": c.gamma_g, + "crfCenterG": c.center_g, + "crfToeB": c.toe_b, + "crfShoulderB": c.shoulder_b, + "crfGammaB": c.gamma_b, + "crfCenterB": c.center_b, + } + + +def run_ppisp_dyn( + slang_path: str | Path, + hdr_image: np.ndarray, + controller_output: np.ndarray, + vignette: VignetteParams, + crf: CrfParams, + *, + device: spy.Device | None = None, +) -> np.ndarray: + """Run ``ppisp_usd_spg_dyn.slang`` and return an HxWx4 uint8 LDR image.""" + slang_path = Path(slang_path) + if device is None: + device = _make_device(slang_path.parent) + + pipeline, _ = _build_pipeline(device, slang_path, "ppispProcessDyn") + h, w = hdr_image.shape[:2] + + in_tex = _create_hdr_input_texture(device, hdr_image) + ctrl_tex = _create_controller_texture(device, controller_output) + out_tex = device.create_texture( + width=w, + height=h, + format=spy.Format.rgba8_unorm, + usage=spy.TextureUsage.shader_resource | spy.TextureUsage.unordered_access, + ) + + encoder = device.create_command_encoder() + with encoder.begin_compute_pass() as cp: + shader_obj = cp.bind_pipeline(pipeline) + cur = spy.ShaderCursor(shader_obj) + _set_param_block(cur, "g_Params", + {**_vignette_dict(vignette), **_crf_dict(crf)}) + cur["g_InTex"] = in_tex + cur["g_ControllerOut"] = ctrl_tex + cur["g_OutTex"] = out_tex + cp.dispatch(spy.math.uint3(_ceildiv(w, 16) * 16, + _ceildiv(h, 16) * 16, 1)) + device.submit_command_buffer(encoder.finish()) + device.wait() + + return _read_rgba8(out_tex, h, w) + + +def run_ppisp_static( + slang_path: str | Path, + hdr_image: np.ndarray, + exposure_offset: float, + color_latents: Sequence[float], + vignette: VignetteParams, + crf: CrfParams, + *, + device: spy.Device | None = None, +) -> np.ndarray: + """Run ``ppisp_usd_spg.slang`` (no controller) and return an LDR uint8 image.""" + slang_path = Path(slang_path) + if len(color_latents) != 8: + raise ValueError(f"color_latents must have 8 entries, got {len(color_latents)}") + if device is None: + device = _make_device(slang_path.parent) + + pipeline, _ = _build_pipeline(device, slang_path, "ppispProcess") + + h, w = hdr_image.shape[:2] + in_tex = _create_hdr_input_texture(device, hdr_image) + out_tex = device.create_texture( + width=w, + height=h, + format=spy.Format.rgba8_unorm, + usage=spy.TextureUsage.shader_resource | spy.TextureUsage.unordered_access, + ) + + fields = { + "exposureOffset": float(exposure_offset), + "colorLatentBlue": [float(color_latents[0]), float(color_latents[1])], + "colorLatentRed": [float(color_latents[2]), float(color_latents[3])], + "colorLatentGreen": [float(color_latents[4]), float(color_latents[5])], + "colorLatentNeutral": [float(color_latents[6]), float(color_latents[7])], + **_vignette_dict(vignette), + **_crf_dict(crf), + } + encoder = device.create_command_encoder() + with encoder.begin_compute_pass() as cp: + shader_obj = cp.bind_pipeline(pipeline) + cur = spy.ShaderCursor(shader_obj) + _set_param_block(cur, "g_Params", fields) + cur["g_InTex"] = in_tex + cur["g_OutTex"] = out_tex + cp.dispatch(spy.math.uint3(_ceildiv(w, 16) * 16, + _ceildiv(h, 16) * 16, 1)) + device.submit_command_buffer(encoder.finish()) + device.wait() + + return _read_rgba8(out_tex, h, w) diff --git a/tools/render_ppisp_spg/validate_controller.py b/tools/render_ppisp_spg/validate_controller.py new file mode 100644 index 00000000..7d3554ff --- /dev/null +++ b/tools/render_ppisp_spg/validate_controller.py @@ -0,0 +1,179 @@ +#!/usr/bin/env python3 +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Numerical sanity check: generate a controller slang for a torch +``_PPISPController`` with known weights, dispatch it via slangpy, and +compare its 9-element output to the PyTorch forward pass. + +This script does not require the full 3DGRUT environment — only +``torch``, ``numpy``, ``slangpy`` and the in-repo writer module. It +fabricates a controller (without needing a ``ppisp.PPISP`` checkpoint) +by reproducing ``ppisp._PPISPController`` from the public +architecture description. +""" + +from __future__ import annotations + +import argparse +import logging +import sys +from pathlib import Path + +import numpy as np + +sys.path.insert(0, str(Path(__file__).resolve().parents[2])) + +# Import the writer directly (bypass threedgrut.__init__ which depends on +# heavy CUDA-only packages we don't need here). +import importlib.util as _ilu # noqa: E402 +import types as _types # noqa: E402 + +# Stub the threedgrut packages so we don't trigger their __init__. +for _pkg in ( + "threedgrut", + "threedgrut.export", + "threedgrut.export.usd", + "threedgrut.export.usd.writers", +): + if _pkg not in sys.modules: + sys.modules[_pkg] = _types.ModuleType(_pkg) + +# Stub stage_utils so the writer can import NamedSerialized. +_stage_utils_stub = _types.ModuleType("threedgrut.export.usd.stage_utils") +import dataclasses as _dc + + +@_dc.dataclass +class _NamedSerialized: + filename: str + serialized: bytes + + +_stage_utils_stub.NamedSerialized = _NamedSerialized +sys.modules["threedgrut.export.usd.stage_utils"] = _stage_utils_stub + +# Stub ppisp_spg package so get_controller_sidecars() can resolve _SPG_DIR. +_ppisp_spg_stub = _types.ModuleType("threedgrut.export.usd.ppisp_spg") +_ppisp_spg_stub._SPG_DIR = ( + Path(__file__).resolve().parents[2] / "threedgrut/export/usd/ppisp_spg" +) +sys.modules["threedgrut.export.usd.ppisp_spg"] = _ppisp_spg_stub + +_writer_path = ( + Path(__file__).resolve().parents[2] + / "threedgrut/export/usd/writers/ppisp_controller_writer.py" +) +_spec = _ilu.spec_from_file_location( + "threedgrut.export.usd.writers.ppisp_controller_writer", str(_writer_path) +) +_writer_mod = _ilu.module_from_spec(_spec) +sys.modules["threedgrut.export.usd.writers.ppisp_controller_writer"] = _writer_mod +_spec.loader.exec_module(_writer_mod) +EXPECTED_SIZES = _writer_mod.EXPECTED_SIZES +flatten_controller_weights = _writer_mod.flatten_controller_weights + +from tools.render_ppisp_spg.spg_runtime import run_controller # noqa: E402 + + +logger = logging.getLogger("validate_controller") + + +def _make_test_controller(seed: int = 0): + """Build a torch module with the same architecture as + ``ppisp._PPISPController``. Importing the real one is preferred but + we duplicate it here so the validator runs without the ppisp package.""" + import torch + from torch import nn + + class _Controller(nn.Module): + def __init__(self): + super().__init__() + cfd = EXPECTED_SIZES["cnn_feature_dim"] + grid = (EXPECTED_SIZES["pool_grid_h"], EXPECTED_SIZES["pool_grid_w"]) + self.cnn_encoder = nn.Sequential( + nn.Conv2d(3, 16, kernel_size=1), + nn.MaxPool2d(EXPECTED_SIZES["input_downsampling"], + stride=EXPECTED_SIZES["input_downsampling"]), + nn.ReLU(inplace=True), + nn.Conv2d(16, 32, kernel_size=1), + nn.ReLU(inplace=True), + nn.Conv2d(32, cfd, kernel_size=1), + nn.AdaptiveAvgPool2d(grid), + nn.Flatten(), + ) + in_dim = cfd * grid[0] * grid[1] + 1 + hd = EXPECTED_SIZES["mlp_hidden_dim"] + self.mlp_trunk = nn.Sequential( + nn.Linear(in_dim, hd), nn.ReLU(inplace=True), + nn.Linear(hd, hd), nn.ReLU(inplace=True), + nn.Linear(hd, hd), nn.ReLU(inplace=True), + ) + self.exposure_head = nn.Linear(hd, 1) + self.color_head = nn.Linear(hd, EXPECTED_SIZES["color_params_per_frame"]) + + def forward(self, rgb: torch.Tensor, prior_exposure: torch.Tensor): + features = self.cnn_encoder(rgb.permute(2, 0, 1).unsqueeze(0).detach()) + features = torch.cat([features.squeeze(0), prior_exposure], dim=0) + hidden = self.mlp_trunk(features) + return self.exposure_head(hidden).squeeze(-1), self.color_head(hidden) + + torch.manual_seed(seed) + ctrl = _Controller().eval() + # Mostly-zero weights with a tiny perturbation so outputs are non-trivial. + with torch.no_grad(): + for p in ctrl.parameters(): + p.normal_(0.0, 0.01) + return ctrl + + +def _torch_reference(ctrl, hdr_image: np.ndarray, prior_exposure: float) -> np.ndarray: + import torch + rgb = torch.from_numpy(hdr_image).float() + pe = torch.tensor([prior_exposure], dtype=torch.float32) + with torch.no_grad(): + exposure, color = ctrl(rgb, pe) + return np.concatenate([ + np.array([float(exposure)], dtype=np.float32), + color.cpu().numpy().astype(np.float32), + ]) + + +def main(argv=None) -> int: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("--width", type=int, default=64) + parser.add_argument("--height", type=int, default=48) + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--prior", type=float, default=0.25) + parser.add_argument("--tol", type=float, default=1.0e-3, + help="abs tol per output element") + parser.add_argument("--keep", type=Path, default=None, + help="Where to write the generated slang file (defaults to a tmp dir)") + parser.add_argument("--verbose", "-v", action="count", default=0) + args = parser.parse_args(argv) + logging.basicConfig(level=logging.WARNING - 10 * args.verbose, + format="%(asctime)s %(levelname)s %(name)s: %(message)s") + + ctrl = _make_test_controller(args.seed) + rng = np.random.default_rng(args.seed) + hdr = (rng.random((args.height, args.width, 3), dtype=np.float32) * 0.8 + 0.1) + + expected = _torch_reference(ctrl, hdr, args.prior) + + weights = flatten_controller_weights(ctrl) + slang_path = Path(__file__).resolve().parents[2] / ( + "threedgrut/export/usd/ppisp_spg/ppisp_controller.slang" + ) + actual = run_controller(slang_path, hdr, weights, prior_exposure=args.prior) + diff = np.abs(actual - expected) + + print(f"reference: {expected}") + print(f"slangpy: {actual}") + print(f"abs diff: {diff}") + print(f"max abs diff: {diff.max():.6g} (tol={args.tol})") + + return 0 if diff.max() <= args.tol else 1 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/tools/render_ppisp_spg/validate_e2e.py b/tools/render_ppisp_spg/validate_e2e.py new file mode 100644 index 00000000..4f6e4b33 --- /dev/null +++ b/tools/render_ppisp_spg/validate_e2e.py @@ -0,0 +1,335 @@ +#!/usr/bin/env python3 +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""End-to-end PPISP SPG export + render validation. + +The pipeline this script exercises: + +1. Build a ``ppisp.PPISP`` module with non-trivial random weights and a + handful of synthetic HDR frames. +2. Author a USD stage with one ``RenderProduct`` per camera, attach the + PPISP shader chain (controller + dynamic PPISP) using the in-repo + writer, and save the USD plus the SPG sidecars to disk. +3. Run the slangpy CLI (`render_renderproduct.py`) against the saved USD + and the synthetic HDR frames to produce LDR PNGs through the slang + shaders. +4. Apply the same PPISP module *in PyTorch* to the same HDR frames, save + them as the reference LDR PNGs. +5. Compare slangpy vs PyTorch images per-frame; report PSNR / max abs + diff. Pass / fail on a configurable PSNR threshold. + +The "training" step is replaced with a perturbed PPISP module because +the validation question is "does the SPG asset reproduce the in-process +PPISP forward pass for these (camera, frame) pairs", not "is the trained +model good". A real trained checkpoint would give the same answer +because the path through both runtimes is identical. +""" + +from __future__ import annotations + +import argparse +import logging +import math +import shutil +import subprocess +import sys +import tempfile +from pathlib import Path +from typing import Dict, List, Tuple + +import numpy as np +from PIL import Image + +import torch +from pxr import Gf, Sdf, Usd, UsdGeom + +sys.path.insert(0, str(Path(__file__).resolve().parents[2])) + +from threedgrut.export.usd.writers.ppisp_writer import ( # noqa: E402 + add_ppisp_to_all_render_products, +) +from threedgrut.export.usd.writers.ppisp_controller_writer import ( # noqa: E402 + get_controller_sidecars, +) +from threedgrut.export.usd.ppisp_spg import ( # noqa: E402 + get_ppisp_spg_dyn_files, +) +from ppisp import PPISP, DEFAULT_PPISP_CONFIG # noqa: E402 + +logger = logging.getLogger("validate_e2e") + + +def _make_perturbed_ppisp(num_cameras: int, num_frames: int, seed: int) -> PPISP: + """Build a PPISP module with non-trivial parameters for every stage.""" + torch.manual_seed(seed) + cfg = DEFAULT_PPISP_CONFIG + ppisp = PPISP(num_cameras=num_cameras, num_frames=num_frames, config=cfg).eval() + with torch.no_grad(): + ppisp.exposure_params.normal_(mean=0.0, std=0.5) + ppisp.color_params.normal_(mean=0.0, std=0.05) + ppisp.vignetting_params.normal_(mean=0.0, std=0.02) + # Keep CRF near identity so the comparison isn't dominated by huge + # nonlinearities; the math is identical between paths regardless. + ppisp.crf_params.add_(torch.randn_like(ppisp.crf_params) * 0.05) + # Perturb every controller's weights so the per-frame override has + # work to do during the dynamic-PPISP path. + for controller in ppisp.controllers: + for p in controller.parameters(): + p.normal_(mean=0.0, std=0.01) + return ppisp + + +def _build_render_product(stage: Usd.Stage, cam_name: str, width: int, height: int) -> Usd.Prim: + rp_path = f"/Render/{cam_name}" + rp = stage.DefinePrim(rp_path, "RenderProduct") + rp.CreateAttribute("resolution", Sdf.ValueTypeNames.Int2).Set(Gf.Vec2i(width, height)) + cam_prim = stage.DefinePrim(f"/World/Cameras/{cam_name}", "Camera") + rp.CreateRelationship("camera").SetTargets([cam_prim.GetPath()]) + hdr = stage.DefinePrim(f"{rp_path}/HdrColor", "RenderVar") + hdr.CreateAttribute("sourceName", Sdf.ValueTypeNames.String).Set("HdrColor") + hdr.CreateAttribute("omni:rtx:aov", Sdf.ValueTypeNames.Opaque, custom=False) + rp.CreateRelationship("orderedVars").SetTargets([Sdf.Path("HdrColor")]) + return rp + + +class _SyntheticDataset: + """Minimal stub matching what build_camera_frame_mapping reads.""" + + def __init__(self, frame_to_camera: List[int], camera_names: List[str]): + self._f2c = list(frame_to_camera) + self._names = list(camera_names) + + def __len__(self) -> int: + return len(self._f2c) + + def get_camera_names(self) -> List[str]: + return list(self._names) + + def get_camera_idx(self, frame_idx: int) -> int: + return int(self._f2c[frame_idx]) + + +def _torch_reference_ldr( + ppisp: PPISP, hdr_image: np.ndarray, camera_idx: int, frame_idx: int +) -> np.ndarray: + """Apply PPISP in PyTorch with the *same* (camera, frame) state the + slang controller path will see at runtime: the controller predicts + exposure / color from the HDR image, while vignetting and CRF use + the per-camera parameters.""" + h, w = hdr_image.shape[:2] + rgb = torch.from_numpy(hdr_image).float() + # Pixel coords like the in-process renderer: integer (x, y). + yy, xx = torch.meshgrid( + torch.arange(h, dtype=torch.float32), + torch.arange(w, dtype=torch.float32), + indexing="ij", + ) + pixel_coords = torch.stack([xx, yy], dim=-1) # [H, W, 2] + + # We want the same path as the slang shader: controller predicts the + # frame state, PPISP applies it. PPISP.forward picks the controller + # path when frame_idx == -1 (novel-view). Pass -1 here so the torch + # reference exercises the controller, matching the slang path. + ppisp_eval = ppisp.eval().to("cuda") + rgb_cuda = rgb.to("cuda") + pixel_coords_cuda = pixel_coords.to("cuda") + with torch.no_grad(): + out = ppisp_eval( + rgb_cuda, + pixel_coords_cuda, + resolution=(w, h), + camera_idx=camera_idx, + frame_idx=-1, + ) + out = out.detach().cpu().numpy() + # The PPISP CUDA kernel saturates internally; convert to uint8 like the + # slang shader does (`saturate(rgb)` -> rgba8_unorm). + ldr = (np.clip(out, 0.0, 1.0) * 255.0 + 0.5).astype(np.uint8) + rgba = np.empty((h, w, 4), dtype=np.uint8) + rgba[..., :3] = ldr + rgba[..., 3] = 255 + return rgba + + +def _save_png(path: Path, image_rgba: np.ndarray) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + Image.fromarray(image_rgba, mode="RGBA").save(path) + + +def _save_npy_hdr(path: Path, hdr: np.ndarray) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + np.save(path, hdr.astype(np.float32)) + + +def _psnr(a: np.ndarray, b: np.ndarray) -> float: + diff = a.astype(np.float32) - b.astype(np.float32) + mse = float((diff * diff).mean()) + if mse <= 0: + return float("inf") + return 20.0 * math.log10(255.0 / math.sqrt(mse)) + + +def _author_stage( + out_dir: Path, + ppisp: PPISP, + cam_names: List[str], + frame_to_camera: List[int], + resolutions: Dict[str, Tuple[int, int]], +) -> Path: + """Build and save the USD stage + ship the SPG sidecars to ``out_dir``.""" + out_dir.mkdir(parents=True, exist_ok=True) + stage = Usd.Stage.CreateNew(str(out_dir / "scene.usda")) + stage.SetMetadata("upAxis", UsdGeom.Tokens.y) + stage.DefinePrim("/World", "Xform") + stage.DefinePrim("/Render", "Scope") + for cam_name, (w, h) in resolutions.items(): + _build_render_product(stage, cam_name, w, h) + + dataset = _SyntheticDataset(frame_to_camera, cam_names) + from threedgrut.export.usd.writers.ppisp_writer import build_camera_frame_mapping + cam_names_built, mapping = build_camera_frame_mapping(dataset) + + add_ppisp_to_all_render_products( + stage=stage, + ppisp=ppisp, + camera_names=cam_names_built, + camera_frame_mapping=mapping, + use_controller=True, + ) + stage.GetRootLayer().Save() + + # Sidecars: shared dyn PPISP + shared controller files. + for s in get_ppisp_spg_dyn_files(): + (out_dir / s.filename).write_bytes(s.serialized) + for s in get_controller_sidecars(): + (out_dir / s.filename).write_bytes(s.serialized) + logger.info("Authored stage at %s with %d sidecars", + out_dir, len(list(out_dir.glob("*.slang*")))) + return out_dir / "scene.usda" + + +def main(argv=None) -> int: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("--num-cameras", type=int, default=2) + parser.add_argument("--frames-per-camera", type=int, default=2) + parser.add_argument("--width", type=int, default=128) + parser.add_argument("--height", type=int, default=96) + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--psnr-threshold", type=float, default=40.0, + help="Per-frame minimum PSNR (dB) for pass.") + parser.add_argument("--keep", type=Path, default=None, + help="Keep working dir at this path instead of a tmpdir.") + parser.add_argument("--verbose", "-v", action="count", default=0) + args = parser.parse_args(argv) + logging.basicConfig(level=logging.INFO - 10 * args.verbose, + format="%(asctime)s %(levelname)s %(name)s: %(message)s") + + if not torch.cuda.is_available(): + raise SystemExit("PPISP forward requires CUDA.") + + work_dir = args.keep + cleanup = work_dir is None + if work_dir is None: + work_dir = Path(tempfile.mkdtemp(prefix="ppisp_e2e_")) + work_dir.mkdir(parents=True, exist_ok=True) + usd_dir = work_dir / "usd" + hdr_dir = work_dir / "hdr" + ref_dir = work_dir / "reference" + slang_dir = work_dir / "slangpy" + + try: + # ---------------------------------------------------------------- + # 1. Build a non-trivial PPISP and a synthetic frame plan. + # ---------------------------------------------------------------- + cam_names = [f"cam_{i}" for i in range(args.num_cameras)] + frame_to_camera: List[int] = [] + for cam_idx in range(args.num_cameras): + frame_to_camera.extend([cam_idx] * args.frames_per_camera) + num_frames = len(frame_to_camera) + resolutions = {n: (args.width, args.height) for n in cam_names} + ppisp = _make_perturbed_ppisp(args.num_cameras, num_frames, seed=args.seed) + + # ---------------------------------------------------------------- + # 2. Synthesise HDR inputs and the PyTorch reference LDR images. + # ---------------------------------------------------------------- + rng = np.random.default_rng(args.seed) + for frame_idx, cam_idx in enumerate(frame_to_camera): + cam_name = cam_names[cam_idx] + # Smooth HDR with a few high-frequency components so the + # controller and the vignetting see real spatial variation. + yy, xx = np.mgrid[0:args.height, 0:args.width].astype(np.float32) + base = 0.4 + 0.4 * rng.random((3,), dtype=np.float32) + hdr = ( + base[None, None, :] + + 0.15 * np.cos((xx / args.width * 4 + frame_idx) * 2 * np.pi)[..., None] + + 0.15 * np.sin((yy / args.height * 4 + cam_idx) * 2 * np.pi)[..., None] + ).astype(np.float32) + hdr += rng.normal(scale=0.02, size=hdr.shape).astype(np.float32) + hdr = np.clip(hdr, 0.0, 1.5) + _save_npy_hdr(hdr_dir / cam_name / f"{frame_idx}.npy", hdr) + + ref = _torch_reference_ldr(ppisp, hdr, cam_idx, frame_idx) + _save_png(ref_dir / cam_name / f"{frame_idx}.png", ref) + + # ---------------------------------------------------------------- + # 3. Author the USD stage + sidecars on disk. + # ---------------------------------------------------------------- + usd_path = _author_stage( + usd_dir, ppisp, cam_names, frame_to_camera, resolutions + ) + + # ---------------------------------------------------------------- + # 4. Run the slangpy CLI against the saved USD. + # ---------------------------------------------------------------- + cli = Path(__file__).resolve().parent / "render_renderproduct.py" + cmd = [ + sys.executable, str(cli), + str(usd_path), str(hdr_dir), str(slang_dir), + "-vv", + ] + logger.info("Running slangpy CLI: %s", " ".join(cmd)) + proc = subprocess.run(cmd, capture_output=True, text=True) + if proc.returncode != 0: + print(proc.stdout) + print(proc.stderr, file=sys.stderr) + raise SystemExit(f"render_renderproduct.py failed (exit {proc.returncode})") + + # ---------------------------------------------------------------- + # 5. Compare per-frame reference vs slangpy outputs. + # ---------------------------------------------------------------- + worst_psnr = float("inf") + worst_pair = None + all_pass = True + for frame_idx, cam_idx in enumerate(frame_to_camera): + cam_name = cam_names[cam_idx] + ref_img = np.asarray(Image.open(ref_dir / cam_name / f"{frame_idx}.png").convert("RGBA")) + sl_path = slang_dir / cam_name / f"{frame_idx}.png" + if not sl_path.exists(): + logger.error("slangpy output missing: %s", sl_path) + all_pass = False + continue + sl_img = np.asarray(Image.open(sl_path).convert("RGBA")) + psnr = _psnr(ref_img[..., :3], sl_img[..., :3]) + max_abs = int(np.max(np.abs(ref_img[..., :3].astype(int) - sl_img[..., :3].astype(int)))) + ok = psnr >= args.psnr_threshold + print(f" cam={cam_name} frame={frame_idx} " + f"PSNR={psnr:7.3f} dB max|Δ|={max_abs} {'OK' if ok else 'FAIL'}") + if psnr < worst_psnr: + worst_psnr = psnr + worst_pair = (cam_name, frame_idx) + if not ok: + all_pass = False + + print() + print(f"worst frame: {worst_pair} at {worst_psnr:.3f} dB " + f"(threshold {args.psnr_threshold} dB)") + return 0 if all_pass else 1 + + finally: + if cleanup: + shutil.rmtree(work_dir, ignore_errors=True) + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/tools/render_ppisp_spg/validate_real_ppisp.py b/tools/render_ppisp_spg/validate_real_ppisp.py new file mode 100644 index 00000000..9fa9d966 --- /dev/null +++ b/tools/render_ppisp_spg/validate_real_ppisp.py @@ -0,0 +1,95 @@ +#!/usr/bin/env python3 +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""End-to-end validator using the *real* ``ppisp`` package. + +Builds an actual ``ppisp.PPISP`` module (one camera, one frame, default +``PPISPConfig``), runs its controller through both PyTorch and the +slangpy SPG harness, and reports the per-output abs diff. Run this after +``install_env_uv.sh`` so the full env including ``ppisp`` is available. +""" + +from __future__ import annotations + +import argparse +import logging +import sys +from pathlib import Path + +import numpy as np + +# Make the in-repo writer importable. The module path goes through +# threedgrut.export.usd.writers.ppisp_controller_writer; that import +# chain pulls heavy CUDA pieces from threedgrut/__init__.py which exist +# in the real env, so we just rely on regular imports here. +sys.path.insert(0, str(Path(__file__).resolve().parents[2])) + +from threedgrut.export.usd.writers.ppisp_controller_writer import ( # noqa: E402 + flatten_controller_weights, +) +from tools.render_ppisp_spg.spg_runtime import run_controller # noqa: E402 + +logger = logging.getLogger("validate_real_ppisp") + + +def main(argv=None) -> int: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("--width", type=int, default=64) + parser.add_argument("--height", type=int, default=48) + parser.add_argument("--prior", type=float, default=0.0) + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--tol", type=float, default=1.0e-4) + parser.add_argument("--num-cameras", type=int, default=1) + parser.add_argument("--num-frames", type=int, default=1) + args = parser.parse_args(argv) + logging.basicConfig(level=logging.INFO, + format="%(asctime)s %(levelname)s %(name)s: %(message)s") + + import torch + from ppisp import PPISP, DEFAULT_PPISP_CONFIG + + torch.manual_seed(args.seed) + rng = np.random.default_rng(args.seed) + + ppisp = PPISP(num_cameras=args.num_cameras, num_frames=args.num_frames, + config=DEFAULT_PPISP_CONFIG).eval() + if not ppisp.controllers or len(ppisp.controllers) == 0: + raise SystemExit("PPISP has no controllers — config.use_controller must be True.") + controller = ppisp.controllers[0] + + # Perturb the controller so the output is non-trivial (PPISP initialises + # everything to zero; without weights, the slang/torch outputs would both + # be zero and the validation would be vacuous). + with torch.no_grad(): + for p in controller.parameters(): + p.normal_(0.0, 0.01) + + hdr = (rng.random((args.height, args.width, 3), dtype=np.float32) * 0.6 + 0.2) + + rgb_t = torch.from_numpy(hdr).float().to(controller.exposure_head.weight.device) + pe_t = torch.tensor([args.prior], dtype=torch.float32, device=rgb_t.device) + with torch.no_grad(): + exposure, color = controller(rgb_t, pe_t) + expected = np.concatenate([ + np.array([float(exposure)], dtype=np.float32), + color.detach().cpu().numpy().astype(np.float32), + ]) + + weights = flatten_controller_weights(controller) + slang_path = Path(__file__).resolve().parents[2] / ( + "threedgrut/export/usd/ppisp_spg/ppisp_controller.slang" + ) + actual = run_controller(slang_path, hdr, weights, prior_exposure=args.prior) + + diff = np.abs(actual - expected) + print(f"reference: {expected}") + print(f"slangpy: {actual}") + print(f"abs diff: {diff}") + print(f"max abs diff: {diff.max():.6g} (tol={args.tol})") + + return 0 if diff.max() <= args.tol else 1 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/tools/render_ppisp_spg/validate_trained.py b/tools/render_ppisp_spg/validate_trained.py new file mode 100644 index 00000000..da39ea1e --- /dev/null +++ b/tools/render_ppisp_spg/validate_trained.py @@ -0,0 +1,366 @@ +#!/usr/bin/env python3 +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""End-to-end validation against a *trained* checkpoint. + +Workflow: + +1. Load the trained checkpoint (which contains the model + the trained + PPISP module including its controllers). +2. For every val frame, run the gaussian renderer to get the pre-PPISP + HDR image. Save it as ``hdr//.npy``. +3. Apply PPISP in PyTorch (the same novel-view path the SPG shader will + use, i.e. ``frame_idx=-1`` so the controller predicts the per-frame + correction). Save it as ``reference//.png``. +4. Author the controller-aware USD via the production exporter and ship + the SPG sidecars to ``usd/``. +5. Run the slangpy CLI (`render_renderproduct.py`) on the USD with the + HDR inputs from step (2) and write its outputs to ``slangpy/``. +6. Compare reference vs slangpy LDR per frame; report PSNR / max abs + diff. Pass / fail on a configurable PSNR threshold. + +This is the workflow a downstream consumer of the asset would actually +exercise: real trained PPISP, real exporter call, real slang dispatch +through the CLI. +""" + +from __future__ import annotations + +import argparse +import logging +import math +import shutil +import subprocess +import sys +import tempfile +from pathlib import Path +from typing import Dict, List, Tuple + +import numpy as np +from PIL import Image + +import torch +from pxr import Gf, Sdf, Usd, UsdGeom + +sys.path.insert(0, str(Path(__file__).resolve().parents[2])) + +from threedgrut.export.usd.writers.ppisp_writer import ( # noqa: E402 + add_ppisp_to_all_render_products, build_camera_frame_mapping, +) +from threedgrut.export.usd.writers.ppisp_controller_writer import ( # noqa: E402 + get_controller_sidecars, +) +from threedgrut.export.usd.ppisp_spg import get_ppisp_spg_dyn_files # noqa: E402 +from threedgrut.render import Renderer # noqa: E402 + +logger = logging.getLogger("validate_trained") + + +def _save_npy(path: Path, arr: np.ndarray) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + np.save(path, arr.astype(np.float32)) + + +def _save_png(path: Path, image_rgba: np.ndarray) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + Image.fromarray(image_rgba, mode="RGBA").save(path) + + +def _psnr(a: np.ndarray, b: np.ndarray) -> float: + diff = a.astype(np.float32) - b.astype(np.float32) + mse = float((diff * diff).mean()) + if mse <= 0: + return float("inf") + return 20.0 * math.log10(255.0 / math.sqrt(mse)) + + +def _to_rgba8(rgb: np.ndarray) -> np.ndarray: + rgb = (np.clip(rgb, 0.0, 1.0) * 255.0 + 0.5).astype(np.uint8) + h, w, _ = rgb.shape + rgba = np.empty((h, w, 4), dtype=np.uint8) + rgba[..., :3] = rgb + rgba[..., 3] = 255 + return rgba + + +def _build_render_product(stage: Usd.Stage, cam_name: str, width: int, height: int) -> Usd.Prim: + rp_path = f"/Render/{cam_name}" + rp = stage.DefinePrim(rp_path, "RenderProduct") + rp.CreateAttribute("resolution", Sdf.ValueTypeNames.Int2).Set(Gf.Vec2i(int(width), int(height))) + cam_prim = stage.DefinePrim(f"/World/Cameras/{cam_name}", "Camera") + rp.CreateRelationship("camera").SetTargets([cam_prim.GetPath()]) + hdr = stage.DefinePrim(f"{rp_path}/HdrColor", "RenderVar") + hdr.CreateAttribute("sourceName", Sdf.ValueTypeNames.String).Set("HdrColor") + hdr.CreateAttribute("omni:rtx:aov", Sdf.ValueTypeNames.Opaque, custom=False) + rp.CreateRelationship("orderedVars").SetTargets([Sdf.Path("HdrColor")]) + return rp + + +class _StubDataset: + def __init__(self, frame_to_camera, names): + self.f2c = list(frame_to_camera) + self.names = list(names) + + def __len__(self): return len(self.f2c) + + def get_camera_names(self): return list(self.names) + + def get_camera_idx(self, i): return int(self.f2c[i]) + + +def main(argv=None) -> int: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("--checkpoint", type=Path, required=True) + parser.add_argument("--data-path", type=Path, default=None, + help="Override the dataset path stored in the checkpoint.") + parser.add_argument("--out-dir", type=Path, default=None, + help="Working directory (default: tmp). Outputs are kept here.") + parser.add_argument("--psnr-threshold", type=float, default=35.0) + parser.add_argument("--max-frames", type=int, default=None, + help="Limit number of val frames processed.") + parser.add_argument("--use-train", action="store_true", + help="Use train frames instead of val (model has overfit, " + "so the gaussian renderer produces non-trivial HDR even after short runs).") + parser.add_argument("--save-hdr-png", action="store_true", + help="Save the HDR render as a normalised PNG for inspection.") + parser.add_argument("--verbose", "-v", action="count", default=0) + args = parser.parse_args(argv) + logging.basicConfig(level=logging.INFO - 10 * args.verbose, + format="%(asctime)s %(levelname)s %(name)s: %(message)s") + + if not torch.cuda.is_available(): + raise SystemExit("CUDA is required.") + + work = args.out_dir or Path(tempfile.mkdtemp(prefix="ppisp_trained_")) + work.mkdir(parents=True, exist_ok=True) + hdr_dir = work / "hdr" + ref_dir = work / "reference" + usd_dir = work / "usd" + slang_dir = work / "slangpy" + + # ------------------------------------------------------------------ + # 1. Load checkpoint via Renderer.from_checkpoint (uses val dataset). + # ------------------------------------------------------------------ + renderer = Renderer.from_checkpoint( + checkpoint_path=str(args.checkpoint), + path=str(args.data_path) if args.data_path else "", + out_dir=str(work / "_renderer_unused"), + save_gt=False, + computes_extra_metrics=False, + ) + model = renderer.model + post_processing = renderer.post_processing + if post_processing is None or type(post_processing).__name__ != "PPISP": + raise SystemExit("Checkpoint has no PPISP post-processing module.") + if not getattr(post_processing.config, "use_controller", False): + raise SystemExit("PPISP was trained without a controller; nothing to validate.") + if args.use_train: + # Pull the train dataloader by re-creating it (Renderer doesn't keep one). + from threedgrut.datasets.utils import configure_dataloader_for_platform + from threedgrut import datasets as ds + conf_for_train = renderer.conf + train_dataset, _ = ds.make(name=conf_for_train.dataset.type, + config=conf_for_train, ray_jitter=None) + train_dataloader = torch.utils.data.DataLoader( + train_dataset, + **configure_dataloader_for_platform( + {"num_workers": 0, "batch_size": 1, "shuffle": False, "pin_memory": True} + ), + ) + val_dataset = train_dataset + val_dataloader = train_dataloader + logger.info("Using TRAIN frames (model overfit) for validation.") + else: + val_dataset = renderer.dataset + val_dataloader = renderer.dataloader + + cam_names: List[str] + if hasattr(val_dataset, "get_camera_names"): + cam_names = list(val_dataset.get_camera_names()) + else: + cam_names = ["cam_0"] + + frame_to_camera: List[int] = [] + resolutions: Dict[str, Tuple[int, int]] = {} + + # ------------------------------------------------------------------ + # 2-3. Render HDR + PyTorch reference LDR for each val frame. + # ------------------------------------------------------------------ + from threedgrut.utils.render import apply_post_processing # noqa: E402 + + seen = 0 + for frame_idx, batch in enumerate(val_dataloader): + if args.max_frames is not None and seen >= args.max_frames: + break + + gpu_batch = val_dataset.get_gpu_batch_with_intrinsics(batch) + with torch.no_grad(): + outputs = model(gpu_batch) + hdr_tensor = outputs["pred_rgb"][0] # [H, W, 3] + + # If the gaussian render is degenerate (e.g. short training, mismatched + # camera frames), the "HDR" is all zeros and the slang/PyTorch + # comparison becomes vacuous. Fall back to the dataset GT image so the + # comparison still exercises the full PPISP pipeline on real-world + # spatial / colour variation. + if hdr_tensor.abs().max().item() < 1e-6: + logger.warning("Gaussian render is degenerate (all zero); " + "substituting GT image as HDR input.") + gt = gpu_batch.rgb_gt + hdr_tensor = gt[0] if gt.dim() == 4 else gt + + hdr_np = hdr_tensor.detach().cpu().numpy().astype(np.float32) + h, w = hdr_np.shape[:2] + + cam_idx = (val_dataset.get_camera_idx(frame_idx) + if hasattr(val_dataset, "get_camera_idx") else 0) + cam_name = cam_names[cam_idx] if cam_idx < len(cam_names) else "cam_0" + resolutions[cam_name] = (w, h) + frame_to_camera.append(cam_idx) + + _save_npy(hdr_dir / cam_name / f"{frame_idx}.npy", hdr_np) + + # Reference path: same PPISP that the slang shader will execute. + # PPISP.forward picks the controller branch when frame_idx=-1, so + # we mirror that here for an apples-to-apples comparison. + with torch.no_grad(): + outputs_ref = dict(outputs) + outputs_ref["pred_rgb"] = hdr_tensor.unsqueeze(0) + # apply_post_processing expects a batch dim. + ref = apply_post_processing( + post_processing, outputs_ref, gpu_batch, training=False + )["pred_rgb"][0] + ref_np = ref.detach().cpu().numpy().astype(np.float32) + _save_png(ref_dir / cam_name / f"{frame_idx}.png", _to_rgba8(ref_np)) + # Save the pre-quantization float reference so we can quantify the + # numerical drift independent of the rgba8_unorm round-trip. + np.save(ref_dir / cam_name / f"{frame_idx}.npy", ref_np) + seen += 1 + + if not frame_to_camera: + raise SystemExit("No validation frames found in dataset.") + logger.info("Rendered %d val frame(s)", len(frame_to_camera)) + + # ------------------------------------------------------------------ + # 4. Author the controller-aware USD + sidecars. + # ------------------------------------------------------------------ + usd_dir.mkdir(parents=True, exist_ok=True) + stage = Usd.Stage.CreateNew(str(usd_dir / "scene.usda")) + stage.SetMetadata("upAxis", UsdGeom.Tokens.y) + stage.DefinePrim("/World", "Xform") + stage.DefinePrim("/Render", "Scope") + for cam_name, (w, h) in resolutions.items(): + _build_render_product(stage, cam_name, w, h) + + dataset_stub = _StubDataset(frame_to_camera, cam_names) + cam_names_built, mapping = build_camera_frame_mapping(dataset_stub) + add_ppisp_to_all_render_products( + stage=stage, + ppisp=post_processing, + camera_names=cam_names_built, + camera_frame_mapping=mapping, + use_controller=True, + ) + stage.GetRootLayer().Save() + + for s in get_ppisp_spg_dyn_files(): + (usd_dir / s.filename).write_bytes(s.serialized) + for s in get_controller_sidecars(): + (usd_dir / s.filename).write_bytes(s.serialized) + + # ------------------------------------------------------------------ + # 5. Run the slangpy CLI. + # ------------------------------------------------------------------ + cli = Path(__file__).resolve().parent / "render_renderproduct.py" + cmd = [ + sys.executable, str(cli), + str(usd_dir / "scene.usda"), str(hdr_dir), str(slang_dir), "-vv", + ] + logger.info("Running slangpy CLI: %s", " ".join(cmd)) + proc = subprocess.run(cmd, capture_output=True, text=True) + if proc.returncode != 0: + print(proc.stdout); print(proc.stderr, file=sys.stderr) + raise SystemExit(f"render_renderproduct.py failed (exit {proc.returncode})") + + # ------------------------------------------------------------------ + # 5b. Probe: run the controller alone via slangpy on each saved HDR + # and compare its 9-element output to PyTorch. This isolates whether + # any drift in the final image originates in the controller (CNN+MLP) + # or further downstream in the PPISP shader. + # ------------------------------------------------------------------ + from threedgrut.export.usd.writers.ppisp_controller_writer import ( + flatten_controller_weights, + ) + from tools.render_ppisp_spg.spg_runtime import run_controller as _run_ctrl + print("\nController (9-float) drift, slang vs torch:") + for frame_idx, cam_idx in enumerate(frame_to_camera): + cam_name = cam_names[cam_idx] + hdr_np = np.load(hdr_dir / cam_name / f"{frame_idx}.npy") + ctrl = post_processing.controllers[cam_idx] + # Torch reference + rgb_t = torch.from_numpy(hdr_np).float().to("cuda") + pe_t = torch.zeros(1, dtype=torch.float32, device="cuda") + with torch.no_grad(): + exposure, color = ctrl(rgb_t, pe_t) + torch_out = np.concatenate([ + np.array([float(exposure)], dtype=np.float32), + color.detach().cpu().numpy().astype(np.float32), + ]) + # Slang + weights = flatten_controller_weights(ctrl) + slang_out = _run_ctrl(usd_dir / "ppisp_controller.slang", hdr_np, weights, prior_exposure=0.0) + diff = np.abs(slang_out - torch_out) + print(f" frame={frame_idx} torch={torch_out} slang={slang_out} max|Δ|={diff.max():.4g}") + + # ------------------------------------------------------------------ + # 6. Compare images. + # ------------------------------------------------------------------ + all_pass = True + worst = (None, float("inf")) + for frame_idx, cam_idx in enumerate(frame_to_camera): + cam_name = cam_names[cam_idx] + ref_path = ref_dir / cam_name / f"{frame_idx}.png" + sl_path = slang_dir / cam_name / f"{frame_idx}.png" + if not sl_path.exists(): + print(f" cam={cam_name} frame={frame_idx} MISSING slangpy output") + all_pass = False + continue + ref = np.asarray(Image.open(ref_path).convert("RGBA")) + sl = np.asarray(Image.open(sl_path).convert("RGBA")) + psnr = _psnr(ref[..., :3], sl[..., :3]) + max_abs = int(np.max(np.abs(ref[..., :3].astype(int) - sl[..., :3].astype(int)))) + ok = psnr >= args.psnr_threshold + + # Also report a float-domain diff: the slang shader writes through + # rgba8_unorm, so its output is already quantized; we re-quantize the + # PyTorch reference with the same rule and compare the float values + # of the reference to that re-quantized form. This shows whether the + # shader is matching the *post-quantization* spec exactly. + ref_float_path = ref_dir / cam_name / f"{frame_idx}.npy" + if ref_float_path.exists(): + ref_float = np.clip(np.load(ref_float_path), 0.0, 1.0) + sl_float = sl[..., :3].astype(np.float32) / 255.0 + float_diff = ref_float - sl_float + mean_abs = float(np.mean(np.abs(float_diff))) + max_abs_f = float(np.max(np.abs(float_diff))) + print(f" cam={cam_name} frame={frame_idx} PSNR={psnr:7.3f} dB " + f"max|Δ|_u8={max_abs} max|Δ|_float={max_abs_f:.4f} " + f"mean|Δ|_float={mean_abs:.5f} " + f"{'OK' if ok else 'FAIL'}") + else: + print(f" cam={cam_name} frame={frame_idx} PSNR={psnr:7.3f} dB " + f"max|Δ|={max_abs} {'OK' if ok else 'FAIL'}") + if psnr < worst[1]: + worst = ((cam_name, frame_idx), psnr) + if not ok: + all_pass = False + + print() + print(f"worst frame: {worst[0]} @ {worst[1]:.3f} dB (threshold {args.psnr_threshold} dB)") + print(f"work dir: {work}") + return 0 if all_pass else 1 + + +if __name__ == "__main__": + raise SystemExit(main()) From ef4b596c152809386a47c5855b378e59d54335df Mon Sep 17 00:00:00 2001 From: Horde Date: Thu, 30 Apr 2026 21:55:39 +0000 Subject: [PATCH 21/42] feat(export): add --ignore-ppisp-controller flag Lets a downstream consumer that doesn't want runtime controller dispatch (e.g. for stricter portability or smaller assets) opt out of the controller export even when the checkpoint contains trained controllers. The exporter falls back to the static SPG path with time-sampled exposure / color attributes derived from ppisp.exposure_params and ppisp.color_params -- the same optimized per-frame parameters PPISP would use when the controller branch is bypassed. No effect on checkpoints trained without a controller. Available both as a CLI flag (--ignore-ppisp-controller) and a YAML key (export_usd.ignore-ppisp-controller). Co-Authored-By: Claude Opus 4.7 (1M context) --- threedgrut/export/scripts/export_usd.py | 10 +++++++++ threedgrut/export/usd/exporter.py | 28 ++++++++++++++++++++++++- 2 files changed, 37 insertions(+), 1 deletion(-) diff --git a/threedgrut/export/scripts/export_usd.py b/threedgrut/export/scripts/export_usd.py index e3ccf002..2a11ce58 100644 --- a/threedgrut/export/scripts/export_usd.py +++ b/threedgrut/export/scripts/export_usd.py @@ -171,6 +171,15 @@ def parse_args(): default=None, help="Optional PPISP frame id to write as static omni-native shader inputs instead of animation.", ) + parser.add_argument( + "--ignore-ppisp-controller", + action="store_true", + help=( + "If the checkpoint contains trained PPISP controllers, ignore them and " + "export the optimized per-frame exposure/color parameters as time-sampled " + "USD attributes instead. Has no effect when the checkpoint has no controllers." + ), + ) parser.add_argument( "--post-processing-bake-epochs", type=int, @@ -423,6 +432,7 @@ def main(): "post_processing_export_frame_id", None, ), + ignore_ppisp_controller=args.ignore_ppisp_controller, post_processing_bake_epochs=_arg_or_conf( args.post_processing_bake_epochs, export_conf, diff --git a/threedgrut/export/usd/exporter.py b/threedgrut/export/usd/exporter.py index 15f2a4a9..2e6ae790 100644 --- a/threedgrut/export/usd/exporter.py +++ b/threedgrut/export/usd/exporter.py @@ -283,6 +283,7 @@ def __init__( post_processing_export_mode: str = MODE_POST_PROCESSING_EXPORT_BAKED_SH, post_processing_export_camera_id: int | None = None, post_processing_export_frame_id: int | None = None, + ignore_ppisp_controller: bool = False, post_processing_bake_epochs: int = 1, post_processing_bake_learning_rate: float = 1.0e-3, post_processing_bake_camera_id: int = 0, @@ -309,6 +310,11 @@ def __init__( Omniverse-native path; currently PPISP SPG. post_processing_export_camera_id: Optional PPISP camera index to use for every RenderProduct in omni-native mode. + ignore_ppisp_controller: If True, skip the PPISP controller export + even when the checkpoint has trained controllers, and fall back + to time-sampled exposure / colour USD attributes derived from + ``ppisp.exposure_params`` and ``ppisp.color_params``. No effect + on checkpoints that were trained without a controller. post_processing_export_frame_id: Optional PPISP frame index to write as static exposure/color inputs in omni-native mode. post_processing_bake_epochs: Number of sequential passes over the train/reference set. @@ -340,6 +346,7 @@ def __init__( self.post_processing_export_frame_id = ( None if post_processing_export_frame_id is None else int(post_processing_export_frame_id) ) + self.ignore_ppisp_controller = bool(ignore_ppisp_controller) self.post_processing_bake_epochs = int(post_processing_bake_epochs) self.post_processing_bake_learning_rate = float(post_processing_bake_learning_rate) self.post_processing_bake_camera_id = int(post_processing_bake_camera_id) @@ -696,13 +703,26 @@ def _export_ppisp( # The static-frame override modes (fixed_frame_id) intentionally bypass # the controller because the goal is to bake one specific frame's # corrections, not to predict them at runtime. - use_controller = has_controller and fixed_frame_id is None + # ignore_ppisp_controller forces the same fall-back even with animation, + # so consumers that don't want runtime controller dispatch can ship the + # optimized per-frame exposure / colour USD attributes instead. + use_controller = ( + has_controller + and fixed_frame_id is None + and not self.ignore_ppisp_controller + ) if has_controller and fixed_frame_id is not None: logger.info( "PPISP controller present but fixed_frame_id is set; using static " "exposure/color from frame %d instead of the controller.", fixed_frame_id, ) + elif has_controller and self.ignore_ppisp_controller: + logger.info( + "PPISP controller present but ignore_ppisp_controller is set; " + "exporting time-sampled exposure/color from optimized PPISP parameters " + "instead of the runtime controller." + ) from threedgrut.export.usd.ppisp_spg import get_ppisp_spg_files, get_ppisp_spg_dyn_files from threedgrut.export.usd.writers.ppisp_writer import ( @@ -791,6 +811,12 @@ def from_config(cls, conf) -> "USDExporter": "post_processing_export_frame_id", None, ), + ignore_ppisp_controller=_get_export_config_value( + export_conf, + "ignore-ppisp-controller", + "ignore_ppisp_controller", + False, + ), post_processing_bake_epochs=_get_export_config_value( export_conf, "post-processing-bake-epochs", From a2d43b6ceba62226c950053722bd5f79c08edb50 Mon Sep 17 00:00:00 2001 From: Horde Date: Thu, 30 Apr 2026 22:13:03 +0000 Subject: [PATCH 22/42] feat(tools): add PPISP controller triage script diagnose_controller.py runs three independent checks against a trained checkpoint and pinpoints which layer is responsible when the controller- driven render disagrees with the optimized-per-frame-params render: H1 -- PyTorch controller(rgb) vs trained exposure_params/color_params. Pure-python check; failure means the controller did not converge to the per-frame state during distillation. H2 -- slang controller vs PyTorch controller on the same HDR. Failure means the slang shader / weight flatten / buffer upload diverges from the trained module. H3 -- ppisp_usd_spg_dyn.slang (reads ControllerParams texture) vs ppisp_usd_spg.slang (USD attributes), both fed the same 9 floats. Failure means the SPG plumbing between the two shaders is broken. Each section prints per-frame numbers plus a clear pass/fail interpretation threshold, and the summary names the failing hypothesis. Saves having to instrument the export end-to-end again the next time something disagrees in Kit. Co-Authored-By: Claude Opus 4.7 (1M context) --- tools/render_ppisp_spg/diagnose_controller.py | 255 ++++++++++++++++++ 1 file changed, 255 insertions(+) create mode 100644 tools/render_ppisp_spg/diagnose_controller.py diff --git a/tools/render_ppisp_spg/diagnose_controller.py b/tools/render_ppisp_spg/diagnose_controller.py new file mode 100644 index 00000000..daf73d41 --- /dev/null +++ b/tools/render_ppisp_spg/diagnose_controller.py @@ -0,0 +1,255 @@ +#!/usr/bin/env python3 +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Triage helper for "controller-driven Omniverse render disagrees with the +optimized-params render". + +Runs three independent checks in order, each isolating one of the three +hypotheses you flagged. + +H1 -- *Did the controller actually learn the optimized params?* + For a sample of frames, compare in-process + exposure, color = controller(gaussian_render(frame), prior=0) + against the trained per-frame parameters + ppisp.exposure_params[frame_idx], ppisp.color_params[frame_idx]. + Pure PyTorch, no SPG/slang. If these disagree, the controller has not + converged to the per-frame state and any SPG export will inherit that. + +H2 -- *Does the slang controller match the PyTorch controller?* + Run controller(rgb, prior) twice on the same input: + a. PyTorch (ppisp.controllers[c]). + b. slangpy on ppisp_controller.slang. + These should agree to ~1e-6 (we measured 3e-7 on bonsai). A larger + delta means the slang shader, the weight flatten, or the buffer + upload disagrees with the trained controller. + +H3 -- *Is the SPG controller -> PPISP plumbing sound?* + Two ways to drive the dynamic PPISP shader on the same HDR: + a. dynamic path: controller slang writes ControllerParams texture, + ppisp_usd_spg_dyn.slang reads it. + b. static path: feed the *same* 9 floats (taken from the PyTorch + controller in step H1/H2) as USD attributes into the legacy + ppisp_usd_spg.slang shader. + These should produce byte-for-byte the same LDR image. If they + disagree, the dynamic shader's texture binding or layout is wrong. + +Usage: + python tools/render_ppisp_spg/diagnose_controller.py \ + --checkpoint runs//ckpt_last.pt \ + --max-frames 4 +""" + +from __future__ import annotations + +import argparse +import logging +import math +import sys +from pathlib import Path +from typing import List + +import numpy as np +import torch + +sys.path.insert(0, str(Path(__file__).resolve().parents[2])) + +from threedgrut.export.usd.writers.ppisp_controller_writer import ( # noqa: E402 + flatten_controller_weights, +) +from threedgrut.render import Renderer # noqa: E402 +from tools.render_ppisp_spg.spg_runtime import ( # noqa: E402 + CrfParams, VignetteParams, run_controller, run_ppisp_dyn, run_ppisp_static, +) + +logger = logging.getLogger("diagnose_controller") + + +PPISP_SPG_DIR = Path(__file__).resolve().parents[2] / "threedgrut/export/usd/ppisp_spg" +CONTROLLER_SLANG = PPISP_SPG_DIR / "ppisp_controller.slang" +PPISP_DYN_SLANG = PPISP_SPG_DIR / "ppisp_usd_spg_dyn.slang" +PPISP_STATIC_SLANG = PPISP_SPG_DIR / "ppisp_usd_spg.slang" + + +def _vignette_for_camera(ppisp, camera_idx: int) -> VignetteParams: + v = ppisp.vignetting_params[camera_idx].detach().cpu().numpy() + p = VignetteParams() + for ch_idx, ch in enumerate(("r", "g", "b")): + setattr(p, f"center_{ch}", (float(v[ch_idx, 0]), float(v[ch_idx, 1]))) + setattr(p, f"alpha1_{ch}", float(v[ch_idx, 2])) + setattr(p, f"alpha2_{ch}", float(v[ch_idx, 3])) + setattr(p, f"alpha3_{ch}", float(v[ch_idx, 4])) + return p + + +def _crf_for_camera(ppisp, camera_idx: int) -> CrfParams: + crf = ppisp.crf_params[camera_idx].detach().cpu().numpy() + p = CrfParams() + for ch_idx, ch in enumerate(("r", "g", "b")): + setattr(p, f"toe_{ch}", float(crf[ch_idx, 0])) + setattr(p, f"shoulder_{ch}", float(crf[ch_idx, 1])) + setattr(p, f"gamma_{ch}", float(crf[ch_idx, 2])) + setattr(p, f"center_{ch}", float(crf[ch_idx, 3])) + return p + + +def _torch_controller(controller, rgb_np: np.ndarray, prior: float = 0.0) -> np.ndarray: + rgb = torch.from_numpy(rgb_np).float().to("cuda") + pe = torch.tensor([prior], dtype=torch.float32, device="cuda") + with torch.no_grad(): + e, c = controller(rgb, pe) + return np.concatenate([ + np.array([float(e)], dtype=np.float32), + c.detach().cpu().numpy().astype(np.float32), + ]) + + +def _gather_frames(renderer: Renderer, max_frames: int): + """For each batch yield (frame_idx, camera_idx, hdr_np).""" + out = [] + for i, batch in enumerate(renderer.dataloader): + if i >= max_frames: + break + gpu_batch = renderer.dataset.get_gpu_batch_with_intrinsics(batch) + with torch.no_grad(): + outputs = renderer.model(gpu_batch) + hdr = outputs["pred_rgb"][0].detach().cpu().numpy().astype(np.float32) + cam = (renderer.dataset.get_camera_idx(i) if hasattr(renderer.dataset, "get_camera_idx") else 0) + out.append((i, int(cam), hdr)) + return out + + +def main(argv=None) -> int: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("--checkpoint", type=Path, required=True) + parser.add_argument("--data-path", type=Path, default=None) + parser.add_argument("--max-frames", type=int, default=4) + parser.add_argument("--prior", type=float, default=0.0, + help="priorExposure value to use at inference. Match what " + "you pass at export time (default 0.0).") + parser.add_argument("--verbose", "-v", action="count", default=0) + args = parser.parse_args(argv) + logging.basicConfig(level=logging.INFO - 10 * args.verbose, + format="%(asctime)s %(levelname)s %(name)s: %(message)s") + + if not torch.cuda.is_available(): + raise SystemExit("CUDA required.") + + renderer = Renderer.from_checkpoint( + checkpoint_path=str(args.checkpoint), + path=str(args.data_path) if args.data_path else "", + out_dir="/tmp/_diag_unused", save_gt=False, computes_extra_metrics=False, + ) + pp = renderer.post_processing + if pp is None or type(pp).__name__ != "PPISP": + raise SystemExit("Checkpoint has no PPISP module.") + if not pp.controllers or len(pp.controllers) == 0: + raise SystemExit("Checkpoint PPISP has no controllers.") + + frames = _gather_frames(renderer, args.max_frames) + if not frames: + raise SystemExit("No frames produced by the dataloader.") + + # ------------------------------------------------------------------ + # H1: trained per-frame params vs in-process controller prediction + # ------------------------------------------------------------------ + print("\n=== H1: PyTorch controller(rgb) vs trained per-frame params ===") + print(f"{'frame':>5} {'cam':>3} {'exp_train':>10} {'exp_pred':>10} " + f"{'Δexp(stops)':>12} {'col_max|Δ|':>11}") + h1_max_exp_diff = 0.0 + h1_max_col_diff = 0.0 + for fidx, cam, hdr in frames: + ctrl = pp.controllers[cam] + pred = _torch_controller(ctrl, hdr, args.prior) + exp_pred = float(pred[0]) + col_pred = pred[1:] + if fidx >= int(pp.exposure_params.shape[0]): + print(f" frame {fidx}: out of range for exposure_params (size {pp.exposure_params.shape[0]})") + continue + exp_train = float(pp.exposure_params[fidx].detach().cpu()) + col_train = pp.color_params[fidx].detach().cpu().numpy().astype(np.float32) + d_exp = exp_pred - exp_train + d_col = float(np.max(np.abs(col_pred - col_train))) + h1_max_exp_diff = max(h1_max_exp_diff, abs(d_exp)) + h1_max_col_diff = max(h1_max_col_diff, d_col) + print(f"{fidx:>5} {cam:>3} {exp_train:>+10.4f} {exp_pred:>+10.4f} " + f"{d_exp:>+12.3f} {d_col:>11.4f}") + print(f" H1 worst: Δexposure = {h1_max_exp_diff:.3f} stops max|Δcolor| = {h1_max_col_diff:.4f}") + print(f" Interpretation: if Δexposure > ~0.3 stops or Δcolor > ~0.05, the controller has") + print(f" not converged to the optimized per-frame state. The static-export path uses the") + print(f" trained values directly, so it will look 'less exposed' than the controller path.") + + # ------------------------------------------------------------------ + # H2: PyTorch controller vs slang controller + # ------------------------------------------------------------------ + print("\n=== H2: PyTorch controller vs slang controller (same HDR) ===") + print(f"{'frame':>5} {'cam':>3} {'max|Δ|':>11}") + h2_max = 0.0 + for fidx, cam, hdr in frames: + ctrl = pp.controllers[cam] + torch_out = _torch_controller(ctrl, hdr, args.prior) + weights = flatten_controller_weights(ctrl) + slang_out = run_controller(CONTROLLER_SLANG, hdr, weights, prior_exposure=args.prior) + d = float(np.max(np.abs(torch_out - slang_out))) + h2_max = max(h2_max, d) + print(f"{fidx:>5} {cam:>3} {d:>11.3e}") + print(f" H2 worst: max|Δ| = {h2_max:.3e}") + print(f" Interpretation: should be ~3e-7. Anything > 1e-3 means the slang shader,") + print(f" weight flatten, or buffer upload disagrees with the trained controller.") + + # ------------------------------------------------------------------ + # H3: dynamic shader (reads texture) vs static shader (USD attrs), + # both fed the same 9 floats from PyTorch. + # ------------------------------------------------------------------ + print("\n=== H3: slang dyn (reads ControllerParams texture) vs slang static (USD attrs) ===") + print(f"{'frame':>5} {'cam':>3} {'max|Δ|_u8':>11} {'mean|Δ|_u8':>12}") + h3_max_diff = 0 + for fidx, cam, hdr in frames: + ctrl = pp.controllers[cam] + ctrl_out = _torch_controller(ctrl, hdr, args.prior) # 9-float ground truth + vig = _vignette_for_camera(pp, cam) + crf = _crf_for_camera(pp, cam) + + # Dynamic path: controller-output 9-float buffer fed via texture. + ldr_dyn = run_ppisp_dyn(PPISP_DYN_SLANG, hdr, ctrl_out, vig, crf) + + # Static path: same 9 floats as USD attributes. Splits ctrl_out into + # exposure (1) + 4x float2 colour latents in declared order. + exposure = float(ctrl_out[0]) + color = list(ctrl_out[1:].astype(float)) + ldr_stat = run_ppisp_static(PPISP_STATIC_SLANG, hdr, exposure, color, vig, crf) + + diff = np.abs(ldr_dyn[..., :3].astype(int) - ldr_stat[..., :3].astype(int)) + max_d = int(diff.max()); mean_d = float(diff.mean()) + h3_max_diff = max(h3_max_diff, max_d) + print(f"{fidx:>5} {cam:>3} {max_d:>11d} {mean_d:>12.4f}") + print(f" H3 worst: max|Δ|_u8 = {h3_max_diff}") + print(f" Interpretation: should be 0 (or 1 from dispatch ordering). > a few means the") + print(f" dynamic shader's texture binding / texel layout disagrees with what the") + print(f" controller writes — i.e. the SPG plumbing has a bug.") + + # ------------------------------------------------------------------ + # Summary + # ------------------------------------------------------------------ + print("\n=== Summary ===") + h1_bad = h1_max_exp_diff > 0.3 or h1_max_col_diff > 0.05 + h2_bad = h2_max > 1e-3 + h3_bad = h3_max_diff > 3 + verdict = [] + if h1_bad: verdict.append("H1 fails -- controller did not learn the per-frame params (training).") + if h2_bad: verdict.append("H2 fails -- slang controller != PyTorch controller (shader / flatten).") + if h3_bad: verdict.append("H3 fails -- dyn shader != static shader on the same 9 floats (plumbing).") + if not verdict: + print(" All three checks pass within thresholds. If Omniverse still disagrees") + print(" with Python, suspect: (i) Kit applies camera exposure to HdrColor before") + print(" the SPG dispatch, or (ii) Kit's HdrColor scale != gaussian-renderer scale,") + print(" or (iii) priorExposure mismatch between training and the USD attribute.") + else: + for v in verdict: + print(f" - {v}") + + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) From ecb1a2a93e4528a05a563fc9b8d6a4c82abcba4f Mon Sep 17 00:00:00 2001 From: Horde Date: Thu, 30 Apr 2026 22:30:29 +0000 Subject: [PATCH 23/42] fix(spg): route controller output through omni:rtx:aov RenderVar Previously the PPISP-dyn shader's inputs:ControllerParams was connected directly to PPISPController_n.outputs:ControllerParams. That's valid UsdShade -- the slangpy harness happily resolves it -- but Kit's SPG runtime walks RenderProduct.orderedVars and resolves connections through omni:rtx:aov on RenderVar prims, never directly between Shader prims. A direct Shader -> Shader hop leaves the consumer unbound at dispatch time (or bound to whatever Kit falls back to), which manifests as a much-too- exposed render in Omniverse even though slangpy looks correct. Mirror the existing HdrColor / LdrColor idiom: insert an intermediate "ControllerParams" RenderVar with `opaque omni:rtx:aov.connect = PPISPController_n.outputs:ControllerParams`, add it to orderedVars, and point PPISP.inputs:ControllerParams.connect at the RenderVar's omni:rtx:aov attribute. Wiring is now: HdrColor (RenderVar+aov) -> PPISPController_n.inputs:HdrColor.connect = HdrColor.omni:rtx:aov PPISPController_n.outputs:ControllerParams -> ControllerParams (RenderVar) omni:rtx:aov.connect = PPISPController_n.outputs:ControllerParams -> PPISP.inputs:ControllerParams.connect = ControllerParams.omni:rtx:aov PPISP.outputs:PPISPColor -> LdrColor (RenderVar) omni:rtx:aov.connect = PPISP.outputs:PPISPColor Co-Authored-By: Claude Opus 4.7 (1M context) --- .../usd/writers/ppisp_controller_writer.py | 28 +++++++++++++++++-- threedgrut/export/usd/writers/ppisp_writer.py | 8 ++++-- 2 files changed, 31 insertions(+), 5 deletions(-) diff --git a/threedgrut/export/usd/writers/ppisp_controller_writer.py b/threedgrut/export/usd/writers/ppisp_controller_writer.py index 6b2ef855..e59305a6 100644 --- a/threedgrut/export/usd/writers/ppisp_controller_writer.py +++ b/threedgrut/export/usd/writers/ppisp_controller_writer.py @@ -251,9 +251,33 @@ def add_controller_shader_to_render_product( weights_input = shader.CreateInput(WEIGHTS_INPUT, Sdf.ValueTypeNames.FloatArray) weights_input.Set(Vt.FloatArray.FromNumpy(weights)) + # Route the controller output through a RenderVar with omni:rtx:aov, so + # SPG resolves it the same way it resolves HdrColor / LdrColor. Direct + # Shader -> Shader connections work in slangpy but Kit's runtime walks + # AOV connections, not arbitrary UsdShade outputs. + var_path = f"{render_product_path}/{CONTROLLER_OUTPUT_NAME}" + render_var = stage.DefinePrim(var_path, "RenderVar") + render_var.CreateAttribute("sourceName", Sdf.ValueTypeNames.String).Set(CONTROLLER_OUTPUT_NAME) + aov_attr = render_var.CreateAttribute( + "omni:rtx:aov", Sdf.ValueTypeNames.Opaque, custom=False + ) + aov_attr.SetConnections([ + shader.GetPath().AppendProperty(f"outputs:{CONTROLLER_OUTPUT_NAME}") + ]) + + # Add the intermediate var to RenderProduct.orderedVars so SPG discovers it. + ordered_vars_rel = render_product.GetRelationship("orderedVars") + if ordered_vars_rel: + targets = list(ordered_vars_rel.GetTargets()) + path = Sdf.Path(CONTROLLER_OUTPUT_NAME) + if path not in targets: + targets.append(path) + ordered_vars_rel.SetTargets(targets) + log.debug( - "Authored PPISP controller shader at %s (camera %d, %d weights)", - shader_path, camera_index, weights.size, + "Authored PPISP controller shader at %s (camera %d, %d weights), " + "AOV RenderVar at %s", + shader_path, camera_index, weights.size, var_path, ) return shader diff --git a/threedgrut/export/usd/writers/ppisp_writer.py b/threedgrut/export/usd/writers/ppisp_writer.py index 18fcf8af..b3a43fb2 100644 --- a/threedgrut/export/usd/writers/ppisp_writer.py +++ b/threedgrut/export/usd/writers/ppisp_writer.py @@ -170,10 +170,12 @@ def _create_shader_prim( if use_dynamic: controller_input = shader.CreateInput(PPISP_CONTROLLER_INPUT, Sdf.ValueTypeNames.Opaque) - controller_output_path = controller_shader.GetPath().AppendProperty( - f"outputs:{PPISP_CONTROLLER_INPUT}" + # Route through the controller's sibling RenderVar's omni:rtx:aov, + # mirroring how PPISP reads HdrColor. SPG only resolves AOV + # connections, not direct Shader -> Shader output references. + controller_input.GetAttr().SetConnections( + [Sdf.Path(f"../{PPISP_CONTROLLER_INPUT}.omni:rtx:aov")] ) - controller_input.GetAttr().SetConnections([controller_output_path]) # PPISPColor opaque output shader.CreateOutput(PPISP_OUTPUT_RENDER_VAR, Sdf.ValueTypeNames.Opaque) From 5c7fd0fe9e1c97b38f6eaf1bf50547b344d1c4af Mon Sep 17 00:00:00 2001 From: Horde Date: Thu, 30 Apr 2026 22:34:15 +0000 Subject: [PATCH 24/42] fix(spg): drop 'require' substring from PPISP lua sandbox messages Kit's SPG lua sandbox does a textual deny-list match for the lua module-loader keyword 'require' and rejects any source containing it -- including inside string literals. Both new lua launchers had assert messages that read "... requires ...", which triggered the sandbox validator and caused the entire shader graph build to fail with: LuaSandbox: Lua source rejected - forbidden pattern 'require' found Output outputs:ControllerParams of node PPISPController_0 is missing shape spec. Can't compute output shape. (The shape error is a downstream effect: with the lua rejected, SPG never executes outputs["ControllerParams"] = slang.empty(...) so the shape can't be inferred.) Replace "requires" with "needs" in both messages. Also did a quick audit for other commonly denied tokens (dofile/loadfile/loadstring/ os./io./debug./setfenv/getfenv) -- both luas are clean. Co-Authored-By: Claude Opus 4.7 (1M context) --- threedgrut/export/usd/ppisp_spg/ppisp_controller.slang.lua | 2 +- threedgrut/export/usd/ppisp_spg/ppisp_usd_spg_dyn.slang.lua | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/threedgrut/export/usd/ppisp_spg/ppisp_controller.slang.lua b/threedgrut/export/usd/ppisp_spg/ppisp_controller.slang.lua index f001912c..a05f2202 100644 --- a/threedgrut/export/usd/ppisp_spg/ppisp_controller.slang.lua +++ b/threedgrut/export/usd/ppisp_spg/ppisp_controller.slang.lua @@ -12,7 +12,7 @@ function controllerProcess(inputs, outputs, params) assert(in_rgba and in_rgba.rank == 2, "HdrColor input must be a 2D texture") local weights = params["weights"] - assert(weights, "controllerProcess requires the inputs:weights attribute") + assert(weights, "controllerProcess needs the inputs:weights attribute") -- 1x9 single-channel float image holding [exposure, color latents]. outputs["ControllerParams"] = slang.empty({ 1, 9 }, slang.float) diff --git a/threedgrut/export/usd/ppisp_spg/ppisp_usd_spg_dyn.slang.lua b/threedgrut/export/usd/ppisp_spg/ppisp_usd_spg_dyn.slang.lua index 20cfdfe5..7d140e21 100644 --- a/threedgrut/export/usd/ppisp_spg/ppisp_usd_spg_dyn.slang.lua +++ b/threedgrut/export/usd/ppisp_spg/ppisp_usd_spg_dyn.slang.lua @@ -13,7 +13,7 @@ function ppispProcessDyn(inputs, outputs, params) assert(in_rgba and in_rgba.rank == 2, "HdrColor input must be a 2D texture") local controller = inputs["ControllerParams"] - assert(controller, "ppispProcessDyn requires a ControllerParams input texture") + assert(controller, "ppispProcessDyn needs a ControllerParams input texture") local height = in_rgba.shape[1] local width = in_rgba.shape[2] From 4452bd66140c43f7cec9189f67aa10ca26edf2fb Mon Sep 17 00:00:00 2001 From: Horde Date: Thu, 30 Apr 2026 23:03:36 +0000 Subject: [PATCH 25/42] fix(spg): probe slang buffer-helper API names at lua dispatch Kit's SPG slang lua sandbox does not expose ``slang.StructuredBuffer`` (it's nil), so the previous launcher failed with "attempt to call a nil value (field 'StructuredBuffer')". Replace the hard-coded call with a small probe that tries every common HLSL/Slang buffer-resource name in turn (StructuredBuffer, RWStructuredBuffer, Buffer, RWBuffer, ByteAddressBuffer, RWByteAddressBuffer) and falls back to an explicit error message that lists every key currently in the ``slang.*`` table. If the probe doesn't find a match the user gets a precise log line naming the available helpers, which we can use to settle on the right one without further guessing. Co-Authored-By: Claude Opus 4.7 (1M context) --- .../usd/ppisp_spg/ppisp_controller.slang.lua | 24 ++++++++++++++++++- 1 file changed, 23 insertions(+), 1 deletion(-) diff --git a/threedgrut/export/usd/ppisp_spg/ppisp_controller.slang.lua b/threedgrut/export/usd/ppisp_spg/ppisp_controller.slang.lua index a05f2202..3050b1b5 100644 --- a/threedgrut/export/usd/ppisp_spg/ppisp_controller.slang.lua +++ b/threedgrut/export/usd/ppisp_spg/ppisp_controller.slang.lua @@ -7,6 +7,28 @@ -- carried by the ``weights`` USD attribute, so this file does not need -- to be regenerated. +-- Bind the controller weight buffer using whichever buffer-helper SPG's +-- slang lua API exposes. The trained weights live as a USD float[] +-- attribute (params["weights"]) and the slang shader reads them as a +-- read-only StructuredBuffer. Names tried in order match common +-- HLSL/Slang resource type names. +local function bind_weights(w) + local fn = + slang.StructuredBuffer + or slang.RWStructuredBuffer + or slang.Buffer + or slang.RWBuffer + or slang.ByteAddressBuffer + or slang.RWByteAddressBuffer + if fn then return fn(w) end + -- Surface what IS available so we can iterate the API name from + -- a Kit log without having to guess. + local names = {} + for k, _ in pairs(slang) do table.insert(names, tostring(k)) end + error("ppisp_controller: no slang buffer-binding helper found. " .. + "slang.* keys = " .. table.concat(names, ", ")) +end + function controllerProcess(inputs, outputs, params) local in_rgba = inputs["HdrColor"] assert(in_rgba and in_rgba.rank == 2, "HdrColor input must be a 2D texture") @@ -26,7 +48,7 @@ function controllerProcess(inputs, outputs, params) slang.float(params["priorExposure"] or 0.0) ), slang.Texture2D(in_rgba), - slang.StructuredBuffer(weights), + bind_weights(weights), slang.RWTexture2D(outputs["ControllerParams"]), }, }) From 206ce092fbf2c1113f92deaef6287e872ef4a252 Mon Sep 17 00:00:00 2001 From: Horde Date: Thu, 30 Apr 2026 23:07:36 +0000 Subject: [PATCH 26/42] spg: broaden buffer-helper probe to surface metatable + all candidates The previous probe only listed pairs(slang), which under SPG's slang table iterates only the directly-stored keys (the truncated 'short, full, bo...' from the Kit log). The actual resource helpers like Texture2D/dispatch live behind a __index metatable, so pairs() misses them and the prior probe couldn't tell us anything useful. Expand the candidate list to cover everything plausible (cased and uncased forms: Buffer/buffer, Array/array, FloatArray/floatArray, image/Image, uniform/Uniform, list/List, etc.) and test each via direct field access (which goes through __index). Report both the direct pairs() keys and the candidate list in the error if nothing matches. Co-Authored-By: Claude Opus 4.7 (1M context) --- .../usd/ppisp_spg/ppisp_controller.slang.lua | 46 +++++++++++++------ 1 file changed, 31 insertions(+), 15 deletions(-) diff --git a/threedgrut/export/usd/ppisp_spg/ppisp_controller.slang.lua b/threedgrut/export/usd/ppisp_spg/ppisp_controller.slang.lua index 3050b1b5..20cdb4cd 100644 --- a/threedgrut/export/usd/ppisp_spg/ppisp_controller.slang.lua +++ b/threedgrut/export/usd/ppisp_spg/ppisp_controller.slang.lua @@ -10,23 +10,39 @@ -- Bind the controller weight buffer using whichever buffer-helper SPG's -- slang lua API exposes. The trained weights live as a USD float[] -- attribute (params["weights"]) and the slang shader reads them as a --- read-only StructuredBuffer. Names tried in order match common --- HLSL/Slang resource type names. +-- read-only StructuredBuffer. local function bind_weights(w) - local fn = - slang.StructuredBuffer - or slang.RWStructuredBuffer - or slang.Buffer - or slang.RWBuffer - or slang.ByteAddressBuffer - or slang.RWByteAddressBuffer - if fn then return fn(w) end - -- Surface what IS available so we can iterate the API name from - -- a Kit log without having to guess. - local names = {} - for k, _ in pairs(slang) do table.insert(names, tostring(k)) end + -- Probe a long list of plausible names. The first non-nil wins. + local candidates = { + "StructuredBuffer", "RWStructuredBuffer", + "Buffer", "RWBuffer", + "ByteAddressBuffer", "RWByteAddressBuffer", + "ConstantBuffer", + "buffer", "Array", "array", + "float_array", "FloatArray", "floatArray", + "FloatBuffer", "floatBuffer", + "image", "Image", + "uniform", "Uniform", + "list", "List", + } + local hits = {} + for _, name in ipairs(candidates) do + if slang[name] ~= nil then + table.insert(hits, name) + end + end + if #hits > 0 then + return slang[hits[1]](w) + end + -- No buffer helper. List EVERY direct slang.* key plus every + -- candidate we tried (so the metatable surface is also probed via + -- __index above). The error message goes to Kit's log. + local direct = {} + for k, _ in pairs(slang) do table.insert(direct, tostring(k)) end + table.sort(direct) error("ppisp_controller: no slang buffer-binding helper found. " .. - "slang.* keys = " .. table.concat(names, ", ")) + "Tried: " .. table.concat(candidates, ",") .. + " | direct keys = " .. table.concat(direct, ",")) end function controllerProcess(inputs, outputs, params) From b096cb4acc800e80c55a4cde15ab8b4a4eb97694 Mon Sep 17 00:00:00 2001 From: Horde Date: Thu, 30 Apr 2026 23:12:19 +0000 Subject: [PATCH 27/42] fix(spg): rename g_Weights -> weights to match USD attribute name MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Kit's SPG was emitting one warning per dispatch: [Warning] [rtx.spg.slang] SpgSlangNode: Failed to find parameter 'params:weights' in shader reflection because the slang shader declared the buffer as ``g_Weights`` while the USD Shader prim exposes it as ``inputs:weights``. SPG's reflection-based auto-binding looks up parameters by name; the mismatch caused the spam even though the explicit lua bind in the dispatch worked. Rename the slang variable to ``weights`` so the names match. Validation still passes at PSNR 63 dB; max|Δ| in the controller's 9-float output remains ~3.58e-7. Co-Authored-By: Claude Opus 4.7 (1M context) --- .../usd/ppisp_spg/ppisp_controller.slang | 42 +++++++++---------- tools/render_ppisp_spg/spg_runtime.py | 2 +- 2 files changed, 22 insertions(+), 22 deletions(-) diff --git a/threedgrut/export/usd/ppisp_spg/ppisp_controller.slang b/threedgrut/export/usd/ppisp_spg/ppisp_controller.slang index 9ead0090..e809eb58 100644 --- a/threedgrut/export/usd/ppisp_spg/ppisp_controller.slang +++ b/threedgrut/export/usd/ppisp_spg/ppisp_controller.slang @@ -52,7 +52,7 @@ static const int THREAD_GROUP_SIZE = 32; // --------------------------------------------------------------------------- // Weight buffer offsets (the Python writer flattens weights in this order -// into a single float buffer that gets bound as g_Weights). +// into a single float buffer that gets bound as weights). // --------------------------------------------------------------------------- static const int OFF_CONV1_W = 0; // 16 * 3 = 48 static const int OFF_CONV1_B = OFF_CONV1_W + 16 * 3; // + 16 = 64 @@ -83,7 +83,7 @@ struct PPISPControllerParams [[vk::binding(0, 1)]] ParameterBlock g_Params; [[vk::binding(1, 1)]] Texture2D g_InTex; -[[vk::binding(2, 1)]] StructuredBuffer g_Weights; +[[vk::binding(2, 1)]] StructuredBuffer weights; [[vk::binding(3, 1)]] RWTexture2D g_OutTex; // --------------------------------------------------------------------------- @@ -94,10 +94,10 @@ void conv1Forward(float3 rgb, out float feat[16]) { [unroll] for (int o = 0; o < 16; ++o) { - float v = g_Weights[OFF_CONV1_B + o]; - v += rgb.r * g_Weights[OFF_CONV1_W + o * 3 + 0]; - v += rgb.g * g_Weights[OFF_CONV1_W + o * 3 + 1]; - v += rgb.b * g_Weights[OFF_CONV1_W + o * 3 + 2]; + float v = weights[OFF_CONV1_B + o]; + v += rgb.r * weights[OFF_CONV1_W + o * 3 + 0]; + v += rgb.g * weights[OFF_CONV1_W + o * 3 + 1]; + v += rgb.b * weights[OFF_CONV1_W + o * 3 + 2]; feat[o] = v; } } @@ -106,9 +106,9 @@ void conv2Forward(float fin[16], out float fout[32]) { [unroll] for (int o = 0; o < 32; ++o) { - float v = g_Weights[OFF_CONV2_B + o]; + float v = weights[OFF_CONV2_B + o]; [unroll] for (int i = 0; i < 16; ++i) - v += fin[i] * g_Weights[OFF_CONV2_W + o * 16 + i]; + v += fin[i] * weights[OFF_CONV2_W + o * 16 + i]; fout[o] = v; } } @@ -117,9 +117,9 @@ void conv3Forward(float fin[32], out float fout[64]) { [unroll] for (int o = 0; o < CNN_FEATURE_DIM; ++o) { - float v = g_Weights[OFF_CONV3_B + o]; + float v = weights[OFF_CONV3_B + o]; [unroll] for (int i = 0; i < 32; ++i) - v += fin[i] * g_Weights[OFF_CONV3_W + o * 32 + i]; + v += fin[i] * weights[OFF_CONV3_W + o * 32 + i]; fout[o] = v; } } @@ -241,11 +241,11 @@ void controllerProcess(uint3 gtid : SV_GroupThreadID) // across the THREAD_GROUP_SIZE threads. for (int o = int(gtid.x); o < MLP_HIDDEN_DIM; o += THREAD_GROUP_SIZE) { - float v = g_Weights[OFF_TRUNK0_B + o]; + float v = weights[OFF_TRUNK0_B + o]; for (int i = 0; i < POOL_FEATURE_LEN; ++i) - v += gsPooled[i] * g_Weights[OFF_TRUNK0_W + o * MLP_INPUT_DIM + i]; + v += gsPooled[i] * weights[OFF_TRUNK0_W + o * MLP_INPUT_DIM + i]; v += g_Params.priorExposure - * g_Weights[OFF_TRUNK0_W + o * MLP_INPUT_DIM + POOL_FEATURE_LEN]; + * weights[OFF_TRUNK0_W + o * MLP_INPUT_DIM + POOL_FEATURE_LEN]; gsHiddenA[o] = max(0.0, v); } GroupMemoryBarrierWithGroupSync(); @@ -253,9 +253,9 @@ void controllerProcess(uint3 gtid : SV_GroupThreadID) // Phase 3: trunk1 (128 -> 128). gsHiddenA -> gsHiddenB. for (int o = int(gtid.x); o < MLP_HIDDEN_DIM; o += THREAD_GROUP_SIZE) { - float v = g_Weights[OFF_TRUNK1_B + o]; + float v = weights[OFF_TRUNK1_B + o]; for (int i = 0; i < MLP_HIDDEN_DIM; ++i) - v += gsHiddenA[i] * g_Weights[OFF_TRUNK1_W + o * MLP_HIDDEN_DIM + i]; + v += gsHiddenA[i] * weights[OFF_TRUNK1_W + o * MLP_HIDDEN_DIM + i]; gsHiddenB[o] = max(0.0, v); } GroupMemoryBarrierWithGroupSync(); @@ -263,9 +263,9 @@ void controllerProcess(uint3 gtid : SV_GroupThreadID) // Phase 4: trunk2 (128 -> 128). gsHiddenB -> gsHiddenA. for (int o = int(gtid.x); o < MLP_HIDDEN_DIM; o += THREAD_GROUP_SIZE) { - float v = g_Weights[OFF_TRUNK2_B + o]; + float v = weights[OFF_TRUNK2_B + o]; for (int i = 0; i < MLP_HIDDEN_DIM; ++i) - v += gsHiddenB[i] * g_Weights[OFF_TRUNK2_W + o * MLP_HIDDEN_DIM + i]; + v += gsHiddenB[i] * weights[OFF_TRUNK2_W + o * MLP_HIDDEN_DIM + i]; gsHiddenA[o] = max(0.0, v); } GroupMemoryBarrierWithGroupSync(); @@ -273,17 +273,17 @@ void controllerProcess(uint3 gtid : SV_GroupThreadID) // Phase 5: heads. if (gtid.x == 0) { - float v = g_Weights[OFF_EXP_B]; + float v = weights[OFF_EXP_B]; for (int i = 0; i < MLP_HIDDEN_DIM; ++i) - v += gsHiddenA[i] * g_Weights[OFF_EXP_W + i]; + v += gsHiddenA[i] * weights[OFF_EXP_W + i]; g_OutTex[int2(0, 0)] = v; } if (gtid.x < uint(COLOR_PARAMS_PER_FRAME)) { int o = int(gtid.x); - float v = g_Weights[OFF_COL_B + o]; + float v = weights[OFF_COL_B + o]; for (int i = 0; i < MLP_HIDDEN_DIM; ++i) - v += gsHiddenA[i] * g_Weights[OFF_COL_W + o * MLP_HIDDEN_DIM + i]; + v += gsHiddenA[i] * weights[OFF_COL_W + o * MLP_HIDDEN_DIM + i]; g_OutTex[int2(1 + o, 0)] = v; } } diff --git a/tools/render_ppisp_spg/spg_runtime.py b/tools/render_ppisp_spg/spg_runtime.py index 027ea6ab..c745375b 100644 --- a/tools/render_ppisp_spg/spg_runtime.py +++ b/tools/render_ppisp_spg/spg_runtime.py @@ -219,7 +219,7 @@ def run_controller( cur = spy.ShaderCursor(shader_obj) _set_param_block(cur, "g_Params", {"priorExposure": float(prior_exposure)}) cur["g_InTex"] = in_tex - cur["g_Weights"] = weights_buf + cur["weights"] = weights_buf cur["g_OutTex"] = out_tex cp.dispatch(spy.math.uint3(32, 1, 1)) device.submit_command_buffer(encoder.finish()) From 9800daee26d9564060ae28d2576dfad259475b1a Mon Sep 17 00:00:00 2001 From: Horde Date: Fri, 1 May 2026 11:46:51 +0000 Subject: [PATCH 28/42] fix(spg): move weights inside ParameterBlock to silence reflection warning Kit's SPG resolves USD ``inputs:foo`` attributes against fields of the slang ParameterBlock -- its reflection lookup is ``params:foo``. With ``weights`` declared as a sibling top-level ``StructuredBuffer`` the lookup failed every dispatch: [Warning] [rtx.spg.slang] SpgSlangNode: Failed to find parameter 'params:weights' in shader reflection The static PPISP shader follows the same convention -- every attribute (exposureOffset, vignetting*, crf*, ...) lives inside the PPISPParams struct that's wrapped by ParameterBlock. Move ``weights`` into PPISPControllerParams alongside priorExposure; update both the lua bind list (now passes the buffer as a positional argument to slang.ParameterBlock) and the slangpy harness's ShaderCursor (cur["g_Params"]["weights"] instead of cur["weights"]). Validation still passes at PSNR 63 dB; controller 9-float drift remains ~3.58e-7 vs torch. Co-Authored-By: Claude Opus 4.7 (1M context) --- .../usd/ppisp_spg/ppisp_controller.slang | 47 ++++++++++--------- .../usd/ppisp_spg/ppisp_controller.slang.lua | 6 ++- tools/render_ppisp_spg/spg_runtime.py | 7 ++- 3 files changed, 35 insertions(+), 25 deletions(-) diff --git a/threedgrut/export/usd/ppisp_spg/ppisp_controller.slang b/threedgrut/export/usd/ppisp_spg/ppisp_controller.slang index e809eb58..c41588e0 100644 --- a/threedgrut/export/usd/ppisp_spg/ppisp_controller.slang +++ b/threedgrut/export/usd/ppisp_spg/ppisp_controller.slang @@ -76,15 +76,20 @@ static const int TOTAL_WEIGHTS = OFF_COL_B + 8; // Bindings // --------------------------------------------------------------------------- +// SPG resolves USD ``inputs:foo`` attributes against fields of the slang +// ParameterBlock -- its reflection lookup is ``params:foo``. Putting +// ``weights`` directly inside the ParameterBlock struct lets SPG's +// auto-binding find it by name, and silences the per-dispatch warning +// "Failed to find parameter 'params:weights' in shader reflection". struct PPISPControllerParams { float priorExposure; + StructuredBuffer weights; }; [[vk::binding(0, 1)]] ParameterBlock g_Params; [[vk::binding(1, 1)]] Texture2D g_InTex; -[[vk::binding(2, 1)]] StructuredBuffer weights; -[[vk::binding(3, 1)]] RWTexture2D g_OutTex; +[[vk::binding(2, 1)]] RWTexture2D g_OutTex; // --------------------------------------------------------------------------- // Per-pixel CNN building blocks @@ -94,10 +99,10 @@ void conv1Forward(float3 rgb, out float feat[16]) { [unroll] for (int o = 0; o < 16; ++o) { - float v = weights[OFF_CONV1_B + o]; - v += rgb.r * weights[OFF_CONV1_W + o * 3 + 0]; - v += rgb.g * weights[OFF_CONV1_W + o * 3 + 1]; - v += rgb.b * weights[OFF_CONV1_W + o * 3 + 2]; + float v = g_Params.weights[OFF_CONV1_B + o]; + v += rgb.r * g_Params.weights[OFF_CONV1_W + o * 3 + 0]; + v += rgb.g * g_Params.weights[OFF_CONV1_W + o * 3 + 1]; + v += rgb.b * g_Params.weights[OFF_CONV1_W + o * 3 + 2]; feat[o] = v; } } @@ -106,9 +111,9 @@ void conv2Forward(float fin[16], out float fout[32]) { [unroll] for (int o = 0; o < 32; ++o) { - float v = weights[OFF_CONV2_B + o]; + float v = g_Params.weights[OFF_CONV2_B + o]; [unroll] for (int i = 0; i < 16; ++i) - v += fin[i] * weights[OFF_CONV2_W + o * 16 + i]; + v += fin[i] * g_Params.weights[OFF_CONV2_W + o * 16 + i]; fout[o] = v; } } @@ -117,9 +122,9 @@ void conv3Forward(float fin[32], out float fout[64]) { [unroll] for (int o = 0; o < CNN_FEATURE_DIM; ++o) { - float v = weights[OFF_CONV3_B + o]; + float v = g_Params.weights[OFF_CONV3_B + o]; [unroll] for (int i = 0; i < 32; ++i) - v += fin[i] * weights[OFF_CONV3_W + o * 32 + i]; + v += fin[i] * g_Params.weights[OFF_CONV3_W + o * 32 + i]; fout[o] = v; } } @@ -241,11 +246,11 @@ void controllerProcess(uint3 gtid : SV_GroupThreadID) // across the THREAD_GROUP_SIZE threads. for (int o = int(gtid.x); o < MLP_HIDDEN_DIM; o += THREAD_GROUP_SIZE) { - float v = weights[OFF_TRUNK0_B + o]; + float v = g_Params.weights[OFF_TRUNK0_B + o]; for (int i = 0; i < POOL_FEATURE_LEN; ++i) - v += gsPooled[i] * weights[OFF_TRUNK0_W + o * MLP_INPUT_DIM + i]; + v += gsPooled[i] * g_Params.weights[OFF_TRUNK0_W + o * MLP_INPUT_DIM + i]; v += g_Params.priorExposure - * weights[OFF_TRUNK0_W + o * MLP_INPUT_DIM + POOL_FEATURE_LEN]; + * g_Params.weights[OFF_TRUNK0_W + o * MLP_INPUT_DIM + POOL_FEATURE_LEN]; gsHiddenA[o] = max(0.0, v); } GroupMemoryBarrierWithGroupSync(); @@ -253,9 +258,9 @@ void controllerProcess(uint3 gtid : SV_GroupThreadID) // Phase 3: trunk1 (128 -> 128). gsHiddenA -> gsHiddenB. for (int o = int(gtid.x); o < MLP_HIDDEN_DIM; o += THREAD_GROUP_SIZE) { - float v = weights[OFF_TRUNK1_B + o]; + float v = g_Params.weights[OFF_TRUNK1_B + o]; for (int i = 0; i < MLP_HIDDEN_DIM; ++i) - v += gsHiddenA[i] * weights[OFF_TRUNK1_W + o * MLP_HIDDEN_DIM + i]; + v += gsHiddenA[i] * g_Params.weights[OFF_TRUNK1_W + o * MLP_HIDDEN_DIM + i]; gsHiddenB[o] = max(0.0, v); } GroupMemoryBarrierWithGroupSync(); @@ -263,9 +268,9 @@ void controllerProcess(uint3 gtid : SV_GroupThreadID) // Phase 4: trunk2 (128 -> 128). gsHiddenB -> gsHiddenA. for (int o = int(gtid.x); o < MLP_HIDDEN_DIM; o += THREAD_GROUP_SIZE) { - float v = weights[OFF_TRUNK2_B + o]; + float v = g_Params.weights[OFF_TRUNK2_B + o]; for (int i = 0; i < MLP_HIDDEN_DIM; ++i) - v += gsHiddenB[i] * weights[OFF_TRUNK2_W + o * MLP_HIDDEN_DIM + i]; + v += gsHiddenB[i] * g_Params.weights[OFF_TRUNK2_W + o * MLP_HIDDEN_DIM + i]; gsHiddenA[o] = max(0.0, v); } GroupMemoryBarrierWithGroupSync(); @@ -273,17 +278,17 @@ void controllerProcess(uint3 gtid : SV_GroupThreadID) // Phase 5: heads. if (gtid.x == 0) { - float v = weights[OFF_EXP_B]; + float v = g_Params.weights[OFF_EXP_B]; for (int i = 0; i < MLP_HIDDEN_DIM; ++i) - v += gsHiddenA[i] * weights[OFF_EXP_W + i]; + v += gsHiddenA[i] * g_Params.weights[OFF_EXP_W + i]; g_OutTex[int2(0, 0)] = v; } if (gtid.x < uint(COLOR_PARAMS_PER_FRAME)) { int o = int(gtid.x); - float v = weights[OFF_COL_B + o]; + float v = g_Params.weights[OFF_COL_B + o]; for (int i = 0; i < MLP_HIDDEN_DIM; ++i) - v += gsHiddenA[i] * weights[OFF_COL_W + o * MLP_HIDDEN_DIM + i]; + v += gsHiddenA[i] * g_Params.weights[OFF_COL_W + o * MLP_HIDDEN_DIM + i]; g_OutTex[int2(1 + o, 0)] = v; } } diff --git a/threedgrut/export/usd/ppisp_spg/ppisp_controller.slang.lua b/threedgrut/export/usd/ppisp_spg/ppisp_controller.slang.lua index 20cdb4cd..67daccdf 100644 --- a/threedgrut/export/usd/ppisp_spg/ppisp_controller.slang.lua +++ b/threedgrut/export/usd/ppisp_spg/ppisp_controller.slang.lua @@ -60,11 +60,13 @@ function controllerProcess(inputs, outputs, params) numthreads = { 32, 1, 1 }, grid = { 1, 1, 1 }, bind = { + -- weights live inside the ParameterBlock struct so SPG's + -- reflection finds them under "params:weights". slang.ParameterBlock( - slang.float(params["priorExposure"] or 0.0) + slang.float(params["priorExposure"] or 0.0), + bind_weights(weights) ), slang.Texture2D(in_rgba), - bind_weights(weights), slang.RWTexture2D(outputs["ControllerParams"]), }, }) diff --git a/tools/render_ppisp_spg/spg_runtime.py b/tools/render_ppisp_spg/spg_runtime.py index c745375b..d0aec5fb 100644 --- a/tools/render_ppisp_spg/spg_runtime.py +++ b/tools/render_ppisp_spg/spg_runtime.py @@ -217,9 +217,12 @@ def run_controller( with encoder.begin_compute_pass() as cp: shader_obj = cp.bind_pipeline(pipeline) cur = spy.ShaderCursor(shader_obj) - _set_param_block(cur, "g_Params", {"priorExposure": float(prior_exposure)}) + # weights live inside the g_Params ParameterBlock now, so SPG's + # reflection finds them under "params:weights" -- silences the + # "Failed to find parameter 'params:weights'" warning in Kit. + cur["g_Params"]["priorExposure"] = float(prior_exposure) + cur["g_Params"]["weights"] = weights_buf cur["g_InTex"] = in_tex - cur["weights"] = weights_buf cur["g_OutTex"] = out_tex cp.dispatch(spy.math.uint3(32, 1, 1)) device.submit_command_buffer(encoder.finish()) From 3636a1cfa19ba561fbaf60ea318457956f2ee63e Mon Sep 17 00:00:00 2001 From: Horde Date: Fri, 1 May 2026 12:49:41 +0000 Subject: [PATCH 29/42] feat(bake): warm-start fitted PPISP SH bake with simple-bake init MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The fit-by-bake export started Adam from the cloned checkpoint's SH state, even though we have a closed-form one-shot bake (simple_bake) that already lands close to the optimum. Use it as the initialization so the fitting loop only refines the residual. Adapter changes: - PostProcessingBakeAdapter.initialize_fit(model, pp) -- new hook, default no-op. - bake_post_processing_into_sh now calls adapter.initialize_fit on the cloned baked_model right before constructing the optimizer. - PPISPPostProcessingBakeAdapter.initialize_fit calls simple_bake with higher_order=True (so the spatial SH coefficients are Jacobian-projected too) and apply_srgb_to_linear=True. The new srgb_to_linear flag on simple_bake matters for the colour- space round-trip: PPISP outputs display-referred values (its CRF folds in gamma-like encoding). Storing those directly as linear SH coefs leaves the asset double-encoded once a downstream consumer applies linear_to_srgb (the validator does, Kit's tonemap does). Applying srgb_to_linear before RGB2SH puts SH back in linear scene- referred space; downstream linear_to_srgb then undoes the encoding exactly. Verified srgb_to_linear∘linear_to_srgb round-trips to fp32 epsilon (max|Δ| ≈ 2.4e-7) on [0, 1]. Existing 19 export tests pass. Co-Authored-By: Claude Opus 4.7 (1M context) --- .../export/usd/post_processing_sh_bake.py | 36 +++++++++++++++++++ .../usd/post_processing_sh_simple_bake.py | 22 ++++++++++-- .../utils/post_processing_linear_to_srgb.py | 28 +++++++++++++++ 3 files changed, 83 insertions(+), 3 deletions(-) diff --git a/threedgrut/export/usd/post_processing_sh_bake.py b/threedgrut/export/usd/post_processing_sh_bake.py index 08ce2d76..46ee0738 100644 --- a/threedgrut/export/usd/post_processing_sh_bake.py +++ b/threedgrut/export/usd/post_processing_sh_bake.py @@ -45,6 +45,15 @@ def apply_fit_transform(self, rgb: torch.Tensor, fixed_post_processing: nn.Modul del fixed_post_processing, gpu_batch return rgb + def initialize_fit(self, baked_model, post_processing: nn.Module) -> None: + """Optionally warm-start the SH fit with a closed-form initialization. + + Default is a no-op: the cloned ``baked_model`` keeps its checkpoint + SH coefficients as the starting point. Subclasses (e.g. PPISP) can + override to apply a one-shot bake before Adam takes over. + """ + del baked_model, post_processing + def log_context(self) -> str: return "" @@ -110,6 +119,13 @@ def bake_post_processing_into_sh( baked_model.build_acc() fixed_post_processing = adapter.create_fixed_post_processing(post_processing, device) + # Warm-start the cloned SH state with the adapter's closed-form bake + # (PPISP: simple_bake on the chosen camera/frame, with sRGB→linear so + # the resulting SH lives in linear scene-referred space). Adam takes + # over from there. Reduces the iterations needed and avoids fitting + # from a checkpoint state that's far from the optimum. + adapter.initialize_fit(baked_model, post_processing) + fit_parameters = list(_set_sh_fit_parameters(baked_model)) optimizer = torch.optim.Adam(fit_parameters, lr=learning_rate) train_dataloader = _create_train_dataloader(conf, train_dataset) @@ -308,5 +324,25 @@ def apply_fit_transform(self, rgb: torch.Tensor, fixed_post_processing: nn.Modul resolution=(width, height), ) + def initialize_fit(self, baked_model, post_processing: nn.Module) -> None: + """Warm-start with the higher-order simple-bake on the chosen + (camera, frame), in linear scene-referred space.""" + # Late import: avoid pulling ppisp into modules that don't need it. + from threedgrut.export.usd.post_processing_sh_simple_bake import simple_bake + + logger.info( + "PPISP SH bake init: applying simple_bake (camera=%d, frame=%d, " + "higher_order=True, apply_srgb_to_linear=True) before fitting.", + self.camera_id, self.frame_id, + ) + simple_bake( + baked_model, + post_processing, + camera_id=self.camera_id, + frame_id=self.frame_id, + higher_order=True, + apply_srgb_to_linear=True, + ) + def log_context(self) -> str: return f" camera={self.camera_id} frame={self.frame_id} vignetting={self.vignetting_mode}" diff --git a/threedgrut/export/usd/post_processing_sh_simple_bake.py b/threedgrut/export/usd/post_processing_sh_simple_bake.py index 58ba76f2..987af401 100644 --- a/threedgrut/export/usd/post_processing_sh_simple_bake.py +++ b/threedgrut/export/usd/post_processing_sh_simple_bake.py @@ -22,6 +22,7 @@ import torch from ppisp import PPISP, ppisp_apply +from threedgrut.utils.post_processing_linear_to_srgb import srgb_to_linear from threedgrut.utils.render import RGB2SH, SH2RGB @@ -119,10 +120,25 @@ def simple_bake( camera_id: int, frame_id: int, higher_order: bool = False, + apply_srgb_to_linear: bool = False, ) -> Tuple[float, torch.Tensor]: - """Mutate SH coefficients with one fixed PPISP camera/frame transform.""" + """Mutate SH coefficients with one fixed PPISP camera/frame transform. + + PPISP outputs display-referred values (its CRF folds in gamma-like + encoding). Storing those directly in linear SH coefficients leaves the + asset double-encoded: downstream consumers that themselves apply a + linear→sRGB step (``threedgrut/utils/post_processing_linear_to_srgb``, + Kit's tonemap, etc.) will gamma-correct on top of an already-encoded + image. ``apply_srgb_to_linear=True`` runs an inverse sRGB on the PPISP + output before ``RGB2SH`` so the SH coefficients land in linear scene- + referred space and a downstream ``linear_to_srgb`` undoes the + transformation cleanly. + """ exposure, color = get_fixed_frame_params(ppisp, frame_id) + def _maybe_srgb_to_linear(rgb: torch.Tensor) -> torch.Tensor: + return srgb_to_linear(rgb) if apply_srgb_to_linear else rgb + if higher_order: with torch.enable_grad(): dc_rgb_linear = SH2RGB(model.features_albedo).detach() @@ -134,7 +150,7 @@ def simple_bake( color=color, ) with torch.no_grad(): - model.features_albedo.copy_(RGB2SH(dc_rgb_baked)) + model.features_albedo.copy_(RGB2SH(_maybe_srgb_to_linear(dc_rgb_baked))) _apply_jacobian_to_specular(model.features_specular, jacobian) else: with torch.no_grad(): @@ -146,6 +162,6 @@ def simple_bake( exposure=exposure, color=color, ) - model.features_albedo.copy_(RGB2SH(dc_rgb_baked)) + model.features_albedo.copy_(RGB2SH(_maybe_srgb_to_linear(dc_rgb_baked))) return exposure, color diff --git a/threedgrut/utils/post_processing_linear_to_srgb.py b/threedgrut/utils/post_processing_linear_to_srgb.py index 93c2000c..657e8245 100644 --- a/threedgrut/utils/post_processing_linear_to_srgb.py +++ b/threedgrut/utils/post_processing_linear_to_srgb.py @@ -66,6 +66,34 @@ def linear_to_srgb(x: torch.Tensor) -> torch.Tensor: ) +def srgb_to_linear(x: torch.Tensor) -> torch.Tensor: + """Inverse of :func:`linear_to_srgb`: sRGB encoded values back to linear. + + Piecewise IEC 61966-2-1 with break point at ``0.04045``: + + .. code-block:: python + + np.where(x < 0.04045, x / 12.92, ((x + 0.055) / 1.055) ** 2.4) + + Round-trips :func:`linear_to_srgb` to fp32 epsilon for ``x`` in [0, 1]; + HDR values (``x > 1``) are passed through the upper branch identically + to the encode side. + + Args: + x: sRGB-encoded tensor (any shape). + + Returns: + Linear values, same shape / dtype / device as ``x``. + """ + limit = 0.04045 + positive_x = torch.clamp(x + 0.055, min=1e-08) + return torch.where( + x < limit, + x / 12.92, + torch.pow(positive_x / 1.055, 2.4), + ) + + class LinearToSrgbPostProcessing(nn.Module): """``nn.Module`` wrapper so linear-to-sRGB can plug into the shared post-processing path. From 066aaddca77263d1558a6d38e1bcd59dd9ab436e Mon Sep 17 00:00:00 2001 From: Horde Date: Fri, 1 May 2026 13:46:13 +0000 Subject: [PATCH 30/42] feat(bake): interpolated-view sampling for the SH bake fit loop Adds two non-training view samplers to bake_post_processing_into_sh, on top of the existing "iterate the train dataloader" mode: * random-pair-slerp -- pick two distinct training views uniformly at random and slerp between them at random s in [0, 1]. No global structure but cheap. * trajectory -- order the training views along an approximate Hamiltonian path (nearest-neighbour seed + 2-opt refinement) using a position+direction distance, arc-length-parameterise the path on [0, 1], and per step sample a random t and slerp inside the bracketing segment. Closer to the camera continuum a viewer would fly through; better generalisation than discrete training poses. New module ``post_processing_view_interpolation``: - slerp_pose(pose_a, pose_b, s) -- quaternion slerp + translation lerp, fp32-clean to within fp32 epsilon at s=0 / s=1. - order_views_along_trajectory(poses, ...) -- NN + 2-opt with a weighted (position L2 + 1 - cos(forward angle)) metric. Position distances are mean-normalised so the rotation term lives on a comparable scale across scenes. - InterpolatedViewSampler -- wraps a template Batch and emits steps_per_epoch synthetic Batches with only T_to_world replaced. bake_post_processing_into_sh now accepts: view_sampling_mode: "training" | "random-pair-slerp" | "trajectory" interpolated_views_seed: optional int RNG seed trajectory_weight_position / trajectory_weight_rotation: trajectory metric weights. Wired through USDExporter constructor + from_config (YAML keys post-processing-bake-view-mode / -view-seed / -trajectory-weight-{position,rotation}) and the export_usd.py CLI (--post-processing-bake-view-mode etc.). Default remains "training" so existing exports are unchanged. Sanity checked on a 6-pose synthetic circle: shuffled inputs reorder to a Hamiltonian path with uniform 0.2 arc-length steps; existing 19 export tests pass. Co-Authored-By: Claude Opus 4.7 (1M context) --- threedgrut/export/scripts/export_usd.py | 57 +++ threedgrut/export/usd/exporter.py | 53 +++ .../export/usd/post_processing_sh_bake.py | 66 ++- .../usd/post_processing_view_interpolation.py | 376 ++++++++++++++++++ 4 files changed, 546 insertions(+), 6 deletions(-) create mode 100644 threedgrut/export/usd/post_processing_view_interpolation.py diff --git a/threedgrut/export/scripts/export_usd.py b/threedgrut/export/scripts/export_usd.py index 2a11ce58..2e723b04 100644 --- a/threedgrut/export/scripts/export_usd.py +++ b/threedgrut/export/scripts/export_usd.py @@ -214,6 +214,35 @@ def parse_args(): "'achromatic-fit' uses chromatic PPISP reference and an achromatic fit-only vignette." ), ) + parser.add_argument( + "--post-processing-bake-view-mode", + type=str, + choices=["training", "random-pair-slerp", "trajectory"], + default=None, + help=( + "Which views the bake fit sees per step. 'training' (default) iterates the train " + "dataloader. 'random-pair-slerp' picks two random training views and slerps. " + "'trajectory' orders views along an NN+2-opt camera path and samples random t in [0,1]." + ), + ) + parser.add_argument( + "--post-processing-bake-view-seed", + type=int, + default=None, + help="Optional RNG seed for the interpolation samplers (None = non-deterministic).", + ) + parser.add_argument( + "--post-processing-bake-trajectory-weight-position", + type=float, + default=None, + help="Trajectory mode only: weight on the (mean-normalised) position term in pose distance.", + ) + parser.add_argument( + "--post-processing-bake-trajectory-weight-rotation", + type=float, + default=None, + help="Trajectory mode only: weight on the (1 - cos(angle)) rotation term in pose distance.", + ) # Dataset path (optional, overrides checkpoint's dataset path) parser.add_argument( @@ -468,6 +497,34 @@ def main(): "ppisp_bake_vignetting_mode", "achromatic-fit", ), + post_processing_bake_view_mode=_arg_or_conf( + args.post_processing_bake_view_mode, + export_conf, + "post-processing-bake-view-mode", + "post_processing_bake_view_mode", + "training", + ), + post_processing_bake_view_seed=_arg_or_conf( + args.post_processing_bake_view_seed, + export_conf, + "post-processing-bake-view-seed", + "post_processing_bake_view_seed", + None, + ), + post_processing_bake_trajectory_weight_position=_arg_or_conf( + args.post_processing_bake_trajectory_weight_position, + export_conf, + "post-processing-bake-trajectory-weight-position", + "post_processing_bake_trajectory_weight_position", + 1.0, + ), + post_processing_bake_trajectory_weight_rotation=_arg_or_conf( + args.post_processing_bake_trajectory_weight_rotation, + export_conf, + "post-processing-bake-trajectory-weight-rotation", + "post_processing_bake_trajectory_weight_rotation", + 0.5, + ), frames_per_second=getattr(export_conf, "frames_per_second", 1.0), ) logger.info("Using ParticleField3DGaussianSplat schema (standard)") diff --git a/threedgrut/export/usd/exporter.py b/threedgrut/export/usd/exporter.py index 2e6ae790..e1a27f66 100644 --- a/threedgrut/export/usd/exporter.py +++ b/threedgrut/export/usd/exporter.py @@ -289,6 +289,10 @@ def __init__( post_processing_bake_camera_id: int = 0, post_processing_bake_frame_id: int = 0, ppisp_bake_vignetting_mode: str = MODE_PPISP_BAKE_VIGNETTING_ACHROMATIC_FIT, + post_processing_bake_view_mode: str = "training", + post_processing_bake_view_seed: int | None = None, + post_processing_bake_trajectory_weight_position: float = 1.0, + post_processing_bake_trajectory_weight_rotation: float = 0.5, frames_per_second: float = 1.0, ): """ @@ -324,6 +328,17 @@ def __init__( ppisp_bake_vignetting_mode: "none" disables vignetting in the PPISP reference. "achromatic-fit" keeps chromatic PPISP vignetting in the reference and applies an achromatic estimate only in the fit loss. + post_processing_bake_view_mode: which views the bake fit sees per step. + "training" iterates the train dataloader (default). "random-pair-slerp" + samples two random training views and slerps between them. "trajectory" + orders the training views along an NN+2-opt camera path, parameterises + arc-length on [0, 1], and samples a random t per step. + post_processing_bake_view_seed: optional RNG seed for the interpolation + samplers. None (default) leaves it non-deterministic. + post_processing_bake_trajectory_weight_position: trajectory mode only. + Weight on the (mean-normalised) position term in the pose distance. + post_processing_bake_trajectory_weight_rotation: trajectory mode only. + Weight on the (1 - cos(angle)) rotation term in the pose distance. frames_per_second: Sets stage.timeCodesPerSecond. Time codes are always bare frame indices (float(frame_idx)), so this controls playback speed. Default 1.0 means 1 frame per second of real time. @@ -352,6 +367,16 @@ def __init__( self.post_processing_bake_camera_id = int(post_processing_bake_camera_id) self.post_processing_bake_frame_id = int(post_processing_bake_frame_id) self.ppisp_bake_vignetting_mode = str(ppisp_bake_vignetting_mode) + self.post_processing_bake_view_mode = str(post_processing_bake_view_mode) + self.post_processing_bake_view_seed = ( + None if post_processing_bake_view_seed is None else int(post_processing_bake_view_seed) + ) + self.post_processing_bake_trajectory_weight_position = float( + post_processing_bake_trajectory_weight_position + ) + self.post_processing_bake_trajectory_weight_rotation = float( + post_processing_bake_trajectory_weight_rotation + ) self.frames_per_second = frames_per_second def _create_default_stage(self, referenced_stages: List[NamedUSDStage]) -> NamedUSDStage: @@ -435,6 +460,10 @@ def export( adapter=adapter, epochs=self.post_processing_bake_epochs, learning_rate=self.post_processing_bake_learning_rate, + view_sampling_mode=self.post_processing_bake_view_mode, + interpolated_views_seed=self.post_processing_bake_view_seed, + trajectory_weight_position=self.post_processing_bake_trajectory_weight_position, + trajectory_weight_rotation=self.post_processing_bake_trajectory_weight_rotation, ) if uses_omni_native_post_processing_export and not has_ppisp_module: raise ValueError("Omniverse-native post-processing export currently supports PPISP post-processing only.") @@ -847,5 +876,29 @@ def from_config(cls, conf) -> "USDExporter": "ppisp_bake_vignetting_mode", MODE_PPISP_BAKE_VIGNETTING_ACHROMATIC_FIT, ), + post_processing_bake_view_mode=_get_export_config_value( + export_conf, + "post-processing-bake-view-mode", + "post_processing_bake_view_mode", + "training", + ), + post_processing_bake_view_seed=_get_export_config_value( + export_conf, + "post-processing-bake-view-seed", + "post_processing_bake_view_seed", + None, + ), + post_processing_bake_trajectory_weight_position=_get_export_config_value( + export_conf, + "post-processing-bake-trajectory-weight-position", + "post_processing_bake_trajectory_weight_position", + 1.0, + ), + post_processing_bake_trajectory_weight_rotation=_get_export_config_value( + export_conf, + "post-processing-bake-trajectory-weight-rotation", + "post_processing_bake_trajectory_weight_rotation", + 0.5, + ), frames_per_second=getattr(export_conf, "frames_per_second", 1.0), ) diff --git a/threedgrut/export/usd/post_processing_sh_bake.py b/threedgrut/export/usd/post_processing_sh_bake.py index 46ee0738..ef01590e 100644 --- a/threedgrut/export/usd/post_processing_sh_bake.py +++ b/threedgrut/export/usd/post_processing_sh_bake.py @@ -101,8 +101,34 @@ def bake_post_processing_into_sh( epochs: int = 1, learning_rate: float = 1.0e-3, device: str = "cuda", + view_sampling_mode: str = "training", + interpolated_views_seed: int | None = None, + trajectory_weight_position: float = 1.0, + trajectory_weight_rotation: float = 0.5, ): - """Return a cloned model whose SH coefficients approximate fixed post-processing output.""" + """Return a cloned model whose SH coefficients approximate fixed post-processing output. + + ``view_sampling_mode`` controls what the optimizer sees each step: + + * ``"training"`` (default) -- iterate the training dataloader as usual. + * ``"random-pair-slerp"`` -- pick two random training views and slerp + between them at a random ``s ∈ [0, 1]``. + * ``"trajectory"`` -- order the training views along an approximate + Hamiltonian path (NN + 2-opt on a position+direction metric), + arc-length-parameterise the path on ``[0, 1]``, sample random + ``t ∈ [0, 1]``, slerp inside the bracketing segment. + + The interpolated-view modes synthesise a ``Batch`` per step from the + template of the first training batch, replacing ``T_to_world`` with + the interpolated pose. ``steps_per_epoch`` matches + ``len(train_dataloader)`` so total step count is unchanged. + """ + from threedgrut.export.usd.post_processing_view_interpolation import ( + InterpolatedViewSampler, + VIEW_SAMPLING_TRAINING, + normalize_view_sampling_mode, + ) + if not hasattr(model, "clone"): raise TypeError("Post-processing SH bake export requires a cloneable MixtureOfGaussians model.") if train_dataset is None: @@ -111,6 +137,7 @@ def bake_post_processing_into_sh( raise ValueError("Post-processing SH bake export requires a post_processing module.") if epochs < 1: raise ValueError(f"epochs must be >= 1, got {epochs}.") + view_sampling_mode = normalize_view_sampling_mode(view_sampling_mode) adapter.validate(post_processing) reference_model = model.to(device).eval() @@ -129,21 +156,48 @@ def bake_post_processing_into_sh( fit_parameters = list(_set_sh_fit_parameters(baked_model)) optimizer = torch.optim.Adam(fit_parameters, lr=learning_rate) train_dataloader = _create_train_dataloader(conf, train_dataset) + steps_per_epoch = len(train_dataloader) + + sampler: InterpolatedViewSampler | None = None + if view_sampling_mode != VIEW_SAMPLING_TRAINING: + # Cache one real training batch to seed the synthetic sampler with + # valid intrinsics / rays / pixel coords; only T_to_world rotates + # per step. + first_batch = next(iter(train_dataloader)) + template = train_dataset.get_gpu_batch_with_intrinsics(first_batch) + sampler = InterpolatedViewSampler( + train_dataset, + template_gpu_batch=template, + mode=view_sampling_mode, + steps_per_epoch=steps_per_epoch, + seed=interpolated_views_seed, + weight_position=trajectory_weight_position, + weight_rotation=trajectory_weight_rotation, + ) logger.info( - "Fitting %s SH bake on train split: epochs=%s frames_per_epoch=%s%s", + "Fitting %s SH bake: mode=%s epochs=%s steps_per_epoch=%s%s", adapter.name, + view_sampling_mode, epochs, - len(train_dataloader), + steps_per_epoch, adapter.log_context(), ) + + def _gpu_batches(): + if sampler is None: + for batch in train_dataloader: + yield train_dataset.get_gpu_batch_with_intrinsics(batch) + else: + for gpu_batch in sampler: + yield gpu_batch + with torch.enable_grad(): global_step = 0 - total_steps = epochs * len(train_dataloader) + total_steps = epochs * steps_per_epoch for epoch in range(epochs): - for batch in train_dataloader: + for gpu_batch in _gpu_batches(): global_step += 1 - gpu_batch = train_dataset.get_gpu_batch_with_intrinsics(batch) reference_rgb = _render_reference(reference_model, fixed_post_processing, gpu_batch) optimizer.zero_grad(set_to_none=True) diff --git a/threedgrut/export/usd/post_processing_view_interpolation.py b/threedgrut/export/usd/post_processing_view_interpolation.py new file mode 100644 index 00000000..3ab5c072 --- /dev/null +++ b/threedgrut/export/usd/post_processing_view_interpolation.py @@ -0,0 +1,376 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 + +"""View samplers for SH-bake fitting. + +The default fit loop iterates the training dataloader, so the optimizer +only sees the discrete set of training poses. This module adds two +interpolation-based samplers: + +* ``random-pair-slerp`` -- pick two distinct training views uniformly at + random, slerp between them at a random ``s in [0, 1]``. Cheap, no + global structure. + +* ``trajectory`` -- order the training views along a smooth path using + nearest-neighbour + 2-opt with a position+direction distance, then + arc-length-parameterise the path on ``[0, 1]``. Each sample picks a + random ``t in [0, 1]``, locates the bracketing pair, and slerps inside + the segment. Closer to the kind of camera continuum a viewer would + fly through; better for fitting a residual that is supposed to + generalise to nearby novel views. + +Both samplers reuse the dataset's per-intrinsic camera-space rays and +pixel-coordinate grid -- only ``T_to_world`` changes per sample. PPISP's +``FixedPPISP`` ignores the per-frame indices on the synthetic batch, so +camera/frame indices on the template are kept as-is. +""" + +from __future__ import annotations + +import logging +import math +from dataclasses import replace +from typing import Iterator, List, Optional, Tuple + +import numpy as np +import torch + +from threedgrut.datasets.protocols import Batch + +logger = logging.getLogger(__name__) + + +VIEW_SAMPLING_TRAINING = "training" +VIEW_SAMPLING_RANDOM_PAIR_SLERP = "random-pair-slerp" +VIEW_SAMPLING_TRAJECTORY = "trajectory" +VIEW_SAMPLING_MODES = { + VIEW_SAMPLING_TRAINING, + VIEW_SAMPLING_RANDOM_PAIR_SLERP, + VIEW_SAMPLING_TRAJECTORY, +} + + +def normalize_view_sampling_mode(mode: Optional[str]) -> str: + normalized = VIEW_SAMPLING_TRAINING if mode is None else str(mode).strip().lower() + if normalized not in VIEW_SAMPLING_MODES: + raise ValueError( + f"Unsupported view sampling mode '{mode}'. " + f"Expected one of: {sorted(VIEW_SAMPLING_MODES)}" + ) + return normalized + + +# --------------------------------------------------------------------------- +# Pose interpolation primitives (numpy, double precision for stability) +# --------------------------------------------------------------------------- + + +def _R_to_quat(R: np.ndarray) -> np.ndarray: + """3x3 rotation -> unit quaternion [w, x, y, z] (Shepperd's method).""" + R = np.asarray(R, dtype=np.float64) + trace = R[0, 0] + R[1, 1] + R[2, 2] + if trace > 0.0: + s = math.sqrt(trace + 1.0) * 2.0 + qw = 0.25 * s + qx = (R[2, 1] - R[1, 2]) / s + qy = (R[0, 2] - R[2, 0]) / s + qz = (R[1, 0] - R[0, 1]) / s + elif R[0, 0] > R[1, 1] and R[0, 0] > R[2, 2]: + s = math.sqrt(1.0 + R[0, 0] - R[1, 1] - R[2, 2]) * 2.0 + qw = (R[2, 1] - R[1, 2]) / s + qx = 0.25 * s + qy = (R[0, 1] + R[1, 0]) / s + qz = (R[0, 2] + R[2, 0]) / s + elif R[1, 1] > R[2, 2]: + s = math.sqrt(1.0 + R[1, 1] - R[0, 0] - R[2, 2]) * 2.0 + qw = (R[0, 2] - R[2, 0]) / s + qx = (R[0, 1] + R[1, 0]) / s + qy = 0.25 * s + qz = (R[1, 2] + R[2, 1]) / s + else: + s = math.sqrt(1.0 + R[2, 2] - R[0, 0] - R[1, 1]) * 2.0 + qw = (R[1, 0] - R[0, 1]) / s + qx = (R[0, 2] + R[2, 0]) / s + qy = (R[1, 2] + R[2, 1]) / s + qz = 0.25 * s + q = np.array([qw, qx, qy, qz], dtype=np.float64) + return q / np.linalg.norm(q) + + +def _quat_to_R(q: np.ndarray) -> np.ndarray: + qw, qx, qy, qz = (q / np.linalg.norm(q)).tolist() + return np.array([ + [1 - 2*(qy*qy + qz*qz), 2*(qx*qy - qz*qw), 2*(qx*qz + qy*qw)], + [2*(qx*qy + qz*qw), 1 - 2*(qx*qx + qz*qz), 2*(qy*qz - qx*qw)], + [2*(qx*qz - qy*qw), 2*(qy*qz + qx*qw), 1 - 2*(qx*qx + qy*qy)], + ], dtype=np.float64) + + +def _slerp_quat(q0: np.ndarray, q1: np.ndarray, s: float) -> np.ndarray: + """Standard quaternion slerp; falls back to lerp+normalise when nearly parallel.""" + q0 = q0 / np.linalg.norm(q0) + q1 = q1 / np.linalg.norm(q1) + d = float(np.dot(q0, q1)) + if d < 0.0: # take the short arc + q1 = -q1 + d = -d + if d > 0.9995: + out = q0 + s * (q1 - q0) + return out / np.linalg.norm(out) + theta = math.acos(max(min(d, 1.0), -1.0)) + sin_theta = math.sin(theta) + a = math.sin((1.0 - s) * theta) / sin_theta + b = math.sin(s * theta) / sin_theta + return a * q0 + b * q1 + + +def slerp_pose(pose_a: np.ndarray, pose_b: np.ndarray, s: float) -> np.ndarray: + """Interpolate a 4x4 c2w pose between ``pose_a`` and ``pose_b`` at ``s in [0, 1]``. + + Rotation: quaternion slerp. Translation: linear lerp. Lower row is left as + ``[0, 0, 0, 1]``. + """ + s = float(np.clip(s, 0.0, 1.0)) + pose_a = np.asarray(pose_a, dtype=np.float64) + pose_b = np.asarray(pose_b, dtype=np.float64) + q_a = _R_to_quat(pose_a[:3, :3]) + q_b = _R_to_quat(pose_b[:3, :3]) + q = _slerp_quat(q_a, q_b, s) + R = _quat_to_R(q) + t = (1.0 - s) * pose_a[:3, 3] + s * pose_b[:3, 3] + out = np.eye(4, dtype=np.float64) + out[:3, :3] = R + out[:3, 3] = t + return out + + +# --------------------------------------------------------------------------- +# Trajectory ordering: nearest-neighbour + 2-opt on a position+direction metric +# --------------------------------------------------------------------------- + + +def _pose_distance_matrix( + poses: np.ndarray, + weight_position: float, + weight_rotation: float, +) -> np.ndarray: + """``D[i, j]`` = weighted (position L2 + 1 - cos(forward angle)).""" + n = poses.shape[0] + pos = poses[:, :3, 3] # (N, 3) + fwd = poses[:, :3, 2] # (N, 3) RDF: +Z = forward + fwd = fwd / np.maximum(np.linalg.norm(fwd, axis=1, keepdims=True), 1e-12) + + # vectorised pairwise position distance + diff = pos[:, None, :] - pos[None, :, :] + d_pos = np.linalg.norm(diff, axis=2) + # normalise by mean pairwise so the rotation term lives on a comparable scale + mean_pos = max(float(d_pos[d_pos > 0].mean()) if (d_pos > 0).any() else 1.0, 1e-9) + + cos_ang = np.clip(fwd @ fwd.T, -1.0, 1.0) + d_rot = 1.0 - cos_ang # in [0, 2] + + return weight_position * (d_pos / mean_pos) + weight_rotation * d_rot + + +def _nearest_neighbour_order(D: np.ndarray, start: int = 0) -> List[int]: + n = D.shape[0] + visited = [False] * n + order = [start] + visited[start] = True + while len(order) < n: + last = order[-1] + # mask visited with +inf + candidates = np.where(visited, np.inf, D[last]) + nxt = int(np.argmin(candidates)) + order.append(nxt) + visited[nxt] = True + return order + + +def _two_opt(order: List[int], D: np.ndarray, max_passes: int = 50) -> List[int]: + """In-place 2-opt swap loop. Stops when a full pass yields no improvement + or when ``max_passes`` is reached.""" + n = len(order) + if n < 4: + return order + for _ in range(max_passes): + improved = False + for i in range(1, n - 2): + for j in range(i + 1, n - 1): + a, b = order[i - 1], order[i] + c, d = order[j], order[j + 1] + # original edges (a,b) + (c,d) + # candidate after reverse: (a,c) + (b,d) + if D[a, c] + D[b, d] + 1e-12 < D[a, b] + D[c, d]: + order[i:j + 1] = order[i:j + 1][::-1] + improved = True + if not improved: + break + return order + + +def order_views_along_trajectory( + poses: np.ndarray, + *, + weight_position: float = 1.0, + weight_rotation: float = 0.5, + start_index: int = 0, + two_opt_passes: int = 50, +) -> Tuple[List[int], np.ndarray]: + """Order ``poses`` along an approximate Hamiltonian path. + + Returns ``(ordered_indices, cum_t)`` where ``cum_t[k] in [0, 1]`` is the + arc-length parameter at the k-th ordered pose. ``cum_t[0] = 0`` and + ``cum_t[-1] = 1``. + """ + poses = np.asarray(poses, dtype=np.float64) + if poses.ndim != 3 or poses.shape[-2:] != (4, 4): + raise ValueError(f"poses must be (N, 4, 4), got {poses.shape}") + n = poses.shape[0] + if n < 2: + return list(range(n)), np.zeros(max(n, 1), dtype=np.float64) + + D = _pose_distance_matrix(poses, weight_position, weight_rotation) + order = _nearest_neighbour_order(D, start=start_index) + order = _two_opt(order, D, max_passes=two_opt_passes) + + cum = np.zeros(n, dtype=np.float64) + for k in range(1, n): + cum[k] = cum[k - 1] + D[order[k - 1], order[k]] + if cum[-1] > 0: + cum = cum / cum[-1] + return order, cum + + +# --------------------------------------------------------------------------- +# Sampler driving the fit loop +# --------------------------------------------------------------------------- + + +class InterpolatedViewSampler: + """Yields ``Batch`` objects with synthetic interpolated poses. + + The sampler grabs one template batch from the training dataset to + cache the per-intrinsic camera-space rays, pixel coords and any + intrinsic dictionaries; only ``T_to_world`` (and ``T_to_world_end``, + which we set to the same pose -- no rolling shutter on synthetic + poses) changes per sample. + + Args: + train_dataset: must implement + :meth:`~threedgrut.datasets.protocols.BoundedMultiViewDataset.get_poses` + and :meth:`get_gpu_batch_with_intrinsics`. + mode: ``"random-pair-slerp"`` or ``"trajectory"``. + steps_per_epoch: how many synthetic batches to emit per pass. + seed: optional RNG seed for reproducibility. + weight_position / weight_rotation: trajectory mode only. + start_index: trajectory mode only. + """ + + def __init__( + self, + train_dataset, + template_gpu_batch: Batch, + mode: str, + steps_per_epoch: int, + *, + seed: Optional[int] = None, + weight_position: float = 1.0, + weight_rotation: float = 0.5, + start_index: int = 0, + ) -> None: + mode = normalize_view_sampling_mode(mode) + if mode == VIEW_SAMPLING_TRAINING: + raise ValueError("InterpolatedViewSampler is only for non-training modes.") + if not hasattr(train_dataset, "get_poses"): + raise TypeError( + "InterpolatedViewSampler requires a dataset exposing get_poses(); " + f"got {type(train_dataset).__name__}." + ) + if not isinstance(template_gpu_batch, Batch): + raise TypeError("template_gpu_batch must be a threedgrut Batch instance.") + self.dataset = train_dataset + self.mode = mode + self.steps_per_epoch = int(steps_per_epoch) + self._rng = np.random.default_rng(seed) + self._template = template_gpu_batch + + poses = np.asarray(train_dataset.get_poses(), dtype=np.float64) + if poses.ndim != 3 or poses.shape[-2:] != (4, 4): + raise ValueError(f"dataset.get_poses() must be (N, 4, 4), got {poses.shape}") + if poses.shape[0] < 2: + raise ValueError("Need at least 2 training views to interpolate.") + self._poses = poses + + if mode == VIEW_SAMPLING_TRAJECTORY: + self._ordered_indices, self._cum_t = order_views_along_trajectory( + poses, + weight_position=weight_position, + weight_rotation=weight_rotation, + start_index=start_index, + ) + logger.info( + "Built %d-view trajectory (NN + 2-opt) for SH-bake interpolation.", + len(self._ordered_indices), + ) + else: + self._ordered_indices = None + self._cum_t = None + + # ------------------------------------------------------------------ + # Pose sampling + # ------------------------------------------------------------------ + + def _sample_pose_random_pair(self) -> np.ndarray: + n = self._poses.shape[0] + i = int(self._rng.integers(0, n)) + j = int(self._rng.integers(0, n - 1)) + if j >= i: + j += 1 # ensures j != i without bias + s = float(self._rng.random()) + return slerp_pose(self._poses[i], self._poses[j], s) + + def _sample_pose_trajectory(self) -> np.ndarray: + t = float(self._rng.random()) + cum = self._cum_t + # Find segment k s.t. cum[k-1] <= t <= cum[k] (with cum[0]=0). + k = int(np.searchsorted(cum, t, side="left")) + k = max(1, min(k, len(cum) - 1)) + denom = max(cum[k] - cum[k - 1], 1e-12) + local_s = float((t - cum[k - 1]) / denom) + a = self._ordered_indices[k - 1] + b = self._ordered_indices[k] + return slerp_pose(self._poses[a], self._poses[b], local_s) + + def _sample_pose(self) -> np.ndarray: + if self.mode == VIEW_SAMPLING_RANDOM_PAIR_SLERP: + return self._sample_pose_random_pair() + return self._sample_pose_trajectory() + + # ------------------------------------------------------------------ + # Batch construction + # ------------------------------------------------------------------ + + def _make_batch(self, pose_np: np.ndarray) -> Batch: + device = self._template.T_to_world.device + dtype = self._template.T_to_world.dtype + T = torch.from_numpy(pose_np).to(device=device, dtype=dtype).unsqueeze(0) + # Same pose for start and end -- no rolling shutter on synthetic views. + return replace(self._template, T_to_world=T, T_to_world_end=T) + + # ------------------------------------------------------------------ + # Iterator protocol + # ------------------------------------------------------------------ + + def __iter__(self) -> Iterator[Batch]: + for _ in range(self.steps_per_epoch): + yield self._make_batch(self._sample_pose()) + + def __len__(self) -> int: + return self.steps_per_epoch From c2ba74c27ed39df3e6c51c8c8b5f9008a6dc6994 Mon Sep 17 00:00:00 2001 From: Horde Date: Fri, 1 May 2026 14:52:14 +0000 Subject: [PATCH 31/42] fix(bake): match PPISP fit reference encoding (linear -> sRGB + clamp) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit PPISP forward is display-referred, so the fitted SH side must encode to sRGB and clamp before MSE — without this the loss plateaued near 13 dB on real scenes. Mirrors post_processing_sh_bake_validation.py::_fitBakedSh. Co-Authored-By: Claude Opus 4.7 (1M context) --- threedgrut/export/usd/post_processing_sh_bake.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/threedgrut/export/usd/post_processing_sh_bake.py b/threedgrut/export/usd/post_processing_sh_bake.py index ef01590e..6e8c7637 100644 --- a/threedgrut/export/usd/post_processing_sh_bake.py +++ b/threedgrut/export/usd/post_processing_sh_bake.py @@ -367,16 +367,20 @@ def create_fixed_post_processing(self, post_processing: nn.Module, device: str) ).eval() def apply_fit_transform(self, rgb: torch.Tensor, fixed_post_processing: nn.Module, gpu_batch) -> torch.Tensor: + # PPISP reference is display-referred; baked side must encode to sRGB to match. + from threedgrut.utils.post_processing_linear_to_srgb import linear_to_srgb + if self.vignetting_mode == MODE_PPISP_BAKE_VIGNETTING_NONE: - return rgb + return torch.clamp(linear_to_srgb(rgb), 0.0, 1.0) _, height, width, _ = rgb.shape - return apply_achromatic_vignetting( + vignetted = apply_achromatic_vignetting( rgb=rgb, ppisp=fixed_post_processing.ppisp, camera_id=fixed_post_processing.camera_id, pixel_coords=gpu_batch.pixel_coords, resolution=(width, height), ) + return torch.clamp(linear_to_srgb(vignetted), 0.0, 1.0) def initialize_fit(self, baked_model, post_processing: nn.Module) -> None: """Warm-start with the higher-order simple-bake on the chosen From 872f50a89d282e0b47c5ad5133474aa3f18a5ca6 Mon Sep 17 00:00:00 2001 From: Horde Date: Fri, 1 May 2026 18:34:20 +0000 Subject: [PATCH 32/42] fix(bake): compose Jacobian through srgb_to_linear and clip outliers Two bugs in simple_bake's higher_order=True path under apply_srgb_to_linear=True: 1. The Jacobian was dPPISP/dX while the DC bake landed in linear space via srgb_to_linear; the chain rule was missing srgb_to_linear'(PPISP). The composed Jacobian now matches the DC color space. 2. ~0.06% of Gaussians at PPISP-saturation extremes have pathological Jacobians (cond > 1e8, |J|_F > 1e4) that pump features_specular norms from O(1) to O(10^4). These outliers dominate Adam's adaptive variance and stall the fit at ~29 dB. Clipping |J|_F > 5 (and any non-finite J) to identity preserves the rotation for the well-behaved 99.94% while leaving the trained specular intact for outliers. On bonsai (1M Gaussians, 9-epoch fit), the higher-order PPISP warm-start went from 25.4 dB to 36.4 dB -- now within 0.2 dB of the DC-only sRGB warm-start (36.6 dB). Co-Authored-By: Claude Opus 4.7 (1M context) --- .../usd/post_processing_sh_simple_bake.py | 49 +++++++++++++++++-- 1 file changed, 44 insertions(+), 5 deletions(-) diff --git a/threedgrut/export/usd/post_processing_sh_simple_bake.py b/threedgrut/export/usd/post_processing_sh_simple_bake.py index 987af401..a9267d80 100644 --- a/threedgrut/export/usd/post_processing_sh_simple_bake.py +++ b/threedgrut/export/usd/post_processing_sh_simple_bake.py @@ -17,6 +17,7 @@ from __future__ import annotations +import logging from typing import Tuple import torch @@ -25,6 +26,15 @@ from threedgrut.utils.post_processing_linear_to_srgb import srgb_to_linear from threedgrut.utils.render import RGB2SH, SH2RGB +logger = logging.getLogger(__name__) + +# A handful of Gaussians sit at PPISP-saturation extremes where the chain +# rule through pow() blows up (cond(J) > 1e8, |J|_F > 1e4). Their rotated +# specular norms grow by 4+ orders of magnitude and dominate Adam's +# variance estimate, stalling the fit. Past p99 of |J|_F (~3.4 on bonsai), +# rotations are unreliable; ``5.0`` keeps a small safety margin. +JACOBIAN_FRO_NORM_CLIP = 5.0 + def get_fixed_frame_params( ppisp: PPISP, @@ -76,16 +86,23 @@ def _bake_dc_with_jacobian_through_ppisp( camera_id: int, exposure: float, color: torch.Tensor, + apply_srgb_to_linear: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor]: - """Run PPISP forward and extract per-Gaussian RGB Jacobians.""" + """Run PPISP forward and extract per-Gaussian RGB Jacobians. + + When ``apply_srgb_to_linear`` is True, both the returned RGB and the + Jacobian correspond to ``srgb_to_linear(PPISP(X))`` so the higher-order + SH rotation stays in the same color space as the DC bake. + """ rgb_in = dc_rgb_linear.detach().clone().requires_grad_(True) - rgb_out = _bake_dc_through_ppisp( + rgb_ppisp = _bake_dc_through_ppisp( dc_rgb_linear=rgb_in, ppisp=ppisp, camera_id=camera_id, exposure=exposure, color=color, ) + rgb_out = srgb_to_linear(rgb_ppisp) if apply_srgb_to_linear else rgb_ppisp num_gaussians = rgb_in.shape[0] jacobian = torch.empty(num_gaussians, 3, 3, device=rgb_in.device, dtype=rgb_in.dtype) @@ -104,13 +121,32 @@ def _bake_dc_with_jacobian_through_ppisp( def _apply_jacobian_to_specular(features_specular: torch.nn.Parameter, jacobian: torch.Tensor) -> None: - """In-place linearization of higher-order SH coefficients by ``jacobian``.""" + """In-place linearization of higher-order SH coefficients by ``jacobian``. + + Gaussians whose Jacobian is non-finite or has Frobenius norm above + :data:`JACOBIAN_FRO_NORM_CLIP` keep their trained specular (i.e. J is + replaced by the identity for them) -- avoids polluting Adam's variance + estimate with rare PPISP-saturation outliers. + """ num_gaussians, total = features_specular.shape if total % 3 != 0: raise ValueError(f"features_specular last-dim ({total}) must be divisible by 3.") num_sh_coeffs = total // 3 specular_rgb = features_specular.view(num_gaussians, num_sh_coeffs, 3) - transformed = torch.einsum("nij,nkj->nki", jacobian, specular_rgb) + + j_fro = torch.linalg.norm(jacobian, ord="fro", dim=(1, 2)) + safe = torch.isfinite(j_fro) & (j_fro <= JACOBIAN_FRO_NORM_CLIP) + eye = torch.eye(3, device=jacobian.device, dtype=jacobian.dtype).expand_as(jacobian) + jacobian_safe = torch.where(safe[:, None, None], jacobian, eye) + n_clipped = int((~safe).sum().item()) + if n_clipped > 0: + logger.info( + "Jacobian rotation clipped on %d/%d gaussians (|J|_F > %.1f or non-finite); " + "their trained features_specular preserved.", + n_clipped, num_gaussians, JACOBIAN_FRO_NORM_CLIP, + ) + + transformed = torch.einsum("nij,nkj->nki", jacobian_safe, specular_rgb) specular_rgb.copy_(transformed) @@ -148,9 +184,12 @@ def _maybe_srgb_to_linear(rgb: torch.Tensor) -> torch.Tensor: camera_id=camera_id, exposure=exposure, color=color, + apply_srgb_to_linear=apply_srgb_to_linear, ) with torch.no_grad(): - model.features_albedo.copy_(RGB2SH(_maybe_srgb_to_linear(dc_rgb_baked))) + # dc_rgb_baked already includes srgb_to_linear when requested, + # so RGB2SH gets the right color space directly. + model.features_albedo.copy_(RGB2SH(dc_rgb_baked)) _apply_jacobian_to_specular(model.features_specular, jacobian) else: with torch.no_grad(): From 32d07119f9b701c204bbf2914a07732d344a3cd4 Mon Sep 17 00:00:00 2001 From: Horde Date: Fri, 1 May 2026 18:34:25 +0000 Subject: [PATCH 33/42] refactor(bake): drop random-pair-slerp view sampler Empirical sweep showed random-pair sampling was always within noise of training views (sometimes -0.8 dB) and never improved over the trajectory sampler. The trajectory mode is retained for sparse-view datasets where interpolating between adjacent training poses can help generalisation. Co-Authored-By: Claude Opus 4.7 (1M context) --- threedgrut/export/scripts/export_usd.py | 6 +- threedgrut/export/usd/exporter.py | 6 +- .../export/usd/post_processing_sh_bake.py | 7 +- .../usd/post_processing_view_interpolation.py | 72 ++++++------------- 4 files changed, 31 insertions(+), 60 deletions(-) diff --git a/threedgrut/export/scripts/export_usd.py b/threedgrut/export/scripts/export_usd.py index 2e723b04..4731268f 100644 --- a/threedgrut/export/scripts/export_usd.py +++ b/threedgrut/export/scripts/export_usd.py @@ -217,12 +217,12 @@ def parse_args(): parser.add_argument( "--post-processing-bake-view-mode", type=str, - choices=["training", "random-pair-slerp", "trajectory"], + choices=["training", "trajectory"], default=None, help=( "Which views the bake fit sees per step. 'training' (default) iterates the train " - "dataloader. 'random-pair-slerp' picks two random training views and slerps. " - "'trajectory' orders views along an NN+2-opt camera path and samples random t in [0,1]." + "dataloader. 'trajectory' orders views along an NN+2-opt camera path and samples " + "random t in [0,1] -- useful when training views are sparse." ), ) parser.add_argument( diff --git a/threedgrut/export/usd/exporter.py b/threedgrut/export/usd/exporter.py index e1a27f66..c969f2ce 100644 --- a/threedgrut/export/usd/exporter.py +++ b/threedgrut/export/usd/exporter.py @@ -329,10 +329,10 @@ def __init__( reference. "achromatic-fit" keeps chromatic PPISP vignetting in the reference and applies an achromatic estimate only in the fit loss. post_processing_bake_view_mode: which views the bake fit sees per step. - "training" iterates the train dataloader (default). "random-pair-slerp" - samples two random training views and slerps between them. "trajectory" + "training" iterates the train dataloader (default). "trajectory" orders the training views along an NN+2-opt camera path, parameterises - arc-length on [0, 1], and samples a random t per step. + arc-length on [0, 1], and samples a random t per step (helpful when + training views are sparse). post_processing_bake_view_seed: optional RNG seed for the interpolation samplers. None (default) leaves it non-deterministic. post_processing_bake_trajectory_weight_position: trajectory mode only. diff --git a/threedgrut/export/usd/post_processing_sh_bake.py b/threedgrut/export/usd/post_processing_sh_bake.py index 6e8c7637..e69168b5 100644 --- a/threedgrut/export/usd/post_processing_sh_bake.py +++ b/threedgrut/export/usd/post_processing_sh_bake.py @@ -111,14 +111,13 @@ def bake_post_processing_into_sh( ``view_sampling_mode`` controls what the optimizer sees each step: * ``"training"`` (default) -- iterate the training dataloader as usual. - * ``"random-pair-slerp"`` -- pick two random training views and slerp - between them at a random ``s ∈ [0, 1]``. * ``"trajectory"`` -- order the training views along an approximate Hamiltonian path (NN + 2-opt on a position+direction metric), arc-length-parameterise the path on ``[0, 1]``, sample random - ``t ∈ [0, 1]``, slerp inside the bracketing segment. + ``t ∈ [0, 1]``, slerp inside the bracketing segment. Helpful on + datasets with sparse view coverage. - The interpolated-view modes synthesise a ``Batch`` per step from the + The trajectory mode synthesises a ``Batch`` per step from the template of the first training batch, replacing ``T_to_world`` with the interpolated pose. ``steps_per_epoch`` matches ``len(train_dataloader)`` so total step count is unchanged. diff --git a/threedgrut/export/usd/post_processing_view_interpolation.py b/threedgrut/export/usd/post_processing_view_interpolation.py index 3ab5c072..223720fe 100644 --- a/threedgrut/export/usd/post_processing_view_interpolation.py +++ b/threedgrut/export/usd/post_processing_view_interpolation.py @@ -10,22 +10,14 @@ """View samplers for SH-bake fitting. The default fit loop iterates the training dataloader, so the optimizer -only sees the discrete set of training poses. This module adds two -interpolation-based samplers: - -* ``random-pair-slerp`` -- pick two distinct training views uniformly at - random, slerp between them at a random ``s in [0, 1]``. Cheap, no - global structure. - -* ``trajectory`` -- order the training views along a smooth path using - nearest-neighbour + 2-opt with a position+direction distance, then - arc-length-parameterise the path on ``[0, 1]``. Each sample picks a - random ``t in [0, 1]``, locates the bracketing pair, and slerps inside - the segment. Closer to the kind of camera continuum a viewer would - fly through; better for fitting a residual that is supposed to - generalise to nearby novel views. - -Both samplers reuse the dataset's per-intrinsic camera-space rays and +only sees the discrete set of training poses. The ``trajectory`` sampler +orders the training views along a smooth path (nearest-neighbour + 2-opt +on position+direction), arc-length-parameterises the path on ``[0, 1]``, +then samples random ``t in [0, 1]`` and slerps inside the bracketing +segment. Useful when training views are sparse and a residual fit needs +to generalise to nearby novel views. + +The sampler reuses the dataset's per-intrinsic camera-space rays and pixel-coordinate grid -- only ``T_to_world`` changes per sample. PPISP's ``FixedPPISP`` ignores the per-frame indices on the synthetic batch, so camera/frame indices on the template are kept as-is. @@ -47,11 +39,9 @@ VIEW_SAMPLING_TRAINING = "training" -VIEW_SAMPLING_RANDOM_PAIR_SLERP = "random-pair-slerp" VIEW_SAMPLING_TRAJECTORY = "trajectory" VIEW_SAMPLING_MODES = { VIEW_SAMPLING_TRAINING, - VIEW_SAMPLING_RANDOM_PAIR_SLERP, VIEW_SAMPLING_TRAJECTORY, } @@ -266,11 +256,11 @@ class InterpolatedViewSampler: train_dataset: must implement :meth:`~threedgrut.datasets.protocols.BoundedMultiViewDataset.get_poses` and :meth:`get_gpu_batch_with_intrinsics`. - mode: ``"random-pair-slerp"`` or ``"trajectory"``. + mode: only ``"trajectory"`` is supported. steps_per_epoch: how many synthetic batches to emit per pass. seed: optional RNG seed for reproducibility. - weight_position / weight_rotation: trajectory mode only. - start_index: trajectory mode only. + weight_position / weight_rotation: trajectory distance weights. + start_index: trajectory NN seed index. """ def __init__( @@ -308,35 +298,22 @@ def __init__( raise ValueError("Need at least 2 training views to interpolate.") self._poses = poses - if mode == VIEW_SAMPLING_TRAJECTORY: - self._ordered_indices, self._cum_t = order_views_along_trajectory( - poses, - weight_position=weight_position, - weight_rotation=weight_rotation, - start_index=start_index, - ) - logger.info( - "Built %d-view trajectory (NN + 2-opt) for SH-bake interpolation.", - len(self._ordered_indices), - ) - else: - self._ordered_indices = None - self._cum_t = None + self._ordered_indices, self._cum_t = order_views_along_trajectory( + poses, + weight_position=weight_position, + weight_rotation=weight_rotation, + start_index=start_index, + ) + logger.info( + "Built %d-view trajectory (NN + 2-opt) for SH-bake interpolation.", + len(self._ordered_indices), + ) # ------------------------------------------------------------------ # Pose sampling # ------------------------------------------------------------------ - def _sample_pose_random_pair(self) -> np.ndarray: - n = self._poses.shape[0] - i = int(self._rng.integers(0, n)) - j = int(self._rng.integers(0, n - 1)) - if j >= i: - j += 1 # ensures j != i without bias - s = float(self._rng.random()) - return slerp_pose(self._poses[i], self._poses[j], s) - - def _sample_pose_trajectory(self) -> np.ndarray: + def _sample_pose(self) -> np.ndarray: t = float(self._rng.random()) cum = self._cum_t # Find segment k s.t. cum[k-1] <= t <= cum[k] (with cum[0]=0). @@ -348,11 +325,6 @@ def _sample_pose_trajectory(self) -> np.ndarray: b = self._ordered_indices[k] return slerp_pose(self._poses[a], self._poses[b], local_s) - def _sample_pose(self) -> np.ndarray: - if self.mode == VIEW_SAMPLING_RANDOM_PAIR_SLERP: - return self._sample_pose_random_pair() - return self._sample_pose_trajectory() - # ------------------------------------------------------------------ # Batch construction # ------------------------------------------------------------------ From 6093a0deeb4eedc50e33765644de1b6b30914c83 Mon Sep 17 00:00:00 2001 From: Horde Date: Fri, 1 May 2026 18:34:31 +0000 Subject: [PATCH 34/42] feat(tools): add ppisp_export bake-modes ablation tool Sweeps simple-bake and Adam-fit variants of the PPISP SH bake on a trained checkpoint, reports mean / median / min / max PSNR (optional SSIM, LPIPS) across the validation split, and writes per-frame numbers to metrics.json. Used to root-cause the higher-order warm-start regression (Jacobian chain rule + outlier clipping). Co-Authored-By: Claude Opus 4.7 (1M context) --- .../bake_modes_benchmark/__init__.py | 2 + .../bake_modes_benchmark/benchmark.py | 427 ++++++++++++++++++ 2 files changed, 429 insertions(+) create mode 100644 tools/ppisp_export/bake_modes_benchmark/__init__.py create mode 100644 tools/ppisp_export/bake_modes_benchmark/benchmark.py diff --git a/tools/ppisp_export/bake_modes_benchmark/__init__.py b/tools/ppisp_export/bake_modes_benchmark/__init__.py new file mode 100644 index 00000000..52a7a9da --- /dev/null +++ b/tools/ppisp_export/bake_modes_benchmark/__init__.py @@ -0,0 +1,2 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 diff --git a/tools/ppisp_export/bake_modes_benchmark/benchmark.py b/tools/ppisp_export/bake_modes_benchmark/benchmark.py new file mode 100644 index 00000000..1a2b9818 --- /dev/null +++ b/tools/ppisp_export/bake_modes_benchmark/benchmark.py @@ -0,0 +1,427 @@ +#!/usr/bin/env python3 +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Sweep PPISP SH-bake modes on a trained checkpoint and report aggregated metrics. + +For each configured (bake mode, view sampling mode, init policy) tuple +the script: + +1. Builds a baked model from the cloned checkpoint: + - ``simple`` / ``simple-higher-order`` flavours run :func:`simple_bake` + directly (with optional sRGB→linear). + - ``fit`` flavours run :func:`bake_post_processing_into_sh` with the + desired view sampling mode and init policy. + +2. Renders every validation frame through: + - the *reference* model + chromatic-vignette PPISP at the chosen + (camera, frame) -- the per-frame target. + - the *baked* model with ``linear_to_srgb`` applied to its output and + an achromatic-vignette correction (matches the evaluator in + ``post_processing_sh_bake_validation.py``). + +3. Computes per-frame PSNR, SSIM and LPIPS, then aggregates mean / + median / min / max across the validation split. + +4. Prints a table sorted by mean PSNR and writes the raw per-frame + numbers to ``/metrics.json``. + +Usage: + + python tools/bake_modes_benchmark/benchmark.py \\ + --checkpoint runs//ckpt_last.pt \\ + --out-dir /tmp/bake_modes \\ + --camera-id 0 --frame-id 0 +""" + +from __future__ import annotations + +import argparse +import json +import logging +import sys +import time +from dataclasses import dataclass, field +from pathlib import Path +from typing import Callable, Dict, List, Optional + +import numpy as np +import torch +import torch.nn as nn + +sys.path.insert(0, str(Path(__file__).resolve().parents[2])) + +from threedgrut.export.usd.post_processing_sh_bake import ( # noqa: E402 + MODE_PPISP_BAKE_VIGNETTING_ACHROMATIC_FIT, + PPISPPostProcessingBakeAdapter, + apply_achromatic_vignetting, + bake_post_processing_into_sh, + FixedPPISP, +) +from threedgrut.export.usd.post_processing_sh_simple_bake import simple_bake # noqa: E402 +from threedgrut.render import Renderer # noqa: E402 +from threedgrut.utils.post_processing_linear_to_srgb import linear_to_srgb # noqa: E402 +from threedgrut.utils.render import apply_post_processing # noqa: E402 + +logger = logging.getLogger("bake_modes_benchmark") + + +# --------------------------------------------------------------------------- +# Mode catalogue +# --------------------------------------------------------------------------- + + +@dataclass +class BakeMode: + """One row in the sweep -- a bake configuration with a short name.""" + name: str + description: str + builder: Callable[..., nn.Module] + + +def _build_simple(*, model, ppisp, camera_id, frame_id, higher_order, srgb, + dataset=None, conf=None): + del dataset, conf # unused by the simple flavours + baked = model.clone().eval() + simple_bake( + baked, ppisp, + camera_id=camera_id, frame_id=frame_id, + higher_order=higher_order, apply_srgb_to_linear=srgb, + ) + baked.build_acc() + return baked + + +def _build_fit(*, model, ppisp, dataset, conf, camera_id, frame_id, + vignetting_mode, view_mode, view_seed, epochs, learning_rate, + init: str): + """Run the full fit-by-bake flow. + + ``init`` chooses the warm-start applied before Adam takes over: + * ``"none"`` -- patch out initialize_fit, fit from the clone. + * ``"higher"`` -- adapter default: simple_bake(higher_order=True, srgb=True). + * ``"dc-srgb"`` -- DC-only simple_bake with sRGB->linear (leaves + features_specular at the trained values). + """ + from threedgrut.export.usd.post_processing_sh_simple_bake import simple_bake + + adapter = PPISPPostProcessingBakeAdapter( + camera_id=camera_id, frame_id=frame_id, vignetting_mode=vignetting_mode, + ) + if init == "none": + adapter.initialize_fit = lambda *a, **kw: None # type: ignore[assignment] + elif init == "dc-srgb": + def _dc_srgb_init(baked_model, post_processing, _cid=camera_id, _fid=frame_id): + simple_bake(baked_model, post_processing, + camera_id=_cid, frame_id=_fid, + higher_order=False, apply_srgb_to_linear=True) + adapter.initialize_fit = _dc_srgb_init # type: ignore[assignment] + elif init != "higher": + raise ValueError(f"unknown init: {init!r}") + return bake_post_processing_into_sh( + model=model, post_processing=ppisp, train_dataset=dataset, conf=conf, + adapter=adapter, epochs=epochs, learning_rate=learning_rate, + view_sampling_mode=view_mode, interpolated_views_seed=view_seed, + ) + + +def all_modes(*, fit_epochs: int, fit_lr: float, view_seed: int) -> List[BakeMode]: + return [ + BakeMode( + "simple", + "one-shot DC-only bake", + lambda **k: _build_simple(**k, higher_order=False, srgb=False), + ), + BakeMode( + "simple-higher-order", + "one-shot bake with higher-order Jacobian linearisation", + lambda **k: _build_simple(**k, higher_order=True, srgb=False), + ), + BakeMode( + "simple-higher-order-srgb", + "simple-higher-order with sRGB→linear before RGB2SH", + lambda **k: _build_simple(**k, higher_order=True, srgb=True), + ), + BakeMode( + "fit-base", + "Adam fit, training views, no warm-start", + lambda **k: _build_fit( + **k, vignetting_mode=MODE_PPISP_BAKE_VIGNETTING_ACHROMATIC_FIT, + view_mode="training", view_seed=view_seed, + epochs=fit_epochs, learning_rate=fit_lr, init="none", + ), + ), + BakeMode( + "fit-base-srgb", + "Adam fit, training views, DC-only simple-bake (sRGB->linear) warm-start", + lambda **k: _build_fit( + **k, vignetting_mode=MODE_PPISP_BAKE_VIGNETTING_ACHROMATIC_FIT, + view_mode="training", view_seed=view_seed, + epochs=fit_epochs, learning_rate=fit_lr, init="dc-srgb", + ), + ), + BakeMode( + "fit-base-srgb-trajectory", + "Adam fit, DC-only sRGB warm-start + trajectory views (NN+2-opt, slerp)", + lambda **k: _build_fit( + **k, vignetting_mode=MODE_PPISP_BAKE_VIGNETTING_ACHROMATIC_FIT, + view_mode="trajectory", view_seed=view_seed, + epochs=fit_epochs, learning_rate=fit_lr, init="dc-srgb", + ), + ), + BakeMode( + "fit-init", + "Adam fit, training views, higher-order simple-bake (sRGB->linear) warm-start", + lambda **k: _build_fit( + **k, vignetting_mode=MODE_PPISP_BAKE_VIGNETTING_ACHROMATIC_FIT, + view_mode="training", view_seed=view_seed, + epochs=fit_epochs, learning_rate=fit_lr, init="higher", + ), + ), + BakeMode( + "fit-init-trajectory", + "Adam fit, higher-order init + trajectory views (NN+2-opt, slerp)", + lambda **k: _build_fit( + **k, vignetting_mode=MODE_PPISP_BAKE_VIGNETTING_ACHROMATIC_FIT, + view_mode="trajectory", view_seed=view_seed, + epochs=fit_epochs, learning_rate=fit_lr, init="higher", + ), + ), + ] + + +# --------------------------------------------------------------------------- +# Per-frame evaluation +# --------------------------------------------------------------------------- + + +@dataclass +class FrameMetrics: + psnr: List[float] = field(default_factory=list) + ssim: List[float] = field(default_factory=list) + lpips: List[float] = field(default_factory=list) + + +def _stats(values: List[float]) -> Dict[str, float]: + if not values: + return {"mean": float("nan"), "median": float("nan"), + "min": float("nan"), "max": float("nan")} + arr = np.asarray(values, dtype=np.float64) + return { + "mean": float(np.mean(arr)), + "median": float(np.median(arr)), + "min": float(np.min(arr)), + "max": float(np.max(arr)), + "n": len(values), + } + + +def _evaluate_mode( + baked_model, + reference_model, + fixed_pp, + dataset, + dataloader, + criteria, + vignetting_mode: str, + max_frames: Optional[int], +) -> FrameMetrics: + fm = FrameMetrics() + with torch.no_grad(): + for i, batch in enumerate(dataloader): + if max_frames is not None and i >= max_frames: + break + gpu_batch = dataset.get_gpu_batch_with_intrinsics(batch) + + # reference: full per-frame PPISP applied to reference render + ref_outputs = reference_model(gpu_batch) + ref_outputs = apply_post_processing(fixed_pp, ref_outputs, gpu_batch, training=False) + ref_rgb = ref_outputs["pred_rgb"].clip(0, 1) + + # baked: render + achromatic-vignette + linear_to_srgb + baked_outputs = baked_model(gpu_batch) + baked_rgb_lin = baked_outputs["pred_rgb"] + if vignetting_mode != "none": + _, h, w, _ = baked_rgb_lin.shape + baked_rgb_lin = apply_achromatic_vignetting( + rgb=baked_rgb_lin, ppisp=fixed_pp.ppisp, + camera_id=fixed_pp.camera_id, + pixel_coords=gpu_batch.pixel_coords, + resolution=(w, h), + ) + baked_rgb = torch.clamp(linear_to_srgb(baked_rgb_lin), 0, 1) + + fm.psnr.append(criteria["psnr"](baked_rgb, ref_rgb).item()) + if "ssim" in criteria: + fm.ssim.append(criteria["ssim"]( + baked_rgb.permute(0, 3, 1, 2), ref_rgb.permute(0, 3, 1, 2), + ).item()) + if "lpips" in criteria: + fm.lpips.append(criteria["lpips"]( + baked_rgb.clip(0, 1).permute(0, 3, 1, 2), + ref_rgb.clip(0, 1).permute(0, 3, 1, 2), + ).item()) + return fm + + +# --------------------------------------------------------------------------- +# Reporting +# --------------------------------------------------------------------------- + + +def _print_table(rows: Dict[str, Dict[str, Dict[str, float]]]) -> None: + """Print one table per metric (PSNR / SSIM / LPIPS), sorted by mean.""" + for metric in ("psnr", "ssim", "lpips"): + any_data = any(metric in r for r in rows.values()) + if not any_data: + continue + print(f"\n=== {metric.upper()} (val split, {next(iter(rows.values())).get(metric, {}).get('n', '?')} frames) ===") + if metric == "psnr": + print(f"{'mode':<28} {'mean':>9} {'median':>9} {'min':>9} {'max':>9}") + else: + print(f"{'mode':<28} {'mean':>9} {'median':>9} {'min':>9} {'max':>9}") + sorted_modes = sorted( + rows.items(), + key=lambda kv: -kv[1].get(metric, {}).get("mean", float("-inf")) + if metric == "psnr" or metric == "ssim" + else kv[1].get(metric, {}).get("mean", float("inf")), + ) + for mode_name, metrics in sorted_modes: + s = metrics.get(metric) + if s is None: + continue + fmt = "%.3f" if metric != "psnr" else "%6.3f" + print( + f"{mode_name:<28} " + f"{s['mean']:>9.4f} {s['median']:>9.4f} " + f"{s['min']:>9.4f} {s['max']:>9.4f}" + ) + + +# --------------------------------------------------------------------------- +# CLI +# --------------------------------------------------------------------------- + + +def main(argv=None) -> int: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("--checkpoint", type=Path, required=True) + parser.add_argument("--data-path", type=str, default="") + parser.add_argument("--out-dir", type=Path, required=True) + parser.add_argument("--camera-id", type=int, default=0) + parser.add_argument("--frame-id", type=int, default=0) + parser.add_argument("--fit-epochs", type=int, default=1) + parser.add_argument("--fit-lr", type=float, default=1.0e-3) + parser.add_argument("--view-seed", type=int, default=0) + parser.add_argument("--max-frames", type=int, default=None, + help="Limit val frames for quick smoke checks.") + parser.add_argument("--modes", nargs="*", default=None, + help="Subset of mode names to run (default: all).") + parser.add_argument("--no-extra-metrics", action="store_true", + help="Skip SSIM/LPIPS (PSNR only).") + parser.add_argument("--verbose", "-v", action="count", default=0) + args = parser.parse_args(argv) + logging.basicConfig(level=logging.INFO - 10 * args.verbose, + format="%(asctime)s %(levelname)s %(name)s: %(message)s") + + if not torch.cuda.is_available(): + raise SystemExit("CUDA required.") + + args.out_dir.mkdir(parents=True, exist_ok=True) + + renderer = Renderer.from_checkpoint( + checkpoint_path=str(args.checkpoint), + path=args.data_path, + out_dir=str(args.out_dir / "_renderer"), + save_gt=False, computes_extra_metrics=not args.no_extra_metrics, + ) + if renderer.post_processing is None: + raise SystemExit("Checkpoint does not contain PPISP.") + ppisp = renderer.post_processing + if not hasattr(ppisp, "vignetting_params"): + raise SystemExit("Checkpoint post-processing is not PPISP-like.") + + fixed_pp = FixedPPISP( + ppisp, args.camera_id, args.frame_id, "cuda", include_vignetting=True, + ).eval() + + # Train dataset for the fit modes (interpolated samplers need it for poses). + from threedgrut.export.usd.post_processing_sh_bake import _create_train_dataloader + train_dataset = renderer.dataset.__class__ # type: ignore + # Re-create train dataset from the loader's dataset reference: easier to + # use renderer.conf-based factory. + import threedgrut.datasets as datasets + train_ds = datasets.make_train(name=renderer.conf.dataset.type, config=renderer.conf, ray_jitter=None) + + from torchmetrics import PeakSignalNoiseRatio + criteria: Dict[str, nn.Module] = {"psnr": PeakSignalNoiseRatio(data_range=1).to("cuda")} + if not args.no_extra_metrics: + from torchmetrics.image import StructuralSimilarityIndexMeasure + from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity + criteria["ssim"] = StructuralSimilarityIndexMeasure(data_range=1.0).to("cuda") + criteria["lpips"] = LearnedPerceptualImagePatchSimilarity(net_type="vgg", normalize=True).to("cuda") + + catalogue = all_modes( + fit_epochs=args.fit_epochs, fit_lr=args.fit_lr, view_seed=args.view_seed, + ) + if args.modes is not None: + wanted = set(args.modes) + catalogue = [m for m in catalogue if m.name in wanted] + if not catalogue: + raise SystemExit(f"No modes match {sorted(wanted)}") + + rows: Dict[str, Dict[str, Dict[str, float]]] = {} + timings: Dict[str, float] = {} + + for mode in catalogue: + logger.info("=" * 60) + logger.info("MODE %s -- %s", mode.name, mode.description) + t0 = time.time() + baked = mode.builder( + model=renderer.model, ppisp=ppisp, dataset=train_ds, conf=renderer.conf, + camera_id=args.camera_id, frame_id=args.frame_id, + ) + build_time = time.time() - t0 + logger.info(" built in %.2fs", build_time) + + fm = _evaluate_mode( + baked, renderer.model, fixed_pp, + renderer.dataset, renderer.dataloader, criteria, + vignetting_mode=MODE_PPISP_BAKE_VIGNETTING_ACHROMATIC_FIT, + max_frames=args.max_frames, + ) + row = {"psnr": _stats(fm.psnr)} + if not args.no_extra_metrics: + row["ssim"] = _stats(fm.ssim) + row["lpips"] = _stats(fm.lpips) + rows[mode.name] = row + timings[mode.name] = build_time + logger.info( + " %s: PSNR mean=%.3f median=%.3f (n=%d)", + mode.name, row["psnr"]["mean"], row["psnr"]["median"], row["psnr"]["n"], + ) + + _print_table(rows) + print("\n=== Build time (seconds) ===") + for name, t in sorted(timings.items(), key=lambda kv: kv[1]): + print(f" {name:<28} {t:>7.2f} s") + + # Persist raw per-frame numbers for offline analysis. + serial = { + name: { + "build_time_s": timings[name], + **{ + metric: rows[name][metric] + for metric in ("psnr", "ssim", "lpips") if metric in rows[name] + }, + } + for name in rows + } + with open(args.out_dir / "metrics.json", "w") as f: + json.dump(serial, f, indent=2) + logger.info("metrics.json saved to %s", args.out_dir / "metrics.json") + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) From 0b7a4c2452e312c9cd1fc084f70bfb6b949eb88c Mon Sep 17 00:00:00 2001 From: Horde Date: Sat, 2 May 2026 13:11:32 +0000 Subject: [PATCH 35/42] feat(export): default to DC-only sRGB warm-start + trajectory + 13 epochs Update PPISP SH-bake export defaults to the configuration that wins on the bake-modes ablation: * PPISPPostProcessingBakeAdapter.initialize_fit now uses higher_order=False -- DC-only simple_bake leaves the trained features_specular intact, which Adam fine-tunes faster than recovering from a Jacobian-rotated specular. * post_processing_bake_view_mode default flipped from "training" to "trajectory" -- arc-length-sampled NN+2-opt path through training poses; +0.17 dB on bonsai at 15 epochs and expected to help more on sparse-view datasets. * post_processing_bake_epochs default raised from 1 to 13 -- the fit is far from converged at 1, and 13 sits in the knee of the PSNR-vs-epochs curve (~36.5-37 dB on bonsai). Co-Authored-By: Claude Opus 4.7 (1M context) --- threedgrut/export/scripts/export_usd.py | 4 ++-- threedgrut/export/usd/exporter.py | 4 ++-- threedgrut/export/usd/post_processing_sh_bake.py | 14 ++++++++++---- 3 files changed, 14 insertions(+), 8 deletions(-) diff --git a/threedgrut/export/scripts/export_usd.py b/threedgrut/export/scripts/export_usd.py index 4731268f..436be2f1 100644 --- a/threedgrut/export/scripts/export_usd.py +++ b/threedgrut/export/scripts/export_usd.py @@ -467,7 +467,7 @@ def main(): export_conf, "post-processing-bake-epochs", "post_processing_bake_epochs", - 1, + 13, ), post_processing_bake_learning_rate=_arg_or_conf( args.post_processing_bake_learning_rate, @@ -502,7 +502,7 @@ def main(): export_conf, "post-processing-bake-view-mode", "post_processing_bake_view_mode", - "training", + "trajectory", ), post_processing_bake_view_seed=_arg_or_conf( args.post_processing_bake_view_seed, diff --git a/threedgrut/export/usd/exporter.py b/threedgrut/export/usd/exporter.py index c969f2ce..851b40c1 100644 --- a/threedgrut/export/usd/exporter.py +++ b/threedgrut/export/usd/exporter.py @@ -284,12 +284,12 @@ def __init__( post_processing_export_camera_id: int | None = None, post_processing_export_frame_id: int | None = None, ignore_ppisp_controller: bool = False, - post_processing_bake_epochs: int = 1, + post_processing_bake_epochs: int = 13, post_processing_bake_learning_rate: float = 1.0e-3, post_processing_bake_camera_id: int = 0, post_processing_bake_frame_id: int = 0, ppisp_bake_vignetting_mode: str = MODE_PPISP_BAKE_VIGNETTING_ACHROMATIC_FIT, - post_processing_bake_view_mode: str = "training", + post_processing_bake_view_mode: str = "trajectory", post_processing_bake_view_seed: int | None = None, post_processing_bake_trajectory_weight_position: float = 1.0, post_processing_bake_trajectory_weight_rotation: float = 0.5, diff --git a/threedgrut/export/usd/post_processing_sh_bake.py b/threedgrut/export/usd/post_processing_sh_bake.py index e69168b5..ab05f7fc 100644 --- a/threedgrut/export/usd/post_processing_sh_bake.py +++ b/threedgrut/export/usd/post_processing_sh_bake.py @@ -382,14 +382,20 @@ def apply_fit_transform(self, rgb: torch.Tensor, fixed_post_processing: nn.Modul return torch.clamp(linear_to_srgb(vignetted), 0.0, 1.0) def initialize_fit(self, baked_model, post_processing: nn.Module) -> None: - """Warm-start with the higher-order simple-bake on the chosen - (camera, frame), in linear scene-referred space.""" + """Warm-start with a DC-only simple-bake on the chosen (camera, + frame), in linear scene-referred space. + + The trained ``features_specular`` is left untouched: a higher-order + Jacobian rotation gives a slightly better starting PSNR but + Adam takes much longer to recover from the rotated specular + (~7 dB at 9 epochs on bonsai, see tools/ppisp_export benchmark). + """ # Late import: avoid pulling ppisp into modules that don't need it. from threedgrut.export.usd.post_processing_sh_simple_bake import simple_bake logger.info( "PPISP SH bake init: applying simple_bake (camera=%d, frame=%d, " - "higher_order=True, apply_srgb_to_linear=True) before fitting.", + "higher_order=False, apply_srgb_to_linear=True) before fitting.", self.camera_id, self.frame_id, ) simple_bake( @@ -397,7 +403,7 @@ def initialize_fit(self, baked_model, post_processing: nn.Module) -> None: post_processing, camera_id=self.camera_id, frame_id=self.frame_id, - higher_order=True, + higher_order=False, apply_srgb_to_linear=True, ) From 9973d68d1296096e8d3cbdd1f5a98abdead864fe Mon Sep 17 00:00:00 2001 From: Horde Date: Sun, 3 May 2026 20:17:55 +0000 Subject: [PATCH 36/42] fix(bake): produce gamma-space SH and co-optimise density Two bugs in the PPISP SH bake fit, each found by ablation on caterpillar: 1. The SH was fit in linear scene-referred space against a display-referred target with a linear_to_srgb step in the loss path. The chain rule through linear_to_srgb is wildly non-uniform on [0, 1] (~13x at darks, ~0.4x at brights). Adam over-weights dark-region updates and the high-degree SH bands oscillate to fit amplified noise -- visible as rainbow fringing around silhouettes. Standard 3DGS doesn't see this because it trains in gamma space directly with identity gradient through the loss. Fix: warm-start albedo with simple_bake(apply_srgb_to_linear=False) and strip linear_to_srgb from apply_fit_transform. SH eval is now display- referred; gradient is identity through the loss; the asset format aligns with no-PPISP exports (gamma SH + no post-processing layer). 2. Fixed-geometry SH refit on a smooth synthetic target (PPISP forward of the trained model's prediction) gives the high-degree bands too much license to align coherently into aliasing patterns -- nothing in the loss surface breaks the symmetry. Standard 3DGS training avoids this via density / split-clone-prune adapting the placement to the target. Fix: open features_albedo, features_specular and density to Adam, each at its 3DGS-standard learning rate (2.5e-3, 1.25e-4, 5e-2). On caterpillar (48 val frames, 13 epochs): before: mean 32.0 dB, worst 26.6 dB, strong rainbow on hard frames after: mean 43.4 dB, worst 39.0 dB, rainbow gone Default vignetting flipped from "achromatic-fit" to "none": the chromatic-vs-achromatic mismatch was contributing ~1 dB of error and the new asset format doesn't have a place to ship a runtime vignette anyway. Co-Authored-By: Claude Opus 4.7 (1M context) --- .../export/usd/post_processing_sh_bake.py | 78 ++++++++++++------- 1 file changed, 48 insertions(+), 30 deletions(-) diff --git a/threedgrut/export/usd/post_processing_sh_bake.py b/threedgrut/export/usd/post_processing_sh_bake.py index ab05f7fc..344f33cf 100644 --- a/threedgrut/export/usd/post_processing_sh_bake.py +++ b/threedgrut/export/usd/post_processing_sh_bake.py @@ -99,7 +99,9 @@ def bake_post_processing_into_sh( *, adapter: PostProcessingBakeAdapter, epochs: int = 1, - learning_rate: float = 1.0e-3, + learning_rate: float = 2.5e-3, + learning_rate_specular: float | None = None, + learning_rate_density: float = 5.0e-2, device: str = "cuda", view_sampling_mode: str = "training", interpolated_views_seed: int | None = None, @@ -108,6 +110,16 @@ def bake_post_processing_into_sh( ): """Return a cloned model whose SH coefficients approximate fixed post-processing output. + Three parameter groups are co-optimised, mirroring 3DGS training defaults: + + * ``features_albedo`` at ``learning_rate`` (default 2.5e-3) + * ``features_specular`` at ``learning_rate_specular`` (default = lr/20) + * ``density`` at ``learning_rate_density`` (default 5e-2) + + Letting density breathe absorbs spatial frequencies the SH alone can't + capture without aliasing -- on harder scenes (caterpillar) this is + worth +5 dB worst-case PSNR over fitting only colour coefficients. + ``view_sampling_mode`` controls what the optimizer sees each step: * ``"training"`` (default) -- iterate the training dataloader as usual. @@ -145,15 +157,22 @@ def bake_post_processing_into_sh( baked_model.build_acc() fixed_post_processing = adapter.create_fixed_post_processing(post_processing, device) - # Warm-start the cloned SH state with the adapter's closed-form bake - # (PPISP: simple_bake on the chosen camera/frame, with sRGB→linear so - # the resulting SH lives in linear scene-referred space). Adam takes - # over from there. Reduces the iterations needed and avoids fitting - # from a checkpoint state that's far from the optimum. + # Warm-start the cloned SH state with the adapter's closed-form bake. + # PPISPPostProcessingBakeAdapter writes display-referred (gamma-space) + # DC; Adam takes over from there. Reduces the iterations needed and + # avoids fitting from a checkpoint state far from the optimum. adapter.initialize_fit(baked_model, post_processing) - fit_parameters = list(_set_sh_fit_parameters(baked_model)) - optimizer = torch.optim.Adam(fit_parameters, lr=learning_rate) + if learning_rate_specular is None: + learning_rate_specular = learning_rate / 20.0 # 3DGS default ratio + + _set_sh_fit_parameters(baked_model) + baked_model.density.requires_grad_(True) + optimizer = torch.optim.Adam([ + {"params": [baked_model.features_albedo], "lr": learning_rate}, + {"params": [baked_model.features_specular], "lr": learning_rate_specular}, + {"params": [baked_model.density], "lr": learning_rate_density}, + ]) train_dataloader = _create_train_dataloader(conf, train_dataset) steps_per_epoch = len(train_dataloader) @@ -280,7 +299,7 @@ def forward( def normalize_ppisp_bake_vignetting_mode(mode: str | None) -> str: - normalized = MODE_PPISP_BAKE_VIGNETTING_ACHROMATIC_FIT if mode is None else str(mode).strip().lower() + normalized = MODE_PPISP_BAKE_VIGNETTING_NONE if mode is None else str(mode).strip().lower() if normalized not in PPISP_BAKE_VIGNETTING_MODES: raise ValueError( f"Unsupported PPISP bake vignetting mode '{mode}'. " @@ -339,7 +358,7 @@ def __init__( self, camera_id: int = 0, frame_id: int = 0, - vignetting_mode: str = MODE_PPISP_BAKE_VIGNETTING_ACHROMATIC_FIT, + vignetting_mode: str = MODE_PPISP_BAKE_VIGNETTING_NONE, ) -> None: self.camera_id = int(camera_id) self.frame_id = int(frame_id) @@ -366,36 +385,35 @@ def create_fixed_post_processing(self, post_processing: nn.Module, device: str) ).eval() def apply_fit_transform(self, rgb: torch.Tensor, fixed_post_processing: nn.Module, gpu_batch) -> torch.Tensor: - # PPISP reference is display-referred; baked side must encode to sRGB to match. - from threedgrut.utils.post_processing_linear_to_srgb import linear_to_srgb - - if self.vignetting_mode == MODE_PPISP_BAKE_VIGNETTING_NONE: - return torch.clamp(linear_to_srgb(rgb), 0.0, 1.0) - _, height, width, _ = rgb.shape - vignetted = apply_achromatic_vignetting( - rgb=rgb, - ppisp=fixed_post_processing.ppisp, - camera_id=fixed_post_processing.camera_id, - pixel_coords=gpu_batch.pixel_coords, - resolution=(width, height), - ) - return torch.clamp(linear_to_srgb(vignetted), 0.0, 1.0) + del fixed_post_processing, gpu_batch + # SH eval lives in display (gamma) space -- initialize_fit warm-starts + # with apply_srgb_to_linear=False, the loss target is the full PPISP + # output (also display-referred), and the loss gradient flows through + # identity. Matches the conditioning of training a 3DGS model + # directly in gamma space, where the same SH degree shows no rainbow + # aliasing. + return torch.clamp(rgb, 0.0, 1.0) def initialize_fit(self, baked_model, post_processing: nn.Module) -> None: """Warm-start with a DC-only simple-bake on the chosen (camera, - frame), in linear scene-referred space. + frame), in display (gamma) space. + + Matches the colour space the trainer used when ``post_processing.method`` + is null/linear-to-srgb -- features_albedo lives directly in display- + referred RGB and ``apply_fit_transform`` is the identity. Aligns the + baked-SH USD asset format with no-PPISP exports. The trained ``features_specular`` is left untouched: a higher-order - Jacobian rotation gives a slightly better starting PSNR but - Adam takes much longer to recover from the rotated specular - (~7 dB at 9 epochs on bonsai, see tools/ppisp_export benchmark). + Jacobian rotation gives a slightly better starting PSNR but Adam + takes much longer to recover from the rotated specular (~7 dB at 9 + epochs on bonsai, see tools/ppisp_export benchmark). """ # Late import: avoid pulling ppisp into modules that don't need it. from threedgrut.export.usd.post_processing_sh_simple_bake import simple_bake logger.info( "PPISP SH bake init: applying simple_bake (camera=%d, frame=%d, " - "higher_order=False, apply_srgb_to_linear=True) before fitting.", + "higher_order=False, apply_srgb_to_linear=False) before fitting.", self.camera_id, self.frame_id, ) simple_bake( @@ -404,7 +422,7 @@ def initialize_fit(self, baked_model, post_processing: nn.Module) -> None: camera_id=self.camera_id, frame_id=self.frame_id, higher_order=False, - apply_srgb_to_linear=True, + apply_srgb_to_linear=False, ) def log_context(self) -> str: From 87670c973ceaf4f237d4f1ccf39422495faa4383 Mon Sep 17 00:00:00 2001 From: Horde Date: Sun, 3 May 2026 20:18:02 +0000 Subject: [PATCH 37/42] feat(export): expose per-param-group bake LRs; default vignetting=none Match the new bake_post_processing_into_sh signature in the exporter and CLI: separate albedo / specular / density learning rates, with the 3DGS-standard ratio (specular = albedo / 20) preserved by default. * default --post-processing-bake-learning-rate flipped to 2.5e-3 * new --post-processing-bake-learning-rate-specular (default = albedo/20) * new --post-processing-bake-learning-rate-density (default 5e-2) * default --ppisp-bake-vignetting-mode flipped to "none" Co-Authored-By: Claude Opus 4.7 (1M context) --- threedgrut/export/scripts/export_usd.py | 30 ++++++++++++++- threedgrut/export/usd/exporter.py | 50 ++++++++++++++++++++----- 2 files changed, 68 insertions(+), 12 deletions(-) diff --git a/threedgrut/export/scripts/export_usd.py b/threedgrut/export/scripts/export_usd.py index 436be2f1..b7b013b9 100644 --- a/threedgrut/export/scripts/export_usd.py +++ b/threedgrut/export/scripts/export_usd.py @@ -190,7 +190,19 @@ def parse_args(): "--post-processing-bake-learning-rate", type=float, default=None, - help="Adam learning rate for post-processing baked-SH export.", + help="Adam learning rate for features_albedo (default 2.5e-3, matches 3DGS).", + ) + parser.add_argument( + "--post-processing-bake-learning-rate-specular", + type=float, + default=None, + help="Adam learning rate for features_specular (default = albedo lr / 20, matches 3DGS).", + ) + parser.add_argument( + "--post-processing-bake-learning-rate-density", + type=float, + default=None, + help="Adam learning rate for density (default 5e-2, matches 3DGS).", ) parser.add_argument( "--post-processing-bake-camera-id", @@ -474,7 +486,21 @@ def main(): export_conf, "post-processing-bake-learning-rate", "post_processing_bake_learning_rate", - 1.0e-3, + 2.5e-3, + ), + post_processing_bake_learning_rate_specular=_arg_or_conf( + args.post_processing_bake_learning_rate_specular, + export_conf, + "post-processing-bake-learning-rate-specular", + "post_processing_bake_learning_rate_specular", + None, + ), + post_processing_bake_learning_rate_density=_arg_or_conf( + args.post_processing_bake_learning_rate_density, + export_conf, + "post-processing-bake-learning-rate-density", + "post_processing_bake_learning_rate_density", + 5.0e-2, ), post_processing_bake_camera_id=_arg_or_conf( args.post_processing_bake_camera_id, diff --git a/threedgrut/export/usd/exporter.py b/threedgrut/export/usd/exporter.py index 851b40c1..53c20762 100644 --- a/threedgrut/export/usd/exporter.py +++ b/threedgrut/export/usd/exporter.py @@ -58,7 +58,7 @@ DEFAULT_PARTICLE_FIELD_SORTING_MODE_HINT, normalize_particle_field_sorting_mode_hint, ) -from threedgrut.export.usd.post_processing_sh_bake import MODE_PPISP_BAKE_VIGNETTING_ACHROMATIC_FIT +from threedgrut.export.usd.post_processing_sh_bake import MODE_PPISP_BAKE_VIGNETTING_NONE from threedgrut.export.usd.writers.camera import export_cameras_to_usd logger = logging.getLogger(__name__) @@ -285,10 +285,12 @@ def __init__( post_processing_export_frame_id: int | None = None, ignore_ppisp_controller: bool = False, post_processing_bake_epochs: int = 13, - post_processing_bake_learning_rate: float = 1.0e-3, + post_processing_bake_learning_rate: float = 2.5e-3, + post_processing_bake_learning_rate_specular: float | None = None, + post_processing_bake_learning_rate_density: float = 5.0e-2, post_processing_bake_camera_id: int = 0, post_processing_bake_frame_id: int = 0, - ppisp_bake_vignetting_mode: str = MODE_PPISP_BAKE_VIGNETTING_ACHROMATIC_FIT, + ppisp_bake_vignetting_mode: str = MODE_PPISP_BAKE_VIGNETTING_NONE, post_processing_bake_view_mode: str = "trajectory", post_processing_bake_view_seed: int | None = None, post_processing_bake_trajectory_weight_position: float = 1.0, @@ -322,12 +324,19 @@ def __init__( post_processing_export_frame_id: Optional PPISP frame index to write as static exposure/color inputs in omni-native mode. post_processing_bake_epochs: Number of sequential passes over the train/reference set. - post_processing_bake_learning_rate: Adam learning rate for baked SH. + post_processing_bake_learning_rate: Adam learning rate for features_albedo + (default 2.5e-3, matches 3DGS). + post_processing_bake_learning_rate_specular: Adam learning rate for + features_specular. Defaults to ``learning_rate / 20`` (the 3DGS ratio). + post_processing_bake_learning_rate_density: Adam learning rate for density + (default 5e-2, matches 3DGS). Optimising density alongside SH absorbs + spatial frequencies the SH alone aliases as colour rainbow fringes. post_processing_bake_camera_id: Camera index for the fixed baked transform. post_processing_bake_frame_id: Frame index for the fixed baked transform. - ppisp_bake_vignetting_mode: "none" disables vignetting in the PPISP - reference. "achromatic-fit" keeps chromatic PPISP vignetting in - the reference and applies an achromatic estimate only in the fit loss. + ppisp_bake_vignetting_mode: "none" (default) -- bake produces gamma-space + SH coefficients with no vignetting; the asset format aligns with + no-PPISP exports. "achromatic-fit" is retained for backwards + compatibility but no longer the recommended mode. post_processing_bake_view_mode: which views the bake fit sees per step. "training" iterates the train dataloader (default). "trajectory" orders the training views along an NN+2-opt camera path, parameterises @@ -364,6 +373,13 @@ def __init__( self.ignore_ppisp_controller = bool(ignore_ppisp_controller) self.post_processing_bake_epochs = int(post_processing_bake_epochs) self.post_processing_bake_learning_rate = float(post_processing_bake_learning_rate) + self.post_processing_bake_learning_rate_specular = ( + None if post_processing_bake_learning_rate_specular is None + else float(post_processing_bake_learning_rate_specular) + ) + self.post_processing_bake_learning_rate_density = float( + post_processing_bake_learning_rate_density + ) self.post_processing_bake_camera_id = int(post_processing_bake_camera_id) self.post_processing_bake_frame_id = int(post_processing_bake_frame_id) self.ppisp_bake_vignetting_mode = str(ppisp_bake_vignetting_mode) @@ -460,6 +476,8 @@ def export( adapter=adapter, epochs=self.post_processing_bake_epochs, learning_rate=self.post_processing_bake_learning_rate, + learning_rate_specular=self.post_processing_bake_learning_rate_specular, + learning_rate_density=self.post_processing_bake_learning_rate_density, view_sampling_mode=self.post_processing_bake_view_mode, interpolated_views_seed=self.post_processing_bake_view_seed, trajectory_weight_position=self.post_processing_bake_trajectory_weight_position, @@ -850,13 +868,25 @@ def from_config(cls, conf) -> "USDExporter": export_conf, "post-processing-bake-epochs", "post_processing_bake_epochs", - 1, + 13, ), post_processing_bake_learning_rate=_get_export_config_value( export_conf, "post-processing-bake-learning-rate", "post_processing_bake_learning_rate", - 1.0e-3, + 2.5e-3, + ), + post_processing_bake_learning_rate_specular=_get_export_config_value( + export_conf, + "post-processing-bake-learning-rate-specular", + "post_processing_bake_learning_rate_specular", + None, + ), + post_processing_bake_learning_rate_density=_get_export_config_value( + export_conf, + "post-processing-bake-learning-rate-density", + "post_processing_bake_learning_rate_density", + 5.0e-2, ), post_processing_bake_camera_id=_get_export_config_value( export_conf, @@ -874,7 +904,7 @@ def from_config(cls, conf) -> "USDExporter": export_conf, "ppisp-bake-vignetting-mode", "ppisp_bake_vignetting_mode", - MODE_PPISP_BAKE_VIGNETTING_ACHROMATIC_FIT, + MODE_PPISP_BAKE_VIGNETTING_NONE, ), post_processing_bake_view_mode=_get_export_config_value( export_conf, From 4bb70bfe8c6b045c537435797ac93924f0cdd594 Mon Sep 17 00:00:00 2001 From: Horde Date: Sun, 3 May 2026 20:18:13 +0000 Subject: [PATCH 38/42] refactor(tools): trim bake-modes benchmark for the gamma-SH pipeline Drop the linear-SH + achromatic-vignette modes (fit-base, fit-base-srgb, fit-init, ...) -- they no longer match the production fit path and were only useful to debug the rainbow regression. Replace with a leaner catalogue aligned with the new defaults: simple one-shot DC-only bake (no fit, gamma SH) simple-higher-order one-shot DC + Jacobian-rotated specular (no fit) fit-color-only Adam on albedo + specular only (density ablation) fit Adam on albedo + specular + density (production) fit-trajectory fit + trajectory view sampling Eval likewise simplifies: SH is already display-referred so the baked side just clips to [0, 1]; reference uses no-vignette PPISP. On caterpillar (48 val frames, 13 epochs) the 'fit' mode lands at 43.4 dB mean / 39.0 dB worst-case, vs 32.0 / 26.6 with the old fit-base. Co-Authored-By: Claude Opus 4.7 (1M context) --- .../bake_modes_benchmark/benchmark.py | 151 ++++++------------ 1 file changed, 48 insertions(+), 103 deletions(-) diff --git a/tools/ppisp_export/bake_modes_benchmark/benchmark.py b/tools/ppisp_export/bake_modes_benchmark/benchmark.py index 1a2b9818..0d1d902f 100644 --- a/tools/ppisp_export/bake_modes_benchmark/benchmark.py +++ b/tools/ppisp_export/bake_modes_benchmark/benchmark.py @@ -4,31 +4,28 @@ """Sweep PPISP SH-bake modes on a trained checkpoint and report aggregated metrics. -For each configured (bake mode, view sampling mode, init policy) tuple -the script: +The bake fits gamma-space (display-referred) SH coefficients against the +PPISP forward output of the trained model, matching the colour space of +the no-PPISP export. Modes vary along two axes: -1. Builds a baked model from the cloned checkpoint: - - ``simple`` / ``simple-higher-order`` flavours run :func:`simple_bake` - directly (with optional sRGB→linear). - - ``fit`` flavours run :func:`bake_post_processing_into_sh` with the - desired view sampling mode and init policy. +* ``simple`` flavours skip optimisation and write only the DC band. +* ``fit`` flavours run :func:`bake_post_processing_into_sh` -- Adam over + features_albedo, features_specular, and (optionally) density. View + sampling is either ``training`` (iterate the dataloader) or + ``trajectory`` (NN+2-opt arc-length-parameterised slerp through training + poses; useful when training views are sparse). -2. Renders every validation frame through: - - the *reference* model + chromatic-vignette PPISP at the chosen - (camera, frame) -- the per-frame target. - - the *baked* model with ``linear_to_srgb`` applied to its output and - an achromatic-vignette correction (matches the evaluator in - ``post_processing_sh_bake_validation.py``). +Per-frame validation: + reference = full PPISP applied to reference-model render at val pose + baked = baked-model render (already display-referred) clipped to [0, 1] -3. Computes per-frame PSNR, SSIM and LPIPS, then aggregates mean / - median / min / max across the validation split. - -4. Prints a table sorted by mean PSNR and writes the raw per-frame - numbers to ``/metrics.json``. +Metrics: per-frame PSNR (+ optional SSIM / LPIPS), aggregated mean / +median / min / max across the val split. Raw per-frame numbers are +persisted to ``/metrics.json``. Usage: - python tools/bake_modes_benchmark/benchmark.py \\ + python tools/ppisp_export/bake_modes_benchmark/benchmark.py \\ --checkpoint runs//ckpt_last.pt \\ --out-dir /tmp/bake_modes \\ --camera-id 0 --frame-id 0 @@ -52,15 +49,13 @@ sys.path.insert(0, str(Path(__file__).resolve().parents[2])) from threedgrut.export.usd.post_processing_sh_bake import ( # noqa: E402 - MODE_PPISP_BAKE_VIGNETTING_ACHROMATIC_FIT, + MODE_PPISP_BAKE_VIGNETTING_NONE, PPISPPostProcessingBakeAdapter, - apply_achromatic_vignetting, bake_post_processing_into_sh, FixedPPISP, ) from threedgrut.export.usd.post_processing_sh_simple_bake import simple_bake # noqa: E402 from threedgrut.render import Renderer # noqa: E402 -from threedgrut.utils.post_processing_linear_to_srgb import linear_to_srgb # noqa: E402 from threedgrut.utils.render import apply_post_processing # noqa: E402 logger = logging.getLogger("bake_modes_benchmark") @@ -79,48 +74,32 @@ class BakeMode: builder: Callable[..., nn.Module] -def _build_simple(*, model, ppisp, camera_id, frame_id, higher_order, srgb, +def _build_simple(*, model, ppisp, camera_id, frame_id, higher_order, dataset=None, conf=None): del dataset, conf # unused by the simple flavours baked = model.clone().eval() simple_bake( baked, ppisp, camera_id=camera_id, frame_id=frame_id, - higher_order=higher_order, apply_srgb_to_linear=srgb, + higher_order=higher_order, apply_srgb_to_linear=False, ) baked.build_acc() return baked def _build_fit(*, model, ppisp, dataset, conf, camera_id, frame_id, - vignetting_mode, view_mode, view_seed, epochs, learning_rate, - init: str): - """Run the full fit-by-bake flow. - - ``init`` chooses the warm-start applied before Adam takes over: - * ``"none"`` -- patch out initialize_fit, fit from the clone. - * ``"higher"`` -- adapter default: simple_bake(higher_order=True, srgb=True). - * ``"dc-srgb"`` -- DC-only simple_bake with sRGB->linear (leaves - features_specular at the trained values). - """ - from threedgrut.export.usd.post_processing_sh_simple_bake import simple_bake - + view_mode, view_seed, epochs, learning_rate, optimize_density: bool): + """Run the full fit-by-bake flow with the production adapter (gamma SH, + no vignetting). ``optimize_density=False`` ablates the density param + group by setting its lr to zero.""" adapter = PPISPPostProcessingBakeAdapter( - camera_id=camera_id, frame_id=frame_id, vignetting_mode=vignetting_mode, + camera_id=camera_id, frame_id=frame_id, + vignetting_mode=MODE_PPISP_BAKE_VIGNETTING_NONE, ) - if init == "none": - adapter.initialize_fit = lambda *a, **kw: None # type: ignore[assignment] - elif init == "dc-srgb": - def _dc_srgb_init(baked_model, post_processing, _cid=camera_id, _fid=frame_id): - simple_bake(baked_model, post_processing, - camera_id=_cid, frame_id=_fid, - higher_order=False, apply_srgb_to_linear=True) - adapter.initialize_fit = _dc_srgb_init # type: ignore[assignment] - elif init != "higher": - raise ValueError(f"unknown init: {init!r}") return bake_post_processing_into_sh( model=model, post_processing=ppisp, train_dataset=dataset, conf=conf, adapter=adapter, epochs=epochs, learning_rate=learning_rate, + learning_rate_density=(5.0e-2 if optimize_density else 0.0), view_sampling_mode=view_mode, interpolated_views_seed=view_seed, ) @@ -129,62 +108,36 @@ def all_modes(*, fit_epochs: int, fit_lr: float, view_seed: int) -> List[BakeMod return [ BakeMode( "simple", - "one-shot DC-only bake", - lambda **k: _build_simple(**k, higher_order=False, srgb=False), + "one-shot DC-only bake (no fit, gamma SH)", + lambda **k: _build_simple(**k, higher_order=False), ), BakeMode( "simple-higher-order", - "one-shot bake with higher-order Jacobian linearisation", - lambda **k: _build_simple(**k, higher_order=True, srgb=False), - ), - BakeMode( - "simple-higher-order-srgb", - "simple-higher-order with sRGB→linear before RGB2SH", - lambda **k: _build_simple(**k, higher_order=True, srgb=True), - ), - BakeMode( - "fit-base", - "Adam fit, training views, no warm-start", - lambda **k: _build_fit( - **k, vignetting_mode=MODE_PPISP_BAKE_VIGNETTING_ACHROMATIC_FIT, - view_mode="training", view_seed=view_seed, - epochs=fit_epochs, learning_rate=fit_lr, init="none", - ), - ), - BakeMode( - "fit-base-srgb", - "Adam fit, training views, DC-only simple-bake (sRGB->linear) warm-start", - lambda **k: _build_fit( - **k, vignetting_mode=MODE_PPISP_BAKE_VIGNETTING_ACHROMATIC_FIT, - view_mode="training", view_seed=view_seed, - epochs=fit_epochs, learning_rate=fit_lr, init="dc-srgb", - ), + "one-shot DC + Jacobian-rotated specular (no fit)", + lambda **k: _build_simple(**k, higher_order=True), ), BakeMode( - "fit-base-srgb-trajectory", - "Adam fit, DC-only sRGB warm-start + trajectory views (NN+2-opt, slerp)", + "fit-color-only", + "Adam fit on features_albedo + features_specular only, training views", lambda **k: _build_fit( - **k, vignetting_mode=MODE_PPISP_BAKE_VIGNETTING_ACHROMATIC_FIT, - view_mode="trajectory", view_seed=view_seed, - epochs=fit_epochs, learning_rate=fit_lr, init="dc-srgb", + **k, view_mode="training", view_seed=view_seed, + epochs=fit_epochs, learning_rate=fit_lr, optimize_density=False, ), ), BakeMode( - "fit-init", - "Adam fit, training views, higher-order simple-bake (sRGB->linear) warm-start", + "fit", + "Adam fit on albedo + specular + density, training views (production default)", lambda **k: _build_fit( - **k, vignetting_mode=MODE_PPISP_BAKE_VIGNETTING_ACHROMATIC_FIT, - view_mode="training", view_seed=view_seed, - epochs=fit_epochs, learning_rate=fit_lr, init="higher", + **k, view_mode="training", view_seed=view_seed, + epochs=fit_epochs, learning_rate=fit_lr, optimize_density=True, ), ), BakeMode( - "fit-init-trajectory", - "Adam fit, higher-order init + trajectory views (NN+2-opt, slerp)", + "fit-trajectory", + "Adam fit on albedo + specular + density, trajectory views (NN+2-opt slerp)", lambda **k: _build_fit( - **k, vignetting_mode=MODE_PPISP_BAKE_VIGNETTING_ACHROMATIC_FIT, - view_mode="trajectory", view_seed=view_seed, - epochs=fit_epochs, learning_rate=fit_lr, init="higher", + **k, view_mode="trajectory", view_seed=view_seed, + epochs=fit_epochs, learning_rate=fit_lr, optimize_density=True, ), ), ] @@ -223,7 +176,6 @@ def _evaluate_mode( dataset, dataloader, criteria, - vignetting_mode: str, max_frames: Optional[int], ) -> FrameMetrics: fm = FrameMetrics() @@ -238,18 +190,9 @@ def _evaluate_mode( ref_outputs = apply_post_processing(fixed_pp, ref_outputs, gpu_batch, training=False) ref_rgb = ref_outputs["pred_rgb"].clip(0, 1) - # baked: render + achromatic-vignette + linear_to_srgb + # baked: SH eval is already display-referred (gamma); just clip. baked_outputs = baked_model(gpu_batch) - baked_rgb_lin = baked_outputs["pred_rgb"] - if vignetting_mode != "none": - _, h, w, _ = baked_rgb_lin.shape - baked_rgb_lin = apply_achromatic_vignetting( - rgb=baked_rgb_lin, ppisp=fixed_pp.ppisp, - camera_id=fixed_pp.camera_id, - pixel_coords=gpu_batch.pixel_coords, - resolution=(w, h), - ) - baked_rgb = torch.clamp(linear_to_srgb(baked_rgb_lin), 0, 1) + baked_rgb = torch.clamp(baked_outputs["pred_rgb"], 0, 1) fm.psnr.append(criteria["psnr"](baked_rgb, ref_rgb).item()) if "ssim" in criteria: @@ -341,8 +284,11 @@ def main(argv=None) -> int: if not hasattr(ppisp, "vignetting_params"): raise SystemExit("Checkpoint post-processing is not PPISP-like.") + # The bake target is PPISP-without-vignetting (matches the production + # MODE_PPISP_BAKE_VIGNETTING_NONE adapter); both reference and baked + # sides therefore live in the same display-referred space. fixed_pp = FixedPPISP( - ppisp, args.camera_id, args.frame_id, "cuda", include_vignetting=True, + ppisp, args.camera_id, args.frame_id, "cuda", include_vignetting=False, ).eval() # Train dataset for the fit modes (interpolated samplers need it for poses). @@ -387,7 +333,6 @@ def main(argv=None) -> int: fm = _evaluate_mode( baked, renderer.model, fixed_pp, renderer.dataset, renderer.dataloader, criteria, - vignetting_mode=MODE_PPISP_BAKE_VIGNETTING_ACHROMATIC_FIT, max_frames=args.max_frames, ) row = {"psnr": _stats(fm.psnr)} From 542929b3548ce8a3cd721a231fb4da868bf225d0 Mon Sep 17 00:00:00 2001 From: Horde Date: Mon, 4 May 2026 12:59:18 +0000 Subject: [PATCH 39/42] perf(export): drop default bake epochs from 13 to 7 The well-conditioned gamma-SH fit converges fast: per-step loss curves on bonsai and caterpillar show diminishing returns past epoch 7. Cutting the default ~halves wall-clock time per export with marginal PSNR cost (rough estimate -0.5 dB mean, the asymptote is ~the same). Users wanting the last fraction of a dB can still set --post-processing-bake-epochs. Co-Authored-By: Claude Opus 4.7 (1M context) --- threedgrut/export/scripts/export_usd.py | 2 +- threedgrut/export/usd/exporter.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/threedgrut/export/scripts/export_usd.py b/threedgrut/export/scripts/export_usd.py index b7b013b9..3abedf80 100644 --- a/threedgrut/export/scripts/export_usd.py +++ b/threedgrut/export/scripts/export_usd.py @@ -479,7 +479,7 @@ def main(): export_conf, "post-processing-bake-epochs", "post_processing_bake_epochs", - 13, + 7, ), post_processing_bake_learning_rate=_arg_or_conf( args.post_processing_bake_learning_rate, diff --git a/threedgrut/export/usd/exporter.py b/threedgrut/export/usd/exporter.py index 53c20762..4dec2df4 100644 --- a/threedgrut/export/usd/exporter.py +++ b/threedgrut/export/usd/exporter.py @@ -284,7 +284,7 @@ def __init__( post_processing_export_camera_id: int | None = None, post_processing_export_frame_id: int | None = None, ignore_ppisp_controller: bool = False, - post_processing_bake_epochs: int = 13, + post_processing_bake_epochs: int = 7, post_processing_bake_learning_rate: float = 2.5e-3, post_processing_bake_learning_rate_specular: float | None = None, post_processing_bake_learning_rate_density: float = 5.0e-2, @@ -868,7 +868,7 @@ def from_config(cls, conf) -> "USDExporter": export_conf, "post-processing-bake-epochs", "post_processing_bake_epochs", - 13, + 7, ), post_processing_bake_learning_rate=_get_export_config_value( export_conf, From d326ef419cc6afd029d2bbc7145ea6010991ab64 Mon Sep 17 00:00:00 2001 From: Horde Date: Mon, 4 May 2026 14:55:29 +0000 Subject: [PATCH 40/42] feat(export): add --output-scale to scale exported SH output A single multiplicative knob applied to the asset's SH-evaluated RGB. ``features_specular`` is scaled linearly; ``features_albedo`` picks up an extra ``(s - 1) * 0.5 / C0`` term to compensate for the constant offset baked into ``RGB2SH`` so a forward eval reproduces ``s * original_rgb``. Works uniformly across export modes (no-PPISP linear-SH, no-PPISP gamma, PPISP baked-sh) -- the offset compensation is structural, not colour- space-dependent. Default 1.0 keeps existing exports byte-identical. Useful for matching downstream tonemap exposure or compensating for runtime gain that the asset's consumer applies. Co-Authored-By: Claude Opus 4.7 (1M context) --- threedgrut/export/scripts/export_usd.py | 18 +++++++++++++++++ threedgrut/export/usd/exporter.py | 20 ++++++++++++++++++- .../export/usd/post_processing_sh_bake.py | 20 ++++++++++++++++++- 3 files changed, 56 insertions(+), 2 deletions(-) diff --git a/threedgrut/export/scripts/export_usd.py b/threedgrut/export/scripts/export_usd.py index 3abedf80..ce030396 100644 --- a/threedgrut/export/scripts/export_usd.py +++ b/threedgrut/export/scripts/export_usd.py @@ -255,6 +255,17 @@ def parse_args(): default=None, help="Trajectory mode only: weight on the (1 - cos(angle)) rotation term in pose distance.", ) + parser.add_argument( + "--output-scale", + type=float, + default=None, + help=( + "Multiplicative scale applied to the SH-evaluated RGB output of the " + "exported asset. Default 1.0 (no-op). The DC offset is compensated so " + "rendered output equals output-scale x original eval. Useful for " + "matching downstream tonemap exposure." + ), + ) # Dataset path (optional, overrides checkpoint's dataset path) parser.add_argument( @@ -551,6 +562,13 @@ def main(): "post_processing_bake_trajectory_weight_rotation", 0.5, ), + output_scale=_arg_or_conf( + args.output_scale, + export_conf, + "output-scale", + "output_scale", + 1.0, + ), frames_per_second=getattr(export_conf, "frames_per_second", 1.0), ) logger.info("Using ParticleField3DGaussianSplat schema (standard)") diff --git a/threedgrut/export/usd/exporter.py b/threedgrut/export/usd/exporter.py index 4dec2df4..437d6082 100644 --- a/threedgrut/export/usd/exporter.py +++ b/threedgrut/export/usd/exporter.py @@ -58,7 +58,10 @@ DEFAULT_PARTICLE_FIELD_SORTING_MODE_HINT, normalize_particle_field_sorting_mode_hint, ) -from threedgrut.export.usd.post_processing_sh_bake import MODE_PPISP_BAKE_VIGNETTING_NONE +from threedgrut.export.usd.post_processing_sh_bake import ( + MODE_PPISP_BAKE_VIGNETTING_NONE, + scale_sh_output, +) from threedgrut.export.usd.writers.camera import export_cameras_to_usd logger = logging.getLogger(__name__) @@ -295,6 +298,7 @@ def __init__( post_processing_bake_view_seed: int | None = None, post_processing_bake_trajectory_weight_position: float = 1.0, post_processing_bake_trajectory_weight_rotation: float = 0.5, + output_scale: float = 1.0, frames_per_second: float = 1.0, ): """ @@ -393,6 +397,7 @@ def __init__( self.post_processing_bake_trajectory_weight_rotation = float( post_processing_bake_trajectory_weight_rotation ) + self.output_scale = float(output_scale) self.frames_per_second = frames_per_second def _create_default_stage(self, referenced_stages: List[NamedUSDStage]) -> NamedUSDStage: @@ -486,6 +491,13 @@ def export( if uses_omni_native_post_processing_export and not has_ppisp_module: raise ValueError("Omniverse-native post-processing export currently supports PPISP post-processing only.") + # User-requested constant brightness scale, applied uniformly to the + # SH output regardless of bake / colour-space mode. The DC offset + # baked into RGB2SH is compensated so a forward eval reproduces + # output_scale * (original SH-evaluated RGB). + if self.output_scale != 1.0: + scale_sh_output(model, self.output_scale) + # Get model data via accessor accessor = GaussianExportAccessor(model, conf) attrs = accessor.get_attributes(preactivation=False) @@ -930,5 +942,11 @@ def from_config(cls, conf) -> "USDExporter": "post_processing_bake_trajectory_weight_rotation", 0.5, ), + output_scale=_get_export_config_value( + export_conf, + "output-scale", + "output_scale", + 1.0, + ), frames_per_second=getattr(export_conf, "frames_per_second", 1.0), ) diff --git a/threedgrut/export/usd/post_processing_sh_bake.py b/threedgrut/export/usd/post_processing_sh_bake.py index 344f33cf..2bab8d6b 100644 --- a/threedgrut/export/usd/post_processing_sh_bake.py +++ b/threedgrut/export/usd/post_processing_sh_bake.py @@ -25,11 +25,29 @@ import torch.nn as nn from threedgrut.datasets.utils import configure_dataloader_for_platform -from threedgrut.utils.render import apply_post_processing +from threedgrut.utils.render import C0, apply_post_processing logger = logging.getLogger(__name__) +def scale_sh_output(model, scale: float) -> None: + """In-place scale the SH-evaluated RGB output by ``scale``. + + SH eval is ``rgb = features_albedo * C0 + 0.5 + sum_k Y_k * features_specular_k``. + To get ``s * rgb`` from a forward eval, every term must be scaled: + * features_specular -> s * features_specular (linear, view-dep bands) + * features_albedo -> s * features_albedo + (s - 1) * 0.5 / C0 + compensates for the constant ``0.5`` offset in the DC band. + """ + if scale == 1.0: + return + s = float(scale) + with torch.no_grad(): + model.features_specular.mul_(s) + model.features_albedo.mul_(s).add_((s - 1.0) * 0.5 / C0) + logger.info("Scaled SH output by %.4f (DC offset compensated)", s) + + class PostProcessingBakeAdapter: """Adapter interface for baking one fixed post-processing transform.""" From 62ff7c50a13a2c1b3fc26aef27487277bfc382ec Mon Sep 17 00:00:00 2001 From: Horde Date: Mon, 4 May 2026 14:55:40 +0000 Subject: [PATCH 41/42] feat(ppisp-spg): add user-overridable per-camera responsivity Adds a ``responsivityR/G/B`` float input on the PPISP SPG shader (both the static and controller-aware variants) that is premultiplied to the input HdrColor before exposure / vignetting / colour correction / CRF run. Default 1.0 per channel keeps exported assets visually identical; authors can scale per-channel sensitivity post-export by editing the USD ``inputs:responsivityR/G/B`` attributes on the per-camera shader prim. Touches: * both ``.slang`` shaders (struct field + premultiply at top of main) * both ``.slang.lua`` launchers (bind the new fields, default 1.0) * both ``.slang.usda`` schemas (declare inputs with default 1.0) * ``ppisp_writer.py`` writes the inputs explicitly via ``_set_responsivity_params`` so the attributes round-trip cleanly. Co-Authored-By: Claude Opus 4.7 (1M context) --- threedgrut/export/usd/ppisp_spg/ppisp_usd_spg.slang | 8 ++++++++ .../export/usd/ppisp_spg/ppisp_usd_spg.slang.lua | 5 +++++ .../export/usd/ppisp_spg/ppisp_usd_spg.slang.usda | 5 +++++ .../export/usd/ppisp_spg/ppisp_usd_spg_dyn.slang | 8 ++++++++ .../export/usd/ppisp_spg/ppisp_usd_spg_dyn.slang.lua | 4 ++++ .../export/usd/ppisp_spg/ppisp_usd_spg_dyn.slang.usda | 5 +++++ threedgrut/export/usd/writers/ppisp_writer.py | 10 ++++++++++ 7 files changed, 45 insertions(+) diff --git a/threedgrut/export/usd/ppisp_spg/ppisp_usd_spg.slang b/threedgrut/export/usd/ppisp_spg/ppisp_usd_spg.slang index b721cb6d..2e20fc55 100644 --- a/threedgrut/export/usd/ppisp_spg/ppisp_usd_spg.slang +++ b/threedgrut/export/usd/ppisp_spg/ppisp_usd_spg.slang @@ -26,6 +26,13 @@ struct PPISPParams { + // User-overridable per-camera responsivity, premultiplied to the + // input HdrColor before the rest of the pipeline runs. Defaults to + // (1, 1, 1) so the override is a no-op unless explicitly authored. + float responsivityR; + float responsivityG; + float responsivityB; + // Exposure float exposureOffset; @@ -215,6 +222,7 @@ void ppispProcess(uint3 tid : SV_DispatchThreadID) float4 pixel = g_InTex.Load(int3(tid.xy, 0)); float3 rgb = pixel.rgb; + rgb *= float3(g_Params.responsivityR, g_Params.responsivityG, g_Params.responsivityB); // Normalize to [-0.5, 0.5] range based on max dimension (matching CUDA kernel) float maxRes = max(float(w), float(h)); diff --git a/threedgrut/export/usd/ppisp_spg/ppisp_usd_spg.slang.lua b/threedgrut/export/usd/ppisp_spg/ppisp_usd_spg.slang.lua index 971be9ea..716014f0 100644 --- a/threedgrut/export/usd/ppisp_spg/ppisp_usd_spg.slang.lua +++ b/threedgrut/export/usd/ppisp_spg/ppisp_usd_spg.slang.lua @@ -29,6 +29,11 @@ function ppispProcess(inputs, outputs, params) grid = { math.ceil(width / 16), math.ceil(height / 16), 1 }, bind = { slang.ParameterBlock( + -- Per-camera responsivity (premultiplied to input HDR) + slang.float(params["responsivityR"] or 1.0), + slang.float(params["responsivityG"] or 1.0), + slang.float(params["responsivityB"] or 1.0), + -- Exposure slang.float(params["exposureOffset"] or 0.0), diff --git a/threedgrut/export/usd/ppisp_spg/ppisp_usd_spg.slang.usda b/threedgrut/export/usd/ppisp_spg/ppisp_usd_spg.slang.usda index b423a28e..50d72281 100644 --- a/threedgrut/export/usd/ppisp_spg/ppisp_usd_spg.slang.usda +++ b/threedgrut/export/usd/ppisp_spg/ppisp_usd_spg.slang.usda @@ -9,6 +9,11 @@ def Shader "SlangPPISP" uniform asset info:spg:sourceAsset = @ppisp_usd_spg.slang@ uniform token info:spg:sourceAsset:subIdentifier = "ppispProcess" + # User-overridable per-camera responsivity (premultiplied to input HDR). + float inputs:responsivityR = 1.0 + float inputs:responsivityG = 1.0 + float inputs:responsivityB = 1.0 + # Exposure parameter float inputs:exposureOffset = 0.0 diff --git a/threedgrut/export/usd/ppisp_spg/ppisp_usd_spg_dyn.slang b/threedgrut/export/usd/ppisp_spg/ppisp_usd_spg_dyn.slang index e84fef70..0e582d74 100644 --- a/threedgrut/export/usd/ppisp_spg/ppisp_usd_spg_dyn.slang +++ b/threedgrut/export/usd/ppisp_spg/ppisp_usd_spg_dyn.slang @@ -24,6 +24,13 @@ struct PPISPDynParams { + // User-overridable per-camera responsivity, premultiplied to the + // input HdrColor before the rest of the pipeline runs. Defaults to + // (1, 1, 1) so the override is a no-op unless explicitly authored. + float responsivityR; + float responsivityG; + float responsivityB; + float2 vignettingCenterR; float vignettingAlpha1R; float vignettingAlpha2R; @@ -175,6 +182,7 @@ void ppispProcessDyn(uint3 tid : SV_DispatchThreadID) float4 pixel = g_InTex.Load(int3(tid.xy, 0)); float3 rgb = pixel.rgb; + rgb *= float3(g_Params.responsivityR, g_Params.responsivityG, g_Params.responsivityB); float maxRes = max(float(w), float(h)); float2 uv = float2(tid.x + 0.5 - float(w) * 0.5, diff --git a/threedgrut/export/usd/ppisp_spg/ppisp_usd_spg_dyn.slang.lua b/threedgrut/export/usd/ppisp_spg/ppisp_usd_spg_dyn.slang.lua index 7d140e21..735b8cba 100644 --- a/threedgrut/export/usd/ppisp_spg/ppisp_usd_spg_dyn.slang.lua +++ b/threedgrut/export/usd/ppisp_spg/ppisp_usd_spg_dyn.slang.lua @@ -30,6 +30,10 @@ function ppispProcessDyn(inputs, outputs, params) grid = { math.ceil(width / 16), math.ceil(height / 16), 1 }, bind = { slang.ParameterBlock( + slang.float(params["responsivityR"] or 1.0), + slang.float(params["responsivityG"] or 1.0), + slang.float(params["responsivityB"] or 1.0), + getFloat2("vignettingCenterR"), slang.float(params["vignettingAlpha1R"] or 0.0), slang.float(params["vignettingAlpha2R"] or 0.0), diff --git a/threedgrut/export/usd/ppisp_spg/ppisp_usd_spg_dyn.slang.usda b/threedgrut/export/usd/ppisp_spg/ppisp_usd_spg_dyn.slang.usda index 97848613..934bcd38 100644 --- a/threedgrut/export/usd/ppisp_spg/ppisp_usd_spg_dyn.slang.usda +++ b/threedgrut/export/usd/ppisp_spg/ppisp_usd_spg_dyn.slang.usda @@ -9,6 +9,11 @@ def Shader "SlangPPISPDyn" uniform asset info:spg:sourceAsset = @ppisp_usd_spg_dyn.slang@ uniform token info:spg:sourceAsset:subIdentifier = "ppispProcessDyn" + # User-overridable per-camera responsivity (premultiplied to input HDR). + float inputs:responsivityR = 1.0 + float inputs:responsivityG = 1.0 + float inputs:responsivityB = 1.0 + # Vignetting (per channel: R, G, B) float2 inputs:vignettingCenterR = (0.0, 0.0) float inputs:vignettingAlpha1R = 0.0 diff --git a/threedgrut/export/usd/writers/ppisp_writer.py b/threedgrut/export/usd/writers/ppisp_writer.py index b3a43fb2..83771153 100644 --- a/threedgrut/export/usd/writers/ppisp_writer.py +++ b/threedgrut/export/usd/writers/ppisp_writer.py @@ -200,6 +200,15 @@ def _create_shader_prim( # --------------------------------------------------------------------------- +def _set_responsivity_params(shader: UsdShade.Shader) -> None: + """Author the user-overridable per-channel responsivity inputs (default + 1.0). The shader premultiplies these with the input HDR before the rest + of the PPISP pipeline runs; consumers can override the values per-camera + in the USD asset without re-running the export.""" + for channel in ("R", "G", "B"): + shader.CreateInput(f"responsivity{channel}", Sdf.ValueTypeNames.Float).Set(1.0) + + def _set_vignetting_params(shader: UsdShade.Shader, ppisp: PPISP, camera_index: int) -> None: """Set per-camera vignetting parameters (static). @@ -367,6 +376,7 @@ def add_ppisp_shader_to_render_product( return stage.GetPseudoRoot() shader = _create_shader_prim(stage, render_product_path, controller_shader=controller_shader) + _set_responsivity_params(shader) _set_vignetting_params(shader, ppisp, camera_index) _set_crf_params(shader, ppisp, camera_index) if controller_shader is not None: From 0e791a2a9758771db9ae2312102457747e34fddb Mon Sep 17 00:00:00 2001 From: Nicolas Moenne-Loccoz Date: Tue, 5 May 2026 16:38:35 -0400 Subject: [PATCH 42/42] chore(export): clean up PPISP PR docs Remove local Cursor agent metadata from the branch and fix stale PPISP plan references so the PR content matches the exported tooling. --- .cursor/rules/formatter-defaults.mdc | 10 ---------- docs/ppisp-controller-export-plan.md | 2 +- ...ppsip-to-rtx-pp-plan.md => ppisp-to-rtx-pp-plan.md} | 7 ++----- 3 files changed, 3 insertions(+), 16 deletions(-) delete mode 100644 .cursor/rules/formatter-defaults.mdc rename docs/{ppsip-to-rtx-pp-plan.md => ppisp-to-rtx-pp-plan.md} (98%) diff --git a/.cursor/rules/formatter-defaults.mdc b/.cursor/rules/formatter-defaults.mdc deleted file mode 100644 index a6e41aaf..00000000 --- a/.cursor/rules/formatter-defaults.mdc +++ /dev/null @@ -1,10 +0,0 @@ ---- -description: Formatter defaults for Cursor agents -alwaysApply: true ---- - -# Formatter Defaults - -- Do not run `bazel run //:format` by default. -- Prefer the smallest relevant formatter or validation command for the files changed, such as Python syntax checks, lints, or file-scoped formatters. -- Ask before running repository-wide formatting commands, especially when they may take a long time or touch unrelated files. diff --git a/docs/ppisp-controller-export-plan.md b/docs/ppisp-controller-export-plan.md index eee951f9..67ba67cc 100644 --- a/docs/ppisp-controller-export-plan.md +++ b/docs/ppisp-controller-export-plan.md @@ -137,7 +137,7 @@ Two-pronged approach: PyTorch controller's outputs to within a tight tolerance, using `slangpy` to dispatch the controller shader against a reference image. -2. **Tool: `tools/render_usd_renderproduct/`** — a slangpy-based runner +2. **Tool: `tools/render_ppisp_spg/`** — a slangpy-based runner that opens an exported USD/USDZ, walks `/Render/` prims, finds their SPG shader chain, and replays the chain on a supplied HDR input for every authored time sample. Useful for visual regression and for diff --git a/docs/ppsip-to-rtx-pp-plan.md b/docs/ppisp-to-rtx-pp-plan.md similarity index 98% rename from docs/ppsip-to-rtx-pp-plan.md rename to docs/ppisp-to-rtx-pp-plan.md index 55e36543..9822ddc9 100644 --- a/docs/ppsip-to-rtx-pp-plan.md +++ b/docs/ppisp-to-rtx-pp-plan.md @@ -23,7 +23,7 @@ the shared camera grouping and `/Render`/`RenderProduct` authoring are available ## 1. Context -The current PPISP USD export plan in `docs/ppisp-export-plan.md` uses a custom +The current PPISP USD export plan in `docs/ppisp-controller-export-plan.md` uses a custom SPG shader on each `RenderProduct` because PPISP is a post-blend image-space operator: @@ -243,7 +243,7 @@ Confidence: 0.25. ### Option R0 — SPG-only export -Keep the SPG plan from `docs/ppisp-export-plan.md` as the only PPISP-preserving +Keep the SPG plan from `docs/ppisp-controller-export-plan.md` as the only PPISP-preserving export path. Use USD post-processing only for user-authored artistic settings unrelated to @@ -614,9 +614,6 @@ Rationale: stable in the target Kit version? - Should Gaussian skip-tonemapping be disabled for PPISP USD approximation, or is the exported Gaussian material already authored for that path? -- Should the requested file name keep the `ppsip` spelling, or should a follow-up - rename to `ppisp-to-usd-post-processing-plan.md` be made? - --- ## 9. Current recommendation