Skip to content

Commit 3b8bc01

Browse files
committed
Add a tool to extract features from OSM to retrain
1 parent 0d95159 commit 3b8bc01

File tree

3 files changed

+88
-0
lines changed

3 files changed

+88
-0
lines changed

src/train/coord_walker.py

+11
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
from src.detection.walker import Walker
2+
3+
4+
class CoordWalker(Walker):
5+
def __init__(self, tile, nodes, square_image_length=50, zoom_level=19, step_width=0.66):
6+
super(CoordWalker, self).__init__(tile, square_image_length, zoom_level, step_width)
7+
self.nodes = nodes
8+
9+
def get_tiles(self):
10+
squared_tiles = self._get_squared_tiles(self.nodes)
11+
return squared_tiles

src/train/fetch.py

+52
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
from src.base.tile import Tile
2+
from src.base.bbox import Bbox
3+
from src.base.node import Node
4+
from src.base.tag import Tag
5+
6+
from src.data.orthofoto.wms.wms_api import WmsApi
7+
8+
from src.train.coord_walker import CoordWalker
9+
from src.train.osm_object_walker import OsmObjectWalker
10+
11+
import argparse
12+
13+
def main(args):
14+
coords = list(map(lambda c: Node(*map(float, c.split(','))), args.coord))
15+
bbox = Bbox.from_nodes(coords[0], coords[1])
16+
if args.tags:
17+
tags = map(lambda k, v: Tag(key=k, value=v), map(lambda kv: kv.split('=', 1), args.tags.split(',')))
18+
#walker = OsmObjectWalker(Tile(image_api=WmsApi(), bbox=bbox), Tag(key='public_transport', value='platform'), square_image_length=100)
19+
walker = OsmObjectWalker(Tile(image_api=WmsApi(), bbox=bbox), tags, square_image_length=100)
20+
else:
21+
walker = CoordWalker(Tile(image_api=WmsApi(), bbox=bbox), coords, square_image_length=100)
22+
23+
tiles = walker.get_tiles()
24+
for n, t in enumerate(tiles):
25+
centre_node = t.get_centre_node()
26+
name = "fetch/{0:02.8}_{1:02.8}.png".format(centre_node.latitude, centre_node.longitude)
27+
t.image.save(name, "PNG")
28+
print(name)
29+
30+
31+
if __name__ == "__main__":
32+
parser = argparse.ArgumentParser()
33+
parser.add_argument(
34+
'--tags',
35+
type=str,
36+
default=None,
37+
help='Tag to fetch from OSM: highway=crossing.'
38+
)
39+
40+
parser.add_argument(
41+
'coord',
42+
type=str,
43+
action='store',
44+
nargs='+',
45+
help='lon,lat coord in WGS84, if --tags bbox left,bottom right,top, else list of coords to fetch.')
46+
47+
args = parser.parse_args()
48+
main(args)
49+
50+
# mapproxy-util serve-develop mapproxy.yml
51+
# montage *.png -geometry 100x100+1+1 out.png
52+
# python retrain.py --image_dir retrain-data --print_misclassified_test_images

src/train/osm_object_walker.py

+25
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
from src.base.node import Node
2+
from src.detection.walker import Walker
3+
from src.data.osm.overpass_api import OverpassApi
4+
5+
6+
class OsmObjectWalker(Walker):
7+
def __init__(self, tile, tags, square_image_length=50, zoom_level=19, step_width=0.66):
8+
super(OsmObjectWalker, self).__init__(tile, square_image_length, zoom_level, step_width)
9+
self.tags = tags
10+
11+
def get_tiles(self):
12+
nodes = self._calculate_tile_centres()
13+
squared_tiles = self._get_squared_tiles(nodes)
14+
return squared_tiles
15+
16+
def _calculate_tile_centres(self):
17+
centers = []
18+
19+
# [out:csv(::lat,::lon)][timeout:25];node["public_transport"="platform"]({{bbox}});out;
20+
self.api = OverpassApi()
21+
data = self.api.get(self.tile.bbox, self.tags, nodes=True, ways=False, relations=False, responseformat='csv(::lat,::lon)')
22+
data = list(map(lambda cc: Node(float(cc[0]), float(cc[1])), data[1:]))
23+
print(data)
24+
25+
return data

0 commit comments

Comments
 (0)