@@ -1428,6 +1428,7 @@ def replace_pattern(new_pattern):
1428
1428
self .remove_nodes ,
1429
1429
self .graph_pre_visitor ,
1430
1430
self .graph_post_visitor ,
1431
+ self .as_function ,
1431
1432
)
1432
1433
1433
1434
return [replace_pattern (p ) for p in self ._target_pattern .commute ()]
@@ -1509,21 +1510,23 @@ class RewriteRuleClassBase:
1509
1510
@classmethod
1510
1511
def rule (cls , * args , ** kwargs ):
1511
1512
instance = cls (* args , ** kwargs )
1512
- setup = instance .setup if hasattr (instance , "setup" ) else None
1513
- cleanup = instance .cleanup if hasattr (instance , "cleanup" ) else None
1514
1513
return RewriteRule (
1515
1514
instance .pattern ,
1516
1515
instance .rewrite ,
1517
1516
instance .check ,
1518
1517
name = instance .name ,
1519
1518
remove_nodes = instance .remove_nodes ,
1520
- graph_pre_visitor = setup ,
1521
- graph_post_visitor = cleanup ,
1519
+ graph_pre_visitor = instance .setup ,
1520
+ graph_post_visitor = instance .cleanup ,
1521
+ as_function = instance .as_function ,
1522
1522
)
1523
1523
1524
- def __init__ (self , name : str | None = None , remove_nodes : bool = True ) -> None :
1524
+ def __init__ (
1525
+ self , name : str | None = None , remove_nodes : bool = True , as_function : bool = False
1526
+ ) -> None :
1525
1527
self .name = name or self .__class__ .__name__
1526
1528
self .remove_nodes = remove_nodes
1529
+ self .as_function = as_function
1527
1530
1528
1531
def pattern (self , op , * args , ** kwargs ):
1529
1532
raise NotImplementedError ("Method 'pattern' must be implemented by derived class." )
@@ -1535,30 +1538,52 @@ def check(self, op, *args, **kwargs):
1535
1538
def rewrite (self , op , * args , ** kwargs ):
1536
1539
raise NotImplementedError ("Method 'rewrite' must be implemented by derived class." )
1537
1540
1541
+ def setup (self ):
1542
+ # Optional setup function that can be overridden by derived classes. Used to do
1543
+ # per model/function initialization.
1544
+ pass
1545
+
1546
+ def cleanup (self ):
1547
+ # Optional cleanup function that can be overridden by derived classes. Used to do
1548
+ # per model/function cleanup.
1549
+ pass
1550
+
1538
1551
1539
1552
def _copy_for_function (
1540
1553
inputs : Sequence [ir .Value | None ], nodes : Sequence [ir .Node ], outputs : Sequence [ir .Value ]
1541
1554
):
1542
1555
"""Utility function to extract a subgraph out as a function."""
1543
1556
value_map : dict [ir .Value , ir .Value ] = {}
1544
1557
function_inputs : list [ir .Value ] = []
1558
+ constant_nodes : list [ir .Node ] = []
1545
1559
for input in inputs :
1546
1560
# Create a function input (formal-parameter value) to represent this value:
1547
- if input is None :
1548
- raise NotImplementedError ("None inputs not supported." )
1549
- new_value = ir .Value (
1550
- name = input .name ,
1551
- shape = input .shape ,
1552
- type = input .type ,
1553
- doc_string = input .doc_string ,
1561
+ new_value = (
1562
+ ir .Value (
1563
+ name = input .name ,
1564
+ shape = input .shape ,
1565
+ type = input .type ,
1566
+ doc_string = input .doc_string ,
1567
+ )
1568
+ if input
1569
+ else ir .Value () # dummy parameter for a None input
1554
1570
)
1555
- value_map [input ] = new_value
1571
+ if input is not None :
1572
+ value_map [input ] = new_value
1556
1573
function_inputs .append (new_value )
1557
1574
1558
1575
def copy_value (value : ir .Value | None ) -> ir .Value | None :
1559
1576
if value is None :
1560
1577
return None
1561
1578
if value not in value_map :
1579
+ const_value = value .const_value
1580
+ if const_value is not None :
1581
+ # create a Constant node to represent the value
1582
+ value_attr = ir .AttrTensor ("value" , const_value )
1583
+ const_node = ir .Node ("" , "Constant" , [], [value_attr ])
1584
+ constant_nodes .append (const_node )
1585
+ value_map [value ] = result = const_node .outputs [0 ]
1586
+ return result
1562
1587
raise ValueError (f"Value { value } not found in value_map." )
1563
1588
return value_map [value ]
1564
1589
@@ -1598,7 +1623,7 @@ def copy_node(node: ir.Node) -> ir.Node:
1598
1623
1599
1624
function_nodes = [copy_node (node ) for node in nodes ]
1600
1625
function_outputs = [copy_value (v ) for v in outputs ]
1601
- return (function_inputs , function_nodes , function_outputs )
1626
+ return (function_inputs , constant_nodes + function_nodes , function_outputs )
1602
1627
1603
1628
1604
1629
def _get_new_overload (model : ir .Model , domain : str , name : str ) -> str :
0 commit comments