Skip to content

Commit 58b6e93

Browse files
committed
add curl share locks to make things threadsafe
1 parent 63c3d65 commit 58b6e93

File tree

5 files changed

+100
-5
lines changed

5 files changed

+100
-5
lines changed

src/curl.jl

Lines changed: 53 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,55 @@ function buffer_send_data(input::Channel{T}) where T <: ProtoType
4242
end
4343
=#
4444

45+
function share_lock(easy_p::Ptr{Cvoid}, data::curl_lock_data, access::curl_lock_access, userptr::Ptr{Cvoid})
46+
share = unsafe_pointer_to_objref(Ptr{CurlShare}(userptr))::CurlShare
47+
lock(share.locks[data])
48+
nothing
49+
end
50+
51+
function share_unlock(easy_p::Ptr{Cvoid}, data::curl_lock_data, userptr::Ptr{Cvoid})
52+
share = unsafe_pointer_to_objref(Ptr{CurlShare}(userptr))::CurlShare
53+
unlock(share.locks[data])
54+
nothing
55+
end
56+
57+
mutable struct CurlShare
58+
shptr::Ptr{CURLSH}
59+
locks::Vector{ReentrantLock}
60+
closed::Bool
61+
62+
function CurlShare()
63+
shptr = curl_share_init()
64+
curl_share_setopt(shptr, CURLSHOPT_SHARE, CURL_LOCK_DATA_COOKIE)
65+
curl_share_setopt(shptr, CURLSHOPT_SHARE, CURL_LOCK_DATA_DNS)
66+
curl_share_setopt(shptr, CURLSHOPT_SHARE, CURL_LOCK_DATA_PSL)
67+
68+
share_lock_cb = @cfunction(share_lock, Cvoid, (Ptr{Cvoid}, Cuint, Cuint, Ptr{Cvoid}))
69+
share_unlock_cb = @cfunction(share_unlock, Cvoid, (Ptr{Cvoid}, Cuint, Ptr{Cvoid}))
70+
71+
@ccall LibCURL.LibCURL_jll.libcurl.curl_share_setopt(shptr::Ptr{CURLSH}, CURLSHOPT_LOCKFUNC::CURLSHoption; share_lock_cb::Ptr{Cvoid})::CURLSHcode
72+
@ccall LibCURL.LibCURL_jll.libcurl.curl_share_setopt(shptr::Ptr{CURLSH}, CURLSHOPT_UNLOCKFUNC::CURLSHoption; share_unlock_cb::Ptr{Cvoid})::CURLSHcode
73+
74+
locks = Vector(undef, CURL_LOCK_DATA_LAST)
75+
for idx in 1:CURL_LOCK_DATA_LAST
76+
locks[idx] = ReentrantLock()
77+
end
78+
79+
obj = new(shptr, locks, false)
80+
userptr = pointer_from_objref(obj)
81+
@ccall LibCURL.LibCURL_jll.libcurl.curl_share_setopt(shptr::Ptr{CURLSH}, CURLSHOPT_USERDATA::CURLSHoption; userptr::Ptr{Cvoid})::CURLSHcode
82+
obj
83+
end
84+
end
85+
86+
function close(share::CurlShare)
87+
if share.closed
88+
curl_share_cleanup(share.shptr)
89+
share.closed = true
90+
end
91+
nothing
92+
end
93+
4594
function send_data(easy::Curl.Easy, input::Channel{T}, max_send_message_length::Int) where T <: ProtoType
4695
while true
4796
yield()
@@ -95,7 +144,7 @@ function grpc_request_header(request_timeout::Real)
95144
end
96145
end
97146

