diff --git a/libvncserver/httpd.c b/libvncserver/httpd.c index e2de6c707..9ae9cf90e 100644 --- a/libvncserver/httpd.c +++ b/libvncserver/httpd.c @@ -355,6 +355,15 @@ httpProcessInput(rfbScreenInfoPtr rfbScreen) rfbScreen->httpSock = RFB_INVALID_SOCKET; return; } +#ifdef LIBVNCSERVER_WITH_WEBSOCKETS + if (strstr(buf, "\r\nUpgrade: websocket\r\n")) { + /* websocket connection */ + rfbLog("httpd: client asked to upgrade to websockets connection\n"); + rfbNewWebSocketsClient(rfbScreen,rfbScreen->httpSock,buf); + rfbScreen->httpSock = RFB_INVALID_SOCKET; + return; + } +#endif } if (strncmp(buf, "GET ", 4)) { diff --git a/libvncserver/private.h b/libvncserver/private.h index d656e3910..96601bb34 100644 --- a/libvncserver/private.h +++ b/libvncserver/private.h @@ -11,6 +11,12 @@ void rfbRedrawAfterHideCursor(rfbClientPtr cl,sraRegionPtr updateRegion); rfbClientPtr rfbClientIteratorHead(rfbClientIteratorPtr i); +#ifdef LIBVNCSERVER_WITH_WEBSOCKETS +/* from websockets.c */ + +rfbBool webSocketsHandshake(rfbClientPtr cl, char *scheme, const char *httpHeaders); +#endif + /* from tight.c */ #ifdef LIBVNCSERVER_HAVE_LIBZ diff --git a/libvncserver/rfbserver.c b/libvncserver/rfbserver.c index 67ffcbcfd..8a697198f 100644 --- a/libvncserver/rfbserver.c +++ b/libvncserver/rfbserver.c @@ -301,7 +301,8 @@ rfbSetProtocolVersion(rfbScreenInfoPtr rfbScreen, int major_, int minor_) static rfbClientPtr rfbNewTCPOrUDPClient(rfbScreenInfoPtr rfbScreen, rfbSocket sock, - rfbBool isUDP) + rfbBool isUDP, + const char *httpHeaders) { rfbProtocolVersionMsg pv; rfbClientIteratorPtr iterator; @@ -470,10 +471,18 @@ rfbNewTCPOrUDPClient(rfbScreenInfoPtr rfbScreen, #endif #ifdef LIBVNCSERVER_WITH_WEBSOCKETS + if (httpHeaders) { + /* Do websocket handshake with HTTP headers previously read */ + if (!webSocketsHandshake(cl, "ws", httpHeaders)) { + rfbCloseClient(cl); + rfbClientConnectionGone(cl); + return NULL; + } + } /* * Wait a few ms for the client to send WebSockets connection (TLS/SSL or plain) */ - if (!webSocketsCheck(cl)) { + else if (!webSocketsCheck(cl)) { /* Error reporting handled in webSocketsHandshake */ rfbCloseClient(cl); rfbClientConnectionGone(cl); @@ -530,14 +539,20 @@ rfbClientPtr rfbNewClient(rfbScreenInfoPtr rfbScreen, rfbSocket sock) { - return(rfbNewTCPOrUDPClient(rfbScreen,sock,FALSE)); + return(rfbNewTCPOrUDPClient(rfbScreen,sock,FALSE,NULL)); } rfbClientPtr rfbNewUDPClient(rfbScreenInfoPtr rfbScreen) { return((rfbScreen->udpClient= - rfbNewTCPOrUDPClient(rfbScreen,rfbScreen->udpSock,TRUE))); + rfbNewTCPOrUDPClient(rfbScreen,rfbScreen->udpSock,TRUE,NULL))); +} + +rfbClientPtr +rfbNewWebSocketsClient(rfbScreenInfoPtr rfbScreen, rfbSocket sock, const char *httpHeaders) +{ + return(rfbNewTCPOrUDPClient(rfbScreen,sock,FALSE,httpHeaders)); } /* diff --git a/libvncserver/websockets.c b/libvncserver/websockets.c index 9fd96a698..8fd26ad92 100644 --- a/libvncserver/websockets.c +++ b/libvncserver/websockets.c @@ -53,6 +53,7 @@ #include "crypto.h" #include "ws_decode.h" #include "base64.h" +#include "private.h" #if 0 #include <sys/syscall.h> @@ -92,8 +93,6 @@ struct timeval ; #endif -static rfbBool webSocketsHandshake(rfbClientPtr cl, char *scheme); - static int webSocketsEncodeHybi(rfbClientPtr cl, const char *src, int len, char **dst); static int ws_read(void *cl, char *buf, size_t len); @@ -157,15 +156,15 @@ webSocketsCheck (rfbClientPtr cl) rfbLog("Got '%s' WebSockets handshake\n", scheme); - if (!webSocketsHandshake(cl, scheme)) { + if (!webSocketsHandshake(cl, scheme, NULL)) { return FALSE; } /* Start WebSockets framing */ return TRUE; } -static rfbBool -webSocketsHandshake(rfbClientPtr cl, char *scheme) +rfbBool +webSocketsHandshake(rfbClientPtr cl, char *scheme, const char *httpHeaders) { char *buf, *response, *line; int n, linestart = 0, len = 0, llen, base64 = FALSE; @@ -189,7 +188,13 @@ webSocketsHandshake(rfbClientPtr cl, char *scheme) } while (len < WEBSOCKETS_MAX_HANDSHAKE_LEN-1) { - if ((n = rfbReadExactTimeout(cl, buf+len, 1, + if (httpHeaders) { + buf[len] = httpHeaders[0]; + if (!httpHeaders[0]) + break; + httpHeaders++; + } + else if ((n = rfbReadExactTimeout(cl, buf+len, 1, WEBSOCKETS_CLIENT_SEND_WAIT_MS)) <= 0) { if ((n < 0) && (errno == ETIMEDOUT)) { break; diff --git a/rfb/rfb.h b/rfb/rfb.h index 4f275fc86..7b6e7accc 100644 --- a/rfb/rfb.h +++ b/rfb/rfb.h @@ -770,6 +770,7 @@ extern rfbBool webSocketCheckDisconnect(rfbClientPtr cl); extern int webSocketsEncode(rfbClientPtr cl, const char *src, int len, char **dst); extern int webSocketsDecode(rfbClientPtr cl, char *dst, int len); extern rfbBool webSocketsHasDataInBuffer(rfbClientPtr cl); +extern rfbClientPtr rfbNewWebSocketsClient(rfbScreenInfoPtr rfbScreen, rfbSocket sock, const char *httpHeaders); #endif /* rfbserver.c */