Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Changes for #1 #2

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,9 @@ _UpgradeReport_Files/
Backup*/
UpgradeLog*.XML


# System.Data.SQLite
*/Test/x64/
*/Test/x86/

############
## Windows
Expand Down
129 changes: 107 additions & 22 deletions CollectionQuery/CollectionQueryableExtensions.cs
Original file line number Diff line number Diff line change
@@ -1,58 +1,143 @@
using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Linq;
using System.Linq.Expressions;
using System.Reflection;
using NHibernate.Collection;
using NHibernate.Engine;
using NHibernate.Linq;
using NHibernate.Proxy;

namespace NHibernate.CollectionQuery
{
using Tuple = System.Tuple;
using Type = System.Type;

public static class CollectionQueryableExtensions
{
private delegate object SessionQueryableFunc(object session);
private delegate object SelectManyFunc(object ownerQueryable, object collectionSelection);

private static Func<IPersistentCollection, ISessionImplementor> sessionGetter;
private static ConcurrentDictionary<Tuple<Type, Type>, SessionQueryableFunc> sessionQueryableGetters;
private static ConcurrentDictionary<Tuple<Type, Type>, SelectManyFunc> selectManyGetters;

static CollectionQueryableExtensions()
{
sessionGetter = CreateSessionGetter();
sessionQueryableGetters = new ConcurrentDictionary<Tuple<Type, Type>, SessionQueryableFunc>();
selectManyGetters = new ConcurrentDictionary<Tuple<Type, Type>, SelectManyFunc>();
}

private static Func<IPersistentCollection, ISessionImplementor> CreateSessionGetter()
{
var sessionProperty = typeof(AbstractPersistentCollection)
.GetProperty("Session", BindingFlags.Instance | BindingFlags.NonPublic);

var collectionParameter = Expression.Parameter(typeof(IPersistentCollection));

var body = Expression.Property(
Expression.Convert(collectionParameter, typeof(AbstractPersistentCollection)),
sessionProperty
);

return Expression.Lambda<Func<IPersistentCollection, ISessionImplementor>>(body, collectionParameter)
.Compile();
}

private static SessionQueryableFunc CreateSessionQueryableGetter(Tuple<Type, Type> types)
{
var sessionType = types.Item1;
var ownerType = types.Item2;

var queryMethod = typeof(NHibernate.Linq.LinqExtensionMethods)
.GetMethod("Query", new[] { sessionType })
.MakeGenericMethod(ownerType);

var sessionParameter = Expression.Parameter(typeof(object));

var body = Expression.Call(null, queryMethod,
Expression.Convert(sessionParameter, sessionType)
);

return Expression.Lambda<SessionQueryableFunc>(body, sessionParameter)
.Compile();
}

private static SelectManyFunc CreateSelectManyGetter(Tuple<Type, Type> types)
{
var ownerType = types.Item1;
var itemType = types.Item2;

var selectManyMethod = typeof(Queryable).GetMethods()
.First(m =>
{
var parameters = m.GetParameters();
if (m.Name != "SelectMany" || parameters.Length != 2) return false;

var p1 = parameters[1].ParameterType;

return p1.GetGenericTypeDefinition() == typeof(Expression<>)
&& p1.GetGenericArguments()[0].GetGenericTypeDefinition() == typeof(Func<,>);
})
.MakeGenericMethod(ownerType, itemType);

var ownerQueryableParameter = Expression.Parameter(typeof(object));
var collectionSelectorParameter = Expression.Parameter(typeof(object));

// Build the type "Expression<Func<TOwner, IEnumerable<TItem>>"
var selectorType = typeof(Expression<>).MakeGenericType(
GetCollectionSelectorType(ownerType, itemType)
);

var body = Expression.Call(null, selectManyMethod,
Expression.Convert(ownerQueryableParameter, typeof(IQueryable<>).MakeGenericType(ownerType)),
Expression.Convert(collectionSelectorParameter, selectorType)
);

return Expression.Lambda<SelectManyFunc>(body, ownerQueryableParameter, collectionSelectorParameter)
.Compile();
}

public static IQueryable<T> Query<T>(this ICollection<T> source, ISessionImplementor session = null)
{
var persistentCollection = source as IPersistentCollection;
if (persistentCollection == null || persistentCollection.WasInitialized)
return source.AsQueryable();

if (session == null)
session = (ISessionImplementor) typeof (AbstractPersistentCollection)
.GetProperty("Session",
BindingFlags.Instance | BindingFlags.NonPublic)
.GetValue(persistentCollection, null);
var queryMethod = typeof(LinqExtensionMethods)
.GetMethod("Query",
new[]
{
session is ISession
? typeof (ISession)
: typeof (IStatelessSession)
});
session = sessionGetter(persistentCollection);

var ownerProxy = persistentCollection.Owner as INHibernateProxy;
var ownerType = ownerProxy == null
? persistentCollection.Owner.GetType()
: ownerProxy.HibernateLazyInitializer.PersistentClass;
var ownerParameter = Expression.Parameter(ownerType);
var collectionPropertyName = persistentCollection.Role.Split('.').Last();
var selectMany = typeof(Queryable).GetMethods()
.First(x => x.Name == "SelectMany")
.MakeGenericMethod(ownerType, typeof(T));
dynamic predicate = Expression.Lambda(Expression.Equal(ownerParameter,
Expression.Constant(persistentCollection.Owner,
ownerType)),
ownerParameter);
var collectionSelector = Expression.Lambda(typeof(Func<,>)
.MakeGenericType(ownerType,
typeof(IEnumerable<>)
.MakeGenericType(typeof(T))),
var collectionSelector = Expression.Lambda(GetCollectionSelectorType(ownerType, typeof(T)),
Expression.Property(ownerParameter, collectionPropertyName),
ownerParameter);
dynamic ownerQueryable = queryMethod.MakeGenericMethod(ownerType).Invoke(null, new object[] { session });
var sessionType = session is ISession ? typeof(ISession) : typeof(IStatelessSession);
var queryableGetter = sessionQueryableGetters.GetOrAdd(Tuple.Create(sessionType, ownerType), CreateSessionQueryableGetter);
dynamic ownerQueryable = queryableGetter(session);
var ownerQuery = Queryable.Where(ownerQueryable, predicate);
var elementsQuery = selectMany.Invoke(null, new object[] { ownerQuery, collectionSelector });

var selectMany = selectManyGetters.GetOrAdd(Tuple.Create(ownerType, typeof(T)), CreateSelectManyGetter);
var elementsQuery = selectMany(ownerQuery, collectionSelector);
return (IQueryable<T>)elementsQuery;
}

private static Type GetCollectionSelectorType(Type ownerType, Type itemType)
{
// Build the type "Func<TOwner, IEnumerable<TItem>"
return typeof(Func<,>)
.MakeGenericType(ownerType,
typeof(IEnumerable<>).MakeGenericType(itemType)
);
}
}
}
4 changes: 2 additions & 2 deletions CollectionQuery/PersistentQueryableSet.cs
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ public PersistentQueryableSet(ISessionImplementor sessionImplementor)
{
}

public PersistentQueryableSet(ISessionImplementor sessionImplementor, ICollection<T> original)
: base(sessionImplementor, original as Iesi.Collections.Generic.ISet<T> ?? new HashedSet<T>(original))
public PersistentQueryableSet(ISessionImplementor sessionImplementor, Iesi.Collections.Generic.ISet<T> original)
: base(sessionImplementor, original)
{
}

Expand Down
9 changes: 8 additions & 1 deletion CollectionQuery/PersistentQueryableSetType.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,14 @@ public IPersistentCollection Instantiate(ISessionImplementor session, ICollectio

public IPersistentCollection Wrap(ISessionImplementor session, object collection)
{
return new PersistentQueryableSet<T>(session, (ICollection<T>) collection);
var set = collection as Iesi.Collections.Generic.ISet<T>;

if (set == null)
{
set = new HashedSet<T>((ICollection<T>)collection);
}

return new PersistentQueryableSet<T>(session, set);
}

public IEnumerable GetElements(object collection)
Expand Down
6 changes: 4 additions & 2 deletions CollectionQuery/Test/NHibernate.CollectionQuery.Test.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,12 @@
<Reference Include="System.Core" />
<Reference Include="Microsoft.CSharp" />
<Reference Include="System.Data" />
<Reference Include="System.Data.SQLite">
<Reference Include="System.Data.SQLite, Version=1.0.85.0, Culture=neutral, PublicKeyToken=db937bc2d44ff139, processorArchitecture=MSIL">
<SpecificVersion>False</SpecificVersion>
<HintPath>..\..\packages\System.Data.SQLite.1.0.85.0\lib\net40\System.Data.SQLite.dll</HintPath>
</Reference>
<Reference Include="System.Data.SQLite.Linq">
<Reference Include="System.Data.SQLite.Linq, Version=1.0.85.0, Culture=neutral, PublicKeyToken=db937bc2d44ff139, processorArchitecture=MSIL">
<SpecificVersion>False</SpecificVersion>
<HintPath>..\..\packages\System.Data.SQLite.1.0.85.0\lib\net40\System.Data.SQLite.Linq.dll</HintPath>
</Reference>
<Reference Include="System.Xml" />
Expand Down
24 changes: 24 additions & 0 deletions CollectionQuery/Test/QueryableCollectionsFixture.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
using NHibernate.Dialect;
using NHibernate.Mapping.ByCode;
using NHibernate.Tool.hbm2ddl;
using NHibernate.Linq;
using NUnit.Framework;

namespace NHibernate.CollectionQuery.Test
Expand Down Expand Up @@ -129,5 +130,28 @@ public void AlreadyInitializedCollectionsAreQueriedInMemory()
Assert.AreEqual(0, sessionFactory.Statistics.QueryExecutionCount,
"unexpected query execution"));
}

[Test]
public void UsingQueryExtensionMethod()
{
using (var session = sessionFactory.OpenSession())
{
var foo = session.Get<Foo>(id);
var bar = foo.Bars.Query().SingleOrDefault(b => b.Data == 2);
Assert.AreEqual(2, bar.Data, "invalid element retrieved");
}
}

// In order to duplicate #1 - run this test by itself.
[Test]
public void PreventSelectingWrongSelectManyQueryableMethod()
{
using (var session = sessionFactory.OpenSession())
{
var foo = session.Query<Foo>().FirstOrDefault();
var bar = foo.Bars.AsQueryable().Where(b => b.Data != 0).FirstOrDefault();
Assert.IsNotNull(bar, "no element retrieved");
}
}
}
}