Skip to content

Unwrap RCWs that are passed to Marshal.GetIUnknownForObject when using the global marshalling ComWrappers instance #115436

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,14 @@ internal static int CallICustomQueryInterface(ManagedObjectWrapperHolder holder,

internal static IntPtr GetOrCreateComInterfaceForObjectWithGlobalMarshallingInstance(object obj)
{
if (s_globalInstanceForMarshalling == null)
{
return IntPtr.Zero;
}

try
{
return s_globalInstanceForMarshalling is null
? IntPtr.Zero
: s_globalInstanceForMarshalling.GetOrCreateComInterfaceForObject(obj, CreateComInterfaceFlags.TrackerSupport);
return ComInterfaceForObject(obj);
}
catch (ArgumentException)
{
Expand All @@ -74,9 +77,14 @@ internal static IntPtr GetOrCreateComInterfaceForObjectWithGlobalMarshallingInst

internal static object? GetOrCreateObjectForComInstanceWithGlobalMarshallingInstance(IntPtr comObject, CreateObjectFlags flags)
{
if (s_globalInstanceForMarshalling == null)
{
return null;
}

try
{
return s_globalInstanceForMarshalling?.GetOrCreateObjectForComInstance(comObject, flags);
return ComObjectForInterface(comObject, flags);
}
catch (ArgumentNullException)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -499,7 +499,7 @@ public static object ConvertNativeComInterfaceToManaged(IntPtr pUnk)

#if TARGET_WINDOWS
#pragma warning disable CA1416
return ComWrappers.ComObjectForInterface(pUnk);
return ComWrappers.ComObjectForInterface(pUnk, CreateObjectFlags.TrackerObject | CreateObjectFlags.Unwrap);
#pragma warning restore CA1416
#else
throw new PlatformNotSupportedException(SR.PlatformNotSupported_ComInterop);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ public static object GetTypedObjectForIUnknown(IntPtr pUnk, Type t)
[SupportedOSPlatform("windows")]
public static object GetObjectForIUnknown(IntPtr pUnk)
{
return ComWrappers.ComObjectForInterface(pUnk);
return ComWrappers.ComObjectForInterface(pUnk, CreateObjectFlags.TrackerObject | CreateObjectFlags.Unwrap);
}

[SupportedOSPlatform("windows")]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -743,13 +743,17 @@ public void DisconnectTracker()

internal static object? GetOrCreateObjectFromWrapper(ComWrappers wrapper, IntPtr externalComObject)
{
if (s_globalInstanceForTrackerSupport != null && s_globalInstanceForTrackerSupport == wrapper)
if (wrapper is null)
{
return null;
}
if (s_globalInstanceForTrackerSupport == wrapper)
{
return s_globalInstanceForTrackerSupport.GetOrCreateObjectForComInstance(externalComObject, CreateObjectFlags.TrackerObject);
}
else if (s_globalInstanceForMarshalling != null && s_globalInstanceForMarshalling == wrapper)
else if (s_globalInstanceForMarshalling == wrapper)
{
return ComObjectForInterface(externalComObject);
return ComObjectForInterface(externalComObject, CreateObjectFlags.TrackerObject | CreateObjectFlags.Unwrap);
}
else
{
Expand Down Expand Up @@ -1404,7 +1408,12 @@ internal static IntPtr ComInterfaceForObject(object instance)
throw new NotSupportedException(SR.InvalidOperation_ComInteropRequireComWrapperInstance);
}

return s_globalInstanceForMarshalling.GetOrCreateComInterfaceForObject(instance, CreateComInterfaceFlags.None);
if (TryGetComInstance(instance, out IntPtr comObject))
{
return comObject;
}

return s_globalInstanceForMarshalling.GetOrCreateComInterfaceForObject(instance, CreateComInterfaceFlags.TrackerSupport);
}

internal static unsafe IntPtr ComInterfaceForObject(object instance, Guid targetIID)
Expand All @@ -1423,15 +1432,14 @@ internal static unsafe IntPtr ComInterfaceForObject(object instance, Guid target
return comObjectInterface;
}

internal static object ComObjectForInterface(IntPtr externalComObject)
internal static object ComObjectForInterface(IntPtr externalComObject, CreateObjectFlags flags)
{
if (s_globalInstanceForMarshalling == null)
{
throw new NotSupportedException(SR.InvalidOperation_ComInteropRequireComWrapperInstance);
}

// TrackerObject support and unwrapping matches the built-in semantics that the global marshalling scenario mimics.
return s_globalInstanceForMarshalling.GetOrCreateObjectForComInstance(externalComObject, CreateObjectFlags.TrackerObject | CreateObjectFlags.Unwrap);
return s_globalInstanceForMarshalling.GetOrCreateObjectForComInstance(externalComObject, flags);
}

internal static IntPtr GetOrCreateTrackerTarget(IntPtr externalComObject)
Expand Down
34 changes: 29 additions & 5 deletions src/tests/Interop/COM/ComWrappers/GlobalInstance/GlobalInstance.cs
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ extern public static int UpdateTestObjectAsInterface(

private const string ManagedServerTypeName = "ConsumeNETServerTesting";

private const string IID_IUNKNOWN = "00000000-0000-0000-C000-000000000046";
private const string IID_IDISPATCH = "00020400-0000-0000-C000-000000000046";
private const string IID_IINSPECTABLE = "AF86E2E0-B12D-4c6a-9C5A-D7AA65101E90";
class TestEx : Test
Expand Down Expand Up @@ -278,7 +279,7 @@ private static void ValidateMarshalAPIs(bool validateUseRegistered)
var testObj = new Test();
IntPtr comWrapper1 = Marshal.GetIUnknownForObject(testObj);
Assert.NotEqual(IntPtr.Zero, comWrapper1);
Assert.Equal(testObj, registeredWrapper.LastComputeVtablesObject);
Assert.Same(testObj, registeredWrapper.LastComputeVtablesObject);

IntPtr comWrapper2 = Marshal.GetIUnknownForObject(testObj);
Assert.Equal(comWrapper1, comWrapper2);
Expand All @@ -295,7 +296,7 @@ private static void ValidateMarshalAPIs(bool validateUseRegistered)
var dispatchObj = new TestEx(IID_IDISPATCH);
IntPtr dispatchWrapper = Marshal.GetIDispatchForObject(dispatchObj);
Assert.NotEqual(IntPtr.Zero, dispatchWrapper);
Assert.Equal(dispatchObj, registeredWrapper.LastComputeVtablesObject);
Assert.Same(dispatchObj, registeredWrapper.LastComputeVtablesObject);

Console.WriteLine($" -- Validate Marshal.GetIDispatchForObject != Marshal.GetIUnknownForObject...");
IntPtr unknownWrapper = Marshal.GetIUnknownForObject(dispatchObj);
Expand All @@ -309,7 +310,7 @@ private static void ValidateMarshalAPIs(bool validateUseRegistered)
object objWrapper1 = Marshal.GetObjectForIUnknown(trackerObjRaw);
Assert.Equal(validateUseRegistered, objWrapper1 is FakeWrapper);
object objWrapper2 = Marshal.GetObjectForIUnknown(trackerObjRaw);
Assert.Equal(objWrapper1, objWrapper2);
Assert.Same(objWrapper1, objWrapper2);

Console.WriteLine($" -- Validate Marshal.GetUniqueObjectForIUnknown...");

Expand All @@ -319,6 +320,29 @@ private static void ValidateMarshalAPIs(bool validateUseRegistered)
Assert.NotEqual(objWrapper1, objWrapper3);

Marshal.Release(trackerObjRaw);

if (validateUseRegistered)
{
Console.WriteLine($" -- Validate Marshal.GetObjectForIUnknown and Marshal.GetIUnknownForObject unwrapping...");
// Validate that the object returned by Marshal.GetObjectForIUnknown is the same as the original object passed to
// Marshal.GetIUnknownForObject.
IntPtr comWrapper3 = Marshal.GetIUnknownForObject(testObj);
object unwrappedObj = Marshal.GetObjectForIUnknown(comWrapper3);
Assert.Same(testObj, unwrappedObj);

// Validate that the pointer returned by Marshal.GetIUnknownForObject is the same one that was passed into
// Marshal.GetObjectForIUnknown.
IntPtr trackerObj2 = MockReferenceTrackerRuntime.CreateTrackerObject();
Marshal.ThrowExceptionForHR(Marshal.QueryInterface(trackerObj2, Guid.Parse(IID_IUNKNOWN), out IntPtr trackerObj2Identity));
Marshal.Release(trackerObj2);

object trackerObjectWrapper = Marshal.GetObjectForIUnknown(trackerObj2);
IntPtr trackerObjUnwrapped = Marshal.GetIUnknownForObject(trackerObjectWrapper);
Assert.Equal(trackerObj2Identity, trackerObjUnwrapped);

Marshal.Release(trackerObj2Identity);
Marshal.Release(trackerObjUnwrapped);
}
}

private static void ValidatePInvokes(bool validateUseRegistered)
Expand Down Expand Up @@ -362,12 +386,12 @@ private static void ValidateInterfaceMarshaler<T>(UpdateTestObject<T> func, bool

T retObj;
int hr = func(testObj as T, value, out retObj);
Assert.Equal(testObj, GlobalComWrappers.Instance.LastComputeVtablesObject);
Assert.Same(testObj, GlobalComWrappers.Instance.LastComputeVtablesObject);
if (shouldSucceed)
{
Assert.True(retObj is Test);
Assert.Equal(value, testObj.GetValue());
Assert.Equal<object>(testObj, retObj);
Assert.Same(testObj, retObj);
}
else
{
Expand Down
Loading