Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.

Commit bffcae6

Browse files
authored
Test Mode and Search Fix (#54) (#56)
* Added SPARSEZOO_TEST_MODE to disable incrementing download count, fixed bug in release version with search * removed unnecessary export line
1 parent 676fdbf commit bffcae6

File tree

5 files changed

+31
-6
lines changed

5 files changed

+31
-6
lines changed

Makefile

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ PYCHECKGLOBS := 'examples/**/*.py' 'scripts/**/*.py' 'src/**/*.py' 'tests/**/*.p
66
DOCDIR := docs
77
MDCHECKGLOBS := 'docs/**/*.md' 'docs/**/*.rst' 'examples/**/*.md' 'notebooks/**/*.md' 'scripts/**/*.md'
88
MDCHECKFILES := CODE_OF_CONDUCT.md CONTRIBUTING.md DEVELOPING.md README.md
9+
SPARSEZOO_TEST_MODE := "true"
910

1011
BUILD_ARGS := # set nightly to build nightly release
1112
TARGETS := "" # targets for running pytests: full,efficientnet,inception,resnet,vgg,ssd,yolo

src/sparsezoo/requests/base.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,12 @@
1919
import os
2020
from typing import Any, List, Union
2121

22+
from sparsezoo.utils import convert_to_bool
2223

23-
__all__ = ["BASE_API_URL", "ModelArgs", "MODELS_API_URL"]
24+
25+
__all__ = ["BASE_API_URL", "ModelArgs", "MODELS_API_URL", "SPARSEZOO_TEST_MODE"]
26+
27+
SPARSEZOO_TEST_MODE = convert_to_bool(os.getenv("SPARSEZOO_TEST_MODE"))
2428

2529
BASE_API_URL = (
2630
os.getenv("SPARSEZOO_API_URL")

src/sparsezoo/requests/download.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
import requests
2323

2424
from sparsezoo.requests.authentication import get_auth_header
25-
from sparsezoo.requests.base import MODELS_API_URL, ModelArgs
25+
from sparsezoo.requests.base import MODELS_API_URL, SPARSEZOO_TEST_MODE, ModelArgs
2626

2727

2828
__all__ = ["download_get_request", "DOWNLOAD_PATH"]
@@ -52,11 +52,18 @@ def download_get_request(
5252
if file_name:
5353
url = f"{url}/{file_name}"
5454

55+
download_args = []
56+
5557
if hasattr(args, "release_version") and args.release_version:
56-
url = f"{url}?release_version={args.release_version}"
58+
download_args.append(f"release_version={args.release_version}")
5759

58-
_LOGGER.debug(f"GET download from {url}")
60+
if SPARSEZOO_TEST_MODE:
61+
download_args.append("increment_download=False")
5962

63+
if download_args:
64+
url = f"{url}?{'&'.join(download_args)}"
65+
66+
_LOGGER.debug(f"GET download from {url}")
6067
response = requests.get(url=url, headers=header)
6168
response.raise_for_status()
6269
response_json = response.json()

src/sparsezoo/requests/search.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def search_get_request(
5959
search_args.extend([f"page={page}", f"page_length={page_length}"])
6060

6161
if args.release_version:
62-
search_args.extend(f"release_version={args.release_version}")
62+
search_args.append(f"release_version={args.release_version}")
6363

6464
search_args = "&".join(search_args)
6565
url = f"{MODELS_API_URL}/{SEARCH_PATH}/{args.model_url_root}?{search_args}"

src/sparsezoo/utils/helpers.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,15 @@
1818

1919
import errno
2020
import os
21-
from typing import Union
21+
from typing import Any, Union
2222

2323
from tqdm import auto, tqdm, tqdm_notebook
2424

2525

2626
__all__ = [
2727
"CACHE_DIR",
2828
"clean_path",
29+
"convert_to_bool",
2930
"create_dirs",
3031
"create_parent_dirs",
3132
"create_tqdm_auto_constructor",
@@ -43,6 +44,18 @@ def clean_path(path: str) -> str:
4344
return os.path.abspath(os.path.expanduser(path))
4445

4546

47+
def convert_to_bool(val: Any):
48+
"""
49+
:param val: a value
50+
:return: False if value is a Falsy value e.g. 0, f, false, None, otherwise True.
51+
"""
52+
return (
53+
bool(val)
54+
if not isinstance(val, str)
55+
else bool(val) and "f" not in val.lower() and "0" not in val.lower()
56+
)
57+
58+
4659
def create_dirs(path: str):
4760
"""
4861
:param path: the directory path to try and create

0 commit comments

Comments
 (0)