-
Notifications
You must be signed in to change notification settings - Fork 45
/
Copy pathtest-txt2img.py
156 lines (135 loc) · 5.21 KB
/
test-txt2img.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
# Copyright 2022 Dirk Moerenhout. All rights reserved.
#
# This program is free software: you can redistribute it and/or modify it under the terms
# of the GNU General Public License as published by the Free Software Foundation,
# either version 3 of the License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY;
# without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
# See the GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License along with this program. If not,
# see <https://www.gnu.org/licenses/>.
# We need regular expressions support
import re
# We need argparse for handling command line arguments
import argparse
# We need os.path for isdir
import os.path
# Numpy is used to provide a random generator
import numpy
# Needed to set session options
import onnxruntime as ort
from diffusers import OnnxStableDiffusionPipeline, OnnxRuntimeModel
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--model",
type=str,
required=True,
help="Directory in current location to load model from",
)
parser.add_argument(
"--size",
default=512,
type=int,
required=False,
help="Width/Height of the picture, defaults to 512, use 768 when appropriate",
)
parser.add_argument(
"--steps",
default=30,
type=int,
required=False,
help="Scheduler steps to use",
)
parser.add_argument(
"--scale",
default=7.5,
type=float,
required=False,
help="Guidance scale (how strict it sticks to the prompt)"
)
parser.add_argument(
"--prompt",
default="a dog on a lawn with the eifel tower in the background",
type=str,
required=False,
help="Text prompt for generation",
)
parser.add_argument(
"--negprompt",
default="blurry, low quality",
type=str,
required=False,
help="Negative text prompt for generation (what to avoid)",
)
parser.add_argument(
"--seed",
type=int,
required=False,
help="Seed for generation, allows you to get the exact same image again",
)
parser.add_argument(
"--fixeddims",
action="store_true",
help="Pass fixed dimensions to ONNX Runtime. Test purposes only, NOT VRAM FRIENDLY!",
)
parser.add_argument(
"--cpu-textenc", "--cpuclip",
action="store_true",
help="Load Text Encoder on CPU to save VRAM"
)
parser.add_argument(
"--cpuvae",
action="store_true",
help="Load VAE on CPU, this will always load the Text Encoder on CPU too"
)
args = parser.parse_args()
VAECPU = TECPU = False
if args.cpuvae:
VAECPU = TECPU = True
if args.cpu_textenc:
TECPU=True
if match := re.search(r"([^/\\]*)[/\\]?$", args.model):
fmodel = match.group(1)
generator=numpy.random
imgname="testpicture-"+fmodel+"_"+str(args.size)+".png"
if args.seed is not None:
generator.seed(args.seed)
imgname="testpicture-"+fmodel+"_"+str(args.size)+"_seed"+str(args.seed)+".png"
if os.path.isdir(args.model+"/unet"):
height=args.size
width=args.size
sess_options = ort.SessionOptions()
sess_options.enable_mem_pattern = False
if args.fixeddims:
sess_options.add_free_dimension_override_by_name("unet_sample_batch", 2)
sess_options.add_free_dimension_override_by_name("unet_sample_channels", 4)
sess_options.add_free_dimension_override_by_name("unet_sample_height", 64)
sess_options.add_free_dimension_override_by_name("unet_sample_width", 64)
sess_options.add_free_dimension_override_by_name("unet_timestep_batch", 1)
sess_options.add_free_dimension_override_by_name("unet_ehs_batch", 2)
sess_options.add_free_dimension_override_by_name("unet_ehs_sequence", 77)
num_inference_steps=args.steps
guidance_scale=args.scale
prompt = args.prompt
negative_prompt = args.negprompt
if TECPU:
cputextenc=OnnxRuntimeModel.from_pretrained(args.model+"/text_encoder")
if VAECPU:
cpuvae=OnnxRuntimeModel.from_pretrained(args.model+"/vae_decoder")
pipe = OnnxStableDiffusionPipeline.from_pretrained(args.model,
provider="DmlExecutionProvider", text_encoder=cputextenc, vae_decoder=cpuvae,
vae_encoder=None)
else:
pipe = OnnxStableDiffusionPipeline.from_pretrained(args.model,
provider="DmlExecutionProvider", text_encoder=cputextenc)
else:
pipe = OnnxStableDiffusionPipeline.from_pretrained(args.model,
provider="DmlExecutionProvider", sess_options=sess_options)
image = pipe(prompt, width, height, num_inference_steps, guidance_scale,
negative_prompt,generator=generator).images[0]
image.save(imgname)
else:
print("model not found")