Skip to content

Commit ac65980

Browse files
authored
Clean up typing (#64)
* clean up typing * fix tests * lint
1 parent 88cc054 commit ac65980

10 files changed

+60
-44
lines changed

jupyter_events/cli.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def main():
5252
@click.command()
5353
@click.argument("schema")
5454
@click.pass_context
55-
def validate(ctx: click.Context, schema: str):
55+
def validate(ctx: click.Context, schema: str) -> int:
5656
"""Validate a SCHEMA against Jupyter Event's meta schema.
5757
5858
SCHEMA can be a JSON/YAML string or filepath to a schema.

jupyter_events/logger.py

+16-14
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,14 @@
88
import logging
99
import warnings
1010
from datetime import datetime
11-
from pathlib import PurePath
12-
from typing import Callable, Optional, Union
11+
from typing import Any, Callable, Coroutine, Optional, Union
1312

1413
from jsonschema import ValidationError
1514
from pythonjsonlogger import jsonlogger # type:ignore
1615
from traitlets import Dict, Instance, Set, default
1716
from traitlets.config import Config, LoggingConfigurable
1817

18+
from .schema import SchemaType
1919
from .schema_registry import SchemaRegistry
2020
from .traits import Handlers
2121
from .validators import JUPYTER_EVENTS_CORE_VALIDATOR
@@ -131,7 +131,7 @@ def get_handlers():
131131
eventlogger_cfg = Config({"EventLogger": my_cfg})
132132
super()._load_config(eventlogger_cfg, section_names=None, traits=None)
133133

134-
def register_event_schema(self, schema: Union[dict, str, PurePath]):
134+
def register_event_schema(self, schema: SchemaType) -> None:
135135
"""Register this schema with the schema registry.
136136
137137
Get this registered schema using the EventLogger.schema.get() method.
@@ -143,7 +143,7 @@ def register_event_schema(self, schema: Union[dict, str, PurePath]):
143143
self._modified_listeners[key] = set()
144144
self._unmodified_listeners[key] = set()
145145

146-
def register_handler(self, handler: logging.Handler):
146+
def register_handler(self, handler: logging.Handler) -> None:
147147
"""Register a new logging handler to the Event Logger.
148148
149149
All outgoing messages will be formatted as a JSON string.
@@ -164,7 +164,7 @@ def _skip_message(record, **kwargs):
164164
if handler not in self.handlers:
165165
self.handlers.append(handler)
166166

167-
def remove_handler(self, handler: logging.Handler):
167+
def remove_handler(self, handler: logging.Handler) -> None:
168168
"""Remove a logging handler from the logger and list of handlers."""
169169
self._logger.removeHandler(handler)
170170
if handler in self.handlers:
@@ -175,7 +175,7 @@ def add_modifier(
175175
*,
176176
schema_id: Union[str, None] = None,
177177
modifier: Callable[[str, dict], dict],
178-
):
178+
) -> None:
179179
"""Add a modifier (callable) to a registered event.
180180
181181
Parameters
@@ -249,8 +249,8 @@ def add_listener(
249249
*,
250250
modified: bool = True,
251251
schema_id: Union[str, None] = None,
252-
listener: Callable[["EventLogger", str, dict], None],
253-
):
252+
listener: Callable[["EventLogger", str, dict], Coroutine[Any, Any, None]],
253+
) -> None:
254254
"""Add a listener (callable) to a registered event.
255255
256256
Parameters
@@ -304,7 +304,7 @@ def remove_listener(
304304
self,
305305
*,
306306
schema_id: Optional[str] = None,
307-
listener: Callable[["EventLogger", str, dict], None],
307+
listener: Callable[["EventLogger", str, dict], Coroutine[Any, Any, None]],
308308
) -> None:
309309
"""Remove a listener from an event or all events.
310310
@@ -327,7 +327,9 @@ def remove_listener(
327327
self._modified_listeners[schema_id].discard(listener)
328328
self._unmodified_listeners[schema_id].discard(listener)
329329

330-
def emit(self, *, schema_id: str, data: dict, timestamp_override=None):
330+
def emit(
331+
self, *, schema_id: str, data: dict, timestamp_override: Optional[datetime] = None
332+
) -> Optional[dict]:
331333
"""
332334
Record given event with schema has occurred.
333335
@@ -351,7 +353,7 @@ def emit(self, *, schema_id: str, data: dict, timestamp_override=None):
351353
and not self._modified_listeners[schema_id]
352354
and not self._unmodified_listeners[schema_id]
353355
):
354-
return
356+
return None
355357

356358
# If the schema hasn't been registered, raise a warning to make sure
357359
# this was intended.
@@ -362,7 +364,7 @@ def emit(self, *, schema_id: str, data: dict, timestamp_override=None):
362364
"`register_event_schema` method.",
363365
SchemaNotRegistered,
364366
)
365-
return
367+
return None
366368

367369
schema = self.schemas.get(schema_id)
368370

@@ -400,7 +402,7 @@ def emit(self, *, schema_id: str, data: dict, timestamp_override=None):
400402

401403
# callback for removing from finished listeners
402404
# from active listeners set.
403-
def _listener_task_done(task: asyncio.Task):
405+
def _listener_task_done(task: asyncio.Task) -> None:
404406
# If an exception happens, log it to the main
405407
# applications logger
406408
err = task.exception()
@@ -429,7 +431,7 @@ def _listener_task_done(task: asyncio.Task):
429431
self._active_listeners.add(task)
430432

431433
# Remove task from active listeners once its finished.
432-
def _listener_task_done(task: asyncio.Task):
434+
def _listener_task_done(task: asyncio.Task) -> None:
433435
# If an exception happens, log it to the main
434436
# applications logger
435437
err = task.exception()

jupyter_events/schema.py

+12-9
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
"""Event schema objects."""
22
import json
33
from pathlib import Path, PurePath
4-
from typing import Type, Union
4+
from typing import Optional, Type, Union
55

6-
from jsonschema import FormatChecker, validators
6+
from jsonschema import FormatChecker, RefResolver, validators
77

88
try:
99
from jsonschema.protocols import Validator
@@ -34,6 +34,9 @@ class EventSchemaFileAbsent(Exception): # noqa
3434
pass
3535

3636

37+
SchemaType = Union[dict, str, PurePath]
38+
39+
3740
class EventSchema:
3841
"""A validated schema that can be used.
3942
@@ -58,10 +61,10 @@ class EventSchema:
5861

5962
def __init__(
6063
self,
61-
schema: Union[dict, str, PurePath],
62-
validator_class: Type[Validator] = validators.Draft7Validator, # type:ignore
64+
schema: SchemaType,
65+
validator_class: Type[Validator] = validators.Draft7Validator, # type:ignore[assignment]
6366
format_checker: FormatChecker = draft7_format_checker,
64-
resolver=None,
67+
resolver: Optional[RefResolver] = None,
6568
):
6669
"""Initialize an event schema."""
6770
_schema = self._load_schema(schema)
@@ -76,29 +79,29 @@ def __repr__(self):
7679
return json.dumps(self._schema, indent=2)
7780

7881
@staticmethod
79-
def _ensure_yaml_loaded(schema, was_str=False) -> None:
82+
def _ensure_yaml_loaded(schema: SchemaType, was_str: bool = False) -> None:
8083
"""Ensures schema was correctly loaded into a dictionary. Raises
8184
EventSchemaLoadingError otherwise."""
8285
if isinstance(schema, dict):
8386
return
8487

8588
error_msg = "Could not deserialize schema into a dictionary."
8689

87-
def intended_as_path(schema):
90+
def intended_as_path(schema: str) -> bool:
8891
path = Path(schema)
8992
return path.match("*.yml") or path.match("*.yaml") or path.match("*.json")
9093

9194
# detect whether the user specified a string but intended a PurePath to
9295
# generate a more helpful error message
93-
if was_str and intended_as_path(schema):
96+
if was_str and intended_as_path(schema): # type:ignore[arg-type]
9497
error_msg += " Paths to schema files must be explicitly wrapped in a Pathlib object."
9598
else:
9699
error_msg += " Double check the schema and ensure it is in the proper form."
97100

98101
raise EventSchemaLoadingError(error_msg)
99102

100103
@staticmethod
101-
def _load_schema(schema: Union[dict, str, PurePath]) -> dict:
104+
def _load_schema(schema: SchemaType) -> dict:
102105
"""Load a JSON schema from different sources/data types.
103106
104107
`schema` could be a dictionary or serialized string representing the

jupyter_events/schema_registry.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,15 @@ def __init__(self, schemas: Optional[dict] = None):
1515
"""Initialize the registry."""
1616
self._schemas = schemas or {}
1717

18-
def __contains__(self, key: str):
18+
def __contains__(self, key: str) -> bool:
1919
"""Syntax sugar to check if a schema is found in the registry"""
2020
return key in self._schemas
2121

2222
def __repr__(self) -> str:
2323
"""The str repr of the registry."""
2424
return ",\n".join([str(s) for s in self._schemas.values()])
2525

26-
def _add(self, schema_obj: EventSchema):
26+
def _add(self, schema_obj: EventSchema) -> None:
2727
if schema_obj.id in self._schemas:
2828
msg = (
2929
f"The schema, {schema_obj.id}, is already "

jupyter_events/validators.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@
4646
)
4747

4848

49-
def validate_schema(schema: dict):
49+
def validate_schema(schema: dict) -> None:
5050
"""Validate a schema dict."""
5151
try:
5252
# Validate the schema against Jupyter Events metaschema.

pyproject.toml

+14
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,20 @@ exclude_lines = [
133133
"@(abc\\.)?abstractmethod",
134134
]
135135

136+
[tool.mypy]
137+
check_untyped_defs = true
138+
disallow_incomplete_defs = true
139+
no_implicit_optional = true
140+
pretty = true
141+
show_error_context = true
142+
show_error_codes = true
143+
strict_equality = true
144+
warn_unused_configs = true
145+
warn_unused_ignores = true
146+
warn_redundant_casts = true
147+
explicit_package_bases = true
148+
namespace_packages = true
149+
136150
[tool.black]
137151
line-length = 100
138152
skip-string-normalization = true

tests/test_listeners.py

+6-9
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,11 @@ def jp_event_schemas(schema):
2323

2424
async def test_listener_function(jp_event_logger, schema):
2525
event_logger = jp_event_logger
26-
global listener_was_called
2726
listener_was_called = False
2827

2928
async def my_listener(logger: EventLogger, schema_id: str, data: dict) -> None:
30-
global listener_was_called
31-
listener_was_called = True # type: ignore
29+
nonlocal listener_was_called
30+
listener_was_called = True
3231

3332
# Add the modifier
3433
event_logger.add_listener(schema_id=schema.id, listener=my_listener)
@@ -41,12 +40,11 @@ async def my_listener(logger: EventLogger, schema_id: str, data: dict) -> None:
4140

4241
async def test_remove_listener_function(jp_event_logger, schema):
4342
event_logger = jp_event_logger
44-
global listener_was_called
4543
listener_was_called = False
4644

4745
async def my_listener(logger: EventLogger, schema_id: str, data: dict) -> None:
48-
global listener_was_called
49-
listener_was_called = True # type: ignore
46+
nonlocal listener_was_called
47+
listener_was_called = True
5048

5149
# Add the modifier
5250
event_logger.add_listener(schema_id=schema.id, listener=my_listener)
@@ -114,15 +112,14 @@ async def test_bad_listener_does_not_break_good_listener(jp_event_logger, schema
114112
h = logging.StreamHandler(log_stream)
115113
app_log.addHandler(h)
116114

117-
global listener_was_called
118115
listener_was_called = False
119116

120117
async def listener_raise_exception(logger: EventLogger, schema_id: str, data: dict) -> None:
121118
raise Exception("This failed") # noqa
122119

123120
async def my_listener(logger: EventLogger, schema_id: str, data: dict) -> None:
124-
global listener_was_called
125-
listener_was_called = True # type: ignore
121+
nonlocal listener_was_called
122+
listener_was_called = True
126123

127124
# Add a bad listener and a good listener and ensure that
128125
# emitting still works and the bad listener's exception is is logged.

tests/test_modifiers.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -54,20 +54,20 @@ def redact(self, schema_id: str, data: dict) -> dict:
5454
assert output["username"] == "<masked>"
5555

5656

57-
def test_bad_modifier_functions(jp_event_logger, schema: EventSchema):
57+
def test_bad_modifier_functions(jp_event_logger: EventLogger, schema: EventSchema) -> None:
5858
event_logger = jp_event_logger
5959

6060
def modifier_with_extra_args(schema_id: str, data: dict, unknown_arg: dict) -> dict:
6161
return data
6262

6363
with pytest.raises(ModifierError):
64-
event_logger.add_modifier(modifier=modifier_with_extra_args)
64+
event_logger.add_modifier(modifier=modifier_with_extra_args) # type:ignore[arg-type]
6565

6666
# Ensure no modifier was added.
6767
assert len(event_logger._modifiers[schema.id]) == 0
6868

6969

70-
def test_bad_modifier_method(jp_event_logger, schema: EventSchema):
70+
def test_bad_modifier_method(jp_event_logger: EventLogger, schema: EventSchema) -> None:
7171
event_logger = jp_event_logger
7272

7373
class Redactor:
@@ -77,7 +77,7 @@ def redact(self, schema_id: str, data: dict, extra_args: dict) -> dict:
7777
redactor = Redactor()
7878

7979
with pytest.raises(ModifierError):
80-
event_logger.add_modifier(modifier=redactor.redact)
80+
event_logger.add_modifier(modifier=redactor.redact) # type:ignore[arg-type]
8181

8282
# Ensure no modifier was added
8383
assert len(event_logger._modifiers[schema.id]) == 0
@@ -90,7 +90,7 @@ def modifier_with_extra_args(event):
9090
return event
9191

9292
with pytest.raises(ModifierError):
93-
logger.add_modifier(modifier=modifier_with_extra_args)
93+
logger.add_modifier(modifier=modifier_with_extra_args) # type:ignore[arg-type]
9494

9595

9696
def test_remove_modifier(schema, jp_event_logger, jp_read_emitted_events):

tests/test_schema.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def test_string_intended_as_path():
5959
def test_unrecognized_type():
6060
"""Validation fails because file is not of valid type."""
6161
with pytest.raises(EventSchemaUnrecognized):
62-
EventSchema(9001)
62+
EventSchema(9001) # type:ignore[arg-type]
6363

6464

6565
def test_invalid_yaml():

tests/utils.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,10 @@ def get_event_data(event, schema, schema_id, version, unredacted_policies):
1616
handler = logging.StreamHandler(sink)
1717

1818
e = EventLogger(handlers=[handler], unredacted_policies=unredacted_policies)
19-
e.register_schema(schema)
19+
e.register_event_schema(schema)
2020

2121
# Record event and read output
22-
e.emit(schema_id, version, deepcopy(event))
22+
e.emit(schema_id=schema_id, data=deepcopy(event))
2323

2424
recorded_event = json.loads(sink.getvalue())
2525
return {key: value for key, value in recorded_event.items() if not key.startswith("__")}

0 commit comments

Comments
 (0)