kopia lustrzana https://github.com/lzzcd001/MeshDiffusion
Porównaj commity
6 Commity
45976e53db
...
fef7edf261
Autor | SHA1 | Data |
---|---|---|
Zhen Liu | fef7edf261 | |
Zhen Liu | 9512f9fe67 | |
Zhen Liu | d571b74081 | |
Zhen Liu | 57039104ef | |
Zhen Liu | b9e13fe51f | |
Zhen Liu | 172af55945 |
20
README.md
20
README.md
|
@ -143,6 +143,26 @@ python main_diffusion.py --mode=train --config=$DIFFUSION_CONFIG \
|
|||
|
||||
where `$TRAIN_SPLIT_FILE` is a json list of indices to be included in the training set. Examples in `metadata/train_split/`. For the diffusion model config file, please refer to `configs/res64.py` or `configs/res128.py`.
|
||||
|
||||
#### Training with our dataset
|
||||
|
||||
The provided datasets are stored in `.npy` instead of `.pt`. Run the following instead
|
||||
|
||||
```
|
||||
cd ../metadata/
|
||||
python save_meta.py --data_path $DMTET_NPY_FOLDER --json_path $META_FILE --extension npy
|
||||
```
|
||||
|
||||
Train a diffusion model
|
||||
|
||||
```
|
||||
cd ..
|
||||
|
||||
python main_diffusion.py --mode=train --config=$DIFFUSION_CONFIG \
|
||||
--config.data.meta_path=$META_FILE \
|
||||
--config.data.filter_meta_path=$TRAIN_SPLIT_FILE \
|
||||
--config.data.extension=npy
|
||||
```
|
||||
|
||||
## Texture Generation
|
||||
|
||||
Follow the instructions in https://github.com/TEXTurePaper/TEXTurePaper and create text-conditioned textures for the generated meshes.
|
||||
|
|
|
@ -56,6 +56,7 @@ def get_default_configs():
|
|||
data.normalize_sdf = True
|
||||
data.meta_path = "PLACEHOLDER" ### metadata for all dataset files
|
||||
data.filter_meta_path = "PLACEHOLDER" ### metadata for the list of training samples
|
||||
data.extension = 'pt' ### either 'pt' or 'npy', depending how the data are stored
|
||||
|
||||
# model
|
||||
config.model = model = ml_collections.ConfigDict()
|
||||
|
@ -85,4 +86,4 @@ def get_default_configs():
|
|||
# rendering
|
||||
config.render = render = ml_collections.ConfigDict()
|
||||
|
||||
return config
|
||||
return config
|
||||
|
|
|
@ -7,7 +7,7 @@ import json
|
|||
import argparse
|
||||
|
||||
class ShapeNetDMTetDataset(Dataset):
|
||||
def __init__(self, root, grid_mask, deform_scale=1.0, aug=False, filter_meta_path=None, normalize_sdf=True):
|
||||
def __init__(self, root, grid_mask, deform_scale=1.0, aug=False, filter_meta_path=None, normalize_sdf=True, extension='pt'):
|
||||
super().__init__()
|
||||
self.fpath_list = json.load(open(root, 'r'))
|
||||
self.deform_scale = deform_scale
|
||||
|
@ -17,10 +17,12 @@ class ShapeNetDMTetDataset(Dataset):
|
|||
self.aug = aug
|
||||
self.grid_mask = grid_mask.cpu()
|
||||
self.resolution = self.grid_mask.size(-1)
|
||||
self.extension = extension
|
||||
assert self.extension in ['pt', 'npy']
|
||||
|
||||
if filter_meta_path is not None:
|
||||
self.filter_ids = json.load(open(filter_meta_path, 'r'))
|
||||
full_id_list = [int(x.rstrip().split('_')[-1][:-3]) for i, x in enumerate(self.fpath_list)]
|
||||
full_id_list = [int(x.rstrip().split('_')[-1][:-len(self.extension)-1]) for i, x in enumerate(self.fpath_list)]
|
||||
fpath_idx_list = [i for i, x in enumerate(full_id_list) if x in self.filter_ids]
|
||||
self.fpath_list = [self.fpath_list[i] for i in fpath_idx_list]
|
||||
|
||||
|
@ -29,7 +31,10 @@ class ShapeNetDMTetDataset(Dataset):
|
|||
|
||||
def __getitem__(self, idx):
|
||||
with torch.no_grad():
|
||||
datum = torch.load(self.fpath_list[idx], map_location='cpu')
|
||||
if self.extension == 'pt':
|
||||
datum = torch.load(self.fpath_list[idx], map_location='cpu')
|
||||
else:
|
||||
datum = torch.tensor(np.load(self.fpath_list[idx]))
|
||||
if self.normalize_sdf:
|
||||
sdf_sign = torch.sign(datum[:, :1])
|
||||
sdf_sign[sdf_sign == 0] = 1.0
|
||||
|
|
|
@ -66,7 +66,7 @@ def train(config):
|
|||
|
||||
print("sdf normalized or not: ", config.data.normalize_sdf)
|
||||
train_dataset = ShapeNetDMTetDataset(json_path, deform_scale=config.model.deform_scale, aug=True, grid_mask=mask,
|
||||
filter_meta_path=config.data.filter_meta_path, normalize_sdf=config.data.normalize_sdf)
|
||||
filter_meta_path=config.data.filter_meta_path, normalize_sdf=config.data.normalize_sdf, extension=config.data.extension)
|
||||
|
||||
|
||||
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=config.training.batch_size,
|
||||
|
|
|
@ -6,9 +6,10 @@ if __name__ == "__main__":
|
|||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--data_path', type=str)
|
||||
parser.add_argument('--json_path', type=str)
|
||||
parser.add_argument('--extension', type=str, default='pt')
|
||||
args = parser.parse_args()
|
||||
|
||||
fpath_list = sorted([os.path.join(args.data_path, fname) for fname in os.listdir(args.data_path) if fname.endswith('.pt')])
|
||||
fpath_list = sorted([os.path.join(args.data_path, fname) for fname in os.listdir(args.data_path) if fname.endswith('.' + args.extension)])
|
||||
os.makedirs(args.json_path, exist_ok=True)
|
||||
json.dump(fpath_list, open(args.json_path, 'w'))
|
||||
|
||||
|
|
Ładowanie…
Reference in New Issue