MeshDiffusion/lib/diffusion/utils.py

31 wiersze
950 B
Python

import torch
import tensorflow as tf
import os
import logging
def restore_checkpoint(ckpt_dir, state, device, strict=False):
if not tf.io.gfile.exists(ckpt_dir):
tf.io.gfile.makedirs(os.path.dirname(ckpt_dir))
logging.warning(f"No checkpoint found at {ckpt_dir}. "
f"Returned the same state as input")
if strict:
raise
return state
else:
loaded_state = torch.load(ckpt_dir, map_location=device)
state['optimizer'].load_state_dict(loaded_state['optimizer'])
state['model'].load_state_dict(loaded_state['model'], strict=False)
state['ema'].load_state_dict(loaded_state['ema'])
state['step'] = loaded_state['step']
return state
def save_checkpoint(ckpt_dir, state):
saved_state = {
'optimizer': state['optimizer'].state_dict(),
'model': state['model'].state_dict(),
'ema': state['ema'].state_dict(),
'step': state['step']
}
torch.save(saved_state, ckpt_dir)