Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 56 additions & 0 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
ControlNetModel,
StableDiffusionXLControlNetPipeline,
AutoPipelineForImage2Image,
FluxPipeline,
)
from diffusers.utils import load_image
from fastapi import FastAPI
Expand Down Expand Up @@ -183,6 +184,20 @@
# token merging
# tomesd.apply_patch(pipe, ratio=0.2) # light speedup

# Flux text-to-image pipeline for experimental usage
try:
flux_pipe = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-schnell",
torch_dtype=torch.float16,
)
flux_pipe.enable_model_cpu_offload()
flux_pipe.enable_attention_slicing()
flux_pipe.enable_vae_slicing()
flux_pipe.watermark = None
except Exception as e:
logger.error(f"Failed to load Flux model: {e}")
flux_pipe = None


refiner = DiffusionPipeline.from_pretrained(
# "stabilityai/stable-diffusion-xl-refiner-1.0",
Expand Down Expand Up @@ -463,6 +478,18 @@ async def create_and_upload_image(
return JSONResponse({"path": path})


@app.get("/flux_create_and_upload_image")
async def flux_create_and_upload_image(prompt: str, save_path: str = ""):
"""Generate an image using the Flux pipeline and upload it."""
path_components = save_path.split("/")[0:-1]
final_name = save_path.split("/")[-1]
if not path_components:
path_components = []
save_path = "/".join(path_components) + quote_plus(final_name)
path = get_flux_image_or_upload(prompt, save_path)
return JSONResponse({"path": path})


@app.get("/inpaint_and_upload_image")
async def inpaint_and_upload_image(
prompt: str, image_url: str, mask_url: str, save_path: str = ""
Expand Down Expand Up @@ -577,6 +604,22 @@ def get_image_or_style_transfer_upload_to_cloud_storage(
return link


def get_flux_image_or_upload(prompt: str, save_path: str):
"""Return cached Flux image URL or generate and upload a new one."""
prompt = shorten_too_long_text(prompt)
save_path = shorten_too_long_text(save_path)
if check_if_blob_exists(save_path):
return f"https://{BUCKET_NAME}/{BUCKET_PATH}/{save_path}"
if flux_pipe is None:
raise RuntimeError("Flux pipeline unavailable")
with torch.inference_mode():
bio = create_flux_image_from_prompt(prompt)
if bio is None:
return None
link = upload_to_bucket(save_path, bio, is_bytesio=True)
return link


def get_image_or_create_upload_to_cloud_storage(
prompt: str, width: int, height: int, save_path: str
):
Expand Down Expand Up @@ -1062,6 +1105,19 @@ def inpaint_image_from_prompt(prompt, image_url: str, mask_url: str):
return image_to_bytes(image)


def create_flux_image_from_prompt(prompt: str, n_steps: int = 4) -> bytes:
"""Generate an image with the Flux pipeline and return it as bytes."""
if flux_pipe is None:
raise RuntimeError("Flux pipeline unavailable")
prompt = shorten_too_long_text(prompt)
image = flux_pipe(
prompt=prompt,
num_inference_steps=n_steps,
guidance_scale=0.0,
).images[0]
return image_to_bytes(image)


def shorten_too_long_text(prompt):
if len(prompt) > 200:
# remove stopwords
Expand Down
2 changes: 1 addition & 1 deletion requirements.in
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ certifi
charset-normalizer
click
cmake
diffusers
diffusers==0.33.1
exceptiongroup
fastapi
filelock
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ cycler==0.12.1
# via matplotlib
deepcache==0.1.1
# via -r requirements.in
diffusers==0.31.0
diffusers==0.33.1
# via
# -r requirements.in
# deepcache
Expand Down
1 change: 1 addition & 0 deletions stable_diffusion_server/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@

34 changes: 34 additions & 0 deletions tests/test_flux_pipeline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import os
import sys
import importlib
from unittest.mock import patch, MagicMock

with patch.dict(sys.modules, {
'cv2': MagicMock(__spec__=MagicMock()),
'google': MagicMock(),
'google.cloud': MagicMock(),
}):
with patch('diffusers.DiffusionPipeline.from_pretrained', return_value=MagicMock()) as _:
main = importlib.import_module('main')

@patch('main.upload_to_bucket')
@patch('main.check_if_blob_exists')
def test_get_flux_image_or_upload_existing(mock_exists, mock_upload):
mock_exists.return_value = True
result = main.get_flux_image_or_upload('test prompt', 'img.png')
assert result == f"https://{main.BUCKET_NAME}/{main.BUCKET_PATH}/img.png"
mock_upload.assert_not_called()

@patch('main.upload_to_bucket')
@patch('main.check_if_blob_exists')
def test_get_flux_image_or_upload_new(mock_exists, mock_upload):
mock_exists.return_value = False
mock_upload.return_value = 'url'
fake_image = MagicMock()
fake_image.images = [MagicMock()]
with patch.object(main, 'flux_pipe') as mock_pipe:
mock_pipe.return_value = fake_image
with patch('main.image_to_bytes', return_value=b'x'):
result = main.get_flux_image_or_upload('test', 'new.png')
assert result == 'url'
mock_upload.assert_called_once()