Skip to content
Open
149 changes: 76 additions & 73 deletions src/websocket_client.erl
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@
%% @doc Erlang websocket client
-module(websocket_client).

-export([
start_link/3,
-export([start_link/3,
start_link/4,
cast/2,
send/2
Expand All @@ -18,13 +17,15 @@
-type opts() :: [opt()].

%% @doc Start the websocket client
-spec start_link(URL :: string(), Handler :: module(), HandlerArgs :: list()) ->
-spec start_link(URL :: string() | binary(), Handler :: module(), HandlerArgs :: list()) ->
{ok, pid()} | {error, term()}.
start_link(URL, Handler, HandlerArgs) ->
start_link(URL, Handler, HandlerArgs, []).

start_link(URL, Handler, HandlerArgs, AsyncStart) when is_boolean(AsyncStart) ->
start_link(URL, Handler, HandlerArgs, [{async_start, AsyncStart}]);
start_link(URL, Handler, HandlerArgs, Opts) when is_binary(URL) ->
start_link(erlang:binary_to_list(URL), Handler, HandlerArgs, Opts);
start_link(URL, Handler, HandlerArgs, Opts) when is_list(Opts) ->
case http_uri:parse(URL, [{scheme_defaults, [{ws,80},{wss,443}]}]) of
{ok, {Protocol, _, Host, Port, Path, Query}} ->
Expand Down Expand Up @@ -122,12 +123,11 @@ ws_client_init(Handler, Protocol, Host, Port, Path, Args, Opts) ->
%% @doc Send http upgrade request and validate handshake response challenge
-spec websocket_handshake(WSReq :: websocket_req:req(), [{string(), string()}]) -> {ok, binary()} | {error, term()}.
websocket_handshake(WSReq, ExtraHeaders) ->
[Protocol, Path, Host, Key, Transport, Socket] =
websocket_req:get([protocol, path, host, key, transport, socket], WSReq),
[Path, Host, Key, Transport, Socket] =
websocket_req:get([path, host, key, transport, socket], WSReq),
Handshake = ["GET ", Path, " HTTP/1.1\r\n"
"Host: ", Host, "\r\n"
"Connection: Upgrade\r\n"
"Origin: ", atom_to_binary(Protocol, utf8), "://", Host, "\r\n"
"Sec-WebSocket-Version: 13\r\n"
"Sec-WebSocket-Key: ", Key, "\r\n"
"Upgrade: websocket\r\n",
Expand All @@ -153,38 +153,35 @@ receive_handshake(Buffer, Transport, Socket) ->
end.

%% @doc Send frame to server
-spec send(websocket_req:frame(), websocket_req:req()) -> ok | {error, term()}.
send(Frame, WSReq) ->
Socket = websocket_req:socket(WSReq),
Transport = websocket_req:transport(WSReq),
Transport:send(Socket, encode_frame(Frame)).


%% @doc Main loop
-spec websocket_loop(WSReq :: websocket_req:req(), HandlerState :: any(),
Buffer :: binary()) ->
ok.
websocket_loop(WSReq, HandlerState, Buffer) ->
receive
Message -> handle_websocket_message(WSReq, HandlerState, Buffer, Message)
Message -> handle_websocket_message(WSReq, HandlerState, Buffer, Message)
end.

handle_websocket_message(WSReq, HandlerState, Buffer, Message) ->
[Handler, Remaining, Socket] =
websocket_req:get([handler, remaining, socket], WSReq),
case Message of
keepalive ->
case websocket_req:get([keepalive_timer], WSReq) of
[undefined] -> ok;
[OldTimer] -> erlang:cancel_timer(OldTimer)
end,
cancel_keepalive_timer(WSReq),
ok = send({ping, <<>>}, WSReq),
KATimer = erlang:send_after(websocket_req:keepalive(WSReq), self(), keepalive),
websocket_loop(websocket_req:set([{keepalive_timer,KATimer}], WSReq), HandlerState, Buffer);
websocket_loop(websocket_req:keepalive_timer(KATimer, WSReq), HandlerState, Buffer);
{cast, Frame} ->
ok = send(Frame, WSReq),
websocket_loop(WSReq, HandlerState, Buffer);
{_Closed, Socket} ->
websocket_close(WSReq, HandlerState, {remote, closed});
websocket_close(WSReq, HandlerState, remote);
{_TransportType, Socket, Data} ->
case Remaining of
undefined ->
Expand All @@ -195,44 +192,56 @@ handle_websocket_message(WSReq, HandlerState, Buffer, Message) ->
websocket_req:opcode(WSReq), Remaining, Data, Buffer)
end;
Msg ->
Handler = websocket_req:handler(WSReq),
try Handler:websocket_info(Msg, WSReq, HandlerState) of
HandlerResponse ->
handle_response(WSReq, HandlerResponse, Buffer)
catch Class:Reason ->
error_logger:error_msg(
"** Websocket client ~p terminating in ~p/~p~n"
" for the reason ~p:~p~n"
"** Last message was ~p~n"
"** Handler state was ~p~n"
"** Stacktrace: ~p~n~n",
[Handler, websocket_info, 3, Class, Reason, Msg, HandlerState,
erlang:get_stacktrace()]),
websocket_close(WSReq, HandlerState, Reason)
HandlerResponse ->
handle_response(WSReq, HandlerResponse, Buffer)
catch
_:Reason ->
websocket_close(WSReq, HandlerState, {handler, Reason})
end
end.

-spec cancel_keepalive_timer(websocket_req:req()) -> ok.
cancel_keepalive_timer(WSReq) ->
case websocket_req:keepalive_timer(WSReq) of
undefined ->
ok;
OldTimer ->
erlang:cancel_timer(OldTimer),
ok
end.

-spec websocket_close(WSReq :: websocket_req:req(),
HandlerState :: any(),
Reason :: tuple()) -> ok.
websocket_close(WSReq, HandlerState, Reason) ->
Handler = websocket_req:handler(WSReq),
try Handler:websocket_terminate(Reason, WSReq, HandlerState)
catch Class:Reason2 ->
error_logger:error_msg(
"** Websocket handler ~p terminating in ~p/~p~n"
" for the reason ~p:~p~n"
try Handler:websocket_terminate(Reason, WSReq, HandlerState) of
_ ->
case Reason of
normal -> ok;
_ -> error_info(Handler, Reason, HandlerState)
end,
exit(Reason)
catch
_:Reason2 ->
error_info(Handler, Reason2, HandlerState),
exit(Reason2)
end.

error_info(Handler, Reason, State) ->
error_logger:error_msg(
"** Websocket handler ~p terminating~n"
"** for the reason ~p~n"
"** Handler state was ~p~n"
"** Stacktrace: ~p~n~n",
[Handler, websocket_terminate, 3, Class, Reason2, HandlerState,
erlang:get_stacktrace()])
end.
[Handler, Reason, State, erlang:get_stacktrace()]).

%% @doc Key sent in initial handshake
-spec generate_ws_key() ->
binary().
generate_ws_key() ->
base64:encode(crypto:rand_bytes(16)).
base64:encode(crypto:strong_rand_bytes(16)).

%% @doc Validate handshake response challenge
-spec validate_handshake(HandshakeResponse :: binary(), Key :: binary()) -> {ok, binary()} | {error, term()}.
Expand Down Expand Up @@ -332,18 +341,24 @@ retrieve_frame(WSReq, HandlerState, Opcode, Len, Data, Buffer) ->
end,
case OpcodeName of
close when byte_size(FullPayload) >= 2 ->
<< CodeBin:2/binary, ClosePayload/binary >> = FullPayload,
<< CodeBin:2/binary, _ClosePayload/binary >> = FullPayload,
Code = binary:decode_unsigned(CodeBin),
Reason = case Code of
1000 -> {normal, ClosePayload};
1002 -> {error, badframe, ClosePayload};
1007 -> {error, badencoding, ClosePayload};
1011 -> {error, handler, ClosePayload};
_ -> {remote, Code, ClosePayload}
% 1000 indicates a normal closure, meaning that the purpose for
% which the connection was established has been fulfilled.
1000 -> normal;

% 1001 indicates that an endpoint is "going away", such as a server
% going down or a browser having navigated away from a page.
1001 -> normal;

% See https://tools.ietf.org/html/rfc6455#section-7.4.1
% for error code descriptions.
_ -> {remote, Code}
end,
websocket_close(WSReq, HandlerState, Reason);
close ->
websocket_close(WSReq, HandlerState, {remote, <<>>});
websocket_close(WSReq, HandlerState, remote);
%% Non-control continuation frame
_ when Opcode < 8, Continuation =/= undefined, Fin == 0 ->
%% Append to previously existing continuation payloads and continue
Expand All @@ -359,36 +374,21 @@ retrieve_frame(WSReq, HandlerState, Opcode, Len, Data, Buffer) ->
try Handler:websocket_handle(
{ContinuationOpcodeName, DefragPayload},
WSReq2, HandlerState) of
HandlerResponse ->
handle_response(websocket_req:remaining(undefined, WSReq1),
HandlerResponse, Rest)
catch Class:Reason ->
error_logger:error_msg(
"** Websocket client ~p terminating in ~p/~p~n"
" for the reason ~p:~p~n"
"** Websocket message was ~p~n"
"** Handler state was ~p~n"
"** Stacktrace: ~p~n~n",
[Handler, websocket_handle, 3, Class, Reason, {ContinuationOpcodeName, DefragPayload}, HandlerState,
erlang:get_stacktrace()]),
websocket_close(WSReq, HandlerState, Reason)
HandlerResponse ->
handle_response(websocket_req:remaining(undefined, WSReq1),
HandlerResponse, Rest)
catch _:Reason ->
websocket_close(WSReq, HandlerState, {handler, Reason})
end;
_ ->
try Handler:websocket_handle(
{OpcodeName, FullPayload},
WSReq, HandlerState) of
HandlerResponse ->
handle_response(websocket_req:remaining(undefined, WSReq),
HandlerResponse, Rest)
catch Class:Reason ->
error_logger:error_msg(
"** Websocket client ~p terminating in ~p/~p~n"
" for the reason ~p:~p~n"
"** Handler state was ~p~n"
"** Stacktrace: ~p~n~n",
[Handler, websocket_handle, 3, Class, Reason, HandlerState,
erlang:get_stacktrace()]),
websocket_close(WSReq, HandlerState, Reason)
HandlerResponse ->
handle_response(websocket_req:remaining(undefined, WSReq),
HandlerResponse, Rest)
catch _:Reason ->
websocket_close(WSReq, HandlerState, {handler, Reason})
end
end.

Expand All @@ -400,11 +400,14 @@ handle_response(WSReq, {reply, Frame, HandlerState}, Buffer) ->
%% we can still have more messages in buffer
case websocket_req:remaining(WSReq) of
%% buffer should not contain uncomplete messages
undefined -> retrieve_frame(WSReq, HandlerState, Buffer);
undefined ->
retrieve_frame(WSReq, HandlerState, Buffer);
%% buffer contain uncomplete message that shouldnt be parsed
_ -> websocket_loop(WSReq, HandlerState, Buffer)
_ ->
websocket_loop(WSReq, HandlerState, Buffer)
end;
Reason -> websocket_close(WSReq, HandlerState, Reason)
{error, Reason} ->
websocket_close(WSReq, HandlerState, {local, Reason})
end;
handle_response(WSReq, {ok, HandlerState}, Buffer) ->
%% we can still have more messages in buffer
Expand All @@ -417,7 +420,7 @@ handle_response(WSReq, {ok, HandlerState}, Buffer) ->

handle_response(WSReq, {close, Payload, HandlerState}, _) ->
send({close, Payload}, WSReq),
websocket_close(WSReq, HandlerState, {normal, Payload}).
websocket_close(WSReq, HandlerState, normal).

%% @doc Encodes the data with a header (including a masking key) and
%% masks the data
Expand All @@ -427,7 +430,7 @@ encode_frame({Type, Payload}) ->
Opcode = websocket_req:name_to_opcode(Type),
Len = iolist_size(Payload),
BinLen = payload_length_to_binary(Len),
MaskingKeyBin = crypto:rand_bytes(4),
MaskingKeyBin = crypto:strong_rand_bytes(4),
<< MaskingKey:32 >> = MaskingKeyBin,
Header = << 1:1, 0:3, Opcode:4, 1:1, BinLen/bits, MaskingKeyBin/bits >>,
MaskedPayload = mask_payload(MaskingKey, Payload),
Expand Down
29 changes: 24 additions & 5 deletions src/websocket_client_handler.erl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,28 @@

-type state() :: any().
-type keepalive() :: integer().
-type close_type() :: normal | error | remote.

-type close_reason() ::
% Either:
% - The websocket was closed by a handler via a `{closed, Reason, State}` tuple
% returned from websocket_handle/3 or websocket_info/3.
% - A 'close' frame was received with code 1000 or 1001.
normal |
% The local end failed to send (see http://erlang.org/doc/man/gen_tcp.html#send-2
% or http://erlang.org/doc/man/ssl.html#send-2). The second element in the
% tuple is the same term that was wrapped in an `{error, Reason}` tuple by
% `send/2`, i.e. `{error, closed}` will become `{local, closed}`, and not
% `{local, {error, closed}}`.
{local, term()} |
% The remote end either closed abruptly, or closed after sending a 'close' frame
% without a status code.
remote |
% The remote end closed with a status code (see https://tools.ietf.org/html/rfc6455#section-7.4.1).
{remote, integer()} |
% An asynchronous exception was raised during message handling, either in
% websocket_handle/3 or websocket_info/3. The term raised is passed as the
% second element in this tuple.
{handler, term()}.

-callback init(list(), websocket_req:req()) ->
{ok, state()}
Expand All @@ -16,8 +37,6 @@
-callback websocket_info(any(), websocket_req:req(), state()) ->
{ok, state()}
| {reply, websocket_req:frame(), state()}
| {close, binary(), state()}.
| {close, binary(), state()}.

-callback websocket_terminate({close_type(), term()} | {close_type(), integer(), binary()},
websocket_req:req(), state()) ->
ok.
-callback websocket_terminate(close_reason(), websocket_req:req(), state()) -> ok.
9 changes: 9 additions & 0 deletions src/websocket_req.erl
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
port/2, port/1,
path/2, path/1,
keepalive/2, keepalive/1,
keepalive_timer/2, keepalive_timer/1,
socket/2, socket/1,
transport/2, transport/1,
handler/2, handler/1,
Expand Down Expand Up @@ -134,6 +135,14 @@ keepalive(K, Req) ->
Req#websocket_req{keepalive = K}.


-spec keepalive_timer(req()) -> undefined | reference().
keepalive_timer(#websocket_req{keepalive_timer = K}) -> K.

-spec keepalive_timer(reference(), req()) -> req().
keepalive_timer(K, Req) ->
Req#websocket_req{keepalive_timer = K}.


-spec socket(req()) -> inet:socket() | ssl:sslsocket().
socket(#websocket_req{socket = S}) -> S.

Expand Down