⬅ run_expt.py source

1 import argparse
  • F401 'csv' imported but unused
2 import csv
3 import os
  • F401 'sys' imported but unused
4 import sys
5 import time
6 from collections import defaultdict
7  
8 try:
  • F401 'graph_tool' imported but unused
9 import graph_tool
  • F841 Local variable 'e' is assigned to but never used
10 except Exception as e:
11 pass
  • F401 'pandas as pd' imported but unused
12 import pandas as pd
13 import torch
14 import torch.multiprocessing
  • F401 'torch.nn' imported but unused
15 import torch.nn as nn
  • F401 'torchvision' imported but unused
16 import torchvision
17  
18 import configs.supported as supported
19 import gds
20 from algorithms.initializer import initialize_algorithm
21 from configs.utils import populate_defaults
22 from train import train, evaluate
  • E501 Line too long (112 > 79 characters)
23 from utils import set_seed, Logger, BatchLogger, log_config, initialize_wandb, close_wandb, ParseKwargs, load, \
24 log_group_data, parse_bool, get_model_prefix
25 from gds.common.data_loaders import get_train_loader, get_eval_loader
26 from gds.common.grouper import CombinatorialGrouper
27  
28  
29  
  • E303 Too many blank lines (3)
30 def main():
31 """ to see default hyperparams for each dataset/model, look at configs/ """
32 parser = argparse.ArgumentParser()
33  
34 # Required arguments
  • E501 Line too long (89 > 79 characters)
35 parser.add_argument('-d', '--dataset', choices=gds.supported_datasets, required=True)
  • E501 Line too long (89 > 79 characters)
36 parser.add_argument('-a', '--algorithm', choices=supported.algorithms, required=True)
37 parser.add_argument('-m', '--model', choices=supported.models)
38 parser.add_argument('--seed', type=int)
39 parser.add_argument('--use_frac', type=parse_bool,
  • E128 Continuation line under-indented for visual indent
  • E501 Line too long (108 > 79 characters)
40 help='Convenience parameter that scales all dataset splits down to the specified fraction, '
  • E128 Continuation line under-indented for visual indent
  • E501 Line too long (111 > 79 characters)
41 'for development purposes. Note that this also scales the test set down, so the reported '
  • E128 Continuation line under-indented for visual indent
42 'numbers are not comparable with the full test set.')
  • W293 Blank line contains whitespace
43
  • W293 Blank line contains whitespace
44
  • E303 Too many blank lines (2)
45 # Resume
  • E501 Line too long (90 > 79 characters)
46 parser.add_argument('--resume', type=parse_bool, const=True, nargs='?', default=False)
47  
48 # Dataset
49 parser.add_argument('--split_scheme',
  • E501 Line too long (117 > 79 characters)
50 help='Identifies how the train/val/test split is constructed. Choices are dataset-specific.')
  • E501 Line too long (86 > 79 characters)
51 parser.add_argument('--dataset_kwargs', nargs='*', action=ParseKwargs, default={})
  • E501 Line too long (92 > 79 characters)
52 parser.add_argument('--download', default=False, type=parse_bool, const=True, nargs='?',
  • E501 Line too long (105 > 79 characters)
53 help='If true, tries to downloads the dataset if it does not exist in root_dir.')
54 parser.add_argument('--version', default=None, type=str)
55  
56 # Loaders
  • E501 Line too long (85 > 79 characters)
57 parser.add_argument('--loader_kwargs', nargs='*', action=ParseKwargs, default={})
58 parser.add_argument('--train_loader', choices=['standard', 'group'])
  • E501 Line too long (88 > 79 characters)
59 parser.add_argument('--uniform_over_groups', type=parse_bool, const=True, nargs='?')
  • E501 Line too long (84 > 79 characters)
60 parser.add_argument('--distinct_groups', type=parse_bool, const=True, nargs='?')
61 parser.add_argument('--n_groups_per_batch', type=int)
62 parser.add_argument('--batch_size', type=int)
  • E501 Line too long (82 > 79 characters)
