diff --git a/main.py b/main.py index 4702620..b9a82a4 100644 --- a/main.py +++ b/main.py @@ -30,6 +30,7 @@ ControlNetModel, StableDiffusionXLControlNetPipeline, AutoPipelineForImage2Image, + FluxPipeline, ) from diffusers.utils import load_image from fastapi import FastAPI @@ -183,6 +184,20 @@ # token merging # tomesd.apply_patch(pipe, ratio=0.2) # light speedup +# Flux text-to-image pipeline for experimental usage +try: + flux_pipe = FluxPipeline.from_pretrained( + "black-forest-labs/FLUX.1-schnell", + torch_dtype=torch.float16, + ) + flux_pipe.enable_model_cpu_offload() + flux_pipe.enable_attention_slicing() + flux_pipe.enable_vae_slicing() + flux_pipe.watermark = None +except Exception as e: + logger.error(f"Failed to load Flux model: {e}") + flux_pipe = None + refiner = DiffusionPipeline.from_pretrained( # "stabilityai/stable-diffusion-xl-refiner-1.0", @@ -463,6 +478,18 @@ async def create_and_upload_image( return JSONResponse({"path": path}) +@app.get("/flux_create_and_upload_image") +async def flux_create_and_upload_image(prompt: str, save_path: str = ""): + """Generate an image using the Flux pipeline and upload it.""" + path_components = save_path.split("/")[0:-1] + final_name = save_path.split("/")[-1] + if not path_components: + path_components = [] + save_path = "/".join(path_components) + quote_plus(final_name) + path = get_flux_image_or_upload(prompt, save_path) + return JSONResponse({"path": path}) + + @app.get("/inpaint_and_upload_image") async def inpaint_and_upload_image( prompt: str, image_url: str, mask_url: str, save_path: str = "" @@ -577,6 +604,22 @@ def get_image_or_style_transfer_upload_to_cloud_storage( return link +def get_flux_image_or_upload(prompt: str, save_path: str): + """Return cached Flux image URL or generate and upload a new one.""" + prompt = shorten_too_long_text(prompt) + save_path = shorten_too_long_text(save_path) + if check_if_blob_exists(save_path): + return f"https://{BUCKET_NAME}/{BUCKET_PATH}/{save_path}" + if flux_pipe is None: + raise RuntimeError("Flux pipeline unavailable") + with torch.inference_mode(): + bio = create_flux_image_from_prompt(prompt) + if bio is None: + return None + link = upload_to_bucket(save_path, bio, is_bytesio=True) + return link + + def get_image_or_create_upload_to_cloud_storage( prompt: str, width: int, height: int, save_path: str ): @@ -1062,6 +1105,19 @@ def inpaint_image_from_prompt(prompt, image_url: str, mask_url: str): return image_to_bytes(image) +def create_flux_image_from_prompt(prompt: str, n_steps: int = 4) -> bytes: + """Generate an image with the Flux pipeline and return it as bytes.""" + if flux_pipe is None: + raise RuntimeError("Flux pipeline unavailable") + prompt = shorten_too_long_text(prompt) + image = flux_pipe( + prompt=prompt, + num_inference_steps=n_steps, + guidance_scale=0.0, + ).images[0] + return image_to_bytes(image) + + def shorten_too_long_text(prompt): if len(prompt) > 200: # remove stopwords diff --git a/requirements.in b/requirements.in index b8a5052..be520a4 100644 --- a/requirements.in +++ b/requirements.in @@ -5,7 +5,7 @@ certifi charset-normalizer click cmake -diffusers +diffusers==0.33.1 exceptiongroup fastapi filelock diff --git a/requirements.txt b/requirements.txt index 409b6fa..ea62030 100644 --- a/requirements.txt +++ b/requirements.txt @@ -62,7 +62,7 @@ cycler==0.12.1 # via matplotlib deepcache==0.1.1 # via -r requirements.in -diffusers==0.31.0 +diffusers==0.33.1 # via # -r requirements.in # deepcache diff --git a/stable_diffusion_server/__init__.py b/stable_diffusion_server/__init__.py new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/stable_diffusion_server/__init__.py @@ -0,0 +1 @@ + diff --git a/tests/test_flux_pipeline.py b/tests/test_flux_pipeline.py new file mode 100644 index 0000000..f53c34b --- /dev/null +++ b/tests/test_flux_pipeline.py @@ -0,0 +1,34 @@ +import os +import sys +import importlib +from unittest.mock import patch, MagicMock + +with patch.dict(sys.modules, { + 'cv2': MagicMock(__spec__=MagicMock()), + 'google': MagicMock(), + 'google.cloud': MagicMock(), +}): + with patch('diffusers.DiffusionPipeline.from_pretrained', return_value=MagicMock()) as _: + main = importlib.import_module('main') + +@patch('main.upload_to_bucket') +@patch('main.check_if_blob_exists') +def test_get_flux_image_or_upload_existing(mock_exists, mock_upload): + mock_exists.return_value = True + result = main.get_flux_image_or_upload('test prompt', 'img.png') + assert result == f"https://{main.BUCKET_NAME}/{main.BUCKET_PATH}/img.png" + mock_upload.assert_not_called() + +@patch('main.upload_to_bucket') +@patch('main.check_if_blob_exists') +def test_get_flux_image_or_upload_new(mock_exists, mock_upload): + mock_exists.return_value = False + mock_upload.return_value = 'url' + fake_image = MagicMock() + fake_image.images = [MagicMock()] + with patch.object(main, 'flux_pipe') as mock_pipe: + mock_pipe.return_value = fake_image + with patch('main.image_to_bytes', return_value=b'x'): + result = main.get_flux_image_or_upload('test', 'new.png') + assert result == 'url' + mock_upload.assert_called_once()