Skip to content

Commit 67de947

Browse files
authored
chore(types): Type-clean /actions (189 errors) (#1361)
Type-cleaned all files under `nemoguard/actions` and added them to pyright pre-commit hooks so type-coverage doesn't regress.
1 parent 77de2a8 commit 67de947

18 files changed

+660
-198
lines changed

nemoguardrails/actions/action_dispatcher.py

Lines changed: 55 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,9 @@
1919
import inspect
2020
import logging
2121
import os
22+
from importlib.machinery import ModuleSpec
2223
from pathlib import Path
23-
from typing import Any, Dict, List, Optional, Tuple, Union
24+
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union, cast
2425

2526
from langchain.chains.base import Chain
2627
from langchain_core.runnables import Runnable
@@ -51,7 +52,7 @@ def __init__(
5152
"""
5253
log.info("Initializing action dispatcher")
5354

54-
self._registered_actions = {}
55+
self._registered_actions: Dict[str, Union[Type, Callable[..., Any]]] = {}
5556

5657
if load_all_actions:
5758
# TODO: check for better way to find actions dir path or use constants.py
@@ -78,9 +79,12 @@ def __init__(
7879
# Last, but not least, if there was a config path, we try to load actions
7980
# from there as well.
8081
if config_path:
81-
config_path = config_path.split(",")
82-
for path in config_path:
83-
self.load_actions_from_path(Path(path.strip()))
82+
split_config_path: List[str] = config_path.split(",")
83+
84+
# Don't load actions if we have an empty list
85+
if split_config_path:
86+
for path in split_config_path:
87+
self.load_actions_from_path(Path(path.strip()))
8488

8589
# If there are any imported paths, we load the actions from there as well.
8690
if import_paths:
@@ -120,26 +124,28 @@ def load_actions_from_path(self, path: Path):
120124
)
121125

122126
def register_action(
123-
self, action: callable, name: Optional[str] = None, override: bool = True
127+
self, action: Callable, name: Optional[str] = None, override: bool = True
124128
):
125129
"""Registers an action with the given name.
126130
127131
Args:
128-
action (callable): The action function.
132+
action (Callable): The action function.
129133
name (Optional[str]): The name of the action. Defaults to None.
130134
override (bool): If an action already exists, whether it should be overridden or not.
131135
"""
132136
if name is None:
133137
action_meta = getattr(action, "action_meta", None)
134-
name = action_meta["name"] if action_meta else action.__name__
138+
action_name = action_meta["name"] if action_meta else action.__name__
139+
else:
140+
action_name = name
135141

136142
# If we're not allowed to override, we stop.
137-
if name in self._registered_actions and not override:
143+
if action_name in self._registered_actions and not override:
138144
return
139145

140-
self._registered_actions[name] = action
146+
self._registered_actions[action_name] = action
141147

142-
def register_actions(self, actions_obj: any, override: bool = True):
148+
def register_actions(self, actions_obj: Any, override: bool = True):
143149
"""Registers all the actions from the given object.
144150
145151
Args:
@@ -167,7 +173,7 @@ def has_registered(self, name: str) -> bool:
167173
name = self._normalize_action_name(name)
168174
return name in self.registered_actions
169175

170-
def get_action(self, name: str) -> callable:
176+
def get_action(self, name: str) -> Optional[Callable]:
171177
"""Get the registered action by name.
172178
173179
Args:
@@ -181,7 +187,7 @@ def get_action(self, name: str) -> callable:
181187

182188
async def execute_action(
183189
self, action_name: str, params: Dict[str, Any]
184-
) -> Tuple[Union[str, Dict[str, Any]], str]:
190+
) -> Tuple[Union[Optional[str], Dict[str, Any]], str]:
185191
"""Execute a registered action.
186192
187193
Args:
@@ -195,16 +201,21 @@ async def execute_action(
195201
action_name = self._normalize_action_name(action_name)
196202

197203
if action_name in self._registered_actions:
198-
log.info(f"Executing registered action: {action_name}")
199-
fn = self._registered_actions.get(action_name, None)
204+
log.info("Executing registered action: %s", action_name)
205+
maybe_fn: Optional[Callable] = self._registered_actions.get(
206+
action_name, None
207+
)
208+
if not maybe_fn:
209+
raise Exception(f"Action '{action_name}' is not registered.")
200210

211+
fn = cast(Callable, maybe_fn)
201212
# Actions that are registered as classes are initialized lazy, when
202213
# they are first used.
203214
if inspect.isclass(fn):
204215
fn = fn()
205216
self._registered_actions[action_name] = fn
206217

207-
if fn is not None:
218+
if fn:
208219
try:
209220
# We support both functions and classes as actions
210221
if inspect.isfunction(fn) or inspect.ismethod(fn):
@@ -245,7 +256,17 @@ async def execute_action(
245256
result = await runnable.ainvoke(input=params)
246257
else:
247258
# TODO: there should be a common base class here
248-
result = fn.run(**params)
259+
fn_run_func = getattr(fn, "run", None)
260+
if not callable(fn_run_func):
261+
raise Exception(
262+
f"No 'run' method defined for action '{action_name}'."
263+
)
264+
265+
fn_run_func_with_signature = cast(
266+
Callable[[], Union[Optional[str], Dict[str, Any]]],
267+
fn_run_func,
268+
)
269+
result = fn_run_func_with_signature(**params)
249270
return result, "success"
250271

251272
# We forward LLM Call exceptions
@@ -288,6 +309,7 @@ def _load_actions_from_module(filepath: str):
288309
"""
289310
action_objects = {}
290311
filename = os.path.basename(filepath)
312+
module = None
291313

292314
if not os.path.isfile(filepath):
293315
log.error(f"{filepath} does not exist or is not a file.")
@@ -298,13 +320,16 @@ def _load_actions_from_module(filepath: str):
298320
log.debug(f"Analyzing file {filename}")
299321
# Import the module from the file
300322

301-
spec = importlib.util.spec_from_file_location(filename, filepath)
302-
if spec is None:
323+
spec: Optional[ModuleSpec] = importlib.util.spec_from_file_location(
324+
filename, filepath
325+
)
326+
if not spec:
303327
log.error(f"Failed to create a module spec from {filepath}.")
304328
return action_objects
305329

306330
module = importlib.util.module_from_spec(spec)
307-
spec.loader.exec_module(module)
331+
if spec.loader:
332+
spec.loader.exec_module(module)
308333

309334
# Loop through all members in the module and check for the `@action` decorator
310335
# If class has action decorator is_action class member is true
@@ -313,19 +338,25 @@ def _load_actions_from_module(filepath: str):
313338
obj, "action_meta"
314339
):
315340
try:
316-
action_objects[obj.action_meta["name"]] = obj
317-
log.info(f"Added {obj.action_meta['name']} to actions")
341+
actionable_name: str = getattr(obj, "action_meta").get("name")
342+
action_objects[actionable_name] = obj
343+
log.info(f"Added {actionable_name} to actions")
318344
except Exception as e:
319345
log.error(
320-
f"Failed to register {obj.action_meta['name']} in action dispatcher due to exception {e}"
346+
f"Failed to register {name} in action dispatcher due to exception {e}"
321347
)
322348
except Exception as e:
349+
if module is None:
350+
raise RuntimeError(f"Failed to load actions from module at {filepath}.")
351+
if not module.__file__:
352+
raise RuntimeError(f"No file found for module {module} at {filepath}.")
353+
323354
try:
324355
relative_filepath = Path(module.__file__).relative_to(Path.cwd())
325356
except ValueError:
326357
relative_filepath = Path(module.__file__).resolve()
327358
log.error(
328-
f"Failed to register {filename} from {relative_filepath} in action dispatcher due to exception: {e}"
359+
f"Failed to register {filename} in action dispatcher due to exception: {e}"
329360
)
330361

331362
return action_objects

nemoguardrails/actions/actions.py

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,27 +14,42 @@
1414
# limitations under the License.
1515

1616
from dataclasses import dataclass, field
17-
from typing import Any, Callable, List, Optional, TypedDict, Union
18-
19-
20-
class ActionMeta(TypedDict, total=False):
17+
from typing import (
18+
Any,
19+
Callable,
20+
List,
21+
Optional,
22+
Protocol,
23+
Type,
24+
TypedDict,
25+
TypeVar,
26+
Union,
27+
cast,
28+
)
29+
30+
31+
class ActionMeta(TypedDict):
2132
name: str
2233
is_system_action: bool
2334
execute_async: bool
2435
output_mapping: Optional[Callable[[Any], bool]]
2536

2637

38+
# Create a TypeVar to represent the decorated function or class
39+
T = TypeVar("T", bound=Union[Callable[..., Any], Type[Any]])
40+
41+
2742
def action(
2843
is_system_action: bool = False,
2944
name: Optional[str] = None,
3045
execute_async: bool = False,
3146
output_mapping: Optional[Callable[[Any], bool]] = None,
32-
) -> Callable[[Union[Callable, type]], Union[Callable, type]]:
47+
) -> Callable[[T], T]:
3348
"""Decorator to mark a function or class as an action.
3449
3550
Args:
3651
is_system_action (bool): Flag indicating if the action is a system action.
37-
name (Optional[str]): The name to associate with the action.
52+
name (str): The name to associate with the action.
3853
execute_async: Whether the function should be executed in async mode.
3954
output_mapping (Optional[Callable[[Any], bool]]): A function to interpret the action's result.
4055
It accepts the return value (e.g. the first element of a tuple) and return True if the output
@@ -44,16 +59,19 @@ def action(
4459
callable: The decorated function or class.
4560
"""
4661

47-
def decorator(fn_or_cls: Union[Callable, type]) -> Union[Callable, type]:
62+
def decorator(fn_or_cls: Union[Callable, Type]) -> Union[Callable, Type]:
4863
"""Inner decorator function to add metadata to the action.
4964
5065
Args:
5166
fn_or_cls: The function or class being decorated.
5267
"""
5368
fn_or_cls_target = getattr(fn_or_cls, "__func__", fn_or_cls)
5469

70+
# Action name is optional for the decorator, but mandatory for ActionMeta TypedDict
71+
action_name: str = cast(str, name or fn_or_cls.__name__)
72+
5573
action_meta: ActionMeta = {
56-
"name": name or fn_or_cls.__name__,
74+
"name": action_name,
5775
"is_system_action": is_system_action,
5876
"execute_async": execute_async,
5977
"output_mapping": output_mapping,
@@ -62,7 +80,7 @@ def decorator(fn_or_cls: Union[Callable, type]) -> Union[Callable, type]:
6280
setattr(fn_or_cls_target, "action_meta", action_meta)
6381
return fn_or_cls
6482

65-
return decorator
83+
return decorator # pyright: ignore (TODO - resolve how the Actionable Protocol doesn't resolve the issue)
6684

6785

6886
@dataclass

nemoguardrails/actions/core.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
# limitations under the License.
1515

1616
import logging
17-
from typing import Optional
17+
from typing import Any, Dict, Optional
1818

1919
from nemoguardrails.actions.actions import ActionResult, action
2020
from nemoguardrails.utils import new_event_dict
@@ -37,13 +37,13 @@ async def create_event(
3737
ActionResult: An action result containing the created event.
3838
"""
3939

40-
event_dict = new_event_dict(
40+
event_dict: Dict[str, Any] = new_event_dict(
4141
event["_type"], **{k: v for k, v in event.items() if k != "_type"}
4242
)
4343

4444
# We add basic support for referring variables as values
4545
for k, v in event_dict.items():
4646
if isinstance(v, str) and v[0] == "$":
47-
event_dict[k] = context.get(v[1:])
47+
event_dict[k] = context.get(v[1:], None) if context else None
4848

4949
return ActionResult(events=[event_dict])

nemoguardrails/actions/langchain/safetools.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,27 @@
1919
"""
2020

2121
import logging
22+
from typing import TYPE_CHECKING
2223

2324
from nemoguardrails.actions.validation import validate_input, validate_response
2425

2526
log = logging.getLogger(__name__)
2627

28+
# Include these outside the try .. except so the Type-checker knows they're always imported
29+
if TYPE_CHECKING:
30+
from langchain_community.utilities import (
31+
ApifyWrapper,
32+
BingSearchAPIWrapper,
33+
GoogleSearchAPIWrapper,
34+
GoogleSerperAPIWrapper,
35+
OpenWeatherMapAPIWrapper,
36+
SearxSearchWrapper,
37+
SerpAPIWrapper,
38+
WikipediaAPIWrapper,
39+
WolframAlphaAPIWrapper,
40+
ZapierNLAWrapper,
41+
)
42+
2743
try:
2844
from langchain_community.utilities import (
2945
ApifyWrapper,

0 commit comments

Comments
 (0)