kopia lustrzana https://github.com/lzzcd001/MeshDiffusion
50 wiersze
2.1 KiB
Python
50 wiersze
2.1 KiB
Python
import os
|
|
import sys
|
|
import torch
|
|
from torch.utils.data import Dataset
|
|
import json
|
|
|
|
import argparse
|
|
|
|
class ShapeNetDMTetDataset(Dataset):
|
|
def __init__(self, root, grid_mask, deform_scale=1.0, aug=False, filter_meta_path=None, normalize_sdf=True):
|
|
super().__init__()
|
|
self.fpath_list = json.load(open(root, 'r'))
|
|
self.deform_scale = deform_scale
|
|
self.normalize_sdf = normalize_sdf
|
|
print(f"dataset with sdf normalized: {normalize_sdf}")
|
|
self.coeff = torch.tensor([1.0, 1.0, self.deform_scale, self.deform_scale, self.deform_scale]).view(-1, 1, 1, 1)
|
|
self.aug = aug
|
|
self.grid_mask = grid_mask.cpu()
|
|
self.resolution = self.grid_mask.size(-1)
|
|
|
|
if filter_meta_path is not None:
|
|
self.filter_ids = json.load(open(filter_meta_path, 'r'))
|
|
full_id_list = [int(x.rstrip().split('_')[-1][:-3]) for i, x in enumerate(self.fpath_list)]
|
|
fpath_idx_list = [i for i, x in enumerate(full_id_list) if x in self.filter_ids]
|
|
self.fpath_list = [self.fpath_list[i] for i in fpath_idx_list]
|
|
|
|
def __len__(self):
|
|
return len(self.fpath_list)
|
|
|
|
def __getitem__(self, idx):
|
|
with torch.no_grad():
|
|
datum = torch.load(self.fpath_list[idx], map_location='cpu')
|
|
if self.normalize_sdf:
|
|
sdf_sign = torch.sign(datum[:, :1])
|
|
sdf_sign[sdf_sign == 0] = 1.0
|
|
datum[:, :1] = sdf_sign
|
|
if self.aug:
|
|
nonempty_mask = (datum[1:].abs().sum(dim=0, keepdim=True) != 0)
|
|
datum[1:] = datum[1:] + (torch.rand(3)[:, None, None, None] - 0.5) * 0.01 * nonempty_mask / (datum.size(-1) / self.resolution)
|
|
|
|
if datum.size(-1) < self.resolution:
|
|
datum = datum * self.grid_mask[0, :, :datum.size(-1), :datum.size(-1), :datum.size(-1)]
|
|
else:
|
|
datum = datum * self.grid_mask[0]
|
|
|
|
if datum.size(-1) < self.resolution:
|
|
diff = self.resolution - datum.size(-1)
|
|
datum = torch.nn.functional.pad(datum, (0, diff, 0, diff, 0, diff, 0, 0))
|
|
return datum
|