kopia lustrzana https://github.com/lzzcd001/MeshDiffusion
114 wiersze
4.6 KiB
Python
114 wiersze
4.6 KiB
Python
# coding=utf-8
|
|
# Copyright 2020 The Google Research Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
# pylint: skip-file
|
|
# pytype: skip-file
|
|
"""Various sampling methods."""
|
|
|
|
import torch
|
|
import numpy as np
|
|
from scipy import integrate
|
|
from .models import utils as mutils
|
|
|
|
|
|
def get_div_fn(fn):
|
|
"""Create the divergence function of `fn` using the Hutchinson-Skilling trace estimator."""
|
|
|
|
def div_fn(x, t, eps):
|
|
with torch.enable_grad():
|
|
x.requires_grad_(True)
|
|
fn_eps = torch.sum(fn(x, t) * eps)
|
|
grad_fn_eps = torch.autograd.grad(fn_eps, x)[0]
|
|
x.requires_grad_(False)
|
|
return torch.sum(grad_fn_eps * eps, dim=tuple(range(1, len(x.shape))))
|
|
|
|
return div_fn
|
|
|
|
|
|
def get_likelihood_fn(sde, inverse_scaler, hutchinson_type='Rademacher',
|
|
rtol=1e-5, atol=1e-5, method='RK45', eps=1e-5):
|
|
"""Create a function to compute the unbiased log-likelihood estimate of a given data point.
|
|
|
|
Args:
|
|
sde: A `sde_lib.SDE` object that represents the forward SDE.
|
|
inverse_scaler: The inverse data normalizer.
|
|
hutchinson_type: "Rademacher" or "Gaussian". The type of noise for Hutchinson-Skilling trace estimator.
|
|
rtol: A `float` number. The relative tolerance level of the black-box ODE solver.
|
|
atol: A `float` number. The absolute tolerance level of the black-box ODE solver.
|
|
method: A `str`. The algorithm for the black-box ODE solver.
|
|
See documentation for `scipy.integrate.solve_ivp`.
|
|
eps: A `float` number. The probability flow ODE is integrated to `eps` for numerical stability.
|
|
|
|
Returns:
|
|
A function that a batch of data points and returns the log-likelihoods in bits/dim,
|
|
the latent code, and the number of function evaluations cost by computation.
|
|
"""
|
|
|
|
def drift_fn(model, x, t):
|
|
"""The drift function of the reverse-time SDE."""
|
|
score_fn = mutils.get_score_fn(sde, model, train=False, continuous=True)
|
|
# Probability flow ODE is a special case of Reverse SDE
|
|
rsde = sde.reverse(score_fn, probability_flow=True)
|
|
return rsde.sde(x, t)[0]
|
|
|
|
def div_fn(model, x, t, noise):
|
|
return get_div_fn(lambda xx, tt: drift_fn(model, xx, tt))(x, t, noise)
|
|
|
|
def likelihood_fn(model, data):
|
|
"""Compute an unbiased estimate to the log-likelihood in bits/dim.
|
|
|
|
Args:
|
|
model: A score model.
|
|
data: A PyTorch tensor.
|
|
|
|
Returns:
|
|
bpd: A PyTorch tensor of shape [batch size]. The log-likelihoods on `data` in bits/dim.
|
|
z: A PyTorch tensor of the same shape as `data`. The latent representation of `data` under the
|
|
probability flow ODE.
|
|
nfe: An integer. The number of function evaluations used for running the black-box ODE solver.
|
|
"""
|
|
with torch.no_grad():
|
|
shape = data.shape
|
|
if hutchinson_type == 'Gaussian':
|
|
epsilon = torch.randn_like(data)
|
|
elif hutchinson_type == 'Rademacher':
|
|
epsilon = torch.randint_like(data, low=0, high=2).float() * 2 - 1.
|
|
else:
|
|
raise NotImplementedError(f"Hutchinson type {hutchinson_type} unknown.")
|
|
|
|
def ode_func(t, x):
|
|
sample = mutils.from_flattened_numpy(x[:-shape[0]], shape).to(data.device).type(torch.float32)
|
|
vec_t = torch.ones(sample.shape[0], device=sample.device) * t
|
|
drift = mutils.to_flattened_numpy(drift_fn(model, sample, vec_t))
|
|
logp_grad = mutils.to_flattened_numpy(div_fn(model, sample, vec_t, epsilon))
|
|
return np.concatenate([drift, logp_grad], axis=0)
|
|
|
|
init = np.concatenate([mutils.to_flattened_numpy(data), np.zeros((shape[0],))], axis=0)
|
|
solution = integrate.solve_ivp(ode_func, (eps, sde.T), init, rtol=rtol, atol=atol, method=method)
|
|
nfe = solution.nfev
|
|
zp = solution.y[:, -1]
|
|
z = mutils.from_flattened_numpy(zp[:-shape[0]], shape).to(data.device).type(torch.float32)
|
|
delta_logp = mutils.from_flattened_numpy(zp[-shape[0]:], (shape[0],)).to(data.device).type(torch.float32)
|
|
prior_logp = sde.prior_logp(z)
|
|
bpd = -(prior_logp + delta_logp) / np.log(2)
|
|
N = np.prod(shape[1:])
|
|
bpd = bpd / N
|
|
# A hack to convert log-likelihoods to bits/dim
|
|
offset = 7. - inverse_scaler(-1.)
|
|
bpd = bpd + offset
|
|
return bpd, z, nfe
|
|
|
|
return likelihood_fn
|