From 3b8bc0102e4f0922a01a2e80cc43619a8e9970cc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fr=C3=A9d=C3=A9ric=20Rodrigo?= Date: Tue, 19 Jun 2018 00:17:43 +0200 Subject: [PATCH] Add a tool to extract features from OSM to retrain --- src/train/coord_walker.py | 11 +++++++ src/train/fetch.py | 52 ++++++++++++++++++++++++++++++++++ src/train/osm_object_walker.py | 25 ++++++++++++++++ 3 files changed, 88 insertions(+) create mode 100644 src/train/coord_walker.py create mode 100644 src/train/fetch.py create mode 100644 src/train/osm_object_walker.py diff --git a/src/train/coord_walker.py b/src/train/coord_walker.py new file mode 100644 index 00000000..1151ff8a --- /dev/null +++ b/src/train/coord_walker.py @@ -0,0 +1,11 @@ +from src.detection.walker import Walker + + +class CoordWalker(Walker): + def __init__(self, tile, nodes, square_image_length=50, zoom_level=19, step_width=0.66): + super(CoordWalker, self).__init__(tile, square_image_length, zoom_level, step_width) + self.nodes = nodes + + def get_tiles(self): + squared_tiles = self._get_squared_tiles(self.nodes) + return squared_tiles diff --git a/src/train/fetch.py b/src/train/fetch.py new file mode 100644 index 00000000..be22e8e2 --- /dev/null +++ b/src/train/fetch.py @@ -0,0 +1,52 @@ +from src.base.tile import Tile +from src.base.bbox import Bbox +from src.base.node import Node +from src.base.tag import Tag + +from src.data.orthofoto.wms.wms_api import WmsApi + +from src.train.coord_walker import CoordWalker +from src.train.osm_object_walker import OsmObjectWalker + +import argparse + +def main(args): + coords = list(map(lambda c: Node(*map(float, c.split(','))), args.coord)) + bbox = Bbox.from_nodes(coords[0], coords[1]) + if args.tags: + tags = map(lambda k, v: Tag(key=k, value=v), map(lambda kv: kv.split('=', 1), args.tags.split(','))) + #walker = OsmObjectWalker(Tile(image_api=WmsApi(), bbox=bbox), Tag(key='public_transport', value='platform'), square_image_length=100) + walker = OsmObjectWalker(Tile(image_api=WmsApi(), bbox=bbox), tags, square_image_length=100) + else: + walker = CoordWalker(Tile(image_api=WmsApi(), bbox=bbox), coords, square_image_length=100) + + tiles = walker.get_tiles() + for n, t in enumerate(tiles): + centre_node = t.get_centre_node() + name = "fetch/{0:02.8}_{1:02.8}.png".format(centre_node.latitude, centre_node.longitude) + t.image.save(name, "PNG") + print(name) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + '--tags', + type=str, + default=None, + help='Tag to fetch from OSM: highway=crossing.' + ) + + parser.add_argument( + 'coord', + type=str, + action='store', + nargs='+', + help='lon,lat coord in WGS84, if --tags bbox left,bottom right,top, else list of coords to fetch.') + + args = parser.parse_args() + main(args) + +# mapproxy-util serve-develop mapproxy.yml +# montage *.png -geometry 100x100+1+1 out.png +# python retrain.py --image_dir retrain-data --print_misclassified_test_images diff --git a/src/train/osm_object_walker.py b/src/train/osm_object_walker.py new file mode 100644 index 00000000..79de2035 --- /dev/null +++ b/src/train/osm_object_walker.py @@ -0,0 +1,25 @@ +from src.base.node import Node +from src.detection.walker import Walker +from src.data.osm.overpass_api import OverpassApi + + +class OsmObjectWalker(Walker): + def __init__(self, tile, tags, square_image_length=50, zoom_level=19, step_width=0.66): + super(OsmObjectWalker, self).__init__(tile, square_image_length, zoom_level, step_width) + self.tags = tags + + def get_tiles(self): + nodes = self._calculate_tile_centres() + squared_tiles = self._get_squared_tiles(nodes) + return squared_tiles + + def _calculate_tile_centres(self): + centers = [] + + # [out:csv(::lat,::lon)][timeout:25];node["public_transport"="platform"]({{bbox}});out; + self.api = OverpassApi() + data = self.api.get(self.tile.bbox, self.tags, nodes=True, ways=False, relations=False, responseformat='csv(::lat,::lon)') + data = list(map(lambda cc: Node(float(cc[0]), float(cc[1])), data[1:])) + print(data) + + return data