pull/1/head
Zhen Liu 2023-03-11 18:17:59 +08:00
commit 64edf07e23
29 zmienionych plików z 9475 dodań i 0 usunięć

133
.gitignore vendored 100644
Wyświetl plik

@ -0,0 +1,133 @@
<<<<<<< HEAD
__pycache__
=======
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
.python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
>>>>>>> 0755e60e97dfcaf72aa397a7dd807819692dc314

21
LICENSE 100644
Wyświetl plik

@ -0,0 +1,21 @@
MIT License
Copyright (c) 2023 Zliu
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

106
README.md 100644
Wyświetl plik

@ -0,0 +1,106 @@
# MeshDiffusion: Score-based Generative 3D Mesh Modeling (ICLR 2023 Spotlight)
This is the official implementation of MeshDiffusion.
MeshDiffusion is a diffusion model for generating 3D meshes with a direct parametrization of deep marching tetrahedra (DMTet). Please refer to https://meshdiffusion.github.io for more details.
## Getting Started
### Requirements
- Python >= 3.8
- CUDA 11.6
- Pytorch >= 1.6
- Pytorch3D
Install https://github.com/NVlabs/nvdiffrec
### Pretrained Models
Download the files from
## Inference
### Unconditional Generation
Run the following
```
python main_diffusion.py --config=$DIFFUSION_CONFIG --mode=uncond_gen \
--config.eval.eval_dir=$OUTPUT_PATH \
--config.eval.ckpt_path=$CKPT_PATH
```
Later run
```
cd nvdiffrec
python eval.py --config $DMTET_CONFIG --sample-path $SAMPLE_PATH
```
where `$SAMPLE_PATH` is the generated sample `.npy` file in `$OUTPUT_PATH`
### Single-view Conditional Generation
First fit a DMTet from a single view of a mesh
```
cd nvdiffrec
python fit_singleview.py --mesh-path $MESH_PATH --angle-ind $ANGLE_IND --out-dir $OUT_DIR --validate $VALIDATE
```
Then use the trained diffusion model to complete the occluded regions
```
cd ..
python main_diffusion.py --config=$DIFFUSION_CONFIG --mode=cond_gen \
--config.eval.eval_dir=$EVAL_DIR \
--config.eval.ckpt_path=$CKPT_PATH \
--config.eval.partial_dmtet_path=$OUT_DIR/tets/dmtet.pt \
--config.eval.tet_path=$TET_PATH \
--config.eval.batch_size=$EVAL_BATCH_SIZE
```
Now visualize the completed meshes
```
cd nvdiffrec
python eval.py --config $DMTET_CONFIG --sample-path $SAMPLE_PATH
```
## Training
For ShapeNet, first create a list of paths of all ground-truth meshes and store them as a json file under `./nvdiffrec/data/shapenet_json`.
Then run the following
```
cd nvdiffrec
python fit_dmtets.py
```
Create a meta file for diffusion model training:
```
cd ../metadata/
python save_meta.py
```
Train a diffusion model
```
cd ..
python main_diffusion.py --mode=train \
```
## Texture Completion
Follow the instructions in https://github.com/TEXTurePaper/TEXTurePaper and create text-conditioned textures for the generated meshes.
## Acknowledgement
This repo is adapted from https://github.com/NVlabs/nvdiffrec and https://github.com/yang-song/score_sde_pytorch.

Wyświetl plik

@ -0,0 +1,88 @@
import ml_collections
import torch
def get_default_configs():
config = ml_collections.ConfigDict()
# training
config.training = training = ml_collections.ConfigDict()
config.training.batch_size = 64
training.n_iters = 2400001
training.snapshot_freq = 50000
training.log_freq = 50
training.eval_freq = 100
## store additional checkpoints for preemption in cloud computing environments
training.snapshot_freq_for_preemption = 5000
## produce samples at each snapshot.
training.snapshot_sampling = True
training.likelihood_weighting = False
training.continuous = True
training.reduce_mean = False
training.iter_size = 1
training.loss_type = 'l2'
training.train_dir = "PLACEHOLDER"
# sampling
config.sampling = sampling = ml_collections.ConfigDict()
sampling.n_steps_each = 1
sampling.noise_removal = True
sampling.probability_flow = False
sampling.snr = 0.075
# evaluation
config.eval = evaluate = ml_collections.ConfigDict()
evaluate.begin_ckpt = 50
evaluate.end_ckpt = 96
evaluate.batch_size = 512
evaluate.enable_sampling = True
evaluate.num_samples = 50000
evaluate.enable_loss = True
evaluate.enable_bpd = False
evaluate.bpd_dataset = 'test'
evaluate.ckpt_path = "PLACEHOLDER"
evaluate.partial_dmtet_path = "PLACEHOLDER"
evaluate.tet_path = "PLACEHOLDER"
evaluate.freeze_iters = 950
# data
config.data = data = ml_collections.ConfigDict()
data.dataset = 'LSUN'
data.image_size = 256
data.random_flip = True
data.uniform_dequantization = False
data.centered = False
data.num_channels = 3
data.num_workers = 4
data.normalize_sdf = True
data.meta_path = "PLACEHOLDER" ### metadata for all dataset files
data.filter_meta_path = "PLACEHOLDER" ### metadata for the list of training samples
# model
config.model = model = ml_collections.ConfigDict()
model.sigma_max = 378
model.sigma_min = 0.01
model.num_scales = 2000
model.beta_min = 0.1
model.beta_max = 20.
model.dropout = 0.
model.embedding_type = 'fourier'
model.deform_scale = 1.0
# optimization
config.optim = optim = ml_collections.ConfigDict()
optim.weight_decay = 0
optim.optimizer = 'Adam'
optim.lr = 2e-4
optim.beta1 = 0.9
optim.eps = 1e-8
optim.warmup = 5000
optim.grad_clip = 1.
config.seed = 42
config.device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
# rendering
config.render = render = ml_collections.ConfigDict()
return config

78
configs/res128.py 100644
Wyświetl plik

@ -0,0 +1,78 @@
# coding=utf-8
# Copyright 2020 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Config file for reproducing the results of DDPM on bedrooms."""
from configs.default_configs import get_default_configs
def get_config():
config = get_default_configs()
# training
training = config.training
training.sde = 'vpsde'
training.continuous = False
training.reduce_mean = True
training.batch_size = 8
training.lip_scale = None
training.iter_size = 4
training.snapshot_freq_for_preemption = 1000
# sampling
sampling = config.sampling
sampling.method = 'pc'
sampling.predictor = 'ancestral_sampling'
sampling.corrector = 'none'
# data
data = config.data
data.dataset = 'ShapeNet'
data.centered = True
data.image_size = 128
data.num_channels = 4
data.meta_path = "PLACEHOLDER" ### metadata for all dataset files
data.filter_meta_path = "PLACEHOLDER" ### metadata for the list of training samples
data.num_workers = 8
data.aug = True
# model
model = config.model
model.name = 'ddpm_res128_v2'
model.scale_by_sigma = False
model.num_scales = 1000
model.ema_rate = 0.9999
model.normalization = 'GroupNorm'
model.nonlinearity = 'swish'
model.nf = 128
model.ch_mult = (1, 1, 2, 4, 4, 4)
model.num_res_blocks_first = 2
model.num_res_blocks = 2
model.attn_resolutions = (16,)
model.resamp_with_conv = True
model.conditional = True
model.dropout = 0.1
# optim
optim = config.optim
optim.lr = 7e-5 / training.iter_size * 2.0
config.eval.batch_size = 7
config.seed = 42
return config

79
configs/res64.py 100644
Wyświetl plik

@ -0,0 +1,79 @@
# coding=utf-8
# Copyright 2020 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Config file for reproducing the results of DDPM on bedrooms."""
from configs.default_configs import get_default_configs
def get_config():
config = get_default_configs()
# training
training = config.training
training.sde = 'vpsde'
training.continuous = False
training.reduce_mean = True
training.batch_size = 48
training.lip_scale = None
training.snapshot_freq_for_preemption = 1000
# sampling
sampling = config.sampling
sampling.method = 'pc'
sampling.predictor = 'ancestral_sampling'
sampling.corrector = 'none'
# data
data = config.data
data.dataset = 'ShapeNet'
data.centered = True
data.image_size = 64
data.num_channels = 4
data.meta_path = "PLACEHOLDER" ### metadata for all dataset files
data.filter_meta_path = "PLACEHOLDER" ### metadata for the list of training samples
data.num_workers = 4
data.aug = True
# model
model = config.model
model.name = 'ddpm_res64'
model.scale_by_sigma = False
model.num_scales = 1000
model.ema_rate = 0.9999
model.normalization = 'GroupNorm'
model.nonlinearity = 'swish'
model.nf = 128
model.ch_mult = (1, 1, 2, 4, 4)
model.num_res_blocks_first = 2
model.num_res_blocks = 3
model.attn_resolutions = (16,)
model.resamp_with_conv = True
model.conditional = True
model.dropout = 0.1
# optim
optim = config.optim
optim.lr = 2e-5
config.eval.batch_size = 4
config.eval.eval_dir = "PLACEHOLDER"
config.seed = 42
return config

Wyświetl plik

@ -0,0 +1,37 @@
import numpy as np
import torch
import os
import sys
import json
import tqdm
import argparse
def tet_to_grids(vertices, grid_size):
grid = torch.zeros(grid_size, grid_size, grid_size, device=vertices.device)
with torch.no_grad():
for i in tqdm.tqdm(range(vertices.size(0))):
grid[vertices[i, 0].item(), vertices[i, 1].item(), vertices[i, 2].item()] = 1.0
return grid
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--resolution', type=int)
parser.add_argument('--tet_folder', type=str, default='../nvdiffrec/data/tets/')
args = parser.parse_args()
tet_path = f'{args.tet_folder}/{args.resolution}_tets_cropped.npz'
tet = np.load(tet_path)
vertices = torch.tensor(tet['vertices'])
vertices_unique = vertices[:].unique()
dx = vertices_unique[1] - vertices_unique[0]
vertices_discretized = (torch.round(
(vertices - vertices.min()) / dx)
).long()
grid = tet_to_grids(vertices_discretized, args.resolution)
torch.save(grid, f'grid_mask_{args.resolution}.pt')

Plik binarny nie jest wyświetlany.

Plik binarny nie jest wyświetlany.

Wyświetl plik

@ -0,0 +1,49 @@
import numpy as np
import torch
import os
import tqdm
import argparse
def tet_to_grids(vertices, values_list, grid_size):
grid = torch.zeros(4, grid_size, grid_size, grid_size, device=vertices.device)
with torch.no_grad():
for k, values in enumerate(values_list):
if k == 0:
grid[k, vertices[:, 0], vertices[:, 1], vertices[:, 2]] = values.squeeze()
else:
grid[1:, vertices[:, 0], vertices[:, 1], vertices[:, 2]] = values.transpose(0, 1)
return grid
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='nvdiffrec')
parser.add_argument('-res', '--resolution', type=int)
parser.add_argument('-ss', '--split-size', type=int, default=int(1e8))
parser.add_argument('-ind', '--index', type=int)
parser.add_argument('-r', '--root', type=str)
parser.add_argument('-s', '--source', type=str)
parser.add_argument('-t', '--target', type=str)
FLAGS = parser.parse_args()
tet_path = f'../nvdiffrec/data/tets/{FLAGS.resolution}_tets_cropped.npz'
tet = np.load(tet_path)
vertices = torch.tensor(tet['vertices'])
vertices_unique = vertices[:].unique()
dx = vertices_unique[1] - vertices_unique[0]
vertices_discretized = (torch.round(
(vertices - vertices.min()) / dx)
).long()
save_folder = FLAGS.root
grid_folder = os.path.join(save_folder, FLAGS.target)
os.makedirs(grid_folder, exist_ok=True)
tets_folder = os.path.join(save_folder, FLAGS.source)
for k in tqdm.trange(FLAGS.split_size):
global_index = k + FLAGS.index * FLAGS.split_size
tet_path = os.path.join(tets_folder, 'dmt_dict_{:05d}.pt'.format(global_index))
if os.path.exists(tet_path):
tet = torch.load(tet_path, map_location="cpu")
grid = tet_to_grids(vertices_discretized, (tet['sdf'].unsqueeze(-1), tet['deform']), FLAGS.resolution)
torch.save(grid, os.path.join(grid_folder, 'grid_{:05d}.pt'.format(global_index)))

43
main_diffusion.py 100644
Wyświetl plik

@ -0,0 +1,43 @@
# coding=utf-8
# Copyright 2020 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Training and evaluation"""
from absl import app
from absl import flags
from ml_collections.config_flags import config_flags
import lib.diffusion.trainer as trainer
import lib.diffusion.evaler as evaler
FLAGS = flags.FLAGS
config_flags.DEFINE_config_file(
"config", None, "diffusion configs", lock_config=False)
flags.DEFINE_enum("mode", None, ["train", "uncond_gen", "cond_gen"], "Running mode")
flags.mark_flags_as_required(["config", "mode"])
def main(argv):
if FLAGS.mode == 'train':
trainer.train(FLAGS.config)
elif FLAGS.mode == 'uncond_gen':
evaler.uncond_gen(FLAGS.config)
elif FLAGS.mode == 'cond_gen':
evaler.cond_gen(FLAGS.config)
if __name__ == "__main__":
app.run(main)

Wyświetl plik

@ -0,0 +1,14 @@
import os
import json
import argparse
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--data_path', type=str)
parser.add_argument('--json_path', type=str)
args = parser.parse_args()
fpath_list = sorted([os.path.join(args.data_path, fname) for fname in os.listdir(root) if fname.endswith('.pt')])
os.makedirs(args.json_path, exist_ok=True)
json.dump(fpath_list, open(args.json_path, 'w'))

File diff suppressed because one or more lines are too long

Wyświetl plik

@ -0,0 +1,18 @@
{
"random_textures": true,
"iter": 5000,
"save_interval": 100,
"texture_res": [ 2048, 2048 ],
"train_res": [1000, 1000],
"batch": 4,
"learning_rate": [0.01, 0.003],
"ks_min" : [0, 0.08, 0.0],
"dmtet_grid" : 128,
"mesh_scale" : 1.1,
"laplace_scale" : 10000,
"display": [{"bsdf" : "kd"}, {"bsdf" : "ks"}, {"bsdf" : "normal"}, {"depth": true}],
"background" : "white",
"envmap": "./data/irrmaps/aerodynamics_workshop_2k.hdr",
"tet_path": "./data/tets/128_tets_cropped.npz",
"cropped": true
}

Wyświetl plik

@ -0,0 +1,18 @@
{
"random_textures": true,
"iter": 5000,
"save_interval": 100,
"texture_res": [ 2048, 2048 ],
"train_res": [1000, 1000],
"batch": 4,
"learning_rate": [0.01, 0.003],
"ks_min" : [0, 0.08, 0.0],
"dmtet_grid" : 64,
"mesh_scale" : 1.1,
"laplace_scale" : 10000,
"display": [{"bsdf" : "kd"}, {"bsdf" : "ks"}, {"bsdf" : "normal"}, {"depth": true}],
"background" : "white",
"envmap": "./data/irrmaps/aerodynamics_workshop_2k.hdr",
"tet_path": "./data/tets/64_tets_cropped.npz",
"cropped": true
}

Wyświetl plik

@ -0,0 +1,3 @@
The aerodynamics_workshop_2k.hdr HDR probe is from https://polyhaven.com/a/aerodynamics_workshop
CC0 License.

Plik binarny nie jest wyświetlany.

Plik binarny nie jest wyświetlany.

Plik binarny nie jest wyświetlany.

Plik binarny nie jest wyświetlany.

Wyświetl plik

@ -0,0 +1,6 @@
Place the tet grid files in this folder.
We provide a few example grids. See the main README.md for a download link.
You can also generate your own grids using https://github.com/crawforddoran/quartet
Please see the `generate_tets.py` script for an example.

Wyświetl plik

