Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use loom-compatibility, avoid hacking into JDK #163

Closed
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion build.mill
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ trait CaskMainModule extends CaskModule {
def ivyDeps = T{
Agg(
ivy"io.undertow:undertow-core:2.3.18.Final",
ivy"com.lihaoyi::upickle:4.0.2"
ivy"com.lihaoyi::upickle:4.0.2",
ivy"com.github.sideeffffect:loom-compatibility:0.2.0",
) ++
Agg.when(!isScala3)(ivy"org.scala-lang:scala-reflect:$crossScalaVersion")
}
Expand Down
101 changes: 11 additions & 90 deletions cask/src/cask/internal/Util.scala
Original file line number Diff line number Diff line change
@@ -1,111 +1,32 @@
package cask.internal

import com.github.sideeffffect.loom_compatibility.{LoomExecutors, LoomThread, LoomUnavailable}

import java.io.{InputStream, PrintWriter, StringWriter}
import scala.collection.generic.CanBuildFrom
import scala.collection.mutable
import java.io.OutputStream
import java.lang.invoke.{MethodHandles, MethodType}
import java.util.concurrent.{Executor, ExecutorService, ForkJoinPool, ThreadFactory}
import java.util.concurrent.{ExecutorService, ForkJoinPool}
import scala.annotation.switch
import scala.concurrent.{ExecutionContext, Future, Promise}
import scala.util.Try
import scala.util.control.NonFatal

object Util {
private val lookup = MethodHandles.lookup()

import cask.util.Logger.Console.globalLogger

/**
* Create a virtual thread executor with the given executor as the scheduler.
* */
def createVirtualThreadExecutor(executor: Executor): Option[ExecutorService] = {
(for {
factory <- Try(createVirtualThreadFactory("cask-handler-executor", executor))
executor <- Try(createNewThreadPerTaskExecutor(factory))
} yield executor).toOption
}

/**
* Create a default cask virtual thread executor if possible.
* */
def createDefaultCaskVirtualThreadExecutor: Option[ExecutorService] = {
for {
scheduler <- getDefaultVirtualThreadScheduler
executor <- createVirtualThreadExecutor(scheduler)
} yield executor
}

/**
* Try to get the default virtual thread scheduler, or null if not supported.
* */
def getDefaultVirtualThreadScheduler: Option[ForkJoinPool] = {
try {
val virtualThreadClass = Class.forName("java.lang.VirtualThread")
val privateLookup = MethodHandles.privateLookupIn(virtualThreadClass, lookup)
val defaultSchedulerField = privateLookup.findStaticVarHandle(virtualThreadClass, "DEFAULT_SCHEDULER", classOf[ForkJoinPool])
Option(defaultSchedulerField.get().asInstanceOf[ForkJoinPool])
} catch {
case NonFatal(e) =>
//--add-opens java.base/java.lang=ALL-UNNAMED
globalLogger.exception(e)
None
}
def createDefaultCaskVirtualThreadExecutor: Option[ExecutorService] = try {
val loomThread = LoomThread.load()
val loomExecutors = LoomExecutors.load()
val factory = loomThread.ofVirtual().name("cask-virtual-thread-", 0L).factory()
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@He-Pin the threads are named, as you can see here.

val executor = loomExecutors.newThreadPerTaskExecutor(factory)
Some(executor)
} catch {
case _: LoomUnavailable => None
}

def createNewThreadPerTaskExecutor(threadFactory: ThreadFactory): ExecutorService = {
try {
val executorsClazz = ClassLoader.getSystemClassLoader.loadClass("java.util.concurrent.Executors")
val newThreadPerTaskExecutorMethod = lookup.findStatic(
executorsClazz,
"newThreadPerTaskExecutor",
MethodType.methodType(classOf[ExecutorService], classOf[ThreadFactory]))
newThreadPerTaskExecutorMethod.invoke(threadFactory)
.asInstanceOf[ExecutorService]
} catch {
case NonFatal(e) =>
globalLogger.exception(e)
throw new UnsupportedOperationException("Failed to create newThreadPerTaskExecutor.", e)
}
}

/**
* Create a virtual thread factory with a executor, the executor will be used as the scheduler of
* virtual thread.
*
* The executor should run task on platform threads.
*
* returns null if not supported.
*/
def createVirtualThreadFactory(prefix: String,
executor: Executor): ThreadFactory =
try {
val builderClass = ClassLoader.getSystemClassLoader.loadClass("java.lang.Thread$Builder")
val ofVirtualClass = ClassLoader.getSystemClassLoader.loadClass("java.lang.Thread$Builder$OfVirtual")
val ofVirtualMethod = lookup.findStatic(classOf[Thread], "ofVirtual", MethodType.methodType(ofVirtualClass))
var builder = ofVirtualMethod.invoke()
if (executor != null) {
val clazz = builder.getClass
val privateLookup = MethodHandles.privateLookupIn(
clazz,
lookup
)
val schedulerFieldSetter = privateLookup
.findSetter(clazz, "scheduler", classOf[Executor])
schedulerFieldSetter.invoke(builder, executor)
}
val nameMethod = lookup.findVirtual(ofVirtualClass, "name",
MethodType.methodType(ofVirtualClass, classOf[String], classOf[Long]))
val factoryMethod = lookup.findVirtual(builderClass, "factory", MethodType.methodType(classOf[ThreadFactory]))
builder = nameMethod.invoke(builder, prefix + "-virtual-thread-", 0L)
factoryMethod.invoke(builder).asInstanceOf[ThreadFactory]
} catch {
case NonFatal(e) =>
globalLogger.exception(e)
//--add-opens java.base/java.lang=ALL-UNNAMED
throw new UnsupportedOperationException("Failed to create virtual thread factory.", e)
}

def firstFutureOf[T](futures: Seq[Future[T]])(implicit ec: ExecutionContext) = {
val p = Promise[T]
futures.foreach(_.foreach(p.trySuccess))
Expand Down