MeshDiffusion/lib/dataset/shapenet_dmtet_dataset.py

50 wiersze
2.1 KiB
Python

import os
import sys
import torch
from torch.utils.data import Dataset
import json
import argparse
class ShapeNetDMTetDataset(Dataset):
def __init__(self, root, grid_mask, deform_scale=1.0, aug=False, filter_meta_path=None, normalize_sdf=True):
super().__init__()
self.fpath_list = json.load(open(root, 'r'))
self.deform_scale = deform_scale
self.normalize_sdf = normalize_sdf
print(f"dataset with sdf normalized: {normalize_sdf}")
self.coeff = torch.tensor([1.0, 1.0, self.deform_scale, self.deform_scale, self.deform_scale]).view(-1, 1, 1, 1)
self.aug = aug
self.grid_mask = grid_mask.cpu()
self.resolution = self.grid_mask.size(-1)
if filter_meta_path is not None:
self.filter_ids = json.load(open(filter_meta_path, 'r'))
full_id_list = [int(x.rstrip().split('_')[-1][:-3]) for i, x in enumerate(self.fpath_list)]
fpath_idx_list = [i for i, x in enumerate(full_id_list) if x in self.filter_ids]
self.fpath_list = [self.fpath_list[i] for i in fpath_idx_list]
def __len__(self):
return len(self.fpath_list)
def __getitem__(self, idx):
with torch.no_grad():
datum = torch.load(self.fpath_list[idx], map_location='cpu')
if self.normalize_sdf:
sdf_sign = torch.sign(datum[:, :1])
sdf_sign[sdf_sign == 0] = 1.0
datum[:, :1] = sdf_sign
if self.aug:
nonempty_mask = (datum[1:].abs().sum(dim=0, keepdim=True) != 0)
datum[1:] = datum[1:] + (torch.rand(3)[:, None, None, None] - 0.5) * 0.01 * nonempty_mask / (datum.size(-1) / self.resolution)
if datum.size(-1) < self.resolution:
datum = datum * self.grid_mask[0, :, :datum.size(-1), :datum.size(-1), :datum.size(-1)]
else:
datum = datum * self.grid_mask[0]
if datum.size(-1) < self.resolution:
diff = self.resolution - datum.size(-1)
datum = torch.nn.functional.pad(datum, (0, diff, 0, diff, 0, diff, 0, 0))
return datum