-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathdemo.py
30 lines (23 loc) · 1005 Bytes
/
demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
from PIL import Image
import numpy as np
import torch
import os
from rembg import remove
import torch.nn as nn
from unique3d_diffusion import Unique3dDiffuser
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--ckpt_dir", type=str, default="ckpt", help="path to the checkpoint")
parser.add_argument("--img_dir", type=str, default=None, help="path to image dir")
parser.add_argument("--img", type=str, default="data/image.png", help="path to the image")
parser.add_argument("--seed", type=int, default=-1, help="random seed")
args = parser.parse_args()
model = Unique3dDiffuser(args.ckpt_dir, args.seed)
if args.img_dir is not None:
for im_file in sorted(os.listdir(args.img_dir)):
image = model.load_image(os.path.join(args.img_dir, im_file))
model(image, save_dir="output")
else:
image = model.load_image(args.img)
model(image, save_dir="output")