63 parser.add_argument('--eval_loader', choices=['standard'], default='standard')
64  
65 # Model
  • E501 Line too long (84 > 79 characters)
66 parser.add_argument('--model_kwargs', nargs='*', action=ParseKwargs, default={},
  • E501 Line too long (108 > 79 characters)
67 help='keyword arguments for model initialization passed as key1=value1 key2=value2')
68  
69 # Objective
70 parser.add_argument('--loss_function', choices=supported.losses)
  • E501 Line too long (83 > 79 characters)
71 parser.add_argument('--loss_kwargs', nargs='*', action=ParseKwargs, default={},
  • E501 Line too long (107 > 79 characters)
72 help='keyword arguments for loss initialization passed as key1=value1 key2=value2')
73  
74 # Algorithm
  • E266 Too many leading '#' for block comment
75 ## To be tuned
76 parser.add_argument('--coral_penalty_weight', type=float)
77 parser.add_argument('--irm_lambda', type=float)
78 parser.add_argument('--flag_step_size', type=float)
79 parser.add_argument('--dann_lambda', type=float)
80 parser.add_argument('--mldg_beta', type=float)
81 parser.add_argument('--gcl_aug_ratio', type=float)
82 parser.add_argument('--parameter', type=float)
83  
  • E266 Too many leading '#' for block comment
84 ## Not to be tuned
85 parser.add_argument('--groupby_fields', nargs='+')
86 parser.add_argument('--group_dro_step_size', type=float)
87 parser.add_argument('--irm_penalty_anneal_iters', type=int)
88 parser.add_argument('--algo_log_metric')
89 parser.add_argument('--gsn_id_type', type=str,
  • E501 Line too long (97 > 79 characters)
90 choices=['cycle_graph', 'path_graph', 'complete_graph', 'binomial_tree'])
91 parser.add_argument('--gsn_k', type=int)
92  
93 # Model selection
94 parser.add_argument('--val_metric')
  • E501 Line too long (90 > 79 characters)
95 parser.add_argument('--val_metric_decreasing', type=parse_bool, const=True, nargs='?')
96  
97 # Optimization
98 parser.add_argument('--n_epochs', type=int)
99 parser.add_argument('--optimizer', choices=supported.optimizers)
100 parser.add_argument('--lr', type=float)
101 parser.add_argument('--weight_decay', type=float)
102 parser.add_argument('--max_grad_norm', type=float)
  • E501 Line too long (88 > 79 characters)
103 parser.add_argument('--optimizer_kwargs', nargs='*', action=ParseKwargs, default={})
104  
105 # Scheduler
106 parser.add_argument('--scheduler', choices=supported.schedulers)
  • E501 Line too long (88 > 79 characters)
107 parser.add_argument('--scheduler_kwargs', nargs='*', action=ParseKwargs, default={})
  • E501 Line too long (92 > 79 characters)
108 parser.add_argument('--scheduler_metric_split', choices=['train', 'val'], default='val')
109 parser.add_argument('--scheduler_metric_name')
110  
111 # Evaluation
  • E501 Line too long (93 > 79 characters)
112 parser.add_argument('--eval_only', type=parse_bool, const=True, nargs='?', default=False)
  • E501 Line too long (98 > 79 characters)
113 parser.add_argument('--process_outputs_function', choices=supported.process_outputs_functions)
  • E501 Line too long (102 > 79 characters)
114 parser.add_argument('--evaluate_all_splits', type=parse_bool, const=True, nargs='?', default=True)
115 parser.add_argument('--eval_splits', nargs='+', default=[])
116 parser.add_argument('--eval_epoch', default=None, type=int,
  • E501 Line too long (191 > 79 characters)
117 help='If eval_only is set, then eval_epoch allows you to specify evaluating at a particular epoch. By default, it evaluates the best epoch by validation performance.')
118  
119 # Ablation
  • E501 Line too long (96 > 79 characters)
