1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
|
# Source repo: https://github.com/bknyaz/graph_attention_pool
# Compute superpixels for MNIST/CIFAR-10 using SLIC algorithm
# https://scikit-image.org/docs/dev/api/skimage.segmentation.html#skimage.segmentation.slic
import argparse
import datetime
import multiprocessing as mp
import os
import pickle
import random
import numpy as np
import scipy
import scipy.ndimage
import scipy.spatial
import torch
from skimage.segmentation import slic
from torchvision import datasets
from generate_image_dataset import get_dataset_class
def parse_args():
parser = argparse.ArgumentParser(description='Extract SLIC superpixels from images')
parser.add_argument('-D', '--dataset', type=str, default='mnist',
choices=['mnist', 'cifar10', 'ColoredMNIST', 'RotatedMNIST'])
parser.add_argument('-d', '--data_dir', type=str, default='./data', help='path to the dataset')
parser.add_argument('-o', '--out_dir', type=str, default='./data', help='path where to save superpixels')
parser.add_argument('-s', '--split', type=str, default='train', choices=['train', 'val', 'test'])
parser.add_argument('-t', '--threads', type=int, default=0, help='number of parallel threads')
parser.add_argument('-n', '--n_sp', type=int, default=75, help='max number of superpixels per image')
parser.add_argument('-c', '--compactness', type=int, default=0.25, help='compactness of the SLIC algorithm '
'(Balances color proximity and space proximity): '
'0.25 is a good value for MNIST '
'and 10 for color images like CIFAR-10')
parser.add_argument('--seed', type=int, default=111, help='seed for shuffling nodes')
args = parser.parse_args()
for arg in vars(args):
print(arg, getattr(args, arg))
return args
def process_image(params):
img, index, n_images, args, to_print, shuffle = params
# assert img.dtype == np.uint8, img.dtype
# img = (img / 255.).astype(np.float32)
n_sp_extracted = args.n_sp + 1 # number of actually extracted superpixels (can be different from requested in SLIC)
n_sp_query = args.n_sp + (
20 if args.dataset == 'mnist' or args.dataset == 'RotatedMNIST' or args.dataset == 'ColoredMNIST' else 50) # number of superpixels we ask to extract (larger to extract more superpixels - closer to the desired n_sp)
while n_sp_extracted > args.n_sp:
superpixels = slic(img, n_segments=n_sp_query, compactness=args.compactness, multichannel=len(img.shape) > 2)
sp_indices = np.unique(superpixels)
n_sp_extracted = len(sp_indices)
n_sp_query -= 1 # reducing the number of superpixels until we get <= n superpixels
assert n_sp_extracted <= args.n_sp and n_sp_extracted > 0, (args.split, index, n_sp_extracted, args.n_sp)
assert n_sp_extracted == np.max(superpixels) + 1, (
'superpixel indices', np.unique(superpixels)) # make sure superpixel indices are numbers from 0 to n-1
if shuffle:
ind = np.random.permutation(n_sp_extracted)
else:
ind = np.arange(n_sp_extracted)
sp_order = sp_indices[ind].astype(np.int32)
if len(img.shape) == 2:
img = img[:, :, None]
# n_ch = 1 if img.shape[2] == 1 else 3
n_ch = img.shape[2]
sp_intensity, sp_coord = [], []
for seg in sp_order:
mask = (superpixels == seg).squeeze()
avg_value = np.zeros(n_ch)
for c in range(n_ch):
avg_value[c] = np.mean(img[:, :, c][mask])
cntr = np.array(scipy.ndimage.measurements.center_of_mass(mask)) # row, col
sp_intensity.append(avg_value)
sp_coord.append(cntr)
sp_intensity = np.array(sp_intensity, np.float32)
sp_coord = np.array(sp_coord, np.float32)
if to_print:
print('image={}/{}, shape={}, min={:.2f}, max={:.2f}, n_sp={}'.format(index + 1, n_images, img.shape,
img.min(), img.max(),
sp_intensity.shape[0]))
return sp_intensity, sp_coord, sp_order, superpixels
if __name__ == '__main__':
dt = datetime.datetime.now()
print('start time:', dt)
args = parse_args()
if not os.path.isdir(args.out_dir):
os.mkdir(args.out_dir)
random.seed(args.seed)
np.random.seed(args.seed) # to make node random permutation reproducible (not tested)
# Read image data using torchvision
is_train = args.split.lower() == 'train'
if args.dataset == 'mnist':
data = datasets.MNIST(args.data_dir, train=is_train, download=True)
assert args.compactness < 10, ('high compactness can result in bad superpixels on MNIST')
assert args.n_sp > 1 and args.n_sp < 28 * 28, (
'the number of superpixels cannot exceed the total number of pixels or be too small')
elif args.dataset == 'cifar10':
data = datasets.CIFAR10(args.data_dir, train=is_train, download=True)
assert args.compactness > 1, ('low compactness can result in bad superpixels on CIFAR-10')
assert args.n_sp > 1 and args.n_sp < 32 * 32, (
'the number of superpixels cannot exceed the total number of pixels or be too small')
elif args.dataset == 'ColoredMNIST' or args.dataset == 'RotatedMNIST':
datasets_instance = get_dataset_class(args.dataset)(args.data_dir, None, None)
datasets = datasets_instance.datasets
environments = datasets_instance.environments
else:
raise NotImplementedError('unsupported dataset: ' + args.dataset)
images, labels, data_envs = [], [], []
for i in range(len(environments)):
environment = environments[i]
dataset = datasets[i]
for j in range(len(dataset)):
image, label = dataset[j]
images.append(image)
labels.append(label.item())
data_envs.append(i)
images = torch.stack(images)
if args.dataset == 'ColoredMNIST' :
images = images.permute(0,2,3,1)
if not isinstance(images, np.ndarray):
images = np.squeeze(images.numpy()) # [0,1]
if isinstance(labels, list):
labels = np.array(labels)
if isinstance(data_envs, list):
data_envs = np.array(data_envs)
n_images = len(labels)
if args.threads <= 0:
sp_data = []
for i in range(n_images):
sp_data.append(process_image((images[i], i, n_images, args, True, True)))
else:
with mp.Pool(processes=args.threads) as pool:
sp_data = pool.map(process_image, [(images[i], i, n_images, args, True, True) for i in range(n_images)])
superpixels = [sp_data[i][3] for i in range(n_images)]
sp_data = [sp_data[i][:3] for i in range(n_images)]
with open('%s/%s_%dsp_%s.pkl' % (args.out_dir, args.dataset, args.n_sp, args.split), 'wb') as f:
pickle.dump((labels.astype(np.int32), sp_data), f, protocol=2)
with open('%s/%s_%dsp_%s_superpixels.pkl' % (args.out_dir, args.dataset, args.n_sp, args.split), 'wb') as f:
pickle.dump(superpixels, f, protocol=2)
with open(f'{args.out_dir}/{args.dataset}_group.npy', 'wb') as f:
np.save(f, data_envs)
with open(f'{args.out_dir}/{args.dataset}_images.npy', 'wb') as f:
np.save(f, images)
print('done in {}'.format(datetime.datetime.now() - dt))
|