Skip to content

WIP: add curl share locks to make things threadsafe #18

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ jobs:
fail-fast: false
matrix:
version:
- '1.3'
- '1.5'
- '1' # automatically expands to the latest stable 1.x release of Julia
os:
- ubuntu-latest
Expand Down
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "gRPCClient"
uuid = "aaca4a50-36af-4a1d-b878-4c443f2061ad"
authors = ["Tanmay K.M. <[email protected]>"]
version = "0.1.2"
version = "0.1.3"

[deps]
Downloads = "f43a241f-c20a-4ad4-852c-f6b1247861c6"
Expand All @@ -12,7 +12,7 @@ ProtoBuf = "3349acd9-ac6a-5e09-bcdb-63829b23a429"
Downloads = "1.3"
LibCURL = "0.6"
ProtoBuf = "0.11"
julia = "1.3"
julia = "1.5"

[extras]
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Expand Down
56 changes: 53 additions & 3 deletions src/curl.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,55 @@ function buffer_send_data(input::Channel{T}) where T <: ProtoType
end
=#

function share_lock(easy_p::Ptr{Cvoid}, data::curl_lock_data, access::curl_lock_access, userptr::Ptr{Cvoid})
share = unsafe_pointer_to_objref(Ptr{CurlShare}(userptr))::CurlShare
lock(share.locks[data])
nothing
end

function share_unlock(easy_p::Ptr{Cvoid}, data::curl_lock_data, userptr::Ptr{Cvoid})
share = unsafe_pointer_to_objref(Ptr{CurlShare}(userptr))::CurlShare
unlock(share.locks[data])
nothing
end

mutable struct CurlShare
shptr::Ptr{CURLSH}
locks::Vector{ReentrantLock}
closed::Bool

function CurlShare()
shptr = curl_share_init()
curl_share_setopt(shptr, CURLSHOPT_SHARE, CURL_LOCK_DATA_COOKIE)
curl_share_setopt(shptr, CURLSHOPT_SHARE, CURL_LOCK_DATA_DNS)
curl_share_setopt(shptr, CURLSHOPT_SHARE, CURL_LOCK_DATA_PSL)

share_lock_cb = @cfunction(share_lock, Cvoid, (Ptr{Cvoid}, Cuint, Cuint, Ptr{Cvoid}))
share_unlock_cb = @cfunction(share_unlock, Cvoid, (Ptr{Cvoid}, Cuint, Ptr{Cvoid}))

@ccall LibCURL.LibCURL_jll.libcurl.curl_share_setopt(shptr::Ptr{CURLSH}, CURLSHOPT_LOCKFUNC::CURLSHoption; share_lock_cb::Ptr{Cvoid})::CURLSHcode
@ccall LibCURL.LibCURL_jll.libcurl.curl_share_setopt(shptr::Ptr{CURLSH}, CURLSHOPT_UNLOCKFUNC::CURLSHoption; share_unlock_cb::Ptr{Cvoid})::CURLSHcode

locks = Vector(undef, CURL_LOCK_DATA_LAST)
for idx in 1:CURL_LOCK_DATA_LAST
locks[idx] = ReentrantLock()
end

obj = new(shptr, locks, false)
userptr = pointer_from_objref(obj)
@ccall LibCURL.LibCURL_jll.libcurl.curl_share_setopt(shptr::Ptr{CURLSH}, CURLSHOPT_USERDATA::CURLSHoption; userptr::Ptr{Cvoid})::CURLSHcode
obj
end
end

function close(share::CurlShare)
if share.closed
curl_share_cleanup(share.shptr)
share.closed = true
end
nothing
end

function send_data(easy::Curl.Easy, input::Channel{T}, max_send_message_length::Int) where T <: ProtoType
while true
yield()
Expand Down Expand Up @@ -95,7 +144,7 @@ function grpc_request_header(request_timeout::Real)
end
end

