⬅ train.py source

1 from tqdm import tqdm
2 import torch
  • E501 Line too long (106 > 79 characters)
3 from utils import save_model, save_pred, get_pred_prefix, get_model_prefix, detach_and_clone, collate_list
4  
5  
6 def run_epoch(algorithm, dataset, general_logger, epoch, config, train):
7 if dataset['verbose']:
8 general_logger.write(f"\n{dataset['name']}:\n")
9  
10 if train:
11 algorithm.train()
12 torch.set_grad_enabled(True)
13 else:
14 algorithm.eval()
15 torch.set_grad_enabled(False)
16  
17 # Not preallocating memory is slower
18 # but makes it easier to handle different types of data loaders
19 # (which might not return exactly the same number of examples per epoch)
20 epoch_y_true = []
21 epoch_y_pred = []
22 epoch_metadata = []
23  
  • E501 Line too long (82 > 79 characters)
24 # Using enumerate(iterator) can sometimes leak memory in some environments (!)
25 # so we manually increment batch_idx
26 batch_idx = 0
  • E501 Line too long (84 > 79 characters)
27 iterator = tqdm(dataset['loader']) if config.progress_bar else dataset['loader']
28  
29 for batch in iterator:
30 if train:
31 batch_results = algorithm.update(batch)
32 else:
33 batch_results = algorithm.evaluate(batch)
  • W293 Blank line contains whitespace
34
35  
  • E303 Too many blank lines (2)
36 # These tensors are already detached, but we need to clone them again
37 # Otherwise they don't get garbage collected properly in some versions
38 # The extra detach is just for safety
39 # (they should already be detached in batch_results)
40 epoch_y_true.append(detach_and_clone(batch_results['y_true']))
41 y_pred = detach_and_clone(batch_results['y_pred'])
42 if config.process_outputs_function is not None:
  • F821 Undefined name 'process_outputs_functions'
  • E501 Line too long (87 > 79 characters)
43 y_pred = process_outputs_functions[config.process_outputs_function](y_pred)
44 epoch_y_pred.append(y_pred)
45 epoch_metadata.append(detach_and_clone(batch_results['metadata']))
46  
47 if train and (batch_idx + 1) % config.log_every == 0:
48 log_results(algorithm, dataset, general_logger, epoch, batch_idx)
49  
50 batch_idx += 1
51  
52 epoch_y_pred = collate_list(epoch_y_pred)
53 epoch_y_true = collate_list(epoch_y_true)
54 epoch_metadata = collate_list(epoch_metadata)
55  
56 results, results_str = dataset['dataset'].eval(
57 epoch_y_pred,
58 epoch_y_true,
59 epoch_metadata)
  • W293 Blank line contains whitespace
60
61 if config.scheduler_metric_split == dataset['split']:
62 algorithm.step_schedulers(
63 is_epoch=True,
64 metrics=results,
65 log_access=(not train))
66  
  • E501 Line too long (83 > 79 characters)
67 # log after updating the scheduler in case it needs to access the internal logs
68 log_results(algorithm, dataset, general_logger, epoch, batch_idx)
69  
70 results['epoch'] = epoch
71 dataset['eval_logger'].log(results)
72 if dataset['verbose']:
73 general_logger.write('Epoch eval:\n')
74 general_logger.write(results_str)
75  
76 return results, epoch_y_pred
77  
78  
  • E501 Line too long (101 > 79 characters)
79 def train(algorithm, datasets, general_logger, result_logger, config, epoch_offset, best_val_metric):
80  
81 for epoch in range(epoch_offset, config.n_epochs):
82 general_logger.write('\nEpoch [%d]:\n' % epoch)
83  
84 # First run training
  • E501 Line too long (90 > 79 characters)
85 run_epoch(algorithm, datasets['train'], general_logger, epoch, config, train=True)
86  
87 # Then run val
  • E501 Line too long (111 > 79 characters)
88 val_results, y_pred = run_epoch(algorithm, datasets['val'], general_logger, epoch, config, train=False)
89 curr_val_metric = val_results[config.val_metric]
  • E501 Line too long (88 > 79 characters)
90 general_logger.write(f'Validation {config.val_metric}: {curr_val_metric:.3f}\n')
  • W293 Blank line contains whitespace
91
92  
  • E303 Too many blank lines (2)
93 if best_val_metric is None:
94 is_best = True
95 else:
96 if config.val_metric_decreasing:
97 is_best = curr_val_metric < best_val_metric
98 else:
99 is_best = curr_val_metric > best_val_metric
100 if is_best:
101 best_val_metric = curr_val_metric
  • E501 Line too long (96 > 79 characters)
102 general_logger.write(f'Epoch {epoch} has the best validation performance so far.\n')
103  
104  
105  
  • E303 Too many blank lines (3)
  • E501 Line too long (97 > 79 characters)
