forked from CAVEconnectome/EMAnnotationSchemas
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathbase.py
192 lines (140 loc) · 4.98 KB
/
base.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
from enum import Enum
import marshmallow as mm
import numpy as np
from geoalchemy2.shape import to_shape
from geoalchemy2.types import WKBElement, WKTElement
from marshmallow import INCLUDE
from sqlalchemy.sql.sqltypes import Integer
class MetaDataTypes(Enum):
"""Enum to hold custom marshmallow
fields to facilitate SQLAlchemy model creation.
"""
REFERENCE = "reference"
ROOT_ID = "root_id"
SPATIAL_POINT = "spatial_point"
SUPERVOXEL_ID = "supervoxel_id"
class NumericField(mm.fields.Int):
def _jsonschema_type_mapping(self):
return {
"type": "integer",
}
class SegmentationField(NumericField):
"""Custom marshmallow field to specify the
SQLAlchemy column is of a 'segmentation' type,
i.e. a 'root_id' column or a 'supervoxel_id'
"""
pass
class PostGISField(mm.fields.Field):
def _jsonschema_type_mapping(self):
return {
"type": "array",
}
def _deserialize(self, value, attr, obj, **kwargs):
if isinstance(value, (WKBElement, WKTElement)):
return get_geom_from_wkb(value)
return value
def _serialize(self, value, attr, obj, **kwargs):
value = f"POINTZ({value[0]} {value[1]} {value[2]})"
return value
def get_geom_from_wkb(wkb):
wkb_element = to_shape(wkb)
if wkb_element.has_z:
return np.asarray(
[wkb_element.xy[0][0], wkb_element.xy[1][0], wkb_element.z], dtype=np.uint64
)
return wkb_element
class ReferenceTableField(mm.fields.Field):
def _jsonschema_type_mapping(self):
return {
"type": "integer",
}
def _deserialize(self, value, attr, obj, **kwargs):
if not isinstance(value, Integer):
return int(value)
return value
def _serialize(self, value, attr, obj, **kwargs):
return int(value)
class IdSchema(mm.Schema):
"""schema with a unique identifier"""
oid = mm.fields.Int(description="identifier for annotation, unique in type")
class AnnotationSchema(mm.Schema):
class Meta:
unknown = INCLUDE
"""schema with the type of annotation"""
valid = mm.fields.Bool(
required=False,
description="is this annotation valid",
default=False,
missing=None,
)
class ReferenceAnnotation(AnnotationSchema):
"""a annotation that references another annotation"""
target_id = ReferenceTableField(
required=True,
description="annotation this references",
metadata={"field_type": MetaDataTypes.REFERENCE.value},
index=True,
)
class FlatSegmentationReference(AnnotationSchema):
pass
class TagAnnotation(mm.Schema):
"""a simple tagged annotation"""
tag = mm.fields.Str(required=True, description="tag to attach to annoation")
class ReferenceTagAnnotation(ReferenceAnnotation, TagAnnotation):
"""A tag attached to another annotation"""
class SpatialPoint(mm.Schema):
"""a position in the segmented volume"""
position = PostGISField(
required=True,
description="spatial position in voxels of x,y,z of annotation",
postgis_geometry="POINTZ",
metadata={"field_type": MetaDataTypes.SPATIAL_POINT.value},
index=True,
)
@mm.post_load
def transform_position(self, data, **kwargs):
if self.context.get("postgis", False):
data[
"position"
] = f'POINTZ({data["position"][0]} {data["position"][1]} {data["position"][2]})'
return data
@mm.post_load
def to_numpy(self, data, **kwargs):
if self.context.get("numpy", False):
data["position"] = np.asarray(data["position"], dtype=np.uint64)
return data
class BoundSpatialPoint(SpatialPoint):
"""a position in the segmented volume that is associated with an object"""
supervoxel_id = SegmentationField(
missing=None,
description="supervoxel id of this point",
metadata={"field_type": MetaDataTypes.SUPERVOXEL_ID.value},
segmentation_field=True,
)
root_id = SegmentationField(
description="root id of the bound point",
missing=None,
metadata={"field_type": MetaDataTypes.ROOT_ID.value},
segmentation_field=True,
index=True,
)
@mm.post_load
def convert_point(self, item, **kwargs):
bsp_fn = self.context.get("bsp_fn", None)
if bsp_fn is not None:
bsp_fn(item)
return item
class FlatSegmentationReferenceSinglePoint(ReferenceAnnotation):
"""Bound spatial point reference to another annotation"""
pt = mm.fields.Nested(
BoundSpatialPoint,
required=True,
description="the point to be used for attaching objects to the dynamic segmentation",
)
class RepresentativePoint(AnnotationSchema):
"""Bound spatial point annotation"""
pt = mm.fields.Nested(
BoundSpatialPoint,
required=True,
description="the point to be used for attaching objects to the dynamic segmentation",
)