Skip to content

Commit 5b84b7c

Browse files
authored
Add decorator to skip tests and log a warning if URL is unaccessible (#2104)
* Add decorator to skip tests and log a warning if URL is unaccessible * More descriptive error message
1 parent db26565 commit 5b84b7c

File tree

2 files changed

+13
-1
lines changed

2 files changed

+13
-1
lines changed

test/torchtext_unittest/common/torchtext_test_case.py

+11
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,24 @@
55
import shutil
66
import subprocess
77
import tempfile
8+
from urllib.error import HTTPError
89

910
import torch # noqa: F401
1011
from torch.testing._internal.common_utils import TestCase
1112

1213
logger = logging.getLogger(__name__)
1314

1415

16+
def third_party_download(test_func):
17+
def inner(*args, **kwargs):
18+
try:
19+
return test_func(*args, **kwargs)
20+
except HTTPError as e:
21+
logger.warning(f"Cannot access URL in {test_func.__name__}. Error message {e}")
22+
23+
return inner
24+
25+
1526
class TorchtextTestCase(TestCase):
1627
def setUp(self) -> None:
1728
logging.basicConfig(format=("%(asctime)s - %(levelname)s - " "%(name)s - %(message)s"), level=logging.INFO)

test/torchtext_unittest/test_build.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import torch
66
import torchtext.data
77

8-
from .common.torchtext_test_case import TorchtextTestCase
8+
from .common.torchtext_test_case import TorchtextTestCase, third_party_download
99

1010

1111
class TestDataUtils(TorchtextTestCase):
@@ -64,6 +64,7 @@ def test_vectors_get_vecs(self) -> None:
6464
self.assertEqual(token_one_vec.shape[0], vec.dim)
6565
self.assertEqual(vec[tokens[0].lower()], token_one_vec)
6666

67+
@third_party_download
6768
def test_download_charngram_vectors(self) -> None:
6869
# Build a vocab and get vectors twice to test caching.
6970
for _ in range(2):

0 commit comments

Comments
 (0)