1 import numpy as np
2 from torch.utils.data import DataLoader
3 from torch.utils.data.sampler import WeightedRandomSampler
4
5 from gds.common.utils import split_into_groups
6
7
8 def get_train_loader(loader, dataset, batch_size,
-
E501
Line too long (107 > 79 characters)
9 uniform_over_groups=None, grouper=None, distinct_groups=True, n_groups_per_batch=None,
10 **loader_kwargs):
11 """
12 Constructs and returns the data loader for training.
13 Args:
-
E501
Line too long (99 > 79 characters)
14 - loader (str): Loader type. 'standard' for standard loaders and 'group' for group loaders,
-
E501
Line too long (104 > 79 characters)
15 which first samples groups and then samples a fixed number of examples belonging
16 to each group.
17 - dataset (WILDSDataset or WILDSSubset): Data
18 - batch_size (int): Batch size
-
E501
Line too long (97 > 79 characters)
19 - uniform_over_groups (None or bool): Whether to sample the groups uniformly or according
20 to the natural data distribution.
-
E501
Line too long (108 > 79 characters)
21 Setting to None applies the defaults for each type of loaders.
-
E501
Line too long (108 > 79 characters)
22 For standard loaders, the default is False. For group loaders,
23 the default is True.
-
E501
Line too long (91 > 79 characters)
24 - grouper (Grouper): Grouper used for group loaders or for uniform_over_groups=True
-
E501
Line too long (108 > 79 characters)
25 - distinct_groups (bool): Whether to sample distinct_groups within each minibatch for group loaders.
-
E501
Line too long (99 > 79 characters)
26 - n_groups_per_batch (int): Number of groups to sample in each minibatch for group loaders.
27 - loader_kwargs: kwargs passed into torch DataLoader initialization.
28 Output:
29 - data loader (DataLoader): Data loader.
30 """
31 if loader == 'standard':
32 if uniform_over_groups is None or not uniform_over_groups:
33 return DataLoader(
34 dataset,
35 shuffle=True, # Shuffle training dataset
36 sampler=None,
37 collate_fn=dataset.collate,
38 batch_size=batch_size,
39 **loader_kwargs)
40 else:
41 assert grouper is not None
42 groups, group_counts = grouper.metadata_to_group(
43 dataset.metadata_array,
44 return_counts=True)
45 group_weights = 1 / group_counts
46 weights = group_weights[groups]
47
-
E501
Line too long (94 > 79 characters)
48 # Replacement needs to be set to True, otherwise we'll run out of minority samples
-
E501
Line too long (84 > 79 characters)
49 sampler = WeightedRandomSampler(weights, len(dataset), replacement=True)
50 return DataLoader(
51 dataset,
52 shuffle=False, # The WeightedRandomSampler already shuffles
53 sampler=sampler,
54 collate_fn=dataset.collate,
55 batch_size=batch_size,
56 **loader_kwargs)
57
58 elif loader == 'group':
59 if uniform_over_groups is None:
60 uniform_over_groups = True
61 assert grouper is not None
62 assert n_groups_per_batch is not None
63 if n_groups_per_batch > grouper.n_groups:
64 raise ValueError(
-
E501
Line too long (126 > 79 characters)
65 f'n_groups_per_batch was set to {n_groups_per_batch} but there are only {grouper.n_groups} groups specified.')
66
67 group_ids = grouper.metadata_to_group(dataset.metadata_array)
68 batch_sampler = GroupSampler(
69 group_ids=group_ids,
70 batch_size=batch_size,
71 n_groups_per_batch=n_groups_per_batch,
72 uniform_over_groups=uniform_over_groups,
73 distinct_groups=distinct_groups)
74
75 return DataLoader(dataset,
76 shuffle=None,
77 sampler=None,
78 collate_fn=dataset.collate,
79 batch_sampler=batch_sampler,
80 drop_last=False,
81 **loader_kwargs)
82
83
-
E501
Line too long (80 > 79 characters)
84 def get_eval_loader(loader, dataset, batch_size, grouper=None, **loader_kwargs):
85 """
86 Constructs and returns the data loader for evaluation.
87 Args:
88 - loader (str): Loader type. 'standard' for standard loaders.
89 - dataset (WILDSDataset or WILDSSubset): Data
90 - batch_size (int): Batch size
91 - loader_kwargs: kwargs passed into torch DataLoader initialization.
92 Output:
93 - data loader (DataLoader): Data loader.
94 """
95 if loader == 'standard':
96 return DataLoader(
97 dataset,
98 shuffle=False, # Do not shuffle eval datasets
99 sampler=None,
100 collate_fn=dataset.collate,
101 batch_size=batch_size,
102 **loader_kwargs)
103
104
105 class GroupSampler:
106 """
107 Constructs batches by first sampling groups,
108 then sampling data from those groups.
109 It drops the last batch if it's incomplete.
110 """
111
112 def __init__(self, group_ids, batch_size, n_groups_per_batch,
113 uniform_over_groups, distinct_groups):
114
115 if batch_size % n_groups_per_batch != 0:
116 raise ValueError(
-
E501
Line too long (116 > 79 characters)
117 f'batch_size ({batch_size}) must be evenly divisible by n_groups_per_batch ({n_groups_per_batch}).')
118 if len(group_ids) < batch_size:
119 raise ValueError(
-
E501
Line too long (169 > 79 characters)
120 f'The dataset has only {len(group_ids)} examples but the batch size is {batch_size}. There must be enough examples to form at least one complete batch.')
121
122 self.group_ids = group_ids
-
E501
Line too long (92 > 79 characters)
123 self.unique_groups, self.group_indices, unique_counts = split_into_groups(group_ids)
124
125 self.distinct_groups = distinct_groups
126 self.n_groups_per_batch = n_groups_per_batch
127 self.n_points_per_group = batch_size // n_groups_per_batch
128
129 self.dataset_size = len(group_ids)
130 self.num_batches = self.dataset_size // batch_size
131
132 if uniform_over_groups: # Sample uniformly over groups
133 self.group_prob = None
134 else: # Sample a group proportionately to its size
-
E501
Line too long (81 > 79 characters)
135 self.group_prob = unique_counts.numpy() / unique_counts.numpy().sum()
136
137 def __iter__(self):
138 for batch_id in range(self.num_batches):
139 # Note that we are selecting group indices rather than groups
140 groups_for_batch = np.random.choice(
141 len(self.unique_groups),
142 size=self.n_groups_per_batch,
143 replace=(not self.distinct_groups),
144 p=self.group_prob)
145 sampled_ids = [
146 np.random.choice(
147 self.group_indices[group],
148 size=self.n_points_per_group,
-
E501
Line too long (86 > 79 characters)
149 replace=len(self.group_indices[group]) <= self.n_points_per_group,
150 # False if the group is larger than the sample size
151 p=None)
152 for group in groups_for_batch]
153
154 # Flatten
155 sampled_ids = np.concatenate(sampled_ids)
156 yield sampled_ids
157
158 def __len__(self):
159 return self.num_batches