diff --git a/src/coreclr/nativeaot/System.Private.CoreLib/src/System/Threading/Thread.NativeAot.Windows.cs b/src/coreclr/nativeaot/System.Private.CoreLib/src/System/Threading/Thread.NativeAot.Windows.cs index e5e691fe6094c2..fd5183d74a99a4 100644 --- a/src/coreclr/nativeaot/System.Private.CoreLib/src/System/Threading/Thread.NativeAot.Windows.cs +++ b/src/coreclr/nativeaot/System.Private.CoreLib/src/System/Threading/Thread.NativeAot.Windows.cs @@ -14,9 +14,6 @@ namespace System.Threading { public sealed partial class Thread { - [ThreadStatic] - private static ApartmentType t_apartmentType; - [ThreadStatic] private static ComState t_comState; @@ -334,14 +331,15 @@ public ApartmentState GetApartmentState() return _initialApartmentState; } - switch (GetCurrentApartmentType()) + switch (GetCurrentApartmentState()) { - case ApartmentType.STA: + case ApartmentState.STA: return ApartmentState.STA; - case ApartmentType.MTA: + case ApartmentState.MTA: return ApartmentState.MTA; default: - return ApartmentState.Unknown; + // If COM is uninitialized on the current thread, it is assumed to be implicit MTA. + return ApartmentState.MTA; } } @@ -374,14 +372,29 @@ private bool SetApartmentStateUnchecked(ApartmentState state, bool throwOnError) } else { + // Compat: Setting ApartmentState to Unknown uninitializes COM UninitializeCom(); } - } - // Clear the cache and check whether new state matches the desired state - t_apartmentType = ApartmentType.Unknown; + // Clear the cache and check whether new state matches the desired state + t_comState &= ~(ComState.STA | ComState.MTA); - retState = GetApartmentState(); + retState = GetCurrentApartmentState(); + } + else + { + Debug.Assert((t_comState & ComState.MTA) != 0); + retState = ApartmentState.MTA; + } + } + + // Special case where we pass in Unknown and get back MTA. + // Once we CoUninitialize the thread, the OS will still + // report the thread as implicitly in the MTA if any + // other thread in the process is CoInitialized. + if ((state == ApartmentState.Unknown) && (retState == ApartmentState.MTA)) + { + return true; } if (retState != state) @@ -415,7 +428,7 @@ private static void InitializeComForThreadPoolThread() // Process-wide COM is initialized very early before any managed code can run. // Assume it is done. // Prevent re-initialization of COM model on threadpool threads from the default one. - t_comState |= ComState.Locked; + t_comState |= ComState.Locked | ComState.MTA; } private static void InitializeCom(ApartmentState state = ApartmentState.MTA) @@ -527,24 +540,25 @@ internal static void CheckForPendingInterrupt() } internal static bool ReentrantWaitsEnabled => - GetCurrentApartmentType() == ApartmentType.STA; + GetCurrentApartmentState() == ApartmentState.STA; - internal static ApartmentType GetCurrentApartmentType() + // Unlike the public API, this returns ApartmentState.Unknown when COM is uninitialized on the current thread + internal static ApartmentState GetCurrentApartmentState() { - ApartmentType currentThreadType = t_apartmentType; - if (currentThreadType != ApartmentType.Unknown) - return currentThreadType; + if ((t_comState & (ComState.MTA | ComState.STA)) != 0) + return ((t_comState & ComState.STA) != 0) ? ApartmentState.STA : ApartmentState.MTA; Interop.APTTYPE aptType; Interop.APTTYPEQUALIFIER aptTypeQualifier; int result = Interop.Ole32.CoGetApartmentType(out aptType, out aptTypeQualifier); - ApartmentType type = ApartmentType.Unknown; + ApartmentState state = ApartmentState.Unknown; switch (result) { case HResults.CO_E_NOTINITIALIZED: - type = ApartmentType.None; + Debug.Fail("COM is not initialized"); + state = ApartmentState.Unknown; break; case HResults.S_OK: @@ -552,24 +566,27 @@ internal static ApartmentType GetCurrentApartmentType() { case Interop.APTTYPE.APTTYPE_STA: case Interop.APTTYPE.APTTYPE_MAINSTA: - type = ApartmentType.STA; + state = ApartmentState.STA; break; case Interop.APTTYPE.APTTYPE_MTA: - type = ApartmentType.MTA; + state = ApartmentState.MTA; break; case Interop.APTTYPE.APTTYPE_NA: switch (aptTypeQualifier) { case Interop.APTTYPEQUALIFIER.APTTYPEQUALIFIER_NA_ON_MTA: + state = ApartmentState.MTA; + break; + case Interop.APTTYPEQUALIFIER.APTTYPEQUALIFIER_NA_ON_IMPLICIT_MTA: - type = ApartmentType.MTA; + state = ApartmentState.Unknown; break; case Interop.APTTYPEQUALIFIER.APTTYPEQUALIFIER_NA_ON_STA: case Interop.APTTYPEQUALIFIER.APTTYPEQUALIFIER_NA_ON_MAINSTA: - type = ApartmentType.STA; + state = ApartmentState.STA; break; default: @@ -585,17 +602,9 @@ internal static ApartmentType GetCurrentApartmentType() break; } - if (type != ApartmentType.Unknown) - t_apartmentType = type; - return type; - } - - internal enum ApartmentType : byte - { - Unknown = 0, - None, - STA, - MTA + if (state != ApartmentState.Unknown) + t_comState |= (state == ApartmentState.STA) ? ComState.STA : ComState.MTA; + return state; } [Flags] @@ -603,6 +612,8 @@ internal enum ComState : byte { InitializedByUs = 1, Locked = 2, + MTA = 4, + STA = 8 } } } diff --git a/src/libraries/System.Threading.Thread/tests/ThreadTests.cs b/src/libraries/System.Threading.Thread/tests/ThreadTests.cs index 972056094a5fff..15390f0dbed35e 100644 --- a/src/libraries/System.Threading.Thread/tests/ThreadTests.cs +++ b/src/libraries/System.Threading.Thread/tests/ThreadTests.cs @@ -247,8 +247,20 @@ public static void GetSetApartmentStateTest_ChangeAfterThreadStarted_Windows( Assert.Equal(ApartmentState.MTA, getApartmentState(t)); Assert.Equal(0, setApartmentState(t, ApartmentState.MTA)); Assert.Equal(ApartmentState.MTA, getApartmentState(t)); - Assert.Equal(setType == 0 ? 0 : 2, setApartmentState(t, ApartmentState.STA)); // cannot be changed after thread is started + Assert.Equal(setType == 0 ? 0 : 2, setApartmentState(t, ApartmentState.STA)); // MTA<->STA cannot be changed directly after thread is started Assert.Equal(ApartmentState.MTA, getApartmentState(t)); + + if (!PlatformDetection.IsWindowsNanoServer) + { + Assert.Equal(0, setApartmentState(t, ApartmentState.Unknown)); // Compat quirk: MTA<->STA can be changed by going through Unknown + Assert.Equal(ApartmentState.MTA, getApartmentState(t)); + Assert.Equal(0, setApartmentState(t, ApartmentState.STA)); + Assert.Equal(ApartmentState.STA, getApartmentState(t)); + Assert.Equal(0, setApartmentState(t, ApartmentState.Unknown)); + Assert.Equal(ApartmentState.MTA, getApartmentState(t)); + Assert.Equal(0, setApartmentState(t, ApartmentState.MTA)); + Assert.Equal(ApartmentState.MTA, getApartmentState(t)); + } }); }