Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Ops.throw #835

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open

feat: Ops.throw #835

wants to merge 4 commits into from

Conversation

avik-pal
Copy link
Collaborator

@avik-pal avik-pal commented Mar 3, 2025

Still need to register the FFI on the cpp end. Is that the best way to do this or do we have a way of registering function calls on julia end?

julia> @code_hlo error("dimension mismatch")
AssertionError("Could not find registered platform with name: \"cuda\". Available platform names are: ")
module @reactant_error attributes {mhlo.num_partitions = 1 : i64, mhlo.num_replicas = 1 : i64} {
  func.func @main() {
    %c = stablehlo.constant dense<true> : tensor<i1>
    stablehlo.custom_call @throw(%c) {api_version = 4 : i32, backend_config = {error_message = "dimension mismatch"}, has_side_effect = true} : (tensor<i1>) -> ()
    return
  }
}

@avik-pal
Copy link
Collaborator Author

avik-pal commented Mar 3, 2025

though I am not super sure if we should do this by default. There are a lot of usecases for having the runtime error functionality separately but with this all errors will be embedded into mlir by default

@avik-pal avik-pal changed the title feat: overlay error with a custom_call feat: Ops.throw Mar 3, 2025
@avik-pal avik-pal requested a review from wsmoses March 3, 2025 19:49
@avik-pal
Copy link
Collaborator Author

avik-pal commented Mar 3, 2025

julia> fn(x) = begin
           Reactant.Ops.throw("dimension mismatch")
           return x .+ 1
       end
fn (generic function with 1 method)

julia> @jit fn(Reactant.to_rarray(rand(2)))
ERROR: INTERNAL: dimension mismatch

Stacktrace:
 [1] reactant_err(msg::Cstring)
   @ Reactant.XLA ~/reactant/Reactant.jl/src/xla/Utils.jl:12
 [2] macro expansion
   @ ~/reactant/Reactant.jl/src/xla/PJRT/LoadedExecutable.jl:175 [inlined]
 [3] execute_sharded
   @ ~/reactant/Reactant.jl/src/xla/PJRT/LoadedExecutable.jl:144 [inlined]
 [4] macro expansion
   @ ~/reactant/Reactant.jl/src/Compiler.jl:1435 [inlined]
 [5] (::Reactant.Compiler.Thunk{typeof(fn), Symbol("##fn_reactant#280"), Tuple{ConcretePJRTArray{…}}, false})(args::ConcretePJRTArray{Float64, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}})
   @ Reactant.Compiler ~/reactant/Reactant.jl/src/Compiler.jl:1690
 [6] top-level scope
   @ ~/reactant/Reactant.jl/src/Compiler.jl:1036
Some type information was truncated. Use `show(err)` to see complete types.

@noinline function throw(
msg::String,
condition::Union{TracedRNumber{Bool},Nothing}=nothing;
location=mlir_stacktrace("throw", @__FILE__, @__LINE__)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
location=mlir_stacktrace("throw", @__FILE__, @__LINE__)
location=mlir_stacktrace("throw", @__FILE__, @__LINE__),

@avik-pal
Copy link
Collaborator Author

avik-pal commented Mar 3, 2025

need to fix cuda before merging this

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants