1 from torch.optim.lr_scheduler import ReduceLROnPlateau, StepLR, MultiStepLR
2 from transformers import (get_linear_schedule_with_warmup,
3 get_cosine_schedule_with_warmup)
4
5
6 def initialize_scheduler(config, optimizer, n_train_steps):
7 # construct schedulers
8 if config.scheduler is None:
9 return None
10 elif config.scheduler == 'linear_schedule_with_warmup':
11 scheduler = get_linear_schedule_with_warmup(
12 optimizer,
13 num_training_steps=n_train_steps,
14 **config.scheduler_kwargs)
15 step_every_batch = True
16 use_metric = False
17 elif config.scheduler == 'cosine_schedule_with_warmup':
18 scheduler = get_cosine_schedule_with_warmup(
19 optimizer,
20 num_training_steps=n_train_steps,
21 **config.scheduler_kwargs)
22 step_every_batch = True
23 use_metric = False
24 elif config.scheduler == 'ReduceLROnPlateau':
-
E501
Line too long (105 > 79 characters)
25 assert config.scheduler_metric_name, f'scheduler metric must be specified for {config.scheduler}'
26 scheduler = ReduceLROnPlateau(
27 optimizer,
28 **config.scheduler_kwargs)
29 step_every_batch = False
30 use_metric = True
31 elif config.scheduler == 'StepLR':
32 scheduler = StepLR(optimizer, **config.scheduler_kwargs)
33 step_every_batch = False
34 use_metric = False
35 elif config.scheduler == 'MultiStepLR':
36 scheduler = MultiStepLR(optimizer, **config.scheduler_kwargs)
37 step_every_batch = False
38 use_metric = False
39 else:
40 raise ValueError('Scheduler not recognized.')
41 # add a step_every_batch field
42 scheduler.step_every_batch = step_every_batch
43 scheduler.use_metric = use_metric
44 return scheduler
45
46
47 def step_scheduler(scheduler, metric=None):
48 if isinstance(scheduler, ReduceLROnPlateau):
49 assert metric is not None
50 scheduler.step(metric)
51 else:
52 scheduler.step()