@ -0,0 +1,75 @@
# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.
import os
import numpy as np
import argparse
from collections import defaultdict
def crop_tets(vertices, indices):
assert indices.shape[1] == 4
vertices_cropped = np.array(vertices)
mask = None
for k in range(3):
if mask is None:
mask = (vertices[:, k] != np.min(vertices[:, k])) & (vertices[:, k] != np.max(vertices[:, k]))
else:
mask = (vertices[:, k] != np.min(vertices[:, k])) & (vertices[:, k] != np.max(vertices[:, k])) & mask
print(f"remaining: {mask.sum()} out of {vertices.shape[0]}")
vertices_cropped = vertices[mask]
vert_inds = np.arange(vertices.shape[0])
vert_inds_unused_mask = (1.0 - mask).astype(np.bool)
verts_inds_unused = vert_inds[vert_inds_unused_mask]
print(f"{verts_inds_unused.shape[0]} out of {vertices.shape[0]}")
remapping = defaultdict(lambda : -1)
count = 0
for i in range(vertices.shape[0]):
if mask[i]:
remapping[i] = count
count += 1
indices_cropped = np.zeros_like(indices, dtype=np.int32)
count = 0
for i in range(indices.shape[0]):
flag = True
tmp = np.zeros((4,))
for k in range(4):
if remapping[indices[i, k]] == -1:
flag = False
break
else:
tmp[k] = remapping[indices[i, k]]
if flag:
indices_cropped[count, :] = tmp[:]
count += 1
if i % 1000 == 0:
print(f"iter {i} / {indices.shape[0]}")
print(vertices_cropped.shape[0], np.min(indices_cropped), np.max(indices_cropped))
indices_cropped = indices_cropped[:count]
return vertices_cropped, indices_cropped
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--resolution', type=int)
args = parser.parse_args()
resolution = args.resolution
npzfile = f'{resolution}_tets.npz'
data = np.load(npzfile)
new_verts, new_inds = crop_tets(data['vertices'], data['indices'])
np.savez_compressed(f'{resolution}_tets_cropped.npz', vertices=new_verts, indices=new_inds)

Wyświetl plik

@ -0,0 +1,47 @@
# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.
import os
import numpy as np
'''
This code segment shows how to use Quartet: https://github.com/crawforddoran/quartet,
to generate a tet grid
1) Download, compile and run Quartet as described in the link above. Example usage `quartet meshes/cube.obj 0.5 cube_5.tet`
2) Run the function below to generate a file `cube_32_tet.tet`
'''
def generate_tetrahedron_grid_file(res=32, root='..'):
frac = 1.0 / res
command = 'cd %s/quartet; ' % (root) + \
'./quartet meshes/cube.obj %f meshes/cube_%f_tet.tet -s meshes/cube_boundary_%f.obj' % (frac, res, res)
os.system(command)
'''
This code segment shows how to convert from a quartet .tet file to compressed npz file
'''
def convert_from_quartet_to_npz(quartetfile = 'cube_32_tet.tet', npzfile = '32_tets'):
file1 = open(quartetfile, 'r')
header = file1.readline()
numvertices = int(header.split(" ")[1])
numtets = int(header.split(" ")[2])
print(numvertices, numtets)
# load vertices
vertices = np.loadtxt(quartetfile, skiprows=1, max_rows=numvertices)
print(vertices.shape)
# load indices
indices = np.loadtxt(quartetfile, dtype=int, skiprows=1+numvertices, max_rows=numtets)
print(indices.shape)
np.savez_compressed(npzfile, vertices=vertices, indices=indices)

452
nvdiffrec/eval.py 100644
Wyświetl plik

@ -0,0 +1,452 @@
# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.
import os
import time
import argparse
import json
import sys
import glob
import tqdm
import numpy as np
import torch
import nvdiffrast.torch as dr
import xatlas
# Import topology / geometry trainers
from lib.geometry.dmtet import DMTetGeometry
import lib.render.renderutils as ru
from lib.render import material
from lib.render import util
from lib.render import mesh
from lib.render import texture
from lib.render import mlptexture
from lib.render import light
from lib.render import render
from pytorch3d.io import save_obj
import pymeshlab
RADIUS = 3.0
# Enable to debug back-prop anomalies
# torch.autograd.set_detect_anomaly(True)
###############################################################################
# Loss setup
###############################################################################
@torch.no_grad()
def createLoss(FLAGS):
if FLAGS.loss == "smape":
return lambda img, ref: ru.image_loss(img, ref, loss='smape', tonemapper='none')
elif FLAGS.loss == "mse":
return lambda img, ref: ru.image_loss(img, ref, loss='mse', tonemapper='none')
elif FLAGS.loss == "logl1":
return lambda img, ref: ru.image_loss(img, ref, loss='l1', tonemapper='log_srgb')
elif FLAGS.loss == "logl2":
return lambda img, ref: ru.image_loss(img, ref, loss='mse', tonemapper='log_srgb')
elif FLAGS.loss == "relmse":
return lambda img, ref: ru.image_loss(img, ref, loss='relmse', tonemapper='none')
else:
assert False
###############################################################################
# Mix background into a dataset image
###############################################################################
@torch.no_grad()
def prepare_batch(target, bg_type='black'):
# assert len(target['img'].shape) == 4, "Image shape should be [n, h, w, c]"
if bg_type == 'checker':
background = torch.tensor(util.checkerboard(target['img'].shape[1:3], 8), dtype=torch.float32, device='cuda')[None, ...]
elif bg_type == 'black':
background = torch.zeros((1, target['resolution'][0], target['resolution'][1]) + (3,), dtype=torch.float32, device='cuda')
elif bg_type == 'white':
background = torch.ones((1, target['resolution'][0], target['resolution'][1]) + (3,), dtype=torch.float32, device='cuda')
elif bg_type == 'reference':
background = target['img'][..., 0:3]
elif bg_type == 'random':
background = torch.rand((1, target['resolution'][0], target['resolution'][1]) + (3,), dtype=torch.float32, device='cuda')
else:
assert False, "Unknown background type %s" % bg_type
target['mv'] = target['mv'].cuda()
target['mvp'] = target['mvp'].cuda()
target['campos'] = target['campos'].cuda()
target['background'] = background
return target
###############################################################################
# UV - map geometry & convert to a mesh
###############################################################################
@torch.no_grad()
def xatlas_uvmap(glctx, geometry, mat, FLAGS):
eval_mesh = geometry.getMesh(mat)
# Create uvs with xatlas
v_pos = eval_mesh.v_pos.detach().cpu().numpy()
t_pos_idx = eval_mesh.t_pos_idx.detach().cpu().numpy()
vmapping, indices, uvs = xatlas.parametrize(v_pos, t_pos_idx)
# Convert to tensors
indices_int64 = indices.astype(np.uint64, casting='same_kind').view(np.int64)
uvs = torch.tensor(uvs, dtype=torch.float32, device='cuda')
faces = torch.tensor(indices_int64, dtype=torch.int64, device='cuda')
new_mesh = mesh.Mesh(v_tex=uvs, t_tex_idx=faces, base=eval_mesh)
mask, kd, ks, normal = render.render_uv(glctx, new_mesh, FLAGS.texture_res, eval_mesh.material['kd_ks_normal'])
if FLAGS.layers > 1:
kd = torch.cat((kd, torch.rand_like(kd[...,0:1])), dim=-1)
kd_min, kd_max = torch.tensor(FLAGS.kd_min, dtype=torch.float32, device='cuda'), torch.tensor(FLAGS.kd_max, dtype=torch.float32, device='cuda')
ks_min, ks_max = torch.tensor(FLAGS.ks_min, dtype=torch.float32, device='cuda'), torch.tensor(FLAGS.ks_max, dtype=torch.float32, device='cuda')
nrm_min, nrm_max = torch.tensor(FLAGS.nrm_min, dtype=torch.float32, device='cuda'), torch.tensor(FLAGS.nrm_max, dtype=torch.float32, device='cuda')
new_mesh.material = material.Material({
'bsdf' : mat['bsdf'],
'kd' : texture.Texture2D(kd, min_max=[kd_min, kd_max]),
'ks' : texture.Texture2D(ks, min_max=[ks_min, ks_max]),
'normal' : texture.Texture2D(normal, min_max=[nrm_min, nrm_max])
})
return new_mesh
###############################################################################
# Utility functions for material
###############################################################################
def initial_guess_material(geometry, mlp, FLAGS, init_mat=None):
kd_min, kd_max = torch.tensor(FLAGS.kd_min, dtype=torch.float32, device='cuda'), torch.tensor(FLAGS.kd_max, dtype=torch.float32, device='cuda')
ks_min, ks_max = torch.tensor(FLAGS.ks_min, dtype=torch.float32, device='cuda'), torch.tensor(FLAGS.ks_max, dtype=torch.float32, device='cuda')
nrm_min, nrm_max = torch.tensor(FLAGS.nrm_min, dtype=torch.float32, device='cuda'), torch.tensor(FLAGS.nrm_max, dtype=torch.float32, device='cuda')
if mlp:
mlp_min = torch.cat((kd_min[0:3], ks_min, nrm_min), dim=0)
mlp_max = torch.cat((kd_max[0:3], ks_max, nrm_max), dim=0)
mlp_map_opt = mlptexture.MLPTexture3D(geometry.getAABB(), channels=9, min_max=[mlp_min, mlp_max])
mat = material.Material({'kd_ks_normal' : mlp_map_opt})
else:
# Setup Kd (albedo) and Ks (x, roughness, metalness) textures
if FLAGS.random_textures or init_mat is None:
num_channels = 4 if FLAGS.layers > 1 else 3
kd_init = torch.rand(size=FLAGS.texture_res + [num_channels], device='cuda') * (kd_max - kd_min)[None, None, 0:num_channels] + kd_min[None, None, 0:num_channels]
kd_map_opt = texture.create_trainable(kd_init , FLAGS.texture_res, not FLAGS.custom_mip, [kd_min, kd_max])
ksR = np.random.uniform(size=FLAGS.texture_res + [1], low=0.0, high=0.01)
ksG = np.random.uniform(size=FLAGS.texture_res + [1], low=ks_min[1].cpu(), high=ks_max[1].cpu())
ksB = np.random.uniform(size=FLAGS.texture_res + [1], low=ks_min[2].cpu(), high=ks_max[2].cpu())
ks_map_opt = texture.create_trainable(np.concatenate((ksR, ksG, ksB), axis=2), FLAGS.texture_res, not FLAGS.custom_mip, [ks_min, ks_max])
else:
kd_map_opt = texture.create_trainable(init_mat['kd'], FLAGS.texture_res, not FLAGS.custom_mip, [kd_min, kd_max])
ks_map_opt = texture.create_trainable(init_mat['ks'], FLAGS.texture_res, not FLAGS.custom_mip, [ks_min, ks_max])
# Setup normal map
if FLAGS.random_textures or init_mat is None or 'normal' not in init_mat:
normal_map_opt = texture.create_trainable(np.array([0, 0, 1]), FLAGS.texture_res, not FLAGS.custom_mip, [nrm_min, nrm_max])
else:
normal_map_opt = texture.create_trainable(init_mat['normal'], FLAGS.texture_res, not FLAGS.custom_mip, [nrm_min, nrm_max])
mat = material.Material({
'kd' : kd_map_opt,
'ks' : ks_map_opt,
'normal' : normal_map_opt
})
if init_mat is not None:
mat['bsdf'] = init_mat['bsdf']
else:
mat['bsdf'] = 'pbr'
return mat
###############################################################################
# Validation & testing
###############################################################################
def rotate_scene(FLAGS, itr):
fovy = np.deg2rad(45)
cam_radius = RADIUS
proj_mtx = util.perspective(fovy, FLAGS.display_res[1] / FLAGS.display_res[0], FLAGS.cam_near_far[0], FLAGS.cam_near_far[1])
# Smooth rotation for display.
ang = (itr / 50) * np.pi * 2
mv = util.translate(0, 0, -cam_radius) @ (util.rotate_x(-0.4) @ util.rotate_y(ang))
mvp = proj_mtx @ mv
campos = torch.linalg.inv(mv)[:3, 3]
res_dict = {
'mv': mv[None, ...].cuda(),
'mvp': mvp[None, ...].cuda(),
'campos': campos[None, ...].cuda(),
'spp': 1,
'resolution': FLAGS.display_res
}
return res_dict
def validate_itr(glctx, target, geometry, opt_material, lgt, FLAGS):
result_dict = {}
with torch.no_grad():
lgt.build_mips()
if FLAGS.camera_space_light:
lgt.xfm(target['mv'])
buffers = geometry.render(glctx, target, lgt, opt_material)
result_dict['opt'] = util.rgb_to_srgb(buffers['shaded'][...,0:3])[0]
result_image = result_dict['opt']
return result_image, result_dict
def validate(glctx, geometry, opt_material, lgt, dataset_validate, out_dir, FLAGS):
# ==============================================================================================
# Validation loop
# ==============================================================================================
mse_values = []
psnr_values = []
dataloader_validate = torch.utils.data.DataLoader(dataset_validate, batch_size=1, collate_fn=dataset_validate.collate)
os.makedirs(out_dir, exist_ok=True)
with open(os.path.join(out_dir, 'metrics.txt'), 'w') as fout:
fout.write('ID, MSE, PSNR\n')
print("Running validation")
for it, target in enumerate(dataloader_validate):
# Mix validation background
target = prepare_batch(target, FLAGS.background)
result_image, result_dict = validate_itr(glctx, target, geometry, opt_material, lgt, FLAGS)
# Compute metrics
opt = torch.clamp(result_dict['opt'], 0.0, 1.0)
ref = torch.clamp(result_dict['ref'], 0.0, 1.0)
mse = torch.nn.functional.mse_loss(opt, ref, size_average=None, reduce=None, reduction='mean').item()
mse_values.append(float(mse))
psnr = util.mse_to_psnr(mse)
psnr_values.append(float(psnr))
line = "%d, %1.8f, %1.8f\n" % (it, mse, psnr)
fout.write(str(line))
for k in result_dict.keys():
np_img = result_dict[k].detach().cpu().numpy()
util.save_image(out_dir + '/' + ('val_%06d_%s.png' % (it, k)), np_img)
avg_mse = np.mean(np.array(mse_values))
avg_psnr = np.mean(np.array(psnr_values))
line = "AVERAGES: %1.4f, %2.3f\n" % (avg_mse, avg_psnr)
fout.write(str(line))
print("MSE, PSNR")
print("%1.8f, %2.3f" % (avg_mse, avg_psnr))
return avg_psnr
###############################################################################
# Main shape fitter function / optimization loop
###############################################################################
class Trainer(torch.nn.Module):
def __init__(self, glctx, geometry, lgt, mat, optimize_geometry, optimize_light, image_loss_fn, FLAGS):
super(Trainer, self).__init__()
self.glctx = glctx
self.geometry = geometry
self.light = lgt
self.material = mat
self.optimize_geometry = optimize_geometry
self.optimize_light = optimize_light
self.image_loss_fn = image_loss_fn
self.FLAGS = FLAGS
if not self.optimize_light:
with torch.no_grad():
self.light.build_mips()
self.params = list(self.material.parameters())
self.params += list(self.light.parameters()) if optimize_light else []
self.geo_params = list(self.geometry.parameters()) if optimize_geometry else []
def forward(self, target, it):
if self.optimize_light:
self.light.build_mips()
if self.FLAGS.camera_space_light:
self.light.xfm(target['mv'])
return self.geometry.tick(glctx, target, self.light, self.material, self.image_loss_fn, it)
#----------------------------------------------------------------------------
# Main function.
#----------------------------------------------------------------------------
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='nvdiffrec')
parser.add_argument('--config', type=str, default=None, help='Config file')
parser.add_argument('-i', '--iter', type=int, default=5000)
parser.add_argument('-b', '--batch', type=int, default=1)
parser.add_argument('-s', '--spp', type=int, default=1)
parser.add_argument('-l', '--layers', type=int, default=1)
parser.add_argument('-r', '--train-res', nargs=2, type=int, default=[512, 512])
parser.add_argument('-dr', '--display-res', type=int, default=None)
parser.add_argument('-tr', '--texture-res', nargs=2, type=int, default=[1024, 1024])
parser.add_argument('-di', '--display-interval', type=int, default=0)
parser.add_argument('-si', '--save-interval', type=int, default=1000)
parser.add_argument('-lr', '--learning-rate', type=float, default=0.01)
parser.add_argument('-mr', '--min-roughness', type=float, default=0.08)
parser.add_argument('-mip', '--custom-mip', action='store_true', default=False)
parser.add_argument('-rt', '--random-textures', action='store_true', default=False)
parser.add_argument('-bg', '--background', default='checker', choices=['black', 'white', 'checker', 'reference'])
parser.add_argument('-o', '--out-dir', type=str, default='./viz_tet')
parser.add_argument('-sp', '--sample-path', type=str, default=None)
parser.add_argument('-bm', '--base-mesh', type=str, default=None)
parser.add_argument('-ds', '--deform-scale', type=float, default=2.0)
parser.add_argument('-vn', '--viz-name', type=str, default='viz')
parser.add_argument('--unnormalized_sdf', action="store_true")
parser.add_argument('--validate', type=bool, default=True)
FLAGS = parser.parse_args()
FLAGS.mtl_override = None # Override material of model
FLAGS.dmtet_grid = 64 # Resolution of initial tet grid. We provide 64 and 128 resolution grids. Other resolutions can be generated with https://github.com/crawforddoran/quartet
FLAGS.mesh_scale = 2.1 # Scale of tet grid box. Adjust to cover the model
FLAGS.env_scale = 1.0 # Env map intensity multiplier
FLAGS.envmap = None # HDR environment probe
FLAGS.display = None # Conf validation window/display. E.g. [{"relight" : <path to envlight>}]
FLAGS.camera_space_light = False # Fixed light in camera space. This is needed for setups like ethiopian head where the scanned object rotates on a stand.
FLAGS.lock_light = False # Disable light optimization in the second pass
FLAGS.lock_pos = False # Disable vertex position optimization in the second pass
FLAGS.sdf_regularizer = 0.2 # Weight for sdf regularizer (see paper for details)
FLAGS.laplace = "relative" # Mesh Laplacian ["absolute", "relative"]
FLAGS.laplace_scale = 10000.0 # Weight for sdf regularizer. Default is relative with large weight
FLAGS.pre_load = True # Pre-load entire dataset into memory for faster training
FLAGS.kd_min = [ 0.0, 0.0, 0.0, 0.0] # Limits for kd
FLAGS.kd_max = [ 1.0, 1.0, 1.0, 1.0]
FLAGS.ks_min = [ 0.0, 0.08, 0.0] # Limits for ks
FLAGS.ks_max = [ 1.0, 1.0, 1.0]
FLAGS.nrm_min = [-1.0, -1.0, 0.0] # Limits for normal map
FLAGS.nrm_max = [ 1.0, 1.0, 1.0]
FLAGS.cam_near_far = [0.1, 1000.0]
FLAGS.learn_light = False
FLAGS.cropped = False
FLAGS.random_lgt = False
if FLAGS.config is not None:
data = json.load(open(FLAGS.config, 'r'))
for key in data:
FLAGS.__dict__[key] = data[key]
if FLAGS.display_res is None:
FLAGS.display_res = FLAGS.train_res
os.makedirs(FLAGS.out_dir, exist_ok=True)
viz_path = os.path.join(FLAGS.out_dir, 'viz')
mesh_path = os.path.join(FLAGS.out_dir, 'mesh')
os.makedirs(viz_path, exist_ok=True)
os.makedirs(mesh_path, exist_ok=True)
glctx = dr.RasterizeGLContext()
# ==============================================================================================
# Create env light with trainable parameters
# ==============================================================================================
lgt = light.load_env(FLAGS.envmap, scale=FLAGS.env_scale)
# ==============================================================================================
# If no initial guess, use DMtets to create geometry
# ==============================================================================================
# Setup geometry for optimization
resolution = FLAGS.dmtet_grid
geometry = DMTetGeometry(resolution, FLAGS.mesh_scale, FLAGS)
geometry.deform_scale = FLAGS.deform_scale
mask = torch.load(f'../data/grid_mask_{resolution}.pt').view(1, resolution, resolution, resolution).to("cuda")
### compute the mapping from tet indices to 3D cubic grid vertex indices
tet_path = FLAGS.tet_path
tet = np.load(tet_path)
vertices = torch.tensor(tet['vertices'])
vertices_unique = vertices[:].unique()
dx = vertices_unique[1] - vertices_unique[0]
vertices_discretized = (torch.round(
(vertices - vertices.min()) / dx)
).long()
data_all = np.load(FLAGS.sample_path)
print('shape of generated data', data_all.shape)
for no_data in tqdm.trange(data_all.shape[0]):
grid = torch.tensor(data_all[no_data])
if FLAGS.unnormalized_sdf:
raise NotImplementedError
geometry.sdf.data[:] = (
grid[0, vertices_discretized[:, 0], vertices_discretized[:, 1], vertices_discretized[:, 2]]
).cuda()
else:
geometry.sdf.data[:] = torch.sign(
grid[0, vertices_discretized[:, 0], vertices_discretized[:, 1], vertices_discretized[:, 2]]
).cuda()
geometry.deform.data[:] = (
grid[1:, vertices_discretized[:, 0], vertices_discretized[:, 1], vertices_discretized[:, 2]]
).cuda().transpose(0, 1)
geometry.deform.data[:] = geometry.deform.data[:].clip(-1.0, 1.0)
### mtl for visualization
opt_material = {
'name' : '_default_mat',
# 'bsdf' : 'pbr',
'bsdf' : 'diffuse',
'kd' : texture.Texture2D(torch.tensor([0.75, 0.3, 0.6], dtype=torch.float32, device='cuda')),
'ks' : texture.Texture2D(torch.tensor([0.0, 0.0, 0.0], dtype=torch.float32, device='cuda'))
}
### create and optimize mesh
base_mesh = geometry.getMesh(opt_material)
### save image (before post-processing)
v_pose = rotate_scene(FLAGS, 25) ## pick a pose (pose # from 0 to 50)
result_image, _ = validate_itr(glctx, prepare_batch(v_pose, FLAGS.background), geometry, opt_material, lgt, FLAGS)
result_image = result_image.detach().cpu().numpy()
util.save_image(os.path.join(viz_path, ('%s_%06d.png' % (FLAGS.viz_name, no_data))), result_image)
# ### save post-processed mesh
# mesh_savepath = os.path.join(mesh_path, '{:06d}.obj'.format(no_data))
# save_obj(
# verts=base_mesh.v_pos,
# faces=base_mesh.t_pos_idx,
# f=mesh_savepath
# )
# ms = pymeshlab.MeshSet()
# ms.load_new_mesh(mesh_savepath)
# ms.meshing_isotropic_explicit_remeshing()
# ms.apply_coord_laplacian_smoothing(stepsmoothnum=3, cotangentweight=False)
# # ms.apply_coord_laplacian_smoothing(stepsmoothnum=3, cotangentweight=True) ## for smoother surface
# ms.meshing_isotropic_explicit_remeshing()
# ms.apply_filter_script()
# ms.save_current_mesh(mesh_savepath)

