kopia lustrzana https://github.com/lzzcd001/MeshDiffusion
31 wiersze
950 B
Python
31 wiersze
950 B
Python
import torch
|
|
import tensorflow as tf
|
|
import os
|
|
import logging
|
|
|
|
|
|
def restore_checkpoint(ckpt_dir, state, device, strict=False):
|
|
if not tf.io.gfile.exists(ckpt_dir):
|
|
tf.io.gfile.makedirs(os.path.dirname(ckpt_dir))
|
|
logging.warning(f"No checkpoint found at {ckpt_dir}. "
|
|
f"Returned the same state as input")
|
|
if strict:
|
|
raise
|
|
return state
|
|
else:
|
|
loaded_state = torch.load(ckpt_dir, map_location=device)
|
|
state['optimizer'].load_state_dict(loaded_state['optimizer'])
|
|
state['model'].load_state_dict(loaded_state['model'], strict=False)
|
|
state['ema'].load_state_dict(loaded_state['ema'])
|
|
state['step'] = loaded_state['step']
|
|
return state
|
|
|
|
|
|
def save_checkpoint(ckpt_dir, state):
|
|
saved_state = {
|
|
'optimizer': state['optimizer'].state_dict(),
|
|
'model': state['model'].state_dict(),
|
|
'ema': state['ema'].state_dict(),
|
|
'step': state['step']
|
|
}
|
|
torch.save(saved_state, ckpt_dir) |