Skip to content

Commit be0bc64

Browse files
committed
fix: typing
1 parent 14b3298 commit be0bc64

File tree

6 files changed

+86
-51
lines changed

6 files changed

+86
-51
lines changed

MANIFEST.in

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
include stactask/py.typed

pyproject.toml

+10
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
[tool.mypy]
2+
strict = true
3+
4+
[[tool.mypy.overrides]]
5+
module = [
6+
"boto3utils",
7+
"jsonpath_ng.ext",
8+
"fsspec",
9+
]
10+
ignore_missing_imports = true

stactask/asset_io.py

+31-18
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,11 @@
22
import logging
33
import os
44
from os import path as op
5-
from typing import Dict, List, Optional
5+
from typing import Any, Dict, List, Optional, Union
66

77
import fsspec
88
from boto3utils import s3
9+
from fsspec import AbstractFileSystem
910
from pystac import Item
1011
from pystac.layout import LayoutTemplate
1112

@@ -18,7 +19,7 @@
1819
sem = asyncio.Semaphore(SIMULTANEOUS_DOWNLOADS)
1920

2021

21-
async def download_file(fs, src, dest):
22+
async def download_file(fs: AbstractFileSystem, src: str, dest: str) -> None:
2223
async with sem:
2324
logger.debug(f"{src} start")
2425
await fs._get_file(src, dest)
@@ -32,8 +33,8 @@ async def download_item_assets(
3233
overwrite: bool = False,
3334
path_template: str = "${collection}/${id}",
3435
absolute_path: bool = False,
35-
**kwargs,
36-
):
36+
**kwargs: Any,
37+
) -> Item:
3738
_assets = item.assets.keys() if assets is None else assets
3839

3940
# determine path from template and item
@@ -76,44 +77,56 @@ async def download_item_assets(
7677
return new_item
7778

7879

79-
async def download_items_assets(items, **kwargs):
80+
async def download_items_assets(items: List[Item], **kwargs: Any) -> List[Item]:
8081
tasks = []
8182
for item in items:
8283
tasks.append(asyncio.create_task(download_item_assets(item, **kwargs)))
83-
new_items = await asyncio.gather(*tasks)
84+
new_items: List[Item] = await asyncio.gather(*tasks)
8485
return new_items
8586

8687

8788
def upload_item_assets_to_s3(
8889
item: Item,
8990
assets: Optional[List[str]] = None,
90-
public_assets: List[str] = [],
91+
public_assets: Union[None, List[str], str] = None,
9192
path_template: str = "${collection}/${id}",
9293
s3_urls: bool = False,
93-
headers: Dict = {},
94-
**kwargs,
95-
) -> Dict:
94+
headers: Optional[Dict[str, Any]] = None,
95+
**kwargs: Any,
96+
) -> Item:
9697
"""Upload Item assets to s3 bucket
9798
Args:
9899
item (Dict): STAC Item
99100
assets (List[str], optional): List of asset keys to upload. Defaults to None.
100-
public_assets (List[str], optional): List of assets keys that should be public. Defaults to [].
101-
path_template (str, optional): Path string template. Defaults to '${collection}/${id}'.
102-
s3_urls (bool, optional): Return s3 URLs instead of http URLs. Defaults to False.
103-
headers (Dict, optional): Dictionary of headers to set on uploaded assets. Defaults to {},
101+
public_assets (List[str], optional): List of assets keys that should be
102+
public. Defaults to [].
103+
path_template (str, optional): Path string template. Defaults to
104+
'${collection}/${id}'.
105+
s3_urls (bool, optional): Return s3 URLs instead of http URLs. Defaults
106+
to False.
107+
headers (Dict, optional): Dictionary of headers to set on uploaded
108+
assets. Defaults to {},
104109
Returns:
105110
Dict: A new STAC Item with uploaded assets pointing to newly uploaded file URLs
106111
"""
112+
if headers is None:
113+
headers = {}
114+
107115
# deepcopy of item
108116
_item = item.to_dict()
109117

118+
if public_assets is None:
119+
public_assets = []
120+
# determine which assets should be public
121+
elif type(public_assets) is str:
122+
if public_assets == "ALL":
123+
public_assets = list(_item["assets"].keys())
124+
else:
125+
raise ValueError(f"unexpected value for `public_assets`: {public_assets}")
126+
110127
# if assets not provided, upload all assets
111128
_assets = assets if assets is not None else _item["assets"].keys()
112129

113-
# determine which assets should be public
114-
if type(public_assets) is str and public_assets == "ALL":
115-
public_assets = _item["assets"].keys()
116-
117130
for key in [a for a in _assets if a in _item["assets"].keys()]:
118131
asset = _item["assets"][key]
119132
filename = asset["href"]

