Skip to content

Commit 2773c01

Browse files
authored
Add close/1 (#50)
1 parent 2110f38 commit 2773c01

File tree

6 files changed

+122
-15
lines changed

6 files changed

+122
-15
lines changed

c_src/ex_dtls/native.c

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ UNIFEX_TERM do_init(UnifexEnv *env, char *mode_str, int dtls_srtp,
108108
state->x509 = NULL;
109109
state->mode = 0;
110110
state->hsk_finished = 0;
111+
state->closed = 0;
111112
state->env = unifex_alloc_env(env);
112113

113114
int mode;
@@ -244,6 +245,10 @@ UNIFEX_TERM get_cert_fingerprint(UnifexEnv *env, UnifexPayload *cert) {
244245
}
245246

246247
UNIFEX_TERM do_handshake(UnifexEnv *env, State *state) {
248+
if (state->closed == 1) {
249+
return do_handshake_result_error_closed(env);
250+
}
251+
247252
SSL_do_handshake(state->ssl);
248253

249254
UnifexPayload **gen_packets = NULL;
@@ -258,14 +263,19 @@ UNIFEX_TERM do_handshake(UnifexEnv *env, State *state) {
258263
} else {
259264
int timeout = get_timeout(state->ssl);
260265
UNIFEX_TERM res_term =
261-
do_handshake_result(env, gen_packets, gen_packets_size, timeout);
266+
do_handshake_result_ok(env, gen_packets, gen_packets_size, timeout);
262267
free_payload_array(gen_packets, gen_packets_size);
263268

264269
return res_term;
265270
}
266271
}
267272

268273
UNIFEX_TERM write_data(UnifexEnv *env, State *state, UnifexPayload *payload) {
274+
if (state->closed == 1) {
275+
DEBUG("Cannot write, connection closed");
276+
return write_data_result_error_closed(env);
277+
}
278+
269279
if (state->hsk_finished != 1) {
270280
DEBUG("Cannot write, handshake not finished");
271281
return write_data_result_error_handshake_not_finished(env);
@@ -303,6 +313,10 @@ UNIFEX_TERM write_data(UnifexEnv *env, State *state, UnifexPayload *payload) {
303313
}
304314

305315
UNIFEX_TERM handle_data(UnifexEnv *env, State *state, UnifexPayload *payload) {
316+
if (state->closed == 1) {
317+
return handle_data_result_error_closed(env);
318+
}
319+
306320
(void)env;
307321

308322
if (payload->size != 0) {
@@ -332,6 +346,32 @@ UNIFEX_TERM handle_data(UnifexEnv *env, State *state, UnifexPayload *payload) {
332346
}
333347
}
334348

349+
// prefix close with exd (ex_dtls) as close is defined in unistd.h
350+
UNIFEX_TERM exd_close(UnifexEnv *env, State *state) {
351+
if (state->closed == 1) {
352+
return exd_close_result_ok(env, NULL, 0);
353+
}
354+
355+
state->closed = 1;
356+
if (SSL_shutdown(state->ssl) < 0) {
357+
return exd_close_result_ok(env, NULL, 0);
358+
} else {
359+
UnifexPayload **gen_packets = NULL;
360+
int gen_packets_size = 0;
361+
read_pending_data(&gen_packets, &gen_packets_size, state);
362+
363+
if (gen_packets == NULL) {
364+
return unifex_raise(state->env,
365+
"Close failed: couldn't read pending data");
366+
} else {
367+
UNIFEX_TERM res_term =
368+
exd_close_result_ok(env, gen_packets, gen_packets_size);
369+
free_payload_array(gen_packets, gen_packets_size);
370+
return res_term;
371+
}
372+
}
373+
}
374+
335375
UNIFEX_TERM handle_regular_read(State *state, char data[], int ret) {
336376
if (ret > 0) {
337377
UnifexPayload packets;
@@ -351,6 +391,7 @@ UNIFEX_TERM handle_read_error(State *state, int ret) {
351391
int error = SSL_get_error(state->ssl, ret);
352392
switch (error) {
353393
case SSL_ERROR_ZERO_RETURN:
394+
state->closed = 1;
354395
return handle_data_result_error_peer_closed_for_writing(state->env);
355396
case SSL_ERROR_WANT_READ:
356397
DEBUG("SSL WANT READ. This is workaround. Did we get retransmission?");
@@ -452,6 +493,10 @@ UNIFEX_TERM handle_handshake_in_progress(State *state, int ret) {
452493
}
453494

454495
UNIFEX_TERM handle_timeout(UnifexEnv *env, State *state) {
496+
if (state->closed == 1) {
497+
return handle_timeout_result_error_closed(env);
498+
}
499+
455500
long result = DTLSv1_handle_timeout(state->ssl);
456501
if (result != 1)
457502
return handle_timeout_result_ok(env);

c_src/ex_dtls/native.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ struct State {
1313
X509 *x509;
1414
int mode;
1515
int hsk_finished;
16+
int closed;
1617
};
1718

1819
#include "_generated/native.h"

c_src/ex_dtls/native.spec.exs

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,17 @@ spec get_peer_cert(state) :: payload | (nil :: label)
2121

2222
spec get_cert_fingerprint(payload) :: payload
2323

24-
spec do_handshake(state) :: {packets :: [payload], timeout :: int}
24+
spec do_handshake(state) :: {:ok :: label, packets :: [payload], timeout :: int} | {:error :: label, :closed :: label}
2525

26-
spec handle_timeout(state) :: (:ok :: label) | {:retransmit :: label, packets :: [payload], timeout :: int}
26+
spec handle_timeout(state) ::
27+
(:ok :: label)
28+
| {:retransmit :: label, packets :: [payload], timeout :: int}
29+
| {:error :: label, :closed :: label}
2730

28-
spec write_data(state, packets :: payload) :: {:ok :: label, packets :: [payload]} | {:error :: label, :handshake_not_finished :: label}
31+
spec write_data(state, packets :: payload) ::
32+
{:ok :: label, packets :: [payload]}
33+
| {:error :: label, :handshake_not_finished :: label}
34+
| {:error :: label, :closed :: label}
2935

3036
spec handle_data(state, packets :: payload) ::
3137
{:ok :: label, packets :: payload}
@@ -34,5 +40,7 @@ spec handle_data(state, packets :: payload) ::
3440
| {:handshake_finished :: label, client_keying_material :: payload,
3541
server_keying_material :: payload, protection_profile :: int, packets :: [payload]}
3642
| {:error :: label, :peer_closed_for_writing :: label}
37-
| {:error :: label, :handshake_error :: label
38-
}
43+
| {:error :: label, :handshake_error :: label}
44+
| {:error :: label, :closed :: label}
45+
46+
spec exd_close(state) :: {:ok :: label, packets :: [payload]}

lib/ex_dtls.ex

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,8 @@ defmodule ExDTLS do
139139
140140
`timeout` is a time in ms after which `handle_timeout/1` should be called.
141141
"""
142-
@spec do_handshake(dtls()) :: {packets :: [binary()], timeout :: integer()}
142+
@spec do_handshake(dtls()) ::
143+
{:ok, packets :: [binary()], timeout :: integer()} | {:error, :closed}
143144
defdelegate do_handshake(dtls), to: Native
144145

145146
@doc """
@@ -148,7 +149,7 @@ defmodule ExDTLS do
148149
Generates encrypted packets that need to be passed to the second host.
149150
"""
150151
@spec write_data(dtls(), data :: binary()) ::
151-
{:ok, packets :: [binary()]} | {:error, :handshake_not_finished}
152+
{:ok, packets :: [binary()]} | {:error, :handshake_not_finished | :closed}
152153
defdelegate write_data(dtls, data), to: Native
153154

154155
@doc """
@@ -172,7 +173,7 @@ defmodule ExDTLS do
172173
remote_keying_material :: binary(), protection_profile_t(), packets :: [binary()]}
173174
| {:handshake_finished, local_keying_material :: binary(),
174175
remote_keying_material :: binary(), protection_profile_t()}
175-
| {:error, :handshake_error | :peer_closed_for_writing}
176+
| {:error, :handshake_error | :peer_closed_for_writing | :closed}
176177
def handle_data(dtls, packets) do
177178
case Native.handle_data(dtls, packets) do
178179
{:handshake_finished, lkm, rkm, protection_profile, []} ->
@@ -192,6 +193,16 @@ defmodule ExDTLS do
192193
193194
If there is no timeout to handle, simple `{:ok, dtls()}` tuple is returned.
194195
"""
195-
@spec handle_timeout(dtls()) :: :ok | {:retransmit, packets :: [binary()], timeout :: integer()}
196+
@spec handle_timeout(dtls()) ::
197+
:ok | {:retransmit, packets :: [binary()], timeout :: integer()} | {:error, :closed}
196198
defdelegate handle_timeout(dtls), to: Native
199+
200+
@doc """
201+
Irreversibly closes DTLS session.
202+
203+
If a handshake has been finished, this function will generate `close_notify` DTLS alert
204+
that should be sent to the other side.
205+
"""
206+
@spec close(dtls()) :: {:ok, packets :: [binary()]}
207+
defdelegate close(dtls), to: Native, as: :exd_close
197208
end

test/integration_test.exs

Lines changed: 46 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ defmodule ExDTLS.IntegrationTest do
55
rx_dtls = ExDTLS.init(mode: :server, dtls_srtp: true, verify_peer: true)
66
tx_dtls = ExDTLS.init(mode: :client, dtls_srtp: true, verify_peer: true)
77

8-
{packets, _timeout} = ExDTLS.do_handshake(tx_dtls)
8+
{:ok, packets, _timeout} = ExDTLS.do_handshake(tx_dtls)
99

1010
assert :ok == loop({rx_dtls, false}, {tx_dtls, false}, packets)
1111

@@ -17,7 +17,7 @@ defmodule ExDTLS.IntegrationTest do
1717
rx_dtls = ExDTLS.init(mode: :server, dtls_srtp: true)
1818
tx_dtls = ExDTLS.init(mode: :client, dtls_srtp: true)
1919

20-
{packets, _timeout} = ExDTLS.do_handshake(tx_dtls)
20+
{:ok, packets, _timeout} = ExDTLS.do_handshake(tx_dtls)
2121

2222
assert :ok == loop({rx_dtls, false}, {tx_dtls, false}, packets)
2323

@@ -34,7 +34,7 @@ defmodule ExDTLS.IntegrationTest do
3434
assert {:error, :handshake_not_finished} = ExDTLS.write_data(sr_dtls, <<1, 2, 3>>)
3535
assert {:error, :handshake_not_finished} = ExDTLS.write_data(cl_dtls, <<1, 2, 3>>)
3636

37-
{packets, _timeout} = ExDTLS.do_handshake(cl_dtls)
37+
{:ok, packets, _timeout} = ExDTLS.do_handshake(cl_dtls)
3838
assert :ok == loop({sr_dtls, false}, {cl_dtls, false}, packets)
3939

4040
msg = <<1, 3, 2, 5>>
@@ -55,11 +55,53 @@ defmodule ExDTLS.IntegrationTest do
5555

5656
tx_dtls = ExDTLS.init(mode: :client, dtls_srtp: true, verify_peer: true)
5757

58-
{packets, _timeout} = ExDTLS.do_handshake(tx_dtls)
58+
{:ok, packets, _timeout} = ExDTLS.do_handshake(tx_dtls)
5959
{:handshake_packets, packets, _timeout} = feed_packets(rx_dtls, packets)
6060
assert {:error, :handshake_error} = feed_packets(tx_dtls, packets)
6161
end
6262

63+
describe "close/1" do
64+
test "before handshake has finished (client mode)" do
65+
dtls = ExDTLS.init(mode: :client, dtls_srtp: true, verify_peer: true)
66+
assert {:ok, []} = ExDTLS.close(dtls)
67+
# assert that handshake can't be started
68+
assert {:error, :closed} = ExDTLS.do_handshake(dtls)
69+
end
70+
71+
test "before handshake has finished (server mode)" do
72+
dtls = ExDTLS.init(mode: :server, dtls_srtp: true, verify_peer: true)
73+
assert {:ok, []} = ExDTLS.close(dtls)
74+
# assert that handshake can't be started
75+
assert {:error, :closed} = ExDTLS.do_handshake(dtls)
76+
end
77+
78+
test "after handshake has finished (client mode)" do
79+
rx_dtls = ExDTLS.init(mode: :server, dtls_srtp: true, verify_peer: true)
80+
tx_dtls = ExDTLS.init(mode: :client, dtls_srtp: true, verify_peer: true)
81+
82+
{:ok, packets, _timeout} = ExDTLS.do_handshake(tx_dtls)
83+
84+
assert :ok == loop({rx_dtls, false}, {tx_dtls, false}, packets)
85+
assert {:ok, [packet]} = ExDTLS.close(tx_dtls)
86+
assert {:error, :peer_closed_for_writing} = ExDTLS.handle_data(rx_dtls, packet)
87+
assert {:error, :closed} = ExDTLS.handle_timeout(tx_dtls)
88+
assert {:error, :closed} = ExDTLS.handle_timeout(rx_dtls)
89+
end
90+
91+
test "after handshake has finished (server mode)" do
92+
rx_dtls = ExDTLS.init(mode: :server, dtls_srtp: true, verify_peer: true)
93+
tx_dtls = ExDTLS.init(mode: :client, dtls_srtp: true, verify_peer: true)
94+
95+
{:ok, packets, _timeout} = ExDTLS.do_handshake(tx_dtls)
96+
97+
assert :ok == loop({rx_dtls, false}, {tx_dtls, false}, packets)
98+
assert {:ok, [packet]} = ExDTLS.close(rx_dtls)
99+
assert {:error, :peer_closed_for_writing} = ExDTLS.handle_data(tx_dtls, packet)
100+
assert {:error, :closed} = ExDTLS.handle_timeout(tx_dtls)
101+
assert {:error, :closed} = ExDTLS.handle_timeout(rx_dtls)
102+
end
103+
end
104+
63105
defp loop({_dtls1, true}, {_dtls2, true}, _packets) do
64106
:ok
65107
end

test/retransmission_test.exs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ defmodule ExDTLS.RetransmissionTest do
55
rx_dtls = ExDTLS.init(mode: :server, dtls_srtp: true)
66
tx_dtls = ExDTLS.init(mode: :client, dtls_srtp: true)
77

8-
{_packets, timeout} = ExDTLS.do_handshake(tx_dtls)
8+
{:ok, _packets, timeout} = ExDTLS.do_handshake(tx_dtls)
99
Process.send_after(self(), {:handle_timeout, :tx}, timeout)
1010
{:retransmit, packets, timeout} = wait_for_timeout(tx_dtls, :tx)
1111
Process.send_after(self(), {:handle_timeout, :tx}, timeout)

0 commit comments

Comments
 (0)