MeshDiffusion/lib/diffusion/evaler.py

213 wiersze
7.7 KiB
Python

import os
import sys
import numpy as np
import logging
from . import losses
from .models import utils as mutils
from .models.ema import ExponentialMovingAverage
from . import sde_lib
import torch
from .utils import restore_checkpoint
from . import sampling
def uncond_gen(
config,
idx=0,
):
"""
Unconditional Generation
"""
with torch.no_grad():
eval_dir, ckpt_path = config.eval.eval_dir, config.eval.ckpt_path
# Create directory to eval_folder
os.makedirs(eval_dir, exist_ok=True)
scaler, inverse_scaler = lambda x: x, lambda x: x
# Initialize model
score_model = mutils.create_model(config)
optimizer = losses.get_optimizer(config, score_model.parameters())
ema = ExponentialMovingAverage(score_model.parameters(), decay=config.model.ema_rate)
state = dict(optimizer=optimizer, model=score_model, ema=ema, step=0)
# Setup SDEs
sde = sde_lib.VPSDE(beta_min=config.model.beta_min, beta_max=config.model.beta_max, N=config.model.num_scales)
img_size = config.data.image_size
grid_mask = torch.load(f'./data/grid_mask_{img_size}.pt').view(1, img_size, img_size, img_size).to("cuda")
sampling_eps = 1e-3
sampling_shape = (config.eval.batch_size,
config.data.num_channels,
config.data.image_size, config.data.image_size, config.data.image_size)
sampling_fn = sampling.get_sampling_fn(config, sde, sampling_shape, inverse_scaler, sampling_eps, grid_mask=grid_mask)
assert os.path.exists(ckpt_path)
print('ckpt path:', ckpt_path)
try:
state = restore_checkpoint(ckpt_path, state, device=config.device)
except:
raise
ema.copy_to(score_model.parameters())
print(f"loaded model is trained till iter {state['step'] // config.training.iter_size}")
save_file_path = os.path.join(eval_dir, f"{idx}.npy")
samples, n = sampling_fn(score_model)
samples = samples.cpu().numpy()
np.save(save_file_path, samples)
def slerp(z1, z2, alpha):
'''
Spherical Linear Interpolation
'''
theta = torch.acos(torch.sum(z1 * z2) / (torch.norm(z1) * torch.norm(z2)))
return (
torch.sin((1 - alpha) * theta) / torch.sin(theta) * z1
+ torch.sin(alpha * theta) / torch.sin(theta) * z2
)
def uncond_gen_interp(
config,
idx=0,
):
"""
Generation with interpolation between initial noises
Used for DDIM
"""
with torch.no_grad():
eval_dir, ckpt_path = config.eval.eval_dir, config.eval.ckpt_path
# Create directory to eval_folder
os.makedirs(eval_dir, exist_ok=True)
scaler, inverse_scaler = lambda x: x, lambda x: x
# Initialize model
score_model = mutils.create_model(config)
optimizer = losses.get_optimizer(config, score_model.parameters())
ema = ExponentialMovingAverage(score_model.parameters(), decay=config.model.ema_rate)
state = dict(optimizer=optimizer, model=score_model, ema=ema, step=0)
# Setup SDEs
sde = sde_lib.VPSDE(beta_min=config.model.beta_min, beta_max=config.model.beta_max, N=config.model.num_scales)
img_size = config.data.image_size
grid_mask = torch.load(f'./data/grid_mask_{img_size}.pt').view(1, img_size, img_size, img_size).to("cuda")
sampling_eps = 1e-3
sampling_shape = (config.eval.batch_size,
config.data.num_channels,
config.data.image_size, config.data.image_size, config.data.image_size)
sampling_fn = sampling.get_sampling_fn(config, sde, sampling_shape, inverse_scaler, sampling_eps, grid_mask=grid_mask)
assert os.path.exists(ckpt_path)
print('ckpt path:', ckpt_path)
try:
state = restore_checkpoint(ckpt_path, state, device=config.device)
except:
raise
ema.copy_to(score_model.parameters())
print(f"loaded model is trained till iter {state['step'] // config.training.iter_size}")
save_file_path = os.path.join(eval_dir, f"{idx}.npy")
noise = sde.prior_sampling(
(2, config.data.num_channels, config.data.image_size, config.data.image_size, config.data.image_size)
).to(config.device)
x0 = torch.zeros(sampling_shape, device=config.device)
x0[0] = noise[0]
x0[-1] = noise[1]
for i in range(1, batch_size - 1):
x[i] = slerp(x[0], x[-1], i / float(batch_size - 1))
samples, n = sampling_fn(score_model, x0=x0)
samples = samples.cpu().numpy()
np.save(save_file_path, samples)
def cond_gen(
config,
save_fname='0',
):
"""
Conditional Generation with partially completed dmtet from a 2.5D view (converted into a cubic grid)
"""
with torch.no_grad():
eval_dir, ckpt_path = config.eval.eval_dir, config.eval.ckpt_path
# Create directory to eval_folder
os.makedirs(eval_dir, exist_ok=True)
scaler, inverse_scaler = lambda x: x, lambda x: x
# Initialize model
score_model = mutils.create_model(config)
optimizer = losses.get_optimizer(config, score_model.parameters())
ema = ExponentialMovingAverage(score_model.parameters(), decay=config.model.ema_rate)
state = dict(optimizer=optimizer, model=score_model, ema=ema, step=0)
# Setup SDEs
sde = sde_lib.VPSDE(beta_min=config.model.beta_min, beta_max=config.model.beta_max, N=config.model.num_scales)
resolution = config.data.image_size
grid_mask = torch.load(f'./data/grid_mask_{resolution}.pt').view(1, 1, resolution, resolution, resolution).to("cuda")
sampling_eps = 1e-3
sampling_shape = (config.eval.batch_size,
config.data.num_channels,
resolution, resolution, resolution)
sampling_fn = sampling.get_sampling_fn(config, sde, sampling_shape, inverse_scaler, sampling_eps, grid_mask=grid_mask)
assert os.path.exists(ckpt_path)
print('ckpt path:', ckpt_path)
try:
state = restore_checkpoint(ckpt_path, state, device=config.device)
except:
raise
ema.copy_to(score_model.parameters())
print(f"loaded model is trained till iter {state['step'] // config.training.iter_size}")
save_file_path = os.path.join(eval_dir, f"{save_fname}.npy")
### Conditional but free gradients; start from small t
partial_dict = torch.load(config.eval.partial_dmtet_path)
partial_sdf = partial_dict['sdf']
partial_mask = partial_dict['vis']
### compute the mapping from tet indices to 3D cubic grid vertex indices
tet_path = config.eval.tet_path
tet = np.load(tet_path)
vertices = torch.tensor(tet['vertices'])
vertices_unique = vertices[:].unique()
dx = vertices_unique[1] - vertices_unique[0]
ind_to_coord = (torch.round(
(vertices - vertices.min()) / dx)
).long()
partial_sdf_grid = torch.zeros((1, 1, resolution, resolution, resolution))
partial_sdf_grid[0, 0, ind_to_coord[:, 0], ind_to_coord[:, 1], ind_to_coord[:, 2]] = partial_sdf
partial_mask_grid = torch.zeros((1, 1, resolution, resolution, resolution))
partial_mask_grid[0, 0, ind_to_coord[:, 0], ind_to_coord[:, 1], ind_to_coord[:, 2]] = partial_mask.float()
samples, n = sampling_fn(
score_model,
partial=partial_sdf_grid.cuda(),
partial_mask=partial_mask_grid.cuda(),
freeze_iters=config.eval.freeze_iters
)
samples = samples.cpu().numpy()
np.save(save_file_path, samples)