Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
using Microsoft.EntityFrameworkCore.Cosmos.Storage.Internal;
using Microsoft.EntityFrameworkCore.Cosmos.ValueGeneration.Internal;
using Microsoft.EntityFrameworkCore.Infrastructure.Internal;
using Microsoft.EntityFrameworkCore.Query.Internal;

// ReSharper disable once CheckNamespace
namespace Microsoft.Extensions.DependencyInjection;
Expand Down Expand Up @@ -94,6 +95,7 @@ public static IServiceCollection AddCosmos<TContext>(
public static IServiceCollection AddEntityFrameworkCosmos(this IServiceCollection serviceCollection)
{
var builder = new EntityFrameworkServicesBuilder(serviceCollection)
.TryAdd<IStructuralTypeMaterializerSource, CosmosStructuralTypeMaterializerSource>()
.TryAdd<LoggingDefinitions, CosmosLoggingDefinitions>()
.TryAdd<IDatabaseProvider, DatabaseProvider<CosmosOptionsExtension>>()
.TryAdd<IDatabase, CosmosDatabaseWrapper>()
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Diagnostics.CodeAnalysis;
using JetBrains.Annotations;
using Microsoft.EntityFrameworkCore.Cosmos.Internal;
Expand Down Expand Up @@ -350,26 +351,73 @@ UnaryExpression unaryExpression
throw new InvalidOperationException(CoreStrings.TranslationFailed(memberExpression.Print()));
}

Expression NullSafeUpdate(Expression? expression)
Expression NullSafeUpdate(Expression? innerExpression)
{
Expression updatedMemberExpression = memberExpression.Update(
expression != null ? MatchTypes(expression, memberExpression.Expression!.Type) : expression);
if (innerExpression is null)
{
return memberExpression.Update(innerExpression);
}

var expressionValue = Expression.Parameter(innerExpression.Type);
var assignment = Expression.Assign(expressionValue, innerExpression);

// Special case for when query is projecting 'nullable.Value' where 'nullable' is of type Nullable<T>
// In this case we return default(T) when 'nullable' is null
if (innerExpression.Type.IsNullableType()
&& !memberExpression.Type.IsNullableType()
&& memberExpression.Expression is MemberExpression outerMember
&& outerMember.Type.IsNullableValueType()
&& memberExpression.Member.Name == nameof(Nullable<>.Value))
{
// Use HasValue property instead of equality comparison
// to avoid issues with value types that don't define the == operator
var nullCheck = Expression.Not(
Expression.Property(expressionValue, nameof(Nullable<>.HasValue)));
var conditionalExpression = Expression.Condition(
nullCheck,
Expression.Default(memberExpression.Type),
Expression.Property(expressionValue, nameof(Nullable<>.Value)));

return Expression.Block(
[expressionValue],
assignment,
conditionalExpression);
}

Expression updatedMemberExpression = memberExpression.Update(MatchTypes(expressionValue, memberExpression.Expression!.Type));

if (expression?.Type.IsNullableType() == true)
if (innerExpression.Type.IsNullableType())
{
var nullableReturnType = memberExpression.Type.MakeNullable();
if (!memberExpression.Type.IsNullableType())

if (!updatedMemberExpression.Type.IsNullableType())
{
updatedMemberExpression = Expression.Convert(updatedMemberExpression, nullableReturnType);
}

Expression nullCheck;
if (innerExpression.Type.IsNullableValueType())
{
// For Nullable<T>, use HasValue property instead of equality comparison
// to avoid issues with value types that don't define the == operator
nullCheck = Expression.Not(
Expression.Property(expressionValue, nameof(Nullable<>.HasValue)));
}
else
{
nullCheck = Expression.Equal(expressionValue, Expression.Default(innerExpression.Type));
}

updatedMemberExpression = Expression.Condition(
Expression.Equal(expression, Expression.Default(expression.Type)),
Expression.Constant(null, nullableReturnType),
nullCheck,
Expression.Default(nullableReturnType),
updatedMemberExpression);
}

return updatedMemberExpression;
return Expression.Block(
[expressionValue],
assignment,
updatedMemberExpression);
}
}

