⬅ common/grouper.py source

1 import warnings
2  
3 import numpy as np
4 import torch
5  
6 from gds.common.utils import get_counts
7 from gds.datasets.gds_dataset import GDSSubset
8  
9  
10 class Grouper:
11 """
12 Groupers group data points together based on their metadata.
13 They are used for training and evaluation,
14 e.g., to measure the accuracies of different groups of data.
15 """
16  
17 def __init__(self):
18 raise NotImplementedError
19  
20 @property
21 def n_groups(self):
22 """
23 The number of groups defined by this Grouper.
24 """
25 return self._n_groups
26  
27 def metadata_to_group(self, metadata, return_counts=False):
28 """
29 Args:
30 - metadata (Tensor): An n x d matrix containing d metadata fields
31 for n different points.
32 - return_counts (bool): If True, return group counts as well.
33 Output:
34 - group (Tensor): An n-length vector of groups.
35 - group_counts (Tensor): Optional, depending on return_counts.
  • E501 Line too long (88 > 79 characters)
36 An n_group-length vector of integers containing the
  • E501 Line too long (90 > 79 characters)
37 numbers of data points in each group in the metadata.
38 """
39 raise NotImplementedError
40  
41 def group_str(self, group):
42 """
43 Args:
44 - group (int): A single integer representing a group.
45 Output:
  • E501 Line too long (81 > 79 characters)
46 - group_str (str): A string containing the pretty name of that group.
47 """
48 raise NotImplementedError
49  
50 def group_field_str(self, group):
51 """
52 Args:
53 - group (int): A single integer representing a group.
54 Output:
55 - group_str (str): A string containing the name of that group.
56 """
57 raise NotImplementedError
58  
59  
60 class CombinatorialGrouper(Grouper):
61 def __init__(self, dataset, groupby_fields):
62 """
  • E501 Line too long (93 > 79 characters)
63 CombinatorialGroupers form groups by taking all possible combinations of the metadata
64 fields specified in groupby_fields, in lexicographical order.
65 For example, if:
66 dataset.metadata_fields = ['country', 'time', 'y']
67 groupby_fields = ['country', 'time']
  • E501 Line too long (82 > 79 characters)
68 and if in dataset.metadata, country is in {0, 1} and time is in {0, 1, 2},
69 then the grouper will assign groups in the following way:
70 country = 0, time = 0 -> group 0
71 country = 1, time = 0 -> group 1
72 country = 0, time = 1 -> group 2
73 country = 1, time = 1 -> group 3
74 country = 0, time = 2 -> group 4
75 country = 1, time = 2 -> group 5
76  
  • E501 Line too long (80 > 79 characters)
77 If groupby_fields is None, then all data points are assigned to group 0.
78  
79 Args:
80 - dataset (WILDSDataset)
81 - groupby_fields (list of str)
82 """
83  
84 if isinstance(dataset, GDSSubset):
  • E501 Line too long (92 > 79 characters)
85 raise ValueError("Grouper should be defined for the full dataset, not a subset")
86 self.groupby_fields = groupby_fields
87  
88 if groupby_fields is None:
89 self._n_groups = 1
90 else:
91 # We assume that the metadata fields are integers,
  • E501 Line too long (84 > 79 characters)
92 # so we can measure the cardinality of each field by taking its max + 1.
93 # Note that this might result in some empty groups.
  • E501 Line too long (99 > 79 characters)
94 self.groupby_field_indices = [i for (i, field) in enumerate(dataset.metadata_fields) if
95 field in groupby_fields]
96 if len(self.groupby_field_indices) != len(self.groupby_fields):
  • E501 Line too long (97 > 79 characters)
97 raise ValueError('At least one group field not found in dataset.metadata_fields')
  • E501 Line too long (84 > 79 characters)
98 grouped_metadata = dataset.metadata_array[:, self.groupby_field_indices]
99 if not isinstance(grouped_metadata, torch.LongTensor):
100 grouped_metadata_long = grouped_metadata.long()
101 if not torch.all(grouped_metadata == grouped_metadata_long):
102 warnings.warn(
  • E501 Line too long (121 > 79 characters)
103 f'CombinatorialGrouper: converting metadata with fields [{", ".join(groupby_fields)}] into long')
104 grouped_metadata = grouped_metadata_long
105 for idx, field in enumerate(self.groupby_fields):
106 min_value = grouped_metadata[:, idx].min()
107 if min_value < 0:
108 raise ValueError(
  • E501 Line too long (114 > 79 characters)
109 f"Metadata for CombinatorialGrouper cannot have values less than 0: {field}, {min_value}")
110 if min_value > 0:
111 warnings.warn(
  • E501 Line too long (141 > 79 characters)
112 f"Minimum metadata value for CombinatorialGrouper is not 0 ({field}, {min_value}). This will result in empty groups")
113 self.cardinality = 1 + torch.max(
114 grouped_metadata, dim=0)[0]
115 cumprod = torch.cumprod(self.cardinality, dim=0)
116 self._n_groups = cumprod[-1].item()
117 self.factors_np = np.concatenate(([1], cumprod[:-1]))
118 self.factors = torch.from_numpy(self.factors_np)
119 self.metadata_map = dataset.metadata_map
120  
121 def metadata_to_group(self, metadata, return_counts=False):
122 if self.groupby_fields is None:
123 groups = torch.zeros(metadata.shape[0], dtype=torch.long)
124 else:
  • E501 Line too long (82 > 79 characters)
125 groups = metadata[:, self.groupby_field_indices].long() @ self.factors
126  
127 if return_counts:
128 group_counts = get_counts(groups, self._n_groups)
129 return groups, group_counts
130 else:
131 return groups
132  
133 def group_str(self, group):
134 if self.groupby_fields is None:
135 return 'all'
136  
137 # group is just an integer, not a Tensor
138 n = len(self.factors_np)
139 metadata = np.zeros(n)
140 for i in range(n - 1):
  • E501 Line too long (80 > 79 characters)
141 metadata[i] = (group % self.factors_np[i + 1]) // self.factors_np[i]
142 metadata[n - 1] = group // self.factors_np[n - 1]
143 group_name = ''
144 for i in reversed(range(n)):
145 meta_val = int(metadata[i])
146 if self.metadata_map is not None:
147 if self.groupby_fields[i] in self.metadata_map:
  • E501 Line too long (82 > 79 characters)
148 meta_val = self.metadata_map[self.groupby_fields[i]][meta_val]
149 group_name += f'{self.groupby_fields[i]} = {meta_val}, '
150 group_name = group_name[:-2]
151 return group_name
152  
153 # a_n = S / x_n
154 # a_{n-1} = (S % x_n) / x_{n-1}
155 # a_{n-2} = (S % x_{n-1}) / x_{n-2}
156 # ...
157 #
158 # g =
159 # a_1 * x_1 +
160 # a_2 * x_2 + ...
161 # a_n * x_n
162  
163 def group_field_str(self, group):
  • E501 Line too long (89 > 79 characters)
164 return self.group_str(group).replace('=', ':').replace(',', '_').replace(' ', '')