Skip to content

Commit 2d356dc

Browse files
committed
Big Code Refactor
1 parent 3b4e715 commit 2d356dc

File tree

289 files changed

+340145
-732
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

289 files changed

+340145
-732
lines changed

.vscode/settings.json

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
{
22
"cmake.sourceDirectory": "C:/Users/reall/Softwares/ComfyUI_windows_portable/ComfyUI/custom_nodes/ComfyUI-3D-Pack/diff-gaussian-rasterization",
33
"python.analysis.extraPaths": [
4-
"./gen_3d_modules"
4+
"./gen_3d_modules",
5+
"./MVs_Algorithms"
56
]
67
}

Dockerfile

+1-1
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ RUN pip install --no-cache -r requirements.txt
6363
WORKDIR /app/custom_nodes/ComfyUI-3D-Pack/
6464
COPY --chown=user:user requirements.txt requirements_post.txt ./
6565
COPY --chown=user:user simple-knn/ simple-knn/
66-
COPY --chown=user:user tgs/ tgs/
66+
COPY --chown=user:user Gen_3D_Modules/TriplaneGaussian/ Gen_3D_Modules/TriplaneGaussian/
6767
RUN pip install --no-cache -r requirements.txt \
6868
# post requirements installation require gpu, setup
6969
# `nvidia-container-runtime`, for docker, see

MVs_Algorithms/DiffRastMesh/__init__.py

