# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual # property and proprietary rights in and to this material, related # documentation and any modifications thereto. Any use, reproduction, # disclosure or distribution of this material and related documentation # without an express license agreement from NVIDIA CORPORATION or # its affiliates is strictly prohibited. import os import time import argparse import json import sys import cv2 import numpy as np import torch import nvdiffrast.torch as dr import xatlas # Import data readers / generators from lib.dataset.dataset_mesh import DatasetMesh from lib.dataset.dataset_shapenet import ShapeNetDataset # Import topology / geometry trainers from lib.geometry.dmtet_singleview import DMTetGeometry from lib.geometry.dmtet_fixedtopo import DMTetGeometryFixedTopo import lib.render.renderutils as ru from lib.render import obj from lib.render import material from lib.render import util from lib.render import mesh from lib.render import texture from lib.render import mlptexture from lib.render import light from lib.render import render from random import randint from time import sleep import traceback RADIUS = 2.0 # define colors color1 = (0, 0, 255) #red color2 = (0, 165, 255) #orange color3 = (0, 255, 255) #yellow color4 = (255, 255, 0) #cyan color5 = (255, 0, 0) #blue color6 = (128, 64, 64) #violet colorArr = np.array([[color1, color2, color3, color4, color5, color6]], dtype=np.uint8) # resize lut to 256 (or more) values lut = cv2.resize(colorArr, (256,1), interpolation = cv2.INTER_LINEAR) ############################################################################### # Loss setup ############################################################################### @torch.no_grad() def createLoss(FLAGS): if FLAGS.loss == "smape": return lambda img, ref: ru.image_loss(img, ref, loss='smape', tonemapper='none') elif FLAGS.loss == "mse": return lambda img, ref: ru.image_loss(img, ref, loss='mse', tonemapper='none') elif FLAGS.loss == "logl1": return lambda img, ref: ru.image_loss(img, ref, loss='l1', tonemapper='log_srgb') elif FLAGS.loss == "logl2": return lambda img, ref: ru.image_loss(img, ref, loss='mse', tonemapper='log_srgb') elif FLAGS.loss == "relmse": return lambda img, ref: ru.image_loss(img, ref, loss='relmse', tonemapper='none') else: assert False ############################################################################### # Mix background into a dataset image ############################################################################### @torch.no_grad() def prepare_batch(target, bg_type='black'): assert len(target['img'].shape) == 4, "Image shape should be [n, h, w, c]" if bg_type == 'checker': background = torch.tensor(util.checkerboard(target['img'].shape[1:3], 8), dtype=torch.float32, device='cuda')[None, ...] elif bg_type == 'black': background = torch.zeros(target['img'].shape[0:3] + (3,), dtype=torch.float32, device='cuda') elif bg_type == 'white': background = torch.ones(target['img'].shape[0:3] + (3,), dtype=torch.float32, device='cuda') elif bg_type == 'reference': background = target['img'][..., 0:3] elif bg_type == 'random': background = torch.rand(target['img'].shape[0:3] + (3,), dtype=torch.float32, device='cuda') else: assert False, "Unknown background type %s" % bg_type target['mv'] = target['mv'].cuda() target['mvp'] = target['mvp'].cuda() target['campos'] = target['campos'].cuda() target['img'] = target['img'].cuda() target['background'] = background target['img'] = torch.cat((torch.lerp(background, target['img'][..., 0:3], target['img'][..., 3:4]), target['img'][..., 3:4]), dim=-1) target['spts'] = target['spts'].cuda() target['vpts'] = target['vpts'].cuda() return target ############################################################################### # UV - map geometry & convert to a mesh ############################################################################### @torch.no_grad() def xatlas_uvmap(glctx, geometry, mat, FLAGS): eval_mesh = geometry.getMesh(mat) # Create uvs with xatlas v_pos = eval_mesh.v_pos.detach().cpu().numpy() t_pos_idx = eval_mesh.t_pos_idx.detach().cpu().numpy() vmapping, indices, uvs = xatlas.parametrize(v_pos, t_pos_idx) # Convert to tensors indices_int64 = indices.astype(np.uint64, casting='same_kind').view(np.int64) uvs = torch.tensor(uvs, dtype=torch.float32, device='cuda') faces = torch.tensor(indices_int64, dtype=torch.int64, device='cuda') new_mesh = mesh.Mesh(v_tex=uvs, t_tex_idx=faces, base=eval_mesh) mask, kd, ks, normal = render.render_uv(glctx, new_mesh, FLAGS.texture_res, eval_mesh.material['kd_ks_normal']) if FLAGS.layers > 1: kd = torch.cat((kd, torch.rand_like(kd[...,0:1])), dim=-1) kd_min, kd_max = torch.tensor(FLAGS.kd_min, dtype=torch.float32, device='cuda'), torch.tensor(FLAGS.kd_max, dtype=torch.float32, device='cuda') ks_min, ks_max = torch.tensor(FLAGS.ks_min, dtype=torch.float32, device='cuda'), torch.tensor(FLAGS.ks_max, dtype=torch.float32, device='cuda') nrm_min, nrm_max = torch.tensor(FLAGS.nrm_min, dtype=torch.float32, device='cuda'), torch.tensor(FLAGS.nrm_max, dtype=torch.float32, device='cuda') new_mesh.material = material.Material({ 'bsdf' : mat['bsdf'], 'kd' : texture.Texture2D(kd, min_max=[kd_min, kd_max]), 'ks' : texture.Texture2D(ks, min_max=[ks_min, ks_max]), 'normal' : texture.Texture2D(normal, min_max=[nrm_min, nrm_max]) }) return new_mesh @torch.no_grad() def xatlas_uvmap_nrm(glctx, geometry, mat, FLAGS): eval_mesh = geometry.getMesh(mat) # Create uvs with xatlas v_pos = eval_mesh.v_pos.detach().cpu().numpy() t_pos_idx = eval_mesh.t_pos_idx.detach().cpu().numpy() vmapping, indices, uvs = xatlas.parametrize(v_pos, t_pos_idx) # Convert to tensors indices_int64 = indices.astype(np.uint64, casting='same_kind').view(np.int64) uvs = torch.tensor(uvs, dtype=torch.float32, device='cuda') faces = torch.tensor(indices_int64, dtype=torch.int64, device='cuda') new_mesh = mesh.Mesh(v_tex=uvs, t_tex_idx=faces, base=eval_mesh) mask, normal = render.render_uv_nrm(glctx, new_mesh, FLAGS.texture_res, eval_mesh.material['normal']) if FLAGS.layers > 1: kd = torch.cat((kd, torch.rand_like(kd[...,0:1])), dim=-1) nrm_min, nrm_max = torch.tensor(FLAGS.nrm_min, dtype=torch.float32, device='cuda'), torch.tensor(FLAGS.nrm_max, dtype=torch.float32, device='cuda') new_mesh.material = material.Material({ 'bsdf' : mat['bsdf'], 'kd' : mat['kd'], 'ks' : mat['ks'], 'normal' : texture.Texture2D(normal, min_max=[nrm_min, nrm_max]) }) return new_mesh ############################################################################### # Utility functions for material ############################################################################### def initial_guess_material(geometry, mlp, FLAGS, init_mat=None): kd_min, kd_max = torch.tensor(FLAGS.kd_min, dtype=torch.float32, device='cuda'), torch.tensor(FLAGS.kd_max, dtype=torch.float32, device='cuda') ks_min, ks_max = torch.tensor(FLAGS.ks_min, dtype=torch.float32, device='cuda'), torch.tensor(FLAGS.ks_max, dtype=torch.float32, device='cuda') nrm_min, nrm_max = torch.tensor(FLAGS.nrm_min, dtype=torch.float32, device='cuda'), torch.tensor(FLAGS.nrm_max, dtype=torch.float32, device='cuda') if mlp: mlp_min = torch.cat((kd_min[0:3], ks_min, nrm_min), dim=0) mlp_max = torch.cat((kd_max[0:3], ks_max, nrm_max), dim=0) mlp_map_opt = mlptexture.MLPTexture3D(geometry.getAABB(), channels=9, min_max=[mlp_min, mlp_max]) mat = material.Material({'kd_ks_normal' : mlp_map_opt}) else: # Setup Kd (albedo) and Ks (x, roughness, metalness) textures if FLAGS.random_textures or init_mat is None: num_channels = 4 if FLAGS.layers > 1 else 3 kd_init = torch.rand(size=FLAGS.texture_res + [num_channels], device='cuda') * (kd_max - kd_min)[None, None, 0:num_channels] + kd_min[None, None, 0:num_channels] kd_map_opt = texture.create_trainable(kd_init , FLAGS.texture_res, not FLAGS.custom_mip, [kd_min, kd_max]) ksR = np.random.uniform(size=FLAGS.texture_res + [1], low=0.0, high=0.01) ksG = np.random.uniform(size=FLAGS.texture_res + [1], low=ks_min[1].cpu(), high=ks_max[1].cpu()) ksB = np.random.uniform(size=FLAGS.texture_res + [1], low=ks_min[2].cpu(), high=ks_max[2].cpu()) ks_map_opt = texture.create_trainable(np.concatenate((ksR, ksG, ksB), axis=2), FLAGS.texture_res, not FLAGS.custom_mip, [ks_min, ks_max]) else: kd_map_opt = texture.create_trainable(init_mat['kd'], FLAGS.texture_res, not FLAGS.custom_mip, [kd_min, kd_max]) ks_map_opt = texture.create_trainable(init_mat['ks'], FLAGS.texture_res, not FLAGS.custom_mip, [ks_min, ks_max]) # Setup normal map if FLAGS.random_textures or init_mat is None or 'normal' not in init_mat: normal_map_opt = texture.create_trainable(np.array([0, 0, 1]), FLAGS.texture_res, not FLAGS.custom_mip, [nrm_min, nrm_max]) else: normal_map_opt = texture.create_trainable(init_mat['normal'], FLAGS.texture_res, not FLAGS.custom_mip, [nrm_min, nrm_max]) mat = material.Material({ 'kd' : kd_map_opt, 'ks' : ks_map_opt, 'normal' : normal_map_opt }) if init_mat is not None: mat['bsdf'] = init_mat['bsdf'] else: mat['bsdf'] = 'pbr' return mat def initial_guess_material_knownkskd(geometry, mlp, FLAGS, init_mat=None): nrm_min, nrm_max = torch.tensor(FLAGS.nrm_min, dtype=torch.float32, device='cuda'), torch.tensor(FLAGS.nrm_max, dtype=torch.float32, device='cuda') if mlp: mlp_min = nrm_min mlp_max = nrm_max mlp_map_opt = mlptexture.MLPTexture3D(geometry.getAABB(), channels=3, min_max=[mlp_min, mlp_max]) # mlp_map_opt = mlptexture.MLPTexture3D(geometry.getAABB(), channels=3, min_max=None) mat = material.Material({ 'kd' : init_mat['kd'], 'ks' : init_mat['ks'], 'normal' : mlp_map_opt, }) else: # Setup normal map if FLAGS.random_textures or init_mat is None or 'normal' not in init_mat: normal_map_opt = texture.create_trainable(np.array([0, 0, 1]), FLAGS.texture_res, not FLAGS.custom_mip, [nrm_min, nrm_max]) else: normal_map_opt = texture.create_trainable(init_mat['normal'], FLAGS.texture_res, not FLAGS.custom_mip, [nrm_min, nrm_max]) mat = material.Material({ 'kd' : init_mat['kd'], 'ks' : init_mat['ks'], 'normal' : normal_map_opt }) if init_mat is not None: mat['bsdf'] = init_mat['bsdf'] else: mat['bsdf'] = 'pbr' return mat ############################################################################### # Validation & testing ############################################################################### def validate_itr(glctx, target, geometry, opt_material, lgt, FLAGS): result_dict = {} with torch.no_grad(): lgt.build_mips() if FLAGS.camera_space_light: lgt.xfm(target['mv']) lgt.xfm(target['envlight_transform']) try: buffers = geometry.render(glctx, target, lgt, opt_material, ema=True, xfm_lgt=target['envlight_transform']) except: buffers = geometry.render(glctx, target, lgt, opt_material, xfm_lgt=target['envlight_transform']) result_dict['ref'] = util.rgb_to_srgb(target['img'][...,0:3])[0] result_dict['opt'] = util.rgb_to_srgb(buffers['shaded'][...,0:3])[0] result_image = torch.cat([result_dict['opt'], result_dict['ref']], axis=1) if FLAGS.display is not None: white_bg = torch.ones_like(target['background']) for layer in FLAGS.display: if 'latlong' in layer and layer['latlong']: if isinstance(lgt, light.EnvironmentLight): result_dict['light_image'] = util.cubemap_to_latlong(lgt.base, FLAGS.display_res) result_image = torch.cat([result_image, result_dict['light_image']], axis=1) elif 'relight' in layer: if not isinstance(layer['relight'], light.EnvironmentLight): layer['relight'] = light.load_env(layer['relight']) img = geometry.render(glctx, target, layer['relight'], opt_material) result_dict['relight'] = util.rgb_to_srgb(img[..., 0:3])[0] result_image = torch.cat([result_image, result_dict['relight']], axis=1) elif 'bsdf' in layer: buffers = geometry.render(glctx, target, lgt, opt_material, bsdf=layer['bsdf']) if layer['bsdf'] == 'kd': result_dict[layer['bsdf']] = util.rgb_to_srgb(buffers['shaded'][0, ..., 0:3]) elif layer['bsdf'] == 'normal': result_dict[layer['bsdf']] = (buffers['shaded'][0, ..., 0:3] + 1) * 0.5 else: result_dict[layer['bsdf']] = buffers['shaded'][0, ..., 0:3] result_image = torch.cat([result_image, result_dict[layer['bsdf']]], axis=1) elif "depth" in layer: depth = buffers['depth'][:, :, :, 0].squeeze().unsqueeze(-1).expand(-1, -1, 1) mask = (depth != 0).float() depth_min = ((1 - mask) * 1e3 + depth).min() depth_max = depth.max() depth = (depth - depth_min) / (depth_max - depth_min + 1e-8) depth = depth * mask + (1 - mask) * depth_min depth = depth.expand(-1, -1, 3) depth = cv2.LUT(np.array(depth.detach().cpu().numpy() * 255.0, dtype=np.uint8), lut) result_dict['depth'] = depth = (torch.tensor(depth, device=mask.device).float() / 255.0 * mask) + 255. * (1 - mask) result_image = torch.cat([result_image, depth], axis=1) buffers = geometry.render(glctx, target, lgt, opt_material) camera = target['geo_viewdir'][:, :, :, :3] result_dict['geo_normal'] = (util.safe_normalize(buffers['geo_normal'][:, :, :, :3]) * camera).sum(-1, keepdim=False).abs()[0] mask = buffers['mask'][0].expand(-1, -1, 3) result_image = torch.cat([result_image, mask], axis=1) return result_image, result_dict def validate(glctx, geometry, opt_material, lgt, dataset_validate, out_dir, FLAGS): # ============================================================================================== # Validation loop # ============================================================================================== mse_values = [] psnr_values = [] dataloader_validate = torch.utils.data.DataLoader(dataset_validate, batch_size=1, collate_fn=dataset_validate.collate) os.makedirs(out_dir, exist_ok=True) with open(os.path.join(out_dir, 'metrics.txt'), 'w') as fout: fout.write('ID, MSE, PSNR\n') print("Running validation") for it, target in enumerate(dataloader_validate): # Mix validation background target = prepare_batch(target, FLAGS.background) result_image, result_dict = validate_itr(glctx, target, geometry, opt_material, lgt, FLAGS) # Compute metrics opt = torch.clamp(result_dict['opt'], 0.0, 1.0) ref = torch.clamp(result_dict['ref'], 0.0, 1.0) mse = torch.nn.functional.mse_loss(opt, ref, size_average=None, reduce=None, reduction='mean').item() mse_values.append(float(mse)) psnr = util.mse_to_psnr(mse) psnr_values.append(float(psnr)) line = "%d, %1.8f, %1.8f\n" % (it, mse, psnr) fout.write(str(line)) for k in result_dict.keys(): np_img = result_dict[k].detach().cpu().numpy() util.save_image(out_dir + '/' + ('val_%06d_%s.png' % (it, k)), np_img) avg_mse = np.mean(np.array(mse_values)) avg_psnr = np.mean(np.array(psnr_values)) line = "AVERAGES: %1.4f, %2.3f\n" % (avg_mse, avg_psnr) fout.write(str(line)) print("MSE, PSNR") print("%1.8f, %2.3f" % (avg_mse, avg_psnr)) return avg_psnr ############################################################################### # Main shape fitter function / optimization loop ############################################################################### class Trainer(torch.nn.Module): def __init__(self, glctx, geometry, lgt, mat, optimize_geometry, optimize_light, image_loss_fn, FLAGS): super(Trainer, self).__init__() self.glctx = glctx self.geometry = geometry self.light = lgt self.material = mat self.optimize_geometry = optimize_geometry self.optimize_light = optimize_light self.image_loss_fn = image_loss_fn self.FLAGS = FLAGS if not self.optimize_light: with torch.no_grad(): self.light.build_mips() self.params = list(self.material.parameters()) self.params += list(self.light.parameters()) if optimize_light else [] self.geo_params = list(self.geometry.parameters()) if optimize_geometry else [] try: self.sdf_params = [self.geometry.sdf] except: self.sdf_params = [] self.deform_params = [self.geometry.deform] def forward(self, target, it): if self.optimize_light: self.light.build_mips() if self.FLAGS.camera_space_light: self.light.xfm(target['mv']) self.light.xfm(target['envlight_transform']) return self.geometry.tick(glctx, target, self.light, self.material, self.image_loss_fn, it, xfm_lgt=target['envlight_transform'], no_depth_thin=False) def optimize_mesh( glctx, geometry, opt_material, lgt, dataset_train, dataset_validate, FLAGS, warmup_iter=0, log_interval=10, pass_idx=0, pass_name="", optimize_light=True, optimize_geometry=True, ): # ============================================================================================== # Setup torch optimizer # ============================================================================================== learning_rate = FLAGS.learning_rate[pass_idx] if isinstance(FLAGS.learning_rate, list) or isinstance(FLAGS.learning_rate, tuple) else FLAGS.learning_rate learning_rate_pos = learning_rate[0] if isinstance(learning_rate, list) or isinstance(learning_rate, tuple) else learning_rate learning_rate_mat = learning_rate[1] if isinstance(learning_rate, list) or isinstance(learning_rate, tuple) else learning_rate def lr_schedule(iter, fraction): if iter < warmup_iter: return iter / warmup_iter return max(0.0, 10**(-(iter - warmup_iter)*0.0002)) # Exponential falloff from [1.0, 0.1] over 5k epochs. # ============================================================================================== # Image loss # ============================================================================================== image_loss_fn = createLoss(FLAGS) trainer_noddp = Trainer(glctx, geometry, lgt, opt_material, optimize_geometry, optimize_light, image_loss_fn, FLAGS) # Single GPU training mode trainer = trainer_noddp if optimize_geometry: optimizer_mesh = torch.optim.Adam([ {'params': trainer_noddp.sdf_params, 'lr': learning_rate_pos}, {'params': trainer_noddp.deform_params, 'lr': learning_rate_pos}, ] ) scheduler_mesh = torch.optim.lr_scheduler.LambdaLR(optimizer_mesh, lr_lambda=lambda x: lr_schedule(x, 0.9)) optimizer = torch.optim.Adam(trainer_noddp.params, lr=learning_rate_mat) scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda x: lr_schedule(x, 0.9)) # ============================================================================================== # Training loop # ============================================================================================== img_cnt = 0 img_loss_vec = [] reg_loss_vec = [] iter_dur_vec = [] dataloader_validate = torch.utils.data.DataLoader(dataset_validate, batch_size=1, collate_fn=dataset_train.collate) def cycle(iterable): iterator = iter(iterable) while True: try: yield next(iterator) except StopIteration: iterator = iter(iterable) v_it = cycle(dataloader_validate) # v_iter_no = 25 v_iter_no = 10 print("Start training loop...") sys.stdout.flush() for _ in range(v_iter_no): v_curr = next(v_it) for it in range(5000): # Mix randomized background into dataset image target = prepare_batch(v_curr, 'random') ### for robustness, we take the easy way of initializing the tet grid with the gt depth image if it < 300 and it % 10 == 0: gt_visible_triangles = target['rast_triangle_id'].long() gt_verts, gt_faces = target['vpts'], target['faces'] surface_faces = gt_faces[gt_visible_triangles] campos = target['campos'][0] try: geometry.init_with_gt_surface(gt_verts, surface_faces, campos) except: pass iter_start_time = time.time() # ============================================================================================== # Zero gradients # ============================================================================================== optimizer.zero_grad() if optimize_geometry: optimizer_mesh.zero_grad() # ============================================================================================== # Training # ============================================================================================== img_loss, reg_loss = trainer(target, it) # ============================================================================================== # Final loss # ============================================================================================== total_loss = img_loss + reg_loss img_loss_vec.append(img_loss.item()) reg_loss_vec.append(reg_loss.item()) # ============================================================================================== # Backpropagate # ============================================================================================== total_loss.backward() if hasattr(lgt, 'base') and lgt.base.grad is not None and optimize_light: lgt.base.grad *= 64 if 'kd_ks_normal' in opt_material: opt_material['kd_ks_normal'].encoder.params.grad /= 8.0 if 'normal' in opt_material and FLAGS.normal_only: try: opt_material['normal'].encoder.params.grad /= 8.0 except: pass optimizer.step() scheduler.step() if optimize_geometry: optimizer_mesh.step() scheduler_mesh.step() geometry.clamp_deform() if FLAGS.use_ema: raise NotImplementedError geometry.update_ema() # ============================================================================================== # Clamp trainables to reasonable range # ============================================================================================== with torch.no_grad(): if 'kd' in opt_material: opt_material['kd'].clamp_() if 'ks' in opt_material: opt_material['ks'].clamp_() if 'normal' in opt_material and not FLAGS.normal_only: opt_material['normal'].clamp_() opt_material['normal'].normalize_() if lgt is not None: lgt.clamp_(min=0.0) torch.cuda.current_stream().synchronize() iter_dur_vec.append(time.time() - iter_start_time) # ============================================================================================== # Logging # ============================================================================================== if it % log_interval == 0 and FLAGS.local_rank == 0: img_loss_avg = np.mean(np.asarray(img_loss_vec[-log_interval:])) reg_loss_avg = np.mean(np.asarray(reg_loss_vec[-log_interval:])) iter_dur_avg = np.mean(np.asarray(iter_dur_vec[-log_interval:])) remaining_time = (FLAGS.iter-it)*iter_dur_avg print("iter=%5d, img_loss=%.6f, reg_loss=%.6f, lr=%.5f, time=%.1f ms, rem=%s" % (it, img_loss_avg, reg_loss_avg, optimizer.param_groups[0]['lr'], iter_dur_avg*1000, util.time_to_text(remaining_time))) sys.stdout.flush() return geometry, opt_material #---------------------------------------------------------------------------- # Main function. #---------------------------------------------------------------------------- if __name__ == "__main__": parser = argparse.ArgumentParser(description='nvdiffrec') parser.add_argument('--config', type=str, default='./configs/res64.json', help='Config file') parser.add_argument('-i', '--iter', type=int, default=5000) parser.add_argument('-s', '--spp', type=int, default=1) parser.add_argument('-l', '--layers', type=int, default=1) parser.add_argument('-r', '--train-res', nargs=2, type=int, default=[512, 512]) parser.add_argument('-dr', '--display-res', type=int, default=None) parser.add_argument('-tr', '--texture-res', nargs=2, type=int, default=[1024, 1024]) parser.add_argument('-di', '--display-interval', type=int, default=0) parser.add_argument('-si', '--save-interval', type=int, default=1000) parser.add_argument('-lr', '--learning-rate', type=float, default=0.01) parser.add_argument('-mr', '--min-roughness', type=float, default=0.08) parser.add_argument('-mip', '--custom-mip', action='store_true', default=False) parser.add_argument('-rt', '--random-textures', action='store_true', default=False) parser.add_argument('-bg', '--background', default='checker', choices=['black', 'white', 'checker', 'reference']) parser.add_argument('--loss', default='logl1', choices=['logl1', 'logl2', 'mse', 'smape', 'relmse']) parser.add_argument('-o', '--out-dir', type=str, default='./dmtet_results_singleview') parser.add_argument('--validate', type=bool, default=True) parser.add_argument('-no', '--normal-only', type=bool, default=True) parser.add_argument('-ema', '--use-ema', action="store_true") parser.add_argument('-rp', '--resume-path', type=str, default=None) parser.add_argument('-mp', '--mesh-path', type=str) parser.add_argument('-an', '--angle-ind', type=int, help='angle index from 0 to 50') FLAGS = parser.parse_args() print(f"parsed arguments") FLAGS.mtl_override = None # Override material of model FLAGS.dmtet_grid = 64 # Resolution of initial tet grid. We provide 64 and 128 resolution grids. Other resolutions can be generated with https://github.com/crawforddoran/quartet FLAGS.mesh_scale = 1.0 # Scale of tet grid box. Adjust to cover the model FLAGS.env_scale = 1.0 # Env map intensity multiplier FLAGS.envmap = None # HDR environment probe FLAGS.display = None # Conf validation window/display. E.g. [{"relight" : }] FLAGS.camera_space_light = False # Fixed light in camera space. This is needed for setups like ethiopian head where the scanned object rotates on a stand. FLAGS.lock_light = False # Disable light optimization in the second pass FLAGS.lock_pos = False # Disable vertex position optimization in the second pass FLAGS.sdf_regularizer = 0.2 # Weight for sdf regularizer (see paper for details) FLAGS.laplace = "relative" # Mesh Laplacian ["absolute", "relative"] FLAGS.laplace_scale = 10000.0 # Weight for sdf regularizer. Default is relative with large weight FLAGS.pre_load = True # Pre-load entire dataset into memory for faster training FLAGS.kd_min = [ 0.0, 0.0, 0.0, 0.0] # Limits for kd FLAGS.kd_max = [ 1.0, 1.0, 1.0, 1.0] FLAGS.ks_min = [ 0.0, 0.08, 0.0] # Limits for ks FLAGS.ks_max = [ 1.0, 1.0, 1.0] FLAGS.nrm_min = [-1.0, -1.0, 0.0] # Limits for normal map FLAGS.nrm_max = [ 1.0, 1.0, 1.0] FLAGS.cam_near_far = [0.1, 1000.0] FLAGS.learn_light = False FLAGS.use_ema = False FLAGS.random_lgt = True FLAGS.dataset_flat_shading = False FLAGS.local_rank = 0 if FLAGS.config is not None: data = json.load(open(FLAGS.config, 'r')) for key in data: FLAGS.__dict__[key] = data[key] if FLAGS.display_res is None: FLAGS.display_res = FLAGS.train_res print(f"Out dir: {FLAGS.out_dir}") if FLAGS.local_rank == 0: print("Config / Flags:") print("---------") for key in FLAGS.__dict__.keys(): print(key, FLAGS.__dict__[key]) print("---------") os.makedirs(FLAGS.out_dir, exist_ok=True) os.makedirs(os.path.join(FLAGS.out_dir, 'val_viz'), exist_ok=True) os.makedirs(os.path.join(FLAGS.out_dir, 'val_viz_pre'), exist_ok=True) os.makedirs(os.path.join(FLAGS.out_dir, 'tets'), exist_ok=True) os.makedirs(os.path.join(FLAGS.out_dir, 'tets_pre'), exist_ok=True) print(f"Using dmtet grid of resolution {FLAGS.dmtet_grid}") glctx = dr.RasterizeGLContext() ### Default mtl mtl_default = { 'name' : '_default_mat', 'bsdf': 'diffuse', 'uniform': True, 'kd' : texture.Texture2D(torch.tensor([0.75, 0.3, 0.6], dtype=torch.float32, device='cuda'), trainable=False), 'ks' : texture.Texture2D(torch.tensor([0.0, 0.0, 0.0], dtype=torch.float32, device='cuda'), trainable=False) } print(f"Loading mesh: {FLAGS.mesh_path}") sys.stdout.flush() ref_mesh = mesh.load_mesh(FLAGS.mesh_path, FLAGS.mtl_override, mtl_default, use_default=FLAGS.normal_only, no_additional=True) ref_mesh = mesh.center_by_reference(ref_mesh, mesh.aabb_clean(ref_mesh), 1.0) print("Loading dataset") sys.stdout.flush() dataset_train = DatasetMesh(ref_mesh, glctx, RADIUS, FLAGS, validate=False) dataset_validate = DatasetMesh(ref_mesh, glctx, RADIUS, FLAGS, validate=True) print("Dataset loaded") sys.stdout.flush() # ============================================================================================== # Create env light with trainable parameters # ============================================================================================== if FLAGS.learn_light: lgt = light.create_trainable_env_rnd(512, scale=0.0, bias=0.5) else: lgt = light.load_env(FLAGS.envmap, scale=FLAGS.env_scale, trainable=False) # ============================================================================================== # If no initial guess, use DMtets to create geometry # ============================================================================================== # Setup geometry for optimization geometry = DMTetGeometry( FLAGS.dmtet_grid, FLAGS.mesh_scale, FLAGS, deform_scale=FLAGS.first_stage_deform ) # Setup textures, make initial guess from reference if possible if not FLAGS.normal_only: mat = initial_guess_material(geometry, True, FLAGS, mtl_default) else: mat = initial_guess_material_knownkskd(geometry, True, FLAGS, mtl_default) print("Start optimization") sys.stdout.flush() if FLAGS.resume_path is None: # Run optimization geometry, mat = optimize_mesh(glctx, geometry, mat, lgt, dataset_train, dataset_validate, FLAGS, pass_idx=0, pass_name="dmtet_pass1", optimize_light=FLAGS.learn_light) base_mesh = geometry.getMesh(mat) vert_mask = torch.zeros_like(geometry.sdf).long().cuda().view(-1, 1) vert_mask[geometry.getValidVertsIdx()] = 1 # Free temporaries / cached memory torch.cuda.empty_cache() ### may slow down training old_geometry = geometry if FLAGS.local_rank == 0 and FLAGS.validate: validate(glctx, geometry, mat, lgt, dataset_validate, os.path.join(FLAGS.out_dir, f"val_viz_pre/dmtet_validate_{FLAGS.index}_{k}_{FLAGS.split_size}"), FLAGS) else: dmt_dict = torch.load(FLAGS.resume_path) if FLAGS.use_ema: geometry.sdf.data[:] = dmt_dict['sdf_ema'] else: geometry.sdf.data[:] = dmt_dict['sdf'] geometry.deform.data[:] = dmt_dict['deform'] old_geometry = geometry # Create textured mesh from result if FLAGS.normal_only: base_mesh = xatlas_uvmap_nrm(glctx, geometry, mat, FLAGS) else: base_mesh = xatlas_uvmap(glctx, geometry, mat, FLAGS) geometry = DMTetGeometryFixedTopo( geometry, base_mesh, FLAGS.dmtet_grid, FLAGS.mesh_scale, FLAGS, deform_scale=FLAGS.second_stage_deform ) geometry.sdf_sign.requires_grad = False geometry.sdf_abs.requires_grad = False geometry.deform.requires_grad = True geometry.deform.data[:] = geometry.deform * FLAGS.first_stage_deform / FLAGS.second_stage_deform if FLAGS.use_ema: geometry.sdf_sign.data[:] = torch.sign(old_geometry.sdf_ema) ### use ema else: geometry.sdf_sign.data[:] = torch.sign(old_geometry.sdf) ### use ema geometry.set_init_v_pos() # ============================================================================================== # Pass 2: Train with fixed topology (mesh) # ============================================================================================== geometry, mat = optimize_mesh(glctx, geometry, mat, lgt, dataset_train, dataset_validate, FLAGS, pass_idx=1, pass_name="mesh_pass", warmup_iter=100, optimize_light=FLAGS.learn_light and not FLAGS.lock_light, optimize_geometry=not FLAGS.lock_pos) ##### Process single-view tet grid dataloader_validate = torch.utils.data.DataLoader(dataset_validate, batch_size=1, collate_fn=dataset_train.collate) v_it = iter(dataloader_validate) for _ in range(FLAGS.angle_ind): v_curr = next(v_it) target = prepare_batch(v_curr, 'random') # ============================================================================================== # Infer occluded regions # ============================================================================================== valid_tet_idx = geometry.getValidTetIdx().long() buffers = geometry.render(glctx, target, lgt, mat, get_visible_tets=True) ## visible tets (except for rasterized ones) visible_tets = torch.zeros(geometry.indices.size(0)).cuda() visible_tets[buffers['visible_tet_id'].long()] = 1 ## to include the rasterized tetrahedra visible_and_rast_tets = visible_tets.clone() rast_tet_id = valid_tet_idx[buffers['rast_triangle_id'].long()].unique() visible_and_rast_tets[rast_tet_id] = 1 visible_tets = (visible_tets == 1) visible_and_rast_tets = (visible_and_rast_tets == 1) ## label all tetrahedral vertices associated with any visible tets visible_verts = torch.zeros(geometry.verts.size(0)) tet_inds = torch.arange(geometry.indices.size(0)) vis_vert_inds = geometry.indices[visible_tets].unique() visible_verts[vis_vert_inds] = 1 visible_and_rast_verts = visible_verts.clone() vis_and_rast_vert_inds = geometry.indices[visible_and_rast_tets].unique() visible_and_rast_verts[vis_and_rast_vert_inds] = 1 visible_and_rast_verts = visible_and_rast_verts.bool() torch.save({ 'sdf': geometry.sdf_sign.cpu().detach(), 'deform': geometry.deform.cpu().detach(), 'vis': visible_verts.cpu().detach(), 'vis_rast': visible_and_rast_verts.cpu().detach() }, os.path.join(FLAGS.out_dir, 'tets/dmtet.pt')) # ============================================================================================== if FLAGS.local_rank == 0 and FLAGS.validate: validate(glctx, geometry, mat, lgt, dataset_validate, os.path.join(FLAGS.out_dir, f"val_viz/dmtet"), FLAGS)