kopia lustrzana https://github.com/lzzcd001/MeshDiffusion
814 wiersze
37 KiB
Python
814 wiersze
37 KiB
Python
# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
#
|
|
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
|
# property and proprietary rights in and to this material, related
|
|
# documentation and any modifications thereto. Any use, reproduction,
|
|
# disclosure or distribution of this material and related documentation
|
|
# without an express license agreement from NVIDIA CORPORATION or
|
|
# its affiliates is strictly prohibited.
|
|
|
|
import os
|
|
import time
|
|
import argparse
|
|
import json
|
|
import sys
|
|
|
|
import cv2
|
|
import numpy as np
|
|
import torch
|
|
import nvdiffrast.torch as dr
|
|
import xatlas
|
|
|
|
# Import data readers / generators
|
|
from lib.dataset.dataset_mesh import DatasetMesh
|
|
from lib.dataset.dataset_shapenet import ShapeNetDataset
|
|
|
|
# Import topology / geometry trainers
|
|
from lib.geometry.dmtet import DMTetGeometry
|
|
from lib.geometry.dmtet_fixedtopo import DMTetGeometryFixedTopo
|
|
|
|
import lib.render.renderutils as ru
|
|
from lib.render import obj
|
|
from lib.render import material
|
|
from lib.render import util
|
|
from lib.render import mesh
|
|
from lib.render import texture
|
|
from lib.render import mlptexture
|
|
from lib.render import light
|
|
from lib.render import render
|
|
|
|
import traceback
|
|
|
|
|
|
RADIUS = 2.0
|
|
|
|
# # Enable to debug back-prop anomalies
|
|
# torch.autograd.set_detect_anomaly(True)
|
|
|
|
# define colors
|
|
color1 = (0, 0, 255) #red
|
|
color2 = (0, 165, 255) #orange
|
|
color3 = (0, 255, 255) #yellow
|
|
color4 = (255, 255, 0) #cyan
|
|
color5 = (255, 0, 0) #blue
|
|
color6 = (128, 64, 64) #violet
|
|
colorArr = np.array([[color1, color2, color3, color4, color5, color6]], dtype=np.uint8)
|
|
|
|
# resize lut to 256 (or more) values
|
|
lut = cv2.resize(colorArr, (256,1), interpolation = cv2.INTER_LINEAR)
|
|
|
|
###############################################################################
|
|
# Loss setup
|
|
###############################################################################
|
|
|
|
@torch.no_grad()
|
|
def createLoss(FLAGS):
|
|
if FLAGS.loss == "smape":
|
|
return lambda img, ref: ru.image_loss(img, ref, loss='smape', tonemapper='none')
|
|
elif FLAGS.loss == "mse":
|
|
return lambda img, ref: ru.image_loss(img, ref, loss='mse', tonemapper='none')
|
|
elif FLAGS.loss == "logl1":
|
|
return lambda img, ref: ru.image_loss(img, ref, loss='l1', tonemapper='log_srgb')
|
|
elif FLAGS.loss == "logl2":
|
|
return lambda img, ref: ru.image_loss(img, ref, loss='mse', tonemapper='log_srgb')
|
|
elif FLAGS.loss == "relmse":
|
|
return lambda img, ref: ru.image_loss(img, ref, loss='relmse', tonemapper='none')
|
|
else:
|
|
assert False
|
|
|
|
###############################################################################
|
|
# Mix background into a dataset image
|
|
###############################################################################
|
|
|
|
@torch.no_grad()
|
|
def prepare_batch(target, bg_type='black'):
|
|
assert len(target['img'].shape) == 4, "Image shape should be [n, h, w, c]"
|
|
if bg_type == 'checker':
|
|
background = torch.tensor(util.checkerboard(target['img'].shape[1:3], 8), dtype=torch.float32, device='cuda')[None, ...]
|
|
elif bg_type == 'black':
|
|
background = torch.zeros(target['img'].shape[0:3] + (3,), dtype=torch.float32, device='cuda')
|
|
elif bg_type == 'white':
|
|
background = torch.ones(target['img'].shape[0:3] + (3,), dtype=torch.float32, device='cuda')
|
|
elif bg_type == 'reference':
|
|
background = target['img'][..., 0:3]
|
|
elif bg_type == 'random':
|
|
background = torch.rand(target['img'].shape[0:3] + (3,), dtype=torch.float32, device='cuda')
|
|
else:
|
|
assert False, "Unknown background type %s" % bg_type
|
|
|
|
target['mv'] = target['mv'].cuda()
|
|
target['mvp'] = target['mvp'].cuda()
|
|
target['campos'] = target['campos'].cuda()
|
|
target['img'] = target['img'].cuda()
|
|
target['background'] = background
|
|
|
|
target['img'] = torch.cat((torch.lerp(background, target['img'][..., 0:3], target['img'][..., 3:4]), target['img'][..., 3:4]), dim=-1)
|
|
|
|
target['spts'] = target['spts'].cuda()
|
|
target['vpts'] = target['vpts'].cuda()
|
|
return target
|
|
|
|
###############################################################################
|
|
# UV - map geometry & convert to a mesh
|
|
###############################################################################
|
|
|
|
@torch.no_grad()
|
|
def xatlas_uvmap(glctx, geometry, mat, FLAGS):
|
|
eval_mesh = geometry.getMesh(mat)
|
|
|
|
# Create uvs with xatlas
|
|
v_pos = eval_mesh.v_pos.detach().cpu().numpy()
|
|
t_pos_idx = eval_mesh.t_pos_idx.detach().cpu().numpy()
|
|
vmapping, indices, uvs = xatlas.parametrize(v_pos, t_pos_idx)
|
|
|
|
# Convert to tensors
|
|
indices_int64 = indices.astype(np.uint64, casting='same_kind').view(np.int64)
|
|
|
|
uvs = torch.tensor(uvs, dtype=torch.float32, device='cuda')
|
|
faces = torch.tensor(indices_int64, dtype=torch.int64, device='cuda')
|
|
|
|
new_mesh = mesh.Mesh(v_tex=uvs, t_tex_idx=faces, base=eval_mesh)
|
|
|
|
mask, kd, ks, normal = render.render_uv(glctx, new_mesh, FLAGS.texture_res, eval_mesh.material['kd_ks_normal'])
|
|
|
|
if FLAGS.layers > 1:
|
|
kd = torch.cat((kd, torch.rand_like(kd[...,0:1])), dim=-1)
|
|
|
|
kd_min, kd_max = torch.tensor(FLAGS.kd_min, dtype=torch.float32, device='cuda'), torch.tensor(FLAGS.kd_max, dtype=torch.float32, device='cuda')
|
|
ks_min, ks_max = torch.tensor(FLAGS.ks_min, dtype=torch.float32, device='cuda'), torch.tensor(FLAGS.ks_max, dtype=torch.float32, device='cuda')
|
|
nrm_min, nrm_max = torch.tensor(FLAGS.nrm_min, dtype=torch.float32, device='cuda'), torch.tensor(FLAGS.nrm_max, dtype=torch.float32, device='cuda')
|
|
|
|
new_mesh.material = material.Material({
|
|
'bsdf' : mat['bsdf'],
|
|
'kd' : texture.Texture2D(kd, min_max=[kd_min, kd_max]),
|
|
'ks' : texture.Texture2D(ks, min_max=[ks_min, ks_max]),
|
|
'normal' : texture.Texture2D(normal, min_max=[nrm_min, nrm_max])
|
|
})
|
|
|
|
return new_mesh
|
|
|
|
@torch.no_grad()
|
|
def xatlas_uvmap_nrm(glctx, geometry, mat, FLAGS):
|
|
eval_mesh = geometry.getMesh(mat)
|
|
|
|
# Create uvs with xatlas
|
|
v_pos = eval_mesh.v_pos.detach().cpu().numpy()
|
|
t_pos_idx = eval_mesh.t_pos_idx.detach().cpu().numpy()
|
|
vmapping, indices, uvs = xatlas.parametrize(v_pos, t_pos_idx)
|
|
|
|
# Convert to tensors
|
|
indices_int64 = indices.astype(np.uint64, casting='same_kind').view(np.int64)
|
|
|
|
uvs = torch.tensor(uvs, dtype=torch.float32, device='cuda')
|
|
faces = torch.tensor(indices_int64, dtype=torch.int64, device='cuda')
|
|
|
|
new_mesh = mesh.Mesh(v_tex=uvs, t_tex_idx=faces, base=eval_mesh)
|
|
|
|
mask, normal = render.render_uv_nrm(glctx, new_mesh, FLAGS.texture_res, eval_mesh.material['normal'])
|
|
|
|
if FLAGS.layers > 1:
|
|
kd = torch.cat((kd, torch.rand_like(kd[...,0:1])), dim=-1)
|
|
|
|
nrm_min, nrm_max = torch.tensor(FLAGS.nrm_min, dtype=torch.float32, device='cuda'), torch.tensor(FLAGS.nrm_max, dtype=torch.float32, device='cuda')
|
|
|
|
new_mesh.material = material.Material({
|
|
'bsdf' : mat['bsdf'],
|
|
'kd' : mat['kd'],
|
|
'ks' : mat['ks'],
|
|
'normal' : texture.Texture2D(normal, min_max=[nrm_min, nrm_max])
|
|
})
|
|
|
|
return new_mesh
|
|
|
|
###############################################################################
|
|
# Utility functions for material
|
|
###############################################################################
|
|
|
|
def initial_guess_material(geometry, mlp, FLAGS, init_mat=None):
|
|
kd_min, kd_max = torch.tensor(FLAGS.kd_min, dtype=torch.float32, device='cuda'), torch.tensor(FLAGS.kd_max, dtype=torch.float32, device='cuda')
|
|
ks_min, ks_max = torch.tensor(FLAGS.ks_min, dtype=torch.float32, device='cuda'), torch.tensor(FLAGS.ks_max, dtype=torch.float32, device='cuda')
|
|
nrm_min, nrm_max = torch.tensor(FLAGS.nrm_min, dtype=torch.float32, device='cuda'), torch.tensor(FLAGS.nrm_max, dtype=torch.float32, device='cuda')
|
|
if mlp:
|
|
mlp_min = torch.cat((kd_min[0:3], ks_min, nrm_min), dim=0)
|
|
mlp_max = torch.cat((kd_max[0:3], ks_max, nrm_max), dim=0)
|
|
mlp_map_opt = mlptexture.MLPTexture3D(geometry.getAABB(), channels=9, min_max=[mlp_min, mlp_max])
|
|
mat = material.Material({'kd_ks_normal' : mlp_map_opt})
|
|
else:
|
|
# Setup Kd (albedo) and Ks (x, roughness, metalness) textures
|
|
if FLAGS.random_textures or init_mat is None:
|
|
num_channels = 4 if FLAGS.layers > 1 else 3
|
|
kd_init = torch.rand(size=FLAGS.texture_res + [num_channels], device='cuda') * (kd_max - kd_min)[None, None, 0:num_channels] + kd_min[None, None, 0:num_channels]
|
|
kd_map_opt = texture.create_trainable(kd_init , FLAGS.texture_res, not FLAGS.custom_mip, [kd_min, kd_max])
|
|
|
|
ksR = np.random.uniform(size=FLAGS.texture_res + [1], low=0.0, high=0.01)
|
|
ksG = np.random.uniform(size=FLAGS.texture_res + [1], low=ks_min[1].cpu(), high=ks_max[1].cpu())
|
|
ksB = np.random.uniform(size=FLAGS.texture_res + [1], low=ks_min[2].cpu(), high=ks_max[2].cpu())
|
|
|
|
ks_map_opt = texture.create_trainable(np.concatenate((ksR, ksG, ksB), axis=2), FLAGS.texture_res, not FLAGS.custom_mip, [ks_min, ks_max])
|
|
else:
|
|
kd_map_opt = texture.create_trainable(init_mat['kd'], FLAGS.texture_res, not FLAGS.custom_mip, [kd_min, kd_max])
|
|
ks_map_opt = texture.create_trainable(init_mat['ks'], FLAGS.texture_res, not FLAGS.custom_mip, [ks_min, ks_max])
|
|
|
|
# Setup normal map
|
|
if FLAGS.random_textures or init_mat is None or 'normal' not in init_mat:
|
|
normal_map_opt = texture.create_trainable(np.array([0, 0, 1]), FLAGS.texture_res, not FLAGS.custom_mip, [nrm_min, nrm_max])
|
|
else:
|
|
normal_map_opt = texture.create_trainable(init_mat['normal'], FLAGS.texture_res, not FLAGS.custom_mip, [nrm_min, nrm_max])
|
|
|
|
mat = material.Material({
|
|
'kd' : kd_map_opt,
|
|
'ks' : ks_map_opt,
|
|
'normal' : normal_map_opt
|
|
})
|
|
|
|
if init_mat is not None:
|
|
mat['bsdf'] = init_mat['bsdf']
|
|
else:
|
|
mat['bsdf'] = 'pbr'
|
|
|
|
return mat
|
|
|
|
def initial_guess_material_knownkskd(geometry, mlp, FLAGS, init_mat=None):
|
|
nrm_min, nrm_max = torch.tensor(FLAGS.nrm_min, dtype=torch.float32, device='cuda'), torch.tensor(FLAGS.nrm_max, dtype=torch.float32, device='cuda')
|
|
|
|
if mlp:
|
|
mlp_min = nrm_min
|
|
mlp_max = nrm_max
|
|
mlp_map_opt = mlptexture.MLPTexture3D(geometry.getAABB(), channels=3, min_max=[mlp_min, mlp_max])
|
|
# mlp_map_opt = mlptexture.MLPTexture3D(geometry.getAABB(), channels=3, min_max=None)
|
|
mat = material.Material({
|
|
'kd' : init_mat['kd'],
|
|
'ks' : init_mat['ks'],
|
|
'normal' : mlp_map_opt,
|
|
})
|
|
else:
|
|
# Setup normal map
|
|
if FLAGS.random_textures or init_mat is None or 'normal' not in init_mat:
|
|
normal_map_opt = texture.create_trainable(np.array([0, 0, 1]), FLAGS.texture_res, not FLAGS.custom_mip, [nrm_min, nrm_max])
|
|
else:
|
|
normal_map_opt = texture.create_trainable(init_mat['normal'], FLAGS.texture_res, not FLAGS.custom_mip, [nrm_min, nrm_max])
|
|
|
|
mat = material.Material({
|
|
'kd' : init_mat['kd'],
|
|
'ks' : init_mat['ks'],
|
|
'normal' : normal_map_opt
|
|
})
|
|
|
|
if init_mat is not None:
|
|
mat['bsdf'] = init_mat['bsdf']
|
|
else:
|
|
mat['bsdf'] = 'pbr'
|
|
|
|
return mat
|
|
|
|
###############################################################################
|
|
# Validation & testing
|
|
###############################################################################
|
|
|
|
def validate_itr(glctx, target, geometry, opt_material, lgt, FLAGS):
|
|
result_dict = {}
|
|
with torch.no_grad():
|
|
lgt.build_mips()
|
|
if FLAGS.camera_space_light:
|
|
lgt.xfm(target['mv'])
|
|
lgt.xfm(target['envlight_transform'])
|
|
|
|
try:
|
|
buffers = geometry.render(glctx, target, lgt, opt_material, ema=True, xfm_lgt=target['envlight_transform'])
|
|
except:
|
|
buffers = geometry.render(glctx, target, lgt, opt_material, xfm_lgt=target['envlight_transform'])
|
|
|
|
result_dict['ref'] = util.rgb_to_srgb(target['img'][...,0:3])[0]
|
|
result_dict['opt'] = util.rgb_to_srgb(buffers['shaded'][...,0:3])[0]
|
|
result_image = torch.cat([result_dict['opt'], result_dict['ref']], axis=1)
|
|
|
|
return result_image, result_dict
|
|
|
|
def validate(glctx, geometry, opt_material, lgt, dataset_validate, out_dir, FLAGS):
|
|
|
|
# ==============================================================================================
|
|
# Validation loop
|
|
# ==============================================================================================
|
|
mse_values = []
|
|
psnr_values = []
|
|
|
|
dataloader_validate = torch.utils.data.DataLoader(dataset_validate, batch_size=1, collate_fn=dataset_validate.collate)
|
|
|
|
os.makedirs(out_dir, exist_ok=True)
|
|
with open(os.path.join(out_dir, 'metrics.txt'), 'w') as fout:
|
|
fout.write('ID, MSE, PSNR\n')
|
|
|
|
print("Running validation")
|
|
for it, target in enumerate(dataloader_validate):
|
|
|
|
# Mix validation background
|
|
target = prepare_batch(target, FLAGS.background)
|
|
|
|
result_image, result_dict = validate_itr(glctx, target, geometry, opt_material, lgt, FLAGS)
|
|
|
|
# Compute metrics
|
|
opt = torch.clamp(result_dict['opt'], 0.0, 1.0)
|
|
ref = torch.clamp(result_dict['ref'], 0.0, 1.0)
|
|
|
|
mse = torch.nn.functional.mse_loss(opt, ref, size_average=None, reduce=None, reduction='mean').item()
|
|
mse_values.append(float(mse))
|
|
psnr = util.mse_to_psnr(mse)
|
|
psnr_values.append(float(psnr))
|
|
|
|
line = "%d, %1.8f, %1.8f\n" % (it, mse, psnr)
|
|
fout.write(str(line))
|
|
|
|
for k in result_dict.keys():
|
|
np_img = result_dict[k].detach().cpu().numpy()
|
|
util.save_image(out_dir + '/' + ('val_%06d_%s.png' % (it, k)), np_img)
|
|
|
|
avg_mse = np.mean(np.array(mse_values))
|
|
avg_psnr = np.mean(np.array(psnr_values))
|
|
line = "AVERAGES: %1.4f, %2.3f\n" % (avg_mse, avg_psnr)
|
|
fout.write(str(line))
|
|
print("MSE, PSNR")
|
|
print("%1.8f, %2.3f" % (avg_mse, avg_psnr))
|
|
return avg_psnr
|
|
|
|
###############################################################################
|
|
# Main shape fitter function / optimization loop
|
|
###############################################################################
|
|
|
|
class Trainer(torch.nn.Module):
|
|
def __init__(self, glctx, geometry, lgt, mat, optimize_geometry, optimize_light, image_loss_fn, FLAGS):
|
|
super(Trainer, self).__init__()
|
|
|
|
self.glctx = glctx
|
|
self.geometry = geometry
|
|
self.light = lgt
|
|
self.material = mat
|
|
self.optimize_geometry = optimize_geometry
|
|
self.optimize_light = optimize_light
|
|
self.image_loss_fn = image_loss_fn
|
|
self.FLAGS = FLAGS
|
|
|
|
if not self.optimize_light:
|
|
with torch.no_grad():
|
|
self.light.build_mips()
|
|
|
|
self.params = list(self.material.parameters())
|
|
self.params += list(self.light.parameters()) if optimize_light else []
|
|
self.geo_params = list(self.geometry.parameters()) if optimize_geometry else []
|
|
try:
|
|
self.sdf_params = [self.geometry.sdf]
|
|
except:
|
|
self.sdf_params = []
|
|
self.deform_params = [self.geometry.deform]
|
|
|
|
def forward(self, target, it):
|
|
if self.optimize_light:
|
|
self.light.build_mips()
|
|
if self.FLAGS.camera_space_light:
|
|
self.light.xfm(target['mv'])
|
|
self.light.xfm(target['envlight_transform'])
|
|
|
|
return self.geometry.tick(glctx, target, self.light, self.material, self.image_loss_fn, it, xfm_lgt=target['envlight_transform'])
|
|
|
|
def optimize_mesh(
|
|
glctx,
|
|
geometry,
|
|
opt_material,
|
|
lgt,
|
|
dataset_train,
|
|
dataset_validate,
|
|
FLAGS,
|
|
warmup_iter=0,
|
|
log_interval=10,
|
|
pass_idx=0,
|
|
pass_name="",
|
|
optimize_light=True,
|
|
optimize_geometry=True,
|
|
):
|
|
|
|
# ==============================================================================================
|
|
# Setup torch optimizer
|
|
# ==============================================================================================
|
|
|
|
learning_rate = FLAGS.learning_rate[pass_idx] if isinstance(FLAGS.learning_rate, list) or isinstance(FLAGS.learning_rate, tuple) else FLAGS.learning_rate
|
|
learning_rate_pos = learning_rate[0] if isinstance(learning_rate, list) or isinstance(learning_rate, tuple) else learning_rate
|
|
learning_rate_mat = learning_rate[1] if isinstance(learning_rate, list) or isinstance(learning_rate, tuple) else learning_rate
|
|
|
|
def lr_schedule(iter, fraction):
|
|
if iter < warmup_iter:
|
|
return iter / warmup_iter
|
|
return max(0.0, 10**(-(iter - warmup_iter)*0.0002)) # Exponential falloff from [1.0, 0.1] over 5k epochs.
|
|
|
|
# ==============================================================================================
|
|
# Image loss
|
|
# ==============================================================================================
|
|
image_loss_fn = createLoss(FLAGS)
|
|
|
|
trainer_noddp = Trainer(glctx, geometry, lgt, opt_material, optimize_geometry, optimize_light, image_loss_fn, FLAGS)
|
|
|
|
if FLAGS.multi_gpu:
|
|
raise NotImplementedError
|
|
# Multi GPU training mode
|
|
import apex
|
|
from apex.parallel import DistributedDataParallel as DDP
|
|
|
|
trainer = DDP(trainer_noddp)
|
|
trainer.train()
|
|
if optimize_geometry:
|
|
optimizer_mesh = apex.optimizers.FusedAdam(trainer_noddp.geo_params, lr=learning_rate_pos)
|
|
scheduler_mesh = torch.optim.lr_scheduler.LambdaLR(optimizer_mesh, lr_lambda=lambda x: lr_schedule(x, 0.9))
|
|
|
|
optimizer = apex.optimizers.FusedAdam(trainer_noddp.params, lr=learning_rate_mat)
|
|
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda x: lr_schedule(x, 0.9))
|
|
else:
|
|
# Single GPU training mode
|
|
trainer = trainer_noddp
|
|
if optimize_geometry:
|
|
# optimizer_mesh = torch.optim.Adam(trainer_noddp.geo_params, lr=learning_rate_pos)
|
|
optimizer_mesh = torch.optim.Adam([
|
|
{'params': trainer_noddp.sdf_params, 'lr': learning_rate_pos},
|
|
{'params': trainer_noddp.deform_params, 'lr': learning_rate_pos},
|
|
])
|
|
# optimizer_mesh = torch.optim.Adam(trainer_noddp.geo_params, lr=learning_rate_pos, betas=(0.2, 0.999), eps=1e-5)
|
|
scheduler_mesh = torch.optim.lr_scheduler.LambdaLR(optimizer_mesh, lr_lambda=lambda x: lr_schedule(x, 0.9))
|
|
|
|
optimizer = torch.optim.Adam(trainer_noddp.params, lr=learning_rate_mat)
|
|
# optimizer = torch.optim.Adam(trainer_noddp.params, lr=learning_rate_mat, betas=(0.2, 0.999), eps=1e-5)
|
|
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda x: lr_schedule(x, 0.9))
|
|
|
|
# ==============================================================================================
|
|
# Training loop
|
|
# ==============================================================================================
|
|
img_cnt = 0
|
|
img_loss_vec = []
|
|
reg_loss_vec = []
|
|
iter_dur_vec = []
|
|
|
|
dataloader_train = torch.utils.data.DataLoader(dataset_train, batch_size=FLAGS.batch, collate_fn=dataset_train.collate, shuffle=True)
|
|
|
|
print("Start training loop...")
|
|
sys.stdout.flush()
|
|
|
|
for it, target in enumerate(dataloader_train):
|
|
|
|
# Mix randomized background into dataset image
|
|
target = prepare_batch(target, 'random')
|
|
|
|
iter_start_time = time.time()
|
|
|
|
|
|
# ==============================================================================================
|
|
# Zero gradients
|
|
# ==============================================================================================
|
|
optimizer.zero_grad()
|
|
if optimize_geometry:
|
|
optimizer_mesh.zero_grad()
|
|
|
|
# ==============================================================================================
|
|
# Training
|
|
# ==============================================================================================
|
|
img_loss, reg_loss = trainer(target, it)
|
|
|
|
# ==============================================================================================
|
|
# Final loss
|
|
# ==============================================================================================
|
|
total_loss = img_loss + reg_loss
|
|
|
|
img_loss_vec.append(img_loss.item())
|
|
reg_loss_vec.append(reg_loss.item())
|
|
|
|
# ==============================================================================================
|
|
# Backpropagate
|
|
# ==============================================================================================
|
|
total_loss.backward()
|
|
|
|
if hasattr(lgt, 'base') and lgt.base.grad is not None and optimize_light:
|
|
lgt.base.grad *= 64
|
|
if 'kd_ks_normal' in opt_material:
|
|
opt_material['kd_ks_normal'].encoder.params.grad /= 8.0
|
|
if 'normal' in opt_material and FLAGS.normal_only:
|
|
try:
|
|
opt_material['normal'].encoder.params.grad /= 8.0
|
|
except:
|
|
pass
|
|
|
|
optimizer.step()
|
|
scheduler.step()
|
|
|
|
if optimize_geometry:
|
|
optimizer_mesh.step()
|
|
scheduler_mesh.step()
|
|
|
|
geometry.clamp_deform()
|
|
geometry.update_ema()
|
|
|
|
# ==============================================================================================
|
|
# Clamp trainables to reasonable range
|
|
# ==============================================================================================
|
|
with torch.no_grad():
|
|
if 'kd' in opt_material:
|
|
opt_material['kd'].clamp_()
|
|
if 'ks' in opt_material:
|
|
opt_material['ks'].clamp_()
|
|
if 'normal' in opt_material and not FLAGS.normal_only:
|
|
opt_material['normal'].clamp_()
|
|
opt_material['normal'].normalize_()
|
|
if lgt is not None:
|
|
lgt.clamp_(min=0.0)
|
|
|
|
torch.cuda.current_stream().synchronize()
|
|
iter_dur_vec.append(time.time() - iter_start_time)
|
|
|
|
# ==============================================================================================
|
|
# Logging
|
|
# ==============================================================================================
|
|
if it % log_interval == 0 and FLAGS.local_rank == 0:
|
|
img_loss_avg = np.mean(np.asarray(img_loss_vec[-log_interval:]))
|
|
reg_loss_avg = np.mean(np.asarray(reg_loss_vec[-log_interval:]))
|
|
iter_dur_avg = np.mean(np.asarray(iter_dur_vec[-log_interval:]))
|
|
|
|
remaining_time = (FLAGS.iter-it)*iter_dur_avg
|
|
print("iter=%5d, img_loss=%.6f, reg_loss=%.6f, lr=%.5f, time=%.1f ms, rem=%s" %
|
|
(it, img_loss_avg, reg_loss_avg, optimizer.param_groups[0]['lr'], iter_dur_avg*1000, util.time_to_text(remaining_time)))
|
|
sys.stdout.flush()
|
|
|
|
return geometry, opt_material
|
|
|
|
#----------------------------------------------------------------------------
|
|
# Main function.
|
|
#----------------------------------------------------------------------------
|
|
|
|
if __name__ == "__main__":
|
|
# sleep(randint(0,15))
|
|
|
|
parser = argparse.ArgumentParser(description='nvdiffrec')
|
|
parser.add_argument('--config', type=str, default='./configs/res64.json', help='Config file')
|
|
parser.add_argument('-i', '--iter', type=int, default=5000)
|
|
parser.add_argument('-b', '--batch', type=int, default=1)
|
|
parser.add_argument('-s', '--spp', type=int, default=1)
|
|
parser.add_argument('-l', '--layers', type=int, default=1)
|
|
parser.add_argument('-r', '--train-res', nargs=2, type=int, default=[512, 512])
|
|
parser.add_argument('-dr', '--display-res', type=int, default=None)
|
|
parser.add_argument('-tr', '--texture-res', nargs=2, type=int, default=[1024, 1024])
|
|
parser.add_argument('-di', '--display-interval', type=int, default=0)
|
|
parser.add_argument('-si', '--save-interval', type=int, default=1000)
|
|
parser.add_argument('-lr', '--learning-rate', type=float, default=0.01)
|
|
parser.add_argument('-mr', '--min-roughness', type=float, default=0.08)
|
|
parser.add_argument('-mip', '--custom-mip', action='store_true', default=False)
|
|
parser.add_argument('-rt', '--random-textures', action='store_true', default=False)
|
|
parser.add_argument('-bg', '--background', default='checker', choices=['black', 'white', 'checker', 'reference'])
|
|
parser.add_argument('--loss', default='logl1', choices=['logl1', 'logl2', 'mse', 'smape', 'relmse'])
|
|
parser.add_argument('-o', '--out-dir', type=str, default='./dmtet_results')
|
|
parser.add_argument('-bm', '--base-mesh', type=str, default=None)
|
|
parser.add_argument('--validate', type=bool, default=True)
|
|
parser.add_argument('-ind', '--index', type=int)
|
|
parser.add_argument('-ss', '--split-size', type=int, default=10)
|
|
parser.add_argument('--cropped', type=bool, default=True)
|
|
parser.add_argument('-no', '--normal-only', type=bool, default=True)
|
|
parser.add_argument('--meta-path', type=str, default='./data/shapenet_json/chair.json')
|
|
parser.add_argument('-rp', '--resume-path', type=str, default=None)
|
|
parser.add_argument('-ema', '--use-ema', action="store_true")
|
|
|
|
FLAGS = parser.parse_args()
|
|
print(f"parsed arguments")
|
|
global_index = FLAGS.index * FLAGS.split_size
|
|
|
|
FLAGS.mtl_override = None # Override material of model
|
|
FLAGS.dmtet_grid = 64 # Resolution of initial tet grid. We provide 64 and 128 resolution grids. Other resolutions can be generated with https://github.com/crawforddoran/quartet
|
|
FLAGS.mesh_scale = 1.0 # Scale of tet grid box. Adjust to cover the model
|
|
FLAGS.env_scale = 1.0 # Env map intensity multiplier
|
|
FLAGS.envmap = None # HDR environment probe
|
|
FLAGS.display = None # Conf validation window/display. E.g. [{"relight" : <path to envlight>}]
|
|
FLAGS.camera_space_light = False # Fixed light in camera space. This is needed for setups like ethiopian head where the scanned object rotates on a stand.
|
|
FLAGS.lock_light = False # Disable light optimization in the second pass
|
|
FLAGS.lock_pos = False # Disable vertex position optimization in the second pass
|
|
FLAGS.sdf_regularizer = 0.2 # Weight for sdf regularizer (see paper for details)
|
|
FLAGS.laplace = "relative" # Mesh Laplacian ["absolute", "relative"]
|
|
FLAGS.laplace_scale = 10000.0 # Weight for sdf regularizer. Default is relative with large weight
|
|
FLAGS.pre_load = True # Pre-load entire dataset into memory for faster training
|
|
FLAGS.kd_min = [ 0.0, 0.0, 0.0, 0.0] # Limits for kd
|
|
FLAGS.kd_max = [ 1.0, 1.0, 1.0, 1.0]
|
|
FLAGS.ks_min = [ 0.0, 0.08, 0.0] # Limits for ks
|
|
FLAGS.ks_max = [ 1.0, 1.0, 1.0]
|
|
FLAGS.nrm_min = [-1.0, -1.0, 0.0] # Limits for normal map
|
|
FLAGS.nrm_max = [ 1.0, 1.0, 1.0]
|
|
FLAGS.cam_near_far = [0.1, 1000.0]
|
|
FLAGS.learn_light = False
|
|
FLAGS.cropped = True
|
|
FLAGS.use_ema = False
|
|
FLAGS.random_lgt = True
|
|
FLAGS.dataset_flat_shading = False
|
|
|
|
FLAGS.local_rank = 0
|
|
FLAGS.multi_gpu = "WORLD_SIZE" in os.environ and int(os.environ["WORLD_SIZE"]) > 1
|
|
if FLAGS.multi_gpu:
|
|
if "MASTER_ADDR" not in os.environ:
|
|
os.environ["MASTER_ADDR"] = 'localhost'
|
|
if "MASTER_PORT" not in os.environ:
|
|
os.environ["MASTER_PORT"] = '23456'
|
|
|
|
FLAGS.local_rank = int(os.environ["LOCAL_RANK"])
|
|
torch.cuda.set_device(FLAGS.local_rank)
|
|
torch.distributed.init_process_group(backend="nccl", init_method="env://")
|
|
|
|
if FLAGS.config is not None:
|
|
data = json.load(open(FLAGS.config, 'r'))
|
|
for key in data:
|
|
FLAGS.__dict__[key] = data[key]
|
|
|
|
if FLAGS.display_res is None:
|
|
FLAGS.display_res = FLAGS.train_res
|
|
|
|
if FLAGS.local_rank == 0:
|
|
print("Config / Flags:")
|
|
print("---------")
|
|
for key in FLAGS.__dict__.keys():
|
|
print(key, FLAGS.__dict__[key])
|
|
print("---------")
|
|
|
|
os.makedirs(FLAGS.out_dir, exist_ok=True)
|
|
os.makedirs(os.path.join(FLAGS.out_dir, 'val_viz'), exist_ok=True)
|
|
os.makedirs(os.path.join(FLAGS.out_dir, 'tets'), exist_ok=True)
|
|
os.makedirs(os.path.join(FLAGS.out_dir, 'tets_pre'), exist_ok=True)
|
|
|
|
|
|
print(f"Using dmt grid of resolution {FLAGS.dmtet_grid}")
|
|
|
|
glctx = dr.RasterizeGLContext()
|
|
|
|
### Default mtl
|
|
mtl_default = {
|
|
'name' : '_default_mat',
|
|
'bsdf': 'diffuse',
|
|
'uniform': True,
|
|
'kd' : texture.Texture2D(torch.tensor([0.75, 0.3, 0.6], dtype=torch.float32, device='cuda'), trainable=False),
|
|
'ks' : texture.Texture2D(torch.tensor([0.0, 0.0, 0.0], dtype=torch.float32, device='cuda'), trainable=False)
|
|
}
|
|
|
|
|
|
print(f"meta json path {FLAGS.meta_path}")
|
|
shapenet_dataset = ShapeNetDataset(f'{FLAGS.meta_path}')
|
|
|
|
print("Start iterating through objects")
|
|
sys.stdout.flush()
|
|
|
|
|
|
|
|
if len(shapenet_dataset) > 0:
|
|
for k in range(FLAGS.split_size):
|
|
# ==============================================================================================
|
|
# Create data pipeline
|
|
# ==============================================================================================
|
|
|
|
global_index = k + FLAGS.index * FLAGS.split_size
|
|
|
|
print("file path to save: {:s}".format(os.path.join(FLAGS.out_dir, 'tets/dmt_dict_{:05d}.pt'.format(global_index))))
|
|
|
|
skip_if_exists = True
|
|
if skip_if_exists and os.path.exists(os.path.join(FLAGS.out_dir, 'tets/dmt_dict_{:05d}.pt'.format(global_index))):
|
|
continue
|
|
|
|
try:
|
|
global_index = k + FLAGS.index * FLAGS.split_size
|
|
|
|
if global_index >= len(shapenet_dataset):
|
|
break
|
|
mesh_fname = shapenet_dataset[global_index]
|
|
|
|
print(f"Loading mesh: {mesh_fname}")
|
|
sys.stdout.flush()
|
|
ref_mesh = mesh.load_mesh(mesh_fname, FLAGS.mtl_override, mtl_default, use_default=FLAGS.normal_only, no_additional=True)
|
|
ref_mesh = mesh.center_by_reference(ref_mesh, mesh.aabb_clean(ref_mesh), 1.0)
|
|
|
|
a = ref_mesh.v_nrm.clone()
|
|
ref_mesh = mesh.auto_normals(ref_mesh) ### important
|
|
|
|
print("Loading dataset")
|
|
sys.stdout.flush()
|
|
dataset_train = DatasetMesh(ref_mesh, glctx, RADIUS, FLAGS, validate=False)
|
|
dataset_validate = DatasetMesh(ref_mesh, glctx, RADIUS, FLAGS, validate=True)
|
|
print("Dataset loaded")
|
|
sys.stdout.flush()
|
|
|
|
|
|
# ==============================================================================================
|
|
# Create env light with trainable parameters
|
|
# ==============================================================================================
|
|
|
|
if FLAGS.learn_light:
|
|
lgt = light.create_trainable_env_rnd(512, scale=0.0, bias=0.5)
|
|
else:
|
|
lgt = light.load_env(FLAGS.envmap, scale=FLAGS.env_scale, trainable=False)
|
|
|
|
# ==============================================================================================
|
|
# If no initial guess, use DMtets to create geometry
|
|
# ==============================================================================================
|
|
|
|
# Setup geometry for optimization
|
|
geometry = DMTetGeometry(
|
|
FLAGS.dmtet_grid, FLAGS.mesh_scale, FLAGS,
|
|
deform_scale=FLAGS.first_stage_deform
|
|
)
|
|
|
|
# Setup textures, make initial guess from reference if possible
|
|
if not FLAGS.normal_only:
|
|
mat = initial_guess_material(geometry, True, FLAGS, mtl_default)
|
|
else:
|
|
mat = initial_guess_material_knownkskd(geometry, True, FLAGS, mtl_default)
|
|
|
|
print("Start optimization")
|
|
sys.stdout.flush()
|
|
|
|
if FLAGS.resume_path is None:
|
|
# Run optimization
|
|
geometry, mat = optimize_mesh(glctx, geometry, mat, lgt, dataset_train, dataset_validate,
|
|
FLAGS, pass_idx=0, pass_name="dmtet_pass1", optimize_light=FLAGS.learn_light)
|
|
|
|
base_mesh = geometry.getMesh(mat)
|
|
|
|
vert_mask = torch.zeros_like(geometry.sdf).long().cuda().view(-1, 1)
|
|
vert_mask[geometry.getValidVertsIdx()] = 1
|
|
|
|
# Free temporaries / cached memory
|
|
torch.cuda.empty_cache() ### may slow down training
|
|
|
|
torch.save({
|
|
'sdf': geometry.sdf.cpu().detach(),
|
|
'sdf_ema': geometry.sdf_ema.cpu().detach(),
|
|
'deform': (geometry.deform * vert_mask).cpu().detach(),
|
|
'deform_unmasked': geometry.deform.cpu().detach(),
|
|
}, os.path.join(FLAGS.out_dir, 'tets_pre/dmt_dict_{:05d}.pt'.format(global_index)))
|
|
|
|
old_geometry = geometry
|
|
else:
|
|
dmt_dict = torch.load(os.path.join(FLAGS.resume_path, 'tets_pre/dmt_dict_{:05d}.pt'.format(global_index)))
|
|
if FLAGS.use_ema:
|
|
geometry.sdf.data[:] = dmt_dict['sdf_ema']
|
|
else:
|
|
geometry.sdf.data[:] = dmt_dict['sdf']
|
|
geometry.deform.data[:] = dmt_dict['deform']
|
|
old_geometry = geometry
|
|
|
|
# Create textured mesh from result
|
|
if FLAGS.normal_only:
|
|
base_mesh = xatlas_uvmap_nrm(glctx, geometry, mat, FLAGS)
|
|
else:
|
|
base_mesh = xatlas_uvmap(glctx, geometry, mat, FLAGS)
|
|
|
|
|
|
# # ==============================================================================================
|
|
# # Pass 2: Finetune deformation with fixed topology
|
|
# # ==============================================================================================
|
|
geometry = DMTetGeometryFixedTopo(
|
|
geometry, base_mesh, FLAGS.dmtet_grid, FLAGS.mesh_scale, FLAGS,
|
|
deform_scale=FLAGS.second_stage_deform
|
|
)
|
|
|
|
geometry.sdf_sign.requires_grad = False
|
|
geometry.sdf_abs.requires_grad = False
|
|
geometry.deform.requires_grad = True
|
|
|
|
geometry.deform.data[:] = geometry.deform * FLAGS.first_stage_deform / FLAGS.second_stage_deform
|
|
|
|
if FLAGS.use_ema:
|
|
geometry.sdf_sign.data[:] = torch.sign(old_geometry.sdf_ema)
|
|
else:
|
|
geometry.sdf_sign.data[:] = torch.sign(old_geometry.sdf)
|
|
|
|
geometry.set_init_v_pos()
|
|
|
|
|
|
geometry, mat = optimize_mesh(glctx, geometry, mat, lgt, dataset_train, dataset_validate, FLAGS,
|
|
pass_idx=1, pass_name="mesh_pass", warmup_iter=100, optimize_light=FLAGS.learn_light and not FLAGS.lock_light,
|
|
optimize_geometry=not FLAGS.lock_pos)
|
|
|
|
vert_mask = torch.zeros_like(geometry.sdf_sign).long().cuda().view(-1, 1)
|
|
vert_mask[geometry.getValidVertsIdx()] = 1
|
|
|
|
torch.save({
|
|
'sdf': geometry.sdf_sign.cpu().detach(),
|
|
'deform': (geometry.deform * vert_mask).cpu().detach(),
|
|
'deform_unmasked': geometry.deform.cpu().detach(),
|
|
},
|
|
os.path.join(FLAGS.out_dir, 'tets/dmt_dict_{:05d}.pt'.format(global_index))
|
|
)
|
|
|
|
if FLAGS.local_rank == 0 and FLAGS.validate:
|
|
validate(glctx, geometry, mat, lgt, dataset_validate, os.path.join(FLAGS.out_dir, f"val_viz/dmtet_validate_{FLAGS.index}_{k}_{FLAGS.split_size}"), FLAGS)
|
|
|
|
# Free temporaries / cached memory
|
|
del geometry
|
|
del ref_mesh
|
|
del dataset_train
|
|
del dataset_validate
|
|
torch.cuda.empty_cache() ### may slow down training
|
|
|
|
print(f"\n\n============ {FLAGS.index}_{k}/{FLAGS.split_size} finished ============\n\n")
|
|
except Exception as err:
|
|
print(f"\n\n============ {FLAGS.index}_{k}/{FLAGS.split_size} Failed ============\n\n")
|
|
print(traceback.format_exc())
|
|
print("\n\n")
|
|
continue
|
|
|
|
|
|
|