-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathrun.py
288 lines (243 loc) · 12.3 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
import os
import joblib
import numpy as np
from pathlib import Path
from pyxit.estimator import COLORSPACE_RGB, COLORSPACE_TRGB, COLORSPACE_HSV, COLORSPACE_GRAY, _raw_to_trgb, _raw_to_hsv
from shapely import wkt
from shapely.affinity import affine_transform
from skimage.util.shape import view_as_windows
from cytomine import CytomineJob
from cytomine.models import ImageInstanceCollection, ImageInstance, AttachedFileCollection, Job, PropertyCollection, \
AnnotationCollection, Annotation
from cytomine.utilities.software import parse_domain_list, str2bool
from sldc import SemanticSegmenter, SSLWorkflowBuilder, StandardOutputLogger, Logger, ImageWindow
from sldc_cytomine import CytomineTileBuilder, CytomineSlide
def extract_windows(image, dims, step):
# subwindows on input image
subwindows = view_as_windows(image, dims, step=step)
subwindows = subwindows.reshape([-1, np.product(dims)])
# generate tile identifierss
n_pixels = int(np.prod(image.shape[:2]))
window_ids = np.arange(n_pixels).reshape(image.shape[:2])
identifiers = view_as_windows(window_ids, dims[:2], step=step)
identifiers = identifiers[:, :, 0, 0].reshape([-1])
return subwindows, identifiers
class ExtraTreesSegmenter(SemanticSegmenter):
def __init__(self, pyxit, classes=None, background=0, min_std=0, max_mean=255, prediction_step=1):
super(ExtraTreesSegmenter, self).__init__(classes=classes)
self._pyxit = pyxit
self._prediction_step = prediction_step
self._min_std = min_std
self._max_mean = max_mean
self._background = background
def _process_tile(self, image):
channels = [image]
if image.ndim > 2:
channels = [image[:, :, i] for i in range(image.shape[2])]
return np.any([
np.std(c) > self._min_std or np.mean(c) < self._max_mean
for c in channels
])
def _convert_colorspace(self, image):
colorspace = self._pyxit.colorspace
flattened = image.reshape([-1] if image.ndim == 2 else [-1, image.shape[2]])
if colorspace == COLORSPACE_RGB:
return image
elif colorspace == COLORSPACE_TRGB:
return _raw_to_trgb(flattened).reshape(image.shape)
elif colorspace == COLORSPACE_HSV:
return _raw_to_hsv(flattened).reshape(image.shape)
elif colorspace == COLORSPACE_GRAY:
return _raw_to_hsv(flattened).reshape(image.shape[:2])
else:
raise ValueError("unknown colorspace code '{}'".format(colorspace))
def segment(self, image):
# extract mask
mask = np.ones(image.shape[:2], dtype=np.bool)
if image.ndim == 3 and image.shape[2] == 2 or image.shape[2] == 4:
mask = image[:, :, -1].astype(np.bool)
image = np.copy(image[:, :, :-1]) # remove mask from image
# skip processing if tile is supposed background (checked via mean & std) or not in the mask
if not (self._process_tile(image) and np.any(mask)):
return np.full(image.shape[:2], self._background)
# change colorspace
image = self._convert_colorspace(image).reshape(image.shape)
# prepare windows
target_height = self._pyxit.target_height
target_width = self._pyxit.target_width
w_dims = [target_height, target_width]
if image.ndim > 2 and image.shape[2] > 1:
w_dims += [image.shape[2]]
subwindows, w_identifiers = extract_windows(image, w_dims, self._prediction_step)
# predict
y = np.array(self._pyxit.base_estimator.predict_proba(subwindows))
cm_dims = list(image.shape[:2]) + [self._pyxit.n_classes_]
confidence_map = np.zeros(cm_dims, dtype=np.float)
pred_count_map = np.zeros(cm_dims[:2], dtype=np.int)
for row, w_index in enumerate(w_identifiers):
im_width = image.shape[1]
pred_dims = [target_height, target_width, self._pyxit.n_classes_]
x_off, y_off = w_index % im_width, w_index // im_width
confidence_map[y_off:(y_off+target_height), x_off:(x_off+target_width)] += y[:, row, :].reshape(pred_dims)
pred_count_map[y_off:(y_off+target_height), x_off:(x_off+target_width)] += 1
# average over multiple predictions
confidence_map /= np.expand_dims(pred_count_map, axis=2)
# remove classe where there is no mask
class_map = np.take(self._pyxit.classes_, np.argmax(confidence_map, axis=2))
class_map[np.logical_not(mask)] = self._background
return class_map
class AnnotationAreaChecker(object):
def __init__(self, min_area, max_area):
self._min_area = min_area
self._max_area = max_area
def check(self, annot):
if self._max_area < 0:
return self._min_area < annot.area
else:
return self._min_area < annot.area < self._max_area
def change_referential(p, height):
return affine_transform(p, [1, 0, 0, -1, 0, height])
def get_iip_window_from_annotation(slide, annotation, zoom_level):
"""generate a iip-compatible roi based on an annotation at the given zoom level"""
roi_polygon = change_referential(wkt.loads(annotation.location), slide.image_instance.height)
if zoom_level == 0:
return slide.window_from_polygon(roi_polygon)
# recompute the roi so that it matches the iip tile topology
zoom_ratio = 1 / (2 ** zoom_level)
scaled_roi = affine_transform(roi_polygon, [zoom_ratio, 0, 0, zoom_ratio, 0, 0])
min_x, min_y, max_x, max_y = (int(v) for v in scaled_roi.bounds)
diff_min_x, diff_min_y = min_x % 256, min_y % 256
diff_max_x, diff_max_y = max_x % 256, max_y % 256
min_x -= diff_min_x
min_y -= diff_min_y
max_x = min(slide.width, max_x + 256 - diff_max_x)
max_y = min(slide.height, max_y + 256 - diff_max_y)
return slide.window((min_x, min_y), max_x - min_x, max_y - min_y, scaled_roi)
def extract_images_or_rois(parameters):
id_annotations = parse_domain_list(parameters.cytomine_roi_annotations)
# if ROI annotations are provided
if len(id_annotations) > 0:
image_cache = dict() # maps ImageInstance id with CytomineSlide object
zones = list()
for id_annot in id_annotations:
annotation = Annotation().fetch(id_annot)
if annotation.image not in image_cache:
image_cache[annotation.image] = CytomineSlide(annotation.image, parameters.cytomine_zoom_level)
window = get_iip_window_from_annotation(
image_cache[annotation.image],
annotation,
parameters.cytomine_zoom_level
)
zones.append(window)
return zones
# work at image level or ROIs by term
images = ImageInstanceCollection()
if parameters.cytomine_id_images is not None:
id_images = parse_domain_list(parameters.cytomine_id_images)
images.extend([ImageInstance().fetch(_id) for _id in id_images])
else:
images = images.fetch_with_filter("project", parameters.cytomine_id_project)
slides = [CytomineSlide(img, parameters.cytomine_zoom_level) for img in images]
if parameters.cytomine_id_roi_term is None:
return slides
# fetch ROI annotations
collection = AnnotationCollection(
terms=[parameters.cytomine_id_roi_term],
reviewed=parameters.cytomine_reviewed_roi,
showWKT=True
)
collection.fetch_with_filter(project=parameters.cytomine_id_project)
slides_map = {slide.image_instance.id: slide for slide in slides}
regions = list()
for annotation in collection:
if annotation.image not in slides_map:
continue
slide = slides_map[annotation.image]
regions.append(get_iip_window_from_annotation(slide, annotation, parameters.cytomine_zoom_level))
return regions
def main(argv):
with CytomineJob.from_cli(argv) as cj:
# use only images from the current project
cj.job.update(progress=1, statusComment="Preparing execution")
# extract images to process
if cj.parameters.cytomine_zoom_level > 0 and (cj.parameters.cytomine_tile_size != 256 or cj.parameters.cytomine_tile_overlap != 0):
raise ValueError("when using zoom_level > 0, tile size should be 256 "
"(given {}) and overlap should be 0 (given {})".format(
cj.parameters.cytomine_tile_size, cj.parameters.cytomine_tile_overlap))
cj.job.update(progress=1, statusComment="Preparing execution (creating folders,...).")
# working path
root_path = str(Path.home())
working_path = os.path.join(root_path, "images")
os.makedirs(working_path, exist_ok=True)
# load training information
cj.job.update(progress=5, statusComment="Extract properties from training job.")
train_job = Job().fetch(cj.parameters.cytomine_id_job)
properties = PropertyCollection(train_job).fetch().as_dict()
binary = str2bool(properties["binary"].value)
classes = parse_domain_list(properties["classes"].value)
cj.job.update(progress=10, statusComment="Download the model file.")
attached_files = AttachedFileCollection(train_job).fetch()
model_file = attached_files.find_by_attribute("filename", "model.joblib")
model_filepath = os.path.join(root_path, "model.joblib")
model_file.download(model_filepath, override=True)
pyxit = joblib.load(model_filepath)
# set n_jobs
pyxit.base_estimator.n_jobs = cj.parameters.n_jobs
pyxit.n_jobs = cj.parameters.n_jobs
cj.job.update(progress=45, statusComment="Build workflow.")
builder = SSLWorkflowBuilder()
builder.set_tile_size(cj.parameters.cytomine_tile_size, cj.parameters.cytomine_tile_size)
builder.set_overlap(cj.parameters.cytomine_tile_overlap)
builder.set_tile_builder(CytomineTileBuilder(working_path, n_jobs=cj.parameters.n_jobs))
builder.set_logger(StandardOutputLogger(level=Logger.INFO))
builder.set_n_jobs(1)
builder.set_background_class(0)
# value 0 will prevent merging but still requires to run the merging check
# procedure (inefficient)
builder.set_distance_tolerance(2 if cj.parameters.union_enabled else 0)
builder.set_segmenter(ExtraTreesSegmenter(
pyxit=pyxit,
classes=classes,
prediction_step=cj.parameters.pyxit_prediction_step,
background=0,
min_std=cj.parameters.tile_filter_min_stddev,
max_mean=cj.parameters.tile_filter_max_mean
))
workflow = builder.get()
area_checker = AnnotationAreaChecker(
min_area=cj.parameters.min_annotation_area,
max_area=cj.parameters.max_annotation_area
)
def get_term(label):
if binary:
if "cytomine_id_predict_term" not in cj.parameters or not cj.parameters.cytomine_id_predict_term:
return []
else:
return [int(cj.parameters.cytomine_id_predict_term)]
# multi-class
return [label]
zones = extract_images_or_rois(cj.parameters)
for zone in cj.monitor(zones, start=50, end=90, period=0.05, prefix="Segmenting images/ROIs"):
results = workflow.process(zone)
annotations = AnnotationCollection()
for obj in results:
if not area_checker.check(obj.polygon):
continue
polygon = obj.polygon
if isinstance(zone, ImageWindow):
polygon = affine_transform(polygon, [1, 0, 0, 1, zone.abs_offset_x, zone.abs_offset_y])
polygon = change_referential(polygon, zone.base_image.height)
if cj.parameters.cytomine_zoom_level > 0:
zoom_mult = (2 ** cj.parameters.cytomine_zoom_level)
polygon = affine_transform(polygon, [zoom_mult, 0, 0, zoom_mult, 0, 0])
annotations.append(Annotation(
location=polygon.wkt,
id_terms=get_term(obj.label),
id_project=cj.project.id,
id_image=zone.base_image.image_instance.id
))
annotations.save()
cj.job.update(status=Job.TERMINATED, status_comment="Finish", progress=100)
if __name__ == "__main__":
import sys
main(sys.argv[1:])