@@ -871,6 +871,52 @@ def clone(self):
871
871
res .fgraph = res .fgraph .clone ()
872
872
return res
873
873
874
+ def make_thunk (self , node , storage_map , compute_map , no_recycling , impl = None ):
875
+ from pytensor .link .c .basic import CLinker
876
+ from pytensor .link .vm import VMLinker
877
+
878
+ # FIXME: Don't call self.fn just to get the optimized fgraph
879
+ fg = self .fn .maker .fgraph
880
+ # fg = self.fgraph
881
+ # rewriter = get_default_mode().optimizer
882
+ # rewriter(fg)
883
+ fg_no_recycling = [
884
+ new_o
885
+ for (new_o , old_o ) in zip (fg .outputs , node .outputs , strict = True )
886
+ if old_o in no_recycling
887
+ ]
888
+
889
+ node_input_storage = [storage_map [r ] for r in node .inputs ]
890
+ node_output_storage = [storage_map [r ] for r in node .outputs ]
891
+
892
+ def create_thunk (linker ):
893
+ linker .accept (fg , no_recycling = fg_no_recycling )
894
+ thunk , _ , _ = linker .make_thunk (
895
+ input_storage = node_input_storage , output_storage = node_output_storage
896
+ )
897
+
898
+ if isinstance (linker , VMLinker ):
899
+ # VMs will complain if a non-lazy thunk returns anything
900
+ # We wrap it in a function that returns None
901
+ def thunk_without_returns ():
902
+ thunk ()
903
+
904
+ return thunk_without_returns
905
+
906
+ return thunk
907
+
908
+ if impl != "py" :
909
+ try :
910
+ # We default to CLinker because it generates code for the whole graph that the compiler can reason about.
911
+ # Whereas the VMLinker will compile each node separately and call them in a pre-defined VM.
912
+ # It also has less overhead
913
+ return create_thunk (linker = CLinker ())
914
+ except NotImplementedError :
915
+ # Some Op doesn't have a C implementation, VM it is
916
+ return create_thunk (linker = VMLinker (use_cloop = True , c_thunks = True ))
917
+ else :
918
+ return create_thunk (VMLinker (use_cloop = False , c_thunks = False ))
919
+
874
920
def perform (self , node , inputs , outputs ):
875
921
variables = self .fn (* inputs )
876
922
assert len (variables ) == len (outputs )
0 commit comments