diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft.Data.SqlClient.csproj b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft.Data.SqlClient.csproj index da5a2d1062..814ec8165b 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft.Data.SqlClient.csproj +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft.Data.SqlClient.csproj @@ -931,6 +931,9 @@ + + + Microsoft.Data.SqlClient.SqlMetaData.xml diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIMarsConnection.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIMarsConnection.cs index 395cfed4be..4523908979 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIMarsConnection.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIMarsConnection.cs @@ -6,6 +6,7 @@ using System.Collections.Generic; using System.Diagnostics; using System.Threading; +using System.Threading.Tasks; namespace Microsoft.Data.SqlClient.SNI { @@ -16,16 +17,20 @@ internal class SNIMarsConnection { private const string s_className = nameof(SNIMarsConnection); + private static QueuedTaskScheduler s_scheduler; + private readonly Guid _connectionId = Guid.NewGuid(); private readonly Dictionary _sessions = new Dictionary(); private readonly byte[] _headerBytes = new byte[SNISMUXHeader.HEADER_LENGTH]; private readonly SNISMUXHeader _currentHeader = new SNISMUXHeader(); private SNIHandle _lowerHandle; - private ushort _nextSessionId = 0; + private int _nextSessionId; private int _currentHeaderByteCount = 0; private int _dataBytesLeft = 0; private SNIPacket _currentPacket; + + /// /// Connection ID /// @@ -45,6 +50,8 @@ public Guid ConnectionId /// Lower handle public SNIMarsConnection(SNIHandle lowerHandle) { + + _nextSessionId = -1; _lowerHandle = lowerHandle; SqlClientEventSource.Log.TrySNITraceEvent(s_className, EventType.INFO, "Created MARS Session Id {0}", args0: ConnectionId); _lowerHandle.SetAsyncCallbacks(HandleReceiveComplete, HandleSendComplete); @@ -54,7 +61,8 @@ public SNIMarsHandle CreateMarsSession(object callbackObject, bool async) { lock (this) { - ushort sessionId = _nextSessionId++; + ushort sessionId = unchecked((ushort)(Interlocked.Increment(ref _nextSessionId) % ushort.MaxValue)); + SNIMarsHandle handle = new SNIMarsHandle(this, sessionId, callbackObject, async); _sessions.Add(sessionId, handle); SqlClientEventSource.Log.TrySNITraceEvent(s_className, EventType.INFO, "MARS Session Id {0}, SNI MARS Handle Id {1}, created new MARS Session {2}", args0: ConnectionId, args1: handle?.ConnectionId, args2: sessionId); @@ -68,25 +76,42 @@ public SNIMarsHandle CreateMarsSession(object callbackObject, bool async) /// public uint StartReceive() { - long scopeID = SqlClientEventSource.Log.TrySNIScopeEnterEvent(s_className); - try + using (TrySNIEventScope.Create(nameof(SNIMarsConnection))) + { + if (LocalAppContextSwitches.UseExperimentalMARSThreading +//#if NETCOREAPP31_AND_ABOVE +// && ThreadPool.PendingWorkItemCount > 0 +//#endif + ) + { + LazyInitializer.EnsureInitialized(ref s_scheduler, () => new QueuedTaskScheduler(10, "MARSIOScheduler", useForegroundThreads: false, ThreadPriority.Normal)); + + // will start an async task on the scheduler and immediatley return so this await is safe + return s_scheduler.Factory.StartNew(StartAsyncReceiveLoopForConnection, this).GetAwaiter().GetResult(); + } + else + { + return StartAsyncReceiveLoopForConnection(this); + } + } + + static uint StartAsyncReceiveLoopForConnection(object state) { + SNIMarsConnection connection = (SNIMarsConnection)state; SNIPacket packet = null; - if (ReceiveAsync(ref packet) == TdsEnums.SNI_SUCCESS_IO_PENDING) + if (connection.ReceiveAsync(ref packet) == TdsEnums.SNI_SUCCESS_IO_PENDING) { - SqlClientEventSource.Log.TrySNITraceEvent(s_className, EventType.INFO, "MARS Session Id {0}, Success IO pending.", args0: ConnectionId); + SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNIMarsConnection), EventType.INFO, "MARS Session Id {0}, Success IO pending.", args0: connection.ConnectionId); return TdsEnums.SNI_SUCCESS_IO_PENDING; } - SqlClientEventSource.Log.TrySNITraceEvent(s_className, EventType.ERR, "MARS Session Id {0}, Connection not usable.", args0: ConnectionId); + SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNIMarsConnection), EventType.ERR, "MARS Session Id {0}, Connection not usable.", args0: connection.ConnectionId); return SNICommon.ReportSNIError(SNIProviders.SMUX_PROV, 0, SNICommon.ConnNotUsableError, Strings.SNI_ERROR_19); - } - finally - { - SqlClientEventSource.Log.TrySNIScopeLeaveEvent(scopeID); - } + }; } + + /// /// Send a packet synchronously /// diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIPacket.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIPacket.cs index 58ac68c7c4..8120e346b9 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIPacket.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIPacket.cs @@ -282,7 +282,7 @@ public void ReadFromStreamAsync(Stream stream, SNIAsyncCallback callback) state: callback, CancellationToken.None, TaskContinuationOptions.DenyChildAttach, - TaskScheduler.Default + TaskScheduler.Current // specifically continue on the current scheduler because we may override it for mars ); } diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNITaskScheduler.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNITaskScheduler.cs new file mode 100644 index 0000000000..7a56cbc88f --- /dev/null +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNITaskScheduler.cs @@ -0,0 +1,177 @@ +using System; +using System.Collections.Concurrent; +using System.Collections.Generic; +using System.Diagnostics; +using System.Linq; +using System.Text; +using System.Threading; +using System.Threading.Tasks; + +namespace Microsoft.Data.SqlClient.SNI +{ + /// + /// Provides a TaskScheduler that provides control over priorities, fairness, and the underlying threads utilized. + /// + [DebuggerDisplay("Id={Id}, Queues={DebugQueueCount}, ScheduledTasks = {DebugTaskCount}")] + internal sealed class QueuedTaskScheduler : TaskScheduler, IDisposable + { + /// Cancellation token used for disposal. + private readonly CancellationTokenSource _disposeCancellation = new CancellationTokenSource(); + /// + /// The maximum allowed concurrency level of this scheduler. If custom threads are + /// used, this represents the number of created threads. + /// + private readonly int _concurrencyLevel; + /// Whether we're processing tasks on the current thread. + private static readonly ThreadLocal s_taskProcessingThread = new ThreadLocal(); + + /// The threads used by the scheduler to process work. + private readonly Thread[] _threads; + /// The collection of tasks to be executed on our custom threads. + private readonly BlockingCollection _blockingTaskQueue; + + private readonly TaskFactory _factory; + + /// Initializes the scheduler. + /// The number of threads to create and use for processing work items. + public QueuedTaskScheduler(int threadCount) : this(threadCount, string.Empty, false, ThreadPriority.Normal, 0, null, null) { } + + /// Initializes the scheduler. + /// The number of threads to create and use for processing work items. + /// The name to use for each of the created threads. + /// A Boolean value that indicates whether to use foreground threads instead of background. + /// The priority to assign to each thread. + /// The stack size to use for each thread. + /// An initialization routine to run on each thread. + /// A finalization routine to run on each thread. + public QueuedTaskScheduler( + int threadCount, + string threadName = "", + bool useForegroundThreads = false, + ThreadPriority threadPriority = ThreadPriority.Normal, + int threadMaxStackSize = 0, + Action threadInit = null, + Action threadFinally = null) + { + // Validates arguments (some validation is left up to the Thread type itself). + // If the thread count is 0, default to the number of logical processors. + if (threadCount < 0) + throw new ArgumentOutOfRangeException(nameof(threadCount)); + else if (threadCount == 0) + _concurrencyLevel = Environment.ProcessorCount; + else + _concurrencyLevel = threadCount; + + // Initialize the queue used for storing tasks + _blockingTaskQueue = new BlockingCollection(); + + // Create all of the threads + _threads = new Thread[threadCount]; + for (int i = 0; i < threadCount; i++) + { + _threads[i] = new Thread(() => DispatchLoop(threadInit, threadFinally), threadMaxStackSize) + { + Priority = threadPriority, + IsBackground = !useForegroundThreads, + }; + if (threadName != null) + _threads[i].Name = threadName + " (" + i + ")"; + } + + _factory = new TaskFactory(this); + + // Start all of the threads + foreach (var thread in _threads) + thread.Start(); + } + + public TaskFactory Factory => _factory; + + /// The dispatch loop run by all threads in this scheduler. + /// An initialization routine to run when the thread begins. + /// A finalization routine to run before the thread ends. + private void DispatchLoop(Action threadInit, Action threadFinally) + { + s_taskProcessingThread.Value = true; + threadInit?.Invoke(); + try + { + // If the scheduler is disposed, the cancellation token will be set and + // we'll receive an OperationCanceledException. That OCE should not crash the process. + try + { + // If a thread abort occurs, we'll try to reset it and continue running. + while (true) + { + try + { + // For each task queued to the scheduler, try to execute it. + foreach (var task in _blockingTaskQueue.GetConsumingEnumerable(_disposeCancellation.Token)) + { + // If the task is not null, that means it was queued to this scheduler directly. + // Run it. + if (task != null) + { + bool tried = TryExecuteTask(task); + } + } + } + catch (ThreadAbortException) + { + // If we received a thread abort, and that thread abort was due to shutting down + // or unloading, let it pass through. Otherwise, reset the abort so we can + // continue processing work items. + if (!Environment.HasShutdownStarted && !AppDomain.CurrentDomain.IsFinalizingForUnload()) + { + Thread.ResetAbort(); + } + } + } + } + catch (OperationCanceledException) { } + } + finally + { + // Run a cleanup routine if there was one + threadFinally?.Invoke(); + s_taskProcessingThread.Value = false; + } + } + + /// Queues a task to the scheduler. + /// The task to be queued. + protected override void QueueTask(Task task) + { + // If we've been disposed, no one should be queueing + if (_disposeCancellation.IsCancellationRequested) + { + throw new ObjectDisposedException(GetType().Name); + } + _blockingTaskQueue.Add(task); + } + + /// Tries to execute a task synchronously on the current thread. + /// The task to execute. + /// Whether the task was previously queued. + /// true if the task was executed; otherwise, false. + protected override bool TryExecuteTaskInline(Task task, bool taskWasPreviouslyQueued) => + // If we're already running tasks on this threads, enable inlining + false; // s_taskProcessingThread.Value && TryExecuteTask(task); + + /// Gets the tasks scheduled to this scheduler. + /// An enumerable of all tasks queued to this scheduler. + /// This does not include the tasks on sub-schedulers. Those will be retrieved by the debugger separately. + protected override IEnumerable GetScheduledTasks() + { + // Get all of the tasks, filtering out nulls, which are just placeholders + // for tasks in other sub-schedulers + return _blockingTaskQueue.Where(t => t != null).ToList(); + } + + /// Gets the maximum concurrency level to use when processing tasks. + public override int MaximumConcurrencyLevel => _concurrencyLevel; + + /// Initiates shutdown of the scheduler. + public void Dispose() => _disposeCancellation.Cancel(); + } +} diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/LocalAppContextSwitches.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/LocalAppContextSwitches.cs index 8e390b21d6..b366359bdd 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/LocalAppContextSwitches.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/LocalAppContextSwitches.cs @@ -15,13 +15,16 @@ internal static partial class LocalAppContextSwitches internal const string LegacyRowVersionNullString = @"Switch.Microsoft.Data.SqlClient.LegacyRowVersionNullBehavior"; internal const string UseSystemDefaultSecureProtocolsString = @"Switch.Microsoft.Data.SqlClient.UseSystemDefaultSecureProtocols"; internal const string SuppressInsecureTLSWarningString = @"Switch.Microsoft.Data.SqlClient.SuppressInsecureTLSWarning"; + internal const string UseExperimentalMARSThreadingString = @"Switch.Microsoft.Data.SqlClient.UseExperimentalMARSThreading"; private static bool s_makeReadAsyncBlocking; private static bool? s_LegacyRowVersionNullBehavior; private static bool? s_UseSystemDefaultSecureProtocols; - private static bool? s_SuppressInsecureTLSWarning; + private static bool? s_SuppressInsecureTLSWarning; + private static bool? s_useExperimentalMARSThreading; -#if !NETFRAMEWORK + +#if NETCOREAPP31_AND_ABOVE static LocalAppContextSwitches() { IAppContextSwitchOverridesSection appContextSwitch = AppConfigManager.FetchConfigurationSection(AppContextSwitchOverridesSection.Name); @@ -95,5 +98,19 @@ public static bool UseSystemDefaultSecureProtocols return s_UseSystemDefaultSecureProtocols.Value; } } + + public static bool UseExperimentalMARSThreading + { + get + { + if (s_useExperimentalMARSThreading is null) + { + bool result; + result = AppContext.TryGetSwitch(UseExperimentalMARSThreadingString, out result) ? result : false; + s_useExperimentalMARSThreading = result; + } + return s_useExperimentalMARSThreading.Value; + } + } } }