Line too long (81 > 79 characters):
18 DEFAULT_SPLIT_NAMES = {'train': 'Train', 'val': 'Validation', 'test': 'Test'}Line too long (87 > 79 characters):
30 # since different subsets (e.g., train vs test) might have different transformsLine too long (90 > 79 characters):
64 - transform (function): Any data transformations to be applied to the input x.Line too long (90 > 79 characters):
66 - subset (WILDSSubset): A (potentially subsampled) subset of the WILDSDataset.Line too long (81 > 79 characters):
69 raise ValueError(f"Split {split} not found in dataset's split_dict.")Line too long (81 > 79 characters):
74 split_idx = np.sort(np.random.permutation(split_idx)[:num_to_retain])Line too long (83 > 79 characters):
80 Convenience function to check that the WILDSDataset is properly configured.Line too long (84 > 79 characters):
87 assert hasattr(self, attr_name), f'WILDSDataset is missing {attr_name}.'Line too long (90 > 79 characters):
92 f'{self.data_dir} does not exist yet. Please generate the dataset first.')Line too long (89 > 79 characters):
100 assert (isinstance(self.y_array, torch.Tensor) or isinstance(self.y_array, list))Line too long (101 > 79 characters):
101 assert isinstance(self.metadata_array, torch.Tensor), 'metadata_array must be a torch.Tensor'Line too long (85 > 79 characters):
164 'compressed_size' is the approximate size of the compressed dataset in bytes.Line too long (81 > 79 characters):
194 A dictionary mapping splits to integer identifiers (used in split_array),Line too long (93 > 79 characters):
212 An array of integers, with split_array[i] representing what split the i-th data pointLine too long (87 > 79 characters):
232 Used for logging and to configure models to produce appropriately-sized output.Line too long (87 > 79 characters):
240 Used for logging and to configure models to produce appropriately-sized output.Line too long (88 > 79 characters):
242 Leave as None if not applicable (e.g., regression or multi-task classification).Line too long (80 > 79 characters):
251 return getattr(self, '_is_classification', (self.n_classes is not None))Line too long (92 > 79 characters):
263 A list of strings naming each column of the metadata table, e.g., ['hospital', 'y'].Line too long (89 > 79 characters):
271 A Tensor of metadata, with the i-th row representing the metadata associated withLine too long (89 > 79 characters):
272 the i-th data point. The columns correspond to the metadata_fields defined above.Line too long (92 > 79 characters):
279 An optional dictionary that, for each metadata field, contains a list that maps fromLine too long (86 > 79 characters):
280 integers (in metadata_array) to a string representing what that integer means.Line too long (94 > 79 characters):
281 This is only used for logging, so that we print out more intelligible metadata values.Line too long (93 > 79 characters):
286 then if metadata_array[i, 0] == 0, the i-th data point belongs to the 'East' hospitalLine too long (88 > 79 characters):
301 Note that we only do a version check for datasets where the download_url is set.Line too long (110 > 79 characters):
307 raise ValueError(f'Version {self.version} not supported. Must be in {self.versions_dict.keys()}.')Line too long (81 > 79 characters):
314 data_dir = os.path.join(root_dir, f'{self.dataset_name}_v{self.version}')Line too long (95 > 79 characters):
316 current_major_version, current_minor_version = tuple(map(int, self.version.split('.')))Line too long (100 > 79 characters):
319 latest_major_version, latest_minor_version = tuple(map(int, self.latest_version.split('.')))Line too long (91 > 79 characters):
323 f'{self.dataset_name} has been updated to version {self.latest_version}.\n'Line too long (150 > 79 characters):
325 f'We highly recommend updating the dataset by not specifying the older version in the command-line argument or dataset constructor.\n'Line too long (91 > 79 characters):
331 f'{self.dataset_name} has been updated to version {self.latest_version}.\n'Line too long (130 > 79 characters):
342 # If the data_dir exists and does not contain the right RELEASE file, but it is not empty and the download_url is not set,Line too long (174 > 79 characters):
354 f'The {self.dataset_name} dataset could not be found in {data_dir}. {self.dataset_name} cannot be automatically downloaded. Please download it manually.')Line too long (260 > 79 characters):
357 f'The {self.dataset_name} dataset could not be found in {data_dir}. Initialize the dataset with download=True to download the dataset. If you are using the example script, run with --download. This might take some time for large datasets.')Line too long (111 > 79 characters):
362 f'Sorry, {self.dataset_name} cannot be automatically downloaded. Please download it manually.')Line too long (101 > 79 characters):
366 print(f'You can also download the dataset manually at https://wilds.stanford.edu/downloads.')Line too long (114 > 79 characters):
377 print(f"It took {round(download_time_in_minutes, 2)} minutes to download and uncompress the dataset.")Line too long (135 > 79 characters):
380 f"\n{os.path.join(data_dir, 'archive.tar.gz')} may be corrupted. Please try deleting it and rerunning this command.\n")Line too long (87 > 79 characters):
405 def standard_group_eval(metric, grouper, y_pred, y_true, metadata, aggregate=True):Line too long (95 > 79 characters):
409 - grouper (CombinatorialGrouper): Grouper object that converts metadata into groupsLine too long (93 > 79 characters):
420 results_str += f"Average {metric.name}: {results[metric.agg_metric_field]:.3f}\n"Line too long (86 > 79 characters):
422 group_results = metric.compute_group_wise(y_pred, y_true, g, grouper.n_groups)Line too long (85 > 79 characters):
433 f"[n = {group_results[metric.group_count_field(group_idx)]:6.0f}]:\t"Line too long (96 > 79 characters):
434 f"{metric.name} = {group_results[metric.group_metric_field(group_idx)]:5.3f}\n")Line too long (107 > 79 characters):
435 results[f'{metric.worst_group_metric_field}'] = group_results[f'{metric.worst_group_metric_field}']Line too long (107 > 79 characters):
436 results_str += f"Worst-group {metric.name}: {group_results[metric.worst_group_metric_field]:.3f}\n"