@@ -663,7 +663,9 @@ defmodule EXLA.Defn do
663663 result =
664664 Value . gather (
665665 tensor ,
666- indices ,
666+ # TODO remove conversion (unsigned indices fail)
667+ # Reported in https://github.com/google/jax/issues/21547
668+ to_type ( indices , { :s , 32 } ) ,
667669 index_vector_dim ,
668670 slice_sizes ,
669671 offset_dims ,
@@ -871,6 +873,10 @@ defmodule EXLA.Defn do
871873 ) do
872874 precision = state . precision
873875
876+ # Ensure both have the same type
877+ left = to_type ( left , ans . type )
878+ right = to_type ( right , ans . type )
879+
874880 Value . dot_general (
875881 left ,
876882 right ,
@@ -1291,6 +1297,9 @@ defmodule EXLA.Defn do
12911297 defp to_operator ( :put_slice , [ % Value { } = tensor , start_indices , slice ] , ans , _state ) do
12921298 tensor = to_type ( tensor , ans . type )
12931299 slice = to_type ( slice , ans . type )
1300+ # TODO remove conversion (unsigned indices fail)
1301+ # Reported in https://github.com/google/jax/issues/21547
1302+ start_indices = Enum . map ( start_indices , & to_type ( & 1 , { :s , 32 } ) )
12941303 Value . dynamic_update_slice ( tensor , slice , start_indices , expr_to_typespec ( ans ) )
12951304 end
12961305
@@ -1313,7 +1322,9 @@ defmodule EXLA.Defn do
13131322
13141323 Value . gather (
13151324 tensor ,
1316- indices ,
1325+ # TODO remove conversion (unsigned indices fail)
1326+ # Reported in https://github.com/google/jax/issues/21547
1327+ to_type ( indices , { :s , 32 } ) ,
13171328 index_vector_dim ,
13181329 slice_sizes ,
13191330 offset_dims ,
@@ -1341,7 +1352,7 @@ defmodule EXLA.Defn do
13411352 defp to_operator ( :sort , [ % Value { } = tensor , opts ] , ans , state ) do
13421353 dimension = opts [ :axis ]
13431354
1344- op =
1355+ operator =
13451356 case opts [ :direction ] do
13461357 :asc -> :less
13471358 :desc -> :greater
@@ -1350,7 +1361,7 @@ defmodule EXLA.Defn do
13501361 arg_typespec = Typespec . tensor ( ans . type , { } )
13511362 arg_typespecs = [ arg_typespec , arg_typespec ]
13521363
1353- comp = sort_computation ( op , ans . type , arg_typespecs , state )
1364+ comp = sort_computation ( operator , ans . type , arg_typespecs , state )
13541365
13551366 Value . sort ( [ tensor ] , comp , dimension , opts [ :stable ] == true , [ expr_to_typespec ( ans ) ] ) |> hd ( )
13561367 end
@@ -1530,30 +1541,45 @@ defmodule EXLA.Defn do
15301541
15311542 ## Computation helpers
15321543
1533- defp sort_computation ( op , type , arg_typespecs , % { builder: % EXLA.MLIR.Function { } = function } ) do
1544+ defp sort_computation (
1545+ operator ,
1546+ type ,
1547+ arg_typespecs ,
1548+ % { builder: % EXLA.MLIR.Function { } = function }
1549+ ) do
15341550 { region , [ lhs , rhs | _ ] } = Function . push_region ( function , arg_typespecs )
15351551
15361552 typespec = Typespec . tensor ( { :pred , 8 } , { } )
15371553
1538- op =
1539- cond do
1540- Nx.Type . integer? ( type ) ->
1541- apply ( Value , op , [ lhs , rhs , typespec ] )
1542-
1543- op == :less ->
1544- is_nan = Value . is_nan ( rhs , typespec )
1545- Value . bitwise_or ( is_nan , Value . less ( lhs , rhs , typespec ) , typespec )
1546-
1547- op == :greater ->
1548- is_nan = Value . is_nan ( lhs , typespec )
1549- Value . bitwise_or ( is_nan , Value . greater ( lhs , rhs , typespec ) , typespec )
1554+ { lhs , rhs } =
1555+ if Nx.Type . float? ( type ) do
1556+ { canonicalize_float_for_sort ( lhs ) , canonicalize_float_for_sort ( rhs ) }
1557+ else
1558+ { lhs , rhs }
15501559 end
15511560
1561+ op = apply ( Value , operator , [ lhs , rhs , typespec , [ total_order: true ] ] )
1562+
15521563 Value . return ( function , [ op ] )
15531564 Function . pop_region ( function )
15541565 region
15551566 end
15561567
1568+ defp canonicalize_float_for_sort ( % Value { function: func } = op ) do
1569+ # Standardize the representation of NaNs (-NaN, NaN) and zeros (-0, 0).
1570+ # See https://github.com/google/jax/blob/e81c82605f0e1813080cfe1037d043b27b38291d/jax/_src/lax/lax.py#L4248-L4253
1571+
1572+ op_typespec = Value . get_typespec ( op )
1573+
1574+ zero = Value . constant ( func , [ 0 ] , Typespec . to_shape ( op_typespec , { } ) )
1575+ zeros = Value . constant ( func , [ 0 ] , op_typespec )
1576+ nans = Value . constant ( func , [ :nan ] , op_typespec )
1577+
1578+ pred_typespec = Typespec . tensor ( { :pred , 8 } , { } )
1579+ op = Value . select ( Value . equal ( op , zero , pred_typespec ) , zeros , op , op_typespec )
1580+ Value . select ( Value . is_nan ( op , pred_typespec ) , nans , op , op_typespec )
1581+ end
1582+
15571583 defp op_computation (
15581584 op ,
15591585 arg_typespecs ,
0 commit comments