⬅ algorithms/ERM.py source

1 from algorithms.single_model_algorithm import SingleModelAlgorithm
2 from models.initializer import initialize_model
3  
4  
5 class ERM(SingleModelAlgorithm):
6 def __init__(self, config, d_out, grouper, loss, metric, n_train_steps):
7 model = initialize_model(config, d_out).to(config.device)
8 # initialize module
9 super().__init__(
10 config=config,
11 model=model,
12 grouper=grouper,
13 loss=loss,
14 metric=metric,
15 n_train_steps=n_train_steps,
16 )
17  
18 def objective(self, results):
  • E501 Line too long (89 > 79 characters)
19 return self.loss.compute(results['y_pred'], results['y_true'], return_dict=False)