@@ -103,17 +103,38 @@ end
103103
104104@device_override Base. log (x:: Float64 ) = ccall (" extern __nv_log" , llvmcall, Cdouble, (Cdouble,), x)
105105@device_override Base. log (x:: Float32 ) = ccall (" extern __nv_logf" , llvmcall, Cfloat, (Cfloat,), x)
106+ @device_override function Base. log (x:: Float16 )
107+ if compute_capability () >= sv " 8.0"
108+ ccall (" extern __nv_hlog" , llvmcall, Float16, (Float16,), x)
109+ else
110+ return Float16 (log (Float32 (x)))
111+ end
112+ end
106113@device_override FastMath. log_fast (x:: Float32 ) = ccall (" extern __nv_fast_logf" , llvmcall, Cfloat, (Cfloat,), x)
107114
108115@device_override Base. log10 (x:: Float64 ) = ccall (" extern __nv_log10" , llvmcall, Cdouble, (Cdouble,), x)
109116@device_override Base. log10 (x:: Float32 ) = ccall (" extern __nv_log10f" , llvmcall, Cfloat, (Cfloat,), x)
117+ @device_override function Base. log10 (x:: Float16 )
118+ if compute_capability () >= sv " 8.0"
119+ ccall (" extern __nv_hlog10" , llvmcall, Float16, (Float16,), x)
120+ else
121+ return Float16 (log10 (Float32 (x)))
122+ end
123+ end
110124@device_override FastMath. log10_fast (x:: Float32 ) = ccall (" extern __nv_fast_log10f" , llvmcall, Cfloat, (Cfloat,), x)
111125
112126@device_override Base. log1p (x:: Float64 ) = ccall (" extern __nv_log1p" , llvmcall, Cdouble, (Cdouble,), x)
113127@device_override Base. log1p (x:: Float32 ) = ccall (" extern __nv_log1pf" , llvmcall, Cfloat, (Cfloat,), x)
114128
115129@device_override Base. log2 (x:: Float64 ) = ccall (" extern __nv_log2" , llvmcall, Cdouble, (Cdouble,), x)
116130@device_override Base. log2 (x:: Float32 ) = ccall (" extern __nv_log2f" , llvmcall, Cfloat, (Cfloat,), x)
131+ @device_override function Base. log2 (x:: Float16 )
132+ if compute_capability () >= sv " 8.0"
133+ ccall (" extern __nv_hlog2" , llvmcall, Float16, (Float16,), x)
134+ else
135+ return Float16 (log (Float32 (x)))
136+ end
137+ end
117138@device_override FastMath. log2_fast (x:: Float32 ) = ccall (" extern __nv_fast_log2f" , llvmcall, Cfloat, (Cfloat,), x)
118139
119140@device_function logb (x:: Float64 ) = ccall (" extern __nv_logb" , llvmcall, Cdouble, (Cdouble,), x)
@@ -127,16 +148,35 @@ end
127148
128149@device_override Base. exp (x:: Float64 ) = ccall (" extern __nv_exp" , llvmcall, Cdouble, (Cdouble,), x)
129150@device_override Base. exp (x:: Float32 ) = ccall (" extern __nv_expf" , llvmcall, Cfloat, (Cfloat,), x)
151+ @device_override function Base. exp (x:: Float16 )
152+ if compute_capability () >= sv " 8.0"
153+ ccall (" extern __nv_hexp" , llvmcall, Float16, (Float16,), x)
154+ else
155+ return Float16 (exp (Float32 (x)))
156+ end
157+ end
130158@device_override FastMath. exp_fast (x:: Float32 ) = ccall (" extern __nv_fast_expf" , llvmcall, Cfloat, (Cfloat,), x)
131159
132160@device_override Base. exp2 (x:: Float64 ) = ccall (" extern __nv_exp2" , llvmcall, Cdouble, (Cdouble,), x)
133161@device_override Base. exp2 (x:: Float32 ) = ccall (" extern __nv_exp2f" , llvmcall, Cfloat, (Cfloat,), x)
162+ @device_override function Base. exp2 (x:: Float16 )
163+ if compute_capability () >= sv " 8.0"
164+ ccall (" extern __nv_hexp2" , llvmcall, Float16, (Float16,), x)
165+ else
166+ return Float16 (exp2 (Float32 (x)))
167+ end
168+ end
134169@device_override FastMath. exp2_fast (x:: Union{Float32, Float64} ) = exp2 (x)
135- # TODO : enable once PTX > 7.0 is supported
136- # @device_override Base.exp2(x::Float16) = @asmcall("ex2.approx.f16 \$0, \$1", "=h,h", Float16, Tuple{Float16}, x)
137170
138171@device_override Base. exp10 (x:: Float64 ) = ccall (" extern __nv_exp10" , llvmcall, Cdouble, (Cdouble,), x)
139172@device_override Base. exp10 (x:: Float32 ) = ccall (" extern __nv_exp10f" , llvmcall, Cfloat, (Cfloat,), x)
173+ @device_override function Base. exp10 (x:: Float16 )
174+ if compute_capability () >= sv " 8.0"
175+ ccall (" extern __nv_hexp10" , llvmcall, Float16, (Float16,), x)
176+ else
177+ return Float16 (exp10 (Float32 (x)))
178+ end
179+ end
140180@device_override FastMath. exp10_fast (x:: Float32 ) = ccall (" extern __nv_fast_exp10f" , llvmcall, Cfloat, (Cfloat,), x)
141181
142182@device_override Base. expm1 (x:: Float64 ) = ccall (" extern __nv_expm1" , llvmcall, Cdouble, (Cdouble,), x)
204244
205245@device_override Base. isnan (x:: Float64 ) = (ccall (" extern __nv_isnand" , llvmcall, Int32, (Cdouble,), x)) != 0
206246@device_override Base. isnan (x:: Float32 ) = (ccall (" extern __nv_isnanf" , llvmcall, Int32, (Cfloat,), x)) != 0
247+ @device_override function Base. isnan (x:: Float16 )
248+ if compute_capability () >= sv " 8.0"
249+ return (ccall (" extern __nv_hisnan" , llvmcall, Int32, (Float16,), x)) != 0
250+ else
251+ return isnan (Float32 (x))
252+ end
253+ end
207254
208255@device_function nearbyint (x:: Float64 ) = ccall (" extern __nv_nearbyint" , llvmcall, Cdouble, (Cdouble,), x)
209256@device_function nearbyint (x:: Float32 ) = ccall (" extern __nv_nearbyintf" , llvmcall, Cfloat, (Cfloat,), x)
@@ -223,14 +270,26 @@ end
223270@device_override Base. abs (x:: Int32 ) = ccall (" extern __nv_abs" , llvmcall, Int32, (Int32,), x)
224271@device_override Base. abs (f:: Float64 ) = ccall (" extern __nv_fabs" , llvmcall, Cdouble, (Cdouble,), f)
225272@device_override Base. abs (f:: Float32 ) = ccall (" extern __nv_fabsf" , llvmcall, Cfloat, (Cfloat,), f)
226- # TODO : enable once PTX > 7.0 is supported
227- # @device_override Base.abs(x::Float16) = @asmcall("abs.f16 \$0, \$1", "=h,h", Float16, Tuple{Float16}, x)
273+ @device_override function Base. abs (f:: Float16 )
274+ if compute_capability () >= sv " 8.0"
275+ ccall (" extern __nv_habs" , llvmcall, Float16, (Float16,), f)
276+ else
277+ return Float16 (abs (Float32 (f)))
278+ end
279+ end
228280@device_override Base. abs (x:: Int64 ) = ccall (" extern __nv_llabs" , llvmcall, Int64, (Int64,), x)
229281
230282# # roots and powers
231283
232284@device_override Base. sqrt (x:: Float64 ) = ccall (" extern __nv_sqrt" , llvmcall, Cdouble, (Cdouble,), x)
233285@device_override Base. sqrt (x:: Float32 ) = ccall (" extern __nv_sqrtf" , llvmcall, Cfloat, (Cfloat,), x)
286+ @device_override function Base. sqrt (x:: Float16 )
287+ if compute_capability () >= sv " 8.0"
288+ ccall (" extern __nv_hsqrt" , llvmcall, Float16, (Float16,), x)
289+ else
290+ return Float16 (sqrt (Float32 (x)))
291+ end
292+ end
234293@device_override FastMath. sqrt_fast (x:: Union{Float32, Float64} ) = sqrt (x)
235294
236295@device_function rsqrt (x:: Float64 ) = ccall (" extern __nv_rsqrt" , llvmcall, Cdouble, (Cdouble,), x)
295354# JuliaGPU/CUDA.jl#2111: fmin semantics wrt. NaN don't match Julia's
296355# @device_override Base.min(x::Float64, y::Float64) = ccall("extern __nv_fmin", llvmcall, Cdouble, (Cdouble, Cdouble), x, y)
297356# @device_override Base.min(x::Float32, y::Float32) = ccall("extern __nv_fminf", llvmcall, Cfloat, (Cfloat, Cfloat), x, y)
357+ @device_override @inline function Base. min (x:: Float16 , y:: Float16 )
358+ if compute_capability () >= sv " 8.0"
359+ return ccall (" extern __nv_hmin" , llvmcall, Float16, (Float16, Float16), x, y)
360+ else
361+ return Float16 (min (Float32 (x), Float32 (y)))
362+ end
363+ end
298364@device_override @inline function Base. min (x:: Float32 , y:: Float32 )
299365 if @static LLVM. version () < v " 14" ? false : (compute_capability () >= sv " 8.0" )
300366 # LLVM 14+ can do the right thing, but only on sm_80+
321387# JuliaGPU/CUDA.jl#2111: fmin semantics wrt. NaN don't match Julia's
322388# @device_override Base.max(x::Float64, y::Float64) = ccall("extern __nv_fmax", llvmcall, Cdouble, (Cdouble, Cdouble), x, y)
323389# @device_override Base.max(x::Float32, y::Float32) = ccall("extern __nv_fmaxf", llvmcall, Cfloat, (Cfloat, Cfloat), x, y)
390+ @device_override @inline function Base. max (x:: Float16 , y:: Float16 )
391+ if compute_capability () >= sv " 8.0"
392+ return ccall (" extern __nv_hmax" , llvmcall, Float16, (Float16, Float16), x, y)
393+ else
394+ return Float16 (max (Float32 (x), Float32 (y)))
395+ end
396+ end
324397@device_override @inline function Base. max (x:: Float32 , y:: Float32 )
325398 if @static LLVM. version () < v " 14" ? false : (compute_capability () >= sv " 8.0" )
326399 # LLVM 14+ can do the right thing, but only on sm_80+
0 commit comments