-
Notifications
You must be signed in to change notification settings - Fork 38
Open
Description
julia> typeof(y)
Slices{ConcretePJRTArray{Float32, 3, 1}, Tuple{Colon, Colon, Int64}, Tuple{Base.OneTo{Int64}}, SubArray{Float32, 2, ConcretePJRTArray{Float32, 3, 1}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Int64}, true}, 1}
julia> @code_hlo sum(y)
module @reactant_sum attributes {mhlo.num_partitions = 1 : i64, mhlo.num_replicas = 1 : i64} {
func.func @main(%arg0: tensor<101x16x32xf32> {enzymexla.memory_effects = []}) -> tensor<16x32xf32> attributes {enzymexla.memory_effects = []} {
%cst = stablehlo.constant dense<0.000000e+00> : tensor<f32>
%0 = stablehlo.reduce(%arg0 init: %cst) applies stablehlo.add across dimensions = [0] : (tensor<101x16x32xf32>, tensor<f32>) -> tensor<16x32xf32>
return %0 : tensor<16x32xf32>
}
}we fold quite aggressively and see no runtime difference, but the size of the unoptimized IR scales linearly with the slice dimension
Metadata
Metadata
Assignees
Labels
No labels