Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use loom-compatibility, avoid hacking into JDK #163

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/actions.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ jobs:
run: |
set -eux
if [ "${{ matrix.java }}" == "21" ]; then
JAVA_OPTS='--add-opens java.base/java.lang=ALL-UNNAMED -Dcask.virtual-threads.enabled=true' ./mill -ikj1 --disable-ticker __.testLocal
JAVA_OPTS='-Dcask.virtual-threads.enabled=true' ./mill -ikj1 --disable-ticker __.testLocal
else
./mill -ikj1 --disable-ticker __.testLocal
fi
Expand All @@ -51,7 +51,7 @@ jobs:
set -eux
if [ "${{ matrix.java }}" == "21" ]; then
./mill __.publishLocal
JAVA_OPTS='--add-opens java.base/java.lang=ALL-UNNAMED -Dcask.virtual-threads.enabled=true' ./mill -ikj1 --disable-ticker testExamples
JAVA_OPTS='-Dcask.virtual-threads.enabled=true' ./mill -ikj1 --disable-ticker testExamples
else
./mill __.publishLocal
./mill -ikj1 --disable-ticker testExamples
Expand Down
3 changes: 2 additions & 1 deletion build.mill
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ trait CaskMainModule extends CaskModule {
def ivyDeps = T{
Agg(
ivy"io.undertow:undertow-core:2.3.18.Final",
ivy"com.lihaoyi::upickle:4.0.2"
ivy"com.lihaoyi::upickle:4.0.2",
ivy"com.github.sideeffffect:loom-compatibility:0.2.0",
) ++
Agg.when(!isScala3)(ivy"org.scala-lang:scala-reflect:$crossScalaVersion")
}
Expand Down
101 changes: 11 additions & 90 deletions cask/src/cask/internal/Util.scala
Original file line number Diff line number Diff line change
@@ -1,111 +1,32 @@
package cask.internal

import com.github.sideeffffect.loom_compatibility.{LoomExecutors, LoomThread, LoomUnavailable}

import java.io.{InputStream, PrintWriter, StringWriter}
import scala.collection.generic.CanBuildFrom
import scala.collection.mutable
import java.io.OutputStream
import java.lang.invoke.{MethodHandles, MethodType}
import java.util.concurrent.{Executor, ExecutorService, ForkJoinPool, ThreadFactory}
import java.util.concurrent.{ExecutorService, ForkJoinPool}
import scala.annotation.switch
import scala.concurrent.{ExecutionContext, Future, Promise}
import scala.util.Try
import scala.util.control.NonFatal

object Util {
private val lookup = MethodHandles.lookup()

import cask.util.Logger.Console.globalLogger

/**
* Create a virtual thread executor with the given executor as the scheduler.
* */
def createVirtualThreadExecutor(executor: Executor): Option[ExecutorService] = {
(for {
factory <- Try(createVirtualThreadFactory("cask-handler-executor", executor))
executor <- Try(createNewThreadPerTaskExecutor(factory))
} yield executor).toOption
}

/**
* Create a default cask virtual thread executor if possible.
* */
def createDefaultCaskVirtualThreadExecutor: Option[ExecutorService] = {
for {
scheduler <- getDefaultVirtualThreadScheduler
executor <- createVirtualThreadExecutor(scheduler)
} yield executor
}

/**
* Try to get the default virtual thread scheduler, or null if not supported.
* */
def getDefaultVirtualThreadScheduler: Option[ForkJoinPool] = {
try {
val virtualThreadClass = Class.forName("java.lang.VirtualThread")
val privateLookup = MethodHandles.privateLookupIn(virtualThreadClass, lookup)
val defaultSchedulerField = privateLookup.findStaticVarHandle(virtualThreadClass, "DEFAULT_SCHEDULER", classOf[ForkJoinPool])
Option(defaultSchedulerField.get().asInstanceOf[ForkJoinPool])
} catch {
case NonFatal(e) =>
//--add-opens java.base/java.lang=ALL-UNNAMED
globalLogger.exception(e)
None
}
def createDefaultCaskVirtualThreadExecutor: Option[ExecutorService] = try {
val loomThread = LoomThread.load()
val loomExecutors = LoomExecutors.load()
val factory = loomThread.ofVirtual().name("cask-virtual-thread-", 0L).factory()
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@He-Pin the threads are named, as you can see here.

val executor = loomExecutors.newThreadPerTaskExecutor(factory)
Some(executor)
} catch {
case _: LoomUnavailable => None
}

def createNewThreadPerTaskExecutor(threadFactory: ThreadFactory): ExecutorService = {
try {
val executorsClazz = ClassLoader.getSystemClassLoader.loadClass("java.util.concurrent.Executors")
val newThreadPerTaskExecutorMethod = lookup.findStatic(
executorsClazz,
"newThreadPerTaskExecutor",
MethodType.methodType(classOf[ExecutorService], classOf[ThreadFactory]))
newThreadPerTaskExecutorMethod.invoke(threadFactory)
.asInstanceOf[ExecutorService]
} catch {
case NonFatal(e) =>
globalLogger.exception(e)
throw new UnsupportedOperationException("Failed to create newThreadPerTaskExecutor.", e)
}
}

/**
* Create a virtual thread factory with a executor, the executor will be used as the scheduler of
* virtual thread.
*
* The executor should run task on platform threads.
*
* returns null if not supported.
*/
def createVirtualThreadFactory(prefix: String,
executor: Executor): ThreadFactory =
try {
val builderClass = ClassLoader.getSystemClassLoader.loadClass("java.lang.Thread$Builder")
val ofVirtualClass = ClassLoader.getSystemClassLoader.loadClass("java.lang.Thread$Builder$OfVirtual")
val ofVirtualMethod = lookup.findStatic(classOf[Thread], "ofVirtual", MethodType.methodType(ofVirtualClass))
var builder = ofVirtualMethod.invoke()
if (executor != null) {
val clazz = builder.getClass
val privateLookup = MethodHandles.privateLookupIn(
clazz,
lookup
)
val schedulerFieldSetter = privateLookup
.findSetter(clazz, "scheduler", classOf[Executor])
schedulerFieldSetter.invoke(builder, executor)
}
val nameMethod = lookup.findVirtual(ofVirtualClass, "name",
MethodType.methodType(ofVirtualClass, classOf[String], classOf[Long]))
val factoryMethod = lookup.findVirtual(builderClass, "factory", MethodType.methodType(classOf[ThreadFactory]))
builder = nameMethod.invoke(builder, prefix + "-virtual-thread-", 0L)
factoryMethod.invoke(builder).asInstanceOf[ThreadFactory]
} catch {
case NonFatal(e) =>
globalLogger.exception(e)
//--add-opens java.base/java.lang=ALL-UNNAMED
throw new UnsupportedOperationException("Failed to create virtual thread factory.", e)
}

def firstFutureOf[T](futures: Seq[Future[T]])(implicit ec: ExecutionContext) = {
val p = Promise[T]
futures.foreach(_.foreach(p.trySuccess))
Expand Down
5 changes: 2 additions & 3 deletions docs/pages/1 - Cask - a Scala HTTP micro-framework.md
Original file line number Diff line number Diff line change
Expand Up @@ -468,9 +468,8 @@ $$$minimalApplicationWithLoom
Cask can support using Virtual Threads to handle the request out of the box, you can enable it with the next steps:

1. Running cask with Java 21 or later
2. add `--add-opens java.base/java.lang=ALL-UNNAMED` to your JVM options, which is needed to name the virtual threads.
3. add `-Dcask.virtual-threads.enabled=true` to your JVM options, which is needed to enable the virtual threads.
4. tweak the underlying carrier threads with `-Djdk.virtualThreadScheduler.parallelism`, `jdk.virtualThreadScheduler.maxPoolSize` and `jdk.unparker.maxPoolSize`.
2. add `-Dcask.virtual-threads.enabled=true` to your JVM options, which is needed to enable the virtual threads.
3. tweak the underlying carrier threads with `-Djdk.virtualThreadScheduler.parallelism`, `jdk.virtualThreadScheduler.maxPoolSize` and `jdk.unparker.maxPoolSize`.

**Advanced Features**:

Expand Down
1 change: 0 additions & 1 deletion example/httpMethods/package.mill
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,5 @@ trait AppModule extends CrossScalaModule{
ivy"com.lihaoyi::utest::0.8.4",
ivy"com.lihaoyi::requests::0.9.0",
)
def forkArgs = Seq("--add-opens=java.base/java.net=ALL-UNNAMED")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,8 @@ object MinimalApplicationWithLoom extends cask.MainRoutes {

//TO USE LOOM:
//1. JDK 21 or later is needed.
//2. add VM option: --add-opens java.base/java.lang=ALL-UNNAMED
//3. set system property: cask.virtual-threads.enabled=true
//4. NOTE: `java.util.concurrent.Executors.newVirtualThreadPerTaskExecutor` is using the shared
//2. set system property: cask.virtual-threads.enabled=true
//3. NOTE: `java.util.concurrent.Executors.newVirtualThreadPerTaskExecutor` is using the shared
// ForkJoinPool in VirtualThread. If you want to use a separate ForkJoinPool, you can create
// a new ForkJoinPool instance and pass it to `createVirtualThreadExecutor` method.

Expand Down
6 changes: 1 addition & 5 deletions example/minimalApplicationWithLoom/package.mill
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,7 @@ trait AppModule extends CrossScalaModule{

val systemProps = Seq(s"-Dcask.virtual-threads.enabled=$envVirtualThread")

val baseArgs = Seq(
"--add-opens", "java.base/java.lang=ALL-UNNAMED"
)

val seq = baseArgs ++ systemProps
val seq = systemProps
println("final forkArgs: " + seq)
seq
}
Expand Down
6 changes: 1 addition & 5 deletions example/staticFilesWithLoom/package.mill
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,7 @@ trait AppModule extends CrossScalaModule{ app =>

val systemProps = Seq(s"-Dcask.virtual-threads.enabled=$envVirtualThread")

val baseArgs = Seq(
"--add-opens", "java.base/java.lang=ALL-UNNAMED"
)

val seq = baseArgs ++ systemProps
val seq = systemProps
println("final forkArgs: " + seq)
seq
}
Expand Down
6 changes: 1 addition & 5 deletions example/todoDbWithLoom/package.mill
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,7 @@ trait AppModule extends CrossScalaModule{

val systemProps = Seq(s"-Dcask.virtual-threads.enabled=$envVirtualThread")

val baseArgs = Seq(
"--add-opens", "java.base/java.lang=ALL-UNNAMED"
)

val seq = baseArgs ++ systemProps
val seq = systemProps
println("final forkArgs: " + seq)
seq
}
Expand Down