31
31
import urllib
32
32
33
33
from etils import epath
34
+ from tensorflow_datasets .core import lazy_imports_lib
34
35
from tensorflow_datasets .core import units
35
36
from tensorflow_datasets .core import utils
36
37
from tensorflow_datasets .core .download import checksums as checksums_lib
@@ -130,6 +131,43 @@ def _get_filename(response: Response) -> str:
130
131
return _basename_from_url (response .url )
131
132
132
133
134
+ def _process_gdrive_confirmation (original_url : str , contents : str ) -> str :
135
+ """Process Google Drive confirmation page.
136
+
137
+ Extracts the download link from a Google Drive confirmation page.
138
+
139
+ Args:
140
+ original_url: The URL the confirmation page was originally retrieved from.
141
+ contents: The confirmation page's HTML.
142
+
143
+ Returns:
144
+ download_url: The URL for downloading the file.
145
+ """
146
+ bs4 = lazy_imports_lib .lazy_imports .bs4
147
+ soup = bs4 .BeautifulSoup (contents , 'html.parser' )
148
+ form = soup .find ('form' )
149
+ if not form :
150
+ raise ValueError (
151
+ f'Failed to obtain confirmation link for GDrive URL { original_url } .'
152
+ )
153
+ action = form .get ('action' , '' )
154
+ if not action :
155
+ raise ValueError (
156
+ f'Failed to obtain confirmation link for GDrive URL { original_url } .'
157
+ )
158
+ # Find the <input>s named 'uuid', 'export', 'id' and 'confirm'
159
+ input_names = ['uuid' , 'export' , 'id' , 'confirm' ]
160
+ params = {}
161
+ for name in input_names :
162
+ input_tag = form .find ('input' , {'name' : name })
163
+ if input_tag :
164
+ params [name ] = input_tag .get ('value' , '' )
165
+ query_string = urllib .parse .urlencode (params )
166
+ download_url = f'{ action } ?{ query_string } ' if query_string else action
167
+ download_url = urllib .parse .urljoin (original_url , download_url )
168
+ return download_url
169
+
170
+
133
171
class _Downloader :
134
172
"""Class providing async download API with checksum validation.
135
173
@@ -318,11 +356,26 @@ def _open_with_requests(
318
356
session .mount (
319
357
'https://' , requests .adapters .HTTPAdapter (max_retries = retries )
320
358
)
321
- if _DRIVE_URL .match (url ):
322
- url = _normalize_drive_url (url )
323
359
with session .get (url , stream = True , ** kwargs ) as response :
324
- _assert_status (response )
325
- yield (response , response .iter_content (chunk_size = io .DEFAULT_BUFFER_SIZE ))
360
+ if (
361
+ _DRIVE_URL .match (url )
362
+ and 'Content-Disposition' not in response .headers
363
+ ):
364
+ download_url = _process_gdrive_confirmation (url , response .text )
365
+ with session .get (
366
+ download_url , stream = True , ** kwargs
367
+ ) as download_response :
368
+ _assert_status (download_response )
369
+ yield (
370
+ download_response ,
371
+ download_response .iter_content (chunk_size = io .DEFAULT_BUFFER_SIZE ),
372
+ )
373
+ else :
374
+ _assert_status (response )
375
+ yield (
376
+ response ,
377
+ response .iter_content (chunk_size = io .DEFAULT_BUFFER_SIZE ),
378
+ )
326
379
327
380
328
381
@contextlib .contextmanager
@@ -338,13 +391,6 @@ def _open_with_urllib(
338
391
)
339
392
340
393
341
- def _normalize_drive_url (url : str ) -> str :
342
- """Returns Google Drive url with confirmation token."""
343
- # This bypasses the "Google Drive can't scan this file for viruses" warning
344
- # when dowloading large files.
345
- return url + '&confirm=t'
346
-
347
-
348
394
def _assert_status (response : requests .Response ) -> None :
349
395
"""Ensure the URL response is 200."""
350
396
if response .status_code != 200 :
0 commit comments