diff --git a/sample_clip_guided.py b/sample_clip_guided.py index 54f2d18..5923501 100755 --- a/sample_clip_guided.py +++ b/sample_clip_guided.py @@ -49,8 +49,6 @@ def main(): help='the batch size') p.add_argument('--checkpoint', type=str, required=True, help='the checkpoint to use') - p.add_argument('--churn', type=float, default=50., - help='the amount of noise to add during sampling') p.add_argument('--clip-guidance-scale', '-cgs', type=float, default=500., help='the CLIP guidance scale') p.add_argument('--clip-model', type=str, default='ViT-B/16', choices=clip.available_models(), @@ -115,7 +113,7 @@ def run(): sigmas = K.sampling.get_sigmas_karras(args.steps, sigma_min, sigma_max, rho=7., device=device) def sample_fn(n): x = torch.randn([n, model_config['input_channels'], size[0], size[1]], device=device) * sigmas[0] - x_0 = K.sampling.sample_dpm_2(model_fn, x, sigmas, s_churn=args.churn, disable=not accelerator.is_local_main_process) + x_0 = K.sampling.sample_dpmpp_2s_ancestral(model_fn, x, sigmas, eta=1., disable=not accelerator.is_local_main_process) return x_0 x_0 = K.evaluation.compute_features(accelerator, sample_fn, lambda x: x, args.n, args.batch_size) if accelerator.is_main_process: diff --git a/train.py b/train.py index eec6d78..dbfbeb9 100755 --- a/train.py +++ b/train.py @@ -242,7 +242,7 @@ def demo(): n_per_proc = math.ceil(args.sample_n / accelerator.num_processes) x = torch.randn([n_per_proc, model_config['input_channels'], size[0], size[1]], device=device) * sigma_max sigmas = K.sampling.get_sigmas_karras(50, sigma_min, sigma_max, rho=7., device=device) - x_0 = K.sampling.sample_lms(model_ema, x, sigmas, disable=not accelerator.is_main_process) + x_0 = K.sampling.sample_dpmpp_2m(model_ema, x, sigmas, disable=not accelerator.is_main_process) x_0 = accelerator.gather(x_0)[:args.sample_n] if accelerator.is_main_process: grid = utils.make_grid(x_0, nrow=math.ceil(args.sample_n ** 0.5), padding=0) @@ -260,7 +260,7 @@ def evaluate(): sigmas = K.sampling.get_sigmas_karras(50, sigma_min, sigma_max, rho=7., device=device) def sample_fn(n): x = torch.randn([n, model_config['input_channels'], size[0], size[1]], device=device) * sigma_max - x_0 = K.sampling.sample_lms(model_ema, x, sigmas, disable=True) + x_0 = K.sampling.sample_dpmpp_2m(model_ema, x, sigmas, disable=True) return x_0 fakes_features = K.evaluation.compute_features(accelerator, sample_fn, extractor, args.evaluate_n, args.batch_size) if accelerator.is_main_process: