From 7c07d2450b95c496cfbfcabc1c7184cf3c2e88d6 Mon Sep 17 00:00:00 2001 From: Billy Moses Date: Sat, 19 Apr 2025 16:44:36 -0500 Subject: [PATCH 1/5] Enable other reactant tests --- test/ext_reactant/reactant.jl | 17 ++++++----------- 1 file changed, 6 insertions(+), 11 deletions(-) diff --git a/test/ext_reactant/reactant.jl b/test/ext_reactant/reactant.jl index ee91fb35fc..f4703b01bd 100644 --- a/test/ext_reactant/reactant.jl +++ b/test/ext_reactant/reactant.jl @@ -19,7 +19,7 @@ end (Flux.Scale([1.0f0 2.0f0 3.0f0 4.0f0], true, abs2), randn(Float32, 2), "Flux.Scale"), - # all arguments must have at least the same length of the firs one + # all arguments must have at least the same length of the first one # a = (Conv((3, 3), 2 => 3),) # b = ((σ = nothing, weight = Float32[-0.169722 -0.12912463 0.026297366; -0.08920034 -0.11879107 -0.30971745; -0.11957143 0.3129449 0.32124594;;; 0.011128465 0.12124362 0.096895896; -0.29864514 -0.053307496 0.055420622; -0.30712044 0.2959723 0.5099815;;;; -0.169722 -0.12912463 0.026297366; -0.08920034 -0.11879107 -0.30971745; -0.11957143 0.3129449 0.32124594;;; 0.011128465 0.12124362 0.096895896; -0.29864514 -0.053307496 0.055420622; -0.30712044 0.2959723 0.5099815;;;; -0.169722 -0.12912463 0.026297366; -0.08920034 -0.11879107 -0.30971745; -0.11957143 0.3129449 0.32124594;;; 0.011128465 0.12124362 0.096895896; -0.29864514 -0.053307496 0.055420622; -0.30712044 0.2959723 0.5099815], bias = Float32[0.33333334, 0.33333334, 0.33333334], stride = nothing, pad = nothing, dilation = nothing, groups = nothing),) # (Conv((3, 3), 2 => 3), randn(Float32, 3, 3, 2, 1), "Conv"), @@ -29,21 +29,17 @@ end # b = ((layers = ((σ = nothing, weight = Float32[0.2703631 0.15815677 0.2918554; 0.20036785 0.43450722 0.3525422; 0.3541182 0.32077286 0.44091386;;; 0.3233156 0.08538988 0.25763267; 0.413441 0.66042584 0.16991; 0.36993486 0.5990643 0.10123589;;;; 0.45728725 0.500834 0.46808332; 0.3662355 0.35068494 0.27277413; 0.44974697 0.47245422 0.10595817;;; 0.36255562 0.6111583 0.52779496; 0.27237993 0.25857046 0.33643073; 0.6679214 0.066386 0.32072845;;;; -0.4879305 -0.59246373 -0.59834677; -0.55097836 -0.5006755 -0.4233263; -0.72177917 -0.65806544 -0.38224664;;; -0.4765812 -0.6856963 -0.5864509; -0.6547631 -0.55094117 -0.38632843; -0.74521375 -0.3817107 -0.48642716], bias = Float32[0.7159346, 0.7152501, -1.0509125], stride = nothing, pad = nothing, dilation = nothing, groups = nothing), (σ = nothing, weight = Float32[0.32858944 -0.10135343 -0.25303265; -0.13622479 0.023095237 0.1746222; 0.18829267 -0.5047879 0.07125988;;; 0.023820637 -0.06595295 -0.003393827; -0.111125976 0.0023178488 0.08700531; -0.073591515 0.057915907 0.048598815;;; 0.016056929 -0.5129501 -0.15588683; -0.3756476 -0.09993523 -0.45654622; -0.3688693 -0.33078116 -0.4093926;;;;], bias = Float32[0.77964276], stride = nothing, pad = nothing, dilation = nothing, groups = nothing)),),) # (Chain(Conv((3, 3), 2 => 3, ), Conv((3, 3), 3 => 1, tanh)), rand(Float32, 5, 5, 2, 1), "Chain(Conv, Conv)"), - # https://github.com/EnzymeAD/Enzyme-JAX/issues/221 - # (Chain(Conv((4, 4), 2 => 2, pad=SamePad()), MeanPool((5, 5), pad=SamePad())), rand(Float32, 5, 5, 2, 2), "Chain(Conv, MeanPool)"), + (Chain(Conv((4, 4), 2 => 2, pad=SamePad()), MeanPool((5, 5), pad=SamePad())), rand(Float32, 5, 5, 2, 2), "Chain(Conv, MeanPool)"), (Maxout(() -> Dense(5 => 4, tanh), 3), randn(Float32, 5, 1), "Maxout"), - # error: 'stablehlo.multiply' op requires compatible types for all operands and results - # This requires an issue to be opened. - # (SkipConnection(Dense(2 => 2), vcat), randn(Float32, 2, 3), "SkipConnection"), + (SkipConnection(Dense(2 => 2), vcat), randn(Float32, 2, 3), "SkipConnection"), (Flux.Bilinear((2, 2) => 3), randn(Float32, 2, 1), "Bilinear"), - # error: inferred shape '[1, 3, 9, 9]' is incompatible with return type of operation 'tensor<1x3x5x5xf32>' - # (ConvTranspose((3, 3), 3 => 2, stride=2), rand(Float32, 5, 5, 3, 1), "ConvTranspose"), + (ConvTranspose((3, 3), 3 => 2, stride=2), rand(Float32, 5, 5, 3, 1), "ConvTranspose"), - # (BatchNorm(2), randn(Float32, 2, 10), "BatchNorm"), # Apparent correctness issue + (BatchNorm(2), randn(Float32, 2, 10), "BatchNorm"), ] for (model, x, name) in models_xs @@ -54,8 +50,7 @@ end end models_xs = [ - # %23 = "stablehlo.gather"(%22, %0) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<2x10xf32>, tensor<1x2xi64>) -> tensor<1x1xf32> - # (first ∘ LayerNorm(2), randn(Float32, 2, 10), "LayerNorm"), # Zygote comparison test fails on the GPUArraysCore.@allowscalar in scalarfirst, so we globally allow scalar + (first ∘ LayerNorm(2), randn(Float32, 2, 10), "LayerNorm"), # Zygote comparison test fails on the GPUArraysCore.@allowscalar in scalarfirst, so we globally allow scalar # Structural mismatch? # a = (first ∘ MultiHeadAttention(16; nheads=8),) From 4e4829d6ea2b07e197f94000cddd37e5df91e9fc Mon Sep 17 00:00:00 2001 From: William Moses Date: Sat, 19 Apr 2025 19:20:54 -0500 Subject: [PATCH 2/5] Update reactant.jl --- test/ext_reactant/reactant.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/ext_reactant/reactant.jl b/test/ext_reactant/reactant.jl index f4703b01bd..dd6a523988 100644 --- a/test/ext_reactant/reactant.jl +++ b/test/ext_reactant/reactant.jl @@ -33,7 +33,8 @@ end (Maxout(() -> Dense(5 => 4, tanh), 3), randn(Float32, 5, 1), "Maxout"), - (SkipConnection(Dense(2 => 2), vcat), randn(Float32, 2, 3), "SkipConnection"), + + # (SkipConnection(Dense(2 => 2), vcat), randn(Float32, 2, 3), "SkipConnection"), (Flux.Bilinear((2, 2) => 3), randn(Float32, 2, 1), "Bilinear"), From b02526cc3afc79edc47bd7c754eaa3988d563fe5 Mon Sep 17 00:00:00 2001 From: William Moses Date: Mon, 28 Apr 2025 13:52:21 -0500 Subject: [PATCH 3/5] Update reactant.jl --- test/ext_reactant/reactant.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/ext_reactant/reactant.jl b/test/ext_reactant/reactant.jl index dd6a523988..ca8edc89d0 100644 --- a/test/ext_reactant/reactant.jl +++ b/test/ext_reactant/reactant.jl @@ -33,7 +33,7 @@ end (Maxout(() -> Dense(5 => 4, tanh), 3), randn(Float32, 5, 1), "Maxout"), - + # # (SkipConnection(Dense(2 => 2), vcat), randn(Float32, 2, 3), "SkipConnection"), (Flux.Bilinear((2, 2) => 3), randn(Float32, 2, 1), "Bilinear"), From 3ca08a140450008f06a78b07559940f2a185c92f Mon Sep 17 00:00:00 2001 From: William Moses Date: Thu, 1 May 2025 14:18:46 -0500 Subject: [PATCH 4/5] Update reactant.jl --- test/ext_reactant/reactant.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/test/ext_reactant/reactant.jl b/test/ext_reactant/reactant.jl index ca8edc89d0..38aeed4c53 100644 --- a/test/ext_reactant/reactant.jl +++ b/test/ext_reactant/reactant.jl @@ -33,7 +33,6 @@ end (Maxout(() -> Dense(5 => 4, tanh), 3), randn(Float32, 5, 1), "Maxout"), - # # (SkipConnection(Dense(2 => 2), vcat), randn(Float32, 2, 3), "SkipConnection"), (Flux.Bilinear((2, 2) => 3), randn(Float32, 2, 1), "Bilinear"), From 18c6b29432008a2521c15b3edfdcd8a02c19d806 Mon Sep 17 00:00:00 2001 From: William Moses Date: Thu, 1 May 2025 17:19:42 -0500 Subject: [PATCH 5/5] Update reactant.jl --- test/ext_reactant/reactant.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/ext_reactant/reactant.jl b/test/ext_reactant/reactant.jl index 38aeed4c53..dac0eeeada 100644 --- a/test/ext_reactant/reactant.jl +++ b/test/ext_reactant/reactant.jl @@ -50,7 +50,7 @@ end end models_xs = [ - (first ∘ LayerNorm(2), randn(Float32, 2, 10), "LayerNorm"), # Zygote comparison test fails on the GPUArraysCore.@allowscalar in scalarfirst, so we globally allow scalar + # (first ∘ LayerNorm(2), randn(Float32, 2, 10), "LayerNorm"), # Zygote comparison test fails on the GPUArraysCore.@allowscalar in scalarfirst, so we globally allow scalar # Structural mismatch? # a = (first ∘ MultiHeadAttention(16; nheads=8),)