Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,6 @@ namespace System.Threading
{
public sealed partial class Thread
{
[ThreadStatic]
private static ApartmentType t_apartmentType;

[ThreadStatic]
private static ComState t_comState;

Expand Down Expand Up @@ -334,14 +331,15 @@ public ApartmentState GetApartmentState()
return _initialApartmentState;
}

switch (GetCurrentApartmentType())
switch (GetCurrentApartmentState())
{
case ApartmentType.STA:
case ApartmentState.STA:
return ApartmentState.STA;
case ApartmentType.MTA:
case ApartmentState.MTA:
return ApartmentState.MTA;
default:
return ApartmentState.Unknown;
// If COM is uninitialized on the current thread, it is assumed to be implicit MTA.
return ApartmentState.MTA;
}
}

Expand Down Expand Up @@ -374,14 +372,29 @@ private bool SetApartmentStateUnchecked(ApartmentState state, bool throwOnError)
}
else
{
// Compat: Setting ApartmentState to Unknown uninitializes COM
UninitializeCom();
}
}

// Clear the cache and check whether new state matches the desired state
t_apartmentType = ApartmentType.Unknown;
// Clear the cache and check whether new state matches the desired state
t_comState &= ~(ComState.STA | ComState.MTA);

retState = GetApartmentState();
retState = GetCurrentApartmentState();
}
else
{
Debug.Assert((t_comState & ComState.MTA) != 0);
retState = ApartmentState.MTA;
}
}

// Special case where we pass in Unknown and get back MTA.
// Once we CoUninitialize the thread, the OS will still
// report the thread as implicitly in the MTA if any
// other thread in the process is CoInitialized.
if ((state == ApartmentState.Unknown) && (retState == ApartmentState.MTA))
{
return true;
}

if (retState != state)
Expand Down Expand Up @@ -415,7 +428,7 @@ private static void InitializeComForThreadPoolThread()
// Process-wide COM is initialized very early before any managed code can run.
// Assume it is done.
// Prevent re-initialization of COM model on threadpool threads from the default one.
t_comState |= ComState.Locked;
t_comState |= ComState.Locked | ComState.MTA;
}

private static void InitializeCom(ApartmentState state = ApartmentState.MTA)
Expand Down Expand Up @@ -527,49 +540,53 @@ internal static void CheckForPendingInterrupt()
}

internal static bool ReentrantWaitsEnabled =>
GetCurrentApartmentType() == ApartmentType.STA;
GetCurrentApartmentState() == ApartmentState.STA;

internal static ApartmentType GetCurrentApartmentType()
// Unlike the public API, this returns ApartmentState.Unknown when COM is uninitialized on the current thread
internal static ApartmentState GetCurrentApartmentState()
{
ApartmentType currentThreadType = t_apartmentType;
if (currentThreadType != ApartmentType.Unknown)
return currentThreadType;
if ((t_comState & (ComState.MTA | ComState.STA)) != 0)
return ((t_comState & ComState.STA) != 0) ? ApartmentState.STA : ApartmentState.MTA;

Interop.APTTYPE aptType;
Interop.APTTYPEQUALIFIER aptTypeQualifier;
int result = Interop.Ole32.CoGetApartmentType(out aptType, out aptTypeQualifier);

ApartmentType type = ApartmentType.Unknown;
ApartmentState state = ApartmentState.Unknown;

switch (result)
{
case HResults.CO_E_NOTINITIALIZED:
type = ApartmentType.None;
Debug.Fail("COM is not initialized");
state = ApartmentState.Unknown;
break;

case HResults.S_OK:
switch (aptType)
{
case Interop.APTTYPE.APTTYPE_STA:
case Interop.APTTYPE.APTTYPE_MAINSTA:
type = ApartmentType.STA;
state = ApartmentState.STA;
break;

case Interop.APTTYPE.APTTYPE_MTA:
type = ApartmentType.MTA;
state = ApartmentState.MTA;
break;

case Interop.APTTYPE.APTTYPE_NA:
switch (aptTypeQualifier)
{
case Interop.APTTYPEQUALIFIER.APTTYPEQUALIFIER_NA_ON_MTA:
state = ApartmentState.MTA;
break;

case Interop.APTTYPEQUALIFIER.APTTYPEQUALIFIER_NA_ON_IMPLICIT_MTA:
type = ApartmentType.MTA;
state = ApartmentState.Unknown;
break;

case Interop.APTTYPEQUALIFIER.APTTYPEQUALIFIER_NA_ON_STA:
case Interop.APTTYPEQUALIFIER.APTTYPEQUALIFIER_NA_ON_MAINSTA:
type = ApartmentType.STA;
state = ApartmentState.STA;
break;

default:
Expand All @@ -585,24 +602,18 @@ internal static ApartmentType GetCurrentApartmentType()
break;
}

if (type != ApartmentType.Unknown)
t_apartmentType = type;
return type;
}

internal enum ApartmentType : byte
{
Unknown = 0,
None,
STA,
MTA
if (state != ApartmentState.Unknown)
t_comState |= (state == ApartmentState.STA) ? ComState.STA : ComState.MTA;
return state;
}

[Flags]
internal enum ComState : byte
{
InitializedByUs = 1,
Locked = 2,
MTA = 4,
STA = 8
}
}
}
14 changes: 13 additions & 1 deletion src/libraries/System.Threading.Thread/tests/ThreadTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -247,8 +247,20 @@ public static void GetSetApartmentStateTest_ChangeAfterThreadStarted_Windows(
Assert.Equal(ApartmentState.MTA, getApartmentState(t));
Assert.Equal(0, setApartmentState(t, ApartmentState.MTA));
Assert.Equal(ApartmentState.MTA, getApartmentState(t));
Assert.Equal(setType == 0 ? 0 : 2, setApartmentState(t, ApartmentState.STA)); // cannot be changed after thread is started
Assert.Equal(setType == 0 ? 0 : 2, setApartmentState(t, ApartmentState.STA)); // MTA<->STA cannot be changed directly after thread is started
Assert.Equal(ApartmentState.MTA, getApartmentState(t));

if (!PlatformDetection.IsWindowsNanoServer)
{
Assert.Equal(0, setApartmentState(t, ApartmentState.Unknown)); // Compat quirk: MTA<->STA can be changed by going through Unknown
Assert.Equal(ApartmentState.MTA, getApartmentState(t));
Assert.Equal(0, setApartmentState(t, ApartmentState.STA));
Assert.Equal(ApartmentState.STA, getApartmentState(t));
Assert.Equal(0, setApartmentState(t, ApartmentState.Unknown));
Assert.Equal(ApartmentState.MTA, getApartmentState(t));
Assert.Equal(0, setApartmentState(t, ApartmentState.MTA));
Assert.Equal(ApartmentState.MTA, getApartmentState(t));
}
});
}

Expand Down
Loading