Skip to content

Commit 855b1cd

Browse files
author
The TensorFlow Datasets Authors
committed
Merge pull request #10963 from alexhartl:master
PiperOrigin-RevId: 715745355
2 parents 1322866 + ff89242 commit 855b1cd

File tree

2 files changed

+67
-11
lines changed

2 files changed

+67
-11
lines changed

tensorflow_datasets/core/download/downloader.py

+57-11
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
import urllib
3232

3333
from etils import epath
34+
from tensorflow_datasets.core import lazy_imports_lib
3435
from tensorflow_datasets.core import units
3536
from tensorflow_datasets.core import utils
3637
from tensorflow_datasets.core.download import checksums as checksums_lib
@@ -130,6 +131,43 @@ def _get_filename(response: Response) -> str:
130131
return _basename_from_url(response.url)
131132

132133

134+
def _process_gdrive_confirmation(original_url: str, contents: str) -> str:
135+
"""Process Google Drive confirmation page.
136+
137+
Extracts the download link from a Google Drive confirmation page.
138+
139+
Args:
140+
original_url: The URL the confirmation page was originally retrieved from.
141+
contents: The confirmation page's HTML.
142+
143+
Returns:
144+
download_url: The URL for downloading the file.
145+
"""
146+
bs4 = lazy_imports_lib.lazy_imports.bs4
147+
soup = bs4.BeautifulSoup(contents, 'html.parser')
148+
form = soup.find('form')
149+
if not form:
150+
raise ValueError(
151+
f'Failed to obtain confirmation link for GDrive URL {original_url}.'
152+
)
153+
action = form.get('action', '')
154+
if not action:
155+
raise ValueError(
156+
f'Failed to obtain confirmation link for GDrive URL {original_url}.'
157+
)
158+
# Find the <input>s named 'uuid', 'export', 'id' and 'confirm'
159+
input_names = ['uuid', 'export', 'id', 'confirm']
160+
params = {}
161+
for name in input_names:
162+
input_tag = form.find('input', {'name': name})
163+
if input_tag:
164+
params[name] = input_tag.get('value', '')
165+
query_string = urllib.parse.urlencode(params)
166+
download_url = f'{action}?{query_string}' if query_string else action
167+
download_url = urllib.parse.urljoin(original_url, download_url)
168+
return download_url
169+
170+
133171
class _Downloader:
134172
"""Class providing async download API with checksum validation.
135173
@@ -318,11 +356,26 @@ def _open_with_requests(
318356
session.mount(
319357
'https://', requests.adapters.HTTPAdapter(max_retries=retries)
320358
)
321-
if _DRIVE_URL.match(url):
322-
url = _normalize_drive_url(url)
323359
with session.get(url, stream=True, **kwargs) as response:
324-
_assert_status(response)
325-
yield (response, response.iter_content(chunk_size=io.DEFAULT_BUFFER_SIZE))
360+
if (
361+
_DRIVE_URL.match(url)
362+
and 'Content-Disposition' not in response.headers
363+
):
364+
download_url = _process_gdrive_confirmation(url, response.text)
365+
with session.get(
366+
download_url, stream=True, **kwargs
367+
) as download_response:
368+
_assert_status(download_response)
369+
yield (
370+
download_response,
371+
download_response.iter_content(chunk_size=io.DEFAULT_BUFFER_SIZE),
372+
)
373+
else:
374+
_assert_status(response)
375+
yield (
376+
response,
377+
response.iter_content(chunk_size=io.DEFAULT_BUFFER_SIZE),
378+
)
326379

327380

328381
@contextlib.contextmanager
@@ -338,13 +391,6 @@ def _open_with_urllib(
338391
)
339392

340393

341-
def _normalize_drive_url(url: str) -> str:
342-
"""Returns Google Drive url with confirmation token."""
343-
# This bypasses the "Google Drive can't scan this file for viruses" warning
344-
# when dowloading large files.
345-
return url + '&confirm=t'
346-
347-
348394
def _assert_status(response: requests.Response) -> None:
349395
"""Ensure the URL response is 200."""
350396
if response.status_code != 200:

tensorflow_datasets/core/download/downloader_test.py

+10
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from typing import Optional
1919
from unittest import mock
2020

21+
import bs4
2122
from etils import epath
2223
import pytest
2324
from tensorflow_datasets import testing
@@ -36,6 +37,7 @@ def __init__(self, url, content, cookies=None, headers=None, status_code=200):
3637
self.status_code = status_code
3738
# For urllib codepath
3839
self.read = self.raw.read
40+
self.text = ''
3941

4042
def __enter__(self):
4143
return self
@@ -78,6 +80,14 @@ def setUp(self):
7880
lambda *a, **kw: _FakeResponse(self.url, self.response, self.cookies),
7981
).start()
8082

83+
bs_mock = mock.MagicMock(spec=bs4.BeautifulSoup)
84+
form_mock = mock.MagicMock()
85+
form_mock.get.return_value = 'x'
86+
bs_mock.find.return_value = form_mock
87+
mock.patch.object(
88+
bs4, 'BeautifulSoup', autospec=True, return_value=bs_mock
89+
).start()
90+
8191
def test_ok(self):
8292
promise = self.downloader.download(self.url, self.tmp_dir)
8393
future = promise.get()

0 commit comments

Comments
 (0)