Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix some typing smells #658

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 102 additions & 25 deletions mapchete/commands/convert.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,8 @@
import logging
import os
from contextlib import AbstractContextManager
from multiprocessing import cpu_count
from pprint import pformat
from typing import List, Optional, Tuple, Type, Union
from typing import Callable, List, Optional, Tuple, Type, Union

import tilematrix
from rasterio.crs import CRS
from rasterio.enums import Resampling
from shapely.geometry import box
from shapely.geometry.base import BaseGeometry
Expand All @@ -17,12 +13,20 @@
from mapchete.config import DaskSettings
from mapchete.enums import Concurrency, DataType, ProcessingMode
from mapchete.errors import JobCancelledError
from mapchete.executor import Executor
from mapchete.executor import get_executor
from mapchete.executor.base import ExecutorBase
from mapchete.formats import available_output_formats
from mapchete.io import MPath, fiona_open, get_best_zoom_level
from mapchete.geometry import reproject_geometry
from mapchete.tile import BufferedTilePyramid
from mapchete.types import MPathLike, ResamplingLike, to_resampling
from mapchete.tile import BufferedTilePyramid, MetatilingValue
from mapchete.types import (
MPathLike,
ResamplingLike,
to_resampling,
ZoomLevelsLike,
BoundsLike,
CRSLike,
)