Expand Down Expand Up @@ -639,8 +687,21 @@ UnaryExpression unaryExpression
updatedMethodCallExpression = Expression.Convert(updatedMethodCallExpression, nullableReturnType);
}

Expression nullCheck;
if (@object.Type.IsNullableValueType())
{
// For Nullable<T>, use HasValue property instead of equality comparison
// to avoid issues with value types that don't define the == operator
nullCheck = Expression.Not(
Expression.Property(@object, nameof(Nullable<>.HasValue)));
}
else
{
nullCheck = Expression.Equal(@object, Expression.Constant(null, @object.Type));
}

return Expression.Condition(
Expression.Equal(@object, Expression.Default(@object.Type)),
nullCheck,
Expression.Constant(null, nullableReturnType),
updatedMethodCallExpression);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@ private abstract class CosmosProjectionBindingRemovingExpressionVisitorBase(
bool trackQueryResults)
: ExpressionVisitor
{
private static readonly MethodInfo GetItemMethodInfo
= typeof(JObject).GetRuntimeProperties()
.Single(pi => pi.Name == "Item" && pi.GetIndexParameters()[0].ParameterType == typeof(string))
public static readonly MethodInfo GetItemMethodInfo
= typeof(JToken).GetRuntimeProperties()
.Single(pi => pi.Name == "Item" && pi.GetIndexParameters()[0].ParameterType == typeof(object))
.GetMethod;

private static readonly PropertyInfo JTokenTypePropertyInfo
Expand Down Expand Up @@ -56,7 +56,7 @@ private readonly IDictionary<Expression, Expression> _ordinalParameterBindings
private List<IncludeExpression> _pendingIncludes
= [];

private static readonly MethodInfo ToObjectWithSerializerMethodInfo
public static readonly MethodInfo ToObjectWithSerializerMethodInfo
= typeof(CosmosProjectionBindingRemovingExpressionVisitorBase)
.GetRuntimeMethods().Single(mi => mi.Name == nameof(SafeToObjectWithSerializer));

Expand All @@ -72,18 +72,7 @@ protected override Expression VisitBinary(BinaryExpression binaryExpression)
string storeName = null;

// Values injected by JObjectInjectingExpressionVisitor
var projectionExpression = ((UnaryExpression)binaryExpression.Right).Operand;

if (projectionExpression is UnaryExpression
{
NodeType: ExpressionType.Convert,
Operand: UnaryExpression operand
})
{
// Unwrap EntityProjectionExpression when the root entity is not projected
// That is, this is handling the projection of a non-root entity type.
projectionExpression = operand.Operand;
}
var projectionExpression = binaryExpression.Right.UnwrapTypeConversion(out _);

switch (projectionExpression)
{
Expand Down Expand Up @@ -154,7 +143,10 @@ protected override Expression VisitBinary(BinaryExpression binaryExpression)
}

break;

case MethodCallExpression { Method.IsGenericMethod: true } jObjectMethodCallExpression
when jObjectMethodCallExpression.Method.GetGenericMethodDefinition() == ToObjectWithSerializerMethodInfo:
// JObject assignment already uses ToObjectWithSerializerMethodInfo. This can happen because code was generated for complex properties that already leverages JObject correctly.
return binaryExpression;
default:
throw new UnreachableException();
}
Expand All @@ -166,19 +158,27 @@ protected override Expression VisitBinary(BinaryExpression binaryExpression)
{
var newExpression = (NewExpression)binaryExpression.Right;

EntityProjectionExpression entityProjectionExpression;
if (newExpression.Arguments[0] is ProjectionBindingExpression projectionBindingExpression)
if (newExpression.Arguments[0] is ComplexPropertyBindingExpression complexPropertyBindingExpression)
{
var projection = GetProjection(projectionBindingExpression);
entityProjectionExpression = (EntityProjectionExpression)projection.Expression;
_materializationContextBindings[parameterExpression] = complexPropertyBindingExpression;
_projectionBindings[complexPropertyBindingExpression] = complexPropertyBindingExpression.JObjectParameter;
}
else
{
var projection = ((UnaryExpression)((UnaryExpression)newExpression.Arguments[0]).Operand).Operand;
entityProjectionExpression = (EntityProjectionExpression)projection;
}
EntityProjectionExpression entityProjectionExpression;
if (newExpression.Arguments[0] is ProjectionBindingExpression projectionBindingExpression)
{
var projection = GetProjection(projectionBindingExpression);
entityProjectionExpression = (EntityProjectionExpression)projection.Expression;
}
else
{
var projection = ((UnaryExpression)((UnaryExpression)newExpression.Arguments[0]).Operand).Operand;
entityProjectionExpression = (EntityProjectionExpression)projection;
}

_materializationContextBindings[parameterExpression] = entityProjectionExpression.Object;
_materializationContextBindings[parameterExpression] = entityProjectionExpression.Object;
}

