kopia lustrzana https://github.com/lzzcd001/MeshDiffusion
164 wiersze
6.9 KiB
Python
164 wiersze
6.9 KiB
Python
# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
#
|
|
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
|
# property and proprietary rights in and to this material, related
|
|
# documentation and any modifications thereto. Any use, reproduction,
|
|
# disclosure or distribution of this material and related documentation
|
|
# without an express license agreement from NVIDIA CORPORATION or
|
|
# its affiliates is strictly prohibited.
|
|
|
|
import numpy as np
|
|
import torch
|
|
import sys
|
|
|
|
from ..render import util
|
|
from ..render import mesh
|
|
from ..render import render
|
|
from ..render import light
|
|
|
|
from .dataset import Dataset
|
|
|
|
import kaolin
|
|
|
|
###############################################################################
|
|
# Reference dataset using mesh & rendering
|
|
###############################################################################
|
|
|
|
class DatasetMesh(Dataset):
|
|
|
|
def __init__(self, ref_mesh, glctx, cam_radius, FLAGS, validate=False):
|
|
# Init
|
|
self.glctx = glctx
|
|
self.cam_radius = cam_radius
|
|
self.FLAGS = FLAGS
|
|
self.validate = validate
|
|
self.fovy = np.deg2rad(45)
|
|
self.aspect = FLAGS.train_res[1] / FLAGS.train_res[0]
|
|
self.random_lgt = FLAGS.random_lgt
|
|
self.camera_lgt = False
|
|
self.flat_shading = FLAGS.dataset_flat_shading
|
|
|
|
|
|
if self.FLAGS.local_rank == 0:
|
|
print(f"use flag shading {FLAGS.dataset_flat_shading}")
|
|
print("DatasetMesh: ref mesh has %d triangles and %d vertices" % (ref_mesh.t_pos_idx.shape[0], ref_mesh.v_pos.shape[0]))
|
|
|
|
# Sanity test training texture resolution
|
|
ref_texture_res = np.maximum(ref_mesh.material['kd'].getRes(), ref_mesh.material['ks'].getRes())
|
|
if 'normal' in ref_mesh.material:
|
|
ref_texture_res = np.maximum(ref_texture_res, ref_mesh.material['normal'].getRes())
|
|
if self.FLAGS.local_rank == 0 and FLAGS.texture_res[0] < ref_texture_res[0] or FLAGS.texture_res[1] < ref_texture_res[1]:
|
|
print("---> WARNING: Picked a texture resolution lower than the reference mesh [%d, %d] < [%d, %d]" % (FLAGS.texture_res[0], FLAGS.texture_res[1], ref_texture_res[0], ref_texture_res[1]))
|
|
|
|
print("Loading env map")
|
|
sys.stdout.flush()
|
|
# Load environment map texture
|
|
self.envlight = light.load_env(FLAGS.envmap, scale=FLAGS.env_scale)
|
|
|
|
print("Computing tangents")
|
|
sys.stdout.flush()
|
|
try:
|
|
self.ref_mesh = mesh.compute_tangents(ref_mesh)
|
|
except Exception as e:
|
|
print(e)
|
|
print("Continue without tangents...")
|
|
self.ref_mesh = ref_mesh
|
|
|
|
def _rotate_scene(self, itr):
|
|
proj_mtx = util.perspective(self.fovy, self.FLAGS.display_res[1] / self.FLAGS.display_res[0], self.FLAGS.cam_near_far[0], self.FLAGS.cam_near_far[1])
|
|
|
|
# Smooth rotation for display.
|
|
ang = (itr / 50) * np.pi * 2
|
|
mv = util.translate(0, 0, -self.cam_radius) @ (util.rotate_x(-0.4) @ util.rotate_y(ang))
|
|
mvp = proj_mtx @ mv
|
|
campos = torch.linalg.inv(mv)[:3, 3]
|
|
|
|
return mv[None, ...].cuda(), mvp[None, ...].cuda(), campos[None, ...].cuda(), self.FLAGS.display_res, self.FLAGS.spp
|
|
|
|
def _random_scene(self):
|
|
# ==============================================================================================
|
|
# Setup projection matrix
|
|
# ==============================================================================================
|
|
iter_res = self.FLAGS.train_res
|
|
proj_mtx = util.perspective(self.fovy, iter_res[1] / iter_res[0], self.FLAGS.cam_near_far[0], self.FLAGS.cam_near_far[1])
|
|
|
|
# ==============================================================================================
|
|
# Random camera & light position
|
|
# ==============================================================================================
|
|
|
|
# Random rotation/translation matrix for optimization.
|
|
mv = util.translate(0, 0, -self.cam_radius) @ util.random_rotation_translation(0.2)
|
|
mvp = proj_mtx @ mv
|
|
campos = torch.linalg.inv(mv)[:3, 3]
|
|
|
|
return mv[None, ...].cuda(), mvp[None, ...].cuda(), campos[None, ...].cuda(), iter_res, self.FLAGS.spp # Add batch dimension
|
|
|
|
def __len__(self):
|
|
return 50 if self.validate else (self.FLAGS.iter + 1) * self.FLAGS.batch
|
|
|
|
def __getitem__(self, itr):
|
|
# ==============================================================================================
|
|
# Randomize scene parameters
|
|
# ==============================================================================================
|
|
|
|
if self.validate:
|
|
mv, mvp, campos, iter_res, iter_spp = self._rotate_scene(itr)
|
|
camera_mv = None
|
|
else:
|
|
mv, mvp, campos, iter_res, iter_spp = self._random_scene()
|
|
if self.random_lgt:
|
|
rnd_rot = util.random_rotation()
|
|
camera_mv = rnd_rot.unsqueeze(0).clone()
|
|
elif self.camera_lgt:
|
|
camera_mv = mv.clone()
|
|
else:
|
|
camera_mv = None
|
|
|
|
|
|
|
|
with torch.no_grad():
|
|
render_out = render.render_mesh(self.glctx, self.ref_mesh, mvp, campos, self.envlight, iter_res, spp=iter_spp,
|
|
num_layers=self.FLAGS.layers, msaa=True, background=None, xfm_lgt=camera_mv, flat_shading=self.flat_shading)
|
|
img = render_out['shaded']
|
|
img_second = render_out['shaded_second']
|
|
normal = render_out['normal']
|
|
depth = render_out['depth']
|
|
geo_normal = render_out['geo_normal']
|
|
pos = render_out['pos']
|
|
|
|
sample_points = torch.tensor(kaolin.ops.mesh.sample_points(self.ref_mesh.v_pos.unsqueeze(0), self.ref_mesh.t_pos_idx, 50000)[0][0])
|
|
vertex_points = self.ref_mesh.v_pos
|
|
|
|
return_dict = {
|
|
'mv' : mv,
|
|
'mvp' : mvp,
|
|
'campos' : campos,
|
|
'resolution' : iter_res,
|
|
'spp' : iter_spp,
|
|
'img' : img,
|
|
'img_second' : img_second,
|
|
'spts': sample_points,
|
|
'vpts': vertex_points,
|
|
'faces': self.ref_mesh.t_pos_idx,
|
|
'depth': depth,
|
|
'normal': normal,
|
|
'geo_normal': geo_normal,
|
|
'geo_viewdir': render_out['geo_viewdir'],
|
|
'pos': pos,
|
|
'envlight_transform': camera_mv,
|
|
'mask': render_out['mask'],
|
|
'mask_cont': render_out['mask_cont'],
|
|
'rast_triangle_id': render_out['rast_triangle_id']
|
|
}
|
|
|
|
try:
|
|
return_dict['depth_second'] = render_out['depth_second']
|
|
except:
|
|
pass
|
|
|
|
try:
|
|
return_dict['normal_second'] = render_out['normal_second']
|
|
except:
|
|
pass
|
|
return return_dict
|