Wyświetl plik

@ -0,0 +1,451 @@
# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.
import os
import time
import argparse
import json
import sys
import glob
import numpy as np
import torch
import nvdiffrast.torch as dr
import xatlas
# Import topology / geometry trainers
from lib.geometry.dmtet import DMTetGeometry
import lib.render.renderutils as ru
from lib.render import material
from lib.render import util
from lib.render import mesh
from lib.render import texture
from lib.render import mlptexture
from lib.render import light
from lib.render import render
from pytorch3d.io import save_obj
import pymeshlab
RADIUS = 3.0
# Enable to debug back-prop anomalies
# torch.autograd.set_detect_anomaly(True)
###############################################################################
# Loss setup
###############################################################################
@torch.no_grad()
def createLoss(FLAGS):
if FLAGS.loss == "smape":
return lambda img, ref: ru.image_loss(img, ref, loss='smape', tonemapper='none')
elif FLAGS.loss == "mse":
return lambda img, ref: ru.image_loss(img, ref, loss='mse', tonemapper='none')
elif FLAGS.loss == "logl1":
return lambda img, ref: ru.image_loss(img, ref, loss='l1', tonemapper='log_srgb')
elif FLAGS.loss == "logl2":
return lambda img, ref: ru.image_loss(img, ref, loss='mse', tonemapper='log_srgb')
elif FLAGS.loss == "relmse":
return lambda img, ref: ru.image_loss(img, ref, loss='relmse', tonemapper='none')
else:
assert False
###############################################################################
# Mix background into a dataset image
###############################################################################
@torch.no_grad()
def prepare_batch(target, bg_type='black'):
# assert len(target['img'].shape) == 4, "Image shape should be [n, h, w, c]"
if bg_type == 'checker':
background = torch.tensor(util.checkerboard(target['img'].shape[1:3], 8), dtype=torch.float32, device='cuda')[None, ...]
elif bg_type == 'black':
background = torch.zeros((1, target['resolution'][0], target['resolution'][1]) + (3,), dtype=torch.float32, device='cuda')
elif bg_type == 'white':
background = torch.ones((1, target['resolution'][0], target['resolution'][1]) + (3,), dtype=torch.float32, device='cuda')
elif bg_type == 'reference':
background = target['img'][..., 0:3]
elif bg_type == 'random':
background = torch.rand((1, target['resolution'][0], target['resolution'][1]) + (3,), dtype=torch.float32, device='cuda')
else:
assert False, "Unknown background type %s" % bg_type
target['mv'] = target['mv'].cuda()
target['mvp'] = target['mvp'].cuda()
target['campos'] = target['campos'].cuda()
target['background'] = background
return target
###############################################################################
# UV - map geometry & convert to a mesh
###############################################################################
@torch.no_grad()
def xatlas_uvmap(glctx, geometry, mat, FLAGS):
eval_mesh = geometry.getMesh(mat)
# Create uvs with xatlas
v_pos = eval_mesh.v_pos.detach().cpu().numpy()
t_pos_idx = eval_mesh.t_pos_idx.detach().cpu().numpy()
vmapping, indices, uvs = xatlas.parametrize(v_pos, t_pos_idx)
# Convert to tensors
indices_int64 = indices.astype(np.uint64, casting='same_kind').view(np.int64)
uvs = torch.tensor(uvs, dtype=torch.float32, device='cuda')
faces = torch.tensor(indices_int64, dtype=torch.int64, device='cuda')
new_mesh = mesh.Mesh(v_tex=uvs, t_tex_idx=faces, base=eval_mesh)
mask, kd, ks, normal = render.render_uv(glctx, new_mesh, FLAGS.texture_res, eval_mesh.material['kd_ks_normal'])
if FLAGS.layers > 1:
kd = torch.cat((kd, torch.rand_like(kd[...,0:1])), dim=-1)
kd_min, kd_max = torch.tensor(FLAGS.kd_min, dtype=torch.float32, device='cuda'), torch.tensor(FLAGS.kd_max, dtype=torch.float32, device='cuda')
ks_min, ks_max = torch.tensor(FLAGS.ks_min, dtype=torch.float32, device='cuda'), torch.tensor(FLAGS.ks_max, dtype=torch.float32, device='cuda')
nrm_min, nrm_max = torch.tensor(FLAGS.nrm_min, dtype=torch.float32, device='cuda'), torch.tensor(FLAGS.nrm_max, dtype=torch.float32, device='cuda')
new_mesh.material = material.Material({
'bsdf' : mat['bsdf'],
'kd' : texture.Texture2D(kd, min_max=[kd_min, kd_max]),
'ks' : texture.Texture2D(ks, min_max=[ks_min, ks_max]),
'normal' : texture.Texture2D(normal, min_max=[nrm_min, nrm_max])
})
return new_mesh
###############################################################################
# Utility functions for material
###############################################################################
def initial_guess_material(geometry, mlp, FLAGS, init_mat=None):
kd_min, kd_max = torch.tensor(FLAGS.kd_min, dtype=torch.float32, device='cuda'), torch.tensor(FLAGS.kd_max, dtype=torch.float32, device='cuda')
ks_min, ks_max = torch.tensor(FLAGS.ks_min, dtype=torch.float32, device='cuda'), torch.tensor(FLAGS.ks_max, dtype=torch.float32, device='cuda')
nrm_min, nrm_max = torch.tensor(FLAGS.nrm_min, dtype=torch.float32, device='cuda'), torch.tensor(FLAGS.nrm_max, dtype=torch.float32, device='cuda')
if mlp:
mlp_min = torch.cat((kd_min[0:3], ks_min, nrm_min), dim=0)
mlp_max = torch.cat((kd_max[0:3], ks_max, nrm_max), dim=0)
mlp_map_opt = mlptexture.MLPTexture3D(geometry.getAABB(), channels=9, min_max=[mlp_min, mlp_max])
mat = material.Material({'kd_ks_normal' : mlp_map_opt})
else:
# Setup Kd (albedo) and Ks (x, roughness, metalness) textures
if FLAGS.random_textures or init_mat is None:
num_channels = 4 if FLAGS.layers > 1 else 3
kd_init = torch.rand(size=FLAGS.texture_res + [num_channels], device='cuda') * (kd_max - kd_min)[None, None, 0:num_channels] + kd_min[None, None, 0:num_channels]
kd_map_opt = texture.create_trainable(kd_init , FLAGS.texture_res, not FLAGS.custom_mip, [kd_min, kd_max])
ksR = np.random.uniform(size=FLAGS.texture_res + [1], low=0.0, high=0.01)
ksG = np.random.uniform(size=FLAGS.texture_res + [1], low=ks_min[1].cpu(), high=ks_max[1].cpu())
ksB = np.random.uniform(size=FLAGS.texture_res + [1], low=ks_min[2].cpu(), high=ks_max[2].cpu())
ks_map_opt = texture.create_trainable(np.concatenate((ksR, ksG, ksB), axis=2), FLAGS.texture_res, not FLAGS.custom_mip, [ks_min, ks_max])
else:
kd_map_opt = texture.create_trainable(init_mat['kd'], FLAGS.texture_res, not FLAGS.custom_mip, [kd_min, kd_max])
ks_map_opt = texture.create_trainable(init_mat['ks'], FLAGS.texture_res, not FLAGS.custom_mip, [ks_min, ks_max])
# Setup normal map
if FLAGS.random_textures or init_mat is None or 'normal' not in init_mat:
normal_map_opt = texture.create_trainable(np.array([0, 0, 1]), FLAGS.texture_res, not FLAGS.custom_mip, [nrm_min, nrm_max])
else:
normal_map_opt = texture.create_trainable(init_mat['normal'], FLAGS.texture_res, not FLAGS.custom_mip, [nrm_min, nrm_max])
mat = material.Material({
'kd' : kd_map_opt,
'ks' : ks_map_opt,
'normal' : normal_map_opt
})
if init_mat is not None:
mat['bsdf'] = init_mat['bsdf']
else:
mat['bsdf'] = 'pbr'
return mat
###############################################################################
# Validation & testing
###############################################################################
def rotate_scene(FLAGS, itr):
fovy = np.deg2rad(45)
cam_radius = RADIUS
proj_mtx = util.perspective(fovy, FLAGS.display_res[1] / FLAGS.display_res[0], FLAGS.cam_near_far[0], FLAGS.cam_near_far[1])
# Smooth rotation for display.
ang = (itr / 50) * np.pi * 2
mv = util.translate(0, 0, -cam_radius) @ (util.rotate_x(-0.4) @ util.rotate_y(ang))
mvp = proj_mtx @ mv
campos = torch.linalg.inv(mv)[:3, 3]
res_dict = {
'mv': mv[None, ...].cuda(),
'mvp': mvp[None, ...].cuda(),
'campos': campos[None, ...].cuda(),
'spp': 1,
'resolution': FLAGS.display_res
}
return res_dict
def validate_itr(glctx, target, geometry, opt_material, lgt, FLAGS):
result_dict = {}
with torch.no_grad():
lgt.build_mips()
if FLAGS.camera_space_light:
lgt.xfm(target['mv'])
buffers = geometry.render(glctx, target, lgt, opt_material)
result_dict['opt'] = util.rgb_to_srgb(buffers['shaded'][...,0:3])[0]
result_image = result_dict['opt']
return result_image, result_dict
def validate(glctx, geometry, opt_material, lgt, dataset_validate, out_dir, FLAGS):
# ==============================================================================================
# Validation loop
# ==============================================================================================
mse_values = []
psnr_values = []
dataloader_validate = torch.utils.data.DataLoader(dataset_validate, batch_size=1, collate_fn=dataset_validate.collate)
os.makedirs(out_dir, exist_ok=True)
with open(os.path.join(out_dir, 'metrics.txt'), 'w') as fout:
fout.write('ID, MSE, PSNR\n')
print("Running validation")
for it, target in enumerate(dataloader_validate):
# Mix validation background
target = prepare_batch(target, FLAGS.background)
result_image, result_dict = validate_itr(glctx, target, geometry, opt_material, lgt, FLAGS)
# Compute metrics
opt = torch.clamp(result_dict['opt'], 0.0, 1.0)
ref = torch.clamp(result_dict['ref'], 0.0, 1.0)
mse = torch.nn.functional.mse_loss(opt, ref, size_average=None, reduce=None, reduction='mean').item()
mse_values.append(float(mse))
psnr = util.mse_to_psnr(mse)
psnr_values.append(float(psnr))
line = "%d, %1.8f, %1.8f\n" % (it, mse, psnr)
fout.write(str(line))
for k in result_dict.keys():
np_img = result_dict[k].detach().cpu().numpy()
util.save_image(out_dir + '/' + ('val_%06d_%s.png' % (it, k)), np_img)
avg_mse = np.mean(np.array(mse_values))
avg_psnr = np.mean(np.array(psnr_values))
line = "AVERAGES: %1.4f, %2.3f\n" % (avg_mse, avg_psnr)
fout.write(str(line))
print("MSE, PSNR")
print("%1.8f, %2.3f" % (avg_mse, avg_psnr))
return avg_psnr
###############################################################################
# Main shape fitter function / optimization loop
###############################################################################
class Trainer(torch.nn.Module):
def __init__(self, glctx, geometry, lgt, mat, optimize_geometry, optimize_light, image_loss_fn, FLAGS):
super(Trainer, self).__init__()
self.glctx = glctx
self.geometry = geometry
self.light = lgt
self.material = mat
self.optimize_geometry = optimize_geometry
self.optimize_light = optimize_light
self.image_loss_fn = image_loss_fn
self.FLAGS = FLAGS
if not self.optimize_light:
with torch.no_grad():
self.light.build_mips()
self.params = list(self.material.parameters())
self.params += list(self.light.parameters()) if optimize_light else []
self.geo_params = list(self.geometry.parameters()) if optimize_geometry else []
def forward(self, target, it):
if self.optimize_light:
self.light.build_mips()
if self.FLAGS.camera_space_light:
self.light.xfm(target['mv'])
return self.geometry.tick(glctx, target, self.light, self.material, self.image_loss_fn, it)
#----------------------------------------------------------------------------
# Main function.
#----------------------------------------------------------------------------
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='nvdiffrec')
parser.add_argument('--config', type=str, default=None, help='Config file')
parser.add_argument('-i', '--iter', type=int, default=5000)
parser.add_argument('-b', '--batch', type=int, default=1)
parser.add_argument('-s', '--spp', type=int, default=1)
parser.add_argument('-l', '--layers', type=int, default=1)
parser.add_argument('-r', '--train-res', nargs=2, type=int, default=[512, 512])
parser.add_argument('-dr', '--display-res', type=int, default=None)
parser.add_argument('-tr', '--texture-res', nargs=2, type=int, default=[1024, 1024])
parser.add_argument('-di', '--display-interval', type=int, default=0)
parser.add_argument('-si', '--save-interval', type=int, default=1000)
parser.add_argument('-lr', '--learning-rate', type=float, default=0.01)
parser.add_argument('-mr', '--min-roughness', type=float, default=0.08)
parser.add_argument('-mip', '--custom-mip', action='store_true', default=False)
parser.add_argument('-rt', '--random-textures', action='store_true', default=False)
parser.add_argument('-bg', '--background', default='checker', choices=['black', 'white', 'checker', 'reference'])
parser.add_argument('-o', '--out-dir', type=str, default='./viz_tet_traj')
parser.add_argument('-sf', '--sample-folder', type=str, default=None)
parser.add_argument('-bm', '--base-mesh', type=str, default=None)
parser.add_argument('-ds', '--deform-scale', type=float, default=2.0)
parser.add_argument('-vn', '--viz-name', type=str, default='viz')
parser.add_argument('--unnormalized_sdf', action="store_true")
parser.add_argument('--validate', type=bool, default=True)
FLAGS = parser.parse_args()
FLAGS.mtl_override = None # Override material of model
FLAGS.dmtet_grid = 64 # Resolution of initial tet grid. We provide 64 and 128 resolution grids. Other resolutions can be generated with https://github.com/crawforddoran/quartet
FLAGS.mesh_scale = 2.1 # Scale of tet grid box. Adjust to cover the model
FLAGS.env_scale = 1.0 # Env map intensity multiplier
FLAGS.envmap = None # HDR environment probe
FLAGS.display = None # Conf validation window/display. E.g. [{"relight" : <path to envlight>}]
FLAGS.camera_space_light = False # Fixed light in camera space. This is needed for setups like ethiopian head where the scanned object rotates on a stand.
FLAGS.lock_light = False # Disable light optimization in the second pass
FLAGS.lock_pos = False # Disable vertex position optimization in the second pass
FLAGS.sdf_regularizer = 0.2 # Weight for sdf regularizer (see paper for details)
FLAGS.laplace = "relative" # Mesh Laplacian ["absolute", "relative"]
FLAGS.laplace_scale = 10000.0 # Weight for sdf regularizer. Default is relative with large weight
FLAGS.pre_load = True # Pre-load entire dataset into memory for faster training
FLAGS.kd_min = [ 0.0, 0.0, 0.0, 0.0] # Limits for kd
FLAGS.kd_max = [ 1.0, 1.0, 1.0, 1.0]
FLAGS.ks_min = [ 0.0, 0.08, 0.0] # Limits for ks
FLAGS.ks_max = [ 1.0, 1.0, 1.0]
FLAGS.nrm_min = [-1.0, -1.0, 0.0] # Limits for normal map
FLAGS.nrm_max = [ 1.0, 1.0, 1.0]
FLAGS.cam_near_far = [0.1, 1000.0]
FLAGS.learn_light = False
FLAGS.cropped = False
FLAGS.random_lgt = False
if FLAGS.config is not None:
data = json.load(open(FLAGS.config, 'r'))
for key in data:
FLAGS.__dict__[key] = data[key]
if FLAGS.display_res is None:
FLAGS.display_res = FLAGS.train_res
os.makedirs(FLAGS.out_dir, exist_ok=True)
viz_path = os.path.join(FLAGS.out_dir, 'viz')
mesh_path = os.path.join(FLAGS.out_dir, 'mesh')
os.makedirs(viz_path, exist_ok=True)
os.makedirs(mesh_path, exist_ok=True)
glctx = dr.RasterizeGLContext()
# ==============================================================================================
# Create env light with trainable parameters
# ==============================================================================================
lgt = light.load_env(FLAGS.envmap, scale=FLAGS.env_scale)
# ==============================================================================================
# If no initial guess, use DMtets to create geometry
# ==============================================================================================
# Setup geometry for optimization
resolution = FLAGS.dmtet_grid
geometry = DMTetGeometry(resolution, FLAGS.mesh_scale, FLAGS)
geometry.deform_scale = FLAGS.deform_scale
mask = torch.load(f'../data/grid_mask_{resolution}.pt').view(1, resolution, resolution, resolution).to("cuda")
### compute the mapping from tet indices to 3D cubic grid vertex indices
tet_path = FLAGS.tet_path
tet = np.load(tet_path)
vertices = torch.tensor(tet['vertices'])
vertices_unique = vertices[:].unique()
dx = vertices_unique[1] - vertices_unique[0]
vertices_discretized = (torch.round(
(vertices - vertices.min()) / dx)
).long()
filelist = sorted([x for x in glob.glob(os.path.join(FLAGS.sample_folder, "*.npy"))])
for k, fpath in enumerate(filelist):
data_all = np.load(fpath)
print('shape of generated data', data_all.shape)
for no_data in range(data_all.shape[0]):
grid = torch.tensor(data_all[no_data])
if FLAGS.unnormalized_sdf:
raise NotImplementedError
geometry.sdf.data[:] = (
grid[0, vertices_discretized[:, 0], vertices_discretized[:, 1], vertices_discretized[:, 2]]
).cuda()
else:
geometry.sdf.data[:] = torch.sign(
grid[0, vertices_discretized[:, 0], vertices_discretized[:, 1], vertices_discretized[:, 2]]
).cuda()
geometry.deform.data[:] = (
grid[1:, vertices_discretized[:, 0], vertices_discretized[:, 1], vertices_discretized[:, 2]]
).cuda().transpose(0, 1)
geometry.deform.data[:] = geometry.deform.data[:].clip(-1.0, 1.0)
### mtl for visualization
opt_material = {
'name' : '_default_mat',
# 'bsdf' : 'pbr',
'bsdf' : 'diffuse',
'kd' : texture.Texture2D(torch.tensor([0.75, 0.3, 0.6], dtype=torch.float32, device='cuda')),
'ks' : texture.Texture2D(torch.tensor([0.0, 0.0, 0.0], dtype=torch.float32, device='cuda'))
}
### create and optimize mesh
base_mesh = geometry.getMesh(opt_material)
v_pose = rotate_scene(FLAGS, 30)
result_image, _ = validate_itr(glctx, prepare_batch(v_pose, FLAGS.background), geometry, opt_material, lgt, FLAGS)
result_image = result_image.detach().cpu().numpy()
util.save_image(os.path.join(viz_path, ('%s_%03d_time%03d.png' % (FLAGS.viz_name, no_data, k))), result_image)
### save post-processed mesh
mesh_savepath = os.path.join(mesh_path, '%s_%03d_time%03d.obj' % (FLAGS.viz_name, no_data, k))
save_obj(
verts=base_mesh.v_pos,
faces=base_mesh.t_pos_idx,
f=mesh_savepath
)
ms = pymeshlab.MeshSet()
ms.load_new_mesh(mesh_savepath)
ms.meshing_isotropic_explicit_remeshing()
ms.apply_coord_laplacian_smoothing(stepsmoothnum=3, cotangentweight=False)
# ms.apply_coord_laplacian_smoothing(stepsmoothnum=3, cotangentweight=True) ## for smoother surface
ms.meshing_isotropic_explicit_remeshing()
ms.apply_filter_script()
ms.save_current_mesh(mesh_savepath)

Wyświetl plik

@ -0,0 +1,820 @@
# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.
import os
import time
import argparse
import json
import sys
import cv2
import numpy as np
import torch
import nvdiffrast.torch as dr
import xatlas
# Import data readers / generators
from lib.dataset.dataset_mesh import DatasetMesh
from lib.dataset.dataset_shapenet import ShapeNetDataset
# Import topology / geometry trainers
from lib.geometry.dmtet import DMTetGeometry
from lib.geometry.dmtet_fixedtopo import DMTetGeometryFixedTopo
import lib.render.renderutils as ru
from lib.render import obj
from lib.render import material
from lib.render import util
from lib.render import mesh
from lib.render import texture
from lib.render import mlptexture
from lib.render import light
from lib.render import render
import traceback
RADIUS = 2.0
# # Enable to debug back-prop anomalies
# torch.autograd.set_detect_anomaly(True)
# define colors
color1 = (0, 0, 255) #red
color2 = (0, 165, 255) #orange
color3 = (0, 255, 255) #yellow
color4 = (255, 255, 0) #cyan
color5 = (255, 0, 0) #blue
color6 = (128, 64, 64) #violet
colorArr = np.array([[color1, color2, color3, color4, color5, color6]], dtype=np.uint8)
# resize lut to 256 (or more) values
lut = cv2.resize(colorArr, (256,1), interpolation = cv2.INTER_LINEAR)
###############################################################################
# Loss setup
###############################################################################
@torch.no_grad()
def createLoss(FLAGS):
if FLAGS.loss == "smape":
return lambda img, ref: ru.image_loss(img, ref, loss='smape', tonemapper='none')
elif FLAGS.loss == "mse":
return lambda img, ref: ru.image_loss(img, ref, loss='mse', tonemapper='none')
elif FLAGS.loss == "logl1":
return lambda img, ref: ru.image_loss(img, ref, loss='l1', tonemapper='log_srgb')
elif FLAGS.loss == "logl2":
return lambda img, ref: ru.image_loss(img, ref, loss='mse', tonemapper='log_srgb')
elif FLAGS.loss == "relmse":
return lambda img, ref: ru.image_loss(img, ref, loss='relmse', tonemapper='none')
else:
assert False
###############################################################################
# Mix background into a dataset image
###############################################################################
@torch.no_grad()
def prepare_batch(target, bg_type='black'):
assert len(target['img'].shape) == 4, "Image shape should be [n, h, w, c]"
if bg_type == 'checker':
background = torch.tensor(util.checkerboard(target['img'].shape[1:3], 8), dtype=torch.float32, device='cuda')[None, ...]
elif bg_type == 'black':
background = torch.zeros(target['img'].shape[0:3] + (3,), dtype=torch.float32, device='cuda')
elif bg_type == 'white':
background = torch.ones(target['img'].shape[0:3] + (3,), dtype=torch.float32, device='cuda')
elif bg_type == 'reference':
background = target['img'][..., 0:3]
elif bg_type == 'random':
background = torch.rand(target['img'].shape[0:3] + (3,), dtype=torch.float32, device='cuda')
else:
assert False, "Unknown background type %s" % bg_type
target['mv'] = target['mv'].cuda()
target['mvp'] = target['mvp'].cuda()
target['campos'] = target['campos'].cuda()
target['img'] = target['img'].cuda()
target['background'] = background
target['img'] = torch.cat((torch.lerp(background, target['img'][..., 0:3], target['img'][..., 3:4]), target['img'][..., 3:4]), dim=-1)
target['spts'] = target['spts'].cuda()
target['vpts'] = target['vpts'].cuda()
return target
###############################################################################
# UV - map geometry & convert to a mesh
###############################################################################
@torch.no_grad()
def xatlas_uvmap(glctx, geometry, mat, FLAGS):
eval_mesh = geometry.getMesh(mat)
# Create uvs with xatlas
v_pos = eval_mesh.v_pos.detach().cpu().numpy()
t_pos_idx = eval_mesh.t_pos_idx.detach().cpu().numpy()
vmapping, indices, uvs = xatlas.parametrize(v_pos, t_pos_idx)
# Convert to tensors
indices_int64 = indices.astype(np.uint64, casting='same_kind').view(np.int64)
uvs = torch.tensor(uvs, dtype=torch.float32, device='cuda')
faces = torch.tensor(indices_int64, dtype=torch.int64, device='cuda')
new_mesh = mesh.Mesh(v_tex=uvs, t_tex_idx=faces, base=eval_mesh)
mask, kd, ks, normal = render.render_uv(glctx, new_mesh, FLAGS.texture_res, eval_mesh.material['kd_ks_normal'])
if FLAGS.layers > 1:
kd = torch.cat((kd, torch.rand_like(kd[...,0:1])), dim=-1)
kd_min, kd_max = torch.tensor(FLAGS.kd_min, dtype=torch.float32, device='cuda'), torch.tensor(FLAGS.kd_max, dtype=torch.float32, device='cuda')
ks_min, ks_max = torch.tensor(FLAGS.ks_min, dtype=torch.float32, device='cuda'), torch.tensor(FLAGS.ks_max, dtype=torch.float32, device='cuda')
nrm_min, nrm_max = torch.tensor(FLAGS.nrm_min, dtype=torch.float32, device='cuda'), torch.tensor(FLAGS.nrm_max, dtype=torch.float32, device='cuda')
new_mesh.material = material.Material({
'bsdf' : mat['bsdf'],
'kd' : texture.Texture2D(kd, min_max=[kd_min, kd_max]),
'ks' : texture.Texture2D(ks, min_max=[ks_min, ks_max]),
'normal' : texture.Texture2D(normal, min_max=[nrm_min, nrm_max])
})
return new_mesh
@torch.no_grad()
def xatlas_uvmap_nrm(glctx, geometry, mat, FLAGS):
eval_mesh = geometry.getMesh(mat)
# Create uvs with xatlas
v_pos = eval_mesh.v_pos.detach().cpu().numpy()
t_pos_idx = eval_mesh.t_pos_idx.detach().cpu().numpy()
vmapping, indices, uvs = xatlas.parametrize(v_pos, t_pos_idx)
# Convert to tensors
indices_int64 = indices.astype(np.uint64, casting='same_kind').view(np.int64)
uvs = torch.tensor(uvs, dtype=torch.float32, device='cuda')
faces = torch.tensor(indices_int64, dtype=torch.int64, device='cuda')
new_mesh = mesh.Mesh(v_tex=uvs, t_tex_idx=faces, base=eval_mesh)
mask, normal = render.render_uv_nrm(glctx, new_mesh, FLAGS.texture_res, eval_mesh.material['normal'])
if FLAGS.layers > 1:
kd = torch.cat((kd, torch.rand_like(kd[...,0:1])), dim=-1)
nrm_min, nrm_max = torch.tensor(FLAGS.nrm_min, dtype=torch.float32, device='cuda'), torch.tensor(FLAGS.nrm_max, dtype=torch.float32, device='cuda')
new_mesh.material = material.Material({
'bsdf' : mat['bsdf'],
'kd' : mat['kd'],
'ks' : mat['ks'],
'normal' : texture.Texture2D(normal, min_max=[nrm_min, nrm_max])
})
return new_mesh
###############################################################################
# Utility functions for material
###############################################################################
def initial_guess_material(geometry, mlp, FLAGS, init_mat=None):
kd_min, kd_max = torch.tensor(FLAGS.kd_min, dtype=torch.float32, device='cuda'), torch.tensor(FLAGS.kd_max, dtype=torch.float32, device='cuda')
ks_min, ks_max = torch.tensor(FLAGS.ks_min, dtype=torch.float32, device='cuda'), torch.tensor(FLAGS.ks_max, dtype=torch.float32, device='cuda')
nrm_min, nrm_max = torch.tensor(FLAGS.nrm_min, dtype=torch.float32, device='cuda'), torch.tensor(FLAGS.nrm_max, dtype=torch.float32, device='cuda')
if mlp:
mlp_min = torch.cat((kd_min[0:3], ks_min, nrm_min), dim=0)
mlp_max = torch.cat((kd_max[0:3], ks_max, nrm_max), dim=0)
mlp_map_opt = mlptexture.MLPTexture3D(geometry.getAABB(), channels=9, min_max=[mlp_min, mlp_max])
mat = material.Material({'kd_ks_normal' : mlp_map_opt})
else:
# Setup Kd (albedo) and Ks (x, roughness, metalness) textures
if FLAGS.random_textures or init_mat is None:
num_channels = 4 if FLAGS.layers > 1 else 3
kd_init = torch.rand(size=FLAGS.texture_res + [num_channels], device='cuda') * (kd_max - kd_min)[None, None, 0:num_channels] + kd_min[None, None, 0:num_channels]
kd_map_opt = texture.create_trainable(kd_init , FLAGS.texture_res, not FLAGS.custom_mip, [kd_min, kd_max])
ksR = np.random.uniform(size=FLAGS.texture_res + [1], low=0.0, high=0.01)
ksG = np.random.uniform(size=FLAGS.texture_res + [1], low=ks_min[1].cpu(), high=ks_max[1].cpu())
ksB = np.random.uniform(size=FLAGS.texture_res + [1], low=ks_min[2].cpu(), high=ks_max[2].cpu())
ks_map_opt = texture.create_trainable(np.concatenate((ksR, ksG, ksB), axis=2), FLAGS.texture_res, not FLAGS.custom_mip, [ks_min, ks_max])
else:
kd_map_opt = texture.create_trainable(init_mat['kd'], FLAGS.texture_res, not FLAGS.custom_mip, [kd_min, kd_max])
ks_map_opt = texture.create_trainable(init_mat['ks'], FLAGS.texture_res, not FLAGS.custom_mip, [ks_min, ks_max])
# Setup normal map
if FLAGS.random_textures or init_mat is None or 'normal' not in init_mat:
normal_map_opt = texture.create_trainable(np.array([0, 0, 1]), FLAGS.texture_res, not FLAGS.custom_mip, [nrm_min, nrm_max])
else:
normal_map_opt = texture.create_trainable(init_mat['normal'], FLAGS.texture_res, not FLAGS.custom_mip, [nrm_min, nrm_max])
mat = material.Material({
'kd' : kd_map_opt,
'ks' : ks_map_opt,
'normal' : normal_map_opt
})
if init_mat is not None:
mat['bsdf'] = init_mat['bsdf']
else:
mat['bsdf'] = 'pbr'
return mat
def initial_guess_material_knownkskd(geometry, mlp, FLAGS, init_mat=None):
nrm_min, nrm_max = torch.tensor(FLAGS.nrm_min, dtype=torch.float32, device='cuda'), torch.tensor(FLAGS.nrm_max, dtype=torch.float32, device='cuda')
if mlp:
mlp_min = nrm_min
mlp_max = nrm_max
mlp_map_opt = mlptexture.MLPTexture3D(geometry.getAABB(), channels=3, min_max=[mlp_min, mlp_max])
# mlp_map_opt = mlptexture.MLPTexture3D(geometry.getAABB(), channels=3, min_max=None)
mat = material.Material({
'kd' : init_mat['kd'],
'ks' : init_mat['ks'],
'normal' : mlp_map_opt,
})
else:
# Setup normal map
if FLAGS.random_textures or init_mat is None or 'normal' not in init_mat:
normal_map_opt = texture.create_trainable(np.array([0, 0, 1]), FLAGS.texture_res, not FLAGS.custom_mip, [nrm_min, nrm_max])
else:
normal_map_opt = texture.create_trainable(init_mat['normal'], FLAGS.texture_res, not FLAGS.custom_mip, [nrm_min, nrm_max])
mat = material.Material({
'kd' : init_mat['kd'],
'ks' : init_mat['ks'],
'normal' : normal_map_opt
})
if init_mat is not None:
mat['bsdf'] = init_mat['bsdf']
else:
mat['bsdf'] = 'pbr'
return mat
###############################################################################
# Validation & testing
###############################################################################
def validate_itr(glctx, target, geometry, opt_material, lgt, FLAGS):
result_dict = {}
with torch.no_grad():
lgt.build_mips()
if FLAGS.camera_space_light:
lgt.xfm(target['mv'])
lgt.xfm(target['envlight_transform'])
try:
buffers = geometry.render(glctx, target, lgt, opt_material, ema=True, xfm_lgt=target['envlight_transform'])
except:
buffers = geometry.render(glctx, target, lgt, opt_material, xfm_lgt=target['envlight_transform'])
result_dict['ref'] = util.rgb_to_srgb(target['img'][...,0:3])[0]
result_dict['opt'] = util.rgb_to_srgb(buffers['shaded'][...,0:3])[0]
result_image = torch.cat([result_dict['opt'], result_dict['ref']], axis=1)
return result_image, result_dict
def validate(glctx, geometry, opt_material, lgt, dataset_validate, out_dir, FLAGS):
# ==============================================================================================
# Validation loop
# ==============================================================================================
mse_values = []
psnr_values = []
dataloader_validate = torch.utils.data.DataLoader(dataset_validate, batch_size=1, collate_fn=dataset_validate.collate)
os.makedirs(out_dir, exist_ok=True)
with open(os.path.join(out_dir, 'metrics.txt'), 'w') as fout:
fout.write('ID, MSE, PSNR\n')
print("Running validation")
for it, target in enumerate(dataloader_validate):
# Mix validation background
target = prepare_batch(target, FLAGS.background)
result_image, result_dict = validate_itr(glctx, target, geometry, opt_material, lgt, FLAGS)
# Compute metrics
opt = torch.clamp(result_dict['opt'], 0.0, 1.0)
ref = torch.clamp(result_dict['ref'], 0.0, 1.0)
mse = torch.nn.functional.mse_loss(opt, ref, size_average=None, reduce=None, reduction='mean').item()
mse_values.append(float(mse))
psnr = util.mse_to_psnr(mse)
psnr_values.append(float(psnr))
line = "%d, %1.8f, %1.8f\n" % (it, mse, psnr)
fout.write(str(line))
for k in result_dict.keys():
np_img = result_dict[k].detach().cpu().numpy()
util.save_image(out_dir + '/' + ('val_%06d_%s.png' % (it, k)), np_img)
avg_mse = np.mean(np.array(mse_values))
avg_psnr = np.mean(np.array(psnr_values))
line = "AVERAGES: %1.4f, %2.3f\n" % (avg_mse, avg_psnr)
fout.write(str(line))
print("MSE, PSNR")
print("%1.8f, %2.3f" % (avg_mse, avg_psnr))
return avg_psnr
###############################################################################
# Main shape fitter function / optimization loop
###############################################################################
class Trainer(torch.nn.Module):
def __init__(self, glctx, geometry, lgt, mat, optimize_geometry, optimize_light, image_loss_fn, FLAGS):
super(Trainer, self).__init__()
self.glctx = glctx
self.geometry = geometry
self.light = lgt
self.material = mat
self.optimize_geometry = optimize_geometry
self.optimize_light = optimize_light
self.image_loss_fn = image_loss_fn
self.FLAGS = FLAGS
if not self.optimize_light:
with torch.no_grad():
self.light.build_mips()
self.params = list(self.material.parameters())
self.params += list(self.light.parameters()) if optimize_light else []
self.geo_params = list(self.geometry.parameters()) if optimize_geometry else []
try:
self.sdf_params = [self.geometry.sdf]
except:
self.sdf_params = []
self.deform_params = [self.geometry.deform]
def forward(self, target, it):
if self.optimize_light:
self.light.build_mips()
if self.FLAGS.camera_space_light:
self.light.xfm(target['mv'])
self.light.xfm(target['envlight_transform'])
return self.geometry.tick(glctx, target, self.light, self.material, self.image_loss_fn, it, xfm_lgt=target['envlight_transform'])
def optimize_mesh(
glctx,
geometry,
opt_material,
lgt,
dataset_train,
dataset_validate,
FLAGS,
warmup_iter=0,
log_interval=10,
pass_idx=0,
pass_name="",
optimize_light=True,
optimize_geometry=True,
):
# ==============================================================================================
# Setup torch optimizer
# ==============================================================================================
learning_rate = FLAGS.learning_rate[pass_idx] if isinstance(FLAGS.learning_rate, list) or isinstance(FLAGS.learning_rate, tuple) else FLAGS.learning_rate
learning_rate_pos = learning_rate[0] if isinstance(learning_rate, list) or isinstance(learning_rate, tuple) else learning_rate
learning_rate_mat = learning_rate[1] if isinstance(learning_rate, list) or isinstance(learning_rate, tuple) else learning_rate
def lr_schedule(iter, fraction):
if iter < warmup_iter:
return iter / warmup_iter
return max(0.0, 10**(-(iter - warmup_iter)*0.0002)) # Exponential falloff from [1.0, 0.1] over 5k epochs.
# ==============================================================================================
# Image loss
# ==============================================================================================
image_loss_fn = createLoss(FLAGS)
trainer_noddp = Trainer(glctx, geometry, lgt, opt_material, optimize_geometry, optimize_light, image_loss_fn, FLAGS)
if FLAGS.multi_gpu:
raise NotImplementedError
# Multi GPU training mode
import apex
from apex.parallel import DistributedDataParallel as DDP
trainer = DDP(trainer_noddp)
trainer.train()
if optimize_geometry:
optimizer_mesh = apex.optimizers.FusedAdam(trainer_noddp.geo_params, lr=learning_rate_pos)
scheduler_mesh = torch.optim.lr_scheduler.LambdaLR(optimizer_mesh, lr_lambda=lambda x: lr_schedule(x, 0.9))
optimizer = apex.optimizers.FusedAdam(trainer_noddp.params, lr=learning_rate_mat)
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda x: lr_schedule(x, 0.9))
else:
# Single GPU training mode
trainer = trainer_noddp
if optimize_geometry:
# optimizer_mesh = torch.optim.Adam(trainer_noddp.geo_params, lr=learning_rate_pos)
optimizer_mesh = torch.optim.Adam([
{'params': trainer_noddp.sdf_params, 'lr': learning_rate_pos},
{'params': trainer_noddp.deform_params, 'lr': learning_rate_pos},
])
# optimizer_mesh = torch.optim.Adam(trainer_noddp.geo_params, lr=learning_rate_pos, betas=(0.2, 0.999), eps=1e-5)
scheduler_mesh = torch.optim.lr_scheduler.LambdaLR(optimizer_mesh, lr_lambda=lambda x: lr_schedule(x, 0.9))
optimizer = torch.optim.Adam(trainer_noddp.params, lr=learning_rate_mat)
# optimizer = torch.optim.Adam(trainer_noddp.params, lr=learning_rate_mat, betas=(0.2, 0.999), eps=1e-5)
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda x: lr_schedule(x, 0.9))
# ==============================================================================================
# Training loop
# ==============================================================================================
img_cnt = 0
img_loss_vec = []
reg_loss_vec = []
iter_dur_vec = []
dataloader_train = torch.utils.data.DataLoader(dataset_train, batch_size=FLAGS.batch, collate_fn=dataset_train.collate, shuffle=True)
print("Start training loop...")
sys.stdout.flush()
for it, target in enumerate(dataloader_train):
# Mix randomized background into dataset image
target = prepare_batch(target, 'random')
iter_start_time = time.time()
# ==============================================================================================
# Zero gradients
# ==============================================================================================
optimizer.zero_grad()
if optimize_geometry:
optimizer_mesh.zero_grad()
# ==============================================================================================
# Training
# ==============================================================================================
img_loss, reg_loss = trainer(target, it)
# ==============================================================================================
# Final loss
# ==============================================================================================
total_loss = img_loss + reg_loss
img_loss_vec.append(img_loss.item())
reg_loss_vec.append(reg_loss.item())
# ==============================================================================================
# Backpropagate
# ==============================================================================================
total_loss.backward()
if hasattr(lgt, 'base') and lgt.base.grad is not None and optimize_light:
lgt.base.grad *= 64
if 'kd_ks_normal' in opt_material:
opt_material['kd_ks_normal'].encoder.params.grad /= 8.0
if 'normal' in opt_material and FLAGS.normal_only:
try:
opt_material['normal'].encoder.params.grad /= 8.0
except:
pass
optimizer.step()
scheduler.step()
if optimize_geometry:
optimizer_mesh.step()
scheduler_mesh.step()
geometry.clamp_deform()
geometry.update_ema()
# ==============================================================================================
# Clamp trainables to reasonable range
# ==============================================================================================
with torch.no_grad():
if 'kd' in opt_material:
opt_material['kd'].clamp_()
if 'ks' in opt_material:
opt_material['ks'].clamp_()
if 'normal' in opt_material and not FLAGS.normal_only:
opt_material['normal'].clamp_()
opt_material['normal'].normalize_()
if lgt is not None:
lgt.clamp_(min=0.0)
torch.cuda.current_stream().synchronize()
iter_dur_vec.append(time.time() - iter_start_time)
# ==============================================================================================
# Logging
# ==============================================================================================
if it % log_interval == 0 and FLAGS.local_rank == 0:
img_loss_avg = np.mean(np.asarray(img_loss_vec[-log_interval:]))
reg_loss_avg = np.mean(np.asarray(reg_loss_vec[-log_interval:]))
iter_dur_avg = np.mean(np.asarray(iter_dur_vec[-log_interval:]))
remaining_time = (FLAGS.iter-it)*iter_dur_avg
print("iter=%5d, img_loss=%.6f, reg_loss=%.6f, lr=%.5f, time=%.1f ms, rem=%s" %
(it, img_loss_avg, reg_loss_avg, optimizer.param_groups[0]['lr'], iter_dur_avg*1000, util.time_to_text(remaining_time)))
sys.stdout.flush()
return geometry, opt_material
#----------------------------------------------------------------------------
# Main function.
#----------------------------------------------------------------------------
if __name__ == "__main__":
# sleep(randint(0,15))
parser = argparse.ArgumentParser(description='nvdiffrec')
parser.add_argument('--config', type=str, default='./configs/res64.json', help='Config file')
parser.add_argument('-i', '--iter', type=int, default=5000)
parser.add_argument('-b', '--batch', type=int, default=1)
parser.add_argument('-s', '--spp', type=int, default=1)
parser.add_argument('-l', '--layers', type=int, default=1)
parser.add_argument('-r', '--train-res', nargs=2, type=int, default=[512, 512])
parser.add_argument('-dr', '--display-res', type=int, default=None)
parser.add_argument('-tr', '--texture-res', nargs=2, type=int, default=[1024, 1024])
parser.add_argument('-di', '--display-interval', type=int, default=0)
parser.add_argument('-si', '--save-interval', type=int, default=1000)
parser.add_argument('-lr', '--learning-rate', type=float, default=0.01)
parser.add_argument('-mr', '--min-roughness', type=float, default=0.08)
parser.add_argument('-mip', '--custom-mip', action='store_true', default=False)
parser.add_argument('-rt', '--random-textures', action='store_true', default=False)
parser.add_argument('-bg', '--background', default='checker', choices=['black', 'white', 'checker', 'reference'])
parser.add_argument('--loss', default='logl1', choices=['logl1', 'logl2', 'mse', 'smape', 'relmse'])
parser.add_argument('-o', '--out-dir', type=str, default='./dmtet_results')
parser.add_argument('-bm', '--base-mesh', type=str, default=None)
parser.add_argument('--validate', type=bool, default=True)
parser.add_argument('-ind', '--index', type=int)
parser.add_argument('-ss', '--split-size', type=int, default=10)
parser.add_argument('--cropped', type=bool, default=True)
parser.add_argument('-no', '--normal-only', type=bool, default=True)
parser.add_argument('--meta-folder', type=str, default='./data/shapenet_json')
parser.add_argument('--cat-name', type=str, default='chair')
parser.add_argument('-rp', '--resume-path', type=str, default=None)
parser.add_argument('-ema', '--use-ema', action="store_true")
FLAGS = parser.parse_args()
print(f"parsed arguments")
global_index = FLAGS.index * FLAGS.split_size
FLAGS.mtl_override = None # Override material of model
FLAGS.dmtet_grid = 64 # Resolution of initial tet grid. We provide 64 and 128 resolution grids. Other resolutions can be generated with https://github.com/crawforddoran/quartet
FLAGS.mesh_scale = 1.0 # Scale of tet grid box. Adjust to cover the model
FLAGS.env_scale = 1.0 # Env map intensity multiplier
FLAGS.envmap = None # HDR environment probe
FLAGS.display = None # Conf validation window/display. E.g. [{"relight" : <path to envlight>}]
FLAGS.camera_space_light = False # Fixed light in camera space. This is needed for setups like ethiopian head where the scanned object rotates on a stand.
FLAGS.lock_light = False # Disable light optimization in the second pass
FLAGS.lock_pos = False # Disable vertex position optimization in the second pass
FLAGS.sdf_regularizer = 0.2 # Weight for sdf regularizer (see paper for details)
FLAGS.laplace = "relative" # Mesh Laplacian ["absolute", "relative"]
FLAGS.laplace_scale = 10000.0 # Weight for sdf regularizer. Default is relative with large weight
FLAGS.pre_load = True # Pre-load entire dataset into memory for faster training
FLAGS.kd_min = [ 0.0, 0.0, 0.0, 0.0] # Limits for kd
FLAGS.kd_max = [ 1.0, 1.0, 1.0, 1.0]
FLAGS.ks_min = [ 0.0, 0.08, 0.0] # Limits for ks
FLAGS.ks_max = [ 1.0, 1.0, 1.0]
FLAGS.nrm_min = [-1.0, -1.0, 0.0] # Limits for normal map
FLAGS.nrm_max = [ 1.0, 1.0, 1.0]
FLAGS.cam_near_far = [0.1, 1000.0]
FLAGS.learn_light = False
FLAGS.cropped = True
FLAGS.use_ema = False
FLAGS.random_lgt = True
FLAGS.dataset_flat_shading = False
FLAGS.local_rank = 0
FLAGS.multi_gpu = "WORLD_SIZE" in os.environ and int(os.environ["WORLD_SIZE"]) > 1
if FLAGS.multi_gpu:
if "MASTER_ADDR" not in os.environ:
os.environ["MASTER_ADDR"] = 'localhost'
if "MASTER_PORT" not in os.environ:
os.environ["MASTER_PORT"] = '23456'
FLAGS.local_rank = int(os.environ["LOCAL_RANK"])
torch.cuda.set_device(FLAGS.local_rank)
torch.distributed.init_process_group(backend="nccl", init_method="env://")
if FLAGS.config is not None:
data = json.load(open(FLAGS.config, 'r'))
for key in data:
FLAGS.__dict__[key] = data[key]
if FLAGS.display_res is None:
FLAGS.display_res = FLAGS.train_res
FLAGS.out_dir = os.path.join(FLAGS.out_dir, FLAGS.cat_name)
if FLAGS.local_rank == 0:
print("Config / Flags:")
print("---------")
for key in FLAGS.__dict__.keys():
print(key, FLAGS.__dict__[key])
print("---------")
os.makedirs(FLAGS.out_dir, exist_ok=True)
os.makedirs(os.path.join(FLAGS.out_dir, 'val_viz'), exist_ok=True)
os.makedirs(os.path.join(FLAGS.out_dir, 'tets'), exist_ok=True)
os.makedirs(os.path.join(FLAGS.out_dir, 'tets_pre'), exist_ok=True)
print(f"Using dmt grid of resolution {FLAGS.dmtet_grid}")
glctx = dr.RasterizeGLContext()
### Default mtl
mtl_default = {
'name' : '_default_mat',
'bsdf': 'diffuse',
'uniform': True,
'kd' : texture.Texture2D(torch.tensor([0.75, 0.3, 0.6], dtype=torch.float32, device='cuda'), trainable=False),
'ks' : texture.Texture2D(torch.tensor([0.0, 0.0, 0.0], dtype=torch.float32, device='cuda'), trainable=False)
}
print(f"meta json path {os.path.join(FLAGS.meta_folder, f'{FLAGS.cat_name}.json')}")
shapenet_dataset = ShapeNetDataset(
os.path.join(FLAGS.meta_folder, f'{FLAGS.cat_name}.json'),
shapenet_v1=(FLAGS.cat_name == 'car')
)
print("Start iterating through objects")
sys.stdout.flush()
if len(shapenet_dataset) > 0:
for k in range(FLAGS.split_size):
# ==============================================================================================
# Create data pipeline
# ==============================================================================================
global_index = k + FLAGS.index * FLAGS.split_size
print("file path to save: {:s}".format(os.path.join(FLAGS.out_dir, 'tets/dmt_dict_{:05d}.pt'.format(global_index))))
skip_if_exists = True
if skip_if_exists and os.path.exists(os.path.join(FLAGS.out_dir, 'tets/dmt_dict_{:05d}.pt'.format(global_index))):
continue
try:
global_index = k + FLAGS.index * FLAGS.split_size
if global_index >= len(shapenet_dataset):
break
mesh_fname = shapenet_dataset[global_index]
print(f"Loading mesh: {mesh_fname}")
sys.stdout.flush()
ref_mesh = mesh.load_mesh(mesh_fname, FLAGS.mtl_override, mtl_default, use_default=FLAGS.normal_only, no_additional=True)
ref_mesh = mesh.center_by_reference(ref_mesh, mesh.aabb_clean(ref_mesh), 1.0)
a = ref_mesh.v_nrm.clone()
ref_mesh = mesh.auto_normals(ref_mesh) ### important
print("Loading dataset")
sys.stdout.flush()
if FLAGS.cat_name == 'car':
RADIUS = 2.0
dataset_train = DatasetMesh(ref_mesh, glctx, RADIUS, FLAGS, validate=False)
dataset_validate = DatasetMesh(ref_mesh, glctx, RADIUS, FLAGS, validate=True)
print("Dataset loaded")
sys.stdout.flush()
# ==============================================================================================
# Create env light with trainable parameters
# ==============================================================================================
if FLAGS.learn_light:
lgt = light.create_trainable_env_rnd(512, scale=0.0, bias=0.5)
else:
lgt = light.load_env(FLAGS.envmap, scale=FLAGS.env_scale, trainable=False)
# ==============================================================================================
# If no initial guess, use DMtets to create geometry
# ==============================================================================================
# Setup geometry for optimization
geometry = DMTetGeometry(FLAGS.dmtet_grid, FLAGS.mesh_scale, FLAGS)
# Setup textures, make initial guess from reference if possible
if not FLAGS.normal_only:
mat = initial_guess_material(geometry, True, FLAGS, mtl_default)
else:
mat = initial_guess_material_knownkskd(geometry, True, FLAGS, mtl_default)
print("Start optimization")
sys.stdout.flush()
if FLAGS.resume_path is None:
# Run optimization
geometry, mat = optimize_mesh(glctx, geometry, mat, lgt, dataset_train, dataset_validate,
FLAGS, pass_idx=0, pass_name="dmtet_pass1", optimize_light=FLAGS.learn_light)
base_mesh = geometry.getMesh(mat)
vert_mask = torch.zeros_like(geometry.sdf).long().cuda().view(-1, 1)
vert_mask[geometry.getValidVertsIdx()] = 1
# Free temporaries / cached memory
torch.cuda.empty_cache() ### may slow down training
torch.save({
'sdf': geometry.sdf.cpu().detach(),
'sdf_ema': geometry.sdf_ema.cpu().detach(),
'deform': (geometry.deform * vert_mask).cpu().detach(),
'deform_unmasked': geometry.deform.cpu().detach(),
}, os.path.join(FLAGS.out_dir, 'tets_pre/dmt_dict_{:05d}.pt'.format(global_index)))
old_geometry = geometry
else:
dmt_dict = torch.load(os.path.join(FLAGS.resume_path, 'tets_pre/dmt_dict_{:05d}.pt'.format(global_index)))
if FLAGS.use_ema:
geometry.sdf.data[:] = dmt_dict['sdf_ema']
else:
geometry.sdf.data[:] = dmt_dict['sdf']
geometry.deform.data[:] = dmt_dict['deform']
old_geometry = geometry
# Create textured mesh from result
if FLAGS.normal_only:
base_mesh = xatlas_uvmap_nrm(glctx, geometry, mat, FLAGS)
else:
base_mesh = xatlas_uvmap(glctx, geometry, mat, FLAGS)
# # ==============================================================================================
# # Pass 2: Finetune deformation with fixed topology
# # ==============================================================================================
geometry = DMTetGeometryFixedTopo(geometry, base_mesh, FLAGS.dmtet_grid, FLAGS.mesh_scale, FLAGS)
geometry.sdf_sign.requires_grad = False
geometry.sdf_abs.requires_grad = False
geometry.deform.requires_grad = True
# geometry.deform.data[:] = geometry.deform * 2.0 / 3.0
# geometry.deform_scale = 3.0
geometry.deform.data[:] = geometry.deform * 0.45 / 1.5
geometry.deform_scale = 1.5
if FLAGS.use_ema:
geometry.sdf_sign.data[:] = torch.sign(old_geometry.sdf_ema)
else:
geometry.sdf_sign.data[:] = torch.sign(old_geometry.sdf)
geometry.set_init_v_pos()
geometry, mat = optimize_mesh(glctx, geometry, mat, lgt, dataset_train, dataset_validate, FLAGS,
pass_idx=1, pass_name="mesh_pass", warmup_iter=100, optimize_light=FLAGS.learn_light and not FLAGS.lock_light,
optimize_geometry=not FLAGS.lock_pos)
vert_mask = torch.zeros_like(geometry.sdf_sign).long().cuda().view(-1, 1)
vert_mask[geometry.getValidVertsIdx()] = 1
torch.save({
'sdf': geometry.sdf_sign.cpu().detach(),
'deform': (geometry.deform * vert_mask).cpu().detach(),
'deform_unmasked': geometry.deform.cpu().detach(),
},
os.path.join(FLAGS.out_dir, 'tets/dmt_dict_{:05d}.pt'.format(global_index))
)
if FLAGS.local_rank == 0 and FLAGS.validate:
validate(glctx, geometry, mat, lgt, dataset_validate, os.path.join(FLAGS.out_dir, f"val_viz/dmtet_validate_{FLAGS.index}_{k}_{FLAGS.split_size}"), FLAGS)
# Free temporaries / cached memory
del geometry
del ref_mesh
del dataset_train
del dataset_validate
torch.cuda.empty_cache() ### may slow down training
print(f"\n\n============ {FLAGS.index}_{k}/{FLAGS.split_size} finished ============\n\n")
except Exception as err:
print(f"\n\n============ {FLAGS.index}_{k}/{FLAGS.split_size} Failed ============\n\n")
print(traceback.format_exc())
print("\n\n")
continue

