Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions docs/sections/user_guide/cli/tools/execute/help.out
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
usage: uw execute --module MODULE --classname CLASSNAME --task TASK [-h]
[--version] [--config-file PATH] [--schema-file PATH]
usage: uw execute --module MODULE --classname CLASSNAME [-h] [--version]
[--task TASK] [--config-file PATH] [--schema-file PATH]
[--cycle CYCLE] [--leadtime LEADTIME] [--batch] [--dry-run]
[--graph-file PATH] [--key-path KEY[.KEY...]] [--quiet]
[--verbose]
Expand All @@ -11,14 +11,14 @@ Required arguments:
Path to driver module or name of module on sys.path
--classname CLASSNAME
Name of driver class
--task TASK
Task to execute

Optional arguments:
-h, --help
Show help and exit
--version
Show version info and exit
--task TASK
Task to execute
--config-file PATH, -c PATH
Path to UW YAML config file (default: read from stdin)
--schema-file PATH
Expand Down
18 changes: 16 additions & 2 deletions src/uwtools/api/execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from traceback import format_exc
from typing import TYPE_CHECKING

from iotaa import tasknames

from uwtools.drivers.support import tasks as _tasks
from uwtools.logging import log
from uwtools.strings import STR
Expand All @@ -28,7 +30,7 @@
def execute(
module: Path | str,
classname: str,
task: str,
task: str | None = None,
schema_file: str | None = None,
config: Path | str | None = None,
cycle: datetime | None = None,
Expand All @@ -48,7 +50,7 @@ def execute(

:param module: Path to driver module or name of module on sys.path.
:param classname: Name of driver class to instantiate.
:param task: Name of driver task to execute.
:param task: Name of driver task to execute. If omitted, a list of available tasks is displayed.
:param schema_file: The JSON Schema file to use for validation.
:param config: Path to config file (read stdin if missing or None).
:param cycle: The cycle.
Expand All @@ -64,6 +66,11 @@ def execute(
if not class_:
return None
assert module_path is not None
if bad_task := task and task not in tasknames(class_):
log.error("%s driver has no task '%s'", class_.__name__, task)
if bad_task or task is None:
_list_available_tasks(module, classname)
return None
args = dict(locals())
accepted = set(getfullargspec(class_).args)
non_optional = {STR.cycle, STR.leadtime}
Expand Down Expand Up @@ -158,4 +165,11 @@ def _get_driver_module_implicit(module: str) -> ModuleType | None:
return None


def _list_available_tasks(module: Path | str, classname: str) -> None:
log.error("Available tasks:")
for taskname, description in tasks(module, classname).items():
log.error(" %s" % taskname)
log.error(" %s" % (description or "No description available."))


__all__ = ["execute", "tasks"]
6 changes: 3 additions & 3 deletions src/uwtools/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,8 +319,8 @@ def _add_subparser_execute(subparsers: Subparsers) -> ModeChecks:
required = parser.add_argument_group(TITLE_REQ_ARG)
_add_arg_module(required)
_add_arg_classname(required)
_add_arg_task(required)
optional = _basic_setup(parser)
_add_arg_task(optional, required=False)
_add_arg_config_file(optional)
_add_arg_schema_file(optional)
_add_arg_cycle(optional)
Expand Down Expand Up @@ -1017,11 +1017,11 @@ def _add_arg_target_dir(group: Group, required: bool = False, helpmsg: str | Non
)


def _add_arg_task(group: Group) -> None:
def _add_arg_task(group: Group, required: bool = True) -> None:
group.add_argument(
_switch(STR.task),
help="Task to execute",
required=True,
required=required,
type=str,
)

Expand Down
61 changes: 35 additions & 26 deletions src/uwtools/tests/api/test_execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,24 @@ def kwargs(args):


@mark.parametrize(("key", "val"), [("batch", True), ("leadtime", 6)])
def test_execute_fail_bad_args(key, kwargs, logged, utc, val):
def test_execute__fail_bad_args(key, kwargs, logged, utc, val):
kwargs.update({"cycle": utc(), key: val})
assert execute.execute(**kwargs) is None
assert logged(f"TestDriver does not accept argument '{key}'")


def test_execute_fail_stdin_not_ok(kwargs, utc):
def test_execute__fail_bad_task(kwargs, logged, utc):
kwargs.update({"cycle": utc(), "task": "foo"})
assert execute.execute(**kwargs) is None
assert logged("TestDriver driver has no task 'foo'")


def test_execute__fail_cannot_load_driver_class(kwargs):
kwargs["module"] = "bad_module_name"
assert execute.execute(**kwargs) is None


def test_execute__fail_stdin_not_ok(kwargs, utc):
kwargs["config"] = None
kwargs["cycle"] = utc()
kwargs["stdin_ok"] = False
Expand All @@ -56,7 +67,7 @@ def test_execute_fail_stdin_not_ok(kwargs, utc):

@mark.parametrize("graph", [True, False])
@mark.parametrize("remove", [[], ["schema_file"]])
def test_execute_pass(graph, kwargs, logged, remove, tmp_path, utc):
def test_execute__pass(graph, kwargs, logged, remove, tmp_path, utc):
for kwarg in remove:
del kwargs[kwarg]
kwargs["cycle"] = utc()
Expand All @@ -67,12 +78,15 @@ def test_execute_pass(graph, kwargs, logged, remove, tmp_path, utc):
with (
patch.object(execute, "_get_driver_class") as gdc,
patch.object(execute, "getfullargspec") as gfa,
patch.object(execute, "tasknames", return_value=["forty_two"]),
):
node = Mock(graph=graph_code) if graph else Mock()
node.ref = 42
driverobj = Mock()
driverobj.forty_two.return_value = node
gdc.return_value = (Mock(return_value=driverobj), kwargs["module"])
class_ = Mock(return_value=driverobj)
setattr(class_, "__name__", "test")
gdc.return_value = (class_, kwargs["module"])
gfa().args = {"batch", "cycle"}
val = execute.execute(**kwargs)
assert val
Expand All @@ -82,77 +96,72 @@ def test_execute_pass(graph, kwargs, logged, remove, tmp_path, utc):
assert graph_file.read_text().strip() == graph_code


def test_execute_fail_cannot_load_driver_class(kwargs):
kwargs["module"] = "bad_module_name"
assert execute.execute(**kwargs) is None


def test_tasks_fail(args, logged, tmp_path):
def test_tasks__fail(args, logged, tmp_path):
module = tmp_path / "not.py"
tasks = execute.tasks(classname=args.classname, module=module)
assert tasks == {}
assert logged("Could not get tasks from class %s in module %s" % (args.classname, module))


def test_tasks_fail_no_cycle(args, kwargs, logged):
def test_tasks__fail_no_cycle(args, kwargs, logged):
assert execute.execute(**kwargs) is None
assert logged("%s requires argument '%s'" % (args.classname, "cycle"))


@mark.parametrize("f", [Path, str])
def test_tasks_pass(args, f):
def test_tasks__pass(args, f):
tasks = execute.tasks(classname=args.classname, module=f(args.module))
assert tasks["forty_two"] == "Forty Two."


def test__get_driver_class_explicit_fail_bad_class(args, logged):
def test__get_driver_class__explicit_fail_bad_class(args, logged):
bad_class = "BadClass"
c, module_path = execute._get_driver_class(classname=bad_class, module=args.module)
assert c is None
assert module_path == args.module
assert logged("Module %s has no class %s" % (args.module, bad_class))


def test__get_driver_class_explicit_fail_bad_name(args, logged):
def test__get_driver_class__explicit_fail_bad_name(args, logged):
bad_name = Path("bad_name")
c, module_path = execute._get_driver_class(classname=args.classname, module=bad_name)
assert c is None
assert module_path is None
assert logged("Could not load module %s" % bad_name)


def test__get_driver_class_explicit_fail_bad_path(args, logged, tmp_path):
def test__get_driver_class__explicit_fail_bad_path(args, logged, tmp_path):
module = tmp_path / "not.py"
c, module_path = execute._get_driver_class(classname=args.classname, module=module)
assert c is None
assert module_path is None
assert logged("Could not load module %s" % module)


def test__get_driver_class_explicit_fail_bad_spec(args, logged):
def test__get_driver_class__explicit_fail_bad_spec(args, logged):
with patch.object(execute, "spec_from_file_location", return_value=None):
c, module_path = execute._get_driver_class(classname=args.classname, module=args.module)
assert c is None
assert module_path is None
assert logged("Could not load module %s" % args.module)


def test__get_driver_class_explicit_pass(args):
def test__get_driver_class__explicit_pass(args):
c, module_path = execute._get_driver_class(classname=args.classname, module=args.module)
assert c
assert c.__name__ == "TestDriver"
assert module_path == args.module


def test__get_driver_class_implicit_pass(args):
def test__get_driver_class__implicit_pass(args):
with patch.object(Path, "cwd", return_value=fixture_path()):
c, module_path = execute._get_driver_class(classname=args.classname, module=args.module)
assert c
assert c.__name__ == "TestDriver"
assert module_path == args.module


def test__get_driver_module_explicit_absolute_fail_syntax_error(args, logged, tmp_path):
def test__get_driver_module__explicit_absolute_fail_syntax_error(args, logged, tmp_path):
module = tmp_path / "module.py"
module.write_text("syntax error\n%s" % args.module.read_text())
assert module.is_absolute()
Expand All @@ -162,37 +171,37 @@ def test__get_driver_module_explicit_absolute_fail_syntax_error(args, logged, tm
assert logged("SyntaxError: invalid syntax")


def test__get_driver_module_explicit_absolute_fail_bad_path(args):
def test__get_driver_module__explicit_absolute_fail_bad_path(args):
assert args.module.is_absolute()
module = args.module.with_suffix(".bad")
assert not execute._get_driver_module_explicit(module=module)


def test__get_driver_module_explicit_absolute_pass(args):
def test__get_driver_module__explicit_absolute_pass(args):
assert args.module.is_absolute()
assert execute._get_driver_module_explicit(module=args.module)


def test__get_driver_module_explicit_relative_fail_bad_path(args):
def test__get_driver_module__explicit_relative_fail_bad_path(args):
args.module = Path(os.path.relpath(args.module)).with_suffix(".bad")
assert not args.module.is_absolute()
assert not execute._get_driver_module_explicit(module=args.module)


def test__get_driver_module_explicit_relative_pass(args):
def test__get_driver_module__explicit_relative_pass(args):
args.module = Path(os.path.relpath(args.module))
assert not args.module.is_absolute()
assert execute._get_driver_module_explicit(module=args.module)


def test__get_driver_module_implicit_pass_full_package():
def test__get_driver_module__implicit_pass_full_package():
assert execute._get_driver_module_implicit("uwtools.tests.fixtures.testdriver")


def test__get_driver_module_implicit_pass():
def test__get_driver_module__implicit_pass():
with patch.object(sys, "path", [str(fixture_path()), *sys.path]):
assert execute._get_driver_module_implicit("testdriver")


def test__get_driver_module_implicit_fail():
def test__get_driver_module__implicit_fail():
assert not execute._get_driver_module_implicit("no.such.module")