⬅ datasets/gds_dataset.py source

1 import os
2 import time
3  
4 import numpy as np
5 import torch
6  
7  
8 class GDSDataset:
9 """
10 Shared dataset class for all WILDS datasets.
11 Each data point in the dataset is an (x, y, metadata) tuple, where:
12 - x is the input features
13 - y is the target
14 - metadata is a vector of relevant information, e.g., domain.
15 For convenience, metadata also contains y.
16 """
17 DEFAULT_SPLITS = {'train': 0, 'val': 1, 'test': 2}
  • E501 Line too long (81 > 79 characters)
18 DEFAULT_SPLIT_NAMES = {'train': 'Train', 'val': 'Validation', 'test': 'Test'}
19  
20 def __init__(self, root_dir, download, split_scheme):
21 if len(self._metadata_array.shape) == 1:
22 self._metadata_array = self._metadata_array.unsqueeze(1)
23 self.check_init()
24  
25 def __len__(self):
26 return len(self.y_array)
27  
28 def __getitem__(self, idx):
29 # Any transformations are handled by the WILDSSubset
  • E501 Line too long (87 > 79 characters)
30 # since different subsets (e.g., train vs test) might have different transforms
31 x = self.get_input(idx)
32 y = self.y_array[idx]
33 metadata = self.metadata_array[idx]
34 return x, y, metadata
35  
36 def get_input(self, idx):
37 """
38 Args:
39 - idx (int): Index of a data point
40 Output:
41 - x (Tensor): Input features of the idx-th data point
42 """
43 raise NotImplementedError
44  
45 def eval(self, y_pred, y_true, metadata):
46 """
47 Args:
48 - y_pred (Tensor): Predicted targets
49 - y_true (Tensor): True targets
50 - metadata (Tensor): Metadata
51 Output:
52 - results (dict): Dictionary of results
53 - results_str (str): Pretty print version of the results
54 """
55 raise NotImplementedError
56  
57 def get_subset(self, split, frac=1.0, transform=None):
58 """
59 Args:
60 - split (str): Split identifier, e.g., 'train', 'val', 'test'.
61 Must be in self.split_dict.
62 - frac (float): What fraction of the split to randomly sample.
63 Used for fast development on a small dataset.
  • E501 Line too long (90 > 79 characters)
64 - transform (function): Any data transformations to be applied to the input x.
65 Output:
  • E501 Line too long (90 > 79 characters)
66 - subset (WILDSSubset): A (potentially subsampled) subset of the WILDSDataset.
67 """
68 if split not in self.split_dict:
  • E501 Line too long (81 > 79 characters)
69 raise ValueError(f"Split {split} not found in dataset's split_dict.")
70 split_mask = self.split_array == self.split_dict[split]
71 split_idx = np.where(split_mask)[0]
72 if frac < 1.0:
73 num_to_retain = int(np.round(float(len(split_idx)) * frac))
  • E501 Line too long (81 > 79 characters)
74 split_idx = np.sort(np.random.permutation(split_idx)[:num_to_retain])
75 subset = GDSSubset(self, split_idx, transform)
76 return subset
77  
78 def check_init(self):
79 """
  • E501 Line too long (83 > 79 characters)
80 Convenience function to check that the WILDSDataset is properly configured.
81 """
82 required_attrs = ['_dataset_name', '_data_dir',
83 '_split_scheme', '_split_array',
84 '_y_array', '_y_size',
85 '_metadata_fields', '_metadata_array']
86 for attr_name in required_attrs:
  • E501 Line too long (84 > 79 characters)
87 assert hasattr(self, attr_name), f'WILDSDataset is missing {attr_name}.'
88  
89 # Check that data directory exists
90 if not os.path.exists(self.data_dir):
91 raise ValueError(
  • E501 Line too long (90 > 79 characters)
92 f'{self.data_dir} does not exist yet. Please generate the dataset first.')
93  
94 # Check splits
95 assert self.split_dict.keys() == self.split_names.keys()
96 assert 'train' in self.split_dict
97 assert 'val' in self.split_dict
98  
99 # Check the form of the required arrays
  • E501 Line too long (89 > 79 characters)
