-
Notifications
You must be signed in to change notification settings - Fork 55
fix[next-dace]: Add entry-point synchronization #2527
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
9f3ad7a
6107c44
90f0a2c
3989019
2e34e30
953aa6a
a820ae0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -159,7 +159,9 @@ def _make_if_region_for_metrics_collection( | |
| return if_region, then_state | ||
|
|
||
|
|
||
| def add_instrumentation(sdfg: dace.SDFG, gpu: bool) -> None: | ||
| def add_instrumentation( | ||
| sdfg: dace.SDFG, gpu: bool, sync_states: tuple[dace.SDFGState, dace.SDFGState] | None | ||
| ) -> None: | ||
| """ | ||
| Instrument SDFG with measurement of total execution time. | ||
|
|
||
|
|
@@ -169,16 +171,51 @@ def add_instrumentation(sdfg: dace.SDFG, gpu: bool) -> None: | |
|
|
||
| The execution time is measured in seconds and represented as a 'float64' value. | ||
| It is written to the global array 'SDFG_ARG_METRIC_COMPUTE_TIME'. | ||
|
|
||
| Args: | ||
| sdfg: The SDFG to be instrumented with time measurements. | ||
| gpu: Flag that specifies if the SDFG is targeting GPU execution. | ||
| sync_states: If provided, a tuple of two states, the source state and the | ||
| sync state of the SDFG, containing tasklets with GPU device synchronization. | ||
| """ | ||
| output, _ = sdfg.add_array(gtx_wfdcommon.SDFG_ARG_METRIC_COMPUTE_TIME, [1], dace.float64) | ||
| start_time, _ = sdfg.add_scalar("gt_start_time", dace.int64, transient=True) | ||
| metrics_level = sdfg.add_symbol(gtx_wfdcommon.SDFG_ARG_METRIC_LEVEL, dace.int32) | ||
|
|
||
| #### 1. Synchronize the CUDA device, in order to wait for kernels completion. | ||
| entry_if_region, begin_state = _make_if_region_for_metrics_collection( | ||
| "metrics_entry", metrics_level, sdfg | ||
| ) | ||
| exit_if_region, end_state = _make_if_region_for_metrics_collection( | ||
| "metrics_exit", metrics_level, sdfg | ||
| ) | ||
| if sync_states is None: | ||
| # Use the newly created entry if-region as new source node | ||
| for source_state in sdfg.source_nodes(): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This assumes that all nodes in the SDFG are states and we do not have nested controlflow regions, which is currently true. |
||
| if source_state not in [entry_if_region, exit_if_region]: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would slightly restructure the code to something like: |
||
| sdfg.add_edge(entry_if_region, source_state, dace.InterstateEdge()) | ||
| source_state.is_start_block = False | ||
| assert sdfg.out_degree(entry_if_region) > 0 | ||
| entry_if_region.is_start_block = True | ||
| # Similarly, the exit if-region as sink node. | ||
| for sink_state in sdfg.sink_nodes(): | ||
| if sink_state not in [entry_if_region, exit_if_region]: | ||
| sdfg.add_edge(sink_state, exit_if_region, dace.InterstateEdge()) | ||
| assert sdfg.in_degree(exit_if_region) > 0 | ||
| else: | ||
| # Keep the existing synchronization points, and put the entry if-region after the entry state | ||
| entry_state, exit_state = sync_states | ||
| for edge in list(sdfg.out_edges(entry_state)): | ||
| sdfg.add_edge(entry_if_region, edge.dst, edge.data) | ||
| sdfg.remove_edge(edge) | ||
| sdfg.add_edge(entry_state, entry_if_region, dace.InterstateEdge()) | ||
| # Put the exit if-region right after the exit state. | ||
| sdfg.add_edge(exit_state, exit_if_region, dace.InterstateEdge()) | ||
|
|
||
| #### 1. Synchronize the CUDA device if the sync states are not provided. | ||
| # Even when the target device is GPU, it can happen that dace emits code without | ||
| # GPU kernels. In this case, the cuda headers are not imported and the SDFG is | ||
| # compiled as plain C++. Therefore, we also check here the schedule of SDFG maps. | ||
| if gpu and _has_gpu_schedule(sdfg): | ||
| if gpu and _has_gpu_schedule(sdfg) and sync_states is None: | ||
| dace_gpu_backend = dace.Config.get("compiler.cuda.backend") | ||
| assert dace_gpu_backend in ["cuda", "hip"], f"GPU backend '{dace_gpu_backend}' is unknown." | ||
|
|
||
|
|
@@ -193,29 +230,19 @@ def add_instrumentation(sdfg: dace.SDFG, gpu: bool) -> None: | |
| has_side_effects = False | ||
|
|
||
| #### 2. Timestamp the SDFG entry point. | ||
| entry_if_region, begin_state = _make_if_region_for_metrics_collection( | ||
| "program_entry", metrics_level, sdfg | ||
| ) | ||
|
|
||
| for source_state in sdfg.source_nodes(): | ||
| if source_state is entry_if_region: | ||
| continue | ||
| sdfg.add_edge(entry_if_region, source_state, dace.InterstateEdge()) | ||
| source_state.is_start_block = False | ||
| assert sdfg.out_degree(entry_if_region) > 0 | ||
| entry_if_region.is_start_block = True | ||
|
|
||
| tlet_start_timer = begin_state.add_tasklet( | ||
| "gt_start_timer", | ||
| inputs={}, | ||
| outputs={"time"}, | ||
| code="""\ | ||
| code=sync_code | ||
| + """\ | ||
| auto now = std::chrono::high_resolution_clock::now(); | ||
| time = std::chrono::duration_cast<std::chrono::nanoseconds>( | ||
| now.time_since_epoch() | ||
| ).count(); | ||
| """, | ||
| language=dace.dtypes.Language.CPP, | ||
| side_effects=has_side_effects, | ||
| ) | ||
| begin_state.add_edge( | ||
| tlet_start_timer, | ||
|
|
@@ -226,17 +253,6 @@ def add_instrumentation(sdfg: dace.SDFG, gpu: bool) -> None: | |
| ) | ||
|
|
||
| #### 3. Collect the SDFG end timestamp and produce the compute metric. | ||
| exit_if_region, end_state = _make_if_region_for_metrics_collection( | ||
| "program_exit", metrics_level, sdfg | ||
| ) | ||
|
|
||
| for sink_state in sdfg.sink_nodes(): | ||
| if sink_state is exit_if_region: | ||
| continue | ||
| sdfg.add_edge(sink_state, exit_if_region, dace.InterstateEdge()) | ||
| assert sdfg.in_degree(exit_if_region) > 0 | ||
|
|
||
| # Populate the branch that computes the stencil time metric | ||
| tlet_stop_timer = end_state.add_tasklet( | ||
| "gt_stop_timer", | ||
| inputs={"run_cpp_start_time"}, | ||
|
|
@@ -286,70 +302,62 @@ def add_instrumentation(sdfg: dace.SDFG, gpu: bool) -> None: | |
| sdfg.validate() | ||
|
|
||
|
|
||
| def make_sdfg_call_sync(sdfg: dace.SDFG, gpu: bool) -> None: | ||
| def make_sdfg_call_sync(sdfg: dace.SDFG, gpu: bool) -> tuple[dace.SDFGState, dace.SDFGState] | None: | ||
| """Process the SDFG such that the call is synchronous. | ||
|
|
||
| This means that `CompiledSDFG.fast_call()` will return only after all computations | ||
| have _finished_ and the results are available. This function only has an effect for | ||
| work that runs on the GPU. Furthermore, all work is scheduled on the default stream. | ||
|
|
||
| Todo: Revisit this function once DaCe changes its behaviour in this regard. | ||
| Returns: | ||
| The SDFG entry and exit states, each calling the GPU primitive for device synchronization. | ||
| """ | ||
|
|
||
| if not gpu: | ||
| # This is only a problem on GPU. Dace uses OpenMP on CPU and | ||
| # the OpenMP parallel region creates a synchronization point. | ||
| return | ||
| return None | ||
| elif not _has_gpu_schedule(sdfg): | ||
| # Even when the target device is GPU, it can happen that dace | ||
| # emits code without GPU kernels. In this case, the cuda headers | ||
| # are not imported and the SDFG is compiled as plain C++. | ||
| return | ||
| return None | ||
|
|
||
| assert dace.Config.get("compiler.cuda.max_concurrent_streams") == -1, ( | ||
| f"Expected `max_concurrent_streams == -1` but it was `{dace.Config.get('compiler.cuda.max_concurrent_streams')}`." | ||
| ) | ||
| dace_gpu_backend = dace.Config.get("compiler.cuda.backend") | ||
| assert dace_gpu_backend in ["cuda", "hip"], f"GPU backend '{dace_gpu_backend}' is unknown." | ||
|
|
||
| # If we are using the default stream, things are a bit simpler/harder. For some | ||
| # reasons when using the default stream, DaCe seems to skip _all_ synchronization, | ||
| # for more see [DaCe issue#2120](https://github.com/spcl/dace/issues/2120). | ||
| # Thus the `CompiledSDFG.fast_call()` call is truly asynchronous, i.e. just | ||
| # launches the kernels and then exist. Thus we have to add a synchronization | ||
| # at the end to have a synchronous call. We can not use `SDFG.append_exit_code()` | ||
| # because that code is only run at the `exit()` stage, not after a call. Thus we | ||
| # will generate an SDFGState that contains a Tasklet with the sync call. | ||
| sync_state = sdfg.add_state("sync_state") | ||
| for sink_state in sdfg.sink_nodes(): | ||
| if sink_state is sync_state: | ||
| entry_state = sdfg.add_state("sync_entry") | ||
| for source_state in sdfg.source_nodes(): | ||
| if source_state is entry_state: | ||
| continue | ||
| sdfg.add_edge(sink_state, sync_state, dace.InterstateEdge()) | ||
| assert sdfg.in_degree(sync_state) > 0 | ||
| sdfg.add_edge(entry_state, source_state, dace.InterstateEdge()) | ||
| source_state.is_start_block = False | ||
| assert sdfg.out_degree(entry_state) > 0 | ||
| entry_state.is_start_block = True | ||
|
|
||
| # NOTE: Since the synchronization is done through the Tasklet explicitly, | ||
| # we can disable synchronization for the last state. Might be useless. | ||
| sync_state.nosync = True | ||
| exit_state = sdfg.add_state("sync_exit") | ||
| for sink_state in sdfg.sink_nodes(): | ||
| if sink_state is exit_state: | ||
| continue | ||
| sdfg.add_edge(sink_state, exit_state, dace.InterstateEdge()) | ||
| assert sdfg.in_degree(exit_state) > 0 | ||
|
|
||
| # NOTE: We should actually wrap the `StreamSynchronize` function inside a | ||
| # NOTE: We should actually wrap the `DeviceSynchronize` function inside a | ||
| # `DACE_GPU_CHECK()` macro. However, this only works in GPU context, but | ||
| # here we are in CPU context. Thus we can not do it. | ||
| dace_gpu_backend = dace.Config.get("compiler.cuda.backend") | ||
| assert dace_gpu_backend in ["cuda", "hip"], f"GPU backend '{dace_gpu_backend}' is unknown." | ||
| sync_state.add_tasklet( | ||
| "sync_tlet", | ||
| inputs=set(), | ||
| outputs=set(), | ||
| code=f"{dace_gpu_backend}StreamSynchronize({dace_gpu_backend}StreamDefault);", | ||
| language=dace.dtypes.Language.CPP, | ||
| side_effects=True, | ||
| ) | ||
|
|
||
| # DaCe [still generates a stream](https://github.com/spcl/dace/blob/54c935cfe74a52c5107dc91680e6201ddbf86821/dace/codegen/targets/cuda.py#L467) | ||
| # despite not using it. Thus to be absolutely sure, we will not set that stream | ||
| # to the default stream. | ||
| sdfg.append_init_code( | ||
| f"__dace_gpu_set_all_streams(__state, {dace_gpu_backend}StreamDefault);", | ||
| location="cuda", | ||
| ) | ||
| # NOTE: Since the synchronization is done through the Tasklet explicitly, | ||
| # we can disable synchronization for the state. | ||
| for state in [entry_state, exit_state]: | ||
| state.add_tasklet( | ||
| "sync_tlet", | ||
| inputs=set(), | ||
| outputs=set(), | ||
| code=f"{dace_gpu_backend}DeviceSynchronize();", | ||
| language=dace.dtypes.Language.CPP, | ||
| side_effects=True, | ||
| ) | ||
| state.nosync = True | ||
| return entry_state, exit_state | ||
|
|
||
|
|
||
| @dataclasses.dataclass(frozen=True) | ||
|
|
@@ -436,11 +444,12 @@ def _generate_sdfg_without_configuring_dace( | |
|
|
||
| if self.async_sdfg_call: | ||
| make_sdfg_call_async(sdfg, on_gpu) | ||
| sync_states = None | ||
| else: | ||
| make_sdfg_call_sync(sdfg, on_gpu) | ||
| sync_states = make_sdfg_call_sync(sdfg, on_gpu) | ||
|
|
||
| if self.use_metrics: | ||
| add_instrumentation(sdfg, on_gpu) | ||
| add_instrumentation(sdfg, on_gpu, sync_states) | ||
|
|
||
| return sdfg | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What do you mean with "source state" if you mean the first state in the SDFG then I would suggest using
sdfg.start_block/sdfg.start_state.