Skip to content

Commit

Permalink
Remove more code duplication
Browse files Browse the repository at this point in the history
  • Loading branch information
yorickvP committed Nov 29, 2024
1 parent e200df8 commit 2cd8d49
Showing 1 changed file with 59 additions and 73 deletions.
132 changes: 59 additions & 73 deletions predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -615,6 +615,10 @@ def shared_predict(
prompt: str,
num_outputs: int,
num_inference_steps: int,
*,
disable_safety_checker: bool,
output_format: str,
output_quality: int,
guidance: float = 3.5, # schnell ignores guidance within the model, fine to have default
image: Path = None, # img2img for flux-dev
prompt_strength: float = 0.8,
Expand All @@ -623,8 +627,13 @@ def shared_predict(
height: int = 1024,
mask: Path = None, # inpainting
):
if image and go_fast:
print(
"img2img (or inpainting) not supported with fp8 quantization; running with bf16"
)
go_fast = False
if go_fast and not self.disable_fp8:
return self.fp8_predict(
imgs, np_imgs = self.fp8_predict(
prompt=prompt,
num_outputs=num_outputs,
num_inference_steps=num_inference_steps,
Expand All @@ -635,19 +644,28 @@ def shared_predict(
width=width,
height=height,
)
if self.disable_fp8:
print("running bf16 model, fp8 disabled")
return self.base_predict(
prompt=prompt,
num_outputs=num_outputs,
num_inference_steps=num_inference_steps,
guidance=guidance,
image=image,
prompt_strength=prompt_strength,
seed=seed,
width=width,
height=height,
mask=mask,
else:
if self.disable_fp8:
print("running bf16 model, fp8 disabled")
imgs, np_imgs = self.base_predict(
prompt=prompt,
num_outputs=num_outputs,
num_inference_steps=num_inference_steps,
guidance=guidance,
image=image,
prompt_strength=prompt_strength,
seed=seed,
width=width,
height=height,
mask=mask,
)

return self.postprocess(
imgs,
disable_safety_checker,
output_format,
output_quality,
np_images=np_imgs,
)


Expand All @@ -674,24 +692,19 @@ def predict(
megapixels: str = SHARED_INPUTS.megapixels,
) -> List[Path]:
width, height = self.preprocess(aspect_ratio, megapixels)
imgs, np_imgs = self.shared_predict(
return self.shared_predict(
go_fast,
prompt,
num_outputs,
num_inference_steps=num_inference_steps,
disable_safety_checker=disable_safety_checker,
output_format=output_format,
output_quality=output_quality,
seed=seed,
width=width,
height=height,
)

return self.postprocess(
imgs,
disable_safety_checker,
output_format,
output_quality,
np_images=np_imgs,
)


class DevPredictor(Predictor):
def setup(self) -> None:
Expand Down Expand Up @@ -728,15 +741,15 @@ def predict(
go_fast: bool = SHARED_INPUTS.go_fast,
megapixels: str = SHARED_INPUTS.megapixels,
) -> List[Path]:
if image and go_fast:
print("img2img not supported with fp8 quantization; running with bf16")
go_fast = False
width, height = self.preprocess(aspect_ratio, megapixels)
imgs, np_imgs = self.shared_predict(
return self.shared_predict(
go_fast,
prompt,
num_outputs,
num_inference_steps,
num_inference_steps=num_inference_steps,
disable_safety_checker=disable_safety_checker,
output_format=output_format,
output_quality=output_quality,
guidance=guidance,
image=image,
prompt_strength=prompt_strength,
Expand All @@ -745,14 +758,6 @@ def predict(
height=height,
)

return self.postprocess(
imgs,
disable_safety_checker,
output_format,
output_quality,
np_images=np_imgs,
)


class SchnellLoraPredictor(Predictor):
def setup(self) -> None:
Expand Down Expand Up @@ -782,24 +787,19 @@ def predict(
self.handle_loras(go_fast, lora_weights, lora_scale)

width, height = self.preprocess(aspect_ratio, megapixels)
imgs, np_imgs = self.shared_predict(
return self.shared_predict(
go_fast,
prompt,
num_outputs,
num_inference_steps=num_inference_steps,
disable_safety_checker=disable_safety_checker,
output_format=output_format,
output_quality=output_quality,
seed=seed,
width=width,
height=height,
)

return self.postprocess(
imgs,
disable_safety_checker,
output_format,
output_quality,
np_images=np_imgs,
)


class DevLoraPredictor(Predictor):
def setup(self, t5=None, clip=None, ae=None) -> None:
Expand Down Expand Up @@ -839,18 +839,17 @@ def predict(
lora_scale: float = SHARED_INPUTS.lora_scale,
megapixels: str = SHARED_INPUTS.megapixels,
) -> List[Path]:
if image and go_fast:
print("img2img not supported with fp8 quantization; running with bf16")
go_fast = False

self.handle_loras(go_fast, lora_weights, lora_scale)

width, height = self.preprocess(aspect_ratio, megapixels)
imgs, np_imgs = self.shared_predict(
return self.shared_predict(
go_fast,
prompt,
num_outputs,
num_inference_steps,
num_inference_steps=num_inference_steps,
disable_safety_checker=disable_safety_checker,
output_format=output_format,
output_quality=output_quality,
guidance=guidance,
image=image,
prompt_strength=prompt_strength,
Expand All @@ -859,14 +858,6 @@ def predict(
height=height,
)

return self.postprocess(
imgs,
disable_safety_checker,
output_format,
output_quality,
np_images=np_imgs,
)


class HotswapPredictor(BasePredictor):
def setup(self) -> None:
Expand Down Expand Up @@ -965,21 +956,24 @@ def predict(
else:
width, height = model.preprocess(aspect_ratio, megapixels=megapixels)

model.handle_loras(
go_fast, replicate_weights, lora_scale, extra_lora, extra_lora_scale
)

if image and go_fast:
print(
"Img2img and inpainting not supported with fast fp8 inference; will run in bf16"
)
go_fast = False

imgs, np_imgs = model.shared_predict(
model.handle_loras(
go_fast, replicate_weights, lora_scale, extra_lora, extra_lora_scale
)

return model.shared_predict(
go_fast,
prompt,
num_outputs,
num_inference_steps,
num_inference_steps=num_inference_steps,
disable_safety_checker=disable_safety_checker,
output_format=output_format,
output_quality=output_quality,
guidance=guidance_scale,
image=image,
prompt_strength=prompt_strength,
Expand All @@ -989,14 +983,6 @@ def predict(
mask=mask,
)

return model.postprocess(
imgs,
disable_safety_checker,
output_format,
output_quality,
np_images=np_imgs,
)


class TestPredictor(Predictor):
def setup(self) -> None:
Expand Down

0 comments on commit 2cd8d49

Please sign in to comment.