Skip to content

Commit d197c7c

Browse files
committed
Unwrap RCWs when using the global marshalling instance via the COM Marshal APIs.
1 parent 1f225f3 commit d197c7c

File tree

3 files changed

+49
-12
lines changed

3 files changed

+49
-12
lines changed

src/coreclr/System.Private.CoreLib/src/System/Runtime/InteropServices/ComWrappers.CoreCLR.cs

+12-4
Original file line numberDiff line numberDiff line change
@@ -58,11 +58,14 @@ internal static int CallICustomQueryInterface(ManagedObjectWrapperHolder holder,
5858

5959
internal static IntPtr GetOrCreateComInterfaceForObjectWithGlobalMarshallingInstance(object obj)
6060
{
61+
if (s_globalInstanceForMarshalling == null)
62+
{
63+
return IntPtr.Zero;
64+
}
65+
6166
try
6267
{
63-
return s_globalInstanceForMarshalling is null
64-
? IntPtr.Zero
65-
: s_globalInstanceForMarshalling.GetOrCreateComInterfaceForObject(obj, CreateComInterfaceFlags.TrackerSupport);
68+
return ComInterfaceForObject(obj);
6669
}
6770
catch (ArgumentException)
6871
{
@@ -74,9 +77,14 @@ internal static IntPtr GetOrCreateComInterfaceForObjectWithGlobalMarshallingInst
7477

7578
internal static object? GetOrCreateObjectForComInstanceWithGlobalMarshallingInstance(IntPtr comObject, CreateObjectFlags flags)
7679
{
80+
if (s_globalInstanceForMarshalling == null)
81+
{
82+
return IntPtr.Zero;
83+
}
84+
7785
try
7886
{
79-
return s_globalInstanceForMarshalling?.GetOrCreateObjectForComInstance(comObject, flags);
87+
return ComObjectForInterface(comObject, flags);
8088
}
8189
catch (ArgumentNullException)
8290
{

src/libraries/System.Private.CoreLib/src/System/Runtime/InteropServices/ComWrappers.cs

+8-3
Original file line numberDiff line numberDiff line change
@@ -1404,7 +1404,12 @@ internal static IntPtr ComInterfaceForObject(object instance)
14041404
throw new NotSupportedException(SR.InvalidOperation_ComInteropRequireComWrapperInstance);
14051405
}
14061406

1407-
return s_globalInstanceForMarshalling.GetOrCreateComInterfaceForObject(instance, CreateComInterfaceFlags.None);
1407+
if (TryGetComInstance(instance, out IntPtr comObject))
1408+
{
1409+
return comObject;
1410+
}
1411+
1412+
return s_globalInstanceForMarshalling.GetOrCreateComInterfaceForObject(instance, CreateComInterfaceFlags.TrackerSupport);
14081413
}
14091414

14101415
internal static unsafe IntPtr ComInterfaceForObject(object instance, Guid targetIID)
@@ -1423,15 +1428,15 @@ internal static unsafe IntPtr ComInterfaceForObject(object instance, Guid target
14231428
return comObjectInterface;
14241429
}
14251430

1426-
internal static object ComObjectForInterface(IntPtr externalComObject)
1431+
internal static object ComObjectForInterface(IntPtr externalComObject, CreateObjectFlags flags = CreateObjectFlags.TrackerObject | CreateObjectFlags.Unwrap)
14271432
{
14281433
if (s_globalInstanceForMarshalling == null)
14291434
{
14301435
throw new NotSupportedException(SR.InvalidOperation_ComInteropRequireComWrapperInstance);
14311436
}
14321437

14331438
// TrackerObject support and unwrapping matches the built-in semantics that the global marshalling scenario mimics.
1434-
return s_globalInstanceForMarshalling.GetOrCreateObjectForComInstance(externalComObject, CreateObjectFlags.TrackerObject | CreateObjectFlags.Unwrap);
1439+
return s_globalInstanceForMarshalling.GetOrCreateObjectForComInstance(externalComObject, flags);
14351440
}
14361441

14371442
internal static IntPtr GetOrCreateTrackerTarget(IntPtr externalComObject)

src/tests/Interop/COM/ComWrappers/GlobalInstance/GlobalInstance.cs

+29-5
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ extern public static int UpdateTestObjectAsInterface(
4949

5050
private const string ManagedServerTypeName = "ConsumeNETServerTesting";
5151

52+
private const string IID_IUNKNOWN = "00000000-0000-0000-C000-000000000046";
5253
private const string IID_IDISPATCH = "00020400-0000-0000-C000-000000000046";
5354
private const string IID_IINSPECTABLE = "AF86E2E0-B12D-4c6a-9C5A-D7AA65101E90";
5455
class TestEx : Test
@@ -278,7 +279,7 @@ private static void ValidateMarshalAPIs(bool validateUseRegistered)
278279
var testObj = new Test();
279280
IntPtr comWrapper1 = Marshal.GetIUnknownForObject(testObj);
280281
Assert.NotEqual(IntPtr.Zero, comWrapper1);
281-
Assert.Equal(testObj, registeredWrapper.LastComputeVtablesObject);
282+
Assert.Same(testObj, registeredWrapper.LastComputeVtablesObject);
282283

283284
IntPtr comWrapper2 = Marshal.GetIUnknownForObject(testObj);
284285
Assert.Equal(comWrapper1, comWrapper2);
@@ -295,7 +296,7 @@ private static void ValidateMarshalAPIs(bool validateUseRegistered)
295296
var dispatchObj = new TestEx(IID_IDISPATCH);
296297
IntPtr dispatchWrapper = Marshal.GetIDispatchForObject(dispatchObj);
297298
Assert.NotEqual(IntPtr.Zero, dispatchWrapper);
298-
Assert.Equal(dispatchObj, registeredWrapper.LastComputeVtablesObject);
299+
Assert.Same(dispatchObj, registeredWrapper.LastComputeVtablesObject);
299300

300301
Console.WriteLine($" -- Validate Marshal.GetIDispatchForObject != Marshal.GetIUnknownForObject...");
301302
IntPtr unknownWrapper = Marshal.GetIUnknownForObject(dispatchObj);
@@ -309,7 +310,7 @@ private static void ValidateMarshalAPIs(bool validateUseRegistered)
309310
object objWrapper1 = Marshal.GetObjectForIUnknown(trackerObjRaw);
310311
Assert.Equal(validateUseRegistered, objWrapper1 is FakeWrapper);
311312
object objWrapper2 = Marshal.GetObjectForIUnknown(trackerObjRaw);
312-
Assert.Equal(objWrapper1, objWrapper2);
313+
Assert.Same(objWrapper1, objWrapper2);
313314

314315
Console.WriteLine($" -- Validate Marshal.GetUniqueObjectForIUnknown...");
315316

@@ -319,6 +320,29 @@ private static void ValidateMarshalAPIs(bool validateUseRegistered)
319320
Assert.NotEqual(objWrapper1, objWrapper3);
320321

321322
Marshal.Release(trackerObjRaw);
323+
324+
if (validateUseRegistered)
325+
{
326+
Console.WriteLine($" -- Validate Marshal.GetObjectForIUnknown and Marshal.GetIUnknownForObject unwrapping...");
327+
// Validate that the object returned by Marshal.GetObjectForIUnknown is the same as the original object passed to
328+
// Marshal.GetIUnknownForObject.
329+
IntPtr comWrapper3 = Marshal.GetIUnknownForObject(testObj);
330+
object unwrappedObj = Marshal.GetObjectForIUnknown(comWrapper3);
331+
Assert.Same(testObj, unwrappedObj);
332+
333+
// Validate that the pointer returned by Marshal.GetIUnknownForObject is the same one that was passed into
334+
// Marshal.GetObjectForIUnknown.
335+
IntPtr trackerObj2 = MockReferenceTrackerRuntime.CreateTrackerObject();
336+
Marshal.ThrowExceptionForHR(Marshal.QueryInterface(trackerObj2, Guid.Parse(IID_IUNKNOWN), out IntPtr trackerObj2Identity));
337+
Marshal.Release(trackerObj2);
338+
339+
object trackerObjectWrapper = Marshal.GetObjectForIUnknown(trackerObj2);
340+
IntPtr trackerObjUnwrapped = Marshal.GetIUnknownForObject(trackerObjectWrapper);
341+
Assert.Equal(trackerObj2Identity, trackerObjUnwrapped);
342+
343+
Marshal.Release(trackerObj2Identity);
344+
Marshal.Release(trackerObjUnwrapped);
345+
}
322346
}
323347

324348
private static void ValidatePInvokes(bool validateUseRegistered)
@@ -362,12 +386,12 @@ private static void ValidateInterfaceMarshaler<T>(UpdateTestObject<T> func, bool
362386

363387
T retObj;
364388
int hr = func(testObj as T, value, out retObj);
365-
Assert.Equal(testObj, GlobalComWrappers.Instance.LastComputeVtablesObject);
389+
Assert.Same(testObj, GlobalComWrappers.Instance.LastComputeVtablesObject);
366390
if (shouldSucceed)
367391
{
368392
Assert.True(retObj is Test);
369393
Assert.Equal(value, testObj.GetValue());
370-
Assert.Equal<object>(testObj, retObj);
394+
Assert.Same(testObj, retObj);
371395
}
372396
else
373397
{

0 commit comments

Comments
 (0)