Line too long (114 > 79 characters):
21 def __init__(self, version=None, root_dir='data', download=False, split_scheme='official', random_split=False,Line too long (85 > 79 characters):
25 self.ogb_dataset = PyGColoredMNISTDataset(name='ColoredMNIST', root=root_dir)Line too long (138 > 79 characters):
32 self._y_type = 'float' # although the task is binary classification, the prediction target contains nan value, thus we need floatLine too long (93 > 79 characters):
42 metadata_file_path = os.path.join(self.ogb_dataset.raw_dir, 'ColoredMNIST_group.npy')Line too long (114 > 79 characters):
45 'https://www.dropbox.com/s/ax1ek3yc6n1q739/ColoredMNIST_group.zip?dl=1', self.ogb_dataset.raw_dir)Line too long (103 > 79 characters):
49 self._metadata_array_wo_y = torch.from_numpy(np.load(metadata_file_path)).reshape(-1, 1).long()Line too long (94 > 79 characters):
51 torch.unsqueeze(self.ogb_dataset.data.y, dim=1)), 1)Line too long (93 > 79 characters):
58 val_split_idx = random_index[int(3 / 6 * dataset_size):int(4 / 6 * dataset_size)]Line too long (113 > 79 characters):
63 train_val_group_idx, test_group_idx = torch.tensor(train_val_group_idx), torch.tensor(test_group_idx)Line too long (104 > 79 characters):
66 split_idx = torch.zeros(len(torch.squeeze(self._metadata_array_wo_y)), dtype=torch.bool)Line too long (82 > 79 characters):
68 split_idx += (torch.squeeze(self._metadata_array_wo_y) == idx)Line too long (107 > 79 characters):
71 train_val_split_idx, test_split_idx = split_idx(train_val_group_idx), split_idx(test_group_idx)Line too long (100 > 79 characters):
73 train_val_split_idx = torch.arange(dataset_size, dtype=torch.int64)[train_val_split_idx]Line too long (98 > 79 characters):
76 train_split_idx = train_val_split_idx[random_index[:int(3 / 4 * train_val_sets_size)]]Line too long (96 > 79 characters):
77 val_split_idx = train_val_split_idx[random_index[int(3 / 4 * train_val_sets_size):]]Line too long (91 > 79 characters):
104 - prediction_fn (function): A function that turns y_pred into predicted labels.Line too long (106 > 79 characters):
105 Only None is supported because OGB Evaluators accept binary logitsLine too long (83 > 79 characters):
135 adj = self._sym_normalize_adj(to_dense_adj(graph.edge_index).squeeze())Line too long (83 > 79 characters):
147 # adj_node_feat[3:, node, node] = F.one_hot(node_feat[2].long(), 8)Line too long (85 > 79 characters):
156 deg_inv = torch.where(deg > 0, 1. / torch.sqrt(deg), torch.zeros(deg.size()))