120 parser.add_argument('--random_split', type=parse_bool, const=True, nargs='?', default=False)
121  
122 # Misc
123 parser.add_argument('--device', type=int, default=0)
124 parser.add_argument('--root_dir', default='./data',
  • E501 Line too long (131 > 79 characters)
125 help='The directory where [dataset]/data can be found (or should be downloaded to, if it does not exist).')
126 parser.add_argument('--log_dir', default='./logs')
127 parser.add_argument('--log_every', default=50, type=int)
128 parser.add_argument('--save_step', type=int)
  • E501 Line too long (92 > 79 characters)
129 parser.add_argument('--save_best', type=parse_bool, const=True, nargs='?', default=True)
  • E501 Line too long (92 > 79 characters)
130 parser.add_argument('--save_last', type=parse_bool, const=True, nargs='?', default=True)
  • E501 Line too long (92 > 79 characters)
131 parser.add_argument('--save_pred', type=parse_bool, const=True, nargs='?', default=True)
  • E501 Line too long (85 > 79 characters)
132 parser.add_argument('--no_group_logging', type=parse_bool, const=True, nargs='?')
  • E501 Line too long (93 > 79 characters)
133 parser.add_argument('--use_wandb', type=parse_bool, const=True, nargs='?', default=False)
  • E501 Line too long (96 > 79 characters)
134 parser.add_argument('--progress_bar', type=parse_bool, const=True, nargs='?', default=False)
135  
136 parser.add_argument('--jiuhai', action='store_true', default=False)
137  
138 config = parser.parse_args()
139 config = populate_defaults(config)
140  
141 # hard code:
  • E203 Whitespace before ':'
142 if config.jiuhai :
  • E203 Whitespace before ':'
143 if config.dataset in {'ogb-molpcba', 'ogbg-ppa'} :
  • E501 Line too long (110 > 79 characters)
  • E203 Whitespace before ':'
144 if config.algorithm in {'deepCORAL', 'FLAG', 'GCL'} or config.model in {'gin_10_layers', 'cheb'} :
  • E501 Line too long (88 > 79 characters)
145 raise ValueError('For Jiuhai\'s experiments, these are too slow, kill.')
146  
  • E203 Whitespace before ':'
147 if config.algorithm == 'MLDG' :
148 assert config.dataset != 'ogb-molpcba' and config.dataset != 'ogbg-ppa'
149  
  • E203 Whitespace before ':'
150 if config.model == 'cheb' and config.algorithm == 'GCL' :
151 config.gcl_aug_ratio = 0.1
152  
153 # To speed up slow algorithms
  • E501 Line too long (105 > 79 characters)
  • E203 Whitespace before ':'
154 if (config.algorithm == 'MLDG' or config.algorithm == 'FLAG') and config.dataset != 'SBM-Isolation' :
155 config.n_epochs = config.n_epochs//2
  • E203 Whitespace before ':'
156 if config.algorithm == 'DANN' or config.algorithm == 'DANN-G' :
157 config.n_epochs *= 2
  • E203 Whitespace before ':'
158 if config.algorithm == 'IRM' :
159 config.n_epochs = int(config.n_epochs * 1.5)
160  
161 if config.algorithm == 'deepCORAL':
162 config.parameter = config.coral_penalty_weight
163 elif config.algorithm == 'DANN' or config.algorithm == 'DANN-G':
164 config.parameter = config.dann_lambda
165 elif config.algorithm == 'MLDG':
166 config.parameter = config.mldg_beta
167 elif config.algorithm == 'IRM':
168 config.parameter = config.irm_lambda
169 elif config.algorithm == 'FLAG':
170 config.parameter = config.flag_step_size
171 elif config.algorithm == 'GCL':
172 config.parameter = config.gcl_aug_ratio
173 else:
174 config.parameter = None
175  
  • W293 Blank line contains whitespace
176
  • E303 Too many blank lines (2)
177 # For the 3wlgnn model, we need to set batch_size to 1
178 if config.model == '3wlgnn':
179 config.batch_size = 1
180  
181 # Set device
  • E501 Line too long (116 > 79 characters)
