@@ -1807,7 +1807,7 @@ defmodule EXLA.Defn.ExprTest do
18071807 indices = Nx . tensor ( [ [ 0 ] ] )
18081808 updates = Nx . tensor ( [ 1 ] )
18091809
1810- assert_equal ( indexed_add ( target , indices , updates ) , Nx . tensor ( [ 1 ] , type: { :s , 64 } ) )
1810+ assert_equal ( indexed_add ( target , indices , updates ) , Nx . tensor ( [ 1 ] , type: { :s , 32 } ) )
18111811
18121812 target = Nx . tensor ( [ 0 ] )
18131813 indices = Nx . tensor ( [ [ 0 ] ] )
@@ -1879,7 +1879,7 @@ defmodule EXLA.Defn.ExprTest do
18791879 indices = Nx . tensor ( [ [ 0 ] ] )
18801880 updates = Nx . tensor ( [ 1 ] )
18811881
1882- assert_equal ( indexed_put ( target , indices , updates ) , Nx . tensor ( [ 1 ] , type: { :s , 64 } ) )
1882+ assert_equal ( indexed_put ( target , indices , updates ) , Nx . tensor ( [ 1 ] , type: { :s , 32 } ) )
18831883
18841884 target = Nx . tensor ( [ 0 ] )
18851885 indices = Nx . tensor ( [ [ 0 ] ] )
@@ -1963,7 +1963,7 @@ defmodule EXLA.Defn.ExprTest do
19631963 test "computes the sum across types" do
19641964 assert_equal ( Nx . tensor ( [ 1 , 2 , 3 ] ) |> sum ( ) , Nx . tensor ( 6 ) )
19651965 assert_equal ( Nx . tensor ( [ 1 , 2 , 3 ] , type: { :s , 8 } ) |> sum ( ) , Nx . tensor ( 6 ) )
1966- assert_equal ( Nx . tensor ( [ 1 , 2 , 3 ] , type: { :u , 8 } ) |> sum ( ) , Nx . tensor ( 6 , type: { :u , 64 } ) )
1966+ assert_equal ( Nx . tensor ( [ 1 , 2 , 3 ] , type: { :u , 8 } ) |> sum ( ) , Nx . tensor ( 6 , type: { :u , 32 } ) )
19671967 assert_equal ( Nx . tensor ( [ 1.0 , 2.0 , 3.0 ] ) |> sum ( ) , Nx . tensor ( 6.0 ) )
19681968
19691969 assert_equal (
@@ -1986,9 +1986,9 @@ defmodule EXLA.Defn.ExprTest do
19861986 defn sum_equal ( t ) , do: Nx . sum ( Nx . equal ( t , 1.0 ) )
19871987
19881988 test "does not overflow" do
1989- assert_equal ( sum_equal ( Nx . tensor ( 1 ) ) , Nx . tensor ( 1 , type: { :u , 64 } ) )
1990- assert_equal ( sum_equal ( Nx . tensor ( [ 1 , 1 , 1 ] ) ) , Nx . tensor ( 3 , type: { :u , 64 } ) )
1991- assert_equal ( sum_equal ( Nx . tensor ( [ 1 , 2 , 3 ] ) ) , Nx . tensor ( 1 , type: { :u , 64 } ) )
1989+ assert_equal ( sum_equal ( Nx . tensor ( 1 ) ) , Nx . tensor ( 1 , type: { :u , 32 } ) )
1990+ assert_equal ( sum_equal ( Nx . tensor ( [ 1 , 1 , 1 ] ) ) , Nx . tensor ( 3 , type: { :u , 32 } ) )
1991+ assert_equal ( sum_equal ( Nx . tensor ( [ 1 , 2 , 3 ] ) ) , Nx . tensor ( 1 , type: { :u , 32 } ) )
19921992 end
19931993
19941994 defn sum_keep ( t ) , do: Nx . sum ( t , keep_axes: true )
@@ -2011,7 +2011,7 @@ defmodule EXLA.Defn.ExprTest do
20112011 test "computes the product across types" do
20122012 assert_equal ( Nx . tensor ( [ 1 , 2 , 3 ] ) |> product ( ) , Nx . tensor ( 6 ) )
20132013 assert_equal ( Nx . tensor ( [ 1 , 2 , 3 ] , type: { :s , 8 } ) |> product ( ) , Nx . tensor ( 6 ) )
2014- assert_equal ( Nx . tensor ( [ 1 , 2 , 3 ] , type: { :u , 8 } ) |> product ( ) , Nx . tensor ( 6 , type: { :u , 64 } ) )
2014+ assert_equal ( Nx . tensor ( [ 1 , 2 , 3 ] , type: { :u , 8 } ) |> product ( ) , Nx . tensor ( 6 , type: { :u , 32 } ) )
20152015 assert_equal ( Nx . tensor ( [ 1.0 , 2.0 , 3.0 ] ) |> product ( ) , Nx . tensor ( 6.0 ) )
20162016
20172017 assert_equal (
@@ -2034,9 +2034,9 @@ defmodule EXLA.Defn.ExprTest do
20342034 defn product_equal ( t ) , do: Nx . product ( Nx . equal ( t , 1.0 ) )
20352035
20362036 test "does not overflow" do
2037- assert_equal ( product_equal ( Nx . tensor ( 1 ) ) , Nx . tensor ( 1 , type: { :u , 64 } ) )
2038- assert_equal ( product_equal ( Nx . tensor ( [ 1 , 1 , 1 ] ) ) , Nx . tensor ( 1 , type: { :u , 64 } ) )
2039- assert_equal ( product_equal ( Nx . tensor ( [ 1 , 2 , 3 ] ) ) , Nx . tensor ( 0 , type: { :u , 64 } ) )
2037+ assert_equal ( product_equal ( Nx . tensor ( 1 ) ) , Nx . tensor ( 1 , type: { :u , 32 } ) )
2038+ assert_equal ( product_equal ( Nx . tensor ( [ 1 , 1 , 1 ] ) ) , Nx . tensor ( 1 , type: { :u , 32 } ) )
2039+ assert_equal ( product_equal ( Nx . tensor ( [ 1 , 2 , 3 ] ) ) , Nx . tensor ( 0 , type: { :u , 32 } ) )
20402040 end
20412041
20422042 defn product_keep ( t ) , do: Nx . product ( t , keep_axes: true )
@@ -2416,12 +2416,12 @@ defmodule EXLA.Defn.ExprTest do
24162416 window_max2 ( Nx . tensor ( [ [ [ 1 , 2 , 3 ] , [ 4 , 5 , 6 ] ] , [ [ 1 , 2 , 3 ] , [ 4 , 5 , 6 ] ] ] ) ) ,
24172417 Nx . tensor ( [
24182418 [
2419- [ - 9_223_372_036_854_775_808 , - 9_223_372_036_854_775_808 ] ,
2420- [ - 9_223_372_036_854_775_808 , 6 ]
2419+ [ - 2_147_483_648 , - 2_147_483_648 ] ,
2420+ [ - 2_147_483_648 , 6 ]
24212421 ] ,
24222422 [
2423- [ - 9_223_372_036_854_775_808 , - 9_223_372_036_854_775_808 ] ,
2424- [ - 9_223_372_036_854_775_808 , 6 ]
2423+ [ - 2_147_483_648 , - 2_147_483_648 ] ,
2424+ [ - 2_147_483_648 , 6 ]
24252425 ]
24262426 ] )
24272427 )
@@ -2482,12 +2482,12 @@ defmodule EXLA.Defn.ExprTest do
24822482 window_min2 ( Nx . tensor ( [ [ [ 1 , 2 , 3 ] , [ 4 , 5 , 6 ] ] , [ [ 1 , 2 , 3 ] , [ 4 , 5 , 6 ] ] ] ) ) ,
24832483 Nx . tensor ( [
24842484 [
2485- [ 9_223_372_036_854_775_807 , 9_223_372_036_854_775_807 ] ,
2486- [ 9_223_372_036_854_775_807 , 3 ]
2485+ [ 2_147_483_647 , 2_147_483_647 ] ,
2486+ [ 2_147_483_647 , 3 ]
24872487 ] ,
24882488 [
2489- [ 9_223_372_036_854_775_807 , 9_223_372_036_854_775_807 ] ,
2490- [ 9_223_372_036_854_775_807 , 3 ]
2489+ [ 2_147_483_647 , 2_147_483_647 ] ,
2490+ [ 2_147_483_647 , 3 ]
24912491 ]
24922492 ] )
24932493 )
0 commit comments