⬅ datasets/download_utils.py source

1 """
2 This file contains utility functions for downloading datasets.
3 The code in this file is taken from the torchvision package,
  • E501 Line too long (90 > 79 characters)
4 specifically, https://github.com/pytorch/vision/blob/master/torchvision/datasets/utils.py.
5 We package it here to avoid users having to install the rest of torchvision.
6 It is licensed under the following license:
7  
8 BSD 3-Clause License
9  
10 Copyright (c) Soumith Chintala 2016,
11 All rights reserved.
12  
13 Redistribution and use in source and binary forms, with or without
14 modification, are permitted provided that the following conditions are met:
15  
16 * Redistributions of source code must retain the above copyright notice, this
17 list of conditions and the following disclaimer.
18  
19 * Redistributions in binary form must reproduce the above copyright notice,
20 this list of conditions and the following disclaimer in the documentation
21 and/or other materials provided with the distribution.
22  
23 * Neither the name of the copyright holder nor the names of its
24 contributors may be used to endorse or promote products derived from
25 this software without specific prior written permission.
26  
27 THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
28 AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
29 IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
30 DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
31 FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
32 DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
33 SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
34 CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
35 OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
36 OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
37 """
38  
39 import gzip
40 import hashlib
41 import os
42 import os.path
43 import tarfile
44 import zipfile
45 from typing import Any, Callable, List, Iterable, Optional, TypeVar
46  
47 import torch
48 from torch.utils.model_zoo import tqdm
49  
50  
51 def gen_bar_updater(total) -> Callable[[int, int, int], None]:
52 pbar = tqdm(total=total, unit='Byte')
53  
54 def bar_update(count, block_size, total_size):
55 if pbar.total is None and total_size:
56 pbar.total = total_size
57 progress_bytes = count * block_size
58 pbar.update(progress_bytes - pbar.n)
59  
60 return bar_update
61  
62  
63 def calculate_md5(fpath: str, chunk_size: int = 1024 * 1024) -> str:
64 md5 = hashlib.md5()
65 with open(fpath, 'rb') as f:
66 for chunk in iter(lambda: f.read(chunk_size), b''):
67 md5.update(chunk)
68 return md5.hexdigest()
69  
70  
71 def check_md5(fpath: str, md5: str, **kwargs: Any) -> bool:
72 return md5 == calculate_md5(fpath, **kwargs)
73  
74  
75 def check_integrity(fpath: str, md5: Optional[str] = None) -> bool:
76 if not os.path.isfile(fpath):
77 return False
78 if md5 is None:
79 return True
80 return check_md5(fpath, md5)
81  
82  
  • E501 Line too long (96 > 79 characters)
83 def download_url(url: str, root: str, filename: Optional[str] = None, md5: Optional[str] = None,
84 size: Optional[int] = None) -> None:
85 """Download a file from a url and place it in root.
86  
87 Args:
88 url (str): URL to download file from
89 root (str): Directory to place downloaded file in
  • E501 Line too long (99 > 79 characters)
90 filename (str, optional): Name to save the file under. If None, use the basename of the URL
  • E501 Line too long (80 > 79 characters)
91 md5 (str, optional): MD5 checksum of the download. If None, do not check
92 """
93 import urllib
94  
95 root = os.path.expanduser(root)
96 if not filename:
97 filename = os.path.basename(url)
98 fpath = os.path.join(root, filename)
99  
100 os.makedirs(root, exist_ok=True)
101  
102 # check if file is already present locally
103 if check_integrity(fpath, md5):
104 print('Using downloaded and verified file: ' + fpath)
105 else: # download the file
106 try:
107 print('Downloading ' + url + ' to ' + fpath)
108 urllib.request.urlretrieve(
109 url, fpath,
110 reporthook=gen_bar_updater(size)
111 )
  • E501 Line too long (83 > 79 characters)
