Skip to content

Commit

Permalink
Add a tool to extract features from OSM to retrain
Browse files Browse the repository at this point in the history
  • Loading branch information
frodrigo committed Jun 20, 2018
1 parent 0d95159 commit 3b8bc01
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 0 deletions.
11 changes: 11 additions & 0 deletions src/train/coord_walker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from src.detection.walker import Walker


class CoordWalker(Walker):
def __init__(self, tile, nodes, square_image_length=50, zoom_level=19, step_width=0.66):
super(CoordWalker, self).__init__(tile, square_image_length, zoom_level, step_width)
self.nodes = nodes

def get_tiles(self):
squared_tiles = self._get_squared_tiles(self.nodes)
return squared_tiles
52 changes: 52 additions & 0 deletions src/train/fetch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from src.base.tile import Tile
from src.base.bbox import Bbox
from src.base.node import Node
from src.base.tag import Tag

from src.data.orthofoto.wms.wms_api import WmsApi

from src.train.coord_walker import CoordWalker
from src.train.osm_object_walker import OsmObjectWalker

import argparse

def main(args):
coords = list(map(lambda c: Node(*map(float, c.split(','))), args.coord))
bbox = Bbox.from_nodes(coords[0], coords[1])
if args.tags:
tags = map(lambda k, v: Tag(key=k, value=v), map(lambda kv: kv.split('=', 1), args.tags.split(',')))
#walker = OsmObjectWalker(Tile(image_api=WmsApi(), bbox=bbox), Tag(key='public_transport', value='platform'), square_image_length=100)
walker = OsmObjectWalker(Tile(image_api=WmsApi(), bbox=bbox), tags, square_image_length=100)
else:
walker = CoordWalker(Tile(image_api=WmsApi(), bbox=bbox), coords, square_image_length=100)

tiles = walker.get_tiles()
for n, t in enumerate(tiles):
centre_node = t.get_centre_node()
name = "fetch/{0:02.8}_{1:02.8}.png".format(centre_node.latitude, centre_node.longitude)
t.image.save(name, "PNG")
print(name)


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
'--tags',
type=str,
default=None,
help='Tag to fetch from OSM: highway=crossing.'
)

parser.add_argument(
'coord',
type=str,
action='store',
nargs='+',
help='lon,lat coord in WGS84, if --tags bbox left,bottom right,top, else list of coords to fetch.')

args = parser.parse_args()
main(args)

# mapproxy-util serve-develop mapproxy.yml
# montage *.png -geometry 100x100+1+1 out.png
# python retrain.py --image_dir retrain-data --print_misclassified_test_images
25 changes: 25 additions & 0 deletions src/train/osm_object_walker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from src.base.node import Node
from src.detection.walker import Walker
from src.data.osm.overpass_api import OverpassApi


class OsmObjectWalker(Walker):
def __init__(self, tile, tags, square_image_length=50, zoom_level=19, step_width=0.66):
super(OsmObjectWalker, self).__init__(tile, square_image_length, zoom_level, step_width)
self.tags = tags

def get_tiles(self):
nodes = self._calculate_tile_centres()
squared_tiles = self._get_squared_tiles(nodes)
return squared_tiles

def _calculate_tile_centres(self):
centers = []

# [out:csv(::lat,::lon)][timeout:25];node["public_transport"="platform"]({{bbox}});out;
self.api = OverpassApi()
data = self.api.get(self.tile.bbox, self.tags, nodes=True, ways=False, relations=False, responseformat='csv(::lat,::lon)')
data = list(map(lambda cc: Node(float(cc[0]), float(cc[1])), data[1:]))
print(data)

return data

0 comments on commit 3b8bc01

Please sign in to comment.