Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,12 @@ version = "0.2.2"

[deps]
BitTwiddlingConvenienceFunctions = "62783981-4cbd-42fc-bca8-16325de8dc4b"
CPUSummary = "2a0fbf3d-bb9c-48f3-b0a9-814d99fd7ab9"
Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3"
ThreadingUtilities = "8290d209-cae3-49c0-8002-c8c24d57dab5"

[compat]
Aqua = "0.8"
BitTwiddlingConvenienceFunctions = "0.1"
CPUSummary = "0.1.2, 0.2"
Static = "1"
Test = "<0.1, 1.10"
ThreadingUtilities = "0.5"
Expand Down
33 changes: 28 additions & 5 deletions src/request.jl
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
import CPUSummary

function worker_bits()
wts = nextpow2(CPUSummary.sys_threads()) # Typically sys_threads (i.e. Sys.CPU_THREADS) does not change between runs, thus it will precompile well.
wts = nextpow2(Threads.nthreads()) # Typically sys_threads (i.e. Sys.CPU_THREADS) does not change between runs, thus it will precompile well.
ws = static(8sizeof(UInt)) # For testing purposes it can be overridden by JULIA_CPU_THREADS,
ifelse(Static.lt(wts, ws), ws, wts)
# Always return Int to avoid type instability with high thread counts
ifelse(wts < 64, 64, wts)
end
function worker_mask_count()
bits = worker_bits()
(bits + StaticInt{63}()) ÷ StaticInt{64}() # cld not defined on `StaticInt`
cld(bits, 64)
end

worker_pointer() = Base.unsafe_convert(Ptr{UInt}, pointer_from_objref(WORKERS))
Expand Down Expand Up @@ -63,6 +62,30 @@ end
(ui,), (ft,)
end

# Handle regular Int for type stability with high thread counts
@inline function _request_threads(
num_requested::UInt32,
wp::Ptr,
N::Int,
threadmask
)
if N == 1
ui, ft, num_requested, wp =
__request_threads(num_requested, wp, _first(threadmask))
return (ui,), (ft,)
else
ui, ft, num_requested, wp =
__request_threads(num_requested, wp, _first(threadmask))
uit, ftt = _request_threads(
num_requested,
wp,
N - 1,
_remaining(threadmask)
)
return (ui, uit...), (ft, ftt...)
end
end

@inline function _exchange_mask!(wp, ::Nothing)
all_threads = _atomic_xchg!(wp, zero(UInt))
all_threads, all_threads
Expand Down
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,4 @@ using Test
end
end
Aqua.test_all(PolyesterWeave)
include("test_high_thread_count.jl")
32 changes: 32 additions & 0 deletions test/test_high_thread_count.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
using Test
using PolyesterWeave
using BitTwiddlingConvenienceFunctions: nextpow2

@testset "High thread count compatibility" begin
# Test worker_bits returns Int for all thread counts
@test isa(PolyesterWeave.worker_bits(), Int)

# Test worker_mask_count returns Int
@test isa(PolyesterWeave.worker_mask_count(), Int)

# Test that request_threads works with high thread counts
# This simulates the case where worker_mask_count() > 1
if Threads.nthreads() > 64
# With > 64 threads, worker_mask_count() should be 2 or more
@test PolyesterWeave.worker_mask_count() >= 2

# Test that request_threads doesn't throw
threads, torelease = PolyesterWeave.request_threads(10)
@test length(threads) >= 0 # May get 0 if no threads available

# Free the threads
PolyesterWeave.free_threads!(torelease)
else
# With <= 64 threads, worker_mask_count() should be 1
@test PolyesterWeave.worker_mask_count() == 1
end

# Test specific values
@test PolyesterWeave.worker_bits() == max(64, nextpow2(Threads.nthreads()))
@test PolyesterWeave.worker_mask_count() == cld(PolyesterWeave.worker_bits(), 64)
end
Loading