1 from torch.optim import SGD, Adam
            
            2 from transformers import AdamW
            
            3  
            
            4  
            
            5 def initialize_optimizer(config, model):
            
            6     # initialize optimizers
            
            7     if config.optimizer == 'SGD':
            
            8         params = filter(lambda p: p.requires_grad, model.parameters())
            
            9         optimizer = SGD(
            
            10             params,
            
            11             lr=config.lr,
            
            12             weight_decay=config.weight_decay,
            
            13             **config.optimizer_kwargs)
            
            14     elif config.optimizer == 'AdamW':
            
            15         if 'bert' in config.model or 'gpt' in config.model:
            
            16             no_decay = ['bias', 'LayerNorm.weight']
            
            17         else:
            
            18             no_decay = []
            
            19  
            
            20         params = [
            
            
               
               
                  - 
                     
                        E501
                     
                     Line too long (103 > 79 characters)
 
               
               
21             {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
             
            22              'weight_decay': config.weight_decay},
            
            
               
               
                  - 
                     
                        E501
                     
                     Line too long (120 > 79 characters)
 
               
               
23             {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
             
            24         ]
            
            25         optimizer = AdamW(
            
            26             params,
            
            27             lr=config.lr,
            
            28             **config.optimizer_kwargs)
            
            29     elif config.optimizer == 'Adam':
            
            30         params = filter(lambda p: p.requires_grad, model.parameters())
            
            31         optimizer = Adam(
            
            32             params,
            
            33             lr=config.lr,
            
            34             weight_decay=config.weight_decay,
            
            35             **config.optimizer_kwargs)
            
            36     else:
            
            37         raise ValueError(f'Optimizer {config.optimizer} not recognized.')
            
            38  
            
            39     return optimizer