Skip to content
Open
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
using Microsoft.EntityFrameworkCore.Cosmos.Storage.Internal;
using Microsoft.EntityFrameworkCore.Cosmos.ValueGeneration.Internal;
using Microsoft.EntityFrameworkCore.Infrastructure.Internal;
using Microsoft.EntityFrameworkCore.Query.Internal;

// ReSharper disable once CheckNamespace
namespace Microsoft.Extensions.DependencyInjection;
Expand Down Expand Up @@ -94,6 +95,7 @@ public static IServiceCollection AddCosmos<TContext>(
public static IServiceCollection AddEntityFrameworkCosmos(this IServiceCollection serviceCollection)
{
var builder = new EntityFrameworkServicesBuilder(serviceCollection)
.TryAdd<IStructuralTypeMaterializerSource, CosmosStructuralTypeMaterializerSource>()
.TryAdd<LoggingDefinitions, CosmosLoggingDefinitions>()
.TryAdd<IDatabaseProvider, DatabaseProvider<CosmosOptionsExtension>>()
.TryAdd<IDatabase, CosmosDatabaseWrapper>()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -352,24 +352,67 @@ UnaryExpression unaryExpression

Expression NullSafeUpdate(Expression? expression)
{
Expression updatedMemberExpression = memberExpression.Update(
expression != null ? MatchTypes(expression, memberExpression.Expression!.Type) : expression);
if (expression is null)
{
return memberExpression.Update(expression);
}

var expressionValue = Expression.Parameter(expression.Type);
var assignment = Expression.Assign(expressionValue, expression);

if (expression.Type.IsNullableType() == true
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I understand this correctly, when the inner expression is a Nullable<> and the outer isn't nullable, this adds "safe null" compensation logic, checking inner.HasValue() and then calling inner.Value.Foo only if true. Also, you're using a block with a variable and an assignment to make sure that inner only gets evaluated once (nice idea).

What I'm not clear on, is why exactly we need both this block and the one below (where we again check and compensate)...

Also, can you please add a short comment before the block explaining this (basically my explanation above, assuming it's correct), as it's really non-obvious?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is for when the query is pojecting Nullable<TStruct>.Value. EF doesn't throw a 'Nullable object must have a value' InvalidOperationException if the value of a projected nullable struct is null, so we need to add a check here. I will add a comment explaining. See Select_nullable_value_type_with_Value for the case this covers

&& !memberExpression.Type.IsNullableType()
&& memberExpression.Expression is MemberExpression innerMember
&& innerMember.Type.IsNullableValueType() == true
&& memberExpression.Member.Name == nameof(Nullable<>.Value))
{
var nullCheck = Expression.Not(
Expression.Property(expressionValue, nameof(Nullable<>.HasValue)));
var conditionalExpression = Expression.Condition(
nullCheck,
Expression.Default(memberExpression.Type),
Expression.Property(expressionValue, nameof(Nullable<>.Value)));

return Expression.Block(
[expressionValue],
assignment,
conditionalExpression);
}

Expression updatedMemberExpression = memberExpression.Update(MatchTypes(expressionValue, memberExpression.Expression!.Type));

if (expression?.Type.IsNullableType() == true)
if (expression.Type.IsNullableType() == true)
{
var nullableReturnType = memberExpression.Type.MakeNullable();
if (!memberExpression.Type.IsNullableType())

if (!updatedMemberExpression.Type.IsNullableType())
{
updatedMemberExpression = Expression.Convert(updatedMemberExpression, nullableReturnType);
}

Expression nullCheck;
if (expression.Type.IsNullableValueType())
{
// For Nullable<T>, use HasValue property instead of equality comparison
// to avoid issues with value types that don't define the == operator
nullCheck = Expression.Not(
Expression.Property(expressionValue, nameof(Nullable<>.HasValue)));
}
else
{
nullCheck = Expression.Equal(expressionValue, Expression.Default(expression.Type));
}

updatedMemberExpression = Expression.Condition(
Expression.Equal(expression, Expression.Default(expression.Type)),
Expression.Constant(null, nullableReturnType),
nullCheck,
Expression.Default(nullableReturnType),
updatedMemberExpression);
}

return updatedMemberExpression;
return Expression.Block(
[expressionValue],
assignment,
updatedMemberExpression);
}
}

Expand Down Expand Up @@ -639,8 +682,21 @@ UnaryExpression unaryExpression
updatedMethodCallExpression = Expression.Convert(updatedMethodCallExpression, nullableReturnType);
}

Expression nullCheck;
if (@object.Type.IsNullableValueType())
{
// For Nullable<T>, use HasValue property instead of equality comparison
// to avoid issues with value types that don't define the == operator
nullCheck = Expression.Not(
Expression.Property(@object, nameof(Nullable<>.HasValue)));
}
else
{
nullCheck = Expression.Equal(@object, Expression.Constant(null, @object.Type));
}

return Expression.Condition(
Expression.Equal(@object, Expression.Default(@object.Type)),
nullCheck,
Expression.Constant(null, nullableReturnType),
updatedMethodCallExpression);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@ private abstract class CosmosProjectionBindingRemovingExpressionVisitorBase(
bool trackQueryResults)
: ExpressionVisitor
{
private static readonly MethodInfo GetItemMethodInfo
= typeof(JObject).GetRuntimeProperties()
.Single(pi => pi.Name == "Item" && pi.GetIndexParameters()[0].ParameterType == typeof(string))
public static readonly MethodInfo GetItemMethodInfo
= typeof(JToken).GetRuntimeProperties()
.Single(pi => pi.Name == "Item" && pi.GetIndexParameters()[0].ParameterType == typeof(object))
.GetMethod;

private static readonly PropertyInfo JTokenTypePropertyInfo
Expand Down Expand Up @@ -56,7 +56,7 @@ private readonly IDictionary<Expression, Expression> _ordinalParameterBindings
private List<IncludeExpression> _pendingIncludes
= [];

private static readonly MethodInfo ToObjectWithSerializerMethodInfo
public static readonly MethodInfo ToObjectWithSerializerMethodInfo
= typeof(CosmosProjectionBindingRemovingExpressionVisitorBase)
.GetRuntimeMethods().Single(mi => mi.Name == nameof(SafeToObjectWithSerializer));

Expand All @@ -72,18 +72,7 @@ protected override Expression VisitBinary(BinaryExpression binaryExpression)
string storeName = null;

// Values injected by JObjectInjectingExpressionVisitor
var projectionExpression = ((UnaryExpression)binaryExpression.Right).Operand;

if (projectionExpression is UnaryExpression
{
NodeType: ExpressionType.Convert,
Operand: UnaryExpression operand
})
{
// Unwrap EntityProjectionExpression when the root entity is not projected
// That is, this is handling the projection of a non-root entity type.
projectionExpression = operand.Operand;
}
var projectionExpression = binaryExpression.Right.UnwrapTypeConversion(out _);

switch (projectionExpression)
{
Expand Down Expand Up @@ -154,7 +143,10 @@ protected override Expression VisitBinary(BinaryExpression binaryExpression)
}

break;

case MethodCallExpression jObjectMethodCallExpression
when jObjectMethodCallExpression.Method.IsGenericMethod && jObjectMethodCallExpression.Method.GetGenericMethodDefinition() == ToObjectWithSerializerMethodInfo:
// jobject already uses ToObjectWithSerializerMethodInfo. This can happen because code was generated for complex properties that already leverages jobject correctly.
return binaryExpression;
default:
throw new UnreachableException();
}
Expand All @@ -166,19 +158,27 @@ protected override Expression VisitBinary(BinaryExpression binaryExpression)
{
var newExpression = (NewExpression)binaryExpression.Right;

EntityProjectionExpression entityProjectionExpression;
if (newExpression.Arguments[0] is ProjectionBindingExpression projectionBindingExpression)
if (newExpression.Arguments[0] is ComplexPropertyBindingExpression complexPropertyBindingExpression)
{
var projection = GetProjection(projectionBindingExpression);
entityProjectionExpression = (EntityProjectionExpression)projection.Expression;
_materializationContextBindings[parameterExpression] = complexPropertyBindingExpression;
_projectionBindings[complexPropertyBindingExpression] = complexPropertyBindingExpression.JObjectParameter;
}
else
{
var projection = ((UnaryExpression)((UnaryExpression)newExpression.Arguments[0]).Operand).Operand;
entityProjectionExpression = (EntityProjectionExpression)projection;
}
EntityProjectionExpression entityProjectionExpression;
if (newExpression.Arguments[0] is ProjectionBindingExpression projectionBindingExpression)
{
var projection = GetProjection(projectionBindingExpression);
entityProjectionExpression = (EntityProjectionExpression)projection.Expression;
}
else
{
var projection = ((UnaryExpression)((UnaryExpression)newExpression.Arguments[0]).Operand).Operand;
entityProjectionExpression = (EntityProjectionExpression)projection;
}

_materializationContextBindings[parameterExpression] = entityProjectionExpression.Object;
_materializationContextBindings[parameterExpression] = entityProjectionExpression.Object;
}

var updatedExpression = New(
newExpression.Constructor,
Expand Down Expand Up @@ -595,7 +595,7 @@ private static Expression AddToCollectionNavigation(
relatedEntity,
Constant(true));

private static readonly MethodInfo PopulateCollectionMethodInfo
public static readonly MethodInfo PopulateCollectionMethodInfo
= typeof(CosmosProjectionBindingRemovingExpressionVisitorBase).GetTypeInfo()
.GetDeclaredMethod(nameof(PopulateCollection));

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ public partial class CosmosShapedQueryCompilingExpressionVisitor(
IQuerySqlGeneratorFactory querySqlGeneratorFactory)
: ShapedQueryCompilingExpressionVisitor(dependencies, cosmosQueryCompilationContext)
{
private ParameterExpression _parentJObject;
private readonly Type _contextType = cosmosQueryCompilationContext.ContextType;
private readonly bool _threadSafetyChecksEnabled = dependencies.CoreSingletonOptions.AreThreadSafetyChecksEnabled;

Expand All @@ -39,6 +40,7 @@ protected override Expression VisitShapedQuery(ShapedQueryExpression shapedQuery
}

var jTokenParameter = Parameter(typeof(JToken), "jToken");
_parentJObject = jTokenParameter;

var shaperBody = shapedQueryExpression.ShaperExpression;

Expand Down Expand Up @@ -170,4 +172,146 @@ private static PartitionKey GeneratePartitionKey(

return builder.Build();
}

/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
/// any release. You should only use it directly in your code with extreme caution and knowing that
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </summary>
public override void AddStructuralTypeInitialization(StructuralTypeShaperExpression shaper, ParameterExpression instanceVariable, List<ParameterExpression> variables, List<Expression> expressions)
{
foreach (var complexProperty in shaper.StructuralType.GetComplexProperties())
{
var member = MakeMemberAccess(instanceVariable, complexProperty.GetMemberInfo(true, true));
if (complexProperty.IsCollection)
{
expressions.Add(CreateComplexCollectionAssignmentBlock(member, complexProperty));
}
else
{
expressions.Add(CreateComplexPropertyAssignmentBlock(member, complexProperty));
}
}
}

private int _currentComplexIndex;

private BlockExpression CreateComplexPropertyAssignmentBlock(MemberExpression memberExpression, IComplexProperty complexProperty)
{
var jObjectVariable = Parameter(typeof(JObject), "complexJObject" + ++_currentComplexIndex);
var assignJObjectVariable = Assign(jObjectVariable,
Call(
CosmosProjectionBindingRemovingExpressionVisitorBase.ToObjectWithSerializerMethodInfo.MakeGenericMethod(typeof(JObject)),
Call(_parentJObject, CosmosProjectionBindingRemovingExpressionVisitorBase.GetItemMethodInfo,
Constant(complexProperty.Name)
)
)
);

var materializeExpression = CreateComplexTypeMaterializeExpression(complexProperty, jObjectVariable);
if (complexProperty.IsNullable)
{
materializeExpression = Condition(Equal(jObjectVariable, Constant(null)),
Default(complexProperty.ClrType.MakeNullable()),
ConvertChecked(materializeExpression, complexProperty.ClrType.MakeNullable())
);
}

return Block(
[jObjectVariable],
[
assignJObjectVariable,
memberExpression.Assign(materializeExpression)
]
);
}

private BlockExpression CreateComplexCollectionAssignmentBlock(MemberExpression memberExpression, IComplexProperty complexProperty)
{
var complexJArrayVariable = Variable(
typeof(JArray),
"complexJArray" + ++_currentComplexIndex);

var assignJArrayVariable = Assign(complexJArrayVariable,
Call(
CosmosProjectionBindingRemovingExpressionVisitorBase.ToObjectWithSerializerMethodInfo.MakeGenericMethod(typeof(JArray)),
Call(_parentJObject, CosmosProjectionBindingRemovingExpressionVisitorBase.GetItemMethodInfo,
Constant(complexProperty.Name)
)
)
);
var jObjectParameter = Parameter(typeof(JObject), "complexJObject" + _currentComplexIndex);
var materializeExpression = CreateComplexTypeMaterializeExpression(complexProperty, jObjectParameter);

var select = Call(
EnumerableMethods.Select.MakeGenericMethod(typeof(JObject), complexProperty.ComplexType.ClrType),
Call(
EnumerableMethods.Cast.MakeGenericMethod(typeof(JObject)),
complexJArrayVariable),
Lambda(materializeExpression, jObjectParameter));

Expression populateExpression =
Call(
CosmosProjectionBindingRemovingExpressionVisitorBase.PopulateCollectionMethodInfo.MakeGenericMethod(complexProperty.ComplexType.ClrType, complexProperty.ClrType),
Constant(complexProperty.GetCollectionAccessor()),
select
);

if (complexProperty.IsNullable)
{
populateExpression = Condition(Equal(complexJArrayVariable, Constant(null)),
Default(complexProperty.ClrType.MakeNullable()),
ConvertChecked(populateExpression, complexProperty.ClrType.MakeNullable())
);
}

return Block(
[complexJArrayVariable],
[
assignJArrayVariable,
memberExpression.Assign(populateExpression)
]
);
}

private Expression CreateComplexTypeMaterializeExpression(IComplexProperty complexProperty, ParameterExpression jObjectParameter)
{
var tempValueBuffer = new ComplexPropertyBindingExpression(complexProperty, jObjectParameter);
var structuralTypeShaperExpression = new StructuralTypeShaperExpression(
complexProperty.ComplexType,
tempValueBuffer,
false);

var oldParentJObject = _parentJObject;
_parentJObject = jObjectParameter;
var materializeExpression = InjectStructuralTypeMaterializers(structuralTypeShaperExpression);
_parentJObject = oldParentJObject;

if (complexProperty.ComplexType.ClrType.IsNullableType())
{
materializeExpression = Condition(Equal(jObjectParameter, Constant(null)),
Default(complexProperty.ComplexType.ClrType),
materializeExpression
);
}

return materializeExpression;
}

private class ComplexPropertyBindingExpression : Expression
{
public ComplexPropertyBindingExpression(IComplexProperty complexProperty, ParameterExpression jObjectParameter)
{
ComplexProperty = complexProperty;
JObjectParameter = jObjectParameter;
}

public override Type Type => typeof(ValueBuffer);

public override ExpressionType NodeType => ExpressionType.Extension;

public IComplexProperty ComplexProperty { get; }
public ParameterExpression JObjectParameter { get; }
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

namespace Microsoft.EntityFrameworkCore.Query.Internal;

#pragma warning disable EF1001 // StructuralTypeMaterializerSource is pubternal

/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
/// any release. You should only use it directly in your code with extreme caution and knowing that
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </summary>
public class CosmosStructuralTypeMaterializerSource(StructuralTypeMaterializerSourceDependencies dependencies)
: StructuralTypeMaterializerSource(dependencies)
{
/// <summary>
/// Complex properties are not handled in the initial materialization expression,
/// So we can more easily generate the necessary nested materialization expressions later in CosmosShapedQueryCompilingExpressionVisitor.
/// </summary>
protected override bool ReadComplexTypeDirectly(IComplexType complexType)
=> false;
}
Loading
Loading