|
| 1 | +import functools |
| 2 | +import logging |
| 3 | +import re |
| 4 | +import subprocess |
| 5 | +import time |
| 6 | + |
| 7 | +import click |
| 8 | +from osgeo import osr, ogr |
| 9 | +from pysqlite3 import dbapi2 as sqlite |
| 10 | +from sqlalchemy import Column, ForeignKey, Integer, Table, Text |
| 11 | +from sqlalchemy.orm import sessionmaker |
| 12 | +from sqlalchemy.types import BLOB |
| 13 | + |
| 14 | + |
| 15 | +from .cli_util import add_help_subcommand, tool_environment |
| 16 | +from .crs_util import make_crs, normalise_wkt |
| 17 | +from .exceptions import SubprocessError |
| 18 | +from .geometry import Geometry, GeometryType, geom_envelope, gpkg_geom_to_ogr |
| 19 | +from .repo import KartRepoState, KartRepoFiles |
| 20 | +from .sqlalchemy import TableSet |
| 21 | +from .sqlalchemy.sqlite import sqlite_engine |
| 22 | +from .serialise_util import msg_unpack |
| 23 | + |
| 24 | +L = logging.getLogger("kart.spatial_tree") |
| 25 | + |
| 26 | + |
| 27 | +S2_MAX_CELLS_INDEX = 8 |
| 28 | +S2_MAX_LEVEL = 15 |
| 29 | + |
| 30 | + |
| 31 | +def _revlist_command(repo): |
| 32 | + return [ |
| 33 | + "git", |
| 34 | + "-C", |
| 35 | + repo.path, |
| 36 | + "rev-list", |
| 37 | + "--objects", |
| 38 | + "--filter=object:type=blob", |
| 39 | + "--missing=allow-promisor", |
| 40 | + ] |
| 41 | + |
| 42 | + |
| 43 | +DS_PATH_PATTERN = r'(.+)/\.(sno|table)-dataset/' |
| 44 | + |
| 45 | + |
| 46 | +def _parse_revlist_output(line_iter, rel_path_pattern): |
| 47 | + full_path_pattern = re.compile(DS_PATH_PATTERN + rel_path_pattern) |
| 48 | + |
| 49 | + for line in line_iter: |
| 50 | + parts = line.split(" ", maxsplit=1) |
| 51 | + if len(parts) != 2: |
| 52 | + continue |
| 53 | + oid, path = parts |
| 54 | + |
| 55 | + m = full_path_pattern.match(path) |
| 56 | + if not m: |
| 57 | + continue |
| 58 | + ds_path = m.group(1) |
| 59 | + yield ds_path, oid |
| 60 | + |
| 61 | + |
| 62 | +class CrsHelper: |
| 63 | + """ |
| 64 | + Loads all CRS definitions for a particular dataset, |
| 65 | + and creates transforms |
| 66 | + """ |
| 67 | + |
| 68 | + def __init__(self, repo): |
| 69 | + self.repo = repo |
| 70 | + self.ds_to_transforms = {} |
| 71 | + |
| 72 | + def transforms_for_dataset(self, ds_path): |
| 73 | + transforms = self.ds_to_transforms.get(ds_path) |
| 74 | + if transforms is None: |
| 75 | + transforms = self._load_transforms_for_dataset(ds_path) |
| 76 | + self.ds_to_transforms[ds_path] = transforms |
| 77 | + return transforms |
| 78 | + |
| 79 | + def _load_transforms_for_dataset(self, ds_path): |
| 80 | + if ds_path in self.ds_to_transforms: |
| 81 | + return self.ds_to_transforms[ds_path] |
| 82 | + |
| 83 | + crs_oids = set(self.iter_crs_oids(ds_path)) |
| 84 | + transforms = [] |
| 85 | + descs = [] |
| 86 | + for crs_oid in crs_oids: |
| 87 | + try: |
| 88 | + transform, desc = self.transform_from_oid(crs_oid) |
| 89 | + if transform not in transforms: |
| 90 | + transforms.append(transform) |
| 91 | + descs.append(desc) |
| 92 | + except Exception as e: |
| 93 | + L.warning( |
| 94 | + f"Couldn't load transform for CRS {crs_oid} at {ds_path}\n{e}" |
| 95 | + ) |
| 96 | + L.info(f"Loaded CRS transforms for {ds_path}: {', '.join(descs)}") |
| 97 | + return transforms |
| 98 | + |
| 99 | + def iter_crs_oids(self, ds_path): |
| 100 | + cmd = [ |
| 101 | + *_revlist_command(self.repo), |
| 102 | + "--all", |
| 103 | + "--", |
| 104 | + *self.all_crs_paths(ds_path), |
| 105 | + ] |
| 106 | + try: |
| 107 | + r = subprocess.run( |
| 108 | + cmd, |
| 109 | + encoding="utf8", |
| 110 | + check=True, |
| 111 | + capture_output=True, |
| 112 | + env=tool_environment(), |
| 113 | + ) |
| 114 | + except subprocess.CalledProcessError as e: |
| 115 | + raise SubprocessError( |
| 116 | + f"There was a problem with git rev-list: {e}", called_process_error=e |
| 117 | + ) |
| 118 | + for d, crs_oid in _parse_revlist_output( |
| 119 | + r.stdout.splitlines(), r"meta/crs/[^/]+" |
| 120 | + ): |
| 121 | + assert d == ds_path |
| 122 | + yield crs_oid |
| 123 | + |
| 124 | + def all_crs_paths(self, ds_path): |
| 125 | + # Delete .sno-dataset if we drop V2 support. |
| 126 | + yield f"{ds_path}/.sno-dataset/meta/crs/" |
| 127 | + yield f"{ds_path}/.table-dataset/meta/crs/" |
| 128 | + |
| 129 | + @functools.lru_cache() |
| 130 | + def transform_from_oid(self, crs_oid): |
| 131 | + wkt = normalise_wkt(self.repo[crs_oid].data.decode("utf-8")) |
| 132 | + return self.transform_from_wkt(wkt) |
| 133 | + |
| 134 | + @functools.lru_cache() |
| 135 | + def transform_from_wkt(self, wkt): |
| 136 | + src_crs = make_crs(wkt) |
| 137 | + if src_crs.IsGeographic(): |
| 138 | + transform = None |
| 139 | + desc = f"IDENTITY({src_crs.GetAuthorityCode(None)})" |
| 140 | + else: |
| 141 | + target_crs = src_crs.CloneGeogCS() |
| 142 | + transform = osr.CoordinateTransformation(src_crs, target_crs) |
| 143 | + desc = f"{src_crs.GetAuthorityCode(None)} -> {target_crs.GetAuthorityCode(None)}" |
| 144 | + return transform, desc |
| 145 | + |
| 146 | + |
| 147 | +class SpatialTreeTables(TableSet): |
| 148 | + """Tables for associating a variable number of S2 cells with each feature.""" |
| 149 | + |
| 150 | + def __init__(self): |
| 151 | + super().__init__() |
| 152 | + |
| 153 | + # "blobs" tracks all the features we have indexed (even if they do not overlap any s2 cells). |
| 154 | + self.blobs = Table( |
| 155 | + "blobs", |
| 156 | + self.sqlalchemy_metadata, |
| 157 | + # From a user-perspective, "rowid" isjust an arbitrary integer primary key. |
| 158 | + # In more detail: This column aliases to the sqlite rowid of the table. |
| 159 | + # See https://www.sqlite.org/lang_createtable.html#rowid |
| 160 | + # Using the rowid directly as a foreign key (see "blob_cells") means faster joins. |
| 161 | + # The rowid can be used without creating a column that aliases to it, but you shouldn't - |
| 162 | + # rowids might change if they are not aliased. See https://sqlite.org/lang_vacuum.html) |
| 163 | + Column("rowid", Integer, nullable=False, primary_key=True), |
| 164 | + # "blob_id" is the git object ID (the SHA-1 hash) of a feature, in binary (20 bytes). |
| 165 | + # Is equivalent to 40 chars of hex eg: d08c3dd220eea08d8dfd6d4adb84f9936c541d7a |
| 166 | + Column("blob_id", BLOB, nullable=False, unique=True), |
| 167 | + sqlite_autoincrement=True, |
| 168 | + ) |
| 169 | + |
| 170 | + # "blob_cells" associates 0 or more S2 cell tokens with each feature that we have indexed. |
| 171 | + self.blob_cells = Table( |
| 172 | + "blob_cells", |
| 173 | + self.sqlalchemy_metadata, |
| 174 | + # Reference to blobs.rowid. |
| 175 | + Column( |
| 176 | + "blob_rowid", |
| 177 | + Integer, |
| 178 | + ForeignKey("blobs.rowid"), |
| 179 | + nullable=False, |
| 180 | + primary_key=True, |
| 181 | + ), |
| 182 | + # S2 cell token eg "6d6dd90351b31cbf". |
| 183 | + # To locate an S2 cell by token, see https://s2.sidewalklabs.com/regioncoverer/ |
| 184 | + Column( |
| 185 | + "cell_token", |
| 186 | + Text, |
| 187 | + nullable=False, |
| 188 | + primary_key=True, |
| 189 | + ), |
| 190 | + ) |
| 191 | + |
| 192 | + |
| 193 | +SpatialTreeTables.copy_tables_to_class() |
| 194 | + |
| 195 | + |
| 196 | +def drop_tables(sess): |
| 197 | + sess.execute("DROP TABLE IF EXISTS blob_cells;") |
| 198 | + sess.execute("DROP TABLE IF EXISTS blobs;") |
| 199 | + |
| 200 | + |
| 201 | +def iter_feature_oids(repo, commit_spec): |
| 202 | + cmd = _revlist_command(repo) + commit_spec |
| 203 | + try: |
| 204 | + p = subprocess.Popen( |
| 205 | + cmd, |
| 206 | + stdout=subprocess.PIPE, |
| 207 | + encoding="utf8", |
| 208 | + env=tool_environment(), |
| 209 | + ) |
| 210 | + yield from _parse_revlist_output(p.stdout, r"feature/.+") |
| 211 | + except subprocess.CalledProcessError as e: |
| 212 | + raise SubprocessError( |
| 213 | + f"There was a problem with git rev-list: {e}", called_process_error=e |
| 214 | + ) |
| 215 | + |
| 216 | + |
| 217 | +def update_spatial_tree(repo, commit_spec, verbosity=1, clear_existing=False): |
| 218 | + """ |
| 219 | + Index the commits given in commit_spec, and write them to the s2_index.db repo file. |
| 220 | +
|
| 221 | + repo - the Kart repo containing the commits to index, and in which to write the index file. |
| 222 | + commit_spec - a list of commits to index (ancestors of these are implicitly included). |
| 223 | + Commits can be exluded by prefixing with '^' (ancestors of these are implicitly excluded). |
| 224 | + (See git rev-list for the full list of possibilities for specifying commits). |
| 225 | + verbosity - how much non-essential information to output. |
| 226 | + clear_existing - when true, deletes any pre-existing data before re-indexing. |
| 227 | + """ |
| 228 | + import pywraps2 as s2 |
| 229 | + |
| 230 | + crs_helper = CrsHelper(repo) |
| 231 | + feature_oid_iter = iter_feature_oids(repo, commit_spec) |
| 232 | + |
| 233 | + s2_coverer = s2.S2RegionCoverer() |
| 234 | + s2_coverer.set_max_cells(S2_MAX_CELLS_INDEX) |
| 235 | + s2_coverer.set_max_level(S2_MAX_LEVEL) |
| 236 | + |
| 237 | + progress_every = None |
| 238 | + if verbosity >= 1: |
| 239 | + progress_every = max(100, 100_000 // (10 ** (verbosity - 1))) |
| 240 | + |
| 241 | + db_path = repo.gitdir_file(KartRepoFiles.S2_INDEX) |
| 242 | + engine = sqlite_engine(db_path) |
| 243 | + with sessionmaker(bind=engine)() as sess: |
| 244 | + if clear_existing: |
| 245 | + drop_tables(sess) |
| 246 | + |
| 247 | + SpatialTreeTables.create_all(sess) |
| 248 | + |
| 249 | + click.echo(f"Indexing {' '.join(commit_spec)} ...") |
| 250 | + t0 = time.monotonic() |
| 251 | + i = 0 |
| 252 | + |
| 253 | + # Using sqlite directly here instead of sqlalchemy is about 10x faster. |
| 254 | + # Possibly due to huge number of unbatched queries. |
| 255 | + # TODO - investigate further. |
| 256 | + db = sqlite.connect(f"file:{db_path}", uri=True) |
| 257 | + with db: |
| 258 | + dbcur = db.cursor() |
| 259 | + |
| 260 | + for i, (ds_path, feature_oid) in enumerate(feature_oid_iter): |
| 261 | + if i and progress_every and i % progress_every == 0: |
| 262 | + click.echo(f" {i:,d} features... @{time.monotonic()-t0:.1f}s") |
| 263 | + |
| 264 | + transforms = crs_helper.transforms_for_dataset(ds_path) |
| 265 | + if not transforms: |
| 266 | + continue |
| 267 | + geom = get_geometry(repo, feature_oid) |
| 268 | + if geom is None: |
| 269 | + continue |
| 270 | + try: |
| 271 | + s2_cell_tokens = find_s2_cells(s2_coverer, geom, transforms) |
| 272 | + except Exception as e: |
| 273 | + L.warning(f"Couldn't locate S2 cells for {feature_oid}:\n{e}") |
| 274 | + continue |
| 275 | + |
| 276 | + params = (bytes.fromhex(feature_oid),) |
| 277 | + row = dbcur.execute( |
| 278 | + "SELECT rowid FROM blobs WHERE blob_id = ?;", params |
| 279 | + ).fetchone() |
| 280 | + if row: |
| 281 | + rowid = row[0] |
| 282 | + else: |
| 283 | + dbcur.execute("INSERT INTO blobs (blob_id) VALUES (?);", params) |
| 284 | + rowid = dbcur.lastrowid |
| 285 | + |
| 286 | + if not s2_cell_tokens: |
| 287 | + continue |
| 288 | + |
| 289 | + params = [(rowid, token) for token in s2_cell_tokens] |
| 290 | + dbcur.executemany( |
| 291 | + "INSERT OR IGNORE INTO blob_cells (blob_rowid, cell_token) VALUES (?, ?);", |
| 292 | + params, |
| 293 | + ) |
| 294 | + |
| 295 | + t1 = time.monotonic() |
| 296 | + click.echo(f"Indexed {i} features in {t1-t0:.1f}s") |
| 297 | + |
| 298 | + |
| 299 | +NO_GEOMETRY_COLUMN = object() |
| 300 | + |
| 301 | + |
| 302 | +def get_geometry(repo, feature_oid): |
| 303 | + legend, fields = msg_unpack(repo[feature_oid]) |
| 304 | + col_id = get_geometry.legend_to_col_id.get(legend) |
| 305 | + if col_id is None: |
| 306 | + col_id = _find_geometry_column(fields) |
| 307 | + get_geometry.legend_to_col_id[legend] = col_id |
| 308 | + return fields[col_id] if col_id is not NO_GEOMETRY_COLUMN else None |
| 309 | + |
| 310 | + |
| 311 | +get_geometry.legend_to_col_id = {} |
| 312 | + |
| 313 | + |
| 314 | +def _find_geometry_column(fields): |
| 315 | + for i, field in enumerate(fields): |
| 316 | + if isinstance(field, Geometry): |
| 317 | + return i |
| 318 | + return NO_GEOMETRY_COLUMN |
| 319 | + |
| 320 | + |
| 321 | +def find_s2_cells(s2_coverer, geom, transforms): |
| 322 | + is_point = geom.geometry_type == GeometryType.POINT |
| 323 | + |
| 324 | + return ( |
| 325 | + _point_f2_cells(s2_coverer, geom, transforms) |
| 326 | + if is_point |
| 327 | + else _general_s2_cells(s2_coverer, geom, transforms) |
| 328 | + ) |
| 329 | + |
| 330 | + |
| 331 | +def _apply_transform(original, transform, overwrite_original=False): |
| 332 | + if transform is None: |
| 333 | + return original |
| 334 | + result = original if overwrite_original else original.Clone() |
| 335 | + result.Transform(transform) |
| 336 | + return result |
| 337 | + |
| 338 | + |
| 339 | +def _point_f2_cells(s2_coverer, geom, transforms): |
| 340 | + import pywraps2 as s2 |
| 341 | + |
| 342 | + g = gpkg_geom_to_ogr(geom) |
| 343 | + one_transform = len(transforms) == 1 |
| 344 | + |
| 345 | + result = set() |
| 346 | + for transform in transforms: |
| 347 | + g_transformed = _apply_transform(g, transform, overwrite_original=one_transform) |
| 348 | + p = g_transformed.GetPoint()[:2] |
| 349 | + s2_ll = s2.S2LatLng.FromDegrees(p[1], p[0]).Normalized() |
| 350 | + s2_token = s2.S2CellId(s2_ll.ToPoint()).ToToken() |
| 351 | + result.add(s2_token) |
| 352 | + |
| 353 | + return result |
| 354 | + |
| 355 | + |
| 356 | +def _general_s2_cells(s2_coverer, geom, transforms): |
| 357 | + import pywraps2 as s2 |
| 358 | + |
| 359 | + e = geom_envelope(geom) |
| 360 | + if e is None: |
| 361 | + return () # Empty. |
| 362 | + |
| 363 | + sw_src = e[0], e[2] |
| 364 | + ne_src = e[1], e[3] |
| 365 | + |
| 366 | + result = set() |
| 367 | + for transform in transforms: |
| 368 | + s2_ll = [] |
| 369 | + for p_src in (sw_src, ne_src): |
| 370 | + g = ogr.Geometry(ogr.wkbPoint) |
| 371 | + g.AddPoint(*p_src) |
| 372 | + _apply_transform(g, transform, overwrite_original=True) |
| 373 | + p_dest = g.GetPoint()[:2] |
| 374 | + s2_ll.append(s2.S2LatLng.FromDegrees(p_dest[1], p_dest[0]).Normalized()) |
| 375 | + |
| 376 | + s2_llrect = s2.S2LatLngRect.FromPointPair(*s2_ll) |
| 377 | + for s2_cell_id in s2_coverer.GetCovering(s2_llrect): |
| 378 | + result.add(s2_cell_id.ToToken()) |
| 379 | + |
| 380 | + return result |
| 381 | + |
| 382 | + |
| 383 | +@add_help_subcommand |
| 384 | +@click.group() |
| 385 | +@click.pass_context |
| 386 | +def spatial_tree(ctx, **kwargs): |
| 387 | + """ |
| 388 | + Commands for maintaining an S2-cell based spatial index. |
| 389 | + """ |
| 390 | + |
| 391 | + |
| 392 | +@spatial_tree.command() |
| 393 | +@click.option( |
| 394 | + "--all", |
| 395 | + "index_all_commits", |
| 396 | + is_flag=True, |
| 397 | + default=False, |
| 398 | + help=("Index / re-index all existing commits"), |
| 399 | +) |
| 400 | +@click.option( |
| 401 | + "--clear-existing", |
| 402 | + is_flag=True, |
| 403 | + default=False, |
| 404 | + help=("Clear existing index before re-indexing"), |
| 405 | +) |
| 406 | +@click.argument( |
| 407 | + "commits", |
| 408 | + nargs=-1, |
| 409 | +) |
| 410 | +@click.pass_context |
| 411 | +def index(ctx, index_all_commits, clear_existing, commits): |
| 412 | + """ |
| 413 | + Indexes all features added by the supplied commits and all of their ancestors. |
| 414 | + To stop recursing at a particular ancestor or ancestors (eg to stop at a commit |
| 415 | + that has already been indexed) prefix that commit with a caret: ^COMMIT. |
| 416 | + """ |
| 417 | + if index_all_commits and commits: |
| 418 | + raise click.UsageError("Can't supply both --all and commits to be indexed") |
| 419 | + elif not index_all_commits and not commits: |
| 420 | + raise click.UsageError("No commits to be indexed were supplied") |
| 421 | + |
| 422 | + commit_spec = ["--all"] if index_all_commits else list(commits) |
| 423 | + |
| 424 | + repo = ctx.obj.get_repo(allowed_states=KartRepoState.ALL_STATES) |
| 425 | + update_spatial_tree( |
| 426 | + repo, |
| 427 | + commit_spec, |
| 428 | + verbosity=ctx.obj.verbosity + 1, |
| 429 | + clear_existing=clear_existing, |
| 430 | + ) |
0 commit comments