Skip to content

Commit 06a470d

Browse files
authoredSep 15, 2021
Adds a new command spatial-tree index (#475)
* Adds a new command `spatial-tree index` Command only works if pywraps2 is manually installed into the kart venv. Although tests exist, they are currently skipped for this reason. (Not trying to modify the kart build while it is red) * Address PR comments on spatial-tree command
1 parent b33d7f2 commit 06a470d

File tree

4 files changed

+560
-0
lines changed

4 files changed

+560
-0
lines changed
 

‎kart/cli.py

+1
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
"pull": {"pull"},
3737
"resolve": {"resolve"},
3838
"show": {"create-patch", "show"},
39+
"spatial_tree": {"spatial-tree"},
3940
"status": {"status"},
4041
"query": {"query"},
4142
"upgrade": {"upgrade", "upgrade-to-tidy", "upgrade-to-kart"},

‎kart/repo.py

+1
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ class KartRepoFiles:
4747
# Kart-specific files:
4848
MERGE_INDEX = "MERGE_INDEX"
4949
MERGE_BRANCH = "MERGE_BRANCH"
50+
S2_INDEX = "s2_index.db"
5051

5152

5253
class KartRepoState(Enum):

‎kart/spatial_tree.py

+430
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,430 @@
1+
import functools
2+
import logging
3+
import re
4+
import subprocess
5+
import time
6+
7+
import click
8+
from osgeo import osr, ogr
9+
from pysqlite3 import dbapi2 as sqlite
10+
from sqlalchemy import Column, ForeignKey, Integer, Table, Text
11+
from sqlalchemy.orm import sessionmaker
12+
from sqlalchemy.types import BLOB
13+
14+
15+
from .cli_util import add_help_subcommand, tool_environment
16+
from .crs_util import make_crs, normalise_wkt
17+
from .exceptions import SubprocessError
18+
from .geometry import Geometry, GeometryType, geom_envelope, gpkg_geom_to_ogr
19+
from .repo import KartRepoState, KartRepoFiles
20+
from .sqlalchemy import TableSet
21+
from .sqlalchemy.sqlite import sqlite_engine
22+
from .serialise_util import msg_unpack
23+
24+
L = logging.getLogger("kart.spatial_tree")
25+
26+
27+
S2_MAX_CELLS_INDEX = 8
28+
S2_MAX_LEVEL = 15
29+
30+
31+
def _revlist_command(repo):
32+
return [
33+
"git",
34+
"-C",
35+
repo.path,
36+
"rev-list",
37+
"--objects",
38+
"--filter=object:type=blob",
39+
"--missing=allow-promisor",
40+
]
41+
42+
43+
DS_PATH_PATTERN = r'(.+)/\.(sno|table)-dataset/'
44+
45+
46+
def _parse_revlist_output(line_iter, rel_path_pattern):
47+
full_path_pattern = re.compile(DS_PATH_PATTERN + rel_path_pattern)
48+
49+
for line in line_iter:
50+
parts = line.split(" ", maxsplit=1)
51+
if len(parts) != 2:
52+
continue
53+
oid, path = parts
54+
55+
m = full_path_pattern.match(path)
56+
if not m:
57+
continue
58+
ds_path = m.group(1)
59+
yield ds_path, oid
60+
61+
62+
class CrsHelper:
63+
"""
64+
Loads all CRS definitions for a particular dataset,
65+
and creates transforms
66+
"""
67+
68+
def __init__(self, repo):
69+
self.repo = repo
70+
self.ds_to_transforms = {}
71+
72+
def transforms_for_dataset(self, ds_path):
73+
transforms = self.ds_to_transforms.get(ds_path)
74+
if transforms is None:
75+
transforms = self._load_transforms_for_dataset(ds_path)
76+
self.ds_to_transforms[ds_path] = transforms
77+
return transforms
78+
79+
def _load_transforms_for_dataset(self, ds_path):
80+
if ds_path in self.ds_to_transforms:
81+
return self.ds_to_transforms[ds_path]
82+
83+
crs_oids = set(self.iter_crs_oids(ds_path))
84+
transforms = []
85+
descs = []
86+
for crs_oid in crs_oids:
87+
try:
88+
transform, desc = self.transform_from_oid(crs_oid)
89+
if transform not in transforms:
90+
transforms.append(transform)
91+
descs.append(desc)
92+
except Exception as e:
93+
L.warning(
94+
f"Couldn't load transform for CRS {crs_oid} at {ds_path}\n{e}"
95+
)
96+
L.info(f"Loaded CRS transforms for {ds_path}: {', '.join(descs)}")
97+
return transforms
98+
99+
def iter_crs_oids(self, ds_path):
100+
cmd = [
101+
*_revlist_command(self.repo),
102+
"--all",
103+
"--",
104+
*self.all_crs_paths(ds_path),
105+
]
106+
try:
107+
r = subprocess.run(
108+
cmd,
109+
encoding="utf8",
110+
check=True,
111+
capture_output=True,
112+
env=tool_environment(),
113+
)
114+
except subprocess.CalledProcessError as e:
115+
raise SubprocessError(
116+
f"There was a problem with git rev-list: {e}", called_process_error=e
117+
)
118+
for d, crs_oid in _parse_revlist_output(
119+
r.stdout.splitlines(), r"meta/crs/[^/]+"
120+
):
121+
assert d == ds_path
122+
yield crs_oid
123+
124+
def all_crs_paths(self, ds_path):
125+
# Delete .sno-dataset if we drop V2 support.
126+
yield f"{ds_path}/.sno-dataset/meta/crs/"
127+
yield f"{ds_path}/.table-dataset/meta/crs/"
128+
129+
@functools.lru_cache()
130+
def transform_from_oid(self, crs_oid):
131+
wkt = normalise_wkt(self.repo[crs_oid].data.decode("utf-8"))
132+
return self.transform_from_wkt(wkt)
133+
134+
@functools.lru_cache()
135+
def transform_from_wkt(self, wkt):
136+
src_crs = make_crs(wkt)
137+
if src_crs.IsGeographic():
138+
transform = None
139+
desc = f"IDENTITY({src_crs.GetAuthorityCode(None)})"
140+
else:
141+
target_crs = src_crs.CloneGeogCS()
142+
transform = osr.CoordinateTransformation(src_crs, target_crs)
143+
desc = f"{src_crs.GetAuthorityCode(None)} -> {target_crs.GetAuthorityCode(None)}"
144+
return transform, desc
145+
146+
147+
class SpatialTreeTables(TableSet):
148+
"""Tables for associating a variable number of S2 cells with each feature."""
149+
150+
def __init__(self):
151+
super().__init__()
152+
153+
# "blobs" tracks all the features we have indexed (even if they do not overlap any s2 cells).
154+
self.blobs = Table(
155+
"blobs",
156+
self.sqlalchemy_metadata,
157+
# From a user-perspective, "rowid" isjust an arbitrary integer primary key.
158+
# In more detail: This column aliases to the sqlite rowid of the table.
159+
# See https://www.sqlite.org/lang_createtable.html#rowid
160+
# Using the rowid directly as a foreign key (see "blob_cells") means faster joins.
161+
# The rowid can be used without creating a column that aliases to it, but you shouldn't -
162+
# rowids might change if they are not aliased. See https://sqlite.org/lang_vacuum.html)
163+
Column("rowid", Integer, nullable=False, primary_key=True),
164+
# "blob_id" is the git object ID (the SHA-1 hash) of a feature, in binary (20 bytes).
165+
# Is equivalent to 40 chars of hex eg: d08c3dd220eea08d8dfd6d4adb84f9936c541d7a
166+
Column("blob_id", BLOB, nullable=False, unique=True),
167+
sqlite_autoincrement=True,
168+
)
169+
170+
# "blob_cells" associates 0 or more S2 cell tokens with each feature that we have indexed.
171+
self.blob_cells = Table(
172+
"blob_cells",
173+
self.sqlalchemy_metadata,
174+
# Reference to blobs.rowid.
175+
Column(
176+
"blob_rowid",
177+
Integer,
178+
ForeignKey("blobs.rowid"),
179+
nullable=False,
180+
primary_key=True,
181+
),
182+
# S2 cell token eg "6d6dd90351b31cbf".
183+
# To locate an S2 cell by token, see https://s2.sidewalklabs.com/regioncoverer/
184+
Column(
185+
"cell_token",
186+
Text,
187+
nullable=False,
188+
primary_key=True,
189+
),
190+
)
191+
192+
193+
SpatialTreeTables.copy_tables_to_class()
194+
195+
196+
def drop_tables(sess):
197+
sess.execute("DROP TABLE IF EXISTS blob_cells;")
198+
sess.execute("DROP TABLE IF EXISTS blobs;")
199+
200+
201+
def iter_feature_oids(repo, commit_spec):
202+
cmd = _revlist_command(repo) + commit_spec
203+
try:
204+
p = subprocess.Popen(
205+
cmd,
206+
stdout=subprocess.PIPE,
207+
encoding="utf8",
208+
env=tool_environment(),
209+
)
210+
yield from _parse_revlist_output(p.stdout, r"feature/.+")
211+
except subprocess.CalledProcessError as e:
212+
raise SubprocessError(
213+
f"There was a problem with git rev-list: {e}", called_process_error=e
214+
)
215+
216+
217+
def update_spatial_tree(repo, commit_spec, verbosity=1, clear_existing=False):
218+
"""
219+
Index the commits given in commit_spec, and write them to the s2_index.db repo file.
220+
221+
repo - the Kart repo containing the commits to index, and in which to write the index file.
222+
commit_spec - a list of commits to index (ancestors of these are implicitly included).
223+
Commits can be exluded by prefixing with '^' (ancestors of these are implicitly excluded).
224+
(See git rev-list for the full list of possibilities for specifying commits).
225+
verbosity - how much non-essential information to output.
226+
clear_existing - when true, deletes any pre-existing data before re-indexing.
227+
"""
228+
import pywraps2 as s2
229+
230+
crs_helper = CrsHelper(repo)
231+
feature_oid_iter = iter_feature_oids(repo, commit_spec)
232+
233+
s2_coverer = s2.S2RegionCoverer()
234+
s2_coverer.set_max_cells(S2_MAX_CELLS_INDEX)
235+
s2_coverer.set_max_level(S2_MAX_LEVEL)
236+
237+
progress_every = None
238+
if verbosity >= 1:
239+
progress_every = max(100, 100_000 // (10 ** (verbosity - 1)))
240+
241+
db_path = repo.gitdir_file(KartRepoFiles.S2_INDEX)
242+
engine = sqlite_engine(db_path)
243+
with sessionmaker(bind=engine)() as sess:
244+
if clear_existing:
245+
drop_tables(sess)
246+
247+
SpatialTreeTables.create_all(sess)
248+
249+
click.echo(f"Indexing {' '.join(commit_spec)} ...")
250+
t0 = time.monotonic()
251+
i = 0
252+
253+
# Using sqlite directly here instead of sqlalchemy is about 10x faster.
254+
# Possibly due to huge number of unbatched queries.
255+
# TODO - investigate further.
256+
db = sqlite.connect(f"file:{db_path}", uri=True)
257+
with db:
258+
dbcur = db.cursor()
259+
260+
for i, (ds_path, feature_oid) in enumerate(feature_oid_iter):
261+
if i and progress_every and i % progress_every == 0:
262+
click.echo(f" {i:,d} features... @{time.monotonic()-t0:.1f}s")
263+
264+
transforms = crs_helper.transforms_for_dataset(ds_path)
265+
if not transforms:
266+
continue
267+
geom = get_geometry(repo, feature_oid)
268+
if geom is None:
269+
continue
270+
try:
271+
s2_cell_tokens = find_s2_cells(s2_coverer, geom, transforms)
272+
except Exception as e:
273+
L.warning(f"Couldn't locate S2 cells for {feature_oid}:\n{e}")
274+
continue
275+
276+
params = (bytes.fromhex(feature_oid),)
277+
row = dbcur.execute(
278+
"SELECT rowid FROM blobs WHERE blob_id = ?;", params
279+
).fetchone()
280+
if row:
281+
rowid = row[0]
282+
else:
283+
dbcur.execute("INSERT INTO blobs (blob_id) VALUES (?);", params)
284+
rowid = dbcur.lastrowid
285+
286+
if not s2_cell_tokens:
287+
continue
288+
289+
params = [(rowid, token) for token in s2_cell_tokens]
290+
dbcur.executemany(
291+
"INSERT OR IGNORE INTO blob_cells (blob_rowid, cell_token) VALUES (?, ?);",
292+
params,
293+
)
294+
295+
t1 = time.monotonic()
296+
click.echo(f"Indexed {i} features in {t1-t0:.1f}s")
297+
298+
299+
NO_GEOMETRY_COLUMN = object()
300+
301+
302+
def get_geometry(repo, feature_oid):
303+
legend, fields = msg_unpack(repo[feature_oid])
304+
col_id = get_geometry.legend_to_col_id.get(legend)
305+
if col_id is None:
306+
col_id = _find_geometry_column(fields)
307+
get_geometry.legend_to_col_id[legend] = col_id
308+
return fields[col_id] if col_id is not NO_GEOMETRY_COLUMN else None
309+
310+
311+
get_geometry.legend_to_col_id = {}
312+
313+
314+
def _find_geometry_column(fields):
315+
for i, field in enumerate(fields):
316+
if isinstance(field, Geometry):
317+
return i
318+
return NO_GEOMETRY_COLUMN
319+
320+
321+
def find_s2_cells(s2_coverer, geom, transforms):
322+
is_point = geom.geometry_type == GeometryType.POINT
323+
324+
return (
325+
_point_f2_cells(s2_coverer, geom, transforms)
326+
if is_point
327+
else _general_s2_cells(s2_coverer, geom, transforms)
328+
)
329+
330+
331+
def _apply_transform(original, transform, overwrite_original=False):
332+
if transform is None:
333+
return original
334+
result = original if overwrite_original else original.Clone()
335+
result.Transform(transform)
336+
return result
337+
338+
339+
def _point_f2_cells(s2_coverer, geom, transforms):
340+
import pywraps2 as s2
341+
342+
g = gpkg_geom_to_ogr(geom)
343+
one_transform = len(transforms) == 1
344+
345+
result = set()
346+
for transform in transforms:
347+
g_transformed = _apply_transform(g, transform, overwrite_original=one_transform)
348+
p = g_transformed.GetPoint()[:2]
349+
s2_ll = s2.S2LatLng.FromDegrees(p[1], p[0]).Normalized()
350+
s2_token = s2.S2CellId(s2_ll.ToPoint()).ToToken()
351+
result.add(s2_token)
352+
353+
return result
354+
355+
356+
def _general_s2_cells(s2_coverer, geom, transforms):
357+
import pywraps2 as s2
358+
359+
e = geom_envelope(geom)
360+
if e is None:
361+
return () # Empty.
362+
363+
sw_src = e[0], e[2]
364+
ne_src = e[1], e[3]
365+
366+
result = set()
367+
for transform in transforms:
368+
s2_ll = []
369+
for p_src in (sw_src, ne_src):
370+
g = ogr.Geometry(ogr.wkbPoint)
371+
g.AddPoint(*p_src)
372+
_apply_transform(g, transform, overwrite_original=True)
373+
p_dest = g.GetPoint()[:2]
374+
s2_ll.append(s2.S2LatLng.FromDegrees(p_dest[1], p_dest[0]).Normalized())
375+
376+
s2_llrect = s2.S2LatLngRect.FromPointPair(*s2_ll)
377+
for s2_cell_id in s2_coverer.GetCovering(s2_llrect):
378+
result.add(s2_cell_id.ToToken())
379+
380+
return result
381+
382+
383+
@add_help_subcommand
384+
@click.group()
385+
@click.pass_context
386+
def spatial_tree(ctx, **kwargs):
387+
"""
388+
Commands for maintaining an S2-cell based spatial index.
389+
"""
390+
391+
392+
@spatial_tree.command()
393+
@click.option(
394+
"--all",
395+
"index_all_commits",
396+
is_flag=True,
397+
default=False,
398+
help=("Index / re-index all existing commits"),
399+
)
400+
@click.option(
401+
"--clear-existing",
402+
is_flag=True,
403+
default=False,
404+
help=("Clear existing index before re-indexing"),
405+
)
406+
@click.argument(
407+
"commits",
408+
nargs=-1,
409+
)
410+
@click.pass_context
411+
def index(ctx, index_all_commits, clear_existing, commits):
412+
"""
413+
Indexes all features added by the supplied commits and all of their ancestors.
414+
To stop recursing at a particular ancestor or ancestors (eg to stop at a commit
415+
that has already been indexed) prefix that commit with a caret: ^COMMIT.
416+
"""
417+
if index_all_commits and commits:
418+
raise click.UsageError("Can't supply both --all and commits to be indexed")
419+
elif not index_all_commits and not commits:
420+
raise click.UsageError("No commits to be indexed were supplied")
421+
422+
commit_spec = ["--all"] if index_all_commits else list(commits)
423+
424+
repo = ctx.obj.get_repo(allowed_states=KartRepoState.ALL_STATES)
425+
update_spatial_tree(
426+
repo,
427+
commit_spec,
428+
verbosity=ctx.obj.verbosity + 1,
429+
clear_existing=clear_existing,
430+
)

‎tests/test_spatial_tree.py

+128
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
import pytest
2+
3+
from kart.sqlalchemy.sqlite import sqlite_engine
4+
from sqlalchemy.orm import sessionmaker
5+
6+
H = pytest.helpers.helpers()
7+
8+
SKIP_REASON = "pywraps2 not yet included in kart"
9+
10+
11+
@pytest.mark.skip(reason=SKIP_REASON)
12+
def test_index_points_all(data_archive, cli_runner):
13+
# Indexing --all should give the same results every time.
14+
# For points, every point should have only one long S2 cell token.
15+
with data_archive("points.tgz") as repo_path:
16+
r = cli_runner.invoke(["spatial-tree", "index", "--all"])
17+
assert r.exit_code == 0, r.stderr
18+
19+
stats = _get_spatial_tree_stats(repo_path)
20+
assert stats.features == 2148
21+
assert stats.avg_cell_tokens_per_feature == 1
22+
assert stats.avg_cell_token_length == 16
23+
assert stats.distinct_cell_tokens == 2143
24+
25+
26+
@pytest.mark.skip(reason=SKIP_REASON)
27+
def test_index_points_commit_by_commit(data_archive, cli_runner):
28+
# Indexing one commit at a time should get the same results as indexing --all.
29+
with data_archive("points.tgz") as repo_path:
30+
r = cli_runner.invoke(["spatial-tree", "index", H.POINTS.HEAD1_SHA])
31+
assert r.exit_code == 0, r.stderr
32+
stats = _get_spatial_tree_stats(repo_path)
33+
assert stats.features == 2143
34+
35+
r = cli_runner.invoke(
36+
["spatial-tree", "index", H.POINTS.HEAD_SHA, "^" + H.POINTS.HEAD1_SHA]
37+
)
38+
assert r.exit_code == 0, r.stderr
39+
40+
stats = _get_spatial_tree_stats(repo_path)
41+
assert stats.features == 2148
42+
assert stats.avg_cell_tokens_per_feature == 1
43+
assert stats.avg_cell_token_length == 16
44+
assert stats.distinct_cell_tokens == 2143
45+
46+
47+
@pytest.mark.skip(reason=SKIP_REASON)
48+
def test_index_points_idempotent(data_archive, cli_runner):
49+
# Indexing the commits one at a time (and backwards) and then indexing --all should
50+
# also give the same result, even though everything will have been indexed twice.
51+
with data_archive("points.tgz") as repo_path:
52+
r = cli_runner.invoke(
53+
["spatial-tree", "index", H.POINTS.HEAD_SHA, "^" + H.POINTS.HEAD1_SHA]
54+
)
55+
assert r.exit_code == 0, r.stderr
56+
stats = _get_spatial_tree_stats(repo_path)
57+
assert stats.features == 5
58+
59+
r = cli_runner.invoke(["spatial-tree", "index", H.POINTS.HEAD1_SHA])
60+
assert r.exit_code == 0, r.stderr
61+
stats = _get_spatial_tree_stats(repo_path)
62+
assert stats.features == 2148
63+
64+
r = cli_runner.invoke(["spatial-tree", "index", "--all"])
65+
assert r.exit_code == 0, r.stderr
66+
stats = _get_spatial_tree_stats(repo_path)
67+
assert stats.features == 2148
68+
assert stats.avg_cell_tokens_per_feature == 1
69+
assert stats.avg_cell_token_length == 16
70+
assert stats.distinct_cell_tokens == 2143
71+
72+
73+
@pytest.mark.skip(reason=SKIP_REASON)
74+
def test_index_polygons_all(data_archive, cli_runner):
75+
with data_archive("polygons.tgz") as repo_path:
76+
r = cli_runner.invoke(["spatial-tree", "index", "--all"])
77+
assert r.exit_code == 0, r.stderr
78+
79+
stats = _get_spatial_tree_stats(repo_path)
80+
assert stats.features == 228
81+
assert stats.avg_cell_tokens_per_feature == pytest.approx(7.232, abs=0.001)
82+
assert stats.avg_cell_token_length == pytest.approx(8.066, abs=0.001)
83+
assert stats.distinct_cell_tokens == 1360
84+
85+
86+
@pytest.mark.skip(reason=SKIP_REASON)
87+
def test_index_table_all(data_archive, cli_runner):
88+
with data_archive("table.tgz") as repo_path:
89+
r = cli_runner.invoke(["spatial-tree", "index", "--all"])
90+
assert r.exit_code == 0, r.stderr
91+
92+
stats = _get_spatial_tree_stats(repo_path)
93+
assert stats.features == 0
94+
assert stats.cell_tokens == 0
95+
96+
97+
def _get_spatial_tree_stats(repo_path):
98+
class Stats:
99+
pass
100+
101+
stats = Stats()
102+
103+
db_path = repo_path / ".kart" / "s2_index.db"
104+
engine = sqlite_engine(db_path)
105+
with sessionmaker(bind=engine)() as sess:
106+
orphans = sess.execute(
107+
"""
108+
SELECT blob_rowid FROM blob_cells
109+
EXCEPT SELECT rowid FROM blobs;
110+
"""
111+
)
112+
assert orphans.first() is None
113+
114+
stats.features = sess.scalar("SELECT COUNT(*) FROM blobs;")
115+
stats.cell_tokens = sess.scalar("SELECT COUNT(*) FROM blob_cells;")
116+
117+
if stats.features:
118+
stats.avg_cell_tokens_per_feature = stats.cell_tokens / stats.features
119+
120+
if stats.cell_tokens:
121+
stats.avg_cell_token_length = sess.scalar(
122+
"SELECT AVG(LENGTH(cell_token)) FROM blob_cells;"
123+
)
124+
stats.distinct_cell_tokens = sess.scalar(
125+
"SELECT COUNT (DISTINCT cell_token) FROM blob_cells;"
126+
)
127+
128+
return stats

0 commit comments

Comments
 (0)
Please sign in to comment.