diff --git a/Project.toml b/Project.toml index 802aecc..5ce33fb 100644 --- a/Project.toml +++ b/Project.toml @@ -5,14 +5,12 @@ version = "0.2.2" [deps] BitTwiddlingConvenienceFunctions = "62783981-4cbd-42fc-bca8-16325de8dc4b" -CPUSummary = "2a0fbf3d-bb9c-48f3-b0a9-814d99fd7ab9" Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3" ThreadingUtilities = "8290d209-cae3-49c0-8002-c8c24d57dab5" [compat] Aqua = "0.8" BitTwiddlingConvenienceFunctions = "0.1" -CPUSummary = "0.1.2, 0.2" Static = "1" Test = "<0.1, 1.10" ThreadingUtilities = "0.5" diff --git a/src/request.jl b/src/request.jl index 1f8e190..87880fe 100644 --- a/src/request.jl +++ b/src/request.jl @@ -1,13 +1,12 @@ -import CPUSummary - function worker_bits() - wts = nextpow2(CPUSummary.sys_threads()) # Typically sys_threads (i.e. Sys.CPU_THREADS) does not change between runs, thus it will precompile well. + wts = nextpow2(Threads.nthreads()) # Typically sys_threads (i.e. Sys.CPU_THREADS) does not change between runs, thus it will precompile well. ws = static(8sizeof(UInt)) # For testing purposes it can be overridden by JULIA_CPU_THREADS, - ifelse(Static.lt(wts, ws), ws, wts) + # Always return Int to avoid type instability with high thread counts + ifelse(wts < 64, 64, wts) end function worker_mask_count() bits = worker_bits() - (bits + StaticInt{63}()) รท StaticInt{64}() # cld not defined on `StaticInt` + cld(bits, 64) end worker_pointer() = Base.unsafe_convert(Ptr{UInt}, pointer_from_objref(WORKERS)) @@ -63,6 +62,30 @@ end (ui,), (ft,) end +# Handle regular Int for type stability with high thread counts +@inline function _request_threads( + num_requested::UInt32, + wp::Ptr, + N::Int, + threadmask +) + if N == 1 + ui, ft, num_requested, wp = + __request_threads(num_requested, wp, _first(threadmask)) + return (ui,), (ft,) + else + ui, ft, num_requested, wp = + __request_threads(num_requested, wp, _first(threadmask)) + uit, ftt = _request_threads( + num_requested, + wp, + N - 1, + _remaining(threadmask) + ) + return (ui, uit...), (ft, ftt...) + end +end + @inline function _exchange_mask!(wp, ::Nothing) all_threads = _atomic_xchg!(wp, zero(UInt)) all_threads, all_threads diff --git a/test/runtests.jl b/test/runtests.jl index d51a37a..37ed7c3 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -28,3 +28,4 @@ using Test end end Aqua.test_all(PolyesterWeave) +include("test_high_thread_count.jl") diff --git a/test/test_high_thread_count.jl b/test/test_high_thread_count.jl new file mode 100644 index 0000000..2a60efc --- /dev/null +++ b/test/test_high_thread_count.jl @@ -0,0 +1,32 @@ +using Test +using PolyesterWeave +using BitTwiddlingConvenienceFunctions: nextpow2 + +@testset "High thread count compatibility" begin + # Test worker_bits returns Int for all thread counts + @test isa(PolyesterWeave.worker_bits(), Int) + + # Test worker_mask_count returns Int + @test isa(PolyesterWeave.worker_mask_count(), Int) + + # Test that request_threads works with high thread counts + # This simulates the case where worker_mask_count() > 1 + if Threads.nthreads() > 64 + # With > 64 threads, worker_mask_count() should be 2 or more + @test PolyesterWeave.worker_mask_count() >= 2 + + # Test that request_threads doesn't throw + threads, torelease = PolyesterWeave.request_threads(10) + @test length(threads) >= 0 # May get 0 if no threads available + + # Free the threads + PolyesterWeave.free_threads!(torelease) + else + # With <= 64 threads, worker_mask_count() should be 1 + @test PolyesterWeave.worker_mask_count() == 1 + end + + # Test specific values + @test PolyesterWeave.worker_bits() == max(64, nextpow2(Threads.nthreads())) + @test PolyesterWeave.worker_mask_count() == cld(PolyesterWeave.worker_bits(), 64) +end