diff --git a/zio-http/jvm/src/main/scala/zio/http/netty/client/NettyConnectionPool.scala b/zio-http/jvm/src/main/scala/zio/http/netty/client/NettyConnectionPool.scala index ee5b672560..d2fdf42def 100644 --- a/zio-http/jvm/src/main/scala/zio/http/netty/client/NettyConnectionPool.scala +++ b/zio-http/jvm/src/main/scala/zio/http/netty/client/NettyConnectionPool.scala @@ -16,7 +16,7 @@ package zio.http.netty.client -import java.net.InetSocketAddress +import java.net.{Inet6Address, InetAddress, InetSocketAddress} import java.util.concurrent.TimeUnit import zio._ @@ -36,6 +36,8 @@ private[netty] trait NettyConnectionPool extends ConnectionPool[JChannel] private[netty] object NettyConnectionPool { + private val HappyEyeballsDelay: Duration = 250.millis + protected def createChannel( channelFactory: JChannelFactory[JChannel], eventLoopGroup: JEventLoopGroup, @@ -106,38 +108,150 @@ private[netty] object NettyConnectionPool { for { resolvedHosts <- dnsResolver.resolve(location.host) - hosts <- Random.shuffle(resolvedHosts.toList) - hostsNec <- ZIO.succeed(NonEmptyChunk.fromIterable(hosts.head, hosts.tail)) - ch <- collectFirstSuccess(hostsNec) { host => - ZIO.suspend { - val bootstrap = new Bootstrap() - .channelFactory(channelFactory) - .group(eventLoopGroup) - .remoteAddress(new InetSocketAddress(host, location.port)) - .withOption[Integer](ChannelOption.CONNECT_TIMEOUT_MILLIS, connectionTimeout.map(_.toMillis.toInt)) - .handler(initializer) - localAddress.foreach(bootstrap.localAddress) - - val channelFuture = bootstrap.connect() - val ch = channelFuture.channel() - Scope.addFinalizer { - NettyFutureExecutor.executed { - channelFuture.cancel(true) - ch.close() - }.when(ch.isOpen).ignoreLogged - } *> NettyFutureExecutor.executed(channelFuture).as(ch) - } - } + ch <- + // Use Happy Eyeballs algorithm + happyEyeballsConnect( + resolvedHosts, + channelFactory, + eventLoopGroup, + location, + initializer, + connectionTimeout, + localAddress, + ) } yield ch } - private def collectFirstSuccess[R, E, A, B]( - as: NonEmptyChunk[A], - )(f: A => ZIO[R, E, B])(implicit trace: Trace): ZIO[R, E, B] = { - ZIO.suspendSucceed { - val it = as.iterator - def loop: ZIO[R, E, B] = f(it.next()).catchAll(e => if (it.hasNext) loop else ZIO.fail(e)) - loop + /** + * Attempts to connect to a single address. + */ + private def connectToAddress( + host: InetAddress, + channelFactory: JChannelFactory[JChannel], + eventLoopGroup: JEventLoopGroup, + location: URL.Location.Absolute, + initializer: ChannelInitializer[JChannel], + connectionTimeout: Option[Duration], + localAddress: Option[InetSocketAddress], + )(implicit trace: Trace): ZIO[Scope, Throwable, JChannel] = { + ZIO.suspend { + val bootstrap = new Bootstrap() + .channelFactory(channelFactory) + .group(eventLoopGroup) + .remoteAddress(new InetSocketAddress(host, location.port)) + .withOption[Integer](ChannelOption.CONNECT_TIMEOUT_MILLIS, connectionTimeout.map(_.toMillis.toInt)) + .handler(initializer) + localAddress.foreach(bootstrap.localAddress) + + val channelFuture = bootstrap.connect() + val ch = channelFuture.channel() + Scope.addFinalizer { + NettyFutureExecutor.executed { + channelFuture.cancel(true) + ch.close() + }.whenDiscard(ch.isOpen).ignoreLogged + }.uninterruptible *> NettyFutureExecutor.executed(channelFuture).as(ch) + } + } + + /** + * Returns a sequence of connection attempts with their delays. Per RFC 8305, + * we start with IPv6, then after firstAddressFamilyDelay we try IPv4, then + * alternate between families. + */ + private def sortAddresses(resolvedHosts: Chunk[InetAddress]): List[InetAddress] = { + val (ipv6Addresses, ipv4Addresses) = resolvedHosts.partition(_.isInstanceOf[Inet6Address]) + val ipv6Iter = ipv6Addresses.iterator + val ipv4Iter = ipv4Addresses.iterator + val builder = List.newBuilder[InetAddress] + builder.sizeHint(resolvedHosts.size) + + // Alternate between families + var useIpv6 = true + while (ipv6Iter.hasNext || ipv4Iter.hasNext) { + + if (useIpv6 && ipv6Iter.hasNext) { + builder += ipv6Iter.next() + } else if (ipv4Iter.hasNext) { + builder += ipv4Iter.next() + } else if (ipv6Iter.hasNext) { + builder += ipv6Iter.next() + } + + useIpv6 = !useIpv6 + } + + builder.result() + } + + /** + * Implements Happy Eyeballs (RFC 8305) connection algorithm. Races connection + * attempts to IPv6 and IPv4 addresses with staggered delays. + */ + private def happyEyeballsConnect( + resolvedHosts: Chunk[InetAddress], + channelFactory: JChannelFactory[JChannel], + eventLoopGroup: JEventLoopGroup, + location: URL.Location.Absolute, + initializer: ChannelInitializer[JChannel], + connectionTimeout: Option[Duration], + localAddress: Option[InetSocketAddress], + )(implicit trace: Trace): ZIO[Scope, Throwable, JChannel] = { + + if (resolvedHosts.isEmpty) { + ZIO.fail(new RuntimeException("No addresses to connect to")) + } else if (resolvedHosts.size == 1) { + connectToAddress( + resolvedHosts.head, + channelFactory, + eventLoopGroup, + location, + initializer, + connectionTimeout, + localAddress, + ) + } else { + val addresses = sortAddresses(resolvedHosts) + for { + lastFailed <- Queue.bounded[Unit](requestedCapacity = 1) + successful <- Ref.make(List.empty[JChannel]) + _ <- ZIO.raceAll( + connectToAddress( + addresses.head, + channelFactory, + eventLoopGroup, + location, + initializer, + connectionTimeout, + localAddress, + ).onExit { + case e: Exit.Success[JChannel] => successful.update(channels => channels :+ e.value) + case _: Exit.Failure[_] => lastFailed.offer(()) + }, + addresses.tail.zipWithIndex.map { case (address, index) => + ZIO.sleep(HappyEyeballsDelay * index.toDouble).raceFirst(lastFailed.take).ignore *> + connectToAddress( + address, + channelFactory, + eventLoopGroup, + location, + initializer, + connectionTimeout, + localAddress, + ).onExit { + case e: Exit.Success[JChannel] => successful.update(channels => channels :+ e.value) + case _: Exit.Failure[_] => lastFailed.offer(()) + } + }, + ) + channels <- successful.get + channel <- channels.headOption match { + case ch: Some[JChannel] => ZIO.succeed(ch.value) + case None => ZIO.fail(new RuntimeException("All connection attempts failed")) + } + _ <- ZIO.foreachDiscard(channels.tail)(ch => ZIO.attempt(ch.close())) + } yield channel + } }