-
Notifications
You must be signed in to change notification settings - Fork 42
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add a tool to extract features from OSM to retrain
- Loading branch information
Showing
3 changed files
with
88 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
from src.detection.walker import Walker | ||
|
||
|
||
class CoordWalker(Walker): | ||
def __init__(self, tile, nodes, square_image_length=50, zoom_level=19, step_width=0.66): | ||
super(CoordWalker, self).__init__(tile, square_image_length, zoom_level, step_width) | ||
self.nodes = nodes | ||
|
||
def get_tiles(self): | ||
squared_tiles = self._get_squared_tiles(self.nodes) | ||
return squared_tiles |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
from src.base.tile import Tile | ||
from src.base.bbox import Bbox | ||
from src.base.node import Node | ||
from src.base.tag import Tag | ||
|
||
from src.data.orthofoto.wms.wms_api import WmsApi | ||
|
||
from src.train.coord_walker import CoordWalker | ||
from src.train.osm_object_walker import OsmObjectWalker | ||
|
||
import argparse | ||
|
||
def main(args): | ||
coords = list(map(lambda c: Node(*map(float, c.split(','))), args.coord)) | ||
bbox = Bbox.from_nodes(coords[0], coords[1]) | ||
if args.tags: | ||
tags = map(lambda k, v: Tag(key=k, value=v), map(lambda kv: kv.split('=', 1), args.tags.split(','))) | ||
#walker = OsmObjectWalker(Tile(image_api=WmsApi(), bbox=bbox), Tag(key='public_transport', value='platform'), square_image_length=100) | ||
walker = OsmObjectWalker(Tile(image_api=WmsApi(), bbox=bbox), tags, square_image_length=100) | ||
else: | ||
walker = CoordWalker(Tile(image_api=WmsApi(), bbox=bbox), coords, square_image_length=100) | ||
|
||
tiles = walker.get_tiles() | ||
for n, t in enumerate(tiles): | ||
centre_node = t.get_centre_node() | ||
name = "fetch/{0:02.8}_{1:02.8}.png".format(centre_node.latitude, centre_node.longitude) | ||
t.image.save(name, "PNG") | ||
print(name) | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument( | ||
'--tags', | ||
type=str, | ||
default=None, | ||
help='Tag to fetch from OSM: highway=crossing.' | ||
) | ||
|
||
parser.add_argument( | ||
'coord', | ||
type=str, | ||
action='store', | ||
nargs='+', | ||
help='lon,lat coord in WGS84, if --tags bbox left,bottom right,top, else list of coords to fetch.') | ||
|
||
args = parser.parse_args() | ||
main(args) | ||
|
||
# mapproxy-util serve-develop mapproxy.yml | ||
# montage *.png -geometry 100x100+1+1 out.png | ||
# python retrain.py --image_dir retrain-data --print_misclassified_test_images |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
from src.base.node import Node | ||
from src.detection.walker import Walker | ||
from src.data.osm.overpass_api import OverpassApi | ||
|
||
|
||
class OsmObjectWalker(Walker): | ||
def __init__(self, tile, tags, square_image_length=50, zoom_level=19, step_width=0.66): | ||
super(OsmObjectWalker, self).__init__(tile, square_image_length, zoom_level, step_width) | ||
self.tags = tags | ||
|
||
def get_tiles(self): | ||
nodes = self._calculate_tile_centres() | ||
squared_tiles = self._get_squared_tiles(nodes) | ||
return squared_tiles | ||
|
||
def _calculate_tile_centres(self): | ||
centers = [] | ||
|
||
# [out:csv(::lat,::lon)][timeout:25];node["public_transport"="platform"]({{bbox}});out; | ||
self.api = OverpassApi() | ||
data = self.api.get(self.tile.bbox, self.tags, nodes=True, ways=False, relations=False, responseformat='csv(::lat,::lon)') | ||
data = list(map(lambda cc: Node(float(cc[0]), float(cc[1])), data[1:])) | ||
print(data) | ||
|
||
return data |