⬅ evaluate.py source

1 import argparse
2 import json
3 import numpy as np
4 import os
5 import torch
6 import urllib.request
7 from ast import literal_eval
8 from typing import Dict, List
9 from urllib.parse import urlparse
10  
11 from gds import benchmark_datasets
12 from gds import get_dataset
13 from gds.datasets.wilds_dataset import GDSDataset, GDSSubset
14  
15 """
16 Evaluate predictions for WILDS datasets.
17  
18 Usage:
19  
  • E501 Line too long (96 > 79 characters)
20 python examples/evaluate.py <Path to directory with predictions> <Path to output directory>
  • E501 Line too long (124 > 79 characters)
21 python examples/evaluate.py <Path to directory with predictions> <Path to output directory> --dataset <A WILDS dataset>
22  
23 """
24  
25  
  • E501 Line too long (82 > 79 characters)
26 def evaluate_all_benchmarks(predictions_dir: str, output_dir: str, root_dir: str):
27 """
28 Evaluate predictions for all the WILDS benchmarks.
29  
30 Parameters:
  • E501 Line too long (83 > 79 characters)
31 predictions_dir (str): Path to the directory with predictions. Can be a URL
32 output_dir (str): Output directory
33 root_dir (str): The directory where datasets can be found
34 """
35 all_results: Dict[str, Dict[str, Dict[str, float]]] = dict()
36 for dataset in benchmark_datasets:
37 try:
38 all_results[dataset] = evaluate_benchmark(
  • E501 Line too long (85 > 79 characters)
39 dataset, os.path.join(predictions_dir, dataset), output_dir, root_dir
40 )
41 except Exception as e:
42 print(f"Could not evaluate predictions for {dataset}:\n{str(e)}")
43  
44 # Write out aggregated results to output file
45 print(f"Writing complete results to {output_dir}...")
46 with open(os.path.join(output_dir, "all_results.json"), "w") as f:
47 json.dump(all_results, f, indent=4)
48  
49  
50 def evaluate_benchmark(
51 dataset_name: str, predictions_dir: str, output_dir: str, root_dir: str
52 ) -> Dict[str, Dict[str, float]]:
53 """
54 Evaluate across multiple replicates for a single benchmark.
55  
56 Parameters:
  • E501 Line too long (99 > 79 characters)
57 dataset_name (str): Name of the dataset. See datasets.py for the complete list of datasets.
  • E501 Line too long (84 > 79 characters)
58 predictions_dir (str): Path to the directory with predictions. Can be a URL.
59 output_dir (str): Output directory
60 root_dir (str): The directory where datasets can be found
61  
62 Returns:
  • E501 Line too long (88 > 79 characters)
63 Metrics as a dictionary with metrics as the keys and metric values as the values
64 """
65  
66 def get_replicates(dataset_name: str) -> List[str]:
67 if dataset_name == "poverty":
68 return [f"fold-{fold}" for fold in ["A", "B", "C", "D", "E"]]
69 else:
70 if dataset_name == "camelyon17":
71 seeds = range(0, 10)
72 elif dataset_name == "civilcomments":
73 seeds = range(0, 5)
74 else:
75 seeds = range(0, 3)
76 return [f"seed-{seed}" for seed in seeds]
77  
78 def get_prediction_file(
79 predictions_dir: str, dataset_name: str, split: str, replicate: str
80 ) -> str:
81 run_id = f"{dataset_name}_split-{split}_{replicate}"
82 for file in os.listdir(predictions_dir):
83 if file.startswith(run_id) and (
84 file.endswith(".csv") or file.endswith(".pth")
85 ):
86 return file
87 raise FileNotFoundError(
  • E501 Line too long (83 > 79 characters)
88 f"Could not find CSV or pth prediction file that starts with {run_id}."
89 )
90  
91  
92  
  • E303 Too many blank lines (3)
93 # Dataset will only be downloaded if it does not exist
94 wilds_dataset: GDSDataset = get_dataset(
95 dataset=dataset_name, root_dir=root_dir, download=True
96 )
97 splits: List[str] = list(wilds_dataset.split_dict.keys())
98 if "train" in splits:
99 splits.remove("train")
100  
101 replicates_results: Dict[str, Dict[str, List[float]]] = dict()
102 replicates: List[str] = get_replicates(dataset_name)
  • F821 Undefined name 'get_metrics'
103 metrics: List[str] = get_metrics(dataset_name)
104  
105 # Store the results for each replicate
106 for split in splits:
107 replicates_results[split] = {}
108 for metric in metrics:
109 replicates_results[split][metric] = []
110  
111 for replicate in replicates:
112 predictions_file = get_prediction_file(
113 predictions_dir, dataset_name, split, replicate
114 )
115 print(
  • E501 Line too long (106 > 79 characters)
116 f"Processing split={split}, replicate={replicate}, predictions_file={predictions_file}..."
117 )
118 full_path = os.path.join(predictions_dir, predictions_file)
119  
  • E501 Line too long (102 > 79 characters)
120 # GlobalWheat's predictions are a list of dictionaries, so it has to be handled separately
121 if dataset_name == "globalwheat":
  • E501 Line too long (86 > 79 characters)
