Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Compiler: fix expect/actual function with default values #428

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ class ComposeIrGenerationExtension(
metrics
).lower(moduleFragment)

CopyDefaultValuesFromExpectLowering().lower(moduleFragment)
CopyDefaultValuesFromExpectLowering(pluginContext).lower(moduleFragment)

val mangler = when {
pluginContext.platform.isJs() -> JsManglerIr
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,21 @@

package androidx.compose.compiler.plugins.kotlin.lower

import androidx.compose.compiler.plugins.kotlin.ComposeFqNames
import androidx.compose.compiler.plugins.kotlin.hasComposableAnnotation
import org.jetbrains.kotlin.descriptors.FunctionDescriptor
import org.jetbrains.kotlin.backend.common.extensions.IrPluginContext
import org.jetbrains.kotlin.descriptors.MemberDescriptor
import org.jetbrains.kotlin.ir.IrStatement
import org.jetbrains.kotlin.ir.ObsoleteDescriptorBasedAPI
import org.jetbrains.kotlin.ir.declarations.IrFunction
import org.jetbrains.kotlin.ir.declarations.IrModuleFragment
import org.jetbrains.kotlin.ir.visitors.IrElementTransformerVoid
import org.jetbrains.kotlin.ir.visitors.transformChildrenVoid
import org.jetbrains.kotlin.resolve.multiplatform.findCompatibleExpectsForActual
import org.jetbrains.kotlin.ir.declarations.*
import org.jetbrains.kotlin.ir.expressions.IrExpression
import org.jetbrains.kotlin.ir.expressions.impl.IrExpressionBodyImpl
import org.jetbrains.kotlin.ir.symbols.*
import org.jetbrains.kotlin.ir.util.*
import org.jetbrains.kotlin.ir.visitors.*
import org.jetbrains.kotlin.resolve.descriptorUtil.module
import org.jetbrains.kotlin.resolve.descriptorUtil.propertyIfAccessor
import org.jetbrains.kotlin.resolve.multiplatform.findCompatibleActualsForExpected

/**
* [ComposableFunctionBodyTransformer] relies on presence of default values in
Expand All @@ -37,55 +43,196 @@ import org.jetbrains.kotlin.resolve.multiplatform.findCompatibleExpectsForActual
* This lowering needs to run before [ComposableFunctionBodyTransformer] and
* before [ComposerParamTransformer].
*
* Fixes https://github.com/JetBrains/compose-jb/issues/1407
* Fixes:
* https://github.com/JetBrains/compose-jb/issues/1407
* https://github.com/JetBrains/compose-multiplatform/issues/2816
* https://github.com/JetBrains/compose-multiplatform/issues/2806
*
* This implementation is borrowed from Kotlin's ExpectToActualDefaultValueCopier.
* Currently, it heavily relies on descriptors to find expect for actuals or vice versa:
* findCompatibleActualsForExpected.
* Unlike ExpectToActualDefaultValueCopier, this lowering performs its transformations
* only for functions marked with @Composable annotation or
* for functions with @Composable lambdas in parameters.
*/
@OptIn(ObsoleteDescriptorBasedAPI::class)
class CopyDefaultValuesFromExpectLowering : ModuleLoweringPass {
class CopyDefaultValuesFromExpectLowering(
pluginContext: IrPluginContext
) : ModuleLoweringPass, IrElementTransformerVoid() {

private val symbolTable = pluginContext.symbolTable

private fun isApplicable(declaration: IrFunction): Boolean {
return declaration.hasComposableAnnotation()
|| declaration.valueParameters.any {
it.type.hasAnnotation(ComposeFqNames.Composable)
}
}

override fun visitFunction(declaration: IrFunction): IrStatement {
val original = super.visitFunction(declaration) as? IrFunction ?: return declaration

if (!original.isExpect || !isApplicable(original)) {
return original
}

val actualForExpected = original.findActualForExpected()

original.valueParameters.forEachIndexed { index, expectValueParameter ->
val actualValueParameter = actualForExpected.valueParameters[index]
val expectDefaultValue = expectValueParameter.defaultValue
if (expectDefaultValue != null) {
actualValueParameter.defaultValue = IrExpressionBodyImpl(
expectDefaultValue.startOffset, expectDefaultValue.endOffset,
expectDefaultValue.expression.remapExpectValueSymbols()
.patchDeclarationParents(actualForExpected)
)

// Remove a default value in the expect fun in order to prevent
// Kotlin expect/actual-related lowerings trying to copy the default values again
expectValueParameter.defaultValue = null
}
}
return original
}

override fun lower(module: IrModuleFragment) {
// it uses FunctionDescriptor since current API (findCompatibleExpectedForActual)
// can return only a descriptor
val expectComposables = mutableMapOf<FunctionDescriptor, IrFunction>()

// first pass to find expect functions with default values
module.transformChildrenVoid(object : IrElementTransformerVoid() {
override fun visitFunction(declaration: IrFunction): IrStatement {
if (declaration.isExpect && declaration.hasComposableAnnotation()) {
val hasDefaultValues = declaration.valueParameters.any {
it.defaultValue != null
}
if (hasDefaultValues) {
expectComposables[declaration.descriptor] = declaration
module.transformChildrenVoid(this)
}

private inline fun <reified T : IrFunction> T.findActualForExpected(): T =
symbolTable.referenceFunction(descriptor.findActualForExpect()).owner as T

private fun IrProperty.findActualForExpected(): IrProperty =
symbolTable.referenceProperty(descriptor.findActualForExpect()).owner

private fun IrClass.findActualForExpected(): IrClass =
symbolTable.referenceClass(descriptor.findActualForExpect()).owner

private fun IrEnumEntry.findActualForExpected(): IrEnumEntry =
symbolTable.referenceEnumEntry(descriptor.findActualForExpect()).owner

private inline fun <reified T : MemberDescriptor> T.findActualForExpect(): T {
if (!this.isExpect) error(this)
return (findCompatibleActualsForExpected(module).singleOrNull() ?: error(this)) as T
}

private fun IrExpression.remapExpectValueSymbols(): IrExpression {
class SymbolRemapper : DeepCopySymbolRemapper() {
override fun getReferencedClass(symbol: IrClassSymbol) =
if (symbol.descriptor.isExpect)
symbol.owner.findActualForExpected().symbol
else super.getReferencedClass(symbol)

override fun getReferencedClassOrNull(symbol: IrClassSymbol?) =
symbol?.let { getReferencedClass(it) }

override fun getReferencedClassifier(symbol: IrClassifierSymbol): IrClassifierSymbol = when (symbol) {
is IrClassSymbol -> getReferencedClass(symbol)
is IrTypeParameterSymbol -> remapExpectTypeParameter(symbol).symbol
else -> error("Unexpected symbol $symbol ${symbol.descriptor}")
}

override fun getReferencedConstructor(symbol: IrConstructorSymbol) =
if (symbol.descriptor.isExpect)
symbol.owner.findActualForExpected().symbol
else super.getReferencedConstructor(symbol)

override fun getReferencedFunction(symbol: IrFunctionSymbol): IrFunctionSymbol = when (symbol) {
is IrSimpleFunctionSymbol -> getReferencedSimpleFunction(symbol)
is IrConstructorSymbol -> getReferencedConstructor(symbol)
else -> error("Unexpected symbol $symbol ${symbol.descriptor}")
}

override fun getReferencedSimpleFunction(symbol: IrSimpleFunctionSymbol) = when {
symbol.descriptor.isExpect -> symbol.owner.findActualForExpected().symbol

symbol.descriptor.propertyIfAccessor.isExpect -> {
val property = symbol.owner.correspondingPropertySymbol!!.owner
val actualPropertyDescriptor = property.descriptor.findActualForExpect()
val accessorDescriptor = when (symbol.owner) {
property.getter -> actualPropertyDescriptor.getter!!
property.setter -> actualPropertyDescriptor.setter!!
else -> error("Unexpected accessor of $symbol ${symbol.descriptor}")
}
symbolTable.referenceFunction(accessorDescriptor) as IrSimpleFunctionSymbol
}
return super.visitFunction(declaration)

else -> super.getReferencedSimpleFunction(symbol)
}
})

// second pass to set corresponding default values
module.transformChildrenVoid(object : IrElementTransformerVoid() {
override fun visitFunction(declaration: IrFunction): IrStatement {
if (declaration.descriptor.isActual && declaration.hasComposableAnnotation()) {
val compatibleExpects = declaration.descriptor.findCompatibleExpectsForActual {
module.descriptor == it
}
if (compatibleExpects.isNotEmpty()) {
val expectFun = compatibleExpects.firstOrNull {
it in expectComposables
}?.let {
expectComposables[it]
}

if (expectFun != null) {
declaration.valueParameters.forEachIndexed { index, it ->
it.defaultValue =
it.defaultValue ?: expectFun.valueParameters[index].defaultValue
}
}

override fun getReferencedProperty(symbol: IrPropertySymbol) =
if (symbol.descriptor.isExpect)
symbol.owner.findActualForExpected().symbol
else
super.getReferencedProperty(symbol)

override fun getReferencedEnumEntry(symbol: IrEnumEntrySymbol): IrEnumEntrySymbol =
if (symbol.descriptor.isExpect)
symbol.owner.findActualForExpected().symbol
else
super.getReferencedEnumEntry(symbol)

override fun getReferencedValue(symbol: IrValueSymbol) =
remapExpectValue(symbol)?.symbol ?: super.getReferencedValue(symbol)
}

val symbolRemapper = SymbolRemapper()
acceptVoid(symbolRemapper)
return transform(DeepCopyIrTreeWithSymbols(symbolRemapper, DeepCopyTypeRemapper(symbolRemapper)), data = null)
}

private fun remapExpectTypeParameter(symbol: IrTypeParameterSymbol): IrTypeParameter {
val parameter = symbol.owner
val parent = parameter.parent

return when (parent) {
is IrClass ->
if (!parent.descriptor.isExpect)
parameter
else parent.findActualForExpected().typeParameters[parameter.index]

is IrFunction ->
if (!parent.descriptor.isExpect)
parameter
else parent.findActualForExpected().typeParameters[parameter.index]

else -> error(parent)
}
}

private fun remapExpectValue(symbol: IrValueSymbol): IrValueParameter? {
if (symbol !is IrValueParameterSymbol) {
return null
}

val parameter = symbol.owner
val parent = parameter.parent

return when (parent) {
is IrClass ->
if (!parent.descriptor.isExpect)
null
else {
assert(parameter == parent.thisReceiver)
parent.findActualForExpected().thisReceiver!!
}

is IrFunction ->
if (!parent.descriptor.isExpect)
null
else when (parameter) {
parent.dispatchReceiverParameter -> parent.findActualForExpected().dispatchReceiverParameter!!
parent.extensionReceiverParameter -> parent.findActualForExpected().extensionReceiverParameter!!
else -> {
assert(parent.valueParameters[parameter.index] == parameter)
parent.findActualForExpected().valueParameters[parameter.index]
}
}
return super.visitFunction(declaration)
}
})

else -> error(parent)
}
}
}