diff --git a/lib/diffusion/sampling.py b/lib/diffusion/sampling.py index 6de18d6..93269c3 100644 --- a/lib/diffusion/sampling.py +++ b/lib/diffusion/sampling.py @@ -390,7 +390,7 @@ def get_pc_sampler(sde, shape, predictor, corrector, inverse_scaler, snr, n_steps=n_steps) def pc_sampler(model, - partial=None, partial_grid_mask=None, partial_channel=0, + partial=None, partial_mask=None, partial_channel=0, freeze_iters=None): """ The PC sampler funciton. @@ -520,7 +520,7 @@ def get_ddim_sampler(sde, shape, predictor, inverse_scaler, n_steps=1, continuous=False) def ddim_sampler(model, schedule='quad', num_steps=100, x0=None, - partial=None, partial_grid_mask=None, partial_channel=0): + partial=None, partial_mask=None, partial_channel=0): """ The PC sampler funciton. Args: