⬅ optimizer.py source

1 from torch.optim import SGD, Adam
2 from transformers import AdamW
3  
4  
5 def initialize_optimizer(config, model):
6 # initialize optimizers
7 if config.optimizer == 'SGD':
8 params = filter(lambda p: p.requires_grad, model.parameters())
9 optimizer = SGD(
10 params,
11 lr=config.lr,
12 weight_decay=config.weight_decay,
13 **config.optimizer_kwargs)
14 elif config.optimizer == 'AdamW':
15 if 'bert' in config.model or 'gpt' in config.model:
16 no_decay = ['bias', 'LayerNorm.weight']
17 else:
18 no_decay = []
19  
20 params = [
  • E501 Line too long (103 > 79 characters)
21 {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
22 'weight_decay': config.weight_decay},
  • E501 Line too long (120 > 79 characters)
23 {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
24 ]
25 optimizer = AdamW(
26 params,
27 lr=config.lr,
28 **config.optimizer_kwargs)
29 elif config.optimizer == 'Adam':
30 params = filter(lambda p: p.requires_grad, model.parameters())
31 optimizer = Adam(
32 params,
33 lr=config.lr,
34 weight_decay=config.weight_decay,
35 **config.optimizer_kwargs)
36 else:
37 raise ValueError(f'Optimizer {config.optimizer} not recognized.')
38  
39 return optimizer