kopia lustrzana https://github.com/lzzcd001/MeshDiffusion
Update shapenet_dmtet_dataset.py
rodzic
b9e13fe51f
commit
57039104ef
|
@ -7,7 +7,7 @@ import json
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
class ShapeNetDMTetDataset(Dataset):
|
class ShapeNetDMTetDataset(Dataset):
|
||||||
def __init__(self, root, grid_mask, deform_scale=1.0, aug=False, filter_meta_path=None, normalize_sdf=True):
|
def __init__(self, root, grid_mask, deform_scale=1.0, aug=False, filter_meta_path=None, normalize_sdf=True, extension='pt'):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.fpath_list = json.load(open(root, 'r'))
|
self.fpath_list = json.load(open(root, 'r'))
|
||||||
self.deform_scale = deform_scale
|
self.deform_scale = deform_scale
|
||||||
|
@ -17,10 +17,12 @@ class ShapeNetDMTetDataset(Dataset):
|
||||||
self.aug = aug
|
self.aug = aug
|
||||||
self.grid_mask = grid_mask.cpu()
|
self.grid_mask = grid_mask.cpu()
|
||||||
self.resolution = self.grid_mask.size(-1)
|
self.resolution = self.grid_mask.size(-1)
|
||||||
|
self.extension = extension
|
||||||
|
assert self.extension in ['pt', 'npy']
|
||||||
|
|
||||||
if filter_meta_path is not None:
|
if filter_meta_path is not None:
|
||||||
self.filter_ids = json.load(open(filter_meta_path, 'r'))
|
self.filter_ids = json.load(open(filter_meta_path, 'r'))
|
||||||
full_id_list = [int(x.rstrip().split('_')[-1][:-3]) for i, x in enumerate(self.fpath_list)]
|
full_id_list = [int(x.rstrip().split('_')[-1][:-len(self.extension)-1]) for i, x in enumerate(self.fpath_list)]
|
||||||
fpath_idx_list = [i for i, x in enumerate(full_id_list) if x in self.filter_ids]
|
fpath_idx_list = [i for i, x in enumerate(full_id_list) if x in self.filter_ids]
|
||||||
self.fpath_list = [self.fpath_list[i] for i in fpath_idx_list]
|
self.fpath_list = [self.fpath_list[i] for i in fpath_idx_list]
|
||||||
|
|
||||||
|
@ -29,7 +31,10 @@ class ShapeNetDMTetDataset(Dataset):
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
def __getitem__(self, idx):
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
datum = torch.load(self.fpath_list[idx], map_location='cpu')
|
if self.extension == 'pt':
|
||||||
|
datum = torch.load(self.fpath_list[idx], map_location='cpu')
|
||||||
|
else:
|
||||||
|
datum = torch.tensor(np.load(self.fpath_list[idx]))
|
||||||
if self.normalize_sdf:
|
if self.normalize_sdf:
|
||||||
sdf_sign = torch.sign(datum[:, :1])
|
sdf_sign = torch.sign(datum[:, :1])
|
||||||
sdf_sign[sdf_sign == 0] = 1.0
|
sdf_sign[sdf_sign == 0] = 1.0
|
||||||
|
|
Ładowanie…
Reference in New Issue