⬅ configs/utils.py source

1 import copy
2  
3 from configs.algorithm import algorithm_defaults
4 from configs.data_loader import loader_defaults
5 from configs.datasets import dataset_defaults
6 from configs.model import model_defaults
7 from configs.scheduler import scheduler_defaults
8  
9  
10 def populate_defaults(config):
11 """Populates hyperparameters with defaults implied by choices
12 of other hyperparameters."""
13  
14 orig_config = copy.deepcopy(config)
15 assert config.dataset is not None, 'dataset must be specified'
16 assert config.algorithm is not None, 'algorithm must be specified'
17  
18 # implied defaults from choice of dataset
19 config = populate_config(
20 config,
21 dataset_defaults[config.dataset]
22 )
23  
24 # implied defaults from choice of algorithm
25 config = populate_config(
26 config,
27 algorithm_defaults[config.algorithm]
28 )
29  
30 # implied defaults from choice of loader
31 config = populate_config(
32 config,
33 loader_defaults
34 )
35 # implied defaults from choice of model
  • E701 Multiple statements on one line (colon)
36 if config.model: config = populate_config(
37 config,
38 model_defaults[config.model],
39 )
40  
41 # implied defaults from choice of scheduler
  • E701 Multiple statements on one line (colon)
42 if config.scheduler: config = populate_config(
43 config,
44 scheduler_defaults[config.scheduler]
45 )
46  
47 # misc implied defaults
48 if config.groupby_fields is None:
49 config.no_group_logging = True
50 config.no_group_logging = bool(config.no_group_logging)
51  
52 # basic checks
53 required_fields = [
  • E501 Line too long (117 > 79 characters)
54 'split_scheme', 'train_loader', 'uniform_over_groups', 'batch_size', 'eval_loader', 'model', 'loss_function',
  • E501 Line too long (93 > 79 characters)
55 'val_metric', 'val_metric_decreasing', 'n_epochs', 'optimizer', 'lr', 'weight_decay',
56 ]
57 for field in required_fields:
  • E501 Line too long (99 > 79 characters)
58 assert getattr(config, field) is not None, f"Must manually specify {field} for this setup."
59  
60 # data loader validations
61 # we only raise this error if the train_loader is standard, and
62 # n_groups_per_batch or distinct_groups are
63 # specified by the user (instead of populated as a default)
64 if config.train_loader == 'standard':
65 if orig_config.n_groups_per_batch is not None:
66 raise ValueError(
  • E501 Line too long (137 > 79 characters)
67 "n_groups_per_batch cannot be specified if the data loader is 'standard'. Consider using a 'group' data loader instead.")
68 if orig_config.distinct_groups is not None:
69 raise ValueError(
  • E501 Line too long (134 > 79 characters)
70 "distinct_groups cannot be specified if the data loader is 'standard'. Consider using a 'group' data loader instead.")
71  
72 return config
73  
74  
75 def populate_config(config, template: dict, force_compatibility=False):
76 """Populates missing (key, val) pairs in config with (key, val) in template.
77 Example usage: populate config with defaults
78 Args:
79 - config: namespace
80 - template: dict
  • E501 Line too long (84 > 79 characters)
81 - force_compatibility: option to raise errors if config.key != template[key]
82 """
83 if template is None:
84 return config
85  
86 d_config = vars(config)
87 for key, val in template.items():
  • E501 Line too long (84 > 79 characters)
88 if not isinstance(val, dict): # config[key] expected to be a non-index-able
89 if key not in d_config or d_config[key] is None:
90 d_config[key] = val
91 elif d_config[key] != val and force_compatibility:
92 raise ValueError(f"Argument {key} must be set to {val}")
93  
94 else: # config[key] expected to be a kwarg dict
95 for kwargs_key, kwargs_val in val.items():
  • E501 Line too long (88 > 79 characters)
96 if kwargs_key not in d_config[key] or d_config[key][kwargs_key] is None:
97 d_config[key][kwargs_key] = kwargs_val
  • E501 Line too long (85 > 79 characters)
98 elif d_config[key][kwargs_key] != kwargs_val and force_compatibility:
  • E501 Line too long (90 > 79 characters)
99 raise ValueError(f"Argument {key}[{kwargs_key}] must be set to {val}")
100 return config