1 from algorithms.single_model_algorithm import SingleModelAlgorithm
2 from models.initializer import initialize_model
3
4
5 class GSN(SingleModelAlgorithm):
6 def __init__(self, config, d_out, grouper, loss,
7 metric, n_train_steps, full_dataset=None):
-
E501
Line too long (92 > 79 characters)
8 model = initialize_model(config, d_out, full_dataset=full_dataset).to(config.device)
9 # initialize module
10 super().__init__(
11 config=config,
12 model=model,
13 grouper=grouper,
14 loss=loss,
15 metric=metric,
16 n_train_steps=n_train_steps,
17 )
18
19 def objective(self, results):
-
E501
Line too long (89 > 79 characters)
20 return self.loss.compute(results['y_pred'], results['y_true'], return_dict=False)