Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1import gds
4def get_dataset(dataset, version=None, **dataset_kwargs):
5 """
6 Returns the appropriate WILDS dataset class.
7 Input:
8 dataset (str): Name of the dataset
9 version (str): Dataset version number, e.g., '1.0'.
10 Defaults to the latest version.
11 dataset_kwargs: Other keyword arguments to pass to the dataset constructors.
12 Output:
13 The specified WILDSDataset class.
14 """
15 if version is not None:
16 version = str(version)
18 if dataset not in gds.supported_datasets:
19 raise ValueError(f'The dataset {dataset} is not recognized. Must be one of {gds.supported_datasets}.')
21 if dataset == 'ogb-molpcba':
22 from gds.datasets.ogbmolpcba_dataset import OGBPCBADataset
23 return OGBPCBADataset(version=version, **dataset_kwargs)
25 elif dataset == 'ogb-molhiv':
26 from gds.datasets.ogbmolhiv_dataset import OGBHIVDataset
27 return OGBHIVDataset(version=version, **dataset_kwargs)
29 elif dataset == 'ogbg-ppa':
30 from gds.datasets.ogbgppa_dataset import OGBGPPADataset
31 return OGBGPPADataset(version=version, **dataset_kwargs)
33 elif dataset == 'RotatedMNIST':
34 from gds.datasets.rotated_mnist_dataset import RotatedMNISTDataset
35 return RotatedMNISTDataset(version=version, **dataset_kwargs)
37 elif dataset == 'ColoredMNIST':
38 from gds.datasets.colored_mnist_dataset import ColoredMNISTDataset
39 return ColoredMNISTDataset(version=version, **dataset_kwargs)
41 elif dataset == 'SBM-Environment':
42 from gds.datasets.sbm_environment_dataset import SBMEnvironmentDataset
43 return SBMEnvironmentDataset(version=version, **dataset_kwargs)
45 elif dataset == 'SBM-Isolation':
46 from gds.datasets.sbm_isolation_dataset import SBMIsolationDataset
47 return SBMIsolationDataset(version=version, **dataset_kwargs)
49 elif dataset == 'UPFD':
50 from gds.datasets.upfd_dataset import UPFDDataset
51 return UPFDDataset(version=version, **dataset_kwargs)
53 else:
54 raise NotImplementedError