var updatedExpression = New(
newExpression.Constructor,
Expand Down Expand Up @@ -595,7 +595,7 @@ private static Expression AddToCollectionNavigation(
relatedEntity,
Constant(true));

private static readonly MethodInfo PopulateCollectionMethodInfo
public static readonly MethodInfo PopulateCollectionMethodInfo
= typeof(CosmosProjectionBindingRemovingExpressionVisitorBase).GetTypeInfo()
.GetDeclaredMethod(nameof(PopulateCollection));

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ public partial class CosmosShapedQueryCompilingExpressionVisitor(
IQuerySqlGeneratorFactory querySqlGeneratorFactory)
: ShapedQueryCompilingExpressionVisitor(dependencies, cosmosQueryCompilationContext)
{
private int _currentComplexIndex;
private ParameterExpression _parentJObject;
private readonly Type _contextType = cosmosQueryCompilationContext.ContextType;
private readonly bool _threadSafetyChecksEnabled = dependencies.CoreSingletonOptions.AreThreadSafetyChecksEnabled;

Expand All @@ -39,6 +41,7 @@ protected override Expression VisitShapedQuery(ShapedQueryExpression shapedQuery
}

var jTokenParameter = Parameter(typeof(JToken), "jToken");
_parentJObject = jTokenParameter;

var shaperBody = shapedQueryExpression.ShaperExpression;

Expand Down Expand Up @@ -170,4 +173,124 @@ private static PartitionKey GeneratePartitionKey(

return builder.Build();
}

/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
/// any release. You should only use it directly in your code with extreme caution and knowing that
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </summary>
public override void AddStructuralTypeInitialization(StructuralTypeShaperExpression shaper, ParameterExpression instanceVariable, List<ParameterExpression> variables, List<Expression> expressions)
{
foreach (var complexProperty in shaper.StructuralType.GetComplexProperties())
{
var member = MakeMemberAccess(instanceVariable, complexProperty.GetMemberInfo(true, true));
expressions.Add(complexProperty.IsCollection
? CreateComplexCollectionAssignmentBlock(member, complexProperty)
: CreateComplexPropertyAssignmentBlock(member, complexProperty));
}
}

private BlockExpression CreateComplexPropertyAssignmentBlock(MemberExpression memberExpression, IComplexProperty complexProperty)
{
var jObjectVariable = Parameter(typeof(JObject), "complexJObject" + ++_currentComplexIndex);
var assignJObjectVariable = Assign(jObjectVariable,
Call(
CosmosProjectionBindingRemovingExpressionVisitorBase.ToObjectWithSerializerMethodInfo.MakeGenericMethod(typeof(JObject)),
Call(_parentJObject, CosmosProjectionBindingRemovingExpressionVisitorBase.GetItemMethodInfo,
Constant(complexProperty.Name))));

var materializeExpression = CreateComplexTypeMaterializeExpression(complexProperty, jObjectVariable);
if (complexProperty.IsNullable)
{
materializeExpression = Condition(Equal(jObjectVariable, Constant(null)),
Default(complexProperty.ClrType.MakeNullable()),
ConvertChecked(materializeExpression, complexProperty.ClrType.MakeNullable()));
}

return Block(
[jObjectVariable],
[
assignJObjectVariable,
memberExpression.Assign(materializeExpression)
]
);
}

private BlockExpression CreateComplexCollectionAssignmentBlock(MemberExpression memberExpression, IComplexProperty complexProperty)
{
var complexJArrayVariable = Variable(
typeof(JArray),
"complexJArray" + ++_currentComplexIndex);

var assignJArrayVariable = Assign(complexJArrayVariable,
Call(
CosmosProjectionBindingRemovingExpressionVisitorBase.ToObjectWithSerializerMethodInfo.MakeGenericMethod(typeof(JArray)),
Call(_parentJObject, CosmosProjectionBindingRemovingExpressionVisitorBase.GetItemMethodInfo,
Constant(complexProperty.Name))));

var jObjectParameter = Parameter(typeof(JObject), "complexJObject" + _currentComplexIndex);
var materializeExpression = CreateComplexTypeMaterializeExpression(complexProperty, jObjectParameter);

var select = Call(
EnumerableMethods.Select.MakeGenericMethod(typeof(JObject), complexProperty.ComplexType.ClrType),
Call(
EnumerableMethods.Cast.MakeGenericMethod(typeof(JObject)),
complexJArrayVariable),
Lambda(materializeExpression, jObjectParameter));

Expression populateExpression =
Call(
CosmosProjectionBindingRemovingExpressionVisitorBase.PopulateCollectionMethodInfo.MakeGenericMethod(complexProperty.ComplexType.ClrType, complexProperty.ClrType),
Constant(complexProperty.GetCollectionAccessor()),
select);

if (complexProperty.IsNullable)
{
populateExpression = Condition(Equal(complexJArrayVariable, Constant(null)),
Default(complexProperty.ClrType.MakeNullable()),
ConvertChecked(populateExpression, complexProperty.ClrType.MakeNullable()));
}

return Block(
[complexJArrayVariable],
[
assignJArrayVariable,
memberExpression.Assign(populateExpression)
]
);
}

private Expression CreateComplexTypeMaterializeExpression(IComplexProperty complexProperty, ParameterExpression jObjectParameter)
{
var tempValueBuffer = new ComplexPropertyBindingExpression(complexProperty, jObjectParameter);
var structuralTypeShaperExpression = new StructuralTypeShaperExpression(
complexProperty.ComplexType,
tempValueBuffer,
false);

var oldParentJObject = _parentJObject;
_parentJObject = jObjectParameter;
var materializeExpression = InjectStructuralTypeMaterializers(structuralTypeShaperExpression);
_parentJObject = oldParentJObject;

if (complexProperty.ComplexType.ClrType.IsNullableType())
{
materializeExpression = Condition(Equal(jObjectParameter, Constant(null)),
Default(complexProperty.ComplexType.ClrType),
materializeExpression);
}

return materializeExpression;
}

private sealed class ComplexPropertyBindingExpression(IComplexProperty complexProperty, ParameterExpression jObjectParameter) : Expression
{
public override Type Type => typeof(ValueBuffer);

public override ExpressionType NodeType => ExpressionType.Extension;

public IComplexProperty ComplexProperty { get; } = complexProperty;
public ParameterExpression JObjectParameter { get; } = jObjectParameter;
}
}
Loading
Loading