Skip to content

Commit 78ad9c9

Browse files
authored
docs: update exporting_to_jax.md (#1107)
1 parent 05523d6 commit 78ad9c9

File tree

1 file changed

+2
-8
lines changed

1 file changed

+2
-8
lines changed

docs/src/manual/exporting_to_jax.md

+2-8
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ end
5959
Now we define a python script to run the model using EnzymeJAX.
6060

6161
```python
62-
from enzyme_ad.jax import primitives
62+
from enzyme_ad.jax import hlo_call
6363

6464
import jax
6565
import jax.numpy as jnp
@@ -81,7 +81,7 @@ def run_lux_model(
8181
weight6_3,
8282
bias6_3,
8383
):
84-
return primitives.ffi_call(
84+
return hlo_call(
8585
x,
8686
weight1,
8787
bias1,
@@ -93,13 +93,7 @@ def run_lux_model(
9393
bias6_2,
9494
weight6_3,
9595
bias6_3,
96-
out_shapes=[
97-
jax.core.ShapedArray([4, 10], jnp.float32),
98-
],
99-
fn="main",
10096
source=code,
101-
lang=primitives.LANG_MHLO,
102-
pipeline_options=primitives.JaXPipeline(""),
10397
)
10498

10599

0 commit comments

Comments
 (0)