100 assert (isinstance(self.y_array, torch.Tensor) or isinstance(self.y_array, list))
  • E501 Line too long (101 > 79 characters)
101 assert isinstance(self.metadata_array, torch.Tensor), 'metadata_array must be a torch.Tensor'
102  
103 # Check that dimensions match
104 assert len(self.y_array) == len(self.metadata_array)
105 assert len(self.split_array) == len(self.metadata_array)
106  
107 # Check metadata
108 assert len(self.metadata_array.shape) == 2
109 assert len(self.metadata_fields) == self.metadata_array.shape[1]
110  
111 # Check that it is not both classification and detection
112 assert not (self.is_classification and self.is_detection)
113  
114 # For convenience, include y in metadata_fields if y_size == 1
115 if self.y_size == 1:
116 assert 'y' in self.metadata_fields
117  
118 @property
119 def latest_version(cls):
120 def is_later(u, v):
121 """Returns true if u is a later version than v."""
122 u_major, u_minor = tuple(map(int, u.split('.')))
123 v_major, v_minor = tuple(map(int, v.split('.')))
124 if (u_major > v_major) or (
125 (u_major == v_major) and (u_minor > v_minor)):
126 return True
127 else:
128 return False
129  
130 latest_version = '0.0'
131 for key in cls.versions_dict.keys():
132 if is_later(key, latest_version):
133 latest_version = key
134 return latest_version
135  
136 @property
137 def dataset_name(self):
138 """
139 A string that identifies the dataset, e.g., 'amazon', 'camelyon17'.
140 """
141 return self._dataset_name
142  
143 @property
144 def version(self):
145 """
146 A string that identifies the dataset version, e.g., '1.0'.
147 """
148 if self._version is None:
149 return self.latest_version
150 else:
151 return self._version
152  
153 @property
154 def versions_dict(self):
155 """
156 A dictionary where each key is a version string (e.g., '1.0')
157 and each value is a dictionary containing the 'download_url' and
158 'compressed_size' keys.
159  
160 'download_url' is the URL for downloading the dataset archive.
161 If None, the dataset cannot be downloaded automatically
162 (e.g., because it first requires accepting a usage agreement).
163  
  • E501 Line too long (85 > 79 characters)
164 'compressed_size' is the approximate size of the compressed dataset in bytes.
165 """
166 return self._versions_dict
167  
168 @property
169 def data_dir(self):
170 """
171 The full path to the folder in which the dataset is stored.
172 """
173 return self._data_dir
174  
175 @property
176 def collate(self):
177 """
178 Torch function to collate items in a batch.
179 By default returns None -> uses default torch collate.
180 """
181 return getattr(self, '_collate', None)
182  
183 @property
184 def split_scheme(self):
185 """
186 A string identifier of how the split is constructed,
187 e.g., 'standard', 'mixed-to-test', 'user', etc.
188 """
189 return self._split_scheme
190  
191 @property
192 def split_dict(self):
193 """
  • E501 Line too long (81 > 79 characters)
194 A dictionary mapping splits to integer identifiers (used in split_array),
195 e.g., {'train': 0, 'val': 1, 'test': 2}.
196 Keys should match up with split_names.
197 """
198 return getattr(self, '_split_dict', GDSDataset.DEFAULT_SPLITS)
199  
200 @property
201 def split_names(self):
202 """
203 A dictionary mapping splits to their pretty names,
204 e.g., {'train': 'Train', 'val': 'Validation', 'test': 'Test'}.
205 Keys should match up with split_dict.
206 """
207 return getattr(self, '_split_names', GDSDataset.DEFAULT_SPLIT_NAMES)
208  
209 @property
210 def split_array(self):
211 """
  • E501 Line too long (93 > 79 characters)
212 An array of integers, with split_array[i] representing what split the i-th data point
213 belongs to.
214 """
215 return self._split_array
216  
217 @property
218 def y_array(self):
219 """
220 A Tensor of targets (e.g., labels for classification tasks),
221 with y_array[i] representing the target of the i-th data point.
222 y_array[i] can contain multiple elements.
223 """
224 return self._y_array
225  
226 @property
227 def y_size(self):
228 """
229 The number of dimensions/elements in the target, i.e., len(y_array[i]).
230 For standard classification/regression tasks, y_size = 1.
231 For multi-task or structured prediction settings, y_size > 1.
  • E501 Line too long (87 > 79 characters)
232 Used for logging and to configure models to produce appropriately-sized output.
233 """
234 return self._y_size
235  
236 @property
237 def n_classes(self):
238 """
239 Number of classes for single-task classification datasets.
  • E501 Line too long (87 > 79 characters)
240 Used for logging and to configure models to produce appropriately-sized output.
241 None by default.
  • E501 Line too long (88 > 79 characters)
242 Leave as None if not applicable (e.g., regression or multi-task classification).
243 """
244 return getattr(self, '_n_classes', None)
245  
246 @property
247 def is_classification(self):
248 """
249 Boolean. True if the task is classification, and false otherwise.
250 """
  • E501 Line too long (80 > 79 characters)
251 return getattr(self, '_is_classification', (self.n_classes is not None))
252  
253 @property
254 def is_detection(self):
255 """
256 Boolean. True if the task is detection, and false otherwise.
257 """
258 return getattr(self, '_is_detection', False)
259  
260 @property
261 def metadata_fields(self):
262 """
  • E501 Line too long (92 > 79 characters)
263 A list of strings naming each column of the metadata table, e.g., ['hospital', 'y'].
264 Must include 'y'.
265 """
266 return self._metadata_fields
267  
268 @property
269 def metadata_array(self):
270 """
  • E501 Line too long (89 > 79 characters)
271 A Tensor of metadata, with the i-th row representing the metadata associated with
  • E501 Line too long (89 > 79 characters)
272 the i-th data point. The columns correspond to the metadata_fields defined above.
273 """
274 return self._metadata_array
275  
276 @property
277 def metadata_map(self):
278 """
  • E501 Line too long (92 > 79 characters)
279 An optional dictionary that, for each metadata field, contains a list that maps from
  • E501 Line too long (86 > 79 characters)
280 integers (in metadata_array) to a string representing what that integer means.
  • E501 Line too long (94 > 79 characters)
281 This is only used for logging, so that we print out more intelligible metadata values.
282 Each key must be in metadata_fields.
283 For example, if we have
284 metadata_fields = ['hospital', 'y']
285 metadata_map = {'hospital': ['East', 'West']}
  • E501 Line too long (93 > 79 characters)
286 then if metadata_array[i, 0] == 0, the i-th data point belongs to the 'East' hospital
287 while if metadata_array[i, 0] == 1, it belongs to the 'West' hospital.
288 """
289 return getattr(self, '_metadata_map', None)
290  
291 @property
292 def original_resolution(self):
293 """
294 Original image resolution for image datasets.
295 """
296 return getattr(self, '_original_resolution', None)
297  
298 def initialize_data_dir(self, root_dir, download):
299 """
300 Helper function for downloading/updating the dataset if required.
  • E501 Line too long (88 > 79 characters)
301 Note that we only do a version check for datasets where the download_url is set.
302 Currently, this includes all datasets except Yelp.
303 Datasets for which we don't control the download, like Yelp,
304 might not handle versions similarly.
305 """
306 if self.version not in self.versions_dict:
  • E501 Line too long (110 > 79 characters)
307 raise ValueError(f'Version {self.version} not supported. Must be in {self.versions_dict.keys()}.')
308  
309 download_url = self.versions_dict[self.version]['download_url']
310 compressed_size = self.versions_dict[self.version]['compressed_size']
311  
312 os.makedirs(root_dir, exist_ok=True)
313  
  • E501 Line too long (81 > 79 characters)
314 data_dir = os.path.join(root_dir, f'{self.dataset_name}_v{self.version}')
315 version_file = os.path.join(data_dir, f'RELEASE_v{self.version}.txt')
  • E501 Line too long (95 > 79 characters)
316 current_major_version, current_minor_version = tuple(map(int, self.version.split('.')))
317  
318 # Check if we specified the latest version. Otherwise, print a warning.
  • E501 Line too long (100 > 79 characters)
319 latest_major_version, latest_minor_version = tuple(map(int, self.latest_version.split('.')))
320 if latest_major_version > current_major_version:
321 print(
322 f'*****************************\n'
  • E501 Line too long (91 > 79 characters)
323 f'{self.dataset_name} has been updated to version {self.latest_version}.\n'
324 f'You are currently using version {self.version}.\n'
  • E501 Line too long (150 > 79 characters)
325 f'We highly recommend updating the dataset by not specifying the older version in the command-line argument or dataset constructor.\n'
326 f'See https://wilds.stanford.edu/changelog for changes.\n'
327 f'*****************************\n')
328 elif latest_minor_version > current_minor_version:
329 print(
330 f'*****************************\n'
  • E501 Line too long (91 > 79 characters)
331 f'{self.dataset_name} has been updated to version {self.latest_version}.\n'
332 f'You are currently using version {self.version}.\n'
333 f'Please consider updating the dataset.\n'
334 f'See https://wilds.stanford.edu/changelog for changes.\n'
335 f'*****************************\n')
336  
337 # If the data_dir exists and contains the right RELEASE file,
338 # we assume the dataset is correctly set up
339 if os.path.exists(data_dir) and os.path.exists(version_file):
340 return data_dir
341  
  • E501 Line too long (130 > 79 characters)
342 # If the data_dir exists and does not contain the right RELEASE file, but it is not empty and the download_url is not set,
343 # we assume the dataset is correctly set up
344 if ((os.path.exists(data_dir)) and
345 (len(os.listdir(data_dir)) > 0) and
346 (download_url is None)):
347 return data_dir
348  
349 # Otherwise, we assume the dataset needs to be downloaded.
350 # If download == False, then return an error.
  • E712 Comparison to False should be 'if cond is False:' or 'if not cond:'
351 if download == False:
352 if download_url is None:
353 raise FileNotFoundError(
  • E501 Line too long (174 > 79 characters)
354 f'The {self.dataset_name} dataset could not be found in {data_dir}. {self.dataset_name} cannot be automatically downloaded. Please download it manually.')
355 else:
356 raise FileNotFoundError(
  • E501 Line too long (260 > 79 characters)
357 f'The {self.dataset_name} dataset could not be found in {data_dir}. Initialize the dataset with download=True to download the dataset. If you are using the example script, run with --download. This might take some time for large datasets.')
358  
359 # Otherwise, proceed with downloading.
360 if download_url is None:
361 raise ValueError(
  • E501 Line too long (111 > 79 characters)
362 f'Sorry, {self.dataset_name} cannot be automatically downloaded. Please download it manually.')
363  
364 from gds.datasets.download_utils import download_and_extract_archive
365 print(f'Downloading dataset to {data_dir}...')
  • F541 F-string is missing placeholders
  • E501 Line too long (101 > 79 characters)
366 print(f'You can also download the dataset manually at https://wilds.stanford.edu/downloads.')
367 try:
368 start_time = time.time()
369 download_and_extract_archive(
370 url=download_url,
371 download_root=data_dir,
372 filename='archive.tar.gz',
373 remove_finished=True,
374 size=compressed_size)
375  
376 download_time_in_minutes = (time.time() - start_time) / 60
  • E501 Line too long (114 > 79 characters)
377 print(f"It took {round(download_time_in_minutes, 2)} minutes to download and uncompress the dataset.")
378 except Exception as e:
379 print(
  • E501 Line too long (135 > 79 characters)
380 f"\n{os.path.join(data_dir, 'archive.tar.gz')} may be corrupted. Please try deleting it and rerunning this command.\n")
  • F541 F-string is missing placeholders
381 print(f"Exception: ", e)
382  
383 return data_dir
384  
385 @staticmethod
386 def standard_eval(metric, y_pred, y_true):
387 """
388 Args:
389 - metric (Metric): Metric to use for eval
390 - y_pred (Tensor): Predicted targets
391 - y_true (Tensor): True targets
392 Output:
393 - results (dict): Dictionary of results
394 - results_str (str): Pretty print version of the results
395 """
396 results = {
397 **metric.compute(y_pred, y_true),
398 }
399 results_str = (
400 f"Average {metric.name}: {results[metric.agg_metric_field]:.3f}\n"
401 )
402 return results, results_str
403  
404 @staticmethod
  • E501 Line too long (87 > 79 characters)
405 def standard_group_eval(metric, grouper, y_pred, y_true, metadata, aggregate=True):
406 """
407 Args:
408 - metric (Metric): Metric to use for eval
  • E501 Line too long (95 > 79 characters)
409 - grouper (CombinatorialGrouper): Grouper object that converts metadata into groups
410 - y_pred (Tensor): Predicted targets
411 - y_true (Tensor): True targets
412 - metadata (Tensor): Metadata
413 Output:
414 - results (dict): Dictionary of results
415 - results_str (str): Pretty print version of the results
416 """
417 results, results_str = {}, ''
418 if aggregate:
419 results.update(metric.compute(y_pred, y_true))
  • E501 Line too long (93 > 79 characters)
420 results_str += f"Average {metric.name}: {results[metric.agg_metric_field]:.3f}\n"
421 g = grouper.metadata_to_group(metadata)
  • E501 Line too long (86 > 79 characters)
422 group_results = metric.compute_group_wise(y_pred, y_true, g, grouper.n_groups)
423 for group_idx in range(grouper.n_groups):
424 group_str = grouper.group_field_str(group_idx)
425 group_metric = group_results[metric.group_metric_field(group_idx)]
426 group_counts = group_results[metric.group_count_field(group_idx)]
427 results[f'{metric.name}_{group_str}'] = group_metric
428 results[f'count_{group_str}'] = group_counts
429 if group_results[metric.group_count_field(group_idx)] == 0:
430 continue
431 results_str += (
432 f' {grouper.group_str(group_idx)} '
  • E501 Line too long (85 > 79 characters)
433 f"[n = {group_results[metric.group_count_field(group_idx)]:6.0f}]:\t"
  • E501 Line too long (96 > 79 characters)
434 f"{metric.name} = {group_results[metric.group_metric_field(group_idx)]:5.3f}\n")
  • E501 Line too long (107 > 79 characters)
435 results[f'{metric.worst_group_metric_field}'] = group_results[f'{metric.worst_group_metric_field}']
  • E501 Line too long (107 > 79 characters)
436 results_str += f"Worst-group {metric.name}: {group_results[metric.worst_group_metric_field]:.3f}\n"
437 return results, results_str
438  
439  
440 class GDSSubset(GDSDataset):
441 def __init__(self, dataset, indices, transform, do_transform_y=False):
442 """
443 This acts like `torch.utils.data.Subset`, but on `WILDSDatasets`.
444 We pass in `transform` (which is used for data augmentation) explicitly
445 because it can potentially vary on the training vs. test subsets.
446  
447 `do_transform_y` (bool): When this is false (the default),
448 `self.transform ` acts only on `x`.
449 Set this to true if `self.transform` should
450 operate on `(x,y)` instead of just `x`.
451 """
452 self.dataset = dataset
453 self.indices = indices
454 inherited_attrs = ['_dataset_name', '_data_dir', '_collate',
455 '_split_scheme', '_split_dict', '_split_names',
456 '_y_size', '_n_classes',
457 '_metadata_fields', '_metadata_map']
458 for attr_name in inherited_attrs:
459 if hasattr(dataset, attr_name):
460 setattr(self, attr_name, getattr(dataset, attr_name))
461 self.transform = transform
462 self.do_transform_y = do_transform_y
463  
464 def __getitem__(self, idx):
465 x, y, metadata = self.dataset[self.indices[idx]]
466 if self.transform is not None:
467 if self.do_transform_y:
468 x, y = self.transform(x, y)
469 else:
470 x = self.transform(x)
471 return x, y, metadata
472  
473 def __len__(self):
474 return len(self.indices)
475  
476 @property
477 def split_array(self):
478 return self.dataset._split_array[self.indices]
479  
480 @property
481 def y_array(self):
482 return self.dataset._y_array[self.indices]
483  
484 @property
485 def metadata_array(self):
486 return self.dataset.metadata_array[self.indices]
487  
488 def eval(self, y_pred, y_true, metadata):
489 return self.dataset.eval(y_pred, y_true, metadata)