Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add loom Support. #159

Merged
merged 2 commits into from
Jan 18, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 14 additions & 5 deletions .github/workflows/actions.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ jobs:
strategy:
matrix:
java: [ '11', '17', '21' ]
name: Tests for Java ${{ matrix.Java }}
name: Tests local for Java ${{ matrix.Java }}
steps:
- uses: actions/checkout@v3
- name: Setup java
Expand All @@ -27,14 +27,18 @@ jobs:
- name: Run tests
run: |
set -eux
./mill -ikj1 --disable-ticker __.testLocal
if [ "${{ matrix.java }}" == "21" ]; then
JAVA_OPTS='--add-opens java.base/java.lang=ALL-UNNAMED -Dcask.virtual-threads.enabled=true' ./mill -ikj1 --disable-ticker __.testLocal
else
./mill -ikj1 --disable-ticker __.testLocal
fi
test-examples:
runs-on: ubuntu-latest
strategy:
matrix:
java: [ '11', '17', '21' ]
name: Tests for Java ${{ matrix.Java }}
name: Tests examples for Java ${{ matrix.Java }}
steps:
- uses: actions/checkout@v3
- name: Setup java
Expand All @@ -45,8 +49,13 @@ jobs:
- name: Run tests
run: |
set -eux
./mill __.publishLocal
./mill -ikj1 --disable-ticker testExamples
if [ "${{ matrix.java }}" == "21" ]; then
./mill __.publishLocal
JAVA_OPTS='--add-opens java.base/java.lang=ALL-UNNAMED -Dcask.virtual-threads.enabled=true' ./mill -ikj1 --disable-ticker testExamples
else
./mill __.publishLocal
./mill -ikj1 --disable-ticker testExamples
fi
publish-sonatype:
if: github.repository == 'com-lihaoyi/cask' && contains(github.ref, 'refs/tags/')
Expand Down
86 changes: 86 additions & 0 deletions build.mill
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,89 @@ object cask extends Cross[CaskMainModule](scalaVersions) {
}
}

trait BenchmarkModule extends CrossScalaModule {
def moduleDeps = Seq(cask(crossScalaVersion))
def ivyDeps = Agg[Dep](
)
}

object benchmark extends Cross[BenchmarkModule](build.scalaVersions) with RunModule {

def waitForServer(url: String, maxAttempts: Int = 120): Boolean = {
(1 to maxAttempts).exists { attempt =>
try {
Thread.sleep(3000)
println("Checking server... Attempt " + attempt)
os.proc("curl", "-s", "-o", "/dev/null", "-w", "%{http_code}", url)
.call(check = false)
.exitCode == 0
} catch {
case _: Throwable =>
Thread.sleep(3000)
false
}
}
}

def runBenchMark(projectRoot: os.Path, example: String, vt: Boolean) = {
def runMillBackground(example: String, vt: Boolean) = {
println(s"Running $example with vt: $vt")
println("projectRoot: " + projectRoot)
os.proc(
"mill",
s"example.$example.app[$scala213].run")
.spawn(
cwd = projectRoot,
env = Map("CASK_VIRTUAL_THREAD" -> vt.toString),
stdout = os.Inherit,
stderr = os.Inherit)
}

val duration = "30s"
val threads = "4"
val connections = "100"
val url = "http://localhost:8080/"
val serverApp = runMillBackground(example, vt)

println(s"Waiting for server to start..., vt:$vt")
if (!waitForServer(url)) {
serverApp.destroy()
println("Failed to start server")
sys.exit(1)
}

val results = os.proc("wrk",
"-t", threads,
"-c", connections,
"-d", duration,
url
).call(stderr = os.Pipe)
serverApp.destroyForcibly()
Thread.sleep(1000)

println(s"""\n$example result with ${if (vt) "(virtual threads)" else "(platform threads)"}:""")
println(results.out.text())
}

def runBenchmarks() = T.command {
val projectRoot = T.workspace
if (os.proc("which", "wrk").call(check = false).exitCode != 0) {
println("Error: wrk is not installed. Please install wrk first.")
sys.exit(1)
}
for (example <- Seq(
"staticFilesWithLoom",
"todoDbWithLoom",
"minimalApplicationWithLoom")) {
println(s"target server started, starting run benchmark with wrk for :$example with VT:false")
runBenchMark(projectRoot, example, vt = false)
println(s"target server started, starting run benchmark with wrk for :$example with VT:true")
runBenchMark(projectRoot, example, vt = true)
}

}
}

