-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathxml_manipulation.py
173 lines (143 loc) · 5.8 KB
/
xml_manipulation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
"""
Different functions for XML manipulations!
"""
def get_xml_bbox(root):
result = []
for object in root.findall("object"):
xmin = int(object.find("bndbox/xmin").text)
ymin = int(object.find("bndbox/ymin").text)
xmax = int(object.find("bndbox/xmax").text)
ymax = int(object.find("bndbox/ymax").text)
result.append([xmin, ymin, xmax, ymax])
return result
def get_xml_bbox_area(root):
areas = []
for object in root.findall("object"):
xmin = int(object.find("bndbox/xmin").text)
ymin = int(object.find("bndbox/ymin").text)
xmax = int(object.find("bndbox/xmax").text)
ymax = int(object.find("bndbox/ymax").text)
area = int((ymax - ymin) * (xmax - xmin))
areas.append(area)
return areas
def get_xml_width(root):
width = int(root.find("size/width").text)
return width
def get_xml_height(root):
height = int(root.find("size/height").text)
return height
def get_targeted_xml_object(root, wanted_bbox_list):
"""
return list of element tree object if the any bounding box is in the passed in list
"""
target_object = []
for object in root.findall("object"):
xmin = int(object.find("bndbox/xmin").text)
ymin = int(object.find("bndbox/ymin").text)
xmax = int(object.find("bndbox/xmax").text)
ymax = int(object.find("bndbox/ymax").text)
if [xmin, ymin, xmax, ymax] in wanted_bbox_list:
target_object.append(object)
return target_object
def get_xml_class_list(XML_PATH):
# import library
import os
from tqdm import tqdm
import xml.etree.ElementTree as ET
from tools import get_filelist
xml_files = get_filelist(XML_PATH)
class_list = []
for item in tqdm(xml_files):
tree = ET.parse(os.path.join(XML_PATH, f"{item}.xml"))
root = tree.getroot()
for object in root.findall("object"):
name = object.find("name").text
if name not in class_list:
class_list.append(name)
return class_list
def add_xml_object(root, name, bbox_coordinates):
# import library
import xml.etree.ElementTree as ET
new_object = ET.SubElement(root, "object")
new_name = ET.SubElement(new_object, "name")
new_name.text = name
new_pose = ET.SubElement(new_object, "pose")
new_pose.text = "Unspecified"
new_trun = ET.SubElement(new_object, "truncated")
new_trun.text = "0"
new_diff = ET.SubElement(new_object, "difficult")
new_diff.text = "0"
new_bbox = ET.SubElement(new_object, "bndbox")
new_xmin = ET.SubElement(new_bbox, "xmin")
new_xmin.text = str(bbox_coordinates[0])
new_ymin = ET.SubElement(new_bbox, "ymin")
new_ymin.text = str(bbox_coordinates[1])
new_xmax = ET.SubElement(new_bbox, "xmax")
new_xmax.text = str(bbox_coordinates[2])
new_ymax = ET.SubElement(new_bbox, "ymax")
new_ymax.text = str(bbox_coordinates[3])
return
def drop_xml_small_bbox(root, max_area=400, list_to_keep=None):
"""
This function loops through all bounding boxes in the XML annotation file passed in.
If the 2 conditions passed in are met, the bounding box will be dropped.
Arguments:
root: xml.Elementree object, root element of the parsed xml tree
max_area: int, bounding box area smaller than this threshold will be removed
list_to_keep: list, a list of lists contains bounding boxes coordinates [xmin, ymin, xmax, ymax] to be kept
Returns:
remain_box: list, a list of lists contains remaining bounding boxes
"""
remain_bbox = [] # store bboxes kept
for object in root.findall("object"):
xmin = int(object.find("bndbox/xmin").text)
ymin = int(object.find("bndbox/ymin").text)
xmax = int(object.find("bndbox/xmax").text)
ymax = int(object.find("bndbox/ymax").text)
area = int((ymax - ymin) * (xmax - xmin))
## 2 conditions:
# condition1:
cond1 = area <= max_area
# condition2:
cond2 = [xmin, ymin, xmax, ymax] not in list_to_keep
try:
if max_area is None and list_to_keep is None:
print("Invalid input: both conditions cannot be NoneType")
elif cond1 and list_to_keep is None:
root.remove(object)
elif cond2 and max_area is None:
root.remove(object)
elif cond1 and cond2:
root.remove(object)
else:
remain_bbox.append([xmin, ymin, xmax, ymax])
except TypeError:
# exclude cases of either condition arguments is NoneType
remain_bbox.append([xmin, ymin, xmax, ymax])
return remain_bbox
def update_crop_xml_bbox(root, w, h, x_shift, y_shift):
"""
Return new bounding boxes value of cropped image
Arguments:
root: xml.Elementree object, root element of the parsed xml tree
w: int, new image width after cropped
h: int, new image height after cropped
x_shift: int, number of pixel shifted from original image, in x-direction
y_shift: int, number of pixel shifted from original image, in y-direction
"""
root.find("size/width").text = str(w)
root.find("size/height").text = str(h)
for object in root.findall("object"):
# for shifted xmin
xmin = int(object.find("bndbox/xmin").text)
object.find("bndbox/xmin").text = str(xmin - x_shift)
# for shifted ymin
ymin = int(object.find("bndbox/ymin").text)
object.find("bndbox/ymin").text = str(ymin - y_shift)
# for shifted xmax
xmax = int(object.find("bndbox/xmax").text)
object.find("bndbox/xmax").text = str(xmax - x_shift)
# for shited ymax
ymax = int(object.find("bndbox/ymax").text)
object.find("bndbox/ymax").text = str(ymax - y_shift)
return