@ -0,0 +1,196 @@
# coding=utf-8
# Copyright 2020 The Google Research Authors.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: skip-file
"""Return training and evaluation/test datasets from config files."""
import jax
import tensorflow as tf
import tensorflow_datasets as tfds
def get_data_scaler(config):
"""Data normalizer. Assume data are always in [0, 1]."""
if config.data.centered:
# Rescale to [-1, 1]
return lambda x: x * 2. - 1.
return lambda x: x
def get_data_inverse_scaler(config):
"""Inverse data normalizer."""
if config.data.centered:
# Rescale [-1, 1] to [0, 1]
return lambda x: (x + 1.) / 2.
return lambda x: x
def crop_resize(image, resolution):
"""Crop and resize an image to the given resolution."""
crop = tf.minimum(tf.shape(image)[0], tf.shape(image)[1])
h, w = tf.shape(image)[0], tf.shape(image)[1]
image = image[(h - crop) // 2:(h + crop) // 2,
(w - crop) // 2:(w + crop) // 2]
image = tf.image.resize(
size=(resolution, resolution),
return tf.cast(image, tf.uint8)
def resize_small(image, resolution):
"""Shrink an image to the given resolution."""
h, w = image.shape[0], image.shape[1]
ratio = resolution / min(h, w)
h = tf.round(h * ratio, tf.int32)
w = tf.round(w * ratio, tf.int32)
return tf.image.resize(image, [h, w], antialias=True)
def central_crop(image, size):
"""Crop the center of an image to the given size."""
top = (image.shape[0] - size) // 2
left = (image.shape[1] - size) // 2
return tf.image.crop_to_bounding_box(image, top, left, size, size)
def get_dataset(config, uniform_dequantization=False, evaluation=False):
"""Create data loaders for training and evaluation.
config: A ml_collection.ConfigDict parsed from config files.
uniform_dequantization: If `True`, add uniform dequantization to images.
evaluation: If `True`, fix number of epochs to 1.
train_ds, eval_ds, dataset_builder.
# Compute batch size for this worker.
batch_size = config.training.batch_size if not evaluation else config.eval.batch_size
if batch_size % jax.device_count() != 0:
raise ValueError(f'Batch sizes ({batch_size} must be divided by'
f'the number of devices ({jax.device_count()})')
# Reduce this when image resolution is too large and data pointer is stored
shuffle_buffer_size = 10000
prefetch_size = tf.data.experimental.AUTOTUNE
num_epochs = None if not evaluation else 1
# Create dataset builders for each dataset.
if config.data.dataset == 'CIFAR10':
dataset_builder = tfds.builder('cifar10')
train_split_name = 'train'
eval_split_name = 'test'
def resize_op(img):
img = tf.image.convert_image_dtype(img, tf.float32)
return tf.image.resize(img, [config.data.image_size, config.data.image_size], antialias=True)
elif config.data.dataset == 'SVHN':
dataset_builder = tfds.builder('svhn_cropped')
train_split_name = 'train'
eval_split_name = 'test'
def resize_op(img):
img = tf.image.convert_image_dtype(img, tf.float32)
return tf.image.resize(img, [config.data.image_size, config.data.image_size], antialias=True)
elif config.data.dataset == 'CELEBA':
dataset_builder = tfds.builder('celeb_a')
train_split_name = 'train'
eval_split_name = 'validation'
def resize_op(img):
img = tf.image.convert_image_dtype(img, tf.float32)
img = central_crop(img, 140)
img = resize_small(img, config.data.image_size)
return img
elif config.data.dataset == 'LSUN':
dataset_builder = tfds.builder(f'lsun/{config.data.category}')
train_split_name = 'train'
eval_split_name = 'validation'
if config.data.image_size == 128:
def resize_op(img):
img = tf.image.convert_image_dtype(img, tf.float32)
img = resize_small(img, config.data.image_size)
img = central_crop(img, config.data.image_size)
return img
def resize_op(img):
img = crop_resize(img, config.data.image_size)
img = tf.image.convert_image_dtype(img, tf.float32)
return img
elif config.data.dataset in ['FFHQ', 'CelebAHQ']:
dataset_builder = tf.data.TFRecordDataset(config.data.tfrecords_path)
train_split_name = eval_split_name = 'train'
raise NotImplementedError(
f'Dataset {config.data.dataset} not yet supported.')
# Customize preprocess functions for each dataset.
if config.data.dataset in ['FFHQ', 'CelebAHQ']:
def preprocess_fn(d):
sample = tf.io.parse_single_example(d, features={
'shape': tf.io.FixedLenFeature([3], tf.int64),
'data': tf.io.FixedLenFeature([], tf.string)})
data = tf.io.decode_raw(sample['data'], tf.uint8)
data = tf.reshape(data, sample['shape'])
data = tf.transpose(data, (1, 2, 0))
img = tf.image.convert_image_dtype(data, tf.float32)
if config.data.random_flip and not evaluation:
img = tf.image.random_flip_left_right(img)
if uniform_dequantization:
img = (tf.random.uniform(img.shape, dtype=tf.float32) + img * 255.) / 256.
return dict(image=img, label=None)
def preprocess_fn(d):
"""Basic preprocessing function scales data to [0, 1) and randomly flips."""
img = resize_op(d['image'])
if config.data.random_flip and not evaluation:
img = tf.image.random_flip_left_right(img)
if uniform_dequantization:
img = (tf.random.uniform(img.shape, dtype=tf.float32) + img * 255.) / 256.
return dict(image=img, label=d.get('label', None))
def create_dataset(dataset_builder, split):
dataset_options = tf.data.Options()
dataset_options.experimental_optimization.map_parallelization = True
dataset_options.experimental_threading.private_threadpool_size = 48
dataset_options.experimental_threading.max_intra_op_parallelism = 1
read_config = tfds.ReadConfig(options=dataset_options)
if isinstance(dataset_builder, tfds.core.DatasetBuilder):
ds = dataset_builder.as_dataset(
split=split, shuffle_files=True, read_config=read_config)
ds = dataset_builder.with_options(dataset_options)
ds = ds.repeat(count=num_epochs)
ds = ds.shuffle(shuffle_buffer_size)
ds = ds.map(preprocess_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE)
ds = ds.batch(batch_size, drop_remainder=True)
return ds.prefetch(prefetch_size)
train_ds = create_dataset(dataset_builder, train_split_name)
eval_ds = create_dataset(dataset_builder, eval_split_name)
return train_ds, eval_ds, dataset_builder

@ -0,0 +1,49 @@
import os
import sys
import torch
from torch.utils.data import Dataset
import json
import argparse
class ShapeNetDMTetDataset(Dataset):
def __init__(self, root, grid_mask, deform_scale=1.0, aug=False, filter_meta_path=None, normalize_sdf=True):
self.fpath_list = json.load(open(root, 'r'))
self.deform_scale = deform_scale
self.normalize_sdf = normalize_sdf
print(f"dataset with sdf normalized: {normalize_sdf}")
self.coeff = torch.tensor([1.0, 1.0, self.deform_scale, self.deform_scale, self.deform_scale]).view(-1, 1, 1, 1)
self.aug = aug
self.grid_mask = grid_mask.cpu()
self.resolution = self.grid_mask.size(-1)
if filter_meta_path is not None:
self.filter_ids = json.load(open(filter_meta_path, 'r'))
full_id_list = [int(x.rstrip().split('_')[-1][:-3]) for i, x in enumerate(self.fpath_list)]
fpath_idx_list = [i for i, x in enumerate(full_id_list) if x in self.filter_ids]
self.fpath_list = [self.fpath_list[i] for i in fpath_idx_list]
def __len__(self):
return len(self.fpath_list)
def __getitem__(self, idx):
with torch.no_grad():
datum = torch.load(self.fpath_list[idx], map_location='cpu')
if self.normalize_sdf:
sdf_sign = torch.sign(datum[:, :1])
sdf_sign[sdf_sign == 0] = 1.0
datum[:, :1] = sdf_sign
if self.aug:
nonempty_mask = (datum[1:].abs().sum(dim=0, keepdim=True) != 0)
datum[1:] = datum[1:] + (torch.rand(3)[:, None, None, None] - 0.5) * 0.01 * nonempty_mask / (datum.size(-1) / self.resolution)
if datum.size(-1) < self.resolution:
datum = datum * self.grid_mask[0, :, :datum.size(-1), :datum.size(-1), :datum.size(-1)]
datum = datum * self.grid_mask[0]
if datum.size(-1) < self.resolution:
diff = self.resolution - datum.size(-1)
datum = torch.nn.functional.pad(datum, (0, diff, 0, diff, 0, diff, 0, 0))
return datum

@ -0,0 +1,212 @@
import os
import sys
import numpy as np
import logging
from . import losses
from .models import utils as mutils
from .models.ema import ExponentialMovingAverage
from . import sde_lib
import torch
from .utils import restore_checkpoint
from . import sampling
def uncond_gen(
Unconditional Generation
with torch.no_grad():
eval_dir, ckpt_path = config.eval.eval_dir, config.eval.ckpt_path
# Create directory to eval_folder
os.makedirs(eval_dir, exist_ok=True)
scaler, inverse_scaler = lambda x: x, lambda x: x
# Initialize model
score_model = mutils.create_model(config)
optimizer = losses.get_optimizer(config, score_model.parameters())
ema = ExponentialMovingAverage(score_model.parameters(), decay=config.model.ema_rate)
state = dict(optimizer=optimizer, model=score_model, ema=ema, step=0)
# Setup SDEs
sde = sde_lib.VPSDE(beta_min=config.model.beta_min, beta_max=config.model.beta_max, N=config.model.num_scales)
img_size = config.data.image_size
grid_mask = torch.load(f'./data/grid_mask_{img_size}.pt').view(1, img_size, img_size, img_size).to("cuda")
sampling_eps = 1e-3
sampling_shape = (config.eval.batch_size,
config.data.image_size, config.data.image_size, config.data.image_size)
sampling_fn = sampling.get_sampling_fn(config, sde, sampling_shape, inverse_scaler, sampling_eps, grid_mask=grid_mask)
assert os.path.exists(ckpt_path)
print('ckpt path:', ckpt_path)
state = restore_checkpoint(ckpt_path, state, device=config.device)
print(f"loaded model is trained till iter {state['step'] // config.training.iter_size}")
save_file_path = os.path.join(eval_dir, f"{idx}.npy")
samples, n = sampling_fn(score_model)
samples = samples.cpu().numpy()
np.save(save_file_path, samples)
def slerp(z1, z2, alpha):
Spherical Linear Interpolation
theta = torch.acos(torch.sum(z1 * z2) / (torch.norm(z1) * torch.norm(z2)))
return (
torch.sin((1 - alpha) * theta) / torch.sin(theta) * z1
+ torch.sin(alpha * theta) / torch.sin(theta) * z2
def uncond_gen_interp(
Generation with interpolation between initial noises
Used for DDIM
with torch.no_grad():
eval_dir, ckpt_path = config.eval.eval_dir, config.eval.ckpt_path
# Create directory to eval_folder
os.makedirs(eval_dir, exist_ok=True)
scaler, inverse_scaler = lambda x: x, lambda x: x
# Initialize model
score_model = mutils.create_model(config)
optimizer = losses.get_optimizer(config, score_model.parameters())
ema = ExponentialMovingAverage(score_model.parameters(), decay=config.model.ema_rate)
state = dict(optimizer=optimizer, model=score_model, ema=ema, step=0)
# Setup SDEs
sde = sde_lib.VPSDE(beta_min=config.model.beta_min, beta_max=config.model.beta_max, N=config.model.num_scales)
img_size = config.data.image_size
grid_mask = torch.load(f'./data/grid_mask_{img_size}.pt').view(1, img_size, img_size, img_size).to("cuda")
sampling_eps = 1e-3
sampling_shape = (config.eval.batch_size,
config.data.image_size, config.data.image_size, config.data.image_size)
sampling_fn = sampling.get_sampling_fn(config, sde, sampling_shape, inverse_scaler, sampling_eps, grid_mask=grid_mask)
assert os.path.exists(ckpt_path)
print('ckpt path:', ckpt_path)
state = restore_checkpoint(ckpt_path, state, device=config.device)
print(f"loaded model is trained till iter {state['step'] // config.training.iter_size}")
save_file_path = os.path.join(eval_dir, f"{idx}.npy")
noise = sde.prior_sampling(
(2, config.data.num_channels, config.data.image_size, config.data.image_size, config.data.image_size)
x0 = torch.zeros(sampling_shape, device=config.device)
x0[0] = noise[0]
x0[-1] = noise[1]
for i in range(1, batch_size - 1):
x[i] = slerp(x[0], x[-1], i / float(batch_size - 1))
samples, n = sampling_fn(score_model, x0=x0)
samples = samples.cpu().numpy()
np.save(save_file_path, samples)
def cond_gen(
Conditional Generation with partially completed dmtet from a 2.5D view (converted into a cubic grid)
with torch.no_grad():
eval_dir, ckpt_path = config.eval.eval_dir, config.eval.ckpt_path
# Create directory to eval_folder
os.makedirs(eval_dir, exist_ok=True)
scaler, inverse_scaler = lambda x: x, lambda x: x
# Initialize model
score_model = mutils.create_model(config)
optimizer = losses.get_optimizer(config, score_model.parameters())
ema = ExponentialMovingAverage(score_model.parameters(), decay=config.model.ema_rate)
state = dict(optimizer=optimizer, model=score_model, ema=ema, step=0)
# Setup SDEs
sde = sde_lib.VPSDE(beta_min=config.model.beta_min, beta_max=config.model.beta_max, N=config.model.num_scales)
resolution = config.data.image_size
grid_mask = torch.load(f'./data/grid_mask_{resolution}.pt').view(1, 1, resolution, resolution, resolution).to("cuda")
sampling_eps = 1e-3
sampling_shape = (config.eval.batch_size,
resolution, resolution, resolution)
sampling_fn = sampling.get_sampling_fn(config, sde, sampling_shape, inverse_scaler, sampling_eps, grid_mask=grid_mask)
assert os.path.exists(ckpt_path)
print('ckpt path:', ckpt_path)
state = restore_checkpoint(ckpt_path, state, device=config.device)
print(f"loaded model is trained till iter {state['step'] // config.training.iter_size}")
save_file_path = os.path.join(eval_dir, f"{save_fname}.npy")
### Conditional but free gradients; start from small t
partial_dict = torch.load(config.eval.partial_dmtet_path)
partial_sdf = partial_dict['sdf']
partial_mask = partial_dict['vis']
### compute the mapping from tet indices to 3D cubic grid vertex indices
tet_path = config.eval.tet_path
tet = np.load(tet_path)
vertices = torch.tensor(tet['vertices'])
vertices_unique = vertices[:].unique()
dx = vertices_unique[1] - vertices_unique[0]
ind_to_coord = (torch.round(
(vertices - vertices.min()) / dx)
partial_sdf_grid = torch.zeros((1, 1, resolution, resolution, resolution))
partial_sdf_grid[0, 0, ind_to_coord[:, 0], ind_to_coord[:, 1], ind_to_coord[:, 2]] = partial_sdf
partial_mask_grid = torch.zeros((1, 1, resolution, resolution, resolution))
partial_mask_grid[0, 0, ind_to_coord[:, 0], ind_to_coord[:, 1], ind_to_coord[:, 2]] = partial_mask.float()
samples, n = sampling_fn(
samples = samples.cpu().numpy()
np.save(save_file_path, samples)

@ -0,0 +1,113 @@
# coding=utf-8
# Copyright 2020 The Google Research Authors.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: skip-file
# pytype: skip-file
"""Various sampling methods."""
import torch
import numpy as np
from scipy import integrate
from .models import utils as mutils
def get_div_fn(fn):
"""Create the divergence function of `fn` using the Hutchinson-Skilling trace estimator."""
def div_fn(x, t, eps):
with torch.enable_grad():
fn_eps = torch.sum(fn(x, t) * eps)
grad_fn_eps = torch.autograd.grad(fn_eps, x)[0]
return torch.sum(grad_fn_eps * eps, dim=tuple(range(1, len(x.shape))))
return div_fn
def get_likelihood_fn(sde, inverse_scaler, hutchinson_type='Rademacher',
rtol=1e-5, atol=1e-5, method='RK45', eps=1e-5):
"""Create a function to compute the unbiased log-likelihood estimate of a given data point.
sde: A `sde_lib.SDE` object that represents the forward SDE.
inverse_scaler: The inverse data normalizer.
hutchinson_type: "Rademacher" or "Gaussian". The type of noise for Hutchinson-Skilling trace estimator.
rtol: A `float` number. The relative tolerance level of the black-box ODE solver.
atol: A `float` number. The absolute tolerance level of the black-box ODE solver.
method: A `str`. The algorithm for the black-box ODE solver.
See documentation for `scipy.integrate.solve_ivp`.
eps: A `float` number. The probability flow ODE is integrated to `eps` for numerical stability.
A function that a batch of data points and returns the log-likelihoods in bits/dim,
the latent code, and the number of function evaluations cost by computation.
def drift_fn(model, x, t):
"""The drift function of the reverse-time SDE."""
score_fn = mutils.get_score_fn(sde, model, train=False, continuous=True)
# Probability flow ODE is a special case of Reverse SDE
rsde = sde.reverse(score_fn, probability_flow=True)
return rsde.sde(x, t)[0]
def div_fn(model, x, t, noise):
return get_div_fn(lambda xx, tt: drift_fn(model, xx, tt))(x, t, noise)
def likelihood_fn(model, data):
"""Compute an unbiased estimate to the log-likelihood in bits/dim.
model: A score model.
data: A PyTorch tensor.
bpd: A PyTorch tensor of shape [batch size]. The log-likelihoods on `data` in bits/dim.
z: A PyTorch tensor of the same shape as `data`. The latent representation of `data` under the
probability flow ODE.
nfe: An integer. The number of function evaluations used for running the black-box ODE solver.
with torch.no_grad():
shape = data.shape
if hutchinson_type == 'Gaussian':
epsilon = torch.randn_like(data)
elif hutchinson_type == 'Rademacher':
epsilon = torch.randint_like(data, low=0, high=2).float() * 2 - 1.
raise NotImplementedError(f"Hutchinson type {hutchinson_type} unknown.")
def ode_func(t, x):
sample = mutils.from_flattened_numpy(x[:-shape[0]], shape).to(data.device).type(torch.float32)
vec_t = torch.ones(sample.shape[0], device=sample.device) * t
drift = mutils.to_flattened_numpy(drift_fn(model, sample, vec_t))
logp_grad = mutils.to_flattened_numpy(div_fn(model, sample, vec_t, epsilon))
return np.concatenate([drift, logp_grad], axis=0)
init = np.concatenate([mutils.to_flattened_numpy(data), np.zeros((shape[0],))], axis=0)
solution = integrate.solve_ivp(ode_func, (eps, sde.T), init, rtol=rtol, atol=atol, method=method)
nfe = solution.nfev
zp = solution.y[:, -1]
z = mutils.from_flattened_numpy(zp[:-shape[0]], shape).to(data.device).type(torch.float32)
delta_logp = mutils.from_flattened_numpy(zp[-shape[0]:], (shape[0],)).to(data.device).type(torch.float32)
prior_logp = sde.prior_logp(z)
bpd = -(prior_logp + delta_logp) / np.log(2)
N = np.prod(shape[1:])
bpd = bpd / N
# A hack to convert log-likelihoods to bits/dim
offset = 7. - inverse_scaler(-1.)
bpd = bpd + offset
return bpd, z, nfe
return likelihood_fn

@ -0,0 +1,141 @@
# coding=utf-8
# Copyright 2020 The Google Research Authors.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.
"""All functions related to loss computation and optimization.
import torch
import torch.optim as optim
import numpy as np
from .models import utils as mutils
from .sde_lib import VPSDE
def get_optimizer(config, params):
"""Returns a flax optimizer object based on `config`."""
if config.optim.optimizer == 'Adam':
optimizer = optim.Adam(params, lr=config.optim.lr, betas=(config.optim.beta1, 0.999), eps=config.optim.eps,
raise NotImplementedError(
f'Optimizer {config.optim.optimizer} not supported yet!')
return optimizer
def optimization_manager(config):
"""Returns an optimize_fn based on `config`."""
def optimize_fn(optimizer, params, step, lr=config.optim.lr,
"""Optimizes with warmup and gradient clipping (disabled if negative)."""
if warmup > 0:
for g in optimizer.param_groups:
g['lr'] = lr * np.minimum(step / warmup, 1.0)
if grad_clip >= 0:
torch.nn.utils.clip_grad_norm_(params, max_norm=grad_clip)
return optimize_fn
def get_ddpm_loss_fn(vpsde, train, mask=None, loss_type='l2'):
"""Legacy code to reproduce previous results on DDPM. Not recommended for new work."""
def loss_fn(model, batch):
model_fn = mutils.get_model_fn(model, train=train)
labels = torch.randint(0, vpsde.N, (batch.shape[0],), device=batch.device)
sqrt_alphas_cumprod = vpsde.sqrt_alphas_cumprod.to(batch.device)
sqrt_1m_alphas_cumprod = vpsde.sqrt_1m_alphas_cumprod.to(batch.device)
noise = torch.randn_like(batch)
perturbed_data = sqrt_alphas_cumprod[labels, None, None, None, None] * batch + \
sqrt_1m_alphas_cumprod[labels, None, None, None, None] * noise
perturbed_data = perturbed_data * mask
score = model_fn(perturbed_data, labels)
if loss_type == 'l2':
losses = torch.square(score - noise)
elif loss_type == 'l1':
losses = torch.abs(score - noise)
raise NotImplementedError
if mask is not None:
losses = losses * mask
losses = torch.mean(losses.reshape(losses.shape[0], -1), dim=-1)
loss = torch.mean(losses) / mask.sum() * np.prod(mask.size())
losses = torch.mean(losses.reshape(losses.shape[0], -1), dim=-1)
loss = torch.mean(losses)
return loss
return loss_fn
def get_step_fn(sde, train, optimize_fn=None, mask=None, loss_type='l2'):
"""Create a one-step training/evaluation function.
sde: An `sde_lib.SDE` object that represents the forward SDE.
optimize_fn: An optimization function.
reduce_mean: If `True`, average the loss across data dimensions. Otherwise sum the loss across data dimensions.
continuous: `True` indicates that the model is defined to take continuous time steps.
likelihood_weighting: If `True`, weight the mixture of score matching losses according to
https://arxiv.org/abs/2101.09258; otherwise use the weighting recommended by our paper.
A one-step function for training or evaluation.
loss_fn = get_ddpm_loss_fn(sde, train, mask=mask, loss_type=loss_type)
def step_fn(state, batch, clear_grad=True, update_param=True):
"""Running one step of training or evaluation.
This function will undergo `jax.lax.scan` so that multiple steps can be pmapped and jit-compiled together
for faster execution.
state: A dictionary of training information, containing the score model, optimizer,
EMA status, and number of optimization steps.
batch: A mini-batch of training/evaluation data.
loss: The average loss value of this state.
model = state['model']
if train:
optimizer = state['optimizer']
if clear_grad:
loss = loss_fn(model, batch)
if update_param:
optimize_fn(optimizer, model.parameters(), step=state['step'])
state['step'] += 1
with torch.no_grad():
ema = state['ema']
loss = loss_fn(model, batch)
return {
'loss': loss,
return step_fn

@ -0,0 +1,15 @@
# coding=utf-8
# Copyright 2020 The Google Research Authors.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.

@ -0,0 +1,215 @@
# coding=utf-8
# Copyright 2020 The Google Research Authors.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: skip-file
"""DDPM model.
This code is the pytorch equivalent of:
import torch
import torch.nn as nn
import functools
import numpy as np
from . import utils, layers, normalization
# RefineBlock = layers.RefineBlock
# ResidualBlock = layers.ResidualBlock
ResnetBlockDDPM = layers.ResnetBlockDDPM
Upsample = layers.Upsample
Downsample = layers.Downsample
conv3x3 = layers.ddpm_conv3x3
conv5x5 = layers.ddpm_conv5x5
transposed_conv6x6 = layers.ddpm_conv6x6_transposed
get_act = layers.get_act
get_normalization = normalization.get_normalization
default_initializer = layers.default_init
class DDPMRes128(nn.Module):
def __init__(self, config):
self.act = act = get_act(config)
self.register_buffer('sigmas', torch.tensor(utils.get_sigmas(config)))
self.nf = nf = config.model.nf
ch_mult = config.model.ch_mult
self.num_res_blocks = num_res_blocks = config.model.num_res_blocks
self.attn_resolutions = attn_resolutions = config.model.attn_resolutions
dropout = config.model.dropout
resamp_with_conv = config.model.resamp_with_conv
self.num_resolutions = num_resolutions = len(ch_mult)
self.all_resolutions = all_resolutions = [config.data.image_size // (2 ** i) for i in range(num_resolutions)] ## manual for 128 to 64
AttnBlock = functools.partial(layers.AttnBlock)
self.conditional = conditional = config.model.conditional
ResnetBlock = functools.partial(ResnetBlockDDPM, act=act, temb_dim=4 * nf, dropout=dropout)
if conditional:
# Condition on noise levels.
modules = [nn.Linear(nf, nf * 4)]
modules[0].weight.data = default_initializer()(modules[0].weight.data.shape)
modules.append(nn.Linear(nf * 4, nf * 4))
modules[1].weight.data = default_initializer()(modules[1].weight.data.shape)
self.centered = config.data.centered
channels = config.data.num_channels
##### Pos Encoding
self.img_size = img_size = config.data.image_size
self.num_freq = int(np.log2(img_size))
coord_x, coord_y, coord_z = torch.meshgrid(torch.arange(img_size), torch.arange(img_size), torch.arange(img_size))
self.use_coords = False
if self.use_coords:
self.coords = torch.nn.Parameter(
# torch.stack([coord_x, coord_y, coord_z]).view(1, 3, img_size, img_size, img_size) * 0.0,
torch.stack([coord_x, coord_y, coord_z]).view(1, 3, img_size, img_size, img_size),
### Mask
self.mask = torch.nn.Parameter(torch.zeros(1, 1, img_size, img_size, img_size), requires_grad=False)
# Downsampling block
self.pos_layer = conv5x5(3, nf, stride=1, padding=2)
self.mask_layer = conv5x5(1, nf, stride=1, padding=2)
modules.append(conv5x5(channels, nf, stride=1, padding=2))
hs_c = [nf]
in_ch = nf
for i_level in range(num_resolutions):
num_res_blocks_curr = self.num_res_blocks if i_level != 0 else 2
# Residual blocks for this resolution
for i_block in range(num_res_blocks_curr):
out_ch = nf * ch_mult[i_level]
modules.append(ResnetBlock(in_ch=in_ch, out_ch=out_ch))
in_ch = out_ch
if all_resolutions[i_level] in attn_resolutions:
if i_level != num_resolutions - 1:
modules.append(Downsample(channels=in_ch, with_conv=resamp_with_conv))
in_ch = hs_c[-1]
# Upsampling block
for i_level in reversed(range(num_resolutions)):
num_res_blocks_curr = self.num_res_blocks if i_level != 0 else 2
for i_block in range(num_res_blocks_curr + 1):
out_ch = nf * ch_mult[i_level]
modules.append(ResnetBlock(in_ch=in_ch + hs_c.pop(), out_ch=out_ch))
in_ch = out_ch
if all_resolutions[i_level] in attn_resolutions:
if i_level != 0:
modules.append(Upsample(channels=in_ch, with_conv=resamp_with_conv))
assert not hs_c
modules.append(nn.GroupNorm(num_channels=in_ch, num_groups=32, eps=1e-6))
# modules.append(conv3x3(in_ch, channels, init_scale=0.))
# modules.append(transposed_conv6x6(in_ch, channels, init_scale=0.))
modules.append(conv5x5(in_ch, channels, init_scale=0., stride=1, padding=2))
self.all_modules = nn.ModuleList(modules)
self.scale_by_sigma = config.model.scale_by_sigma
def forward(self, x, labels):
modules = self.all_modules
m_idx = 0
if self.conditional:
# timestep/scale embedding
timesteps = labels
temb = layers.get_timestep_embedding(timesteps, self.nf)
temb = modules[m_idx](temb)
m_idx += 1
temb = modules[m_idx](self.act(temb))
m_idx += 1
temb = None
if self.centered:
# Input is in [-1, 1]
h = x
# Input is in [0, 1]
h = 2 * x - 1.
# Downsampling block
if self.use_coords:
hs = [modules[m_idx](h) + self.pos_layer(self.coords) + self.mask_layer(self.mask)]
hs = [modules[m_idx](h) + self.mask_layer(self.mask)]
m_idx += 1
for i_level in range(self.num_resolutions):
# Residual blocks for this resolution
num_res_blocks = self.num_res_blocks if i_level != 0 else 2
for i_block in range(num_res_blocks):
h = modules[m_idx](hs[-1], temb)
m_idx += 1
if h.shape[-1] in self.attn_resolutions:
h = modules[m_idx](h)
m_idx += 1
if i_level != self.num_resolutions - 1:
m_idx += 1
h = hs[-1]
h = modules[m_idx](h, temb)
m_idx += 1
h = modules[m_idx](h)
m_idx += 1
h = modules[m_idx](h, temb)
m_idx += 1
# Upsampling block
for i_level in reversed(range(self.num_resolutions)):
num_res_blocks = self.num_res_blocks if i_level != 0 else 2
for i_block in range(num_res_blocks + 1):
hspop = hs.pop()
input = torch.cat([h, hspop], dim=1)
h = modules[m_idx](input, temb)
m_idx += 1
if h.shape[-1] in self.attn_resolutions:
h = modules[m_idx](h)
m_idx += 1
if i_level != 0:
h = modules[m_idx](h)
m_idx += 1
assert not hs
h = self.act(modules[m_idx](h))
m_idx += 1
h = modules[m_idx](h)
m_idx += 1
assert m_idx == len(modules)
if self.scale_by_sigma:
# Divide the output by sigmas. Useful for training with the NCSN loss.
# The DDPM loss scales the network output by sigma in the loss function,
# so no need of doing it here.
used_sigmas = self.sigmas[labels, None, None, None]
h = h / used_sigmas
return h

@ -0,0 +1,199 @@
# coding=utf-8
# Copyright 2020 The Google Research Authors.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: skip-file
"""DDPM model.
This code is the pytorch equivalent of:
import torch
import torch.nn as nn
import functools
import numpy as np
from . import utils, layers, normalization
# RefineBlock = layers.RefineBlock
# ResidualBlock = layers.ResidualBlock
ResnetBlockDDPM = layers.ResnetBlockDDPM
Upsample = layers.Upsample
Downsample = layers.Downsample
conv3x3 = layers.ddpm_conv3x3
get_act = layers.get_act
get_normalization = normalization.get_normalization
default_initializer = layers.default_init
class DDPMRes64(nn.Module):
def __init__(self, config):
self.act = act = get_act(config)
self.register_buffer('sigmas', torch.tensor(utils.get_sigmas(config)))
self.nf = nf = config.model.nf
ch_mult = config.model.ch_mult
self.num_res_blocks = num_res_blocks = config.model.num_res_blocks
self.attn_resolutions = attn_resolutions = config.model.attn_resolutions
dropout = config.model.dropout
resamp_with_conv = config.model.resamp_with_conv
self.num_resolutions = num_resolutions = len(ch_mult)
self.all_resolutions = all_resolutions = [config.data.image_size // (2 ** i) for i in range(num_resolutions)]
AttnBlock = functools.partial(layers.AttnBlock)
self.conditional = conditional = config.model.conditional
ResnetBlock = functools.partial(ResnetBlockDDPM, act=act, temb_dim=4 * nf, dropout=dropout)
if conditional:
# Condition on noise levels.
modules = [nn.Linear(nf, nf * 4)]
modules[0].weight.data = default_initializer()(modules[0].weight.data.shape)
modules.append(nn.Linear(nf * 4, nf * 4))
modules[1].weight.data = default_initializer()(modules[1].weight.data.shape)
self.centered = config.data.centered
channels = config.data.num_channels
##### Pos Encoding
self.img_size = img_size = config.data.image_size
self.num_freq = int(np.log2(img_size))
coord_x, coord_y, coord_z = torch.meshgrid(torch.arange(img_size), torch.arange(img_size), torch.arange(img_size))
self.coords = torch.nn.Parameter(
torch.stack([coord_x, coord_y, coord_z]).view(1, 3, img_size, img_size, img_size) * 0.0,
### Mask
self.mask = torch.nn.Parameter(torch.zeros(1, 1, img_size, img_size, img_size), requires_grad=False)
# Downsampling block
self.pos_layer = conv3x3(3, nf)
self.mask_layer = conv3x3(1, nf)
modules.append(conv3x3(channels, nf))
hs_c = [nf]
in_ch = nf
for i_level in range(num_resolutions):
# Residual blocks for this resolution
for i_block in range(num_res_blocks):
out_ch = nf * ch_mult[i_level]
modules.append(ResnetBlock(in_ch=in_ch, out_ch=out_ch))
in_ch = out_ch
if all_resolutions[i_level] in attn_resolutions:
if i_level != num_resolutions - 1:
modules.append(Downsample(channels=in_ch, with_conv=resamp_with_conv))
in_ch = hs_c[-1]
# Upsampling block
for i_level in reversed(range(num_resolutions)):
for i_block in range(num_res_blocks + 1):
out_ch = nf * ch_mult[i_level]
modules.append(ResnetBlock(in_ch=in_ch + hs_c.pop(), out_ch=out_ch))
in_ch = out_ch
if all_resolutions[i_level] in attn_resolutions:
if i_level != 0:
modules.append(Upsample(channels=in_ch, with_conv=resamp_with_conv))
assert not hs_c
modules.append(nn.GroupNorm(num_channels=in_ch, num_groups=32, eps=1e-6))
modules.append(conv3x3(in_ch, channels, init_scale=0.))
self.all_modules = nn.ModuleList(modules)
self.scale_by_sigma = config.model.scale_by_sigma
def forward(self, x, labels):
modules = self.all_modules
m_idx = 0
if self.conditional:
# timestep/scale embedding
timesteps = labels
temb = layers.get_timestep_embedding(timesteps, self.nf)
temb = modules[m_idx](temb)
m_idx += 1
temb = modules[m_idx](self.act(temb))
m_idx += 1
temb = None
if self.centered:
# Input is in [-1, 1]
h = x
# Input is in [0, 1]
h = 2 * x - 1.
# Downsampling block
hs = [modules[m_idx](h) + self.pos_layer(self.coords) + self.mask_layer(self.mask)]
m_idx += 1
for i_level in range(self.num_resolutions):
# Residual blocks for this resolution
for i_block in range(self.num_res_blocks):
h = modules[m_idx](hs[-1], temb)
m_idx += 1
if h.shape[-1] in self.attn_resolutions:
h = modules[m_idx](h)
m_idx += 1
if i_level != self.num_resolutions - 1:
m_idx += 1
h = hs[-1]
h = modules[m_idx](h, temb)
m_idx += 1
h = modules[m_idx](h)
m_idx += 1
h = modules[m_idx](h, temb)
m_idx += 1
# Upsampling block
for i_level in reversed(range(self.num_resolutions)):
for i_block in range(self.num_res_blocks + 1):
hspop = hs.pop()
input = torch.cat([h, hspop], dim=1)
h = modules[m_idx](input, temb)
m_idx += 1
if h.shape[-1] in self.attn_resolutions:
h = modules[m_idx](h)
m_idx += 1
if i_level != 0:
h = modules[m_idx](h)
m_idx += 1
assert not hs
h = self.act(modules[m_idx](h))
m_idx += 1
h = modules[m_idx](h)
m_idx += 1
assert m_idx == len(modules)
if self.scale_by_sigma:
# Divide the output by sigmas. Useful for training with the NCSN loss.
# The DDPM loss scales the network output by sigma in the loss function,
# so no need of doing it here.
used_sigmas = self.sigmas[labels, None, None, None]
h = h / used_sigmas
return h

@ -0,0 +1,98 @@
# Modified from https://raw.githubusercontent.com/fadel/pytorch_ema/master/torch_ema/ema.py
from __future__ import division
from __future__ import unicode_literals
import torch
# Partially based on: https://github.com/tensorflow/tensorflow/blob/r1.13/tensorflow/python/training/moving_averages.py
class ExponentialMovingAverage:
Maintains (exponential) moving average of a set of parameters.
def __init__(self, parameters, decay, use_num_updates=True):
parameters: Iterable of `torch.nn.Parameter`; usually the result of
decay: The exponential decay.
use_num_updates: Whether to use number of updates when computing
if decay < 0.0 or decay > 1.0:
raise ValueError('Decay must be between 0 and 1')
self.decay = decay
self.num_updates = 0 if use_num_updates else None
self.shadow_params = [p.clone().detach()
for p in parameters if p.requires_grad]
self.collected_params = []
def update(self, parameters):
Update currently maintained parameters.
Call this every time the parameters are updated, such as the result of
the `optimizer.step()` call.
parameters: Iterable of `torch.nn.Parameter`; usually the same set of
parameters used to initialize this object.
decay = self.decay
if self.num_updates is not None:
self.num_updates += 1
decay = min(decay, (1 + self.num_updates) / (10 + self.num_updates))
one_minus_decay = 1.0 - decay
with torch.no_grad():
parameters = [p for p in parameters if p.requires_grad]
for s_param, param in zip(self.shadow_params, parameters):
s_param.sub_(one_minus_decay * (s_param - param))
def copy_to(self, parameters):
Copy current parameters into given collection of parameters.
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
updated with the stored moving averages.
parameters = [p for p in parameters if p.requires_grad]
for s_param, param in zip(self.shadow_params, parameters):
if param.requires_grad:
def store(self, parameters):
Save the current parameters for restoring later.
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
temporarily stored.
self.collected_params = [param.clone() for param in parameters]
def restore(self, parameters):
Restore the parameters stored with the `store` method.
Useful to validate the model with EMA parameters without affecting the
original optimization process. Store the parameters before the
`copy_to` method. After validation (or model saving), use this to
restore the former parameters.
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
updated with the stored parameters.
for c_param, param in zip(self.collected_params, parameters):
def state_dict(self):
return dict(decay=self.decay, num_updates=self.num_updates,
def load_state_dict(self, state_dict):
self.decay = state_dict['decay']
self.num_updates = state_dict['num_updates']
self.shadow_params = state_dict['shadow_params']

@ -0,0 +1,771 @@
# coding=utf-8
# Copyright 2020 The Google Research Authors.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: skip-file
"""Common layers for defining score networks.
import math
import string
from functools import partial
import torch.nn as nn
import torch
import torch.nn.functional as F
import numpy as np
from .normalization import ConditionalInstanceNorm3dPlus
def get_act(config):
"""Get activation functions from the config file."""
if config.model.nonlinearity.lower() == 'elu':
return nn.ELU()
elif config.model.nonlinearity.lower() == 'relu':
return nn.ReLU()
elif config.model.nonlinearity.lower() == 'lrelu':
return nn.LeakyReLU(negative_slope=0.2)
elif config.model.nonlinearity.lower() == 'swish':
return nn.SiLU()
raise NotImplementedError('activation function does not exist!')
def ncsn_conv1x1(in_planes, out_planes, stride=1, bias=True, dilation=1, init_scale=1., padding=0):
"""1x1 convolution. Same as NCSNv1/v2."""
conv = nn.Conv3d(in_planes, out_planes, kernel_size=1, stride=stride, bias=bias, dilation=dilation,
init_scale = 1e-10 if init_scale == 0 else init_scale
conv.weight.data *= init_scale
conv.bias.data *= init_scale
return conv
def variance_scaling(scale, mode, distribution,
in_axis=1, out_axis=0,
"""Ported from JAX. """
def _compute_fans(shape, in_axis=1, out_axis=0):
receptive_field_size = np.prod(shape) / shape[in_axis] / shape[out_axis]
fan_in = shape[in_axis] * receptive_field_size
fan_out = shape[out_axis] * receptive_field_size
return fan_in, fan_out
def init(shape, dtype=dtype, device=device):
fan_in, fan_out = _compute_fans(shape, in_axis, out_axis)
if mode == "fan_in":
denominator = fan_in
elif mode == "fan_out":
denominator = fan_out
elif mode == "fan_avg":
denominator = (fan_in + fan_out) / 2
raise ValueError(
"invalid mode for variance scaling initializer: {}".format(mode))
variance = scale / denominator
if distribution == "normal":
return torch.randn(*shape, dtype=dtype, device=device) * np.sqrt(variance)
elif distribution == "uniform":
return (torch.rand(*shape, dtype=dtype, device=device) * 2. - 1.) * np.sqrt(3 * variance)
raise ValueError("invalid distribution for variance scaling initializer")
return init
def default_init(scale=1.):
"""The same initialization used in DDPM."""
scale = 1e-10 if scale == 0 else scale
return variance_scaling(scale, 'fan_avg', 'uniform')
class Dense(nn.Module):
"""Linear layer with `default_init`."""
def __init__(self):
def ddpm_conv1x1(in_planes, out_planes, stride=1, bias=True, init_scale=1., padding=0):
"""1x1 convolution with DDPM initialization."""
conv = nn.Conv3d(in_planes, out_planes, kernel_size=1, stride=stride, padding=padding, bias=bias)
conv.weight.data = default_init(init_scale)(conv.weight.data.shape)
return conv
def ncsn_conv3x3(in_planes, out_planes, stride=1, bias=True, dilation=1, init_scale=1., padding=1):
"""3x3 convolution with PyTorch initialization. Same as NCSNv1/NCSNv2."""
init_scale = 1e-10 if init_scale == 0 else init_scale
conv = nn.Conv3d(in_planes, out_planes, stride=stride, bias=bias,
dilation=dilation, padding=padding, kernel_size=3)
conv.weight.data *= init_scale
conv.bias.data *= init_scale
return conv
def ddpm_conv3x3(in_planes, out_planes, stride=1, bias=True, dilation=1, init_scale=1., padding=1):
"""3x3 convolution with DDPM initialization."""
conv = nn.Conv3d(in_planes, out_planes, kernel_size=3, stride=stride, padding=padding,
dilation=dilation, bias=bias)
conv.weight.data = default_init(init_scale)(conv.weight.data.shape)
return conv
def ddpm_conv5x5(in_planes, out_planes, stride=2, bias=True, dilation=1, init_scale=1., padding=2):
"""3x3 convolution with DDPM initialization."""
conv = nn.Conv3d(in_planes, out_planes, kernel_size=5, stride=stride, padding=padding,
dilation=dilation, bias=bias)
conv.weight.data = default_init(init_scale)(conv.weight.data.shape)
return conv
def ddpm_conv5x5_transposed(in_planes, out_planes, stride=2, bias=True, dilation=1, init_scale=1., padding=2):
"""3x3 convolution with DDPM initialization."""
conv = nn.ConvTranspose3d(in_planes, out_planes, kernel_size=5, stride=stride, padding=padding,
dilation=dilation, bias=bias, output_padding=(0, 1))
conv.weight.data = default_init(init_scale)(conv.weight.data.shape)
return conv
def ddpm_conv6x6_transposed(in_planes, out_planes, stride=2, bias=True, dilation=1, init_scale=1., padding=2):
"""3x3 convolution with DDPM initialization."""
conv = nn.ConvTranspose3d(in_planes, out_planes, kernel_size=6, stride=stride, padding=padding,
dilation=dilation, bias=bias)
conv.weight.data = default_init(init_scale)(conv.weight.data.shape)
return conv
# Functions below are ported over from the NCSNv1/NCSNv2 codebase:
# https://github.com/ermongroup/ncsn
# https://github.com/ermongroup/ncsnv2
class CRPBlock(nn.Module):
def __init__(self, features, n_stages, act=nn.ReLU(), maxpool=True):
self.convs = nn.ModuleList()
for i in range(n_stages):
self.convs.append(ncsn_conv3x3(features, features, stride=1, bias=False))
self.n_stages = n_stages
if maxpool:
self.pool = nn.MaxPool3d(kernel_size=5, stride=1, padding=2)
self.pool = nn.AvgPool3d(kernel_size=5, stride=1, padding=2)
self.act = act
def forward(self, x):
x = self.act(x)
path = x
for i in range(self.n_stages):
path = self.pool(path)
path = self.convs[i](path)
x = path + x
return x
class CondCRPBlock(nn.Module):
def __init__(self, features, n_stages, num_classes, normalizer, act=nn.ReLU()):
self.convs = nn.ModuleList()
self.norms = nn.ModuleList()
self.normalizer = normalizer
for i in range(n_stages):
self.norms.append(normalizer(features, num_classes, bias=True))
self.convs.append(ncsn_conv3x3(features, features, stride=1, bias=False))
self.n_stages = n_stages
self.pool = nn.AvgPool3d(kernel_size=5, stride=1, padding=2)
self.act = act
def forward(self, x, y):
x = self.act(x)
path = x
for i in range(self.n_stages):
path = self.norms[i](path, y)
path = self.pool(path)
path = self.convs[i](path)
x = path + x
return x
class RCUBlock(nn.Module):
def __init__(self, features, n_blocks, n_stages, act=nn.ReLU()):
for i in range(n_blocks):
for j in range(n_stages):
setattr(self, '{}_{}_conv'.format(i + 1, j + 1), ncsn_conv3x3(features, features, stride=1, bias=False))
self.stride = 1
self.n_blocks = n_blocks
self.n_stages = n_stages
self.act = act
def forward(self, x):
for i in range(self.n_blocks):
residual = x
for j in range(self.n_stages):
x = self.act(x)
x = getattr(self, '{}_{}_conv'.format(i + 1, j + 1))(x)
x += residual
return x
class CondRCUBlock(nn.Module):
def __init__(self, features, n_blocks, n_stages, num_classes, normalizer, act=nn.ReLU()):
for i in range(n_blocks):
for j in range(n_stages):
setattr(self, '{}_{}_norm'.format(i + 1, j + 1), normalizer(features, num_classes, bias=True))
setattr(self, '{}_{}_conv'.format(i + 1, j + 1), ncsn_conv3x3(features, features, stride=1, bias=False))
self.stride = 1
self.n_blocks = n_blocks
self.n_stages = n_stages
self.act = act
self.normalizer = normalizer
def forward(self, x, y):
for i in range(self.n_blocks):
residual = x
for j in range(self.n_stages):
x = getattr(self, '{}_{}_norm'.format(i + 1, j + 1))(x, y)
x = self.act(x)
x = getattr(self, '{}_{}_conv'.format(i + 1, j + 1))(x)
x += residual
return x
class MSFBlock(nn.Module):
def __init__(self, in_planes, features):
assert isinstance(in_planes, list) or isinstance(in_planes, tuple)
self.convs = nn.ModuleList()
self.features = features
for i in range(len(in_planes)):
self.convs.append(ncsn_conv3x3(in_planes[i], features, stride=1, bias=True))
def forward(self, xs, shape):
sums = torch.zeros(xs[0].shape[0], self.features, *shape, device=xs[0].device)
for i in range(len(self.convs)):
h = self.convs[i](xs[i])
h = F.interpolate(h, size=shape, mode='bilinear', align_corners=True)
sums += h
return sums
class CondMSFBlock(nn.Module):
def __init__(self, in_planes, features, num_classes, normalizer):
assert isinstance(in_planes, list) or isinstance(in_planes, tuple)
self.convs = nn.ModuleList()
self.norms = nn.ModuleList()
self.features = features
self.normalizer = normalizer
for i in range(len(in_planes)):
self.convs.append(ncsn_conv3x3(in_planes[i], features, stride=1, bias=True))
self.norms.append(normalizer(in_planes[i], num_classes, bias=True))
def forward(self, xs, y, shape):
sums = torch.zeros(xs[0].shape[0], self.features, *shape, device=xs[0].device)
for i in range(len(self.convs)):
h = self.norms[i](xs[i], y)
h = self.convs[i](h)
h = F.interpolate(h, size=shape, mode='bilinear', align_corners=True)
sums += h
return sums
class RefineBlock(nn.Module):
def __init__(self, in_planes, features, act=nn.ReLU(), start=False, end=False, maxpool=True):
assert isinstance(in_planes, tuple) or isinstance(in_planes, list)
self.n_blocks = n_blocks = len(in_planes)
self.adapt_convs = nn.ModuleList()
for i in range(n_blocks):
self.adapt_convs.append(RCUBlock(in_planes[i], 2, 2, act))
self.output_convs = RCUBlock(features, 3 if end else 1, 2, act)
if not start:
self.msf = MSFBlock(in_planes, features)
self.crp = CRPBlock(features, 2, act, maxpool=maxpool)
def forward(self, xs, output_shape):
assert isinstance(xs, tuple) or isinstance(xs, list)
hs = []
for i in range(len(xs)):
h = self.adapt_convs[i](xs[i])
if self.n_blocks > 1:
h = self.msf(hs, output_shape)
h = hs[0]
h = self.crp(h)
h = self.output_convs(h)
return h
class CondRefineBlock(nn.Module):
def __init__(self, in_planes, features, num_classes, normalizer, act=nn.ReLU(), start=False, end=False):
assert isinstance(in_planes, tuple) or isinstance(in_planes, list)
self.n_blocks = n_blocks = len(in_planes)
self.adapt_convs = nn.ModuleList()
for i in range(n_blocks):
CondRCUBlock(in_planes[i], 2, 2, num_classes, normalizer, act)
self.output_convs = CondRCUBlock(features, 3 if end else 1, 2, num_classes, normalizer, act)
if not start:
self.msf = CondMSFBlock(in_planes, features, num_classes, normalizer)
self.crp = CondCRPBlock(features, 2, num_classes, normalizer, act)
def forward(self, xs, y, output_shape):
assert isinstance(xs, tuple) or isinstance(xs, list)
hs = []
for i in range(len(xs)):
h = self.adapt_convs[i](xs[i], y)
if self.n_blocks > 1:
h = self.msf(hs, y, output_shape)
h = hs[0]
h = self.crp(h, y)
h = self.output_convs(h, y)
return h
class ConvMeanPool(nn.Module):
def __init__(self, input_dim, output_dim, kernel_size=3, biases=True, adjust_padding=False):
if not adjust_padding:
conv = nn.Conv3d(input_dim, output_dim, kernel_size, stride=1, padding=kernel_size // 2, bias=biases)
self.conv = conv
conv = nn.Conv3d(input_dim, output_dim, kernel_size, stride=1, padding=kernel_size // 2, bias=biases)
self.conv = nn.Sequential(
nn.ZeroPad3d((1, 0, 1, 0)),
def forward(self, inputs):
output = self.conv(inputs)
output = sum([output[:, :, ::2, ::2], output[:, :, 1::2, ::2],
output[:, :, ::2, 1::2], output[:, :, 1::2, 1::2]]) / 4.
return output
class MeanPoolConv(nn.Module):
def __init__(self, input_dim, output_dim, kernel_size=3, biases=True):
self.conv = nn.Conv3d(input_dim, output_dim, kernel_size, stride=1, padding=kernel_size // 2, bias=biases)
def forward(self, inputs):
output = inputs
output = sum([output[:, :, ::2, ::2], output[:, :, 1::2, ::2],
output[:, :, ::2, 1::2], output[:, :, 1::2, 1::2]]) / 4.
return self.conv(output)
class UpsampleConv(nn.Module):
def __init__(self, input_dim, output_dim, kernel_size=3, biases=True):
self.conv = nn.Conv3d(input_dim, output_dim, kernel_size, stride=1, padding=kernel_size // 2, bias=biases)
self.pixelshuffle = nn.PixelShuffle(upscale_factor=2)
def forward(self, inputs):
output = inputs
output = torch.cat([output, output, output, output], dim=1)
output = self.pixelshuffle(output)
return self.conv(output)
class ConditionalResidualBlock(nn.Module):
def __init__(self, input_dim, output_dim, num_classes, resample=1, act=nn.ELU(),
normalization=ConditionalInstanceNorm3dPlus, adjust_padding=False, dilation=None):
self.non_linearity = act
self.input_dim = input_dim
self.output_dim = output_dim
self.resample = resample
self.normalization = normalization
if resample == 'down':
if dilation > 1:
self.conv1 = ncsn_conv3x3(input_dim, input_dim, dilation=dilation)
self.normalize2 = normalization(input_dim, num_classes)
self.conv2 = ncsn_conv3x3(input_dim, output_dim, dilation=dilation)
conv_shortcut = partial(ncsn_conv3x3, dilation=dilation)
self.conv1 = ncsn_conv3x3(input_dim, input_dim)
self.normalize2 = normalization(input_dim, num_classes)
self.conv2 = ConvMeanPool(input_dim, output_dim, 3, adjust_padding=adjust_padding)
conv_shortcut = partial(ConvMeanPool, kernel_size=1, adjust_padding=adjust_padding)
elif resample is None:
if dilation > 1:
conv_shortcut = partial(ncsn_conv3x3, dilation=dilation)
self.conv1 = ncsn_conv3x3(input_dim, output_dim, dilation=dilation)
self.normalize2 = normalization(output_dim, num_classes)
self.conv2 = ncsn_conv3x3(output_dim, output_dim, dilation=dilation)
conv_shortcut = nn.Conv3d
self.conv1 = ncsn_conv3x3(input_dim, output_dim)
self.normalize2 = normalization(output_dim, num_classes)
self.conv2 = ncsn_conv3x3(output_dim, output_dim)
raise Exception('invalid resample value')
if output_dim != input_dim or resample is not None:
self.shortcut = conv_shortcut(input_dim, output_dim)
self.normalize1 = normalization(input_dim, num_classes)
def forward(self, x, y):
output = self.normalize1(x, y)
output = self.non_linearity(output)
output = self.conv1(output)
output = self.normalize2(output, y)
output = self.non_linearity(output)
output = self.conv2(output)
if self.output_dim == self.input_dim and self.resample is None:
shortcut = x
shortcut = self.shortcut(x)
return shortcut + output
class ResidualBlock(nn.Module):
def __init__(self, input_dim, output_dim, resample=None, act=nn.ELU(),
normalization=nn.InstanceNorm3d, adjust_padding=False, dilation=1):
self.non_linearity = act
self.input_dim = input_dim
self.output_dim = output_dim
self.resample = resample
self.normalization = normalization
if resample == 'down':
if dilation > 1:
self.conv1 = ncsn_conv3x3(input_dim, input_dim, dilation=dilation)
self.normalize2 = normalization(input_dim)
self.conv2 = ncsn_conv3x3(input_dim, output_dim, dilation=dilation)
conv_shortcut = partial(ncsn_conv3x3, dilation=dilation)
self.conv1 = ncsn_conv3x3(input_dim, input_dim)
self.normalize2 = normalization(input_dim)
self.conv2 = ConvMeanPool(input_dim, output_dim, 3, adjust_padding=adjust_padding)
conv_shortcut = partial(ConvMeanPool, kernel_size=1, adjust_padding=adjust_padding)
elif resample is None:
if dilation > 1:
conv_shortcut = partial(ncsn_conv3x3, dilation=dilation)
self.conv1 = ncsn_conv3x3(input_dim, output_dim, dilation=dilation)
self.normalize2 = normalization(output_dim)
self.conv2 = ncsn_conv3x3(output_dim, output_dim, dilation=dilation)
# conv_shortcut = nn.Conv3d ### Something wierd here.
conv_shortcut = partial(ncsn_conv1x1)
self.conv1 = ncsn_conv3x3(input_dim, output_dim)
self.normalize2 = normalization(output_dim)
self.conv2 = ncsn_conv3x3(output_dim, output_dim)
raise Exception('invalid resample value')
if output_dim != input_dim or resample is not None:
self.shortcut = conv_shortcut(input_dim, output_dim)
self.normalize1 = normalization(input_dim)
def forward(self, x):
output = self.normalize1(x)
output = self.non_linearity(output)
output = self.conv1(output)
output = self.normalize2(output)
output = self.non_linearity(output)
output = self.conv2(output)
if self.output_dim == self.input_dim and self.resample is None:
shortcut = x
shortcut = self.shortcut(x)
return shortcut + output
# Functions below are ported over from the DDPM codebase:
# https://github.com/hojonathanho/diffusion/blob/master/diffusion_tf/nn.py
def get_timestep_embedding(timesteps, embedding_dim, max_positions=10000):
assert len(timesteps.shape) == 1 # and timesteps.dtype == tf.int32
half_dim = embedding_dim // 2
# magic number 10000 is from transformers
emb = math.log(max_positions) / (half_dim - 1)
# emb = math.log(2.) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) * -emb)
# emb = tf.range(num_embeddings, dtype=jnp.float32)[:, None] * emb[None, :]
# emb = tf.cast(timesteps, dtype=jnp.float32)[:, None] * emb[None, :]
emb = timesteps.float()[:, None] * emb[None, :]
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
if embedding_dim % 2 == 1: # zero pad
emb = F.pad(emb, (0, 1), mode='constant')
assert emb.shape == (timesteps.shape[0], embedding_dim)
return emb
def _einsum(a, b, c, x, y):
einsum_str = '{},{}->{}'.format(''.join(a), ''.join(b), ''.join(c))
return torch.einsum(einsum_str, x, y)
def contract_inner(x, y):
"""tensordot(x, y, 1)."""
x_chars = list(string.ascii_lowercase[:len(x.shape)])
y_chars = list(string.ascii_lowercase[len(x.shape):len(y.shape) + len(x.shape)])
y_chars[0] = x_chars[-1] # first axis of y and last of x get summed
out_chars = x_chars[:-1] + y_chars[1:]
return _einsum(x_chars, y_chars, out_chars, x, y)
class NIN(nn.Module):
def __init__(self, in_dim, num_units, init_scale=0.1):
self.W = nn.Parameter(default_init(scale=init_scale)((in_dim, num_units)), requires_grad=True)
self.b = nn.Parameter(torch.zeros(num_units), requires_grad=True)
def forward(self, x):
x = x.permute(0, 2, 3, 4, 1)
y = contract_inner(x, self.W) + self.b
return y.permute(0, 4, 1, 2, 3)
class AttnBlock(nn.Module):
"""Channel-wise self-attention block."""
def __init__(self, channels):
self.GroupNorm_0 = nn.GroupNorm(num_groups=32, num_channels=channels, eps=1e-6)
self.NIN_0 = NIN(channels, channels)
self.NIN_1 = NIN(channels, channels)
self.NIN_2 = NIN(channels, channels)
self.NIN_3 = NIN(channels, channels, init_scale=0.)
def forward(self, x):
B, C, D, H, W = x.shape
h = self.GroupNorm_0(x)
q = self.NIN_0(h)
k = self.NIN_1(h)
v = self.NIN_2(h)
w = torch.einsum('bcdhw,bckij->bdhwkij', q, k) * (int(C) ** (-0.5))
w = torch.reshape(w, (B, D, H, W, D * H * W))
w = F.softmax(w, dim=-1)
w = torch.reshape(w, (B, D, H, W, D, H, W))
h = torch.einsum('bdhwkij,bckij->bcdhw', w, v)
h = self.NIN_3(h)
return x + h
class Upsample(nn.Module):
def __init__(self, channels, with_conv=False):
if with_conv:
self.Conv_0 = ddpm_conv3x3(channels, channels)
self.with_conv = with_conv
def forward(self, x):
B, C, D, H, W = x.shape
h = F.interpolate(x, (D * 2, H * 2, W * 2), mode='nearest')
if self.with_conv:
h = self.Conv_0(h)
return h
class Downsample(nn.Module):
def __init__(self, channels, with_conv=False):
if with_conv:
self.Conv_0 = ddpm_conv3x3(channels, channels, stride=2, padding=0)
self.with_conv = with_conv
def forward(self, x):
B, C, D, H, W = x.shape
# Emulate 'SAME' padding
if self.with_conv:
x = F.pad(x, (0, 1, 0, 1, 0, 1))
x = self.Conv_0(x)
x = F.avg_pool3d(x, kernel_size=2, stride=2, padding=0)
assert x.shape == (B, C, D // 2, H // 2, W // 2)
return x
class ResnetBlockDDPM(nn.Module):
"""The ResNet Blocks used in DDPM."""
def __init__(self, act, in_ch, out_ch=None, temb_dim=None, conv_shortcut=False, dropout=0.1):
if out_ch is None:
out_ch = in_ch
self.GroupNorm_0 = nn.GroupNorm(num_groups=32, num_channels=in_ch, eps=1e-6)
self.act = act
self.Conv_0 = ddpm_conv3x3(in_ch, out_ch)
if temb_dim is not None:
self.Dense_0 = nn.Linear(temb_dim, out_ch)
self.Dense_0.weight.data = default_init()(self.Dense_0.weight.data.shape)
self.GroupNorm_1 = nn.GroupNorm(num_groups=32, num_channels=out_ch, eps=1e-6)
self.Dropout_0 = nn.Dropout(dropout)
self.Conv_1 = ddpm_conv3x3(out_ch, out_ch, init_scale=0.)
if in_ch != out_ch:
if conv_shortcut:
self.Conv_2 = ddpm_conv3x3(in_ch, out_ch)
self.NIN_0 = NIN(in_ch, out_ch)
self.out_ch = out_ch
self.in_ch = in_ch
self.conv_shortcut = conv_shortcut
def forward(self, x, temb=None):
B, C, D, H, W = x.shape
assert C == self.in_ch
out_ch = self.out_ch if self.out_ch else self.in_ch
h = self.act(self.GroupNorm_0(x))
h = self.Conv_0(h)
# Add bias to each feature map conditioned on the time embedding
if temb is not None:
h += self.Dense_0(self.act(temb))[:, :, None, None, None]
h = self.act(self.GroupNorm_1(h))
h = self.Dropout_0(h)
h = self.Conv_1(h)
if C != out_ch:
if self.conv_shortcut:
x = self.Conv_2(x)
x = self.NIN_0(x)
return x + h
# class PositionalEncoding(nn.Module):
# def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
# super().__init__()
# self.dropout = nn.Dropout(p=dropout)
# position = torch.arange(max_len).unsqueeze(1)
# div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
# pe = torch.zeros(max_len, 1, d_model)
# pe[:, 0, 0::2] = torch.sin(position * div_term)
# pe[:, 0, 1::2] = torch.cos(position * div_term)
# self.register_buffer('pe', pe)
# def forward(self, x: Tensor) -> Tensor:
# """
# Args:
# x: Tensor, shape [seq_len, batch_size, embedding_dim]
# """
# x = x + self.pe[:x.size(0)]
# return self.dropout(x)
class ResnetBlockDDPMPosEncoding(nn.Module):
"""The ResNet Blocks used in DDPM."""
def __init__(self, act, in_ch, out_ch=None, temb_dim=None, conv_shortcut=False, dropout=0.1, img_size=64):
##### Pos Encoding
coord_x, coord_y, coord_z = torch.meshgrid(torch.arange(img_size), torch.arange(img_size), torch.arange(img_size))
coords = torch.stack([coord_x, coord_y, coord_z])
self.num_freq = int(np.log2(img_size))
pos_encoding = torch.zeros(1, 2 * self.num_freq, 3, img_size, img_size, img_size)
with torch.no_grad():
for i in range(self.num_freq):
pos_encoding[0, 2*i, :, :, :, :] = torch.cos((i+1) * np.pi * coords)
pos_encoding[0, 2*i + 1, :, :, :, :] = torch.sin((i+1) * np.pi * coords)
self.pos_encoding = nn.Parameter(
pos_encoding.view(1, 2 * self.num_freq * 3, img_size, img_size, img_size) / img_size,
if out_ch is None:
out_ch = in_ch
self.GroupNorm_0 = nn.GroupNorm(num_groups=32, num_channels=in_ch, eps=1e-6)
self.act = act
self.Conv_0 = ddpm_conv3x3(in_ch, out_ch)
self.Conv_0_pos = ddpm_conv3x3(2 * self.num_freq * 3, out_ch)
if temb_dim is not None:
self.Dense_0 = nn.Linear(temb_dim, out_ch)
self.Dense_0.weight.data = default_init()(self.Dense_0.weight.data.shape)
self.GroupNorm_1 = nn.GroupNorm(num_groups=32, num_channels=out_ch, eps=1e-6)
self.Dropout_0 = nn.Dropout(dropout)
self.Conv_1 = ddpm_conv3x3(out_ch, out_ch, init_scale=0.)
if in_ch != out_ch:
if conv_shortcut:
self.Conv_2 = ddpm_conv3x3(in_ch, out_ch)
self.NIN_0 = NIN(in_ch, out_ch)
self.out_ch = out_ch
self.in_ch = in_ch
self.conv_shortcut = conv_shortcut
def forward(self, x, temb=None):
B, C, D, H, W = x.shape
assert C == self.in_ch
out_ch = self.out_ch if self.out_ch else self.in_ch
h = self.act(self.GroupNorm_0(x))
h = self.Conv_0(h) + self.Conv_0_pos(self.pos_encoding).expand(h.size(0), -1, -1, -1, -1)
# Add bias to each feature map conditioned on the time embedding
if temb is not None:
h += self.Dense_0(self.act(temb))[:, :, None, None, None]
h = self.act(self.GroupNorm_1(h))
h = self.Dropout_0(h)
h = self.Conv_1(h)
if C != out_ch:
if self.conv_shortcut:
x = self.Conv_2(x)
x = self.NIN_0(x)
return x + h

@ -0,0 +1,223 @@
# coding=utf-8
# Copyright 2020 The Google Research Authors.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.
"""Normalization layers."""
import torch.nn as nn
import torch
import functools
def get_normalization(config, conditional=False):
"""Obtain normalization modules from the config file."""
norm = config.model.normalization
if conditional:
if norm == 'InstanceNorm++':
return functools.partial(ConditionalInstanceNorm3dPlus, num_classes=config.model.num_classes)
raise NotImplementedError(f'{norm} not implemented yet.')
if norm == 'InstanceNorm':
return nn.InstanceNorm3d
elif norm == 'InstanceNorm++':
return InstanceNorm3dPlus
elif norm == 'VarianceNorm':
return VarianceNorm3d
elif norm == 'GroupNorm':
return nn.GroupNorm
raise ValueError('Unknown normalization: %s' % norm)
class ConditionalBatchNorm3d(nn.Module):
def __init__(self, num_features, num_classes, bias=True):
self.num_features = num_features
self.bias = bias
self.bn = nn.BatchNorm3d(num_features, affine=False)
if self.bias:
self.embed = nn.Embedding(num_classes, num_features * 2)
self.embed.weight.data[:, :num_features].uniform_() # Initialise scale at N(1, 0.02)
self.embed.weight.data[:, num_features:].zero_() # Initialise bias at 0
self.embed = nn.Embedding(num_classes, num_features)
def forward(self, x, y):
out = self.bn(x)
if self.bias:
gamma, beta = self.embed(y).chunk(2, dim=1)
out = gamma.view(-1, self.num_features, 1, 1, 1) * out + beta.view(-1, self.num_features, 1, 1, 1)
gamma = self.embed(y)
out = gamma.view(-1, self.num_features, 1, 1, 1) * out
return out
class ConditionalInstanceNorm3d(nn.Module):
def __init__(self, num_features, num_classes, bias=True):
self.num_features = num_features
self.bias = bias
self.instance_norm = nn.InstanceNorm3d(num_features, affine=False, track_running_stats=False)
if bias:
self.embed = nn.Embedding(num_classes, num_features * 2)
self.embed.weight.data[:, :num_features].uniform_() # Initialise scale at N(1, 0.02)
self.embed.weight.data[:, num_features:].zero_() # Initialise bias at 0
self.embed = nn.Embedding(num_classes, num_features)
def forward(self, x, y):
h = self.instance_norm(x)
if self.bias:
gamma, beta = self.embed(y).chunk(2, dim=-1)
out = gamma.view(-1, self.num_features, 1, 1, 1) * h + beta.view(-1, self.num_features, 1, 1, 1)
gamma = self.embed(y)
out = gamma.view(-1, self.num_features, 1, 1, 1) * h
return out
class ConditionalVarianceNorm3d(nn.Module):
def __init__(self, num_features, num_classes, bias=False):
self.num_features = num_features
self.bias = bias
self.embed = nn.Embedding(num_classes, num_features)
self.embed.weight.data.normal_(1, 0.02)
def forward(self, x, y):
vars = torch.var(x, dim=(2, 3, 4), keepdim=True)
h = x / torch.sqrt(vars + 1e-5)
gamma = self.embed(y)
out = gamma.view(-1, self.num_features, 1, 1, 1) * h
return out
class VarianceNorm3d(nn.Module):
def __init__(self, num_features, bias=False):
self.num_features = num_features
self.bias = bias
self.alpha = nn.Parameter(torch.zeros(num_features))
self.alpha.data.normal_(1, 0.02)
def forward(self, x):
vars = torch.var(x, dim=(2, 3, 4), keepdim=True)
h = x / torch.sqrt(vars + 1e-5)
out = self.alpha.view(-1, self.num_features, 1, 1, 1) * h
return out
class ConditionalNoneNorm3d(nn.Module):
def __init__(self, num_features, num_classes, bias=True):
self.num_features = num_features
self.bias = bias
if bias:
self.embed = nn.Embedding(num_classes, num_features * 2)
self.embed.weight.data[:, :num_features].uniform_() # Initialise scale at N(1, 0.02)
self.embed.weight.data[:, num_features:].zero_() # Initialise bias at 0
self.embed = nn.Embedding(num_classes, num_features)
def forward(self, x, y):
if self.bias:
gamma, beta = self.embed(y).chunk(2, dim=-1)
out = gamma.view(-1, self.num_features, 1, 1, 1) * x + beta.view(-1, self.num_features, 1, 1, 1)
gamma = self.embed(y)
out = gamma.view(-1, self.num_features, 1, 1, 1) * x
return out
class NoneNorm3d(nn.Module):
def __init__(self, num_features, bias=True):
def forward(self, x):
return x
class InstanceNorm3dPlus(nn.Module):
def __init__(self, num_features, bias=True):
self.num_features = num_features
self.bias = bias
self.instance_norm = nn.InstanceNorm3d(num_features, affine=False, track_running_stats=False)
self.alpha = nn.Parameter(torch.zeros(num_features))
self.gamma = nn.Parameter(torch.zeros(num_features))
self.alpha.data.normal_(1, 0.02)
self.gamma.data.normal_(1, 0.02)
if bias:
self.beta = nn.Parameter(torch.zeros(num_features))
def forward(self, x):
means = torch.mean(x, dim=(2, 3, 4))
m = torch.mean(means, dim=-1, keepdim=True)
v = torch.var(means, dim=-1, keepdim=True)
means = (means - m) / (torch.sqrt(v + 1e-5))
h = self.instance_norm(x)
if self.bias:
h = h + means[..., None, None, None] * self.alpha[..., None, None, None]
out = self.gamma.view(-1, self.num_features, 1, 1, 1) * h + self.beta.view(-1, self.num_features, 1, 1, 1)
h = h + means[..., None, None, None] * self.alpha[..., None, None, None]
out = self.gamma.view(-1, self.num_features, 1, 1, 1) * h
return out
class ConditionalInstanceNorm3dPlus(nn.Module):
def __init__(self, num_features, num_classes, bias=True):
self.num_features = num_features
self.bias = bias
self.instance_norm = nn.InstanceNorm3d(num_features, affine=False, track_running_stats=False)
if bias:
self.embed = nn.Embedding(num_classes, num_features * 3)
self.embed.weight.data[:, :2 * num_features].normal_(1, 0.02) # Initialise scale at N(1, 0.02)
self.embed.weight.data[:, 2 * num_features:].zero_() # Initialise bias at 0
self.embed = nn.Embedding(num_classes, 2 * num_features)
self.embed.weight.data.normal_(1, 0.02)
def forward(self, x, y):
means = torch.mean(x, dim=(2, 3, 4))
m = torch.mean(means, dim=-1, keepdim=True)
v = torch.var(means, dim=-1, keepdim=True)
means = (means - m) / (torch.sqrt(v + 1e-5))
h = self.instance_norm(x)
if self.bias:
gamma, alpha, beta = self.embed(y).chunk(3, dim=-1)
h = h + means[..., None, None, None] * alpha[..., None, None, None]
out = gamma.view(-1, self.num_features, 1, 1, 1) * h + beta.view(-1, self.num_features, 1, 1, 1)
gamma, alpha = self.embed(y).chunk(2, dim=-1)
h = h + means[..., None, None, None] * alpha[..., None, None, None]
out = gamma.view(-1, self.num_features, 1, 1, 1) * h
return out
def lip_weight_normalization_3d(W, softplus_c):
Lipschitz weight normalization based on the L-infinity norm (see Eq.9 in [Liu et al 2022])
absrowsum = torch.sum(torch.abs(W), dim=[1,2,3,4]) + 1e-8
scale = torch.nn.functional.relu(softplus_c/absrowsum - 1.0) + 1.0
return W * scale[:, None, None, None, None]

@ -0,0 +1,259 @@
"""Layers used for up-sampling or down-sampling images.
Many functions are ported from https://github.com/NVlabs/stylegan2.
import torch.nn as nn
import torch
import torch.nn.functional as F
import numpy as np
# from op import upfirdn2d
# Function ported from StyleGAN2
def get_weight(module,
"""Get/create weight tensor for a convolution or fully-connected layer."""
return module.param(weight_var, kernel_init, shape)
class Conv3d(nn.Module):
"""Conv3d layer with optimal upsampling and downsampling (StyleGAN2)."""
def __init__(self, in_ch, out_ch, kernel, up=False, down=False,
resample_kernel=(1, 3, 3, 1),
assert not (up and down)
assert kernel >= 1 and kernel % 2 == 1
self.weight = nn.Parameter(torch.zeros(out_ch, in_ch, kernel, kernel, kernel))
if kernel_init is not None:
self.weight.data = kernel_init(self.weight.data.shape)
if use_bias:
self.bias = nn.Parameter(torch.zeros(out_ch))
self.up = up
self.down = down
self.resample_kernel = resample_kernel
self.kernel = kernel
self.use_bias = use_bias
def forward(self, x):
if self.up:
x = upsample_conv_3d(x, self.weight, k=self.resample_kernel)
elif self.down:
x = conv_downsample_3d(x, self.weight, k=self.resample_kernel)
x = F.conv3d(x, self.weight, stride=1, padding=self.kernel // 2)
if self.use_bias:
x = x + self.bias.reshape(1, -1, 1, 1, 1)
return x
def naive_upsample_3d(x, factor=2):
_N, C, D, H, W = x.shape
x = torch.reshape(x, (-1, C, D, 1, H, 1, W, 1))
x = x.repeat(1, 1, 1, factor, 1, factor, 1, factor)
return torch.reshape(x, (-1, C, D * factor, H * factor, W * factor))
def naive_downsample_3d(x, factor=2):
_N, C, D, H, W = x.shape
x = torch.reshape(x, (-1, C, D // factor, factor, H // factor, factor, W // factor, factor))
return torch.mean(x, dim=(3, 5, 7))
def upsample_conv_3d(x, w, k=None, factor=2, gain=1):
"""Fused `upsample_3d()` followed by `tf.nn.conv3d()`.
Padding is performed only once at the beginning, not between the
The fused op is considerably more efficient than performing the same
using standard TensorFlow ops. It supports gradients of arbitrary order.
x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
w: Weight tensor of the shape `[filterH, filterW, inChannels,
outChannels]`. Grouped convolution can be performed by `inChannels =
x.shape[0] // numGroups`.
k: FIR filter of the shape `[firH, firW]` or `[firN]`
(separable). The default is `[1] * factor`, which corresponds to
nearest-neighbor upsampling.
factor: Integer upsampling factor (default: 2).
gain: Scaling factor for signal magnitude (default: 1.0).
Tensor of the shape `[N, C, H * factor, W * factor]` or
`[N, H * factor, W * factor, C]`, and same datatype as `x`.
assert isinstance(factor, int) and factor >= 1
# Check weight shape.
assert len(w.shape) == 5
convD = w.shape[2]
convH = w.shape[3]
convW = w.shape[4]
inC = w.shape[1]
outC = w.shape[0]
assert convW == convH
# Setup filter kernel.
if k is None:
k = [1] * factor
k = _setup_kernel(k) * (gain * (factor ** 2))
p = (k.shape[0] - factor) - (convW - 1)
stride = (factor, factor)
# Determine data dimensions.
stride = [1, 1, factor, factor]
output_shape = ((_shape(x, 2) - 1) * factor + convD, (_shape(x, 3) - 1) * factor + convH, (_shape(x, 4) - 1) * factor + convW)
output_padding = (output_shape[0] - (_shape(x, 2) - 1) * stride[0] - convD,
output_shape[1] - (_shape(x, 3) - 1) * stride[1] - convH,
output_shape[2] - (_shape(x, 4) - 1) * stride[2] - convW)
assert output_padding[0] >= 0 and output_padding[1] >= 0 and output_padding[2] >= 0
num_groups = _shape(x, 1) // inC
# Transpose weights.
w = torch.reshape(w, (num_groups, -1, inC, convD, convH, convW))
w = w[..., ::-1, ::-1].permute(0, 2, 1, 3, 4, 5)
w = torch.reshape(w, (num_groups * inC, -1, convD, convH, convW))
x = F.conv_transpose3d(x, w, stride=stride, output_padding=output_padding, padding=0)
## Original TF code.
# x = tf.nn.conv3d_transpose(
# x,
# w,
# output_shape=output_shape,
# strides=stride,
# padding='VALID',
# data_format=data_format)
## JAX equivalent
return upfirdn2d(x, torch.tensor(k, device=x.device),
pad=((p + 1) // 2 + factor - 1, p // 2 + 1))
def conv_downsample_3d(x, w, k=None, factor=2, gain=1):
"""Fused `tf.nn.conv3d()` followed by `downsample_3d()`.
Padding is performed only once at the beginning, not between the operations.
The fused op is considerably more efficient than performing the same
using standard TensorFlow ops. It supports gradients of arbitrary order.
x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
w: Weight tensor of the shape `[filterH, filterW, inChannels,
outChannels]`. Grouped convolution can be performed by `inChannels =
x.shape[0] // numGroups`.
k: FIR filter of the shape `[firH, firW]` or `[firN]`
(separable). The default is `[1] * factor`, which corresponds to
average pooling.
factor: Integer downsampling factor (default: 2).
gain: Scaling factor for signal magnitude (default: 1.0).
Tensor of the shape `[N, C, H // factor, W // factor]` or
`[N, H // factor, W // factor, C]`, and same datatype as `x`.
assert isinstance(factor, int) and factor >= 1
_outC, _inC, convH, convW = w.shape
assert convW == convH
if k is None:
k = [1] * factor
k = _setup_kernel(k) * gain
p = (k.shape[0] - factor) + (convW - 1)
s = [factor, factor]
x = upfirdn2d(x, torch.tensor(k, device=x.device),
pad=((p + 1) // 2, p // 2))
return F.conv3d(x, w, stride=s, padding=0)
def _setup_kernel(k):
k = np.asarray(k, dtype=np.float32)
if k.ndim == 1:
k = np.outer(k, k)
k /= np.sum(k)
assert k.ndim == 2
assert k.shape[0] == k.shape[1]
return k
def _shape(x, dim):
return x.shape[dim]
def upsample_3d(x, k=None, factor=2, gain=1):
r"""Upsample a batch of 3d images with the given filter.
Accepts a batch of 3d images of the shape `[N, C, H, W]` or `[N, H, W, C]`
and upsamples each image with the given filter. The filter is normalized so
if the input pixels are constant, they will be scaled by the specified
Pixels outside the image are assumed to be zero, and the filter is padded
zeros so that its shape is a multiple of the upsampling factor.
x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
k: FIR filter of the shape `[firH, firW]` or `[firN]`
(separable). The default is `[1] * factor`, which corresponds to
nearest-neighbor upsampling.
factor: Integer upsampling factor (default: 2).
gain: Scaling factor for signal magnitude (default: 1.0).
Tensor of the shape `[N, C, H * factor, W * factor]`
assert isinstance(factor, int) and factor >= 1
if k is None:
k = [1] * factor
k = _setup_kernel(k) * (gain * (factor ** 2))
p = k.shape[0] - factor
return upfirdn2d(x, torch.tensor(k, device=x.device),
up=factor, pad=((p + 1) // 2 + factor - 1, p // 2))
def downsample_3d(x, k=None, factor=2, gain=1):
r"""Downsample a batch of 3d images with the given filter.
Accepts a batch of 3d images of the shape `[N, C, H, W]` or `[N, H, W, C]`
and downsamples each image with the given filter. The filter is normalized
so that
if the input pixels are constant, they will be scaled by the specified
Pixels outside the image are assumed to be zero, and the filter is padded
zeros so that its shape is a multiple of the downsampling factor.
x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
k: FIR filter of the shape `[firH, firW]` or `[firN]`
(separable). The default is `[1] * factor`, which corresponds to
average pooling.
factor: Integer downsampling factor (default: 2).
gain: Scaling factor for signal magnitude (default: 1.0).
Tensor of the shape `[N, C, H // factor, W // factor]`
assert isinstance(factor, int) and factor >= 1
if k is None:
k = [1] * factor
k = _setup_kernel(k) * gain
p = k.shape[0] - factor
return upfirdn2d(x, torch.tensor(k, device=x.device),
down=factor, pad=((p + 1) // 2, p // 2))

@ -0,0 +1,213 @@
# coding=utf-8
# Copyright 2020 The Google Research Authors.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.
"""All functions and modules related to model definition.
import torch
from .. import sde_lib
import numpy as np
_MODELS = {}
def register_model(cls=None, *, name=None):
"""A decorator for registering model classes."""
def _register(cls):
if name is None:
local_name = cls.__name__
local_name = name
if local_name in _MODELS:
raise ValueError(f'Already registered model with name: {local_name}')
_MODELS[local_name] = cls
return cls
if cls is None:
return _register
return _register(cls)
def get_model(name):
return _MODELS[name]
def get_sigmas(config):
"""Get sigmas --- the set of noise levels for SMLD from config files.
config: A ConfigDict object parsed from the config file
sigmas: a jax numpy arrary of noise levels
sigmas = np.exp(
np.linspace(np.log(config.model.sigma_max), np.log(config.model.sigma_min), config.model.num_scales))
return sigmas
def get_ddpm_params(config):
"""Get betas and alphas --- parameters used in the original DDPM paper."""
num_diffusion_timesteps = 1000
# parameters need to be adapted if number of time steps differs from 1000
beta_start = config.model.beta_min / config.model.num_scales
beta_end = config.model.beta_max / config.model.num_scales
betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64)
alphas = 1. - betas
alphas_cumprod = np.cumprod(alphas, axis=0)
sqrt_alphas_cumprod = np.sqrt(alphas_cumprod)
sqrt_1m_alphas_cumprod = np.sqrt(1. - alphas_cumprod)
return {
'betas': betas,
'alphas': alphas,
'alphas_cumprod': alphas_cumprod,
'sqrt_alphas_cumprod': sqrt_alphas_cumprod,
'sqrt_1m_alphas_cumprod': sqrt_1m_alphas_cumprod,
'beta_min': beta_start * (num_diffusion_timesteps - 1),
'beta_max': beta_end * (num_diffusion_timesteps - 1),
'num_diffusion_timesteps': num_diffusion_timesteps
def create_model(config, use_parallel=True):
"""Create the score model."""
model_name = config.model.name
score_model = get_model(model_name)(config)
# score_model = score_model.to(config.device)
score_model = score_model
if use_parallel:
score_model = torch.nn.DataParallel(score_model).to(config.device)
return score_model
def get_model_fn(model, train=False):
"""Create a function to give the output of the score-based model.
model: The score model.
train: `True` for training and `False` for evaluation.
A model function.
def model_fn(x, labels):
"""Compute the output of the score-based model.
x: A mini-batch of input data.
labels: A mini-batch of conditioning variables for time steps. Should be interpreted differently
for different models.
A tuple of (model output, new mutable states)
if not train:
return model(x, labels)
return model(x, labels)
return model_fn
def get_reg_fn(model, train=False):
"""Create a function to give the output of the score-based model.
model: The score model.
train: `True` for training and `False` for evaluation.
A model function.
def model_fn(x):
"""Compute the output of the score-based model.
x: A mini-batch of input data.
labels: A mini-batch of conditioning variables for time steps. Should be interpreted differently
for different models.
A tuple of (model output, new mutable states)
if not train:
return model.get_reg(x)
return torch.zeros_like(x, device=x.device)
return model.get_reg(x)
return torch.zeros_like(x, device=x.device)
return model_fn
def get_score_fn(sde, model, train=False, continuous=False, std_scale=True):
"""Wraps `score_fn` so that the model output corresponds to a real time-dependent score function.
sde: An `sde_lib.SDE` object that represents the forward SDE.
model: A score model.
train: `True` for training and `False` for evaluation.
continuous: If `True`, the score-based model is expected to directly take continuous time steps.
std_scale: whether to scale the score function by the inverse of std. Used for DDIM sampling
A score function.
model_fn = get_model_fn(model, train=train)
reg_fn = get_reg_fn(model, train=train)
assert not continuous
if isinstance(sde, sde_lib.VPSDE):
if not std_scale:
def score_fn(x, t):
labels = t * (sde.N - 1)
score = model_fn(x, labels)
return score
def score_fn(x, t):
# For VP-trained models, t=0 corresponds to the lowest noise level
labels = t * (sde.N - 1)
score = model_fn(x, labels)
std = sde.sqrt_1m_alphas_cumprod.to(labels.device)[labels.long()]
score = -score / std[:, None, None, None, None]
return score
raise NotImplementedError(f"SDE class {sde.__class__.__name__} not yet supported.")
return score_fn
def to_flattened_numpy(x):
"""Flatten a torch tensor `x` and convert it to numpy."""
return x.detach().cpu().numpy().reshape((-1,))
def from_flattened_numpy(x, shape):
"""Form a torch tensor with the given `shape` from a flattened numpy array `x`."""
return torch.from_numpy(x.reshape(shape))

@ -0,0 +1,570 @@
# coding=utf-8
# Copyright 2020 The Google Research Authors.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: skip-file
# pytype: skip-file
"""Various sampling methods."""
import functools
import torch
import numpy as np
import abc
from .models.utils import from_flattened_numpy, to_flattened_numpy, get_score_fn
from scipy import integrate
from . import sde_lib
from .models import utils as mutils
import logging
import tqdm
def register_predictor(cls=None, *, name=None):
"""A decorator for registering predictor classes."""
def _register(cls):
if name is None:
local_name = cls.__name__
local_name = name
if local_name in _PREDICTORS:
raise ValueError(f'Already registered model with name: {local_name}')
_PREDICTORS[local_name] = cls
return cls
if cls is None:
return _register
return _register(cls)
def register_corrector(cls=None, *, name=None):
"""A decorator for registering corrector classes."""
def _register(cls):
if name is None:
local_name = cls.__name__
local_name = name
if local_name in _CORRECTORS:
raise ValueError(f'Already registered model with name: {local_name}')
_CORRECTORS[local_name] = cls
return cls
if cls is None:
return _register
return _register(cls)
def get_predictor(name):
return _PREDICTORS[name]
def get_corrector(name):
return _CORRECTORS[name]
def get_sampling_fn(config, sde, shape, inverse_scaler, eps, grid_mask=None, return_traj=False):
"""Create a sampling function.
config: A `ml_collections.ConfigDict` object that contains all configuration information.
sde: A `sde_lib.SDE` object that represents the forward SDE.
shape: A sequence of integers representing the expected shape of a single sample.
inverse_scaler: The inverse data normalizer function.
eps: A `float` number. The reverse-time SDE is only integrated to `eps` for numerical stability.
A function that takes random states and a replicated training state and outputs samples with the
trailing dimensions matching `shape`.
sampler_name = config.sampling.method
# Probability flow ODE sampling with black-box ODE solvers
# Predictor-Corrector sampling. Predictor-only and Corrector-only samplers are special cases.
if sampler_name.lower() == 'pc':
predictor = get_predictor(config.sampling.predictor.lower())
corrector = get_corrector(config.sampling.corrector.lower())
sampling_fn = get_pc_sampler(sde=sde,
elif sampler_name.lower() == 'ddim':
predictor = get_predictor('ddim')
sampling_fn = get_ddim_sampler(sde=sde,
raise ValueError(f"Sampler name {sampler_name} unknown.")
return sampling_fn
class Predictor(abc.ABC):
"""The abstract class for a predictor algorithm."""
def __init__(self, sde, score_fn, probability_flow=False):
self.sde = sde
# Compute the reverse SDE/ODE
self.rsde = sde.reverse(score_fn, probability_flow)
self.score_fn = score_fn
def update_fn(self, x, t):
"""One update of the predictor.
x: A PyTorch tensor representing the current state
t: A Pytorch tensor representing the current time step.
x: A PyTorch tensor of the next state.
x_mean: A PyTorch tensor. The next state without random noise. Useful for denoising.
class Corrector(abc.ABC):
"""The abstract class for a corrector algorithm."""
def __init__(self, sde, score_fn, snr, n_steps):
self.sde = sde
self.score_fn = score_fn
self.snr = snr
self.n_steps = n_steps
def update_fn(self, x, t):
"""One update of the corrector.
x: A PyTorch tensor representing the current state
t: A PyTorch tensor representing the current time step.
x: A PyTorch tensor of the next state.
x_mean: A PyTorch tensor. The next state without random noise. Useful for denoising.
class EulerMaruyamaPredictor(Predictor):
def __init__(self, sde, score_fn, probability_flow=False):
super().__init__(sde, score_fn, probability_flow)
def update_fn(self, x, t):
dt = -1. / self.rsde.N
z = torch.randn_like(x)
drift, diffusion = self.rsde.sde(x, t)
x_mean = x + drift * dt
x = x_mean + diffusion[:, None, None, None, None] * np.sqrt(-dt) * z
return x, x_mean
class ReverseDiffusionPredictor(Predictor):
def __init__(self, sde, score_fn, probability_flow=False):
super().__init__(sde, score_fn, probability_flow)
def update_fn(self, x, t):
f, G = self.rsde.discretize(x, t)
z = torch.randn_like(x)
x_mean = x - f
x = x_mean + G[:, None, None, None, None] * z
return x, x_mean
class AncestralSamplingPredictor(Predictor):
"""The ancestral sampling predictor. Currently only supports VE/VP SDEs."""
def __init__(self, sde, score_fn, probability_flow=False):
super().__init__(sde, score_fn, probability_flow)
if not isinstance(sde, sde_lib.VPSDE):
raise NotImplementedError(f"SDE class {sde.__class__.__name__} not yet supported.")
assert not probability_flow, "Probability flow not supported by ancestral sampling"
def vpsde_update_fn(self, x, t):
sde = self.sde
timestep = (t * (sde.N - 1) / sde.T).long()
beta = sde.discrete_betas.to(t.device)[timestep]
score = self.score_fn(x, t)
x_mean = (x + beta[:, None, None, None, None] * score) / torch.sqrt(1. - beta)[:, None, None, None, None]
noise = torch.randn_like(x)
x = x_mean + torch.sqrt(beta)[:, None, None, None, None] * noise
return x, x_mean
def update_fn(self, x, t):
if isinstance(self.sde, sde_lib.VPSDE):
return self.vpsde_update_fn(x, t)
raise NotImplementedError
class NonePredictor(Predictor):
"""An empty predictor that does nothing."""
def __init__(self, sde, score_fn, probability_flow=False):
def update_fn(self, x, t):
return x, x
class DDIMPredictor(Predictor):
def __init__(self, sde, score_fn, probability_flow=False):
super().__init__(sde, score_fn, probability_flow)
def update_fn(self, x, t, tprev=None):
x, x0_pred = self.rsde.discretize_ddim(x, t, tprev=tprev)
return x, x0_pred
class LangevinCorrector(Corrector):
def __init__(self, sde, score_fn, snr, n_steps):
super().__init__(sde, score_fn, snr, n_steps)
if not isinstance(sde, sde_lib.VPSDE):
raise NotImplementedError(f"SDE class {sde.__class__.__name__} not yet supported.")
def update_fn(self, x, t):
sde = self.sde
score_fn = self.score_fn
n_steps = self.n_steps
target_snr = self.snr
if isinstance(sde, sde_lib.VPSDE):
timestep = (t * (sde.N - 1) / sde.T).long()
alpha = sde.alphas.to(t.device)[timestep]
alpha = torch.ones_like(t)
for i in range(n_steps):
grad = score_fn(x, t)
noise = torch.randn_like(x)
grad_norm = torch.norm(grad.reshape(grad.shape[0], -1), dim=-1).mean()
noise_norm = torch.norm(noise.reshape(noise.shape[0], -1), dim=-1).mean()
step_size = (target_snr * noise_norm / grad_norm) ** 2 * 2 * alpha
x_mean = x + step_size[:, None, None, None, None] * grad
x = x_mean + torch.sqrt(step_size * 2)[:, None, None, None, None] * noise
return x, x_mean
class AnnealedLangevinDynamics(Corrector):
"""The original annealed Langevin dynamics predictor in NCSN/NCSNv2.
We include this corrector only for completeness. It was not directly used in our paper.
def __init__(self, sde, score_fn, snr, n_steps):
super().__init__(sde, score_fn, snr, n_steps)
if not isinstance(sde, sde_lib.VPSDE):
raise NotImplementedError(f"SDE class {sde.__class__.__name__} not yet supported.")
def update_fn(self, x, t):
sde = self.sde
score_fn = self.score_fn
n_steps = self.n_steps
target_snr = self.snr
if isinstance(sde, sde_lib.VPSDE):
timestep = (t * (sde.N - 1) / sde.T).long()
alpha = sde.alphas.to(t.device)[timestep]
alpha = torch.ones_like(t)
std = self.sde.marginal_prob(x, t)[1]
for i in range(n_steps):
grad = score_fn(x, t)
noise = torch.randn_like(x)
step_size = (target_snr * std) ** 2 * 2 * alpha
x_mean = x + step_size[:, None, None, None, None] * grad
x = x_mean + noise * torch.sqrt(step_size * 2)[:, None, None, None, None]
return x, x_mean
class NoneCorrector(Corrector):
"""An empty corrector that does nothing."""
def __init__(self, sde, score_fn, snr, n_steps):
def update_fn(self, x, t):
return x, x
def shared_predictor_update_fn(x, t, sde, model, predictor, probability_flow, continuous):
"""A wrapper that configures and returns the update function of predictors."""
score_fn = mutils.get_score_fn(sde, model, train=False, continuous=continuous)
if predictor is None:
# Corrector-only sampler
predictor_obj = NonePredictor(sde, score_fn, probability_flow)
predictor_obj = predictor(sde, score_fn, probability_flow)
return predictor_obj.update_fn(x, t)
def shared_corrector_update_fn(x, t, sde, model, corrector, continuous, snr, n_steps):
"""A wrapper tha configures and returns the update function of correctors."""
score_fn = mutils.get_score_fn(sde, model, train=False, continuous=continuous)
if corrector is None:
# Predictor-only sampler
corrector_obj = NoneCorrector(sde, score_fn, snr, n_steps)
corrector_obj = corrector(sde, score_fn, snr, n_steps)
return corrector_obj.update_fn(x, t)
def get_pc_sampler(sde, shape, predictor, corrector, inverse_scaler, snr,
n_steps=1, probability_flow=False, continuous=False,
denoise=True, eps=1e-3, device='cuda', grid_mask=None, return_traj=False):
"""Create a Predictor-Corrector (PC) sampler.
sde: An `sde_lib.SDE` object representing the forward SDE.
shape: A sequence of integers. The expected shape of a single sample.
predictor: A subclass of `sampling.Predictor` representing the predictor algorithm.
corrector: A subclass of `sampling.Corrector` representing the corrector algorithm.
inverse_scaler: The inverse data normalizer.
snr: A `float` number. The signal-to-noise ratio for configuring correctors.
n_steps: An integer. The number of corrector steps per predictor update.
probability_flow: If `True`, solve the reverse-time probability flow ODE when running the predictor.
continuous: `True` indicates that the score model was continuously trained.
denoise: If `True`, add one-step denoising to the final samples.
eps: A `float` number. The reverse-time SDE and ODE are integrated to `epsilon` to avoid numerical issues.
device: PyTorch device.
A sampling function that returns samples and the number of function evaluations during sampling.
# Create predictor & corrector update functions
predictor_update_fn = functools.partial(shared_predictor_update_fn,
corrector_update_fn = functools.partial(shared_corrector_update_fn,
def pc_sampler(model,
partial=None, partial_grid_mask=None, partial_channel=0,
""" The PC sampler funciton.
model: A score model.
Samples, number of function evaluations.
with torch.no_grad():
if freeze_iters is None:
freeze_iters = sde.N + 10 # just some randomly large number greater than sde.N
timesteps = torch.linspace(sde.T, eps, sde.N, device=device)
def compute_xzero(sde, model, x, t, grid_mask_input):
timestep_int = (t * (sde.N - 1) / sde.T).long()
alphas1 = sde.sqrt_alphas_cumprod[timestep_int].cuda()
alphas2 = sde.sqrt_1m_alphas_cumprod[timestep_int].cuda()
alphas1_prev = sde.sqrt_alphas_cumprod[timestep_int - 1].cuda()
alphas2_prev = sde.sqrt_1m_alphas_cumprod[timestep_int - 1].cuda()
score_pred = model(x, t * torch.ones(shape[0], device=x.device))
x0_pred_scaled = (x - alphas2 * score_pred)
x0_pred = x0_pred_scaled / alphas1
x0_pred = x0_pred.clamp(-1, 1)
return x0_pred * grid_mask_input
# Initial sample
x = sde.prior_sampling(shape).to(device)
assert len(x.size()) == 5
x = x * grid_mask
traj_buffer = []
if partial is not None:
assert len(partial.size()) == 5
t = timesteps[0]
vec_t = torch.ones(shape[0], device=t.device) * t
x[:, partial_channel] = partial[:, partial_channel] * grid_mask[:, partial_channel]
partial_mean, partial_std = sde.marginal_prob(x, vec_t)
sampled_update = partial_mean[:, partial_channel] + partial_std[:, None, None, None, None] * torch.randn_like(partial_mean[:, partial_channel], device=partial_std.device)
x[:, partial_channel] = (
x[:, partial_channel] * (1 - partial_mask[:, partial_channel])
+ sampled_update[:, partial_channel] * partial_mask[:, partial_channel]
) * grid_mask[:, partial_channel]
if partial is not None:
x_mean = x
for i in tqdm.trange(sde.N):
t = timesteps[i]
vec_t = torch.ones(shape[0], device=t.device) * t
x, x_mean = corrector_update_fn(x, vec_t, model=model)
x, x_mean = x * grid_mask, x_mean * grid_mask
x, x_mean = predictor_update_fn(x, vec_t, model=model)
x, x_mean = x * grid_mask, x_mean * grid_mask
if i != sde.N - 1 and i < freeze_iters:
x[:, partial_channel] = (x[:, partial_channel] * (1 - partial_mask[:, partial_channel]) + partial[:, partial_channel] * partial_mask[:, partial_channel]) * grid_mask[:, partial_channel]
x_mean[:, partial_channel] = (x_mean[:, partial_channel] * (1 - partial_mask[:, partial_channel]) + partial[:, partial_channel] * partial_mask[:, partial_channel]) * grid_mask[:, partial_channel]
### add noise to the condition x0_star
partial_mean, partial_std = sde.marginal_prob(x, timesteps[i] * torch.ones(shape[0], device=t.device))
sampled_update = partial_mean[:, partial_channel] + partial_std[:, None, None, None] * torch.randn_like(partial_mean[:, partial_channel], device=partial_std.device)
x[:, partial_channel] = (
x[:, partial_channel] * (1 - partial_mask[:, partial_channel])
+ sampled_update * partial_mask[:, partial_channel]
) * grid_mask[:, partial_channel]
x_mean[:, partial_channel] = x[:, partial_channel]
for i in tqdm.trange(sde.N - 1):
t = timesteps[i]
vec_t = torch.ones(shape[0], device=t.device) * t
x, x_mean = corrector_update_fn(x, vec_t, model=model)
x, x_mean = x * grid_mask, x_mean * grid_mask
x, x_mean = predictor_update_fn(x, vec_t, model=model)
x, x_mean = x * grid_mask, x_mean * grid_mask
if return_traj and i >= 700 and i % 10 == 0:
traj_buffer.append(compute_xzero(sde, model, x, t, grid_mask))
if return_traj:
return traj_buffer, sde.N * (n_steps + 1)
return inverse_scaler(x_mean if denoise else x), sde.N * (n_steps + 1)
return pc_sampler
def ddim_predictor_update_fn(x, t, tprev, sde, model, predictor, probability_flow, continuous):
"""A wrapper that configures and returns the update function of predictors."""
assert not continuous
score_fn = mutils.get_score_fn(sde, model, train=False, continuous=False, std_scale=False)
if predictor is None:
# Corrector-only sampler
predictor_obj = NonePredictor(sde, score_fn, probability_flow)
predictor_obj = predictor(sde, score_fn, probability_flow)
return predictor_obj.update_fn(x, t, tprev)
def get_ddim_sampler(sde, shape, predictor, inverse_scaler, n_steps=1,
denoise=False, eps=1e-3, device='cuda', grid_mask=None):
"""Probability flow ODE sampler with the black-box ODE solver.
sde: An `sde_lib.SDE` object that represents the forward SDE.
shape: A sequence of integers. The expected shape of a single sample.
inverse_scaler: The inverse data normalizer.
denoise: If `True`, add one-step denoising to final samples.
eps: A `float` number. The reverse-time SDE/ODE will be integrated to `eps` for numerical stability.
device: PyTorch device.
A sampling function that returns samples and the number of function evaluations during sampling.
predictor_update_fn = functools.partial(ddim_predictor_update_fn,
def ddim_sampler(model, schedule='quad', num_steps=100, x0=None,
partial=None, partial_grid_mask=None, partial_channel=0):
""" The PC sampler funciton.
model: A score model.
Samples, number of function evaluations.
with torch.no_grad():
if x0 is not None:
x = x0 * grid_mask
# Initial sample
x = sde.prior_sampling(shape).to(device)
x = x * grid_mask
if partial is not None:
x[:, partial_channel] = x[:, partial_channel] * (1 - partial_mask) + partial * partial_mask
timesteps = torch.linspace(sde.T, eps, sde.N, device=device)
if schedule == 'uniform':
skip = sde.N // num_steps
seq = range(0, sde.N, skip)
elif schedule == 'quad':
seq = (
0, np.sqrt(sde.N * 0.8), 100
** 2
seq = [int(s) for s in list(seq)]
timesteps = torch.tensor(seq) / sde.N
for i in tqdm.tqdm(reversed(range(1, len(timesteps)))):
t = timesteps[i]
tprev = timesteps[i - 1]
vec_t = torch.ones(shape[0], device=t.device) * t
vec_tprev = torch.ones(shape[0], device=t.device) * tprev
x, x0_pred = predictor_update_fn(x, vec_t, model=model, tprev=vec_tprev)
x, x0_pred = x * grid_mask, x0_pred * grid_mask
if partial is not None:
x[:, partial_channel] = x[:, partial_channel] * (1 - partial_mask) + partial * partial_mask
x0_pred[:, partial_channel] = x0_pred[:, partial_channel] * (1 - partial_mask) + partial * partial_mask
return inverse_scaler(x0_pred * grid_mask if (denoise and not encode) else x * grid_mask), sde.N * (n_steps + 1)
return ddim_sampler

@ -0,0 +1,233 @@
"""Abstract SDE classes, Reverse SDE, and VE/VP SDEs."""
import abc
import torch
import numpy as np
import torch.nn.functional as F
import time
class SDE(abc.ABC):
"""SDE abstract class. Functions are designed for a mini-batch of inputs."""
def __init__(self, N):
"""Construct an SDE.
N: number of discretization time steps.
self.N = N
def T(self):
"""End time of the SDE."""
def sde(self, x, t):
def marginal_prob(self, x, t):
"""Parameters to determine the marginal distribution of the SDE, $p_t(x)$."""
def prior_sampling(self, shape):
"""Generate one sample from the prior distribution, $p_T(x)$."""
def prior_logp(self, z):
"""Compute log-density of the prior distribution.
Useful for computing the log-likelihood via probability flow ODE.
z: latent code
log probability density
def discretize(self, x, t):
"""Discretize the SDE in the form: x_{i+1} = x_i + f_i(x_i) + G_i z_i.
Useful for reverse diffusion sampling and probabiliy flow sampling.
Defaults to Euler-Maruyama discretization.
x: a torch tensor
t: a torch float representing the time step (from 0 to `self.T`)
f, G
dt = 1 / self.N
drift, diffusion = self.sde(x, t)
f = drift * dt
G = diffusion * torch.sqrt(torch.tensor(dt, device=t.device))
return f, G
def reverse(self, score_fn, probability_flow=False):
"""Create the reverse-time SDE/ODE.
score_fn: A time-dependent score-based model that takes x and t and returns the score.
probability_flow: If `True`, create the reverse-time ODE used for probability flow sampling.
N = self.N
T = self.T
sde_fn = self.sde
discretize_fn = self.discretize
sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
sqrt_1m_alphas_cumprod = self.sqrt_1m_alphas_cumprod
# Build the class for reverse-time SDE.
class RSDE(self.__class__):
def __init__(self):
self.N = N
self.probability_flow = probability_flow
def T(self):
return T
def sde(self, x, t):
"""Create the drift and diffusion functions for the reverse SDE/ODE."""
drift, diffusion = sde_fn(x, t)
score = score_fn(x, t)
drift = drift - diffusion[:, None, None, None, None] ** 2 * score * (0.5 if self.probability_flow else 1.)
# Set the diffusion function to zero for ODEs.
diffusion = 0. if self.probability_flow else diffusion
return drift, diffusion
def discretize(self, x, t):
"""Create discretized iteration rules for the reverse diffusion sampler."""
f, G = discretize_fn(x, t)
rev_f = f - G[:, None, None, None, None] ** 2 * score_fn(x, t) * (0.5 if self.probability_flow else 1.)
rev_G = torch.zeros_like(G) if self.probability_flow else G
return rev_f, rev_G
def discretize_ddim(self, x, t, tprev=None, encode=False):
"""DDPM discretization."""
timestep = (t * (N - 1) / T).long()
timestep_prev = (tprev * (N - 1) / T).long()
score = score_fn(x.float(), t.float())
# alphas1prev_div_alphas1 = torch.exp(log_diff)
alphas1 = sqrt_alphas_cumprod[timestep].cuda()[:, None, None, None, None]
alphas2 = sqrt_1m_alphas_cumprod[timestep].cuda()[:, None, None, None, None]
alphas1_prev = sqrt_alphas_cumprod[timestep_prev].cuda()[:, None, None, None, None]
alphas2_prev = sqrt_1m_alphas_cumprod[timestep_prev].cuda()[:, None, None, None, None]
alphas1prev_div_alphas1 = alphas1_prev.double() / alphas1.double()
alphas2prev_div_alphas2 = alphas2_prev.double() / alphas2.double()
x0_pred_scaled = (x.double() - alphas2.double() * score.double())
use_clip = False
if use_clip:
x0_pred_scaled = x0_pred_scaled.clamp(-alphas1[0].squeeze(), alphas1[0].squeeze())
score_scaled_t = x - x0_pred_scaled
x0_pred = x0_pred_scaled / alphas1
x_new = (
alphas1prev_div_alphas1.double() * x +
(-alphas1prev_div_alphas1 + alphas2prev_div_alphas2.double()) * score_scaled_t.double()
return x_new, x0_pred
def discretize_conditional_ddpm(self, x, t, tprev=None, condition_func=None, condition=False):
"""DDPM discretization."""
timestep = (t * (N - 1) / T).long()
timestep_prev = (tprev * (N - 1) / T).long()
score = score_fn(x.float(), t.float())
# alphas1prev_div_alphas1 = torch.exp(log_diff)
alphas1 = sqrt_alphas_cumprod[timestep].cuda()[:, None, None, None, None]
alphas2 = sqrt_1m_alphas_cumprod[timestep].cuda()[:, None, None, None, None]
alphas1_prev = sqrt_alphas_cumprod[timestep_prev].cuda()[:, None, None, None, None]
alphas1prev_div_alphas1 = alphas1_prev.double() / alphas1.double()
x0_pred_scaled = (x.double() - alphas2.double() * score.double())
x0_pred_scaled = x0_pred_scaled.clamp(-alphas1[0].squeeze(), alphas1[0].squeeze())
x0_pred = x0_pred_scaled / alphas1
if condition is None:
condition_update = 0
if (t - 0.99).mean() < 1e-3:
x = x0_pred
condition_update = condition_func(x.float(), condition)
x_new = (
x - alphas1prev_div_alphas1.double() * condition_update
return x_new, x0_pred
return RSDE()
class VPSDE(SDE):
def __init__(self, beta_min=0.1, beta_max=20, N=1000):
"""Construct a Variance Preserving SDE.
beta_min: value of beta(0)
beta_max: value of beta(1)
N: number of discretization steps
self.beta_0 = beta_min
self.beta_1 = beta_max
self.N = N
self.discrete_betas = torch.linspace(beta_min / N, beta_max / N, N).cuda()
self.alphas = 1. - self.discrete_betas
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
self.alphas_cumprod_ext = torch.cat([torch.tensor([1.0 - 1e-4]).cuda(), torch.cumprod(self.alphas, dim=0)], dim=0)
self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
self.sqrt_1m_alphas_cumprod = torch.sqrt(1. - self.alphas_cumprod)
self.alphas_cumprod = self.alphas_cumprod
self.alphas_cumprod_ext = self.alphas_cumprod_ext
def T(self):
return 1
def sde(self, x, t):
beta_t = self.beta_0 + t * (self.beta_1 - self.beta_0)
drift = -0.5 * beta_t[:, None, None, None, None] * x
diffusion = torch.sqrt(beta_t)
return drift, diffusion
def marginal_prob(self, x, t):
log_mean_coeff = -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0
mean = torch.exp(log_mean_coeff[:, None, None, None, None]) * x
std = torch.sqrt(1. - torch.exp(2. * log_mean_coeff))
return mean, std
def prior_sampling(self, shape):
return torch.randn(*shape)
def prior_logp(self, z):
shape = z.shape
N = np.prod(shape[1:])
logps = -N / 2. * np.log(2 * np.pi) - torch.sum(z ** 2, dim=(1, 2, 3, 4)) / 2.
return logps
def discretize(self, x, t):
"""DDPM discretization."""
timestep = (t * (self.N - 1) / self.T).long()
beta = self.discrete_betas.to(x.device)[timestep]
alpha = self.alphas.to(x.device)[timestep]
sqrt_beta = torch.sqrt(beta)
f = torch.sqrt(alpha)[:, None, None, None, None] * x - x
G = sqrt_beta
return f, G

@ -0,0 +1,130 @@
import os
import sys
import numpy as np
import logging
# Keep the import below for registering all model definitions
from .models import ddpm_res64, ddpm_res128
from . import losses
from .models import utils as mutils
from .models.ema import ExponentialMovingAverage
from . import sde_lib
import torch
from torch.utils import tensorboard
from .utils import save_checkpoint, restore_checkpoint
from ..dataset.shapenet_dmtet_dataset import ShapeNetDMTetDataset
def train(config):
"""Runs the training pipeline.
config: Configuration to use.
workdir: Working directory for checkpoints and TF summaries. If this
contains checkpoint training will be resumed from the latest checkpoint.
workdir = config.training.train_dir
# Create directories for experimental logs
logging.info("working dir: {:s}".format(workdir))
tb_dir = os.path.join(workdir, "tensorboard")
writer = tensorboard.SummaryWriter(tb_dir)
resolution = config.data.image_size
# Initialize model.
score_model = mutils.create_model(config)
ema = ExponentialMovingAverage(score_model.parameters(), decay=config.model.ema_rate)
optimizer = losses.get_optimizer(config, score_model.parameters())
state = dict(optimizer=optimizer, model=score_model, ema=ema, step=0)
# Create checkpoints directory
checkpoint_dir = os.path.join(workdir, "checkpoints")
# Intermediate checkpoints to resume training after pre-emption in cloud environments
checkpoint_meta_dir = os.path.join(workdir, "checkpoints-meta", "checkpoint.pth")
os.makedirs(checkpoint_dir, exist_ok=True)
os.makedirs(os.path.dirname(checkpoint_meta_dir), exist_ok=True)
# Resume training when intermediate checkpoints are detected
state = restore_checkpoint(checkpoint_meta_dir, state, config.device)
initial_step = int(state['step'])
json_path = config.data.meta_path
print("----- Assigning mask -----")
logging.info(f"{json_path}, {config.data.filter_meta_path}")
### mask on tet to ignore regions
mask = torch.load(f'./data/grid_mask_{resolution}.pt').view(1, 1, resolution, resolution, resolution).to("cuda")
if hasattr(score_model.module, 'mask'):
print("----- Assigning mask -----")
score_model.module.mask.data[:] = mask[:]
print(f"work dir: {workdir}")
print("sdf normalized or not: ", config.data.normalize_sdf)
train_dataset = ShapeNetDMTetDataset(json_path, deform_scale=config.model.deform_scale, aug=True, grid_mask=mask,
filter_meta_path=config.data.filter_meta_path, normalize_sdf=config.data.normalize_sdf)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=config.training.batch_size,
data_iter = iter(train_loader)
print("data loader set")
# Setup SDEs
sde = sde_lib.VPSDE(beta_min=config.model.beta_min, beta_max=config.model.beta_max, N=config.model.num_scales)
# Build one-step training and evaluation functions
optimize_fn = losses.optimization_manager(config)
train_step_fn = losses.get_step_fn(sde, train=True, optimize_fn=optimize_fn,
mask=mask, loss_type=config.training.loss_type)
num_train_steps = config.training.n_iters
# In case there are multiple hosts (e.g., TPU pods), only log to host 0
logging.info("Starting training loop at step %d." % (initial_step // config.training.iter_size,))
iter_size = config.training.iter_size
for step in range(initial_step // iter_size, num_train_steps + 1):
tmp_loss = 0.0
for step_inner in range(iter_size):
# batch, batch_mask = next(data_iter)
batch = next(data_iter)
except StopIteration:
# StopIteration is thrown if dataset ends
# reinitialize data loader
data_iter = iter(train_loader)
batch = next(data_iter)
batch = batch.cuda()
# Execute one training step
clear_grad_flag = (step_inner == 0)
update_param_flag = (step_inner == iter_size - 1)
loss_dict = train_step_fn(state, batch, clear_grad=clear_grad_flag, update_param=update_param_flag)
loss = loss_dict['loss']
tmp_loss += loss.item()
tmp_loss /= iter_size
if step % config.training.log_freq == 0:
logging.info("step: %d, training_loss: %.5e" % (step, tmp_loss))
writer.add_scalar("training_loss", loss, step)
# Save a temporary checkpoint to resume training after pre-emption periodically
if step != 0 and step % config.training.snapshot_freq_for_preemption == 0:
logging.info(f"save meta at iter {step}")
save_checkpoint(checkpoint_meta_dir, state)
# Save a checkpoint periodically and generate samples if needed
if step != 0 and step % config.training.snapshot_freq == 0 or step == num_train_steps:
logging.info(f"save model: {step}-th")
save_checkpoint(os.path.join(checkpoint_dir, f'checkpoint_{step}.pth'), state)

@ -0,0 +1,31 @@
import torch
import tensorflow as tf
import os
import logging
def restore_checkpoint(ckpt_dir, state, device, strict=False):
if not tf.io.gfile.exists(ckpt_dir):
logging.warning(f"No checkpoint found at {ckpt_dir}. "
f"Returned the same state as input")
if strict:
return state
loaded_state = torch.load(ckpt_dir, map_location=device)
state['model'].load_state_dict(loaded_state['model'], strict=False)
state['step'] = loaded_state['step']
return state
def save_checkpoint(ckpt_dir, state):
saved_state = {
'optimizer': state['optimizer'].state_dict(),
'model': state['model'].state_dict(),
'ema': state['ema'].state_dict(),
'step': state['step']
torch.save(saved_state, ckpt_dir)

@ -0,0 +1,77 @@
# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.
import torch
class Dataset(torch.utils.data.Dataset):
"""Basic dataset interface"""
def __init__(self):
def __len__(self):
raise NotImplementedError
def __getitem__(self):
raise NotImplementedError
def collate(self, batch):
iter_res, iter_spp = batch[0]['resolution'], batch[0]['spp']
res_dict = {
'mv' : torch.cat(list([item['mv'] for item in batch]), dim=0),
'mvp' : torch.cat(list([item['mvp'] for item in batch]), dim=0),
'campos' : torch.cat(list([item['campos'] for item in batch]), dim=0),
'resolution' : iter_res,
'spp' : iter_spp,
'img' : torch.cat(list([item['img'] for item in batch]), dim=0)
if 'spts' in batch[0]:
res_dict['spts'] = batch[0]['spts']
if 'vpts' in batch[0]:
res_dict['vpts'] = batch[0]['vpts']
if 'faces' in batch[0]:
res_dict['faces'] = batch[0]['faces']
if 'rast_triangle_id' in batch[0]:
res_dict['rast_triangle_id'] = batch[0]['rast_triangle_id']
if 'depth' in batch[0]:
res_dict['depth'] = torch.cat(list([item['depth'] for item in batch]), dim=0)
if 'normal' in batch[0]:
res_dict['normal'] = torch.cat(list([item['normal'] for item in batch]), dim=0)
if 'geo_normal' in batch[0]:
res_dict['geo_normal'] = torch.cat(list([item['geo_normal'] for item in batch]), dim=0)
if 'geo_viewdir' in batch[0]:
res_dict['geo_viewdir'] = torch.cat(list([item['geo_viewdir'] for item in batch]), dim=0)
if 'pos' in batch[0]:
res_dict['pos'] = torch.cat(list([item['pos'] for item in batch]), dim=0)
if 'mask' in batch[0]:
res_dict['mask'] = torch.cat(list([item['mask'] for item in batch]), dim=0)
if 'mask_cont' in batch[0]:
res_dict['mask_cont'] = torch.cat(list([item['mask_cont'] for item in batch]), dim=0)
if 'envlight_transform' in batch[0]:
if batch[0]['envlight_transform'] is not None:
res_dict['envlight_transform'] = torch.cat(list([item['envlight_transform'] for item in batch]), dim=0)
res_dict['envlight_transform'] = None
res_dict['depth_second'] = torch.cat(list([item['depth_second'] for item in batch]), dim=0)
res_dict['normal_second'] = torch.cat(list([item['normal_second'] for item in batch]), dim=0)
res_dict['img_second'] = torch.cat(list([item['img_second'] for item in batch]), dim=0)
return res_dict

@ -0,0 +1,163 @@
# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.
import numpy as np
import torch
import sys
from ..render import util
from ..render import mesh
from ..render import render
from ..render import light
from .dataset import Dataset
import kaolin
# Reference dataset using mesh & rendering
class DatasetMesh(Dataset):
def __init__(self, ref_mesh, glctx, cam_radius, FLAGS, validate=False):
# Init
self.glctx = glctx
self.cam_radius = cam_radius
self.validate = validate
self.fovy = np.deg2rad(45)
self.aspect = FLAGS.train_res[1] / FLAGS.train_res[0]
self.random_lgt = FLAGS.random_lgt
self.camera_lgt = False
self.flat_shading = FLAGS.dataset_flat_shading
if self.FLAGS.local_rank == 0:
print(f"use flag shading {FLAGS.dataset_flat_shading}")
print("DatasetMesh: ref mesh has %d triangles and %d vertices" % (ref_mesh.t_pos_idx.shape[0], ref_mesh.v_pos.shape[0]))
# Sanity test training texture resolution
ref_texture_res = np.maximum(ref_mesh.material['kd'].getRes(), ref_mesh.material['ks'].getRes())
if 'normal' in ref_mesh.material:
ref_texture_res = np.maximum(ref_texture_res, ref_mesh.material['normal'].getRes())
if self.FLAGS.local_rank == 0 and FLAGS.texture_res[0] < ref_texture_res[0] or FLAGS.texture_res[1] < ref_texture_res[1]:
print("---> WARNING: Picked a texture resolution lower than the reference mesh [%d, %d] < [%d, %d]" % (FLAGS.texture_res[0], FLAGS.texture_res[1], ref_texture_res[0], ref_texture_res[1]))
print("Loading env map")
# Load environment map texture
self.envlight = light.load_env(FLAGS.envmap, scale=FLAGS.env_scale)
print("Computing tangents")
self.ref_mesh = mesh.compute_tangents(ref_mesh)
except Exception as e:
print("Continue without tangents...")
self.ref_mesh = ref_mesh
def _rotate_scene(self, itr):
proj_mtx = util.perspective(self.fovy, self.FLAGS.display_res[1] / self.FLAGS.display_res[0], self.FLAGS.cam_near_far[0], self.FLAGS.cam_near_far[1])
# Smooth rotation for display.
ang = (itr / 50) * np.pi * 2
mv = util.translate(0, 0, -self.cam_radius) @ (util.rotate_x(-0.4) @ util.rotate_y(ang))
mvp = proj_mtx @ mv
campos = torch.linalg.inv(mv)[:3, 3]
return mv[None, ...].cuda(), mvp[None, ...].cuda(), campos[None, ...].cuda(), self.FLAGS.display_res, self.FLAGS.spp
def _random_scene(self):
# ==============================================================================================
# Setup projection matrix
# ==============================================================================================
iter_res = self.FLAGS.train_res
proj_mtx = util.perspective(self.fovy, iter_res[1] / iter_res[0], self.FLAGS.cam_near_far[0], self.FLAGS.cam_near_far[1])
# ==============================================================================================
# Random camera & light position
# ==============================================================================================
# Random rotation/translation matrix for optimization.
mv = util.translate(0, 0, -self.cam_radius) @ util.random_rotation_translation(0.2)
mvp = proj_mtx @ mv
campos = torch.linalg.inv(mv)[:3, 3]
return mv[None, ...].cuda(), mvp[None, ...].cuda(), campos[None, ...].cuda(), iter_res, self.FLAGS.spp # Add batch dimension
def __len__(self):
return 50 if self.validate else (self.FLAGS.iter + 1) * self.FLAGS.batch
def __getitem__(self, itr):
# ==============================================================================================
# Randomize scene parameters
# ==============================================================================================
if self.validate:
mv, mvp, campos, iter_res, iter_spp = self._rotate_scene(itr)
camera_mv = None
mv, mvp, campos, iter_res, iter_spp = self._random_scene()
if self.random_lgt:
rnd_rot = util.random_rotation()
camera_mv = rnd_rot.unsqueeze(0).clone()
elif self.camera_lgt:
camera_mv = mv.clone()
camera_mv = None
with torch.no_grad():
render_out = render.render_mesh(self.glctx, self.ref_mesh, mvp, campos, self.envlight, iter_res, spp=iter_spp,
num_layers=self.FLAGS.layers, msaa=True, background=None, xfm_lgt=camera_mv, flat_shading=self.flat_shading)
img = render_out['shaded']
img_second = render_out['shaded_second']
normal = render_out['normal']
depth = render_out['depth']
geo_normal = render_out['geo_normal']
pos = render_out['pos']
sample_points = torch.tensor(kaolin.ops.mesh.sample_points(self.ref_mesh.v_pos.unsqueeze(0), self.ref_mesh.t_pos_idx, 50000)[0][0])
vertex_points = self.ref_mesh.v_pos
return_dict = {
'mv' : mv,
'mvp' : mvp,
'campos' : campos,
'resolution' : iter_res,
'spp' : iter_spp,
'img' : img,
'img_second' : img_second,
'spts': sample_points,
'vpts': vertex_points,
'faces': self.ref_mesh.t_pos_idx,
'depth': depth,
'normal': normal,
'geo_normal': geo_normal,
'geo_viewdir': render_out['geo_viewdir'],
'pos': pos,
'envlight_transform': camera_mv,
'mask': render_out['mask'],
'mask_cont': render_out['mask_cont'],
'rast_triangle_id': render_out['rast_triangle_id']
return_dict['depth_second'] = render_out['depth_second']
return_dict['normal_second'] = render_out['normal_second']
return return_dict

@ -0,0 +1,32 @@
# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.
import numpy as np
import torch
import os
import json
# ShapeNet Dataset to get meshes
class ShapeNetDataset(object):
def __init__(self, shapenet_json):
with open(shapenet_json, 'r') as f:
self.mesh_list = json.load(f)
print(f"len all data: {len(self.mesh_list)}")
def __len__(self):
return len(self.mesh_list)
def __getitem__(self, idx):
return self.mesh_list[idx]

@ -0,0 +1,462 @@
# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.
import os
import numpy as np
import torch
from ..render import mesh
from ..render import render
from ..render import regularizer
import kaolin
import pytorch3d.ops
from ..render import util as render_utils
import torch.nn.functional as F
from ..render import renderutils as ru
# Marching tetrahedrons implementation (differentiable), adapted from
# https://github.com/NVIDIAGameWorks/kaolin/blob/master/kaolin/ops/conversions/tetmesh.py
class DMTet:
def __init__(self):
self.triangle_table = torch.tensor([
[-1, -1, -1, -1, -1, -1],
[ 1, 0, 2, -1, -1, -1],
[ 4, 0, 3, -1, -1, -1],
[ 1, 4, 2, 1, 3, 4],
[ 3, 1, 5, -1, -1, -1],
[ 2, 3, 0, 2, 5, 3],
[ 1, 4, 0, 1, 5, 4],
[ 4, 2, 5, -1, -1, -1],
[ 4, 5, 2, -1, -1, -1],
[ 4, 1, 0, 4, 5, 1],
[ 3, 2, 0, 3, 5, 2],
[ 1, 3, 5, -1, -1, -1],
[ 4, 1, 2, 4, 3, 1],
[ 3, 0, 4, -1, -1, -1],
[ 2, 0, 1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1]
], dtype=torch.long, device='cuda')
self.num_triangles_table = torch.tensor([0,1,1,2,1,2,2,1,1,2,2,1,2,1,1,0], dtype=torch.long, device='cuda')
self.base_tet_edges = torch.tensor([0,1,0,2,0,3,1,2,1,3,2,3], dtype=torch.long, device='cuda')
# Utility functions
def sort_edges(self, edges_ex2):
with torch.no_grad():
order = (edges_ex2[:,0] > edges_ex2[:,1]).long()
order = order.unsqueeze(dim=1)
a = torch.gather(input=edges_ex2, index=order, dim=1)
b = torch.gather(input=edges_ex2, index=1-order, dim=1)
return torch.stack([a, b],-1)
def map_uv(self, faces, face_gidx, max_idx):
N = int(np.ceil(np.sqrt((max_idx+1)//2)))
tex_y, tex_x = torch.meshgrid(
torch.linspace(0, 1 - (1 / N), N, dtype=torch.float32, device="cuda"),
torch.linspace(0, 1 - (1 / N), N, dtype=torch.float32, device="cuda"),
pad = 0.9 / N
uvs = torch.stack([
tex_x , tex_y,
tex_x + pad, tex_y,
tex_x + pad, tex_y + pad,
tex_x , tex_y + pad
], dim=-1).view(-1, 2)
def _idx(tet_idx, N):
x = tet_idx % N
y = torch.div(tet_idx, N, rounding_mode='trunc')
return y * N + x
tet_idx = _idx(torch.div(face_gidx, 2, rounding_mode='trunc'), N)
tri_idx = face_gidx % 2
uv_idx = torch.stack((
tet_idx * 4, tet_idx * 4 + tri_idx + 1, tet_idx * 4 + tri_idx + 2
), dim = -1). view(-1, 3)
# Marching tets implementation
def __call__(self, pos_nx3, sdf_n, tet_fx4):
with torch.no_grad():
occ_n = sdf_n > 0
occ_fx4 = occ_n[tet_fx4.reshape(-1)].reshape(-1,4)
occ_sum = torch.sum(occ_fx4, -1)
valid_tets = (occ_sum>0) & (occ_sum<4)
occ_sum = occ_sum[valid_tets]
# find all vertices
all_edges = tet_fx4[valid_tets][:,self.base_tet_edges].reshape(-1,2)
all_edges = self.sort_edges(all_edges)
unique_edges, idx_map = torch.unique(all_edges,dim=0, return_inverse=True)
unique_edges = unique_edges.long()
mask_edges = occ_n[unique_edges.reshape(-1)].reshape(-1,2).sum(-1) == 1
mapping = torch.ones((unique_edges.shape[0]), dtype=torch.long, device="cuda") * -1
mapping[mask_edges] = torch.arange(mask_edges.sum(), dtype=torch.long,device="cuda")
idx_map = mapping[idx_map] # map edges to verts
interp_v = unique_edges[mask_edges]
edges_to_interp = pos_nx3[interp_v.reshape(-1)].reshape(-1,2,3)
edges_to_interp_sdf = sdf_n[interp_v.reshape(-1)].reshape(-1,2,1)
edges_to_interp_sdf[:,-1] *= -1
denominator = edges_to_interp_sdf.sum(1,keepdim = True)
edges_to_interp_sdf = torch.flip(edges_to_interp_sdf, [1])/denominator
verts = (edges_to_interp * edges_to_interp_sdf).sum(1)
idx_map = idx_map.reshape(-1,6)
v_id = torch.pow(2, torch.arange(4, dtype=torch.long, device="cuda"))
tetindex = (occ_fx4[valid_tets] * v_id.unsqueeze(0)).sum(-1)
num_triangles = self.num_triangles_table[tetindex]
# Generate triangle indices
faces = torch.cat((
torch.gather(input=idx_map[num_triangles == 1], dim=1, index=self.triangle_table[tetindex[num_triangles == 1]][:, :3]).reshape(-1,3),
torch.gather(input=idx_map[num_triangles == 2], dim=1, index=self.triangle_table[tetindex[num_triangles == 2]][:, :6]).reshape(-1,3),
), dim=0)
# Get global face index (static, does not depend on topology)
num_tets = tet_fx4.shape[0]
tet_gidx = torch.arange(num_tets, dtype=torch.long, device="cuda")[valid_tets]
face_gidx = torch.cat((
tet_gidx[num_triangles == 1]*2,
torch.stack((tet_gidx[num_triangles == 2]*2, tet_gidx[num_triangles == 2]*2 + 1), dim=-1).view(-1)
), dim=0)
uvs, uv_idx = self.map_uv(faces, face_gidx, num_tets*2)
face_to_valid_tet = torch.cat((
tet_gidx[num_triangles == 1],
torch.stack((tet_gidx[num_triangles == 2], tet_gidx[num_triangles == 2]), dim=-1).view(-1)
), dim=0)
valid_vert_idx = tet_fx4[tet_gidx[num_triangles > 0]].long().unique()
return verts, faces, uvs, uv_idx, face_to_valid_tet.long(), valid_vert_idx
# Regularizer
def sdf_reg_loss(sdf, all_edges):
sdf_f1x6x2 = sdf[all_edges.reshape(-1)].reshape(-1,2)
mask = torch.sign(sdf_f1x6x2[...,0]) != torch.sign(sdf_f1x6x2[...,1])
sdf_f1x6x2 = sdf_f1x6x2[mask]
sdf_diff = torch.nn.functional.binary_cross_entropy_with_logits(sdf_f1x6x2[...,0], (sdf_f1x6x2[...,1] > 0).float()) + \
torch.nn.functional.binary_cross_entropy_with_logits(sdf_f1x6x2[...,1], (sdf_f1x6x2[...,0] > 0).float())
return sdf_diff
class Buffer(object):
def __init__(self, shape, capacity, device) -> None:
self.len_curr = 0
self.pointer = 0
self.capacity = capacity
self.buffer = torch.zeros((capacity, ) + shape, device=device)
def push(self, x):
Push one single data point into the buffer
self.buffer[self.pointer] = x
self.pointer = (self.pointer + 1) % self.capacity
if self.len_curr < self.capacity:
self.len_curr += 1
def avg(self):
# simple windowed avg without exp decay
return torch.sign(torch.sign(self.buffer[:self.len_curr]).float().mean(dim=0)).float()
# Geometry interface
class DMTetGeometry(torch.nn.Module):
def __init__(self, grid_res, scale, FLAGS, root='./', grid_to_tet=None, deform_scale=1.0, **kwargs):
super(DMTetGeometry, self).__init__()
self.grid_res = grid_res
self.marching_tets = DMTet()
self.tanh = False
self.deform_scale = deform_scale
self.grid_to_tet = grid_to_tet
self.padding = 5
self.smooth_kernel = torch.ones(1, 1, self.padding*2 + 1, self.padding*2 + 1).cuda()
tets = np.load(os.path.join(root, 'data/tets/{}_tets_cropped.npz'.format(self.grid_res)))
self.verts = torch.tensor(tets['vertices'], dtype=torch.float32, device='cuda') * scale
self.indices = torch.tensor(tets['indices'], dtype=torch.long, device='cuda')
# Random init
sdf = torch.rand_like(self.verts[:,0]).clamp(-1.0, 1.0) - 0.1
self.sdf = torch.nn.Parameter(sdf.clone().detach(), requires_grad=True)
self.register_parameter('sdf', self.sdf)
self.deform = torch.nn.Parameter(torch.zeros_like(self.verts), requires_grad=True)
self.register_parameter('deform', self.deform)
self.sdf_ema = torch.nn.Parameter(sdf.clone().detach(), requires_grad=False)
self.deform_ema = torch.nn.Parameter(self.deform.clone().detach(), requires_grad=False)
self.ema_coeff = 0.9
self.sdf_buffer = Buffer(sdf.size(), capacity=200, device='cuda')
def generate_edges(self):
with torch.no_grad():
edges = torch.tensor([0,1,0,2,0,3,1,2,1,3,2,3], dtype = torch.long, device = "cuda")
all_edges = self.indices[:,edges].reshape(-1,2)
all_edges_sorted = torch.sort(all_edges, dim=1)[0]
self.all_edges = torch.unique(all_edges_sorted, dim=0)
def getAABB(self):
return torch.min(self.verts, dim=0).values, torch.max(self.verts, dim=0).values
def getVertNNDist(self):
v_deformed = (self.verts + 2 / (self.grid_res * 2) * torch.tanh(self.deform)).unsqueeze(0)
return (pytorch3d.ops.knn.knn_points(v_deformed, v_deformed, K=2).dists[0, :, -1].detach()) ## K=2 because dist(self, self)=0
def getTetCenters(self):
v_deformed = self.get_deformed() # size: N x 3
face_verts = v_deformed[self.indices] # size: M x 4 x 3
face_centers = face_verts.mean(dim=1) # size: M x 3
return face_centers
def getValidTetIdx(self):
# Run DM tet to get a base mesh
v_deformed = self.get_deformed()
verts, faces, uvs, uv_idx, tet_gidx, valid_vert_idx = self.marching_tets(v_deformed, self.sdf, self.indices)
return tet_gidx.long()
def getValidVertsIdx(self):
# Run DM tet to get a base mesh
v_deformed = self.get_deformed()
verts, faces, uvs, uv_idx, tet_gidx, valid_vert_idx = self.marching_tets(v_deformed, self.sdf, self.indices)
return self.indices[tet_gidx.long()].unique()
def getMesh(self, material, noise=0.0, ema=False):
# Run DM tet to get a base mesh
v_deformed = self.get_deformed(ema=ema)
if ema:
# sdf = self.sdf * (1 - self.ema_coeff) + self.sdf_ema.detach() * self.ema_coeff
sdf = self.sdf_ema
sdf = self.sdf
verts, faces, uvs, uv_idx, tet_gidx, valid_vert_idx = self.marching_tets(v_deformed, sdf, self.indices)
imesh = mesh.Mesh(verts, faces, v_tex=uvs, t_tex_idx=uv_idx, material=material)
imesh = mesh.auto_normals(imesh)
if material is not None:
# Run mesh operations to generate tangent space
imesh = mesh.compute_tangents(imesh)
imesh.valid_vert_idx = valid_vert_idx
return imesh
def get_deformed(self, no_grad=False, ema=False):
if no_grad:
deform = self.deform.detach()
deform = self.deform
if self.tanh:
# v_deformed = self.verts + 2 / (self.grid_res * 2) * torch.tanh(self.deform)
v_deformed = self.verts + 2 / (self.grid_res * 2) * torch.tanh(deform) * self.deform_scale
v_deformed = self.verts + 2 / (self.grid_res * 2) * deform * self.deform_scale
return v_deformed
def get_angle(self):
with torch.no_grad():
comb_list = [
(0, 1, 2, 3),
(0, 1, 3, 2),
(0, 2, 3, 1),
(1, 2, 3, 0)
directions = torch.zeros(self.indices.size(0), 4).cuda()
dir_vec = torch.zeros(self.indices.size(0), 4, 3).cuda()
vert_inds = torch.zeros(self.indices.size(0), 4).cuda().long()
count = 0
vpos_list = self.get_deformed()
for comb in comb_list:
face = self.indices[:, comb[:3]]
face_pos = vpos_list[face, :]
face_center = face_pos.mean(1, keepdim=False)
v = self.indices[:, comb[3]]
test_vec = vpos_list[v]
ref_vec = render_utils.safe_normalize(vpos_list[face[:, 0]] - face_center)
distance_vec = test_vec - render_utils.dot(test_vec, ref_vec) * ref_vec
directions[:, count] = torch.sign(render_utils.dot(test_vec, distance_vec)[:, 0])
dir_vec[:, count, :] = distance_vec
vert_inds[:, count] = v
count += 1
return directions, dir_vec, vert_inds
def clamp_deform(self):
if not self.tanh:
self.deform.data[:] = self.deform.data.clamp(-0.99, 0.99)
self.sdf.data[:] = self.sdf.data.clamp(-1.0, 1.0)
def render(self, glctx, target, lgt, opt_material, bsdf=None, ema=False, xfm_lgt=None, get_visible_tets=False):
opt_mesh = self.getMesh(opt_material, ema=ema)
tet_centers = self.getTetCenters() if get_visible_tets else None
return render.render_mesh(glctx, opt_mesh, target['mvp'], target['campos'], lgt, target['resolution'], spp=target['spp'],
msaa=True, background=target['background'], bsdf=bsdf, xfm_lgt=xfm_lgt, tet_centers=tet_centers)
def render_with_mesh(self, glctx, target, lgt, opt_material, bsdf=None, noise=0.0, ema=False, xfm_lgt=None):
opt_mesh = self.getMesh(opt_material, noise=noise, ema=ema)
return opt_mesh, render.render_mesh(glctx, opt_mesh, target['mvp'], target['campos'], lgt, target['resolution'], spp=target['spp'],
msaa=True, background=target['background'], bsdf=bsdf, xfm_lgt=xfm_lgt)
def update_ema(self, ema_coeff=0.9):
self.sdf_ema.data[:] = self.sdf_buffer.avg()
self.deform_ema.data[:] = self.deform.data[:]
def render_ema(self, glctx, target, lgt, opt_material, bsdf=None, xfm_lgt=None):
opt_mesh = self.getMesh(opt_material, ema=True)
return render.render_mesh(glctx, opt_mesh, target['mvp'], target['campos'], lgt, target['resolution'], spp=target['spp'],
msaa=True, background=target['background'], bsdf=bsdf, xfm_lgt=xfm_lgt)
def tick(self, glctx, target, lgt, opt_material, loss_fn, iteration, with_reg=True, xfm_lgt=None, no_depth_thin=True):
self.deform.requires_grad = True
if iteration > 200 and iteration < 2000 and iteration % 20 == 0:
with torch.no_grad():
v_pos = self.get_deformed()
v_pos_camera_homo = ru.xfm_points(v_pos[None, ...], target['mvp'])
v_pos_camera = v_pos_camera_homo[:, :, :2] / v_pos_camera_homo[:, :, -1:]
v_pos_camera_discrete = torch.round((v_pos_camera * 0.5 + 0.5).clip(0, 1) * (target['resolution'][0] - 1)).long()
mask_cont = F.conv2d(target['mask_cont'][:, :, :, 0].unsqueeze(1), self.smooth_kernel, stride=1, padding=self.padding)[:, 0]
target_mask = mask_cont == 0
for k in range(target_mask.size(0)):
assert v_pos_camera_discrete[k].min() >= 0 and v_pos_camera_discrete[k].max() < target['resolution'][0]
v_mask = target_mask[k, v_pos_camera_discrete[k, :, 1], v_pos_camera_discrete[k, :, 0]].view(v_pos.size(0))
self.sdf.data[v_mask] = 1e-2
self.deform.data[v_mask] = 0.0
# ==============================================================================================
# Render optimizable object with identical conditions
# ==============================================================================================
imesh, buffers = self.render_with_mesh(glctx, target, lgt, opt_material, noise=0.0, xfm_lgt=xfm_lgt)
# ==============================================================================================
# Compute loss
# ==============================================================================================
t_iter = iteration / self.FLAGS.iter
# Image-space loss, split into a coverage component and a color component
color_ref = target['img']
img_loss = torch.tensor(0.0).cuda()
alpha_scale = 1.0
img_loss = torch.nn.functional.mse_loss(buffers['shaded'][..., 3:], color_ref[..., 3:]) * alpha_scale
img_loss = img_loss + loss_fn(buffers['shaded'][..., 0:3] * color_ref[..., 3:], color_ref[..., 0:3] * color_ref[..., 3:])
color_ref_second = target['img_second']
img_loss = img_loss + torch.nn.functional.mse_loss(buffers['shaded_second'][..., 3:], color_ref_second[..., 3:]) * alpha_scale * 1e-1
img_loss = img_loss + loss_fn(buffers['shaded_second'][..., 0:3] * color_ref_second[..., 3:], color_ref_second[..., 0:3] * color_ref_second[..., 3:]) * 1e-1
mask = (target['mask_cont'][:, :, :, 0] == 1.0).float()
if iteration < 10000:
depth_scale = 100.0
depth_scale = 1.0
if iteration % 300 == 0 and iteration < 1790:
self.deform.data[:] *= 0.4
if no_depth_thin:
valid_depth_mask = (target['depth_second'] >= 0).float().detach()
depth_prox_mask = ((target['depth_second'] - target['depth']).abs() >= 5e-3).float().detach()
valid_depth_mask = 1.0
depth_diff = (buffers['depth'][:, :, :, :1] - target['depth'][:, :, :, :1]).abs() * mask.unsqueeze(-1) * valid_depth_mask
depth_diff_second = (buffers['depth_second'][:, :, :, :1] - target['depth_second'][:, :, :, :1]).abs() * mask.unsqueeze(-1) * valid_depth_mask * depth_prox_mask * 1e-1
thres = 1.0
l1_loss_mask = (depth_diff < thres).float()
l1_loss_mask_second = (depth_diff_second < thres).float()
img_loss = img_loss + (
l1_loss_mask * depth_diff
+ (1 - l1_loss_mask) * (depth_diff.pow(2) + thres - thres**2)
).mean() * 1.0 * depth_scale
+ (
l1_loss_mask_second * depth_diff_second
+ (1 - l1_loss_mask_second) * (depth_diff_second.pow(2) + thres - thres**2)
).mean() * 1.0 * depth_scale
reg_loss = torch.tensor(0.0).cuda()
# SDF regularizer
iter_thres = 0
sdf_weight = self.FLAGS.sdf_regularizer - (self.FLAGS.sdf_regularizer - 0.01) * min(1.0, 4.0 * ((iteration - iter_thres) / (self.FLAGS.iter - iter_thres)))
sdf_mask = torch.zeros_like(self.sdf, device=self.sdf.device)
sdf_mask[imesh.valid_vert_idx] = 1.0
sdf_masked = self.sdf.detach() * sdf_mask + self.sdf * (1 - sdf_mask)
reg_loss = sdf_reg_loss(sdf_masked, self.all_edges).mean() * sdf_weight * 0.1 # Dropoff to 0.01
# Albedo (k_d) smoothnesss regularizer
reg_loss += torch.mean(buffers['kd_grad'][..., :-1] * buffers['kd_grad'][..., -1:]) * 0.03 * min(1.0, iteration / 500)
# Visibility regularizer
reg_loss += torch.mean(buffers['occlusion'][..., :-1] * buffers['occlusion'][..., -1:]) * 1e0 * min(1.0, iteration / 500)
# pointcloud chamfer distance
pred_points = kaolin.ops.mesh.sample_points(imesh.v_pos.unsqueeze(0), imesh.t_pos_idx, 50000)[0][0]
target_pts = target['spts']
chamfer = kaolin.metrics.pointcloud.chamfer_distance(pred_points.unsqueeze(0), target_pts.unsqueeze(0)).mean()
reg_loss += chamfer
return img_loss, reg_loss

@ -0,0 +1,350 @@
# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.
import numpy as np
import torch
from ..render import mesh
from ..render import render
from ..render import regularizer
import kaolin
from ..render import util as render_utils
import torch.nn.functional as F
# Marching tetrahedrons implementation (differentiable), adapted from
# https://github.com/NVIDIAGameWorks/kaolin/blob/master/kaolin/ops/conversions/tetmesh.py
class DMTet:
def __init__(self):
self.triangle_table = torch.tensor([
[-1, -1, -1, -1, -1, -1],
[ 1, 0, 2, -1, -1, -1],
[ 4, 0, 3, -1, -1, -1],
[ 1, 4, 2, 1, 3, 4],
[ 3, 1, 5, -1, -1, -1],
[ 2, 3, 0, 2, 5, 3],
[ 1, 4, 0, 1, 5, 4],
[ 4, 2, 5, -1, -1, -1],
[ 4, 5, 2, -1, -1, -1],
[ 4, 1, 0, 4, 5, 1],
[ 3, 2, 0, 3, 5, 2],
[ 1, 3, 5, -1, -1, -1],
[ 4, 1, 2, 4, 3, 1],
[ 3, 0, 4, -1, -1, -1],
[ 2, 0, 1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1]
], dtype=torch.long, device='cuda')
self.num_triangles_table = torch.tensor([0,1,1,2,1,2,2,1,1,2,2,1,2,1,1,0], dtype=torch.long, device='cuda')
self.base_tet_edges = torch.tensor([0,1,0,2,0,3,1,2,1,3,2,3], dtype=torch.long, device='cuda')
# Utility functions
def sort_edges(self, edges_ex2):
with torch.no_grad():
order = (edges_ex2[:,0] > edges_ex2[:,1]).long()
order = order.unsqueeze(dim=1)
a = torch.gather(input=edges_ex2, index=order, dim=1)
b = torch.gather(input=edges_ex2, index=1-order, dim=1)
return torch.stack([a, b],-1)
def map_uv(self, faces, face_gidx, max_idx):
N = int(np.ceil(np.sqrt((max_idx+1)//2)))
tex_y, tex_x = torch.meshgrid(
torch.linspace(0, 1 - (1 / N), N, dtype=torch.float32, device="cuda"),
torch.linspace(0, 1 - (1 / N), N, dtype=torch.float32, device="cuda"),
pad = 0.9 / N
uvs = torch.stack([
tex_x , tex_y,
tex_x + pad, tex_y,
tex_x + pad, tex_y + pad,
tex_x , tex_y + pad
], dim=-1).view(-1, 2)
def _idx(tet_idx, N):
x = tet_idx % N
y = torch.div(tet_idx, N, rounding_mode='trunc')
return y * N + x
tet_idx = _idx(torch.div(face_gidx, 2, rounding_mode='trunc'), N)
tri_idx = face_gidx % 2
uv_idx = torch.stack((
tet_idx * 4, tet_idx * 4 + tri_idx + 1, tet_idx * 4 + tri_idx + 2
), dim = -1). view(-1, 3)
return uvs, uv_idx
# Marching tets implementation
def __call__(self, pos_nx3, sdf_n, tet_fx4, get_tet_gidx=False):
with torch.no_grad():
occ_n = sdf_n > 0
occ_fx4 = occ_n[tet_fx4.reshape(-1)].reshape(-1,4)
occ_sum = torch.sum(occ_fx4, -1)
valid_tets = (occ_sum>0) & (occ_sum<4)
occ_sum = occ_sum[valid_tets]
# find all vertices
all_edges = tet_fx4[valid_tets][:,self.base_tet_edges].reshape(-1,2)
all_edges = self.sort_edges(all_edges)
unique_edges, idx_map = torch.unique(all_edges,dim=0, return_inverse=True)
unique_edges = unique_edges.long()
mask_edges = occ_n[unique_edges.reshape(-1)].reshape(-1,2).sum(-1) == 1
mapping = torch.ones((unique_edges.shape[0]), dtype=torch.long, device="cuda") * -1
mapping[mask_edges] = torch.arange(mask_edges.sum(), dtype=torch.long,device="cuda")
idx_map = mapping[idx_map] # map edges to verts
interp_v = unique_edges[mask_edges]
edges_to_interp = pos_nx3[interp_v.reshape(-1)].reshape(-1,2,3)
edges_to_interp_sdf = sdf_n[interp_v.reshape(-1)].reshape(-1,2,1)
edges_to_interp_sdf[:,-1] *= -1
denominator = edges_to_interp_sdf.sum(1,keepdim = True)
edges_to_interp_sdf = torch.flip(edges_to_interp_sdf, [1])/denominator
verts = (edges_to_interp * edges_to_interp_sdf).sum(1)
idx_map = idx_map.reshape(-1,6)
v_id = torch.pow(2, torch.arange(4, dtype=torch.long, device="cuda"))
tetindex = (occ_fx4[valid_tets] * v_id.unsqueeze(0)).sum(-1)
num_triangles = self.num_triangles_table[tetindex]
# Generate triangle indices
faces = torch.cat((
torch.gather(input=idx_map[num_triangles == 1], dim=1, index=self.triangle_table[tetindex[num_triangles == 1]][:, :3]).reshape(-1,3),
torch.gather(input=idx_map[num_triangles == 2], dim=1, index=self.triangle_table[tetindex[num_triangles == 2]][:, :6]).reshape(-1,3),
), dim=0)
# Get global face index (static, does not depend on topology)
num_tets = tet_fx4.shape[0]
tet_gidx = torch.arange(num_tets, dtype=torch.long, device="cuda")[valid_tets]
face_gidx = torch.cat((
tet_gidx[num_triangles == 1]*2,
torch.stack((tet_gidx[num_triangles == 2]*2, tet_gidx[num_triangles == 2]*2 + 1), dim=-1).view(-1)
), dim=0)
uvs, uv_idx = self.map_uv(faces, face_gidx, num_tets*2)
if get_tet_gidx:
face_to_valid_tet = torch.cat((
tet_gidx[num_triangles == 1],
torch.stack((tet_gidx[num_triangles == 2], tet_gidx[num_triangles == 2]), dim=-1).view(-1)
), dim=0)
return verts, faces, uvs, uv_idx, face_to_valid_tet.long()
return verts, faces, uvs, uv_idx
# Regularizer
def sdf_reg_loss(sdf, all_edges):
sdf_f1x6x2 = sdf[all_edges.reshape(-1)].reshape(-1,2)
mask = torch.sign(sdf_f1x6x2[...,0]) != torch.sign(sdf_f1x6x2[...,1])
sdf_f1x6x2 = sdf_f1x6x2[mask]
sdf_diff = torch.nn.functional.binary_cross_entropy_with_logits(sdf_f1x6x2[...,0], (sdf_f1x6x2[...,1] > 0).float()) + \
torch.nn.functional.binary_cross_entropy_with_logits(sdf_f1x6x2[...,1], (sdf_f1x6x2[...,0] > 0).float())
return sdf_diff
# Geometry interface
class DMTetGeometryFixedTopo(torch.nn.Module):
def __init__(self, dmt_geometry, base_mesh, grid_res, scale, FLAGS, deform_scale=1.0, **kwargs):
super(DMTetGeometryFixedTopo, self).__init__()
self.grid_res = grid_res
self.marching_tets = DMTet()
self.initial_guess = base_mesh
self.scale = scale
self.tanh = False
self.deform_scale = deform_scale
tets = np.load('./data/tets/{}_tets_cropped.npz'.format(self.grid_res))
self.verts = torch.tensor(tets['vertices'], dtype=torch.float32, device='cuda') * scale
self.indices = torch.tensor(tets['indices'], dtype=torch.long, device='cuda')
self.sdf_sign = torch.nn.Parameter(torch.sign(dmt_geometry.sdf.data + 1e-8).float(), requires_grad=False)
self.sdf_sign.data[self.sdf_sign.data == 0] = 1.0 ## Avoid abiguity
self.register_parameter('sdf_sign', self.sdf_sign)
self.sdf_abs = torch.nn.Parameter(torch.ones_like(dmt_geometry.sdf), requires_grad=False)
self.register_parameter('sdf_abs', self.sdf_abs)
self.deform = torch.nn.Parameter(dmt_geometry.deform.data, requires_grad=True)
self.register_parameter('deform', self.deform)
self.sdf_abs_ema = torch.nn.Parameter(self.sdf_abs.clone().detach(), requires_grad=False)
self.deform_ema = torch.nn.Parameter(self.deform.clone().detach(), requires_grad=False)
def set_init_v_pos(self):
with torch.no_grad():
v_deformed = self.get_deformed()
verts, faces, uvs, uv_idx = self.marching_tets(v_deformed, self.sdf_sign * self.sdf_abs.abs(), self.indices)
self.initial_guess_v_pos = verts
def generate_edges(self):
with torch.no_grad():
edges = torch.tensor([0,1,0,2,0,3,1,2,1,3,2,3], dtype = torch.long, device = "cuda")
all_edges = self.indices[:,edges].reshape(-1,2)
all_edges_sorted = torch.sort(all_edges, dim=1)[0]
self.all_edges = torch.unique(all_edges_sorted, dim=0)
def getAABB(self):
return torch.min(self.verts, dim=0).values, torch.max(self.verts, dim=0).values
def getVertNNDist(self):
raise NotImplementedError
v_deformed = (self.verts + 2 / (self.grid_res * 2) * torch.tanh(self.deform)).unsqueeze(0)
return (pytorch3d.ops.knn.knn_points(v_deformed, v_deformed, K=2).dists[0, :, -1].detach()) ## K=2 because dist(self, self)=0
def getMesh(self, material):
# Run DM tet to get a base mesh
v_deformed = self.get_deformed()
verts, faces, uvs, uv_idx = self.marching_tets(v_deformed, self.sdf_sign * self.sdf_abs.abs(), self.indices)
imesh = mesh.Mesh(verts, faces, v_tex=uvs, t_tex_idx=uv_idx, material=material)
# Run mesh operations to generate tangent space
imesh = mesh.auto_normals(imesh)
imesh = mesh.compute_tangents(imesh)
return imesh
def getMesh_tet_gidx(self, material):
# Run DM tet to get a base mesh
v_deformed = self.get_deformed()
verts, faces, uvs, uv_idx, tet_gidx = self.marching_tets(
v_deformed, self.sdf_sign * self.sdf_abs.abs(), self.indices, get_tet_gidx=True)
imesh = mesh.Mesh(verts, faces, v_tex=uvs, t_tex_idx=uv_idx, material=material)
# Run mesh operations to generate tangent space
imesh = mesh.auto_normals(imesh)
imesh = mesh.compute_tangents(imesh)
return imesh, tet_gidx
def update_ema(self, ema_coeff=0.9):
def get_deformed(self):
if self.tanh:
v_deformed = self.verts + 2 / (self.grid_res * 2) * torch.tanh(self.deform) * self.deform_scale
v_deformed = self.verts + 2 / (self.grid_res * 2) * self.deform * self.deform_scale
return v_deformed
def getValidTetIdx(self):
# Run DM tet to get a base mesh
v_deformed = self.get_deformed()
verts, faces, uvs, uv_idx, tet_gidx = self.marching_tets(
v_deformed, self.sdf_sign * self.sdf_abs.abs(), self.indices, get_tet_gidx=True)
return tet_gidx.long()
def getValidVertsIdx(self):
# Run DM tet to get a base mesh
v_deformed = self.get_deformed()
verts, faces, uvs, uv_idx, tet_gidx = self.marching_tets(
v_deformed, self.sdf_sign * self.sdf_abs.abs(), self.indices, get_tet_gidx=True)
return self.indices[tet_gidx.long()].unique()
def getTetCenters(self):
v_deformed = self.get_deformed() # size: N x 3
face_verts = v_deformed[self.indices] # size: M x 4 x 3
face_centers = face_verts.mean(dim=1) # size: M x 3
return face_centers
def clamp_deform(self):
if not self.tanh:
self.deform.data[:] = self.deform.data.clamp(-0.99, 0.99)
def render(self, glctx, target, lgt, opt_material, bsdf=None, ema=False, xfm_lgt=None, get_visible_tets=False):
opt_mesh = self.getMesh(opt_material)
tet_centers = self.getTetCenters() if get_visible_tets else None
return render.render_mesh(
glctx, opt_mesh, target['mvp'], target['campos'], lgt, target['resolution'], spp=target['spp'],
msaa=True, background=target['background'], bsdf=bsdf, xfm_lgt=xfm_lgt, tet_centers=tet_centers)
def render_with_mesh(self, glctx, target, lgt, opt_material, bsdf=None, xfm_lgt=None):
opt_mesh = self.getMesh(opt_material)
return opt_mesh, render.render_mesh(
glctx, opt_mesh, target['mvp'], target['campos'], lgt, target['resolution'], spp=target['spp'],
num_layers=self.FLAGS.layers, msaa=True, background=target['background'], bsdf=bsdf, xfm_lgt=xfm_lgt)
def tick(self, glctx, target, lgt, opt_material, loss_fn, iteration, with_reg=True, xfm_lgt=None, no_depth_thin=True):
# ==============================================================================================
# Render optimizable object with identical conditions
# ==============================================================================================
imesh, buffers = self.render_with_mesh(glctx, target, lgt, opt_material, xfm_lgt=xfm_lgt)
# ==============================================================================================
# Compute loss
# ==============================================================================================
t_iter = iteration / self.FLAGS.iter
# Image-space loss, split into a coverage component and a color component
color_ref = target['img']
img_loss = torch.nn.functional.mse_loss(buffers['shaded'][..., 3:], color_ref[..., 3:])
img_loss = img_loss + loss_fn(
buffers['shaded'][..., 0:3] * color_ref[..., 3:],
color_ref[..., 0:3] * color_ref[..., 3:]
mask = target['mask'][:, :, :, 0]
if no_depth_thin:
valid_depth_mask = (
(target['depth_second'] >= 0).float() * ((target['depth_second'] - target['depth']).abs() >= 5e-3).float()
valid_depth_mask = 1.0
depth_diff = (buffers['depth'][:, :, :, :1] - target['depth'][:, :, :, :1]).abs() * mask.unsqueeze(-1) * valid_depth_mask
depth_diff = (buffers['depth_second'][:, :, :, :1] - target['depth_second'][:, :, :, :1]).abs() * mask.unsqueeze(-1) * valid_depth_mask * 1e-1
l1_loss_mask = (depth_diff < 1.0).float()
img_loss = img_loss + (l1_loss_mask * depth_diff + (1 - l1_loss_mask) * depth_diff.pow(2)).mean() * 100.0
reg_loss = torch.tensor([0], dtype=torch.float32, device="cuda")
# Compute regularizer.
reg_loss += regularizer.laplace_regularizer_const(imesh.v_pos - self.initial_guess_v_pos, imesh.t_pos_idx) * self.FLAGS.laplace_scale * (1 - t_iter) * 1e-2
### Chamfer distance for ShapeNet
pred_points = kaolin.ops.mesh.sample_points(imesh.v_pos.unsqueeze(0), imesh.t_pos_idx, 50000)[0][0]
target_pts = target['spts']
chamfer = kaolin.metrics.pointcloud.chamfer_distance(pred_points.unsqueeze(0), target_pts.unsqueeze(0)).mean()
reg_loss += chamfer
return img_loss, reg_loss

@ -0,0 +1,516 @@
# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.
import os
import numpy as np
import torch
from ..render import mesh
from ..render import render
from ..render import regularizer
import kaolin
import pytorch3d.ops
from ..render import util as render_utils
import torch.nn.functional as F
from ..render import renderutils as ru
# Marching tetrahedrons implementation (differentiable), adapted from
# https://github.com/NVIDIAGameWorks/kaolin/blob/master/kaolin/ops/conversions/tetmesh.py
class DMTet:
def __init__(self):
self.triangle_table = torch.tensor([
[-1, -1, -1, -1, -1, -1],
[ 1, 0, 2, -1, -1, -1],
[ 4, 0, 3, -1, -1, -1],
[ 1, 4, 2, 1, 3, 4],
[ 3, 1, 5, -1, -1, -1],
[ 2, 3, 0, 2, 5, 3],
[ 1, 4, 0, 1, 5, 4],
[ 4, 2, 5, -1, -1, -1],
[ 4, 5, 2, -1, -1, -1],
[ 4, 1, 0, 4, 5, 1],
[ 3, 2, 0, 3, 5, 2],
[ 1, 3, 5, -1, -1, -1],
[ 4, 1, 2, 4, 3, 1],
[ 3, 0, 4, -1, -1, -1],
[ 2, 0, 1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1]
], dtype=torch.long, device='cuda')
self.num_triangles_table = torch.tensor([0,1,1,2,1,2,2,1,1,2,2,1,2,1,1,0], dtype=torch.long, device='cuda')
self.base_tet_edges = torch.tensor([0,1,0,2,0,3,1,2,1,3,2,3], dtype=torch.long, device='cuda')
# Utility functions
def sort_edges(self, edges_ex2):
with torch.no_grad():
order = (edges_ex2[:,0] > edges_ex2[:,1]).long()
order = order.unsqueeze(dim=1)
a = torch.gather(input=edges_ex2, index=order, dim=1)
b = torch.gather(input=edges_ex2, index=1-order, dim=1)
return torch.stack([a, b],-1)
def map_uv(self, faces, face_gidx, max_idx):
N = int(np.ceil(np.sqrt((max_idx+1)//2)))
tex_y, tex_x = torch.meshgrid(
torch.linspace(0, 1 - (1 / N), N, dtype=torch.float32, device="cuda"),
torch.linspace(0, 1 - (1 / N), N, dtype=torch.float32, device="cuda"),
pad = 0.9 / N
uvs = torch.stack([
tex_x , tex_y,
tex_x + pad, tex_y,
tex_x + pad, tex_y + pad,
tex_x , tex_y + pad
], dim=-1).view(-1, 2)
def _idx(tet_idx, N):
x = tet_idx % N
y = torch.div(tet_idx, N, rounding_mode='trunc')
return y * N + x
tet_idx = _idx(torch.div(face_gidx, 2, rounding_mode='trunc'), N)
tri_idx = face_gidx % 2
uv_idx = torch.stack((
tet_idx * 4, tet_idx * 4 + tri_idx + 1, tet_idx * 4 + tri_idx + 2
), dim = -1). view(-1, 3)
return uvs, uv_idx
# Marching tets implementation
def __call__(self, pos_nx3, sdf_n, tet_fx4):
with torch.no_grad():
occ_n = sdf_n > 0
occ_fx4 = occ_n[tet_fx4.reshape(-1)].reshape(-1,4)
occ_sum = torch.sum(occ_fx4, -1)
valid_tets = (occ_sum>0) & (occ_sum<4)
occ_sum = occ_sum[valid_tets]
# find all vertices
all_edges = tet_fx4[valid_tets][:,self.base_tet_edges].reshape(-1,2)
all_edges = self.sort_edges(all_edges)
unique_edges, idx_map = torch.unique(all_edges,dim=0, return_inverse=True)
unique_edges = unique_edges.long()
mask_edges = occ_n[unique_edges.reshape(-1)].reshape(-1,2).sum(-1) == 1
mapping = torch.ones((unique_edges.shape[0]), dtype=torch.long, device="cuda") * -1
mapping[mask_edges] = torch.arange(mask_edges.sum(), dtype=torch.long,device="cuda")
idx_map = mapping[idx_map] # map edges to verts
interp_v = unique_edges[mask_edges]
edges_to_interp = pos_nx3[interp_v.reshape(-1)].reshape(-1,2,3)
edges_to_interp_sdf = sdf_n[interp_v.reshape(-1)].reshape(-1,2,1)
edges_to_interp_sdf[:,-1] *= -1
denominator = edges_to_interp_sdf.sum(1,keepdim = True)
edges_to_interp_sdf = torch.flip(edges_to_interp_sdf, [1])/denominator
verts = (edges_to_interp * edges_to_interp_sdf).sum(1)
idx_map = idx_map.reshape(-1,6)
v_id = torch.pow(2, torch.arange(4, dtype=torch.long, device="cuda"))
tetindex = (occ_fx4[valid_tets] * v_id.unsqueeze(0)).sum(-1)
num_triangles = self.num_triangles_table[tetindex]
# Generate triangle indices
faces = torch.cat((
input=idx_map[num_triangles == 1],
index=self.triangle_table[tetindex[num_triangles == 1]][:, :3]
input=idx_map[num_triangles == 2],
index=self.triangle_table[tetindex[num_triangles == 2]][:, :6]
), dim=0)
# Get global face index (static, does not depend on topology)
num_tets = tet_fx4.shape[0]
tet_gidx = torch.arange(num_tets, dtype=torch.long, device="cuda")[valid_tets]
face_gidx = torch.cat((
tet_gidx[num_triangles == 1]*2,
torch.stack((tet_gidx[num_triangles == 2]*2, tet_gidx[num_triangles == 2]*2 + 1), dim=-1).view(-1)
), dim=0)
uvs, uv_idx = self.map_uv(faces, face_gidx, num_tets*2)
face_to_valid_tet = torch.cat((
tet_gidx[num_triangles == 1],
torch.stack((tet_gidx[num_triangles == 2], tet_gidx[num_triangles == 2]), dim=-1).view(-1)
), dim=0)
valid_vert_idx = tet_fx4[tet_gidx[num_triangles > 0]].long().unique()
return verts, faces, uvs, uv_idx, face_to_valid_tet.long(), valid_vert_idx
# Regularizer
def sdf_reg_loss(sdf, all_edges):
sdf_f1x6x2 = sdf[all_edges.reshape(-1)].reshape(-1,2)
mask = torch.sign(sdf_f1x6x2[...,0]) != torch.sign(sdf_f1x6x2[...,1])
sdf_f1x6x2 = sdf_f1x6x2[mask]
sdf_diff = torch.nn.functional.binary_cross_entropy_with_logits(sdf_f1x6x2[...,0], (sdf_f1x6x2[...,1] > 0).float()) + \
torch.nn.functional.binary_cross_entropy_with_logits(sdf_f1x6x2[...,1], (sdf_f1x6x2[...,0] > 0).float())
return sdf_diff
class Buffer(object):
def __init__(self, shape, capacity, device) -> None:
self.len_curr = 0
self.pointer = 0
self.capacity = capacity
self.buffer = torch.zeros((capacity, ) + shape, device=device)
def push(self, x):
Push one single data point into the buffer
self.buffer[self.pointer] = x
self.pointer = (self.pointer + 1) % self.capacity
if self.len_curr < self.capacity:
self.len_curr += 1
def avg(self):
return torch.sign(torch.sign(self.buffer[:self.len_curr]).float().mean(dim=0)).float()
# return self.buffer[:self.len_curr].mean(dim=0)
# return self.buffer[:self.len_curr][-1]
# Geometry interface
class DMTetGeometry(torch.nn.Module):
def __init__(self, grid_res, scale, FLAGS, root='./', grid_to_tet=None, deform_scale=2.0, **kwargs):
super(DMTetGeometry, self).__init__()
self.grid_res = grid_res
self.marching_tets = DMTet()
self.cropped = True
self.tanh = False
self.deform_scale = deform_scale
self.grid_to_tet = grid_to_tet
if self.cropped:
print("use cropped tets")
tets = np.load(os.path.join(root, 'data/tets/{}_tets_cropped.npz'.format(self.grid_res)))
tets = np.load(os.path.join(root, 'data/tets/{}_tets.npz'.format(self.grid_res)))
print('tet min and max', tets['vertices'].min() * scale, tets['vertices'].max() * scale)
self.verts = torch.tensor(tets['vertices'], dtype=torch.float32, device='cuda') * scale
self.indices = torch.tensor(tets['indices'], dtype=torch.long, device='cuda')
# Random init
sdf = torch.rand_like(self.verts[:,0]).clamp(-1.0, 1.0) - 0.1
# sdf = torch.sign(sdf) * 0.1
# sdf = self.verts.pow(2).sum(dim=-1).sqrt() - 0.5
self.sdf = torch.nn.Parameter(sdf.clone().detach(), requires_grad=True)
self.register_parameter('sdf', self.sdf)
self.deform = torch.nn.Parameter(torch.zeros_like(self.verts), requires_grad=True)
self.register_parameter('deform', self.deform)
self.alpha = None
self.sdf_ema = torch.nn.Parameter(sdf.clone().detach(), requires_grad=False)
self.deform_ema = torch.nn.Parameter(self.deform.clone().detach(), requires_grad=False)
# self.ema_coeff = 0.7
self.ema_coeff = 0.9
self.sdf_buffer = Buffer(sdf.size(), capacity=200, device='cuda')
def generate_edges(self):
with torch.no_grad():
edges = torch.tensor([0,1,0,2,0,3,1,2,1,3,2,3], dtype = torch.long, device = "cuda")
all_edges = self.indices[:,edges].reshape(-1,2)
all_edges_sorted = torch.sort(all_edges, dim=1)[0]
self.all_edges = torch.unique(all_edges_sorted, dim=0)
def getAABB(self):
return torch.min(self.verts, dim=0).values, torch.max(self.verts, dim=0).values
def getVertNNDist(self):
v_deformed = (self.verts + 2 / (self.grid_res * 2) * torch.tanh(self.deform)).unsqueeze(0)
return pytorch3d.ops.knn.knn_points(
v_deformed, v_deformed, K=2
).dists[0, :, -1].detach() ## K=2 because dist(self, self)=0
def getTetCenters(self):
v_deformed = self.get_deformed() # size: N x 3
face_verts = v_deformed[self.indices] # size: M x 4 x 3
face_centers = face_verts.mean(dim=1) # size: M x 3
return face_centers
def getValidTetIdx(self):
# Run DM tet to get a base mesh
v_deformed = self.get_deformed()
verts, faces, uvs, uv_idx, tet_gidx, valid_vert_idx = self.marching_tets(v_deformed, self.sdf, self.indices)
return tet_gidx.long()
def getValidVertsIdx(self):
# Run DM tet to get a base mesh
v_deformed = self.get_deformed()
verts, faces, uvs, uv_idx, tet_gidx, valid_vert_idx = self.marching_tets(v_deformed, self.sdf, self.indices)
return self.indices[tet_gidx.long()].unique()
def getMesh(self, material, noise=0.0, ema=False):
# Run DM tet to get a base mesh
v_deformed = self.get_deformed(ema=ema)
if ema:
sdf = self.sdf_ema
sdf = self.sdf
verts, faces, uvs, uv_idx, tet_gidx, valid_vert_idx = self.marching_tets(v_deformed, sdf, self.indices)
imesh = mesh.Mesh(verts, faces, v_tex=uvs, t_tex_idx=uv_idx, material=material)
if material is not None:
# Run mesh operations to generate tangent space
imesh = mesh.auto_normals(imesh)
imesh = mesh.compute_tangents(imesh)
imesh.valid_vert_idx = valid_vert_idx
return imesh
def getMesh_no_deform(self, material, noise=0.0, ema=False):
# Run DM tet to get a base mesh
if ema:
# sdf = self.sdf * (1 - self.ema_coeff) + self.sdf_ema.detach() * self.ema_coeff
sdf = self.sdf_ema
sdf = self.sdf
verts, faces, uvs, uv_idx, tet_gidx, valid_vert_idx = self.marching_tets(self.verts, torch.sign(sdf), self.indices)
imesh = mesh.Mesh(verts, faces, v_tex=uvs, t_tex_idx=uv_idx, material=material)
# Run mesh operations to generate tangent space
imesh = mesh.auto_normals(imesh)
imesh = mesh.compute_tangents(imesh)
return imesh
def getMesh_no_deform_gd(self, material, noise=0.0, ema=False):
# Run DM tet to get a base mesh
v_deformed = self.get_deformed(no_grad=True)
if ema:
# sdf = self.sdf * (1 - self.ema_coeff) + self.sdf_ema.detach() * self.ema_coeff
sdf = self.sdf_ema
sdf = self.sdf
verts, faces, uvs, uv_idx, tet_gidx, valid_vert_idx = self.marching_tets(v_deformed, sdf, self.indices)
imesh = mesh.Mesh(verts, faces, v_tex=uvs, t_tex_idx=uv_idx, material=material)
# Run mesh operations to generate tangent space
imesh = mesh.auto_normals(imesh)
imesh = mesh.compute_tangents(imesh)
return imesh
def get_deformed(self, no_grad=False, ema=False):
if no_grad:
deform = self.deform.detach()
deform = self.deform
if self.tanh:
v_deformed = self.verts + 2 / (self.grid_res * 2) * torch.tanh(deform) * self.deform_scale
v_deformed = self.verts + 2 / (self.grid_res * 2) * deform * self.deform_scale
return v_deformed
def get_angle(self):
with torch.no_grad():
comb_list = [
(0, 1, 2, 3),
(0, 1, 3, 2),
(0, 2, 3, 1),
(1, 2, 3, 0)
directions = torch.zeros(self.indices.size(0), 4).cuda()
dir_vec = torch.zeros(self.indices.size(0), 4, 3).cuda()
vert_inds = torch.zeros(self.indices.size(0), 4).cuda().long()
count = 0
vpos_list = self.get_deformed()
for comb in comb_list:
face = self.indices[:, comb[:3]]
face_pos = vpos_list[face, :]
face_center = face_pos.mean(1, keepdim=False)
v = self.indices[:, comb[3]]
test_vec = vpos_list[v]
ref_vec = render_utils.safe_normalize(vpos_list[face[:, 0]] - face_center)
distance_vec = test_vec - render_utils.dot(test_vec, ref_vec) * ref_vec
directions[:, count] = torch.sign(render_utils.dot(test_vec, distance_vec)[:, 0])
dir_vec[:, count, :] = distance_vec
vert_inds[:, count] = v
count += 1
return directions, dir_vec, vert_inds
def clamp_deform(self):
if not self.tanh:
self.deform.data[:] = self.deform.data.clamp(-0.99, 0.99)
self.sdf.data[:] = self.sdf.data.clamp(-1.0, 1.0)
def render(self, glctx, target, lgt, opt_material, bsdf=None, ema=False, xfm_lgt=None, get_visible_tets=False):
opt_mesh = self.getMesh(opt_material, ema=ema)
tet_centers = self.getTetCenters() if get_visible_tets else None
return render.render_mesh(
glctx, opt_mesh, target['mvp'], target['campos'], lgt, target['resolution'], spp=target['spp'],
msaa=True, background=target['background'], bsdf=bsdf, xfm_lgt=xfm_lgt, tet_centers=tet_centers)
def render_with_mesh(self, glctx, target, lgt, opt_material, bsdf=None, noise=0.0, ema=False, xfm_lgt=None):
opt_mesh = self.getMesh(opt_material, noise=noise, ema=ema)
return opt_mesh, render.render_mesh(
glctx, opt_mesh, target['mvp'], target['campos'], lgt, target['resolution'], spp=target['spp'],
msaa=True, background=target['background'], bsdf=bsdf, xfm_lgt=xfm_lgt)
def update_ema(self, ema_coeff=0.9):
self.sdf_ema.data[:] = self.sdf_buffer.avg()
# self.sdf_ema.data[:] = self.sdf.data[:] * (1 - ema_coeff) + self.sdf_ema.data[:] * ema_coeff
# self.deform_ema.data[:] = self.deform.data[:] * (1 - ema_coeff) + self.deform_ema.data[:] * ema_coeff
self.deform_ema.data[:] = self.deform.data[:]
def render_ema(self, glctx, target, lgt, opt_material, bsdf=None, xfm_lgt=None):
opt_mesh = self.getMesh(opt_material, ema=True)
return render.render_mesh(
glctx, opt_mesh, target['mvp'], target['campos'], lgt, target['resolution'], spp=target['spp'],
msaa=True, background=target['background'], bsdf=bsdf, xfm_lgt=xfm_lgt)
def init_with_gt_surface(self, gt_verts, surface_faces, campos):
with torch.no_grad():
surface_face_verts = gt_verts[surface_faces]
surface_centers = surface_face_verts.mean(dim=1)
v_pos = self.get_deformed()
results = pytorch3d.ops.knn_points(v_pos[None, ...], surface_centers[None, ...])
dists, nn_idx = results.dists, results.idx
displacement = (v_pos - surface_centers[nn_idx[0, :, 0]])
view_dirs = campos - surface_centers
normals = torch.cross(
surface_face_verts[:, 0] - surface_face_verts[:, 1], surface_face_verts[:, 0] - surface_face_verts[:, 2])
mask = ((normals * view_dirs).sum(dim=-1, keepdim=True) >= 0).float()
normals = normals * mask - normals * (1 - mask)
outside_verts_idx = ((displacement * normals[nn_idx[0, :, 0]]).sum(dim=-1) > 0)
self.sdf.data[outside_verts_idx] = 1.0
def tick(self, glctx, target, lgt, opt_material, loss_fn, iteration, with_reg=True, xfm_lgt=None, no_depth_thin=True):
if iteration < 100:
self.deform.requires_grad = False
self.deform_scale = 2.0
self.deform.requires_grad = True
self.deform_scale = 2.0
if iteration > 200 and iteration < 2000 and iteration % 20 == 0:
with torch.no_grad():
v_pos = self.get_deformed()
v_pos_camera_homo = ru.xfm_points(v_pos[None, ...], target['mvp'])
v_pos_camera = v_pos_camera_homo[:, :, :2] / v_pos_camera_homo[:, :, -1:]
v_pos_camera_discrete = ((v_pos_camera * 0.5 + 0.5).clip(0, 1) * (target['resolution'][0] - 1)).long()
target_mask = target['mask_cont'][:, :, :, 0] == 0
for k in range(target_mask.size(0)):
assert v_pos_camera_discrete[k].min() >= 0 and v_pos_camera_discrete[k].max() < target['resolution'][0]
v_mask = target_mask[k, v_pos_camera_discrete[k, :, 1], v_pos_camera_discrete[k, :, 0]].view(v_pos.size(0))
# print(v_mask.sum())
self.sdf.data[v_mask] = self.sdf.data[v_mask].abs().clamp(0.0, 1.0)
# ==============================================================================================
# Render optimizable object with identical conditions
# ==============================================================================================
imesh, buffers = self.render_with_mesh(glctx, target, lgt, opt_material, noise=0.0, xfm_lgt=xfm_lgt)
# ==============================================================================================
# Compute loss
# ==============================================================================================
# Image-space loss, split into a coverage component and a color component
color_ref = target['img']
img_loss = torch.tensor(0.0).cuda()
img_loss = torch.nn.functional.mse_loss(buffers['shaded'][..., 3:], color_ref[..., 3:])
img_loss = img_loss + loss_fn(buffers['shaded'][..., 0:3] * color_ref[..., 3:], color_ref[..., 0:3] * color_ref[..., 3:])
mask = (target['mask_cont'][:, :, :, 0] == 1.0).float()
mask_curr = (buffers['mask_cont'][:, :, :, 0] == 1.0).float()
if iteration % 300 == 0 and iteration < 1790:
self.deform.data[:] *= 0.4
if no_depth_thin:
valid_depth_mask = (
(target['depth_second'] >= 0).float() * ((target['depth_second'] - target['depth']).abs() >= 5e-3).float()
valid_depth_mask = 1.0
depth_diff = (buffers['depth'][:, :, :, :1] - target['depth'][:, :, :, :1]).abs() * mask.unsqueeze(-1) * valid_depth_mask
l1_loss_mask = (depth_diff < 1.0).float()
img_loss = img_loss + (l1_loss_mask * depth_diff + (1 - l1_loss_mask) * depth_diff.pow(2)).mean() * 100.0
reg_loss = torch.tensor(0.0).cuda()
# SDF regularizer
iter_thres = 0
sdf_weight = self.FLAGS.sdf_regularizer - (self.FLAGS.sdf_regularizer - 0.01) * min(1.0, 4.0 * ((iteration - iter_thres) / (self.FLAGS.iter - iter_thres)))
sdf_mask = torch.zeros_like(self.sdf, device=self.sdf.device)
sdf_mask[imesh.valid_vert_idx] = 1.0
sdf_masked = self.sdf.detach() * sdf_mask + self.sdf * (1 - sdf_mask)
reg_loss = sdf_reg_loss(sdf_masked, self.all_edges).mean() * sdf_weight * 2.5 # Dropoff to 0.01
# Albedo (k_d) smoothnesss regularizer
reg_loss += torch.mean(buffers['kd_grad'][..., :-1] * buffers['kd_grad'][..., -1:]) * 0.03 * min(1.0, iteration / 500)
# Visibility regularizer
reg_loss += torch.mean(buffers['occlusion'][..., :-1] * buffers['occlusion'][..., -1:]) * 1e0 * min(1.0, iteration / 500)
pred_points = kaolin.ops.mesh.sample_points(imesh.v_pos.unsqueeze(0), imesh.t_pos_idx, 50000)[0][0]
target_pts = target['spts']
chamfer = kaolin.metrics.pointcloud.chamfer_distance(pred_points.unsqueeze(0), target_pts.unsqueeze(0)).mean()
reg_loss += chamfer
return img_loss, reg_loss

@ -0,0 +1,128 @@
import torch
def _base_sample_points_selected_faces(face_vertices, face_features=None):
"""Base function to sample points over selected faces.
The coordinates of the face vertices are interpolated to generate new samples.
face_vertices (tuple of torch.Tensor):
Coordinates of vertices, corresponding to selected faces to sample from.
A tuple of 3 entries corresponding to each of the face vertices.
Each entry is a torch.Tensor of shape :math:`(\\text{batch_size}, \\text{num_samples}, 3)`.
face_features (tuple of torch.Tensor, Optional):
Features of face vertices, corresponding to selected faces to sample from.
A tuple of 3 entries corresponding to each of the face vertices.
Each entry is a torch.Tensor of shape
:math:`(\\text{batch_size}, \\text{num_samples}, \\text{feature_dim})`.
(torch.Tensor, torch.Tensor):
Sampled point coordinates of shape :math:`(\\text{batch_size}, \\text{num_samples}, 3)`.
Sampled points interpolated features of shape
:math:`(\\text{batch_size}, \\text{num_samples}, \\text{feature_dim})`.
If `face_vertices_features` arg is not specified, the returned interpolated features are None.
face_vertices0, face_vertices1, face_vertices2 = face_vertices
sampling_shape = tuple(int(d) for d in face_vertices0.shape[:-1]) + (1,)
# u is proximity to middle point between v1 and v2 against v0.
# v is proximity to v2 against v1.
# The probability density for u should be f_U(u) = 2u.
# However, torch.rand use a uniform (f_X(x) = x) distribution,
# so using torch.sqrt we make a change of variable to have the desired density
# f_Y(y) = f_X(y ^ 2) * |d(y ^ 2) / dy| = 2y
u = torch.sqrt(torch.rand(sampling_shape,
v = torch.rand(sampling_shape,
w0 = 1 - u
w1 = u * (1 - v)
w2 = u * v
points = w0 * face_vertices0 + w1 * face_vertices1 + w2 * face_vertices2
features = None
if face_features is not None:
face_features0, face_features1, face_features2 = face_features
features = w0 * face_features0 + w1 * face_features1 + \
w2 * face_features2
return points, features
def sample_points(vertices, faces, num_samples, areas=None, face_features=None):
r"""Uniformly sample points over the surface of triangle meshes.
First face on which the point is sampled is randomly selected,
with the probability of selection being proportional to the area of the face.
then the coordinate on the face is uniformly sampled.
If ``face_features`` is defined for the mesh faces,
the sampled points will be returned with interpolated features as well,
otherwise, no feature interpolation will occur.
vertices (torch.Tensor):
The vertices of the meshes, of shape
:math:`(\text{batch_size}, \text{num_vertices}, 3)`.
faces (torch.LongTensor):
The faces of the mesh, of shape :math:`(\text{num_faces}, 3)`.
num_samples (int):
The number of point sampled per mesh.
areas (torch.Tensor, optional):
The areas of each face, of shape :math:`(\text{batch_size}, \text{num_faces})`,
can be preprocessed, for fast on-the-fly sampling,
will be computed if None (default).
face_features (torch.Tensor, optional):
Per-vertex-per-face features, matching ``faces`` order,
of shape :math:`(\text{batch_size}, \text{num_faces}, 3, \text{feature_dim})`.
For example:
1. Texture uv coordinates would be of shape
:math:`(\text{batch_size}, \text{num_faces}, 3, 2)`.
2. RGB color values would be of shape
:math:`(\text{batch_size}, \text{num_faces}, 3, 3)`.
When specified, it is used to interpolate the features for new sampled points.
See also:
:func:`~kaolin.ops.mesh.index_vertices_by_faces` for conversion of features defined per vertex
and need to be converted to per-vertex-per-face shape of :math:`(\text{num_faces}, 3)`.
(torch.Tensor, torch.LongTensor, (optional) torch.Tensor):
the pointclouds of shape :math:`(\text{batch_size}, \text{num_samples}, 3)`,
and the indexes of the faces selected,
of shape :math:`(\text{batch_size}, \text{num_samples})`.
If ``face_features`` arg is specified, then the interpolated features of sampled points of shape
:math:`(\text{batch_size}, \text{num_samples}, \text{feature_dim})` are also returned.
if faces.shape[-1] != 3:
raise NotImplementedError("sample_points is only implemented for triangle meshes")
faces_0, faces_1, faces_2 = torch.split(faces, 1, dim=1) # (num_faces, 3) -> tuple of (num_faces,)
face_v_0 = torch.index_select(vertices, 1, faces_0.reshape(-1)) # (batch_size, num_faces, 3)
face_v_1 = torch.index_select(vertices, 1, faces_1.reshape(-1)) # (batch_size, num_faces, 3)
face_v_2 = torch.index_select(vertices, 1, faces_2.reshape(-1)) # (batch_size, num_faces, 3)
if areas is None:
areas = _base_face_areas(face_v_0, face_v_1, face_v_2).squeeze(-1)
face_dist = torch.distributions.Categorical(areas)
face_choices = face_dist.sample([num_samples]).transpose(0, 1)
_face_choices = face_choices.unsqueeze(-1).repeat(1, 1, 3)
v0 = torch.gather(face_v_0, 1, _face_choices) # (batch_size, num_samples, 3)
v1 = torch.gather(face_v_1, 1, _face_choices) # (batch_size, num_samples, 3)
v2 = torch.gather(face_v_2, 1, _face_choices) # (batch_size, num_samples, 3)
face_vertices_choices = (v0, v1, v2)
# UV coordinates are available, make sure to calculate them for sampled points as well
face_features_choices = None
if face_features is not None:
feat_dim = face_features.shape[-1]
# (num_faces, 3) -> tuple of (num_faces,)
_face_choices = face_choices[..., None, None].repeat(1, 1, 3, feat_dim)
face_features_choices = torch.gather(face_features, 1, _face_choices)
face_features_choices = tuple(
tmp_feat.squeeze(2) for tmp_feat in torch.split(face_features_choices, 1, dim=2))
points, point_features = _base_sample_points_selected_faces(
face_vertices_choices, face_features_choices)
if point_features is not None:
return points, face_choices, point_features
return points, face_choices

@ -0,0 +1,187 @@
# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.
import os
import numpy as np
import torch
import nvdiffrast.torch as dr
from . import util
from . import renderutils as ru
import sys
# Utility functions
class cubemap_mip(torch.autograd.Function):
def forward(ctx, cubemap):
return util.avg_pool_nhwc(cubemap, (2,2))
def backward(ctx, dout):
res = dout.shape[1] * 2
out = torch.zeros(6, res, res, dout.shape[-1], dtype=torch.float32, device="cuda")
for s in range(6):
gy, gx = torch.meshgrid(torch.linspace(-1.0 + 1.0 / res, 1.0 - 1.0 / res, res, device="cuda"),
torch.linspace(-1.0 + 1.0 / res, 1.0 - 1.0 / res, res, device="cuda"),
v = util.safe_normalize(util.cube_to_dir(s, gx, gy))
out[s, ...] = dr.texture(dout[None, ...] * 0.25, v[None, ...].contiguous(), filter_mode='linear', boundary_mode='cube')
return out
# Split-sum environment map light source with automatic mipmap generation
class EnvironmentLight(torch.nn.Module):
def __init__(self, base, trainable=True):
super(EnvironmentLight, self).__init__()
self.mtx = None
self.base = torch.nn.Parameter(base.clone().detach(), requires_grad=trainable)
print(f"light trainable or not: {trainable}")
if trainable:
self.register_parameter('env_base', self.base)
def xfm(self, mtx):
self.mtx = mtx
def clone(self):
return EnvironmentLight(self.base.clone().detach())
def clamp_(self, min=None, max=None):
self.base.clamp_(min, max)
def get_mip(self, roughness):
return torch.where(roughness < self.MAX_ROUGHNESS
, (torch.clamp(roughness, self.MIN_ROUGHNESS, self.MAX_ROUGHNESS) - self.MIN_ROUGHNESS) / (self.MAX_ROUGHNESS - self.MIN_ROUGHNESS) * (len(self.specular) - 2)
, (torch.clamp(roughness, self.MAX_ROUGHNESS, 1.0) - self.MAX_ROUGHNESS) / (1.0 - self.MAX_ROUGHNESS) + len(self.specular) - 2)
def build_mips(self, cutoff=0.99):
self.specular = [self.base]
while self.specular[-1].shape[1] > self.LIGHT_MIN_RES:
self.specular += [cubemap_mip.apply(self.specular[-1])]
self.diffuse = ru.diffuse_cubemap(self.specular[-1])
for idx in range(len(self.specular) - 1):
roughness = (idx / (len(self.specular) - 2)) * (self.MAX_ROUGHNESS - self.MIN_ROUGHNESS) + self.MIN_ROUGHNESS
self.specular[idx] = ru.specular_cubemap(self.specular[idx], roughness, cutoff)
self.specular[-1] = ru.specular_cubemap(self.specular[-1], 1.0, cutoff)
def regularizer(self):
white = (self.base[..., 0:1] + self.base[..., 1:2] + self.base[..., 2:3]) / 3.0
return torch.mean(torch.abs(self.base - white))
def shade(self, gb_pos, gb_normal, kd, ks, view_pos, specular=True, xfm_lgt=None):
wo = util.safe_normalize(view_pos - gb_pos)
if specular:
roughness = ks[..., 1:2] # y component
metallic = ks[..., 2:3] # z component
spec_col = (1.0 - metallic)*0.04 + kd * metallic
diff_col = kd * (1.0 - metallic)
diff_col = kd
reflvec = util.safe_normalize(util.reflect(wo, gb_normal))
nrmvec = gb_normal
if xfm_lgt is not None:
# print(self.mtx.size())
mtx = torch.as_tensor(xfm_lgt, dtype=torch.float32, device='cuda')
reflvec = ru.xfm_vectors(reflvec.view(reflvec.shape[0], reflvec.shape[1] * reflvec.shape[2], reflvec.shape[3]), mtx).view(*reflvec.shape)
nrmvec = ru.xfm_vectors(nrmvec.view(nrmvec.shape[0], nrmvec.shape[1] * nrmvec.shape[2], nrmvec.shape[3]), mtx).view(*nrmvec.shape)
elif self.mtx is not None: # Rotate lookup
raise NotImplementedError
# print(self.mtx.size())
mtx = torch.as_tensor(self.mtx, dtype=torch.float32, device='cuda')
reflvec = ru.xfm_vectors(reflvec.view(reflvec.shape[0], reflvec.shape[1] * reflvec.shape[2], reflvec.shape[3]), mtx).view(*reflvec.shape)
nrmvec = ru.xfm_vectors(nrmvec.view(nrmvec.shape[0], nrmvec.shape[1] * nrmvec.shape[2], nrmvec.shape[3]), mtx).view(*nrmvec.shape)
# if self.mtx is not None: # Rotate lookup
# # print(self.mtx.size())
# mtx = torch.as_tensor(self.mtx, dtype=torch.float32, device='cuda')
# reflvec = ru.xfm_vectors(reflvec.view(reflvec.shape[0], reflvec.shape[1] * reflvec.shape[2], reflvec.shape[3]), mtx).view(*reflvec.shape)
# nrmvec = ru.xfm_vectors(nrmvec.view(nrmvec.shape[0], nrmvec.shape[1] * nrmvec.shape[2], nrmvec.shape[3]), mtx).view(*nrmvec.shape)
# Diffuse lookup
diffuse = dr.texture(self.diffuse[None, ...], nrmvec.contiguous(), filter_mode='linear', boundary_mode='cube')
shaded_col = diffuse * diff_col
if specular:
raise NotImplementedError
# Lookup FG term from lookup texture
NdotV = torch.clamp(util.dot(wo, gb_normal), min=1e-4)
fg_uv = torch.cat((NdotV, roughness), dim=-1)
if not hasattr(self, '_FG_LUT'):
self._FG_LUT = torch.as_tensor(np.fromfile('data/irrmaps/bsdf_256_256.bin', dtype=np.float32).reshape(1, 256, 256, 2), dtype=torch.float32, device='cuda')
fg_lookup = dr.texture(self._FG_LUT, fg_uv, filter_mode='linear', boundary_mode='clamp')
# Roughness adjusted specular env lookup
miplevel = self.get_mip(roughness)
spec = dr.texture(self.specular[0][None, ...], reflvec.contiguous(), mip=list(m[None, ...] for m in self.specular[1:]), mip_level_bias=miplevel[..., 0], filter_mode='linear-mipmap-linear', boundary_mode='cube')
# Compute aggregate lighting
reflectance = spec_col * fg_lookup[...,0:1] + fg_lookup[...,1:2]
shaded_col += spec * reflectance
assert ks[..., 0:1].sum().item() == 0
return shaded_col * (1.0 - ks[..., 0:1]) # Modulate by hemisphere visibility
# Load and store
# Load from latlong .HDR file
def _load_env_hdr(fn, scale=1.0, trainable=True):
print("load env inner loop")
latlong_img = torch.tensor(util.load_image(fn), dtype=torch.float32, device='cuda')*scale
print("get cubemap")
cubemap = util.latlong_to_cubemap(latlong_img, [512, 512])
print("get light object")
l = EnvironmentLight(cubemap, trainable=trainable)
print("build mips")
print("build mips done")
return l
def load_env(fn, scale=1.0, trainable=True):
if os.path.splitext(fn)[1].lower() == ".hdr":
return _load_env_hdr(fn, scale, trainable=trainable)
assert False, "Unknown envlight extension %s" % os.path.splitext(fn)[1]
def save_env_map(fn, light):
assert isinstance(light, EnvironmentLight), "Can only save EnvironmentLight currently"
if isinstance(light, EnvironmentLight):
color = util.cubemap_to_latlong(light.base, [512, 1024])
util.save_image_raw(fn, color.detach().cpu().numpy())
# Create trainable env map with random initialization
def create_trainable_env_rnd(base_res, scale=0.5, bias=0.25):
base = torch.rand(6, base_res, base_res, 3, dtype=torch.float32, device='cuda') * scale + bias
return EnvironmentLight(base)

@ -0,0 +1,199 @@
# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.
import os
import numpy as np
import torch
from . import util
from . import texture
# Wrapper to make materials behave like a python dict, but register textures as
# torch.nn.Module parameters.
class Material(torch.nn.Module):
def __init__(self, mat_dict):
super(Material, self).__init__()
self.mat_keys = set()
for key in mat_dict.keys():
self[key] = mat_dict[key]
def __contains__(self, key):
return hasattr(self, key)
def __getitem__(self, key):
return getattr(self, key)
def __setitem__(self, key, val):
setattr(self, key, val)
def __delitem__(self, key):
delattr(self, key)
def keys(self):
return self.mat_keys
# .mtl material format loading / storing
def load_mtl(fn, clear_ks=True, avoid_pure_black=False):
import re
mtl_path = os.path.dirname(fn)
# Read file
with open(fn, 'r') as f:
lines = f.readlines()
# Parse materials
materials = []
for line in lines:
split_line = re.split(' +|\t+|\n+', line.strip())
prefix = split_line[0].lower()
data = split_line[1:]
if 'newmtl' in prefix:
material = Material({'name' : data[0]})
materials += [material]
elif materials:
if 'bsdf' in prefix or 'map_kd' in prefix or 'map_ks' in prefix or 'bump' in prefix:
material[prefix] = data[0]
if 'kd' in prefix and avoid_pure_black:
tmp_kd = torch.tensor(tuple(float(d) for d in data), dtype=torch.float32, device='cuda')
if tmp_kd.sum() == 0.0:
tmp_kd[0] = 1.0
tmp_kd[1] = 0.75
material[prefix] = tmp_kd
material[prefix] = torch.tensor(tuple(float(d) for d in data), dtype=torch.float32, device='cuda')
# Convert everything to textures. Our code expects 'kd' and 'ks' to be texture maps. So replace constants with 1x1 maps
for mat in materials:
if not 'bsdf' in mat:
mat['bsdf'] = 'pbr'
if 'map_kd' in mat:
mat['kd'] = texture.load_texture2D(os.path.join(mtl_path, mat['map_kd']))
# mat['kd'] = texture.load_texture2D(os.path.join(mtl_path, mat['map_kd']), channels=3)
mat['kd'] = texture.Texture2D(mat['kd'])
if 'map_ks' in mat:
mat['ks'] = texture.load_texture2D(os.path.join(mtl_path, mat['map_ks']), channels=3)
mat['ks'] = texture.Texture2D(mat['ks'])
if 'bump' in mat:
mat['normal'] = texture.load_texture2D(os.path.join(mtl_path, mat['bump']), lambda_fn=lambda x: x * 2 - 1, channels=3)
# Convert Kd from sRGB to linear RGB
mat['kd'] = texture.srgb_to_rgb(mat['kd'])
if clear_ks:
# Override ORM occlusion (red) channel by zeros. We hijack this channel
for mip in mat['ks'].getMips():
mip[..., 0] = 0.0
return materials
def save_mtl(fn, material):
folder = os.path.dirname(fn)
with open(fn, "w") as f:
f.write('newmtl defaultMat\n')
if material is not None:
f.write('bsdf %s\n' % material['bsdf'])
if 'kd' in material.keys():
f.write('map_kd texture_kd.png\n')
texture.save_texture2D(os.path.join(folder, 'texture_kd.png'), texture.rgb_to_srgb(material['kd']))
if 'ks' in material.keys():
f.write('map_ks texture_ks.png\n')
texture.save_texture2D(os.path.join(folder, 'texture_ks.png'), material['ks'])
if 'normal' in material.keys():
f.write('bump texture_n.png\n')
texture.save_texture2D(os.path.join(folder, 'texture_n.png'), material['normal'], lambda_fn=lambda x:(util.safe_normalize(x)+1)*0.5)
f.write('Kd 1 1 1\n')
f.write('Ks 0 0 0\n')
f.write('Ka 0 0 0\n')
f.write('Tf 1 1 1\n')
f.write('Ni 1\n')
f.write('Ns 0\n')
# Merge multiple materials into a single uber-material
def _upscale_replicate(x, full_res):
x = x.permute(0, 3, 1, 2)
x = torch.nn.functional.pad(x, (0, full_res[1] - x.shape[3], 0, full_res[0] - x.shape[2]), 'replicate')
return x.permute(0, 2, 3, 1).contiguous()
def merge_materials(materials, texcoords, tfaces, mfaces):
assert len(materials) > 0
for mat in materials:
assert mat['bsdf'] == materials[0]['bsdf'], "All materials must have the same BSDF (uber shader)"
assert ('normal' in mat) is ('normal' in materials[0]), "All materials must have either normal map enabled or disabled"
uber_material = Material({
'name' : 'uber_material',
'bsdf' : materials[0]['bsdf'],
textures = ['kd', 'ks', 'normal']
# Find maximum texture resolution across all materials and textures
max_res = None
for mat in materials:
for tex in textures:
tex_res = np.array(mat[tex].getRes()) if tex in mat else np.array([1, 1])
max_res = np.maximum(max_res, tex_res) if max_res is not None else tex_res
# Compute size of compund texture and round up to nearest PoT
full_res = 2**np.ceil(np.log2(max_res * np.array([1, len(materials)]))).astype(np.int)
# Normalize texture resolution across all materials & combine into a single large texture
for tex in textures:
if tex in materials[0]:
tex_data = torch.cat(tuple(util.scale_img_nhwc(mat[tex].data, tuple(max_res)) for mat in materials), dim=2) # Lay out all textures horizontally, NHWC so dim2 is x
tex_data = _upscale_replicate(tex_data, full_res)
uber_material[tex] = texture.Texture2D(tex_data)
# Compute scaling values for used / unused texture area
s_coeff = [full_res[0] / max_res[0], full_res[1] / max_res[1]]
# Recompute texture coordinates to cooincide with new composite texture
new_tverts = {}
new_tverts_data = []
for fi in range(len(tfaces)):
matIdx = mfaces[fi]
for vi in range(3):
ti = tfaces[fi][vi]
if not (ti in new_tverts):
new_tverts[ti] = {}
if not (matIdx in new_tverts[ti]): # create new vertex
if len(texcoords) == 0:
# continue
new_tverts_data.append([(matIdx) / s_coeff[1], 0]) # Offset texture coodrinate (x direction) by material id & scale to local space. Note, texcoords are (u,v) but texture is stored (w,h) so the indexes swap here
new_tverts[ti][matIdx] = len(new_tverts_data) - 1
new_tverts_data.append([(matIdx + texcoords[ti][0]) / s_coeff[1], texcoords[ti][1] / s_coeff[0]]) # Offset texture coodrinate (x direction) by material id & scale to local space. Note, texcoords are (u,v) but texture is stored (w,h) so the indexes swap here
new_tverts[ti][matIdx] = len(new_tverts_data) - 1
# if not (matIdx in new_tverts[ti]): # create new vertex
# new_tverts_data.append([(matIdx + texcoords[ti][0]) / s_coeff[1], texcoords[ti][1] / s_coeff[0]]) # Offset texture coodrinate (x direction) by material id & scale to local space. Note, texcoords are (u,v) but texture is stored (w,h) so the indexes swap here
# new_tverts[ti][matIdx] = len(new_tverts_data) - 1
tfaces[fi][vi] = new_tverts[ti][matIdx] # reindex vertex
return uber_material, new_tverts_data, tfaces

@ -0,0 +1,277 @@
# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.
import os
import numpy as np
import torch
from . import obj
from . import util
# Base mesh class
class Mesh:
def __init__(self, v_pos=None, t_pos_idx=None, v_nrm=None, t_nrm_idx=None, v_tex=None, t_tex_idx=None, v_tng=None, t_tng_idx=None,
material=None, base=None, f_nrm=None):
self.v_pos = v_pos
self.v_nrm = v_nrm
self.v_tex = v_tex
self.v_tng = v_tng
self.t_pos_idx = t_pos_idx
self.t_nrm_idx = t_nrm_idx
self.t_tex_idx = t_tex_idx
self.t_tng_idx = t_tng_idx
self.material = material
# self.f_nrm = f_nrm
if base is not None:
i0 = self.t_pos_idx[:, 0]
i1 = self.t_pos_idx[:, 1]
i2 = self.t_pos_idx[:, 2]
v0 = self.v_pos[i0, :]
v1 = self.v_pos[i1, :]
v2 = self.v_pos[i2, :]
self.f_nrm = face_normals = torch.cross(v1 - v0, v2 - v0)
self.f_nrm = f_nrm
def copy_none(self, other):
if self.v_pos is None:
self.v_pos = other.v_pos
if self.t_pos_idx is None:
self.t_pos_idx = other.t_pos_idx
if self.v_nrm is None:
self.v_nrm = other.v_nrm
if self.t_nrm_idx is None:
self.t_nrm_idx = other.t_nrm_idx
if self.v_tex is None:
self.v_tex = other.v_tex
if self.t_tex_idx is None:
self.t_tex_idx = other.t_tex_idx
if self.v_tng is None:
self.v_tng = other.v_tng
if self.t_tng_idx is None:
self.t_tng_idx = other.t_tng_idx
if self.material is None:
self.material = other.material
# if self.f_nrm is None:
# self.f_nrm = other.f_nrm
def clone(self):
out = Mesh(base=self)
if out.v_pos is not None:
out.v_pos = out.v_pos.clone().detach()
if out.t_pos_idx is not None:
out.t_pos_idx = out.t_pos_idx.clone().detach()
if out.v_nrm is not None:
out.v_nrm = out.v_nrm.clone().detach()
if out.t_nrm_idx is not None:
out.t_nrm_idx = out.t_nrm_idx.clone().detach()
if out.v_tex is not None:
out.v_tex = out.v_tex.clone().detach()
if out.t_tex_idx is not None:
out.t_tex_idx = out.t_tex_idx.clone().detach()
if out.v_tng is not None:
out.v_tng = out.v_tng.clone().detach()
if out.t_tng_idx is not None:
out.t_tng_idx = out.t_tng_idx.clone().detach()
if out.f_nrm is not None:
out.f_nrm = out.f_nrm.clone().detach()
return out
# Mesh loeading helper
def load_mesh(filename, mtl_override=None, mtl_default=None, use_default=False, no_additional=False):
name, ext = os.path.splitext(filename)
if ext == ".obj":
return obj.load_obj(filename, clear_ks=True, mtl_override=mtl_override, mtl_default=mtl_default, use_default=use_default, no_additional=no_additional)
assert False, "Invalid mesh file extension"
# Compute AABB
def aabb(mesh):
return torch.min(mesh.v_pos, dim=0).values, torch.max(mesh.v_pos, dim=0).values
# Compute AABB with only used vertices
def aabb_clean(mesh):
v_pos_clean = mesh.v_pos[mesh.t_pos_idx.unique()]
return torch.min(v_pos_clean, dim=0).values, torch.max(v_pos_clean, dim=0).values
# Compute unique edge list from attribute/vertex index list
def compute_edges(attr_idx, return_inverse=False):
with torch.no_grad():
# Create all edges, packed by triangle
all_edges = torch.cat((
torch.stack((attr_idx[:, 0], attr_idx[:, 1]), dim=-1),
torch.stack((attr_idx[:, 1], attr_idx[:, 2]), dim=-1),
torch.stack((attr_idx[:, 2], attr_idx[:, 0]), dim=-1),
), dim=-1).view(-1, 2)
# Swap edge order so min index is always first
order = (all_edges[:, 0] > all_edges[:, 1]).long().unsqueeze(dim=1)
sorted_edges = torch.cat((
torch.gather(all_edges, 1, order),
torch.gather(all_edges, 1, 1 - order)
), dim=-1)
# Eliminate duplicates and return inverse mapping
return torch.unique(sorted_edges, dim=0, return_inverse=return_inverse)
# Compute unique edge to face mapping from attribute/vertex index list
def compute_edge_to_face_mapping(attr_idx, return_inverse=False):
with torch.no_grad():
# Get unique edges
# Create all edges, packed by triangle
all_edges = torch.cat((
torch.stack((attr_idx[:, 0], attr_idx[:, 1]), dim=-1),
torch.stack((attr_idx[:, 1], attr_idx[:, 2]), dim=-1),
torch.stack((attr_idx[:, 2], attr_idx[:, 0]), dim=-1),
), dim=-1).view(-1, 2)
# Swap edge order so min index is always first
order = (all_edges[:, 0] > all_edges[:, 1]).long().unsqueeze(dim=1)
sorted_edges = torch.cat((
torch.gather(all_edges, 1, order),
torch.gather(all_edges, 1, 1 - order)
), dim=-1)
# Elliminate duplicates and return inverse mapping
unique_edges, idx_map = torch.unique(sorted_edges, dim=0, return_inverse=True)
tris = torch.arange(attr_idx.shape[0]).repeat_interleave(3).cuda()
tris_per_edge = torch.zeros((unique_edges.shape[0], 2), dtype=torch.int64).cuda()
# Compute edge to face table
mask0 = order[:,0] == 0
mask1 = order[:,0] == 1
tris_per_edge[idx_map[mask0], 0] = tris[mask0]
tris_per_edge[idx_map[mask1], 1] = tris[mask1]
return tris_per_edge
# Align base mesh to reference mesh:move & rescale to match bounding boxes.
def unit_size(mesh):
with torch.no_grad():
vmin, vmax = aabb(mesh)
scale = 2 / torch.max(vmax - vmin).item()
v_pos = mesh.v_pos - (vmax + vmin) / 2 # Center mesh on origin
v_pos = v_pos * scale # Rescale to unit size
return Mesh(v_pos, base=mesh)
# Center & scale mesh for rendering
def center_by_reference(base_mesh, ref_aabb, scale):
center = (ref_aabb[0] + ref_aabb[1]) * 0.5
scale = scale / torch.max(ref_aabb[1] - ref_aabb[0]).item()
print('normalization:', center, scale)
v_pos = (base_mesh.v_pos - center[None, ...]) * scale
return Mesh(v_pos, base=base_mesh)
# Simple smooth vertex normal computation
def auto_normals(imesh):
i0 = imesh.t_pos_idx[:, 0]
i1 = imesh.t_pos_idx[:, 1]
i2 = imesh.t_pos_idx[:, 2]
v0 = imesh.v_pos[i0, :]
v1 = imesh.v_pos[i1, :]
v2 = imesh.v_pos[i2, :]
f_nrm = face_normals = torch.cross(v1 - v0, v2 - v0)
# Splat face normals to vertices
v_nrm = torch.zeros_like(imesh.v_pos)
v_nrm.scatter_add_(0, i0[:, None].repeat(1,3), face_normals)
v_nrm.scatter_add_(0, i1[:, None].repeat(1,3), face_normals)
v_nrm.scatter_add_(0, i2[:, None].repeat(1,3), face_normals)
# Normalize, replace zero (degenerated) normals with some default value
v_nrm = torch.where(util.dot(v_nrm, v_nrm) > 1e-20, v_nrm, torch.tensor([0.0, 0.0, 1.0], dtype=torch.float32, device='cuda'))
v_nrm = util.safe_normalize(v_nrm)
if torch.is_anomaly_enabled():
assert torch.all(torch.isfinite(v_nrm))
return Mesh(v_nrm=v_nrm, t_nrm_idx=imesh.t_pos_idx, base=imesh, f_nrm=f_nrm)
# Compute tangent space from texture map coordinates
# Follows http://www.mikktspace.com/ conventions
def compute_tangents(imesh):
vn_idx = [None] * 3
pos = [None] * 3
tex = [None] * 3
for i in range(0,3):
pos[i] = imesh.v_pos[imesh.t_pos_idx[:, i]]
tex[i] = imesh.v_tex[imesh.t_tex_idx[:, i]]
vn_idx[i] = imesh.t_nrm_idx[:, i]
tangents = torch.zeros_like(imesh.v_nrm)
tansum = torch.zeros_like(imesh.v_nrm)
# Compute tangent space for each triangle
uve1 = tex[1] - tex[0]
uve2 = tex[2] - tex[0]
pe1 = pos[1] - pos[0]
pe2 = pos[2] - pos[0]
nom = (pe1 * uve2[..., 1:2] - pe2 * uve1[..., 1:2])
denom = (uve1[..., 0:1] * uve2[..., 1:2] - uve1[..., 1:2] * uve2[..., 0:1])
assert not torch.isnan(uve1).any()
assert not torch.isnan(uve2).any()
assert not torch.isnan(pe1).any()
assert not torch.isnan(pe2).any()
# Avoid division by zero for degenerated texture coordinates
tang = nom / torch.where(denom > 0.0, torch.clamp(denom, min=1e-6), torch.clamp(denom, max=-1e-6)) #### ZL: something wrong in this line, not sure why
assert (torch.where(denom > 0.0, torch.clamp(denom, min=1e-6), torch.clamp(denom, max=-1e-6)) != 0.0).all()
assert not torch.isnan(nom).any()
assert not torch.isnan(tang).any()
# Update all 3 vertices
for i in range(0,3):
idx = vn_idx[i][:, None].repeat(1,3)
tangents.scatter_add_(0, idx, tang) # tangents[n_i] = tangents[n_i] + tang
tansum.scatter_add_(0, idx, torch.ones_like(tang)) # tansum[n_i] = tansum[n_i] + 1
tangents = tangents / tansum
assert not torch.isnan(tangents).any()
# Normalize and make sure tangent is perpendicular to normal
tangents = util.safe_normalize(tangents)
tangents = util.safe_normalize(tangents - util.dot(tangents, imesh.v_nrm) * imesh.v_nrm)
if torch.is_anomaly_enabled():
assert torch.all(torch.isfinite(tangents))
return Mesh(v_tng=tangents, t_tng_idx=imesh.t_nrm_idx, base=imesh)

@ -0,0 +1,104 @@
# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.
import torch
import tinycudann as tcnn
import numpy as np
# Small MLP using PyTorch primitives, internal helper class
class _MLP(torch.nn.Module):
def __init__(self, cfg, loss_scale=1.0):
super(_MLP, self).__init__()
self.loss_scale = loss_scale
net = (torch.nn.Linear(cfg['n_input_dims'], cfg['n_neurons'], bias=False), torch.nn.ReLU())
for i in range(cfg['n_hidden_layers']-1):
net = net + (torch.nn.Linear(cfg['n_neurons'], cfg['n_neurons'], bias=False), torch.nn.ReLU())
net = net + (torch.nn.Linear(cfg['n_neurons'], cfg['n_output_dims'], bias=False),)
self.net = torch.nn.Sequential(*net).cuda()
if self.loss_scale != 1.0:
self.net.register_full_backward_hook(lambda module, grad_i, grad_o: (grad_i[0] * self.loss_scale, ))
def forward(self, x):
return self.net(x.to(torch.float32))
def _init_weights(m):
if type(m) == torch.nn.Linear:
torch.nn.init.kaiming_uniform_(m.weight, nonlinearity='relu')
if hasattr(m.bias, 'data'):
# Outward visible MLP class
class MLPTexture3D(torch.nn.Module):
def __init__(self, AABB, channels = 3, internal_dims = 32, hidden = 2, min_max = None):
super(MLPTexture3D, self).__init__()
self.channels = channels
self.internal_dims = internal_dims
self.AABB = AABB
self.min_max = min_max
# Setup positional encoding, see https://github.com/NVlabs/tiny-cuda-nn for details
desired_resolution = 4096
base_grid_resolution = 16
num_levels = 16
per_level_scale = np.exp(np.log(desired_resolution / base_grid_resolution) / (num_levels-1))
enc_cfg = {
"otype": "HashGrid",
"n_levels": num_levels,
"n_features_per_level": 2,
"log2_hashmap_size": 19,
"base_resolution": base_grid_resolution,
"per_level_scale" : per_level_scale
gradient_scaling = 128.0
self.encoder = tcnn.Encoding(3, enc_cfg)
self.encoder.register_full_backward_hook(lambda module, grad_i, grad_o: (grad_i[0] / gradient_scaling, ))
# Setup MLP
mlp_cfg = {
"n_input_dims" : self.encoder.n_output_dims,
"n_output_dims" : self.channels,
"n_hidden_layers" : hidden,
"n_neurons" : self.internal_dims
self.net = _MLP(mlp_cfg, gradient_scaling)
print("Encoder output: %d dims" % (self.encoder.n_output_dims))
# Sample texture at a given location
def sample(self, texc):
_texc = (texc.view(-1, 3) - self.AABB[0][None, ...]) / (self.AABB[1][None, ...] - self.AABB[0][None, ...])
_texc = torch.clamp(_texc, min=0, max=1)
p_enc = self.encoder(_texc.contiguous())
out = self.net.forward(p_enc)
# Sigmoid limit and scale to the allowed range
if self.min_max is not None:
out = torch.sigmoid(out) * (self.min_max[1][None, :] - self.min_max[0][None, :]) + self.min_max[0][None, :]
return out.view(*texc.shape[:-1], self.channels) # Remap to [n, h, w, c]
# In-place clamp with no derivative to make sure values are in valid range after training
def clamp_(self):
def cleanup(self):

@ -0,0 +1,216 @@
# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.
import os
import torch
from . import texture
from . import mesh
from . import material
# Utility functions
def _find_mat(materials, name):
for mat in materials:
if mat['name'] == name:
return mat
return materials[0] # Materials 0 is the default
# Create mesh object from objfile
def load_obj(filename, clear_ks=True, mtl_override=None, mtl_default=None, use_default=False, no_additional=False):
obj_path = os.path.dirname(filename)
# Read entire file
with open(filename, 'r') as f:
lines = f.readlines()
# Load materials
if mtl_default is None:
all_materials = [
'name' : '_default_mat',
'bsdf' : 'pbr',
'kd' : texture.Texture2D(torch.tensor([0.5, 0.5, 0.5], dtype=torch.float32, device='cuda')),
'ks' : texture.Texture2D(torch.tensor([0.0, 0.0, 0.0], dtype=torch.float32, device='cuda'))
print("Load use-defined default mtl")
all_materials = [mtl_default]
if not no_additional:
if mtl_override is None:
for line in lines:
if len(line.split()) == 0:
if line.split()[0] == 'mtllib':
all_materials += material.load_mtl(os.path.join(obj_path, line.split()[1]), clear_ks, avoid_pure_black=True) # Read in entire material library
all_materials += material.load_mtl(mtl_override)
print("Skip loading non-default materials")
# load vertices
vertices, texcoords, normals = [], [], []
for line in lines:
if len(line.split()) == 0:
prefix = line.split()[0].lower()
if prefix == 'v':
vertices.append([float(v) for v in line.split()[1:]])
elif prefix == 'vt':
val = [float(v) for v in line.split()[1:]]
texcoords.append([val[0], 1.0 - val[1]])
elif prefix == 'vn':
normals.append([float(v) for v in line.split()[1:]])
# load faces
activeMatIdx = None
used_materials = []
faces, tfaces, nfaces, mfaces = [], [], [], []
for line in lines:
if len(line.split()) == 0:
prefix = line.split()[0].lower()
if prefix == 'usemtl': # Track used materials
mat = _find_mat(all_materials, line.split()[1])
if not mat in used_materials:
activeMatIdx = used_materials.index(mat)
elif prefix == 'f': # Parse face
vs = line.split()[1:]
nv = len(vs)
vv = vs[0].split('/')
v0 = int(vv[0]) - 1
# t1 = int(vv[1]) - 1 if vv[1] != "" else -1
# n1 = int(vv[2]) - 1 if vv[2] != "" else -1
t0 = int(vv[1]) - 1 if vv[1] != "" else -1
n0 = int(vv[2]) - 1 if vv[2] != "" else -1
t0 = n0 = -1
for i in range(nv - 2): # Triangulate polygons
vv = vs[i + 1].split('/')
v1 = int(vv[0]) - 1
# t1 = int(vv[1]) - 1 if vv[1] != "" else -1
# n1 = int(vv[2]) - 1 if vv[2] != "" else -1
t1 = int(vv[1]) - 1 if vv[1] != "" else -1
n1 = int(vv[2]) - 1 if vv[2] != "" else -1
t1 = n1 = -1
vv = vs[i + 2].split('/')
v2 = int(vv[0]) - 1
# t2 = int(vv[1]) - 1 if vv[1] != "" else -1
# n2 = int(vv[2]) - 1 if vv[2] != "" else -1
t2 = int(vv[1]) - 1 if vv[1] != "" else -1
n2 = int(vv[2]) - 1 if vv[2] != "" else -1
t2 = n2 = -1
faces.append([v0, v1, v2])
tfaces.append([t0, t1, t2])
nfaces.append([n0, n1, n2])
assert len(tfaces) == len(faces) and len(nfaces) == len (faces)
# # Create an "uber" material by combining all textures into a larger texture
# # if len(used_materials) > 1:
# if True:
# uber_material, texcoords, tfaces = material.merge_materials(used_materials, texcoords, tfaces, mfaces)
# elif len(used_materials) == 1:
# uber_material = used_materials[0]
# else:
# uber_material = None
vertices = torch.tensor(vertices, dtype=torch.float32, device='cuda')
# texcoords = torch.tensor(texcoords, dtype=torch.float32, device='cuda') if len(texcoords) > 0 else None
# normals = torch.tensor(normals, dtype=torch.float32, device='cuda') if len(normals) > 0 else None
# # normals = None
faces = torch.tensor(faces, dtype=torch.int64, device='cuda')
# tfaces = torch.tensor(tfaces, dtype=torch.int64, device='cuda') if texcoords is not None else None
# nfaces = torch.tensor(nfaces, dtype=torch.int64, device='cuda') if normals is not None else None
# print(uber_material)
uber_material = all_materials[0]
texcoords = normals = tfaces = nfaces = None
# return mesh.Mesh(vertices, faces, normals, nfaces, texcoords, tfaces, material=uber_material)
imesh = mesh.Mesh(vertices, faces, normals, nfaces, texcoords, tfaces, material=uber_material)
imesh = mesh.auto_normals(imesh)
return imesh
# Save mesh object to objfile
def write_obj(folder, mesh, save_material=True):
obj_file = os.path.join(folder, 'mesh.obj')
print("Writing mesh: ", obj_file)
with open(obj_file, "w") as f:
# f.write("mtllib mesh.mtl\n")
f.write("g default\n")
v_pos = mesh.v_pos.detach().cpu().numpy() if mesh.v_pos is not None else None
# v_nrm = mesh.v_nrm.detach().cpu().numpy() if mesh.v_nrm is not None else None
# v_tex = mesh.v_tex.detach().cpu().numpy() if mesh.v_tex is not None else None
v_nrm = None
v_tex = None
t_pos_idx = mesh.t_pos_idx.detach().cpu().numpy() if mesh.t_pos_idx is not None else None
# t_nrm_idx = mesh.t_nrm_idx.detach().cpu().numpy() if mesh.t_nrm_idx is not None else None
# t_tex_idx = mesh.t_tex_idx.detach().cpu().numpy() if mesh.t_tex_idx is not None else None
print(" writing %d vertices" % len(v_pos))
for v in v_pos:
f.write('v {} {} {} \n'.format(v[0], v[1], v[2]))
# if v_tex is not None:
# print(" writing %d texcoords" % len(v_tex))
# assert(len(t_pos_idx) == len(t_tex_idx))
# for v in v_tex:
# f.write('vt {} {} \n'.format(v[0], 1.0 - v[1]))
# if v_nrm is not None:
# print(" writing %d normals" % len(v_nrm))
# assert(len(t_pos_idx) == len(t_nrm_idx))
# for v in v_nrm:
# f.write('vn {} {} {}\n'.format(v[0], v[1], v[2]))
# faces
f.write("s 1 \n")
f.write("g pMesh1\n")
f.write("usemtl defaultMat\n")
# Write faces
print(" writing %d faces" % len(t_pos_idx))
for i in range(len(t_pos_idx)):
f.write("f ")
for j in range(3):
f.write(' %s/%s/%s' % (str(t_pos_idx[i][j]+1), '' if v_tex is None else str(t_tex_idx[i][j]+1), '' if v_nrm is None else str(t_nrm_idx[i][j]+1)))
if save_material:
mtl_file = os.path.join(folder, 'mesh.mtl')
print("Writing material: ", mtl_file)
material.save_mtl(mtl_file, mesh.material)
print("Done exporting mesh")

@ -0,0 +1,205 @@
# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.
import torch
import nvdiffrast.torch as dr
import pytorch3d.ops
from . import util
from . import mesh
# Computes the image gradient, useful for kd/ks smoothness losses
def image_grad(buf, std=0.01):
t, s = torch.meshgrid(torch.linspace(-1.0 + 1.0 / buf.shape[1], 1.0 - 1.0 / buf.shape[1], buf.shape[1], device="cuda"),
torch.linspace(-1.0 + 1.0 / buf.shape[2], 1.0 - 1.0 / buf.shape[2], buf.shape[2], device="cuda"),
tc = torch.normal(mean=0, std=std, size=(buf.shape[0], buf.shape[1], buf.shape[2], 2), device="cuda") + torch.stack((s, t), dim=-1)[None, ...]
tap = dr.texture(buf, tc, filter_mode='linear', boundary_mode='clamp')
return torch.abs(tap[..., :-1] - buf[..., :-1]) * tap[..., -1:] * buf[..., -1:]
# Computes the avergage edge length of a mesh.
# Rough estimate of the tessellation of a mesh. Can be used e.g. to clamp gradients
def avg_edge_length(v_pos, t_pos_idx):
e_pos_idx = mesh.compute_edges(t_pos_idx)
edge_len = util.length(v_pos[e_pos_idx[:, 0]] - v_pos[e_pos_idx[:, 1]])
return torch.mean(edge_len)
# Laplacian regularization using umbrella operator (Fujiwara / Desbrun).
# https://mgarland.org/class/geom04/material/smoothing.pdf
def laplace_regularizer_const(v_pos, t_pos_idx):
term = torch.zeros_like(v_pos)
norm = torch.zeros_like(v_pos[..., 0:1])
v0 = v_pos[t_pos_idx[:, 0], :]
v1 = v_pos[t_pos_idx[:, 1], :]
v2 = v_pos[t_pos_idx[:, 2], :]
term.scatter_add_(0, t_pos_idx[:, 0:1].repeat(1,3), (v1 - v0) + (v2 - v0))
term.scatter_add_(0, t_pos_idx[:, 1:2].repeat(1,3), (v0 - v1) + (v2 - v1))
term.scatter_add_(0, t_pos_idx[:, 2:3].repeat(1,3), (v0 - v2) + (v1 - v2))
two = torch.ones_like(v0) * 2.0
norm.scatter_add_(0, t_pos_idx[:, 0:1], two)
norm.scatter_add_(0, t_pos_idx[:, 1:2], two)
norm.scatter_add_(0, t_pos_idx[:, 2:3], two)
term = term / torch.clamp(norm, min=1.0)
return torch.mean(term**2)
def scale_dependent_relative_laplace_regularizer_const(v_pos, v_pos_abs, t_pos_idx):
term = torch.zeros_like(v_pos)
norm = torch.zeros_like(v_pos[..., 0:1])
v0 = v_pos[t_pos_idx[:, 0], :]
v1 = v_pos[t_pos_idx[:, 1], :]
v2 = v_pos[t_pos_idx[:, 2], :]
v0_abs = v_pos_abs[t_pos_idx[:, 0], :]
v1_abs = v_pos_abs[t_pos_idx[:, 1], :]
v2_abs = v_pos_abs[t_pos_idx[:, 2], :]
eps = 1e-8
deformable_dist = False
if deformable_dist:
raise NotImplementedError
## The original distance; does not account for the
v01_dist = ((v0_abs - v1_abs).pow(2).sum(-1, keepdim=True) + eps).sqrt()
v12_dist = ((v1_abs - v2_abs).pow(2).sum(-1, keepdim=True) + eps).sqrt()
v20_dist = ((v2_abs - v0_abs).pow(2).sum(-1, keepdim=True) + eps).sqrt()
term.scatter_add_(0, t_pos_idx[:, 0:1].repeat(1,3), (v1 - v0) / v01_dist + (v2 - v0) / v20_dist)
term.scatter_add_(0, t_pos_idx[:, 1:2].repeat(1,3), (v0 - v1) / v01_dist + (v2 - v1) / v12_dist)
term.scatter_add_(0, t_pos_idx[:, 2:3].repeat(1,3), (v0 - v2) / v20_dist + (v1 - v2) / v12_dist)
return torch.mean(term**2)
def scale_dependent_laplace_regularizer_const(v_pos, t_pos_idx):
term = torch.zeros_like(v_pos)
norm = torch.zeros_like(v_pos[..., 0:1])
v0 = v_pos[t_pos_idx[:, 0], :]
v1 = v_pos[t_pos_idx[:, 1], :]
v2 = v_pos[t_pos_idx[:, 2], :]
eps = 1e-8
v01_dist = ((v0 - v1).pow(2).sum(-1, keepdim=True) + eps).sqrt()
v12_dist = ((v1 - v2).pow(2).sum(-1, keepdim=True) + eps).sqrt()
v20_dist = ((v2 - v0).pow(2).sum(-1, keepdim=True) + eps).sqrt()
stopgd = True
if stopgd:
v01_dist = v01_dist.detach()
v12_dist = v12_dist.detach()
v20_dist = v20_dist.detach()
term.scatter_add_(0, t_pos_idx[:, 0:1].repeat(1,3), (v1 - v0) / v01_dist + (v2 - v0) / v20_dist)
term.scatter_add_(0, t_pos_idx[:, 1:2].repeat(1,3), (v0 - v1) / v01_dist + (v2 - v1) / v12_dist)
term.scatter_add_(0, t_pos_idx[:, 2:3].repeat(1,3), (v0 - v2) / v20_dist + (v1 - v2) / v12_dist)
return torch.mean(term**2)
def mesh_repulsion(v_pos, t_pos_idx):
term = torch.zeros_like(v_pos)
v0 = v_pos[t_pos_idx[:, 0], :]
v1 = v_pos[t_pos_idx[:, 1], :]
v2 = v_pos[t_pos_idx[:, 2], :]
eps = 1e-8
v01_dist = ((v0 - v1).pow(2).sum(-1, keepdim=True) + eps).sqrt()
v12_dist = ((v1 - v2).pow(2).sum(-1, keepdim=True) + eps).sqrt()
v20_dist = ((v2 - v0).pow(2).sum(-1, keepdim=True) + eps).sqrt()
term.scatter_add_(0, t_pos_idx[:, 0:1], v01_dist)
term.scatter_add_(0, t_pos_idx[:, 1:2], v12_dist)
term.scatter_add_(0, t_pos_idx[:, 2:3], v20_dist)
return term**2
def laplace_regularizer_const_adaptive(v_pos, t_pos_idx):
term = torch.zeros_like(v_pos)
norm = torch.zeros_like(v_pos[..., 0:1])
v0 = v_pos[t_pos_idx[:, 0], :]
v1 = v_pos[t_pos_idx[:, 1], :]
v2 = v_pos[t_pos_idx[:, 2], :]
term.scatter_add_(0, t_pos_idx[:, 0:1].repeat(1,3), (v1 - v0) + (v2 - v0))
term.scatter_add_(0, t_pos_idx[:, 1:2].repeat(1,3), (v0 - v1) + (v2 - v1))
term.scatter_add_(0, t_pos_idx[:, 2:3].repeat(1,3), (v0 - v2) + (v1 - v2))
two = torch.ones_like(v0) * 2.0
norm.scatter_add_(0, t_pos_idx[:, 0:1], two)
norm.scatter_add_(0, t_pos_idx[:, 1:2], two)
norm.scatter_add_(0, t_pos_idx[:, 2:3], two)
term = term / torch.clamp(norm, min=1.0)
v_pos = v_pos.unsqueeze(0) * 64
with torch.no_grad():
scale = (pytorch3d.ops.knn.knn_points(v_pos, v_pos, K=2).dists[0, :, -1].detach()).sqrt().pow(1.5) ## K=2 because dist(self, self)=0
dist = term.pow(2).mean(-1) ### since the vanilla one uses mean
return torch.mean(dist * scale)
# def laplace_regularizer_const_sec_order(v_pos, t_pos_idx):
# term = torch.zeros_like(v_pos)
# norm = torch.zeros_like(v_pos[..., 0:1])
# v0 = v_pos[t_pos_idx[:, 0], :]
# v1 = v_pos[t_pos_idx[:, 1], :]
# v2 = v_pos[t_pos_idx[:, 2], :]
# term.scatter_add_(0, t_pos_idx[:, 0:1].repeat(1,3), (v1 - v0) + (v2 - v0))
# term.scatter_add_(0, t_pos_idx[:, 1:2].repeat(1,3), (v0 - v1) + (v2 - v1))
# term.scatter_add_(0, t_pos_idx[:, 2:3].repeat(1,3), (v0 - v2) + (v1 - v2))
# two = torch.ones_like(v0) * 2.0
# norm.scatter_add_(0, t_pos_idx[:, 0:1], two)
# norm.scatter_add_(0, t_pos_idx[:, 1:2], two)
# norm.scatter_add_(0, t_pos_idx[:, 2:3], two)
# term = term / torch.clamp(norm, min=1.0)
# return torch.mean(term**2)
# Smooth vertex normals
def normal_consistency(v_pos, t_pos_idx):
# Compute face normals
v0 = v_pos[t_pos_idx[:, 0], :]
v1 = v_pos[t_pos_idx[:, 1], :]
v2 = v_pos[t_pos_idx[:, 2], :]
face_normals = util.safe_normalize(torch.cross(v1 - v0, v2 - v0))
tris_per_edge = mesh.compute_edge_to_face_mapping(t_pos_idx)
# Fetch normals for both faces sharind an edge
n0 = face_normals[tris_per_edge[:, 0], :]
n1 = face_normals[tris_per_edge[:, 1], :]
# Compute error metric based on normal difference
term = torch.clamp(util.dot(n0, n1), min=-1.0, max=1.0)
term = (1.0 - term) * 0.5
return torch.mean(torch.abs(term))

@ -0,0 +1,454 @@
# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.
import torch
import nvdiffrast.torch as dr
from . import util
from . import renderutils as ru
from . import light
# ==============================================================================================
# Helper functions
# ==============================================================================================
def interpolate(attr, rast, attr_idx, rast_db=None):
return dr.interpolate(attr.contiguous(), rast, attr_idx, rast_db=rast_db, diff_attrs=None if rast_db is None else 'all')
# ==============================================================================================
# pixel shader
# ==============================================================================================
def shade(
# Texture lookups
perturbed_nrm = None
alpha_mtl = None
if 'kd_ks_normal' in material:
# Combined texture, used for MLPs because lookups are expensive
all_tex_jitter = material['kd_ks_normal'].sample(gb_pos + torch.normal(mean=0, std=0.01, size=gb_pos.shape, device="cuda"))
all_tex = material['kd_ks_normal'].sample(gb_pos)
assert all_tex.shape[-1] == 9 or all_tex.shape[-1] == 10, "Combined kd_ks_normal must be 9 or 10 channels"
kd, ks, perturbed_nrm = all_tex[..., :-6], all_tex[..., -6:-3], all_tex[..., -3:]
# Compute albedo (kd) gradient, used for material regularizer
kd_grad = torch.sum(torch.abs(all_tex_jitter[..., :-6] - all_tex[..., :-6]), dim=-1, keepdim=True) / 3
kd_jitter = material['kd'].sample(gb_texc + torch.normal(mean=0, std=0.005, size=gb_texc.shape, device="cuda"), gb_texc_deriv)
if 'alpha' in material:
raise NotImplementedError
alpha_mtl = material['alpha'].sample(gb_texc, gb_texc_deriv)
alpha_mtl = material['alpha'].sample(gb_pos + torch.normal(mean=0, std=0.01, size=gb_pos.shape, device="cuda"))
kd = material['kd'].sample(gb_texc, gb_texc_deriv)
ks = material['ks'].sample(gb_texc, gb_texc_deriv)[..., 0:3] # skip alpha
kd_grad = torch.sum(torch.abs(kd_jitter[..., 0:3] - kd[..., 0:3]), dim=-1, keepdim=True) / 3
kd_jitter = kd = material['kd'].data[0].expand(*gb_pos.size())
ks = material['ks'].data[0].expand(*gb_pos.size())[..., 0:3] # skip alpha
kd_grad = torch.sum(torch.abs(kd_jitter[..., 0:3] - kd[..., 0:3]), dim=-1, keepdim=True) / 3
# Separate kd into alpha and color, default alpha = 1
alpha = kd[..., 3:4] if kd.shape[-1] == 4 else torch.ones_like(kd[..., 0:1])
if alpha_mtl is not None:
alpha = alpha_mtl
kd = kd[..., 0:3]
# Normal perturbation & normal bend
if 'no_perturbed_nrm' in material and material['no_perturbed_nrm']:
perturbed_nrm = None
use_python = (gb_tangent is None)
gb_normal = ru.prepare_shading_normal(gb_pos, view_pos, perturbed_nrm, gb_normal, gb_tangent, gb_geometric_normal, two_sided_shading=True, opengl=True, use_python=use_python)
gb_geo_normal_corrected = ru.prepare_shading_normal(gb_pos, view_pos, None, gb_geometric_normal, gb_tangent, gb_geometric_normal, two_sided_shading=True, opengl=True, use_python=use_python)
# Evaluate BSDF
assert 'bsdf' in material or bsdf is not None, "Material must specify a BSDF type"
bsdf = material['bsdf'] if bsdf is None else bsdf
if bsdf == 'pbr':
# do not use pbr
raise NotImplementedError
if isinstance(lgt, light.EnvironmentLight):
shaded_col = lgt.shade(gb_pos, gb_normal, kd, ks, view_pos, specular=True)
assert False, "Invalid light type"
elif bsdf == 'diffuse':
if isinstance(lgt, light.EnvironmentLight):
shaded_col = lgt.shade(gb_pos, gb_geo_normal_corrected, kd, ks, view_pos, specular=False, xfm_lgt=xfm_lgt)
assert False, "Invalid light type"
elif bsdf == 'normal':
shaded_col = (gb_normal + 1.0)*0.5
elif bsdf == 'tangent':
shaded_col = (gb_tangent + 1.0)*0.5
elif bsdf == 'kd':
shaded_col = kd
elif bsdf == 'ks':
shaded_col = ks
assert False, "Invalid BSDF '%s'" % bsdf
nan_mask = torch.isnan(shaded_col)
if nan_mask.any():
if alpha is not None:
nan_mask = torch.isnan(alpha)
if nan_mask.any():
# Return multiple buffers
buffers = {
'shaded' : torch.cat((shaded_col, alpha), dim=-1),
'kd_grad' : torch.cat((kd_grad, alpha), dim=-1),
'occlusion' : torch.cat((ks[..., :1], alpha), dim=-1),
'normal' : torch.cat((gb_normal, alpha), dim=-1),
'depth' : torch.cat(((gb_pos - view_pos).pow(2).sum(dim=-1, keepdim=True).sqrt(), alpha), dim=-1),
'pos' : torch.cat((gb_pos, alpha), dim=-1),
'geo_normal': torch.cat((gb_geo_normal_corrected, alpha), dim=-1),
'geo_viewdir': torch.cat((view_pos - gb_pos, alpha), dim=-1),
'alpha' : alpha
return buffers
# ==============================================================================================
# Render a depth slice of the mesh (scene), some limitations:
# - Single mesh
# - Single light
# - Single material
# ==============================================================================================
def render_layer(
xfm_lgt = None,
flat_shading = False
full_res = [resolution[0]*spp, resolution[1]*spp]
# Rasterize
# Scale down to shading resolution when MSAA is enabled, otherwise shade at full resolution
if spp > 1 and msaa:
rast_out_s = util.scale_img_nhwc(rast, resolution, mag='nearest', min='nearest')
rast_out_s = rast
# Interpolate attributes
# Interpolate world space position
gb_pos, _ = interpolate(mesh.v_pos[None, ...], rast_out_s, mesh.t_pos_idx.int())
# Compute geometric normals. We need those because of bent normals trick (for bump mapping)
v0 = mesh.v_pos[mesh.t_pos_idx[:, 0], :]
v1 = mesh.v_pos[mesh.t_pos_idx[:, 1], :]
v2 = mesh.v_pos[mesh.t_pos_idx[:, 2], :]
face_normals = util.safe_normalize(torch.cross(v1 - v0, v2 - v0))
face_normal_indices = (torch.arange(0, face_normals.shape[0], dtype=torch.int64, device='cuda')[:, None]).repeat(1, 3)
gb_geometric_normal, _ = interpolate(face_normals[None, ...], rast_out_s, face_normal_indices.int())
if flat_shading:
gb_normal = mesh.f_nrm[rast_out_s[:, :, :, -1].long() - 1] # empty triangle get id=0; the first idx starts from 1
gb_normal[rast_out_s[:, :, :, -1].long() == 0] = 0
assert mesh.v_nrm is not None
gb_normal, _ = interpolate(mesh.v_nrm[None, ...], rast_out_s, mesh.t_nrm_idx.int())
if mesh.v_tng is not None:
gb_tangent, _ = interpolate(mesh.v_tng[None, ...], rast_out_s, mesh.t_tng_idx.int()) # Interpolate tangents
gb_tangent = None
# Do not use texture coordinate in our case
gb_texc, gb_texc_deriv = None, None
# Shade
buffers = shade(gb_pos, gb_geometric_normal, gb_normal, gb_tangent, gb_texc, gb_texc_deriv,
view_pos, lgt, mesh.material, bsdf, xfm_lgt=xfm_lgt)
#### get a mask on mesh (used to identify foreground)
mask_cont, _ = interpolate(torch.ones_like(mesh.v_pos[None, :, :1], device=mesh.v_pos.device), rast_out_s, mesh.t_pos_idx.int())
mask = (mask_cont > 0).float()
buffers['mask'] = mask
buffers['mask_cont'] = mask_cont
# Prepare output
# Scale back up to visibility resolution if using MSAA
if spp > 1 and msaa:
for key in buffers.keys():
if buffers[key] is not None:
buffers[key] = util.scale_img_nhwc(buffers[key], full_res, mag='nearest', min='nearest')
# Return buffers
return buffers
# ==============================================================================================
# Render a depth peeled mesh (scene), some limitations:
# - Single mesh
# - Single light
# - Single material
# ==============================================================================================
def render_mesh(
spp = 1,
num_layers = 1,
msaa = False,
background = None,
bsdf = None,
xfm_lgt = None,
tet_centers = None,
flat_shading = False
def prepare_input_vector(x):
x = torch.tensor(x, dtype=torch.float32, device='cuda') if not torch.is_tensor(x) else x
return x[:, None, None, :] if len(x.shape) == 2 else x
def composite_buffer(key, layers, background, antialias):
accum = background
for buffers, rast in layers:
alpha = (rast[..., -1:] > 0).float() * buffers[key][..., -1:]
accum = torch.lerp(accum, torch.cat((buffers[key][..., :-1], torch.ones_like(buffers[key][..., -1:])), dim=-1), alpha)
if antialias:
accum = dr.antialias(accum.contiguous(), rast, v_pos_clip, mesh.t_pos_idx.int())
break ## HACK: the first layer only
return accum
def separate_buffer(key, layers, background, antialias):
accum_list = []
for buffers, rast in layers:
accum = background
alpha = (rast[..., -1:] > 0).float() * buffers[key][..., -1:]
accum = torch.lerp(accum, torch.cat((buffers[key][..., :-1], torch.ones_like(buffers[key][..., -1:])), dim=-1), alpha)
if antialias:
accum = dr.antialias(accum.contiguous(), rast, v_pos_clip, mesh.t_pos_idx.int())
return accum_list
assert mesh.t_pos_idx.shape[0] > 0, "Got empty training triangle mesh (unrecoverable discontinuity)"
assert background is None or (background.shape[1] == resolution[0] and background.shape[2] == resolution[1])
full_res = [resolution[0]*spp, resolution[1]*spp]
# Convert numpy arrays to torch tensors
mtx_in = torch.tensor(mtx_in, dtype=torch.float32, device='cuda') if not torch.is_tensor(mtx_in) else mtx_in
view_pos = prepare_input_vector(view_pos)
# clip space transform
v_pos_clip = ru.xfm_points(mesh.v_pos[None, ...], mtx_in)
# Render all layers front-to-back
with dr.DepthPeeler(ctx, v_pos_clip, mesh.t_pos_idx.int(), full_res) as peeler:
rast, db = peeler.rasterize_next_layer()
layers = [(render_layer(rast, db, mesh, view_pos, lgt, resolution, spp, msaa, bsdf, xfm_lgt, flat_shading), rast)]
rast_1st_layer = rast
# with torch.no_grad():
if True:
rast, db = peeler.rasterize_next_layer()
layers2 = [(render_layer(rast, db, mesh, view_pos, lgt, resolution, spp, msaa, bsdf, xfm_lgt, flat_shading), rast)]
# Setup background
if background is not None:
if spp > 1:
background = util.scale_img_nhwc(background, full_res, mag='nearest', min='nearest')
background = torch.cat((background, torch.zeros_like(background[..., 0:1])), dim=-1)
background = torch.zeros(1, full_res[0], full_res[1], 4, dtype=torch.float32, device='cuda')
# Composite layers front-to-back
out_buffers = {}
for key in layers[0][0].keys():
if key == 'shaded':
accum = composite_buffer(key, layers, background, True)
elif (key == 'depth' or key == 'pos') and layers[0][0][key] is not None:
accum = separate_buffer(key, layers, torch.ones_like(layers[0][0][key]) * 20.0, False)
elif ('normal' in key) and layers[0][0][key] is not None:
accum = composite_buffer(key, layers, torch.zeros_like(layers[0][0][key]), True)
elif layers[0][0][key] is not None:
accum = composite_buffer(key, layers, torch.zeros_like(layers[0][0][key]), False)
if (key == 'depth' or key == 'pos') and layers[0][0][key] is not None:
out_buffers[key] = util.avg_pool_nhwc(accum[0], spp) if spp > 1 else accum[0]
# Downscale to framebuffer resolution. Use avg pooling
out_buffers[key] = util.avg_pool_nhwc(accum, spp) if spp > 1 else accum
accum = composite_buffer('shaded', layers, background, True)
out_buffers['shaded_second'] = util.avg_pool_nhwc(accum, spp) if spp > 1 else accum
accum = separate_buffer('depth', layers2, -1 * torch.ones_like(layers2[0][0]['depth']), False)
out_buffers['depth_second'] = util.avg_pool_nhwc(accum[0], spp) if spp > 1 else accum[0]
accum = separate_buffer('normal', layers2, torch.zeros_like(layers2[0][0]['normal']), False)
out_buffers['normal_second'] = util.avg_pool_nhwc(accum[0], spp) if spp > 1 else accum[0]
rast_triangle_id = rast_1st_layer[:, :, :, -1].unique()
if rast_triangle_id[0] == 0:
if rast_triangle_id.size(0) > 1:
rast_triangle_id = rast_triangle_id[1:] - 1 ## since by the convention of the rasterizer, 0 = empty
rast_triangle_id = None
out_buffers['rast_triangle_id'] = rast_triangle_id
out_buffers['rast_depth'] = rast_1st_layer[:, :, :, -2] # z-buffer
if tet_centers is not None:
with torch.no_grad():
v_pos_clip = v_pos_clip[0]
assert full_res[0] == full_res[1]
homo_transformed_tet_centers = ru.xfm_points(tet_centers[None, ...], mtx_in)
transformed_tet_centers = homo_transformed_tet_centers[0, :, :3] / homo_transformed_tet_centers[0, :, 3:4]
int_transformed_tet_centers = torch.round((transformed_tet_centers / 2.0 + 0.5) * (full_res[0] - 1)).long() # from the clip space (i.e., [-1, 1]^3) to the nearest integer coordinates in the canvas
### transpose THE "image"
tmp_int_transformed_tet_centers = int_transformed_tet_centers.clone()
int_transformed_tet_centers[:, 0] = tmp_int_transformed_tet_centers[:, 1]
int_transformed_tet_centers[:, 1] = tmp_int_transformed_tet_centers[:, 0]
valid_tet_centers = ((torch.logical_and((int_transformed_tet_centers <= full_res[0] - 1), int_transformed_tet_centers >= 0).float()).prod(dim=-1) == 1) # those tet centers in/on the edge of the clip space
valid_int_transformed_tet_centers = int_transformed_tet_centers[valid_tet_centers]
tet_center_dirs = (tet_centers - view_pos.view(1, 3))
tet_center_depths = tet_center_dirs.pow(2).sum(-1).sqrt()
### Finding occluded tetrahedra
valid_transformed_tet_center_depths = transformed_tet_centers[valid_tet_centers][:, -1] # get the depth in the clip space
valid_tet_ids = torch.arange(tet_centers.size(0)).to(valid_tet_centers.device)[valid_tet_centers]
corrected_rast_depth = out_buffers['rast_depth'].clone().detach()
corrected_rast_depth[rast_1st_layer[:, :, :, -1] == 0] = 100 # for all pixels without any rasterized mesh, just set the depth to a large enough value
Hacky way of finding most of the non-occluded tetrahedra (except for already rasterized ones):
For each pixel, find the min depth in a small neighborhood.
If the center of a tetrahedron (coinciding with this pixel when rasterized) is smaller than this min depth,
this tetrahedron is certainly non-occluded.
Doing this because exact per-pixel comparison for triangular meshes can be costly,
plus we do not need to perfectly finding all visible tetrahedra.
depth_search_range = 7 ### change this value for different resolution in rasterization
corrected_rast_depth = -torch.nn.functional.max_pool2d(
valid_reference_depth = corrected_rast_depth[0, valid_int_transformed_tet_centers[:, 0], valid_int_transformed_tet_centers[:, 1]]
depth_filter = valid_reference_depth >= valid_transformed_tet_center_depths
empty_2d_mask = (rast_1st_layer[:, :, :, -1] == 0)
empty_2d_mask = (-torch.nn.functional.max_pool2d(
padding=depth_search_range)).bool() ### similar philosophy for using a neighborhood
empty_filter = empty_2d_mask[0, valid_int_transformed_tet_centers[:, 0], valid_int_transformed_tet_centers[:, 1]]
## visible tets are either determined by depth test or emptyness test
out_buffers['visible_tet_id'] = valid_tet_ids[torch.logical_or(empty_filter, depth_filter)]
return out_buffers
# ==============================================================================================
# Render UVs
# ==============================================================================================
def render_uv(ctx, mesh, resolution, mlp_texture):
# clip space transform
uv_clip = mesh.v_tex[None, ...]*2.0 - 1.0
# pad to four component coordinate
uv_clip4 = torch.cat((uv_clip, torch.zeros_like(uv_clip[...,0:1]), torch.ones_like(uv_clip[...,0:1])), dim = -1)
# rasterize
rast, _ = dr.rasterize(ctx, uv_clip4, mesh.t_tex_idx.int(), resolution)
# Interpolate world space position
gb_pos, _ = interpolate(mesh.v_pos[None, ...], rast, mesh.t_pos_idx.int())
# Sample out textures from MLP
all_tex = mlp_texture.sample(gb_pos)
assert all_tex.shape[-1] == 9 or all_tex.shape[-1] == 10, "Combined kd_ks_normal must be 9 or 10 channels"
perturbed_nrm = all_tex[..., -3:]
return (rast[..., -1:] > 0).float(), all_tex[..., :-6], all_tex[..., -6:-3], util.safe_normalize(perturbed_nrm)
# ==============================================================================================
# Render UVs
# ==============================================================================================
def render_uv_nrm(ctx, mesh, resolution, mlp_texture):
# clip space transform
uv_clip = mesh.v_tex[None, ...]*2.0 - 1.0
# pad to four component coordinate
uv_clip4 = torch.cat((uv_clip, torch.zeros_like(uv_clip[...,0:1]), torch.ones_like(uv_clip[...,0:1])), dim = -1)
# rasterize
rast, _ = dr.rasterize(ctx, uv_clip4, mesh.t_tex_idx.int(), resolution)
# Interpolate world space position
gb_pos, _ = interpolate(mesh.v_pos[None, ...], rast, mesh.t_pos_idx.int())
# Sample out textures from MLP
all_tex = mlp_texture.sample(gb_pos)
perturbed_nrm = all_tex[..., -3:]
return (rast[..., -1:] > 0).float(), util.safe_normalize(perturbed_nrm)

# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.
from .ops import xfm_points, xfm_vectors, image_loss, diffuse_cubemap, specular_cubemap, prepare_shading_normal, lambert, frostbite_diffuse, pbr_specular, pbr_bsdf, _fresnel_shlick, _ndf_ggx, _lambda_ggx, _masking_smith
__all__ = ["xfm_vectors", "xfm_points", "image_loss", "diffuse_cubemap","specular_cubemap", "prepare_shading_normal", "lambert", "frostbite_diffuse", "pbr_specular", "pbr_bsdf", "_fresnel_shlick", "_ndf_ggx", "_lambda_ggx", "_masking_smith", ]

* Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
* property and proprietary rights in and to this material, related
* documentation and any modifications thereto. Any use, reproduction,
* disclosure or distribution of this material and related documentation
* without an express license agreement from NVIDIA CORPORATION or
* its affiliates is strictly prohibited.
#include "common.h"
#include "bsdf.h"
#define SPECULAR_EPSILON 1e-4f
// Lambert functions
__device__ inline float fwdLambert(const vec3f nrm, const vec3f wi)
return max(dot(nrm, wi) / M_PI, 0.0f);
__device__ inline void bwdLambert(const vec3f nrm, const vec3f wi, vec3f& d_nrm, vec3f& d_wi, const float d_out)
if (dot(nrm, wi) > 0.0f)
bwdDot(nrm, wi, d_nrm, d_wi, d_out / M_PI);
// Fresnel Schlick
__device__ inline float fwdFresnelSchlick(const float f0, const float f90, const float cosTheta)
float _cosTheta = clamp(cosTheta, SPECULAR_EPSILON, 1.0f - SPECULAR_EPSILON);
float scale = powf(1.0f - _cosTheta, 5.0f);
return f0 * (1.0f - scale) + f90 * scale;
__device__ inline void bwdFresnelSchlick(const float f0, const float f90, const float cosTheta, float& d_f0, float& d_f90, float& d_cosTheta, const float d_out)
float _cosTheta = clamp(cosTheta, SPECULAR_EPSILON, 1.0f - SPECULAR_EPSILON);
float scale = pow(max(1.0f - _cosTheta, 0.0f), 5.0f);
d_f0 += d_out * (1.0 - scale);
d_f90 += d_out * scale;
if (cosTheta >= SPECULAR_EPSILON && cosTheta < 1.0f - SPECULAR_EPSILON)
d_cosTheta += d_out * (f90 - f0) * -5.0f * powf(1.0f - cosTheta, 4.0f);
__device__ inline vec3f fwdFresnelSchlick(const vec3f f0, const vec3f f90, const float cosTheta)
float _cosTheta = clamp(cosTheta, SPECULAR_EPSILON, 1.0f - SPECULAR_EPSILON);
float scale = powf(1.0f - _cosTheta, 5.0f);
return f0 * (1.0f - scale) + f90 * scale;
__device__ inline void bwdFresnelSchlick(const vec3f f0, const vec3f f90, const float cosTheta, vec3f& d_f0, vec3f& d_f90, float& d_cosTheta, const vec3f d_out)
float _cosTheta = clamp(cosTheta, SPECULAR_EPSILON, 1.0f - SPECULAR_EPSILON);
float scale = pow(max(1.0f - _cosTheta, 0.0f), 5.0f);
d_f0 += d_out * (1.0 - scale);
d_f90 += d_out * scale;
if (cosTheta >= SPECULAR_EPSILON && cosTheta < 1.0f - SPECULAR_EPSILON)
d_cosTheta += sum(d_out * (f90 - f0) * -5.0f * powf(1.0f - cosTheta, 4.0f));
// Frostbite diffuse
__device__ inline float fwdFrostbiteDiffuse(const vec3f nrm, const vec3f wi, const vec3f wo, float linearRoughness)
float wiDotN = dot(wi, nrm);
float woDotN = dot(wo, nrm);
if (wiDotN > 0.0f && woDotN > 0.0f)
vec3f h = safeNormalize(wo + wi);
float wiDotH = dot(wi, h);
float energyBias = 0.5f * linearRoughness;
float energyFactor = 1.0f - (0.51f / 1.51f) * linearRoughness;
float f90 = energyBias + 2.f * wiDotH * wiDotH * linearRoughness;
float f0 = 1.f;
float wiScatter = fwdFresnelSchlick(f0, f90, wiDotN);
float woScatter = fwdFresnelSchlick(f0, f90, woDotN);
return wiScatter * woScatter * energyFactor;
else return 0.0f;
__device__ inline void bwdFrostbiteDiffuse(const vec3f nrm, const vec3f wi, const vec3f wo, float linearRoughness, vec3f& d_nrm, vec3f& d_wi, vec3f& d_wo, float &d_linearRoughness, const float d_out)
float wiDotN = dot(wi, nrm);
float woDotN = dot(wo, nrm);
if (wiDotN > 0.0f && woDotN > 0.0f)
vec3f h = safeNormalize(wo + wi);
float wiDotH = dot(wi, h);
float energyBias = 0.5f * linearRoughness;
float energyFactor = 1.0f - (0.51f / 1.51f) * linearRoughness;
float f90 = energyBias + 2.f * wiDotH * wiDotH * linearRoughness;
float f0 = 1.f;
float wiScatter = fwdFresnelSchlick(f0, f90, wiDotN);
float woScatter = fwdFresnelSchlick(f0, f90, woDotN);
// -------------- BWD --------------
// Backprop: return wiScatter * woScatter * energyFactor;
float d_wiScatter = d_out * woScatter * energyFactor;
float d_woScatter = d_out * wiScatter * energyFactor;
float d_energyFactor = d_out * wiScatter * woScatter;
// Backprop: float woScatter = fwdFresnelSchlick(f0, f90, woDotN);
float d_woDotN = 0.0f, d_f0 = 0.0, d_f90 = 0.0f;
bwdFresnelSchlick(f0, f90, woDotN, d_f0, d_f90, d_woDotN, d_woScatter);
// Backprop: float wiScatter = fwdFresnelSchlick(fd0, fd90, wiDotN);
float d_wiDotN = 0.0f;
bwdFresnelSchlick(f0, f90, wiDotN, d_f0, d_f90, d_wiDotN, d_wiScatter);
// Backprop: float f90 = energyBias + 2.f * wiDotH * wiDotH * linearRoughness;
float d_energyBias = d_f90;
float d_wiDotH = d_f90 * 4 * wiDotH * linearRoughness;
d_linearRoughness += d_f90 * 2 * wiDotH * wiDotH;
// Backprop: float energyFactor = 1.0f - (0.51f / 1.51f) * linearRoughness;
d_linearRoughness -= (0.51f / 1.51f) * d_energyFactor;
// Backprop: float energyBias = 0.5f * linearRoughness;
d_linearRoughness += 0.5 * d_energyBias;
// Backprop: float wiDotH = dot(wi, h);
vec3f d_h(0);
bwdDot(wi, h, d_wi, d_h, d_wiDotH);
// Backprop: vec3f h = safeNormalize(wo + wi);
vec3f d_wo_wi(0);
bwdSafeNormalize(wo + wi, d_wo_wi, d_h);
d_wi += d_wo_wi; d_wo += d_wo_wi;
bwdDot(wo, nrm, d_wo, d_nrm, d_woDotN);
bwdDot(wi, nrm, d_wi, d_nrm, d_wiDotN);
// Ndf GGX
__device__ inline float fwdNdfGGX(const float alphaSqr, const float cosTheta)
float _cosTheta = clamp(cosTheta, SPECULAR_EPSILON, 1.0f - SPECULAR_EPSILON);
float d = (_cosTheta * alphaSqr - _cosTheta) * _cosTheta + 1.0f;
return alphaSqr / (d * d * M_PI);
__device__ inline void bwdNdfGGX(const float alphaSqr, const float cosTheta, float& d_alphaSqr, float& d_cosTheta, const float d_out)
// Torch only back propagates if clamp doesn't trigger
float _cosTheta = clamp(cosTheta, SPECULAR_EPSILON, 1.0f - SPECULAR_EPSILON);
float cosThetaSqr = _cosTheta * _cosTheta;
d_alphaSqr += d_out * (1.0f - (alphaSqr + 1.0f) * cosThetaSqr) / (M_PI * powf((alphaSqr - 1.0) * cosThetaSqr + 1.0f, 3.0f));
if (cosTheta > SPECULAR_EPSILON && cosTheta < 1.0f - SPECULAR_EPSILON)
d_cosTheta += d_out * -(4.0f * (alphaSqr - 1.0f) * alphaSqr * cosTheta) / (M_PI * powf((alphaSqr - 1.0) * cosThetaSqr + 1.0f, 3.0f));
// Lambda GGX
__device__ inline float fwdLambdaGGX(const float alphaSqr, const float cosTheta)
float _cosTheta = clamp(cosTheta, SPECULAR_EPSILON, 1.0f - SPECULAR_EPSILON);
float cosThetaSqr = _cosTheta * _cosTheta;
float tanThetaSqr = (1.0 - cosThetaSqr) / cosThetaSqr;
float res = 0.5f * (sqrtf(1.0f + alphaSqr * tanThetaSqr) - 1.0f);
return res;
__device__ inline void bwdLambdaGGX(const float alphaSqr, const float cosTheta, float& d_alphaSqr, float& d_cosTheta, const float d_out)
float _cosTheta = clamp(cosTheta, SPECULAR_EPSILON, 1.0f - SPECULAR_EPSILON);
float cosThetaSqr = _cosTheta * _cosTheta;
float tanThetaSqr = (1.0 - cosThetaSqr) / cosThetaSqr;
float res = 0.5f * (sqrtf(1.0f + alphaSqr * tanThetaSqr) - 1.0f);
d_alphaSqr += d_out * (0.25 * tanThetaSqr) / sqrtf(alphaSqr * tanThetaSqr + 1.0f);
if (cosTheta > SPECULAR_EPSILON && cosTheta < 1.0f - SPECULAR_EPSILON)
d_cosTheta += d_out * -(0.5 * alphaSqr) / (powf(_cosTheta, 3.0f) * sqrtf(alphaSqr / cosThetaSqr - alphaSqr + 1.0f));
// Masking GGX
__device__ inline float fwdMaskingSmithGGXCorrelated(const float alphaSqr, const float cosThetaI, const float cosThetaO)
float lambdaI = fwdLambdaGGX(alphaSqr, cosThetaI);
float lambdaO = fwdLambdaGGX(alphaSqr, cosThetaO);
return 1.0f / (1.0f + lambdaI + lambdaO);
__device__ inline void bwdMaskingSmithGGXCorrelated(const float alphaSqr, const float cosThetaI, const float cosThetaO, float& d_alphaSqr, float& d_cosThetaI, float& d_cosThetaO, const float d_out)
// FWD eval
float lambdaI = fwdLambdaGGX(alphaSqr, cosThetaI);
float lambdaO = fwdLambdaGGX(alphaSqr, cosThetaO);
// BWD eval
float d_lambdaIO = -d_out / powf(1.0f + lambdaI + lambdaO, 2.0f);
bwdLambdaGGX(alphaSqr, cosThetaI, d_alphaSqr, d_cosThetaI, d_lambdaIO);
bwdLambdaGGX(alphaSqr, cosThetaO, d_alphaSqr, d_cosThetaO, d_lambdaIO);
// GGX specular
__device__ vec3f fwdPbrSpecular(const vec3f col, const vec3f nrm, const vec3f wo, const vec3f wi, const float alpha, const float min_roughness)
float _alpha = clamp(alpha, min_roughness * min_roughness, 1.0f);
float alphaSqr = _alpha * _alpha;
vec3f h = safeNormalize(wo + wi);
float woDotN = dot(wo, nrm);
float wiDotN = dot(wi, nrm);
float woDotH = dot(wo, h);
float nDotH = dot(nrm, h);
float D = fwdNdfGGX(alphaSqr, nDotH);
float G = fwdMaskingSmithGGXCorrelated(alphaSqr, woDotN, wiDotN);
vec3f F = fwdFresnelSchlick(col, 1.0f, woDotH);
vec3f w = F * D * G * 0.25 / woDotN;
bool frontfacing = (woDotN > SPECULAR_EPSILON) & (wiDotN > SPECULAR_EPSILON);
return frontfacing ? w : 0.0f;
__device__ void bwdPbrSpecular(
const vec3f col, const vec3f nrm, const vec3f wo, const vec3f wi, const float alpha, const float min_roughness,
vec3f& d_col, vec3f& d_nrm, vec3f& d_wo, vec3f& d_wi, float& d_alpha, const vec3f d_out)
// FWD eval
float _alpha = clamp(alpha, min_roughness * min_roughness, 1.0f);
float alphaSqr = _alpha * _alpha;
vec3f h = safeNormalize(wo + wi);
float woDotN = dot(wo, nrm);
float wiDotN = dot(wi, nrm);
float woDotH = dot(wo, h);
float nDotH = dot(nrm, h);
float D = fwdNdfGGX(alphaSqr, nDotH);
float G = fwdMaskingSmithGGXCorrelated(alphaSqr, woDotN, wiDotN);
vec3f F = fwdFresnelSchlick(col, 1.0f, woDotH);
vec3f w = F * D * G * 0.25 / woDotN;
bool frontfacing = (woDotN > SPECULAR_EPSILON) & (wiDotN > SPECULAR_EPSILON);
if (frontfacing)
// BWD eval
vec3f d_F = d_out * D * G * 0.25f / woDotN;
float d_D = sum(d_out * F * G * 0.25f / woDotN);
float d_G = sum(d_out * F * D * 0.25f / woDotN);
float d_woDotN = -sum(d_out * F * D * G * 0.25f / (woDotN * woDotN));
vec3f d_f90(0);
float d_woDotH(0), d_wiDotN(0), d_nDotH(0), d_alphaSqr(0);
bwdFresnelSchlick(col, 1.0f, woDotH, d_col, d_f90, d_woDotH, d_F);
bwdMaskingSmithGGXCorrelated(alphaSqr, woDotN, wiDotN, d_alphaSqr, d_woDotN, d_wiDotN, d_G);
bwdNdfGGX(alphaSqr, nDotH, d_alphaSqr, d_nDotH, d_D);
vec3f d_h(0);
bwdDot(nrm, h, d_nrm, d_h, d_nDotH);
bwdDot(wo, h, d_wo, d_h, d_woDotH);
bwdDot(wi, nrm, d_wi, d_nrm, d_wiDotN);
bwdDot(wo, nrm, d_wo, d_nrm, d_woDotN);
vec3f d_h_unnorm(0);
bwdSafeNormalize(wo + wi, d_h_unnorm, d_h);
d_wo += d_h_unnorm;
d_wi += d_h_unnorm;
if (alpha > min_roughness * min_roughness)
d_alpha += d_alphaSqr * 2 * alpha;
// Full PBR BSDF
__device__ vec3f fwdPbrBSDF(const vec3f kd, const vec3f arm, const vec3f pos, const vec3f nrm, const vec3f view_pos, const vec3f light_pos, const float min_roughness, int BSDF)
vec3f wo = safeNormalize(view_pos - pos);
vec3f wi = safeNormalize(light_pos - pos);
float alpha = arm.y * arm.y;
vec3f spec_col = (0.04f * (1.0f - arm.z) + kd * arm.z) * (1.0 - arm.x);
vec3f diff_col = kd * (1.0f - arm.z);
float diff = 0.0f;
if (BSDF == 0)
diff = fwdLambert(nrm, wi);
diff = fwdFrostbiteDiffuse(nrm, wi, wo, arm.y);
vec3f diffuse = diff_col * diff;
vec3f specular = fwdPbrSpecular(spec_col, nrm, wo, wi, alpha, min_roughness);
return diffuse + specular;
__device__ void bwdPbrBSDF(
const vec3f kd, const vec3f arm, const vec3f pos, const vec3f nrm, const vec3f view_pos, const vec3f light_pos, const float min_roughness, int BSDF,
vec3f& d_kd, vec3f& d_arm, vec3f& d_pos, vec3f& d_nrm, vec3f& d_view_pos, vec3f& d_light_pos, const vec3f d_out)
// FWD
vec3f _wi = light_pos - pos;
vec3f _wo = view_pos - pos;
vec3f wi = safeNormalize(_wi);
vec3f wo = safeNormalize(_wo);
float alpha = arm.y * arm.y;
vec3f spec_col = (0.04f * (1.0f - arm.z) + kd * arm.z) * (1.0 - arm.x);
vec3f diff_col = kd * (1.0f - arm.z);
float diff = 0.0f;
if (BSDF == 0)
diff = fwdLambert(nrm, wi);
diff = fwdFrostbiteDiffuse(nrm, wi, wo, arm.y);
// BWD
float d_alpha(0);
vec3f d_spec_col(0), d_wi(0), d_wo(0);
bwdPbrSpecular(spec_col, nrm, wo, wi, alpha, min_roughness, d_spec_col, d_nrm, d_wo, d_wi, d_alpha, d_out);
float d_diff = sum(diff_col * d_out);
if (BSDF == 0)
bwdLambert(nrm, wi, d_nrm, d_wi, d_diff);
bwdFrostbiteDiffuse(nrm, wi, wo, arm.y, d_nrm, d_wi, d_wo, d_arm.y, d_diff);
// Backprop: diff_col = kd * (1.0f - arm.z)
vec3f d_diff_col = d_out * diff;
d_kd += d_diff_col * (1.0f - arm.z);
d_arm.z -= sum(d_diff_col * kd);
// Backprop: spec_col = (0.04f * (1.0f - arm.z) + kd * arm.z) * (1.0 - arm.x)
d_kd -= d_spec_col * (arm.x - 1.0f) * arm.z;
d_arm.x += sum(d_spec_col * (arm.z * (0.04f - kd) - 0.04f));
d_arm.z -= sum(d_spec_col * (kd - 0.04f) * (arm.x - 1.0f));
// Backprop: alpha = arm.y * arm.y
d_arm.y += d_alpha * 2 * arm.y;
// Backprop: vec3f wi = safeNormalize(light_pos - pos);
vec3f d__wi(0);
bwdSafeNormalize(_wi, d__wi, d_wi);
d_light_pos += d__wi;
d_pos -= d__wi;
// Backprop: vec3f wo = safeNormalize(view_pos - pos);
vec3f d__wo(0);
bwdSafeNormalize(_wo, d__wo, d_wo);
d_view_pos += d__wo;
d_pos -= d__wo;
// Kernels
__global__ void LambertFwdKernel(LambertKernelParams p)
// Calculate pixel position.
unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
unsigned int pz = blockIdx.z;
if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
vec3f nrm = p.nrm.fetch3(px, py, pz);
vec3f wi = p.wi.fetch3(px, py, pz);
float res = fwdLambert(nrm, wi);
p.out.store(px, py, pz, res);
__global__ void LambertBwdKernel(LambertKernelParams p)
// Calculate pixel position.
unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
unsigned int pz = blockIdx.z;
if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
vec3f nrm = p.nrm.fetch3(px, py, pz);
vec3f wi = p.wi.fetch3(px, py, pz);
float d_out = p.out.fetch1(px, py, pz);
vec3f d_nrm(0), d_wi(0);
bwdLambert(nrm, wi, d_nrm, d_wi, d_out);
p.nrm.store_grad(px, py, pz, d_nrm);
p.wi.store_grad(px, py, pz, d_wi);
__global__ void FrostbiteDiffuseFwdKernel(FrostbiteDiffuseKernelParams p)
// Calculate pixel position.
unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
unsigned int pz = blockIdx.z;
if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
vec3f nrm = p.nrm.fetch3(px, py, pz);
vec3f wi = p.wi.fetch3(px, py, pz);
vec3f wo = p.wo.fetch3(px, py, pz);
float linearRoughness = p.linearRoughness.fetch1(px, py, pz);
float res = fwdFrostbiteDiffuse(nrm, wi, wo, linearRoughness);
p.out.store(px, py, pz, res);
__global__ void FrostbiteDiffuseBwdKernel(FrostbiteDiffuseKernelParams p)
// Calculate pixel position.
unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
unsigned int pz = blockIdx.z;
if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
vec3f nrm = p.nrm.fetch3(px, py, pz);
vec3f wi = p.wi.fetch3(px, py, pz);
vec3f wo = p.wo.fetch3(px, py, pz);
float linearRoughness = p.linearRoughness.fetch1(px, py, pz);
float d_out = p.out.fetch1(px, py, pz);
float d_linearRoughness = 0.0f;
vec3f d_nrm(0), d_wi(0), d_wo(0);
bwdFrostbiteDiffuse(nrm, wi, wo, linearRoughness, d_nrm, d_wi, d_wo, d_linearRoughness, d_out);
p.nrm.store_grad(px, py, pz, d_nrm);
p.wi.store_grad(px, py, pz, d_wi);
p.wo.store_grad(px, py, pz, d_wo);
p.linearRoughness.store_grad(px, py, pz, d_linearRoughness);
__global__ void FresnelShlickFwdKernel(FresnelShlickKernelParams p)
// Calculate pixel position.
unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
unsigned int pz = blockIdx.z;
if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
vec3f f0 = p.f0.fetch3(px, py, pz);
vec3f f90 = p.f90.fetch3(px, py, pz);
float cosTheta = p.cosTheta.fetch1(px, py, pz);
vec3f res = fwdFresnelSchlick(f0, f90, cosTheta);
p.out.store(px, py, pz, res);
__global__ void FresnelShlickBwdKernel(FresnelShlickKernelParams p)
// Calculate pixel position.
unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
unsigned int pz = blockIdx.z;
if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
vec3f f0 = p.f0.fetch3(px, py, pz);
vec3f f90 = p.f90.fetch3(px, py, pz);
float cosTheta = p.cosTheta.fetch1(px, py, pz);
vec3f d_out = p.out.fetch3(px, py, pz);
vec3f d_f0(0), d_f90(0);
float d_cosTheta(0);
bwdFresnelSchlick(f0, f90, cosTheta, d_f0, d_f90, d_cosTheta, d_out);
p.f0.store_grad(px, py, pz, d_f0);
p.f90.store_grad(px, py, pz, d_f90);
p.cosTheta.store_grad(px, py, pz, d_cosTheta);
__global__ void ndfGGXFwdKernel(NdfGGXParams p)
// Calculate pixel position.
unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
unsigned int pz = blockIdx.z;
if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
float alphaSqr = p.alphaSqr.fetch1(px, py, pz);
float cosTheta = p.cosTheta.fetch1(px, py, pz);
float res = fwdNdfGGX(alphaSqr, cosTheta);
p.out.store(px, py, pz, res);
__global__ void ndfGGXBwdKernel(NdfGGXParams p)
// Calculate pixel position.
unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
unsigned int pz = blockIdx.z;
if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
float alphaSqr = p.alphaSqr.fetch1(px, py, pz);
float cosTheta = p.cosTheta.fetch1(px, py, pz);
float d_out = p.out.fetch1(px, py, pz);
float d_alphaSqr(0), d_cosTheta(0);
bwdNdfGGX(alphaSqr, cosTheta, d_alphaSqr, d_cosTheta, d_out);
p.alphaSqr.store_grad(px, py, pz, d_alphaSqr);
p.cosTheta.store_grad(px, py, pz, d_cosTheta);
__global__ void lambdaGGXFwdKernel(NdfGGXParams p)
// Calculate pixel position.
unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
unsigned int pz = blockIdx.z;
if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
float alphaSqr = p.alphaSqr.fetch1(px, py, pz);
float cosTheta = p.cosTheta.fetch1(px, py, pz);
float res = fwdLambdaGGX(alphaSqr, cosTheta);
p.out.store(px, py, pz, res);
__global__ void lambdaGGXBwdKernel(NdfGGXParams p)
// Calculate pixel position.
unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
unsigned int pz = blockIdx.z;
if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
float alphaSqr = p.alphaSqr.fetch1(px, py, pz);
float cosTheta = p.cosTheta.fetch1(px, py, pz);
float d_out = p.out.fetch1(px, py, pz);
float d_alphaSqr(0), d_cosTheta(0);
bwdLambdaGGX(alphaSqr, cosTheta, d_alphaSqr, d_cosTheta, d_out);
p.alphaSqr.store_grad(px, py, pz, d_alphaSqr);
p.cosTheta.store_grad(px, py, pz, d_cosTheta);
__global__ void maskingSmithFwdKernel(MaskingSmithParams p)
// Calculate pixel position.
unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
unsigned int pz = blockIdx.z;
if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
float alphaSqr = p.alphaSqr.fetch1(px, py, pz);
float cosThetaI = p.cosThetaI.fetch1(px, py, pz);
float cosThetaO = p.cosThetaO.fetch1(px, py, pz);
float res = fwdMaskingSmithGGXCorrelated(alphaSqr, cosThetaI, cosThetaO);
p.out.store(px, py, pz, res);
__global__ void maskingSmithBwdKernel(MaskingSmithParams p)
// Calculate pixel position.
unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
unsigned int pz = blockIdx.z;
if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
float alphaSqr = p.alphaSqr.fetch1(px, py, pz);
float cosThetaI = p.cosThetaI.fetch1(px, py, pz);
float cosThetaO = p.cosThetaO.fetch1(px, py, pz);
float d_out = p.out.fetch1(px, py, pz);
float d_alphaSqr(0), d_cosThetaI(0), d_cosThetaO(0);
bwdMaskingSmithGGXCorrelated(alphaSqr, cosThetaI, cosThetaO, d_alphaSqr, d_cosThetaI, d_cosThetaO, d_out);
p.alphaSqr.store_grad(px, py, pz, d_alphaSqr);
p.cosThetaI.store_grad(px, py, pz, d_cosThetaI);
p.cosThetaO.store_grad(px, py, pz, d_cosThetaO);
__global__ void pbrSpecularFwdKernel(PbrSpecular p)
// Calculate pixel position.
unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
unsigned int pz = blockIdx.z;
if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
vec3f col = p.col.fetch3(px, py, pz);
vec3f nrm = p.nrm.fetch3(px, py, pz);
vec3f wo = p.wo.fetch3(px, py, pz);
vec3f wi = p.wi.fetch3(px, py, pz);
float alpha = p.alpha.fetch1(px, py, pz);
vec3f res = fwdPbrSpecular(col, nrm, wo, wi, alpha, p.min_roughness);
p.out.store(px, py, pz, res);
__global__ void pbrSpecularBwdKernel(PbrSpecular p)
// Calculate pixel position.
unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
unsigned int pz = blockIdx.z;
if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
vec3f col = p.col.fetch3(px, py, pz);
vec3f nrm = p.nrm.fetch3(px, py, pz);
vec3f wo = p.wo.fetch3(px, py, pz);
vec3f wi = p.wi.fetch3(px, py, pz);
float alpha = p.alpha.fetch1(px, py, pz);
vec3f d_out = p.out.fetch3(px, py, pz);
float d_alpha(0);
vec3f d_col(0), d_nrm(0), d_wo(0), d_wi(0);
bwdPbrSpecular(col, nrm, wo, wi, alpha, p.min_roughness, d_col, d_nrm, d_wo, d_wi, d_alpha, d_out);
p.col.store_grad(px, py, pz, d_col);
p.nrm.store_grad(px, py, pz, d_nrm);
p.wo.store_grad(px, py, pz, d_wo);
p.wi.store_grad(px, py, pz, d_wi);
p.alpha.store_grad(px, py, pz, d_alpha);
__global__ void pbrBSDFFwdKernel(PbrBSDF p)
// Calculate pixel position.
unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
unsigned int pz = blockIdx.z;
if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
vec3f kd = p.kd.fetch3(px, py, pz);
vec3f arm = p.arm.fetch3(px, py, pz);
vec3f pos = p.pos.fetch3(px, py, pz);
vec3f nrm = p.nrm.fetch3(px, py, pz);
vec3f view_pos = p.view_pos.fetch3(px, py, pz);
vec3f light_pos = p.light_pos.fetch3(px, py, pz);
vec3f res = fwdPbrBSDF(kd, arm, pos, nrm, view_pos, light_pos, p.min_roughness, p.BSDF);
p.out.store(px, py, pz, res);
__global__ void pbrBSDFBwdKernel(PbrBSDF p)
// Calculate pixel position.
unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
unsigned int pz = blockIdx.z;
if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
vec3f kd = p.kd.fetch3(px, py, pz);
vec3f arm = p.arm.fetch3(px, py, pz);
vec3f pos = p.pos.fetch3(px, py, pz);
vec3f nrm = p.nrm.fetch3(px, py, pz);
vec3f view_pos = p.view_pos.fetch3(px, py, pz);
vec3f light_pos = p.light_pos.fetch3(px, py, pz);
vec3f d_out = p.out.fetch3(px, py, pz);
vec3f d_kd(0), d_arm(0), d_pos(0), d_nrm(0), d_view_pos(0), d_light_pos(0);
bwdPbrBSDF(kd, arm, pos, nrm, view_pos, light_pos, p.min_roughness, p.BSDF, d_kd, d_arm, d_pos, d_nrm, d_view_pos, d_light_pos, d_out);
p.kd.store_grad(px, py, pz, d_kd);
p.arm.store_grad(px, py, pz, d_arm);
p.pos.store_grad(px, py, pz, d_pos);
p.nrm.store_grad(px, py, pz, d_nrm);
p.view_pos.store_grad(px, py, pz, d_view_pos);
p.light_pos.store_grad(px, py, pz, d_light_pos);

# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.
import math
import torch
# Vector utility functions
def _dot(x, y):
return torch.sum(x*y, -1, keepdim=True)
def _reflect(x, n):
return 2*_dot(x, n)*n - x
def _safe_normalize(x):
return torch.nn.functional.normalize(x, dim = -1)
def _bend_normal(view_vec, smooth_nrm, geom_nrm, two_sided_shading):
# Swap normal direction for backfacing surfaces
if two_sided_shading:
smooth_nrm = torch.where(_dot(geom_nrm, view_vec) > 0, smooth_nrm, -smooth_nrm)
geom_nrm = torch.where(_dot(geom_nrm, view_vec) > 0, geom_nrm, -geom_nrm)
t = torch.clamp(_dot(view_vec, smooth_nrm) / NORMAL_THRESHOLD, min=0, max=1)
return torch.lerp(geom_nrm, smooth_nrm, t)
def _perturb_normal(perturbed_nrm, smooth_nrm, smooth_tng, opengl):
smooth_bitang = _safe_normalize(torch.cross(smooth_tng, smooth_nrm))
if opengl:
shading_nrm = smooth_tng * perturbed_nrm[..., 0:1] - smooth_bitang * perturbed_nrm[..., 1:2] + smooth_nrm * torch.clamp(perturbed_nrm[..., 2:3], min=0.0)
shading_nrm = smooth_tng * perturbed_nrm[..., 0:1] + smooth_bitang * perturbed_nrm[..., 1:2] + smooth_nrm * torch.clamp(perturbed_nrm[..., 2:3], min=0.0)
return _safe_normalize(shading_nrm)
def bsdf_prepare_shading_normal(pos, view_pos, perturbed_nrm, smooth_nrm, smooth_tng, geom_nrm, two_sided_shading, opengl):
smooth_nrm = _safe_normalize(smooth_nrm)
view_vec = _safe_normalize(view_pos - pos)
if smooth_tng is None:
shading_nrm = smooth_nrm
smooth_tng = _safe_normalize(smooth_tng)
shading_nrm = _perturb_normal(perturbed_nrm, smooth_nrm, smooth_tng, opengl)
return _bend_normal(view_vec, shading_nrm, geom_nrm, two_sided_shading)
# Simple lambertian diffuse BSDF
def bsdf_lambert(nrm, wi):
return torch.clamp(_dot(nrm, wi), min=0.0) / math.pi
# Frostbite diffuse
def bsdf_frostbite(nrm, wi, wo, linearRoughness):
wiDotN = _dot(wi, nrm)
woDotN = _dot(wo, nrm)
h = _safe_normalize(wo + wi)
wiDotH = _dot(wi, h)
energyBias = 0.5 * linearRoughness
energyFactor = 1.0 - (0.51 / 1.51) * linearRoughness
f90 = energyBias + 2.0 * wiDotH * wiDotH * linearRoughness
f0 = 1.0
wiScatter = bsdf_fresnel_shlick(f0, f90, wiDotN)
woScatter = bsdf_fresnel_shlick(f0, f90, woDotN)
res = wiScatter * woScatter * energyFactor
return torch.where((wiDotN > 0.0) & (woDotN > 0.0), res, torch.zeros_like(res))
# Phong specular, loosely based on mitsuba implementation
def bsdf_phong(nrm, wo, wi, N):
dp_r = torch.clamp(_dot(_reflect(wo, nrm), wi), min=0.0, max=1.0)
dp_l = torch.clamp(_dot(nrm, wi), min=0.0, max=1.0)
return (dp_r ** N) * dp_l * (N + 2) / (2 * math.pi)
# PBR's implementation of GGX specular
specular_epsilon = 1e-4
def bsdf_fresnel_shlick(f0, f90, cosTheta):
_cosTheta = torch.clamp(cosTheta, min=specular_epsilon, max=1.0 - specular_epsilon)
return f0 + (f90 - f0) * (1.0 - _cosTheta) ** 5.0
def bsdf_ndf_ggx(alphaSqr, cosTheta):
_cosTheta = torch.clamp(cosTheta, min=specular_epsilon, max=1.0 - specular_epsilon)
d = (_cosTheta * alphaSqr - _cosTheta) * _cosTheta + 1
return alphaSqr / (d * d * math.pi)
def bsdf_lambda_ggx(alphaSqr, cosTheta):
_cosTheta = torch.clamp(cosTheta, min=specular_epsilon, max=1.0 - specular_epsilon)
cosThetaSqr = _cosTheta * _cosTheta
tanThetaSqr = (1.0 - cosThetaSqr) / cosThetaSqr
res = 0.5 * (torch.sqrt(1 + alphaSqr * tanThetaSqr) - 1.0)
return res
def bsdf_masking_smith_ggx_correlated(alphaSqr, cosThetaI, cosThetaO):
lambdaI = bsdf_lambda_ggx(alphaSqr, cosThetaI)
lambdaO = bsdf_lambda_ggx(alphaSqr, cosThetaO)
return 1 / (1 + lambdaI + lambdaO)
def bsdf_pbr_specular(col, nrm, wo, wi, alpha, min_roughness=0.08):
_alpha = torch.clamp(alpha, min=min_roughness*min_roughness, max=1.0)
alphaSqr = _alpha * _alpha
h = _safe_normalize(wo + wi)
woDotN = _dot(wo, nrm)
wiDotN = _dot(wi, nrm)
woDotH = _dot(wo, h)
nDotH = _dot(nrm, h)
D = bsdf_ndf_ggx(alphaSqr, nDotH)
G = bsdf_masking_smith_ggx_correlated(alphaSqr, woDotN, wiDotN)
F = bsdf_fresnel_shlick(col, 1, woDotH)
w = F * D * G * 0.25 / torch.clamp(woDotN, min=specular_epsilon)
frontfacing = (woDotN > specular_epsilon) & (wiDotN > specular_epsilon)
return torch.where(frontfacing, w, torch.zeros_like(w))
def bsdf_pbr(kd, arm, pos, nrm, view_pos, light_pos, min_roughness, BSDF):
wo = _safe_normalize(view_pos - pos)
wi = _safe_normalize(light_pos - pos)
spec_str = arm[..., 0:1] # x component
roughness = arm[..., 1:2] # y component
metallic = arm[..., 2:3] # z component
ks = (0.04 * (1.0 - metallic) + kd * metallic) * (1 - spec_str)
kd = kd * (1.0 - metallic)
if BSDF == 0:
diffuse = kd * bsdf_lambert(nrm, wi)
diffuse = kd * bsdf_frostbite(nrm, wi, wo, roughness)
specular = bsdf_pbr_specular(ks, nrm, wo, wi, roughness*roughness, min_roughness=min_roughness)
return diffuse + specular

* Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
* property and proprietary rights in and to this material, related
* documentation and any modifications thereto. Any use, reproduction,
* disclosure or distribution of this material and related documentation
* without an express license agreement from NVIDIA CORPORATION or
* its affiliates is strictly prohibited.
#include "common.h"
#include "bsdf.h"
#define SPECULAR_EPSILON 1e-4f
// Lambert functions
__device__ inline float fwdLambert(const vec3f nrm, const vec3f wi)
return max(dot(nrm, wi) / M_PI, 0.0f);
__device__ inline void bwdLambert(const vec3f nrm, const vec3f wi, vec3f& d_nrm, vec3f& d_wi, const float d_out)
if (dot(nrm, wi) > 0.0f)
bwdDot(nrm, wi, d_nrm, d_wi, d_out / M_PI);
// Fresnel Schlick
__device__ inline float fwdFresnelSchlick(const float f0, const float f90, const float cosTheta)
float _cosTheta = clamp(cosTheta, SPECULAR_EPSILON, 1.0f - SPECULAR_EPSILON);
float scale = powf(1.0f - _cosTheta, 5.0f);
return f0 * (1.0f - scale) + f90 * scale;
__device__ inline void bwdFresnelSchlick(const float f0, const float f90, const float cosTheta, float& d_f0, float& d_f90, float& d_cosTheta, const float d_out)
float _cosTheta = clamp(cosTheta, SPECULAR_EPSILON, 1.0f - SPECULAR_EPSILON);
float scale = pow(max(1.0f - _cosTheta, 0.0f), 5.0f);
d_f0 += d_out * (1.0 - scale);
d_f90 += d_out * scale;
if (cosTheta >= SPECULAR_EPSILON && cosTheta < 1.0f - SPECULAR_EPSILON)
d_cosTheta += d_out * (f90 - f0) * -5.0f * powf(1.0f - cosTheta, 4.0f);
__device__ inline vec3f fwdFresnelSchlick(const vec3f f0, const vec3f f90, const float cosTheta)
float _cosTheta = clamp(cosTheta, SPECULAR_EPSILON, 1.0f - SPECULAR_EPSILON);
float scale = powf(1.0f - _cosTheta, 5.0f);
return f0 * (1.0f - scale) + f90 * scale;
__device__ inline void bwdFresnelSchlick(const vec3f f0, const vec3f f90, const float cosTheta, vec3f& d_f0, vec3f& d_f90, float& d_cosTheta, const vec3f d_out)
float _cosTheta = clamp(cosTheta, SPECULAR_EPSILON, 1.0f - SPECULAR_EPSILON);
float scale = pow(max(1.0f - _cosTheta, 0.0f), 5.0f);
d_f0 += d_out * (1.0 - scale);
d_f90 += d_out * scale;
if (cosTheta >= SPECULAR_EPSILON && cosTheta < 1.0f - SPECULAR_EPSILON)
d_cosTheta += sum(d_out * (f90 - f0) * -5.0f * powf(1.0f - cosTheta, 4.0f));
// Frostbite diffuse
__device__ inline float fwdFrostbiteDiffuse(const vec3f nrm, const vec3f wi, const vec3f wo, float linearRoughness)
float wiDotN = dot(wi, nrm);
float woDotN = dot(wo, nrm);
if (wiDotN > 0.0f && woDotN > 0.0f)
vec3f h = safeNormalize(wo + wi);
float wiDotH = dot(wi, h);
float energyBias = 0.5f * linearRoughness;
float energyFactor = 1.0f - (0.51f / 1.51f) * linearRoughness;
float f90 = energyBias + 2.f * wiDotH * wiDotH * linearRoughness;
float f0 = 1.f;
float wiScatter = fwdFresnelSchlick(f0, f90, wiDotN);
float woScatter = fwdFresnelSchlick(f0, f90, woDotN);
return wiScatter * woScatter * energyFactor;
else return 0.0f;
__device__ inline void bwdFrostbiteDiffuse(const vec3f nrm, const vec3f wi, const vec3f wo, float linearRoughness, vec3f& d_nrm, vec3f& d_wi, vec3f& d_wo, float &d_linearRoughness, const float d_out)
float wiDotN = dot(wi, nrm);
float woDotN = dot(wo, nrm);
if (wiDotN > 0.0f && woDotN > 0.0f)
vec3f h = safeNormalize(wo + wi);
float wiDotH = dot(wi, h);
float energyBias = 0.5f * linearRoughness;
float energyFactor = 1.0f - (0.51f / 1.51f) * linearRoughness;
float f90 = energyBias + 2.f * wiDotH * wiDotH * linearRoughness;
float f0 = 1.f;
float wiScatter = fwdFresnelSchlick(f0, f90, wiDotN);
float woScatter = fwdFresnelSchlick(f0, f90, woDotN);
// -------------- BWD --------------
// Backprop: return wiScatter * woScatter * energyFactor;
float d_wiScatter = d_out * woScatter * energyFactor;
float d_woScatter = d_out * wiScatter * energyFactor;
float d_energyFactor = d_out * wiScatter * woScatter;
// Backprop: float woScatter = fwdFresnelSchlick(f0, f90, woDotN);
float d_woDotN = 0.0f, d_f0 = 0.0, d_f90 = 0.0f;
bwdFresnelSchlick(f0, f90, woDotN, d_f0, d_f90, d_woDotN, d_woScatter);
// Backprop: float wiScatter = fwdFresnelSchlick(fd0, fd90, wiDotN);
float d_wiDotN = 0.0f;
bwdFresnelSchlick(f0, f90, wiDotN, d_f0, d_f90, d_wiDotN, d_wiScatter);
// Backprop: float f90 = energyBias + 2.f * wiDotH * wiDotH * linearRoughness;
float d_energyBias = d_f90;
float d_wiDotH = d_f90 * 4 * wiDotH * linearRoughness;
d_linearRoughness += d_f90 * 2 * wiDotH * wiDotH;
// Backprop: float energyFactor = 1.0f - (0.51f / 1.51f) * linearRoughness;
d_linearRoughness -= (0.51f / 1.51f) * d_energyFactor;
// Backprop: float energyBias = 0.5f * linearRoughness;
d_linearRoughness += 0.5 * d_energyBias;
// Backprop: float wiDotH = dot(wi, h);
vec3f d_h(0);
bwdDot(wi, h, d_wi, d_h, d_wiDotH);
// Backprop: vec3f h = safeNormalize(wo + wi);
vec3f d_wo_wi(0);
bwdSafeNormalize(wo + wi, d_wo_wi, d_h);
d_wi += d_wo_wi; d_wo += d_wo_wi;
bwdDot(wo, nrm, d_wo, d_nrm, d_woDotN);
bwdDot(wi, nrm, d_wi, d_nrm, d_wiDotN);
// Ndf GGX
__device__ inline float fwdNdfGGX(const float alphaSqr, const float cosTheta)
float _cosTheta = clamp(cosTheta, SPECULAR_EPSILON, 1.0f - SPECULAR_EPSILON);
float d = (_cosTheta * alphaSqr - _cosTheta) * _cosTheta + 1.0f;
return alphaSqr / (d * d * M_PI);
__device__ inline void bwdNdfGGX(const float alphaSqr, const float cosTheta, float& d_alphaSqr, float& d_cosTheta, const float d_out)
// Torch only back propagates if clamp doesn't trigger
float _cosTheta = clamp(cosTheta, SPECULAR_EPSILON, 1.0f - SPECULAR_EPSILON);
float cosThetaSqr = _cosTheta * _cosTheta;
d_alphaSqr += d_out * (1.0f - (alphaSqr + 1.0f) * cosThetaSqr) / (M_PI * powf((alphaSqr - 1.0) * cosThetaSqr + 1.0f, 3.0f));
if (cosTheta > SPECULAR_EPSILON && cosTheta < 1.0f - SPECULAR_EPSILON)
d_cosTheta += d_out * -(4.0f * (alphaSqr - 1.0f) * alphaSqr * cosTheta) / (M_PI * powf((alphaSqr - 1.0) * cosThetaSqr + 1.0f, 3.0f));
// Lambda GGX
__device__ inline float fwdLambdaGGX(const float alphaSqr, const float cosTheta)
float _cosTheta = clamp(cosTheta, SPECULAR_EPSILON, 1.0f - SPECULAR_EPSILON);
float cosThetaSqr = _cosTheta * _cosTheta;
float tanThetaSqr = (1.0 - cosThetaSqr) / cosThetaSqr;
float res = 0.5f * (sqrtf(1.0f + alphaSqr * tanThetaSqr) - 1.0f);
return res;
__device__ inline void bwdLambdaGGX(const float alphaSqr, const float cosTheta, float& d_alphaSqr, float& d_cosTheta, const float d_out)
float _cosTheta = clamp(cosTheta, SPECULAR_EPSILON, 1.0f - SPECULAR_EPSILON);
float cosThetaSqr = _cosTheta * _cosTheta;
float tanThetaSqr = (1.0 - cosThetaSqr) / cosThetaSqr;
float res = 0.5f * (sqrtf(1.0f + alphaSqr * tanThetaSqr) - 1.0f);
d_alphaSqr += d_out * (0.25 * tanThetaSqr) / sqrtf(alphaSqr * tanThetaSqr + 1.0f);
if (cosTheta > SPECULAR_EPSILON && cosTheta < 1.0f - SPECULAR_EPSILON)
d_cosTheta += d_out * -(0.5 * alphaSqr) / (powf(_cosTheta, 3.0f) * sqrtf(alphaSqr / cosThetaSqr - alphaSqr + 1.0f));
// Masking GGX
__device__ inline float fwdMaskingSmithGGXCorrelated(const float alphaSqr, const float cosThetaI, const float cosThetaO)
float lambdaI = fwdLambdaGGX(alphaSqr, cosThetaI);
float lambdaO = fwdLambdaGGX(alphaSqr, cosThetaO);
return 1.0f / (1.0f + lambdaI + lambdaO);
__device__ inline void bwdMaskingSmithGGXCorrelated(const float alphaSqr, const float cosThetaI, const float cosThetaO, float& d_alphaSqr, float& d_cosThetaI, float& d_cosThetaO, const float d_out)
// FWD eval
float lambdaI = fwdLambdaGGX(alphaSqr, cosThetaI);
float lambdaO = fwdLambdaGGX(alphaSqr, cosThetaO);
// BWD eval
float d_lambdaIO = -d_out / powf(1.0f + lambdaI + lambdaO, 2.0f);
bwdLambdaGGX(alphaSqr, cosThetaI, d_alphaSqr, d_cosThetaI, d_lambdaIO);
bwdLambdaGGX(alphaSqr, cosThetaO, d_alphaSqr, d_cosThetaO, d_lambdaIO);
// GGX specular
__device__ vec3f fwdPbrSpecular(const vec3f col, const vec3f nrm, const vec3f wo, const vec3f wi, const float alpha, const float min_roughness)
float _alpha = clamp(alpha, min_roughness * min_roughness, 1.0f);
float alphaSqr = _alpha * _alpha;
vec3f h = safeNormalize(wo + wi);
float woDotN = dot(wo, nrm);
float wiDotN = dot(wi, nrm);
float woDotH = dot(wo, h);
float nDotH = dot(nrm, h);
float D = fwdNdfGGX(alphaSqr, nDotH);
float G = fwdMaskingSmithGGXCorrelated(alphaSqr, woDotN, wiDotN);
vec3f F = fwdFresnelSchlick(col, 1.0f, woDotH);
vec3f w = F * D * G * 0.25 / woDotN;
bool frontfacing = (woDotN > SPECULAR_EPSILON) & (wiDotN > SPECULAR_EPSILON);
return frontfacing ? w : 0.0f;
__device__ void bwdPbrSpecular(
const vec3f col, const vec3f nrm, const vec3f wo, const vec3f wi, const float alpha, const float min_roughness,
vec3f& d_col, vec3f& d_nrm, vec3f& d_wo, vec3f& d_wi, float& d_alpha, const vec3f d_out)
// FWD eval
float _alpha = clamp(alpha, min_roughness * min_roughness, 1.0f);
float alphaSqr = _alpha * _alpha;
vec3f h = safeNormalize(wo + wi);
float woDotN = dot(wo, nrm);
float wiDotN = dot(wi, nrm);
float woDotH = dot(wo, h);
float nDotH = dot(nrm, h);
float D = fwdNdfGGX(alphaSqr, nDotH);
float G = fwdMaskingSmithGGXCorrelated(alphaSqr, woDotN, wiDotN);
vec3f F = fwdFresnelSchlick(col, 1.0f, woDotH);
vec3f w = F * D * G * 0.25 / woDotN;
bool frontfacing = (woDotN > SPECULAR_EPSILON) & (wiDotN > SPECULAR_EPSILON);
if (frontfacing)
// BWD eval
vec3f d_F = d_out * D * G * 0.25f / woDotN;
float d_D = sum(d_out * F * G * 0.25f / woDotN);
float d_G = sum(d_out * F * D * 0.25f / woDotN);
float d_woDotN = -sum(d_out * F * D * G * 0.25f / (woDotN * woDotN));
vec3f d_f90(0);
float d_woDotH(0), d_wiDotN(0), d_nDotH(0), d_alphaSqr(0);
bwdFresnelSchlick(col, 1.0f, woDotH, d_col, d_f90, d_woDotH, d_F);
bwdMaskingSmithGGXCorrelated(alphaSqr, woDotN, wiDotN, d_alphaSqr, d_woDotN, d_wiDotN, d_G);
bwdNdfGGX(alphaSqr, nDotH, d_alphaSqr, d_nDotH, d_D);
vec3f d_h(0);
bwdDot(nrm, h, d_nrm, d_h, d_nDotH);
bwdDot(wo, h, d_wo, d_h, d_woDotH);
bwdDot(wi, nrm, d_wi, d_nrm, d_wiDotN);
bwdDot(wo, nrm, d_wo, d_nrm, d_woDotN);
vec3f d_h_unnorm(0);
bwdSafeNormalize(wo + wi, d_h_unnorm, d_h);
d_wo += d_h_unnorm;
d_wi += d_h_unnorm;
if (alpha > min_roughness * min_roughness)
d_alpha += d_alphaSqr * 2 * alpha;
// Full PBR BSDF
__device__ vec3f fwdPbrBSDF(const vec3f kd, const vec3f arm, const vec3f pos, const vec3f nrm, const vec3f view_pos, const vec3f light_pos, const float min_roughness, int BSDF)
vec3f wo = safeNormalize(view_pos - pos);
vec3f wi = safeNormalize(light_pos - pos);
float alpha = arm.y * arm.y;
vec3f spec_col = (0.04f * (1.0f - arm.z) + kd * arm.z) * (1.0 - arm.x);
vec3f diff_col = kd * (1.0f - arm.z);
float diff = 0.0f;
if (BSDF == 0)
diff = fwdLambert(nrm, wi);
diff = fwdFrostbiteDiffuse(nrm, wi, wo, arm.y);
vec3f diffuse = diff_col * diff;
vec3f specular = fwdPbrSpecular(spec_col, nrm, wo, wi, alpha, min_roughness);
return diffuse + specular;
__device__ void bwdPbrBSDF(
const vec3f kd, const vec3f arm, const vec3f pos, const vec3f nrm, const vec3f view_pos, const vec3f light_pos, const float min_roughness, int BSDF,
vec3f& d_kd, vec3f& d_arm, vec3f& d_pos, vec3f& d_nrm, vec3f& d_view_pos, vec3f& d_light_pos, const vec3f d_out)
// FWD
vec3f _wi = light_pos - pos;
vec3f _wo = view_pos - pos;
vec3f wi = safeNormalize(_wi);
vec3f wo = safeNormalize(_wo);
float alpha = arm.y * arm.y;
vec3f spec_col = (0.04f * (1.0f - arm.z) + kd * arm.z) * (1.0 - arm.x);
vec3f diff_col = kd * (1.0f - arm.z);
float diff = 0.0f;
if (BSDF == 0)
diff = fwdLambert(nrm, wi);
diff = fwdFrostbiteDiffuse(nrm, wi, wo, arm.y);
// BWD
float d_alpha(0);
vec3f d_spec_col(0), d_wi(0), d_wo(0);
bwdPbrSpecular(spec_col, nrm, wo, wi, alpha, min_roughness, d_spec_col, d_nrm, d_wo, d_wi, d_alpha, d_out);
float d_diff = sum(diff_col * d_out);
if (BSDF == 0)
bwdLambert(nrm, wi, d_nrm, d_wi, d_diff);
bwdFrostbiteDiffuse(nrm, wi, wo, arm.y, d_nrm, d_wi, d_wo, d_arm.y, d_diff);
// Backprop: diff_col = kd * (1.0f - arm.z)
vec3f d_diff_col = d_out * diff;
d_kd += d_diff_col * (1.0f - arm.z);
d_arm.z -= sum(d_diff_col * kd);
// Backprop: spec_col = (0.04f * (1.0f - arm.z) + kd * arm.z) * (1.0 - arm.x)
d_kd -= d_spec_col * (arm.x - 1.0f) * arm.z;
d_arm.x += sum(d_spec_col * (arm.z * (0.04f - kd) - 0.04f));
d_arm.z -= sum(d_spec_col * (kd - 0.04f) * (arm.x - 1.0f));
// Backprop: alpha = arm.y * arm.y
d_arm.y += d_alpha * 2 * arm.y;
// Backprop: vec3f wi = safeNormalize(light_pos - pos);
vec3f d__wi(0);
bwdSafeNormalize(_wi, d__wi, d_wi);
d_light_pos += d__wi;
d_pos -= d__wi;
// Backprop: vec3f wo = safeNormalize(view_pos - pos);
vec3f d__wo(0);
bwdSafeNormalize(_wo, d__wo, d_wo);
d_view_pos += d__wo;
d_pos -= d__wo;
// Kernels
__global__ void LambertFwdKernel(LambertKernelParams p)
// Calculate pixel position.
unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
unsigned int pz = blockIdx.z;
if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
vec3f nrm = p.nrm.fetch3(px, py, pz);
vec3f wi = p.wi.fetch3(px, py, pz);
float res = fwdLambert(nrm, wi);
p.out.store(px, py, pz, res);
__global__ void LambertBwdKernel(LambertKernelParams p)
// Calculate pixel position.
unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
unsigned int pz = blockIdx.z;
if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
vec3f nrm = p.nrm.fetch3(px, py, pz);
vec3f wi = p.wi.fetch3(px, py, pz);
float d_out = p.out.fetch1(px, py, pz);
vec3f d_nrm(0), d_wi(0);
bwdLambert(nrm, wi, d_nrm, d_wi, d_out);
p.nrm.store_grad(px, py, pz, d_nrm);
p.wi.store_grad(px, py, pz, d_wi);
__global__ void FrostbiteDiffuseFwdKernel(FrostbiteDiffuseKernelParams p)
// Calculate pixel position.
unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
unsigned int pz = blockIdx.z;
if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
vec3f nrm = p.nrm.fetch3(px, py, pz);
vec3f wi = p.wi.fetch3(px, py, pz);
vec3f wo = p.wo.fetch3(px, py, pz);
float linearRoughness = p.linearRoughness.fetch1(px, py, pz);
float res = fwdFrostbiteDiffuse(nrm, wi, wo, linearRoughness);
p.out.store(px, py, pz, res);
__global__ void FrostbiteDiffuseBwdKernel(FrostbiteDiffuseKernelParams p)
// Calculate pixel position.
unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
unsigned int pz = blockIdx.z;
if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
vec3f nrm = p.nrm.fetch3(px, py, pz);
vec3f wi = p.wi.fetch3(px, py, pz);
vec3f wo = p.wo.fetch3(px, py, pz);
float linearRoughness = p.linearRoughness.fetch1(px, py, pz);
float d_out = p.out.fetch1(px, py, pz);
float d_linearRoughness = 0.0f;
vec3f d_nrm(0), d_wi(0), d_wo(0);
bwdFrostbiteDiffuse(nrm, wi, wo, linearRoughness, d_nrm, d_wi, d_wo, d_linearRoughness, d_out);
p.nrm.store_grad(px, py, pz, d_nrm);
p.wi.store_grad(px, py, pz, d_wi);
p.wo.store_grad(px, py, pz, d_wo);
p.linearRoughness.store_grad(px, py, pz, d_linearRoughness);
__global__ void FresnelShlickFwdKernel(FresnelShlickKernelParams p)
// Calculate pixel position.
unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
unsigned int pz = blockIdx.z;
if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
vec3f f0 = p.f0.fetch3(px, py, pz);
vec3f f90 = p.f90.fetch3(px, py, pz);
float cosTheta = p.cosTheta.fetch1(px, py, pz);
vec3f res = fwdFresnelSchlick(f0, f90, cosTheta);
p.out.store(px, py, pz, res);
__global__ void FresnelShlickBwdKernel(FresnelShlickKernelParams p)
// Calculate pixel position.
unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
unsigned int pz = blockIdx.z;
if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
vec3f f0 = p.f0.fetch3(px, py, pz);
vec3f f90 = p.f90.fetch3(px, py, pz);
float cosTheta = p.cosTheta.fetch1(px, py, pz);
vec3f d_out = p.out.fetch3(px, py, pz);
vec3f d_f0(0), d_f90(0);
float d_cosTheta(0);
bwdFresnelSchlick(f0, f90, cosTheta, d_f0, d_f90, d_cosTheta, d_out);
p.f0.store_grad(px, py, pz, d_f0);
p.f90.store_grad(px, py, pz, d_f90);
p.cosTheta.store_grad(px, py, pz, d_cosTheta);
__global__ void ndfGGXFwdKernel(NdfGGXParams p)
// Calculate pixel position.
unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
unsigned int pz = blockIdx.z;
if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
float alphaSqr = p.alphaSqr.fetch1(px, py, pz);
float cosTheta = p.cosTheta.fetch1(px, py, pz);
float res = fwdNdfGGX(alphaSqr, cosTheta);
p.out.store(px, py, pz, res);
__global__ void ndfGGXBwdKernel(NdfGGXParams p)
// Calculate pixel position.
unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
unsigned int pz = blockIdx.z;
if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
float alphaSqr = p.alphaSqr.fetch1(px, py, pz);
float cosTheta = p.cosTheta.fetch1(px, py, pz);
float d_out = p.out.fetch1(px, py, pz);
float d_alphaSqr(0), d_cosTheta(0);
bwdNdfGGX(alphaSqr, cosTheta, d_alphaSqr, d_cosTheta, d_out);
p.alphaSqr.store_grad(px, py, pz, d_alphaSqr);
p.cosTheta.store_grad(px, py, pz, d_cosTheta);
__global__ void lambdaGGXFwdKernel(NdfGGXParams p)
// Calculate pixel position.
unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
unsigned int pz = blockIdx.z;
if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
float alphaSqr = p.alphaSqr.fetch1(px, py, pz);
float cosTheta = p.cosTheta.fetch1(px, py, pz);
float res = fwdLambdaGGX(alphaSqr, cosTheta);
p.out.store(px, py, pz, res);
__global__ void lambdaGGXBwdKernel(NdfGGXParams p)
// Calculate pixel position.
unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
unsigned int pz = blockIdx.z;
if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
float alphaSqr = p.alphaSqr.fetch1(px, py, pz);
float cosTheta = p.cosTheta.fetch1(px, py, pz);
float d_out = p.out.fetch1(px, py, pz);
float d_alphaSqr(0), d_cosTheta(0);
bwdLambdaGGX(alphaSqr, cosTheta, d_alphaSqr, d_cosTheta, d_out);
p.alphaSqr.store_grad(px, py, pz, d_alphaSqr);
p.cosTheta.store_grad(px, py, pz, d_cosTheta);
__global__ void maskingSmithFwdKernel(MaskingSmithParams p)
// Calculate pixel position.
unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
unsigned int pz = blockIdx.z;
if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
float alphaSqr = p.alphaSqr.fetch1(px, py, pz);
float cosThetaI = p.cosThetaI.fetch1(px, py, pz);
float cosThetaO = p.cosThetaO.fetch1(px, py, pz);
float res = fwdMaskingSmithGGXCorrelated(alphaSqr, cosThetaI, cosThetaO);
p.out.store(px, py, pz, res);
__global__ void maskingSmithBwdKernel(MaskingSmithParams p)
// Calculate pixel position.
unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
unsigned int pz = blockIdx.z;
if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
float alphaSqr = p.alphaSqr.fetch1(px, py, pz);
float cosThetaI = p.cosThetaI.fetch1(px, py, pz);
float cosThetaO = p.cosThetaO.fetch1(px, py, pz);
float d_out = p.out.fetch1(px, py, pz);
float d_alphaSqr(0), d_cosThetaI(0), d_cosThetaO(0);
bwdMaskingSmithGGXCorrelated(alphaSqr, cosThetaI, cosThetaO, d_alphaSqr, d_cosThetaI, d_cosThetaO, d_out);
p.alphaSqr.store_grad(px, py, pz, d_alphaSqr);
p.cosThetaI.store_grad(px, py, pz, d_cosThetaI);
p.cosThetaO.store_grad(px, py, pz, d_cosThetaO);
__global__ void pbrSpecularFwdKernel(PbrSpecular p)
// Calculate pixel position.
unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
unsigned int pz = blockIdx.z;
if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
vec3f col = p.col.fetch3(px, py, pz);
vec3f nrm = p.nrm.fetch3(px, py, pz);
vec3f wo = p.wo.fetch3(px, py, pz);
vec3f wi = p.wi.fetch3(px, py, pz);
float alpha = p.alpha.fetch1(px, py, pz);
vec3f res = fwdPbrSpecular(col, nrm, wo, wi, alpha, p.min_roughness);
p.out.store(px, py, pz, res);
__global__ void pbrSpecularBwdKernel(PbrSpecular p)
// Calculate pixel position.
unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
unsigned int pz = blockIdx.z;
if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
vec3f col = p.col.fetch3(px, py, pz);
vec3f nrm = p.nrm.fetch3(px, py, pz);
vec3f wo = p.wo.fetch3(px, py, pz);
vec3f wi = p.wi.fetch3(px, py, pz);
float alpha = p.alpha.fetch1(px, py, pz);
vec3f d_out = p.out.fetch3(px, py, pz);
float d_alpha(0);
vec3f d_col(0), d_nrm(0), d_wo(0), d_wi(0);
bwdPbrSpecular(col, nrm, wo, wi, alpha, p.min_roughness, d_col, d_nrm, d_wo, d_wi, d_alpha, d_out);
p.col.store_grad(px, py, pz, d_col);
p.nrm.store_grad(px, py, pz, d_nrm);
p.wo.store_grad(px, py, pz, d_wo);
p.wi.store_grad(px, py, pz, d_wi);
p.alpha.store_grad(px, py, pz, d_alpha);
__global__ void pbrBSDFFwdKernel(PbrBSDF p)
// Calculate pixel position.
unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
unsigned int pz = blockIdx.z;
if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
vec3f kd = p.kd.fetch3(px, py, pz);
vec3f arm = p.arm.fetch3(px, py, pz);
vec3f pos = p.pos.fetch3(px, py, pz);
vec3f nrm = p.nrm.fetch3(px, py, pz);
vec3f view_pos = p.view_pos.fetch3(px, py, pz);
vec3f light_pos = p.light_pos.fetch3(px, py, pz);
vec3f res = fwdPbrBSDF(kd, arm, pos, nrm, view_pos, light_pos, p.min_roughness, p.BSDF);
p.out.store(px, py, pz, res);
__global__ void pbrBSDFBwdKernel(PbrBSDF p)
// Calculate pixel position.
unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
unsigned int pz = blockIdx.z;
if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
vec3f kd = p.kd.fetch3(px, py, pz);
vec3f arm = p.arm.fetch3(px, py, pz);
vec3f pos = p.pos.fetch3(px, py, pz);
vec3f nrm = p.nrm.fetch3(px, py, pz);
vec3f view_pos = p.view_pos.fetch3(px, py, pz);
vec3f light_pos = p.light_pos.fetch3(px, py, pz);
vec3f d_out = p.out.fetch3(px, py, pz);
vec3f d_kd(0), d_arm(0), d_pos(0), d_nrm(0), d_view_pos(0), d_light_pos(0);
bwdPbrBSDF(kd, arm, pos, nrm, view_pos, light_pos, p.min_roughness, p.BSDF, d_kd, d_arm, d_pos, d_nrm, d_view_pos, d_light_pos, d_out);
p.kd.store_grad(px, py, pz, d_kd);
p.arm.store_grad(px, py, pz, d_arm);
p.pos.store_grad(px, py, pz, d_pos);
p.nrm.store_grad(px, py, pz, d_nrm);
p.view_pos.store_grad(px, py, pz, d_view_pos);
p.light_pos.store_grad(px, py, pz, d_light_pos);

* Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
* property and proprietary rights in and to this material, related
* documentation and any modifications thereto. Any use, reproduction,
* disclosure or distribution of this material and related documentation
* without an express license agreement from NVIDIA CORPORATION or
* its affiliates is strictly prohibited.
#pragma once
#include "common.h"
struct LambertKernelParams
Tensor nrm;
Tensor wi;
Tensor out;
dim3 gridSize;
struct FrostbiteDiffuseKernelParams
Tensor nrm;
Tensor wi;
Tensor wo;
Tensor linearRoughness;
Tensor out;
dim3 gridSize;
struct FresnelShlickKernelParams
Tensor f0;
Tensor f90;
Tensor cosTheta;
Tensor out;
dim3 gridSize;
struct NdfGGXParams
Tensor alphaSqr;
Tensor cosTheta;
Tensor out;
dim3 gridSize;
struct MaskingSmithParams
Tensor alphaSqr;
Tensor cosThetaI;
Tensor cosThetaO;
Tensor out;
dim3 gridSize;
struct PbrSpecular
Tensor col;
Tensor nrm;
Tensor wo;
Tensor wi;
Tensor alpha;
Tensor out;
dim3 gridSize;
float min_roughness;
struct PbrBSDF
Tensor kd;
Tensor arm;
Tensor pos;
Tensor nrm;
Tensor view_pos;
Tensor light_pos;
Tensor out;
dim3 gridSize;
float min_roughness;
int BSDF;

* Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
* property and proprietary rights in and to this material, related
* documentation and any modifications thereto. Any use, reproduction,
* disclosure or distribution of this material and related documentation
* without an express license agreement from NVIDIA CORPORATION or
* its affiliates is strictly prohibited.
#include <cuda_runtime.h>
#include <algorithm>
// Block and grid size calculators for kernel launches.
dim3 getLaunchBlockSize(int maxWidth, int maxHeight, dim3 dims)
int maxThreads = maxWidth * maxHeight;
if (maxThreads <= 1 || (dims.x * dims.y) <= 1)
return dim3(1, 1, 1); // Degenerate.
// Start from max size.
int bw = maxWidth;
int bh = maxHeight;
// Optimizations for weirdly sized buffers.
if (dims.x < bw)
// Decrease block width to smallest power of two that covers the buffer width.
while ((bw >> 1) >= dims.x)
bw >>= 1;
// Maximize height.
bh = maxThreads / bw;
if (bh > dims.y)
bh = dims.y;
else if (dims.y < bh)
// Halve height and double width until fits completely inside buffer vertically.
while (bh > dims.y)
bh >>= 1;
if (bw < dims.x)
bw <<= 1;
// Done.
return dim3(bw, bh, 1);
// returns the size of a block that can be reduced using horizontal SIMD operations (e.g. __shfl_xor_sync)
dim3 getWarpSize(dim3 blockSize)
return dim3(
std::min(blockSize.x, 32u),
std::min(std::max(32u / blockSize.x, 1u), std::min(32u, blockSize.y)),
std::min(std::max(32u / (blockSize.x * blockSize.y), 1u), std::min(32u, blockSize.z))
dim3 getLaunchGridSize(dim3 blockSize, dim3 dims)
dim3 gridSize;
gridSize.x = (dims.x - 1) / blockSize.x + 1;
gridSize.y = (dims.y - 1) / blockSize.y + 1;
gridSize.z = (dims.z - 1) / blockSize.z + 1;
return gridSize;

* Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
* property and proprietary rights in and to this material, related
* documentation and any modifications thereto. Any use, reproduction,
* disclosure or distribution of this material and related documentation
* without an express license agreement from NVIDIA CORPORATION or
* its affiliates is strictly prohibited.
#pragma once
#include <cuda.h>
#include <stdint.h>
#include "vec3f.h"
#include "vec4f.h"
#include "tensor.h"
dim3 getLaunchBlockSize(int maxWidth, int maxHeight, dim3 dims);
dim3 getLaunchGridSize(dim3 blockSize, dim3 dims);
#ifdef __CUDACC__
#ifdef _MSC_VER
#define M_PI 3.14159265358979323846f
__host__ __device__ static inline dim3 getWarpSize(dim3 blockSize)
return dim3(
min(blockSize.x, 32u),
min(max(32u / blockSize.x, 1u), min(32u, blockSize.y)),
min(max(32u / (blockSize.x * blockSize.y), 1u), min(32u, blockSize.z))
__device__ static inline float clamp(float val, float mn, float mx) { return min(max(val, mn), mx); }
dim3 getWarpSize(dim3 blockSize);

* Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
* property and proprietary rights in and to this material, related
* documentation and any modifications thereto. Any use, reproduction,
* disclosure or distribution of this material and related documentation
* without an express license agreement from NVIDIA CORPORATION or
* its affiliates is strictly prohibited.
#include "common.h"
#include "cubemap.h"
#include <float.h>
// https://cgvr.cs.uni-bremen.de/teaching/cg_literatur/Spherical,%20Cubic,%20and%20Parabolic%20Environment%20Mappings.pdf
__device__ float pixel_area(int x, int y, int N)
if (N > 1)
int H = N / 2;
x = abs(x - H);
y = abs(y - H);
float dx = atan((float)(x + 1) / (float)H) - atan((float)x / (float)H);
float dy = atan((float)(y + 1) / (float)H) - atan((float)y / (float)H);
return dx * dy;
return 1;
__device__ vec3f cube_to_dir(int x, int y, int side, int N)
float fx = 2.0f * (((float)x + 0.5f) / (float)N) - 1.0f;
float fy = 2.0f * (((float)y + 0.5f) / (float)N) - 1.0f;
switch (side)
case 0: return safeNormalize(vec3f(1, -fy, -fx));
case 1: return safeNormalize(vec3f(-1, -fy, fx));
case 2: return safeNormalize(vec3f(fx, 1, fy));
case 3: return safeNormalize(vec3f(fx, -1, -fy));
case 4: return safeNormalize(vec3f(fx, -fy, 1));
case 5: return safeNormalize(vec3f(-fx, -fy, -1));
return vec3f(0,0,0); // Unreachable
__device__ vec3f dir_to_side(int side, vec3f v)
switch (side)
case 0: return vec3f(-v.z, -v.y, v.x);
case 1: return vec3f( v.z, -v.y, -v.x);
case 2: return vec3f( v.x, v.z, v.y);
case 3: return vec3f( v.x, -v.z, -v.y);
case 4: return vec3f( v.x, -v.y, v.z);
case 5: return vec3f(-v.x, -v.y, -v.z);
return vec3f(0,0,0); // Unreachable
__device__ void extents_1d(float x, float z, float theta, float& _min, float& _max)
float l = sqrtf(x * x + z * z);
float pxr = x + z * tan(theta) * l, pzr = z - x * tan(theta) * l;
float pxl = x - z * tan(theta) * l, pzl = z + x * tan(theta) * l;
if (pzl <= 0.00001f)
_min = pxl > 0.0f ? FLT_MAX : -FLT_MAX;
_min = pxl / pzl;
if (pzr <= 0.00001f)
_max = pxr > 0.0f ? FLT_MAX : -FLT_MAX;
_max = pxr / pzr;
__device__ void dir_extents(int side, int N, vec3f v, float theta, int &_xmin, int& _xmax, int& _ymin, int& _ymax)
vec3f c = dir_to_side(side, v); // remap to (x,y,z) where side is at z = 1
if (theta < 0.785398f) // PI/4
float xmin, xmax, ymin, ymax;
extents_1d(c.x, c.z, theta, xmin, xmax);
extents_1d(c.y, c.z, theta, ymin, ymax);
if (xmin > 1.0f || xmax < -1.0f || ymin > 1.0f || ymax < -1.0f)
_xmin = -1; _xmax = -1; _ymin = -1; _ymax = -1; // Bad aabb
_xmin = (int)min(max((xmin + 1.0f) * (0.5f * (float)N), 0.0f), (float)(N - 1));
_xmax = (int)min(max((xmax + 1.0f) * (0.5f * (float)N), 0.0f), (float)(N - 1));
_ymin = (int)min(max((ymin + 1.0f) * (0.5f * (float)N), 0.0f), (float)(N - 1));
_ymax = (int)min(max((ymax + 1.0f) * (0.5f * (float)N), 0.0f), (float)(N - 1));
_xmin = 0.0f;
_xmax = (float)(N-1);
_ymin = 0.0f;
_ymax = (float)(N-1);
// Diffuse kernel
__global__ void DiffuseCubemapFwdKernel(DiffuseCubemapKernelParams p)
// Calculate pixel position.
int px = blockIdx.x * blockDim.x + threadIdx.x;
int py = blockIdx.y * blockDim.y + threadIdx.y;
int pz = blockIdx.z;
if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
int Npx = p.cubemap.dims[1];
vec3f N = cube_to_dir(px, py, pz, Npx);
vec3f col(0);
for (int s = 0; s < p.cubemap.dims[0]; ++s)
for (int y = 0; y < Npx; ++y)
for (int x = 0; x < Npx; ++x)
vec3f L = cube_to_dir(x, y, s, Npx);
float costheta = min(max(dot(N, L), 0.0f), 0.999f);
float w = costheta * pixel_area(x, y, Npx) / 3.141592f; // pi = area of positive hemisphere
col += p.cubemap.fetch3(x, y, s) * w;
p.out.store(px, py, pz, col);
__global__ void DiffuseCubemapBwdKernel(DiffuseCubemapKernelParams p)
// Calculate pixel position.
int px = blockIdx.x * blockDim.x + threadIdx.x;
int py = blockIdx.y * blockDim.y + threadIdx.y;
int pz = blockIdx.z;
if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
int Npx = p.cubemap.dims[1];
vec3f N = cube_to_dir(px, py, pz, Npx);
vec3f grad = p.out.fetch3(px, py, pz);
for (int s = 0; s < p.cubemap.dims[0]; ++s)
for (int y = 0; y < Npx; ++y)
for (int x = 0; x < Npx; ++x)
vec3f L = cube_to_dir(x, y, s, Npx);
float costheta = min(max(dot(N, L), 0.0f), 0.999f);
float w = costheta * pixel_area(x, y, Npx) / 3.141592f; // pi = area of positive hemisphere
atomicAdd((float*)p.cubemap.d_val + p.cubemap.nhwcIndexContinuous(s, y, x, 0), grad.x * w);
atomicAdd((float*)p.cubemap.d_val + p.cubemap.nhwcIndexContinuous(s, y, x, 1), grad.y * w);
atomicAdd((float*)p.cubemap.d_val + p.cubemap.nhwcIndexContinuous(s, y, x, 2), grad.z * w);
// GGX splitsum kernel
__device__ inline float ndfGGX(const float alphaSqr, const float cosTheta)
float _cosTheta = clamp(cosTheta, 0.0, 1.0f);
float d = (_cosTheta * alphaSqr - _cosTheta) * _cosTheta + 1.0f;
return alphaSqr / (d * d * M_PI);
__global__ void SpecularBoundsKernel(SpecularBoundsKernelParams p)
int px = blockIdx.x * blockDim.x + threadIdx.x;
int py = blockIdx.y * blockDim.y + threadIdx.y;
int pz = blockIdx.z;
if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
int Npx = p.gridSize.x;
vec3f VNR = cube_to_dir(px, py, pz, Npx);
const int TILE_SIZE = 16;
// Brute force entire cubemap and compute bounds for the cone
for (int s = 0; s < p.gridSize.z; ++s)
// Assume empty BBox
int _min_x = p.gridSize.x - 1, _max_x = 0;
int _min_y = p.gridSize.y - 1, _max_y = 0;
// For each (8x8) tile
for (int tx = 0; tx < (p.gridSize.x + TILE_SIZE - 1) / TILE_SIZE; tx++)
for (int ty = 0; ty < (p.gridSize.y + TILE_SIZE - 1) / TILE_SIZE; ty++)
// Compute tile extents
int tsx = tx * TILE_SIZE, tsy = ty * TILE_SIZE;
int tex = min((tx + 1) * TILE_SIZE, p.gridSize.x), tey = min((ty + 1) * TILE_SIZE, p.gridSize.y);
// Use some blunt interval arithmetics to cull tiles
vec3f L0 = cube_to_dir(tsx, tsy, s, Npx), L1 = cube_to_dir(tex, tsy, s, Npx);
vec3f L2 = cube_to_dir(tsx, tey, s, Npx), L3 = cube_to_dir(tex, tey, s, Npx);
float minx = min(min(L0.x, L1.x), min(L2.x, L3.x)), maxx = max(max(L0.x, L1.x), max(L2.x, L3.x));
float miny = min(min(L0.y, L1.y), min(L2.y, L3.y)), maxy = max(max(L0.y, L1.y), max(L2.y, L3.y));
float minz = min(min(L0.z, L1.z), min(L2.z, L3.z)), maxz = max(max(L0.z, L1.z), max(L2.z, L3.z));
float maxdp = max(minx * VNR.x, maxx * VNR.x) + max(miny * VNR.y, maxy * VNR.y) + max(minz * VNR.z, maxz * VNR.z);
if (maxdp >= p.costheta_cutoff)
// Test all pixels in tile.
for (int y = tsy; y < tey; ++y)
for (int x = tsx; x < tex; ++x)
vec3f L = cube_to_dir(x, y, s, Npx);
if (dot(L, VNR) >= p.costheta_cutoff)
_min_x = min(_min_x, x);
_max_x = max(_max_x, x);
_min_y = min(_min_y, y);
_max_y = max(_max_y, y);
p.out.store(p.out._nhwcIndex(pz, py, px, s * 4 + 0), _min_x);
p.out.store(p.out._nhwcIndex(pz, py, px, s * 4 + 1), _max_x);
p.out.store(p.out._nhwcIndex(pz, py, px, s * 4 + 2), _min_y);
p.out.store(p.out._nhwcIndex(pz, py, px, s * 4 + 3), _max_y);
__global__ void SpecularCubemapFwdKernel(SpecularCubemapKernelParams p)
// Calculate pixel position.
int px = blockIdx.x * blockDim.x + threadIdx.x;
int py = blockIdx.y * blockDim.y + threadIdx.y;
int pz = blockIdx.z;
if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
int Npx = p.cubemap.dims[1];
vec3f VNR = cube_to_dir(px, py, pz, Npx);
float alpha = p.roughness * p.roughness;
float alphaSqr = alpha * alpha;
float wsum = 0.0f;
vec3f col(0);
for (int s = 0; s < p.cubemap.dims[0]; ++s)
int xmin, xmax, ymin, ymax;
xmin = (int)p.bounds.fetch(p.bounds._nhwcIndex(pz, py, px, s * 4 + 0));
xmax = (int)p.bounds.fetch(p.bounds._nhwcIndex(pz, py, px, s * 4 + 1));
ymin = (int)p.bounds.fetch(p.bounds._nhwcIndex(pz, py, px, s * 4 + 2));
ymax = (int)p.bounds.fetch(p.bounds._nhwcIndex(pz, py, px, s * 4 + 3));
if (xmin <= xmax)
for (int y = ymin; y <= ymax; ++y)
for (int x = xmin; x <= xmax; ++x)
vec3f L = cube_to_dir(x, y, s, Npx);
if (dot(L, VNR) >= p.costheta_cutoff)
vec3f H = safeNormalize(L + VNR);
float wiDotN = max(dot(L, VNR), 0.0f);
float VNRDotH = max(dot(VNR, H), 0.0f);
float w = wiDotN * ndfGGX(alphaSqr, VNRDotH) * pixel_area(x, y, Npx) / 4.0f;
col += p.cubemap.fetch3(x, y, s) * w;
wsum += w;
p.out.store(p.out._nhwcIndex(pz, py, px, 0), col.x);
p.out.store(p.out._nhwcIndex(pz, py, px, 1), col.y);
p.out.store(p.out._nhwcIndex(pz, py, px, 2), col.z);
p.out.store(p.out._nhwcIndex(pz, py, px, 3), wsum);
__global__ void SpecularCubemapBwdKernel(SpecularCubemapKernelParams p)
// Calculate pixel position.
int px = blockIdx.x * blockDim.x + threadIdx.x;
int py = blockIdx.y * blockDim.y + threadIdx.y;
int pz = blockIdx.z;
if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
int Npx = p.cubemap.dims[1];
vec3f VNR = cube_to_dir(px, py, pz, Npx);
vec3f grad = p.out.fetch3(px, py, pz);
float alpha = p.roughness * p.roughness;
float alphaSqr = alpha * alpha;
vec3f col(0);
for (int s = 0; s < p.cubemap.dims[0]; ++s)
int xmin, xmax, ymin, ymax;
xmin = (int)p.bounds.fetch(p.bounds._nhwcIndex(pz, py, px, s * 4 + 0));
xmax = (int)p.bounds.fetch(p.bounds._nhwcIndex(pz, py, px, s * 4 + 1));
ymin = (int)p.bounds.fetch(p.bounds._nhwcIndex(pz, py, px, s * 4 + 2));
ymax = (int)p.bounds.fetch(p.bounds._nhwcIndex(pz, py, px, s * 4 + 3));
if (xmin <= xmax)
for (int y = ymin; y <= ymax; ++y)
for (int x = xmin; x <= xmax; ++x)
vec3f L = cube_to_dir(x, y, s, Npx);
if (dot(L, VNR) >= p.costheta_cutoff)
vec3f H = safeNormalize(L + VNR);
float wiDotN = max(dot(L, VNR), 0.0f);
float VNRDotH = max(dot(VNR, H), 0.0f);
float w = wiDotN * ndfGGX(alphaSqr, VNRDotH) * pixel_area(x, y, Npx) / 4.0f;
atomicAdd((float*)p.cubemap.d_val + p.cubemap.nhwcIndexContinuous(s, y, x, 0), grad.x * w);
atomicAdd((float*)p.cubemap.d_val + p.cubemap.nhwcIndexContinuous(s, y, x, 1), grad.y * w);
atomicAdd((float*)p.cubemap.d_val + p.cubemap.nhwcIndexContinuous(s, y, x, 2), grad.z * w);

* Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
* property and proprietary rights in and to this material, related
* documentation and any modifications thereto. Any use, reproduction,
* disclosure or distribution of this material and related documentation
* without an express license agreement from NVIDIA CORPORATION or
* its affiliates is strictly prohibited.
#pragma once
#include "common.h"
struct DiffuseCubemapKernelParams
Tensor cubemap;
Tensor out;
dim3 gridSize;
struct SpecularCubemapKernelParams
Tensor cubemap;
Tensor bounds;
Tensor out;
dim3 gridSize;
float costheta_cutoff;
float roughness;
struct SpecularBoundsKernelParams
float costheta_cutoff;
Tensor out;
dim3 gridSize;

* Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
* property and proprietary rights in and to this material, related
* documentation and any modifications thereto. Any use, reproduction,
* disclosure or distribution of this material and related documentation
* without an express license agreement from NVIDIA CORPORATION or
* its affiliates is strictly prohibited.
#include <cuda.h>
#include "common.h"
#include "loss.h"
// Utils
__device__ inline float bwdAbs(float x) { return x == 0.0f ? 0.0f : x < 0.0f ? -1.0f : 1.0f; }
__device__ float warpSum(float val) {
for (int i = 1; i < 32; i *= 2)
val += __shfl_xor_sync(0xFFFFFFFF, val, i);
return val;
// Tonemapping
__device__ inline float fwdSRGB(float x)
return x > 0.0031308f ? powf(max(x, 0.0031308f), 1.0f / 2.4f) * 1.055f - 0.055f : 12.92f * max(x, 0.0f);
__device__ inline void bwdSRGB(float x, float &d_x, float d_out)
if (x > 0.0031308f)
d_x += d_out * 0.439583f / powf(x, 0.583333f);
else if (x > 0.0f)
d_x += d_out * 12.92f;
__device__ inline vec3f fwdTonemapLogSRGB(vec3f x)
return vec3f(fwdSRGB(logf(x.x + 1.0f)), fwdSRGB(logf(x.y + 1.0f)), fwdSRGB(logf(x.z + 1.0f)));
__device__ inline void bwdTonemapLogSRGB(vec3f x, vec3f& d_x, vec3f d_out)
if (x.x > 0.0f && x.x < 65535.0f)
bwdSRGB(logf(x.x + 1.0f), d_x.x, d_out.x);
d_x.x *= 1 / (x.x + 1.0f);
if (x.y > 0.0f && x.y < 65535.0f)
bwdSRGB(logf(x.y + 1.0f), d_x.y, d_out.y);
d_x.y *= 1 / (x.y + 1.0f);
if (x.z > 0.0f && x.z < 65535.0f)
bwdSRGB(logf(x.z + 1.0f), d_x.z, d_out.z);
d_x.z *= 1 / (x.z + 1.0f);
__device__ inline float fwdRELMSE(float img, float target, float eps = 0.1f)
return (img - target) * (img - target) / (img * img + target * target + eps);
__device__ inline void bwdRELMSE(float img, float target, float &d_img, float &d_target, float d_out, float eps = 0.1f)
float denom = (target * target + img * img + eps);
d_img += d_out * 2 * (img - target) * (target * (target + img) + eps) / (denom * denom);
d_target -= d_out * 2 * (img - target) * (img * (target + img) + eps) / (denom * denom);
__device__ inline float fwdSMAPE(float img, float target, float eps=0.01f)
return abs(img - target) / (img + target + eps);
__device__ inline void bwdSMAPE(float img, float target, float& d_img, float& d_target, float d_out, float eps = 0.01f)
float denom = (target + img + eps);
d_img += d_out * bwdAbs(img - target) * (2 * target + eps) / (denom * denom);
d_target -= d_out * bwdAbs(img - target) * (2 * img + eps) / (denom * denom);
// Kernels
__global__ void imgLossFwdKernel(LossKernelParams p)
// Calculate pixel position.
unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
unsigned int pz = blockIdx.z;
float floss = 0.0f;
if (px < p.gridSize.x && py < p.gridSize.y && pz < p.gridSize.z)
vec3f img = p.img.fetch3(px, py, pz);
vec3f target = p.target.fetch3(px, py, pz);
img = vec3f(clamp(img.x, 0.0f, 65535.0f), clamp(img.y, 0.0f, 65535.0f), clamp(img.z, 0.0f, 65535.0f));
target = vec3f(clamp(target.x, 0.0f, 65535.0f), clamp(target.y, 0.0f, 65535.0f), clamp(target.z, 0.0f, 65535.0f));
if (p.tonemapper == TONEMAPPER_LOG_SRGB)
img = fwdTonemapLogSRGB(img);
target = fwdTonemapLogSRGB(target);
vec3f vloss(0);
if (p.loss == LOSS_MSE)
vloss = (img - target) * (img - target);
else if (p.loss == LOSS_RELMSE)
vloss = vec3f(fwdRELMSE(img.x, target.x), fwdRELMSE(img.y, target.y), fwdRELMSE(img.z, target.z));
else if (p.loss == LOSS_SMAPE)
vloss = vec3f(fwdSMAPE(img.x, target.x), fwdSMAPE(img.y, target.y), fwdSMAPE(img.z, target.z));
vloss = vec3f(abs(img.x - target.x), abs(img.y - target.y), abs(img.z - target.z));
floss = sum(vloss) / 3.0f;
floss = warpSum(floss);
dim3 warpSize = getWarpSize(blockDim);
if (px < p.gridSize.x && py < p.gridSize.y && pz < p.gridSize.z && threadIdx.x % warpSize.x == 0 && threadIdx.y % warpSize.y == 0 && threadIdx.z % warpSize.z == 0)
p.out.store(px / warpSize.x, py / warpSize.y, pz / warpSize.z, floss);
__global__ void imgLossBwdKernel(LossKernelParams p)
// Calculate pixel position.
unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
unsigned int pz = blockIdx.z;
if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
dim3 warpSize = getWarpSize(blockDim);
vec3f _img = p.img.fetch3(px, py, pz);
vec3f _target = p.target.fetch3(px, py, pz);
float d_out = p.out.fetch1(px / warpSize.x, py / warpSize.y, pz / warpSize.z);
// FWD
vec3f img = _img, target = _target;
if (p.tonemapper == TONEMAPPER_LOG_SRGB)
img = fwdTonemapLogSRGB(img);
target = fwdTonemapLogSRGB(target);
// BWD
vec3f d_vloss = vec3f(d_out, d_out, d_out) / 3.0f;
vec3f d_img(0), d_target(0);
if (p.loss == LOSS_MSE)
d_img = vec3f(d_vloss.x * 2 * (img.x - target.x), d_vloss.y * 2 * (img.y - target.y), d_vloss.x * 2 * (img.z - target.z));
d_target = -d_img;
else if (p.loss == LOSS_RELMSE)
bwdRELMSE(img.x, target.x, d_img.x, d_target.x, d_vloss.x);
bwdRELMSE(img.y, target.y, d_img.y, d_target.y, d_vloss.y);
bwdRELMSE(img.z, target.z, d_img.z, d_target.z, d_vloss.z);
else if (p.loss == LOSS_SMAPE)
bwdSMAPE(img.x, target.x, d_img.x, d_target.x, d_vloss.x);
bwdSMAPE(img.y, target.y, d_img.y, d_target.y, d_vloss.y);
bwdSMAPE(img.z, target.z, d_img.z, d_target.z, d_vloss.z);
d_img = d_vloss * vec3f(bwdAbs(img.x - target.x), bwdAbs(img.y - target.y), bwdAbs(img.z - target.z));
d_target = -d_img;
if (p.tonemapper == TONEMAPPER_LOG_SRGB)
vec3f d__img(0), d__target(0);
bwdTonemapLogSRGB(_img, d__img, d_img);
bwdTonemapLogSRGB(_target, d__target, d_target);
d_img = d__img; d_target = d__target;
if (_img.x <= 0.0f || _img.x >= 65535.0f) d_img.x = 0;
if (_img.y <= 0.0f || _img.y >= 65535.0f) d_img.y = 0;
if (_img.z <= 0.0f || _img.z >= 65535.0f) d_img.z = 0;
if (_target.x <= 0.0f || _target.x >= 65535.0f) d_target.x = 0;
if (_target.y <= 0.0f || _target.y >= 65535.0f) d_target.y = 0;
if (_target.z <= 0.0f || _target.z >= 65535.0f) d_target.z = 0;
p.img.store_grad(px, py, pz, d_img);
p.target.store_grad(px, py, pz, d_target);

* Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
* property and proprietary rights in and to this material, related
* documentation and any modifications thereto. Any use, reproduction,
* disclosure or distribution of this material and related documentation
* without an express license agreement from NVIDIA CORPORATION or
* its affiliates is strictly prohibited.
#pragma once
#include "common.h"
enum TonemapperType
enum LossType
LOSS_L1 = 0,
struct LossKernelParams
Tensor img;
Tensor target;
Tensor out;
dim3 gridSize;
TonemapperType tonemapper;
LossType loss;

* Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
* property and proprietary rights in and to this material, related
* documentation and any modifications thereto. Any use, reproduction,
* disclosure or distribution of this material and related documentation
* without an express license agreement from NVIDIA CORPORATION or
* its affiliates is strictly prohibited.
#include <cuda.h>
#include <stdio.h>
#include "common.h"
#include "mesh.h"
// Kernels
__global__ void xfmPointsFwdKernel(XfmKernelParams p)
unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
unsigned int pz = blockIdx.z * blockDim.z + threadIdx.z;
__shared__ float mtx[4][4];
if (threadIdx.x < 16)
mtx[threadIdx.x % 4][threadIdx.x / 4] = p.matrix.fetch(p.matrix.nhwcIndex(pz, threadIdx.x / 4, threadIdx.x % 4, 0));
if (px >= p.gridSize.x)
vec3f pos(
p.points.fetch(p.points.nhwcIndex(pz, px, 0, 0)),
p.points.fetch(p.points.nhwcIndex(pz, px, 1, 0)),
p.points.fetch(p.points.nhwcIndex(pz, px, 2, 0))
if (p.isPoints)
p.out.store(p.out.nhwcIndex(pz, px, 0, 0), pos.x * mtx[0][0] + pos.y * mtx[1][0] + pos.z * mtx[2][0] + mtx[3][0]);
p.out.store(p.out.nhwcIndex(pz, px, 1, 0), pos.x * mtx[0][1] + pos.y * mtx[1][1] + pos.z * mtx[2][1] + mtx[3][1]);
p.out.store(p.out.nhwcIndex(pz, px, 2, 0), pos.x * mtx[0][2] + pos.y * mtx[1][2] + pos.z * mtx[2][2] + mtx[3][2]);
p.out.store(p.out.nhwcIndex(pz, px, 3, 0), pos.x * mtx[0][3] + pos.y * mtx[1][3] + pos.z * mtx[2][3] + mtx[3][3]);
p.out.store(p.out.nhwcIndex(pz, px, 0, 0), pos.x * mtx[0][0] + pos.y * mtx[1][0] + pos.z * mtx[2][0]);
p.out.store(p.out.nhwcIndex(pz, px, 1, 0), pos.x * mtx[0][1] + pos.y * mtx[1][1] + pos.z * mtx[2][1]);
p.out.store(p.out.nhwcIndex(pz, px, 2, 0), pos.x * mtx[0][2] + pos.y * mtx[1][2] + pos.z * mtx[2][2]);
__global__ void xfmPointsBwdKernel(XfmKernelParams p)
unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
unsigned int pz = blockIdx.z * blockDim.z + threadIdx.z;
__shared__ float mtx[4][4];
if (threadIdx.x < 16)
mtx[threadIdx.x % 4][threadIdx.x / 4] = p.matrix.fetch(p.matrix.nhwcIndex(pz, threadIdx.x / 4, threadIdx.x % 4, 0));
if (px >= p.gridSize.x)
vec3f pos(
p.points.fetch(p.points.nhwcIndex(pz, px, 0, 0)),
p.points.fetch(p.points.nhwcIndex(pz, px, 1, 0)),
p.points.fetch(p.points.nhwcIndex(pz, px, 2, 0))
vec4f d_out(
p.out.fetch(p.out.nhwcIndex(pz, px, 0, 0)),
p.out.fetch(p.out.nhwcIndex(pz, px, 1, 0)),
p.out.fetch(p.out.nhwcIndex(pz, px, 2, 0)),
p.out.fetch(p.out.nhwcIndex(pz, px, 3, 0))
if (p.isPoints)
p.points.store_grad(p.points.nhwcIndexContinuous(pz, px, 0, 0), d_out.x * mtx[0][0] + d_out.y * mtx[0][1] + d_out.z * mtx[0][2] + d_out.w * mtx[0][3]);
p.points.store_grad(p.points.nhwcIndexContinuous(pz, px, 1, 0), d_out.x * mtx[1][0] + d_out.y * mtx[1][1] + d_out.z * mtx[1][2] + d_out.w * mtx[1][3]);
p.points.store_grad(p.points.nhwcIndexContinuous(pz, px, 2, 0), d_out.x * mtx[2][0] + d_out.y * mtx[2][1] + d_out.z * mtx[2][2] + d_out.w * mtx[2][3]);
p.points.store_grad(p.points.nhwcIndexContinuous(pz, px, 0, 0), d_out.x * mtx[0][0] + d_out.y * mtx[0][1] + d_out.z * mtx[0][2]);
p.points.store_grad(p.points.nhwcIndexContinuous(pz, px, 1, 0), d_out.x * mtx[1][0] + d_out.y * mtx[1][1] + d_out.z * mtx[1][2]);
p.points.store_grad(p.points.nhwcIndexContinuous(pz, px, 2, 0), d_out.x * mtx[2][0] + d_out.y * mtx[2][1] + d_out.z * mtx[2][2]);

* Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
* property and proprietary rights in and to this material, related
* documentation and any modifications thereto. Any use, reproduction,
* disclosure or distribution of this material and related documentation
* without an express license agreement from NVIDIA CORPORATION or
* its affiliates is strictly prohibited.
#pragma once
#include "common.h"
struct XfmKernelParams
bool isPoints;
Tensor points;
Tensor matrix;
Tensor out;
dim3 gridSize;

* Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
* property and proprietary rights in and to this material, related
* documentation and any modifications thereto. Any use, reproduction,
* disclosure or distribution of this material and related documentation
* without an express license agreement from NVIDIA CORPORATION or
* its affiliates is strictly prohibited.
#include "common.h"
#include "normal.h"
// Perturb shading normal by tangent frame
__device__ vec3f fwdPerturbNormal(const vec3f perturbed_nrm, const vec3f smooth_nrm, const vec3f smooth_tng, bool opengl)
vec3f _smooth_bitng = cross(smooth_tng, smooth_nrm);
vec3f smooth_bitng = safeNormalize(_smooth_bitng);
vec3f _shading_nrm = smooth_tng * perturbed_nrm.x + (opengl ? -1 : 1) * smooth_bitng * perturbed_nrm.y + smooth_nrm * max(perturbed_nrm.z, 0.0f);
return safeNormalize(_shading_nrm);
__device__ void bwdPerturbNormal(const vec3f perturbed_nrm, const vec3f smooth_nrm, const vec3f smooth_tng, vec3f &d_perturbed_nrm, vec3f &d_smooth_nrm, vec3f &d_smooth_tng, const vec3f d_out, bool opengl)
// FWD
vec3f _smooth_bitng = cross(smooth_tng, smooth_nrm);
vec3f smooth_bitng = safeNormalize(_smooth_bitng);
vec3f _shading_nrm = smooth_tng * perturbed_nrm.x + (opengl ? -1 : 1) * smooth_bitng * perturbed_nrm.y + smooth_nrm * max(perturbed_nrm.z, 0.0f);
// BWD
vec3f d_shading_nrm(0);
bwdSafeNormalize(_shading_nrm, d_shading_nrm, d_out);
vec3f d_smooth_bitng(0);
if (perturbed_nrm.z > 0.0f)
d_smooth_nrm += d_shading_nrm * perturbed_nrm.z;
d_perturbed_nrm.z += sum(d_shading_nrm * smooth_nrm);
d_smooth_bitng += (opengl ? -1 : 1) * d_shading_nrm * perturbed_nrm.y;
d_perturbed_nrm.y += (opengl ? -1 : 1) * sum(d_shading_nrm * smooth_bitng);
d_smooth_tng += d_shading_nrm * perturbed_nrm.x;
d_perturbed_nrm.x += sum(d_shading_nrm * smooth_tng);
vec3f d__smooth_bitng(0);
bwdSafeNormalize(_smooth_bitng, d__smooth_bitng, d_smooth_bitng);
bwdCross(smooth_tng, smooth_nrm, d_smooth_tng, d_smooth_nrm, d__smooth_bitng);
#define bent_nrm_eps 0.001f
__device__ vec3f fwdBendNormal(const vec3f view_vec, const vec3f smooth_nrm, const vec3f geom_nrm)
float dp = dot(view_vec, smooth_nrm);
float t = clamp(dp / NORMAL_THRESHOLD, 0.0f, 1.0f);
return geom_nrm * (1.0f - t) + smooth_nrm * t;
__device__ void bwdBendNormal(const vec3f view_vec, const vec3f smooth_nrm, const vec3f geom_nrm, vec3f& d_view_vec, vec3f& d_smooth_nrm, vec3f& d_geom_nrm, const vec3f d_out)
// FWD
float dp = dot(view_vec, smooth_nrm);
float t = clamp(dp / NORMAL_THRESHOLD, 0.0f, 1.0f);
// BWD
d_smooth_nrm += d_out;
// geom_nrm * (1.0f - t) + smooth_nrm * t;
d_geom_nrm += d_out * (1.0f - t);
d_smooth_nrm += d_out * t;
bwdDot(view_vec, smooth_nrm, d_view_vec, d_smooth_nrm, d_dp);
// Kernels
__global__ void PrepareShadingNormalFwdKernel(PrepareShadingNormalKernelParams p)
// Calculate pixel position.
unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
unsigned int pz = blockIdx.z;
if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
vec3f pos = p.pos.fetch3(px, py, pz);
vec3f view_pos = p.view_pos.fetch3(px, py, pz);
vec3f perturbed_nrm = p.perturbed_nrm.fetch3(px, py, pz);
vec3f _smooth_nrm = p.smooth_nrm.fetch3(px, py, pz);
vec3f _smooth_tng = p.smooth_tng.fetch3(px, py, pz);
vec3f geom_nrm = p.geom_nrm.fetch3(px, py, pz);
vec3f smooth_nrm = safeNormalize(_smooth_nrm);
vec3f smooth_tng = safeNormalize(_smooth_tng);
vec3f view_vec = safeNormalize(view_pos - pos);
vec3f shading_nrm = fwdPerturbNormal(perturbed_nrm, smooth_nrm, smooth_tng, p.opengl);
vec3f res;
if (p.two_sided_shading && dot(view_vec, geom_nrm) < 0.0f)
res = fwdBendNormal(view_vec, -shading_nrm, -geom_nrm);
res = fwdBendNormal(view_vec, shading_nrm, geom_nrm);
p.out.store(px, py, pz, res);
__global__ void PrepareShadingNormalBwdKernel(PrepareShadingNormalKernelParams p)
// Calculate pixel position.
unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
unsigned int pz = blockIdx.z;
if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
vec3f pos = p.pos.fetch3(px, py, pz);
vec3f view_pos = p.view_pos.fetch3(px, py, pz);
vec3f perturbed_nrm = p.perturbed_nrm.fetch3(px, py, pz);
vec3f _smooth_nrm = p.smooth_nrm.fetch3(px, py, pz);
vec3f _smooth_tng = p.smooth_tng.fetch3(px, py, pz);
vec3f geom_nrm = p.geom_nrm.fetch3(px, py, pz);
vec3f d_out = p.out.fetch3(px, py, pz);
// FWD
vec3f smooth_nrm = safeNormalize(_smooth_nrm);
vec3f smooth_tng = safeNormalize(_smooth_tng);
vec3f shading_nrm = fwdPerturbNormal(perturbed_nrm, smooth_nrm, smooth_tng, p.opengl);
// BWD
vec3f d_view_vec(0), d_shading_nrm(0), d_geom_nrm(0);
if (p.two_sided_shading && dot(view_vec, geom_nrm) < 0.0f)
bwdBendNormal(view_vec, -shading_nrm, -geom_nrm, d_view_vec, d_shading_nrm, d_geom_nrm, d_out);
d_shading_nrm = -d_shading_nrm;
d_geom_nrm = -d_geom_nrm;
bwdBendNormal(view_vec, shading_nrm, geom_nrm, d_view_vec, d_shading_nrm, d_geom_nrm, d_out);
vec3f d_perturbed_nrm(0), d_smooth_nrm(0), d_smooth_tng(0);
bwdPerturbNormal(perturbed_nrm, smooth_nrm, smooth_tng, d_perturbed_nrm, d_smooth_nrm, d_smooth_tng, d_shading_nrm, p.opengl);
vec3f d__view_vec(0), d__smooth_nrm(0), d__smooth_tng(0);
bwdSafeNormalize(_view_vec, d__view_vec, d_view_vec);
bwdSafeNormalize(_smooth_nrm, d__smooth_nrm, d_smooth_nrm);
bwdSafeNormalize(_smooth_tng, d__smooth_tng, d_smooth_tng);
p.pos.store_grad(px, py, pz, -d__view_vec);
p.view_pos.store_grad(px, py, pz, d__view_vec);
p.perturbed_nrm.store_grad(px, py, pz, d_perturbed_nrm);
p.smooth_nrm.store_grad(px, py, pz, d__smooth_nrm);
p.smooth_tng.store_grad(px, py, pz, d__smooth_tng);
p.geom_nrm.store_grad(px, py, pz, d_geom_nrm);

* Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
* property and proprietary rights in and to this material, related
* documentation and any modifications thereto. Any use, reproduction,
* disclosure or distribution of this material and related documentation
* without an express license agreement from NVIDIA CORPORATION or
* its affiliates is strictly prohibited.
#pragma once
#include "common.h"
struct PrepareShadingNormalKernelParams
Tensor pos;
Tensor view_pos;
Tensor perturbed_nrm;
Tensor smooth_nrm;
Tensor smooth_tng;
Tensor geom_nrm;
Tensor out;
dim3 gridSize;
bool two_sided_shading, opengl;

* Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
* property and proprietary rights in and to this material, related
* documentation and any modifications thereto. Any use, reproduction,
* disclosure or distribution of this material and related documentation
* without an express license agreement from NVIDIA CORPORATION or
* its affiliates is strictly prohibited.
#pragma once
#if defined(__CUDACC__) && defined(BFLOAT16)
#include <cuda_bf16.h> // bfloat16 is float32 compatible with less mantissa bits
// CUDA-side Tensor class for in/out parameter parsing. Can be float32 or bfloat16
struct Tensor
void* val;
void* d_val;
int dims[4], _dims[4];
int strides[4];
bool fp16;
#if defined(__CUDA__) && !defined(__CUDA_ARCH__)
Tensor() : val(nullptr), d_val(nullptr), fp16(true), dims{ 0, 0, 0, 0 }, _dims{ 0, 0, 0, 0 }, strides{ 0, 0, 0, 0 } {}
#ifdef __CUDACC__
// Helpers to index and read/write a single element
__device__ inline int _nhwcIndex(int n, int h, int w, int c) const { return n * strides[0] + h * strides[1] + w * strides[2] + c * strides[3]; }
__device__ inline int nhwcIndex(int n, int h, int w, int c) const { return (dims[0] == 1 ? 0 : n * strides[0]) + (dims[1] == 1 ? 0 : h * strides[1]) + (dims[2] == 1 ? 0 : w * strides[2]) + (dims[3] == 1 ? 0 : c * strides[3]); }
__device__ inline int nhwcIndexContinuous(int n, int h, int w, int c) const { return ((n * _dims[1] + h) * _dims[2] + w) * _dims[3] + c; }
#ifdef BFLOAT16
__device__ inline float fetch(unsigned int idx) const { return fp16 ? __bfloat162float(((__nv_bfloat16*)val)[idx]) : ((float*)val)[idx]; }
__device__ inline void store(unsigned int idx, float _val) { if (fp16) ((__nv_bfloat16*)val)[idx] = __float2bfloat16(_val); else ((float*)val)[idx] = _val; }
__device__ inline void store_grad(unsigned int idx, float _val) { if (fp16) ((__nv_bfloat16*)d_val)[idx] = __float2bfloat16(_val); else ((float*)d_val)[idx] = _val; }
__device__ inline float fetch(unsigned int idx) const { return ((float*)val)[idx]; }
__device__ inline void store(unsigned int idx, float _val) { ((float*)val)[idx] = _val; }
__device__ inline void store_grad(unsigned int idx, float _val) { ((float*)d_val)[idx] = _val; }
// Fetch, use broadcasting for tensor dimensions of size 1
__device__ inline float fetch1(unsigned int x, unsigned int y, unsigned int z) const
return fetch(nhwcIndex(z, y, x, 0));
__device__ inline vec3f fetch3(unsigned int x, unsigned int y, unsigned int z) const
return vec3f(
fetch(nhwcIndex(z, y, x, 0)),
fetch(nhwcIndex(z, y, x, 1)),
fetch(nhwcIndex(z, y, x, 2))
// Store, no broadcasting here. Assume we output full res gradient and then reduce using torch.sum outside
__device__ inline void store(unsigned int x, unsigned int y, unsigned int z, float _val)
store(_nhwcIndex(z, y, x, 0), _val);
__device__ inline void store(unsigned int x, unsigned int y, unsigned int z, vec3f _val)
store(_nhwcIndex(z, y, x, 0), _val.x);
store(_nhwcIndex(z, y, x, 1), _val.y);
store(_nhwcIndex(z, y, x, 2), _val.z);
// Store gradient , no broadcasting here. Assume we output full res gradient and then reduce using torch.sum outside
__device__ inline void store_grad(unsigned int x, unsigned int y, unsigned int z, float _val)
store_grad(nhwcIndexContinuous(z, y, x, 0), _val);
__device__ inline void store_grad(unsigned int x, unsigned int y, unsigned int z, vec3f _val)
store_grad(nhwcIndexContinuous(z, y, x, 0), _val.x);
store_grad(nhwcIndexContinuous(z, y, x, 1), _val.y);
store_grad(nhwcIndexContinuous(z, y, x, 2), _val.z);

* Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
* property and proprietary rights in and to this material, related
* documentation and any modifications thereto. Any use, reproduction,
* disclosure or distribution of this material and related documentation
* without an express license agreement from NVIDIA CORPORATION or
* its affiliates is strictly prohibited.
#pragma once
struct vec3f
float x, y, z;
#ifdef __CUDACC__
__device__ vec3f() { }
__device__ vec3f(float v) { x = v; y = v; z = v; }
__device__ vec3f(float _x, float _y, float _z) { x = _x; y = _y; z = _z; }
__device__ vec3f(float3 v) { x = v.x; y = v.y; z = v.z; }
__device__ inline vec3f& operator+=(const vec3f& b) { x += b.x; y += b.y; z += b.z; return *this; }
__device__ inline vec3f& operator-=(const vec3f& b) { x -= b.x; y -= b.y; z -= b.z; return *this; }
__device__ inline vec3f& operator*=(const vec3f& b) { x *= b.x; y *= b.y; z *= b.z; return *this; }
__device__ inline vec3f& operator/=(const vec3f& b) { x /= b.x; y /= b.y; z /= b.z; return *this; }
#ifdef __CUDACC__
__device__ static inline vec3f operator+(const vec3f& a, const vec3f& b) { return vec3f(a.x + b.x, a.y + b.y, a.z + b.z); }
__device__ static inline vec3f operator-(const vec3f& a, const vec3f& b) { return vec3f(a.x - b.x, a.y - b.y, a.z - b.z); }
__device__ static inline vec3f operator*(const vec3f& a, const vec3f& b) { return vec3f(a.x * b.x, a.y * b.y, a.z * b.z); }
__device__ static inline vec3f operator/(const vec3f& a, const vec3f& b) { return vec3f(a.x / b.x, a.y / b.y, a.z / b.z); }
__device__ static inline vec3f operator-(const vec3f& a) { return vec3f(-a.x, -a.y, -a.z); }
__device__ static inline float sum(vec3f a)
return a.x + a.y + a.z;
__device__ static inline vec3f cross(vec3f a, vec3f b)
vec3f out;
out.x = a.y * b.z - a.z * b.y;
out.y = a.z * b.x - a.x * b.z;
out.z = a.x * b.y - a.y * b.x;
return out;
__device__ static inline void bwdCross(vec3f a, vec3f b, vec3f &d_a, vec3f &d_b, vec3f d_out)
d_a.x += d_out.z * b.y - d_out.y * b.z;
d_a.y += d_out.x * b.z - d_out.z * b.x;
d_a.z += d_out.y * b.x - d_out.x * b.y;
d_b.y += d_out.z * a.x - d_out.x * a.z;
d_b.z += d_out.x * a.y - d_out.y * a.x;
__device__ static inline float dot(vec3f a, vec3f b)
return a.x * b.x + a.y * b.y + a.z * b.z;
__device__ static inline void bwdDot(vec3f a, vec3f b, vec3f& d_a, vec3f& d_b, float d_out)
d_a.x += d_out * b.x; d_a.y += d_out * b.y; d_a.z += d_out * b.z;
d_b.x += d_out * a.x; d_b.y += d_out * a.y; d_b.z += d_out * a.z;
__device__ static inline vec3f reflect(vec3f x, vec3f n)
return n * 2.0f * dot(n, x) - x;
__device__ static inline void bwdReflect(vec3f x, vec3f n, vec3f& d_x, vec3f& d_n, const vec3f d_out)
d_x.x += d_out.x * (2 * n.x * n.x - 1) + d_out.y * (2 * n.x * n.y) + d_out.z * (2 * n.x * n.z);
d_x.y += d_out.x * (2 * n.x * n.y) + d_out.y * (2 * n.y * n.y - 1) + d_out.z * (2 * n.y * n.z);
d_x.z += d_out.x * (2 * n.x * n.z) + d_out.y * (2 * n.y * n.z) + d_out.z * (2 * n.z * n.z - 1);
d_n.x += d_out.x * (2 * (2 * n.x * x.x + n.y * x.y + n.z * x.z)) + d_out.y * (2 * n.y * x.x) + d_out.z * (2 * n.z * x.x);
d_n.y += d_out.x * (2 * n.x * x.y) + d_out.y * (2 * (n.x * x.x + 2 * n.y * x.y + n.z * x.z)) + d_out.z * (2 * n.z * x.y);
d_n.z += d_out.x * (2 * n.x * x.z) + d_out.y * (2 * n.y * x.z) + d_out.z * (2 * (n.x * x.x + n.y * x.y + 2 * n.z * x.z));
__device__ static inline vec3f safeNormalize(vec3f v)
float l = sqrtf(v.x * v.x + v.y * v.y + v.z * v.z);
return l > 0.0f ? (v / l) : vec3f(0.0f);
__device__ static inline void bwdSafeNormalize(const vec3f v, vec3f& d_v, const vec3f d_out)
float l = sqrtf(v.x * v.x + v.y * v.y + v.z * v.z);
if (l > 0.0f)
float fac = 1.0 / powf(v.x * v.x + v.y * v.y + v.z * v.z, 1.5f);
d_v.x += (d_out.x * (v.y * v.y + v.z * v.z) - d_out.y * (v.x * v.y) - d_out.z * (v.x * v.z)) * fac;
d_v.y += (d_out.y * (v.x * v.x + v.z * v.z) - d_out.x * (v.y * v.x) - d_out.z * (v.y * v.z)) * fac;
d_v.z += (d_out.z * (v.x * v.x + v.y * v.y) - d_out.x * (v.z * v.x) - d_out.y * (v.z * v.y)) * fac;

* Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
* property and proprietary rights in and to this material, related
* documentation and any modifications thereto. Any use, reproduction,
* disclosure or distribution of this material and related documentation
* without an express license agreement from NVIDIA CORPORATION or
* its affiliates is strictly prohibited.
#pragma once
struct vec4f
float x, y, z, w;
#ifdef __CUDACC__
__device__ vec4f() { }
__device__ vec4f(float v) { x = v; y = v; z = v; w = v; }
__device__ vec4f(float _x, float _y, float _z, float _w) { x = _x; y = _y; z = _z; w = _w; }
__device__ vec4f(float4 v) { x = v.x; y = v.y; z = v.z; w = v.w; }

# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.
import torch
# HDR image losses
def _tonemap_srgb(f):
return torch.where(f > 0.0031308, torch.pow(torch.clamp(f, min=0.0031308), 1.0/2.4)*1.055 - 0.055, 12.92*f)
def _SMAPE(img, target, eps=0.01):
nom = torch.abs(img - target)
denom = torch.abs(img) + torch.abs(target) + 0.01
return torch.mean(nom / denom)
def _RELMSE(img, target, eps=0.1):
nom = (img - target) * (img - target)
denom = img * img + target * target + 0.1
return torch.mean(nom / denom)
def image_loss_fn(img, target, loss, tonemapper):
if tonemapper == 'log_srgb':
img = _tonemap_srgb(torch.log(torch.clamp(img, min=0, max=65535) + 1))
target = _tonemap_srgb(torch.log(torch.clamp(target, min=0, max=65535) + 1))
if loss == 'mse':
return torch.nn.functional.mse_loss(img, target)
elif loss == 'smape':
return _SMAPE(img, target)
elif loss == 'relmse':
return _RELMSE(img, target)
return torch.nn.functional.l1_loss(img, target)

# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.
import numpy as np
import os
import sys
import torch
import torch.utils.cpp_extension
from .bsdf import *
from .loss import *
# C++/Cuda plugin compiler/loader.
_cached_plugin = None
def _get_plugin():
# Return cached plugin if already loaded.
global _cached_plugin
if _cached_plugin is not None:
return _cached_plugin
# Make sure we can find the necessary compiler and libary binaries.
if os.name == 'nt':
def find_cl_path():
import glob
for edition in ['Enterprise', 'Professional', 'BuildTools', 'Community']:
paths = sorted(glob.glob(r"C:\Program Files (x86)\Microsoft Visual Studio\*\%s\VC\Tools\MSVC\*\bin\Hostx64\x64" % edition), reverse=True)
if paths:
return paths[0]
# If cl.exe is not on path, try to find it.
if os.system("where cl.exe >nul 2>nul") != 0:
cl_path = find_cl_path()
if cl_path is None:
raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation")
os.environ['PATH'] += ';' + cl_path
# Compiler options.
opts = ['-DNVDR_TORCH']
# Linker options.
if os.name == 'posix':
ldflags = ['-lcuda', '-lnvrtc']
elif os.name == 'nt':
ldflags = ['cuda.lib', 'advapi32.lib', 'nvrtc.lib']
# List of sources.
source_files = [
# Some containers set this to contain old architectures that won't compile. We only need the one installed in the machine.
os.environ['TORCH_CUDA_ARCH_LIST'] = ''
# Try to detect if a stray lock file is left in cache directory and show a warning. This sometimes happens on Windows if the build is interrupted at just the right moment.
lock_fn = os.path.join(torch.utils.cpp_extension._get_build_directory('renderutils_plugin', False), 'lock')
if os.path.exists(lock_fn):
print("Warning: Lock file exists in build directory: '%s'" % lock_fn)
# Compile and load.
source_paths = [os.path.join(os.path.dirname(__file__), fn) for fn in source_files]
torch.utils.cpp_extension.load(name='renderutils_plugin', sources=source_paths, extra_cflags=opts,
extra_cuda_cflags=opts, extra_ldflags=ldflags, with_cuda=True, verbose=True,
# build_directory="PLACEHOLDER",
# Import, cache, and return the compiled module.
import renderutils_plugin
_cached_plugin = renderutils_plugin
return _cached_plugin
# Internal kernels, just used for testing functionality
class _fresnel_shlick_func(torch.autograd.Function):
def forward(ctx, f0, f90, cosTheta):
out = _get_plugin().fresnel_shlick_fwd(f0, f90, cosTheta, False)
ctx.save_for_backward(f0, f90, cosTheta)
return out
def backward(ctx, dout):
f0, f90, cosTheta = ctx.saved_variables
return _get_plugin().fresnel_shlick_bwd(f0, f90, cosTheta, dout) + (None,)
def _fresnel_shlick(f0, f90, cosTheta, use_python=False):
if use_python:
out = bsdf_fresnel_shlick(f0, f90, cosTheta)
out = _fresnel_shlick_func.apply(f0, f90, cosTheta)
if torch.is_anomaly_enabled():
assert torch.all(torch.isfinite(out)), "Output of _fresnel_shlick contains inf or NaN"
return out
class _ndf_ggx_func(torch.autograd.Function):
def forward(ctx, alphaSqr, cosTheta):
out = _get_plugin().ndf_ggx_fwd(alphaSqr, cosTheta, False)
ctx.save_for_backward(alphaSqr, cosTheta)
return out
def backward(ctx, dout):
alphaSqr, cosTheta = ctx.saved_variables
return _get_plugin().ndf_ggx_bwd(alphaSqr, cosTheta, dout) + (None,)
def _ndf_ggx(alphaSqr, cosTheta, use_python=False):
if use_python:
out = bsdf_ndf_ggx(alphaSqr, cosTheta)
out = _ndf_ggx_func.apply(alphaSqr, cosTheta)
if torch.is_anomaly_enabled():
assert torch.all(torch.isfinite(out)), "Output of _ndf_ggx contains inf or NaN"
return out
class _lambda_ggx_func(torch.autograd.Function):
def forward(ctx, alphaSqr, cosTheta):
out = _get_plugin().lambda_ggx_fwd(alphaSqr, cosTheta, False)
ctx.save_for_backward(alphaSqr, cosTheta)
return out
def backward(ctx, dout):
alphaSqr, cosTheta = ctx.saved_variables
return _get_plugin().lambda_ggx_bwd(alphaSqr, cosTheta, dout) + (None,)
def _lambda_ggx(alphaSqr, cosTheta, use_python=False):
if use_python:
out = bsdf_lambda_ggx(alphaSqr, cosTheta)
out = _lambda_ggx_func.apply(alphaSqr, cosTheta)
if torch.is_anomaly_enabled():
assert torch.all(torch.isfinite(out)), "Output of _lambda_ggx contains inf or NaN"
return out
class _masking_smith_func(torch.autograd.Function):
def forward(ctx, alphaSqr, cosThetaI, cosThetaO):
ctx.save_for_backward(alphaSqr, cosThetaI, cosThetaO)
out = _get_plugin().masking_smith_fwd(alphaSqr, cosThetaI, cosThetaO, False)
return out
def backward(ctx, dout):
alphaSqr, cosThetaI, cosThetaO = ctx.saved_variables
return _get_plugin().masking_smith_bwd(alphaSqr, cosThetaI, cosThetaO, dout) + (None,)
def _masking_smith(alphaSqr, cosThetaI, cosThetaO, use_python=False):
if use_python:
out = bsdf_masking_smith_ggx_correlated(alphaSqr, cosThetaI, cosThetaO)
out = _masking_smith_func.apply(alphaSqr, cosThetaI, cosThetaO)
if torch.is_anomaly_enabled():
assert torch.all(torch.isfinite(out)), "Output of _masking_smith contains inf or NaN"
return out
# Shading normal setup (bump mapping + bent normals)
class _prepare_shading_normal_func(torch.autograd.Function):
def forward(ctx, pos, view_pos, perturbed_nrm, smooth_nrm, smooth_tng, geom_nrm, two_sided_shading, opengl):
ctx.two_sided_shading, ctx.opengl = two_sided_shading, opengl
out = _get_plugin().prepare_shading_normal_fwd(pos, view_pos, perturbed_nrm, smooth_nrm, smooth_tng, geom_nrm, two_sided_shading, opengl, False)
ctx.save_for_backward(pos, view_pos, perturbed_nrm, smooth_nrm, smooth_tng, geom_nrm)
return out
def backward(ctx, dout):
pos, view_pos, perturbed_nrm, smooth_nrm, smooth_tng, geom_nrm = ctx.saved_variables
return _get_plugin().prepare_shading_normal_bwd(pos, view_pos, perturbed_nrm, smooth_nrm, smooth_tng, geom_nrm, dout, ctx.two_sided_shading, ctx.opengl) + (None, None, None)
def prepare_shading_normal(pos, view_pos, perturbed_nrm, smooth_nrm, smooth_tng, geom_nrm, two_sided_shading=True, opengl=True, use_python=False):
'''Takes care of all corner cases and produces a final normal used for shading:
- Constructs tangent space
- Flips normal direction based on geometric normal for two sided Shading
- Perturbs shading normal by normal map
- Bends backfacing normals towards the camera to avoid shading artifacts
All tensors assume a shape of [minibatch_size, height, width, 3] or broadcastable equivalent.
pos: World space g-buffer position.
view_pos: Camera position in world space (typically using broadcasting).
perturbed_nrm: Trangent-space normal perturbation from normal map lookup.
smooth_nrm: Interpolated vertex normals.
smooth_tng: Interpolated vertex tangents.
geom_nrm: Geometric (face) normals.
two_sided_shading: Use one/two sided shading
opengl: Use OpenGL/DirectX normal map conventions
use_python: Use PyTorch implementation (for validation)
Final shading normal
if perturbed_nrm is None:
perturbed_nrm = torch.tensor([0, 0, 1], dtype=torch.float32, device='cuda', requires_grad=False)[None, None, None, ...]
if use_python:
out = bsdf_prepare_shading_normal(pos, view_pos, perturbed_nrm, smooth_nrm, smooth_tng, geom_nrm, two_sided_shading, opengl)
out = _prepare_shading_normal_func.apply(pos, view_pos, perturbed_nrm, smooth_nrm, smooth_tng, geom_nrm, two_sided_shading, opengl)
if torch.is_anomaly_enabled():
assert torch.all(torch.isfinite(out)), "Output of prepare_shading_normal contains inf or NaN"
return out
# BSDF functions
class _lambert_func(torch.autograd.Function):
def forward(ctx, nrm, wi):
out = _get_plugin().lambert_fwd(nrm, wi, False)
ctx.save_for_backward(nrm, wi)
return out
def backward(ctx, dout):
nrm, wi = ctx.saved_variables
return _get_plugin().lambert_bwd(nrm, wi, dout) + (None,)
def lambert(nrm, wi, use_python=False):
'''Lambertian bsdf.
All tensors assume a shape of [minibatch_size, height, width, 3] or broadcastable equivalent.
nrm: World space shading normal.
wi: World space light vector.
use_python: Use PyTorch implementation (for validation)
Shaded diffuse value with shape [minibatch_size, height, width, 1]
if use_python:
out = bsdf_lambert(nrm, wi)
out = _lambert_func.apply(nrm, wi)
if torch.is_anomaly_enabled():
assert torch.all(torch.isfinite(out)), "Output of lambert contains inf or NaN"
return out
class _frostbite_diffuse_func(torch.autograd.Function):
def forward(ctx, nrm, wi, wo, linearRoughness):
out = _get_plugin().frostbite_fwd(nrm, wi, wo, linearRoughness, False)
ctx.save_for_backward(nrm, wi, wo, linearRoughness)
return out
def backward(ctx, dout):
nrm, wi, wo, linearRoughness = ctx.saved_variables
return _get_plugin().frostbite_bwd(nrm, wi, wo, linearRoughness, dout) + (None,)
def frostbite_diffuse(nrm, wi, wo, linearRoughness, use_python=False):
'''Frostbite, normalized Disney Diffuse bsdf.
All tensors assume a shape of [minibatch_size, height, width, 3] or broadcastable equivalent.
nrm: World space shading normal.
wi: World space light vector.
wo: World space camera vector.
linearRoughness: Material roughness
use_python: Use PyTorch implementation (for validation)
Shaded diffuse value with shape [minibatch_size, height, width, 1]
if use_python:
out = bsdf_frostbite(nrm, wi, wo, linearRoughness)
out = _frostbite_diffuse_func.apply(nrm, wi, wo, linearRoughness)
if torch.is_anomaly_enabled():
assert torch.all(torch.isfinite(out)), "Output of lambert contains inf or NaN"
return out
class _pbr_specular_func(torch.autograd.Function):
def forward(ctx, col, nrm, wo, wi, alpha, min_roughness):
ctx.save_for_backward(col, nrm, wo, wi, alpha)
ctx.min_roughness = min_roughness
out = _get_plugin().pbr_specular_fwd(col, nrm, wo, wi, alpha, min_roughness, False)
return out
def backward(ctx, dout):
col, nrm, wo, wi, alpha = ctx.saved_variables
return _get_plugin().pbr_specular_bwd(col, nrm, wo, wi, alpha, ctx.min_roughness, dout) + (None, None)
def pbr_specular(col, nrm, wo, wi, alpha, min_roughness=0.08, use_python=False):
'''Physically-based specular bsdf.
All tensors assume a shape of [minibatch_size, height, width, 3] or broadcastable equivalent unless otherwise noted.
col: Specular lobe color
nrm: World space shading normal.
wo: World space camera vector.
wi: World space light vector
alpha: Specular roughness parameter with shape [minibatch_size, height, width, 1]
min_roughness: Scalar roughness clamping threshold
use_python: Use PyTorch implementation (for validation)
Shaded specular color
if use_python:
out = bsdf_pbr_specular(col, nrm, wo, wi, alpha, min_roughness=min_roughness)
out = _pbr_specular_func.apply(col, nrm, wo, wi, alpha, min_roughness)
if torch.is_anomaly_enabled():
assert torch.all(torch.isfinite(out)), "Output of pbr_specular contains inf or NaN"
return out
class _pbr_bsdf_func(torch.autograd.Function):
def forward(ctx, kd, arm, pos, nrm, view_pos, light_pos, min_roughness, BSDF):
ctx.save_for_backward(kd, arm, pos, nrm, view_pos, light_pos)
ctx.min_roughness = min_roughness
out = _get_plugin().pbr_bsdf_fwd(kd, arm, pos, nrm, view_pos, light_pos, min_roughness, BSDF, False)
return out
def backward(ctx, dout):
kd, arm, pos, nrm, view_pos, light_pos = ctx.saved_variables
return _get_plugin().pbr_bsdf_bwd(kd, arm, pos, nrm, view_pos, light_pos, ctx.min_roughness, ctx.BSDF, dout) + (None, None, None)
def pbr_bsdf(kd, arm, pos, nrm, view_pos, light_pos, min_roughness=0.08, bsdf="lambert", use_python=False):
'''Physically-based bsdf, both diffuse & specular lobes
All tensors assume a shape of [minibatch_size, height, width, 3] or broadcastable equivalent unless otherwise noted.
kd: Diffuse albedo.
arm: Specular parameters (attenuation, linear roughness, metalness).
pos: World space position.
nrm: World space shading normal.
view_pos: Camera position in world space, typically using broadcasting.
light_pos: Light position in world space, typically using broadcasting.
min_roughness: Scalar roughness clamping threshold
bsdf: Controls diffuse BSDF, can be either 'lambert' or 'frostbite'
use_python: Use PyTorch implementation (for validation)
Shaded color.
BSDF = 0
if bsdf == 'frostbite':
BSDF = 1
if use_python:
out = bsdf_pbr(kd, arm, pos, nrm, view_pos, light_pos, min_roughness, BSDF)
out = _pbr_bsdf_func.apply(kd, arm, pos, nrm, view_pos, light_pos, min_roughness, BSDF)
if torch.is_anomaly_enabled():
assert torch.all(torch.isfinite(out)), "Output of pbr_bsdf contains inf or NaN"
return out
# cubemap filter with filtering across edges
class _diffuse_cubemap_func(torch.autograd.Function):
def forward(ctx, cubemap):
out = _get_plugin().diffuse_cubemap_fwd(cubemap)
return out
def backward(ctx, dout):
cubemap, = ctx.saved_variables
cubemap_grad = _get_plugin().diffuse_cubemap_bwd(cubemap, dout)
return cubemap_grad, None
def diffuse_cubemap(cubemap, use_python=False):
if use_python:
assert False
out = _diffuse_cubemap_func.apply(cubemap)
if torch.is_anomaly_enabled():
assert torch.all(torch.isfinite(out)), "Output of diffuse_cubemap contains inf or NaN"
return out
class _specular_cubemap(torch.autograd.Function):
def forward(ctx, cubemap, roughness, costheta_cutoff, bounds):
out = _get_plugin().specular_cubemap_fwd(cubemap, bounds, roughness, costheta_cutoff)
ctx.save_for_backward(cubemap, bounds)
ctx.roughness, ctx.theta_cutoff = roughness, costheta_cutoff
return out
def backward(ctx, dout):
cubemap, bounds = ctx.saved_variables
cubemap_grad = _get_plugin().specular_cubemap_bwd(cubemap, bounds, dout, ctx.roughness, ctx.theta_cutoff)
return cubemap_grad, None, None, None
# Compute the bounds of the GGX NDF lobe to retain "cutoff" percent of the energy
def __ndfBounds(res, roughness, cutoff):
def ndfGGX(alphaSqr, costheta):
costheta = np.clip(costheta, 0.0, 1.0)
d = (costheta * alphaSqr - costheta) * costheta + 1.0
return alphaSqr / (d * d * np.pi)
# Sample out cutoff angle
nSamples = 1000000
costheta = np.cos(np.linspace(0, np.pi/2.0, nSamples))
D = np.cumsum(ndfGGX(roughness**4, costheta))
idx = np.argmax(D >= D[..., -1] * cutoff)
# Brute force compute lookup table with bounds
bounds = _get_plugin().specular_bounds(res, costheta[idx])
return costheta[idx], bounds
__ndfBoundsDict = {}
def specular_cubemap(cubemap, roughness, cutoff=0.99, use_python=False):
assert cubemap.shape[0] == 6 and cubemap.shape[1] == cubemap.shape[2], "Bad shape for cubemap tensor: %s" % str(cubemap.shape)
if use_python:
assert False
key = (cubemap.shape[1], roughness, cutoff)
if key not in __ndfBoundsDict:
__ndfBoundsDict[key] = __ndfBounds(*key)
out = _specular_cubemap.apply(cubemap, roughness, *__ndfBoundsDict[key])
if torch.is_anomaly_enabled():
assert torch.all(torch.isfinite(out)), "Output of specular_cubemap contains inf or NaN"
return out[..., 0:3] / out[..., 3:]
# Fast image loss function
class _image_loss_func(torch.autograd.Function):
def forward(ctx, img, target, loss, tonemapper):
ctx.loss, ctx.tonemapper = loss, tonemapper
ctx.save_for_backward(img, target)
out = _get_plugin().image_loss_fwd(img, target, loss, tonemapper, False)
return out
def backward(ctx, dout):
img, target = ctx.saved_variables
return _get_plugin().image_loss_bwd(img, target, dout, ctx.loss, ctx.tonemapper) + (None, None, None)
def image_loss(img, target, loss='l1', tonemapper='none', use_python=False):
'''Compute HDR image loss. Combines tonemapping and loss into a single kernel for better perf.
All tensors assume a shape of [minibatch_size, height, width, 3] or broadcastable equivalent unless otherwise noted.
img: Input image.
target: Target (reference) image.
loss: Type of loss. Valid options are ['l1', 'mse', 'smape', 'relmse']
tonemapper: Tonemapping operations. Valid options are ['none', 'log_srgb']
use_python: Use PyTorch implementation (for validation)
Image space loss (scalar value).
if use_python:
out = image_loss_fn(img, target, loss, tonemapper)
out = _image_loss_func.apply(img, target, loss, tonemapper)
out = torch.sum(out) / (img.shape[0]*img.shape[1]*img.shape[2])
if torch.is_anomaly_enabled():
assert torch.all(torch.isfinite(out)), "Output of image_loss contains inf or NaN"
return out
# Transform points function
class _xfm_func(torch.autograd.Function):
def forward(ctx, points, matrix, isPoints):
ctx.save_for_backward(points, matrix)
ctx.isPoints = isPoints
return _get_plugin().xfm_fwd(points, matrix, isPoints, False)
def backward(ctx, dout):
points, matrix = ctx.saved_variables
return (_get_plugin().xfm_bwd(points, matrix, dout, ctx.isPoints),) + (None, None, None)
def xfm_points(points, matrix, use_python=False):
'''Transform points.
points: Tensor containing 3D points with shape [minibatch_size, num_vertices, 3] or [1, num_vertices, 3]
matrix: A 4x4 transform matrix with shape [minibatch_size, 4, 4]
use_python: Use PyTorch's torch.matmul (for validation)
Transformed points in homogeneous 4D with shape [minibatch_size, num_vertices, 4].
if use_python:
out = torch.matmul(torch.nn.functional.pad(points, pad=(0,1), mode='constant', value=1.0), torch.transpose(matrix, 1, 2))
out = _xfm_func.apply(points, matrix, True)
if torch.is_anomaly_enabled():
assert torch.all(torch.isfinite(out)), "Output of xfm_points contains inf or NaN"
return out
def xfm_vectors(vectors, matrix, use_python=False):
'''Transform vectors.
vectors: Tensor containing 3D vectors with shape [minibatch_size, num_vertices, 3] or [1, num_vertices, 3]
matrix: A 4x4 transform matrix with shape [minibatch_size, 4, 4]
use_python: Use PyTorch's torch.matmul (for validation)
Transformed vectors in homogeneous 4D with shape [minibatch_size, num_vertices, 4].
if use_python:
out = torch.matmul(torch.nn.functional.pad(vectors, pad=(0,1), mode='constant', value=0.0), torch.transpose(matrix, 1, 2))[..., 0:3].contiguous()
out = _xfm_func.apply(vectors, matrix, False)
if torch.is_anomaly_enabled():
assert torch.all(torch.isfinite(out)), "Output of xfm_vectors contains inf or NaN"
return out

from setuptools import setup
import torch
import os,glob
from torch.utils.cpp_extension import (CUDAExtension, CppExtension, BuildExtension)
def get_extensions():
extensions = []
ext_name = 'nvdiffrec_renderutils'
# prevent ninja from using too many resources
os.environ.setdefault('MAX_JOBS', '16')
define_macros = []
# Compiler options.
opts = ['-DNVDR_TORCH']
# Linker options.
if os.name == 'posix':
ldflags = ['-lcuda', '-lnvrtc']
elif os.name == 'nt':
ldflags = ['cuda.lib', 'advapi32.lib', 'nvrtc.lib']
# List of sources.
source_files = [
os.environ['TORCH_CUDA_ARCH_LIST'] = "5.0 6.0 6.1 7.0 7.5 8.0 8.6"
if torch.cuda.is_available():
print(f'Compiling {ext_name} with CUDA')
define_macros += [('WITH_CUDA', None)]
# op_files = glob.glob('./c_src/*')
# extension = CUDAExtension
raise NotImplementedError
include_path = os.path.abspath('./c_src')
ext_ops = CUDAExtension(
extra_compile_args=opts + ldflags,
libraries=['cuda', 'nvrtc'],
return extensions
cmdclass={'build_ext': BuildExtension},

# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.
import torch
import os
import sys
sys.path.insert(0, os.path.join(sys.path[0], '../..'))
import renderutils as ru
RES = 4
DTYPE = torch.float32
def relative_loss(name, ref, cuda):
ref = ref.float()
cuda = cuda.float()
print(name, torch.max(torch.abs(ref - cuda) / torch.abs(ref + 1e-7)).item())
def test_normal():
pos_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)
pos_ref = pos_cuda.clone().detach().requires_grad_(True)
view_pos_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)
view_pos_ref = view_pos_cuda.clone().detach().requires_grad_(True)
perturbed_nrm_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)
perturbed_nrm_ref = perturbed_nrm_cuda.clone().detach().requires_grad_(True)
smooth_nrm_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)
smooth_nrm_ref = smooth_nrm_cuda.clone().detach().requires_grad_(True)
smooth_tng_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)
smooth_tng_ref = smooth_tng_cuda.clone().detach().requires_grad_(True)
geom_nrm_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)
geom_nrm_ref = geom_nrm_cuda.clone().detach().requires_grad_(True)
target = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda')
ref = ru.prepare_shading_normal(pos_ref, view_pos_ref, perturbed_nrm_ref, smooth_nrm_ref, smooth_tng_ref, geom_nrm_ref, True, use_python=True)
ref_loss = torch.nn.MSELoss()(ref, target)
cuda = ru.prepare_shading_normal(pos_cuda, view_pos_cuda, perturbed_nrm_cuda, smooth_nrm_cuda, smooth_tng_cuda, geom_nrm_cuda, True)
cuda_loss = torch.nn.MSELoss()(cuda, target)
print(" bent normal")
relative_loss("res:", ref, cuda)
relative_loss("pos:", pos_ref.grad, pos_cuda.grad)
relative_loss("view_pos:", view_pos_ref.grad, view_pos_cuda.grad)
relative_loss("perturbed_nrm:", perturbed_nrm_ref.grad, perturbed_nrm_cuda.grad)
relative_loss("smooth_nrm:", smooth_nrm_ref.grad, smooth_nrm_cuda.grad)
relative_loss("smooth_tng:", smooth_tng_ref.grad, smooth_tng_cuda.grad)
relative_loss("geom_nrm:", geom_nrm_ref.grad, geom_nrm_cuda.grad)
def test_schlick():
f0_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)
f0_ref = f0_cuda.clone().detach().requires_grad_(True)
f90_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)
f90_ref = f90_cuda.clone().detach().requires_grad_(True)
cosT_cuda = torch.rand(1, RES, RES, 1, dtype=DTYPE, device='cuda', requires_grad=True) * 2.0
cosT_cuda = cosT_cuda.clone().detach().requires_grad_(True)
cosT_ref = cosT_cuda.clone().detach().requires_grad_(True)
target = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda')
ref = ru._fresnel_shlick(f0_ref, f90_ref, cosT_ref, use_python=True)
ref_loss = torch.nn.MSELoss()(ref, target)
cuda = ru._fresnel_shlick(f0_cuda, f90_cuda, cosT_cuda)
cuda_loss = torch.nn.MSELoss()(cuda, target)
print(" Fresnel shlick")
relative_loss("res:", ref, cuda)
relative_loss("f0:", f0_ref.grad, f0_cuda.grad)
relative_loss("f90:", f90_ref.grad, f90_cuda.grad)
relative_loss("cosT:", cosT_ref.grad, cosT_cuda.grad)
def test_ndf_ggx():
alphaSqr_cuda = torch.rand(1, RES, RES, 1, dtype=DTYPE, device='cuda', requires_grad=True)
alphaSqr_cuda = alphaSqr_cuda.clone().detach().requires_grad_(True)
alphaSqr_ref = alphaSqr_cuda.clone().detach().requires_grad_(True)
cosT_cuda = torch.rand(1, RES, RES, 1, dtype=DTYPE, device='cuda', requires_grad=True) * 3.0 - 1
cosT_cuda = cosT_cuda.clone().detach().requires_grad_(True)
cosT_ref = cosT_cuda.clone().detach().requires_grad_(True)
target = torch.rand(1, RES, RES, 1, dtype=DTYPE, device='cuda')
ref = ru._ndf_ggx(alphaSqr_ref, cosT_ref, use_python=True)
ref_loss = torch.nn.MSELoss()(ref, target)
cuda = ru._ndf_ggx(alphaSqr_cuda, cosT_cuda)
cuda_loss = torch.nn.MSELoss()(cuda, target)
print(" Ndf GGX")
relative_loss("res:", ref, cuda)
relative_loss("alpha:", alphaSqr_ref.grad, alphaSqr_cuda.grad)
relative_loss("cosT:", cosT_ref.grad, cosT_cuda.grad)
def test_lambda_ggx():
alphaSqr_cuda = torch.rand(1, RES, RES, 1, dtype=DTYPE, device='cuda', requires_grad=True)
alphaSqr_ref = alphaSqr_cuda.clone().detach().requires_grad_(True)
cosT_cuda = torch.rand(1, RES, RES, 1, dtype=DTYPE, device='cuda', requires_grad=True) * 3.0 - 1
cosT_cuda = cosT_cuda.clone().detach().requires_grad_(True)
cosT_ref = cosT_cuda.clone().detach().requires_grad_(True)
target = torch.rand(1, RES, RES, 1, dtype=DTYPE, device='cuda')
ref = ru._lambda_ggx(alphaSqr_ref, cosT_ref, use_python=True)
ref_loss = torch.nn.MSELoss()(ref, target)
cuda = ru._lambda_ggx(alphaSqr_cuda, cosT_cuda)
cuda_loss = torch.nn.MSELoss()(cuda, target)
print(" Lambda GGX")
relative_loss("res:", ref, cuda)
relative_loss("alpha:", alphaSqr_ref.grad, alphaSqr_cuda.grad)
relative_loss("cosT:", cosT_ref.grad, cosT_cuda.grad)
def test_masking_smith():
alphaSqr_cuda = torch.rand(1, RES, RES, 1, dtype=DTYPE, device='cuda', requires_grad=True)
alphaSqr_ref = alphaSqr_cuda.clone().detach().requires_grad_(True)
cosThetaI_cuda = torch.rand(1, RES, RES, 1, dtype=DTYPE, device='cuda', requires_grad=True)
cosThetaI_ref = cosThetaI_cuda.clone().detach().requires_grad_(True)
cosThetaO_cuda = torch.rand(1, RES, RES, 1, dtype=DTYPE, device='cuda', requires_grad=True)
cosThetaO_ref = cosThetaO_cuda.clone().detach().requires_grad_(True)
target = torch.rand(1, RES, RES, 1, dtype=DTYPE, device='cuda')
ref = ru._masking_smith(alphaSqr_ref, cosThetaI_ref, cosThetaO_ref, use_python=True)
ref_loss = torch.nn.MSELoss()(ref, target)
cuda = ru._masking_smith(alphaSqr_cuda, cosThetaI_cuda, cosThetaO_cuda)
cuda_loss = torch.nn.MSELoss()(cuda, target)
print(" Smith masking term")
relative_loss("res:", ref, cuda)
relative_loss("alpha:", alphaSqr_ref.grad, alphaSqr_cuda.grad)
relative_loss("cosThetaI:", cosThetaI_ref.grad, cosThetaI_cuda.grad)
relative_loss("cosThetaO:", cosThetaO_ref.grad, cosThetaO_cuda.grad)
def test_lambert():
normals_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)
normals_ref = normals_cuda.clone().detach().requires_grad_(True)
wi_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)
wi_ref = wi_cuda.clone().detach().requires_grad_(True)
target = torch.rand(1, RES, RES, 1, dtype=DTYPE, device='cuda')
ref = ru.lambert(normals_ref, wi_ref, use_python=True)
ref_loss = torch.nn.MSELoss()(ref, target)
cuda = ru.lambert(normals_cuda, wi_cuda)
cuda_loss = torch.nn.MSELoss()(cuda, target)
print(" Lambert")
relative_loss("res:", ref, cuda)
relative_loss("nrm:", normals_ref.grad, normals_cuda.grad)
relative_loss("wi:", wi_ref.grad, wi_cuda.grad)
def test_frostbite():
normals_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)
normals_ref = normals_cuda.clone().detach().requires_grad_(True)
wi_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)
wi_ref = wi_cuda.clone().detach().requires_grad_(True)
wo_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)
wo_ref = wo_cuda.clone().detach().requires_grad_(True)
rough_cuda = torch.rand(1, RES, RES, 1, dtype=DTYPE, device='cuda', requires_grad=True)
rough_ref = rough_cuda.clone().detach().requires_grad_(True)
target = torch.rand(1, RES, RES, 1, dtype=DTYPE, device='cuda')
ref = ru.frostbite_diffuse(normals_ref, wi_ref, wo_ref, rough_ref, use_python=True)
ref_loss = torch.nn.MSELoss()(ref, target)
cuda = ru.frostbite_diffuse(normals_cuda, wi_cuda, wo_cuda, rough_cuda)
cuda_loss = torch.nn.MSELoss()(cuda, target)
print(" Frostbite")
relative_loss("res:", ref, cuda)
relative_loss("nrm:", normals_ref.grad, normals_cuda.grad)
relative_loss("wo:", wo_ref.grad, wo_cuda.grad)
relative_loss("wi:", wi_ref.grad, wi_cuda.grad)
relative_loss("rough:", rough_ref.grad, rough_cuda.grad)
def test_pbr_specular():
col_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)
col_ref = col_cuda.clone().detach().requires_grad_(True)
nrm_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)
nrm_ref = nrm_cuda.clone().detach().requires_grad_(True)
wi_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)
wi_ref = wi_cuda.clone().detach().requires_grad_(True)
wo_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)
wo_ref = wo_cuda.clone().detach().requires_grad_(True)
alpha_cuda = torch.rand(1, RES, RES, 1, dtype=DTYPE, device='cuda', requires_grad=True)
alpha_ref = alpha_cuda.clone().detach().requires_grad_(True)
target = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda')
ref = ru.pbr_specular(col_ref, nrm_ref, wo_ref, wi_ref, alpha_ref, use_python=True)
ref_loss = torch.nn.MSELoss()(ref, target)
cuda = ru.pbr_specular(col_cuda, nrm_cuda, wo_cuda, wi_cuda, alpha_cuda)
cuda_loss = torch.nn.MSELoss()(cuda, target)
print(" Pbr specular")
relative_loss("res:", ref, cuda)
if col_ref.grad is not None:
relative_loss("col:", col_ref.grad, col_cuda.grad)
if nrm_ref.grad is not None:
relative_loss("nrm:", nrm_ref.grad, nrm_cuda.grad)
if wi_ref.grad is not None:
relative_loss("wi:", wi_ref.grad, wi_cuda.grad)
if wo_ref.grad is not None:
relative_loss("wo:", wo_ref.grad, wo_cuda.grad)
if alpha_ref.grad is not None:
relative_loss("alpha:", alpha_ref.grad, alpha_cuda.grad)
def test_pbr_bsdf(bsdf):
kd_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)
kd_ref = kd_cuda.clone().detach().requires_grad_(True)
arm_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)
arm_ref = arm_cuda.clone().detach().requires_grad_(True)
pos_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)
pos_ref = pos_cuda.clone().detach().requires_grad_(True)
nrm_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)
nrm_ref = nrm_cuda.clone().detach().requires_grad_(True)
view_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)
view_ref = view_cuda.clone().detach().requires_grad_(True)
light_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)
light_ref = light_cuda.clone().detach().requires_grad_(True)
target = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda')
ref = ru.pbr_bsdf(kd_ref, arm_ref, pos_ref, nrm_ref, view_ref, light_ref, use_python=True, bsdf=bsdf)
ref_loss = torch.nn.MSELoss()(ref, target)
cuda = ru.pbr_bsdf(kd_cuda, arm_cuda, pos_cuda, nrm_cuda, view_cuda, light_cuda, bsdf=bsdf)
cuda_loss = torch.nn.MSELoss()(cuda, target)
print(" Pbr BSDF")
relative_loss("res:", ref, cuda)
if kd_ref.grad is not None:
relative_loss("kd:", kd_ref.grad, kd_cuda.grad)
if arm_ref.grad is not None:
relative_loss("arm:", arm_ref.grad, arm_cuda.grad)
if pos_ref.grad is not None:
relative_loss("pos:", pos_ref.grad, pos_cuda.grad)
if nrm_ref.grad is not None:
relative_loss("nrm:", nrm_ref.grad, nrm_cuda.grad)
if view_ref.grad is not None:
relative_loss("view:", view_ref.grad, view_cuda.grad)
if light_ref.grad is not None:
relative_loss("light:", light_ref.grad, light_cuda.grad)

# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.
import torch
import os
import sys
sys.path.insert(0, os.path.join(sys.path[0], '../..'))
import renderutils as ru
RES = 4
DTYPE = torch.float32
def relative_loss(name, ref, cuda):
ref = ref.float()
cuda = cuda.float()
print(name, torch.max(torch.abs(ref - cuda) / torch.abs(ref + 1e-7)).item())
def test_cubemap():
cubemap_cuda = torch.rand(6, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)
cubemap_ref = cubemap_cuda.clone().detach().requires_grad_(True)
weights = torch.rand(3, 3, 1, dtype=DTYPE, device='cuda')
target = torch.rand(6, RES, RES, 3, dtype=DTYPE, device='cuda')
ref = ru.filter_cubemap(cubemap_ref, weights, use_python=True)
ref_loss = torch.nn.MSELoss()(ref, target)
cuda = ru.filter_cubemap(cubemap_cuda, weights, use_python=False)
cuda_loss = torch.nn.MSELoss()(cuda, target)
print(" Cubemap:")
relative_loss("flt:", ref, cuda)
relative_loss("cubemap:", cubemap_ref.grad, cubemap_cuda.grad)

# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.
import torch
import os
import sys
sys.path.insert(0, os.path.join(sys.path[0], '../..'))
import renderutils as ru
RES = 8
DTYPE = torch.float32
def tonemap_srgb(f):
return torch.where(f > 0.0031308, torch.pow(torch.clamp(f, min=0.0031308), 1.0/2.4)*1.055 - 0.055, 12.92*f)
def l1(output, target):
x = torch.clamp(output, min=0, max=65535)
r = torch.clamp(target, min=0, max=65535)
x = tonemap_srgb(torch.log(x + 1))
r = tonemap_srgb(torch.log(r + 1))
return torch.nn.functional.l1_loss(x,r)
def relative_loss(name, ref, cuda):
ref = ref.float()
cuda = cuda.float()
print(name, torch.max(torch.abs(ref - cuda) / torch.abs(ref + 1e-7)).item())
def test_loss(loss, tonemapper):
img_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)
img_ref = img_cuda.clone().detach().requires_grad_(True)
target_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)
target_ref = target_cuda.clone().detach().requires_grad_(True)
ref_loss = ru.image_loss(img_ref, target_ref, loss=loss, tonemapper=tonemapper, use_python=True)
cuda_loss = ru.image_loss(img_cuda, target_cuda, loss=loss, tonemapper=tonemapper)
print(" Loss: %s, %s" % (loss, tonemapper))
relative_loss("res:", ref_loss, cuda_loss)
relative_loss("img:", img_ref.grad, img_cuda.grad)
relative_loss("target:", target_ref.grad, target_cuda.grad)
test_loss('l1', 'none')
test_loss('l1', 'log_srgb')
test_loss('mse', 'log_srgb')
test_loss('smape', 'none')
test_loss('relmse', 'none')
test_loss('mse', 'none')

# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.
import torch
import os
import sys
sys.path.insert(0, os.path.join(sys.path[0], '../..'))
import renderutils as ru
RES = 1024
DTYPE = torch.float32
def tonemap_srgb(f):
return torch.where(f > 0.0031308, torch.pow(torch.clamp(f, min=0.0031308), 1.0/2.4)*1.055 - 0.055, 12.92*f)
def l1(output, target):
x = torch.clamp(output, min=0, max=65535)
r = torch.clamp(target, min=0, max=65535)
x = tonemap_srgb(torch.log(x + 1))
r = tonemap_srgb(torch.log(r + 1))
return torch.nn.functional.l1_loss(x,r)
def relative_loss(name, ref, cuda):
ref = ref.float()
cuda = cuda.float()
print(name, torch.max(torch.abs(ref - cuda) / torch.abs(ref)).item())
def test_xfm_points():
points_cuda = torch.rand(1, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)
points_ref = points_cuda.clone().detach().requires_grad_(True)
mtx_cuda = torch.rand(BATCH, 4, 4, dtype=DTYPE, device='cuda', requires_grad=False)
mtx_ref = mtx_cuda.clone().detach().requires_grad_(True)
target = torch.rand(BATCH, RES, 4, dtype=DTYPE, device='cuda', requires_grad=True)
ref_out = ru.xfm_points(points_ref, mtx_ref, use_python=True)
ref_loss = torch.nn.MSELoss()(ref_out, target)
cuda_out = ru.xfm_points(points_cuda, mtx_cuda)
cuda_loss = torch.nn.MSELoss()(cuda_out, target)
relative_loss("res:", ref_out, cuda_out)
relative_loss("points:", points_ref.grad, points_cuda.grad)
def test_xfm_vectors():
points_cuda = torch.rand(1, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)
points_ref = points_cuda.clone().detach().requires_grad_(True)
points_cuda_p = points_cuda.clone().detach().requires_grad_(True)
points_ref_p = points_cuda.clone().detach().requires_grad_(True)
mtx_cuda = torch.rand(BATCH, 4, 4, dtype=DTYPE, device='cuda', requires_grad=False)
mtx_ref = mtx_cuda.clone().detach().requires_grad_(True)
target = torch.rand(BATCH, RES, 4, dtype=DTYPE, device='cuda', requires_grad=True)
ref_out = ru.xfm_vectors(points_ref.contiguous(), mtx_ref, use_python=True)
ref_loss = torch.nn.MSELoss()(ref_out, target[..., 0:3])
cuda_out = ru.xfm_vectors(points_cuda.contiguous(), mtx_cuda)
cuda_loss = torch.nn.MSELoss()(cuda_out, target[..., 0:3])
ref_out_p = ru.xfm_points(points_ref_p.contiguous(), mtx_ref, use_python=True)
ref_loss_p = torch.nn.MSELoss()(ref_out_p, target)
cuda_out_p = ru.xfm_points(points_cuda_p.contiguous(), mtx_cuda)
cuda_loss_p = torch.nn.MSELoss()(cuda_out_p, target)
relative_loss("res:", ref_out, cuda_out)
relative_loss("points:", points_ref.grad, points_cuda.grad)
relative_loss("points_p:", points_ref_p.grad, points_cuda_p.grad)

# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.
import torch
import os
import sys
sys.path.insert(0, os.path.join(sys.path[0], '../..'))
import renderutils as ru
def test_bsdf(BATCH, RES, ITR):
kd_cuda = torch.rand(BATCH, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)
kd_ref = kd_cuda.clone().detach().requires_grad_(True)
arm_cuda = torch.rand(BATCH, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)
arm_ref = arm_cuda.clone().detach().requires_grad_(True)
pos_cuda = torch.rand(BATCH, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)
pos_ref = pos_cuda.clone().detach().requires_grad_(True)
nrm_cuda = torch.rand(BATCH, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)
nrm_ref = nrm_cuda.clone().detach().requires_grad_(True)
view_cuda = torch.rand(BATCH, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)
view_ref = view_cuda.clone().detach().requires_grad_(True)
light_cuda = torch.rand(BATCH, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)
light_ref = light_cuda.clone().detach().requires_grad_(True)
target = torch.rand(BATCH, RES, RES, 3, device='cuda')
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
ru.pbr_bsdf(kd_cuda, arm_cuda, pos_cuda, nrm_cuda, view_cuda, light_cuda)
print("--- Testing: [%d, %d, %d] ---" % (BATCH, RES, RES))
for i in range(ITR):
ref = ru.pbr_bsdf(kd_ref, arm_ref, pos_ref, nrm_ref, view_ref, light_ref, use_python=True)
print("Pbr BSDF python:", start.elapsed_time(end))
for i in range(ITR):
cuda = ru.pbr_bsdf(kd_cuda, arm_cuda, pos_cuda, nrm_cuda, view_cuda, light_cuda)
print("Pbr BSDF cuda:", start.elapsed_time(end))
test_bsdf(1, 512, 1000)
test_bsdf(16, 512, 1000)
test_bsdf(1, 2048, 1000)

# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.
import os
import numpy as np
import torch
import nvdiffrast.torch as dr
from . import util
# Smooth pooling / mip computation with linear gradient upscaling
class texture2d_mip(torch.autograd.Function):
def forward(ctx, texture):
return util.avg_pool_nhwc(texture, (2,2))
def backward(ctx, dout):
gy, gx = torch.meshgrid(torch.linspace(0.0 + 0.25 / dout.shape[1], 1.0 - 0.25 / dout.shape[1], dout.shape[1]*2, device="cuda"),
torch.linspace(0.0 + 0.25 / dout.shape[2], 1.0 - 0.25 / dout.shape[2], dout.shape[2]*2, device="cuda"),
uv = torch.stack((gx, gy), dim=-1)
return dr.texture(dout * 0.25, uv[None, ...].contiguous(), filter_mode='linear', boundary_mode='clamp')
# Simple texture class. A texture can be either
# - A 3D tensor (using auto mipmaps)
# - A list of 3D tensors (full custom mip hierarchy)
class Texture2D(torch.nn.Module):
# Initializes a texture from image data.
# Input can be constant value (1D array) or texture (3D array) or mip hierarchy (list of 3d arrays)
def __init__(self, init, min_max=None, trainable=True):
super(Texture2D, self).__init__()
if isinstance(init, np.ndarray):
init = torch.tensor(init, dtype=torch.float32, device='cuda')
elif isinstance(init, list) and len(init) == 1:
init = init[0]
if isinstance(init, list):
self.data = list(torch.nn.Parameter(mip.clone().detach(), requires_grad=trainable) for mip in init)
elif len(init.shape) == 4:
self.data = torch.nn.Parameter(init.clone().detach(), requires_grad=trainable)
elif len(init.shape) == 3:
self.data = torch.nn.Parameter(init[None, ...].clone().detach(), requires_grad=trainable)
elif len(init.shape) == 1:
self.data = torch.nn.Parameter(init[None, None, None, :].clone().detach(), requires_grad=trainable) # Convert constant to 1x1 tensor
assert False, "Invalid texture object"
self.min_max = min_max
# Filtered (trilinear) sample texture at a given location
def sample(self, texc, texc_deriv, filter_mode='linear-mipmap-linear'):
if isinstance(self.data, list):
out = dr.texture(self.data[0], texc, texc_deriv, mip=self.data[1:], filter_mode=filter_mode)
if self.data.shape[1] > 1 and self.data.shape[2] > 1:
mips = [self.data]
while mips[-1].shape[1] > 1 and mips[-1].shape[2] > 1:
mips += [texture2d_mip.apply(mips[-1])]
out = dr.texture(mips[0], texc, texc_deriv, mip=mips[1:], filter_mode=filter_mode)
out = dr.texture(self.data, texc, texc_deriv, filter_mode=filter_mode)
return out
def getRes(self):
return self.getMips()[0].shape[1:3]
def getChannels(self):
return self.getMips()[0].shape[3]
def getMips(self):
if isinstance(self.data, list):
return self.data
return [self.data]
# In-place clamp with no derivative to make sure values are in valid range after training
def clamp_(self):
if self.min_max is not None:
for mip in self.getMips():
for i in range(mip.shape[-1]):
mip[..., i].clamp_(min=self.min_max[0][i], max=self.min_max[1][i])
# In-place clamp with no derivative to make sure values are in valid range after training
def normalize_(self):
with torch.no_grad():
for mip in self.getMips():
mip = util.safe_normalize(mip)
# Helper function to create a trainable texture from a regular texture. The trainable weights are
# initialized with texture data as an initial guess
def create_trainable(init, res=None, auto_mipmaps=True, min_max=None):
with torch.no_grad():
if isinstance(init, Texture2D):
assert isinstance(init.data, torch.Tensor)
min_max = init.min_max if min_max is None else min_max
init = init.data
elif isinstance(init, np.ndarray):
init = torch.tensor(init, dtype=torch.float32, device='cuda')
# Pad to NHWC if needed
if len(init.shape) == 1: # Extend constant to NHWC tensor
init = init[None, None, None, :]
elif len(init.shape) == 3:
init = init[None, ...]
# Scale input to desired resolution.
if res is not None:
init = util.scale_img_nhwc(init, res)
# Genreate custom mipchain
if not auto_mipmaps:
mip_chain = [init.clone().detach().requires_grad_(True)]
while mip_chain[-1].shape[1] > 1 or mip_chain[-1].shape[2] > 1:
new_size = [max(mip_chain[-1].shape[1] // 2, 1), max(mip_chain[-1].shape[2] // 2, 1)]
mip_chain += [util.scale_img_nhwc(mip_chain[-1], new_size)]
return Texture2D(mip_chain, min_max=min_max)
return Texture2D(init, min_max=min_max)
# Convert texture to and from SRGB
def srgb_to_rgb(texture):
return Texture2D(list(util.srgb_to_rgb(mip) for mip in texture.getMips()))
def rgb_to_srgb(texture):
return Texture2D(list(util.rgb_to_srgb(mip) for mip in texture.getMips()))
# Utility functions for loading / storing a texture
def _load_mip2D(fn, lambda_fn=None, channels=None):
imgdata = torch.tensor(util.load_image(fn), dtype=torch.float32, device='cuda')
if channels is not None:
imgdata = imgdata[..., 0:channels]
if lambda_fn is not None:
imgdata = lambda_fn(imgdata)
return imgdata.detach().clone()
def load_texture2D(fn, lambda_fn=None, channels=None):
base, ext = os.path.splitext(fn)
# if os.path.exists(base + "_0" + ext):
# mips = []
# while os.path.exists(base + ("_%d" % len(mips)) + ext):
# mips += [_load_mip2D(base + ("_%d" % len(mips)) + ext, lambda_fn, channels)]
# return Texture2D(mips)
# else:
# return Texture2D(_load_mip2D(fn, lambda_fn, channels))
return Texture2D(_load_mip2D(fn, lambda_fn, channels))
def _save_mip2D(fn, mip, mipidx, lambda_fn):
if lambda_fn is not None:
data = lambda_fn(mip).detach().cpu().numpy()
data = mip.detach().cpu().numpy()
if mipidx is None:
util.save_image(fn, data)
base, ext = os.path.splitext(fn)
util.save_image(base + ("_%d" % mipidx) + ext, data)
def save_texture2D(fn, tex, lambda_fn=None):
if isinstance(tex.data, list):
for i, mip in enumerate(tex.data):
_save_mip2D(fn, mip[0,...], i, lambda_fn)
_save_mip2D(fn, tex.data[0,...], None, lambda_fn)

# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.
import os
import numpy as np
import torch
import nvdiffrast.torch as dr
import imageio
# Vector operations
def dot(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return torch.sum(x*y, -1, keepdim=True)
def reflect(x: torch.Tensor, n: torch.Tensor) -> torch.Tensor:
return 2*dot(x, n)*n - x
def length(x: torch.Tensor, eps: float = 1e-20) -> torch.Tensor:
return torch.sqrt(torch.clamp(dot(x,x), min=eps)) # Clamp to avoid nan gradients because grad(sqrt(0)) = NaN
# print(dot(x,x).min())
# if torch.isnan(dot(x,x).min()):
# raise
return torch.sqrt(dot(x,x) + eps) + eps # Clamp to avoid nan gradients because grad(sqrt(0)) = NaN
def safe_normalize(x: torch.Tensor, eps: float = 1e-20) -> torch.Tensor:
# def safe_normalize(x: torch.Tensor, eps: float = 1e-9) -> torch.Tensor:
return x / length(x, eps)
def to_hvec(x: torch.Tensor, w: float) -> torch.Tensor:
return torch.nn.functional.pad(x, pad=(0,1), mode='constant', value=w)
# sRGB color transforms
def _rgb_to_srgb(f: torch.Tensor) -> torch.Tensor:
return torch.where(f <= 0.0031308, f * 12.92, torch.pow(torch.clamp(f, 0.0031308), 1.0/2.4)*1.055 - 0.055)
def rgb_to_srgb(f: torch.Tensor) -> torch.Tensor:
assert f.shape[-1] == 3 or f.shape[-1] == 4
out = torch.cat((_rgb_to_srgb(f[..., 0:3]), f[..., 3:4]), dim=-1) if f.shape[-1] == 4 else _rgb_to_srgb(f)
assert out.shape[0] == f.shape[0] and out.shape[1] == f.shape[1] and out.shape[2] == f.shape[2]
return out
def _srgb_to_rgb(f: torch.Tensor) -> torch.Tensor:
return torch.where(f <= 0.04045, f / 12.92, torch.pow((torch.clamp(f, 0.04045) + 0.055) / 1.055, 2.4))
def srgb_to_rgb(f: torch.Tensor) -> torch.Tensor:
assert f.shape[-1] == 3 or f.shape[-1] == 4
out = torch.cat((_srgb_to_rgb(f[..., 0:3]), f[..., 3:4]), dim=-1) if f.shape[-1] == 4 else _srgb_to_rgb(f)
assert out.shape[0] == f.shape[0] and out.shape[1] == f.shape[1] and out.shape[2] == f.shape[2]
return out
def reinhard(f: torch.Tensor) -> torch.Tensor:
return f/(1+f)
# Metrics (taken from jaxNerf source code, in order to replicate their measurements)
# https://github.com/google-research/google-research/blob/301451a62102b046bbeebff49a760ebeec9707b8/jaxnerf/nerf/utils.py#L266
def mse_to_psnr(mse):
"""Compute PSNR given an MSE (we assume the maximum pixel value is 1)."""
return -10. / np.log(10.) * np.log(mse)
def psnr_to_mse(psnr):
"""Compute MSE given a PSNR (we assume the maximum pixel value is 1)."""
return np.exp(-0.1 * np.log(10.) * psnr)
# Displacement texture lookup
def get_miplevels(texture: np.ndarray) -> float:
minDim = min(texture.shape[0], texture.shape[1])
return np.floor(np.log2(minDim))
def tex_2d(tex_map : torch.Tensor, coords : torch.Tensor, filter='nearest') -> torch.Tensor:
tex_map = tex_map[None, ...] # Add batch dimension
tex_map = tex_map.permute(0, 3, 1, 2) # NHWC -> NCHW
tex = torch.nn.functional.grid_sample(tex_map, coords[None, None, ...] * 2 - 1, mode=filter, align_corners=False)
tex = tex.permute(0, 2, 3, 1) # NCHW -> NHWC
return tex[0, 0, ...]
# Cubemap utility functions
def cube_to_dir(s, x, y):
if s == 0: rx, ry, rz = torch.ones_like(x), -y, -x
elif s == 1: rx, ry, rz = -torch.ones_like(x), -y, x
elif s == 2: rx, ry, rz = x, torch.ones_like(x), y
elif s == 3: rx, ry, rz = x, -torch.ones_like(x), -y
elif s == 4: rx, ry, rz = x, -y, torch.ones_like(x)
elif s == 5: rx, ry, rz = -x, -y, -torch.ones_like(x)
return torch.stack((rx, ry, rz), dim=-1)
def latlong_to_cubemap(latlong_map, res):
cubemap = torch.zeros(6, res[0], res[1], latlong_map.shape[-1], dtype=torch.float32, device='cuda')
for s in range(6):
gy, gx = torch.meshgrid(torch.linspace(-1.0 + 1.0 / res[0], 1.0 - 1.0 / res[0], res[0], device='cuda'),
torch.linspace(-1.0 + 1.0 / res[1], 1.0 - 1.0 / res[1], res[1], device='cuda'),
v = safe_normalize(cube_to_dir(s, gx, gy))
tu = torch.atan2(v[..., 0:1], -v[..., 2:3]) / (2 * np.pi) + 0.5
tv = torch.acos(torch.clamp(v[..., 1:2], min=-1, max=1)) / np.pi
texcoord = torch.cat((tu, tv), dim=-1)
cubemap[s, ...] = dr.texture(latlong_map[None, ...], texcoord[None, ...], filter_mode='linear')[0]
return cubemap
def cubemap_to_latlong(cubemap, res):
gy, gx = torch.meshgrid(torch.linspace( 0.0 + 1.0 / res[0], 1.0 - 1.0 / res[0], res[0], device='cuda'),
torch.linspace(-1.0 + 1.0 / res[1], 1.0 - 1.0 / res[1], res[1], device='cuda'),
sintheta, costheta = torch.sin(gy*np.pi), torch.cos(gy*np.pi)
sinphi, cosphi = torch.sin(gx*np.pi), torch.cos(gx*np.pi)
reflvec = torch.stack((
), dim=-1)
return dr.texture(cubemap[None, ...], reflvec[None, ...].contiguous(), filter_mode='linear', boundary_mode='cube')[0]
# Image scaling
def scale_img_hwc(x : torch.Tensor, size, mag='bilinear', min='area') -> torch.Tensor:
return scale_img_nhwc(x[None, ...], size, mag, min)[0]
def scale_img_nhwc(x : torch.Tensor, size, mag='bilinear', min='area') -> torch.Tensor:
# assert (x.shape[1] >= size[0] and x.shape[2] >= size[1]) or (x.shape[1] < size[0] and x.shape[2] < size[1]), "Trying to magnify image in one dimension and minify in the other"
assert (x.shape[1] >= size[0] and x.shape[2] >= size[1]) or (x.shape[1] <= size[0] and x.shape[2] <= size[1]), "Trying to magnify image in one dimension and minify in the other"
y = x.permute(0, 3, 1, 2) # NHWC -> NCHW
if x.shape[1] > size[0] and x.shape[2] > size[1]: # Minification, previous size was bigger
y = torch.nn.functional.interpolate(y, size, mode=min)
else: # Magnification
if mag == 'bilinear' or mag == 'bicubic':
y = torch.nn.functional.interpolate(y, size, mode=mag, align_corners=True)
y = torch.nn.functional.interpolate(y, size, mode=mag)
return y.permute(0, 2, 3, 1).contiguous() # NCHW -> NHWC
def avg_pool_nhwc(x : torch.Tensor, size) -> torch.Tensor:
y = x.permute(0, 3, 1, 2) # NHWC -> NCHW
y = torch.nn.functional.avg_pool2d(y, size)
return y.permute(0, 2, 3, 1).contiguous() # NCHW -> NHWC
# Behaves similar to tf.segment_sum
def segment_sum(data: torch.Tensor, segment_ids: torch.Tensor) -> torch.Tensor:
num_segments = torch.unique_consecutive(segment_ids).shape[0]
# Repeats ids until same dimension as data
if len(segment_ids.shape) == 1:
s = torch.prod(torch.tensor(data.shape[1:], dtype=torch.int64, device='cuda')).long()
segment_ids = segment_ids.repeat_interleave(s).view(segment_ids.shape[0], *data.shape[1:])
assert data.shape == segment_ids.shape, "data.shape and segment_ids.shape should be equal"
shape = [num_segments] + list(data.shape[1:])
result = torch.zeros(*shape, dtype=torch.float32, device='cuda')
result = result.scatter_add(0, segment_ids, data)
return result
# Matrix helpers.
def fovx_to_fovy(fovx, aspect):
return np.arctan(np.tan(fovx / 2) / aspect) * 2.0
def focal_length_to_fovy(focal_length, sensor_height):
return 2 * np.arctan(0.5 * sensor_height / focal_length)
# Reworked so this matches gluPerspective / glm::perspective, using fovy
def perspective(fovy=0.7854, aspect=1.0, n=0.1, f=1000.0, device=None):
y = np.tan(fovy / 2)
return torch.tensor([[1/(y*aspect), 0, 0, 0],
[ 0, 1/-y, 0, 0],
[ 0, 0, -(f+n)/(f-n), -(2*f*n)/(f-n)],
[ 0, 0, -1, 0]], dtype=torch.float32, device=device)
# Reworked so this matches gluPerspective / glm::perspective, using fovy
def perspective_offcenter(fovy, fraction, rx, ry, aspect=1.0, n=0.1, f=1000.0, device=None):
y = np.tan(fovy / 2)
# Full frustum
R, L = aspect*y, -aspect*y
T, B = y, -y
# Create a randomized sub-frustum
width = (R-L)*fraction
height = (T-B)*fraction
xstart = (R-L)*rx
ystart = (T-B)*ry
l = L + xstart
r = l + width
b = B + ystart
t = b + height
# https://www.scratchapixel.com/lessons/3d-basic-rendering/perspective-and-orthographic-projection-matrix/opengl-perspective-projection-matrix
return torch.tensor([[2/(r-l), 0, (r+l)/(r-l), 0],
[ 0, -2/(t-b), (t+b)/(t-b), 0],
[ 0, 0, -(f+n)/(f-n), -(2*f*n)/(f-n)],
[ 0, 0, -1, 0]], dtype=torch.float32, device=device)
def translate(x, y, z, device=None):
return torch.tensor([[1, 0, 0, x],
[0, 1, 0, y],
[0, 0, 1, z],
[0, 0, 0, 1]], dtype=torch.float32, device=device)
def rotate_x(a, device=None):
s, c = np.sin(a), np.cos(a)
return torch.tensor([[1, 0, 0, 0],
[0, c, s, 0],
[0, -s, c, 0],
[0, 0, 0, 1]], dtype=torch.float32, device=device)
def rotate_y(a, device=None):
s, c = np.sin(a), np.cos(a)
return torch.tensor([[ c, 0, s, 0],
[ 0, 1, 0, 0],
[-s, 0, c, 0],
[ 0, 0, 0, 1]], dtype=torch.float32, device=device)
def scale(s, device=None):
return torch.tensor([[ s, 0, 0, 0],
[ 0, s, 0, 0],
[ 0, 0, s, 0],
[ 0, 0, 0, 1]], dtype=torch.float32, device=device)
def lookAt(eye, at, up):
a = eye - at
w = a / torch.linalg.norm(a)
u = torch.cross(up, w)
u = u / torch.linalg.norm(u)
v = torch.cross(w, u)
translate = torch.tensor([[1, 0, 0, -eye[0]],
[0, 1, 0, -eye[1]],
[0, 0, 1, -eye[2]],
[0, 0, 0, 1]], dtype=eye.dtype, device=eye.device)
rotate = torch.tensor([[u[0], u[1], u[2], 0],
[v[0], v[1], v[2], 0],
[w[0], w[1], w[2], 0],
[0, 0, 0, 1]], dtype=eye.dtype, device=eye.device)
return rotate @ translate
def random_rotation_translation(t, device=None):
m = np.random.normal(size=[3, 3])
m[1] = np.cross(m[0], m[2])
m[2] = np.cross(m[0], m[1])
m = m / np.linalg.norm(m, axis=1, keepdims=True)
m = np.pad(m, [[0, 1], [0, 1]], mode='constant')
m[3, 3] = 1.0
m[:3, 3] = np.random.uniform(-t, t, size=[3])
return torch.tensor(m, dtype=torch.float32, device=device)
def random_rotation(device=None):
m = np.random.normal(size=[3, 3])
m[1] = np.cross(m[0], m[2])
m[2] = np.cross(m[0], m[1])
m = m / np.linalg.norm(m, axis=1, keepdims=True)
m = np.pad(m, [[0, 1], [0, 1]], mode='constant')
m[3, 3] = 1.0
m[:3, 3] = np.array([0,0,0]).astype(np.float32)
return torch.tensor(m, dtype=torch.float32, device=device)
def batch_random_rotation(batch_size, device=None):
m = np.random.normal(size=[batch_size, 3, 3])
m[:, 1] = np.cross(m[:, 0], m[:, 2])
m[:, 2] = np.cross(m[:, 0], m[:, 1])
m = m / np.linalg.norm(m, axis=-1, keepdims=True)
m = np.pad(m, [[0, 0], [0, 1], [0, 1]], mode='constant')
m[:, 3, 3] = 1.0
m[:, :3, 3] = np.array([0,0,0]).astype(np.float32).unsqueeze(0)
return torch.tensor(m, dtype=torch.float32, device=device)
# Compute focal points of a set of lines using least squares.
# handy for poorly centered datasets
def lines_focal(o, d):
d = safe_normalize(d)
I = torch.eye(3, dtype=o.dtype, device=o.device)
S = torch.sum(d[..., None] @ torch.transpose(d[..., None], 1, 2) - I[None, ...], dim=0)
C = torch.sum((d[..., None] @ torch.transpose(d[..., None], 1, 2) - I[None, ...]) @ o[..., None], dim=0).squeeze(1)
return torch.linalg.pinv(S) @ C
# Cosine sample around a vector N
def cosine_sample(N, size=None):
# construct local frame
N = N/torch.linalg.norm(N)
dx0 = torch.tensor([0, N[2], -N[1]], dtype=N.dtype, device=N.device)
dx1 = torch.tensor([-N[2], 0, N[0]], dtype=N.dtype, device=N.device)
dx = torch.where(dot(dx0, dx0) > dot(dx1, dx1), dx0, dx1)
#dx = dx0 if np.dot(dx0,dx0) > np.dot(dx1,dx1) else dx1
dx = dx / torch.linalg.norm(dx)
dy = torch.cross(N,dx)
dy = dy / torch.linalg.norm(dy)
# cosine sampling in local frame
if size is None:
phi = 2.0 * np.pi * np.random.uniform()
s = np.random.uniform()
phi = 2.0 * np.pi * torch.rand(*size, 1, dtype=N.dtype, device=N.device)
s = torch.rand(*size, 1, dtype=N.dtype, device=N.device)
costheta = np.sqrt(s)
sintheta = np.sqrt(1.0 - s)
# cartesian vector in local space
x = np.cos(phi)*sintheta
y = np.sin(phi)*sintheta
z = costheta
# local to world
return dx*x + dy*y + N*z
# Bilinear downsample by 2x.
def bilinear_downsample(x : torch.tensor) -> torch.Tensor:
w = torch.tensor([[1, 3, 3, 1], [3, 9, 9, 3], [3, 9, 9, 3], [1, 3, 3, 1]], dtype=torch.float32, device=x.device) / 64.0
w = w.expand(x.shape[-1], 1, 4, 4)
x = torch.nn.functional.conv2d(x.permute(0, 3, 1, 2), w, padding=1, stride=2, groups=x.shape[-1])
return x.permute(0, 2, 3, 1)
# Bilinear downsample log(spp) steps
def bilinear_downsample(x : torch.tensor, spp) -> torch.Tensor:
w = torch.tensor([[1, 3, 3, 1], [3, 9, 9, 3], [3, 9, 9, 3], [1, 3, 3, 1]], dtype=torch.float32, device=x.device) / 64.0
g = x.shape[-1]
w = w.expand(g, 1, 4, 4)
x = x.permute(0, 3, 1, 2) # NHWC -> NCHW
steps = int(np.log2(spp))
for _ in range(steps):
xp = torch.nn.functional.pad(x, (1,1,1,1), mode='replicate')
x = torch.nn.functional.conv2d(xp, w, padding=0, stride=2, groups=g)
return x.permute(0, 2, 3, 1).contiguous() # NCHW -> NHWC
# Singleton initialize GLFW
_glfw_initialized = False
def init_glfw():
global _glfw_initialized
import glfw
glfw.ERROR_REPORTING = 'raise'
glfw.window_hint(glfw.VISIBLE, glfw.FALSE)
test = glfw.create_window(8, 8, "Test", None, None) # Create a window and see if not initialized yet
except glfw.GLFWError as e:
if e.error_code == glfw.NOT_INITIALIZED:
_glfw_initialized = True
# Image display function using OpenGL.
_glfw_window = None
def display_image(image, title=None):
# Import OpenGL
import OpenGL.GL as gl
import glfw
# Zoom image if requested.
image = np.asarray(image[..., 0:3]) if image.shape[-1] == 4 else np.asarray(image)
height, width, channels = image.shape
# Initialize window.
if title is None:
title = 'Debug window'
global _glfw_window
if _glfw_window is None:
_glfw_window = glfw.create_window(width, height, title, None, None)
glfw.set_window_title(_glfw_window, title)
glfw.set_window_size(_glfw_window, width, height)
# Update window.
gl.glClearColor(0, 0, 0, 1)
gl.glWindowPos2f(0, 0)
gl.glPixelStorei(gl.GL_UNPACK_ALIGNMENT, 1)
gl_format = {3: gl.GL_RGB, 2: gl.GL_RG, 1: gl.GL_LUMINANCE}[channels]
gl_dtype = {'uint8': gl.GL_UNSIGNED_BYTE, 'float32': gl.GL_FLOAT}[image.dtype.name]
gl.glDrawPixels(width, height, gl_format, gl_dtype, image[::-1])
if glfw.window_should_close(_glfw_window):
return False
return True
# Image save/load helper.
def save_image(fn, x : np.ndarray):
if os.path.splitext(fn)[1] == ".png":
imageio.imwrite(fn, np.clip(np.rint(x * 255.0), 0, 255).astype(np.uint8), compress_level=3) # Low compression for faster saving
imageio.imwrite(fn, np.clip(np.rint(x * 255.0), 0, 255).astype(np.uint8))
print("WARNING: FAILED to save image %s" % fn)
def save_image_raw(fn, x : np.ndarray):
imageio.imwrite(fn, x)
print("WARNING: FAILED to save image %s" % fn)
def load_image_raw(fn) -> np.ndarray:
return imageio.imread(fn)
def load_image(fn) -> np.ndarray:
img = load_image_raw(fn)
if img.dtype == np.float32: # HDR image
return img
else: # LDR image
return img.astype(np.float32) / 255
def time_to_text(x):
if x > 3600:
return "%.2f h" % (x / 3600)
elif x > 60:
return "%.2f m" % (x / 60)
return "%.2f s" % x
def checkerboard(res, checker_size) -> np.ndarray:
tiles_y = (res[0] + (checker_size*2) - 1) // (checker_size*2)
tiles_x = (res[1] + (checker_size*2) - 1) // (checker_size*2)
check = np.kron([[1, 0] * tiles_x, [0, 1] * tiles_x] * tiles_y, np.ones((checker_size, checker_size)))*0.33 + 0.33
check = check[:res[0], :res[1]]
return np.stack((check, check, check), axis=-1)