Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1import gds 

2 

3 

4def get_dataset(dataset, version=None, **dataset_kwargs): 

5 """ 

6 Returns the appropriate WILDS dataset class. 

7 Input: 

8 dataset (str): Name of the dataset 

9 version (str): Dataset version number, e.g., '1.0'. 

10 Defaults to the latest version. 

11 dataset_kwargs: Other keyword arguments to pass to the dataset constructors. 

12 Output: 

13 The specified WILDSDataset class. 

14 """ 

15 if version is not None: 

16 version = str(version) 

17 

18 if dataset not in gds.supported_datasets: 

19 raise ValueError(f'The dataset {dataset} is not recognized. Must be one of {gds.supported_datasets}.') 

20 

21 if dataset == 'ogb-molpcba': 

22 from gds.datasets.ogbmolpcba_dataset import OGBPCBADataset 

23 return OGBPCBADataset(version=version, **dataset_kwargs) 

24 

25 elif dataset == 'ogb-molhiv': 

26 from gds.datasets.ogbmolhiv_dataset import OGBHIVDataset 

27 return OGBHIVDataset(version=version, **dataset_kwargs) 

28 

29 elif dataset == 'ogbg-ppa': 

30 from gds.datasets.ogbgppa_dataset import OGBGPPADataset 

31 return OGBGPPADataset(version=version, **dataset_kwargs) 

32 

33 elif dataset == 'RotatedMNIST': 

34 from gds.datasets.rotated_mnist_dataset import RotatedMNISTDataset 

35 return RotatedMNISTDataset(version=version, **dataset_kwargs) 

36 

37 elif dataset == 'ColoredMNIST': 

38 from gds.datasets.colored_mnist_dataset import ColoredMNISTDataset 

39 return ColoredMNISTDataset(version=version, **dataset_kwargs) 

40 

41 elif dataset == 'SBM-Environment': 

42 from gds.datasets.sbm_environment_dataset import SBMEnvironmentDataset 

43 return SBMEnvironmentDataset(version=version, **dataset_kwargs) 

44 

45 elif dataset == 'SBM-Isolation': 

46 from gds.datasets.sbm_isolation_dataset import SBMIsolationDataset 

47 return SBMIsolationDataset(version=version, **dataset_kwargs) 

48 

49 elif dataset == 'UPFD': 

50 from gds.datasets.upfd_dataset import UPFDDataset 

51 return UPFDDataset(version=version, **dataset_kwargs) 

52 

53 else: 

54 raise NotImplementedError