2023-03-11 10:17:59 +00:00
# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.
import os
import time
import argparse
import json
import sys
import cv2
import numpy as np
import torch
import nvdiffrast . torch as dr
import xatlas
# Import data readers / generators
from lib . dataset . dataset_mesh import DatasetMesh
from lib . dataset . dataset_shapenet import ShapeNetDataset
# Import topology / geometry trainers
from lib . geometry . dmtet_singleview import DMTetGeometry
from lib . geometry . dmtet_fixedtopo import DMTetGeometryFixedTopo
import lib . render . renderutils as ru
from lib . render import obj
from lib . render import material
from lib . render import util
from lib . render import mesh
from lib . render import texture
from lib . render import mlptexture
from lib . render import light
from lib . render import render
from random import randint
from time import sleep
import traceback
RADIUS = 2.0
# define colors
color1 = ( 0 , 0 , 255 ) #red
color2 = ( 0 , 165 , 255 ) #orange
color3 = ( 0 , 255 , 255 ) #yellow
color4 = ( 255 , 255 , 0 ) #cyan
color5 = ( 255 , 0 , 0 ) #blue
color6 = ( 128 , 64 , 64 ) #violet
colorArr = np . array ( [ [ color1 , color2 , color3 , color4 , color5 , color6 ] ] , dtype = np . uint8 )
# resize lut to 256 (or more) values
lut = cv2 . resize ( colorArr , ( 256 , 1 ) , interpolation = cv2 . INTER_LINEAR )
###############################################################################
# Loss setup
###############################################################################
@torch.no_grad ( )
def createLoss ( FLAGS ) :
if FLAGS . loss == " smape " :
return lambda img , ref : ru . image_loss ( img , ref , loss = ' smape ' , tonemapper = ' none ' )
elif FLAGS . loss == " mse " :
return lambda img , ref : ru . image_loss ( img , ref , loss = ' mse ' , tonemapper = ' none ' )
elif FLAGS . loss == " logl1 " :
return lambda img , ref : ru . image_loss ( img , ref , loss = ' l1 ' , tonemapper = ' log_srgb ' )
elif FLAGS . loss == " logl2 " :
return lambda img , ref : ru . image_loss ( img , ref , loss = ' mse ' , tonemapper = ' log_srgb ' )
elif FLAGS . loss == " relmse " :
return lambda img , ref : ru . image_loss ( img , ref , loss = ' relmse ' , tonemapper = ' none ' )
else :
assert False
###############################################################################
# Mix background into a dataset image
###############################################################################
@torch.no_grad ( )
def prepare_batch ( target , bg_type = ' black ' ) :
assert len ( target [ ' img ' ] . shape ) == 4 , " Image shape should be [n, h, w, c] "
if bg_type == ' checker ' :
background = torch . tensor ( util . checkerboard ( target [ ' img ' ] . shape [ 1 : 3 ] , 8 ) , dtype = torch . float32 , device = ' cuda ' ) [ None , . . . ]
elif bg_type == ' black ' :
background = torch . zeros ( target [ ' img ' ] . shape [ 0 : 3 ] + ( 3 , ) , dtype = torch . float32 , device = ' cuda ' )
elif bg_type == ' white ' :
background = torch . ones ( target [ ' img ' ] . shape [ 0 : 3 ] + ( 3 , ) , dtype = torch . float32 , device = ' cuda ' )
elif bg_type == ' reference ' :
background = target [ ' img ' ] [ . . . , 0 : 3 ]
elif bg_type == ' random ' :
background = torch . rand ( target [ ' img ' ] . shape [ 0 : 3 ] + ( 3 , ) , dtype = torch . float32 , device = ' cuda ' )
else :
assert False , " Unknown background type %s " % bg_type
target [ ' mv ' ] = target [ ' mv ' ] . cuda ( )
target [ ' mvp ' ] = target [ ' mvp ' ] . cuda ( )
target [ ' campos ' ] = target [ ' campos ' ] . cuda ( )
target [ ' img ' ] = target [ ' img ' ] . cuda ( )
target [ ' background ' ] = background
target [ ' img ' ] = torch . cat ( ( torch . lerp ( background , target [ ' img ' ] [ . . . , 0 : 3 ] , target [ ' img ' ] [ . . . , 3 : 4 ] ) , target [ ' img ' ] [ . . . , 3 : 4 ] ) , dim = - 1 )
target [ ' spts ' ] = target [ ' spts ' ] . cuda ( )
target [ ' vpts ' ] = target [ ' vpts ' ] . cuda ( )
return target
###############################################################################
# UV - map geometry & convert to a mesh
###############################################################################
@torch.no_grad ( )
def xatlas_uvmap ( glctx , geometry , mat , FLAGS ) :
eval_mesh = geometry . getMesh ( mat )
# Create uvs with xatlas
v_pos = eval_mesh . v_pos . detach ( ) . cpu ( ) . numpy ( )
t_pos_idx = eval_mesh . t_pos_idx . detach ( ) . cpu ( ) . numpy ( )
vmapping , indices , uvs = xatlas . parametrize ( v_pos , t_pos_idx )
# Convert to tensors
indices_int64 = indices . astype ( np . uint64 , casting = ' same_kind ' ) . view ( np . int64 )
uvs = torch . tensor ( uvs , dtype = torch . float32 , device = ' cuda ' )
faces = torch . tensor ( indices_int64 , dtype = torch . int64 , device = ' cuda ' )
new_mesh = mesh . Mesh ( v_tex = uvs , t_tex_idx = faces , base = eval_mesh )
mask , kd , ks , normal = render . render_uv ( glctx , new_mesh , FLAGS . texture_res , eval_mesh . material [ ' kd_ks_normal ' ] )
if FLAGS . layers > 1 :
kd = torch . cat ( ( kd , torch . rand_like ( kd [ . . . , 0 : 1 ] ) ) , dim = - 1 )
kd_min , kd_max = torch . tensor ( FLAGS . kd_min , dtype = torch . float32 , device = ' cuda ' ) , torch . tensor ( FLAGS . kd_max , dtype = torch . float32 , device = ' cuda ' )
ks_min , ks_max = torch . tensor ( FLAGS . ks_min , dtype = torch . float32 , device = ' cuda ' ) , torch . tensor ( FLAGS . ks_max , dtype = torch . float32 , device = ' cuda ' )
nrm_min , nrm_max = torch . tensor ( FLAGS . nrm_min , dtype = torch . float32 , device = ' cuda ' ) , torch . tensor ( FLAGS . nrm_max , dtype = torch . float32 , device = ' cuda ' )
new_mesh . material = material . Material ( {
' bsdf ' : mat [ ' bsdf ' ] ,
' kd ' : texture . Texture2D ( kd , min_max = [ kd_min , kd_max ] ) ,
' ks ' : texture . Texture2D ( ks , min_max = [ ks_min , ks_max ] ) ,
' normal ' : texture . Texture2D ( normal , min_max = [ nrm_min , nrm_max ] )
} )
return new_mesh
@torch.no_grad ( )
def xatlas_uvmap_nrm ( glctx , geometry , mat , FLAGS ) :
eval_mesh = geometry . getMesh ( mat )
# Create uvs with xatlas
v_pos = eval_mesh . v_pos . detach ( ) . cpu ( ) . numpy ( )
t_pos_idx = eval_mesh . t_pos_idx . detach ( ) . cpu ( ) . numpy ( )
vmapping , indices , uvs = xatlas . parametrize ( v_pos , t_pos_idx )
# Convert to tensors
indices_int64 = indices . astype ( np . uint64 , casting = ' same_kind ' ) . view ( np . int64 )
uvs = torch . tensor ( uvs , dtype = torch . float32 , device = ' cuda ' )
faces = torch . tensor ( indices_int64 , dtype = torch . int64 , device = ' cuda ' )
new_mesh = mesh . Mesh ( v_tex = uvs , t_tex_idx = faces , base = eval_mesh )
mask , normal = render . render_uv_nrm ( glctx , new_mesh , FLAGS . texture_res , eval_mesh . material [ ' normal ' ] )
if FLAGS . layers > 1 :
kd = torch . cat ( ( kd , torch . rand_like ( kd [ . . . , 0 : 1 ] ) ) , dim = - 1 )
nrm_min , nrm_max = torch . tensor ( FLAGS . nrm_min , dtype = torch . float32 , device = ' cuda ' ) , torch . tensor ( FLAGS . nrm_max , dtype = torch . float32 , device = ' cuda ' )
new_mesh . material = material . Material ( {
' bsdf ' : mat [ ' bsdf ' ] ,
' kd ' : mat [ ' kd ' ] ,
' ks ' : mat [ ' ks ' ] ,
' normal ' : texture . Texture2D ( normal , min_max = [ nrm_min , nrm_max ] )
} )
return new_mesh
###############################################################################
# Utility functions for material
###############################################################################
def initial_guess_material ( geometry , mlp , FLAGS , init_mat = None ) :
kd_min , kd_max = torch . tensor ( FLAGS . kd_min , dtype = torch . float32 , device = ' cuda ' ) , torch . tensor ( FLAGS . kd_max , dtype = torch . float32 , device = ' cuda ' )
ks_min , ks_max = torch . tensor ( FLAGS . ks_min , dtype = torch . float32 , device = ' cuda ' ) , torch . tensor ( FLAGS . ks_max , dtype = torch . float32 , device = ' cuda ' )
nrm_min , nrm_max = torch . tensor ( FLAGS . nrm_min , dtype = torch . float32 , device = ' cuda ' ) , torch . tensor ( FLAGS . nrm_max , dtype = torch . float32 , device = ' cuda ' )
if mlp :
mlp_min = torch . cat ( ( kd_min [ 0 : 3 ] , ks_min , nrm_min ) , dim = 0 )
mlp_max = torch . cat ( ( kd_max [ 0 : 3 ] , ks_max , nrm_max ) , dim = 0 )
mlp_map_opt = mlptexture . MLPTexture3D ( geometry . getAABB ( ) , channels = 9 , min_max = [ mlp_min , mlp_max ] )
mat = material . Material ( { ' kd_ks_normal ' : mlp_map_opt } )
else :
# Setup Kd (albedo) and Ks (x, roughness, metalness) textures
if FLAGS . random_textures or init_mat is None :
num_channels = 4 if FLAGS . layers > 1 else 3
kd_init = torch . rand ( size = FLAGS . texture_res + [ num_channels ] , device = ' cuda ' ) * ( kd_max - kd_min ) [ None , None , 0 : num_channels ] + kd_min [ None , None , 0 : num_channels ]
kd_map_opt = texture . create_trainable ( kd_init , FLAGS . texture_res , not FLAGS . custom_mip , [ kd_min , kd_max ] )
ksR = np . random . uniform ( size = FLAGS . texture_res + [ 1 ] , low = 0.0 , high = 0.01 )
ksG = np . random . uniform ( size = FLAGS . texture_res + [ 1 ] , low = ks_min [ 1 ] . cpu ( ) , high = ks_max [ 1 ] . cpu ( ) )
ksB = np . random . uniform ( size = FLAGS . texture_res + [ 1 ] , low = ks_min [ 2 ] . cpu ( ) , high = ks_max [ 2 ] . cpu ( ) )
ks_map_opt = texture . create_trainable ( np . concatenate ( ( ksR , ksG , ksB ) , axis = 2 ) , FLAGS . texture_res , not FLAGS . custom_mip , [ ks_min , ks_max ] )
else :
kd_map_opt = texture . create_trainable ( init_mat [ ' kd ' ] , FLAGS . texture_res , not FLAGS . custom_mip , [ kd_min , kd_max ] )
ks_map_opt = texture . create_trainable ( init_mat [ ' ks ' ] , FLAGS . texture_res , not FLAGS . custom_mip , [ ks_min , ks_max ] )
# Setup normal map
if FLAGS . random_textures or init_mat is None or ' normal ' not in init_mat :
normal_map_opt = texture . create_trainable ( np . array ( [ 0 , 0 , 1 ] ) , FLAGS . texture_res , not FLAGS . custom_mip , [ nrm_min , nrm_max ] )
else :
normal_map_opt = texture . create_trainable ( init_mat [ ' normal ' ] , FLAGS . texture_res , not FLAGS . custom_mip , [ nrm_min , nrm_max ] )
mat = material . Material ( {
' kd ' : kd_map_opt ,
' ks ' : ks_map_opt ,
' normal ' : normal_map_opt
} )
if init_mat is not None :
mat [ ' bsdf ' ] = init_mat [ ' bsdf ' ]
else :
mat [ ' bsdf ' ] = ' pbr '
return mat
def initial_guess_material_knownkskd ( geometry , mlp , FLAGS , init_mat = None ) :
nrm_min , nrm_max = torch . tensor ( FLAGS . nrm_min , dtype = torch . float32 , device = ' cuda ' ) , torch . tensor ( FLAGS . nrm_max , dtype = torch . float32 , device = ' cuda ' )
if mlp :
mlp_min = nrm_min
mlp_max = nrm_max
mlp_map_opt = mlptexture . MLPTexture3D ( geometry . getAABB ( ) , channels = 3 , min_max = [ mlp_min , mlp_max ] )
# mlp_map_opt = mlptexture.MLPTexture3D(geometry.getAABB(), channels=3, min_max=None)
mat = material . Material ( {
' kd ' : init_mat [ ' kd ' ] ,
' ks ' : init_mat [ ' ks ' ] ,
' normal ' : mlp_map_opt ,
} )
else :
# Setup normal map
if FLAGS . random_textures or init_mat is None or ' normal ' not in init_mat :
normal_map_opt = texture . create_trainable ( np . array ( [ 0 , 0 , 1 ] ) , FLAGS . texture_res , not FLAGS . custom_mip , [ nrm_min , nrm_max ] )
else :
normal_map_opt = texture . create_trainable ( init_mat [ ' normal ' ] , FLAGS . texture_res , not FLAGS . custom_mip , [ nrm_min , nrm_max ] )
mat = material . Material ( {
' kd ' : init_mat [ ' kd ' ] ,
' ks ' : init_mat [ ' ks ' ] ,
' normal ' : normal_map_opt
} )
if init_mat is not None :
mat [ ' bsdf ' ] = init_mat [ ' bsdf ' ]
else :
mat [ ' bsdf ' ] = ' pbr '
return mat
###############################################################################
# Validation & testing
###############################################################################
def validate_itr ( glctx , target , geometry , opt_material , lgt , FLAGS ) :
result_dict = { }
with torch . no_grad ( ) :
lgt . build_mips ( )
if FLAGS . camera_space_light :
lgt . xfm ( target [ ' mv ' ] )
lgt . xfm ( target [ ' envlight_transform ' ] )
try :
buffers = geometry . render ( glctx , target , lgt , opt_material , ema = True , xfm_lgt = target [ ' envlight_transform ' ] )
except :
buffers = geometry . render ( glctx , target , lgt , opt_material , xfm_lgt = target [ ' envlight_transform ' ] )
result_dict [ ' ref ' ] = util . rgb_to_srgb ( target [ ' img ' ] [ . . . , 0 : 3 ] ) [ 0 ]
result_dict [ ' opt ' ] = util . rgb_to_srgb ( buffers [ ' shaded ' ] [ . . . , 0 : 3 ] ) [ 0 ]
result_image = torch . cat ( [ result_dict [ ' opt ' ] , result_dict [ ' ref ' ] ] , axis = 1 )
if FLAGS . display is not None :
white_bg = torch . ones_like ( target [ ' background ' ] )
for layer in FLAGS . display :
if ' latlong ' in layer and layer [ ' latlong ' ] :
if isinstance ( lgt , light . EnvironmentLight ) :
result_dict [ ' light_image ' ] = util . cubemap_to_latlong ( lgt . base , FLAGS . display_res )
result_image = torch . cat ( [ result_image , result_dict [ ' light_image ' ] ] , axis = 1 )
elif ' relight ' in layer :
if not isinstance ( layer [ ' relight ' ] , light . EnvironmentLight ) :
layer [ ' relight ' ] = light . load_env ( layer [ ' relight ' ] )
img = geometry . render ( glctx , target , layer [ ' relight ' ] , opt_material )
result_dict [ ' relight ' ] = util . rgb_to_srgb ( img [ . . . , 0 : 3 ] ) [ 0 ]
result_image = torch . cat ( [ result_image , result_dict [ ' relight ' ] ] , axis = 1 )
elif ' bsdf ' in layer :
buffers = geometry . render ( glctx , target , lgt , opt_material , bsdf = layer [ ' bsdf ' ] )
if layer [ ' bsdf ' ] == ' kd ' :
result_dict [ layer [ ' bsdf ' ] ] = util . rgb_to_srgb ( buffers [ ' shaded ' ] [ 0 , . . . , 0 : 3 ] )
elif layer [ ' bsdf ' ] == ' normal ' :
result_dict [ layer [ ' bsdf ' ] ] = ( buffers [ ' shaded ' ] [ 0 , . . . , 0 : 3 ] + 1 ) * 0.5
else :
result_dict [ layer [ ' bsdf ' ] ] = buffers [ ' shaded ' ] [ 0 , . . . , 0 : 3 ]
result_image = torch . cat ( [ result_image , result_dict [ layer [ ' bsdf ' ] ] ] , axis = 1 )
elif " depth " in layer :
depth = buffers [ ' depth ' ] [ : , : , : , 0 ] . squeeze ( ) . unsqueeze ( - 1 ) . expand ( - 1 , - 1 , 1 )
mask = ( depth != 0 ) . float ( )
depth_min = ( ( 1 - mask ) * 1e3 + depth ) . min ( )
depth_max = depth . max ( )
depth = ( depth - depth_min ) / ( depth_max - depth_min + 1e-8 )
depth = depth * mask + ( 1 - mask ) * depth_min
depth = depth . expand ( - 1 , - 1 , 3 )
depth = cv2 . LUT ( np . array ( depth . detach ( ) . cpu ( ) . numpy ( ) * 255.0 , dtype = np . uint8 ) , lut )
result_dict [ ' depth ' ] = depth = ( torch . tensor ( depth , device = mask . device ) . float ( ) / 255.0 * mask ) + 255. * ( 1 - mask )
result_image = torch . cat ( [ result_image , depth ] , axis = 1 )
buffers = geometry . render ( glctx , target , lgt , opt_material )
camera = target [ ' geo_viewdir ' ] [ : , : , : , : 3 ]
result_dict [ ' geo_normal ' ] = ( util . safe_normalize ( buffers [ ' geo_normal ' ] [ : , : , : , : 3 ] ) * camera ) . sum ( - 1 , keepdim = False ) . abs ( ) [ 0 ]
mask = buffers [ ' mask ' ] [ 0 ] . expand ( - 1 , - 1 , 3 )
result_image = torch . cat ( [ result_image , mask ] , axis = 1 )
return result_image , result_dict
def validate ( glctx , geometry , opt_material , lgt , dataset_validate , out_dir , FLAGS ) :
# ==============================================================================================
# Validation loop
# ==============================================================================================
mse_values = [ ]
psnr_values = [ ]
dataloader_validate = torch . utils . data . DataLoader ( dataset_validate , batch_size = 1 , collate_fn = dataset_validate . collate )
os . makedirs ( out_dir , exist_ok = True )
with open ( os . path . join ( out_dir , ' metrics.txt ' ) , ' w ' ) as fout :
fout . write ( ' ID, MSE, PSNR \n ' )
print ( " Running validation " )
for it , target in enumerate ( dataloader_validate ) :
# Mix validation background
target = prepare_batch ( target , FLAGS . background )
result_image , result_dict = validate_itr ( glctx , target , geometry , opt_material , lgt , FLAGS )
# Compute metrics
opt = torch . clamp ( result_dict [ ' opt ' ] , 0.0 , 1.0 )
ref = torch . clamp ( result_dict [ ' ref ' ] , 0.0 , 1.0 )
mse = torch . nn . functional . mse_loss ( opt , ref , size_average = None , reduce = None , reduction = ' mean ' ) . item ( )
mse_values . append ( float ( mse ) )
psnr = util . mse_to_psnr ( mse )
psnr_values . append ( float ( psnr ) )
line = " %d , %1.8f , %1.8f \n " % ( it , mse , psnr )
fout . write ( str ( line ) )
for k in result_dict . keys ( ) :
np_img = result_dict [ k ] . detach ( ) . cpu ( ) . numpy ( )
util . save_image ( out_dir + ' / ' + ( ' val_ %06d _ %s .png ' % ( it , k ) ) , np_img )
avg_mse = np . mean ( np . array ( mse_values ) )
avg_psnr = np . mean ( np . array ( psnr_values ) )
line = " AVERAGES: %1.4f , %2.3f \n " % ( avg_mse , avg_psnr )
fout . write ( str ( line ) )
print ( " MSE, PSNR " )
print ( " %1.8f , %2.3f " % ( avg_mse , avg_psnr ) )
return avg_psnr
###############################################################################
# Main shape fitter function / optimization loop
###############################################################################
class Trainer ( torch . nn . Module ) :
def __init__ ( self , glctx , geometry , lgt , mat , optimize_geometry , optimize_light , image_loss_fn , FLAGS ) :
super ( Trainer , self ) . __init__ ( )
self . glctx = glctx
self . geometry = geometry
self . light = lgt
self . material = mat
self . optimize_geometry = optimize_geometry
self . optimize_light = optimize_light
self . image_loss_fn = image_loss_fn
self . FLAGS = FLAGS
if not self . optimize_light :
with torch . no_grad ( ) :
self . light . build_mips ( )
self . params = list ( self . material . parameters ( ) )
self . params + = list ( self . light . parameters ( ) ) if optimize_light else [ ]
self . geo_params = list ( self . geometry . parameters ( ) ) if optimize_geometry else [ ]
try :
self . sdf_params = [ self . geometry . sdf ]
except :
self . sdf_params = [ ]
self . deform_params = [ self . geometry . deform ]
def forward ( self , target , it ) :
if self . optimize_light :
self . light . build_mips ( )
if self . FLAGS . camera_space_light :
self . light . xfm ( target [ ' mv ' ] )
self . light . xfm ( target [ ' envlight_transform ' ] )
return self . geometry . tick ( glctx , target , self . light , self . material , self . image_loss_fn , it , xfm_lgt = target [ ' envlight_transform ' ] , no_depth_thin = False )
def optimize_mesh (
glctx ,
geometry ,
opt_material ,
lgt ,
dataset_train ,
dataset_validate ,
FLAGS ,
warmup_iter = 0 ,
log_interval = 10 ,
pass_idx = 0 ,
pass_name = " " ,
optimize_light = True ,
optimize_geometry = True ,
) :
# ==============================================================================================
# Setup torch optimizer
# ==============================================================================================
learning_rate = FLAGS . learning_rate [ pass_idx ] if isinstance ( FLAGS . learning_rate , list ) or isinstance ( FLAGS . learning_rate , tuple ) else FLAGS . learning_rate
learning_rate_pos = learning_rate [ 0 ] if isinstance ( learning_rate , list ) or isinstance ( learning_rate , tuple ) else learning_rate
learning_rate_mat = learning_rate [ 1 ] if isinstance ( learning_rate , list ) or isinstance ( learning_rate , tuple ) else learning_rate
def lr_schedule ( iter , fraction ) :
if iter < warmup_iter :
return iter / warmup_iter
return max ( 0.0 , 10 * * ( - ( iter - warmup_iter ) * 0.0002 ) ) # Exponential falloff from [1.0, 0.1] over 5k epochs.
# ==============================================================================================
# Image loss
# ==============================================================================================
image_loss_fn = createLoss ( FLAGS )
trainer_noddp = Trainer ( glctx , geometry , lgt , opt_material , optimize_geometry , optimize_light , image_loss_fn , FLAGS )
# Single GPU training mode
trainer = trainer_noddp
if optimize_geometry :
optimizer_mesh = torch . optim . Adam ( [
{ ' params ' : trainer_noddp . sdf_params , ' lr ' : learning_rate_pos } ,
{ ' params ' : trainer_noddp . deform_params , ' lr ' : learning_rate_pos } ,
]
)
scheduler_mesh = torch . optim . lr_scheduler . LambdaLR ( optimizer_mesh , lr_lambda = lambda x : lr_schedule ( x , 0.9 ) )
optimizer = torch . optim . Adam ( trainer_noddp . params , lr = learning_rate_mat )
scheduler = torch . optim . lr_scheduler . LambdaLR ( optimizer , lr_lambda = lambda x : lr_schedule ( x , 0.9 ) )
# ==============================================================================================
# Training loop
# ==============================================================================================
img_cnt = 0
img_loss_vec = [ ]
reg_loss_vec = [ ]
iter_dur_vec = [ ]
dataloader_validate = torch . utils . data . DataLoader ( dataset_validate , batch_size = 1 , collate_fn = dataset_train . collate )
def cycle ( iterable ) :
iterator = iter ( iterable )
while True :
try :
yield next ( iterator )
except StopIteration :
iterator = iter ( iterable )
v_it = cycle ( dataloader_validate )
# v_iter_no = 25
v_iter_no = 10
print ( " Start training loop... " )
sys . stdout . flush ( )
for _ in range ( v_iter_no ) :
v_curr = next ( v_it )
for it in range ( 5000 ) :
# Mix randomized background into dataset image
target = prepare_batch ( v_curr , ' random ' )
### for robustness, we take the easy way of initializing the tet grid with the gt depth image
if it < 300 and it % 10 == 0 :
gt_visible_triangles = target [ ' rast_triangle_id ' ] . long ( )
gt_verts , gt_faces = target [ ' vpts ' ] , target [ ' faces ' ]
surface_faces = gt_faces [ gt_visible_triangles ]
campos = target [ ' campos ' ] [ 0 ]
try :
geometry . init_with_gt_surface ( gt_verts , surface_faces , campos )
except :
pass
iter_start_time = time . time ( )
# ==============================================================================================
# Zero gradients
# ==============================================================================================
optimizer . zero_grad ( )
if optimize_geometry :
optimizer_mesh . zero_grad ( )
# ==============================================================================================
# Training
# ==============================================================================================
img_loss , reg_loss = trainer ( target , it )
# ==============================================================================================
# Final loss
# ==============================================================================================
total_loss = img_loss + reg_loss
img_loss_vec . append ( img_loss . item ( ) )
reg_loss_vec . append ( reg_loss . item ( ) )
# ==============================================================================================
# Backpropagate
# ==============================================================================================
total_loss . backward ( )
if hasattr ( lgt , ' base ' ) and lgt . base . grad is not None and optimize_light :
lgt . base . grad * = 64
if ' kd_ks_normal ' in opt_material :
opt_material [ ' kd_ks_normal ' ] . encoder . params . grad / = 8.0
if ' normal ' in opt_material and FLAGS . normal_only :
try :
opt_material [ ' normal ' ] . encoder . params . grad / = 8.0
except :
pass
optimizer . step ( )
scheduler . step ( )
if optimize_geometry :
optimizer_mesh . step ( )
scheduler_mesh . step ( )
geometry . clamp_deform ( )
if FLAGS . use_ema :
raise NotImplementedError
geometry . update_ema ( )
# ==============================================================================================
# Clamp trainables to reasonable range
# ==============================================================================================
with torch . no_grad ( ) :
if ' kd ' in opt_material :
opt_material [ ' kd ' ] . clamp_ ( )
if ' ks ' in opt_material :
opt_material [ ' ks ' ] . clamp_ ( )
if ' normal ' in opt_material and not FLAGS . normal_only :
opt_material [ ' normal ' ] . clamp_ ( )
opt_material [ ' normal ' ] . normalize_ ( )
if lgt is not None :
lgt . clamp_ ( min = 0.0 )
torch . cuda . current_stream ( ) . synchronize ( )
iter_dur_vec . append ( time . time ( ) - iter_start_time )
# ==============================================================================================
# Logging
# ==============================================================================================
if it % log_interval == 0 and FLAGS . local_rank == 0 :
img_loss_avg = np . mean ( np . asarray ( img_loss_vec [ - log_interval : ] ) )
reg_loss_avg = np . mean ( np . asarray ( reg_loss_vec [ - log_interval : ] ) )
iter_dur_avg = np . mean ( np . asarray ( iter_dur_vec [ - log_interval : ] ) )
remaining_time = ( FLAGS . iter - it ) * iter_dur_avg
print ( " iter= %5d , img_loss= %.6f , reg_loss= %.6f , lr= %.5f , time= %.1f ms, rem= %s " %
( it , img_loss_avg , reg_loss_avg , optimizer . param_groups [ 0 ] [ ' lr ' ] , iter_dur_avg * 1000 , util . time_to_text ( remaining_time ) ) )
sys . stdout . flush ( )
return geometry , opt_material
#----------------------------------------------------------------------------
# Main function.
#----------------------------------------------------------------------------
if __name__ == " __main__ " :
parser = argparse . ArgumentParser ( description = ' nvdiffrec ' )
parser . add_argument ( ' --config ' , type = str , default = ' ./configs/res64.json ' , help = ' Config file ' )
parser . add_argument ( ' -i ' , ' --iter ' , type = int , default = 5000 )
parser . add_argument ( ' -s ' , ' --spp ' , type = int , default = 1 )
parser . add_argument ( ' -l ' , ' --layers ' , type = int , default = 1 )
parser . add_argument ( ' -r ' , ' --train-res ' , nargs = 2 , type = int , default = [ 512 , 512 ] )
parser . add_argument ( ' -dr ' , ' --display-res ' , type = int , default = None )
parser . add_argument ( ' -tr ' , ' --texture-res ' , nargs = 2 , type = int , default = [ 1024 , 1024 ] )
parser . add_argument ( ' -di ' , ' --display-interval ' , type = int , default = 0 )
parser . add_argument ( ' -si ' , ' --save-interval ' , type = int , default = 1000 )
parser . add_argument ( ' -lr ' , ' --learning-rate ' , type = float , default = 0.01 )
parser . add_argument ( ' -mr ' , ' --min-roughness ' , type = float , default = 0.08 )
parser . add_argument ( ' -mip ' , ' --custom-mip ' , action = ' store_true ' , default = False )
parser . add_argument ( ' -rt ' , ' --random-textures ' , action = ' store_true ' , default = False )
parser . add_argument ( ' -bg ' , ' --background ' , default = ' checker ' , choices = [ ' black ' , ' white ' , ' checker ' , ' reference ' ] )
parser . add_argument ( ' --loss ' , default = ' logl1 ' , choices = [ ' logl1 ' , ' logl2 ' , ' mse ' , ' smape ' , ' relmse ' ] )
parser . add_argument ( ' -o ' , ' --out-dir ' , type = str , default = ' ./dmtet_results_singleview ' )
parser . add_argument ( ' --validate ' , type = bool , default = True )
parser . add_argument ( ' -no ' , ' --normal-only ' , type = bool , default = True )
parser . add_argument ( ' -ema ' , ' --use-ema ' , action = " store_true " )
parser . add_argument ( ' -rp ' , ' --resume-path ' , type = str , default = None )
parser . add_argument ( ' -mp ' , ' --mesh-path ' , type = str )
parser . add_argument ( ' -an ' , ' --angle-ind ' , type = int , help = ' angle index from 0 to 50 ' )
FLAGS = parser . parse_args ( )
print ( f " parsed arguments " )
FLAGS . mtl_override = None # Override material of model
FLAGS . dmtet_grid = 64 # Resolution of initial tet grid. We provide 64 and 128 resolution grids. Other resolutions can be generated with https://github.com/crawforddoran/quartet
FLAGS . mesh_scale = 1.0 # Scale of tet grid box. Adjust to cover the model
FLAGS . env_scale = 1.0 # Env map intensity multiplier
FLAGS . envmap = None # HDR environment probe
FLAGS . display = None # Conf validation window/display. E.g. [{"relight" : <path to envlight>}]
FLAGS . camera_space_light = False # Fixed light in camera space. This is needed for setups like ethiopian head where the scanned object rotates on a stand.
FLAGS . lock_light = False # Disable light optimization in the second pass
FLAGS . lock_pos = False # Disable vertex position optimization in the second pass
FLAGS . sdf_regularizer = 0.2 # Weight for sdf regularizer (see paper for details)
FLAGS . laplace = " relative " # Mesh Laplacian ["absolute", "relative"]
FLAGS . laplace_scale = 10000.0 # Weight for sdf regularizer. Default is relative with large weight
FLAGS . pre_load = True # Pre-load entire dataset into memory for faster training
FLAGS . kd_min = [ 0.0 , 0.0 , 0.0 , 0.0 ] # Limits for kd
FLAGS . kd_max = [ 1.0 , 1.0 , 1.0 , 1.0 ]
FLAGS . ks_min = [ 0.0 , 0.08 , 0.0 ] # Limits for ks
FLAGS . ks_max = [ 1.0 , 1.0 , 1.0 ]
FLAGS . nrm_min = [ - 1.0 , - 1.0 , 0.0 ] # Limits for normal map
FLAGS . nrm_max = [ 1.0 , 1.0 , 1.0 ]
FLAGS . cam_near_far = [ 0.1 , 1000.0 ]
FLAGS . learn_light = False
FLAGS . use_ema = False
FLAGS . random_lgt = True
FLAGS . dataset_flat_shading = False
FLAGS . local_rank = 0
if FLAGS . config is not None :
data = json . load ( open ( FLAGS . config , ' r ' ) )
for key in data :
FLAGS . __dict__ [ key ] = data [ key ]
if FLAGS . display_res is None :
FLAGS . display_res = FLAGS . train_res
print ( f " Out dir: { FLAGS . out_dir } " )
if FLAGS . local_rank == 0 :
print ( " Config / Flags: " )
print ( " --------- " )
for key in FLAGS . __dict__ . keys ( ) :
print ( key , FLAGS . __dict__ [ key ] )
print ( " --------- " )
os . makedirs ( FLAGS . out_dir , exist_ok = True )
os . makedirs ( os . path . join ( FLAGS . out_dir , ' val_viz ' ) , exist_ok = True )
os . makedirs ( os . path . join ( FLAGS . out_dir , ' val_viz_pre ' ) , exist_ok = True )
os . makedirs ( os . path . join ( FLAGS . out_dir , ' tets ' ) , exist_ok = True )
os . makedirs ( os . path . join ( FLAGS . out_dir , ' tets_pre ' ) , exist_ok = True )
print ( f " Using dmtet grid of resolution { FLAGS . dmtet_grid } " )
glctx = dr . RasterizeGLContext ( )
### Default mtl
mtl_default = {
' name ' : ' _default_mat ' ,
' bsdf ' : ' diffuse ' ,
' uniform ' : True ,
' kd ' : texture . Texture2D ( torch . tensor ( [ 0.75 , 0.3 , 0.6 ] , dtype = torch . float32 , device = ' cuda ' ) , trainable = False ) ,
' ks ' : texture . Texture2D ( torch . tensor ( [ 0.0 , 0.0 , 0.0 ] , dtype = torch . float32 , device = ' cuda ' ) , trainable = False )
}
print ( f " Loading mesh: { FLAGS . mesh_path } " )
sys . stdout . flush ( )
ref_mesh = mesh . load_mesh ( FLAGS . mesh_path , FLAGS . mtl_override , mtl_default , use_default = FLAGS . normal_only , no_additional = True )
ref_mesh = mesh . center_by_reference ( ref_mesh , mesh . aabb_clean ( ref_mesh ) , 1.0 )
print ( " Loading dataset " )
sys . stdout . flush ( )
dataset_train = DatasetMesh ( ref_mesh , glctx , RADIUS , FLAGS , validate = False )
dataset_validate = DatasetMesh ( ref_mesh , glctx , RADIUS , FLAGS , validate = True )
print ( " Dataset loaded " )
sys . stdout . flush ( )
# ==============================================================================================
# Create env light with trainable parameters
# ==============================================================================================
if FLAGS . learn_light :
lgt = light . create_trainable_env_rnd ( 512 , scale = 0.0 , bias = 0.5 )
else :
lgt = light . load_env ( FLAGS . envmap , scale = FLAGS . env_scale , trainable = False )
# ==============================================================================================
# If no initial guess, use DMtets to create geometry
# ==============================================================================================
# Setup geometry for optimization
2023-03-11 15:22:14 +00:00
geometry = DMTetGeometry (
FLAGS . dmtet_grid , FLAGS . mesh_scale , FLAGS ,
deform_scale = FLAGS . first_stage_deform
)
2023-03-11 10:17:59 +00:00
# Setup textures, make initial guess from reference if possible
if not FLAGS . normal_only :
mat = initial_guess_material ( geometry , True , FLAGS , mtl_default )
else :
mat = initial_guess_material_knownkskd ( geometry , True , FLAGS , mtl_default )
print ( " Start optimization " )
sys . stdout . flush ( )
if FLAGS . resume_path is None :
# Run optimization
geometry , mat = optimize_mesh ( glctx , geometry , mat , lgt , dataset_train , dataset_validate ,
FLAGS , pass_idx = 0 , pass_name = " dmtet_pass1 " , optimize_light = FLAGS . learn_light )
base_mesh = geometry . getMesh ( mat )
vert_mask = torch . zeros_like ( geometry . sdf ) . long ( ) . cuda ( ) . view ( - 1 , 1 )
vert_mask [ geometry . getValidVertsIdx ( ) ] = 1
# Free temporaries / cached memory
torch . cuda . empty_cache ( ) ### may slow down training
old_geometry = geometry
if FLAGS . local_rank == 0 and FLAGS . validate :
validate ( glctx , geometry , mat , lgt , dataset_validate , os . path . join ( FLAGS . out_dir , f " val_viz_pre/dmtet_validate_ { FLAGS . index } _ { k } _ { FLAGS . split_size } " ) , FLAGS )
else :
2023-09-01 03:54:10 +00:00
dmt_dict = torch . load ( FLAGS . resume_path )
2023-03-11 10:17:59 +00:00
if FLAGS . use_ema :
geometry . sdf . data [ : ] = dmt_dict [ ' sdf_ema ' ]
else :
geometry . sdf . data [ : ] = dmt_dict [ ' sdf ' ]
geometry . deform . data [ : ] = dmt_dict [ ' deform ' ]
old_geometry = geometry
# Create textured mesh from result
if FLAGS . normal_only :
base_mesh = xatlas_uvmap_nrm ( glctx , geometry , mat , FLAGS )
else :
base_mesh = xatlas_uvmap ( glctx , geometry , mat , FLAGS )
2023-03-11 15:22:14 +00:00
geometry = DMTetGeometryFixedTopo (
geometry , base_mesh , FLAGS . dmtet_grid , FLAGS . mesh_scale , FLAGS ,
deform_scale = FLAGS . second_stage_deform
)
2023-03-11 10:17:59 +00:00
geometry . sdf_sign . requires_grad = False
geometry . sdf_abs . requires_grad = False
geometry . deform . requires_grad = True
2023-03-11 15:22:14 +00:00
geometry . deform . data [ : ] = geometry . deform * FLAGS . first_stage_deform / FLAGS . second_stage_deform
2023-03-11 10:17:59 +00:00
if FLAGS . use_ema :
geometry . sdf_sign . data [ : ] = torch . sign ( old_geometry . sdf_ema ) ### use ema
else :
geometry . sdf_sign . data [ : ] = torch . sign ( old_geometry . sdf ) ### use ema
geometry . set_init_v_pos ( )
# ==============================================================================================
# Pass 2: Train with fixed topology (mesh)
# ==============================================================================================
geometry , mat = optimize_mesh ( glctx , geometry , mat , lgt , dataset_train , dataset_validate , FLAGS ,
pass_idx = 1 , pass_name = " mesh_pass " , warmup_iter = 100 , optimize_light = FLAGS . learn_light and not FLAGS . lock_light ,
optimize_geometry = not FLAGS . lock_pos )
##### Process single-view tet grid
dataloader_validate = torch . utils . data . DataLoader ( dataset_validate , batch_size = 1 , collate_fn = dataset_train . collate )
v_it = iter ( dataloader_validate )
for _ in range ( FLAGS . angle_ind ) :
v_curr = next ( v_it )
target = prepare_batch ( v_curr , ' random ' )
# ==============================================================================================
# Infer occluded regions
# ==============================================================================================
valid_tet_idx = geometry . getValidTetIdx ( ) . long ( )
buffers = geometry . render ( glctx , target , lgt , mat , get_visible_tets = True )
## visible tets (except for rasterized ones)
visible_tets = torch . zeros ( geometry . indices . size ( 0 ) ) . cuda ( )
visible_tets [ buffers [ ' visible_tet_id ' ] . long ( ) ] = 1
## to include the rasterized tetrahedra
visible_and_rast_tets = visible_tets . clone ( )
rast_tet_id = valid_tet_idx [ buffers [ ' rast_triangle_id ' ] . long ( ) ] . unique ( )
visible_and_rast_tets [ rast_tet_id ] = 1
visible_tets = ( visible_tets == 1 )
visible_and_rast_tets = ( visible_and_rast_tets == 1 )
## label all tetrahedral vertices associated with any visible tets
visible_verts = torch . zeros ( geometry . verts . size ( 0 ) )
tet_inds = torch . arange ( geometry . indices . size ( 0 ) )
vis_vert_inds = geometry . indices [ visible_tets ] . unique ( )
visible_verts [ vis_vert_inds ] = 1
visible_and_rast_verts = visible_verts . clone ( )
vis_and_rast_vert_inds = geometry . indices [ visible_and_rast_tets ] . unique ( )
visible_and_rast_verts [ vis_and_rast_vert_inds ] = 1
visible_and_rast_verts = visible_and_rast_verts . bool ( )
torch . save ( {
' sdf ' : geometry . sdf_sign . cpu ( ) . detach ( ) ,
' deform ' : geometry . deform . cpu ( ) . detach ( ) ,
' vis ' : visible_verts . cpu ( ) . detach ( ) ,
' vis_rast ' : visible_and_rast_verts . cpu ( ) . detach ( )
2023-09-01 03:54:10 +00:00
} , os . path . join ( FLAGS . out_dir , ' tets/dmtet.pt ' ) )
2023-03-11 10:17:59 +00:00
# ==============================================================================================
if FLAGS . local_rank == 0 and FLAGS . validate :
validate ( glctx , geometry , mat , lgt , dataset_validate , os . path . join ( FLAGS . out_dir , f " val_viz/dmtet " ) , FLAGS )