MeshDiffusion/lib/diffusion/models/utils.py

213 wiersze
6.0 KiB
Python

# coding=utf-8
# Copyright 2020 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""All functions and modules related to model definition.
"""
import torch
from .. import sde_lib
import numpy as np
_MODELS = {}
def register_model(cls=None, *, name=None):
"""A decorator for registering model classes."""
def _register(cls):
if name is None:
local_name = cls.__name__
else:
local_name = name
if local_name in _MODELS:
raise ValueError(f'Already registered model with name: {local_name}')
_MODELS[local_name] = cls
return cls
if cls is None:
return _register
else:
return _register(cls)
def get_model(name):
return _MODELS[name]
def get_sigmas(config):
"""Get sigmas --- the set of noise levels for SMLD from config files.
Args:
config: A ConfigDict object parsed from the config file
Returns:
sigmas: a jax numpy arrary of noise levels
"""
sigmas = np.exp(
np.linspace(np.log(config.model.sigma_max), np.log(config.model.sigma_min), config.model.num_scales))
return sigmas
def get_ddpm_params(config):
"""Get betas and alphas --- parameters used in the original DDPM paper."""
num_diffusion_timesteps = 1000
# parameters need to be adapted if number of time steps differs from 1000
beta_start = config.model.beta_min / config.model.num_scales
beta_end = config.model.beta_max / config.model.num_scales
betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64)
alphas = 1. - betas
alphas_cumprod = np.cumprod(alphas, axis=0)
sqrt_alphas_cumprod = np.sqrt(alphas_cumprod)
sqrt_1m_alphas_cumprod = np.sqrt(1. - alphas_cumprod)
return {
'betas': betas,
'alphas': alphas,
'alphas_cumprod': alphas_cumprod,
'sqrt_alphas_cumprod': sqrt_alphas_cumprod,
'sqrt_1m_alphas_cumprod': sqrt_1m_alphas_cumprod,
'beta_min': beta_start * (num_diffusion_timesteps - 1),
'beta_max': beta_end * (num_diffusion_timesteps - 1),
'num_diffusion_timesteps': num_diffusion_timesteps
}
def create_model(config, use_parallel=True):
"""Create the score model."""
model_name = config.model.name
score_model = get_model(model_name)(config)
# score_model = score_model.to(config.device)
score_model = score_model
if use_parallel:
score_model = torch.nn.DataParallel(score_model).to(config.device)
return score_model
def get_model_fn(model, train=False):
"""Create a function to give the output of the score-based model.
Args:
model: The score model.
train: `True` for training and `False` for evaluation.
Returns:
A model function.
"""
def model_fn(x, labels):
"""Compute the output of the score-based model.
Args:
x: A mini-batch of input data.
labels: A mini-batch of conditioning variables for time steps. Should be interpreted differently
for different models.
Returns:
A tuple of (model output, new mutable states)
"""
if not train:
model.eval()
return model(x, labels)
else:
model.train()
return model(x, labels)
return model_fn
def get_reg_fn(model, train=False):
"""Create a function to give the output of the score-based model.
Args:
model: The score model.
train: `True` for training and `False` for evaluation.
Returns:
A model function.
"""
def model_fn(x):
"""Compute the output of the score-based model.
Args:
x: A mini-batch of input data.
labels: A mini-batch of conditioning variables for time steps. Should be interpreted differently
for different models.
Returns:
A tuple of (model output, new mutable states)
"""
if not train:
model.eval()
try:
return model.get_reg(x)
except:
return torch.zeros_like(x, device=x.device)
else:
model.train()
try:
return model.get_reg(x)
except:
return torch.zeros_like(x, device=x.device)
return model_fn
def get_score_fn(sde, model, train=False, continuous=False, std_scale=True):
"""Wraps `score_fn` so that the model output corresponds to a real time-dependent score function.
Args:
sde: An `sde_lib.SDE` object that represents the forward SDE.
model: A score model.
train: `True` for training and `False` for evaluation.
continuous: If `True`, the score-based model is expected to directly take continuous time steps.
std_scale: whether to scale the score function by the inverse of std. Used for DDIM sampling
Returns:
A score function.
"""
model_fn = get_model_fn(model, train=train)
reg_fn = get_reg_fn(model, train=train)
assert not continuous
if isinstance(sde, sde_lib.VPSDE):
if not std_scale:
def score_fn(x, t):
labels = t * (sde.N - 1)
score = model_fn(x, labels)
return score
else:
def score_fn(x, t):
# For VP-trained models, t=0 corresponds to the lowest noise level
labels = t * (sde.N - 1)
score = model_fn(x, labels)
std = sde.sqrt_1m_alphas_cumprod.to(labels.device)[labels.long()]
score = -score / std[:, None, None, None, None]
return score
else:
raise NotImplementedError(f"SDE class {sde.__class__.__name__} not yet supported.")
return score_fn
def to_flattened_numpy(x):
"""Flatten a torch tensor `x` and convert it to numpy."""
return x.detach().cpu().numpy().reshape((-1,))
def from_flattened_numpy(x, shape):
"""Form a torch tensor with the given `shape` from a flattened numpy array `x`."""
return torch.from_numpy(x.reshape(shape))