diff --git a/flytekit/core/array_node_map_task.py b/flytekit/core/array_node_map_task.py index 9df005ac33..243466b8b6 100644 --- a/flytekit/core/array_node_map_task.py +++ b/flytekit/core/array_node_map_task.py @@ -13,6 +13,7 @@ from flytekit.core import tracker from flytekit.core.array_node import array_node from flytekit.core.base_task import PythonTask, TaskResolverMixin +from flytekit.core.container_task import ContainerTask from flytekit.core.context_manager import ExecutionState, FlyteContext, FlyteContextManager from flytekit.core.interface import transform_interface_to_list_interface from flytekit.core.launch_plan import LaunchPlan @@ -36,7 +37,7 @@ class ArrayNodeMapTask(PythonTask): def __init__( self, # TODO: add support for other Flyte entities - python_function_task: Union[PythonFunctionTask, PythonInstanceTask, functools.partial], + python_function_task: Union[PythonFunctionTask, PythonInstanceTask, ContainerTask, functools.partial], concurrency: Optional[int] = None, min_successes: Optional[int] = None, min_success_ratio: Optional[float] = None, @@ -66,10 +67,10 @@ def __init__( isinstance(actual_task, PythonFunctionTask) and actual_task.execution_mode == PythonFunctionTask.ExecutionBehavior.DEFAULT ) - or isinstance(actual_task, PythonInstanceTask) + or isinstance(actual_task, (PythonInstanceTask, ContainerTask)) ): raise ValueError( - "Only PythonFunctionTask with default execution mode (not @dynamic or @eager) and PythonInstanceTask are supported in map tasks." + "Only PythonFunctionTask with default execution mode (not @dynamic or @eager), PythonInstanceTask, and ContainerTask are supported in map tasks." ) n_outputs = len(actual_task.python_interface.outputs) @@ -101,6 +102,9 @@ def __init__( if isinstance(actual_task, PythonInstanceTask): mod = actual_task.task_type f = actual_task.lhs + elif isinstance(actual_task, ContainerTask): + mod = actual_task.task_type + f = actual_task.name else: _, mod, f, _ = tracker.extract_task_module(cast(PythonFunctionTask, actual_task).task_function) sorted_bounded_inputs = ",".join(sorted(self._bound_inputs)) @@ -192,6 +196,10 @@ def prepare_target(self): """ Alters the underlying run_task command to modify it for map task execution and then resets it after. """ + if isinstance(self._run_task, ContainerTask): + yield + return + self.python_function_task.set_command_fn(self.get_command) try: yield @@ -261,9 +269,10 @@ def __call__(self, *args, **kwargs): return super().__call__(*args, **kwargs) def _literal_map_to_python_input( - self, literal_map: _literal_models.LiteralMap, ctx: FlyteContext + self, literal_map: _literal_models.LiteralMap, ctx: Optional[FlyteContext] = None ) -> Dict[str, Any]: - ctx = FlyteContextManager.current_context() + if ctx is None: + ctx = FlyteContextManager.current_context() inputs_interface = self.python_interface.inputs inputs_map = literal_map # If we run locally, we will need to process all of the inputs. If we are running in a remote task execution @@ -381,7 +390,7 @@ def _raw_execute(self, **kwargs) -> Any: def map_task( - target: Union[LaunchPlan, PythonFunctionTask, "FlyteLaunchPlan"], + target: Union[LaunchPlan, PythonFunctionTask, ContainerTask, "FlyteLaunchPlan"], concurrency: Optional[int] = None, min_successes: Optional[int] = None, min_success_ratio: float = 1.0, @@ -418,7 +427,7 @@ def map_task( def array_node_map_task( - task_function: PythonFunctionTask, + task_function: Union[PythonFunctionTask, ContainerTask], concurrency: Optional[int] = None, # TODO why no min_successes? min_success_ratio: float = 1.0, diff --git a/flytekit/core/container_task.py b/flytekit/core/container_task.py index 2d99f3b8c0..12b3fe6267 100644 --- a/flytekit/core/container_task.py +++ b/flytekit/core/container_task.py @@ -14,7 +14,6 @@ from flytekit.image_spec.image_spec import ImageSpec from flytekit.loggers import logger from flytekit.models import task as _task_model -from flytekit.models.literals import LiteralMap from flytekit.models.security import Secret, SecurityContext _PRIMARY_CONTAINER_NAME_FIELD = "primary_container_name" @@ -254,14 +253,12 @@ def _get_output_dict(self, output_directory: str) -> Dict[str, Any]: output_dict[k] = self._convert_output_val_to_correct_type(output_val, output_type) return output_dict - def execute(self, **kwargs) -> LiteralMap: + def execute(self, **kwargs) -> Any: try: import docker except ImportError: raise ImportError(DOCKER_IMPORT_ERROR_MESSAGE) - from flytekit.core.type_engine import TypeEngine - ctx = FlyteContext.current_context() # Normalize the input and output directories @@ -289,8 +286,12 @@ def execute(self, **kwargs) -> LiteralMap: container.wait() output_dict = self._get_output_dict(output_directory) - outputs_literal_map = TypeEngine.dict_to_literal_map(ctx, output_dict) - return outputs_literal_map + if len(output_dict) == 0: + return None + elif len(output_dict) == 1: + return list(output_dict.values())[0] + elif len(output_dict) > 1: + return tuple(output_dict.values()) def get_container(self, settings: SerializationSettings) -> _task_model.Container: # if pod_template is specified, return None here but in get_k8s_pod, return pod_template merged with container diff --git a/flytekit/tools/translator.py b/flytekit/tools/translator.py index 12869644d6..b6e660e815 100644 --- a/flytekit/tools/translator.py +++ b/flytekit/tools/translator.py @@ -166,11 +166,11 @@ def get_serializable_task( if settings.should_fast_serialize(): # This handles container tasks. if container and isinstance(entity, (PythonAutoContainerTask, MapPythonTask, ArrayNodeMapTask)): - # For fast registration, we'll need to muck with the command, but on - # ly for certain kinds of tasks. Specifically, - # tasks that rely on user code defined in the container. This should be encapsulated by the auto container - # parent class - container._args = prefix_with_fast_execute(settings, container.args) + # For fast registration, we'll need to muck with the command, but + # only for certain kinds of tasks. Specifically, tasks that rely + # on user code defined in the container. This should be + # encapsulated by the auto container parent class + container._args = prefix_with_fast_execute(settings, container.args or []) # If the pod spec is not None, we have to get it again, because the one we retrieved above will be incorrect. # The reason we have to call get_k8s_pod again, instead of just modifying the command in this file, is because diff --git a/tests/flytekit/integration/remote/test_remote.py b/tests/flytekit/integration/remote/test_remote.py index 18481d9e69..a6c4d56fcf 100644 --- a/tests/flytekit/integration/remote/test_remote.py +++ b/tests/flytekit/integration/remote/test_remote.py @@ -1,3 +1,5 @@ +import math +import functools import botocore.session import shutil from contextlib import ExitStack, contextmanager @@ -20,7 +22,7 @@ import string from dataclasses import asdict, dataclass -from flytekit import LaunchPlan, kwtypes, WorkflowExecutionPhase, task, workflow +from flytekit import LaunchPlan, kwtypes, WorkflowExecutionPhase, task, workflow, ContainerTask, map_task from flytekit.configuration import Config, ImageConfig, SerializationSettings from flytekit.core.launch_plan import reference_launch_plan from flytekit.core.task import reference_task @@ -1358,3 +1360,46 @@ def test_run_wf_with_resource_requests_override(register): ], limits=[], ) + + +def test_container_task_map_execution(): + # NOTE: We only take one output "area" even if this calculate-ellipse-area.py + # produce two output. This is because that map task can only return one value. + calculate_ellipse_area_python_template_style = ContainerTask( + name="calculate_ellipse_area_python_template_style", + input_data_dir="/var/inputs", + output_data_dir="/var/outputs", + inputs=kwtypes(a=float, b=float), + outputs=kwtypes(area=float), + image="ghcr.io/flyteorg/rawcontainers-python:v2", + command=[ + "python", + "calculate-ellipse-area.py", + "{{.inputs.a}}", + "{{.inputs.b}}", + "/var/outputs", + ], + ) + + @workflow + def wf(a: list[float], b: float) -> list[float]: + partial_task = functools.partial( + calculate_ellipse_area_python_template_style, b=b + ) + res = map_task(partial_task)(a=a) + return res + + def calculate_area(a, b): + return math.pi * a * b + + expected_area = [ + calculate_area(a, b) for a, b in [(3.0, 4.0), (4.0, 4.0), (5.0, 4.0)] + ] + + + remote = FlyteRemote(Config.auto(config_file=CONFIG), PROJECT, DOMAIN) + execution = remote.execute(wf, inputs={"a": [3.0, 4.0, 5.0], "b": 4.0}, wait=True, version=VERSION) + execution = remote.wait(execution=execution, timeout=datetime.timedelta(minutes=2)) + assert execution.error is None, f"Execution failed with error: {execution.error}" + + assert execution.outputs["o0"] == expected_area diff --git a/tests/flytekit/unit/core/test_array_node_map_task.py b/tests/flytekit/unit/core/test_array_node_map_task.py index a2f35424e4..8da954b303 100644 --- a/tests/flytekit/unit/core/test_array_node_map_task.py +++ b/tests/flytekit/unit/core/test_array_node_map_task.py @@ -1,3 +1,5 @@ +import math +import sys import functools from datetime import timedelta import os @@ -9,10 +11,17 @@ import pytest from flyteidl.core import workflow_pb2 as _core_workflow -from flytekit import dynamic, map_task, task, workflow, eager, PythonFunctionTask -from flytekit.configuration import FastSerializationSettings, Image, ImageConfig, SerializationSettings +from flytekit import dynamic, map_task, task, workflow, eager, PythonFunctionTask, ImageSpec +from flytekit.configuration import ( + FastSerializationSettings, + Image, + ImageConfig, + SerializationSettings, +) from flytekit.core import context_manager from flytekit.core.array_node_map_task import ArrayNodeMapTask, ArrayNodeMapTaskResolver +from flytekit.core.base_task import kwtypes +from flytekit.core.container_task import ContainerTask from flytekit.core.task import TaskMetadata from flytekit.core.type_engine import TypeEngine from flytekit.extras.accelerators import GPUAccelerator @@ -96,12 +105,16 @@ def say_hello(name: str) -> str: ctx = context_manager.FlyteContextManager.current_context() with context_manager.FlyteContextManager.with_context( - ctx.with_execution_state( - ctx.execution_state.with_params(mode=context_manager.ExecutionState.Mode.TASK_EXECUTION) + ctx.with_execution_state( + ctx.execution_state.with_params( + mode=context_manager.ExecutionState.Mode.TASK_EXECUTION ) + ) ) as ctx: t = map_task(say_hello) - lm = TypeEngine.dict_to_literal_map(ctx, {"name": ["earth", "mars"]}, type_hints={"name": typing.List[str]}) + lm = TypeEngine.dict_to_literal_map( + ctx, {"name": ["earth", "mars"]}, type_hints={"name": typing.List[str]} + ) res = t.dispatch_execute(ctx, lm) assert len(res.literals) == 1 assert res.literals["o0"].scalar.primitive.string_value == "hello earth!" @@ -121,7 +134,9 @@ def t1(a: int) -> int: return a + 1 arraynode_maptask = map_task(t1, metadata=TaskMetadata(retries=2)) - task_spec = get_serializable(OrderedDict(), serialization_settings, arraynode_maptask) + task_spec = get_serializable( + OrderedDict(), serialization_settings, arraynode_maptask + ) assert task_spec.template.metadata.retries.retries == 2 assert task_spec.template.type == "python-task" @@ -153,14 +168,18 @@ def t1(a: int) -> int: def test_fast_serialization(serialization_settings): - serialization_settings.fast_serialization_settings = FastSerializationSettings(enabled=True) + serialization_settings.fast_serialization_settings = FastSerializationSettings( + enabled=True + ) @task def t1(a: int) -> int: return a + 1 arraynode_maptask = map_task(t1, metadata=TaskMetadata(retries=2)) - task_spec = get_serializable(OrderedDict(), serialization_settings, arraynode_maptask) + task_spec = get_serializable( + OrderedDict(), serialization_settings, arraynode_maptask + ) assert task_spec.template.container.args == [ "pyflyte-fast-execute", @@ -201,7 +220,11 @@ def t1(a: int) -> int: ({}, {"concurrency": 2}, False), ({}, {"min_successes": 3}, False), ({}, {"min_success_ratio": 0.2}, False), - ({}, {"concurrency": 10, "min_successes": 999, "min_success_ratio": 0.2}, False), + ( + {}, + {"concurrency": 10, "min_successes": 999, "min_success_ratio": 0.2}, + False, + ), ({"concurrency": 1}, {"concurrency": 2}, False), ({"concurrency": 42}, {"concurrency": 42}, True), ({"min_successes": 1}, {"min_successes": 2}, False), @@ -241,7 +264,11 @@ def many_inputs(a: int, b: str, c: float) -> str: return f"{a} - {b} - {c}" m = map_task(many_inputs) - assert m.python_interface.inputs == {"a": List[int], "b": List[str], "c": List[float]} + assert m.python_interface.inputs == { + "a": List[int], + "b": List[str], + "c": List[float], + } assert ( m.name == "tests.flytekit.unit.core.test_array_node_map_task.map_many_inputs_6b3bd0353da5de6e84d7982921ead2b3-arraynode" @@ -315,24 +342,46 @@ def task4(c: list[str], a: list[int], b: list[float]) -> list[str]: m2 = map_task(functools.partial(task2, c=fixed_param_c))(a=param_a, b=param_b) m3 = map_task(functools.partial(task3, c=fixed_param_c))(a=param_a, b=param_b) - m4 = ArrayNodeMapTask(task1, bound_inputs_values={"c": fixed_param_c})(a=param_a, b=param_b) - m5 = ArrayNodeMapTask(task2, bound_inputs_values={"c": fixed_param_c})(a=param_a, b=param_b) - m6 = ArrayNodeMapTask(task3, bound_inputs_values={"c": fixed_param_c})(a=param_a, b=param_b) + m4 = ArrayNodeMapTask(task1, bound_inputs_values={"c": fixed_param_c})( + a=param_a, b=param_b + ) + m5 = ArrayNodeMapTask(task2, bound_inputs_values={"c": fixed_param_c})( + a=param_a, b=param_b + ) + m6 = ArrayNodeMapTask(task3, bound_inputs_values={"c": fixed_param_c})( + a=param_a, b=param_b + ) - assert m1 == m2 == m3 == m4 == m5 == m6 == ["1 - 0.1 - c", "2 - 0.2 - c", "3 - 0.3 - c"] + assert ( + m1 + == m2 + == m3 + == m4 + == m5 + == m6 + == ["1 - 0.1 - c", "2 - 0.2 - c", "3 - 0.3 - c"] + ) list_param_a = [[1, 2, 3], [4, 5, 6], [7, 8, 9]] list_param_b = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]] fixed_list_param_c = ["c", "d", "e"] - m7 = map_task(functools.partial(task4, c=fixed_list_param_c))(a=list_param_a, b=list_param_b) - m8 = ArrayNodeMapTask(task4, bound_inputs_values={"c": fixed_list_param_c})(a=list_param_a, b=list_param_b) + m7 = map_task(functools.partial(task4, c=fixed_list_param_c))( + a=list_param_a, b=list_param_b + ) + m8 = ArrayNodeMapTask(task4, bound_inputs_values={"c": fixed_list_param_c})( + a=list_param_a, b=list_param_b + ) - assert m7 == m8 == [ - ['1 - 0.1 - c', '2 - 0.2 - d', '3 - 0.3 - e'], - ['4 - 0.4 - c', '5 - 0.5 - d', '6 - 0.6 - e'], - ['7 - 0.7 - c', '8 - 0.8 - d', '9 - 0.9 - e'] - ] + assert ( + m7 + == m8 + == [ + ["1 - 0.1 - c", "2 - 0.2 - d", "3 - 0.3 - e"], + ["4 - 0.4 - c", "5 - 0.5 - d", "6 - 0.6 - e"], + ["7 - 0.7 - c", "8 - 0.8 - d", "9 - 0.9 - e"], + ] + ) with pytest.raises(ValueError): map_task(functools.partial(task1, c=fixed_list_param_c))(a=param_a, b=param_b) @@ -361,15 +410,24 @@ def task1(a: int, b: float, c: str) -> str: param_d = "d" partial_task = functools.partial(task1, c=param_c) - m1 = ArrayNodeMapTask(partial_task, bound_inputs_values={"c": param_d})(a=param_a, b=param_b) + m1 = ArrayNodeMapTask(partial_task, bound_inputs_values={"c": param_d})( + a=param_a, b=param_b + ) assert m1 == ["1 - 0.1 - d", "2 - 0.2 - d", "3 - 0.3 - d"] - with pytest.raises(ValueError, match="bound_inputs and bound_inputs_values should have the same keys if both set"): - ArrayNodeMapTask(task1, bound_inputs_values={"c": param_c}, bound_inputs={"b"})(a=param_a, b=param_b) + with pytest.raises( + ValueError, + match="bound_inputs and bound_inputs_values should have the same keys if both set", + ): + ArrayNodeMapTask(task1, bound_inputs_values={"c": param_c}, bound_inputs={"b"})( + a=param_a, b=param_b + ) # no error raised - ArrayNodeMapTask(task1, bound_inputs_values={"c": param_c}, bound_inputs={"c"})(a=param_a, b=param_b) + ArrayNodeMapTask(task1, bound_inputs_values={"c": param_c}, bound_inputs={"c"})( + a=param_a, b=param_b + ) @task() @@ -385,7 +443,9 @@ def task_2() -> int: def get_wf_bound_input(serialization_settings): @workflow() def wf1() -> List[str]: - return ArrayNodeMapTask(task_1, bound_inputs_values={"a": 1})(b=[1, 2, 3], c=["a", "b", "c"]) + return ArrayNodeMapTask(task_1, bound_inputs_values={"a": 1})( + b=[1, 2, 3], c=["a", "b", "c"] + ) return wf1 @@ -393,7 +453,9 @@ def wf1() -> List[str]: def get_wf_partials(serialization_settings): @workflow() def wf2() -> List[str]: - return ArrayNodeMapTask(functools.partial(task_1, a=1))(b=[1, 2, 3], c=["a", "b", "c"]) + return ArrayNodeMapTask(functools.partial(task_1, a=1))( + b=[1, 2, 3], c=["a", "b", "c"] + ) return wf2 @@ -403,7 +465,9 @@ def get_wf_bound_input_upstream(serialization_settings): @workflow() def wf3() -> List[str]: a = task_2() - return ArrayNodeMapTask(task_1, bound_inputs_values={"a": a})(b=[1, 2, 3], c=["a", "b", "c"]) + return ArrayNodeMapTask(task_1, bound_inputs_values={"a": a})( + b=[1, 2, 3], c=["a", "b", "c"] + ) return wf3 @@ -413,7 +477,9 @@ def get_wf_partials_upstream(serialization_settings): @workflow() def wf4() -> List[str]: a = task_2() - return ArrayNodeMapTask(functools.partial(task_1, a=a))(b=[1, 2, 3], c=["a", "b", "c"]) + return ArrayNodeMapTask(functools.partial(task_1, a=a))( + b=[1, 2, 3], c=["a", "b", "c"] + ) return wf4 @@ -422,7 +488,9 @@ def get_wf_bound_input_partials_collision(serialization_settings): @workflow() def wf5() -> List[str]: - return ArrayNodeMapTask(functools.partial(task_1, a=1), bound_inputs_values={"a": 2})(b=[1, 2, 3], c=["a", "b", "c"]) + return ArrayNodeMapTask( + functools.partial(task_1, a=1), bound_inputs_values={"a": 2} + )(b=[1, 2, 3], c=["a", "b", "c"]) return wf5 @@ -431,7 +499,9 @@ def get_wf_bound_input_overrides(serialization_settings): @workflow() def wf6() -> List[str]: - return ArrayNodeMapTask(task_1, bound_inputs_values={"a": 1})(a=[1, 2, 3], b=[1, 2, 3], c=["a", "b", "c"]) + return ArrayNodeMapTask(task_1, bound_inputs_values={"a": 1})( + a=[1, 2, 3], b=[1, 2, 3], c=["a", "b", "c"] + ) return wf6 @@ -455,16 +525,52 @@ def promise_binding(node_id, var): @pytest.mark.parametrize( ("wf", "upstream_nodes", "expected_inputs"), [ - (get_wf_bound_input, {}, {"a": get_int_binding(1), "b": B_BINDINGS_LIST, "c": C_BINDINGS_LIST}), - (get_wf_partials, {}, {"a": get_int_binding(1), "b": B_BINDINGS_LIST, "c": C_BINDINGS_LIST}), - (get_wf_bound_input_upstream, {"n0"}, {"a": promise_binding("n0", "o0"), "b": B_BINDINGS_LIST, "c": C_BINDINGS_LIST}), - (get_wf_partials_upstream, {"n0"}, {"a": promise_binding("n0", "o0"), "b": B_BINDINGS_LIST, "c": C_BINDINGS_LIST}), - (get_wf_bound_input_partials_collision, {}, {"a": get_int_binding(2), "b": B_BINDINGS_LIST, "c": C_BINDINGS_LIST}), - (get_wf_bound_input_overrides, {}, {"a": get_int_binding(1), "b": B_BINDINGS_LIST, "c": C_BINDINGS_LIST}), - ] + ( + get_wf_bound_input, + {}, + {"a": get_int_binding(1), "b": B_BINDINGS_LIST, "c": C_BINDINGS_LIST}, + ), + ( + get_wf_partials, + {}, + {"a": get_int_binding(1), "b": B_BINDINGS_LIST, "c": C_BINDINGS_LIST}, + ), + ( + get_wf_bound_input_upstream, + {"n0"}, + { + "a": promise_binding("n0", "o0"), + "b": B_BINDINGS_LIST, + "c": C_BINDINGS_LIST, + }, + ), + ( + get_wf_partials_upstream, + {"n0"}, + { + "a": promise_binding("n0", "o0"), + "b": B_BINDINGS_LIST, + "c": C_BINDINGS_LIST, + }, + ), + ( + get_wf_bound_input_partials_collision, + {}, + {"a": get_int_binding(2), "b": B_BINDINGS_LIST, "c": C_BINDINGS_LIST}, + ), + ( + get_wf_bound_input_overrides, + {}, + {"a": get_int_binding(1), "b": B_BINDINGS_LIST, "c": C_BINDINGS_LIST}, + ), + ], ) -def test_bound_inputs_serialization(wf, upstream_nodes, expected_inputs, serialization_settings): - wf_spec = get_serializable(OrderedDict(), serialization_settings, wf(serialization_settings)) +def test_bound_inputs_serialization( + wf, upstream_nodes, expected_inputs, serialization_settings +): + wf_spec = get_serializable( + OrderedDict(), serialization_settings, wf(serialization_settings) + ) assert len(wf_spec.template.nodes) == len(upstream_nodes) + 1 parent_node = wf_spec.template.nodes[len(upstream_nodes)] @@ -505,7 +611,9 @@ def some_task1(inputs: int) -> int: @workflow def my_wf1() -> typing.List[typing.Optional[int]]: - return map_task(some_task1, min_success_ratio=min_success_ratio)(inputs=[1, 2, 3, 4]) + return map_task(some_task1, min_success_ratio=min_success_ratio)( + inputs=[1, 2, 3, 4] + ) if should_raise_error: with pytest.raises(ValueError): @@ -555,7 +663,9 @@ def t1(a: int) -> typing.Optional[int]: t1, min_success_ratio=0.9, concurrency=10, - metadata=TaskMetadata(retries=2, interruptible=True, timeout=timedelta(seconds=10)) + metadata=TaskMetadata( + retries=2, interruptible=True, timeout=timedelta(seconds=10) + ), ) assert arraynode_maptask.metadata.interruptible @@ -563,7 +673,9 @@ def t1(a: int) -> typing.Optional[int]: def wf(x: typing.List[int]): return arraynode_maptask(a=x) - full_state_array_node_map_task = map_task(PythonFunctionTaskExtension(task_config={}, task_function=t1)) + full_state_array_node_map_task = map_task( + PythonFunctionTaskExtension(task_config={}, task_function=t1) + ) @workflow def wf1(x: typing.List[int]): @@ -577,7 +689,9 @@ def wf1(x: typing.List[int]): assert array_node.array_node._min_success_ratio == 0.9 assert array_node.array_node._parallelism == 10 assert not array_node.array_node._is_original_sub_node_interface - assert array_node.array_node._execution_mode == _core_workflow.ArrayNode.MINIMAL_STATE + assert ( + array_node.array_node._execution_mode == _core_workflow.ArrayNode.MINIMAL_STATE + ) task_spec = od[arraynode_maptask] assert task_spec.template.metadata.retries.retries == 2 assert task_spec.template.metadata.interruptible @@ -609,9 +723,7 @@ def wf(x: typing.List[int]): def test_serialization_extended_resources_shared_memory(serialization_settings): - @task( - shared_memory="2Gi" - ) + @task(shared_memory="2Gi") def t1(a: int) -> int: return a + 1 @@ -630,30 +742,26 @@ def wf(x: typing.List[int]): def test_supported_node_type(): @task - def test_task(): - ... + def test_task(): ... map_task(test_task) def test_unsupported_node_types(): @dynamic - def test_dynamic(): - ... + def test_dynamic(): ... with pytest.raises(ValueError): map_task(test_dynamic) @eager - def test_eager(): - ... + def test_eager(): ... with pytest.raises(ValueError): map_task(test_eager) @workflow - def test_wf(): - ... + def test_wf(): ... with pytest.raises(ValueError): map_task(test_wf) @@ -692,9 +800,11 @@ def say_hello(name: str) -> str: ctx = context_manager.FlyteContextManager.current_context() with context_manager.FlyteContextManager.with_context( - ctx.with_execution_state( - ctx.execution_state.with_params(mode=context_manager.ExecutionState.Mode.TASK_EXECUTION) + ctx.with_execution_state( + ctx.execution_state.with_params( + mode=context_manager.ExecutionState.Mode.TASK_EXECUTION ) + ) ) as ctx: list_strs = ["a", "b", "c"] lt = TypeEngine.to_literal_type(typing.List[str]) @@ -709,9 +819,7 @@ def say_hello(name: str) -> str: ), ) - lm = LiteralMap({ - "name": literal - }) + lm = LiteralMap({"name": literal}) for index, map_input_str in enumerate(list_strs): monkeypatch.setenv("BATCH_JOB_ARRAY_INDEX_VAR_NAME", "name") @@ -719,5 +827,51 @@ def say_hello(name: str) -> str: t = map_task(say_hello) res = t.dispatch_execute(ctx, lm) assert len(res.literals) == 1 - assert res.literals[f"o{0}"].scalar.primitive.string_value == f"hello {map_input_str}!" + assert ( + res.literals[f"o{0}"].scalar.primitive.string_value + == f"hello {map_input_str}!" + ) monkeypatch.undo() + + +@pytest.mark.skipif( + sys.platform in ["darwin", "win32"], + reason="Skip if running on windows or macos due to CI Docker environment setup failure", +) +def test_container_task_map_execution(serialization_settings): + # NOTE: We only take one output "area" even if this calculate-ellipse-area.py + # produce two output. This is because that map task can only return one value. + calculate_ellipse_area_python_template_style = ContainerTask( + name="calculate_ellipse_area_python_template_style", + input_data_dir="/var/inputs", + output_data_dir="/var/outputs", + inputs=kwtypes(a=float, b=float), + outputs=kwtypes(area=float), + image="ghcr.io/flyteorg/rawcontainers-python:v2", + command=[ + "python", + "calculate-ellipse-area.py", + "{{.inputs.a}}", + "{{.inputs.b}}", + "/var/outputs", + ], + ) + + @workflow + def wf(a: list[float], b: float) -> list[float]: + partial_task = functools.partial( + calculate_ellipse_area_python_template_style, b=b + ) + res = map_task(partial_task)(a=a) + return res + + def calculate_area(a, b): + return math.pi * a * b + + expected_area = [ + calculate_area(a, b) for a, b in [(3.0, 4.0), (4.0, 4.0), (5.0, 4.0)] + ] + + res = wf(a=[3.0, 4.0, 5.0], b=4.0) + + assert res == expected_area