
164 wiersze
6.9 KiB

# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.
import numpy as np
import torch
import sys
from ..render import util
from ..render import mesh
from ..render import render
from ..render import light
from .dataset import Dataset
import kaolin
# Reference dataset using mesh & rendering
class DatasetMesh(Dataset):
def __init__(self, ref_mesh, glctx, cam_radius, FLAGS, validate=False):
# Init
self.glctx = glctx
self.cam_radius = cam_radius
self.validate = validate
self.fovy = np.deg2rad(45)
self.aspect = FLAGS.train_res[1] / FLAGS.train_res[0]
self.random_lgt = FLAGS.random_lgt
self.camera_lgt = False
self.flat_shading = FLAGS.dataset_flat_shading
if self.FLAGS.local_rank == 0:
print(f"use flag shading {FLAGS.dataset_flat_shading}")
print("DatasetMesh: ref mesh has %d triangles and %d vertices" % (ref_mesh.t_pos_idx.shape[0], ref_mesh.v_pos.shape[0]))
# Sanity test training texture resolution
ref_texture_res = np.maximum(ref_mesh.material['kd'].getRes(), ref_mesh.material['ks'].getRes())
if 'normal' in ref_mesh.material:
ref_texture_res = np.maximum(ref_texture_res, ref_mesh.material['normal'].getRes())
if self.FLAGS.local_rank == 0 and FLAGS.texture_res[0] < ref_texture_res[0] or FLAGS.texture_res[1] < ref_texture_res[1]:
print("---> WARNING: Picked a texture resolution lower than the reference mesh [%d, %d] < [%d, %d]" % (FLAGS.texture_res[0], FLAGS.texture_res[1], ref_texture_res[0], ref_texture_res[1]))
print("Loading env map")
# Load environment map texture
self.envlight = light.load_env(FLAGS.envmap, scale=FLAGS.env_scale)
print("Computing tangents")
self.ref_mesh = mesh.compute_tangents(ref_mesh)
except Exception as e:
print("Continue without tangents...")
self.ref_mesh = ref_mesh
def _rotate_scene(self, itr):
proj_mtx = util.perspective(self.fovy, self.FLAGS.display_res[1] / self.FLAGS.display_res[0], self.FLAGS.cam_near_far[0], self.FLAGS.cam_near_far[1])
# Smooth rotation for display.
ang = (itr / 50) * np.pi * 2
mv = util.translate(0, 0, -self.cam_radius) @ (util.rotate_x(-0.4) @ util.rotate_y(ang))
mvp = proj_mtx @ mv
campos = torch.linalg.inv(mv)[:3, 3]
return mv[None, ...].cuda(), mvp[None, ...].cuda(), campos[None, ...].cuda(), self.FLAGS.display_res, self.FLAGS.spp
def _random_scene(self):
# ==============================================================================================
# Setup projection matrix
# ==============================================================================================
iter_res = self.FLAGS.train_res
proj_mtx = util.perspective(self.fovy, iter_res[1] / iter_res[0], self.FLAGS.cam_near_far[0], self.FLAGS.cam_near_far[1])
# ==============================================================================================
# Random camera & light position
# ==============================================================================================
# Random rotation/translation matrix for optimization.
mv = util.translate(0, 0, -self.cam_radius) @ util.random_rotation_translation(0.2)
mvp = proj_mtx @ mv
campos = torch.linalg.inv(mv)[:3, 3]
return mv[None, ...].cuda(), mvp[None, ...].cuda(), campos[None, ...].cuda(), iter_res, self.FLAGS.spp # Add batch dimension
def __len__(self):
return 50 if self.validate else (self.FLAGS.iter + 1) * self.FLAGS.batch
def __getitem__(self, itr):
# ==============================================================================================
# Randomize scene parameters
# ==============================================================================================
if self.validate:
mv, mvp, campos, iter_res, iter_spp = self._rotate_scene(itr)
camera_mv = None
mv, mvp, campos, iter_res, iter_spp = self._random_scene()
if self.random_lgt:
rnd_rot = util.random_rotation()
camera_mv = rnd_rot.unsqueeze(0).clone()
elif self.camera_lgt:
camera_mv = mv.clone()
camera_mv = None
with torch.no_grad():
render_out = render.render_mesh(self.glctx, self.ref_mesh, mvp, campos, self.envlight, iter_res, spp=iter_spp,
num_layers=self.FLAGS.layers, msaa=True, background=None, xfm_lgt=camera_mv, flat_shading=self.flat_shading)
img = render_out['shaded']
img_second = render_out['shaded_second']
normal = render_out['normal']
depth = render_out['depth']
geo_normal = render_out['geo_normal']
pos = render_out['pos']
sample_points = torch.tensor(kaolin.ops.mesh.sample_points(self.ref_mesh.v_pos.unsqueeze(0), self.ref_mesh.t_pos_idx, 50000)[0][0])
vertex_points = self.ref_mesh.v_pos
return_dict = {
'mv' : mv,
'mvp' : mvp,
'campos' : campos,
'resolution' : iter_res,
'spp' : iter_spp,
'img' : img,
'img_second' : img_second,
'spts': sample_points,
'vpts': vertex_points,
'faces': self.ref_mesh.t_pos_idx,
'depth': depth,
'normal': normal,
'geo_normal': geo_normal,
'geo_viewdir': render_out['geo_viewdir'],
'pos': pos,
'envlight_transform': camera_mv,
'mask': render_out['mask'],
'mask_cont': render_out['mask_cont'],
'rast_triangle_id': render_out['rast_triangle_id']
return_dict['depth_second'] = render_out['depth_second']
return_dict['normal_second'] = render_out['normal_second']
return return_dict