Skip to content

Commit fd6a37e

Browse files
authored
Merge pull request #245 from Dartvauder/dev
Update app.py
2 parents a46111a + 2c31597 commit fd6a37e

File tree

1 file changed

+33
-38
lines changed

1 file changed

+33
-38
lines changed

LaunchFile/app.py

+33-38
Original file line numberDiff line numberDiff line change
@@ -8623,47 +8623,42 @@ def display_output_file(text_file, image_file, video_file, audio_file, model3d_f
86238623
return text_files, image_files, video_files, audio_files, model3d_files, display_output_file
86248624

86258625

8626-
def download_model(model_name_llm, model_name_sd):
8627-
if not model_name_llm and not model_name_sd:
8628-
return "Please select a model to download"
8629-
8630-
if model_name_llm and model_name_sd:
8631-
return "Please select one model type for downloading"
8632-
8633-
if model_name_llm:
8634-
model_url = ""
8635-
if model_name_llm == "StarlingLM(Transformers7B)":
8636-
model_url = "https://huggingface.co/Nexusflow/Starling-LM-7B-beta"
8637-
elif model_name_llm == "OpenChat3.6(Llama8B.Q4)":
8638-
model_url = "https://huggingface.co/bartowski/openchat-3.6-8b-20240522-GGUF/resolve/main/openchat-3.6-8b-20240522-Q4_K_M.gguf"
8639-
model_path = os.path.join("inputs", "text", "llm_models", model_name_llm)
8640-
8641-
if model_url:
8642-
if model_name_llm == "StarlingLM(Transformers7B)":
8643-
Repo.clone_from(model_url, model_path)
8626+
def download_model(llm_model_url, sd_model_url):
8627+
if not llm_model_url and not sd_model_url:
8628+
return "Please enter at least one model URL to download"
8629+
8630+
messages = []
8631+
8632+
if llm_model_url:
8633+
try:
8634+
if "/" in llm_model_url and "blob/main" not in llm_model_url:
8635+
repo_name = llm_model_url.split("/")[-1]
8636+
model_path = os.path.join("inputs", "text", "llm_models", repo_name)
8637+
os.makedirs(model_path, exist_ok=True)
8638+
Repo.clone_from(f"https://huggingface.co/{llm_model_url}", model_path)
8639+
messages.append(f"LLM model repository {repo_name} downloaded successfully!")
86448640
else:
8645-
response = requests.get(model_url, allow_redirects=True)
8641+
file_name = llm_model_url.split("/")[-1]
8642+
model_path = os.path.join("inputs", "text", "llm_models", file_name)
8643+
response = requests.get(llm_model_url, allow_redirects=True)
86468644
with open(model_path, "wb") as file:
86478645
file.write(response.content)
8648-
return f"LLM model {model_name_llm} downloaded successfully!"
8649-
else:
8650-
return "Invalid LLM model name"
8651-
8652-
if model_name_sd:
8653-
model_url = ""
8654-
if model_name_sd == "Dreamshaper8(SD1.5)":
8655-
model_url = "https://huggingface.co/Lykon/DreamShaper/resolve/main/DreamShaper_8_pruned.safetensors"
8656-
elif model_name_sd == "RealisticVisionV4.0(SDXL)":
8657-
model_url = "https://huggingface.co/SG161222/RealVisXL_V4.0/resolve/main/RealVisXL_V4.0.safetensors"
8658-
model_path = os.path.join("inputs", "image", "sd_models", f"{model_name_sd}")
8659-
8660-
if model_url:
8661-
response = requests.get(model_url, allow_redirects=True)
8646+
messages.append(f"LLM model file {file_name} downloaded successfully!")
8647+
except Exception as e:
8648+
messages.append(f"Error downloading LLM model: {str(e)}")
8649+
8650+
if sd_model_url:
8651+
try:
8652+
file_name = sd_model_url.split("/")[-1]
8653+
model_path = os.path.join("inputs", "image", "sd_models", file_name)
8654+
response = requests.get(sd_model_url, allow_redirects=True)
86628655
with open(model_path, "wb") as file:
86638656
file.write(response.content)
8664-
return f"StableDiffusion model {model_name_sd} downloaded successfully!"
8665-
else:
8666-
return "Invalid StableDiffusion model name"
8657+
messages.append(f"StableDiffusion model file {file_name} downloaded successfully!")
8658+
except Exception as e:
8659+
messages.append(f"Error downloading StableDiffusion model: {str(e)}")
8660+
8661+
return "\n".join(messages)
86678662

86688663

86698664
def settings_interface(language, share_value, debug_value, monitoring_value, auto_launch, api_status, open_api, queue_max_size, status_update_rate, gradio_auth, server_name, server_port, hf_token, theme,
@@ -11348,8 +11343,8 @@ def reload_interface():
1134811343
model_downloader_interface = gr.Interface(
1134911344
fn=download_model,
1135011345
inputs=[
11351-
gr.Dropdown(choices=[None, "StarlingLM(Transformers7B)", "OpenChat3.6(Llama8B.Q4)"], label=_("Download LLM model", lang), value=None),
11352-
gr.Dropdown(choices=[None, "Dreamshaper8(SD1.5)", "RealisticVisionV4.0(SDXL)"], label=_("Download StableDiffusion model", lang), value=None),
11346+
gr.Textbox(label=_("Download LLM model", lang), placeholder="repo-author/repo-name or https://huggingface.co/.../model.gguf"),
11347+
gr.Textbox(label=_("Download StableDiffusion model", lang), placeholder="https://huggingface.co/.../model.safetensors"),
1135311348
],
1135411349
outputs=[
1135511350
gr.Textbox(label=_("Message", lang), type="text"),

0 commit comments

Comments
 (0)