@@ -55,19 +55,23 @@ class MapFusionVertical(transformation.SingleStateTransformation):
5555
5656 In order to determine if an intermediate can be removed or has to be kept, it is in general
5757 necessary to scan the whole SDFG, which is the default behaviour. There are two ways to
58- speed this up. The first way is to set `assume_always_shared` to `True`. In this case the
59- transformation will not perform the scan, but assume that the data is shared, i.e. used
60- somewhere else. This might lead to dead data flow.
61- The second way is to use the transformation inside a pipeline, which includes the
62- `FindSingleUseData` analysis pass, see note below. If the result of this pass is present then the
63- transformation will use it instead to determine if a intermediate can be removed. Note that
64- `assume_always_shared` takes precedence. For this pattern the `FullMapFusion` pass is provided,
65- that combines the analysis pass and `MapFusionVertical`.
58+ speed this up. The second way is to use the transformation inside a pipeline, which includes the
59+ `FindSingleUseData` analysis pass, see note below. If the result of this pass is present then
60+ the transformation will use it instead to determine if a intermediate can be removed.
61+ The second way is to either specify `assume_always_shared` or `assume_always_single_use_data`
62+ (see note below). Which instruct the transformation to assume that the intermediate is either
63+ shared, i.e. will become an output of the fused Map or that it is not used anywhere else
64+ and will be removed (which means that if it is specified but the data is used somewhere else,
65+ an invalid SDFG will be produced). However, the function will always perform local checks,
66+ i.e. checks the degree of the intermediate and if it refers to a non-transient data. It is
67+ important that. Furthermore, if a pipeline is present it will take precedence.
6668
6769 :param only_inner_maps: Only match Maps that are internal, i.e. inside another Map.
6870 :param only_toplevel_maps: Only consider Maps that are at the top.
6971 :param strict_dataflow: Which dataflow mode should be used, see above.
7072 :param assume_always_shared: Assume that all intermediates are shared.
73+ :param assume_always_single_use_data: Assume that all intermediates are single use data,
74+ i.e. are no longer needed.
7175 :param consolidate_edges_only_if_not_extending: If `True`, the default is `False`,
7276 the transformation will only consolidate edges if this does not lead to an
7377 extension of the subset.
@@ -76,12 +80,15 @@ class MapFusionVertical(transformation.SingleStateTransformation):
7680 goes to the same AccessNode.
7781
7882 :note: This transformation modifies more nodes than it matches.
79- :note: If `assume_always_shared` is `True` then the transformation will assume that
80- all intermediates are shared. This avoids the problems mentioned above with
81- the cache at the expense of the creation of dead dataflow.
83+ :note: While it is always "safe" to specify the `assume_always_shared`, this is not true
84+ for `assume_always_single_use_data`. Specifying it when the data is used somewhere
85+ else will lead to invalid behaviour. Thus only specify it if you know what you are doing.
86+ :note: The flags `assume_always_shared` and `assume_always_single_use_data` are intended
87+ when it is clear from the usage context what should happen to the intermediate. This is
88+ most often, but not always the case if `can_be_applied_to()` or `apply_to()` are used.
8289 :note: Because of [issue#1911](https://github.com/spcl/dace/issues/1911) the `can_be_applied()`
83- can not use the pipeline result and will thus scan the whole SDFG. The `FullMapFusion`
84- pass is not affected by this.
90+ can not use the pipeline result and will thus scan the whole SDFG. The `FullMapFusion`
91+ pass is not affected by this.
8592 """
8693
8794 first_map_exit = transformation .transformation .PatternNode (nodes .MapExit )
@@ -99,16 +106,23 @@ class MapFusionVertical(transformation.SingleStateTransformation):
99106 default = False ,
100107 desc = "Only perform fusing if the Maps are inner Maps, i.e. does not have top level scope." ,
101108 )
109+
102110 strict_dataflow = properties .Property (
103111 dtype = bool ,
104112 default = True ,
105113 desc = "If `True` then the transformation will ensure a more stricter data flow." ,
106114 )
115+
107116 assume_always_shared = properties .Property (
108117 dtype = bool ,
109118 default = False ,
110119 desc = "If `True` then all intermediates will be classified as shared." ,
111120 )
121+ assume_always_single_use_data = properties .Property (
122+ dtype = bool ,
123+ default = False ,
124+ desc = "If `True` then all intermediates are classified as single use data." ,
125+ )
112126
113127 never_consolidate_edges = properties .Property (
114128 dtype = bool ,
@@ -127,6 +141,7 @@ def __init__(
127141 only_toplevel_maps : Optional [bool ] = None ,
128142 strict_dataflow : Optional [bool ] = None ,
129143 assume_always_shared : Optional [bool ] = None ,
144+ assume_always_single_use_data : Optional [bool ] = None ,
130145 consolidate_edges_only_if_not_extending : Optional [bool ] = None ,
131146 never_consolidate_edges : Optional [bool ] = None ,
132147 ** kwargs : Any ,
@@ -140,11 +155,16 @@ def __init__(
140155 self .strict_dataflow = strict_dataflow
141156 if assume_always_shared is not None :
142157 self .assume_always_shared = assume_always_shared
158+ if assume_always_single_use_data is not None :
159+ self .assume_always_single_use_data = assume_always_single_use_data
143160 if never_consolidate_edges is not None :
144161 self .never_consolidate_edges = never_consolidate_edges
145162 if consolidate_edges_only_if_not_extending is not None :
146163 self .consolidate_edges_only_if_not_extending = consolidate_edges_only_if_not_extending
147164
165+ if self .assume_always_shared and self .assume_always_single_use_data :
166+ raise ValueError ('Specified both `assume_always_single_use_data` and `assume_always_shared`.' )
167+
148168 # See comment in `is_shared_data()` for more information.
149169 # NOTE: `_pipeline_result` will take precedence over this value.
150170 self ._single_use_data : Optional [Dict [dace .SDFG , Set [str ]]] = None
@@ -1271,62 +1291,69 @@ def is_shared_data(
12711291 ) -> bool :
12721292 """Tests if `data` is shared data, i.e. it can not be removed from the SDFG.
12731293
1274- Depending on the situation, the function will not perform a scan of the whole SDFG:
1275- 1) If `assume_always_shared` was set to `True`, the function will return `True` unconditionally.
1276- 2) If `data` is non transient then the function will return `True`, as non transient data
1277- must be reconstructed always.
1278- 3) If the AccessNode `data` has more than one outgoing edge or more than one incoming edge
1294+ The function returns `True` is `data` refers to shared data and `False` otherwise.
1295+ The process to determine this is as follows:
1296+ 1) If the AccessNode `data` has more than one outgoing edge or more than one incoming edge
12791297 it is classified as shared.
1280- 2) If `FindSingleUseData` is in the pipeline it will be used and no scan will be performed.
1281- 3) The function will perform a scan.
1298+ 2) If `data` refers to non transient memory, the function returns `False`.
1299+ 3) If `FindSingleUseData` is in the pipeline it will be used. I.e. it will check if `data`
1300+ is in the set and return `True` or `False` otherwise.
1301+ 4) If `assume_always_shared` is `True` the function will return `False`.
1302+ 5) If `assume_always_single_use_data` is `True` the function will return `True`.
1303+ 6) The function will perform a scan of the SDFG.
1304+
1305+ This order means that if a pipeline is present, then it will take precedence.
12821306
12831307 :param data: The transient that should be checked.
12841308 :param state: The state in which the fusion is performed.
12851309 :param sdfg: The SDFG in which we want to perform the fusing.
1286-
12871310 """
1288- # `assume_always_shared` takes precedence.
1289- if self .assume_always_shared :
1290- return True
1291-
1292- # If `data` is non transient then return `True` as the intermediate can not be removed.
1293- if not data .desc (sdfg ).transient :
1294- return True
1311+ assert not (self .assume_always_shared and self .assume_always_single_use_data )
12951312
12961313 # This means the data is consumed by multiple Maps, through the same AccessNode, in this state
12971314 # Note currently multiple incoming edges are not handled, but in the spirit of this function
12981315 # we consider such AccessNodes as shared, because we can not remove the intermediate.
1316+ # TODO(phimuell): If one of the two Maps has multiple connections to the intermediate,
1317+ # then this detection will fail here. This is not a problem because this is currently
1318+ # not supported. But still.
12991319 if state .out_degree (data ) > 1 :
13001320 return True
13011321 if state .in_degree (data ) > 1 :
13021322 return True
13031323
1304- # NOTE: Actually, if this transformation is run through the `FullMapFusion` pass, it should
1305- # read the results from `FindSingelUseData`, that was computed because it is a dependent
1306- # pass through the `self._pipeline_results` which is set by the `SingleStateTransformation`.
1307- # However, this member is only set during when `apply()` is called, but not during
1308- # `can_be_applied()`, see [issue#1911](https://github.com/spcl/dace/issues/1911).
1309- # Because, the whole goal of this separation of scanning and fusion was to make the
1310- # transformation stateless, the member `_single_use_data` was introduced. If it is set
1311- # then we use it otherwise we use the scanner.
1312- # This value is set for example by the `FullMapFusion` pass.
1313- # TODO(phimuell): Change this once the issue is resolved.
1324+ # Non transient data must be reconstructed anyways, so it is by definition shared.
1325+ if not data .desc (sdfg ).transient :
1326+ return True
13141327
1328+ # NOTE: Actually, if this transformation is run inside a pipeline, which specified
1329+ # `FindSingelUseData` as a dependent pass, it should read the cached data through
1330+ # `self._pipeline_results`. However, this member is only set during the `apply()`
1331+ # function but not during `can_be_applied()`, see [issue#1911](https://github.com/spcl/dace/issues/1911).
1332+ # Since we also need the information during `can_be_applied()`, we would still scan the
1333+ # SDFG. To avoid that the special member `_single_use_data` was introduced, which
1334+ # allows to specify this from the outside. This is not nice, because it gives the
1335+ # transformation state and every parent transformation must do that.
1336+ # TODO(phimuell): Change this once the issue is resolved.
13151337 single_use_data = None
13161338 if self ._pipeline_results is not None and "FindSingelUseData" in self ._pipeline_results :
13171339 single_use_data = self ._pipeline_results ["FindSingelUseData" ]
1318-
13191340 elif self ._single_use_data is not None :
13201341 single_use_data = self ._single_use_data
1321- else :
1322- # We have to perform the full scan of the SDFG.
1323- return self ._scan_sdfg_if_data_is_shared (data = data , state = state , sdfg = sdfg )
1324-
1325- assert single_use_data is not None
1326- assert sdfg in single_use_data , (
1327- f"`_single_use_data` was set, but does not contain information about the SDFG '{ sdfg .name } '." )
1328- single_use_data_sdfg : Set [str ] = single_use_data [sdfg ]
1329- return data .data not in single_use_data_sdfg
1342+ # The single use data was present so scan it.
1343+ if single_use_data is not None :
1344+ assert sdfg in single_use_data
1345+ return data .data not in single_use_data [sdfg ]
1346+
1347+ # If we are here, then we were unable to locate a previous result of `FindSingelUseData`.
1348+ # However, before we perform a scan, we check if the shortcuts, i.e. `assume_always_*`
1349+ # was specified and use them.
1350+ if self .assume_always_shared :
1351+ return True
1352+ elif self .assume_always_single_use_data :
1353+ return False
1354+
1355+ # Perform the real scan.
1356+ return self ._scan_sdfg_if_data_is_shared (data = data , state = state , sdfg = sdfg )
13301357
13311358 def _scan_sdfg_if_data_is_shared (
13321359 self ,
0 commit comments