diff --git a/packages/bun-usockets/src/bsd.c b/packages/bun-usockets/src/bsd.c index 0e9c435ffa3dea..5a64df6e8ee34c 100644 --- a/packages/bun-usockets/src/bsd.c +++ b/packages/bun-usockets/src/bsd.c @@ -30,11 +30,12 @@ #include #include #include -#include #include #include #include #include +#else /* _WIN32 */ +#include #endif #if defined(__APPLE__) && defined(__aarch64__) @@ -60,9 +61,10 @@ int bsd_sendmmsg(LIBUS_SOCKET_DESCRIPTOR fd, struct udp_sendbuf* sendbuf, int fl errno = EAFNOSUPPORT; return -1; } + int err = WSAGetLastError(); if (ret < 0) { - if (errno == EINTR) continue; - if (errno == EAGAIN || errno == EWOULDBLOCK) return i; + if (err == WSAEINTR) continue; + if (err == WSAEWOULDBLOCK) return i; return ret; } break; @@ -110,7 +112,7 @@ int bsd_recvmmsg(LIBUS_SOCKET_DESCRIPTOR fd, struct udp_recvbuf *recvbuf, int fl while (1) { ssize_t ret = recvfrom(fd, recvbuf->buf, LIBUS_RECV_BUFFER_LENGTH, flags, (struct sockaddr *)&recvbuf->addr, &addr_len); if (ret < 0) { - if (errno == EINTR) continue; + if (WSAGetLastError() == WSAEINTR) continue; return ret; } recvbuf->recvlen = ret; @@ -576,7 +578,7 @@ LIBUS_SOCKET_DESCRIPTOR bsd_create_listen_socket(const char *host, int port, int #include #include -static int bsd_create_unix_socket_address(const char *path, size_t path_len, int* dirfd_linux_workaround_for_unix_path_len, struct sockaddr_un *server_address, size_t* addrlen) { +static LIBUS_SOCKET_DESCRIPTOR bsd_create_unix_socket_address(const char *path, size_t path_len, int* dirfd_linux_workaround_for_unix_path_len, struct sockaddr_un *server_address, size_t* addrlen) { memset(server_address, 0, sizeof(struct sockaddr_un)); server_address->sun_family = AF_UNIX; @@ -842,7 +844,13 @@ int bsd_disconnect_udp_socket(LIBUS_SOCKET_DESCRIPTOR fd) { int res = connect(fd, &addr, sizeof(addr)); // EAFNOSUPPORT is harmless in this case - we just want to disconnect - if (res == 0 || errno == EAFNOSUPPORT) { + if (res == 0 || +#ifdef _WIN32 + WSAGetLastError() == WSAEAFNOSUPPORT +#else + errno == EAFNOSUPPORT +#endif + ) { return 0; } else { return -1; @@ -880,21 +888,33 @@ int bsd_disconnect_udp_socket(LIBUS_SOCKET_DESCRIPTOR fd) { // return 0; // no ecn defaults to 0 // } -static int bsd_do_connect_raw(struct addrinfo *rp, int fd) +static int bsd_do_connect_raw(struct addrinfo *rp, LIBUS_SOCKET_DESCRIPTOR fd) { +#ifdef _WIN32 + do { + if (connect(fd, rp->ai_addr, rp->ai_addrlen) == 0 || WSAGetLastError() == WSAEINPROGRESS) { + return 0; + } + } while (WSAGetLastError() == WSAEINTR); + + return WSAGetLastError(); +#else do { if (connect(fd, rp->ai_addr, rp->ai_addrlen) == 0 || errno == EINPROGRESS) { return 0; } } while (errno == EINTR); - return LIBUS_SOCKET_ERROR; + return errno; +#endif } -static int bsd_do_connect(struct addrinfo *rp, int *fd) +static int bsd_do_connect(struct addrinfo *rp, LIBUS_SOCKET_DESCRIPTOR *fd) { + int lastErr = 0; while (rp != NULL) { - if (bsd_do_connect_raw(rp, *fd) == 0) { + lastErr = bsd_do_connect_raw(rp, *fd); + if (lastErr == 0) { return 0; } @@ -902,124 +922,116 @@ static int bsd_do_connect(struct addrinfo *rp, int *fd) bsd_close_socket(*fd); if (rp == NULL) { + if (lastErr != 0) { + errno = lastErr; + } return LIBUS_SOCKET_ERROR; } - int resultFd = bsd_create_socket(rp->ai_family, rp->ai_socktype, rp->ai_protocol); + LIBUS_SOCKET_DESCRIPTOR resultFd = bsd_create_socket(rp->ai_family, SOCK_STREAM, 0); if (resultFd < 0) { return LIBUS_SOCKET_ERROR; } *fd = resultFd; } + if (lastErr != 0) { + errno = lastErr; + } + return LIBUS_SOCKET_ERROR; } -LIBUS_SOCKET_DESCRIPTOR bsd_create_connect_socket(const char *host, int port, const char *source_host, int options) { #ifdef _WIN32 - // The caller (sometimes) uses NULL to indicate localhost. This works fine with getaddrinfo, but not with WSAConnectByName - if (!host) { - host = "localhost"; - } else if (strcmp(host, "0.0.0.0") == 0 || strcmp(host, "::") == 0 || strcmp(host, "[::]") == 0) { - // windows disallows connecting to 0.0.0.0. To emulate POSIX behavior, we connect to localhost instead - // Also see https://docs.libuv.org/en/v1.x/tcp.html#c.uv_tcp_connect - host = "localhost"; - } - // On windows we use WSAConnectByName to speed up connecting to localhost - // The other implementation also works on windows, but is slower - char port_string[16]; - snprintf(port_string, 16, "%d", port); - SOCKET s = socket(AF_INET6, SOCK_STREAM, 0); - if (s == INVALID_SOCKET) { - return LIBUS_SOCKET_ERROR; - } - // https://learn.microsoft.com/en-us/windows/win32/api/winsock2/nf-winsock2-wsaconnectbynamea#remarks - DWORD zero = 0; - if (SOCKET_ERROR == setsockopt(s, IPPROTO_IPV6, IPV6_V6ONLY, (const char*)&zero, sizeof(DWORD))) { - closesocket(s); - return LIBUS_SOCKET_ERROR; - } - if (source_host) { - struct addrinfo *interface_result; - if (!getaddrinfo(source_host, NULL, NULL, &interface_result)) { - int ret = bind(s, interface_result->ai_addr, (socklen_t) interface_result->ai_addrlen); - freeaddrinfo(interface_result); - if (ret == SOCKET_ERROR) { - closesocket(s); - return LIBUS_SOCKET_ERROR; - } +static int convert_null_addr(struct addrinfo *addrinfo, struct addrinfo* result, struct sockaddr_storage *inaddr) { + // 1. check that all addrinfo results are 0.0.0.0 or :: + if (addrinfo->ai_family == AF_INET) { + struct sockaddr_in *addr = (struct sockaddr_in *) addrinfo->ai_addr; + if (addr->sin_addr.s_addr == htonl(INADDR_ANY)) { + memcpy(inaddr, addr, sizeof(struct sockaddr_in)); + ((struct sockaddr_in *) inaddr)->sin_addr.s_addr = htonl(INADDR_LOOPBACK); + + memcpy(result, addrinfo, sizeof(struct addrinfo)); + result->ai_addr = (struct sockaddr *) inaddr; + result->ai_next = NULL; + + return 1; + } + } else if (addrinfo->ai_family == AF_INET6) { + struct sockaddr_in6 *addr = (struct sockaddr_in6 *) addrinfo->ai_addr; + if (memcmp(&addr->sin6_addr, &in6addr_any, sizeof(struct in6_addr)) == 0) { + memcpy(inaddr, addr, sizeof(struct sockaddr_in6)); + memcpy(&((struct sockaddr_in6 *) inaddr)->sin6_addr, &in6addr_loopback, sizeof(struct in6_addr)); + + memcpy(result, addrinfo, sizeof(struct addrinfo)); + result->ai_addr = (struct sockaddr *) inaddr; + result->ai_next = NULL; + + return 1; } } - SOCKADDR_STORAGE local; - SOCKADDR_STORAGE remote; - DWORD local_len = sizeof(local); - DWORD remote_len = sizeof(remote); - if (FALSE == WSAConnectByNameA(s, host, port_string, &local_len, (SOCKADDR*)&local, &remote_len, (SOCKADDR*)&remote, NULL, NULL)) { - closesocket(s); - return LIBUS_SOCKET_ERROR; + return 0; +} + +static int is_loopback(struct addrinfo *addrinfo) { + if (addrinfo->ai_family == AF_INET) { + struct sockaddr_in *addr = (struct sockaddr_in *) addrinfo->ai_addr; + return addr->sin_addr.s_addr == htonl(INADDR_LOOPBACK); + } else if (addrinfo->ai_family == AF_INET6) { + struct sockaddr_in6 *addr = (struct sockaddr_in6 *) addrinfo->ai_addr; + return memcmp(&addr->sin6_addr, &in6addr_loopback, sizeof(struct in6_addr)) == 0; + } else { + return 0; } +} +#endif - // See - // - https://stackoverflow.com/questions/60591081/getpeername-always-fails-with-error-code-wsaenotconn - // - https://learn.microsoft.com/en-us/windows/win32/api/winsock2/nf-winsock2-wsaconnectbynamea#remarks - // - // When the WSAConnectByName function returns TRUE, the socket s is in the default state for a connected socket. - // The socket s does not enable previously set properties or options until SO_UPDATE_CONNECT_CONTEXT is set on the socket. - // Use the setsockopt function to set the SO_UPDATE_CONNECT_CONTEXT option. - // - if (SOCKET_ERROR == setsockopt( s, SOL_SOCKET, SO_UPDATE_CONNECT_CONTEXT, NULL, 0 )) { - closesocket(s); +LIBUS_SOCKET_DESCRIPTOR bsd_create_connect_socket(struct addrinfo *addrinfo, int options) { + LIBUS_SOCKET_DESCRIPTOR fd = bsd_create_socket(addrinfo->ai_family, SOCK_STREAM, 0); + if (fd == LIBUS_SOCKET_ERROR) { return LIBUS_SOCKET_ERROR; } - return s; -#else - struct addrinfo hints, *result; - memset(&hints, 0, sizeof(struct addrinfo)); - hints.ai_family = AF_UNSPEC; - hints.ai_socktype = SOCK_STREAM; - char port_string[16]; - snprintf(port_string, 16, "%d", port); +#ifdef _WIN32 - if (getaddrinfo(host, port_string, &hints, &result) != 0) { - return LIBUS_SOCKET_ERROR; + // On windows we can't connect to the null address directly. + // To match POSIX behavior, we need to connect to localhost instead. + struct addrinfo alt_result; + struct sockaddr_storage storage; + + if (convert_null_addr(addrinfo, &alt_result, &storage)) { + addrinfo = &alt_result; } - LIBUS_SOCKET_DESCRIPTOR fd = bsd_create_socket(result->ai_family, result->ai_socktype, result->ai_protocol); - if (fd == LIBUS_SOCKET_ERROR) { - freeaddrinfo(result); - return LIBUS_SOCKET_ERROR; + // This sets the socket to fail quickly if no connection can be established to localhost, + // instead of waiting for the default 2 seconds. This is necessary because we always try to connect + // using IPv6 first, but it's possible that whatever we want to connect to is only listening on IPv4. + // see https://github.com/libuv/libuv/blob/bf61390769068de603e6deec8e16623efcbe761a/src/win/tcp.c#L806 + TCP_INITIAL_RTO_PARAMETERS retransmit_ioctl; + DWORD bytes; + if (is_loopback(addrinfo)) { + memset(&retransmit_ioctl, 0, sizeof(retransmit_ioctl)); + retransmit_ioctl.Rtt = TCP_INITIAL_RTO_NO_SYN_RETRANSMISSIONS; + retransmit_ioctl.MaxSynRetransmissions = TCP_INITIAL_RTO_NO_SYN_RETRANSMISSIONS; + WSAIoctl(fd, + SIO_TCP_INITIAL_RTO, + &retransmit_ioctl, + sizeof(retransmit_ioctl), + NULL, + 0, + &bytes, + NULL, + NULL); } - if (source_host) { - struct addrinfo *interface_result; - if (!getaddrinfo(source_host, NULL, NULL, &interface_result)) { - int ret = bind(fd, interface_result->ai_addr, (socklen_t) interface_result->ai_addrlen); - freeaddrinfo(interface_result); - if (ret == LIBUS_SOCKET_ERROR) { - bsd_close_socket(fd); - freeaddrinfo(result); - return LIBUS_SOCKET_ERROR; - } - } +#endif - if (bsd_do_connect_raw(result, fd) != 0) { - bsd_close_socket(fd); - freeaddrinfo(result); - return LIBUS_SOCKET_ERROR; - } - } else { - if (bsd_do_connect(result, &fd) != 0) { - freeaddrinfo(result); - return LIBUS_SOCKET_ERROR; - } + if (bsd_do_connect(addrinfo, &fd) != 0) { + return LIBUS_SOCKET_ERROR; } - - freeaddrinfo(result); return fd; -#endif } static LIBUS_SOCKET_DESCRIPTOR internal_bsd_create_connect_socket_unix(const char *server_path, size_t len, int options, struct sockaddr_un* server_address, const size_t addrlen) { diff --git a/packages/bun-usockets/src/context.c b/packages/bun-usockets/src/context.c index a918879bf39e7b..7423305ef54009 100644 --- a/packages/bun-usockets/src/context.c +++ b/packages/bun-usockets/src/context.c @@ -19,6 +19,7 @@ #include "internal/internal.h" #include #include +#include int default_is_low_prio_handler(struct us_socket_t *s) { return 0; @@ -330,7 +331,7 @@ struct us_listen_socket_t *us_socket_context_listen_unix(int ssl, struct us_sock us_poll_start(p, context->loop, LIBUS_SOCKET_READABLE); struct us_listen_socket_t *ls = (struct us_listen_socket_t *) p; - + ls->s.connect_state = NULL; ls->s.context = context; ls->s.timeout = 255; ls->s.long_timeout = 255; @@ -343,33 +344,83 @@ struct us_listen_socket_t *us_socket_context_listen_unix(int ssl, struct us_sock return ls; } -struct us_socket_t *us_socket_context_connect(int ssl, struct us_socket_context_t *context, const char *host, int port, const char *source_host, int options, int socket_ext_size) { +struct us_connecting_socket_t *us_socket_context_connect(int ssl, struct us_socket_context_t *context, const char *host, int port, int options, int socket_ext_size) { #ifndef LIBUS_NO_SSL if (ssl) { - return (struct us_socket_t *) us_internal_ssl_socket_context_connect((struct us_internal_ssl_socket_context_t *) context, host, port, source_host, options, socket_ext_size); + return us_internal_ssl_socket_context_connect((struct us_internal_ssl_socket_context_t *) context, host, port, options, socket_ext_size); } #endif - LIBUS_SOCKET_DESCRIPTOR connect_socket_fd = bsd_create_connect_socket(host, port, source_host, options); + struct us_connecting_socket_t *c = us_calloc(1, sizeof(struct us_connecting_socket_t) + socket_ext_size); + c->socket_ext_size = socket_ext_size; + c->context = context; + c->options = options; + c->ssl = ssl > 0; + c->timeout = 255; + c->long_timeout = 255; + c->pending_resolve_callback = 1; + + Bun__addrinfo_get(host, port, c); + +#ifdef _WIN32 + context->loop->uv_loop->active_handles++; +#else + context->loop->num_polls++; +#endif + + return c; +} + +void us_internal_socket_after_resolve(struct us_connecting_socket_t *c) { + // make sure to decrement the active_handles counter, no matter what +#ifdef _WIN32 + c->context->loop->uv_loop->active_handles--; +#else + c->context->loop->num_polls--; +#endif + + c->pending_resolve_callback = 0; + // if the socket was closed while we were resolving the address, free it + if (c->closed) { + us_connecting_socket_free(c); + return; + } + struct addrinfo_result *result = Bun__addrinfo_getRequestResult(c->addrinfo_req); + if (result->error) { + c->error = result->error; + c->context->on_connect_error(c, result->error); + Bun__addrinfo_freeRequest(c->addrinfo_req, 0); + us_connecting_socket_close(0, c); + return; + } + LIBUS_SOCKET_DESCRIPTOR connect_socket_fd = bsd_create_connect_socket(result->info, c->options); if (connect_socket_fd == LIBUS_SOCKET_ERROR) { - return 0; + c->error = errno; + c->context->on_connect_error(c, errno); + Bun__addrinfo_freeRequest(c->addrinfo_req, 1); + us_connecting_socket_close(0, c); + return; } - /* Connect sockets are semi-sockets just like listen sockets */ - struct us_poll_t *p = us_create_poll(context->loop, 0, sizeof(struct us_socket_t) + socket_ext_size); - us_poll_init(p, connect_socket_fd, POLL_TYPE_SEMI_SOCKET); - us_poll_start(p, context->loop, LIBUS_SOCKET_WRITABLE); - - struct us_socket_t *connect_socket = (struct us_socket_t *) p; + Bun__addrinfo_freeRequest(c->addrinfo_req, 0); + struct us_socket_t *s = (struct us_socket_t *)us_create_poll(c->context->loop, 0, sizeof(struct us_socket_t) + c->socket_ext_size); + s->context = c->context; + s->timeout = c->timeout; + s->long_timeout = c->long_timeout; + s->low_prio_state = 0; /* Link it into context so that timeout fires properly */ - connect_socket->context = context; - connect_socket->timeout = 255; - connect_socket->long_timeout = 255; - connect_socket->low_prio_state = 0; - us_internal_socket_context_link_socket(context, connect_socket); + us_internal_socket_context_link_socket(s->context, s); + // TODO check this, specifically how it interacts with the SSL code + memcpy(us_socket_ext(0, s), us_connecting_socket_ext(0, c), c->socket_ext_size); - return connect_socket; + /* Connect sockets are semi-sockets just like listen sockets */ + us_poll_init(&s->p, connect_socket_fd, POLL_TYPE_SEMI_SOCKET); + us_poll_start(&s->p, s->context->loop, LIBUS_SOCKET_WRITABLE); + + // store the socket so we can close it if we need to + c->socket = s; + s->connect_state = c; } struct us_socket_t *us_socket_context_connect_unix(int ssl, struct us_socket_context_t *context, const char *server_path, size_t pathlen, int options, int socket_ext_size) { @@ -396,6 +447,7 @@ struct us_socket_t *us_socket_context_connect_unix(int ssl, struct us_socket_con connect_socket->timeout = 255; connect_socket->long_timeout = 255; connect_socket->low_prio_state = 0; + connect_socket->connect_state = NULL; us_internal_socket_context_link_socket(context, connect_socket); return connect_socket; @@ -431,7 +483,17 @@ struct us_socket_t *us_socket_context_adopt_socket(int ssl, struct us_socket_con us_internal_socket_context_unlink_socket(s->context, s); } - struct us_socket_t *new_s = (struct us_socket_t *) us_poll_resize(&s->p, s->context->loop, sizeof(struct us_socket_t) + ext_size); + + struct us_connecting_socket_t *c = s->connect_state; + + struct us_socket_t *new_s = s; + if (ext_size != -1) { + new_s = (struct us_socket_t *) us_poll_resize(&s->p, s->context->loop, sizeof(struct us_socket_t) + ext_size); + if (c) { + c->socket = new_s; + c->context = context; + } + } new_s->timeout = 255; new_s->long_timeout = 255; @@ -526,7 +588,7 @@ void us_socket_context_on_end(int ssl, struct us_socket_context_t *context, stru context->on_end = on_end; } -void us_socket_context_on_connect_error(int ssl, struct us_socket_context_t *context, struct us_socket_t *(*on_connect_error)(struct us_socket_t *s, int code)) { +void us_socket_context_on_connect_error(int ssl, struct us_socket_context_t *context, struct us_connecting_socket_t *(*on_connect_error)(struct us_connecting_socket_t *s, int code)) { #ifndef LIBUS_NO_SSL if (ssl) { us_internal_ssl_socket_context_on_connect_error((struct us_internal_ssl_socket_context_t *) context, (struct us_internal_ssl_socket_t * (*)(struct us_internal_ssl_socket_t *, int)) on_connect_error); diff --git a/packages/bun-usockets/src/crypto/openssl.c b/packages/bun-usockets/src/crypto/openssl.c index 3be9b950ad27a7..f88048f3cf8c3a 100644 --- a/packages/bun-usockets/src/crypto/openssl.c +++ b/packages/bun-usockets/src/crypto/openssl.c @@ -107,7 +107,7 @@ enum { struct us_internal_ssl_socket_t { struct us_socket_t s; - SSL *ssl; + SSL *ssl; // this _must_ be the first member after s #if ALLOW_SERVER_RENEGOTIATION unsigned int client_pending_renegotiations; uint64_t last_ssl_renegotiation; @@ -1515,11 +1515,12 @@ struct us_listen_socket_t *us_internal_ssl_socket_context_listen_unix( socket_ext_size); } -struct us_internal_ssl_socket_t *us_internal_ssl_socket_context_connect( +// TODO does this need more changes? +struct us_connecting_socket_t *us_internal_ssl_socket_context_connect( struct us_internal_ssl_socket_context_t *context, const char *host, - int port, const char *source_host, int options, int socket_ext_size) { - return (struct us_internal_ssl_socket_t *)us_socket_context_connect( - 0, &context->sc, host, port, source_host, options, + int port, int options, int socket_ext_size) { + return us_socket_context_connect( + 0, &context->sc, host, port, options, sizeof(struct us_internal_ssl_socket_t) - sizeof(struct us_socket_t) + socket_ext_size); } @@ -1612,7 +1613,7 @@ void us_internal_ssl_socket_context_on_connect_error( struct us_internal_ssl_socket_t *, int code)) { us_socket_context_on_connect_error( 0, (struct us_socket_context_t *)context, - (struct us_socket_t * (*)(struct us_socket_t *, int)) on_connect_error); + (struct us_connecting_socket_t * (*)(struct us_connecting_socket_t *, int)) on_connect_error); } void *us_internal_ssl_socket_context_ext( @@ -1691,6 +1692,10 @@ void *us_internal_ssl_socket_ext(struct us_internal_ssl_socket_t *s) { return s + 1; } +void *us_internal_connecting_ssl_socket_ext(struct us_connecting_socket_t *s) { + return (char*)(s + 1) + sizeof(struct us_internal_ssl_socket_t) - sizeof(struct us_socket_t); +} + int us_internal_ssl_socket_is_shut_down(struct us_internal_ssl_socket_t *s) { return us_socket_is_shut_down(0, &s->s) || SSL_get_shutdown(s->ssl) & SSL_SENT_SHUTDOWN; @@ -1743,10 +1748,13 @@ struct us_internal_ssl_socket_t *us_internal_ssl_socket_context_adopt_socket( struct us_internal_ssl_socket_context_t *context, struct us_internal_ssl_socket_t *s, int ext_size) { // todo: this is completely untested + int new_ext_size = ext_size; + if (ext_size != -1) { + new_ext_size = sizeof(struct us_internal_ssl_socket_t) - sizeof(struct us_socket_t) + ext_size; + } return (struct us_internal_ssl_socket_t *)us_socket_context_adopt_socket( 0, &context->sc, &s->s, - sizeof(struct us_internal_ssl_socket_t) - sizeof(struct us_socket_t) + - ext_size); + new_ext_size); } struct us_internal_ssl_socket_t * @@ -1873,11 +1881,11 @@ ssl_wrapped_on_connect_error(struct us_internal_ssl_socket_t *s, int code) { context); if (wrapped_context->events.on_connect_error) { - wrapped_context->events.on_connect_error((struct us_socket_t *)s, code); + wrapped_context->events.on_connect_error((struct us_connecting_socket_t *)s, code); } if (wrapped_context->old_events.on_connect_error) { - wrapped_context->old_events.on_connect_error((struct us_socket_t *)s, code); + wrapped_context->old_events.on_connect_error((struct us_connecting_socket_t *)s, code); } return s; } @@ -1949,7 +1957,7 @@ struct us_internal_ssl_socket_t *us_internal_ssl_socket_wrap_with_tls( // as well us_socket_context_on_connect_error( 0, context, - (struct us_socket_t * (*)(struct us_socket_t *, int)) + (struct us_connecting_socket_t * (*)(struct us_connecting_socket_t *, int)) ssl_wrapped_on_connect_error); us_socket_context_on_end(0, context, (struct us_socket_t * (*)(struct us_socket_t *)) diff --git a/packages/bun-usockets/src/eventing/epoll_kqueue.c b/packages/bun-usockets/src/eventing/epoll_kqueue.c index 3e60c904522247..67ab45e40eaf2f 100644 --- a/packages/bun-usockets/src/eventing/epoll_kqueue.c +++ b/packages/bun-usockets/src/eventing/epoll_kqueue.c @@ -227,12 +227,20 @@ void us_loop_run_bun_tick(struct us_loop_t *loop, int64_t timeoutMs) { int events = loop->ready_polls[loop->current_ready_poll].events; int error = loop->ready_polls[loop->current_ready_poll].events & (EPOLLERR | EPOLLHUP); #else + + struct kevent64_s *kev = &loop->ready_polls[loop->current_ready_poll]; /* EVFILT_READ, EVFILT_TIME, EVFILT_USER are all mapped to LIBUS_SOCKET_READABLE */ int events = LIBUS_SOCKET_READABLE; - if (loop->ready_polls[loop->current_ready_poll].filter == EVFILT_WRITE) { + if (kev->filter == EVFILT_WRITE) { events = LIBUS_SOCKET_WRITABLE; } - int error = loop->ready_polls[loop->current_ready_poll].flags & (EV_ERROR | EV_EOF); + // see man 2 kqueue + int error = 0; + if (kev->flags & EV_ERROR) { + error = (int)kev->data; + } else if (kev->flags & EV_EOF) { + error = kev->fflags; + } #endif /* Always filter all polls by what they actually poll for (callback polls always poll for readable) */ events &= us_poll_events(poll); diff --git a/packages/bun-usockets/src/eventing/libuv.c b/packages/bun-usockets/src/eventing/libuv.c index e5d10c4a41d70d..a37fd115a6b8fa 100644 --- a/packages/bun-usockets/src/eventing/libuv.c +++ b/packages/bun-usockets/src/eventing/libuv.c @@ -271,6 +271,12 @@ void us_timer_set(struct us_timer_t *t, void (*cb)(struct us_timer_t *t), struct us_internal_callback_t *internal_cb = (struct us_internal_callback_t *)t; + // only add the timer to the event loop once + if (internal_cb->has_added_timer_to_event_loop) { + return; + } + internal_cb->has_added_timer_to_event_loop = 1; + internal_cb->cb = (void (*)(struct us_internal_callback_t *))cb; uv_timer_t *uv_timer = (uv_timer_t *)(internal_cb + 1); diff --git a/packages/bun-usockets/src/internal/internal.h b/packages/bun-usockets/src/internal/internal.h index 29224ce0148926..0fc37249d8cb6f 100644 --- a/packages/bun-usockets/src/internal/internal.h +++ b/packages/bun-usockets/src/internal/internal.h @@ -71,6 +71,20 @@ enum { #define POLL_TYPE_POLLING_MASK 0b11000 #define POLL_TYPE_MASK (POLL_TYPE_KIND_MASK | POLL_TYPE_POLLING_MASK) +/* Bun APIs implemented in Zig */ +void Bun__lock(uint32_t *lock); +void Bun__unlock(uint32_t *lock); + +struct addrinfo_result { + struct addrinfo *info; + int error; +}; + +extern void Bun__addrinfo_get(const char* host, int port, struct us_connecting_socket_t *s); +extern void Bun__addrinfo_freeRequest(void* addrinfo_req, int error); +extern struct addrinfo_result *Bun__addrinfo_getRequestResult(void* addrinfo_req); + + /* Loop related */ void us_internal_dispatch_ready_poll(struct us_poll_t *p, int error, int events); @@ -112,6 +126,9 @@ void us_internal_socket_context_link_socket(struct us_socket_context_t *context, void us_internal_socket_context_unlink_socket( struct us_socket_context_t *context, struct us_socket_t *s); +void us_internal_socket_after_resolve(struct us_connecting_socket_t *s); +int us_internal_handle_dns_results(struct us_loop_t *loop); + /* Sockets are polls */ struct us_socket_t { alignas(LIBUS_EXT_ALIGNMENT) struct us_poll_t p; // 4 bytes @@ -122,6 +139,20 @@ struct us_socket_t { = was in low-prio queue in this iteration */ struct us_socket_context_t *context; struct us_socket_t *prev, *next; + struct us_connecting_socket_t *connect_state; +}; + +struct us_connecting_socket_t { + alignas(LIBUS_EXT_ALIGNMENT) void *addrinfo_req; + struct us_socket_context_t *context; + struct us_connecting_socket_t *next; + struct us_socket_t *socket; + int options; + int socket_ext_size; + unsigned int closed : 1, shutdown : 1, ssl : 1, shutdown_read : 1, pending_resolve_callback : 1; + unsigned char timeout; + unsigned char long_timeout; + int error; }; struct us_wrapped_socket_context_t { @@ -210,7 +241,7 @@ struct us_socket_context_t { struct us_socket_t *(*on_socket_timeout)(struct us_socket_t *); struct us_socket_t *(*on_socket_long_timeout)(struct us_socket_t *); struct us_socket_t *(*on_end)(struct us_socket_t *); - struct us_socket_t *(*on_connect_error)(struct us_socket_t *, int code); + struct us_connecting_socket_t *(*on_connect_error)(struct us_connecting_socket_t *, int code); int (*is_low_prio)(struct us_socket_t *); }; @@ -317,9 +348,9 @@ struct us_listen_socket_t *us_internal_ssl_socket_context_listen_unix( struct us_internal_ssl_socket_context_t *context, const char *path, size_t pathlen, int options, int socket_ext_size); -struct us_internal_ssl_socket_t *us_internal_ssl_socket_context_connect( +struct us_connecting_socket_t *us_internal_ssl_socket_context_connect( struct us_internal_ssl_socket_context_t *context, const char *host, - int port, const char *source_host, int options, int socket_ext_size); + int port, int options, int socket_ext_size); struct us_internal_ssl_socket_t *us_internal_ssl_socket_context_connect_unix( struct us_internal_ssl_socket_context_t *context, const char *server_path, @@ -338,6 +369,7 @@ us_internal_ssl_socket_context_ext(struct us_internal_ssl_socket_context_t *s); struct us_internal_ssl_socket_context_t * us_internal_ssl_socket_get_context(struct us_internal_ssl_socket_t *s); void *us_internal_ssl_socket_ext(struct us_internal_ssl_socket_t *s); +void *us_internal_connecting_ssl_socket_ext(struct us_connecting_socket_t *c); int us_internal_ssl_socket_is_shut_down(struct us_internal_ssl_socket_t *s); void us_internal_ssl_socket_shutdown(struct us_internal_ssl_socket_t *s); diff --git a/packages/bun-usockets/src/internal/loop_data.h b/packages/bun-usockets/src/internal/loop_data.h index 7cdcd1979dfe09..0ba5d409c47fe0 100644 --- a/packages/bun-usockets/src/internal/loop_data.h +++ b/packages/bun-usockets/src/internal/loop_data.h @@ -18,6 +18,9 @@ #ifndef LOOP_DATA_H #define LOOP_DATA_H +#include + +// IMPORTANT: When changing this, don't forget to update the zig version in uws.zig as well! struct us_internal_loop_data_t { struct us_timer_t *sweep_timer; struct us_internal_async *wakeup_async; @@ -33,6 +36,11 @@ struct us_internal_loop_data_t { struct us_socket_t *closed_head; struct us_socket_t *low_prio_head; int low_prio_budget; + struct us_connecting_socket_t *dns_ready_head; + struct us_connecting_socket_t *closed_connecting_head; + uint32_t mutex; + void *parent_ptr; + char parent_tag; /* We do not care if this flips or not, it doesn't matter */ long long iteration_nr; }; diff --git a/packages/bun-usockets/src/internal/networking/bsd.h b/packages/bun-usockets/src/internal/networking/bsd.h index 91d62700fdbd8f..ba968380d49dbb 100644 --- a/packages/bun-usockets/src/internal/networking/bsd.h +++ b/packages/bun-usockets/src/internal/networking/bsd.h @@ -40,6 +40,7 @@ #endif /* For socklen_t */ #include +#include #define SETSOCKOPT_PTR_TYPE int * #define LIBUS_SOCKET_ERROR -1 #endif @@ -149,7 +150,7 @@ LIBUS_SOCKET_DESCRIPTOR bsd_create_udp_socket(const char *host, int port); int bsd_connect_udp_socket(LIBUS_SOCKET_DESCRIPTOR fd, const char *host, int port); int bsd_disconnect_udp_socket(LIBUS_SOCKET_DESCRIPTOR fd); -LIBUS_SOCKET_DESCRIPTOR bsd_create_connect_socket(const char *host, int port, const char *source_host, int options); +LIBUS_SOCKET_DESCRIPTOR bsd_create_connect_socket(struct addrinfo *addrinfo, int options); LIBUS_SOCKET_DESCRIPTOR bsd_create_connect_socket_unix(const char *server_path, size_t pathlen, int options); diff --git a/packages/bun-usockets/src/libusockets.h b/packages/bun-usockets/src/libusockets.h index a082766cdcf154..39176eae0e55ed 100644 --- a/packages/bun-usockets/src/libusockets.h +++ b/packages/bun-usockets/src/libusockets.h @@ -76,6 +76,7 @@ enum { /* Library types publicly available */ struct us_socket_t; +struct us_connecting_socket_t; struct us_timer_t; struct us_socket_context_t; struct us_loop_t; @@ -179,7 +180,7 @@ struct us_socket_events_t { struct us_socket_t *(*on_timeout)(struct us_socket_t *); struct us_socket_t *(*on_long_timeout)(struct us_socket_t *); struct us_socket_t *(*on_end)(struct us_socket_t *); - struct us_socket_t *(*on_connect_error)(struct us_socket_t *, int code); + struct us_connecting_socket_t *(*on_connect_error)(struct us_connecting_socket_t *, int code); void (*on_handshake)(struct us_socket_t*, int success, struct us_bun_verify_error_t verify_error, void* custom_data); }; @@ -243,7 +244,7 @@ void us_socket_context_on_long_timeout(int ssl, struct us_socket_context_t *cont struct us_socket_t *(*on_timeout)(struct us_socket_t *s)); /* This one is only used for when a connecting socket fails in a late stage. */ void us_socket_context_on_connect_error(int ssl, struct us_socket_context_t *context, - struct us_socket_t *(*on_connect_error)(struct us_socket_t *s, int code)); + struct us_connecting_socket_t *(*on_connect_error)(struct us_connecting_socket_t *s, int code)); void us_socket_context_on_handshake(int ssl, struct us_socket_context_t *context, void (*on_handshake)(struct us_socket_t *, int success, struct us_bun_verify_error_t verify_error, void* custom_data), void* custom_data); @@ -267,8 +268,8 @@ struct us_listen_socket_t *us_socket_context_listen_unix(int ssl, struct us_sock void us_listen_socket_close(int ssl, struct us_listen_socket_t *ls); /* Land in on_open or on_connection_error or return null or return socket */ -struct us_socket_t *us_socket_context_connect(int ssl, struct us_socket_context_t *context, - const char *host, int port, const char *source_host, int options, int socket_ext_size); +struct us_connecting_socket_t *us_socket_context_connect(int ssl, struct us_socket_context_t *context, + const char *host, int port, int options, int socket_ext_size); struct us_socket_t *us_socket_context_connect_unix(int ssl, struct us_socket_context_t *context, const char *server_path, size_t pathlen, int options, int socket_ext_size); @@ -277,10 +278,12 @@ struct us_socket_t *us_socket_context_connect_unix(int ssl, struct us_socket_con * Can also be used to determine if a socket is a listen_socket or not, but you probably know that already. */ int us_socket_is_established(int ssl, struct us_socket_t *s); +void us_connecting_socket_free(struct us_connecting_socket_t *c); + /* Cancel a connecting socket. Can be used together with us_socket_timeout to limit connection times. * Entirely destroys the socket - this function works like us_socket_close but does not trigger on_close event since * you never got the on_open event first. */ -struct us_socket_t *us_socket_close_connecting(int ssl, struct us_socket_t *s); +void us_connecting_socket_close(int ssl, struct us_connecting_socket_t *c); /* Returns the loop for this socket context. */ struct us_loop_t *us_socket_context_loop(int ssl, struct us_socket_context_t *context); @@ -370,6 +373,7 @@ void us_socket_long_timeout(int ssl, struct us_socket_t *s, unsigned int minutes /* Return the user data extension of this socket */ void *us_socket_ext(int ssl, struct us_socket_t *s); +void *us_connecting_socket_ext(int ssl, struct us_connecting_socket_t *c); /* Return the socket context of this socket */ struct us_socket_context_t *us_socket_context(int ssl, struct us_socket_t *s); diff --git a/packages/bun-usockets/src/loop.c b/packages/bun-usockets/src/loop.c index 780c4a5f782c4b..51111e2dc394ba 100644 --- a/packages/bun-usockets/src/loop.c +++ b/packages/bun-usockets/src/loop.c @@ -40,6 +40,13 @@ void us_internal_loop_data_init(struct us_loop_t *loop, void (*wakeup_cb)(struct loop->data.post_cb = post_cb; loop->data.iteration_nr = 0; + loop->data.closed_connecting_head = 0; + loop->data.dns_ready_head = 0; + loop->data.mutex = 0; + + loop->data.parent_ptr = 0; + loop->data.parent_tag = 0; + loop->data.wakeup_async = us_internal_create_async(loop, 1, 0); us_internal_async_set(loop->data.wakeup_async, (void (*)(struct us_internal_async *)) wakeup_cb); } @@ -165,25 +172,63 @@ void us_internal_handle_low_priority_sockets(struct us_loop_t *loop) { } } +// Called when DNS resolution completes +// Does not wake up the loop. +void us_internal_dns_callback(struct us_connecting_socket_t *c, void* addrinfo_req) { + struct us_loop_t *loop = c->context->loop; + Bun__lock(&loop->data.mutex); + c->addrinfo_req = addrinfo_req; + c->next = loop->data.dns_ready_head; + loop->data.dns_ready_head = c; + Bun__unlock(&loop->data.mutex); +} + +// Called when DNS resolution completes +// Wakes up the loop. +// Can be caleld from any thread. +void us_internal_dns_callback_threadsafe(struct us_connecting_socket_t *c, void* addrinfo_req) { + struct us_loop_t *loop = c->context->loop; + us_internal_dns_callback(c, addrinfo_req); + us_wakeup_loop(loop); +} + +void us_internal_drain_pending_dns_resolve(struct us_loop_t *loop, struct us_connecting_socket_t *s) { + while (s) { + struct us_connecting_socket_t *next = s->next; + us_internal_socket_after_resolve(s); + s = next; + } +} + +int us_internal_handle_dns_results(struct us_loop_t *loop) { + struct us_connecting_socket_t *s = __atomic_exchange_n(&loop->data.dns_ready_head, NULL, __ATOMIC_ACQ_REL); + us_internal_drain_pending_dns_resolve(loop, s); + return s != NULL; +} + /* Note: Properly takes the linked list and timeout sweep into account */ void us_internal_free_closed_sockets(struct us_loop_t *loop) { /* Free all closed sockets (maybe it is better to reverse order?) */ - if (loop->data.closed_head) { - for (struct us_socket_t *s = loop->data.closed_head; s; ) { - struct us_socket_t *next = s->next; - us_poll_free((struct us_poll_t *) s, loop); - s = next; - } - loop->data.closed_head = 0; + for (struct us_socket_t *s = loop->data.closed_head; s; ) { + struct us_socket_t *next = s->next; + us_poll_free((struct us_poll_t *) s, loop); + s = next; } - if (loop->data.closed_udp_head) { - for (struct us_udp_socket_t *s = loop->data.closed_udp_head; s; ) { - struct us_udp_socket_t *next = s->next; - us_poll_free((struct us_poll_t *) s, loop); - s = next; - } - loop->data.closed_udp_head = 0; + loop->data.closed_head = 0; + + for (struct us_udp_socket_t *s = loop->data.closed_udp_head; s; ) { + struct us_udp_socket_t *next = s->next; + us_poll_free((struct us_poll_t *) s, loop); + s = next; } + loop->data.closed_udp_head = 0; + + for (struct us_connecting_socket_t *s = loop->data.closed_connecting_head; s; ) { + struct us_connecting_socket_t *next = s->next; + us_free(s); + s = next; + } + loop->data.closed_connecting_head = 0; } void sweep_timer_cb(struct us_internal_callback_t *cb) { @@ -197,11 +242,13 @@ long long us_loop_iteration_number(struct us_loop_t *loop) { /* These may have somewhat different meaning depending on the underlying event library */ void us_internal_loop_pre(struct us_loop_t *loop) { loop->data.iteration_nr++; + us_internal_handle_dns_results(loop); us_internal_handle_low_priority_sockets(loop); loop->data.pre_cb(loop); } void us_internal_loop_post(struct us_loop_t *loop) { + us_internal_handle_dns_results(loop); us_internal_free_closed_sockets(loop); loop->data.post_cb(loop); } @@ -235,8 +282,8 @@ void us_internal_dispatch_ready_poll(struct us_poll_t *p, int error, int events) /* It is perfectly possible to come here with an error */ if (error) { /* Emit error, close without emitting on_close */ - s->context->on_connect_error(s, 0); - us_socket_close_connecting(0, s); + s->context->on_connect_error(s->connect_state, error); + us_connecting_socket_close(0, s->connect_state); s = NULL; } else { /* All sockets poll for readable */ @@ -252,6 +299,12 @@ void us_internal_dispatch_ready_poll(struct us_poll_t *p, int error, int events) us_socket_timeout(0, s, 0); s->context->on_open(s, 1, 0, 0); + + if (s->connect_state) { + // now that the socket is open, we can release the associated us_connecting_socket_t if it exists + us_connecting_socket_free(s->connect_state); + s->connect_state = NULL; + } } } else { struct us_listen_socket_t *listen_socket = (struct us_listen_socket_t *) p; @@ -273,6 +326,7 @@ void us_internal_dispatch_ready_poll(struct us_poll_t *p, int error, int events) struct us_socket_t *s = (struct us_socket_t *) accepted_p; s->context = listen_socket->s.context; + s->connect_state = NULL; s->timeout = 255; s->long_timeout = 255; s->low_prio_state = 0; @@ -401,7 +455,7 @@ void us_internal_dispatch_ready_poll(struct us_poll_t *p, int error, int events) /* Such as epollerr epollhup */ if (error && s) { /* Todo: decide what code we give here */ - s = us_socket_close(0, s, 0, NULL); + s = us_socket_close(0, s, error, NULL); return; } break; diff --git a/packages/bun-usockets/src/socket.c b/packages/bun-usockets/src/socket.c index fb51e63269df78..35c88b6f018869 100644 --- a/packages/bun-usockets/src/socket.c +++ b/packages/bun-usockets/src/socket.c @@ -42,6 +42,10 @@ void us_socket_shutdown_read(int ssl, struct us_socket_t *s) { bsd_shutdown_socket_read(us_poll_fd((struct us_poll_t *) s)); } +void us_connecting_socket_shutdown_read(int ssl, struct us_connecting_socket_t *c) { + c->shutdown_read = 1; +} + void us_socket_remote_address(int ssl, struct us_socket_t *s, char *buf, int *length) { struct bsd_addr_t addr; if (bsd_remote_addr(us_poll_fd(&s->p), &addr) || *length < bsd_addr_get_ip_length(&addr)) { @@ -66,6 +70,10 @@ struct us_socket_context_t *us_socket_context(int ssl, struct us_socket_t *s) { return s->context; } +struct us_socket_context_t *us_connecting_socket_context(int ssl, struct us_connecting_socket_t *c) { + return c->context; +} + void us_socket_timeout(int ssl, struct us_socket_t *s, unsigned int seconds) { if (seconds) { s->timeout = ((unsigned int)s->context->timestamp + ((seconds + 3) >> 2)) % 240; @@ -74,6 +82,14 @@ void us_socket_timeout(int ssl, struct us_socket_t *s, unsigned int seconds) { } } +void us_connecting_socket_timeout(int ssl, struct us_connecting_socket_t *c, unsigned int seconds) { + if (seconds) { + c->timeout = ((unsigned int)c->context->timestamp + ((seconds + 3) >> 2)) % 240; + } else { + c->timeout = 255; + } +} + void us_socket_long_timeout(int ssl, struct us_socket_t *s, unsigned int minutes) { if (minutes) { s->long_timeout = ((unsigned int)s->context->long_timestamp + minutes) % 240; @@ -82,6 +98,14 @@ void us_socket_long_timeout(int ssl, struct us_socket_t *s, unsigned int minutes } } +void us_connecting_socket_long_timeout(int ssl, struct us_connecting_socket_t *c, unsigned int minutes) { + if (minutes) { + c->long_timeout = ((unsigned int)c->context->long_timestamp + minutes) % 240; + } else { + c->long_timeout = 255; + } +} + void us_socket_flush(int ssl, struct us_socket_t *s) { if (!us_socket_is_shut_down(0, s)) { bsd_socket_flush(us_poll_fd((struct us_poll_t *) s)); @@ -92,14 +116,30 @@ int us_socket_is_closed(int ssl, struct us_socket_t *s) { return s->prev == (struct us_socket_t *) s->context; } +int us_connecting_socket_is_closed(int ssl, struct us_connecting_socket_t *c) { + return c->closed; +} + int us_socket_is_established(int ssl, struct us_socket_t *s) { /* Everything that is not POLL_TYPE_SEMI_SOCKET is established */ return us_internal_poll_type((struct us_poll_t *) s) != POLL_TYPE_SEMI_SOCKET; } -/* Exactly the same as us_socket_close but does not emit on_close event */ -struct us_socket_t *us_socket_close_connecting(int ssl, struct us_socket_t *s) { - if (!us_socket_is_closed(0, s)) { +void us_connecting_socket_free(struct us_connecting_socket_t *c) { + // we can't just free c immediately, as it may be enqueued in the dns_ready_head list + // instead, we move it to a close list and free it after the iteration + c->next = c->context->loop->data.closed_connecting_head; + c->context->loop->data.closed_connecting_head = c; +} + +void us_connecting_socket_close(int ssl, struct us_connecting_socket_t *c) { + if (c->closed) return; + c->closed = 1; + + struct us_socket_t *s = c->socket; + if (s) { + c->socket = NULL; + us_internal_socket_context_unlink_socket(s->context, s); us_poll_stop((struct us_poll_t *) s, s->context->loop); bsd_close_socket(us_poll_fd((struct us_poll_t *) s)); @@ -110,13 +150,15 @@ struct us_socket_t *us_socket_close_connecting(int ssl, struct us_socket_t *s) { /* Any socket with prev = context is marked as closed */ s->prev = (struct us_socket_t *) s->context; + } - //return s->context->on_close(s, code, reason); + // we can only schedule the socket to be freed if there is no pending callback + // otherwise, the callback will see that the socket is closed and will free it + if (!c->pending_resolve_callback) { + us_connecting_socket_free(c); } - return s; -} +} -/* Same as above but emits on_close */ struct us_socket_t *us_socket_close(int ssl, struct us_socket_t *s, int code, void *reason) { if (!us_socket_is_closed(0, s)) { if (s->low_prio_state == 1) { @@ -221,7 +263,6 @@ struct us_socket_t *us_socket_pair(struct us_socket_context_t *ctx, int socket_e /* This is not available for SSL sockets as it makes no sense. */ int us_socket_write2(int ssl, struct us_socket_t *s, const char *header, int header_length, const char *payload, int payload_length) { - if (us_socket_is_closed(ssl, s) || us_socket_is_shut_down(ssl, s)) { return 0; } @@ -272,17 +313,25 @@ void *us_socket_get_native_handle(int ssl, struct us_socket_t *s) { return us_internal_ssl_socket_get_native_handle((struct us_internal_ssl_socket_t *) s); } #endif - return (void *) (uintptr_t) us_poll_fd((struct us_poll_t *) s); } +void *us_connecting_socket_get_native_handle(int ssl, struct us_connecting_socket_t *c) { +#ifndef LIBUS_NO_SSL + // returns the ssl context + if (ssl) { + return *(void **)(c + 1); + } +#endif + return (void *) (uintptr_t) -1; +} + int us_socket_write(int ssl, struct us_socket_t *s, const char *data, int length, int msg_more) { #ifndef LIBUS_NO_SSL if (ssl) { return us_internal_ssl_socket_write((struct us_internal_ssl_socket_t *) s, data, length, msg_more); } #endif - if (us_socket_is_closed(ssl, s) || us_socket_is_shut_down(ssl, s)) { return 0; } @@ -306,16 +355,29 @@ void *us_socket_ext(int ssl, struct us_socket_t *s) { return s + 1; } +void *us_connecting_socket_ext(int ssl, struct us_connecting_socket_t *c) { +#ifndef LIBUS_NO_SSL + if (ssl) { + return us_internal_connecting_ssl_socket_ext(c); + } +#endif + + return c + 1; +} + int us_socket_is_shut_down(int ssl, struct us_socket_t *s) { #ifndef LIBUS_NO_SSL if (ssl) { return us_internal_ssl_socket_is_shut_down((struct us_internal_ssl_socket_t *) s); } #endif - return us_internal_poll_type(&s->p) == POLL_TYPE_SOCKET_SHUT_DOWN; } +int us_connecting_socket_is_shut_down(int ssl, struct us_connecting_socket_t *c) { + return c->shutdown; +} + void us_socket_shutdown(int ssl, struct us_socket_t *s) { #ifndef LIBUS_NO_SSL if (ssl) { @@ -323,7 +385,6 @@ void us_socket_shutdown(int ssl, struct us_socket_t *s) { return; } #endif - /* Todo: should we emit on_close if calling shutdown on an already half-closed socket? * We need more states in that case, we need to track RECEIVED_FIN * so far, the app has to track this and call close as needed */ @@ -334,6 +395,14 @@ void us_socket_shutdown(int ssl, struct us_socket_t *s) { } } +void us_connecting_socket_shutdown(int ssl, struct us_connecting_socket_t *c) { + c->shutdown = 1; +} + +int us_connecting_socket_get_error(int ssl, struct us_connecting_socket_t *c) { + return c->error; +} + /* Note: this assumes that the socket is non-TLS and will be adopted and wrapped with a new TLS context context ext will not be copied to the new context, new context will contain us_wrapped_socket_context_t on ext @@ -400,4 +469,8 @@ void us_socket_unref(struct us_socket_t *s) { uv_unref((uv_handle_t*)s->p.uv_p); #endif // do nothing if not using libuv -} \ No newline at end of file +} + +struct us_loop_t *us_connecting_socket_get_loop(struct us_connecting_socket_t *c) { + return c->context->loop; +} diff --git a/src/async/posix_event_loop.zig b/src/async/posix_event_loop.zig index 2d31c2f5fdc609..f7ca03133773db 100644 --- a/src/async/posix_event_loop.zig +++ b/src/async/posix_event_loop.zig @@ -163,6 +163,7 @@ pub const FilePoll = struct { const FileSink = JSC.WebCore.FileSink.Poll; const DNSResolver = JSC.DNS.DNSResolver; const GetAddrInfoRequest = JSC.DNS.GetAddrInfoRequest; + const Request = JSC.DNS.InternalDNS.Request; const LifecycleScriptSubprocessOutputReader = bun.install.LifecycleScriptSubprocess.OutputReader; const BufferedReader = bun.io.BufferedReader; pub const Owner = bun.TaggedPointerUnion(.{ @@ -186,6 +187,7 @@ pub const FilePoll = struct { DNSResolver, GetAddrInfoRequest, + Request, // LifecycleScriptSubprocessOutputReader, Process, ShellBufferedWriter, // i do not know why, but this has to be here otherwise compiler will complain about dependency loop @@ -411,6 +413,16 @@ pub const FilePoll = struct { loader.onMachportChange(); }, + @field(Owner.Tag, "Request") => { + if (comptime !Environment.isMac) { + unreachable; + } + + log("onUpdate " ++ kqueue_or_epoll ++ " (fd: {}) InternalDNSRequest", .{poll.fd}); + const loader: *Request = ptr.as(Request); + Request.MacAsyncDNS.onMachportChange(loader); + }, + else => { const possible_name = Owner.typeNameFromTag(@intFromEnum(ptr.tag())); log("onUpdate " ++ kqueue_or_epoll ++ " (fd: {}) disconnected? (maybe: {s})", .{ poll.fd, possible_name orelse "" }); diff --git a/src/bun.js/api/bun/dns_resolver.zig b/src/bun.js/api/bun/dns_resolver.zig index 4563b90ce9e7fa..8c2afd805b9350 100644 --- a/src/bun.js/api/bun/dns_resolver.zig +++ b/src/bun.js/api/bun/dns_resolver.zig @@ -1168,6 +1168,380 @@ pub const GlobalData = struct { } }; +pub const InternalDNS = struct { + const log = Output.scoped(.dns, true); + pub const Request = struct { + const Key = struct { + host: ?[:0]const u8, + port: u16, + hash: u64, + + pub fn init(name: ?[:0]const u8, port: u16) @This() { + const hash = if (name) |n| brk: { + var hasher = std.hash.Wyhash.init(0); + hasher.update(n); + const hash = hasher.final(); + break :brk hash; + } else 0; + return .{ + .host = name, + .port = port, + .hash = hash, + }; + } + + pub fn toOwned(this: @This()) @This() { + if (this.host) |host| { + const host_copy = bun.default_allocator.dupeZ(u8, host) catch bun.outOfMemory(); + return .{ + .host = host_copy, + .port = this.port, + .hash = this.hash, + }; + } else { + return this; + } + } + }; + + const Result = extern struct { + info: ?*std.c.addrinfo, + err: c_int, + }; + + pub const MacAsyncDNS = struct { + file_poll: ?*bun.Async.FilePoll = null, + machport: ?*anyopaque = null, + + extern fn getaddrinfo_send_reply(*anyopaque, *const JSC.DNS.LibInfo.GetaddrinfoAsyncHandleReply) bool; + pub fn onMachportChange(this: *Request) void { + if (!getaddrinfo_send_reply(this.libinfo.machport.?, LibInfo.getaddrinfo_async_handle_reply().?)) { + libinfoCallback(@intFromEnum(std.c.E.NOSYS), null, this); + } + } + }; + + key: Key, + result: ?Result = null, + + notify: std.ArrayListUnmanaged(*bun.uws.ConnectingSocket) = .{}, + // number of sockets that have a reference to result or are waiting for the result + // while this is non-zero, this entry cannot be freed + refcount: usize = 0, + valid: bool = true, + + libinfo: if (Environment.isMac) MacAsyncDNS else void = if (Environment.isMac) .{} else {}, + + pub fn deinit(this: *@This()) void { + bun.assert(this.notify.items.len == 0); + if (this.result) |res| { + if (res.info) |info| { + std.c.freeaddrinfo(info); + } + } + if (this.key.host) |host| { + bun.default_allocator.free(host); + } + } + }; + + const GlobalCache = struct { + const MAX_ENTRIES = 256; + + lock: bun.Lock = bun.Lock.init(), + cache: [MAX_ENTRIES]*Request = undefined, + len: usize = 0, + + const This = @This(); + + const CacheResult = union(enum) { + inflight: *Request, + resolved: *Request, + none, + }; + + fn get( + this: *This, + key: Request.Key, + ) ?*Request { + for (this.cache[0..this.len]) |entry| { + if (entry.key.hash == key.hash and entry.key.port == key.port and entry.valid) { + return entry; + } + } + return null; + } + + fn isNearlyFull(this: *This) bool { + // 80% full (value is kind of arbitrary) + return this.len * 5 >= this.cache.len * 4; + } + + fn remove(this: *This, entry: *Request) void { + const len = this.len; + // equivalent of swapRemove + for (0..len) |i| { + if (this.cache[i] == entry) { + this.cache[i] = this.cache[len - 1]; + this.len -= 1; + dns_cache_size = len - 1; + return; + } + } + } + + fn tryPush(this: *This, entry: *Request) bool { + // is the cache full? + if (this.len >= this.cache.len) { + // check if there is an element to evict + for (this.cache[0..this.len]) |*e| { + if (e.*.refcount == 0) { + e.*.deinit(); + e.* = entry; + return true; + } + } + return false; + } else { + // just append to the end + this.cache[this.len] = entry; + this.len += 1; + return true; + } + } + }; + + var global_cache = GlobalCache{}; + + // we just hardcode a STREAM socktype + const hints: std.c.addrinfo = .{ + .addr = null, + .addrlen = 0, + .canonname = null, + .family = std.c.AF.UNSPEC, + .flags = 0, + .next = null, + .protocol = 0, + .socktype = std.c.SOCK.STREAM, + }; + + extern fn us_internal_dns_callback(socket: *bun.uws.ConnectingSocket, req: *Request) void; + extern fn us_internal_dns_callback_threadsafe(socket: *bun.uws.ConnectingSocket, req: *Request) void; + + fn afterResult(req: *Request, info: ?*std.c.addrinfo, err: c_int) void { + // need to acquire the global cache lock to ensure that the notify list is not modified while we are iterating over it + global_cache.lock.lock(); + defer global_cache.lock.unlock(); + + req.result = .{ + .info = info, + .err = err, + }; + for (req.notify.items) |socket| { + us_internal_dns_callback_threadsafe(socket, req); + } + req.notify.clearAndFree(bun.default_allocator); + } + + fn workPoolCallback(req: *Request) void { + var port_buf: [128]u8 = undefined; + const port = std.fmt.bufPrintIntToSlice(&port_buf, req.key.port, 10, .lower, .{}); + port_buf[port.len] = 0; + const portZ = port_buf[0..port.len :0]; + + if (Environment.isWindows) { + const wsa = std.os.windows.ws2_32; + const wsa_hints = wsa.addrinfo{ + .flags = 0, + .family = wsa.AF.UNSPEC, + .socktype = wsa.SOCK.STREAM, + .protocol = 0, + .addrlen = 0, + .canonname = null, + .addr = null, + .next = null, + }; + + var addrinfo: ?*wsa.addrinfo = null; + const err = wsa.getaddrinfo( + if (req.key.host) |host| host.ptr else null, + if (port.len > 0) portZ.ptr else null, + &wsa_hints, + &addrinfo, + ); + afterResult(req, @ptrCast(addrinfo), err); + } else { + var addrinfo: ?*std.c.addrinfo = null; + const err = std.c.getaddrinfo( + if (req.key.host) |host| host.ptr else null, + if (port.len > 0) portZ.ptr else null, + &hints, + &addrinfo, + ); + afterResult(req, addrinfo, @intFromEnum(err)); + } + } + + pub fn lookupLibinfo(req: *Request, socket: *bun.uws.ConnectingSocket) bool { + const getaddrinfo_async_start_ = LibInfo.getaddrinfo_async_start() orelse return false; + const loop = bun.uws.us_connecting_socket_get_loop(socket).internal_loop_data.getParent(); + + var port_buf: [128]u8 = undefined; + const port = std.fmt.bufPrintIntToSlice(&port_buf, req.key.port, 10, .lower, .{}); + port_buf[port.len] = 0; + const portZ = port_buf[0..port.len :0]; + + var machport: ?*anyopaque = null; + const errno = getaddrinfo_async_start_( + &machport, + if (req.key.host) |host| host.ptr else null, + if (port.len > 0) portZ.ptr else null, + &hints, + libinfoCallback, + req, + ); + + if (errno != 0 or machport == null) { + return false; + } + + var poll = bun.Async.FilePoll.init(loop, bun.toFD(@intFromPtr(machport)), .{}, InternalDNSRequest, req); + const rc = poll.register(loop.loop(), .machport, true); + + if (rc == .err) { + poll.deinit(); + return false; + } + + req.libinfo = .{ + .file_poll = poll, + .machport = machport, + }; + + return true; + } + + fn libinfoCallback( + status: i32, + addr_info: ?*std.c.addrinfo, + arg: ?*anyopaque, + ) callconv(.C) void { + const req = bun.cast(*Request, arg); + afterResult(req, addr_info, @intCast(status)); + } + + var dns_cache_hits_completed: usize = 0; + var dns_cache_hits_inflight: usize = 0; + var dns_cache_size: usize = 0; + var dns_cache_misses: usize = 0; + var dns_cache_errors: usize = 0; + var getaddrinfo_calls: usize = 0; + + pub fn createDNSCacheStatsObject(globalObject: *JSC.JSGlobalObject, _: *JSC.CallFrame) callconv(.C) JSC.JSValue { + const object = JSC.JSValue.createEmptyObject(globalObject, 7); + object.put(globalObject, JSC.ZigString.static("cache_hits_completed"), JSC.JSValue.jsNumber(@atomicLoad(usize, &dns_cache_hits_completed, .Monotonic))); + object.put(globalObject, JSC.ZigString.static("cache_hits_inflight"), JSC.JSValue.jsNumber(@atomicLoad(usize, &dns_cache_hits_inflight, .Monotonic))); + object.put(globalObject, JSC.ZigString.static("size"), JSC.JSValue.jsNumber(@atomicLoad(usize, &dns_cache_size, .Monotonic))); + object.put(globalObject, JSC.ZigString.static("cache_misses"), JSC.JSValue.jsNumber(@atomicLoad(usize, &dns_cache_misses, .Monotonic))); + object.put(globalObject, JSC.ZigString.static("errors"), JSC.JSValue.jsNumber(@atomicLoad(usize, &dns_cache_errors, .Monotonic))); + object.put(globalObject, JSC.ZigString.static("getaddrinfo"), JSC.JSValue.jsNumber(@atomicLoad(usize, &getaddrinfo_calls, .Monotonic))); + return object; + } + + pub fn getDNSCacheStats(globalObject: *JSC.JSGlobalObject) callconv(.C) JSC.JSValue { + return JSC.JSFunction.create(globalObject, "createDNSCacheStatsObject", createDNSCacheStatsObject, 0, .{}); + } + + fn getaddrinfo(_host: ?[*:0]const u8, port: u16, socket: *bun.uws.ConnectingSocket) callconv(.C) void { + const host: ?[:0]const u8 = std.mem.span(_host); + const key = Request.Key.init(host, port); + + global_cache.lock.lock(); + getaddrinfo_calls += 1; + // is there a cache hit? + if (!bun.getRuntimeFeatureFlag("BUN_FEATURE_FLAG_DISABLE_DNS_CACHE")) { + if (global_cache.get(key)) |entry| { + if (entry.result != null) { + log("getaddrinfo({s}:{d}) = cache hit", .{ host orelse "", port }); + // result is already available, we can notify the socket immediately + entry.refcount += 1; + dns_cache_hits_completed += 1; + global_cache.lock.unlock(); + us_internal_dns_callback(socket, entry); + return; + } else { + log("getaddrinfo({s}:{d}) = cache hit (inflight)", .{ host orelse "", port }); + // add this socket to the list of sockets to be notified when the request is resolved + entry.notify.append(bun.default_allocator, socket) catch bun.outOfMemory(); + entry.refcount += 1; + dns_cache_hits_inflight += 1; + global_cache.lock.unlock(); + return; + } + } + } + + // no cache hit, we have to make a new request + const req = bun.default_allocator.create(Request) catch bun.outOfMemory(); + req.* = .{ + .key = key.toOwned(), + .refcount = 1, + }; + req.notify.append(bun.default_allocator, socket) catch bun.outOfMemory(); + _ = global_cache.tryPush(req); + dns_cache_misses += 1; + dns_cache_size = global_cache.len; + global_cache.lock.unlock(); + + // doesn't work yet + if (comptime Environment.isMac) { + if (!bun.getRuntimeFeatureFlag("BUN_FEATURE_FLAG_DISABLE_DNS_CACHE_LIBINFO")) { + const res = lookupLibinfo(req, socket); + log("getaddrinfo({s}:{d}) = cache miss (libinfo)", .{ host orelse "", port }); + if (res) return; + // if we were not able to use libinfo, we fall back to the work pool + } + } + + log("getaddrinfo({s}:{d}) = cache miss (libc)", .{ host orelse "", port }); + // schedule the request to be executed on the work pool + bun.JSC.WorkPool.go(bun.default_allocator, *Request, req, workPoolCallback) catch bun.outOfMemory(); + } + + fn freeaddrinfo(req: *Request, err: c_int) callconv(.C) void { + global_cache.lock.lock(); + defer global_cache.lock.unlock(); + + req.valid = err == 0; + dns_cache_errors += @as(usize, @intFromBool(err != 0)); + + req.refcount -= 1; + if (req.refcount == 0 and (global_cache.isNearlyFull() or !req.valid)) { + log("cache --", .{}); + global_cache.remove(req); + req.deinit(); + } + } + + fn getRequestResult(req: *Request) callconv(.C) *Request.Result { + return &req.result.?; + } +}; + +pub const InternalDNSRequest = InternalDNS.Request; + +comptime { + @export(InternalDNS.getaddrinfo, .{ + .name = "Bun__addrinfo_get", + }); + @export(InternalDNS.freeaddrinfo, .{ + .name = "Bun__addrinfo_freeRequest", + }); + @export(InternalDNS.getRequestResult, .{ + .name = "Bun__addrinfo_getRequestResult", + }); +} + pub const DNSResolver = struct { const log = Output.scoped(.DNSResolver, false); @@ -2538,8 +2912,6 @@ pub const DNSResolver = struct { }, ); } - // pub fn cancel(globalThis: *JSC.JSGlobalObject, callframe: *JSC.CallFrame) callconv(.C) JSC.JSValue { - // const arguments = callframe.arguments(3); - - // } }; + +pub const getDNSCacheStats = InternalDNS.getDNSCacheStats; diff --git a/src/bun.js/api/bun/socket.zig b/src/bun.js/api/bun/socket.zig index 0f128ba3145775..ca8c3aa5087419 100644 --- a/src/bun.js/api/bun/socket.zig +++ b/src/bun.js/api/bun/socket.zig @@ -810,7 +810,7 @@ pub const Listener = struct { const globalObject = listener.handlers.globalObject; Socket.dataSetCached(this_socket.getThisValue(globalObject), globalObject, default_data); } - socket.ext(**anyopaque).?.* = bun.cast(**anyopaque, this_socket); + socket.ext(**anyopaque).* = bun.cast(**anyopaque, this_socket); socket.setTimeout(120000); } @@ -1358,6 +1358,8 @@ fn NewSocket(comptime ssl: bool) type { } pub fn onOpen(this: *This, socket: Socket) void { + // update the internal socket instance to the one that was just connected + this.socket = socket; JSC.markBinding(@src()); log("onOpen ssl: {}", .{comptime ssl}); @@ -1397,7 +1399,7 @@ fn NewSocket(comptime ssl: bool) type { this.socket = socket; if (this.wrapped == .none) { - socket.ext(**anyopaque).?.* = bun.cast(**anyopaque, this); + socket.ext(**anyopaque).* = bun.cast(**anyopaque, this); } const handlers = this.handlers; @@ -3020,7 +3022,7 @@ fn NewSocket(comptime ssl: bool) type { tls.poll_ref.ref(this.handlers.vm); // mark both instances on socket data - new_socket.ext(WrappedSocket).?.* = .{ .tcp = raw, .tls = tls }; + new_socket.ext(WrappedSocket).* = .{ .tcp = raw, .tls = tls }; // start TLS handshake after we set ext new_socket.startTLS(!this.handlers.is_server); diff --git a/src/bun.js/api/bun/subprocess.zig b/src/bun.js/api/bun/subprocess.zig index 6d5cbe73dfa617..a80325ba5e6d2f 100644 --- a/src/bun.js/api/bun/subprocess.zig +++ b/src/bun.js/api/bun/subprocess.zig @@ -2044,9 +2044,9 @@ pub const Subprocess = struct { var posix_ipc_info: if (Environment.isPosix) IPC.Socket else void = undefined; if (Environment.isPosix and !is_sync) { if (maybe_ipc_mode != null) { - posix_ipc_info = .{ + posix_ipc_info = IPC.Socket.from( // we initialize ext later in the function - .socket = uws.us_socket_from_fd( + uws.us_socket_from_fd( jsc_vm.rareData().spawnIPCContext(jsc_vm), @sizeOf(*Subprocess), spawned.extra_pipes.items[0].cast(), @@ -2056,7 +2056,7 @@ pub const Subprocess = struct { globalThis.throw("failed to create socket pair", .{}); return .zero; }, - }; + ); } } @@ -2117,8 +2117,7 @@ pub const Subprocess = struct { if (subprocess.ipc_data) |*ipc_data| { if (Environment.isPosix) { - const ptr = posix_ipc_info.ext(*Subprocess); - ptr.?.* = subprocess; + posix_ipc_info.ext(*Subprocess).* = subprocess; } else { if (ipc_data.configureServer( Subprocess, diff --git a/src/bun.js/event_loop.zig b/src/bun.js/event_loop.zig index 11d3b4ca8f1f84..59ce94dd0bfd0a 100644 --- a/src/bun.js/event_loop.zig +++ b/src/bun.js/event_loop.zig @@ -1598,6 +1598,7 @@ pub const EventLoop = struct { // _ = actual.addPostHandler(*JSC.EventLoop, this, JSC.EventLoop.afterUSocketsTick); // _ = actual.addPreHandler(*JSC.VM, this.virtual_machine.jsc, JSC.VM.drainMicrotasks); } + bun.uws.Loop.get().internal_loop_data.setParentEventLoop(bun.JSC.EventLoopHandle.init(this)); } /// Asynchronously run the garbage collector and track how much memory is now allocated @@ -1772,6 +1773,7 @@ pub const MiniEventLoop = struct { const loop = MiniEventLoop.init(bun.default_allocator); global = bun.default_allocator.create(MiniEventLoop) catch bun.outOfMemory(); global.* = loop; + global.loop.internal_loop_data.setParentEventLoop(bun.JSC.EventLoopHandle.init(global)); global.env = env orelse bun.DotEnv.instance orelse env_loader: { const map = bun.default_allocator.create(bun.DotEnv.Map) catch bun.outOfMemory(); map.* = bun.DotEnv.Map.init(bun.default_allocator); diff --git a/src/deps/uws.zig b/src/deps/uws.zig index 3a338076821066..8e6eaecedcef5c 100644 --- a/src/deps/uws.zig +++ b/src/deps/uws.zig @@ -10,6 +10,7 @@ pub const u_int64_t = c_ulonglong; pub const LIBUS_LISTEN_DEFAULT: i32 = 0; pub const LIBUS_LISTEN_EXCLUSIVE_PORT: i32 = 1; pub const Socket = opaque {}; +pub const ConnectingSocket = opaque {}; const debug = bun.Output.scoped(.uws, false); const uws = @This(); @@ -38,48 +39,130 @@ pub const InternalLoopData = extern struct { closed_head: ?*Socket, low_prio_head: ?*Socket, low_prio_budget: i32, + dns_ready_head: *ConnectingSocket, + closed_connecting_head: *ConnectingSocket, + mutex: u32, // this is actually a bun.Lock + parent_ptr: ?*anyopaque, + parent_tag: c_char, + iteration_nr: c_longlong, pub fn recvSlice(this: *InternalLoopData) []u8 { return this.recv_buf[0..LIBUS_RECV_BUFFER_LENGTH]; } + + pub fn setParentEventLoop(this: *InternalLoopData, parent: bun.JSC.EventLoopHandle) void { + switch (parent) { + .js => |ptr| { + this.parent_tag = 1; + this.parent_ptr = ptr; + }, + .mini => |ptr| { + this.parent_tag = 2; + this.parent_ptr = ptr; + }, + } + } + + pub fn getParent(this: *InternalLoopData) bun.JSC.EventLoopHandle { + const parent = this.parent_ptr orelse @panic("Parent loop not set - pointer is null"); + return switch (this.parent_tag) { + 0 => @panic("Parent loop not set - tag is zero"), + 1 => .{ .js = bun.cast(*bun.JSC.EventLoop, parent) }, + 2 => .{ .mini = bun.cast(*bun.JSC.MiniEventLoop, parent) }, + else => @panic("Parent loop data corrupted - tag is invalid"), + }; + } +}; + +pub const InternalSocket = union(enum) { + done: *Socket, + connecting: *ConnectingSocket, + + pub fn get(this: @This()) ?*Socket { + return switch (this) { + .done => this.done, + .connecting => null, + }; + } + + pub fn eq(this: @This(), other: @This()) bool { + return switch (this) { + .done => switch (other) { + .done => this.done == other.done, + .connecting => false, + }, + .connecting => switch (other) { + .done => false, + .connecting => this.connecting == other.connecting, + }, + }; + } }; pub fn NewSocketHandler(comptime is_ssl: bool) type { return struct { const ssl_int: i32 = @intFromBool(is_ssl); - socket: *Socket, + socket: InternalSocket, const ThisSocket = @This(); pub fn verifyError(this: ThisSocket) us_bun_verify_error_t { - const ssl_error: us_bun_verify_error_t = uws.us_socket_verify_error(comptime ssl_int, this.socket); + const socket = this.socket.get() orelse return std.mem.zeroes(us_bun_verify_error_t); + const ssl_error: us_bun_verify_error_t = uws.us_socket_verify_error(comptime ssl_int, socket); return ssl_error; } pub fn isEstablished(this: ThisSocket) bool { - return us_socket_is_established(comptime ssl_int, this.socket) > 0; + const socket = this.socket.get() orelse return false; + return us_socket_is_established(comptime ssl_int, socket) > 0; } pub fn timeout(this: ThisSocket, seconds: c_uint) void { - return us_socket_timeout(comptime ssl_int, this.socket, seconds); + switch (this.socket) { + .done => |socket| us_socket_timeout(comptime ssl_int, socket, seconds), + .connecting => |socket| us_connecting_socket_timeout(comptime ssl_int, socket, seconds), + } } pub fn setTimeout(this: ThisSocket, seconds: c_uint) void { - if (seconds > 240) { - us_socket_timeout(comptime ssl_int, this.socket, 0); - us_socket_long_timeout(comptime ssl_int, this.socket, seconds / 60); - } else { - us_socket_timeout(comptime ssl_int, this.socket, seconds); - us_socket_long_timeout(comptime ssl_int, this.socket, 0); + switch (this.socket) { + .done => |socket| { + if (seconds > 240) { + us_socket_timeout(comptime ssl_int, socket, 0); + us_socket_long_timeout(comptime ssl_int, socket, seconds / 60); + } else { + us_socket_timeout(comptime ssl_int, socket, seconds); + us_socket_long_timeout(comptime ssl_int, socket, 0); + } + }, + .connecting => |socket| { + if (seconds > 240) { + us_connecting_socket_timeout(comptime ssl_int, socket, 0); + us_connecting_socket_long_timeout(comptime ssl_int, socket, seconds / 60); + } else { + us_connecting_socket_timeout(comptime ssl_int, socket, seconds); + us_connecting_socket_long_timeout(comptime ssl_int, socket, 0); + } + }, } } pub fn setTimeoutMinutes(this: ThisSocket, minutes: c_uint) void { - return us_socket_long_timeout(comptime ssl_int, this.socket, minutes); + switch (this.socket) { + .done => |socket| { + us_socket_timeout(comptime ssl_int, socket, 0); + us_socket_long_timeout(comptime ssl_int, socket, minutes); + }, + .connecting => |socket| { + us_connecting_socket_timeout(comptime ssl_int, socket, 0); + us_connecting_socket_long_timeout(comptime ssl_int, socket, minutes); + }, + } } pub fn startTLS(this: ThisSocket, is_client: bool) void { - _ = us_socket_open(comptime ssl_int, this.socket, @intFromBool(is_client), null, 0); + const socket = this.socket.get() orelse @panic("socket is not open"); + _ = us_socket_open(comptime ssl_int, socket, @intFromBool(is_client), null, 0); } pub fn ssl(this: ThisSocket) *BoringSSL.SSL { @@ -109,34 +192,34 @@ pub fn NewSocketHandler(comptime is_ssl: bool) type { const ValueType = if (deref) ContextType else *ContextType; fn getValue(socket: *Socket) ValueType { if (comptime ContextType == anyopaque) { - return us_socket_ext(1, socket).?; + return us_socket_ext(1, socket); } if (comptime deref_) { - return (TLSSocket{ .socket = socket }).ext(ContextType).?.*; + return (TLSSocket.from(socket)).ext(ContextType).*; } - return (TLSSocket{ .socket = socket }).ext(ContextType).?; + return (TLSSocket.from(socket)).ext(ContextType); } pub fn on_open(socket: *Socket, is_client: i32, _: [*c]u8, _: i32) callconv(.C) ?*Socket { if (comptime @hasDecl(Fields, "onCreate")) { if (is_client == 0) { Fields.onCreate( - TLSSocket{ .socket = socket }, + TLSSocket.from(socket), ); } } Fields.onOpen( getValue(socket), - TLSSocket{ .socket = socket }, + TLSSocket.from(socket), ); return socket; } pub fn on_close(socket: *Socket, code: i32, reason: ?*anyopaque) callconv(.C) ?*Socket { Fields.onClose( getValue(socket), - TLSSocket{ .socket = socket }, + TLSSocket.from(socket), code, reason, ); @@ -145,7 +228,7 @@ pub fn NewSocketHandler(comptime is_ssl: bool) type { pub fn on_data(socket: *Socket, buf: ?[*]u8, len: i32) callconv(.C) ?*Socket { Fields.onData( getValue(socket), - TLSSocket{ .socket = socket }, + TLSSocket.from(socket), buf.?[0..@as(usize, @intCast(len))], ); return socket; @@ -153,28 +236,28 @@ pub fn NewSocketHandler(comptime is_ssl: bool) type { pub fn on_writable(socket: *Socket) callconv(.C) ?*Socket { Fields.onWritable( getValue(socket), - TLSSocket{ .socket = socket }, + TLSSocket.from(socket), ); return socket; } pub fn on_timeout(socket: *Socket) callconv(.C) ?*Socket { Fields.onTimeout( getValue(socket), - TLSSocket{ .socket = socket }, + TLSSocket.from(socket), ); return socket; } pub fn on_long_timeout(socket: *Socket) callconv(.C) ?*Socket { Fields.onLongTimeout( getValue(socket), - TLSSocket{ .socket = socket }, + TLSSocket.from(socket), ); return socket; } pub fn on_connect_error(socket: *Socket, code: i32) callconv(.C) ?*Socket { Fields.onConnectError( getValue(socket), - TLSSocket{ .socket = socket }, + TLSSocket.from(socket), code, ); return socket; @@ -182,12 +265,12 @@ pub fn NewSocketHandler(comptime is_ssl: bool) type { pub fn on_end(socket: *Socket) callconv(.C) ?*Socket { Fields.onEnd( getValue(socket), - TLSSocket{ .socket = socket }, + TLSSocket.from(socket), ); return socket; } pub fn on_handshake(socket: *Socket, success: i32, verify_error: us_bun_verify_error_t, _: ?*anyopaque) callconv(.C) void { - Fields.onHandshake(getValue(socket), TLSSocket{ .socket = socket }, success, verify_error); + Fields.onHandshake(getValue(socket), TLSSocket.from(socket), success, verify_error); } }; @@ -203,67 +286,73 @@ pub fn NewSocketHandler(comptime is_ssl: bool) type { .on_long_timeout = SocketHandler.on_long_timeout, }; - const socket = us_socket_wrap_with_tls(ssl_int, this.socket, options, events, socket_ext_size) orelse return null; + const this_socket = this.socket.get() orelse @panic("socket is not open"); + + const socket = us_socket_wrap_with_tls(ssl_int, this_socket, options, events, socket_ext_size) orelse return null; return NewSocketHandler(true).from(socket); } - pub fn getNativeHandle(this: ThisSocket) *NativeSocketHandleType(is_ssl) { - return @as(*NativeSocketHandleType(is_ssl), @ptrCast(us_socket_get_native_handle(comptime ssl_int, this.socket).?)); + pub fn getNativeHandle(this: ThisSocket) ?*NativeSocketHandleType(is_ssl) { + return @ptrCast(switch (this.socket) { + .done => |socket| us_socket_get_native_handle(comptime ssl_int, socket), + .connecting => |socket| us_connecting_socket_get_native_handle(comptime ssl_int, socket), + } orelse return null); } pub inline fn fd(this: ThisSocket) bun.FileDescriptor { if (comptime is_ssl) { @compileError("SSL sockets do not have a file descriptor accessible this way"); } - - if (comptime Environment.isWindows) { + const socket = this.socket.get() orelse return bun.invalid_fd; + return if (comptime Environment.isWindows) // on windows uSockets exposes SOCKET - return bun.toFD(@as(bun.FDImpl.System, @ptrCast(us_socket_get_native_handle(0, this.socket)))); - } - - return bun.toFD(@as(i32, @intCast(@intFromPtr(us_socket_get_native_handle(0, this.socket))))); + bun.toFD(@as(bun.FDImpl.System, @ptrCast(us_socket_get_native_handle(0, socket)))) + else + bun.toFD(@as(i32, @intCast(@intFromPtr(us_socket_get_native_handle(0, socket))))); } pub fn markNeedsMoreForSendfile(this: ThisSocket) void { if (comptime is_ssl) { @compileError("SSL sockets do not support sendfile yet"); } - - us_socket_sendfile_needs_more(this.socket); + const socket = this.socket.get() orelse return; + us_socket_sendfile_needs_more(socket); } - pub fn ext(this: ThisSocket, comptime ContextType: type) ?*ContextType { + pub fn ext(this: ThisSocket, comptime ContextType: type) *ContextType { const alignment = if (ContextType == *anyopaque) @sizeOf(usize) else std.meta.alignment(ContextType); - const ptr = us_socket_ext( - comptime ssl_int, - this.socket, - ) orelse return null; + const ptr = switch (this.socket) { + .done => |sock| us_socket_ext(comptime ssl_int, sock), + .connecting => |sock| us_connecting_socket_ext(comptime ssl_int, sock), + }; return @as(*align(alignment) ContextType, @ptrCast(@alignCast(ptr))); } /// This can be null if the socket was closed. pub fn context(this: ThisSocket) ?*SocketContext { - return us_socket_context( - comptime ssl_int, - this.socket, - ); + switch (this.socket) { + .done => |socket| return us_socket_context(comptime ssl_int, socket), + .connecting => |socket| return us_connecting_socket_context(comptime ssl_int, socket), + } } pub fn flush(this: ThisSocket) void { + const socket = this.socket.get() orelse return; return us_socket_flush( comptime ssl_int, - this.socket, + socket, ); } pub fn write(this: ThisSocket, data: []const u8, msg_more: bool) i32 { + const socket = this.socket.get() orelse return 0; const result = us_socket_write( comptime ssl_int, - this.socket, + socket, data.ptr, // truncate to 31 bits since sign bit exists @as(i32, @intCast(@as(u31, @truncate(data.len)))), @@ -278,9 +367,10 @@ pub fn NewSocketHandler(comptime is_ssl: bool) type { } pub fn rawWrite(this: ThisSocket, data: []const u8, msg_more: bool) i32 { + const socket = this.socket.get() orelse return 0; return us_socket_raw_write( comptime ssl_int, - this.socket, + socket, data.ptr, // truncate to 31 bits since sign bit exists @as(i32, @intCast(@as(u31, @truncate(data.len)))), @@ -288,24 +378,57 @@ pub fn NewSocketHandler(comptime is_ssl: bool) type { ); } pub fn shutdown(this: ThisSocket) void { - debug("us_socket_shutdown({d})", .{@intFromPtr(this.socket)}); - return us_socket_shutdown( - comptime ssl_int, - this.socket, - ); + // debug("us_socket_shutdown({d})", .{@intFromPtr(this.socket)}); + switch (this.socket) { + .done => |socket| { + return us_socket_shutdown( + comptime ssl_int, + socket, + ); + }, + .connecting => |socket| { + return us_connecting_socket_shutdown( + comptime ssl_int, + socket, + ); + }, + } } + pub fn shutdownRead(this: ThisSocket) void { - debug("us_socket_shutdown_read({d})", .{@intFromPtr(this.socket)}); - return us_socket_shutdown_read( - comptime ssl_int, - this.socket, - ); + switch (this.socket) { + .done => |socket| { + // debug("us_socket_shutdown_read({d})", .{@intFromPtr(socket)}); + return us_socket_shutdown_read( + comptime ssl_int, + socket, + ); + }, + .connecting => |socket| { + // debug("us_connecting_socket_shutdown_read({d})", .{@intFromPtr(socket)}); + return us_connecting_socket_shutdown_read( + comptime ssl_int, + socket, + ); + }, + } } + pub fn isShutdown(this: ThisSocket) bool { - return us_socket_is_shut_down( - comptime ssl_int, - this.socket, - ) > 0; + switch (this.socket) { + .done => |socket| { + return us_socket_is_shut_down( + comptime ssl_int, + socket, + ) > 0; + }, + .connecting => |socket| { + return us_connecting_socket_is_shut_down( + comptime ssl_int, + socket, + ) > 0; + }, + } } pub fn isClosedOrHasError(this: ThisSocket) bool { @@ -317,37 +440,73 @@ pub fn NewSocketHandler(comptime is_ssl: bool) type { } pub fn getError(this: ThisSocket) i32 { - return us_socket_get_error( - comptime ssl_int, - this.socket, - ); + switch (this.socket) { + .done => |socket| { + return us_socket_get_error( + comptime ssl_int, + socket, + ); + }, + .connecting => |socket| { + return us_connecting_socket_get_error( + comptime ssl_int, + socket, + ); + }, + } } pub fn isClosed(this: ThisSocket) bool { - return us_socket_is_closed( - comptime ssl_int, - this.socket, - ) > 0; + switch (this.socket) { + .done => |socket| { + return us_socket_is_closed( + comptime ssl_int, + socket, + ) > 0; + }, + .connecting => |socket| { + return us_connecting_socket_is_closed( + comptime ssl_int, + socket, + ) > 0; + }, + } } + pub fn close(this: ThisSocket, code: i32, reason: ?*anyopaque) void { - debug("us_socket_close({d})", .{@intFromPtr(this.socket)}); - _ = us_socket_close( - comptime ssl_int, - this.socket, - code, - reason, - ); + // debug("us_socket_close({d})", .{@intFromPtr(this.socket)}); + switch (this.socket) { + .done => |socket| { + _ = us_socket_close( + comptime ssl_int, + socket, + code, + reason, + ); + }, + .connecting => |socket| { + _ = us_connecting_socket_close( + comptime ssl_int, + socket, + ); + }, + } } pub fn localPort(this: ThisSocket) i32 { + const socket = this.socket.get() orelse return 0; return us_socket_local_port( comptime ssl_int, - this.socket, + socket, ); } pub fn remoteAddress(this: ThisSocket, buf: [*]u8, length: *i32) void { + const socket = this.socket.get() orelse { + length.* = 0; + return; + }; return us_socket_remote_address( comptime ssl_int, - this.socket, + socket, buf, length, ); @@ -361,10 +520,11 @@ pub fn NewSocketHandler(comptime is_ssl: bool) type { /// # Returns /// This function returns a slice of the buffer on success, or null on failure. pub fn localAddressBinary(this: ThisSocket, buf: []u8) ?[]const u8 { + const socket = this.socket.get() orelse return null; var length: i32 = @intCast(buf.len); us_socket_local_address( comptime ssl_int, - this.socket, + socket, buf.ptr, &length, ); @@ -388,11 +548,8 @@ pub fn NewSocketHandler(comptime is_ssl: bool) type { const addr_v6_len = @sizeOf(std.meta.FieldType(std.os.sockaddr.in6, .addr)); var sa_buf: [addr_v6_len + 1]u8 = undefined; - const binary = this.localAddressBinary(&sa_buf); - if (binary == null) { - return null; - } - const addr_len: usize = binary.?.len; + const binary = this.localAddressBinary(&sa_buf) orelse return null; + const addr_len: usize = binary.len; sa_buf[addr_len] = 0; var ret: ?[*:0]const u8 = null; @@ -426,14 +583,10 @@ pub fn NewSocketHandler(comptime is_ssl: bool) type { const host_ = allocator.dupeZ(u8, host) catch return null; defer allocator.free(host_); - const socket = us_socket_context_connect(comptime ssl_int, socket_ctx, host_, port, null, 0, @sizeOf(Context)) orelse return null; - const socket_ = ThisSocket{ .socket = socket }; + const socket = us_socket_context_connect(comptime ssl_int, socket_ctx, host_, port, 0, @sizeOf(Context)) orelse return null; + const socket_ = ThisSocket{ .socket = .{ .connecting = socket } }; - var holder = socket_.ext(Context) orelse { - if (comptime bun.Environment.allow_assert) unreachable; - _ = us_socket_close_connecting(comptime ssl_int, socket); - return null; - }; + var holder = socket_.ext(Context); holder.* = ctx; @field(holder, socket_field_name) = socket_; return holder; @@ -459,12 +612,9 @@ pub fn NewSocketHandler(comptime is_ssl: bool) type { this: *This, comptime socket_field_name: ?[]const u8, ) ?ThisSocket { - const socket_ = ThisSocket{ .socket = us_socket_from_fd(ctx, @sizeOf(*anyopaque), bun.socketcast(handle)) orelse return null }; + const socket_ = ThisSocket{ .socket = .{ .done = us_socket_from_fd(ctx, @sizeOf(*anyopaque), bun.socketcast(handle)) orelse return null } }; - const holder = socket_.ext(*anyopaque) orelse { - if (comptime bun.Environment.allow_assert) unreachable; - return null; - }; + const holder = socket_.ext(*anyopaque); holder.* = this; if (comptime socket_field_name) |field| { @@ -500,12 +650,8 @@ pub fn NewSocketHandler(comptime is_ssl: bool) type { const socket = us_socket_context_connect_unix(comptime ssl_int, socket_ctx, path_, path_.len, 0, 8) orelse return error.FailedToOpenSocket; - const socket_ = ThisSocket{ .socket = socket }; - const holder = socket_.ext(*anyopaque) orelse { - if (comptime bun.Environment.allow_assert) unreachable; - _ = us_socket_close_connecting(comptime ssl_int, socket); - return error.FailedToOpenSocket; - }; + const socket_ = ThisSocket{ .socket = .{ .done = socket } }; + const holder = socket_.ext(*anyopaque); holder.* = ctx; return socket_; } @@ -531,35 +677,18 @@ pub fn NewSocketHandler(comptime is_ssl: bool) type { defer if (host) |allocated_host| allocator.free(allocated_host[0..raw_host.len]); - const us_socket = us_socket_context_connect( + const connecting = us_socket_context_connect( comptime ssl_int, socket_ctx, host, port, - null, 0, @sizeOf(*anyopaque), - ) orelse { - if (Environment.isWindows) { - try bun.windows.WSAGetLastError(); - } else { - // TODO(@paperdave): On POSIX, this will call getaddrinfo + socket - // the former of these does not set errno, and usockets does not have - // a way to propogate the error. - // - // This is caught in the wild: https://github.com/oven-sh/bun/issues/6381 - // and we can definitely report a better here. It just is tricky. - } - return error.FailedToOpenSocket; - }; + ); - const socket = ThisSocket{ .socket = us_socket }; + const socket = ThisSocket{ .socket = .{ .connecting = connecting } }; - const holder = socket.ext(*anyopaque) orelse { - if (comptime bun.Environment.allow_assert) unreachable; - _ = us_socket_close_connecting(comptime ssl_int, us_socket); - return error.FailedToOpenSocket; - }; + const holder = socket.ext(*anyopaque); holder.* = ptr; return socket; } @@ -588,30 +717,30 @@ pub fn NewSocketHandler(comptime is_ssl: bool) type { } if (comptime deref_) { - return (SocketHandlerType{ .socket = socket }).ext(ContextType).?.*; + return (SocketHandlerType.from(socket)).ext(ContextType).*; } - return (SocketHandlerType{ .socket = socket }).ext(ContextType).?; + return (SocketHandlerType.from(socket)).ext(ContextType); } pub fn on_open(socket: *Socket, is_client: i32, _: [*c]u8, _: i32) callconv(.C) ?*Socket { if (comptime @hasDecl(Fields, "onCreate")) { if (is_client == 0) { Fields.onCreate( - SocketHandlerType{ .socket = socket }, + SocketHandlerType.from(socket), ); } } Fields.onOpen( getValue(socket), - SocketHandlerType{ .socket = socket }, + SocketHandlerType.from(socket), ); return socket; } pub fn on_close(socket: *Socket, code: i32, reason: ?*anyopaque) callconv(.C) ?*Socket { Fields.onClose( getValue(socket), - SocketHandlerType{ .socket = socket }, + SocketHandlerType.from(socket), code, reason, ); @@ -620,7 +749,7 @@ pub fn NewSocketHandler(comptime is_ssl: bool) type { pub fn on_data(socket: *Socket, buf: ?[*]u8, len: i32) callconv(.C) ?*Socket { Fields.onData( getValue(socket), - SocketHandlerType{ .socket = socket }, + SocketHandlerType.from(socket), buf.?[0..@as(usize, @intCast(len))], ); return socket; @@ -628,21 +757,27 @@ pub fn NewSocketHandler(comptime is_ssl: bool) type { pub fn on_writable(socket: *Socket) callconv(.C) ?*Socket { Fields.onWritable( getValue(socket), - SocketHandlerType{ .socket = socket }, + SocketHandlerType.from(socket), ); return socket; } pub fn on_timeout(socket: *Socket) callconv(.C) ?*Socket { Fields.onTimeout( getValue(socket), - SocketHandlerType{ .socket = socket }, + SocketHandlerType.from(socket), ); return socket; } - pub fn on_connect_error(socket: *Socket, code: i32) callconv(.C) ?*Socket { + pub fn on_connect_error(socket: *ConnectingSocket, code: i32) callconv(.C) ?*ConnectingSocket { + const val = if (comptime ContextType == anyopaque) + us_connecting_socket_ext(comptime ssl_int, socket) + else if (comptime deref_) + SocketHandlerType.fromConnecting(socket).ext(ContextType).* + else + SocketHandlerType.fromConnecting(socket).ext(ContextType); Fields.onConnectError( - getValue(socket), - SocketHandlerType{ .socket = socket }, + val, + SocketHandlerType.fromConnecting(socket), code, ); return socket; @@ -650,12 +785,12 @@ pub fn NewSocketHandler(comptime is_ssl: bool) type { pub fn on_end(socket: *Socket) callconv(.C) ?*Socket { Fields.onEnd( getValue(socket), - SocketHandlerType{ .socket = socket }, + SocketHandlerType.from(socket), ); return socket; } pub fn on_handshake(socket: *Socket, success: i32, verify_error: us_bun_verify_error_t, _: ?*anyopaque) callconv(.C) void { - Fields.onHandshake(getValue(socket), SocketHandlerType{ .socket = socket }, success, verify_error); + Fields.onHandshake(getValue(socket), SocketHandlerType.from(socket), success, verify_error); } }; @@ -694,34 +829,34 @@ pub fn NewSocketHandler(comptime is_ssl: bool) type { const ValueType = if (deref) ContextType else *ContextType; fn getValue(socket: *Socket) ValueType { if (comptime ContextType == anyopaque) { - return us_socket_ext(comptime ssl_int, socket).?; + return us_socket_ext(comptime ssl_int, socket); } if (comptime deref_) { - return (ThisSocket{ .socket = socket }).ext(ContextType).?.*; + return (ThisSocket.from(socket)).ext(ContextType).*; } - return (ThisSocket{ .socket = socket }).ext(ContextType).?; + return (ThisSocket.from(socket)).ext(ContextType); } pub fn on_open(socket: *Socket, is_client: i32, _: [*c]u8, _: i32) callconv(.C) ?*Socket { if (comptime @hasDecl(Fields, "onCreate")) { if (is_client == 0) { Fields.onCreate( - ThisSocket{ .socket = socket }, + ThisSocket.from(socket), ); } } Fields.onOpen( getValue(socket), - ThisSocket{ .socket = socket }, + ThisSocket.from(socket), ); return socket; } pub fn on_close(socket: *Socket, code: i32, reason: ?*anyopaque) callconv(.C) ?*Socket { Fields.onClose( getValue(socket), - ThisSocket{ .socket = socket }, + ThisSocket.from(socket), code, reason, ); @@ -730,7 +865,7 @@ pub fn NewSocketHandler(comptime is_ssl: bool) type { pub fn on_data(socket: *Socket, buf: ?[*]u8, len: i32) callconv(.C) ?*Socket { Fields.onData( getValue(socket), - ThisSocket{ .socket = socket }, + ThisSocket.from(socket), buf.?[0..@as(usize, @intCast(len))], ); return socket; @@ -738,28 +873,34 @@ pub fn NewSocketHandler(comptime is_ssl: bool) type { pub fn on_writable(socket: *Socket) callconv(.C) ?*Socket { Fields.onWritable( getValue(socket), - ThisSocket{ .socket = socket }, + ThisSocket.from(socket), ); return socket; } pub fn on_timeout(socket: *Socket) callconv(.C) ?*Socket { Fields.onTimeout( getValue(socket), - ThisSocket{ .socket = socket }, + ThisSocket.from(socket), ); return socket; } pub fn on_long_timeout(socket: *Socket) callconv(.C) ?*Socket { Fields.onLongTimeout( getValue(socket), - ThisSocket{ .socket = socket }, + ThisSocket.from(socket), ); return socket; } - pub fn on_connect_error(socket: *Socket, code: i32) callconv(.C) ?*Socket { + pub fn on_connect_error(socket: *ConnectingSocket, code: i32) callconv(.C) ?*ConnectingSocket { + const val = if (comptime ContextType == anyopaque) + us_connecting_socket_ext(comptime ssl_int, socket) + else if (comptime deref_) + ThisSocket.fromConnecting(socket).ext(ContextType).* + else + ThisSocket.fromConnecting(socket).ext(ContextType); Fields.onConnectError( - getValue(socket), - ThisSocket{ .socket = socket }, + val, + ThisSocket.fromConnecting(socket), code, ); return socket; @@ -767,12 +908,12 @@ pub fn NewSocketHandler(comptime is_ssl: bool) type { pub fn on_end(socket: *Socket) callconv(.C) ?*Socket { Fields.onEnd( getValue(socket), - ThisSocket{ .socket = socket }, + ThisSocket.from(socket), ); return socket; } pub fn on_handshake(socket: *Socket, success: i32, verify_error: us_bun_verify_error_t, _: ?*anyopaque) callconv(.C) void { - Fields.onHandshake(getValue(socket), ThisSocket{ .socket = socket }, success, verify_error); + Fields.onHandshake(getValue(socket), ThisSocket.from(socket), success, verify_error); } }; @@ -797,25 +938,15 @@ pub fn NewSocketHandler(comptime is_ssl: bool) type { } pub fn from(socket: *Socket) ThisSocket { - return ThisSocket{ .socket = socket }; + return ThisSocket{ .socket = .{ .done = socket } }; } - pub fn adopt( - socket: *Socket, - socket_ctx: *SocketContext, - comptime Context: type, - comptime socket_field_name: []const u8, - ctx: Context, - ) ?*Context { - var adopted = ThisSocket{ .socket = us_socket_context_adopt_socket(comptime ssl_int, socket_ctx, socket, @sizeOf(Context)) orelse return null }; - var holder = adopted.ext(Context) orelse { - if (comptime bun.Environment.allow_assert) unreachable; - _ = us_socket_close(comptime ssl_int, socket, 0, null); - return null; - }; - holder.* = ctx; - @field(holder, socket_field_name) = adopted; - return holder; + pub fn fromConnecting(connecting: *ConnectingSocket) ThisSocket { + return ThisSocket{ .socket = .{ .connecting = connecting } }; + } + + pub fn fromAny(socket: InternalSocket) ThisSocket { + return ThisSocket{ .socket = socket }; } pub fn adoptPtr( @@ -825,12 +956,12 @@ pub fn NewSocketHandler(comptime is_ssl: bool) type { comptime socket_field_name: []const u8, ctx: *Context, ) bool { - var adopted = ThisSocket{ .socket = us_socket_context_adopt_socket(comptime ssl_int, socket_ctx, socket, @sizeOf(*Context)) orelse return false }; - const holder = adopted.ext(*anyopaque) orelse { - if (comptime bun.Environment.allow_assert) unreachable; - _ = us_socket_close(comptime ssl_int, socket, 0, null); - return false; - }; + // ext_size of -1 means we want to keep the current ext size + // in particular, we don't want to allocate a new socket + const new_socket = us_socket_context_adopt_socket(comptime ssl_int, socket_ctx, socket, -1) orelse return false; + bun.assert(new_socket == socket); + var adopted = ThisSocket.from(new_socket); + const holder = adopted.ext(*anyopaque); holder.* = ctx; @field(ctx, socket_field_name) = adopted; return true; @@ -914,7 +1045,7 @@ pub const SocketContext = opaque { fn timeout(socket: *Socket) callconv(.C) ?*Socket { return socket; } - fn connect_error(socket: *Socket, _: i32) callconv(.C) ?*Socket { + fn connect_error(socket: *ConnectingSocket, _: i32) callconv(.C) ?*ConnectingSocket { return socket; } fn end(socket: *Socket) callconv(.C) ?*Socket { @@ -1220,16 +1351,15 @@ extern fn us_socket_context_on_handshake(ssl: i32, context: ?*SocketContext, on_ extern fn us_socket_context_on_timeout(ssl: i32, context: ?*SocketContext, on_timeout: *const fn (*Socket) callconv(.C) ?*Socket) void; extern fn us_socket_context_on_long_timeout(ssl: i32, context: ?*SocketContext, on_timeout: *const fn (*Socket) callconv(.C) ?*Socket) void; -extern fn us_socket_context_on_connect_error(ssl: i32, context: ?*SocketContext, on_connect_error: *const fn (*Socket, i32) callconv(.C) ?*Socket) void; +extern fn us_socket_context_on_connect_error(ssl: i32, context: ?*SocketContext, on_connect_error: *const fn (*ConnectingSocket, i32) callconv(.C) ?*ConnectingSocket) void; extern fn us_socket_context_on_end(ssl: i32, context: ?*SocketContext, on_end: *const fn (*Socket) callconv(.C) ?*Socket) void; extern fn us_socket_context_ext(ssl: i32, context: ?*SocketContext) ?*anyopaque; pub extern fn us_socket_context_listen(ssl: i32, context: ?*SocketContext, host: ?[*:0]const u8, port: i32, options: i32, socket_ext_size: i32) ?*ListenSocket; pub extern fn us_socket_context_listen_unix(ssl: i32, context: ?*SocketContext, path: [*:0]const u8, pathlen: usize, options: i32, socket_ext_size: i32) ?*ListenSocket; -pub extern fn us_socket_context_connect(ssl: i32, context: ?*SocketContext, host: ?[*:0]const u8, port: i32, source_host: [*c]const u8, options: i32, socket_ext_size: i32) ?*Socket; +pub extern fn us_socket_context_connect(ssl: i32, context: ?*SocketContext, host: ?[*:0]const u8, port: i32, options: i32, socket_ext_size: i32) *ConnectingSocket; pub extern fn us_socket_context_connect_unix(ssl: i32, context: ?*SocketContext, path: [*c]const u8, pathlen: usize, options: i32, socket_ext_size: i32) ?*Socket; pub extern fn us_socket_is_established(ssl: i32, s: ?*Socket) i32; -pub extern fn us_socket_close_connecting(ssl: i32, s: ?*Socket) ?*Socket; pub extern fn us_socket_context_loop(ssl: i32, context: ?*SocketContext) ?*Loop; pub extern fn us_socket_context_adopt_socket(ssl: i32, context: ?*SocketContext, s: ?*Socket, ext_size: i32) ?*Socket; pub extern fn us_create_child_socket_context(ssl: i32, context: ?*SocketContext, context_ext_size: i32) ?*SocketContext; @@ -1322,10 +1452,11 @@ pub const Poll = opaque { }; extern fn us_socket_get_native_handle(ssl: i32, s: ?*Socket) ?*anyopaque; +extern fn us_connecting_socket_get_native_handle(ssl: i32, s: ?*ConnectingSocket) ?*anyopaque; extern fn us_socket_timeout(ssl: i32, s: ?*Socket, seconds: c_uint) void; extern fn us_socket_long_timeout(ssl: i32, s: ?*Socket, seconds: c_uint) void; -extern fn us_socket_ext(ssl: i32, s: ?*Socket) ?*anyopaque; +extern fn us_socket_ext(ssl: i32, s: ?*Socket) *anyopaque; extern fn us_socket_context(ssl: i32, s: ?*Socket) ?*SocketContext; extern fn us_socket_flush(ssl: i32, s: ?*Socket) void; extern fn us_socket_write(ssl: i32, s: ?*Socket, data: [*c]const u8, length: i32, msg_more: i32) i32; @@ -1335,6 +1466,20 @@ extern fn us_socket_shutdown_read(ssl: i32, s: ?*Socket) void; extern fn us_socket_is_shut_down(ssl: i32, s: ?*Socket) i32; extern fn us_socket_is_closed(ssl: i32, s: ?*Socket) i32; extern fn us_socket_close(ssl: i32, s: ?*Socket, code: i32, reason: ?*anyopaque) ?*Socket; + +extern fn us_connecting_socket_timeout(ssl: i32, s: ?*ConnectingSocket, seconds: c_uint) void; +extern fn us_connecting_socket_long_timeout(ssl: i32, s: ?*ConnectingSocket, seconds: c_uint) void; +extern fn us_connecting_socket_ext(ssl: i32, s: ?*ConnectingSocket) *anyopaque; +extern fn us_connecting_socket_context(ssl: i32, s: ?*ConnectingSocket) ?*SocketContext; +extern fn us_connecting_socket_shutdown(ssl: i32, s: ?*ConnectingSocket) void; +extern fn us_connecting_socket_is_closed(ssl: i32, s: ?*ConnectingSocket) i32; +extern fn us_connecting_socket_close(ssl: i32, s: ?*ConnectingSocket) void; +extern fn us_connecting_socket_shutdown_read(ssl: i32, s: ?*ConnectingSocket) void; +extern fn us_connecting_socket_is_shut_down(ssl: i32, s: ?*ConnectingSocket) i32; +extern fn us_connecting_socket_get_error(ssl: i32, s: ?*ConnectingSocket) i32; + +pub extern fn us_connecting_socket_get_loop(s: *ConnectingSocket) *Loop; + // if a TLS socket calls this, it will start SSL instance and call open event will also do TLS handshake if required // will have no effect if the socket is closed or is not TLS extern fn us_socket_open(ssl: i32, s: ?*Socket, is_client: i32, ip: [*c]const u8, ip_length: i32) ?*Socket; @@ -1721,7 +1866,7 @@ pub fn NewApp(comptime ssl: bool) type { } pub fn socket(this: *@This()) NewSocketHandler(ssl) { - return .{ .socket = @ptrCast(this) }; + return NewSocketHandler(ssl).from(@ptrCast(this)); } }; diff --git a/src/http.zig b/src/http.zig index 9b988e66068dd9..b12cd3deb70de3 100644 --- a/src/http.zig +++ b/src/http.zig @@ -44,7 +44,7 @@ const TaggedPointerUnion = @import("./tagged_pointer.zig").TaggedPointerUnion; const DeadSocket = opaque {}; var dead_socket = @as(*DeadSocket, @ptrFromInt(1)); //TODO: this needs to be freed when Worker Threads are implemented -var socket_async_http_abort_tracker = std.AutoArrayHashMap(u32, *uws.Socket).init(bun.default_allocator); +var socket_async_http_abort_tracker = std.AutoArrayHashMap(u32, uws.InternalSocket).init(bun.default_allocator); var async_http_id: std.atomic.Value(u32) = std.atomic.Value(u32).init(0); const MAX_REDIRECT_URL_LENGTH = 128 * 1024; const print_every = 0; @@ -372,12 +372,12 @@ fn NewHTTPContext(comptime ssl: bool) type { // we manually abort the connection if the hostname doesn't match .reject_unauthorized = 0, }; - this.us_socket_context = uws.us_create_bun_socket_context(ssl_int, http_thread.loop, @sizeOf(usize), opts).?; + this.us_socket_context = uws.us_create_bun_socket_context(ssl_int, http_thread.loop.loop, @sizeOf(usize), opts).?; this.sslCtx().setup(); } else { const opts: uws.us_socket_context_options_t = .{}; - this.us_socket_context = uws.us_create_socket_context(ssl_int, http_thread.loop, @sizeOf(usize), opts).?; + this.us_socket_context = uws.us_create_socket_context(ssl_int, http_thread.loop.loop, @sizeOf(usize), opts).?; } HTTPSocket.configure( @@ -391,7 +391,7 @@ fn NewHTTPContext(comptime ssl: bool) type { /// Attempt to keep the socket alive by reusing it for another request. /// If no space is available, close the socket. pub fn releaseSocket(this: *@This(), socket: HTTPSocket, hostname: []const u8, port: u16) void { - log("releaseSocket(0x{})", .{bun.fmt.hexIntUpper(@intFromPtr(socket.socket))}); + // log("releaseSocket(0x{})", .{bun.fmt.hexIntUpper(@intFromPtr(socket.socket))}); if (comptime Environment.allow_assert) { assert(!socket.isClosed()); @@ -403,7 +403,7 @@ fn NewHTTPContext(comptime ssl: bool) type { if (hostname.len <= MAX_KEEPALIVE_HOSTNAME and !socket.isClosedOrHasError() and socket.isEstablished()) { if (this.pending_sockets.get()) |pending| { - socket.ext(**anyopaque).?.* = bun.cast(**anyopaque, ActiveSocket.init(pending).ptr()); + socket.ext(**anyopaque).* = bun.cast(**anyopaque, ActiveSocket.init(pending).ptr()); socket.flush(); socket.timeout(0); socket.setTimeoutMinutes(5); @@ -413,12 +413,12 @@ fn NewHTTPContext(comptime ssl: bool) type { pending.hostname_len = @as(u8, @truncate(hostname.len)); pending.port = port; - log("Keep-Alive release {s}:{d} (0x{})", .{ hostname, port, @intFromPtr(socket.socket) }); + // log("Keep-Alive release {s}:{d} (0x{})", .{ hostname, port, @intFromPtr(socket.socket) }); return; } } - socket.ext(**anyopaque).?.* = bun.cast(**anyopaque, ActiveSocket.init(&dead_socket).ptr()); + socket.ext(**anyopaque).* = bun.cast(**anyopaque, ActiveSocket.init(&dead_socket).ptr()); socket.close(0, null); } @@ -436,7 +436,7 @@ fn NewHTTPContext(comptime ssl: bool) type { assert(context().pending_sockets.put(pooled)); } - socket.ext(**anyopaque).?.* = bun.cast(**anyopaque, ActiveSocket.init(&dead_socket).ptr()); + socket.ext(**anyopaque).* = bun.cast(**anyopaque, ActiveSocket.init(&dead_socket).ptr()); socket.close(0, null); if (comptime Environment.allow_assert) { assert(false); @@ -455,7 +455,7 @@ fn NewHTTPContext(comptime ssl: bool) type { .code = if (ssl_error.code == null) "" else ssl_error.code[0..bun.len(ssl_error.code) :0], .reason = if (ssl_error.code == null) "" else ssl_error.reason[0..bun.len(ssl_error.reason) :0], }; - log("onHandshake(0x{}) authorized: {} error: {s}", .{ bun.fmt.hexIntUpper(@intFromPtr(socket.socket)), authorized, handshake_error.code }); + // log("onHandshake(0x{}) authorized: {} error: {s}", .{ bun.fmt.hexIntUpper(@intFromPtr(socket.socket)), authorized, handshake_error.code }); const active = ActiveSocket.from(bun.cast(**anyopaque, ptr).*); if (active.get(HTTPClient)) |client| { @@ -472,7 +472,7 @@ fn NewHTTPContext(comptime ssl: bool) type { return client.firstCall(comptime ssl, socket); } else { // if authorized it self is false, this means that the connection was rejected - socket.ext(**anyopaque).?.* = bun.cast(**anyopaque, ActiveSocket.init(&dead_socket).ptr()); + socket.ext(**anyopaque).* = bun.cast(**anyopaque, ActiveSocket.init(&dead_socket).ptr()); if (client.state.stage != .done and client.state.stage != .fail) client.fail(error.ConnectionRefused); return; @@ -485,7 +485,7 @@ fn NewHTTPContext(comptime ssl: bool) type { // we can reach here if we are aborted if (!socket.isClosed()) { - socket.ext(**anyopaque).?.* = bun.cast(**anyopaque, ActiveSocket.init(&dead_socket).ptr()); + socket.ext(**anyopaque).* = bun.cast(**anyopaque, ActiveSocket.init(&dead_socket).ptr()); socket.close(0, null); } } @@ -496,7 +496,7 @@ fn NewHTTPContext(comptime ssl: bool) type { _: ?*anyopaque, ) void { var tagged = ActiveSocket.from(bun.cast(**anyopaque, ptr).*); - socket.ext(**anyopaque).?.* = bun.cast(**anyopaque, ActiveSocket.init(&dead_socket).ptr()); + socket.ext(**anyopaque).* = bun.cast(**anyopaque, ActiveSocket.init(&dead_socket).ptr()); if (tagged.get(HTTPClient)) |client| { return client.onClose(comptime ssl, socket); @@ -548,7 +548,7 @@ fn NewHTTPContext(comptime ssl: bool) type { socket: HTTPSocket, ) void { var tagged = ActiveSocket.from(bun.cast(**anyopaque, ptr).*); - socket.ext(**anyopaque).?.* = bun.cast( + socket.ext(**anyopaque).* = bun.cast( **anyopaque, ActiveSocket.init(&dead_socket).ptr(), ); @@ -588,7 +588,7 @@ fn NewHTTPContext(comptime ssl: bool) type { var tagged = ActiveSocket.from(@as(**anyopaque, @ptrCast(@alignCast(ptr))).*); { @setRuntimeSafety(false); - socket.ext(**anyopaque).?.* = @as(**anyopaque, @ptrCast(@alignCast(ActiveSocket.init(dead_socket).ptrUnsafe()))); + socket.ext(**anyopaque).* = @as(**anyopaque, @ptrCast(@alignCast(ActiveSocket.init(dead_socket).ptrUnsafe()))); } if (tagged.get(HTTPClient)) |client| { @@ -623,12 +623,12 @@ fn NewHTTPContext(comptime ssl: bool) type { assert(context().pending_sockets.put(socket)); if (http_socket.isClosed()) { - http_socket.ext(**anyopaque).?.* = bun.cast(**anyopaque, ActiveSocket.init(&dead_socket).ptr()); + http_socket.ext(**anyopaque).* = bun.cast(**anyopaque, ActiveSocket.init(&dead_socket).ptr()); continue; } if (http_socket.isShutdown() or http_socket.getError() != 0) { - http_socket.ext(**anyopaque).?.* = bun.cast(**anyopaque, ActiveSocket.init(&dead_socket).ptr()); + http_socket.ext(**anyopaque).* = bun.cast(**anyopaque, ActiveSocket.init(&dead_socket).ptr()); http_socket.close(0, null); continue; } @@ -663,7 +663,7 @@ fn NewHTTPContext(comptime ssl: bool) type { if (client.isKeepAlivePossible()) { if (this.existingSocket(hostname, port)) |sock| { - sock.ext(**anyopaque).?.* = bun.cast(**anyopaque, ActiveSocket.init(client).ptr()); + sock.ext(**anyopaque).* = bun.cast(**anyopaque, ActiveSocket.init(client).ptr()); client.allow_retry = true; client.onOpen(comptime ssl, sock); if (comptime ssl) { @@ -692,7 +692,7 @@ const ShutdownQueue = UnboundedQueue(AsyncHTTP, .next); pub const HTTPThread = struct { var http_thread_loaded: std.atomic.Value(bool) = std.atomic.Value(bool).init(false); - loop: *uws.Loop, + loop: *JSC.MiniEventLoop, http_context: NewHTTPContext(false), https_context: NewHTTPContext(true), @@ -742,13 +742,7 @@ pub const HTTPThread = struct { default_arena = Arena.init() catch unreachable; default_allocator = default_arena.allocator(); - const loop = bun.uws.Loop.create(struct { - pub fn wakeup(_: *uws.Loop) callconv(.C) void { - http_thread.drainEvents(); - } - pub fn pre(_: *uws.Loop) callconv(.C) void {} - pub fn post(_: *uws.Loop) callconv(.C) void {} - }); + const loop = bun.JSC.MiniEventLoop.initGlobal(null); if (Environment.isWindows) { _ = std.os.getenvW(comptime bun.strings.w("SystemRoot")) orelse { @@ -785,10 +779,10 @@ pub const HTTPThread = struct { for (this.queued_shutdowns.items) |http| { if (socket_async_http_abort_tracker.fetchSwapRemove(http.async_http_id)) |socket_ptr| { if (http.is_tls) { - const socket = uws.SocketTLS.from(socket_ptr.value); + const socket = uws.SocketTLS.fromAny(socket_ptr.value); socket.shutdown(); } else { - const socket = uws.SocketTCP.from(socket_ptr.value); + const socket = uws.SocketTCP.fromAny(socket_ptr.value); socket.shutdown(); } } @@ -823,9 +817,9 @@ pub const HTTPThread = struct { fn processEvents(this: *@This()) noreturn { if (comptime Environment.isPosix) { - this.loop.num_polls = @max(2, this.loop.num_polls); + this.loop.loop.num_polls = @max(2, this.loop.loop.num_polls); } else if (comptime Environment.isWindows) { - this.loop.inc(); + this.loop.loop.inc(); } else { @compileError("TODO:"); } @@ -838,7 +832,8 @@ pub const HTTPThread = struct { start_time = std.time.nanoTimestamp(); } Output.flush(); - this.loop.run(); + this.loop.tickOnce(this); + // this.loop.run(); if (comptime Environment.isDebug) { const end = std.time.nanoTimestamp(); threadlog("Waited {any}\n", .{std.fmt.fmtDurationSigned(@as(i64, @truncate(end - start_time)))}); @@ -857,12 +852,12 @@ pub const HTTPThread = struct { }) catch bun.outOfMemory(); } if (this.has_awoken.load(.Monotonic)) - this.loop.wakeup(); + this.loop.loop.wakeup(); } pub fn wakeup(this: *@This()) void { if (this.has_awoken.load(.Monotonic)) - this.loop.wakeup(); + this.loop.loop.wakeup(); } pub fn schedule(this: *@This(), batch: Batch) void { @@ -878,7 +873,7 @@ pub const HTTPThread = struct { } if (this.has_awoken.load(.Monotonic)) - this.loop.wakeup(); + this.loop.loop.wakeup(); } }; @@ -2177,7 +2172,7 @@ pub fn doRedirect(this: *HTTPClient, comptime is_ssl: bool, ctx: *NewHTTPContext this.state.response_message_buffer.deinit(); // we need to clean the client reference before closing the socket because we are going to reuse the same ref in a another request - socket.ext(**anyopaque).?.* = bun.cast( + socket.ext(**anyopaque).* = bun.cast( **anyopaque, NewHTTPContext(is_ssl).ActiveSocket.init(&dead_socket).ptr(), ); @@ -2584,7 +2579,7 @@ pub fn closeAndFail(this: *HTTPClient, err: anyerror, comptime is_ssl: bool, soc if (this.state.stage != .fail and this.state.stage != .done) { log("closeAndFail: {s}", .{@errorName(err)}); if (!socket.isClosed()) { - socket.ext(**anyopaque).?.* = bun.cast( + socket.ext(**anyopaque).* = bun.cast( **anyopaque, NewHTTPContext(is_ssl).ActiveSocket.init(&dead_socket).ptr(), ); @@ -2998,7 +2993,7 @@ pub fn progressUpdate(this: *HTTPClient, comptime is_ssl: bool, ctx: *NewHTTPCon const callback = this.result_callback; if (is_done) { - socket.ext(**anyopaque).?.* = bun.cast(**anyopaque, NewHTTPContext(is_ssl).ActiveSocket.init(&dead_socket).ptr()); + socket.ext(**anyopaque).* = bun.cast(**anyopaque, NewHTTPContext(is_ssl).ActiveSocket.init(&dead_socket).ptr()); if (this.isKeepAlivePossible() and !socket.isClosedOrHasError()) { ctx.releaseSocket( diff --git a/src/http/websocket_http_client.zig b/src/http/websocket_http_client.zig index 4efb35edae3f01..f8a1ac26e5296e 100644 --- a/src/http/websocket_http_client.zig +++ b/src/http/websocket_http_client.zig @@ -358,12 +358,10 @@ pub fn NewHTTPUpgradeClient(comptime ssl: bool) type { var tcp = this.tcp orelse return; this.tcp = null; - if (!tcp.isEstablished()) { - _ = uws.us_socket_close_connecting(comptime @as(c_int, @intFromBool(ssl)), tcp.socket); - } else { + if (tcp.isEstablished()) { tcp.shutdown(); - tcp.close(0, null); } + tcp.close(0, null); } pub fn fail(this: *HTTPClient, code: ErrorCode) void { @@ -426,14 +424,14 @@ pub fn NewHTTPUpgradeClient(comptime ssl: bool) type { pub fn handleOpen(this: *HTTPClient, socket: Socket) void { log("onOpen", .{}); - bun.assert(socket.socket == this.tcp.?.socket); + this.tcp = socket; bun.assert(this.input_body_buf.len > 0); bun.assert(this.to_send.len == 0); if (comptime ssl) { if (this.hostname.len > 0) { - socket.getNativeHandle().configureHTTPClient(this.hostname); + socket.getNativeHandle().?.configureHTTPClient(this.hostname); bun.default_allocator.free(this.hostname); this.hostname = ""; } @@ -456,7 +454,7 @@ pub fn NewHTTPUpgradeClient(comptime ssl: bool) type { return; } - bun.assert(socket.socket == this.tcp.?.socket); + bun.assert(socket.socket.eq(this.tcp.?.socket)); if (comptime Environment.allow_assert) bun.assert(!socket.isShutdown()); @@ -498,7 +496,7 @@ pub fn NewHTTPUpgradeClient(comptime ssl: bool) type { pub fn handleEnd(this: *HTTPClient, socket: Socket) void { log("onEnd", .{}); - bun.assert(socket.socket == this.tcp.?.socket); + bun.assert(socket.socket.eq(this.tcp.?.socket)); this.terminate(ErrorCode.ended); } @@ -617,14 +615,14 @@ pub fn NewHTTPUpgradeClient(comptime ssl: bool) type { this.tcp.?.timeout(0); log("onDidConnect", .{}); - this.outgoing_websocket.?.didConnect(this.tcp.?.socket, overflow.ptr, overflow.len); + this.outgoing_websocket.?.didConnect(this.tcp.?.socket.get().?, overflow.ptr, overflow.len); } pub fn handleWritable( this: *HTTPClient, socket: Socket, ) void { - bun.assert(socket.socket == this.tcp.?.socket); + bun.assert(socket.socket.eq(this.tcp.?.socket)); if (this.to_send.len == 0) return; @@ -646,9 +644,10 @@ pub fn NewHTTPUpgradeClient(comptime ssl: bool) type { // In theory, this could be called immediately // In that case, we set `state` to `failed` and return, expecting the parent to call `destroy`. - pub fn handleConnectError(this: *HTTPClient, socket: Socket, _: c_int) void { + pub fn handleConnectError(this: *HTTPClient, _: Socket, _: c_int) void { this.tcp = null; - _ = uws.us_socket_close_connecting(comptime @as(c_int, @intFromBool(ssl)), socket.socket); + + // the socket is freed by usockets when the connection fails if (this.state == .reading) { this.terminate(ErrorCode.failed_to_connect); @@ -1029,11 +1028,7 @@ pub fn NewWebSocketClient(comptime ssl: bool) type { if (this.tcp.isClosed() or this.tcp.isShutdown()) return; - if (!this.tcp.isEstablished()) { - _ = uws.us_socket_close_connecting(comptime @as(c_int, @intFromBool(ssl)), this.tcp.socket); - } else { - this.tcp.close(0, null); - } + this.tcp.close(0, null); } pub fn fail(this: *WebSocket, code: ErrorCode) void { @@ -1668,7 +1663,7 @@ pub fn NewWebSocketClient(comptime ssl: bool) type { } pub fn handleEnd(this: *WebSocket, socket: Socket) void { - bun.assert(socket.socket == this.tcp.socket); + bun.assert(socket.socket.eq(this.tcp.socket)); this.terminate(ErrorCode.ended); } @@ -1677,7 +1672,7 @@ pub fn NewWebSocketClient(comptime ssl: bool) type { socket: Socket, ) void { if (this.close_received) return; - bun.assert(socket.socket == this.tcp.socket); + bun.assert(socket.socket.eq(this.tcp.socket)); const send_buf = this.send_buffer.readableSlice(0); if (send_buf.len == 0) return; diff --git a/src/js/internal-for-testing.ts b/src/js/internal-for-testing.ts index 53cf5877564bd6..c25b9496b5d39f 100644 --- a/src/js/internal-for-testing.ts +++ b/src/js/internal-for-testing.ts @@ -55,3 +55,12 @@ export const nativeFrameForTesting: (callback: () => void) => void = $cpp( "CallSite.cpp", "createNativeFrameForTesting", ); + +export const dnsCacheStats = $zig("dns_resolver.zig", "getDNSCacheStats") as () => { + hits_completed: number; + hits_inflight: number; + size: number; + misses: number; + errors: number; + getaddrinfo: number; +}; diff --git a/src/lock.zig b/src/lock.zig index bb9f4165881f95..76d35810eab3c6 100644 --- a/src/lock.zig +++ b/src/lock.zig @@ -4,9 +4,9 @@ const Futex = @import("./futex.zig"); // Credit: this is copypasta from @kprotty. Thank you @kprotty! pub const Mutex = struct { - state: Atomic(u32) = Atomic(u32).init(UNLOCKED), + state: Atomic(u32) = Atomic(u32).init(UNLOCKED), // if changed update loop.c in usockets - const UNLOCKED = 0; + const UNLOCKED = 0; // if changed update loop.c in usockets const LOCKED = 0b01; const CONTENDED = 0b11; const is_x86 = @import("builtin").target.cpu.arch.isX86(); @@ -125,3 +125,11 @@ pub const Lock = struct { }; pub fn spinCycle() void {} + +export fn Bun__lock(lock: *Lock) void { + lock.lock(); +} + +export fn Bun__unlock(lock: *Lock) void { + lock.unlock(); +} diff --git a/test/js/node/net/node-net-server.test.ts b/test/js/node/net/node-net-server.test.ts index 8303b636716dad..c3816677c86eab 100644 --- a/test/js/node/net/node-net-server.test.ts +++ b/test/js/node/net/node-net-server.test.ts @@ -282,14 +282,9 @@ describe("net.createServer listen", () => { err = e as Error; } - if (process.platform !== "win32") { - expect(err).not.toBeNull(); - expect(err!.message).toBe("Failed to connect"); - expect(err!.name).toBe("ECONNREFUSED"); - } else { - // Bun allows this to work on Windows - expect(err).toBeNull(); - } + expect(err).not.toBeNull(); + expect(err!.message).toBe("Failed to connect"); + expect(err!.name).toBe("ECONNREFUSED"); server.close(); done();