diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft.Data.SqlClient.csproj b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft.Data.SqlClient.csproj
index da5a2d1062..814ec8165b 100644
--- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft.Data.SqlClient.csproj
+++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft.Data.SqlClient.csproj
@@ -931,6 +931,9 @@
+
+
+
Microsoft.Data.SqlClient.SqlMetaData.xml
diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIMarsConnection.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIMarsConnection.cs
index 395cfed4be..4523908979 100644
--- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIMarsConnection.cs
+++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIMarsConnection.cs
@@ -6,6 +6,7 @@
using System.Collections.Generic;
using System.Diagnostics;
using System.Threading;
+using System.Threading.Tasks;
namespace Microsoft.Data.SqlClient.SNI
{
@@ -16,16 +17,20 @@ internal class SNIMarsConnection
{
private const string s_className = nameof(SNIMarsConnection);
+ private static QueuedTaskScheduler s_scheduler;
+
private readonly Guid _connectionId = Guid.NewGuid();
private readonly Dictionary _sessions = new Dictionary();
private readonly byte[] _headerBytes = new byte[SNISMUXHeader.HEADER_LENGTH];
private readonly SNISMUXHeader _currentHeader = new SNISMUXHeader();
private SNIHandle _lowerHandle;
- private ushort _nextSessionId = 0;
+ private int _nextSessionId;
private int _currentHeaderByteCount = 0;
private int _dataBytesLeft = 0;
private SNIPacket _currentPacket;
+
+
///
/// Connection ID
///
@@ -45,6 +50,8 @@ public Guid ConnectionId
/// Lower handle
public SNIMarsConnection(SNIHandle lowerHandle)
{
+
+ _nextSessionId = -1;
_lowerHandle = lowerHandle;
SqlClientEventSource.Log.TrySNITraceEvent(s_className, EventType.INFO, "Created MARS Session Id {0}", args0: ConnectionId);
_lowerHandle.SetAsyncCallbacks(HandleReceiveComplete, HandleSendComplete);
@@ -54,7 +61,8 @@ public SNIMarsHandle CreateMarsSession(object callbackObject, bool async)
{
lock (this)
{
- ushort sessionId = _nextSessionId++;
+ ushort sessionId = unchecked((ushort)(Interlocked.Increment(ref _nextSessionId) % ushort.MaxValue));
+
SNIMarsHandle handle = new SNIMarsHandle(this, sessionId, callbackObject, async);
_sessions.Add(sessionId, handle);
SqlClientEventSource.Log.TrySNITraceEvent(s_className, EventType.INFO, "MARS Session Id {0}, SNI MARS Handle Id {1}, created new MARS Session {2}", args0: ConnectionId, args1: handle?.ConnectionId, args2: sessionId);
@@ -68,25 +76,42 @@ public SNIMarsHandle CreateMarsSession(object callbackObject, bool async)
///
public uint StartReceive()
{
- long scopeID = SqlClientEventSource.Log.TrySNIScopeEnterEvent(s_className);
- try
+ using (TrySNIEventScope.Create(nameof(SNIMarsConnection)))
+ {
+ if (LocalAppContextSwitches.UseExperimentalMARSThreading
+//#if NETCOREAPP31_AND_ABOVE
+// && ThreadPool.PendingWorkItemCount > 0
+//#endif
+ )
+ {
+ LazyInitializer.EnsureInitialized(ref s_scheduler, () => new QueuedTaskScheduler(10, "MARSIOScheduler", useForegroundThreads: false, ThreadPriority.Normal));
+
+ // will start an async task on the scheduler and immediatley return so this await is safe
+ return s_scheduler.Factory.StartNew(StartAsyncReceiveLoopForConnection, this).GetAwaiter().GetResult();
+ }
+ else
+ {
+ return StartAsyncReceiveLoopForConnection(this);
+ }
+ }
+
+ static uint StartAsyncReceiveLoopForConnection(object state)
{
+ SNIMarsConnection connection = (SNIMarsConnection)state;
SNIPacket packet = null;
- if (ReceiveAsync(ref packet) == TdsEnums.SNI_SUCCESS_IO_PENDING)
+ if (connection.ReceiveAsync(ref packet) == TdsEnums.SNI_SUCCESS_IO_PENDING)
{
- SqlClientEventSource.Log.TrySNITraceEvent(s_className, EventType.INFO, "MARS Session Id {0}, Success IO pending.", args0: ConnectionId);
+ SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNIMarsConnection), EventType.INFO, "MARS Session Id {0}, Success IO pending.", args0: connection.ConnectionId);
return TdsEnums.SNI_SUCCESS_IO_PENDING;
}
- SqlClientEventSource.Log.TrySNITraceEvent(s_className, EventType.ERR, "MARS Session Id {0}, Connection not usable.", args0: ConnectionId);
+ SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNIMarsConnection), EventType.ERR, "MARS Session Id {0}, Connection not usable.", args0: connection.ConnectionId);
return SNICommon.ReportSNIError(SNIProviders.SMUX_PROV, 0, SNICommon.ConnNotUsableError, Strings.SNI_ERROR_19);
- }
- finally
- {
- SqlClientEventSource.Log.TrySNIScopeLeaveEvent(scopeID);
- }
+ };
}
+
+
///
/// Send a packet synchronously
///
diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIPacket.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIPacket.cs
index 58ac68c7c4..8120e346b9 100644
--- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIPacket.cs
+++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIPacket.cs
@@ -282,7 +282,7 @@ public void ReadFromStreamAsync(Stream stream, SNIAsyncCallback callback)
state: callback,
CancellationToken.None,
TaskContinuationOptions.DenyChildAttach,
- TaskScheduler.Default
+ TaskScheduler.Current // specifically continue on the current scheduler because we may override it for mars
);
}
diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNITaskScheduler.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNITaskScheduler.cs
new file mode 100644
index 0000000000..7a56cbc88f
--- /dev/null
+++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNITaskScheduler.cs
@@ -0,0 +1,177 @@
+using System;
+using System.Collections.Concurrent;
+using System.Collections.Generic;
+using System.Diagnostics;
+using System.Linq;
+using System.Text;
+using System.Threading;
+using System.Threading.Tasks;
+
+namespace Microsoft.Data.SqlClient.SNI
+{
+ ///
+ /// Provides a TaskScheduler that provides control over priorities, fairness, and the underlying threads utilized.
+ ///
+ [DebuggerDisplay("Id={Id}, Queues={DebugQueueCount}, ScheduledTasks = {DebugTaskCount}")]
+ internal sealed class QueuedTaskScheduler : TaskScheduler, IDisposable
+ {
+ /// Cancellation token used for disposal.
+ private readonly CancellationTokenSource _disposeCancellation = new CancellationTokenSource();
+ ///
+ /// The maximum allowed concurrency level of this scheduler. If custom threads are
+ /// used, this represents the number of created threads.
+ ///
+ private readonly int _concurrencyLevel;
+ /// Whether we're processing tasks on the current thread.
+ private static readonly ThreadLocal s_taskProcessingThread = new ThreadLocal();
+
+ /// The threads used by the scheduler to process work.
+ private readonly Thread[] _threads;
+ /// The collection of tasks to be executed on our custom threads.
+ private readonly BlockingCollection _blockingTaskQueue;
+
+ private readonly TaskFactory _factory;
+
+ /// Initializes the scheduler.
+ /// The number of threads to create and use for processing work items.
+ public QueuedTaskScheduler(int threadCount) : this(threadCount, string.Empty, false, ThreadPriority.Normal, 0, null, null) { }
+
+ /// Initializes the scheduler.
+ /// The number of threads to create and use for processing work items.
+ /// The name to use for each of the created threads.
+ /// A Boolean value that indicates whether to use foreground threads instead of background.
+ /// The priority to assign to each thread.
+ /// The stack size to use for each thread.
+ /// An initialization routine to run on each thread.
+ /// A finalization routine to run on each thread.
+ public QueuedTaskScheduler(
+ int threadCount,
+ string threadName = "",
+ bool useForegroundThreads = false,
+ ThreadPriority threadPriority = ThreadPriority.Normal,
+ int threadMaxStackSize = 0,
+ Action threadInit = null,
+ Action threadFinally = null)
+ {
+ // Validates arguments (some validation is left up to the Thread type itself).
+ // If the thread count is 0, default to the number of logical processors.
+ if (threadCount < 0)
+ throw new ArgumentOutOfRangeException(nameof(threadCount));
+ else if (threadCount == 0)
+ _concurrencyLevel = Environment.ProcessorCount;
+ else
+ _concurrencyLevel = threadCount;
+
+ // Initialize the queue used for storing tasks
+ _blockingTaskQueue = new BlockingCollection();
+
+ // Create all of the threads
+ _threads = new Thread[threadCount];
+ for (int i = 0; i < threadCount; i++)
+ {
+ _threads[i] = new Thread(() => DispatchLoop(threadInit, threadFinally), threadMaxStackSize)
+ {
+ Priority = threadPriority,
+ IsBackground = !useForegroundThreads,
+ };
+ if (threadName != null)
+ _threads[i].Name = threadName + " (" + i + ")";
+ }
+
+ _factory = new TaskFactory(this);
+
+ // Start all of the threads
+ foreach (var thread in _threads)
+ thread.Start();
+ }
+
+ public TaskFactory Factory => _factory;
+
+ /// The dispatch loop run by all threads in this scheduler.
+ /// An initialization routine to run when the thread begins.
+ /// A finalization routine to run before the thread ends.
+ private void DispatchLoop(Action threadInit, Action threadFinally)
+ {
+ s_taskProcessingThread.Value = true;
+ threadInit?.Invoke();
+ try
+ {
+ // If the scheduler is disposed, the cancellation token will be set and
+ // we'll receive an OperationCanceledException. That OCE should not crash the process.
+ try
+ {
+ // If a thread abort occurs, we'll try to reset it and continue running.
+ while (true)
+ {
+ try
+ {
+ // For each task queued to the scheduler, try to execute it.
+ foreach (var task in _blockingTaskQueue.GetConsumingEnumerable(_disposeCancellation.Token))
+ {
+ // If the task is not null, that means it was queued to this scheduler directly.
+ // Run it.
+ if (task != null)
+ {
+ bool tried = TryExecuteTask(task);
+ }
+ }
+ }
+ catch (ThreadAbortException)
+ {
+ // If we received a thread abort, and that thread abort was due to shutting down
+ // or unloading, let it pass through. Otherwise, reset the abort so we can
+ // continue processing work items.
+ if (!Environment.HasShutdownStarted && !AppDomain.CurrentDomain.IsFinalizingForUnload())
+ {
+ Thread.ResetAbort();
+ }
+ }
+ }
+ }
+ catch (OperationCanceledException) { }
+ }
+ finally
+ {
+ // Run a cleanup routine if there was one
+ threadFinally?.Invoke();
+ s_taskProcessingThread.Value = false;
+ }
+ }
+
+ /// Queues a task to the scheduler.
+ /// The task to be queued.
+ protected override void QueueTask(Task task)
+ {
+ // If we've been disposed, no one should be queueing
+ if (_disposeCancellation.IsCancellationRequested)
+ {
+ throw new ObjectDisposedException(GetType().Name);
+ }
+ _blockingTaskQueue.Add(task);
+ }
+
+ /// Tries to execute a task synchronously on the current thread.
+ /// The task to execute.
+ /// Whether the task was previously queued.
+ /// true if the task was executed; otherwise, false.
+ protected override bool TryExecuteTaskInline(Task task, bool taskWasPreviouslyQueued) =>
+ // If we're already running tasks on this threads, enable inlining
+ false; // s_taskProcessingThread.Value && TryExecuteTask(task);
+
+ /// Gets the tasks scheduled to this scheduler.
+ /// An enumerable of all tasks queued to this scheduler.
+ /// This does not include the tasks on sub-schedulers. Those will be retrieved by the debugger separately.
+ protected override IEnumerable GetScheduledTasks()
+ {
+ // Get all of the tasks, filtering out nulls, which are just placeholders
+ // for tasks in other sub-schedulers
+ return _blockingTaskQueue.Where(t => t != null).ToList();
+ }
+
+ /// Gets the maximum concurrency level to use when processing tasks.
+ public override int MaximumConcurrencyLevel => _concurrencyLevel;
+
+ /// Initiates shutdown of the scheduler.
+ public void Dispose() => _disposeCancellation.Cancel();
+ }
+}
diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/LocalAppContextSwitches.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/LocalAppContextSwitches.cs
index 8e390b21d6..b366359bdd 100644
--- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/LocalAppContextSwitches.cs
+++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/LocalAppContextSwitches.cs
@@ -15,13 +15,16 @@ internal static partial class LocalAppContextSwitches
internal const string LegacyRowVersionNullString = @"Switch.Microsoft.Data.SqlClient.LegacyRowVersionNullBehavior";
internal const string UseSystemDefaultSecureProtocolsString = @"Switch.Microsoft.Data.SqlClient.UseSystemDefaultSecureProtocols";
internal const string SuppressInsecureTLSWarningString = @"Switch.Microsoft.Data.SqlClient.SuppressInsecureTLSWarning";
+ internal const string UseExperimentalMARSThreadingString = @"Switch.Microsoft.Data.SqlClient.UseExperimentalMARSThreading";
private static bool s_makeReadAsyncBlocking;
private static bool? s_LegacyRowVersionNullBehavior;
private static bool? s_UseSystemDefaultSecureProtocols;
- private static bool? s_SuppressInsecureTLSWarning;
+ private static bool? s_SuppressInsecureTLSWarning;
+ private static bool? s_useExperimentalMARSThreading;
-#if !NETFRAMEWORK
+
+#if NETCOREAPP31_AND_ABOVE
static LocalAppContextSwitches()
{
IAppContextSwitchOverridesSection appContextSwitch = AppConfigManager.FetchConfigurationSection(AppContextSwitchOverridesSection.Name);
@@ -95,5 +98,19 @@ public static bool UseSystemDefaultSecureProtocols
return s_UseSystemDefaultSecureProtocols.Value;
}
}
+
+ public static bool UseExperimentalMARSThreading
+ {
+ get
+ {
+ if (s_useExperimentalMARSThreading is null)
+ {
+ bool result;
+ result = AppContext.TryGetSwitch(UseExperimentalMARSThreadingString, out result) ? result : false;
+ s_useExperimentalMARSThreading = result;
+ }
+ return s_useExperimentalMARSThreading.Value;
+ }
+ }
}
}