Skip to content

Commit 2f1bfa7

Browse files
authored
Merge branch 'main' into iree-turbine
2 parents 29f34ec + 6cb86a8 commit 2f1bfa7

File tree

11 files changed

+93
-59
lines changed

11 files changed

+93
-59
lines changed

.github/workflows/nightly.yml

+1
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ jobs:
5353
python process_skipfiles.py
5454
$env:SHARK_PACKAGE_VERSION=${{ env.package_version }}
5555
pip install -e .
56+
pip freeze -l
5657
pyinstaller .\apps\shark_studio\shark_studio.spec
5758
mv ./dist/nodai_shark_studio.exe ./dist/nodai_shark_studio_${{ env.package_version_ }}.exe
5859
signtool sign /f c:\g\shark_02152023.cer /fd certHash /csp "eToken Base Cryptographic Provider" /k "${{ secrets.CI_CERT }}" ./dist/nodai_shark_studio_${{ env.package_version_ }}.exe

apps/shark_studio/api/initializers.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -53,11 +53,11 @@ def initialize():
5353
clear_tmp_imgs()
5454

5555
from apps.shark_studio.web.utils.file_utils import (
56-
create_checkpoint_folders,
56+
create_model_folders,
5757
)
5858

5959
# Create custom models folders if they don't exist
60-
create_checkpoint_folders()
60+
create_model_folders()
6161

6262
import gradio as gr
6363

apps/shark_studio/api/sd.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -602,7 +602,9 @@ def view_json_file(file_path):
602602

603603
global_obj._init()
604604

605-
sd_json = view_json_file(get_resource_path("../configs/default_sd_config.json"))
605+
sd_json = view_json_file(
606+
get_resource_path(os.path.join(cmd_opts.config_dir, "default_sd_config.json"))
607+
)
606608
sd_kwargs = json.loads(sd_json)
607609
for arg in vars(cmd_opts):
608610
if arg in sd_kwargs:

apps/shark_studio/modules/ckpt_processing.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from pathlib import Path
77
from tqdm import tqdm
88
from omegaconf import OmegaConf
9+
from diffusers import StableDiffusionPipeline
910
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
1011
from diffusers.pipelines.stable_diffusion.convert_from_ckpt import (
1112
download_from_original_stable_diffusion_ckpt,
@@ -87,6 +88,7 @@ def process_custom_pipe_weights(custom_weights):
8788
), "checkpoint files supported can be any of [.ckpt, .safetensors] type"
8889
custom_weights_tgt = get_path_to_diffusers_checkpoint(custom_weights)
8990
custom_weights_params = custom_weights
91+
9092
return custom_weights_params, custom_weights_tgt
9193

9294

@@ -98,7 +100,7 @@ def get_civitai_checkpoint(url: str):
98100
base_filename = re.findall(
99101
'"([^"]*)"', response.headers["Content-Disposition"]
100102
)[0]
101-
destination_path = Path.cwd() / (cmd_opts.ckpt_dir or "models") / base_filename
103+
destination_path = Path.cwd() / (cmd_opts.model_dir or "models") / base_filename
102104

103105
# we don't have this model downloaded yet
104106
if not destination_path.is_file():

apps/shark_studio/modules/pipeline.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def __init__(
4141
self.device, self.device_id = clean_device_info(device)
4242
self.import_mlir = import_mlir
4343
self.iree_module_dict = {}
44-
self.tmp_dir = get_resource_path(os.path.join("..", "shark_tmp"))
44+
self.tmp_dir = get_resource_path(cmd_opts.tmp_dir)
4545
if not os.path.exists(self.tmp_dir):
4646
os.mkdir(self.tmp_dir)
4747
self.tempfiles = {}
@@ -55,9 +55,7 @@ def get_compiled_map(self, pipe_id, submodel="None", init_kwargs={}) -> None:
5555
# and your model map is populated with any IR - unique model IDs and their static params,
5656
# call this method to get the artifacts associated with your map.
5757
self.pipe_id = self.safe_name(pipe_id)
58-
self.pipe_vmfb_path = Path(
59-
os.path.join(get_checkpoints_path(".."), self.pipe_id)
60-
)
58+
self.pipe_vmfb_path = Path(os.path.join(get_checkpoints_path(), self.pipe_id))
6159
self.pipe_vmfb_path.mkdir(parents=False, exist_ok=True)
6260
if submodel == "None":
6361
print("\n[LOG] Gathering any pre-compiled artifacts....")

apps/shark_studio/modules/shared_cmd_opts.py

+18-3
Original file line numberDiff line numberDiff line change
@@ -339,7 +339,7 @@ def is_valid_file(arg):
339339
p.add_argument(
340340
"--output_dir",
341341
type=str,
342-
default=None,
342+
default=os.path.join(os.getcwd(), "generated_imgs"),
343343
help="Directory path to save the output images and json.",
344344
)
345345

@@ -613,12 +613,27 @@ def is_valid_file(arg):
613613
)
614614

