Skip to content

Commit 6ac586c

Browse files
committed
fixes
1 parent cb2d20a commit 6ac586c

File tree

3 files changed

+48
-3
lines changed

3 files changed

+48
-3
lines changed

ext/ReactantCUDAExt.jl

+11
Original file line numberDiff line numberDiff line change
@@ -419,6 +419,17 @@ function Adapt.adapt_storage(::ReactantKernelAdaptor, xs::TracedRNumber{T}) wher
419419
return res
420420
end
421421

422+
import Reactant.TracedRNumberOverrides.TracedStepRangeLen
423+
424+
function Adapt.adapt_storage(::ReactantKernelAdaptor, r::TracedStepRangeLen)
425+
return TracedStepRangeLen(
426+
Adapt.adapt(ReactantKernelAdaptor(), r.ref),
427+
Adapt.adapt(ReactantKernelAdaptor(), r.step),
428+
Adapt.adapt(ReactantKernelAdaptor(), r.len),
429+
Adapt.adapt(ReactantKernelAdaptor(), r.offset),
430+
)
431+
end
432+
422433
# Since we cache these objects we cannot cache data containing MLIR operations (e.g. the entry must be a string
423434
# and not the operation itself).
424435
struct LLVMFunc{F,tt}

src/Compiler.jl

+8-2
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ end
3333
end
3434

3535
@inline function traced_getfield(@nospecialize(obj::AbstractArray{T}), field) where {T}
36-
(isbitstype(T) || ancestor(obj) isa RArray) && return Base.getfield(obj, field)
36+
(isbitstype(T) || ancestor(obj) isa RArray || obj isa AbstractRange) && return Base.getfield(obj, field)
3737
return Base.getindex(obj, field)
3838
end
3939

@@ -1471,7 +1471,11 @@ function codegen_flatten!(
14711471
is_sharded &&
14721472
runtime isa Val{:PJRT} &&
14731473
(flatten_names = vcat(eachrow(reshape(flatten_names, length(mesh), :))...))
1474-
1474+
Core.println("####")
1475+
Core.println("$(quote
1476+
$(flatten_code...)
1477+
end)")
1478+
Core.println("####")
14751479
return flatten_names, flatten_code
14761480
end
14771481

@@ -1914,6 +1918,8 @@ function compile(f, args; sync=false, kwargs...)
19141918
return result
19151919
end
19161920

1921+
Core.println("$body")
1922+
19171923
return register_thunk(
19181924
fname,
19191925
Tuple{map(Core.Typeof, args)...},

src/TracedRNumber.jl

+29-1
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,21 @@ function Base.ifelse(
262262
end
263263
end
264264

265+
function Base.:*(x::Base.TwicePrecision{T}, y::Base.TwicePrecision{T}) where {T<:TracedRNumber}
266+
zh, zl = Base.mul12(x.hi, y.hi)
267+
hi, lo = Base.canonicalize2(zh, (x.hi * y.lo + x.lo * y.hi) + zl)
268+
hi = ifelse(iszero(zh) | !isfinite(zh), zh, hi)
269+
lo = ifelse(iszero(zl) | !isfinite(zl), zl, lo)
270+
271+
return Base.TwicePrecision{T}(hi, lo)
272+
end
273+
274+
function Base.:+(x::Base.TwicePrecision{T}, y::Base.TwicePrecision{T}) where {T<:TracedRNumber}
275+
r = x.hi + y.hi
276+
@trace s = abs(x.hi) > abs(y.hi) ? begin; (((x.hi - r) + y.hi) + y.lo) + x.lo; end : begin; (((y.hi - r) + x.hi) + x.lo) + y.lo; end
277+
Base.TwicePrecision(Base.canonicalize2(r, s)...)
278+
end
279+
265280
for (T1, T2) in zip((Bool, Integer), (Bool, Integer))
266281
T = promote_type(T1, T2)
267282
@eval begin
@@ -451,12 +466,25 @@ function Base.iterate(r::TracedStepRangeLen, i::Integer=1)
451466
length(r) < i && return nothing
452467
Base.unsafe_getindex(r, i), i
453468
end
469+
470+
errorcount = Ref(0)
471+
454472
function Base.unsafe_getindex(
455473
r::TracedStepRangeLen{T},
456474
i::Integer,
457475
) where {T}
458476
u = oftype(r.offset, i) - r.offset
459-
@warn T r.ref + u*r.step
477+
# @warn T typeof(r.ref + u*r.step)
478+
# @warn T typeof(r.ref) typeof(u) typeof(r.step)
479+
# test = r.ref*u
480+
# @info which(*, (typeof(r.ref), typeof(u)))
481+
# @info typeof(test)
482+
# return r.ref + u*r.step
483+
# errorcount[] += 1
484+
# if errorcount[] == 1
485+
# errorcount[] = 0
486+
# error("stop")
487+
# end
460488
T(r.ref + u*r.step)
461489
end
462490
function getindex(r::TracedStepRangeLen{T}, s::OrdinalRange{S}) where {T, S<:Integer}

0 commit comments

Comments
 (0)