kopia lustrzana https://github.com/lzzcd001/MeshDiffusion
first commit
rodzic
c7fe8274ef
commit
91b21d1cae
|
@ -17,7 +17,6 @@ dist/
|
|||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
|
|
|
@ -0,0 +1,196 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2020 The Google Research Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# pylint: skip-file
|
||||
"""Return training and evaluation/test datasets from config files."""
|
||||
import jax
|
||||
import tensorflow as tf
|
||||
import tensorflow_datasets as tfds
|
||||
|
||||
|
||||
def get_data_scaler(config):
|
||||
"""Data normalizer. Assume data are always in [0, 1]."""
|
||||
if config.data.centered:
|
||||
# Rescale to [-1, 1]
|
||||
return lambda x: x * 2. - 1.
|
||||
else:
|
||||
return lambda x: x
|
||||
|
||||
|
||||
def get_data_inverse_scaler(config):
|
||||
"""Inverse data normalizer."""
|
||||
if config.data.centered:
|
||||
# Rescale [-1, 1] to [0, 1]
|
||||
return lambda x: (x + 1.) / 2.
|
||||
else:
|
||||
return lambda x: x
|
||||
|
||||
|
||||
def crop_resize(image, resolution):
|
||||
"""Crop and resize an image to the given resolution."""
|
||||
crop = tf.minimum(tf.shape(image)[0], tf.shape(image)[1])
|
||||
h, w = tf.shape(image)[0], tf.shape(image)[1]
|
||||
image = image[(h - crop) // 2:(h + crop) // 2,
|
||||
(w - crop) // 2:(w + crop) // 2]
|
||||
image = tf.image.resize(
|
||||
image,
|
||||
size=(resolution, resolution),
|
||||
antialias=True,
|
||||
method=tf.image.ResizeMethod.BICUBIC)
|
||||
return tf.cast(image, tf.uint8)
|
||||
|
||||
|
||||
def resize_small(image, resolution):
|
||||
"""Shrink an image to the given resolution."""
|
||||
h, w = image.shape[0], image.shape[1]
|
||||
ratio = resolution / min(h, w)
|
||||
h = tf.round(h * ratio, tf.int32)
|
||||
w = tf.round(w * ratio, tf.int32)
|
||||
return tf.image.resize(image, [h, w], antialias=True)
|
||||
|
||||
|
||||
def central_crop(image, size):
|
||||
"""Crop the center of an image to the given size."""
|
||||
top = (image.shape[0] - size) // 2
|
||||
left = (image.shape[1] - size) // 2
|
||||
return tf.image.crop_to_bounding_box(image, top, left, size, size)
|
||||
|
||||
|
||||
def get_dataset(config, uniform_dequantization=False, evaluation=False):
|
||||
"""Create data loaders for training and evaluation.
|
||||
|
||||
Args:
|
||||
config: A ml_collection.ConfigDict parsed from config files.
|
||||
uniform_dequantization: If `True`, add uniform dequantization to images.
|
||||
evaluation: If `True`, fix number of epochs to 1.
|
||||
|
||||
Returns:
|
||||
train_ds, eval_ds, dataset_builder.
|
||||
"""
|
||||
# Compute batch size for this worker.
|
||||
batch_size = config.training.batch_size if not evaluation else config.eval.batch_size
|
||||
if batch_size % jax.device_count() != 0:
|
||||
raise ValueError(f'Batch sizes ({batch_size} must be divided by'
|
||||
f'the number of devices ({jax.device_count()})')
|
||||
|
||||
# Reduce this when image resolution is too large and data pointer is stored
|
||||
shuffle_buffer_size = 10000
|
||||
prefetch_size = tf.data.experimental.AUTOTUNE
|
||||
num_epochs = None if not evaluation else 1
|
||||
|
||||
# Create dataset builders for each dataset.
|
||||
if config.data.dataset == 'CIFAR10':
|
||||
dataset_builder = tfds.builder('cifar10')
|
||||
train_split_name = 'train'
|
||||
eval_split_name = 'test'
|
||||
|
||||
def resize_op(img):
|
||||
img = tf.image.convert_image_dtype(img, tf.float32)
|
||||
return tf.image.resize(img, [config.data.image_size, config.data.image_size], antialias=True)
|
||||
|
||||
elif config.data.dataset == 'SVHN':
|
||||
dataset_builder = tfds.builder('svhn_cropped')
|
||||
train_split_name = 'train'
|
||||
eval_split_name = 'test'
|
||||
|
||||
def resize_op(img):
|
||||
img = tf.image.convert_image_dtype(img, tf.float32)
|
||||
return tf.image.resize(img, [config.data.image_size, config.data.image_size], antialias=True)
|
||||
|
||||
elif config.data.dataset == 'CELEBA':
|
||||
dataset_builder = tfds.builder('celeb_a')
|
||||
train_split_name = 'train'
|
||||
eval_split_name = 'validation'
|
||||
|
||||
def resize_op(img):
|
||||
img = tf.image.convert_image_dtype(img, tf.float32)
|
||||
img = central_crop(img, 140)
|
||||
img = resize_small(img, config.data.image_size)
|
||||
return img
|
||||
|
||||
elif config.data.dataset == 'LSUN':
|
||||
dataset_builder = tfds.builder(f'lsun/{config.data.category}')
|
||||
train_split_name = 'train'
|
||||
eval_split_name = 'validation'
|
||||
|
||||
if config.data.image_size == 128:
|
||||
def resize_op(img):
|
||||
img = tf.image.convert_image_dtype(img, tf.float32)
|
||||
img = resize_small(img, config.data.image_size)
|
||||
img = central_crop(img, config.data.image_size)
|
||||
return img
|
||||
|
||||
else:
|
||||
def resize_op(img):
|
||||
img = crop_resize(img, config.data.image_size)
|
||||
img = tf.image.convert_image_dtype(img, tf.float32)
|
||||
return img
|
||||
|
||||
elif config.data.dataset in ['FFHQ', 'CelebAHQ']:
|
||||
dataset_builder = tf.data.TFRecordDataset(config.data.tfrecords_path)
|
||||
train_split_name = eval_split_name = 'train'
|
||||
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f'Dataset {config.data.dataset} not yet supported.')
|
||||
|
||||
# Customize preprocess functions for each dataset.
|
||||
if config.data.dataset in ['FFHQ', 'CelebAHQ']:
|
||||
def preprocess_fn(d):
|
||||
sample = tf.io.parse_single_example(d, features={
|
||||
'shape': tf.io.FixedLenFeature([3], tf.int64),
|
||||
'data': tf.io.FixedLenFeature([], tf.string)})
|
||||
data = tf.io.decode_raw(sample['data'], tf.uint8)
|
||||
data = tf.reshape(data, sample['shape'])
|
||||
data = tf.transpose(data, (1, 2, 0))
|
||||
img = tf.image.convert_image_dtype(data, tf.float32)
|
||||
if config.data.random_flip and not evaluation:
|
||||
img = tf.image.random_flip_left_right(img)
|
||||
if uniform_dequantization:
|
||||
img = (tf.random.uniform(img.shape, dtype=tf.float32) + img * 255.) / 256.
|
||||
return dict(image=img, label=None)
|
||||
|
||||
else:
|
||||
def preprocess_fn(d):
|
||||
"""Basic preprocessing function scales data to [0, 1) and randomly flips."""
|
||||
img = resize_op(d['image'])
|
||||
if config.data.random_flip and not evaluation:
|
||||
img = tf.image.random_flip_left_right(img)
|
||||
if uniform_dequantization:
|
||||
img = (tf.random.uniform(img.shape, dtype=tf.float32) + img * 255.) / 256.
|
||||
|
||||
return dict(image=img, label=d.get('label', None))
|
||||
|
||||
def create_dataset(dataset_builder, split):
|
||||
dataset_options = tf.data.Options()
|
||||
dataset_options.experimental_optimization.map_parallelization = True
|
||||
dataset_options.experimental_threading.private_threadpool_size = 48
|
||||
dataset_options.experimental_threading.max_intra_op_parallelism = 1
|
||||
read_config = tfds.ReadConfig(options=dataset_options)
|
||||
if isinstance(dataset_builder, tfds.core.DatasetBuilder):
|
||||
dataset_builder.download_and_prepare()
|
||||
ds = dataset_builder.as_dataset(
|
||||
split=split, shuffle_files=True, read_config=read_config)
|
||||
else:
|
||||
ds = dataset_builder.with_options(dataset_options)
|
||||
ds = ds.repeat(count=num_epochs)
|
||||
ds = ds.shuffle(shuffle_buffer_size)
|
||||
ds = ds.map(preprocess_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE)
|
||||
ds = ds.batch(batch_size, drop_remainder=True)
|
||||
return ds.prefetch(prefetch_size)
|
||||
|
||||
train_ds = create_dataset(dataset_builder, train_split_name)
|
||||
eval_ds = create_dataset(dataset_builder, eval_split_name)
|
||||
return train_ds, eval_ds, dataset_builder
|
|
@ -0,0 +1,49 @@
|
|||
import os
|
||||
import sys
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
import json
|
||||
|
||||
import argparse
|
||||
|
||||
class ShapeNetDMTetDataset(Dataset):
|
||||
def __init__(self, root, grid_mask, deform_scale=1.0, aug=False, filter_meta_path=None, normalize_sdf=True):
|
||||
super().__init__()
|
||||
self.fpath_list = json.load(open(root, 'r'))
|
||||
self.deform_scale = deform_scale
|
||||
self.normalize_sdf = normalize_sdf
|
||||
print(f"dataset with sdf normalized: {normalize_sdf}")
|
||||
self.coeff = torch.tensor([1.0, 1.0, self.deform_scale, self.deform_scale, self.deform_scale]).view(-1, 1, 1, 1)
|
||||
self.aug = aug
|
||||
self.grid_mask = grid_mask.cpu()
|
||||
self.resolution = self.grid_mask.size(-1)
|
||||
|
||||
if filter_meta_path is not None:
|
||||
self.filter_ids = json.load(open(filter_meta_path, 'r'))
|
||||
full_id_list = [int(x.rstrip().split('_')[-1][:-3]) for i, x in enumerate(self.fpath_list)]
|
||||
fpath_idx_list = [i for i, x in enumerate(full_id_list) if x in self.filter_ids]
|
||||
self.fpath_list = [self.fpath_list[i] for i in fpath_idx_list]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.fpath_list)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
with torch.no_grad():
|
||||
datum = torch.load(self.fpath_list[idx], map_location='cpu')
|
||||
if self.normalize_sdf:
|
||||
sdf_sign = torch.sign(datum[:, :1])
|
||||
sdf_sign[sdf_sign == 0] = 1.0
|
||||
datum[:, :1] = sdf_sign
|
||||
if self.aug:
|
||||
nonempty_mask = (datum[1:].abs().sum(dim=0, keepdim=True) != 0)
|
||||
datum[1:] = datum[1:] + (torch.rand(3)[:, None, None, None] - 0.5) * 0.01 * nonempty_mask / (datum.size(-1) / self.resolution)
|
||||
|
||||
if datum.size(-1) < self.resolution:
|
||||
datum = datum * self.grid_mask[0, :, :datum.size(-1), :datum.size(-1), :datum.size(-1)]
|
||||
else:
|
||||
datum = datum * self.grid_mask[0]
|
||||
|
||||
if datum.size(-1) < self.resolution:
|
||||
diff = self.resolution - datum.size(-1)
|
||||
datum = torch.nn.functional.pad(datum, (0, diff, 0, diff, 0, diff, 0, 0))
|
||||
return datum
|
|
@ -0,0 +1,212 @@
|
|||
import os
|
||||
import sys
|
||||
import numpy as np
|
||||
|
||||
import logging
|
||||
from . import losses
|
||||
from .models import utils as mutils
|
||||
from .models.ema import ExponentialMovingAverage
|
||||
from . import sde_lib
|
||||
import torch
|
||||
from .utils import restore_checkpoint
|
||||
from . import sampling
|
||||
|
||||
def uncond_gen(
|
||||
config,
|
||||
idx=0,
|
||||
):
|
||||
"""
|
||||
Unconditional Generation
|
||||
"""
|
||||
with torch.no_grad():
|
||||
eval_dir, ckpt_path = config.eval.eval_dir, config.eval.ckpt_path
|
||||
# Create directory to eval_folder
|
||||
os.makedirs(eval_dir, exist_ok=True)
|
||||
|
||||
scaler, inverse_scaler = lambda x: x, lambda x: x
|
||||
|
||||
# Initialize model
|
||||
score_model = mutils.create_model(config)
|
||||
optimizer = losses.get_optimizer(config, score_model.parameters())
|
||||
ema = ExponentialMovingAverage(score_model.parameters(), decay=config.model.ema_rate)
|
||||
state = dict(optimizer=optimizer, model=score_model, ema=ema, step=0)
|
||||
|
||||
# Setup SDEs
|
||||
sde = sde_lib.VPSDE(beta_min=config.model.beta_min, beta_max=config.model.beta_max, N=config.model.num_scales)
|
||||
|
||||
img_size = config.data.image_size
|
||||
grid_mask = torch.load(f'./data/grid_mask_{img_size}.pt').view(1, img_size, img_size, img_size).to("cuda")
|
||||
|
||||
sampling_eps = 1e-3
|
||||
sampling_shape = (config.eval.batch_size,
|
||||
config.data.num_channels,
|
||||
config.data.image_size, config.data.image_size, config.data.image_size)
|
||||
sampling_fn = sampling.get_sampling_fn(config, sde, sampling_shape, inverse_scaler, sampling_eps, grid_mask=grid_mask)
|
||||
|
||||
assert os.path.exists(ckpt_path)
|
||||
print('ckpt path:', ckpt_path)
|
||||
try:
|
||||
state = restore_checkpoint(ckpt_path, state, device=config.device)
|
||||
except:
|
||||
raise
|
||||
ema.copy_to(score_model.parameters())
|
||||
|
||||
print(f"loaded model is trained till iter {state['step'] // config.training.iter_size}")
|
||||
save_file_path = os.path.join(eval_dir, f"{idx}.npy")
|
||||
|
||||
|
||||
samples, n = sampling_fn(score_model)
|
||||
samples = samples.cpu().numpy()
|
||||
np.save(save_file_path, samples)
|
||||
|
||||
|
||||
def slerp(z1, z2, alpha):
|
||||
'''
|
||||
Spherical Linear Interpolation
|
||||
'''
|
||||
theta = torch.acos(torch.sum(z1 * z2) / (torch.norm(z1) * torch.norm(z2)))
|
||||
return (
|
||||
torch.sin((1 - alpha) * theta) / torch.sin(theta) * z1
|
||||
+ torch.sin(alpha * theta) / torch.sin(theta) * z2
|
||||
)
|
||||
|
||||
def uncond_gen_interp(
|
||||
config,
|
||||
idx=0,
|
||||
):
|
||||
"""
|
||||
Generation with interpolation between initial noises
|
||||
Used for DDIM
|
||||
"""
|
||||
with torch.no_grad():
|
||||
eval_dir, ckpt_path = config.eval.eval_dir, config.eval.ckpt_path
|
||||
# Create directory to eval_folder
|
||||
os.makedirs(eval_dir, exist_ok=True)
|
||||
|
||||
scaler, inverse_scaler = lambda x: x, lambda x: x
|
||||
|
||||
# Initialize model
|
||||
score_model = mutils.create_model(config)
|
||||
optimizer = losses.get_optimizer(config, score_model.parameters())
|
||||
ema = ExponentialMovingAverage(score_model.parameters(), decay=config.model.ema_rate)
|
||||
state = dict(optimizer=optimizer, model=score_model, ema=ema, step=0)
|
||||
|
||||
# Setup SDEs
|
||||
sde = sde_lib.VPSDE(beta_min=config.model.beta_min, beta_max=config.model.beta_max, N=config.model.num_scales)
|
||||
|
||||
img_size = config.data.image_size
|
||||
grid_mask = torch.load(f'./data/grid_mask_{img_size}.pt').view(1, img_size, img_size, img_size).to("cuda")
|
||||
|
||||
sampling_eps = 1e-3
|
||||
sampling_shape = (config.eval.batch_size,
|
||||
config.data.num_channels,
|
||||
config.data.image_size, config.data.image_size, config.data.image_size)
|
||||
sampling_fn = sampling.get_sampling_fn(config, sde, sampling_shape, inverse_scaler, sampling_eps, grid_mask=grid_mask)
|
||||
|
||||
assert os.path.exists(ckpt_path)
|
||||
print('ckpt path:', ckpt_path)
|
||||
try:
|
||||
state = restore_checkpoint(ckpt_path, state, device=config.device)
|
||||
except:
|
||||
raise
|
||||
ema.copy_to(score_model.parameters())
|
||||
|
||||
print(f"loaded model is trained till iter {state['step'] // config.training.iter_size}")
|
||||
save_file_path = os.path.join(eval_dir, f"{idx}.npy")
|
||||
|
||||
|
||||
noise = sde.prior_sampling(
|
||||
(2, config.data.num_channels, config.data.image_size, config.data.image_size, config.data.image_size)
|
||||
).to(config.device)
|
||||
|
||||
|
||||
x0 = torch.zeros(sampling_shape, device=config.device)
|
||||
x0[0] = noise[0]
|
||||
x0[-1] = noise[1]
|
||||
for i in range(1, batch_size - 1):
|
||||
x[i] = slerp(x[0], x[-1], i / float(batch_size - 1))
|
||||
|
||||
samples, n = sampling_fn(score_model, x0=x0)
|
||||
samples = samples.cpu().numpy()
|
||||
np.save(save_file_path, samples)
|
||||
|
||||
|
||||
def cond_gen(
|
||||
config,
|
||||
save_fname='0',
|
||||
):
|
||||
"""
|
||||
Conditional Generation with partially completed dmtet from a 2.5D view (converted into a cubic grid)
|
||||
"""
|
||||
with torch.no_grad():
|
||||
eval_dir, ckpt_path = config.eval.eval_dir, config.eval.ckpt_path
|
||||
# Create directory to eval_folder
|
||||
os.makedirs(eval_dir, exist_ok=True)
|
||||
|
||||
scaler, inverse_scaler = lambda x: x, lambda x: x
|
||||
|
||||
# Initialize model
|
||||
score_model = mutils.create_model(config)
|
||||
optimizer = losses.get_optimizer(config, score_model.parameters())
|
||||
ema = ExponentialMovingAverage(score_model.parameters(), decay=config.model.ema_rate)
|
||||
state = dict(optimizer=optimizer, model=score_model, ema=ema, step=0)
|
||||
|
||||
# Setup SDEs
|
||||
sde = sde_lib.VPSDE(beta_min=config.model.beta_min, beta_max=config.model.beta_max, N=config.model.num_scales)
|
||||
|
||||
resolution = config.data.image_size
|
||||
grid_mask = torch.load(f'./data/grid_mask_{resolution}.pt').view(1, 1, resolution, resolution, resolution).to("cuda")
|
||||
|
||||
sampling_eps = 1e-3
|
||||
sampling_shape = (config.eval.batch_size,
|
||||
config.data.num_channels,
|
||||
resolution, resolution, resolution)
|
||||
sampling_fn = sampling.get_sampling_fn(config, sde, sampling_shape, inverse_scaler, sampling_eps, grid_mask=grid_mask)
|
||||
|
||||
assert os.path.exists(ckpt_path)
|
||||
print('ckpt path:', ckpt_path)
|
||||
try:
|
||||
state = restore_checkpoint(ckpt_path, state, device=config.device)
|
||||
except:
|
||||
raise
|
||||
ema.copy_to(score_model.parameters())
|
||||
|
||||
print(f"loaded model is trained till iter {state['step'] // config.training.iter_size}")
|
||||
|
||||
|
||||
save_file_path = os.path.join(eval_dir, f"{save_fname}.npy")
|
||||
|
||||
### Conditional but free gradients; start from small t
|
||||
|
||||
partial_dict = torch.load(config.eval.partial_dmtet_path)
|
||||
partial_sdf = partial_dict['sdf']
|
||||
partial_mask = partial_dict['vis']
|
||||
|
||||
|
||||
### compute the mapping from tet indices to 3D cubic grid vertex indices
|
||||
tet_path = config.eval.tet_path
|
||||
tet = np.load(tet_path)
|
||||
vertices = torch.tensor(tet['vertices'])
|
||||
vertices_unique = vertices[:].unique()
|
||||
dx = vertices_unique[1] - vertices_unique[0]
|
||||
|
||||
ind_to_coord = (torch.round(
|
||||
(vertices - vertices.min()) / dx)
|
||||
).long()
|
||||
|
||||
|
||||
partial_sdf_grid = torch.zeros((1, 1, resolution, resolution, resolution))
|
||||
partial_sdf_grid[0, 0, ind_to_coord[:, 0], ind_to_coord[:, 1], ind_to_coord[:, 2]] = partial_sdf
|
||||
partial_mask_grid = torch.zeros((1, 1, resolution, resolution, resolution))
|
||||
partial_mask_grid[0, 0, ind_to_coord[:, 0], ind_to_coord[:, 1], ind_to_coord[:, 2]] = partial_mask.float()
|
||||
|
||||
samples, n = sampling_fn(
|
||||
score_model,
|
||||
partial=partial_sdf_grid.cuda(),
|
||||
partial_mask=partial_mask_grid.cuda(),
|
||||
freeze_iters=config.eval.freeze_iters
|
||||
)
|
||||
|
||||
samples = samples.cpu().numpy()
|
||||
np.save(save_file_path, samples)
|
||||
|
|
@ -0,0 +1,113 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2020 The Google Research Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# pylint: skip-file
|
||||
# pytype: skip-file
|
||||
"""Various sampling methods."""
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
from scipy import integrate
|
||||
from .models import utils as mutils
|
||||
|
||||
|
||||
def get_div_fn(fn):
|
||||
"""Create the divergence function of `fn` using the Hutchinson-Skilling trace estimator."""
|
||||
|
||||
def div_fn(x, t, eps):
|
||||
with torch.enable_grad():
|
||||
x.requires_grad_(True)
|
||||
fn_eps = torch.sum(fn(x, t) * eps)
|
||||
grad_fn_eps = torch.autograd.grad(fn_eps, x)[0]
|
||||
x.requires_grad_(False)
|
||||
return torch.sum(grad_fn_eps * eps, dim=tuple(range(1, len(x.shape))))
|
||||
|
||||
return div_fn
|
||||
|
||||
|
||||
def get_likelihood_fn(sde, inverse_scaler, hutchinson_type='Rademacher',
|
||||
rtol=1e-5, atol=1e-5, method='RK45', eps=1e-5):
|
||||
"""Create a function to compute the unbiased log-likelihood estimate of a given data point.
|
||||
|
||||
Args:
|
||||
sde: A `sde_lib.SDE` object that represents the forward SDE.
|
||||
inverse_scaler: The inverse data normalizer.
|
||||
hutchinson_type: "Rademacher" or "Gaussian". The type of noise for Hutchinson-Skilling trace estimator.
|
||||
rtol: A `float` number. The relative tolerance level of the black-box ODE solver.
|
||||
atol: A `float` number. The absolute tolerance level of the black-box ODE solver.
|
||||
method: A `str`. The algorithm for the black-box ODE solver.
|
||||
See documentation for `scipy.integrate.solve_ivp`.
|
||||
eps: A `float` number. The probability flow ODE is integrated to `eps` for numerical stability.
|
||||
|
||||
Returns:
|
||||
A function that a batch of data points and returns the log-likelihoods in bits/dim,
|
||||
the latent code, and the number of function evaluations cost by computation.
|
||||
"""
|
||||
|
||||
def drift_fn(model, x, t):
|
||||
"""The drift function of the reverse-time SDE."""
|
||||
score_fn = mutils.get_score_fn(sde, model, train=False, continuous=True)
|
||||
# Probability flow ODE is a special case of Reverse SDE
|
||||
rsde = sde.reverse(score_fn, probability_flow=True)
|
||||
return rsde.sde(x, t)[0]
|
||||
|
||||
def div_fn(model, x, t, noise):
|
||||
return get_div_fn(lambda xx, tt: drift_fn(model, xx, tt))(x, t, noise)
|
||||
|
||||
def likelihood_fn(model, data):
|
||||
"""Compute an unbiased estimate to the log-likelihood in bits/dim.
|
||||
|
||||
Args:
|
||||
model: A score model.
|
||||
data: A PyTorch tensor.
|
||||
|
||||
Returns:
|
||||
bpd: A PyTorch tensor of shape [batch size]. The log-likelihoods on `data` in bits/dim.
|
||||
z: A PyTorch tensor of the same shape as `data`. The latent representation of `data` under the
|
||||
probability flow ODE.
|
||||
nfe: An integer. The number of function evaluations used for running the black-box ODE solver.
|
||||
"""
|
||||
with torch.no_grad():
|
||||
shape = data.shape
|
||||
if hutchinson_type == 'Gaussian':
|
||||
epsilon = torch.randn_like(data)
|
||||
elif hutchinson_type == 'Rademacher':
|
||||
epsilon = torch.randint_like(data, low=0, high=2).float() * 2 - 1.
|
||||
else:
|
||||
raise NotImplementedError(f"Hutchinson type {hutchinson_type} unknown.")
|
||||
|
||||
def ode_func(t, x):
|
||||
sample = mutils.from_flattened_numpy(x[:-shape[0]], shape).to(data.device).type(torch.float32)
|
||||
vec_t = torch.ones(sample.shape[0], device=sample.device) * t
|
||||
drift = mutils.to_flattened_numpy(drift_fn(model, sample, vec_t))
|
||||
logp_grad = mutils.to_flattened_numpy(div_fn(model, sample, vec_t, epsilon))
|
||||
return np.concatenate([drift, logp_grad], axis=0)
|
||||
|
||||
init = np.concatenate([mutils.to_flattened_numpy(data), np.zeros((shape[0],))], axis=0)
|
||||
solution = integrate.solve_ivp(ode_func, (eps, sde.T), init, rtol=rtol, atol=atol, method=method)
|
||||
nfe = solution.nfev
|
||||
zp = solution.y[:, -1]
|
||||
z = mutils.from_flattened_numpy(zp[:-shape[0]], shape).to(data.device).type(torch.float32)
|
||||
delta_logp = mutils.from_flattened_numpy(zp[-shape[0]:], (shape[0],)).to(data.device).type(torch.float32)
|
||||
prior_logp = sde.prior_logp(z)
|
||||
bpd = -(prior_logp + delta_logp) / np.log(2)
|
||||
N = np.prod(shape[1:])
|
||||
bpd = bpd / N
|
||||
# A hack to convert log-likelihoods to bits/dim
|
||||
offset = 7. - inverse_scaler(-1.)
|
||||
bpd = bpd + offset
|
||||
return bpd, z, nfe
|
||||
|
||||
return likelihood_fn
|
|
@ -0,0 +1,141 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2020 The Google Research Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""All functions related to loss computation and optimization.
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.optim as optim
|
||||
import numpy as np
|
||||
from .models import utils as mutils
|
||||
from .sde_lib import VPSDE
|
||||
|
||||
|
||||
def get_optimizer(config, params):
|
||||
"""Returns a flax optimizer object based on `config`."""
|
||||
if config.optim.optimizer == 'Adam':
|
||||
optimizer = optim.Adam(params, lr=config.optim.lr, betas=(config.optim.beta1, 0.999), eps=config.optim.eps,
|
||||
weight_decay=config.optim.weight_decay)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f'Optimizer {config.optim.optimizer} not supported yet!')
|
||||
|
||||
return optimizer
|
||||
|
||||
|
||||
def optimization_manager(config):
|
||||
"""Returns an optimize_fn based on `config`."""
|
||||
|
||||
def optimize_fn(optimizer, params, step, lr=config.optim.lr,
|
||||
warmup=config.optim.warmup,
|
||||
grad_clip=config.optim.grad_clip):
|
||||
"""Optimizes with warmup and gradient clipping (disabled if negative)."""
|
||||
if warmup > 0:
|
||||
for g in optimizer.param_groups:
|
||||
g['lr'] = lr * np.minimum(step / warmup, 1.0)
|
||||
if grad_clip >= 0:
|
||||
torch.nn.utils.clip_grad_norm_(params, max_norm=grad_clip)
|
||||
optimizer.step()
|
||||
|
||||
return optimize_fn
|
||||
|
||||
def get_ddpm_loss_fn(vpsde, train, mask=None, loss_type='l2'):
|
||||
"""Legacy code to reproduce previous results on DDPM. Not recommended for new work."""
|
||||
|
||||
def loss_fn(model, batch):
|
||||
model_fn = mutils.get_model_fn(model, train=train)
|
||||
labels = torch.randint(0, vpsde.N, (batch.shape[0],), device=batch.device)
|
||||
sqrt_alphas_cumprod = vpsde.sqrt_alphas_cumprod.to(batch.device)
|
||||
sqrt_1m_alphas_cumprod = vpsde.sqrt_1m_alphas_cumprod.to(batch.device)
|
||||
noise = torch.randn_like(batch)
|
||||
perturbed_data = sqrt_alphas_cumprod[labels, None, None, None, None] * batch + \
|
||||
sqrt_1m_alphas_cumprod[labels, None, None, None, None] * noise
|
||||
perturbed_data = perturbed_data * mask
|
||||
score = model_fn(perturbed_data, labels)
|
||||
|
||||
if loss_type == 'l2':
|
||||
losses = torch.square(score - noise)
|
||||
elif loss_type == 'l1':
|
||||
losses = torch.abs(score - noise)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
if mask is not None:
|
||||
losses = losses * mask
|
||||
losses = torch.mean(losses.reshape(losses.shape[0], -1), dim=-1)
|
||||
loss = torch.mean(losses) / mask.sum() * np.prod(mask.size())
|
||||
else:
|
||||
losses = torch.mean(losses.reshape(losses.shape[0], -1), dim=-1)
|
||||
loss = torch.mean(losses)
|
||||
|
||||
return loss
|
||||
|
||||
return loss_fn
|
||||
|
||||
def get_step_fn(sde, train, optimize_fn=None, mask=None, loss_type='l2'):
|
||||
"""Create a one-step training/evaluation function.
|
||||
|
||||
Args:
|
||||
sde: An `sde_lib.SDE` object that represents the forward SDE.
|
||||
optimize_fn: An optimization function.
|
||||
reduce_mean: If `True`, average the loss across data dimensions. Otherwise sum the loss across data dimensions.
|
||||
continuous: `True` indicates that the model is defined to take continuous time steps.
|
||||
likelihood_weighting: If `True`, weight the mixture of score matching losses according to
|
||||
https://arxiv.org/abs/2101.09258; otherwise use the weighting recommended by our paper.
|
||||
|
||||
Returns:
|
||||
A one-step function for training or evaluation.
|
||||
"""
|
||||
|
||||
loss_fn = get_ddpm_loss_fn(sde, train, mask=mask, loss_type=loss_type)
|
||||
|
||||
def step_fn(state, batch, clear_grad=True, update_param=True):
|
||||
"""Running one step of training or evaluation.
|
||||
|
||||
This function will undergo `jax.lax.scan` so that multiple steps can be pmapped and jit-compiled together
|
||||
for faster execution.
|
||||
|
||||
Args:
|
||||
state: A dictionary of training information, containing the score model, optimizer,
|
||||
EMA status, and number of optimization steps.
|
||||
batch: A mini-batch of training/evaluation data.
|
||||
|
||||
Returns:
|
||||
loss: The average loss value of this state.
|
||||
"""
|
||||
model = state['model']
|
||||
if train:
|
||||
optimizer = state['optimizer']
|
||||
if clear_grad:
|
||||
optimizer.zero_grad()
|
||||
loss = loss_fn(model, batch)
|
||||
loss.backward()
|
||||
if update_param:
|
||||
optimize_fn(optimizer, model.parameters(), step=state['step'])
|
||||
state['step'] += 1
|
||||
state['ema'].update(model.parameters())
|
||||
else:
|
||||
with torch.no_grad():
|
||||
ema = state['ema']
|
||||
ema.store(model.parameters())
|
||||
ema.copy_to(model.parameters())
|
||||
loss = loss_fn(model, batch)
|
||||
ema.restore(model.parameters())
|
||||
|
||||
return {
|
||||
'loss': loss,
|
||||
}
|
||||
|
||||
return step_fn
|
|
@ -0,0 +1,15 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2020 The Google Research Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
|
@ -0,0 +1,215 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2020 The Google Research Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# pylint: skip-file
|
||||
"""DDPM model.
|
||||
|
||||
This code is the pytorch equivalent of:
|
||||
https://github.com/hojonathanho/diffusion/blob/master/diffusion_tf/models/unet.py
|
||||
"""
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import functools
|
||||
import numpy as np
|
||||
|
||||
from . import utils, layers, normalization
|
||||
|
||||
# RefineBlock = layers.RefineBlock
|
||||
# ResidualBlock = layers.ResidualBlock
|
||||
ResnetBlockDDPM = layers.ResnetBlockDDPM
|
||||
Upsample = layers.Upsample
|
||||
Downsample = layers.Downsample
|
||||
conv3x3 = layers.ddpm_conv3x3
|
||||
conv5x5 = layers.ddpm_conv5x5
|
||||
transposed_conv6x6 = layers.ddpm_conv6x6_transposed
|
||||
get_act = layers.get_act
|
||||
get_normalization = normalization.get_normalization
|
||||
default_initializer = layers.default_init
|
||||
|
||||
@utils.register_model(name='ddpm_res128')
|
||||
class DDPMRes128(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.act = act = get_act(config)
|
||||
self.register_buffer('sigmas', torch.tensor(utils.get_sigmas(config)))
|
||||
|
||||
self.nf = nf = config.model.nf
|
||||
ch_mult = config.model.ch_mult
|
||||
self.num_res_blocks = num_res_blocks = config.model.num_res_blocks
|
||||
self.attn_resolutions = attn_resolutions = config.model.attn_resolutions
|
||||
dropout = config.model.dropout
|
||||
resamp_with_conv = config.model.resamp_with_conv
|
||||
self.num_resolutions = num_resolutions = len(ch_mult)
|
||||
self.all_resolutions = all_resolutions = [config.data.image_size // (2 ** i) for i in range(num_resolutions)] ## manual for 128 to 64
|
||||
|
||||
AttnBlock = functools.partial(layers.AttnBlock)
|
||||
self.conditional = conditional = config.model.conditional
|
||||
ResnetBlock = functools.partial(ResnetBlockDDPM, act=act, temb_dim=4 * nf, dropout=dropout)
|
||||
if conditional:
|
||||
# Condition on noise levels.
|
||||
modules = [nn.Linear(nf, nf * 4)]
|
||||
modules[0].weight.data = default_initializer()(modules[0].weight.data.shape)
|
||||
nn.init.zeros_(modules[0].bias)
|
||||
modules.append(nn.Linear(nf * 4, nf * 4))
|
||||
modules[1].weight.data = default_initializer()(modules[1].weight.data.shape)
|
||||
nn.init.zeros_(modules[1].bias)
|
||||
|
||||
self.centered = config.data.centered
|
||||
channels = config.data.num_channels
|
||||
|
||||
|
||||
##### Pos Encoding
|
||||
self.img_size = img_size = config.data.image_size
|
||||
self.num_freq = int(np.log2(img_size))
|
||||
coord_x, coord_y, coord_z = torch.meshgrid(torch.arange(img_size), torch.arange(img_size), torch.arange(img_size))
|
||||
self.use_coords = False
|
||||
if self.use_coords:
|
||||
self.coords = torch.nn.Parameter(
|
||||
# torch.stack([coord_x, coord_y, coord_z]).view(1, 3, img_size, img_size, img_size) * 0.0,
|
||||
torch.stack([coord_x, coord_y, coord_z]).view(1, 3, img_size, img_size, img_size),
|
||||
requires_grad=False
|
||||
)
|
||||
####
|
||||
|
||||
### Mask
|
||||
self.mask = torch.nn.Parameter(torch.zeros(1, 1, img_size, img_size, img_size), requires_grad=False)
|
||||
|
||||
# Downsampling block
|
||||
self.pos_layer = conv5x5(3, nf, stride=1, padding=2)
|
||||
self.mask_layer = conv5x5(1, nf, stride=1, padding=2)
|
||||
modules.append(conv5x5(channels, nf, stride=1, padding=2))
|
||||
hs_c = [nf]
|
||||
in_ch = nf
|
||||
|
||||
|
||||
for i_level in range(num_resolutions):
|
||||
num_res_blocks_curr = self.num_res_blocks if i_level != 0 else 2
|
||||
# Residual blocks for this resolution
|
||||
for i_block in range(num_res_blocks_curr):
|
||||
out_ch = nf * ch_mult[i_level]
|
||||
modules.append(ResnetBlock(in_ch=in_ch, out_ch=out_ch))
|
||||
in_ch = out_ch
|
||||
if all_resolutions[i_level] in attn_resolutions:
|
||||
modules.append(AttnBlock(channels=in_ch))
|
||||
hs_c.append(in_ch)
|
||||
if i_level != num_resolutions - 1:
|
||||
modules.append(Downsample(channels=in_ch, with_conv=resamp_with_conv))
|
||||
hs_c.append(in_ch)
|
||||
|
||||
in_ch = hs_c[-1]
|
||||
modules.append(ResnetBlock(in_ch=in_ch))
|
||||
modules.append(AttnBlock(channels=in_ch))
|
||||
modules.append(ResnetBlock(in_ch=in_ch))
|
||||
|
||||
# Upsampling block
|
||||
for i_level in reversed(range(num_resolutions)):
|
||||
num_res_blocks_curr = self.num_res_blocks if i_level != 0 else 2
|
||||
for i_block in range(num_res_blocks_curr + 1):
|
||||
out_ch = nf * ch_mult[i_level]
|
||||
modules.append(ResnetBlock(in_ch=in_ch + hs_c.pop(), out_ch=out_ch))
|
||||
in_ch = out_ch
|
||||
if all_resolutions[i_level] in attn_resolutions:
|
||||
modules.append(AttnBlock(channels=in_ch))
|
||||
if i_level != 0:
|
||||
modules.append(Upsample(channels=in_ch, with_conv=resamp_with_conv))
|
||||
|
||||
assert not hs_c
|
||||
modules.append(nn.GroupNorm(num_channels=in_ch, num_groups=32, eps=1e-6))
|
||||
# modules.append(conv3x3(in_ch, channels, init_scale=0.))
|
||||
# modules.append(transposed_conv6x6(in_ch, channels, init_scale=0.))
|
||||
modules.append(conv5x5(in_ch, channels, init_scale=0., stride=1, padding=2))
|
||||
self.all_modules = nn.ModuleList(modules)
|
||||
|
||||
self.scale_by_sigma = config.model.scale_by_sigma
|
||||
|
||||
def forward(self, x, labels):
|
||||
modules = self.all_modules
|
||||
m_idx = 0
|
||||
if self.conditional:
|
||||
# timestep/scale embedding
|
||||
timesteps = labels
|
||||
temb = layers.get_timestep_embedding(timesteps, self.nf)
|
||||
temb = modules[m_idx](temb)
|
||||
m_idx += 1
|
||||
temb = modules[m_idx](self.act(temb))
|
||||
m_idx += 1
|
||||
else:
|
||||
temb = None
|
||||
|
||||
if self.centered:
|
||||
# Input is in [-1, 1]
|
||||
h = x
|
||||
else:
|
||||
# Input is in [0, 1]
|
||||
h = 2 * x - 1.
|
||||
|
||||
# Downsampling block
|
||||
if self.use_coords:
|
||||
hs = [modules[m_idx](h) + self.pos_layer(self.coords) + self.mask_layer(self.mask)]
|
||||
else:
|
||||
hs = [modules[m_idx](h) + self.mask_layer(self.mask)]
|
||||
m_idx += 1
|
||||
for i_level in range(self.num_resolutions):
|
||||
# Residual blocks for this resolution
|
||||
num_res_blocks = self.num_res_blocks if i_level != 0 else 2
|
||||
for i_block in range(num_res_blocks):
|
||||
h = modules[m_idx](hs[-1], temb)
|
||||
m_idx += 1
|
||||
if h.shape[-1] in self.attn_resolutions:
|
||||
h = modules[m_idx](h)
|
||||
m_idx += 1
|
||||
hs.append(h)
|
||||
if i_level != self.num_resolutions - 1:
|
||||
hs.append(modules[m_idx](hs[-1]))
|
||||
m_idx += 1
|
||||
|
||||
h = hs[-1]
|
||||
h = modules[m_idx](h, temb)
|
||||
m_idx += 1
|
||||
h = modules[m_idx](h)
|
||||
m_idx += 1
|
||||
h = modules[m_idx](h, temb)
|
||||
m_idx += 1
|
||||
|
||||
# Upsampling block
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
num_res_blocks = self.num_res_blocks if i_level != 0 else 2
|
||||
for i_block in range(num_res_blocks + 1):
|
||||
hspop = hs.pop()
|
||||
input = torch.cat([h, hspop], dim=1)
|
||||
h = modules[m_idx](input, temb)
|
||||
m_idx += 1
|
||||
if h.shape[-1] in self.attn_resolutions:
|
||||
h = modules[m_idx](h)
|
||||
m_idx += 1
|
||||
if i_level != 0:
|
||||
h = modules[m_idx](h)
|
||||
m_idx += 1
|
||||
|
||||
assert not hs
|
||||
h = self.act(modules[m_idx](h))
|
||||
m_idx += 1
|
||||
h = modules[m_idx](h)
|
||||
m_idx += 1
|
||||
assert m_idx == len(modules)
|
||||
|
||||
if self.scale_by_sigma:
|
||||
# Divide the output by sigmas. Useful for training with the NCSN loss.
|
||||
# The DDPM loss scales the network output by sigma in the loss function,
|
||||
# so no need of doing it here.
|
||||
used_sigmas = self.sigmas[labels, None, None, None]
|
||||
h = h / used_sigmas
|
||||
|
||||
return h
|
|
@ -0,0 +1,199 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2020 The Google Research Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# pylint: skip-file
|
||||
"""DDPM model.
|
||||
|
||||
This code is the pytorch equivalent of:
|
||||
https://github.com/hojonathanho/diffusion/blob/master/diffusion_tf/models/unet.py
|
||||
"""
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import functools
|
||||
import numpy as np
|
||||
|
||||
from . import utils, layers, normalization
|
||||
|
||||
# RefineBlock = layers.RefineBlock
|
||||
# ResidualBlock = layers.ResidualBlock
|
||||
ResnetBlockDDPM = layers.ResnetBlockDDPM
|
||||
Upsample = layers.Upsample
|
||||
Downsample = layers.Downsample
|
||||
conv3x3 = layers.ddpm_conv3x3
|
||||
get_act = layers.get_act
|
||||
get_normalization = normalization.get_normalization
|
||||
default_initializer = layers.default_init
|
||||
|
||||
@utils.register_model(name='ddpm_res64')
|
||||
class DDPMRes64(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.act = act = get_act(config)
|
||||
self.register_buffer('sigmas', torch.tensor(utils.get_sigmas(config)))
|
||||
|
||||
self.nf = nf = config.model.nf
|
||||
ch_mult = config.model.ch_mult
|
||||
self.num_res_blocks = num_res_blocks = config.model.num_res_blocks
|
||||
self.attn_resolutions = attn_resolutions = config.model.attn_resolutions
|
||||
dropout = config.model.dropout
|
||||
resamp_with_conv = config.model.resamp_with_conv
|
||||
self.num_resolutions = num_resolutions = len(ch_mult)
|
||||
self.all_resolutions = all_resolutions = [config.data.image_size // (2 ** i) for i in range(num_resolutions)]
|
||||
|
||||
AttnBlock = functools.partial(layers.AttnBlock)
|
||||
self.conditional = conditional = config.model.conditional
|
||||
ResnetBlock = functools.partial(ResnetBlockDDPM, act=act, temb_dim=4 * nf, dropout=dropout)
|
||||
if conditional:
|
||||
# Condition on noise levels.
|
||||
modules = [nn.Linear(nf, nf * 4)]
|
||||
modules[0].weight.data = default_initializer()(modules[0].weight.data.shape)
|
||||
nn.init.zeros_(modules[0].bias)
|
||||
modules.append(nn.Linear(nf * 4, nf * 4))
|
||||
modules[1].weight.data = default_initializer()(modules[1].weight.data.shape)
|
||||
nn.init.zeros_(modules[1].bias)
|
||||
|
||||
self.centered = config.data.centered
|
||||
channels = config.data.num_channels
|
||||
|
||||
|
||||
##### Pos Encoding
|
||||
self.img_size = img_size = config.data.image_size
|
||||
self.num_freq = int(np.log2(img_size))
|
||||
coord_x, coord_y, coord_z = torch.meshgrid(torch.arange(img_size), torch.arange(img_size), torch.arange(img_size))
|
||||
self.coords = torch.nn.Parameter(
|
||||
torch.stack([coord_x, coord_y, coord_z]).view(1, 3, img_size, img_size, img_size) * 0.0,
|
||||
requires_grad=False
|
||||
)
|
||||
####
|
||||
|
||||
### Mask
|
||||
self.mask = torch.nn.Parameter(torch.zeros(1, 1, img_size, img_size, img_size), requires_grad=False)
|
||||
|
||||
# Downsampling block
|
||||
self.pos_layer = conv3x3(3, nf)
|
||||
self.mask_layer = conv3x3(1, nf)
|
||||
modules.append(conv3x3(channels, nf))
|
||||
hs_c = [nf]
|
||||
in_ch = nf
|
||||
for i_level in range(num_resolutions):
|
||||
# Residual blocks for this resolution
|
||||
for i_block in range(num_res_blocks):
|
||||
out_ch = nf * ch_mult[i_level]
|
||||
modules.append(ResnetBlock(in_ch=in_ch, out_ch=out_ch))
|
||||
in_ch = out_ch
|
||||
if all_resolutions[i_level] in attn_resolutions:
|
||||
modules.append(AttnBlock(channels=in_ch))
|
||||
hs_c.append(in_ch)
|
||||
if i_level != num_resolutions - 1:
|
||||
modules.append(Downsample(channels=in_ch, with_conv=resamp_with_conv))
|
||||
hs_c.append(in_ch)
|
||||
|
||||
in_ch = hs_c[-1]
|
||||
modules.append(ResnetBlock(in_ch=in_ch))
|
||||
modules.append(AttnBlock(channels=in_ch))
|
||||
modules.append(ResnetBlock(in_ch=in_ch))
|
||||
|
||||
# Upsampling block
|
||||
for i_level in reversed(range(num_resolutions)):
|
||||
for i_block in range(num_res_blocks + 1):
|
||||
out_ch = nf * ch_mult[i_level]
|
||||
modules.append(ResnetBlock(in_ch=in_ch + hs_c.pop(), out_ch=out_ch))
|
||||
in_ch = out_ch
|
||||
if all_resolutions[i_level] in attn_resolutions:
|
||||
modules.append(AttnBlock(channels=in_ch))
|
||||
if i_level != 0:
|
||||
modules.append(Upsample(channels=in_ch, with_conv=resamp_with_conv))
|
||||
|
||||
assert not hs_c
|
||||
modules.append(nn.GroupNorm(num_channels=in_ch, num_groups=32, eps=1e-6))
|
||||
modules.append(conv3x3(in_ch, channels, init_scale=0.))
|
||||
self.all_modules = nn.ModuleList(modules)
|
||||
|
||||
self.scale_by_sigma = config.model.scale_by_sigma
|
||||
|
||||
def forward(self, x, labels):
|
||||
modules = self.all_modules
|
||||
m_idx = 0
|
||||
if self.conditional:
|
||||
# timestep/scale embedding
|
||||
timesteps = labels
|
||||
temb = layers.get_timestep_embedding(timesteps, self.nf)
|
||||
temb = modules[m_idx](temb)
|
||||
m_idx += 1
|
||||
temb = modules[m_idx](self.act(temb))
|
||||
m_idx += 1
|
||||
else:
|
||||
temb = None
|
||||
|
||||
if self.centered:
|
||||
# Input is in [-1, 1]
|
||||
h = x
|
||||
else:
|
||||
# Input is in [0, 1]
|
||||
h = 2 * x - 1.
|
||||
|
||||
# Downsampling block
|
||||
hs = [modules[m_idx](h) + self.pos_layer(self.coords) + self.mask_layer(self.mask)]
|
||||
m_idx += 1
|
||||
for i_level in range(self.num_resolutions):
|
||||
# Residual blocks for this resolution
|
||||
for i_block in range(self.num_res_blocks):
|
||||
h = modules[m_idx](hs[-1], temb)
|
||||
m_idx += 1
|
||||
if h.shape[-1] in self.attn_resolutions:
|
||||
h = modules[m_idx](h)
|
||||
m_idx += 1
|
||||
hs.append(h)
|
||||
if i_level != self.num_resolutions - 1:
|
||||
hs.append(modules[m_idx](hs[-1]))
|
||||
m_idx += 1
|
||||
|
||||
h = hs[-1]
|
||||
h = modules[m_idx](h, temb)
|
||||
m_idx += 1
|
||||
h = modules[m_idx](h)
|
||||
m_idx += 1
|
||||
h = modules[m_idx](h, temb)
|
||||
m_idx += 1
|
||||
|
||||
# Upsampling block
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
for i_block in range(self.num_res_blocks + 1):
|
||||
hspop = hs.pop()
|
||||
input = torch.cat([h, hspop], dim=1)
|
||||
h = modules[m_idx](input, temb)
|
||||
m_idx += 1
|
||||
if h.shape[-1] in self.attn_resolutions:
|
||||
h = modules[m_idx](h)
|
||||
m_idx += 1
|
||||
if i_level != 0:
|
||||
h = modules[m_idx](h)
|
||||
m_idx += 1
|
||||
|
||||
assert not hs
|
||||
h = self.act(modules[m_idx](h))
|
||||
m_idx += 1
|
||||
h = modules[m_idx](h)
|
||||
m_idx += 1
|
||||
assert m_idx == len(modules)
|
||||
|
||||
if self.scale_by_sigma:
|
||||
# Divide the output by sigmas. Useful for training with the NCSN loss.
|
||||
# The DDPM loss scales the network output by sigma in the loss function,
|
||||
# so no need of doing it here.
|
||||
used_sigmas = self.sigmas[labels, None, None, None]
|
||||
h = h / used_sigmas
|
||||
|
||||
return h
|
|
@ -0,0 +1,98 @@
|
|||
# Modified from https://raw.githubusercontent.com/fadel/pytorch_ema/master/torch_ema/ema.py
|
||||
|
||||
from __future__ import division
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
# Partially based on: https://github.com/tensorflow/tensorflow/blob/r1.13/tensorflow/python/training/moving_averages.py
|
||||
class ExponentialMovingAverage:
|
||||
"""
|
||||
Maintains (exponential) moving average of a set of parameters.
|
||||
"""
|
||||
|
||||
def __init__(self, parameters, decay, use_num_updates=True):
|
||||
"""
|
||||
Args:
|
||||
parameters: Iterable of `torch.nn.Parameter`; usually the result of
|
||||
`model.parameters()`.
|
||||
decay: The exponential decay.
|
||||
use_num_updates: Whether to use number of updates when computing
|
||||
averages.
|
||||
"""
|
||||
if decay < 0.0 or decay > 1.0:
|
||||
raise ValueError('Decay must be between 0 and 1')
|
||||
self.decay = decay
|
||||
self.num_updates = 0 if use_num_updates else None
|
||||
self.shadow_params = [p.clone().detach()
|
||||
for p in parameters if p.requires_grad]
|
||||
self.collected_params = []
|
||||
|
||||
def update(self, parameters):
|
||||
"""
|
||||
Update currently maintained parameters.
|
||||
|
||||
Call this every time the parameters are updated, such as the result of
|
||||
the `optimizer.step()` call.
|
||||
|
||||
Args:
|
||||
parameters: Iterable of `torch.nn.Parameter`; usually the same set of
|
||||
parameters used to initialize this object.
|
||||
"""
|
||||
decay = self.decay
|
||||
if self.num_updates is not None:
|
||||
self.num_updates += 1
|
||||
decay = min(decay, (1 + self.num_updates) / (10 + self.num_updates))
|
||||
one_minus_decay = 1.0 - decay
|
||||
with torch.no_grad():
|
||||
parameters = [p for p in parameters if p.requires_grad]
|
||||
for s_param, param in zip(self.shadow_params, parameters):
|
||||
s_param.sub_(one_minus_decay * (s_param - param))
|
||||
|
||||
def copy_to(self, parameters):
|
||||
"""
|
||||
Copy current parameters into given collection of parameters.
|
||||
|
||||
Args:
|
||||
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
|
||||
updated with the stored moving averages.
|
||||
"""
|
||||
parameters = [p for p in parameters if p.requires_grad]
|
||||
for s_param, param in zip(self.shadow_params, parameters):
|
||||
if param.requires_grad:
|
||||
param.data.copy_(s_param.data)
|
||||
|
||||
def store(self, parameters):
|
||||
"""
|
||||
Save the current parameters for restoring later.
|
||||
|
||||
Args:
|
||||
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
|
||||
temporarily stored.
|
||||
"""
|
||||
self.collected_params = [param.clone() for param in parameters]
|
||||
|
||||
def restore(self, parameters):
|
||||
"""
|
||||
Restore the parameters stored with the `store` method.
|
||||
Useful to validate the model with EMA parameters without affecting the
|
||||
original optimization process. Store the parameters before the
|
||||
`copy_to` method. After validation (or model saving), use this to
|
||||
restore the former parameters.
|
||||
|
||||
Args:
|
||||
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
|
||||
updated with the stored parameters.
|
||||
"""
|
||||
for c_param, param in zip(self.collected_params, parameters):
|
||||
param.data.copy_(c_param.data)
|
||||
|
||||
def state_dict(self):
|
||||
return dict(decay=self.decay, num_updates=self.num_updates,
|
||||
shadow_params=self.shadow_params)
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
self.decay = state_dict['decay']
|
||||
self.num_updates = state_dict['num_updates']
|
||||
self.shadow_params = state_dict['shadow_params']
|
|
@ -0,0 +1,771 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2020 The Google Research Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# pylint: skip-file
|
||||
"""Common layers for defining score networks.
|
||||
"""
|
||||
import math
|
||||
import string
|
||||
from functools import partial
|
||||
import torch.nn as nn
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
from .normalization import ConditionalInstanceNorm3dPlus
|
||||
|
||||
|
||||
def get_act(config):
|
||||
"""Get activation functions from the config file."""
|
||||
|
||||
if config.model.nonlinearity.lower() == 'elu':
|
||||
return nn.ELU()
|
||||
elif config.model.nonlinearity.lower() == 'relu':
|
||||
return nn.ReLU()
|
||||
elif config.model.nonlinearity.lower() == 'lrelu':
|
||||
return nn.LeakyReLU(negative_slope=0.2)
|
||||
elif config.model.nonlinearity.lower() == 'swish':
|
||||
return nn.SiLU()
|
||||
else:
|
||||
raise NotImplementedError('activation function does not exist!')
|
||||
|
||||
|
||||
def ncsn_conv1x1(in_planes, out_planes, stride=1, bias=True, dilation=1, init_scale=1., padding=0):
|
||||
"""1x1 convolution. Same as NCSNv1/v2."""
|
||||
conv = nn.Conv3d(in_planes, out_planes, kernel_size=1, stride=stride, bias=bias, dilation=dilation,
|
||||
padding=padding)
|
||||
init_scale = 1e-10 if init_scale == 0 else init_scale
|
||||
conv.weight.data *= init_scale
|
||||
conv.bias.data *= init_scale
|
||||
return conv
|
||||
|
||||
|
||||
def variance_scaling(scale, mode, distribution,
|
||||
in_axis=1, out_axis=0,
|
||||
dtype=torch.float32,
|
||||
device='cpu'):
|
||||
"""Ported from JAX. """
|
||||
|
||||
def _compute_fans(shape, in_axis=1, out_axis=0):
|
||||
receptive_field_size = np.prod(shape) / shape[in_axis] / shape[out_axis]
|
||||
fan_in = shape[in_axis] * receptive_field_size
|
||||
fan_out = shape[out_axis] * receptive_field_size
|
||||
return fan_in, fan_out
|
||||
|
||||
def init(shape, dtype=dtype, device=device):
|
||||
fan_in, fan_out = _compute_fans(shape, in_axis, out_axis)
|
||||
if mode == "fan_in":
|
||||
denominator = fan_in
|
||||
elif mode == "fan_out":
|
||||
denominator = fan_out
|
||||
elif mode == "fan_avg":
|
||||
denominator = (fan_in + fan_out) / 2
|
||||
else:
|
||||
raise ValueError(
|
||||
"invalid mode for variance scaling initializer: {}".format(mode))
|
||||
variance = scale / denominator
|
||||
if distribution == "normal":
|
||||
return torch.randn(*shape, dtype=dtype, device=device) * np.sqrt(variance)
|
||||
elif distribution == "uniform":
|
||||
return (torch.rand(*shape, dtype=dtype, device=device) * 2. - 1.) * np.sqrt(3 * variance)
|
||||
else:
|
||||
raise ValueError("invalid distribution for variance scaling initializer")
|
||||
|
||||
return init
|
||||
|
||||
|
||||
def default_init(scale=1.):
|
||||
"""The same initialization used in DDPM."""
|
||||
scale = 1e-10 if scale == 0 else scale
|
||||
return variance_scaling(scale, 'fan_avg', 'uniform')
|
||||
|
||||
|
||||
class Dense(nn.Module):
|
||||
"""Linear layer with `default_init`."""
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
||||
def ddpm_conv1x1(in_planes, out_planes, stride=1, bias=True, init_scale=1., padding=0):
|
||||
"""1x1 convolution with DDPM initialization."""
|
||||
conv = nn.Conv3d(in_planes, out_planes, kernel_size=1, stride=stride, padding=padding, bias=bias)
|
||||
conv.weight.data = default_init(init_scale)(conv.weight.data.shape)
|
||||
nn.init.zeros_(conv.bias)
|
||||
return conv
|
||||
|
||||
|
||||
def ncsn_conv3x3(in_planes, out_planes, stride=1, bias=True, dilation=1, init_scale=1., padding=1):
|
||||
"""3x3 convolution with PyTorch initialization. Same as NCSNv1/NCSNv2."""
|
||||
init_scale = 1e-10 if init_scale == 0 else init_scale
|
||||
conv = nn.Conv3d(in_planes, out_planes, stride=stride, bias=bias,
|
||||
dilation=dilation, padding=padding, kernel_size=3)
|
||||
conv.weight.data *= init_scale
|
||||
conv.bias.data *= init_scale
|
||||
return conv
|
||||
|
||||
|
||||
def ddpm_conv3x3(in_planes, out_planes, stride=1, bias=True, dilation=1, init_scale=1., padding=1):
|
||||
"""3x3 convolution with DDPM initialization."""
|
||||
conv = nn.Conv3d(in_planes, out_planes, kernel_size=3, stride=stride, padding=padding,
|
||||
dilation=dilation, bias=bias)
|
||||
conv.weight.data = default_init(init_scale)(conv.weight.data.shape)
|
||||
nn.init.zeros_(conv.bias)
|
||||
return conv
|
||||
|
||||
def ddpm_conv5x5(in_planes, out_planes, stride=2, bias=True, dilation=1, init_scale=1., padding=2):
|
||||
"""3x3 convolution with DDPM initialization."""
|
||||
conv = nn.Conv3d(in_planes, out_planes, kernel_size=5, stride=stride, padding=padding,
|
||||
dilation=dilation, bias=bias)
|
||||
conv.weight.data = default_init(init_scale)(conv.weight.data.shape)
|
||||
nn.init.zeros_(conv.bias)
|
||||
return conv
|
||||
|
||||
|
||||
def ddpm_conv5x5_transposed(in_planes, out_planes, stride=2, bias=True, dilation=1, init_scale=1., padding=2):
|
||||
"""3x3 convolution with DDPM initialization."""
|
||||
conv = nn.ConvTranspose3d(in_planes, out_planes, kernel_size=5, stride=stride, padding=padding,
|
||||
dilation=dilation, bias=bias, output_padding=(0, 1))
|
||||
conv.weight.data = default_init(init_scale)(conv.weight.data.shape)
|
||||
nn.init.zeros_(conv.bias)
|
||||
return conv
|
||||
|
||||
|
||||
def ddpm_conv6x6_transposed(in_planes, out_planes, stride=2, bias=True, dilation=1, init_scale=1., padding=2):
|
||||
"""3x3 convolution with DDPM initialization."""
|
||||
conv = nn.ConvTranspose3d(in_planes, out_planes, kernel_size=6, stride=stride, padding=padding,
|
||||
dilation=dilation, bias=bias)
|
||||
conv.weight.data = default_init(init_scale)(conv.weight.data.shape)
|
||||
nn.init.zeros_(conv.bias)
|
||||
return conv
|
||||
|
||||
|
||||
###########################################################################
|
||||
# Functions below are ported over from the NCSNv1/NCSNv2 codebase:
|
||||
# https://github.com/ermongroup/ncsn
|
||||
# https://github.com/ermongroup/ncsnv2
|
||||
###########################################################################
|
||||
|
||||
|
||||
class CRPBlock(nn.Module):
|
||||
def __init__(self, features, n_stages, act=nn.ReLU(), maxpool=True):
|
||||
super().__init__()
|
||||
self.convs = nn.ModuleList()
|
||||
for i in range(n_stages):
|
||||
self.convs.append(ncsn_conv3x3(features, features, stride=1, bias=False))
|
||||
self.n_stages = n_stages
|
||||
if maxpool:
|
||||
self.pool = nn.MaxPool3d(kernel_size=5, stride=1, padding=2)
|
||||
else:
|
||||
self.pool = nn.AvgPool3d(kernel_size=5, stride=1, padding=2)
|
||||
|
||||
self.act = act
|
||||
|
||||
def forward(self, x):
|
||||
x = self.act(x)
|
||||
path = x
|
||||
for i in range(self.n_stages):
|
||||
path = self.pool(path)
|
||||
path = self.convs[i](path)
|
||||
x = path + x
|
||||
return x
|
||||
|
||||
|
||||
class CondCRPBlock(nn.Module):
|
||||
def __init__(self, features, n_stages, num_classes, normalizer, act=nn.ReLU()):
|
||||
super().__init__()
|
||||
self.convs = nn.ModuleList()
|
||||
self.norms = nn.ModuleList()
|
||||
self.normalizer = normalizer
|
||||
for i in range(n_stages):
|
||||
self.norms.append(normalizer(features, num_classes, bias=True))
|
||||
self.convs.append(ncsn_conv3x3(features, features, stride=1, bias=False))
|
||||
|
||||
self.n_stages = n_stages
|
||||
self.pool = nn.AvgPool3d(kernel_size=5, stride=1, padding=2)
|
||||
self.act = act
|
||||
|
||||
def forward(self, x, y):
|
||||
x = self.act(x)
|
||||
path = x
|
||||
for i in range(self.n_stages):
|
||||
path = self.norms[i](path, y)
|
||||
path = self.pool(path)
|
||||
path = self.convs[i](path)
|
||||
|
||||
x = path + x
|
||||
return x
|
||||
|
||||
|
||||
class RCUBlock(nn.Module):
|
||||
def __init__(self, features, n_blocks, n_stages, act=nn.ReLU()):
|
||||
super().__init__()
|
||||
|
||||
for i in range(n_blocks):
|
||||
for j in range(n_stages):
|
||||
setattr(self, '{}_{}_conv'.format(i + 1, j + 1), ncsn_conv3x3(features, features, stride=1, bias=False))
|
||||
|
||||
self.stride = 1
|
||||
self.n_blocks = n_blocks
|
||||
self.n_stages = n_stages
|
||||
self.act = act
|
||||
|
||||
def forward(self, x):
|
||||
for i in range(self.n_blocks):
|
||||
residual = x
|
||||
for j in range(self.n_stages):
|
||||
x = self.act(x)
|
||||
x = getattr(self, '{}_{}_conv'.format(i + 1, j + 1))(x)
|
||||
|
||||
x += residual
|
||||
return x
|
||||
|
||||
|
||||
class CondRCUBlock(nn.Module):
|
||||
def __init__(self, features, n_blocks, n_stages, num_classes, normalizer, act=nn.ReLU()):
|
||||
super().__init__()
|
||||
|
||||
for i in range(n_blocks):
|
||||
for j in range(n_stages):
|
||||
setattr(self, '{}_{}_norm'.format(i + 1, j + 1), normalizer(features, num_classes, bias=True))
|
||||
setattr(self, '{}_{}_conv'.format(i + 1, j + 1), ncsn_conv3x3(features, features, stride=1, bias=False))
|
||||
|
||||
self.stride = 1
|
||||
self.n_blocks = n_blocks
|
||||
self.n_stages = n_stages
|
||||
self.act = act
|
||||
self.normalizer = normalizer
|
||||
|
||||
def forward(self, x, y):
|
||||
for i in range(self.n_blocks):
|
||||
residual = x
|
||||
for j in range(self.n_stages):
|
||||
x = getattr(self, '{}_{}_norm'.format(i + 1, j + 1))(x, y)
|
||||
x = self.act(x)
|
||||
x = getattr(self, '{}_{}_conv'.format(i + 1, j + 1))(x)
|
||||
|
||||
x += residual
|
||||
return x
|
||||
|
||||
|
||||
class MSFBlock(nn.Module):
|
||||
def __init__(self, in_planes, features):
|
||||
super().__init__()
|
||||
assert isinstance(in_planes, list) or isinstance(in_planes, tuple)
|
||||
self.convs = nn.ModuleList()
|
||||
self.features = features
|
||||
|
||||
for i in range(len(in_planes)):
|
||||
self.convs.append(ncsn_conv3x3(in_planes[i], features, stride=1, bias=True))
|
||||
|
||||
def forward(self, xs, shape):
|
||||
sums = torch.zeros(xs[0].shape[0], self.features, *shape, device=xs[0].device)
|
||||
for i in range(len(self.convs)):
|
||||
h = self.convs[i](xs[i])
|
||||
h = F.interpolate(h, size=shape, mode='bilinear', align_corners=True)
|
||||
sums += h
|
||||
return sums
|
||||
|
||||
|
||||
class CondMSFBlock(nn.Module):
|
||||
def __init__(self, in_planes, features, num_classes, normalizer):
|
||||
super().__init__()
|
||||
assert isinstance(in_planes, list) or isinstance(in_planes, tuple)
|
||||
|
||||
self.convs = nn.ModuleList()
|
||||
self.norms = nn.ModuleList()
|
||||
self.features = features
|
||||
self.normalizer = normalizer
|
||||
|
||||
for i in range(len(in_planes)):
|
||||
self.convs.append(ncsn_conv3x3(in_planes[i], features, stride=1, bias=True))
|
||||
self.norms.append(normalizer(in_planes[i], num_classes, bias=True))
|
||||
|
||||
def forward(self, xs, y, shape):
|
||||
sums = torch.zeros(xs[0].shape[0], self.features, *shape, device=xs[0].device)
|
||||
for i in range(len(self.convs)):
|
||||
h = self.norms[i](xs[i], y)
|
||||
h = self.convs[i](h)
|
||||
h = F.interpolate(h, size=shape, mode='bilinear', align_corners=True)
|
||||
sums += h
|
||||
return sums
|
||||
|
||||
|
||||
class RefineBlock(nn.Module):
|
||||
def __init__(self, in_planes, features, act=nn.ReLU(), start=False, end=False, maxpool=True):
|
||||
super().__init__()
|
||||
|
||||
assert isinstance(in_planes, tuple) or isinstance(in_planes, list)
|
||||
self.n_blocks = n_blocks = len(in_planes)
|
||||
|
||||
self.adapt_convs = nn.ModuleList()
|
||||
for i in range(n_blocks):
|
||||
self.adapt_convs.append(RCUBlock(in_planes[i], 2, 2, act))
|
||||
|
||||
self.output_convs = RCUBlock(features, 3 if end else 1, 2, act)
|
||||
|
||||
if not start:
|
||||
self.msf = MSFBlock(in_planes, features)
|
||||
|
||||
self.crp = CRPBlock(features, 2, act, maxpool=maxpool)
|
||||
|
||||
def forward(self, xs, output_shape):
|
||||
assert isinstance(xs, tuple) or isinstance(xs, list)
|
||||
hs = []
|
||||
for i in range(len(xs)):
|
||||
h = self.adapt_convs[i](xs[i])
|
||||
hs.append(h)
|
||||
|
||||
if self.n_blocks > 1:
|
||||
h = self.msf(hs, output_shape)
|
||||
else:
|
||||
h = hs[0]
|
||||
|
||||
h = self.crp(h)
|
||||
h = self.output_convs(h)
|
||||
|
||||
return h
|
||||
|
||||
|
||||
class CondRefineBlock(nn.Module):
|
||||
def __init__(self, in_planes, features, num_classes, normalizer, act=nn.ReLU(), start=False, end=False):
|
||||
super().__init__()
|
||||
|
||||
assert isinstance(in_planes, tuple) or isinstance(in_planes, list)
|
||||
self.n_blocks = n_blocks = len(in_planes)
|
||||
|
||||
self.adapt_convs = nn.ModuleList()
|
||||
for i in range(n_blocks):
|
||||
self.adapt_convs.append(
|
||||
CondRCUBlock(in_planes[i], 2, 2, num_classes, normalizer, act)
|
||||
)
|
||||
|
||||
self.output_convs = CondRCUBlock(features, 3 if end else 1, 2, num_classes, normalizer, act)
|
||||
|
||||
if not start:
|
||||
self.msf = CondMSFBlock(in_planes, features, num_classes, normalizer)
|
||||
|
||||
self.crp = CondCRPBlock(features, 2, num_classes, normalizer, act)
|
||||
|
||||
def forward(self, xs, y, output_shape):
|
||||
assert isinstance(xs, tuple) or isinstance(xs, list)
|
||||
hs = []
|
||||
for i in range(len(xs)):
|
||||
h = self.adapt_convs[i](xs[i], y)
|
||||
hs.append(h)
|
||||
|
||||
if self.n_blocks > 1:
|
||||
h = self.msf(hs, y, output_shape)
|
||||
else:
|
||||
h = hs[0]
|
||||
|
||||
h = self.crp(h, y)
|
||||
h = self.output_convs(h, y)
|
||||
|
||||
return h
|
||||
|
||||
|
||||
class ConvMeanPool(nn.Module):
|
||||
def __init__(self, input_dim, output_dim, kernel_size=3, biases=True, adjust_padding=False):
|
||||
super().__init__()
|
||||
if not adjust_padding:
|
||||
conv = nn.Conv3d(input_dim, output_dim, kernel_size, stride=1, padding=kernel_size // 2, bias=biases)
|
||||
self.conv = conv
|
||||
else:
|
||||
conv = nn.Conv3d(input_dim, output_dim, kernel_size, stride=1, padding=kernel_size // 2, bias=biases)
|
||||
|
||||
self.conv = nn.Sequential(
|
||||
nn.ZeroPad3d((1, 0, 1, 0)),
|
||||
conv
|
||||
)
|
||||
|
||||
def forward(self, inputs):
|
||||
output = self.conv(inputs)
|
||||
output = sum([output[:, :, ::2, ::2], output[:, :, 1::2, ::2],
|
||||
output[:, :, ::2, 1::2], output[:, :, 1::2, 1::2]]) / 4.
|
||||
return output
|
||||
|
||||
|
||||
class MeanPoolConv(nn.Module):
|
||||
def __init__(self, input_dim, output_dim, kernel_size=3, biases=True):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv3d(input_dim, output_dim, kernel_size, stride=1, padding=kernel_size // 2, bias=biases)
|
||||
|
||||
def forward(self, inputs):
|
||||
output = inputs
|
||||
output = sum([output[:, :, ::2, ::2], output[:, :, 1::2, ::2],
|
||||
output[:, :, ::2, 1::2], output[:, :, 1::2, 1::2]]) / 4.
|
||||
return self.conv(output)
|
||||
|
||||
|
||||
class UpsampleConv(nn.Module):
|
||||
def __init__(self, input_dim, output_dim, kernel_size=3, biases=True):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv3d(input_dim, output_dim, kernel_size, stride=1, padding=kernel_size // 2, bias=biases)
|
||||
self.pixelshuffle = nn.PixelShuffle(upscale_factor=2)
|
||||
|
||||
def forward(self, inputs):
|
||||
output = inputs
|
||||
output = torch.cat([output, output, output, output], dim=1)
|
||||
output = self.pixelshuffle(output)
|
||||
return self.conv(output)
|
||||
|
||||
|
||||
class ConditionalResidualBlock(nn.Module):
|
||||
def __init__(self, input_dim, output_dim, num_classes, resample=1, act=nn.ELU(),
|
||||
normalization=ConditionalInstanceNorm3dPlus, adjust_padding=False, dilation=None):
|
||||
super().__init__()
|
||||
self.non_linearity = act
|
||||
self.input_dim = input_dim
|
||||
self.output_dim = output_dim
|
||||
self.resample = resample
|
||||
self.normalization = normalization
|
||||
if resample == 'down':
|
||||
if dilation > 1:
|
||||
self.conv1 = ncsn_conv3x3(input_dim, input_dim, dilation=dilation)
|
||||
self.normalize2 = normalization(input_dim, num_classes)
|
||||
self.conv2 = ncsn_conv3x3(input_dim, output_dim, dilation=dilation)
|
||||
conv_shortcut = partial(ncsn_conv3x3, dilation=dilation)
|
||||
else:
|
||||
self.conv1 = ncsn_conv3x3(input_dim, input_dim)
|
||||
self.normalize2 = normalization(input_dim, num_classes)
|
||||
self.conv2 = ConvMeanPool(input_dim, output_dim, 3, adjust_padding=adjust_padding)
|
||||
conv_shortcut = partial(ConvMeanPool, kernel_size=1, adjust_padding=adjust_padding)
|
||||
|
||||
elif resample is None:
|
||||
if dilation > 1:
|
||||
conv_shortcut = partial(ncsn_conv3x3, dilation=dilation)
|
||||
self.conv1 = ncsn_conv3x3(input_dim, output_dim, dilation=dilation)
|
||||
self.normalize2 = normalization(output_dim, num_classes)
|
||||
self.conv2 = ncsn_conv3x3(output_dim, output_dim, dilation=dilation)
|
||||
else:
|
||||
conv_shortcut = nn.Conv3d
|
||||
self.conv1 = ncsn_conv3x3(input_dim, output_dim)
|
||||
self.normalize2 = normalization(output_dim, num_classes)
|
||||
self.conv2 = ncsn_conv3x3(output_dim, output_dim)
|
||||
else:
|
||||
raise Exception('invalid resample value')
|
||||
|
||||
if output_dim != input_dim or resample is not None:
|
||||
self.shortcut = conv_shortcut(input_dim, output_dim)
|
||||
|
||||
self.normalize1 = normalization(input_dim, num_classes)
|
||||
|
||||
def forward(self, x, y):
|
||||
output = self.normalize1(x, y)
|
||||
output = self.non_linearity(output)
|
||||
output = self.conv1(output)
|
||||
output = self.normalize2(output, y)
|
||||
output = self.non_linearity(output)
|
||||
output = self.conv2(output)
|
||||
|
||||
if self.output_dim == self.input_dim and self.resample is None:
|
||||
shortcut = x
|
||||
else:
|
||||
shortcut = self.shortcut(x)
|
||||
|
||||
return shortcut + output
|
||||
|
||||
|
||||
class ResidualBlock(nn.Module):
|
||||
def __init__(self, input_dim, output_dim, resample=None, act=nn.ELU(),
|
||||
normalization=nn.InstanceNorm3d, adjust_padding=False, dilation=1):
|
||||
super().__init__()
|
||||
self.non_linearity = act
|
||||
self.input_dim = input_dim
|
||||
self.output_dim = output_dim
|
||||
self.resample = resample
|
||||
self.normalization = normalization
|
||||
if resample == 'down':
|
||||
if dilation > 1:
|
||||
self.conv1 = ncsn_conv3x3(input_dim, input_dim, dilation=dilation)
|
||||
self.normalize2 = normalization(input_dim)
|
||||
self.conv2 = ncsn_conv3x3(input_dim, output_dim, dilation=dilation)
|
||||
conv_shortcut = partial(ncsn_conv3x3, dilation=dilation)
|
||||
else:
|
||||
self.conv1 = ncsn_conv3x3(input_dim, input_dim)
|
||||
self.normalize2 = normalization(input_dim)
|
||||
self.conv2 = ConvMeanPool(input_dim, output_dim, 3, adjust_padding=adjust_padding)
|
||||
conv_shortcut = partial(ConvMeanPool, kernel_size=1, adjust_padding=adjust_padding)
|
||||
|
||||
elif resample is None:
|
||||
if dilation > 1:
|
||||
conv_shortcut = partial(ncsn_conv3x3, dilation=dilation)
|
||||
self.conv1 = ncsn_conv3x3(input_dim, output_dim, dilation=dilation)
|
||||
self.normalize2 = normalization(output_dim)
|
||||
self.conv2 = ncsn_conv3x3(output_dim, output_dim, dilation=dilation)
|
||||
else:
|
||||
# conv_shortcut = nn.Conv3d ### Something wierd here.
|
||||
conv_shortcut = partial(ncsn_conv1x1)
|
||||
self.conv1 = ncsn_conv3x3(input_dim, output_dim)
|
||||
self.normalize2 = normalization(output_dim)
|
||||
self.conv2 = ncsn_conv3x3(output_dim, output_dim)
|
||||
else:
|
||||
raise Exception('invalid resample value')
|
||||
|
||||
if output_dim != input_dim or resample is not None:
|
||||
self.shortcut = conv_shortcut(input_dim, output_dim)
|
||||
|
||||
self.normalize1 = normalization(input_dim)
|
||||
|
||||
def forward(self, x):
|
||||
output = self.normalize1(x)
|
||||
output = self.non_linearity(output)
|
||||
output = self.conv1(output)
|
||||
output = self.normalize2(output)
|
||||
output = self.non_linearity(output)
|
||||
output = self.conv2(output)
|
||||
|
||||
if self.output_dim == self.input_dim and self.resample is None:
|
||||
shortcut = x
|
||||
else:
|
||||
shortcut = self.shortcut(x)
|
||||
|
||||
return shortcut + output
|
||||
|
||||
|
||||
###########################################################################
|
||||
# Functions below are ported over from the DDPM codebase:
|
||||
# https://github.com/hojonathanho/diffusion/blob/master/diffusion_tf/nn.py
|
||||
###########################################################################
|
||||
|
||||
def get_timestep_embedding(timesteps, embedding_dim, max_positions=10000):
|
||||
assert len(timesteps.shape) == 1 # and timesteps.dtype == tf.int32
|
||||
half_dim = embedding_dim // 2
|
||||
# magic number 10000 is from transformers
|
||||
emb = math.log(max_positions) / (half_dim - 1)
|
||||
# emb = math.log(2.) / (half_dim - 1)
|
||||
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) * -emb)
|
||||
# emb = tf.range(num_embeddings, dtype=jnp.float32)[:, None] * emb[None, :]
|
||||
# emb = tf.cast(timesteps, dtype=jnp.float32)[:, None] * emb[None, :]
|
||||
emb = timesteps.float()[:, None] * emb[None, :]
|
||||
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
|
||||
if embedding_dim % 2 == 1: # zero pad
|
||||
emb = F.pad(emb, (0, 1), mode='constant')
|
||||
assert emb.shape == (timesteps.shape[0], embedding_dim)
|
||||
return emb
|
||||
|
||||
|
||||
def _einsum(a, b, c, x, y):
|
||||
einsum_str = '{},{}->{}'.format(''.join(a), ''.join(b), ''.join(c))
|
||||
return torch.einsum(einsum_str, x, y)
|
||||
|
||||
|
||||
def contract_inner(x, y):
|
||||
"""tensordot(x, y, 1)."""
|
||||
x_chars = list(string.ascii_lowercase[:len(x.shape)])
|
||||
y_chars = list(string.ascii_lowercase[len(x.shape):len(y.shape) + len(x.shape)])
|
||||
y_chars[0] = x_chars[-1] # first axis of y and last of x get summed
|
||||
out_chars = x_chars[:-1] + y_chars[1:]
|
||||
return _einsum(x_chars, y_chars, out_chars, x, y)
|
||||
|
||||
|
||||
class NIN(nn.Module):
|
||||
def __init__(self, in_dim, num_units, init_scale=0.1):
|
||||
super().__init__()
|
||||
self.W = nn.Parameter(default_init(scale=init_scale)((in_dim, num_units)), requires_grad=True)
|
||||
self.b = nn.Parameter(torch.zeros(num_units), requires_grad=True)
|
||||
|
||||
def forward(self, x):
|
||||
x = x.permute(0, 2, 3, 4, 1)
|
||||
y = contract_inner(x, self.W) + self.b
|
||||
return y.permute(0, 4, 1, 2, 3)
|
||||
|
||||
|
||||
class AttnBlock(nn.Module):
|
||||
"""Channel-wise self-attention block."""
|
||||
def __init__(self, channels):
|
||||
super().__init__()
|
||||
self.GroupNorm_0 = nn.GroupNorm(num_groups=32, num_channels=channels, eps=1e-6)
|
||||
self.NIN_0 = NIN(channels, channels)
|
||||
self.NIN_1 = NIN(channels, channels)
|
||||
self.NIN_2 = NIN(channels, channels)
|
||||
self.NIN_3 = NIN(channels, channels, init_scale=0.)
|
||||
|
||||
def forward(self, x):
|
||||
B, C, D, H, W = x.shape
|
||||
h = self.GroupNorm_0(x)
|
||||
q = self.NIN_0(h)
|
||||
k = self.NIN_1(h)
|
||||
v = self.NIN_2(h)
|
||||
|
||||
w = torch.einsum('bcdhw,bckij->bdhwkij', q, k) * (int(C) ** (-0.5))
|
||||
w = torch.reshape(w, (B, D, H, W, D * H * W))
|
||||
w = F.softmax(w, dim=-1)
|
||||
w = torch.reshape(w, (B, D, H, W, D, H, W))
|
||||
h = torch.einsum('bdhwkij,bckij->bcdhw', w, v)
|
||||
h = self.NIN_3(h)
|
||||
return x + h
|
||||
|
||||
|
||||
class Upsample(nn.Module):
|
||||
def __init__(self, channels, with_conv=False):
|
||||
super().__init__()
|
||||
if with_conv:
|
||||
self.Conv_0 = ddpm_conv3x3(channels, channels)
|
||||
self.with_conv = with_conv
|
||||
|
||||
def forward(self, x):
|
||||
B, C, D, H, W = x.shape
|
||||
h = F.interpolate(x, (D * 2, H * 2, W * 2), mode='nearest')
|
||||
if self.with_conv:
|
||||
h = self.Conv_0(h)
|
||||
return h
|
||||
|
||||
|
||||
class Downsample(nn.Module):
|
||||
def __init__(self, channels, with_conv=False):
|
||||
super().__init__()
|
||||
if with_conv:
|
||||
self.Conv_0 = ddpm_conv3x3(channels, channels, stride=2, padding=0)
|
||||
self.with_conv = with_conv
|
||||
|
||||
def forward(self, x):
|
||||
B, C, D, H, W = x.shape
|
||||
# Emulate 'SAME' padding
|
||||
if self.with_conv:
|
||||
x = F.pad(x, (0, 1, 0, 1, 0, 1))
|
||||
x = self.Conv_0(x)
|
||||
else:
|
||||
x = F.avg_pool3d(x, kernel_size=2, stride=2, padding=0)
|
||||
|
||||
assert x.shape == (B, C, D // 2, H // 2, W // 2)
|
||||
return x
|
||||
|
||||
|
||||
class ResnetBlockDDPM(nn.Module):
|
||||
"""The ResNet Blocks used in DDPM."""
|
||||
def __init__(self, act, in_ch, out_ch=None, temb_dim=None, conv_shortcut=False, dropout=0.1):
|
||||
super().__init__()
|
||||
if out_ch is None:
|
||||
out_ch = in_ch
|
||||
self.GroupNorm_0 = nn.GroupNorm(num_groups=32, num_channels=in_ch, eps=1e-6)
|
||||
self.act = act
|
||||
self.Conv_0 = ddpm_conv3x3(in_ch, out_ch)
|
||||
if temb_dim is not None:
|
||||
self.Dense_0 = nn.Linear(temb_dim, out_ch)
|
||||
self.Dense_0.weight.data = default_init()(self.Dense_0.weight.data.shape)
|
||||
nn.init.zeros_(self.Dense_0.bias)
|
||||
|
||||
self.GroupNorm_1 = nn.GroupNorm(num_groups=32, num_channels=out_ch, eps=1e-6)
|
||||
self.Dropout_0 = nn.Dropout(dropout)
|
||||
self.Conv_1 = ddpm_conv3x3(out_ch, out_ch, init_scale=0.)
|
||||
if in_ch != out_ch:
|
||||
if conv_shortcut:
|
||||
self.Conv_2 = ddpm_conv3x3(in_ch, out_ch)
|
||||
else:
|
||||
self.NIN_0 = NIN(in_ch, out_ch)
|
||||
self.out_ch = out_ch
|
||||
self.in_ch = in_ch
|
||||
self.conv_shortcut = conv_shortcut
|
||||
|
||||
def forward(self, x, temb=None):
|
||||
B, C, D, H, W = x.shape
|
||||
assert C == self.in_ch
|
||||
out_ch = self.out_ch if self.out_ch else self.in_ch
|
||||
h = self.act(self.GroupNorm_0(x))
|
||||
h = self.Conv_0(h)
|
||||
# Add bias to each feature map conditioned on the time embedding
|
||||
if temb is not None:
|
||||
h += self.Dense_0(self.act(temb))[:, :, None, None, None]
|
||||
h = self.act(self.GroupNorm_1(h))
|
||||
h = self.Dropout_0(h)
|
||||
h = self.Conv_1(h)
|
||||
if C != out_ch:
|
||||
if self.conv_shortcut:
|
||||
x = self.Conv_2(x)
|
||||
else:
|
||||
x = self.NIN_0(x)
|
||||
return x + h
|
||||
|
||||
# class PositionalEncoding(nn.Module):
|
||||
|
||||
# def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
|
||||
# super().__init__()
|
||||
# self.dropout = nn.Dropout(p=dropout)
|
||||
|
||||
# position = torch.arange(max_len).unsqueeze(1)
|
||||
# div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
|
||||
# pe = torch.zeros(max_len, 1, d_model)
|
||||
# pe[:, 0, 0::2] = torch.sin(position * div_term)
|
||||
# pe[:, 0, 1::2] = torch.cos(position * div_term)
|
||||
# self.register_buffer('pe', pe)
|
||||
|
||||
# def forward(self, x: Tensor) -> Tensor:
|
||||
# """
|
||||
# Args:
|
||||
# x: Tensor, shape [seq_len, batch_size, embedding_dim]
|
||||
# """
|
||||
# x = x + self.pe[:x.size(0)]
|
||||
# return self.dropout(x)
|
||||
|
||||
class ResnetBlockDDPMPosEncoding(nn.Module):
|
||||
"""The ResNet Blocks used in DDPM."""
|
||||
def __init__(self, act, in_ch, out_ch=None, temb_dim=None, conv_shortcut=False, dropout=0.1, img_size=64):
|
||||
super().__init__()
|
||||
##### Pos Encoding
|
||||
coord_x, coord_y, coord_z = torch.meshgrid(torch.arange(img_size), torch.arange(img_size), torch.arange(img_size))
|
||||
coords = torch.stack([coord_x, coord_y, coord_z])
|
||||
self.num_freq = int(np.log2(img_size))
|
||||
pos_encoding = torch.zeros(1, 2 * self.num_freq, 3, img_size, img_size, img_size)
|
||||
with torch.no_grad():
|
||||
for i in range(self.num_freq):
|
||||
pos_encoding[0, 2*i, :, :, :, :] = torch.cos((i+1) * np.pi * coords)
|
||||
pos_encoding[0, 2*i + 1, :, :, :, :] = torch.sin((i+1) * np.pi * coords)
|
||||
self.pos_encoding = nn.Parameter(
|
||||
pos_encoding.view(1, 2 * self.num_freq * 3, img_size, img_size, img_size) / img_size,
|
||||
requires_grad=False
|
||||
)
|
||||
####
|
||||
|
||||
if out_ch is None:
|
||||
out_ch = in_ch
|
||||
self.GroupNorm_0 = nn.GroupNorm(num_groups=32, num_channels=in_ch, eps=1e-6)
|
||||
self.act = act
|
||||
self.Conv_0 = ddpm_conv3x3(in_ch, out_ch)
|
||||
self.Conv_0_pos = ddpm_conv3x3(2 * self.num_freq * 3, out_ch)
|
||||
if temb_dim is not None:
|
||||
self.Dense_0 = nn.Linear(temb_dim, out_ch)
|
||||
self.Dense_0.weight.data = default_init()(self.Dense_0.weight.data.shape)
|
||||
nn.init.zeros_(self.Dense_0.bias)
|
||||
|
||||
self.GroupNorm_1 = nn.GroupNorm(num_groups=32, num_channels=out_ch, eps=1e-6)
|
||||
self.Dropout_0 = nn.Dropout(dropout)
|
||||
self.Conv_1 = ddpm_conv3x3(out_ch, out_ch, init_scale=0.)
|
||||
if in_ch != out_ch:
|
||||
if conv_shortcut:
|
||||
self.Conv_2 = ddpm_conv3x3(in_ch, out_ch)
|
||||
else:
|
||||
self.NIN_0 = NIN(in_ch, out_ch)
|
||||
self.out_ch = out_ch
|
||||
self.in_ch = in_ch
|
||||
self.conv_shortcut = conv_shortcut
|
||||
|
||||
def forward(self, x, temb=None):
|
||||
B, C, D, H, W = x.shape
|
||||
assert C == self.in_ch
|
||||
out_ch = self.out_ch if self.out_ch else self.in_ch
|
||||
h = self.act(self.GroupNorm_0(x))
|
||||
h = self.Conv_0(h) + self.Conv_0_pos(self.pos_encoding).expand(h.size(0), -1, -1, -1, -1)
|
||||
# Add bias to each feature map conditioned on the time embedding
|
||||
if temb is not None:
|
||||
h += self.Dense_0(self.act(temb))[:, :, None, None, None]
|
||||
h = self.act(self.GroupNorm_1(h))
|
||||
h = self.Dropout_0(h)
|
||||
h = self.Conv_1(h)
|
||||
if C != out_ch:
|
||||
if self.conv_shortcut:
|
||||
x = self.Conv_2(x)
|
||||
else:
|
||||
x = self.NIN_0(x)
|
||||
return x + h
|
|
@ -0,0 +1,223 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2020 The Google Research Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Normalization layers."""
|
||||
import torch.nn as nn
|
||||
import torch
|
||||
import functools
|
||||
|
||||
|
||||
def get_normalization(config, conditional=False):
|
||||
"""Obtain normalization modules from the config file."""
|
||||
norm = config.model.normalization
|
||||
if conditional:
|
||||
if norm == 'InstanceNorm++':
|
||||
return functools.partial(ConditionalInstanceNorm3dPlus, num_classes=config.model.num_classes)
|
||||
else:
|
||||
raise NotImplementedError(f'{norm} not implemented yet.')
|
||||
else:
|
||||
if norm == 'InstanceNorm':
|
||||
return nn.InstanceNorm3d
|
||||
elif norm == 'InstanceNorm++':
|
||||
return InstanceNorm3dPlus
|
||||
elif norm == 'VarianceNorm':
|
||||
return VarianceNorm3d
|
||||
elif norm == 'GroupNorm':
|
||||
return nn.GroupNorm
|
||||
else:
|
||||
raise ValueError('Unknown normalization: %s' % norm)
|
||||
|
||||
|
||||
class ConditionalBatchNorm3d(nn.Module):
|
||||
def __init__(self, num_features, num_classes, bias=True):
|
||||
super().__init__()
|
||||
self.num_features = num_features
|
||||
self.bias = bias
|
||||
self.bn = nn.BatchNorm3d(num_features, affine=False)
|
||||
if self.bias:
|
||||
self.embed = nn.Embedding(num_classes, num_features * 2)
|
||||
self.embed.weight.data[:, :num_features].uniform_() # Initialise scale at N(1, 0.02)
|
||||
self.embed.weight.data[:, num_features:].zero_() # Initialise bias at 0
|
||||
else:
|
||||
self.embed = nn.Embedding(num_classes, num_features)
|
||||
self.embed.weight.data.uniform_()
|
||||
|
||||
def forward(self, x, y):
|
||||
out = self.bn(x)
|
||||
if self.bias:
|
||||
gamma, beta = self.embed(y).chunk(2, dim=1)
|
||||
out = gamma.view(-1, self.num_features, 1, 1, 1) * out + beta.view(-1, self.num_features, 1, 1, 1)
|
||||
else:
|
||||
gamma = self.embed(y)
|
||||
out = gamma.view(-1, self.num_features, 1, 1, 1) * out
|
||||
return out
|
||||
|
||||
|
||||
class ConditionalInstanceNorm3d(nn.Module):
|
||||
def __init__(self, num_features, num_classes, bias=True):
|
||||
super().__init__()
|
||||
self.num_features = num_features
|
||||
self.bias = bias
|
||||
self.instance_norm = nn.InstanceNorm3d(num_features, affine=False, track_running_stats=False)
|
||||
if bias:
|
||||
self.embed = nn.Embedding(num_classes, num_features * 2)
|
||||
self.embed.weight.data[:, :num_features].uniform_() # Initialise scale at N(1, 0.02)
|
||||
self.embed.weight.data[:, num_features:].zero_() # Initialise bias at 0
|
||||
else:
|
||||
self.embed = nn.Embedding(num_classes, num_features)
|
||||
self.embed.weight.data.uniform_()
|
||||
|
||||
def forward(self, x, y):
|
||||
h = self.instance_norm(x)
|
||||
if self.bias:
|
||||
gamma, beta = self.embed(y).chunk(2, dim=-1)
|
||||
out = gamma.view(-1, self.num_features, 1, 1, 1) * h + beta.view(-1, self.num_features, 1, 1, 1)
|
||||
else:
|
||||
gamma = self.embed(y)
|
||||
out = gamma.view(-1, self.num_features, 1, 1, 1) * h
|
||||
return out
|
||||
|
||||
|
||||
class ConditionalVarianceNorm3d(nn.Module):
|
||||
def __init__(self, num_features, num_classes, bias=False):
|
||||
super().__init__()
|
||||
self.num_features = num_features
|
||||
self.bias = bias
|
||||
self.embed = nn.Embedding(num_classes, num_features)
|
||||
self.embed.weight.data.normal_(1, 0.02)
|
||||
|
||||
def forward(self, x, y):
|
||||
vars = torch.var(x, dim=(2, 3, 4), keepdim=True)
|
||||
h = x / torch.sqrt(vars + 1e-5)
|
||||
|
||||
gamma = self.embed(y)
|
||||
out = gamma.view(-1, self.num_features, 1, 1, 1) * h
|
||||
return out
|
||||
|
||||
|
||||
class VarianceNorm3d(nn.Module):
|
||||
def __init__(self, num_features, bias=False):
|
||||
super().__init__()
|
||||
self.num_features = num_features
|
||||
self.bias = bias
|
||||
self.alpha = nn.Parameter(torch.zeros(num_features))
|
||||
self.alpha.data.normal_(1, 0.02)
|
||||
|
||||
def forward(self, x):
|
||||
vars = torch.var(x, dim=(2, 3, 4), keepdim=True)
|
||||
h = x / torch.sqrt(vars + 1e-5)
|
||||
|
||||
out = self.alpha.view(-1, self.num_features, 1, 1, 1) * h
|
||||
return out
|
||||
|
||||
|
||||
class ConditionalNoneNorm3d(nn.Module):
|
||||
def __init__(self, num_features, num_classes, bias=True):
|
||||
super().__init__()
|
||||
self.num_features = num_features
|
||||
self.bias = bias
|
||||
if bias:
|
||||
self.embed = nn.Embedding(num_classes, num_features * 2)
|
||||
self.embed.weight.data[:, :num_features].uniform_() # Initialise scale at N(1, 0.02)
|
||||
self.embed.weight.data[:, num_features:].zero_() # Initialise bias at 0
|
||||
else:
|
||||
self.embed = nn.Embedding(num_classes, num_features)
|
||||
self.embed.weight.data.uniform_()
|
||||
|
||||
def forward(self, x, y):
|
||||
if self.bias:
|
||||
gamma, beta = self.embed(y).chunk(2, dim=-1)
|
||||
out = gamma.view(-1, self.num_features, 1, 1, 1) * x + beta.view(-1, self.num_features, 1, 1, 1)
|
||||
else:
|
||||
gamma = self.embed(y)
|
||||
out = gamma.view(-1, self.num_features, 1, 1, 1) * x
|
||||
return out
|
||||
|
||||
|
||||
class NoneNorm3d(nn.Module):
|
||||
def __init__(self, num_features, bias=True):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x):
|
||||
return x
|
||||
|
||||
|
||||
class InstanceNorm3dPlus(nn.Module):
|
||||
def __init__(self, num_features, bias=True):
|
||||
super().__init__()
|
||||
self.num_features = num_features
|
||||
self.bias = bias
|
||||
self.instance_norm = nn.InstanceNorm3d(num_features, affine=False, track_running_stats=False)
|
||||
self.alpha = nn.Parameter(torch.zeros(num_features))
|
||||
self.gamma = nn.Parameter(torch.zeros(num_features))
|
||||
self.alpha.data.normal_(1, 0.02)
|
||||
self.gamma.data.normal_(1, 0.02)
|
||||
if bias:
|
||||
self.beta = nn.Parameter(torch.zeros(num_features))
|
||||
|
||||
def forward(self, x):
|
||||
means = torch.mean(x, dim=(2, 3, 4))
|
||||
m = torch.mean(means, dim=-1, keepdim=True)
|
||||
v = torch.var(means, dim=-1, keepdim=True)
|
||||
means = (means - m) / (torch.sqrt(v + 1e-5))
|
||||
h = self.instance_norm(x)
|
||||
|
||||
if self.bias:
|
||||
h = h + means[..., None, None, None] * self.alpha[..., None, None, None]
|
||||
out = self.gamma.view(-1, self.num_features, 1, 1, 1) * h + self.beta.view(-1, self.num_features, 1, 1, 1)
|
||||
else:
|
||||
h = h + means[..., None, None, None] * self.alpha[..., None, None, None]
|
||||
out = self.gamma.view(-1, self.num_features, 1, 1, 1) * h
|
||||
return out
|
||||
|
||||
|
||||
class ConditionalInstanceNorm3dPlus(nn.Module):
|
||||
def __init__(self, num_features, num_classes, bias=True):
|
||||
super().__init__()
|
||||
self.num_features = num_features
|
||||
self.bias = bias
|
||||
self.instance_norm = nn.InstanceNorm3d(num_features, affine=False, track_running_stats=False)
|
||||
if bias:
|
||||
self.embed = nn.Embedding(num_classes, num_features * 3)
|
||||
self.embed.weight.data[:, :2 * num_features].normal_(1, 0.02) # Initialise scale at N(1, 0.02)
|
||||
self.embed.weight.data[:, 2 * num_features:].zero_() # Initialise bias at 0
|
||||
else:
|
||||
self.embed = nn.Embedding(num_classes, 2 * num_features)
|
||||
self.embed.weight.data.normal_(1, 0.02)
|
||||
|
||||
def forward(self, x, y):
|
||||
means = torch.mean(x, dim=(2, 3, 4))
|
||||
m = torch.mean(means, dim=-1, keepdim=True)
|
||||
v = torch.var(means, dim=-1, keepdim=True)
|
||||
means = (means - m) / (torch.sqrt(v + 1e-5))
|
||||
h = self.instance_norm(x)
|
||||
|
||||
if self.bias:
|
||||
gamma, alpha, beta = self.embed(y).chunk(3, dim=-1)
|
||||
h = h + means[..., None, None, None] * alpha[..., None, None, None]
|
||||
out = gamma.view(-1, self.num_features, 1, 1, 1) * h + beta.view(-1, self.num_features, 1, 1, 1)
|
||||
else:
|
||||
gamma, alpha = self.embed(y).chunk(2, dim=-1)
|
||||
h = h + means[..., None, None, None] * alpha[..., None, None, None]
|
||||
out = gamma.view(-1, self.num_features, 1, 1, 1) * h
|
||||
return out
|
||||
|
||||
def lip_weight_normalization_3d(W, softplus_c):
|
||||
"""
|
||||
Lipschitz weight normalization based on the L-infinity norm (see Eq.9 in [Liu et al 2022])
|
||||
"""
|
||||
absrowsum = torch.sum(torch.abs(W), dim=[1,2,3,4]) + 1e-8
|
||||
scale = torch.nn.functional.relu(softplus_c/absrowsum - 1.0) + 1.0
|
||||
return W * scale[:, None, None, None, None]
|
|
@ -0,0 +1,259 @@
|
|||
"""Layers used for up-sampling or down-sampling images.
|
||||
|
||||
Many functions are ported from https://github.com/NVlabs/stylegan2.
|
||||
"""
|
||||
|
||||
import torch.nn as nn
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
# from op import upfirdn2d
|
||||
|
||||
|
||||
# Function ported from StyleGAN2
|
||||
def get_weight(module,
|
||||
shape,
|
||||
weight_var='weight',
|
||||
kernel_init=None):
|
||||
"""Get/create weight tensor for a convolution or fully-connected layer."""
|
||||
|
||||
return module.param(weight_var, kernel_init, shape)
|
||||
|
||||
|
||||
class Conv3d(nn.Module):
|
||||
"""Conv3d layer with optimal upsampling and downsampling (StyleGAN2)."""
|
||||
|
||||
def __init__(self, in_ch, out_ch, kernel, up=False, down=False,
|
||||
resample_kernel=(1, 3, 3, 1),
|
||||
use_bias=True,
|
||||
kernel_init=None):
|
||||
super().__init__()
|
||||
assert not (up and down)
|
||||
assert kernel >= 1 and kernel % 2 == 1
|
||||
self.weight = nn.Parameter(torch.zeros(out_ch, in_ch, kernel, kernel, kernel))
|
||||
if kernel_init is not None:
|
||||
self.weight.data = kernel_init(self.weight.data.shape)
|
||||
if use_bias:
|
||||
self.bias = nn.Parameter(torch.zeros(out_ch))
|
||||
|
||||
self.up = up
|
||||
self.down = down
|
||||
self.resample_kernel = resample_kernel
|
||||
self.kernel = kernel
|
||||
self.use_bias = use_bias
|
||||
|
||||
def forward(self, x):
|
||||
if self.up:
|
||||
x = upsample_conv_3d(x, self.weight, k=self.resample_kernel)
|
||||
elif self.down:
|
||||
x = conv_downsample_3d(x, self.weight, k=self.resample_kernel)
|
||||
else:
|
||||
x = F.conv3d(x, self.weight, stride=1, padding=self.kernel // 2)
|
||||
|
||||
if self.use_bias:
|
||||
x = x + self.bias.reshape(1, -1, 1, 1, 1)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def naive_upsample_3d(x, factor=2):
|
||||
_N, C, D, H, W = x.shape
|
||||
x = torch.reshape(x, (-1, C, D, 1, H, 1, W, 1))
|
||||
x = x.repeat(1, 1, 1, factor, 1, factor, 1, factor)
|
||||
return torch.reshape(x, (-1, C, D * factor, H * factor, W * factor))
|
||||
|
||||
|
||||
def naive_downsample_3d(x, factor=2):
|
||||
_N, C, D, H, W = x.shape
|
||||
x = torch.reshape(x, (-1, C, D // factor, factor, H // factor, factor, W // factor, factor))
|
||||
return torch.mean(x, dim=(3, 5, 7))
|
||||
|
||||
|
||||
def upsample_conv_3d(x, w, k=None, factor=2, gain=1):
|
||||
"""Fused `upsample_3d()` followed by `tf.nn.conv3d()`.
|
||||
|
||||
Padding is performed only once at the beginning, not between the
|
||||
operations.
|
||||
The fused op is considerably more efficient than performing the same
|
||||
calculation
|
||||
using standard TensorFlow ops. It supports gradients of arbitrary order.
|
||||
Args:
|
||||
x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
|
||||
C]`.
|
||||
w: Weight tensor of the shape `[filterH, filterW, inChannels,
|
||||
outChannels]`. Grouped convolution can be performed by `inChannels =
|
||||
x.shape[0] // numGroups`.
|
||||
k: FIR filter of the shape `[firH, firW]` or `[firN]`
|
||||
(separable). The default is `[1] * factor`, which corresponds to
|
||||
nearest-neighbor upsampling.
|
||||
factor: Integer upsampling factor (default: 2).
|
||||
gain: Scaling factor for signal magnitude (default: 1.0).
|
||||
|
||||
Returns:
|
||||
Tensor of the shape `[N, C, H * factor, W * factor]` or
|
||||
`[N, H * factor, W * factor, C]`, and same datatype as `x`.
|
||||
"""
|
||||
|
||||
assert isinstance(factor, int) and factor >= 1
|
||||
|
||||
# Check weight shape.
|
||||
assert len(w.shape) == 5
|
||||
convD = w.shape[2]
|
||||
convH = w.shape[3]
|
||||
convW = w.shape[4]
|
||||
inC = w.shape[1]
|
||||
outC = w.shape[0]
|
||||
|
||||
assert convW == convH
|
||||
|
||||
# Setup filter kernel.
|
||||
if k is None:
|
||||
k = [1] * factor
|
||||
k = _setup_kernel(k) * (gain * (factor ** 2))
|
||||
p = (k.shape[0] - factor) - (convW - 1)
|
||||
|
||||
stride = (factor, factor)
|
||||
|
||||
# Determine data dimensions.
|
||||
stride = [1, 1, factor, factor]
|
||||
output_shape = ((_shape(x, 2) - 1) * factor + convD, (_shape(x, 3) - 1) * factor + convH, (_shape(x, 4) - 1) * factor + convW)
|
||||
output_padding = (output_shape[0] - (_shape(x, 2) - 1) * stride[0] - convD,
|
||||
output_shape[1] - (_shape(x, 3) - 1) * stride[1] - convH,
|
||||
output_shape[2] - (_shape(x, 4) - 1) * stride[2] - convW)
|
||||
assert output_padding[0] >= 0 and output_padding[1] >= 0 and output_padding[2] >= 0
|
||||
num_groups = _shape(x, 1) // inC
|
||||
|
||||
# Transpose weights.
|
||||
w = torch.reshape(w, (num_groups, -1, inC, convD, convH, convW))
|
||||
w = w[..., ::-1, ::-1].permute(0, 2, 1, 3, 4, 5)
|
||||
w = torch.reshape(w, (num_groups * inC, -1, convD, convH, convW))
|
||||
|
||||
x = F.conv_transpose3d(x, w, stride=stride, output_padding=output_padding, padding=0)
|
||||
## Original TF code.
|
||||
# x = tf.nn.conv3d_transpose(
|
||||
# x,
|
||||
# w,
|
||||
# output_shape=output_shape,
|
||||
# strides=stride,
|
||||
# padding='VALID',
|
||||
# data_format=data_format)
|
||||
## JAX equivalent
|
||||
|
||||
return upfirdn2d(x, torch.tensor(k, device=x.device),
|
||||
pad=((p + 1) // 2 + factor - 1, p // 2 + 1))
|
||||
|
||||
|
||||
def conv_downsample_3d(x, w, k=None, factor=2, gain=1):
|
||||
"""Fused `tf.nn.conv3d()` followed by `downsample_3d()`.
|
||||
|
||||
Padding is performed only once at the beginning, not between the operations.
|
||||
The fused op is considerably more efficient than performing the same
|
||||
calculation
|
||||
using standard TensorFlow ops. It supports gradients of arbitrary order.
|
||||
Args:
|
||||
x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
|
||||
C]`.
|
||||
w: Weight tensor of the shape `[filterH, filterW, inChannels,
|
||||
outChannels]`. Grouped convolution can be performed by `inChannels =
|
||||
x.shape[0] // numGroups`.
|
||||
k: FIR filter of the shape `[firH, firW]` or `[firN]`
|
||||
(separable). The default is `[1] * factor`, which corresponds to
|
||||
average pooling.
|
||||
factor: Integer downsampling factor (default: 2).
|
||||
gain: Scaling factor for signal magnitude (default: 1.0).
|
||||
|
||||
Returns:
|
||||
Tensor of the shape `[N, C, H // factor, W // factor]` or
|
||||
`[N, H // factor, W // factor, C]`, and same datatype as `x`.
|
||||
"""
|
||||
|
||||
assert isinstance(factor, int) and factor >= 1
|
||||
_outC, _inC, convH, convW = w.shape
|
||||
assert convW == convH
|
||||
if k is None:
|
||||
k = [1] * factor
|
||||
k = _setup_kernel(k) * gain
|
||||
p = (k.shape[0] - factor) + (convW - 1)
|
||||
s = [factor, factor]
|
||||
x = upfirdn2d(x, torch.tensor(k, device=x.device),
|
||||
pad=((p + 1) // 2, p // 2))
|
||||
return F.conv3d(x, w, stride=s, padding=0)
|
||||
|
||||
|
||||
def _setup_kernel(k):
|
||||
k = np.asarray(k, dtype=np.float32)
|
||||
if k.ndim == 1:
|
||||
k = np.outer(k, k)
|
||||
k /= np.sum(k)
|
||||
assert k.ndim == 2
|
||||
assert k.shape[0] == k.shape[1]
|
||||
return k
|
||||
|
||||
|
||||
def _shape(x, dim):
|
||||
return x.shape[dim]
|
||||
|
||||
|
||||
def upsample_3d(x, k=None, factor=2, gain=1):
|
||||
r"""Upsample a batch of 3d images with the given filter.
|
||||
|
||||
Accepts a batch of 3d images of the shape `[N, C, H, W]` or `[N, H, W, C]`
|
||||
and upsamples each image with the given filter. The filter is normalized so
|
||||
that
|
||||
if the input pixels are constant, they will be scaled by the specified
|
||||
`gain`.
|
||||
Pixels outside the image are assumed to be zero, and the filter is padded
|
||||
with
|
||||
zeros so that its shape is a multiple of the upsampling factor.
|
||||
Args:
|
||||
x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
|
||||
C]`.
|
||||
k: FIR filter of the shape `[firH, firW]` or `[firN]`
|
||||
(separable). The default is `[1] * factor`, which corresponds to
|
||||
nearest-neighbor upsampling.
|
||||
factor: Integer upsampling factor (default: 2).
|
||||
gain: Scaling factor for signal magnitude (default: 1.0).
|
||||
|
||||
Returns:
|
||||
Tensor of the shape `[N, C, H * factor, W * factor]`
|
||||
"""
|
||||
assert isinstance(factor, int) and factor >= 1
|
||||
if k is None:
|
||||
k = [1] * factor
|
||||
k = _setup_kernel(k) * (gain * (factor ** 2))
|
||||
p = k.shape[0] - factor
|
||||
return upfirdn2d(x, torch.tensor(k, device=x.device),
|
||||
up=factor, pad=((p + 1) // 2 + factor - 1, p // 2))
|
||||
|
||||
|
||||
def downsample_3d(x, k=None, factor=2, gain=1):
|
||||
r"""Downsample a batch of 3d images with the given filter.
|
||||
|
||||
Accepts a batch of 3d images of the shape `[N, C, H, W]` or `[N, H, W, C]`
|
||||
and downsamples each image with the given filter. The filter is normalized
|
||||
so that
|
||||
if the input pixels are constant, they will be scaled by the specified
|
||||
`gain`.
|
||||
Pixels outside the image are assumed to be zero, and the filter is padded
|
||||
with
|
||||
zeros so that its shape is a multiple of the downsampling factor.
|
||||
Args:
|
||||
x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
|
||||
C]`.
|
||||
k: FIR filter of the shape `[firH, firW]` or `[firN]`
|
||||
(separable). The default is `[1] * factor`, which corresponds to
|
||||
average pooling.
|
||||
factor: Integer downsampling factor (default: 2).
|
||||
gain: Scaling factor for signal magnitude (default: 1.0).
|
||||
|
||||
Returns:
|
||||
Tensor of the shape `[N, C, H // factor, W // factor]`
|
||||
"""
|
||||
|
||||
assert isinstance(factor, int) and factor >= 1
|
||||
if k is None:
|
||||
k = [1] * factor
|
||||
k = _setup_kernel(k) * gain
|
||||
p = k.shape[0] - factor
|
||||
return upfirdn2d(x, torch.tensor(k, device=x.device),
|
||||
down=factor, pad=((p + 1) // 2, p // 2))
|
|
@ -0,0 +1,213 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2020 The Google Research Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""All functions and modules related to model definition.
|
||||
"""
|
||||
|
||||
import torch
|
||||
from .. import sde_lib
|
||||
import numpy as np
|
||||
|
||||
|
||||
_MODELS = {}
|
||||
|
||||
|
||||
def register_model(cls=None, *, name=None):
|
||||
"""A decorator for registering model classes."""
|
||||
|
||||
def _register(cls):
|
||||
if name is None:
|
||||
local_name = cls.__name__
|
||||
else:
|
||||
local_name = name
|
||||
if local_name in _MODELS:
|
||||
raise ValueError(f'Already registered model with name: {local_name}')
|
||||
_MODELS[local_name] = cls
|
||||
return cls
|
||||
|
||||
if cls is None:
|
||||
return _register
|
||||
else:
|
||||
return _register(cls)
|
||||
|
||||
|
||||
def get_model(name):
|
||||
return _MODELS[name]
|
||||
|
||||
|
||||
def get_sigmas(config):
|
||||
"""Get sigmas --- the set of noise levels for SMLD from config files.
|
||||
Args:
|
||||
config: A ConfigDict object parsed from the config file
|
||||
Returns:
|
||||
sigmas: a jax numpy arrary of noise levels
|
||||
"""
|
||||
sigmas = np.exp(
|
||||
np.linspace(np.log(config.model.sigma_max), np.log(config.model.sigma_min), config.model.num_scales))
|
||||
|
||||
return sigmas
|
||||
|
||||
|
||||
def get_ddpm_params(config):
|
||||
"""Get betas and alphas --- parameters used in the original DDPM paper."""
|
||||
num_diffusion_timesteps = 1000
|
||||
# parameters need to be adapted if number of time steps differs from 1000
|
||||
beta_start = config.model.beta_min / config.model.num_scales
|
||||
beta_end = config.model.beta_max / config.model.num_scales
|
||||
betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64)
|
||||
|
||||
alphas = 1. - betas
|
||||
alphas_cumprod = np.cumprod(alphas, axis=0)
|
||||
sqrt_alphas_cumprod = np.sqrt(alphas_cumprod)
|
||||
sqrt_1m_alphas_cumprod = np.sqrt(1. - alphas_cumprod)
|
||||
|
||||
return {
|
||||
'betas': betas,
|
||||
'alphas': alphas,
|
||||
'alphas_cumprod': alphas_cumprod,
|
||||
'sqrt_alphas_cumprod': sqrt_alphas_cumprod,
|
||||
'sqrt_1m_alphas_cumprod': sqrt_1m_alphas_cumprod,
|
||||
'beta_min': beta_start * (num_diffusion_timesteps - 1),
|
||||
'beta_max': beta_end * (num_diffusion_timesteps - 1),
|
||||
'num_diffusion_timesteps': num_diffusion_timesteps
|
||||
}
|
||||
|
||||
|
||||
def create_model(config, use_parallel=True):
|
||||
"""Create the score model."""
|
||||
model_name = config.model.name
|
||||
score_model = get_model(model_name)(config)
|
||||
# score_model = score_model.to(config.device)
|
||||
score_model = score_model
|
||||
if use_parallel:
|
||||
score_model = torch.nn.DataParallel(score_model).to(config.device)
|
||||
return score_model
|
||||
|
||||
|
||||
def get_model_fn(model, train=False):
|
||||
"""Create a function to give the output of the score-based model.
|
||||
|
||||
Args:
|
||||
model: The score model.
|
||||
train: `True` for training and `False` for evaluation.
|
||||
|
||||
Returns:
|
||||
A model function.
|
||||
"""
|
||||
|
||||
def model_fn(x, labels):
|
||||
"""Compute the output of the score-based model.
|
||||
|
||||
Args:
|
||||
x: A mini-batch of input data.
|
||||
labels: A mini-batch of conditioning variables for time steps. Should be interpreted differently
|
||||
for different models.
|
||||
|
||||
Returns:
|
||||
A tuple of (model output, new mutable states)
|
||||
"""
|
||||
if not train:
|
||||
model.eval()
|
||||
return model(x, labels)
|
||||
else:
|
||||
model.train()
|
||||
return model(x, labels)
|
||||
|
||||
return model_fn
|
||||
|
||||
def get_reg_fn(model, train=False):
|
||||
"""Create a function to give the output of the score-based model.
|
||||
|
||||
Args:
|
||||
model: The score model.
|
||||
train: `True` for training and `False` for evaluation.
|
||||
|
||||
Returns:
|
||||
A model function.
|
||||
"""
|
||||
|
||||
def model_fn(x):
|
||||
"""Compute the output of the score-based model.
|
||||
|
||||
Args:
|
||||
x: A mini-batch of input data.
|
||||
labels: A mini-batch of conditioning variables for time steps. Should be interpreted differently
|
||||
for different models.
|
||||
|
||||
Returns:
|
||||
A tuple of (model output, new mutable states)
|
||||
"""
|
||||
if not train:
|
||||
model.eval()
|
||||
try:
|
||||
return model.get_reg(x)
|
||||
except:
|
||||
return torch.zeros_like(x, device=x.device)
|
||||
else:
|
||||
model.train()
|
||||
try:
|
||||
return model.get_reg(x)
|
||||
except:
|
||||
return torch.zeros_like(x, device=x.device)
|
||||
|
||||
return model_fn
|
||||
|
||||
def get_score_fn(sde, model, train=False, continuous=False, std_scale=True):
|
||||
"""Wraps `score_fn` so that the model output corresponds to a real time-dependent score function.
|
||||
|
||||
Args:
|
||||
sde: An `sde_lib.SDE` object that represents the forward SDE.
|
||||
model: A score model.
|
||||
train: `True` for training and `False` for evaluation.
|
||||
continuous: If `True`, the score-based model is expected to directly take continuous time steps.
|
||||
std_scale: whether to scale the score function by the inverse of std. Used for DDIM sampling
|
||||
|
||||
Returns:
|
||||
A score function.
|
||||
"""
|
||||
model_fn = get_model_fn(model, train=train)
|
||||
reg_fn = get_reg_fn(model, train=train)
|
||||
|
||||
assert not continuous
|
||||
if isinstance(sde, sde_lib.VPSDE):
|
||||
if not std_scale:
|
||||
def score_fn(x, t):
|
||||
labels = t * (sde.N - 1)
|
||||
score = model_fn(x, labels)
|
||||
return score
|
||||
else:
|
||||
def score_fn(x, t):
|
||||
# For VP-trained models, t=0 corresponds to the lowest noise level
|
||||
labels = t * (sde.N - 1)
|
||||
score = model_fn(x, labels)
|
||||
std = sde.sqrt_1m_alphas_cumprod.to(labels.device)[labels.long()]
|
||||
|
||||
score = -score / std[:, None, None, None, None]
|
||||
return score
|
||||
|
||||
else:
|
||||
raise NotImplementedError(f"SDE class {sde.__class__.__name__} not yet supported.")
|
||||
|
||||
return score_fn
|
||||
|
||||
|
||||
def to_flattened_numpy(x):
|
||||
"""Flatten a torch tensor `x` and convert it to numpy."""
|
||||
return x.detach().cpu().numpy().reshape((-1,))
|
||||
|
||||
|
||||
def from_flattened_numpy(x, shape):
|
||||
"""Form a torch tensor with the given `shape` from a flattened numpy array `x`."""
|
||||
return torch.from_numpy(x.reshape(shape))
|
|
@ -0,0 +1,570 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2020 The Google Research Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# pylint: skip-file
|
||||
# pytype: skip-file
|
||||
"""Various sampling methods."""
|
||||
import functools
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
import abc
|
||||
|
||||
from .models.utils import from_flattened_numpy, to_flattened_numpy, get_score_fn
|
||||
from scipy import integrate
|
||||
from . import sde_lib
|
||||
from .models import utils as mutils
|
||||
|
||||
import logging
|
||||
import tqdm
|
||||
|
||||
_CORRECTORS = {}
|
||||
_PREDICTORS = {}
|
||||
|
||||
|
||||
def register_predictor(cls=None, *, name=None):
|
||||
"""A decorator for registering predictor classes."""
|
||||
|
||||
def _register(cls):
|
||||
if name is None:
|
||||
local_name = cls.__name__
|
||||
else:
|
||||
local_name = name
|
||||
if local_name in _PREDICTORS:
|
||||
raise ValueError(f'Already registered model with name: {local_name}')
|
||||
_PREDICTORS[local_name] = cls
|
||||
return cls
|
||||
|
||||
if cls is None:
|
||||
return _register
|
||||
else:
|
||||
return _register(cls)
|
||||
|
||||
|
||||
def register_corrector(cls=None, *, name=None):
|
||||
"""A decorator for registering corrector classes."""
|
||||
|
||||
def _register(cls):
|
||||
if name is None:
|
||||
local_name = cls.__name__
|
||||
else:
|
||||
local_name = name
|
||||
if local_name in _CORRECTORS:
|
||||
raise ValueError(f'Already registered model with name: {local_name}')
|
||||
_CORRECTORS[local_name] = cls
|
||||
return cls
|
||||
|
||||
if cls is None:
|
||||
return _register
|
||||
else:
|
||||
return _register(cls)
|
||||
|
||||
|
||||
def get_predictor(name):
|
||||
return _PREDICTORS[name]
|
||||
|
||||
|
||||
def get_corrector(name):
|
||||
return _CORRECTORS[name]
|
||||
|
||||
|
||||
def get_sampling_fn(config, sde, shape, inverse_scaler, eps, grid_mask=None, return_traj=False):
|
||||
"""Create a sampling function.
|
||||
|
||||
Args:
|
||||
config: A `ml_collections.ConfigDict` object that contains all configuration information.
|
||||
sde: A `sde_lib.SDE` object that represents the forward SDE.
|
||||
shape: A sequence of integers representing the expected shape of a single sample.
|
||||
inverse_scaler: The inverse data normalizer function.
|
||||
eps: A `float` number. The reverse-time SDE is only integrated to `eps` for numerical stability.
|
||||
|
||||
Returns:
|
||||
A function that takes random states and a replicated training state and outputs samples with the
|
||||
trailing dimensions matching `shape`.
|
||||
"""
|
||||
|
||||
sampler_name = config.sampling.method
|
||||
# Probability flow ODE sampling with black-box ODE solvers
|
||||
# Predictor-Corrector sampling. Predictor-only and Corrector-only samplers are special cases.
|
||||
if sampler_name.lower() == 'pc':
|
||||
predictor = get_predictor(config.sampling.predictor.lower())
|
||||
corrector = get_corrector(config.sampling.corrector.lower())
|
||||
sampling_fn = get_pc_sampler(sde=sde,
|
||||
shape=shape,
|
||||
predictor=predictor,
|
||||
corrector=corrector,
|
||||
inverse_scaler=inverse_scaler,
|
||||
snr=config.sampling.snr,
|
||||
n_steps=config.sampling.n_steps_each,
|
||||
probability_flow=config.sampling.probability_flow,
|
||||
continuous=config.training.continuous,
|
||||
denoise=config.sampling.noise_removal,
|
||||
eps=eps,
|
||||
device=config.device,
|
||||
grid_mask=grid_mask,
|
||||
return_traj=return_traj)
|
||||
elif sampler_name.lower() == 'ddim':
|
||||
predictor = get_predictor('ddim')
|
||||
sampling_fn = get_ddim_sampler(sde=sde,
|
||||
shape=shape,
|
||||
predictor=predictor,
|
||||
inverse_scaler=inverse_scaler,
|
||||
n_steps=config.sampling.n_steps_each,
|
||||
denoise=config.sampling.noise_removal,
|
||||
eps=eps,
|
||||
device=config.device,
|
||||
grid_mask=grid_mask)
|
||||
else:
|
||||
raise ValueError(f"Sampler name {sampler_name} unknown.")
|
||||
|
||||
return sampling_fn
|
||||
|
||||
|
||||
class Predictor(abc.ABC):
|
||||
"""The abstract class for a predictor algorithm."""
|
||||
|
||||
def __init__(self, sde, score_fn, probability_flow=False):
|
||||
super().__init__()
|
||||
self.sde = sde
|
||||
# Compute the reverse SDE/ODE
|
||||
self.rsde = sde.reverse(score_fn, probability_flow)
|
||||
self.score_fn = score_fn
|
||||
|
||||
@abc.abstractmethod
|
||||
def update_fn(self, x, t):
|
||||
"""One update of the predictor.
|
||||
|
||||
Args:
|
||||
x: A PyTorch tensor representing the current state
|
||||
t: A Pytorch tensor representing the current time step.
|
||||
|
||||
Returns:
|
||||
x: A PyTorch tensor of the next state.
|
||||
x_mean: A PyTorch tensor. The next state without random noise. Useful for denoising.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class Corrector(abc.ABC):
|
||||
"""The abstract class for a corrector algorithm."""
|
||||
|
||||
def __init__(self, sde, score_fn, snr, n_steps):
|
||||
super().__init__()
|
||||
self.sde = sde
|
||||
self.score_fn = score_fn
|
||||
self.snr = snr
|
||||
self.n_steps = n_steps
|
||||
|
||||
@abc.abstractmethod
|
||||
def update_fn(self, x, t):
|
||||
"""One update of the corrector.
|
||||
|
||||
Args:
|
||||
x: A PyTorch tensor representing the current state
|
||||
t: A PyTorch tensor representing the current time step.
|
||||
|
||||
Returns:
|
||||
x: A PyTorch tensor of the next state.
|
||||
x_mean: A PyTorch tensor. The next state without random noise. Useful for denoising.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@register_predictor(name='euler_maruyama')
|
||||
class EulerMaruyamaPredictor(Predictor):
|
||||
def __init__(self, sde, score_fn, probability_flow=False):
|
||||
super().__init__(sde, score_fn, probability_flow)
|
||||
|
||||
def update_fn(self, x, t):
|
||||
dt = -1. / self.rsde.N
|
||||
z = torch.randn_like(x)
|
||||
drift, diffusion = self.rsde.sde(x, t)
|
||||
x_mean = x + drift * dt
|
||||
x = x_mean + diffusion[:, None, None, None, None] * np.sqrt(-dt) * z
|
||||
return x, x_mean
|
||||
|
||||
|
||||
@register_predictor(name='reverse_diffusion')
|
||||
class ReverseDiffusionPredictor(Predictor):
|
||||
def __init__(self, sde, score_fn, probability_flow=False):
|
||||
super().__init__(sde, score_fn, probability_flow)
|
||||
|
||||
def update_fn(self, x, t):
|
||||
f, G = self.rsde.discretize(x, t)
|
||||
z = torch.randn_like(x)
|
||||
x_mean = x - f
|
||||
x = x_mean + G[:, None, None, None, None] * z
|
||||
return x, x_mean
|
||||
|
||||
|
||||
@register_predictor(name='ancestral_sampling')
|
||||
class AncestralSamplingPredictor(Predictor):
|
||||
"""The ancestral sampling predictor. Currently only supports VE/VP SDEs."""
|
||||
|
||||
def __init__(self, sde, score_fn, probability_flow=False):
|
||||
super().__init__(sde, score_fn, probability_flow)
|
||||
if not isinstance(sde, sde_lib.VPSDE):
|
||||
raise NotImplementedError(f"SDE class {sde.__class__.__name__} not yet supported.")
|
||||
assert not probability_flow, "Probability flow not supported by ancestral sampling"
|
||||
|
||||
def vpsde_update_fn(self, x, t):
|
||||
sde = self.sde
|
||||
timestep = (t * (sde.N - 1) / sde.T).long()
|
||||
beta = sde.discrete_betas.to(t.device)[timestep]
|
||||
score = self.score_fn(x, t)
|
||||
x_mean = (x + beta[:, None, None, None, None] * score) / torch.sqrt(1. - beta)[:, None, None, None, None]
|
||||
noise = torch.randn_like(x)
|
||||
x = x_mean + torch.sqrt(beta)[:, None, None, None, None] * noise
|
||||
return x, x_mean
|
||||
|
||||
def update_fn(self, x, t):
|
||||
if isinstance(self.sde, sde_lib.VPSDE):
|
||||
return self.vpsde_update_fn(x, t)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@register_predictor(name='none')
|
||||
class NonePredictor(Predictor):
|
||||
"""An empty predictor that does nothing."""
|
||||
|
||||
def __init__(self, sde, score_fn, probability_flow=False):
|
||||
pass
|
||||
|
||||
def update_fn(self, x, t):
|
||||
return x, x
|
||||
|
||||
@register_predictor(name='ddim')
|
||||
class DDIMPredictor(Predictor):
|
||||
def __init__(self, sde, score_fn, probability_flow=False):
|
||||
super().__init__(sde, score_fn, probability_flow)
|
||||
|
||||
|
||||
def update_fn(self, x, t, tprev=None):
|
||||
x, x0_pred = self.rsde.discretize_ddim(x, t, tprev=tprev)
|
||||
return x, x0_pred
|
||||
|
||||
@register_corrector(name='langevin')
|
||||
class LangevinCorrector(Corrector):
|
||||
def __init__(self, sde, score_fn, snr, n_steps):
|
||||
super().__init__(sde, score_fn, snr, n_steps)
|
||||
if not isinstance(sde, sde_lib.VPSDE):
|
||||
raise NotImplementedError(f"SDE class {sde.__class__.__name__} not yet supported.")
|
||||
|
||||
def update_fn(self, x, t):
|
||||
sde = self.sde
|
||||
score_fn = self.score_fn
|
||||
n_steps = self.n_steps
|
||||
target_snr = self.snr
|
||||
if isinstance(sde, sde_lib.VPSDE):
|
||||
timestep = (t * (sde.N - 1) / sde.T).long()
|
||||
alpha = sde.alphas.to(t.device)[timestep]
|
||||
else:
|
||||
alpha = torch.ones_like(t)
|
||||
|
||||
for i in range(n_steps):
|
||||
grad = score_fn(x, t)
|
||||
noise = torch.randn_like(x)
|
||||
grad_norm = torch.norm(grad.reshape(grad.shape[0], -1), dim=-1).mean()
|
||||
noise_norm = torch.norm(noise.reshape(noise.shape[0], -1), dim=-1).mean()
|
||||
step_size = (target_snr * noise_norm / grad_norm) ** 2 * 2 * alpha
|
||||
x_mean = x + step_size[:, None, None, None, None] * grad
|
||||
x = x_mean + torch.sqrt(step_size * 2)[:, None, None, None, None] * noise
|
||||
|
||||
return x, x_mean
|
||||
|
||||
|
||||
@register_corrector(name='ald')
|
||||
class AnnealedLangevinDynamics(Corrector):
|
||||
"""The original annealed Langevin dynamics predictor in NCSN/NCSNv2.
|
||||
|
||||
We include this corrector only for completeness. It was not directly used in our paper.
|
||||
"""
|
||||
|
||||
def __init__(self, sde, score_fn, snr, n_steps):
|
||||
super().__init__(sde, score_fn, snr, n_steps)
|
||||
if not isinstance(sde, sde_lib.VPSDE):
|
||||
raise NotImplementedError(f"SDE class {sde.__class__.__name__} not yet supported.")
|
||||
|
||||
def update_fn(self, x, t):
|
||||
sde = self.sde
|
||||
score_fn = self.score_fn
|
||||
n_steps = self.n_steps
|
||||
target_snr = self.snr
|
||||
if isinstance(sde, sde_lib.VPSDE):
|
||||
timestep = (t * (sde.N - 1) / sde.T).long()
|
||||
alpha = sde.alphas.to(t.device)[timestep]
|
||||
else:
|
||||
alpha = torch.ones_like(t)
|
||||
|
||||
std = self.sde.marginal_prob(x, t)[1]
|
||||
|
||||
for i in range(n_steps):
|
||||
grad = score_fn(x, t)
|
||||
noise = torch.randn_like(x)
|
||||
step_size = (target_snr * std) ** 2 * 2 * alpha
|
||||
x_mean = x + step_size[:, None, None, None, None] * grad
|
||||
x = x_mean + noise * torch.sqrt(step_size * 2)[:, None, None, None, None]
|
||||
|
||||
return x, x_mean
|
||||
|
||||
|
||||
@register_corrector(name='none')
|
||||
class NoneCorrector(Corrector):
|
||||
"""An empty corrector that does nothing."""
|
||||
|
||||
def __init__(self, sde, score_fn, snr, n_steps):
|
||||
pass
|
||||
|
||||
def update_fn(self, x, t):
|
||||
return x, x
|
||||
|
||||
|
||||
def shared_predictor_update_fn(x, t, sde, model, predictor, probability_flow, continuous):
|
||||
"""A wrapper that configures and returns the update function of predictors."""
|
||||
score_fn = mutils.get_score_fn(sde, model, train=False, continuous=continuous)
|
||||
if predictor is None:
|
||||
# Corrector-only sampler
|
||||
predictor_obj = NonePredictor(sde, score_fn, probability_flow)
|
||||
else:
|
||||
predictor_obj = predictor(sde, score_fn, probability_flow)
|
||||
return predictor_obj.update_fn(x, t)
|
||||
|
||||
|
||||
def shared_corrector_update_fn(x, t, sde, model, corrector, continuous, snr, n_steps):
|
||||
"""A wrapper tha configures and returns the update function of correctors."""
|
||||
score_fn = mutils.get_score_fn(sde, model, train=False, continuous=continuous)
|
||||
if corrector is None:
|
||||
# Predictor-only sampler
|
||||
corrector_obj = NoneCorrector(sde, score_fn, snr, n_steps)
|
||||
else:
|
||||
corrector_obj = corrector(sde, score_fn, snr, n_steps)
|
||||
return corrector_obj.update_fn(x, t)
|
||||
|
||||
|
||||
def get_pc_sampler(sde, shape, predictor, corrector, inverse_scaler, snr,
|
||||
n_steps=1, probability_flow=False, continuous=False,
|
||||
denoise=True, eps=1e-3, device='cuda', grid_mask=None, return_traj=False):
|
||||
"""Create a Predictor-Corrector (PC) sampler.
|
||||
|
||||
Args:
|
||||
sde: An `sde_lib.SDE` object representing the forward SDE.
|
||||
shape: A sequence of integers. The expected shape of a single sample.
|
||||
predictor: A subclass of `sampling.Predictor` representing the predictor algorithm.
|
||||
corrector: A subclass of `sampling.Corrector` representing the corrector algorithm.
|
||||
inverse_scaler: The inverse data normalizer.
|
||||
snr: A `float` number. The signal-to-noise ratio for configuring correctors.
|
||||
n_steps: An integer. The number of corrector steps per predictor update.
|
||||
probability_flow: If `True`, solve the reverse-time probability flow ODE when running the predictor.
|
||||
continuous: `True` indicates that the score model was continuously trained.
|
||||
denoise: If `True`, add one-step denoising to the final samples.
|
||||
eps: A `float` number. The reverse-time SDE and ODE are integrated to `epsilon` to avoid numerical issues.
|
||||
device: PyTorch device.
|
||||
|
||||
Returns:
|
||||
A sampling function that returns samples and the number of function evaluations during sampling.
|
||||
"""
|
||||
# Create predictor & corrector update functions
|
||||
predictor_update_fn = functools.partial(shared_predictor_update_fn,
|
||||
sde=sde,
|
||||
predictor=predictor,
|
||||
probability_flow=probability_flow,
|
||||
continuous=continuous)
|
||||
corrector_update_fn = functools.partial(shared_corrector_update_fn,
|
||||
sde=sde,
|
||||
corrector=corrector,
|
||||
continuous=continuous,
|
||||
snr=snr,
|
||||
n_steps=n_steps)
|
||||
|
||||
def pc_sampler(model,
|
||||
partial=None, partial_grid_mask=None, partial_channel=0,
|
||||
freeze_iters=None):
|
||||
""" The PC sampler funciton.
|
||||
|
||||
Args:
|
||||
model: A score model.
|
||||
Returns:
|
||||
Samples, number of function evaluations.
|
||||
"""
|
||||
with torch.no_grad():
|
||||
|
||||
if freeze_iters is None:
|
||||
freeze_iters = sde.N + 10 # just some randomly large number greater than sde.N
|
||||
timesteps = torch.linspace(sde.T, eps, sde.N, device=device)
|
||||
|
||||
|
||||
|
||||
def compute_xzero(sde, model, x, t, grid_mask_input):
|
||||
timestep_int = (t * (sde.N - 1) / sde.T).long()
|
||||
alphas1 = sde.sqrt_alphas_cumprod[timestep_int].cuda()
|
||||
alphas2 = sde.sqrt_1m_alphas_cumprod[timestep_int].cuda()
|
||||
alphas1_prev = sde.sqrt_alphas_cumprod[timestep_int - 1].cuda()
|
||||
alphas2_prev = sde.sqrt_1m_alphas_cumprod[timestep_int - 1].cuda()
|
||||
score_pred = model(x, t * torch.ones(shape[0], device=x.device))
|
||||
x0_pred_scaled = (x - alphas2 * score_pred)
|
||||
x0_pred = x0_pred_scaled / alphas1
|
||||
x0_pred = x0_pred.clamp(-1, 1)
|
||||
return x0_pred * grid_mask_input
|
||||
|
||||
# Initial sample
|
||||
x = sde.prior_sampling(shape).to(device)
|
||||
assert len(x.size()) == 5
|
||||
x = x * grid_mask
|
||||
|
||||
traj_buffer = []
|
||||
|
||||
if partial is not None:
|
||||
assert len(partial.size()) == 5
|
||||
t = timesteps[0]
|
||||
vec_t = torch.ones(shape[0], device=t.device) * t
|
||||
x[:, partial_channel] = partial[:, partial_channel] * grid_mask[:, partial_channel]
|
||||
|
||||
partial_mean, partial_std = sde.marginal_prob(x, vec_t)
|
||||
sampled_update = partial_mean[:, partial_channel] + partial_std[:, None, None, None, None] * torch.randn_like(partial_mean[:, partial_channel], device=partial_std.device)
|
||||
x[:, partial_channel] = (
|
||||
x[:, partial_channel] * (1 - partial_mask[:, partial_channel])
|
||||
+ sampled_update[:, partial_channel] * partial_mask[:, partial_channel]
|
||||
) * grid_mask[:, partial_channel]
|
||||
|
||||
|
||||
if partial is not None:
|
||||
x_mean = x
|
||||
for i in tqdm.trange(sde.N):
|
||||
t = timesteps[i]
|
||||
vec_t = torch.ones(shape[0], device=t.device) * t
|
||||
|
||||
x, x_mean = corrector_update_fn(x, vec_t, model=model)
|
||||
x, x_mean = x * grid_mask, x_mean * grid_mask
|
||||
x, x_mean = predictor_update_fn(x, vec_t, model=model)
|
||||
x, x_mean = x * grid_mask, x_mean * grid_mask
|
||||
|
||||
|
||||
if i != sde.N - 1 and i < freeze_iters:
|
||||
|
||||
x[:, partial_channel] = (x[:, partial_channel] * (1 - partial_mask[:, partial_channel]) + partial[:, partial_channel] * partial_mask[:, partial_channel]) * grid_mask[:, partial_channel]
|
||||
x_mean[:, partial_channel] = (x_mean[:, partial_channel] * (1 - partial_mask[:, partial_channel]) + partial[:, partial_channel] * partial_mask[:, partial_channel]) * grid_mask[:, partial_channel]
|
||||
|
||||
### add noise to the condition x0_star
|
||||
partial_mean, partial_std = sde.marginal_prob(x, timesteps[i] * torch.ones(shape[0], device=t.device))
|
||||
sampled_update = partial_mean[:, partial_channel] + partial_std[:, None, None, None] * torch.randn_like(partial_mean[:, partial_channel], device=partial_std.device)
|
||||
x[:, partial_channel] = (
|
||||
x[:, partial_channel] * (1 - partial_mask[:, partial_channel])
|
||||
+ sampled_update * partial_mask[:, partial_channel]
|
||||
) * grid_mask[:, partial_channel]
|
||||
x_mean[:, partial_channel] = x[:, partial_channel]
|
||||
|
||||
else:
|
||||
|
||||
for i in tqdm.trange(sde.N - 1):
|
||||
t = timesteps[i]
|
||||
|
||||
vec_t = torch.ones(shape[0], device=t.device) * t
|
||||
x, x_mean = corrector_update_fn(x, vec_t, model=model)
|
||||
x, x_mean = x * grid_mask, x_mean * grid_mask
|
||||
x, x_mean = predictor_update_fn(x, vec_t, model=model)
|
||||
x, x_mean = x * grid_mask, x_mean * grid_mask
|
||||
|
||||
if return_traj and i >= 700 and i % 10 == 0:
|
||||
traj_buffer.append(compute_xzero(sde, model, x, t, grid_mask))
|
||||
|
||||
if return_traj:
|
||||
return traj_buffer, sde.N * (n_steps + 1)
|
||||
return inverse_scaler(x_mean if denoise else x), sde.N * (n_steps + 1)
|
||||
|
||||
return pc_sampler
|
||||
|
||||
def ddim_predictor_update_fn(x, t, tprev, sde, model, predictor, probability_flow, continuous):
|
||||
"""A wrapper that configures and returns the update function of predictors."""
|
||||
assert not continuous
|
||||
score_fn = mutils.get_score_fn(sde, model, train=False, continuous=False, std_scale=False)
|
||||
if predictor is None:
|
||||
# Corrector-only sampler
|
||||
predictor_obj = NonePredictor(sde, score_fn, probability_flow)
|
||||
else:
|
||||
predictor_obj = predictor(sde, score_fn, probability_flow)
|
||||
return predictor_obj.update_fn(x, t, tprev)
|
||||
|
||||
def get_ddim_sampler(sde, shape, predictor, inverse_scaler, n_steps=1,
|
||||
denoise=False, eps=1e-3, device='cuda', grid_mask=None):
|
||||
"""Probability flow ODE sampler with the black-box ODE solver.
|
||||
|
||||
Args:
|
||||
sde: An `sde_lib.SDE` object that represents the forward SDE.
|
||||
shape: A sequence of integers. The expected shape of a single sample.
|
||||
inverse_scaler: The inverse data normalizer.
|
||||
denoise: If `True`, add one-step denoising to final samples.
|
||||
eps: A `float` number. The reverse-time SDE/ODE will be integrated to `eps` for numerical stability.
|
||||
device: PyTorch device.
|
||||
|
||||
Returns:
|
||||
A sampling function that returns samples and the number of function evaluations during sampling.
|
||||
"""
|
||||
|
||||
predictor_update_fn = functools.partial(ddim_predictor_update_fn,
|
||||
sde=sde,
|
||||
predictor=predictor,
|
||||
probability_flow=False,
|
||||
continuous=False)
|
||||
|
||||
def ddim_sampler(model, schedule='quad', num_steps=100, x0=None,
|
||||
partial=None, partial_grid_mask=None, partial_channel=0):
|
||||
""" The PC sampler funciton.
|
||||
|
||||
Args:
|
||||
model: A score model.
|
||||
Returns:
|
||||
Samples, number of function evaluations.
|
||||
"""
|
||||
with torch.no_grad():
|
||||
if x0 is not None:
|
||||
x = x0 * grid_mask
|
||||
else:
|
||||
# Initial sample
|
||||
x = sde.prior_sampling(shape).to(device)
|
||||
x = x * grid_mask
|
||||
|
||||
if partial is not None:
|
||||
x[:, partial_channel] = x[:, partial_channel] * (1 - partial_mask) + partial * partial_mask
|
||||
|
||||
timesteps = torch.linspace(sde.T, eps, sde.N, device=device)
|
||||
|
||||
if schedule == 'uniform':
|
||||
skip = sde.N // num_steps
|
||||
seq = range(0, sde.N, skip)
|
||||
elif schedule == 'quad':
|
||||
seq = (
|
||||
np.linspace(
|
||||
0, np.sqrt(sde.N * 0.8), 100
|
||||
)
|
||||
** 2
|
||||
)
|
||||
seq = [int(s) for s in list(seq)]
|
||||
|
||||
timesteps = torch.tensor(seq) / sde.N
|
||||
|
||||
for i in tqdm.tqdm(reversed(range(1, len(timesteps)))):
|
||||
t = timesteps[i]
|
||||
tprev = timesteps[i - 1]
|
||||
vec_t = torch.ones(shape[0], device=t.device) * t
|
||||
vec_tprev = torch.ones(shape[0], device=t.device) * tprev
|
||||
x, x0_pred = predictor_update_fn(x, vec_t, model=model, tprev=vec_tprev)
|
||||
x, x0_pred = x * grid_mask, x0_pred * grid_mask
|
||||
if partial is not None:
|
||||
x[:, partial_channel] = x[:, partial_channel] * (1 - partial_mask) + partial * partial_mask
|
||||
x0_pred[:, partial_channel] = x0_pred[:, partial_channel] * (1 - partial_mask) + partial * partial_mask
|
||||
|
||||
return inverse_scaler(x0_pred * grid_mask if (denoise and not encode) else x * grid_mask), sde.N * (n_steps + 1)
|
||||
return ddim_sampler
|
|
@ -0,0 +1,233 @@
|
|||
"""Abstract SDE classes, Reverse SDE, and VE/VP SDEs."""
|
||||
import abc
|
||||
import torch
|
||||
import numpy as np
|
||||
import torch.nn.functional as F
|
||||
import time
|
||||
|
||||
|
||||
class SDE(abc.ABC):
|
||||
"""SDE abstract class. Functions are designed for a mini-batch of inputs."""
|
||||
|
||||
def __init__(self, N):
|
||||
"""Construct an SDE.
|
||||
|
||||
Args:
|
||||
N: number of discretization time steps.
|
||||
"""
|
||||
super().__init__()
|
||||
self.N = N
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def T(self):
|
||||
"""End time of the SDE."""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def sde(self, x, t):
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def marginal_prob(self, x, t):
|
||||
"""Parameters to determine the marginal distribution of the SDE, $p_t(x)$."""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def prior_sampling(self, shape):
|
||||
"""Generate one sample from the prior distribution, $p_T(x)$."""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def prior_logp(self, z):
|
||||
"""Compute log-density of the prior distribution.
|
||||
|
||||
Useful for computing the log-likelihood via probability flow ODE.
|
||||
|
||||
Args:
|
||||
z: latent code
|
||||
Returns:
|
||||
log probability density
|
||||
"""
|
||||
pass
|
||||
|
||||
def discretize(self, x, t):
|
||||
"""Discretize the SDE in the form: x_{i+1} = x_i + f_i(x_i) + G_i z_i.
|
||||
|
||||
Useful for reverse diffusion sampling and probabiliy flow sampling.
|
||||
Defaults to Euler-Maruyama discretization.
|
||||
|
||||
Args:
|
||||
x: a torch tensor
|
||||
t: a torch float representing the time step (from 0 to `self.T`)
|
||||
|
||||
Returns:
|
||||
f, G
|
||||
"""
|
||||
dt = 1 / self.N
|
||||
drift, diffusion = self.sde(x, t)
|
||||
f = drift * dt
|
||||
G = diffusion * torch.sqrt(torch.tensor(dt, device=t.device))
|
||||
return f, G
|
||||
|
||||
def reverse(self, score_fn, probability_flow=False):
|
||||
"""Create the reverse-time SDE/ODE.
|
||||
|
||||
Args:
|
||||
score_fn: A time-dependent score-based model that takes x and t and returns the score.
|
||||
probability_flow: If `True`, create the reverse-time ODE used for probability flow sampling.
|
||||
"""
|
||||
N = self.N
|
||||
T = self.T
|
||||
sde_fn = self.sde
|
||||
discretize_fn = self.discretize
|
||||
sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
|
||||
sqrt_1m_alphas_cumprod = self.sqrt_1m_alphas_cumprod
|
||||
|
||||
# Build the class for reverse-time SDE.
|
||||
class RSDE(self.__class__):
|
||||
def __init__(self):
|
||||
self.N = N
|
||||
self.probability_flow = probability_flow
|
||||
|
||||
@property
|
||||
def T(self):
|
||||
return T
|
||||
|
||||
def sde(self, x, t):
|
||||
"""Create the drift and diffusion functions for the reverse SDE/ODE."""
|
||||
drift, diffusion = sde_fn(x, t)
|
||||
score = score_fn(x, t)
|
||||
drift = drift - diffusion[:, None, None, None, None] ** 2 * score * (0.5 if self.probability_flow else 1.)
|
||||
# Set the diffusion function to zero for ODEs.
|
||||
diffusion = 0. if self.probability_flow else diffusion
|
||||
return drift, diffusion
|
||||
|
||||
def discretize(self, x, t):
|
||||
"""Create discretized iteration rules for the reverse diffusion sampler."""
|
||||
f, G = discretize_fn(x, t)
|
||||
rev_f = f - G[:, None, None, None, None] ** 2 * score_fn(x, t) * (0.5 if self.probability_flow else 1.)
|
||||
rev_G = torch.zeros_like(G) if self.probability_flow else G
|
||||
return rev_f, rev_G
|
||||
|
||||
def discretize_ddim(self, x, t, tprev=None, encode=False):
|
||||
"""DDPM discretization."""
|
||||
timestep = (t * (N - 1) / T).long()
|
||||
timestep_prev = (tprev * (N - 1) / T).long()
|
||||
|
||||
score = score_fn(x.float(), t.float())
|
||||
|
||||
# alphas1prev_div_alphas1 = torch.exp(log_diff)
|
||||
alphas1 = sqrt_alphas_cumprod[timestep].cuda()[:, None, None, None, None]
|
||||
alphas2 = sqrt_1m_alphas_cumprod[timestep].cuda()[:, None, None, None, None]
|
||||
alphas1_prev = sqrt_alphas_cumprod[timestep_prev].cuda()[:, None, None, None, None]
|
||||
alphas2_prev = sqrt_1m_alphas_cumprod[timestep_prev].cuda()[:, None, None, None, None]
|
||||
alphas1prev_div_alphas1 = alphas1_prev.double() / alphas1.double()
|
||||
alphas2prev_div_alphas2 = alphas2_prev.double() / alphas2.double()
|
||||
|
||||
|
||||
x0_pred_scaled = (x.double() - alphas2.double() * score.double())
|
||||
use_clip = False
|
||||
if use_clip:
|
||||
x0_pred_scaled = x0_pred_scaled.clamp(-alphas1[0].squeeze(), alphas1[0].squeeze())
|
||||
score_scaled_t = x - x0_pred_scaled
|
||||
x0_pred = x0_pred_scaled / alphas1
|
||||
|
||||
x_new = (
|
||||
alphas1prev_div_alphas1.double() * x +
|
||||
(-alphas1prev_div_alphas1 + alphas2prev_div_alphas2.double()) * score_scaled_t.double()
|
||||
)
|
||||
return x_new, x0_pred
|
||||
|
||||
|
||||
def discretize_conditional_ddpm(self, x, t, tprev=None, condition_func=None, condition=False):
|
||||
"""DDPM discretization."""
|
||||
timestep = (t * (N - 1) / T).long()
|
||||
timestep_prev = (tprev * (N - 1) / T).long()
|
||||
|
||||
score = score_fn(x.float(), t.float())
|
||||
|
||||
# alphas1prev_div_alphas1 = torch.exp(log_diff)
|
||||
alphas1 = sqrt_alphas_cumprod[timestep].cuda()[:, None, None, None, None]
|
||||
alphas2 = sqrt_1m_alphas_cumprod[timestep].cuda()[:, None, None, None, None]
|
||||
alphas1_prev = sqrt_alphas_cumprod[timestep_prev].cuda()[:, None, None, None, None]
|
||||
alphas1prev_div_alphas1 = alphas1_prev.double() / alphas1.double()
|
||||
|
||||
x0_pred_scaled = (x.double() - alphas2.double() * score.double())
|
||||
x0_pred_scaled = x0_pred_scaled.clamp(-alphas1[0].squeeze(), alphas1[0].squeeze())
|
||||
x0_pred = x0_pred_scaled / alphas1
|
||||
|
||||
if condition is None:
|
||||
condition_update = 0
|
||||
else:
|
||||
if (t - 0.99).mean() < 1e-3:
|
||||
x = x0_pred
|
||||
condition_update = condition_func(x.float(), condition)
|
||||
|
||||
x_new = (
|
||||
x - alphas1prev_div_alphas1.double() * condition_update
|
||||
)
|
||||
return x_new, x0_pred
|
||||
|
||||
|
||||
return RSDE()
|
||||
|
||||
|
||||
class VPSDE(SDE):
|
||||
def __init__(self, beta_min=0.1, beta_max=20, N=1000):
|
||||
"""Construct a Variance Preserving SDE.
|
||||
|
||||
Args:
|
||||
beta_min: value of beta(0)
|
||||
beta_max: value of beta(1)
|
||||
N: number of discretization steps
|
||||
"""
|
||||
super().__init__(N)
|
||||
self.beta_0 = beta_min
|
||||
self.beta_1 = beta_max
|
||||
self.N = N
|
||||
self.discrete_betas = torch.linspace(beta_min / N, beta_max / N, N).cuda()
|
||||
self.alphas = 1. - self.discrete_betas
|
||||
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
|
||||
self.alphas_cumprod_ext = torch.cat([torch.tensor([1.0 - 1e-4]).cuda(), torch.cumprod(self.alphas, dim=0)], dim=0)
|
||||
|
||||
self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
|
||||
self.sqrt_1m_alphas_cumprod = torch.sqrt(1. - self.alphas_cumprod)
|
||||
|
||||
self.alphas_cumprod = self.alphas_cumprod
|
||||
self.alphas_cumprod_ext = self.alphas_cumprod_ext
|
||||
|
||||
@property
|
||||
def T(self):
|
||||
return 1
|
||||
|
||||
def sde(self, x, t):
|
||||
beta_t = self.beta_0 + t * (self.beta_1 - self.beta_0)
|
||||
drift = -0.5 * beta_t[:, None, None, None, None] * x
|
||||
diffusion = torch.sqrt(beta_t)
|
||||
return drift, diffusion
|
||||
|
||||
def marginal_prob(self, x, t):
|
||||
log_mean_coeff = -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0
|
||||
mean = torch.exp(log_mean_coeff[:, None, None, None, None]) * x
|
||||
std = torch.sqrt(1. - torch.exp(2. * log_mean_coeff))
|
||||
return mean, std
|
||||
|
||||
def prior_sampling(self, shape):
|
||||
return torch.randn(*shape)
|
||||
|
||||
def prior_logp(self, z):
|
||||
shape = z.shape
|
||||
N = np.prod(shape[1:])
|
||||
logps = -N / 2. * np.log(2 * np.pi) - torch.sum(z ** 2, dim=(1, 2, 3, 4)) / 2.
|
||||
return logps
|
||||
|
||||
def discretize(self, x, t):
|
||||
"""DDPM discretization."""
|
||||
timestep = (t * (self.N - 1) / self.T).long()
|
||||
beta = self.discrete_betas.to(x.device)[timestep]
|
||||
alpha = self.alphas.to(x.device)[timestep]
|
||||
sqrt_beta = torch.sqrt(beta)
|
||||
f = torch.sqrt(alpha)[:, None, None, None, None] * x - x
|
||||
G = sqrt_beta
|
||||
return f, G
|
|
@ -0,0 +1,130 @@
|
|||
import os
|
||||
import sys
|
||||
import numpy as np
|
||||
|
||||
import logging
|
||||
# Keep the import below for registering all model definitions
|
||||
from .models import ddpm_res64, ddpm_res128
|
||||
|
||||
from . import losses
|
||||
from .models import utils as mutils
|
||||
from .models.ema import ExponentialMovingAverage
|
||||
from . import sde_lib
|
||||
import torch
|
||||
from torch.utils import tensorboard
|
||||
from .utils import save_checkpoint, restore_checkpoint
|
||||
from ..dataset.shapenet_dmtet_dataset import ShapeNetDMTetDataset
|
||||
|
||||
def train(config):
|
||||
"""Runs the training pipeline.
|
||||
|
||||
Args:
|
||||
config: Configuration to use.
|
||||
workdir: Working directory for checkpoints and TF summaries. If this
|
||||
contains checkpoint training will be resumed from the latest checkpoint.
|
||||
"""
|
||||
|
||||
workdir = config.training.train_dir
|
||||
# Create directories for experimental logs
|
||||
logging.info("working dir: {:s}".format(workdir))
|
||||
|
||||
|
||||
tb_dir = os.path.join(workdir, "tensorboard")
|
||||
writer = tensorboard.SummaryWriter(tb_dir)
|
||||
|
||||
resolution = config.data.image_size
|
||||
# Initialize model.
|
||||
score_model = mutils.create_model(config)
|
||||
ema = ExponentialMovingAverage(score_model.parameters(), decay=config.model.ema_rate)
|
||||
optimizer = losses.get_optimizer(config, score_model.parameters())
|
||||
state = dict(optimizer=optimizer, model=score_model, ema=ema, step=0)
|
||||
|
||||
|
||||
# Create checkpoints directory
|
||||
checkpoint_dir = os.path.join(workdir, "checkpoints")
|
||||
# Intermediate checkpoints to resume training after pre-emption in cloud environments
|
||||
checkpoint_meta_dir = os.path.join(workdir, "checkpoints-meta", "checkpoint.pth")
|
||||
os.makedirs(checkpoint_dir, exist_ok=True)
|
||||
os.makedirs(os.path.dirname(checkpoint_meta_dir), exist_ok=True)
|
||||
|
||||
# Resume training when intermediate checkpoints are detected
|
||||
state = restore_checkpoint(checkpoint_meta_dir, state, config.device)
|
||||
initial_step = int(state['step'])
|
||||
|
||||
json_path = config.data.meta_path
|
||||
print("----- Assigning mask -----")
|
||||
logging.info(f"{json_path}, {config.data.filter_meta_path}")
|
||||
|
||||
### mask on tet to ignore regions
|
||||
mask = torch.load(f'./data/grid_mask_{resolution}.pt').view(1, 1, resolution, resolution, resolution).to("cuda")
|
||||
|
||||
if hasattr(score_model.module, 'mask'):
|
||||
print("----- Assigning mask -----")
|
||||
score_model.module.mask.data[:] = mask[:]
|
||||
|
||||
print(f"work dir: {workdir}")
|
||||
|
||||
print("sdf normalized or not: ", config.data.normalize_sdf)
|
||||
train_dataset = ShapeNetDMTetDataset(json_path, deform_scale=config.model.deform_scale, aug=True, grid_mask=mask,
|
||||
filter_meta_path=config.data.filter_meta_path, normalize_sdf=config.data.normalize_sdf)
|
||||
|
||||
|
||||
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=config.training.batch_size,
|
||||
shuffle=True,
|
||||
num_workers=config.data.num_workers,
|
||||
pin_memory=True)
|
||||
|
||||
data_iter = iter(train_loader)
|
||||
|
||||
print("data loader set")
|
||||
|
||||
# Setup SDEs
|
||||
sde = sde_lib.VPSDE(beta_min=config.model.beta_min, beta_max=config.model.beta_max, N=config.model.num_scales)
|
||||
|
||||
# Build one-step training and evaluation functions
|
||||
optimize_fn = losses.optimization_manager(config)
|
||||
train_step_fn = losses.get_step_fn(sde, train=True, optimize_fn=optimize_fn,
|
||||
mask=mask, loss_type=config.training.loss_type)
|
||||
|
||||
num_train_steps = config.training.n_iters
|
||||
|
||||
# In case there are multiple hosts (e.g., TPU pods), only log to host 0
|
||||
logging.info("Starting training loop at step %d." % (initial_step // config.training.iter_size,))
|
||||
|
||||
iter_size = config.training.iter_size
|
||||
for step in range(initial_step // iter_size, num_train_steps + 1):
|
||||
tmp_loss = 0.0
|
||||
for step_inner in range(iter_size):
|
||||
try:
|
||||
# batch, batch_mask = next(data_iter)
|
||||
batch = next(data_iter)
|
||||
except StopIteration:
|
||||
# StopIteration is thrown if dataset ends
|
||||
# reinitialize data loader
|
||||
data_iter = iter(train_loader)
|
||||
batch = next(data_iter)
|
||||
|
||||
batch = batch.cuda()
|
||||
|
||||
# Execute one training step
|
||||
clear_grad_flag = (step_inner == 0)
|
||||
update_param_flag = (step_inner == iter_size - 1)
|
||||
loss_dict = train_step_fn(state, batch, clear_grad=clear_grad_flag, update_param=update_param_flag)
|
||||
loss = loss_dict['loss']
|
||||
tmp_loss += loss.item()
|
||||
|
||||
tmp_loss /= iter_size
|
||||
if step % config.training.log_freq == 0:
|
||||
logging.info("step: %d, training_loss: %.5e" % (step, tmp_loss))
|
||||
sys.stdout.flush()
|
||||
writer.add_scalar("training_loss", loss, step)
|
||||
|
||||
# Save a temporary checkpoint to resume training after pre-emption periodically
|
||||
if step != 0 and step % config.training.snapshot_freq_for_preemption == 0:
|
||||
logging.info(f"save meta at iter {step}")
|
||||
save_checkpoint(checkpoint_meta_dir, state)
|
||||
|
||||
# Save a checkpoint periodically and generate samples if needed
|
||||
if step != 0 and step % config.training.snapshot_freq == 0 or step == num_train_steps:
|
||||
logging.info(f"save model: {step}-th")
|
||||
save_checkpoint(os.path.join(checkpoint_dir, f'checkpoint_{step}.pth'), state)
|
|
@ -0,0 +1,31 @@
|
|||
import torch
|
||||
import tensorflow as tf
|
||||
import os
|
||||
import logging
|
||||
|
||||
|
||||
def restore_checkpoint(ckpt_dir, state, device, strict=False):
|
||||
if not tf.io.gfile.exists(ckpt_dir):
|
||||
tf.io.gfile.makedirs(os.path.dirname(ckpt_dir))
|
||||
logging.warning(f"No checkpoint found at {ckpt_dir}. "
|
||||
f"Returned the same state as input")
|
||||
if strict:
|
||||
raise
|
||||
return state
|
||||
else:
|
||||
loaded_state = torch.load(ckpt_dir, map_location=device)
|
||||
state['optimizer'].load_state_dict(loaded_state['optimizer'])
|
||||
state['model'].load_state_dict(loaded_state['model'], strict=False)
|
||||
state['ema'].load_state_dict(loaded_state['ema'])
|
||||
state['step'] = loaded_state['step']
|
||||
return state
|
||||
|
||||
|
||||
def save_checkpoint(ckpt_dir, state):
|
||||
saved_state = {
|
||||
'optimizer': state['optimizer'].state_dict(),
|
||||
'model': state['model'].state_dict(),
|
||||
'ema': state['ema'].state_dict(),
|
||||
'step': state['step']
|
||||
}
|
||||
torch.save(saved_state, ckpt_dir)
|
|
@ -0,0 +1,77 @@
|
|||
# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
#
|
||||
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
||||
# property and proprietary rights in and to this material, related
|
||||
# documentation and any modifications thereto. Any use, reproduction,
|
||||
# disclosure or distribution of this material and related documentation
|
||||
# without an express license agreement from NVIDIA CORPORATION or
|
||||
# its affiliates is strictly prohibited.
|
||||
|
||||
import torch
|
||||
|
||||
class Dataset(torch.utils.data.Dataset):
|
||||
"""Basic dataset interface"""
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def __len__(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def __getitem__(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def collate(self, batch):
|
||||
iter_res, iter_spp = batch[0]['resolution'], batch[0]['spp']
|
||||
res_dict = {
|
||||
'mv' : torch.cat(list([item['mv'] for item in batch]), dim=0),
|
||||
'mvp' : torch.cat(list([item['mvp'] for item in batch]), dim=0),
|
||||
'campos' : torch.cat(list([item['campos'] for item in batch]), dim=0),
|
||||
'resolution' : iter_res,
|
||||
'spp' : iter_spp,
|
||||
'img' : torch.cat(list([item['img'] for item in batch]), dim=0)
|
||||
}
|
||||
|
||||
if 'spts' in batch[0]:
|
||||
res_dict['spts'] = batch[0]['spts']
|
||||
if 'vpts' in batch[0]:
|
||||
res_dict['vpts'] = batch[0]['vpts']
|
||||
if 'faces' in batch[0]:
|
||||
res_dict['faces'] = batch[0]['faces']
|
||||
if 'rast_triangle_id' in batch[0]:
|
||||
res_dict['rast_triangle_id'] = batch[0]['rast_triangle_id']
|
||||
|
||||
if 'depth' in batch[0]:
|
||||
res_dict['depth'] = torch.cat(list([item['depth'] for item in batch]), dim=0)
|
||||
if 'normal' in batch[0]:
|
||||
res_dict['normal'] = torch.cat(list([item['normal'] for item in batch]), dim=0)
|
||||
if 'geo_normal' in batch[0]:
|
||||
res_dict['geo_normal'] = torch.cat(list([item['geo_normal'] for item in batch]), dim=0)
|
||||
if 'geo_viewdir' in batch[0]:
|
||||
res_dict['geo_viewdir'] = torch.cat(list([item['geo_viewdir'] for item in batch]), dim=0)
|
||||
if 'pos' in batch[0]:
|
||||
res_dict['pos'] = torch.cat(list([item['pos'] for item in batch]), dim=0)
|
||||
if 'mask' in batch[0]:
|
||||
res_dict['mask'] = torch.cat(list([item['mask'] for item in batch]), dim=0)
|
||||
if 'mask_cont' in batch[0]:
|
||||
res_dict['mask_cont'] = torch.cat(list([item['mask_cont'] for item in batch]), dim=0)
|
||||
if 'envlight_transform' in batch[0]:
|
||||
if batch[0]['envlight_transform'] is not None:
|
||||
res_dict['envlight_transform'] = torch.cat(list([item['envlight_transform'] for item in batch]), dim=0)
|
||||
else:
|
||||
res_dict['envlight_transform'] = None
|
||||
|
||||
try:
|
||||
res_dict['depth_second'] = torch.cat(list([item['depth_second'] for item in batch]), dim=0)
|
||||
except:
|
||||
pass
|
||||
try:
|
||||
res_dict['normal_second'] = torch.cat(list([item['normal_second'] for item in batch]), dim=0)
|
||||
except:
|
||||
pass
|
||||
try:
|
||||
res_dict['img_second'] = torch.cat(list([item['img_second'] for item in batch]), dim=0)
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
return res_dict
|
|
@ -0,0 +1,163 @@
|
|||
# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
#
|
||||
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
||||
# property and proprietary rights in and to this material, related
|
||||
# documentation and any modifications thereto. Any use, reproduction,
|
||||
# disclosure or distribution of this material and related documentation
|
||||
# without an express license agreement from NVIDIA CORPORATION or
|
||||
# its affiliates is strictly prohibited.
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import sys
|
||||
|
||||
from ..render import util
|
||||
from ..render import mesh
|
||||
from ..render import render
|
||||
from ..render import light
|
||||
|
||||
from .dataset import Dataset
|
||||
|
||||
import kaolin
|
||||
|
||||
###############################################################################
|
||||
# Reference dataset using mesh & rendering
|
||||
###############################################################################
|
||||
|
||||
class DatasetMesh(Dataset):
|
||||
|
||||
def __init__(self, ref_mesh, glctx, cam_radius, FLAGS, validate=False):
|
||||
# Init
|
||||
self.glctx = glctx
|
||||
self.cam_radius = cam_radius
|
||||
self.FLAGS = FLAGS
|
||||
self.validate = validate
|
||||
self.fovy = np.deg2rad(45)
|
||||
self.aspect = FLAGS.train_res[1] / FLAGS.train_res[0]
|
||||
self.random_lgt = FLAGS.random_lgt
|
||||
self.camera_lgt = False
|
||||
self.flat_shading = FLAGS.dataset_flat_shading
|
||||
|
||||
|
||||
if self.FLAGS.local_rank == 0:
|
||||
print(f"use flag shading {FLAGS.dataset_flat_shading}")
|
||||
print("DatasetMesh: ref mesh has %d triangles and %d vertices" % (ref_mesh.t_pos_idx.shape[0], ref_mesh.v_pos.shape[0]))
|
||||
|
||||
# Sanity test training texture resolution
|
||||
ref_texture_res = np.maximum(ref_mesh.material['kd'].getRes(), ref_mesh.material['ks'].getRes())
|
||||
if 'normal' in ref_mesh.material:
|
||||
ref_texture_res = np.maximum(ref_texture_res, ref_mesh.material['normal'].getRes())
|
||||
if self.FLAGS.local_rank == 0 and FLAGS.texture_res[0] < ref_texture_res[0] or FLAGS.texture_res[1] < ref_texture_res[1]:
|
||||
print("---> WARNING: Picked a texture resolution lower than the reference mesh [%d, %d] < [%d, %d]" % (FLAGS.texture_res[0], FLAGS.texture_res[1], ref_texture_res[0], ref_texture_res[1]))
|
||||
|
||||
print("Loading env map")
|
||||
sys.stdout.flush()
|
||||
# Load environment map texture
|
||||
self.envlight = light.load_env(FLAGS.envmap, scale=FLAGS.env_scale)
|
||||
|
||||
print("Computing tangents")
|
||||
sys.stdout.flush()
|
||||
try:
|
||||
self.ref_mesh = mesh.compute_tangents(ref_mesh)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
print("Continue without tangents...")
|
||||
self.ref_mesh = ref_mesh
|
||||
|
||||
def _rotate_scene(self, itr):
|
||||
proj_mtx = util.perspective(self.fovy, self.FLAGS.display_res[1] / self.FLAGS.display_res[0], self.FLAGS.cam_near_far[0], self.FLAGS.cam_near_far[1])
|
||||
|
||||
# Smooth rotation for display.
|
||||
ang = (itr / 50) * np.pi * 2
|
||||
mv = util.translate(0, 0, -self.cam_radius) @ (util.rotate_x(-0.4) @ util.rotate_y(ang))
|
||||
mvp = proj_mtx @ mv
|
||||
campos = torch.linalg.inv(mv)[:3, 3]
|
||||
|
||||
return mv[None, ...].cuda(), mvp[None, ...].cuda(), campos[None, ...].cuda(), self.FLAGS.display_res, self.FLAGS.spp
|
||||
|
||||
def _random_scene(self):
|
||||
# ==============================================================================================
|
||||
# Setup projection matrix
|
||||
# ==============================================================================================
|
||||
iter_res = self.FLAGS.train_res
|
||||
proj_mtx = util.perspective(self.fovy, iter_res[1] / iter_res[0], self.FLAGS.cam_near_far[0], self.FLAGS.cam_near_far[1])
|
||||
|
||||
# ==============================================================================================
|
||||
# Random camera & light position
|
||||
# ==============================================================================================
|
||||
|
||||
# Random rotation/translation matrix for optimization.
|
||||
mv = util.translate(0, 0, -self.cam_radius) @ util.random_rotation_translation(0.2)
|
||||
mvp = proj_mtx @ mv
|
||||
campos = torch.linalg.inv(mv)[:3, 3]
|
||||
|
||||
return mv[None, ...].cuda(), mvp[None, ...].cuda(), campos[None, ...].cuda(), iter_res, self.FLAGS.spp # Add batch dimension
|
||||
|
||||
def __len__(self):
|
||||
return 50 if self.validate else (self.FLAGS.iter + 1) * self.FLAGS.batch
|
||||
|
||||
def __getitem__(self, itr):
|
||||
# ==============================================================================================
|
||||
# Randomize scene parameters
|
||||
# ==============================================================================================
|
||||
|
||||
if self.validate:
|
||||
mv, mvp, campos, iter_res, iter_spp = self._rotate_scene(itr)
|
||||
camera_mv = None
|
||||
else:
|
||||
mv, mvp, campos, iter_res, iter_spp = self._random_scene()
|
||||
if self.random_lgt:
|
||||
rnd_rot = util.random_rotation()
|
||||
camera_mv = rnd_rot.unsqueeze(0).clone()
|
||||
elif self.camera_lgt:
|
||||
camera_mv = mv.clone()
|
||||
else:
|
||||
camera_mv = None
|
||||
|
||||
|
||||
|
||||
with torch.no_grad():
|
||||
render_out = render.render_mesh(self.glctx, self.ref_mesh, mvp, campos, self.envlight, iter_res, spp=iter_spp,
|
||||
num_layers=self.FLAGS.layers, msaa=True, background=None, xfm_lgt=camera_mv, flat_shading=self.flat_shading)
|
||||
img = render_out['shaded']
|
||||
img_second = render_out['shaded_second']
|
||||
normal = render_out['normal']
|
||||
depth = render_out['depth']
|
||||
geo_normal = render_out['geo_normal']
|
||||
pos = render_out['pos']
|
||||
|
||||
sample_points = torch.tensor(kaolin.ops.mesh.sample_points(self.ref_mesh.v_pos.unsqueeze(0), self.ref_mesh.t_pos_idx, 50000)[0][0])
|
||||
vertex_points = self.ref_mesh.v_pos
|
||||
|
||||
return_dict = {
|
||||
'mv' : mv,
|
||||
'mvp' : mvp,
|
||||
'campos' : campos,
|
||||
'resolution' : iter_res,
|
||||
'spp' : iter_spp,
|
||||
'img' : img,
|
||||
'img_second' : img_second,
|
||||
'spts': sample_points,
|
||||
'vpts': vertex_points,
|
||||
'faces': self.ref_mesh.t_pos_idx,
|
||||
'depth': depth,
|
||||
'normal': normal,
|
||||
'geo_normal': geo_normal,
|
||||
'geo_viewdir': render_out['geo_viewdir'],
|
||||
'pos': pos,
|
||||
'envlight_transform': camera_mv,
|
||||
'mask': render_out['mask'],
|
||||
'mask_cont': render_out['mask_cont'],
|
||||
'rast_triangle_id': render_out['rast_triangle_id']
|
||||
}
|
||||
|
||||
try:
|
||||
return_dict['depth_second'] = render_out['depth_second']
|
||||
except:
|
||||
pass
|
||||
|
||||
try:
|
||||
return_dict['normal_second'] = render_out['normal_second']
|
||||
except:
|
||||
pass
|
||||
return return_dict
|
|
@ -0,0 +1,32 @@
|
|||
# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
#
|
||||
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
||||
# property and proprietary rights in and to this material, related
|
||||
# documentation and any modifications thereto. Any use, reproduction,
|
||||
# disclosure or distribution of this material and related documentation
|
||||
# without an express license agreement from NVIDIA CORPORATION or
|
||||
# its affiliates is strictly prohibited.
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
import os
|
||||
import json
|
||||
|
||||
###############################################################################
|
||||
# ShapeNet Dataset to get meshes
|
||||
###############################################################################
|
||||
|
||||
class ShapeNetDataset(object):
|
||||
def __init__(self, shapenet_json):
|
||||
with open(shapenet_json, 'r') as f:
|
||||
self.mesh_list = json.load(f)
|
||||
|
||||
print(f"len all data: {len(self.mesh_list)}")
|
||||
|
||||
def __len__(self):
|
||||
return len(self.mesh_list)
|
||||
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return self.mesh_list[idx]
|
|
@ -0,0 +1,462 @@
|
|||
# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
#
|
||||
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
||||
# property and proprietary rights in and to this material, related
|
||||
# documentation and any modifications thereto. Any use, reproduction,
|
||||
# disclosure or distribution of this material and related documentation
|
||||
# without an express license agreement from NVIDIA CORPORATION or
|
||||
# its affiliates is strictly prohibited.
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from ..render import mesh
|
||||
from ..render import render
|
||||
from ..render import regularizer
|
||||
|
||||
|
||||
import kaolin
|
||||
import pytorch3d.ops
|
||||
from ..render import util as render_utils
|
||||
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ..render import renderutils as ru
|
||||
|
||||
###############################################################################
|
||||
# Marching tetrahedrons implementation (differentiable), adapted from
|
||||
# https://github.com/NVIDIAGameWorks/kaolin/blob/master/kaolin/ops/conversions/tetmesh.py
|
||||
###############################################################################
|
||||
|
||||
class DMTet:
|
||||
def __init__(self):
|
||||
self.triangle_table = torch.tensor([
|
||||
[-1, -1, -1, -1, -1, -1],
|
||||
[ 1, 0, 2, -1, -1, -1],
|
||||
[ 4, 0, 3, -1, -1, -1],
|
||||
[ 1, 4, 2, 1, 3, 4],
|
||||
[ 3, 1, 5, -1, -1, -1],
|
||||
[ 2, 3, 0, 2, 5, 3],
|
||||
[ 1, 4, 0, 1, 5, 4],
|
||||
[ 4, 2, 5, -1, -1, -1],
|
||||
[ 4, 5, 2, -1, -1, -1],
|
||||
[ 4, 1, 0, 4, 5, 1],
|
||||
[ 3, 2, 0, 3, 5, 2],
|
||||
[ 1, 3, 5, -1, -1, -1],
|
||||
[ 4, 1, 2, 4, 3, 1],
|
||||
[ 3, 0, 4, -1, -1, -1],
|
||||
[ 2, 0, 1, -1, -1, -1],
|
||||
[-1, -1, -1, -1, -1, -1]
|
||||
], dtype=torch.long, device='cuda')
|
||||
|
||||
self.num_triangles_table = torch.tensor([0,1,1,2,1,2,2,1,1,2,2,1,2,1,1,0], dtype=torch.long, device='cuda')
|
||||
self.base_tet_edges = torch.tensor([0,1,0,2,0,3,1,2,1,3,2,3], dtype=torch.long, device='cuda')
|
||||
|
||||
###############################################################################
|
||||
# Utility functions
|
||||
###############################################################################
|
||||
|
||||
def sort_edges(self, edges_ex2):
|
||||
with torch.no_grad():
|
||||
order = (edges_ex2[:,0] > edges_ex2[:,1]).long()
|
||||
order = order.unsqueeze(dim=1)
|
||||
|
||||
a = torch.gather(input=edges_ex2, index=order, dim=1)
|
||||
b = torch.gather(input=edges_ex2, index=1-order, dim=1)
|
||||
|
||||
return torch.stack([a, b],-1)
|
||||
|
||||
def map_uv(self, faces, face_gidx, max_idx):
|
||||
N = int(np.ceil(np.sqrt((max_idx+1)//2)))
|
||||
tex_y, tex_x = torch.meshgrid(
|
||||
torch.linspace(0, 1 - (1 / N), N, dtype=torch.float32, device="cuda"),
|
||||
torch.linspace(0, 1 - (1 / N), N, dtype=torch.float32, device="cuda"),
|
||||
indexing='ij'
|
||||
)
|
||||
|
||||
pad = 0.9 / N
|
||||
|
||||
uvs = torch.stack([
|
||||
tex_x , tex_y,
|
||||
tex_x + pad, tex_y,
|
||||
tex_x + pad, tex_y + pad,
|
||||
tex_x , tex_y + pad
|
||||
], dim=-1).view(-1, 2)
|
||||
|
||||
def _idx(tet_idx, N):
|
||||
x = tet_idx % N
|
||||
y = torch.div(tet_idx, N, rounding_mode='trunc')
|
||||
return y * N + x
|
||||
|
||||
tet_idx = _idx(torch.div(face_gidx, 2, rounding_mode='trunc'), N)
|
||||
tri_idx = face_gidx % 2
|
||||
|
||||
uv_idx = torch.stack((
|
||||
tet_idx * 4, tet_idx * 4 + tri_idx + 1, tet_idx * 4 + tri_idx + 2
|
||||
), dim = -1). view(-1, 3)
|
||||
|
||||
return uvs, uv_idx
|
||||
|
||||
###############################################################################
|
||||
# Marching tets implementation
|
||||
###############################################################################
|
||||
|
||||
def __call__(self, pos_nx3, sdf_n, tet_fx4):
|
||||
with torch.no_grad():
|
||||
occ_n = sdf_n > 0
|
||||
occ_fx4 = occ_n[tet_fx4.reshape(-1)].reshape(-1,4)
|
||||
occ_sum = torch.sum(occ_fx4, -1)
|
||||
valid_tets = (occ_sum>0) & (occ_sum<4)
|
||||
occ_sum = occ_sum[valid_tets]
|
||||
|
||||
# find all vertices
|
||||
all_edges = tet_fx4[valid_tets][:,self.base_tet_edges].reshape(-1,2)
|
||||
all_edges = self.sort_edges(all_edges)
|
||||
unique_edges, idx_map = torch.unique(all_edges,dim=0, return_inverse=True)
|
||||
|
||||
unique_edges = unique_edges.long()
|
||||
mask_edges = occ_n[unique_edges.reshape(-1)].reshape(-1,2).sum(-1) == 1
|
||||
mapping = torch.ones((unique_edges.shape[0]), dtype=torch.long, device="cuda") * -1
|
||||
mapping[mask_edges] = torch.arange(mask_edges.sum(), dtype=torch.long,device="cuda")
|
||||
idx_map = mapping[idx_map] # map edges to verts
|
||||
|
||||
interp_v = unique_edges[mask_edges]
|
||||
edges_to_interp = pos_nx3[interp_v.reshape(-1)].reshape(-1,2,3)
|
||||
edges_to_interp_sdf = sdf_n[interp_v.reshape(-1)].reshape(-1,2,1)
|
||||
edges_to_interp_sdf[:,-1] *= -1
|
||||
|
||||
denominator = edges_to_interp_sdf.sum(1,keepdim = True)
|
||||
|
||||
edges_to_interp_sdf = torch.flip(edges_to_interp_sdf, [1])/denominator
|
||||
verts = (edges_to_interp * edges_to_interp_sdf).sum(1)
|
||||
|
||||
idx_map = idx_map.reshape(-1,6)
|
||||
|
||||
v_id = torch.pow(2, torch.arange(4, dtype=torch.long, device="cuda"))
|
||||
tetindex = (occ_fx4[valid_tets] * v_id.unsqueeze(0)).sum(-1)
|
||||
num_triangles = self.num_triangles_table[tetindex]
|
||||
|
||||
# Generate triangle indices
|
||||
faces = torch.cat((
|
||||
torch.gather(input=idx_map[num_triangles == 1], dim=1, index=self.triangle_table[tetindex[num_triangles == 1]][:, :3]).reshape(-1,3),
|
||||
torch.gather(input=idx_map[num_triangles == 2], dim=1, index=self.triangle_table[tetindex[num_triangles == 2]][:, :6]).reshape(-1,3),
|
||||
), dim=0)
|
||||
|
||||
# Get global face index (static, does not depend on topology)
|
||||
num_tets = tet_fx4.shape[0]
|
||||
tet_gidx = torch.arange(num_tets, dtype=torch.long, device="cuda")[valid_tets]
|
||||
face_gidx = torch.cat((
|
||||
tet_gidx[num_triangles == 1]*2,
|
||||
torch.stack((tet_gidx[num_triangles == 2]*2, tet_gidx[num_triangles == 2]*2 + 1), dim=-1).view(-1)
|
||||
), dim=0)
|
||||
|
||||
uvs, uv_idx = self.map_uv(faces, face_gidx, num_tets*2)
|
||||
|
||||
face_to_valid_tet = torch.cat((
|
||||
tet_gidx[num_triangles == 1],
|
||||
torch.stack((tet_gidx[num_triangles == 2], tet_gidx[num_triangles == 2]), dim=-1).view(-1)
|
||||
), dim=0)
|
||||
|
||||
valid_vert_idx = tet_fx4[tet_gidx[num_triangles > 0]].long().unique()
|
||||
|
||||
return verts, faces, uvs, uv_idx, face_to_valid_tet.long(), valid_vert_idx
|
||||
|
||||
###############################################################################
|
||||
# Regularizer
|
||||
###############################################################################
|
||||
|
||||
def sdf_reg_loss(sdf, all_edges):
|
||||
sdf_f1x6x2 = sdf[all_edges.reshape(-1)].reshape(-1,2)
|
||||
mask = torch.sign(sdf_f1x6x2[...,0]) != torch.sign(sdf_f1x6x2[...,1])
|
||||
sdf_f1x6x2 = sdf_f1x6x2[mask]
|
||||
sdf_diff = torch.nn.functional.binary_cross_entropy_with_logits(sdf_f1x6x2[...,0], (sdf_f1x6x2[...,1] > 0).float()) + \
|
||||
torch.nn.functional.binary_cross_entropy_with_logits(sdf_f1x6x2[...,1], (sdf_f1x6x2[...,0] > 0).float())
|
||||
return sdf_diff
|
||||
|
||||
|
||||
|
||||
class Buffer(object):
|
||||
def __init__(self, shape, capacity, device) -> None:
|
||||
self.len_curr = 0
|
||||
self.pointer = 0
|
||||
self.capacity = capacity
|
||||
self.buffer = torch.zeros((capacity, ) + shape, device=device)
|
||||
|
||||
def push(self, x):
|
||||
'''
|
||||
Push one single data point into the buffer
|
||||
'''
|
||||
self.buffer[self.pointer] = x
|
||||
self.pointer = (self.pointer + 1) % self.capacity
|
||||
if self.len_curr < self.capacity:
|
||||
self.len_curr += 1
|
||||
|
||||
def avg(self):
|
||||
# simple windowed avg without exp decay
|
||||
return torch.sign(torch.sign(self.buffer[:self.len_curr]).float().mean(dim=0)).float()
|
||||
|
||||
###############################################################################
|
||||
# Geometry interface
|
||||
###############################################################################
|
||||
|
||||
class DMTetGeometry(torch.nn.Module):
|
||||
def __init__(self, grid_res, scale, FLAGS, root='./', grid_to_tet=None, deform_scale=1.0, **kwargs):
|
||||
super(DMTetGeometry, self).__init__()
|
||||
|
||||
self.FLAGS = FLAGS
|
||||
self.grid_res = grid_res
|
||||
self.marching_tets = DMTet()
|
||||
self.tanh = False
|
||||
self.deform_scale = deform_scale
|
||||
|
||||
self.grid_to_tet = grid_to_tet
|
||||
|
||||
self.padding = 5
|
||||
self.smooth_kernel = torch.ones(1, 1, self.padding*2 + 1, self.padding*2 + 1).cuda()
|
||||
|
||||
tets = np.load(os.path.join(root, 'data/tets/{}_tets_cropped.npz'.format(self.grid_res)))
|
||||
self.verts = torch.tensor(tets['vertices'], dtype=torch.float32, device='cuda') * scale
|
||||
self.indices = torch.tensor(tets['indices'], dtype=torch.long, device='cuda')
|
||||
self.generate_edges()
|
||||
|
||||
# Random init
|
||||
sdf = torch.rand_like(self.verts[:,0]).clamp(-1.0, 1.0) - 0.1
|
||||
|
||||
self.sdf = torch.nn.Parameter(sdf.clone().detach(), requires_grad=True)
|
||||
self.register_parameter('sdf', self.sdf)
|
||||
|
||||
self.deform = torch.nn.Parameter(torch.zeros_like(self.verts), requires_grad=True)
|
||||
self.register_parameter('deform', self.deform)
|
||||
|
||||
self.sdf_ema = torch.nn.Parameter(sdf.clone().detach(), requires_grad=False)
|
||||
self.deform_ema = torch.nn.Parameter(self.deform.clone().detach(), requires_grad=False)
|
||||
|
||||
self.ema_coeff = 0.9
|
||||
self.sdf_buffer = Buffer(sdf.size(), capacity=200, device='cuda')
|
||||
|
||||
|
||||
def generate_edges(self):
|
||||
with torch.no_grad():
|
||||
edges = torch.tensor([0,1,0,2,0,3,1,2,1,3,2,3], dtype = torch.long, device = "cuda")
|
||||
all_edges = self.indices[:,edges].reshape(-1,2)
|
||||
all_edges_sorted = torch.sort(all_edges, dim=1)[0]
|
||||
self.all_edges = torch.unique(all_edges_sorted, dim=0)
|
||||
|
||||
def getAABB(self):
|
||||
return torch.min(self.verts, dim=0).values, torch.max(self.verts, dim=0).values
|
||||
|
||||
def getVertNNDist(self):
|
||||
v_deformed = (self.verts + 2 / (self.grid_res * 2) * torch.tanh(self.deform)).unsqueeze(0)
|
||||
return (pytorch3d.ops.knn.knn_points(v_deformed, v_deformed, K=2).dists[0, :, -1].detach()) ## K=2 because dist(self, self)=0
|
||||
|
||||
def getTetCenters(self):
|
||||
v_deformed = self.get_deformed() # size: N x 3
|
||||
face_verts = v_deformed[self.indices] # size: M x 4 x 3
|
||||
face_centers = face_verts.mean(dim=1) # size: M x 3
|
||||
|
||||
return face_centers
|
||||
|
||||
def getValidTetIdx(self):
|
||||
# Run DM tet to get a base mesh
|
||||
v_deformed = self.get_deformed()
|
||||
verts, faces, uvs, uv_idx, tet_gidx, valid_vert_idx = self.marching_tets(v_deformed, self.sdf, self.indices)
|
||||
return tet_gidx.long()
|
||||
|
||||
def getValidVertsIdx(self):
|
||||
# Run DM tet to get a base mesh
|
||||
v_deformed = self.get_deformed()
|
||||
verts, faces, uvs, uv_idx, tet_gidx, valid_vert_idx = self.marching_tets(v_deformed, self.sdf, self.indices)
|
||||
return self.indices[tet_gidx.long()].unique()
|
||||
|
||||
def getMesh(self, material, noise=0.0, ema=False):
|
||||
# Run DM tet to get a base mesh
|
||||
v_deformed = self.get_deformed(ema=ema)
|
||||
|
||||
if ema:
|
||||
# sdf = self.sdf * (1 - self.ema_coeff) + self.sdf_ema.detach() * self.ema_coeff
|
||||
sdf = self.sdf_ema
|
||||
else:
|
||||
sdf = self.sdf
|
||||
|
||||
verts, faces, uvs, uv_idx, tet_gidx, valid_vert_idx = self.marching_tets(v_deformed, sdf, self.indices)
|
||||
imesh = mesh.Mesh(verts, faces, v_tex=uvs, t_tex_idx=uv_idx, material=material)
|
||||
|
||||
imesh = mesh.auto_normals(imesh)
|
||||
if material is not None:
|
||||
# Run mesh operations to generate tangent space
|
||||
imesh = mesh.compute_tangents(imesh)
|
||||
imesh.valid_vert_idx = valid_vert_idx
|
||||
|
||||
return imesh
|
||||
|
||||
def get_deformed(self, no_grad=False, ema=False):
|
||||
if no_grad:
|
||||
deform = self.deform.detach()
|
||||
else:
|
||||
deform = self.deform
|
||||
|
||||
if self.tanh:
|
||||
# v_deformed = self.verts + 2 / (self.grid_res * 2) * torch.tanh(self.deform)
|
||||
v_deformed = self.verts + 2 / (self.grid_res * 2) * torch.tanh(deform) * self.deform_scale
|
||||
else:
|
||||
v_deformed = self.verts + 2 / (self.grid_res * 2) * deform * self.deform_scale
|
||||
return v_deformed
|
||||
|
||||
def get_angle(self):
|
||||
with torch.no_grad():
|
||||
comb_list = [
|
||||
(0, 1, 2, 3),
|
||||
(0, 1, 3, 2),
|
||||
(0, 2, 3, 1),
|
||||
(1, 2, 3, 0)
|
||||
]
|
||||
|
||||
directions = torch.zeros(self.indices.size(0), 4).cuda()
|
||||
dir_vec = torch.zeros(self.indices.size(0), 4, 3).cuda()
|
||||
vert_inds = torch.zeros(self.indices.size(0), 4).cuda().long()
|
||||
count = 0
|
||||
vpos_list = self.get_deformed()
|
||||
for comb in comb_list:
|
||||
face = self.indices[:, comb[:3]]
|
||||
face_pos = vpos_list[face, :]
|
||||
face_center = face_pos.mean(1, keepdim=False)
|
||||
v = self.indices[:, comb[3]]
|
||||
test_vec = vpos_list[v]
|
||||
ref_vec = render_utils.safe_normalize(vpos_list[face[:, 0]] - face_center)
|
||||
distance_vec = test_vec - render_utils.dot(test_vec, ref_vec) * ref_vec
|
||||
directions[:, count] = torch.sign(render_utils.dot(test_vec, distance_vec)[:, 0])
|
||||
dir_vec[:, count, :] = distance_vec
|
||||
vert_inds[:, count] = v
|
||||
count += 1
|
||||
return directions, dir_vec, vert_inds
|
||||
|
||||
|
||||
def clamp_deform(self):
|
||||
if not self.tanh:
|
||||
self.deform.data[:] = self.deform.data.clamp(-0.99, 0.99)
|
||||
self.sdf.data[:] = self.sdf.data.clamp(-1.0, 1.0)
|
||||
|
||||
def render(self, glctx, target, lgt, opt_material, bsdf=None, ema=False, xfm_lgt=None, get_visible_tets=False):
|
||||
opt_mesh = self.getMesh(opt_material, ema=ema)
|
||||
tet_centers = self.getTetCenters() if get_visible_tets else None
|
||||
return render.render_mesh(glctx, opt_mesh, target['mvp'], target['campos'], lgt, target['resolution'], spp=target['spp'],
|
||||
msaa=True, background=target['background'], bsdf=bsdf, xfm_lgt=xfm_lgt, tet_centers=tet_centers)
|
||||
|
||||
def render_with_mesh(self, glctx, target, lgt, opt_material, bsdf=None, noise=0.0, ema=False, xfm_lgt=None):
|
||||
opt_mesh = self.getMesh(opt_material, noise=noise, ema=ema)
|
||||
return opt_mesh, render.render_mesh(glctx, opt_mesh, target['mvp'], target['campos'], lgt, target['resolution'], spp=target['spp'],
|
||||
msaa=True, background=target['background'], bsdf=bsdf, xfm_lgt=xfm_lgt)
|
||||
|
||||
def update_ema(self, ema_coeff=0.9):
|
||||
self.sdf_buffer.push(self.sdf)
|
||||
self.sdf_ema.data[:] = self.sdf_buffer.avg()
|
||||
self.deform_ema.data[:] = self.deform.data[:]
|
||||
|
||||
|
||||
def render_ema(self, glctx, target, lgt, opt_material, bsdf=None, xfm_lgt=None):
|
||||
opt_mesh = self.getMesh(opt_material, ema=True)
|
||||
return render.render_mesh(glctx, opt_mesh, target['mvp'], target['campos'], lgt, target['resolution'], spp=target['spp'],
|
||||
msaa=True, background=target['background'], bsdf=bsdf, xfm_lgt=xfm_lgt)
|
||||
|
||||
def tick(self, glctx, target, lgt, opt_material, loss_fn, iteration, with_reg=True, xfm_lgt=None, no_depth_thin=True):
|
||||
|
||||
self.deform.requires_grad = True
|
||||
|
||||
if iteration > 200 and iteration < 2000 and iteration % 20 == 0:
|
||||
with torch.no_grad():
|
||||
v_pos = self.get_deformed()
|
||||
v_pos_camera_homo = ru.xfm_points(v_pos[None, ...], target['mvp'])
|
||||
v_pos_camera = v_pos_camera_homo[:, :, :2] / v_pos_camera_homo[:, :, -1:]
|
||||
v_pos_camera_discrete = torch.round((v_pos_camera * 0.5 + 0.5).clip(0, 1) * (target['resolution'][0] - 1)).long()
|
||||
mask_cont = F.conv2d(target['mask_cont'][:, :, :, 0].unsqueeze(1), self.smooth_kernel, stride=1, padding=self.padding)[:, 0]
|
||||
target_mask = mask_cont == 0
|
||||
for k in range(target_mask.size(0)):
|
||||
assert v_pos_camera_discrete[k].min() >= 0 and v_pos_camera_discrete[k].max() < target['resolution'][0]
|
||||
v_mask = target_mask[k, v_pos_camera_discrete[k, :, 1], v_pos_camera_discrete[k, :, 0]].view(v_pos.size(0))
|
||||
self.sdf.data[v_mask] = 1e-2
|
||||
self.deform.data[v_mask] = 0.0
|
||||
|
||||
# ==============================================================================================
|
||||
# Render optimizable object with identical conditions
|
||||
# ==============================================================================================
|
||||
imesh, buffers = self.render_with_mesh(glctx, target, lgt, opt_material, noise=0.0, xfm_lgt=xfm_lgt)
|
||||
|
||||
# ==============================================================================================
|
||||
# Compute loss
|
||||
# ==============================================================================================
|
||||
t_iter = iteration / self.FLAGS.iter
|
||||
|
||||
# Image-space loss, split into a coverage component and a color component
|
||||
color_ref = target['img']
|
||||
img_loss = torch.tensor(0.0).cuda()
|
||||
alpha_scale = 1.0
|
||||
img_loss = torch.nn.functional.mse_loss(buffers['shaded'][..., 3:], color_ref[..., 3:]) * alpha_scale
|
||||
img_loss = img_loss + loss_fn(buffers['shaded'][..., 0:3] * color_ref[..., 3:], color_ref[..., 0:3] * color_ref[..., 3:])
|
||||
|
||||
|
||||
color_ref_second = target['img_second']
|
||||
img_loss = img_loss + torch.nn.functional.mse_loss(buffers['shaded_second'][..., 3:], color_ref_second[..., 3:]) * alpha_scale * 1e-1
|
||||
img_loss = img_loss + loss_fn(buffers['shaded_second'][..., 0:3] * color_ref_second[..., 3:], color_ref_second[..., 0:3] * color_ref_second[..., 3:]) * 1e-1
|
||||
|
||||
mask = (target['mask_cont'][:, :, :, 0] == 1.0).float()
|
||||
|
||||
if iteration < 10000:
|
||||
depth_scale = 100.0
|
||||
else:
|
||||
depth_scale = 1.0
|
||||
|
||||
if iteration % 300 == 0 and iteration < 1790:
|
||||
self.deform.data[:] *= 0.4
|
||||
|
||||
if no_depth_thin:
|
||||
valid_depth_mask = (target['depth_second'] >= 0).float().detach()
|
||||
depth_prox_mask = ((target['depth_second'] - target['depth']).abs() >= 5e-3).float().detach()
|
||||
else:
|
||||
valid_depth_mask = 1.0
|
||||
|
||||
depth_diff = (buffers['depth'][:, :, :, :1] - target['depth'][:, :, :, :1]).abs() * mask.unsqueeze(-1) * valid_depth_mask
|
||||
depth_diff_second = (buffers['depth_second'][:, :, :, :1] - target['depth_second'][:, :, :, :1]).abs() * mask.unsqueeze(-1) * valid_depth_mask * depth_prox_mask * 1e-1
|
||||
|
||||
thres = 1.0
|
||||
l1_loss_mask = (depth_diff < thres).float()
|
||||
l1_loss_mask_second = (depth_diff_second < thres).float()
|
||||
|
||||
img_loss = img_loss + (
|
||||
(
|
||||
l1_loss_mask * depth_diff
|
||||
+ (1 - l1_loss_mask) * (depth_diff.pow(2) + thres - thres**2)
|
||||
).mean() * 1.0 * depth_scale
|
||||
+ (
|
||||
l1_loss_mask_second * depth_diff_second
|
||||
+ (1 - l1_loss_mask_second) * (depth_diff_second.pow(2) + thres - thres**2)
|
||||
).mean() * 1.0 * depth_scale
|
||||
)
|
||||
|
||||
|
||||
reg_loss = torch.tensor(0.0).cuda()
|
||||
|
||||
# SDF regularizer
|
||||
iter_thres = 0
|
||||
sdf_weight = self.FLAGS.sdf_regularizer - (self.FLAGS.sdf_regularizer - 0.01) * min(1.0, 4.0 * ((iteration - iter_thres) / (self.FLAGS.iter - iter_thres)))
|
||||
|
||||
sdf_mask = torch.zeros_like(self.sdf, device=self.sdf.device)
|
||||
sdf_mask[imesh.valid_vert_idx] = 1.0
|
||||
sdf_masked = self.sdf.detach() * sdf_mask + self.sdf * (1 - sdf_mask)
|
||||
reg_loss = sdf_reg_loss(sdf_masked, self.all_edges).mean() * sdf_weight * 0.1 # Dropoff to 0.01
|
||||
|
||||
# Albedo (k_d) smoothnesss regularizer
|
||||
reg_loss += torch.mean(buffers['kd_grad'][..., :-1] * buffers['kd_grad'][..., -1:]) * 0.03 * min(1.0, iteration / 500)
|
||||
|
||||
# Visibility regularizer
|
||||
reg_loss += torch.mean(buffers['occlusion'][..., :-1] * buffers['occlusion'][..., -1:]) * 1e0 * min(1.0, iteration / 500)
|
||||
|
||||
# pointcloud chamfer distance
|
||||
pred_points = kaolin.ops.mesh.sample_points(imesh.v_pos.unsqueeze(0), imesh.t_pos_idx, 50000)[0][0]
|
||||
target_pts = target['spts']
|
||||
chamfer = kaolin.metrics.pointcloud.chamfer_distance(pred_points.unsqueeze(0), target_pts.unsqueeze(0)).mean()
|
||||
|
||||
reg_loss += chamfer
|
||||
|
||||
|
||||
return img_loss, reg_loss
|
|
@ -0,0 +1,350 @@
|
|||
# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
#
|
||||
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
||||
# property and proprietary rights in and to this material, related
|
||||
# documentation and any modifications thereto. Any use, reproduction,
|
||||
# disclosure or distribution of this material and related documentation
|
||||
# without an express license agreement from NVIDIA CORPORATION or
|
||||
# its affiliates is strictly prohibited.
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from ..render import mesh
|
||||
from ..render import render
|
||||
from ..render import regularizer
|
||||
|
||||
import kaolin
|
||||
from ..render import util as render_utils
|
||||
import torch.nn.functional as F
|
||||
|
||||
###############################################################################
|
||||
# Marching tetrahedrons implementation (differentiable), adapted from
|
||||
# https://github.com/NVIDIAGameWorks/kaolin/blob/master/kaolin/ops/conversions/tetmesh.py
|
||||
###############################################################################
|
||||
|
||||
class DMTet:
|
||||
def __init__(self):
|
||||
self.triangle_table = torch.tensor([
|
||||
[-1, -1, -1, -1, -1, -1],
|
||||
[ 1, 0, 2, -1, -1, -1],
|
||||
[ 4, 0, 3, -1, -1, -1],
|
||||
[ 1, 4, 2, 1, 3, 4],
|
||||
[ 3, 1, 5, -1, -1, -1],
|
||||
[ 2, 3, 0, 2, 5, 3],
|
||||
[ 1, 4, 0, 1, 5, 4],
|
||||
[ 4, 2, 5, -1, -1, -1],
|
||||
[ 4, 5, 2, -1, -1, -1],
|
||||
[ 4, 1, 0, 4, 5, 1],
|
||||
[ 3, 2, 0, 3, 5, 2],
|
||||
[ 1, 3, 5, -1, -1, -1],
|
||||
[ 4, 1, 2, 4, 3, 1],
|
||||
[ 3, 0, 4, -1, -1, -1],
|
||||
[ 2, 0, 1, -1, -1, -1],
|
||||
[-1, -1, -1, -1, -1, -1]
|
||||
], dtype=torch.long, device='cuda')
|
||||
|
||||
self.num_triangles_table = torch.tensor([0,1,1,2,1,2,2,1,1,2,2,1,2,1,1,0], dtype=torch.long, device='cuda')
|
||||
self.base_tet_edges = torch.tensor([0,1,0,2,0,3,1,2,1,3,2,3], dtype=torch.long, device='cuda')
|
||||
|
||||
###############################################################################
|
||||
# Utility functions
|
||||
###############################################################################
|
||||
|
||||
def sort_edges(self, edges_ex2):
|
||||
with torch.no_grad():
|
||||
order = (edges_ex2[:,0] > edges_ex2[:,1]).long()
|
||||
order = order.unsqueeze(dim=1)
|
||||
|
||||
a = torch.gather(input=edges_ex2, index=order, dim=1)
|
||||
b = torch.gather(input=edges_ex2, index=1-order, dim=1)
|
||||
|
||||
return torch.stack([a, b],-1)
|
||||
|
||||
def map_uv(self, faces, face_gidx, max_idx):
|
||||
N = int(np.ceil(np.sqrt((max_idx+1)//2)))
|
||||
tex_y, tex_x = torch.meshgrid(
|
||||
torch.linspace(0, 1 - (1 / N), N, dtype=torch.float32, device="cuda"),
|
||||
torch.linspace(0, 1 - (1 / N), N, dtype=torch.float32, device="cuda"),
|
||||
indexing='ij'
|
||||
)
|
||||
|
||||
pad = 0.9 / N
|
||||
|
||||
uvs = torch.stack([
|
||||
tex_x , tex_y,
|
||||
tex_x + pad, tex_y,
|
||||
tex_x + pad, tex_y + pad,
|
||||
tex_x , tex_y + pad
|
||||
], dim=-1).view(-1, 2)
|
||||
|
||||
def _idx(tet_idx, N):
|
||||
x = tet_idx % N
|
||||
y = torch.div(tet_idx, N, rounding_mode='trunc')
|
||||
return y * N + x
|
||||
|
||||
tet_idx = _idx(torch.div(face_gidx, 2, rounding_mode='trunc'), N)
|
||||
tri_idx = face_gidx % 2
|
||||
|
||||
uv_idx = torch.stack((
|
||||
tet_idx * 4, tet_idx * 4 + tri_idx + 1, tet_idx * 4 + tri_idx + 2
|
||||
), dim = -1). view(-1, 3)
|
||||
|
||||
return uvs, uv_idx
|
||||
|
||||
###############################################################################
|
||||
# Marching tets implementation
|
||||
###############################################################################
|
||||
|
||||
def __call__(self, pos_nx3, sdf_n, tet_fx4, get_tet_gidx=False):
|
||||
with torch.no_grad():
|
||||
occ_n = sdf_n > 0
|
||||
occ_fx4 = occ_n[tet_fx4.reshape(-1)].reshape(-1,4)
|
||||
occ_sum = torch.sum(occ_fx4, -1)
|
||||
valid_tets = (occ_sum>0) & (occ_sum<4)
|
||||
occ_sum = occ_sum[valid_tets]
|
||||
|
||||
# find all vertices
|
||||
all_edges = tet_fx4[valid_tets][:,self.base_tet_edges].reshape(-1,2)
|
||||
all_edges = self.sort_edges(all_edges)
|
||||
unique_edges, idx_map = torch.unique(all_edges,dim=0, return_inverse=True)
|
||||
|
||||
unique_edges = unique_edges.long()
|
||||
mask_edges = occ_n[unique_edges.reshape(-1)].reshape(-1,2).sum(-1) == 1
|
||||
mapping = torch.ones((unique_edges.shape[0]), dtype=torch.long, device="cuda") * -1
|
||||
mapping[mask_edges] = torch.arange(mask_edges.sum(), dtype=torch.long,device="cuda")
|
||||
idx_map = mapping[idx_map] # map edges to verts
|
||||
|
||||
interp_v = unique_edges[mask_edges]
|
||||
edges_to_interp = pos_nx3[interp_v.reshape(-1)].reshape(-1,2,3)
|
||||
edges_to_interp_sdf = sdf_n[interp_v.reshape(-1)].reshape(-1,2,1)
|
||||
edges_to_interp_sdf[:,-1] *= -1
|
||||
|
||||
denominator = edges_to_interp_sdf.sum(1,keepdim = True)
|
||||
|
||||
edges_to_interp_sdf = torch.flip(edges_to_interp_sdf, [1])/denominator
|
||||
verts = (edges_to_interp * edges_to_interp_sdf).sum(1)
|
||||
|
||||
idx_map = idx_map.reshape(-1,6)
|
||||
|
||||
v_id = torch.pow(2, torch.arange(4, dtype=torch.long, device="cuda"))
|
||||
tetindex = (occ_fx4[valid_tets] * v_id.unsqueeze(0)).sum(-1)
|
||||
num_triangles = self.num_triangles_table[tetindex]
|
||||
|
||||
# Generate triangle indices
|
||||
faces = torch.cat((
|
||||
torch.gather(input=idx_map[num_triangles == 1], dim=1, index=self.triangle_table[tetindex[num_triangles == 1]][:, :3]).reshape(-1,3),
|
||||
torch.gather(input=idx_map[num_triangles == 2], dim=1, index=self.triangle_table[tetindex[num_triangles == 2]][:, :6]).reshape(-1,3),
|
||||
), dim=0)
|
||||
|
||||
# Get global face index (static, does not depend on topology)
|
||||
num_tets = tet_fx4.shape[0]
|
||||
tet_gidx = torch.arange(num_tets, dtype=torch.long, device="cuda")[valid_tets]
|
||||
face_gidx = torch.cat((
|
||||
tet_gidx[num_triangles == 1]*2,
|
||||
torch.stack((tet_gidx[num_triangles == 2]*2, tet_gidx[num_triangles == 2]*2 + 1), dim=-1).view(-1)
|
||||
), dim=0)
|
||||
|
||||
uvs, uv_idx = self.map_uv(faces, face_gidx, num_tets*2)
|
||||
|
||||
if get_tet_gidx:
|
||||
face_to_valid_tet = torch.cat((
|
||||
tet_gidx[num_triangles == 1],
|
||||
torch.stack((tet_gidx[num_triangles == 2], tet_gidx[num_triangles == 2]), dim=-1).view(-1)
|
||||
), dim=0)
|
||||
|
||||
return verts, faces, uvs, uv_idx, face_to_valid_tet.long()
|
||||
else:
|
||||
return verts, faces, uvs, uv_idx
|
||||
|
||||
###############################################################################
|
||||
# Regularizer
|
||||
###############################################################################
|
||||
|
||||
def sdf_reg_loss(sdf, all_edges):
|
||||
sdf_f1x6x2 = sdf[all_edges.reshape(-1)].reshape(-1,2)
|
||||
mask = torch.sign(sdf_f1x6x2[...,0]) != torch.sign(sdf_f1x6x2[...,1])
|
||||
sdf_f1x6x2 = sdf_f1x6x2[mask]
|
||||
sdf_diff = torch.nn.functional.binary_cross_entropy_with_logits(sdf_f1x6x2[...,0], (sdf_f1x6x2[...,1] > 0).float()) + \
|
||||
torch.nn.functional.binary_cross_entropy_with_logits(sdf_f1x6x2[...,1], (sdf_f1x6x2[...,0] > 0).float())
|
||||
return sdf_diff
|
||||
|
||||
###############################################################################
|
||||
# Geometry interface
|
||||
###############################################################################
|
||||
|
||||
class DMTetGeometryFixedTopo(torch.nn.Module):
|
||||
def __init__(self, dmt_geometry, base_mesh, grid_res, scale, FLAGS, deform_scale=1.0, **kwargs):
|
||||
super(DMTetGeometryFixedTopo, self).__init__()
|
||||
|
||||
self.FLAGS = FLAGS
|
||||
self.grid_res = grid_res
|
||||
self.marching_tets = DMTet()
|
||||
self.initial_guess = base_mesh
|
||||
self.scale = scale
|
||||
self.tanh = False
|
||||
self.deform_scale = deform_scale
|
||||
|
||||
tets = np.load('./data/tets/{}_tets_cropped.npz'.format(self.grid_res))
|
||||
|
||||
self.verts = torch.tensor(tets['vertices'], dtype=torch.float32, device='cuda') * scale
|
||||
self.indices = torch.tensor(tets['indices'], dtype=torch.long, device='cuda')
|
||||
self.generate_edges()
|
||||
|
||||
self.sdf_sign = torch.nn.Parameter(torch.sign(dmt_geometry.sdf.data + 1e-8).float(), requires_grad=False)
|
||||
self.sdf_sign.data[self.sdf_sign.data == 0] = 1.0 ## Avoid abiguity
|
||||
self.register_parameter('sdf_sign', self.sdf_sign)
|
||||
|
||||
self.sdf_abs = torch.nn.Parameter(torch.ones_like(dmt_geometry.sdf), requires_grad=False)
|
||||
self.register_parameter('sdf_abs', self.sdf_abs)
|
||||
|
||||
self.deform = torch.nn.Parameter(dmt_geometry.deform.data, requires_grad=True)
|
||||
self.register_parameter('deform', self.deform)
|
||||
|
||||
self.sdf_abs_ema = torch.nn.Parameter(self.sdf_abs.clone().detach(), requires_grad=False)
|
||||
self.deform_ema = torch.nn.Parameter(self.deform.clone().detach(), requires_grad=False)
|
||||
|
||||
def set_init_v_pos(self):
|
||||
with torch.no_grad():
|
||||
v_deformed = self.get_deformed()
|
||||
verts, faces, uvs, uv_idx = self.marching_tets(v_deformed, self.sdf_sign * self.sdf_abs.abs(), self.indices)
|
||||
self.initial_guess_v_pos = verts
|
||||
|
||||
|
||||
def generate_edges(self):
|
||||
with torch.no_grad():
|
||||
edges = torch.tensor([0,1,0,2,0,3,1,2,1,3,2,3], dtype = torch.long, device = "cuda")
|
||||
all_edges = self.indices[:,edges].reshape(-1,2)
|
||||
all_edges_sorted = torch.sort(all_edges, dim=1)[0]
|
||||
self.all_edges = torch.unique(all_edges_sorted, dim=0)
|
||||
|
||||
def getAABB(self):
|
||||
return torch.min(self.verts, dim=0).values, torch.max(self.verts, dim=0).values
|
||||
|
||||
def getVertNNDist(self):
|
||||
raise NotImplementedError
|
||||
v_deformed = (self.verts + 2 / (self.grid_res * 2) * torch.tanh(self.deform)).unsqueeze(0)
|
||||
return (pytorch3d.ops.knn.knn_points(v_deformed, v_deformed, K=2).dists[0, :, -1].detach()) ## K=2 because dist(self, self)=0
|
||||
|
||||
def getMesh(self, material):
|
||||
# Run DM tet to get a base mesh
|
||||
v_deformed = self.get_deformed()
|
||||
verts, faces, uvs, uv_idx = self.marching_tets(v_deformed, self.sdf_sign * self.sdf_abs.abs(), self.indices)
|
||||
imesh = mesh.Mesh(verts, faces, v_tex=uvs, t_tex_idx=uv_idx, material=material)
|
||||
|
||||
# Run mesh operations to generate tangent space
|
||||
imesh = mesh.auto_normals(imesh)
|
||||
imesh = mesh.compute_tangents(imesh)
|
||||
|
||||
return imesh
|
||||
|
||||
def getMesh_tet_gidx(self, material):
|
||||
# Run DM tet to get a base mesh
|
||||
v_deformed = self.get_deformed()
|
||||
verts, faces, uvs, uv_idx, tet_gidx = self.marching_tets(
|
||||
v_deformed, self.sdf_sign * self.sdf_abs.abs(), self.indices, get_tet_gidx=True)
|
||||
imesh = mesh.Mesh(verts, faces, v_tex=uvs, t_tex_idx=uv_idx, material=material)
|
||||
|
||||
# Run mesh operations to generate tangent space
|
||||
imesh = mesh.auto_normals(imesh)
|
||||
imesh = mesh.compute_tangents(imesh)
|
||||
|
||||
return imesh, tet_gidx
|
||||
|
||||
|
||||
def update_ema(self, ema_coeff=0.9):
|
||||
return
|
||||
|
||||
def get_deformed(self):
|
||||
if self.tanh:
|
||||
v_deformed = self.verts + 2 / (self.grid_res * 2) * torch.tanh(self.deform) * self.deform_scale
|
||||
else:
|
||||
v_deformed = self.verts + 2 / (self.grid_res * 2) * self.deform * self.deform_scale
|
||||
return v_deformed
|
||||
|
||||
def getValidTetIdx(self):
|
||||
# Run DM tet to get a base mesh
|
||||
v_deformed = self.get_deformed()
|
||||
verts, faces, uvs, uv_idx, tet_gidx = self.marching_tets(
|
||||
v_deformed, self.sdf_sign * self.sdf_abs.abs(), self.indices, get_tet_gidx=True)
|
||||
return tet_gidx.long()
|
||||
|
||||
def getValidVertsIdx(self):
|
||||
# Run DM tet to get a base mesh
|
||||
v_deformed = self.get_deformed()
|
||||
verts, faces, uvs, uv_idx, tet_gidx = self.marching_tets(
|
||||
v_deformed, self.sdf_sign * self.sdf_abs.abs(), self.indices, get_tet_gidx=True)
|
||||
return self.indices[tet_gidx.long()].unique()
|
||||
|
||||
def getTetCenters(self):
|
||||
v_deformed = self.get_deformed() # size: N x 3
|
||||
face_verts = v_deformed[self.indices] # size: M x 4 x 3
|
||||
face_centers = face_verts.mean(dim=1) # size: M x 3
|
||||
|
||||
return face_centers
|
||||
|
||||
def clamp_deform(self):
|
||||
if not self.tanh:
|
||||
self.deform.data[:] = self.deform.data.clamp(-0.99, 0.99)
|
||||
|
||||
def render(self, glctx, target, lgt, opt_material, bsdf=None, ema=False, xfm_lgt=None, get_visible_tets=False):
|
||||
opt_mesh = self.getMesh(opt_material)
|
||||
tet_centers = self.getTetCenters() if get_visible_tets else None
|
||||
return render.render_mesh(
|
||||
glctx, opt_mesh, target['mvp'], target['campos'], lgt, target['resolution'], spp=target['spp'],
|
||||
msaa=True, background=target['background'], bsdf=bsdf, xfm_lgt=xfm_lgt, tet_centers=tet_centers)
|
||||
|
||||
|
||||
def render_with_mesh(self, glctx, target, lgt, opt_material, bsdf=None, xfm_lgt=None):
|
||||
opt_mesh = self.getMesh(opt_material)
|
||||
return opt_mesh, render.render_mesh(
|
||||
glctx, opt_mesh, target['mvp'], target['campos'], lgt, target['resolution'], spp=target['spp'],
|
||||
num_layers=self.FLAGS.layers, msaa=True, background=target['background'], bsdf=bsdf, xfm_lgt=xfm_lgt)
|
||||
|
||||
def tick(self, glctx, target, lgt, opt_material, loss_fn, iteration, with_reg=True, xfm_lgt=None, no_depth_thin=True):
|
||||
|
||||
# ==============================================================================================
|
||||
# Render optimizable object with identical conditions
|
||||
# ==============================================================================================
|
||||
imesh, buffers = self.render_with_mesh(glctx, target, lgt, opt_material, xfm_lgt=xfm_lgt)
|
||||
|
||||
# ==============================================================================================
|
||||
# Compute loss
|
||||
# ==============================================================================================
|
||||
t_iter = iteration / self.FLAGS.iter
|
||||
|
||||
# Image-space loss, split into a coverage component and a color component
|
||||
color_ref = target['img']
|
||||
img_loss = torch.nn.functional.mse_loss(buffers['shaded'][..., 3:], color_ref[..., 3:])
|
||||
img_loss = img_loss + loss_fn(
|
||||
buffers['shaded'][..., 0:3] * color_ref[..., 3:],
|
||||
color_ref[..., 0:3] * color_ref[..., 3:]
|
||||
)
|
||||
mask = target['mask'][:, :, :, 0]
|
||||
|
||||
|
||||
if no_depth_thin:
|
||||
valid_depth_mask = (
|
||||
(target['depth_second'] >= 0).float() * ((target['depth_second'] - target['depth']).abs() >= 5e-3).float()
|
||||
).detach()
|
||||
else:
|
||||
valid_depth_mask = 1.0
|
||||
|
||||
depth_diff = (buffers['depth'][:, :, :, :1] - target['depth'][:, :, :, :1]).abs() * mask.unsqueeze(-1) * valid_depth_mask
|
||||
depth_diff = (buffers['depth_second'][:, :, :, :1] - target['depth_second'][:, :, :, :1]).abs() * mask.unsqueeze(-1) * valid_depth_mask * 1e-1
|
||||
|
||||
l1_loss_mask = (depth_diff < 1.0).float()
|
||||
img_loss = img_loss + (l1_loss_mask * depth_diff + (1 - l1_loss_mask) * depth_diff.pow(2)).mean() * 100.0
|
||||
|
||||
reg_loss = torch.tensor([0], dtype=torch.float32, device="cuda")
|
||||
|
||||
# Compute regularizer.
|
||||
reg_loss += regularizer.laplace_regularizer_const(imesh.v_pos - self.initial_guess_v_pos, imesh.t_pos_idx) * self.FLAGS.laplace_scale * (1 - t_iter) * 1e-2
|
||||
|
||||
### Chamfer distance for ShapeNet
|
||||
pred_points = kaolin.ops.mesh.sample_points(imesh.v_pos.unsqueeze(0), imesh.t_pos_idx, 50000)[0][0]
|
||||
target_pts = target['spts']
|
||||
chamfer = kaolin.metrics.pointcloud.chamfer_distance(pred_points.unsqueeze(0), target_pts.unsqueeze(0)).mean()
|
||||
reg_loss += chamfer
|
||||
|
||||
return img_loss, reg_loss
|
|
@ -0,0 +1,516 @@
|
|||
# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
#
|
||||
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
||||
# property and proprietary rights in and to this material, related
|
||||
# documentation and any modifications thereto. Any use, reproduction,
|
||||
# disclosure or distribution of this material and related documentation
|
||||
# without an express license agreement from NVIDIA CORPORATION or
|
||||
# its affiliates is strictly prohibited.
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from ..render import mesh
|
||||
from ..render import render
|
||||
from ..render import regularizer
|
||||
|
||||
|
||||
import kaolin
|
||||
import pytorch3d.ops
|
||||
from ..render import util as render_utils
|
||||
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ..render import renderutils as ru
|
||||
|
||||
###############################################################################
|
||||
# Marching tetrahedrons implementation (differentiable), adapted from
|
||||
# https://github.com/NVIDIAGameWorks/kaolin/blob/master/kaolin/ops/conversions/tetmesh.py
|
||||
###############################################################################
|
||||
|
||||
class DMTet:
|
||||
def __init__(self):
|
||||
self.triangle_table = torch.tensor([
|
||||
[-1, -1, -1, -1, -1, -1],
|
||||
[ 1, 0, 2, -1, -1, -1],
|
||||
[ 4, 0, 3, -1, -1, -1],
|
||||
[ 1, 4, 2, 1, 3, 4],
|
||||
[ 3, 1, 5, -1, -1, -1],
|
||||
[ 2, 3, 0, 2, 5, 3],
|
||||
[ 1, 4, 0, 1, 5, 4],
|
||||
[ 4, 2, 5, -1, -1, -1],
|
||||
[ 4, 5, 2, -1, -1, -1],
|
||||
[ 4, 1, 0, 4, 5, 1],
|
||||
[ 3, 2, 0, 3, 5, 2],
|
||||
[ 1, 3, 5, -1, -1, -1],
|
||||
[ 4, 1, 2, 4, 3, 1],
|
||||
[ 3, 0, 4, -1, -1, -1],
|
||||
[ 2, 0, 1, -1, -1, -1],
|
||||
[-1, -1, -1, -1, -1, -1]
|
||||
], dtype=torch.long, device='cuda')
|
||||
|
||||
self.num_triangles_table = torch.tensor([0,1,1,2,1,2,2,1,1,2,2,1,2,1,1,0], dtype=torch.long, device='cuda')
|
||||
self.base_tet_edges = torch.tensor([0,1,0,2,0,3,1,2,1,3,2,3], dtype=torch.long, device='cuda')
|
||||
|
||||
###############################################################################
|
||||
# Utility functions
|
||||
###############################################################################
|
||||
|
||||
def sort_edges(self, edges_ex2):
|
||||
with torch.no_grad():
|
||||
order = (edges_ex2[:,0] > edges_ex2[:,1]).long()
|
||||
order = order.unsqueeze(dim=1)
|
||||
|
||||
a = torch.gather(input=edges_ex2, index=order, dim=1)
|
||||
b = torch.gather(input=edges_ex2, index=1-order, dim=1)
|
||||
|
||||
return torch.stack([a, b],-1)
|
||||
|
||||
def map_uv(self, faces, face_gidx, max_idx):
|
||||
N = int(np.ceil(np.sqrt((max_idx+1)//2)))
|
||||
tex_y, tex_x = torch.meshgrid(
|
||||
torch.linspace(0, 1 - (1 / N), N, dtype=torch.float32, device="cuda"),
|
||||
torch.linspace(0, 1 - (1 / N), N, dtype=torch.float32, device="cuda"),
|
||||
indexing='ij'
|
||||
)
|
||||
|
||||
pad = 0.9 / N
|
||||
|
||||
uvs = torch.stack([
|
||||
tex_x , tex_y,
|
||||
tex_x + pad, tex_y,
|
||||
tex_x + pad, tex_y + pad,
|
||||
tex_x , tex_y + pad
|
||||
], dim=-1).view(-1, 2)
|
||||
|
||||
def _idx(tet_idx, N):
|
||||
x = tet_idx % N
|
||||
y = torch.div(tet_idx, N, rounding_mode='trunc')
|
||||
return y * N + x
|
||||
|
||||
tet_idx = _idx(torch.div(face_gidx, 2, rounding_mode='trunc'), N)
|
||||
tri_idx = face_gidx % 2
|
||||
|
||||
uv_idx = torch.stack((
|
||||
tet_idx * 4, tet_idx * 4 + tri_idx + 1, tet_idx * 4 + tri_idx + 2
|
||||
), dim = -1). view(-1, 3)
|
||||
|
||||
return uvs, uv_idx
|
||||
|
||||
###############################################################################
|
||||
# Marching tets implementation
|
||||
###############################################################################
|
||||
|
||||
def __call__(self, pos_nx3, sdf_n, tet_fx4):
|
||||
with torch.no_grad():
|
||||
occ_n = sdf_n > 0
|
||||
occ_fx4 = occ_n[tet_fx4.reshape(-1)].reshape(-1,4)
|
||||
occ_sum = torch.sum(occ_fx4, -1)
|
||||
valid_tets = (occ_sum>0) & (occ_sum<4)
|
||||
occ_sum = occ_sum[valid_tets]
|
||||
|
||||
# find all vertices
|
||||
all_edges = tet_fx4[valid_tets][:,self.base_tet_edges].reshape(-1,2)
|
||||
all_edges = self.sort_edges(all_edges)
|
||||
unique_edges, idx_map = torch.unique(all_edges,dim=0, return_inverse=True)
|
||||
|
||||
unique_edges = unique_edges.long()
|
||||
mask_edges = occ_n[unique_edges.reshape(-1)].reshape(-1,2).sum(-1) == 1
|
||||
mapping = torch.ones((unique_edges.shape[0]), dtype=torch.long, device="cuda") * -1
|
||||
mapping[mask_edges] = torch.arange(mask_edges.sum(), dtype=torch.long,device="cuda")
|
||||
idx_map = mapping[idx_map] # map edges to verts
|
||||
|
||||
interp_v = unique_edges[mask_edges]
|
||||
edges_to_interp = pos_nx3[interp_v.reshape(-1)].reshape(-1,2,3)
|
||||
edges_to_interp_sdf = sdf_n[interp_v.reshape(-1)].reshape(-1,2,1)
|
||||
edges_to_interp_sdf[:,-1] *= -1
|
||||
|
||||
denominator = edges_to_interp_sdf.sum(1,keepdim = True)
|
||||
|
||||
edges_to_interp_sdf = torch.flip(edges_to_interp_sdf, [1])/denominator
|
||||
verts = (edges_to_interp * edges_to_interp_sdf).sum(1)
|
||||
idx_map = idx_map.reshape(-1,6)
|
||||
|
||||
v_id = torch.pow(2, torch.arange(4, dtype=torch.long, device="cuda"))
|
||||
tetindex = (occ_fx4[valid_tets] * v_id.unsqueeze(0)).sum(-1)
|
||||
num_triangles = self.num_triangles_table[tetindex]
|
||||
|
||||
# Generate triangle indices
|
||||
faces = torch.cat((
|
||||
torch.gather(
|
||||
input=idx_map[num_triangles == 1],
|
||||
dim=1,
|
||||
index=self.triangle_table[tetindex[num_triangles == 1]][:, :3]
|
||||
).reshape(-1,3),
|
||||
torch.gather(
|
||||
input=idx_map[num_triangles == 2],
|
||||
dim=1,
|
||||
index=self.triangle_table[tetindex[num_triangles == 2]][:, :6]
|
||||
).reshape(-1,3),
|
||||
), dim=0)
|
||||
|
||||
# Get global face index (static, does not depend on topology)
|
||||
num_tets = tet_fx4.shape[0]
|
||||
tet_gidx = torch.arange(num_tets, dtype=torch.long, device="cuda")[valid_tets]
|
||||
face_gidx = torch.cat((
|
||||
tet_gidx[num_triangles == 1]*2,
|
||||
torch.stack((tet_gidx[num_triangles == 2]*2, tet_gidx[num_triangles == 2]*2 + 1), dim=-1).view(-1)
|
||||
), dim=0)
|
||||
|
||||
uvs, uv_idx = self.map_uv(faces, face_gidx, num_tets*2)
|
||||
|
||||
|
||||
face_to_valid_tet = torch.cat((
|
||||
tet_gidx[num_triangles == 1],
|
||||
torch.stack((tet_gidx[num_triangles == 2], tet_gidx[num_triangles == 2]), dim=-1).view(-1)
|
||||
), dim=0)
|
||||
|
||||
valid_vert_idx = tet_fx4[tet_gidx[num_triangles > 0]].long().unique()
|
||||
|
||||
return verts, faces, uvs, uv_idx, face_to_valid_tet.long(), valid_vert_idx
|
||||
|
||||
###############################################################################
|
||||
# Regularizer
|
||||
###############################################################################
|
||||
|
||||
def sdf_reg_loss(sdf, all_edges):
|
||||
sdf_f1x6x2 = sdf[all_edges.reshape(-1)].reshape(-1,2)
|
||||
mask = torch.sign(sdf_f1x6x2[...,0]) != torch.sign(sdf_f1x6x2[...,1])
|
||||
sdf_f1x6x2 = sdf_f1x6x2[mask]
|
||||
sdf_diff = torch.nn.functional.binary_cross_entropy_with_logits(sdf_f1x6x2[...,0], (sdf_f1x6x2[...,1] > 0).float()) + \
|
||||
torch.nn.functional.binary_cross_entropy_with_logits(sdf_f1x6x2[...,1], (sdf_f1x6x2[...,0] > 0).float())
|
||||
return sdf_diff
|
||||
|
||||
|
||||
|
||||
class Buffer(object):
|
||||
def __init__(self, shape, capacity, device) -> None:
|
||||
self.len_curr = 0
|
||||
self.pointer = 0
|
||||
self.capacity = capacity
|
||||
self.buffer = torch.zeros((capacity, ) + shape, device=device)
|
||||
|
||||
def push(self, x):
|
||||
'''
|
||||
Push one single data point into the buffer
|
||||
'''
|
||||
self.buffer[self.pointer] = x
|
||||
self.pointer = (self.pointer + 1) % self.capacity
|
||||
if self.len_curr < self.capacity:
|
||||
self.len_curr += 1
|
||||
|
||||
def avg(self):
|
||||
return torch.sign(torch.sign(self.buffer[:self.len_curr]).float().mean(dim=0)).float()
|
||||
# return self.buffer[:self.len_curr].mean(dim=0)
|
||||
# return self.buffer[:self.len_curr][-1]
|
||||
|
||||
###############################################################################
|
||||
# Geometry interface
|
||||
###############################################################################
|
||||
|
||||
class DMTetGeometry(torch.nn.Module):
|
||||
def __init__(self, grid_res, scale, FLAGS, root='./', grid_to_tet=None, deform_scale=2.0, **kwargs):
|
||||
super(DMTetGeometry, self).__init__()
|
||||
|
||||
self.FLAGS = FLAGS
|
||||
self.grid_res = grid_res
|
||||
self.marching_tets = DMTet()
|
||||
self.cropped = True
|
||||
self.tanh = False
|
||||
self.deform_scale = deform_scale
|
||||
|
||||
self.grid_to_tet = grid_to_tet
|
||||
|
||||
if self.cropped:
|
||||
print("use cropped tets")
|
||||
tets = np.load(os.path.join(root, 'data/tets/{}_tets_cropped.npz'.format(self.grid_res)))
|
||||
else:
|
||||
tets = np.load(os.path.join(root, 'data/tets/{}_tets.npz'.format(self.grid_res)))
|
||||
print('tet min and max', tets['vertices'].min() * scale, tets['vertices'].max() * scale)
|
||||
self.verts = torch.tensor(tets['vertices'], dtype=torch.float32, device='cuda') * scale
|
||||
self.indices = torch.tensor(tets['indices'], dtype=torch.long, device='cuda')
|
||||
self.generate_edges()
|
||||
|
||||
# Random init
|
||||
sdf = torch.rand_like(self.verts[:,0]).clamp(-1.0, 1.0) - 0.1
|
||||
# sdf = torch.sign(sdf) * 0.1
|
||||
# sdf = self.verts.pow(2).sum(dim=-1).sqrt() - 0.5
|
||||
|
||||
self.sdf = torch.nn.Parameter(sdf.clone().detach(), requires_grad=True)
|
||||
self.register_parameter('sdf', self.sdf)
|
||||
|
||||
self.deform = torch.nn.Parameter(torch.zeros_like(self.verts), requires_grad=True)
|
||||
self.register_parameter('deform', self.deform)
|
||||
|
||||
self.alpha = None
|
||||
|
||||
self.sdf_ema = torch.nn.Parameter(sdf.clone().detach(), requires_grad=False)
|
||||
self.deform_ema = torch.nn.Parameter(self.deform.clone().detach(), requires_grad=False)
|
||||
|
||||
# self.ema_coeff = 0.7
|
||||
self.ema_coeff = 0.9
|
||||
|
||||
self.sdf_buffer = Buffer(sdf.size(), capacity=200, device='cuda')
|
||||
|
||||
def generate_edges(self):
|
||||
with torch.no_grad():
|
||||
edges = torch.tensor([0,1,0,2,0,3,1,2,1,3,2,3], dtype = torch.long, device = "cuda")
|
||||
all_edges = self.indices[:,edges].reshape(-1,2)
|
||||
all_edges_sorted = torch.sort(all_edges, dim=1)[0]
|
||||
self.all_edges = torch.unique(all_edges_sorted, dim=0)
|
||||
|
||||
def getAABB(self):
|
||||
return torch.min(self.verts, dim=0).values, torch.max(self.verts, dim=0).values
|
||||
|
||||
def getVertNNDist(self):
|
||||
v_deformed = (self.verts + 2 / (self.grid_res * 2) * torch.tanh(self.deform)).unsqueeze(0)
|
||||
return pytorch3d.ops.knn.knn_points(
|
||||
v_deformed, v_deformed, K=2
|
||||
).dists[0, :, -1].detach() ## K=2 because dist(self, self)=0
|
||||
|
||||
def getTetCenters(self):
|
||||
v_deformed = self.get_deformed() # size: N x 3
|
||||
face_verts = v_deformed[self.indices] # size: M x 4 x 3
|
||||
face_centers = face_verts.mean(dim=1) # size: M x 3
|
||||
|
||||
return face_centers
|
||||
|
||||
def getValidTetIdx(self):
|
||||
# Run DM tet to get a base mesh
|
||||
v_deformed = self.get_deformed()
|
||||
verts, faces, uvs, uv_idx, tet_gidx, valid_vert_idx = self.marching_tets(v_deformed, self.sdf, self.indices)
|
||||
return tet_gidx.long()
|
||||
|
||||
def getValidVertsIdx(self):
|
||||
# Run DM tet to get a base mesh
|
||||
v_deformed = self.get_deformed()
|
||||
verts, faces, uvs, uv_idx, tet_gidx, valid_vert_idx = self.marching_tets(v_deformed, self.sdf, self.indices)
|
||||
return self.indices[tet_gidx.long()].unique()
|
||||
|
||||
def getMesh(self, material, noise=0.0, ema=False):
|
||||
# Run DM tet to get a base mesh
|
||||
v_deformed = self.get_deformed(ema=ema)
|
||||
|
||||
if ema:
|
||||
sdf = self.sdf_ema
|
||||
else:
|
||||
sdf = self.sdf
|
||||
|
||||
verts, faces, uvs, uv_idx, tet_gidx, valid_vert_idx = self.marching_tets(v_deformed, sdf, self.indices)
|
||||
imesh = mesh.Mesh(verts, faces, v_tex=uvs, t_tex_idx=uv_idx, material=material)
|
||||
|
||||
if material is not None:
|
||||
# Run mesh operations to generate tangent space
|
||||
imesh = mesh.auto_normals(imesh)
|
||||
imesh = mesh.compute_tangents(imesh)
|
||||
imesh.valid_vert_idx = valid_vert_idx
|
||||
|
||||
return imesh
|
||||
|
||||
def getMesh_no_deform(self, material, noise=0.0, ema=False):
|
||||
# Run DM tet to get a base mesh
|
||||
if ema:
|
||||
# sdf = self.sdf * (1 - self.ema_coeff) + self.sdf_ema.detach() * self.ema_coeff
|
||||
sdf = self.sdf_ema
|
||||
else:
|
||||
sdf = self.sdf
|
||||
|
||||
verts, faces, uvs, uv_idx, tet_gidx, valid_vert_idx = self.marching_tets(self.verts, torch.sign(sdf), self.indices)
|
||||
imesh = mesh.Mesh(verts, faces, v_tex=uvs, t_tex_idx=uv_idx, material=material)
|
||||
|
||||
# Run mesh operations to generate tangent space
|
||||
imesh = mesh.auto_normals(imesh)
|
||||
imesh = mesh.compute_tangents(imesh)
|
||||
|
||||
return imesh
|
||||
|
||||
def getMesh_no_deform_gd(self, material, noise=0.0, ema=False):
|
||||
# Run DM tet to get a base mesh
|
||||
v_deformed = self.get_deformed(no_grad=True)
|
||||
|
||||
|
||||
if ema:
|
||||
# sdf = self.sdf * (1 - self.ema_coeff) + self.sdf_ema.detach() * self.ema_coeff
|
||||
sdf = self.sdf_ema
|
||||
else:
|
||||
sdf = self.sdf
|
||||
|
||||
verts, faces, uvs, uv_idx, tet_gidx, valid_vert_idx = self.marching_tets(v_deformed, sdf, self.indices)
|
||||
imesh = mesh.Mesh(verts, faces, v_tex=uvs, t_tex_idx=uv_idx, material=material)
|
||||
|
||||
# Run mesh operations to generate tangent space
|
||||
imesh = mesh.auto_normals(imesh)
|
||||
imesh = mesh.compute_tangents(imesh)
|
||||
|
||||
return imesh
|
||||
|
||||
def get_deformed(self, no_grad=False, ema=False):
|
||||
if no_grad:
|
||||
deform = self.deform.detach()
|
||||
else:
|
||||
deform = self.deform
|
||||
|
||||
if self.tanh:
|
||||
v_deformed = self.verts + 2 / (self.grid_res * 2) * torch.tanh(deform) * self.deform_scale
|
||||
else:
|
||||
v_deformed = self.verts + 2 / (self.grid_res * 2) * deform * self.deform_scale
|
||||
return v_deformed
|
||||
|
||||
def get_angle(self):
|
||||
with torch.no_grad():
|
||||
comb_list = [
|
||||
(0, 1, 2, 3),
|
||||
(0, 1, 3, 2),
|
||||
(0, 2, 3, 1),
|
||||
(1, 2, 3, 0)
|
||||
]
|
||||
|
||||
directions = torch.zeros(self.indices.size(0), 4).cuda()
|
||||
dir_vec = torch.zeros(self.indices.size(0), 4, 3).cuda()
|
||||
vert_inds = torch.zeros(self.indices.size(0), 4).cuda().long()
|
||||
count = 0
|
||||
vpos_list = self.get_deformed()
|
||||
for comb in comb_list:
|
||||
face = self.indices[:, comb[:3]]
|
||||
face_pos = vpos_list[face, :]
|
||||
face_center = face_pos.mean(1, keepdim=False)
|
||||
v = self.indices[:, comb[3]]
|
||||
test_vec = vpos_list[v]
|
||||
ref_vec = render_utils.safe_normalize(vpos_list[face[:, 0]] - face_center)
|
||||
distance_vec = test_vec - render_utils.dot(test_vec, ref_vec) * ref_vec
|
||||
directions[:, count] = torch.sign(render_utils.dot(test_vec, distance_vec)[:, 0])
|
||||
dir_vec[:, count, :] = distance_vec
|
||||
vert_inds[:, count] = v
|
||||
count += 1
|
||||
return directions, dir_vec, vert_inds
|
||||
|
||||
|
||||
def clamp_deform(self):
|
||||
if not self.tanh:
|
||||
self.deform.data[:] = self.deform.data.clamp(-0.99, 0.99)
|
||||
self.sdf.data[:] = self.sdf.data.clamp(-1.0, 1.0)
|
||||
|
||||
def render(self, glctx, target, lgt, opt_material, bsdf=None, ema=False, xfm_lgt=None, get_visible_tets=False):
|
||||
opt_mesh = self.getMesh(opt_material, ema=ema)
|
||||
tet_centers = self.getTetCenters() if get_visible_tets else None
|
||||
return render.render_mesh(
|
||||
glctx, opt_mesh, target['mvp'], target['campos'], lgt, target['resolution'], spp=target['spp'],
|
||||
msaa=True, background=target['background'], bsdf=bsdf, xfm_lgt=xfm_lgt, tet_centers=tet_centers)
|
||||
|
||||
def render_with_mesh(self, glctx, target, lgt, opt_material, bsdf=None, noise=0.0, ema=False, xfm_lgt=None):
|
||||
opt_mesh = self.getMesh(opt_material, noise=noise, ema=ema)
|
||||
return opt_mesh, render.render_mesh(
|
||||
glctx, opt_mesh, target['mvp'], target['campos'], lgt, target['resolution'], spp=target['spp'],
|
||||
msaa=True, background=target['background'], bsdf=bsdf, xfm_lgt=xfm_lgt)
|
||||
|
||||
def update_ema(self, ema_coeff=0.9):
|
||||
self.sdf_buffer.push(self.sdf)
|
||||
self.sdf_ema.data[:] = self.sdf_buffer.avg()
|
||||
# self.sdf_ema.data[:] = self.sdf.data[:] * (1 - ema_coeff) + self.sdf_ema.data[:] * ema_coeff
|
||||
# self.deform_ema.data[:] = self.deform.data[:] * (1 - ema_coeff) + self.deform_ema.data[:] * ema_coeff
|
||||
self.deform_ema.data[:] = self.deform.data[:]
|
||||
|
||||
|
||||
def render_ema(self, glctx, target, lgt, opt_material, bsdf=None, xfm_lgt=None):
|
||||
opt_mesh = self.getMesh(opt_material, ema=True)
|
||||
return render.render_mesh(
|
||||
glctx, opt_mesh, target['mvp'], target['campos'], lgt, target['resolution'], spp=target['spp'],
|
||||
msaa=True, background=target['background'], bsdf=bsdf, xfm_lgt=xfm_lgt)
|
||||
|
||||
def init_with_gt_surface(self, gt_verts, surface_faces, campos):
|
||||
with torch.no_grad():
|
||||
surface_face_verts = gt_verts[surface_faces]
|
||||
surface_centers = surface_face_verts.mean(dim=1)
|
||||
v_pos = self.get_deformed()
|
||||
results = pytorch3d.ops.knn_points(v_pos[None, ...], surface_centers[None, ...])
|
||||
dists, nn_idx = results.dists, results.idx
|
||||
displacement = (v_pos - surface_centers[nn_idx[0, :, 0]])
|
||||
view_dirs = campos - surface_centers
|
||||
normals = torch.cross(
|
||||
surface_face_verts[:, 0] - surface_face_verts[:, 1], surface_face_verts[:, 0] - surface_face_verts[:, 2])
|
||||
mask = ((normals * view_dirs).sum(dim=-1, keepdim=True) >= 0).float()
|
||||
normals = normals * mask - normals * (1 - mask)
|
||||
outside_verts_idx = ((displacement * normals[nn_idx[0, :, 0]]).sum(dim=-1) > 0)
|
||||
self.sdf.data[outside_verts_idx] = 1.0
|
||||
|
||||
|
||||
def tick(self, glctx, target, lgt, opt_material, loss_fn, iteration, with_reg=True, xfm_lgt=None, no_depth_thin=True):
|
||||
|
||||
if iteration < 100:
|
||||
self.deform.requires_grad = False
|
||||
self.deform_scale = 2.0
|
||||
else:
|
||||
self.deform.requires_grad = True
|
||||
self.deform_scale = 2.0
|
||||
|
||||
if iteration > 200 and iteration < 2000 and iteration % 20 == 0:
|
||||
with torch.no_grad():
|
||||
v_pos = self.get_deformed()
|
||||
v_pos_camera_homo = ru.xfm_points(v_pos[None, ...], target['mvp'])
|
||||
v_pos_camera = v_pos_camera_homo[:, :, :2] / v_pos_camera_homo[:, :, -1:]
|
||||
v_pos_camera_discrete = ((v_pos_camera * 0.5 + 0.5).clip(0, 1) * (target['resolution'][0] - 1)).long()
|
||||
target_mask = target['mask_cont'][:, :, :, 0] == 0
|
||||
for k in range(target_mask.size(0)):
|
||||
assert v_pos_camera_discrete[k].min() >= 0 and v_pos_camera_discrete[k].max() < target['resolution'][0]
|
||||
v_mask = target_mask[k, v_pos_camera_discrete[k, :, 1], v_pos_camera_discrete[k, :, 0]].view(v_pos.size(0))
|
||||
# print(v_mask.sum())
|
||||
self.sdf.data[v_mask] = self.sdf.data[v_mask].abs().clamp(0.0, 1.0)
|
||||
|
||||
# ==============================================================================================
|
||||
# Render optimizable object with identical conditions
|
||||
# ==============================================================================================
|
||||
imesh, buffers = self.render_with_mesh(glctx, target, lgt, opt_material, noise=0.0, xfm_lgt=xfm_lgt)
|
||||
|
||||
# ==============================================================================================
|
||||
# Compute loss
|
||||
# ==============================================================================================
|
||||
|
||||
# Image-space loss, split into a coverage component and a color component
|
||||
color_ref = target['img']
|
||||
img_loss = torch.tensor(0.0).cuda()
|
||||
img_loss = torch.nn.functional.mse_loss(buffers['shaded'][..., 3:], color_ref[..., 3:])
|
||||
img_loss = img_loss + loss_fn(buffers['shaded'][..., 0:3] * color_ref[..., 3:], color_ref[..., 0:3] * color_ref[..., 3:])
|
||||
mask = (target['mask_cont'][:, :, :, 0] == 1.0).float()
|
||||
mask_curr = (buffers['mask_cont'][:, :, :, 0] == 1.0).float()
|
||||
|
||||
if iteration % 300 == 0 and iteration < 1790:
|
||||
self.deform.data[:] *= 0.4
|
||||
|
||||
if no_depth_thin:
|
||||
valid_depth_mask = (
|
||||
(target['depth_second'] >= 0).float() * ((target['depth_second'] - target['depth']).abs() >= 5e-3).float()
|
||||
).detach()
|
||||
else:
|
||||
valid_depth_mask = 1.0
|
||||
|
||||
|
||||
depth_diff = (buffers['depth'][:, :, :, :1] - target['depth'][:, :, :, :1]).abs() * mask.unsqueeze(-1) * valid_depth_mask
|
||||
l1_loss_mask = (depth_diff < 1.0).float()
|
||||
img_loss = img_loss + (l1_loss_mask * depth_diff + (1 - l1_loss_mask) * depth_diff.pow(2)).mean() * 100.0
|
||||
|
||||
reg_loss = torch.tensor(0.0).cuda()
|
||||
|
||||
# SDF regularizer
|
||||
iter_thres = 0
|
||||
sdf_weight = self.FLAGS.sdf_regularizer - (self.FLAGS.sdf_regularizer - 0.01) * min(1.0, 4.0 * ((iteration - iter_thres) / (self.FLAGS.iter - iter_thres)))
|
||||
|
||||
sdf_mask = torch.zeros_like(self.sdf, device=self.sdf.device)
|
||||
sdf_mask[imesh.valid_vert_idx] = 1.0
|
||||
sdf_masked = self.sdf.detach() * sdf_mask + self.sdf * (1 - sdf_mask)
|
||||
reg_loss = sdf_reg_loss(sdf_masked, self.all_edges).mean() * sdf_weight * 2.5 # Dropoff to 0.01
|
||||
|
||||
|
||||
# Albedo (k_d) smoothnesss regularizer
|
||||
reg_loss += torch.mean(buffers['kd_grad'][..., :-1] * buffers['kd_grad'][..., -1:]) * 0.03 * min(1.0, iteration / 500)
|
||||
|
||||
# Visibility regularizer
|
||||
reg_loss += torch.mean(buffers['occlusion'][..., :-1] * buffers['occlusion'][..., -1:]) * 1e0 * min(1.0, iteration / 500)
|
||||
|
||||
pred_points = kaolin.ops.mesh.sample_points(imesh.v_pos.unsqueeze(0), imesh.t_pos_idx, 50000)[0][0]
|
||||
target_pts = target['spts']
|
||||
chamfer = kaolin.metrics.pointcloud.chamfer_distance(pred_points.unsqueeze(0), target_pts.unsqueeze(0)).mean()
|
||||
reg_loss += chamfer
|
||||
|
||||
|
||||
return img_loss, reg_loss
|
|
@ -0,0 +1,128 @@
|
|||
import torch
|
||||
|
||||
def _base_sample_points_selected_faces(face_vertices, face_features=None):
|
||||
"""Base function to sample points over selected faces.
|
||||
The coordinates of the face vertices are interpolated to generate new samples.
|
||||
Args:
|
||||
face_vertices (tuple of torch.Tensor):
|
||||
Coordinates of vertices, corresponding to selected faces to sample from.
|
||||
A tuple of 3 entries corresponding to each of the face vertices.
|
||||
Each entry is a torch.Tensor of shape :math:`(\\text{batch_size}, \\text{num_samples}, 3)`.
|
||||
face_features (tuple of torch.Tensor, Optional):
|
||||
Features of face vertices, corresponding to selected faces to sample from.
|
||||
A tuple of 3 entries corresponding to each of the face vertices.
|
||||
Each entry is a torch.Tensor of shape
|
||||
:math:`(\\text{batch_size}, \\text{num_samples}, \\text{feature_dim})`.
|
||||
Returns:
|
||||
(torch.Tensor, torch.Tensor):
|
||||
Sampled point coordinates of shape :math:`(\\text{batch_size}, \\text{num_samples}, 3)`.
|
||||
Sampled points interpolated features of shape
|
||||
:math:`(\\text{batch_size}, \\text{num_samples}, \\text{feature_dim})`.
|
||||
If `face_vertices_features` arg is not specified, the returned interpolated features are None.
|
||||
"""
|
||||
|
||||
face_vertices0, face_vertices1, face_vertices2 = face_vertices
|
||||
|
||||
sampling_shape = tuple(int(d) for d in face_vertices0.shape[:-1]) + (1,)
|
||||
# u is proximity to middle point between v1 and v2 against v0.
|
||||
# v is proximity to v2 against v1.
|
||||
#
|
||||
# The probability density for u should be f_U(u) = 2u.
|
||||
# However, torch.rand use a uniform (f_X(x) = x) distribution,
|
||||
# so using torch.sqrt we make a change of variable to have the desired density
|
||||
# f_Y(y) = f_X(y ^ 2) * |d(y ^ 2) / dy| = 2y
|
||||
u = torch.sqrt(torch.rand(sampling_shape,
|
||||
device=face_vertices0.device,
|
||||
dtype=face_vertices0.dtype))
|
||||
|
||||
v = torch.rand(sampling_shape,
|
||||
device=face_vertices0.device,
|
||||
dtype=face_vertices0.dtype)
|
||||
w0 = 1 - u
|
||||
w1 = u * (1 - v)
|
||||
w2 = u * v
|
||||
|
||||
points = w0 * face_vertices0 + w1 * face_vertices1 + w2 * face_vertices2
|
||||
|
||||
features = None
|
||||
if face_features is not None:
|
||||
face_features0, face_features1, face_features2 = face_features
|
||||
features = w0 * face_features0 + w1 * face_features1 + \
|
||||
w2 * face_features2
|
||||
|
||||
return points, features
|
||||
|
||||
def sample_points(vertices, faces, num_samples, areas=None, face_features=None):
|
||||
r"""Uniformly sample points over the surface of triangle meshes.
|
||||
First face on which the point is sampled is randomly selected,
|
||||
with the probability of selection being proportional to the area of the face.
|
||||
then the coordinate on the face is uniformly sampled.
|
||||
If ``face_features`` is defined for the mesh faces,
|
||||
the sampled points will be returned with interpolated features as well,
|
||||
otherwise, no feature interpolation will occur.
|
||||
Args:
|
||||
vertices (torch.Tensor):
|
||||
The vertices of the meshes, of shape
|
||||
:math:`(\text{batch_size}, \text{num_vertices}, 3)`.
|
||||
faces (torch.LongTensor):
|
||||
The faces of the mesh, of shape :math:`(\text{num_faces}, 3)`.
|
||||
num_samples (int):
|
||||
The number of point sampled per mesh.
|
||||
areas (torch.Tensor, optional):
|
||||
The areas of each face, of shape :math:`(\text{batch_size}, \text{num_faces})`,
|
||||
can be preprocessed, for fast on-the-fly sampling,
|
||||
will be computed if None (default).
|
||||
face_features (torch.Tensor, optional):
|
||||
Per-vertex-per-face features, matching ``faces`` order,
|
||||
of shape :math:`(\text{batch_size}, \text{num_faces}, 3, \text{feature_dim})`.
|
||||
For example:
|
||||
1. Texture uv coordinates would be of shape
|
||||
:math:`(\text{batch_size}, \text{num_faces}, 3, 2)`.
|
||||
2. RGB color values would be of shape
|
||||
:math:`(\text{batch_size}, \text{num_faces}, 3, 3)`.
|
||||
When specified, it is used to interpolate the features for new sampled points.
|
||||
See also:
|
||||
:func:`~kaolin.ops.mesh.index_vertices_by_faces` for conversion of features defined per vertex
|
||||
and need to be converted to per-vertex-per-face shape of :math:`(\text{num_faces}, 3)`.
|
||||
Returns:
|
||||
(torch.Tensor, torch.LongTensor, (optional) torch.Tensor):
|
||||
the pointclouds of shape :math:`(\text{batch_size}, \text{num_samples}, 3)`,
|
||||
and the indexes of the faces selected,
|
||||
of shape :math:`(\text{batch_size}, \text{num_samples})`.
|
||||
If ``face_features`` arg is specified, then the interpolated features of sampled points of shape
|
||||
:math:`(\text{batch_size}, \text{num_samples}, \text{feature_dim})` are also returned.
|
||||
"""
|
||||
if faces.shape[-1] != 3:
|
||||
raise NotImplementedError("sample_points is only implemented for triangle meshes")
|
||||
faces_0, faces_1, faces_2 = torch.split(faces, 1, dim=1) # (num_faces, 3) -> tuple of (num_faces,)
|
||||
face_v_0 = torch.index_select(vertices, 1, faces_0.reshape(-1)) # (batch_size, num_faces, 3)
|
||||
face_v_1 = torch.index_select(vertices, 1, faces_1.reshape(-1)) # (batch_size, num_faces, 3)
|
||||
face_v_2 = torch.index_select(vertices, 1, faces_2.reshape(-1)) # (batch_size, num_faces, 3)
|
||||
|
||||
if areas is None:
|
||||
areas = _base_face_areas(face_v_0, face_v_1, face_v_2).squeeze(-1)
|
||||
face_dist = torch.distributions.Categorical(areas)
|
||||
face_choices = face_dist.sample([num_samples]).transpose(0, 1)
|
||||
_face_choices = face_choices.unsqueeze(-1).repeat(1, 1, 3)
|
||||
v0 = torch.gather(face_v_0, 1, _face_choices) # (batch_size, num_samples, 3)
|
||||
v1 = torch.gather(face_v_1, 1, _face_choices) # (batch_size, num_samples, 3)
|
||||
v2 = torch.gather(face_v_2, 1, _face_choices) # (batch_size, num_samples, 3)
|
||||
face_vertices_choices = (v0, v1, v2)
|
||||
|
||||
# UV coordinates are available, make sure to calculate them for sampled points as well
|
||||
face_features_choices = None
|
||||
if face_features is not None:
|
||||
feat_dim = face_features.shape[-1]
|
||||
# (num_faces, 3) -> tuple of (num_faces,)
|
||||
_face_choices = face_choices[..., None, None].repeat(1, 1, 3, feat_dim)
|
||||
face_features_choices = torch.gather(face_features, 1, _face_choices)
|
||||
face_features_choices = tuple(
|
||||
tmp_feat.squeeze(2) for tmp_feat in torch.split(face_features_choices, 1, dim=2))
|
||||
|
||||
points, point_features = _base_sample_points_selected_faces(
|
||||
face_vertices_choices, face_features_choices)
|
||||
|
||||
if point_features is not None:
|
||||
return points, face_choices, point_features
|
||||
else:
|
||||
return points, face_choices
|
|
@ -0,0 +1,187 @@
|
|||
# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
#
|
||||
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
||||
# property and proprietary rights in and to this material, related
|
||||
# documentation and any modifications thereto. Any use, reproduction,
|
||||
# disclosure or distribution of this material and related documentation
|
||||
# without an express license agreement from NVIDIA CORPORATION or
|
||||
# its affiliates is strictly prohibited.
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
import torch
|
||||
import nvdiffrast.torch as dr
|
||||
|
||||
from . import util
|
||||
from . import renderutils as ru
|
||||
|
||||
import sys
|
||||
|
||||
######################################################################################
|
||||
# Utility functions
|
||||
######################################################################################
|
||||
|
||||
class cubemap_mip(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, cubemap):
|
||||
return util.avg_pool_nhwc(cubemap, (2,2))
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, dout):
|
||||
res = dout.shape[1] * 2
|
||||
out = torch.zeros(6, res, res, dout.shape[-1], dtype=torch.float32, device="cuda")
|
||||
for s in range(6):
|
||||
gy, gx = torch.meshgrid(torch.linspace(-1.0 + 1.0 / res, 1.0 - 1.0 / res, res, device="cuda"),
|
||||
torch.linspace(-1.0 + 1.0 / res, 1.0 - 1.0 / res, res, device="cuda"),
|
||||
indexing='ij')
|
||||
v = util.safe_normalize(util.cube_to_dir(s, gx, gy))
|
||||
out[s, ...] = dr.texture(dout[None, ...] * 0.25, v[None, ...].contiguous(), filter_mode='linear', boundary_mode='cube')
|
||||
return out
|
||||
|
||||
######################################################################################
|
||||
# Split-sum environment map light source with automatic mipmap generation
|
||||
######################################################################################
|
||||
|
||||
class EnvironmentLight(torch.nn.Module):
|
||||
LIGHT_MIN_RES = 16
|
||||
|
||||
MIN_ROUGHNESS = 0.08
|
||||
MAX_ROUGHNESS = 0.5
|
||||
|
||||
def __init__(self, base, trainable=True):
|
||||
super(EnvironmentLight, self).__init__()
|
||||
self.mtx = None
|
||||
self.base = torch.nn.Parameter(base.clone().detach(), requires_grad=trainable)
|
||||
print(f"light trainable or not: {trainable}")
|
||||
if trainable:
|
||||
self.register_parameter('env_base', self.base)
|
||||
|
||||
def xfm(self, mtx):
|
||||
self.mtx = mtx
|
||||
|
||||
def clone(self):
|
||||
return EnvironmentLight(self.base.clone().detach())
|
||||
|
||||
def clamp_(self, min=None, max=None):
|
||||
self.base.clamp_(min, max)
|
||||
|
||||
def get_mip(self, roughness):
|
||||
return torch.where(roughness < self.MAX_ROUGHNESS
|
||||
, (torch.clamp(roughness, self.MIN_ROUGHNESS, self.MAX_ROUGHNESS) - self.MIN_ROUGHNESS) / (self.MAX_ROUGHNESS - self.MIN_ROUGHNESS) * (len(self.specular) - 2)
|
||||
, (torch.clamp(roughness, self.MAX_ROUGHNESS, 1.0) - self.MAX_ROUGHNESS) / (1.0 - self.MAX_ROUGHNESS) + len(self.specular) - 2)
|
||||
|
||||
def build_mips(self, cutoff=0.99):
|
||||
self.specular = [self.base]
|
||||
while self.specular[-1].shape[1] > self.LIGHT_MIN_RES:
|
||||
self.specular += [cubemap_mip.apply(self.specular[-1])]
|
||||
|
||||
self.diffuse = ru.diffuse_cubemap(self.specular[-1])
|
||||
|
||||
for idx in range(len(self.specular) - 1):
|
||||
roughness = (idx / (len(self.specular) - 2)) * (self.MAX_ROUGHNESS - self.MIN_ROUGHNESS) + self.MIN_ROUGHNESS
|
||||
self.specular[idx] = ru.specular_cubemap(self.specular[idx], roughness, cutoff)
|
||||
self.specular[-1] = ru.specular_cubemap(self.specular[-1], 1.0, cutoff)
|
||||
|
||||
def regularizer(self):
|
||||
white = (self.base[..., 0:1] + self.base[..., 1:2] + self.base[..., 2:3]) / 3.0
|
||||
return torch.mean(torch.abs(self.base - white))
|
||||
|
||||
def shade(self, gb_pos, gb_normal, kd, ks, view_pos, specular=True, xfm_lgt=None):
|
||||
wo = util.safe_normalize(view_pos - gb_pos)
|
||||
|
||||
if specular:
|
||||
roughness = ks[..., 1:2] # y component
|
||||
metallic = ks[..., 2:3] # z component
|
||||
spec_col = (1.0 - metallic)*0.04 + kd * metallic
|
||||
diff_col = kd * (1.0 - metallic)
|
||||
else:
|
||||
diff_col = kd
|
||||
|
||||
reflvec = util.safe_normalize(util.reflect(wo, gb_normal))
|
||||
nrmvec = gb_normal
|
||||
if xfm_lgt is not None:
|
||||
# print(self.mtx.size())
|
||||
mtx = torch.as_tensor(xfm_lgt, dtype=torch.float32, device='cuda')
|
||||
reflvec = ru.xfm_vectors(reflvec.view(reflvec.shape[0], reflvec.shape[1] * reflvec.shape[2], reflvec.shape[3]), mtx).view(*reflvec.shape)
|
||||
nrmvec = ru.xfm_vectors(nrmvec.view(nrmvec.shape[0], nrmvec.shape[1] * nrmvec.shape[2], nrmvec.shape[3]), mtx).view(*nrmvec.shape)
|
||||
elif self.mtx is not None: # Rotate lookup
|
||||
raise NotImplementedError
|
||||
# print(self.mtx.size())
|
||||
mtx = torch.as_tensor(self.mtx, dtype=torch.float32, device='cuda')
|
||||
reflvec = ru.xfm_vectors(reflvec.view(reflvec.shape[0], reflvec.shape[1] * reflvec.shape[2], reflvec.shape[3]), mtx).view(*reflvec.shape)
|
||||
nrmvec = ru.xfm_vectors(nrmvec.view(nrmvec.shape[0], nrmvec.shape[1] * nrmvec.shape[2], nrmvec.shape[3]), mtx).view(*nrmvec.shape)
|
||||
|
||||
# if self.mtx is not None: # Rotate lookup
|
||||
# # print(self.mtx.size())
|
||||
# mtx = torch.as_tensor(self.mtx, dtype=torch.float32, device='cuda')
|
||||
# reflvec = ru.xfm_vectors(reflvec.view(reflvec.shape[0], reflvec.shape[1] * reflvec.shape[2], reflvec.shape[3]), mtx).view(*reflvec.shape)
|
||||
# nrmvec = ru.xfm_vectors(nrmvec.view(nrmvec.shape[0], nrmvec.shape[1] * nrmvec.shape[2], nrmvec.shape[3]), mtx).view(*nrmvec.shape)
|
||||
|
||||
# Diffuse lookup
|
||||
diffuse = dr.texture(self.diffuse[None, ...], nrmvec.contiguous(), filter_mode='linear', boundary_mode='cube')
|
||||
shaded_col = diffuse * diff_col
|
||||
|
||||
if specular:
|
||||
raise NotImplementedError
|
||||
# Lookup FG term from lookup texture
|
||||
NdotV = torch.clamp(util.dot(wo, gb_normal), min=1e-4)
|
||||
fg_uv = torch.cat((NdotV, roughness), dim=-1)
|
||||
if not hasattr(self, '_FG_LUT'):
|
||||
self._FG_LUT = torch.as_tensor(np.fromfile('data/irrmaps/bsdf_256_256.bin', dtype=np.float32).reshape(1, 256, 256, 2), dtype=torch.float32, device='cuda')
|
||||
fg_lookup = dr.texture(self._FG_LUT, fg_uv, filter_mode='linear', boundary_mode='clamp')
|
||||
|
||||
# Roughness adjusted specular env lookup
|
||||
miplevel = self.get_mip(roughness)
|
||||
spec = dr.texture(self.specular[0][None, ...], reflvec.contiguous(), mip=list(m[None, ...] for m in self.specular[1:]), mip_level_bias=miplevel[..., 0], filter_mode='linear-mipmap-linear', boundary_mode='cube')
|
||||
|
||||
# Compute aggregate lighting
|
||||
reflectance = spec_col * fg_lookup[...,0:1] + fg_lookup[...,1:2]
|
||||
shaded_col += spec * reflectance
|
||||
|
||||
assert ks[..., 0:1].sum().item() == 0
|
||||
return shaded_col * (1.0 - ks[..., 0:1]) # Modulate by hemisphere visibility
|
||||
|
||||
######################################################################################
|
||||
# Load and store
|
||||
######################################################################################
|
||||
|
||||
# Load from latlong .HDR file
|
||||
def _load_env_hdr(fn, scale=1.0, trainable=True):
|
||||
print("load env inner loop")
|
||||
sys.stdout.flush()
|
||||
latlong_img = torch.tensor(util.load_image(fn), dtype=torch.float32, device='cuda')*scale
|
||||
print("get cubemap")
|
||||
sys.stdout.flush()
|
||||
cubemap = util.latlong_to_cubemap(latlong_img, [512, 512])
|
||||
|
||||
print("get light object")
|
||||
sys.stdout.flush()
|
||||
l = EnvironmentLight(cubemap, trainable=trainable)
|
||||
print("build mips")
|
||||
sys.stdout.flush()
|
||||
l.build_mips()
|
||||
print("build mips done")
|
||||
sys.stdout.flush()
|
||||
|
||||
return l
|
||||
|
||||
def load_env(fn, scale=1.0, trainable=True):
|
||||
if os.path.splitext(fn)[1].lower() == ".hdr":
|
||||
return _load_env_hdr(fn, scale, trainable=trainable)
|
||||
else:
|
||||
assert False, "Unknown envlight extension %s" % os.path.splitext(fn)[1]
|
||||
|
||||
def save_env_map(fn, light):
|
||||
assert isinstance(light, EnvironmentLight), "Can only save EnvironmentLight currently"
|
||||
if isinstance(light, EnvironmentLight):
|
||||
color = util.cubemap_to_latlong(light.base, [512, 1024])
|
||||
util.save_image_raw(fn, color.detach().cpu().numpy())
|
||||
|
||||
######################################################################################
|
||||
# Create trainable env map with random initialization
|
||||
######################################################################################
|
||||
|
||||
def create_trainable_env_rnd(base_res, scale=0.5, bias=0.25):
|
||||
base = torch.rand(6, base_res, base_res, 3, dtype=torch.float32, device='cuda') * scale + bias
|
||||
return EnvironmentLight(base)
|
||||
|
|
@ -0,0 +1,199 @@
|
|||
# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
#
|
||||
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
||||
# property and proprietary rights in and to this material, related
|
||||
# documentation and any modifications thereto. Any use, reproduction,
|
||||
# disclosure or distribution of this material and related documentation
|
||||
# without an express license agreement from NVIDIA CORPORATION or
|
||||
# its affiliates is strictly prohibited.
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from . import util
|
||||
from . import texture
|
||||
|
||||
######################################################################################
|
||||
# Wrapper to make materials behave like a python dict, but register textures as
|
||||
# torch.nn.Module parameters.
|
||||
######################################################################################
|
||||
class Material(torch.nn.Module):
|
||||
def __init__(self, mat_dict):
|
||||
super(Material, self).__init__()
|
||||
self.mat_keys = set()
|
||||
for key in mat_dict.keys():
|
||||
self.mat_keys.add(key)
|
||||
self[key] = mat_dict[key]
|
||||
|
||||
def __contains__(self, key):
|
||||
return hasattr(self, key)
|
||||
|
||||
def __getitem__(self, key):
|
||||
return getattr(self, key)
|
||||
|
||||
def __setitem__(self, key, val):
|
||||
self.mat_keys.add(key)
|
||||
setattr(self, key, val)
|
||||
|
||||
def __delitem__(self, key):
|
||||
self.mat_keys.remove(key)
|
||||
delattr(self, key)
|
||||
|
||||
def keys(self):
|
||||
return self.mat_keys
|
||||
|
||||
######################################################################################
|
||||
# .mtl material format loading / storing
|
||||
######################################################################################
|
||||
@torch.no_grad()
|
||||
def load_mtl(fn, clear_ks=True, avoid_pure_black=False):
|
||||
import re
|
||||
mtl_path = os.path.dirname(fn)
|
||||
|
||||
# Read file
|
||||
with open(fn, 'r') as f:
|
||||
lines = f.readlines()
|
||||
|
||||
# Parse materials
|
||||
materials = []
|
||||
for line in lines:
|
||||
split_line = re.split(' +|\t+|\n+', line.strip())
|
||||
prefix = split_line[0].lower()
|
||||
data = split_line[1:]
|
||||
if 'newmtl' in prefix:
|
||||
material = Material({'name' : data[0]})
|
||||
materials += [material]
|
||||
elif materials:
|
||||
if 'bsdf' in prefix or 'map_kd' in prefix or 'map_ks' in prefix or 'bump' in prefix:
|
||||
material[prefix] = data[0]
|
||||
else:
|
||||
if 'kd' in prefix and avoid_pure_black:
|
||||
tmp_kd = torch.tensor(tuple(float(d) for d in data), dtype=torch.float32, device='cuda')
|
||||
if tmp_kd.sum() == 0.0:
|
||||
tmp_kd[0] = 1.0
|
||||
tmp_kd[1] = 0.75
|
||||
material[prefix] = tmp_kd
|
||||
else:
|
||||
material[prefix] = torch.tensor(tuple(float(d) for d in data), dtype=torch.float32, device='cuda')
|
||||
|
||||
# Convert everything to textures. Our code expects 'kd' and 'ks' to be texture maps. So replace constants with 1x1 maps
|
||||
for mat in materials:
|
||||
if not 'bsdf' in mat:
|
||||
mat['bsdf'] = 'pbr'
|
||||
|
||||
if 'map_kd' in mat:
|
||||
mat['kd'] = texture.load_texture2D(os.path.join(mtl_path, mat['map_kd']))
|
||||
# mat['kd'] = texture.load_texture2D(os.path.join(mtl_path, mat['map_kd']), channels=3)
|
||||
else:
|
||||
mat['kd'] = texture.Texture2D(mat['kd'])
|
||||
|
||||
if 'map_ks' in mat:
|
||||
mat['ks'] = texture.load_texture2D(os.path.join(mtl_path, mat['map_ks']), channels=3)
|
||||
else:
|
||||
mat['ks'] = texture.Texture2D(mat['ks'])
|
||||
|
||||
if 'bump' in mat:
|
||||
mat['normal'] = texture.load_texture2D(os.path.join(mtl_path, mat['bump']), lambda_fn=lambda x: x * 2 - 1, channels=3)
|
||||
|
||||
# Convert Kd from sRGB to linear RGB
|
||||
mat['kd'] = texture.srgb_to_rgb(mat['kd'])
|
||||
|
||||
if clear_ks:
|
||||
# Override ORM occlusion (red) channel by zeros. We hijack this channel
|
||||
for mip in mat['ks'].getMips():
|
||||
mip[..., 0] = 0.0
|
||||
|
||||
return materials
|
||||
|
||||
@torch.no_grad()
|
||||
def save_mtl(fn, material):
|
||||
folder = os.path.dirname(fn)
|
||||
with open(fn, "w") as f:
|
||||
f.write('newmtl defaultMat\n')
|
||||
if material is not None:
|
||||
f.write('bsdf %s\n' % material['bsdf'])
|
||||
if 'kd' in material.keys():
|
||||
f.write('map_kd texture_kd.png\n')
|
||||
texture.save_texture2D(os.path.join(folder, 'texture_kd.png'), texture.rgb_to_srgb(material['kd']))
|
||||
if 'ks' in material.keys():
|
||||
f.write('map_ks texture_ks.png\n')
|
||||
texture.save_texture2D(os.path.join(folder, 'texture_ks.png'), material['ks'])
|
||||
if 'normal' in material.keys():
|
||||
f.write('bump texture_n.png\n')
|
||||
texture.save_texture2D(os.path.join(folder, 'texture_n.png'), material['normal'], lambda_fn=lambda x:(util.safe_normalize(x)+1)*0.5)
|
||||
else:
|
||||
f.write('Kd 1 1 1\n')
|
||||
f.write('Ks 0 0 0\n')
|
||||
f.write('Ka 0 0 0\n')
|
||||
f.write('Tf 1 1 1\n')
|
||||
f.write('Ni 1\n')
|
||||
f.write('Ns 0\n')
|
||||
|
||||
######################################################################################
|
||||
# Merge multiple materials into a single uber-material
|
||||
######################################################################################
|
||||
|
||||
def _upscale_replicate(x, full_res):
|
||||
x = x.permute(0, 3, 1, 2)
|
||||
x = torch.nn.functional.pad(x, (0, full_res[1] - x.shape[3], 0, full_res[0] - x.shape[2]), 'replicate')
|
||||
return x.permute(0, 2, 3, 1).contiguous()
|
||||
|
||||
def merge_materials(materials, texcoords, tfaces, mfaces):
|
||||
assert len(materials) > 0
|
||||
for mat in materials:
|
||||
assert mat['bsdf'] == materials[0]['bsdf'], "All materials must have the same BSDF (uber shader)"
|
||||
assert ('normal' in mat) is ('normal' in materials[0]), "All materials must have either normal map enabled or disabled"
|
||||
|
||||
uber_material = Material({
|
||||
'name' : 'uber_material',
|
||||
'bsdf' : materials[0]['bsdf'],
|
||||
})
|
||||
|
||||
textures = ['kd', 'ks', 'normal']
|
||||
|
||||
# Find maximum texture resolution across all materials and textures
|
||||
max_res = None
|
||||
for mat in materials:
|
||||
for tex in textures:
|
||||
tex_res = np.array(mat[tex].getRes()) if tex in mat else np.array([1, 1])
|
||||
max_res = np.maximum(max_res, tex_res) if max_res is not None else tex_res
|
||||
|
||||
# Compute size of compund texture and round up to nearest PoT
|
||||
full_res = 2**np.ceil(np.log2(max_res * np.array([1, len(materials)]))).astype(np.int)
|
||||
|
||||
# Normalize texture resolution across all materials & combine into a single large texture
|
||||
for tex in textures:
|
||||
if tex in materials[0]:
|
||||
tex_data = torch.cat(tuple(util.scale_img_nhwc(mat[tex].data, tuple(max_res)) for mat in materials), dim=2) # Lay out all textures horizontally, NHWC so dim2 is x
|
||||
tex_data = _upscale_replicate(tex_data, full_res)
|
||||
uber_material[tex] = texture.Texture2D(tex_data)
|
||||
|
||||
# Compute scaling values for used / unused texture area
|
||||
s_coeff = [full_res[0] / max_res[0], full_res[1] / max_res[1]]
|
||||
|
||||
# Recompute texture coordinates to cooincide with new composite texture
|
||||
new_tverts = {}
|
||||
new_tverts_data = []
|
||||
for fi in range(len(tfaces)):
|
||||
matIdx = mfaces[fi]
|
||||
for vi in range(3):
|
||||
ti = tfaces[fi][vi]
|
||||
if not (ti in new_tverts):
|
||||
new_tverts[ti] = {}
|
||||
if not (matIdx in new_tverts[ti]): # create new vertex
|
||||
if len(texcoords) == 0:
|
||||
# continue
|
||||
new_tverts_data.append([(matIdx) / s_coeff[1], 0]) # Offset texture coodrinate (x direction) by material id & scale to local space. Note, texcoords are (u,v) but texture is stored (w,h) so the indexes swap here
|
||||
new_tverts[ti][matIdx] = len(new_tverts_data) - 1
|
||||
else:
|
||||
new_tverts_data.append([(matIdx + texcoords[ti][0]) / s_coeff[1], texcoords[ti][1] / s_coeff[0]]) # Offset texture coodrinate (x direction) by material id & scale to local space. Note, texcoords are (u,v) but texture is stored (w,h) so the indexes swap here
|
||||
new_tverts[ti][matIdx] = len(new_tverts_data) - 1
|
||||
|
||||
# if not (matIdx in new_tverts[ti]): # create new vertex
|
||||
# new_tverts_data.append([(matIdx + texcoords[ti][0]) / s_coeff[1], texcoords[ti][1] / s_coeff[0]]) # Offset texture coodrinate (x direction) by material id & scale to local space. Note, texcoords are (u,v) but texture is stored (w,h) so the indexes swap here
|
||||
# new_tverts[ti][matIdx] = len(new_tverts_data) - 1
|
||||
tfaces[fi][vi] = new_tverts[ti][matIdx] # reindex vertex
|
||||
|
||||
return uber_material, new_tverts_data, tfaces
|
||||
|
|
@ -0,0 +1,277 @@
|
|||
# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
#
|
||||
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
||||
# property and proprietary rights in and to this material, related
|
||||
# documentation and any modifications thereto. Any use, reproduction,
|
||||
# disclosure or distribution of this material and related documentation
|
||||
# without an express license agreement from NVIDIA CORPORATION or
|
||||
# its affiliates is strictly prohibited.
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from . import obj
|
||||
from . import util
|
||||
|
||||
######################################################################################
|
||||
# Base mesh class
|
||||
######################################################################################
|
||||
class Mesh:
|
||||
def __init__(self, v_pos=None, t_pos_idx=None, v_nrm=None, t_nrm_idx=None, v_tex=None, t_tex_idx=None, v_tng=None, t_tng_idx=None,
|
||||
material=None, base=None, f_nrm=None):
|
||||
self.v_pos = v_pos
|
||||
self.v_nrm = v_nrm
|
||||
self.v_tex = v_tex
|
||||
self.v_tng = v_tng
|
||||
self.t_pos_idx = t_pos_idx
|
||||
self.t_nrm_idx = t_nrm_idx
|
||||
self.t_tex_idx = t_tex_idx
|
||||
self.t_tng_idx = t_tng_idx
|
||||
self.material = material
|
||||
# self.f_nrm = f_nrm
|
||||
|
||||
if base is not None:
|
||||
self.copy_none(base)
|
||||
|
||||
try:
|
||||
i0 = self.t_pos_idx[:, 0]
|
||||
i1 = self.t_pos_idx[:, 1]
|
||||
i2 = self.t_pos_idx[:, 2]
|
||||
|
||||
v0 = self.v_pos[i0, :]
|
||||
v1 = self.v_pos[i1, :]
|
||||
v2 = self.v_pos[i2, :]
|
||||
|
||||
self.f_nrm = face_normals = torch.cross(v1 - v0, v2 - v0)
|
||||
except:
|
||||
self.f_nrm = f_nrm
|
||||
|
||||
def copy_none(self, other):
|
||||
if self.v_pos is None:
|
||||
self.v_pos = other.v_pos
|
||||
if self.t_pos_idx is None:
|
||||
self.t_pos_idx = other.t_pos_idx
|
||||
if self.v_nrm is None:
|
||||
self.v_nrm = other.v_nrm
|
||||
if self.t_nrm_idx is None:
|
||||
self.t_nrm_idx = other.t_nrm_idx
|
||||
if self.v_tex is None:
|
||||
self.v_tex = other.v_tex
|
||||
if self.t_tex_idx is None:
|
||||
self.t_tex_idx = other.t_tex_idx
|
||||
if self.v_tng is None:
|
||||
self.v_tng = other.v_tng
|
||||
if self.t_tng_idx is None:
|
||||
self.t_tng_idx = other.t_tng_idx
|
||||
if self.material is None:
|
||||
self.material = other.material
|
||||
# if self.f_nrm is None:
|
||||
# self.f_nrm = other.f_nrm
|
||||
|
||||
|
||||
def clone(self):
|
||||
out = Mesh(base=self)
|
||||
if out.v_pos is not None:
|
||||
out.v_pos = out.v_pos.clone().detach()
|
||||
if out.t_pos_idx is not None:
|
||||
out.t_pos_idx = out.t_pos_idx.clone().detach()
|
||||
if out.v_nrm is not None:
|
||||
out.v_nrm = out.v_nrm.clone().detach()
|
||||
if out.t_nrm_idx is not None:
|
||||
out.t_nrm_idx = out.t_nrm_idx.clone().detach()
|
||||
if out.v_tex is not None:
|
||||
out.v_tex = out.v_tex.clone().detach()
|
||||
if out.t_tex_idx is not None:
|
||||
out.t_tex_idx = out.t_tex_idx.clone().detach()
|
||||
if out.v_tng is not None:
|
||||
out.v_tng = out.v_tng.clone().detach()
|
||||
if out.t_tng_idx is not None:
|
||||
out.t_tng_idx = out.t_tng_idx.clone().detach()
|
||||
if out.f_nrm is not None:
|
||||
out.f_nrm = out.f_nrm.clone().detach()
|
||||
return out
|
||||
|
||||
######################################################################################
|
||||
# Mesh loeading helper
|
||||
######################################################################################
|
||||
|
||||
def load_mesh(filename, mtl_override=None, mtl_default=None, use_default=False, no_additional=False):
|
||||
name, ext = os.path.splitext(filename)
|
||||
if ext == ".obj":
|
||||
return obj.load_obj(filename, clear_ks=True, mtl_override=mtl_override, mtl_default=mtl_default, use_default=use_default, no_additional=no_additional)
|
||||
assert False, "Invalid mesh file extension"
|
||||
|
||||
######################################################################################
|
||||
# Compute AABB
|
||||
######################################################################################
|
||||
def aabb(mesh):
|
||||
return torch.min(mesh.v_pos, dim=0).values, torch.max(mesh.v_pos, dim=0).values
|
||||
|
||||
######################################################################################
|
||||
# Compute AABB with only used vertices
|
||||
######################################################################################
|
||||
def aabb_clean(mesh):
|
||||
v_pos_clean = mesh.v_pos[mesh.t_pos_idx.unique()]
|
||||
return torch.min(v_pos_clean, dim=0).values, torch.max(v_pos_clean, dim=0).values
|
||||
|
||||
######################################################################################
|
||||
# Compute unique edge list from attribute/vertex index list
|
||||
######################################################################################
|
||||
def compute_edges(attr_idx, return_inverse=False):
|
||||
with torch.no_grad():
|
||||
# Create all edges, packed by triangle
|
||||
all_edges = torch.cat((
|
||||
torch.stack((attr_idx[:, 0], attr_idx[:, 1]), dim=-1),
|
||||
torch.stack((attr_idx[:, 1], attr_idx[:, 2]), dim=-1),
|
||||
torch.stack((attr_idx[:, 2], attr_idx[:, 0]), dim=-1),
|
||||
), dim=-1).view(-1, 2)
|
||||
|
||||
# Swap edge order so min index is always first
|
||||
order = (all_edges[:, 0] > all_edges[:, 1]).long().unsqueeze(dim=1)
|
||||
sorted_edges = torch.cat((
|
||||
torch.gather(all_edges, 1, order),
|
||||
torch.gather(all_edges, 1, 1 - order)
|
||||
), dim=-1)
|
||||
|
||||
# Eliminate duplicates and return inverse mapping
|
||||
return torch.unique(sorted_edges, dim=0, return_inverse=return_inverse)
|
||||
|
||||
######################################################################################
|
||||
# Compute unique edge to face mapping from attribute/vertex index list
|
||||
######################################################################################
|
||||
def compute_edge_to_face_mapping(attr_idx, return_inverse=False):
|
||||
with torch.no_grad():
|
||||
# Get unique edges
|
||||
# Create all edges, packed by triangle
|
||||
all_edges = torch.cat((
|
||||
torch.stack((attr_idx[:, 0], attr_idx[:, 1]), dim=-1),
|
||||
torch.stack((attr_idx[:, 1], attr_idx[:, 2]), dim=-1),
|
||||
torch.stack((attr_idx[:, 2], attr_idx[:, 0]), dim=-1),
|
||||
), dim=-1).view(-1, 2)
|
||||
|
||||
# Swap edge order so min index is always first
|
||||
order = (all_edges[:, 0] > all_edges[:, 1]).long().unsqueeze(dim=1)
|
||||
sorted_edges = torch.cat((
|
||||
torch.gather(all_edges, 1, order),
|
||||
torch.gather(all_edges, 1, 1 - order)
|
||||
), dim=-1)
|
||||
|
||||
# Elliminate duplicates and return inverse mapping
|
||||
unique_edges, idx_map = torch.unique(sorted_edges, dim=0, return_inverse=True)
|
||||
|
||||
tris = torch.arange(attr_idx.shape[0]).repeat_interleave(3).cuda()
|
||||
|
||||
tris_per_edge = torch.zeros((unique_edges.shape[0], 2), dtype=torch.int64).cuda()
|
||||
|
||||
# Compute edge to face table
|
||||
mask0 = order[:,0] == 0
|
||||
mask1 = order[:,0] == 1
|
||||
tris_per_edge[idx_map[mask0], 0] = tris[mask0]
|
||||
tris_per_edge[idx_map[mask1], 1] = tris[mask1]
|
||||
|
||||
return tris_per_edge
|
||||
|
||||
######################################################################################
|
||||
# Align base mesh to reference mesh:move & rescale to match bounding boxes.
|
||||
######################################################################################
|
||||
def unit_size(mesh):
|
||||
with torch.no_grad():
|
||||
vmin, vmax = aabb(mesh)
|
||||
scale = 2 / torch.max(vmax - vmin).item()
|
||||
v_pos = mesh.v_pos - (vmax + vmin) / 2 # Center mesh on origin
|
||||
v_pos = v_pos * scale # Rescale to unit size
|
||||
|
||||
return Mesh(v_pos, base=mesh)
|
||||
|
||||
######################################################################################
|
||||
# Center & scale mesh for rendering
|
||||
######################################################################################
|
||||
def center_by_reference(base_mesh, ref_aabb, scale):
|
||||
center = (ref_aabb[0] + ref_aabb[1]) * 0.5
|
||||
scale = scale / torch.max(ref_aabb[1] - ref_aabb[0]).item()
|
||||
print('normalization:', center, scale)
|
||||
v_pos = (base_mesh.v_pos - center[None, ...]) * scale
|
||||
return Mesh(v_pos, base=base_mesh)
|
||||
|
||||
######################################################################################
|
||||
# Simple smooth vertex normal computation
|
||||
######################################################################################
|
||||
def auto_normals(imesh):
|
||||
|
||||
i0 = imesh.t_pos_idx[:, 0]
|
||||
i1 = imesh.t_pos_idx[:, 1]
|
||||
i2 = imesh.t_pos_idx[:, 2]
|
||||
|
||||
v0 = imesh.v_pos[i0, :]
|
||||
v1 = imesh.v_pos[i1, :]
|
||||
v2 = imesh.v_pos[i2, :]
|
||||
|
||||
f_nrm = face_normals = torch.cross(v1 - v0, v2 - v0)
|
||||
|
||||
# Splat face normals to vertices
|
||||
v_nrm = torch.zeros_like(imesh.v_pos)
|
||||
v_nrm.scatter_add_(0, i0[:, None].repeat(1,3), face_normals)
|
||||
v_nrm.scatter_add_(0, i1[:, None].repeat(1,3), face_normals)
|
||||
v_nrm.scatter_add_(0, i2[:, None].repeat(1,3), face_normals)
|
||||
|
||||
# Normalize, replace zero (degenerated) normals with some default value
|
||||
v_nrm = torch.where(util.dot(v_nrm, v_nrm) > 1e-20, v_nrm, torch.tensor([0.0, 0.0, 1.0], dtype=torch.float32, device='cuda'))
|
||||
v_nrm = util.safe_normalize(v_nrm)
|
||||
|
||||
if torch.is_anomaly_enabled():
|
||||
assert torch.all(torch.isfinite(v_nrm))
|
||||
|
||||
return Mesh(v_nrm=v_nrm, t_nrm_idx=imesh.t_pos_idx, base=imesh, f_nrm=f_nrm)
|
||||
|
||||
######################################################################################
|
||||
# Compute tangent space from texture map coordinates
|
||||
# Follows http://www.mikktspace.com/ conventions
|
||||
######################################################################################
|
||||
def compute_tangents(imesh):
|
||||
vn_idx = [None] * 3
|
||||
pos = [None] * 3
|
||||
tex = [None] * 3
|
||||
for i in range(0,3):
|
||||
pos[i] = imesh.v_pos[imesh.t_pos_idx[:, i]]
|
||||
tex[i] = imesh.v_tex[imesh.t_tex_idx[:, i]]
|
||||
vn_idx[i] = imesh.t_nrm_idx[:, i]
|
||||
|
||||
tangents = torch.zeros_like(imesh.v_nrm)
|
||||
tansum = torch.zeros_like(imesh.v_nrm)
|
||||
|
||||
# Compute tangent space for each triangle
|
||||
uve1 = tex[1] - tex[0]
|
||||
uve2 = tex[2] - tex[0]
|
||||
pe1 = pos[1] - pos[0]
|
||||
pe2 = pos[2] - pos[0]
|
||||
|
||||
nom = (pe1 * uve2[..., 1:2] - pe2 * uve1[..., 1:2])
|
||||
denom = (uve1[..., 0:1] * uve2[..., 1:2] - uve1[..., 1:2] * uve2[..., 0:1])
|
||||
assert not torch.isnan(uve1).any()
|
||||
assert not torch.isnan(uve2).any()
|
||||
assert not torch.isnan(pe1).any()
|
||||
assert not torch.isnan(pe2).any()
|
||||
|
||||
# Avoid division by zero for degenerated texture coordinates
|
||||
tang = nom / torch.where(denom > 0.0, torch.clamp(denom, min=1e-6), torch.clamp(denom, max=-1e-6)) #### ZL: something wrong in this line, not sure why
|
||||
assert (torch.where(denom > 0.0, torch.clamp(denom, min=1e-6), torch.clamp(denom, max=-1e-6)) != 0.0).all()
|
||||
assert not torch.isnan(nom).any()
|
||||
assert not torch.isnan(tang).any()
|
||||
|
||||
# Update all 3 vertices
|
||||
for i in range(0,3):
|
||||
idx = vn_idx[i][:, None].repeat(1,3)
|
||||
tangents.scatter_add_(0, idx, tang) # tangents[n_i] = tangents[n_i] + tang
|
||||
tansum.scatter_add_(0, idx, torch.ones_like(tang)) # tansum[n_i] = tansum[n_i] + 1
|
||||
tangents = tangents / tansum
|
||||
assert not torch.isnan(tangents).any()
|
||||
|
||||
# Normalize and make sure tangent is perpendicular to normal
|
||||
tangents = util.safe_normalize(tangents)
|
||||
tangents = util.safe_normalize(tangents - util.dot(tangents, imesh.v_nrm) * imesh.v_nrm)
|
||||
|
||||
if torch.is_anomaly_enabled():
|
||||
assert torch.all(torch.isfinite(tangents))
|
||||
|
||||
return Mesh(v_tng=tangents, t_tng_idx=imesh.t_nrm_idx, base=imesh)
|
|
@ -0,0 +1,104 @@
|
|||
# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
#
|
||||
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
||||
# property and proprietary rights in and to this material, related
|
||||
# documentation and any modifications thereto. Any use, reproduction,
|
||||
# disclosure or distribution of this material and related documentation
|
||||
# without an express license agreement from NVIDIA CORPORATION or
|
||||
# its affiliates is strictly prohibited.
|
||||
|
||||
import torch
|
||||
import tinycudann as tcnn
|
||||
import numpy as np
|
||||
|
||||
#######################################################################################################################################################
|
||||
# Small MLP using PyTorch primitives, internal helper class
|
||||
#######################################################################################################################################################
|
||||
|
||||
class _MLP(torch.nn.Module):
|
||||
def __init__(self, cfg, loss_scale=1.0):
|
||||
super(_MLP, self).__init__()
|
||||
self.loss_scale = loss_scale
|
||||
net = (torch.nn.Linear(cfg['n_input_dims'], cfg['n_neurons'], bias=False), torch.nn.ReLU())
|
||||
for i in range(cfg['n_hidden_layers']-1):
|
||||
net = net + (torch.nn.Linear(cfg['n_neurons'], cfg['n_neurons'], bias=False), torch.nn.ReLU())
|
||||
net = net + (torch.nn.Linear(cfg['n_neurons'], cfg['n_output_dims'], bias=False),)
|
||||
self.net = torch.nn.Sequential(*net).cuda()
|
||||
|
||||
self.net.apply(self._init_weights)
|
||||
|
||||
if self.loss_scale != 1.0:
|
||||
self.net.register_full_backward_hook(lambda module, grad_i, grad_o: (grad_i[0] * self.loss_scale, ))
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x.to(torch.float32))
|
||||
|
||||
@staticmethod
|
||||
def _init_weights(m):
|
||||
if type(m) == torch.nn.Linear:
|
||||
torch.nn.init.kaiming_uniform_(m.weight, nonlinearity='relu')
|
||||
if hasattr(m.bias, 'data'):
|
||||
m.bias.data.fill_(0.0)
|
||||
|
||||
#######################################################################################################################################################
|
||||
# Outward visible MLP class
|
||||
#######################################################################################################################################################
|
||||
|
||||
class MLPTexture3D(torch.nn.Module):
|
||||
def __init__(self, AABB, channels = 3, internal_dims = 32, hidden = 2, min_max = None):
|
||||
super(MLPTexture3D, self).__init__()
|
||||
|
||||
self.channels = channels
|
||||
self.internal_dims = internal_dims
|
||||
self.AABB = AABB
|
||||
self.min_max = min_max
|
||||
|
||||
# Setup positional encoding, see https://github.com/NVlabs/tiny-cuda-nn for details
|
||||
desired_resolution = 4096
|
||||
base_grid_resolution = 16
|
||||
num_levels = 16
|
||||
per_level_scale = np.exp(np.log(desired_resolution / base_grid_resolution) / (num_levels-1))
|
||||
|
||||
enc_cfg = {
|
||||
"otype": "HashGrid",
|
||||
"n_levels": num_levels,
|
||||
"n_features_per_level": 2,
|
||||
"log2_hashmap_size": 19,
|
||||
"base_resolution": base_grid_resolution,
|
||||
"per_level_scale" : per_level_scale
|
||||
}
|
||||
|
||||
gradient_scaling = 128.0
|
||||
self.encoder = tcnn.Encoding(3, enc_cfg)
|
||||
self.encoder.register_full_backward_hook(lambda module, grad_i, grad_o: (grad_i[0] / gradient_scaling, ))
|
||||
|
||||
# Setup MLP
|
||||
mlp_cfg = {
|
||||
"n_input_dims" : self.encoder.n_output_dims,
|
||||
"n_output_dims" : self.channels,
|
||||
"n_hidden_layers" : hidden,
|
||||
"n_neurons" : self.internal_dims
|
||||
}
|
||||
self.net = _MLP(mlp_cfg, gradient_scaling)
|
||||
print("Encoder output: %d dims" % (self.encoder.n_output_dims))
|
||||
|
||||
# Sample texture at a given location
|
||||
def sample(self, texc):
|
||||
_texc = (texc.view(-1, 3) - self.AABB[0][None, ...]) / (self.AABB[1][None, ...] - self.AABB[0][None, ...])
|
||||
_texc = torch.clamp(_texc, min=0, max=1)
|
||||
|
||||
p_enc = self.encoder(_texc.contiguous())
|
||||
out = self.net.forward(p_enc)
|
||||
|
||||
# Sigmoid limit and scale to the allowed range
|
||||
if self.min_max is not None:
|
||||
out = torch.sigmoid(out) * (self.min_max[1][None, :] - self.min_max[0][None, :]) + self.min_max[0][None, :]
|
||||
|
||||
return out.view(*texc.shape[:-1], self.channels) # Remap to [n, h, w, c]
|
||||
|
||||
# In-place clamp with no derivative to make sure values are in valid range after training
|
||||
def clamp_(self):
|
||||
pass
|
||||
|
||||
def cleanup(self):
|
||||
tcnn.free_temporary_memory()
|
|
@ -0,0 +1,216 @@
|
|||
# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
#
|
||||
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
||||
# property and proprietary rights in and to this material, related
|
||||
# documentation and any modifications thereto. Any use, reproduction,
|
||||
# disclosure or distribution of this material and related documentation
|
||||
# without an express license agreement from NVIDIA CORPORATION or
|
||||
# its affiliates is strictly prohibited.
|
||||
|
||||
import os
|
||||
import torch
|
||||
|
||||
from . import texture
|
||||
from . import mesh
|
||||
from . import material
|
||||
|
||||
######################################################################################
|
||||
# Utility functions
|
||||
######################################################################################
|
||||
|
||||
def _find_mat(materials, name):
|
||||
for mat in materials:
|
||||
if mat['name'] == name:
|
||||
return mat
|
||||
return materials[0] # Materials 0 is the default
|
||||
|
||||
######################################################################################
|
||||
# Create mesh object from objfile
|
||||
######################################################################################
|
||||
|
||||
def load_obj(filename, clear_ks=True, mtl_override=None, mtl_default=None, use_default=False, no_additional=False):
|
||||
obj_path = os.path.dirname(filename)
|
||||
|
||||
# Read entire file
|
||||
with open(filename, 'r') as f:
|
||||
lines = f.readlines()
|
||||
|
||||
# Load materials
|
||||
if mtl_default is None:
|
||||
all_materials = [
|
||||
{
|
||||
'name' : '_default_mat',
|
||||
'bsdf' : 'pbr',
|
||||
'kd' : texture.Texture2D(torch.tensor([0.5, 0.5, 0.5], dtype=torch.float32, device='cuda')),
|
||||
'ks' : texture.Texture2D(torch.tensor([0.0, 0.0, 0.0], dtype=torch.float32, device='cuda'))
|
||||
}
|
||||
]
|
||||
else:
|
||||
print("Load use-defined default mtl")
|
||||
all_materials = [mtl_default]
|
||||
|
||||
if not no_additional:
|
||||
if mtl_override is None:
|
||||
for line in lines:
|
||||
if len(line.split()) == 0:
|
||||
continue
|
||||
if line.split()[0] == 'mtllib':
|
||||
all_materials += material.load_mtl(os.path.join(obj_path, line.split()[1]), clear_ks, avoid_pure_black=True) # Read in entire material library
|
||||
else:
|
||||
all_materials += material.load_mtl(mtl_override)
|
||||
else:
|
||||
print("Skip loading non-default materials")
|
||||
|
||||
|
||||
# load vertices
|
||||
vertices, texcoords, normals = [], [], []
|
||||
for line in lines:
|
||||
if len(line.split()) == 0:
|
||||
continue
|
||||
|
||||
prefix = line.split()[0].lower()
|
||||
if prefix == 'v':
|
||||
vertices.append([float(v) for v in line.split()[1:]])
|
||||
elif prefix == 'vt':
|
||||
val = [float(v) for v in line.split()[1:]]
|
||||
texcoords.append([val[0], 1.0 - val[1]])
|
||||
elif prefix == 'vn':
|
||||
normals.append([float(v) for v in line.split()[1:]])
|
||||
|
||||
print(all_materials)
|
||||
|
||||
# load faces
|
||||
activeMatIdx = None
|
||||
used_materials = []
|
||||
faces, tfaces, nfaces, mfaces = [], [], [], []
|
||||
for line in lines:
|
||||
if len(line.split()) == 0:
|
||||
continue
|
||||
|
||||
prefix = line.split()[0].lower()
|
||||
if prefix == 'usemtl': # Track used materials
|
||||
mat = _find_mat(all_materials, line.split()[1])
|
||||
if not mat in used_materials:
|
||||
used_materials.append(mat)
|
||||
activeMatIdx = used_materials.index(mat)
|
||||
elif prefix == 'f': # Parse face
|
||||
vs = line.split()[1:]
|
||||
nv = len(vs)
|
||||
vv = vs[0].split('/')
|
||||
v0 = int(vv[0]) - 1
|
||||
# t1 = int(vv[1]) - 1 if vv[1] != "" else -1
|
||||
# n1 = int(vv[2]) - 1 if vv[2] != "" else -1
|
||||
try:
|
||||
t0 = int(vv[1]) - 1 if vv[1] != "" else -1
|
||||
n0 = int(vv[2]) - 1 if vv[2] != "" else -1
|
||||
except:
|
||||
t0 = n0 = -1
|
||||
for i in range(nv - 2): # Triangulate polygons
|
||||
vv = vs[i + 1].split('/')
|
||||
v1 = int(vv[0]) - 1
|
||||
# t1 = int(vv[1]) - 1 if vv[1] != "" else -1
|
||||
# n1 = int(vv[2]) - 1 if vv[2] != "" else -1
|
||||
try:
|
||||
t1 = int(vv[1]) - 1 if vv[1] != "" else -1
|
||||
n1 = int(vv[2]) - 1 if vv[2] != "" else -1
|
||||
except:
|
||||
t1 = n1 = -1
|
||||
vv = vs[i + 2].split('/')
|
||||
v2 = int(vv[0]) - 1
|
||||
# t2 = int(vv[1]) - 1 if vv[1] != "" else -1
|
||||
# n2 = int(vv[2]) - 1 if vv[2] != "" else -1
|
||||
try:
|
||||
t2 = int(vv[1]) - 1 if vv[1] != "" else -1
|
||||
n2 = int(vv[2]) - 1 if vv[2] != "" else -1
|
||||
except:
|
||||
t2 = n2 = -1
|
||||
mfaces.append(activeMatIdx)
|
||||
faces.append([v0, v1, v2])
|
||||
tfaces.append([t0, t1, t2])
|
||||
nfaces.append([n0, n1, n2])
|
||||
assert len(tfaces) == len(faces) and len(nfaces) == len (faces)
|
||||
|
||||
# # Create an "uber" material by combining all textures into a larger texture
|
||||
# # if len(used_materials) > 1:
|
||||
# if True:
|
||||
# uber_material, texcoords, tfaces = material.merge_materials(used_materials, texcoords, tfaces, mfaces)
|
||||
# elif len(used_materials) == 1:
|
||||
# uber_material = used_materials[0]
|
||||
# else:
|
||||
# uber_material = None
|
||||
|
||||
vertices = torch.tensor(vertices, dtype=torch.float32, device='cuda')
|
||||
# texcoords = torch.tensor(texcoords, dtype=torch.float32, device='cuda') if len(texcoords) > 0 else None
|
||||
# normals = torch.tensor(normals, dtype=torch.float32, device='cuda') if len(normals) > 0 else None
|
||||
# # normals = None
|
||||
|
||||
faces = torch.tensor(faces, dtype=torch.int64, device='cuda')
|
||||
# tfaces = torch.tensor(tfaces, dtype=torch.int64, device='cuda') if texcoords is not None else None
|
||||
# nfaces = torch.tensor(nfaces, dtype=torch.int64, device='cuda') if normals is not None else None
|
||||
|
||||
# print(uber_material)
|
||||
|
||||
uber_material = all_materials[0]
|
||||
texcoords = normals = tfaces = nfaces = None
|
||||
# return mesh.Mesh(vertices, faces, normals, nfaces, texcoords, tfaces, material=uber_material)
|
||||
|
||||
imesh = mesh.Mesh(vertices, faces, normals, nfaces, texcoords, tfaces, material=uber_material)
|
||||
imesh = mesh.auto_normals(imesh)
|
||||
return imesh
|
||||
|
||||
######################################################################################
|
||||
# Save mesh object to objfile
|
||||
######################################################################################
|
||||
|
||||
def write_obj(folder, mesh, save_material=True):
|
||||
obj_file = os.path.join(folder, 'mesh.obj')
|
||||
print("Writing mesh: ", obj_file)
|
||||
with open(obj_file, "w") as f:
|
||||
# f.write("mtllib mesh.mtl\n")
|
||||
f.write("g default\n")
|
||||
|
||||
v_pos = mesh.v_pos.detach().cpu().numpy() if mesh.v_pos is not None else None
|
||||
# v_nrm = mesh.v_nrm.detach().cpu().numpy() if mesh.v_nrm is not None else None
|
||||
# v_tex = mesh.v_tex.detach().cpu().numpy() if mesh.v_tex is not None else None
|
||||
v_nrm = None
|
||||
v_tex = None
|
||||
|
||||
t_pos_idx = mesh.t_pos_idx.detach().cpu().numpy() if mesh.t_pos_idx is not None else None
|
||||
# t_nrm_idx = mesh.t_nrm_idx.detach().cpu().numpy() if mesh.t_nrm_idx is not None else None
|
||||
# t_tex_idx = mesh.t_tex_idx.detach().cpu().numpy() if mesh.t_tex_idx is not None else None
|
||||
|
||||
print(" writing %d vertices" % len(v_pos))
|
||||
for v in v_pos:
|
||||
f.write('v {} {} {} \n'.format(v[0], v[1], v[2]))
|
||||
|
||||
# if v_tex is not None:
|
||||
# print(" writing %d texcoords" % len(v_tex))
|
||||
# assert(len(t_pos_idx) == len(t_tex_idx))
|
||||
# for v in v_tex:
|
||||
# f.write('vt {} {} \n'.format(v[0], 1.0 - v[1]))
|
||||
|
||||
# if v_nrm is not None:
|
||||
# print(" writing %d normals" % len(v_nrm))
|
||||
# assert(len(t_pos_idx) == len(t_nrm_idx))
|
||||
# for v in v_nrm:
|
||||
# f.write('vn {} {} {}\n'.format(v[0], v[1], v[2]))
|
||||
|
||||
# faces
|
||||
f.write("s 1 \n")
|
||||
f.write("g pMesh1\n")
|
||||
f.write("usemtl defaultMat\n")
|
||||
|
||||
# Write faces
|
||||
print(" writing %d faces" % len(t_pos_idx))
|
||||
for i in range(len(t_pos_idx)):
|
||||
f.write("f ")
|
||||
for j in range(3):
|
||||
f.write(' %s/%s/%s' % (str(t_pos_idx[i][j]+1), '' if v_tex is None else str(t_tex_idx[i][j]+1), '' if v_nrm is None else str(t_nrm_idx[i][j]+1)))
|
||||
f.write("\n")
|
||||
|
||||
if save_material:
|
||||
mtl_file = os.path.join(folder, 'mesh.mtl')
|
||||
print("Writing material: ", mtl_file)
|
||||
material.save_mtl(mtl_file, mesh.material)
|
||||
|
||||
print("Done exporting mesh")
|
|
@ -0,0 +1,205 @@
|
|||
# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
#
|
||||
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
||||
# property and proprietary rights in and to this material, related
|
||||
# documentation and any modifications thereto. Any use, reproduction,
|
||||
# disclosure or distribution of this material and related documentation
|
||||
# without an express license agreement from NVIDIA CORPORATION or
|
||||
# its affiliates is strictly prohibited.
|
||||
|
||||
import torch
|
||||
import nvdiffrast.torch as dr
|
||||
import pytorch3d.ops
|
||||
|
||||
from . import util
|
||||
from . import mesh
|
||||
|
||||
######################################################################################
|
||||
# Computes the image gradient, useful for kd/ks smoothness losses
|
||||
######################################################################################
|
||||
def image_grad(buf, std=0.01):
|
||||
t, s = torch.meshgrid(torch.linspace(-1.0 + 1.0 / buf.shape[1], 1.0 - 1.0 / buf.shape[1], buf.shape[1], device="cuda"),
|
||||
torch.linspace(-1.0 + 1.0 / buf.shape[2], 1.0 - 1.0 / buf.shape[2], buf.shape[2], device="cuda"),
|
||||
indexing='ij')
|
||||
tc = torch.normal(mean=0, std=std, size=(buf.shape[0], buf.shape[1], buf.shape[2], 2), device="cuda") + torch.stack((s, t), dim=-1)[None, ...]
|
||||
tap = dr.texture(buf, tc, filter_mode='linear', boundary_mode='clamp')
|
||||
return torch.abs(tap[..., :-1] - buf[..., :-1]) * tap[..., -1:] * buf[..., -1:]
|
||||
|
||||
######################################################################################
|
||||
# Computes the avergage edge length of a mesh.
|
||||
# Rough estimate of the tessellation of a mesh. Can be used e.g. to clamp gradients
|
||||
######################################################################################
|
||||
def avg_edge_length(v_pos, t_pos_idx):
|
||||
e_pos_idx = mesh.compute_edges(t_pos_idx)
|
||||
edge_len = util.length(v_pos[e_pos_idx[:, 0]] - v_pos[e_pos_idx[:, 1]])
|
||||
return torch.mean(edge_len)
|
||||
|
||||
######################################################################################
|
||||
# Laplacian regularization using umbrella operator (Fujiwara / Desbrun).
|
||||
# https://mgarland.org/class/geom04/material/smoothing.pdf
|
||||
######################################################################################
|
||||
def laplace_regularizer_const(v_pos, t_pos_idx):
|
||||
term = torch.zeros_like(v_pos)
|
||||
norm = torch.zeros_like(v_pos[..., 0:1])
|
||||
|
||||
v0 = v_pos[t_pos_idx[:, 0], :]
|
||||
v1 = v_pos[t_pos_idx[:, 1], :]
|
||||
v2 = v_pos[t_pos_idx[:, 2], :]
|
||||
|
||||
term.scatter_add_(0, t_pos_idx[:, 0:1].repeat(1,3), (v1 - v0) + (v2 - v0))
|
||||
term.scatter_add_(0, t_pos_idx[:, 1:2].repeat(1,3), (v0 - v1) + (v2 - v1))
|
||||
term.scatter_add_(0, t_pos_idx[:, 2:3].repeat(1,3), (v0 - v2) + (v1 - v2))
|
||||
|
||||
two = torch.ones_like(v0) * 2.0
|
||||
norm.scatter_add_(0, t_pos_idx[:, 0:1], two)
|
||||
norm.scatter_add_(0, t_pos_idx[:, 1:2], two)
|
||||
norm.scatter_add_(0, t_pos_idx[:, 2:3], two)
|
||||
|
||||
term = term / torch.clamp(norm, min=1.0)
|
||||
|
||||
return torch.mean(term**2)
|
||||
|
||||
|
||||
def scale_dependent_relative_laplace_regularizer_const(v_pos, v_pos_abs, t_pos_idx):
|
||||
term = torch.zeros_like(v_pos)
|
||||
norm = torch.zeros_like(v_pos[..., 0:1])
|
||||
|
||||
v0 = v_pos[t_pos_idx[:, 0], :]
|
||||
v1 = v_pos[t_pos_idx[:, 1], :]
|
||||
v2 = v_pos[t_pos_idx[:, 2], :]
|
||||
|
||||
v0_abs = v_pos_abs[t_pos_idx[:, 0], :]
|
||||
v1_abs = v_pos_abs[t_pos_idx[:, 1], :]
|
||||
v2_abs = v_pos_abs[t_pos_idx[:, 2], :]
|
||||
|
||||
eps = 1e-8
|
||||
deformable_dist = False
|
||||
if deformable_dist:
|
||||
raise NotImplementedError
|
||||
else:
|
||||
## The original distance; does not account for the
|
||||
v01_dist = ((v0_abs - v1_abs).pow(2).sum(-1, keepdim=True) + eps).sqrt()
|
||||
v12_dist = ((v1_abs - v2_abs).pow(2).sum(-1, keepdim=True) + eps).sqrt()
|
||||
v20_dist = ((v2_abs - v0_abs).pow(2).sum(-1, keepdim=True) + eps).sqrt()
|
||||
|
||||
term.scatter_add_(0, t_pos_idx[:, 0:1].repeat(1,3), (v1 - v0) / v01_dist + (v2 - v0) / v20_dist)
|
||||
term.scatter_add_(0, t_pos_idx[:, 1:2].repeat(1,3), (v0 - v1) / v01_dist + (v2 - v1) / v12_dist)
|
||||
term.scatter_add_(0, t_pos_idx[:, 2:3].repeat(1,3), (v0 - v2) / v20_dist + (v1 - v2) / v12_dist)
|
||||
|
||||
return torch.mean(term**2)
|
||||
|
||||
|
||||
def scale_dependent_laplace_regularizer_const(v_pos, t_pos_idx):
|
||||
term = torch.zeros_like(v_pos)
|
||||
norm = torch.zeros_like(v_pos[..., 0:1])
|
||||
|
||||
v0 = v_pos[t_pos_idx[:, 0], :]
|
||||
v1 = v_pos[t_pos_idx[:, 1], :]
|
||||
v2 = v_pos[t_pos_idx[:, 2], :]
|
||||
|
||||
eps = 1e-8
|
||||
v01_dist = ((v0 - v1).pow(2).sum(-1, keepdim=True) + eps).sqrt()
|
||||
v12_dist = ((v1 - v2).pow(2).sum(-1, keepdim=True) + eps).sqrt()
|
||||
v20_dist = ((v2 - v0).pow(2).sum(-1, keepdim=True) + eps).sqrt()
|
||||
|
||||
stopgd = True
|
||||
if stopgd:
|
||||
v01_dist = v01_dist.detach()
|
||||
v12_dist = v12_dist.detach()
|
||||
v20_dist = v20_dist.detach()
|
||||
|
||||
term.scatter_add_(0, t_pos_idx[:, 0:1].repeat(1,3), (v1 - v0) / v01_dist + (v2 - v0) / v20_dist)
|
||||
term.scatter_add_(0, t_pos_idx[:, 1:2].repeat(1,3), (v0 - v1) / v01_dist + (v2 - v1) / v12_dist)
|
||||
term.scatter_add_(0, t_pos_idx[:, 2:3].repeat(1,3), (v0 - v2) / v20_dist + (v1 - v2) / v12_dist)
|
||||
|
||||
return torch.mean(term**2)
|
||||
|
||||
|
||||
def mesh_repulsion(v_pos, t_pos_idx):
|
||||
term = torch.zeros_like(v_pos)
|
||||
|
||||
v0 = v_pos[t_pos_idx[:, 0], :]
|
||||
v1 = v_pos[t_pos_idx[:, 1], :]
|
||||
v2 = v_pos[t_pos_idx[:, 2], :]
|
||||
|
||||
|
||||
eps = 1e-8
|
||||
v01_dist = ((v0 - v1).pow(2).sum(-1, keepdim=True) + eps).sqrt()
|
||||
v12_dist = ((v1 - v2).pow(2).sum(-1, keepdim=True) + eps).sqrt()
|
||||
v20_dist = ((v2 - v0).pow(2).sum(-1, keepdim=True) + eps).sqrt()
|
||||
|
||||
term.scatter_add_(0, t_pos_idx[:, 0:1], v01_dist)
|
||||
term.scatter_add_(0, t_pos_idx[:, 1:2], v12_dist)
|
||||
term.scatter_add_(0, t_pos_idx[:, 2:3], v20_dist)
|
||||
|
||||
return term**2
|
||||
|
||||
def laplace_regularizer_const_adaptive(v_pos, t_pos_idx):
|
||||
term = torch.zeros_like(v_pos)
|
||||
norm = torch.zeros_like(v_pos[..., 0:1])
|
||||
|
||||
v0 = v_pos[t_pos_idx[:, 0], :]
|
||||
v1 = v_pos[t_pos_idx[:, 1], :]
|
||||
v2 = v_pos[t_pos_idx[:, 2], :]
|
||||
|
||||
term.scatter_add_(0, t_pos_idx[:, 0:1].repeat(1,3), (v1 - v0) + (v2 - v0))
|
||||
term.scatter_add_(0, t_pos_idx[:, 1:2].repeat(1,3), (v0 - v1) + (v2 - v1))
|
||||
term.scatter_add_(0, t_pos_idx[:, 2:3].repeat(1,3), (v0 - v2) + (v1 - v2))
|
||||
|
||||
two = torch.ones_like(v0) * 2.0
|
||||
norm.scatter_add_(0, t_pos_idx[:, 0:1], two)
|
||||
norm.scatter_add_(0, t_pos_idx[:, 1:2], two)
|
||||
norm.scatter_add_(0, t_pos_idx[:, 2:3], two)
|
||||
|
||||
term = term / torch.clamp(norm, min=1.0)
|
||||
|
||||
v_pos = v_pos.unsqueeze(0) * 64
|
||||
with torch.no_grad():
|
||||
scale = (pytorch3d.ops.knn.knn_points(v_pos, v_pos, K=2).dists[0, :, -1].detach()).sqrt().pow(1.5) ## K=2 because dist(self, self)=0
|
||||
dist = term.pow(2).mean(-1) ### since the vanilla one uses mean
|
||||
|
||||
return torch.mean(dist * scale)
|
||||
|
||||
# def laplace_regularizer_const_sec_order(v_pos, t_pos_idx):
|
||||
# term = torch.zeros_like(v_pos)
|
||||
# norm = torch.zeros_like(v_pos[..., 0:1])
|
||||
|
||||
# v0 = v_pos[t_pos_idx[:, 0], :]
|
||||
# v1 = v_pos[t_pos_idx[:, 1], :]
|
||||
# v2 = v_pos[t_pos_idx[:, 2], :]
|
||||
|
||||
# term.scatter_add_(0, t_pos_idx[:, 0:1].repeat(1,3), (v1 - v0) + (v2 - v0))
|
||||
# term.scatter_add_(0, t_pos_idx[:, 1:2].repeat(1,3), (v0 - v1) + (v2 - v1))
|
||||
# term.scatter_add_(0, t_pos_idx[:, 2:3].repeat(1,3), (v0 - v2) + (v1 - v2))
|
||||
|
||||
# two = torch.ones_like(v0) * 2.0
|
||||
# norm.scatter_add_(0, t_pos_idx[:, 0:1], two)
|
||||
# norm.scatter_add_(0, t_pos_idx[:, 1:2], two)
|
||||
# norm.scatter_add_(0, t_pos_idx[:, 2:3], two)
|
||||
|
||||
# term = term / torch.clamp(norm, min=1.0)
|
||||
|
||||
# return torch.mean(term**2)
|
||||
|
||||
######################################################################################
|
||||
# Smooth vertex normals
|
||||
######################################################################################
|
||||
def normal_consistency(v_pos, t_pos_idx):
|
||||
# Compute face normals
|
||||
v0 = v_pos[t_pos_idx[:, 0], :]
|
||||
v1 = v_pos[t_pos_idx[:, 1], :]
|
||||
v2 = v_pos[t_pos_idx[:, 2], :]
|
||||
|
||||
face_normals = util.safe_normalize(torch.cross(v1 - v0, v2 - v0))
|
||||
|
||||
tris_per_edge = mesh.compute_edge_to_face_mapping(t_pos_idx)
|
||||
|
||||
# Fetch normals for both faces sharind an edge
|
||||
n0 = face_normals[tris_per_edge[:, 0], :]
|
||||
n1 = face_normals[tris_per_edge[:, 1], :]
|
||||
|
||||
# Compute error metric based on normal difference
|
||||
term = torch.clamp(util.dot(n0, n1), min=-1.0, max=1.0)
|
||||
term = (1.0 - term) * 0.5
|
||||
|
||||
return torch.mean(torch.abs(term))
|
|
@ -0,0 +1,454 @@
|
|||
# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
#
|
||||
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
||||
# property and proprietary rights in and to this material, related
|
||||
# documentation and any modifications thereto. Any use, reproduction,
|
||||
# disclosure or distribution of this material and related documentation
|
||||
# without an express license agreement from NVIDIA CORPORATION or
|
||||
# its affiliates is strictly prohibited.
|
||||
|
||||
import torch
|
||||
import nvdiffrast.torch as dr
|
||||
|
||||
from . import util
|
||||
from . import renderutils as ru
|
||||
from . import light
|
||||
|
||||
# ==============================================================================================
|
||||
# Helper functions
|
||||
# ==============================================================================================
|
||||
def interpolate(attr, rast, attr_idx, rast_db=None):
|
||||
return dr.interpolate(attr.contiguous(), rast, attr_idx, rast_db=rast_db, diff_attrs=None if rast_db is None else 'all')
|
||||
|
||||
# ==============================================================================================
|
||||
# pixel shader
|
||||
# ==============================================================================================
|
||||
def shade(
|
||||
gb_pos,
|
||||
gb_geometric_normal,
|
||||
gb_normal,
|
||||
gb_tangent,
|
||||
gb_texc,
|
||||
gb_texc_deriv,
|
||||
view_pos,
|
||||
lgt,
|
||||
material,
|
||||
bsdf,
|
||||
xfm_lgt=None
|
||||
):
|
||||
|
||||
################################################################################
|
||||
# Texture lookups
|
||||
################################################################################
|
||||
perturbed_nrm = None
|
||||
alpha_mtl = None
|
||||
if 'kd_ks_normal' in material:
|
||||
# Combined texture, used for MLPs because lookups are expensive
|
||||
all_tex_jitter = material['kd_ks_normal'].sample(gb_pos + torch.normal(mean=0, std=0.01, size=gb_pos.shape, device="cuda"))
|
||||
all_tex = material['kd_ks_normal'].sample(gb_pos)
|
||||
assert all_tex.shape[-1] == 9 or all_tex.shape[-1] == 10, "Combined kd_ks_normal must be 9 or 10 channels"
|
||||
kd, ks, perturbed_nrm = all_tex[..., :-6], all_tex[..., -6:-3], all_tex[..., -3:]
|
||||
# Compute albedo (kd) gradient, used for material regularizer
|
||||
kd_grad = torch.sum(torch.abs(all_tex_jitter[..., :-6] - all_tex[..., :-6]), dim=-1, keepdim=True) / 3
|
||||
else:
|
||||
try:
|
||||
kd_jitter = material['kd'].sample(gb_texc + torch.normal(mean=0, std=0.005, size=gb_texc.shape, device="cuda"), gb_texc_deriv)
|
||||
if 'alpha' in material:
|
||||
raise NotImplementedError
|
||||
try:
|
||||
alpha_mtl = material['alpha'].sample(gb_texc, gb_texc_deriv)
|
||||
except:
|
||||
alpha_mtl = material['alpha'].sample(gb_pos + torch.normal(mean=0, std=0.01, size=gb_pos.shape, device="cuda"))
|
||||
kd = material['kd'].sample(gb_texc, gb_texc_deriv)
|
||||
ks = material['ks'].sample(gb_texc, gb_texc_deriv)[..., 0:3] # skip alpha
|
||||
kd_grad = torch.sum(torch.abs(kd_jitter[..., 0:3] - kd[..., 0:3]), dim=-1, keepdim=True) / 3
|
||||
except:
|
||||
kd_jitter = kd = material['kd'].data[0].expand(*gb_pos.size())
|
||||
ks = material['ks'].data[0].expand(*gb_pos.size())[..., 0:3] # skip alpha
|
||||
kd_grad = torch.sum(torch.abs(kd_jitter[..., 0:3] - kd[..., 0:3]), dim=-1, keepdim=True) / 3
|
||||
|
||||
# Separate kd into alpha and color, default alpha = 1
|
||||
alpha = kd[..., 3:4] if kd.shape[-1] == 4 else torch.ones_like(kd[..., 0:1])
|
||||
if alpha_mtl is not None:
|
||||
alpha = alpha_mtl
|
||||
kd = kd[..., 0:3]
|
||||
|
||||
################################################################################
|
||||
# Normal perturbation & normal bend
|
||||
################################################################################
|
||||
if 'no_perturbed_nrm' in material and material['no_perturbed_nrm']:
|
||||
perturbed_nrm = None
|
||||
|
||||
use_python = (gb_tangent is None)
|
||||
|
||||
gb_normal = ru.prepare_shading_normal(gb_pos, view_pos, perturbed_nrm, gb_normal, gb_tangent, gb_geometric_normal, two_sided_shading=True, opengl=True, use_python=use_python)
|
||||
gb_geo_normal_corrected = ru.prepare_shading_normal(gb_pos, view_pos, None, gb_geometric_normal, gb_tangent, gb_geometric_normal, two_sided_shading=True, opengl=True, use_python=use_python)
|
||||
|
||||
################################################################################
|
||||
# Evaluate BSDF
|
||||
################################################################################
|
||||
|
||||
assert 'bsdf' in material or bsdf is not None, "Material must specify a BSDF type"
|
||||
bsdf = material['bsdf'] if bsdf is None else bsdf
|
||||
if bsdf == 'pbr':
|
||||
# do not use pbr
|
||||
raise NotImplementedError
|
||||
if isinstance(lgt, light.EnvironmentLight):
|
||||
shaded_col = lgt.shade(gb_pos, gb_normal, kd, ks, view_pos, specular=True)
|
||||
else:
|
||||
assert False, "Invalid light type"
|
||||
elif bsdf == 'diffuse':
|
||||
if isinstance(lgt, light.EnvironmentLight):
|
||||
shaded_col = lgt.shade(gb_pos, gb_geo_normal_corrected, kd, ks, view_pos, specular=False, xfm_lgt=xfm_lgt)
|
||||
else:
|
||||
assert False, "Invalid light type"
|
||||
elif bsdf == 'normal':
|
||||
shaded_col = (gb_normal + 1.0)*0.5
|
||||
elif bsdf == 'tangent':
|
||||
shaded_col = (gb_tangent + 1.0)*0.5
|
||||
elif bsdf == 'kd':
|
||||
shaded_col = kd
|
||||
elif bsdf == 'ks':
|
||||
shaded_col = ks
|
||||
else:
|
||||
assert False, "Invalid BSDF '%s'" % bsdf
|
||||
|
||||
nan_mask = torch.isnan(shaded_col)
|
||||
if nan_mask.any():
|
||||
raise
|
||||
if alpha is not None:
|
||||
nan_mask = torch.isnan(alpha)
|
||||
if nan_mask.any():
|
||||
raise
|
||||
|
||||
# Return multiple buffers
|
||||
buffers = {
|
||||
'shaded' : torch.cat((shaded_col, alpha), dim=-1),
|
||||
'kd_grad' : torch.cat((kd_grad, alpha), dim=-1),
|
||||
'occlusion' : torch.cat((ks[..., :1], alpha), dim=-1),
|
||||
'normal' : torch.cat((gb_normal, alpha), dim=-1),
|
||||
'depth' : torch.cat(((gb_pos - view_pos).pow(2).sum(dim=-1, keepdim=True).sqrt(), alpha), dim=-1),
|
||||
'pos' : torch.cat((gb_pos, alpha), dim=-1),
|
||||
'geo_normal': torch.cat((gb_geo_normal_corrected, alpha), dim=-1),
|
||||
'geo_viewdir': torch.cat((view_pos - gb_pos, alpha), dim=-1),
|
||||
'alpha' : alpha
|
||||
}
|
||||
|
||||
|
||||
return buffers
|
||||
|
||||
# ==============================================================================================
|
||||
# Render a depth slice of the mesh (scene), some limitations:
|
||||
# - Single mesh
|
||||
# - Single light
|
||||
# - Single material
|
||||
# ==============================================================================================
|
||||
def render_layer(
|
||||
rast,
|
||||
rast_deriv,
|
||||
mesh,
|
||||
view_pos,
|
||||
lgt,
|
||||
resolution,
|
||||
spp,
|
||||
msaa,
|
||||
bsdf,
|
||||
xfm_lgt = None,
|
||||
flat_shading = False
|
||||
):
|
||||
|
||||
full_res = [resolution[0]*spp, resolution[1]*spp]
|
||||
|
||||
################################################################################
|
||||
# Rasterize
|
||||
################################################################################
|
||||
|
||||
# Scale down to shading resolution when MSAA is enabled, otherwise shade at full resolution
|
||||
if spp > 1 and msaa:
|
||||
rast_out_s = util.scale_img_nhwc(rast, resolution, mag='nearest', min='nearest')
|
||||
else:
|
||||
rast_out_s = rast
|
||||
|
||||
################################################################################
|
||||
# Interpolate attributes
|
||||
################################################################################
|
||||
|
||||
# Interpolate world space position
|
||||
gb_pos, _ = interpolate(mesh.v_pos[None, ...], rast_out_s, mesh.t_pos_idx.int())
|
||||
|
||||
# Compute geometric normals. We need those because of bent normals trick (for bump mapping)
|
||||
v0 = mesh.v_pos[mesh.t_pos_idx[:, 0], :]
|
||||
v1 = mesh.v_pos[mesh.t_pos_idx[:, 1], :]
|
||||
v2 = mesh.v_pos[mesh.t_pos_idx[:, 2], :]
|
||||
face_normals = util.safe_normalize(torch.cross(v1 - v0, v2 - v0))
|
||||
face_normal_indices = (torch.arange(0, face_normals.shape[0], dtype=torch.int64, device='cuda')[:, None]).repeat(1, 3)
|
||||
gb_geometric_normal, _ = interpolate(face_normals[None, ...], rast_out_s, face_normal_indices.int())
|
||||
|
||||
if flat_shading:
|
||||
gb_normal = mesh.f_nrm[rast_out_s[:, :, :, -1].long() - 1] # empty triangle get id=0; the first idx starts from 1
|
||||
gb_normal[rast_out_s[:, :, :, -1].long() == 0] = 0
|
||||
else:
|
||||
assert mesh.v_nrm is not None
|
||||
gb_normal, _ = interpolate(mesh.v_nrm[None, ...], rast_out_s, mesh.t_nrm_idx.int())
|
||||
|
||||
if mesh.v_tng is not None:
|
||||
gb_tangent, _ = interpolate(mesh.v_tng[None, ...], rast_out_s, mesh.t_tng_idx.int()) # Interpolate tangents
|
||||
else:
|
||||
gb_tangent = None
|
||||
|
||||
# Do not use texture coordinate in our case
|
||||
gb_texc, gb_texc_deriv = None, None
|
||||
|
||||
|
||||
################################################################################
|
||||
# Shade
|
||||
################################################################################
|
||||
buffers = shade(gb_pos, gb_geometric_normal, gb_normal, gb_tangent, gb_texc, gb_texc_deriv,
|
||||
view_pos, lgt, mesh.material, bsdf, xfm_lgt=xfm_lgt)
|
||||
|
||||
#### get a mask on mesh (used to identify foreground)
|
||||
mask_cont, _ = interpolate(torch.ones_like(mesh.v_pos[None, :, :1], device=mesh.v_pos.device), rast_out_s, mesh.t_pos_idx.int())
|
||||
mask = (mask_cont > 0).float()
|
||||
buffers['mask'] = mask
|
||||
buffers['mask_cont'] = mask_cont
|
||||
|
||||
################################################################################
|
||||
# Prepare output
|
||||
################################################################################
|
||||
|
||||
# Scale back up to visibility resolution if using MSAA
|
||||
if spp > 1 and msaa:
|
||||
for key in buffers.keys():
|
||||
if buffers[key] is not None:
|
||||
buffers[key] = util.scale_img_nhwc(buffers[key], full_res, mag='nearest', min='nearest')
|
||||
|
||||
|
||||
# Return buffers
|
||||
return buffers
|
||||
|
||||
# ==============================================================================================
|
||||
# Render a depth peeled mesh (scene), some limitations:
|
||||
# - Single mesh
|
||||
# - Single light
|
||||
# - Single material
|
||||
# ==============================================================================================
|
||||
def render_mesh(
|
||||
ctx,
|
||||
mesh,
|
||||
mtx_in,
|
||||
view_pos,
|
||||
lgt,
|
||||
resolution,
|
||||
spp = 1,
|
||||
num_layers = 1,
|
||||
msaa = False,
|
||||
background = None,
|
||||
bsdf = None,
|
||||
xfm_lgt = None,
|
||||
tet_centers = None,
|
||||
flat_shading = False
|
||||
):
|
||||
|
||||
def prepare_input_vector(x):
|
||||
x = torch.tensor(x, dtype=torch.float32, device='cuda') if not torch.is_tensor(x) else x
|
||||
return x[:, None, None, :] if len(x.shape) == 2 else x
|
||||
|
||||
def composite_buffer(key, layers, background, antialias):
|
||||
accum = background
|
||||
for buffers, rast in layers:
|
||||
alpha = (rast[..., -1:] > 0).float() * buffers[key][..., -1:]
|
||||
accum = torch.lerp(accum, torch.cat((buffers[key][..., :-1], torch.ones_like(buffers[key][..., -1:])), dim=-1), alpha)
|
||||
if antialias:
|
||||
accum = dr.antialias(accum.contiguous(), rast, v_pos_clip, mesh.t_pos_idx.int())
|
||||
break ## HACK: the first layer only
|
||||
return accum
|
||||
|
||||
def separate_buffer(key, layers, background, antialias):
|
||||
accum_list = []
|
||||
for buffers, rast in layers:
|
||||
accum = background
|
||||
alpha = (rast[..., -1:] > 0).float() * buffers[key][..., -1:]
|
||||
accum = torch.lerp(accum, torch.cat((buffers[key][..., :-1], torch.ones_like(buffers[key][..., -1:])), dim=-1), alpha)
|
||||
if antialias:
|
||||
accum = dr.antialias(accum.contiguous(), rast, v_pos_clip, mesh.t_pos_idx.int())
|
||||
accum_list.append(accum)
|
||||
return accum_list
|
||||
|
||||
assert mesh.t_pos_idx.shape[0] > 0, "Got empty training triangle mesh (unrecoverable discontinuity)"
|
||||
assert background is None or (background.shape[1] == resolution[0] and background.shape[2] == resolution[1])
|
||||
|
||||
full_res = [resolution[0]*spp, resolution[1]*spp]
|
||||
|
||||
# Convert numpy arrays to torch tensors
|
||||
mtx_in = torch.tensor(mtx_in, dtype=torch.float32, device='cuda') if not torch.is_tensor(mtx_in) else mtx_in
|
||||
view_pos = prepare_input_vector(view_pos)
|
||||
|
||||
# clip space transform
|
||||
v_pos_clip = ru.xfm_points(mesh.v_pos[None, ...], mtx_in)
|
||||
|
||||
# Render all layers front-to-back
|
||||
with dr.DepthPeeler(ctx, v_pos_clip, mesh.t_pos_idx.int(), full_res) as peeler:
|
||||
rast, db = peeler.rasterize_next_layer()
|
||||
layers = [(render_layer(rast, db, mesh, view_pos, lgt, resolution, spp, msaa, bsdf, xfm_lgt, flat_shading), rast)]
|
||||
rast_1st_layer = rast
|
||||
# with torch.no_grad():
|
||||
if True:
|
||||
rast, db = peeler.rasterize_next_layer()
|
||||
layers2 = [(render_layer(rast, db, mesh, view_pos, lgt, resolution, spp, msaa, bsdf, xfm_lgt, flat_shading), rast)]
|
||||
|
||||
# Setup background
|
||||
if background is not None:
|
||||
if spp > 1:
|
||||
background = util.scale_img_nhwc(background, full_res, mag='nearest', min='nearest')
|
||||
background = torch.cat((background, torch.zeros_like(background[..., 0:1])), dim=-1)
|
||||
else:
|
||||
background = torch.zeros(1, full_res[0], full_res[1], 4, dtype=torch.float32, device='cuda')
|
||||
|
||||
# Composite layers front-to-back
|
||||
out_buffers = {}
|
||||
for key in layers[0][0].keys():
|
||||
if key == 'shaded':
|
||||
accum = composite_buffer(key, layers, background, True)
|
||||
elif (key == 'depth' or key == 'pos') and layers[0][0][key] is not None:
|
||||
accum = separate_buffer(key, layers, torch.ones_like(layers[0][0][key]) * 20.0, False)
|
||||
elif ('normal' in key) and layers[0][0][key] is not None:
|
||||
accum = composite_buffer(key, layers, torch.zeros_like(layers[0][0][key]), True)
|
||||
elif layers[0][0][key] is not None:
|
||||
accum = composite_buffer(key, layers, torch.zeros_like(layers[0][0][key]), False)
|
||||
|
||||
if (key == 'depth' or key == 'pos') and layers[0][0][key] is not None:
|
||||
out_buffers[key] = util.avg_pool_nhwc(accum[0], spp) if spp > 1 else accum[0]
|
||||
else:
|
||||
# Downscale to framebuffer resolution. Use avg pooling
|
||||
out_buffers[key] = util.avg_pool_nhwc(accum, spp) if spp > 1 else accum
|
||||
|
||||
accum = composite_buffer('shaded', layers, background, True)
|
||||
out_buffers['shaded_second'] = util.avg_pool_nhwc(accum, spp) if spp > 1 else accum
|
||||
|
||||
accum = separate_buffer('depth', layers2, -1 * torch.ones_like(layers2[0][0]['depth']), False)
|
||||
out_buffers['depth_second'] = util.avg_pool_nhwc(accum[0], spp) if spp > 1 else accum[0]
|
||||
|
||||
|
||||
accum = separate_buffer('normal', layers2, torch.zeros_like(layers2[0][0]['normal']), False)
|
||||
out_buffers['normal_second'] = util.avg_pool_nhwc(accum[0], spp) if spp > 1 else accum[0]
|
||||
|
||||
rast_triangle_id = rast_1st_layer[:, :, :, -1].unique()
|
||||
if rast_triangle_id[0] == 0:
|
||||
if rast_triangle_id.size(0) > 1:
|
||||
rast_triangle_id = rast_triangle_id[1:] - 1 ## since by the convention of the rasterizer, 0 = empty
|
||||
else:
|
||||
rast_triangle_id = None
|
||||
out_buffers['rast_triangle_id'] = rast_triangle_id
|
||||
out_buffers['rast_depth'] = rast_1st_layer[:, :, :, -2] # z-buffer
|
||||
|
||||
|
||||
|
||||
if tet_centers is not None:
|
||||
with torch.no_grad():
|
||||
v_pos_clip = v_pos_clip[0]
|
||||
assert full_res[0] == full_res[1]
|
||||
homo_transformed_tet_centers = ru.xfm_points(tet_centers[None, ...], mtx_in)
|
||||
transformed_tet_centers = homo_transformed_tet_centers[0, :, :3] / homo_transformed_tet_centers[0, :, 3:4]
|
||||
|
||||
int_transformed_tet_centers = torch.round((transformed_tet_centers / 2.0 + 0.5) * (full_res[0] - 1)).long() # from the clip space (i.e., [-1, 1]^3) to the nearest integer coordinates in the canvas
|
||||
|
||||
### transpose THE "image"
|
||||
tmp_int_transformed_tet_centers = int_transformed_tet_centers.clone()
|
||||
int_transformed_tet_centers[:, 0] = tmp_int_transformed_tet_centers[:, 1]
|
||||
int_transformed_tet_centers[:, 1] = tmp_int_transformed_tet_centers[:, 0]
|
||||
|
||||
|
||||
valid_tet_centers = ((torch.logical_and((int_transformed_tet_centers <= full_res[0] - 1), int_transformed_tet_centers >= 0).float()).prod(dim=-1) == 1) # those tet centers in/on the edge of the clip space
|
||||
valid_int_transformed_tet_centers = int_transformed_tet_centers[valid_tet_centers]
|
||||
|
||||
tet_center_dirs = (tet_centers - view_pos.view(1, 3))
|
||||
tet_center_depths = tet_center_dirs.pow(2).sum(-1).sqrt()
|
||||
|
||||
### Finding occluded tetrahedra
|
||||
valid_transformed_tet_center_depths = transformed_tet_centers[valid_tet_centers][:, -1] # get the depth in the clip space
|
||||
valid_tet_ids = torch.arange(tet_centers.size(0)).to(valid_tet_centers.device)[valid_tet_centers]
|
||||
|
||||
|
||||
corrected_rast_depth = out_buffers['rast_depth'].clone().detach()
|
||||
|
||||
|
||||
corrected_rast_depth[rast_1st_layer[:, :, :, -1] == 0] = 100 # for all pixels without any rasterized mesh, just set the depth to a large enough value
|
||||
|
||||
'''
|
||||
Hacky way of finding most of the non-occluded tetrahedra (except for already rasterized ones):
|
||||
|
||||
For each pixel, find the min depth in a small neighborhood.
|
||||
If the center of a tetrahedron (coinciding with this pixel when rasterized) is smaller than this min depth,
|
||||
this tetrahedron is certainly non-occluded.
|
||||
|
||||
Doing this because exact per-pixel comparison for triangular meshes can be costly,
|
||||
plus we do not need to perfectly finding all visible tetrahedra.
|
||||
'''
|
||||
depth_search_range = 7 ### change this value for different resolution in rasterization
|
||||
corrected_rast_depth = -torch.nn.functional.max_pool2d(
|
||||
-corrected_rast_depth,
|
||||
kernel_size=2*depth_search_range+1,
|
||||
stride=1,
|
||||
padding=depth_search_range)
|
||||
|
||||
valid_reference_depth = corrected_rast_depth[0, valid_int_transformed_tet_centers[:, 0], valid_int_transformed_tet_centers[:, 1]]
|
||||
depth_filter = valid_reference_depth >= valid_transformed_tet_center_depths
|
||||
|
||||
|
||||
empty_2d_mask = (rast_1st_layer[:, :, :, -1] == 0)
|
||||
empty_2d_mask = (-torch.nn.functional.max_pool2d(
|
||||
-empty_2d_mask.float(),
|
||||
kernel_size=2*depth_search_range+1,
|
||||
stride=1,
|
||||
padding=depth_search_range)).bool() ### similar philosophy for using a neighborhood
|
||||
empty_filter = empty_2d_mask[0, valid_int_transformed_tet_centers[:, 0], valid_int_transformed_tet_centers[:, 1]]
|
||||
|
||||
## visible tets are either determined by depth test or emptyness test
|
||||
out_buffers['visible_tet_id'] = valid_tet_ids[torch.logical_or(empty_filter, depth_filter)]
|
||||
|
||||
return out_buffers
|
||||
|
||||
# ==============================================================================================
|
||||
# Render UVs
|
||||
# ==============================================================================================
|
||||
def render_uv(ctx, mesh, resolution, mlp_texture):
|
||||
|
||||
# clip space transform
|
||||
uv_clip = mesh.v_tex[None, ...]*2.0 - 1.0
|
||||
|
||||
# pad to four component coordinate
|
||||
uv_clip4 = torch.cat((uv_clip, torch.zeros_like(uv_clip[...,0:1]), torch.ones_like(uv_clip[...,0:1])), dim = -1)
|
||||
|
||||
# rasterize
|
||||
rast, _ = dr.rasterize(ctx, uv_clip4, mesh.t_tex_idx.int(), resolution)
|
||||
|
||||
# Interpolate world space position
|
||||
gb_pos, _ = interpolate(mesh.v_pos[None, ...], rast, mesh.t_pos_idx.int())
|
||||
|
||||
# Sample out textures from MLP
|
||||
all_tex = mlp_texture.sample(gb_pos)
|
||||
assert all_tex.shape[-1] == 9 or all_tex.shape[-1] == 10, "Combined kd_ks_normal must be 9 or 10 channels"
|
||||
perturbed_nrm = all_tex[..., -3:]
|
||||
return (rast[..., -1:] > 0).float(), all_tex[..., :-6], all_tex[..., -6:-3], util.safe_normalize(perturbed_nrm)
|
||||
|
||||
# ==============================================================================================
|
||||
# Render UVs
|
||||
# ==============================================================================================
|
||||
def render_uv_nrm(ctx, mesh, resolution, mlp_texture):
|
||||
|
||||
# clip space transform
|
||||
uv_clip = mesh.v_tex[None, ...]*2.0 - 1.0
|
||||
|
||||
# pad to four component coordinate
|
||||
uv_clip4 = torch.cat((uv_clip, torch.zeros_like(uv_clip[...,0:1]), torch.ones_like(uv_clip[...,0:1])), dim = -1)
|
||||
|
||||
# rasterize
|
||||
rast, _ = dr.rasterize(ctx, uv_clip4, mesh.t_tex_idx.int(), resolution)
|
||||
|
||||
# Interpolate world space position
|
||||
gb_pos, _ = interpolate(mesh.v_pos[None, ...], rast, mesh.t_pos_idx.int())
|
||||
|
||||
# Sample out textures from MLP
|
||||
all_tex = mlp_texture.sample(gb_pos)
|
||||
perturbed_nrm = all_tex[..., -3:]
|
||||
return (rast[..., -1:] > 0).float(), util.safe_normalize(perturbed_nrm)
|
|
@ -0,0 +1,11 @@
|
|||
# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
#
|
||||
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
||||
# property and proprietary rights in and to this material, related
|
||||
# documentation and any modifications thereto. Any use, reproduction,
|
||||
# disclosure or distribution of this material and related documentation
|
||||
# without an express license agreement from NVIDIA CORPORATION or
|
||||
# its affiliates is strictly prohibited.
|
||||
|
||||
from .ops import xfm_points, xfm_vectors, image_loss, diffuse_cubemap, specular_cubemap, prepare_shading_normal, lambert, frostbite_diffuse, pbr_specular, pbr_bsdf, _fresnel_shlick, _ndf_ggx, _lambda_ggx, _masking_smith
|
||||
__all__ = ["xfm_vectors", "xfm_points", "image_loss", "diffuse_cubemap","specular_cubemap", "prepare_shading_normal", "lambert", "frostbite_diffuse", "pbr_specular", "pbr_bsdf", "_fresnel_shlick", "_ndf_ggx", "_lambda_ggx", "_masking_smith", ]
|
|
@ -0,0 +1,710 @@
|
|||
/*
|
||||
* Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
*
|
||||
* NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
||||
* property and proprietary rights in and to this material, related
|
||||
* documentation and any modifications thereto. Any use, reproduction,
|
||||
* disclosure or distribution of this material and related documentation
|
||||
* without an express license agreement from NVIDIA CORPORATION or
|
||||
* its affiliates is strictly prohibited.
|
||||
*/
|
||||
|
||||
#include "common.h"
|
||||
#include "bsdf.h"
|
||||
|
||||
#define SPECULAR_EPSILON 1e-4f
|
||||
|
||||
//------------------------------------------------------------------------
|
||||
// Lambert functions
|
||||
|
||||
__device__ inline float fwdLambert(const vec3f nrm, const vec3f wi)
|
||||
{
|
||||
return max(dot(nrm, wi) / M_PI, 0.0f);
|
||||
}
|
||||
|
||||
__device__ inline void bwdLambert(const vec3f nrm, const vec3f wi, vec3f& d_nrm, vec3f& d_wi, const float d_out)
|
||||
{
|
||||
if (dot(nrm, wi) > 0.0f)
|
||||
bwdDot(nrm, wi, d_nrm, d_wi, d_out / M_PI);
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------
|
||||
// Fresnel Schlick
|
||||
|
||||
__device__ inline float fwdFresnelSchlick(const float f0, const float f90, const float cosTheta)
|
||||
{
|
||||
float _cosTheta = clamp(cosTheta, SPECULAR_EPSILON, 1.0f - SPECULAR_EPSILON);
|
||||
float scale = powf(1.0f - _cosTheta, 5.0f);
|
||||
return f0 * (1.0f - scale) + f90 * scale;
|
||||
}
|
||||
|
||||
__device__ inline void bwdFresnelSchlick(const float f0, const float f90, const float cosTheta, float& d_f0, float& d_f90, float& d_cosTheta, const float d_out)
|
||||
{
|
||||
float _cosTheta = clamp(cosTheta, SPECULAR_EPSILON, 1.0f - SPECULAR_EPSILON);
|
||||
float scale = pow(max(1.0f - _cosTheta, 0.0f), 5.0f);
|
||||
d_f0 += d_out * (1.0 - scale);
|
||||
d_f90 += d_out * scale;
|
||||
if (cosTheta >= SPECULAR_EPSILON && cosTheta < 1.0f - SPECULAR_EPSILON)
|
||||
{
|
||||
d_cosTheta += d_out * (f90 - f0) * -5.0f * powf(1.0f - cosTheta, 4.0f);
|
||||
}
|
||||
}
|
||||
|
||||
__device__ inline vec3f fwdFresnelSchlick(const vec3f f0, const vec3f f90, const float cosTheta)
|
||||
{
|
||||
float _cosTheta = clamp(cosTheta, SPECULAR_EPSILON, 1.0f - SPECULAR_EPSILON);
|
||||
float scale = powf(1.0f - _cosTheta, 5.0f);
|
||||
return f0 * (1.0f - scale) + f90 * scale;
|
||||
}
|
||||
|
||||
__device__ inline void bwdFresnelSchlick(const vec3f f0, const vec3f f90, const float cosTheta, vec3f& d_f0, vec3f& d_f90, float& d_cosTheta, const vec3f d_out)
|
||||
{
|
||||
float _cosTheta = clamp(cosTheta, SPECULAR_EPSILON, 1.0f - SPECULAR_EPSILON);
|
||||
float scale = pow(max(1.0f - _cosTheta, 0.0f), 5.0f);
|
||||
d_f0 += d_out * (1.0 - scale);
|
||||
d_f90 += d_out * scale;
|
||||
if (cosTheta >= SPECULAR_EPSILON && cosTheta < 1.0f - SPECULAR_EPSILON)
|
||||
{
|
||||
d_cosTheta += sum(d_out * (f90 - f0) * -5.0f * powf(1.0f - cosTheta, 4.0f));
|
||||
}
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------
|
||||
// Frostbite diffuse
|
||||
|
||||
__device__ inline float fwdFrostbiteDiffuse(const vec3f nrm, const vec3f wi, const vec3f wo, float linearRoughness)
|
||||
{
|
||||
float wiDotN = dot(wi, nrm);
|
||||
float woDotN = dot(wo, nrm);
|
||||
if (wiDotN > 0.0f && woDotN > 0.0f)
|
||||
{
|
||||
vec3f h = safeNormalize(wo + wi);
|
||||
float wiDotH = dot(wi, h);
|
||||
|
||||
float energyBias = 0.5f * linearRoughness;
|
||||
float energyFactor = 1.0f - (0.51f / 1.51f) * linearRoughness;
|
||||
float f90 = energyBias + 2.f * wiDotH * wiDotH * linearRoughness;
|
||||
float f0 = 1.f;
|
||||
|
||||
float wiScatter = fwdFresnelSchlick(f0, f90, wiDotN);
|
||||
float woScatter = fwdFresnelSchlick(f0, f90, woDotN);
|
||||
|
||||
return wiScatter * woScatter * energyFactor;
|
||||
}
|
||||
else return 0.0f;
|
||||
}
|
||||
|
||||
__device__ inline void bwdFrostbiteDiffuse(const vec3f nrm, const vec3f wi, const vec3f wo, float linearRoughness, vec3f& d_nrm, vec3f& d_wi, vec3f& d_wo, float &d_linearRoughness, const float d_out)
|
||||
{
|
||||
float wiDotN = dot(wi, nrm);
|
||||
float woDotN = dot(wo, nrm);
|
||||
|
||||
if (wiDotN > 0.0f && woDotN > 0.0f)
|
||||
{
|
||||
vec3f h = safeNormalize(wo + wi);
|
||||
float wiDotH = dot(wi, h);
|
||||
|
||||
float energyBias = 0.5f * linearRoughness;
|
||||
float energyFactor = 1.0f - (0.51f / 1.51f) * linearRoughness;
|
||||
float f90 = energyBias + 2.f * wiDotH * wiDotH * linearRoughness;
|
||||
float f0 = 1.f;
|
||||
|
||||
float wiScatter = fwdFresnelSchlick(f0, f90, wiDotN);
|
||||
float woScatter = fwdFresnelSchlick(f0, f90, woDotN);
|
||||
|
||||
// -------------- BWD --------------
|
||||
// Backprop: return wiScatter * woScatter * energyFactor;
|
||||
float d_wiScatter = d_out * woScatter * energyFactor;
|
||||
float d_woScatter = d_out * wiScatter * energyFactor;
|
||||
float d_energyFactor = d_out * wiScatter * woScatter;
|
||||
|
||||
// Backprop: float woScatter = fwdFresnelSchlick(f0, f90, woDotN);
|
||||
float d_woDotN = 0.0f, d_f0 = 0.0, d_f90 = 0.0f;
|
||||
bwdFresnelSchlick(f0, f90, woDotN, d_f0, d_f90, d_woDotN, d_woScatter);
|
||||
|
||||
// Backprop: float wiScatter = fwdFresnelSchlick(fd0, fd90, wiDotN);
|
||||
float d_wiDotN = 0.0f;
|
||||
bwdFresnelSchlick(f0, f90, wiDotN, d_f0, d_f90, d_wiDotN, d_wiScatter);
|
||||
|
||||
// Backprop: float f90 = energyBias + 2.f * wiDotH * wiDotH * linearRoughness;
|
||||
float d_energyBias = d_f90;
|
||||
float d_wiDotH = d_f90 * 4 * wiDotH * linearRoughness;
|
||||
d_linearRoughness += d_f90 * 2 * wiDotH * wiDotH;
|
||||
|
||||
// Backprop: float energyFactor = 1.0f - (0.51f / 1.51f) * linearRoughness;
|
||||
d_linearRoughness -= (0.51f / 1.51f) * d_energyFactor;
|
||||
|
||||
// Backprop: float energyBias = 0.5f * linearRoughness;
|
||||
d_linearRoughness += 0.5 * d_energyBias;
|
||||
|
||||
// Backprop: float wiDotH = dot(wi, h);
|
||||
vec3f d_h(0);
|
||||
bwdDot(wi, h, d_wi, d_h, d_wiDotH);
|
||||
|
||||
// Backprop: vec3f h = safeNormalize(wo + wi);
|
||||
vec3f d_wo_wi(0);
|
||||
bwdSafeNormalize(wo + wi, d_wo_wi, d_h);
|
||||
d_wi += d_wo_wi; d_wo += d_wo_wi;
|
||||
|
||||
bwdDot(wo, nrm, d_wo, d_nrm, d_woDotN);
|
||||
bwdDot(wi, nrm, d_wi, d_nrm, d_wiDotN);
|
||||
}
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------
|
||||
// Ndf GGX
|
||||
|
||||
__device__ inline float fwdNdfGGX(const float alphaSqr, const float cosTheta)
|
||||
{
|
||||
float _cosTheta = clamp(cosTheta, SPECULAR_EPSILON, 1.0f - SPECULAR_EPSILON);
|
||||
float d = (_cosTheta * alphaSqr - _cosTheta) * _cosTheta + 1.0f;
|
||||
return alphaSqr / (d * d * M_PI);
|
||||
}
|
||||
|
||||
__device__ inline void bwdNdfGGX(const float alphaSqr, const float cosTheta, float& d_alphaSqr, float& d_cosTheta, const float d_out)
|
||||
{
|
||||
// Torch only back propagates if clamp doesn't trigger
|
||||
float _cosTheta = clamp(cosTheta, SPECULAR_EPSILON, 1.0f - SPECULAR_EPSILON);
|
||||
float cosThetaSqr = _cosTheta * _cosTheta;
|
||||
d_alphaSqr += d_out * (1.0f - (alphaSqr + 1.0f) * cosThetaSqr) / (M_PI * powf((alphaSqr - 1.0) * cosThetaSqr + 1.0f, 3.0f));
|
||||
if (cosTheta > SPECULAR_EPSILON && cosTheta < 1.0f - SPECULAR_EPSILON)
|
||||
{
|
||||
d_cosTheta += d_out * -(4.0f * (alphaSqr - 1.0f) * alphaSqr * cosTheta) / (M_PI * powf((alphaSqr - 1.0) * cosThetaSqr + 1.0f, 3.0f));
|
||||
}
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------
|
||||
// Lambda GGX
|
||||
|
||||
__device__ inline float fwdLambdaGGX(const float alphaSqr, const float cosTheta)
|
||||
{
|
||||
float _cosTheta = clamp(cosTheta, SPECULAR_EPSILON, 1.0f - SPECULAR_EPSILON);
|
||||
float cosThetaSqr = _cosTheta * _cosTheta;
|
||||
float tanThetaSqr = (1.0 - cosThetaSqr) / cosThetaSqr;
|
||||
float res = 0.5f * (sqrtf(1.0f + alphaSqr * tanThetaSqr) - 1.0f);
|
||||
return res;
|
||||
}
|
||||
|
||||
__device__ inline void bwdLambdaGGX(const float alphaSqr, const float cosTheta, float& d_alphaSqr, float& d_cosTheta, const float d_out)
|
||||
{
|
||||
float _cosTheta = clamp(cosTheta, SPECULAR_EPSILON, 1.0f - SPECULAR_EPSILON);
|
||||
float cosThetaSqr = _cosTheta * _cosTheta;
|
||||
float tanThetaSqr = (1.0 - cosThetaSqr) / cosThetaSqr;
|
||||
float res = 0.5f * (sqrtf(1.0f + alphaSqr * tanThetaSqr) - 1.0f);
|
||||
|
||||
d_alphaSqr += d_out * (0.25 * tanThetaSqr) / sqrtf(alphaSqr * tanThetaSqr + 1.0f);
|
||||
if (cosTheta > SPECULAR_EPSILON && cosTheta < 1.0f - SPECULAR_EPSILON)
|
||||
d_cosTheta += d_out * -(0.5 * alphaSqr) / (powf(_cosTheta, 3.0f) * sqrtf(alphaSqr / cosThetaSqr - alphaSqr + 1.0f));
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------
|
||||
// Masking GGX
|
||||
|
||||
__device__ inline float fwdMaskingSmithGGXCorrelated(const float alphaSqr, const float cosThetaI, const float cosThetaO)
|
||||
{
|
||||
float lambdaI = fwdLambdaGGX(alphaSqr, cosThetaI);
|
||||
float lambdaO = fwdLambdaGGX(alphaSqr, cosThetaO);
|
||||
return 1.0f / (1.0f + lambdaI + lambdaO);
|
||||
}
|
||||
|
||||
__device__ inline void bwdMaskingSmithGGXCorrelated(const float alphaSqr, const float cosThetaI, const float cosThetaO, float& d_alphaSqr, float& d_cosThetaI, float& d_cosThetaO, const float d_out)
|
||||
{
|
||||
// FWD eval
|
||||
float lambdaI = fwdLambdaGGX(alphaSqr, cosThetaI);
|
||||
float lambdaO = fwdLambdaGGX(alphaSqr, cosThetaO);
|
||||
|
||||
// BWD eval
|
||||
float d_lambdaIO = -d_out / powf(1.0f + lambdaI + lambdaO, 2.0f);
|
||||
bwdLambdaGGX(alphaSqr, cosThetaI, d_alphaSqr, d_cosThetaI, d_lambdaIO);
|
||||
bwdLambdaGGX(alphaSqr, cosThetaO, d_alphaSqr, d_cosThetaO, d_lambdaIO);
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------
|
||||
// GGX specular
|
||||
|
||||
__device__ vec3f fwdPbrSpecular(const vec3f col, const vec3f nrm, const vec3f wo, const vec3f wi, const float alpha, const float min_roughness)
|
||||
{
|
||||
float _alpha = clamp(alpha, min_roughness * min_roughness, 1.0f);
|
||||
float alphaSqr = _alpha * _alpha;
|
||||
|
||||
vec3f h = safeNormalize(wo + wi);
|
||||
float woDotN = dot(wo, nrm);
|
||||
float wiDotN = dot(wi, nrm);
|
||||
float woDotH = dot(wo, h);
|
||||
float nDotH = dot(nrm, h);
|
||||
|
||||
float D = fwdNdfGGX(alphaSqr, nDotH);
|
||||
float G = fwdMaskingSmithGGXCorrelated(alphaSqr, woDotN, wiDotN);
|
||||
vec3f F = fwdFresnelSchlick(col, 1.0f, woDotH);
|
||||
vec3f w = F * D * G * 0.25 / woDotN;
|
||||
|
||||
bool frontfacing = (woDotN > SPECULAR_EPSILON) & (wiDotN > SPECULAR_EPSILON);
|
||||
return frontfacing ? w : 0.0f;
|
||||
}
|
||||
|
||||
__device__ void bwdPbrSpecular(
|
||||
const vec3f col, const vec3f nrm, const vec3f wo, const vec3f wi, const float alpha, const float min_roughness,
|
||||
vec3f& d_col, vec3f& d_nrm, vec3f& d_wo, vec3f& d_wi, float& d_alpha, const vec3f d_out)
|
||||
{
|
||||
///////////////////////////////////////////////////////////////////////
|
||||
// FWD eval
|
||||
|
||||
float _alpha = clamp(alpha, min_roughness * min_roughness, 1.0f);
|
||||
float alphaSqr = _alpha * _alpha;
|
||||
|
||||
vec3f h = safeNormalize(wo + wi);
|
||||
float woDotN = dot(wo, nrm);
|
||||
float wiDotN = dot(wi, nrm);
|
||||
float woDotH = dot(wo, h);
|
||||
float nDotH = dot(nrm, h);
|
||||
|
||||
float D = fwdNdfGGX(alphaSqr, nDotH);
|
||||
float G = fwdMaskingSmithGGXCorrelated(alphaSqr, woDotN, wiDotN);
|
||||
vec3f F = fwdFresnelSchlick(col, 1.0f, woDotH);
|
||||
vec3f w = F * D * G * 0.25 / woDotN;
|
||||
bool frontfacing = (woDotN > SPECULAR_EPSILON) & (wiDotN > SPECULAR_EPSILON);
|
||||
|
||||
if (frontfacing)
|
||||
{
|
||||
///////////////////////////////////////////////////////////////////////
|
||||
// BWD eval
|
||||
|
||||
vec3f d_F = d_out * D * G * 0.25f / woDotN;
|
||||
float d_D = sum(d_out * F * G * 0.25f / woDotN);
|
||||
float d_G = sum(d_out * F * D * 0.25f / woDotN);
|
||||
|
||||
float d_woDotN = -sum(d_out * F * D * G * 0.25f / (woDotN * woDotN));
|
||||
|
||||
vec3f d_f90(0);
|
||||
float d_woDotH(0), d_wiDotN(0), d_nDotH(0), d_alphaSqr(0);
|
||||
bwdFresnelSchlick(col, 1.0f, woDotH, d_col, d_f90, d_woDotH, d_F);
|
||||
bwdMaskingSmithGGXCorrelated(alphaSqr, woDotN, wiDotN, d_alphaSqr, d_woDotN, d_wiDotN, d_G);
|
||||
bwdNdfGGX(alphaSqr, nDotH, d_alphaSqr, d_nDotH, d_D);
|
||||
|
||||
vec3f d_h(0);
|
||||
bwdDot(nrm, h, d_nrm, d_h, d_nDotH);
|
||||
bwdDot(wo, h, d_wo, d_h, d_woDotH);
|
||||
bwdDot(wi, nrm, d_wi, d_nrm, d_wiDotN);
|
||||
bwdDot(wo, nrm, d_wo, d_nrm, d_woDotN);
|
||||
|
||||
vec3f d_h_unnorm(0);
|
||||
bwdSafeNormalize(wo + wi, d_h_unnorm, d_h);
|
||||
d_wo += d_h_unnorm;
|
||||
d_wi += d_h_unnorm;
|
||||
|
||||
if (alpha > min_roughness * min_roughness)
|
||||
d_alpha += d_alphaSqr * 2 * alpha;
|
||||
}
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------
|
||||
// Full PBR BSDF
|
||||
|
||||
__device__ vec3f fwdPbrBSDF(const vec3f kd, const vec3f arm, const vec3f pos, const vec3f nrm, const vec3f view_pos, const vec3f light_pos, const float min_roughness, int BSDF)
|
||||
{
|
||||
vec3f wo = safeNormalize(view_pos - pos);
|
||||
vec3f wi = safeNormalize(light_pos - pos);
|
||||
|
||||
float alpha = arm.y * arm.y;
|
||||
vec3f spec_col = (0.04f * (1.0f - arm.z) + kd * arm.z) * (1.0 - arm.x);
|
||||
vec3f diff_col = kd * (1.0f - arm.z);
|
||||
|
||||
float diff = 0.0f;
|
||||
if (BSDF == 0)
|
||||
diff = fwdLambert(nrm, wi);
|
||||
else
|
||||
diff = fwdFrostbiteDiffuse(nrm, wi, wo, arm.y);
|
||||
vec3f diffuse = diff_col * diff;
|
||||
vec3f specular = fwdPbrSpecular(spec_col, nrm, wo, wi, alpha, min_roughness);
|
||||
|
||||
return diffuse + specular;
|
||||
}
|
||||
|
||||
__device__ void bwdPbrBSDF(
|
||||
const vec3f kd, const vec3f arm, const vec3f pos, const vec3f nrm, const vec3f view_pos, const vec3f light_pos, const float min_roughness, int BSDF,
|
||||
vec3f& d_kd, vec3f& d_arm, vec3f& d_pos, vec3f& d_nrm, vec3f& d_view_pos, vec3f& d_light_pos, const vec3f d_out)
|
||||
{
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
// FWD
|
||||
vec3f _wi = light_pos - pos;
|
||||
vec3f _wo = view_pos - pos;
|
||||
vec3f wi = safeNormalize(_wi);
|
||||
vec3f wo = safeNormalize(_wo);
|
||||
|
||||
float alpha = arm.y * arm.y;
|
||||
vec3f spec_col = (0.04f * (1.0f - arm.z) + kd * arm.z) * (1.0 - arm.x);
|
||||
vec3f diff_col = kd * (1.0f - arm.z);
|
||||
float diff = 0.0f;
|
||||
if (BSDF == 0)
|
||||
diff = fwdLambert(nrm, wi);
|
||||
else
|
||||
diff = fwdFrostbiteDiffuse(nrm, wi, wo, arm.y);
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
// BWD
|
||||
|
||||
float d_alpha(0);
|
||||
vec3f d_spec_col(0), d_wi(0), d_wo(0);
|
||||
bwdPbrSpecular(spec_col, nrm, wo, wi, alpha, min_roughness, d_spec_col, d_nrm, d_wo, d_wi, d_alpha, d_out);
|
||||
|
||||
float d_diff = sum(diff_col * d_out);
|
||||
if (BSDF == 0)
|
||||
bwdLambert(nrm, wi, d_nrm, d_wi, d_diff);
|
||||
else
|
||||
bwdFrostbiteDiffuse(nrm, wi, wo, arm.y, d_nrm, d_wi, d_wo, d_arm.y, d_diff);
|
||||
|
||||
// Backprop: diff_col = kd * (1.0f - arm.z)
|
||||
vec3f d_diff_col = d_out * diff;
|
||||
d_kd += d_diff_col * (1.0f - arm.z);
|
||||
d_arm.z -= sum(d_diff_col * kd);
|
||||
|
||||
// Backprop: spec_col = (0.04f * (1.0f - arm.z) + kd * arm.z) * (1.0 - arm.x)
|
||||
d_kd -= d_spec_col * (arm.x - 1.0f) * arm.z;
|
||||
d_arm.x += sum(d_spec_col * (arm.z * (0.04f - kd) - 0.04f));
|
||||
d_arm.z -= sum(d_spec_col * (kd - 0.04f) * (arm.x - 1.0f));
|
||||
|
||||
// Backprop: alpha = arm.y * arm.y
|
||||
d_arm.y += d_alpha * 2 * arm.y;
|
||||
|
||||
// Backprop: vec3f wi = safeNormalize(light_pos - pos);
|
||||
vec3f d__wi(0);
|
||||
bwdSafeNormalize(_wi, d__wi, d_wi);
|
||||
d_light_pos += d__wi;
|
||||
d_pos -= d__wi;
|
||||
|
||||
// Backprop: vec3f wo = safeNormalize(view_pos - pos);
|
||||
vec3f d__wo(0);
|
||||
bwdSafeNormalize(_wo, d__wo, d_wo);
|
||||
d_view_pos += d__wo;
|
||||
d_pos -= d__wo;
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------
|
||||
// Kernels
|
||||
|
||||
__global__ void LambertFwdKernel(LambertKernelParams p)
|
||||
{
|
||||
// Calculate pixel position.
|
||||
unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
|
||||
unsigned int pz = blockIdx.z;
|
||||
if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
|
||||
return;
|
||||
|
||||
vec3f nrm = p.nrm.fetch3(px, py, pz);
|
||||
vec3f wi = p.wi.fetch3(px, py, pz);
|
||||
|
||||
float res = fwdLambert(nrm, wi);
|
||||
|
||||
p.out.store(px, py, pz, res);
|
||||
}
|
||||
|
||||
__global__ void LambertBwdKernel(LambertKernelParams p)
|
||||
{
|
||||
// Calculate pixel position.
|
||||
unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
|
||||
unsigned int pz = blockIdx.z;
|
||||
if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
|
||||
return;
|
||||
|
||||
vec3f nrm = p.nrm.fetch3(px, py, pz);
|
||||
vec3f wi = p.wi.fetch3(px, py, pz);
|
||||
float d_out = p.out.fetch1(px, py, pz);
|
||||
|
||||
vec3f d_nrm(0), d_wi(0);
|
||||
bwdLambert(nrm, wi, d_nrm, d_wi, d_out);
|
||||
|
||||
p.nrm.store_grad(px, py, pz, d_nrm);
|
||||
p.wi.store_grad(px, py, pz, d_wi);
|
||||
}
|
||||
|
||||
__global__ void FrostbiteDiffuseFwdKernel(FrostbiteDiffuseKernelParams p)
|
||||
{
|
||||
// Calculate pixel position.
|
||||
unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
|
||||
unsigned int pz = blockIdx.z;
|
||||
if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
|
||||
return;
|
||||
|
||||
vec3f nrm = p.nrm.fetch3(px, py, pz);
|
||||
vec3f wi = p.wi.fetch3(px, py, pz);
|
||||
vec3f wo = p.wo.fetch3(px, py, pz);
|
||||
float linearRoughness = p.linearRoughness.fetch1(px, py, pz);
|
||||
|
||||
float res = fwdFrostbiteDiffuse(nrm, wi, wo, linearRoughness);
|
||||
|
||||
p.out.store(px, py, pz, res);
|
||||
}
|
||||
|
||||
__global__ void FrostbiteDiffuseBwdKernel(FrostbiteDiffuseKernelParams p)
|
||||
{
|
||||
// Calculate pixel position.
|
||||
unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
|
||||
unsigned int pz = blockIdx.z;
|
||||
if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
|
||||
return;
|
||||
|
||||
vec3f nrm = p.nrm.fetch3(px, py, pz);
|
||||
vec3f wi = p.wi.fetch3(px, py, pz);
|
||||
vec3f wo = p.wo.fetch3(px, py, pz);
|
||||
float linearRoughness = p.linearRoughness.fetch1(px, py, pz);
|
||||
float d_out = p.out.fetch1(px, py, pz);
|
||||
|
||||
float d_linearRoughness = 0.0f;
|
||||
vec3f d_nrm(0), d_wi(0), d_wo(0);
|
||||
bwdFrostbiteDiffuse(nrm, wi, wo, linearRoughness, d_nrm, d_wi, d_wo, d_linearRoughness, d_out);
|
||||
|
||||
p.nrm.store_grad(px, py, pz, d_nrm);
|
||||
p.wi.store_grad(px, py, pz, d_wi);
|
||||
p.wo.store_grad(px, py, pz, d_wo);
|
||||
p.linearRoughness.store_grad(px, py, pz, d_linearRoughness);
|
||||
}
|
||||
|
||||
__global__ void FresnelShlickFwdKernel(FresnelShlickKernelParams p)
|
||||
{
|
||||
// Calculate pixel position.
|
||||
unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
|
||||
unsigned int pz = blockIdx.z;
|
||||
if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
|
||||
return;
|
||||
|
||||
vec3f f0 = p.f0.fetch3(px, py, pz);
|
||||
vec3f f90 = p.f90.fetch3(px, py, pz);
|
||||
float cosTheta = p.cosTheta.fetch1(px, py, pz);
|
||||
|
||||
vec3f res = fwdFresnelSchlick(f0, f90, cosTheta);
|
||||
p.out.store(px, py, pz, res);
|
||||
}
|
||||
|
||||
__global__ void FresnelShlickBwdKernel(FresnelShlickKernelParams p)
|
||||
{
|
||||
// Calculate pixel position.
|
||||
unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
|
||||
unsigned int pz = blockIdx.z;
|
||||
if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
|
||||
return;
|
||||
|
||||
vec3f f0 = p.f0.fetch3(px, py, pz);
|
||||
vec3f f90 = p.f90.fetch3(px, py, pz);
|
||||
float cosTheta = p.cosTheta.fetch1(px, py, pz);
|
||||
vec3f d_out = p.out.fetch3(px, py, pz);
|
||||
|
||||
vec3f d_f0(0), d_f90(0);
|
||||
float d_cosTheta(0);
|
||||
bwdFresnelSchlick(f0, f90, cosTheta, d_f0, d_f90, d_cosTheta, d_out);
|
||||
|
||||
p.f0.store_grad(px, py, pz, d_f0);
|
||||
p.f90.store_grad(px, py, pz, d_f90);
|
||||
p.cosTheta.store_grad(px, py, pz, d_cosTheta);
|
||||
}
|
||||
|
||||
__global__ void ndfGGXFwdKernel(NdfGGXParams p)
|
||||
{
|
||||
// Calculate pixel position.
|
||||
unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
|
||||
unsigned int pz = blockIdx.z;
|
||||
if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
|
||||
return;
|
||||
|
||||
float alphaSqr = p.alphaSqr.fetch1(px, py, pz);
|
||||
float cosTheta = p.cosTheta.fetch1(px, py, pz);
|
||||
float res = fwdNdfGGX(alphaSqr, cosTheta);
|
||||
|
||||
p.out.store(px, py, pz, res);
|
||||
}
|
||||
|
||||
__global__ void ndfGGXBwdKernel(NdfGGXParams p)
|
||||
{
|
||||
// Calculate pixel position.
|
||||
unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
|
||||
unsigned int pz = blockIdx.z;
|
||||
if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
|
||||
return;
|
||||
|
||||
float alphaSqr = p.alphaSqr.fetch1(px, py, pz);
|
||||
float cosTheta = p.cosTheta.fetch1(px, py, pz);
|
||||
float d_out = p.out.fetch1(px, py, pz);
|
||||
|
||||
float d_alphaSqr(0), d_cosTheta(0);
|
||||
bwdNdfGGX(alphaSqr, cosTheta, d_alphaSqr, d_cosTheta, d_out);
|
||||
|
||||
p.alphaSqr.store_grad(px, py, pz, d_alphaSqr);
|
||||
p.cosTheta.store_grad(px, py, pz, d_cosTheta);
|
||||
}
|
||||
|
||||
__global__ void lambdaGGXFwdKernel(NdfGGXParams p)
|
||||
{
|
||||
// Calculate pixel position.
|
||||
unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
|
||||
unsigned int pz = blockIdx.z;
|
||||
if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
|
||||
return;
|
||||
|
||||
float alphaSqr = p.alphaSqr.fetch1(px, py, pz);
|
||||
float cosTheta = p.cosTheta.fetch1(px, py, pz);
|
||||
float res = fwdLambdaGGX(alphaSqr, cosTheta);
|
||||
|
||||
p.out.store(px, py, pz, res);
|
||||
}
|
||||
|
||||
__global__ void lambdaGGXBwdKernel(NdfGGXParams p)
|
||||
{
|
||||
// Calculate pixel position.
|
||||
unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
|
||||
unsigned int pz = blockIdx.z;
|
||||
if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
|
||||
return;
|
||||
|
||||
float alphaSqr = p.alphaSqr.fetch1(px, py, pz);
|
||||
float cosTheta = p.cosTheta.fetch1(px, py, pz);
|
||||
float d_out = p.out.fetch1(px, py, pz);
|
||||
|
||||
float d_alphaSqr(0), d_cosTheta(0);
|
||||
bwdLambdaGGX(alphaSqr, cosTheta, d_alphaSqr, d_cosTheta, d_out);
|
||||
|
||||
p.alphaSqr.store_grad(px, py, pz, d_alphaSqr);
|
||||
p.cosTheta.store_grad(px, py, pz, d_cosTheta);
|
||||
}
|
||||
|
||||
__global__ void maskingSmithFwdKernel(MaskingSmithParams p)
|
||||
{
|
||||
// Calculate pixel position.
|
||||
unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
|
||||
unsigned int pz = blockIdx.z;
|
||||
if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
|
||||
return;
|
||||
|
||||
float alphaSqr = p.alphaSqr.fetch1(px, py, pz);
|
||||
float cosThetaI = p.cosThetaI.fetch1(px, py, pz);
|
||||
float cosThetaO = p.cosThetaO.fetch1(px, py, pz);
|
||||
float res = fwdMaskingSmithGGXCorrelated(alphaSqr, cosThetaI, cosThetaO);
|
||||
|
||||
p.out.store(px, py, pz, res);
|
||||
}
|
||||
|
||||
__global__ void maskingSmithBwdKernel(MaskingSmithParams p)
|
||||
{
|
||||
// Calculate pixel position.
|
||||
unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
|
||||
unsigned int pz = blockIdx.z;
|
||||
if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
|
||||
return;
|
||||
|
||||
float alphaSqr = p.alphaSqr.fetch1(px, py, pz);
|
||||
float cosThetaI = p.cosThetaI.fetch1(px, py, pz);
|
||||
float cosThetaO = p.cosThetaO.fetch1(px, py, pz);
|
||||
float d_out = p.out.fetch1(px, py, pz);
|
||||
|
||||
float d_alphaSqr(0), d_cosThetaI(0), d_cosThetaO(0);
|
||||
bwdMaskingSmithGGXCorrelated(alphaSqr, cosThetaI, cosThetaO, d_alphaSqr, d_cosThetaI, d_cosThetaO, d_out);
|
||||
|
||||
p.alphaSqr.store_grad(px, py, pz, d_alphaSqr);
|
||||
p.cosThetaI.store_grad(px, py, pz, d_cosThetaI);
|
||||
p.cosThetaO.store_grad(px, py, pz, d_cosThetaO);
|
||||
}
|
||||
|
||||
__global__ void pbrSpecularFwdKernel(PbrSpecular p)
|
||||
{
|
||||
// Calculate pixel position.
|
||||
unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
|
||||
unsigned int pz = blockIdx.z;
|
||||
if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
|
||||
return;
|
||||
|
||||
vec3f col = p.col.fetch3(px, py, pz);
|
||||
vec3f nrm = p.nrm.fetch3(px, py, pz);
|
||||
vec3f wo = p.wo.fetch3(px, py, pz);
|
||||
vec3f wi = p.wi.fetch3(px, py, pz);
|
||||
float alpha = p.alpha.fetch1(px, py, pz);
|
||||
|
||||
vec3f res = fwdPbrSpecular(col, nrm, wo, wi, alpha, p.min_roughness);
|
||||
|
||||
p.out.store(px, py, pz, res);
|
||||
}
|
||||
|
||||
__global__ void pbrSpecularBwdKernel(PbrSpecular p)
|
||||
{
|
||||
// Calculate pixel position.
|
||||
unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
|
||||
unsigned int pz = blockIdx.z;
|
||||
if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
|
||||
return;
|
||||
|
||||
vec3f col = p.col.fetch3(px, py, pz);
|
||||
vec3f nrm = p.nrm.fetch3(px, py, pz);
|
||||
vec3f wo = p.wo.fetch3(px, py, pz);
|
||||
vec3f wi = p.wi.fetch3(px, py, pz);
|
||||
float alpha = p.alpha.fetch1(px, py, pz);
|
||||
vec3f d_out = p.out.fetch3(px, py, pz);
|
||||
|
||||
float d_alpha(0);
|
||||
vec3f d_col(0), d_nrm(0), d_wo(0), d_wi(0);
|
||||
bwdPbrSpecular(col, nrm, wo, wi, alpha, p.min_roughness, d_col, d_nrm, d_wo, d_wi, d_alpha, d_out);
|
||||
|
||||
p.col.store_grad(px, py, pz, d_col);
|
||||
p.nrm.store_grad(px, py, pz, d_nrm);
|
||||
p.wo.store_grad(px, py, pz, d_wo);
|
||||
p.wi.store_grad(px, py, pz, d_wi);
|
||||
p.alpha.store_grad(px, py, pz, d_alpha);
|
||||
}
|
||||
|
||||
__global__ void pbrBSDFFwdKernel(PbrBSDF p)
|
||||
{
|
||||
// Calculate pixel position.
|
||||
unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
|
||||
unsigned int pz = blockIdx.z;
|
||||
if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
|
||||
return;
|
||||
|
||||
vec3f kd = p.kd.fetch3(px, py, pz);
|
||||
vec3f arm = p.arm.fetch3(px, py, pz);
|
||||
vec3f pos = p.pos.fetch3(px, py, pz);
|
||||
vec3f nrm = p.nrm.fetch3(px, py, pz);
|
||||
vec3f view_pos = p.view_pos.fetch3(px, py, pz);
|
||||
vec3f light_pos = p.light_pos.fetch3(px, py, pz);
|
||||
|
||||
vec3f res = fwdPbrBSDF(kd, arm, pos, nrm, view_pos, light_pos, p.min_roughness, p.BSDF);
|
||||
|
||||
p.out.store(px, py, pz, res);
|
||||
}
|
||||
__global__ void pbrBSDFBwdKernel(PbrBSDF p)
|
||||
{
|
||||
// Calculate pixel position.
|
||||
unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
|
||||
unsigned int pz = blockIdx.z;
|
||||
if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
|
||||
return;
|
||||
|
||||
vec3f kd = p.kd.fetch3(px, py, pz);
|
||||
vec3f arm = p.arm.fetch3(px, py, pz);
|
||||
vec3f pos = p.pos.fetch3(px, py, pz);
|
||||
vec3f nrm = p.nrm.fetch3(px, py, pz);
|
||||
vec3f view_pos = p.view_pos.fetch3(px, py, pz);
|
||||
vec3f light_pos = p.light_pos.fetch3(px, py, pz);
|
||||
vec3f d_out = p.out.fetch3(px, py, pz);
|
||||
|
||||
vec3f d_kd(0), d_arm(0), d_pos(0), d_nrm(0), d_view_pos(0), d_light_pos(0);
|
||||
bwdPbrBSDF(kd, arm, pos, nrm, view_pos, light_pos, p.min_roughness, p.BSDF, d_kd, d_arm, d_pos, d_nrm, d_view_pos, d_light_pos, d_out);
|
||||
|
||||
p.kd.store_grad(px, py, pz, d_kd);
|
||||
p.arm.store_grad(px, py, pz, d_arm);
|
||||
p.pos.store_grad(px, py, pz, d_pos);
|
||||
p.nrm.store_grad(px, py, pz, d_nrm);
|
||||
p.view_pos.store_grad(px, py, pz, d_view_pos);
|
||||
p.light_pos.store_grad(px, py, pz, d_light_pos);
|
||||
}
|
|
@ -0,0 +1,154 @@
|
|||
# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
#
|
||||
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
||||
# property and proprietary rights in and to this material, related
|
||||
# documentation and any modifications thereto. Any use, reproduction,
|
||||
# disclosure or distribution of this material and related documentation
|
||||
# without an express license agreement from NVIDIA CORPORATION or
|
||||
# its affiliates is strictly prohibited.
|
||||
|
||||
import math
|
||||
import torch
|
||||
|
||||
NORMAL_THRESHOLD = 0.1
|
||||
|
||||
################################################################################
|
||||
# Vector utility functions
|
||||
################################################################################
|
||||
|
||||
def _dot(x, y):
|
||||
return torch.sum(x*y, -1, keepdim=True)
|
||||
|
||||
def _reflect(x, n):
|
||||
return 2*_dot(x, n)*n - x
|
||||
|
||||
def _safe_normalize(x):
|
||||
return torch.nn.functional.normalize(x, dim = -1)
|
||||
|
||||
def _bend_normal(view_vec, smooth_nrm, geom_nrm, two_sided_shading):
|
||||
# Swap normal direction for backfacing surfaces
|
||||
if two_sided_shading:
|
||||
smooth_nrm = torch.where(_dot(geom_nrm, view_vec) > 0, smooth_nrm, -smooth_nrm)
|
||||
geom_nrm = torch.where(_dot(geom_nrm, view_vec) > 0, geom_nrm, -geom_nrm)
|
||||
|
||||
t = torch.clamp(_dot(view_vec, smooth_nrm) / NORMAL_THRESHOLD, min=0, max=1)
|
||||
return torch.lerp(geom_nrm, smooth_nrm, t)
|
||||
|
||||
|
||||
def _perturb_normal(perturbed_nrm, smooth_nrm, smooth_tng, opengl):
|
||||
smooth_bitang = _safe_normalize(torch.cross(smooth_tng, smooth_nrm))
|
||||
if opengl:
|
||||
shading_nrm = smooth_tng * perturbed_nrm[..., 0:1] - smooth_bitang * perturbed_nrm[..., 1:2] + smooth_nrm * torch.clamp(perturbed_nrm[..., 2:3], min=0.0)
|
||||
else:
|
||||
shading_nrm = smooth_tng * perturbed_nrm[..., 0:1] + smooth_bitang * perturbed_nrm[..., 1:2] + smooth_nrm * torch.clamp(perturbed_nrm[..., 2:3], min=0.0)
|
||||
return _safe_normalize(shading_nrm)
|
||||
|
||||
def bsdf_prepare_shading_normal(pos, view_pos, perturbed_nrm, smooth_nrm, smooth_tng, geom_nrm, two_sided_shading, opengl):
|
||||
smooth_nrm = _safe_normalize(smooth_nrm)
|
||||
view_vec = _safe_normalize(view_pos - pos)
|
||||
if smooth_tng is None:
|
||||
shading_nrm = smooth_nrm
|
||||
else:
|
||||
smooth_tng = _safe_normalize(smooth_tng)
|
||||
shading_nrm = _perturb_normal(perturbed_nrm, smooth_nrm, smooth_tng, opengl)
|
||||
return _bend_normal(view_vec, shading_nrm, geom_nrm, two_sided_shading)
|
||||
|
||||
################################################################################
|
||||
# Simple lambertian diffuse BSDF
|
||||
################################################################################
|
||||
|
||||
def bsdf_lambert(nrm, wi):
|
||||
return torch.clamp(_dot(nrm, wi), min=0.0) / math.pi
|
||||
|
||||
################################################################################
|
||||
# Frostbite diffuse
|
||||
################################################################################
|
||||
|
||||
def bsdf_frostbite(nrm, wi, wo, linearRoughness):
|
||||
wiDotN = _dot(wi, nrm)
|
||||
woDotN = _dot(wo, nrm)
|
||||
|
||||
h = _safe_normalize(wo + wi)
|
||||
wiDotH = _dot(wi, h)
|
||||
|
||||
energyBias = 0.5 * linearRoughness
|
||||
energyFactor = 1.0 - (0.51 / 1.51) * linearRoughness
|
||||
f90 = energyBias + 2.0 * wiDotH * wiDotH * linearRoughness
|
||||
f0 = 1.0
|
||||
|
||||
wiScatter = bsdf_fresnel_shlick(f0, f90, wiDotN)
|
||||
woScatter = bsdf_fresnel_shlick(f0, f90, woDotN)
|
||||
res = wiScatter * woScatter * energyFactor
|
||||
return torch.where((wiDotN > 0.0) & (woDotN > 0.0), res, torch.zeros_like(res))
|
||||
|
||||
################################################################################
|
||||
# Phong specular, loosely based on mitsuba implementation
|
||||
################################################################################
|
||||
|
||||
def bsdf_phong(nrm, wo, wi, N):
|
||||
dp_r = torch.clamp(_dot(_reflect(wo, nrm), wi), min=0.0, max=1.0)
|
||||
dp_l = torch.clamp(_dot(nrm, wi), min=0.0, max=1.0)
|
||||
return (dp_r ** N) * dp_l * (N + 2) / (2 * math.pi)
|
||||
|
||||
################################################################################
|
||||
# PBR's implementation of GGX specular
|
||||
################################################################################
|
||||
|
||||
specular_epsilon = 1e-4
|
||||
|
||||
def bsdf_fresnel_shlick(f0, f90, cosTheta):
|
||||
_cosTheta = torch.clamp(cosTheta, min=specular_epsilon, max=1.0 - specular_epsilon)
|
||||
return f0 + (f90 - f0) * (1.0 - _cosTheta) ** 5.0
|
||||
|
||||
def bsdf_ndf_ggx(alphaSqr, cosTheta):
|
||||
_cosTheta = torch.clamp(cosTheta, min=specular_epsilon, max=1.0 - specular_epsilon)
|
||||
d = (_cosTheta * alphaSqr - _cosTheta) * _cosTheta + 1
|
||||
return alphaSqr / (d * d * math.pi)
|
||||
|
||||
def bsdf_lambda_ggx(alphaSqr, cosTheta):
|
||||
_cosTheta = torch.clamp(cosTheta, min=specular_epsilon, max=1.0 - specular_epsilon)
|
||||
cosThetaSqr = _cosTheta * _cosTheta
|
||||
tanThetaSqr = (1.0 - cosThetaSqr) / cosThetaSqr
|
||||
res = 0.5 * (torch.sqrt(1 + alphaSqr * tanThetaSqr) - 1.0)
|
||||
return res
|
||||
|
||||
def bsdf_masking_smith_ggx_correlated(alphaSqr, cosThetaI, cosThetaO):
|
||||
lambdaI = bsdf_lambda_ggx(alphaSqr, cosThetaI)
|
||||
lambdaO = bsdf_lambda_ggx(alphaSqr, cosThetaO)
|
||||
return 1 / (1 + lambdaI + lambdaO)
|
||||
|
||||
def bsdf_pbr_specular(col, nrm, wo, wi, alpha, min_roughness=0.08):
|
||||
_alpha = torch.clamp(alpha, min=min_roughness*min_roughness, max=1.0)
|
||||
alphaSqr = _alpha * _alpha
|
||||
|
||||
h = _safe_normalize(wo + wi)
|
||||
woDotN = _dot(wo, nrm)
|
||||
wiDotN = _dot(wi, nrm)
|
||||
woDotH = _dot(wo, h)
|
||||
nDotH = _dot(nrm, h)
|
||||
|
||||
D = bsdf_ndf_ggx(alphaSqr, nDotH)
|
||||
G = bsdf_masking_smith_ggx_correlated(alphaSqr, woDotN, wiDotN)
|
||||
F = bsdf_fresnel_shlick(col, 1, woDotH)
|
||||
|
||||
w = F * D * G * 0.25 / torch.clamp(woDotN, min=specular_epsilon)
|
||||
|
||||
frontfacing = (woDotN > specular_epsilon) & (wiDotN > specular_epsilon)
|
||||
return torch.where(frontfacing, w, torch.zeros_like(w))
|
||||
|
||||
def bsdf_pbr(kd, arm, pos, nrm, view_pos, light_pos, min_roughness, BSDF):
|
||||
wo = _safe_normalize(view_pos - pos)
|
||||
wi = _safe_normalize(light_pos - pos)
|
||||
|
||||
spec_str = arm[..., 0:1] # x component
|
||||
roughness = arm[..., 1:2] # y component
|
||||
metallic = arm[..., 2:3] # z component
|
||||
ks = (0.04 * (1.0 - metallic) + kd * metallic) * (1 - spec_str)
|
||||
kd = kd * (1.0 - metallic)
|
||||
|
||||
if BSDF == 0:
|
||||
diffuse = kd * bsdf_lambert(nrm, wi)
|
||||
else:
|
||||
diffuse = kd * bsdf_frostbite(nrm, wi, wo, roughness)
|
||||
specular = bsdf_pbr_specular(ks, nrm, wo, wi, roughness*roughness, min_roughness=min_roughness)
|
||||
return diffuse + specular
|
|
@ -0,0 +1,710 @@
|
|||
/*
|
||||
* Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
*
|
||||
* NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
||||
* property and proprietary rights in and to this material, related
|
||||
* documentation and any modifications thereto. Any use, reproduction,
|
||||
* disclosure or distribution of this material and related documentation
|
||||
* without an express license agreement from NVIDIA CORPORATION or
|
||||
* its affiliates is strictly prohibited.
|
||||
*/
|
||||
|
||||
#include "common.h"
|
||||
#include "bsdf.h"
|
||||
|
||||
#define SPECULAR_EPSILON 1e-4f
|
||||
|
||||
//------------------------------------------------------------------------
|
||||
// Lambert functions
|
||||
|
||||
__device__ inline float fwdLambert(const vec3f nrm, const vec3f wi)
|
||||
{
|
||||
return max(dot(nrm, wi) / M_PI, 0.0f);
|
||||
}
|
||||
|
||||
__device__ inline void bwdLambert(const vec3f nrm, const vec3f wi, vec3f& d_nrm, vec3f& d_wi, const float d_out)
|
||||
{
|
||||
if (dot(nrm, wi) > 0.0f)
|
||||
bwdDot(nrm, wi, d_nrm, d_wi, d_out / M_PI);
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------
|
||||
// Fresnel Schlick
|
||||
|
||||
__device__ inline float fwdFresnelSchlick(const float f0, const float f90, const float cosTheta)
|
||||
{
|
||||
float _cosTheta = clamp(cosTheta, SPECULAR_EPSILON, 1.0f - SPECULAR_EPSILON);
|
||||
float scale = powf(1.0f - _cosTheta, 5.0f);
|
||||
return f0 * (1.0f - scale) + f90 * scale;
|
||||
}
|
||||
|
||||
__device__ inline void bwdFresnelSchlick(const float f0, const float f90, const float cosTheta, float& d_f0, float& d_f90, float& d_cosTheta, const float d_out)
|
||||
{
|
||||
float _cosTheta = clamp(cosTheta, SPECULAR_EPSILON, 1.0f - SPECULAR_EPSILON);
|
||||
float scale = pow(max(1.0f - _cosTheta, 0.0f), 5.0f);
|
||||
d_f0 += d_out * (1.0 - scale);
|
||||
d_f90 += d_out * scale;
|
||||
if (cosTheta >= SPECULAR_EPSILON && cosTheta < 1.0f - SPECULAR_EPSILON)
|
||||
{
|
||||
d_cosTheta += d_out * (f90 - f0) * -5.0f * powf(1.0f - cosTheta, 4.0f);
|
||||
}
|
||||
}
|
||||
|
||||
__device__ inline vec3f fwdFresnelSchlick(const vec3f f0, const vec3f f90, const float cosTheta)
|
||||
{
|
||||
float _cosTheta = clamp(cosTheta, SPECULAR_EPSILON, 1.0f - SPECULAR_EPSILON);
|
||||
float scale = powf(1.0f - _cosTheta, 5.0f);
|
||||
return f0 * (1.0f - scale) + f90 * scale;
|
||||
}
|
||||
|
||||
__device__ inline void bwdFresnelSchlick(const vec3f f0, const vec3f f90, const float cosTheta, vec3f& d_f0, vec3f& d_f90, float& d_cosTheta, const vec3f d_out)
|
||||
{
|
||||
float _cosTheta = clamp(cosTheta, SPECULAR_EPSILON, 1.0f - SPECULAR_EPSILON);
|
||||
float scale = pow(max(1.0f - _cosTheta, 0.0f), 5.0f);
|
||||
d_f0 += d_out * (1.0 - scale);
|
||||
d_f90 += d_out * scale;
|
||||
if (cosTheta >= SPECULAR_EPSILON && cosTheta < 1.0f - SPECULAR_EPSILON)
|
||||
{
|
||||
d_cosTheta += sum(d_out * (f90 - f0) * -5.0f * powf(1.0f - cosTheta, 4.0f));
|
||||
}
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------
|
||||
// Frostbite diffuse
|
||||
|
||||
__device__ inline float fwdFrostbiteDiffuse(const vec3f nrm, const vec3f wi, const vec3f wo, float linearRoughness)
|
||||
{
|
||||
float wiDotN = dot(wi, nrm);
|
||||
float woDotN = dot(wo, nrm);
|
||||
if (wiDotN > 0.0f && woDotN > 0.0f)
|
||||
{
|
||||
vec3f h = safeNormalize(wo + wi);
|
||||
float wiDotH = dot(wi, h);
|
||||
|
||||
float energyBias = 0.5f * linearRoughness;
|
||||
float energyFactor = 1.0f - (0.51f / 1.51f) * linearRoughness;
|
||||
float f90 = energyBias + 2.f * wiDotH * wiDotH * linearRoughness;
|
||||
float f0 = 1.f;
|
||||
|
||||
float wiScatter = fwdFresnelSchlick(f0, f90, wiDotN);
|
||||
float woScatter = fwdFresnelSchlick(f0, f90, woDotN);
|
||||
|
||||
return wiScatter * woScatter * energyFactor;
|
||||
}
|
||||
else return 0.0f;
|
||||
}
|
||||
|
||||
__device__ inline void bwdFrostbiteDiffuse(const vec3f nrm, const vec3f wi, const vec3f wo, float linearRoughness, vec3f& d_nrm, vec3f& d_wi, vec3f& d_wo, float &d_linearRoughness, const float d_out)
|
||||
{
|
||||
float wiDotN = dot(wi, nrm);
|
||||
float woDotN = dot(wo, nrm);
|
||||
|
||||
if (wiDotN > 0.0f && woDotN > 0.0f)
|
||||
{
|
||||
vec3f h = safeNormalize(wo + wi);
|
||||
float wiDotH = dot(wi, h);
|
||||
|
||||
float energyBias = 0.5f * linearRoughness;
|
||||
float energyFactor = 1.0f - (0.51f / 1.51f) * linearRoughness;
|
||||
float f90 = energyBias + 2.f * wiDotH * wiDotH * linearRoughness;
|
||||
float f0 = 1.f;
|
||||
|
||||
float wiScatter = fwdFresnelSchlick(f0, f90, wiDotN);
|
||||
float woScatter = fwdFresnelSchlick(f0, f90, woDotN);
|
||||
|
||||
// -------------- BWD --------------
|
||||
// Backprop: return wiScatter * woScatter * energyFactor;
|
||||
float d_wiScatter = d_out * woScatter * energyFactor;
|
||||
float d_woScatter = d_out * wiScatter * energyFactor;
|
||||
float d_energyFactor = d_out * wiScatter * woScatter;
|
||||
|
||||
// Backprop: float woScatter = fwdFresnelSchlick(f0, f90, woDotN);
|
||||
float d_woDotN = 0.0f, d_f0 = 0.0, d_f90 = 0.0f;
|
||||
bwdFresnelSchlick(f0, f90, woDotN, d_f0, d_f90, d_woDotN, d_woScatter);
|
||||
|
||||
// Backprop: float wiScatter = fwdFresnelSchlick(fd0, fd90, wiDotN);
|
||||
float d_wiDotN = 0.0f;
|
||||
bwdFresnelSchlick(f0, f90, wiDotN, d_f0, d_f90, d_wiDotN, d_wiScatter);
|
||||
|
||||
// Backprop: float f90 = energyBias + 2.f * wiDotH * wiDotH * linearRoughness;
|
||||
float d_energyBias = d_f90;
|
||||
float d_wiDotH = d_f90 * 4 * wiDotH * linearRoughness;
|
||||
d_linearRoughness += d_f90 * 2 * wiDotH * wiDotH;
|
||||
|
||||
// Backprop: float energyFactor = 1.0f - (0.51f / 1.51f) * linearRoughness;
|
||||
d_linearRoughness -= (0.51f / 1.51f) * d_energyFactor;
|
||||
|
||||
// Backprop: float energyBias = 0.5f * linearRoughness;
|
||||
d_linearRoughness += 0.5 * d_energyBias;
|
||||
|
||||
// Backprop: float wiDotH = dot(wi, h);
|
||||
vec3f d_h(0);
|
||||
bwdDot(wi, h, d_wi, d_h, d_wiDotH);
|
||||
|
||||
// Backprop: vec3f h = safeNormalize(wo + wi);
|
||||
vec3f d_wo_wi(0);
|
||||
bwdSafeNormalize(wo + wi, d_wo_wi, d_h);
|
||||
d_wi += d_wo_wi; d_wo += d_wo_wi;
|
||||
|
||||
bwdDot(wo, nrm, d_wo, d_nrm, d_woDotN);
|
||||
bwdDot(wi, nrm, d_wi, d_nrm, d_wiDotN);
|
||||
}
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------
|
||||
// Ndf GGX
|
||||
|
||||
__device__ inline float fwdNdfGGX(const float alphaSqr, const float cosTheta)
|
||||
{
|
||||
float _cosTheta = clamp(cosTheta, SPECULAR_EPSILON, 1.0f - SPECULAR_EPSILON);
|
||||
float d = (_cosTheta * alphaSqr - _cosTheta) * _cosTheta + 1.0f;
|
||||
return alphaSqr / (d * d * M_PI);
|
||||
}
|
||||
|
||||
__device__ inline void bwdNdfGGX(const float alphaSqr, const float cosTheta, float& d_alphaSqr, float& d_cosTheta, const float d_out)
|
||||
{
|
||||
// Torch only back propagates if clamp doesn't trigger
|
||||
float _cosTheta = clamp(cosTheta, SPECULAR_EPSILON, 1.0f - SPECULAR_EPSILON);
|
||||
float cosThetaSqr = _cosTheta * _cosTheta;
|
||||
d_alphaSqr += d_out * (1.0f - (alphaSqr + 1.0f) * cosThetaSqr) / (M_PI * powf((alphaSqr - 1.0) * cosThetaSqr + 1.0f, 3.0f));
|
||||
if (cosTheta > SPECULAR_EPSILON && cosTheta < 1.0f - SPECULAR_EPSILON)
|
||||
{
|
||||
d_cosTheta += d_out * -(4.0f * (alphaSqr - 1.0f) * alphaSqr * cosTheta) / (M_PI * powf((alphaSqr - 1.0) * cosThetaSqr + 1.0f, 3.0f));
|
||||
}
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------
|
||||
// Lambda GGX
|
||||
|
||||
__device__ inline float fwdLambdaGGX(const float alphaSqr, const float cosTheta)
|
||||
{
|
||||
float _cosTheta = clamp(cosTheta, SPECULAR_EPSILON, 1.0f - SPECULAR_EPSILON);
|
||||
float cosThetaSqr = _cosTheta * _cosTheta;
|
||||
float tanThetaSqr = (1.0 - cosThetaSqr) / cosThetaSqr;
|
||||
float res = 0.5f * (sqrtf(1.0f + alphaSqr * tanThetaSqr) - 1.0f);
|
||||
return res;
|
||||
}
|
||||
|
||||
__device__ inline void bwdLambdaGGX(const float alphaSqr, const float cosTheta, float& d_alphaSqr, float& d_cosTheta, const float d_out)
|
||||
{
|
||||
float _cosTheta = clamp(cosTheta, SPECULAR_EPSILON, 1.0f - SPECULAR_EPSILON);
|
||||
float cosThetaSqr = _cosTheta * _cosTheta;
|
||||
float tanThetaSqr = (1.0 - cosThetaSqr) / cosThetaSqr;
|
||||
float res = 0.5f * (sqrtf(1.0f + alphaSqr * tanThetaSqr) - 1.0f);
|
||||
|
||||
d_alphaSqr += d_out * (0.25 * tanThetaSqr) / sqrtf(alphaSqr * tanThetaSqr + 1.0f);
|
||||
if (cosTheta > SPECULAR_EPSILON && cosTheta < 1.0f - SPECULAR_EPSILON)
|
||||
d_cosTheta += d_out * -(0.5 * alphaSqr) / (powf(_cosTheta, 3.0f) * sqrtf(alphaSqr / cosThetaSqr - alphaSqr + 1.0f));
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------
|
||||
// Masking GGX
|
||||
|
||||
__device__ inline float fwdMaskingSmithGGXCorrelated(const float alphaSqr, const float cosThetaI, const float cosThetaO)
|
||||
{
|
||||
float lambdaI = fwdLambdaGGX(alphaSqr, cosThetaI);
|
||||
float lambdaO = fwdLambdaGGX(alphaSqr, cosThetaO);
|
||||
return 1.0f / (1.0f + lambdaI + lambdaO);
|
||||
}
|
||||
|
||||
__device__ inline void bwdMaskingSmithGGXCorrelated(const float alphaSqr, const float cosThetaI, const float cosThetaO, float& d_alphaSqr, float& d_cosThetaI, float& d_cosThetaO, const float d_out)
|
||||
{
|
||||
// FWD eval
|
||||
float lambdaI = fwdLambdaGGX(alphaSqr, cosThetaI);
|
||||
float lambdaO = fwdLambdaGGX(alphaSqr, cosThetaO);
|
||||
|
||||
// BWD eval
|
||||
float d_lambdaIO = -d_out / powf(1.0f + lambdaI + lambdaO, 2.0f);
|
||||
bwdLambdaGGX(alphaSqr, cosThetaI, d_alphaSqr, d_cosThetaI, d_lambdaIO);
|
||||
bwdLambdaGGX(alphaSqr, cosThetaO, d_alphaSqr, d_cosThetaO, d_lambdaIO);
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------
|
||||
// GGX specular
|
||||
|
||||
__device__ vec3f fwdPbrSpecular(const vec3f col, const vec3f nrm, const vec3f wo, const vec3f wi, const float alpha, const float min_roughness)
|
||||
{
|
||||
float _alpha = clamp(alpha, min_roughness * min_roughness, 1.0f);
|
||||
float alphaSqr = _alpha * _alpha;
|
||||
|
||||
vec3f h = safeNormalize(wo + wi);
|
||||
float woDotN = dot(wo, nrm);
|
||||
float wiDotN = dot(wi, nrm);
|
||||
float woDotH = dot(wo, h);
|
||||
float nDotH = dot(nrm, h);
|
||||
|
||||
float D = fwdNdfGGX(alphaSqr, nDotH);
|
||||
float G = fwdMaskingSmithGGXCorrelated(alphaSqr, woDotN, wiDotN);
|
||||
vec3f F = fwdFresnelSchlick(col, 1.0f, woDotH);
|
||||
vec3f w = F * D * G * 0.25 / woDotN;
|
||||
|
||||
bool frontfacing = (woDotN > SPECULAR_EPSILON) & (wiDotN > SPECULAR_EPSILON);
|
||||
return frontfacing ? w : 0.0f;
|
||||
}
|
||||
|
||||
__device__ void bwdPbrSpecular(
|
||||
const vec3f col, const vec3f nrm, const vec3f wo, const vec3f wi, const float alpha, const float min_roughness,
|
||||
vec3f& d_col, vec3f& d_nrm, vec3f& d_wo, vec3f& d_wi, float& d_alpha, const vec3f d_out)
|
||||
{
|
||||
///////////////////////////////////////////////////////////////////////
|
||||
// FWD eval
|
||||
|
||||
float _alpha = clamp(alpha, min_roughness * min_roughness, 1.0f);
|
||||
float alphaSqr = _alpha * _alpha;
|
||||
|
||||
vec3f h = safeNormalize(wo + wi);
|
||||
float woDotN = dot(wo, nrm);
|
||||
float wiDotN = dot(wi, nrm);
|
||||
float woDotH = dot(wo, h);
|
||||
float nDotH = dot(nrm, h);
|
||||
|
||||
float D = fwdNdfGGX(alphaSqr, nDotH);
|
||||
float G = fwdMaskingSmithGGXCorrelated(alphaSqr, woDotN, wiDotN);
|
||||
vec3f F = fwdFresnelSchlick(col, 1.0f, woDotH);
|
||||
vec3f w = F * D * G * 0.25 / woDotN;
|
||||
bool frontfacing = (woDotN > SPECULAR_EPSILON) & (wiDotN > SPECULAR_EPSILON);
|
||||
|
||||
if (frontfacing)
|
||||
{
|
||||
///////////////////////////////////////////////////////////////////////
|
||||
// BWD eval
|
||||
|
||||
vec3f d_F = d_out * D * G * 0.25f / woDotN;
|
||||
float d_D = sum(d_out * F * G * 0.25f / woDotN);
|
||||
float d_G = sum(d_out * F * D * 0.25f / woDotN);
|
||||
|
||||
float d_woDotN = -sum(d_out * F * D * G * 0.25f / (woDotN * woDotN));
|
||||
|
||||
vec3f d_f90(0);
|
||||
float d_woDotH(0), d_wiDotN(0), d_nDotH(0), d_alphaSqr(0);
|
||||
bwdFresnelSchlick(col, 1.0f, woDotH, d_col, d_f90, d_woDotH, d_F);
|
||||
bwdMaskingSmithGGXCorrelated(alphaSqr, woDotN, wiDotN, d_alphaSqr, d_woDotN, d_wiDotN, d_G);
|
||||
bwdNdfGGX(alphaSqr, nDotH, d_alphaSqr, d_nDotH, d_D);
|
||||
|
||||
vec3f d_h(0);
|
||||
bwdDot(nrm, h, d_nrm, d_h, d_nDotH);
|
||||
bwdDot(wo, h, d_wo, d_h, d_woDotH);
|
||||
bwdDot(wi, nrm, d_wi, d_nrm, d_wiDotN);
|
||||
bwdDot(wo, nrm, d_wo, d_nrm, d_woDotN);
|
||||
|
||||
vec3f d_h_unnorm(0);
|
||||
bwdSafeNormalize(wo + wi, d_h_unnorm, d_h);
|
||||
d_wo += d_h_unnorm;
|
||||
d_wi += d_h_unnorm;
|
||||
|
||||
if (alpha > min_roughness * min_roughness)
|
||||
d_alpha += d_alphaSqr * 2 * alpha;
|
||||
}
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------
|
||||
// Full PBR BSDF
|
||||
|
||||
__device__ vec3f fwdPbrBSDF(const vec3f kd, const vec3f arm, const vec3f pos, const vec3f nrm, const vec3f view_pos, const vec3f light_pos, const float min_roughness, int BSDF)
|
||||
{
|
||||
vec3f wo = safeNormalize(view_pos - pos);
|
||||
vec3f wi = safeNormalize(light_pos - pos);
|
||||
|
||||
float alpha = arm.y * arm.y;
|
||||
vec3f spec_col = (0.04f * (1.0f - arm.z) + kd * arm.z) * (1.0 - arm.x);
|
||||
vec3f diff_col = kd * (1.0f - arm.z);
|
||||
|
||||
float diff = 0.0f;
|
||||
if (BSDF == 0)
|
||||
diff = fwdLambert(nrm, wi);
|
||||
else
|
||||
diff = fwdFrostbiteDiffuse(nrm, wi, wo, arm.y);
|
||||
vec3f diffuse = diff_col * diff;
|
||||
vec3f specular = fwdPbrSpecular(spec_col, nrm, wo, wi, alpha, min_roughness);
|
||||
|
||||
return diffuse + specular;
|
||||
}
|
||||
|
||||
__device__ void bwdPbrBSDF(
|
||||
const vec3f kd, const vec3f arm, const vec3f pos, const vec3f nrm, const vec3f view_pos, const vec3f light_pos, const float min_roughness, int BSDF,
|
||||
vec3f& d_kd, vec3f& d_arm, vec3f& d_pos, vec3f& d_nrm, vec3f& d_view_pos, vec3f& d_light_pos, const vec3f d_out)
|
||||
{
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
// FWD
|
||||
vec3f _wi = light_pos - pos;
|
||||
vec3f _wo = view_pos - pos;
|
||||
vec3f wi = safeNormalize(_wi);
|
||||
vec3f wo = safeNormalize(_wo);
|
||||
|
||||
float alpha = arm.y * arm.y;
|
||||
vec3f spec_col = (0.04f * (1.0f - arm.z) + kd * arm.z) * (1.0 - arm.x);
|
||||
vec3f diff_col = kd * (1.0f - arm.z);
|
||||
float diff = 0.0f;
|
||||
if (BSDF == 0)
|
||||
diff = fwdLambert(nrm, wi);
|
||||
else
|
||||
diff = fwdFrostbiteDiffuse(nrm, wi, wo, arm.y);
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
// BWD
|
||||
|
||||
float d_alpha(0);
|
||||
vec3f d_spec_col(0), d_wi(0), d_wo(0);
|
||||
bwdPbrSpecular(spec_col, nrm, wo, wi, alpha, min_roughness, d_spec_col, d_nrm, d_wo, d_wi, d_alpha, d_out);
|
||||
|
||||
float d_diff = sum(diff_col * d_out);
|
||||
if (BSDF == 0)
|
||||
bwdLambert(nrm, wi, d_nrm, d_wi, d_diff);
|
||||
else
|
||||
bwdFrostbiteDiffuse(nrm, wi, wo, arm.y, d_nrm, d_wi, d_wo, d_arm.y, d_diff);
|
||||
|
||||
// Backprop: diff_col = kd * (1.0f - arm.z)
|
||||
vec3f d_diff_col = d_out * diff;
|
||||
d_kd += d_diff_col * (1.0f - arm.z);
|
||||
d_arm.z -= sum(d_diff_col * kd);
|
||||
|
||||
// Backprop: spec_col = (0.04f * (1.0f - arm.z) + kd * arm.z) * (1.0 - arm.x)
|
||||
d_kd -= d_spec_col * (arm.x - 1.0f) * arm.z;
|
||||
d_arm.x += sum(d_spec_col * (arm.z * (0.04f - kd) - 0.04f));
|
||||
d_arm.z -= sum(d_spec_col * (kd - 0.04f) * (arm.x - 1.0f));
|
||||
|
||||
// Backprop: alpha = arm.y * arm.y
|
||||
d_arm.y += d_alpha * 2 * arm.y;
|
||||
|
||||
// Backprop: vec3f wi = safeNormalize(light_pos - pos);
|
||||
vec3f d__wi(0);
|
||||
bwdSafeNormalize(_wi, d__wi, d_wi);
|
||||
d_light_pos += d__wi;
|
||||
d_pos -= d__wi;
|
||||
|
||||
// Backprop: vec3f wo = safeNormalize(view_pos - pos);
|
||||
vec3f d__wo(0);
|
||||
bwdSafeNormalize(_wo, d__wo, d_wo);
|
||||
d_view_pos += d__wo;
|
||||
d_pos -= d__wo;
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------
|
||||
// Kernels
|
||||
|
||||
__global__ void LambertFwdKernel(LambertKernelParams p)
|
||||
{
|
||||
// Calculate pixel position.
|
||||
unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
|
||||
unsigned int pz = blockIdx.z;
|
||||
if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
|
||||
return;
|
||||
|
||||
vec3f nrm = p.nrm.fetch3(px, py, pz);
|
||||
vec3f wi = p.wi.fetch3(px, py, pz);
|
||||
|
||||
float res = fwdLambert(nrm, wi);
|
||||
|
||||
p.out.store(px, py, pz, res);
|
||||
}
|
||||
|
||||
__global__ void LambertBwdKernel(LambertKernelParams p)
|
||||
{
|
||||
// Calculate pixel position.
|
||||
unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
|
||||
unsigned int pz = blockIdx.z;
|
||||
if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
|
||||
return;
|
||||
|
||||
vec3f nrm = p.nrm.fetch3(px, py, pz);
|
||||
vec3f wi = p.wi.fetch3(px, py, pz);
|
||||
float d_out = p.out.fetch1(px, py, pz);
|
||||
|
||||
vec3f d_nrm(0), d_wi(0);
|
||||
bwdLambert(nrm, wi, d_nrm, d_wi, d_out);
|
||||
|
||||
p.nrm.store_grad(px, py, pz, d_nrm);
|
||||
p.wi.store_grad(px, py, pz, d_wi);
|
||||
}
|
||||
|
||||
__global__ void FrostbiteDiffuseFwdKernel(FrostbiteDiffuseKernelParams p)
|
||||
{
|
||||
// Calculate pixel position.
|
||||
unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
|
||||
unsigned int pz = blockIdx.z;
|
||||
if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
|
||||
return;
|
||||
|
||||
vec3f nrm = p.nrm.fetch3(px, py, pz);
|
||||
vec3f wi = p.wi.fetch3(px, py, pz);
|
||||
vec3f wo = p.wo.fetch3(px, py, pz);
|
||||
float linearRoughness = p.linearRoughness.fetch1(px, py, pz);
|
||||
|
||||
float res = fwdFrostbiteDiffuse(nrm, wi, wo, linearRoughness);
|
||||
|
||||
p.out.store(px, py, pz, res);
|
||||
}
|
||||
|
||||
__global__ void FrostbiteDiffuseBwdKernel(FrostbiteDiffuseKernelParams p)
|
||||
{
|
||||
// Calculate pixel position.
|
||||
unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
|
||||
unsigned int pz = blockIdx.z;
|
||||
if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
|
||||
return;
|
||||
|
||||
vec3f nrm = p.nrm.fetch3(px, py, pz);
|
||||
vec3f wi = p.wi.fetch3(px, py, pz);
|
||||
vec3f wo = p.wo.fetch3(px, py, pz);
|
||||
float linearRoughness = p.linearRoughness.fetch1(px, py, pz);
|
||||
float d_out = p.out.fetch1(px, py, pz);
|
||||
|
||||
float d_linearRoughness = 0.0f;
|
||||
vec3f d_nrm(0), d_wi(0), d_wo(0);
|
||||
bwdFrostbiteDiffuse(nrm, wi, wo, linearRoughness, d_nrm, d_wi, d_wo, d_linearRoughness, d_out);
|
||||
|
||||
p.nrm.store_grad(px, py, pz, d_nrm);
|
||||
p.wi.store_grad(px, py, pz, d_wi);
|
||||
p.wo.store_grad(px, py, pz, d_wo);
|
||||
p.linearRoughness.store_grad(px, py, pz, d_linearRoughness);
|
||||
}
|
||||
|
||||
__global__ void FresnelShlickFwdKernel(FresnelShlickKernelParams p)
|
||||
{
|
||||
// Calculate pixel position.
|
||||
unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
|
||||
unsigned int pz = blockIdx.z;
|
||||
if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
|
||||
return;
|
||||
|
||||
vec3f f0 = p.f0.fetch3(px, py, pz);
|
||||
vec3f f90 = p.f90.fetch3(px, py, pz);
|
||||
float cosTheta = p.cosTheta.fetch1(px, py, pz);
|
||||
|
||||
vec3f res = fwdFresnelSchlick(f0, f90, cosTheta);
|
||||
p.out.store(px, py, pz, res);
|
||||
}
|
||||
|
||||
__global__ void FresnelShlickBwdKernel(FresnelShlickKernelParams p)
|
||||
{
|
||||
// Calculate pixel position.
|
||||
unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
|
||||
unsigned int pz = blockIdx.z;
|
||||
if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
|
||||
return;
|
||||
|
||||
vec3f f0 = p.f0.fetch3(px, py, pz);
|
||||
vec3f f90 = p.f90.fetch3(px, py, pz);
|
||||
float cosTheta = p.cosTheta.fetch1(px, py, pz);
|
||||
vec3f d_out = p.out.fetch3(px, py, pz);
|
||||
|
||||
vec3f d_f0(0), d_f90(0);
|
||||
float d_cosTheta(0);
|
||||
bwdFresnelSchlick(f0, f90, cosTheta, d_f0, d_f90, d_cosTheta, d_out);
|
||||
|
||||
p.f0.store_grad(px, py, pz, d_f0);
|
||||
p.f90.store_grad(px, py, pz, d_f90);
|
||||
p.cosTheta.store_grad(px, py, pz, d_cosTheta);
|
||||
}
|
||||
|
||||
__global__ void ndfGGXFwdKernel(NdfGGXParams p)
|
||||
{
|
||||
// Calculate pixel position.
|
||||
unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
|
||||
unsigned int pz = blockIdx.z;
|
||||
if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
|
||||
return;
|
||||
|
||||
float alphaSqr = p.alphaSqr.fetch1(px, py, pz);
|
||||
float cosTheta = p.cosTheta.fetch1(px, py, pz);
|
||||
float res = fwdNdfGGX(alphaSqr, cosTheta);
|
||||
|
||||
p.out.store(px, py, pz, res);
|
||||
}
|
||||
|
||||
__global__ void ndfGGXBwdKernel(NdfGGXParams p)
|
||||
{
|
||||
// Calculate pixel position.
|
||||
unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
|
||||
unsigned int pz = blockIdx.z;
|
||||
if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
|
||||
return;
|
||||
|
||||
float alphaSqr = p.alphaSqr.fetch1(px, py, pz);
|
||||
float cosTheta = p.cosTheta.fetch1(px, py, pz);
|
||||
float d_out = p.out.fetch1(px, py, pz);
|
||||
|
||||
float d_alphaSqr(0), d_cosTheta(0);
|
||||
bwdNdfGGX(alphaSqr, cosTheta, d_alphaSqr, d_cosTheta, d_out);
|
||||
|
||||
p.alphaSqr.store_grad(px, py, pz, d_alphaSqr);
|
||||
p.cosTheta.store_grad(px, py, pz, d_cosTheta);
|
||||
}
|
||||
|
||||
__global__ void lambdaGGXFwdKernel(NdfGGXParams p)
|
||||
{
|
||||
// Calculate pixel position.
|
||||
unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
|
||||
unsigned int pz = blockIdx.z;
|
||||
if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
|
||||
return;
|
||||
|
||||
float alphaSqr = p.alphaSqr.fetch1(px, py, pz);
|
||||
float cosTheta = p.cosTheta.fetch1(px, py, pz);
|
||||
float res = fwdLambdaGGX(alphaSqr, cosTheta);
|
||||
|
||||
p.out.store(px, py, pz, res);
|
||||
}
|
||||
|
||||
__global__ void lambdaGGXBwdKernel(NdfGGXParams p)
|
||||
{
|
||||
// Calculate pixel position.
|
||||
unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
|
||||
unsigned int pz = blockIdx.z;
|
||||
if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
|
||||
return;
|
||||
|
||||
float alphaSqr = p.alphaSqr.fetch1(px, py, pz);
|
||||
float cosTheta = p.cosTheta.fetch1(px, py, pz);
|
||||
float d_out = p.out.fetch1(px, py, pz);
|
||||
|
||||
float d_alphaSqr(0), d_cosTheta(0);
|
||||
bwdLambdaGGX(alphaSqr, cosTheta, d_alphaSqr, d_cosTheta, d_out);
|
||||
|
||||
p.alphaSqr.store_grad(px, py, pz, d_alphaSqr);
|
||||
p.cosTheta.store_grad(px, py, pz, d_cosTheta);
|
||||
}
|
||||
|
||||
__global__ void maskingSmithFwdKernel(MaskingSmithParams p)
|
||||
{
|
||||
// Calculate pixel position.
|
||||
unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
|
||||
unsigned int pz = blockIdx.z;
|
||||
if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
|
||||
return;
|
||||
|
||||
float alphaSqr = p.alphaSqr.fetch1(px, py, pz);
|
||||
float cosThetaI = p.cosThetaI.fetch1(px, py, pz);
|
||||
float cosThetaO = p.cosThetaO.fetch1(px, py, pz);
|
||||
float res = fwdMaskingSmithGGXCorrelated(alphaSqr, cosThetaI, cosThetaO);
|
||||
|
||||
p.out.store(px, py, pz, res);
|
||||
}
|
||||
|
||||
__global__ void maskingSmithBwdKernel(MaskingSmithParams p)
|
||||
{
|
||||
// Calculate pixel position.
|
||||
unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
|
||||
unsigned int pz = blockIdx.z;
|
||||
if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
|
||||
return;
|
||||
|
||||
float alphaSqr = p.alphaSqr.fetch1(px, py, pz);
|
||||
float cosThetaI = p.cosThetaI.fetch1(px, py, pz);
|
||||
float cosThetaO = p.cosThetaO.fetch1(px, py, pz);
|
||||
float d_out = p.out.fetch1(px, py, pz);
|
||||
|
||||
float d_alphaSqr(0), d_cosThetaI(0), d_cosThetaO(0);
|
||||
bwdMaskingSmithGGXCorrelated(alphaSqr, cosThetaI, cosThetaO, d_alphaSqr, d_cosThetaI, d_cosThetaO, d_out);
|
||||
|
||||
p.alphaSqr.store_grad(px, py, pz, d_alphaSqr);
|
||||
p.cosThetaI.store_grad(px, py, pz, d_cosThetaI);
|
||||
p.cosThetaO.store_grad(px, py, pz, d_cosThetaO);
|
||||
}
|
||||
|
||||
__global__ void pbrSpecularFwdKernel(PbrSpecular p)
|
||||
{
|
||||
// Calculate pixel position.
|
||||
unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
|
||||
unsigned int pz = blockIdx.z;
|
||||
if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
|
||||
return;
|
||||
|
||||
vec3f col = p.col.fetch3(px, py, pz);
|
||||
vec3f nrm = p.nrm.fetch3(px, py, pz);
|
||||
vec3f wo = p.wo.fetch3(px, py, pz);
|
||||
vec3f wi = p.wi.fetch3(px, py, pz);
|
||||
float alpha = p.alpha.fetch1(px, py, pz);
|
||||
|
||||
vec3f res = fwdPbrSpecular(col, nrm, wo, wi, alpha, p.min_roughness);
|
||||
|
||||
p.out.store(px, py, pz, res);
|
||||
}
|
||||
|
||||
__global__ void pbrSpecularBwdKernel(PbrSpecular p)
|
||||
{
|
||||
// Calculate pixel position.
|
||||
unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
|
||||
unsigned int pz = blockIdx.z;
|
||||
if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
|
||||
return;
|
||||
|
||||
vec3f col = p.col.fetch3(px, py, pz);
|
||||
vec3f nrm = p.nrm.fetch3(px, py, pz);
|
||||
vec3f wo = p.wo.fetch3(px, py, pz);
|
||||
vec3f wi = p.wi.fetch3(px, py, pz);
|
||||
float alpha = p.alpha.fetch1(px, py, pz);
|
||||
vec3f d_out = p.out.fetch3(px, py, pz);
|
||||
|
||||
float d_alpha(0);
|
||||
vec3f d_col(0), d_nrm(0), d_wo(0), d_wi(0);
|
||||
bwdPbrSpecular(col, nrm, wo, wi, alpha, p.min_roughness, d_col, d_nrm, d_wo, d_wi, d_alpha, d_out);
|
||||
|
||||
p.col.store_grad(px, py, pz, d_col);
|
||||
p.nrm.store_grad(px, py, pz, d_nrm);
|
||||
p.wo.store_grad(px, py, pz, d_wo);
|
||||
p.wi.store_grad(px, py, pz, d_wi);
|
||||
p.alpha.store_grad(px, py, pz, d_alpha);
|
||||
}
|
||||
|
||||
__global__ void pbrBSDFFwdKernel(PbrBSDF p)
|
||||
{
|
||||
// Calculate pixel position.
|
||||
unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
|
||||
unsigned int pz = blockIdx.z;
|
||||
if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
|
||||
return;
|
||||
|
||||
vec3f kd = p.kd.fetch3(px, py, pz);
|
||||
vec3f arm = p.arm.fetch3(px, py, pz);
|
||||
vec3f pos = p.pos.fetch3(px, py, pz);
|
||||
vec3f nrm = p.nrm.fetch3(px, py, pz);
|
||||
vec3f view_pos = p.view_pos.fetch3(px, py, pz);
|
||||
vec3f light_pos = p.light_pos.fetch3(px, py, pz);
|
||||
|
||||
vec3f res = fwdPbrBSDF(kd, arm, pos, nrm, view_pos, light_pos, p.min_roughness, p.BSDF);
|
||||
|
||||
p.out.store(px, py, pz, res);
|
||||
}
|
||||
__global__ void pbrBSDFBwdKernel(PbrBSDF p)
|
||||
{
|
||||
// Calculate pixel position.
|
||||
unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
|
||||
unsigned int pz = blockIdx.z;
|
||||
if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
|
||||
return;
|
||||
|
||||
vec3f kd = p.kd.fetch3(px, py, pz);
|
||||
vec3f arm = p.arm.fetch3(px, py, pz);
|
||||
vec3f pos = p.pos.fetch3(px, py, pz);
|
||||
vec3f nrm = p.nrm.fetch3(px, py, pz);
|
||||
vec3f view_pos = p.view_pos.fetch3(px, py, pz);
|
||||
vec3f light_pos = p.light_pos.fetch3(px, py, pz);
|
||||
vec3f d_out = p.out.fetch3(px, py, pz);
|
||||
|
||||
vec3f d_kd(0), d_arm(0), d_pos(0), d_nrm(0), d_view_pos(0), d_light_pos(0);
|
||||
bwdPbrBSDF(kd, arm, pos, nrm, view_pos, light_pos, p.min_roughness, p.BSDF, d_kd, d_arm, d_pos, d_nrm, d_view_pos, d_light_pos, d_out);
|
||||
|
||||
p.kd.store_grad(px, py, pz, d_kd);
|
||||
p.arm.store_grad(px, py, pz, d_arm);
|
||||
p.pos.store_grad(px, py, pz, d_pos);
|
||||
p.nrm.store_grad(px, py, pz, d_nrm);
|
||||
p.view_pos.store_grad(px, py, pz, d_view_pos);
|
||||
p.light_pos.store_grad(px, py, pz, d_light_pos);
|
||||
}
|
|
@ -0,0 +1,84 @@
|
|||
/*
|
||||
* Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
*
|
||||
* NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
||||
* property and proprietary rights in and to this material, related
|
||||
* documentation and any modifications thereto. Any use, reproduction,
|
||||
* disclosure or distribution of this material and related documentation
|
||||
* without an express license agreement from NVIDIA CORPORATION or
|
||||
* its affiliates is strictly prohibited.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "common.h"
|
||||
|
||||
struct LambertKernelParams
|
||||
{
|
||||
Tensor nrm;
|
||||
Tensor wi;
|
||||
Tensor out;
|
||||
dim3 gridSize;
|
||||
};
|
||||
|
||||
struct FrostbiteDiffuseKernelParams
|
||||
{
|
||||
Tensor nrm;
|
||||
Tensor wi;
|
||||
Tensor wo;
|
||||
Tensor linearRoughness;
|
||||
Tensor out;
|
||||
dim3 gridSize;
|
||||
};
|
||||
|
||||
struct FresnelShlickKernelParams
|
||||
{
|
||||
Tensor f0;
|
||||
Tensor f90;
|
||||
Tensor cosTheta;
|
||||
Tensor out;
|
||||
dim3 gridSize;
|
||||
};
|
||||
|
||||
struct NdfGGXParams
|
||||
{
|
||||
Tensor alphaSqr;
|
||||
Tensor cosTheta;
|
||||
Tensor out;
|
||||
dim3 gridSize;
|
||||
};
|
||||
|
||||
struct MaskingSmithParams
|
||||
{
|
||||
Tensor alphaSqr;
|
||||
Tensor cosThetaI;
|
||||
Tensor cosThetaO;
|
||||
Tensor out;
|
||||
dim3 gridSize;
|
||||
};
|
||||
|
||||
struct PbrSpecular
|
||||
{
|
||||
Tensor col;
|
||||
Tensor nrm;
|
||||
Tensor wo;
|
||||
Tensor wi;
|
||||
Tensor alpha;
|
||||
Tensor out;
|
||||
dim3 gridSize;
|
||||
float min_roughness;
|
||||
};
|
||||
|
||||
struct PbrBSDF
|
||||
{
|
||||
Tensor kd;
|
||||
Tensor arm;
|
||||
Tensor pos;
|
||||
Tensor nrm;
|
||||
Tensor view_pos;
|
||||
Tensor light_pos;
|
||||
Tensor out;
|
||||
dim3 gridSize;
|
||||
float min_roughness;
|
||||
int BSDF;
|
||||
};
|
|
@ -0,0 +1,74 @@
|
|||
/*
|
||||
* Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
*
|
||||
* NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
||||
* property and proprietary rights in and to this material, related
|
||||
* documentation and any modifications thereto. Any use, reproduction,
|
||||
* disclosure or distribution of this material and related documentation
|
||||
* without an express license agreement from NVIDIA CORPORATION or
|
||||
* its affiliates is strictly prohibited.
|
||||
*/
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
#include <algorithm>
|
||||
|
||||
//------------------------------------------------------------------------
|
||||
// Block and grid size calculators for kernel launches.
|
||||
|
||||
dim3 getLaunchBlockSize(int maxWidth, int maxHeight, dim3 dims)
|
||||
{
|
||||
int maxThreads = maxWidth * maxHeight;
|
||||
if (maxThreads <= 1 || (dims.x * dims.y) <= 1)
|
||||
return dim3(1, 1, 1); // Degenerate.
|
||||
|
||||
// Start from max size.
|
||||
int bw = maxWidth;
|
||||
int bh = maxHeight;
|
||||
|
||||
// Optimizations for weirdly sized buffers.
|
||||
if (dims.x < bw)
|
||||
{
|
||||
// Decrease block width to smallest power of two that covers the buffer width.
|
||||
while ((bw >> 1) >= dims.x)
|
||||
bw >>= 1;
|
||||
|
||||
// Maximize height.
|
||||
bh = maxThreads / bw;
|
||||
if (bh > dims.y)
|
||||
bh = dims.y;
|
||||
}
|
||||
else if (dims.y < bh)
|
||||
{
|
||||
// Halve height and double width until fits completely inside buffer vertically.
|
||||
while (bh > dims.y)
|
||||
{
|
||||
bh >>= 1;
|
||||
if (bw < dims.x)
|
||||
bw <<= 1;
|
||||
}
|
||||
}
|
||||
|
||||
// Done.
|
||||
return dim3(bw, bh, 1);
|
||||
}
|
||||
|
||||
// returns the size of a block that can be reduced using horizontal SIMD operations (e.g. __shfl_xor_sync)
|
||||
dim3 getWarpSize(dim3 blockSize)
|
||||
{
|
||||
return dim3(
|
||||
std::min(blockSize.x, 32u),
|
||||
std::min(std::max(32u / blockSize.x, 1u), std::min(32u, blockSize.y)),
|
||||
std::min(std::max(32u / (blockSize.x * blockSize.y), 1u), std::min(32u, blockSize.z))
|
||||
);
|
||||
}
|
||||
|
||||
dim3 getLaunchGridSize(dim3 blockSize, dim3 dims)
|
||||
{
|
||||
dim3 gridSize;
|
||||
gridSize.x = (dims.x - 1) / blockSize.x + 1;
|
||||
gridSize.y = (dims.y - 1) / blockSize.y + 1;
|
||||
gridSize.z = (dims.z - 1) / blockSize.z + 1;
|
||||
return gridSize;
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------
|
|
@ -0,0 +1,41 @@
|
|||
/*
|
||||
* Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
*
|
||||
* NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
||||
* property and proprietary rights in and to this material, related
|
||||
* documentation and any modifications thereto. Any use, reproduction,
|
||||
* disclosure or distribution of this material and related documentation
|
||||
* without an express license agreement from NVIDIA CORPORATION or
|
||||
* its affiliates is strictly prohibited.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
#include <cuda.h>
|
||||
#include <stdint.h>
|
||||
|
||||
#include "vec3f.h"
|
||||
#include "vec4f.h"
|
||||
#include "tensor.h"
|
||||
|
||||
dim3 getLaunchBlockSize(int maxWidth, int maxHeight, dim3 dims);
|
||||
dim3 getLaunchGridSize(dim3 blockSize, dim3 dims);
|
||||
|
||||
#ifdef __CUDACC__
|
||||
|
||||
#ifdef _MSC_VER
|
||||
#define M_PI 3.14159265358979323846f
|
||||
#endif
|
||||
|
||||
__host__ __device__ static inline dim3 getWarpSize(dim3 blockSize)
|
||||
{
|
||||
return dim3(
|
||||
min(blockSize.x, 32u),
|
||||
min(max(32u / blockSize.x, 1u), min(32u, blockSize.y)),
|
||||
min(max(32u / (blockSize.x * blockSize.y), 1u), min(32u, blockSize.z))
|
||||
);
|
||||
}
|
||||
|
||||
__device__ static inline float clamp(float val, float mn, float mx) { return min(max(val, mn), mx); }
|
||||
#else
|
||||
dim3 getWarpSize(dim3 blockSize);
|
||||
#endif
|
|
@ -0,0 +1,350 @@
|
|||
/*
|
||||
* Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
*
|
||||
* NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
||||
* property and proprietary rights in and to this material, related
|
||||
* documentation and any modifications thereto. Any use, reproduction,
|
||||
* disclosure or distribution of this material and related documentation
|
||||
* without an express license agreement from NVIDIA CORPORATION or
|
||||
* its affiliates is strictly prohibited.
|
||||
*/
|
||||
|
||||
#include "common.h"
|
||||
#include "cubemap.h"
|
||||
#include <float.h>
|
||||
|
||||
// https://cgvr.cs.uni-bremen.de/teaching/cg_literatur/Spherical,%20Cubic,%20and%20Parabolic%20Environment%20Mappings.pdf
|
||||
__device__ float pixel_area(int x, int y, int N)
|
||||
{
|
||||
if (N > 1)
|
||||
{
|
||||
int H = N / 2;
|
||||
x = abs(x - H);
|
||||
y = abs(y - H);
|
||||
float dx = atan((float)(x + 1) / (float)H) - atan((float)x / (float)H);
|
||||
float dy = atan((float)(y + 1) / (float)H) - atan((float)y / (float)H);
|
||||
return dx * dy;
|
||||
}
|
||||
else
|
||||
return 1;
|
||||
}
|
||||
|
||||
__device__ vec3f cube_to_dir(int x, int y, int side, int N)
|
||||
{
|
||||
float fx = 2.0f * (((float)x + 0.5f) / (float)N) - 1.0f;
|
||||
float fy = 2.0f * (((float)y + 0.5f) / (float)N) - 1.0f;
|
||||
switch (side)
|
||||
{
|
||||
case 0: return safeNormalize(vec3f(1, -fy, -fx));
|
||||
case 1: return safeNormalize(vec3f(-1, -fy, fx));
|
||||
case 2: return safeNormalize(vec3f(fx, 1, fy));
|
||||
case 3: return safeNormalize(vec3f(fx, -1, -fy));
|
||||
case 4: return safeNormalize(vec3f(fx, -fy, 1));
|
||||
case 5: return safeNormalize(vec3f(-fx, -fy, -1));
|
||||
}
|
||||
return vec3f(0,0,0); // Unreachable
|
||||
}
|
||||
|
||||
__device__ vec3f dir_to_side(int side, vec3f v)
|
||||
{
|
||||
switch (side)
|
||||
{
|
||||
case 0: return vec3f(-v.z, -v.y, v.x);
|
||||
case 1: return vec3f( v.z, -v.y, -v.x);
|
||||
case 2: return vec3f( v.x, v.z, v.y);
|
||||
case 3: return vec3f( v.x, -v.z, -v.y);
|
||||
case 4: return vec3f( v.x, -v.y, v.z);
|
||||
case 5: return vec3f(-v.x, -v.y, -v.z);
|
||||
}
|
||||
return vec3f(0,0,0); // Unreachable
|
||||
}
|
||||
|
||||
__device__ void extents_1d(float x, float z, float theta, float& _min, float& _max)
|
||||
{
|
||||
float l = sqrtf(x * x + z * z);
|
||||
float pxr = x + z * tan(theta) * l, pzr = z - x * tan(theta) * l;
|
||||
float pxl = x - z * tan(theta) * l, pzl = z + x * tan(theta) * l;
|
||||
if (pzl <= 0.00001f)
|
||||
_min = pxl > 0.0f ? FLT_MAX : -FLT_MAX;
|
||||
else
|
||||
_min = pxl / pzl;
|
||||
if (pzr <= 0.00001f)
|
||||
_max = pxr > 0.0f ? FLT_MAX : -FLT_MAX;
|
||||
else
|
||||
_max = pxr / pzr;
|
||||
}
|
||||
|
||||
__device__ void dir_extents(int side, int N, vec3f v, float theta, int &_xmin, int& _xmax, int& _ymin, int& _ymax)
|
||||
{
|
||||
vec3f c = dir_to_side(side, v); // remap to (x,y,z) where side is at z = 1
|
||||
|
||||
if (theta < 0.785398f) // PI/4
|
||||
{
|
||||
float xmin, xmax, ymin, ymax;
|
||||
extents_1d(c.x, c.z, theta, xmin, xmax);
|
||||
extents_1d(c.y, c.z, theta, ymin, ymax);
|
||||
|
||||
if (xmin > 1.0f || xmax < -1.0f || ymin > 1.0f || ymax < -1.0f)
|
||||
{
|
||||
_xmin = -1; _xmax = -1; _ymin = -1; _ymax = -1; // Bad aabb
|
||||
}
|
||||
else
|
||||
{
|
||||
_xmin = (int)min(max((xmin + 1.0f) * (0.5f * (float)N), 0.0f), (float)(N - 1));
|
||||
_xmax = (int)min(max((xmax + 1.0f) * (0.5f * (float)N), 0.0f), (float)(N - 1));
|
||||
_ymin = (int)min(max((ymin + 1.0f) * (0.5f * (float)N), 0.0f), (float)(N - 1));
|
||||
_ymax = (int)min(max((ymax + 1.0f) * (0.5f * (float)N), 0.0f), (float)(N - 1));
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
_xmin = 0.0f;
|
||||
_xmax = (float)(N-1);
|
||||
_ymin = 0.0f;
|
||||
_ymax = (float)(N-1);
|
||||
}
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// Diffuse kernel
|
||||
__global__ void DiffuseCubemapFwdKernel(DiffuseCubemapKernelParams p)
|
||||
{
|
||||
// Calculate pixel position.
|
||||
int px = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
int py = blockIdx.y * blockDim.y + threadIdx.y;
|
||||
int pz = blockIdx.z;
|
||||
if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
|
||||
return;
|
||||
|
||||
int Npx = p.cubemap.dims[1];
|
||||
vec3f N = cube_to_dir(px, py, pz, Npx);
|
||||
|
||||
vec3f col(0);
|
||||
|
||||
for (int s = 0; s < p.cubemap.dims[0]; ++s)
|
||||
{
|
||||
for (int y = 0; y < Npx; ++y)
|
||||
{
|
||||
for (int x = 0; x < Npx; ++x)
|
||||
{
|
||||
vec3f L = cube_to_dir(x, y, s, Npx);
|
||||
float costheta = min(max(dot(N, L), 0.0f), 0.999f);
|
||||
float w = costheta * pixel_area(x, y, Npx) / 3.141592f; // pi = area of positive hemisphere
|
||||
col += p.cubemap.fetch3(x, y, s) * w;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
p.out.store(px, py, pz, col);
|
||||
}
|
||||
|
||||
__global__ void DiffuseCubemapBwdKernel(DiffuseCubemapKernelParams p)
|
||||
{
|
||||
// Calculate pixel position.
|
||||
int px = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
int py = blockIdx.y * blockDim.y + threadIdx.y;
|
||||
int pz = blockIdx.z;
|
||||
if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
|
||||
return;
|
||||
|
||||
int Npx = p.cubemap.dims[1];
|
||||
vec3f N = cube_to_dir(px, py, pz, Npx);
|
||||
vec3f grad = p.out.fetch3(px, py, pz);
|
||||
|
||||
for (int s = 0; s < p.cubemap.dims[0]; ++s)
|
||||
{
|
||||
for (int y = 0; y < Npx; ++y)
|
||||
{
|
||||
for (int x = 0; x < Npx; ++x)
|
||||
{
|
||||
vec3f L = cube_to_dir(x, y, s, Npx);
|
||||
float costheta = min(max(dot(N, L), 0.0f), 0.999f);
|
||||
float w = costheta * pixel_area(x, y, Npx) / 3.141592f; // pi = area of positive hemisphere
|
||||
atomicAdd((float*)p.cubemap.d_val + p.cubemap.nhwcIndexContinuous(s, y, x, 0), grad.x * w);
|
||||
atomicAdd((float*)p.cubemap.d_val + p.cubemap.nhwcIndexContinuous(s, y, x, 1), grad.y * w);
|
||||
atomicAdd((float*)p.cubemap.d_val + p.cubemap.nhwcIndexContinuous(s, y, x, 2), grad.z * w);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// GGX splitsum kernel
|
||||
|
||||
__device__ inline float ndfGGX(const float alphaSqr, const float cosTheta)
|
||||
{
|
||||
float _cosTheta = clamp(cosTheta, 0.0, 1.0f);
|
||||
float d = (_cosTheta * alphaSqr - _cosTheta) * _cosTheta + 1.0f;
|
||||
return alphaSqr / (d * d * M_PI);
|
||||
}
|
||||
|
||||
__global__ void SpecularBoundsKernel(SpecularBoundsKernelParams p)
|
||||
{
|
||||
int px = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
int py = blockIdx.y * blockDim.y + threadIdx.y;
|
||||
int pz = blockIdx.z;
|
||||
if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
|
||||
return;
|
||||
|
||||
int Npx = p.gridSize.x;
|
||||
vec3f VNR = cube_to_dir(px, py, pz, Npx);
|
||||
|
||||
const int TILE_SIZE = 16;
|
||||
|
||||
// Brute force entire cubemap and compute bounds for the cone
|
||||
for (int s = 0; s < p.gridSize.z; ++s)
|
||||
{
|
||||
// Assume empty BBox
|
||||
int _min_x = p.gridSize.x - 1, _max_x = 0;
|
||||
int _min_y = p.gridSize.y - 1, _max_y = 0;
|
||||
|
||||
// For each (8x8) tile
|
||||
for (int tx = 0; tx < (p.gridSize.x + TILE_SIZE - 1) / TILE_SIZE; tx++)
|
||||
{
|
||||
for (int ty = 0; ty < (p.gridSize.y + TILE_SIZE - 1) / TILE_SIZE; ty++)
|
||||
{
|
||||
// Compute tile extents
|
||||
int tsx = tx * TILE_SIZE, tsy = ty * TILE_SIZE;
|
||||
int tex = min((tx + 1) * TILE_SIZE, p.gridSize.x), tey = min((ty + 1) * TILE_SIZE, p.gridSize.y);
|
||||
|
||||
// Use some blunt interval arithmetics to cull tiles
|
||||
vec3f L0 = cube_to_dir(tsx, tsy, s, Npx), L1 = cube_to_dir(tex, tsy, s, Npx);
|
||||
vec3f L2 = cube_to_dir(tsx, tey, s, Npx), L3 = cube_to_dir(tex, tey, s, Npx);
|
||||
|
||||
float minx = min(min(L0.x, L1.x), min(L2.x, L3.x)), maxx = max(max(L0.x, L1.x), max(L2.x, L3.x));
|
||||
float miny = min(min(L0.y, L1.y), min(L2.y, L3.y)), maxy = max(max(L0.y, L1.y), max(L2.y, L3.y));
|
||||
float minz = min(min(L0.z, L1.z), min(L2.z, L3.z)), maxz = max(max(L0.z, L1.z), max(L2.z, L3.z));
|
||||
|
||||
float maxdp = max(minx * VNR.x, maxx * VNR.x) + max(miny * VNR.y, maxy * VNR.y) + max(minz * VNR.z, maxz * VNR.z);
|
||||
if (maxdp >= p.costheta_cutoff)
|
||||
{
|
||||
// Test all pixels in tile.
|
||||
for (int y = tsy; y < tey; ++y)
|
||||
{
|
||||
for (int x = tsx; x < tex; ++x)
|
||||
{
|
||||
vec3f L = cube_to_dir(x, y, s, Npx);
|
||||
if (dot(L, VNR) >= p.costheta_cutoff)
|
||||
{
|
||||
_min_x = min(_min_x, x);
|
||||
_max_x = max(_max_x, x);
|
||||
_min_y = min(_min_y, y);
|
||||
_max_y = max(_max_y, y);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
p.out.store(p.out._nhwcIndex(pz, py, px, s * 4 + 0), _min_x);
|
||||
p.out.store(p.out._nhwcIndex(pz, py, px, s * 4 + 1), _max_x);
|
||||
p.out.store(p.out._nhwcIndex(pz, py, px, s * 4 + 2), _min_y);
|
||||
p.out.store(p.out._nhwcIndex(pz, py, px, s * 4 + 3), _max_y);
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void SpecularCubemapFwdKernel(SpecularCubemapKernelParams p)
|
||||
{
|
||||
// Calculate pixel position.
|
||||
int px = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
int py = blockIdx.y * blockDim.y + threadIdx.y;
|
||||
int pz = blockIdx.z;
|
||||
if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
|
||||
return;
|
||||
|
||||
int Npx = p.cubemap.dims[1];
|
||||
vec3f VNR = cube_to_dir(px, py, pz, Npx);
|
||||
|
||||
float alpha = p.roughness * p.roughness;
|
||||
float alphaSqr = alpha * alpha;
|
||||
|
||||
float wsum = 0.0f;
|
||||
vec3f col(0);
|
||||
for (int s = 0; s < p.cubemap.dims[0]; ++s)
|
||||
{
|
||||
int xmin, xmax, ymin, ymax;
|
||||
xmin = (int)p.bounds.fetch(p.bounds._nhwcIndex(pz, py, px, s * 4 + 0));
|
||||
xmax = (int)p.bounds.fetch(p.bounds._nhwcIndex(pz, py, px, s * 4 + 1));
|
||||
ymin = (int)p.bounds.fetch(p.bounds._nhwcIndex(pz, py, px, s * 4 + 2));
|
||||
ymax = (int)p.bounds.fetch(p.bounds._nhwcIndex(pz, py, px, s * 4 + 3));
|
||||
|
||||
if (xmin <= xmax)
|
||||
{
|
||||
for (int y = ymin; y <= ymax; ++y)
|
||||
{
|
||||
for (int x = xmin; x <= xmax; ++x)
|
||||
{
|
||||
vec3f L = cube_to_dir(x, y, s, Npx);
|
||||
if (dot(L, VNR) >= p.costheta_cutoff)
|
||||
{
|
||||
vec3f H = safeNormalize(L + VNR);
|
||||
|
||||
float wiDotN = max(dot(L, VNR), 0.0f);
|
||||
float VNRDotH = max(dot(VNR, H), 0.0f);
|
||||
|
||||
float w = wiDotN * ndfGGX(alphaSqr, VNRDotH) * pixel_area(x, y, Npx) / 4.0f;
|
||||
col += p.cubemap.fetch3(x, y, s) * w;
|
||||
wsum += w;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
p.out.store(p.out._nhwcIndex(pz, py, px, 0), col.x);
|
||||
p.out.store(p.out._nhwcIndex(pz, py, px, 1), col.y);
|
||||
p.out.store(p.out._nhwcIndex(pz, py, px, 2), col.z);
|
||||
p.out.store(p.out._nhwcIndex(pz, py, px, 3), wsum);
|
||||
}
|
||||
|
||||
__global__ void SpecularCubemapBwdKernel(SpecularCubemapKernelParams p)
|
||||
{
|
||||
// Calculate pixel position.
|
||||
int px = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
int py = blockIdx.y * blockDim.y + threadIdx.y;
|
||||
int pz = blockIdx.z;
|
||||
if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
|
||||
return;
|
||||
|
||||
int Npx = p.cubemap.dims[1];
|
||||
vec3f VNR = cube_to_dir(px, py, pz, Npx);
|
||||
|
||||
vec3f grad = p.out.fetch3(px, py, pz);
|
||||
|
||||
float alpha = p.roughness * p.roughness;
|
||||
float alphaSqr = alpha * alpha;
|
||||
|
||||
vec3f col(0);
|
||||
for (int s = 0; s < p.cubemap.dims[0]; ++s)
|
||||
{
|
||||
int xmin, xmax, ymin, ymax;
|
||||
xmin = (int)p.bounds.fetch(p.bounds._nhwcIndex(pz, py, px, s * 4 + 0));
|
||||
xmax = (int)p.bounds.fetch(p.bounds._nhwcIndex(pz, py, px, s * 4 + 1));
|
||||
ymin = (int)p.bounds.fetch(p.bounds._nhwcIndex(pz, py, px, s * 4 + 2));
|
||||
ymax = (int)p.bounds.fetch(p.bounds._nhwcIndex(pz, py, px, s * 4 + 3));
|
||||
|
||||
if (xmin <= xmax)
|
||||
{
|
||||
for (int y = ymin; y <= ymax; ++y)
|
||||
{
|
||||
for (int x = xmin; x <= xmax; ++x)
|
||||
{
|
||||
vec3f L = cube_to_dir(x, y, s, Npx);
|
||||
if (dot(L, VNR) >= p.costheta_cutoff)
|
||||
{
|
||||
vec3f H = safeNormalize(L + VNR);
|
||||
|
||||
float wiDotN = max(dot(L, VNR), 0.0f);
|
||||
float VNRDotH = max(dot(VNR, H), 0.0f);
|
||||
|
||||
float w = wiDotN * ndfGGX(alphaSqr, VNRDotH) * pixel_area(x, y, Npx) / 4.0f;
|
||||
|
||||
atomicAdd((float*)p.cubemap.d_val + p.cubemap.nhwcIndexContinuous(s, y, x, 0), grad.x * w);
|
||||
atomicAdd((float*)p.cubemap.d_val + p.cubemap.nhwcIndexContinuous(s, y, x, 1), grad.y * w);
|
||||
atomicAdd((float*)p.cubemap.d_val + p.cubemap.nhwcIndexContinuous(s, y, x, 2), grad.z * w);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,38 @@
|
|||
/*
|
||||
* Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
*
|
||||
* NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
||||
* property and proprietary rights in and to this material, related
|
||||
* documentation and any modifications thereto. Any use, reproduction,
|
||||
* disclosure or distribution of this material and related documentation
|
||||
* without an express license agreement from NVIDIA CORPORATION or
|
||||
* its affiliates is strictly prohibited.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "common.h"
|
||||
|
||||
struct DiffuseCubemapKernelParams
|
||||
{
|
||||
Tensor cubemap;
|
||||
Tensor out;
|
||||
dim3 gridSize;
|
||||
};
|
||||
|
||||
struct SpecularCubemapKernelParams
|
||||
{
|
||||
Tensor cubemap;
|
||||
Tensor bounds;
|
||||
Tensor out;
|
||||
dim3 gridSize;
|
||||
float costheta_cutoff;
|
||||
float roughness;
|
||||
};
|
||||
|
||||
struct SpecularBoundsKernelParams
|
||||
{
|
||||
float costheta_cutoff;
|
||||
Tensor out;
|
||||
dim3 gridSize;
|
||||
};
|
|
@ -0,0 +1,210 @@
|
|||
/*
|
||||
* Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
*
|
||||
* NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
||||
* property and proprietary rights in and to this material, related
|
||||
* documentation and any modifications thereto. Any use, reproduction,
|
||||
* disclosure or distribution of this material and related documentation
|
||||
* without an express license agreement from NVIDIA CORPORATION or
|
||||
* its affiliates is strictly prohibited.
|
||||
*/
|
||||
|
||||
#include <cuda.h>
|
||||
|
||||
#include "common.h"
|
||||
#include "loss.h"
|
||||
|
||||
//------------------------------------------------------------------------
|
||||
// Utils
|
||||
|
||||
__device__ inline float bwdAbs(float x) { return x == 0.0f ? 0.0f : x < 0.0f ? -1.0f : 1.0f; }
|
||||
|
||||
__device__ float warpSum(float val) {
|
||||
for (int i = 1; i < 32; i *= 2)
|
||||
val += __shfl_xor_sync(0xFFFFFFFF, val, i);
|
||||
return val;
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------
|
||||
// Tonemapping
|
||||
|
||||
__device__ inline float fwdSRGB(float x)
|
||||
{
|
||||
return x > 0.0031308f ? powf(max(x, 0.0031308f), 1.0f / 2.4f) * 1.055f - 0.055f : 12.92f * max(x, 0.0f);
|
||||
}
|
||||
|
||||
__device__ inline void bwdSRGB(float x, float &d_x, float d_out)
|
||||
{
|
||||
if (x > 0.0031308f)
|
||||
d_x += d_out * 0.439583f / powf(x, 0.583333f);
|
||||
else if (x > 0.0f)
|
||||
d_x += d_out * 12.92f;
|
||||
}
|
||||
|
||||
__device__ inline vec3f fwdTonemapLogSRGB(vec3f x)
|
||||
{
|
||||
return vec3f(fwdSRGB(logf(x.x + 1.0f)), fwdSRGB(logf(x.y + 1.0f)), fwdSRGB(logf(x.z + 1.0f)));
|
||||
}
|
||||
|
||||
__device__ inline void bwdTonemapLogSRGB(vec3f x, vec3f& d_x, vec3f d_out)
|
||||
{
|
||||
if (x.x > 0.0f && x.x < 65535.0f)
|
||||
{
|
||||
bwdSRGB(logf(x.x + 1.0f), d_x.x, d_out.x);
|
||||
d_x.x *= 1 / (x.x + 1.0f);
|
||||
}
|
||||
if (x.y > 0.0f && x.y < 65535.0f)
|
||||
{
|
||||
bwdSRGB(logf(x.y + 1.0f), d_x.y, d_out.y);
|
||||
d_x.y *= 1 / (x.y + 1.0f);
|
||||
}
|
||||
if (x.z > 0.0f && x.z < 65535.0f)
|
||||
{
|
||||
bwdSRGB(logf(x.z + 1.0f), d_x.z, d_out.z);
|
||||
d_x.z *= 1 / (x.z + 1.0f);
|
||||
}
|
||||
}
|
||||
|
||||
__device__ inline float fwdRELMSE(float img, float target, float eps = 0.1f)
|
||||
{
|
||||
return (img - target) * (img - target) / (img * img + target * target + eps);
|
||||
}
|
||||
|
||||
__device__ inline void bwdRELMSE(float img, float target, float &d_img, float &d_target, float d_out, float eps = 0.1f)
|
||||
{
|
||||
float denom = (target * target + img * img + eps);
|
||||
d_img += d_out * 2 * (img - target) * (target * (target + img) + eps) / (denom * denom);
|
||||
d_target -= d_out * 2 * (img - target) * (img * (target + img) + eps) / (denom * denom);
|
||||
}
|
||||
|
||||
__device__ inline float fwdSMAPE(float img, float target, float eps=0.01f)
|
||||
{
|
||||
return abs(img - target) / (img + target + eps);
|
||||
}
|
||||
|
||||
__device__ inline void bwdSMAPE(float img, float target, float& d_img, float& d_target, float d_out, float eps = 0.01f)
|
||||
{
|
||||
float denom = (target + img + eps);
|
||||
d_img += d_out * bwdAbs(img - target) * (2 * target + eps) / (denom * denom);
|
||||
d_target -= d_out * bwdAbs(img - target) * (2 * img + eps) / (denom * denom);
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------
|
||||
// Kernels
|
||||
|
||||
__global__ void imgLossFwdKernel(LossKernelParams p)
|
||||
{
|
||||
// Calculate pixel position.
|
||||
unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
|
||||
unsigned int pz = blockIdx.z;
|
||||
|
||||
float floss = 0.0f;
|
||||
if (px < p.gridSize.x && py < p.gridSize.y && pz < p.gridSize.z)
|
||||
{
|
||||
vec3f img = p.img.fetch3(px, py, pz);
|
||||
vec3f target = p.target.fetch3(px, py, pz);
|
||||
|
||||
img = vec3f(clamp(img.x, 0.0f, 65535.0f), clamp(img.y, 0.0f, 65535.0f), clamp(img.z, 0.0f, 65535.0f));
|
||||
target = vec3f(clamp(target.x, 0.0f, 65535.0f), clamp(target.y, 0.0f, 65535.0f), clamp(target.z, 0.0f, 65535.0f));
|
||||
|
||||
if (p.tonemapper == TONEMAPPER_LOG_SRGB)
|
||||
{
|
||||
img = fwdTonemapLogSRGB(img);
|
||||
target = fwdTonemapLogSRGB(target);
|
||||
}
|
||||
|
||||
vec3f vloss(0);
|
||||
if (p.loss == LOSS_MSE)
|
||||
vloss = (img - target) * (img - target);
|
||||
else if (p.loss == LOSS_RELMSE)
|
||||
vloss = vec3f(fwdRELMSE(img.x, target.x), fwdRELMSE(img.y, target.y), fwdRELMSE(img.z, target.z));
|
||||
else if (p.loss == LOSS_SMAPE)
|
||||
vloss = vec3f(fwdSMAPE(img.x, target.x), fwdSMAPE(img.y, target.y), fwdSMAPE(img.z, target.z));
|
||||
else
|
||||
vloss = vec3f(abs(img.x - target.x), abs(img.y - target.y), abs(img.z - target.z));
|
||||
|
||||
floss = sum(vloss) / 3.0f;
|
||||
}
|
||||
|
||||
floss = warpSum(floss);
|
||||
|
||||
dim3 warpSize = getWarpSize(blockDim);
|
||||
if (px < p.gridSize.x && py < p.gridSize.y && pz < p.gridSize.z && threadIdx.x % warpSize.x == 0 && threadIdx.y % warpSize.y == 0 && threadIdx.z % warpSize.z == 0)
|
||||
p.out.store(px / warpSize.x, py / warpSize.y, pz / warpSize.z, floss);
|
||||
}
|
||||
|
||||
__global__ void imgLossBwdKernel(LossKernelParams p)
|
||||
{
|
||||
// Calculate pixel position.
|
||||
unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
|
||||
unsigned int pz = blockIdx.z;
|
||||
|
||||
if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
|
||||
return;
|
||||
|
||||
dim3 warpSize = getWarpSize(blockDim);
|
||||
|
||||
vec3f _img = p.img.fetch3(px, py, pz);
|
||||
vec3f _target = p.target.fetch3(px, py, pz);
|
||||
float d_out = p.out.fetch1(px / warpSize.x, py / warpSize.y, pz / warpSize.z);
|
||||
|
||||
/////////////////////////////////////////////////////////////////////
|
||||
// FWD
|
||||
|
||||
vec3f img = _img, target = _target;
|
||||
if (p.tonemapper == TONEMAPPER_LOG_SRGB)
|
||||
{
|
||||
img = fwdTonemapLogSRGB(img);
|
||||
target = fwdTonemapLogSRGB(target);
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////
|
||||
// BWD
|
||||
|
||||
vec3f d_vloss = vec3f(d_out, d_out, d_out) / 3.0f;
|
||||
|
||||
vec3f d_img(0), d_target(0);
|
||||
if (p.loss == LOSS_MSE)
|
||||
{
|
||||
d_img = vec3f(d_vloss.x * 2 * (img.x - target.x), d_vloss.y * 2 * (img.y - target.y), d_vloss.x * 2 * (img.z - target.z));
|
||||
d_target = -d_img;
|
||||
}
|
||||
else if (p.loss == LOSS_RELMSE)
|
||||
{
|
||||
bwdRELMSE(img.x, target.x, d_img.x, d_target.x, d_vloss.x);
|
||||
bwdRELMSE(img.y, target.y, d_img.y, d_target.y, d_vloss.y);
|
||||
bwdRELMSE(img.z, target.z, d_img.z, d_target.z, d_vloss.z);
|
||||
}
|
||||
else if (p.loss == LOSS_SMAPE)
|
||||
{
|
||||
bwdSMAPE(img.x, target.x, d_img.x, d_target.x, d_vloss.x);
|
||||
bwdSMAPE(img.y, target.y, d_img.y, d_target.y, d_vloss.y);
|
||||
bwdSMAPE(img.z, target.z, d_img.z, d_target.z, d_vloss.z);
|
||||
}
|
||||
else
|
||||
{
|
||||
d_img = d_vloss * vec3f(bwdAbs(img.x - target.x), bwdAbs(img.y - target.y), bwdAbs(img.z - target.z));
|
||||
d_target = -d_img;
|
||||
}
|
||||
|
||||
|
||||
if (p.tonemapper == TONEMAPPER_LOG_SRGB)
|
||||
{
|
||||
vec3f d__img(0), d__target(0);
|
||||
bwdTonemapLogSRGB(_img, d__img, d_img);
|
||||
bwdTonemapLogSRGB(_target, d__target, d_target);
|
||||
d_img = d__img; d_target = d__target;
|
||||
}
|
||||
|
||||
if (_img.x <= 0.0f || _img.x >= 65535.0f) d_img.x = 0;
|
||||
if (_img.y <= 0.0f || _img.y >= 65535.0f) d_img.y = 0;
|
||||
if (_img.z <= 0.0f || _img.z >= 65535.0f) d_img.z = 0;
|
||||
if (_target.x <= 0.0f || _target.x >= 65535.0f) d_target.x = 0;
|
||||
if (_target.y <= 0.0f || _target.y >= 65535.0f) d_target.y = 0;
|
||||
if (_target.z <= 0.0f || _target.z >= 65535.0f) d_target.z = 0;
|
||||
|
||||
p.img.store_grad(px, py, pz, d_img);
|
||||
p.target.store_grad(px, py, pz, d_target);
|
||||
}
|
|
@ -0,0 +1,38 @@
|
|||
/*
|
||||
* Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
*
|
||||
* NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
||||
* property and proprietary rights in and to this material, related
|
||||
* documentation and any modifications thereto. Any use, reproduction,
|
||||
* disclosure or distribution of this material and related documentation
|
||||
* without an express license agreement from NVIDIA CORPORATION or
|
||||
* its affiliates is strictly prohibited.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "common.h"
|
||||
|
||||
enum TonemapperType
|
||||
{
|
||||
TONEMAPPER_NONE = 0,
|
||||
TONEMAPPER_LOG_SRGB = 1
|
||||
};
|
||||
|
||||
enum LossType
|
||||
{
|
||||
LOSS_L1 = 0,
|
||||
LOSS_MSE = 1,
|
||||
LOSS_RELMSE = 2,
|
||||
LOSS_SMAPE = 3
|
||||
};
|
||||
|
||||
struct LossKernelParams
|
||||
{
|
||||
Tensor img;
|
||||
Tensor target;
|
||||
Tensor out;
|
||||
dim3 gridSize;
|
||||
TonemapperType tonemapper;
|
||||
LossType loss;
|
||||
};
|
|
@ -0,0 +1,94 @@
|
|||
/*
|
||||
* Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
*
|
||||
* NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
||||
* property and proprietary rights in and to this material, related
|
||||
* documentation and any modifications thereto. Any use, reproduction,
|
||||
* disclosure or distribution of this material and related documentation
|
||||
* without an express license agreement from NVIDIA CORPORATION or
|
||||
* its affiliates is strictly prohibited.
|
||||
*/
|
||||
|
||||
#include <cuda.h>
|
||||
#include <stdio.h>
|
||||
|
||||
#include "common.h"
|
||||
#include "mesh.h"
|
||||
|
||||
|
||||
//------------------------------------------------------------------------
|
||||
// Kernels
|
||||
|
||||
__global__ void xfmPointsFwdKernel(XfmKernelParams p)
|
||||
{
|
||||
unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
unsigned int pz = blockIdx.z * blockDim.z + threadIdx.z;
|
||||
|
||||
__shared__ float mtx[4][4];
|
||||
if (threadIdx.x < 16)
|
||||
mtx[threadIdx.x % 4][threadIdx.x / 4] = p.matrix.fetch(p.matrix.nhwcIndex(pz, threadIdx.x / 4, threadIdx.x % 4, 0));
|
||||
__syncthreads();
|
||||
|
||||
if (px >= p.gridSize.x)
|
||||
return;
|
||||
|
||||
vec3f pos(
|
||||
p.points.fetch(p.points.nhwcIndex(pz, px, 0, 0)),
|
||||
p.points.fetch(p.points.nhwcIndex(pz, px, 1, 0)),
|
||||
p.points.fetch(p.points.nhwcIndex(pz, px, 2, 0))
|
||||
);
|
||||
|
||||
if (p.isPoints)
|
||||
{
|
||||
p.out.store(p.out.nhwcIndex(pz, px, 0, 0), pos.x * mtx[0][0] + pos.y * mtx[1][0] + pos.z * mtx[2][0] + mtx[3][0]);
|
||||
p.out.store(p.out.nhwcIndex(pz, px, 1, 0), pos.x * mtx[0][1] + pos.y * mtx[1][1] + pos.z * mtx[2][1] + mtx[3][1]);
|
||||
p.out.store(p.out.nhwcIndex(pz, px, 2, 0), pos.x * mtx[0][2] + pos.y * mtx[1][2] + pos.z * mtx[2][2] + mtx[3][2]);
|
||||
p.out.store(p.out.nhwcIndex(pz, px, 3, 0), pos.x * mtx[0][3] + pos.y * mtx[1][3] + pos.z * mtx[2][3] + mtx[3][3]);
|
||||
}
|
||||
else
|
||||
{
|
||||
p.out.store(p.out.nhwcIndex(pz, px, 0, 0), pos.x * mtx[0][0] + pos.y * mtx[1][0] + pos.z * mtx[2][0]);
|
||||
p.out.store(p.out.nhwcIndex(pz, px, 1, 0), pos.x * mtx[0][1] + pos.y * mtx[1][1] + pos.z * mtx[2][1]);
|
||||
p.out.store(p.out.nhwcIndex(pz, px, 2, 0), pos.x * mtx[0][2] + pos.y * mtx[1][2] + pos.z * mtx[2][2]);
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void xfmPointsBwdKernel(XfmKernelParams p)
|
||||
{
|
||||
unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
unsigned int pz = blockIdx.z * blockDim.z + threadIdx.z;
|
||||
|
||||
__shared__ float mtx[4][4];
|
||||
if (threadIdx.x < 16)
|
||||
mtx[threadIdx.x % 4][threadIdx.x / 4] = p.matrix.fetch(p.matrix.nhwcIndex(pz, threadIdx.x / 4, threadIdx.x % 4, 0));
|
||||
__syncthreads();
|
||||
|
||||
if (px >= p.gridSize.x)
|
||||
return;
|
||||
|
||||
vec3f pos(
|
||||
p.points.fetch(p.points.nhwcIndex(pz, px, 0, 0)),
|
||||
p.points.fetch(p.points.nhwcIndex(pz, px, 1, 0)),
|
||||
p.points.fetch(p.points.nhwcIndex(pz, px, 2, 0))
|
||||
);
|
||||
|
||||
vec4f d_out(
|
||||
p.out.fetch(p.out.nhwcIndex(pz, px, 0, 0)),
|
||||
p.out.fetch(p.out.nhwcIndex(pz, px, 1, 0)),
|
||||
p.out.fetch(p.out.nhwcIndex(pz, px, 2, 0)),
|
||||
p.out.fetch(p.out.nhwcIndex(pz, px, 3, 0))
|
||||
);
|
||||
|
||||
if (p.isPoints)
|
||||
{
|
||||
p.points.store_grad(p.points.nhwcIndexContinuous(pz, px, 0, 0), d_out.x * mtx[0][0] + d_out.y * mtx[0][1] + d_out.z * mtx[0][2] + d_out.w * mtx[0][3]);
|
||||
p.points.store_grad(p.points.nhwcIndexContinuous(pz, px, 1, 0), d_out.x * mtx[1][0] + d_out.y * mtx[1][1] + d_out.z * mtx[1][2] + d_out.w * mtx[1][3]);
|
||||
p.points.store_grad(p.points.nhwcIndexContinuous(pz, px, 2, 0), d_out.x * mtx[2][0] + d_out.y * mtx[2][1] + d_out.z * mtx[2][2] + d_out.w * mtx[2][3]);
|
||||
}
|
||||
else
|
||||
{
|
||||
p.points.store_grad(p.points.nhwcIndexContinuous(pz, px, 0, 0), d_out.x * mtx[0][0] + d_out.y * mtx[0][1] + d_out.z * mtx[0][2]);
|
||||
p.points.store_grad(p.points.nhwcIndexContinuous(pz, px, 1, 0), d_out.x * mtx[1][0] + d_out.y * mtx[1][1] + d_out.z * mtx[1][2]);
|
||||
p.points.store_grad(p.points.nhwcIndexContinuous(pz, px, 2, 0), d_out.x * mtx[2][0] + d_out.y * mtx[2][1] + d_out.z * mtx[2][2]);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,23 @@
|
|||
/*
|
||||
* Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
*
|
||||
* NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
||||
* property and proprietary rights in and to this material, related
|
||||
* documentation and any modifications thereto. Any use, reproduction,
|
||||
* disclosure or distribution of this material and related documentation
|
||||
* without an express license agreement from NVIDIA CORPORATION or
|
||||
* its affiliates is strictly prohibited.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "common.h"
|
||||
|
||||
struct XfmKernelParams
|
||||
{
|
||||
bool isPoints;
|
||||
Tensor points;
|
||||
Tensor matrix;
|
||||
Tensor out;
|
||||
dim3 gridSize;
|
||||
};
|
|
@ -0,0 +1,182 @@
|
|||
/*
|
||||
* Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
*
|
||||
* NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
||||
* property and proprietary rights in and to this material, related
|
||||
* documentation and any modifications thereto. Any use, reproduction,
|
||||
* disclosure or distribution of this material and related documentation
|
||||
* without an express license agreement from NVIDIA CORPORATION or
|
||||
* its affiliates is strictly prohibited.
|
||||
*/
|
||||
|
||||
#include "common.h"
|
||||
#include "normal.h"
|
||||
|
||||
#define NORMAL_THRESHOLD 0.1f
|
||||
|
||||
//------------------------------------------------------------------------
|
||||
// Perturb shading normal by tangent frame
|
||||
|
||||
__device__ vec3f fwdPerturbNormal(const vec3f perturbed_nrm, const vec3f smooth_nrm, const vec3f smooth_tng, bool opengl)
|
||||
{
|
||||
vec3f _smooth_bitng = cross(smooth_tng, smooth_nrm);
|
||||
vec3f smooth_bitng = safeNormalize(_smooth_bitng);
|
||||
vec3f _shading_nrm = smooth_tng * perturbed_nrm.x + (opengl ? -1 : 1) * smooth_bitng * perturbed_nrm.y + smooth_nrm * max(perturbed_nrm.z, 0.0f);
|
||||
return safeNormalize(_shading_nrm);
|
||||
}
|
||||
|
||||
__device__ void bwdPerturbNormal(const vec3f perturbed_nrm, const vec3f smooth_nrm, const vec3f smooth_tng, vec3f &d_perturbed_nrm, vec3f &d_smooth_nrm, vec3f &d_smooth_tng, const vec3f d_out, bool opengl)
|
||||
{
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
// FWD
|
||||
vec3f _smooth_bitng = cross(smooth_tng, smooth_nrm);
|
||||
vec3f smooth_bitng = safeNormalize(_smooth_bitng);
|
||||
vec3f _shading_nrm = smooth_tng * perturbed_nrm.x + (opengl ? -1 : 1) * smooth_bitng * perturbed_nrm.y + smooth_nrm * max(perturbed_nrm.z, 0.0f);
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
// BWD
|
||||
vec3f d_shading_nrm(0);
|
||||
bwdSafeNormalize(_shading_nrm, d_shading_nrm, d_out);
|
||||
|
||||
vec3f d_smooth_bitng(0);
|
||||
|
||||
if (perturbed_nrm.z > 0.0f)
|
||||
{
|
||||
d_smooth_nrm += d_shading_nrm * perturbed_nrm.z;
|
||||
d_perturbed_nrm.z += sum(d_shading_nrm * smooth_nrm);
|
||||
}
|
||||
|
||||
d_smooth_bitng += (opengl ? -1 : 1) * d_shading_nrm * perturbed_nrm.y;
|
||||
d_perturbed_nrm.y += (opengl ? -1 : 1) * sum(d_shading_nrm * smooth_bitng);
|
||||
|
||||
d_smooth_tng += d_shading_nrm * perturbed_nrm.x;
|
||||
d_perturbed_nrm.x += sum(d_shading_nrm * smooth_tng);
|
||||
|
||||
vec3f d__smooth_bitng(0);
|
||||
bwdSafeNormalize(_smooth_bitng, d__smooth_bitng, d_smooth_bitng);
|
||||
|
||||
bwdCross(smooth_tng, smooth_nrm, d_smooth_tng, d_smooth_nrm, d__smooth_bitng);
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------
|
||||
#define bent_nrm_eps 0.001f
|
||||
|
||||
__device__ vec3f fwdBendNormal(const vec3f view_vec, const vec3f smooth_nrm, const vec3f geom_nrm)
|
||||
{
|
||||
float dp = dot(view_vec, smooth_nrm);
|
||||
float t = clamp(dp / NORMAL_THRESHOLD, 0.0f, 1.0f);
|
||||
return geom_nrm * (1.0f - t) + smooth_nrm * t;
|
||||
}
|
||||
|
||||
__device__ void bwdBendNormal(const vec3f view_vec, const vec3f smooth_nrm, const vec3f geom_nrm, vec3f& d_view_vec, vec3f& d_smooth_nrm, vec3f& d_geom_nrm, const vec3f d_out)
|
||||
{
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
// FWD
|
||||
float dp = dot(view_vec, smooth_nrm);
|
||||
float t = clamp(dp / NORMAL_THRESHOLD, 0.0f, 1.0f);
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
// BWD
|
||||
if (dp > NORMAL_THRESHOLD)
|
||||
d_smooth_nrm += d_out;
|
||||
else
|
||||
{
|
||||
// geom_nrm * (1.0f - t) + smooth_nrm * t;
|
||||
d_geom_nrm += d_out * (1.0f - t);
|
||||
d_smooth_nrm += d_out * t;
|
||||
float d_t = sum(d_out * (smooth_nrm - geom_nrm));
|
||||
|
||||
float d_dp = dp < 0.0f || dp > NORMAL_THRESHOLD ? 0.0f : d_t / NORMAL_THRESHOLD;
|
||||
|
||||
bwdDot(view_vec, smooth_nrm, d_view_vec, d_smooth_nrm, d_dp);
|
||||
}
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------
|
||||
// Kernels
|
||||
|
||||
__global__ void PrepareShadingNormalFwdKernel(PrepareShadingNormalKernelParams p)
|
||||
{
|
||||
// Calculate pixel position.
|
||||
unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
|
||||
unsigned int pz = blockIdx.z;
|
||||
if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
|
||||
return;
|
||||
|
||||
vec3f pos = p.pos.fetch3(px, py, pz);
|
||||
vec3f view_pos = p.view_pos.fetch3(px, py, pz);
|
||||
vec3f perturbed_nrm = p.perturbed_nrm.fetch3(px, py, pz);
|
||||
vec3f _smooth_nrm = p.smooth_nrm.fetch3(px, py, pz);
|
||||
vec3f _smooth_tng = p.smooth_tng.fetch3(px, py, pz);
|
||||
vec3f geom_nrm = p.geom_nrm.fetch3(px, py, pz);
|
||||
|
||||
vec3f smooth_nrm = safeNormalize(_smooth_nrm);
|
||||
vec3f smooth_tng = safeNormalize(_smooth_tng);
|
||||
vec3f view_vec = safeNormalize(view_pos - pos);
|
||||
vec3f shading_nrm = fwdPerturbNormal(perturbed_nrm, smooth_nrm, smooth_tng, p.opengl);
|
||||
|
||||
vec3f res;
|
||||
if (p.two_sided_shading && dot(view_vec, geom_nrm) < 0.0f)
|
||||
res = fwdBendNormal(view_vec, -shading_nrm, -geom_nrm);
|
||||
else
|
||||
res = fwdBendNormal(view_vec, shading_nrm, geom_nrm);
|
||||
|
||||
p.out.store(px, py, pz, res);
|
||||
}
|
||||
|
||||
__global__ void PrepareShadingNormalBwdKernel(PrepareShadingNormalKernelParams p)
|
||||
{
|
||||
// Calculate pixel position.
|
||||
unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
|
||||
unsigned int pz = blockIdx.z;
|
||||
if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
|
||||
return;
|
||||
|
||||
vec3f pos = p.pos.fetch3(px, py, pz);
|
||||
vec3f view_pos = p.view_pos.fetch3(px, py, pz);
|
||||
vec3f perturbed_nrm = p.perturbed_nrm.fetch3(px, py, pz);
|
||||
vec3f _smooth_nrm = p.smooth_nrm.fetch3(px, py, pz);
|
||||
vec3f _smooth_tng = p.smooth_tng.fetch3(px, py, pz);
|
||||
vec3f geom_nrm = p.geom_nrm.fetch3(px, py, pz);
|
||||
vec3f d_out = p.out.fetch3(px, py, pz);
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// FWD
|
||||
|
||||
vec3f smooth_nrm = safeNormalize(_smooth_nrm);
|
||||
vec3f smooth_tng = safeNormalize(_smooth_tng);
|
||||
vec3f _view_vec = view_pos - pos;
|
||||
vec3f view_vec = safeNormalize(view_pos - pos);
|
||||
|
||||
vec3f shading_nrm = fwdPerturbNormal(perturbed_nrm, smooth_nrm, smooth_tng, p.opengl);
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// BWD
|
||||
|
||||
vec3f d_view_vec(0), d_shading_nrm(0), d_geom_nrm(0);
|
||||
if (p.two_sided_shading && dot(view_vec, geom_nrm) < 0.0f)
|
||||
{
|
||||
bwdBendNormal(view_vec, -shading_nrm, -geom_nrm, d_view_vec, d_shading_nrm, d_geom_nrm, d_out);
|
||||
d_shading_nrm = -d_shading_nrm;
|
||||
d_geom_nrm = -d_geom_nrm;
|
||||
}
|
||||
else
|
||||
bwdBendNormal(view_vec, shading_nrm, geom_nrm, d_view_vec, d_shading_nrm, d_geom_nrm, d_out);
|
||||
|
||||
vec3f d_perturbed_nrm(0), d_smooth_nrm(0), d_smooth_tng(0);
|
||||
bwdPerturbNormal(perturbed_nrm, smooth_nrm, smooth_tng, d_perturbed_nrm, d_smooth_nrm, d_smooth_tng, d_shading_nrm, p.opengl);
|
||||
|
||||
vec3f d__view_vec(0), d__smooth_nrm(0), d__smooth_tng(0);
|
||||
bwdSafeNormalize(_view_vec, d__view_vec, d_view_vec);
|
||||
bwdSafeNormalize(_smooth_nrm, d__smooth_nrm, d_smooth_nrm);
|
||||
bwdSafeNormalize(_smooth_tng, d__smooth_tng, d_smooth_tng);
|
||||
|
||||
p.pos.store_grad(px, py, pz, -d__view_vec);
|
||||
p.view_pos.store_grad(px, py, pz, d__view_vec);
|
||||
p.perturbed_nrm.store_grad(px, py, pz, d_perturbed_nrm);
|
||||
p.smooth_nrm.store_grad(px, py, pz, d__smooth_nrm);
|
||||
p.smooth_tng.store_grad(px, py, pz, d__smooth_tng);
|
||||
p.geom_nrm.store_grad(px, py, pz, d_geom_nrm);
|
||||
}
|
|
@ -0,0 +1,27 @@
|
|||
/*
|
||||
* Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
*
|
||||
* NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
||||
* property and proprietary rights in and to this material, related
|
||||
* documentation and any modifications thereto. Any use, reproduction,
|
||||
* disclosure or distribution of this material and related documentation
|
||||
* without an express license agreement from NVIDIA CORPORATION or
|
||||
* its affiliates is strictly prohibited.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "common.h"
|
||||
|
||||
struct PrepareShadingNormalKernelParams
|
||||
{
|
||||
Tensor pos;
|
||||
Tensor view_pos;
|
||||
Tensor perturbed_nrm;
|
||||
Tensor smooth_nrm;
|
||||
Tensor smooth_tng;
|
||||
Tensor geom_nrm;
|
||||
Tensor out;
|
||||
dim3 gridSize;
|
||||
bool two_sided_shading, opengl;
|
||||
};
|
|
@ -0,0 +1,92 @@
|
|||
/*
|
||||
* Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
*
|
||||
* NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
||||
* property and proprietary rights in and to this material, related
|
||||
* documentation and any modifications thereto. Any use, reproduction,
|
||||
* disclosure or distribution of this material and related documentation
|
||||
* without an express license agreement from NVIDIA CORPORATION or
|
||||
* its affiliates is strictly prohibited.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
#if defined(__CUDACC__) && defined(BFLOAT16)
|
||||
#include <cuda_bf16.h> // bfloat16 is float32 compatible with less mantissa bits
|
||||
#endif
|
||||
|
||||
//---------------------------------------------------------------------------------
|
||||
// CUDA-side Tensor class for in/out parameter parsing. Can be float32 or bfloat16
|
||||
|
||||
struct Tensor
|
||||
{
|
||||
void* val;
|
||||
void* d_val;
|
||||
int dims[4], _dims[4];
|
||||
int strides[4];
|
||||
bool fp16;
|
||||
|
||||
#if defined(__CUDA__) && !defined(__CUDA_ARCH__)
|
||||
Tensor() : val(nullptr), d_val(nullptr), fp16(true), dims{ 0, 0, 0, 0 }, _dims{ 0, 0, 0, 0 }, strides{ 0, 0, 0, 0 } {}
|
||||
#endif
|
||||
|
||||
#ifdef __CUDACC__
|
||||
// Helpers to index and read/write a single element
|
||||
__device__ inline int _nhwcIndex(int n, int h, int w, int c) const { return n * strides[0] + h * strides[1] + w * strides[2] + c * strides[3]; }
|
||||
__device__ inline int nhwcIndex(int n, int h, int w, int c) const { return (dims[0] == 1 ? 0 : n * strides[0]) + (dims[1] == 1 ? 0 : h * strides[1]) + (dims[2] == 1 ? 0 : w * strides[2]) + (dims[3] == 1 ? 0 : c * strides[3]); }
|
||||
__device__ inline int nhwcIndexContinuous(int n, int h, int w, int c) const { return ((n * _dims[1] + h) * _dims[2] + w) * _dims[3] + c; }
|
||||
#ifdef BFLOAT16
|
||||
__device__ inline float fetch(unsigned int idx) const { return fp16 ? __bfloat162float(((__nv_bfloat16*)val)[idx]) : ((float*)val)[idx]; }
|
||||
__device__ inline void store(unsigned int idx, float _val) { if (fp16) ((__nv_bfloat16*)val)[idx] = __float2bfloat16(_val); else ((float*)val)[idx] = _val; }
|
||||
__device__ inline void store_grad(unsigned int idx, float _val) { if (fp16) ((__nv_bfloat16*)d_val)[idx] = __float2bfloat16(_val); else ((float*)d_val)[idx] = _val; }
|
||||
#else
|
||||
__device__ inline float fetch(unsigned int idx) const { return ((float*)val)[idx]; }
|
||||
__device__ inline void store(unsigned int idx, float _val) { ((float*)val)[idx] = _val; }
|
||||
__device__ inline void store_grad(unsigned int idx, float _val) { ((float*)d_val)[idx] = _val; }
|
||||
#endif
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////////////////
|
||||
// Fetch, use broadcasting for tensor dimensions of size 1
|
||||
__device__ inline float fetch1(unsigned int x, unsigned int y, unsigned int z) const
|
||||
{
|
||||
return fetch(nhwcIndex(z, y, x, 0));
|
||||
}
|
||||
|
||||
__device__ inline vec3f fetch3(unsigned int x, unsigned int y, unsigned int z) const
|
||||
{
|
||||
return vec3f(
|
||||
fetch(nhwcIndex(z, y, x, 0)),
|
||||
fetch(nhwcIndex(z, y, x, 1)),
|
||||
fetch(nhwcIndex(z, y, x, 2))
|
||||
);
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// Store, no broadcasting here. Assume we output full res gradient and then reduce using torch.sum outside
|
||||
__device__ inline void store(unsigned int x, unsigned int y, unsigned int z, float _val)
|
||||
{
|
||||
store(_nhwcIndex(z, y, x, 0), _val);
|
||||
}
|
||||
|
||||
__device__ inline void store(unsigned int x, unsigned int y, unsigned int z, vec3f _val)
|
||||
{
|
||||
store(_nhwcIndex(z, y, x, 0), _val.x);
|
||||
store(_nhwcIndex(z, y, x, 1), _val.y);
|
||||
store(_nhwcIndex(z, y, x, 2), _val.z);
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// Store gradient , no broadcasting here. Assume we output full res gradient and then reduce using torch.sum outside
|
||||
__device__ inline void store_grad(unsigned int x, unsigned int y, unsigned int z, float _val)
|
||||
{
|
||||
store_grad(nhwcIndexContinuous(z, y, x, 0), _val);
|
||||
}
|
||||
|
||||
__device__ inline void store_grad(unsigned int x, unsigned int y, unsigned int z, vec3f _val)
|
||||
{
|
||||
store_grad(nhwcIndexContinuous(z, y, x, 0), _val.x);
|
||||
store_grad(nhwcIndexContinuous(z, y, x, 1), _val.y);
|
||||
store_grad(nhwcIndexContinuous(z, y, x, 2), _val.z);
|
||||
}
|
||||
#endif
|
||||
|
||||
};
|
Plik diff jest za duży
Load Diff
|
@ -0,0 +1,109 @@
|
|||
/*
|
||||
* Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
*
|
||||
* NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
||||
* property and proprietary rights in and to this material, related
|
||||
* documentation and any modifications thereto. Any use, reproduction,
|
||||
* disclosure or distribution of this material and related documentation
|
||||
* without an express license agreement from NVIDIA CORPORATION or
|
||||
* its affiliates is strictly prohibited.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
struct vec3f
|
||||
{
|
||||
float x, y, z;
|
||||
|
||||
#ifdef __CUDACC__
|
||||
__device__ vec3f() { }
|
||||
__device__ vec3f(float v) { x = v; y = v; z = v; }
|
||||
__device__ vec3f(float _x, float _y, float _z) { x = _x; y = _y; z = _z; }
|
||||
__device__ vec3f(float3 v) { x = v.x; y = v.y; z = v.z; }
|
||||
|
||||
__device__ inline vec3f& operator+=(const vec3f& b) { x += b.x; y += b.y; z += b.z; return *this; }
|
||||
__device__ inline vec3f& operator-=(const vec3f& b) { x -= b.x; y -= b.y; z -= b.z; return *this; }
|
||||
__device__ inline vec3f& operator*=(const vec3f& b) { x *= b.x; y *= b.y; z *= b.z; return *this; }
|
||||
__device__ inline vec3f& operator/=(const vec3f& b) { x /= b.x; y /= b.y; z /= b.z; return *this; }
|
||||
#endif
|
||||
};
|
||||
|
||||
#ifdef __CUDACC__
|
||||
__device__ static inline vec3f operator+(const vec3f& a, const vec3f& b) { return vec3f(a.x + b.x, a.y + b.y, a.z + b.z); }
|
||||
__device__ static inline vec3f operator-(const vec3f& a, const vec3f& b) { return vec3f(a.x - b.x, a.y - b.y, a.z - b.z); }
|
||||
__device__ static inline vec3f operator*(const vec3f& a, const vec3f& b) { return vec3f(a.x * b.x, a.y * b.y, a.z * b.z); }
|
||||
__device__ static inline vec3f operator/(const vec3f& a, const vec3f& b) { return vec3f(a.x / b.x, a.y / b.y, a.z / b.z); }
|
||||
__device__ static inline vec3f operator-(const vec3f& a) { return vec3f(-a.x, -a.y, -a.z); }
|
||||
|
||||
__device__ static inline float sum(vec3f a)
|
||||
{
|
||||
return a.x + a.y + a.z;
|
||||
}
|
||||
|
||||
__device__ static inline vec3f cross(vec3f a, vec3f b)
|
||||
{
|
||||
vec3f out;
|
||||
out.x = a.y * b.z - a.z * b.y;
|
||||
out.y = a.z * b.x - a.x * b.z;
|
||||
out.z = a.x * b.y - a.y * b.x;
|
||||
return out;
|
||||
}
|
||||
|
||||
__device__ static inline void bwdCross(vec3f a, vec3f b, vec3f &d_a, vec3f &d_b, vec3f d_out)
|
||||
{
|
||||
d_a.x += d_out.z * b.y - d_out.y * b.z;
|
||||
d_a.y += d_out.x * b.z - d_out.z * b.x;
|
||||
d_a.z += d_out.y * b.x - d_out.x * b.y;
|
||||
|
||||
d_b.x += d_out.y * a.z - d_out.z * a.y;
|
||||
d_b.y += d_out.z * a.x - d_out.x * a.z;
|
||||
d_b.z += d_out.x * a.y - d_out.y * a.x;
|
||||
}
|
||||
|
||||
__device__ static inline float dot(vec3f a, vec3f b)
|
||||
{
|
||||
return a.x * b.x + a.y * b.y + a.z * b.z;
|
||||
}
|
||||
|
||||
__device__ static inline void bwdDot(vec3f a, vec3f b, vec3f& d_a, vec3f& d_b, float d_out)
|
||||
{
|
||||
d_a.x += d_out * b.x; d_a.y += d_out * b.y; d_a.z += d_out * b.z;
|
||||
d_b.x += d_out * a.x; d_b.y += d_out * a.y; d_b.z += d_out * a.z;
|
||||
}
|
||||
|
||||
__device__ static inline vec3f reflect(vec3f x, vec3f n)
|
||||
{
|
||||
return n * 2.0f * dot(n, x) - x;
|
||||
}
|
||||
|
||||
__device__ static inline void bwdReflect(vec3f x, vec3f n, vec3f& d_x, vec3f& d_n, const vec3f d_out)
|
||||
{
|
||||
d_x.x += d_out.x * (2 * n.x * n.x - 1) + d_out.y * (2 * n.x * n.y) + d_out.z * (2 * n.x * n.z);
|
||||
d_x.y += d_out.x * (2 * n.x * n.y) + d_out.y * (2 * n.y * n.y - 1) + d_out.z * (2 * n.y * n.z);
|
||||
d_x.z += d_out.x * (2 * n.x * n.z) + d_out.y * (2 * n.y * n.z) + d_out.z * (2 * n.z * n.z - 1);
|
||||
|
||||
d_n.x += d_out.x * (2 * (2 * n.x * x.x + n.y * x.y + n.z * x.z)) + d_out.y * (2 * n.y * x.x) + d_out.z * (2 * n.z * x.x);
|
||||
d_n.y += d_out.x * (2 * n.x * x.y) + d_out.y * (2 * (n.x * x.x + 2 * n.y * x.y + n.z * x.z)) + d_out.z * (2 * n.z * x.y);
|
||||
d_n.z += d_out.x * (2 * n.x * x.z) + d_out.y * (2 * n.y * x.z) + d_out.z * (2 * (n.x * x.x + n.y * x.y + 2 * n.z * x.z));
|
||||
}
|
||||
|
||||
__device__ static inline vec3f safeNormalize(vec3f v)
|
||||
{
|
||||
float l = sqrtf(v.x * v.x + v.y * v.y + v.z * v.z);
|
||||
return l > 0.0f ? (v / l) : vec3f(0.0f);
|
||||
}
|
||||
|
||||
__device__ static inline void bwdSafeNormalize(const vec3f v, vec3f& d_v, const vec3f d_out)
|
||||
{
|
||||
|
||||
float l = sqrtf(v.x * v.x + v.y * v.y + v.z * v.z);
|
||||
if (l > 0.0f)
|
||||
{
|
||||
float fac = 1.0 / powf(v.x * v.x + v.y * v.y + v.z * v.z, 1.5f);
|
||||
d_v.x += (d_out.x * (v.y * v.y + v.z * v.z) - d_out.y * (v.x * v.y) - d_out.z * (v.x * v.z)) * fac;
|
||||
d_v.y += (d_out.y * (v.x * v.x + v.z * v.z) - d_out.x * (v.y * v.x) - d_out.z * (v.y * v.z)) * fac;
|
||||
d_v.z += (d_out.z * (v.x * v.x + v.y * v.y) - d_out.x * (v.z * v.x) - d_out.y * (v.z * v.y)) * fac;
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
|
@ -0,0 +1,25 @@
|
|||
/*
|
||||
* Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
*
|
||||
* NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
||||
* property and proprietary rights in and to this material, related
|
||||
* documentation and any modifications thereto. Any use, reproduction,
|
||||
* disclosure or distribution of this material and related documentation
|
||||
* without an express license agreement from NVIDIA CORPORATION or
|
||||
* its affiliates is strictly prohibited.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
struct vec4f
|
||||
{
|
||||
float x, y, z, w;
|
||||
|
||||
#ifdef __CUDACC__
|
||||
__device__ vec4f() { }
|
||||
__device__ vec4f(float v) { x = v; y = v; z = v; w = v; }
|
||||
__device__ vec4f(float _x, float _y, float _z, float _w) { x = _x; y = _y; z = _z; w = _w; }
|
||||
__device__ vec4f(float4 v) { x = v.x; y = v.y; z = v.z; w = v.w; }
|
||||
#endif
|
||||
};
|
||||
|
|
@ -0,0 +1,41 @@
|
|||
# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
#
|
||||
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
||||
# property and proprietary rights in and to this material, related
|
||||
# documentation and any modifications thereto. Any use, reproduction,
|
||||
# disclosure or distribution of this material and related documentation
|
||||
# without an express license agreement from NVIDIA CORPORATION or
|
||||
# its affiliates is strictly prohibited.
|
||||
|
||||
import torch
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
# HDR image losses
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
def _tonemap_srgb(f):
|
||||
return torch.where(f > 0.0031308, torch.pow(torch.clamp(f, min=0.0031308), 1.0/2.4)*1.055 - 0.055, 12.92*f)
|
||||
|
||||
def _SMAPE(img, target, eps=0.01):
|
||||
nom = torch.abs(img - target)
|
||||
denom = torch.abs(img) + torch.abs(target) + 0.01
|
||||
return torch.mean(nom / denom)
|
||||
|
||||
def _RELMSE(img, target, eps=0.1):
|
||||
nom = (img - target) * (img - target)
|
||||
denom = img * img + target * target + 0.1
|
||||
return torch.mean(nom / denom)
|
||||
|
||||
def image_loss_fn(img, target, loss, tonemapper):
|
||||
if tonemapper == 'log_srgb':
|
||||
img = _tonemap_srgb(torch.log(torch.clamp(img, min=0, max=65535) + 1))
|
||||
target = _tonemap_srgb(torch.log(torch.clamp(target, min=0, max=65535) + 1))
|
||||
|
||||
if loss == 'mse':
|
||||
return torch.nn.functional.mse_loss(img, target)
|
||||
elif loss == 'smape':
|
||||
return _SMAPE(img, target)
|
||||
elif loss == 'relmse':
|
||||
return _RELMSE(img, target)
|
||||
else:
|
||||
return torch.nn.functional.l1_loss(img, target)
|
|
@ -0,0 +1,556 @@
|
|||
# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
#
|
||||
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
||||
# property and proprietary rights in and to this material, related
|
||||
# documentation and any modifications thereto. Any use, reproduction,
|
||||
# disclosure or distribution of this material and related documentation
|
||||
# without an express license agreement from NVIDIA CORPORATION or
|
||||
# its affiliates is strictly prohibited.
|
||||
|
||||
import numpy as np
|
||||
import os
|
||||
import sys
|
||||
import torch
|
||||
import torch.utils.cpp_extension
|
||||
|
||||
from .bsdf import *
|
||||
from .loss import *
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
# C++/Cuda plugin compiler/loader.
|
||||
|
||||
_cached_plugin = None
|
||||
def _get_plugin():
|
||||
# Return cached plugin if already loaded.
|
||||
global _cached_plugin
|
||||
if _cached_plugin is not None:
|
||||
return _cached_plugin
|
||||
|
||||
# Make sure we can find the necessary compiler and libary binaries.
|
||||
if os.name == 'nt':
|
||||
def find_cl_path():
|
||||
import glob
|
||||
for edition in ['Enterprise', 'Professional', 'BuildTools', 'Community']:
|
||||
paths = sorted(glob.glob(r"C:\Program Files (x86)\Microsoft Visual Studio\*\%s\VC\Tools\MSVC\*\bin\Hostx64\x64" % edition), reverse=True)
|
||||
if paths:
|
||||
return paths[0]
|
||||
|
||||
# If cl.exe is not on path, try to find it.
|
||||
if os.system("where cl.exe >nul 2>nul") != 0:
|
||||
cl_path = find_cl_path()
|
||||
if cl_path is None:
|
||||
raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation")
|
||||
os.environ['PATH'] += ';' + cl_path
|
||||
|
||||
# Compiler options.
|
||||
opts = ['-DNVDR_TORCH']
|
||||
|
||||
# Linker options.
|
||||
if os.name == 'posix':
|
||||
ldflags = ['-lcuda', '-lnvrtc']
|
||||
elif os.name == 'nt':
|
||||
ldflags = ['cuda.lib', 'advapi32.lib', 'nvrtc.lib']
|
||||
|
||||
# List of sources.
|
||||
source_files = [
|
||||
'c_src/mesh.cu',
|
||||
'c_src/loss.cu',
|
||||
'c_src/bsdf.cu',
|
||||
'c_src/normal.cu',
|
||||
'c_src/cubemap.cu',
|
||||
'c_src/common.cpp',
|
||||
'c_src/torch_bindings.cpp'
|
||||
]
|
||||
|
||||
# Some containers set this to contain old architectures that won't compile. We only need the one installed in the machine.
|
||||
os.environ['TORCH_CUDA_ARCH_LIST'] = ''
|
||||
|
||||
# Try to detect if a stray lock file is left in cache directory and show a warning. This sometimes happens on Windows if the build is interrupted at just the right moment.
|
||||
try:
|
||||
lock_fn = os.path.join(torch.utils.cpp_extension._get_build_directory('renderutils_plugin', False), 'lock')
|
||||
if os.path.exists(lock_fn):
|
||||
print("Warning: Lock file exists in build directory: '%s'" % lock_fn)
|
||||
except:
|
||||
pass
|
||||
|
||||
# Compile and load.
|
||||
source_paths = [os.path.join(os.path.dirname(__file__), fn) for fn in source_files]
|
||||
torch.utils.cpp_extension.load(name='renderutils_plugin', sources=source_paths, extra_cflags=opts,
|
||||
extra_cuda_cflags=opts, extra_ldflags=ldflags, with_cuda=True, verbose=True,
|
||||
# build_directory="PLACEHOLDER",
|
||||
)
|
||||
|
||||
# Import, cache, and return the compiled module.
|
||||
import renderutils_plugin
|
||||
_cached_plugin = renderutils_plugin
|
||||
return _cached_plugin
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
# Internal kernels, just used for testing functionality
|
||||
|
||||
class _fresnel_shlick_func(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, f0, f90, cosTheta):
|
||||
out = _get_plugin().fresnel_shlick_fwd(f0, f90, cosTheta, False)
|
||||
ctx.save_for_backward(f0, f90, cosTheta)
|
||||
return out
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, dout):
|
||||
f0, f90, cosTheta = ctx.saved_variables
|
||||
return _get_plugin().fresnel_shlick_bwd(f0, f90, cosTheta, dout) + (None,)
|
||||
|
||||
def _fresnel_shlick(f0, f90, cosTheta, use_python=False):
|
||||
if use_python:
|
||||
out = bsdf_fresnel_shlick(f0, f90, cosTheta)
|
||||
else:
|
||||
out = _fresnel_shlick_func.apply(f0, f90, cosTheta)
|
||||
|
||||
if torch.is_anomaly_enabled():
|
||||
assert torch.all(torch.isfinite(out)), "Output of _fresnel_shlick contains inf or NaN"
|
||||
return out
|
||||
|
||||
|
||||
class _ndf_ggx_func(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, alphaSqr, cosTheta):
|
||||
out = _get_plugin().ndf_ggx_fwd(alphaSqr, cosTheta, False)
|
||||
ctx.save_for_backward(alphaSqr, cosTheta)
|
||||
return out
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, dout):
|
||||
alphaSqr, cosTheta = ctx.saved_variables
|
||||
return _get_plugin().ndf_ggx_bwd(alphaSqr, cosTheta, dout) + (None,)
|
||||
|
||||
def _ndf_ggx(alphaSqr, cosTheta, use_python=False):
|
||||
if use_python:
|
||||
out = bsdf_ndf_ggx(alphaSqr, cosTheta)
|
||||
else:
|
||||
out = _ndf_ggx_func.apply(alphaSqr, cosTheta)
|
||||
|
||||
if torch.is_anomaly_enabled():
|
||||
assert torch.all(torch.isfinite(out)), "Output of _ndf_ggx contains inf or NaN"
|
||||
return out
|
||||
|
||||
class _lambda_ggx_func(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, alphaSqr, cosTheta):
|
||||
out = _get_plugin().lambda_ggx_fwd(alphaSqr, cosTheta, False)
|
||||
ctx.save_for_backward(alphaSqr, cosTheta)
|
||||
return out
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, dout):
|
||||
alphaSqr, cosTheta = ctx.saved_variables
|
||||
return _get_plugin().lambda_ggx_bwd(alphaSqr, cosTheta, dout) + (None,)
|
||||
|
||||
def _lambda_ggx(alphaSqr, cosTheta, use_python=False):
|
||||
if use_python:
|
||||
out = bsdf_lambda_ggx(alphaSqr, cosTheta)
|
||||
else:
|
||||
out = _lambda_ggx_func.apply(alphaSqr, cosTheta)
|
||||
|
||||
if torch.is_anomaly_enabled():
|
||||
assert torch.all(torch.isfinite(out)), "Output of _lambda_ggx contains inf or NaN"
|
||||
return out
|
||||
|
||||
class _masking_smith_func(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, alphaSqr, cosThetaI, cosThetaO):
|
||||
ctx.save_for_backward(alphaSqr, cosThetaI, cosThetaO)
|
||||
out = _get_plugin().masking_smith_fwd(alphaSqr, cosThetaI, cosThetaO, False)
|
||||
return out
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, dout):
|
||||
alphaSqr, cosThetaI, cosThetaO = ctx.saved_variables
|
||||
return _get_plugin().masking_smith_bwd(alphaSqr, cosThetaI, cosThetaO, dout) + (None,)
|
||||
|
||||
def _masking_smith(alphaSqr, cosThetaI, cosThetaO, use_python=False):
|
||||
if use_python:
|
||||
out = bsdf_masking_smith_ggx_correlated(alphaSqr, cosThetaI, cosThetaO)
|
||||
else:
|
||||
out = _masking_smith_func.apply(alphaSqr, cosThetaI, cosThetaO)
|
||||
|
||||
if torch.is_anomaly_enabled():
|
||||
assert torch.all(torch.isfinite(out)), "Output of _masking_smith contains inf or NaN"
|
||||
return out
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
# Shading normal setup (bump mapping + bent normals)
|
||||
|
||||
class _prepare_shading_normal_func(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, pos, view_pos, perturbed_nrm, smooth_nrm, smooth_tng, geom_nrm, two_sided_shading, opengl):
|
||||
ctx.two_sided_shading, ctx.opengl = two_sided_shading, opengl
|
||||
out = _get_plugin().prepare_shading_normal_fwd(pos, view_pos, perturbed_nrm, smooth_nrm, smooth_tng, geom_nrm, two_sided_shading, opengl, False)
|
||||
ctx.save_for_backward(pos, view_pos, perturbed_nrm, smooth_nrm, smooth_tng, geom_nrm)
|
||||
return out
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, dout):
|
||||
pos, view_pos, perturbed_nrm, smooth_nrm, smooth_tng, geom_nrm = ctx.saved_variables
|
||||
return _get_plugin().prepare_shading_normal_bwd(pos, view_pos, perturbed_nrm, smooth_nrm, smooth_tng, geom_nrm, dout, ctx.two_sided_shading, ctx.opengl) + (None, None, None)
|
||||
|
||||
def prepare_shading_normal(pos, view_pos, perturbed_nrm, smooth_nrm, smooth_tng, geom_nrm, two_sided_shading=True, opengl=True, use_python=False):
|
||||
'''Takes care of all corner cases and produces a final normal used for shading:
|
||||
- Constructs tangent space
|
||||
- Flips normal direction based on geometric normal for two sided Shading
|
||||
- Perturbs shading normal by normal map
|
||||
- Bends backfacing normals towards the camera to avoid shading artifacts
|
||||
|
||||
All tensors assume a shape of [minibatch_size, height, width, 3] or broadcastable equivalent.
|
||||
|
||||
Args:
|
||||
pos: World space g-buffer position.
|
||||
view_pos: Camera position in world space (typically using broadcasting).
|
||||
perturbed_nrm: Trangent-space normal perturbation from normal map lookup.
|
||||
smooth_nrm: Interpolated vertex normals.
|
||||
smooth_tng: Interpolated vertex tangents.
|
||||
geom_nrm: Geometric (face) normals.
|
||||
two_sided_shading: Use one/two sided shading
|
||||
opengl: Use OpenGL/DirectX normal map conventions
|
||||
use_python: Use PyTorch implementation (for validation)
|
||||
Returns:
|
||||
Final shading normal
|
||||
'''
|
||||
|
||||
if perturbed_nrm is None:
|
||||
perturbed_nrm = torch.tensor([0, 0, 1], dtype=torch.float32, device='cuda', requires_grad=False)[None, None, None, ...]
|
||||
|
||||
if use_python:
|
||||
out = bsdf_prepare_shading_normal(pos, view_pos, perturbed_nrm, smooth_nrm, smooth_tng, geom_nrm, two_sided_shading, opengl)
|
||||
else:
|
||||
out = _prepare_shading_normal_func.apply(pos, view_pos, perturbed_nrm, smooth_nrm, smooth_tng, geom_nrm, two_sided_shading, opengl)
|
||||
|
||||
if torch.is_anomaly_enabled():
|
||||
assert torch.all(torch.isfinite(out)), "Output of prepare_shading_normal contains inf or NaN"
|
||||
return out
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
# BSDF functions
|
||||
|
||||
class _lambert_func(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, nrm, wi):
|
||||
out = _get_plugin().lambert_fwd(nrm, wi, False)
|
||||
ctx.save_for_backward(nrm, wi)
|
||||
return out
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, dout):
|
||||
nrm, wi = ctx.saved_variables
|
||||
return _get_plugin().lambert_bwd(nrm, wi, dout) + (None,)
|
||||
|
||||
def lambert(nrm, wi, use_python=False):
|
||||
'''Lambertian bsdf.
|
||||
All tensors assume a shape of [minibatch_size, height, width, 3] or broadcastable equivalent.
|
||||
|
||||
Args:
|
||||
nrm: World space shading normal.
|
||||
wi: World space light vector.
|
||||
use_python: Use PyTorch implementation (for validation)
|
||||
|
||||
Returns:
|
||||
Shaded diffuse value with shape [minibatch_size, height, width, 1]
|
||||
'''
|
||||
|
||||
if use_python:
|
||||
out = bsdf_lambert(nrm, wi)
|
||||
else:
|
||||
out = _lambert_func.apply(nrm, wi)
|
||||
|
||||
if torch.is_anomaly_enabled():
|
||||
assert torch.all(torch.isfinite(out)), "Output of lambert contains inf or NaN"
|
||||
return out
|
||||
|
||||
class _frostbite_diffuse_func(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, nrm, wi, wo, linearRoughness):
|
||||
out = _get_plugin().frostbite_fwd(nrm, wi, wo, linearRoughness, False)
|
||||
ctx.save_for_backward(nrm, wi, wo, linearRoughness)
|
||||
return out
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, dout):
|
||||
nrm, wi, wo, linearRoughness = ctx.saved_variables
|
||||
return _get_plugin().frostbite_bwd(nrm, wi, wo, linearRoughness, dout) + (None,)
|
||||
|
||||
def frostbite_diffuse(nrm, wi, wo, linearRoughness, use_python=False):
|
||||
'''Frostbite, normalized Disney Diffuse bsdf.
|
||||
All tensors assume a shape of [minibatch_size, height, width, 3] or broadcastable equivalent.
|
||||
|
||||
Args:
|
||||
nrm: World space shading normal.
|
||||
wi: World space light vector.
|
||||
wo: World space camera vector.
|
||||
linearRoughness: Material roughness
|
||||
use_python: Use PyTorch implementation (for validation)
|
||||
|
||||
Returns:
|
||||
Shaded diffuse value with shape [minibatch_size, height, width, 1]
|
||||
'''
|
||||
|
||||
if use_python:
|
||||
out = bsdf_frostbite(nrm, wi, wo, linearRoughness)
|
||||
else:
|
||||
out = _frostbite_diffuse_func.apply(nrm, wi, wo, linearRoughness)
|
||||
|
||||
if torch.is_anomaly_enabled():
|
||||
assert torch.all(torch.isfinite(out)), "Output of lambert contains inf or NaN"
|
||||
return out
|
||||
|
||||
class _pbr_specular_func(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, col, nrm, wo, wi, alpha, min_roughness):
|
||||
ctx.save_for_backward(col, nrm, wo, wi, alpha)
|
||||
ctx.min_roughness = min_roughness
|
||||
out = _get_plugin().pbr_specular_fwd(col, nrm, wo, wi, alpha, min_roughness, False)
|
||||
return out
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, dout):
|
||||
col, nrm, wo, wi, alpha = ctx.saved_variables
|
||||
return _get_plugin().pbr_specular_bwd(col, nrm, wo, wi, alpha, ctx.min_roughness, dout) + (None, None)
|
||||
|
||||
def pbr_specular(col, nrm, wo, wi, alpha, min_roughness=0.08, use_python=False):
|
||||
'''Physically-based specular bsdf.
|
||||
All tensors assume a shape of [minibatch_size, height, width, 3] or broadcastable equivalent unless otherwise noted.
|
||||
|
||||
Args:
|
||||
col: Specular lobe color
|
||||
nrm: World space shading normal.
|
||||
wo: World space camera vector.
|
||||
wi: World space light vector
|
||||
alpha: Specular roughness parameter with shape [minibatch_size, height, width, 1]
|
||||
min_roughness: Scalar roughness clamping threshold
|
||||
|
||||
use_python: Use PyTorch implementation (for validation)
|
||||
Returns:
|
||||
Shaded specular color
|
||||
'''
|
||||
|
||||
if use_python:
|
||||
out = bsdf_pbr_specular(col, nrm, wo, wi, alpha, min_roughness=min_roughness)
|
||||
else:
|
||||
out = _pbr_specular_func.apply(col, nrm, wo, wi, alpha, min_roughness)
|
||||
|
||||
if torch.is_anomaly_enabled():
|
||||
assert torch.all(torch.isfinite(out)), "Output of pbr_specular contains inf or NaN"
|
||||
return out
|
||||
|
||||
class _pbr_bsdf_func(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, kd, arm, pos, nrm, view_pos, light_pos, min_roughness, BSDF):
|
||||
ctx.save_for_backward(kd, arm, pos, nrm, view_pos, light_pos)
|
||||
ctx.min_roughness = min_roughness
|
||||
ctx.BSDF = BSDF
|
||||
out = _get_plugin().pbr_bsdf_fwd(kd, arm, pos, nrm, view_pos, light_pos, min_roughness, BSDF, False)
|
||||
return out
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, dout):
|
||||
kd, arm, pos, nrm, view_pos, light_pos = ctx.saved_variables
|
||||
return _get_plugin().pbr_bsdf_bwd(kd, arm, pos, nrm, view_pos, light_pos, ctx.min_roughness, ctx.BSDF, dout) + (None, None, None)
|
||||
|
||||
def pbr_bsdf(kd, arm, pos, nrm, view_pos, light_pos, min_roughness=0.08, bsdf="lambert", use_python=False):
|
||||
'''Physically-based bsdf, both diffuse & specular lobes
|
||||
All tensors assume a shape of [minibatch_size, height, width, 3] or broadcastable equivalent unless otherwise noted.
|
||||
|
||||
Args:
|
||||
kd: Diffuse albedo.
|
||||
arm: Specular parameters (attenuation, linear roughness, metalness).
|
||||
pos: World space position.
|
||||
nrm: World space shading normal.
|
||||
view_pos: Camera position in world space, typically using broadcasting.
|
||||
light_pos: Light position in world space, typically using broadcasting.
|
||||
min_roughness: Scalar roughness clamping threshold
|
||||
bsdf: Controls diffuse BSDF, can be either 'lambert' or 'frostbite'
|
||||
|
||||
use_python: Use PyTorch implementation (for validation)
|
||||
|
||||
Returns:
|
||||
Shaded color.
|
||||
'''
|
||||
|
||||
BSDF = 0
|
||||
if bsdf == 'frostbite':
|
||||
BSDF = 1
|
||||
|
||||
if use_python:
|
||||
out = bsdf_pbr(kd, arm, pos, nrm, view_pos, light_pos, min_roughness, BSDF)
|
||||
else:
|
||||
out = _pbr_bsdf_func.apply(kd, arm, pos, nrm, view_pos, light_pos, min_roughness, BSDF)
|
||||
|
||||
if torch.is_anomaly_enabled():
|
||||
assert torch.all(torch.isfinite(out)), "Output of pbr_bsdf contains inf or NaN"
|
||||
return out
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
# cubemap filter with filtering across edges
|
||||
|
||||
class _diffuse_cubemap_func(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, cubemap):
|
||||
out = _get_plugin().diffuse_cubemap_fwd(cubemap)
|
||||
ctx.save_for_backward(cubemap)
|
||||
return out
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, dout):
|
||||
cubemap, = ctx.saved_variables
|
||||
cubemap_grad = _get_plugin().diffuse_cubemap_bwd(cubemap, dout)
|
||||
return cubemap_grad, None
|
||||
|
||||
def diffuse_cubemap(cubemap, use_python=False):
|
||||
if use_python:
|
||||
assert False
|
||||
else:
|
||||
out = _diffuse_cubemap_func.apply(cubemap)
|
||||
if torch.is_anomaly_enabled():
|
||||
assert torch.all(torch.isfinite(out)), "Output of diffuse_cubemap contains inf or NaN"
|
||||
return out
|
||||
|
||||
class _specular_cubemap(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, cubemap, roughness, costheta_cutoff, bounds):
|
||||
out = _get_plugin().specular_cubemap_fwd(cubemap, bounds, roughness, costheta_cutoff)
|
||||
ctx.save_for_backward(cubemap, bounds)
|
||||
ctx.roughness, ctx.theta_cutoff = roughness, costheta_cutoff
|
||||
return out
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, dout):
|
||||
cubemap, bounds = ctx.saved_variables
|
||||
cubemap_grad = _get_plugin().specular_cubemap_bwd(cubemap, bounds, dout, ctx.roughness, ctx.theta_cutoff)
|
||||
return cubemap_grad, None, None, None
|
||||
|
||||
# Compute the bounds of the GGX NDF lobe to retain "cutoff" percent of the energy
|
||||
def __ndfBounds(res, roughness, cutoff):
|
||||
def ndfGGX(alphaSqr, costheta):
|
||||
costheta = np.clip(costheta, 0.0, 1.0)
|
||||
d = (costheta * alphaSqr - costheta) * costheta + 1.0
|
||||
return alphaSqr / (d * d * np.pi)
|
||||
|
||||
# Sample out cutoff angle
|
||||
nSamples = 1000000
|
||||
costheta = np.cos(np.linspace(0, np.pi/2.0, nSamples))
|
||||
D = np.cumsum(ndfGGX(roughness**4, costheta))
|
||||
idx = np.argmax(D >= D[..., -1] * cutoff)
|
||||
|
||||
# Brute force compute lookup table with bounds
|
||||
bounds = _get_plugin().specular_bounds(res, costheta[idx])
|
||||
|
||||
return costheta[idx], bounds
|
||||
__ndfBoundsDict = {}
|
||||
|
||||
def specular_cubemap(cubemap, roughness, cutoff=0.99, use_python=False):
|
||||
assert cubemap.shape[0] == 6 and cubemap.shape[1] == cubemap.shape[2], "Bad shape for cubemap tensor: %s" % str(cubemap.shape)
|
||||
|
||||
if use_python:
|
||||
assert False
|
||||
else:
|
||||
key = (cubemap.shape[1], roughness, cutoff)
|
||||
if key not in __ndfBoundsDict:
|
||||
__ndfBoundsDict[key] = __ndfBounds(*key)
|
||||
out = _specular_cubemap.apply(cubemap, roughness, *__ndfBoundsDict[key])
|
||||
if torch.is_anomaly_enabled():
|
||||
assert torch.all(torch.isfinite(out)), "Output of specular_cubemap contains inf or NaN"
|
||||
return out[..., 0:3] / out[..., 3:]
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
# Fast image loss function
|
||||
|
||||
class _image_loss_func(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, img, target, loss, tonemapper):
|
||||
ctx.loss, ctx.tonemapper = loss, tonemapper
|
||||
ctx.save_for_backward(img, target)
|
||||
out = _get_plugin().image_loss_fwd(img, target, loss, tonemapper, False)
|
||||
return out
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, dout):
|
||||
img, target = ctx.saved_variables
|
||||
return _get_plugin().image_loss_bwd(img, target, dout, ctx.loss, ctx.tonemapper) + (None, None, None)
|
||||
|
||||
def image_loss(img, target, loss='l1', tonemapper='none', use_python=False):
|
||||
'''Compute HDR image loss. Combines tonemapping and loss into a single kernel for better perf.
|
||||
All tensors assume a shape of [minibatch_size, height, width, 3] or broadcastable equivalent unless otherwise noted.
|
||||
|
||||
Args:
|
||||
img: Input image.
|
||||
target: Target (reference) image.
|
||||
loss: Type of loss. Valid options are ['l1', 'mse', 'smape', 'relmse']
|
||||
tonemapper: Tonemapping operations. Valid options are ['none', 'log_srgb']
|
||||
use_python: Use PyTorch implementation (for validation)
|
||||
|
||||
Returns:
|
||||
Image space loss (scalar value).
|
||||
'''
|
||||
if use_python:
|
||||
out = image_loss_fn(img, target, loss, tonemapper)
|
||||
else:
|
||||
out = _image_loss_func.apply(img, target, loss, tonemapper)
|
||||
out = torch.sum(out) / (img.shape[0]*img.shape[1]*img.shape[2])
|
||||
|
||||
if torch.is_anomaly_enabled():
|
||||
assert torch.all(torch.isfinite(out)), "Output of image_loss contains inf or NaN"
|
||||
return out
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
# Transform points function
|
||||
|
||||
class _xfm_func(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, points, matrix, isPoints):
|
||||
ctx.save_for_backward(points, matrix)
|
||||
ctx.isPoints = isPoints
|
||||
return _get_plugin().xfm_fwd(points, matrix, isPoints, False)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, dout):
|
||||
points, matrix = ctx.saved_variables
|
||||
return (_get_plugin().xfm_bwd(points, matrix, dout, ctx.isPoints),) + (None, None, None)
|
||||
|
||||
def xfm_points(points, matrix, use_python=False):
|
||||
'''Transform points.
|
||||
Args:
|
||||
points: Tensor containing 3D points with shape [minibatch_size, num_vertices, 3] or [1, num_vertices, 3]
|
||||
matrix: A 4x4 transform matrix with shape [minibatch_size, 4, 4]
|
||||
use_python: Use PyTorch's torch.matmul (for validation)
|
||||
Returns:
|
||||
Transformed points in homogeneous 4D with shape [minibatch_size, num_vertices, 4].
|
||||
'''
|
||||
if use_python:
|
||||
out = torch.matmul(torch.nn.functional.pad(points, pad=(0,1), mode='constant', value=1.0), torch.transpose(matrix, 1, 2))
|
||||
else:
|
||||
out = _xfm_func.apply(points, matrix, True)
|
||||
|
||||
if torch.is_anomaly_enabled():
|
||||
assert torch.all(torch.isfinite(out)), "Output of xfm_points contains inf or NaN"
|
||||
return out
|
||||
|
||||
def xfm_vectors(vectors, matrix, use_python=False):
|
||||
'''Transform vectors.
|
||||
Args:
|
||||
vectors: Tensor containing 3D vectors with shape [minibatch_size, num_vertices, 3] or [1, num_vertices, 3]
|
||||
matrix: A 4x4 transform matrix with shape [minibatch_size, 4, 4]
|
||||
use_python: Use PyTorch's torch.matmul (for validation)
|
||||
|
||||
Returns:
|
||||
Transformed vectors in homogeneous 4D with shape [minibatch_size, num_vertices, 4].
|
||||
'''
|
||||
|
||||
if use_python:
|
||||
out = torch.matmul(torch.nn.functional.pad(vectors, pad=(0,1), mode='constant', value=0.0), torch.transpose(matrix, 1, 2))[..., 0:3].contiguous()
|
||||
else:
|
||||
out = _xfm_func.apply(vectors, matrix, False)
|
||||
|
||||
if torch.is_anomaly_enabled():
|
||||
assert torch.all(torch.isfinite(out)), "Output of xfm_vectors contains inf or NaN"
|
||||
return out
|
||||
|
||||
|
||||
|
|
@ -0,0 +1,61 @@
|
|||
from setuptools import setup
|
||||
import torch
|
||||
import os,glob
|
||||
from torch.utils.cpp_extension import (CUDAExtension, CppExtension, BuildExtension)
|
||||
|
||||
def get_extensions():
|
||||
extensions = []
|
||||
ext_name = 'nvdiffrec_renderutils'
|
||||
# prevent ninja from using too many resources
|
||||
os.environ.setdefault('MAX_JOBS', '16')
|
||||
define_macros = []
|
||||
|
||||
# Compiler options.
|
||||
opts = ['-DNVDR_TORCH']
|
||||
|
||||
# Linker options.
|
||||
if os.name == 'posix':
|
||||
ldflags = ['-lcuda', '-lnvrtc']
|
||||
elif os.name == 'nt':
|
||||
ldflags = ['cuda.lib', 'advapi32.lib', 'nvrtc.lib']
|
||||
|
||||
# List of sources.
|
||||
source_files = [
|
||||
'c_src/mesh.cu',
|
||||
'c_src/loss.cu',
|
||||
'c_src/bsdf.cu',
|
||||
'c_src/normal.cu',
|
||||
'c_src/cubemap.cu',
|
||||
'c_src/common.cpp',
|
||||
'c_src/torch_bindings.cpp'
|
||||
]
|
||||
|
||||
os.environ['TORCH_CUDA_ARCH_LIST'] = "5.0 6.0 6.1 7.0 7.5 8.0 8.6"
|
||||
|
||||
if torch.cuda.is_available():
|
||||
print(f'Compiling {ext_name} with CUDA')
|
||||
define_macros += [('WITH_CUDA', None)]
|
||||
# op_files = glob.glob('./c_src/*')
|
||||
# extension = CUDAExtension
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
include_path = os.path.abspath('./c_src')
|
||||
ext_ops = CUDAExtension(
|
||||
name=ext_name,
|
||||
sources=source_files,
|
||||
include_dirs=[include_path],
|
||||
define_macros=define_macros,
|
||||
extra_compile_args=opts + ldflags,
|
||||
libraries=['cuda', 'nvrtc'],
|
||||
extra_cuda_cflags=opts,
|
||||
extra_cflags=opts,
|
||||
extra_ldflags=ldflags)
|
||||
extensions.append(ext_ops)
|
||||
return extensions
|
||||
|
||||
setup(
|
||||
name='nvdiffrec_renderutils',
|
||||
ext_modules=get_extensions(),
|
||||
cmdclass={'build_ext': BuildExtension},
|
||||
)
|
|
@ -0,0 +1,296 @@
|
|||
# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
#
|
||||
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
||||
# property and proprietary rights in and to this material, related
|
||||
# documentation and any modifications thereto. Any use, reproduction,
|
||||
# disclosure or distribution of this material and related documentation
|
||||
# without an express license agreement from NVIDIA CORPORATION or
|
||||
# its affiliates is strictly prohibited.
|
||||
|
||||
import torch
|
||||
|
||||
import os
|
||||
import sys
|
||||
sys.path.insert(0, os.path.join(sys.path[0], '../..'))
|
||||
import renderutils as ru
|
||||
|
||||
RES = 4
|
||||
DTYPE = torch.float32
|
||||
|
||||
def relative_loss(name, ref, cuda):
|
||||
ref = ref.float()
|
||||
cuda = cuda.float()
|
||||
print(name, torch.max(torch.abs(ref - cuda) / torch.abs(ref + 1e-7)).item())
|
||||
|
||||
def test_normal():
|
||||
pos_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)
|
||||
pos_ref = pos_cuda.clone().detach().requires_grad_(True)
|
||||
view_pos_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)
|
||||
view_pos_ref = view_pos_cuda.clone().detach().requires_grad_(True)
|
||||
perturbed_nrm_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)
|
||||
perturbed_nrm_ref = perturbed_nrm_cuda.clone().detach().requires_grad_(True)
|
||||
smooth_nrm_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)
|
||||
smooth_nrm_ref = smooth_nrm_cuda.clone().detach().requires_grad_(True)
|
||||
smooth_tng_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)
|
||||
smooth_tng_ref = smooth_tng_cuda.clone().detach().requires_grad_(True)
|
||||
geom_nrm_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)
|
||||
geom_nrm_ref = geom_nrm_cuda.clone().detach().requires_grad_(True)
|
||||
target = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda')
|
||||
|
||||
ref = ru.prepare_shading_normal(pos_ref, view_pos_ref, perturbed_nrm_ref, smooth_nrm_ref, smooth_tng_ref, geom_nrm_ref, True, use_python=True)
|
||||
ref_loss = torch.nn.MSELoss()(ref, target)
|
||||
ref_loss.backward()
|
||||
|
||||
cuda = ru.prepare_shading_normal(pos_cuda, view_pos_cuda, perturbed_nrm_cuda, smooth_nrm_cuda, smooth_tng_cuda, geom_nrm_cuda, True)
|
||||
cuda_loss = torch.nn.MSELoss()(cuda, target)
|
||||
cuda_loss.backward()
|
||||
|
||||
print("-------------------------------------------------------------")
|
||||
print(" bent normal")
|
||||
print("-------------------------------------------------------------")
|
||||
relative_loss("res:", ref, cuda)
|
||||
relative_loss("pos:", pos_ref.grad, pos_cuda.grad)
|
||||
relative_loss("view_pos:", view_pos_ref.grad, view_pos_cuda.grad)
|
||||
relative_loss("perturbed_nrm:", perturbed_nrm_ref.grad, perturbed_nrm_cuda.grad)
|
||||
relative_loss("smooth_nrm:", smooth_nrm_ref.grad, smooth_nrm_cuda.grad)
|
||||
relative_loss("smooth_tng:", smooth_tng_ref.grad, smooth_tng_cuda.grad)
|
||||
relative_loss("geom_nrm:", geom_nrm_ref.grad, geom_nrm_cuda.grad)
|
||||
|
||||
def test_schlick():
|
||||
f0_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)
|
||||
f0_ref = f0_cuda.clone().detach().requires_grad_(True)
|
||||
f90_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)
|
||||
f90_ref = f90_cuda.clone().detach().requires_grad_(True)
|
||||
cosT_cuda = torch.rand(1, RES, RES, 1, dtype=DTYPE, device='cuda', requires_grad=True) * 2.0
|
||||
cosT_cuda = cosT_cuda.clone().detach().requires_grad_(True)
|
||||
cosT_ref = cosT_cuda.clone().detach().requires_grad_(True)
|
||||
target = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda')
|
||||
|
||||
ref = ru._fresnel_shlick(f0_ref, f90_ref, cosT_ref, use_python=True)
|
||||
ref_loss = torch.nn.MSELoss()(ref, target)
|
||||
ref_loss.backward()
|
||||
|
||||
cuda = ru._fresnel_shlick(f0_cuda, f90_cuda, cosT_cuda)
|
||||
cuda_loss = torch.nn.MSELoss()(cuda, target)
|
||||
cuda_loss.backward()
|
||||
|
||||
print("-------------------------------------------------------------")
|
||||
print(" Fresnel shlick")
|
||||
print("-------------------------------------------------------------")
|
||||
relative_loss("res:", ref, cuda)
|
||||
relative_loss("f0:", f0_ref.grad, f0_cuda.grad)
|
||||
relative_loss("f90:", f90_ref.grad, f90_cuda.grad)
|
||||
relative_loss("cosT:", cosT_ref.grad, cosT_cuda.grad)
|
||||
|
||||
def test_ndf_ggx():
|
||||
alphaSqr_cuda = torch.rand(1, RES, RES, 1, dtype=DTYPE, device='cuda', requires_grad=True)
|
||||
alphaSqr_cuda = alphaSqr_cuda.clone().detach().requires_grad_(True)
|
||||
alphaSqr_ref = alphaSqr_cuda.clone().detach().requires_grad_(True)
|
||||
cosT_cuda = torch.rand(1, RES, RES, 1, dtype=DTYPE, device='cuda', requires_grad=True) * 3.0 - 1
|
||||
cosT_cuda = cosT_cuda.clone().detach().requires_grad_(True)
|
||||
cosT_ref = cosT_cuda.clone().detach().requires_grad_(True)
|
||||
target = torch.rand(1, RES, RES, 1, dtype=DTYPE, device='cuda')
|
||||
|
||||
ref = ru._ndf_ggx(alphaSqr_ref, cosT_ref, use_python=True)
|
||||
ref_loss = torch.nn.MSELoss()(ref, target)
|
||||
ref_loss.backward()
|
||||
|
||||
cuda = ru._ndf_ggx(alphaSqr_cuda, cosT_cuda)
|
||||
cuda_loss = torch.nn.MSELoss()(cuda, target)
|
||||
cuda_loss.backward()
|
||||
|
||||
print("-------------------------------------------------------------")
|
||||
print(" Ndf GGX")
|
||||
print("-------------------------------------------------------------")
|
||||
relative_loss("res:", ref, cuda)
|
||||
relative_loss("alpha:", alphaSqr_ref.grad, alphaSqr_cuda.grad)
|
||||
relative_loss("cosT:", cosT_ref.grad, cosT_cuda.grad)
|
||||
|
||||
def test_lambda_ggx():
|
||||
alphaSqr_cuda = torch.rand(1, RES, RES, 1, dtype=DTYPE, device='cuda', requires_grad=True)
|
||||
alphaSqr_ref = alphaSqr_cuda.clone().detach().requires_grad_(True)
|
||||
cosT_cuda = torch.rand(1, RES, RES, 1, dtype=DTYPE, device='cuda', requires_grad=True) * 3.0 - 1
|
||||
cosT_cuda = cosT_cuda.clone().detach().requires_grad_(True)
|
||||
cosT_ref = cosT_cuda.clone().detach().requires_grad_(True)
|
||||
target = torch.rand(1, RES, RES, 1, dtype=DTYPE, device='cuda')
|
||||
|
||||
ref = ru._lambda_ggx(alphaSqr_ref, cosT_ref, use_python=True)
|
||||
ref_loss = torch.nn.MSELoss()(ref, target)
|
||||
ref_loss.backward()
|
||||
|
||||
cuda = ru._lambda_ggx(alphaSqr_cuda, cosT_cuda)
|
||||
cuda_loss = torch.nn.MSELoss()(cuda, target)
|
||||
cuda_loss.backward()
|
||||
|
||||
print("-------------------------------------------------------------")
|
||||
print(" Lambda GGX")
|
||||
print("-------------------------------------------------------------")
|
||||
relative_loss("res:", ref, cuda)
|
||||
relative_loss("alpha:", alphaSqr_ref.grad, alphaSqr_cuda.grad)
|
||||
relative_loss("cosT:", cosT_ref.grad, cosT_cuda.grad)
|
||||
|
||||
def test_masking_smith():
|
||||
alphaSqr_cuda = torch.rand(1, RES, RES, 1, dtype=DTYPE, device='cuda', requires_grad=True)
|
||||
alphaSqr_ref = alphaSqr_cuda.clone().detach().requires_grad_(True)
|
||||
cosThetaI_cuda = torch.rand(1, RES, RES, 1, dtype=DTYPE, device='cuda', requires_grad=True)
|
||||
cosThetaI_ref = cosThetaI_cuda.clone().detach().requires_grad_(True)
|
||||
cosThetaO_cuda = torch.rand(1, RES, RES, 1, dtype=DTYPE, device='cuda', requires_grad=True)
|
||||
cosThetaO_ref = cosThetaO_cuda.clone().detach().requires_grad_(True)
|
||||
target = torch.rand(1, RES, RES, 1, dtype=DTYPE, device='cuda')
|
||||
|
||||
ref = ru._masking_smith(alphaSqr_ref, cosThetaI_ref, cosThetaO_ref, use_python=True)
|
||||
ref_loss = torch.nn.MSELoss()(ref, target)
|
||||
ref_loss.backward()
|
||||
|
||||
cuda = ru._masking_smith(alphaSqr_cuda, cosThetaI_cuda, cosThetaO_cuda)
|
||||
cuda_loss = torch.nn.MSELoss()(cuda, target)
|
||||
cuda_loss.backward()
|
||||
|
||||
print("-------------------------------------------------------------")
|
||||
print(" Smith masking term")
|
||||
print("-------------------------------------------------------------")
|
||||
relative_loss("res:", ref, cuda)
|
||||
relative_loss("alpha:", alphaSqr_ref.grad, alphaSqr_cuda.grad)
|
||||
relative_loss("cosThetaI:", cosThetaI_ref.grad, cosThetaI_cuda.grad)
|
||||
relative_loss("cosThetaO:", cosThetaO_ref.grad, cosThetaO_cuda.grad)
|
||||
|
||||
def test_lambert():
|
||||
normals_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)
|
||||
normals_ref = normals_cuda.clone().detach().requires_grad_(True)
|
||||
wi_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)
|
||||
wi_ref = wi_cuda.clone().detach().requires_grad_(True)
|
||||
target = torch.rand(1, RES, RES, 1, dtype=DTYPE, device='cuda')
|
||||
|
||||
ref = ru.lambert(normals_ref, wi_ref, use_python=True)
|
||||
ref_loss = torch.nn.MSELoss()(ref, target)
|
||||
ref_loss.backward()
|
||||
|
||||
cuda = ru.lambert(normals_cuda, wi_cuda)
|
||||
cuda_loss = torch.nn.MSELoss()(cuda, target)
|
||||
cuda_loss.backward()
|
||||
|
||||
print("-------------------------------------------------------------")
|
||||
print(" Lambert")
|
||||
print("-------------------------------------------------------------")
|
||||
relative_loss("res:", ref, cuda)
|
||||
relative_loss("nrm:", normals_ref.grad, normals_cuda.grad)
|
||||
relative_loss("wi:", wi_ref.grad, wi_cuda.grad)
|
||||
|
||||
def test_frostbite():
|
||||
normals_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)
|
||||
normals_ref = normals_cuda.clone().detach().requires_grad_(True)
|
||||
wi_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)
|
||||
wi_ref = wi_cuda.clone().detach().requires_grad_(True)
|
||||
wo_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)
|
||||
wo_ref = wo_cuda.clone().detach().requires_grad_(True)
|
||||
rough_cuda = torch.rand(1, RES, RES, 1, dtype=DTYPE, device='cuda', requires_grad=True)
|
||||
rough_ref = rough_cuda.clone().detach().requires_grad_(True)
|
||||
target = torch.rand(1, RES, RES, 1, dtype=DTYPE, device='cuda')
|
||||
|
||||
ref = ru.frostbite_diffuse(normals_ref, wi_ref, wo_ref, rough_ref, use_python=True)
|
||||
ref_loss = torch.nn.MSELoss()(ref, target)
|
||||
ref_loss.backward()
|
||||
|
||||
cuda = ru.frostbite_diffuse(normals_cuda, wi_cuda, wo_cuda, rough_cuda)
|
||||
cuda_loss = torch.nn.MSELoss()(cuda, target)
|
||||
cuda_loss.backward()
|
||||
|
||||
print("-------------------------------------------------------------")
|
||||
print(" Frostbite")
|
||||
print("-------------------------------------------------------------")
|
||||
relative_loss("res:", ref, cuda)
|
||||
relative_loss("nrm:", normals_ref.grad, normals_cuda.grad)
|
||||
relative_loss("wo:", wo_ref.grad, wo_cuda.grad)
|
||||
relative_loss("wi:", wi_ref.grad, wi_cuda.grad)
|
||||
relative_loss("rough:", rough_ref.grad, rough_cuda.grad)
|
||||
|
||||
def test_pbr_specular():
|
||||
col_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)
|
||||
col_ref = col_cuda.clone().detach().requires_grad_(True)
|
||||
nrm_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)
|
||||
nrm_ref = nrm_cuda.clone().detach().requires_grad_(True)
|
||||
wi_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)
|
||||
wi_ref = wi_cuda.clone().detach().requires_grad_(True)
|
||||
wo_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)
|
||||
wo_ref = wo_cuda.clone().detach().requires_grad_(True)
|
||||
alpha_cuda = torch.rand(1, RES, RES, 1, dtype=DTYPE, device='cuda', requires_grad=True)
|
||||
alpha_ref = alpha_cuda.clone().detach().requires_grad_(True)
|
||||
target = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda')
|
||||
|
||||
ref = ru.pbr_specular(col_ref, nrm_ref, wo_ref, wi_ref, alpha_ref, use_python=True)
|
||||
ref_loss = torch.nn.MSELoss()(ref, target)
|
||||
ref_loss.backward()
|
||||
|
||||
cuda = ru.pbr_specular(col_cuda, nrm_cuda, wo_cuda, wi_cuda, alpha_cuda)
|
||||
cuda_loss = torch.nn.MSELoss()(cuda, target)
|
||||
cuda_loss.backward()
|
||||
|
||||
print("-------------------------------------------------------------")
|
||||
print(" Pbr specular")
|
||||
print("-------------------------------------------------------------")
|
||||
|
||||
relative_loss("res:", ref, cuda)
|
||||
if col_ref.grad is not None:
|
||||
relative_loss("col:", col_ref.grad, col_cuda.grad)
|
||||
if nrm_ref.grad is not None:
|
||||
relative_loss("nrm:", nrm_ref.grad, nrm_cuda.grad)
|
||||
if wi_ref.grad is not None:
|
||||
relative_loss("wi:", wi_ref.grad, wi_cuda.grad)
|
||||
if wo_ref.grad is not None:
|
||||
relative_loss("wo:", wo_ref.grad, wo_cuda.grad)
|
||||
if alpha_ref.grad is not None:
|
||||
relative_loss("alpha:", alpha_ref.grad, alpha_cuda.grad)
|
||||
|
||||
def test_pbr_bsdf(bsdf):
|
||||
kd_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)
|
||||
kd_ref = kd_cuda.clone().detach().requires_grad_(True)
|
||||
arm_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)
|
||||
arm_ref = arm_cuda.clone().detach().requires_grad_(True)
|
||||
pos_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)
|
||||
pos_ref = pos_cuda.clone().detach().requires_grad_(True)
|
||||
nrm_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)
|
||||
nrm_ref = nrm_cuda.clone().detach().requires_grad_(True)
|
||||
view_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)
|
||||
view_ref = view_cuda.clone().detach().requires_grad_(True)
|
||||
light_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)
|
||||
light_ref = light_cuda.clone().detach().requires_grad_(True)
|
||||
target = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda')
|
||||
|
||||
ref = ru.pbr_bsdf(kd_ref, arm_ref, pos_ref, nrm_ref, view_ref, light_ref, use_python=True, bsdf=bsdf)
|
||||
ref_loss = torch.nn.MSELoss()(ref, target)
|
||||
ref_loss.backward()
|
||||
|
||||
cuda = ru.pbr_bsdf(kd_cuda, arm_cuda, pos_cuda, nrm_cuda, view_cuda, light_cuda, bsdf=bsdf)
|
||||
cuda_loss = torch.nn.MSELoss()(cuda, target)
|
||||
cuda_loss.backward()
|
||||
|
||||
print("-------------------------------------------------------------")
|
||||
print(" Pbr BSDF")
|
||||
print("-------------------------------------------------------------")
|
||||
|
||||
relative_loss("res:", ref, cuda)
|
||||
if kd_ref.grad is not None:
|
||||
relative_loss("kd:", kd_ref.grad, kd_cuda.grad)
|
||||
if arm_ref.grad is not None:
|
||||
relative_loss("arm:", arm_ref.grad, arm_cuda.grad)
|
||||
if pos_ref.grad is not None:
|
||||
relative_loss("pos:", pos_ref.grad, pos_cuda.grad)
|
||||
if nrm_ref.grad is not None:
|
||||
relative_loss("nrm:", nrm_ref.grad, nrm_cuda.grad)
|
||||
if view_ref.grad is not None:
|
||||
relative_loss("view:", view_ref.grad, view_cuda.grad)
|
||||
if light_ref.grad is not None:
|
||||
relative_loss("light:", light_ref.grad, light_cuda.grad)
|
||||
|
||||
test_normal()
|
||||
|
||||
test_schlick()
|
||||
test_ndf_ggx()
|
||||
test_lambda_ggx()
|
||||
test_masking_smith()
|
||||
|
||||
test_lambert()
|
||||
test_frostbite()
|
||||
test_pbr_specular()
|
||||
test_pbr_bsdf('lambert')
|
||||
test_pbr_bsdf('frostbite')
|
|
@ -0,0 +1,47 @@
|
|||
# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
#
|
||||
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
||||
# property and proprietary rights in and to this material, related
|
||||
# documentation and any modifications thereto. Any use, reproduction,
|
||||
# disclosure or distribution of this material and related documentation
|
||||
# without an express license agreement from NVIDIA CORPORATION or
|
||||
# its affiliates is strictly prohibited.
|
||||
|
||||
import torch
|
||||
|
||||
import os
|
||||
import sys
|
||||
sys.path.insert(0, os.path.join(sys.path[0], '../..'))
|
||||
import renderutils as ru
|
||||
|
||||
RES = 4
|
||||
DTYPE = torch.float32
|
||||
|
||||
def relative_loss(name, ref, cuda):
|
||||
ref = ref.float()
|
||||
cuda = cuda.float()
|
||||
print(name, torch.max(torch.abs(ref - cuda) / torch.abs(ref + 1e-7)).item())
|
||||
|
||||
def test_cubemap():
|
||||
cubemap_cuda = torch.rand(6, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)
|
||||
cubemap_ref = cubemap_cuda.clone().detach().requires_grad_(True)
|
||||
weights = torch.rand(3, 3, 1, dtype=DTYPE, device='cuda')
|
||||
target = torch.rand(6, RES, RES, 3, dtype=DTYPE, device='cuda')
|
||||
|
||||
ref = ru.filter_cubemap(cubemap_ref, weights, use_python=True)
|
||||
ref_loss = torch.nn.MSELoss()(ref, target)
|
||||
ref_loss.backward()
|
||||
|
||||
cuda = ru.filter_cubemap(cubemap_cuda, weights, use_python=False)
|
||||
cuda_loss = torch.nn.MSELoss()(cuda, target)
|
||||
cuda_loss.backward()
|
||||
|
||||
print("-------------------------------------------------------------")
|
||||
print(" Cubemap:")
|
||||
print("-------------------------------------------------------------")
|
||||
|
||||
relative_loss("flt:", ref, cuda)
|
||||
relative_loss("cubemap:", cubemap_ref.grad, cubemap_cuda.grad)
|
||||
|
||||
|
||||
test_cubemap()
|
|
@ -0,0 +1,61 @@
|
|||
# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
#
|
||||
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
||||
# property and proprietary rights in and to this material, related
|
||||
# documentation and any modifications thereto. Any use, reproduction,
|
||||
# disclosure or distribution of this material and related documentation
|
||||
# without an express license agreement from NVIDIA CORPORATION or
|
||||
# its affiliates is strictly prohibited.
|
||||
|
||||
import torch
|
||||
|
||||
import os
|
||||
import sys
|
||||
sys.path.insert(0, os.path.join(sys.path[0], '../..'))
|
||||
import renderutils as ru
|
||||
|
||||
RES = 8
|
||||
DTYPE = torch.float32
|
||||
|
||||
def tonemap_srgb(f):
|
||||
return torch.where(f > 0.0031308, torch.pow(torch.clamp(f, min=0.0031308), 1.0/2.4)*1.055 - 0.055, 12.92*f)
|
||||
|
||||
def l1(output, target):
|
||||
x = torch.clamp(output, min=0, max=65535)
|
||||
r = torch.clamp(target, min=0, max=65535)
|
||||
x = tonemap_srgb(torch.log(x + 1))
|
||||
r = tonemap_srgb(torch.log(r + 1))
|
||||
return torch.nn.functional.l1_loss(x,r)
|
||||
|
||||
def relative_loss(name, ref, cuda):
|
||||
ref = ref.float()
|
||||
cuda = cuda.float()
|
||||
print(name, torch.max(torch.abs(ref - cuda) / torch.abs(ref + 1e-7)).item())
|
||||
|
||||
def test_loss(loss, tonemapper):
|
||||
img_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)
|
||||
img_ref = img_cuda.clone().detach().requires_grad_(True)
|
||||
target_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)
|
||||
target_ref = target_cuda.clone().detach().requires_grad_(True)
|
||||
|
||||
ref_loss = ru.image_loss(img_ref, target_ref, loss=loss, tonemapper=tonemapper, use_python=True)
|
||||
ref_loss.backward()
|
||||
|
||||
cuda_loss = ru.image_loss(img_cuda, target_cuda, loss=loss, tonemapper=tonemapper)
|
||||
cuda_loss.backward()
|
||||
|
||||
print("-------------------------------------------------------------")
|
||||
print(" Loss: %s, %s" % (loss, tonemapper))
|
||||
print("-------------------------------------------------------------")
|
||||
|
||||
relative_loss("res:", ref_loss, cuda_loss)
|
||||
relative_loss("img:", img_ref.grad, img_cuda.grad)
|
||||
relative_loss("target:", target_ref.grad, target_cuda.grad)
|
||||
|
||||
|
||||
test_loss('l1', 'none')
|
||||
test_loss('l1', 'log_srgb')
|
||||
test_loss('mse', 'log_srgb')
|
||||
test_loss('smape', 'none')
|
||||
test_loss('relmse', 'none')
|
||||
test_loss('mse', 'none')
|
|
@ -0,0 +1,90 @@
|
|||
# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
#
|
||||
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
||||
# property and proprietary rights in and to this material, related
|
||||
# documentation and any modifications thereto. Any use, reproduction,
|
||||
# disclosure or distribution of this material and related documentation
|
||||
# without an express license agreement from NVIDIA CORPORATION or
|
||||
# its affiliates is strictly prohibited.
|
||||
|
||||
import torch
|
||||
|
||||
import os
|
||||
import sys
|
||||
sys.path.insert(0, os.path.join(sys.path[0], '../..'))
|
||||
import renderutils as ru
|
||||
|
||||
BATCH = 8
|
||||
RES = 1024
|
||||
DTYPE = torch.float32
|
||||
|
||||
torch.manual_seed(0)
|
||||
|
||||
def tonemap_srgb(f):
|
||||
return torch.where(f > 0.0031308, torch.pow(torch.clamp(f, min=0.0031308), 1.0/2.4)*1.055 - 0.055, 12.92*f)
|
||||
|
||||
def l1(output, target):
|
||||
x = torch.clamp(output, min=0, max=65535)
|
||||
r = torch.clamp(target, min=0, max=65535)
|
||||
x = tonemap_srgb(torch.log(x + 1))
|
||||
r = tonemap_srgb(torch.log(r + 1))
|
||||
return torch.nn.functional.l1_loss(x,r)
|
||||
|
||||
def relative_loss(name, ref, cuda):
|
||||
ref = ref.float()
|
||||
cuda = cuda.float()
|
||||
print(name, torch.max(torch.abs(ref - cuda) / torch.abs(ref)).item())
|
||||
|
||||
def test_xfm_points():
|
||||
points_cuda = torch.rand(1, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)
|
||||
points_ref = points_cuda.clone().detach().requires_grad_(True)
|
||||
mtx_cuda = torch.rand(BATCH, 4, 4, dtype=DTYPE, device='cuda', requires_grad=False)
|
||||
mtx_ref = mtx_cuda.clone().detach().requires_grad_(True)
|
||||
target = torch.rand(BATCH, RES, 4, dtype=DTYPE, device='cuda', requires_grad=True)
|
||||
|
||||
ref_out = ru.xfm_points(points_ref, mtx_ref, use_python=True)
|
||||
ref_loss = torch.nn.MSELoss()(ref_out, target)
|
||||
ref_loss.backward()
|
||||
|
||||
cuda_out = ru.xfm_points(points_cuda, mtx_cuda)
|
||||
cuda_loss = torch.nn.MSELoss()(cuda_out, target)
|
||||
cuda_loss.backward()
|
||||
|
||||
print("-------------------------------------------------------------")
|
||||
|
||||
relative_loss("res:", ref_out, cuda_out)
|
||||
relative_loss("points:", points_ref.grad, points_cuda.grad)
|
||||
|
||||
def test_xfm_vectors():
|
||||
points_cuda = torch.rand(1, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)
|
||||
points_ref = points_cuda.clone().detach().requires_grad_(True)
|
||||
points_cuda_p = points_cuda.clone().detach().requires_grad_(True)
|
||||
points_ref_p = points_cuda.clone().detach().requires_grad_(True)
|
||||
mtx_cuda = torch.rand(BATCH, 4, 4, dtype=DTYPE, device='cuda', requires_grad=False)
|
||||
mtx_ref = mtx_cuda.clone().detach().requires_grad_(True)
|
||||
target = torch.rand(BATCH, RES, 4, dtype=DTYPE, device='cuda', requires_grad=True)
|
||||
|
||||
ref_out = ru.xfm_vectors(points_ref.contiguous(), mtx_ref, use_python=True)
|
||||
ref_loss = torch.nn.MSELoss()(ref_out, target[..., 0:3])
|
||||
ref_loss.backward()
|
||||
|
||||
cuda_out = ru.xfm_vectors(points_cuda.contiguous(), mtx_cuda)
|
||||
cuda_loss = torch.nn.MSELoss()(cuda_out, target[..., 0:3])
|
||||
cuda_loss.backward()
|
||||
|
||||
ref_out_p = ru.xfm_points(points_ref_p.contiguous(), mtx_ref, use_python=True)
|
||||
ref_loss_p = torch.nn.MSELoss()(ref_out_p, target)
|
||||
ref_loss_p.backward()
|
||||
|
||||
cuda_out_p = ru.xfm_points(points_cuda_p.contiguous(), mtx_cuda)
|
||||
cuda_loss_p = torch.nn.MSELoss()(cuda_out_p, target)
|
||||
cuda_loss_p.backward()
|
||||
|
||||
print("-------------------------------------------------------------")
|
||||
|
||||
relative_loss("res:", ref_out, cuda_out)
|
||||
relative_loss("points:", points_ref.grad, points_cuda.grad)
|
||||
relative_loss("points_p:", points_ref_p.grad, points_cuda_p.grad)
|
||||
|
||||
test_xfm_points()
|
||||
test_xfm_vectors()
|
|
@ -0,0 +1,57 @@
|
|||
# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
#
|
||||
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
||||
# property and proprietary rights in and to this material, related
|
||||
# documentation and any modifications thereto. Any use, reproduction,
|
||||
# disclosure or distribution of this material and related documentation
|
||||
# without an express license agreement from NVIDIA CORPORATION or
|
||||
# its affiliates is strictly prohibited.
|
||||
|
||||
import torch
|
||||
|
||||
import os
|
||||
import sys
|
||||
sys.path.insert(0, os.path.join(sys.path[0], '../..'))
|
||||
import renderutils as ru
|
||||
|
||||
DTYPE=torch.float32
|
||||
|
||||
def test_bsdf(BATCH, RES, ITR):
|
||||
kd_cuda = torch.rand(BATCH, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)
|
||||
kd_ref = kd_cuda.clone().detach().requires_grad_(True)
|
||||
arm_cuda = torch.rand(BATCH, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)
|
||||
arm_ref = arm_cuda.clone().detach().requires_grad_(True)
|
||||
pos_cuda = torch.rand(BATCH, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)
|
||||
pos_ref = pos_cuda.clone().detach().requires_grad_(True)
|
||||
nrm_cuda = torch.rand(BATCH, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)
|
||||
nrm_ref = nrm_cuda.clone().detach().requires_grad_(True)
|
||||
view_cuda = torch.rand(BATCH, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)
|
||||
view_ref = view_cuda.clone().detach().requires_grad_(True)
|
||||
light_cuda = torch.rand(BATCH, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)
|
||||
light_ref = light_cuda.clone().detach().requires_grad_(True)
|
||||
target = torch.rand(BATCH, RES, RES, 3, device='cuda')
|
||||
|
||||
start = torch.cuda.Event(enable_timing=True)
|
||||
end = torch.cuda.Event(enable_timing=True)
|
||||
|
||||
ru.pbr_bsdf(kd_cuda, arm_cuda, pos_cuda, nrm_cuda, view_cuda, light_cuda)
|
||||
|
||||
print("--- Testing: [%d, %d, %d] ---" % (BATCH, RES, RES))
|
||||
|
||||
start.record()
|
||||
for i in range(ITR):
|
||||
ref = ru.pbr_bsdf(kd_ref, arm_ref, pos_ref, nrm_ref, view_ref, light_ref, use_python=True)
|
||||
end.record()
|
||||
torch.cuda.synchronize()
|
||||
print("Pbr BSDF python:", start.elapsed_time(end))
|
||||
|
||||
start.record()
|
||||
for i in range(ITR):
|
||||
cuda = ru.pbr_bsdf(kd_cuda, arm_cuda, pos_cuda, nrm_cuda, view_cuda, light_cuda)
|
||||
end.record()
|
||||
torch.cuda.synchronize()
|
||||
print("Pbr BSDF cuda:", start.elapsed_time(end))
|
||||
|
||||
test_bsdf(1, 512, 1000)
|
||||
test_bsdf(16, 512, 1000)
|
||||
test_bsdf(1, 2048, 1000)
|
|
@ -0,0 +1,187 @@
|
|||
# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
#
|
||||
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
||||
# property and proprietary rights in and to this material, related
|
||||
# documentation and any modifications thereto. Any use, reproduction,
|
||||
# disclosure or distribution of this material and related documentation
|
||||
# without an express license agreement from NVIDIA CORPORATION or
|
||||
# its affiliates is strictly prohibited.
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
import torch
|
||||
import nvdiffrast.torch as dr
|
||||
|
||||
from . import util
|
||||
|
||||
######################################################################################
|
||||
# Smooth pooling / mip computation with linear gradient upscaling
|
||||
######################################################################################
|
||||
|
||||
class texture2d_mip(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, texture):
|
||||
return util.avg_pool_nhwc(texture, (2,2))
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, dout):
|
||||
gy, gx = torch.meshgrid(torch.linspace(0.0 + 0.25 / dout.shape[1], 1.0 - 0.25 / dout.shape[1], dout.shape[1]*2, device="cuda"),
|
||||
torch.linspace(0.0 + 0.25 / dout.shape[2], 1.0 - 0.25 / dout.shape[2], dout.shape[2]*2, device="cuda"),
|
||||
indexing='ij')
|
||||
uv = torch.stack((gx, gy), dim=-1)
|
||||
return dr.texture(dout * 0.25, uv[None, ...].contiguous(), filter_mode='linear', boundary_mode='clamp')
|
||||
|
||||
########################################################################################################
|
||||
# Simple texture class. A texture can be either
|
||||
# - A 3D tensor (using auto mipmaps)
|
||||
# - A list of 3D tensors (full custom mip hierarchy)
|
||||
########################################################################################################
|
||||
|
||||
class Texture2D(torch.nn.Module):
|
||||
# Initializes a texture from image data.
|
||||
# Input can be constant value (1D array) or texture (3D array) or mip hierarchy (list of 3d arrays)
|
||||
def __init__(self, init, min_max=None, trainable=True):
|
||||
super(Texture2D, self).__init__()
|
||||
|
||||
if isinstance(init, np.ndarray):
|
||||
init = torch.tensor(init, dtype=torch.float32, device='cuda')
|
||||
elif isinstance(init, list) and len(init) == 1:
|
||||
init = init[0]
|
||||
|
||||
if isinstance(init, list):
|
||||
self.data = list(torch.nn.Parameter(mip.clone().detach(), requires_grad=trainable) for mip in init)
|
||||
elif len(init.shape) == 4:
|
||||
self.data = torch.nn.Parameter(init.clone().detach(), requires_grad=trainable)
|
||||
elif len(init.shape) == 3:
|
||||
self.data = torch.nn.Parameter(init[None, ...].clone().detach(), requires_grad=trainable)
|
||||
elif len(init.shape) == 1:
|
||||
self.data = torch.nn.Parameter(init[None, None, None, :].clone().detach(), requires_grad=trainable) # Convert constant to 1x1 tensor
|
||||
else:
|
||||
assert False, "Invalid texture object"
|
||||
|
||||
self.min_max = min_max
|
||||
|
||||
# Filtered (trilinear) sample texture at a given location
|
||||
def sample(self, texc, texc_deriv, filter_mode='linear-mipmap-linear'):
|
||||
if isinstance(self.data, list):
|
||||
out = dr.texture(self.data[0], texc, texc_deriv, mip=self.data[1:], filter_mode=filter_mode)
|
||||
else:
|
||||
if self.data.shape[1] > 1 and self.data.shape[2] > 1:
|
||||
mips = [self.data]
|
||||
while mips[-1].shape[1] > 1 and mips[-1].shape[2] > 1:
|
||||
mips += [texture2d_mip.apply(mips[-1])]
|
||||
out = dr.texture(mips[0], texc, texc_deriv, mip=mips[1:], filter_mode=filter_mode)
|
||||
else:
|
||||
out = dr.texture(self.data, texc, texc_deriv, filter_mode=filter_mode)
|
||||
return out
|
||||
|
||||
def getRes(self):
|
||||
return self.getMips()[0].shape[1:3]
|
||||
|
||||
def getChannels(self):
|
||||
return self.getMips()[0].shape[3]
|
||||
|
||||
def getMips(self):
|
||||
if isinstance(self.data, list):
|
||||
return self.data
|
||||
else:
|
||||
return [self.data]
|
||||
|
||||
# In-place clamp with no derivative to make sure values are in valid range after training
|
||||
def clamp_(self):
|
||||
if self.min_max is not None:
|
||||
for mip in self.getMips():
|
||||
for i in range(mip.shape[-1]):
|
||||
mip[..., i].clamp_(min=self.min_max[0][i], max=self.min_max[1][i])
|
||||
|
||||
# In-place clamp with no derivative to make sure values are in valid range after training
|
||||
def normalize_(self):
|
||||
with torch.no_grad():
|
||||
for mip in self.getMips():
|
||||
mip = util.safe_normalize(mip)
|
||||
|
||||
########################################################################################################
|
||||
# Helper function to create a trainable texture from a regular texture. The trainable weights are
|
||||
# initialized with texture data as an initial guess
|
||||
########################################################################################################
|
||||
|
||||
def create_trainable(init, res=None, auto_mipmaps=True, min_max=None):
|
||||
with torch.no_grad():
|
||||
if isinstance(init, Texture2D):
|
||||
assert isinstance(init.data, torch.Tensor)
|
||||
min_max = init.min_max if min_max is None else min_max
|
||||
init = init.data
|
||||
elif isinstance(init, np.ndarray):
|
||||
init = torch.tensor(init, dtype=torch.float32, device='cuda')
|
||||
|
||||
# Pad to NHWC if needed
|
||||
if len(init.shape) == 1: # Extend constant to NHWC tensor
|
||||
init = init[None, None, None, :]
|
||||
elif len(init.shape) == 3:
|
||||
init = init[None, ...]
|
||||
|
||||
# Scale input to desired resolution.
|
||||
if res is not None:
|
||||
init = util.scale_img_nhwc(init, res)
|
||||
|
||||
# Genreate custom mipchain
|
||||
if not auto_mipmaps:
|
||||
mip_chain = [init.clone().detach().requires_grad_(True)]
|
||||
while mip_chain[-1].shape[1] > 1 or mip_chain[-1].shape[2] > 1:
|
||||
new_size = [max(mip_chain[-1].shape[1] // 2, 1), max(mip_chain[-1].shape[2] // 2, 1)]
|
||||
mip_chain += [util.scale_img_nhwc(mip_chain[-1], new_size)]
|
||||
return Texture2D(mip_chain, min_max=min_max)
|
||||
else:
|
||||
return Texture2D(init, min_max=min_max)
|
||||
|
||||
########################################################################################################
|
||||
# Convert texture to and from SRGB
|
||||
########################################################################################################
|
||||
|
||||
def srgb_to_rgb(texture):
|
||||
return Texture2D(list(util.srgb_to_rgb(mip) for mip in texture.getMips()))
|
||||
|
||||
def rgb_to_srgb(texture):
|
||||
return Texture2D(list(util.rgb_to_srgb(mip) for mip in texture.getMips()))
|
||||
|
||||
########################################################################################################
|
||||
# Utility functions for loading / storing a texture
|
||||
########################################################################################################
|
||||
|
||||
def _load_mip2D(fn, lambda_fn=None, channels=None):
|
||||
imgdata = torch.tensor(util.load_image(fn), dtype=torch.float32, device='cuda')
|
||||
if channels is not None:
|
||||
imgdata = imgdata[..., 0:channels]
|
||||
if lambda_fn is not None:
|
||||
imgdata = lambda_fn(imgdata)
|
||||
return imgdata.detach().clone()
|
||||
|
||||
def load_texture2D(fn, lambda_fn=None, channels=None):
|
||||
base, ext = os.path.splitext(fn)
|
||||
# if os.path.exists(base + "_0" + ext):
|
||||
# mips = []
|
||||
# while os.path.exists(base + ("_%d" % len(mips)) + ext):
|
||||
# mips += [_load_mip2D(base + ("_%d" % len(mips)) + ext, lambda_fn, channels)]
|
||||
# return Texture2D(mips)
|
||||
# else:
|
||||
# return Texture2D(_load_mip2D(fn, lambda_fn, channels))
|
||||
return Texture2D(_load_mip2D(fn, lambda_fn, channels))
|
||||
|
||||
def _save_mip2D(fn, mip, mipidx, lambda_fn):
|
||||
if lambda_fn is not None:
|
||||
data = lambda_fn(mip).detach().cpu().numpy()
|
||||
else:
|
||||
data = mip.detach().cpu().numpy()
|
||||
|
||||
if mipidx is None:
|
||||
util.save_image(fn, data)
|
||||
else:
|
||||
base, ext = os.path.splitext(fn)
|
||||
util.save_image(base + ("_%d" % mipidx) + ext, data)
|
||||
|
||||
def save_texture2D(fn, tex, lambda_fn=None):
|
||||
if isinstance(tex.data, list):
|
||||
for i, mip in enumerate(tex.data):
|
||||
_save_mip2D(fn, mip[0,...], i, lambda_fn)
|
||||
else:
|
||||
_save_mip2D(fn, tex.data[0,...], None, lambda_fn)
|
|
@ -0,0 +1,482 @@
|
|||
# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
#
|
||||
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
||||
# property and proprietary rights in and to this material, related
|
||||
# documentation and any modifications thereto. Any use, reproduction,
|
||||
# disclosure or distribution of this material and related documentation
|
||||
# without an express license agreement from NVIDIA CORPORATION or
|
||||
# its affiliates is strictly prohibited.
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
import torch
|
||||
import nvdiffrast.torch as dr
|
||||
import imageio
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
# Vector operations
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
def dot(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||
return torch.sum(x*y, -1, keepdim=True)
|
||||
|
||||
def reflect(x: torch.Tensor, n: torch.Tensor) -> torch.Tensor:
|
||||
return 2*dot(x, n)*n - x
|
||||
|
||||
def length(x: torch.Tensor, eps: float = 1e-20) -> torch.Tensor:
|
||||
return torch.sqrt(torch.clamp(dot(x,x), min=eps)) # Clamp to avoid nan gradients because grad(sqrt(0)) = NaN
|
||||
# print(dot(x,x).min())
|
||||
# if torch.isnan(dot(x,x).min()):
|
||||
# raise
|
||||
return torch.sqrt(dot(x,x) + eps) + eps # Clamp to avoid nan gradients because grad(sqrt(0)) = NaN
|
||||
|
||||
def safe_normalize(x: torch.Tensor, eps: float = 1e-20) -> torch.Tensor:
|
||||
# def safe_normalize(x: torch.Tensor, eps: float = 1e-9) -> torch.Tensor:
|
||||
return x / length(x, eps)
|
||||
|
||||
def to_hvec(x: torch.Tensor, w: float) -> torch.Tensor:
|
||||
return torch.nn.functional.pad(x, pad=(0,1), mode='constant', value=w)
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
# sRGB color transforms
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
def _rgb_to_srgb(f: torch.Tensor) -> torch.Tensor:
|
||||
return torch.where(f <= 0.0031308, f * 12.92, torch.pow(torch.clamp(f, 0.0031308), 1.0/2.4)*1.055 - 0.055)
|
||||
|
||||
def rgb_to_srgb(f: torch.Tensor) -> torch.Tensor:
|
||||
assert f.shape[-1] == 3 or f.shape[-1] == 4
|
||||
out = torch.cat((_rgb_to_srgb(f[..., 0:3]), f[..., 3:4]), dim=-1) if f.shape[-1] == 4 else _rgb_to_srgb(f)
|
||||
assert out.shape[0] == f.shape[0] and out.shape[1] == f.shape[1] and out.shape[2] == f.shape[2]
|
||||
return out
|
||||
|
||||
def _srgb_to_rgb(f: torch.Tensor) -> torch.Tensor:
|
||||
return torch.where(f <= 0.04045, f / 12.92, torch.pow((torch.clamp(f, 0.04045) + 0.055) / 1.055, 2.4))
|
||||
|
||||
def srgb_to_rgb(f: torch.Tensor) -> torch.Tensor:
|
||||
assert f.shape[-1] == 3 or f.shape[-1] == 4
|
||||
out = torch.cat((_srgb_to_rgb(f[..., 0:3]), f[..., 3:4]), dim=-1) if f.shape[-1] == 4 else _srgb_to_rgb(f)
|
||||
assert out.shape[0] == f.shape[0] and out.shape[1] == f.shape[1] and out.shape[2] == f.shape[2]
|
||||
return out
|
||||
|
||||
def reinhard(f: torch.Tensor) -> torch.Tensor:
|
||||
return f/(1+f)
|
||||
|
||||
#-----------------------------------------------------------------------------------
|
||||
# Metrics (taken from jaxNerf source code, in order to replicate their measurements)
|
||||
#
|
||||
# https://github.com/google-research/google-research/blob/301451a62102b046bbeebff49a760ebeec9707b8/jaxnerf/nerf/utils.py#L266
|
||||
#
|
||||
#-----------------------------------------------------------------------------------
|
||||
|
||||
def mse_to_psnr(mse):
|
||||
"""Compute PSNR given an MSE (we assume the maximum pixel value is 1)."""
|
||||
return -10. / np.log(10.) * np.log(mse)
|
||||
|
||||
def psnr_to_mse(psnr):
|
||||
"""Compute MSE given a PSNR (we assume the maximum pixel value is 1)."""
|
||||
return np.exp(-0.1 * np.log(10.) * psnr)
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
# Displacement texture lookup
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
def get_miplevels(texture: np.ndarray) -> float:
|
||||
minDim = min(texture.shape[0], texture.shape[1])
|
||||
return np.floor(np.log2(minDim))
|
||||
|
||||
def tex_2d(tex_map : torch.Tensor, coords : torch.Tensor, filter='nearest') -> torch.Tensor:
|
||||
tex_map = tex_map[None, ...] # Add batch dimension
|
||||
tex_map = tex_map.permute(0, 3, 1, 2) # NHWC -> NCHW
|
||||
tex = torch.nn.functional.grid_sample(tex_map, coords[None, None, ...] * 2 - 1, mode=filter, align_corners=False)
|
||||
tex = tex.permute(0, 2, 3, 1) # NCHW -> NHWC
|
||||
return tex[0, 0, ...]
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
# Cubemap utility functions
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
def cube_to_dir(s, x, y):
|
||||
if s == 0: rx, ry, rz = torch.ones_like(x), -y, -x
|
||||
elif s == 1: rx, ry, rz = -torch.ones_like(x), -y, x
|
||||
elif s == 2: rx, ry, rz = x, torch.ones_like(x), y
|
||||
elif s == 3: rx, ry, rz = x, -torch.ones_like(x), -y
|
||||
elif s == 4: rx, ry, rz = x, -y, torch.ones_like(x)
|
||||
elif s == 5: rx, ry, rz = -x, -y, -torch.ones_like(x)
|
||||
return torch.stack((rx, ry, rz), dim=-1)
|
||||
|
||||
def latlong_to_cubemap(latlong_map, res):
|
||||
cubemap = torch.zeros(6, res[0], res[1], latlong_map.shape[-1], dtype=torch.float32, device='cuda')
|
||||
for s in range(6):
|
||||
gy, gx = torch.meshgrid(torch.linspace(-1.0 + 1.0 / res[0], 1.0 - 1.0 / res[0], res[0], device='cuda'),
|
||||
torch.linspace(-1.0 + 1.0 / res[1], 1.0 - 1.0 / res[1], res[1], device='cuda'),
|
||||
indexing='ij')
|
||||
v = safe_normalize(cube_to_dir(s, gx, gy))
|
||||
|
||||
tu = torch.atan2(v[..., 0:1], -v[..., 2:3]) / (2 * np.pi) + 0.5
|
||||
tv = torch.acos(torch.clamp(v[..., 1:2], min=-1, max=1)) / np.pi
|
||||
texcoord = torch.cat((tu, tv), dim=-1)
|
||||
|
||||
cubemap[s, ...] = dr.texture(latlong_map[None, ...], texcoord[None, ...], filter_mode='linear')[0]
|
||||
return cubemap
|
||||
|
||||
def cubemap_to_latlong(cubemap, res):
|
||||
gy, gx = torch.meshgrid(torch.linspace( 0.0 + 1.0 / res[0], 1.0 - 1.0 / res[0], res[0], device='cuda'),
|
||||
torch.linspace(-1.0 + 1.0 / res[1], 1.0 - 1.0 / res[1], res[1], device='cuda'),
|
||||
indexing='ij')
|
||||
|
||||
sintheta, costheta = torch.sin(gy*np.pi), torch.cos(gy*np.pi)
|
||||
sinphi, cosphi = torch.sin(gx*np.pi), torch.cos(gx*np.pi)
|
||||
|
||||
reflvec = torch.stack((
|
||||
sintheta*sinphi,
|
||||
costheta,
|
||||
-sintheta*cosphi
|
||||
), dim=-1)
|
||||
return dr.texture(cubemap[None, ...], reflvec[None, ...].contiguous(), filter_mode='linear', boundary_mode='cube')[0]
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
# Image scaling
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
def scale_img_hwc(x : torch.Tensor, size, mag='bilinear', min='area') -> torch.Tensor:
|
||||
return scale_img_nhwc(x[None, ...], size, mag, min)[0]
|
||||
|
||||
def scale_img_nhwc(x : torch.Tensor, size, mag='bilinear', min='area') -> torch.Tensor:
|
||||
# assert (x.shape[1] >= size[0] and x.shape[2] >= size[1]) or (x.shape[1] < size[0] and x.shape[2] < size[1]), "Trying to magnify image in one dimension and minify in the other"
|
||||
assert (x.shape[1] >= size[0] and x.shape[2] >= size[1]) or (x.shape[1] <= size[0] and x.shape[2] <= size[1]), "Trying to magnify image in one dimension and minify in the other"
|
||||
y = x.permute(0, 3, 1, 2) # NHWC -> NCHW
|
||||
if x.shape[1] > size[0] and x.shape[2] > size[1]: # Minification, previous size was bigger
|
||||
y = torch.nn.functional.interpolate(y, size, mode=min)
|
||||
else: # Magnification
|
||||
if mag == 'bilinear' or mag == 'bicubic':
|
||||
y = torch.nn.functional.interpolate(y, size, mode=mag, align_corners=True)
|
||||
else:
|
||||
y = torch.nn.functional.interpolate(y, size, mode=mag)
|
||||
return y.permute(0, 2, 3, 1).contiguous() # NCHW -> NHWC
|
||||
|
||||
def avg_pool_nhwc(x : torch.Tensor, size) -> torch.Tensor:
|
||||
y = x.permute(0, 3, 1, 2) # NHWC -> NCHW
|
||||
y = torch.nn.functional.avg_pool2d(y, size)
|
||||
return y.permute(0, 2, 3, 1).contiguous() # NCHW -> NHWC
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
# Behaves similar to tf.segment_sum
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
def segment_sum(data: torch.Tensor, segment_ids: torch.Tensor) -> torch.Tensor:
|
||||
num_segments = torch.unique_consecutive(segment_ids).shape[0]
|
||||
|
||||
# Repeats ids until same dimension as data
|
||||
if len(segment_ids.shape) == 1:
|
||||
s = torch.prod(torch.tensor(data.shape[1:], dtype=torch.int64, device='cuda')).long()
|
||||
segment_ids = segment_ids.repeat_interleave(s).view(segment_ids.shape[0], *data.shape[1:])
|
||||
|
||||
assert data.shape == segment_ids.shape, "data.shape and segment_ids.shape should be equal"
|
||||
|
||||
shape = [num_segments] + list(data.shape[1:])
|
||||
result = torch.zeros(*shape, dtype=torch.float32, device='cuda')
|
||||
result = result.scatter_add(0, segment_ids, data)
|
||||
return result
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
# Matrix helpers.
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
def fovx_to_fovy(fovx, aspect):
|
||||
return np.arctan(np.tan(fovx / 2) / aspect) * 2.0
|
||||
|
||||
def focal_length_to_fovy(focal_length, sensor_height):
|
||||
return 2 * np.arctan(0.5 * sensor_height / focal_length)
|
||||
|
||||
# Reworked so this matches gluPerspective / glm::perspective, using fovy
|
||||
def perspective(fovy=0.7854, aspect=1.0, n=0.1, f=1000.0, device=None):
|
||||
y = np.tan(fovy / 2)
|
||||
return torch.tensor([[1/(y*aspect), 0, 0, 0],
|
||||
[ 0, 1/-y, 0, 0],
|
||||
[ 0, 0, -(f+n)/(f-n), -(2*f*n)/(f-n)],
|
||||
[ 0, 0, -1, 0]], dtype=torch.float32, device=device)
|
||||
|
||||
# Reworked so this matches gluPerspective / glm::perspective, using fovy
|
||||
def perspective_offcenter(fovy, fraction, rx, ry, aspect=1.0, n=0.1, f=1000.0, device=None):
|
||||
y = np.tan(fovy / 2)
|
||||
|
||||
# Full frustum
|
||||
R, L = aspect*y, -aspect*y
|
||||
T, B = y, -y
|
||||
|
||||
# Create a randomized sub-frustum
|
||||
width = (R-L)*fraction
|
||||
height = (T-B)*fraction
|
||||
xstart = (R-L)*rx
|
||||
ystart = (T-B)*ry
|
||||
|
||||
l = L + xstart
|
||||
r = l + width
|
||||
b = B + ystart
|
||||
t = b + height
|
||||
|
||||
# https://www.scratchapixel.com/lessons/3d-basic-rendering/perspective-and-orthographic-projection-matrix/opengl-perspective-projection-matrix
|
||||
return torch.tensor([[2/(r-l), 0, (r+l)/(r-l), 0],
|
||||
[ 0, -2/(t-b), (t+b)/(t-b), 0],
|
||||
[ 0, 0, -(f+n)/(f-n), -(2*f*n)/(f-n)],
|
||||
[ 0, 0, -1, 0]], dtype=torch.float32, device=device)
|
||||
|
||||
def translate(x, y, z, device=None):
|
||||
return torch.tensor([[1, 0, 0, x],
|
||||
[0, 1, 0, y],
|
||||
[0, 0, 1, z],
|
||||
[0, 0, 0, 1]], dtype=torch.float32, device=device)
|
||||
|
||||
def rotate_x(a, device=None):
|
||||
s, c = np.sin(a), np.cos(a)
|
||||
return torch.tensor([[1, 0, 0, 0],
|
||||
[0, c, s, 0],
|
||||
[0, -s, c, 0],
|
||||
[0, 0, 0, 1]], dtype=torch.float32, device=device)
|
||||
|
||||
def rotate_y(a, device=None):
|
||||
s, c = np.sin(a), np.cos(a)
|
||||
return torch.tensor([[ c, 0, s, 0],
|
||||
[ 0, 1, 0, 0],
|
||||
[-s, 0, c, 0],
|
||||
[ 0, 0, 0, 1]], dtype=torch.float32, device=device)
|
||||
|
||||
def scale(s, device=None):
|
||||
return torch.tensor([[ s, 0, 0, 0],
|
||||
[ 0, s, 0, 0],
|
||||
[ 0, 0, s, 0],
|
||||
[ 0, 0, 0, 1]], dtype=torch.float32, device=device)
|
||||
|
||||
def lookAt(eye, at, up):
|
||||
a = eye - at
|
||||
w = a / torch.linalg.norm(a)
|
||||
u = torch.cross(up, w)
|
||||
u = u / torch.linalg.norm(u)
|
||||
v = torch.cross(w, u)
|
||||
translate = torch.tensor([[1, 0, 0, -eye[0]],
|
||||
[0, 1, 0, -eye[1]],
|
||||
[0, 0, 1, -eye[2]],
|
||||
[0, 0, 0, 1]], dtype=eye.dtype, device=eye.device)
|
||||
rotate = torch.tensor([[u[0], u[1], u[2], 0],
|
||||
[v[0], v[1], v[2], 0],
|
||||
[w[0], w[1], w[2], 0],
|
||||
[0, 0, 0, 1]], dtype=eye.dtype, device=eye.device)
|
||||
return rotate @ translate
|
||||
|
||||
@torch.no_grad()
|
||||
def random_rotation_translation(t, device=None):
|
||||
m = np.random.normal(size=[3, 3])
|
||||
m[1] = np.cross(m[0], m[2])
|
||||
m[2] = np.cross(m[0], m[1])
|
||||
m = m / np.linalg.norm(m, axis=1, keepdims=True)
|
||||
m = np.pad(m, [[0, 1], [0, 1]], mode='constant')
|
||||
m[3, 3] = 1.0
|
||||
m[:3, 3] = np.random.uniform(-t, t, size=[3])
|
||||
return torch.tensor(m, dtype=torch.float32, device=device)
|
||||
|
||||
@torch.no_grad()
|
||||
def random_rotation(device=None):
|
||||
m = np.random.normal(size=[3, 3])
|
||||
m[1] = np.cross(m[0], m[2])
|
||||
m[2] = np.cross(m[0], m[1])
|
||||
m = m / np.linalg.norm(m, axis=1, keepdims=True)
|
||||
m = np.pad(m, [[0, 1], [0, 1]], mode='constant')
|
||||
m[3, 3] = 1.0
|
||||
m[:3, 3] = np.array([0,0,0]).astype(np.float32)
|
||||
return torch.tensor(m, dtype=torch.float32, device=device)
|
||||
|
||||
@torch.no_grad()
|
||||
def batch_random_rotation(batch_size, device=None):
|
||||
m = np.random.normal(size=[batch_size, 3, 3])
|
||||
m[:, 1] = np.cross(m[:, 0], m[:, 2])
|
||||
m[:, 2] = np.cross(m[:, 0], m[:, 1])
|
||||
m = m / np.linalg.norm(m, axis=-1, keepdims=True)
|
||||
m = np.pad(m, [[0, 0], [0, 1], [0, 1]], mode='constant')
|
||||
m[:, 3, 3] = 1.0
|
||||
m[:, :3, 3] = np.array([0,0,0]).astype(np.float32).unsqueeze(0)
|
||||
return torch.tensor(m, dtype=torch.float32, device=device)
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
# Compute focal points of a set of lines using least squares.
|
||||
# handy for poorly centered datasets
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
def lines_focal(o, d):
|
||||
d = safe_normalize(d)
|
||||
I = torch.eye(3, dtype=o.dtype, device=o.device)
|
||||
S = torch.sum(d[..., None] @ torch.transpose(d[..., None], 1, 2) - I[None, ...], dim=0)
|
||||
C = torch.sum((d[..., None] @ torch.transpose(d[..., None], 1, 2) - I[None, ...]) @ o[..., None], dim=0).squeeze(1)
|
||||
return torch.linalg.pinv(S) @ C
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
# Cosine sample around a vector N
|
||||
#----------------------------------------------------------------------------
|
||||
@torch.no_grad()
|
||||
def cosine_sample(N, size=None):
|
||||
# construct local frame
|
||||
N = N/torch.linalg.norm(N)
|
||||
|
||||
dx0 = torch.tensor([0, N[2], -N[1]], dtype=N.dtype, device=N.device)
|
||||
dx1 = torch.tensor([-N[2], 0, N[0]], dtype=N.dtype, device=N.device)
|
||||
|
||||
dx = torch.where(dot(dx0, dx0) > dot(dx1, dx1), dx0, dx1)
|
||||
#dx = dx0 if np.dot(dx0,dx0) > np.dot(dx1,dx1) else dx1
|
||||
dx = dx / torch.linalg.norm(dx)
|
||||
dy = torch.cross(N,dx)
|
||||
dy = dy / torch.linalg.norm(dy)
|
||||
|
||||
# cosine sampling in local frame
|
||||
if size is None:
|
||||
phi = 2.0 * np.pi * np.random.uniform()
|
||||
s = np.random.uniform()
|
||||
else:
|
||||
phi = 2.0 * np.pi * torch.rand(*size, 1, dtype=N.dtype, device=N.device)
|
||||
s = torch.rand(*size, 1, dtype=N.dtype, device=N.device)
|
||||
costheta = np.sqrt(s)
|
||||
sintheta = np.sqrt(1.0 - s)
|
||||
|
||||
# cartesian vector in local space
|
||||
x = np.cos(phi)*sintheta
|
||||
y = np.sin(phi)*sintheta
|
||||
z = costheta
|
||||
|
||||
# local to world
|
||||
return dx*x + dy*y + N*z
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
# Bilinear downsample by 2x.
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
def bilinear_downsample(x : torch.tensor) -> torch.Tensor:
|
||||
w = torch.tensor([[1, 3, 3, 1], [3, 9, 9, 3], [3, 9, 9, 3], [1, 3, 3, 1]], dtype=torch.float32, device=x.device) / 64.0
|
||||
w = w.expand(x.shape[-1], 1, 4, 4)
|
||||
x = torch.nn.functional.conv2d(x.permute(0, 3, 1, 2), w, padding=1, stride=2, groups=x.shape[-1])
|
||||
return x.permute(0, 2, 3, 1)
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
# Bilinear downsample log(spp) steps
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
def bilinear_downsample(x : torch.tensor, spp) -> torch.Tensor:
|
||||
w = torch.tensor([[1, 3, 3, 1], [3, 9, 9, 3], [3, 9, 9, 3], [1, 3, 3, 1]], dtype=torch.float32, device=x.device) / 64.0
|
||||
g = x.shape[-1]
|
||||
w = w.expand(g, 1, 4, 4)
|
||||
x = x.permute(0, 3, 1, 2) # NHWC -> NCHW
|
||||
steps = int(np.log2(spp))
|
||||
for _ in range(steps):
|
||||
xp = torch.nn.functional.pad(x, (1,1,1,1), mode='replicate')
|
||||
x = torch.nn.functional.conv2d(xp, w, padding=0, stride=2, groups=g)
|
||||
return x.permute(0, 2, 3, 1).contiguous() # NCHW -> NHWC
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
# Singleton initialize GLFW
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
_glfw_initialized = False
|
||||
def init_glfw():
|
||||
global _glfw_initialized
|
||||
try:
|
||||
import glfw
|
||||
glfw.ERROR_REPORTING = 'raise'
|
||||
glfw.default_window_hints()
|
||||
glfw.window_hint(glfw.VISIBLE, glfw.FALSE)
|
||||
test = glfw.create_window(8, 8, "Test", None, None) # Create a window and see if not initialized yet
|
||||
except glfw.GLFWError as e:
|
||||
if e.error_code == glfw.NOT_INITIALIZED:
|
||||
glfw.init()
|
||||
_glfw_initialized = True
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
# Image display function using OpenGL.
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
_glfw_window = None
|
||||
def display_image(image, title=None):
|
||||
# Import OpenGL
|
||||
import OpenGL.GL as gl
|
||||
import glfw
|
||||
|
||||
# Zoom image if requested.
|
||||
image = np.asarray(image[..., 0:3]) if image.shape[-1] == 4 else np.asarray(image)
|
||||
height, width, channels = image.shape
|
||||
|
||||
# Initialize window.
|
||||
init_glfw()
|
||||
if title is None:
|
||||
title = 'Debug window'
|
||||
global _glfw_window
|
||||
if _glfw_window is None:
|
||||
glfw.default_window_hints()
|
||||
_glfw_window = glfw.create_window(width, height, title, None, None)
|
||||
glfw.make_context_current(_glfw_window)
|
||||
glfw.show_window(_glfw_window)
|
||||
glfw.swap_interval(0)
|
||||
else:
|
||||
glfw.make_context_current(_glfw_window)
|
||||
glfw.set_window_title(_glfw_window, title)
|
||||
glfw.set_window_size(_glfw_window, width, height)
|
||||
|
||||
# Update window.
|
||||
glfw.poll_events()
|
||||
gl.glClearColor(0, 0, 0, 1)
|
||||
gl.glClear(gl.GL_COLOR_BUFFER_BIT)
|
||||
gl.glWindowPos2f(0, 0)
|
||||
gl.glPixelStorei(gl.GL_UNPACK_ALIGNMENT, 1)
|
||||
gl_format = {3: gl.GL_RGB, 2: gl.GL_RG, 1: gl.GL_LUMINANCE}[channels]
|
||||
gl_dtype = {'uint8': gl.GL_UNSIGNED_BYTE, 'float32': gl.GL_FLOAT}[image.dtype.name]
|
||||
gl.glDrawPixels(width, height, gl_format, gl_dtype, image[::-1])
|
||||
glfw.swap_buffers(_glfw_window)
|
||||
if glfw.window_should_close(_glfw_window):
|
||||
return False
|
||||
return True
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
# Image save/load helper.
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
def save_image(fn, x : np.ndarray):
|
||||
try:
|
||||
if os.path.splitext(fn)[1] == ".png":
|
||||
imageio.imwrite(fn, np.clip(np.rint(x * 255.0), 0, 255).astype(np.uint8), compress_level=3) # Low compression for faster saving
|
||||
else:
|
||||
imageio.imwrite(fn, np.clip(np.rint(x * 255.0), 0, 255).astype(np.uint8))
|
||||
except:
|
||||
print("WARNING: FAILED to save image %s" % fn)
|
||||
|
||||
def save_image_raw(fn, x : np.ndarray):
|
||||
try:
|
||||
imageio.imwrite(fn, x)
|
||||
except:
|
||||
print("WARNING: FAILED to save image %s" % fn)
|
||||
|
||||
|
||||
def load_image_raw(fn) -> np.ndarray:
|
||||
return imageio.imread(fn)
|
||||
|
||||
def load_image(fn) -> np.ndarray:
|
||||
img = load_image_raw(fn)
|
||||
if img.dtype == np.float32: # HDR image
|
||||
return img
|
||||
else: # LDR image
|
||||
return img.astype(np.float32) / 255
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
def time_to_text(x):
|
||||
if x > 3600:
|
||||
return "%.2f h" % (x / 3600)
|
||||
elif x > 60:
|
||||
return "%.2f m" % (x / 60)
|
||||
else:
|
||||
return "%.2f s" % x
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
def checkerboard(res, checker_size) -> np.ndarray:
|
||||
tiles_y = (res[0] + (checker_size*2) - 1) // (checker_size*2)
|
||||
tiles_x = (res[1] + (checker_size*2) - 1) // (checker_size*2)
|
||||
check = np.kron([[1, 0] * tiles_x, [0, 1] * tiles_x] * tiles_y, np.ones((checker_size, checker_size)))*0.33 + 0.33
|
||||
check = check[:res[0], :res[1]]
|
||||
return np.stack((check, check, check), axis=-1)
|
||||
|
Ładowanie…
Reference in New Issue