MeshDiffusion/lib/diffusion/trainer.py

131 wiersze
5.1 KiB
Python

import os
import sys
import numpy as np
import logging
# Keep the import below for registering all model definitions
from .models import ddpm_res64, ddpm_res128
from . import losses
from .models import utils as mutils
from .models.ema import ExponentialMovingAverage
from . import sde_lib
import torch
from torch.utils import tensorboard
from .utils import save_checkpoint, restore_checkpoint
from ..dataset.shapenet_dmtet_dataset import ShapeNetDMTetDataset
def train(config):
"""Runs the training pipeline.
Args:
config: Configuration to use.
workdir: Working directory for checkpoints and TF summaries. If this
contains checkpoint training will be resumed from the latest checkpoint.
"""
workdir = config.training.train_dir
# Create directories for experimental logs
logging.info("working dir: {:s}".format(workdir))
tb_dir = os.path.join(workdir, "tensorboard")
writer = tensorboard.SummaryWriter(tb_dir)
resolution = config.data.image_size
# Initialize model.
score_model = mutils.create_model(config)
ema = ExponentialMovingAverage(score_model.parameters(), decay=config.model.ema_rate)
optimizer = losses.get_optimizer(config, score_model.parameters())
state = dict(optimizer=optimizer, model=score_model, ema=ema, step=0)
# Create checkpoints directory
checkpoint_dir = os.path.join(workdir, "checkpoints")
# Intermediate checkpoints to resume training after pre-emption in cloud environments
checkpoint_meta_dir = os.path.join(workdir, "checkpoints-meta", "checkpoint.pth")
os.makedirs(checkpoint_dir, exist_ok=True)
os.makedirs(os.path.dirname(checkpoint_meta_dir), exist_ok=True)
# Resume training when intermediate checkpoints are detected
state = restore_checkpoint(checkpoint_meta_dir, state, config.device)
initial_step = int(state['step'])
json_path = config.data.meta_path
print("----- Assigning mask -----")
logging.info(f"{json_path}, {config.data.filter_meta_path}")
### mask on tet to ignore regions
mask = torch.load(f'./data/grid_mask_{resolution}.pt').view(1, 1, resolution, resolution, resolution).to("cuda")
if hasattr(score_model.module, 'mask'):
print("----- Assigning mask -----")
score_model.module.mask.data[:] = mask[:]
print(f"work dir: {workdir}")
print("sdf normalized or not: ", config.data.normalize_sdf)
train_dataset = ShapeNetDMTetDataset(json_path, deform_scale=config.model.deform_scale, aug=True, grid_mask=mask,
filter_meta_path=config.data.filter_meta_path, normalize_sdf=config.data.normalize_sdf)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=config.training.batch_size,
shuffle=True,
num_workers=config.data.num_workers,
pin_memory=True)
data_iter = iter(train_loader)
print("data loader set")
# Setup SDEs
sde = sde_lib.VPSDE(beta_min=config.model.beta_min, beta_max=config.model.beta_max, N=config.model.num_scales)
# Build one-step training and evaluation functions
optimize_fn = losses.optimization_manager(config)
train_step_fn = losses.get_step_fn(sde, train=True, optimize_fn=optimize_fn,
mask=mask, loss_type=config.training.loss_type)
num_train_steps = config.training.n_iters
# In case there are multiple hosts (e.g., TPU pods), only log to host 0
logging.info("Starting training loop at step %d." % (initial_step // config.training.iter_size,))
iter_size = config.training.iter_size
for step in range(initial_step // iter_size, num_train_steps + 1):
tmp_loss = 0.0
for step_inner in range(iter_size):
try:
# batch, batch_mask = next(data_iter)
batch = next(data_iter)
except StopIteration:
# StopIteration is thrown if dataset ends
# reinitialize data loader
data_iter = iter(train_loader)
batch = next(data_iter)
batch = batch.cuda()
# Execute one training step
clear_grad_flag = (step_inner == 0)
update_param_flag = (step_inner == iter_size - 1)
loss_dict = train_step_fn(state, batch, clear_grad=clear_grad_flag, update_param=update_param_flag)
loss = loss_dict['loss']
tmp_loss += loss.item()
tmp_loss /= iter_size
if step % config.training.log_freq == 0:
logging.info("step: %d, training_loss: %.5e" % (step, tmp_loss))
sys.stdout.flush()
writer.add_scalar("training_loss", loss, step)
# Save a temporary checkpoint to resume training after pre-emption periodically
if step != 0 and step % config.training.snapshot_freq_for_preemption == 0:
logging.info(f"save meta at iter {step}")
save_checkpoint(checkpoint_meta_dir, state)
# Save a checkpoint periodically and generate samples if needed
if step != 0 and step % config.training.snapshot_freq == 0 or step == num_train_steps:
logging.info(f"save model: {step}-th")
save_checkpoint(os.path.join(checkpoint_dir, f'checkpoint_{step}.pth'), state)