112 except (urllib.error.URLError, IOError) as e: # type: ignore[attr-defined]
113 if url[:5] == 'https':
114 url = url.replace('https:', 'http:')
115 print('Failed download. Trying https -> http instead.'
116 ' Downloading ' + url + ' to ' + fpath)
117 urllib.request.urlretrieve(
118 url, fpath,
119 reporthook=gen_bar_updater(size)
120 )
121 else:
122 raise e
123 # check integrity of downloaded file
124 if not check_integrity(fpath, md5):
125 raise RuntimeError("File not found or corrupted.")
126  
127  
128 def list_dir(root: str, prefix: bool = False) -> List[str]:
129 """List all directories at a given root
130  
131 Args:
132 root (str): Path to directory whose folders need to be listed
  • E501 Line too long (85 > 79 characters)
133 prefix (bool, optional): If true, prepends the path to each result, otherwise
134 only returns the name of the directories found
135 """
136 root = os.path.expanduser(root)
  • E501 Line too long (87 > 79 characters)
137 directories = [p for p in os.listdir(root) if os.path.isdir(os.path.join(root, p))]
138 if prefix is True:
139 directories = [os.path.join(root, d) for d in directories]
140 return directories
141  
142  
143 def list_files(root: str, suffix: str, prefix: bool = False) -> List[str]:
144 """List all files ending with a suffix at a given root
145  
146 Args:
147 root (str): Path to directory whose folders need to be listed
  • E501 Line too long (93 > 79 characters)
148 suffix (str or tuple): Suffix of the files to match, e.g. '.png' or ('.jpg', '.png').
149 It uses the Python "str.endswith" method and is passed directly
  • E501 Line too long (85 > 79 characters)
150 prefix (bool, optional): If true, prepends the path to each result, otherwise
151 only returns the name of the files found
152 """
153 root = os.path.expanduser(root)
  • E501 Line too long (105 > 79 characters)
154 files = [p for p in os.listdir(root) if os.path.isfile(os.path.join(root, p)) and p.endswith(suffix)]
155 if prefix is True:
156 files = [os.path.join(root, d) for d in files]
157 return files
158  
159  
  • F821 Undefined name 'requests'
  • E501 Line too long (96 > 79 characters)
160 def _quota_exceeded(response: "requests.models.Response") -> bool: # type: ignore[name-defined]
161 return "Google Drive - Quota exceeded" in response.text
162  
163  
  • E501 Line too long (120 > 79 characters)
164 def download_file_from_google_drive(file_id: str, root: str, filename: Optional[str] = None, md5: Optional[str] = None):
165 """Download a Google Drive file from and place it in root.
166  
167 Args:
168 file_id (str): id of file to be downloaded
169 root (str): Directory to place downloaded file in
  • E501 Line too long (95 > 79 characters)
170 filename (str, optional): Name to save the file under. If None, use the id of the file.
  • E501 Line too long (80 > 79 characters)
171 md5 (str, optional): MD5 checksum of the download. If None, do not check
172 """
  • E501 Line too long (109 > 79 characters)
173 # Based on https://stackoverflow.com/questions/38511444/python-download-files-from-google-drive-using-url
174 import requests
175 url = "https://docs.google.com/uc?export=download"
176  
177 root = os.path.expanduser(root)
178 if not filename:
179 filename = file_id
180 fpath = os.path.join(root, filename)
181  
182 os.makedirs(root, exist_ok=True)
183  
184 if os.path.isfile(fpath) and check_integrity(fpath, md5):
185 print('Using downloaded and verified file: ' + fpath)
186 else:
187 session = requests.Session()
188  
189 response = session.get(url, params={'id': file_id}, stream=True)
190 token = _get_confirm_token(response)
191  
192 if token:
193 params = {'id': file_id, 'confirm': token}
194 response = session.get(url, params=params, stream=True)
195  
196 if _quota_exceeded(response):
197 msg = (
198 f"The daily quota of the file {filename} is exceeded and it "
199 f"can't be downloaded. This is a limitation of Google Drive "
200 f"and can only be overcome by trying again later."
201 )
202 raise RuntimeError(msg)
203  
204 _save_response_content(response, fpath)
205  
206  
  • F821 Undefined name 'requests'
  • E501 Line too long (108 > 79 characters)