logger = logging.getLogger(__name__)
OUTPUT_FORMATS = available_output_formats()
Expand All @@ -31,11 +35,11 @@
def convert(
input_path: MPathLike,
output_path: MPathLike,
zoom: Optional[Union[int, List[int]]] = None,
zoom: Optional[ZoomLevelsLike] = None,
area: Optional[Union[BaseGeometry, str, dict]] = None,
area_crs: Optional[Union[CRS, str]] = None,
bounds: Optional[Tuple[float, float, float, float]] = None,
bounds_crs: Optional[Union[CRS, str]] = None,
area_crs: Optional[CRSLike] = None,
bounds: Optional[BoundsLike] = None,
bounds_crs: Optional[CRSLike] = None,
point: Optional[Tuple[float, float]] = None,
point_crs: Optional[Tuple[float, float]] = None,
overwrite: bool = False,
Expand All @@ -45,7 +49,7 @@ def convert(
clip_geometry: Optional[str] = None,
bidx: Optional[Union[List[int], int]] = None,
output_pyramid: Optional[Union[str, dict, MPathLike]] = None,
output_metatiling: Optional[int] = None,
output_metatiling: Optional[MetatilingValue] = None,
output_format: Optional[str] = None,
output_dtype: Optional[str] = None,
output_geometry_type: Optional[str] = None,
Expand All @@ -58,9 +62,9 @@ def convert(
cog: bool = False,
src_fs_opts: Optional[dict] = None,
dst_fs_opts: Optional[dict] = None,
executor_getter: AbstractContextManager = Executor,
executor_getter: Callable[..., ExecutorBase] = get_executor,
observers: Optional[List[ObserverProtocol]] = None,
retry_on_exception: Tuple[Type[Exception], Type[Exception]] = Exception,
retry_on_exception: Union[Tuple[Type[Exception], ...], Type[Exception]] = Exception,
cancel_on_exception: Type[Exception] = JobCancelledError,
retries: int = 0,
) -> None:
Expand Down Expand Up @@ -90,15 +94,86 @@ def convert(
except Exception as e:
raise ValueError(e)

# try to read output grid definition from a file
if not (
isinstance(output_pyramid, str)
and output_pyramid in tilematrix._conf.PYRAMID_PARAMS.keys()
):
try:
output_pyramid = MPath.from_inp(output_pyramid).read_json() # type: ignore
except Exception: # pragma: no cover
pass
# process = "mapchete.processes.convert"
# input = dict(inp=input_path, clip=clip_geometry)

# if output_pyramid:
# # try to read output grid definition from a file
# if isinstance(output_pyramid, dict):
# _params = output_pyramid
# elif (
# isinstance(output_pyramid, str)
# and output_pyramid in tilematrix._conf.PYRAMID_PARAMS.keys()
# ):
# _params = output_pyramid
# else:
# _params = MPath.from_inp(output_pyramid).read_json()

# pyramid_config = PyramidConfig(
# grid=_params,
# metatiling=(
# output_metatiling
# or (
# input_info.output_pyramid.metatiling
# if input_info.output_pyramid
# else 1
# )
# ),
# pixelbuffer=(
# input_info.output_pyramid.pixelbuffer
# if input_info.output_pyramid
# else 0
# ),
# )
# elif input_info.output_pyramid:
# pyramid_config = input_info.output_pyramid.to_config()
# else:
# raise ValueError("Output pyramid required.")

# output = dict(
# {
# k: v
# for k, v in input_info.output_params.items()
# if k not in ["delimiters", "bounds", "mode"]
# },
# path=output_path,
# format=(
# output_format or output_info.driver or input_info.output_params["format"]
# ),
# dtype=output_dtype or input_info.output_params.get("dtype"),
# **creation_options,
# **(
# dict(overviews=True, overviews_resampling=overviews_resampling_method)
# if overviews
# else dict()
# ),
# )
# config_dir = MPath.cwd()
# zoom_levels = (
# zoom
# or input_info.zoom_levels
# or dict(
# min=0,
# max=get_best_zoom_level(input_path, pyramid_config.grid),
# )
# )

# process_parameters = dict(
# scale_ratio=scale_ratio,
# scale_offset=scale_offset,
# resampling=resampling_method,
# band_indexes=bidx,
# )

# process_config = ProcessConfig(
# process=process,
# input=input,
# output=output,
# pyramid=pyramid_config,
# config_dir=config_dir,
# zoom_levels=zoom_levels,
# process_parameters=process_parameters,
# )

# collect mapchete configuration
mapchete_config = dict(
Expand Down Expand Up @@ -148,7 +223,7 @@ def convert(
else dict()
),
),
config_dir=os.getcwd(),
config_dir=MPath.cwd(),
zoom_levels=zoom or input_info.zoom_levels,
process_parameters=dict(
scale_ratio=scale_ratio,
Expand All @@ -169,8 +244,10 @@ def convert(
]
if bidx is not None:
mapchete_config["output"].update(bands=len(bidx))

if mapchete_config["pyramid"] is None:
raise ValueError("Output pyramid required.")

elif mapchete_config["zoom_levels"] is None:
try:
mapchete_config.update(
Expand Down
43 changes: 24 additions & 19 deletions mapchete/commands/execute.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
"""Execute a process."""

import logging
from contextlib import AbstractContextManager
from multiprocessing import cpu_count
from typing import List, Optional, Tuple, Type, Union
from typing import Callable, List, Literal, Optional, Tuple, Type, Union

from rasterio.crs import CRS
from shapely.geometry.base import BaseGeometry

import mapchete
Expand All @@ -14,32 +12,36 @@
from mapchete.config.parse import bounds_from_opts, raw_conf, raw_conf_process_pyramid
from mapchete.enums import Concurrency, ProcessingMode, Status
from mapchete.errors import JobCancelledError
from mapchete.executor import Executor
from mapchete.executor import get_executor
from mapchete.executor.base import ExecutorBase
from mapchete.executor.concurrent_futures import MULTIPROCESSING_DEFAULT_START_METHOD
from mapchete.executor.types import Profiler
from mapchete.processing.profilers import preconfigured_profilers
from mapchete.processing.profilers.time import measure_time
from mapchete.types import MPathLike, Progress
from mapchete.types import MPathLike, Progress, BoundsLike, CRSLike, ZoomLevelsLike

logger = logging.getLogger(__name__)


def execute(
mapchete_config: Union[dict, MPathLike],
zoom: Optional[Union[int, List[int]]] = None,
zoom: Optional[ZoomLevelsLike] = None,
area: Optional[Union[BaseGeometry, str, dict]] = None,
area_crs: Optional[Union[CRS, str]] = None,
bounds: Optional[Tuple[float]] = None,
bounds_crs: Optional[Union[CRS, str]] = None,
area_crs: Optional[CRSLike] = None,
bounds: Optional[BoundsLike] = None,
bounds_crs: Optional[CRSLike] = None,
point: Optional[Tuple[float, float]] = None,
point_crs: Optional[Tuple[float, float]] = None,
tile: Optional[Tuple[int, int, int]] = None,
overwrite: bool = False,
mode: ProcessingMode = ProcessingMode.CONTINUE,
concurrency: Concurrency = Concurrency.none,
workers: int = None,
multiprocessing_start_method: str = None,
workers: Optional[int] = None,
multiprocessing_start_method: Literal[
"fork", "forkserver", "spawn"
] = MULTIPROCESSING_DEFAULT_START_METHOD,
dask_settings: DaskSettings = DaskSettings(),
executor_getter: AbstractContextManager = Executor,
executor_getter: Callable[..., ExecutorBase] = get_executor,
profiling: bool = False,
observers: Optional[List[ObserverProtocol]] = None,
retry_on_exception: Union[Tuple[Type[Exception], ...], Type[Exception]] = Exception,
Expand Down Expand Up @@ -85,7 +87,7 @@ def execute(
Reusable Client instance if required. Otherwise a new client will be created.
"""
try:
mode = "overwrite" if overwrite else mode
mode = ProcessingMode.OVERWRITE if overwrite else mode
all_observers = Observers(observers)

if not isinstance(retry_on_exception, tuple):
Expand All @@ -95,9 +97,11 @@ def execute(
all_observers.notify(status=Status.parsing)

if tile:
tile = raw_conf_process_pyramid(raw_conf(mapchete_config)).tile(*tile)
bounds = tile.bounds
zoom = tile.zoom
buffered_tile = raw_conf_process_pyramid(raw_conf(mapchete_config)).tile(
*tile
)
bounds = buffered_tile.bounds
zoom = buffered_tile.zoom
else:
try:
bounds = bounds_from_opts(
Expand All @@ -109,6 +113,7 @@ def execute(
)
except ValueError:
bounds = None
buffered_tile = None

# be careful opening mapchete not as context manager
with mapchete.open(
Expand Down Expand Up @@ -138,7 +143,7 @@ def execute(
all_observers.notify(status=Status.initializing)

# determine tasks
tasks = mp.tasks(zoom=zoom, tile=tile)
tasks = mp.tasks(zoom=zoom, tile=buffered_tile)

if len(tasks) == 0:
all_observers.notify(
Expand All @@ -163,10 +168,10 @@ def execute(
) as executor:
if profiling:
for profiler in preconfigured_profilers:
executor.add_profiler(profiler)
executor.add_profiler(profiler=profiler)
else:
executor.add_profiler(
Profiler(name="time", decorator=measure_time)
profiler=Profiler(name="time", decorator=measure_time)
)
all_observers.notify(
status=Status.running,
Expand Down
8 changes: 1 addition & 7 deletions mapchete/config/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from mapchete.types import BoundsLike, MPathLike, ZoomLevelsLike
from mapchete.validate import validate_values
from mapchete.zoom_levels import ZoomLevels
from mapchete.tile import MetatilingValue
from mapchete.tile import MetatilingValue, PyramidConfig


class OutputConfigBase(BaseModel):
Expand All @@ -25,12 +25,6 @@ class OutputConfigBase(BaseModel):
pixelbuffer: NonNegativeInt = 0


class PyramidConfig(BaseModel):
grid: Union[str, dict]
metatiling: MetatilingValue = 1
pixelbuffer: NonNegativeInt = 0


class DaskAdaptOptions(BaseModel):
minimum: int = 0
maximum: int = 20
Expand Down
47 changes: 23 additions & 24 deletions mapchete/executor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,30 +13,29 @@
__all__ = ["MULTIPROCESSING_DEFAULT_START_METHOD", "MFuture"]


class Executor:
"""
Executor factory for dask and concurrent.futures executor
"""

def __new__(
cls, *args, concurrency: Optional[Concurrency] = None, **kwargs
) -> ExecutorBase:
concurrency = (
Concurrency.dask
if kwargs.get("dask_scheduler") or kwargs.get("dask_client")
else concurrency
def get_executor(
*args, concurrency: Optional[Concurrency] = None, **kwargs
) -> ExecutorBase:
concurrency = (
Concurrency.dask
if kwargs.get("dask_scheduler") or kwargs.get("dask_client")
else concurrency
)

if concurrency == Concurrency.dask:
return DaskExecutor(*args, **kwargs)

elif concurrency in [None, Concurrency.none]:
return SequentialExecutor(*args, **kwargs)

elif concurrency in [Concurrency.processes, Concurrency.threads]:
return ConcurrentFuturesExecutor(*args, concurrency=concurrency, **kwargs)

else: # pragma: no cover
raise ValueError(
f"concurrency must be one of None, 'processes', 'threads' or 'dask', not {concurrency}"
)

if concurrency == Concurrency.dask:
return DaskExecutor(*args, **kwargs)

elif concurrency in [None, Concurrency.none]:
return SequentialExecutor(*args, **kwargs)

elif concurrency in [Concurrency.processes, Concurrency.threads]:
return ConcurrentFuturesExecutor(*args, concurrency=concurrency, **kwargs)

else: # pragma: no cover
raise ValueError(
f"concurrency must be one of None, 'processes', 'threads' or 'dask', not {concurrency}"
)
# for backwards compatibility
Executor = get_executor
1 change: 1 addition & 0 deletions mapchete/io/raster/referenced_raster.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def __init__(
bounds or array_bounds(self.height, self.width, self.transform)
)
self.__geo_interface__ = mapping(shape(self.bounds))
# logger.debug("%s array has %s", self, pretty_bytes(self.array.size))

@property
def meta(self) -> dict:
Expand Down
Loading
Loading