kopia lustrzana https://github.com/lzzcd001/MeshDiffusion
457 wiersze
21 KiB
Python
457 wiersze
21 KiB
Python
# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
#
|
|
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
|
# property and proprietary rights in and to this material, related
|
|
# documentation and any modifications thereto. Any use, reproduction,
|
|
# disclosure or distribution of this material and related documentation
|
|
# without an express license agreement from NVIDIA CORPORATION or
|
|
# its affiliates is strictly prohibited.
|
|
|
|
import os
|
|
import time
|
|
import argparse
|
|
import json
|
|
import sys
|
|
|
|
import glob
|
|
import tqdm
|
|
|
|
import numpy as np
|
|
import torch
|
|
import nvdiffrast.torch as dr
|
|
import xatlas
|
|
|
|
# Import topology / geometry trainers
|
|
from lib.geometry.dmtet import DMTetGeometry
|
|
|
|
import lib.render.renderutils as ru
|
|
from lib.render import material
|
|
from lib.render import util
|
|
from lib.render import mesh
|
|
from lib.render import texture
|
|
from lib.render import mlptexture
|
|
from lib.render import light
|
|
from lib.render import render
|
|
|
|
from pytorch3d.io import save_obj
|
|
|
|
import pymeshlab
|
|
|
|
RADIUS = 3.0
|
|
|
|
# Enable to debug back-prop anomalies
|
|
# torch.autograd.set_detect_anomaly(True)
|
|
|
|
###############################################################################
|
|
# Loss setup
|
|
###############################################################################
|
|
|
|
@torch.no_grad()
|
|
def createLoss(FLAGS):
|
|
if FLAGS.loss == "smape":
|
|
return lambda img, ref: ru.image_loss(img, ref, loss='smape', tonemapper='none')
|
|
elif FLAGS.loss == "mse":
|
|
return lambda img, ref: ru.image_loss(img, ref, loss='mse', tonemapper='none')
|
|
elif FLAGS.loss == "logl1":
|
|
return lambda img, ref: ru.image_loss(img, ref, loss='l1', tonemapper='log_srgb')
|
|
elif FLAGS.loss == "logl2":
|
|
return lambda img, ref: ru.image_loss(img, ref, loss='mse', tonemapper='log_srgb')
|
|
elif FLAGS.loss == "relmse":
|
|
return lambda img, ref: ru.image_loss(img, ref, loss='relmse', tonemapper='none')
|
|
else:
|
|
assert False
|
|
|
|
###############################################################################
|
|
# Mix background into a dataset image
|
|
###############################################################################
|
|
|
|
@torch.no_grad()
|
|
def prepare_batch(target, bg_type='black'):
|
|
# assert len(target['img'].shape) == 4, "Image shape should be [n, h, w, c]"
|
|
if bg_type == 'checker':
|
|
background = torch.tensor(util.checkerboard(target['img'].shape[1:3], 8), dtype=torch.float32, device='cuda')[None, ...]
|
|
elif bg_type == 'black':
|
|
background = torch.zeros((1, target['resolution'][0], target['resolution'][1]) + (3,), dtype=torch.float32, device='cuda')
|
|
elif bg_type == 'white':
|
|
background = torch.ones((1, target['resolution'][0], target['resolution'][1]) + (3,), dtype=torch.float32, device='cuda')
|
|
elif bg_type == 'reference':
|
|
background = target['img'][..., 0:3]
|
|
elif bg_type == 'random':
|
|
background = torch.rand((1, target['resolution'][0], target['resolution'][1]) + (3,), dtype=torch.float32, device='cuda')
|
|
else:
|
|
assert False, "Unknown background type %s" % bg_type
|
|
|
|
target['mv'] = target['mv'].cuda()
|
|
target['mvp'] = target['mvp'].cuda()
|
|
target['campos'] = target['campos'].cuda()
|
|
target['background'] = background
|
|
|
|
return target
|
|
|
|
###############################################################################
|
|
# UV - map geometry & convert to a mesh
|
|
###############################################################################
|
|
|
|
@torch.no_grad()
|
|
def xatlas_uvmap(glctx, geometry, mat, FLAGS):
|
|
eval_mesh = geometry.getMesh(mat)
|
|
|
|
# Create uvs with xatlas
|
|
v_pos = eval_mesh.v_pos.detach().cpu().numpy()
|
|
t_pos_idx = eval_mesh.t_pos_idx.detach().cpu().numpy()
|
|
vmapping, indices, uvs = xatlas.parametrize(v_pos, t_pos_idx)
|
|
|
|
# Convert to tensors
|
|
indices_int64 = indices.astype(np.uint64, casting='same_kind').view(np.int64)
|
|
|
|
uvs = torch.tensor(uvs, dtype=torch.float32, device='cuda')
|
|
faces = torch.tensor(indices_int64, dtype=torch.int64, device='cuda')
|
|
|
|
new_mesh = mesh.Mesh(v_tex=uvs, t_tex_idx=faces, base=eval_mesh)
|
|
|
|
mask, kd, ks, normal = render.render_uv(glctx, new_mesh, FLAGS.texture_res, eval_mesh.material['kd_ks_normal'])
|
|
|
|
if FLAGS.layers > 1:
|
|
kd = torch.cat((kd, torch.rand_like(kd[...,0:1])), dim=-1)
|
|
|
|
kd_min, kd_max = torch.tensor(FLAGS.kd_min, dtype=torch.float32, device='cuda'), torch.tensor(FLAGS.kd_max, dtype=torch.float32, device='cuda')
|
|
ks_min, ks_max = torch.tensor(FLAGS.ks_min, dtype=torch.float32, device='cuda'), torch.tensor(FLAGS.ks_max, dtype=torch.float32, device='cuda')
|
|
nrm_min, nrm_max = torch.tensor(FLAGS.nrm_min, dtype=torch.float32, device='cuda'), torch.tensor(FLAGS.nrm_max, dtype=torch.float32, device='cuda')
|
|
|
|
new_mesh.material = material.Material({
|
|
'bsdf' : mat['bsdf'],
|
|
'kd' : texture.Texture2D(kd, min_max=[kd_min, kd_max]),
|
|
'ks' : texture.Texture2D(ks, min_max=[ks_min, ks_max]),
|
|
'normal' : texture.Texture2D(normal, min_max=[nrm_min, nrm_max])
|
|
})
|
|
|
|
return new_mesh
|
|
|
|
###############################################################################
|
|
# Utility functions for material
|
|
###############################################################################
|
|
|
|
def initial_guess_material(geometry, mlp, FLAGS, init_mat=None):
|
|
kd_min, kd_max = torch.tensor(FLAGS.kd_min, dtype=torch.float32, device='cuda'), torch.tensor(FLAGS.kd_max, dtype=torch.float32, device='cuda')
|
|
ks_min, ks_max = torch.tensor(FLAGS.ks_min, dtype=torch.float32, device='cuda'), torch.tensor(FLAGS.ks_max, dtype=torch.float32, device='cuda')
|
|
nrm_min, nrm_max = torch.tensor(FLAGS.nrm_min, dtype=torch.float32, device='cuda'), torch.tensor(FLAGS.nrm_max, dtype=torch.float32, device='cuda')
|
|
if mlp:
|
|
mlp_min = torch.cat((kd_min[0:3], ks_min, nrm_min), dim=0)
|
|
mlp_max = torch.cat((kd_max[0:3], ks_max, nrm_max), dim=0)
|
|
mlp_map_opt = mlptexture.MLPTexture3D(geometry.getAABB(), channels=9, min_max=[mlp_min, mlp_max])
|
|
mat = material.Material({'kd_ks_normal' : mlp_map_opt})
|
|
else:
|
|
# Setup Kd (albedo) and Ks (x, roughness, metalness) textures
|
|
if FLAGS.random_textures or init_mat is None:
|
|
num_channels = 4 if FLAGS.layers > 1 else 3
|
|
kd_init = torch.rand(size=FLAGS.texture_res + [num_channels], device='cuda') * (kd_max - kd_min)[None, None, 0:num_channels] + kd_min[None, None, 0:num_channels]
|
|
kd_map_opt = texture.create_trainable(kd_init , FLAGS.texture_res, not FLAGS.custom_mip, [kd_min, kd_max])
|
|
|
|
ksR = np.random.uniform(size=FLAGS.texture_res + [1], low=0.0, high=0.01)
|
|
ksG = np.random.uniform(size=FLAGS.texture_res + [1], low=ks_min[1].cpu(), high=ks_max[1].cpu())
|
|
ksB = np.random.uniform(size=FLAGS.texture_res + [1], low=ks_min[2].cpu(), high=ks_max[2].cpu())
|
|
|
|
ks_map_opt = texture.create_trainable(np.concatenate((ksR, ksG, ksB), axis=2), FLAGS.texture_res, not FLAGS.custom_mip, [ks_min, ks_max])
|
|
else:
|
|
kd_map_opt = texture.create_trainable(init_mat['kd'], FLAGS.texture_res, not FLAGS.custom_mip, [kd_min, kd_max])
|
|
ks_map_opt = texture.create_trainable(init_mat['ks'], FLAGS.texture_res, not FLAGS.custom_mip, [ks_min, ks_max])
|
|
|
|
# Setup normal map
|
|
if FLAGS.random_textures or init_mat is None or 'normal' not in init_mat:
|
|
normal_map_opt = texture.create_trainable(np.array([0, 0, 1]), FLAGS.texture_res, not FLAGS.custom_mip, [nrm_min, nrm_max])
|
|
else:
|
|
normal_map_opt = texture.create_trainable(init_mat['normal'], FLAGS.texture_res, not FLAGS.custom_mip, [nrm_min, nrm_max])
|
|
|
|
mat = material.Material({
|
|
'kd' : kd_map_opt,
|
|
'ks' : ks_map_opt,
|
|
'normal' : normal_map_opt
|
|
})
|
|
|
|
if init_mat is not None:
|
|
mat['bsdf'] = init_mat['bsdf']
|
|
else:
|
|
mat['bsdf'] = 'pbr'
|
|
|
|
return mat
|
|
|
|
###############################################################################
|
|
# Validation & testing
|
|
###############################################################################
|
|
|
|
def rotate_scene(FLAGS, itr):
|
|
fovy = np.deg2rad(45)
|
|
cam_radius = RADIUS
|
|
proj_mtx = util.perspective(fovy, FLAGS.display_res[1] / FLAGS.display_res[0], FLAGS.cam_near_far[0], FLAGS.cam_near_far[1])
|
|
|
|
# Smooth rotation for display.
|
|
ang = (itr / 50) * np.pi * 2
|
|
mv = util.translate(0, 0, -cam_radius) @ (util.rotate_x(-0.4) @ util.rotate_y(ang))
|
|
mvp = proj_mtx @ mv
|
|
campos = torch.linalg.inv(mv)[:3, 3]
|
|
|
|
res_dict = {
|
|
'mv': mv[None, ...].cuda(),
|
|
'mvp': mvp[None, ...].cuda(),
|
|
'campos': campos[None, ...].cuda(),
|
|
'spp': 1,
|
|
'resolution': FLAGS.display_res
|
|
}
|
|
|
|
return res_dict
|
|
|
|
def validate_itr(glctx, target, geometry, opt_material, lgt, FLAGS):
|
|
result_dict = {}
|
|
with torch.no_grad():
|
|
lgt.build_mips()
|
|
if FLAGS.camera_space_light:
|
|
lgt.xfm(target['mv'])
|
|
|
|
buffers = geometry.render(glctx, target, lgt, opt_material)
|
|
|
|
result_dict['opt'] = util.rgb_to_srgb(buffers['shaded'][...,0:3])[0]
|
|
result_image = result_dict['opt']
|
|
|
|
return result_image, result_dict
|
|
|
|
def validate(glctx, geometry, opt_material, lgt, dataset_validate, out_dir, FLAGS):
|
|
|
|
# ==============================================================================================
|
|
# Validation loop
|
|
# ==============================================================================================
|
|
mse_values = []
|
|
psnr_values = []
|
|
|
|
dataloader_validate = torch.utils.data.DataLoader(dataset_validate, batch_size=1, collate_fn=dataset_validate.collate)
|
|
|
|
os.makedirs(out_dir, exist_ok=True)
|
|
with open(os.path.join(out_dir, 'metrics.txt'), 'w') as fout:
|
|
fout.write('ID, MSE, PSNR\n')
|
|
|
|
print("Running validation")
|
|
for it, target in enumerate(dataloader_validate):
|
|
|
|
# Mix validation background
|
|
target = prepare_batch(target, FLAGS.background)
|
|
|
|
result_image, result_dict = validate_itr(glctx, target, geometry, opt_material, lgt, FLAGS)
|
|
|
|
# Compute metrics
|
|
opt = torch.clamp(result_dict['opt'], 0.0, 1.0)
|
|
ref = torch.clamp(result_dict['ref'], 0.0, 1.0)
|
|
|
|
mse = torch.nn.functional.mse_loss(opt, ref, size_average=None, reduce=None, reduction='mean').item()
|
|
mse_values.append(float(mse))
|
|
psnr = util.mse_to_psnr(mse)
|
|
psnr_values.append(float(psnr))
|
|
|
|
line = "%d, %1.8f, %1.8f\n" % (it, mse, psnr)
|
|
fout.write(str(line))
|
|
|
|
for k in result_dict.keys():
|
|
np_img = result_dict[k].detach().cpu().numpy()
|
|
util.save_image(out_dir + '/' + ('val_%06d_%s.png' % (it, k)), np_img)
|
|
|
|
avg_mse = np.mean(np.array(mse_values))
|
|
avg_psnr = np.mean(np.array(psnr_values))
|
|
line = "AVERAGES: %1.4f, %2.3f\n" % (avg_mse, avg_psnr)
|
|
fout.write(str(line))
|
|
print("MSE, PSNR")
|
|
print("%1.8f, %2.3f" % (avg_mse, avg_psnr))
|
|
return avg_psnr
|
|
|
|
###############################################################################
|
|
# Main shape fitter function / optimization loop
|
|
###############################################################################
|
|
|
|
class Trainer(torch.nn.Module):
|
|
def __init__(self, glctx, geometry, lgt, mat, optimize_geometry, optimize_light, image_loss_fn, FLAGS):
|
|
super(Trainer, self).__init__()
|
|
|
|
self.glctx = glctx
|
|
self.geometry = geometry
|
|
self.light = lgt
|
|
self.material = mat
|
|
self.optimize_geometry = optimize_geometry
|
|
self.optimize_light = optimize_light
|
|
self.image_loss_fn = image_loss_fn
|
|
self.FLAGS = FLAGS
|
|
|
|
if not self.optimize_light:
|
|
with torch.no_grad():
|
|
self.light.build_mips()
|
|
|
|
self.params = list(self.material.parameters())
|
|
self.params += list(self.light.parameters()) if optimize_light else []
|
|
self.geo_params = list(self.geometry.parameters()) if optimize_geometry else []
|
|
|
|
def forward(self, target, it):
|
|
if self.optimize_light:
|
|
self.light.build_mips()
|
|
if self.FLAGS.camera_space_light:
|
|
self.light.xfm(target['mv'])
|
|
|
|
return self.geometry.tick(glctx, target, self.light, self.material, self.image_loss_fn, it)
|
|
|
|
#----------------------------------------------------------------------------
|
|
# Main function.
|
|
#----------------------------------------------------------------------------
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser(description='nvdiffrec')
|
|
parser.add_argument('--config', type=str, default=None, help='Config file')
|
|
parser.add_argument('-i', '--iter', type=int, default=5000)
|
|
parser.add_argument('-b', '--batch', type=int, default=1)
|
|
parser.add_argument('-s', '--spp', type=int, default=1)
|
|
parser.add_argument('-l', '--layers', type=int, default=1)
|
|
parser.add_argument('-r', '--train-res', nargs=2, type=int, default=[512, 512])
|
|
parser.add_argument('-dr', '--display-res', type=int, default=None)
|
|
parser.add_argument('-tr', '--texture-res', nargs=2, type=int, default=[1024, 1024])
|
|
parser.add_argument('-di', '--display-interval', type=int, default=0)
|
|
parser.add_argument('-si', '--save-interval', type=int, default=1000)
|
|
parser.add_argument('-lr', '--learning-rate', type=float, default=0.01)
|
|
parser.add_argument('-mr', '--min-roughness', type=float, default=0.08)
|
|
parser.add_argument('-mip', '--custom-mip', action='store_true', default=False)
|
|
parser.add_argument('-rt', '--random-textures', action='store_true', default=False)
|
|
parser.add_argument('-bg', '--background', default='checker', choices=['black', 'white', 'checker', 'reference'])
|
|
parser.add_argument('-o', '--out-dir', type=str, default='./viz_tet')
|
|
parser.add_argument('-sp', '--sample-path', type=str, default=None)
|
|
parser.add_argument('-bm', '--base-mesh', type=str, default=None)
|
|
parser.add_argument('-ds', '--deform-scale', type=float, default=2.0)
|
|
parser.add_argument('-vn', '--viz-name', type=str, default='viz')
|
|
parser.add_argument('--unnormalized_sdf', action="store_true")
|
|
parser.add_argument('--validate', type=bool, default=True)
|
|
parser.add_argument('--angle-ind', type=int, default=25, help='z-axis rotation of the object, from 0 to 50')
|
|
parser.add_argument('-ns', '--num-smooth-steps', type=int, default=3, help='number of post-processing Laplacian smoothing steps')
|
|
|
|
FLAGS = parser.parse_args()
|
|
|
|
FLAGS.mtl_override = None # Override material of model
|
|
FLAGS.dmtet_grid = 64 # Resolution of initial tet grid. We provide 64 and 128 resolution grids. Other resolutions can be generated with https://github.com/crawforddoran/quartet
|
|
FLAGS.mesh_scale = 2.1 # Scale of tet grid box. Adjust to cover the model
|
|
FLAGS.env_scale = 1.0 # Env map intensity multiplier
|
|
FLAGS.envmap = None # HDR environment probe
|
|
FLAGS.display = None # Conf validation window/display. E.g. [{"relight" : <path to envlight>}]
|
|
FLAGS.camera_space_light = False # Fixed light in camera space. This is needed for setups like ethiopian head where the scanned object rotates on a stand.
|
|
FLAGS.lock_light = False # Disable light optimization in the second pass
|
|
FLAGS.lock_pos = False # Disable vertex position optimization in the second pass
|
|
FLAGS.sdf_regularizer = 0.2 # Weight for sdf regularizer (see paper for details)
|
|
FLAGS.laplace = "relative" # Mesh Laplacian ["absolute", "relative"]
|
|
FLAGS.laplace_scale = 10000.0 # Weight for sdf regularizer. Default is relative with large weight
|
|
FLAGS.pre_load = True # Pre-load entire dataset into memory for faster training
|
|
FLAGS.kd_min = [ 0.0, 0.0, 0.0, 0.0] # Limits for kd
|
|
FLAGS.kd_max = [ 1.0, 1.0, 1.0, 1.0]
|
|
FLAGS.ks_min = [ 0.0, 0.08, 0.0] # Limits for ks
|
|
FLAGS.ks_max = [ 1.0, 1.0, 1.0]
|
|
FLAGS.nrm_min = [-1.0, -1.0, 0.0] # Limits for normal map
|
|
FLAGS.nrm_max = [ 1.0, 1.0, 1.0]
|
|
FLAGS.cam_near_far = [0.1, 1000.0]
|
|
FLAGS.learn_light = False
|
|
FLAGS.cropped = False
|
|
FLAGS.random_lgt = False
|
|
|
|
if FLAGS.config is not None:
|
|
data = json.load(open(FLAGS.config, 'r'))
|
|
for key in data:
|
|
FLAGS.__dict__[key] = data[key]
|
|
|
|
if FLAGS.display_res is None:
|
|
FLAGS.display_res = FLAGS.train_res
|
|
|
|
os.makedirs(FLAGS.out_dir, exist_ok=True)
|
|
viz_path = os.path.join(FLAGS.out_dir, 'viz')
|
|
mesh_path = os.path.join(FLAGS.out_dir, 'mesh')
|
|
os.makedirs(viz_path, exist_ok=True)
|
|
os.makedirs(mesh_path, exist_ok=True)
|
|
|
|
glctx = dr.RasterizeGLContext()
|
|
|
|
# ==============================================================================================
|
|
# Create env light with trainable parameters
|
|
# ==============================================================================================
|
|
|
|
lgt = light.load_env(FLAGS.envmap, scale=FLAGS.env_scale)
|
|
|
|
# ==============================================================================================
|
|
# If no initial guess, use DMtets to create geometry
|
|
# ==============================================================================================
|
|
|
|
# Setup geometry for optimization
|
|
resolution = FLAGS.dmtet_grid
|
|
geometry = DMTetGeometry(
|
|
resolution, FLAGS.mesh_scale, FLAGS,
|
|
deform_scale=FLAGS.deform_scale
|
|
)
|
|
|
|
mask = torch.load(f'../data/grid_mask_{resolution}.pt').view(1, resolution, resolution, resolution).to("cuda")
|
|
|
|
|
|
### compute the mapping from tet indices to 3D cubic grid vertex indices
|
|
tet_path = FLAGS.tet_path
|
|
tet = np.load(tet_path)
|
|
vertices = torch.tensor(tet['vertices'])
|
|
vertices_unique = vertices[:].unique()
|
|
dx = vertices_unique[1] - vertices_unique[0]
|
|
|
|
vertices_discretized = (torch.round(
|
|
(vertices - vertices.min()) / dx)
|
|
).long()
|
|
|
|
data_all = np.load(FLAGS.sample_path)
|
|
print('shape of generated data', data_all.shape)
|
|
|
|
for no_data in tqdm.trange(data_all.shape[0]):
|
|
|
|
grid = torch.tensor(data_all[no_data])
|
|
if FLAGS.unnormalized_sdf:
|
|
raise NotImplementedError
|
|
geometry.sdf.data[:] = (
|
|
grid[0, vertices_discretized[:, 0], vertices_discretized[:, 1], vertices_discretized[:, 2]]
|
|
).cuda()
|
|
else:
|
|
geometry.sdf.data[:] = torch.sign(
|
|
grid[0, vertices_discretized[:, 0], vertices_discretized[:, 1], vertices_discretized[:, 2]]
|
|
).cuda()
|
|
geometry.deform.data[:] = (
|
|
grid[1:, vertices_discretized[:, 0], vertices_discretized[:, 1], vertices_discretized[:, 2]]
|
|
).cuda().transpose(0, 1)
|
|
|
|
geometry.deform.data[:] = geometry.deform.data[:].clip(-1.0, 1.0)
|
|
|
|
### mtl for visualization
|
|
opt_material = {
|
|
'name' : '_default_mat',
|
|
# 'bsdf' : 'pbr',
|
|
'bsdf' : 'diffuse',
|
|
'kd' : texture.Texture2D(torch.tensor([0.75, 0.3, 0.6], dtype=torch.float32, device='cuda')),
|
|
'ks' : texture.Texture2D(torch.tensor([0.0, 0.0, 0.0], dtype=torch.float32, device='cuda'))
|
|
}
|
|
|
|
### create and optimize mesh
|
|
base_mesh = geometry.getMesh(opt_material)
|
|
|
|
|
|
### save image (before post-processing)
|
|
v_pose = rotate_scene(FLAGS, FLAGS.angle_ind) ## pick a pose (pose # from 0 to 50)
|
|
result_image, _ = validate_itr(glctx, prepare_batch(v_pose, FLAGS.background), geometry, opt_material, lgt, FLAGS)
|
|
result_image = result_image.detach().cpu().numpy()
|
|
util.save_image(os.path.join(viz_path, ('%s_%06d.png' % (FLAGS.viz_name, no_data))), result_image)
|
|
|
|
|
|
### save post-processed mesh
|
|
mesh_savepath = os.path.join(mesh_path, '{:06d}.obj'.format(no_data))
|
|
save_obj(
|
|
verts=base_mesh.v_pos,
|
|
faces=base_mesh.t_pos_idx,
|
|
f=mesh_savepath
|
|
)
|
|
|
|
ms = pymeshlab.MeshSet()
|
|
ms.load_new_mesh(mesh_savepath)
|
|
ms.meshing_isotropic_explicit_remeshing()
|
|
ms.apply_coord_laplacian_smoothing(stepsmoothnum=FLAGS.num_smooth_steps, cotangentweight=False)
|
|
# ms.apply_coord_laplacian_smoothing(stepsmoothnum=3, cotangentweight=True) ## for smoother surface
|
|
ms.meshing_isotropic_explicit_remeshing()
|
|
ms.apply_filter_script()
|
|
ms.save_current_mesh(mesh_savepath)
|