Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add size option to choose size with pixels #31

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 34 additions & 8 deletions predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,14 +62,24 @@
"9:21": (640, 1536),
}

# 1 megapixel sizes
SIZES = {f"{x}x{y}": (x, y) for x, y in ASPECT_RATIOS.values()}
# 0.25 megapixel sizes
SIZES.update({f"{x / 2}x{y / 2}": (x / 2, y / 2) for x, y in ASPECT_RATIOS.values()})
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

minor thing - this should be integer division {x // 2}. if it's {x / 2} then we'll have a trailing .0 on all the strings, e.g. 512.0 instead of 512.



@dataclass
class SharedInputs:
prompt: Input = Input(description="Prompt for generated image")
size: Input = Input(
description="Size of the generated image",
choices=list(SIZES.keys()),
default="1024x1024",
)
aspect_ratio: Input = Input(
description="Aspect ratio for the generated image",
choices=list(ASPECT_RATIOS.keys()),
default="1:1",
default=None,
)
num_outputs: Input = Input(
description="Number of outputs to generate", default=1, le=4, ge=1
Expand Down Expand Up @@ -99,7 +109,7 @@ class SharedInputs:
megapixels: Input = Input(
description="Approximate number of megapixels for generated image",
choices=["1", "0.25"],
default="1",
default=None,
)


Expand Down Expand Up @@ -246,11 +256,25 @@ def predict():
raise Exception("You need to instantiate a predictor for a specific flux model")

def preprocess(
self, aspect_ratio: str, seed: Optional[int], megapixels: str
self,
size: str,
aspect_ratio: str | None,
seed: Optional[int],
megapixels: str | None,
) -> Dict:
width, height = ASPECT_RATIOS.get(aspect_ratio)
if megapixels == "0.25":
width, height = width // 2, height // 2
width, height = SIZES.get(size)

# Backwards compatibility for deprecated aspect_ratio and megapixels inputs
if aspect_ratio is not None or megapixels is not None:
# set defaults
if aspect_ratio is None:
aspect_ratio = "1024x1024"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is off, aspect_ratio should default to 1:1 not 1024x1024

if megapixels is None:
megapixels = "1"

width, height = ASPECT_RATIOS.get(aspect_ratio)
if megapixels == "0.25":
width, height = width // 2, height // 2

if not seed:
seed = int.from_bytes(os.urandom(2), "big")
Expand Down Expand Up @@ -468,6 +492,7 @@ def setup(self) -> None:
def predict(
self,
prompt: str = SHARED_INPUTS.prompt,
size: str = SHARED_INPUTS.size,
aspect_ratio: str = SHARED_INPUTS.aspect_ratio,
num_outputs: int = SHARED_INPUTS.num_outputs,
num_inference_steps: int = Input(
Expand All @@ -483,7 +508,7 @@ def predict(
go_fast: bool = SHARED_INPUTS.go_fast,
megapixels: str = SHARED_INPUTS.megapixels,
) -> List[Path]:
hws_kwargs = self.preprocess(aspect_ratio, seed, megapixels)
hws_kwargs = self.preprocess(size, aspect_ratio, seed, megapixels)

if go_fast and not self.disable_fp8:
imgs, np_imgs = self.fp8_predict(
Expand Down Expand Up @@ -518,6 +543,7 @@ def setup(self) -> None:
def predict(
self,
prompt: str = SHARED_INPUTS.prompt,
size: str = SHARED_INPUTS.size,
aspect_ratio: str = SHARED_INPUTS.aspect_ratio,
image: Path = Input(
description="Input image for image to image mode. The aspect ratio of your output will match this image",
Expand Down Expand Up @@ -549,7 +575,7 @@ def predict(
if image and go_fast:
print("img2img not supported with fp8 quantization; running with bf16")
go_fast = False
hws_kwargs = self.preprocess(aspect_ratio, seed, megapixels)
hws_kwargs = self.preprocess(size, aspect_ratio, seed, megapixels)

if go_fast and not self.disable_fp8:
imgs, np_imgs = self.fp8_predict(
Expand Down
Loading