Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: allow_sdxl_controlnet #411

Merged
merged 2 commits into from
May 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# Changelog

# 4.35.1

* Added allow_sdxl_controlnet worker key

# 4.35.0

* Added ability to generate QR-code images
Expand Down
12 changes: 12 additions & 0 deletions horde/apis/models/stable_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,14 @@ def __init__(self):
help="If True, this worker will pick up requests requesting ControlNet.",
location="json",
)
self.job_pop_parser.add_argument(
"allow_sdxl_controlnet",
type=bool,
required=False,
default=False,
help="If True, this worker will pick up requests requesting SDXL ControlNet.",
location="json",
)
self.job_pop_parser.add_argument(
"allow_lora",
type=bool,
Expand Down Expand Up @@ -519,6 +527,10 @@ def __init__(self, api):
default=True,
description="If True, this worker will pick up requests requesting ControlNet.",
),
"allow_sdxl_controlnet": fields.Boolean(
default=True,
description="If True, this worker will pick up requests requesting SDXL ControlNet.",
),
"allow_lora": fields.Boolean(
default=True,
description="If True, this worker will pick up requests requesting LoRas.",
Expand Down
8 changes: 8 additions & 0 deletions horde/apis/models/v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -620,6 +620,14 @@ def __init__(self, api):
default=None,
description="If True, this worker supports and allows lora requests.",
),
"controlnet": fields.Boolean(
default=None,
description="If True, this worker supports and allows controlnet requests.",
),
"sdxl_controlnet": fields.Boolean(
default=None,
description="If True, this worker supports and allows SDXL controlnet requests.",
),
"max_length": fields.Integer(
example=80,
description="The maximum tokens this worker can generate.",
Expand Down
3 changes: 2 additions & 1 deletion horde/apis/v2/stable.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ def validate(self):
if self.params.get("workflow") == "qr_code":
# QR-code pipeline cannot do batching currently
self.args["disable_batching"] = True
if not all(model_reference.get_model_baseline(model_name).startswith("stable diffusion 1") for model_name in self.args.models):
if not all(model_reference.get_model_baseline(model_name) in ['stable_diffusion 1', 'stable_diffusion_xl'] for model_name in self.args.models):
raise e.BadRequest("QR Code controlnet only works with SD 1.5 models currently", rc="ControlNetMismatch.")
if self.params.get("extra_texts") is None or len(self.params.get("extra_texts")) == 0:
raise e.BadRequest("This request requires you pass the required extra texts for this workflow.", rc="MissingExtraTexts.")
Expand Down Expand Up @@ -588,6 +588,7 @@ def check_in(self):
allow_unsafe_ipaddr=self.args.allow_unsafe_ipaddr,
allow_post_processing=self.args.allow_post_processing,
allow_controlnet=self.args.allow_controlnet,
allow_sdxl_controlnet=self.args.allow_sdxl_controlnet,
allow_lora=self.args.allow_lora,
priority_usernames=self.priority_usernames,
)
Expand Down
10 changes: 10 additions & 0 deletions horde/classes/stable/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ class ImageWorker(Worker):
allow_painting = db.Column(db.Boolean, default=True, nullable=False)
allow_post_processing = db.Column(db.Boolean, default=True, nullable=False)
allow_controlnet = db.Column(db.Boolean, default=False, nullable=False)
allow_sdxl_controlnet = db.Column(db.Boolean, default=False, nullable=False)
allow_lora = db.Column(db.Boolean, default=False, nullable=False)
wtype = "image"

Expand All @@ -36,6 +37,7 @@ def check_in(self, max_pixels, **kwargs):
self.allow_painting = kwargs.get("allow_painting", True)
self.allow_post_processing = kwargs.get("allow_post_processing", True)
self.allow_controlnet = kwargs.get("allow_controlnet", False)
self.allow_sdxl_controlnet = kwargs.get("allow_sdxl_controlnet", False)
self.allow_lora = kwargs.get("allow_lora", False)
if len(self.get_model_names()) == 0:
self.set_models(["stable_diffusion"])
Expand Down Expand Up @@ -116,7 +118,14 @@ def can_generate(self, waiting_prompt):
if not check_bridge_capability("image_is_control", self.bridge_agent):
return [False, "bridge_version"]
if not self.allow_controlnet:
return [False, "controlnet"]
if waiting_prompt.params.get("workflow") == "qr_code":
if not check_bridge_capability("controlnet", self.bridge_agent):
return [False, "bridge_version"]
if not check_bridge_capability("qr_code", self.bridge_agent):
return [False, "bridge_version"]
if not self.allow_sdxl_controlnet:
return [False, "controlnet"]
if waiting_prompt.params.get("hires_fix") and not check_bridge_capability("hires_fix", self.bridge_agent):
return [False, "bridge_version"]
if (
Expand Down Expand Up @@ -169,6 +178,7 @@ def get_details(self, details_privilege=0):
ret_dict["painting"] = self.allow_painting if check_bridge_capability("inpainting", self.bridge_agent) else False
ret_dict["post-processing"] = self.allow_post_processing
ret_dict["controlnet"] = self.allow_controlnet
ret_dict["sdxl_controlnet"] = self.allow_sdxl_controlnet
ret_dict["lora"] = self.allow_lora
return ret_dict

Expand Down
2 changes: 1 addition & 1 deletion horde/consts.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
HORDE_VERSION = "4.35.0"
HORDE_VERSION = "4.35.1"

WHITELISTED_SERVICE_IPS = {
"212.227.227.178", # Turing Bot
Expand Down
1 change: 1 addition & 0 deletions sql_statements/4.35.1.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
ALTER TABLE workers ADD COLUMN allow_sdxl_controlnet BOOLEAN default false not null;
1 change: 1 addition & 0 deletions tests/test_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def test_simple_image_gen(api_key: str, HORDE_URL: str, CIVERSION: str) -> None:
"allow_unsafe_ipaddr": True,
"allow_post_processing": True,
"allow_controlnet": True,
"allow_sdxl_controlnet": True,
"allow_lora": True,
}
pop_req = requests.post(f"{protocol}://{HORDE_URL}/api/v2/generate/pop", json=pop_dict, headers=headers)
Expand Down
1 change: 1 addition & 0 deletions tests/test_image_extra_sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def test_simple_image_gen(api_key: str, HORDE_URL: str, CIVERSION: str) -> None:
"allow_unsafe_ipaddr": True,
"allow_post_processing": True,
"allow_controlnet": True,
"allow_sdxl_controlnet": True,
"allow_lora": True,
}
pop_req = requests.post(f"{protocol}://{HORDE_URL}/api/v2/generate/pop", json=pop_dict, headers=headers)
Expand Down
Loading