kopia lustrzana https://github.com/lzzcd001/MeshDiffusion
63 wiersze
1.5 KiB
Python
63 wiersze
1.5 KiB
Python
|
"""Config file for reproducing the results of DDPM on bedrooms."""
|
||
|
|
||
|
from configs.default_configs import get_default_configs
|
||
|
|
||
|
|
||
|
def get_config():
|
||
|
config = get_default_configs()
|
||
|
|
||
|
# training
|
||
|
training = config.training
|
||
|
training.sde = 'vpsde'
|
||
|
training.continuous = False
|
||
|
training.reduce_mean = True
|
||
|
training.batch_size = 8
|
||
|
training.lip_scale = None
|
||
|
training.iter_size = 4
|
||
|
|
||
|
training.snapshot_freq_for_preemption = 1000
|
||
|
|
||
|
# sampling
|
||
|
sampling = config.sampling
|
||
|
sampling.method = 'pc'
|
||
|
sampling.predictor = 'ancestral_sampling'
|
||
|
sampling.corrector = 'none'
|
||
|
|
||
|
# data
|
||
|
data = config.data
|
||
|
data.dataset = 'ShapeNet'
|
||
|
data.centered = True
|
||
|
data.image_size = 128
|
||
|
data.num_channels = 4
|
||
|
data.meta_path = "PLACEHOLDER" ### metadata for all dataset files
|
||
|
data.filter_meta_path = "PLACEHOLDER" ### metadata for the list of training samples
|
||
|
data.num_workers = 8
|
||
|
data.aug = True
|
||
|
|
||
|
|
||
|
# model
|
||
|
model = config.model
|
||
|
model.name = 'ddpm_res128_v2'
|
||
|
model.scale_by_sigma = False
|
||
|
model.num_scales = 1000
|
||
|
model.ema_rate = 0.9999
|
||
|
model.normalization = 'GroupNorm'
|
||
|
model.nonlinearity = 'swish'
|
||
|
model.nf = 128
|
||
|
model.ch_mult = (1, 1, 2, 4, 4, 4)
|
||
|
model.num_res_blocks_first = 2
|
||
|
model.num_res_blocks = 2
|
||
|
model.attn_resolutions = (16,)
|
||
|
model.resamp_with_conv = True
|
||
|
model.conditional = True
|
||
|
model.dropout = 0.1
|
||
|
|
||
|
# optim
|
||
|
optim = config.optim
|
||
|
optim.lr = 7e-5 / training.iter_size * 2.0
|
||
|
|
||
|
config.eval.batch_size = 7
|
||
|
config.seed = 42
|
||
|
|
||
|
return config
|