# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual # property and proprietary rights in and to this material, related # documentation and any modifications thereto. Any use, reproduction, # disclosure or distribution of this material and related documentation # without an express license agreement from NVIDIA CORPORATION or # its affiliates is strictly prohibited. import os import time import argparse import json import sys import glob import tqdm import numpy as np import torch import nvdiffrast.torch as dr import xatlas # Import topology / geometry trainers from lib.geometry.dmtet import DMTetGeometry import lib.render.renderutils as ru from lib.render import material from lib.render import util from lib.render import mesh from lib.render import texture from lib.render import mlptexture from lib.render import light from lib.render import render from pytorch3d.io import save_obj import pymeshlab RADIUS = 3.0 # Enable to debug back-prop anomalies # torch.autograd.set_detect_anomaly(True) ############################################################################### # Loss setup ############################################################################### @torch.no_grad() def createLoss(FLAGS): if FLAGS.loss == "smape": return lambda img, ref: ru.image_loss(img, ref, loss='smape', tonemapper='none') elif FLAGS.loss == "mse": return lambda img, ref: ru.image_loss(img, ref, loss='mse', tonemapper='none') elif FLAGS.loss == "logl1": return lambda img, ref: ru.image_loss(img, ref, loss='l1', tonemapper='log_srgb') elif FLAGS.loss == "logl2": return lambda img, ref: ru.image_loss(img, ref, loss='mse', tonemapper='log_srgb') elif FLAGS.loss == "relmse": return lambda img, ref: ru.image_loss(img, ref, loss='relmse', tonemapper='none') else: assert False ############################################################################### # Mix background into a dataset image ############################################################################### @torch.no_grad() def prepare_batch(target, bg_type='black'): # assert len(target['img'].shape) == 4, "Image shape should be [n, h, w, c]" if bg_type == 'checker': background = torch.tensor(util.checkerboard(target['img'].shape[1:3], 8), dtype=torch.float32, device='cuda')[None, ...] elif bg_type == 'black': background = torch.zeros((1, target['resolution'][0], target['resolution'][1]) + (3,), dtype=torch.float32, device='cuda') elif bg_type == 'white': background = torch.ones((1, target['resolution'][0], target['resolution'][1]) + (3,), dtype=torch.float32, device='cuda') elif bg_type == 'reference': background = target['img'][..., 0:3] elif bg_type == 'random': background = torch.rand((1, target['resolution'][0], target['resolution'][1]) + (3,), dtype=torch.float32, device='cuda') else: assert False, "Unknown background type %s" % bg_type target['mv'] = target['mv'].cuda() target['mvp'] = target['mvp'].cuda() target['campos'] = target['campos'].cuda() target['background'] = background return target ############################################################################### # UV - map geometry & convert to a mesh ############################################################################### @torch.no_grad() def xatlas_uvmap(glctx, geometry, mat, FLAGS): eval_mesh = geometry.getMesh(mat) # Create uvs with xatlas v_pos = eval_mesh.v_pos.detach().cpu().numpy() t_pos_idx = eval_mesh.t_pos_idx.detach().cpu().numpy() vmapping, indices, uvs = xatlas.parametrize(v_pos, t_pos_idx) # Convert to tensors indices_int64 = indices.astype(np.uint64, casting='same_kind').view(np.int64) uvs = torch.tensor(uvs, dtype=torch.float32, device='cuda') faces = torch.tensor(indices_int64, dtype=torch.int64, device='cuda') new_mesh = mesh.Mesh(v_tex=uvs, t_tex_idx=faces, base=eval_mesh) mask, kd, ks, normal = render.render_uv(glctx, new_mesh, FLAGS.texture_res, eval_mesh.material['kd_ks_normal']) if FLAGS.layers > 1: kd = torch.cat((kd, torch.rand_like(kd[...,0:1])), dim=-1) kd_min, kd_max = torch.tensor(FLAGS.kd_min, dtype=torch.float32, device='cuda'), torch.tensor(FLAGS.kd_max, dtype=torch.float32, device='cuda') ks_min, ks_max = torch.tensor(FLAGS.ks_min, dtype=torch.float32, device='cuda'), torch.tensor(FLAGS.ks_max, dtype=torch.float32, device='cuda') nrm_min, nrm_max = torch.tensor(FLAGS.nrm_min, dtype=torch.float32, device='cuda'), torch.tensor(FLAGS.nrm_max, dtype=torch.float32, device='cuda') new_mesh.material = material.Material({ 'bsdf' : mat['bsdf'], 'kd' : texture.Texture2D(kd, min_max=[kd_min, kd_max]), 'ks' : texture.Texture2D(ks, min_max=[ks_min, ks_max]), 'normal' : texture.Texture2D(normal, min_max=[nrm_min, nrm_max]) }) return new_mesh ############################################################################### # Utility functions for material ############################################################################### def initial_guess_material(geometry, mlp, FLAGS, init_mat=None): kd_min, kd_max = torch.tensor(FLAGS.kd_min, dtype=torch.float32, device='cuda'), torch.tensor(FLAGS.kd_max, dtype=torch.float32, device='cuda') ks_min, ks_max = torch.tensor(FLAGS.ks_min, dtype=torch.float32, device='cuda'), torch.tensor(FLAGS.ks_max, dtype=torch.float32, device='cuda') nrm_min, nrm_max = torch.tensor(FLAGS.nrm_min, dtype=torch.float32, device='cuda'), torch.tensor(FLAGS.nrm_max, dtype=torch.float32, device='cuda') if mlp: mlp_min = torch.cat((kd_min[0:3], ks_min, nrm_min), dim=0) mlp_max = torch.cat((kd_max[0:3], ks_max, nrm_max), dim=0) mlp_map_opt = mlptexture.MLPTexture3D(geometry.getAABB(), channels=9, min_max=[mlp_min, mlp_max]) mat = material.Material({'kd_ks_normal' : mlp_map_opt}) else: # Setup Kd (albedo) and Ks (x, roughness, metalness) textures if FLAGS.random_textures or init_mat is None: num_channels = 4 if FLAGS.layers > 1 else 3 kd_init = torch.rand(size=FLAGS.texture_res + [num_channels], device='cuda') * (kd_max - kd_min)[None, None, 0:num_channels] + kd_min[None, None, 0:num_channels] kd_map_opt = texture.create_trainable(kd_init , FLAGS.texture_res, not FLAGS.custom_mip, [kd_min, kd_max]) ksR = np.random.uniform(size=FLAGS.texture_res + [1], low=0.0, high=0.01) ksG = np.random.uniform(size=FLAGS.texture_res + [1], low=ks_min[1].cpu(), high=ks_max[1].cpu()) ksB = np.random.uniform(size=FLAGS.texture_res + [1], low=ks_min[2].cpu(), high=ks_max[2].cpu()) ks_map_opt = texture.create_trainable(np.concatenate((ksR, ksG, ksB), axis=2), FLAGS.texture_res, not FLAGS.custom_mip, [ks_min, ks_max]) else: kd_map_opt = texture.create_trainable(init_mat['kd'], FLAGS.texture_res, not FLAGS.custom_mip, [kd_min, kd_max]) ks_map_opt = texture.create_trainable(init_mat['ks'], FLAGS.texture_res, not FLAGS.custom_mip, [ks_min, ks_max]) # Setup normal map if FLAGS.random_textures or init_mat is None or 'normal' not in init_mat: normal_map_opt = texture.create_trainable(np.array([0, 0, 1]), FLAGS.texture_res, not FLAGS.custom_mip, [nrm_min, nrm_max]) else: normal_map_opt = texture.create_trainable(init_mat['normal'], FLAGS.texture_res, not FLAGS.custom_mip, [nrm_min, nrm_max]) mat = material.Material({ 'kd' : kd_map_opt, 'ks' : ks_map_opt, 'normal' : normal_map_opt }) if init_mat is not None: mat['bsdf'] = init_mat['bsdf'] else: mat['bsdf'] = 'pbr' return mat ############################################################################### # Validation & testing ############################################################################### def rotate_scene(FLAGS, itr): fovy = np.deg2rad(45) cam_radius = RADIUS proj_mtx = util.perspective(fovy, FLAGS.display_res[1] / FLAGS.display_res[0], FLAGS.cam_near_far[0], FLAGS.cam_near_far[1]) # Smooth rotation for display. ang = (itr / 50) * np.pi * 2 mv = util.translate(0, 0, -cam_radius) @ (util.rotate_x(-0.4) @ util.rotate_y(ang)) mvp = proj_mtx @ mv campos = torch.linalg.inv(mv)[:3, 3] res_dict = { 'mv': mv[None, ...].cuda(), 'mvp': mvp[None, ...].cuda(), 'campos': campos[None, ...].cuda(), 'spp': 1, 'resolution': FLAGS.display_res } return res_dict def validate_itr(glctx, target, geometry, opt_material, lgt, FLAGS): result_dict = {} with torch.no_grad(): lgt.build_mips() if FLAGS.camera_space_light: lgt.xfm(target['mv']) buffers = geometry.render(glctx, target, lgt, opt_material) result_dict['opt'] = util.rgb_to_srgb(buffers['shaded'][...,0:3])[0] result_image = result_dict['opt'] return result_image, result_dict def validate(glctx, geometry, opt_material, lgt, dataset_validate, out_dir, FLAGS): # ============================================================================================== # Validation loop # ============================================================================================== mse_values = [] psnr_values = [] dataloader_validate = torch.utils.data.DataLoader(dataset_validate, batch_size=1, collate_fn=dataset_validate.collate) os.makedirs(out_dir, exist_ok=True) with open(os.path.join(out_dir, 'metrics.txt'), 'w') as fout: fout.write('ID, MSE, PSNR\n') print("Running validation") for it, target in enumerate(dataloader_validate): # Mix validation background target = prepare_batch(target, FLAGS.background) result_image, result_dict = validate_itr(glctx, target, geometry, opt_material, lgt, FLAGS) # Compute metrics opt = torch.clamp(result_dict['opt'], 0.0, 1.0) ref = torch.clamp(result_dict['ref'], 0.0, 1.0) mse = torch.nn.functional.mse_loss(opt, ref, size_average=None, reduce=None, reduction='mean').item() mse_values.append(float(mse)) psnr = util.mse_to_psnr(mse) psnr_values.append(float(psnr)) line = "%d, %1.8f, %1.8f\n" % (it, mse, psnr) fout.write(str(line)) for k in result_dict.keys(): np_img = result_dict[k].detach().cpu().numpy() util.save_image(out_dir + '/' + ('val_%06d_%s.png' % (it, k)), np_img) avg_mse = np.mean(np.array(mse_values)) avg_psnr = np.mean(np.array(psnr_values)) line = "AVERAGES: %1.4f, %2.3f\n" % (avg_mse, avg_psnr) fout.write(str(line)) print("MSE, PSNR") print("%1.8f, %2.3f" % (avg_mse, avg_psnr)) return avg_psnr ############################################################################### # Main shape fitter function / optimization loop ############################################################################### class Trainer(torch.nn.Module): def __init__(self, glctx, geometry, lgt, mat, optimize_geometry, optimize_light, image_loss_fn, FLAGS): super(Trainer, self).__init__() self.glctx = glctx self.geometry = geometry self.light = lgt self.material = mat self.optimize_geometry = optimize_geometry self.optimize_light = optimize_light self.image_loss_fn = image_loss_fn self.FLAGS = FLAGS if not self.optimize_light: with torch.no_grad(): self.light.build_mips() self.params = list(self.material.parameters()) self.params += list(self.light.parameters()) if optimize_light else [] self.geo_params = list(self.geometry.parameters()) if optimize_geometry else [] def forward(self, target, it): if self.optimize_light: self.light.build_mips() if self.FLAGS.camera_space_light: self.light.xfm(target['mv']) return self.geometry.tick(glctx, target, self.light, self.material, self.image_loss_fn, it) #---------------------------------------------------------------------------- # Main function. #---------------------------------------------------------------------------- if __name__ == "__main__": parser = argparse.ArgumentParser(description='nvdiffrec') parser.add_argument('--config', type=str, default=None, help='Config file') parser.add_argument('-i', '--iter', type=int, default=5000) parser.add_argument('-b', '--batch', type=int, default=1) parser.add_argument('-s', '--spp', type=int, default=1) parser.add_argument('-l', '--layers', type=int, default=1) parser.add_argument('-r', '--train-res', nargs=2, type=int, default=[512, 512]) parser.add_argument('-dr', '--display-res', type=int, default=None) parser.add_argument('-tr', '--texture-res', nargs=2, type=int, default=[1024, 1024]) parser.add_argument('-di', '--display-interval', type=int, default=0) parser.add_argument('-si', '--save-interval', type=int, default=1000) parser.add_argument('-lr', '--learning-rate', type=float, default=0.01) parser.add_argument('-mr', '--min-roughness', type=float, default=0.08) parser.add_argument('-mip', '--custom-mip', action='store_true', default=False) parser.add_argument('-rt', '--random-textures', action='store_true', default=False) parser.add_argument('-bg', '--background', default='checker', choices=['black', 'white', 'checker', 'reference']) parser.add_argument('-o', '--out-dir', type=str, default='./viz_tet') parser.add_argument('-sp', '--sample-path', type=str, default=None) parser.add_argument('-bm', '--base-mesh', type=str, default=None) parser.add_argument('-ds', '--deform-scale', type=float, default=2.0) parser.add_argument('-vn', '--viz-name', type=str, default='viz') parser.add_argument('--unnormalized_sdf', action="store_true") parser.add_argument('--validate', type=bool, default=True) parser.add_argument('--angle-ind', type=int, default=25, help='z-axis rotation of the object, from 0 to 50') parser.add_argument('-ns', '--num-smooth-steps', type=int, default=3, help='number of post-processing Laplacian smoothing steps') FLAGS = parser.parse_args() FLAGS.mtl_override = None # Override material of model FLAGS.dmtet_grid = 64 # Resolution of initial tet grid. We provide 64 and 128 resolution grids. Other resolutions can be generated with https://github.com/crawforddoran/quartet FLAGS.mesh_scale = 2.1 # Scale of tet grid box. Adjust to cover the model FLAGS.env_scale = 1.0 # Env map intensity multiplier FLAGS.envmap = None # HDR environment probe FLAGS.display = None # Conf validation window/display. E.g. [{"relight" : }] FLAGS.camera_space_light = False # Fixed light in camera space. This is needed for setups like ethiopian head where the scanned object rotates on a stand. FLAGS.lock_light = False # Disable light optimization in the second pass FLAGS.lock_pos = False # Disable vertex position optimization in the second pass FLAGS.sdf_regularizer = 0.2 # Weight for sdf regularizer (see paper for details) FLAGS.laplace = "relative" # Mesh Laplacian ["absolute", "relative"] FLAGS.laplace_scale = 10000.0 # Weight for sdf regularizer. Default is relative with large weight FLAGS.pre_load = True # Pre-load entire dataset into memory for faster training FLAGS.kd_min = [ 0.0, 0.0, 0.0, 0.0] # Limits for kd FLAGS.kd_max = [ 1.0, 1.0, 1.0, 1.0] FLAGS.ks_min = [ 0.0, 0.08, 0.0] # Limits for ks FLAGS.ks_max = [ 1.0, 1.0, 1.0] FLAGS.nrm_min = [-1.0, -1.0, 0.0] # Limits for normal map FLAGS.nrm_max = [ 1.0, 1.0, 1.0] FLAGS.cam_near_far = [0.1, 1000.0] FLAGS.learn_light = False FLAGS.cropped = False FLAGS.random_lgt = False if FLAGS.config is not None: data = json.load(open(FLAGS.config, 'r')) for key in data: FLAGS.__dict__[key] = data[key] if FLAGS.display_res is None: FLAGS.display_res = FLAGS.train_res os.makedirs(FLAGS.out_dir, exist_ok=True) viz_path = os.path.join(FLAGS.out_dir, 'viz') mesh_path = os.path.join(FLAGS.out_dir, 'mesh') os.makedirs(viz_path, exist_ok=True) os.makedirs(mesh_path, exist_ok=True) glctx = dr.RasterizeGLContext() # ============================================================================================== # Create env light with trainable parameters # ============================================================================================== lgt = light.load_env(FLAGS.envmap, scale=FLAGS.env_scale) # ============================================================================================== # If no initial guess, use DMtets to create geometry # ============================================================================================== # Setup geometry for optimization resolution = FLAGS.dmtet_grid geometry = DMTetGeometry( resolution, FLAGS.mesh_scale, FLAGS, deform_scale=FLAGS.deform_scale ) mask = torch.load(f'../data/grid_mask_{resolution}.pt').view(1, resolution, resolution, resolution).to("cuda") ### compute the mapping from tet indices to 3D cubic grid vertex indices tet_path = FLAGS.tet_path tet = np.load(tet_path) vertices = torch.tensor(tet['vertices']) vertices_unique = vertices[:].unique() dx = vertices_unique[1] - vertices_unique[0] vertices_discretized = (torch.round( (vertices - vertices.min()) / dx) ).long() data_all = np.load(FLAGS.sample_path) print('shape of generated data', data_all.shape) for no_data in tqdm.trange(data_all.shape[0]): grid = torch.tensor(data_all[no_data]) if FLAGS.unnormalized_sdf: raise NotImplementedError geometry.sdf.data[:] = ( grid[0, vertices_discretized[:, 0], vertices_discretized[:, 1], vertices_discretized[:, 2]] ).cuda() else: geometry.sdf.data[:] = torch.sign( grid[0, vertices_discretized[:, 0], vertices_discretized[:, 1], vertices_discretized[:, 2]] ).cuda() geometry.deform.data[:] = ( grid[1:, vertices_discretized[:, 0], vertices_discretized[:, 1], vertices_discretized[:, 2]] ).cuda().transpose(0, 1) geometry.deform.data[:] = geometry.deform.data[:].clip(-1.0, 1.0) ### mtl for visualization opt_material = { 'name' : '_default_mat', # 'bsdf' : 'pbr', 'bsdf' : 'diffuse', 'kd' : texture.Texture2D(torch.tensor([0.75, 0.3, 0.6], dtype=torch.float32, device='cuda')), 'ks' : texture.Texture2D(torch.tensor([0.0, 0.0, 0.0], dtype=torch.float32, device='cuda')) } ### create and optimize mesh base_mesh = geometry.getMesh(opt_material) ### save image (before post-processing) v_pose = rotate_scene(FLAGS, FLAGS.angle_ind) ## pick a pose (pose # from 0 to 50) result_image, _ = validate_itr(glctx, prepare_batch(v_pose, FLAGS.background), geometry, opt_material, lgt, FLAGS) result_image = result_image.detach().cpu().numpy() util.save_image(os.path.join(viz_path, ('%s_%06d.png' % (FLAGS.viz_name, no_data))), result_image) ### save post-processed mesh mesh_savepath = os.path.join(mesh_path, '{:06d}.obj'.format(no_data)) save_obj( verts=base_mesh.v_pos, faces=base_mesh.t_pos_idx, f=mesh_savepath ) ms = pymeshlab.MeshSet() ms.load_new_mesh(mesh_savepath) ms.meshing_isotropic_explicit_remeshing() ms.apply_coord_laplacian_smoothing(stepsmoothnum=FLAGS.num_smooth_steps, cotangentweight=False) # ms.apply_coord_laplacian_smoothing(stepsmoothnum=3, cotangentweight=True) ## for smoother surface ms.meshing_isotropic_explicit_remeshing() ms.apply_filter_script() ms.save_current_mesh(mesh_savepath)