diff --git a/src/coreclr/System.Private.CoreLib/src/System/Runtime/InteropServices/ComWrappers.CoreCLR.cs b/src/coreclr/System.Private.CoreLib/src/System/Runtime/InteropServices/ComWrappers.CoreCLR.cs index 5f89431729f0aa..58340f73800701 100644 --- a/src/coreclr/System.Private.CoreLib/src/System/Runtime/InteropServices/ComWrappers.CoreCLR.cs +++ b/src/coreclr/System.Private.CoreLib/src/System/Runtime/InteropServices/ComWrappers.CoreCLR.cs @@ -58,11 +58,14 @@ internal static int CallICustomQueryInterface(ManagedObjectWrapperHolder holder, internal static IntPtr GetOrCreateComInterfaceForObjectWithGlobalMarshallingInstance(object obj) { + if (s_globalInstanceForMarshalling == null) + { + return IntPtr.Zero; + } + try { - return s_globalInstanceForMarshalling is null - ? IntPtr.Zero - : s_globalInstanceForMarshalling.GetOrCreateComInterfaceForObject(obj, CreateComInterfaceFlags.TrackerSupport); + return ComInterfaceForObject(obj); } catch (ArgumentException) { @@ -74,9 +77,14 @@ internal static IntPtr GetOrCreateComInterfaceForObjectWithGlobalMarshallingInst internal static object? GetOrCreateObjectForComInstanceWithGlobalMarshallingInstance(IntPtr comObject, CreateObjectFlags flags) { + if (s_globalInstanceForMarshalling == null) + { + return null; + } + try { - return s_globalInstanceForMarshalling?.GetOrCreateObjectForComInstance(comObject, flags); + return ComObjectForInterface(comObject, flags); } catch (ArgumentNullException) { diff --git a/src/coreclr/nativeaot/System.Private.CoreLib/src/Internal/Runtime/CompilerHelpers/InteropHelpers.cs b/src/coreclr/nativeaot/System.Private.CoreLib/src/Internal/Runtime/CompilerHelpers/InteropHelpers.cs index 195ea214189e3a..5c84977c13d965 100644 --- a/src/coreclr/nativeaot/System.Private.CoreLib/src/Internal/Runtime/CompilerHelpers/InteropHelpers.cs +++ b/src/coreclr/nativeaot/System.Private.CoreLib/src/Internal/Runtime/CompilerHelpers/InteropHelpers.cs @@ -499,7 +499,7 @@ public static object ConvertNativeComInterfaceToManaged(IntPtr pUnk) #if TARGET_WINDOWS #pragma warning disable CA1416 - return ComWrappers.ComObjectForInterface(pUnk); + return ComWrappers.ComObjectForInterface(pUnk, CreateObjectFlags.TrackerObject | CreateObjectFlags.Unwrap); #pragma warning restore CA1416 #else throw new PlatformNotSupportedException(SR.PlatformNotSupported_ComInterop); diff --git a/src/coreclr/nativeaot/System.Private.CoreLib/src/System/Runtime/InteropServices/Marshal.Com.cs b/src/coreclr/nativeaot/System.Private.CoreLib/src/System/Runtime/InteropServices/Marshal.Com.cs index 57bab379b3a02c..944618e17b72a8 100644 --- a/src/coreclr/nativeaot/System.Private.CoreLib/src/System/Runtime/InteropServices/Marshal.Com.cs +++ b/src/coreclr/nativeaot/System.Private.CoreLib/src/System/Runtime/InteropServices/Marshal.Com.cs @@ -295,7 +295,7 @@ public static object GetTypedObjectForIUnknown(IntPtr pUnk, Type t) [SupportedOSPlatform("windows")] public static object GetObjectForIUnknown(IntPtr pUnk) { - return ComWrappers.ComObjectForInterface(pUnk); + return ComWrappers.ComObjectForInterface(pUnk, CreateObjectFlags.TrackerObject | CreateObjectFlags.Unwrap); } [SupportedOSPlatform("windows")] diff --git a/src/libraries/System.Private.CoreLib/src/System/Runtime/InteropServices/ComWrappers.cs b/src/libraries/System.Private.CoreLib/src/System/Runtime/InteropServices/ComWrappers.cs index 639f72b2d93416..d21499a27f14bb 100644 --- a/src/libraries/System.Private.CoreLib/src/System/Runtime/InteropServices/ComWrappers.cs +++ b/src/libraries/System.Private.CoreLib/src/System/Runtime/InteropServices/ComWrappers.cs @@ -743,13 +743,17 @@ public void DisconnectTracker() internal static object? GetOrCreateObjectFromWrapper(ComWrappers wrapper, IntPtr externalComObject) { - if (s_globalInstanceForTrackerSupport != null && s_globalInstanceForTrackerSupport == wrapper) + if (wrapper is null) + { + return null; + } + if (s_globalInstanceForTrackerSupport == wrapper) { return s_globalInstanceForTrackerSupport.GetOrCreateObjectForComInstance(externalComObject, CreateObjectFlags.TrackerObject); } - else if (s_globalInstanceForMarshalling != null && s_globalInstanceForMarshalling == wrapper) + else if (s_globalInstanceForMarshalling == wrapper) { - return ComObjectForInterface(externalComObject); + return ComObjectForInterface(externalComObject, CreateObjectFlags.TrackerObject | CreateObjectFlags.Unwrap); } else { @@ -1404,7 +1408,12 @@ internal static IntPtr ComInterfaceForObject(object instance) throw new NotSupportedException(SR.InvalidOperation_ComInteropRequireComWrapperInstance); } - return s_globalInstanceForMarshalling.GetOrCreateComInterfaceForObject(instance, CreateComInterfaceFlags.None); + if (TryGetComInstance(instance, out IntPtr comObject)) + { + return comObject; + } + + return s_globalInstanceForMarshalling.GetOrCreateComInterfaceForObject(instance, CreateComInterfaceFlags.TrackerSupport); } internal static unsafe IntPtr ComInterfaceForObject(object instance, Guid targetIID) @@ -1423,15 +1432,14 @@ internal static unsafe IntPtr ComInterfaceForObject(object instance, Guid target return comObjectInterface; } - internal static object ComObjectForInterface(IntPtr externalComObject) + internal static object ComObjectForInterface(IntPtr externalComObject, CreateObjectFlags flags) { if (s_globalInstanceForMarshalling == null) { throw new NotSupportedException(SR.InvalidOperation_ComInteropRequireComWrapperInstance); } - // TrackerObject support and unwrapping matches the built-in semantics that the global marshalling scenario mimics. - return s_globalInstanceForMarshalling.GetOrCreateObjectForComInstance(externalComObject, CreateObjectFlags.TrackerObject | CreateObjectFlags.Unwrap); + return s_globalInstanceForMarshalling.GetOrCreateObjectForComInstance(externalComObject, flags); } internal static IntPtr GetOrCreateTrackerTarget(IntPtr externalComObject) diff --git a/src/tests/Interop/COM/ComWrappers/GlobalInstance/GlobalInstance.cs b/src/tests/Interop/COM/ComWrappers/GlobalInstance/GlobalInstance.cs index a2c91318c0ce92..eaa93ee94fbd86 100644 --- a/src/tests/Interop/COM/ComWrappers/GlobalInstance/GlobalInstance.cs +++ b/src/tests/Interop/COM/ComWrappers/GlobalInstance/GlobalInstance.cs @@ -49,6 +49,7 @@ extern public static int UpdateTestObjectAsInterface( private const string ManagedServerTypeName = "ConsumeNETServerTesting"; + private const string IID_IUNKNOWN = "00000000-0000-0000-C000-000000000046"; private const string IID_IDISPATCH = "00020400-0000-0000-C000-000000000046"; private const string IID_IINSPECTABLE = "AF86E2E0-B12D-4c6a-9C5A-D7AA65101E90"; class TestEx : Test @@ -278,7 +279,7 @@ private static void ValidateMarshalAPIs(bool validateUseRegistered) var testObj = new Test(); IntPtr comWrapper1 = Marshal.GetIUnknownForObject(testObj); Assert.NotEqual(IntPtr.Zero, comWrapper1); - Assert.Equal(testObj, registeredWrapper.LastComputeVtablesObject); + Assert.Same(testObj, registeredWrapper.LastComputeVtablesObject); IntPtr comWrapper2 = Marshal.GetIUnknownForObject(testObj); Assert.Equal(comWrapper1, comWrapper2); @@ -295,7 +296,7 @@ private static void ValidateMarshalAPIs(bool validateUseRegistered) var dispatchObj = new TestEx(IID_IDISPATCH); IntPtr dispatchWrapper = Marshal.GetIDispatchForObject(dispatchObj); Assert.NotEqual(IntPtr.Zero, dispatchWrapper); - Assert.Equal(dispatchObj, registeredWrapper.LastComputeVtablesObject); + Assert.Same(dispatchObj, registeredWrapper.LastComputeVtablesObject); Console.WriteLine($" -- Validate Marshal.GetIDispatchForObject != Marshal.GetIUnknownForObject..."); IntPtr unknownWrapper = Marshal.GetIUnknownForObject(dispatchObj); @@ -309,7 +310,7 @@ private static void ValidateMarshalAPIs(bool validateUseRegistered) object objWrapper1 = Marshal.GetObjectForIUnknown(trackerObjRaw); Assert.Equal(validateUseRegistered, objWrapper1 is FakeWrapper); object objWrapper2 = Marshal.GetObjectForIUnknown(trackerObjRaw); - Assert.Equal(objWrapper1, objWrapper2); + Assert.Same(objWrapper1, objWrapper2); Console.WriteLine($" -- Validate Marshal.GetUniqueObjectForIUnknown..."); @@ -319,6 +320,29 @@ private static void ValidateMarshalAPIs(bool validateUseRegistered) Assert.NotEqual(objWrapper1, objWrapper3); Marshal.Release(trackerObjRaw); + + if (validateUseRegistered) + { + Console.WriteLine($" -- Validate Marshal.GetObjectForIUnknown and Marshal.GetIUnknownForObject unwrapping..."); + // Validate that the object returned by Marshal.GetObjectForIUnknown is the same as the original object passed to + // Marshal.GetIUnknownForObject. + IntPtr comWrapper3 = Marshal.GetIUnknownForObject(testObj); + object unwrappedObj = Marshal.GetObjectForIUnknown(comWrapper3); + Assert.Same(testObj, unwrappedObj); + + // Validate that the pointer returned by Marshal.GetIUnknownForObject is the same one that was passed into + // Marshal.GetObjectForIUnknown. + IntPtr trackerObj2 = MockReferenceTrackerRuntime.CreateTrackerObject(); + Marshal.ThrowExceptionForHR(Marshal.QueryInterface(trackerObj2, Guid.Parse(IID_IUNKNOWN), out IntPtr trackerObj2Identity)); + Marshal.Release(trackerObj2); + + object trackerObjectWrapper = Marshal.GetObjectForIUnknown(trackerObj2); + IntPtr trackerObjUnwrapped = Marshal.GetIUnknownForObject(trackerObjectWrapper); + Assert.Equal(trackerObj2Identity, trackerObjUnwrapped); + + Marshal.Release(trackerObj2Identity); + Marshal.Release(trackerObjUnwrapped); + } } private static void ValidatePInvokes(bool validateUseRegistered) @@ -362,12 +386,12 @@ private static void ValidateInterfaceMarshaler(UpdateTestObject func, bool T retObj; int hr = func(testObj as T, value, out retObj); - Assert.Equal(testObj, GlobalComWrappers.Instance.LastComputeVtablesObject); + Assert.Same(testObj, GlobalComWrappers.Instance.LastComputeVtablesObject); if (shouldSucceed) { Assert.True(retObj is Test); Assert.Equal(value, testObj.GetValue()); - Assert.Equal(testObj, retObj); + Assert.Same(testObj, retObj); } else {