Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix execute ir pointer #856

Closed
wants to merge 8 commits into from
Closed

Fix execute ir pointer #856

wants to merge 8 commits into from

Conversation

wsmoses
Copy link
Member

@wsmoses wsmoses commented Mar 8, 2025

No description provided.

@wsmoses wsmoses requested review from giordano and avik-pal March 8, 2025 02:37
@wsmoses
Copy link
Member Author

wsmoses commented Mar 8, 2025

@avik-pal after this bugfix goes in we should probably fix the calling conv of the other ones to be a similar llvmcall

src/xla/XLA.jl Outdated
Comment on lines 120 to 133
@static if !Sys.isapple()
lljit = Enzyme.LLVM.JuliaOJIT()
jd_main = Enzyme.LLVM.JITDylib(lljit)

for name in ("XLAExecute", "XLAExecuteSharded", "ifrt_loaded_executable_execute")
ptr = Libdl.dlsym(Reactant_jll.libReactantExtra_handle, name)
Enzyme.LLVM.define(
jd_main,
Enzyme.Compiler.JIT.absolute_symbol_materialization(
Enzyme.LLVM.mangle(lljit, name), ptr
),
)
end
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This a better solution than doing RTLD_GLOBAL.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You mean that with this we don't actually need JuliaPackaging/Yggdrasil#10706? I'd be happy to revert that change 🙂

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No this is still required even with that (from local testing).

Specifically loading a precompiled .so fails with just this and is fixed by the global

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I’ll also note here that we don’t use the JIT from llvm.jl. The exclusive purpose of this is such that these symbols are findable from Julia’s JIT

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the concrete error when loading a plgimage?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will get when by computer, but cannot find symbol XLAExecuteSharded when loading GordonBell2025/…/.so when doing PRONTOLab/GB-25#31

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that this code without the global works normally without the precompiled code done in the linked issue

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

/home/wmoses/git/Enzyme.jl/julia11/julia: symbol lookup error: /home/wmoses/.julia/compiled/v1.11/GordonBell25/qFJ42_9Shjc.so: undefined symbol: XLAExecuteSharded

@giordano
Copy link
Member

giordano commented Mar 9, 2025

Besides x86_64 macOS failures due to #867, all integration tests are failing on Linux (which is affected by making all symbols global): https://github.com/EnzymeAD/Reactant.jl/actions/runs/13740137086/job/38428385024?pr=856. Sounds like we really need a solution to avoid that.

@vchuravy
Copy link
Member

vchuravy commented Mar 9, 2025

Sounds like we really need a solution to avoid that.

I have an idea.

Comment on lines +130 to +131
N1 = $(length(libname)+1)
N2 = $(length(fname)+1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
N1 = $(length(libname)+1)
N2 = $(length(fname)+1)
N1 = $(length(libname) + 1)
N2 = $(length(fname) + 1)

@@ -84,7 +84,7 @@ PythonCall = "0.9"
Random = "1.10"
Random123 = "1.7"
ReactantCore = "0.1.5"
Reactant_jll = "0.0.83"
Reactant_jll = "0.0.84"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should probably test with the old version (without rtld_global)

Suggested change
Reactant_jll = "0.0.84"
Reactant_jll = "0.0.83"

@giordano
Copy link
Member

On the merge commit of this PR: https://github.com/EnzymeAD/Reactant.jl/actions/runs/13764654846/job/38488238961#step:9:539

findfirst / findlast: Error During Test at /home/runner/work/Reactant.jl/Reactant.jl/test/sorting.jl:182
  Test threw exception
  Expression: ffirstlinindices(x) == #= /home/runner/work/Reactant.jl/Reactant.jl/test/sorting.jl:182 =# @jit(findfirst(x_ra))
  ArgumentError: invalid index: nothing of type Nothing
  Stacktrace:
    [1] to_index(i::Nothing)
      @ Base ./indices.jl:315
    [2] to_index(A::LinearIndices{2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}, i::Nothing)
      @ Base ./indices.jl:292
    [3] to_indices
      @ ./indices.jl:368 [inlined]
    [4] to_indices
      @ ./indices.jl:360 [inlined]
    [5] getindex
      @ ./abstractarray.jl:1312 [inlined]
    [6] (::Main.var"##Sorting#243".var"#ffirstlinindices#36")(x::Matrix{Bool})
      @ Main.var"##Sorting#243" ~/work/Reactant.jl/Reactant.jl/test/sorting.jl:177
    [7] macro expansion
      @ /opt/hostedtoolcache/julia/1.11.3/x64/share/julia/stdlib/v1.11/Test/src/Test.jl:676 [inlined]
    [8] macro expansion
      @ ~/work/Reactant.jl/Reactant.jl/test/sorting.jl:182 [inlined]
    [9] macro expansion
      @ /opt/hostedtoolcache/julia/1.11.3/x64/share/julia/stdlib/v1.11/Test/src/Test.jl:1704 [inlined]
   [10] top-level scope
      @ ~/work/Reactant.jl/Reactant.jl/test/sorting.jl:174

I've never seen it before, is that new with this change?

@wsmoses
Copy link
Member Author

wsmoses commented Mar 10, 2025

odd....in any case this is resolved by the other pr (which was just merged), so closing this

@wsmoses wsmoses closed this Mar 10, 2025
@giordano giordano deleted the exir branch March 10, 2025 15:41
@giordano
Copy link
Member

Sorry, I posted in the wrong PR, it was meant to be in #869 (comment)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants