diff --git a/.gitignore b/.gitignore index a669bda..7d2e40a 100644 --- a/.gitignore +++ b/.gitignore @@ -17,7 +17,6 @@ dist/ downloads/ eggs/ .eggs/ -lib/ lib64/ parts/ sdist/ diff --git a/lib/dataset/datasets.py b/lib/dataset/datasets.py new file mode 100644 index 0000000..6eb2b44 --- /dev/null +++ b/lib/dataset/datasets.py @@ -0,0 +1,196 @@ +# coding=utf-8 +# Copyright 2020 The Google Research Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: skip-file +"""Return training and evaluation/test datasets from config files.""" +import jax +import tensorflow as tf +import tensorflow_datasets as tfds + + +def get_data_scaler(config): + """Data normalizer. Assume data are always in [0, 1].""" + if config.data.centered: + # Rescale to [-1, 1] + return lambda x: x * 2. - 1. + else: + return lambda x: x + + +def get_data_inverse_scaler(config): + """Inverse data normalizer.""" + if config.data.centered: + # Rescale [-1, 1] to [0, 1] + return lambda x: (x + 1.) / 2. + else: + return lambda x: x + + +def crop_resize(image, resolution): + """Crop and resize an image to the given resolution.""" + crop = tf.minimum(tf.shape(image)[0], tf.shape(image)[1]) + h, w = tf.shape(image)[0], tf.shape(image)[1] + image = image[(h - crop) // 2:(h + crop) // 2, + (w - crop) // 2:(w + crop) // 2] + image = tf.image.resize( + image, + size=(resolution, resolution), + antialias=True, + method=tf.image.ResizeMethod.BICUBIC) + return tf.cast(image, tf.uint8) + + +def resize_small(image, resolution): + """Shrink an image to the given resolution.""" + h, w = image.shape[0], image.shape[1] + ratio = resolution / min(h, w) + h = tf.round(h * ratio, tf.int32) + w = tf.round(w * ratio, tf.int32) + return tf.image.resize(image, [h, w], antialias=True) + + +def central_crop(image, size): + """Crop the center of an image to the given size.""" + top = (image.shape[0] - size) // 2 + left = (image.shape[1] - size) // 2 + return tf.image.crop_to_bounding_box(image, top, left, size, size) + + +def get_dataset(config, uniform_dequantization=False, evaluation=False): + """Create data loaders for training and evaluation. + + Args: + config: A ml_collection.ConfigDict parsed from config files. + uniform_dequantization: If `True`, add uniform dequantization to images. + evaluation: If `True`, fix number of epochs to 1. + + Returns: + train_ds, eval_ds, dataset_builder. + """ + # Compute batch size for this worker. + batch_size = config.training.batch_size if not evaluation else config.eval.batch_size + if batch_size % jax.device_count() != 0: + raise ValueError(f'Batch sizes ({batch_size} must be divided by' + f'the number of devices ({jax.device_count()})') + + # Reduce this when image resolution is too large and data pointer is stored + shuffle_buffer_size = 10000 + prefetch_size = tf.data.experimental.AUTOTUNE + num_epochs = None if not evaluation else 1 + + # Create dataset builders for each dataset. + if config.data.dataset == 'CIFAR10': + dataset_builder = tfds.builder('cifar10') + train_split_name = 'train' + eval_split_name = 'test' + + def resize_op(img): + img = tf.image.convert_image_dtype(img, tf.float32) + return tf.image.resize(img, [config.data.image_size, config.data.image_size], antialias=True) + + elif config.data.dataset == 'SVHN': + dataset_builder = tfds.builder('svhn_cropped') + train_split_name = 'train' + eval_split_name = 'test' + + def resize_op(img): + img = tf.image.convert_image_dtype(img, tf.float32) + return tf.image.resize(img, [config.data.image_size, config.data.image_size], antialias=True) + + elif config.data.dataset == 'CELEBA': + dataset_builder = tfds.builder('celeb_a') + train_split_name = 'train' + eval_split_name = 'validation' + + def resize_op(img): + img = tf.image.convert_image_dtype(img, tf.float32) + img = central_crop(img, 140) + img = resize_small(img, config.data.image_size) + return img + + elif config.data.dataset == 'LSUN': + dataset_builder = tfds.builder(f'lsun/{config.data.category}') + train_split_name = 'train' + eval_split_name = 'validation' + + if config.data.image_size == 128: + def resize_op(img): + img = tf.image.convert_image_dtype(img, tf.float32) + img = resize_small(img, config.data.image_size) + img = central_crop(img, config.data.image_size) + return img + + else: + def resize_op(img): + img = crop_resize(img, config.data.image_size) + img = tf.image.convert_image_dtype(img, tf.float32) + return img + + elif config.data.dataset in ['FFHQ', 'CelebAHQ']: + dataset_builder = tf.data.TFRecordDataset(config.data.tfrecords_path) + train_split_name = eval_split_name = 'train' + + else: + raise NotImplementedError( + f'Dataset {config.data.dataset} not yet supported.') + + # Customize preprocess functions for each dataset. + if config.data.dataset in ['FFHQ', 'CelebAHQ']: + def preprocess_fn(d): + sample = tf.io.parse_single_example(d, features={ + 'shape': tf.io.FixedLenFeature([3], tf.int64), + 'data': tf.io.FixedLenFeature([], tf.string)}) + data = tf.io.decode_raw(sample['data'], tf.uint8) + data = tf.reshape(data, sample['shape']) + data = tf.transpose(data, (1, 2, 0)) + img = tf.image.convert_image_dtype(data, tf.float32) + if config.data.random_flip and not evaluation: + img = tf.image.random_flip_left_right(img) + if uniform_dequantization: + img = (tf.random.uniform(img.shape, dtype=tf.float32) + img * 255.) / 256. + return dict(image=img, label=None) + + else: + def preprocess_fn(d): + """Basic preprocessing function scales data to [0, 1) and randomly flips.""" + img = resize_op(d['image']) + if config.data.random_flip and not evaluation: + img = tf.image.random_flip_left_right(img) + if uniform_dequantization: + img = (tf.random.uniform(img.shape, dtype=tf.float32) + img * 255.) / 256. + + return dict(image=img, label=d.get('label', None)) + + def create_dataset(dataset_builder, split): + dataset_options = tf.data.Options() + dataset_options.experimental_optimization.map_parallelization = True + dataset_options.experimental_threading.private_threadpool_size = 48 + dataset_options.experimental_threading.max_intra_op_parallelism = 1 + read_config = tfds.ReadConfig(options=dataset_options) + if isinstance(dataset_builder, tfds.core.DatasetBuilder): + dataset_builder.download_and_prepare() + ds = dataset_builder.as_dataset( + split=split, shuffle_files=True, read_config=read_config) + else: + ds = dataset_builder.with_options(dataset_options) + ds = ds.repeat(count=num_epochs) + ds = ds.shuffle(shuffle_buffer_size) + ds = ds.map(preprocess_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE) + ds = ds.batch(batch_size, drop_remainder=True) + return ds.prefetch(prefetch_size) + + train_ds = create_dataset(dataset_builder, train_split_name) + eval_ds = create_dataset(dataset_builder, eval_split_name) + return train_ds, eval_ds, dataset_builder diff --git a/lib/dataset/shapenet_dmtet_dataset.py b/lib/dataset/shapenet_dmtet_dataset.py new file mode 100644 index 0000000..5782cd4 --- /dev/null +++ b/lib/dataset/shapenet_dmtet_dataset.py @@ -0,0 +1,49 @@ +import os +import sys +import torch +from torch.utils.data import Dataset +import json + +import argparse + +class ShapeNetDMTetDataset(Dataset): + def __init__(self, root, grid_mask, deform_scale=1.0, aug=False, filter_meta_path=None, normalize_sdf=True): + super().__init__() + self.fpath_list = json.load(open(root, 'r')) + self.deform_scale = deform_scale + self.normalize_sdf = normalize_sdf + print(f"dataset with sdf normalized: {normalize_sdf}") + self.coeff = torch.tensor([1.0, 1.0, self.deform_scale, self.deform_scale, self.deform_scale]).view(-1, 1, 1, 1) + self.aug = aug + self.grid_mask = grid_mask.cpu() + self.resolution = self.grid_mask.size(-1) + + if filter_meta_path is not None: + self.filter_ids = json.load(open(filter_meta_path, 'r')) + full_id_list = [int(x.rstrip().split('_')[-1][:-3]) for i, x in enumerate(self.fpath_list)] + fpath_idx_list = [i for i, x in enumerate(full_id_list) if x in self.filter_ids] + self.fpath_list = [self.fpath_list[i] for i in fpath_idx_list] + + def __len__(self): + return len(self.fpath_list) + + def __getitem__(self, idx): + with torch.no_grad(): + datum = torch.load(self.fpath_list[idx], map_location='cpu') + if self.normalize_sdf: + sdf_sign = torch.sign(datum[:, :1]) + sdf_sign[sdf_sign == 0] = 1.0 + datum[:, :1] = sdf_sign + if self.aug: + nonempty_mask = (datum[1:].abs().sum(dim=0, keepdim=True) != 0) + datum[1:] = datum[1:] + (torch.rand(3)[:, None, None, None] - 0.5) * 0.01 * nonempty_mask / (datum.size(-1) / self.resolution) + + if datum.size(-1) < self.resolution: + datum = datum * self.grid_mask[0, :, :datum.size(-1), :datum.size(-1), :datum.size(-1)] + else: + datum = datum * self.grid_mask[0] + + if datum.size(-1) < self.resolution: + diff = self.resolution - datum.size(-1) + datum = torch.nn.functional.pad(datum, (0, diff, 0, diff, 0, diff, 0, 0)) + return datum diff --git a/lib/diffusion/evaler.py b/lib/diffusion/evaler.py new file mode 100644 index 0000000..e43fcc4 --- /dev/null +++ b/lib/diffusion/evaler.py @@ -0,0 +1,212 @@ +import os +import sys +import numpy as np + +import logging +from . import losses +from .models import utils as mutils +from .models.ema import ExponentialMovingAverage +from . import sde_lib +import torch +from .utils import restore_checkpoint +from . import sampling + +def uncond_gen( + config, + idx=0, + ): + """ + Unconditional Generation + """ + with torch.no_grad(): + eval_dir, ckpt_path = config.eval.eval_dir, config.eval.ckpt_path + # Create directory to eval_folder + os.makedirs(eval_dir, exist_ok=True) + + scaler, inverse_scaler = lambda x: x, lambda x: x + + # Initialize model + score_model = mutils.create_model(config) + optimizer = losses.get_optimizer(config, score_model.parameters()) + ema = ExponentialMovingAverage(score_model.parameters(), decay=config.model.ema_rate) + state = dict(optimizer=optimizer, model=score_model, ema=ema, step=0) + + # Setup SDEs + sde = sde_lib.VPSDE(beta_min=config.model.beta_min, beta_max=config.model.beta_max, N=config.model.num_scales) + + img_size = config.data.image_size + grid_mask = torch.load(f'./data/grid_mask_{img_size}.pt').view(1, img_size, img_size, img_size).to("cuda") + + sampling_eps = 1e-3 + sampling_shape = (config.eval.batch_size, + config.data.num_channels, + config.data.image_size, config.data.image_size, config.data.image_size) + sampling_fn = sampling.get_sampling_fn(config, sde, sampling_shape, inverse_scaler, sampling_eps, grid_mask=grid_mask) + + assert os.path.exists(ckpt_path) + print('ckpt path:', ckpt_path) + try: + state = restore_checkpoint(ckpt_path, state, device=config.device) + except: + raise + ema.copy_to(score_model.parameters()) + + print(f"loaded model is trained till iter {state['step'] // config.training.iter_size}") + save_file_path = os.path.join(eval_dir, f"{idx}.npy") + + + samples, n = sampling_fn(score_model) + samples = samples.cpu().numpy() + np.save(save_file_path, samples) + + +def slerp(z1, z2, alpha): + ''' + Spherical Linear Interpolation + ''' + theta = torch.acos(torch.sum(z1 * z2) / (torch.norm(z1) * torch.norm(z2))) + return ( + torch.sin((1 - alpha) * theta) / torch.sin(theta) * z1 + + torch.sin(alpha * theta) / torch.sin(theta) * z2 + ) + +def uncond_gen_interp( + config, + idx=0, + ): + """ + Generation with interpolation between initial noises + Used for DDIM + """ + with torch.no_grad(): + eval_dir, ckpt_path = config.eval.eval_dir, config.eval.ckpt_path + # Create directory to eval_folder + os.makedirs(eval_dir, exist_ok=True) + + scaler, inverse_scaler = lambda x: x, lambda x: x + + # Initialize model + score_model = mutils.create_model(config) + optimizer = losses.get_optimizer(config, score_model.parameters()) + ema = ExponentialMovingAverage(score_model.parameters(), decay=config.model.ema_rate) + state = dict(optimizer=optimizer, model=score_model, ema=ema, step=0) + + # Setup SDEs + sde = sde_lib.VPSDE(beta_min=config.model.beta_min, beta_max=config.model.beta_max, N=config.model.num_scales) + + img_size = config.data.image_size + grid_mask = torch.load(f'./data/grid_mask_{img_size}.pt').view(1, img_size, img_size, img_size).to("cuda") + + sampling_eps = 1e-3 + sampling_shape = (config.eval.batch_size, + config.data.num_channels, + config.data.image_size, config.data.image_size, config.data.image_size) + sampling_fn = sampling.get_sampling_fn(config, sde, sampling_shape, inverse_scaler, sampling_eps, grid_mask=grid_mask) + + assert os.path.exists(ckpt_path) + print('ckpt path:', ckpt_path) + try: + state = restore_checkpoint(ckpt_path, state, device=config.device) + except: + raise + ema.copy_to(score_model.parameters()) + + print(f"loaded model is trained till iter {state['step'] // config.training.iter_size}") + save_file_path = os.path.join(eval_dir, f"{idx}.npy") + + + noise = sde.prior_sampling( + (2, config.data.num_channels, config.data.image_size, config.data.image_size, config.data.image_size) + ).to(config.device) + + + x0 = torch.zeros(sampling_shape, device=config.device) + x0[0] = noise[0] + x0[-1] = noise[1] + for i in range(1, batch_size - 1): + x[i] = slerp(x[0], x[-1], i / float(batch_size - 1)) + + samples, n = sampling_fn(score_model, x0=x0) + samples = samples.cpu().numpy() + np.save(save_file_path, samples) + + +def cond_gen( + config, + save_fname='0', + ): + """ + Conditional Generation with partially completed dmtet from a 2.5D view (converted into a cubic grid) + """ + with torch.no_grad(): + eval_dir, ckpt_path = config.eval.eval_dir, config.eval.ckpt_path + # Create directory to eval_folder + os.makedirs(eval_dir, exist_ok=True) + + scaler, inverse_scaler = lambda x: x, lambda x: x + + # Initialize model + score_model = mutils.create_model(config) + optimizer = losses.get_optimizer(config, score_model.parameters()) + ema = ExponentialMovingAverage(score_model.parameters(), decay=config.model.ema_rate) + state = dict(optimizer=optimizer, model=score_model, ema=ema, step=0) + + # Setup SDEs + sde = sde_lib.VPSDE(beta_min=config.model.beta_min, beta_max=config.model.beta_max, N=config.model.num_scales) + + resolution = config.data.image_size + grid_mask = torch.load(f'./data/grid_mask_{resolution}.pt').view(1, 1, resolution, resolution, resolution).to("cuda") + + sampling_eps = 1e-3 + sampling_shape = (config.eval.batch_size, + config.data.num_channels, + resolution, resolution, resolution) + sampling_fn = sampling.get_sampling_fn(config, sde, sampling_shape, inverse_scaler, sampling_eps, grid_mask=grid_mask) + + assert os.path.exists(ckpt_path) + print('ckpt path:', ckpt_path) + try: + state = restore_checkpoint(ckpt_path, state, device=config.device) + except: + raise + ema.copy_to(score_model.parameters()) + + print(f"loaded model is trained till iter {state['step'] // config.training.iter_size}") + + + save_file_path = os.path.join(eval_dir, f"{save_fname}.npy") + + ### Conditional but free gradients; start from small t + + partial_dict = torch.load(config.eval.partial_dmtet_path) + partial_sdf = partial_dict['sdf'] + partial_mask = partial_dict['vis'] + + + ### compute the mapping from tet indices to 3D cubic grid vertex indices + tet_path = config.eval.tet_path + tet = np.load(tet_path) + vertices = torch.tensor(tet['vertices']) + vertices_unique = vertices[:].unique() + dx = vertices_unique[1] - vertices_unique[0] + + ind_to_coord = (torch.round( + (vertices - vertices.min()) / dx) + ).long() + + + partial_sdf_grid = torch.zeros((1, 1, resolution, resolution, resolution)) + partial_sdf_grid[0, 0, ind_to_coord[:, 0], ind_to_coord[:, 1], ind_to_coord[:, 2]] = partial_sdf + partial_mask_grid = torch.zeros((1, 1, resolution, resolution, resolution)) + partial_mask_grid[0, 0, ind_to_coord[:, 0], ind_to_coord[:, 1], ind_to_coord[:, 2]] = partial_mask.float() + + samples, n = sampling_fn( + score_model, + partial=partial_sdf_grid.cuda(), + partial_mask=partial_mask_grid.cuda(), + freeze_iters=config.eval.freeze_iters + ) + + samples = samples.cpu().numpy() + np.save(save_file_path, samples) + diff --git a/lib/diffusion/likelihood.py b/lib/diffusion/likelihood.py new file mode 100644 index 0000000..606dfa3 --- /dev/null +++ b/lib/diffusion/likelihood.py @@ -0,0 +1,113 @@ +# coding=utf-8 +# Copyright 2020 The Google Research Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: skip-file +# pytype: skip-file +"""Various sampling methods.""" + +import torch +import numpy as np +from scipy import integrate +from .models import utils as mutils + + +def get_div_fn(fn): + """Create the divergence function of `fn` using the Hutchinson-Skilling trace estimator.""" + + def div_fn(x, t, eps): + with torch.enable_grad(): + x.requires_grad_(True) + fn_eps = torch.sum(fn(x, t) * eps) + grad_fn_eps = torch.autograd.grad(fn_eps, x)[0] + x.requires_grad_(False) + return torch.sum(grad_fn_eps * eps, dim=tuple(range(1, len(x.shape)))) + + return div_fn + + +def get_likelihood_fn(sde, inverse_scaler, hutchinson_type='Rademacher', + rtol=1e-5, atol=1e-5, method='RK45', eps=1e-5): + """Create a function to compute the unbiased log-likelihood estimate of a given data point. + + Args: + sde: A `sde_lib.SDE` object that represents the forward SDE. + inverse_scaler: The inverse data normalizer. + hutchinson_type: "Rademacher" or "Gaussian". The type of noise for Hutchinson-Skilling trace estimator. + rtol: A `float` number. The relative tolerance level of the black-box ODE solver. + atol: A `float` number. The absolute tolerance level of the black-box ODE solver. + method: A `str`. The algorithm for the black-box ODE solver. + See documentation for `scipy.integrate.solve_ivp`. + eps: A `float` number. The probability flow ODE is integrated to `eps` for numerical stability. + + Returns: + A function that a batch of data points and returns the log-likelihoods in bits/dim, + the latent code, and the number of function evaluations cost by computation. + """ + + def drift_fn(model, x, t): + """The drift function of the reverse-time SDE.""" + score_fn = mutils.get_score_fn(sde, model, train=False, continuous=True) + # Probability flow ODE is a special case of Reverse SDE + rsde = sde.reverse(score_fn, probability_flow=True) + return rsde.sde(x, t)[0] + + def div_fn(model, x, t, noise): + return get_div_fn(lambda xx, tt: drift_fn(model, xx, tt))(x, t, noise) + + def likelihood_fn(model, data): + """Compute an unbiased estimate to the log-likelihood in bits/dim. + + Args: + model: A score model. + data: A PyTorch tensor. + + Returns: + bpd: A PyTorch tensor of shape [batch size]. The log-likelihoods on `data` in bits/dim. + z: A PyTorch tensor of the same shape as `data`. The latent representation of `data` under the + probability flow ODE. + nfe: An integer. The number of function evaluations used for running the black-box ODE solver. + """ + with torch.no_grad(): + shape = data.shape + if hutchinson_type == 'Gaussian': + epsilon = torch.randn_like(data) + elif hutchinson_type == 'Rademacher': + epsilon = torch.randint_like(data, low=0, high=2).float() * 2 - 1. + else: + raise NotImplementedError(f"Hutchinson type {hutchinson_type} unknown.") + + def ode_func(t, x): + sample = mutils.from_flattened_numpy(x[:-shape[0]], shape).to(data.device).type(torch.float32) + vec_t = torch.ones(sample.shape[0], device=sample.device) * t + drift = mutils.to_flattened_numpy(drift_fn(model, sample, vec_t)) + logp_grad = mutils.to_flattened_numpy(div_fn(model, sample, vec_t, epsilon)) + return np.concatenate([drift, logp_grad], axis=0) + + init = np.concatenate([mutils.to_flattened_numpy(data), np.zeros((shape[0],))], axis=0) + solution = integrate.solve_ivp(ode_func, (eps, sde.T), init, rtol=rtol, atol=atol, method=method) + nfe = solution.nfev + zp = solution.y[:, -1] + z = mutils.from_flattened_numpy(zp[:-shape[0]], shape).to(data.device).type(torch.float32) + delta_logp = mutils.from_flattened_numpy(zp[-shape[0]:], (shape[0],)).to(data.device).type(torch.float32) + prior_logp = sde.prior_logp(z) + bpd = -(prior_logp + delta_logp) / np.log(2) + N = np.prod(shape[1:]) + bpd = bpd / N + # A hack to convert log-likelihoods to bits/dim + offset = 7. - inverse_scaler(-1.) + bpd = bpd + offset + return bpd, z, nfe + + return likelihood_fn diff --git a/lib/diffusion/losses.py b/lib/diffusion/losses.py new file mode 100644 index 0000000..13c709f --- /dev/null +++ b/lib/diffusion/losses.py @@ -0,0 +1,141 @@ +# coding=utf-8 +# Copyright 2020 The Google Research Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""All functions related to loss computation and optimization. +""" + +import torch +import torch.optim as optim +import numpy as np +from .models import utils as mutils +from .sde_lib import VPSDE + + +def get_optimizer(config, params): + """Returns a flax optimizer object based on `config`.""" + if config.optim.optimizer == 'Adam': + optimizer = optim.Adam(params, lr=config.optim.lr, betas=(config.optim.beta1, 0.999), eps=config.optim.eps, + weight_decay=config.optim.weight_decay) + else: + raise NotImplementedError( + f'Optimizer {config.optim.optimizer} not supported yet!') + + return optimizer + + +def optimization_manager(config): + """Returns an optimize_fn based on `config`.""" + + def optimize_fn(optimizer, params, step, lr=config.optim.lr, + warmup=config.optim.warmup, + grad_clip=config.optim.grad_clip): + """Optimizes with warmup and gradient clipping (disabled if negative).""" + if warmup > 0: + for g in optimizer.param_groups: + g['lr'] = lr * np.minimum(step / warmup, 1.0) + if grad_clip >= 0: + torch.nn.utils.clip_grad_norm_(params, max_norm=grad_clip) + optimizer.step() + + return optimize_fn + +def get_ddpm_loss_fn(vpsde, train, mask=None, loss_type='l2'): + """Legacy code to reproduce previous results on DDPM. Not recommended for new work.""" + + def loss_fn(model, batch): + model_fn = mutils.get_model_fn(model, train=train) + labels = torch.randint(0, vpsde.N, (batch.shape[0],), device=batch.device) + sqrt_alphas_cumprod = vpsde.sqrt_alphas_cumprod.to(batch.device) + sqrt_1m_alphas_cumprod = vpsde.sqrt_1m_alphas_cumprod.to(batch.device) + noise = torch.randn_like(batch) + perturbed_data = sqrt_alphas_cumprod[labels, None, None, None, None] * batch + \ + sqrt_1m_alphas_cumprod[labels, None, None, None, None] * noise + perturbed_data = perturbed_data * mask + score = model_fn(perturbed_data, labels) + + if loss_type == 'l2': + losses = torch.square(score - noise) + elif loss_type == 'l1': + losses = torch.abs(score - noise) + else: + raise NotImplementedError + + if mask is not None: + losses = losses * mask + losses = torch.mean(losses.reshape(losses.shape[0], -1), dim=-1) + loss = torch.mean(losses) / mask.sum() * np.prod(mask.size()) + else: + losses = torch.mean(losses.reshape(losses.shape[0], -1), dim=-1) + loss = torch.mean(losses) + + return loss + + return loss_fn + +def get_step_fn(sde, train, optimize_fn=None, mask=None, loss_type='l2'): + """Create a one-step training/evaluation function. + + Args: + sde: An `sde_lib.SDE` object that represents the forward SDE. + optimize_fn: An optimization function. + reduce_mean: If `True`, average the loss across data dimensions. Otherwise sum the loss across data dimensions. + continuous: `True` indicates that the model is defined to take continuous time steps. + likelihood_weighting: If `True`, weight the mixture of score matching losses according to + https://arxiv.org/abs/2101.09258; otherwise use the weighting recommended by our paper. + + Returns: + A one-step function for training or evaluation. + """ + + loss_fn = get_ddpm_loss_fn(sde, train, mask=mask, loss_type=loss_type) + + def step_fn(state, batch, clear_grad=True, update_param=True): + """Running one step of training or evaluation. + + This function will undergo `jax.lax.scan` so that multiple steps can be pmapped and jit-compiled together + for faster execution. + + Args: + state: A dictionary of training information, containing the score model, optimizer, + EMA status, and number of optimization steps. + batch: A mini-batch of training/evaluation data. + + Returns: + loss: The average loss value of this state. + """ + model = state['model'] + if train: + optimizer = state['optimizer'] + if clear_grad: + optimizer.zero_grad() + loss = loss_fn(model, batch) + loss.backward() + if update_param: + optimize_fn(optimizer, model.parameters(), step=state['step']) + state['step'] += 1 + state['ema'].update(model.parameters()) + else: + with torch.no_grad(): + ema = state['ema'] + ema.store(model.parameters()) + ema.copy_to(model.parameters()) + loss = loss_fn(model, batch) + ema.restore(model.parameters()) + + return { + 'loss': loss, + } + + return step_fn \ No newline at end of file diff --git a/lib/diffusion/models/__init__.py b/lib/diffusion/models/__init__.py new file mode 100644 index 0000000..9a19804 --- /dev/null +++ b/lib/diffusion/models/__init__.py @@ -0,0 +1,15 @@ +# coding=utf-8 +# Copyright 2020 The Google Research Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + diff --git a/lib/diffusion/models/ddpm_res128.py b/lib/diffusion/models/ddpm_res128.py new file mode 100644 index 0000000..14b053a --- /dev/null +++ b/lib/diffusion/models/ddpm_res128.py @@ -0,0 +1,215 @@ +# coding=utf-8 +# Copyright 2020 The Google Research Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: skip-file +"""DDPM model. + +This code is the pytorch equivalent of: +https://github.com/hojonathanho/diffusion/blob/master/diffusion_tf/models/unet.py +""" +import torch +import torch.nn as nn +import functools +import numpy as np + +from . import utils, layers, normalization + +# RefineBlock = layers.RefineBlock +# ResidualBlock = layers.ResidualBlock +ResnetBlockDDPM = layers.ResnetBlockDDPM +Upsample = layers.Upsample +Downsample = layers.Downsample +conv3x3 = layers.ddpm_conv3x3 +conv5x5 = layers.ddpm_conv5x5 +transposed_conv6x6 = layers.ddpm_conv6x6_transposed +get_act = layers.get_act +get_normalization = normalization.get_normalization +default_initializer = layers.default_init + +@utils.register_model(name='ddpm_res128') +class DDPMRes128(nn.Module): + def __init__(self, config): + super().__init__() + self.act = act = get_act(config) + self.register_buffer('sigmas', torch.tensor(utils.get_sigmas(config))) + + self.nf = nf = config.model.nf + ch_mult = config.model.ch_mult + self.num_res_blocks = num_res_blocks = config.model.num_res_blocks + self.attn_resolutions = attn_resolutions = config.model.attn_resolutions + dropout = config.model.dropout + resamp_with_conv = config.model.resamp_with_conv + self.num_resolutions = num_resolutions = len(ch_mult) + self.all_resolutions = all_resolutions = [config.data.image_size // (2 ** i) for i in range(num_resolutions)] ## manual for 128 to 64 + + AttnBlock = functools.partial(layers.AttnBlock) + self.conditional = conditional = config.model.conditional + ResnetBlock = functools.partial(ResnetBlockDDPM, act=act, temb_dim=4 * nf, dropout=dropout) + if conditional: + # Condition on noise levels. + modules = [nn.Linear(nf, nf * 4)] + modules[0].weight.data = default_initializer()(modules[0].weight.data.shape) + nn.init.zeros_(modules[0].bias) + modules.append(nn.Linear(nf * 4, nf * 4)) + modules[1].weight.data = default_initializer()(modules[1].weight.data.shape) + nn.init.zeros_(modules[1].bias) + + self.centered = config.data.centered + channels = config.data.num_channels + + + ##### Pos Encoding + self.img_size = img_size = config.data.image_size + self.num_freq = int(np.log2(img_size)) + coord_x, coord_y, coord_z = torch.meshgrid(torch.arange(img_size), torch.arange(img_size), torch.arange(img_size)) + self.use_coords = False + if self.use_coords: + self.coords = torch.nn.Parameter( + # torch.stack([coord_x, coord_y, coord_z]).view(1, 3, img_size, img_size, img_size) * 0.0, + torch.stack([coord_x, coord_y, coord_z]).view(1, 3, img_size, img_size, img_size), + requires_grad=False + ) + #### + + ### Mask + self.mask = torch.nn.Parameter(torch.zeros(1, 1, img_size, img_size, img_size), requires_grad=False) + + # Downsampling block + self.pos_layer = conv5x5(3, nf, stride=1, padding=2) + self.mask_layer = conv5x5(1, nf, stride=1, padding=2) + modules.append(conv5x5(channels, nf, stride=1, padding=2)) + hs_c = [nf] + in_ch = nf + + + for i_level in range(num_resolutions): + num_res_blocks_curr = self.num_res_blocks if i_level != 0 else 2 + # Residual blocks for this resolution + for i_block in range(num_res_blocks_curr): + out_ch = nf * ch_mult[i_level] + modules.append(ResnetBlock(in_ch=in_ch, out_ch=out_ch)) + in_ch = out_ch + if all_resolutions[i_level] in attn_resolutions: + modules.append(AttnBlock(channels=in_ch)) + hs_c.append(in_ch) + if i_level != num_resolutions - 1: + modules.append(Downsample(channels=in_ch, with_conv=resamp_with_conv)) + hs_c.append(in_ch) + + in_ch = hs_c[-1] + modules.append(ResnetBlock(in_ch=in_ch)) + modules.append(AttnBlock(channels=in_ch)) + modules.append(ResnetBlock(in_ch=in_ch)) + + # Upsampling block + for i_level in reversed(range(num_resolutions)): + num_res_blocks_curr = self.num_res_blocks if i_level != 0 else 2 + for i_block in range(num_res_blocks_curr + 1): + out_ch = nf * ch_mult[i_level] + modules.append(ResnetBlock(in_ch=in_ch + hs_c.pop(), out_ch=out_ch)) + in_ch = out_ch + if all_resolutions[i_level] in attn_resolutions: + modules.append(AttnBlock(channels=in_ch)) + if i_level != 0: + modules.append(Upsample(channels=in_ch, with_conv=resamp_with_conv)) + + assert not hs_c + modules.append(nn.GroupNorm(num_channels=in_ch, num_groups=32, eps=1e-6)) + # modules.append(conv3x3(in_ch, channels, init_scale=0.)) + # modules.append(transposed_conv6x6(in_ch, channels, init_scale=0.)) + modules.append(conv5x5(in_ch, channels, init_scale=0., stride=1, padding=2)) + self.all_modules = nn.ModuleList(modules) + + self.scale_by_sigma = config.model.scale_by_sigma + + def forward(self, x, labels): + modules = self.all_modules + m_idx = 0 + if self.conditional: + # timestep/scale embedding + timesteps = labels + temb = layers.get_timestep_embedding(timesteps, self.nf) + temb = modules[m_idx](temb) + m_idx += 1 + temb = modules[m_idx](self.act(temb)) + m_idx += 1 + else: + temb = None + + if self.centered: + # Input is in [-1, 1] + h = x + else: + # Input is in [0, 1] + h = 2 * x - 1. + + # Downsampling block + if self.use_coords: + hs = [modules[m_idx](h) + self.pos_layer(self.coords) + self.mask_layer(self.mask)] + else: + hs = [modules[m_idx](h) + self.mask_layer(self.mask)] + m_idx += 1 + for i_level in range(self.num_resolutions): + # Residual blocks for this resolution + num_res_blocks = self.num_res_blocks if i_level != 0 else 2 + for i_block in range(num_res_blocks): + h = modules[m_idx](hs[-1], temb) + m_idx += 1 + if h.shape[-1] in self.attn_resolutions: + h = modules[m_idx](h) + m_idx += 1 + hs.append(h) + if i_level != self.num_resolutions - 1: + hs.append(modules[m_idx](hs[-1])) + m_idx += 1 + + h = hs[-1] + h = modules[m_idx](h, temb) + m_idx += 1 + h = modules[m_idx](h) + m_idx += 1 + h = modules[m_idx](h, temb) + m_idx += 1 + + # Upsampling block + for i_level in reversed(range(self.num_resolutions)): + num_res_blocks = self.num_res_blocks if i_level != 0 else 2 + for i_block in range(num_res_blocks + 1): + hspop = hs.pop() + input = torch.cat([h, hspop], dim=1) + h = modules[m_idx](input, temb) + m_idx += 1 + if h.shape[-1] in self.attn_resolutions: + h = modules[m_idx](h) + m_idx += 1 + if i_level != 0: + h = modules[m_idx](h) + m_idx += 1 + + assert not hs + h = self.act(modules[m_idx](h)) + m_idx += 1 + h = modules[m_idx](h) + m_idx += 1 + assert m_idx == len(modules) + + if self.scale_by_sigma: + # Divide the output by sigmas. Useful for training with the NCSN loss. + # The DDPM loss scales the network output by sigma in the loss function, + # so no need of doing it here. + used_sigmas = self.sigmas[labels, None, None, None] + h = h / used_sigmas + + return h diff --git a/lib/diffusion/models/ddpm_res64.py b/lib/diffusion/models/ddpm_res64.py new file mode 100644 index 0000000..55ab6dc --- /dev/null +++ b/lib/diffusion/models/ddpm_res64.py @@ -0,0 +1,199 @@ +# coding=utf-8 +# Copyright 2020 The Google Research Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: skip-file +"""DDPM model. + +This code is the pytorch equivalent of: +https://github.com/hojonathanho/diffusion/blob/master/diffusion_tf/models/unet.py +""" +import torch +import torch.nn as nn +import functools +import numpy as np + +from . import utils, layers, normalization + +# RefineBlock = layers.RefineBlock +# ResidualBlock = layers.ResidualBlock +ResnetBlockDDPM = layers.ResnetBlockDDPM +Upsample = layers.Upsample +Downsample = layers.Downsample +conv3x3 = layers.ddpm_conv3x3 +get_act = layers.get_act +get_normalization = normalization.get_normalization +default_initializer = layers.default_init + +@utils.register_model(name='ddpm_res64') +class DDPMRes64(nn.Module): + def __init__(self, config): + super().__init__() + self.act = act = get_act(config) + self.register_buffer('sigmas', torch.tensor(utils.get_sigmas(config))) + + self.nf = nf = config.model.nf + ch_mult = config.model.ch_mult + self.num_res_blocks = num_res_blocks = config.model.num_res_blocks + self.attn_resolutions = attn_resolutions = config.model.attn_resolutions + dropout = config.model.dropout + resamp_with_conv = config.model.resamp_with_conv + self.num_resolutions = num_resolutions = len(ch_mult) + self.all_resolutions = all_resolutions = [config.data.image_size // (2 ** i) for i in range(num_resolutions)] + + AttnBlock = functools.partial(layers.AttnBlock) + self.conditional = conditional = config.model.conditional + ResnetBlock = functools.partial(ResnetBlockDDPM, act=act, temb_dim=4 * nf, dropout=dropout) + if conditional: + # Condition on noise levels. + modules = [nn.Linear(nf, nf * 4)] + modules[0].weight.data = default_initializer()(modules[0].weight.data.shape) + nn.init.zeros_(modules[0].bias) + modules.append(nn.Linear(nf * 4, nf * 4)) + modules[1].weight.data = default_initializer()(modules[1].weight.data.shape) + nn.init.zeros_(modules[1].bias) + + self.centered = config.data.centered + channels = config.data.num_channels + + + ##### Pos Encoding + self.img_size = img_size = config.data.image_size + self.num_freq = int(np.log2(img_size)) + coord_x, coord_y, coord_z = torch.meshgrid(torch.arange(img_size), torch.arange(img_size), torch.arange(img_size)) + self.coords = torch.nn.Parameter( + torch.stack([coord_x, coord_y, coord_z]).view(1, 3, img_size, img_size, img_size) * 0.0, + requires_grad=False + ) + #### + + ### Mask + self.mask = torch.nn.Parameter(torch.zeros(1, 1, img_size, img_size, img_size), requires_grad=False) + + # Downsampling block + self.pos_layer = conv3x3(3, nf) + self.mask_layer = conv3x3(1, nf) + modules.append(conv3x3(channels, nf)) + hs_c = [nf] + in_ch = nf + for i_level in range(num_resolutions): + # Residual blocks for this resolution + for i_block in range(num_res_blocks): + out_ch = nf * ch_mult[i_level] + modules.append(ResnetBlock(in_ch=in_ch, out_ch=out_ch)) + in_ch = out_ch + if all_resolutions[i_level] in attn_resolutions: + modules.append(AttnBlock(channels=in_ch)) + hs_c.append(in_ch) + if i_level != num_resolutions - 1: + modules.append(Downsample(channels=in_ch, with_conv=resamp_with_conv)) + hs_c.append(in_ch) + + in_ch = hs_c[-1] + modules.append(ResnetBlock(in_ch=in_ch)) + modules.append(AttnBlock(channels=in_ch)) + modules.append(ResnetBlock(in_ch=in_ch)) + + # Upsampling block + for i_level in reversed(range(num_resolutions)): + for i_block in range(num_res_blocks + 1): + out_ch = nf * ch_mult[i_level] + modules.append(ResnetBlock(in_ch=in_ch + hs_c.pop(), out_ch=out_ch)) + in_ch = out_ch + if all_resolutions[i_level] in attn_resolutions: + modules.append(AttnBlock(channels=in_ch)) + if i_level != 0: + modules.append(Upsample(channels=in_ch, with_conv=resamp_with_conv)) + + assert not hs_c + modules.append(nn.GroupNorm(num_channels=in_ch, num_groups=32, eps=1e-6)) + modules.append(conv3x3(in_ch, channels, init_scale=0.)) + self.all_modules = nn.ModuleList(modules) + + self.scale_by_sigma = config.model.scale_by_sigma + + def forward(self, x, labels): + modules = self.all_modules + m_idx = 0 + if self.conditional: + # timestep/scale embedding + timesteps = labels + temb = layers.get_timestep_embedding(timesteps, self.nf) + temb = modules[m_idx](temb) + m_idx += 1 + temb = modules[m_idx](self.act(temb)) + m_idx += 1 + else: + temb = None + + if self.centered: + # Input is in [-1, 1] + h = x + else: + # Input is in [0, 1] + h = 2 * x - 1. + + # Downsampling block + hs = [modules[m_idx](h) + self.pos_layer(self.coords) + self.mask_layer(self.mask)] + m_idx += 1 + for i_level in range(self.num_resolutions): + # Residual blocks for this resolution + for i_block in range(self.num_res_blocks): + h = modules[m_idx](hs[-1], temb) + m_idx += 1 + if h.shape[-1] in self.attn_resolutions: + h = modules[m_idx](h) + m_idx += 1 + hs.append(h) + if i_level != self.num_resolutions - 1: + hs.append(modules[m_idx](hs[-1])) + m_idx += 1 + + h = hs[-1] + h = modules[m_idx](h, temb) + m_idx += 1 + h = modules[m_idx](h) + m_idx += 1 + h = modules[m_idx](h, temb) + m_idx += 1 + + # Upsampling block + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + hspop = hs.pop() + input = torch.cat([h, hspop], dim=1) + h = modules[m_idx](input, temb) + m_idx += 1 + if h.shape[-1] in self.attn_resolutions: + h = modules[m_idx](h) + m_idx += 1 + if i_level != 0: + h = modules[m_idx](h) + m_idx += 1 + + assert not hs + h = self.act(modules[m_idx](h)) + m_idx += 1 + h = modules[m_idx](h) + m_idx += 1 + assert m_idx == len(modules) + + if self.scale_by_sigma: + # Divide the output by sigmas. Useful for training with the NCSN loss. + # The DDPM loss scales the network output by sigma in the loss function, + # so no need of doing it here. + used_sigmas = self.sigmas[labels, None, None, None] + h = h / used_sigmas + + return h diff --git a/lib/diffusion/models/ema.py b/lib/diffusion/models/ema.py new file mode 100644 index 0000000..7355511 --- /dev/null +++ b/lib/diffusion/models/ema.py @@ -0,0 +1,98 @@ +# Modified from https://raw.githubusercontent.com/fadel/pytorch_ema/master/torch_ema/ema.py + +from __future__ import division +from __future__ import unicode_literals + +import torch + + +# Partially based on: https://github.com/tensorflow/tensorflow/blob/r1.13/tensorflow/python/training/moving_averages.py +class ExponentialMovingAverage: + """ + Maintains (exponential) moving average of a set of parameters. + """ + + def __init__(self, parameters, decay, use_num_updates=True): + """ + Args: + parameters: Iterable of `torch.nn.Parameter`; usually the result of + `model.parameters()`. + decay: The exponential decay. + use_num_updates: Whether to use number of updates when computing + averages. + """ + if decay < 0.0 or decay > 1.0: + raise ValueError('Decay must be between 0 and 1') + self.decay = decay + self.num_updates = 0 if use_num_updates else None + self.shadow_params = [p.clone().detach() + for p in parameters if p.requires_grad] + self.collected_params = [] + + def update(self, parameters): + """ + Update currently maintained parameters. + + Call this every time the parameters are updated, such as the result of + the `optimizer.step()` call. + + Args: + parameters: Iterable of `torch.nn.Parameter`; usually the same set of + parameters used to initialize this object. + """ + decay = self.decay + if self.num_updates is not None: + self.num_updates += 1 + decay = min(decay, (1 + self.num_updates) / (10 + self.num_updates)) + one_minus_decay = 1.0 - decay + with torch.no_grad(): + parameters = [p for p in parameters if p.requires_grad] + for s_param, param in zip(self.shadow_params, parameters): + s_param.sub_(one_minus_decay * (s_param - param)) + + def copy_to(self, parameters): + """ + Copy current parameters into given collection of parameters. + + Args: + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + updated with the stored moving averages. + """ + parameters = [p for p in parameters if p.requires_grad] + for s_param, param in zip(self.shadow_params, parameters): + if param.requires_grad: + param.data.copy_(s_param.data) + + def store(self, parameters): + """ + Save the current parameters for restoring later. + + Args: + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + temporarily stored. + """ + self.collected_params = [param.clone() for param in parameters] + + def restore(self, parameters): + """ + Restore the parameters stored with the `store` method. + Useful to validate the model with EMA parameters without affecting the + original optimization process. Store the parameters before the + `copy_to` method. After validation (or model saving), use this to + restore the former parameters. + + Args: + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + updated with the stored parameters. + """ + for c_param, param in zip(self.collected_params, parameters): + param.data.copy_(c_param.data) + + def state_dict(self): + return dict(decay=self.decay, num_updates=self.num_updates, + shadow_params=self.shadow_params) + + def load_state_dict(self, state_dict): + self.decay = state_dict['decay'] + self.num_updates = state_dict['num_updates'] + self.shadow_params = state_dict['shadow_params'] \ No newline at end of file diff --git a/lib/diffusion/models/layers.py b/lib/diffusion/models/layers.py new file mode 100644 index 0000000..51090a8 --- /dev/null +++ b/lib/diffusion/models/layers.py @@ -0,0 +1,771 @@ +# coding=utf-8 +# Copyright 2020 The Google Research Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: skip-file +"""Common layers for defining score networks. +""" +import math +import string +from functools import partial +import torch.nn as nn +import torch +import torch.nn.functional as F +import numpy as np +from .normalization import ConditionalInstanceNorm3dPlus + + +def get_act(config): + """Get activation functions from the config file.""" + + if config.model.nonlinearity.lower() == 'elu': + return nn.ELU() + elif config.model.nonlinearity.lower() == 'relu': + return nn.ReLU() + elif config.model.nonlinearity.lower() == 'lrelu': + return nn.LeakyReLU(negative_slope=0.2) + elif config.model.nonlinearity.lower() == 'swish': + return nn.SiLU() + else: + raise NotImplementedError('activation function does not exist!') + + +def ncsn_conv1x1(in_planes, out_planes, stride=1, bias=True, dilation=1, init_scale=1., padding=0): + """1x1 convolution. Same as NCSNv1/v2.""" + conv = nn.Conv3d(in_planes, out_planes, kernel_size=1, stride=stride, bias=bias, dilation=dilation, + padding=padding) + init_scale = 1e-10 if init_scale == 0 else init_scale + conv.weight.data *= init_scale + conv.bias.data *= init_scale + return conv + + +def variance_scaling(scale, mode, distribution, + in_axis=1, out_axis=0, + dtype=torch.float32, + device='cpu'): + """Ported from JAX. """ + + def _compute_fans(shape, in_axis=1, out_axis=0): + receptive_field_size = np.prod(shape) / shape[in_axis] / shape[out_axis] + fan_in = shape[in_axis] * receptive_field_size + fan_out = shape[out_axis] * receptive_field_size + return fan_in, fan_out + + def init(shape, dtype=dtype, device=device): + fan_in, fan_out = _compute_fans(shape, in_axis, out_axis) + if mode == "fan_in": + denominator = fan_in + elif mode == "fan_out": + denominator = fan_out + elif mode == "fan_avg": + denominator = (fan_in + fan_out) / 2 + else: + raise ValueError( + "invalid mode for variance scaling initializer: {}".format(mode)) + variance = scale / denominator + if distribution == "normal": + return torch.randn(*shape, dtype=dtype, device=device) * np.sqrt(variance) + elif distribution == "uniform": + return (torch.rand(*shape, dtype=dtype, device=device) * 2. - 1.) * np.sqrt(3 * variance) + else: + raise ValueError("invalid distribution for variance scaling initializer") + + return init + + +def default_init(scale=1.): + """The same initialization used in DDPM.""" + scale = 1e-10 if scale == 0 else scale + return variance_scaling(scale, 'fan_avg', 'uniform') + + +class Dense(nn.Module): + """Linear layer with `default_init`.""" + def __init__(self): + super().__init__() + + +def ddpm_conv1x1(in_planes, out_planes, stride=1, bias=True, init_scale=1., padding=0): + """1x1 convolution with DDPM initialization.""" + conv = nn.Conv3d(in_planes, out_planes, kernel_size=1, stride=stride, padding=padding, bias=bias) + conv.weight.data = default_init(init_scale)(conv.weight.data.shape) + nn.init.zeros_(conv.bias) + return conv + + +def ncsn_conv3x3(in_planes, out_planes, stride=1, bias=True, dilation=1, init_scale=1., padding=1): + """3x3 convolution with PyTorch initialization. Same as NCSNv1/NCSNv2.""" + init_scale = 1e-10 if init_scale == 0 else init_scale + conv = nn.Conv3d(in_planes, out_planes, stride=stride, bias=bias, + dilation=dilation, padding=padding, kernel_size=3) + conv.weight.data *= init_scale + conv.bias.data *= init_scale + return conv + + +def ddpm_conv3x3(in_planes, out_planes, stride=1, bias=True, dilation=1, init_scale=1., padding=1): + """3x3 convolution with DDPM initialization.""" + conv = nn.Conv3d(in_planes, out_planes, kernel_size=3, stride=stride, padding=padding, + dilation=dilation, bias=bias) + conv.weight.data = default_init(init_scale)(conv.weight.data.shape) + nn.init.zeros_(conv.bias) + return conv + +def ddpm_conv5x5(in_planes, out_planes, stride=2, bias=True, dilation=1, init_scale=1., padding=2): + """3x3 convolution with DDPM initialization.""" + conv = nn.Conv3d(in_planes, out_planes, kernel_size=5, stride=stride, padding=padding, + dilation=dilation, bias=bias) + conv.weight.data = default_init(init_scale)(conv.weight.data.shape) + nn.init.zeros_(conv.bias) + return conv + + +def ddpm_conv5x5_transposed(in_planes, out_planes, stride=2, bias=True, dilation=1, init_scale=1., padding=2): + """3x3 convolution with DDPM initialization.""" + conv = nn.ConvTranspose3d(in_planes, out_planes, kernel_size=5, stride=stride, padding=padding, + dilation=dilation, bias=bias, output_padding=(0, 1)) + conv.weight.data = default_init(init_scale)(conv.weight.data.shape) + nn.init.zeros_(conv.bias) + return conv + + +def ddpm_conv6x6_transposed(in_planes, out_planes, stride=2, bias=True, dilation=1, init_scale=1., padding=2): + """3x3 convolution with DDPM initialization.""" + conv = nn.ConvTranspose3d(in_planes, out_planes, kernel_size=6, stride=stride, padding=padding, + dilation=dilation, bias=bias) + conv.weight.data = default_init(init_scale)(conv.weight.data.shape) + nn.init.zeros_(conv.bias) + return conv + + + ########################################################################### + # Functions below are ported over from the NCSNv1/NCSNv2 codebase: + # https://github.com/ermongroup/ncsn + # https://github.com/ermongroup/ncsnv2 + ########################################################################### + + +class CRPBlock(nn.Module): + def __init__(self, features, n_stages, act=nn.ReLU(), maxpool=True): + super().__init__() + self.convs = nn.ModuleList() + for i in range(n_stages): + self.convs.append(ncsn_conv3x3(features, features, stride=1, bias=False)) + self.n_stages = n_stages + if maxpool: + self.pool = nn.MaxPool3d(kernel_size=5, stride=1, padding=2) + else: + self.pool = nn.AvgPool3d(kernel_size=5, stride=1, padding=2) + + self.act = act + + def forward(self, x): + x = self.act(x) + path = x + for i in range(self.n_stages): + path = self.pool(path) + path = self.convs[i](path) + x = path + x + return x + + +class CondCRPBlock(nn.Module): + def __init__(self, features, n_stages, num_classes, normalizer, act=nn.ReLU()): + super().__init__() + self.convs = nn.ModuleList() + self.norms = nn.ModuleList() + self.normalizer = normalizer + for i in range(n_stages): + self.norms.append(normalizer(features, num_classes, bias=True)) + self.convs.append(ncsn_conv3x3(features, features, stride=1, bias=False)) + + self.n_stages = n_stages + self.pool = nn.AvgPool3d(kernel_size=5, stride=1, padding=2) + self.act = act + + def forward(self, x, y): + x = self.act(x) + path = x + for i in range(self.n_stages): + path = self.norms[i](path, y) + path = self.pool(path) + path = self.convs[i](path) + + x = path + x + return x + + +class RCUBlock(nn.Module): + def __init__(self, features, n_blocks, n_stages, act=nn.ReLU()): + super().__init__() + + for i in range(n_blocks): + for j in range(n_stages): + setattr(self, '{}_{}_conv'.format(i + 1, j + 1), ncsn_conv3x3(features, features, stride=1, bias=False)) + + self.stride = 1 + self.n_blocks = n_blocks + self.n_stages = n_stages + self.act = act + + def forward(self, x): + for i in range(self.n_blocks): + residual = x + for j in range(self.n_stages): + x = self.act(x) + x = getattr(self, '{}_{}_conv'.format(i + 1, j + 1))(x) + + x += residual + return x + + +class CondRCUBlock(nn.Module): + def __init__(self, features, n_blocks, n_stages, num_classes, normalizer, act=nn.ReLU()): + super().__init__() + + for i in range(n_blocks): + for j in range(n_stages): + setattr(self, '{}_{}_norm'.format(i + 1, j + 1), normalizer(features, num_classes, bias=True)) + setattr(self, '{}_{}_conv'.format(i + 1, j + 1), ncsn_conv3x3(features, features, stride=1, bias=False)) + + self.stride = 1 + self.n_blocks = n_blocks + self.n_stages = n_stages + self.act = act + self.normalizer = normalizer + + def forward(self, x, y): + for i in range(self.n_blocks): + residual = x + for j in range(self.n_stages): + x = getattr(self, '{}_{}_norm'.format(i + 1, j + 1))(x, y) + x = self.act(x) + x = getattr(self, '{}_{}_conv'.format(i + 1, j + 1))(x) + + x += residual + return x + + +class MSFBlock(nn.Module): + def __init__(self, in_planes, features): + super().__init__() + assert isinstance(in_planes, list) or isinstance(in_planes, tuple) + self.convs = nn.ModuleList() + self.features = features + + for i in range(len(in_planes)): + self.convs.append(ncsn_conv3x3(in_planes[i], features, stride=1, bias=True)) + + def forward(self, xs, shape): + sums = torch.zeros(xs[0].shape[0], self.features, *shape, device=xs[0].device) + for i in range(len(self.convs)): + h = self.convs[i](xs[i]) + h = F.interpolate(h, size=shape, mode='bilinear', align_corners=True) + sums += h + return sums + + +class CondMSFBlock(nn.Module): + def __init__(self, in_planes, features, num_classes, normalizer): + super().__init__() + assert isinstance(in_planes, list) or isinstance(in_planes, tuple) + + self.convs = nn.ModuleList() + self.norms = nn.ModuleList() + self.features = features + self.normalizer = normalizer + + for i in range(len(in_planes)): + self.convs.append(ncsn_conv3x3(in_planes[i], features, stride=1, bias=True)) + self.norms.append(normalizer(in_planes[i], num_classes, bias=True)) + + def forward(self, xs, y, shape): + sums = torch.zeros(xs[0].shape[0], self.features, *shape, device=xs[0].device) + for i in range(len(self.convs)): + h = self.norms[i](xs[i], y) + h = self.convs[i](h) + h = F.interpolate(h, size=shape, mode='bilinear', align_corners=True) + sums += h + return sums + + +class RefineBlock(nn.Module): + def __init__(self, in_planes, features, act=nn.ReLU(), start=False, end=False, maxpool=True): + super().__init__() + + assert isinstance(in_planes, tuple) or isinstance(in_planes, list) + self.n_blocks = n_blocks = len(in_planes) + + self.adapt_convs = nn.ModuleList() + for i in range(n_blocks): + self.adapt_convs.append(RCUBlock(in_planes[i], 2, 2, act)) + + self.output_convs = RCUBlock(features, 3 if end else 1, 2, act) + + if not start: + self.msf = MSFBlock(in_planes, features) + + self.crp = CRPBlock(features, 2, act, maxpool=maxpool) + + def forward(self, xs, output_shape): + assert isinstance(xs, tuple) or isinstance(xs, list) + hs = [] + for i in range(len(xs)): + h = self.adapt_convs[i](xs[i]) + hs.append(h) + + if self.n_blocks > 1: + h = self.msf(hs, output_shape) + else: + h = hs[0] + + h = self.crp(h) + h = self.output_convs(h) + + return h + + +class CondRefineBlock(nn.Module): + def __init__(self, in_planes, features, num_classes, normalizer, act=nn.ReLU(), start=False, end=False): + super().__init__() + + assert isinstance(in_planes, tuple) or isinstance(in_planes, list) + self.n_blocks = n_blocks = len(in_planes) + + self.adapt_convs = nn.ModuleList() + for i in range(n_blocks): + self.adapt_convs.append( + CondRCUBlock(in_planes[i], 2, 2, num_classes, normalizer, act) + ) + + self.output_convs = CondRCUBlock(features, 3 if end else 1, 2, num_classes, normalizer, act) + + if not start: + self.msf = CondMSFBlock(in_planes, features, num_classes, normalizer) + + self.crp = CondCRPBlock(features, 2, num_classes, normalizer, act) + + def forward(self, xs, y, output_shape): + assert isinstance(xs, tuple) or isinstance(xs, list) + hs = [] + for i in range(len(xs)): + h = self.adapt_convs[i](xs[i], y) + hs.append(h) + + if self.n_blocks > 1: + h = self.msf(hs, y, output_shape) + else: + h = hs[0] + + h = self.crp(h, y) + h = self.output_convs(h, y) + + return h + + +class ConvMeanPool(nn.Module): + def __init__(self, input_dim, output_dim, kernel_size=3, biases=True, adjust_padding=False): + super().__init__() + if not adjust_padding: + conv = nn.Conv3d(input_dim, output_dim, kernel_size, stride=1, padding=kernel_size // 2, bias=biases) + self.conv = conv + else: + conv = nn.Conv3d(input_dim, output_dim, kernel_size, stride=1, padding=kernel_size // 2, bias=biases) + + self.conv = nn.Sequential( + nn.ZeroPad3d((1, 0, 1, 0)), + conv + ) + + def forward(self, inputs): + output = self.conv(inputs) + output = sum([output[:, :, ::2, ::2], output[:, :, 1::2, ::2], + output[:, :, ::2, 1::2], output[:, :, 1::2, 1::2]]) / 4. + return output + + +class MeanPoolConv(nn.Module): + def __init__(self, input_dim, output_dim, kernel_size=3, biases=True): + super().__init__() + self.conv = nn.Conv3d(input_dim, output_dim, kernel_size, stride=1, padding=kernel_size // 2, bias=biases) + + def forward(self, inputs): + output = inputs + output = sum([output[:, :, ::2, ::2], output[:, :, 1::2, ::2], + output[:, :, ::2, 1::2], output[:, :, 1::2, 1::2]]) / 4. + return self.conv(output) + + +class UpsampleConv(nn.Module): + def __init__(self, input_dim, output_dim, kernel_size=3, biases=True): + super().__init__() + self.conv = nn.Conv3d(input_dim, output_dim, kernel_size, stride=1, padding=kernel_size // 2, bias=biases) + self.pixelshuffle = nn.PixelShuffle(upscale_factor=2) + + def forward(self, inputs): + output = inputs + output = torch.cat([output, output, output, output], dim=1) + output = self.pixelshuffle(output) + return self.conv(output) + + +class ConditionalResidualBlock(nn.Module): + def __init__(self, input_dim, output_dim, num_classes, resample=1, act=nn.ELU(), + normalization=ConditionalInstanceNorm3dPlus, adjust_padding=False, dilation=None): + super().__init__() + self.non_linearity = act + self.input_dim = input_dim + self.output_dim = output_dim + self.resample = resample + self.normalization = normalization + if resample == 'down': + if dilation > 1: + self.conv1 = ncsn_conv3x3(input_dim, input_dim, dilation=dilation) + self.normalize2 = normalization(input_dim, num_classes) + self.conv2 = ncsn_conv3x3(input_dim, output_dim, dilation=dilation) + conv_shortcut = partial(ncsn_conv3x3, dilation=dilation) + else: + self.conv1 = ncsn_conv3x3(input_dim, input_dim) + self.normalize2 = normalization(input_dim, num_classes) + self.conv2 = ConvMeanPool(input_dim, output_dim, 3, adjust_padding=adjust_padding) + conv_shortcut = partial(ConvMeanPool, kernel_size=1, adjust_padding=adjust_padding) + + elif resample is None: + if dilation > 1: + conv_shortcut = partial(ncsn_conv3x3, dilation=dilation) + self.conv1 = ncsn_conv3x3(input_dim, output_dim, dilation=dilation) + self.normalize2 = normalization(output_dim, num_classes) + self.conv2 = ncsn_conv3x3(output_dim, output_dim, dilation=dilation) + else: + conv_shortcut = nn.Conv3d + self.conv1 = ncsn_conv3x3(input_dim, output_dim) + self.normalize2 = normalization(output_dim, num_classes) + self.conv2 = ncsn_conv3x3(output_dim, output_dim) + else: + raise Exception('invalid resample value') + + if output_dim != input_dim or resample is not None: + self.shortcut = conv_shortcut(input_dim, output_dim) + + self.normalize1 = normalization(input_dim, num_classes) + + def forward(self, x, y): + output = self.normalize1(x, y) + output = self.non_linearity(output) + output = self.conv1(output) + output = self.normalize2(output, y) + output = self.non_linearity(output) + output = self.conv2(output) + + if self.output_dim == self.input_dim and self.resample is None: + shortcut = x + else: + shortcut = self.shortcut(x) + + return shortcut + output + + +class ResidualBlock(nn.Module): + def __init__(self, input_dim, output_dim, resample=None, act=nn.ELU(), + normalization=nn.InstanceNorm3d, adjust_padding=False, dilation=1): + super().__init__() + self.non_linearity = act + self.input_dim = input_dim + self.output_dim = output_dim + self.resample = resample + self.normalization = normalization + if resample == 'down': + if dilation > 1: + self.conv1 = ncsn_conv3x3(input_dim, input_dim, dilation=dilation) + self.normalize2 = normalization(input_dim) + self.conv2 = ncsn_conv3x3(input_dim, output_dim, dilation=dilation) + conv_shortcut = partial(ncsn_conv3x3, dilation=dilation) + else: + self.conv1 = ncsn_conv3x3(input_dim, input_dim) + self.normalize2 = normalization(input_dim) + self.conv2 = ConvMeanPool(input_dim, output_dim, 3, adjust_padding=adjust_padding) + conv_shortcut = partial(ConvMeanPool, kernel_size=1, adjust_padding=adjust_padding) + + elif resample is None: + if dilation > 1: + conv_shortcut = partial(ncsn_conv3x3, dilation=dilation) + self.conv1 = ncsn_conv3x3(input_dim, output_dim, dilation=dilation) + self.normalize2 = normalization(output_dim) + self.conv2 = ncsn_conv3x3(output_dim, output_dim, dilation=dilation) + else: + # conv_shortcut = nn.Conv3d ### Something wierd here. + conv_shortcut = partial(ncsn_conv1x1) + self.conv1 = ncsn_conv3x3(input_dim, output_dim) + self.normalize2 = normalization(output_dim) + self.conv2 = ncsn_conv3x3(output_dim, output_dim) + else: + raise Exception('invalid resample value') + + if output_dim != input_dim or resample is not None: + self.shortcut = conv_shortcut(input_dim, output_dim) + + self.normalize1 = normalization(input_dim) + + def forward(self, x): + output = self.normalize1(x) + output = self.non_linearity(output) + output = self.conv1(output) + output = self.normalize2(output) + output = self.non_linearity(output) + output = self.conv2(output) + + if self.output_dim == self.input_dim and self.resample is None: + shortcut = x + else: + shortcut = self.shortcut(x) + + return shortcut + output + + +########################################################################### +# Functions below are ported over from the DDPM codebase: +# https://github.com/hojonathanho/diffusion/blob/master/diffusion_tf/nn.py +########################################################################### + +def get_timestep_embedding(timesteps, embedding_dim, max_positions=10000): + assert len(timesteps.shape) == 1 # and timesteps.dtype == tf.int32 + half_dim = embedding_dim // 2 + # magic number 10000 is from transformers + emb = math.log(max_positions) / (half_dim - 1) + # emb = math.log(2.) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) * -emb) + # emb = tf.range(num_embeddings, dtype=jnp.float32)[:, None] * emb[None, :] + # emb = tf.cast(timesteps, dtype=jnp.float32)[:, None] * emb[None, :] + emb = timesteps.float()[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = F.pad(emb, (0, 1), mode='constant') + assert emb.shape == (timesteps.shape[0], embedding_dim) + return emb + + +def _einsum(a, b, c, x, y): + einsum_str = '{},{}->{}'.format(''.join(a), ''.join(b), ''.join(c)) + return torch.einsum(einsum_str, x, y) + + +def contract_inner(x, y): + """tensordot(x, y, 1).""" + x_chars = list(string.ascii_lowercase[:len(x.shape)]) + y_chars = list(string.ascii_lowercase[len(x.shape):len(y.shape) + len(x.shape)]) + y_chars[0] = x_chars[-1] # first axis of y and last of x get summed + out_chars = x_chars[:-1] + y_chars[1:] + return _einsum(x_chars, y_chars, out_chars, x, y) + + +class NIN(nn.Module): + def __init__(self, in_dim, num_units, init_scale=0.1): + super().__init__() + self.W = nn.Parameter(default_init(scale=init_scale)((in_dim, num_units)), requires_grad=True) + self.b = nn.Parameter(torch.zeros(num_units), requires_grad=True) + + def forward(self, x): + x = x.permute(0, 2, 3, 4, 1) + y = contract_inner(x, self.W) + self.b + return y.permute(0, 4, 1, 2, 3) + + +class AttnBlock(nn.Module): + """Channel-wise self-attention block.""" + def __init__(self, channels): + super().__init__() + self.GroupNorm_0 = nn.GroupNorm(num_groups=32, num_channels=channels, eps=1e-6) + self.NIN_0 = NIN(channels, channels) + self.NIN_1 = NIN(channels, channels) + self.NIN_2 = NIN(channels, channels) + self.NIN_3 = NIN(channels, channels, init_scale=0.) + + def forward(self, x): + B, C, D, H, W = x.shape + h = self.GroupNorm_0(x) + q = self.NIN_0(h) + k = self.NIN_1(h) + v = self.NIN_2(h) + + w = torch.einsum('bcdhw,bckij->bdhwkij', q, k) * (int(C) ** (-0.5)) + w = torch.reshape(w, (B, D, H, W, D * H * W)) + w = F.softmax(w, dim=-1) + w = torch.reshape(w, (B, D, H, W, D, H, W)) + h = torch.einsum('bdhwkij,bckij->bcdhw', w, v) + h = self.NIN_3(h) + return x + h + + +class Upsample(nn.Module): + def __init__(self, channels, with_conv=False): + super().__init__() + if with_conv: + self.Conv_0 = ddpm_conv3x3(channels, channels) + self.with_conv = with_conv + + def forward(self, x): + B, C, D, H, W = x.shape + h = F.interpolate(x, (D * 2, H * 2, W * 2), mode='nearest') + if self.with_conv: + h = self.Conv_0(h) + return h + + +class Downsample(nn.Module): + def __init__(self, channels, with_conv=False): + super().__init__() + if with_conv: + self.Conv_0 = ddpm_conv3x3(channels, channels, stride=2, padding=0) + self.with_conv = with_conv + + def forward(self, x): + B, C, D, H, W = x.shape + # Emulate 'SAME' padding + if self.with_conv: + x = F.pad(x, (0, 1, 0, 1, 0, 1)) + x = self.Conv_0(x) + else: + x = F.avg_pool3d(x, kernel_size=2, stride=2, padding=0) + + assert x.shape == (B, C, D // 2, H // 2, W // 2) + return x + + +class ResnetBlockDDPM(nn.Module): + """The ResNet Blocks used in DDPM.""" + def __init__(self, act, in_ch, out_ch=None, temb_dim=None, conv_shortcut=False, dropout=0.1): + super().__init__() + if out_ch is None: + out_ch = in_ch + self.GroupNorm_0 = nn.GroupNorm(num_groups=32, num_channels=in_ch, eps=1e-6) + self.act = act + self.Conv_0 = ddpm_conv3x3(in_ch, out_ch) + if temb_dim is not None: + self.Dense_0 = nn.Linear(temb_dim, out_ch) + self.Dense_0.weight.data = default_init()(self.Dense_0.weight.data.shape) + nn.init.zeros_(self.Dense_0.bias) + + self.GroupNorm_1 = nn.GroupNorm(num_groups=32, num_channels=out_ch, eps=1e-6) + self.Dropout_0 = nn.Dropout(dropout) + self.Conv_1 = ddpm_conv3x3(out_ch, out_ch, init_scale=0.) + if in_ch != out_ch: + if conv_shortcut: + self.Conv_2 = ddpm_conv3x3(in_ch, out_ch) + else: + self.NIN_0 = NIN(in_ch, out_ch) + self.out_ch = out_ch + self.in_ch = in_ch + self.conv_shortcut = conv_shortcut + + def forward(self, x, temb=None): + B, C, D, H, W = x.shape + assert C == self.in_ch + out_ch = self.out_ch if self.out_ch else self.in_ch + h = self.act(self.GroupNorm_0(x)) + h = self.Conv_0(h) + # Add bias to each feature map conditioned on the time embedding + if temb is not None: + h += self.Dense_0(self.act(temb))[:, :, None, None, None] + h = self.act(self.GroupNorm_1(h)) + h = self.Dropout_0(h) + h = self.Conv_1(h) + if C != out_ch: + if self.conv_shortcut: + x = self.Conv_2(x) + else: + x = self.NIN_0(x) + return x + h + +# class PositionalEncoding(nn.Module): + +# def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000): +# super().__init__() +# self.dropout = nn.Dropout(p=dropout) + +# position = torch.arange(max_len).unsqueeze(1) +# div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)) +# pe = torch.zeros(max_len, 1, d_model) +# pe[:, 0, 0::2] = torch.sin(position * div_term) +# pe[:, 0, 1::2] = torch.cos(position * div_term) +# self.register_buffer('pe', pe) + +# def forward(self, x: Tensor) -> Tensor: +# """ +# Args: +# x: Tensor, shape [seq_len, batch_size, embedding_dim] +# """ +# x = x + self.pe[:x.size(0)] +# return self.dropout(x) + +class ResnetBlockDDPMPosEncoding(nn.Module): + """The ResNet Blocks used in DDPM.""" + def __init__(self, act, in_ch, out_ch=None, temb_dim=None, conv_shortcut=False, dropout=0.1, img_size=64): + super().__init__() + ##### Pos Encoding + coord_x, coord_y, coord_z = torch.meshgrid(torch.arange(img_size), torch.arange(img_size), torch.arange(img_size)) + coords = torch.stack([coord_x, coord_y, coord_z]) + self.num_freq = int(np.log2(img_size)) + pos_encoding = torch.zeros(1, 2 * self.num_freq, 3, img_size, img_size, img_size) + with torch.no_grad(): + for i in range(self.num_freq): + pos_encoding[0, 2*i, :, :, :, :] = torch.cos((i+1) * np.pi * coords) + pos_encoding[0, 2*i + 1, :, :, :, :] = torch.sin((i+1) * np.pi * coords) + self.pos_encoding = nn.Parameter( + pos_encoding.view(1, 2 * self.num_freq * 3, img_size, img_size, img_size) / img_size, + requires_grad=False + ) + #### + + if out_ch is None: + out_ch = in_ch + self.GroupNorm_0 = nn.GroupNorm(num_groups=32, num_channels=in_ch, eps=1e-6) + self.act = act + self.Conv_0 = ddpm_conv3x3(in_ch, out_ch) + self.Conv_0_pos = ddpm_conv3x3(2 * self.num_freq * 3, out_ch) + if temb_dim is not None: + self.Dense_0 = nn.Linear(temb_dim, out_ch) + self.Dense_0.weight.data = default_init()(self.Dense_0.weight.data.shape) + nn.init.zeros_(self.Dense_0.bias) + + self.GroupNorm_1 = nn.GroupNorm(num_groups=32, num_channels=out_ch, eps=1e-6) + self.Dropout_0 = nn.Dropout(dropout) + self.Conv_1 = ddpm_conv3x3(out_ch, out_ch, init_scale=0.) + if in_ch != out_ch: + if conv_shortcut: + self.Conv_2 = ddpm_conv3x3(in_ch, out_ch) + else: + self.NIN_0 = NIN(in_ch, out_ch) + self.out_ch = out_ch + self.in_ch = in_ch + self.conv_shortcut = conv_shortcut + + def forward(self, x, temb=None): + B, C, D, H, W = x.shape + assert C == self.in_ch + out_ch = self.out_ch if self.out_ch else self.in_ch + h = self.act(self.GroupNorm_0(x)) + h = self.Conv_0(h) + self.Conv_0_pos(self.pos_encoding).expand(h.size(0), -1, -1, -1, -1) + # Add bias to each feature map conditioned on the time embedding + if temb is not None: + h += self.Dense_0(self.act(temb))[:, :, None, None, None] + h = self.act(self.GroupNorm_1(h)) + h = self.Dropout_0(h) + h = self.Conv_1(h) + if C != out_ch: + if self.conv_shortcut: + x = self.Conv_2(x) + else: + x = self.NIN_0(x) + return x + h \ No newline at end of file diff --git a/lib/diffusion/models/normalization.py b/lib/diffusion/models/normalization.py new file mode 100644 index 0000000..ccc6a3e --- /dev/null +++ b/lib/diffusion/models/normalization.py @@ -0,0 +1,223 @@ +# coding=utf-8 +# Copyright 2020 The Google Research Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Normalization layers.""" +import torch.nn as nn +import torch +import functools + + +def get_normalization(config, conditional=False): + """Obtain normalization modules from the config file.""" + norm = config.model.normalization + if conditional: + if norm == 'InstanceNorm++': + return functools.partial(ConditionalInstanceNorm3dPlus, num_classes=config.model.num_classes) + else: + raise NotImplementedError(f'{norm} not implemented yet.') + else: + if norm == 'InstanceNorm': + return nn.InstanceNorm3d + elif norm == 'InstanceNorm++': + return InstanceNorm3dPlus + elif norm == 'VarianceNorm': + return VarianceNorm3d + elif norm == 'GroupNorm': + return nn.GroupNorm + else: + raise ValueError('Unknown normalization: %s' % norm) + + +class ConditionalBatchNorm3d(nn.Module): + def __init__(self, num_features, num_classes, bias=True): + super().__init__() + self.num_features = num_features + self.bias = bias + self.bn = nn.BatchNorm3d(num_features, affine=False) + if self.bias: + self.embed = nn.Embedding(num_classes, num_features * 2) + self.embed.weight.data[:, :num_features].uniform_() # Initialise scale at N(1, 0.02) + self.embed.weight.data[:, num_features:].zero_() # Initialise bias at 0 + else: + self.embed = nn.Embedding(num_classes, num_features) + self.embed.weight.data.uniform_() + + def forward(self, x, y): + out = self.bn(x) + if self.bias: + gamma, beta = self.embed(y).chunk(2, dim=1) + out = gamma.view(-1, self.num_features, 1, 1, 1) * out + beta.view(-1, self.num_features, 1, 1, 1) + else: + gamma = self.embed(y) + out = gamma.view(-1, self.num_features, 1, 1, 1) * out + return out + + +class ConditionalInstanceNorm3d(nn.Module): + def __init__(self, num_features, num_classes, bias=True): + super().__init__() + self.num_features = num_features + self.bias = bias + self.instance_norm = nn.InstanceNorm3d(num_features, affine=False, track_running_stats=False) + if bias: + self.embed = nn.Embedding(num_classes, num_features * 2) + self.embed.weight.data[:, :num_features].uniform_() # Initialise scale at N(1, 0.02) + self.embed.weight.data[:, num_features:].zero_() # Initialise bias at 0 + else: + self.embed = nn.Embedding(num_classes, num_features) + self.embed.weight.data.uniform_() + + def forward(self, x, y): + h = self.instance_norm(x) + if self.bias: + gamma, beta = self.embed(y).chunk(2, dim=-1) + out = gamma.view(-1, self.num_features, 1, 1, 1) * h + beta.view(-1, self.num_features, 1, 1, 1) + else: + gamma = self.embed(y) + out = gamma.view(-1, self.num_features, 1, 1, 1) * h + return out + + +class ConditionalVarianceNorm3d(nn.Module): + def __init__(self, num_features, num_classes, bias=False): + super().__init__() + self.num_features = num_features + self.bias = bias + self.embed = nn.Embedding(num_classes, num_features) + self.embed.weight.data.normal_(1, 0.02) + + def forward(self, x, y): + vars = torch.var(x, dim=(2, 3, 4), keepdim=True) + h = x / torch.sqrt(vars + 1e-5) + + gamma = self.embed(y) + out = gamma.view(-1, self.num_features, 1, 1, 1) * h + return out + + +class VarianceNorm3d(nn.Module): + def __init__(self, num_features, bias=False): + super().__init__() + self.num_features = num_features + self.bias = bias + self.alpha = nn.Parameter(torch.zeros(num_features)) + self.alpha.data.normal_(1, 0.02) + + def forward(self, x): + vars = torch.var(x, dim=(2, 3, 4), keepdim=True) + h = x / torch.sqrt(vars + 1e-5) + + out = self.alpha.view(-1, self.num_features, 1, 1, 1) * h + return out + + +class ConditionalNoneNorm3d(nn.Module): + def __init__(self, num_features, num_classes, bias=True): + super().__init__() + self.num_features = num_features + self.bias = bias + if bias: + self.embed = nn.Embedding(num_classes, num_features * 2) + self.embed.weight.data[:, :num_features].uniform_() # Initialise scale at N(1, 0.02) + self.embed.weight.data[:, num_features:].zero_() # Initialise bias at 0 + else: + self.embed = nn.Embedding(num_classes, num_features) + self.embed.weight.data.uniform_() + + def forward(self, x, y): + if self.bias: + gamma, beta = self.embed(y).chunk(2, dim=-1) + out = gamma.view(-1, self.num_features, 1, 1, 1) * x + beta.view(-1, self.num_features, 1, 1, 1) + else: + gamma = self.embed(y) + out = gamma.view(-1, self.num_features, 1, 1, 1) * x + return out + + +class NoneNorm3d(nn.Module): + def __init__(self, num_features, bias=True): + super().__init__() + + def forward(self, x): + return x + + +class InstanceNorm3dPlus(nn.Module): + def __init__(self, num_features, bias=True): + super().__init__() + self.num_features = num_features + self.bias = bias + self.instance_norm = nn.InstanceNorm3d(num_features, affine=False, track_running_stats=False) + self.alpha = nn.Parameter(torch.zeros(num_features)) + self.gamma = nn.Parameter(torch.zeros(num_features)) + self.alpha.data.normal_(1, 0.02) + self.gamma.data.normal_(1, 0.02) + if bias: + self.beta = nn.Parameter(torch.zeros(num_features)) + + def forward(self, x): + means = torch.mean(x, dim=(2, 3, 4)) + m = torch.mean(means, dim=-1, keepdim=True) + v = torch.var(means, dim=-1, keepdim=True) + means = (means - m) / (torch.sqrt(v + 1e-5)) + h = self.instance_norm(x) + + if self.bias: + h = h + means[..., None, None, None] * self.alpha[..., None, None, None] + out = self.gamma.view(-1, self.num_features, 1, 1, 1) * h + self.beta.view(-1, self.num_features, 1, 1, 1) + else: + h = h + means[..., None, None, None] * self.alpha[..., None, None, None] + out = self.gamma.view(-1, self.num_features, 1, 1, 1) * h + return out + + +class ConditionalInstanceNorm3dPlus(nn.Module): + def __init__(self, num_features, num_classes, bias=True): + super().__init__() + self.num_features = num_features + self.bias = bias + self.instance_norm = nn.InstanceNorm3d(num_features, affine=False, track_running_stats=False) + if bias: + self.embed = nn.Embedding(num_classes, num_features * 3) + self.embed.weight.data[:, :2 * num_features].normal_(1, 0.02) # Initialise scale at N(1, 0.02) + self.embed.weight.data[:, 2 * num_features:].zero_() # Initialise bias at 0 + else: + self.embed = nn.Embedding(num_classes, 2 * num_features) + self.embed.weight.data.normal_(1, 0.02) + + def forward(self, x, y): + means = torch.mean(x, dim=(2, 3, 4)) + m = torch.mean(means, dim=-1, keepdim=True) + v = torch.var(means, dim=-1, keepdim=True) + means = (means - m) / (torch.sqrt(v + 1e-5)) + h = self.instance_norm(x) + + if self.bias: + gamma, alpha, beta = self.embed(y).chunk(3, dim=-1) + h = h + means[..., None, None, None] * alpha[..., None, None, None] + out = gamma.view(-1, self.num_features, 1, 1, 1) * h + beta.view(-1, self.num_features, 1, 1, 1) + else: + gamma, alpha = self.embed(y).chunk(2, dim=-1) + h = h + means[..., None, None, None] * alpha[..., None, None, None] + out = gamma.view(-1, self.num_features, 1, 1, 1) * h + return out + +def lip_weight_normalization_3d(W, softplus_c): + """ + Lipschitz weight normalization based on the L-infinity norm (see Eq.9 in [Liu et al 2022]) + """ + absrowsum = torch.sum(torch.abs(W), dim=[1,2,3,4]) + 1e-8 + scale = torch.nn.functional.relu(softplus_c/absrowsum - 1.0) + 1.0 + return W * scale[:, None, None, None, None] \ No newline at end of file diff --git a/lib/diffusion/models/up_or_down_sampling.py b/lib/diffusion/models/up_or_down_sampling.py new file mode 100644 index 0000000..34474a3 --- /dev/null +++ b/lib/diffusion/models/up_or_down_sampling.py @@ -0,0 +1,259 @@ +"""Layers used for up-sampling or down-sampling images. + +Many functions are ported from https://github.com/NVlabs/stylegan2. +""" + +import torch.nn as nn +import torch +import torch.nn.functional as F +import numpy as np +# from op import upfirdn2d + + +# Function ported from StyleGAN2 +def get_weight(module, + shape, + weight_var='weight', + kernel_init=None): + """Get/create weight tensor for a convolution or fully-connected layer.""" + + return module.param(weight_var, kernel_init, shape) + + +class Conv3d(nn.Module): + """Conv3d layer with optimal upsampling and downsampling (StyleGAN2).""" + + def __init__(self, in_ch, out_ch, kernel, up=False, down=False, + resample_kernel=(1, 3, 3, 1), + use_bias=True, + kernel_init=None): + super().__init__() + assert not (up and down) + assert kernel >= 1 and kernel % 2 == 1 + self.weight = nn.Parameter(torch.zeros(out_ch, in_ch, kernel, kernel, kernel)) + if kernel_init is not None: + self.weight.data = kernel_init(self.weight.data.shape) + if use_bias: + self.bias = nn.Parameter(torch.zeros(out_ch)) + + self.up = up + self.down = down + self.resample_kernel = resample_kernel + self.kernel = kernel + self.use_bias = use_bias + + def forward(self, x): + if self.up: + x = upsample_conv_3d(x, self.weight, k=self.resample_kernel) + elif self.down: + x = conv_downsample_3d(x, self.weight, k=self.resample_kernel) + else: + x = F.conv3d(x, self.weight, stride=1, padding=self.kernel // 2) + + if self.use_bias: + x = x + self.bias.reshape(1, -1, 1, 1, 1) + + return x + + +def naive_upsample_3d(x, factor=2): + _N, C, D, H, W = x.shape + x = torch.reshape(x, (-1, C, D, 1, H, 1, W, 1)) + x = x.repeat(1, 1, 1, factor, 1, factor, 1, factor) + return torch.reshape(x, (-1, C, D * factor, H * factor, W * factor)) + + +def naive_downsample_3d(x, factor=2): + _N, C, D, H, W = x.shape + x = torch.reshape(x, (-1, C, D // factor, factor, H // factor, factor, W // factor, factor)) + return torch.mean(x, dim=(3, 5, 7)) + + +def upsample_conv_3d(x, w, k=None, factor=2, gain=1): + """Fused `upsample_3d()` followed by `tf.nn.conv3d()`. + + Padding is performed only once at the beginning, not between the + operations. + The fused op is considerably more efficient than performing the same + calculation + using standard TensorFlow ops. It supports gradients of arbitrary order. + Args: + x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, + C]`. + w: Weight tensor of the shape `[filterH, filterW, inChannels, + outChannels]`. Grouped convolution can be performed by `inChannels = + x.shape[0] // numGroups`. + k: FIR filter of the shape `[firH, firW]` or `[firN]` + (separable). The default is `[1] * factor`, which corresponds to + nearest-neighbor upsampling. + factor: Integer upsampling factor (default: 2). + gain: Scaling factor for signal magnitude (default: 1.0). + + Returns: + Tensor of the shape `[N, C, H * factor, W * factor]` or + `[N, H * factor, W * factor, C]`, and same datatype as `x`. + """ + + assert isinstance(factor, int) and factor >= 1 + + # Check weight shape. + assert len(w.shape) == 5 + convD = w.shape[2] + convH = w.shape[3] + convW = w.shape[4] + inC = w.shape[1] + outC = w.shape[0] + + assert convW == convH + + # Setup filter kernel. + if k is None: + k = [1] * factor + k = _setup_kernel(k) * (gain * (factor ** 2)) + p = (k.shape[0] - factor) - (convW - 1) + + stride = (factor, factor) + + # Determine data dimensions. + stride = [1, 1, factor, factor] + output_shape = ((_shape(x, 2) - 1) * factor + convD, (_shape(x, 3) - 1) * factor + convH, (_shape(x, 4) - 1) * factor + convW) + output_padding = (output_shape[0] - (_shape(x, 2) - 1) * stride[0] - convD, + output_shape[1] - (_shape(x, 3) - 1) * stride[1] - convH, + output_shape[2] - (_shape(x, 4) - 1) * stride[2] - convW) + assert output_padding[0] >= 0 and output_padding[1] >= 0 and output_padding[2] >= 0 + num_groups = _shape(x, 1) // inC + + # Transpose weights. + w = torch.reshape(w, (num_groups, -1, inC, convD, convH, convW)) + w = w[..., ::-1, ::-1].permute(0, 2, 1, 3, 4, 5) + w = torch.reshape(w, (num_groups * inC, -1, convD, convH, convW)) + + x = F.conv_transpose3d(x, w, stride=stride, output_padding=output_padding, padding=0) + ## Original TF code. + # x = tf.nn.conv3d_transpose( + # x, + # w, + # output_shape=output_shape, + # strides=stride, + # padding='VALID', + # data_format=data_format) + ## JAX equivalent + + return upfirdn2d(x, torch.tensor(k, device=x.device), + pad=((p + 1) // 2 + factor - 1, p // 2 + 1)) + + +def conv_downsample_3d(x, w, k=None, factor=2, gain=1): + """Fused `tf.nn.conv3d()` followed by `downsample_3d()`. + + Padding is performed only once at the beginning, not between the operations. + The fused op is considerably more efficient than performing the same + calculation + using standard TensorFlow ops. It supports gradients of arbitrary order. + Args: + x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, + C]`. + w: Weight tensor of the shape `[filterH, filterW, inChannels, + outChannels]`. Grouped convolution can be performed by `inChannels = + x.shape[0] // numGroups`. + k: FIR filter of the shape `[firH, firW]` or `[firN]` + (separable). The default is `[1] * factor`, which corresponds to + average pooling. + factor: Integer downsampling factor (default: 2). + gain: Scaling factor for signal magnitude (default: 1.0). + + Returns: + Tensor of the shape `[N, C, H // factor, W // factor]` or + `[N, H // factor, W // factor, C]`, and same datatype as `x`. + """ + + assert isinstance(factor, int) and factor >= 1 + _outC, _inC, convH, convW = w.shape + assert convW == convH + if k is None: + k = [1] * factor + k = _setup_kernel(k) * gain + p = (k.shape[0] - factor) + (convW - 1) + s = [factor, factor] + x = upfirdn2d(x, torch.tensor(k, device=x.device), + pad=((p + 1) // 2, p // 2)) + return F.conv3d(x, w, stride=s, padding=0) + + +def _setup_kernel(k): + k = np.asarray(k, dtype=np.float32) + if k.ndim == 1: + k = np.outer(k, k) + k /= np.sum(k) + assert k.ndim == 2 + assert k.shape[0] == k.shape[1] + return k + + +def _shape(x, dim): + return x.shape[dim] + + +def upsample_3d(x, k=None, factor=2, gain=1): + r"""Upsample a batch of 3d images with the given filter. + + Accepts a batch of 3d images of the shape `[N, C, H, W]` or `[N, H, W, C]` + and upsamples each image with the given filter. The filter is normalized so + that + if the input pixels are constant, they will be scaled by the specified + `gain`. + Pixels outside the image are assumed to be zero, and the filter is padded + with + zeros so that its shape is a multiple of the upsampling factor. + Args: + x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, + C]`. + k: FIR filter of the shape `[firH, firW]` or `[firN]` + (separable). The default is `[1] * factor`, which corresponds to + nearest-neighbor upsampling. + factor: Integer upsampling factor (default: 2). + gain: Scaling factor for signal magnitude (default: 1.0). + + Returns: + Tensor of the shape `[N, C, H * factor, W * factor]` + """ + assert isinstance(factor, int) and factor >= 1 + if k is None: + k = [1] * factor + k = _setup_kernel(k) * (gain * (factor ** 2)) + p = k.shape[0] - factor + return upfirdn2d(x, torch.tensor(k, device=x.device), + up=factor, pad=((p + 1) // 2 + factor - 1, p // 2)) + + +def downsample_3d(x, k=None, factor=2, gain=1): + r"""Downsample a batch of 3d images with the given filter. + + Accepts a batch of 3d images of the shape `[N, C, H, W]` or `[N, H, W, C]` + and downsamples each image with the given filter. The filter is normalized + so that + if the input pixels are constant, they will be scaled by the specified + `gain`. + Pixels outside the image are assumed to be zero, and the filter is padded + with + zeros so that its shape is a multiple of the downsampling factor. + Args: + x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, + C]`. + k: FIR filter of the shape `[firH, firW]` or `[firN]` + (separable). The default is `[1] * factor`, which corresponds to + average pooling. + factor: Integer downsampling factor (default: 2). + gain: Scaling factor for signal magnitude (default: 1.0). + + Returns: + Tensor of the shape `[N, C, H // factor, W // factor]` + """ + + assert isinstance(factor, int) and factor >= 1 + if k is None: + k = [1] * factor + k = _setup_kernel(k) * gain + p = k.shape[0] - factor + return upfirdn2d(x, torch.tensor(k, device=x.device), + down=factor, pad=((p + 1) // 2, p // 2)) diff --git a/lib/diffusion/models/utils.py b/lib/diffusion/models/utils.py new file mode 100644 index 0000000..13b28de --- /dev/null +++ b/lib/diffusion/models/utils.py @@ -0,0 +1,213 @@ +# coding=utf-8 +# Copyright 2020 The Google Research Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""All functions and modules related to model definition. +""" + +import torch +from .. import sde_lib +import numpy as np + + +_MODELS = {} + + +def register_model(cls=None, *, name=None): + """A decorator for registering model classes.""" + + def _register(cls): + if name is None: + local_name = cls.__name__ + else: + local_name = name + if local_name in _MODELS: + raise ValueError(f'Already registered model with name: {local_name}') + _MODELS[local_name] = cls + return cls + + if cls is None: + return _register + else: + return _register(cls) + + +def get_model(name): + return _MODELS[name] + + +def get_sigmas(config): + """Get sigmas --- the set of noise levels for SMLD from config files. + Args: + config: A ConfigDict object parsed from the config file + Returns: + sigmas: a jax numpy arrary of noise levels + """ + sigmas = np.exp( + np.linspace(np.log(config.model.sigma_max), np.log(config.model.sigma_min), config.model.num_scales)) + + return sigmas + + +def get_ddpm_params(config): + """Get betas and alphas --- parameters used in the original DDPM paper.""" + num_diffusion_timesteps = 1000 + # parameters need to be adapted if number of time steps differs from 1000 + beta_start = config.model.beta_min / config.model.num_scales + beta_end = config.model.beta_max / config.model.num_scales + betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64) + + alphas = 1. - betas + alphas_cumprod = np.cumprod(alphas, axis=0) + sqrt_alphas_cumprod = np.sqrt(alphas_cumprod) + sqrt_1m_alphas_cumprod = np.sqrt(1. - alphas_cumprod) + + return { + 'betas': betas, + 'alphas': alphas, + 'alphas_cumprod': alphas_cumprod, + 'sqrt_alphas_cumprod': sqrt_alphas_cumprod, + 'sqrt_1m_alphas_cumprod': sqrt_1m_alphas_cumprod, + 'beta_min': beta_start * (num_diffusion_timesteps - 1), + 'beta_max': beta_end * (num_diffusion_timesteps - 1), + 'num_diffusion_timesteps': num_diffusion_timesteps + } + + +def create_model(config, use_parallel=True): + """Create the score model.""" + model_name = config.model.name + score_model = get_model(model_name)(config) + # score_model = score_model.to(config.device) + score_model = score_model + if use_parallel: + score_model = torch.nn.DataParallel(score_model).to(config.device) + return score_model + + +def get_model_fn(model, train=False): + """Create a function to give the output of the score-based model. + + Args: + model: The score model. + train: `True` for training and `False` for evaluation. + + Returns: + A model function. + """ + + def model_fn(x, labels): + """Compute the output of the score-based model. + + Args: + x: A mini-batch of input data. + labels: A mini-batch of conditioning variables for time steps. Should be interpreted differently + for different models. + + Returns: + A tuple of (model output, new mutable states) + """ + if not train: + model.eval() + return model(x, labels) + else: + model.train() + return model(x, labels) + + return model_fn + +def get_reg_fn(model, train=False): + """Create a function to give the output of the score-based model. + + Args: + model: The score model. + train: `True` for training and `False` for evaluation. + + Returns: + A model function. + """ + + def model_fn(x): + """Compute the output of the score-based model. + + Args: + x: A mini-batch of input data. + labels: A mini-batch of conditioning variables for time steps. Should be interpreted differently + for different models. + + Returns: + A tuple of (model output, new mutable states) + """ + if not train: + model.eval() + try: + return model.get_reg(x) + except: + return torch.zeros_like(x, device=x.device) + else: + model.train() + try: + return model.get_reg(x) + except: + return torch.zeros_like(x, device=x.device) + + return model_fn + +def get_score_fn(sde, model, train=False, continuous=False, std_scale=True): + """Wraps `score_fn` so that the model output corresponds to a real time-dependent score function. + + Args: + sde: An `sde_lib.SDE` object that represents the forward SDE. + model: A score model. + train: `True` for training and `False` for evaluation. + continuous: If `True`, the score-based model is expected to directly take continuous time steps. + std_scale: whether to scale the score function by the inverse of std. Used for DDIM sampling + + Returns: + A score function. + """ + model_fn = get_model_fn(model, train=train) + reg_fn = get_reg_fn(model, train=train) + + assert not continuous + if isinstance(sde, sde_lib.VPSDE): + if not std_scale: + def score_fn(x, t): + labels = t * (sde.N - 1) + score = model_fn(x, labels) + return score + else: + def score_fn(x, t): + # For VP-trained models, t=0 corresponds to the lowest noise level + labels = t * (sde.N - 1) + score = model_fn(x, labels) + std = sde.sqrt_1m_alphas_cumprod.to(labels.device)[labels.long()] + + score = -score / std[:, None, None, None, None] + return score + + else: + raise NotImplementedError(f"SDE class {sde.__class__.__name__} not yet supported.") + + return score_fn + + +def to_flattened_numpy(x): + """Flatten a torch tensor `x` and convert it to numpy.""" + return x.detach().cpu().numpy().reshape((-1,)) + + +def from_flattened_numpy(x, shape): + """Form a torch tensor with the given `shape` from a flattened numpy array `x`.""" + return torch.from_numpy(x.reshape(shape)) \ No newline at end of file diff --git a/lib/diffusion/sampling.py b/lib/diffusion/sampling.py new file mode 100644 index 0000000..6de18d6 --- /dev/null +++ b/lib/diffusion/sampling.py @@ -0,0 +1,570 @@ +# coding=utf-8 +# Copyright 2020 The Google Research Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: skip-file +# pytype: skip-file +"""Various sampling methods.""" +import functools + +import torch +import numpy as np +import abc + +from .models.utils import from_flattened_numpy, to_flattened_numpy, get_score_fn +from scipy import integrate +from . import sde_lib +from .models import utils as mutils + +import logging +import tqdm + +_CORRECTORS = {} +_PREDICTORS = {} + + +def register_predictor(cls=None, *, name=None): + """A decorator for registering predictor classes.""" + + def _register(cls): + if name is None: + local_name = cls.__name__ + else: + local_name = name + if local_name in _PREDICTORS: + raise ValueError(f'Already registered model with name: {local_name}') + _PREDICTORS[local_name] = cls + return cls + + if cls is None: + return _register + else: + return _register(cls) + + +def register_corrector(cls=None, *, name=None): + """A decorator for registering corrector classes.""" + + def _register(cls): + if name is None: + local_name = cls.__name__ + else: + local_name = name + if local_name in _CORRECTORS: + raise ValueError(f'Already registered model with name: {local_name}') + _CORRECTORS[local_name] = cls + return cls + + if cls is None: + return _register + else: + return _register(cls) + + +def get_predictor(name): + return _PREDICTORS[name] + + +def get_corrector(name): + return _CORRECTORS[name] + + +def get_sampling_fn(config, sde, shape, inverse_scaler, eps, grid_mask=None, return_traj=False): + """Create a sampling function. + + Args: + config: A `ml_collections.ConfigDict` object that contains all configuration information. + sde: A `sde_lib.SDE` object that represents the forward SDE. + shape: A sequence of integers representing the expected shape of a single sample. + inverse_scaler: The inverse data normalizer function. + eps: A `float` number. The reverse-time SDE is only integrated to `eps` for numerical stability. + + Returns: + A function that takes random states and a replicated training state and outputs samples with the + trailing dimensions matching `shape`. + """ + + sampler_name = config.sampling.method + # Probability flow ODE sampling with black-box ODE solvers + # Predictor-Corrector sampling. Predictor-only and Corrector-only samplers are special cases. + if sampler_name.lower() == 'pc': + predictor = get_predictor(config.sampling.predictor.lower()) + corrector = get_corrector(config.sampling.corrector.lower()) + sampling_fn = get_pc_sampler(sde=sde, + shape=shape, + predictor=predictor, + corrector=corrector, + inverse_scaler=inverse_scaler, + snr=config.sampling.snr, + n_steps=config.sampling.n_steps_each, + probability_flow=config.sampling.probability_flow, + continuous=config.training.continuous, + denoise=config.sampling.noise_removal, + eps=eps, + device=config.device, + grid_mask=grid_mask, + return_traj=return_traj) + elif sampler_name.lower() == 'ddim': + predictor = get_predictor('ddim') + sampling_fn = get_ddim_sampler(sde=sde, + shape=shape, + predictor=predictor, + inverse_scaler=inverse_scaler, + n_steps=config.sampling.n_steps_each, + denoise=config.sampling.noise_removal, + eps=eps, + device=config.device, + grid_mask=grid_mask) + else: + raise ValueError(f"Sampler name {sampler_name} unknown.") + + return sampling_fn + + +class Predictor(abc.ABC): + """The abstract class for a predictor algorithm.""" + + def __init__(self, sde, score_fn, probability_flow=False): + super().__init__() + self.sde = sde + # Compute the reverse SDE/ODE + self.rsde = sde.reverse(score_fn, probability_flow) + self.score_fn = score_fn + + @abc.abstractmethod + def update_fn(self, x, t): + """One update of the predictor. + + Args: + x: A PyTorch tensor representing the current state + t: A Pytorch tensor representing the current time step. + + Returns: + x: A PyTorch tensor of the next state. + x_mean: A PyTorch tensor. The next state without random noise. Useful for denoising. + """ + pass + + +class Corrector(abc.ABC): + """The abstract class for a corrector algorithm.""" + + def __init__(self, sde, score_fn, snr, n_steps): + super().__init__() + self.sde = sde + self.score_fn = score_fn + self.snr = snr + self.n_steps = n_steps + + @abc.abstractmethod + def update_fn(self, x, t): + """One update of the corrector. + + Args: + x: A PyTorch tensor representing the current state + t: A PyTorch tensor representing the current time step. + + Returns: + x: A PyTorch tensor of the next state. + x_mean: A PyTorch tensor. The next state without random noise. Useful for denoising. + """ + pass + + +@register_predictor(name='euler_maruyama') +class EulerMaruyamaPredictor(Predictor): + def __init__(self, sde, score_fn, probability_flow=False): + super().__init__(sde, score_fn, probability_flow) + + def update_fn(self, x, t): + dt = -1. / self.rsde.N + z = torch.randn_like(x) + drift, diffusion = self.rsde.sde(x, t) + x_mean = x + drift * dt + x = x_mean + diffusion[:, None, None, None, None] * np.sqrt(-dt) * z + return x, x_mean + + +@register_predictor(name='reverse_diffusion') +class ReverseDiffusionPredictor(Predictor): + def __init__(self, sde, score_fn, probability_flow=False): + super().__init__(sde, score_fn, probability_flow) + + def update_fn(self, x, t): + f, G = self.rsde.discretize(x, t) + z = torch.randn_like(x) + x_mean = x - f + x = x_mean + G[:, None, None, None, None] * z + return x, x_mean + + +@register_predictor(name='ancestral_sampling') +class AncestralSamplingPredictor(Predictor): + """The ancestral sampling predictor. Currently only supports VE/VP SDEs.""" + + def __init__(self, sde, score_fn, probability_flow=False): + super().__init__(sde, score_fn, probability_flow) + if not isinstance(sde, sde_lib.VPSDE): + raise NotImplementedError(f"SDE class {sde.__class__.__name__} not yet supported.") + assert not probability_flow, "Probability flow not supported by ancestral sampling" + + def vpsde_update_fn(self, x, t): + sde = self.sde + timestep = (t * (sde.N - 1) / sde.T).long() + beta = sde.discrete_betas.to(t.device)[timestep] + score = self.score_fn(x, t) + x_mean = (x + beta[:, None, None, None, None] * score) / torch.sqrt(1. - beta)[:, None, None, None, None] + noise = torch.randn_like(x) + x = x_mean + torch.sqrt(beta)[:, None, None, None, None] * noise + return x, x_mean + + def update_fn(self, x, t): + if isinstance(self.sde, sde_lib.VPSDE): + return self.vpsde_update_fn(x, t) + else: + raise NotImplementedError + + +@register_predictor(name='none') +class NonePredictor(Predictor): + """An empty predictor that does nothing.""" + + def __init__(self, sde, score_fn, probability_flow=False): + pass + + def update_fn(self, x, t): + return x, x + +@register_predictor(name='ddim') +class DDIMPredictor(Predictor): + def __init__(self, sde, score_fn, probability_flow=False): + super().__init__(sde, score_fn, probability_flow) + + + def update_fn(self, x, t, tprev=None): + x, x0_pred = self.rsde.discretize_ddim(x, t, tprev=tprev) + return x, x0_pred + +@register_corrector(name='langevin') +class LangevinCorrector(Corrector): + def __init__(self, sde, score_fn, snr, n_steps): + super().__init__(sde, score_fn, snr, n_steps) + if not isinstance(sde, sde_lib.VPSDE): + raise NotImplementedError(f"SDE class {sde.__class__.__name__} not yet supported.") + + def update_fn(self, x, t): + sde = self.sde + score_fn = self.score_fn + n_steps = self.n_steps + target_snr = self.snr + if isinstance(sde, sde_lib.VPSDE): + timestep = (t * (sde.N - 1) / sde.T).long() + alpha = sde.alphas.to(t.device)[timestep] + else: + alpha = torch.ones_like(t) + + for i in range(n_steps): + grad = score_fn(x, t) + noise = torch.randn_like(x) + grad_norm = torch.norm(grad.reshape(grad.shape[0], -1), dim=-1).mean() + noise_norm = torch.norm(noise.reshape(noise.shape[0], -1), dim=-1).mean() + step_size = (target_snr * noise_norm / grad_norm) ** 2 * 2 * alpha + x_mean = x + step_size[:, None, None, None, None] * grad + x = x_mean + torch.sqrt(step_size * 2)[:, None, None, None, None] * noise + + return x, x_mean + + +@register_corrector(name='ald') +class AnnealedLangevinDynamics(Corrector): + """The original annealed Langevin dynamics predictor in NCSN/NCSNv2. + + We include this corrector only for completeness. It was not directly used in our paper. + """ + + def __init__(self, sde, score_fn, snr, n_steps): + super().__init__(sde, score_fn, snr, n_steps) + if not isinstance(sde, sde_lib.VPSDE): + raise NotImplementedError(f"SDE class {sde.__class__.__name__} not yet supported.") + + def update_fn(self, x, t): + sde = self.sde + score_fn = self.score_fn + n_steps = self.n_steps + target_snr = self.snr + if isinstance(sde, sde_lib.VPSDE): + timestep = (t * (sde.N - 1) / sde.T).long() + alpha = sde.alphas.to(t.device)[timestep] + else: + alpha = torch.ones_like(t) + + std = self.sde.marginal_prob(x, t)[1] + + for i in range(n_steps): + grad = score_fn(x, t) + noise = torch.randn_like(x) + step_size = (target_snr * std) ** 2 * 2 * alpha + x_mean = x + step_size[:, None, None, None, None] * grad + x = x_mean + noise * torch.sqrt(step_size * 2)[:, None, None, None, None] + + return x, x_mean + + +@register_corrector(name='none') +class NoneCorrector(Corrector): + """An empty corrector that does nothing.""" + + def __init__(self, sde, score_fn, snr, n_steps): + pass + + def update_fn(self, x, t): + return x, x + + +def shared_predictor_update_fn(x, t, sde, model, predictor, probability_flow, continuous): + """A wrapper that configures and returns the update function of predictors.""" + score_fn = mutils.get_score_fn(sde, model, train=False, continuous=continuous) + if predictor is None: + # Corrector-only sampler + predictor_obj = NonePredictor(sde, score_fn, probability_flow) + else: + predictor_obj = predictor(sde, score_fn, probability_flow) + return predictor_obj.update_fn(x, t) + + +def shared_corrector_update_fn(x, t, sde, model, corrector, continuous, snr, n_steps): + """A wrapper tha configures and returns the update function of correctors.""" + score_fn = mutils.get_score_fn(sde, model, train=False, continuous=continuous) + if corrector is None: + # Predictor-only sampler + corrector_obj = NoneCorrector(sde, score_fn, snr, n_steps) + else: + corrector_obj = corrector(sde, score_fn, snr, n_steps) + return corrector_obj.update_fn(x, t) + + +def get_pc_sampler(sde, shape, predictor, corrector, inverse_scaler, snr, + n_steps=1, probability_flow=False, continuous=False, + denoise=True, eps=1e-3, device='cuda', grid_mask=None, return_traj=False): + """Create a Predictor-Corrector (PC) sampler. + + Args: + sde: An `sde_lib.SDE` object representing the forward SDE. + shape: A sequence of integers. The expected shape of a single sample. + predictor: A subclass of `sampling.Predictor` representing the predictor algorithm. + corrector: A subclass of `sampling.Corrector` representing the corrector algorithm. + inverse_scaler: The inverse data normalizer. + snr: A `float` number. The signal-to-noise ratio for configuring correctors. + n_steps: An integer. The number of corrector steps per predictor update. + probability_flow: If `True`, solve the reverse-time probability flow ODE when running the predictor. + continuous: `True` indicates that the score model was continuously trained. + denoise: If `True`, add one-step denoising to the final samples. + eps: A `float` number. The reverse-time SDE and ODE are integrated to `epsilon` to avoid numerical issues. + device: PyTorch device. + + Returns: + A sampling function that returns samples and the number of function evaluations during sampling. + """ + # Create predictor & corrector update functions + predictor_update_fn = functools.partial(shared_predictor_update_fn, + sde=sde, + predictor=predictor, + probability_flow=probability_flow, + continuous=continuous) + corrector_update_fn = functools.partial(shared_corrector_update_fn, + sde=sde, + corrector=corrector, + continuous=continuous, + snr=snr, + n_steps=n_steps) + + def pc_sampler(model, + partial=None, partial_grid_mask=None, partial_channel=0, + freeze_iters=None): + """ The PC sampler funciton. + + Args: + model: A score model. + Returns: + Samples, number of function evaluations. + """ + with torch.no_grad(): + + if freeze_iters is None: + freeze_iters = sde.N + 10 # just some randomly large number greater than sde.N + timesteps = torch.linspace(sde.T, eps, sde.N, device=device) + + + + def compute_xzero(sde, model, x, t, grid_mask_input): + timestep_int = (t * (sde.N - 1) / sde.T).long() + alphas1 = sde.sqrt_alphas_cumprod[timestep_int].cuda() + alphas2 = sde.sqrt_1m_alphas_cumprod[timestep_int].cuda() + alphas1_prev = sde.sqrt_alphas_cumprod[timestep_int - 1].cuda() + alphas2_prev = sde.sqrt_1m_alphas_cumprod[timestep_int - 1].cuda() + score_pred = model(x, t * torch.ones(shape[0], device=x.device)) + x0_pred_scaled = (x - alphas2 * score_pred) + x0_pred = x0_pred_scaled / alphas1 + x0_pred = x0_pred.clamp(-1, 1) + return x0_pred * grid_mask_input + + # Initial sample + x = sde.prior_sampling(shape).to(device) + assert len(x.size()) == 5 + x = x * grid_mask + + traj_buffer = [] + + if partial is not None: + assert len(partial.size()) == 5 + t = timesteps[0] + vec_t = torch.ones(shape[0], device=t.device) * t + x[:, partial_channel] = partial[:, partial_channel] * grid_mask[:, partial_channel] + + partial_mean, partial_std = sde.marginal_prob(x, vec_t) + sampled_update = partial_mean[:, partial_channel] + partial_std[:, None, None, None, None] * torch.randn_like(partial_mean[:, partial_channel], device=partial_std.device) + x[:, partial_channel] = ( + x[:, partial_channel] * (1 - partial_mask[:, partial_channel]) + + sampled_update[:, partial_channel] * partial_mask[:, partial_channel] + ) * grid_mask[:, partial_channel] + + + if partial is not None: + x_mean = x + for i in tqdm.trange(sde.N): + t = timesteps[i] + vec_t = torch.ones(shape[0], device=t.device) * t + + x, x_mean = corrector_update_fn(x, vec_t, model=model) + x, x_mean = x * grid_mask, x_mean * grid_mask + x, x_mean = predictor_update_fn(x, vec_t, model=model) + x, x_mean = x * grid_mask, x_mean * grid_mask + + + if i != sde.N - 1 and i < freeze_iters: + + x[:, partial_channel] = (x[:, partial_channel] * (1 - partial_mask[:, partial_channel]) + partial[:, partial_channel] * partial_mask[:, partial_channel]) * grid_mask[:, partial_channel] + x_mean[:, partial_channel] = (x_mean[:, partial_channel] * (1 - partial_mask[:, partial_channel]) + partial[:, partial_channel] * partial_mask[:, partial_channel]) * grid_mask[:, partial_channel] + + ### add noise to the condition x0_star + partial_mean, partial_std = sde.marginal_prob(x, timesteps[i] * torch.ones(shape[0], device=t.device)) + sampled_update = partial_mean[:, partial_channel] + partial_std[:, None, None, None] * torch.randn_like(partial_mean[:, partial_channel], device=partial_std.device) + x[:, partial_channel] = ( + x[:, partial_channel] * (1 - partial_mask[:, partial_channel]) + + sampled_update * partial_mask[:, partial_channel] + ) * grid_mask[:, partial_channel] + x_mean[:, partial_channel] = x[:, partial_channel] + + else: + + for i in tqdm.trange(sde.N - 1): + t = timesteps[i] + + vec_t = torch.ones(shape[0], device=t.device) * t + x, x_mean = corrector_update_fn(x, vec_t, model=model) + x, x_mean = x * grid_mask, x_mean * grid_mask + x, x_mean = predictor_update_fn(x, vec_t, model=model) + x, x_mean = x * grid_mask, x_mean * grid_mask + + if return_traj and i >= 700 and i % 10 == 0: + traj_buffer.append(compute_xzero(sde, model, x, t, grid_mask)) + + if return_traj: + return traj_buffer, sde.N * (n_steps + 1) + return inverse_scaler(x_mean if denoise else x), sde.N * (n_steps + 1) + + return pc_sampler + +def ddim_predictor_update_fn(x, t, tprev, sde, model, predictor, probability_flow, continuous): + """A wrapper that configures and returns the update function of predictors.""" + assert not continuous + score_fn = mutils.get_score_fn(sde, model, train=False, continuous=False, std_scale=False) + if predictor is None: + # Corrector-only sampler + predictor_obj = NonePredictor(sde, score_fn, probability_flow) + else: + predictor_obj = predictor(sde, score_fn, probability_flow) + return predictor_obj.update_fn(x, t, tprev) + +def get_ddim_sampler(sde, shape, predictor, inverse_scaler, n_steps=1, + denoise=False, eps=1e-3, device='cuda', grid_mask=None): + """Probability flow ODE sampler with the black-box ODE solver. + + Args: + sde: An `sde_lib.SDE` object that represents the forward SDE. + shape: A sequence of integers. The expected shape of a single sample. + inverse_scaler: The inverse data normalizer. + denoise: If `True`, add one-step denoising to final samples. + eps: A `float` number. The reverse-time SDE/ODE will be integrated to `eps` for numerical stability. + device: PyTorch device. + + Returns: + A sampling function that returns samples and the number of function evaluations during sampling. + """ + + predictor_update_fn = functools.partial(ddim_predictor_update_fn, + sde=sde, + predictor=predictor, + probability_flow=False, + continuous=False) + + def ddim_sampler(model, schedule='quad', num_steps=100, x0=None, + partial=None, partial_grid_mask=None, partial_channel=0): + """ The PC sampler funciton. + + Args: + model: A score model. + Returns: + Samples, number of function evaluations. + """ + with torch.no_grad(): + if x0 is not None: + x = x0 * grid_mask + else: + # Initial sample + x = sde.prior_sampling(shape).to(device) + x = x * grid_mask + + if partial is not None: + x[:, partial_channel] = x[:, partial_channel] * (1 - partial_mask) + partial * partial_mask + + timesteps = torch.linspace(sde.T, eps, sde.N, device=device) + + if schedule == 'uniform': + skip = sde.N // num_steps + seq = range(0, sde.N, skip) + elif schedule == 'quad': + seq = ( + np.linspace( + 0, np.sqrt(sde.N * 0.8), 100 + ) + ** 2 + ) + seq = [int(s) for s in list(seq)] + + timesteps = torch.tensor(seq) / sde.N + + for i in tqdm.tqdm(reversed(range(1, len(timesteps)))): + t = timesteps[i] + tprev = timesteps[i - 1] + vec_t = torch.ones(shape[0], device=t.device) * t + vec_tprev = torch.ones(shape[0], device=t.device) * tprev + x, x0_pred = predictor_update_fn(x, vec_t, model=model, tprev=vec_tprev) + x, x0_pred = x * grid_mask, x0_pred * grid_mask + if partial is not None: + x[:, partial_channel] = x[:, partial_channel] * (1 - partial_mask) + partial * partial_mask + x0_pred[:, partial_channel] = x0_pred[:, partial_channel] * (1 - partial_mask) + partial * partial_mask + + return inverse_scaler(x0_pred * grid_mask if (denoise and not encode) else x * grid_mask), sde.N * (n_steps + 1) + return ddim_sampler diff --git a/lib/diffusion/sde_lib.py b/lib/diffusion/sde_lib.py new file mode 100644 index 0000000..37127ff --- /dev/null +++ b/lib/diffusion/sde_lib.py @@ -0,0 +1,233 @@ +"""Abstract SDE classes, Reverse SDE, and VE/VP SDEs.""" +import abc +import torch +import numpy as np +import torch.nn.functional as F +import time + + +class SDE(abc.ABC): + """SDE abstract class. Functions are designed for a mini-batch of inputs.""" + + def __init__(self, N): + """Construct an SDE. + + Args: + N: number of discretization time steps. + """ + super().__init__() + self.N = N + + @property + @abc.abstractmethod + def T(self): + """End time of the SDE.""" + pass + + @abc.abstractmethod + def sde(self, x, t): + pass + + @abc.abstractmethod + def marginal_prob(self, x, t): + """Parameters to determine the marginal distribution of the SDE, $p_t(x)$.""" + pass + + @abc.abstractmethod + def prior_sampling(self, shape): + """Generate one sample from the prior distribution, $p_T(x)$.""" + pass + + @abc.abstractmethod + def prior_logp(self, z): + """Compute log-density of the prior distribution. + + Useful for computing the log-likelihood via probability flow ODE. + + Args: + z: latent code + Returns: + log probability density + """ + pass + + def discretize(self, x, t): + """Discretize the SDE in the form: x_{i+1} = x_i + f_i(x_i) + G_i z_i. + + Useful for reverse diffusion sampling and probabiliy flow sampling. + Defaults to Euler-Maruyama discretization. + + Args: + x: a torch tensor + t: a torch float representing the time step (from 0 to `self.T`) + + Returns: + f, G + """ + dt = 1 / self.N + drift, diffusion = self.sde(x, t) + f = drift * dt + G = diffusion * torch.sqrt(torch.tensor(dt, device=t.device)) + return f, G + + def reverse(self, score_fn, probability_flow=False): + """Create the reverse-time SDE/ODE. + + Args: + score_fn: A time-dependent score-based model that takes x and t and returns the score. + probability_flow: If `True`, create the reverse-time ODE used for probability flow sampling. + """ + N = self.N + T = self.T + sde_fn = self.sde + discretize_fn = self.discretize + sqrt_alphas_cumprod = self.sqrt_alphas_cumprod + sqrt_1m_alphas_cumprod = self.sqrt_1m_alphas_cumprod + + # Build the class for reverse-time SDE. + class RSDE(self.__class__): + def __init__(self): + self.N = N + self.probability_flow = probability_flow + + @property + def T(self): + return T + + def sde(self, x, t): + """Create the drift and diffusion functions for the reverse SDE/ODE.""" + drift, diffusion = sde_fn(x, t) + score = score_fn(x, t) + drift = drift - diffusion[:, None, None, None, None] ** 2 * score * (0.5 if self.probability_flow else 1.) + # Set the diffusion function to zero for ODEs. + diffusion = 0. if self.probability_flow else diffusion + return drift, diffusion + + def discretize(self, x, t): + """Create discretized iteration rules for the reverse diffusion sampler.""" + f, G = discretize_fn(x, t) + rev_f = f - G[:, None, None, None, None] ** 2 * score_fn(x, t) * (0.5 if self.probability_flow else 1.) + rev_G = torch.zeros_like(G) if self.probability_flow else G + return rev_f, rev_G + + def discretize_ddim(self, x, t, tprev=None, encode=False): + """DDPM discretization.""" + timestep = (t * (N - 1) / T).long() + timestep_prev = (tprev * (N - 1) / T).long() + + score = score_fn(x.float(), t.float()) + + # alphas1prev_div_alphas1 = torch.exp(log_diff) + alphas1 = sqrt_alphas_cumprod[timestep].cuda()[:, None, None, None, None] + alphas2 = sqrt_1m_alphas_cumprod[timestep].cuda()[:, None, None, None, None] + alphas1_prev = sqrt_alphas_cumprod[timestep_prev].cuda()[:, None, None, None, None] + alphas2_prev = sqrt_1m_alphas_cumprod[timestep_prev].cuda()[:, None, None, None, None] + alphas1prev_div_alphas1 = alphas1_prev.double() / alphas1.double() + alphas2prev_div_alphas2 = alphas2_prev.double() / alphas2.double() + + + x0_pred_scaled = (x.double() - alphas2.double() * score.double()) + use_clip = False + if use_clip: + x0_pred_scaled = x0_pred_scaled.clamp(-alphas1[0].squeeze(), alphas1[0].squeeze()) + score_scaled_t = x - x0_pred_scaled + x0_pred = x0_pred_scaled / alphas1 + + x_new = ( + alphas1prev_div_alphas1.double() * x + + (-alphas1prev_div_alphas1 + alphas2prev_div_alphas2.double()) * score_scaled_t.double() + ) + return x_new, x0_pred + + + def discretize_conditional_ddpm(self, x, t, tprev=None, condition_func=None, condition=False): + """DDPM discretization.""" + timestep = (t * (N - 1) / T).long() + timestep_prev = (tprev * (N - 1) / T).long() + + score = score_fn(x.float(), t.float()) + + # alphas1prev_div_alphas1 = torch.exp(log_diff) + alphas1 = sqrt_alphas_cumprod[timestep].cuda()[:, None, None, None, None] + alphas2 = sqrt_1m_alphas_cumprod[timestep].cuda()[:, None, None, None, None] + alphas1_prev = sqrt_alphas_cumprod[timestep_prev].cuda()[:, None, None, None, None] + alphas1prev_div_alphas1 = alphas1_prev.double() / alphas1.double() + + x0_pred_scaled = (x.double() - alphas2.double() * score.double()) + x0_pred_scaled = x0_pred_scaled.clamp(-alphas1[0].squeeze(), alphas1[0].squeeze()) + x0_pred = x0_pred_scaled / alphas1 + + if condition is None: + condition_update = 0 + else: + if (t - 0.99).mean() < 1e-3: + x = x0_pred + condition_update = condition_func(x.float(), condition) + + x_new = ( + x - alphas1prev_div_alphas1.double() * condition_update + ) + return x_new, x0_pred + + + return RSDE() + + +class VPSDE(SDE): + def __init__(self, beta_min=0.1, beta_max=20, N=1000): + """Construct a Variance Preserving SDE. + + Args: + beta_min: value of beta(0) + beta_max: value of beta(1) + N: number of discretization steps + """ + super().__init__(N) + self.beta_0 = beta_min + self.beta_1 = beta_max + self.N = N + self.discrete_betas = torch.linspace(beta_min / N, beta_max / N, N).cuda() + self.alphas = 1. - self.discrete_betas + self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) + self.alphas_cumprod_ext = torch.cat([torch.tensor([1.0 - 1e-4]).cuda(), torch.cumprod(self.alphas, dim=0)], dim=0) + + self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod) + self.sqrt_1m_alphas_cumprod = torch.sqrt(1. - self.alphas_cumprod) + + self.alphas_cumprod = self.alphas_cumprod + self.alphas_cumprod_ext = self.alphas_cumprod_ext + + @property + def T(self): + return 1 + + def sde(self, x, t): + beta_t = self.beta_0 + t * (self.beta_1 - self.beta_0) + drift = -0.5 * beta_t[:, None, None, None, None] * x + diffusion = torch.sqrt(beta_t) + return drift, diffusion + + def marginal_prob(self, x, t): + log_mean_coeff = -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0 + mean = torch.exp(log_mean_coeff[:, None, None, None, None]) * x + std = torch.sqrt(1. - torch.exp(2. * log_mean_coeff)) + return mean, std + + def prior_sampling(self, shape): + return torch.randn(*shape) + + def prior_logp(self, z): + shape = z.shape + N = np.prod(shape[1:]) + logps = -N / 2. * np.log(2 * np.pi) - torch.sum(z ** 2, dim=(1, 2, 3, 4)) / 2. + return logps + + def discretize(self, x, t): + """DDPM discretization.""" + timestep = (t * (self.N - 1) / self.T).long() + beta = self.discrete_betas.to(x.device)[timestep] + alpha = self.alphas.to(x.device)[timestep] + sqrt_beta = torch.sqrt(beta) + f = torch.sqrt(alpha)[:, None, None, None, None] * x - x + G = sqrt_beta + return f, G \ No newline at end of file diff --git a/lib/diffusion/trainer.py b/lib/diffusion/trainer.py new file mode 100644 index 0000000..9dc5c93 --- /dev/null +++ b/lib/diffusion/trainer.py @@ -0,0 +1,130 @@ +import os +import sys +import numpy as np + +import logging +# Keep the import below for registering all model definitions +from .models import ddpm_res64, ddpm_res128 + +from . import losses +from .models import utils as mutils +from .models.ema import ExponentialMovingAverage +from . import sde_lib +import torch +from torch.utils import tensorboard +from .utils import save_checkpoint, restore_checkpoint +from ..dataset.shapenet_dmtet_dataset import ShapeNetDMTetDataset + +def train(config): + """Runs the training pipeline. + + Args: + config: Configuration to use. + workdir: Working directory for checkpoints and TF summaries. If this + contains checkpoint training will be resumed from the latest checkpoint. + """ + + workdir = config.training.train_dir + # Create directories for experimental logs + logging.info("working dir: {:s}".format(workdir)) + + + tb_dir = os.path.join(workdir, "tensorboard") + writer = tensorboard.SummaryWriter(tb_dir) + + resolution = config.data.image_size + # Initialize model. + score_model = mutils.create_model(config) + ema = ExponentialMovingAverage(score_model.parameters(), decay=config.model.ema_rate) + optimizer = losses.get_optimizer(config, score_model.parameters()) + state = dict(optimizer=optimizer, model=score_model, ema=ema, step=0) + + + # Create checkpoints directory + checkpoint_dir = os.path.join(workdir, "checkpoints") + # Intermediate checkpoints to resume training after pre-emption in cloud environments + checkpoint_meta_dir = os.path.join(workdir, "checkpoints-meta", "checkpoint.pth") + os.makedirs(checkpoint_dir, exist_ok=True) + os.makedirs(os.path.dirname(checkpoint_meta_dir), exist_ok=True) + + # Resume training when intermediate checkpoints are detected + state = restore_checkpoint(checkpoint_meta_dir, state, config.device) + initial_step = int(state['step']) + + json_path = config.data.meta_path + print("----- Assigning mask -----") + logging.info(f"{json_path}, {config.data.filter_meta_path}") + + ### mask on tet to ignore regions + mask = torch.load(f'./data/grid_mask_{resolution}.pt').view(1, 1, resolution, resolution, resolution).to("cuda") + + if hasattr(score_model.module, 'mask'): + print("----- Assigning mask -----") + score_model.module.mask.data[:] = mask[:] + + print(f"work dir: {workdir}") + + print("sdf normalized or not: ", config.data.normalize_sdf) + train_dataset = ShapeNetDMTetDataset(json_path, deform_scale=config.model.deform_scale, aug=True, grid_mask=mask, + filter_meta_path=config.data.filter_meta_path, normalize_sdf=config.data.normalize_sdf) + + + train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=config.training.batch_size, + shuffle=True, + num_workers=config.data.num_workers, + pin_memory=True) + + data_iter = iter(train_loader) + + print("data loader set") + + # Setup SDEs + sde = sde_lib.VPSDE(beta_min=config.model.beta_min, beta_max=config.model.beta_max, N=config.model.num_scales) + + # Build one-step training and evaluation functions + optimize_fn = losses.optimization_manager(config) + train_step_fn = losses.get_step_fn(sde, train=True, optimize_fn=optimize_fn, + mask=mask, loss_type=config.training.loss_type) + + num_train_steps = config.training.n_iters + + # In case there are multiple hosts (e.g., TPU pods), only log to host 0 + logging.info("Starting training loop at step %d." % (initial_step // config.training.iter_size,)) + + iter_size = config.training.iter_size + for step in range(initial_step // iter_size, num_train_steps + 1): + tmp_loss = 0.0 + for step_inner in range(iter_size): + try: + # batch, batch_mask = next(data_iter) + batch = next(data_iter) + except StopIteration: + # StopIteration is thrown if dataset ends + # reinitialize data loader + data_iter = iter(train_loader) + batch = next(data_iter) + + batch = batch.cuda() + + # Execute one training step + clear_grad_flag = (step_inner == 0) + update_param_flag = (step_inner == iter_size - 1) + loss_dict = train_step_fn(state, batch, clear_grad=clear_grad_flag, update_param=update_param_flag) + loss = loss_dict['loss'] + tmp_loss += loss.item() + + tmp_loss /= iter_size + if step % config.training.log_freq == 0: + logging.info("step: %d, training_loss: %.5e" % (step, tmp_loss)) + sys.stdout.flush() + writer.add_scalar("training_loss", loss, step) + + # Save a temporary checkpoint to resume training after pre-emption periodically + if step != 0 and step % config.training.snapshot_freq_for_preemption == 0: + logging.info(f"save meta at iter {step}") + save_checkpoint(checkpoint_meta_dir, state) + + # Save a checkpoint periodically and generate samples if needed + if step != 0 and step % config.training.snapshot_freq == 0 or step == num_train_steps: + logging.info(f"save model: {step}-th") + save_checkpoint(os.path.join(checkpoint_dir, f'checkpoint_{step}.pth'), state) diff --git a/lib/diffusion/utils.py b/lib/diffusion/utils.py new file mode 100644 index 0000000..8cac065 --- /dev/null +++ b/lib/diffusion/utils.py @@ -0,0 +1,31 @@ +import torch +import tensorflow as tf +import os +import logging + + +def restore_checkpoint(ckpt_dir, state, device, strict=False): + if not tf.io.gfile.exists(ckpt_dir): + tf.io.gfile.makedirs(os.path.dirname(ckpt_dir)) + logging.warning(f"No checkpoint found at {ckpt_dir}. " + f"Returned the same state as input") + if strict: + raise + return state + else: + loaded_state = torch.load(ckpt_dir, map_location=device) + state['optimizer'].load_state_dict(loaded_state['optimizer']) + state['model'].load_state_dict(loaded_state['model'], strict=False) + state['ema'].load_state_dict(loaded_state['ema']) + state['step'] = loaded_state['step'] + return state + + +def save_checkpoint(ckpt_dir, state): + saved_state = { + 'optimizer': state['optimizer'].state_dict(), + 'model': state['model'].state_dict(), + 'ema': state['ema'].state_dict(), + 'step': state['step'] + } + torch.save(saved_state, ckpt_dir) \ No newline at end of file diff --git a/nvdiffrec/lib/dataset/dataset.py b/nvdiffrec/lib/dataset/dataset.py new file mode 100644 index 0000000..5eddffe --- /dev/null +++ b/nvdiffrec/lib/dataset/dataset.py @@ -0,0 +1,77 @@ +# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +import torch + +class Dataset(torch.utils.data.Dataset): + """Basic dataset interface""" + def __init__(self): + super().__init__() + + def __len__(self): + raise NotImplementedError + + def __getitem__(self): + raise NotImplementedError + + def collate(self, batch): + iter_res, iter_spp = batch[0]['resolution'], batch[0]['spp'] + res_dict = { + 'mv' : torch.cat(list([item['mv'] for item in batch]), dim=0), + 'mvp' : torch.cat(list([item['mvp'] for item in batch]), dim=0), + 'campos' : torch.cat(list([item['campos'] for item in batch]), dim=0), + 'resolution' : iter_res, + 'spp' : iter_spp, + 'img' : torch.cat(list([item['img'] for item in batch]), dim=0) + } + + if 'spts' in batch[0]: + res_dict['spts'] = batch[0]['spts'] + if 'vpts' in batch[0]: + res_dict['vpts'] = batch[0]['vpts'] + if 'faces' in batch[0]: + res_dict['faces'] = batch[0]['faces'] + if 'rast_triangle_id' in batch[0]: + res_dict['rast_triangle_id'] = batch[0]['rast_triangle_id'] + + if 'depth' in batch[0]: + res_dict['depth'] = torch.cat(list([item['depth'] for item in batch]), dim=0) + if 'normal' in batch[0]: + res_dict['normal'] = torch.cat(list([item['normal'] for item in batch]), dim=0) + if 'geo_normal' in batch[0]: + res_dict['geo_normal'] = torch.cat(list([item['geo_normal'] for item in batch]), dim=0) + if 'geo_viewdir' in batch[0]: + res_dict['geo_viewdir'] = torch.cat(list([item['geo_viewdir'] for item in batch]), dim=0) + if 'pos' in batch[0]: + res_dict['pos'] = torch.cat(list([item['pos'] for item in batch]), dim=0) + if 'mask' in batch[0]: + res_dict['mask'] = torch.cat(list([item['mask'] for item in batch]), dim=0) + if 'mask_cont' in batch[0]: + res_dict['mask_cont'] = torch.cat(list([item['mask_cont'] for item in batch]), dim=0) + if 'envlight_transform' in batch[0]: + if batch[0]['envlight_transform'] is not None: + res_dict['envlight_transform'] = torch.cat(list([item['envlight_transform'] for item in batch]), dim=0) + else: + res_dict['envlight_transform'] = None + + try: + res_dict['depth_second'] = torch.cat(list([item['depth_second'] for item in batch]), dim=0) + except: + pass + try: + res_dict['normal_second'] = torch.cat(list([item['normal_second'] for item in batch]), dim=0) + except: + pass + try: + res_dict['img_second'] = torch.cat(list([item['img_second'] for item in batch]), dim=0) + except: + pass + + + return res_dict \ No newline at end of file diff --git a/nvdiffrec/lib/dataset/dataset_mesh.py b/nvdiffrec/lib/dataset/dataset_mesh.py new file mode 100644 index 0000000..6a73fa1 --- /dev/null +++ b/nvdiffrec/lib/dataset/dataset_mesh.py @@ -0,0 +1,163 @@ +# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +import numpy as np +import torch +import sys + +from ..render import util +from ..render import mesh +from ..render import render +from ..render import light + +from .dataset import Dataset + +import kaolin + +############################################################################### +# Reference dataset using mesh & rendering +############################################################################### + +class DatasetMesh(Dataset): + + def __init__(self, ref_mesh, glctx, cam_radius, FLAGS, validate=False): + # Init + self.glctx = glctx + self.cam_radius = cam_radius + self.FLAGS = FLAGS + self.validate = validate + self.fovy = np.deg2rad(45) + self.aspect = FLAGS.train_res[1] / FLAGS.train_res[0] + self.random_lgt = FLAGS.random_lgt + self.camera_lgt = False + self.flat_shading = FLAGS.dataset_flat_shading + + + if self.FLAGS.local_rank == 0: + print(f"use flag shading {FLAGS.dataset_flat_shading}") + print("DatasetMesh: ref mesh has %d triangles and %d vertices" % (ref_mesh.t_pos_idx.shape[0], ref_mesh.v_pos.shape[0])) + + # Sanity test training texture resolution + ref_texture_res = np.maximum(ref_mesh.material['kd'].getRes(), ref_mesh.material['ks'].getRes()) + if 'normal' in ref_mesh.material: + ref_texture_res = np.maximum(ref_texture_res, ref_mesh.material['normal'].getRes()) + if self.FLAGS.local_rank == 0 and FLAGS.texture_res[0] < ref_texture_res[0] or FLAGS.texture_res[1] < ref_texture_res[1]: + print("---> WARNING: Picked a texture resolution lower than the reference mesh [%d, %d] < [%d, %d]" % (FLAGS.texture_res[0], FLAGS.texture_res[1], ref_texture_res[0], ref_texture_res[1])) + + print("Loading env map") + sys.stdout.flush() + # Load environment map texture + self.envlight = light.load_env(FLAGS.envmap, scale=FLAGS.env_scale) + + print("Computing tangents") + sys.stdout.flush() + try: + self.ref_mesh = mesh.compute_tangents(ref_mesh) + except Exception as e: + print(e) + print("Continue without tangents...") + self.ref_mesh = ref_mesh + + def _rotate_scene(self, itr): + proj_mtx = util.perspective(self.fovy, self.FLAGS.display_res[1] / self.FLAGS.display_res[0], self.FLAGS.cam_near_far[0], self.FLAGS.cam_near_far[1]) + + # Smooth rotation for display. + ang = (itr / 50) * np.pi * 2 + mv = util.translate(0, 0, -self.cam_radius) @ (util.rotate_x(-0.4) @ util.rotate_y(ang)) + mvp = proj_mtx @ mv + campos = torch.linalg.inv(mv)[:3, 3] + + return mv[None, ...].cuda(), mvp[None, ...].cuda(), campos[None, ...].cuda(), self.FLAGS.display_res, self.FLAGS.spp + + def _random_scene(self): + # ============================================================================================== + # Setup projection matrix + # ============================================================================================== + iter_res = self.FLAGS.train_res + proj_mtx = util.perspective(self.fovy, iter_res[1] / iter_res[0], self.FLAGS.cam_near_far[0], self.FLAGS.cam_near_far[1]) + + # ============================================================================================== + # Random camera & light position + # ============================================================================================== + + # Random rotation/translation matrix for optimization. + mv = util.translate(0, 0, -self.cam_radius) @ util.random_rotation_translation(0.2) + mvp = proj_mtx @ mv + campos = torch.linalg.inv(mv)[:3, 3] + + return mv[None, ...].cuda(), mvp[None, ...].cuda(), campos[None, ...].cuda(), iter_res, self.FLAGS.spp # Add batch dimension + + def __len__(self): + return 50 if self.validate else (self.FLAGS.iter + 1) * self.FLAGS.batch + + def __getitem__(self, itr): + # ============================================================================================== + # Randomize scene parameters + # ============================================================================================== + + if self.validate: + mv, mvp, campos, iter_res, iter_spp = self._rotate_scene(itr) + camera_mv = None + else: + mv, mvp, campos, iter_res, iter_spp = self._random_scene() + if self.random_lgt: + rnd_rot = util.random_rotation() + camera_mv = rnd_rot.unsqueeze(0).clone() + elif self.camera_lgt: + camera_mv = mv.clone() + else: + camera_mv = None + + + + with torch.no_grad(): + render_out = render.render_mesh(self.glctx, self.ref_mesh, mvp, campos, self.envlight, iter_res, spp=iter_spp, + num_layers=self.FLAGS.layers, msaa=True, background=None, xfm_lgt=camera_mv, flat_shading=self.flat_shading) + img = render_out['shaded'] + img_second = render_out['shaded_second'] + normal = render_out['normal'] + depth = render_out['depth'] + geo_normal = render_out['geo_normal'] + pos = render_out['pos'] + + sample_points = torch.tensor(kaolin.ops.mesh.sample_points(self.ref_mesh.v_pos.unsqueeze(0), self.ref_mesh.t_pos_idx, 50000)[0][0]) + vertex_points = self.ref_mesh.v_pos + + return_dict = { + 'mv' : mv, + 'mvp' : mvp, + 'campos' : campos, + 'resolution' : iter_res, + 'spp' : iter_spp, + 'img' : img, + 'img_second' : img_second, + 'spts': sample_points, + 'vpts': vertex_points, + 'faces': self.ref_mesh.t_pos_idx, + 'depth': depth, + 'normal': normal, + 'geo_normal': geo_normal, + 'geo_viewdir': render_out['geo_viewdir'], + 'pos': pos, + 'envlight_transform': camera_mv, + 'mask': render_out['mask'], + 'mask_cont': render_out['mask_cont'], + 'rast_triangle_id': render_out['rast_triangle_id'] + } + + try: + return_dict['depth_second'] = render_out['depth_second'] + except: + pass + + try: + return_dict['normal_second'] = render_out['normal_second'] + except: + pass + return return_dict diff --git a/nvdiffrec/lib/dataset/dataset_shapenet.py b/nvdiffrec/lib/dataset/dataset_shapenet.py new file mode 100644 index 0000000..b3b86f5 --- /dev/null +++ b/nvdiffrec/lib/dataset/dataset_shapenet.py @@ -0,0 +1,32 @@ +# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +import numpy as np +import torch + +import os +import json + +############################################################################### +# ShapeNet Dataset to get meshes +############################################################################### + +class ShapeNetDataset(object): + def __init__(self, shapenet_json): + with open(shapenet_json, 'r') as f: + self.mesh_list = json.load(f) + + print(f"len all data: {len(self.mesh_list)}") + + def __len__(self): + return len(self.mesh_list) + + + def __getitem__(self, idx): + return self.mesh_list[idx] diff --git a/nvdiffrec/lib/geometry/dmtet.py b/nvdiffrec/lib/geometry/dmtet.py new file mode 100644 index 0000000..1ee43f4 --- /dev/null +++ b/nvdiffrec/lib/geometry/dmtet.py @@ -0,0 +1,462 @@ +# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +import os +import numpy as np +import torch + +from ..render import mesh +from ..render import render +from ..render import regularizer + + +import kaolin +import pytorch3d.ops +from ..render import util as render_utils + +import torch.nn.functional as F + +from ..render import renderutils as ru + +############################################################################### +# Marching tetrahedrons implementation (differentiable), adapted from +# https://github.com/NVIDIAGameWorks/kaolin/blob/master/kaolin/ops/conversions/tetmesh.py +############################################################################### + +class DMTet: + def __init__(self): + self.triangle_table = torch.tensor([ + [-1, -1, -1, -1, -1, -1], + [ 1, 0, 2, -1, -1, -1], + [ 4, 0, 3, -1, -1, -1], + [ 1, 4, 2, 1, 3, 4], + [ 3, 1, 5, -1, -1, -1], + [ 2, 3, 0, 2, 5, 3], + [ 1, 4, 0, 1, 5, 4], + [ 4, 2, 5, -1, -1, -1], + [ 4, 5, 2, -1, -1, -1], + [ 4, 1, 0, 4, 5, 1], + [ 3, 2, 0, 3, 5, 2], + [ 1, 3, 5, -1, -1, -1], + [ 4, 1, 2, 4, 3, 1], + [ 3, 0, 4, -1, -1, -1], + [ 2, 0, 1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1] + ], dtype=torch.long, device='cuda') + + self.num_triangles_table = torch.tensor([0,1,1,2,1,2,2,1,1,2,2,1,2,1,1,0], dtype=torch.long, device='cuda') + self.base_tet_edges = torch.tensor([0,1,0,2,0,3,1,2,1,3,2,3], dtype=torch.long, device='cuda') + + ############################################################################### + # Utility functions + ############################################################################### + + def sort_edges(self, edges_ex2): + with torch.no_grad(): + order = (edges_ex2[:,0] > edges_ex2[:,1]).long() + order = order.unsqueeze(dim=1) + + a = torch.gather(input=edges_ex2, index=order, dim=1) + b = torch.gather(input=edges_ex2, index=1-order, dim=1) + + return torch.stack([a, b],-1) + + def map_uv(self, faces, face_gidx, max_idx): + N = int(np.ceil(np.sqrt((max_idx+1)//2))) + tex_y, tex_x = torch.meshgrid( + torch.linspace(0, 1 - (1 / N), N, dtype=torch.float32, device="cuda"), + torch.linspace(0, 1 - (1 / N), N, dtype=torch.float32, device="cuda"), + indexing='ij' + ) + + pad = 0.9 / N + + uvs = torch.stack([ + tex_x , tex_y, + tex_x + pad, tex_y, + tex_x + pad, tex_y + pad, + tex_x , tex_y + pad + ], dim=-1).view(-1, 2) + + def _idx(tet_idx, N): + x = tet_idx % N + y = torch.div(tet_idx, N, rounding_mode='trunc') + return y * N + x + + tet_idx = _idx(torch.div(face_gidx, 2, rounding_mode='trunc'), N) + tri_idx = face_gidx % 2 + + uv_idx = torch.stack(( + tet_idx * 4, tet_idx * 4 + tri_idx + 1, tet_idx * 4 + tri_idx + 2 + ), dim = -1). view(-1, 3) + + return uvs, uv_idx + + ############################################################################### + # Marching tets implementation + ############################################################################### + + def __call__(self, pos_nx3, sdf_n, tet_fx4): + with torch.no_grad(): + occ_n = sdf_n > 0 + occ_fx4 = occ_n[tet_fx4.reshape(-1)].reshape(-1,4) + occ_sum = torch.sum(occ_fx4, -1) + valid_tets = (occ_sum>0) & (occ_sum<4) + occ_sum = occ_sum[valid_tets] + + # find all vertices + all_edges = tet_fx4[valid_tets][:,self.base_tet_edges].reshape(-1,2) + all_edges = self.sort_edges(all_edges) + unique_edges, idx_map = torch.unique(all_edges,dim=0, return_inverse=True) + + unique_edges = unique_edges.long() + mask_edges = occ_n[unique_edges.reshape(-1)].reshape(-1,2).sum(-1) == 1 + mapping = torch.ones((unique_edges.shape[0]), dtype=torch.long, device="cuda") * -1 + mapping[mask_edges] = torch.arange(mask_edges.sum(), dtype=torch.long,device="cuda") + idx_map = mapping[idx_map] # map edges to verts + + interp_v = unique_edges[mask_edges] + edges_to_interp = pos_nx3[interp_v.reshape(-1)].reshape(-1,2,3) + edges_to_interp_sdf = sdf_n[interp_v.reshape(-1)].reshape(-1,2,1) + edges_to_interp_sdf[:,-1] *= -1 + + denominator = edges_to_interp_sdf.sum(1,keepdim = True) + + edges_to_interp_sdf = torch.flip(edges_to_interp_sdf, [1])/denominator + verts = (edges_to_interp * edges_to_interp_sdf).sum(1) + + idx_map = idx_map.reshape(-1,6) + + v_id = torch.pow(2, torch.arange(4, dtype=torch.long, device="cuda")) + tetindex = (occ_fx4[valid_tets] * v_id.unsqueeze(0)).sum(-1) + num_triangles = self.num_triangles_table[tetindex] + + # Generate triangle indices + faces = torch.cat(( + torch.gather(input=idx_map[num_triangles == 1], dim=1, index=self.triangle_table[tetindex[num_triangles == 1]][:, :3]).reshape(-1,3), + torch.gather(input=idx_map[num_triangles == 2], dim=1, index=self.triangle_table[tetindex[num_triangles == 2]][:, :6]).reshape(-1,3), + ), dim=0) + + # Get global face index (static, does not depend on topology) + num_tets = tet_fx4.shape[0] + tet_gidx = torch.arange(num_tets, dtype=torch.long, device="cuda")[valid_tets] + face_gidx = torch.cat(( + tet_gidx[num_triangles == 1]*2, + torch.stack((tet_gidx[num_triangles == 2]*2, tet_gidx[num_triangles == 2]*2 + 1), dim=-1).view(-1) + ), dim=0) + + uvs, uv_idx = self.map_uv(faces, face_gidx, num_tets*2) + + face_to_valid_tet = torch.cat(( + tet_gidx[num_triangles == 1], + torch.stack((tet_gidx[num_triangles == 2], tet_gidx[num_triangles == 2]), dim=-1).view(-1) + ), dim=0) + + valid_vert_idx = tet_fx4[tet_gidx[num_triangles > 0]].long().unique() + + return verts, faces, uvs, uv_idx, face_to_valid_tet.long(), valid_vert_idx + +############################################################################### +# Regularizer +############################################################################### + +def sdf_reg_loss(sdf, all_edges): + sdf_f1x6x2 = sdf[all_edges.reshape(-1)].reshape(-1,2) + mask = torch.sign(sdf_f1x6x2[...,0]) != torch.sign(sdf_f1x6x2[...,1]) + sdf_f1x6x2 = sdf_f1x6x2[mask] + sdf_diff = torch.nn.functional.binary_cross_entropy_with_logits(sdf_f1x6x2[...,0], (sdf_f1x6x2[...,1] > 0).float()) + \ + torch.nn.functional.binary_cross_entropy_with_logits(sdf_f1x6x2[...,1], (sdf_f1x6x2[...,0] > 0).float()) + return sdf_diff + + + +class Buffer(object): + def __init__(self, shape, capacity, device) -> None: + self.len_curr = 0 + self.pointer = 0 + self.capacity = capacity + self.buffer = torch.zeros((capacity, ) + shape, device=device) + + def push(self, x): + ''' + Push one single data point into the buffer + ''' + self.buffer[self.pointer] = x + self.pointer = (self.pointer + 1) % self.capacity + if self.len_curr < self.capacity: + self.len_curr += 1 + + def avg(self): + # simple windowed avg without exp decay + return torch.sign(torch.sign(self.buffer[:self.len_curr]).float().mean(dim=0)).float() + +############################################################################### +# Geometry interface +############################################################################### + +class DMTetGeometry(torch.nn.Module): + def __init__(self, grid_res, scale, FLAGS, root='./', grid_to_tet=None, deform_scale=1.0, **kwargs): + super(DMTetGeometry, self).__init__() + + self.FLAGS = FLAGS + self.grid_res = grid_res + self.marching_tets = DMTet() + self.tanh = False + self.deform_scale = deform_scale + + self.grid_to_tet = grid_to_tet + + self.padding = 5 + self.smooth_kernel = torch.ones(1, 1, self.padding*2 + 1, self.padding*2 + 1).cuda() + + tets = np.load(os.path.join(root, 'data/tets/{}_tets_cropped.npz'.format(self.grid_res))) + self.verts = torch.tensor(tets['vertices'], dtype=torch.float32, device='cuda') * scale + self.indices = torch.tensor(tets['indices'], dtype=torch.long, device='cuda') + self.generate_edges() + + # Random init + sdf = torch.rand_like(self.verts[:,0]).clamp(-1.0, 1.0) - 0.1 + + self.sdf = torch.nn.Parameter(sdf.clone().detach(), requires_grad=True) + self.register_parameter('sdf', self.sdf) + + self.deform = torch.nn.Parameter(torch.zeros_like(self.verts), requires_grad=True) + self.register_parameter('deform', self.deform) + + self.sdf_ema = torch.nn.Parameter(sdf.clone().detach(), requires_grad=False) + self.deform_ema = torch.nn.Parameter(self.deform.clone().detach(), requires_grad=False) + + self.ema_coeff = 0.9 + self.sdf_buffer = Buffer(sdf.size(), capacity=200, device='cuda') + + + def generate_edges(self): + with torch.no_grad(): + edges = torch.tensor([0,1,0,2,0,3,1,2,1,3,2,3], dtype = torch.long, device = "cuda") + all_edges = self.indices[:,edges].reshape(-1,2) + all_edges_sorted = torch.sort(all_edges, dim=1)[0] + self.all_edges = torch.unique(all_edges_sorted, dim=0) + + def getAABB(self): + return torch.min(self.verts, dim=0).values, torch.max(self.verts, dim=0).values + + def getVertNNDist(self): + v_deformed = (self.verts + 2 / (self.grid_res * 2) * torch.tanh(self.deform)).unsqueeze(0) + return (pytorch3d.ops.knn.knn_points(v_deformed, v_deformed, K=2).dists[0, :, -1].detach()) ## K=2 because dist(self, self)=0 + + def getTetCenters(self): + v_deformed = self.get_deformed() # size: N x 3 + face_verts = v_deformed[self.indices] # size: M x 4 x 3 + face_centers = face_verts.mean(dim=1) # size: M x 3 + + return face_centers + + def getValidTetIdx(self): + # Run DM tet to get a base mesh + v_deformed = self.get_deformed() + verts, faces, uvs, uv_idx, tet_gidx, valid_vert_idx = self.marching_tets(v_deformed, self.sdf, self.indices) + return tet_gidx.long() + + def getValidVertsIdx(self): + # Run DM tet to get a base mesh + v_deformed = self.get_deformed() + verts, faces, uvs, uv_idx, tet_gidx, valid_vert_idx = self.marching_tets(v_deformed, self.sdf, self.indices) + return self.indices[tet_gidx.long()].unique() + + def getMesh(self, material, noise=0.0, ema=False): + # Run DM tet to get a base mesh + v_deformed = self.get_deformed(ema=ema) + + if ema: + # sdf = self.sdf * (1 - self.ema_coeff) + self.sdf_ema.detach() * self.ema_coeff + sdf = self.sdf_ema + else: + sdf = self.sdf + + verts, faces, uvs, uv_idx, tet_gidx, valid_vert_idx = self.marching_tets(v_deformed, sdf, self.indices) + imesh = mesh.Mesh(verts, faces, v_tex=uvs, t_tex_idx=uv_idx, material=material) + + imesh = mesh.auto_normals(imesh) + if material is not None: + # Run mesh operations to generate tangent space + imesh = mesh.compute_tangents(imesh) + imesh.valid_vert_idx = valid_vert_idx + + return imesh + + def get_deformed(self, no_grad=False, ema=False): + if no_grad: + deform = self.deform.detach() + else: + deform = self.deform + + if self.tanh: + # v_deformed = self.verts + 2 / (self.grid_res * 2) * torch.tanh(self.deform) + v_deformed = self.verts + 2 / (self.grid_res * 2) * torch.tanh(deform) * self.deform_scale + else: + v_deformed = self.verts + 2 / (self.grid_res * 2) * deform * self.deform_scale + return v_deformed + + def get_angle(self): + with torch.no_grad(): + comb_list = [ + (0, 1, 2, 3), + (0, 1, 3, 2), + (0, 2, 3, 1), + (1, 2, 3, 0) + ] + + directions = torch.zeros(self.indices.size(0), 4).cuda() + dir_vec = torch.zeros(self.indices.size(0), 4, 3).cuda() + vert_inds = torch.zeros(self.indices.size(0), 4).cuda().long() + count = 0 + vpos_list = self.get_deformed() + for comb in comb_list: + face = self.indices[:, comb[:3]] + face_pos = vpos_list[face, :] + face_center = face_pos.mean(1, keepdim=False) + v = self.indices[:, comb[3]] + test_vec = vpos_list[v] + ref_vec = render_utils.safe_normalize(vpos_list[face[:, 0]] - face_center) + distance_vec = test_vec - render_utils.dot(test_vec, ref_vec) * ref_vec + directions[:, count] = torch.sign(render_utils.dot(test_vec, distance_vec)[:, 0]) + dir_vec[:, count, :] = distance_vec + vert_inds[:, count] = v + count += 1 + return directions, dir_vec, vert_inds + + + def clamp_deform(self): + if not self.tanh: + self.deform.data[:] = self.deform.data.clamp(-0.99, 0.99) + self.sdf.data[:] = self.sdf.data.clamp(-1.0, 1.0) + + def render(self, glctx, target, lgt, opt_material, bsdf=None, ema=False, xfm_lgt=None, get_visible_tets=False): + opt_mesh = self.getMesh(opt_material, ema=ema) + tet_centers = self.getTetCenters() if get_visible_tets else None + return render.render_mesh(glctx, opt_mesh, target['mvp'], target['campos'], lgt, target['resolution'], spp=target['spp'], + msaa=True, background=target['background'], bsdf=bsdf, xfm_lgt=xfm_lgt, tet_centers=tet_centers) + + def render_with_mesh(self, glctx, target, lgt, opt_material, bsdf=None, noise=0.0, ema=False, xfm_lgt=None): + opt_mesh = self.getMesh(opt_material, noise=noise, ema=ema) + return opt_mesh, render.render_mesh(glctx, opt_mesh, target['mvp'], target['campos'], lgt, target['resolution'], spp=target['spp'], + msaa=True, background=target['background'], bsdf=bsdf, xfm_lgt=xfm_lgt) + + def update_ema(self, ema_coeff=0.9): + self.sdf_buffer.push(self.sdf) + self.sdf_ema.data[:] = self.sdf_buffer.avg() + self.deform_ema.data[:] = self.deform.data[:] + + + def render_ema(self, glctx, target, lgt, opt_material, bsdf=None, xfm_lgt=None): + opt_mesh = self.getMesh(opt_material, ema=True) + return render.render_mesh(glctx, opt_mesh, target['mvp'], target['campos'], lgt, target['resolution'], spp=target['spp'], + msaa=True, background=target['background'], bsdf=bsdf, xfm_lgt=xfm_lgt) + + def tick(self, glctx, target, lgt, opt_material, loss_fn, iteration, with_reg=True, xfm_lgt=None, no_depth_thin=True): + + self.deform.requires_grad = True + + if iteration > 200 and iteration < 2000 and iteration % 20 == 0: + with torch.no_grad(): + v_pos = self.get_deformed() + v_pos_camera_homo = ru.xfm_points(v_pos[None, ...], target['mvp']) + v_pos_camera = v_pos_camera_homo[:, :, :2] / v_pos_camera_homo[:, :, -1:] + v_pos_camera_discrete = torch.round((v_pos_camera * 0.5 + 0.5).clip(0, 1) * (target['resolution'][0] - 1)).long() + mask_cont = F.conv2d(target['mask_cont'][:, :, :, 0].unsqueeze(1), self.smooth_kernel, stride=1, padding=self.padding)[:, 0] + target_mask = mask_cont == 0 + for k in range(target_mask.size(0)): + assert v_pos_camera_discrete[k].min() >= 0 and v_pos_camera_discrete[k].max() < target['resolution'][0] + v_mask = target_mask[k, v_pos_camera_discrete[k, :, 1], v_pos_camera_discrete[k, :, 0]].view(v_pos.size(0)) + self.sdf.data[v_mask] = 1e-2 + self.deform.data[v_mask] = 0.0 + + # ============================================================================================== + # Render optimizable object with identical conditions + # ============================================================================================== + imesh, buffers = self.render_with_mesh(glctx, target, lgt, opt_material, noise=0.0, xfm_lgt=xfm_lgt) + + # ============================================================================================== + # Compute loss + # ============================================================================================== + t_iter = iteration / self.FLAGS.iter + + # Image-space loss, split into a coverage component and a color component + color_ref = target['img'] + img_loss = torch.tensor(0.0).cuda() + alpha_scale = 1.0 + img_loss = torch.nn.functional.mse_loss(buffers['shaded'][..., 3:], color_ref[..., 3:]) * alpha_scale + img_loss = img_loss + loss_fn(buffers['shaded'][..., 0:3] * color_ref[..., 3:], color_ref[..., 0:3] * color_ref[..., 3:]) + + + color_ref_second = target['img_second'] + img_loss = img_loss + torch.nn.functional.mse_loss(buffers['shaded_second'][..., 3:], color_ref_second[..., 3:]) * alpha_scale * 1e-1 + img_loss = img_loss + loss_fn(buffers['shaded_second'][..., 0:3] * color_ref_second[..., 3:], color_ref_second[..., 0:3] * color_ref_second[..., 3:]) * 1e-1 + + mask = (target['mask_cont'][:, :, :, 0] == 1.0).float() + + if iteration < 10000: + depth_scale = 100.0 + else: + depth_scale = 1.0 + + if iteration % 300 == 0 and iteration < 1790: + self.deform.data[:] *= 0.4 + + if no_depth_thin: + valid_depth_mask = (target['depth_second'] >= 0).float().detach() + depth_prox_mask = ((target['depth_second'] - target['depth']).abs() >= 5e-3).float().detach() + else: + valid_depth_mask = 1.0 + + depth_diff = (buffers['depth'][:, :, :, :1] - target['depth'][:, :, :, :1]).abs() * mask.unsqueeze(-1) * valid_depth_mask + depth_diff_second = (buffers['depth_second'][:, :, :, :1] - target['depth_second'][:, :, :, :1]).abs() * mask.unsqueeze(-1) * valid_depth_mask * depth_prox_mask * 1e-1 + + thres = 1.0 + l1_loss_mask = (depth_diff < thres).float() + l1_loss_mask_second = (depth_diff_second < thres).float() + + img_loss = img_loss + ( + ( + l1_loss_mask * depth_diff + + (1 - l1_loss_mask) * (depth_diff.pow(2) + thres - thres**2) + ).mean() * 1.0 * depth_scale + + ( + l1_loss_mask_second * depth_diff_second + + (1 - l1_loss_mask_second) * (depth_diff_second.pow(2) + thres - thres**2) + ).mean() * 1.0 * depth_scale + ) + + + reg_loss = torch.tensor(0.0).cuda() + + # SDF regularizer + iter_thres = 0 + sdf_weight = self.FLAGS.sdf_regularizer - (self.FLAGS.sdf_regularizer - 0.01) * min(1.0, 4.0 * ((iteration - iter_thres) / (self.FLAGS.iter - iter_thres))) + + sdf_mask = torch.zeros_like(self.sdf, device=self.sdf.device) + sdf_mask[imesh.valid_vert_idx] = 1.0 + sdf_masked = self.sdf.detach() * sdf_mask + self.sdf * (1 - sdf_mask) + reg_loss = sdf_reg_loss(sdf_masked, self.all_edges).mean() * sdf_weight * 0.1 # Dropoff to 0.01 + + # Albedo (k_d) smoothnesss regularizer + reg_loss += torch.mean(buffers['kd_grad'][..., :-1] * buffers['kd_grad'][..., -1:]) * 0.03 * min(1.0, iteration / 500) + + # Visibility regularizer + reg_loss += torch.mean(buffers['occlusion'][..., :-1] * buffers['occlusion'][..., -1:]) * 1e0 * min(1.0, iteration / 500) + + # pointcloud chamfer distance + pred_points = kaolin.ops.mesh.sample_points(imesh.v_pos.unsqueeze(0), imesh.t_pos_idx, 50000)[0][0] + target_pts = target['spts'] + chamfer = kaolin.metrics.pointcloud.chamfer_distance(pred_points.unsqueeze(0), target_pts.unsqueeze(0)).mean() + + reg_loss += chamfer + + + return img_loss, reg_loss diff --git a/nvdiffrec/lib/geometry/dmtet_fixedtopo.py b/nvdiffrec/lib/geometry/dmtet_fixedtopo.py new file mode 100644 index 0000000..873d31a --- /dev/null +++ b/nvdiffrec/lib/geometry/dmtet_fixedtopo.py @@ -0,0 +1,350 @@ +# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +import numpy as np +import torch + +from ..render import mesh +from ..render import render +from ..render import regularizer + +import kaolin +from ..render import util as render_utils +import torch.nn.functional as F + +############################################################################### +# Marching tetrahedrons implementation (differentiable), adapted from +# https://github.com/NVIDIAGameWorks/kaolin/blob/master/kaolin/ops/conversions/tetmesh.py +############################################################################### + +class DMTet: + def __init__(self): + self.triangle_table = torch.tensor([ + [-1, -1, -1, -1, -1, -1], + [ 1, 0, 2, -1, -1, -1], + [ 4, 0, 3, -1, -1, -1], + [ 1, 4, 2, 1, 3, 4], + [ 3, 1, 5, -1, -1, -1], + [ 2, 3, 0, 2, 5, 3], + [ 1, 4, 0, 1, 5, 4], + [ 4, 2, 5, -1, -1, -1], + [ 4, 5, 2, -1, -1, -1], + [ 4, 1, 0, 4, 5, 1], + [ 3, 2, 0, 3, 5, 2], + [ 1, 3, 5, -1, -1, -1], + [ 4, 1, 2, 4, 3, 1], + [ 3, 0, 4, -1, -1, -1], + [ 2, 0, 1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1] + ], dtype=torch.long, device='cuda') + + self.num_triangles_table = torch.tensor([0,1,1,2,1,2,2,1,1,2,2,1,2,1,1,0], dtype=torch.long, device='cuda') + self.base_tet_edges = torch.tensor([0,1,0,2,0,3,1,2,1,3,2,3], dtype=torch.long, device='cuda') + + ############################################################################### + # Utility functions + ############################################################################### + + def sort_edges(self, edges_ex2): + with torch.no_grad(): + order = (edges_ex2[:,0] > edges_ex2[:,1]).long() + order = order.unsqueeze(dim=1) + + a = torch.gather(input=edges_ex2, index=order, dim=1) + b = torch.gather(input=edges_ex2, index=1-order, dim=1) + + return torch.stack([a, b],-1) + + def map_uv(self, faces, face_gidx, max_idx): + N = int(np.ceil(np.sqrt((max_idx+1)//2))) + tex_y, tex_x = torch.meshgrid( + torch.linspace(0, 1 - (1 / N), N, dtype=torch.float32, device="cuda"), + torch.linspace(0, 1 - (1 / N), N, dtype=torch.float32, device="cuda"), + indexing='ij' + ) + + pad = 0.9 / N + + uvs = torch.stack([ + tex_x , tex_y, + tex_x + pad, tex_y, + tex_x + pad, tex_y + pad, + tex_x , tex_y + pad + ], dim=-1).view(-1, 2) + + def _idx(tet_idx, N): + x = tet_idx % N + y = torch.div(tet_idx, N, rounding_mode='trunc') + return y * N + x + + tet_idx = _idx(torch.div(face_gidx, 2, rounding_mode='trunc'), N) + tri_idx = face_gidx % 2 + + uv_idx = torch.stack(( + tet_idx * 4, tet_idx * 4 + tri_idx + 1, tet_idx * 4 + tri_idx + 2 + ), dim = -1). view(-1, 3) + + return uvs, uv_idx + + ############################################################################### + # Marching tets implementation + ############################################################################### + + def __call__(self, pos_nx3, sdf_n, tet_fx4, get_tet_gidx=False): + with torch.no_grad(): + occ_n = sdf_n > 0 + occ_fx4 = occ_n[tet_fx4.reshape(-1)].reshape(-1,4) + occ_sum = torch.sum(occ_fx4, -1) + valid_tets = (occ_sum>0) & (occ_sum<4) + occ_sum = occ_sum[valid_tets] + + # find all vertices + all_edges = tet_fx4[valid_tets][:,self.base_tet_edges].reshape(-1,2) + all_edges = self.sort_edges(all_edges) + unique_edges, idx_map = torch.unique(all_edges,dim=0, return_inverse=True) + + unique_edges = unique_edges.long() + mask_edges = occ_n[unique_edges.reshape(-1)].reshape(-1,2).sum(-1) == 1 + mapping = torch.ones((unique_edges.shape[0]), dtype=torch.long, device="cuda") * -1 + mapping[mask_edges] = torch.arange(mask_edges.sum(), dtype=torch.long,device="cuda") + idx_map = mapping[idx_map] # map edges to verts + + interp_v = unique_edges[mask_edges] + edges_to_interp = pos_nx3[interp_v.reshape(-1)].reshape(-1,2,3) + edges_to_interp_sdf = sdf_n[interp_v.reshape(-1)].reshape(-1,2,1) + edges_to_interp_sdf[:,-1] *= -1 + + denominator = edges_to_interp_sdf.sum(1,keepdim = True) + + edges_to_interp_sdf = torch.flip(edges_to_interp_sdf, [1])/denominator + verts = (edges_to_interp * edges_to_interp_sdf).sum(1) + + idx_map = idx_map.reshape(-1,6) + + v_id = torch.pow(2, torch.arange(4, dtype=torch.long, device="cuda")) + tetindex = (occ_fx4[valid_tets] * v_id.unsqueeze(0)).sum(-1) + num_triangles = self.num_triangles_table[tetindex] + + # Generate triangle indices + faces = torch.cat(( + torch.gather(input=idx_map[num_triangles == 1], dim=1, index=self.triangle_table[tetindex[num_triangles == 1]][:, :3]).reshape(-1,3), + torch.gather(input=idx_map[num_triangles == 2], dim=1, index=self.triangle_table[tetindex[num_triangles == 2]][:, :6]).reshape(-1,3), + ), dim=0) + + # Get global face index (static, does not depend on topology) + num_tets = tet_fx4.shape[0] + tet_gidx = torch.arange(num_tets, dtype=torch.long, device="cuda")[valid_tets] + face_gidx = torch.cat(( + tet_gidx[num_triangles == 1]*2, + torch.stack((tet_gidx[num_triangles == 2]*2, tet_gidx[num_triangles == 2]*2 + 1), dim=-1).view(-1) + ), dim=0) + + uvs, uv_idx = self.map_uv(faces, face_gidx, num_tets*2) + + if get_tet_gidx: + face_to_valid_tet = torch.cat(( + tet_gidx[num_triangles == 1], + torch.stack((tet_gidx[num_triangles == 2], tet_gidx[num_triangles == 2]), dim=-1).view(-1) + ), dim=0) + + return verts, faces, uvs, uv_idx, face_to_valid_tet.long() + else: + return verts, faces, uvs, uv_idx + +############################################################################### +# Regularizer +############################################################################### + +def sdf_reg_loss(sdf, all_edges): + sdf_f1x6x2 = sdf[all_edges.reshape(-1)].reshape(-1,2) + mask = torch.sign(sdf_f1x6x2[...,0]) != torch.sign(sdf_f1x6x2[...,1]) + sdf_f1x6x2 = sdf_f1x6x2[mask] + sdf_diff = torch.nn.functional.binary_cross_entropy_with_logits(sdf_f1x6x2[...,0], (sdf_f1x6x2[...,1] > 0).float()) + \ + torch.nn.functional.binary_cross_entropy_with_logits(sdf_f1x6x2[...,1], (sdf_f1x6x2[...,0] > 0).float()) + return sdf_diff + +############################################################################### +# Geometry interface +############################################################################### + +class DMTetGeometryFixedTopo(torch.nn.Module): + def __init__(self, dmt_geometry, base_mesh, grid_res, scale, FLAGS, deform_scale=1.0, **kwargs): + super(DMTetGeometryFixedTopo, self).__init__() + + self.FLAGS = FLAGS + self.grid_res = grid_res + self.marching_tets = DMTet() + self.initial_guess = base_mesh + self.scale = scale + self.tanh = False + self.deform_scale = deform_scale + + tets = np.load('./data/tets/{}_tets_cropped.npz'.format(self.grid_res)) + + self.verts = torch.tensor(tets['vertices'], dtype=torch.float32, device='cuda') * scale + self.indices = torch.tensor(tets['indices'], dtype=torch.long, device='cuda') + self.generate_edges() + + self.sdf_sign = torch.nn.Parameter(torch.sign(dmt_geometry.sdf.data + 1e-8).float(), requires_grad=False) + self.sdf_sign.data[self.sdf_sign.data == 0] = 1.0 ## Avoid abiguity + self.register_parameter('sdf_sign', self.sdf_sign) + + self.sdf_abs = torch.nn.Parameter(torch.ones_like(dmt_geometry.sdf), requires_grad=False) + self.register_parameter('sdf_abs', self.sdf_abs) + + self.deform = torch.nn.Parameter(dmt_geometry.deform.data, requires_grad=True) + self.register_parameter('deform', self.deform) + + self.sdf_abs_ema = torch.nn.Parameter(self.sdf_abs.clone().detach(), requires_grad=False) + self.deform_ema = torch.nn.Parameter(self.deform.clone().detach(), requires_grad=False) + + def set_init_v_pos(self): + with torch.no_grad(): + v_deformed = self.get_deformed() + verts, faces, uvs, uv_idx = self.marching_tets(v_deformed, self.sdf_sign * self.sdf_abs.abs(), self.indices) + self.initial_guess_v_pos = verts + + + def generate_edges(self): + with torch.no_grad(): + edges = torch.tensor([0,1,0,2,0,3,1,2,1,3,2,3], dtype = torch.long, device = "cuda") + all_edges = self.indices[:,edges].reshape(-1,2) + all_edges_sorted = torch.sort(all_edges, dim=1)[0] + self.all_edges = torch.unique(all_edges_sorted, dim=0) + + def getAABB(self): + return torch.min(self.verts, dim=0).values, torch.max(self.verts, dim=0).values + + def getVertNNDist(self): + raise NotImplementedError + v_deformed = (self.verts + 2 / (self.grid_res * 2) * torch.tanh(self.deform)).unsqueeze(0) + return (pytorch3d.ops.knn.knn_points(v_deformed, v_deformed, K=2).dists[0, :, -1].detach()) ## K=2 because dist(self, self)=0 + + def getMesh(self, material): + # Run DM tet to get a base mesh + v_deformed = self.get_deformed() + verts, faces, uvs, uv_idx = self.marching_tets(v_deformed, self.sdf_sign * self.sdf_abs.abs(), self.indices) + imesh = mesh.Mesh(verts, faces, v_tex=uvs, t_tex_idx=uv_idx, material=material) + + # Run mesh operations to generate tangent space + imesh = mesh.auto_normals(imesh) + imesh = mesh.compute_tangents(imesh) + + return imesh + + def getMesh_tet_gidx(self, material): + # Run DM tet to get a base mesh + v_deformed = self.get_deformed() + verts, faces, uvs, uv_idx, tet_gidx = self.marching_tets( + v_deformed, self.sdf_sign * self.sdf_abs.abs(), self.indices, get_tet_gidx=True) + imesh = mesh.Mesh(verts, faces, v_tex=uvs, t_tex_idx=uv_idx, material=material) + + # Run mesh operations to generate tangent space + imesh = mesh.auto_normals(imesh) + imesh = mesh.compute_tangents(imesh) + + return imesh, tet_gidx + + + def update_ema(self, ema_coeff=0.9): + return + + def get_deformed(self): + if self.tanh: + v_deformed = self.verts + 2 / (self.grid_res * 2) * torch.tanh(self.deform) * self.deform_scale + else: + v_deformed = self.verts + 2 / (self.grid_res * 2) * self.deform * self.deform_scale + return v_deformed + + def getValidTetIdx(self): + # Run DM tet to get a base mesh + v_deformed = self.get_deformed() + verts, faces, uvs, uv_idx, tet_gidx = self.marching_tets( + v_deformed, self.sdf_sign * self.sdf_abs.abs(), self.indices, get_tet_gidx=True) + return tet_gidx.long() + + def getValidVertsIdx(self): + # Run DM tet to get a base mesh + v_deformed = self.get_deformed() + verts, faces, uvs, uv_idx, tet_gidx = self.marching_tets( + v_deformed, self.sdf_sign * self.sdf_abs.abs(), self.indices, get_tet_gidx=True) + return self.indices[tet_gidx.long()].unique() + + def getTetCenters(self): + v_deformed = self.get_deformed() # size: N x 3 + face_verts = v_deformed[self.indices] # size: M x 4 x 3 + face_centers = face_verts.mean(dim=1) # size: M x 3 + + return face_centers + + def clamp_deform(self): + if not self.tanh: + self.deform.data[:] = self.deform.data.clamp(-0.99, 0.99) + + def render(self, glctx, target, lgt, opt_material, bsdf=None, ema=False, xfm_lgt=None, get_visible_tets=False): + opt_mesh = self.getMesh(opt_material) + tet_centers = self.getTetCenters() if get_visible_tets else None + return render.render_mesh( + glctx, opt_mesh, target['mvp'], target['campos'], lgt, target['resolution'], spp=target['spp'], + msaa=True, background=target['background'], bsdf=bsdf, xfm_lgt=xfm_lgt, tet_centers=tet_centers) + + + def render_with_mesh(self, glctx, target, lgt, opt_material, bsdf=None, xfm_lgt=None): + opt_mesh = self.getMesh(opt_material) + return opt_mesh, render.render_mesh( + glctx, opt_mesh, target['mvp'], target['campos'], lgt, target['resolution'], spp=target['spp'], + num_layers=self.FLAGS.layers, msaa=True, background=target['background'], bsdf=bsdf, xfm_lgt=xfm_lgt) + + def tick(self, glctx, target, lgt, opt_material, loss_fn, iteration, with_reg=True, xfm_lgt=None, no_depth_thin=True): + + # ============================================================================================== + # Render optimizable object with identical conditions + # ============================================================================================== + imesh, buffers = self.render_with_mesh(glctx, target, lgt, opt_material, xfm_lgt=xfm_lgt) + + # ============================================================================================== + # Compute loss + # ============================================================================================== + t_iter = iteration / self.FLAGS.iter + + # Image-space loss, split into a coverage component and a color component + color_ref = target['img'] + img_loss = torch.nn.functional.mse_loss(buffers['shaded'][..., 3:], color_ref[..., 3:]) + img_loss = img_loss + loss_fn( + buffers['shaded'][..., 0:3] * color_ref[..., 3:], + color_ref[..., 0:3] * color_ref[..., 3:] + ) + mask = target['mask'][:, :, :, 0] + + + if no_depth_thin: + valid_depth_mask = ( + (target['depth_second'] >= 0).float() * ((target['depth_second'] - target['depth']).abs() >= 5e-3).float() + ).detach() + else: + valid_depth_mask = 1.0 + + depth_diff = (buffers['depth'][:, :, :, :1] - target['depth'][:, :, :, :1]).abs() * mask.unsqueeze(-1) * valid_depth_mask + depth_diff = (buffers['depth_second'][:, :, :, :1] - target['depth_second'][:, :, :, :1]).abs() * mask.unsqueeze(-1) * valid_depth_mask * 1e-1 + + l1_loss_mask = (depth_diff < 1.0).float() + img_loss = img_loss + (l1_loss_mask * depth_diff + (1 - l1_loss_mask) * depth_diff.pow(2)).mean() * 100.0 + + reg_loss = torch.tensor([0], dtype=torch.float32, device="cuda") + + # Compute regularizer. + reg_loss += regularizer.laplace_regularizer_const(imesh.v_pos - self.initial_guess_v_pos, imesh.t_pos_idx) * self.FLAGS.laplace_scale * (1 - t_iter) * 1e-2 + + ### Chamfer distance for ShapeNet + pred_points = kaolin.ops.mesh.sample_points(imesh.v_pos.unsqueeze(0), imesh.t_pos_idx, 50000)[0][0] + target_pts = target['spts'] + chamfer = kaolin.metrics.pointcloud.chamfer_distance(pred_points.unsqueeze(0), target_pts.unsqueeze(0)).mean() + reg_loss += chamfer + + return img_loss, reg_loss \ No newline at end of file diff --git a/nvdiffrec/lib/geometry/dmtet_singleview.py b/nvdiffrec/lib/geometry/dmtet_singleview.py new file mode 100644 index 0000000..3d9af5a --- /dev/null +++ b/nvdiffrec/lib/geometry/dmtet_singleview.py @@ -0,0 +1,516 @@ +# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +import os +import numpy as np +import torch + +from ..render import mesh +from ..render import render +from ..render import regularizer + + +import kaolin +import pytorch3d.ops +from ..render import util as render_utils + +import torch.nn.functional as F + +from ..render import renderutils as ru + +############################################################################### +# Marching tetrahedrons implementation (differentiable), adapted from +# https://github.com/NVIDIAGameWorks/kaolin/blob/master/kaolin/ops/conversions/tetmesh.py +############################################################################### + +class DMTet: + def __init__(self): + self.triangle_table = torch.tensor([ + [-1, -1, -1, -1, -1, -1], + [ 1, 0, 2, -1, -1, -1], + [ 4, 0, 3, -1, -1, -1], + [ 1, 4, 2, 1, 3, 4], + [ 3, 1, 5, -1, -1, -1], + [ 2, 3, 0, 2, 5, 3], + [ 1, 4, 0, 1, 5, 4], + [ 4, 2, 5, -1, -1, -1], + [ 4, 5, 2, -1, -1, -1], + [ 4, 1, 0, 4, 5, 1], + [ 3, 2, 0, 3, 5, 2], + [ 1, 3, 5, -1, -1, -1], + [ 4, 1, 2, 4, 3, 1], + [ 3, 0, 4, -1, -1, -1], + [ 2, 0, 1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1] + ], dtype=torch.long, device='cuda') + + self.num_triangles_table = torch.tensor([0,1,1,2,1,2,2,1,1,2,2,1,2,1,1,0], dtype=torch.long, device='cuda') + self.base_tet_edges = torch.tensor([0,1,0,2,0,3,1,2,1,3,2,3], dtype=torch.long, device='cuda') + + ############################################################################### + # Utility functions + ############################################################################### + + def sort_edges(self, edges_ex2): + with torch.no_grad(): + order = (edges_ex2[:,0] > edges_ex2[:,1]).long() + order = order.unsqueeze(dim=1) + + a = torch.gather(input=edges_ex2, index=order, dim=1) + b = torch.gather(input=edges_ex2, index=1-order, dim=1) + + return torch.stack([a, b],-1) + + def map_uv(self, faces, face_gidx, max_idx): + N = int(np.ceil(np.sqrt((max_idx+1)//2))) + tex_y, tex_x = torch.meshgrid( + torch.linspace(0, 1 - (1 / N), N, dtype=torch.float32, device="cuda"), + torch.linspace(0, 1 - (1 / N), N, dtype=torch.float32, device="cuda"), + indexing='ij' + ) + + pad = 0.9 / N + + uvs = torch.stack([ + tex_x , tex_y, + tex_x + pad, tex_y, + tex_x + pad, tex_y + pad, + tex_x , tex_y + pad + ], dim=-1).view(-1, 2) + + def _idx(tet_idx, N): + x = tet_idx % N + y = torch.div(tet_idx, N, rounding_mode='trunc') + return y * N + x + + tet_idx = _idx(torch.div(face_gidx, 2, rounding_mode='trunc'), N) + tri_idx = face_gidx % 2 + + uv_idx = torch.stack(( + tet_idx * 4, tet_idx * 4 + tri_idx + 1, tet_idx * 4 + tri_idx + 2 + ), dim = -1). view(-1, 3) + + return uvs, uv_idx + + ############################################################################### + # Marching tets implementation + ############################################################################### + + def __call__(self, pos_nx3, sdf_n, tet_fx4): + with torch.no_grad(): + occ_n = sdf_n > 0 + occ_fx4 = occ_n[tet_fx4.reshape(-1)].reshape(-1,4) + occ_sum = torch.sum(occ_fx4, -1) + valid_tets = (occ_sum>0) & (occ_sum<4) + occ_sum = occ_sum[valid_tets] + + # find all vertices + all_edges = tet_fx4[valid_tets][:,self.base_tet_edges].reshape(-1,2) + all_edges = self.sort_edges(all_edges) + unique_edges, idx_map = torch.unique(all_edges,dim=0, return_inverse=True) + + unique_edges = unique_edges.long() + mask_edges = occ_n[unique_edges.reshape(-1)].reshape(-1,2).sum(-1) == 1 + mapping = torch.ones((unique_edges.shape[0]), dtype=torch.long, device="cuda") * -1 + mapping[mask_edges] = torch.arange(mask_edges.sum(), dtype=torch.long,device="cuda") + idx_map = mapping[idx_map] # map edges to verts + + interp_v = unique_edges[mask_edges] + edges_to_interp = pos_nx3[interp_v.reshape(-1)].reshape(-1,2,3) + edges_to_interp_sdf = sdf_n[interp_v.reshape(-1)].reshape(-1,2,1) + edges_to_interp_sdf[:,-1] *= -1 + + denominator = edges_to_interp_sdf.sum(1,keepdim = True) + + edges_to_interp_sdf = torch.flip(edges_to_interp_sdf, [1])/denominator + verts = (edges_to_interp * edges_to_interp_sdf).sum(1) + idx_map = idx_map.reshape(-1,6) + + v_id = torch.pow(2, torch.arange(4, dtype=torch.long, device="cuda")) + tetindex = (occ_fx4[valid_tets] * v_id.unsqueeze(0)).sum(-1) + num_triangles = self.num_triangles_table[tetindex] + + # Generate triangle indices + faces = torch.cat(( + torch.gather( + input=idx_map[num_triangles == 1], + dim=1, + index=self.triangle_table[tetindex[num_triangles == 1]][:, :3] + ).reshape(-1,3), + torch.gather( + input=idx_map[num_triangles == 2], + dim=1, + index=self.triangle_table[tetindex[num_triangles == 2]][:, :6] + ).reshape(-1,3), + ), dim=0) + + # Get global face index (static, does not depend on topology) + num_tets = tet_fx4.shape[0] + tet_gidx = torch.arange(num_tets, dtype=torch.long, device="cuda")[valid_tets] + face_gidx = torch.cat(( + tet_gidx[num_triangles == 1]*2, + torch.stack((tet_gidx[num_triangles == 2]*2, tet_gidx[num_triangles == 2]*2 + 1), dim=-1).view(-1) + ), dim=0) + + uvs, uv_idx = self.map_uv(faces, face_gidx, num_tets*2) + + + face_to_valid_tet = torch.cat(( + tet_gidx[num_triangles == 1], + torch.stack((tet_gidx[num_triangles == 2], tet_gidx[num_triangles == 2]), dim=-1).view(-1) + ), dim=0) + + valid_vert_idx = tet_fx4[tet_gidx[num_triangles > 0]].long().unique() + + return verts, faces, uvs, uv_idx, face_to_valid_tet.long(), valid_vert_idx + +############################################################################### +# Regularizer +############################################################################### + +def sdf_reg_loss(sdf, all_edges): + sdf_f1x6x2 = sdf[all_edges.reshape(-1)].reshape(-1,2) + mask = torch.sign(sdf_f1x6x2[...,0]) != torch.sign(sdf_f1x6x2[...,1]) + sdf_f1x6x2 = sdf_f1x6x2[mask] + sdf_diff = torch.nn.functional.binary_cross_entropy_with_logits(sdf_f1x6x2[...,0], (sdf_f1x6x2[...,1] > 0).float()) + \ + torch.nn.functional.binary_cross_entropy_with_logits(sdf_f1x6x2[...,1], (sdf_f1x6x2[...,0] > 0).float()) + return sdf_diff + + + +class Buffer(object): + def __init__(self, shape, capacity, device) -> None: + self.len_curr = 0 + self.pointer = 0 + self.capacity = capacity + self.buffer = torch.zeros((capacity, ) + shape, device=device) + + def push(self, x): + ''' + Push one single data point into the buffer + ''' + self.buffer[self.pointer] = x + self.pointer = (self.pointer + 1) % self.capacity + if self.len_curr < self.capacity: + self.len_curr += 1 + + def avg(self): + return torch.sign(torch.sign(self.buffer[:self.len_curr]).float().mean(dim=0)).float() + # return self.buffer[:self.len_curr].mean(dim=0) + # return self.buffer[:self.len_curr][-1] + +############################################################################### +# Geometry interface +############################################################################### + +class DMTetGeometry(torch.nn.Module): + def __init__(self, grid_res, scale, FLAGS, root='./', grid_to_tet=None, deform_scale=2.0, **kwargs): + super(DMTetGeometry, self).__init__() + + self.FLAGS = FLAGS + self.grid_res = grid_res + self.marching_tets = DMTet() + self.cropped = True + self.tanh = False + self.deform_scale = deform_scale + + self.grid_to_tet = grid_to_tet + + if self.cropped: + print("use cropped tets") + tets = np.load(os.path.join(root, 'data/tets/{}_tets_cropped.npz'.format(self.grid_res))) + else: + tets = np.load(os.path.join(root, 'data/tets/{}_tets.npz'.format(self.grid_res))) + print('tet min and max', tets['vertices'].min() * scale, tets['vertices'].max() * scale) + self.verts = torch.tensor(tets['vertices'], dtype=torch.float32, device='cuda') * scale + self.indices = torch.tensor(tets['indices'], dtype=torch.long, device='cuda') + self.generate_edges() + + # Random init + sdf = torch.rand_like(self.verts[:,0]).clamp(-1.0, 1.0) - 0.1 + # sdf = torch.sign(sdf) * 0.1 + # sdf = self.verts.pow(2).sum(dim=-1).sqrt() - 0.5 + + self.sdf = torch.nn.Parameter(sdf.clone().detach(), requires_grad=True) + self.register_parameter('sdf', self.sdf) + + self.deform = torch.nn.Parameter(torch.zeros_like(self.verts), requires_grad=True) + self.register_parameter('deform', self.deform) + + self.alpha = None + + self.sdf_ema = torch.nn.Parameter(sdf.clone().detach(), requires_grad=False) + self.deform_ema = torch.nn.Parameter(self.deform.clone().detach(), requires_grad=False) + + # self.ema_coeff = 0.7 + self.ema_coeff = 0.9 + + self.sdf_buffer = Buffer(sdf.size(), capacity=200, device='cuda') + + def generate_edges(self): + with torch.no_grad(): + edges = torch.tensor([0,1,0,2,0,3,1,2,1,3,2,3], dtype = torch.long, device = "cuda") + all_edges = self.indices[:,edges].reshape(-1,2) + all_edges_sorted = torch.sort(all_edges, dim=1)[0] + self.all_edges = torch.unique(all_edges_sorted, dim=0) + + def getAABB(self): + return torch.min(self.verts, dim=0).values, torch.max(self.verts, dim=0).values + + def getVertNNDist(self): + v_deformed = (self.verts + 2 / (self.grid_res * 2) * torch.tanh(self.deform)).unsqueeze(0) + return pytorch3d.ops.knn.knn_points( + v_deformed, v_deformed, K=2 + ).dists[0, :, -1].detach() ## K=2 because dist(self, self)=0 + + def getTetCenters(self): + v_deformed = self.get_deformed() # size: N x 3 + face_verts = v_deformed[self.indices] # size: M x 4 x 3 + face_centers = face_verts.mean(dim=1) # size: M x 3 + + return face_centers + + def getValidTetIdx(self): + # Run DM tet to get a base mesh + v_deformed = self.get_deformed() + verts, faces, uvs, uv_idx, tet_gidx, valid_vert_idx = self.marching_tets(v_deformed, self.sdf, self.indices) + return tet_gidx.long() + + def getValidVertsIdx(self): + # Run DM tet to get a base mesh + v_deformed = self.get_deformed() + verts, faces, uvs, uv_idx, tet_gidx, valid_vert_idx = self.marching_tets(v_deformed, self.sdf, self.indices) + return self.indices[tet_gidx.long()].unique() + + def getMesh(self, material, noise=0.0, ema=False): + # Run DM tet to get a base mesh + v_deformed = self.get_deformed(ema=ema) + + if ema: + sdf = self.sdf_ema + else: + sdf = self.sdf + + verts, faces, uvs, uv_idx, tet_gidx, valid_vert_idx = self.marching_tets(v_deformed, sdf, self.indices) + imesh = mesh.Mesh(verts, faces, v_tex=uvs, t_tex_idx=uv_idx, material=material) + + if material is not None: + # Run mesh operations to generate tangent space + imesh = mesh.auto_normals(imesh) + imesh = mesh.compute_tangents(imesh) + imesh.valid_vert_idx = valid_vert_idx + + return imesh + + def getMesh_no_deform(self, material, noise=0.0, ema=False): + # Run DM tet to get a base mesh + if ema: + # sdf = self.sdf * (1 - self.ema_coeff) + self.sdf_ema.detach() * self.ema_coeff + sdf = self.sdf_ema + else: + sdf = self.sdf + + verts, faces, uvs, uv_idx, tet_gidx, valid_vert_idx = self.marching_tets(self.verts, torch.sign(sdf), self.indices) + imesh = mesh.Mesh(verts, faces, v_tex=uvs, t_tex_idx=uv_idx, material=material) + + # Run mesh operations to generate tangent space + imesh = mesh.auto_normals(imesh) + imesh = mesh.compute_tangents(imesh) + + return imesh + + def getMesh_no_deform_gd(self, material, noise=0.0, ema=False): + # Run DM tet to get a base mesh + v_deformed = self.get_deformed(no_grad=True) + + + if ema: + # sdf = self.sdf * (1 - self.ema_coeff) + self.sdf_ema.detach() * self.ema_coeff + sdf = self.sdf_ema + else: + sdf = self.sdf + + verts, faces, uvs, uv_idx, tet_gidx, valid_vert_idx = self.marching_tets(v_deformed, sdf, self.indices) + imesh = mesh.Mesh(verts, faces, v_tex=uvs, t_tex_idx=uv_idx, material=material) + + # Run mesh operations to generate tangent space + imesh = mesh.auto_normals(imesh) + imesh = mesh.compute_tangents(imesh) + + return imesh + + def get_deformed(self, no_grad=False, ema=False): + if no_grad: + deform = self.deform.detach() + else: + deform = self.deform + + if self.tanh: + v_deformed = self.verts + 2 / (self.grid_res * 2) * torch.tanh(deform) * self.deform_scale + else: + v_deformed = self.verts + 2 / (self.grid_res * 2) * deform * self.deform_scale + return v_deformed + + def get_angle(self): + with torch.no_grad(): + comb_list = [ + (0, 1, 2, 3), + (0, 1, 3, 2), + (0, 2, 3, 1), + (1, 2, 3, 0) + ] + + directions = torch.zeros(self.indices.size(0), 4).cuda() + dir_vec = torch.zeros(self.indices.size(0), 4, 3).cuda() + vert_inds = torch.zeros(self.indices.size(0), 4).cuda().long() + count = 0 + vpos_list = self.get_deformed() + for comb in comb_list: + face = self.indices[:, comb[:3]] + face_pos = vpos_list[face, :] + face_center = face_pos.mean(1, keepdim=False) + v = self.indices[:, comb[3]] + test_vec = vpos_list[v] + ref_vec = render_utils.safe_normalize(vpos_list[face[:, 0]] - face_center) + distance_vec = test_vec - render_utils.dot(test_vec, ref_vec) * ref_vec + directions[:, count] = torch.sign(render_utils.dot(test_vec, distance_vec)[:, 0]) + dir_vec[:, count, :] = distance_vec + vert_inds[:, count] = v + count += 1 + return directions, dir_vec, vert_inds + + + def clamp_deform(self): + if not self.tanh: + self.deform.data[:] = self.deform.data.clamp(-0.99, 0.99) + self.sdf.data[:] = self.sdf.data.clamp(-1.0, 1.0) + + def render(self, glctx, target, lgt, opt_material, bsdf=None, ema=False, xfm_lgt=None, get_visible_tets=False): + opt_mesh = self.getMesh(opt_material, ema=ema) + tet_centers = self.getTetCenters() if get_visible_tets else None + return render.render_mesh( + glctx, opt_mesh, target['mvp'], target['campos'], lgt, target['resolution'], spp=target['spp'], + msaa=True, background=target['background'], bsdf=bsdf, xfm_lgt=xfm_lgt, tet_centers=tet_centers) + + def render_with_mesh(self, glctx, target, lgt, opt_material, bsdf=None, noise=0.0, ema=False, xfm_lgt=None): + opt_mesh = self.getMesh(opt_material, noise=noise, ema=ema) + return opt_mesh, render.render_mesh( + glctx, opt_mesh, target['mvp'], target['campos'], lgt, target['resolution'], spp=target['spp'], + msaa=True, background=target['background'], bsdf=bsdf, xfm_lgt=xfm_lgt) + + def update_ema(self, ema_coeff=0.9): + self.sdf_buffer.push(self.sdf) + self.sdf_ema.data[:] = self.sdf_buffer.avg() + # self.sdf_ema.data[:] = self.sdf.data[:] * (1 - ema_coeff) + self.sdf_ema.data[:] * ema_coeff + # self.deform_ema.data[:] = self.deform.data[:] * (1 - ema_coeff) + self.deform_ema.data[:] * ema_coeff + self.deform_ema.data[:] = self.deform.data[:] + + + def render_ema(self, glctx, target, lgt, opt_material, bsdf=None, xfm_lgt=None): + opt_mesh = self.getMesh(opt_material, ema=True) + return render.render_mesh( + glctx, opt_mesh, target['mvp'], target['campos'], lgt, target['resolution'], spp=target['spp'], + msaa=True, background=target['background'], bsdf=bsdf, xfm_lgt=xfm_lgt) + + def init_with_gt_surface(self, gt_verts, surface_faces, campos): + with torch.no_grad(): + surface_face_verts = gt_verts[surface_faces] + surface_centers = surface_face_verts.mean(dim=1) + v_pos = self.get_deformed() + results = pytorch3d.ops.knn_points(v_pos[None, ...], surface_centers[None, ...]) + dists, nn_idx = results.dists, results.idx + displacement = (v_pos - surface_centers[nn_idx[0, :, 0]]) + view_dirs = campos - surface_centers + normals = torch.cross( + surface_face_verts[:, 0] - surface_face_verts[:, 1], surface_face_verts[:, 0] - surface_face_verts[:, 2]) + mask = ((normals * view_dirs).sum(dim=-1, keepdim=True) >= 0).float() + normals = normals * mask - normals * (1 - mask) + outside_verts_idx = ((displacement * normals[nn_idx[0, :, 0]]).sum(dim=-1) > 0) + self.sdf.data[outside_verts_idx] = 1.0 + + + def tick(self, glctx, target, lgt, opt_material, loss_fn, iteration, with_reg=True, xfm_lgt=None, no_depth_thin=True): + + if iteration < 100: + self.deform.requires_grad = False + self.deform_scale = 2.0 + else: + self.deform.requires_grad = True + self.deform_scale = 2.0 + + if iteration > 200 and iteration < 2000 and iteration % 20 == 0: + with torch.no_grad(): + v_pos = self.get_deformed() + v_pos_camera_homo = ru.xfm_points(v_pos[None, ...], target['mvp']) + v_pos_camera = v_pos_camera_homo[:, :, :2] / v_pos_camera_homo[:, :, -1:] + v_pos_camera_discrete = ((v_pos_camera * 0.5 + 0.5).clip(0, 1) * (target['resolution'][0] - 1)).long() + target_mask = target['mask_cont'][:, :, :, 0] == 0 + for k in range(target_mask.size(0)): + assert v_pos_camera_discrete[k].min() >= 0 and v_pos_camera_discrete[k].max() < target['resolution'][0] + v_mask = target_mask[k, v_pos_camera_discrete[k, :, 1], v_pos_camera_discrete[k, :, 0]].view(v_pos.size(0)) + # print(v_mask.sum()) + self.sdf.data[v_mask] = self.sdf.data[v_mask].abs().clamp(0.0, 1.0) + + # ============================================================================================== + # Render optimizable object with identical conditions + # ============================================================================================== + imesh, buffers = self.render_with_mesh(glctx, target, lgt, opt_material, noise=0.0, xfm_lgt=xfm_lgt) + + # ============================================================================================== + # Compute loss + # ============================================================================================== + + # Image-space loss, split into a coverage component and a color component + color_ref = target['img'] + img_loss = torch.tensor(0.0).cuda() + img_loss = torch.nn.functional.mse_loss(buffers['shaded'][..., 3:], color_ref[..., 3:]) + img_loss = img_loss + loss_fn(buffers['shaded'][..., 0:3] * color_ref[..., 3:], color_ref[..., 0:3] * color_ref[..., 3:]) + mask = (target['mask_cont'][:, :, :, 0] == 1.0).float() + mask_curr = (buffers['mask_cont'][:, :, :, 0] == 1.0).float() + + if iteration % 300 == 0 and iteration < 1790: + self.deform.data[:] *= 0.4 + + if no_depth_thin: + valid_depth_mask = ( + (target['depth_second'] >= 0).float() * ((target['depth_second'] - target['depth']).abs() >= 5e-3).float() + ).detach() + else: + valid_depth_mask = 1.0 + + + depth_diff = (buffers['depth'][:, :, :, :1] - target['depth'][:, :, :, :1]).abs() * mask.unsqueeze(-1) * valid_depth_mask + l1_loss_mask = (depth_diff < 1.0).float() + img_loss = img_loss + (l1_loss_mask * depth_diff + (1 - l1_loss_mask) * depth_diff.pow(2)).mean() * 100.0 + + reg_loss = torch.tensor(0.0).cuda() + + # SDF regularizer + iter_thres = 0 + sdf_weight = self.FLAGS.sdf_regularizer - (self.FLAGS.sdf_regularizer - 0.01) * min(1.0, 4.0 * ((iteration - iter_thres) / (self.FLAGS.iter - iter_thres))) + + sdf_mask = torch.zeros_like(self.sdf, device=self.sdf.device) + sdf_mask[imesh.valid_vert_idx] = 1.0 + sdf_masked = self.sdf.detach() * sdf_mask + self.sdf * (1 - sdf_mask) + reg_loss = sdf_reg_loss(sdf_masked, self.all_edges).mean() * sdf_weight * 2.5 # Dropoff to 0.01 + + + # Albedo (k_d) smoothnesss regularizer + reg_loss += torch.mean(buffers['kd_grad'][..., :-1] * buffers['kd_grad'][..., -1:]) * 0.03 * min(1.0, iteration / 500) + + # Visibility regularizer + reg_loss += torch.mean(buffers['occlusion'][..., :-1] * buffers['occlusion'][..., -1:]) * 1e0 * min(1.0, iteration / 500) + + pred_points = kaolin.ops.mesh.sample_points(imesh.v_pos.unsqueeze(0), imesh.t_pos_idx, 50000)[0][0] + target_pts = target['spts'] + chamfer = kaolin.metrics.pointcloud.chamfer_distance(pred_points.unsqueeze(0), target_pts.unsqueeze(0)).mean() + reg_loss += chamfer + + + return img_loss, reg_loss diff --git a/nvdiffrec/lib/geometry/utils.py b/nvdiffrec/lib/geometry/utils.py new file mode 100644 index 0000000..5be4f5e --- /dev/null +++ b/nvdiffrec/lib/geometry/utils.py @@ -0,0 +1,128 @@ +import torch + +def _base_sample_points_selected_faces(face_vertices, face_features=None): + """Base function to sample points over selected faces. + The coordinates of the face vertices are interpolated to generate new samples. + Args: + face_vertices (tuple of torch.Tensor): + Coordinates of vertices, corresponding to selected faces to sample from. + A tuple of 3 entries corresponding to each of the face vertices. + Each entry is a torch.Tensor of shape :math:`(\\text{batch_size}, \\text{num_samples}, 3)`. + face_features (tuple of torch.Tensor, Optional): + Features of face vertices, corresponding to selected faces to sample from. + A tuple of 3 entries corresponding to each of the face vertices. + Each entry is a torch.Tensor of shape + :math:`(\\text{batch_size}, \\text{num_samples}, \\text{feature_dim})`. + Returns: + (torch.Tensor, torch.Tensor): + Sampled point coordinates of shape :math:`(\\text{batch_size}, \\text{num_samples}, 3)`. + Sampled points interpolated features of shape + :math:`(\\text{batch_size}, \\text{num_samples}, \\text{feature_dim})`. + If `face_vertices_features` arg is not specified, the returned interpolated features are None. + """ + + face_vertices0, face_vertices1, face_vertices2 = face_vertices + + sampling_shape = tuple(int(d) for d in face_vertices0.shape[:-1]) + (1,) + # u is proximity to middle point between v1 and v2 against v0. + # v is proximity to v2 against v1. + # + # The probability density for u should be f_U(u) = 2u. + # However, torch.rand use a uniform (f_X(x) = x) distribution, + # so using torch.sqrt we make a change of variable to have the desired density + # f_Y(y) = f_X(y ^ 2) * |d(y ^ 2) / dy| = 2y + u = torch.sqrt(torch.rand(sampling_shape, + device=face_vertices0.device, + dtype=face_vertices0.dtype)) + + v = torch.rand(sampling_shape, + device=face_vertices0.device, + dtype=face_vertices0.dtype) + w0 = 1 - u + w1 = u * (1 - v) + w2 = u * v + + points = w0 * face_vertices0 + w1 * face_vertices1 + w2 * face_vertices2 + + features = None + if face_features is not None: + face_features0, face_features1, face_features2 = face_features + features = w0 * face_features0 + w1 * face_features1 + \ + w2 * face_features2 + + return points, features + +def sample_points(vertices, faces, num_samples, areas=None, face_features=None): + r"""Uniformly sample points over the surface of triangle meshes. + First face on which the point is sampled is randomly selected, + with the probability of selection being proportional to the area of the face. + then the coordinate on the face is uniformly sampled. + If ``face_features`` is defined for the mesh faces, + the sampled points will be returned with interpolated features as well, + otherwise, no feature interpolation will occur. + Args: + vertices (torch.Tensor): + The vertices of the meshes, of shape + :math:`(\text{batch_size}, \text{num_vertices}, 3)`. + faces (torch.LongTensor): + The faces of the mesh, of shape :math:`(\text{num_faces}, 3)`. + num_samples (int): + The number of point sampled per mesh. + areas (torch.Tensor, optional): + The areas of each face, of shape :math:`(\text{batch_size}, \text{num_faces})`, + can be preprocessed, for fast on-the-fly sampling, + will be computed if None (default). + face_features (torch.Tensor, optional): + Per-vertex-per-face features, matching ``faces`` order, + of shape :math:`(\text{batch_size}, \text{num_faces}, 3, \text{feature_dim})`. + For example: + 1. Texture uv coordinates would be of shape + :math:`(\text{batch_size}, \text{num_faces}, 3, 2)`. + 2. RGB color values would be of shape + :math:`(\text{batch_size}, \text{num_faces}, 3, 3)`. + When specified, it is used to interpolate the features for new sampled points. + See also: + :func:`~kaolin.ops.mesh.index_vertices_by_faces` for conversion of features defined per vertex + and need to be converted to per-vertex-per-face shape of :math:`(\text{num_faces}, 3)`. + Returns: + (torch.Tensor, torch.LongTensor, (optional) torch.Tensor): + the pointclouds of shape :math:`(\text{batch_size}, \text{num_samples}, 3)`, + and the indexes of the faces selected, + of shape :math:`(\text{batch_size}, \text{num_samples})`. + If ``face_features`` arg is specified, then the interpolated features of sampled points of shape + :math:`(\text{batch_size}, \text{num_samples}, \text{feature_dim})` are also returned. + """ + if faces.shape[-1] != 3: + raise NotImplementedError("sample_points is only implemented for triangle meshes") + faces_0, faces_1, faces_2 = torch.split(faces, 1, dim=1) # (num_faces, 3) -> tuple of (num_faces,) + face_v_0 = torch.index_select(vertices, 1, faces_0.reshape(-1)) # (batch_size, num_faces, 3) + face_v_1 = torch.index_select(vertices, 1, faces_1.reshape(-1)) # (batch_size, num_faces, 3) + face_v_2 = torch.index_select(vertices, 1, faces_2.reshape(-1)) # (batch_size, num_faces, 3) + + if areas is None: + areas = _base_face_areas(face_v_0, face_v_1, face_v_2).squeeze(-1) + face_dist = torch.distributions.Categorical(areas) + face_choices = face_dist.sample([num_samples]).transpose(0, 1) + _face_choices = face_choices.unsqueeze(-1).repeat(1, 1, 3) + v0 = torch.gather(face_v_0, 1, _face_choices) # (batch_size, num_samples, 3) + v1 = torch.gather(face_v_1, 1, _face_choices) # (batch_size, num_samples, 3) + v2 = torch.gather(face_v_2, 1, _face_choices) # (batch_size, num_samples, 3) + face_vertices_choices = (v0, v1, v2) + + # UV coordinates are available, make sure to calculate them for sampled points as well + face_features_choices = None + if face_features is not None: + feat_dim = face_features.shape[-1] + # (num_faces, 3) -> tuple of (num_faces,) + _face_choices = face_choices[..., None, None].repeat(1, 1, 3, feat_dim) + face_features_choices = torch.gather(face_features, 1, _face_choices) + face_features_choices = tuple( + tmp_feat.squeeze(2) for tmp_feat in torch.split(face_features_choices, 1, dim=2)) + + points, point_features = _base_sample_points_selected_faces( + face_vertices_choices, face_features_choices) + + if point_features is not None: + return points, face_choices, point_features + else: + return points, face_choices \ No newline at end of file diff --git a/nvdiffrec/lib/render/light.py b/nvdiffrec/lib/render/light.py new file mode 100644 index 0000000..7469ae8 --- /dev/null +++ b/nvdiffrec/lib/render/light.py @@ -0,0 +1,187 @@ +# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +import os +import numpy as np +import torch +import nvdiffrast.torch as dr + +from . import util +from . import renderutils as ru + +import sys + +###################################################################################### +# Utility functions +###################################################################################### + +class cubemap_mip(torch.autograd.Function): + @staticmethod + def forward(ctx, cubemap): + return util.avg_pool_nhwc(cubemap, (2,2)) + + @staticmethod + def backward(ctx, dout): + res = dout.shape[1] * 2 + out = torch.zeros(6, res, res, dout.shape[-1], dtype=torch.float32, device="cuda") + for s in range(6): + gy, gx = torch.meshgrid(torch.linspace(-1.0 + 1.0 / res, 1.0 - 1.0 / res, res, device="cuda"), + torch.linspace(-1.0 + 1.0 / res, 1.0 - 1.0 / res, res, device="cuda"), + indexing='ij') + v = util.safe_normalize(util.cube_to_dir(s, gx, gy)) + out[s, ...] = dr.texture(dout[None, ...] * 0.25, v[None, ...].contiguous(), filter_mode='linear', boundary_mode='cube') + return out + +###################################################################################### +# Split-sum environment map light source with automatic mipmap generation +###################################################################################### + +class EnvironmentLight(torch.nn.Module): + LIGHT_MIN_RES = 16 + + MIN_ROUGHNESS = 0.08 + MAX_ROUGHNESS = 0.5 + + def __init__(self, base, trainable=True): + super(EnvironmentLight, self).__init__() + self.mtx = None + self.base = torch.nn.Parameter(base.clone().detach(), requires_grad=trainable) + print(f"light trainable or not: {trainable}") + if trainable: + self.register_parameter('env_base', self.base) + + def xfm(self, mtx): + self.mtx = mtx + + def clone(self): + return EnvironmentLight(self.base.clone().detach()) + + def clamp_(self, min=None, max=None): + self.base.clamp_(min, max) + + def get_mip(self, roughness): + return torch.where(roughness < self.MAX_ROUGHNESS + , (torch.clamp(roughness, self.MIN_ROUGHNESS, self.MAX_ROUGHNESS) - self.MIN_ROUGHNESS) / (self.MAX_ROUGHNESS - self.MIN_ROUGHNESS) * (len(self.specular) - 2) + , (torch.clamp(roughness, self.MAX_ROUGHNESS, 1.0) - self.MAX_ROUGHNESS) / (1.0 - self.MAX_ROUGHNESS) + len(self.specular) - 2) + + def build_mips(self, cutoff=0.99): + self.specular = [self.base] + while self.specular[-1].shape[1] > self.LIGHT_MIN_RES: + self.specular += [cubemap_mip.apply(self.specular[-1])] + + self.diffuse = ru.diffuse_cubemap(self.specular[-1]) + + for idx in range(len(self.specular) - 1): + roughness = (idx / (len(self.specular) - 2)) * (self.MAX_ROUGHNESS - self.MIN_ROUGHNESS) + self.MIN_ROUGHNESS + self.specular[idx] = ru.specular_cubemap(self.specular[idx], roughness, cutoff) + self.specular[-1] = ru.specular_cubemap(self.specular[-1], 1.0, cutoff) + + def regularizer(self): + white = (self.base[..., 0:1] + self.base[..., 1:2] + self.base[..., 2:3]) / 3.0 + return torch.mean(torch.abs(self.base - white)) + + def shade(self, gb_pos, gb_normal, kd, ks, view_pos, specular=True, xfm_lgt=None): + wo = util.safe_normalize(view_pos - gb_pos) + + if specular: + roughness = ks[..., 1:2] # y component + metallic = ks[..., 2:3] # z component + spec_col = (1.0 - metallic)*0.04 + kd * metallic + diff_col = kd * (1.0 - metallic) + else: + diff_col = kd + + reflvec = util.safe_normalize(util.reflect(wo, gb_normal)) + nrmvec = gb_normal + if xfm_lgt is not None: + # print(self.mtx.size()) + mtx = torch.as_tensor(xfm_lgt, dtype=torch.float32, device='cuda') + reflvec = ru.xfm_vectors(reflvec.view(reflvec.shape[0], reflvec.shape[1] * reflvec.shape[2], reflvec.shape[3]), mtx).view(*reflvec.shape) + nrmvec = ru.xfm_vectors(nrmvec.view(nrmvec.shape[0], nrmvec.shape[1] * nrmvec.shape[2], nrmvec.shape[3]), mtx).view(*nrmvec.shape) + elif self.mtx is not None: # Rotate lookup + raise NotImplementedError + # print(self.mtx.size()) + mtx = torch.as_tensor(self.mtx, dtype=torch.float32, device='cuda') + reflvec = ru.xfm_vectors(reflvec.view(reflvec.shape[0], reflvec.shape[1] * reflvec.shape[2], reflvec.shape[3]), mtx).view(*reflvec.shape) + nrmvec = ru.xfm_vectors(nrmvec.view(nrmvec.shape[0], nrmvec.shape[1] * nrmvec.shape[2], nrmvec.shape[3]), mtx).view(*nrmvec.shape) + + # if self.mtx is not None: # Rotate lookup + # # print(self.mtx.size()) + # mtx = torch.as_tensor(self.mtx, dtype=torch.float32, device='cuda') + # reflvec = ru.xfm_vectors(reflvec.view(reflvec.shape[0], reflvec.shape[1] * reflvec.shape[2], reflvec.shape[3]), mtx).view(*reflvec.shape) + # nrmvec = ru.xfm_vectors(nrmvec.view(nrmvec.shape[0], nrmvec.shape[1] * nrmvec.shape[2], nrmvec.shape[3]), mtx).view(*nrmvec.shape) + + # Diffuse lookup + diffuse = dr.texture(self.diffuse[None, ...], nrmvec.contiguous(), filter_mode='linear', boundary_mode='cube') + shaded_col = diffuse * diff_col + + if specular: + raise NotImplementedError + # Lookup FG term from lookup texture + NdotV = torch.clamp(util.dot(wo, gb_normal), min=1e-4) + fg_uv = torch.cat((NdotV, roughness), dim=-1) + if not hasattr(self, '_FG_LUT'): + self._FG_LUT = torch.as_tensor(np.fromfile('data/irrmaps/bsdf_256_256.bin', dtype=np.float32).reshape(1, 256, 256, 2), dtype=torch.float32, device='cuda') + fg_lookup = dr.texture(self._FG_LUT, fg_uv, filter_mode='linear', boundary_mode='clamp') + + # Roughness adjusted specular env lookup + miplevel = self.get_mip(roughness) + spec = dr.texture(self.specular[0][None, ...], reflvec.contiguous(), mip=list(m[None, ...] for m in self.specular[1:]), mip_level_bias=miplevel[..., 0], filter_mode='linear-mipmap-linear', boundary_mode='cube') + + # Compute aggregate lighting + reflectance = spec_col * fg_lookup[...,0:1] + fg_lookup[...,1:2] + shaded_col += spec * reflectance + + assert ks[..., 0:1].sum().item() == 0 + return shaded_col * (1.0 - ks[..., 0:1]) # Modulate by hemisphere visibility + +###################################################################################### +# Load and store +###################################################################################### + +# Load from latlong .HDR file +def _load_env_hdr(fn, scale=1.0, trainable=True): + print("load env inner loop") + sys.stdout.flush() + latlong_img = torch.tensor(util.load_image(fn), dtype=torch.float32, device='cuda')*scale + print("get cubemap") + sys.stdout.flush() + cubemap = util.latlong_to_cubemap(latlong_img, [512, 512]) + + print("get light object") + sys.stdout.flush() + l = EnvironmentLight(cubemap, trainable=trainable) + print("build mips") + sys.stdout.flush() + l.build_mips() + print("build mips done") + sys.stdout.flush() + + return l + +def load_env(fn, scale=1.0, trainable=True): + if os.path.splitext(fn)[1].lower() == ".hdr": + return _load_env_hdr(fn, scale, trainable=trainable) + else: + assert False, "Unknown envlight extension %s" % os.path.splitext(fn)[1] + +def save_env_map(fn, light): + assert isinstance(light, EnvironmentLight), "Can only save EnvironmentLight currently" + if isinstance(light, EnvironmentLight): + color = util.cubemap_to_latlong(light.base, [512, 1024]) + util.save_image_raw(fn, color.detach().cpu().numpy()) + +###################################################################################### +# Create trainable env map with random initialization +###################################################################################### + +def create_trainable_env_rnd(base_res, scale=0.5, bias=0.25): + base = torch.rand(6, base_res, base_res, 3, dtype=torch.float32, device='cuda') * scale + bias + return EnvironmentLight(base) + diff --git a/nvdiffrec/lib/render/material.py b/nvdiffrec/lib/render/material.py new file mode 100644 index 0000000..e12df41 --- /dev/null +++ b/nvdiffrec/lib/render/material.py @@ -0,0 +1,199 @@ +# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +import os +import numpy as np +import torch + +from . import util +from . import texture + +###################################################################################### +# Wrapper to make materials behave like a python dict, but register textures as +# torch.nn.Module parameters. +###################################################################################### +class Material(torch.nn.Module): + def __init__(self, mat_dict): + super(Material, self).__init__() + self.mat_keys = set() + for key in mat_dict.keys(): + self.mat_keys.add(key) + self[key] = mat_dict[key] + + def __contains__(self, key): + return hasattr(self, key) + + def __getitem__(self, key): + return getattr(self, key) + + def __setitem__(self, key, val): + self.mat_keys.add(key) + setattr(self, key, val) + + def __delitem__(self, key): + self.mat_keys.remove(key) + delattr(self, key) + + def keys(self): + return self.mat_keys + +###################################################################################### +# .mtl material format loading / storing +###################################################################################### +@torch.no_grad() +def load_mtl(fn, clear_ks=True, avoid_pure_black=False): + import re + mtl_path = os.path.dirname(fn) + + # Read file + with open(fn, 'r') as f: + lines = f.readlines() + + # Parse materials + materials = [] + for line in lines: + split_line = re.split(' +|\t+|\n+', line.strip()) + prefix = split_line[0].lower() + data = split_line[1:] + if 'newmtl' in prefix: + material = Material({'name' : data[0]}) + materials += [material] + elif materials: + if 'bsdf' in prefix or 'map_kd' in prefix or 'map_ks' in prefix or 'bump' in prefix: + material[prefix] = data[0] + else: + if 'kd' in prefix and avoid_pure_black: + tmp_kd = torch.tensor(tuple(float(d) for d in data), dtype=torch.float32, device='cuda') + if tmp_kd.sum() == 0.0: + tmp_kd[0] = 1.0 + tmp_kd[1] = 0.75 + material[prefix] = tmp_kd + else: + material[prefix] = torch.tensor(tuple(float(d) for d in data), dtype=torch.float32, device='cuda') + + # Convert everything to textures. Our code expects 'kd' and 'ks' to be texture maps. So replace constants with 1x1 maps + for mat in materials: + if not 'bsdf' in mat: + mat['bsdf'] = 'pbr' + + if 'map_kd' in mat: + mat['kd'] = texture.load_texture2D(os.path.join(mtl_path, mat['map_kd'])) + # mat['kd'] = texture.load_texture2D(os.path.join(mtl_path, mat['map_kd']), channels=3) + else: + mat['kd'] = texture.Texture2D(mat['kd']) + + if 'map_ks' in mat: + mat['ks'] = texture.load_texture2D(os.path.join(mtl_path, mat['map_ks']), channels=3) + else: + mat['ks'] = texture.Texture2D(mat['ks']) + + if 'bump' in mat: + mat['normal'] = texture.load_texture2D(os.path.join(mtl_path, mat['bump']), lambda_fn=lambda x: x * 2 - 1, channels=3) + + # Convert Kd from sRGB to linear RGB + mat['kd'] = texture.srgb_to_rgb(mat['kd']) + + if clear_ks: + # Override ORM occlusion (red) channel by zeros. We hijack this channel + for mip in mat['ks'].getMips(): + mip[..., 0] = 0.0 + + return materials + +@torch.no_grad() +def save_mtl(fn, material): + folder = os.path.dirname(fn) + with open(fn, "w") as f: + f.write('newmtl defaultMat\n') + if material is not None: + f.write('bsdf %s\n' % material['bsdf']) + if 'kd' in material.keys(): + f.write('map_kd texture_kd.png\n') + texture.save_texture2D(os.path.join(folder, 'texture_kd.png'), texture.rgb_to_srgb(material['kd'])) + if 'ks' in material.keys(): + f.write('map_ks texture_ks.png\n') + texture.save_texture2D(os.path.join(folder, 'texture_ks.png'), material['ks']) + if 'normal' in material.keys(): + f.write('bump texture_n.png\n') + texture.save_texture2D(os.path.join(folder, 'texture_n.png'), material['normal'], lambda_fn=lambda x:(util.safe_normalize(x)+1)*0.5) + else: + f.write('Kd 1 1 1\n') + f.write('Ks 0 0 0\n') + f.write('Ka 0 0 0\n') + f.write('Tf 1 1 1\n') + f.write('Ni 1\n') + f.write('Ns 0\n') + +###################################################################################### +# Merge multiple materials into a single uber-material +###################################################################################### + +def _upscale_replicate(x, full_res): + x = x.permute(0, 3, 1, 2) + x = torch.nn.functional.pad(x, (0, full_res[1] - x.shape[3], 0, full_res[0] - x.shape[2]), 'replicate') + return x.permute(0, 2, 3, 1).contiguous() + +def merge_materials(materials, texcoords, tfaces, mfaces): + assert len(materials) > 0 + for mat in materials: + assert mat['bsdf'] == materials[0]['bsdf'], "All materials must have the same BSDF (uber shader)" + assert ('normal' in mat) is ('normal' in materials[0]), "All materials must have either normal map enabled or disabled" + + uber_material = Material({ + 'name' : 'uber_material', + 'bsdf' : materials[0]['bsdf'], + }) + + textures = ['kd', 'ks', 'normal'] + + # Find maximum texture resolution across all materials and textures + max_res = None + for mat in materials: + for tex in textures: + tex_res = np.array(mat[tex].getRes()) if tex in mat else np.array([1, 1]) + max_res = np.maximum(max_res, tex_res) if max_res is not None else tex_res + + # Compute size of compund texture and round up to nearest PoT + full_res = 2**np.ceil(np.log2(max_res * np.array([1, len(materials)]))).astype(np.int) + + # Normalize texture resolution across all materials & combine into a single large texture + for tex in textures: + if tex in materials[0]: + tex_data = torch.cat(tuple(util.scale_img_nhwc(mat[tex].data, tuple(max_res)) for mat in materials), dim=2) # Lay out all textures horizontally, NHWC so dim2 is x + tex_data = _upscale_replicate(tex_data, full_res) + uber_material[tex] = texture.Texture2D(tex_data) + + # Compute scaling values for used / unused texture area + s_coeff = [full_res[0] / max_res[0], full_res[1] / max_res[1]] + + # Recompute texture coordinates to cooincide with new composite texture + new_tverts = {} + new_tverts_data = [] + for fi in range(len(tfaces)): + matIdx = mfaces[fi] + for vi in range(3): + ti = tfaces[fi][vi] + if not (ti in new_tverts): + new_tverts[ti] = {} + if not (matIdx in new_tverts[ti]): # create new vertex + if len(texcoords) == 0: + # continue + new_tverts_data.append([(matIdx) / s_coeff[1], 0]) # Offset texture coodrinate (x direction) by material id & scale to local space. Note, texcoords are (u,v) but texture is stored (w,h) so the indexes swap here + new_tverts[ti][matIdx] = len(new_tverts_data) - 1 + else: + new_tverts_data.append([(matIdx + texcoords[ti][0]) / s_coeff[1], texcoords[ti][1] / s_coeff[0]]) # Offset texture coodrinate (x direction) by material id & scale to local space. Note, texcoords are (u,v) but texture is stored (w,h) so the indexes swap here + new_tverts[ti][matIdx] = len(new_tverts_data) - 1 + + # if not (matIdx in new_tverts[ti]): # create new vertex + # new_tverts_data.append([(matIdx + texcoords[ti][0]) / s_coeff[1], texcoords[ti][1] / s_coeff[0]]) # Offset texture coodrinate (x direction) by material id & scale to local space. Note, texcoords are (u,v) but texture is stored (w,h) so the indexes swap here + # new_tverts[ti][matIdx] = len(new_tverts_data) - 1 + tfaces[fi][vi] = new_tverts[ti][matIdx] # reindex vertex + + return uber_material, new_tverts_data, tfaces + diff --git a/nvdiffrec/lib/render/mesh.py b/nvdiffrec/lib/render/mesh.py new file mode 100644 index 0000000..da598c7 --- /dev/null +++ b/nvdiffrec/lib/render/mesh.py @@ -0,0 +1,277 @@ +# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +import os +import numpy as np +import torch + +from . import obj +from . import util + +###################################################################################### +# Base mesh class +###################################################################################### +class Mesh: + def __init__(self, v_pos=None, t_pos_idx=None, v_nrm=None, t_nrm_idx=None, v_tex=None, t_tex_idx=None, v_tng=None, t_tng_idx=None, + material=None, base=None, f_nrm=None): + self.v_pos = v_pos + self.v_nrm = v_nrm + self.v_tex = v_tex + self.v_tng = v_tng + self.t_pos_idx = t_pos_idx + self.t_nrm_idx = t_nrm_idx + self.t_tex_idx = t_tex_idx + self.t_tng_idx = t_tng_idx + self.material = material + # self.f_nrm = f_nrm + + if base is not None: + self.copy_none(base) + + try: + i0 = self.t_pos_idx[:, 0] + i1 = self.t_pos_idx[:, 1] + i2 = self.t_pos_idx[:, 2] + + v0 = self.v_pos[i0, :] + v1 = self.v_pos[i1, :] + v2 = self.v_pos[i2, :] + + self.f_nrm = face_normals = torch.cross(v1 - v0, v2 - v0) + except: + self.f_nrm = f_nrm + + def copy_none(self, other): + if self.v_pos is None: + self.v_pos = other.v_pos + if self.t_pos_idx is None: + self.t_pos_idx = other.t_pos_idx + if self.v_nrm is None: + self.v_nrm = other.v_nrm + if self.t_nrm_idx is None: + self.t_nrm_idx = other.t_nrm_idx + if self.v_tex is None: + self.v_tex = other.v_tex + if self.t_tex_idx is None: + self.t_tex_idx = other.t_tex_idx + if self.v_tng is None: + self.v_tng = other.v_tng + if self.t_tng_idx is None: + self.t_tng_idx = other.t_tng_idx + if self.material is None: + self.material = other.material + # if self.f_nrm is None: + # self.f_nrm = other.f_nrm + + + def clone(self): + out = Mesh(base=self) + if out.v_pos is not None: + out.v_pos = out.v_pos.clone().detach() + if out.t_pos_idx is not None: + out.t_pos_idx = out.t_pos_idx.clone().detach() + if out.v_nrm is not None: + out.v_nrm = out.v_nrm.clone().detach() + if out.t_nrm_idx is not None: + out.t_nrm_idx = out.t_nrm_idx.clone().detach() + if out.v_tex is not None: + out.v_tex = out.v_tex.clone().detach() + if out.t_tex_idx is not None: + out.t_tex_idx = out.t_tex_idx.clone().detach() + if out.v_tng is not None: + out.v_tng = out.v_tng.clone().detach() + if out.t_tng_idx is not None: + out.t_tng_idx = out.t_tng_idx.clone().detach() + if out.f_nrm is not None: + out.f_nrm = out.f_nrm.clone().detach() + return out + +###################################################################################### +# Mesh loeading helper +###################################################################################### + +def load_mesh(filename, mtl_override=None, mtl_default=None, use_default=False, no_additional=False): + name, ext = os.path.splitext(filename) + if ext == ".obj": + return obj.load_obj(filename, clear_ks=True, mtl_override=mtl_override, mtl_default=mtl_default, use_default=use_default, no_additional=no_additional) + assert False, "Invalid mesh file extension" + +###################################################################################### +# Compute AABB +###################################################################################### +def aabb(mesh): + return torch.min(mesh.v_pos, dim=0).values, torch.max(mesh.v_pos, dim=0).values + +###################################################################################### +# Compute AABB with only used vertices +###################################################################################### +def aabb_clean(mesh): + v_pos_clean = mesh.v_pos[mesh.t_pos_idx.unique()] + return torch.min(v_pos_clean, dim=0).values, torch.max(v_pos_clean, dim=0).values + +###################################################################################### +# Compute unique edge list from attribute/vertex index list +###################################################################################### +def compute_edges(attr_idx, return_inverse=False): + with torch.no_grad(): + # Create all edges, packed by triangle + all_edges = torch.cat(( + torch.stack((attr_idx[:, 0], attr_idx[:, 1]), dim=-1), + torch.stack((attr_idx[:, 1], attr_idx[:, 2]), dim=-1), + torch.stack((attr_idx[:, 2], attr_idx[:, 0]), dim=-1), + ), dim=-1).view(-1, 2) + + # Swap edge order so min index is always first + order = (all_edges[:, 0] > all_edges[:, 1]).long().unsqueeze(dim=1) + sorted_edges = torch.cat(( + torch.gather(all_edges, 1, order), + torch.gather(all_edges, 1, 1 - order) + ), dim=-1) + + # Eliminate duplicates and return inverse mapping + return torch.unique(sorted_edges, dim=0, return_inverse=return_inverse) + +###################################################################################### +# Compute unique edge to face mapping from attribute/vertex index list +###################################################################################### +def compute_edge_to_face_mapping(attr_idx, return_inverse=False): + with torch.no_grad(): + # Get unique edges + # Create all edges, packed by triangle + all_edges = torch.cat(( + torch.stack((attr_idx[:, 0], attr_idx[:, 1]), dim=-1), + torch.stack((attr_idx[:, 1], attr_idx[:, 2]), dim=-1), + torch.stack((attr_idx[:, 2], attr_idx[:, 0]), dim=-1), + ), dim=-1).view(-1, 2) + + # Swap edge order so min index is always first + order = (all_edges[:, 0] > all_edges[:, 1]).long().unsqueeze(dim=1) + sorted_edges = torch.cat(( + torch.gather(all_edges, 1, order), + torch.gather(all_edges, 1, 1 - order) + ), dim=-1) + + # Elliminate duplicates and return inverse mapping + unique_edges, idx_map = torch.unique(sorted_edges, dim=0, return_inverse=True) + + tris = torch.arange(attr_idx.shape[0]).repeat_interleave(3).cuda() + + tris_per_edge = torch.zeros((unique_edges.shape[0], 2), dtype=torch.int64).cuda() + + # Compute edge to face table + mask0 = order[:,0] == 0 + mask1 = order[:,0] == 1 + tris_per_edge[idx_map[mask0], 0] = tris[mask0] + tris_per_edge[idx_map[mask1], 1] = tris[mask1] + + return tris_per_edge + +###################################################################################### +# Align base mesh to reference mesh:move & rescale to match bounding boxes. +###################################################################################### +def unit_size(mesh): + with torch.no_grad(): + vmin, vmax = aabb(mesh) + scale = 2 / torch.max(vmax - vmin).item() + v_pos = mesh.v_pos - (vmax + vmin) / 2 # Center mesh on origin + v_pos = v_pos * scale # Rescale to unit size + + return Mesh(v_pos, base=mesh) + +###################################################################################### +# Center & scale mesh for rendering +###################################################################################### +def center_by_reference(base_mesh, ref_aabb, scale): + center = (ref_aabb[0] + ref_aabb[1]) * 0.5 + scale = scale / torch.max(ref_aabb[1] - ref_aabb[0]).item() + print('normalization:', center, scale) + v_pos = (base_mesh.v_pos - center[None, ...]) * scale + return Mesh(v_pos, base=base_mesh) + +###################################################################################### +# Simple smooth vertex normal computation +###################################################################################### +def auto_normals(imesh): + + i0 = imesh.t_pos_idx[:, 0] + i1 = imesh.t_pos_idx[:, 1] + i2 = imesh.t_pos_idx[:, 2] + + v0 = imesh.v_pos[i0, :] + v1 = imesh.v_pos[i1, :] + v2 = imesh.v_pos[i2, :] + + f_nrm = face_normals = torch.cross(v1 - v0, v2 - v0) + + # Splat face normals to vertices + v_nrm = torch.zeros_like(imesh.v_pos) + v_nrm.scatter_add_(0, i0[:, None].repeat(1,3), face_normals) + v_nrm.scatter_add_(0, i1[:, None].repeat(1,3), face_normals) + v_nrm.scatter_add_(0, i2[:, None].repeat(1,3), face_normals) + + # Normalize, replace zero (degenerated) normals with some default value + v_nrm = torch.where(util.dot(v_nrm, v_nrm) > 1e-20, v_nrm, torch.tensor([0.0, 0.0, 1.0], dtype=torch.float32, device='cuda')) + v_nrm = util.safe_normalize(v_nrm) + + if torch.is_anomaly_enabled(): + assert torch.all(torch.isfinite(v_nrm)) + + return Mesh(v_nrm=v_nrm, t_nrm_idx=imesh.t_pos_idx, base=imesh, f_nrm=f_nrm) + +###################################################################################### +# Compute tangent space from texture map coordinates +# Follows http://www.mikktspace.com/ conventions +###################################################################################### +def compute_tangents(imesh): + vn_idx = [None] * 3 + pos = [None] * 3 + tex = [None] * 3 + for i in range(0,3): + pos[i] = imesh.v_pos[imesh.t_pos_idx[:, i]] + tex[i] = imesh.v_tex[imesh.t_tex_idx[:, i]] + vn_idx[i] = imesh.t_nrm_idx[:, i] + + tangents = torch.zeros_like(imesh.v_nrm) + tansum = torch.zeros_like(imesh.v_nrm) + + # Compute tangent space for each triangle + uve1 = tex[1] - tex[0] + uve2 = tex[2] - tex[0] + pe1 = pos[1] - pos[0] + pe2 = pos[2] - pos[0] + + nom = (pe1 * uve2[..., 1:2] - pe2 * uve1[..., 1:2]) + denom = (uve1[..., 0:1] * uve2[..., 1:2] - uve1[..., 1:2] * uve2[..., 0:1]) + assert not torch.isnan(uve1).any() + assert not torch.isnan(uve2).any() + assert not torch.isnan(pe1).any() + assert not torch.isnan(pe2).any() + + # Avoid division by zero for degenerated texture coordinates + tang = nom / torch.where(denom > 0.0, torch.clamp(denom, min=1e-6), torch.clamp(denom, max=-1e-6)) #### ZL: something wrong in this line, not sure why + assert (torch.where(denom > 0.0, torch.clamp(denom, min=1e-6), torch.clamp(denom, max=-1e-6)) != 0.0).all() + assert not torch.isnan(nom).any() + assert not torch.isnan(tang).any() + + # Update all 3 vertices + for i in range(0,3): + idx = vn_idx[i][:, None].repeat(1,3) + tangents.scatter_add_(0, idx, tang) # tangents[n_i] = tangents[n_i] + tang + tansum.scatter_add_(0, idx, torch.ones_like(tang)) # tansum[n_i] = tansum[n_i] + 1 + tangents = tangents / tansum + assert not torch.isnan(tangents).any() + + # Normalize and make sure tangent is perpendicular to normal + tangents = util.safe_normalize(tangents) + tangents = util.safe_normalize(tangents - util.dot(tangents, imesh.v_nrm) * imesh.v_nrm) + + if torch.is_anomaly_enabled(): + assert torch.all(torch.isfinite(tangents)) + + return Mesh(v_tng=tangents, t_tng_idx=imesh.t_nrm_idx, base=imesh) diff --git a/nvdiffrec/lib/render/mlptexture.py b/nvdiffrec/lib/render/mlptexture.py new file mode 100644 index 0000000..9abe96b --- /dev/null +++ b/nvdiffrec/lib/render/mlptexture.py @@ -0,0 +1,104 @@ +# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +import torch +import tinycudann as tcnn +import numpy as np + +####################################################################################################################################################### +# Small MLP using PyTorch primitives, internal helper class +####################################################################################################################################################### + +class _MLP(torch.nn.Module): + def __init__(self, cfg, loss_scale=1.0): + super(_MLP, self).__init__() + self.loss_scale = loss_scale + net = (torch.nn.Linear(cfg['n_input_dims'], cfg['n_neurons'], bias=False), torch.nn.ReLU()) + for i in range(cfg['n_hidden_layers']-1): + net = net + (torch.nn.Linear(cfg['n_neurons'], cfg['n_neurons'], bias=False), torch.nn.ReLU()) + net = net + (torch.nn.Linear(cfg['n_neurons'], cfg['n_output_dims'], bias=False),) + self.net = torch.nn.Sequential(*net).cuda() + + self.net.apply(self._init_weights) + + if self.loss_scale != 1.0: + self.net.register_full_backward_hook(lambda module, grad_i, grad_o: (grad_i[0] * self.loss_scale, )) + + def forward(self, x): + return self.net(x.to(torch.float32)) + + @staticmethod + def _init_weights(m): + if type(m) == torch.nn.Linear: + torch.nn.init.kaiming_uniform_(m.weight, nonlinearity='relu') + if hasattr(m.bias, 'data'): + m.bias.data.fill_(0.0) + +####################################################################################################################################################### +# Outward visible MLP class +####################################################################################################################################################### + +class MLPTexture3D(torch.nn.Module): + def __init__(self, AABB, channels = 3, internal_dims = 32, hidden = 2, min_max = None): + super(MLPTexture3D, self).__init__() + + self.channels = channels + self.internal_dims = internal_dims + self.AABB = AABB + self.min_max = min_max + + # Setup positional encoding, see https://github.com/NVlabs/tiny-cuda-nn for details + desired_resolution = 4096 + base_grid_resolution = 16 + num_levels = 16 + per_level_scale = np.exp(np.log(desired_resolution / base_grid_resolution) / (num_levels-1)) + + enc_cfg = { + "otype": "HashGrid", + "n_levels": num_levels, + "n_features_per_level": 2, + "log2_hashmap_size": 19, + "base_resolution": base_grid_resolution, + "per_level_scale" : per_level_scale + } + + gradient_scaling = 128.0 + self.encoder = tcnn.Encoding(3, enc_cfg) + self.encoder.register_full_backward_hook(lambda module, grad_i, grad_o: (grad_i[0] / gradient_scaling, )) + + # Setup MLP + mlp_cfg = { + "n_input_dims" : self.encoder.n_output_dims, + "n_output_dims" : self.channels, + "n_hidden_layers" : hidden, + "n_neurons" : self.internal_dims + } + self.net = _MLP(mlp_cfg, gradient_scaling) + print("Encoder output: %d dims" % (self.encoder.n_output_dims)) + + # Sample texture at a given location + def sample(self, texc): + _texc = (texc.view(-1, 3) - self.AABB[0][None, ...]) / (self.AABB[1][None, ...] - self.AABB[0][None, ...]) + _texc = torch.clamp(_texc, min=0, max=1) + + p_enc = self.encoder(_texc.contiguous()) + out = self.net.forward(p_enc) + + # Sigmoid limit and scale to the allowed range + if self.min_max is not None: + out = torch.sigmoid(out) * (self.min_max[1][None, :] - self.min_max[0][None, :]) + self.min_max[0][None, :] + + return out.view(*texc.shape[:-1], self.channels) # Remap to [n, h, w, c] + + # In-place clamp with no derivative to make sure values are in valid range after training + def clamp_(self): + pass + + def cleanup(self): + tcnn.free_temporary_memory() diff --git a/nvdiffrec/lib/render/obj.py b/nvdiffrec/lib/render/obj.py new file mode 100644 index 0000000..fce29d8 --- /dev/null +++ b/nvdiffrec/lib/render/obj.py @@ -0,0 +1,216 @@ +# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +import os +import torch + +from . import texture +from . import mesh +from . import material + +###################################################################################### +# Utility functions +###################################################################################### + +def _find_mat(materials, name): + for mat in materials: + if mat['name'] == name: + return mat + return materials[0] # Materials 0 is the default + +###################################################################################### +# Create mesh object from objfile +###################################################################################### + +def load_obj(filename, clear_ks=True, mtl_override=None, mtl_default=None, use_default=False, no_additional=False): + obj_path = os.path.dirname(filename) + + # Read entire file + with open(filename, 'r') as f: + lines = f.readlines() + + # Load materials + if mtl_default is None: + all_materials = [ + { + 'name' : '_default_mat', + 'bsdf' : 'pbr', + 'kd' : texture.Texture2D(torch.tensor([0.5, 0.5, 0.5], dtype=torch.float32, device='cuda')), + 'ks' : texture.Texture2D(torch.tensor([0.0, 0.0, 0.0], dtype=torch.float32, device='cuda')) + } + ] + else: + print("Load use-defined default mtl") + all_materials = [mtl_default] + + if not no_additional: + if mtl_override is None: + for line in lines: + if len(line.split()) == 0: + continue + if line.split()[0] == 'mtllib': + all_materials += material.load_mtl(os.path.join(obj_path, line.split()[1]), clear_ks, avoid_pure_black=True) # Read in entire material library + else: + all_materials += material.load_mtl(mtl_override) + else: + print("Skip loading non-default materials") + + + # load vertices + vertices, texcoords, normals = [], [], [] + for line in lines: + if len(line.split()) == 0: + continue + + prefix = line.split()[0].lower() + if prefix == 'v': + vertices.append([float(v) for v in line.split()[1:]]) + elif prefix == 'vt': + val = [float(v) for v in line.split()[1:]] + texcoords.append([val[0], 1.0 - val[1]]) + elif prefix == 'vn': + normals.append([float(v) for v in line.split()[1:]]) + + print(all_materials) + + # load faces + activeMatIdx = None + used_materials = [] + faces, tfaces, nfaces, mfaces = [], [], [], [] + for line in lines: + if len(line.split()) == 0: + continue + + prefix = line.split()[0].lower() + if prefix == 'usemtl': # Track used materials + mat = _find_mat(all_materials, line.split()[1]) + if not mat in used_materials: + used_materials.append(mat) + activeMatIdx = used_materials.index(mat) + elif prefix == 'f': # Parse face + vs = line.split()[1:] + nv = len(vs) + vv = vs[0].split('/') + v0 = int(vv[0]) - 1 + # t1 = int(vv[1]) - 1 if vv[1] != "" else -1 + # n1 = int(vv[2]) - 1 if vv[2] != "" else -1 + try: + t0 = int(vv[1]) - 1 if vv[1] != "" else -1 + n0 = int(vv[2]) - 1 if vv[2] != "" else -1 + except: + t0 = n0 = -1 + for i in range(nv - 2): # Triangulate polygons + vv = vs[i + 1].split('/') + v1 = int(vv[0]) - 1 + # t1 = int(vv[1]) - 1 if vv[1] != "" else -1 + # n1 = int(vv[2]) - 1 if vv[2] != "" else -1 + try: + t1 = int(vv[1]) - 1 if vv[1] != "" else -1 + n1 = int(vv[2]) - 1 if vv[2] != "" else -1 + except: + t1 = n1 = -1 + vv = vs[i + 2].split('/') + v2 = int(vv[0]) - 1 + # t2 = int(vv[1]) - 1 if vv[1] != "" else -1 + # n2 = int(vv[2]) - 1 if vv[2] != "" else -1 + try: + t2 = int(vv[1]) - 1 if vv[1] != "" else -1 + n2 = int(vv[2]) - 1 if vv[2] != "" else -1 + except: + t2 = n2 = -1 + mfaces.append(activeMatIdx) + faces.append([v0, v1, v2]) + tfaces.append([t0, t1, t2]) + nfaces.append([n0, n1, n2]) + assert len(tfaces) == len(faces) and len(nfaces) == len (faces) + + # # Create an "uber" material by combining all textures into a larger texture + # # if len(used_materials) > 1: + # if True: + # uber_material, texcoords, tfaces = material.merge_materials(used_materials, texcoords, tfaces, mfaces) + # elif len(used_materials) == 1: + # uber_material = used_materials[0] + # else: + # uber_material = None + + vertices = torch.tensor(vertices, dtype=torch.float32, device='cuda') + # texcoords = torch.tensor(texcoords, dtype=torch.float32, device='cuda') if len(texcoords) > 0 else None + # normals = torch.tensor(normals, dtype=torch.float32, device='cuda') if len(normals) > 0 else None + # # normals = None + + faces = torch.tensor(faces, dtype=torch.int64, device='cuda') + # tfaces = torch.tensor(tfaces, dtype=torch.int64, device='cuda') if texcoords is not None else None + # nfaces = torch.tensor(nfaces, dtype=torch.int64, device='cuda') if normals is not None else None + + # print(uber_material) + + uber_material = all_materials[0] + texcoords = normals = tfaces = nfaces = None + # return mesh.Mesh(vertices, faces, normals, nfaces, texcoords, tfaces, material=uber_material) + + imesh = mesh.Mesh(vertices, faces, normals, nfaces, texcoords, tfaces, material=uber_material) + imesh = mesh.auto_normals(imesh) + return imesh + +###################################################################################### +# Save mesh object to objfile +###################################################################################### + +def write_obj(folder, mesh, save_material=True): + obj_file = os.path.join(folder, 'mesh.obj') + print("Writing mesh: ", obj_file) + with open(obj_file, "w") as f: + # f.write("mtllib mesh.mtl\n") + f.write("g default\n") + + v_pos = mesh.v_pos.detach().cpu().numpy() if mesh.v_pos is not None else None + # v_nrm = mesh.v_nrm.detach().cpu().numpy() if mesh.v_nrm is not None else None + # v_tex = mesh.v_tex.detach().cpu().numpy() if mesh.v_tex is not None else None + v_nrm = None + v_tex = None + + t_pos_idx = mesh.t_pos_idx.detach().cpu().numpy() if mesh.t_pos_idx is not None else None + # t_nrm_idx = mesh.t_nrm_idx.detach().cpu().numpy() if mesh.t_nrm_idx is not None else None + # t_tex_idx = mesh.t_tex_idx.detach().cpu().numpy() if mesh.t_tex_idx is not None else None + + print(" writing %d vertices" % len(v_pos)) + for v in v_pos: + f.write('v {} {} {} \n'.format(v[0], v[1], v[2])) + + # if v_tex is not None: + # print(" writing %d texcoords" % len(v_tex)) + # assert(len(t_pos_idx) == len(t_tex_idx)) + # for v in v_tex: + # f.write('vt {} {} \n'.format(v[0], 1.0 - v[1])) + + # if v_nrm is not None: + # print(" writing %d normals" % len(v_nrm)) + # assert(len(t_pos_idx) == len(t_nrm_idx)) + # for v in v_nrm: + # f.write('vn {} {} {}\n'.format(v[0], v[1], v[2])) + + # faces + f.write("s 1 \n") + f.write("g pMesh1\n") + f.write("usemtl defaultMat\n") + + # Write faces + print(" writing %d faces" % len(t_pos_idx)) + for i in range(len(t_pos_idx)): + f.write("f ") + for j in range(3): + f.write(' %s/%s/%s' % (str(t_pos_idx[i][j]+1), '' if v_tex is None else str(t_tex_idx[i][j]+1), '' if v_nrm is None else str(t_nrm_idx[i][j]+1))) + f.write("\n") + + if save_material: + mtl_file = os.path.join(folder, 'mesh.mtl') + print("Writing material: ", mtl_file) + material.save_mtl(mtl_file, mesh.material) + + print("Done exporting mesh") diff --git a/nvdiffrec/lib/render/regularizer.py b/nvdiffrec/lib/render/regularizer.py new file mode 100644 index 0000000..74c7225 --- /dev/null +++ b/nvdiffrec/lib/render/regularizer.py @@ -0,0 +1,205 @@ +# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +import torch +import nvdiffrast.torch as dr +import pytorch3d.ops + +from . import util +from . import mesh + +###################################################################################### +# Computes the image gradient, useful for kd/ks smoothness losses +###################################################################################### +def image_grad(buf, std=0.01): + t, s = torch.meshgrid(torch.linspace(-1.0 + 1.0 / buf.shape[1], 1.0 - 1.0 / buf.shape[1], buf.shape[1], device="cuda"), + torch.linspace(-1.0 + 1.0 / buf.shape[2], 1.0 - 1.0 / buf.shape[2], buf.shape[2], device="cuda"), + indexing='ij') + tc = torch.normal(mean=0, std=std, size=(buf.shape[0], buf.shape[1], buf.shape[2], 2), device="cuda") + torch.stack((s, t), dim=-1)[None, ...] + tap = dr.texture(buf, tc, filter_mode='linear', boundary_mode='clamp') + return torch.abs(tap[..., :-1] - buf[..., :-1]) * tap[..., -1:] * buf[..., -1:] + +###################################################################################### +# Computes the avergage edge length of a mesh. +# Rough estimate of the tessellation of a mesh. Can be used e.g. to clamp gradients +###################################################################################### +def avg_edge_length(v_pos, t_pos_idx): + e_pos_idx = mesh.compute_edges(t_pos_idx) + edge_len = util.length(v_pos[e_pos_idx[:, 0]] - v_pos[e_pos_idx[:, 1]]) + return torch.mean(edge_len) + +###################################################################################### +# Laplacian regularization using umbrella operator (Fujiwara / Desbrun). +# https://mgarland.org/class/geom04/material/smoothing.pdf +###################################################################################### +def laplace_regularizer_const(v_pos, t_pos_idx): + term = torch.zeros_like(v_pos) + norm = torch.zeros_like(v_pos[..., 0:1]) + + v0 = v_pos[t_pos_idx[:, 0], :] + v1 = v_pos[t_pos_idx[:, 1], :] + v2 = v_pos[t_pos_idx[:, 2], :] + + term.scatter_add_(0, t_pos_idx[:, 0:1].repeat(1,3), (v1 - v0) + (v2 - v0)) + term.scatter_add_(0, t_pos_idx[:, 1:2].repeat(1,3), (v0 - v1) + (v2 - v1)) + term.scatter_add_(0, t_pos_idx[:, 2:3].repeat(1,3), (v0 - v2) + (v1 - v2)) + + two = torch.ones_like(v0) * 2.0 + norm.scatter_add_(0, t_pos_idx[:, 0:1], two) + norm.scatter_add_(0, t_pos_idx[:, 1:2], two) + norm.scatter_add_(0, t_pos_idx[:, 2:3], two) + + term = term / torch.clamp(norm, min=1.0) + + return torch.mean(term**2) + + +def scale_dependent_relative_laplace_regularizer_const(v_pos, v_pos_abs, t_pos_idx): + term = torch.zeros_like(v_pos) + norm = torch.zeros_like(v_pos[..., 0:1]) + + v0 = v_pos[t_pos_idx[:, 0], :] + v1 = v_pos[t_pos_idx[:, 1], :] + v2 = v_pos[t_pos_idx[:, 2], :] + + v0_abs = v_pos_abs[t_pos_idx[:, 0], :] + v1_abs = v_pos_abs[t_pos_idx[:, 1], :] + v2_abs = v_pos_abs[t_pos_idx[:, 2], :] + + eps = 1e-8 + deformable_dist = False + if deformable_dist: + raise NotImplementedError + else: + ## The original distance; does not account for the + v01_dist = ((v0_abs - v1_abs).pow(2).sum(-1, keepdim=True) + eps).sqrt() + v12_dist = ((v1_abs - v2_abs).pow(2).sum(-1, keepdim=True) + eps).sqrt() + v20_dist = ((v2_abs - v0_abs).pow(2).sum(-1, keepdim=True) + eps).sqrt() + + term.scatter_add_(0, t_pos_idx[:, 0:1].repeat(1,3), (v1 - v0) / v01_dist + (v2 - v0) / v20_dist) + term.scatter_add_(0, t_pos_idx[:, 1:2].repeat(1,3), (v0 - v1) / v01_dist + (v2 - v1) / v12_dist) + term.scatter_add_(0, t_pos_idx[:, 2:3].repeat(1,3), (v0 - v2) / v20_dist + (v1 - v2) / v12_dist) + + return torch.mean(term**2) + + +def scale_dependent_laplace_regularizer_const(v_pos, t_pos_idx): + term = torch.zeros_like(v_pos) + norm = torch.zeros_like(v_pos[..., 0:1]) + + v0 = v_pos[t_pos_idx[:, 0], :] + v1 = v_pos[t_pos_idx[:, 1], :] + v2 = v_pos[t_pos_idx[:, 2], :] + + eps = 1e-8 + v01_dist = ((v0 - v1).pow(2).sum(-1, keepdim=True) + eps).sqrt() + v12_dist = ((v1 - v2).pow(2).sum(-1, keepdim=True) + eps).sqrt() + v20_dist = ((v2 - v0).pow(2).sum(-1, keepdim=True) + eps).sqrt() + + stopgd = True + if stopgd: + v01_dist = v01_dist.detach() + v12_dist = v12_dist.detach() + v20_dist = v20_dist.detach() + + term.scatter_add_(0, t_pos_idx[:, 0:1].repeat(1,3), (v1 - v0) / v01_dist + (v2 - v0) / v20_dist) + term.scatter_add_(0, t_pos_idx[:, 1:2].repeat(1,3), (v0 - v1) / v01_dist + (v2 - v1) / v12_dist) + term.scatter_add_(0, t_pos_idx[:, 2:3].repeat(1,3), (v0 - v2) / v20_dist + (v1 - v2) / v12_dist) + + return torch.mean(term**2) + + +def mesh_repulsion(v_pos, t_pos_idx): + term = torch.zeros_like(v_pos) + + v0 = v_pos[t_pos_idx[:, 0], :] + v1 = v_pos[t_pos_idx[:, 1], :] + v2 = v_pos[t_pos_idx[:, 2], :] + + + eps = 1e-8 + v01_dist = ((v0 - v1).pow(2).sum(-1, keepdim=True) + eps).sqrt() + v12_dist = ((v1 - v2).pow(2).sum(-1, keepdim=True) + eps).sqrt() + v20_dist = ((v2 - v0).pow(2).sum(-1, keepdim=True) + eps).sqrt() + + term.scatter_add_(0, t_pos_idx[:, 0:1], v01_dist) + term.scatter_add_(0, t_pos_idx[:, 1:2], v12_dist) + term.scatter_add_(0, t_pos_idx[:, 2:3], v20_dist) + + return term**2 + +def laplace_regularizer_const_adaptive(v_pos, t_pos_idx): + term = torch.zeros_like(v_pos) + norm = torch.zeros_like(v_pos[..., 0:1]) + + v0 = v_pos[t_pos_idx[:, 0], :] + v1 = v_pos[t_pos_idx[:, 1], :] + v2 = v_pos[t_pos_idx[:, 2], :] + + term.scatter_add_(0, t_pos_idx[:, 0:1].repeat(1,3), (v1 - v0) + (v2 - v0)) + term.scatter_add_(0, t_pos_idx[:, 1:2].repeat(1,3), (v0 - v1) + (v2 - v1)) + term.scatter_add_(0, t_pos_idx[:, 2:3].repeat(1,3), (v0 - v2) + (v1 - v2)) + + two = torch.ones_like(v0) * 2.0 + norm.scatter_add_(0, t_pos_idx[:, 0:1], two) + norm.scatter_add_(0, t_pos_idx[:, 1:2], two) + norm.scatter_add_(0, t_pos_idx[:, 2:3], two) + + term = term / torch.clamp(norm, min=1.0) + + v_pos = v_pos.unsqueeze(0) * 64 + with torch.no_grad(): + scale = (pytorch3d.ops.knn.knn_points(v_pos, v_pos, K=2).dists[0, :, -1].detach()).sqrt().pow(1.5) ## K=2 because dist(self, self)=0 + dist = term.pow(2).mean(-1) ### since the vanilla one uses mean + + return torch.mean(dist * scale) + +# def laplace_regularizer_const_sec_order(v_pos, t_pos_idx): +# term = torch.zeros_like(v_pos) +# norm = torch.zeros_like(v_pos[..., 0:1]) + +# v0 = v_pos[t_pos_idx[:, 0], :] +# v1 = v_pos[t_pos_idx[:, 1], :] +# v2 = v_pos[t_pos_idx[:, 2], :] + +# term.scatter_add_(0, t_pos_idx[:, 0:1].repeat(1,3), (v1 - v0) + (v2 - v0)) +# term.scatter_add_(0, t_pos_idx[:, 1:2].repeat(1,3), (v0 - v1) + (v2 - v1)) +# term.scatter_add_(0, t_pos_idx[:, 2:3].repeat(1,3), (v0 - v2) + (v1 - v2)) + +# two = torch.ones_like(v0) * 2.0 +# norm.scatter_add_(0, t_pos_idx[:, 0:1], two) +# norm.scatter_add_(0, t_pos_idx[:, 1:2], two) +# norm.scatter_add_(0, t_pos_idx[:, 2:3], two) + +# term = term / torch.clamp(norm, min=1.0) + +# return torch.mean(term**2) + +###################################################################################### +# Smooth vertex normals +###################################################################################### +def normal_consistency(v_pos, t_pos_idx): + # Compute face normals + v0 = v_pos[t_pos_idx[:, 0], :] + v1 = v_pos[t_pos_idx[:, 1], :] + v2 = v_pos[t_pos_idx[:, 2], :] + + face_normals = util.safe_normalize(torch.cross(v1 - v0, v2 - v0)) + + tris_per_edge = mesh.compute_edge_to_face_mapping(t_pos_idx) + + # Fetch normals for both faces sharind an edge + n0 = face_normals[tris_per_edge[:, 0], :] + n1 = face_normals[tris_per_edge[:, 1], :] + + # Compute error metric based on normal difference + term = torch.clamp(util.dot(n0, n1), min=-1.0, max=1.0) + term = (1.0 - term) * 0.5 + + return torch.mean(torch.abs(term)) diff --git a/nvdiffrec/lib/render/render.py b/nvdiffrec/lib/render/render.py new file mode 100644 index 0000000..77dd140 --- /dev/null +++ b/nvdiffrec/lib/render/render.py @@ -0,0 +1,454 @@ +# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +import torch +import nvdiffrast.torch as dr + +from . import util +from . import renderutils as ru +from . import light + +# ============================================================================================== +# Helper functions +# ============================================================================================== +def interpolate(attr, rast, attr_idx, rast_db=None): + return dr.interpolate(attr.contiguous(), rast, attr_idx, rast_db=rast_db, diff_attrs=None if rast_db is None else 'all') + +# ============================================================================================== +# pixel shader +# ============================================================================================== +def shade( + gb_pos, + gb_geometric_normal, + gb_normal, + gb_tangent, + gb_texc, + gb_texc_deriv, + view_pos, + lgt, + material, + bsdf, + xfm_lgt=None + ): + + ################################################################################ + # Texture lookups + ################################################################################ + perturbed_nrm = None + alpha_mtl = None + if 'kd_ks_normal' in material: + # Combined texture, used for MLPs because lookups are expensive + all_tex_jitter = material['kd_ks_normal'].sample(gb_pos + torch.normal(mean=0, std=0.01, size=gb_pos.shape, device="cuda")) + all_tex = material['kd_ks_normal'].sample(gb_pos) + assert all_tex.shape[-1] == 9 or all_tex.shape[-1] == 10, "Combined kd_ks_normal must be 9 or 10 channels" + kd, ks, perturbed_nrm = all_tex[..., :-6], all_tex[..., -6:-3], all_tex[..., -3:] + # Compute albedo (kd) gradient, used for material regularizer + kd_grad = torch.sum(torch.abs(all_tex_jitter[..., :-6] - all_tex[..., :-6]), dim=-1, keepdim=True) / 3 + else: + try: + kd_jitter = material['kd'].sample(gb_texc + torch.normal(mean=0, std=0.005, size=gb_texc.shape, device="cuda"), gb_texc_deriv) + if 'alpha' in material: + raise NotImplementedError + try: + alpha_mtl = material['alpha'].sample(gb_texc, gb_texc_deriv) + except: + alpha_mtl = material['alpha'].sample(gb_pos + torch.normal(mean=0, std=0.01, size=gb_pos.shape, device="cuda")) + kd = material['kd'].sample(gb_texc, gb_texc_deriv) + ks = material['ks'].sample(gb_texc, gb_texc_deriv)[..., 0:3] # skip alpha + kd_grad = torch.sum(torch.abs(kd_jitter[..., 0:3] - kd[..., 0:3]), dim=-1, keepdim=True) / 3 + except: + kd_jitter = kd = material['kd'].data[0].expand(*gb_pos.size()) + ks = material['ks'].data[0].expand(*gb_pos.size())[..., 0:3] # skip alpha + kd_grad = torch.sum(torch.abs(kd_jitter[..., 0:3] - kd[..., 0:3]), dim=-1, keepdim=True) / 3 + + # Separate kd into alpha and color, default alpha = 1 + alpha = kd[..., 3:4] if kd.shape[-1] == 4 else torch.ones_like(kd[..., 0:1]) + if alpha_mtl is not None: + alpha = alpha_mtl + kd = kd[..., 0:3] + + ################################################################################ + # Normal perturbation & normal bend + ################################################################################ + if 'no_perturbed_nrm' in material and material['no_perturbed_nrm']: + perturbed_nrm = None + + use_python = (gb_tangent is None) + + gb_normal = ru.prepare_shading_normal(gb_pos, view_pos, perturbed_nrm, gb_normal, gb_tangent, gb_geometric_normal, two_sided_shading=True, opengl=True, use_python=use_python) + gb_geo_normal_corrected = ru.prepare_shading_normal(gb_pos, view_pos, None, gb_geometric_normal, gb_tangent, gb_geometric_normal, two_sided_shading=True, opengl=True, use_python=use_python) + + ################################################################################ + # Evaluate BSDF + ################################################################################ + + assert 'bsdf' in material or bsdf is not None, "Material must specify a BSDF type" + bsdf = material['bsdf'] if bsdf is None else bsdf + if bsdf == 'pbr': + # do not use pbr + raise NotImplementedError + if isinstance(lgt, light.EnvironmentLight): + shaded_col = lgt.shade(gb_pos, gb_normal, kd, ks, view_pos, specular=True) + else: + assert False, "Invalid light type" + elif bsdf == 'diffuse': + if isinstance(lgt, light.EnvironmentLight): + shaded_col = lgt.shade(gb_pos, gb_geo_normal_corrected, kd, ks, view_pos, specular=False, xfm_lgt=xfm_lgt) + else: + assert False, "Invalid light type" + elif bsdf == 'normal': + shaded_col = (gb_normal + 1.0)*0.5 + elif bsdf == 'tangent': + shaded_col = (gb_tangent + 1.0)*0.5 + elif bsdf == 'kd': + shaded_col = kd + elif bsdf == 'ks': + shaded_col = ks + else: + assert False, "Invalid BSDF '%s'" % bsdf + + nan_mask = torch.isnan(shaded_col) + if nan_mask.any(): + raise + if alpha is not None: + nan_mask = torch.isnan(alpha) + if nan_mask.any(): + raise + + # Return multiple buffers + buffers = { + 'shaded' : torch.cat((shaded_col, alpha), dim=-1), + 'kd_grad' : torch.cat((kd_grad, alpha), dim=-1), + 'occlusion' : torch.cat((ks[..., :1], alpha), dim=-1), + 'normal' : torch.cat((gb_normal, alpha), dim=-1), + 'depth' : torch.cat(((gb_pos - view_pos).pow(2).sum(dim=-1, keepdim=True).sqrt(), alpha), dim=-1), + 'pos' : torch.cat((gb_pos, alpha), dim=-1), + 'geo_normal': torch.cat((gb_geo_normal_corrected, alpha), dim=-1), + 'geo_viewdir': torch.cat((view_pos - gb_pos, alpha), dim=-1), + 'alpha' : alpha + } + + + return buffers + +# ============================================================================================== +# Render a depth slice of the mesh (scene), some limitations: +# - Single mesh +# - Single light +# - Single material +# ============================================================================================== +def render_layer( + rast, + rast_deriv, + mesh, + view_pos, + lgt, + resolution, + spp, + msaa, + bsdf, + xfm_lgt = None, + flat_shading = False + ): + + full_res = [resolution[0]*spp, resolution[1]*spp] + + ################################################################################ + # Rasterize + ################################################################################ + + # Scale down to shading resolution when MSAA is enabled, otherwise shade at full resolution + if spp > 1 and msaa: + rast_out_s = util.scale_img_nhwc(rast, resolution, mag='nearest', min='nearest') + else: + rast_out_s = rast + + ################################################################################ + # Interpolate attributes + ################################################################################ + + # Interpolate world space position + gb_pos, _ = interpolate(mesh.v_pos[None, ...], rast_out_s, mesh.t_pos_idx.int()) + + # Compute geometric normals. We need those because of bent normals trick (for bump mapping) + v0 = mesh.v_pos[mesh.t_pos_idx[:, 0], :] + v1 = mesh.v_pos[mesh.t_pos_idx[:, 1], :] + v2 = mesh.v_pos[mesh.t_pos_idx[:, 2], :] + face_normals = util.safe_normalize(torch.cross(v1 - v0, v2 - v0)) + face_normal_indices = (torch.arange(0, face_normals.shape[0], dtype=torch.int64, device='cuda')[:, None]).repeat(1, 3) + gb_geometric_normal, _ = interpolate(face_normals[None, ...], rast_out_s, face_normal_indices.int()) + + if flat_shading: + gb_normal = mesh.f_nrm[rast_out_s[:, :, :, -1].long() - 1] # empty triangle get id=0; the first idx starts from 1 + gb_normal[rast_out_s[:, :, :, -1].long() == 0] = 0 + else: + assert mesh.v_nrm is not None + gb_normal, _ = interpolate(mesh.v_nrm[None, ...], rast_out_s, mesh.t_nrm_idx.int()) + + if mesh.v_tng is not None: + gb_tangent, _ = interpolate(mesh.v_tng[None, ...], rast_out_s, mesh.t_tng_idx.int()) # Interpolate tangents + else: + gb_tangent = None + + # Do not use texture coordinate in our case + gb_texc, gb_texc_deriv = None, None + + + ################################################################################ + # Shade + ################################################################################ + buffers = shade(gb_pos, gb_geometric_normal, gb_normal, gb_tangent, gb_texc, gb_texc_deriv, + view_pos, lgt, mesh.material, bsdf, xfm_lgt=xfm_lgt) + + #### get a mask on mesh (used to identify foreground) + mask_cont, _ = interpolate(torch.ones_like(mesh.v_pos[None, :, :1], device=mesh.v_pos.device), rast_out_s, mesh.t_pos_idx.int()) + mask = (mask_cont > 0).float() + buffers['mask'] = mask + buffers['mask_cont'] = mask_cont + + ################################################################################ + # Prepare output + ################################################################################ + + # Scale back up to visibility resolution if using MSAA + if spp > 1 and msaa: + for key in buffers.keys(): + if buffers[key] is not None: + buffers[key] = util.scale_img_nhwc(buffers[key], full_res, mag='nearest', min='nearest') + + + # Return buffers + return buffers + +# ============================================================================================== +# Render a depth peeled mesh (scene), some limitations: +# - Single mesh +# - Single light +# - Single material +# ============================================================================================== +def render_mesh( + ctx, + mesh, + mtx_in, + view_pos, + lgt, + resolution, + spp = 1, + num_layers = 1, + msaa = False, + background = None, + bsdf = None, + xfm_lgt = None, + tet_centers = None, + flat_shading = False + ): + + def prepare_input_vector(x): + x = torch.tensor(x, dtype=torch.float32, device='cuda') if not torch.is_tensor(x) else x + return x[:, None, None, :] if len(x.shape) == 2 else x + + def composite_buffer(key, layers, background, antialias): + accum = background + for buffers, rast in layers: + alpha = (rast[..., -1:] > 0).float() * buffers[key][..., -1:] + accum = torch.lerp(accum, torch.cat((buffers[key][..., :-1], torch.ones_like(buffers[key][..., -1:])), dim=-1), alpha) + if antialias: + accum = dr.antialias(accum.contiguous(), rast, v_pos_clip, mesh.t_pos_idx.int()) + break ## HACK: the first layer only + return accum + + def separate_buffer(key, layers, background, antialias): + accum_list = [] + for buffers, rast in layers: + accum = background + alpha = (rast[..., -1:] > 0).float() * buffers[key][..., -1:] + accum = torch.lerp(accum, torch.cat((buffers[key][..., :-1], torch.ones_like(buffers[key][..., -1:])), dim=-1), alpha) + if antialias: + accum = dr.antialias(accum.contiguous(), rast, v_pos_clip, mesh.t_pos_idx.int()) + accum_list.append(accum) + return accum_list + + assert mesh.t_pos_idx.shape[0] > 0, "Got empty training triangle mesh (unrecoverable discontinuity)" + assert background is None or (background.shape[1] == resolution[0] and background.shape[2] == resolution[1]) + + full_res = [resolution[0]*spp, resolution[1]*spp] + + # Convert numpy arrays to torch tensors + mtx_in = torch.tensor(mtx_in, dtype=torch.float32, device='cuda') if not torch.is_tensor(mtx_in) else mtx_in + view_pos = prepare_input_vector(view_pos) + + # clip space transform + v_pos_clip = ru.xfm_points(mesh.v_pos[None, ...], mtx_in) + + # Render all layers front-to-back + with dr.DepthPeeler(ctx, v_pos_clip, mesh.t_pos_idx.int(), full_res) as peeler: + rast, db = peeler.rasterize_next_layer() + layers = [(render_layer(rast, db, mesh, view_pos, lgt, resolution, spp, msaa, bsdf, xfm_lgt, flat_shading), rast)] + rast_1st_layer = rast + # with torch.no_grad(): + if True: + rast, db = peeler.rasterize_next_layer() + layers2 = [(render_layer(rast, db, mesh, view_pos, lgt, resolution, spp, msaa, bsdf, xfm_lgt, flat_shading), rast)] + + # Setup background + if background is not None: + if spp > 1: + background = util.scale_img_nhwc(background, full_res, mag='nearest', min='nearest') + background = torch.cat((background, torch.zeros_like(background[..., 0:1])), dim=-1) + else: + background = torch.zeros(1, full_res[0], full_res[1], 4, dtype=torch.float32, device='cuda') + + # Composite layers front-to-back + out_buffers = {} + for key in layers[0][0].keys(): + if key == 'shaded': + accum = composite_buffer(key, layers, background, True) + elif (key == 'depth' or key == 'pos') and layers[0][0][key] is not None: + accum = separate_buffer(key, layers, torch.ones_like(layers[0][0][key]) * 20.0, False) + elif ('normal' in key) and layers[0][0][key] is not None: + accum = composite_buffer(key, layers, torch.zeros_like(layers[0][0][key]), True) + elif layers[0][0][key] is not None: + accum = composite_buffer(key, layers, torch.zeros_like(layers[0][0][key]), False) + + if (key == 'depth' or key == 'pos') and layers[0][0][key] is not None: + out_buffers[key] = util.avg_pool_nhwc(accum[0], spp) if spp > 1 else accum[0] + else: + # Downscale to framebuffer resolution. Use avg pooling + out_buffers[key] = util.avg_pool_nhwc(accum, spp) if spp > 1 else accum + + accum = composite_buffer('shaded', layers, background, True) + out_buffers['shaded_second'] = util.avg_pool_nhwc(accum, spp) if spp > 1 else accum + + accum = separate_buffer('depth', layers2, -1 * torch.ones_like(layers2[0][0]['depth']), False) + out_buffers['depth_second'] = util.avg_pool_nhwc(accum[0], spp) if spp > 1 else accum[0] + + + accum = separate_buffer('normal', layers2, torch.zeros_like(layers2[0][0]['normal']), False) + out_buffers['normal_second'] = util.avg_pool_nhwc(accum[0], spp) if spp > 1 else accum[0] + + rast_triangle_id = rast_1st_layer[:, :, :, -1].unique() + if rast_triangle_id[0] == 0: + if rast_triangle_id.size(0) > 1: + rast_triangle_id = rast_triangle_id[1:] - 1 ## since by the convention of the rasterizer, 0 = empty + else: + rast_triangle_id = None + out_buffers['rast_triangle_id'] = rast_triangle_id + out_buffers['rast_depth'] = rast_1st_layer[:, :, :, -2] # z-buffer + + + + if tet_centers is not None: + with torch.no_grad(): + v_pos_clip = v_pos_clip[0] + assert full_res[0] == full_res[1] + homo_transformed_tet_centers = ru.xfm_points(tet_centers[None, ...], mtx_in) + transformed_tet_centers = homo_transformed_tet_centers[0, :, :3] / homo_transformed_tet_centers[0, :, 3:4] + + int_transformed_tet_centers = torch.round((transformed_tet_centers / 2.0 + 0.5) * (full_res[0] - 1)).long() # from the clip space (i.e., [-1, 1]^3) to the nearest integer coordinates in the canvas + + ### transpose THE "image" + tmp_int_transformed_tet_centers = int_transformed_tet_centers.clone() + int_transformed_tet_centers[:, 0] = tmp_int_transformed_tet_centers[:, 1] + int_transformed_tet_centers[:, 1] = tmp_int_transformed_tet_centers[:, 0] + + + valid_tet_centers = ((torch.logical_and((int_transformed_tet_centers <= full_res[0] - 1), int_transformed_tet_centers >= 0).float()).prod(dim=-1) == 1) # those tet centers in/on the edge of the clip space + valid_int_transformed_tet_centers = int_transformed_tet_centers[valid_tet_centers] + + tet_center_dirs = (tet_centers - view_pos.view(1, 3)) + tet_center_depths = tet_center_dirs.pow(2).sum(-1).sqrt() + + ### Finding occluded tetrahedra + valid_transformed_tet_center_depths = transformed_tet_centers[valid_tet_centers][:, -1] # get the depth in the clip space + valid_tet_ids = torch.arange(tet_centers.size(0)).to(valid_tet_centers.device)[valid_tet_centers] + + + corrected_rast_depth = out_buffers['rast_depth'].clone().detach() + + + corrected_rast_depth[rast_1st_layer[:, :, :, -1] == 0] = 100 # for all pixels without any rasterized mesh, just set the depth to a large enough value + + ''' + Hacky way of finding most of the non-occluded tetrahedra (except for already rasterized ones): + + For each pixel, find the min depth in a small neighborhood. + If the center of a tetrahedron (coinciding with this pixel when rasterized) is smaller than this min depth, + this tetrahedron is certainly non-occluded. + + Doing this because exact per-pixel comparison for triangular meshes can be costly, + plus we do not need to perfectly finding all visible tetrahedra. + ''' + depth_search_range = 7 ### change this value for different resolution in rasterization + corrected_rast_depth = -torch.nn.functional.max_pool2d( + -corrected_rast_depth, + kernel_size=2*depth_search_range+1, + stride=1, + padding=depth_search_range) + + valid_reference_depth = corrected_rast_depth[0, valid_int_transformed_tet_centers[:, 0], valid_int_transformed_tet_centers[:, 1]] + depth_filter = valid_reference_depth >= valid_transformed_tet_center_depths + + + empty_2d_mask = (rast_1st_layer[:, :, :, -1] == 0) + empty_2d_mask = (-torch.nn.functional.max_pool2d( + -empty_2d_mask.float(), + kernel_size=2*depth_search_range+1, + stride=1, + padding=depth_search_range)).bool() ### similar philosophy for using a neighborhood + empty_filter = empty_2d_mask[0, valid_int_transformed_tet_centers[:, 0], valid_int_transformed_tet_centers[:, 1]] + + ## visible tets are either determined by depth test or emptyness test + out_buffers['visible_tet_id'] = valid_tet_ids[torch.logical_or(empty_filter, depth_filter)] + + return out_buffers + +# ============================================================================================== +# Render UVs +# ============================================================================================== +def render_uv(ctx, mesh, resolution, mlp_texture): + + # clip space transform + uv_clip = mesh.v_tex[None, ...]*2.0 - 1.0 + + # pad to four component coordinate + uv_clip4 = torch.cat((uv_clip, torch.zeros_like(uv_clip[...,0:1]), torch.ones_like(uv_clip[...,0:1])), dim = -1) + + # rasterize + rast, _ = dr.rasterize(ctx, uv_clip4, mesh.t_tex_idx.int(), resolution) + + # Interpolate world space position + gb_pos, _ = interpolate(mesh.v_pos[None, ...], rast, mesh.t_pos_idx.int()) + + # Sample out textures from MLP + all_tex = mlp_texture.sample(gb_pos) + assert all_tex.shape[-1] == 9 or all_tex.shape[-1] == 10, "Combined kd_ks_normal must be 9 or 10 channels" + perturbed_nrm = all_tex[..., -3:] + return (rast[..., -1:] > 0).float(), all_tex[..., :-6], all_tex[..., -6:-3], util.safe_normalize(perturbed_nrm) + +# ============================================================================================== +# Render UVs +# ============================================================================================== +def render_uv_nrm(ctx, mesh, resolution, mlp_texture): + + # clip space transform + uv_clip = mesh.v_tex[None, ...]*2.0 - 1.0 + + # pad to four component coordinate + uv_clip4 = torch.cat((uv_clip, torch.zeros_like(uv_clip[...,0:1]), torch.ones_like(uv_clip[...,0:1])), dim = -1) + + # rasterize + rast, _ = dr.rasterize(ctx, uv_clip4, mesh.t_tex_idx.int(), resolution) + + # Interpolate world space position + gb_pos, _ = interpolate(mesh.v_pos[None, ...], rast, mesh.t_pos_idx.int()) + + # Sample out textures from MLP + all_tex = mlp_texture.sample(gb_pos) + perturbed_nrm = all_tex[..., -3:] + return (rast[..., -1:] > 0).float(), util.safe_normalize(perturbed_nrm) diff --git a/nvdiffrec/lib/render/renderutils/__init__.py b/nvdiffrec/lib/render/renderutils/__init__.py new file mode 100644 index 0000000..f29739f --- /dev/null +++ b/nvdiffrec/lib/render/renderutils/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +from .ops import xfm_points, xfm_vectors, image_loss, diffuse_cubemap, specular_cubemap, prepare_shading_normal, lambert, frostbite_diffuse, pbr_specular, pbr_bsdf, _fresnel_shlick, _ndf_ggx, _lambda_ggx, _masking_smith +__all__ = ["xfm_vectors", "xfm_points", "image_loss", "diffuse_cubemap","specular_cubemap", "prepare_shading_normal", "lambert", "frostbite_diffuse", "pbr_specular", "pbr_bsdf", "_fresnel_shlick", "_ndf_ggx", "_lambda_ggx", "_masking_smith", ] diff --git a/nvdiffrec/lib/render/renderutils/bsdf.cu b/nvdiffrec/lib/render/renderutils/bsdf.cu new file mode 100644 index 0000000..c167214 --- /dev/null +++ b/nvdiffrec/lib/render/renderutils/bsdf.cu @@ -0,0 +1,710 @@ +/* + * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#include "common.h" +#include "bsdf.h" + +#define SPECULAR_EPSILON 1e-4f + +//------------------------------------------------------------------------ +// Lambert functions + +__device__ inline float fwdLambert(const vec3f nrm, const vec3f wi) +{ + return max(dot(nrm, wi) / M_PI, 0.0f); +} + +__device__ inline void bwdLambert(const vec3f nrm, const vec3f wi, vec3f& d_nrm, vec3f& d_wi, const float d_out) +{ + if (dot(nrm, wi) > 0.0f) + bwdDot(nrm, wi, d_nrm, d_wi, d_out / M_PI); +} + +//------------------------------------------------------------------------ +// Fresnel Schlick + +__device__ inline float fwdFresnelSchlick(const float f0, const float f90, const float cosTheta) +{ + float _cosTheta = clamp(cosTheta, SPECULAR_EPSILON, 1.0f - SPECULAR_EPSILON); + float scale = powf(1.0f - _cosTheta, 5.0f); + return f0 * (1.0f - scale) + f90 * scale; +} + +__device__ inline void bwdFresnelSchlick(const float f0, const float f90, const float cosTheta, float& d_f0, float& d_f90, float& d_cosTheta, const float d_out) +{ + float _cosTheta = clamp(cosTheta, SPECULAR_EPSILON, 1.0f - SPECULAR_EPSILON); + float scale = pow(max(1.0f - _cosTheta, 0.0f), 5.0f); + d_f0 += d_out * (1.0 - scale); + d_f90 += d_out * scale; + if (cosTheta >= SPECULAR_EPSILON && cosTheta < 1.0f - SPECULAR_EPSILON) + { + d_cosTheta += d_out * (f90 - f0) * -5.0f * powf(1.0f - cosTheta, 4.0f); + } +} + +__device__ inline vec3f fwdFresnelSchlick(const vec3f f0, const vec3f f90, const float cosTheta) +{ + float _cosTheta = clamp(cosTheta, SPECULAR_EPSILON, 1.0f - SPECULAR_EPSILON); + float scale = powf(1.0f - _cosTheta, 5.0f); + return f0 * (1.0f - scale) + f90 * scale; +} + +__device__ inline void bwdFresnelSchlick(const vec3f f0, const vec3f f90, const float cosTheta, vec3f& d_f0, vec3f& d_f90, float& d_cosTheta, const vec3f d_out) +{ + float _cosTheta = clamp(cosTheta, SPECULAR_EPSILON, 1.0f - SPECULAR_EPSILON); + float scale = pow(max(1.0f - _cosTheta, 0.0f), 5.0f); + d_f0 += d_out * (1.0 - scale); + d_f90 += d_out * scale; + if (cosTheta >= SPECULAR_EPSILON && cosTheta < 1.0f - SPECULAR_EPSILON) + { + d_cosTheta += sum(d_out * (f90 - f0) * -5.0f * powf(1.0f - cosTheta, 4.0f)); + } +} + +//------------------------------------------------------------------------ +// Frostbite diffuse + +__device__ inline float fwdFrostbiteDiffuse(const vec3f nrm, const vec3f wi, const vec3f wo, float linearRoughness) +{ + float wiDotN = dot(wi, nrm); + float woDotN = dot(wo, nrm); + if (wiDotN > 0.0f && woDotN > 0.0f) + { + vec3f h = safeNormalize(wo + wi); + float wiDotH = dot(wi, h); + + float energyBias = 0.5f * linearRoughness; + float energyFactor = 1.0f - (0.51f / 1.51f) * linearRoughness; + float f90 = energyBias + 2.f * wiDotH * wiDotH * linearRoughness; + float f0 = 1.f; + + float wiScatter = fwdFresnelSchlick(f0, f90, wiDotN); + float woScatter = fwdFresnelSchlick(f0, f90, woDotN); + + return wiScatter * woScatter * energyFactor; + } + else return 0.0f; +} + +__device__ inline void bwdFrostbiteDiffuse(const vec3f nrm, const vec3f wi, const vec3f wo, float linearRoughness, vec3f& d_nrm, vec3f& d_wi, vec3f& d_wo, float &d_linearRoughness, const float d_out) +{ + float wiDotN = dot(wi, nrm); + float woDotN = dot(wo, nrm); + + if (wiDotN > 0.0f && woDotN > 0.0f) + { + vec3f h = safeNormalize(wo + wi); + float wiDotH = dot(wi, h); + + float energyBias = 0.5f * linearRoughness; + float energyFactor = 1.0f - (0.51f / 1.51f) * linearRoughness; + float f90 = energyBias + 2.f * wiDotH * wiDotH * linearRoughness; + float f0 = 1.f; + + float wiScatter = fwdFresnelSchlick(f0, f90, wiDotN); + float woScatter = fwdFresnelSchlick(f0, f90, woDotN); + + // -------------- BWD -------------- + // Backprop: return wiScatter * woScatter * energyFactor; + float d_wiScatter = d_out * woScatter * energyFactor; + float d_woScatter = d_out * wiScatter * energyFactor; + float d_energyFactor = d_out * wiScatter * woScatter; + + // Backprop: float woScatter = fwdFresnelSchlick(f0, f90, woDotN); + float d_woDotN = 0.0f, d_f0 = 0.0, d_f90 = 0.0f; + bwdFresnelSchlick(f0, f90, woDotN, d_f0, d_f90, d_woDotN, d_woScatter); + + // Backprop: float wiScatter = fwdFresnelSchlick(fd0, fd90, wiDotN); + float d_wiDotN = 0.0f; + bwdFresnelSchlick(f0, f90, wiDotN, d_f0, d_f90, d_wiDotN, d_wiScatter); + + // Backprop: float f90 = energyBias + 2.f * wiDotH * wiDotH * linearRoughness; + float d_energyBias = d_f90; + float d_wiDotH = d_f90 * 4 * wiDotH * linearRoughness; + d_linearRoughness += d_f90 * 2 * wiDotH * wiDotH; + + // Backprop: float energyFactor = 1.0f - (0.51f / 1.51f) * linearRoughness; + d_linearRoughness -= (0.51f / 1.51f) * d_energyFactor; + + // Backprop: float energyBias = 0.5f * linearRoughness; + d_linearRoughness += 0.5 * d_energyBias; + + // Backprop: float wiDotH = dot(wi, h); + vec3f d_h(0); + bwdDot(wi, h, d_wi, d_h, d_wiDotH); + + // Backprop: vec3f h = safeNormalize(wo + wi); + vec3f d_wo_wi(0); + bwdSafeNormalize(wo + wi, d_wo_wi, d_h); + d_wi += d_wo_wi; d_wo += d_wo_wi; + + bwdDot(wo, nrm, d_wo, d_nrm, d_woDotN); + bwdDot(wi, nrm, d_wi, d_nrm, d_wiDotN); + } +} + +//------------------------------------------------------------------------ +// Ndf GGX + +__device__ inline float fwdNdfGGX(const float alphaSqr, const float cosTheta) +{ + float _cosTheta = clamp(cosTheta, SPECULAR_EPSILON, 1.0f - SPECULAR_EPSILON); + float d = (_cosTheta * alphaSqr - _cosTheta) * _cosTheta + 1.0f; + return alphaSqr / (d * d * M_PI); +} + +__device__ inline void bwdNdfGGX(const float alphaSqr, const float cosTheta, float& d_alphaSqr, float& d_cosTheta, const float d_out) +{ + // Torch only back propagates if clamp doesn't trigger + float _cosTheta = clamp(cosTheta, SPECULAR_EPSILON, 1.0f - SPECULAR_EPSILON); + float cosThetaSqr = _cosTheta * _cosTheta; + d_alphaSqr += d_out * (1.0f - (alphaSqr + 1.0f) * cosThetaSqr) / (M_PI * powf((alphaSqr - 1.0) * cosThetaSqr + 1.0f, 3.0f)); + if (cosTheta > SPECULAR_EPSILON && cosTheta < 1.0f - SPECULAR_EPSILON) + { + d_cosTheta += d_out * -(4.0f * (alphaSqr - 1.0f) * alphaSqr * cosTheta) / (M_PI * powf((alphaSqr - 1.0) * cosThetaSqr + 1.0f, 3.0f)); + } +} + +//------------------------------------------------------------------------ +// Lambda GGX + +__device__ inline float fwdLambdaGGX(const float alphaSqr, const float cosTheta) +{ + float _cosTheta = clamp(cosTheta, SPECULAR_EPSILON, 1.0f - SPECULAR_EPSILON); + float cosThetaSqr = _cosTheta * _cosTheta; + float tanThetaSqr = (1.0 - cosThetaSqr) / cosThetaSqr; + float res = 0.5f * (sqrtf(1.0f + alphaSqr * tanThetaSqr) - 1.0f); + return res; +} + +__device__ inline void bwdLambdaGGX(const float alphaSqr, const float cosTheta, float& d_alphaSqr, float& d_cosTheta, const float d_out) +{ + float _cosTheta = clamp(cosTheta, SPECULAR_EPSILON, 1.0f - SPECULAR_EPSILON); + float cosThetaSqr = _cosTheta * _cosTheta; + float tanThetaSqr = (1.0 - cosThetaSqr) / cosThetaSqr; + float res = 0.5f * (sqrtf(1.0f + alphaSqr * tanThetaSqr) - 1.0f); + + d_alphaSqr += d_out * (0.25 * tanThetaSqr) / sqrtf(alphaSqr * tanThetaSqr + 1.0f); + if (cosTheta > SPECULAR_EPSILON && cosTheta < 1.0f - SPECULAR_EPSILON) + d_cosTheta += d_out * -(0.5 * alphaSqr) / (powf(_cosTheta, 3.0f) * sqrtf(alphaSqr / cosThetaSqr - alphaSqr + 1.0f)); +} + +//------------------------------------------------------------------------ +// Masking GGX + +__device__ inline float fwdMaskingSmithGGXCorrelated(const float alphaSqr, const float cosThetaI, const float cosThetaO) +{ + float lambdaI = fwdLambdaGGX(alphaSqr, cosThetaI); + float lambdaO = fwdLambdaGGX(alphaSqr, cosThetaO); + return 1.0f / (1.0f + lambdaI + lambdaO); +} + +__device__ inline void bwdMaskingSmithGGXCorrelated(const float alphaSqr, const float cosThetaI, const float cosThetaO, float& d_alphaSqr, float& d_cosThetaI, float& d_cosThetaO, const float d_out) +{ + // FWD eval + float lambdaI = fwdLambdaGGX(alphaSqr, cosThetaI); + float lambdaO = fwdLambdaGGX(alphaSqr, cosThetaO); + + // BWD eval + float d_lambdaIO = -d_out / powf(1.0f + lambdaI + lambdaO, 2.0f); + bwdLambdaGGX(alphaSqr, cosThetaI, d_alphaSqr, d_cosThetaI, d_lambdaIO); + bwdLambdaGGX(alphaSqr, cosThetaO, d_alphaSqr, d_cosThetaO, d_lambdaIO); +} + +//------------------------------------------------------------------------ +// GGX specular + +__device__ vec3f fwdPbrSpecular(const vec3f col, const vec3f nrm, const vec3f wo, const vec3f wi, const float alpha, const float min_roughness) +{ + float _alpha = clamp(alpha, min_roughness * min_roughness, 1.0f); + float alphaSqr = _alpha * _alpha; + + vec3f h = safeNormalize(wo + wi); + float woDotN = dot(wo, nrm); + float wiDotN = dot(wi, nrm); + float woDotH = dot(wo, h); + float nDotH = dot(nrm, h); + + float D = fwdNdfGGX(alphaSqr, nDotH); + float G = fwdMaskingSmithGGXCorrelated(alphaSqr, woDotN, wiDotN); + vec3f F = fwdFresnelSchlick(col, 1.0f, woDotH); + vec3f w = F * D * G * 0.25 / woDotN; + + bool frontfacing = (woDotN > SPECULAR_EPSILON) & (wiDotN > SPECULAR_EPSILON); + return frontfacing ? w : 0.0f; +} + +__device__ void bwdPbrSpecular( + const vec3f col, const vec3f nrm, const vec3f wo, const vec3f wi, const float alpha, const float min_roughness, + vec3f& d_col, vec3f& d_nrm, vec3f& d_wo, vec3f& d_wi, float& d_alpha, const vec3f d_out) +{ + /////////////////////////////////////////////////////////////////////// + // FWD eval + + float _alpha = clamp(alpha, min_roughness * min_roughness, 1.0f); + float alphaSqr = _alpha * _alpha; + + vec3f h = safeNormalize(wo + wi); + float woDotN = dot(wo, nrm); + float wiDotN = dot(wi, nrm); + float woDotH = dot(wo, h); + float nDotH = dot(nrm, h); + + float D = fwdNdfGGX(alphaSqr, nDotH); + float G = fwdMaskingSmithGGXCorrelated(alphaSqr, woDotN, wiDotN); + vec3f F = fwdFresnelSchlick(col, 1.0f, woDotH); + vec3f w = F * D * G * 0.25 / woDotN; + bool frontfacing = (woDotN > SPECULAR_EPSILON) & (wiDotN > SPECULAR_EPSILON); + + if (frontfacing) + { + /////////////////////////////////////////////////////////////////////// + // BWD eval + + vec3f d_F = d_out * D * G * 0.25f / woDotN; + float d_D = sum(d_out * F * G * 0.25f / woDotN); + float d_G = sum(d_out * F * D * 0.25f / woDotN); + + float d_woDotN = -sum(d_out * F * D * G * 0.25f / (woDotN * woDotN)); + + vec3f d_f90(0); + float d_woDotH(0), d_wiDotN(0), d_nDotH(0), d_alphaSqr(0); + bwdFresnelSchlick(col, 1.0f, woDotH, d_col, d_f90, d_woDotH, d_F); + bwdMaskingSmithGGXCorrelated(alphaSqr, woDotN, wiDotN, d_alphaSqr, d_woDotN, d_wiDotN, d_G); + bwdNdfGGX(alphaSqr, nDotH, d_alphaSqr, d_nDotH, d_D); + + vec3f d_h(0); + bwdDot(nrm, h, d_nrm, d_h, d_nDotH); + bwdDot(wo, h, d_wo, d_h, d_woDotH); + bwdDot(wi, nrm, d_wi, d_nrm, d_wiDotN); + bwdDot(wo, nrm, d_wo, d_nrm, d_woDotN); + + vec3f d_h_unnorm(0); + bwdSafeNormalize(wo + wi, d_h_unnorm, d_h); + d_wo += d_h_unnorm; + d_wi += d_h_unnorm; + + if (alpha > min_roughness * min_roughness) + d_alpha += d_alphaSqr * 2 * alpha; + } +} + +//------------------------------------------------------------------------ +// Full PBR BSDF + +__device__ vec3f fwdPbrBSDF(const vec3f kd, const vec3f arm, const vec3f pos, const vec3f nrm, const vec3f view_pos, const vec3f light_pos, const float min_roughness, int BSDF) +{ + vec3f wo = safeNormalize(view_pos - pos); + vec3f wi = safeNormalize(light_pos - pos); + + float alpha = arm.y * arm.y; + vec3f spec_col = (0.04f * (1.0f - arm.z) + kd * arm.z) * (1.0 - arm.x); + vec3f diff_col = kd * (1.0f - arm.z); + + float diff = 0.0f; + if (BSDF == 0) + diff = fwdLambert(nrm, wi); + else + diff = fwdFrostbiteDiffuse(nrm, wi, wo, arm.y); + vec3f diffuse = diff_col * diff; + vec3f specular = fwdPbrSpecular(spec_col, nrm, wo, wi, alpha, min_roughness); + + return diffuse + specular; +} + +__device__ void bwdPbrBSDF( + const vec3f kd, const vec3f arm, const vec3f pos, const vec3f nrm, const vec3f view_pos, const vec3f light_pos, const float min_roughness, int BSDF, + vec3f& d_kd, vec3f& d_arm, vec3f& d_pos, vec3f& d_nrm, vec3f& d_view_pos, vec3f& d_light_pos, const vec3f d_out) +{ + //////////////////////////////////////////////////////////////////////// + // FWD + vec3f _wi = light_pos - pos; + vec3f _wo = view_pos - pos; + vec3f wi = safeNormalize(_wi); + vec3f wo = safeNormalize(_wo); + + float alpha = arm.y * arm.y; + vec3f spec_col = (0.04f * (1.0f - arm.z) + kd * arm.z) * (1.0 - arm.x); + vec3f diff_col = kd * (1.0f - arm.z); + float diff = 0.0f; + if (BSDF == 0) + diff = fwdLambert(nrm, wi); + else + diff = fwdFrostbiteDiffuse(nrm, wi, wo, arm.y); + + //////////////////////////////////////////////////////////////////////// + // BWD + + float d_alpha(0); + vec3f d_spec_col(0), d_wi(0), d_wo(0); + bwdPbrSpecular(spec_col, nrm, wo, wi, alpha, min_roughness, d_spec_col, d_nrm, d_wo, d_wi, d_alpha, d_out); + + float d_diff = sum(diff_col * d_out); + if (BSDF == 0) + bwdLambert(nrm, wi, d_nrm, d_wi, d_diff); + else + bwdFrostbiteDiffuse(nrm, wi, wo, arm.y, d_nrm, d_wi, d_wo, d_arm.y, d_diff); + + // Backprop: diff_col = kd * (1.0f - arm.z) + vec3f d_diff_col = d_out * diff; + d_kd += d_diff_col * (1.0f - arm.z); + d_arm.z -= sum(d_diff_col * kd); + + // Backprop: spec_col = (0.04f * (1.0f - arm.z) + kd * arm.z) * (1.0 - arm.x) + d_kd -= d_spec_col * (arm.x - 1.0f) * arm.z; + d_arm.x += sum(d_spec_col * (arm.z * (0.04f - kd) - 0.04f)); + d_arm.z -= sum(d_spec_col * (kd - 0.04f) * (arm.x - 1.0f)); + + // Backprop: alpha = arm.y * arm.y + d_arm.y += d_alpha * 2 * arm.y; + + // Backprop: vec3f wi = safeNormalize(light_pos - pos); + vec3f d__wi(0); + bwdSafeNormalize(_wi, d__wi, d_wi); + d_light_pos += d__wi; + d_pos -= d__wi; + + // Backprop: vec3f wo = safeNormalize(view_pos - pos); + vec3f d__wo(0); + bwdSafeNormalize(_wo, d__wo, d_wo); + d_view_pos += d__wo; + d_pos -= d__wo; +} + +//------------------------------------------------------------------------ +// Kernels + +__global__ void LambertFwdKernel(LambertKernelParams p) +{ + // Calculate pixel position. + unsigned int px = blockIdx.x * blockDim.x + threadIdx.x; + unsigned int py = blockIdx.y * blockDim.y + threadIdx.y; + unsigned int pz = blockIdx.z; + if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z) + return; + + vec3f nrm = p.nrm.fetch3(px, py, pz); + vec3f wi = p.wi.fetch3(px, py, pz); + + float res = fwdLambert(nrm, wi); + + p.out.store(px, py, pz, res); +} + +__global__ void LambertBwdKernel(LambertKernelParams p) +{ + // Calculate pixel position. + unsigned int px = blockIdx.x * blockDim.x + threadIdx.x; + unsigned int py = blockIdx.y * blockDim.y + threadIdx.y; + unsigned int pz = blockIdx.z; + if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z) + return; + + vec3f nrm = p.nrm.fetch3(px, py, pz); + vec3f wi = p.wi.fetch3(px, py, pz); + float d_out = p.out.fetch1(px, py, pz); + + vec3f d_nrm(0), d_wi(0); + bwdLambert(nrm, wi, d_nrm, d_wi, d_out); + + p.nrm.store_grad(px, py, pz, d_nrm); + p.wi.store_grad(px, py, pz, d_wi); +} + +__global__ void FrostbiteDiffuseFwdKernel(FrostbiteDiffuseKernelParams p) +{ + // Calculate pixel position. + unsigned int px = blockIdx.x * blockDim.x + threadIdx.x; + unsigned int py = blockIdx.y * blockDim.y + threadIdx.y; + unsigned int pz = blockIdx.z; + if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z) + return; + + vec3f nrm = p.nrm.fetch3(px, py, pz); + vec3f wi = p.wi.fetch3(px, py, pz); + vec3f wo = p.wo.fetch3(px, py, pz); + float linearRoughness = p.linearRoughness.fetch1(px, py, pz); + + float res = fwdFrostbiteDiffuse(nrm, wi, wo, linearRoughness); + + p.out.store(px, py, pz, res); +} + +__global__ void FrostbiteDiffuseBwdKernel(FrostbiteDiffuseKernelParams p) +{ + // Calculate pixel position. + unsigned int px = blockIdx.x * blockDim.x + threadIdx.x; + unsigned int py = blockIdx.y * blockDim.y + threadIdx.y; + unsigned int pz = blockIdx.z; + if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z) + return; + + vec3f nrm = p.nrm.fetch3(px, py, pz); + vec3f wi = p.wi.fetch3(px, py, pz); + vec3f wo = p.wo.fetch3(px, py, pz); + float linearRoughness = p.linearRoughness.fetch1(px, py, pz); + float d_out = p.out.fetch1(px, py, pz); + + float d_linearRoughness = 0.0f; + vec3f d_nrm(0), d_wi(0), d_wo(0); + bwdFrostbiteDiffuse(nrm, wi, wo, linearRoughness, d_nrm, d_wi, d_wo, d_linearRoughness, d_out); + + p.nrm.store_grad(px, py, pz, d_nrm); + p.wi.store_grad(px, py, pz, d_wi); + p.wo.store_grad(px, py, pz, d_wo); + p.linearRoughness.store_grad(px, py, pz, d_linearRoughness); +} + +__global__ void FresnelShlickFwdKernel(FresnelShlickKernelParams p) +{ + // Calculate pixel position. + unsigned int px = blockIdx.x * blockDim.x + threadIdx.x; + unsigned int py = blockIdx.y * blockDim.y + threadIdx.y; + unsigned int pz = blockIdx.z; + if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z) + return; + + vec3f f0 = p.f0.fetch3(px, py, pz); + vec3f f90 = p.f90.fetch3(px, py, pz); + float cosTheta = p.cosTheta.fetch1(px, py, pz); + + vec3f res = fwdFresnelSchlick(f0, f90, cosTheta); + p.out.store(px, py, pz, res); +} + +__global__ void FresnelShlickBwdKernel(FresnelShlickKernelParams p) +{ + // Calculate pixel position. + unsigned int px = blockIdx.x * blockDim.x + threadIdx.x; + unsigned int py = blockIdx.y * blockDim.y + threadIdx.y; + unsigned int pz = blockIdx.z; + if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z) + return; + + vec3f f0 = p.f0.fetch3(px, py, pz); + vec3f f90 = p.f90.fetch3(px, py, pz); + float cosTheta = p.cosTheta.fetch1(px, py, pz); + vec3f d_out = p.out.fetch3(px, py, pz); + + vec3f d_f0(0), d_f90(0); + float d_cosTheta(0); + bwdFresnelSchlick(f0, f90, cosTheta, d_f0, d_f90, d_cosTheta, d_out); + + p.f0.store_grad(px, py, pz, d_f0); + p.f90.store_grad(px, py, pz, d_f90); + p.cosTheta.store_grad(px, py, pz, d_cosTheta); +} + +__global__ void ndfGGXFwdKernel(NdfGGXParams p) +{ + // Calculate pixel position. + unsigned int px = blockIdx.x * blockDim.x + threadIdx.x; + unsigned int py = blockIdx.y * blockDim.y + threadIdx.y; + unsigned int pz = blockIdx.z; + if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z) + return; + + float alphaSqr = p.alphaSqr.fetch1(px, py, pz); + float cosTheta = p.cosTheta.fetch1(px, py, pz); + float res = fwdNdfGGX(alphaSqr, cosTheta); + + p.out.store(px, py, pz, res); +} + +__global__ void ndfGGXBwdKernel(NdfGGXParams p) +{ + // Calculate pixel position. + unsigned int px = blockIdx.x * blockDim.x + threadIdx.x; + unsigned int py = blockIdx.y * blockDim.y + threadIdx.y; + unsigned int pz = blockIdx.z; + if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z) + return; + + float alphaSqr = p.alphaSqr.fetch1(px, py, pz); + float cosTheta = p.cosTheta.fetch1(px, py, pz); + float d_out = p.out.fetch1(px, py, pz); + + float d_alphaSqr(0), d_cosTheta(0); + bwdNdfGGX(alphaSqr, cosTheta, d_alphaSqr, d_cosTheta, d_out); + + p.alphaSqr.store_grad(px, py, pz, d_alphaSqr); + p.cosTheta.store_grad(px, py, pz, d_cosTheta); +} + +__global__ void lambdaGGXFwdKernel(NdfGGXParams p) +{ + // Calculate pixel position. + unsigned int px = blockIdx.x * blockDim.x + threadIdx.x; + unsigned int py = blockIdx.y * blockDim.y + threadIdx.y; + unsigned int pz = blockIdx.z; + if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z) + return; + + float alphaSqr = p.alphaSqr.fetch1(px, py, pz); + float cosTheta = p.cosTheta.fetch1(px, py, pz); + float res = fwdLambdaGGX(alphaSqr, cosTheta); + + p.out.store(px, py, pz, res); +} + +__global__ void lambdaGGXBwdKernel(NdfGGXParams p) +{ + // Calculate pixel position. + unsigned int px = blockIdx.x * blockDim.x + threadIdx.x; + unsigned int py = blockIdx.y * blockDim.y + threadIdx.y; + unsigned int pz = blockIdx.z; + if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z) + return; + + float alphaSqr = p.alphaSqr.fetch1(px, py, pz); + float cosTheta = p.cosTheta.fetch1(px, py, pz); + float d_out = p.out.fetch1(px, py, pz); + + float d_alphaSqr(0), d_cosTheta(0); + bwdLambdaGGX(alphaSqr, cosTheta, d_alphaSqr, d_cosTheta, d_out); + + p.alphaSqr.store_grad(px, py, pz, d_alphaSqr); + p.cosTheta.store_grad(px, py, pz, d_cosTheta); +} + +__global__ void maskingSmithFwdKernel(MaskingSmithParams p) +{ + // Calculate pixel position. + unsigned int px = blockIdx.x * blockDim.x + threadIdx.x; + unsigned int py = blockIdx.y * blockDim.y + threadIdx.y; + unsigned int pz = blockIdx.z; + if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z) + return; + + float alphaSqr = p.alphaSqr.fetch1(px, py, pz); + float cosThetaI = p.cosThetaI.fetch1(px, py, pz); + float cosThetaO = p.cosThetaO.fetch1(px, py, pz); + float res = fwdMaskingSmithGGXCorrelated(alphaSqr, cosThetaI, cosThetaO); + + p.out.store(px, py, pz, res); +} + +__global__ void maskingSmithBwdKernel(MaskingSmithParams p) +{ + // Calculate pixel position. + unsigned int px = blockIdx.x * blockDim.x + threadIdx.x; + unsigned int py = blockIdx.y * blockDim.y + threadIdx.y; + unsigned int pz = blockIdx.z; + if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z) + return; + + float alphaSqr = p.alphaSqr.fetch1(px, py, pz); + float cosThetaI = p.cosThetaI.fetch1(px, py, pz); + float cosThetaO = p.cosThetaO.fetch1(px, py, pz); + float d_out = p.out.fetch1(px, py, pz); + + float d_alphaSqr(0), d_cosThetaI(0), d_cosThetaO(0); + bwdMaskingSmithGGXCorrelated(alphaSqr, cosThetaI, cosThetaO, d_alphaSqr, d_cosThetaI, d_cosThetaO, d_out); + + p.alphaSqr.store_grad(px, py, pz, d_alphaSqr); + p.cosThetaI.store_grad(px, py, pz, d_cosThetaI); + p.cosThetaO.store_grad(px, py, pz, d_cosThetaO); +} + +__global__ void pbrSpecularFwdKernel(PbrSpecular p) +{ + // Calculate pixel position. + unsigned int px = blockIdx.x * blockDim.x + threadIdx.x; + unsigned int py = blockIdx.y * blockDim.y + threadIdx.y; + unsigned int pz = blockIdx.z; + if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z) + return; + + vec3f col = p.col.fetch3(px, py, pz); + vec3f nrm = p.nrm.fetch3(px, py, pz); + vec3f wo = p.wo.fetch3(px, py, pz); + vec3f wi = p.wi.fetch3(px, py, pz); + float alpha = p.alpha.fetch1(px, py, pz); + + vec3f res = fwdPbrSpecular(col, nrm, wo, wi, alpha, p.min_roughness); + + p.out.store(px, py, pz, res); +} + +__global__ void pbrSpecularBwdKernel(PbrSpecular p) +{ + // Calculate pixel position. + unsigned int px = blockIdx.x * blockDim.x + threadIdx.x; + unsigned int py = blockIdx.y * blockDim.y + threadIdx.y; + unsigned int pz = blockIdx.z; + if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z) + return; + + vec3f col = p.col.fetch3(px, py, pz); + vec3f nrm = p.nrm.fetch3(px, py, pz); + vec3f wo = p.wo.fetch3(px, py, pz); + vec3f wi = p.wi.fetch3(px, py, pz); + float alpha = p.alpha.fetch1(px, py, pz); + vec3f d_out = p.out.fetch3(px, py, pz); + + float d_alpha(0); + vec3f d_col(0), d_nrm(0), d_wo(0), d_wi(0); + bwdPbrSpecular(col, nrm, wo, wi, alpha, p.min_roughness, d_col, d_nrm, d_wo, d_wi, d_alpha, d_out); + + p.col.store_grad(px, py, pz, d_col); + p.nrm.store_grad(px, py, pz, d_nrm); + p.wo.store_grad(px, py, pz, d_wo); + p.wi.store_grad(px, py, pz, d_wi); + p.alpha.store_grad(px, py, pz, d_alpha); +} + +__global__ void pbrBSDFFwdKernel(PbrBSDF p) +{ + // Calculate pixel position. + unsigned int px = blockIdx.x * blockDim.x + threadIdx.x; + unsigned int py = blockIdx.y * blockDim.y + threadIdx.y; + unsigned int pz = blockIdx.z; + if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z) + return; + + vec3f kd = p.kd.fetch3(px, py, pz); + vec3f arm = p.arm.fetch3(px, py, pz); + vec3f pos = p.pos.fetch3(px, py, pz); + vec3f nrm = p.nrm.fetch3(px, py, pz); + vec3f view_pos = p.view_pos.fetch3(px, py, pz); + vec3f light_pos = p.light_pos.fetch3(px, py, pz); + + vec3f res = fwdPbrBSDF(kd, arm, pos, nrm, view_pos, light_pos, p.min_roughness, p.BSDF); + + p.out.store(px, py, pz, res); +} +__global__ void pbrBSDFBwdKernel(PbrBSDF p) +{ + // Calculate pixel position. + unsigned int px = blockIdx.x * blockDim.x + threadIdx.x; + unsigned int py = blockIdx.y * blockDim.y + threadIdx.y; + unsigned int pz = blockIdx.z; + if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z) + return; + + vec3f kd = p.kd.fetch3(px, py, pz); + vec3f arm = p.arm.fetch3(px, py, pz); + vec3f pos = p.pos.fetch3(px, py, pz); + vec3f nrm = p.nrm.fetch3(px, py, pz); + vec3f view_pos = p.view_pos.fetch3(px, py, pz); + vec3f light_pos = p.light_pos.fetch3(px, py, pz); + vec3f d_out = p.out.fetch3(px, py, pz); + + vec3f d_kd(0), d_arm(0), d_pos(0), d_nrm(0), d_view_pos(0), d_light_pos(0); + bwdPbrBSDF(kd, arm, pos, nrm, view_pos, light_pos, p.min_roughness, p.BSDF, d_kd, d_arm, d_pos, d_nrm, d_view_pos, d_light_pos, d_out); + + p.kd.store_grad(px, py, pz, d_kd); + p.arm.store_grad(px, py, pz, d_arm); + p.pos.store_grad(px, py, pz, d_pos); + p.nrm.store_grad(px, py, pz, d_nrm); + p.view_pos.store_grad(px, py, pz, d_view_pos); + p.light_pos.store_grad(px, py, pz, d_light_pos); +} diff --git a/nvdiffrec/lib/render/renderutils/bsdf.py b/nvdiffrec/lib/render/renderutils/bsdf.py new file mode 100644 index 0000000..ddc67a9 --- /dev/null +++ b/nvdiffrec/lib/render/renderutils/bsdf.py @@ -0,0 +1,154 @@ +# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +import math +import torch + +NORMAL_THRESHOLD = 0.1 + +################################################################################ +# Vector utility functions +################################################################################ + +def _dot(x, y): + return torch.sum(x*y, -1, keepdim=True) + +def _reflect(x, n): + return 2*_dot(x, n)*n - x + +def _safe_normalize(x): + return torch.nn.functional.normalize(x, dim = -1) + +def _bend_normal(view_vec, smooth_nrm, geom_nrm, two_sided_shading): + # Swap normal direction for backfacing surfaces + if two_sided_shading: + smooth_nrm = torch.where(_dot(geom_nrm, view_vec) > 0, smooth_nrm, -smooth_nrm) + geom_nrm = torch.where(_dot(geom_nrm, view_vec) > 0, geom_nrm, -geom_nrm) + + t = torch.clamp(_dot(view_vec, smooth_nrm) / NORMAL_THRESHOLD, min=0, max=1) + return torch.lerp(geom_nrm, smooth_nrm, t) + + +def _perturb_normal(perturbed_nrm, smooth_nrm, smooth_tng, opengl): + smooth_bitang = _safe_normalize(torch.cross(smooth_tng, smooth_nrm)) + if opengl: + shading_nrm = smooth_tng * perturbed_nrm[..., 0:1] - smooth_bitang * perturbed_nrm[..., 1:2] + smooth_nrm * torch.clamp(perturbed_nrm[..., 2:3], min=0.0) + else: + shading_nrm = smooth_tng * perturbed_nrm[..., 0:1] + smooth_bitang * perturbed_nrm[..., 1:2] + smooth_nrm * torch.clamp(perturbed_nrm[..., 2:3], min=0.0) + return _safe_normalize(shading_nrm) + +def bsdf_prepare_shading_normal(pos, view_pos, perturbed_nrm, smooth_nrm, smooth_tng, geom_nrm, two_sided_shading, opengl): + smooth_nrm = _safe_normalize(smooth_nrm) + view_vec = _safe_normalize(view_pos - pos) + if smooth_tng is None: + shading_nrm = smooth_nrm + else: + smooth_tng = _safe_normalize(smooth_tng) + shading_nrm = _perturb_normal(perturbed_nrm, smooth_nrm, smooth_tng, opengl) + return _bend_normal(view_vec, shading_nrm, geom_nrm, two_sided_shading) + +################################################################################ +# Simple lambertian diffuse BSDF +################################################################################ + +def bsdf_lambert(nrm, wi): + return torch.clamp(_dot(nrm, wi), min=0.0) / math.pi + +################################################################################ +# Frostbite diffuse +################################################################################ + +def bsdf_frostbite(nrm, wi, wo, linearRoughness): + wiDotN = _dot(wi, nrm) + woDotN = _dot(wo, nrm) + + h = _safe_normalize(wo + wi) + wiDotH = _dot(wi, h) + + energyBias = 0.5 * linearRoughness + energyFactor = 1.0 - (0.51 / 1.51) * linearRoughness + f90 = energyBias + 2.0 * wiDotH * wiDotH * linearRoughness + f0 = 1.0 + + wiScatter = bsdf_fresnel_shlick(f0, f90, wiDotN) + woScatter = bsdf_fresnel_shlick(f0, f90, woDotN) + res = wiScatter * woScatter * energyFactor + return torch.where((wiDotN > 0.0) & (woDotN > 0.0), res, torch.zeros_like(res)) + +################################################################################ +# Phong specular, loosely based on mitsuba implementation +################################################################################ + +def bsdf_phong(nrm, wo, wi, N): + dp_r = torch.clamp(_dot(_reflect(wo, nrm), wi), min=0.0, max=1.0) + dp_l = torch.clamp(_dot(nrm, wi), min=0.0, max=1.0) + return (dp_r ** N) * dp_l * (N + 2) / (2 * math.pi) + +################################################################################ +# PBR's implementation of GGX specular +################################################################################ + +specular_epsilon = 1e-4 + +def bsdf_fresnel_shlick(f0, f90, cosTheta): + _cosTheta = torch.clamp(cosTheta, min=specular_epsilon, max=1.0 - specular_epsilon) + return f0 + (f90 - f0) * (1.0 - _cosTheta) ** 5.0 + +def bsdf_ndf_ggx(alphaSqr, cosTheta): + _cosTheta = torch.clamp(cosTheta, min=specular_epsilon, max=1.0 - specular_epsilon) + d = (_cosTheta * alphaSqr - _cosTheta) * _cosTheta + 1 + return alphaSqr / (d * d * math.pi) + +def bsdf_lambda_ggx(alphaSqr, cosTheta): + _cosTheta = torch.clamp(cosTheta, min=specular_epsilon, max=1.0 - specular_epsilon) + cosThetaSqr = _cosTheta * _cosTheta + tanThetaSqr = (1.0 - cosThetaSqr) / cosThetaSqr + res = 0.5 * (torch.sqrt(1 + alphaSqr * tanThetaSqr) - 1.0) + return res + +def bsdf_masking_smith_ggx_correlated(alphaSqr, cosThetaI, cosThetaO): + lambdaI = bsdf_lambda_ggx(alphaSqr, cosThetaI) + lambdaO = bsdf_lambda_ggx(alphaSqr, cosThetaO) + return 1 / (1 + lambdaI + lambdaO) + +def bsdf_pbr_specular(col, nrm, wo, wi, alpha, min_roughness=0.08): + _alpha = torch.clamp(alpha, min=min_roughness*min_roughness, max=1.0) + alphaSqr = _alpha * _alpha + + h = _safe_normalize(wo + wi) + woDotN = _dot(wo, nrm) + wiDotN = _dot(wi, nrm) + woDotH = _dot(wo, h) + nDotH = _dot(nrm, h) + + D = bsdf_ndf_ggx(alphaSqr, nDotH) + G = bsdf_masking_smith_ggx_correlated(alphaSqr, woDotN, wiDotN) + F = bsdf_fresnel_shlick(col, 1, woDotH) + + w = F * D * G * 0.25 / torch.clamp(woDotN, min=specular_epsilon) + + frontfacing = (woDotN > specular_epsilon) & (wiDotN > specular_epsilon) + return torch.where(frontfacing, w, torch.zeros_like(w)) + +def bsdf_pbr(kd, arm, pos, nrm, view_pos, light_pos, min_roughness, BSDF): + wo = _safe_normalize(view_pos - pos) + wi = _safe_normalize(light_pos - pos) + + spec_str = arm[..., 0:1] # x component + roughness = arm[..., 1:2] # y component + metallic = arm[..., 2:3] # z component + ks = (0.04 * (1.0 - metallic) + kd * metallic) * (1 - spec_str) + kd = kd * (1.0 - metallic) + + if BSDF == 0: + diffuse = kd * bsdf_lambert(nrm, wi) + else: + diffuse = kd * bsdf_frostbite(nrm, wi, wo, roughness) + specular = bsdf_pbr_specular(ks, nrm, wo, wi, roughness*roughness, min_roughness=min_roughness) + return diffuse + specular diff --git a/nvdiffrec/lib/render/renderutils/c_src/bsdf.cu b/nvdiffrec/lib/render/renderutils/c_src/bsdf.cu new file mode 100644 index 0000000..c167214 --- /dev/null +++ b/nvdiffrec/lib/render/renderutils/c_src/bsdf.cu @@ -0,0 +1,710 @@ +/* + * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#include "common.h" +#include "bsdf.h" + +#define SPECULAR_EPSILON 1e-4f + +//------------------------------------------------------------------------ +// Lambert functions + +__device__ inline float fwdLambert(const vec3f nrm, const vec3f wi) +{ + return max(dot(nrm, wi) / M_PI, 0.0f); +} + +__device__ inline void bwdLambert(const vec3f nrm, const vec3f wi, vec3f& d_nrm, vec3f& d_wi, const float d_out) +{ + if (dot(nrm, wi) > 0.0f) + bwdDot(nrm, wi, d_nrm, d_wi, d_out / M_PI); +} + +//------------------------------------------------------------------------ +// Fresnel Schlick + +__device__ inline float fwdFresnelSchlick(const float f0, const float f90, const float cosTheta) +{ + float _cosTheta = clamp(cosTheta, SPECULAR_EPSILON, 1.0f - SPECULAR_EPSILON); + float scale = powf(1.0f - _cosTheta, 5.0f); + return f0 * (1.0f - scale) + f90 * scale; +} + +__device__ inline void bwdFresnelSchlick(const float f0, const float f90, const float cosTheta, float& d_f0, float& d_f90, float& d_cosTheta, const float d_out) +{ + float _cosTheta = clamp(cosTheta, SPECULAR_EPSILON, 1.0f - SPECULAR_EPSILON); + float scale = pow(max(1.0f - _cosTheta, 0.0f), 5.0f); + d_f0 += d_out * (1.0 - scale); + d_f90 += d_out * scale; + if (cosTheta >= SPECULAR_EPSILON && cosTheta < 1.0f - SPECULAR_EPSILON) + { + d_cosTheta += d_out * (f90 - f0) * -5.0f * powf(1.0f - cosTheta, 4.0f); + } +} + +__device__ inline vec3f fwdFresnelSchlick(const vec3f f0, const vec3f f90, const float cosTheta) +{ + float _cosTheta = clamp(cosTheta, SPECULAR_EPSILON, 1.0f - SPECULAR_EPSILON); + float scale = powf(1.0f - _cosTheta, 5.0f); + return f0 * (1.0f - scale) + f90 * scale; +} + +__device__ inline void bwdFresnelSchlick(const vec3f f0, const vec3f f90, const float cosTheta, vec3f& d_f0, vec3f& d_f90, float& d_cosTheta, const vec3f d_out) +{ + float _cosTheta = clamp(cosTheta, SPECULAR_EPSILON, 1.0f - SPECULAR_EPSILON); + float scale = pow(max(1.0f - _cosTheta, 0.0f), 5.0f); + d_f0 += d_out * (1.0 - scale); + d_f90 += d_out * scale; + if (cosTheta >= SPECULAR_EPSILON && cosTheta < 1.0f - SPECULAR_EPSILON) + { + d_cosTheta += sum(d_out * (f90 - f0) * -5.0f * powf(1.0f - cosTheta, 4.0f)); + } +} + +//------------------------------------------------------------------------ +// Frostbite diffuse + +__device__ inline float fwdFrostbiteDiffuse(const vec3f nrm, const vec3f wi, const vec3f wo, float linearRoughness) +{ + float wiDotN = dot(wi, nrm); + float woDotN = dot(wo, nrm); + if (wiDotN > 0.0f && woDotN > 0.0f) + { + vec3f h = safeNormalize(wo + wi); + float wiDotH = dot(wi, h); + + float energyBias = 0.5f * linearRoughness; + float energyFactor = 1.0f - (0.51f / 1.51f) * linearRoughness; + float f90 = energyBias + 2.f * wiDotH * wiDotH * linearRoughness; + float f0 = 1.f; + + float wiScatter = fwdFresnelSchlick(f0, f90, wiDotN); + float woScatter = fwdFresnelSchlick(f0, f90, woDotN); + + return wiScatter * woScatter * energyFactor; + } + else return 0.0f; +} + +__device__ inline void bwdFrostbiteDiffuse(const vec3f nrm, const vec3f wi, const vec3f wo, float linearRoughness, vec3f& d_nrm, vec3f& d_wi, vec3f& d_wo, float &d_linearRoughness, const float d_out) +{ + float wiDotN = dot(wi, nrm); + float woDotN = dot(wo, nrm); + + if (wiDotN > 0.0f && woDotN > 0.0f) + { + vec3f h = safeNormalize(wo + wi); + float wiDotH = dot(wi, h); + + float energyBias = 0.5f * linearRoughness; + float energyFactor = 1.0f - (0.51f / 1.51f) * linearRoughness; + float f90 = energyBias + 2.f * wiDotH * wiDotH * linearRoughness; + float f0 = 1.f; + + float wiScatter = fwdFresnelSchlick(f0, f90, wiDotN); + float woScatter = fwdFresnelSchlick(f0, f90, woDotN); + + // -------------- BWD -------------- + // Backprop: return wiScatter * woScatter * energyFactor; + float d_wiScatter = d_out * woScatter * energyFactor; + float d_woScatter = d_out * wiScatter * energyFactor; + float d_energyFactor = d_out * wiScatter * woScatter; + + // Backprop: float woScatter = fwdFresnelSchlick(f0, f90, woDotN); + float d_woDotN = 0.0f, d_f0 = 0.0, d_f90 = 0.0f; + bwdFresnelSchlick(f0, f90, woDotN, d_f0, d_f90, d_woDotN, d_woScatter); + + // Backprop: float wiScatter = fwdFresnelSchlick(fd0, fd90, wiDotN); + float d_wiDotN = 0.0f; + bwdFresnelSchlick(f0, f90, wiDotN, d_f0, d_f90, d_wiDotN, d_wiScatter); + + // Backprop: float f90 = energyBias + 2.f * wiDotH * wiDotH * linearRoughness; + float d_energyBias = d_f90; + float d_wiDotH = d_f90 * 4 * wiDotH * linearRoughness; + d_linearRoughness += d_f90 * 2 * wiDotH * wiDotH; + + // Backprop: float energyFactor = 1.0f - (0.51f / 1.51f) * linearRoughness; + d_linearRoughness -= (0.51f / 1.51f) * d_energyFactor; + + // Backprop: float energyBias = 0.5f * linearRoughness; + d_linearRoughness += 0.5 * d_energyBias; + + // Backprop: float wiDotH = dot(wi, h); + vec3f d_h(0); + bwdDot(wi, h, d_wi, d_h, d_wiDotH); + + // Backprop: vec3f h = safeNormalize(wo + wi); + vec3f d_wo_wi(0); + bwdSafeNormalize(wo + wi, d_wo_wi, d_h); + d_wi += d_wo_wi; d_wo += d_wo_wi; + + bwdDot(wo, nrm, d_wo, d_nrm, d_woDotN); + bwdDot(wi, nrm, d_wi, d_nrm, d_wiDotN); + } +} + +//------------------------------------------------------------------------ +// Ndf GGX + +__device__ inline float fwdNdfGGX(const float alphaSqr, const float cosTheta) +{ + float _cosTheta = clamp(cosTheta, SPECULAR_EPSILON, 1.0f - SPECULAR_EPSILON); + float d = (_cosTheta * alphaSqr - _cosTheta) * _cosTheta + 1.0f; + return alphaSqr / (d * d * M_PI); +} + +__device__ inline void bwdNdfGGX(const float alphaSqr, const float cosTheta, float& d_alphaSqr, float& d_cosTheta, const float d_out) +{ + // Torch only back propagates if clamp doesn't trigger + float _cosTheta = clamp(cosTheta, SPECULAR_EPSILON, 1.0f - SPECULAR_EPSILON); + float cosThetaSqr = _cosTheta * _cosTheta; + d_alphaSqr += d_out * (1.0f - (alphaSqr + 1.0f) * cosThetaSqr) / (M_PI * powf((alphaSqr - 1.0) * cosThetaSqr + 1.0f, 3.0f)); + if (cosTheta > SPECULAR_EPSILON && cosTheta < 1.0f - SPECULAR_EPSILON) + { + d_cosTheta += d_out * -(4.0f * (alphaSqr - 1.0f) * alphaSqr * cosTheta) / (M_PI * powf((alphaSqr - 1.0) * cosThetaSqr + 1.0f, 3.0f)); + } +} + +//------------------------------------------------------------------------ +// Lambda GGX + +__device__ inline float fwdLambdaGGX(const float alphaSqr, const float cosTheta) +{ + float _cosTheta = clamp(cosTheta, SPECULAR_EPSILON, 1.0f - SPECULAR_EPSILON); + float cosThetaSqr = _cosTheta * _cosTheta; + float tanThetaSqr = (1.0 - cosThetaSqr) / cosThetaSqr; + float res = 0.5f * (sqrtf(1.0f + alphaSqr * tanThetaSqr) - 1.0f); + return res; +} + +__device__ inline void bwdLambdaGGX(const float alphaSqr, const float cosTheta, float& d_alphaSqr, float& d_cosTheta, const float d_out) +{ + float _cosTheta = clamp(cosTheta, SPECULAR_EPSILON, 1.0f - SPECULAR_EPSILON); + float cosThetaSqr = _cosTheta * _cosTheta; + float tanThetaSqr = (1.0 - cosThetaSqr) / cosThetaSqr; + float res = 0.5f * (sqrtf(1.0f + alphaSqr * tanThetaSqr) - 1.0f); + + d_alphaSqr += d_out * (0.25 * tanThetaSqr) / sqrtf(alphaSqr * tanThetaSqr + 1.0f); + if (cosTheta > SPECULAR_EPSILON && cosTheta < 1.0f - SPECULAR_EPSILON) + d_cosTheta += d_out * -(0.5 * alphaSqr) / (powf(_cosTheta, 3.0f) * sqrtf(alphaSqr / cosThetaSqr - alphaSqr + 1.0f)); +} + +//------------------------------------------------------------------------ +// Masking GGX + +__device__ inline float fwdMaskingSmithGGXCorrelated(const float alphaSqr, const float cosThetaI, const float cosThetaO) +{ + float lambdaI = fwdLambdaGGX(alphaSqr, cosThetaI); + float lambdaO = fwdLambdaGGX(alphaSqr, cosThetaO); + return 1.0f / (1.0f + lambdaI + lambdaO); +} + +__device__ inline void bwdMaskingSmithGGXCorrelated(const float alphaSqr, const float cosThetaI, const float cosThetaO, float& d_alphaSqr, float& d_cosThetaI, float& d_cosThetaO, const float d_out) +{ + // FWD eval + float lambdaI = fwdLambdaGGX(alphaSqr, cosThetaI); + float lambdaO = fwdLambdaGGX(alphaSqr, cosThetaO); + + // BWD eval + float d_lambdaIO = -d_out / powf(1.0f + lambdaI + lambdaO, 2.0f); + bwdLambdaGGX(alphaSqr, cosThetaI, d_alphaSqr, d_cosThetaI, d_lambdaIO); + bwdLambdaGGX(alphaSqr, cosThetaO, d_alphaSqr, d_cosThetaO, d_lambdaIO); +} + +//------------------------------------------------------------------------ +// GGX specular + +__device__ vec3f fwdPbrSpecular(const vec3f col, const vec3f nrm, const vec3f wo, const vec3f wi, const float alpha, const float min_roughness) +{ + float _alpha = clamp(alpha, min_roughness * min_roughness, 1.0f); + float alphaSqr = _alpha * _alpha; + + vec3f h = safeNormalize(wo + wi); + float woDotN = dot(wo, nrm); + float wiDotN = dot(wi, nrm); + float woDotH = dot(wo, h); + float nDotH = dot(nrm, h); + + float D = fwdNdfGGX(alphaSqr, nDotH); + float G = fwdMaskingSmithGGXCorrelated(alphaSqr, woDotN, wiDotN); + vec3f F = fwdFresnelSchlick(col, 1.0f, woDotH); + vec3f w = F * D * G * 0.25 / woDotN; + + bool frontfacing = (woDotN > SPECULAR_EPSILON) & (wiDotN > SPECULAR_EPSILON); + return frontfacing ? w : 0.0f; +} + +__device__ void bwdPbrSpecular( + const vec3f col, const vec3f nrm, const vec3f wo, const vec3f wi, const float alpha, const float min_roughness, + vec3f& d_col, vec3f& d_nrm, vec3f& d_wo, vec3f& d_wi, float& d_alpha, const vec3f d_out) +{ + /////////////////////////////////////////////////////////////////////// + // FWD eval + + float _alpha = clamp(alpha, min_roughness * min_roughness, 1.0f); + float alphaSqr = _alpha * _alpha; + + vec3f h = safeNormalize(wo + wi); + float woDotN = dot(wo, nrm); + float wiDotN = dot(wi, nrm); + float woDotH = dot(wo, h); + float nDotH = dot(nrm, h); + + float D = fwdNdfGGX(alphaSqr, nDotH); + float G = fwdMaskingSmithGGXCorrelated(alphaSqr, woDotN, wiDotN); + vec3f F = fwdFresnelSchlick(col, 1.0f, woDotH); + vec3f w = F * D * G * 0.25 / woDotN; + bool frontfacing = (woDotN > SPECULAR_EPSILON) & (wiDotN > SPECULAR_EPSILON); + + if (frontfacing) + { + /////////////////////////////////////////////////////////////////////// + // BWD eval + + vec3f d_F = d_out * D * G * 0.25f / woDotN; + float d_D = sum(d_out * F * G * 0.25f / woDotN); + float d_G = sum(d_out * F * D * 0.25f / woDotN); + + float d_woDotN = -sum(d_out * F * D * G * 0.25f / (woDotN * woDotN)); + + vec3f d_f90(0); + float d_woDotH(0), d_wiDotN(0), d_nDotH(0), d_alphaSqr(0); + bwdFresnelSchlick(col, 1.0f, woDotH, d_col, d_f90, d_woDotH, d_F); + bwdMaskingSmithGGXCorrelated(alphaSqr, woDotN, wiDotN, d_alphaSqr, d_woDotN, d_wiDotN, d_G); + bwdNdfGGX(alphaSqr, nDotH, d_alphaSqr, d_nDotH, d_D); + + vec3f d_h(0); + bwdDot(nrm, h, d_nrm, d_h, d_nDotH); + bwdDot(wo, h, d_wo, d_h, d_woDotH); + bwdDot(wi, nrm, d_wi, d_nrm, d_wiDotN); + bwdDot(wo, nrm, d_wo, d_nrm, d_woDotN); + + vec3f d_h_unnorm(0); + bwdSafeNormalize(wo + wi, d_h_unnorm, d_h); + d_wo += d_h_unnorm; + d_wi += d_h_unnorm; + + if (alpha > min_roughness * min_roughness) + d_alpha += d_alphaSqr * 2 * alpha; + } +} + +//------------------------------------------------------------------------ +// Full PBR BSDF + +__device__ vec3f fwdPbrBSDF(const vec3f kd, const vec3f arm, const vec3f pos, const vec3f nrm, const vec3f view_pos, const vec3f light_pos, const float min_roughness, int BSDF) +{ + vec3f wo = safeNormalize(view_pos - pos); + vec3f wi = safeNormalize(light_pos - pos); + + float alpha = arm.y * arm.y; + vec3f spec_col = (0.04f * (1.0f - arm.z) + kd * arm.z) * (1.0 - arm.x); + vec3f diff_col = kd * (1.0f - arm.z); + + float diff = 0.0f; + if (BSDF == 0) + diff = fwdLambert(nrm, wi); + else + diff = fwdFrostbiteDiffuse(nrm, wi, wo, arm.y); + vec3f diffuse = diff_col * diff; + vec3f specular = fwdPbrSpecular(spec_col, nrm, wo, wi, alpha, min_roughness); + + return diffuse + specular; +} + +__device__ void bwdPbrBSDF( + const vec3f kd, const vec3f arm, const vec3f pos, const vec3f nrm, const vec3f view_pos, const vec3f light_pos, const float min_roughness, int BSDF, + vec3f& d_kd, vec3f& d_arm, vec3f& d_pos, vec3f& d_nrm, vec3f& d_view_pos, vec3f& d_light_pos, const vec3f d_out) +{ + //////////////////////////////////////////////////////////////////////// + // FWD + vec3f _wi = light_pos - pos; + vec3f _wo = view_pos - pos; + vec3f wi = safeNormalize(_wi); + vec3f wo = safeNormalize(_wo); + + float alpha = arm.y * arm.y; + vec3f spec_col = (0.04f * (1.0f - arm.z) + kd * arm.z) * (1.0 - arm.x); + vec3f diff_col = kd * (1.0f - arm.z); + float diff = 0.0f; + if (BSDF == 0) + diff = fwdLambert(nrm, wi); + else + diff = fwdFrostbiteDiffuse(nrm, wi, wo, arm.y); + + //////////////////////////////////////////////////////////////////////// + // BWD + + float d_alpha(0); + vec3f d_spec_col(0), d_wi(0), d_wo(0); + bwdPbrSpecular(spec_col, nrm, wo, wi, alpha, min_roughness, d_spec_col, d_nrm, d_wo, d_wi, d_alpha, d_out); + + float d_diff = sum(diff_col * d_out); + if (BSDF == 0) + bwdLambert(nrm, wi, d_nrm, d_wi, d_diff); + else + bwdFrostbiteDiffuse(nrm, wi, wo, arm.y, d_nrm, d_wi, d_wo, d_arm.y, d_diff); + + // Backprop: diff_col = kd * (1.0f - arm.z) + vec3f d_diff_col = d_out * diff; + d_kd += d_diff_col * (1.0f - arm.z); + d_arm.z -= sum(d_diff_col * kd); + + // Backprop: spec_col = (0.04f * (1.0f - arm.z) + kd * arm.z) * (1.0 - arm.x) + d_kd -= d_spec_col * (arm.x - 1.0f) * arm.z; + d_arm.x += sum(d_spec_col * (arm.z * (0.04f - kd) - 0.04f)); + d_arm.z -= sum(d_spec_col * (kd - 0.04f) * (arm.x - 1.0f)); + + // Backprop: alpha = arm.y * arm.y + d_arm.y += d_alpha * 2 * arm.y; + + // Backprop: vec3f wi = safeNormalize(light_pos - pos); + vec3f d__wi(0); + bwdSafeNormalize(_wi, d__wi, d_wi); + d_light_pos += d__wi; + d_pos -= d__wi; + + // Backprop: vec3f wo = safeNormalize(view_pos - pos); + vec3f d__wo(0); + bwdSafeNormalize(_wo, d__wo, d_wo); + d_view_pos += d__wo; + d_pos -= d__wo; +} + +//------------------------------------------------------------------------ +// Kernels + +__global__ void LambertFwdKernel(LambertKernelParams p) +{ + // Calculate pixel position. + unsigned int px = blockIdx.x * blockDim.x + threadIdx.x; + unsigned int py = blockIdx.y * blockDim.y + threadIdx.y; + unsigned int pz = blockIdx.z; + if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z) + return; + + vec3f nrm = p.nrm.fetch3(px, py, pz); + vec3f wi = p.wi.fetch3(px, py, pz); + + float res = fwdLambert(nrm, wi); + + p.out.store(px, py, pz, res); +} + +__global__ void LambertBwdKernel(LambertKernelParams p) +{ + // Calculate pixel position. + unsigned int px = blockIdx.x * blockDim.x + threadIdx.x; + unsigned int py = blockIdx.y * blockDim.y + threadIdx.y; + unsigned int pz = blockIdx.z; + if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z) + return; + + vec3f nrm = p.nrm.fetch3(px, py, pz); + vec3f wi = p.wi.fetch3(px, py, pz); + float d_out = p.out.fetch1(px, py, pz); + + vec3f d_nrm(0), d_wi(0); + bwdLambert(nrm, wi, d_nrm, d_wi, d_out); + + p.nrm.store_grad(px, py, pz, d_nrm); + p.wi.store_grad(px, py, pz, d_wi); +} + +__global__ void FrostbiteDiffuseFwdKernel(FrostbiteDiffuseKernelParams p) +{ + // Calculate pixel position. + unsigned int px = blockIdx.x * blockDim.x + threadIdx.x; + unsigned int py = blockIdx.y * blockDim.y + threadIdx.y; + unsigned int pz = blockIdx.z; + if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z) + return; + + vec3f nrm = p.nrm.fetch3(px, py, pz); + vec3f wi = p.wi.fetch3(px, py, pz); + vec3f wo = p.wo.fetch3(px, py, pz); + float linearRoughness = p.linearRoughness.fetch1(px, py, pz); + + float res = fwdFrostbiteDiffuse(nrm, wi, wo, linearRoughness); + + p.out.store(px, py, pz, res); +} + +__global__ void FrostbiteDiffuseBwdKernel(FrostbiteDiffuseKernelParams p) +{ + // Calculate pixel position. + unsigned int px = blockIdx.x * blockDim.x + threadIdx.x; + unsigned int py = blockIdx.y * blockDim.y + threadIdx.y; + unsigned int pz = blockIdx.z; + if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z) + return; + + vec3f nrm = p.nrm.fetch3(px, py, pz); + vec3f wi = p.wi.fetch3(px, py, pz); + vec3f wo = p.wo.fetch3(px, py, pz); + float linearRoughness = p.linearRoughness.fetch1(px, py, pz); + float d_out = p.out.fetch1(px, py, pz); + + float d_linearRoughness = 0.0f; + vec3f d_nrm(0), d_wi(0), d_wo(0); + bwdFrostbiteDiffuse(nrm, wi, wo, linearRoughness, d_nrm, d_wi, d_wo, d_linearRoughness, d_out); + + p.nrm.store_grad(px, py, pz, d_nrm); + p.wi.store_grad(px, py, pz, d_wi); + p.wo.store_grad(px, py, pz, d_wo); + p.linearRoughness.store_grad(px, py, pz, d_linearRoughness); +} + +__global__ void FresnelShlickFwdKernel(FresnelShlickKernelParams p) +{ + // Calculate pixel position. + unsigned int px = blockIdx.x * blockDim.x + threadIdx.x; + unsigned int py = blockIdx.y * blockDim.y + threadIdx.y; + unsigned int pz = blockIdx.z; + if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z) + return; + + vec3f f0 = p.f0.fetch3(px, py, pz); + vec3f f90 = p.f90.fetch3(px, py, pz); + float cosTheta = p.cosTheta.fetch1(px, py, pz); + + vec3f res = fwdFresnelSchlick(f0, f90, cosTheta); + p.out.store(px, py, pz, res); +} + +__global__ void FresnelShlickBwdKernel(FresnelShlickKernelParams p) +{ + // Calculate pixel position. + unsigned int px = blockIdx.x * blockDim.x + threadIdx.x; + unsigned int py = blockIdx.y * blockDim.y + threadIdx.y; + unsigned int pz = blockIdx.z; + if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z) + return; + + vec3f f0 = p.f0.fetch3(px, py, pz); + vec3f f90 = p.f90.fetch3(px, py, pz); + float cosTheta = p.cosTheta.fetch1(px, py, pz); + vec3f d_out = p.out.fetch3(px, py, pz); + + vec3f d_f0(0), d_f90(0); + float d_cosTheta(0); + bwdFresnelSchlick(f0, f90, cosTheta, d_f0, d_f90, d_cosTheta, d_out); + + p.f0.store_grad(px, py, pz, d_f0); + p.f90.store_grad(px, py, pz, d_f90); + p.cosTheta.store_grad(px, py, pz, d_cosTheta); +} + +__global__ void ndfGGXFwdKernel(NdfGGXParams p) +{ + // Calculate pixel position. + unsigned int px = blockIdx.x * blockDim.x + threadIdx.x; + unsigned int py = blockIdx.y * blockDim.y + threadIdx.y; + unsigned int pz = blockIdx.z; + if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z) + return; + + float alphaSqr = p.alphaSqr.fetch1(px, py, pz); + float cosTheta = p.cosTheta.fetch1(px, py, pz); + float res = fwdNdfGGX(alphaSqr, cosTheta); + + p.out.store(px, py, pz, res); +} + +__global__ void ndfGGXBwdKernel(NdfGGXParams p) +{ + // Calculate pixel position. + unsigned int px = blockIdx.x * blockDim.x + threadIdx.x; + unsigned int py = blockIdx.y * blockDim.y + threadIdx.y; + unsigned int pz = blockIdx.z; + if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z) + return; + + float alphaSqr = p.alphaSqr.fetch1(px, py, pz); + float cosTheta = p.cosTheta.fetch1(px, py, pz); + float d_out = p.out.fetch1(px, py, pz); + + float d_alphaSqr(0), d_cosTheta(0); + bwdNdfGGX(alphaSqr, cosTheta, d_alphaSqr, d_cosTheta, d_out); + + p.alphaSqr.store_grad(px, py, pz, d_alphaSqr); + p.cosTheta.store_grad(px, py, pz, d_cosTheta); +} + +__global__ void lambdaGGXFwdKernel(NdfGGXParams p) +{ + // Calculate pixel position. + unsigned int px = blockIdx.x * blockDim.x + threadIdx.x; + unsigned int py = blockIdx.y * blockDim.y + threadIdx.y; + unsigned int pz = blockIdx.z; + if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z) + return; + + float alphaSqr = p.alphaSqr.fetch1(px, py, pz); + float cosTheta = p.cosTheta.fetch1(px, py, pz); + float res = fwdLambdaGGX(alphaSqr, cosTheta); + + p.out.store(px, py, pz, res); +} + +__global__ void lambdaGGXBwdKernel(NdfGGXParams p) +{ + // Calculate pixel position. + unsigned int px = blockIdx.x * blockDim.x + threadIdx.x; + unsigned int py = blockIdx.y * blockDim.y + threadIdx.y; + unsigned int pz = blockIdx.z; + if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z) + return; + + float alphaSqr = p.alphaSqr.fetch1(px, py, pz); + float cosTheta = p.cosTheta.fetch1(px, py, pz); + float d_out = p.out.fetch1(px, py, pz); + + float d_alphaSqr(0), d_cosTheta(0); + bwdLambdaGGX(alphaSqr, cosTheta, d_alphaSqr, d_cosTheta, d_out); + + p.alphaSqr.store_grad(px, py, pz, d_alphaSqr); + p.cosTheta.store_grad(px, py, pz, d_cosTheta); +} + +__global__ void maskingSmithFwdKernel(MaskingSmithParams p) +{ + // Calculate pixel position. + unsigned int px = blockIdx.x * blockDim.x + threadIdx.x; + unsigned int py = blockIdx.y * blockDim.y + threadIdx.y; + unsigned int pz = blockIdx.z; + if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z) + return; + + float alphaSqr = p.alphaSqr.fetch1(px, py, pz); + float cosThetaI = p.cosThetaI.fetch1(px, py, pz); + float cosThetaO = p.cosThetaO.fetch1(px, py, pz); + float res = fwdMaskingSmithGGXCorrelated(alphaSqr, cosThetaI, cosThetaO); + + p.out.store(px, py, pz, res); +} + +__global__ void maskingSmithBwdKernel(MaskingSmithParams p) +{ + // Calculate pixel position. + unsigned int px = blockIdx.x * blockDim.x + threadIdx.x; + unsigned int py = blockIdx.y * blockDim.y + threadIdx.y; + unsigned int pz = blockIdx.z; + if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z) + return; + + float alphaSqr = p.alphaSqr.fetch1(px, py, pz); + float cosThetaI = p.cosThetaI.fetch1(px, py, pz); + float cosThetaO = p.cosThetaO.fetch1(px, py, pz); + float d_out = p.out.fetch1(px, py, pz); + + float d_alphaSqr(0), d_cosThetaI(0), d_cosThetaO(0); + bwdMaskingSmithGGXCorrelated(alphaSqr, cosThetaI, cosThetaO, d_alphaSqr, d_cosThetaI, d_cosThetaO, d_out); + + p.alphaSqr.store_grad(px, py, pz, d_alphaSqr); + p.cosThetaI.store_grad(px, py, pz, d_cosThetaI); + p.cosThetaO.store_grad(px, py, pz, d_cosThetaO); +} + +__global__ void pbrSpecularFwdKernel(PbrSpecular p) +{ + // Calculate pixel position. + unsigned int px = blockIdx.x * blockDim.x + threadIdx.x; + unsigned int py = blockIdx.y * blockDim.y + threadIdx.y; + unsigned int pz = blockIdx.z; + if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z) + return; + + vec3f col = p.col.fetch3(px, py, pz); + vec3f nrm = p.nrm.fetch3(px, py, pz); + vec3f wo = p.wo.fetch3(px, py, pz); + vec3f wi = p.wi.fetch3(px, py, pz); + float alpha = p.alpha.fetch1(px, py, pz); + + vec3f res = fwdPbrSpecular(col, nrm, wo, wi, alpha, p.min_roughness); + + p.out.store(px, py, pz, res); +} + +__global__ void pbrSpecularBwdKernel(PbrSpecular p) +{ + // Calculate pixel position. + unsigned int px = blockIdx.x * blockDim.x + threadIdx.x; + unsigned int py = blockIdx.y * blockDim.y + threadIdx.y; + unsigned int pz = blockIdx.z; + if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z) + return; + + vec3f col = p.col.fetch3(px, py, pz); + vec3f nrm = p.nrm.fetch3(px, py, pz); + vec3f wo = p.wo.fetch3(px, py, pz); + vec3f wi = p.wi.fetch3(px, py, pz); + float alpha = p.alpha.fetch1(px, py, pz); + vec3f d_out = p.out.fetch3(px, py, pz); + + float d_alpha(0); + vec3f d_col(0), d_nrm(0), d_wo(0), d_wi(0); + bwdPbrSpecular(col, nrm, wo, wi, alpha, p.min_roughness, d_col, d_nrm, d_wo, d_wi, d_alpha, d_out); + + p.col.store_grad(px, py, pz, d_col); + p.nrm.store_grad(px, py, pz, d_nrm); + p.wo.store_grad(px, py, pz, d_wo); + p.wi.store_grad(px, py, pz, d_wi); + p.alpha.store_grad(px, py, pz, d_alpha); +} + +__global__ void pbrBSDFFwdKernel(PbrBSDF p) +{ + // Calculate pixel position. + unsigned int px = blockIdx.x * blockDim.x + threadIdx.x; + unsigned int py = blockIdx.y * blockDim.y + threadIdx.y; + unsigned int pz = blockIdx.z; + if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z) + return; + + vec3f kd = p.kd.fetch3(px, py, pz); + vec3f arm = p.arm.fetch3(px, py, pz); + vec3f pos = p.pos.fetch3(px, py, pz); + vec3f nrm = p.nrm.fetch3(px, py, pz); + vec3f view_pos = p.view_pos.fetch3(px, py, pz); + vec3f light_pos = p.light_pos.fetch3(px, py, pz); + + vec3f res = fwdPbrBSDF(kd, arm, pos, nrm, view_pos, light_pos, p.min_roughness, p.BSDF); + + p.out.store(px, py, pz, res); +} +__global__ void pbrBSDFBwdKernel(PbrBSDF p) +{ + // Calculate pixel position. + unsigned int px = blockIdx.x * blockDim.x + threadIdx.x; + unsigned int py = blockIdx.y * blockDim.y + threadIdx.y; + unsigned int pz = blockIdx.z; + if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z) + return; + + vec3f kd = p.kd.fetch3(px, py, pz); + vec3f arm = p.arm.fetch3(px, py, pz); + vec3f pos = p.pos.fetch3(px, py, pz); + vec3f nrm = p.nrm.fetch3(px, py, pz); + vec3f view_pos = p.view_pos.fetch3(px, py, pz); + vec3f light_pos = p.light_pos.fetch3(px, py, pz); + vec3f d_out = p.out.fetch3(px, py, pz); + + vec3f d_kd(0), d_arm(0), d_pos(0), d_nrm(0), d_view_pos(0), d_light_pos(0); + bwdPbrBSDF(kd, arm, pos, nrm, view_pos, light_pos, p.min_roughness, p.BSDF, d_kd, d_arm, d_pos, d_nrm, d_view_pos, d_light_pos, d_out); + + p.kd.store_grad(px, py, pz, d_kd); + p.arm.store_grad(px, py, pz, d_arm); + p.pos.store_grad(px, py, pz, d_pos); + p.nrm.store_grad(px, py, pz, d_nrm); + p.view_pos.store_grad(px, py, pz, d_view_pos); + p.light_pos.store_grad(px, py, pz, d_light_pos); +} diff --git a/nvdiffrec/lib/render/renderutils/c_src/bsdf.h b/nvdiffrec/lib/render/renderutils/c_src/bsdf.h new file mode 100644 index 0000000..59adbf0 --- /dev/null +++ b/nvdiffrec/lib/render/renderutils/c_src/bsdf.h @@ -0,0 +1,84 @@ +/* + * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#pragma once + +#include "common.h" + +struct LambertKernelParams +{ + Tensor nrm; + Tensor wi; + Tensor out; + dim3 gridSize; +}; + +struct FrostbiteDiffuseKernelParams +{ + Tensor nrm; + Tensor wi; + Tensor wo; + Tensor linearRoughness; + Tensor out; + dim3 gridSize; +}; + +struct FresnelShlickKernelParams +{ + Tensor f0; + Tensor f90; + Tensor cosTheta; + Tensor out; + dim3 gridSize; +}; + +struct NdfGGXParams +{ + Tensor alphaSqr; + Tensor cosTheta; + Tensor out; + dim3 gridSize; +}; + +struct MaskingSmithParams +{ + Tensor alphaSqr; + Tensor cosThetaI; + Tensor cosThetaO; + Tensor out; + dim3 gridSize; +}; + +struct PbrSpecular +{ + Tensor col; + Tensor nrm; + Tensor wo; + Tensor wi; + Tensor alpha; + Tensor out; + dim3 gridSize; + float min_roughness; +}; + +struct PbrBSDF +{ + Tensor kd; + Tensor arm; + Tensor pos; + Tensor nrm; + Tensor view_pos; + Tensor light_pos; + Tensor out; + dim3 gridSize; + float min_roughness; + int BSDF; +}; diff --git a/nvdiffrec/lib/render/renderutils/c_src/common.cpp b/nvdiffrec/lib/render/renderutils/c_src/common.cpp new file mode 100644 index 0000000..445895e --- /dev/null +++ b/nvdiffrec/lib/render/renderutils/c_src/common.cpp @@ -0,0 +1,74 @@ +/* + * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#include +#include + +//------------------------------------------------------------------------ +// Block and grid size calculators for kernel launches. + +dim3 getLaunchBlockSize(int maxWidth, int maxHeight, dim3 dims) +{ + int maxThreads = maxWidth * maxHeight; + if (maxThreads <= 1 || (dims.x * dims.y) <= 1) + return dim3(1, 1, 1); // Degenerate. + + // Start from max size. + int bw = maxWidth; + int bh = maxHeight; + + // Optimizations for weirdly sized buffers. + if (dims.x < bw) + { + // Decrease block width to smallest power of two that covers the buffer width. + while ((bw >> 1) >= dims.x) + bw >>= 1; + + // Maximize height. + bh = maxThreads / bw; + if (bh > dims.y) + bh = dims.y; + } + else if (dims.y < bh) + { + // Halve height and double width until fits completely inside buffer vertically. + while (bh > dims.y) + { + bh >>= 1; + if (bw < dims.x) + bw <<= 1; + } + } + + // Done. + return dim3(bw, bh, 1); +} + +// returns the size of a block that can be reduced using horizontal SIMD operations (e.g. __shfl_xor_sync) +dim3 getWarpSize(dim3 blockSize) +{ + return dim3( + std::min(blockSize.x, 32u), + std::min(std::max(32u / blockSize.x, 1u), std::min(32u, blockSize.y)), + std::min(std::max(32u / (blockSize.x * blockSize.y), 1u), std::min(32u, blockSize.z)) + ); +} + +dim3 getLaunchGridSize(dim3 blockSize, dim3 dims) +{ + dim3 gridSize; + gridSize.x = (dims.x - 1) / blockSize.x + 1; + gridSize.y = (dims.y - 1) / blockSize.y + 1; + gridSize.z = (dims.z - 1) / blockSize.z + 1; + return gridSize; +} + +//------------------------------------------------------------------------ diff --git a/nvdiffrec/lib/render/renderutils/c_src/common.h b/nvdiffrec/lib/render/renderutils/c_src/common.h new file mode 100644 index 0000000..5abaeeb --- /dev/null +++ b/nvdiffrec/lib/render/renderutils/c_src/common.h @@ -0,0 +1,41 @@ +/* + * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#pragma once +#include +#include + +#include "vec3f.h" +#include "vec4f.h" +#include "tensor.h" + +dim3 getLaunchBlockSize(int maxWidth, int maxHeight, dim3 dims); +dim3 getLaunchGridSize(dim3 blockSize, dim3 dims); + +#ifdef __CUDACC__ + +#ifdef _MSC_VER +#define M_PI 3.14159265358979323846f +#endif + +__host__ __device__ static inline dim3 getWarpSize(dim3 blockSize) +{ + return dim3( + min(blockSize.x, 32u), + min(max(32u / blockSize.x, 1u), min(32u, blockSize.y)), + min(max(32u / (blockSize.x * blockSize.y), 1u), min(32u, blockSize.z)) + ); +} + +__device__ static inline float clamp(float val, float mn, float mx) { return min(max(val, mn), mx); } +#else +dim3 getWarpSize(dim3 blockSize); +#endif \ No newline at end of file diff --git a/nvdiffrec/lib/render/renderutils/c_src/cubemap.cu b/nvdiffrec/lib/render/renderutils/c_src/cubemap.cu new file mode 100644 index 0000000..2ce21d8 --- /dev/null +++ b/nvdiffrec/lib/render/renderutils/c_src/cubemap.cu @@ -0,0 +1,350 @@ +/* + * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#include "common.h" +#include "cubemap.h" +#include + +// https://cgvr.cs.uni-bremen.de/teaching/cg_literatur/Spherical,%20Cubic,%20and%20Parabolic%20Environment%20Mappings.pdf +__device__ float pixel_area(int x, int y, int N) +{ + if (N > 1) + { + int H = N / 2; + x = abs(x - H); + y = abs(y - H); + float dx = atan((float)(x + 1) / (float)H) - atan((float)x / (float)H); + float dy = atan((float)(y + 1) / (float)H) - atan((float)y / (float)H); + return dx * dy; + } + else + return 1; +} + +__device__ vec3f cube_to_dir(int x, int y, int side, int N) +{ + float fx = 2.0f * (((float)x + 0.5f) / (float)N) - 1.0f; + float fy = 2.0f * (((float)y + 0.5f) / (float)N) - 1.0f; + switch (side) + { + case 0: return safeNormalize(vec3f(1, -fy, -fx)); + case 1: return safeNormalize(vec3f(-1, -fy, fx)); + case 2: return safeNormalize(vec3f(fx, 1, fy)); + case 3: return safeNormalize(vec3f(fx, -1, -fy)); + case 4: return safeNormalize(vec3f(fx, -fy, 1)); + case 5: return safeNormalize(vec3f(-fx, -fy, -1)); + } + return vec3f(0,0,0); // Unreachable +} + +__device__ vec3f dir_to_side(int side, vec3f v) +{ + switch (side) + { + case 0: return vec3f(-v.z, -v.y, v.x); + case 1: return vec3f( v.z, -v.y, -v.x); + case 2: return vec3f( v.x, v.z, v.y); + case 3: return vec3f( v.x, -v.z, -v.y); + case 4: return vec3f( v.x, -v.y, v.z); + case 5: return vec3f(-v.x, -v.y, -v.z); + } + return vec3f(0,0,0); // Unreachable +} + +__device__ void extents_1d(float x, float z, float theta, float& _min, float& _max) +{ + float l = sqrtf(x * x + z * z); + float pxr = x + z * tan(theta) * l, pzr = z - x * tan(theta) * l; + float pxl = x - z * tan(theta) * l, pzl = z + x * tan(theta) * l; + if (pzl <= 0.00001f) + _min = pxl > 0.0f ? FLT_MAX : -FLT_MAX; + else + _min = pxl / pzl; + if (pzr <= 0.00001f) + _max = pxr > 0.0f ? FLT_MAX : -FLT_MAX; + else + _max = pxr / pzr; +} + +__device__ void dir_extents(int side, int N, vec3f v, float theta, int &_xmin, int& _xmax, int& _ymin, int& _ymax) +{ + vec3f c = dir_to_side(side, v); // remap to (x,y,z) where side is at z = 1 + + if (theta < 0.785398f) // PI/4 + { + float xmin, xmax, ymin, ymax; + extents_1d(c.x, c.z, theta, xmin, xmax); + extents_1d(c.y, c.z, theta, ymin, ymax); + + if (xmin > 1.0f || xmax < -1.0f || ymin > 1.0f || ymax < -1.0f) + { + _xmin = -1; _xmax = -1; _ymin = -1; _ymax = -1; // Bad aabb + } + else + { + _xmin = (int)min(max((xmin + 1.0f) * (0.5f * (float)N), 0.0f), (float)(N - 1)); + _xmax = (int)min(max((xmax + 1.0f) * (0.5f * (float)N), 0.0f), (float)(N - 1)); + _ymin = (int)min(max((ymin + 1.0f) * (0.5f * (float)N), 0.0f), (float)(N - 1)); + _ymax = (int)min(max((ymax + 1.0f) * (0.5f * (float)N), 0.0f), (float)(N - 1)); + } + } + else + { + _xmin = 0.0f; + _xmax = (float)(N-1); + _ymin = 0.0f; + _ymax = (float)(N-1); + } +} + +/////////////////////////////////////////////////////////////////////////////////////////////////////////// +// Diffuse kernel +__global__ void DiffuseCubemapFwdKernel(DiffuseCubemapKernelParams p) +{ + // Calculate pixel position. + int px = blockIdx.x * blockDim.x + threadIdx.x; + int py = blockIdx.y * blockDim.y + threadIdx.y; + int pz = blockIdx.z; + if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z) + return; + + int Npx = p.cubemap.dims[1]; + vec3f N = cube_to_dir(px, py, pz, Npx); + + vec3f col(0); + + for (int s = 0; s < p.cubemap.dims[0]; ++s) + { + for (int y = 0; y < Npx; ++y) + { + for (int x = 0; x < Npx; ++x) + { + vec3f L = cube_to_dir(x, y, s, Npx); + float costheta = min(max(dot(N, L), 0.0f), 0.999f); + float w = costheta * pixel_area(x, y, Npx) / 3.141592f; // pi = area of positive hemisphere + col += p.cubemap.fetch3(x, y, s) * w; + } + } + } + + p.out.store(px, py, pz, col); +} + +__global__ void DiffuseCubemapBwdKernel(DiffuseCubemapKernelParams p) +{ + // Calculate pixel position. + int px = blockIdx.x * blockDim.x + threadIdx.x; + int py = blockIdx.y * blockDim.y + threadIdx.y; + int pz = blockIdx.z; + if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z) + return; + + int Npx = p.cubemap.dims[1]; + vec3f N = cube_to_dir(px, py, pz, Npx); + vec3f grad = p.out.fetch3(px, py, pz); + + for (int s = 0; s < p.cubemap.dims[0]; ++s) + { + for (int y = 0; y < Npx; ++y) + { + for (int x = 0; x < Npx; ++x) + { + vec3f L = cube_to_dir(x, y, s, Npx); + float costheta = min(max(dot(N, L), 0.0f), 0.999f); + float w = costheta * pixel_area(x, y, Npx) / 3.141592f; // pi = area of positive hemisphere + atomicAdd((float*)p.cubemap.d_val + p.cubemap.nhwcIndexContinuous(s, y, x, 0), grad.x * w); + atomicAdd((float*)p.cubemap.d_val + p.cubemap.nhwcIndexContinuous(s, y, x, 1), grad.y * w); + atomicAdd((float*)p.cubemap.d_val + p.cubemap.nhwcIndexContinuous(s, y, x, 2), grad.z * w); + } + } + } +} + +/////////////////////////////////////////////////////////////////////////////////////////////////////////// +// GGX splitsum kernel + +__device__ inline float ndfGGX(const float alphaSqr, const float cosTheta) +{ + float _cosTheta = clamp(cosTheta, 0.0, 1.0f); + float d = (_cosTheta * alphaSqr - _cosTheta) * _cosTheta + 1.0f; + return alphaSqr / (d * d * M_PI); +} + +__global__ void SpecularBoundsKernel(SpecularBoundsKernelParams p) +{ + int px = blockIdx.x * blockDim.x + threadIdx.x; + int py = blockIdx.y * blockDim.y + threadIdx.y; + int pz = blockIdx.z; + if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z) + return; + + int Npx = p.gridSize.x; + vec3f VNR = cube_to_dir(px, py, pz, Npx); + + const int TILE_SIZE = 16; + + // Brute force entire cubemap and compute bounds for the cone + for (int s = 0; s < p.gridSize.z; ++s) + { + // Assume empty BBox + int _min_x = p.gridSize.x - 1, _max_x = 0; + int _min_y = p.gridSize.y - 1, _max_y = 0; + + // For each (8x8) tile + for (int tx = 0; tx < (p.gridSize.x + TILE_SIZE - 1) / TILE_SIZE; tx++) + { + for (int ty = 0; ty < (p.gridSize.y + TILE_SIZE - 1) / TILE_SIZE; ty++) + { + // Compute tile extents + int tsx = tx * TILE_SIZE, tsy = ty * TILE_SIZE; + int tex = min((tx + 1) * TILE_SIZE, p.gridSize.x), tey = min((ty + 1) * TILE_SIZE, p.gridSize.y); + + // Use some blunt interval arithmetics to cull tiles + vec3f L0 = cube_to_dir(tsx, tsy, s, Npx), L1 = cube_to_dir(tex, tsy, s, Npx); + vec3f L2 = cube_to_dir(tsx, tey, s, Npx), L3 = cube_to_dir(tex, tey, s, Npx); + + float minx = min(min(L0.x, L1.x), min(L2.x, L3.x)), maxx = max(max(L0.x, L1.x), max(L2.x, L3.x)); + float miny = min(min(L0.y, L1.y), min(L2.y, L3.y)), maxy = max(max(L0.y, L1.y), max(L2.y, L3.y)); + float minz = min(min(L0.z, L1.z), min(L2.z, L3.z)), maxz = max(max(L0.z, L1.z), max(L2.z, L3.z)); + + float maxdp = max(minx * VNR.x, maxx * VNR.x) + max(miny * VNR.y, maxy * VNR.y) + max(minz * VNR.z, maxz * VNR.z); + if (maxdp >= p.costheta_cutoff) + { + // Test all pixels in tile. + for (int y = tsy; y < tey; ++y) + { + for (int x = tsx; x < tex; ++x) + { + vec3f L = cube_to_dir(x, y, s, Npx); + if (dot(L, VNR) >= p.costheta_cutoff) + { + _min_x = min(_min_x, x); + _max_x = max(_max_x, x); + _min_y = min(_min_y, y); + _max_y = max(_max_y, y); + } + } + } + } + } + } + p.out.store(p.out._nhwcIndex(pz, py, px, s * 4 + 0), _min_x); + p.out.store(p.out._nhwcIndex(pz, py, px, s * 4 + 1), _max_x); + p.out.store(p.out._nhwcIndex(pz, py, px, s * 4 + 2), _min_y); + p.out.store(p.out._nhwcIndex(pz, py, px, s * 4 + 3), _max_y); + } +} + +__global__ void SpecularCubemapFwdKernel(SpecularCubemapKernelParams p) +{ + // Calculate pixel position. + int px = blockIdx.x * blockDim.x + threadIdx.x; + int py = blockIdx.y * blockDim.y + threadIdx.y; + int pz = blockIdx.z; + if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z) + return; + + int Npx = p.cubemap.dims[1]; + vec3f VNR = cube_to_dir(px, py, pz, Npx); + + float alpha = p.roughness * p.roughness; + float alphaSqr = alpha * alpha; + + float wsum = 0.0f; + vec3f col(0); + for (int s = 0; s < p.cubemap.dims[0]; ++s) + { + int xmin, xmax, ymin, ymax; + xmin = (int)p.bounds.fetch(p.bounds._nhwcIndex(pz, py, px, s * 4 + 0)); + xmax = (int)p.bounds.fetch(p.bounds._nhwcIndex(pz, py, px, s * 4 + 1)); + ymin = (int)p.bounds.fetch(p.bounds._nhwcIndex(pz, py, px, s * 4 + 2)); + ymax = (int)p.bounds.fetch(p.bounds._nhwcIndex(pz, py, px, s * 4 + 3)); + + if (xmin <= xmax) + { + for (int y = ymin; y <= ymax; ++y) + { + for (int x = xmin; x <= xmax; ++x) + { + vec3f L = cube_to_dir(x, y, s, Npx); + if (dot(L, VNR) >= p.costheta_cutoff) + { + vec3f H = safeNormalize(L + VNR); + + float wiDotN = max(dot(L, VNR), 0.0f); + float VNRDotH = max(dot(VNR, H), 0.0f); + + float w = wiDotN * ndfGGX(alphaSqr, VNRDotH) * pixel_area(x, y, Npx) / 4.0f; + col += p.cubemap.fetch3(x, y, s) * w; + wsum += w; + } + } + } + } + } + + p.out.store(p.out._nhwcIndex(pz, py, px, 0), col.x); + p.out.store(p.out._nhwcIndex(pz, py, px, 1), col.y); + p.out.store(p.out._nhwcIndex(pz, py, px, 2), col.z); + p.out.store(p.out._nhwcIndex(pz, py, px, 3), wsum); +} + +__global__ void SpecularCubemapBwdKernel(SpecularCubemapKernelParams p) +{ + // Calculate pixel position. + int px = blockIdx.x * blockDim.x + threadIdx.x; + int py = blockIdx.y * blockDim.y + threadIdx.y; + int pz = blockIdx.z; + if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z) + return; + + int Npx = p.cubemap.dims[1]; + vec3f VNR = cube_to_dir(px, py, pz, Npx); + + vec3f grad = p.out.fetch3(px, py, pz); + + float alpha = p.roughness * p.roughness; + float alphaSqr = alpha * alpha; + + vec3f col(0); + for (int s = 0; s < p.cubemap.dims[0]; ++s) + { + int xmin, xmax, ymin, ymax; + xmin = (int)p.bounds.fetch(p.bounds._nhwcIndex(pz, py, px, s * 4 + 0)); + xmax = (int)p.bounds.fetch(p.bounds._nhwcIndex(pz, py, px, s * 4 + 1)); + ymin = (int)p.bounds.fetch(p.bounds._nhwcIndex(pz, py, px, s * 4 + 2)); + ymax = (int)p.bounds.fetch(p.bounds._nhwcIndex(pz, py, px, s * 4 + 3)); + + if (xmin <= xmax) + { + for (int y = ymin; y <= ymax; ++y) + { + for (int x = xmin; x <= xmax; ++x) + { + vec3f L = cube_to_dir(x, y, s, Npx); + if (dot(L, VNR) >= p.costheta_cutoff) + { + vec3f H = safeNormalize(L + VNR); + + float wiDotN = max(dot(L, VNR), 0.0f); + float VNRDotH = max(dot(VNR, H), 0.0f); + + float w = wiDotN * ndfGGX(alphaSqr, VNRDotH) * pixel_area(x, y, Npx) / 4.0f; + + atomicAdd((float*)p.cubemap.d_val + p.cubemap.nhwcIndexContinuous(s, y, x, 0), grad.x * w); + atomicAdd((float*)p.cubemap.d_val + p.cubemap.nhwcIndexContinuous(s, y, x, 1), grad.y * w); + atomicAdd((float*)p.cubemap.d_val + p.cubemap.nhwcIndexContinuous(s, y, x, 2), grad.z * w); + } + } + } + } + } +} diff --git a/nvdiffrec/lib/render/renderutils/c_src/cubemap.h b/nvdiffrec/lib/render/renderutils/c_src/cubemap.h new file mode 100644 index 0000000..f395cc2 --- /dev/null +++ b/nvdiffrec/lib/render/renderutils/c_src/cubemap.h @@ -0,0 +1,38 @@ +/* + * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#pragma once + +#include "common.h" + +struct DiffuseCubemapKernelParams +{ + Tensor cubemap; + Tensor out; + dim3 gridSize; +}; + +struct SpecularCubemapKernelParams +{ + Tensor cubemap; + Tensor bounds; + Tensor out; + dim3 gridSize; + float costheta_cutoff; + float roughness; +}; + +struct SpecularBoundsKernelParams +{ + float costheta_cutoff; + Tensor out; + dim3 gridSize; +}; diff --git a/nvdiffrec/lib/render/renderutils/c_src/loss.cu b/nvdiffrec/lib/render/renderutils/c_src/loss.cu new file mode 100644 index 0000000..aae5272 --- /dev/null +++ b/nvdiffrec/lib/render/renderutils/c_src/loss.cu @@ -0,0 +1,210 @@ +/* + * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#include + +#include "common.h" +#include "loss.h" + +//------------------------------------------------------------------------ +// Utils + +__device__ inline float bwdAbs(float x) { return x == 0.0f ? 0.0f : x < 0.0f ? -1.0f : 1.0f; } + +__device__ float warpSum(float val) { + for (int i = 1; i < 32; i *= 2) + val += __shfl_xor_sync(0xFFFFFFFF, val, i); + return val; +} + +//------------------------------------------------------------------------ +// Tonemapping + +__device__ inline float fwdSRGB(float x) +{ + return x > 0.0031308f ? powf(max(x, 0.0031308f), 1.0f / 2.4f) * 1.055f - 0.055f : 12.92f * max(x, 0.0f); +} + +__device__ inline void bwdSRGB(float x, float &d_x, float d_out) +{ + if (x > 0.0031308f) + d_x += d_out * 0.439583f / powf(x, 0.583333f); + else if (x > 0.0f) + d_x += d_out * 12.92f; +} + +__device__ inline vec3f fwdTonemapLogSRGB(vec3f x) +{ + return vec3f(fwdSRGB(logf(x.x + 1.0f)), fwdSRGB(logf(x.y + 1.0f)), fwdSRGB(logf(x.z + 1.0f))); +} + +__device__ inline void bwdTonemapLogSRGB(vec3f x, vec3f& d_x, vec3f d_out) +{ + if (x.x > 0.0f && x.x < 65535.0f) + { + bwdSRGB(logf(x.x + 1.0f), d_x.x, d_out.x); + d_x.x *= 1 / (x.x + 1.0f); + } + if (x.y > 0.0f && x.y < 65535.0f) + { + bwdSRGB(logf(x.y + 1.0f), d_x.y, d_out.y); + d_x.y *= 1 / (x.y + 1.0f); + } + if (x.z > 0.0f && x.z < 65535.0f) + { + bwdSRGB(logf(x.z + 1.0f), d_x.z, d_out.z); + d_x.z *= 1 / (x.z + 1.0f); + } +} + +__device__ inline float fwdRELMSE(float img, float target, float eps = 0.1f) +{ + return (img - target) * (img - target) / (img * img + target * target + eps); +} + +__device__ inline void bwdRELMSE(float img, float target, float &d_img, float &d_target, float d_out, float eps = 0.1f) +{ + float denom = (target * target + img * img + eps); + d_img += d_out * 2 * (img - target) * (target * (target + img) + eps) / (denom * denom); + d_target -= d_out * 2 * (img - target) * (img * (target + img) + eps) / (denom * denom); +} + +__device__ inline float fwdSMAPE(float img, float target, float eps=0.01f) +{ + return abs(img - target) / (img + target + eps); +} + +__device__ inline void bwdSMAPE(float img, float target, float& d_img, float& d_target, float d_out, float eps = 0.01f) +{ + float denom = (target + img + eps); + d_img += d_out * bwdAbs(img - target) * (2 * target + eps) / (denom * denom); + d_target -= d_out * bwdAbs(img - target) * (2 * img + eps) / (denom * denom); +} + +//------------------------------------------------------------------------ +// Kernels + +__global__ void imgLossFwdKernel(LossKernelParams p) +{ + // Calculate pixel position. + unsigned int px = blockIdx.x * blockDim.x + threadIdx.x; + unsigned int py = blockIdx.y * blockDim.y + threadIdx.y; + unsigned int pz = blockIdx.z; + + float floss = 0.0f; + if (px < p.gridSize.x && py < p.gridSize.y && pz < p.gridSize.z) + { + vec3f img = p.img.fetch3(px, py, pz); + vec3f target = p.target.fetch3(px, py, pz); + + img = vec3f(clamp(img.x, 0.0f, 65535.0f), clamp(img.y, 0.0f, 65535.0f), clamp(img.z, 0.0f, 65535.0f)); + target = vec3f(clamp(target.x, 0.0f, 65535.0f), clamp(target.y, 0.0f, 65535.0f), clamp(target.z, 0.0f, 65535.0f)); + + if (p.tonemapper == TONEMAPPER_LOG_SRGB) + { + img = fwdTonemapLogSRGB(img); + target = fwdTonemapLogSRGB(target); + } + + vec3f vloss(0); + if (p.loss == LOSS_MSE) + vloss = (img - target) * (img - target); + else if (p.loss == LOSS_RELMSE) + vloss = vec3f(fwdRELMSE(img.x, target.x), fwdRELMSE(img.y, target.y), fwdRELMSE(img.z, target.z)); + else if (p.loss == LOSS_SMAPE) + vloss = vec3f(fwdSMAPE(img.x, target.x), fwdSMAPE(img.y, target.y), fwdSMAPE(img.z, target.z)); + else + vloss = vec3f(abs(img.x - target.x), abs(img.y - target.y), abs(img.z - target.z)); + + floss = sum(vloss) / 3.0f; + } + + floss = warpSum(floss); + + dim3 warpSize = getWarpSize(blockDim); + if (px < p.gridSize.x && py < p.gridSize.y && pz < p.gridSize.z && threadIdx.x % warpSize.x == 0 && threadIdx.y % warpSize.y == 0 && threadIdx.z % warpSize.z == 0) + p.out.store(px / warpSize.x, py / warpSize.y, pz / warpSize.z, floss); +} + +__global__ void imgLossBwdKernel(LossKernelParams p) +{ + // Calculate pixel position. + unsigned int px = blockIdx.x * blockDim.x + threadIdx.x; + unsigned int py = blockIdx.y * blockDim.y + threadIdx.y; + unsigned int pz = blockIdx.z; + + if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z) + return; + + dim3 warpSize = getWarpSize(blockDim); + + vec3f _img = p.img.fetch3(px, py, pz); + vec3f _target = p.target.fetch3(px, py, pz); + float d_out = p.out.fetch1(px / warpSize.x, py / warpSize.y, pz / warpSize.z); + + ///////////////////////////////////////////////////////////////////// + // FWD + + vec3f img = _img, target = _target; + if (p.tonemapper == TONEMAPPER_LOG_SRGB) + { + img = fwdTonemapLogSRGB(img); + target = fwdTonemapLogSRGB(target); + } + + ///////////////////////////////////////////////////////////////////// + // BWD + + vec3f d_vloss = vec3f(d_out, d_out, d_out) / 3.0f; + + vec3f d_img(0), d_target(0); + if (p.loss == LOSS_MSE) + { + d_img = vec3f(d_vloss.x * 2 * (img.x - target.x), d_vloss.y * 2 * (img.y - target.y), d_vloss.x * 2 * (img.z - target.z)); + d_target = -d_img; + } + else if (p.loss == LOSS_RELMSE) + { + bwdRELMSE(img.x, target.x, d_img.x, d_target.x, d_vloss.x); + bwdRELMSE(img.y, target.y, d_img.y, d_target.y, d_vloss.y); + bwdRELMSE(img.z, target.z, d_img.z, d_target.z, d_vloss.z); + } + else if (p.loss == LOSS_SMAPE) + { + bwdSMAPE(img.x, target.x, d_img.x, d_target.x, d_vloss.x); + bwdSMAPE(img.y, target.y, d_img.y, d_target.y, d_vloss.y); + bwdSMAPE(img.z, target.z, d_img.z, d_target.z, d_vloss.z); + } + else + { + d_img = d_vloss * vec3f(bwdAbs(img.x - target.x), bwdAbs(img.y - target.y), bwdAbs(img.z - target.z)); + d_target = -d_img; + } + + + if (p.tonemapper == TONEMAPPER_LOG_SRGB) + { + vec3f d__img(0), d__target(0); + bwdTonemapLogSRGB(_img, d__img, d_img); + bwdTonemapLogSRGB(_target, d__target, d_target); + d_img = d__img; d_target = d__target; + } + + if (_img.x <= 0.0f || _img.x >= 65535.0f) d_img.x = 0; + if (_img.y <= 0.0f || _img.y >= 65535.0f) d_img.y = 0; + if (_img.z <= 0.0f || _img.z >= 65535.0f) d_img.z = 0; + if (_target.x <= 0.0f || _target.x >= 65535.0f) d_target.x = 0; + if (_target.y <= 0.0f || _target.y >= 65535.0f) d_target.y = 0; + if (_target.z <= 0.0f || _target.z >= 65535.0f) d_target.z = 0; + + p.img.store_grad(px, py, pz, d_img); + p.target.store_grad(px, py, pz, d_target); +} \ No newline at end of file diff --git a/nvdiffrec/lib/render/renderutils/c_src/loss.h b/nvdiffrec/lib/render/renderutils/c_src/loss.h new file mode 100644 index 0000000..26790bf --- /dev/null +++ b/nvdiffrec/lib/render/renderutils/c_src/loss.h @@ -0,0 +1,38 @@ +/* + * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#pragma once + +#include "common.h" + +enum TonemapperType +{ + TONEMAPPER_NONE = 0, + TONEMAPPER_LOG_SRGB = 1 +}; + +enum LossType +{ + LOSS_L1 = 0, + LOSS_MSE = 1, + LOSS_RELMSE = 2, + LOSS_SMAPE = 3 +}; + +struct LossKernelParams +{ + Tensor img; + Tensor target; + Tensor out; + dim3 gridSize; + TonemapperType tonemapper; + LossType loss; +}; diff --git a/nvdiffrec/lib/render/renderutils/c_src/mesh.cu b/nvdiffrec/lib/render/renderutils/c_src/mesh.cu new file mode 100644 index 0000000..3690ea3 --- /dev/null +++ b/nvdiffrec/lib/render/renderutils/c_src/mesh.cu @@ -0,0 +1,94 @@ +/* + * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#include +#include + +#include "common.h" +#include "mesh.h" + + +//------------------------------------------------------------------------ +// Kernels + +__global__ void xfmPointsFwdKernel(XfmKernelParams p) +{ + unsigned int px = blockIdx.x * blockDim.x + threadIdx.x; + unsigned int pz = blockIdx.z * blockDim.z + threadIdx.z; + + __shared__ float mtx[4][4]; + if (threadIdx.x < 16) + mtx[threadIdx.x % 4][threadIdx.x / 4] = p.matrix.fetch(p.matrix.nhwcIndex(pz, threadIdx.x / 4, threadIdx.x % 4, 0)); + __syncthreads(); + + if (px >= p.gridSize.x) + return; + + vec3f pos( + p.points.fetch(p.points.nhwcIndex(pz, px, 0, 0)), + p.points.fetch(p.points.nhwcIndex(pz, px, 1, 0)), + p.points.fetch(p.points.nhwcIndex(pz, px, 2, 0)) + ); + + if (p.isPoints) + { + p.out.store(p.out.nhwcIndex(pz, px, 0, 0), pos.x * mtx[0][0] + pos.y * mtx[1][0] + pos.z * mtx[2][0] + mtx[3][0]); + p.out.store(p.out.nhwcIndex(pz, px, 1, 0), pos.x * mtx[0][1] + pos.y * mtx[1][1] + pos.z * mtx[2][1] + mtx[3][1]); + p.out.store(p.out.nhwcIndex(pz, px, 2, 0), pos.x * mtx[0][2] + pos.y * mtx[1][2] + pos.z * mtx[2][2] + mtx[3][2]); + p.out.store(p.out.nhwcIndex(pz, px, 3, 0), pos.x * mtx[0][3] + pos.y * mtx[1][3] + pos.z * mtx[2][3] + mtx[3][3]); + } + else + { + p.out.store(p.out.nhwcIndex(pz, px, 0, 0), pos.x * mtx[0][0] + pos.y * mtx[1][0] + pos.z * mtx[2][0]); + p.out.store(p.out.nhwcIndex(pz, px, 1, 0), pos.x * mtx[0][1] + pos.y * mtx[1][1] + pos.z * mtx[2][1]); + p.out.store(p.out.nhwcIndex(pz, px, 2, 0), pos.x * mtx[0][2] + pos.y * mtx[1][2] + pos.z * mtx[2][2]); + } +} + +__global__ void xfmPointsBwdKernel(XfmKernelParams p) +{ + unsigned int px = blockIdx.x * blockDim.x + threadIdx.x; + unsigned int pz = blockIdx.z * blockDim.z + threadIdx.z; + + __shared__ float mtx[4][4]; + if (threadIdx.x < 16) + mtx[threadIdx.x % 4][threadIdx.x / 4] = p.matrix.fetch(p.matrix.nhwcIndex(pz, threadIdx.x / 4, threadIdx.x % 4, 0)); + __syncthreads(); + + if (px >= p.gridSize.x) + return; + + vec3f pos( + p.points.fetch(p.points.nhwcIndex(pz, px, 0, 0)), + p.points.fetch(p.points.nhwcIndex(pz, px, 1, 0)), + p.points.fetch(p.points.nhwcIndex(pz, px, 2, 0)) + ); + + vec4f d_out( + p.out.fetch(p.out.nhwcIndex(pz, px, 0, 0)), + p.out.fetch(p.out.nhwcIndex(pz, px, 1, 0)), + p.out.fetch(p.out.nhwcIndex(pz, px, 2, 0)), + p.out.fetch(p.out.nhwcIndex(pz, px, 3, 0)) + ); + + if (p.isPoints) + { + p.points.store_grad(p.points.nhwcIndexContinuous(pz, px, 0, 0), d_out.x * mtx[0][0] + d_out.y * mtx[0][1] + d_out.z * mtx[0][2] + d_out.w * mtx[0][3]); + p.points.store_grad(p.points.nhwcIndexContinuous(pz, px, 1, 0), d_out.x * mtx[1][0] + d_out.y * mtx[1][1] + d_out.z * mtx[1][2] + d_out.w * mtx[1][3]); + p.points.store_grad(p.points.nhwcIndexContinuous(pz, px, 2, 0), d_out.x * mtx[2][0] + d_out.y * mtx[2][1] + d_out.z * mtx[2][2] + d_out.w * mtx[2][3]); + } + else + { + p.points.store_grad(p.points.nhwcIndexContinuous(pz, px, 0, 0), d_out.x * mtx[0][0] + d_out.y * mtx[0][1] + d_out.z * mtx[0][2]); + p.points.store_grad(p.points.nhwcIndexContinuous(pz, px, 1, 0), d_out.x * mtx[1][0] + d_out.y * mtx[1][1] + d_out.z * mtx[1][2]); + p.points.store_grad(p.points.nhwcIndexContinuous(pz, px, 2, 0), d_out.x * mtx[2][0] + d_out.y * mtx[2][1] + d_out.z * mtx[2][2]); + } +} \ No newline at end of file diff --git a/nvdiffrec/lib/render/renderutils/c_src/mesh.h b/nvdiffrec/lib/render/renderutils/c_src/mesh.h new file mode 100644 index 0000000..16e2166 --- /dev/null +++ b/nvdiffrec/lib/render/renderutils/c_src/mesh.h @@ -0,0 +1,23 @@ +/* + * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#pragma once + +#include "common.h" + +struct XfmKernelParams +{ + bool isPoints; + Tensor points; + Tensor matrix; + Tensor out; + dim3 gridSize; +}; diff --git a/nvdiffrec/lib/render/renderutils/c_src/normal.cu b/nvdiffrec/lib/render/renderutils/c_src/normal.cu new file mode 100644 index 0000000..a50e49e --- /dev/null +++ b/nvdiffrec/lib/render/renderutils/c_src/normal.cu @@ -0,0 +1,182 @@ +/* + * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#include "common.h" +#include "normal.h" + +#define NORMAL_THRESHOLD 0.1f + +//------------------------------------------------------------------------ +// Perturb shading normal by tangent frame + +__device__ vec3f fwdPerturbNormal(const vec3f perturbed_nrm, const vec3f smooth_nrm, const vec3f smooth_tng, bool opengl) +{ + vec3f _smooth_bitng = cross(smooth_tng, smooth_nrm); + vec3f smooth_bitng = safeNormalize(_smooth_bitng); + vec3f _shading_nrm = smooth_tng * perturbed_nrm.x + (opengl ? -1 : 1) * smooth_bitng * perturbed_nrm.y + smooth_nrm * max(perturbed_nrm.z, 0.0f); + return safeNormalize(_shading_nrm); +} + +__device__ void bwdPerturbNormal(const vec3f perturbed_nrm, const vec3f smooth_nrm, const vec3f smooth_tng, vec3f &d_perturbed_nrm, vec3f &d_smooth_nrm, vec3f &d_smooth_tng, const vec3f d_out, bool opengl) +{ + //////////////////////////////////////////////////////////////////////// + // FWD + vec3f _smooth_bitng = cross(smooth_tng, smooth_nrm); + vec3f smooth_bitng = safeNormalize(_smooth_bitng); + vec3f _shading_nrm = smooth_tng * perturbed_nrm.x + (opengl ? -1 : 1) * smooth_bitng * perturbed_nrm.y + smooth_nrm * max(perturbed_nrm.z, 0.0f); + + //////////////////////////////////////////////////////////////////////// + // BWD + vec3f d_shading_nrm(0); + bwdSafeNormalize(_shading_nrm, d_shading_nrm, d_out); + + vec3f d_smooth_bitng(0); + + if (perturbed_nrm.z > 0.0f) + { + d_smooth_nrm += d_shading_nrm * perturbed_nrm.z; + d_perturbed_nrm.z += sum(d_shading_nrm * smooth_nrm); + } + + d_smooth_bitng += (opengl ? -1 : 1) * d_shading_nrm * perturbed_nrm.y; + d_perturbed_nrm.y += (opengl ? -1 : 1) * sum(d_shading_nrm * smooth_bitng); + + d_smooth_tng += d_shading_nrm * perturbed_nrm.x; + d_perturbed_nrm.x += sum(d_shading_nrm * smooth_tng); + + vec3f d__smooth_bitng(0); + bwdSafeNormalize(_smooth_bitng, d__smooth_bitng, d_smooth_bitng); + + bwdCross(smooth_tng, smooth_nrm, d_smooth_tng, d_smooth_nrm, d__smooth_bitng); +} + +//------------------------------------------------------------------------ +#define bent_nrm_eps 0.001f + +__device__ vec3f fwdBendNormal(const vec3f view_vec, const vec3f smooth_nrm, const vec3f geom_nrm) +{ + float dp = dot(view_vec, smooth_nrm); + float t = clamp(dp / NORMAL_THRESHOLD, 0.0f, 1.0f); + return geom_nrm * (1.0f - t) + smooth_nrm * t; +} + +__device__ void bwdBendNormal(const vec3f view_vec, const vec3f smooth_nrm, const vec3f geom_nrm, vec3f& d_view_vec, vec3f& d_smooth_nrm, vec3f& d_geom_nrm, const vec3f d_out) +{ + //////////////////////////////////////////////////////////////////////// + // FWD + float dp = dot(view_vec, smooth_nrm); + float t = clamp(dp / NORMAL_THRESHOLD, 0.0f, 1.0f); + + //////////////////////////////////////////////////////////////////////// + // BWD + if (dp > NORMAL_THRESHOLD) + d_smooth_nrm += d_out; + else + { + // geom_nrm * (1.0f - t) + smooth_nrm * t; + d_geom_nrm += d_out * (1.0f - t); + d_smooth_nrm += d_out * t; + float d_t = sum(d_out * (smooth_nrm - geom_nrm)); + + float d_dp = dp < 0.0f || dp > NORMAL_THRESHOLD ? 0.0f : d_t / NORMAL_THRESHOLD; + + bwdDot(view_vec, smooth_nrm, d_view_vec, d_smooth_nrm, d_dp); + } +} + +//------------------------------------------------------------------------ +// Kernels + +__global__ void PrepareShadingNormalFwdKernel(PrepareShadingNormalKernelParams p) +{ + // Calculate pixel position. + unsigned int px = blockIdx.x * blockDim.x + threadIdx.x; + unsigned int py = blockIdx.y * blockDim.y + threadIdx.y; + unsigned int pz = blockIdx.z; + if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z) + return; + + vec3f pos = p.pos.fetch3(px, py, pz); + vec3f view_pos = p.view_pos.fetch3(px, py, pz); + vec3f perturbed_nrm = p.perturbed_nrm.fetch3(px, py, pz); + vec3f _smooth_nrm = p.smooth_nrm.fetch3(px, py, pz); + vec3f _smooth_tng = p.smooth_tng.fetch3(px, py, pz); + vec3f geom_nrm = p.geom_nrm.fetch3(px, py, pz); + + vec3f smooth_nrm = safeNormalize(_smooth_nrm); + vec3f smooth_tng = safeNormalize(_smooth_tng); + vec3f view_vec = safeNormalize(view_pos - pos); + vec3f shading_nrm = fwdPerturbNormal(perturbed_nrm, smooth_nrm, smooth_tng, p.opengl); + + vec3f res; + if (p.two_sided_shading && dot(view_vec, geom_nrm) < 0.0f) + res = fwdBendNormal(view_vec, -shading_nrm, -geom_nrm); + else + res = fwdBendNormal(view_vec, shading_nrm, geom_nrm); + + p.out.store(px, py, pz, res); +} + +__global__ void PrepareShadingNormalBwdKernel(PrepareShadingNormalKernelParams p) +{ + // Calculate pixel position. + unsigned int px = blockIdx.x * blockDim.x + threadIdx.x; + unsigned int py = blockIdx.y * blockDim.y + threadIdx.y; + unsigned int pz = blockIdx.z; + if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z) + return; + + vec3f pos = p.pos.fetch3(px, py, pz); + vec3f view_pos = p.view_pos.fetch3(px, py, pz); + vec3f perturbed_nrm = p.perturbed_nrm.fetch3(px, py, pz); + vec3f _smooth_nrm = p.smooth_nrm.fetch3(px, py, pz); + vec3f _smooth_tng = p.smooth_tng.fetch3(px, py, pz); + vec3f geom_nrm = p.geom_nrm.fetch3(px, py, pz); + vec3f d_out = p.out.fetch3(px, py, pz); + + /////////////////////////////////////////////////////////////////////////////////////////////////// + // FWD + + vec3f smooth_nrm = safeNormalize(_smooth_nrm); + vec3f smooth_tng = safeNormalize(_smooth_tng); + vec3f _view_vec = view_pos - pos; + vec3f view_vec = safeNormalize(view_pos - pos); + + vec3f shading_nrm = fwdPerturbNormal(perturbed_nrm, smooth_nrm, smooth_tng, p.opengl); + + /////////////////////////////////////////////////////////////////////////////////////////////////// + // BWD + + vec3f d_view_vec(0), d_shading_nrm(0), d_geom_nrm(0); + if (p.two_sided_shading && dot(view_vec, geom_nrm) < 0.0f) + { + bwdBendNormal(view_vec, -shading_nrm, -geom_nrm, d_view_vec, d_shading_nrm, d_geom_nrm, d_out); + d_shading_nrm = -d_shading_nrm; + d_geom_nrm = -d_geom_nrm; + } + else + bwdBendNormal(view_vec, shading_nrm, geom_nrm, d_view_vec, d_shading_nrm, d_geom_nrm, d_out); + + vec3f d_perturbed_nrm(0), d_smooth_nrm(0), d_smooth_tng(0); + bwdPerturbNormal(perturbed_nrm, smooth_nrm, smooth_tng, d_perturbed_nrm, d_smooth_nrm, d_smooth_tng, d_shading_nrm, p.opengl); + + vec3f d__view_vec(0), d__smooth_nrm(0), d__smooth_tng(0); + bwdSafeNormalize(_view_vec, d__view_vec, d_view_vec); + bwdSafeNormalize(_smooth_nrm, d__smooth_nrm, d_smooth_nrm); + bwdSafeNormalize(_smooth_tng, d__smooth_tng, d_smooth_tng); + + p.pos.store_grad(px, py, pz, -d__view_vec); + p.view_pos.store_grad(px, py, pz, d__view_vec); + p.perturbed_nrm.store_grad(px, py, pz, d_perturbed_nrm); + p.smooth_nrm.store_grad(px, py, pz, d__smooth_nrm); + p.smooth_tng.store_grad(px, py, pz, d__smooth_tng); + p.geom_nrm.store_grad(px, py, pz, d_geom_nrm); +} \ No newline at end of file diff --git a/nvdiffrec/lib/render/renderutils/c_src/normal.h b/nvdiffrec/lib/render/renderutils/c_src/normal.h new file mode 100644 index 0000000..8882c22 --- /dev/null +++ b/nvdiffrec/lib/render/renderutils/c_src/normal.h @@ -0,0 +1,27 @@ +/* + * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#pragma once + +#include "common.h" + +struct PrepareShadingNormalKernelParams +{ + Tensor pos; + Tensor view_pos; + Tensor perturbed_nrm; + Tensor smooth_nrm; + Tensor smooth_tng; + Tensor geom_nrm; + Tensor out; + dim3 gridSize; + bool two_sided_shading, opengl; +}; diff --git a/nvdiffrec/lib/render/renderutils/c_src/tensor.h b/nvdiffrec/lib/render/renderutils/c_src/tensor.h new file mode 100644 index 0000000..1dfb4e8 --- /dev/null +++ b/nvdiffrec/lib/render/renderutils/c_src/tensor.h @@ -0,0 +1,92 @@ +/* + * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#pragma once +#if defined(__CUDACC__) && defined(BFLOAT16) +#include // bfloat16 is float32 compatible with less mantissa bits +#endif + +//--------------------------------------------------------------------------------- +// CUDA-side Tensor class for in/out parameter parsing. Can be float32 or bfloat16 + +struct Tensor +{ + void* val; + void* d_val; + int dims[4], _dims[4]; + int strides[4]; + bool fp16; + +#if defined(__CUDA__) && !defined(__CUDA_ARCH__) + Tensor() : val(nullptr), d_val(nullptr), fp16(true), dims{ 0, 0, 0, 0 }, _dims{ 0, 0, 0, 0 }, strides{ 0, 0, 0, 0 } {} +#endif + +#ifdef __CUDACC__ + // Helpers to index and read/write a single element + __device__ inline int _nhwcIndex(int n, int h, int w, int c) const { return n * strides[0] + h * strides[1] + w * strides[2] + c * strides[3]; } + __device__ inline int nhwcIndex(int n, int h, int w, int c) const { return (dims[0] == 1 ? 0 : n * strides[0]) + (dims[1] == 1 ? 0 : h * strides[1]) + (dims[2] == 1 ? 0 : w * strides[2]) + (dims[3] == 1 ? 0 : c * strides[3]); } + __device__ inline int nhwcIndexContinuous(int n, int h, int w, int c) const { return ((n * _dims[1] + h) * _dims[2] + w) * _dims[3] + c; } +#ifdef BFLOAT16 + __device__ inline float fetch(unsigned int idx) const { return fp16 ? __bfloat162float(((__nv_bfloat16*)val)[idx]) : ((float*)val)[idx]; } + __device__ inline void store(unsigned int idx, float _val) { if (fp16) ((__nv_bfloat16*)val)[idx] = __float2bfloat16(_val); else ((float*)val)[idx] = _val; } + __device__ inline void store_grad(unsigned int idx, float _val) { if (fp16) ((__nv_bfloat16*)d_val)[idx] = __float2bfloat16(_val); else ((float*)d_val)[idx] = _val; } +#else + __device__ inline float fetch(unsigned int idx) const { return ((float*)val)[idx]; } + __device__ inline void store(unsigned int idx, float _val) { ((float*)val)[idx] = _val; } + __device__ inline void store_grad(unsigned int idx, float _val) { ((float*)d_val)[idx] = _val; } +#endif + + ////////////////////////////////////////////////////////////////////////////////////////// + // Fetch, use broadcasting for tensor dimensions of size 1 + __device__ inline float fetch1(unsigned int x, unsigned int y, unsigned int z) const + { + return fetch(nhwcIndex(z, y, x, 0)); + } + + __device__ inline vec3f fetch3(unsigned int x, unsigned int y, unsigned int z) const + { + return vec3f( + fetch(nhwcIndex(z, y, x, 0)), + fetch(nhwcIndex(z, y, x, 1)), + fetch(nhwcIndex(z, y, x, 2)) + ); + } + + ///////////////////////////////////////////////////////////////////////////////////////////////////////////// + // Store, no broadcasting here. Assume we output full res gradient and then reduce using torch.sum outside + __device__ inline void store(unsigned int x, unsigned int y, unsigned int z, float _val) + { + store(_nhwcIndex(z, y, x, 0), _val); + } + + __device__ inline void store(unsigned int x, unsigned int y, unsigned int z, vec3f _val) + { + store(_nhwcIndex(z, y, x, 0), _val.x); + store(_nhwcIndex(z, y, x, 1), _val.y); + store(_nhwcIndex(z, y, x, 2), _val.z); + } + + ///////////////////////////////////////////////////////////////////////////////////////////////////////////// + // Store gradient , no broadcasting here. Assume we output full res gradient and then reduce using torch.sum outside + __device__ inline void store_grad(unsigned int x, unsigned int y, unsigned int z, float _val) + { + store_grad(nhwcIndexContinuous(z, y, x, 0), _val); + } + + __device__ inline void store_grad(unsigned int x, unsigned int y, unsigned int z, vec3f _val) + { + store_grad(nhwcIndexContinuous(z, y, x, 0), _val.x); + store_grad(nhwcIndexContinuous(z, y, x, 1), _val.y); + store_grad(nhwcIndexContinuous(z, y, x, 2), _val.z); + } +#endif + +}; diff --git a/nvdiffrec/lib/render/renderutils/c_src/torch_bindings.cpp b/nvdiffrec/lib/render/renderutils/c_src/torch_bindings.cpp new file mode 100644 index 0000000..64c9e70 --- /dev/null +++ b/nvdiffrec/lib/render/renderutils/c_src/torch_bindings.cpp @@ -0,0 +1,1062 @@ +/* + * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#ifdef _MSC_VER +#pragma warning(push, 0) +#include +#pragma warning(pop) +#else +#include +#endif + +#include +#include +#include +#include + +#define NVDR_CHECK_CUDA_ERROR(CUDA_CALL) { cudaError_t err = CUDA_CALL; AT_CUDA_CHECK(cudaGetLastError()); } +#define NVDR_CHECK_GL_ERROR(GL_CALL) { GL_CALL; GLenum err = glGetError(); TORCH_CHECK(err == GL_NO_ERROR, "OpenGL error: ", getGLErrorString(err), "[", #GL_CALL, ";]"); } +#define CHECK_TENSOR(X, DIMS, CHANNELS) \ + TORCH_CHECK(X.is_cuda(), #X " must be a cuda tensor") \ + TORCH_CHECK(X.scalar_type() == torch::kFloat || X.scalar_type() == torch::kBFloat16, #X " must be fp32 or bf16") \ + TORCH_CHECK(X.dim() == DIMS, #X " must have " #DIMS " dimensions") \ + TORCH_CHECK(X.size(DIMS - 1) == CHANNELS, #X " must have " #CHANNELS " channels") + +#include "common.h" +#include "loss.h" +#include "normal.h" +#include "cubemap.h" +#include "bsdf.h" +#include "mesh.h" + +#define BLOCK_X 8 +#define BLOCK_Y 8 + +//------------------------------------------------------------------------ +// mesh.cu + +void xfmPointsFwdKernel(XfmKernelParams p); +void xfmPointsBwdKernel(XfmKernelParams p); + +//------------------------------------------------------------------------ +// loss.cu + +void imgLossFwdKernel(LossKernelParams p); +void imgLossBwdKernel(LossKernelParams p); + +//------------------------------------------------------------------------ +// normal.cu + +void PrepareShadingNormalFwdKernel(PrepareShadingNormalKernelParams p); +void PrepareShadingNormalBwdKernel(PrepareShadingNormalKernelParams p); + +//------------------------------------------------------------------------ +// cubemap.cu + +void DiffuseCubemapFwdKernel(DiffuseCubemapKernelParams p); +void DiffuseCubemapBwdKernel(DiffuseCubemapKernelParams p); +void SpecularBoundsKernel(SpecularBoundsKernelParams p); +void SpecularCubemapFwdKernel(SpecularCubemapKernelParams p); +void SpecularCubemapBwdKernel(SpecularCubemapKernelParams p); + +//------------------------------------------------------------------------ +// bsdf.cu + +void LambertFwdKernel(LambertKernelParams p); +void LambertBwdKernel(LambertKernelParams p); + +void FrostbiteDiffuseFwdKernel(FrostbiteDiffuseKernelParams p); +void FrostbiteDiffuseBwdKernel(FrostbiteDiffuseKernelParams p); + +void FresnelShlickFwdKernel(FresnelShlickKernelParams p); +void FresnelShlickBwdKernel(FresnelShlickKernelParams p); + +void ndfGGXFwdKernel(NdfGGXParams p); +void ndfGGXBwdKernel(NdfGGXParams p); + +void lambdaGGXFwdKernel(NdfGGXParams p); +void lambdaGGXBwdKernel(NdfGGXParams p); + +void maskingSmithFwdKernel(MaskingSmithParams p); +void maskingSmithBwdKernel(MaskingSmithParams p); + +void pbrSpecularFwdKernel(PbrSpecular p); +void pbrSpecularBwdKernel(PbrSpecular p); + +void pbrBSDFFwdKernel(PbrBSDF p); +void pbrBSDFBwdKernel(PbrBSDF p); + +//------------------------------------------------------------------------ +// Tensor helpers + +void update_grid(dim3 &gridSize, torch::Tensor x) +{ + gridSize.x = std::max(gridSize.x, (uint32_t)x.size(2)); + gridSize.y = std::max(gridSize.y, (uint32_t)x.size(1)); + gridSize.z = std::max(gridSize.z, (uint32_t)x.size(0)); +} + +template +void update_grid(dim3& gridSize, torch::Tensor x, Ts&&... vs) +{ + gridSize.x = std::max(gridSize.x, (uint32_t)x.size(2)); + gridSize.y = std::max(gridSize.y, (uint32_t)x.size(1)); + gridSize.z = std::max(gridSize.z, (uint32_t)x.size(0)); + update_grid(gridSize, std::forward(vs)...); +} + +Tensor make_cuda_tensor(torch::Tensor val) +{ + Tensor res; + for (int i = 0; i < val.dim(); ++i) + { + res.dims[i] = val.size(i); + res.strides[i] = val.stride(i); + } + res.fp16 = val.scalar_type() == torch::kBFloat16; + res.val = res.fp16 ? (void*)val.data_ptr() : (void*)val.data_ptr(); + res.d_val = nullptr; + return res; +} + +Tensor make_cuda_tensor(torch::Tensor val, dim3 outDims, torch::Tensor* grad = nullptr) +{ + Tensor res; + for (int i = 0; i < val.dim(); ++i) + { + res.dims[i] = val.size(i); + res.strides[i] = val.stride(i); + } + if (val.dim() == 4) + res._dims[0] = outDims.z, res._dims[1] = outDims.y, res._dims[2] = outDims.x, res._dims[3] = val.size(3); + else + res._dims[0] = outDims.z, res._dims[1] = outDims.x, res._dims[2] = val.size(2), res._dims[3] = 1; // Add a trailing one for indexing math to work out + + res.fp16 = val.scalar_type() == torch::kBFloat16; + res.val = res.fp16 ? (void*)val.data_ptr() : (void*)val.data_ptr(); + res.d_val = nullptr; + if (grad != nullptr) + { + if (val.dim() == 4) + *grad = torch::empty({ outDims.z, outDims.y, outDims.x, val.size(3) }, torch::TensorOptions().dtype(res.fp16 ? torch::kBFloat16 : torch::kFloat32).device(torch::kCUDA)); + else // 3 + *grad = torch::empty({ outDims.z, outDims.x, val.size(2) }, torch::TensorOptions().dtype(res.fp16 ? torch::kBFloat16 : torch::kFloat32).device(torch::kCUDA)); + + res.d_val = res.fp16 ? (void*)grad->data_ptr() : (void*)grad->data_ptr(); + } + return res; +} + +//------------------------------------------------------------------------ +// prepare_shading_normal + +torch::Tensor prepare_shading_normal_fwd(torch::Tensor pos, torch::Tensor view_pos, torch::Tensor perturbed_nrm, torch::Tensor smooth_nrm, torch::Tensor smooth_tng, torch::Tensor geom_nrm, bool two_sided_shading, bool opengl, bool fp16) +{ + CHECK_TENSOR(pos, 4, 3); + CHECK_TENSOR(view_pos, 4, 3); + CHECK_TENSOR(perturbed_nrm, 4, 3); + CHECK_TENSOR(smooth_nrm, 4, 3); + CHECK_TENSOR(smooth_tng, 4, 3); + CHECK_TENSOR(geom_nrm, 4, 3); + + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + // Extract input parameters. + PrepareShadingNormalKernelParams p; + p.two_sided_shading = two_sided_shading; + p.opengl = opengl; + p.out.fp16 = fp16; + update_grid(p.gridSize, pos, view_pos, perturbed_nrm, smooth_nrm, smooth_tng, geom_nrm); + + // Allocate output tensors. + torch::TensorOptions opts = torch::TensorOptions().dtype(p.out.fp16 ? torch::kBFloat16 : torch::kFloat32).device(torch::kCUDA); + torch::Tensor out = torch::empty({ p.gridSize.z, p.gridSize.y, p.gridSize.x, 3 }, opts); + + // Choose launch parameters. + dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize); + dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize); + + // Setup tensors + p.pos = make_cuda_tensor(pos, p.gridSize); + p.view_pos = make_cuda_tensor(view_pos, p.gridSize); + p.perturbed_nrm = make_cuda_tensor(perturbed_nrm, p.gridSize); + p.smooth_nrm = make_cuda_tensor(smooth_nrm, p.gridSize); + p.smooth_tng = make_cuda_tensor(smooth_tng, p.gridSize); + p.geom_nrm = make_cuda_tensor(geom_nrm, p.gridSize); + p.out = make_cuda_tensor(out, p.gridSize); + + // Launch CUDA kernel. + void* args[] = { &p }; + NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)PrepareShadingNormalFwdKernel, gridSize, blockSize, args, 0, stream)); + + return out; +} + +std::tuple prepare_shading_normal_bwd(torch::Tensor pos, torch::Tensor view_pos, torch::Tensor perturbed_nrm, torch::Tensor smooth_nrm, torch::Tensor smooth_tng, torch::Tensor geom_nrm, torch::Tensor grad, bool two_sided_shading, bool opengl) +{ + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + // Extract input parameters. + PrepareShadingNormalKernelParams p; + p.two_sided_shading = two_sided_shading; + p.opengl = opengl; + update_grid(p.gridSize, pos, view_pos, perturbed_nrm, smooth_nrm, smooth_tng, geom_nrm); + + // Choose launch parameters. + dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize); + dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize); + + // Setup tensors + torch::Tensor pos_grad, view_pos_grad, perturbed_nrm_grad, smooth_nrm_grad, smooth_tng_grad, geom_nrm_grad; + p.pos = make_cuda_tensor(pos, p.gridSize, &pos_grad); + p.view_pos = make_cuda_tensor(view_pos, p.gridSize, &view_pos_grad); + p.perturbed_nrm = make_cuda_tensor(perturbed_nrm, p.gridSize, &perturbed_nrm_grad); + p.smooth_nrm = make_cuda_tensor(smooth_nrm, p.gridSize, &smooth_nrm_grad); + p.smooth_tng = make_cuda_tensor(smooth_tng, p.gridSize, &smooth_tng_grad); + p.geom_nrm = make_cuda_tensor(geom_nrm, p.gridSize, &geom_nrm_grad); + p.out = make_cuda_tensor(grad, p.gridSize); + + // Launch CUDA kernel. + void* args[] = { &p }; + NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)PrepareShadingNormalBwdKernel, gridSize, blockSize, args, 0, stream)); + + return std::tuple(pos_grad, view_pos_grad, perturbed_nrm_grad, smooth_nrm_grad, smooth_tng_grad, geom_nrm_grad); +} + +//------------------------------------------------------------------------ +// lambert + +torch::Tensor lambert_fwd(torch::Tensor nrm, torch::Tensor wi, bool fp16) +{ + CHECK_TENSOR(nrm, 4, 3); + CHECK_TENSOR(wi, 4, 3); + + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + // Extract input parameters. + LambertKernelParams p; + p.out.fp16 = fp16; + update_grid(p.gridSize, nrm, wi); + + // Allocate output tensors. + torch::TensorOptions opts = torch::TensorOptions().dtype(p.out.fp16 ? torch::kBFloat16 : torch::kFloat32).device(torch::kCUDA); + torch::Tensor out = torch::empty({ p.gridSize.z, p.gridSize.y, p.gridSize.x, 1 }, opts); + + // Choose launch parameters. + dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize); + dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize); + + p.nrm = make_cuda_tensor(nrm, p.gridSize); + p.wi = make_cuda_tensor(wi, p.gridSize); + p.out = make_cuda_tensor(out, p.gridSize); + + // Launch CUDA kernel. + void* args[] = { &p }; + NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)LambertFwdKernel, gridSize, blockSize, args, 0, stream)); + + return out; +} + +std::tuple lambert_bwd(torch::Tensor nrm, torch::Tensor wi, torch::Tensor grad) +{ + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + // Extract input parameters. + LambertKernelParams p; + update_grid(p.gridSize, nrm, wi); + + // Choose launch parameters. + dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize); + dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize); + + torch::Tensor nrm_grad, wi_grad; + p.nrm = make_cuda_tensor(nrm, p.gridSize, &nrm_grad); + p.wi = make_cuda_tensor(wi, p.gridSize, &wi_grad); + p.out = make_cuda_tensor(grad, p.gridSize); + + // Launch CUDA kernel. + void* args[] = { &p }; + NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)LambertBwdKernel, gridSize, blockSize, args, 0, stream)); + + return std::tuple(nrm_grad, wi_grad); +} + +//------------------------------------------------------------------------ +// frostbite diffuse + +torch::Tensor frostbite_fwd(torch::Tensor nrm, torch::Tensor wi, torch::Tensor wo, torch::Tensor linearRoughness, bool fp16) +{ + CHECK_TENSOR(nrm, 4, 3); + CHECK_TENSOR(wi, 4, 3); + CHECK_TENSOR(wo, 4, 3); + CHECK_TENSOR(linearRoughness, 4, 1); + + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + // Extract input parameters. + FrostbiteDiffuseKernelParams p; + p.out.fp16 = fp16; + update_grid(p.gridSize, nrm, wi, wo, linearRoughness); + + // Allocate output tensors. + torch::TensorOptions opts = torch::TensorOptions().dtype(p.out.fp16 ? torch::kBFloat16 : torch::kFloat32).device(torch::kCUDA); + torch::Tensor out = torch::empty({ p.gridSize.z, p.gridSize.y, p.gridSize.x, 1 }, opts); + + // Choose launch parameters. + dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize); + dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize); + + p.nrm = make_cuda_tensor(nrm, p.gridSize); + p.wi = make_cuda_tensor(wi, p.gridSize); + p.wo = make_cuda_tensor(wo, p.gridSize); + p.linearRoughness = make_cuda_tensor(linearRoughness, p.gridSize); + p.out = make_cuda_tensor(out, p.gridSize); + + // Launch CUDA kernel. + void* args[] = { &p }; + NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)FrostbiteDiffuseFwdKernel, gridSize, blockSize, args, 0, stream)); + + return out; +} + +std::tuple frostbite_bwd(torch::Tensor nrm, torch::Tensor wi, torch::Tensor wo, torch::Tensor linearRoughness, torch::Tensor grad) +{ + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + // Extract input parameters. + FrostbiteDiffuseKernelParams p; + update_grid(p.gridSize, nrm, wi, wo, linearRoughness); + + // Choose launch parameters. + dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize); + dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize); + + torch::Tensor nrm_grad, wi_grad, wo_grad, linearRoughness_grad; + p.nrm = make_cuda_tensor(nrm, p.gridSize, &nrm_grad); + p.wi = make_cuda_tensor(wi, p.gridSize, &wi_grad); + p.wo = make_cuda_tensor(wo, p.gridSize, &wo_grad); + p.linearRoughness = make_cuda_tensor(linearRoughness, p.gridSize, &linearRoughness_grad); + p.out = make_cuda_tensor(grad, p.gridSize); + + // Launch CUDA kernel. + void* args[] = { &p }; + NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)FrostbiteDiffuseBwdKernel, gridSize, blockSize, args, 0, stream)); + + return std::tuple(nrm_grad, wi_grad, wo_grad, linearRoughness_grad); +} + +//------------------------------------------------------------------------ +// fresnel_shlick + +torch::Tensor fresnel_shlick_fwd(torch::Tensor f0, torch::Tensor f90, torch::Tensor cosTheta, bool fp16) +{ + CHECK_TENSOR(f0, 4, 3); + CHECK_TENSOR(f90, 4, 3); + CHECK_TENSOR(cosTheta, 4, 1); + + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + // Extract input parameters. + FresnelShlickKernelParams p; + p.out.fp16 = fp16; + update_grid(p.gridSize, f0, f90, cosTheta); + + // Allocate output tensors. + torch::TensorOptions opts = torch::TensorOptions().dtype(p.out.fp16 ? torch::kBFloat16 : torch::kFloat32).device(torch::kCUDA); + torch::Tensor out = torch::empty({ p.gridSize.z, p.gridSize.y, p.gridSize.x, 3 }, opts); + + // Choose launch parameters. + dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize); + dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize); + + p.f0 = make_cuda_tensor(f0, p.gridSize); + p.f90 = make_cuda_tensor(f90, p.gridSize); + p.cosTheta = make_cuda_tensor(cosTheta, p.gridSize); + p.out = make_cuda_tensor(out, p.gridSize); + + // Launch CUDA kernel. + void* args[] = { &p }; + NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)FresnelShlickFwdKernel, gridSize, blockSize, args, 0, stream)); + + return out; +} + +std::tuple fresnel_shlick_bwd(torch::Tensor f0, torch::Tensor f90, torch::Tensor cosTheta, torch::Tensor grad) +{ + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + // Extract input parameters. + FresnelShlickKernelParams p; + update_grid(p.gridSize, f0, f90, cosTheta); + + // Choose launch parameters. + dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize); + dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize); + + torch::Tensor f0_grad, f90_grad, cosT_grad; + p.f0 = make_cuda_tensor(f0, p.gridSize, &f0_grad); + p.f90 = make_cuda_tensor(f90, p.gridSize, &f90_grad); + p.cosTheta = make_cuda_tensor(cosTheta, p.gridSize, &cosT_grad); + p.out = make_cuda_tensor(grad, p.gridSize); + + // Launch CUDA kernel. + void* args[] = { &p }; + NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)FresnelShlickBwdKernel, gridSize, blockSize, args, 0, stream)); + + return std::tuple(f0_grad, f90_grad, cosT_grad); +} + +//------------------------------------------------------------------------ +// ndf_ggd + +torch::Tensor ndf_ggx_fwd(torch::Tensor alphaSqr, torch::Tensor cosTheta, bool fp16) +{ + CHECK_TENSOR(alphaSqr, 4, 1); + CHECK_TENSOR(cosTheta, 4, 1); + + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + // Extract input parameters. + NdfGGXParams p; + p.out.fp16 = fp16; + update_grid(p.gridSize, alphaSqr, cosTheta); + + // Allocate output tensors. + torch::TensorOptions opts = torch::TensorOptions().dtype(p.out.fp16 ? torch::kBFloat16 : torch::kFloat32).device(torch::kCUDA); + torch::Tensor out = torch::empty({ p.gridSize.z, p.gridSize.y, p.gridSize.x, 1 }, opts); + + // Choose launch parameters. + dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize); + dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize); + + p.alphaSqr = make_cuda_tensor(alphaSqr, p.gridSize); + p.cosTheta = make_cuda_tensor(cosTheta, p.gridSize); + p.out = make_cuda_tensor(out, p.gridSize); + + // Launch CUDA kernel. + void* args[] = { &p }; + NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)ndfGGXFwdKernel, gridSize, blockSize, args, 0, stream)); + + return out; +} + +std::tuple ndf_ggx_bwd(torch::Tensor alphaSqr, torch::Tensor cosTheta, torch::Tensor grad) +{ + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + // Extract input parameters. + NdfGGXParams p; + update_grid(p.gridSize, alphaSqr, cosTheta); + + // Choose launch parameters. + dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize); + dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize); + + torch::Tensor alphaSqr_grad, cosTheta_grad; + p.alphaSqr = make_cuda_tensor(alphaSqr, p.gridSize, &alphaSqr_grad); + p.cosTheta = make_cuda_tensor(cosTheta, p.gridSize, &cosTheta_grad); + p.out = make_cuda_tensor(grad, p.gridSize); + + // Launch CUDA kernel. + void* args[] = { &p }; + NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)ndfGGXBwdKernel, gridSize, blockSize, args, 0, stream)); + + return std::tuple(alphaSqr_grad, cosTheta_grad); +} + +//------------------------------------------------------------------------ +// lambda_ggx + +torch::Tensor lambda_ggx_fwd(torch::Tensor alphaSqr, torch::Tensor cosTheta, bool fp16) +{ + CHECK_TENSOR(alphaSqr, 4, 1); + CHECK_TENSOR(cosTheta, 4, 1); + + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + // Extract input parameters. + NdfGGXParams p; + p.out.fp16 = fp16; + update_grid(p.gridSize, alphaSqr, cosTheta); + + // Allocate output tensors. + torch::TensorOptions opts = torch::TensorOptions().dtype(p.out.fp16 ? torch::kBFloat16 : torch::kFloat32).device(torch::kCUDA); + torch::Tensor out = torch::empty({ p.gridSize.z, p.gridSize.y, p.gridSize.x, 1 }, opts); + + // Choose launch parameters. + dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize); + dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize); + + p.alphaSqr = make_cuda_tensor(alphaSqr, p.gridSize); + p.cosTheta = make_cuda_tensor(cosTheta, p.gridSize); + p.out = make_cuda_tensor(out, p.gridSize); + + // Launch CUDA kernel. + void* args[] = { &p }; + NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)lambdaGGXFwdKernel, gridSize, blockSize, args, 0, stream)); + + return out; +} + +std::tuple lambda_ggx_bwd(torch::Tensor alphaSqr, torch::Tensor cosTheta, torch::Tensor grad) +{ + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + // Extract input parameters. + NdfGGXParams p; + update_grid(p.gridSize, alphaSqr, cosTheta); + + // Choose launch parameters. + dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize); + dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize); + + torch::Tensor alphaSqr_grad, cosTheta_grad; + p.alphaSqr = make_cuda_tensor(alphaSqr, p.gridSize, &alphaSqr_grad); + p.cosTheta = make_cuda_tensor(cosTheta, p.gridSize, &cosTheta_grad); + p.out = make_cuda_tensor(grad, p.gridSize); + + // Launch CUDA kernel. + void* args[] = { &p }; + NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)lambdaGGXBwdKernel, gridSize, blockSize, args, 0, stream)); + + return std::tuple(alphaSqr_grad, cosTheta_grad); +} + +//------------------------------------------------------------------------ +// masking_smith + +torch::Tensor masking_smith_fwd(torch::Tensor alphaSqr, torch::Tensor cosThetaI, torch::Tensor cosThetaO, bool fp16) +{ + CHECK_TENSOR(alphaSqr, 4, 1); + CHECK_TENSOR(cosThetaI, 4, 1); + CHECK_TENSOR(cosThetaO, 4, 1); + + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + // Extract input parameters. + MaskingSmithParams p; + p.out.fp16 = fp16; + update_grid(p.gridSize, alphaSqr, cosThetaI, cosThetaO); + + // Allocate output tensors. + torch::TensorOptions opts = torch::TensorOptions().dtype(p.out.fp16 ? torch::kBFloat16 : torch::kFloat32).device(torch::kCUDA); + torch::Tensor out = torch::empty({ p.gridSize.z, p.gridSize.y, p.gridSize.x, 1 }, opts); + + // Choose launch parameters. + dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize); + dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize); + + p.alphaSqr = make_cuda_tensor(alphaSqr, p.gridSize); + p.cosThetaI = make_cuda_tensor(cosThetaI, p.gridSize); + p.cosThetaO = make_cuda_tensor(cosThetaO, p.gridSize); + p.out = make_cuda_tensor(out, p.gridSize); + + // Launch CUDA kernel. + void* args[] = { &p }; + NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)maskingSmithFwdKernel, gridSize, blockSize, args, 0, stream)); + + return out; +} + +std::tuple masking_smith_bwd(torch::Tensor alphaSqr, torch::Tensor cosThetaI, torch::Tensor cosThetaO, torch::Tensor grad) +{ + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + // Extract input parameters. + MaskingSmithParams p; + update_grid(p.gridSize, alphaSqr, cosThetaI, cosThetaO); + + // Choose launch parameters. + dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize); + dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize); + + torch::Tensor alphaSqr_grad, cosThetaI_grad, cosThetaO_grad; + p.alphaSqr = make_cuda_tensor(alphaSqr, p.gridSize, &alphaSqr_grad); + p.cosThetaI = make_cuda_tensor(cosThetaI, p.gridSize, &cosThetaI_grad); + p.cosThetaO = make_cuda_tensor(cosThetaO, p.gridSize, &cosThetaO_grad); + p.out = make_cuda_tensor(grad, p.gridSize); + + // Launch CUDA kernel. + void* args[] = { &p }; + NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)maskingSmithBwdKernel, gridSize, blockSize, args, 0, stream)); + + return std::tuple(alphaSqr_grad, cosThetaI_grad, cosThetaO_grad); +} + +//------------------------------------------------------------------------ +// pbr_specular + +torch::Tensor pbr_specular_fwd(torch::Tensor col, torch::Tensor nrm, torch::Tensor wo, torch::Tensor wi, torch::Tensor alpha, float min_roughness, bool fp16) +{ + CHECK_TENSOR(col, 4, 3); + CHECK_TENSOR(nrm, 4, 3); + CHECK_TENSOR(wo, 4, 3); + CHECK_TENSOR(wi, 4, 3); + CHECK_TENSOR(alpha, 4, 1); + + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + // Extract input parameters. + PbrSpecular p; + p.out.fp16 = fp16; + p.min_roughness = min_roughness; + update_grid(p.gridSize, col, nrm, wo, wi, alpha); + + // Allocate output tensors. + torch::TensorOptions opts = torch::TensorOptions().dtype(p.out.fp16 ? torch::kBFloat16 : torch::kFloat32).device(torch::kCUDA); + torch::Tensor out = torch::empty({ p.gridSize.z, p.gridSize.y, p.gridSize.x, 3 }, opts); + + // Choose launch parameters. + dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize); + dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize); + + p.col = make_cuda_tensor(col, p.gridSize); + p.nrm = make_cuda_tensor(nrm, p.gridSize); + p.wo = make_cuda_tensor(wo, p.gridSize); + p.wi = make_cuda_tensor(wi, p.gridSize); + p.alpha = make_cuda_tensor(alpha, p.gridSize); + p.out = make_cuda_tensor(out, p.gridSize); + + // Launch CUDA kernel. + void* args[] = { &p }; + NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)pbrSpecularFwdKernel, gridSize, blockSize, args, 0, stream)); + + return out; +} + +std::tuple pbr_specular_bwd(torch::Tensor col, torch::Tensor nrm, torch::Tensor wo, torch::Tensor wi, torch::Tensor alpha, float min_roughness, torch::Tensor grad) +{ + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + // Extract input parameters. + PbrSpecular p; + update_grid(p.gridSize, col, nrm, wo, wi, alpha); + p.min_roughness = min_roughness; + + // Choose launch parameters. + dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize); + dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize); + + torch::Tensor col_grad, nrm_grad, wo_grad, wi_grad, alpha_grad; + p.col = make_cuda_tensor(col, p.gridSize, &col_grad); + p.nrm = make_cuda_tensor(nrm, p.gridSize, &nrm_grad); + p.wo = make_cuda_tensor(wo, p.gridSize, &wo_grad); + p.wi = make_cuda_tensor(wi, p.gridSize, &wi_grad); + p.alpha = make_cuda_tensor(alpha, p.gridSize, &alpha_grad); + p.out = make_cuda_tensor(grad, p.gridSize); + + // Launch CUDA kernel. + void* args[] = { &p }; + NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)pbrSpecularBwdKernel, gridSize, blockSize, args, 0, stream)); + + return std::tuple(col_grad, nrm_grad, wo_grad, wi_grad, alpha_grad); +} + +//------------------------------------------------------------------------ +// pbr_bsdf + +torch::Tensor pbr_bsdf_fwd(torch::Tensor kd, torch::Tensor arm, torch::Tensor pos, torch::Tensor nrm, torch::Tensor view_pos, torch::Tensor light_pos, float min_roughness, int BSDF, bool fp16) +{ + CHECK_TENSOR(kd, 4, 3); + CHECK_TENSOR(arm, 4, 3); + CHECK_TENSOR(pos, 4, 3); + CHECK_TENSOR(nrm, 4, 3); + CHECK_TENSOR(view_pos, 4, 3); + CHECK_TENSOR(light_pos, 4, 3); + + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + // Extract input parameters. + PbrBSDF p; + p.out.fp16 = fp16; + p.min_roughness = min_roughness; + p.BSDF = BSDF; + update_grid(p.gridSize, kd, arm, pos, nrm, view_pos, light_pos); + + // Allocate output tensors. + torch::TensorOptions opts = torch::TensorOptions().dtype(p.out.fp16 ? torch::kBFloat16 : torch::kFloat32).device(torch::kCUDA); + torch::Tensor out = torch::empty({ p.gridSize.z, p.gridSize.y, p.gridSize.x, 3 }, opts); + + // Choose launch parameters. + dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize); + dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize); + + p.kd = make_cuda_tensor(kd, p.gridSize); + p.arm = make_cuda_tensor(arm, p.gridSize); + p.pos = make_cuda_tensor(pos, p.gridSize); + p.nrm = make_cuda_tensor(nrm, p.gridSize); + p.view_pos = make_cuda_tensor(view_pos, p.gridSize); + p.light_pos = make_cuda_tensor(light_pos, p.gridSize); + p.out = make_cuda_tensor(out, p.gridSize); + + // Launch CUDA kernel. + void* args[] = { &p }; + NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)pbrBSDFFwdKernel, gridSize, blockSize, args, 0, stream)); + + return out; +} + +std::tuple pbr_bsdf_bwd(torch::Tensor kd, torch::Tensor arm, torch::Tensor pos, torch::Tensor nrm, torch::Tensor view_pos, torch::Tensor light_pos, float min_roughness, int BSDF, torch::Tensor grad) +{ + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + // Extract input parameters. + PbrBSDF p; + update_grid(p.gridSize, kd, arm, pos, nrm, view_pos, light_pos); + p.min_roughness = min_roughness; + p.BSDF = BSDF; + + // Choose launch parameters. + dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize); + dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize); + + torch::Tensor kd_grad, arm_grad, pos_grad, nrm_grad, view_pos_grad, light_pos_grad; + p.kd = make_cuda_tensor(kd, p.gridSize, &kd_grad); + p.arm = make_cuda_tensor(arm, p.gridSize, &arm_grad); + p.pos = make_cuda_tensor(pos, p.gridSize, &pos_grad); + p.nrm = make_cuda_tensor(nrm, p.gridSize, &nrm_grad); + p.view_pos = make_cuda_tensor(view_pos, p.gridSize, &view_pos_grad); + p.light_pos = make_cuda_tensor(light_pos, p.gridSize, &light_pos_grad); + p.out = make_cuda_tensor(grad, p.gridSize); + + // Launch CUDA kernel. + void* args[] = { &p }; + NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)pbrBSDFBwdKernel, gridSize, blockSize, args, 0, stream)); + + return std::tuple(kd_grad, arm_grad, pos_grad, nrm_grad, view_pos_grad, light_pos_grad); +} + +//------------------------------------------------------------------------ +// filter_cubemap + +torch::Tensor diffuse_cubemap_fwd(torch::Tensor cubemap) +{ + CHECK_TENSOR(cubemap, 4, 3); + + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + // Extract input parameters. + DiffuseCubemapKernelParams p; + update_grid(p.gridSize, cubemap); + + // Allocate output tensors. + torch::TensorOptions opts = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); + torch::Tensor out = torch::empty({ p.gridSize.z, p.gridSize.y, p.gridSize.x, 3 }, opts); + + // Choose launch parameters. + dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize); + dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize); + + // Setup tensors + p.cubemap = make_cuda_tensor(cubemap, p.gridSize); + p.out = make_cuda_tensor(out, p.gridSize); + + // Launch CUDA kernel. + void* args[] = { &p }; + NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)DiffuseCubemapFwdKernel, gridSize, blockSize, args, 0, stream)); + + return out; +} + +torch::Tensor diffuse_cubemap_bwd(torch::Tensor cubemap, torch::Tensor grad) +{ + CHECK_TENSOR(cubemap, 4, 3); + CHECK_TENSOR(grad, 4, 3); + + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + // Extract input parameters. + DiffuseCubemapKernelParams p; + update_grid(p.gridSize, cubemap); + + // Choose launch parameters. + dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize); + dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize); + + // Setup tensors + torch::Tensor cubemap_grad; + p.cubemap = make_cuda_tensor(cubemap, p.gridSize); + p.out = make_cuda_tensor(grad, p.gridSize); + + cubemap_grad = torch::zeros({ p.gridSize.z, p.gridSize.y, p.gridSize.x, cubemap.size(3) }, torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA)); + p.cubemap.d_val = (void*)cubemap_grad.data_ptr(); + + // Launch CUDA kernel. + void* args[] = { &p }; + NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)DiffuseCubemapBwdKernel, gridSize, blockSize, args, 0, stream)); + + return cubemap_grad; +} + +torch::Tensor specular_bounds(int resolution, float costheta_cutoff) +{ + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + // Extract input parameters. + SpecularBoundsKernelParams p; + p.costheta_cutoff = costheta_cutoff; + p.gridSize = dim3(resolution, resolution, 6); + + // Allocate output tensors. + torch::TensorOptions opts = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); + torch::Tensor out = torch::zeros({ p.gridSize.z, p.gridSize.y, p.gridSize.x, 6*4 }, opts); + + // Choose launch parameters. + dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize); + dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize); + + // Setup tensors + p.out = make_cuda_tensor(out, p.gridSize); + + // Launch CUDA kernel. + void* args[] = { &p }; + NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)SpecularBoundsKernel, gridSize, blockSize, args, 0, stream)); + + return out; +} + +torch::Tensor specular_cubemap_fwd(torch::Tensor cubemap, torch::Tensor bounds, float roughness, float costheta_cutoff) +{ + CHECK_TENSOR(cubemap, 4, 3); + CHECK_TENSOR(bounds, 4, 6*4); + + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + // Extract input parameters. + SpecularCubemapKernelParams p; + p.roughness = roughness; + p.costheta_cutoff = costheta_cutoff; + update_grid(p.gridSize, cubemap); + + // Allocate output tensors. + torch::TensorOptions opts = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); + torch::Tensor out = torch::empty({ p.gridSize.z, p.gridSize.y, p.gridSize.x, 4 }, opts); + + // Choose launch parameters. + dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize); + dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize); + + // Setup tensors + p.cubemap = make_cuda_tensor(cubemap, p.gridSize); + p.bounds = make_cuda_tensor(bounds, p.gridSize); + p.out = make_cuda_tensor(out, p.gridSize); + + // Launch CUDA kernel. + void* args[] = { &p }; + NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)SpecularCubemapFwdKernel, gridSize, blockSize, args, 0, stream)); + + return out; +} + +torch::Tensor specular_cubemap_bwd(torch::Tensor cubemap, torch::Tensor bounds, torch::Tensor grad, float roughness, float costheta_cutoff) +{ + CHECK_TENSOR(cubemap, 4, 3); + CHECK_TENSOR(bounds, 4, 6*4); + + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + // Extract input parameters. + SpecularCubemapKernelParams p; + p.roughness = roughness; + p.costheta_cutoff = costheta_cutoff; + update_grid(p.gridSize, cubemap); + + // Choose launch parameters. + dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize); + dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize); + + // Setup tensors + torch::Tensor cubemap_grad; + p.cubemap = make_cuda_tensor(cubemap, p.gridSize); + p.bounds = make_cuda_tensor(bounds, p.gridSize); + p.out = make_cuda_tensor(grad, p.gridSize); + + cubemap_grad = torch::zeros({ p.gridSize.z, p.gridSize.y, p.gridSize.x, cubemap.size(3) }, torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA)); + p.cubemap.d_val = (void*)cubemap_grad.data_ptr(); + + // Launch CUDA kernel. + void* args[] = { &p }; + NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)SpecularCubemapBwdKernel, gridSize, blockSize, args, 0, stream)); + + return cubemap_grad; +} + +//------------------------------------------------------------------------ +// loss function + +LossType strToLoss(std::string str) +{ + if (str == "mse") + return LOSS_MSE; + else if (str == "relmse") + return LOSS_RELMSE; + else if (str == "smape") + return LOSS_SMAPE; + else + return LOSS_L1; +} + +torch::Tensor image_loss_fwd(torch::Tensor img, torch::Tensor target, std::string loss, std::string tonemapper, bool fp16) +{ + CHECK_TENSOR(img, 4, 3); + CHECK_TENSOR(target, 4, 3); + + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + // Extract input parameters. + LossKernelParams p; + p.out.fp16 = fp16; + p.loss = strToLoss(loss); + p.tonemapper = tonemapper == "log_srgb" ? TONEMAPPER_LOG_SRGB : TONEMAPPER_NONE; + update_grid(p.gridSize, img, target); + + // Choose launch parameters. + dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize); + dim3 warpSize = getWarpSize(blockSize); + dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize); + + // Allocate output tensors. + torch::TensorOptions opts = torch::TensorOptions().dtype(p.out.fp16 ? torch::kBFloat16 : torch::kFloat32).device(torch::kCUDA); + torch::Tensor out = torch::empty({ (p.gridSize.z - 1)/ warpSize.z + 1, (p.gridSize.y - 1) / warpSize.y + 1, (p.gridSize.x - 1) / warpSize.x + 1, 1 }, opts); + + p.img = make_cuda_tensor(img, p.gridSize); + p.target = make_cuda_tensor(target, p.gridSize); + p.out = make_cuda_tensor(out, p.gridSize); + + // Launch CUDA kernel. + void* args[] = { &p }; + NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)imgLossFwdKernel, gridSize, blockSize, args, 0, stream)); + + return out; +} + +std::tuple image_loss_bwd(torch::Tensor img, torch::Tensor target, torch::Tensor grad, std::string loss, std::string tonemapper) +{ + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + // Extract input parameters. + LossKernelParams p; + p.loss = strToLoss(loss); + p.tonemapper = tonemapper == "log_srgb" ? TONEMAPPER_LOG_SRGB : TONEMAPPER_NONE; + update_grid(p.gridSize, img, target); + + // Choose launch parameters. + dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize); + dim3 warpSize = getWarpSize(blockSize); + dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize); + + torch::Tensor img_grad, target_grad; + p.img = make_cuda_tensor(img, p.gridSize, &img_grad); + p.target = make_cuda_tensor(target, p.gridSize, &target_grad); + p.out = make_cuda_tensor(grad, p.gridSize); + + // Launch CUDA kernel. + void* args[] = { &p }; + NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)imgLossBwdKernel, gridSize, blockSize, args, 0, stream)); + + return std::tuple(img_grad, target_grad); +} + +//------------------------------------------------------------------------ +// transform function + +torch::Tensor xfm_fwd(torch::Tensor points, torch::Tensor matrix, bool isPoints, bool fp16) +{ + CHECK_TENSOR(points, 3, 3); + CHECK_TENSOR(matrix, 3, 4); + + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + // Extract input parameters. + XfmKernelParams p; + p.out.fp16 = fp16; + p.isPoints = isPoints; + p.gridSize.x = points.size(1); + p.gridSize.y = 1; + p.gridSize.z = std::max(matrix.size(0), points.size(0)); + + // Choose launch parameters. + dim3 blockSize(BLOCK_X * BLOCK_Y, 1, 1); + dim3 warpSize = getWarpSize(blockSize); + dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize); + + // Allocate output tensors. + torch::TensorOptions opts = torch::TensorOptions().dtype(p.out.fp16 ? torch::kBFloat16 : torch::kFloat32).device(torch::kCUDA); + torch::Tensor out = isPoints ? torch::empty({ matrix.size(0), points.size(1), 4 }, opts) : torch::empty({ matrix.size(0), points.size(1), 3 }, opts); + + p.points = make_cuda_tensor(points, p.gridSize); + p.matrix = make_cuda_tensor(matrix, p.gridSize); + p.out = make_cuda_tensor(out, p.gridSize); + + // Launch CUDA kernel. + void* args[] = { &p }; + NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)xfmPointsFwdKernel, gridSize, blockSize, args, 0, stream)); + + return out; +} + +torch::Tensor xfm_bwd(torch::Tensor points, torch::Tensor matrix, torch::Tensor grad, bool isPoints) +{ + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + // Extract input parameters. + XfmKernelParams p; + p.isPoints = isPoints; + p.gridSize.x = points.size(1); + p.gridSize.y = 1; + p.gridSize.z = std::max(matrix.size(0), points.size(0)); + + // Choose launch parameters. + dim3 blockSize(BLOCK_X * BLOCK_Y, 1, 1); + dim3 warpSize = getWarpSize(blockSize); + dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize); + + torch::Tensor points_grad; + p.points = make_cuda_tensor(points, p.gridSize, &points_grad); + p.matrix = make_cuda_tensor(matrix, p.gridSize); + p.out = make_cuda_tensor(grad, p.gridSize); + + // Launch CUDA kernel. + void* args[] = { &p }; + NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)xfmPointsBwdKernel, gridSize, blockSize, args, 0, stream)); + + return points_grad; +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("prepare_shading_normal_fwd", &prepare_shading_normal_fwd, "prepare_shading_normal_fwd"); + m.def("prepare_shading_normal_bwd", &prepare_shading_normal_bwd, "prepare_shading_normal_bwd"); + m.def("lambert_fwd", &lambert_fwd, "lambert_fwd"); + m.def("lambert_bwd", &lambert_bwd, "lambert_bwd"); + m.def("frostbite_fwd", &frostbite_fwd, "frostbite_fwd"); + m.def("frostbite_bwd", &frostbite_bwd, "frostbite_bwd"); + m.def("fresnel_shlick_fwd", &fresnel_shlick_fwd, "fresnel_shlick_fwd"); + m.def("fresnel_shlick_bwd", &fresnel_shlick_bwd, "fresnel_shlick_bwd"); + m.def("ndf_ggx_fwd", &ndf_ggx_fwd, "ndf_ggx_fwd"); + m.def("ndf_ggx_bwd", &ndf_ggx_bwd, "ndf_ggx_bwd"); + m.def("lambda_ggx_fwd", &lambda_ggx_fwd, "lambda_ggx_fwd"); + m.def("lambda_ggx_bwd", &lambda_ggx_bwd, "lambda_ggx_bwd"); + m.def("masking_smith_fwd", &masking_smith_fwd, "masking_smith_fwd"); + m.def("masking_smith_bwd", &masking_smith_bwd, "masking_smith_bwd"); + m.def("pbr_specular_fwd", &pbr_specular_fwd, "pbr_specular_fwd"); + m.def("pbr_specular_bwd", &pbr_specular_bwd, "pbr_specular_bwd"); + m.def("pbr_bsdf_fwd", &pbr_bsdf_fwd, "pbr_bsdf_fwd"); + m.def("pbr_bsdf_bwd", &pbr_bsdf_bwd, "pbr_bsdf_bwd"); + m.def("diffuse_cubemap_fwd", &diffuse_cubemap_fwd, "diffuse_cubemap_fwd"); + m.def("diffuse_cubemap_bwd", &diffuse_cubemap_bwd, "diffuse_cubemap_bwd"); + m.def("specular_bounds", &specular_bounds, "specular_bounds"); + m.def("specular_cubemap_fwd", &specular_cubemap_fwd, "specular_cubemap_fwd"); + m.def("specular_cubemap_bwd", &specular_cubemap_bwd, "specular_cubemap_bwd"); + m.def("image_loss_fwd", &image_loss_fwd, "image_loss_fwd"); + m.def("image_loss_bwd", &image_loss_bwd, "image_loss_bwd"); + m.def("xfm_fwd", &xfm_fwd, "xfm_fwd"); + m.def("xfm_bwd", &xfm_bwd, "xfm_bwd"); +} \ No newline at end of file diff --git a/nvdiffrec/lib/render/renderutils/c_src/vec3f.h b/nvdiffrec/lib/render/renderutils/c_src/vec3f.h new file mode 100644 index 0000000..7e67454 --- /dev/null +++ b/nvdiffrec/lib/render/renderutils/c_src/vec3f.h @@ -0,0 +1,109 @@ +/* + * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#pragma once + +struct vec3f +{ + float x, y, z; + +#ifdef __CUDACC__ + __device__ vec3f() { } + __device__ vec3f(float v) { x = v; y = v; z = v; } + __device__ vec3f(float _x, float _y, float _z) { x = _x; y = _y; z = _z; } + __device__ vec3f(float3 v) { x = v.x; y = v.y; z = v.z; } + + __device__ inline vec3f& operator+=(const vec3f& b) { x += b.x; y += b.y; z += b.z; return *this; } + __device__ inline vec3f& operator-=(const vec3f& b) { x -= b.x; y -= b.y; z -= b.z; return *this; } + __device__ inline vec3f& operator*=(const vec3f& b) { x *= b.x; y *= b.y; z *= b.z; return *this; } + __device__ inline vec3f& operator/=(const vec3f& b) { x /= b.x; y /= b.y; z /= b.z; return *this; } +#endif +}; + +#ifdef __CUDACC__ +__device__ static inline vec3f operator+(const vec3f& a, const vec3f& b) { return vec3f(a.x + b.x, a.y + b.y, a.z + b.z); } +__device__ static inline vec3f operator-(const vec3f& a, const vec3f& b) { return vec3f(a.x - b.x, a.y - b.y, a.z - b.z); } +__device__ static inline vec3f operator*(const vec3f& a, const vec3f& b) { return vec3f(a.x * b.x, a.y * b.y, a.z * b.z); } +__device__ static inline vec3f operator/(const vec3f& a, const vec3f& b) { return vec3f(a.x / b.x, a.y / b.y, a.z / b.z); } +__device__ static inline vec3f operator-(const vec3f& a) { return vec3f(-a.x, -a.y, -a.z); } + +__device__ static inline float sum(vec3f a) +{ + return a.x + a.y + a.z; +} + +__device__ static inline vec3f cross(vec3f a, vec3f b) +{ + vec3f out; + out.x = a.y * b.z - a.z * b.y; + out.y = a.z * b.x - a.x * b.z; + out.z = a.x * b.y - a.y * b.x; + return out; +} + +__device__ static inline void bwdCross(vec3f a, vec3f b, vec3f &d_a, vec3f &d_b, vec3f d_out) +{ + d_a.x += d_out.z * b.y - d_out.y * b.z; + d_a.y += d_out.x * b.z - d_out.z * b.x; + d_a.z += d_out.y * b.x - d_out.x * b.y; + + d_b.x += d_out.y * a.z - d_out.z * a.y; + d_b.y += d_out.z * a.x - d_out.x * a.z; + d_b.z += d_out.x * a.y - d_out.y * a.x; +} + +__device__ static inline float dot(vec3f a, vec3f b) +{ + return a.x * b.x + a.y * b.y + a.z * b.z; +} + +__device__ static inline void bwdDot(vec3f a, vec3f b, vec3f& d_a, vec3f& d_b, float d_out) +{ + d_a.x += d_out * b.x; d_a.y += d_out * b.y; d_a.z += d_out * b.z; + d_b.x += d_out * a.x; d_b.y += d_out * a.y; d_b.z += d_out * a.z; +} + +__device__ static inline vec3f reflect(vec3f x, vec3f n) +{ + return n * 2.0f * dot(n, x) - x; +} + +__device__ static inline void bwdReflect(vec3f x, vec3f n, vec3f& d_x, vec3f& d_n, const vec3f d_out) +{ + d_x.x += d_out.x * (2 * n.x * n.x - 1) + d_out.y * (2 * n.x * n.y) + d_out.z * (2 * n.x * n.z); + d_x.y += d_out.x * (2 * n.x * n.y) + d_out.y * (2 * n.y * n.y - 1) + d_out.z * (2 * n.y * n.z); + d_x.z += d_out.x * (2 * n.x * n.z) + d_out.y * (2 * n.y * n.z) + d_out.z * (2 * n.z * n.z - 1); + + d_n.x += d_out.x * (2 * (2 * n.x * x.x + n.y * x.y + n.z * x.z)) + d_out.y * (2 * n.y * x.x) + d_out.z * (2 * n.z * x.x); + d_n.y += d_out.x * (2 * n.x * x.y) + d_out.y * (2 * (n.x * x.x + 2 * n.y * x.y + n.z * x.z)) + d_out.z * (2 * n.z * x.y); + d_n.z += d_out.x * (2 * n.x * x.z) + d_out.y * (2 * n.y * x.z) + d_out.z * (2 * (n.x * x.x + n.y * x.y + 2 * n.z * x.z)); +} + +__device__ static inline vec3f safeNormalize(vec3f v) +{ + float l = sqrtf(v.x * v.x + v.y * v.y + v.z * v.z); + return l > 0.0f ? (v / l) : vec3f(0.0f); +} + +__device__ static inline void bwdSafeNormalize(const vec3f v, vec3f& d_v, const vec3f d_out) +{ + + float l = sqrtf(v.x * v.x + v.y * v.y + v.z * v.z); + if (l > 0.0f) + { + float fac = 1.0 / powf(v.x * v.x + v.y * v.y + v.z * v.z, 1.5f); + d_v.x += (d_out.x * (v.y * v.y + v.z * v.z) - d_out.y * (v.x * v.y) - d_out.z * (v.x * v.z)) * fac; + d_v.y += (d_out.y * (v.x * v.x + v.z * v.z) - d_out.x * (v.y * v.x) - d_out.z * (v.y * v.z)) * fac; + d_v.z += (d_out.z * (v.x * v.x + v.y * v.y) - d_out.x * (v.z * v.x) - d_out.y * (v.z * v.y)) * fac; + } +} + +#endif \ No newline at end of file diff --git a/nvdiffrec/lib/render/renderutils/c_src/vec4f.h b/nvdiffrec/lib/render/renderutils/c_src/vec4f.h new file mode 100644 index 0000000..e3f3077 --- /dev/null +++ b/nvdiffrec/lib/render/renderutils/c_src/vec4f.h @@ -0,0 +1,25 @@ +/* + * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#pragma once + +struct vec4f +{ + float x, y, z, w; + +#ifdef __CUDACC__ + __device__ vec4f() { } + __device__ vec4f(float v) { x = v; y = v; z = v; w = v; } + __device__ vec4f(float _x, float _y, float _z, float _w) { x = _x; y = _y; z = _z; w = _w; } + __device__ vec4f(float4 v) { x = v.x; y = v.y; z = v.z; w = v.w; } +#endif +}; + diff --git a/nvdiffrec/lib/render/renderutils/loss.py b/nvdiffrec/lib/render/renderutils/loss.py new file mode 100644 index 0000000..92a24c0 --- /dev/null +++ b/nvdiffrec/lib/render/renderutils/loss.py @@ -0,0 +1,41 @@ +# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +import torch + +#---------------------------------------------------------------------------- +# HDR image losses +#---------------------------------------------------------------------------- + +def _tonemap_srgb(f): + return torch.where(f > 0.0031308, torch.pow(torch.clamp(f, min=0.0031308), 1.0/2.4)*1.055 - 0.055, 12.92*f) + +def _SMAPE(img, target, eps=0.01): + nom = torch.abs(img - target) + denom = torch.abs(img) + torch.abs(target) + 0.01 + return torch.mean(nom / denom) + +def _RELMSE(img, target, eps=0.1): + nom = (img - target) * (img - target) + denom = img * img + target * target + 0.1 + return torch.mean(nom / denom) + +def image_loss_fn(img, target, loss, tonemapper): + if tonemapper == 'log_srgb': + img = _tonemap_srgb(torch.log(torch.clamp(img, min=0, max=65535) + 1)) + target = _tonemap_srgb(torch.log(torch.clamp(target, min=0, max=65535) + 1)) + + if loss == 'mse': + return torch.nn.functional.mse_loss(img, target) + elif loss == 'smape': + return _SMAPE(img, target) + elif loss == 'relmse': + return _RELMSE(img, target) + else: + return torch.nn.functional.l1_loss(img, target) diff --git a/nvdiffrec/lib/render/renderutils/ops.py b/nvdiffrec/lib/render/renderutils/ops.py new file mode 100644 index 0000000..bb74516 --- /dev/null +++ b/nvdiffrec/lib/render/renderutils/ops.py @@ -0,0 +1,556 @@ +# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +import numpy as np +import os +import sys +import torch +import torch.utils.cpp_extension + +from .bsdf import * +from .loss import * + +#---------------------------------------------------------------------------- +# C++/Cuda plugin compiler/loader. + +_cached_plugin = None +def _get_plugin(): + # Return cached plugin if already loaded. + global _cached_plugin + if _cached_plugin is not None: + return _cached_plugin + + # Make sure we can find the necessary compiler and libary binaries. + if os.name == 'nt': + def find_cl_path(): + import glob + for edition in ['Enterprise', 'Professional', 'BuildTools', 'Community']: + paths = sorted(glob.glob(r"C:\Program Files (x86)\Microsoft Visual Studio\*\%s\VC\Tools\MSVC\*\bin\Hostx64\x64" % edition), reverse=True) + if paths: + return paths[0] + + # If cl.exe is not on path, try to find it. + if os.system("where cl.exe >nul 2>nul") != 0: + cl_path = find_cl_path() + if cl_path is None: + raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation") + os.environ['PATH'] += ';' + cl_path + + # Compiler options. + opts = ['-DNVDR_TORCH'] + + # Linker options. + if os.name == 'posix': + ldflags = ['-lcuda', '-lnvrtc'] + elif os.name == 'nt': + ldflags = ['cuda.lib', 'advapi32.lib', 'nvrtc.lib'] + + # List of sources. + source_files = [ + 'c_src/mesh.cu', + 'c_src/loss.cu', + 'c_src/bsdf.cu', + 'c_src/normal.cu', + 'c_src/cubemap.cu', + 'c_src/common.cpp', + 'c_src/torch_bindings.cpp' + ] + + # Some containers set this to contain old architectures that won't compile. We only need the one installed in the machine. + os.environ['TORCH_CUDA_ARCH_LIST'] = '' + + # Try to detect if a stray lock file is left in cache directory and show a warning. This sometimes happens on Windows if the build is interrupted at just the right moment. + try: + lock_fn = os.path.join(torch.utils.cpp_extension._get_build_directory('renderutils_plugin', False), 'lock') + if os.path.exists(lock_fn): + print("Warning: Lock file exists in build directory: '%s'" % lock_fn) + except: + pass + + # Compile and load. + source_paths = [os.path.join(os.path.dirname(__file__), fn) for fn in source_files] + torch.utils.cpp_extension.load(name='renderutils_plugin', sources=source_paths, extra_cflags=opts, + extra_cuda_cflags=opts, extra_ldflags=ldflags, with_cuda=True, verbose=True, + # build_directory="PLACEHOLDER", + ) + + # Import, cache, and return the compiled module. + import renderutils_plugin + _cached_plugin = renderutils_plugin + return _cached_plugin + +#---------------------------------------------------------------------------- +# Internal kernels, just used for testing functionality + +class _fresnel_shlick_func(torch.autograd.Function): + @staticmethod + def forward(ctx, f0, f90, cosTheta): + out = _get_plugin().fresnel_shlick_fwd(f0, f90, cosTheta, False) + ctx.save_for_backward(f0, f90, cosTheta) + return out + + @staticmethod + def backward(ctx, dout): + f0, f90, cosTheta = ctx.saved_variables + return _get_plugin().fresnel_shlick_bwd(f0, f90, cosTheta, dout) + (None,) + +def _fresnel_shlick(f0, f90, cosTheta, use_python=False): + if use_python: + out = bsdf_fresnel_shlick(f0, f90, cosTheta) + else: + out = _fresnel_shlick_func.apply(f0, f90, cosTheta) + + if torch.is_anomaly_enabled(): + assert torch.all(torch.isfinite(out)), "Output of _fresnel_shlick contains inf or NaN" + return out + + +class _ndf_ggx_func(torch.autograd.Function): + @staticmethod + def forward(ctx, alphaSqr, cosTheta): + out = _get_plugin().ndf_ggx_fwd(alphaSqr, cosTheta, False) + ctx.save_for_backward(alphaSqr, cosTheta) + return out + + @staticmethod + def backward(ctx, dout): + alphaSqr, cosTheta = ctx.saved_variables + return _get_plugin().ndf_ggx_bwd(alphaSqr, cosTheta, dout) + (None,) + +def _ndf_ggx(alphaSqr, cosTheta, use_python=False): + if use_python: + out = bsdf_ndf_ggx(alphaSqr, cosTheta) + else: + out = _ndf_ggx_func.apply(alphaSqr, cosTheta) + + if torch.is_anomaly_enabled(): + assert torch.all(torch.isfinite(out)), "Output of _ndf_ggx contains inf or NaN" + return out + +class _lambda_ggx_func(torch.autograd.Function): + @staticmethod + def forward(ctx, alphaSqr, cosTheta): + out = _get_plugin().lambda_ggx_fwd(alphaSqr, cosTheta, False) + ctx.save_for_backward(alphaSqr, cosTheta) + return out + + @staticmethod + def backward(ctx, dout): + alphaSqr, cosTheta = ctx.saved_variables + return _get_plugin().lambda_ggx_bwd(alphaSqr, cosTheta, dout) + (None,) + +def _lambda_ggx(alphaSqr, cosTheta, use_python=False): + if use_python: + out = bsdf_lambda_ggx(alphaSqr, cosTheta) + else: + out = _lambda_ggx_func.apply(alphaSqr, cosTheta) + + if torch.is_anomaly_enabled(): + assert torch.all(torch.isfinite(out)), "Output of _lambda_ggx contains inf or NaN" + return out + +class _masking_smith_func(torch.autograd.Function): + @staticmethod + def forward(ctx, alphaSqr, cosThetaI, cosThetaO): + ctx.save_for_backward(alphaSqr, cosThetaI, cosThetaO) + out = _get_plugin().masking_smith_fwd(alphaSqr, cosThetaI, cosThetaO, False) + return out + + @staticmethod + def backward(ctx, dout): + alphaSqr, cosThetaI, cosThetaO = ctx.saved_variables + return _get_plugin().masking_smith_bwd(alphaSqr, cosThetaI, cosThetaO, dout) + (None,) + +def _masking_smith(alphaSqr, cosThetaI, cosThetaO, use_python=False): + if use_python: + out = bsdf_masking_smith_ggx_correlated(alphaSqr, cosThetaI, cosThetaO) + else: + out = _masking_smith_func.apply(alphaSqr, cosThetaI, cosThetaO) + + if torch.is_anomaly_enabled(): + assert torch.all(torch.isfinite(out)), "Output of _masking_smith contains inf or NaN" + return out + +#---------------------------------------------------------------------------- +# Shading normal setup (bump mapping + bent normals) + +class _prepare_shading_normal_func(torch.autograd.Function): + @staticmethod + def forward(ctx, pos, view_pos, perturbed_nrm, smooth_nrm, smooth_tng, geom_nrm, two_sided_shading, opengl): + ctx.two_sided_shading, ctx.opengl = two_sided_shading, opengl + out = _get_plugin().prepare_shading_normal_fwd(pos, view_pos, perturbed_nrm, smooth_nrm, smooth_tng, geom_nrm, two_sided_shading, opengl, False) + ctx.save_for_backward(pos, view_pos, perturbed_nrm, smooth_nrm, smooth_tng, geom_nrm) + return out + + @staticmethod + def backward(ctx, dout): + pos, view_pos, perturbed_nrm, smooth_nrm, smooth_tng, geom_nrm = ctx.saved_variables + return _get_plugin().prepare_shading_normal_bwd(pos, view_pos, perturbed_nrm, smooth_nrm, smooth_tng, geom_nrm, dout, ctx.two_sided_shading, ctx.opengl) + (None, None, None) + +def prepare_shading_normal(pos, view_pos, perturbed_nrm, smooth_nrm, smooth_tng, geom_nrm, two_sided_shading=True, opengl=True, use_python=False): + '''Takes care of all corner cases and produces a final normal used for shading: + - Constructs tangent space + - Flips normal direction based on geometric normal for two sided Shading + - Perturbs shading normal by normal map + - Bends backfacing normals towards the camera to avoid shading artifacts + + All tensors assume a shape of [minibatch_size, height, width, 3] or broadcastable equivalent. + + Args: + pos: World space g-buffer position. + view_pos: Camera position in world space (typically using broadcasting). + perturbed_nrm: Trangent-space normal perturbation from normal map lookup. + smooth_nrm: Interpolated vertex normals. + smooth_tng: Interpolated vertex tangents. + geom_nrm: Geometric (face) normals. + two_sided_shading: Use one/two sided shading + opengl: Use OpenGL/DirectX normal map conventions + use_python: Use PyTorch implementation (for validation) + Returns: + Final shading normal + ''' + + if perturbed_nrm is None: + perturbed_nrm = torch.tensor([0, 0, 1], dtype=torch.float32, device='cuda', requires_grad=False)[None, None, None, ...] + + if use_python: + out = bsdf_prepare_shading_normal(pos, view_pos, perturbed_nrm, smooth_nrm, smooth_tng, geom_nrm, two_sided_shading, opengl) + else: + out = _prepare_shading_normal_func.apply(pos, view_pos, perturbed_nrm, smooth_nrm, smooth_tng, geom_nrm, two_sided_shading, opengl) + + if torch.is_anomaly_enabled(): + assert torch.all(torch.isfinite(out)), "Output of prepare_shading_normal contains inf or NaN" + return out + +#---------------------------------------------------------------------------- +# BSDF functions + +class _lambert_func(torch.autograd.Function): + @staticmethod + def forward(ctx, nrm, wi): + out = _get_plugin().lambert_fwd(nrm, wi, False) + ctx.save_for_backward(nrm, wi) + return out + + @staticmethod + def backward(ctx, dout): + nrm, wi = ctx.saved_variables + return _get_plugin().lambert_bwd(nrm, wi, dout) + (None,) + +def lambert(nrm, wi, use_python=False): + '''Lambertian bsdf. + All tensors assume a shape of [minibatch_size, height, width, 3] or broadcastable equivalent. + + Args: + nrm: World space shading normal. + wi: World space light vector. + use_python: Use PyTorch implementation (for validation) + + Returns: + Shaded diffuse value with shape [minibatch_size, height, width, 1] + ''' + + if use_python: + out = bsdf_lambert(nrm, wi) + else: + out = _lambert_func.apply(nrm, wi) + + if torch.is_anomaly_enabled(): + assert torch.all(torch.isfinite(out)), "Output of lambert contains inf or NaN" + return out + +class _frostbite_diffuse_func(torch.autograd.Function): + @staticmethod + def forward(ctx, nrm, wi, wo, linearRoughness): + out = _get_plugin().frostbite_fwd(nrm, wi, wo, linearRoughness, False) + ctx.save_for_backward(nrm, wi, wo, linearRoughness) + return out + + @staticmethod + def backward(ctx, dout): + nrm, wi, wo, linearRoughness = ctx.saved_variables + return _get_plugin().frostbite_bwd(nrm, wi, wo, linearRoughness, dout) + (None,) + +def frostbite_diffuse(nrm, wi, wo, linearRoughness, use_python=False): + '''Frostbite, normalized Disney Diffuse bsdf. + All tensors assume a shape of [minibatch_size, height, width, 3] or broadcastable equivalent. + + Args: + nrm: World space shading normal. + wi: World space light vector. + wo: World space camera vector. + linearRoughness: Material roughness + use_python: Use PyTorch implementation (for validation) + + Returns: + Shaded diffuse value with shape [minibatch_size, height, width, 1] + ''' + + if use_python: + out = bsdf_frostbite(nrm, wi, wo, linearRoughness) + else: + out = _frostbite_diffuse_func.apply(nrm, wi, wo, linearRoughness) + + if torch.is_anomaly_enabled(): + assert torch.all(torch.isfinite(out)), "Output of lambert contains inf or NaN" + return out + +class _pbr_specular_func(torch.autograd.Function): + @staticmethod + def forward(ctx, col, nrm, wo, wi, alpha, min_roughness): + ctx.save_for_backward(col, nrm, wo, wi, alpha) + ctx.min_roughness = min_roughness + out = _get_plugin().pbr_specular_fwd(col, nrm, wo, wi, alpha, min_roughness, False) + return out + + @staticmethod + def backward(ctx, dout): + col, nrm, wo, wi, alpha = ctx.saved_variables + return _get_plugin().pbr_specular_bwd(col, nrm, wo, wi, alpha, ctx.min_roughness, dout) + (None, None) + +def pbr_specular(col, nrm, wo, wi, alpha, min_roughness=0.08, use_python=False): + '''Physically-based specular bsdf. + All tensors assume a shape of [minibatch_size, height, width, 3] or broadcastable equivalent unless otherwise noted. + + Args: + col: Specular lobe color + nrm: World space shading normal. + wo: World space camera vector. + wi: World space light vector + alpha: Specular roughness parameter with shape [minibatch_size, height, width, 1] + min_roughness: Scalar roughness clamping threshold + + use_python: Use PyTorch implementation (for validation) + Returns: + Shaded specular color + ''' + + if use_python: + out = bsdf_pbr_specular(col, nrm, wo, wi, alpha, min_roughness=min_roughness) + else: + out = _pbr_specular_func.apply(col, nrm, wo, wi, alpha, min_roughness) + + if torch.is_anomaly_enabled(): + assert torch.all(torch.isfinite(out)), "Output of pbr_specular contains inf or NaN" + return out + +class _pbr_bsdf_func(torch.autograd.Function): + @staticmethod + def forward(ctx, kd, arm, pos, nrm, view_pos, light_pos, min_roughness, BSDF): + ctx.save_for_backward(kd, arm, pos, nrm, view_pos, light_pos) + ctx.min_roughness = min_roughness + ctx.BSDF = BSDF + out = _get_plugin().pbr_bsdf_fwd(kd, arm, pos, nrm, view_pos, light_pos, min_roughness, BSDF, False) + return out + + @staticmethod + def backward(ctx, dout): + kd, arm, pos, nrm, view_pos, light_pos = ctx.saved_variables + return _get_plugin().pbr_bsdf_bwd(kd, arm, pos, nrm, view_pos, light_pos, ctx.min_roughness, ctx.BSDF, dout) + (None, None, None) + +def pbr_bsdf(kd, arm, pos, nrm, view_pos, light_pos, min_roughness=0.08, bsdf="lambert", use_python=False): + '''Physically-based bsdf, both diffuse & specular lobes + All tensors assume a shape of [minibatch_size, height, width, 3] or broadcastable equivalent unless otherwise noted. + + Args: + kd: Diffuse albedo. + arm: Specular parameters (attenuation, linear roughness, metalness). + pos: World space position. + nrm: World space shading normal. + view_pos: Camera position in world space, typically using broadcasting. + light_pos: Light position in world space, typically using broadcasting. + min_roughness: Scalar roughness clamping threshold + bsdf: Controls diffuse BSDF, can be either 'lambert' or 'frostbite' + + use_python: Use PyTorch implementation (for validation) + + Returns: + Shaded color. + ''' + + BSDF = 0 + if bsdf == 'frostbite': + BSDF = 1 + + if use_python: + out = bsdf_pbr(kd, arm, pos, nrm, view_pos, light_pos, min_roughness, BSDF) + else: + out = _pbr_bsdf_func.apply(kd, arm, pos, nrm, view_pos, light_pos, min_roughness, BSDF) + + if torch.is_anomaly_enabled(): + assert torch.all(torch.isfinite(out)), "Output of pbr_bsdf contains inf or NaN" + return out + +#---------------------------------------------------------------------------- +# cubemap filter with filtering across edges + +class _diffuse_cubemap_func(torch.autograd.Function): + @staticmethod + def forward(ctx, cubemap): + out = _get_plugin().diffuse_cubemap_fwd(cubemap) + ctx.save_for_backward(cubemap) + return out + + @staticmethod + def backward(ctx, dout): + cubemap, = ctx.saved_variables + cubemap_grad = _get_plugin().diffuse_cubemap_bwd(cubemap, dout) + return cubemap_grad, None + +def diffuse_cubemap(cubemap, use_python=False): + if use_python: + assert False + else: + out = _diffuse_cubemap_func.apply(cubemap) + if torch.is_anomaly_enabled(): + assert torch.all(torch.isfinite(out)), "Output of diffuse_cubemap contains inf or NaN" + return out + +class _specular_cubemap(torch.autograd.Function): + @staticmethod + def forward(ctx, cubemap, roughness, costheta_cutoff, bounds): + out = _get_plugin().specular_cubemap_fwd(cubemap, bounds, roughness, costheta_cutoff) + ctx.save_for_backward(cubemap, bounds) + ctx.roughness, ctx.theta_cutoff = roughness, costheta_cutoff + return out + + @staticmethod + def backward(ctx, dout): + cubemap, bounds = ctx.saved_variables + cubemap_grad = _get_plugin().specular_cubemap_bwd(cubemap, bounds, dout, ctx.roughness, ctx.theta_cutoff) + return cubemap_grad, None, None, None + +# Compute the bounds of the GGX NDF lobe to retain "cutoff" percent of the energy +def __ndfBounds(res, roughness, cutoff): + def ndfGGX(alphaSqr, costheta): + costheta = np.clip(costheta, 0.0, 1.0) + d = (costheta * alphaSqr - costheta) * costheta + 1.0 + return alphaSqr / (d * d * np.pi) + + # Sample out cutoff angle + nSamples = 1000000 + costheta = np.cos(np.linspace(0, np.pi/2.0, nSamples)) + D = np.cumsum(ndfGGX(roughness**4, costheta)) + idx = np.argmax(D >= D[..., -1] * cutoff) + + # Brute force compute lookup table with bounds + bounds = _get_plugin().specular_bounds(res, costheta[idx]) + + return costheta[idx], bounds +__ndfBoundsDict = {} + +def specular_cubemap(cubemap, roughness, cutoff=0.99, use_python=False): + assert cubemap.shape[0] == 6 and cubemap.shape[1] == cubemap.shape[2], "Bad shape for cubemap tensor: %s" % str(cubemap.shape) + + if use_python: + assert False + else: + key = (cubemap.shape[1], roughness, cutoff) + if key not in __ndfBoundsDict: + __ndfBoundsDict[key] = __ndfBounds(*key) + out = _specular_cubemap.apply(cubemap, roughness, *__ndfBoundsDict[key]) + if torch.is_anomaly_enabled(): + assert torch.all(torch.isfinite(out)), "Output of specular_cubemap contains inf or NaN" + return out[..., 0:3] / out[..., 3:] + +#---------------------------------------------------------------------------- +# Fast image loss function + +class _image_loss_func(torch.autograd.Function): + @staticmethod + def forward(ctx, img, target, loss, tonemapper): + ctx.loss, ctx.tonemapper = loss, tonemapper + ctx.save_for_backward(img, target) + out = _get_plugin().image_loss_fwd(img, target, loss, tonemapper, False) + return out + + @staticmethod + def backward(ctx, dout): + img, target = ctx.saved_variables + return _get_plugin().image_loss_bwd(img, target, dout, ctx.loss, ctx.tonemapper) + (None, None, None) + +def image_loss(img, target, loss='l1', tonemapper='none', use_python=False): + '''Compute HDR image loss. Combines tonemapping and loss into a single kernel for better perf. + All tensors assume a shape of [minibatch_size, height, width, 3] or broadcastable equivalent unless otherwise noted. + + Args: + img: Input image. + target: Target (reference) image. + loss: Type of loss. Valid options are ['l1', 'mse', 'smape', 'relmse'] + tonemapper: Tonemapping operations. Valid options are ['none', 'log_srgb'] + use_python: Use PyTorch implementation (for validation) + + Returns: + Image space loss (scalar value). + ''' + if use_python: + out = image_loss_fn(img, target, loss, tonemapper) + else: + out = _image_loss_func.apply(img, target, loss, tonemapper) + out = torch.sum(out) / (img.shape[0]*img.shape[1]*img.shape[2]) + + if torch.is_anomaly_enabled(): + assert torch.all(torch.isfinite(out)), "Output of image_loss contains inf or NaN" + return out + +#---------------------------------------------------------------------------- +# Transform points function + +class _xfm_func(torch.autograd.Function): + @staticmethod + def forward(ctx, points, matrix, isPoints): + ctx.save_for_backward(points, matrix) + ctx.isPoints = isPoints + return _get_plugin().xfm_fwd(points, matrix, isPoints, False) + + @staticmethod + def backward(ctx, dout): + points, matrix = ctx.saved_variables + return (_get_plugin().xfm_bwd(points, matrix, dout, ctx.isPoints),) + (None, None, None) + +def xfm_points(points, matrix, use_python=False): + '''Transform points. + Args: + points: Tensor containing 3D points with shape [minibatch_size, num_vertices, 3] or [1, num_vertices, 3] + matrix: A 4x4 transform matrix with shape [minibatch_size, 4, 4] + use_python: Use PyTorch's torch.matmul (for validation) + Returns: + Transformed points in homogeneous 4D with shape [minibatch_size, num_vertices, 4]. + ''' + if use_python: + out = torch.matmul(torch.nn.functional.pad(points, pad=(0,1), mode='constant', value=1.0), torch.transpose(matrix, 1, 2)) + else: + out = _xfm_func.apply(points, matrix, True) + + if torch.is_anomaly_enabled(): + assert torch.all(torch.isfinite(out)), "Output of xfm_points contains inf or NaN" + return out + +def xfm_vectors(vectors, matrix, use_python=False): + '''Transform vectors. + Args: + vectors: Tensor containing 3D vectors with shape [minibatch_size, num_vertices, 3] or [1, num_vertices, 3] + matrix: A 4x4 transform matrix with shape [minibatch_size, 4, 4] + use_python: Use PyTorch's torch.matmul (for validation) + + Returns: + Transformed vectors in homogeneous 4D with shape [minibatch_size, num_vertices, 4]. + ''' + + if use_python: + out = torch.matmul(torch.nn.functional.pad(vectors, pad=(0,1), mode='constant', value=0.0), torch.transpose(matrix, 1, 2))[..., 0:3].contiguous() + else: + out = _xfm_func.apply(vectors, matrix, False) + + if torch.is_anomaly_enabled(): + assert torch.all(torch.isfinite(out)), "Output of xfm_vectors contains inf or NaN" + return out + + + diff --git a/nvdiffrec/lib/render/renderutils/setup.py b/nvdiffrec/lib/render/renderutils/setup.py new file mode 100644 index 0000000..c698672 --- /dev/null +++ b/nvdiffrec/lib/render/renderutils/setup.py @@ -0,0 +1,61 @@ +from setuptools import setup +import torch +import os,glob +from torch.utils.cpp_extension import (CUDAExtension, CppExtension, BuildExtension) + +def get_extensions(): + extensions = [] + ext_name = 'nvdiffrec_renderutils' + # prevent ninja from using too many resources + os.environ.setdefault('MAX_JOBS', '16') + define_macros = [] + + # Compiler options. + opts = ['-DNVDR_TORCH'] + + # Linker options. + if os.name == 'posix': + ldflags = ['-lcuda', '-lnvrtc'] + elif os.name == 'nt': + ldflags = ['cuda.lib', 'advapi32.lib', 'nvrtc.lib'] + + # List of sources. + source_files = [ + 'c_src/mesh.cu', + 'c_src/loss.cu', + 'c_src/bsdf.cu', + 'c_src/normal.cu', + 'c_src/cubemap.cu', + 'c_src/common.cpp', + 'c_src/torch_bindings.cpp' + ] + + os.environ['TORCH_CUDA_ARCH_LIST'] = "5.0 6.0 6.1 7.0 7.5 8.0 8.6" + + if torch.cuda.is_available(): + print(f'Compiling {ext_name} with CUDA') + define_macros += [('WITH_CUDA', None)] + # op_files = glob.glob('./c_src/*') + # extension = CUDAExtension + else: + raise NotImplementedError + + include_path = os.path.abspath('./c_src') + ext_ops = CUDAExtension( + name=ext_name, + sources=source_files, + include_dirs=[include_path], + define_macros=define_macros, + extra_compile_args=opts + ldflags, + libraries=['cuda', 'nvrtc'], + extra_cuda_cflags=opts, + extra_cflags=opts, + extra_ldflags=ldflags) + extensions.append(ext_ops) + return extensions + +setup( + name='nvdiffrec_renderutils', + ext_modules=get_extensions(), + cmdclass={'build_ext': BuildExtension}, + ) \ No newline at end of file diff --git a/nvdiffrec/lib/render/renderutils/tests/test_bsdf.py b/nvdiffrec/lib/render/renderutils/tests/test_bsdf.py new file mode 100644 index 0000000..b0b60c3 --- /dev/null +++ b/nvdiffrec/lib/render/renderutils/tests/test_bsdf.py @@ -0,0 +1,296 @@ +# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +import torch + +import os +import sys +sys.path.insert(0, os.path.join(sys.path[0], '../..')) +import renderutils as ru + +RES = 4 +DTYPE = torch.float32 + +def relative_loss(name, ref, cuda): + ref = ref.float() + cuda = cuda.float() + print(name, torch.max(torch.abs(ref - cuda) / torch.abs(ref + 1e-7)).item()) + +def test_normal(): + pos_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) + pos_ref = pos_cuda.clone().detach().requires_grad_(True) + view_pos_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) + view_pos_ref = view_pos_cuda.clone().detach().requires_grad_(True) + perturbed_nrm_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) + perturbed_nrm_ref = perturbed_nrm_cuda.clone().detach().requires_grad_(True) + smooth_nrm_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) + smooth_nrm_ref = smooth_nrm_cuda.clone().detach().requires_grad_(True) + smooth_tng_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) + smooth_tng_ref = smooth_tng_cuda.clone().detach().requires_grad_(True) + geom_nrm_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) + geom_nrm_ref = geom_nrm_cuda.clone().detach().requires_grad_(True) + target = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda') + + ref = ru.prepare_shading_normal(pos_ref, view_pos_ref, perturbed_nrm_ref, smooth_nrm_ref, smooth_tng_ref, geom_nrm_ref, True, use_python=True) + ref_loss = torch.nn.MSELoss()(ref, target) + ref_loss.backward() + + cuda = ru.prepare_shading_normal(pos_cuda, view_pos_cuda, perturbed_nrm_cuda, smooth_nrm_cuda, smooth_tng_cuda, geom_nrm_cuda, True) + cuda_loss = torch.nn.MSELoss()(cuda, target) + cuda_loss.backward() + + print("-------------------------------------------------------------") + print(" bent normal") + print("-------------------------------------------------------------") + relative_loss("res:", ref, cuda) + relative_loss("pos:", pos_ref.grad, pos_cuda.grad) + relative_loss("view_pos:", view_pos_ref.grad, view_pos_cuda.grad) + relative_loss("perturbed_nrm:", perturbed_nrm_ref.grad, perturbed_nrm_cuda.grad) + relative_loss("smooth_nrm:", smooth_nrm_ref.grad, smooth_nrm_cuda.grad) + relative_loss("smooth_tng:", smooth_tng_ref.grad, smooth_tng_cuda.grad) + relative_loss("geom_nrm:", geom_nrm_ref.grad, geom_nrm_cuda.grad) + +def test_schlick(): + f0_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) + f0_ref = f0_cuda.clone().detach().requires_grad_(True) + f90_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) + f90_ref = f90_cuda.clone().detach().requires_grad_(True) + cosT_cuda = torch.rand(1, RES, RES, 1, dtype=DTYPE, device='cuda', requires_grad=True) * 2.0 + cosT_cuda = cosT_cuda.clone().detach().requires_grad_(True) + cosT_ref = cosT_cuda.clone().detach().requires_grad_(True) + target = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda') + + ref = ru._fresnel_shlick(f0_ref, f90_ref, cosT_ref, use_python=True) + ref_loss = torch.nn.MSELoss()(ref, target) + ref_loss.backward() + + cuda = ru._fresnel_shlick(f0_cuda, f90_cuda, cosT_cuda) + cuda_loss = torch.nn.MSELoss()(cuda, target) + cuda_loss.backward() + + print("-------------------------------------------------------------") + print(" Fresnel shlick") + print("-------------------------------------------------------------") + relative_loss("res:", ref, cuda) + relative_loss("f0:", f0_ref.grad, f0_cuda.grad) + relative_loss("f90:", f90_ref.grad, f90_cuda.grad) + relative_loss("cosT:", cosT_ref.grad, cosT_cuda.grad) + +def test_ndf_ggx(): + alphaSqr_cuda = torch.rand(1, RES, RES, 1, dtype=DTYPE, device='cuda', requires_grad=True) + alphaSqr_cuda = alphaSqr_cuda.clone().detach().requires_grad_(True) + alphaSqr_ref = alphaSqr_cuda.clone().detach().requires_grad_(True) + cosT_cuda = torch.rand(1, RES, RES, 1, dtype=DTYPE, device='cuda', requires_grad=True) * 3.0 - 1 + cosT_cuda = cosT_cuda.clone().detach().requires_grad_(True) + cosT_ref = cosT_cuda.clone().detach().requires_grad_(True) + target = torch.rand(1, RES, RES, 1, dtype=DTYPE, device='cuda') + + ref = ru._ndf_ggx(alphaSqr_ref, cosT_ref, use_python=True) + ref_loss = torch.nn.MSELoss()(ref, target) + ref_loss.backward() + + cuda = ru._ndf_ggx(alphaSqr_cuda, cosT_cuda) + cuda_loss = torch.nn.MSELoss()(cuda, target) + cuda_loss.backward() + + print("-------------------------------------------------------------") + print(" Ndf GGX") + print("-------------------------------------------------------------") + relative_loss("res:", ref, cuda) + relative_loss("alpha:", alphaSqr_ref.grad, alphaSqr_cuda.grad) + relative_loss("cosT:", cosT_ref.grad, cosT_cuda.grad) + +def test_lambda_ggx(): + alphaSqr_cuda = torch.rand(1, RES, RES, 1, dtype=DTYPE, device='cuda', requires_grad=True) + alphaSqr_ref = alphaSqr_cuda.clone().detach().requires_grad_(True) + cosT_cuda = torch.rand(1, RES, RES, 1, dtype=DTYPE, device='cuda', requires_grad=True) * 3.0 - 1 + cosT_cuda = cosT_cuda.clone().detach().requires_grad_(True) + cosT_ref = cosT_cuda.clone().detach().requires_grad_(True) + target = torch.rand(1, RES, RES, 1, dtype=DTYPE, device='cuda') + + ref = ru._lambda_ggx(alphaSqr_ref, cosT_ref, use_python=True) + ref_loss = torch.nn.MSELoss()(ref, target) + ref_loss.backward() + + cuda = ru._lambda_ggx(alphaSqr_cuda, cosT_cuda) + cuda_loss = torch.nn.MSELoss()(cuda, target) + cuda_loss.backward() + + print("-------------------------------------------------------------") + print(" Lambda GGX") + print("-------------------------------------------------------------") + relative_loss("res:", ref, cuda) + relative_loss("alpha:", alphaSqr_ref.grad, alphaSqr_cuda.grad) + relative_loss("cosT:", cosT_ref.grad, cosT_cuda.grad) + +def test_masking_smith(): + alphaSqr_cuda = torch.rand(1, RES, RES, 1, dtype=DTYPE, device='cuda', requires_grad=True) + alphaSqr_ref = alphaSqr_cuda.clone().detach().requires_grad_(True) + cosThetaI_cuda = torch.rand(1, RES, RES, 1, dtype=DTYPE, device='cuda', requires_grad=True) + cosThetaI_ref = cosThetaI_cuda.clone().detach().requires_grad_(True) + cosThetaO_cuda = torch.rand(1, RES, RES, 1, dtype=DTYPE, device='cuda', requires_grad=True) + cosThetaO_ref = cosThetaO_cuda.clone().detach().requires_grad_(True) + target = torch.rand(1, RES, RES, 1, dtype=DTYPE, device='cuda') + + ref = ru._masking_smith(alphaSqr_ref, cosThetaI_ref, cosThetaO_ref, use_python=True) + ref_loss = torch.nn.MSELoss()(ref, target) + ref_loss.backward() + + cuda = ru._masking_smith(alphaSqr_cuda, cosThetaI_cuda, cosThetaO_cuda) + cuda_loss = torch.nn.MSELoss()(cuda, target) + cuda_loss.backward() + + print("-------------------------------------------------------------") + print(" Smith masking term") + print("-------------------------------------------------------------") + relative_loss("res:", ref, cuda) + relative_loss("alpha:", alphaSqr_ref.grad, alphaSqr_cuda.grad) + relative_loss("cosThetaI:", cosThetaI_ref.grad, cosThetaI_cuda.grad) + relative_loss("cosThetaO:", cosThetaO_ref.grad, cosThetaO_cuda.grad) + +def test_lambert(): + normals_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) + normals_ref = normals_cuda.clone().detach().requires_grad_(True) + wi_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) + wi_ref = wi_cuda.clone().detach().requires_grad_(True) + target = torch.rand(1, RES, RES, 1, dtype=DTYPE, device='cuda') + + ref = ru.lambert(normals_ref, wi_ref, use_python=True) + ref_loss = torch.nn.MSELoss()(ref, target) + ref_loss.backward() + + cuda = ru.lambert(normals_cuda, wi_cuda) + cuda_loss = torch.nn.MSELoss()(cuda, target) + cuda_loss.backward() + + print("-------------------------------------------------------------") + print(" Lambert") + print("-------------------------------------------------------------") + relative_loss("res:", ref, cuda) + relative_loss("nrm:", normals_ref.grad, normals_cuda.grad) + relative_loss("wi:", wi_ref.grad, wi_cuda.grad) + +def test_frostbite(): + normals_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) + normals_ref = normals_cuda.clone().detach().requires_grad_(True) + wi_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) + wi_ref = wi_cuda.clone().detach().requires_grad_(True) + wo_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) + wo_ref = wo_cuda.clone().detach().requires_grad_(True) + rough_cuda = torch.rand(1, RES, RES, 1, dtype=DTYPE, device='cuda', requires_grad=True) + rough_ref = rough_cuda.clone().detach().requires_grad_(True) + target = torch.rand(1, RES, RES, 1, dtype=DTYPE, device='cuda') + + ref = ru.frostbite_diffuse(normals_ref, wi_ref, wo_ref, rough_ref, use_python=True) + ref_loss = torch.nn.MSELoss()(ref, target) + ref_loss.backward() + + cuda = ru.frostbite_diffuse(normals_cuda, wi_cuda, wo_cuda, rough_cuda) + cuda_loss = torch.nn.MSELoss()(cuda, target) + cuda_loss.backward() + + print("-------------------------------------------------------------") + print(" Frostbite") + print("-------------------------------------------------------------") + relative_loss("res:", ref, cuda) + relative_loss("nrm:", normals_ref.grad, normals_cuda.grad) + relative_loss("wo:", wo_ref.grad, wo_cuda.grad) + relative_loss("wi:", wi_ref.grad, wi_cuda.grad) + relative_loss("rough:", rough_ref.grad, rough_cuda.grad) + +def test_pbr_specular(): + col_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) + col_ref = col_cuda.clone().detach().requires_grad_(True) + nrm_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) + nrm_ref = nrm_cuda.clone().detach().requires_grad_(True) + wi_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) + wi_ref = wi_cuda.clone().detach().requires_grad_(True) + wo_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) + wo_ref = wo_cuda.clone().detach().requires_grad_(True) + alpha_cuda = torch.rand(1, RES, RES, 1, dtype=DTYPE, device='cuda', requires_grad=True) + alpha_ref = alpha_cuda.clone().detach().requires_grad_(True) + target = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda') + + ref = ru.pbr_specular(col_ref, nrm_ref, wo_ref, wi_ref, alpha_ref, use_python=True) + ref_loss = torch.nn.MSELoss()(ref, target) + ref_loss.backward() + + cuda = ru.pbr_specular(col_cuda, nrm_cuda, wo_cuda, wi_cuda, alpha_cuda) + cuda_loss = torch.nn.MSELoss()(cuda, target) + cuda_loss.backward() + + print("-------------------------------------------------------------") + print(" Pbr specular") + print("-------------------------------------------------------------") + + relative_loss("res:", ref, cuda) + if col_ref.grad is not None: + relative_loss("col:", col_ref.grad, col_cuda.grad) + if nrm_ref.grad is not None: + relative_loss("nrm:", nrm_ref.grad, nrm_cuda.grad) + if wi_ref.grad is not None: + relative_loss("wi:", wi_ref.grad, wi_cuda.grad) + if wo_ref.grad is not None: + relative_loss("wo:", wo_ref.grad, wo_cuda.grad) + if alpha_ref.grad is not None: + relative_loss("alpha:", alpha_ref.grad, alpha_cuda.grad) + +def test_pbr_bsdf(bsdf): + kd_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) + kd_ref = kd_cuda.clone().detach().requires_grad_(True) + arm_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) + arm_ref = arm_cuda.clone().detach().requires_grad_(True) + pos_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) + pos_ref = pos_cuda.clone().detach().requires_grad_(True) + nrm_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) + nrm_ref = nrm_cuda.clone().detach().requires_grad_(True) + view_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) + view_ref = view_cuda.clone().detach().requires_grad_(True) + light_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) + light_ref = light_cuda.clone().detach().requires_grad_(True) + target = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda') + + ref = ru.pbr_bsdf(kd_ref, arm_ref, pos_ref, nrm_ref, view_ref, light_ref, use_python=True, bsdf=bsdf) + ref_loss = torch.nn.MSELoss()(ref, target) + ref_loss.backward() + + cuda = ru.pbr_bsdf(kd_cuda, arm_cuda, pos_cuda, nrm_cuda, view_cuda, light_cuda, bsdf=bsdf) + cuda_loss = torch.nn.MSELoss()(cuda, target) + cuda_loss.backward() + + print("-------------------------------------------------------------") + print(" Pbr BSDF") + print("-------------------------------------------------------------") + + relative_loss("res:", ref, cuda) + if kd_ref.grad is not None: + relative_loss("kd:", kd_ref.grad, kd_cuda.grad) + if arm_ref.grad is not None: + relative_loss("arm:", arm_ref.grad, arm_cuda.grad) + if pos_ref.grad is not None: + relative_loss("pos:", pos_ref.grad, pos_cuda.grad) + if nrm_ref.grad is not None: + relative_loss("nrm:", nrm_ref.grad, nrm_cuda.grad) + if view_ref.grad is not None: + relative_loss("view:", view_ref.grad, view_cuda.grad) + if light_ref.grad is not None: + relative_loss("light:", light_ref.grad, light_cuda.grad) + +test_normal() + +test_schlick() +test_ndf_ggx() +test_lambda_ggx() +test_masking_smith() + +test_lambert() +test_frostbite() +test_pbr_specular() +test_pbr_bsdf('lambert') +test_pbr_bsdf('frostbite') diff --git a/nvdiffrec/lib/render/renderutils/tests/test_cubemap.py b/nvdiffrec/lib/render/renderutils/tests/test_cubemap.py new file mode 100644 index 0000000..a1ae0a2 --- /dev/null +++ b/nvdiffrec/lib/render/renderutils/tests/test_cubemap.py @@ -0,0 +1,47 @@ +# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +import torch + +import os +import sys +sys.path.insert(0, os.path.join(sys.path[0], '../..')) +import renderutils as ru + +RES = 4 +DTYPE = torch.float32 + +def relative_loss(name, ref, cuda): + ref = ref.float() + cuda = cuda.float() + print(name, torch.max(torch.abs(ref - cuda) / torch.abs(ref + 1e-7)).item()) + +def test_cubemap(): + cubemap_cuda = torch.rand(6, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) + cubemap_ref = cubemap_cuda.clone().detach().requires_grad_(True) + weights = torch.rand(3, 3, 1, dtype=DTYPE, device='cuda') + target = torch.rand(6, RES, RES, 3, dtype=DTYPE, device='cuda') + + ref = ru.filter_cubemap(cubemap_ref, weights, use_python=True) + ref_loss = torch.nn.MSELoss()(ref, target) + ref_loss.backward() + + cuda = ru.filter_cubemap(cubemap_cuda, weights, use_python=False) + cuda_loss = torch.nn.MSELoss()(cuda, target) + cuda_loss.backward() + + print("-------------------------------------------------------------") + print(" Cubemap:") + print("-------------------------------------------------------------") + + relative_loss("flt:", ref, cuda) + relative_loss("cubemap:", cubemap_ref.grad, cubemap_cuda.grad) + + +test_cubemap() diff --git a/nvdiffrec/lib/render/renderutils/tests/test_loss.py b/nvdiffrec/lib/render/renderutils/tests/test_loss.py new file mode 100644 index 0000000..7a68f3f --- /dev/null +++ b/nvdiffrec/lib/render/renderutils/tests/test_loss.py @@ -0,0 +1,61 @@ +# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +import torch + +import os +import sys +sys.path.insert(0, os.path.join(sys.path[0], '../..')) +import renderutils as ru + +RES = 8 +DTYPE = torch.float32 + +def tonemap_srgb(f): + return torch.where(f > 0.0031308, torch.pow(torch.clamp(f, min=0.0031308), 1.0/2.4)*1.055 - 0.055, 12.92*f) + +def l1(output, target): + x = torch.clamp(output, min=0, max=65535) + r = torch.clamp(target, min=0, max=65535) + x = tonemap_srgb(torch.log(x + 1)) + r = tonemap_srgb(torch.log(r + 1)) + return torch.nn.functional.l1_loss(x,r) + +def relative_loss(name, ref, cuda): + ref = ref.float() + cuda = cuda.float() + print(name, torch.max(torch.abs(ref - cuda) / torch.abs(ref + 1e-7)).item()) + +def test_loss(loss, tonemapper): + img_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) + img_ref = img_cuda.clone().detach().requires_grad_(True) + target_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) + target_ref = target_cuda.clone().detach().requires_grad_(True) + + ref_loss = ru.image_loss(img_ref, target_ref, loss=loss, tonemapper=tonemapper, use_python=True) + ref_loss.backward() + + cuda_loss = ru.image_loss(img_cuda, target_cuda, loss=loss, tonemapper=tonemapper) + cuda_loss.backward() + + print("-------------------------------------------------------------") + print(" Loss: %s, %s" % (loss, tonemapper)) + print("-------------------------------------------------------------") + + relative_loss("res:", ref_loss, cuda_loss) + relative_loss("img:", img_ref.grad, img_cuda.grad) + relative_loss("target:", target_ref.grad, target_cuda.grad) + + +test_loss('l1', 'none') +test_loss('l1', 'log_srgb') +test_loss('mse', 'log_srgb') +test_loss('smape', 'none') +test_loss('relmse', 'none') +test_loss('mse', 'none') \ No newline at end of file diff --git a/nvdiffrec/lib/render/renderutils/tests/test_mesh.py b/nvdiffrec/lib/render/renderutils/tests/test_mesh.py new file mode 100644 index 0000000..4856c5c --- /dev/null +++ b/nvdiffrec/lib/render/renderutils/tests/test_mesh.py @@ -0,0 +1,90 @@ +# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +import torch + +import os +import sys +sys.path.insert(0, os.path.join(sys.path[0], '../..')) +import renderutils as ru + +BATCH = 8 +RES = 1024 +DTYPE = torch.float32 + +torch.manual_seed(0) + +def tonemap_srgb(f): + return torch.where(f > 0.0031308, torch.pow(torch.clamp(f, min=0.0031308), 1.0/2.4)*1.055 - 0.055, 12.92*f) + +def l1(output, target): + x = torch.clamp(output, min=0, max=65535) + r = torch.clamp(target, min=0, max=65535) + x = tonemap_srgb(torch.log(x + 1)) + r = tonemap_srgb(torch.log(r + 1)) + return torch.nn.functional.l1_loss(x,r) + +def relative_loss(name, ref, cuda): + ref = ref.float() + cuda = cuda.float() + print(name, torch.max(torch.abs(ref - cuda) / torch.abs(ref)).item()) + +def test_xfm_points(): + points_cuda = torch.rand(1, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) + points_ref = points_cuda.clone().detach().requires_grad_(True) + mtx_cuda = torch.rand(BATCH, 4, 4, dtype=DTYPE, device='cuda', requires_grad=False) + mtx_ref = mtx_cuda.clone().detach().requires_grad_(True) + target = torch.rand(BATCH, RES, 4, dtype=DTYPE, device='cuda', requires_grad=True) + + ref_out = ru.xfm_points(points_ref, mtx_ref, use_python=True) + ref_loss = torch.nn.MSELoss()(ref_out, target) + ref_loss.backward() + + cuda_out = ru.xfm_points(points_cuda, mtx_cuda) + cuda_loss = torch.nn.MSELoss()(cuda_out, target) + cuda_loss.backward() + + print("-------------------------------------------------------------") + + relative_loss("res:", ref_out, cuda_out) + relative_loss("points:", points_ref.grad, points_cuda.grad) + +def test_xfm_vectors(): + points_cuda = torch.rand(1, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) + points_ref = points_cuda.clone().detach().requires_grad_(True) + points_cuda_p = points_cuda.clone().detach().requires_grad_(True) + points_ref_p = points_cuda.clone().detach().requires_grad_(True) + mtx_cuda = torch.rand(BATCH, 4, 4, dtype=DTYPE, device='cuda', requires_grad=False) + mtx_ref = mtx_cuda.clone().detach().requires_grad_(True) + target = torch.rand(BATCH, RES, 4, dtype=DTYPE, device='cuda', requires_grad=True) + + ref_out = ru.xfm_vectors(points_ref.contiguous(), mtx_ref, use_python=True) + ref_loss = torch.nn.MSELoss()(ref_out, target[..., 0:3]) + ref_loss.backward() + + cuda_out = ru.xfm_vectors(points_cuda.contiguous(), mtx_cuda) + cuda_loss = torch.nn.MSELoss()(cuda_out, target[..., 0:3]) + cuda_loss.backward() + + ref_out_p = ru.xfm_points(points_ref_p.contiguous(), mtx_ref, use_python=True) + ref_loss_p = torch.nn.MSELoss()(ref_out_p, target) + ref_loss_p.backward() + + cuda_out_p = ru.xfm_points(points_cuda_p.contiguous(), mtx_cuda) + cuda_loss_p = torch.nn.MSELoss()(cuda_out_p, target) + cuda_loss_p.backward() + + print("-------------------------------------------------------------") + + relative_loss("res:", ref_out, cuda_out) + relative_loss("points:", points_ref.grad, points_cuda.grad) + relative_loss("points_p:", points_ref_p.grad, points_cuda_p.grad) + +test_xfm_points() +test_xfm_vectors() diff --git a/nvdiffrec/lib/render/renderutils/tests/test_perf.py b/nvdiffrec/lib/render/renderutils/tests/test_perf.py new file mode 100644 index 0000000..ffc143e --- /dev/null +++ b/nvdiffrec/lib/render/renderutils/tests/test_perf.py @@ -0,0 +1,57 @@ +# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +import torch + +import os +import sys +sys.path.insert(0, os.path.join(sys.path[0], '../..')) +import renderutils as ru + +DTYPE=torch.float32 + +def test_bsdf(BATCH, RES, ITR): + kd_cuda = torch.rand(BATCH, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) + kd_ref = kd_cuda.clone().detach().requires_grad_(True) + arm_cuda = torch.rand(BATCH, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) + arm_ref = arm_cuda.clone().detach().requires_grad_(True) + pos_cuda = torch.rand(BATCH, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) + pos_ref = pos_cuda.clone().detach().requires_grad_(True) + nrm_cuda = torch.rand(BATCH, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) + nrm_ref = nrm_cuda.clone().detach().requires_grad_(True) + view_cuda = torch.rand(BATCH, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) + view_ref = view_cuda.clone().detach().requires_grad_(True) + light_cuda = torch.rand(BATCH, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) + light_ref = light_cuda.clone().detach().requires_grad_(True) + target = torch.rand(BATCH, RES, RES, 3, device='cuda') + + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + + ru.pbr_bsdf(kd_cuda, arm_cuda, pos_cuda, nrm_cuda, view_cuda, light_cuda) + + print("--- Testing: [%d, %d, %d] ---" % (BATCH, RES, RES)) + + start.record() + for i in range(ITR): + ref = ru.pbr_bsdf(kd_ref, arm_ref, pos_ref, nrm_ref, view_ref, light_ref, use_python=True) + end.record() + torch.cuda.synchronize() + print("Pbr BSDF python:", start.elapsed_time(end)) + + start.record() + for i in range(ITR): + cuda = ru.pbr_bsdf(kd_cuda, arm_cuda, pos_cuda, nrm_cuda, view_cuda, light_cuda) + end.record() + torch.cuda.synchronize() + print("Pbr BSDF cuda:", start.elapsed_time(end)) + +test_bsdf(1, 512, 1000) +test_bsdf(16, 512, 1000) +test_bsdf(1, 2048, 1000) diff --git a/nvdiffrec/lib/render/texture.py b/nvdiffrec/lib/render/texture.py new file mode 100644 index 0000000..6605a7d --- /dev/null +++ b/nvdiffrec/lib/render/texture.py @@ -0,0 +1,187 @@ +# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +import os +import numpy as np +import torch +import nvdiffrast.torch as dr + +from . import util + +###################################################################################### +# Smooth pooling / mip computation with linear gradient upscaling +###################################################################################### + +class texture2d_mip(torch.autograd.Function): + @staticmethod + def forward(ctx, texture): + return util.avg_pool_nhwc(texture, (2,2)) + + @staticmethod + def backward(ctx, dout): + gy, gx = torch.meshgrid(torch.linspace(0.0 + 0.25 / dout.shape[1], 1.0 - 0.25 / dout.shape[1], dout.shape[1]*2, device="cuda"), + torch.linspace(0.0 + 0.25 / dout.shape[2], 1.0 - 0.25 / dout.shape[2], dout.shape[2]*2, device="cuda"), + indexing='ij') + uv = torch.stack((gx, gy), dim=-1) + return dr.texture(dout * 0.25, uv[None, ...].contiguous(), filter_mode='linear', boundary_mode='clamp') + +######################################################################################################## +# Simple texture class. A texture can be either +# - A 3D tensor (using auto mipmaps) +# - A list of 3D tensors (full custom mip hierarchy) +######################################################################################################## + +class Texture2D(torch.nn.Module): + # Initializes a texture from image data. + # Input can be constant value (1D array) or texture (3D array) or mip hierarchy (list of 3d arrays) + def __init__(self, init, min_max=None, trainable=True): + super(Texture2D, self).__init__() + + if isinstance(init, np.ndarray): + init = torch.tensor(init, dtype=torch.float32, device='cuda') + elif isinstance(init, list) and len(init) == 1: + init = init[0] + + if isinstance(init, list): + self.data = list(torch.nn.Parameter(mip.clone().detach(), requires_grad=trainable) for mip in init) + elif len(init.shape) == 4: + self.data = torch.nn.Parameter(init.clone().detach(), requires_grad=trainable) + elif len(init.shape) == 3: + self.data = torch.nn.Parameter(init[None, ...].clone().detach(), requires_grad=trainable) + elif len(init.shape) == 1: + self.data = torch.nn.Parameter(init[None, None, None, :].clone().detach(), requires_grad=trainable) # Convert constant to 1x1 tensor + else: + assert False, "Invalid texture object" + + self.min_max = min_max + + # Filtered (trilinear) sample texture at a given location + def sample(self, texc, texc_deriv, filter_mode='linear-mipmap-linear'): + if isinstance(self.data, list): + out = dr.texture(self.data[0], texc, texc_deriv, mip=self.data[1:], filter_mode=filter_mode) + else: + if self.data.shape[1] > 1 and self.data.shape[2] > 1: + mips = [self.data] + while mips[-1].shape[1] > 1 and mips[-1].shape[2] > 1: + mips += [texture2d_mip.apply(mips[-1])] + out = dr.texture(mips[0], texc, texc_deriv, mip=mips[1:], filter_mode=filter_mode) + else: + out = dr.texture(self.data, texc, texc_deriv, filter_mode=filter_mode) + return out + + def getRes(self): + return self.getMips()[0].shape[1:3] + + def getChannels(self): + return self.getMips()[0].shape[3] + + def getMips(self): + if isinstance(self.data, list): + return self.data + else: + return [self.data] + + # In-place clamp with no derivative to make sure values are in valid range after training + def clamp_(self): + if self.min_max is not None: + for mip in self.getMips(): + for i in range(mip.shape[-1]): + mip[..., i].clamp_(min=self.min_max[0][i], max=self.min_max[1][i]) + + # In-place clamp with no derivative to make sure values are in valid range after training + def normalize_(self): + with torch.no_grad(): + for mip in self.getMips(): + mip = util.safe_normalize(mip) + +######################################################################################################## +# Helper function to create a trainable texture from a regular texture. The trainable weights are +# initialized with texture data as an initial guess +######################################################################################################## + +def create_trainable(init, res=None, auto_mipmaps=True, min_max=None): + with torch.no_grad(): + if isinstance(init, Texture2D): + assert isinstance(init.data, torch.Tensor) + min_max = init.min_max if min_max is None else min_max + init = init.data + elif isinstance(init, np.ndarray): + init = torch.tensor(init, dtype=torch.float32, device='cuda') + + # Pad to NHWC if needed + if len(init.shape) == 1: # Extend constant to NHWC tensor + init = init[None, None, None, :] + elif len(init.shape) == 3: + init = init[None, ...] + + # Scale input to desired resolution. + if res is not None: + init = util.scale_img_nhwc(init, res) + + # Genreate custom mipchain + if not auto_mipmaps: + mip_chain = [init.clone().detach().requires_grad_(True)] + while mip_chain[-1].shape[1] > 1 or mip_chain[-1].shape[2] > 1: + new_size = [max(mip_chain[-1].shape[1] // 2, 1), max(mip_chain[-1].shape[2] // 2, 1)] + mip_chain += [util.scale_img_nhwc(mip_chain[-1], new_size)] + return Texture2D(mip_chain, min_max=min_max) + else: + return Texture2D(init, min_max=min_max) + +######################################################################################################## +# Convert texture to and from SRGB +######################################################################################################## + +def srgb_to_rgb(texture): + return Texture2D(list(util.srgb_to_rgb(mip) for mip in texture.getMips())) + +def rgb_to_srgb(texture): + return Texture2D(list(util.rgb_to_srgb(mip) for mip in texture.getMips())) + +######################################################################################################## +# Utility functions for loading / storing a texture +######################################################################################################## + +def _load_mip2D(fn, lambda_fn=None, channels=None): + imgdata = torch.tensor(util.load_image(fn), dtype=torch.float32, device='cuda') + if channels is not None: + imgdata = imgdata[..., 0:channels] + if lambda_fn is not None: + imgdata = lambda_fn(imgdata) + return imgdata.detach().clone() + +def load_texture2D(fn, lambda_fn=None, channels=None): + base, ext = os.path.splitext(fn) + # if os.path.exists(base + "_0" + ext): + # mips = [] + # while os.path.exists(base + ("_%d" % len(mips)) + ext): + # mips += [_load_mip2D(base + ("_%d" % len(mips)) + ext, lambda_fn, channels)] + # return Texture2D(mips) + # else: + # return Texture2D(_load_mip2D(fn, lambda_fn, channels)) + return Texture2D(_load_mip2D(fn, lambda_fn, channels)) + +def _save_mip2D(fn, mip, mipidx, lambda_fn): + if lambda_fn is not None: + data = lambda_fn(mip).detach().cpu().numpy() + else: + data = mip.detach().cpu().numpy() + + if mipidx is None: + util.save_image(fn, data) + else: + base, ext = os.path.splitext(fn) + util.save_image(base + ("_%d" % mipidx) + ext, data) + +def save_texture2D(fn, tex, lambda_fn=None): + if isinstance(tex.data, list): + for i, mip in enumerate(tex.data): + _save_mip2D(fn, mip[0,...], i, lambda_fn) + else: + _save_mip2D(fn, tex.data[0,...], None, lambda_fn) diff --git a/nvdiffrec/lib/render/util.py b/nvdiffrec/lib/render/util.py new file mode 100644 index 0000000..51c9cbd --- /dev/null +++ b/nvdiffrec/lib/render/util.py @@ -0,0 +1,482 @@ +# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +import os +import numpy as np +import torch +import nvdiffrast.torch as dr +import imageio + +#---------------------------------------------------------------------------- +# Vector operations +#---------------------------------------------------------------------------- + +def dot(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + return torch.sum(x*y, -1, keepdim=True) + +def reflect(x: torch.Tensor, n: torch.Tensor) -> torch.Tensor: + return 2*dot(x, n)*n - x + +def length(x: torch.Tensor, eps: float = 1e-20) -> torch.Tensor: + return torch.sqrt(torch.clamp(dot(x,x), min=eps)) # Clamp to avoid nan gradients because grad(sqrt(0)) = NaN + # print(dot(x,x).min()) + # if torch.isnan(dot(x,x).min()): + # raise + return torch.sqrt(dot(x,x) + eps) + eps # Clamp to avoid nan gradients because grad(sqrt(0)) = NaN + +def safe_normalize(x: torch.Tensor, eps: float = 1e-20) -> torch.Tensor: +# def safe_normalize(x: torch.Tensor, eps: float = 1e-9) -> torch.Tensor: + return x / length(x, eps) + +def to_hvec(x: torch.Tensor, w: float) -> torch.Tensor: + return torch.nn.functional.pad(x, pad=(0,1), mode='constant', value=w) + +#---------------------------------------------------------------------------- +# sRGB color transforms +#---------------------------------------------------------------------------- + +def _rgb_to_srgb(f: torch.Tensor) -> torch.Tensor: + return torch.where(f <= 0.0031308, f * 12.92, torch.pow(torch.clamp(f, 0.0031308), 1.0/2.4)*1.055 - 0.055) + +def rgb_to_srgb(f: torch.Tensor) -> torch.Tensor: + assert f.shape[-1] == 3 or f.shape[-1] == 4 + out = torch.cat((_rgb_to_srgb(f[..., 0:3]), f[..., 3:4]), dim=-1) if f.shape[-1] == 4 else _rgb_to_srgb(f) + assert out.shape[0] == f.shape[0] and out.shape[1] == f.shape[1] and out.shape[2] == f.shape[2] + return out + +def _srgb_to_rgb(f: torch.Tensor) -> torch.Tensor: + return torch.where(f <= 0.04045, f / 12.92, torch.pow((torch.clamp(f, 0.04045) + 0.055) / 1.055, 2.4)) + +def srgb_to_rgb(f: torch.Tensor) -> torch.Tensor: + assert f.shape[-1] == 3 or f.shape[-1] == 4 + out = torch.cat((_srgb_to_rgb(f[..., 0:3]), f[..., 3:4]), dim=-1) if f.shape[-1] == 4 else _srgb_to_rgb(f) + assert out.shape[0] == f.shape[0] and out.shape[1] == f.shape[1] and out.shape[2] == f.shape[2] + return out + +def reinhard(f: torch.Tensor) -> torch.Tensor: + return f/(1+f) + +#----------------------------------------------------------------------------------- +# Metrics (taken from jaxNerf source code, in order to replicate their measurements) +# +# https://github.com/google-research/google-research/blob/301451a62102b046bbeebff49a760ebeec9707b8/jaxnerf/nerf/utils.py#L266 +# +#----------------------------------------------------------------------------------- + +def mse_to_psnr(mse): + """Compute PSNR given an MSE (we assume the maximum pixel value is 1).""" + return -10. / np.log(10.) * np.log(mse) + +def psnr_to_mse(psnr): + """Compute MSE given a PSNR (we assume the maximum pixel value is 1).""" + return np.exp(-0.1 * np.log(10.) * psnr) + +#---------------------------------------------------------------------------- +# Displacement texture lookup +#---------------------------------------------------------------------------- + +def get_miplevels(texture: np.ndarray) -> float: + minDim = min(texture.shape[0], texture.shape[1]) + return np.floor(np.log2(minDim)) + +def tex_2d(tex_map : torch.Tensor, coords : torch.Tensor, filter='nearest') -> torch.Tensor: + tex_map = tex_map[None, ...] # Add batch dimension + tex_map = tex_map.permute(0, 3, 1, 2) # NHWC -> NCHW + tex = torch.nn.functional.grid_sample(tex_map, coords[None, None, ...] * 2 - 1, mode=filter, align_corners=False) + tex = tex.permute(0, 2, 3, 1) # NCHW -> NHWC + return tex[0, 0, ...] + +#---------------------------------------------------------------------------- +# Cubemap utility functions +#---------------------------------------------------------------------------- + +def cube_to_dir(s, x, y): + if s == 0: rx, ry, rz = torch.ones_like(x), -y, -x + elif s == 1: rx, ry, rz = -torch.ones_like(x), -y, x + elif s == 2: rx, ry, rz = x, torch.ones_like(x), y + elif s == 3: rx, ry, rz = x, -torch.ones_like(x), -y + elif s == 4: rx, ry, rz = x, -y, torch.ones_like(x) + elif s == 5: rx, ry, rz = -x, -y, -torch.ones_like(x) + return torch.stack((rx, ry, rz), dim=-1) + +def latlong_to_cubemap(latlong_map, res): + cubemap = torch.zeros(6, res[0], res[1], latlong_map.shape[-1], dtype=torch.float32, device='cuda') + for s in range(6): + gy, gx = torch.meshgrid(torch.linspace(-1.0 + 1.0 / res[0], 1.0 - 1.0 / res[0], res[0], device='cuda'), + torch.linspace(-1.0 + 1.0 / res[1], 1.0 - 1.0 / res[1], res[1], device='cuda'), + indexing='ij') + v = safe_normalize(cube_to_dir(s, gx, gy)) + + tu = torch.atan2(v[..., 0:1], -v[..., 2:3]) / (2 * np.pi) + 0.5 + tv = torch.acos(torch.clamp(v[..., 1:2], min=-1, max=1)) / np.pi + texcoord = torch.cat((tu, tv), dim=-1) + + cubemap[s, ...] = dr.texture(latlong_map[None, ...], texcoord[None, ...], filter_mode='linear')[0] + return cubemap + +def cubemap_to_latlong(cubemap, res): + gy, gx = torch.meshgrid(torch.linspace( 0.0 + 1.0 / res[0], 1.0 - 1.0 / res[0], res[0], device='cuda'), + torch.linspace(-1.0 + 1.0 / res[1], 1.0 - 1.0 / res[1], res[1], device='cuda'), + indexing='ij') + + sintheta, costheta = torch.sin(gy*np.pi), torch.cos(gy*np.pi) + sinphi, cosphi = torch.sin(gx*np.pi), torch.cos(gx*np.pi) + + reflvec = torch.stack(( + sintheta*sinphi, + costheta, + -sintheta*cosphi + ), dim=-1) + return dr.texture(cubemap[None, ...], reflvec[None, ...].contiguous(), filter_mode='linear', boundary_mode='cube')[0] + +#---------------------------------------------------------------------------- +# Image scaling +#---------------------------------------------------------------------------- + +def scale_img_hwc(x : torch.Tensor, size, mag='bilinear', min='area') -> torch.Tensor: + return scale_img_nhwc(x[None, ...], size, mag, min)[0] + +def scale_img_nhwc(x : torch.Tensor, size, mag='bilinear', min='area') -> torch.Tensor: + # assert (x.shape[1] >= size[0] and x.shape[2] >= size[1]) or (x.shape[1] < size[0] and x.shape[2] < size[1]), "Trying to magnify image in one dimension and minify in the other" + assert (x.shape[1] >= size[0] and x.shape[2] >= size[1]) or (x.shape[1] <= size[0] and x.shape[2] <= size[1]), "Trying to magnify image in one dimension and minify in the other" + y = x.permute(0, 3, 1, 2) # NHWC -> NCHW + if x.shape[1] > size[0] and x.shape[2] > size[1]: # Minification, previous size was bigger + y = torch.nn.functional.interpolate(y, size, mode=min) + else: # Magnification + if mag == 'bilinear' or mag == 'bicubic': + y = torch.nn.functional.interpolate(y, size, mode=mag, align_corners=True) + else: + y = torch.nn.functional.interpolate(y, size, mode=mag) + return y.permute(0, 2, 3, 1).contiguous() # NCHW -> NHWC + +def avg_pool_nhwc(x : torch.Tensor, size) -> torch.Tensor: + y = x.permute(0, 3, 1, 2) # NHWC -> NCHW + y = torch.nn.functional.avg_pool2d(y, size) + return y.permute(0, 2, 3, 1).contiguous() # NCHW -> NHWC + +#---------------------------------------------------------------------------- +# Behaves similar to tf.segment_sum +#---------------------------------------------------------------------------- + +def segment_sum(data: torch.Tensor, segment_ids: torch.Tensor) -> torch.Tensor: + num_segments = torch.unique_consecutive(segment_ids).shape[0] + + # Repeats ids until same dimension as data + if len(segment_ids.shape) == 1: + s = torch.prod(torch.tensor(data.shape[1:], dtype=torch.int64, device='cuda')).long() + segment_ids = segment_ids.repeat_interleave(s).view(segment_ids.shape[0], *data.shape[1:]) + + assert data.shape == segment_ids.shape, "data.shape and segment_ids.shape should be equal" + + shape = [num_segments] + list(data.shape[1:]) + result = torch.zeros(*shape, dtype=torch.float32, device='cuda') + result = result.scatter_add(0, segment_ids, data) + return result + +#---------------------------------------------------------------------------- +# Matrix helpers. +#---------------------------------------------------------------------------- + +def fovx_to_fovy(fovx, aspect): + return np.arctan(np.tan(fovx / 2) / aspect) * 2.0 + +def focal_length_to_fovy(focal_length, sensor_height): + return 2 * np.arctan(0.5 * sensor_height / focal_length) + +# Reworked so this matches gluPerspective / glm::perspective, using fovy +def perspective(fovy=0.7854, aspect=1.0, n=0.1, f=1000.0, device=None): + y = np.tan(fovy / 2) + return torch.tensor([[1/(y*aspect), 0, 0, 0], + [ 0, 1/-y, 0, 0], + [ 0, 0, -(f+n)/(f-n), -(2*f*n)/(f-n)], + [ 0, 0, -1, 0]], dtype=torch.float32, device=device) + +# Reworked so this matches gluPerspective / glm::perspective, using fovy +def perspective_offcenter(fovy, fraction, rx, ry, aspect=1.0, n=0.1, f=1000.0, device=None): + y = np.tan(fovy / 2) + + # Full frustum + R, L = aspect*y, -aspect*y + T, B = y, -y + + # Create a randomized sub-frustum + width = (R-L)*fraction + height = (T-B)*fraction + xstart = (R-L)*rx + ystart = (T-B)*ry + + l = L + xstart + r = l + width + b = B + ystart + t = b + height + + # https://www.scratchapixel.com/lessons/3d-basic-rendering/perspective-and-orthographic-projection-matrix/opengl-perspective-projection-matrix + return torch.tensor([[2/(r-l), 0, (r+l)/(r-l), 0], + [ 0, -2/(t-b), (t+b)/(t-b), 0], + [ 0, 0, -(f+n)/(f-n), -(2*f*n)/(f-n)], + [ 0, 0, -1, 0]], dtype=torch.float32, device=device) + +def translate(x, y, z, device=None): + return torch.tensor([[1, 0, 0, x], + [0, 1, 0, y], + [0, 0, 1, z], + [0, 0, 0, 1]], dtype=torch.float32, device=device) + +def rotate_x(a, device=None): + s, c = np.sin(a), np.cos(a) + return torch.tensor([[1, 0, 0, 0], + [0, c, s, 0], + [0, -s, c, 0], + [0, 0, 0, 1]], dtype=torch.float32, device=device) + +def rotate_y(a, device=None): + s, c = np.sin(a), np.cos(a) + return torch.tensor([[ c, 0, s, 0], + [ 0, 1, 0, 0], + [-s, 0, c, 0], + [ 0, 0, 0, 1]], dtype=torch.float32, device=device) + +def scale(s, device=None): + return torch.tensor([[ s, 0, 0, 0], + [ 0, s, 0, 0], + [ 0, 0, s, 0], + [ 0, 0, 0, 1]], dtype=torch.float32, device=device) + +def lookAt(eye, at, up): + a = eye - at + w = a / torch.linalg.norm(a) + u = torch.cross(up, w) + u = u / torch.linalg.norm(u) + v = torch.cross(w, u) + translate = torch.tensor([[1, 0, 0, -eye[0]], + [0, 1, 0, -eye[1]], + [0, 0, 1, -eye[2]], + [0, 0, 0, 1]], dtype=eye.dtype, device=eye.device) + rotate = torch.tensor([[u[0], u[1], u[2], 0], + [v[0], v[1], v[2], 0], + [w[0], w[1], w[2], 0], + [0, 0, 0, 1]], dtype=eye.dtype, device=eye.device) + return rotate @ translate + +@torch.no_grad() +def random_rotation_translation(t, device=None): + m = np.random.normal(size=[3, 3]) + m[1] = np.cross(m[0], m[2]) + m[2] = np.cross(m[0], m[1]) + m = m / np.linalg.norm(m, axis=1, keepdims=True) + m = np.pad(m, [[0, 1], [0, 1]], mode='constant') + m[3, 3] = 1.0 + m[:3, 3] = np.random.uniform(-t, t, size=[3]) + return torch.tensor(m, dtype=torch.float32, device=device) + +@torch.no_grad() +def random_rotation(device=None): + m = np.random.normal(size=[3, 3]) + m[1] = np.cross(m[0], m[2]) + m[2] = np.cross(m[0], m[1]) + m = m / np.linalg.norm(m, axis=1, keepdims=True) + m = np.pad(m, [[0, 1], [0, 1]], mode='constant') + m[3, 3] = 1.0 + m[:3, 3] = np.array([0,0,0]).astype(np.float32) + return torch.tensor(m, dtype=torch.float32, device=device) + +@torch.no_grad() +def batch_random_rotation(batch_size, device=None): + m = np.random.normal(size=[batch_size, 3, 3]) + m[:, 1] = np.cross(m[:, 0], m[:, 2]) + m[:, 2] = np.cross(m[:, 0], m[:, 1]) + m = m / np.linalg.norm(m, axis=-1, keepdims=True) + m = np.pad(m, [[0, 0], [0, 1], [0, 1]], mode='constant') + m[:, 3, 3] = 1.0 + m[:, :3, 3] = np.array([0,0,0]).astype(np.float32).unsqueeze(0) + return torch.tensor(m, dtype=torch.float32, device=device) + +#---------------------------------------------------------------------------- +# Compute focal points of a set of lines using least squares. +# handy for poorly centered datasets +#---------------------------------------------------------------------------- + +def lines_focal(o, d): + d = safe_normalize(d) + I = torch.eye(3, dtype=o.dtype, device=o.device) + S = torch.sum(d[..., None] @ torch.transpose(d[..., None], 1, 2) - I[None, ...], dim=0) + C = torch.sum((d[..., None] @ torch.transpose(d[..., None], 1, 2) - I[None, ...]) @ o[..., None], dim=0).squeeze(1) + return torch.linalg.pinv(S) @ C + +#---------------------------------------------------------------------------- +# Cosine sample around a vector N +#---------------------------------------------------------------------------- +@torch.no_grad() +def cosine_sample(N, size=None): + # construct local frame + N = N/torch.linalg.norm(N) + + dx0 = torch.tensor([0, N[2], -N[1]], dtype=N.dtype, device=N.device) + dx1 = torch.tensor([-N[2], 0, N[0]], dtype=N.dtype, device=N.device) + + dx = torch.where(dot(dx0, dx0) > dot(dx1, dx1), dx0, dx1) + #dx = dx0 if np.dot(dx0,dx0) > np.dot(dx1,dx1) else dx1 + dx = dx / torch.linalg.norm(dx) + dy = torch.cross(N,dx) + dy = dy / torch.linalg.norm(dy) + + # cosine sampling in local frame + if size is None: + phi = 2.0 * np.pi * np.random.uniform() + s = np.random.uniform() + else: + phi = 2.0 * np.pi * torch.rand(*size, 1, dtype=N.dtype, device=N.device) + s = torch.rand(*size, 1, dtype=N.dtype, device=N.device) + costheta = np.sqrt(s) + sintheta = np.sqrt(1.0 - s) + + # cartesian vector in local space + x = np.cos(phi)*sintheta + y = np.sin(phi)*sintheta + z = costheta + + # local to world + return dx*x + dy*y + N*z + +#---------------------------------------------------------------------------- +# Bilinear downsample by 2x. +#---------------------------------------------------------------------------- + +def bilinear_downsample(x : torch.tensor) -> torch.Tensor: + w = torch.tensor([[1, 3, 3, 1], [3, 9, 9, 3], [3, 9, 9, 3], [1, 3, 3, 1]], dtype=torch.float32, device=x.device) / 64.0 + w = w.expand(x.shape[-1], 1, 4, 4) + x = torch.nn.functional.conv2d(x.permute(0, 3, 1, 2), w, padding=1, stride=2, groups=x.shape[-1]) + return x.permute(0, 2, 3, 1) + +#---------------------------------------------------------------------------- +# Bilinear downsample log(spp) steps +#---------------------------------------------------------------------------- + +def bilinear_downsample(x : torch.tensor, spp) -> torch.Tensor: + w = torch.tensor([[1, 3, 3, 1], [3, 9, 9, 3], [3, 9, 9, 3], [1, 3, 3, 1]], dtype=torch.float32, device=x.device) / 64.0 + g = x.shape[-1] + w = w.expand(g, 1, 4, 4) + x = x.permute(0, 3, 1, 2) # NHWC -> NCHW + steps = int(np.log2(spp)) + for _ in range(steps): + xp = torch.nn.functional.pad(x, (1,1,1,1), mode='replicate') + x = torch.nn.functional.conv2d(xp, w, padding=0, stride=2, groups=g) + return x.permute(0, 2, 3, 1).contiguous() # NCHW -> NHWC + +#---------------------------------------------------------------------------- +# Singleton initialize GLFW +#---------------------------------------------------------------------------- + +_glfw_initialized = False +def init_glfw(): + global _glfw_initialized + try: + import glfw + glfw.ERROR_REPORTING = 'raise' + glfw.default_window_hints() + glfw.window_hint(glfw.VISIBLE, glfw.FALSE) + test = glfw.create_window(8, 8, "Test", None, None) # Create a window and see if not initialized yet + except glfw.GLFWError as e: + if e.error_code == glfw.NOT_INITIALIZED: + glfw.init() + _glfw_initialized = True + +#---------------------------------------------------------------------------- +# Image display function using OpenGL. +#---------------------------------------------------------------------------- + +_glfw_window = None +def display_image(image, title=None): + # Import OpenGL + import OpenGL.GL as gl + import glfw + + # Zoom image if requested. + image = np.asarray(image[..., 0:3]) if image.shape[-1] == 4 else np.asarray(image) + height, width, channels = image.shape + + # Initialize window. + init_glfw() + if title is None: + title = 'Debug window' + global _glfw_window + if _glfw_window is None: + glfw.default_window_hints() + _glfw_window = glfw.create_window(width, height, title, None, None) + glfw.make_context_current(_glfw_window) + glfw.show_window(_glfw_window) + glfw.swap_interval(0) + else: + glfw.make_context_current(_glfw_window) + glfw.set_window_title(_glfw_window, title) + glfw.set_window_size(_glfw_window, width, height) + + # Update window. + glfw.poll_events() + gl.glClearColor(0, 0, 0, 1) + gl.glClear(gl.GL_COLOR_BUFFER_BIT) + gl.glWindowPos2f(0, 0) + gl.glPixelStorei(gl.GL_UNPACK_ALIGNMENT, 1) + gl_format = {3: gl.GL_RGB, 2: gl.GL_RG, 1: gl.GL_LUMINANCE}[channels] + gl_dtype = {'uint8': gl.GL_UNSIGNED_BYTE, 'float32': gl.GL_FLOAT}[image.dtype.name] + gl.glDrawPixels(width, height, gl_format, gl_dtype, image[::-1]) + glfw.swap_buffers(_glfw_window) + if glfw.window_should_close(_glfw_window): + return False + return True + +#---------------------------------------------------------------------------- +# Image save/load helper. +#---------------------------------------------------------------------------- + +def save_image(fn, x : np.ndarray): + try: + if os.path.splitext(fn)[1] == ".png": + imageio.imwrite(fn, np.clip(np.rint(x * 255.0), 0, 255).astype(np.uint8), compress_level=3) # Low compression for faster saving + else: + imageio.imwrite(fn, np.clip(np.rint(x * 255.0), 0, 255).astype(np.uint8)) + except: + print("WARNING: FAILED to save image %s" % fn) + +def save_image_raw(fn, x : np.ndarray): + try: + imageio.imwrite(fn, x) + except: + print("WARNING: FAILED to save image %s" % fn) + + +def load_image_raw(fn) -> np.ndarray: + return imageio.imread(fn) + +def load_image(fn) -> np.ndarray: + img = load_image_raw(fn) + if img.dtype == np.float32: # HDR image + return img + else: # LDR image + return img.astype(np.float32) / 255 + +#---------------------------------------------------------------------------- + +def time_to_text(x): + if x > 3600: + return "%.2f h" % (x / 3600) + elif x > 60: + return "%.2f m" % (x / 60) + else: + return "%.2f s" % x + +#---------------------------------------------------------------------------- + +def checkerboard(res, checker_size) -> np.ndarray: + tiles_y = (res[0] + (checker_size*2) - 1) // (checker_size*2) + tiles_x = (res[1] + (checker_size*2) - 1) // (checker_size*2) + check = np.kron([[1, 0] * tiles_x, [0, 1] * tiles_x] * tiles_y, np.ones((checker_size, checker_size)))*0.33 + 0.33 + check = check[:res[0], :res[1]] + return np.stack((check, check, check), axis=-1) +