Skip to content

Commit 03887b7

Browse files
authored
Merge pull request #5775 from dotty-staging/write-replace
Fix #4440: Do not serialize the content of static objects
2 parents 47c51dd + 95900de commit 03887b7

16 files changed

+166
-21
lines changed

compiler/src/dotty/tools/dotc/core/Definitions.scala

+7
Original file line numberDiff line numberDiff line change
@@ -575,7 +575,9 @@ class Definitions {
575575
case List(pt) => (pt isRef StringClass)
576576
case _ => false
577577
}).symbol.asTerm
578+
578579
lazy val JavaSerializableClass: ClassSymbol = ctx.requiredClass("java.io.Serializable")
580+
579581
lazy val ComparableClass: ClassSymbol = ctx.requiredClass("java.lang.Comparable")
580582

581583
lazy val SystemClass: ClassSymbol = ctx.requiredClass("java.lang.System")
@@ -656,6 +658,11 @@ class Definitions {
656658
lazy val Product_productPrefixR: TermRef = ProductClass.requiredMethodRef(nme.productPrefix)
657659
def Product_productPrefix(implicit ctx: Context): Symbol = Product_productPrefixR.symbol
658660

661+
lazy val ModuleSerializationProxyType: TypeRef = ctx.requiredClassRef("scala.runtime.ModuleSerializationProxy")
662+
def ModuleSerializationProxyClass(implicit ctx: Context): ClassSymbol = ModuleSerializationProxyType.symbol.asClass
663+
lazy val ModuleSerializationProxyConstructor: TermSymbol =
664+
ModuleSerializationProxyClass.requiredMethod(nme.CONSTRUCTOR, List(ClassType(WildcardType)))
665+
659666
lazy val GenericType: TypeRef = ctx.requiredClassRef("scala.reflect.Generic")
660667
def GenericClass(implicit ctx: Context): ClassSymbol = GenericType.symbol.asClass
661668
lazy val ShapeType: TypeRef = ctx.requiredClassRef("scala.compiletime.Shape")

compiler/src/dotty/tools/dotc/core/StdNames.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -495,7 +495,6 @@ object StdNames {
495495
val productIterator: N = "productIterator"
496496
val productPrefix: N = "productPrefix"
497497
val raw_ : N = "raw"
498-
val readResolve: N = "readResolve"
499498
val reflect: N = "reflect"
500499
val reflectiveSelectable: N = "reflectiveSelectable"
501500
val reify : N = "reify"
@@ -558,6 +557,7 @@ object StdNames {
558557
val withFilterIfRefutable: N = "withFilterIfRefutable$"
559558
val WorksheetWrapper: N = "WorksheetWrapper"
560559
val wrap: N = "wrap"
560+
val writeReplace: N = "writeReplace"
561561
val zero: N = "zero"
562562
val zip: N = "zip"
563563
val nothingRuntimeClass: N = "scala.runtime.Nothing$"

compiler/src/dotty/tools/dotc/core/SymDenotations.scala

+4
Original file line numberDiff line numberDiff line change
@@ -680,6 +680,10 @@ object SymDenotations {
680680
*/
681681
def derivesFrom(base: Symbol)(implicit ctx: Context): Boolean = false
682682

683+
/** Is this symbol a class that extends `java.io.Serializable` ? */
684+
def isSerializable(implicit ctx: Context): Boolean =
685+
isClass && derivesFrom(defn.JavaSerializableClass)
686+
683687
/** Is this symbol a class that extends `AnyVal`? */
684688
final def isValueClass(implicit ctx: Context): Boolean = {
685689
val di = initial

compiler/src/dotty/tools/dotc/printing/PlainPrinter.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -472,7 +472,7 @@ class PlainPrinter(_ctx: Context) extends Printer {
472472

473473
def toText(const: Constant): Text = const.tag match {
474474
case StringTag => stringText("\"" + escapedString(const.value.toString) + "\"")
475-
case ClazzTag => "classOf[" ~ toText(const.typeValue.classSymbol) ~ "]"
475+
case ClazzTag => "classOf[" ~ toText(const.typeValue) ~ "]"
476476
case CharTag => literalText(s"'${escapedChar(const.charValue)}'")
477477
case LongTag => literalText(const.longValue.toString + "L")
478478
case EnumTag => literalText(const.symbolValue.name.toString)

compiler/src/dotty/tools/dotc/transform/SyntheticMethods.scala

+38-9
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,9 @@ import ValueClasses.isDerivedValueClass
2323
* def productArity: Int
2424
* def productPrefix: String
2525
*
26-
* Special handling:
27-
* protected def readResolve(): AnyRef
26+
* Add to serializable static objects, unless an implementation
27+
* already exists:
28+
* private def writeReplace(): AnyRef
2829
*
2930
* Selectively added to value classes, unless a non-default
3031
* implementation already exists:
@@ -50,8 +51,10 @@ class SyntheticMethods(thisPhase: DenotTransformer) {
5051
def caseSymbols(implicit ctx: Context): List[Symbol] = { initSymbols; myCaseSymbols }
5152
def caseModuleSymbols(implicit ctx: Context): List[Symbol] = { initSymbols; myCaseModuleSymbols }
5253

53-
/** The synthetic methods of the case or value class `clazz`. */
54-
def syntheticMethods(clazz: ClassSymbol)(implicit ctx: Context): List[Tree] = {
54+
/** If this is a case or value class, return the appropriate additional methods,
55+
* otherwise return nothing.
56+
*/
57+
def caseAndValueMethods(clazz: ClassSymbol)(implicit ctx: Context): List[Tree] = {
5558
val clazzType = clazz.appliedRef
5659
lazy val accessors =
5760
if (isDerivedValueClass(clazz)) clazz.paramAccessors.take(1) // Tail parameters can only be `erased`
@@ -255,12 +258,38 @@ class SyntheticMethods(thisPhase: DenotTransformer) {
255258
*/
256259
def canEqualBody(that: Tree): Tree = that.isInstance(AnnotatedType(clazzType, Annotation(defn.UncheckedAnnot)))
257260

258-
symbolsToSynthesize flatMap syntheticDefIfMissing
261+
symbolsToSynthesize.flatMap(syntheticDefIfMissing)
259262
}
260263

261-
def addSyntheticMethods(impl: Template)(implicit ctx: Context): Template =
262-
if (ctx.owner.is(Case) || isDerivedValueClass(ctx.owner))
263-
cpy.Template(impl)(body = impl.body ++ syntheticMethods(ctx.owner.asClass))
264+
/** If this is a serializable static object `Foo`, add the method:
265+
*
266+
* private def writeReplace(): AnyRef =
267+
* new scala.runtime.ModuleSerializationProxy(classOf[Foo.type])
268+
*
269+
* unless an implementation already exists, otherwise do nothing.
270+
*/
271+
def serializableObjectMethod(clazz: ClassSymbol)(implicit ctx: Context): List[Tree] = {
272+
def hasWriteReplace: Boolean =
273+
clazz.membersNamed(nme.writeReplace)
274+
.filterWithPredicate(s => s.signature == Signature(defn.AnyRefType, isJava = false))
275+
.exists
276+
if (clazz.is(Module) && clazz.isStatic && clazz.isSerializable && !hasWriteReplace) {
277+
val writeReplace = ctx.newSymbol(clazz, nme.writeReplace, Method | Private | Synthetic,
278+
MethodType(Nil, defn.AnyRefType), coord = clazz.coord).entered.asTerm
279+
List(
280+
DefDef(writeReplace,
281+
_ => New(defn.ModuleSerializationProxyType,
282+
defn.ModuleSerializationProxyConstructor,
283+
List(Literal(Constant(clazz.sourceModule.termRef)))))
284+
.withSpan(ctx.owner.span.focus))
285+
}
264286
else
265-
impl
287+
Nil
288+
}
289+
290+
def addSyntheticMethods(impl: Template)(implicit ctx: Context): Template = {
291+
val clazz = ctx.owner.asClass
292+
cpy.Template(impl)(body = serializableObjectMethod(clazz) ::: caseAndValueMethods(clazz) ::: impl.body)
293+
}
294+
266295
}

compiler/src/dotty/tools/dotc/typer/Applications.scala

+2-1
Original file line numberDiff line numberDiff line change
@@ -907,7 +907,8 @@ trait Applications extends Compatibility { self: Typer with Dynamic =>
907907
if (typedArgs.length <= pt.paramInfos.length && !isNamed)
908908
if (typedFn.symbol == defn.Predef_classOf && typedArgs.nonEmpty) {
909909
val arg = typedArgs.head
910-
checkClassType(arg.tpe, arg.sourcePos, traitReq = false, stablePrefixReq = false)
910+
if (!arg.symbol.is(Module)) // Allow `classOf[Foo.type]` if `Foo` is an object
911+
checkClassType(arg.tpe, arg.sourcePos, traitReq = false, stablePrefixReq = false)
911912
}
912913
case _ =>
913914
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
// Copied from https://github.com/scala/scala/blob/2.13.x/src/library/scala/runtime/ModuleSerializationProxy.java
2+
// TODO: Remove this file once we switch to the Scala 2.13 stdlib since it already contains it.
3+
4+
/*
5+
* Scala (https://www.scala-lang.org)
6+
*
7+
* Copyright EPFL and Lightbend, Inc.
8+
*
9+
* Licensed under Apache License 2.0
10+
* (http://www.apache.org/licenses/LICENSE-2.0).
11+
*
12+
* See the NOTICE file distributed with this work for
13+
* additional information regarding copyright ownership.
14+
*/
15+
16+
package scala.runtime;
17+
18+
import java.io.Serializable;
19+
import java.security.AccessController;
20+
import java.security.PrivilegedActionException;
21+
import java.security.PrivilegedExceptionAction;
22+
import java.util.HashSet;
23+
import java.util.Set;
24+
25+
/** A serialization proxy for singleton objects */
26+
public final class ModuleSerializationProxy implements Serializable {
27+
private static final long serialVersionUID = 1L;
28+
private final Class<?> moduleClass;
29+
private static final ClassValue<Object> instances = new ClassValue<Object>() {
30+
@Override
31+
protected Object computeValue(Class<?> type) {
32+
try {
33+
return AccessController.doPrivileged((PrivilegedExceptionAction<Object>) () -> type.getField("MODULE$").get(null));
34+
} catch (PrivilegedActionException e) {
35+
return rethrowRuntime(e.getCause());
36+
}
37+
}
38+
};
39+
40+
private static Object rethrowRuntime(Throwable e) {
41+
Throwable cause = e.getCause();
42+
if (cause instanceof RuntimeException) throw (RuntimeException) cause;
43+
else throw new RuntimeException(cause);
44+
}
45+
46+
public ModuleSerializationProxy(Class<?> moduleClass) {
47+
this.moduleClass = moduleClass;
48+
}
49+
50+
@SuppressWarnings("unused")
51+
private Object readResolve() {
52+
return instances.get(moduleClass);
53+
}
54+
}

library/src/scala/tasty/reflect/Printers.scala

+3-3
Original file line numberDiff line numberDiff line change
@@ -643,11 +643,11 @@ trait Printers
643643

644644
def keepDefinition(d: Definition): Boolean = {
645645
val flags = d.symbol.flags
646-
def isCaseClassUnOverridableMethod: Boolean = {
646+
def isUndecompilableCaseClassMethod: Boolean = {
647647
// Currently the compiler does not allow overriding some of the methods generated for case classes
648648
d.symbol.flags.is(Flags.Synthetic) &&
649649
(d match {
650-
case DefDef("apply" | "unapply", _, _, _, _) if d.symbol.owner.flags.is(Flags.Object) => true
650+
case DefDef("apply" | "unapply" | "writeReplace", _, _, _, _) if d.symbol.owner.flags.is(Flags.Object) => true
651651
case DefDef(n, _, _, _, _) if d.symbol.owner.flags.is(Flags.Case) =>
652652
n == "copy" ||
653653
n.matches("copy\\$default\\$[1-9][0-9]*") || // default parameters for the copy method
@@ -657,7 +657,7 @@ trait Printers
657657
})
658658
}
659659
def isInnerModuleObject = d.symbol.flags.is(Flags.Lazy) && d.symbol.flags.is(Flags.Object)
660-
!flags.is(Flags.Param) && !flags.is(Flags.ParamAccessor) && !flags.is(Flags.FieldAccessor) && !isCaseClassUnOverridableMethod && !isInnerModuleObject
660+
!flags.is(Flags.Param) && !flags.is(Flags.ParamAccessor) && !flags.is(Flags.FieldAccessor) && !isUndecompilableCaseClassMethod && !isInnerModuleObject
661661
}
662662
val stats1 = stats.collect {
663663
case IsDefinition(stat) if keepDefinition(stat) => stat

tests/neg/classOf.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ object Test {
55

66
def f1[T] = classOf[T] // error
77
def f2[T <: String] = classOf[T] // error
8-
val x = classOf[Test.type] // error
8+
val x = classOf[Test.type] // ok
99
val y = classOf[C { type I = String }] // error
1010
val z = classOf[A] // ok
1111
}

tests/run/classof-object.decompiled

+4
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
/** Decompiled from out/runTestFromTasty/run/classof-object/Test.tasty */
2+
object Test {
3+
def main(args: scala.Array[scala.Predef.String]): scala.Unit = if (scala.Predef.classOf[Test.type].==(Test.getClass()).unary_!) dotty.DottyPredef.assertFail() else ()
4+
}

tests/run/classof-object.scala

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
object Test {
2+
def main(args: Array[String]): Unit = {
3+
assert(classOf[Test.type] == Test.getClass)
4+
}
5+
}

tests/run/literals.decompiled

+2-2
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
object Test {
33
def αρετη: java.lang.String = "alpha rho epsilon tau eta"
44
case class GGG(i: scala.Int) {
5-
def αα(that: Test.GGG): scala.Int = GGG.this.i.+(that.i)
65
override def hashCode(): scala.Int = {
76
var acc: scala.Int = 767242539
87
acc = scala.runtime.Statics.mix(acc, GGG.this.i)
@@ -24,6 +23,7 @@ object Test {
2423
case _ =>
2524
throw new java.lang.IndexOutOfBoundsException(n.toString())
2625
}
26+
def αα(that: Test.GGG): scala.Int = GGG.this.i.+(that.i)
2727
}
2828
object GGG extends scala.Function1[scala.Int, Test.GGG]
2929
def check_success[a](name: scala.Predef.String, closure: => a, expected: a): scala.Unit = {
@@ -95,4 +95,4 @@ object Test {
9595
val ggg: scala.Int = Test.GGG.apply(1).αα(Test.GGG.apply(2))
9696
Test.check_success[scala.Int]("ggg == 3", ggg, 3)
9797
}
98-
}
98+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
import java.io.File
2+
3+
object Module {
4+
val data = new Array[Byte](32 * 1024 * 1024)
5+
}
6+
7+
object Test {
8+
private val readResolve = classOf[scala.runtime.ModuleSerializationProxy].getDeclaredMethod("readResolve")
9+
readResolve.setAccessible(true)
10+
11+
val testClassesDir = new File(Module.getClass.getClassLoader.getResource("Module.class").toURI).getParentFile
12+
def main(args: Array[String]): Unit = {
13+
for (i <- 1 to 256) {
14+
// This would "java.lang.OutOfMemoryError: Java heap space" if ModuleSerializationProxy
15+
// prevented class unloading.
16+
deserializeDynamicLoadedClass()
17+
}
18+
}
19+
20+
def deserializeDynamicLoadedClass(): Unit = {
21+
val loader = new java.net.URLClassLoader(Array(testClassesDir.toURI.toURL), ClassLoader.getSystemClassLoader)
22+
val moduleClass = loader.loadClass("Module$")
23+
assert(moduleClass ne Module.getClass)
24+
val result = readResolve.invoke(new scala.runtime.ModuleSerializationProxy(moduleClass))
25+
assert(result.getClass == moduleClass)
26+
}
27+
}

tests/run/serialize.scala

+14
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,23 @@ object Test {
88
in.readObject.asInstanceOf[T]
99
}
1010

11+
object Foo extends Serializable {}
12+
13+
object Baz extends Serializable {
14+
private def writeReplace(): AnyRef = {
15+
this
16+
}
17+
}
18+
1119
def main(args: Array[String]): Unit = {
1220
val x: PartialFunction[Int, Int] = { case x => x + 1 }
1321
val adder = serializeDeserialize(x)
1422
assert(adder(1) == 2)
23+
24+
val foo = serializeDeserialize(Foo)
25+
assert(foo eq Foo)
26+
27+
val baz = serializeDeserialize(Baz)
28+
assert(baz ne Baz)
1529
}
1630
}

0 commit comments

Comments
 (0)