Skip to content

Commit

Permalink
Make SaveAllAsync
Browse files Browse the repository at this point in the history
  • Loading branch information
say25 committed Aug 27, 2019
1 parent 6bdb282 commit 8108ee7
Show file tree
Hide file tree
Showing 5 changed files with 69 additions and 13 deletions.
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Licensed under the MIT license. See LICENSE file in the project root for full license information.

using System.Collections.Generic;
using System.Threading;
using System.Threading.Tasks;
using Npgsql;

Expand All @@ -9,6 +10,6 @@ namespace PostgreSQLCopyHelper
public interface IPostgreSQLCopyHelper<TEntity>
{
ulong SaveAll(NpgsqlConnection connection, IEnumerable<TEntity> entities);
Task<ulong> SaveAllAsync(NpgsqlConnection connection, IEnumerable<TEntity> entities);
ValueTask<ulong> SaveAllAsync(NpgsqlConnection connection, IEnumerable<TEntity> entities, CancellationToken cancellationToken = default);
}
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Licensed under the MIT license. See LICENSE file in the project root for full license information.

using System;
using System.Threading.Tasks;
using Npgsql;

namespace PostgreSQLCopyHelper.Model
Expand All @@ -9,7 +10,7 @@ internal class ColumnDefinition<TEntity>
{
public string ColumnName { get; set; }

public Action<NpgsqlBinaryImporter, TEntity> Write { get; set; }
public Func<NpgsqlBinaryImporter, TEntity, Task> Write { get; set; }

public override string ToString()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using Npgsql;
using NpgsqlTypes;
Expand Down Expand Up @@ -38,15 +39,28 @@ public PostgreSQLCopyHelper(string schemaName, string tableName)
}

public ulong SaveAll(NpgsqlConnection connection, IEnumerable<TEntity> entities) =>
SaveAllAsync(connection, entities).GetAwaiter().GetResult();
DoSaveAllAsync(connection, entities).GetAwaiter().GetResult();

public async Task<ulong> SaveAllAsync(NpgsqlConnection connection, IEnumerable<TEntity> entities)
public ValueTask<ulong> SaveAllAsync(NpgsqlConnection connection, IEnumerable<TEntity> entities, CancellationToken cancellationToken = default)
{
if (cancellationToken.IsCancellationRequested)
{
return new ValueTask<ulong>(Task.FromCanceled<ulong>(cancellationToken));
}

using (NoSynchronizationContextScope.Enter())
{
return DoSaveAllAsync(connection, entities);
}
}

private async ValueTask<ulong> DoSaveAllAsync(NpgsqlConnection connection, IEnumerable<TEntity> entities)
{
using (var binaryCopyWriter = connection.BeginBinaryImport(GetCopyCommand()))
{
await WriteToStream(binaryCopyWriter, entities);

return await binaryCopyWriter.Complete(async: true);
return await binaryCopyWriter.CompleteAsync();
}
}

Expand All @@ -59,23 +73,23 @@ public PostgreSQLCopyHelper<TEntity> UsePostgresQuoting(bool enabled = true)

public PostgreSQLCopyHelper<TEntity> Map<TProperty>(string columnName, Func<TEntity, TProperty> propertyGetter, NpgsqlDbType type)
{
return AddColumn(columnName, (writer, entity) => writer.Write(propertyGetter(entity), type));
return AddColumn(columnName, (writer, entity) => writer.WriteAsync(propertyGetter(entity), type));
}

public PostgreSQLCopyHelper<TEntity> MapNullable<TProperty>(string columnName, Func<TEntity, TProperty?> propertyGetter, NpgsqlDbType type)
where TProperty : struct
{
return AddColumn(columnName, (writer, entity) =>
return AddColumn(columnName, async (writer, entity) =>
{
var val = propertyGetter(entity);

if (!val.HasValue)
{
writer.WriteNull();
await writer.WriteNullAsync();
}
else
{
writer.Write(val.Value, type);
await writer.WriteAsync(val.Value, type);
}
});
}
Expand All @@ -84,16 +98,16 @@ private async Task WriteToStream(NpgsqlBinaryImporter writer, IEnumerable<TEntit
{
foreach (var entity in entities)
{
await writer.StartRow(async: true);
await writer.StartRowAsync();

foreach (var columnDefinition in _columns)
{
columnDefinition.Write(writer, entity);
await columnDefinition.Write(writer, entity);
}
}
}

private PostgreSQLCopyHelper<TEntity> AddColumn(string columnName, Action<NpgsqlBinaryImporter, TEntity> action)
private PostgreSQLCopyHelper<TEntity> AddColumn(string columnName, Func<NpgsqlBinaryImporter, TEntity, Task> action)
{
_columns.Add(new ColumnDefinition<TEntity>
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@
</PropertyGroup>

<ItemGroup>
<PackageReference Include="Npgsql" Version="4.1.0-preview1" />
<PackageReference Include="Microsoft.SourceLink.GitHub" Version="1.0.0-beta2-19367-01">
<PrivateAssets>all</PrivateAssets>
<IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets>
</PackageReference>
<PackageReference Include="Npgsql" Version="4.1.0-ci.2184" />
</ItemGroup>

</Project>
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
// Code taken from https://github.com/npgsql/npgsql/blob/dev/src/Npgsql/NoSynchronizationContextScope.cs

using System;
using System.Threading;

namespace PostgreSQLCopyHelper.Utils
{
/// <summary>
/// This mechanism is used to temporarily set the current synchronization context to null while
/// executing Npgsql code, making all await continuations execute on the thread pool. This replaces
/// the need to place ConfigureAwait(false) everywhere, and should be used in all surface async methods,
/// without exception.
///
/// Warning: do not use this directly in async methods, use it in sync wrappers of async methods
/// (see https://github.com/npgsql/npgsql/issues/1593)
/// </summary>
/// <remarks>
/// http://stackoverflow.com/a/28307965/640325
/// </remarks>
internal static class NoSynchronizationContextScope
{
internal static Disposable Enter()
{
var sc = SynchronizationContext.Current;
SynchronizationContext.SetSynchronizationContext(null);
return new Disposable(sc);
}

internal struct Disposable : IDisposable
{
private readonly SynchronizationContext _synchronizationContext;

internal Disposable(SynchronizationContext synchronizationContext)
=> _synchronizationContext = synchronizationContext;

public void Dispose()
=> SynchronizationContext.SetSynchronizationContext(_synchronizationContext);
}
}
}

0 comments on commit 8108ee7

Please sign in to comment.