trait LocalModule extends CrossScalaModule{
override def millSourcePath = super.millSourcePath / "app"
def moduleDeps = Seq(cask(crossScalaVersion))
Expand All @@ -111,13 +194,16 @@ def zippedExamples = T {
build.example.httpMethods.millSourcePath,
build.example.minimalApplication.millSourcePath,
build.example.minimalApplication2.millSourcePath,
build.example.minimalApplicationWithLoom.millSourcePath,
build.example.redirectAbort.millSourcePath,
build.example.scalatags.millSourcePath,
build.example.staticFiles.millSourcePath,
build.example.staticFilesWithLoom.millSourcePath,
build.example.staticFiles2.millSourcePath,
build.example.todo.millSourcePath,
build.example.todoApi.millSourcePath,
build.example.todoDb.millSourcePath,
build.example.todoDbWithLoom.millSourcePath,
build.example.twirl.millSourcePath,
build.example.variableRoutes.millSourcePath,
build.example.queryParams.millSourcePath,
Expand Down
18 changes: 18 additions & 0 deletions cask/src/cask/internal/ThreadBlockingHandler.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
package cask.internal

import io.undertow.server.{HttpHandler, HttpServerExchange}

import java.util.concurrent.Executor

/**
* A handler that dispatches the request to the given handler using the given executor.
* */
final class ThreadBlockingHandler(executor: Executor, handler: HttpHandler) extends HttpHandler {
require(executor ne null, "Executor should not be null")
require(handler ne null, "Handler should not be null")

def handleRequest(exchange: HttpServerExchange): Unit = {
exchange.startBlocking()
exchange.dispatch(executor, handler)
}
}
140 changes: 120 additions & 20 deletions cask/src/cask/internal/Util.scala
Original file line number Diff line number Diff line change
@@ -1,24 +1,121 @@
package cask.internal

import java.io.{InputStream, PrintWriter, StringWriter}

import scala.collection.generic.CanBuildFrom
import scala.collection.mutable
import java.io.OutputStream

import java.lang.invoke.{MethodHandles, MethodType}
import java.util.concurrent.{Executor, ExecutorService, ForkJoinPool, ThreadFactory}
import scala.annotation.switch
import scala.concurrent.{ExecutionContext, Future, Promise}
import scala.util.Try
import scala.util.control.NonFatal

object Util {
private val lookup = MethodHandles.lookup()

import cask.util.Logger.Console.globalLogger

/**
* Create a virtual thread executor with the given executor as the scheduler.
* */
def createVirtualThreadExecutor(executor: Executor): Option[ExecutorService] = {
(for {
factory <- Try(createVirtualThreadFactory("cask-handler-executor", executor))
executor <- Try(createNewThreadPerTaskExecutor(factory))
} yield executor).toOption
}

/**
* Create a default cask virtual thread executor if possible.
* */
def createDefaultCaskVirtualThreadExecutor: Option[ExecutorService] = {
for {
scheduler <- getDefaultVirtualThreadScheduler
executor <- createVirtualThreadExecutor(scheduler)
} yield executor
}

/**
* Try to get the default virtual thread scheduler, or null if not supported.
* */
def getDefaultVirtualThreadScheduler: Option[ForkJoinPool] = {
try {
val virtualThreadClass = Class.forName("java.lang.VirtualThread")
val privateLookup = MethodHandles.privateLookupIn(virtualThreadClass, lookup)
val defaultSchedulerField = privateLookup.findStaticVarHandle(virtualThreadClass, "DEFAULT_SCHEDULER", classOf[ForkJoinPool])
Option(defaultSchedulerField.get().asInstanceOf[ForkJoinPool])
} catch {
case NonFatal(e) =>
//--add-opens java.base/java.lang=ALL-UNNAMED
globalLogger.exception(e)
None
}
}

def createNewThreadPerTaskExecutor(threadFactory: ThreadFactory): ExecutorService = {
try {
val executorsClazz = ClassLoader.getSystemClassLoader.loadClass("java.util.concurrent.Executors")
val newThreadPerTaskExecutorMethod = lookup.findStatic(
executorsClazz,
"newThreadPerTaskExecutor",
MethodType.methodType(classOf[ExecutorService], classOf[ThreadFactory]))
newThreadPerTaskExecutorMethod.invoke(threadFactory)
.asInstanceOf[ExecutorService]
} catch {
case NonFatal(e) =>
globalLogger.exception(e)
throw new UnsupportedOperationException("Failed to create newThreadPerTaskExecutor.", e)
Comment on lines +67 to +68

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

exception is both logged and re-thrown.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I'm throwing it on purpose.

}
}

/**
* Create a virtual thread factory with a executor, the executor will be used as the scheduler of
* virtual thread.
*
* The executor should run task on platform threads.
*
* returns null if not supported.
*/
def createVirtualThreadFactory(prefix: String,
executor: Executor): ThreadFactory =
try {
val builderClass = ClassLoader.getSystemClassLoader.loadClass("java.lang.Thread$Builder")
val ofVirtualClass = ClassLoader.getSystemClassLoader.loadClass("java.lang.Thread$Builder$OfVirtual")
val ofVirtualMethod = lookup.findStatic(classOf[Thread], "ofVirtual", MethodType.methodType(ofVirtualClass))
var builder = ofVirtualMethod.invoke()
if (executor != null) {
val clazz = builder.getClass
val privateLookup = MethodHandles.privateLookupIn(
clazz,
lookup
)
val schedulerFieldSetter = privateLookup
.findSetter(clazz, "scheduler", classOf[Executor])
schedulerFieldSetter.invoke(builder, executor)
}
val nameMethod = lookup.findVirtual(ofVirtualClass, "name",
MethodType.methodType(ofVirtualClass, classOf[String], classOf[Long]))
val factoryMethod = lookup.findVirtual(builderClass, "factory", MethodType.methodType(classOf[ThreadFactory]))
builder = nameMethod.invoke(builder, prefix + "-virtual-thread-", 0L)
factoryMethod.invoke(builder).asInstanceOf[ThreadFactory]
} catch {
case NonFatal(e) =>
globalLogger.exception(e)
//--add-opens java.base/java.lang=ALL-UNNAMED
throw new UnsupportedOperationException("Failed to create virtual thread factory.", e)
}

def firstFutureOf[T](futures: Seq[Future[T]])(implicit ec: ExecutionContext) = {
val p = Promise[T]
futures.foreach(_.foreach(p.trySuccess))
p.future
}

/**
* Convert a string to a C&P-able literal. Basically
* copied verbatim from the uPickle source code.
*/
* Convert a string to a C&P-able literal. Basically
* copied verbatim from the uPickle source code.
*/
def literalize(s: IndexedSeq[Char], unicode: Boolean = true) = {
val sb = new StringBuilder
sb.append('"')
Expand Down Expand Up @@ -47,29 +144,30 @@ object Util {
def transferTo(in: InputStream, out: OutputStream) = {
val buffer = new Array[Byte](8192)

while ({
in.read(buffer) match{
while ( {
in.read(buffer) match {
case -1 => false
case n =>
out.write(buffer, 0, n)
true
}
}) ()
}

def pluralize(s: String, n: Int) = {
if (n == 1) s else s + "s"
}

/**
* Splits a string into path segments; automatically removes all
* leading/trailing slashes, and ignores empty path segments.
*
* Written imperatively for performance since it's used all over the place.
*/
* Splits a string into path segments; automatically removes all
* leading/trailing slashes, and ignores empty path segments.
*
* Written imperatively for performance since it's used all over the place.
*/
def splitPath(p: String): collection.IndexedSeq[String] = {
val pLength = p.length
var i = 0
while(i < pLength && p(i) == '/') i += 1
while (i < pLength && p(i) == '/') i += 1
var segmentStart = i
val out = mutable.ArrayBuffer.empty[String]

Expand All @@ -81,7 +179,7 @@ object Util {
segmentStart = i + 1
}

while(i < pLength){
while (i < pLength) {
if (p(i) == '/') complete()
i += 1
}
Expand All @@ -96,33 +194,35 @@ object Util {
pw.flush()
trace.toString
}

def softWrap(s: String, leftOffset: Int, maxWidth: Int) = {
val oneLine = s.linesIterator.mkString(" ").split(' ')

lazy val indent = " " * leftOffset

val output = new StringBuilder(oneLine.head)
var currentLineWidth = oneLine.head.length
for(chunk <- oneLine.tail){
for (chunk <- oneLine.tail) {
val addedWidth = currentLineWidth + chunk.length + 1
if (addedWidth > maxWidth){
if (addedWidth > maxWidth) {
output.append("\n" + indent)
output.append(chunk)
currentLineWidth = chunk.length
} else{
} else {
currentLineWidth = addedWidth
output.append(' ')
output.append(chunk)
}
}
output.mkString
}

def sequenceEither[A, B, M[X] <: TraversableOnce[X]](in: M[Either[A, B]])(
implicit cbf: CanBuildFrom[M[Either[A, B]], B, M[B]]): Either[A, M[B]] = {
in.foldLeft[Either[A, mutable.Builder[B, M[B]]]](Right(cbf(in))) {
case (acc, el) =>
for (a <- acc; e <- el) yield a += e
}
case (acc, el) =>
for (a <- acc; e <- el) yield a += e
}
.map(_.result())
}
}
Loading
Loading