Wyświetl plik

@ -0,0 +1,836 @@
# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.
import os
import time
import argparse
import json
import sys
import cv2
import numpy as np
import torch
import nvdiffrast.torch as dr
import xatlas
# Import data readers / generators
from lib.dataset.dataset_mesh import DatasetMesh
from lib.dataset.dataset_shapenet import ShapeNetDataset
# Import topology / geometry trainers
from lib.geometry.dmtet_singleview import DMTetGeometry
from lib.geometry.dmtet_fixedtopo import DMTetGeometryFixedTopo
import lib.render.renderutils as ru
from lib.render import obj
from lib.render import material
from lib.render import util
from lib.render import mesh
from lib.render import texture
from lib.render import mlptexture
from lib.render import light
from lib.render import render
from random import randint
from time import sleep
import traceback
RADIUS = 2.0
# define colors
color1 = (0, 0, 255) #red
color2 = (0, 165, 255) #orange
color3 = (0, 255, 255) #yellow
color4 = (255, 255, 0) #cyan
color5 = (255, 0, 0) #blue
color6 = (128, 64, 64) #violet
colorArr = np.array([[color1, color2, color3, color4, color5, color6]], dtype=np.uint8)
# resize lut to 256 (or more) values
lut = cv2.resize(colorArr, (256,1), interpolation = cv2.INTER_LINEAR)
###############################################################################
# Loss setup
###############################################################################
@torch.no_grad()
def createLoss(FLAGS):
if FLAGS.loss == "smape":
return lambda img, ref: ru.image_loss(img, ref, loss='smape', tonemapper='none')
elif FLAGS.loss == "mse":
return lambda img, ref: ru.image_loss(img, ref, loss='mse', tonemapper='none')
elif FLAGS.loss == "logl1":
return lambda img, ref: ru.image_loss(img, ref, loss='l1', tonemapper='log_srgb')
elif FLAGS.loss == "logl2":
return lambda img, ref: ru.image_loss(img, ref, loss='mse', tonemapper='log_srgb')
elif FLAGS.loss == "relmse":
return lambda img, ref: ru.image_loss(img, ref, loss='relmse', tonemapper='none')
else:
assert False
###############################################################################
# Mix background into a dataset image
###############################################################################
@torch.no_grad()
def prepare_batch(target, bg_type='black'):
assert len(target['img'].shape) == 4, "Image shape should be [n, h, w, c]"
if bg_type == 'checker':
background = torch.tensor(util.checkerboard(target['img'].shape[1:3], 8), dtype=torch.float32, device='cuda')[None, ...]
elif bg_type == 'black':
background = torch.zeros(target['img'].shape[0:3] + (3,), dtype=torch.float32, device='cuda')
elif bg_type == 'white':
background = torch.ones(target['img'].shape[0:3] + (3,), dtype=torch.float32, device='cuda')
elif bg_type == 'reference':
background = target['img'][..., 0:3]
elif bg_type == 'random':
background = torch.rand(target['img'].shape[0:3] + (3,), dtype=torch.float32, device='cuda')
else:
assert False, "Unknown background type %s" % bg_type
target['mv'] = target['mv'].cuda()
target['mvp'] = target['mvp'].cuda()
target['campos'] = target['campos'].cuda()
target['img'] = target['img'].cuda()
target['background'] = background
target['img'] = torch.cat((torch.lerp(background, target['img'][..., 0:3], target['img'][..., 3:4]), target['img'][..., 3:4]), dim=-1)
target['spts'] = target['spts'].cuda()
target['vpts'] = target['vpts'].cuda()
return target
###############################################################################
# UV - map geometry & convert to a mesh
###############################################################################
@torch.no_grad()
def xatlas_uvmap(glctx, geometry, mat, FLAGS):
eval_mesh = geometry.getMesh(mat)
# Create uvs with xatlas
v_pos = eval_mesh.v_pos.detach().cpu().numpy()
t_pos_idx = eval_mesh.t_pos_idx.detach().cpu().numpy()
vmapping, indices, uvs = xatlas.parametrize(v_pos, t_pos_idx)
# Convert to tensors
indices_int64 = indices.astype(np.uint64, casting='same_kind').view(np.int64)
uvs = torch.tensor(uvs, dtype=torch.float32, device='cuda')
faces = torch.tensor(indices_int64, dtype=torch.int64, device='cuda')
new_mesh = mesh.Mesh(v_tex=uvs, t_tex_idx=faces, base=eval_mesh)
mask, kd, ks, normal = render.render_uv(glctx, new_mesh, FLAGS.texture_res, eval_mesh.material['kd_ks_normal'])
if FLAGS.layers > 1:
kd = torch.cat((kd, torch.rand_like(kd[...,0:1])), dim=-1)
kd_min, kd_max = torch.tensor(FLAGS.kd_min, dtype=torch.float32, device='cuda'), torch.tensor(FLAGS.kd_max, dtype=torch.float32, device='cuda')
ks_min, ks_max = torch.tensor(FLAGS.ks_min, dtype=torch.float32, device='cuda'), torch.tensor(FLAGS.ks_max, dtype=torch.float32, device='cuda')
nrm_min, nrm_max = torch.tensor(FLAGS.nrm_min, dtype=torch.float32, device='cuda'), torch.tensor(FLAGS.nrm_max, dtype=torch.float32, device='cuda')
new_mesh.material = material.Material({
'bsdf' : mat['bsdf'],
'kd' : texture.Texture2D(kd, min_max=[kd_min, kd_max]),
'ks' : texture.Texture2D(ks, min_max=[ks_min, ks_max]),
'normal' : texture.Texture2D(normal, min_max=[nrm_min, nrm_max])
})
return new_mesh
@torch.no_grad()
def xatlas_uvmap_nrm(glctx, geometry, mat, FLAGS):
eval_mesh = geometry.getMesh(mat)
# Create uvs with xatlas
v_pos = eval_mesh.v_pos.detach().cpu().numpy()
t_pos_idx = eval_mesh.t_pos_idx.detach().cpu().numpy()
vmapping, indices, uvs = xatlas.parametrize(v_pos, t_pos_idx)
# Convert to tensors
indices_int64 = indices.astype(np.uint64, casting='same_kind').view(np.int64)
uvs = torch.tensor(uvs, dtype=torch.float32, device='cuda')
faces = torch.tensor(indices_int64, dtype=torch.int64, device='cuda')
new_mesh = mesh.Mesh(v_tex=uvs, t_tex_idx=faces, base=eval_mesh)
mask, normal = render.render_uv_nrm(glctx, new_mesh, FLAGS.texture_res, eval_mesh.material['normal'])
if FLAGS.layers > 1:
kd = torch.cat((kd, torch.rand_like(kd[...,0:1])), dim=-1)
nrm_min, nrm_max = torch.tensor(FLAGS.nrm_min, dtype=torch.float32, device='cuda'), torch.tensor(FLAGS.nrm_max, dtype=torch.float32, device='cuda')
new_mesh.material = material.Material({
'bsdf' : mat['bsdf'],
'kd' : mat['kd'],
'ks' : mat['ks'],
'normal' : texture.Texture2D(normal, min_max=[nrm_min, nrm_max])
})
return new_mesh
###############################################################################
# Utility functions for material
###############################################################################
def initial_guess_material(geometry, mlp, FLAGS, init_mat=None):
kd_min, kd_max = torch.tensor(FLAGS.kd_min, dtype=torch.float32, device='cuda'), torch.tensor(FLAGS.kd_max, dtype=torch.float32, device='cuda')
ks_min, ks_max = torch.tensor(FLAGS.ks_min, dtype=torch.float32, device='cuda'), torch.tensor(FLAGS.ks_max, dtype=torch.float32, device='cuda')
nrm_min, nrm_max = torch.tensor(FLAGS.nrm_min, dtype=torch.float32, device='cuda'), torch.tensor(FLAGS.nrm_max, dtype=torch.float32, device='cuda')
if mlp:
mlp_min = torch.cat((kd_min[0:3], ks_min, nrm_min), dim=0)
mlp_max = torch.cat((kd_max[0:3], ks_max, nrm_max), dim=0)
mlp_map_opt = mlptexture.MLPTexture3D(geometry.getAABB(), channels=9, min_max=[mlp_min, mlp_max])
mat = material.Material({'kd_ks_normal' : mlp_map_opt})
else:
# Setup Kd (albedo) and Ks (x, roughness, metalness) textures
if FLAGS.random_textures or init_mat is None:
num_channels = 4 if FLAGS.layers > 1 else 3
kd_init = torch.rand(size=FLAGS.texture_res + [num_channels], device='cuda') * (kd_max - kd_min)[None, None, 0:num_channels] + kd_min[None, None, 0:num_channels]
kd_map_opt = texture.create_trainable(kd_init , FLAGS.texture_res, not FLAGS.custom_mip, [kd_min, kd_max])
ksR = np.random.uniform(size=FLAGS.texture_res + [1], low=0.0, high=0.01)
ksG = np.random.uniform(size=FLAGS.texture_res + [1], low=ks_min[1].cpu(), high=ks_max[1].cpu())
ksB = np.random.uniform(size=FLAGS.texture_res + [1], low=ks_min[2].cpu(), high=ks_max[2].cpu())
ks_map_opt = texture.create_trainable(np.concatenate((ksR, ksG, ksB), axis=2), FLAGS.texture_res, not FLAGS.custom_mip, [ks_min, ks_max])
else:
kd_map_opt = texture.create_trainable(init_mat['kd'], FLAGS.texture_res, not FLAGS.custom_mip, [kd_min, kd_max])
ks_map_opt = texture.create_trainable(init_mat['ks'], FLAGS.texture_res, not FLAGS.custom_mip, [ks_min, ks_max])
# Setup normal map
if FLAGS.random_textures or init_mat is None or 'normal' not in init_mat:
normal_map_opt = texture.create_trainable(np.array([0, 0, 1]), FLAGS.texture_res, not FLAGS.custom_mip, [nrm_min, nrm_max])
else:
normal_map_opt = texture.create_trainable(init_mat['normal'], FLAGS.texture_res, not FLAGS.custom_mip, [nrm_min, nrm_max])
mat = material.Material({
'kd' : kd_map_opt,
'ks' : ks_map_opt,
'normal' : normal_map_opt
})
if init_mat is not None:
mat['bsdf'] = init_mat['bsdf']
else:
mat['bsdf'] = 'pbr'
return mat
def initial_guess_material_knownkskd(geometry, mlp, FLAGS, init_mat=None):
nrm_min, nrm_max = torch.tensor(FLAGS.nrm_min, dtype=torch.float32, device='cuda'), torch.tensor(FLAGS.nrm_max, dtype=torch.float32, device='cuda')
if mlp:
mlp_min = nrm_min
mlp_max = nrm_max
mlp_map_opt = mlptexture.MLPTexture3D(geometry.getAABB(), channels=3, min_max=[mlp_min, mlp_max])
# mlp_map_opt = mlptexture.MLPTexture3D(geometry.getAABB(), channels=3, min_max=None)
mat = material.Material({
'kd' : init_mat['kd'],
'ks' : init_mat['ks'],
'normal' : mlp_map_opt,
})
else:
# Setup normal map
if FLAGS.random_textures or init_mat is None or 'normal' not in init_mat:
normal_map_opt = texture.create_trainable(np.array([0, 0, 1]), FLAGS.texture_res, not FLAGS.custom_mip, [nrm_min, nrm_max])
else:
normal_map_opt = texture.create_trainable(init_mat['normal'], FLAGS.texture_res, not FLAGS.custom_mip, [nrm_min, nrm_max])
mat = material.Material({
'kd' : init_mat['kd'],
'ks' : init_mat['ks'],
'normal' : normal_map_opt
})
if init_mat is not None:
mat['bsdf'] = init_mat['bsdf']
else:
mat['bsdf'] = 'pbr'
return mat
###############################################################################
# Validation & testing
###############################################################################
def validate_itr(glctx, target, geometry, opt_material, lgt, FLAGS):
result_dict = {}
with torch.no_grad():
lgt.build_mips()
if FLAGS.camera_space_light:
lgt.xfm(target['mv'])
lgt.xfm(target['envlight_transform'])
try:
buffers = geometry.render(glctx, target, lgt, opt_material, ema=True, xfm_lgt=target['envlight_transform'])
except:
buffers = geometry.render(glctx, target, lgt, opt_material, xfm_lgt=target['envlight_transform'])
result_dict['ref'] = util.rgb_to_srgb(target['img'][...,0:3])[0]
result_dict['opt'] = util.rgb_to_srgb(buffers['shaded'][...,0:3])[0]
result_image = torch.cat([result_dict['opt'], result_dict['ref']], axis=1)
if FLAGS.display is not None:
white_bg = torch.ones_like(target['background'])
for layer in FLAGS.display:
if 'latlong' in layer and layer['latlong']:
if isinstance(lgt, light.EnvironmentLight):
result_dict['light_image'] = util.cubemap_to_latlong(lgt.base, FLAGS.display_res)
result_image = torch.cat([result_image, result_dict['light_image']], axis=1)
elif 'relight' in layer:
if not isinstance(layer['relight'], light.EnvironmentLight):
layer['relight'] = light.load_env(layer['relight'])
img = geometry.render(glctx, target, layer['relight'], opt_material)
result_dict['relight'] = util.rgb_to_srgb(img[..., 0:3])[0]
result_image = torch.cat([result_image, result_dict['relight']], axis=1)
elif 'bsdf' in layer:
buffers = geometry.render(glctx, target, lgt, opt_material, bsdf=layer['bsdf'])
if layer['bsdf'] == 'kd':
result_dict[layer['bsdf']] = util.rgb_to_srgb(buffers['shaded'][0, ..., 0:3])
elif layer['bsdf'] == 'normal':
result_dict[layer['bsdf']] = (buffers['shaded'][0, ..., 0:3] + 1) * 0.5
else:
result_dict[layer['bsdf']] = buffers['shaded'][0, ..., 0:3]
result_image = torch.cat([result_image, result_dict[layer['bsdf']]], axis=1)
elif "depth" in layer:
depth = buffers['depth'][:, :, :, 0].squeeze().unsqueeze(-1).expand(-1, -1, 1)
mask = (depth != 0).float()
depth_min = ((1 - mask) * 1e3 + depth).min()
depth_max = depth.max()
depth = (depth - depth_min) / (depth_max - depth_min + 1e-8)
depth = depth * mask + (1 - mask) * depth_min
depth = depth.expand(-1, -1, 3)
depth = cv2.LUT(np.array(depth.detach().cpu().numpy() * 255.0, dtype=np.uint8), lut)
result_dict['depth'] = depth = (torch.tensor(depth, device=mask.device).float() / 255.0 * mask) + 255. * (1 - mask)
result_image = torch.cat([result_image, depth], axis=1)
buffers = geometry.render(glctx, target, lgt, opt_material)
camera = target['geo_viewdir'][:, :, :, :3]
result_dict['geo_normal'] = (util.safe_normalize(buffers['geo_normal'][:, :, :, :3]) * camera).sum(-1, keepdim=False).abs()[0]
mask = buffers['mask'][0].expand(-1, -1, 3)
result_image = torch.cat([result_image, mask], axis=1)
return result_image, result_dict
def validate(glctx, geometry, opt_material, lgt, dataset_validate, out_dir, FLAGS):
# ==============================================================================================
# Validation loop
# ==============================================================================================
mse_values = []
psnr_values = []
dataloader_validate = torch.utils.data.DataLoader(dataset_validate, batch_size=1, collate_fn=dataset_validate.collate)
os.makedirs(out_dir, exist_ok=True)
with open(os.path.join(out_dir, 'metrics.txt'), 'w') as fout:
fout.write('ID, MSE, PSNR\n')
print("Running validation")
for it, target in enumerate(dataloader_validate):
# Mix validation background
target = prepare_batch(target, FLAGS.background)
result_image, result_dict = validate_itr(glctx, target, geometry, opt_material, lgt, FLAGS)
# Compute metrics
opt = torch.clamp(result_dict['opt'], 0.0, 1.0)
ref = torch.clamp(result_dict['ref'], 0.0, 1.0)
mse = torch.nn.functional.mse_loss(opt, ref, size_average=None, reduce=None, reduction='mean').item()
mse_values.append(float(mse))
psnr = util.mse_to_psnr(mse)
psnr_values.append(float(psnr))
line = "%d, %1.8f, %1.8f\n" % (it, mse, psnr)
fout.write(str(line))
for k in result_dict.keys():
np_img = result_dict[k].detach().cpu().numpy()
util.save_image(out_dir + '/' + ('val_%06d_%s.png' % (it, k)), np_img)
avg_mse = np.mean(np.array(mse_values))
avg_psnr = np.mean(np.array(psnr_values))
line = "AVERAGES: %1.4f, %2.3f\n" % (avg_mse, avg_psnr)
fout.write(str(line))
print("MSE, PSNR")
print("%1.8f, %2.3f" % (avg_mse, avg_psnr))
return avg_psnr
###############################################################################
# Main shape fitter function / optimization loop
###############################################################################
class Trainer(torch.nn.Module):
def __init__(self, glctx, geometry, lgt, mat, optimize_geometry, optimize_light, image_loss_fn, FLAGS):
super(Trainer, self).__init__()
self.glctx = glctx
self.geometry = geometry
self.light = lgt
self.material = mat
self.optimize_geometry = optimize_geometry
self.optimize_light = optimize_light
self.image_loss_fn = image_loss_fn
self.FLAGS = FLAGS
if not self.optimize_light:
with torch.no_grad():
self.light.build_mips()
self.params = list(self.material.parameters())
self.params += list(self.light.parameters()) if optimize_light else []
self.geo_params = list(self.geometry.parameters()) if optimize_geometry else []
try:
self.sdf_params = [self.geometry.sdf]
except:
self.sdf_params = []
self.deform_params = [self.geometry.deform]
def forward(self, target, it):
if self.optimize_light:
self.light.build_mips()
if self.FLAGS.camera_space_light:
self.light.xfm(target['mv'])
self.light.xfm(target['envlight_transform'])
return self.geometry.tick(glctx, target, self.light, self.material, self.image_loss_fn, it, xfm_lgt=target['envlight_transform'], no_depth_thin=False)
def optimize_mesh(
glctx,
geometry,
opt_material,
lgt,
dataset_train,
dataset_validate,
FLAGS,
warmup_iter=0,
log_interval=10,
pass_idx=0,
pass_name="",
optimize_light=True,
optimize_geometry=True,
):
# ==============================================================================================
# Setup torch optimizer
# ==============================================================================================
learning_rate = FLAGS.learning_rate[pass_idx] if isinstance(FLAGS.learning_rate, list) or isinstance(FLAGS.learning_rate, tuple) else FLAGS.learning_rate
learning_rate_pos = learning_rate[0] if isinstance(learning_rate, list) or isinstance(learning_rate, tuple) else learning_rate
learning_rate_mat = learning_rate[1] if isinstance(learning_rate, list) or isinstance(learning_rate, tuple) else learning_rate
def lr_schedule(iter, fraction):
if iter < warmup_iter:
return iter / warmup_iter
return max(0.0, 10**(-(iter - warmup_iter)*0.0002)) # Exponential falloff from [1.0, 0.1] over 5k epochs.
# ==============================================================================================
# Image loss
# ==============================================================================================
image_loss_fn = createLoss(FLAGS)
trainer_noddp = Trainer(glctx, geometry, lgt, opt_material, optimize_geometry, optimize_light, image_loss_fn, FLAGS)
# Single GPU training mode
trainer = trainer_noddp
if optimize_geometry:
optimizer_mesh = torch.optim.Adam([
{'params': trainer_noddp.sdf_params, 'lr': learning_rate_pos},
{'params': trainer_noddp.deform_params, 'lr': learning_rate_pos},
]
)
scheduler_mesh = torch.optim.lr_scheduler.LambdaLR(optimizer_mesh, lr_lambda=lambda x: lr_schedule(x, 0.9))
optimizer = torch.optim.Adam(trainer_noddp.params, lr=learning_rate_mat)
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda x: lr_schedule(x, 0.9))
# ==============================================================================================
# Training loop
# ==============================================================================================
img_cnt = 0
img_loss_vec = []
reg_loss_vec = []
iter_dur_vec = []
dataloader_validate = torch.utils.data.DataLoader(dataset_validate, batch_size=1, collate_fn=dataset_train.collate)
def cycle(iterable):
iterator = iter(iterable)
while True:
try:
yield next(iterator)
except StopIteration:
iterator = iter(iterable)
v_it = cycle(dataloader_validate)
# v_iter_no = 25
v_iter_no = 10
print("Start training loop...")
sys.stdout.flush()
for _ in range(v_iter_no):
v_curr = next(v_it)
for it in range(5000):
# Mix randomized background into dataset image
target = prepare_batch(v_curr, 'random')
### for robustness, we take the easy way of initializing the tet grid with the gt depth image
if it < 300 and it % 10 == 0:
gt_visible_triangles = target['rast_triangle_id'].long()
gt_verts, gt_faces = target['vpts'], target['faces']
surface_faces = gt_faces[gt_visible_triangles]
campos = target['campos'][0]
try:
geometry.init_with_gt_surface(gt_verts, surface_faces, campos)
except:
pass
iter_start_time = time.time()
# ==============================================================================================
# Zero gradients
# ==============================================================================================
optimizer.zero_grad()
if optimize_geometry:
optimizer_mesh.zero_grad()
# ==============================================================================================
# Training
# ==============================================================================================
img_loss, reg_loss = trainer(target, it)
# ==============================================================================================
# Final loss
# ==============================================================================================
total_loss = img_loss + reg_loss
img_loss_vec.append(img_loss.item())
reg_loss_vec.append(reg_loss.item())
# ==============================================================================================
# Backpropagate
# ==============================================================================================
total_loss.backward()
if hasattr(lgt, 'base') and lgt.base.grad is not None and optimize_light:
lgt.base.grad *= 64
if 'kd_ks_normal' in opt_material:
opt_material['kd_ks_normal'].encoder.params.grad /= 8.0
if 'normal' in opt_material and FLAGS.normal_only:
try:
opt_material['normal'].encoder.params.grad /= 8.0
except:
pass
optimizer.step()
scheduler.step()
if optimize_geometry:
optimizer_mesh.step()
scheduler_mesh.step()
geometry.clamp_deform()
if FLAGS.use_ema:
raise NotImplementedError
geometry.update_ema()
# ==============================================================================================
# Clamp trainables to reasonable range
# ==============================================================================================
with torch.no_grad():
if 'kd' in opt_material:
opt_material['kd'].clamp_()
if 'ks' in opt_material:
opt_material['ks'].clamp_()
if 'normal' in opt_material and not FLAGS.normal_only:
opt_material['normal'].clamp_()
opt_material['normal'].normalize_()
if lgt is not None:
lgt.clamp_(min=0.0)
torch.cuda.current_stream().synchronize()
iter_dur_vec.append(time.time() - iter_start_time)
# ==============================================================================================
# Logging
# ==============================================================================================
if it % log_interval == 0 and FLAGS.local_rank == 0:
img_loss_avg = np.mean(np.asarray(img_loss_vec[-log_interval:]))
reg_loss_avg = np.mean(np.asarray(reg_loss_vec[-log_interval:]))
iter_dur_avg = np.mean(np.asarray(iter_dur_vec[-log_interval:]))
remaining_time = (FLAGS.iter-it)*iter_dur_avg
print("iter=%5d, img_loss=%.6f, reg_loss=%.6f, lr=%.5f, time=%.1f ms, rem=%s" %
(it, img_loss_avg, reg_loss_avg, optimizer.param_groups[0]['lr'], iter_dur_avg*1000, util.time_to_text(remaining_time)))
sys.stdout.flush()
return geometry, opt_material
#----------------------------------------------------------------------------
# Main function.
#----------------------------------------------------------------------------
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='nvdiffrec')
parser.add_argument('--config', type=str, default='./configs/res64.json', help='Config file')
parser.add_argument('-i', '--iter', type=int, default=5000)
parser.add_argument('-s', '--spp', type=int, default=1)
parser.add_argument('-l', '--layers', type=int, default=1)
parser.add_argument('-r', '--train-res', nargs=2, type=int, default=[512, 512])
parser.add_argument('-dr', '--display-res', type=int, default=None)
parser.add_argument('-tr', '--texture-res', nargs=2, type=int, default=[1024, 1024])
parser.add_argument('-di', '--display-interval', type=int, default=0)
parser.add_argument('-si', '--save-interval', type=int, default=1000)
parser.add_argument('-lr', '--learning-rate', type=float, default=0.01)
parser.add_argument('-mr', '--min-roughness', type=float, default=0.08)
parser.add_argument('-mip', '--custom-mip', action='store_true', default=False)
parser.add_argument('-rt', '--random-textures', action='store_true', default=False)
parser.add_argument('-bg', '--background', default='checker', choices=['black', 'white', 'checker', 'reference'])
parser.add_argument('--loss', default='logl1', choices=['logl1', 'logl2', 'mse', 'smape', 'relmse'])
parser.add_argument('-o', '--out-dir', type=str, default='./dmtet_results_singleview')
parser.add_argument('--validate', type=bool, default=True)
parser.add_argument('-no', '--normal-only', type=bool, default=True)
parser.add_argument('-ema', '--use-ema', action="store_true")
parser.add_argument('-rp', '--resume-path', type=str, default=None)
parser.add_argument('-mp', '--mesh-path', type=str)
parser.add_argument('-an', '--angle-ind', type=int, help='angle index from 0 to 50')
FLAGS = parser.parse_args()
print(f"parsed arguments")
FLAGS.mtl_override = None # Override material of model
FLAGS.dmtet_grid = 64 # Resolution of initial tet grid. We provide 64 and 128 resolution grids. Other resolutions can be generated with https://github.com/crawforddoran/quartet
FLAGS.mesh_scale = 1.0 # Scale of tet grid box. Adjust to cover the model
FLAGS.env_scale = 1.0 # Env map intensity multiplier
FLAGS.envmap = None # HDR environment probe
FLAGS.display = None # Conf validation window/display. E.g. [{"relight" : <path to envlight>}]
FLAGS.camera_space_light = False # Fixed light in camera space. This is needed for setups like ethiopian head where the scanned object rotates on a stand.
FLAGS.lock_light = False # Disable light optimization in the second pass
FLAGS.lock_pos = False # Disable vertex position optimization in the second pass
FLAGS.sdf_regularizer = 0.2 # Weight for sdf regularizer (see paper for details)
FLAGS.laplace = "relative" # Mesh Laplacian ["absolute", "relative"]
FLAGS.laplace_scale = 10000.0 # Weight for sdf regularizer. Default is relative with large weight
FLAGS.pre_load = True # Pre-load entire dataset into memory for faster training
FLAGS.kd_min = [ 0.0, 0.0, 0.0, 0.0] # Limits for kd
FLAGS.kd_max = [ 1.0, 1.0, 1.0, 1.0]
FLAGS.ks_min = [ 0.0, 0.08, 0.0] # Limits for ks
FLAGS.ks_max = [ 1.0, 1.0, 1.0]
FLAGS.nrm_min = [-1.0, -1.0, 0.0] # Limits for normal map
FLAGS.nrm_max = [ 1.0, 1.0, 1.0]
FLAGS.cam_near_far = [0.1, 1000.0]
FLAGS.learn_light = False
FLAGS.use_ema = False
FLAGS.random_lgt = True
FLAGS.dataset_flat_shading = False
FLAGS.local_rank = 0
if FLAGS.config is not None:
data = json.load(open(FLAGS.config, 'r'))
for key in data:
FLAGS.__dict__[key] = data[key]
if FLAGS.display_res is None:
FLAGS.display_res = FLAGS.train_res
print(f"Out dir: {FLAGS.out_dir}")
if FLAGS.local_rank == 0:
print("Config / Flags:")
print("---------")
for key in FLAGS.__dict__.keys():
print(key, FLAGS.__dict__[key])
print("---------")
os.makedirs(FLAGS.out_dir, exist_ok=True)
os.makedirs(os.path.join(FLAGS.out_dir, 'val_viz'), exist_ok=True)
os.makedirs(os.path.join(FLAGS.out_dir, 'val_viz_pre'), exist_ok=True)
os.makedirs(os.path.join(FLAGS.out_dir, 'tets'), exist_ok=True)
os.makedirs(os.path.join(FLAGS.out_dir, 'tets_pre'), exist_ok=True)
print(f"Using dmtet grid of resolution {FLAGS.dmtet_grid}")
glctx = dr.RasterizeGLContext()
### Default mtl
mtl_default = {
'name' : '_default_mat',
'bsdf': 'diffuse',
'uniform': True,
'kd' : texture.Texture2D(torch.tensor([0.75, 0.3, 0.6], dtype=torch.float32, device='cuda'), trainable=False),
'ks' : texture.Texture2D(torch.tensor([0.0, 0.0, 0.0], dtype=torch.float32, device='cuda'), trainable=False)
}
print(f"Loading mesh: {FLAGS.mesh_path}")
sys.stdout.flush()
ref_mesh = mesh.load_mesh(FLAGS.mesh_path, FLAGS.mtl_override, mtl_default, use_default=FLAGS.normal_only, no_additional=True)
ref_mesh = mesh.center_by_reference(ref_mesh, mesh.aabb_clean(ref_mesh), 1.0)
print("Loading dataset")
sys.stdout.flush()
dataset_train = DatasetMesh(ref_mesh, glctx, RADIUS, FLAGS, validate=False)
dataset_validate = DatasetMesh(ref_mesh, glctx, RADIUS, FLAGS, validate=True)
print("Dataset loaded")
sys.stdout.flush()
# ==============================================================================================
# Create env light with trainable parameters
# ==============================================================================================
if FLAGS.learn_light:
lgt = light.create_trainable_env_rnd(512, scale=0.0, bias=0.5)
else:
lgt = light.load_env(FLAGS.envmap, scale=FLAGS.env_scale, trainable=False)
# ==============================================================================================
# If no initial guess, use DMtets to create geometry
# ==============================================================================================
# Setup geometry for optimization
geometry = DMTetGeometry(FLAGS.dmtet_grid, FLAGS.mesh_scale, FLAGS)
# Setup textures, make initial guess from reference if possible
if not FLAGS.normal_only:
mat = initial_guess_material(geometry, True, FLAGS, mtl_default)
else:
mat = initial_guess_material_knownkskd(geometry, True, FLAGS, mtl_default)
print("Start optimization")
sys.stdout.flush()
if FLAGS.resume_path is None:
# Run optimization
geometry, mat = optimize_mesh(glctx, geometry, mat, lgt, dataset_train, dataset_validate,
FLAGS, pass_idx=0, pass_name="dmtet_pass1", optimize_light=FLAGS.learn_light)
base_mesh = geometry.getMesh(mat)
vert_mask = torch.zeros_like(geometry.sdf).long().cuda().view(-1, 1)
vert_mask[geometry.getValidVertsIdx()] = 1
# Free temporaries / cached memory
torch.cuda.empty_cache() ### may slow down training
torch.save({
'sdf': geometry.sdf.cpu().detach(),
'sdf_ema': geometry.sdf_ema.cpu().detach(),
'deform': (geometry.deform * vert_mask).cpu().detach(),
'deform_unmasked': geometry.deform.cpu().detach(),
}, os.path.join(FLAGS.out_dir, 'tets_pre/dmt_dict_{:05d}.pt'.format(global_index)))
old_geometry = geometry
if FLAGS.local_rank == 0 and FLAGS.validate:
validate(glctx, geometry, mat, lgt, dataset_validate, os.path.join(FLAGS.out_dir, f"val_viz_pre/dmtet_validate_{FLAGS.index}_{k}_{FLAGS.split_size}"), FLAGS)
else:
dmt_dict = torch.load(os.path.join(FLAGS.resume_path, 'tets_pre/dmt_dict_{:05d}.pt'.format(global_index)))
if FLAGS.use_ema:
geometry.sdf.data[:] = dmt_dict['sdf_ema']
else:
geometry.sdf.data[:] = dmt_dict['sdf']
geometry.deform.data[:] = dmt_dict['deform']
old_geometry = geometry
# Create textured mesh from result
if FLAGS.normal_only:
base_mesh = xatlas_uvmap_nrm(glctx, geometry, mat, FLAGS)
else:
base_mesh = xatlas_uvmap(glctx, geometry, mat, FLAGS)
geometry = DMTetGeometryFixedTopo(geometry, base_mesh, FLAGS.dmtet_grid, FLAGS.mesh_scale, FLAGS)
geometry.sdf_sign.requires_grad = False
geometry.sdf_abs.requires_grad = False
geometry.deform.requires_grad = True
geometry.deform.data[:] = geometry.deform * 2.0 / 3.0
geometry.deform_scale = 3.0
if FLAGS.use_ema:
geometry.sdf_sign.data[:] = torch.sign(old_geometry.sdf_ema) ### use ema
else:
geometry.sdf_sign.data[:] = torch.sign(old_geometry.sdf) ### use ema
geometry.set_init_v_pos()
# ==============================================================================================
# Pass 2: Train with fixed topology (mesh)
# ==============================================================================================
geometry, mat = optimize_mesh(glctx, geometry, mat, lgt, dataset_train, dataset_validate, FLAGS,
pass_idx=1, pass_name="mesh_pass", warmup_iter=100, optimize_light=FLAGS.learn_light and not FLAGS.lock_light,
optimize_geometry=not FLAGS.lock_pos)
##### Process single-view tet grid
dataloader_validate = torch.utils.data.DataLoader(dataset_validate, batch_size=1, collate_fn=dataset_train.collate)
v_it = iter(dataloader_validate)
for _ in range(FLAGS.angle_ind):
v_curr = next(v_it)
target = prepare_batch(v_curr, 'random')
# ==============================================================================================
# Infer occluded regions
# ==============================================================================================
valid_tet_idx = geometry.getValidTetIdx().long()
buffers = geometry.render(glctx, target, lgt, mat, get_visible_tets=True)
## visible tets (except for rasterized ones)
visible_tets = torch.zeros(geometry.indices.size(0)).cuda()
visible_tets[buffers['visible_tet_id'].long()] = 1
## to include the rasterized tetrahedra
visible_and_rast_tets = visible_tets.clone()
rast_tet_id = valid_tet_idx[buffers['rast_triangle_id'].long()].unique()
visible_and_rast_tets[rast_tet_id] = 1
visible_tets = (visible_tets == 1)
visible_and_rast_tets = (visible_and_rast_tets == 1)
## label all tetrahedral vertices associated with any visible tets
visible_verts = torch.zeros(geometry.verts.size(0))
tet_inds = torch.arange(geometry.indices.size(0))
vis_vert_inds = geometry.indices[visible_tets].unique()
visible_verts[vis_vert_inds] = 1
visible_and_rast_verts = visible_verts.clone()
vis_and_rast_vert_inds = geometry.indices[visible_and_rast_tets].unique()
visible_and_rast_verts[vis_and_rast_vert_inds] = 1
visible_and_rast_verts = visible_and_rast_verts.bool()
torch.save({
'sdf': geometry.sdf_sign.cpu().detach(),
'deform': geometry.deform.cpu().detach(),
'vis': visible_verts.cpu().detach(),
'vis_rast': visible_and_rast_verts.cpu().detach()
}, os.path.join(FLAGS.out_dir, 'tets/dmtet.pt'.format(global_index)))
# ==============================================================================================
if FLAGS.local_rank == 0 and FLAGS.validate:
validate(glctx, geometry, mat, lgt, dataset_validate, os.path.join(FLAGS.out_dir, f"val_viz/dmtet"), FLAGS)