⬅ utils.py source

1 import argparse
2 import csv
3 import os
4 import random
5 import sys
6 from pathlib import Path
7  
8 import numpy as np
9 import pandas as pd
10 import torch
11  
12 try:
13 import wandb
  • F841 Local variable 'e' is assigned to but never used
14 except Exception as e:
15 pass
16  
17  
18 def update_average(prev_avg, prev_counts, curr_avg, curr_counts):
19 denom = prev_counts + curr_counts
20 if isinstance(curr_counts, torch.Tensor):
21 denom += (denom == 0).float()
22 elif isinstance(curr_counts, int) or isinstance(curr_counts, float):
23 if denom == 0:
24 return 0.
25 else:
26 raise ValueError('Type of curr_counts not recognized')
27 prev_weight = prev_counts / denom
28 curr_weight = curr_counts / denom
29 return prev_weight * prev_avg + curr_weight * curr_avg
30  
31  
  • E501 Line too long (104 > 79 characters)
32 # Taken from https://sumit-ghosh.com/articles/parsing-dictionary-key-value-pairs-kwargs-argparse-python/
33 class ParseKwargs(argparse.Action):
34 def __call__(self, parser, namespace, values, option_string=None):
35 setattr(namespace, self.dest, dict())
36 for value in values:
37 key, value_str = value.split('=')
38 if value_str.replace('-', '').isnumeric():
39 processed_val = int(value_str)
40 elif value_str.replace('-', '').replace('.', '').isnumeric():
41 processed_val = float(value_str)
42 elif value_str in ['True', 'true']:
43 processed_val = True
44 elif value_str in ['False', 'false']:
45 processed_val = False
46 else:
47 processed_val = value_str
48 getattr(namespace, self.dest)[key] = processed_val
49  
50  
51 def parse_bool(v):
52 if v.lower() == 'true':
53 return True
54 elif v.lower() == 'false':
55 return False
56 else:
57 raise argparse.ArgumentTypeError('Boolean value expected.')
58  
59  
60 def save_model(algorithm, epoch, best_val_metric, prefix, suffix):
61 path = prefix.parent / (prefix.name + suffix)
62 state = {}
63 state['algorithm'] = algorithm.state_dict()
64 state['epoch'] = epoch
65 state['best_val_metric'] = best_val_metric
66 torch.save(state, path)
67  
68  
69  
  • E303 Too many blank lines (3)
70 def load(algorithm, path):
71 path = Path(path)
72 state = torch.load(path)
73 algorithm.load_state_dict(state['algorithm'])
74 return state['epoch'], state['best_val_metric']
75  
76  
77 def log_group_data(datasets, grouper, logger):
78 for k, dataset in datasets.items():
79 name = dataset['name']
80 dataset = dataset['dataset']
81 logger.write(f'{name} data...\n')
82 if grouper is None:
83 logger.write(f' n = {len(dataset)}\n')
84 else:
85 _, group_counts = grouper.metadata_to_group(
86 dataset.metadata_array,
87 return_counts=True)
88 group_counts = group_counts.tolist()
89 for group_idx in range(grouper.n_groups):
  • E501 Line too long (104 > 79 characters)
90 logger.write(f' {grouper.group_str(group_idx)}: n = {group_counts[group_idx]:.0f}\n')
91 logger.flush()
92  
93  
94 class Logger(object):
95 def __init__(self, fpath=None, mode='w'):
96 self.console = sys.stdout
97 self.file = None
98 if fpath is not None:
99 self.file = open(Path(fpath), mode)
100  
101 def __del__(self):
102 self.close()
103  
104 def __enter__(self):
105 pass
106  
107 def __exit__(self, *args):
108 self.close()
109  
110 def write(self, msg):
111 self.console.write(msg)
112 if self.file is not None:
113 self.file.write(msg)
114  
115 def flush(self):
116 self.console.flush()
117 if self.file is not None:
118 self.file.flush()
119 os.fsync(self.file.fileno())
120  
121 def close(self):
122 self.console.close()
123 if self.file is not None:
124 self.file.close()
125  
126  
127 class BatchLogger:
128 def __init__(self, csv_path, mode='w', use_wandb=False):
129 self.path = Path(csv_path)
130 self.mode = mode
131 self.file = open(self.path, mode)
132 self.is_initialized = False
133  
134 # Use Weights and Biases for logging
135 self.use_wandb = use_wandb
136 if use_wandb:
137 self.split = self.path.stem
138  
139 def setup(self, log_dict):
140 columns = log_dict.keys()
141 # Move epoch and batch to the front if in the log_dict
142 for key in ['batch', 'epoch']:
143 if key in columns:
144 columns = [key] + [k for k in columns if k != key]
145  
146 self.writer = csv.DictWriter(self.file, fieldnames=columns)
  • E501 Line too long (98 > 79 characters)