stactask/py.typed

Whitespace-only changes.

stactask/task.py

+42-31
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,10 @@
1111
from pathlib import Path
1212
from shutil import rmtree
1313
from tempfile import mkdtemp
14-
from typing import Any, Dict, List, Optional, Union
14+
from typing import Any, Callable, Dict, List, Optional, Union
1515

1616
import fsspec
17-
from pystac import ItemCollection
17+
from pystac import Item, ItemCollection
1818

1919
from .asset_io import (
2020
download_item_assets,
@@ -27,8 +27,9 @@
2727
# types
2828
PathLike = Union[str, Path]
2929
"""
30-
Tasks can use parameters provided in a `process` Dictionary that is supplied in the ItemCollection
31-
JSON under the "process" field. An example process definition:
30+
Tasks can use parameters provided in a `process` Dictionary that is supplied in
31+
the ItemCollection JSON under the "process" field. An example process
32+
definition:
3233
3334
```
3435
{
@@ -59,7 +60,7 @@ class Task(ABC):
5960

6061
def __init__(
6162
self: "Task",
62-
payload: Dict,
63+
payload: Dict[str, Any],
6364
workdir: Optional[PathLike] = None,
6465
save_workdir: bool = False,
6566
skip_upload: bool = False,
@@ -88,18 +89,18 @@ def __init__(
8889
self._workdir = Path(workdir)
8990
makedirs(self._workdir, exist_ok=True)
9091

91-
def __del__(self):
92+
def __del__(self) -> None:
9293
# remove work directory if not running locally
9394
if not self._save_workdir:
9495
self.logger.debug("Removing work directory %s", self._workdir)
9596
rmtree(self._workdir)
9697

9798
@property
98-
def process_definition(self) -> Dict:
99+
def process_definition(self) -> Dict[str, Any]:
99100
return self._payload.get("process", {})
100101

101102
@property
102-
def parameters(self) -> Dict:
103+
def parameters(self) -> Dict[str, Any]:
103104
task_configs = self.process_definition.get("tasks", [])
104105
if isinstance(task_configs, List):
105106
# tasks is a list
@@ -121,11 +122,11 @@ def parameters(self) -> Dict:
121122
raise ValueError(f"unexpected value for 'tasks': {task_configs}")
122123

123124
@property
124-
def upload_options(self) -> Dict:
125+
def upload_options(self) -> Dict[str, Any]:
125126
return self.process_definition.get("upload_options", {})
126127

127128
@property
128-
def items_as_dicts(self) -> List[Dict]:
129+
def items_as_dicts(self) -> List[Dict[str, Any]]:
129130
return self._payload.get("features", [])
130131

131132
@property
@@ -134,12 +135,12 @@ def items(self) -> ItemCollection:
134135
return ItemCollection.from_dict(items_dict, preserve_dict=True)
135136

136137
@classmethod
137-
def validate(cls, payload: Dict) -> bool:
138+
def validate(cls, payload: Dict[str, Any]) -> bool:
138139
# put validation logic on input Items and process definition here
139140
return True
140141

141142
@classmethod
142-
def add_software_version(cls, items: List[Dict]):
143+
def add_software_version(cls, items: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
143144
processing_ext = (
144145
"https://stac-extensions.github.io/processing/v1.1.0/schema.json"
145146
)
@@ -153,7 +154,7 @@ def add_software_version(cls, items: List[Dict]):
153154
i["properties"]["processing:software"] = {cls.name: cls.version}
154155
return items
155156

156-
def assign_collections(self):
157+
def assign_collections(self) -> None:
157158
"""Assigns new collection names based on"""
158159
for i, (coll, expr) in itertools.product(
159160
self._payload["features"],
@@ -163,13 +164,18 @@ def assign_collections(self):
163164
i["collection"] = coll
164165

165166
def download_item_assets(
166-
self, item: Dict, path_template: str = "${collection}/${id}", **kwargs
167-
):
168-
"""Download provided asset keys for all items in payload. Assets are saved in workdir in a
169-
directory named by the Item ID, and the items are updated with the new asset hrefs.
167+
self,
168+
item: Item,
169+
path_template: str = "${collection}/${id}",
170+
**kwargs: Any,
171+
) -> Item:
172+
"""Download provided asset keys for all items in payload. Assets are
173+
saved in workdir in a directory named by the Item ID, and the items are
174+
updated with the new asset hrefs.
170175
171176
Args:
172-
assets (Optional[List[str]], optional): List of asset keys to download. Defaults to all assets.
177+
assets (Optional[List[str]], optional): List of asset keys to
178+
download. Defaults to all assets.
173179
"""
174180
outdir = str(self._workdir / path_template)
175181
loop = asyncio.get_event_loop()
@@ -179,16 +185,21 @@ def download_item_assets(
179185
return item
180186

181187
def download_items_assets(
182-
self, items: List[Dict], path_template: str = "${collection}/${id}", **kwargs
183-
):
188+
self,
189+
items: List[Item],
190+
path_template: str = "${collection}/${id}",
191+
**kwargs: Any,
192+
) -> List[Item]:
184193
outdir = str(self._workdir / path_template)
185194
loop = asyncio.get_event_loop()
186195
items = loop.run_until_complete(
187-
download_items_assets(self.items, path_template=outdir, **kwargs)
196+
download_items_assets(items, path_template=outdir, **kwargs)
188197
)
189198
return items
190199

