# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual # property and proprietary rights in and to this material, related # documentation and any modifications thereto. Any use, reproduction, # disclosure or distribution of this material and related documentation # without an express license agreement from NVIDIA CORPORATION or # its affiliates is strictly prohibited. import os import numpy as np import torch from ..render import mesh from ..render import render from ..render import regularizer import kaolin import pytorch3d.ops from ..render import util as render_utils import torch.nn.functional as F from ..render import renderutils as ru ############################################################################### # Marching tetrahedrons implementation (differentiable), adapted from # https://github.com/NVIDIAGameWorks/kaolin/blob/master/kaolin/ops/conversions/tetmesh.py ############################################################################### class DMTet: def __init__(self): self.triangle_table = torch.tensor([ [-1, -1, -1, -1, -1, -1], [ 1, 0, 2, -1, -1, -1], [ 4, 0, 3, -1, -1, -1], [ 1, 4, 2, 1, 3, 4], [ 3, 1, 5, -1, -1, -1], [ 2, 3, 0, 2, 5, 3], [ 1, 4, 0, 1, 5, 4], [ 4, 2, 5, -1, -1, -1], [ 4, 5, 2, -1, -1, -1], [ 4, 1, 0, 4, 5, 1], [ 3, 2, 0, 3, 5, 2], [ 1, 3, 5, -1, -1, -1], [ 4, 1, 2, 4, 3, 1], [ 3, 0, 4, -1, -1, -1], [ 2, 0, 1, -1, -1, -1], [-1, -1, -1, -1, -1, -1] ], dtype=torch.long, device='cuda') self.num_triangles_table = torch.tensor([0,1,1,2,1,2,2,1,1,2,2,1,2,1,1,0], dtype=torch.long, device='cuda') self.base_tet_edges = torch.tensor([0,1,0,2,0,3,1,2,1,3,2,3], dtype=torch.long, device='cuda') ############################################################################### # Utility functions ############################################################################### def sort_edges(self, edges_ex2): with torch.no_grad(): order = (edges_ex2[:,0] > edges_ex2[:,1]).long() order = order.unsqueeze(dim=1) a = torch.gather(input=edges_ex2, index=order, dim=1) b = torch.gather(input=edges_ex2, index=1-order, dim=1) return torch.stack([a, b],-1) def map_uv(self, faces, face_gidx, max_idx): N = int(np.ceil(np.sqrt((max_idx+1)//2))) tex_y, tex_x = torch.meshgrid( torch.linspace(0, 1 - (1 / N), N, dtype=torch.float32, device="cuda"), torch.linspace(0, 1 - (1 / N), N, dtype=torch.float32, device="cuda"), indexing='ij' ) pad = 0.9 / N uvs = torch.stack([ tex_x , tex_y, tex_x + pad, tex_y, tex_x + pad, tex_y + pad, tex_x , tex_y + pad ], dim=-1).view(-1, 2) def _idx(tet_idx, N): x = tet_idx % N y = torch.div(tet_idx, N, rounding_mode='trunc') return y * N + x tet_idx = _idx(torch.div(face_gidx, 2, rounding_mode='trunc'), N) tri_idx = face_gidx % 2 uv_idx = torch.stack(( tet_idx * 4, tet_idx * 4 + tri_idx + 1, tet_idx * 4 + tri_idx + 2 ), dim = -1). view(-1, 3) return uvs, uv_idx ############################################################################### # Marching tets implementation ############################################################################### def __call__(self, pos_nx3, sdf_n, tet_fx4): with torch.no_grad(): occ_n = sdf_n > 0 occ_fx4 = occ_n[tet_fx4.reshape(-1)].reshape(-1,4) occ_sum = torch.sum(occ_fx4, -1) valid_tets = (occ_sum>0) & (occ_sum<4) occ_sum = occ_sum[valid_tets] # find all vertices all_edges = tet_fx4[valid_tets][:,self.base_tet_edges].reshape(-1,2) all_edges = self.sort_edges(all_edges) unique_edges, idx_map = torch.unique(all_edges,dim=0, return_inverse=True) unique_edges = unique_edges.long() mask_edges = occ_n[unique_edges.reshape(-1)].reshape(-1,2).sum(-1) == 1 mapping = torch.ones((unique_edges.shape[0]), dtype=torch.long, device="cuda") * -1 mapping[mask_edges] = torch.arange(mask_edges.sum(), dtype=torch.long,device="cuda") idx_map = mapping[idx_map] # map edges to verts interp_v = unique_edges[mask_edges] edges_to_interp = pos_nx3[interp_v.reshape(-1)].reshape(-1,2,3) edges_to_interp_sdf = sdf_n[interp_v.reshape(-1)].reshape(-1,2,1) edges_to_interp_sdf[:,-1] *= -1 denominator = edges_to_interp_sdf.sum(1,keepdim = True) edges_to_interp_sdf = torch.flip(edges_to_interp_sdf, [1])/denominator verts = (edges_to_interp * edges_to_interp_sdf).sum(1) idx_map = idx_map.reshape(-1,6) v_id = torch.pow(2, torch.arange(4, dtype=torch.long, device="cuda")) tetindex = (occ_fx4[valid_tets] * v_id.unsqueeze(0)).sum(-1) num_triangles = self.num_triangles_table[tetindex] # Generate triangle indices faces = torch.cat(( torch.gather(input=idx_map[num_triangles == 1], dim=1, index=self.triangle_table[tetindex[num_triangles == 1]][:, :3]).reshape(-1,3), torch.gather(input=idx_map[num_triangles == 2], dim=1, index=self.triangle_table[tetindex[num_triangles == 2]][:, :6]).reshape(-1,3), ), dim=0) # Get global face index (static, does not depend on topology) num_tets = tet_fx4.shape[0] tet_gidx = torch.arange(num_tets, dtype=torch.long, device="cuda")[valid_tets] face_gidx = torch.cat(( tet_gidx[num_triangles == 1]*2, torch.stack((tet_gidx[num_triangles == 2]*2, tet_gidx[num_triangles == 2]*2 + 1), dim=-1).view(-1) ), dim=0) uvs, uv_idx = self.map_uv(faces, face_gidx, num_tets*2) face_to_valid_tet = torch.cat(( tet_gidx[num_triangles == 1], torch.stack((tet_gidx[num_triangles == 2], tet_gidx[num_triangles == 2]), dim=-1).view(-1) ), dim=0) valid_vert_idx = tet_fx4[tet_gidx[num_triangles > 0]].long().unique() return verts, faces, uvs, uv_idx, face_to_valid_tet.long(), valid_vert_idx ############################################################################### # Regularizer ############################################################################### def sdf_reg_loss(sdf, all_edges): sdf_f1x6x2 = sdf[all_edges.reshape(-1)].reshape(-1,2) mask = torch.sign(sdf_f1x6x2[...,0]) != torch.sign(sdf_f1x6x2[...,1]) sdf_f1x6x2 = sdf_f1x6x2[mask] sdf_diff = torch.nn.functional.binary_cross_entropy_with_logits(sdf_f1x6x2[...,0], (sdf_f1x6x2[...,1] > 0).float()) + \ torch.nn.functional.binary_cross_entropy_with_logits(sdf_f1x6x2[...,1], (sdf_f1x6x2[...,0] > 0).float()) return sdf_diff class Buffer(object): def __init__(self, shape, capacity, device) -> None: self.len_curr = 0 self.pointer = 0 self.capacity = capacity self.buffer = torch.zeros((capacity, ) + shape, device=device) def push(self, x): ''' Push one single data point into the buffer ''' self.buffer[self.pointer] = x self.pointer = (self.pointer + 1) % self.capacity if self.len_curr < self.capacity: self.len_curr += 1 def avg(self): # simple windowed avg without exp decay return torch.sign(torch.sign(self.buffer[:self.len_curr]).float().mean(dim=0)).float() ############################################################################### # Geometry interface ############################################################################### class DMTetGeometry(torch.nn.Module): def __init__(self, grid_res, scale, FLAGS, root='./', grid_to_tet=None, deform_scale=1.0, **kwargs): super(DMTetGeometry, self).__init__() self.FLAGS = FLAGS self.grid_res = grid_res self.marching_tets = DMTet() self.tanh = False self.deform_scale = deform_scale self.grid_to_tet = grid_to_tet self.padding = 5 self.smooth_kernel = torch.ones(1, 1, self.padding*2 + 1, self.padding*2 + 1).cuda() tets = np.load(os.path.join(root, 'data/tets/{}_tets_cropped.npz'.format(self.grid_res))) self.verts = torch.tensor(tets['vertices'], dtype=torch.float32, device='cuda') * scale self.indices = torch.tensor(tets['indices'], dtype=torch.long, device='cuda') self.generate_edges() # Random init sdf = torch.rand_like(self.verts[:,0]).clamp(-1.0, 1.0) - 0.1 self.sdf = torch.nn.Parameter(sdf.clone().detach(), requires_grad=True) self.register_parameter('sdf', self.sdf) self.deform = torch.nn.Parameter(torch.zeros_like(self.verts), requires_grad=True) self.register_parameter('deform', self.deform) self.sdf_ema = torch.nn.Parameter(sdf.clone().detach(), requires_grad=False) self.deform_ema = torch.nn.Parameter(self.deform.clone().detach(), requires_grad=False) self.ema_coeff = 0.9 self.sdf_buffer = Buffer(sdf.size(), capacity=200, device='cuda') def generate_edges(self): with torch.no_grad(): edges = torch.tensor([0,1,0,2,0,3,1,2,1,3,2,3], dtype = torch.long, device = "cuda") all_edges = self.indices[:,edges].reshape(-1,2) all_edges_sorted = torch.sort(all_edges, dim=1)[0] self.all_edges = torch.unique(all_edges_sorted, dim=0) def getAABB(self): return torch.min(self.verts, dim=0).values, torch.max(self.verts, dim=0).values def getVertNNDist(self): v_deformed = (self.verts + 2 / (self.grid_res * 2) * torch.tanh(self.deform)).unsqueeze(0) return (pytorch3d.ops.knn.knn_points(v_deformed, v_deformed, K=2).dists[0, :, -1].detach()) ## K=2 because dist(self, self)=0 def getTetCenters(self): v_deformed = self.get_deformed() # size: N x 3 face_verts = v_deformed[self.indices] # size: M x 4 x 3 face_centers = face_verts.mean(dim=1) # size: M x 3 return face_centers def getValidTetIdx(self): # Run DM tet to get a base mesh v_deformed = self.get_deformed() verts, faces, uvs, uv_idx, tet_gidx, valid_vert_idx = self.marching_tets(v_deformed, self.sdf, self.indices) return tet_gidx.long() def getValidVertsIdx(self): # Run DM tet to get a base mesh v_deformed = self.get_deformed() verts, faces, uvs, uv_idx, tet_gidx, valid_vert_idx = self.marching_tets(v_deformed, self.sdf, self.indices) return self.indices[tet_gidx.long()].unique() def getMesh(self, material, noise=0.0, ema=False): # Run DM tet to get a base mesh v_deformed = self.get_deformed(ema=ema) if ema: # sdf = self.sdf * (1 - self.ema_coeff) + self.sdf_ema.detach() * self.ema_coeff sdf = self.sdf_ema else: sdf = self.sdf verts, faces, uvs, uv_idx, tet_gidx, valid_vert_idx = self.marching_tets(v_deformed, sdf, self.indices) imesh = mesh.Mesh(verts, faces, v_tex=uvs, t_tex_idx=uv_idx, material=material) imesh = mesh.auto_normals(imesh) if material is not None: # Run mesh operations to generate tangent space imesh = mesh.compute_tangents(imesh) imesh.valid_vert_idx = valid_vert_idx return imesh def get_deformed(self, no_grad=False, ema=False): if no_grad: deform = self.deform.detach() else: deform = self.deform if self.tanh: # v_deformed = self.verts + 2 / (self.grid_res * 2) * torch.tanh(self.deform) v_deformed = self.verts + 2 / (self.grid_res * 2) * torch.tanh(deform) * self.deform_scale else: v_deformed = self.verts + 2 / (self.grid_res * 2) * deform * self.deform_scale return v_deformed def get_angle(self): with torch.no_grad(): comb_list = [ (0, 1, 2, 3), (0, 1, 3, 2), (0, 2, 3, 1), (1, 2, 3, 0) ] directions = torch.zeros(self.indices.size(0), 4).cuda() dir_vec = torch.zeros(self.indices.size(0), 4, 3).cuda() vert_inds = torch.zeros(self.indices.size(0), 4).cuda().long() count = 0 vpos_list = self.get_deformed() for comb in comb_list: face = self.indices[:, comb[:3]] face_pos = vpos_list[face, :] face_center = face_pos.mean(1, keepdim=False) v = self.indices[:, comb[3]] test_vec = vpos_list[v] ref_vec = render_utils.safe_normalize(vpos_list[face[:, 0]] - face_center) distance_vec = test_vec - render_utils.dot(test_vec, ref_vec) * ref_vec directions[:, count] = torch.sign(render_utils.dot(test_vec, distance_vec)[:, 0]) dir_vec[:, count, :] = distance_vec vert_inds[:, count] = v count += 1 return directions, dir_vec, vert_inds def clamp_deform(self): if not self.tanh: self.deform.data[:] = self.deform.data.clamp(-0.99, 0.99) self.sdf.data[:] = self.sdf.data.clamp(-1.0, 1.0) def render(self, glctx, target, lgt, opt_material, bsdf=None, ema=False, xfm_lgt=None, get_visible_tets=False): opt_mesh = self.getMesh(opt_material, ema=ema) tet_centers = self.getTetCenters() if get_visible_tets else None return render.render_mesh(glctx, opt_mesh, target['mvp'], target['campos'], lgt, target['resolution'], spp=target['spp'], msaa=True, background=target['background'], bsdf=bsdf, xfm_lgt=xfm_lgt, tet_centers=tet_centers) def render_with_mesh(self, glctx, target, lgt, opt_material, bsdf=None, noise=0.0, ema=False, xfm_lgt=None): opt_mesh = self.getMesh(opt_material, noise=noise, ema=ema) return opt_mesh, render.render_mesh(glctx, opt_mesh, target['mvp'], target['campos'], lgt, target['resolution'], spp=target['spp'], msaa=True, background=target['background'], bsdf=bsdf, xfm_lgt=xfm_lgt) def update_ema(self, ema_coeff=0.9): self.sdf_buffer.push(self.sdf) self.sdf_ema.data[:] = self.sdf_buffer.avg() self.deform_ema.data[:] = self.deform.data[:] def render_ema(self, glctx, target, lgt, opt_material, bsdf=None, xfm_lgt=None): opt_mesh = self.getMesh(opt_material, ema=True) return render.render_mesh(glctx, opt_mesh, target['mvp'], target['campos'], lgt, target['resolution'], spp=target['spp'], msaa=True, background=target['background'], bsdf=bsdf, xfm_lgt=xfm_lgt) def tick(self, glctx, target, lgt, opt_material, loss_fn, iteration, with_reg=True, xfm_lgt=None, no_depth_thin=True): self.deform.requires_grad = True if iteration > 200 and iteration < 2000 and iteration % 20 == 0: with torch.no_grad(): v_pos = self.get_deformed() v_pos_camera_homo = ru.xfm_points(v_pos[None, ...], target['mvp']) v_pos_camera = v_pos_camera_homo[:, :, :2] / v_pos_camera_homo[:, :, -1:] v_pos_camera_discrete = torch.round((v_pos_camera * 0.5 + 0.5).clip(0, 1) * (target['resolution'][0] - 1)).long() mask_cont = F.conv2d(target['mask_cont'][:, :, :, 0].unsqueeze(1), self.smooth_kernel, stride=1, padding=self.padding)[:, 0] target_mask = mask_cont == 0 for k in range(target_mask.size(0)): assert v_pos_camera_discrete[k].min() >= 0 and v_pos_camera_discrete[k].max() < target['resolution'][0] v_mask = target_mask[k, v_pos_camera_discrete[k, :, 1], v_pos_camera_discrete[k, :, 0]].view(v_pos.size(0)) self.sdf.data[v_mask] = 1e-2 self.deform.data[v_mask] = 0.0 # ============================================================================================== # Render optimizable object with identical conditions # ============================================================================================== imesh, buffers = self.render_with_mesh(glctx, target, lgt, opt_material, noise=0.0, xfm_lgt=xfm_lgt) # ============================================================================================== # Compute loss # ============================================================================================== t_iter = iteration / self.FLAGS.iter # Image-space loss, split into a coverage component and a color component color_ref = target['img'] img_loss = torch.tensor(0.0).cuda() alpha_scale = 1.0 img_loss = torch.nn.functional.mse_loss(buffers['shaded'][..., 3:], color_ref[..., 3:]) * alpha_scale img_loss = img_loss + loss_fn(buffers['shaded'][..., 0:3] * color_ref[..., 3:], color_ref[..., 0:3] * color_ref[..., 3:]) color_ref_second = target['img_second'] img_loss = img_loss + torch.nn.functional.mse_loss(buffers['shaded_second'][..., 3:], color_ref_second[..., 3:]) * alpha_scale * 1e-1 img_loss = img_loss + loss_fn(buffers['shaded_second'][..., 0:3] * color_ref_second[..., 3:], color_ref_second[..., 0:3] * color_ref_second[..., 3:]) * 1e-1 mask = (target['mask_cont'][:, :, :, 0] == 1.0).float() if iteration < 10000: depth_scale = 100.0 else: depth_scale = 1.0 if iteration % 300 == 0 and iteration < 1790: self.deform.data[:] *= 0.4 if no_depth_thin: valid_depth_mask = (target['depth_second'] >= 0).float().detach() depth_prox_mask = ((target['depth_second'] - target['depth']).abs() >= 5e-3).float().detach() else: valid_depth_mask = 1.0 depth_diff = (buffers['depth'][:, :, :, :1] - target['depth'][:, :, :, :1]).abs() * mask.unsqueeze(-1) * valid_depth_mask depth_diff_second = (buffers['depth_second'][:, :, :, :1] - target['depth_second'][:, :, :, :1]).abs() * mask.unsqueeze(-1) * valid_depth_mask * depth_prox_mask * 1e-1 thres = 1.0 l1_loss_mask = (depth_diff < thres).float() l1_loss_mask_second = (depth_diff_second < thres).float() img_loss = img_loss + ( ( l1_loss_mask * depth_diff + (1 - l1_loss_mask) * (depth_diff.pow(2) + thres - thres**2) ).mean() * 1.0 * depth_scale + ( l1_loss_mask_second * depth_diff_second + (1 - l1_loss_mask_second) * (depth_diff_second.pow(2) + thres - thres**2) ).mean() * 1.0 * depth_scale ) reg_loss = torch.tensor(0.0).cuda() # SDF regularizer iter_thres = 0 sdf_weight = self.FLAGS.sdf_regularizer - (self.FLAGS.sdf_regularizer - 0.01) * min(1.0, 4.0 * ((iteration - iter_thres) / (self.FLAGS.iter - iter_thres))) sdf_mask = torch.zeros_like(self.sdf, device=self.sdf.device) sdf_mask[imesh.valid_vert_idx] = 1.0 sdf_masked = self.sdf.detach() * sdf_mask + self.sdf * (1 - sdf_mask) reg_loss = sdf_reg_loss(sdf_masked, self.all_edges).mean() * sdf_weight * 0.1 # Dropoff to 0.01 # Albedo (k_d) smoothnesss regularizer reg_loss += torch.mean(buffers['kd_grad'][..., :-1] * buffers['kd_grad'][..., -1:]) * 0.03 * min(1.0, iteration / 500) # Visibility regularizer reg_loss += torch.mean(buffers['occlusion'][..., :-1] * buffers['occlusion'][..., -1:]) * 1e0 * min(1.0, iteration / 500) # pointcloud chamfer distance pred_points = kaolin.ops.mesh.sample_points(imesh.v_pos.unsqueeze(0), imesh.t_pos_idx, 50000)[0][0] target_pts = target['spts'] chamfer = kaolin.metrics.pointcloud.chamfer_distance(pred_points.unsqueeze(0), target_pts.unsqueeze(0)).mean() reg_loss += chamfer return img_loss, reg_loss