From 99b399528d60f05452408298c9c727afcf260992 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Ros?= Date: Thu, 2 May 2024 17:03:54 -0700 Subject: [PATCH] Refactor worker pattern (#545) --- src/YesSql.Core/Data/WorkDispatcher.cs | 90 ++++++++++++++++++++++++ src/YesSql.Core/Data/WorkerQueryKey.cs | 66 ++++++++++------- src/YesSql.Core/Services/DefaultQuery.cs | 10 +-- src/YesSql.Core/Session.cs | 6 +- src/YesSql.Core/Store.cs | 52 +++----------- 5 files changed, 148 insertions(+), 76 deletions(-) create mode 100644 src/YesSql.Core/Data/WorkDispatcher.cs diff --git a/src/YesSql.Core/Data/WorkDispatcher.cs b/src/YesSql.Core/Data/WorkDispatcher.cs new file mode 100644 index 00000000..26288e17 --- /dev/null +++ b/src/YesSql.Core/Data/WorkDispatcher.cs @@ -0,0 +1,90 @@ +using System; +using System.Collections.Concurrent; +using System.Threading.Tasks; + +#nullable enable + +namespace YesSql.Data; + +internal sealed class WorkDispatcher where TKey : notnull +{ + private readonly ConcurrentDictionary> _workers = new(); + + public async Task ScheduleAsync(TKey key, Func> valueFactory) + { + ArgumentNullException.ThrowIfNull(key); + + while (true) + { + if (_workers.TryGetValue(key, out var task)) + { + return await task; + } + + // This is the task that we'll return to all waiters. We'll complete it when the factory is complete + var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + + if (_workers.TryAdd(key, tcs.Task)) + { + try + { + var value = await valueFactory(key); + tcs.TrySetResult(value); + return await tcs.Task; + } + catch (Exception ex) + { + // Make sure all waiters see the exception + tcs.SetException(ex); + + throw; + } + finally + { + // We remove the entry if the factory failed so it's not a permanent failure + // and future gets can retry (this could be a pluggable policy) + _workers.TryRemove(key, out _); + } + } + } + } + + public async Task ScheduleAsync(TKey key, TState state, Func> valueFactory) + { + ArgumentNullException.ThrowIfNull(key); + + while (true) + { + if (_workers.TryGetValue(key, out var task)) + { + return await task; + } + + // This is the task that we'll return to all waiters. We'll complete it when the factory is complete + var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + + if (_workers.TryAdd(key, tcs.Task)) + { + try + { + var value = await valueFactory(key, state); + tcs.TrySetResult(value); + return await tcs.Task; + } + catch (Exception ex) + { + // Make sure all waiters see the exception + tcs.SetException(ex); + + throw; + } + finally + { + // We remove the entry if the factory failed so it's not a permanent failure + // and future gets can retry (this could be a pluggable policy) + _workers.TryRemove(key, out _); + } + } + } + } +} \ No newline at end of file diff --git a/src/YesSql.Core/Data/WorkerQueryKey.cs b/src/YesSql.Core/Data/WorkerQueryKey.cs index 8ab26005..6a347e3a 100644 --- a/src/YesSql.Core/Data/WorkerQueryKey.cs +++ b/src/YesSql.Core/Data/WorkerQueryKey.cs @@ -6,39 +6,44 @@ namespace YesSql.Data /// /// An instance of represents the state of . /// - public readonly struct WorkerQueryKey : IEquatable + public class WorkerQueryKey : IEquatable { private readonly string _prefix; + private readonly long _id; private readonly long[] _ids; private readonly Dictionary _parameters; - private readonly int _hashcode; + private readonly int _hashCode; public WorkerQueryKey(string prefix, long[] ids) { - if (prefix == null) - { - throw new ArgumentNullException(nameof(prefix)); - } - - if (ids == null) - { - throw new ArgumentNullException(nameof(ids)); - } + ArgumentNullException.ThrowIfNull(prefix); + ArgumentNullException.ThrowIfNull(ids); _prefix = prefix; _parameters = null; _ids = ids; - _hashcode = 0; - _hashcode = BuildHashCode(); + _hashCode = BuildHashCode(); + } + + public WorkerQueryKey(string prefix, long id) + { + ArgumentNullException.ThrowIfNull(prefix); + + _prefix = prefix; + _parameters = null; + _id = id; + _hashCode = BuildHashCode(); } public WorkerQueryKey(string prefix, Dictionary parameters) { + ArgumentNullException.ThrowIfNull(prefix); + ArgumentNullException.ThrowIfNull(parameters); + _prefix = prefix; _parameters = parameters; _ids = null; - _hashcode = 0; - _hashcode = BuildHashCode(); + _hashCode = BuildHashCode(); } /// @@ -75,9 +80,9 @@ public bool Equals(WorkerQueryKey other) private int BuildHashCode() { - var combinedHash = 5381; - combinedHash = ((combinedHash << 5) + combinedHash) ^ _prefix.GetHashCode(); + var hashCode = new HashCode(); + hashCode.Add(_prefix); if (_parameters != null) { @@ -85,35 +90,36 @@ private int BuildHashCode() { if (parameter.Key != null) { - combinedHash = ((combinedHash << 5) + combinedHash) ^ parameter.Key.GetHashCode(); + hashCode.Add(parameter.Key); } if (parameter.Value != null) { - combinedHash = ((combinedHash << 5) + combinedHash) ^ parameter.Value.GetHashCode(); + hashCode.Add(parameter.Value); } } - - return combinedHash; } if (_ids != null) { foreach (var id in _ids) { - combinedHash = ((combinedHash << 5) + combinedHash) ^ (int)id; + hashCode.Add(id); } + } - return combinedHash; + if (_id != 0) + { + hashCode.Add(_id); } - return default; + return hashCode.ToHashCode(); } /// public override int GetHashCode() { - return _hashcode; + return _hashCode; } private static bool SameParameters(Dictionary values1, Dictionary values2) @@ -181,5 +187,15 @@ private static bool SameIds(long[] values1, long[] values2) return true; } + + public static bool operator ==(WorkerQueryKey left, WorkerQueryKey right) + { + return left.Equals(right); + } + + public static bool operator !=(WorkerQueryKey left, WorkerQueryKey right) + { + return !(left == right); + } } } diff --git a/src/YesSql.Core/Services/DefaultQuery.cs b/src/YesSql.Core/Services/DefaultQuery.cs index 4cbc08d2..542c2ee6 100644 --- a/src/YesSql.Core/Services/DefaultQuery.cs +++ b/src/YesSql.Core/Services/DefaultQuery.cs @@ -1110,7 +1110,7 @@ public async Task CountAsync() try { - return await _session._store.ProduceAsync(key, static (state) => + return await _session._store.ProduceAsync(key, static (key, state) => { var logger = state.Session._store.Configuration.Logger; @@ -1221,7 +1221,7 @@ protected async Task FirstOrDefaultImpl() _query._queryState._sqlBuilder.Selector("*"); var sql = _query._queryState._sqlBuilder.ToSqlString(); var key = new WorkerQueryKey(sql, _query._queryState._sqlBuilder.Parameters); - return (await _query._session._store.ProduceAsync(key, static (state) => + return (await _query._session._store.ProduceAsync(key, static (key, state) => { var logger = state.Query._session._store.Configuration.Logger; @@ -1239,7 +1239,7 @@ protected async Task FirstOrDefaultImpl() _query._queryState._sqlBuilder.Selector(_query._queryState._documentTable, "*", _query._queryState._store.Configuration.Schema); var sql = _query._queryState._sqlBuilder.ToSqlString(); var key = new WorkerQueryKey(sql, _query._queryState._sqlBuilder.Parameters); - var documents = await _query._session._store.ProduceAsync(key, static (state) => + var documents = await _query._session._store.ProduceAsync(key, static (key, state) => { var logger = state.Query._session._store.Configuration.Logger; @@ -1326,7 +1326,7 @@ internal async Task> ListImpl() var sql = sqlBuilder.ToSqlString(); var key = new WorkerQueryKey(sql, _query._queryState._sqlBuilder.Parameters); - return await _query._session._store.ProduceAsync(key, static (state) => + return await _query._session._store.ProduceAsync(key, static (key, state) => { var logger = state.Query._session._store.Configuration.Logger; @@ -1356,7 +1356,7 @@ internal async Task> ListImpl() var key = new WorkerQueryKey(sql, sqlBuilder.Parameters); - var documents = await _query._session._store.ProduceAsync(key, static (state) => + var documents = await _query._session._store.ProduceAsync(key, static (key, state) => { var logger = state.Query._session._store.Configuration.Logger; diff --git a/src/YesSql.Core/Session.cs b/src/YesSql.Core/Session.cs index 005140b2..0544f024 100644 --- a/src/YesSql.Core/Session.cs +++ b/src/YesSql.Core/Session.cs @@ -406,11 +406,11 @@ private async Task GetDocumentByIdAsync(long id, string collection) var documentTable = Store.Configuration.TableNameConvention.GetDocumentTable(collection); var command = "select * from " + _dialect.QuoteForTableName(_tablePrefix + documentTable, Store.Configuration.Schema) + " where " + _dialect.QuoteForColumnName("Id") + " = @Id"; - var key = new WorkerQueryKey(nameof(GetDocumentByIdAsync), new[] { id }); + var key = new WorkerQueryKey(nameof(GetDocumentByIdAsync), id); try { - var result = await _store.ProduceAsync(key, (state) => + var result = await _store.ProduceAsync(key, (key, state) => { var logger = state.Store.Configuration.Logger; @@ -506,7 +506,7 @@ public async Task> GetAsync(long[] ids, string collection = nu var key = new WorkerQueryKey(nameof(GetAsync), ids); try { - var documents = await _store.ProduceAsync(key, static (state) => + var documents = await _store.ProduceAsync(key, static (key, state) => { var logger = state.Store.Configuration.Logger; diff --git a/src/YesSql.Core/Store.cs b/src/YesSql.Core/Store.cs index bd5de1e4..904ff779 100644 --- a/src/YesSql.Core/Store.cs +++ b/src/YesSql.Core/Store.cs @@ -37,6 +37,8 @@ public class Store : IStore internal readonly ConcurrentDictionary CompiledQueries = new(); + private readonly WorkDispatcher _dispatcher = new(); + internal const int SmallBufferSize = 128; internal const int MediumBufferSize = 512; internal const int LargeBufferSize = 1024; @@ -289,55 +291,19 @@ public IStore RegisterScopedIndexes(IEnumerable indexProviders) /// A key identifying the running work. /// A function containing the logic to execute. /// The result of the work. - internal async Task ProduceAsync(WorkerQueryKey key, Func> work, TState state) + internal Task ProduceAsync(WorkerQueryKey key, Func> work, TState state) { if (!Configuration.QueryGatingEnabled) { - return await work(state); + return work(key, state); } - object content = null; - - while (content == null) - { - // Is there any query already processing the ? - if (!Workers.TryGetValue(key, out var result)) - { - // Multiple threads can potentially reach this point which is fine - // c.f. https://blogs.msdn.microsoft.com/seteplia/2018/10/01/the-danger-of-taskcompletionsourcet-class/ - var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); - - Workers.TryAdd(key, tcs.Task); - - try - { - // The current worker is processed - content = await work(state); - } - catch - { - // An exception occurred in the main worker, we broadcast the null value - content = null; - throw; - } - finally - { - // Remove the worker task before setting the result. - // If the result is null, other threads would potentially - // acquire it otherwise. - Workers.TryRemove(key, out _); + return ProduceAwaitedAsync(key, work, state); + } - // Notify all other awaiters to return the result - tcs.TrySetResult(content); - } - } - else - { - // Another worker is already running, wait for it to finish and reuse the results. - // This value can be null if the worker failed, in this case the loop will run again. - content = await result; - } - } + internal async Task ProduceAwaitedAsync(WorkerQueryKey key, Func> work, TState state) + { + var content = await _dispatcher.ScheduleAsync(key, state, async (key, state) => await work(key, state)); return (T)content; }