98-
function easy_handle(maxage::Clong, keepalive::Clong, negotiation::Symbol, revocation::Bool, request_timeout::Real)
147+
function easy_handle(curlshare::Ptr{CURLSH}, maxage::Clong, keepalive::Clong, negotiation::Symbol, revocation::Bool, request_timeout::Real)
99148
easy = Curl.Easy()
100149
http_version = (negotiation === :http2) ? CURL_HTTP_VERSION_2_0 :
101150
(negotiation === :http2_tls) ? CURL_HTTP_VERSION_2TLS :
@@ -105,6 +154,7 @@ function easy_handle(maxage::Clong, keepalive::Clong, negotiation::Symbol, revoc
105154
Curl.setopt(easy, CURLOPT_PIPEWAIT, Clong(1))
106155
Curl.setopt(easy, CURLOPT_POST, Clong(1))
107156
Curl.setopt(easy, CURLOPT_HTTPHEADER, grpc_request_header(request_timeout))
157+
Curl.setopt(easy, CURLOPT_SHARE, curlshare)
108158
if !revocation
109159
Curl.setopt(easy, CURLOPT_SSL_OPTIONS, CURLSSLOPT_NO_REVOKE)
110160
end
@@ -172,7 +222,7 @@ function set_connect_timeout(easy::Curl.Easy, timeout::Real)
172222
end
173223
end
174224

175-
function grpc_request(downloader::Downloader, url::String, input::Channel{T1}, output::Channel{T2};
225+
function grpc_request(curlshare::Ptr{CURLSH}, downloader::Downloader, url::String, input::Channel{T1}, output::Channel{T2};
176226
maxage::Clong = typemax(Clong),
177227
keepalive::Clong = 60,
178228
negotiation::Symbol = :http2_prior_knowledge,
@@ -182,7 +232,7 @@ function grpc_request(downloader::Downloader, url::String, input::Channel{T1}, o
182232
max_recv_message_length::Int = DEFAULT_MAX_RECV_MESSAGE_LENGTH,
183233
max_send_message_length::Int = DEFAULT_MAX_SEND_MESSAGE_LENGTH,
184234
verbose::Bool = false)::gRPCStatus where {T1 <: ProtoType, T2 <: ProtoType}
185-
Curl.with_handle(easy_handle(maxage, keepalive, negotiation, revocation, request_timeout)) do easy
235+
Curl.with_handle(easy_handle(curlshare, maxage, keepalive, negotiation, revocation, request_timeout)) do easy
186236
# setup the request
187237
Curl.set_url(easy, url)
188238
Curl.set_timeout(easy, request_timeout)

src/gRPCClient.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ using ProtoBuf
66

77
import Downloads: Curl
88
import ProtoBuf: call_method
9+
import Base: close
910

1011
export gRPCController, gRPCChannel, gRPCException, gRPCServiceCallException, gRPCMessageTooLargeException, gRPCStatus, gRPCCheck, StatusCode
1112

src/grpc.jl

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -146,16 +146,22 @@ the server.
146146
struct gRPCChannel <: ProtoRpcChannel
147147
downloader::Downloader
148148
baseurl::String
149+
curlshare::CurlShare # TODO: this should be optional to avoid unnecessary overheads when possible
149150

150151
function gRPCChannel(baseurl::String)
151152
downloader = Downloader(; grace=Inf)
152153
Curl.init!(downloader.multi)
153154
Curl.setopt(downloader.multi, CURLMOPT_PIPELINING, CURLPIPE_MULTIPLEX)
154155
endswith(baseurl, '/') && (baseurl = baseurl[1:end-1])
155-
new(downloader, baseurl)
156+
new(downloader, baseurl, CurlShare())
156157
end
157158
end
158159

160+
function close(channel::gRPCChannel)
161+
close(channel.curlshare)
162+
nothing
163+
end
164+
159165
function to_delimited_message_bytes(msg, max_message_length::Int)
160166
iob = IOBuffer()
161167
limitiob = LimitIO(iob, max_message_length)
@@ -193,7 +199,7 @@ function call_method(channel::gRPCChannel, service::ServiceDescriptor, method::M
193199
end
194200
function call_method(channel::gRPCChannel, service::ServiceDescriptor, method::MethodDescriptor, controller::gRPCController, input::Channel{T1}, outchannel::Channel{T2}) where {T1 <: ProtoType, T2 <: ProtoType}
195201
url = string(channel.baseurl, "/", service.name, "/", method.name)
196-
status_future = @async grpc_request(channel.downloader, url, input, outchannel;
202+
status_future = @async grpc_request(channel.curlshare.shptr, channel.downloader, url, input, outchannel;
197203
maxage = controller.maxage,
198204
keepalive = controller.keepalive,
199205
negotiation = controller.negotiation,

test/runtests_routeguide.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ using Downloads
55
using Random
66
using Sockets
77
using Test
8+
using Base.Threads
89

910
const SERVER_RELEASE = "https://github.com/JuliaComputing/gRPCClient.jl/releases/download/testserver_v0.2/"
1011
function server_binary()
@@ -88,6 +89,11 @@ server_endpoint = isempty(ARGS) ? "http://localhost:10000/" : ARGS[1]
8889
@debug("testing routeclinet...")
8990
test_clients(server_endpoint)
9091

92+
if Threads.nthreads() > 1
93+
@debug("testing multithreaded clients...", threadcount=Threads.nthreads())
94+
test_threaded_clients(server_endpoint)
95+
end
96+
9197
kill(serverproc)
9298
@info("stopped test server")
9399
end

test/test_routeclient.jl

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,13 +139,29 @@ end
139139

140140
function test_async_client(server_endpoint::String)
141141
client = RouteGuideClient(server_endpoint; verbose=false)
142+
try
143+
test_async_client(client)
144+
finally
145+
close(client.channel)
146+
end
147+
nothing
148+
end
149+
function test_async_client(client::RouteGuideClient)
142150
@testset "GetFeature" begin
143151
test_async_get_feature(client)
144152
end
145153
end
146154

147155
function test_blocking_client(server_endpoint::String)
148156
client = RouteGuideBlockingClient(server_endpoint; verbose=false)
157+
try
158+
test_blocking_client(client)
159+
finally
160+
close(client.channel)
161+
end
162+
nothing
163+
end
164+
function test_blocking_client(client::RouteGuideBlockingClient)
149165
@testset "request response" begin
150166
test_get_feature(client)
151167
end
@@ -171,4 +187,20 @@ function test_clients(server_endpoint::String)
171187
test_blocking_client(server_endpoint)
172188
@info("testing async client")
173189
test_async_client(server_endpoint)
190+
end
191+
192+
function test_threaded_clients(server_endpoint::String)
193+
@info("testing threaded blocking client")
194+
Threads.@threads for idx in 1:2
195+
client = RouteGuideBlockingClient(server_endpoint; verbose=false)
196+
test_blocking_client(client)
197+
close(client.channel)
198+
end
199+
200+
# @info("testing threaded async client")
201+
# Threads.@threads for idx in 1:10
202+
# client = RouteGuideClient(server_endpoint; verbose=false)
203+
# test_async_client(client)
204+
# close(client.channel)
205+
# end
174206
end

0 commit comments

Comments
 (0)