Skip to content

Commit 59d8ff1

Browse files
committed
Add disconnect
1 parent 2110f38 commit 59d8ff1

File tree

6 files changed

+82
-11
lines changed

6 files changed

+82
-11
lines changed

c_src/ex_dtls/native.c

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ UNIFEX_TERM do_init(UnifexEnv *env, char *mode_str, int dtls_srtp,
108108
state->x509 = NULL;
109109
state->mode = 0;
110110
state->hsk_finished = 0;
111+
state->closed = 0;
111112
state->env = unifex_alloc_env(env);
112113

113114
int mode;
@@ -244,6 +245,10 @@ UNIFEX_TERM get_cert_fingerprint(UnifexEnv *env, UnifexPayload *cert) {
244245
}
245246

246247
UNIFEX_TERM do_handshake(UnifexEnv *env, State *state) {
248+
if (state->closed == 1) {
249+
return do_handshake_result_error_closed(env);
250+
}
251+
247252
SSL_do_handshake(state->ssl);
248253

249254
UnifexPayload **gen_packets = NULL;
@@ -258,14 +263,19 @@ UNIFEX_TERM do_handshake(UnifexEnv *env, State *state) {
258263
} else {
259264
int timeout = get_timeout(state->ssl);
260265
UNIFEX_TERM res_term =
261-
do_handshake_result(env, gen_packets, gen_packets_size, timeout);
266+
do_handshake_result_ok(env, gen_packets, gen_packets_size, timeout);
262267
free_payload_array(gen_packets, gen_packets_size);
263268

264269
return res_term;
265270
}
266271
}
267272

268273
UNIFEX_TERM write_data(UnifexEnv *env, State *state, UnifexPayload *payload) {
274+
if (state->closed == 1) {
275+
DEBUG("Cannot write, connection closed");
276+
return write_data_result_error_closed(env);
277+
}
278+
269279
if (state->hsk_finished != 1) {
270280
DEBUG("Cannot write, handshake not finished");
271281
return write_data_result_error_handshake_not_finished(env);
@@ -303,6 +313,10 @@ UNIFEX_TERM write_data(UnifexEnv *env, State *state, UnifexPayload *payload) {
303313
}
304314

305315
UNIFEX_TERM handle_data(UnifexEnv *env, State *state, UnifexPayload *payload) {
316+
if (state->closed == 1) {
317+
return handle_data_result_error_closed(env);
318+
}
319+
306320
(void)env;
307321

308322
if (payload->size != 0) {
@@ -332,6 +346,31 @@ UNIFEX_TERM handle_data(UnifexEnv *env, State *state, UnifexPayload *payload) {
332346
}
333347
}
334348

349+
UNIFEX_TERM disconnect(UnifexEnv *env, State *state) {
350+
if (state->closed == 1) {
351+
return disconnect_result_ok(env, NULL, 0);
352+
}
353+
354+
state->closed = 1;
355+
if (SSL_shutdown(state->ssl) < 0) {
356+
return disconnect_result_ok(env, NULL, 0);
357+
} else {
358+
UnifexPayload **gen_packets = NULL;
359+
int gen_packets_size = 0;
360+
read_pending_data(&gen_packets, &gen_packets_size, state);
361+
362+
if (gen_packets == NULL) {
363+
return unifex_raise(state->env,
364+
"Disconnect failed: couldn't read pending data");
365+
} else {
366+
UNIFEX_TERM res_term =
367+
disconnect_result_ok(env, gen_packets, gen_packets_size);
368+
free_payload_array(gen_packets, gen_packets_size);
369+
return res_term;
370+
}
371+
}
372+
}
373+
335374
UNIFEX_TERM handle_regular_read(State *state, char data[], int ret) {
336375
if (ret > 0) {
337376
UnifexPayload packets;

c_src/ex_dtls/native.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ struct State {
1313
X509 *x509;
1414
int mode;
1515
int hsk_finished;
16+
int closed;
1617
};
1718

1819
#include "_generated/native.h"

c_src/ex_dtls/native.spec.exs

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,17 @@ spec get_peer_cert(state) :: payload | (nil :: label)
2121

2222
spec get_cert_fingerprint(payload) :: payload
2323

24-
spec do_handshake(state) :: {packets :: [payload], timeout :: int}
24+
spec do_handshake(state) :: {:ok :: label, packets :: [payload], timeout :: int} | {:error :: label, :closed :: label}
2525

26-
spec handle_timeout(state) :: (:ok :: label) | {:retransmit :: label, packets :: [payload], timeout :: int}
26+
spec handle_timeout(state) ::
27+
(:ok :: label)
28+
| {:retransmit :: label, packets :: [payload], timeout :: int}
29+
| {:error :: label, :closed :: label}
2730

28-
spec write_data(state, packets :: payload) :: {:ok :: label, packets :: [payload]} | {:error :: label, :handshake_not_finished :: label}
31+
spec write_data(state, packets :: payload) ::
32+
{:ok :: label, packets :: [payload]}
33+
| {:error :: label, :handshake_not_finished :: label}
34+
| {:error :: label, :closed :: label}
2935

3036
spec handle_data(state, packets :: payload) ::
3137
{:ok :: label, packets :: payload}
@@ -34,5 +40,7 @@ spec handle_data(state, packets :: payload) ::
3440
| {:handshake_finished :: label, client_keying_material :: payload,
3541
server_keying_material :: payload, protection_profile :: int, packets :: [payload]}
3642
| {:error :: label, :peer_closed_for_writing :: label}
37-
| {:error :: label, :handshake_error :: label
38-
}
43+
| {:error :: label, :handshake_error :: label}
44+
| {:error :: label, :closed :: label}
45+
46+
spec disconnect(state) :: {:ok :: label, packets :: [payload]}

lib/ex_dtls.ex

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,4 +194,7 @@ defmodule ExDTLS do
194194
"""
195195
@spec handle_timeout(dtls()) :: :ok | {:retransmit, packets :: [binary()], timeout :: integer()}
196196
defdelegate handle_timeout(dtls), to: Native
197+
198+
@spec disconnect(dtls()) :: {:ok, packets :: [binary()]}
199+
defdelegate disconnect(dtls), to: Native
197200
end

test/integration_test.exs

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ defmodule ExDTLS.IntegrationTest do
55
rx_dtls = ExDTLS.init(mode: :server, dtls_srtp: true, verify_peer: true)
66
tx_dtls = ExDTLS.init(mode: :client, dtls_srtp: true, verify_peer: true)
77

8-
{packets, _timeout} = ExDTLS.do_handshake(tx_dtls)
8+
{:ok, packets, _timeout} = ExDTLS.do_handshake(tx_dtls)
99

1010
assert :ok == loop({rx_dtls, false}, {tx_dtls, false}, packets)
1111

@@ -17,7 +17,7 @@ defmodule ExDTLS.IntegrationTest do
1717
rx_dtls = ExDTLS.init(mode: :server, dtls_srtp: true)
1818
tx_dtls = ExDTLS.init(mode: :client, dtls_srtp: true)
1919

20-
{packets, _timeout} = ExDTLS.do_handshake(tx_dtls)
20+
{:ok, packets, _timeout} = ExDTLS.do_handshake(tx_dtls)
2121

2222
assert :ok == loop({rx_dtls, false}, {tx_dtls, false}, packets)
2323

@@ -34,7 +34,7 @@ defmodule ExDTLS.IntegrationTest do
3434
assert {:error, :handshake_not_finished} = ExDTLS.write_data(sr_dtls, <<1, 2, 3>>)
3535
assert {:error, :handshake_not_finished} = ExDTLS.write_data(cl_dtls, <<1, 2, 3>>)
3636

37-
{packets, _timeout} = ExDTLS.do_handshake(cl_dtls)
37+
{:ok, packets, _timeout} = ExDTLS.do_handshake(cl_dtls)
3838
assert :ok == loop({sr_dtls, false}, {cl_dtls, false}, packets)
3939

4040
msg = <<1, 3, 2, 5>>
@@ -55,11 +55,31 @@ defmodule ExDTLS.IntegrationTest do
5555

5656
tx_dtls = ExDTLS.init(mode: :client, dtls_srtp: true, verify_peer: true)
5757

58-
{packets, _timeout} = ExDTLS.do_handshake(tx_dtls)
58+
{:ok, packets, _timeout} = ExDTLS.do_handshake(tx_dtls)
5959
{:handshake_packets, packets, _timeout} = feed_packets(rx_dtls, packets)
6060
assert {:error, :handshake_error} = feed_packets(tx_dtls, packets)
6161
end
6262

63+
describe "disconnect" do
64+
test "before handshake has finished" do
65+
dtls = ExDTLS.init(mode: :client, dtls_srtp: true, verify_peer: true)
66+
assert {:ok, []} = ExDTLS.disconnect(dtls)
67+
# assert that handshake can't be started
68+
assert {:error, :closed} = ExDTLS.do_handshake(dtls)
69+
end
70+
71+
test "after handshake has finished" do
72+
rx_dtls = ExDTLS.init(mode: :server, dtls_srtp: true, verify_peer: true)
73+
tx_dtls = ExDTLS.init(mode: :client, dtls_srtp: true, verify_peer: true)
74+
75+
{:ok, packets, _timeout} = ExDTLS.do_handshake(tx_dtls)
76+
77+
assert :ok == loop({rx_dtls, false}, {tx_dtls, false}, packets)
78+
assert {:ok, [packet]} = ExDTLS.disconnect(rx_dtls)
79+
assert {:error, :peer_closed_for_writing} = ExDTLS.handle_data(tx_dtls, packet)
80+
end
81+
end
82+
6383
defp loop({_dtls1, true}, {_dtls2, true}, _packets) do
6484
:ok
6585
end

test/retransmission_test.exs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ defmodule ExDTLS.RetransmissionTest do
55
rx_dtls = ExDTLS.init(mode: :server, dtls_srtp: true)
66
tx_dtls = ExDTLS.init(mode: :client, dtls_srtp: true)
77

8-
{_packets, timeout} = ExDTLS.do_handshake(tx_dtls)
8+
{:ok, _packets, timeout} = ExDTLS.do_handshake(tx_dtls)
99
Process.send_after(self(), {:handle_timeout, :tx}, timeout)
1010
{:retransmit, packets, timeout} = wait_for_timeout(tx_dtls, :tx)
1111
Process.send_after(self(), {:handle_timeout, :tx}, timeout)

0 commit comments

Comments
 (0)