1
- from typing import Any , Type
2
- from collections .abc import Hashable
3
-
4
1
"""
5
2
Finch performs extensive rewriting and defining of functions. The Finch
6
3
compiler is designed to inspect objects and functions defined by other
22
19
23
20
```python
24
21
from finch import register_property
25
- register_property(complex, '__add__', 'is_associative', lambda obj: True)
22
+
23
+ register_property(complex, "__add__", "is_associative", lambda obj: True)
26
24
```
27
25
28
26
Finch includes a convenience functions to query each property as well,
29
27
for example:
30
28
```python
31
29
from finch import query_property
32
30
from operator import add
33
- query_property(complex, '__add__', 'is_associative')
31
+
32
+ query_property(complex, "__add__", "is_associative")
34
33
# True
35
34
is_associative(add, complex, complex)
36
35
# True
37
36
```
38
37
39
- Properties can be inherited in the same way as methods. First we check whether properties have been defined for the object itself (in the case of functions), then we check For example, if you
40
- register a property for a class, all subclasses of that class will inherit
41
- that property. This allows you to define properties for a class and have
42
- them automatically apply to all subclasses, without having to register the
43
- property for each subclass individually.
38
+ Properties can be inherited in the same way as methods. First we check whether
39
+ properties have been defined for the object itself (in the case of functions), then we
40
+ check For example, if you register a property for a class, all subclasses of that class
41
+ will inherit that property. This allows you to define properties for a class and have
42
+ them automatically apply to all subclasses, without having to register the property for
43
+ each subclass individually.
44
44
"""
45
+
45
46
import operator
46
- from typing import Union
47
+ from collections .abc import Hashable
48
+ from typing import Any , Callable
49
+
47
50
import numpy as np
48
51
49
- _properties = {}
52
+ _properties : dict [ tuple [ Hashable , str , str ], Any ] = {}
50
53
51
- def query_property (obj , attr , prop , * args ):
54
+
55
+ def query_property (obj : Hashable , attr : str , prop : Hashable , * args : Any ) -> Any :
52
56
"""Queries a property of an attribute of an object or class. Properties can
53
57
be overridden by calling register_property on the object or it's class.
54
58
@@ -64,22 +68,28 @@ def query_property(obj, attr, prop, *args):
64
68
Raises:
65
69
NotImplementedError: If the property is not implemented for the given type.
66
70
"""
67
- if isinstance (obj , type ):
68
- T = obj
69
- else :
70
- if isinstance (obj , Hashable ):
71
- if (obj , attr , prop ) in _properties :
72
- return _properties [(obj , attr , prop )](obj , * args )
71
+ T = obj
72
+ if not isinstance (obj , Hashable ):
73
73
T = type (obj )
74
- while True :
75
- if (T , attr , prop ) in _properties :
76
- return _properties [(T , attr , prop )](obj , * args )
77
- if T is object :
78
- break
79
- T = T .__base__
74
+ to_query = {T }
75
+ queried : set [type ] = set ()
76
+ while len (to_query ) != len (queried ):
77
+ to_query_new = to_query .copy ()
78
+ for o in to_query :
79
+ method = _properties .get ((o , attr , prop ), None )
80
+ if method is not None :
81
+ return method (obj , * args )
82
+ queried .add (o )
83
+ if not isinstance (o , type ):
84
+ to_query_new .update (type (o ))
85
+ continue
86
+ to_query_new .update (o .__mro__ )
87
+ to_query = to_query_new
88
+
80
89
raise NotImplementedError (f"Property { prop } not implemented for { type (obj )} " )
81
90
82
- def register_property (cls , attr , prop , f ):
91
+
92
+ def register_property (cls : type , attr : str , prop : str , f : Callable ) -> None :
83
93
"""Registers a property for a class or object.
84
94
85
95
Args:
@@ -90,6 +100,7 @@ def register_property(cls, attr, prop, f):
90
100
"""
91
101
_properties [(cls , attr , prop )] = f
92
102
103
+
93
104
def fill_value (arg : Any ) -> Any :
94
105
"""The fill value for the given argument. The fill value is the
95
106
default value for a tensor when it is created with a given shape and dtype,
@@ -104,11 +115,15 @@ def fill_value(arg: Any) -> Any:
104
115
Raises:
105
116
NotImplementedError: If the fill value is not implemented for the given type.
106
117
"""
107
- return query_property (arg , '__self__' , 'fill_value' )
118
+ return query_property (arg , "__self__" , "fill_value" )
119
+
120
+
121
+ register_property (
122
+ np .ndarray , "__self__" , "fill_value" , lambda x : np .zeros ((), dtype = x .dtype )[()]
123
+ )
108
124
109
- register_property (np .ndarray , '__self__' , 'fill_value' , lambda x : np .zeros ((), dtype = x .dtype )[()])
110
125
111
- def element_type (arg : Any ) -> Type :
126
+ def element_type (arg : Any ) -> type :
112
127
"""The element type of the given argument. The element type is the scalar type of
113
128
the elements in a tensor, which may be different from the data type of the
114
129
tensor.
@@ -122,9 +137,16 @@ def element_type(arg: Any) -> Type:
122
137
Raises:
123
138
NotImplementedError: If the element type is not implemented for the given type.
124
139
"""
125
- return query_property (arg , '__self__' , 'element_type' )
140
+ return query_property (arg , "__self__" , "element_type" )
141
+
142
+
143
+ register_property (
144
+ np .ndarray ,
145
+ "__self__" ,
146
+ "element_type" ,
147
+ lambda x : type (np .zeros ((), dtype = x .dtype )[()]),
148
+ )
126
149
127
- register_property (np .ndarray , '__self__' , 'element_type' , lambda x : type (np .zeros ((), dtype = x .dtype )[()]))
128
150
129
151
def return_type (op : Any , * args : Any ) -> Any :
130
152
"""The return type of the given function on the given argument types.
@@ -136,7 +158,8 @@ def return_type(op: Any, *args: Any) -> Any:
136
158
Returns:
137
159
The return type of op(*args: arg_types)
138
160
"""
139
- return query_property (op , '__call__' , 'return_type' , * args )
161
+ return query_property (op , "__call__" , "return_type" , * args )
162
+
140
163
141
164
StableNumber = (np .number , bool , int , float , complex )
142
165
@@ -158,34 +181,52 @@ def return_type(op: Any, *args: Any) -> Any:
158
181
}
159
182
160
183
for op , (meth , rmeth ) in _reflexive_operators .items ():
161
- register_property (op , '__call__' , 'return_type' , lambda op , a , b : query_property (a , meth , 'return_type' , b ) if hasattr (a , meth ) else query_property (b , rmeth , 'return_type' , a )),
184
+ (
185
+ register_property (
186
+ op ,
187
+ "__call__" ,
188
+ "return_type" ,
189
+ lambda op , a , b , meth = meth , rmeth = rmeth : query_property (
190
+ a , meth , "return_type" , b
191
+ )
192
+ if hasattr (a , meth )
193
+ else query_property (b , rmeth , "return_type" , a ),
194
+ ),
195
+ )
196
+
162
197
def _return_type (meth ):
163
198
def _return_type_closure (a , b ):
164
199
if issubclass (b , StableNumber ):
165
200
return type (getattr (a (True ), meth )(b (True )))
166
- else :
167
- raise TypeError (f"Unsupported operand type for { type (a )} .{ meth } : { type (b )} " )
201
+ raise TypeError (
202
+ f"Unsupported operand type for { type (a )} .{ meth } : { type (b )} "
203
+ )
204
+
168
205
return _return_type_closure
206
+
169
207
for T in StableNumber :
170
- register_property (T , meth , 'return_type' , _return_type (meth ))
171
- register_property (T , rmeth , 'return_type' , _return_type (rmeth ))
208
+ register_property (T , meth , "return_type" , _return_type (meth ))
209
+ register_property (T , rmeth , "return_type" , _return_type (rmeth ))
210
+
172
211
173
212
def is_associative (op : Any ) -> bool :
174
213
"""Returns whether the given function is associative, that is, whether the
175
214
op(op(a, b), c) == op(a, op(b, c)) for all a, b, c.
176
215
177
216
Args:
178
217
op: The function to check.
179
-
218
+
180
219
Returns:
181
220
True if the function can be proven to be associative, False otherwise.
182
221
"""
183
- return query_property (op , '__call__' , 'is_associative' )
222
+ return query_property (op , "__call__" , "is_associative" )
223
+
184
224
185
225
for op in [operator .add , operator .mul , operator .and_ , operator .xor , operator .or_ ]:
186
- register_property (op , ' __call__' , ' is_associative' , lambda op : True )
226
+ register_property (op , " __call__" , " is_associative" , lambda op : True )
187
227
188
- def fixpoint_type (op : Any , z : Any , T : Type ) -> Type :
228
+
229
+ def fixpoint_type (op : Any , z : Any , T : type ) -> type :
189
230
"""Determines the fixpoint type after repeated calling the given operation.
190
231
191
232
Args:
@@ -200,9 +241,12 @@ def fixpoint_type(op: Any, z: Any, T: Type) -> Type:
200
241
R = type (z )
201
242
while R not in S :
202
243
S .add (R )
203
- R = return_type (op , type (z ), T ) # Assuming `op` is a callable that takes `z` and `T` as arguments
244
+ R = return_type (
245
+ op , type (z ), T
246
+ ) # Assuming `op` is a callable that takes `z` and `T` as arguments
204
247
return R
205
248
249
+
206
250
def init_value (op , arg ) -> Any :
207
251
"""Returns the initial value for a reduction operation on the given type.
208
252
@@ -214,17 +258,24 @@ def init_value(op, arg) -> Any:
214
258
The initial value for the given operation and type.
215
259
216
260
Raises:
217
- NotImplementedError: If the initial value is not implemented for the given type and operation.
261
+ NotImplementedError: If the initial value is not implemented for the given type
262
+ and operation.
218
263
"""
219
- return query_property (op , '__call__' , 'init_value' , arg )
264
+ return query_property (op , "__call__" , "init_value" , arg )
265
+
220
266
221
267
for op in [operator .add , operator .mul , operator .and_ , operator .xor , operator .or_ ]:
222
268
(meth , rmeth ) = _reflexive_operators [op ]
223
- register_property (op , '__call__' , 'init_value' , lambda op , arg : query_property (arg , meth , 'init_value' , arg ))
269
+ register_property (
270
+ op ,
271
+ "__call__" ,
272
+ "init_value" ,
273
+ lambda op , arg , meth = meth : query_property (arg , meth , "init_value" , arg ),
274
+ )
224
275
225
276
for T in StableNumber :
226
- register_property (T , ' __add__' , ' init_value' , lambda a , b : a (False ))
227
- register_property (T , ' __mul__' , ' init_value' , lambda a , b : a (True ))
228
- register_property (T , ' __and__' , ' init_value' , lambda a , b : a (True ))
229
- register_property (T , ' __xor__' , ' init_value' , lambda a , b : a (False ))
230
- register_property (T , ' __or__' , ' init_value' , lambda a , b : a (False ))
277
+ register_property (T , " __add__" , " init_value" , lambda a , b : a (False ))
278
+ register_property (T , " __mul__" , " init_value" , lambda a , b : a (True ))
279
+ register_property (T , " __and__" , " init_value" , lambda a , b : a (True ))
280
+ register_property (T , " __xor__" , " init_value" , lambda a , b : a (False ))
281
+ register_property (T , " __or__" , " init_value" , lambda a , b : a (False ))
0 commit comments