From 57039104ef25c6635482bca9cd880ea66e3ebe91 Mon Sep 17 00:00:00 2001 From: Zhen Liu Date: Wed, 2 Aug 2023 02:19:16 +0200 Subject: [PATCH] Update shapenet_dmtet_dataset.py --- lib/dataset/shapenet_dmtet_dataset.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/lib/dataset/shapenet_dmtet_dataset.py b/lib/dataset/shapenet_dmtet_dataset.py index 5782cd4..f9cf7f4 100644 --- a/lib/dataset/shapenet_dmtet_dataset.py +++ b/lib/dataset/shapenet_dmtet_dataset.py @@ -7,7 +7,7 @@ import json import argparse class ShapeNetDMTetDataset(Dataset): - def __init__(self, root, grid_mask, deform_scale=1.0, aug=False, filter_meta_path=None, normalize_sdf=True): + def __init__(self, root, grid_mask, deform_scale=1.0, aug=False, filter_meta_path=None, normalize_sdf=True, extension='pt'): super().__init__() self.fpath_list = json.load(open(root, 'r')) self.deform_scale = deform_scale @@ -17,10 +17,12 @@ class ShapeNetDMTetDataset(Dataset): self.aug = aug self.grid_mask = grid_mask.cpu() self.resolution = self.grid_mask.size(-1) + self.extension = extension + assert self.extension in ['pt', 'npy'] if filter_meta_path is not None: self.filter_ids = json.load(open(filter_meta_path, 'r')) - full_id_list = [int(x.rstrip().split('_')[-1][:-3]) for i, x in enumerate(self.fpath_list)] + full_id_list = [int(x.rstrip().split('_')[-1][:-len(self.extension)-1]) for i, x in enumerate(self.fpath_list)] fpath_idx_list = [i for i, x in enumerate(full_id_list) if x in self.filter_ids] self.fpath_list = [self.fpath_list[i] for i in fpath_idx_list] @@ -29,7 +31,10 @@ class ShapeNetDMTetDataset(Dataset): def __getitem__(self, idx): with torch.no_grad(): - datum = torch.load(self.fpath_list[idx], map_location='cpu') + if self.extension == 'pt': + datum = torch.load(self.fpath_list[idx], map_location='cpu') + else: + datum = torch.tensor(np.load(self.fpath_list[idx])) if self.normalize_sdf: sdf_sign = torch.sign(datum[:, :1]) sdf_sign[sdf_sign == 0] = 1.0