@@ -8623,47 +8623,42 @@ def display_output_file(text_file, image_file, video_file, audio_file, model3d_f
8623
8623
return text_files, image_files, video_files, audio_files, model3d_files, display_output_file
8624
8624
8625
8625
8626
- def download_model(model_name_llm, model_name_sd):
8627
- if not model_name_llm and not model_name_sd:
8628
- return "Please select a model to download"
8629
-
8630
- if model_name_llm and model_name_sd:
8631
- return "Please select one model type for downloading"
8632
-
8633
- if model_name_llm:
8634
- model_url = ""
8635
- if model_name_llm == "StarlingLM(Transformers7B)":
8636
- model_url = "https://huggingface.co/Nexusflow/Starling-LM-7B-beta"
8637
- elif model_name_llm == "OpenChat3.6(Llama8B.Q4)":
8638
- model_url = "https://huggingface.co/bartowski/openchat-3.6-8b-20240522-GGUF/resolve/main/openchat-3.6-8b-20240522-Q4_K_M.gguf"
8639
- model_path = os.path.join("inputs", "text", "llm_models", model_name_llm)
8640
-
8641
- if model_url:
8642
- if model_name_llm == "StarlingLM(Transformers7B)":
8643
- Repo.clone_from(model_url, model_path)
8626
+ def download_model(llm_model_url, sd_model_url):
8627
+ if not llm_model_url and not sd_model_url:
8628
+ return "Please enter at least one model URL to download"
8629
+
8630
+ messages = []
8631
+
8632
+ if llm_model_url:
8633
+ try:
8634
+ if "/" in llm_model_url and "blob/main" not in llm_model_url:
8635
+ repo_name = llm_model_url.split("/")[-1]
8636
+ model_path = os.path.join("inputs", "text", "llm_models", repo_name)
8637
+ os.makedirs(model_path, exist_ok=True)
8638
+ Repo.clone_from(f"https://huggingface.co/{llm_model_url}", model_path)
8639
+ messages.append(f"LLM model repository {repo_name} downloaded successfully!")
8644
8640
else:
8645
- response = requests.get(model_url, allow_redirects=True)
8641
+ file_name = llm_model_url.split("/")[-1]
8642
+ model_path = os.path.join("inputs", "text", "llm_models", file_name)
8643
+ response = requests.get(llm_model_url, allow_redirects=True)
8646
8644
with open(model_path, "wb") as file:
8647
8645
file.write(response.content)
8648
- return f"LLM model {model_name_llm} downloaded successfully!"
8649
- else:
8650
- return "Invalid LLM model name"
8651
-
8652
- if model_name_sd:
8653
- model_url = ""
8654
- if model_name_sd == "Dreamshaper8(SD1.5)":
8655
- model_url = "https://huggingface.co/Lykon/DreamShaper/resolve/main/DreamShaper_8_pruned.safetensors"
8656
- elif model_name_sd == "RealisticVisionV4.0(SDXL)":
8657
- model_url = "https://huggingface.co/SG161222/RealVisXL_V4.0/resolve/main/RealVisXL_V4.0.safetensors"
8658
- model_path = os.path.join("inputs", "image", "sd_models", f"{model_name_sd}")
8659
-
8660
- if model_url:
8661
- response = requests.get(model_url, allow_redirects=True)
8646
+ messages.append(f"LLM model file {file_name} downloaded successfully!")
8647
+ except Exception as e:
8648
+ messages.append(f"Error downloading LLM model: {str(e)}")
8649
+
8650
+ if sd_model_url:
8651
+ try:
8652
+ file_name = sd_model_url.split("/")[-1]
8653
+ model_path = os.path.join("inputs", "image", "sd_models", file_name)
8654
+ response = requests.get(sd_model_url, allow_redirects=True)
8662
8655
with open(model_path, "wb") as file:
8663
8656
file.write(response.content)
8664
- return f"StableDiffusion model {model_name_sd} downloaded successfully!"
8665
- else:
8666
- return "Invalid StableDiffusion model name"
8657
+ messages.append(f"StableDiffusion model file {file_name} downloaded successfully!")
8658
+ except Exception as e:
8659
+ messages.append(f"Error downloading StableDiffusion model: {str(e)}")
8660
+
8661
+ return "\n".join(messages)
8667
8662
8668
8663
8669
8664
def settings_interface(language, share_value, debug_value, monitoring_value, auto_launch, api_status, open_api, queue_max_size, status_update_rate, gradio_auth, server_name, server_port, hf_token, theme,
@@ -11348,8 +11343,8 @@ def reload_interface():
11348
11343
model_downloader_interface = gr.Interface(
11349
11344
fn=download_model,
11350
11345
inputs=[
11351
- gr.Dropdown(choices=[None, "StarlingLM(Transformers7B)", "OpenChat3.6(Llama8B.Q4)"], label=_("Download LLM model", lang), value=None ),
11352
- gr.Dropdown(choices=[None, "Dreamshaper8(SD1.5)", "RealisticVisionV4.0(SDXL)"], label=_("Download StableDiffusion model", lang), value=None ),
11346
+ gr.Textbox( label=_("Download LLM model", lang), placeholder="repo-author/repo-name or https://huggingface.co/.../model.gguf" ),
11347
+ gr.Textbox( label=_("Download StableDiffusion model", lang), placeholder="https://huggingface.co/.../model.safetensors" ),
11353
11348
],
11354
11349
outputs=[
11355
11350
gr.Textbox(label=_("Message", lang), type="text"),
0 commit comments