Skip to content

Commit

Permalink
better checkpoint download
Browse files Browse the repository at this point in the history
  • Loading branch information
kevinunger committed Jun 29, 2023
1 parent 3347bbb commit 6888c9d
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 1 deletion.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ ENV/
[Ll]ib64
[Ll]ocal
[Ss]cripts
!scripts\download_model.bat
!scripts/
pyvenv.cfg
.venv
pip-selfcheck.json
Expand Down
78 changes: 78 additions & 0 deletions scripts/download_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import os
import sys
import json
import requests
from tqdm import tqdm

def download_file(url: str, filename: str, download_dir: str):
"""Download a file if it does not already exist."""

try:
filepath = os.path.join(download_dir, filename)
content_length = int(requests.head(url).headers.get("content-length", 0))

# If file already exists and size matches, skip download
if os.path.isfile(filepath) and os.path.getsize(filepath) == content_length:
print(f"{filepath} already exists. Skipping download.")
return
if os.path.isfile(filepath) and os.path.getsize(filepath) != content_length:
print(f"{filepath} already exists but size does not match. Redownloading.")
else:
print(f"Downloading {filename} from {url}")

# Start download, stream=True allows for progress tracking
response = requests.get(url, stream=True)

# Check if request was successful
response.raise_for_status()

# Create progress bar
total_size = int(response.headers.get('content-length', 0))
progress_bar = tqdm(
total=total_size,
unit='iB',
unit_scale=True,
ncols=70,
file=sys.stdout
)

# Write response content to file
with open(filepath, 'wb') as f:
for data in response.iter_content(chunk_size=1024):
f.write(data)
progress_bar.update(len(data)) # Update progress bar

# Close progress bar
progress_bar.close()

# Error handling for incomplete downloads
if total_size != 0 and progress_bar.n != total_size:
print("ERROR, something went wrong while downloading")
raise Exception()


except Exception as e:
print(f"An error occurred: {e}")

def main():
"""Main function to download files from URLs in a config file."""

# Get JSON config file path
script_dir = os.path.dirname(os.path.realpath(__file__))
config_file_path = os.path.join(script_dir, "download_models.json")

# Set download directory
download_dir = "checkpoints"
os.makedirs(download_dir, exist_ok=True)

# Load URL and filenames from JSON
with open(config_file_path, "r") as f:
config = json.load(f)

# Download each file specified in config
for url, filename in config.items():
download_file(url, filename, download_dir)


if __name__ == "__main__":
main()
10 changes: 10 additions & 0 deletions scripts/download_models.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
{
"https://storage.googleapis.com/self-distilled-stylegan/lions_512_pytorch.pkl": "stylegan2_lions_512_pytorch.pkl",
"https://storage.googleapis.com/self-distilled-stylegan/dogs_1024_pytorch.pkl": "stylegan2_dogs_1024_pytorch.pkl",
"https://storage.googleapis.com/self-distilled-stylegan/horses_256_pytorch.pkl": "stylegan2_horses_256_pytorch.pkl",
"https://storage.googleapis.com/self-distilled-stylegan/elephants_512_pytorch.pkl": "stylegan2_elephants_512_pytorch.pkl",
"https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan2/versions/1/files/stylegan2-ffhq-512x512.pkl": "stylegan2-ffhq-512x512.pkl",
"https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan2/versions/1/files/stylegan2-afhqcat-512x512.pkl": "stylegan2-afhqcat-512x512.pkl",
"http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/stylegan2-car-config-f.pkl": "stylegan2-car-config-f.pkl",
"http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/stylegan2-cat-config-f.pkl": "stylegan2-cat-config-f.pkl"
}

0 comments on commit 6888c9d

Please sign in to comment.