Skip to content

Commit

Permalink
[RFC] Parameterize context type of decorators (#137)
Browse files Browse the repository at this point in the history
This allows customizing the way by-type parameters can be read: instead
of always reading data from a `cask.Request`, this allows a decorator to
parameterize the context type and pass it in explicitly from the
`wrapFunction` method.

Essentially, this makes the context parameter of `ArgReader`s (which are
used to translate data from a HTTP request to a scala parameter)
customizable for every decorator.

## Motivation

The idea behind this proposal is to uniformize the way "named" and
"by-type" parameters within the same endpoint are handled. By "named"
parameters I mean parameters which are set from the `Map[String, Input]`
in the delegate function, and by `by-type` parameters I mean parameters
which are computed from an arity zero `ArgReader`, and thus use the
`cask.Request` context to compute their value.

E.g. 

```
@cask.get(/:foo/:bar)
def index(foo: String, bar: Int, cookie1: cask.Cookies, req: cask.Request)
```

In this case, `foo` and `bar` and "named" and `cookie1` and `req` are
"by-type".

Right now, the `wrapFunction` method handles how named parameters are
set, and therefore can centrally do arbitrary pre- and post-processing.
However, by-type parameters must always use a `cask.Request` to compute
any values, which can be problematic if the computation is expensive or
if any kind of state is maintained.

**This pull request removes this asymmetry, by allowing the
`wrapFunction` to compute a custom context which is computed once before
being passed to all `InputReader`s**

## Implementation approach

- add `InputContext` type parameter to decorators
- add a parameter of type `InputContext` to `delegate`. `type Delegate =
(InputContext, Map[String, Input)) => ...`
- change `Decorator.invoke()` to carry a list of input contexts, one for
each parameter list
- also change the `Entrypoint` macros to use these input contexts rather
than hardcoding `cask.Request`
  • Loading branch information
jodersky authored Nov 3, 2024
1 parent e7fcba1 commit d823a8a
Show file tree
Hide file tree
Showing 26 changed files with 208 additions and 103 deletions.
2 changes: 2 additions & 0 deletions build.mill
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ def zippedExamples = T {
build.example.cookies.millSourcePath,
build.example.decorated.millSourcePath,
build.example.decorated2.millSourcePath,
build.example.decoratedContext.millSourcePath,
build.example.endpoints.millSourcePath,
build.example.formJsonPost.millSourcePath,
build.example.httpMethods.millSourcePath,
Expand Down Expand Up @@ -143,6 +144,7 @@ def zippedExamples = T {
.replaceFirst(
"object app extends.*\ntrait AppModule extends CrossScalaModule(.*)\\{",
s"object app extends ScalaModule $$1\\{\n def scalaVersion = \"${scala213}\"")
.replaceAll("build.scala3", s"\"${scala3}\"")
.replaceFirst(
"def ivyDeps = Agg\\[Dep\\]\\(",
"def ivyDeps = Agg(\n ivy\"com.lihaoyi::cask:" + releaseTag + "\","
Expand Down
2 changes: 1 addition & 1 deletion cask/src-2/cask/main/Routes.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import language.experimental.macros

trait Routes{

def decorators = Seq.empty[cask.router.Decorator[_, _, _]]
def decorators = Seq.empty[cask.router.Decorator[_, _, _, _]]
private[this] var metadata0: RoutesEndpointsMetadata[this.type] = null
def caskMetadata =
if (metadata0 != null) metadata0
Expand Down
15 changes: 7 additions & 8 deletions cask/src-2/cask/router/Macros.scala
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ class Macros[C <: blackbox.Context](val c: C) {
def extractMethod(method: MethodSymbol,
curCls: c.universe.Type,
convertToResultType: c.Tree,
ctx: c.Tree,
argReaders: Seq[c.Tree],
annotDeserializeTypes: Seq[c.Tree]): c.universe.Tree = {
val baseArgSym = TermName(c.freshName())
Expand All @@ -64,7 +63,7 @@ class Macros[C <: blackbox.Context](val c: C) {
val ctxSymbol = q"${c.fresh[TermName](TermName("ctx"))}"
val argData = for(argListIndex <- method.paramLists.indices) yield{
val annotDeserializeType = annotDeserializeTypes.lift(argListIndex).getOrElse(tq"scala.Any")
val argReader = argReaders.lift(argListIndex).getOrElse(q"cask.router.NoOpParser.instanceAny")
val argReader = argReaders.lift(argListIndex).getOrElse(q"cask.router.NoOpParser.instanceAnyRequest")
val flattenedArgLists = method.paramss(argListIndex)
def hasDefault(i: Int) = {
// defaults are numbered globally on a class-level, this means that we
Expand Down Expand Up @@ -108,18 +107,18 @@ class Macros[C <: blackbox.Context](val c: C) {

val argSig =
q"""
cask.router.ArgSig[$annotDeserializeType, $curCls, $docUnwrappedType, $ctx](
cask.router.ArgSig[$annotDeserializeType, $curCls, $docUnwrappedType, Any](
${arg.name.toString},
${docUnwrappedType.toString},
$docTree,
$defaultOpt
)($argReader[$docUnwrappedType])
)($argReader[$docUnwrappedType].asInstanceOf[cask.router.ArgReader[$annotDeserializeType, $docUnwrappedType, Any]])
"""

val reader = q"""
cask.router.Runtime.makeReadCall(
$argValuesSymbol($argListIndex),
$ctxSymbol,
$ctxSymbol($argListIndex),
$default,
$argSigsSymbol($argListIndex)($i)
)
Expand Down Expand Up @@ -151,7 +150,7 @@ class Macros[C <: blackbox.Context](val c: C) {
for(argNameCast <- argNameCasts) methodCall = q"$methodCall(..$argNameCast)"

val res = q"""
cask.router.EntryPoint[$curCls, $ctx](
cask.router.EntryPoint[$curCls, Any](
${method.name.toString},
${argSigs.toList},
${methodDoc match{
Expand All @@ -160,9 +159,9 @@ class Macros[C <: blackbox.Context](val c: C) {
}},
(
$baseArgSym: $curCls,
$ctxSymbol: $ctx,
$ctxSymbol: Seq[_],
$argValuesSymbol: Seq[Map[String, Any]],
$argSigsSymbol: scala.Seq[scala.Seq[cask.router.ArgSig[Any, _, _, $ctx]]]
$argSigsSymbol: scala.Seq[scala.Seq[cask.router.ArgSig[Any, _, _, Any]]]
) =>
cask.router.Runtime.validate(Seq(..${readArgs.flatten.toList})).map{
case Seq(..${argNames.flatten.toList}) => $convertToResultType($methodCall)
Expand Down
7 changes: 3 additions & 4 deletions cask/src-2/cask/router/RoutesEndpointMetadata.scala
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,15 @@ object RoutesEndpointsMetadata{

val routeParts = for{
m <- c.weakTypeOf[T].members
annotations = m.annotations.filter(_.tree.tpe <:< c.weakTypeOf[Decorator[_, _, _]])
annotations = m.annotations.filter(_.tree.tpe <:< c.weakTypeOf[Decorator[_, _, _, _]])
if annotations.nonEmpty
} yield {
if(!(annotations.last.tree.tpe <:< weakTypeOf[Endpoint[_, _, _]])) c.abort(
if(!(annotations.last.tree.tpe <:< weakTypeOf[Endpoint[_, _, _, _]])) c.abort(
annotations.head.tree.pos,
s"Last annotation applied to a function must be an instance of Endpoint, " +
s"not ${annotations.last.tree.tpe}"
)
val allEndpoints = annotations.filter(_.tree.tpe <:< weakTypeOf[Endpoint[_, _, _]])
val allEndpoints = annotations.filter(_.tree.tpe <:< weakTypeOf[Endpoint[_, _, _, _]])
if(allEndpoints.length > 1) c.abort(
annotations.last.tree.pos,
s"You can only apply one Endpoint annotation to a function, not " +
Expand All @@ -49,7 +49,6 @@ object RoutesEndpointsMetadata{
m.asInstanceOf[MethodSymbol],
weakTypeOf[T],
q"${annotObjectSyms.last}.convertToResultType",
tq"cask.Request",
annotObjectSyms.reverse.map(annotObjectSym => q"$annotObjectSym.getParamParser"),
annotObjectSyms.reverse.map(annotObjectSym => tq"$annotObjectSym.InputTypeAlias")
)
Expand Down
2 changes: 1 addition & 1 deletion cask/src-3/cask/main/Routes.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import language.experimental.macros

trait Routes{

def decorators = Seq.empty[cask.router.Decorator[_, _, _]]
def decorators = Seq.empty[cask.router.Decorator[_, _, _, _]]
private[this] var metadata0: RoutesEndpointsMetadata[this.type] = null
def caskMetadata =
if (metadata0 != null) metadata0
Expand Down
36 changes: 18 additions & 18 deletions cask/src-3/cask/router/Macros.scala
Original file line number Diff line number Diff line change
Expand Up @@ -9,21 +9,21 @@ object Macros {
* This replicates EndpointMetadata.seqify, but in a macro where error
* positions can be controlled.
*/
def checkDecorators(using Quotes)(decorators: List[Expr[Decorator[_, _, _]]]): Boolean = {
def checkDecorators(using Quotes)(decorators: List[Expr[Decorator[_, _, _, _]]]): Boolean = {
import quotes.reflect._

var hasErrors = false

def check(prevOuter: TypeRepr, decorators: List[Expr[Decorator[_, _, _]]]): Unit =
def check(prevOuter: TypeRepr, decorators: List[Expr[Decorator[_, _, _, _]]]): Unit =
decorators match {
case Nil =>
case '{ $d: Decorator[outer, inner, _] } :: tail =>
case '{ $d: Decorator[outer, inner, _, _] } :: tail =>
if (TypeRepr.of[inner] <:< prevOuter) {
check(TypeRepr.of[outer], tail)
} else {
hasErrors = true
report.error(
s"required: cask.router.Decorator[_, ${prevOuter.show}, _]",
s"required: cask.router.Decorator[_, ${prevOuter.show}, _, _]",
d
)
}
Expand Down Expand Up @@ -56,7 +56,7 @@ object Macros {

/** Summon the reader for a parameter. */
def summonReader(using Quotes)(
decorator: Expr[Decorator[_,_,_]],
decorator: Expr[Decorator[_,_,_,_]],
param: quotes.reflect.Symbol
): Expr[ArgReader[_, _, _]] = {
import quotes.reflect._
Expand Down Expand Up @@ -143,13 +143,13 @@ object Macros {
*/
def convertToResponse(using Quotes)(
method: quotes.reflect.Symbol,
endpoint: Expr[Endpoint[_, _, _]],
endpoint: Expr[Endpoint[_, _, _, _]],
result: Expr[Any]
): Expr[Any] = {
import quotes.reflect._

val innerReturnedTpt = endpoint.asTerm.tpe.asType match {
case '[Endpoint[_, innerReturned, _]] => TypeRepr.of[innerReturned]
case '[Endpoint[_, innerReturned, _, _]] => TypeRepr.of[innerReturned]
case _ => ???
}

Expand Down Expand Up @@ -186,9 +186,9 @@ object Macros {

def extractMethod[Cls: Type](using q: Quotes)(
method: quotes.reflect.Symbol,
decorators: List[Expr[Decorator[_, _, _]]], // these must also include the endpoint
endpoint: Expr[Endpoint[_, _, _]]
): Expr[EntryPoint[Cls, cask.Request]] = {
decorators: List[Expr[Decorator[_, _, _, _]]], // these must also include the endpoint
endpoint: Expr[Endpoint[_, _, _, _]]
): Expr[EntryPoint[Cls, Any]] = {
import quotes.reflect._

val defaults = getDefaultParams(method)
Expand All @@ -198,7 +198,7 @@ object Macros {

// sometimes we have more params than annotated decorators, for example if
// there are global decorators
val decorator: Option[Expr[Decorator[_, _, _]]] = decorators.lift(idx)
val decorator: Option[Expr[Decorator[_, _, _, _]]] = decorators.lift(idx)

val exprs1 = for (param <- params) yield {
val paramTree = param.tree.asInstanceOf[ValDef]
Expand Down Expand Up @@ -231,35 +231,35 @@ object Macros {
case Some(deco) => summonReader(deco, param)
case None =>
decoTpe match
case '[t] => '{ NoOpParser.instanceAny[t] }
case '[t] => '{ NoOpParser.instanceAnyRequest[t] } // TODO
}

'{
ArgSig[Any, Cls, Any, cask.Request](
ArgSig[Any, Cls, Any, Any](
${Expr(param.name)},
${Expr(paramTpeName)},
doc = None, // TODO
default = ${defaultGetter}
)(using ${reader}.asInstanceOf[ArgReader[Any, Any, cask.Request]])
)(using ${reader}.asInstanceOf[ArgReader[Any, Any, Any]])
}
}
Expr.ofList(exprs1)
}
val sigExprs = Expr.ofList(exprs0)

'{
EntryPoint[Cls, cask.Request](
EntryPoint[Cls, Any](
name = ${Expr(method.name)},
argSignatures = $sigExprs,
doc = None, // TODO
invoke0 = (
clazz: Cls,
ctx: cask.Request,
ctxs: Seq[Any],
argss: Seq[Map[String, Any]],
sigss: Seq[Seq[ArgSig[Any, _, _, cask.Request]]]
sigss: Seq[Seq[ArgSig[Any, _, _, Any]]]
) => {
val parsedArgss: Seq[Seq[Either[Seq[cask.router.Result.ParamError], Any]]] =
sigss.zip(argss).map{ case (sigs, args) =>
(sigss, argss, ctxs).zipped.map { case (sigs, args, ctx) =>
sigs.map{ case sig =>
Runtime.makeReadCall(
args,
Expand Down
12 changes: 6 additions & 6 deletions cask/src-3/cask/router/RoutesEndpointMetadata.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,18 @@ object RoutesEndpointsMetadata{

val routeParts: List[Expr[EndpointMetadata[T]]] = for {
m <- TypeRepr.of[T].typeSymbol.memberMethods
annotations = m.annotations.filter(_.tpe <:< TypeRepr.of[Decorator[_, _, _]])
annotations = m.annotations.filter(_.tpe <:< TypeRepr.of[Decorator[_, _, _, _]])
if (annotations.nonEmpty)
} yield {

if(!(annotations.head.tpe <:< TypeRepr.of[Endpoint[_, _, _]])) {
if(!(annotations.head.tpe <:< TypeRepr.of[Endpoint[_, _, _, _]])) {
report.error(s"Last annotation applied to a function must be an instance of Endpoint, " +
s"not ${annotations.head.tpe.show}",
annotations.head.pos
)
return '{???} // in this case, we can't continue expansion of this macro
}
val allEndpoints = annotations.filter(_.tpe <:< TypeRepr.of[Endpoint[_, _, _]])
val allEndpoints = annotations.filter(_.tpe <:< TypeRepr.of[Endpoint[_, _, _, _]])
if(allEndpoints.length > 1) {
report.error(
s"You can only apply one Endpoint annotation to a function, not " +
Expand All @@ -41,16 +41,16 @@ object RoutesEndpointsMetadata{
return '{???}
}

val decorators = annotations.map(_.asExprOf[Decorator[_, _, _]])
val decorators = annotations.map(_.asExprOf[Decorator[_, _, _, _]])

if (!Macros.checkDecorators(decorators))
return '{???} // there was a type mismatch in the decorator chain

val endpointExpr = decorators.head.asExprOf[Endpoint[_, _, _]]
val endpointExpr = decorators.head.asExprOf[Endpoint[_, _, _, _]]
val entrypointExpr = Macros.extractMethod[T](m, decorators, endpointExpr)

'{
val entrypoint: EntryPoint[T, cask.Request] = ${entrypointExpr}
val entrypoint: EntryPoint[T, Any] = ${entrypointExpr}

EndpointMetadata[T](
// the Scala 2 version and non-macro code expects decorators to be reversed
Expand Down
2 changes: 1 addition & 1 deletion cask/src/cask/decorators/compress.scala
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ class compress extends cask.RawDecorator{
.toSeq
.flatMap(_.asScala)
.flatMap(_.split(", "))
val finalResult = delegate(Map()).transform{ case v: cask.Response.Raw =>
val finalResult = delegate(ctx, Map()).transform{ case v: cask.Response.Raw =>
val (newData, newHeaders) = if (acceptEncodings.exists(_.toLowerCase == "gzip")) {
new Response.Data {
def write(out: OutputStream): Unit = {
Expand Down
1 change: 1 addition & 0 deletions cask/src/cask/endpoints/FormEndpoint.scala
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ class postForm(val path: String, override val subpath: Boolean = false)
.createParser(ctx.exchange)
.parseBlocking()
delegate(
ctx,
formData
.iterator()
.asScala
Expand Down
4 changes: 2 additions & 2 deletions cask/src/cask/endpoints/JsonEndpoint.scala
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ abstract class postJsonBase(val path: String, override val subpath: Boolean = fa
} yield obj.toMap
obj match{
case Left(r) => Result.Success(r.map(Response.Data.WritableData(_)))
case Right(params) => delegate(params)
case Right(params) => delegate(ctx, params)
}
}
def wrapPathSegment(s: String): ujson.Value = ujson.Str(s)
Expand All @@ -78,7 +78,7 @@ class getJson(val path: String, override val subpath: Boolean = false)
type InputParser[T] = QueryParamReader[T]

def wrapFunction(ctx: Request, delegate: Delegate): Result[Response.Raw] = {
delegate(WebEndpoint.buildMapFromQueryParams(ctx))
delegate(ctx, WebEndpoint.buildMapFromQueryParams(ctx))
}
def wrapPathSegment(s: String) = Seq(s)
}
4 changes: 2 additions & 2 deletions cask/src/cask/endpoints/StaticEndpoints.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class staticFiles(val path: String, headers: Seq[(String, String)] = Nil) extend
type InputParser[T] = QueryParamReader[T]
override def subpath = true
def wrapFunction(ctx: Request, delegate: Delegate) = {
delegate(Map()).map{t =>
delegate(ctx, Map()).map{t =>
val (path, contentTypeOpt) = StaticUtil.makePathAndContentType(t, ctx)
cask.model.StaticFile(path, headers ++ contentTypeOpt.map("Content-Type" -> _))
}
Expand All @@ -36,7 +36,7 @@ class staticResources(val path: String,
type InputParser[T] = QueryParamReader[T]
override def subpath = true
def wrapFunction(ctx: Request, delegate: Delegate) = {
delegate(Map()).map { t =>
delegate(ctx, Map()).map { t =>
val (path, contentTypeOpt) = StaticUtil.makePathAndContentType(t, ctx)
cask.model.StaticResource(path, resourceRoot, headers ++ contentTypeOpt.map("Content-Type" -> _))
}
Expand Down
2 changes: 1 addition & 1 deletion cask/src/cask/endpoints/WebEndpoints.scala
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ trait WebEndpoint extends HttpEndpoint[Response.Raw, Seq[String]]{
type InputParser[T] = QueryParamReader[T]
def wrapFunction(ctx: Request,
delegate: Delegate): Result[Response.Raw] = {
delegate(WebEndpoint.buildMapFromQueryParams(ctx))
delegate(ctx, WebEndpoint.buildMapFromQueryParams(ctx))
}
def wrapPathSegment(s: String) = Seq(s)
}
Expand Down
4 changes: 2 additions & 2 deletions cask/src/cask/endpoints/WebSocketEndpoint.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,12 @@ object WebsocketResult{
}

class websocket(val path: String, override val subpath: Boolean = false)
extends cask.router.Endpoint[WebsocketResult, WebsocketResult, Seq[String]]{
extends cask.router.Endpoint[WebsocketResult, WebsocketResult, Seq[String], Request]{
val methods = Seq("websocket")
type InputParser[T] = QueryParamReader[T]
type OuterReturned = Result[WebsocketResult]
def wrapFunction(ctx: Request, delegate: Delegate) = {
delegate(WebEndpoint.buildMapFromQueryParams(ctx))
delegate(ctx, WebEndpoint.buildMapFromQueryParams(ctx))
}

def wrapPathSegment(s: String): Seq[String] = Seq(s)
Expand Down
5 changes: 3 additions & 2 deletions cask/src/cask/main/Main.scala
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class MainRoutes extends Main with Routes{
* application-wide properties.
*/
abstract class Main{
def mainDecorators: Seq[Decorator[_, _, _]] = Nil
def mainDecorators: Seq[Decorator[_, _, _, _]] = Nil
def allRoutes: Seq[Routes]
def port: Int = 8080
def host: String = "localhost"
Expand Down Expand Up @@ -74,7 +74,7 @@ abstract class Main{

object Main{
class DefaultHandler(dispatchTrie: DispatchTrie[Map[String, (Routes, EndpointMetadata[_])]],
mainDecorators: Seq[Decorator[_, _, _]],
mainDecorators: Seq[Decorator[_, _, _, _]],
debugMode: Boolean,
handleNotFound: Request => Response.Raw,
handleMethodNotAllowed: Request => Response.Raw,
Expand Down Expand Up @@ -120,6 +120,7 @@ object Main{
routes,
routeBindings,
(mainDecorators ++ routes.decorators ++ metadata.decorators).toList,
Nil,
Nil
) match {
case Result.Success(res) => runner(res)
Expand Down
Loading

0 comments on commit d823a8a

Please sign in to comment.