Skip to content

Commit 95b0cc4

Browse files
committed
Avoid PyTensor function overhead in OpFromGraph
Also provide pure C-implementation when all Ops allow it.
1 parent 0f5da80 commit 95b0cc4

File tree

2 files changed

+103
-0
lines changed

2 files changed

+103
-0
lines changed

pytensor/compile/builders.py

+46
Original file line numberDiff line numberDiff line change
@@ -871,6 +871,52 @@ def clone(self):
871871
res.fgraph = res.fgraph.clone()
872872
return res
873873

874+
def make_thunk(self, node, storage_map, compute_map, no_recycling, impl=None):
875+
from pytensor.link.c.basic import CLinker
876+
from pytensor.link.vm import VMLinker
877+
878+
# FIXME: Don't call self.fn just to get the optimized fgraph
879+
fg = self.fn.maker.fgraph
880+
# fg = self.fgraph
881+
# rewriter = get_default_mode().optimizer
882+
# rewriter(fg)
883+
fg_no_recycling = [
884+
new_o
885+
for (new_o, old_o) in zip(fg.outputs, node.outputs, strict=True)
886+
if old_o in no_recycling
887+
]
888+
889+
node_input_storage = [storage_map[r] for r in node.inputs]
890+
node_output_storage = [storage_map[r] for r in node.outputs]
891+
892+
def create_thunk(linker):
893+
linker.accept(fg, no_recycling=fg_no_recycling)
894+
thunk, _, _ = linker.make_thunk(
895+
input_storage=node_input_storage, output_storage=node_output_storage
896+
)
897+
898+
if isinstance(linker, VMLinker):
899+
# VMs will complain if a non-lazy thunk returns anything
900+
# We wrap it in a function that returns None
901+
def thunk_without_returns():
902+
thunk()
903+
904+
return thunk_without_returns
905+
906+
return thunk
907+
908+
if impl != "py":
909+
try:
910+
# We default to CLinker because it generates code for the whole graph that the compiler can reason about.
911+
# Whereas the VMLinker will compile each node separately and call them in a pre-defined VM.
912+
# It also has less overhead
913+
return create_thunk(linker=CLinker())
914+
except NotImplementedError:
915+
# Some Op doesn't have a C implementation, VM it is
916+
return create_thunk(linker=VMLinker(use_cloop=True, c_thunks=True))
917+
else:
918+
return create_thunk(VMLinker(use_cloop=False, c_thunks=False))
919+
874920
def perform(self, node, inputs, outputs):
875921
variables = self.fn(*inputs)
876922
assert len(variables) == len(outputs)

tests/compile/test_builders.py

+57
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import pytest
55

66
import pytensor.tensor as pt
7+
from pytensor import scan
78
from pytensor.compile import shared
89
from pytensor.compile.builders import OpFromGraph
910
from pytensor.compile.function import function
@@ -740,3 +741,59 @@ def test_debugprint():
740741

741742
for truth, out in zip(exp_res.split("\n"), lines, strict=True):
742743
assert truth.strip() == out.strip()
744+
745+
746+
@pytest.mark.parametrize("kind", ("ofg", "inlined", "scan"))
747+
@pytest.mark.parametrize("mode", ("fast_compile", "fast_run"))
748+
@pytest.mark.parametrize("c_op", (True, False), ids=lambda x: f"c_op={x}")
749+
def test_benchmark(c_op, mode, kind, benchmark):
750+
n = 25
751+
752+
if c_op:
753+
754+
def _f(x):
755+
if isinstance(x, np.ndarray):
756+
y = np.exp(x)
757+
else:
758+
y = pt.exp(x)
759+
y /= y.sum()
760+
return y
761+
else:
762+
763+
def _f(x):
764+
if isinstance(x, np.ndarray):
765+
return np.sort(x)
766+
else:
767+
return pt.sort(x)
768+
769+
x = pt.vector("x")
770+
771+
if kind == "ofg":
772+
f = OpFromGraph([x], [_f(x)])
773+
else:
774+
f = _f
775+
776+
if kind == "scan":
777+
# Scan is included for a reference of how bad the overhead can be
778+
outs, _ = scan(fn=f, outputs_info=[x], n_steps=n)
779+
out = outs[-1]
780+
else:
781+
out = x
782+
for i in range(n):
783+
out = f(out)
784+
785+
compiled_fn = function([x], out, trust_input=True, mode=mode)
786+
compiled_fn.dprint(print_memory_map=True)
787+
compiled_fn.vm.allow_gc = (
788+
False # For fairness to the default VM, since OFG inner VM does not do GC
789+
)
790+
791+
rng = np.random.default_rng(1)
792+
x_test = rng.normal(size=(10,))
793+
794+
res = benchmark(compiled_fn, x_test)
795+
796+
expected_res = x_test
797+
for i in range(n):
798+
expected_res = _f(expected_res)
799+
np.testing.assert_allclose(res, expected_res)

0 commit comments

Comments
 (0)