182 config.device = torch.device("cuda:" + str(config.device)) if torch.cuda.is_available() else torch.device("cpu")
183  
184 # Initialize logs
185 if os.path.exists(config.log_dir) and config.resume:
186 resume = True
187 mode = 'a'
188 elif os.path.exists(config.log_dir) and config.eval_only:
189 resume = False
190 mode = 'a'
191 else:
192 resume = False
193 mode = 'w'
194  
195  
  • E303 Too many blank lines (2)
196 if not os.path.exists(config.log_dir):
197 os.makedirs(config.log_dir)
  • E501 Line too long (152 > 79 characters)
198 logger = Logger(os.path.join(config.log_dir, f'{config.dataset}_{config.algorithm}_{config.parameter}_{config.model}_seed-{config.seed}.txt'), mode)
199  
200  
  • E303 Too many blank lines (2)
201 # Record config
202 log_config(config, logger)
203  
204 # Set random seed
205 set_seed(config.seed)
206  
207 # Data
208 if config.algorithm == 'GSN':
209 config.dataset_kwargs['gsn_id_type'] = config.gsn_id_type
210 config.dataset_kwargs['gsn_k'] = config.gsn_k
211 full_dataset = gds.get_dataset(
212 dataset=config.dataset,
213 version=config.version,
214 root_dir=config.root_dir,
215 download=config.download,
216 split_scheme=config.split_scheme,
217 random_split=config.random_split,
218 subgraph=True if config.algorithm == 'GSN' else False,
219 algorithm=config.algorithm,
220 model=config.model,
221 **config.dataset_kwargs)
222  
223 train_grouper = CombinatorialGrouper(
224 dataset=full_dataset,
225 groupby_fields=config.groupby_fields)
226  
227 datasets = defaultdict(dict)
228 if config.use_wandb:
229 wandb_runner = initialize_wandb(config)
230 for split in full_dataset.split_dict.keys():
231 if split == 'train':
232 verbose = True
233 elif split == 'val':
234 verbose = True
235 else:
236 verbose = False
237  
238 # Get subset
239 if config.use_frac:
  • E501 Line too long (97 > 79 characters)
240 datasets[split]['dataset'] = full_dataset.get_subset(split, frac=config.default_frac)
241 else:
  • E501 Line too long (81 > 79 characters)
242 datasets[split]['dataset'] = full_dataset.get_subset(split, frac=1.0)
243  
244 if split == 'train':
245 datasets[split]['loader'] = get_train_loader(
246 loader=config.train_loader,
247 dataset=datasets[split]['dataset'],
248 batch_size=config.batch_size,
249 uniform_over_groups=config.uniform_over_groups,
250 grouper=train_grouper,
251 distinct_groups=config.distinct_groups,
252 n_groups_per_batch=config.n_groups_per_batch,
253 **config.loader_kwargs)
254 else:
255 datasets[split]['loader'] = get_eval_loader(
256 loader=config.eval_loader,
257 dataset=datasets[split]['dataset'],
258 grouper=train_grouper,
259 batch_size=config.batch_size,
260 **config.loader_kwargs)
261  
262 # Set fields
263 datasets[split]['split'] = split
264 datasets[split]['name'] = full_dataset.split_names[split]
265 datasets[split]['verbose'] = verbose
266  
267 # Loggers
268 datasets[split]['eval_logger'] = BatchLogger(
  • E501 Line too long (204 > 79 characters)
269 os.path.join(config.log_dir, f'{config.dataset}_{config.algorithm}_{config.parameter}_{config.model}_seed-{config.seed}_{split}_eval.csv'), mode=mode, use_wandb=(config.use_wandb and verbose))
270 datasets[split]['algo_logger'] = BatchLogger(
  • E501 Line too long (204 > 79 characters)
271 os.path.join(config.log_dir, f'{config.dataset}_{config.algorithm}_{config.parameter}_{config.model}_seed-{config.seed}_{split}_algo.csv'), mode=mode, use_wandb=(config.use_wandb and verbose))
272  
273 # Logging dataset info
274 # Show class breakdown if feasible
  • E501 Line too long (128 > 79 characters)