615615
p.add_argument(
616-
"--ckpt_dir",
616+
"--tmp_dir",
617+
type=str,
618+
default=os.path.join(os.getcwd(), "shark_tmp"),
619+
help="Path to tmp directory",
620+
)
621+
622+
p.add_argument(
623+
"--config_dir",
617624
type=str,
618-
default="../models",
625+
default=os.path.join(os.getcwd(), "configs"),
626+
help="Path to config directory",
627+
)
628+
629+
p.add_argument(
630+
"--model_dir",
631+
type=str,
632+
default=os.path.join(os.getcwd(), "models"),
619633
help="Path to directory where all .ckpts are stored in order to populate "
620634
"them in the web UI.",
621635
)
636+
622637
# TODO: replace API flag when these can be run together
623638
p.add_argument(
624639
"--ui",

apps/shark_studio/web/ui/sd.py

+39-31
Original file line numberDiff line numberDiff line change
@@ -231,9 +231,14 @@ def import_original(original_img, width, height):
231231

232232

233233
def base_model_changed(base_model_id):
234-
new_choices = get_checkpoints(
235-
os.path.join("checkpoints", os.path.basename(str(base_model_id)))
236-
) + get_checkpoints(model_type="checkpoints")
234+
ckpt_path = Path(
235+
os.path.join(
236+
cmd_opts.model_dir, "checkpoints", os.path.basename(str(base_model_id))
237+
)
238+
)
239+
ckpt_path.mkdir(parents=True, exist_ok=True)
240+
241+
new_choices = get_checkpoints(ckpt_path) + get_checkpoints(model_type="checkpoints")
237242

238243
return gr.Dropdown(
239244
value=new_choices[0] if len(new_choices) > 0 else "None",
@@ -581,21 +586,6 @@ def base_model_changed(base_model_id):
581586
object_fit="fit",
582587
preview=True,
583588
)
584-
with gr.Row():
585-
std_output = gr.Textbox(
586-
value=f"{sd_model_info}\n"
587-
f"Images will be saved at "
588-
f"{get_generated_imgs_path()}",
589-
lines=2,
590-
elem_id="std_output",
591-
show_label=True,
592-
label="Log",
593-
show_copy_button=True,
594-
)
595-
sd_element.load(
596-
logger.read_sd_logs, None, std_output, every=1
597-
)
598-
sd_status = gr.Textbox(visible=False)
599589
with gr.Row():
600590
batch_count = gr.Slider(
601591
1,
@@ -631,19 +621,18 @@ def base_model_changed(base_model_id):
631621
stop_batch = gr.Button("Stop")
632622
with gr.Tab(label="Config", id=102) as sd_tab_config:
633623
with gr.Column(elem_classes=["sd-right-panel"]):
634-
with gr.Row(elem_classes=["fill"]):
635-
Path(get_configs_path()).mkdir(
636-
parents=True, exist_ok=True
637-
)
638-
default_config_file = os.path.join(
639-
get_configs_path(),
640-
"default_sd_config.json",
641-
)
642-
write_default_sd_config(default_config_file)
643-
sd_json = gr.JSON(
644-
elem_classes=["fill"],
645-
value=view_json_file(default_config_file),
646-
)
624+
Path(get_configs_path()).mkdir(parents=True, exist_ok=True)
625+
default_config_file = os.path.join(
626+
get_configs_path(),
627+
"default_sd_config.json",
628+
)
629+
write_default_sd_config(default_config_file)
630+
sd_json = gr.JSON(
631+
label="SD Config",
632+
elem_classes=["fill"],
633+
value=view_json_file(default_config_file),
634+
render=False,
635+
)
647636
with gr.Row():
648637
with gr.Column(scale=3):
649638
load_sd_config = gr.FileExplorer(
@@ -706,11 +695,30 @@ def base_model_changed(base_model_id):
706695
inputs=[sd_json, sd_config_name],
707696
outputs=[sd_config_name],
708697
)
698+
with gr.Row(elem_classes=["fill"]):
699+
sd_json.render()
709700
save_sd_config.click(
710701
fn=save_sd_cfg,
711702
inputs=[sd_json, sd_config_name],
712703
outputs=[sd_config_name],
713704
)
705+
with gr.Tab(label="Log", id=103) as sd_tab_log:
706+
with gr.Row():
707+
std_output = gr.Textbox(
708+
value=f"{sd_model_info}\n"
709+
f"Images will be saved at "
710+
f"{get_generated_imgs_path()}",
711+
elem_id="std_output",
712+
show_label=True,
713+
label="Log",
714+
show_copy_button=True,
715+
)
716+
sd_element.load(
717+
logger.read_sd_logs, None, std_output, every=1
718+
)
719+
sd_status = gr.Textbox(visible=False)
720+
with gr.Tab(label="Automation", id=104) as sd_tab_automation:
721+
pass
714722

715723
pull_kwargs = dict(
716724
fn=pull_sd_configs,

apps/shark_studio/web/utils/file_utils.py

+19-13
Original file line numberDiff line numberDiff line change
@@ -66,41 +66,47 @@ def get_resource_path(path):
6666

6767

6868
def get_configs_path() -> Path:
69-
configs = get_resource_path(os.path.join("..", "configs"))
69+
configs = get_resource_path(cmd_opts.config_dir)
7070
if not os.path.exists(configs):
7171
os.mkdir(configs)
72-
return Path(get_resource_path("../configs"))
72+
return Path(configs)
7373

7474

7575
def get_generated_imgs_path() -> Path:
76-
return Path(
77-
cmd_opts.output_dir
78-
if cmd_opts.output_dir
79-
else get_resource_path("../generated_imgs")
80-
)
76+
outputs = get_resource_path(cmd_opts.output_dir)
77+
if not os.path.exists(outputs):
78+
os.mkdir(outputs)
79+
return Path(outputs)
80+
81+
82+
def get_tmp_path() -> Path:
83+
tmpdir = get_resource_path(cmd_opts.model_dir)
84+
if not os.path.exists(tmpdir):
85+
os.mkdir(tmpdir)
86+
return Path(tmpdir)
8187

8288

8389
def get_generated_imgs_todays_subdir() -> str:
8490
return dt.now().strftime("%Y%m%d")
8591

8692

87-
def create_checkpoint_folders():
93+
def create_model_folders():
8894
dir = ["checkpoints", "vae", "lora", "vmfb"]
89-
if not os.path.isdir(cmd_opts.ckpt_dir):
95+
if not os.path.isdir(cmd_opts.model_dir):
9096
try:
91-
os.makedirs(cmd_opts.ckpt_dir)
97+
os.makedirs(cmd_opts.model_dir)
9298
except OSError:
9399
sys.exit(
94-
f"Invalid --ckpt_dir argument, "
95-
f"{cmd_opts.ckpt_dir} folder does not exist, and cannot be created."
100+
f"Invalid --model_dir argument, "
101+
f"{cmd_opts.model_dir} folder does not exist, and cannot be created."
96102
)
97103

98104
for root in dir:
99105
Path(get_checkpoints_path(root)).mkdir(parents=True, exist_ok=True)
100106

101107

102108
def get_checkpoints_path(model_type=""):
103-
return get_resource_path(os.path.join(cmd_opts.ckpt_dir, model_type))
109+
return get_resource_path(os.path.join(cmd_opts.model_dir, model_type))
104110

105111

106112
def get_checkpoints(model_type="checkpoints"):

apps/shark_studio/web/utils/tmp_configs.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@
22
import shutil
33
from time import time
44

5-
shark_tmp = os.path.join(os.getcwd(), "shark_tmp/")
5+
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
6+
7+
shark_tmp = cmd_opts.tmp_dir # os.path.join(os.getcwd(), "shark_tmp/")
68

79

810
def clear_tmp_mlir():

setup_venv.ps1

-1
Original file line numberDiff line numberDiff line change
@@ -89,5 +89,4 @@ else {python -m venv .\shark.venv\}
8989
python -m pip install --upgrade pip
9090
pip install wheel
9191
pip install -r requirements.txt
92-
9392
Write-Host "Source your venv with ./shark.venv/Scripts/activate"

shark/shark_importer.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import os
77
import hashlib
88

9+
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
910

1011
def create_hash(file_name):
1112
with open(file_name, "rb") as f:
@@ -120,7 +121,7 @@ def import_mlir(
120121
is_dynamic=False,
121122
tracing_required=False,
122123
func_name="forward",
123-
save_dir="./shark_tmp/",
124+
save_dir=cmd_opts.tmp_dir, #"./shark_tmp/",
124125
mlir_type="linalg",
125126
):
126127
if self.frontend in ["torch", "pytorch"]:
@@ -806,7 +807,7 @@ def save_mlir(
806807
model_name + "_" + frontend + "_" + mlir_dialect + ".mlir"
807808
)
808809
if dir == "":
809-
dir = os.path.join(".", "shark_tmp")
810+
dir = cmd_opts.tmp_dir, #os.path.join(".", "shark_tmp")
810811
mlir_path = os.path.join(dir, model_name_mlir)
811812
print(f"saving {model_name_mlir} to {dir}")
812813
if not os.path.exists(dir):

0 commit comments

Comments
 (0)