Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions env.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,14 @@
import os

BUCKET_NAME = 'static.netwrck.com'
BUCKET_PATH = 'static/uploads'

# Default Flux pipeline model name. Can be overridden with the
# ``FLUX_MODEL_NAME`` environment variable.
FLUX_MODEL_NAME = os.environ.get(
"FLUX_MODEL_NAME", "black-forest-labs/FLUX.1-schnell"
)

# Toggle loading of the Flux pipeline via the ``ENABLE_FLUX_PIPELINE``
# environment variable. Any value other than "0" enables it.
ENABLE_FLUX_PIPELINE = os.environ.get("ENABLE_FLUX_PIPELINE", "1") != "0"
84 changes: 83 additions & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
ControlNetModel,
StableDiffusionXLControlNetPipeline,
AutoPipelineForImage2Image,
FluxPipeline,
)
from diffusers.utils import load_image
from fastapi import FastAPI
Expand All @@ -39,8 +40,14 @@
from starlette.responses import FileResponse
from starlette.responses import JSONResponse
from transformers import set_seed
from cachetools import TTLCache

from env import BUCKET_PATH, BUCKET_NAME
from env import (
BUCKET_PATH,
BUCKET_NAME,
FLUX_MODEL_NAME,
ENABLE_FLUX_PIPELINE,
)
from stable_diffusion_server.bucket_api import check_if_blob_exists, upload_to_bucket
from stable_diffusion_server.bumpy_detection import detect_too_bumpy
from stable_diffusion_server.image_processing import process_image_for_stable_diffusion
Expand Down Expand Up @@ -183,6 +190,27 @@
# token merging
# tomesd.apply_patch(pipe, ratio=0.2) # light speedup

# Flux text-to-image pipeline for experimental usage
flux_pipe = None
if ENABLE_FLUX_PIPELINE:
try:
flux_pipe = FluxPipeline.from_pretrained(
FLUX_MODEL_NAME,
torch_dtype=torch.float16,
)
flux_pipe.enable_model_cpu_offload()
flux_pipe.enable_attention_slicing()
flux_pipe.enable_vae_slicing()
flux_pipe.watermark = None
except Exception as e:
logger.error(f"Failed to load Flux model: {e}")
flux_pipe = None
else:
logger.info("Flux pipeline disabled via environment variable")

# Cache to avoid repeatedly generating identical Flux images.
_flux_cache = TTLCache(maxsize=64, ttl=300)


