-
Notifications
You must be signed in to change notification settings - Fork 25
/
service.py
62 lines (53 loc) · 1.76 KB
/
service.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import bentoml
from PIL.Image import Image
from annotated_types import Le, Ge
from typing_extensions import Annotated
BASE_MODEL_ID = "stabilityai/stable-diffusion-xl-base-1.0"
REPO = "ByteDance/SDXL-Lightning"
CKPT = "sdxl_lightning_4step_unet.safetensors" # Use the correct ckpt for your step setting!
sample_prompt = "A girl smiling"
@bentoml.service(
traffic={"timeout": 300},
workers=1,
resources={
"gpu": 1,
"gpu_type": "nvidia-l4",
},
)
class SDXLLightning:
def __init__(self) -> None:
import torch
from diffusers import (
StableDiffusionXLPipeline,
UNet2DConditionModel,
EulerDiscreteScheduler
)
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
self.unet = UNet2DConditionModel.from_config(
BASE_MODEL_ID, subfolder="unet"
).to("cuda", torch.float16)
self.unet.load_state_dict(
load_file(hf_hub_download(REPO, CKPT),
device="cuda")
)
self.pipe = StableDiffusionXLPipeline.from_pretrained(
BASE_MODEL_ID,
unet=self.unet,
torch_dtype=torch.float16,
variant="fp16"
).to("cuda")
self.pipe.scheduler = EulerDiscreteScheduler.from_config(
self.pipe.scheduler.config, timestep_spacing="trailing"
)
@bentoml.api
def txt2img(self, prompt: str = sample_prompt) -> Image:
# step number to match ckpt file version
num_inference_steps = 4
guidance_scale = 0.0
image = self.pipe(
prompt=prompt,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
).images[0]
return image