⬅ common/utils.py source

1 import numpy as np
2 import torch
3 from pandas.api.types import CategoricalDtype
4  
5  
6 def minimum(numbers, empty_val=0.):
7 if isinstance(numbers, torch.Tensor):
8 if numbers.numel() == 0:
9 return torch.tensor(empty_val, device=numbers.device)
10 else:
11 return numbers[~torch.isnan(numbers)].min()
12 elif isinstance(numbers, np.ndarray):
13 if numbers.size == 0:
14 return np.array(empty_val)
15 else:
16 return np.nanmin(numbers)
17 else:
18 if len(numbers) == 0:
19 return empty_val
20 else:
21 return min(numbers)
22  
23  
24 def maximum(numbers, empty_val=0.):
25 if isinstance(numbers, torch.Tensor):
26 if numbers.numel() == 0:
27 return torch.tensor(empty_val, device=numbers.device)
28 else:
29 return numbers[~torch.isnan(numbers)].max()
30 elif isinstance(numbers, np.ndarray):
31 if numbers.size == 0:
32 return np.array(empty_val)
33 else:
34 return np.nanmax(numbers)
35 else:
36 if len(numbers) == 0:
37 return empty_val
38 else:
39 return max(numbers)
40  
41  
42 def split_into_groups(g):
43 """
44 Args:
45 - g (Tensor): Vector of groups
46 Returns:
47 - groups (Tensor): Unique groups present in g
  • E501 Line too long (92 > 79 characters)
48 - group_indices (list): List of Tensors, where the i-th tensor is the indices of the
49 elements of g that equal groups[i].
50 Has the same length as len(groups).
51 - unique_counts (Tensor): Counts of each element in groups.
52 Has the same length as len(groups).
53 """
  • E501 Line too long (84 > 79 characters)
54 unique_groups, unique_counts = torch.unique(g, sorted=False, return_counts=True)
55 group_indices = []
56 for group in unique_groups:
57 group_indices.append(
58 torch.nonzero(g == group, as_tuple=True)[0])
59 return unique_groups, group_indices, unique_counts
60  
61  
62 def get_counts(g, n_groups):
63 """
64 This differs from split_into_groups in how it handles missing groups.
65 get_counts always returns a count Tensor of length n_groups,
66 whereas split_into_groups returns a unique_counts Tensor
67 whose length is the number of unique groups present in g.
68 Args:
69 - g (Tensor): Vector of groups
70 Returns:
  • E501 Line too long (87 > 79 characters)
71 - counts (Tensor): A list of length n_groups, denoting the count of each group.
72 """
  • E501 Line too long (84 > 79 characters)
73 unique_groups, unique_counts = torch.unique(g, sorted=False, return_counts=True)
74 counts = torch.zeros(n_groups, device=g.device)
75 counts[unique_groups] = unique_counts.float()
76 return counts
77  
78  
79 def avg_over_groups(v, g, n_groups):
80 """
81 Args:
82 v (Tensor): Vector containing the quantity to average over.
  • E501 Line too long (81 > 79 characters)
83 g (Tensor): Vector of the same length as v, containing group information.
84 Returns:
85 group_avgs (Tensor): Vector of length num_groups
86 group_counts (Tensor)
87 """
88 import torch_scatter
89 assert v.device == g.device
90 assert v.numel() == g.numel()
91 group_count = get_counts(g, n_groups)
  • E501 Line too long (88 > 79 characters)
92 group_avgs = torch_scatter.scatter(src=v, index=g, dim_size=n_groups, reduce='mean')
93 return group_avgs, group_count
94  
95  
96 def map_to_id_array(df, ordered_map={}):
97 maps = {}
98 array = np.zeros(df.shape)
99 for i, c in enumerate(df.columns):
100 if c in ordered_map:
  • E501 Line too long (85 > 79 characters)
101 category_type = CategoricalDtype(categories=ordered_map[c], ordered=True)
102 else:
103 category_type = 'category'
104 series = df[c].astype(category_type)
105 maps[c] = series.cat.categories.values
106 array[:, i] = series.cat.codes.values
107 return maps, array
108  
109  
110 def subsample_idxs(idxs, num=5000, take_rest=False, seed=None):
111 seed = (seed + 541433) if seed is not None else None
112 rng = np.random.default_rng(seed)
113  
114 idxs = idxs.copy()
115 rng.shuffle(idxs)
116 if take_rest:
117 idxs = idxs[num:]
118 else:
119 idxs = idxs[:num]
120 return idxs
121  
122  
123 def shuffle_arr(arr, seed=None):
124 seed = (seed + 548207) if seed is not None else None
125 rng = np.random.default_rng(seed)
126  
127 arr = arr.copy()
128 rng.shuffle(arr)
129 return arr
130  
131  
132 def threshold_at_recall(y_pred, y_true, global_recall=60):
  • E501 Line too long (95 > 79 characters)
133 """ Calculate the model threshold to use to achieve a desired global_recall level. Assumes that
134 y_true is a vector of the true binary labels."""
135 return np.percentile(y_pred[y_true == 1], 100 - global_recall)
136  
137  
138 def numel(obj):
139 if torch.is_tensor(obj):
140 return obj.numel()
141 elif isinstance(obj, list):
142 return len(obj)
143 else:
144 raise TypeError("Invalid type for numel")