refiner = DiffusionPipeline.from_pretrained(
# "stabilityai/stable-diffusion-xl-refiner-1.0",
Expand Down Expand Up @@ -463,6 +491,26 @@ async def create_and_upload_image(
return JSONResponse({"path": path})


@app.get("/flux_create_and_upload_image")
async def flux_create_and_upload_image(prompt: str, save_path: str = ""):
"""Generate an image using the Flux pipeline and upload it."""
path_components = save_path.split("/")[0:-1]
final_name = save_path.split("/")[-1]
if not path_components:
path_components = []
save_path = "/".join(path_components) + quote_plus(final_name)
path = get_flux_image_or_upload(prompt, save_path)
return JSONResponse({"path": path})


@app.get("/flux_health")
async def flux_health():
"""Simple health check for the Flux pipeline."""
if flux_pipe is None:
return JSONResponse({"status": "unavailable"}, status_code=503)
return JSONResponse({"status": "ok"})


@app.get("/inpaint_and_upload_image")
async def inpaint_and_upload_image(
prompt: str, image_url: str, mask_url: str, save_path: str = ""
Expand Down Expand Up @@ -577,6 +625,22 @@ def get_image_or_style_transfer_upload_to_cloud_storage(
return link


def get_flux_image_or_upload(prompt: str, save_path: str):
"""Return cached Flux image URL or generate and upload a new one."""
prompt = shorten_too_long_text(prompt)
save_path = shorten_too_long_text(save_path)
if check_if_blob_exists(save_path):
return f"https://{BUCKET_NAME}/{BUCKET_PATH}/{save_path}"
if flux_pipe is None:
raise RuntimeError("Flux pipeline unavailable")
with torch.inference_mode():
bio = create_flux_image_from_prompt(prompt)
if bio is None:
return None
link = upload_to_bucket(save_path, bio, is_bytesio=True)
return link


def get_image_or_create_upload_to_cloud_storage(
prompt: str, width: int, height: int, save_path: str
):
Expand Down Expand Up @@ -1062,6 +1126,24 @@ def inpaint_image_from_prompt(prompt, image_url: str, mask_url: str):
return image_to_bytes(image)


def create_flux_image_from_prompt(prompt: str, n_steps: int = 4) -> bytes:
"""Generate an image with the Flux pipeline and return it as bytes."""
if flux_pipe is None:
raise RuntimeError("Flux pipeline unavailable")
prompt = shorten_too_long_text(prompt)
key = (prompt, n_steps)
if key in _flux_cache:
return _flux_cache[key]
image = flux_pipe(
prompt=prompt,
num_inference_steps=n_steps,
guidance_scale=0.0,
).images[0]
result = image_to_bytes(image)
_flux_cache[key] = result
return result


def shorten_too_long_text(prompt):
if len(prompt) > 200:
# remove stopwords
Expand Down
2 changes: 1 addition & 1 deletion requirements.in
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ certifi
charset-normalizer
click
cmake
diffusers
diffusers==0.33.1
exceptiongroup
fastapi
filelock
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ cycler==0.12.1
# via matplotlib
deepcache==0.1.1
# via -r requirements.in
diffusers==0.31.0
diffusers==0.33.1
# via
# -r requirements.in
# deepcache
Expand Down
1 change: 1 addition & 0 deletions stable_diffusion_server/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@

79 changes: 79 additions & 0 deletions tests/test_flux_pipeline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import os
import sys
import importlib
from unittest.mock import patch, MagicMock

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))

with patch.dict(sys.modules, {
'cv2': MagicMock(__spec__=MagicMock()),
'google': MagicMock(),
'google.cloud': MagicMock(),
'optimum': MagicMock(__spec__=MagicMock()),
'optimum.quanto': MagicMock(),
'nltk': MagicMock(
__spec__=MagicMock(),
corpus=MagicMock(stopwords=MagicMock(words=lambda lang: [])),
),
}):
with patch('diffusers.DiffusionPipeline.from_pretrained', return_value=MagicMock()) as _:
with patch('diffusers.FluxPipeline.from_pretrained', return_value=MagicMock()) as _:
with patch('diffusers.schedulers.scheduling_lcm.LCMScheduler.from_config', return_value=MagicMock()) as _:
with patch('diffusers.AutoPipelineForImage2Image.from_pipe', return_value=MagicMock()) as _:
main = importlib.import_module('main')

@patch('main.upload_to_bucket')
@patch('main.check_if_blob_exists')
def test_get_flux_image_or_upload_existing(mock_exists, mock_upload):
mock_exists.return_value = True
result = main.get_flux_image_or_upload('test prompt', 'img.png')
assert result == f"https://{main.BUCKET_NAME}/{main.BUCKET_PATH}/img.png"
mock_upload.assert_not_called()

@patch('main.upload_to_bucket')
@patch('main.check_if_blob_exists')
def test_get_flux_image_or_upload_new(mock_exists, mock_upload):
mock_exists.return_value = False
mock_upload.return_value = 'url'
fake_image = MagicMock()
fake_image.images = [MagicMock()]
with patch.object(main, 'flux_pipe') as mock_pipe:
mock_pipe.return_value = fake_image
with patch('main.image_to_bytes', return_value=b'x'):
result = main.get_flux_image_or_upload('test', 'new.png')
assert result == 'url'
mock_upload.assert_called_once()


def test_flux_pipeline_disabled():
with patch.dict(os.environ, {"ENABLE_FLUX_PIPELINE": "0"}):
import importlib
with patch.dict(sys.modules, {
'cv2': MagicMock(__spec__=MagicMock()),
'google': MagicMock(),
'google.cloud': MagicMock(),
}):
with patch('diffusers.DiffusionPipeline.from_pretrained', return_value=MagicMock()):
with patch('diffusers.FluxPipeline.from_pretrained', return_value=MagicMock()):
reloaded = importlib.reload(main)
assert reloaded.flux_pipe is None


def test_create_flux_image_cache():
fake_img = MagicMock()
fake_img.images = [MagicMock()]
with patch.object(main, 'flux_pipe', return_value=fake_img) as mock_pipe:
with patch('main.image_to_bytes', return_value=b'x'):
main._flux_cache.clear()
res1 = main.create_flux_image_from_prompt('abc')
res2 = main.create_flux_image_from_prompt('abc')
assert res1 == res2
assert mock_pipe.call_count == 1


def test_flux_health_endpoint():
from fastapi.testclient import TestClient
client = TestClient(main.app)
with patch.object(main, 'flux_pipe', MagicMock()):
resp = client.get('/flux_health')
assert resp.status_code == 200
Loading