1 import warnings
2
3 import numpy as np
4 import torch
5
6 from gds.common.utils import get_counts
7 from gds.datasets.gds_dataset import GDSSubset
8
9
10 class Grouper:
11 """
12 Groupers group data points together based on their metadata.
13 They are used for training and evaluation,
14 e.g., to measure the accuracies of different groups of data.
15 """
16
17 def __init__(self):
18 raise NotImplementedError
19
20 @property
21 def n_groups(self):
22 """
23 The number of groups defined by this Grouper.
24 """
25 return self._n_groups
26
27 def metadata_to_group(self, metadata, return_counts=False):
28 """
29 Args:
30 - metadata (Tensor): An n x d matrix containing d metadata fields
31 for n different points.
32 - return_counts (bool): If True, return group counts as well.
33 Output:
34 - group (Tensor): An n-length vector of groups.
35 - group_counts (Tensor): Optional, depending on return_counts.
-
E501
Line too long (88 > 79 characters)
36 An n_group-length vector of integers containing the
-
E501
Line too long (90 > 79 characters)
37 numbers of data points in each group in the metadata.
38 """
39 raise NotImplementedError
40
41 def group_str(self, group):
42 """
43 Args:
44 - group (int): A single integer representing a group.
45 Output:
-
E501
Line too long (81 > 79 characters)
46 - group_str (str): A string containing the pretty name of that group.
47 """
48 raise NotImplementedError
49
50 def group_field_str(self, group):
51 """
52 Args:
53 - group (int): A single integer representing a group.
54 Output:
55 - group_str (str): A string containing the name of that group.
56 """
57 raise NotImplementedError
58
59
60 class CombinatorialGrouper(Grouper):
61 def __init__(self, dataset, groupby_fields):
62 """
-
E501
Line too long (93 > 79 characters)
63 CombinatorialGroupers form groups by taking all possible combinations of the metadata
64 fields specified in groupby_fields, in lexicographical order.
65 For example, if:
66 dataset.metadata_fields = ['country', 'time', 'y']
67 groupby_fields = ['country', 'time']
-
E501
Line too long (82 > 79 characters)
68 and if in dataset.metadata, country is in {0, 1} and time is in {0, 1, 2},
69 then the grouper will assign groups in the following way:
70 country = 0, time = 0 -> group 0
71 country = 1, time = 0 -> group 1
72 country = 0, time = 1 -> group 2
73 country = 1, time = 1 -> group 3
74 country = 0, time = 2 -> group 4
75 country = 1, time = 2 -> group 5
76
-
E501
Line too long (80 > 79 characters)
77 If groupby_fields is None, then all data points are assigned to group 0.
78
79 Args:
80 - dataset (WILDSDataset)
81 - groupby_fields (list of str)
82 """
83
84 if isinstance(dataset, GDSSubset):
-
E501
Line too long (92 > 79 characters)
85 raise ValueError("Grouper should be defined for the full dataset, not a subset")
86 self.groupby_fields = groupby_fields
87
88 if groupby_fields is None:
89 self._n_groups = 1
90 else:
91 # We assume that the metadata fields are integers,
-
E501
Line too long (84 > 79 characters)
92 # so we can measure the cardinality of each field by taking its max + 1.
93 # Note that this might result in some empty groups.
-
E501
Line too long (99 > 79 characters)
94 self.groupby_field_indices = [i for (i, field) in enumerate(dataset.metadata_fields) if
95 field in groupby_fields]
96 if len(self.groupby_field_indices) != len(self.groupby_fields):
-
E501
Line too long (97 > 79 characters)
97 raise ValueError('At least one group field not found in dataset.metadata_fields')
-
E501
Line too long (84 > 79 characters)
98 grouped_metadata = dataset.metadata_array[:, self.groupby_field_indices]
99 if not isinstance(grouped_metadata, torch.LongTensor):
100 grouped_metadata_long = grouped_metadata.long()
101 if not torch.all(grouped_metadata == grouped_metadata_long):
102 warnings.warn(
-
E501
Line too long (121 > 79 characters)
103 f'CombinatorialGrouper: converting metadata with fields [{", ".join(groupby_fields)}] into long')
104 grouped_metadata = grouped_metadata_long
105 for idx, field in enumerate(self.groupby_fields):
106 min_value = grouped_metadata[:, idx].min()
107 if min_value < 0:
108 raise ValueError(
-
E501
Line too long (114 > 79 characters)
109 f"Metadata for CombinatorialGrouper cannot have values less than 0: {field}, {min_value}")
110 if min_value > 0:
111 warnings.warn(
-
E501
Line too long (141 > 79 characters)
112 f"Minimum metadata value for CombinatorialGrouper is not 0 ({field}, {min_value}). This will result in empty groups")
113 self.cardinality = 1 + torch.max(
114 grouped_metadata, dim=0)[0]
115 cumprod = torch.cumprod(self.cardinality, dim=0)
116 self._n_groups = cumprod[-1].item()
117 self.factors_np = np.concatenate(([1], cumprod[:-1]))
118 self.factors = torch.from_numpy(self.factors_np)
119 self.metadata_map = dataset.metadata_map
120
121 def metadata_to_group(self, metadata, return_counts=False):
122 if self.groupby_fields is None:
123 groups = torch.zeros(metadata.shape[0], dtype=torch.long)
124 else:
-
E501
Line too long (82 > 79 characters)
125 groups = metadata[:, self.groupby_field_indices].long() @ self.factors
126
127 if return_counts:
128 group_counts = get_counts(groups, self._n_groups)
129 return groups, group_counts
130 else:
131 return groups
132
133 def group_str(self, group):
134 if self.groupby_fields is None:
135 return 'all'
136
137 # group is just an integer, not a Tensor
138 n = len(self.factors_np)
139 metadata = np.zeros(n)
140 for i in range(n - 1):
-
E501
Line too long (80 > 79 characters)
141 metadata[i] = (group % self.factors_np[i + 1]) // self.factors_np[i]
142 metadata[n - 1] = group // self.factors_np[n - 1]
143 group_name = ''
144 for i in reversed(range(n)):
145 meta_val = int(metadata[i])
146 if self.metadata_map is not None:
147 if self.groupby_fields[i] in self.metadata_map:
-
E501
Line too long (82 > 79 characters)
148 meta_val = self.metadata_map[self.groupby_fields[i]][meta_val]
149 group_name += f'{self.groupby_fields[i]} = {meta_val}, '
150 group_name = group_name[:-2]
151 return group_name
152
153 # a_n = S / x_n
154 # a_{n-1} = (S % x_n) / x_{n-1}
155 # a_{n-2} = (S % x_{n-1}) / x_{n-2}
156 # ...
157 #
158 # g =
159 # a_1 * x_1 +
160 # a_2 * x_2 + ...
161 # a_n * x_n
162
163 def group_field_str(self, group):
-
E501
Line too long (89 > 79 characters)
164 return self.group_str(group).replace('=', ':').replace(',', '_').replace(' ', '')