147 if self.mode == 'w' or (not os.path.exists(self.path)) or os.path.getsize(self.path) == 0:
148 self.writer.writeheader()
149 self.is_initialized = True
150  
151 def log(self, log_dict):
152 if self.is_initialized is False:
153 self.setup(log_dict)
154 self.writer.writerow(log_dict)
155 self.flush()
156  
157 if self.use_wandb:
158 results = {}
159 for key in log_dict:
160 new_key = f'{self.split}/{key}'
161 results[new_key] = log_dict[key]
162 wandb.log(results)
163  
164 def flush(self):
165 self.file.flush()
166  
167 def close(self):
168 self.file.close()
169  
170  
171 def set_seed(seed):
172 """Sets seed"""
173 if torch.cuda.is_available():
174 torch.cuda.manual_seed(seed)
175 torch.manual_seed(seed)
176 np.random.seed(seed)
177 random.seed(seed)
178 torch.backends.cudnn.benchmark = False
179 torch.backends.cudnn.deterministic = True
180  
181  
182 def log_config(config, logger):
183 for name, val in vars(config).items():
184 logger.write(f'{name.replace("_", " ").capitalize()}: {val}\n')
185 logger.write('\n')
186  
187  
188 def initialize_wandb(config):
  • E501 Line too long (80 > 79 characters)
189 name = config.dataset + '_' + config.algorithm + '_' + f"seed-{config.seed}"
  • F541 F-string is missing placeholders
  • E501 Line too long (107 > 79 characters)
190 wandb_runner = wandb.init(name=name, project=f"graphdg", entity='graphnet', config=config, reinit=True)
191 return wandb_runner
192  
193  
194 def close_wandb(wandb_runner):
195 wandb_runner.finish()
196  
197  
198 def save_pred(y_pred, prefix, suffix):
199 csv_path = prefix.parent / (prefix.name + suffix + '.csv')
200 pth_path = prefix.parent / (prefix.name + suffix + '.pth')
201 # Single tensor
202 if torch.is_tensor(y_pred):
203 df = pd.DataFrame(y_pred.numpy())
204 df.to_csv(csv_path, index=False, header=False)
205 # Dictionary
206 elif isinstance(y_pred, dict) or isinstance(y_pred, list):
207 torch.save(y_pred, pth_path)
208 else:
209 raise TypeError("Invalid type for save_pred")
210  
211  
212 def get_replicate_str(dataset, config):
213 if dataset['dataset'].dataset_name == 'poverty':
214 replicate_str = f"fold-{config.dataset_kwargs['fold']}"
215 else:
216 replicate_str = f"seed-{config.seed}"
217 return replicate_str
218  
219  
220 def get_pred_prefix(dataset, config):
221 dataset_name = dataset['dataset'].dataset_name
222 split = dataset['split']
223 replicate_str = get_replicate_str(dataset, config)
  • E501 Line too long (137 > 79 characters)
224 prefix = Path(config.log_dir) / f"{dataset_name}_{config.algorithm}_{config.parameter}_{config.model}_{replicate_str}_split-{split}_"
225  
226 return prefix
227  
228  
229 def get_model_prefix(dataset, config):
230 dataset_name = dataset['dataset'].dataset_name
231 replicate_str = get_replicate_str(dataset, config)
  • E501 Line too long (123 > 79 characters)
232 prefix = Path(config.log_dir) / f"{dataset_name}_{config.algorithm}_{config.parameter}_{config.model}_{replicate_str}_"
233  
234 return prefix
235  
236  
237 def move_to(obj, device):
238 if isinstance(obj, dict):
239 return {k: move_to(v, device) for k, v in obj.items()}
240 elif isinstance(obj, list):
241 return [move_to(v, device) for v in obj]
242 elif isinstance(obj, float) or isinstance(obj, int):
243 return obj
244 else:
245 # Assume obj is a Tensor or other type
246 # (like Batch, for MolPCBA) that supports .to(device)
247 return obj.to(device)
248  
249  
250 def detach_and_clone(obj):
251 if torch.is_tensor(obj):
252 return obj.detach().clone()
253 elif isinstance(obj, dict):
254 return {k: detach_and_clone(v) for k, v in obj.items()}
255 elif isinstance(obj, list):
256 return [detach_and_clone(v) for v in obj]
257 elif isinstance(obj, float) or isinstance(obj, int):
258 return obj
259 else:
260 raise TypeError("Invalid type for detach_and_clone")
261  
262  
263 def collate_list(vec):
264 """
  • E501 Line too long (84 > 79 characters)
265 If vec is a list of Tensors, it concatenates them all along the first dimension.
266  
  • E501 Line too long (85 > 79 characters)
267 If vec is a list of lists, it joins these lists together, but does not attempt to
  • E501 Line too long (88 > 79 characters)
268 recursively collate. This allows each element of the list to be, e.g., its own dict.
269  
  • E501 Line too long (89 > 79 characters)
270 If vec is a list of dicts (with the same keys in each dict), it returns a single dict
  • E501 Line too long (86 > 79 characters)
271 with the same keys. For each key, it recursively collates all entries in the list.
272 """
273 if not isinstance(vec, list):
274 raise TypeError("collate_list must take in a list")
275 elem = vec[0]
276 if torch.is_tensor(elem):
277 return torch.cat(vec)
278 elif isinstance(elem, list):
279 return [obj for sublist in vec for obj in sublist]
280 elif isinstance(elem, dict):
281 return {k: collate_list([d[k] for d in vec]) for k in elem}
282 else:
  • E501 Line too long (84 > 79 characters)
283 raise TypeError("Elements of the list to collate must be tensors or dicts.")
284  
285  
286 def remove_key(key):
287 """
288 Returns a function that strips out a key from a dict.
289 """
290  
291 def remove(d):
292 if not isinstance(d, dict):
293 raise TypeError("remove_key must take in a dict")
294 return {k: v for (k, v) in d.items() if k != key}
295  
296 return remove