207 def _get_confirm_token(response: "requests.models.Response") -> Optional[str]: # type: ignore[name-defined]
208 for key, value in response.cookies.items():
209 if key.startswith('download_warning'):
210 return value
211  
212 return None
213  
214  
215 def _save_response_content(
  • F821 Undefined name 'requests'
  • E501 Line too long (118 > 79 characters)
216 response: "requests.models.Response", destination: str, chunk_size: int = 32768, # type: ignore[name-defined]
217 ) -> None:
218 with open(destination, "wb") as f:
219 pbar = tqdm(total=None)
220 progress = 0
221 for chunk in response.iter_content(chunk_size):
222 if chunk: # filter out keep-alive new chunks
223 f.write(chunk)
224 progress += len(chunk)
225 pbar.update(progress - pbar.n)
226 pbar.close()
227  
228  
229 def _is_tarxz(filename: str) -> bool:
230 return filename.endswith(".tar.xz")
231  
232  
233 def _is_tar(filename: str) -> bool:
234 return filename.endswith(".tar")
235  
236  
237 def _is_targz(filename: str) -> bool:
238 return filename.endswith(".tar.gz")
239  
240  
241 def _is_tgz(filename: str) -> bool:
242 return filename.endswith(".tgz")
243  
244  
245 def _is_gzip(filename: str) -> bool:
246 return filename.endswith(".gz") and not filename.endswith(".tar.gz")
247  
248  
249 def _is_zip(filename: str) -> bool:
250 return filename.endswith(".zip")
251  
252  
  • E501 Line too long (106 > 79 characters)
253 def extract_archive(from_path: str, to_path: Optional[str] = None, remove_finished: bool = False) -> None:
254 if to_path is None:
255 to_path = os.path.dirname(from_path)
256  
257 if _is_tar(from_path):
258 with tarfile.open(from_path, 'r') as tar:
259 tar.extractall(path=to_path)
260 elif _is_targz(from_path) or _is_tgz(from_path):
261 with tarfile.open(from_path, 'r:gz') as tar:
262 tar.extractall(path=to_path)
263 elif _is_tarxz(from_path):
264 with tarfile.open(from_path, 'r:xz') as tar:
265 tar.extractall(path=to_path)
266 elif _is_gzip(from_path):
  • E501 Line too long (89 > 79 characters)
267 to_path = os.path.join(to_path, os.path.splitext(os.path.basename(from_path))[0])
268 with open(to_path, "wb") as out_f, gzip.GzipFile(from_path) as zip_f:
269 out_f.write(zip_f.read())
270 elif _is_zip(from_path):
271 with zipfile.ZipFile(from_path, 'r') as z:
272 z.extractall(to_path)
273 else:
274 raise ValueError("Extraction of {} not supported".format(from_path))
275  
276 if remove_finished:
277 os.remove(from_path)
278  
279  
280 def download_and_extract_archive(
281 url: str,
282 download_root: str,
283 extract_root: Optional[str] = None,
284 filename: Optional[str] = None,
285 md5: Optional[str] = None,
286 remove_finished: bool = False,
287 size: Optional[int] = None
288 ) -> None:
289 download_root = os.path.expanduser(download_root)
290 if extract_root is None:
291 extract_root = download_root
292 if not filename:
293 filename = os.path.basename(url)
294  
295 download_url(url, download_root, filename, md5, size)
296  
297 archive = os.path.join(download_root, filename)
298 print("Extracting {} to {}".format(archive, extract_root))
299 extract_archive(archive, extract_root, remove_finished)
300  
301  
302 def iterable_to_str(iterable: Iterable) -> str:
303 return "'" + "', '".join([str(item) for item in iterable]) + "'"
304  
305  
306 T = TypeVar("T", str, bytes)
307  
308  
309 def verify_str_arg(
  • E501 Line too long (112 > 79 characters)
310 value: T, arg: Optional[str] = None, valid_values: Iterable[T] = None, custom_msg: Optional[str] = None,
311 ) -> T:
312 if not isinstance(value, torch._six.string_classes):
313 if arg is None:
314 msg = "Expected type str, but got type {type}."
315 else:
316 msg = "Expected type str for argument {arg}, but got type {type}."
317 msg = msg.format(type=type(value), arg=arg)
318 raise ValueError(msg)
319  
320 if valid_values is None:
321 return value
322  
323 if value not in valid_values:
324 if custom_msg is not None:
325 msg = custom_msg
326 else:
327 msg = ("Unknown value '{value}' for argument {arg}. "
328 "Valid values are {{{valid_values}}}.")
329 msg = msg.format(value=value, arg=arg,
330 valid_values=iterable_to_str(valid_values))
331 raise ValueError(msg)
332  
333 return value