diff --git a/nvdiffrec/fit_singleview.py b/nvdiffrec/fit_singleview.py index 8f57a96..13327e8 100644 --- a/nvdiffrec/fit_singleview.py +++ b/nvdiffrec/fit_singleview.py @@ -732,20 +732,13 @@ if __name__ == "__main__": # Free temporaries / cached memory torch.cuda.empty_cache() ### may slow down training - torch.save({ - 'sdf': geometry.sdf.cpu().detach(), - 'sdf_ema': geometry.sdf_ema.cpu().detach(), - 'deform': (geometry.deform * vert_mask).cpu().detach(), - 'deform_unmasked': geometry.deform.cpu().detach(), - }, os.path.join(FLAGS.out_dir, 'tets_pre/dmt_dict_{:05d}.pt'.format(global_index))) - old_geometry = geometry if FLAGS.local_rank == 0 and FLAGS.validate: validate(glctx, geometry, mat, lgt, dataset_validate, os.path.join(FLAGS.out_dir, f"val_viz_pre/dmtet_validate_{FLAGS.index}_{k}_{FLAGS.split_size}"), FLAGS) else: - dmt_dict = torch.load(os.path.join(FLAGS.resume_path, 'tets_pre/dmt_dict_{:05d}.pt'.format(global_index))) + dmt_dict = torch.load(FLAGS.resume_path) if FLAGS.use_ema: geometry.sdf.data[:] = dmt_dict['sdf_ema'] else: @@ -831,7 +824,7 @@ if __name__ == "__main__": 'deform': geometry.deform.cpu().detach(), 'vis': visible_verts.cpu().detach(), 'vis_rast': visible_and_rast_verts.cpu().detach() - }, os.path.join(FLAGS.out_dir, 'tets/dmtet.pt'.format(global_index))) + }, os.path.join(FLAGS.out_dir, 'tets/dmtet.pt')) # ==============================================================================================