275 if config.no_group_logging and full_dataset.is_classification and full_dataset.y_size == 1 and full_dataset.n_classes <= 10:
276 log_grouper = CombinatorialGrouper(
277 dataset=full_dataset,
278 groupby_fields=['y'])
279 elif config.no_group_logging:
280 log_grouper = None
281 else:
282 log_grouper = train_grouper
283 log_group_data(datasets, log_grouper, logger)
284  
  • W293 Blank line contains whitespace
285
  • E303 Too many blank lines (2)
  • E266 Too many leading '#' for block comment
286 ## Initialize algorithm
287 algorithm = initialize_algorithm(
288 config=config,
289 datasets=datasets,
290 full_dataset=full_dataset,
291 train_grouper=train_grouper)
292  
293 model_prefix = get_model_prefix(datasets['train'], config)
  • W293 Blank line contains whitespace
294
295  
  • E303 Too many blank lines (2)
296 if not config.eval_only:
  • E266 Too many leading '#' for block comment
297 ## Load saved results if resuming
298 resume_success = False
299 if resume:
  • E501 Line too long (90 > 79 characters)
300 save_path = model_prefix.with_name(model_prefix.name + 'epoch-last_model.pth')
301 if not os.path.exists(save_path):
302 epochs = [
303 int(file.split('epoch:')[1].split('_')[0])
  • E501 Line too long (84 > 79 characters)
304 for file in os.listdir(config.log_dir) if file.endswith('.pth')]
305 if len(epochs) > 0:
306 latest_epoch = max(epochs)
  • E501 Line too long (109 > 79 characters)
307 save_path = model_prefix.with_name(model_prefix.name + f'epoch-{latest_epoch}_model.pth')
308 try:
309 prev_epoch, best_val_metric = load(algorithm, save_path)
310 epoch_offset = prev_epoch + 1
  • E501 Line too long (106 > 79 characters)
311 logger.write(f'Resuming from epoch {epoch_offset} with best val metric {best_val_metric}')
312 resume_success = True
313 except FileNotFoundError:
314 pass
315  
  • E712 Comparison to False should be 'if cond is False:' or 'if not cond:'
316 if resume_success == False:
317 epoch_offset = 0
318 best_val_metric = None
319 start = time.time()
320 train(
321 algorithm=algorithm,
322 datasets=datasets,
323 general_logger=logger,
324 result_logger=logger,
325 config=config,
326 epoch_offset=epoch_offset,
327 best_val_metric=best_val_metric)
328 else:
329 if config.eval_epoch is None:
  • E501 Line too long (96 > 79 characters)
330 eval_model_path = model_prefix.with_name(model_prefix.name + 'epoch-best_model.pth')
331 else:
  • E501 Line too long (112 > 79 characters)
332 eval_model_path = model_prefix.with_name(model_prefix.name + f'epoch-{config.eval_epoch}_model.pth')
333 best_epoch, best_val_metric = load(algorithm, eval_model_path)
334 if config.eval_epoch is None:
335 epoch = best_epoch
336 else:
337 epoch = config.eval_epoch
338 if epoch == best_epoch:
339 is_best = True
340 evaluate(
341 algorithm=algorithm,
342 datasets=datasets,
343 epoch=epoch,
344 general_logger=logger,
345 result_logger=logger,
346 config=config,
347 is_best=is_best)
348  
349 # have to close wandb runner before closing logger (and stdout)
350 if config.use_wandb:
351 close_wandb(wandb_runner)
352 finish = time.time()
353 if not config.eval_only:
354 logger.write(f'time(s): {finish-start:.3f}\n')
355 logger.close()
356 for split in datasets:
357 datasets[split]['eval_logger'].close()
358 datasets[split]['algo_logger'].close()
359  
360  
361 if __name__ == '__main__':
362 main()