Whitespace-only changes.
+162
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
import random
2+
import tqdm
3+
4+
import torch
5+
import torch.nn as nn
6+
import torch.nn.functional as F
7+
from kiui.mesh_utils import clean_mesh, decimate_mesh
8+
from kiui.mesh_utils import laplacian_smooth_loss, normal_consistency
9+
from pytorch_msssim import SSIM, MS_SSIM
10+
11+
import comfy.utils
12+
13+
from .diff_mesh_renderer import DiffRastRenderer
14+
15+
from shared_utils.camera_utils import BaseCameraController
16+
from shared_utils.image_utils import prepare_torch_img
17+
18+
class DiffMeshCameraController(BaseCameraController):
19+
20+
def get_render_result(self, render_pose, bg_color, **kwargs):
21+
ref_cam = (render_pose, self.cam.perspective)
22+
return self.renderer.render(*ref_cam, self.cam.H, self.cam.W, ssaa=1, bg_color=bg_color, **kwargs) #ssaa = min(2.0, max(0.125, 2 * np.random.random()))
23+
24+
class DiffMesh:
25+
26+
def __init__(
27+
self,
28+
mesh,
29+
training_iterations,
30+
batch_size,
31+
texture_learning_rate,
32+
train_mesh_geometry,
33+
geometry_learning_rate,
34+
ms_ssim_loss_weight,
35+
remesh_after_n_iteration,
36+
invert_bg_prob,
37+
force_cuda_rasterize
38+
):
39+
self.device = torch.device("cuda")
40+
41+
self.train_mesh_geometry = train_mesh_geometry
42+
self.remesh_after_n_iteration = remesh_after_n_iteration
43+
44+
# prepare main components for optimization
45+
self.renderer = DiffRastRenderer(mesh, force_cuda_rasterize).to(self.device)
46+
47+
self.optimizer = torch.optim.Adam(self.renderer.get_params(texture_learning_rate, train_mesh_geometry, geometry_learning_rate))
48+
#self.ssim_loss = SSIM(data_range=1, size_average=True, channel=3)
49+
self.ms_ssim_loss = MS_SSIM(data_range=1, size_average=True, channel=3)
50+
self.lambda_ssim = ms_ssim_loss_weight
51+
52+
self.training_iterations = training_iterations
53+
54+
self.batch_size = batch_size
55+
56+
self.invert_bg_prob = invert_bg_prob
57+
58+
def prepare_training(self, reference_images, reference_masks, reference_orbit_camera_poses, reference_orbit_camera_fovy):
59+
self.ref_imgs_num = len(reference_images)
60+
61+
self.ref_size_H = reference_images[0].shape[0]
62+
self.ref_size_W = reference_images[0].shape[1]
63+
64+
# default camera settings
65+
self.cam_controller = DiffMeshCameraController(
66+
self.renderer, self.ref_size_W, self.ref_size_H, reference_orbit_camera_fovy, self.invert_bg_prob, None, self.device
67+
)
68+
69+
self.all_ref_cam_poses = reference_orbit_camera_poses
70+
71+
# prepare reference images and masks
72+
ref_imgs_torch_list = []
73+
ref_masks_torch_list = []
74+
for i in range(self.ref_imgs_num):
75+
ref_imgs_torch_list.append(prepare_torch_img(reference_images[i].unsqueeze(0), self.ref_size_H, self.ref_size_W, self.device))
76+
ref_masks_torch_list.append(prepare_torch_img(reference_masks[i].unsqueeze(2).unsqueeze(0), self.ref_size_H, self.ref_size_W, self.device))
77+
78+
self.ref_imgs_torch = torch.cat(ref_imgs_torch_list, dim=0)
79+
self.ref_masks_torch = torch.cat(ref_masks_torch_list, dim=0)
80+
81+
def training(self, decimate_target=5e4):
82+
starter = torch.cuda.Event(enable_timing=True)
83+
ender = torch.cuda.Event(enable_timing=True)
84+
starter.record()
85+
86+
ref_imgs_masked = []
87+
for i in range(self.ref_imgs_num):
88+
ref_imgs_masked.append((self.ref_imgs_torch[i] * self.ref_masks_torch[i]).unsqueeze(0))
89+
90+
ref_imgs_num_minus_1 = self.ref_imgs_num-1
91+
92+
comfy_pbar = comfy.utils.ProgressBar(self.training_iterations)
93+
94+
for step in tqdm.trange(self.training_iterations):
95+
96+
### calculate loss between reference and rendered image from known view
97+
loss = 0
98+
masked_rendered_img_batch = []
99+
masked_ref_img_batch = []
100+
for _ in range(self.batch_size):
101+
102+
i = random.randint(0, ref_imgs_num_minus_1)
103+
104+
out = self.cam_controller.render_at_pose(self.all_ref_cam_poses[i])
105+
106+
image = out["image"] # [H, W, 3] in [0, 1]
107+
image = image.permute(2, 0, 1).contiguous() # [3, H, W] in [0, 1]
108+
109+
image_masked = (image * self.ref_masks_torch[i]).unsqueeze(0)
110+
111+
masked_rendered_img_batch.append(image_masked)
112+
masked_ref_img_batch.append(ref_imgs_masked[i])
113+
114+
masked_rendered_img_batch_torch = torch.cat(masked_rendered_img_batch, dim=0)
115+
masked_ref_img_batch_torch = torch.cat(masked_ref_img_batch, dim=0)
116+
117+
# rgb loss
118+
loss += (1 - self.lambda_ssim) * F.mse_loss(masked_rendered_img_batch_torch, masked_ref_img_batch_torch)
119+
120+
# D-SSIM loss
121+
# [1, 3, H, W] in [0, 1]
122+
#loss += self.lambda_ssim * (1 - self.ssim_loss(X, Y))
123+
loss += self.lambda_ssim * (1 - self.ms_ssim_loss(masked_ref_img_batch_torch, masked_rendered_img_batch_torch))
124+
125+
# Regularization loss
126+
if self.train_mesh_geometry:
127+
current_v = self.renderer.mesh.v + self.renderer.v_offsets
128+
loss += 0.01 * laplacian_smooth_loss(current_v, self.renderer.mesh.f)
129+
loss += 0.001 * normal_consistency(current_v, self.renderer.mesh.f)
130+
loss += 0.1 * (self.renderer.v_offsets ** 2).sum(-1).mean()
131+
132+
# remesh periodically
133+
if step > 0 and step % self.remesh_after_n_iteration == 0:
134+
vertices = (self.renderer.mesh.v + self.renderer.v_offsets).detach().cpu().numpy()
135+
triangles = self.renderer.mesh.f.detach().cpu().numpy()
136+
vertices, triangles = clean_mesh(vertices, triangles, remesh=True, remesh_size=0.01)
137+
if triangles.shape[0] > decimate_target:
138+
vertices, triangles = decimate_mesh(vertices, triangles, decimate_target, optimalplacement=False)
139+
self.renderer.mesh.v = torch.from_numpy(vertices).contiguous().float().to(self.device)
140+
self.renderer.mesh.f = torch.from_numpy(triangles).contiguous().int().to(self.device)
141+
self.renderer.v_offsets = nn.Parameter(torch.zeros_like(self.renderer.mesh.v)).to(self.device)
142+
143+
# optimize step
144+
loss.backward()
145+
self.optimizer.step()
146+
self.optimizer.zero_grad()
147+
148+
comfy_pbar.update_absolute(step + 1)
149+
150+
torch.cuda.synchronize()
151+
152+
self.need_update = True
153+
154+
print(f"Step: {step}")
155+
156+
self.renderer.update_mesh()
157+
158+
ender.record()
159+
#t = starter.elapsed_time(ender)
160+
161+
def get_mesh_and_texture(self):
162+
return (self.renderer.mesh, self.renderer.mesh.albedo, )
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
import os
2+
import math
3+
4+
import numpy as np
5+
import torch
6+
import torch.nn as nn
7+
import torch.nn.functional as F
8+
import nvdiffrast.torch as dr
9+
10+
from kiui.op import inverse_sigmoid
11+
12+
from mesh_processer.mesh import safe_normalize
13+
14+
def scale_img_nhwc(x, size, mag='bilinear', min='bilinear'):
15+
assert (x.shape[1] >= size[0] and x.shape[2] >= size[1]) or (x.shape[1] < size[0] and x.shape[2] < size[1]), "Trying to magnify image in one dimension and minify in the other"
16+
y = x.permute(0, 3, 1, 2) # NHWC -> NCHW
17+
if x.shape[1] > size[0] and x.shape[2] > size[1]: # Minification, previous size was bigger
18+
y = torch.nn.functional.interpolate(y, size, mode=min)
19+
else: # Magnification
20+
if mag == 'bilinear' or mag == 'bicubic':
21+
y = torch.nn.functional.interpolate(y, size, mode=mag, align_corners=True)
22+
else:
23+
y = torch.nn.functional.interpolate(y, size, mode=mag)
24+
return y.permute(0, 2, 3, 1).contiguous() # NCHW -> NHWC
25+
26+
def scale_img_hwc(x, size, mag='bilinear', min='bilinear'):
27+
return scale_img_nhwc(x[None, ...], size, mag, min)[0]
28+
29+
def scale_img_nhw(x, size, mag='bilinear', min='bilinear'):
30+
return scale_img_nhwc(x[..., None], size, mag, min)[..., 0]
31+
32+
def scale_img_hw(x, size, mag='bilinear', min='bilinear'):
33+
return scale_img_nhwc(x[None, ..., None], size, mag, min)[0, ..., 0]
34+
35+
def make_divisible(x, m=8):
36+
return int(math.ceil(x / m) * m)
37+
38+
class DiffRastRenderer(nn.Module):
39+
def __init__(self, mesh, force_cuda_rast):
40+
41+
super().__init__()
42+
43+
self.mesh = mesh
44+
45+
if force_cuda_rast or os.name != 'nt':
46+
self.glctx = dr.RasterizeCudaContext()
47+
else:
48+
self.glctx = dr.RasterizeGLContext()
49+
50+
# extract trainable parameters
51+
self.v_offsets = nn.Parameter(torch.zeros_like(self.mesh.v), requires_grad=True)
52+
self.raw_albedo = nn.Parameter(inverse_sigmoid(self.mesh.albedo), requires_grad=True)
53+
54+
self.train_geo = False
55+
56+
def get_params(self, texture_lr, train_geo, geom_lr):
57+
58+
params = [
59+
{'params': self.raw_albedo, 'lr': texture_lr},
60+
]
61+
62+
self.train_geo = train_geo
63+
if train_geo:
64+
params.append({'params': self.v_offsets, 'lr': geom_lr})
65+
66+
return params
67+
68+
def update_mesh(self):
69+
self.mesh.v = (self.mesh.v + self.v_offsets).detach()
70+
self.mesh.albedo = torch.sigmoid(self.raw_albedo.detach())
71+
72+
def render(self, pose, proj, h0, w0, ssaa=1, bg_color=1, texture_filter='linear',
73+
optional_render_types=['depth', 'normal']):
74+
75+
# do super-sampling
76+
if ssaa != 1:
77+
h = make_divisible(h0 * ssaa, 8)
78+
w = make_divisible(w0 * ssaa, 8)
79+
else:
80+
h, w = h0, w0
81+
82+
results = {}
83+
84+
# get v
85+
if self.train_geo:
86+
v = self.mesh.v + self.v_offsets # [N, 3]
87+
else:
88+
v = self.mesh.v
89+
90+
pose = torch.from_numpy(pose.astype(np.float32)).to(v.device)
91+
proj = torch.from_numpy(proj.astype(np.float32)).to(v.device)
92+
93+
# get v_clip and render rgb
94+
v_cam = torch.matmul(F.pad(v, pad=(0, 1), mode='constant', value=1.0), torch.inverse(pose).T).float().unsqueeze(0)
95+
v_clip = v_cam @ proj.T
96+
97+
rast, rast_db = dr.rasterize(self.glctx, v_clip, self.mesh.f, (h, w))
98+
99+
#alpha = (rast[0, ..., 3:] > 0).float() # [H, W, 1]
100+
alpha = torch.clamp(rast[..., -1:], 0, 1).contiguous() # [1, H, W, 1]
101+
alpha = dr.antialias(alpha, rast, v_clip, self.mesh.f).clamp(0, 1).squeeze(0) # [H, W, 1] important to enable gradients!
102+
103+
# render albedo
104+
texc, texc_db = dr.interpolate(self.mesh.vt.unsqueeze(0).contiguous(), rast, self.mesh.ft, rast_db=rast_db, diff_attrs='all')
105+
albedo = dr.texture(self.raw_albedo.unsqueeze(0), texc, uv_da=texc_db, filter_mode=texture_filter) # [1, H, W, 3]
106+
albedo = torch.sigmoid(albedo)
107+
108+
# render depth
109+
if 'depth' in optional_render_types:
110+
depth, _ = dr.interpolate(-v_cam[..., [2]], rast, self.mesh.f) # [1, H, W, 1]
111+
depth = depth.squeeze(0) # [H, W, 1]
112+
113+
# get vn and render normal
114+
if 'normal' in optional_render_types:
115+
if self.train_geo:
116+
i0, i1, i2 = self.mesh.f[:, 0].long(), self.mesh.f[:, 1].long(), self.mesh.f[:, 2].long()
117+
v0, v1, v2 = v[i0, :], v[i1, :], v[i2, :]
118+
119+
face_normals = torch.cross(v1 - v0, v2 - v0)
120+
face_normals = safe_normalize(face_normals)
121+
122+
vn = torch.zeros_like(v)
123+
vn.scatter_add_(0, i0[:, None].repeat(1,3), face_normals)
124+
vn.scatter_add_(0, i1[:, None].repeat(1,3), face_normals)
125+
vn.scatter_add_(0, i2[:, None].repeat(1,3), face_normals)
126+
127+
vn = torch.where(torch.sum(vn * vn, -1, keepdim=True) > 1e-20, vn, torch.tensor([0.0, 0.0, 1.0], dtype=torch.float32, device=vn.device))
128+
else:
129+
vn = self.mesh.vn
130+
131+
normal, _ = dr.interpolate(vn.unsqueeze(0).contiguous(), rast, self.mesh.fn)
132+
normal = safe_normalize(normal[0])
133+
134+
# rotated normal (where [0, 0, 1] always faces camera)
135+
rot_normal = normal @ pose[:3, :3]
136+
viewcos = rot_normal[..., [2]]
137+
138+
# antialias
139+
albedo = dr.antialias(albedo, rast, v_clip, self.mesh.f).squeeze(0).contiguous() # [H, W, 3]
140+
albedo = alpha * albedo + (1 - alpha) * bg_color
141+
142+
# ssaa
143+
if ssaa != 1:
144+
albedo = scale_img_hwc(albedo, (h0, w0))
145+
alpha = scale_img_hwc(alpha, (h0, w0))
146+
if 'depth' in optional_render_types:
147+
depth = scale_img_hwc(depth, (h0, w0))
148+
if 'normal' in optional_render_types:
149+
normal = scale_img_hwc(normal, (h0, w0))
150+
viewcos = scale_img_hwc(viewcos, (h0, w0))
151+
152+
results['image'] = albedo.clamp(0, 1)
153+
results['alpha'] = alpha
154+
if 'depth' in optional_render_types:
155+
results['depth'] = depth
156+
if 'normal' in optional_render_types:
157+
results['normal'] = (normal + 1) / 2
158+
results['viewcos'] = viewcos
159+
160+
return results

MVs_Algorithms/FlexiCubes/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)