122 metric_results: Dict[str, float] = evaluate_replicate_for_globalwheat(
123 wilds_dataset, split, full_path
124 )
125 else:
126 predicted_labels: torch.Tensor = get_predictions(full_path)
127 metric_results = evaluate_replicate(
128 wilds_dataset, split, predicted_labels
129 )
130 for metric in metrics:
  • E501 Line too long (80 > 79 characters)
131 replicates_results[split][metric].append(metric_results[metric])
132  
133 aggregated_results: Dict[str, Dict[str, float]] = dict()
134  
135 # Aggregate results of replicates
136 for split in splits:
137 aggregated_results[split] = {}
138 for metric in metrics:
  • E501 Line too long (85 > 79 characters)
139 replicates_metric_values: List[float] = replicates_results[split][metric]
140 aggregated_results[split][f"{metric}_std"] = np.std(
141 replicates_metric_values, ddof=1
142 )
  • E501 Line too long (81 > 79 characters)
143 aggregated_results[split][metric] = np.mean(replicates_metric_values)
144  
145 # Write out aggregated results to output file
146 print(f"Writing aggregated results for {dataset_name} to {output_dir}...")
  • E501 Line too long (82 > 79 characters)
147 with open(os.path.join(output_dir, f"{dataset_name}_results.json"), "w") as f:
148 json.dump(aggregated_results, f, indent=4)
149  
150 return aggregated_results
151  
152  
153 def evaluate_replicate(
154 dataset: GDSDataset, split: str, predicted_labels: torch.Tensor
155 ) -> Dict[str, float]:
156 """
157 Evaluate the given predictions and return the appropriate metrics.
158  
159 Parameters:
160 dataset (GDSDataset): A WILDS Dataset
161 split (str): split we are evaluating on
162 predicted_labels (torch.Tensor): Predictions
163  
164 Returns:
  • E501 Line too long (88 > 79 characters)
165 Metrics as a dictionary with metrics as the keys and metric values as the values
166 """
167 # Dataset will only be downloaded if it does not exist
168 subset: GDSSubset = dataset.get_subset(split)
169 metadata: torch.Tensor = subset.metadata_array
170 true_labels = subset.y_array
171 if predicted_labels.shape != true_labels.shape:
172 predicted_labels.unsqueeze_(-1)
173 return dataset.eval(predicted_labels, true_labels, metadata)[0]
174  
175  
176 def evaluate_replicate_for_globalwheat(
177 dataset: GDSDataset, split: str, path_to_predictions: str
178 ) -> Dict[str, float]:
179 predicted_labels = torch.load(path_to_predictions)
180 subset: GDSSubset = dataset.get_subset(split)
181 metadata: torch.Tensor = subset.metadata_array
182 true_labels = [subset.dataset.y_array[idx] for idx in subset.indices]
183 return dataset.eval(predicted_labels, true_labels, metadata)[0]
184  
185  
186 def get_predictions(path: str) -> torch.Tensor:
187 """
188 Extract out the predictions from the file at path.
189  
190 Parameters:
  • E501 Line too long (81 > 79 characters)
191 path (str): Path to the file that has the predicted labels. Can be a URL.
192  
193 Return:
194 Tensor representing predictions
195 """
196 if is_path_url(path):
197 data = urllib.request.urlopen(path)
198 else:
199 file = open(path, mode="r")
200 data = file.readlines()
201 file.close()
202  
  • E501 Line too long (86 > 79 characters)
203 predicted_labels = [literal_eval(line.rstrip()) for line in data if line.rstrip()]
204 return torch.from_numpy(np.array(predicted_labels))
205  
206  
207 def is_path_url(path: str) -> bool:
208 """
209 Returns True if the path is a URL.
210 """
211 try:
212 result = urlparse(path)
213 return all([result.scheme, result.netloc, result.path])
  • E722 Do not use bare 'except'
214 except:
215 return False
216  
217  
218 def main():
219 if args.dataset:
220 evaluate_benchmark(
221 args.dataset, args.predictions_dir, args.output_dir, args.root_dir
222 )
223 else:
  • E501 Line too long (82 > 79 characters)
224 print("A dataset was not specified. Evaluating for all WILDS datasets...")
  • E501 Line too long (85 > 79 characters)
225 evaluate_all_benchmarks(args.predictions_dir, args.output_dir, args.root_dir)
226 print("\nDone.")
227  
228  
229 if __name__ == "__main__":
230 parser = argparse.ArgumentParser(
231 description="Evaluate predictions for WILDS datasets."
232 )
233 parser.add_argument(
234 "predictions_dir",
235 type=str,
236 help="Path to prediction CSV or pth files.",
237 )
238 parser.add_argument(
239 "output_dir",
240 type=str,
241 help="Path to output directory.",
242 )
243 parser.add_argument(
244 "--dataset",
245 type=str,
246 choices=benchmark_datasets,
247 help="WILDS dataset to evaluate for.",
248 )
249 parser.add_argument(
250 "--root-dir",
251 type=str,
252 default="data",
  • E501 Line too long (113 > 79 characters)
253 help="The directory where the datasets can be found (or should be downloaded to, if they do not exist).",
254 )
255  
256 # Parse args and run this script
257 args = parser.parse_args()
258 main()