function easy_handle(maxage::Clong, keepalive::Clong, negotiation::Symbol, revocation::Bool, request_timeout::Real)
function easy_handle(curlshare::Ptr{CURLSH}, maxage::Clong, keepalive::Clong, negotiation::Symbol, revocation::Bool, request_timeout::Real)
easy = Curl.Easy()
http_version = (negotiation === :http2) ? CURL_HTTP_VERSION_2_0 :
(negotiation === :http2_tls) ? CURL_HTTP_VERSION_2TLS :
Expand All @@ -105,6 +154,7 @@ function easy_handle(maxage::Clong, keepalive::Clong, negotiation::Symbol, revoc
Curl.setopt(easy, CURLOPT_PIPEWAIT, Clong(1))
Curl.setopt(easy, CURLOPT_POST, Clong(1))
Curl.setopt(easy, CURLOPT_HTTPHEADER, grpc_request_header(request_timeout))
Curl.setopt(easy, CURLOPT_SHARE, curlshare)
if !revocation
Curl.setopt(easy, CURLOPT_SSL_OPTIONS, CURLSSLOPT_NO_REVOKE)
end
Expand Down Expand Up @@ -172,7 +222,7 @@ function set_connect_timeout(easy::Curl.Easy, timeout::Real)
end
end

function grpc_request(downloader::Downloader, url::String, input::Channel{T1}, output::Channel{T2};
function grpc_request(curlshare::Ptr{CURLSH}, downloader::Downloader, url::String, input::Channel{T1}, output::Channel{T2};
maxage::Clong = typemax(Clong),
keepalive::Clong = 60,
negotiation::Symbol = :http2_prior_knowledge,
Expand All @@ -182,7 +232,7 @@ function grpc_request(downloader::Downloader, url::String, input::Channel{T1}, o
max_recv_message_length::Int = DEFAULT_MAX_RECV_MESSAGE_LENGTH,
max_send_message_length::Int = DEFAULT_MAX_SEND_MESSAGE_LENGTH,
verbose::Bool = false)::gRPCStatus where {T1 <: ProtoType, T2 <: ProtoType}
Curl.with_handle(easy_handle(maxage, keepalive, negotiation, revocation, request_timeout)) do easy
Curl.with_handle(easy_handle(curlshare, maxage, keepalive, negotiation, revocation, request_timeout)) do easy
# setup the request
Curl.set_url(easy, url)
Curl.set_timeout(easy, request_timeout)
Expand Down
1 change: 1 addition & 0 deletions src/gRPCClient.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ using ProtoBuf

import Downloads: Curl
import ProtoBuf: call_method
import Base: close

export gRPCController, gRPCChannel, gRPCException, gRPCServiceCallException, gRPCMessageTooLargeException, gRPCStatus, gRPCCheck, StatusCode

Expand Down
10 changes: 8 additions & 2 deletions src/grpc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -146,16 +146,22 @@ the server.
struct gRPCChannel <: ProtoRpcChannel
downloader::Downloader
baseurl::String
curlshare::CurlShare # TODO: this should be optional to avoid unnecessary overheads when possible

function gRPCChannel(baseurl::String)
downloader = Downloader(; grace=Inf)
Curl.init!(downloader.multi)
Curl.setopt(downloader.multi, CURLMOPT_PIPELINING, CURLPIPE_MULTIPLEX)
endswith(baseurl, '/') && (baseurl = baseurl[1:end-1])
new(downloader, baseurl)
new(downloader, baseurl, CurlShare())
end
end

function close(channel::gRPCChannel)
close(channel.curlshare)
nothing
end

function to_delimited_message_bytes(msg, max_message_length::Int)
iob = IOBuffer()
limitiob = LimitIO(iob, max_message_length)
Expand Down Expand Up @@ -193,7 +199,7 @@ function call_method(channel::gRPCChannel, service::ServiceDescriptor, method::M
end
function call_method(channel::gRPCChannel, service::ServiceDescriptor, method::MethodDescriptor, controller::gRPCController, input::Channel{T1}, outchannel::Channel{T2}) where {T1 <: ProtoType, T2 <: ProtoType}
url = string(channel.baseurl, "/", service.name, "/", method.name)
status_future = @async grpc_request(channel.downloader, url, input, outchannel;
status_future = @async grpc_request(channel.curlshare.shptr, channel.downloader, url, input, outchannel;
maxage = controller.maxage,
keepalive = controller.keepalive,
negotiation = controller.negotiation,
Expand Down
8 changes: 7 additions & 1 deletion test/runtests_routeguide.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ using Downloads
using Random
using Sockets
using Test
using Base.Threads

const SERVER_RELEASE = "https://github.com/JuliaComputing/gRPCClient.jl/releases/download/testserver_v0.2/"
function server_binary()
Expand Down Expand Up @@ -79,7 +80,7 @@ server_endpoint = isempty(ARGS) ? "http://localhost:10000/" : ARGS[1]
else
@info("skipping code generation on Windows to avoid needing batch file execution permissions")
end

test_timeout_header_values()

include("test_routeclient.jl")
Expand All @@ -91,6 +92,11 @@ server_endpoint = isempty(ARGS) ? "http://localhost:10000/" : ARGS[1]
@debug("testing async safety...")
test_task_safety(server_endpoint)

if Threads.nthreads() > 1
@debug("testing multithreaded clients...", threadcount=Threads.nthreads())
test_threaded_clients(server_endpoint)
end

kill(serverproc)
@info("stopped test server")
end
Expand Down
53 changes: 52 additions & 1 deletion test/test_routeclient.jl
Original file line number Diff line number Diff line change
Expand Up @@ -139,13 +139,29 @@ end

function test_async_client(server_endpoint::String)
client = RouteGuideClient(server_endpoint; verbose=false)
try
test_async_client(client)
finally
close(client.channel)
end
nothing
end
function test_async_client(client::RouteGuideClient)
@testset "GetFeature" begin
test_async_get_feature(client)
end
end

function test_blocking_client(server_endpoint::String)
client = RouteGuideBlockingClient(server_endpoint; verbose=false)
try
test_blocking_client(client)
finally
close(client.channel)
end
nothing
end
function test_blocking_client(client::RouteGuideBlockingClient)
@testset "request response" begin
test_get_feature(client)
end
Expand Down Expand Up @@ -189,4 +205,39 @@ function test_task_safety(server_endpoint::String)
end
end
end
end
end

function test_threaded_clients(server_endpoint::String)
@info("testing threaded blocking client")

testsetslck = ReentrantLock()
topts = Test.get_testset()
function recordts(ts)
lock(testsetslck) do
for result in ts.results
Test.record(topts, result)
end
end
end

Test.TESTSET_PRINT_ENABLE[] = false
Threads.@threads for _idx in 1:10
ts = @testset "threaded blocking client" begin
client = RouteGuideBlockingClient(server_endpoint; verbose=false)
test_blocking_client(client)
close(client.channel)
end
recordts(ts)
end

@info("testing threaded async client")
Threads.@threads for _idx in 1:10
ts = @testset "threaded async client" begin
client = RouteGuideClient(server_endpoint; verbose=false)
test_async_client(client)
close(client.channel)
end
recordts(ts)
end
Test.TESTSET_PRINT_ENABLE[] = true
end