# coding=utf-8 # Copyright 2020 The Google Research Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """All functions and modules related to model definition. """ import torch from .. import sde_lib import numpy as np _MODELS = {} def register_model(cls=None, *, name=None): """A decorator for registering model classes.""" def _register(cls): if name is None: local_name = cls.__name__ else: local_name = name if local_name in _MODELS: raise ValueError(f'Already registered model with name: {local_name}') _MODELS[local_name] = cls return cls if cls is None: return _register else: return _register(cls) def get_model(name): return _MODELS[name] def get_sigmas(config): """Get sigmas --- the set of noise levels for SMLD from config files. Args: config: A ConfigDict object parsed from the config file Returns: sigmas: a jax numpy arrary of noise levels """ sigmas = np.exp( np.linspace(np.log(config.model.sigma_max), np.log(config.model.sigma_min), config.model.num_scales)) return sigmas def get_ddpm_params(config): """Get betas and alphas --- parameters used in the original DDPM paper.""" num_diffusion_timesteps = 1000 # parameters need to be adapted if number of time steps differs from 1000 beta_start = config.model.beta_min / config.model.num_scales beta_end = config.model.beta_max / config.model.num_scales betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64) alphas = 1. - betas alphas_cumprod = np.cumprod(alphas, axis=0) sqrt_alphas_cumprod = np.sqrt(alphas_cumprod) sqrt_1m_alphas_cumprod = np.sqrt(1. - alphas_cumprod) return { 'betas': betas, 'alphas': alphas, 'alphas_cumprod': alphas_cumprod, 'sqrt_alphas_cumprod': sqrt_alphas_cumprod, 'sqrt_1m_alphas_cumprod': sqrt_1m_alphas_cumprod, 'beta_min': beta_start * (num_diffusion_timesteps - 1), 'beta_max': beta_end * (num_diffusion_timesteps - 1), 'num_diffusion_timesteps': num_diffusion_timesteps } def create_model(config, use_parallel=True): """Create the score model.""" model_name = config.model.name score_model = get_model(model_name)(config) # score_model = score_model.to(config.device) score_model = score_model if use_parallel: score_model = torch.nn.DataParallel(score_model).to(config.device) return score_model def get_model_fn(model, train=False): """Create a function to give the output of the score-based model. Args: model: The score model. train: `True` for training and `False` for evaluation. Returns: A model function. """ def model_fn(x, labels): """Compute the output of the score-based model. Args: x: A mini-batch of input data. labels: A mini-batch of conditioning variables for time steps. Should be interpreted differently for different models. Returns: A tuple of (model output, new mutable states) """ if not train: model.eval() return model(x, labels) else: model.train() return model(x, labels) return model_fn def get_reg_fn(model, train=False): """Create a function to give the output of the score-based model. Args: model: The score model. train: `True` for training and `False` for evaluation. Returns: A model function. """ def model_fn(x): """Compute the output of the score-based model. Args: x: A mini-batch of input data. labels: A mini-batch of conditioning variables for time steps. Should be interpreted differently for different models. Returns: A tuple of (model output, new mutable states) """ if not train: model.eval() try: return model.get_reg(x) except: return torch.zeros_like(x, device=x.device) else: model.train() try: return model.get_reg(x) except: return torch.zeros_like(x, device=x.device) return model_fn def get_score_fn(sde, model, train=False, continuous=False, std_scale=True): """Wraps `score_fn` so that the model output corresponds to a real time-dependent score function. Args: sde: An `sde_lib.SDE` object that represents the forward SDE. model: A score model. train: `True` for training and `False` for evaluation. continuous: If `True`, the score-based model is expected to directly take continuous time steps. std_scale: whether to scale the score function by the inverse of std. Used for DDIM sampling Returns: A score function. """ model_fn = get_model_fn(model, train=train) reg_fn = get_reg_fn(model, train=train) assert not continuous if isinstance(sde, sde_lib.VPSDE): if not std_scale: def score_fn(x, t): labels = t * (sde.N - 1) score = model_fn(x, labels) return score else: def score_fn(x, t): # For VP-trained models, t=0 corresponds to the lowest noise level labels = t * (sde.N - 1) score = model_fn(x, labels) std = sde.sqrt_1m_alphas_cumprod.to(labels.device)[labels.long()] score = -score / std[:, None, None, None, None] return score else: raise NotImplementedError(f"SDE class {sde.__class__.__name__} not yet supported.") return score_fn def to_flattened_numpy(x): """Flatten a torch tensor `x` and convert it to numpy.""" return x.detach().cpu().numpy().reshape((-1,)) def from_flattened_numpy(x, shape): """Form a torch tensor with the given `shape` from a flattened numpy array `x`.""" return torch.from_numpy(x.reshape(shape))