106 save_model_if_needed(algorithm, datasets['val'], epoch, config, is_best, best_val_metric)
107 save_pred_if_needed(y_pred, datasets['val'], epoch, config, is_best)
108  
109 # Then run everything else
110 if config.evaluate_all_splits:
  • E501 Line too long (101 > 79 characters)
111 additional_splits = [split for split in datasets.keys() if split not in ['train', 'val']]
112 else:
113 additional_splits = config.eval_splits
114 for split in additional_splits:
  • E501 Line too long (105 > 79 characters)
115 _, y_pred = run_epoch(algorithm, datasets[split], general_logger, epoch, config, train=False)
  • E501 Line too long (80 > 79 characters)
116 save_pred_if_needed(y_pred, datasets[split], epoch, config, is_best)
117  
118 general_logger.write('\n')
119  
120  
121  
  • E303 Too many blank lines (3)
  • E501 Line too long (89 > 79 characters)
122 def evaluate(algorithm, datasets, epoch, general_logger, result_logger, config, is_best):
123 algorithm.eval()
124 torch.set_grad_enabled(False)
125 for split, dataset in datasets.items():
  • E501 Line too long (82 > 79 characters)
126 if (not config.evaluate_all_splits) and (split not in config.eval_splits):
127 continue
128 epoch_y_true = []
129 epoch_y_pred = []
130 epoch_metadata = []
  • E501 Line too long (88 > 79 characters)
131 iterator = tqdm(dataset['loader']) if config.progress_bar else dataset['loader']
132 for batch in iterator:
133 batch_results = algorithm.evaluate(batch)
134 epoch_y_true.append(detach_and_clone(batch_results['y_true']))
135 y_pred = detach_and_clone(batch_results['y_pred'])
136 if config.process_outputs_function is not None:
  • F821 Undefined name 'process_outputs_functions'
  • E501 Line too long (91 > 79 characters)
137 y_pred = process_outputs_functions[config.process_outputs_function](y_pred)
138 epoch_y_pred.append(y_pred)
139 epoch_metadata.append(detach_and_clone(batch_results['metadata']))
140  
141 epoch_y_pred = collate_list(epoch_y_pred)
142 epoch_y_true = collate_list(epoch_y_true)
143 epoch_metadata = collate_list(epoch_metadata)
144  
145 results, results_str = dataset['dataset'].eval(
146 epoch_y_pred,
147 epoch_y_true,
148 epoch_metadata)
  • W293 Blank line contains whitespace
149
150 results['epoch'] = epoch
151 dataset['eval_logger'].log(results)
152 general_logger.write(f'Eval split {split} at epoch {epoch}:\n')
153 general_logger.write(results_str)
154  
  • E501 Line too long (85 > 79 characters)
155 # Skip saving train preds, since the train loader generally shuffles the data
156 if split != 'train':
  • E501 Line too long (95 > 79 characters)
157 save_pred_if_needed(epoch_y_pred, dataset, epoch, config, is_best, force_save=True)
158  
159  
160 def log_results(algorithm, dataset, general_logger, epoch, batch_idx):
161 if algorithm.has_log:
162 log = algorithm.get_log()
163 log['epoch'] = epoch
164 log['batch'] = batch_idx
165 dataset['algo_logger'].log(log)
166 if dataset['verbose']:
167 general_logger.write(algorithm.get_pretty_log_str())
168 algorithm.reset_log()
169  
170  
  • E501 Line too long (83 > 79 characters)
171 def save_pred_if_needed(y_pred, dataset, epoch, config, is_best, force_save=False):
172 if config.save_pred:
173 prefix = get_pred_prefix(dataset, config)
  • E501 Line too long (96 > 79 characters)
174 if force_save or (config.save_step is not None and (epoch + 1) % config.save_step == 0):
175 save_pred(y_pred, prefix, f'epoch-{epoch}_pred')
176 if (not force_save) and config.save_last:
  • F541 F-string is missing placeholders
177 save_pred(y_pred, prefix, f'epoch-last_pred')
178 if config.save_best and is_best:
  • F541 F-string is missing placeholders
179 save_pred(y_pred, prefix, f'epoch-best_pred')
180  
181  
  • E501 Line too long (86 > 79 characters)
182 def save_model_if_needed(algorithm, dataset, epoch, config, is_best, best_val_metric):
183 prefix = get_model_prefix(dataset, config)
184 if config.save_step is not None and (epoch + 1) % config.save_step == 0:
  • E501 Line too long (89 > 79 characters)
185 save_model(algorithm, epoch, best_val_metric, prefix, f'epoch-{epoch}_model.pth')
186 if config.save_last:
  • E501 Line too long (85 > 79 characters)
187 save_model(algorithm, epoch, best_val_metric, prefix, 'epoch-last_model.pth')
188 if config.save_best and is_best:
  • E501 Line too long (85 > 79 characters)
189 save_model(algorithm, epoch, best_val_metric, prefix, 'epoch-best_model.pth')