kopia lustrzana https://github.com/lzzcd001/MeshDiffusion
835 wiersze
37 KiB
Python
835 wiersze
37 KiB
Python
# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
#
|
|
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
|
# property and proprietary rights in and to this material, related
|
|
# documentation and any modifications thereto. Any use, reproduction,
|
|
# disclosure or distribution of this material and related documentation
|
|
# without an express license agreement from NVIDIA CORPORATION or
|
|
# its affiliates is strictly prohibited.
|
|
|
|
import os
|
|
import time
|
|
import argparse
|
|
import json
|
|
import sys
|
|
|
|
import cv2
|
|
import numpy as np
|
|
import torch
|
|
import nvdiffrast.torch as dr
|
|
import xatlas
|
|
|
|
# Import data readers / generators
|
|
from lib.dataset.dataset_mesh import DatasetMesh
|
|
from lib.dataset.dataset_shapenet import ShapeNetDataset
|
|
|
|
# Import topology / geometry trainers
|
|
from lib.geometry.dmtet_singleview import DMTetGeometry
|
|
from lib.geometry.dmtet_fixedtopo import DMTetGeometryFixedTopo
|
|
|
|
import lib.render.renderutils as ru
|
|
from lib.render import obj
|
|
from lib.render import material
|
|
from lib.render import util
|
|
from lib.render import mesh
|
|
from lib.render import texture
|
|
from lib.render import mlptexture
|
|
from lib.render import light
|
|
from lib.render import render
|
|
from random import randint
|
|
from time import sleep
|
|
|
|
import traceback
|
|
|
|
RADIUS = 2.0
|
|
|
|
# define colors
|
|
color1 = (0, 0, 255) #red
|
|
color2 = (0, 165, 255) #orange
|
|
color3 = (0, 255, 255) #yellow
|
|
color4 = (255, 255, 0) #cyan
|
|
color5 = (255, 0, 0) #blue
|
|
color6 = (128, 64, 64) #violet
|
|
colorArr = np.array([[color1, color2, color3, color4, color5, color6]], dtype=np.uint8)
|
|
|
|
# resize lut to 256 (or more) values
|
|
lut = cv2.resize(colorArr, (256,1), interpolation = cv2.INTER_LINEAR)
|
|
|
|
###############################################################################
|
|
# Loss setup
|
|
###############################################################################
|
|
|
|
@torch.no_grad()
|
|
def createLoss(FLAGS):
|
|
if FLAGS.loss == "smape":
|
|
return lambda img, ref: ru.image_loss(img, ref, loss='smape', tonemapper='none')
|
|
elif FLAGS.loss == "mse":
|
|
return lambda img, ref: ru.image_loss(img, ref, loss='mse', tonemapper='none')
|
|
elif FLAGS.loss == "logl1":
|
|
return lambda img, ref: ru.image_loss(img, ref, loss='l1', tonemapper='log_srgb')
|
|
elif FLAGS.loss == "logl2":
|
|
return lambda img, ref: ru.image_loss(img, ref, loss='mse', tonemapper='log_srgb')
|
|
elif FLAGS.loss == "relmse":
|
|
return lambda img, ref: ru.image_loss(img, ref, loss='relmse', tonemapper='none')
|
|
else:
|
|
assert False
|
|
|
|
###############################################################################
|
|
# Mix background into a dataset image
|
|
###############################################################################
|
|
|
|
@torch.no_grad()
|
|
def prepare_batch(target, bg_type='black'):
|
|
assert len(target['img'].shape) == 4, "Image shape should be [n, h, w, c]"
|
|
if bg_type == 'checker':
|
|
background = torch.tensor(util.checkerboard(target['img'].shape[1:3], 8), dtype=torch.float32, device='cuda')[None, ...]
|
|
elif bg_type == 'black':
|
|
background = torch.zeros(target['img'].shape[0:3] + (3,), dtype=torch.float32, device='cuda')
|
|
elif bg_type == 'white':
|
|
background = torch.ones(target['img'].shape[0:3] + (3,), dtype=torch.float32, device='cuda')
|
|
elif bg_type == 'reference':
|
|
background = target['img'][..., 0:3]
|
|
elif bg_type == 'random':
|
|
background = torch.rand(target['img'].shape[0:3] + (3,), dtype=torch.float32, device='cuda')
|
|
else:
|
|
assert False, "Unknown background type %s" % bg_type
|
|
|
|
target['mv'] = target['mv'].cuda()
|
|
target['mvp'] = target['mvp'].cuda()
|
|
target['campos'] = target['campos'].cuda()
|
|
target['img'] = target['img'].cuda()
|
|
target['background'] = background
|
|
|
|
target['img'] = torch.cat((torch.lerp(background, target['img'][..., 0:3], target['img'][..., 3:4]), target['img'][..., 3:4]), dim=-1)
|
|
|
|
target['spts'] = target['spts'].cuda()
|
|
target['vpts'] = target['vpts'].cuda()
|
|
return target
|
|
|
|
###############################################################################
|
|
# UV - map geometry & convert to a mesh
|
|
###############################################################################
|
|
|
|
@torch.no_grad()
|
|
def xatlas_uvmap(glctx, geometry, mat, FLAGS):
|
|
eval_mesh = geometry.getMesh(mat)
|
|
|
|
# Create uvs with xatlas
|
|
v_pos = eval_mesh.v_pos.detach().cpu().numpy()
|
|
t_pos_idx = eval_mesh.t_pos_idx.detach().cpu().numpy()
|
|
vmapping, indices, uvs = xatlas.parametrize(v_pos, t_pos_idx)
|
|
|
|
# Convert to tensors
|
|
indices_int64 = indices.astype(np.uint64, casting='same_kind').view(np.int64)
|
|
|
|
uvs = torch.tensor(uvs, dtype=torch.float32, device='cuda')
|
|
faces = torch.tensor(indices_int64, dtype=torch.int64, device='cuda')
|
|
|
|
new_mesh = mesh.Mesh(v_tex=uvs, t_tex_idx=faces, base=eval_mesh)
|
|
|
|
mask, kd, ks, normal = render.render_uv(glctx, new_mesh, FLAGS.texture_res, eval_mesh.material['kd_ks_normal'])
|
|
|
|
if FLAGS.layers > 1:
|
|
kd = torch.cat((kd, torch.rand_like(kd[...,0:1])), dim=-1)
|
|
|
|
kd_min, kd_max = torch.tensor(FLAGS.kd_min, dtype=torch.float32, device='cuda'), torch.tensor(FLAGS.kd_max, dtype=torch.float32, device='cuda')
|
|
ks_min, ks_max = torch.tensor(FLAGS.ks_min, dtype=torch.float32, device='cuda'), torch.tensor(FLAGS.ks_max, dtype=torch.float32, device='cuda')
|
|
nrm_min, nrm_max = torch.tensor(FLAGS.nrm_min, dtype=torch.float32, device='cuda'), torch.tensor(FLAGS.nrm_max, dtype=torch.float32, device='cuda')
|
|
|
|
new_mesh.material = material.Material({
|
|
'bsdf' : mat['bsdf'],
|
|
'kd' : texture.Texture2D(kd, min_max=[kd_min, kd_max]),
|
|
'ks' : texture.Texture2D(ks, min_max=[ks_min, ks_max]),
|
|
'normal' : texture.Texture2D(normal, min_max=[nrm_min, nrm_max])
|
|
})
|
|
|
|
return new_mesh
|
|
|
|
@torch.no_grad()
|
|
def xatlas_uvmap_nrm(glctx, geometry, mat, FLAGS):
|
|
eval_mesh = geometry.getMesh(mat)
|
|
|
|
# Create uvs with xatlas
|
|
v_pos = eval_mesh.v_pos.detach().cpu().numpy()
|
|
t_pos_idx = eval_mesh.t_pos_idx.detach().cpu().numpy()
|
|
vmapping, indices, uvs = xatlas.parametrize(v_pos, t_pos_idx)
|
|
|
|
# Convert to tensors
|
|
indices_int64 = indices.astype(np.uint64, casting='same_kind').view(np.int64)
|
|
|
|
uvs = torch.tensor(uvs, dtype=torch.float32, device='cuda')
|
|
faces = torch.tensor(indices_int64, dtype=torch.int64, device='cuda')
|
|
|
|
new_mesh = mesh.Mesh(v_tex=uvs, t_tex_idx=faces, base=eval_mesh)
|
|
|
|
mask, normal = render.render_uv_nrm(glctx, new_mesh, FLAGS.texture_res, eval_mesh.material['normal'])
|
|
|
|
if FLAGS.layers > 1:
|
|
kd = torch.cat((kd, torch.rand_like(kd[...,0:1])), dim=-1)
|
|
|
|
nrm_min, nrm_max = torch.tensor(FLAGS.nrm_min, dtype=torch.float32, device='cuda'), torch.tensor(FLAGS.nrm_max, dtype=torch.float32, device='cuda')
|
|
|
|
new_mesh.material = material.Material({
|
|
'bsdf' : mat['bsdf'],
|
|
'kd' : mat['kd'],
|
|
'ks' : mat['ks'],
|
|
'normal' : texture.Texture2D(normal, min_max=[nrm_min, nrm_max])
|
|
})
|
|
|
|
return new_mesh
|
|
|
|
###############################################################################
|
|
# Utility functions for material
|
|
###############################################################################
|
|
|
|
def initial_guess_material(geometry, mlp, FLAGS, init_mat=None):
|
|
kd_min, kd_max = torch.tensor(FLAGS.kd_min, dtype=torch.float32, device='cuda'), torch.tensor(FLAGS.kd_max, dtype=torch.float32, device='cuda')
|
|
ks_min, ks_max = torch.tensor(FLAGS.ks_min, dtype=torch.float32, device='cuda'), torch.tensor(FLAGS.ks_max, dtype=torch.float32, device='cuda')
|
|
nrm_min, nrm_max = torch.tensor(FLAGS.nrm_min, dtype=torch.float32, device='cuda'), torch.tensor(FLAGS.nrm_max, dtype=torch.float32, device='cuda')
|
|
if mlp:
|
|
mlp_min = torch.cat((kd_min[0:3], ks_min, nrm_min), dim=0)
|
|
mlp_max = torch.cat((kd_max[0:3], ks_max, nrm_max), dim=0)
|
|
mlp_map_opt = mlptexture.MLPTexture3D(geometry.getAABB(), channels=9, min_max=[mlp_min, mlp_max])
|
|
mat = material.Material({'kd_ks_normal' : mlp_map_opt})
|
|
else:
|
|
# Setup Kd (albedo) and Ks (x, roughness, metalness) textures
|
|
if FLAGS.random_textures or init_mat is None:
|
|
num_channels = 4 if FLAGS.layers > 1 else 3
|
|
kd_init = torch.rand(size=FLAGS.texture_res + [num_channels], device='cuda') * (kd_max - kd_min)[None, None, 0:num_channels] + kd_min[None, None, 0:num_channels]
|
|
kd_map_opt = texture.create_trainable(kd_init , FLAGS.texture_res, not FLAGS.custom_mip, [kd_min, kd_max])
|
|
|
|
ksR = np.random.uniform(size=FLAGS.texture_res + [1], low=0.0, high=0.01)
|
|
ksG = np.random.uniform(size=FLAGS.texture_res + [1], low=ks_min[1].cpu(), high=ks_max[1].cpu())
|
|
ksB = np.random.uniform(size=FLAGS.texture_res + [1], low=ks_min[2].cpu(), high=ks_max[2].cpu())
|
|
|
|
ks_map_opt = texture.create_trainable(np.concatenate((ksR, ksG, ksB), axis=2), FLAGS.texture_res, not FLAGS.custom_mip, [ks_min, ks_max])
|
|
else:
|
|
kd_map_opt = texture.create_trainable(init_mat['kd'], FLAGS.texture_res, not FLAGS.custom_mip, [kd_min, kd_max])
|
|
ks_map_opt = texture.create_trainable(init_mat['ks'], FLAGS.texture_res, not FLAGS.custom_mip, [ks_min, ks_max])
|
|
|
|
# Setup normal map
|
|
if FLAGS.random_textures or init_mat is None or 'normal' not in init_mat:
|
|
normal_map_opt = texture.create_trainable(np.array([0, 0, 1]), FLAGS.texture_res, not FLAGS.custom_mip, [nrm_min, nrm_max])
|
|
else:
|
|
normal_map_opt = texture.create_trainable(init_mat['normal'], FLAGS.texture_res, not FLAGS.custom_mip, [nrm_min, nrm_max])
|
|
|
|
mat = material.Material({
|
|
'kd' : kd_map_opt,
|
|
'ks' : ks_map_opt,
|
|
'normal' : normal_map_opt
|
|
})
|
|
|
|
if init_mat is not None:
|
|
mat['bsdf'] = init_mat['bsdf']
|
|
else:
|
|
mat['bsdf'] = 'pbr'
|
|
|
|
return mat
|
|
|
|
def initial_guess_material_knownkskd(geometry, mlp, FLAGS, init_mat=None):
|
|
nrm_min, nrm_max = torch.tensor(FLAGS.nrm_min, dtype=torch.float32, device='cuda'), torch.tensor(FLAGS.nrm_max, dtype=torch.float32, device='cuda')
|
|
|
|
if mlp:
|
|
mlp_min = nrm_min
|
|
mlp_max = nrm_max
|
|
mlp_map_opt = mlptexture.MLPTexture3D(geometry.getAABB(), channels=3, min_max=[mlp_min, mlp_max])
|
|
# mlp_map_opt = mlptexture.MLPTexture3D(geometry.getAABB(), channels=3, min_max=None)
|
|
mat = material.Material({
|
|
'kd' : init_mat['kd'],
|
|
'ks' : init_mat['ks'],
|
|
'normal' : mlp_map_opt,
|
|
})
|
|
else:
|
|
# Setup normal map
|
|
if FLAGS.random_textures or init_mat is None or 'normal' not in init_mat:
|
|
normal_map_opt = texture.create_trainable(np.array([0, 0, 1]), FLAGS.texture_res, not FLAGS.custom_mip, [nrm_min, nrm_max])
|
|
else:
|
|
normal_map_opt = texture.create_trainable(init_mat['normal'], FLAGS.texture_res, not FLAGS.custom_mip, [nrm_min, nrm_max])
|
|
|
|
mat = material.Material({
|
|
'kd' : init_mat['kd'],
|
|
'ks' : init_mat['ks'],
|
|
'normal' : normal_map_opt
|
|
})
|
|
|
|
if init_mat is not None:
|
|
mat['bsdf'] = init_mat['bsdf']
|
|
else:
|
|
mat['bsdf'] = 'pbr'
|
|
|
|
return mat
|
|
|
|
###############################################################################
|
|
# Validation & testing
|
|
###############################################################################
|
|
|
|
def validate_itr(glctx, target, geometry, opt_material, lgt, FLAGS):
|
|
result_dict = {}
|
|
with torch.no_grad():
|
|
lgt.build_mips()
|
|
if FLAGS.camera_space_light:
|
|
lgt.xfm(target['mv'])
|
|
lgt.xfm(target['envlight_transform'])
|
|
|
|
try:
|
|
buffers = geometry.render(glctx, target, lgt, opt_material, ema=True, xfm_lgt=target['envlight_transform'])
|
|
except:
|
|
buffers = geometry.render(glctx, target, lgt, opt_material, xfm_lgt=target['envlight_transform'])
|
|
|
|
result_dict['ref'] = util.rgb_to_srgb(target['img'][...,0:3])[0]
|
|
result_dict['opt'] = util.rgb_to_srgb(buffers['shaded'][...,0:3])[0]
|
|
result_image = torch.cat([result_dict['opt'], result_dict['ref']], axis=1)
|
|
|
|
if FLAGS.display is not None:
|
|
white_bg = torch.ones_like(target['background'])
|
|
for layer in FLAGS.display:
|
|
if 'latlong' in layer and layer['latlong']:
|
|
if isinstance(lgt, light.EnvironmentLight):
|
|
result_dict['light_image'] = util.cubemap_to_latlong(lgt.base, FLAGS.display_res)
|
|
result_image = torch.cat([result_image, result_dict['light_image']], axis=1)
|
|
elif 'relight' in layer:
|
|
if not isinstance(layer['relight'], light.EnvironmentLight):
|
|
layer['relight'] = light.load_env(layer['relight'])
|
|
img = geometry.render(glctx, target, layer['relight'], opt_material)
|
|
result_dict['relight'] = util.rgb_to_srgb(img[..., 0:3])[0]
|
|
result_image = torch.cat([result_image, result_dict['relight']], axis=1)
|
|
elif 'bsdf' in layer:
|
|
buffers = geometry.render(glctx, target, lgt, opt_material, bsdf=layer['bsdf'])
|
|
if layer['bsdf'] == 'kd':
|
|
result_dict[layer['bsdf']] = util.rgb_to_srgb(buffers['shaded'][0, ..., 0:3])
|
|
elif layer['bsdf'] == 'normal':
|
|
result_dict[layer['bsdf']] = (buffers['shaded'][0, ..., 0:3] + 1) * 0.5
|
|
else:
|
|
result_dict[layer['bsdf']] = buffers['shaded'][0, ..., 0:3]
|
|
result_image = torch.cat([result_image, result_dict[layer['bsdf']]], axis=1)
|
|
elif "depth" in layer:
|
|
depth = buffers['depth'][:, :, :, 0].squeeze().unsqueeze(-1).expand(-1, -1, 1)
|
|
mask = (depth != 0).float()
|
|
depth_min = ((1 - mask) * 1e3 + depth).min()
|
|
depth_max = depth.max()
|
|
depth = (depth - depth_min) / (depth_max - depth_min + 1e-8)
|
|
depth = depth * mask + (1 - mask) * depth_min
|
|
depth = depth.expand(-1, -1, 3)
|
|
depth = cv2.LUT(np.array(depth.detach().cpu().numpy() * 255.0, dtype=np.uint8), lut)
|
|
result_dict['depth'] = depth = (torch.tensor(depth, device=mask.device).float() / 255.0 * mask) + 255. * (1 - mask)
|
|
result_image = torch.cat([result_image, depth], axis=1)
|
|
|
|
buffers = geometry.render(glctx, target, lgt, opt_material)
|
|
|
|
camera = target['geo_viewdir'][:, :, :, :3]
|
|
|
|
result_dict['geo_normal'] = (util.safe_normalize(buffers['geo_normal'][:, :, :, :3]) * camera).sum(-1, keepdim=False).abs()[0]
|
|
|
|
mask = buffers['mask'][0].expand(-1, -1, 3)
|
|
result_image = torch.cat([result_image, mask], axis=1)
|
|
|
|
return result_image, result_dict
|
|
|
|
def validate(glctx, geometry, opt_material, lgt, dataset_validate, out_dir, FLAGS):
|
|
|
|
# ==============================================================================================
|
|
# Validation loop
|
|
# ==============================================================================================
|
|
mse_values = []
|
|
psnr_values = []
|
|
|
|
dataloader_validate = torch.utils.data.DataLoader(dataset_validate, batch_size=1, collate_fn=dataset_validate.collate)
|
|
|
|
os.makedirs(out_dir, exist_ok=True)
|
|
with open(os.path.join(out_dir, 'metrics.txt'), 'w') as fout:
|
|
fout.write('ID, MSE, PSNR\n')
|
|
|
|
print("Running validation")
|
|
for it, target in enumerate(dataloader_validate):
|
|
|
|
# Mix validation background
|
|
target = prepare_batch(target, FLAGS.background)
|
|
|
|
result_image, result_dict = validate_itr(glctx, target, geometry, opt_material, lgt, FLAGS)
|
|
|
|
# Compute metrics
|
|
opt = torch.clamp(result_dict['opt'], 0.0, 1.0)
|
|
ref = torch.clamp(result_dict['ref'], 0.0, 1.0)
|
|
|
|
mse = torch.nn.functional.mse_loss(opt, ref, size_average=None, reduce=None, reduction='mean').item()
|
|
mse_values.append(float(mse))
|
|
psnr = util.mse_to_psnr(mse)
|
|
psnr_values.append(float(psnr))
|
|
|
|
line = "%d, %1.8f, %1.8f\n" % (it, mse, psnr)
|
|
fout.write(str(line))
|
|
|
|
for k in result_dict.keys():
|
|
np_img = result_dict[k].detach().cpu().numpy()
|
|
util.save_image(out_dir + '/' + ('val_%06d_%s.png' % (it, k)), np_img)
|
|
|
|
avg_mse = np.mean(np.array(mse_values))
|
|
avg_psnr = np.mean(np.array(psnr_values))
|
|
line = "AVERAGES: %1.4f, %2.3f\n" % (avg_mse, avg_psnr)
|
|
fout.write(str(line))
|
|
print("MSE, PSNR")
|
|
print("%1.8f, %2.3f" % (avg_mse, avg_psnr))
|
|
return avg_psnr
|
|
|
|
###############################################################################
|
|
# Main shape fitter function / optimization loop
|
|
###############################################################################
|
|
|
|
class Trainer(torch.nn.Module):
|
|
def __init__(self, glctx, geometry, lgt, mat, optimize_geometry, optimize_light, image_loss_fn, FLAGS):
|
|
super(Trainer, self).__init__()
|
|
|
|
self.glctx = glctx
|
|
self.geometry = geometry
|
|
self.light = lgt
|
|
self.material = mat
|
|
self.optimize_geometry = optimize_geometry
|
|
self.optimize_light = optimize_light
|
|
self.image_loss_fn = image_loss_fn
|
|
self.FLAGS = FLAGS
|
|
|
|
if not self.optimize_light:
|
|
with torch.no_grad():
|
|
self.light.build_mips()
|
|
|
|
self.params = list(self.material.parameters())
|
|
self.params += list(self.light.parameters()) if optimize_light else []
|
|
self.geo_params = list(self.geometry.parameters()) if optimize_geometry else []
|
|
try:
|
|
self.sdf_params = [self.geometry.sdf]
|
|
except:
|
|
self.sdf_params = []
|
|
self.deform_params = [self.geometry.deform]
|
|
|
|
def forward(self, target, it):
|
|
if self.optimize_light:
|
|
self.light.build_mips()
|
|
if self.FLAGS.camera_space_light:
|
|
self.light.xfm(target['mv'])
|
|
self.light.xfm(target['envlight_transform'])
|
|
|
|
return self.geometry.tick(glctx, target, self.light, self.material, self.image_loss_fn, it, xfm_lgt=target['envlight_transform'], no_depth_thin=False)
|
|
|
|
def optimize_mesh(
|
|
glctx,
|
|
geometry,
|
|
opt_material,
|
|
lgt,
|
|
dataset_train,
|
|
dataset_validate,
|
|
FLAGS,
|
|
warmup_iter=0,
|
|
log_interval=10,
|
|
pass_idx=0,
|
|
pass_name="",
|
|
optimize_light=True,
|
|
optimize_geometry=True,
|
|
):
|
|
|
|
# ==============================================================================================
|
|
# Setup torch optimizer
|
|
# ==============================================================================================
|
|
|
|
learning_rate = FLAGS.learning_rate[pass_idx] if isinstance(FLAGS.learning_rate, list) or isinstance(FLAGS.learning_rate, tuple) else FLAGS.learning_rate
|
|
learning_rate_pos = learning_rate[0] if isinstance(learning_rate, list) or isinstance(learning_rate, tuple) else learning_rate
|
|
learning_rate_mat = learning_rate[1] if isinstance(learning_rate, list) or isinstance(learning_rate, tuple) else learning_rate
|
|
|
|
def lr_schedule(iter, fraction):
|
|
if iter < warmup_iter:
|
|
return iter / warmup_iter
|
|
return max(0.0, 10**(-(iter - warmup_iter)*0.0002)) # Exponential falloff from [1.0, 0.1] over 5k epochs.
|
|
|
|
# ==============================================================================================
|
|
# Image loss
|
|
# ==============================================================================================
|
|
image_loss_fn = createLoss(FLAGS)
|
|
|
|
trainer_noddp = Trainer(glctx, geometry, lgt, opt_material, optimize_geometry, optimize_light, image_loss_fn, FLAGS)
|
|
|
|
# Single GPU training mode
|
|
trainer = trainer_noddp
|
|
if optimize_geometry:
|
|
optimizer_mesh = torch.optim.Adam([
|
|
{'params': trainer_noddp.sdf_params, 'lr': learning_rate_pos},
|
|
{'params': trainer_noddp.deform_params, 'lr': learning_rate_pos},
|
|
]
|
|
)
|
|
scheduler_mesh = torch.optim.lr_scheduler.LambdaLR(optimizer_mesh, lr_lambda=lambda x: lr_schedule(x, 0.9))
|
|
|
|
optimizer = torch.optim.Adam(trainer_noddp.params, lr=learning_rate_mat)
|
|
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda x: lr_schedule(x, 0.9))
|
|
|
|
# ==============================================================================================
|
|
# Training loop
|
|
# ==============================================================================================
|
|
img_cnt = 0
|
|
img_loss_vec = []
|
|
reg_loss_vec = []
|
|
iter_dur_vec = []
|
|
|
|
dataloader_validate = torch.utils.data.DataLoader(dataset_validate, batch_size=1, collate_fn=dataset_train.collate)
|
|
|
|
def cycle(iterable):
|
|
iterator = iter(iterable)
|
|
while True:
|
|
try:
|
|
yield next(iterator)
|
|
except StopIteration:
|
|
iterator = iter(iterable)
|
|
|
|
v_it = cycle(dataloader_validate)
|
|
# v_iter_no = 25
|
|
v_iter_no = 10
|
|
|
|
print("Start training loop...")
|
|
sys.stdout.flush()
|
|
for _ in range(v_iter_no):
|
|
v_curr = next(v_it)
|
|
|
|
for it in range(5000):
|
|
# Mix randomized background into dataset image
|
|
target = prepare_batch(v_curr, 'random')
|
|
|
|
### for robustness, we take the easy way of initializing the tet grid with the gt depth image
|
|
if it < 300 and it % 10 == 0:
|
|
gt_visible_triangles = target['rast_triangle_id'].long()
|
|
gt_verts, gt_faces = target['vpts'], target['faces']
|
|
surface_faces = gt_faces[gt_visible_triangles]
|
|
campos = target['campos'][0]
|
|
try:
|
|
geometry.init_with_gt_surface(gt_verts, surface_faces, campos)
|
|
except:
|
|
pass
|
|
|
|
iter_start_time = time.time()
|
|
|
|
|
|
# ==============================================================================================
|
|
# Zero gradients
|
|
# ==============================================================================================
|
|
optimizer.zero_grad()
|
|
if optimize_geometry:
|
|
optimizer_mesh.zero_grad()
|
|
|
|
# ==============================================================================================
|
|
# Training
|
|
# ==============================================================================================
|
|
img_loss, reg_loss = trainer(target, it)
|
|
|
|
# ==============================================================================================
|
|
# Final loss
|
|
# ==============================================================================================
|
|
total_loss = img_loss + reg_loss
|
|
|
|
img_loss_vec.append(img_loss.item())
|
|
reg_loss_vec.append(reg_loss.item())
|
|
|
|
# ==============================================================================================
|
|
# Backpropagate
|
|
# ==============================================================================================
|
|
total_loss.backward()
|
|
|
|
if hasattr(lgt, 'base') and lgt.base.grad is not None and optimize_light:
|
|
lgt.base.grad *= 64
|
|
if 'kd_ks_normal' in opt_material:
|
|
opt_material['kd_ks_normal'].encoder.params.grad /= 8.0
|
|
if 'normal' in opt_material and FLAGS.normal_only:
|
|
try:
|
|
opt_material['normal'].encoder.params.grad /= 8.0
|
|
except:
|
|
pass
|
|
|
|
optimizer.step()
|
|
scheduler.step()
|
|
|
|
if optimize_geometry:
|
|
optimizer_mesh.step()
|
|
scheduler_mesh.step()
|
|
|
|
geometry.clamp_deform()
|
|
if FLAGS.use_ema:
|
|
raise NotImplementedError
|
|
geometry.update_ema()
|
|
|
|
# ==============================================================================================
|
|
# Clamp trainables to reasonable range
|
|
# ==============================================================================================
|
|
with torch.no_grad():
|
|
if 'kd' in opt_material:
|
|
opt_material['kd'].clamp_()
|
|
if 'ks' in opt_material:
|
|
opt_material['ks'].clamp_()
|
|
if 'normal' in opt_material and not FLAGS.normal_only:
|
|
opt_material['normal'].clamp_()
|
|
opt_material['normal'].normalize_()
|
|
if lgt is not None:
|
|
lgt.clamp_(min=0.0)
|
|
|
|
torch.cuda.current_stream().synchronize()
|
|
iter_dur_vec.append(time.time() - iter_start_time)
|
|
|
|
# ==============================================================================================
|
|
# Logging
|
|
# ==============================================================================================
|
|
if it % log_interval == 0 and FLAGS.local_rank == 0:
|
|
img_loss_avg = np.mean(np.asarray(img_loss_vec[-log_interval:]))
|
|
reg_loss_avg = np.mean(np.asarray(reg_loss_vec[-log_interval:]))
|
|
iter_dur_avg = np.mean(np.asarray(iter_dur_vec[-log_interval:]))
|
|
|
|
remaining_time = (FLAGS.iter-it)*iter_dur_avg
|
|
print("iter=%5d, img_loss=%.6f, reg_loss=%.6f, lr=%.5f, time=%.1f ms, rem=%s" %
|
|
(it, img_loss_avg, reg_loss_avg, optimizer.param_groups[0]['lr'], iter_dur_avg*1000, util.time_to_text(remaining_time)))
|
|
sys.stdout.flush()
|
|
|
|
return geometry, opt_material
|
|
|
|
#----------------------------------------------------------------------------
|
|
# Main function.
|
|
#----------------------------------------------------------------------------
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser(description='nvdiffrec')
|
|
parser.add_argument('--config', type=str, default='./configs/res64.json', help='Config file')
|
|
parser.add_argument('-i', '--iter', type=int, default=5000)
|
|
parser.add_argument('-s', '--spp', type=int, default=1)
|
|
parser.add_argument('-l', '--layers', type=int, default=1)
|
|
parser.add_argument('-r', '--train-res', nargs=2, type=int, default=[512, 512])
|
|
parser.add_argument('-dr', '--display-res', type=int, default=None)
|
|
parser.add_argument('-tr', '--texture-res', nargs=2, type=int, default=[1024, 1024])
|
|
parser.add_argument('-di', '--display-interval', type=int, default=0)
|
|
parser.add_argument('-si', '--save-interval', type=int, default=1000)
|
|
parser.add_argument('-lr', '--learning-rate', type=float, default=0.01)
|
|
parser.add_argument('-mr', '--min-roughness', type=float, default=0.08)
|
|
parser.add_argument('-mip', '--custom-mip', action='store_true', default=False)
|
|
parser.add_argument('-rt', '--random-textures', action='store_true', default=False)
|
|
parser.add_argument('-bg', '--background', default='checker', choices=['black', 'white', 'checker', 'reference'])
|
|
parser.add_argument('--loss', default='logl1', choices=['logl1', 'logl2', 'mse', 'smape', 'relmse'])
|
|
parser.add_argument('-o', '--out-dir', type=str, default='./dmtet_results_singleview')
|
|
parser.add_argument('--validate', type=bool, default=True)
|
|
parser.add_argument('-no', '--normal-only', type=bool, default=True)
|
|
parser.add_argument('-ema', '--use-ema', action="store_true")
|
|
parser.add_argument('-rp', '--resume-path', type=str, default=None)
|
|
parser.add_argument('-mp', '--mesh-path', type=str)
|
|
parser.add_argument('-an', '--angle-ind', type=int, help='angle index from 0 to 50')
|
|
|
|
FLAGS = parser.parse_args()
|
|
print(f"parsed arguments")
|
|
|
|
FLAGS.mtl_override = None # Override material of model
|
|
FLAGS.dmtet_grid = 64 # Resolution of initial tet grid. We provide 64 and 128 resolution grids. Other resolutions can be generated with https://github.com/crawforddoran/quartet
|
|
FLAGS.mesh_scale = 1.0 # Scale of tet grid box. Adjust to cover the model
|
|
FLAGS.env_scale = 1.0 # Env map intensity multiplier
|
|
FLAGS.envmap = None # HDR environment probe
|
|
FLAGS.display = None # Conf validation window/display. E.g. [{"relight" : <path to envlight>}]
|
|
FLAGS.camera_space_light = False # Fixed light in camera space. This is needed for setups like ethiopian head where the scanned object rotates on a stand.
|
|
FLAGS.lock_light = False # Disable light optimization in the second pass
|
|
FLAGS.lock_pos = False # Disable vertex position optimization in the second pass
|
|
FLAGS.sdf_regularizer = 0.2 # Weight for sdf regularizer (see paper for details)
|
|
FLAGS.laplace = "relative" # Mesh Laplacian ["absolute", "relative"]
|
|
FLAGS.laplace_scale = 10000.0 # Weight for sdf regularizer. Default is relative with large weight
|
|
FLAGS.pre_load = True # Pre-load entire dataset into memory for faster training
|
|
FLAGS.kd_min = [ 0.0, 0.0, 0.0, 0.0] # Limits for kd
|
|
FLAGS.kd_max = [ 1.0, 1.0, 1.0, 1.0]
|
|
FLAGS.ks_min = [ 0.0, 0.08, 0.0] # Limits for ks
|
|
FLAGS.ks_max = [ 1.0, 1.0, 1.0]
|
|
FLAGS.nrm_min = [-1.0, -1.0, 0.0] # Limits for normal map
|
|
FLAGS.nrm_max = [ 1.0, 1.0, 1.0]
|
|
FLAGS.cam_near_far = [0.1, 1000.0]
|
|
FLAGS.learn_light = False
|
|
FLAGS.use_ema = False
|
|
FLAGS.random_lgt = True
|
|
FLAGS.dataset_flat_shading = False
|
|
|
|
FLAGS.local_rank = 0
|
|
if FLAGS.config is not None:
|
|
data = json.load(open(FLAGS.config, 'r'))
|
|
for key in data:
|
|
FLAGS.__dict__[key] = data[key]
|
|
|
|
if FLAGS.display_res is None:
|
|
FLAGS.display_res = FLAGS.train_res
|
|
|
|
print(f"Out dir: {FLAGS.out_dir}")
|
|
|
|
if FLAGS.local_rank == 0:
|
|
print("Config / Flags:")
|
|
print("---------")
|
|
for key in FLAGS.__dict__.keys():
|
|
print(key, FLAGS.__dict__[key])
|
|
print("---------")
|
|
|
|
os.makedirs(FLAGS.out_dir, exist_ok=True)
|
|
os.makedirs(os.path.join(FLAGS.out_dir, 'val_viz'), exist_ok=True)
|
|
os.makedirs(os.path.join(FLAGS.out_dir, 'val_viz_pre'), exist_ok=True)
|
|
os.makedirs(os.path.join(FLAGS.out_dir, 'tets'), exist_ok=True)
|
|
os.makedirs(os.path.join(FLAGS.out_dir, 'tets_pre'), exist_ok=True)
|
|
|
|
|
|
print(f"Using dmtet grid of resolution {FLAGS.dmtet_grid}")
|
|
|
|
glctx = dr.RasterizeGLContext()
|
|
|
|
### Default mtl
|
|
mtl_default = {
|
|
'name' : '_default_mat',
|
|
'bsdf': 'diffuse',
|
|
'uniform': True,
|
|
'kd' : texture.Texture2D(torch.tensor([0.75, 0.3, 0.6], dtype=torch.float32, device='cuda'), trainable=False),
|
|
'ks' : texture.Texture2D(torch.tensor([0.0, 0.0, 0.0], dtype=torch.float32, device='cuda'), trainable=False)
|
|
}
|
|
|
|
print(f"Loading mesh: {FLAGS.mesh_path}")
|
|
sys.stdout.flush()
|
|
ref_mesh = mesh.load_mesh(FLAGS.mesh_path, FLAGS.mtl_override, mtl_default, use_default=FLAGS.normal_only, no_additional=True)
|
|
ref_mesh = mesh.center_by_reference(ref_mesh, mesh.aabb_clean(ref_mesh), 1.0)
|
|
print("Loading dataset")
|
|
sys.stdout.flush()
|
|
dataset_train = DatasetMesh(ref_mesh, glctx, RADIUS, FLAGS, validate=False)
|
|
dataset_validate = DatasetMesh(ref_mesh, glctx, RADIUS, FLAGS, validate=True)
|
|
print("Dataset loaded")
|
|
sys.stdout.flush()
|
|
|
|
|
|
# ==============================================================================================
|
|
# Create env light with trainable parameters
|
|
# ==============================================================================================
|
|
|
|
if FLAGS.learn_light:
|
|
lgt = light.create_trainable_env_rnd(512, scale=0.0, bias=0.5)
|
|
else:
|
|
lgt = light.load_env(FLAGS.envmap, scale=FLAGS.env_scale, trainable=False)
|
|
|
|
|
|
# ==============================================================================================
|
|
# If no initial guess, use DMtets to create geometry
|
|
# ==============================================================================================
|
|
|
|
# Setup geometry for optimization
|
|
geometry = DMTetGeometry(
|
|
FLAGS.dmtet_grid, FLAGS.mesh_scale, FLAGS,
|
|
deform_scale=FLAGS.first_stage_deform
|
|
)
|
|
|
|
# Setup textures, make initial guess from reference if possible
|
|
if not FLAGS.normal_only:
|
|
mat = initial_guess_material(geometry, True, FLAGS, mtl_default)
|
|
else:
|
|
mat = initial_guess_material_knownkskd(geometry, True, FLAGS, mtl_default)
|
|
|
|
print("Start optimization")
|
|
sys.stdout.flush()
|
|
|
|
if FLAGS.resume_path is None:
|
|
# Run optimization
|
|
geometry, mat = optimize_mesh(glctx, geometry, mat, lgt, dataset_train, dataset_validate,
|
|
FLAGS, pass_idx=0, pass_name="dmtet_pass1", optimize_light=FLAGS.learn_light)
|
|
|
|
base_mesh = geometry.getMesh(mat)
|
|
|
|
vert_mask = torch.zeros_like(geometry.sdf).long().cuda().view(-1, 1)
|
|
vert_mask[geometry.getValidVertsIdx()] = 1
|
|
|
|
# Free temporaries / cached memory
|
|
torch.cuda.empty_cache() ### may slow down training
|
|
|
|
old_geometry = geometry
|
|
|
|
if FLAGS.local_rank == 0 and FLAGS.validate:
|
|
validate(glctx, geometry, mat, lgt, dataset_validate, os.path.join(FLAGS.out_dir, f"val_viz_pre/dmtet_validate_{FLAGS.index}_{k}_{FLAGS.split_size}"), FLAGS)
|
|
|
|
else:
|
|
dmt_dict = torch.load(FLAGS.resume_path)
|
|
if FLAGS.use_ema:
|
|
geometry.sdf.data[:] = dmt_dict['sdf_ema']
|
|
else:
|
|
geometry.sdf.data[:] = dmt_dict['sdf']
|
|
geometry.deform.data[:] = dmt_dict['deform']
|
|
old_geometry = geometry
|
|
|
|
# Create textured mesh from result
|
|
if FLAGS.normal_only:
|
|
base_mesh = xatlas_uvmap_nrm(glctx, geometry, mat, FLAGS)
|
|
else:
|
|
base_mesh = xatlas_uvmap(glctx, geometry, mat, FLAGS)
|
|
|
|
|
|
geometry = DMTetGeometryFixedTopo(
|
|
geometry, base_mesh, FLAGS.dmtet_grid, FLAGS.mesh_scale, FLAGS,
|
|
deform_scale=FLAGS.second_stage_deform
|
|
)
|
|
|
|
|
|
geometry.sdf_sign.requires_grad = False
|
|
geometry.sdf_abs.requires_grad = False
|
|
geometry.deform.requires_grad = True
|
|
|
|
geometry.deform.data[:] = geometry.deform * FLAGS.first_stage_deform / FLAGS.second_stage_deform
|
|
|
|
if FLAGS.use_ema:
|
|
geometry.sdf_sign.data[:] = torch.sign(old_geometry.sdf_ema) ### use ema
|
|
else:
|
|
geometry.sdf_sign.data[:] = torch.sign(old_geometry.sdf) ### use ema
|
|
|
|
geometry.set_init_v_pos()
|
|
|
|
|
|
# ==============================================================================================
|
|
# Pass 2: Train with fixed topology (mesh)
|
|
# ==============================================================================================
|
|
geometry, mat = optimize_mesh(glctx, geometry, mat, lgt, dataset_train, dataset_validate, FLAGS,
|
|
pass_idx=1, pass_name="mesh_pass", warmup_iter=100, optimize_light=FLAGS.learn_light and not FLAGS.lock_light,
|
|
optimize_geometry=not FLAGS.lock_pos)
|
|
|
|
##### Process single-view tet grid
|
|
dataloader_validate = torch.utils.data.DataLoader(dataset_validate, batch_size=1, collate_fn=dataset_train.collate)
|
|
|
|
v_it = iter(dataloader_validate)
|
|
for _ in range(FLAGS.angle_ind):
|
|
v_curr = next(v_it)
|
|
|
|
target = prepare_batch(v_curr, 'random')
|
|
|
|
# ==============================================================================================
|
|
# Infer occluded regions
|
|
# ==============================================================================================
|
|
valid_tet_idx = geometry.getValidTetIdx().long()
|
|
|
|
buffers = geometry.render(glctx, target, lgt, mat, get_visible_tets=True)
|
|
|
|
## visible tets (except for rasterized ones)
|
|
visible_tets = torch.zeros(geometry.indices.size(0)).cuda()
|
|
visible_tets[buffers['visible_tet_id'].long()] = 1
|
|
|
|
## to include the rasterized tetrahedra
|
|
visible_and_rast_tets = visible_tets.clone()
|
|
rast_tet_id = valid_tet_idx[buffers['rast_triangle_id'].long()].unique()
|
|
visible_and_rast_tets[rast_tet_id] = 1
|
|
|
|
visible_tets = (visible_tets == 1)
|
|
visible_and_rast_tets = (visible_and_rast_tets == 1)
|
|
|
|
## label all tetrahedral vertices associated with any visible tets
|
|
visible_verts = torch.zeros(geometry.verts.size(0))
|
|
tet_inds = torch.arange(geometry.indices.size(0))
|
|
vis_vert_inds = geometry.indices[visible_tets].unique()
|
|
visible_verts[vis_vert_inds] = 1
|
|
|
|
visible_and_rast_verts = visible_verts.clone()
|
|
vis_and_rast_vert_inds = geometry.indices[visible_and_rast_tets].unique()
|
|
visible_and_rast_verts[vis_and_rast_vert_inds] = 1
|
|
visible_and_rast_verts = visible_and_rast_verts.bool()
|
|
|
|
torch.save({
|
|
'sdf': geometry.sdf_sign.cpu().detach(),
|
|
'deform': geometry.deform.cpu().detach(),
|
|
'vis': visible_verts.cpu().detach(),
|
|
'vis_rast': visible_and_rast_verts.cpu().detach()
|
|
}, os.path.join(FLAGS.out_dir, 'tets/dmtet.pt'))
|
|
|
|
# ==============================================================================================
|
|
|
|
|
|
if FLAGS.local_rank == 0 and FLAGS.validate:
|
|
validate(glctx, geometry, mat, lgt, dataset_validate, os.path.join(FLAGS.out_dir, f"val_viz/dmtet"), FLAGS)
|
|
|