1 import argparse
-
F401
'csv' imported but unused
2 import csv
3 import os
-
F401
'sys' imported but unused
4 import sys
5 import time
6 from collections import defaultdict
7
8 try:
-
F401
'graph_tool' imported but unused
9 import graph_tool
-
F841
Local variable 'e' is assigned to but never used
10 except Exception as e:
11 pass
-
F401
'pandas as pd' imported but unused
12 import pandas as pd
13 import torch
14 import torch.multiprocessing
-
F401
'torch.nn' imported but unused
15 import torch.nn as nn
-
F401
'torchvision' imported but unused
16 import torchvision
17
18 import configs.supported as supported
19 import gds
20 from algorithms.initializer import initialize_algorithm
21 from configs.utils import populate_defaults
22 from train import train, evaluate
-
E501
Line too long (112 > 79 characters)
23 from utils import set_seed, Logger, BatchLogger, log_config, initialize_wandb, close_wandb, ParseKwargs, load, \
24 log_group_data, parse_bool, get_model_prefix
25 from gds.common.data_loaders import get_train_loader, get_eval_loader
26 from gds.common.grouper import CombinatorialGrouper
27
28
29
-
E303
Too many blank lines (3)
30 def main():
31 """ to see default hyperparams for each dataset/model, look at configs/ """
32 parser = argparse.ArgumentParser()
33
34 # Required arguments
-
E501
Line too long (89 > 79 characters)
35 parser.add_argument('-d', '--dataset', choices=gds.supported_datasets, required=True)
-
E501
Line too long (89 > 79 characters)
36 parser.add_argument('-a', '--algorithm', choices=supported.algorithms, required=True)
37 parser.add_argument('-m', '--model', choices=supported.models)
38 parser.add_argument('--seed', type=int)
39 parser.add_argument('--use_frac', type=parse_bool,
-
E128
Continuation line under-indented for visual indent
-
E501
Line too long (108 > 79 characters)
40 help='Convenience parameter that scales all dataset splits down to the specified fraction, '
-
E128
Continuation line under-indented for visual indent
-
E501
Line too long (111 > 79 characters)
41 'for development purposes. Note that this also scales the test set down, so the reported '
-
E128
Continuation line under-indented for visual indent
42 'numbers are not comparable with the full test set.')
-
W293
Blank line contains whitespace
43
-
W293
Blank line contains whitespace
44
-
E303
Too many blank lines (2)
45 # Resume
-
E501
Line too long (90 > 79 characters)
46 parser.add_argument('--resume', type=parse_bool, const=True, nargs='?', default=False)
47
48 # Dataset
49 parser.add_argument('--split_scheme',
-
E501
Line too long (117 > 79 characters)
50 help='Identifies how the train/val/test split is constructed. Choices are dataset-specific.')
-
E501
Line too long (86 > 79 characters)
51 parser.add_argument('--dataset_kwargs', nargs='*', action=ParseKwargs, default={})
-
E501
Line too long (92 > 79 characters)
52 parser.add_argument('--download', default=False, type=parse_bool, const=True, nargs='?',
-
E501
Line too long (105 > 79 characters)
53 help='If true, tries to downloads the dataset if it does not exist in root_dir.')
54 parser.add_argument('--version', default=None, type=str)
55
56 # Loaders
-
E501
Line too long (85 > 79 characters)
57 parser.add_argument('--loader_kwargs', nargs='*', action=ParseKwargs, default={})
58 parser.add_argument('--train_loader', choices=['standard', 'group'])
-
E501
Line too long (88 > 79 characters)
59 parser.add_argument('--uniform_over_groups', type=parse_bool, const=True, nargs='?')
-
E501
Line too long (84 > 79 characters)
60 parser.add_argument('--distinct_groups', type=parse_bool, const=True, nargs='?')
61 parser.add_argument('--n_groups_per_batch', type=int)
62 parser.add_argument('--batch_size', type=int)
-
E501
Line too long (82 > 79 characters)
63 parser.add_argument('--eval_loader', choices=['standard'], default='standard')
64
65 # Model
-
E501
Line too long (84 > 79 characters)
66 parser.add_argument('--model_kwargs', nargs='*', action=ParseKwargs, default={},
-
E501
Line too long (108 > 79 characters)
67 help='keyword arguments for model initialization passed as key1=value1 key2=value2')
68
69 # Objective
70 parser.add_argument('--loss_function', choices=supported.losses)
-
E501
Line too long (83 > 79 characters)
71 parser.add_argument('--loss_kwargs', nargs='*', action=ParseKwargs, default={},
-
E501
Line too long (107 > 79 characters)
72 help='keyword arguments for loss initialization passed as key1=value1 key2=value2')
73
74 # Algorithm
-
E266
Too many leading '#' for block comment
75 ## To be tuned
76 parser.add_argument('--coral_penalty_weight', type=float)
77 parser.add_argument('--irm_lambda', type=float)
78 parser.add_argument('--flag_step_size', type=float)
79 parser.add_argument('--dann_lambda', type=float)
80 parser.add_argument('--mldg_beta', type=float)
81 parser.add_argument('--gcl_aug_ratio', type=float)
82 parser.add_argument('--parameter', type=float)
83
-
E266
Too many leading '#' for block comment
84 ## Not to be tuned
85 parser.add_argument('--groupby_fields', nargs='+')
86 parser.add_argument('--group_dro_step_size', type=float)
87 parser.add_argument('--irm_penalty_anneal_iters', type=int)
88 parser.add_argument('--algo_log_metric')
89 parser.add_argument('--gsn_id_type', type=str,
-
E501
Line too long (97 > 79 characters)
90 choices=['cycle_graph', 'path_graph', 'complete_graph', 'binomial_tree'])
91 parser.add_argument('--gsn_k', type=int)
92
93 # Model selection
94 parser.add_argument('--val_metric')
-
E501
Line too long (90 > 79 characters)
95 parser.add_argument('--val_metric_decreasing', type=parse_bool, const=True, nargs='?')
96
97 # Optimization
98 parser.add_argument('--n_epochs', type=int)
99 parser.add_argument('--optimizer', choices=supported.optimizers)
100 parser.add_argument('--lr', type=float)
101 parser.add_argument('--weight_decay', type=float)
102 parser.add_argument('--max_grad_norm', type=float)
-
E501
Line too long (88 > 79 characters)
103 parser.add_argument('--optimizer_kwargs', nargs='*', action=ParseKwargs, default={})
104
105 # Scheduler
106 parser.add_argument('--scheduler', choices=supported.schedulers)
-
E501
Line too long (88 > 79 characters)
107 parser.add_argument('--scheduler_kwargs', nargs='*', action=ParseKwargs, default={})
-
E501
Line too long (92 > 79 characters)
108 parser.add_argument('--scheduler_metric_split', choices=['train', 'val'], default='val')
109 parser.add_argument('--scheduler_metric_name')
110
111 # Evaluation
-
E501
Line too long (93 > 79 characters)
112 parser.add_argument('--eval_only', type=parse_bool, const=True, nargs='?', default=False)
-
E501
Line too long (98 > 79 characters)
113 parser.add_argument('--process_outputs_function', choices=supported.process_outputs_functions)
-
E501
Line too long (102 > 79 characters)
114 parser.add_argument('--evaluate_all_splits', type=parse_bool, const=True, nargs='?', default=True)
115 parser.add_argument('--eval_splits', nargs='+', default=[])
116 parser.add_argument('--eval_epoch', default=None, type=int,
-
E501
Line too long (191 > 79 characters)
117 help='If eval_only is set, then eval_epoch allows you to specify evaluating at a particular epoch. By default, it evaluates the best epoch by validation performance.')
118
119 # Ablation
-
E501
Line too long (96 > 79 characters)
120 parser.add_argument('--random_split', type=parse_bool, const=True, nargs='?', default=False)
121
122 # Misc
123 parser.add_argument('--device', type=int, default=0)
124 parser.add_argument('--root_dir', default='./data',
-
E501
Line too long (131 > 79 characters)
125 help='The directory where [dataset]/data can be found (or should be downloaded to, if it does not exist).')
126 parser.add_argument('--log_dir', default='./logs')
127 parser.add_argument('--log_every', default=50, type=int)
128 parser.add_argument('--save_step', type=int)
-
E501
Line too long (92 > 79 characters)
129 parser.add_argument('--save_best', type=parse_bool, const=True, nargs='?', default=True)
-
E501
Line too long (92 > 79 characters)
130 parser.add_argument('--save_last', type=parse_bool, const=True, nargs='?', default=True)
-
E501
Line too long (92 > 79 characters)
131 parser.add_argument('--save_pred', type=parse_bool, const=True, nargs='?', default=True)
-
E501
Line too long (85 > 79 characters)
132 parser.add_argument('--no_group_logging', type=parse_bool, const=True, nargs='?')
-
E501
Line too long (93 > 79 characters)
133 parser.add_argument('--use_wandb', type=parse_bool, const=True, nargs='?', default=False)
-
E501
Line too long (96 > 79 characters)
134 parser.add_argument('--progress_bar', type=parse_bool, const=True, nargs='?', default=False)
135
136 parser.add_argument('--jiuhai', action='store_true', default=False)
137
138 config = parser.parse_args()
139 config = populate_defaults(config)
140
141 # hard code:
-
E203
Whitespace before ':'
142 if config.jiuhai :
-
E203
Whitespace before ':'
143 if config.dataset in {'ogb-molpcba', 'ogbg-ppa'} :
-
E501
Line too long (110 > 79 characters)
-
E203
Whitespace before ':'
144 if config.algorithm in {'deepCORAL', 'FLAG', 'GCL'} or config.model in {'gin_10_layers', 'cheb'} :
-
E501
Line too long (88 > 79 characters)
145 raise ValueError('For Jiuhai\'s experiments, these are too slow, kill.')
146
-
E203
Whitespace before ':'
147 if config.algorithm == 'MLDG' :
148 assert config.dataset != 'ogb-molpcba' and config.dataset != 'ogbg-ppa'
149
-
E203
Whitespace before ':'
150 if config.model == 'cheb' and config.algorithm == 'GCL' :
151 config.gcl_aug_ratio = 0.1
152
153 # To speed up slow algorithms
-
E501
Line too long (105 > 79 characters)
-
E203
Whitespace before ':'
154 if (config.algorithm == 'MLDG' or config.algorithm == 'FLAG') and config.dataset != 'SBM-Isolation' :
155 config.n_epochs = config.n_epochs//2
-
E203
Whitespace before ':'
156 if config.algorithm == 'DANN' or config.algorithm == 'DANN-G' :
157 config.n_epochs *= 2
-
E203
Whitespace before ':'
158 if config.algorithm == 'IRM' :
159 config.n_epochs = int(config.n_epochs * 1.5)
160
161 if config.algorithm == 'deepCORAL':
162 config.parameter = config.coral_penalty_weight
163 elif config.algorithm == 'DANN' or config.algorithm == 'DANN-G':
164 config.parameter = config.dann_lambda
165 elif config.algorithm == 'MLDG':
166 config.parameter = config.mldg_beta
167 elif config.algorithm == 'IRM':
168 config.parameter = config.irm_lambda
169 elif config.algorithm == 'FLAG':
170 config.parameter = config.flag_step_size
171 elif config.algorithm == 'GCL':
172 config.parameter = config.gcl_aug_ratio
173 else:
174 config.parameter = None
175
-
W293
Blank line contains whitespace
176
-
E303
Too many blank lines (2)
177 # For the 3wlgnn model, we need to set batch_size to 1
178 if config.model == '3wlgnn':
179 config.batch_size = 1
180
181 # Set device
-
E501
Line too long (116 > 79 characters)
182 config.device = torch.device("cuda:" + str(config.device)) if torch.cuda.is_available() else torch.device("cpu")
183
184 # Initialize logs
185 if os.path.exists(config.log_dir) and config.resume:
186 resume = True
187 mode = 'a'
188 elif os.path.exists(config.log_dir) and config.eval_only:
189 resume = False
190 mode = 'a'
191 else:
192 resume = False
193 mode = 'w'
194
195
-
E303
Too many blank lines (2)
196 if not os.path.exists(config.log_dir):
197 os.makedirs(config.log_dir)
-
E501
Line too long (152 > 79 characters)
198 logger = Logger(os.path.join(config.log_dir, f'{config.dataset}_{config.algorithm}_{config.parameter}_{config.model}_seed-{config.seed}.txt'), mode)
199
200
-
E303
Too many blank lines (2)
201 # Record config
202 log_config(config, logger)
203
204 # Set random seed
205 set_seed(config.seed)
206
207 # Data
208 if config.algorithm == 'GSN':
209 config.dataset_kwargs['gsn_id_type'] = config.gsn_id_type
210 config.dataset_kwargs['gsn_k'] = config.gsn_k
211 full_dataset = gds.get_dataset(
212 dataset=config.dataset,
213 version=config.version,
214 root_dir=config.root_dir,
215 download=config.download,
216 split_scheme=config.split_scheme,
217 random_split=config.random_split,
218 subgraph=True if config.algorithm == 'GSN' else False,
219 algorithm=config.algorithm,
220 model=config.model,
221 **config.dataset_kwargs)
222
223 train_grouper = CombinatorialGrouper(
224 dataset=full_dataset,
225 groupby_fields=config.groupby_fields)
226
227 datasets = defaultdict(dict)
228 if config.use_wandb:
229 wandb_runner = initialize_wandb(config)
230 for split in full_dataset.split_dict.keys():
231 if split == 'train':
232 verbose = True
233 elif split == 'val':
234 verbose = True
235 else:
236 verbose = False
237
238 # Get subset
239 if config.use_frac:
-
E501
Line too long (97 > 79 characters)
240 datasets[split]['dataset'] = full_dataset.get_subset(split, frac=config.default_frac)
241 else:
-
E501
Line too long (81 > 79 characters)
242 datasets[split]['dataset'] = full_dataset.get_subset(split, frac=1.0)
243
244 if split == 'train':
245 datasets[split]['loader'] = get_train_loader(
246 loader=config.train_loader,
247 dataset=datasets[split]['dataset'],
248 batch_size=config.batch_size,
249 uniform_over_groups=config.uniform_over_groups,
250 grouper=train_grouper,
251 distinct_groups=config.distinct_groups,
252 n_groups_per_batch=config.n_groups_per_batch,
253 **config.loader_kwargs)
254 else:
255 datasets[split]['loader'] = get_eval_loader(
256 loader=config.eval_loader,
257 dataset=datasets[split]['dataset'],
258 grouper=train_grouper,
259 batch_size=config.batch_size,
260 **config.loader_kwargs)
261
262 # Set fields
263 datasets[split]['split'] = split
264 datasets[split]['name'] = full_dataset.split_names[split]
265 datasets[split]['verbose'] = verbose
266
267 # Loggers
268 datasets[split]['eval_logger'] = BatchLogger(
-
E501
Line too long (204 > 79 characters)
269 os.path.join(config.log_dir, f'{config.dataset}_{config.algorithm}_{config.parameter}_{config.model}_seed-{config.seed}_{split}_eval.csv'), mode=mode, use_wandb=(config.use_wandb and verbose))
270 datasets[split]['algo_logger'] = BatchLogger(
-
E501
Line too long (204 > 79 characters)
271 os.path.join(config.log_dir, f'{config.dataset}_{config.algorithm}_{config.parameter}_{config.model}_seed-{config.seed}_{split}_algo.csv'), mode=mode, use_wandb=(config.use_wandb and verbose))
272
273 # Logging dataset info
274 # Show class breakdown if feasible
-
E501
Line too long (128 > 79 characters)
275 if config.no_group_logging and full_dataset.is_classification and full_dataset.y_size == 1 and full_dataset.n_classes <= 10:
276 log_grouper = CombinatorialGrouper(
277 dataset=full_dataset,
278 groupby_fields=['y'])
279 elif config.no_group_logging:
280 log_grouper = None
281 else:
282 log_grouper = train_grouper
283 log_group_data(datasets, log_grouper, logger)
284
-
W293
Blank line contains whitespace
285
-
E303
Too many blank lines (2)
-
E266
Too many leading '#' for block comment
286 ## Initialize algorithm
287 algorithm = initialize_algorithm(
288 config=config,
289 datasets=datasets,
290 full_dataset=full_dataset,
291 train_grouper=train_grouper)
292
293 model_prefix = get_model_prefix(datasets['train'], config)
-
W293
Blank line contains whitespace
294
295
-
E303
Too many blank lines (2)
296 if not config.eval_only:
-
E266
Too many leading '#' for block comment
297 ## Load saved results if resuming
298 resume_success = False
299 if resume:
-
E501
Line too long (90 > 79 characters)
300 save_path = model_prefix.with_name(model_prefix.name + 'epoch-last_model.pth')
301 if not os.path.exists(save_path):
302 epochs = [
303 int(file.split('epoch:')[1].split('_')[0])
-
E501
Line too long (84 > 79 characters)
304 for file in os.listdir(config.log_dir) if file.endswith('.pth')]
305 if len(epochs) > 0:
306 latest_epoch = max(epochs)
-
E501
Line too long (109 > 79 characters)
307 save_path = model_prefix.with_name(model_prefix.name + f'epoch-{latest_epoch}_model.pth')
308 try:
309 prev_epoch, best_val_metric = load(algorithm, save_path)
310 epoch_offset = prev_epoch + 1
-
E501
Line too long (106 > 79 characters)
311 logger.write(f'Resuming from epoch {epoch_offset} with best val metric {best_val_metric}')
312 resume_success = True
313 except FileNotFoundError:
314 pass
315
-
E712
Comparison to False should be 'if cond is False:' or 'if not cond:'
316 if resume_success == False:
317 epoch_offset = 0
318 best_val_metric = None
319 start = time.time()
320 train(
321 algorithm=algorithm,
322 datasets=datasets,
323 general_logger=logger,
324 result_logger=logger,
325 config=config,
326 epoch_offset=epoch_offset,
327 best_val_metric=best_val_metric)
328 else:
329 if config.eval_epoch is None:
-
E501
Line too long (96 > 79 characters)
330 eval_model_path = model_prefix.with_name(model_prefix.name + 'epoch-best_model.pth')
331 else:
-
E501
Line too long (112 > 79 characters)
332 eval_model_path = model_prefix.with_name(model_prefix.name + f'epoch-{config.eval_epoch}_model.pth')
333 best_epoch, best_val_metric = load(algorithm, eval_model_path)
334 if config.eval_epoch is None:
335 epoch = best_epoch
336 else:
337 epoch = config.eval_epoch
338 if epoch == best_epoch:
339 is_best = True
340 evaluate(
341 algorithm=algorithm,
342 datasets=datasets,
343 epoch=epoch,
344 general_logger=logger,
345 result_logger=logger,
346 config=config,
347 is_best=is_best)
348
349 # have to close wandb runner before closing logger (and stdout)
350 if config.use_wandb:
351 close_wandb(wandb_runner)
352 finish = time.time()
353 if not config.eval_only:
354 logger.write(f'time(s): {finish-start:.3f}\n')
355 logger.close()
356 for split in datasets:
357 datasets[split]['eval_logger'].close()
358 datasets[split]['algo_logger'].close()
359
360
361 if __name__ == '__main__':
362 main()