diff --git a/.gitignore b/.gitignore index 24b5204..8b626aa 100644 --- a/.gitignore +++ b/.gitignore @@ -99,7 +99,7 @@ ENV/ [Ll]ib64 [Ll]ocal [Ss]cripts -!scripts\download_model.bat +!scripts/ pyvenv.cfg .venv pip-selfcheck.json diff --git a/scripts/download_model.py b/scripts/download_model.py new file mode 100644 index 0000000..b951b05 --- /dev/null +++ b/scripts/download_model.py @@ -0,0 +1,78 @@ +import os +import sys +import json +import requests +from tqdm import tqdm + +def download_file(url: str, filename: str, download_dir: str): + """Download a file if it does not already exist.""" + + try: + filepath = os.path.join(download_dir, filename) + content_length = int(requests.head(url).headers.get("content-length", 0)) + + # If file already exists and size matches, skip download + if os.path.isfile(filepath) and os.path.getsize(filepath) == content_length: + print(f"{filepath} already exists. Skipping download.") + return + if os.path.isfile(filepath) and os.path.getsize(filepath) != content_length: + print(f"{filepath} already exists but size does not match. Redownloading.") + else: + print(f"Downloading {filename} from {url}") + + # Start download, stream=True allows for progress tracking + response = requests.get(url, stream=True) + + # Check if request was successful + response.raise_for_status() + + # Create progress bar + total_size = int(response.headers.get('content-length', 0)) + progress_bar = tqdm( + total=total_size, + unit='iB', + unit_scale=True, + ncols=70, + file=sys.stdout + ) + + # Write response content to file + with open(filepath, 'wb') as f: + for data in response.iter_content(chunk_size=1024): + f.write(data) + progress_bar.update(len(data)) # Update progress bar + + # Close progress bar + progress_bar.close() + + # Error handling for incomplete downloads + if total_size != 0 and progress_bar.n != total_size: + print("ERROR, something went wrong while downloading") + raise Exception() + + + except Exception as e: + print(f"An error occurred: {e}") + +def main(): + """Main function to download files from URLs in a config file.""" + + # Get JSON config file path + script_dir = os.path.dirname(os.path.realpath(__file__)) + config_file_path = os.path.join(script_dir, "download_models.json") + + # Set download directory + download_dir = "checkpoints" + os.makedirs(download_dir, exist_ok=True) + + # Load URL and filenames from JSON + with open(config_file_path, "r") as f: + config = json.load(f) + + # Download each file specified in config + for url, filename in config.items(): + download_file(url, filename, download_dir) + + +if __name__ == "__main__": + main() diff --git a/scripts/download_models.json b/scripts/download_models.json new file mode 100644 index 0000000..637d553 --- /dev/null +++ b/scripts/download_models.json @@ -0,0 +1,10 @@ +{ + "https://storage.googleapis.com/self-distilled-stylegan/lions_512_pytorch.pkl": "stylegan2_lions_512_pytorch.pkl", + "https://storage.googleapis.com/self-distilled-stylegan/dogs_1024_pytorch.pkl": "stylegan2_dogs_1024_pytorch.pkl", + "https://storage.googleapis.com/self-distilled-stylegan/horses_256_pytorch.pkl": "stylegan2_horses_256_pytorch.pkl", + "https://storage.googleapis.com/self-distilled-stylegan/elephants_512_pytorch.pkl": "stylegan2_elephants_512_pytorch.pkl", + "https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan2/versions/1/files/stylegan2-ffhq-512x512.pkl": "stylegan2-ffhq-512x512.pkl", + "https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan2/versions/1/files/stylegan2-afhqcat-512x512.pkl": "stylegan2-afhqcat-512x512.pkl", + "http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/stylegan2-car-config-f.pkl": "stylegan2-car-config-f.pkl", + "http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/stylegan2-cat-config-f.pkl": "stylegan2-cat-config-f.pkl" +}