191-
def upload_item_assets_to_s3(self, item: Dict, assets: Optional[List[str]] = None):
200+
def upload_item_assets_to_s3(
201+
self, item: Item, assets: Optional[List[str]] = None
202+
) -> Item:
192203
if self._skip_upload:
193204
self.logger.warning("Skipping upload of new and modified assets")
194205
return item
@@ -197,7 +208,7 @@ def upload_item_assets_to_s3(self, item: Dict, assets: Optional[List[str]] = Non
197208

198209
# this should be in PySTAC
199210
@staticmethod
200-
def create_item_from_item(item):
211+
def create_item_from_item(item: Dict[str, Any]) -> Dict[str, Any]:
201212
new_item = deepcopy(item)
202213
# create a derived output item
203214
links = [
@@ -216,7 +227,7 @@ def create_item_from_item(item):
216227
return new_item
217228

218229
@abstractmethod
219-
def process(self, **kwargs) -> List[Dict]:
230+
def process(self, **kwargs: Any) -> List[Dict[str, Any]]:
220231
"""Main task logic - virtual
221232
222233
Returns:
@@ -229,7 +240,7 @@ def process(self, **kwargs) -> List[Dict]:
229240
pass
230241

231242
@classmethod
232-
def handler(cls, payload: Dict, **kwargs) -> Dict[str, Any]:
243+
def handler(cls, payload: Dict[str, Any], **kwargs: Any) -> Dict[str, Any]:
233244
if "href" in payload or "url" in payload:
234245
# read input
235246
with fsspec.open(payload.get("href", payload.get("url"))) as f:
@@ -249,7 +260,7 @@ def handler(cls, payload: Dict, **kwargs) -> Dict[str, Any]:
249260
raise err
250261

251262
@classmethod
252-
def parse_args(cls, args):
263+
def parse_args(cls, args: List[str]) -> Dict[str, Any]:
253264
dhf = argparse.ArgumentDefaultsHelpFormatter
254265
parser0 = argparse.ArgumentParser(description=cls.description)
255266
parser0.add_argument(
@@ -297,8 +308,8 @@ def parse_args(cls, args):
297308
default=False,
298309
)
299310
h = """ Run local mode
300-
(save-workdir = True, skip-upload = True, skip-validation = True,
301-
workdir = 'local-output', output = 'local-output/output-payload.json') """
311+
(save-workdir = True, skip-upload = True, skip-validation = True,
312+
workdir = 'local-output', output = 'local-output/output-payload.json') """
302313
parser.add_argument("--local", help=h, action="store_true", default=False)
303314

304315
# turn Namespace into dictionary
@@ -322,7 +333,7 @@ def parse_args(cls, args):
322333
return pargs
323334

324335
@classmethod
325-
def cli(cls):
336+
def cli(cls) -> None:
326337
args = cls.parse_args(sys.argv[1:])
327338
cmd = args.pop("command")
328339

@@ -364,9 +375,9 @@ def cli(cls):
364375
from functools import wraps # noqa
365376

366377

367-
def silence_event_loop_closed(func):
378+
def silence_event_loop_closed(func: Callable[[Any], Any]) -> Callable[[Any], Any]:
368379
@wraps(func)
369-
def wrapper(self, *args, **kwargs):
380+
def wrapper(self, *args: Any, **kwargs: Any) -> Any: # type: ignore
370381
try:
371382
return func(self, *args, **kwargs)
372383
except RuntimeError as e:

stactask/utils.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
1-
from typing import Dict
1+
from typing import Any, Dict
22

33
from jsonpath_ng.ext import parser
44

55

6-
def stac_jsonpath_match(item: Dict, expr: str) -> bool:
6+
def stac_jsonpath_match(item: Dict[str, Any], expr: str) -> bool:
77
"""Match jsonpath expression against STAC JSON.
88
Use https://jsonpath.herokuapp.com/ to experiment with JSONpath
99
and https://regex101.com/ to experiment with regex

0 commit comments

Comments
 (0)