diff --git a/apply_factor.py b/apply_factor.py index 53cd0e04..8018f2a4 100755 --- a/apply_factor.py +++ b/apply_factor.py @@ -21,6 +21,13 @@ default=5, help="scalar factors for moving latent vectors along eigenvector", ) + parser.add_argument( + "-d_num", + "--degree_num", + type=int, + default=3, + help="number of scalar factors for moving latent vectors along eigenvector", + ) parser.add_argument("--ckpt", type=str, required=True, help="stylegan2 checkpoints") parser.add_argument( "--size", type=int, default=256, help="output image size of the generator" @@ -58,29 +65,23 @@ latent = torch.randn(args.n_sample, 512, device=args.device) latent = g.get_latent(latent) - direction = args.degree * eigvec[:, args.index].unsqueeze(0) + direction = eigvec[:, args.index].unsqueeze(0) - img, _ = g( - [latent], - truncation=args.truncation, - truncation_latent=trunc, - input_is_latent=True, - ) - img1, _ = g( - [latent + direction], - truncation=args.truncation, - truncation_latent=trunc, - input_is_latent=True, - ) - img2, _ = g( - [latent - direction], - truncation=args.truncation, - truncation_latent=trunc, - input_is_latent=True, - ) + img_list = [] + + for u in torch.linspace(- args.degree, args.degree, args.d_num): + + img_batch, _ = g( + [latent + u * direction], + truncation=args.truncation, + truncation_latent=trunc, + input_is_latent=True, + ) + + img_list.append(img_batch) grid = utils.save_image( - torch.cat([img1, img, img2], 0), + torch.cat(img_list, 0), f"{args.out_prefix}_index-{args.index}_degree-{args.degree}.png", normalize=True, range=(-1, 1),