⬅ get_dataset.py source

1 import gds
2  
3  
4 def get_dataset(dataset, version=None, **dataset_kwargs):
5 """
6 Returns the appropriate WILDS dataset class.
7 Input:
8 dataset (str): Name of the dataset
9 version (str): Dataset version number, e.g., '1.0'.
10 Defaults to the latest version.
  • E501 Line too long (84 > 79 characters)
11 dataset_kwargs: Other keyword arguments to pass to the dataset constructors.
12 Output:
13 The specified WILDSDataset class.
14 """
15 if version is not None:
16 version = str(version)
17  
18 if dataset not in gds.supported_datasets:
  • E501 Line too long (110 > 79 characters)
19 raise ValueError(f'The dataset {dataset} is not recognized. Must be one of {gds.supported_datasets}.')
20  
21 if dataset == 'ogb-molpcba':
22 from gds.datasets.ogbmolpcba_dataset import OGBPCBADataset
23 return OGBPCBADataset(version=version, **dataset_kwargs)
24  
25 elif dataset == 'ogb-molhiv':
26 from gds.datasets.ogbmolhiv_dataset import OGBHIVDataset
27 return OGBHIVDataset(version=version, **dataset_kwargs)
28  
29 elif dataset == 'ogbg-ppa':
30 from gds.datasets.ogbgppa_dataset import OGBGPPADataset
31 return OGBGPPADataset(version=version, **dataset_kwargs)
32  
33 elif dataset == 'RotatedMNIST':
34 from gds.datasets.rotated_mnist_dataset import RotatedMNISTDataset
35 return RotatedMNISTDataset(version=version, **dataset_kwargs)
36  
37 elif dataset == 'ColoredMNIST':
38 from gds.datasets.colored_mnist_dataset import ColoredMNISTDataset
39 return ColoredMNISTDataset(version=version, **dataset_kwargs)
40  
41 elif dataset == 'SBM-Environment':
42 from gds.datasets.sbm_environment_dataset import SBMEnvironmentDataset
43 return SBMEnvironmentDataset(version=version, **dataset_kwargs)
44  
45 elif dataset == 'SBM-Isolation':
46 from gds.datasets.sbm_isolation_dataset import SBMIsolationDataset
47 return SBMIsolationDataset(version=version, **dataset_kwargs)
48  
49 elif dataset == 'UPFD':
50 from gds.datasets.upfd_dataset import UPFDDataset
51 return UPFDDataset(version=version, **dataset_kwargs)
52  
53 else:
54 raise NotImplementedError