# coding=utf-8 # Copyright 2020 The Google Research Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # pylint: skip-file """DDPM model. This code is the pytorch equivalent of: https://github.com/hojonathanho/diffusion/blob/master/diffusion_tf/models/unet.py """ import torch import torch.nn as nn import functools import numpy as np from . import utils, layers, normalization # RefineBlock = layers.RefineBlock # ResidualBlock = layers.ResidualBlock ResnetBlockDDPM = layers.ResnetBlockDDPM Upsample = layers.Upsample Downsample = layers.Downsample conv3x3 = layers.ddpm_conv3x3 conv5x5 = layers.ddpm_conv5x5 transposed_conv6x6 = layers.ddpm_conv6x6_transposed get_act = layers.get_act get_normalization = normalization.get_normalization default_initializer = layers.default_init @utils.register_model(name='ddpm_res128') class DDPMRes128(nn.Module): def __init__(self, config): super().__init__() self.act = act = get_act(config) self.register_buffer('sigmas', torch.tensor(utils.get_sigmas(config))) self.nf = nf = config.model.nf ch_mult = config.model.ch_mult self.num_res_blocks = num_res_blocks = config.model.num_res_blocks self.attn_resolutions = attn_resolutions = config.model.attn_resolutions dropout = config.model.dropout resamp_with_conv = config.model.resamp_with_conv self.num_resolutions = num_resolutions = len(ch_mult) self.all_resolutions = all_resolutions = [config.data.image_size // (2 ** i) for i in range(num_resolutions)] ## manual for 128 to 64 AttnBlock = functools.partial(layers.AttnBlock) self.conditional = conditional = config.model.conditional ResnetBlock = functools.partial(ResnetBlockDDPM, act=act, temb_dim=4 * nf, dropout=dropout) if conditional: # Condition on noise levels. modules = [nn.Linear(nf, nf * 4)] modules[0].weight.data = default_initializer()(modules[0].weight.data.shape) nn.init.zeros_(modules[0].bias) modules.append(nn.Linear(nf * 4, nf * 4)) modules[1].weight.data = default_initializer()(modules[1].weight.data.shape) nn.init.zeros_(modules[1].bias) self.centered = config.data.centered channels = config.data.num_channels ##### Pos Encoding self.img_size = img_size = config.data.image_size self.num_freq = int(np.log2(img_size)) coord_x, coord_y, coord_z = torch.meshgrid(torch.arange(img_size), torch.arange(img_size), torch.arange(img_size)) self.use_coords = False if self.use_coords: self.coords = torch.nn.Parameter( # torch.stack([coord_x, coord_y, coord_z]).view(1, 3, img_size, img_size, img_size) * 0.0, torch.stack([coord_x, coord_y, coord_z]).view(1, 3, img_size, img_size, img_size), requires_grad=False ) #### ### Mask self.mask = torch.nn.Parameter(torch.zeros(1, 1, img_size, img_size, img_size), requires_grad=False) # Downsampling block self.pos_layer = conv5x5(3, nf, stride=1, padding=2) self.mask_layer = conv5x5(1, nf, stride=1, padding=2) modules.append(conv5x5(channels, nf, stride=1, padding=2)) hs_c = [nf] in_ch = nf for i_level in range(num_resolutions): num_res_blocks_curr = self.num_res_blocks if i_level != 0 else 2 # Residual blocks for this resolution for i_block in range(num_res_blocks_curr): out_ch = nf * ch_mult[i_level] modules.append(ResnetBlock(in_ch=in_ch, out_ch=out_ch)) in_ch = out_ch if all_resolutions[i_level] in attn_resolutions: modules.append(AttnBlock(channels=in_ch)) hs_c.append(in_ch) if i_level != num_resolutions - 1: modules.append(Downsample(channels=in_ch, with_conv=resamp_with_conv)) hs_c.append(in_ch) in_ch = hs_c[-1] modules.append(ResnetBlock(in_ch=in_ch)) modules.append(AttnBlock(channels=in_ch)) modules.append(ResnetBlock(in_ch=in_ch)) # Upsampling block for i_level in reversed(range(num_resolutions)): num_res_blocks_curr = self.num_res_blocks if i_level != 0 else 2 for i_block in range(num_res_blocks_curr + 1): out_ch = nf * ch_mult[i_level] modules.append(ResnetBlock(in_ch=in_ch + hs_c.pop(), out_ch=out_ch)) in_ch = out_ch if all_resolutions[i_level] in attn_resolutions: modules.append(AttnBlock(channels=in_ch)) if i_level != 0: modules.append(Upsample(channels=in_ch, with_conv=resamp_with_conv)) assert not hs_c modules.append(nn.GroupNorm(num_channels=in_ch, num_groups=32, eps=1e-6)) # modules.append(conv3x3(in_ch, channels, init_scale=0.)) # modules.append(transposed_conv6x6(in_ch, channels, init_scale=0.)) modules.append(conv5x5(in_ch, channels, init_scale=0., stride=1, padding=2)) self.all_modules = nn.ModuleList(modules) self.scale_by_sigma = config.model.scale_by_sigma def forward(self, x, labels): modules = self.all_modules m_idx = 0 if self.conditional: # timestep/scale embedding timesteps = labels temb = layers.get_timestep_embedding(timesteps, self.nf) temb = modules[m_idx](temb) m_idx += 1 temb = modules[m_idx](self.act(temb)) m_idx += 1 else: temb = None if self.centered: # Input is in [-1, 1] h = x else: # Input is in [0, 1] h = 2 * x - 1. # Downsampling block if self.use_coords: hs = [modules[m_idx](h) + self.pos_layer(self.coords) + self.mask_layer(self.mask)] else: hs = [modules[m_idx](h) + self.mask_layer(self.mask)] m_idx += 1 for i_level in range(self.num_resolutions): # Residual blocks for this resolution num_res_blocks = self.num_res_blocks if i_level != 0 else 2 for i_block in range(num_res_blocks): h = modules[m_idx](hs[-1], temb) m_idx += 1 if h.shape[-1] in self.attn_resolutions: h = modules[m_idx](h) m_idx += 1 hs.append(h) if i_level != self.num_resolutions - 1: hs.append(modules[m_idx](hs[-1])) m_idx += 1 h = hs[-1] h = modules[m_idx](h, temb) m_idx += 1 h = modules[m_idx](h) m_idx += 1 h = modules[m_idx](h, temb) m_idx += 1 # Upsampling block for i_level in reversed(range(self.num_resolutions)): num_res_blocks = self.num_res_blocks if i_level != 0 else 2 for i_block in range(num_res_blocks + 1): hspop = hs.pop() input = torch.cat([h, hspop], dim=1) h = modules[m_idx](input, temb) m_idx += 1 if h.shape[-1] in self.attn_resolutions: h = modules[m_idx](h) m_idx += 1 if i_level != 0: h = modules[m_idx](h) m_idx += 1 assert not hs h = self.act(modules[m_idx](h)) m_idx += 1 h = modules[m_idx](h) m_idx += 1 assert m_idx == len(modules) if self.scale_by_sigma: # Divide the output by sigmas. Useful for training with the NCSN loss. # The DDPM loss scales the network output by sigma in the loss function, # so no need of doing it here. used_sigmas = self.sigmas[labels, None, None, None] h = h / used_sigmas return h