Skip to content

Commit

Permalink
Add NpgsqlConnection ReloadTypesAsync()
Browse files Browse the repository at this point in the history
Possibly fixes #4369 but will probably be replaced by a more general
approach.
  • Loading branch information
Brar committed Sep 13, 2022
1 parent 8469821 commit ac2d3e4
Show file tree
Hide file tree
Showing 5 changed files with 114 additions and 31 deletions.
85 changes: 69 additions & 16 deletions src/Npgsql/NpgsqlConnection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1163,7 +1163,7 @@ async Task<NpgsqlBinaryImporter> BeginBinaryImport(string copyFromCommand, bool
throw new ArgumentException("Must contain a COPY FROM STDIN command!", nameof(copyFromCommand));

CheckReady();
var connector = StartBindingScope(ConnectorBindingScope.Copy);
var connector = await StartBindingScope(ConnectorBindingScope.Copy, async);

LogMessages.StartingBinaryImport(connector.LoggingConfiguration.CopyLogger, connector.Id);
// no point in passing a cancellationToken here, as we register the cancellation in the Init method
Expand Down Expand Up @@ -1217,7 +1217,7 @@ async Task<NpgsqlBinaryExporter> BeginBinaryExport(string copyToCommand, bool as
throw new ArgumentException("Must contain a COPY TO STDOUT command!", nameof(copyToCommand));

CheckReady();
var connector = StartBindingScope(ConnectorBindingScope.Copy);
var connector = await StartBindingScope(ConnectorBindingScope.Copy, async);

LogMessages.StartingBinaryExport(connector.LoggingConfiguration.CopyLogger, connector.Id);
// no point in passing a cancellationToken here, as we register the cancellation in the Init method
Expand Down Expand Up @@ -1277,7 +1277,7 @@ async Task<TextWriter> BeginTextImport(string copyFromCommand, bool async, Cance
throw new ArgumentException("Must contain a COPY FROM STDIN command!", nameof(copyFromCommand));

CheckReady();
var connector = StartBindingScope(ConnectorBindingScope.Copy);
var connector = await StartBindingScope(ConnectorBindingScope.Copy, async);

LogMessages.StartingTextImport(connector.LoggingConfiguration.CopyLogger, connector.Id);
// no point in passing a cancellationToken here, as we register the cancellation in the Init method
Expand Down Expand Up @@ -1338,7 +1338,7 @@ async Task<TextReader> BeginTextExport(string copyToCommand, bool async, Cancell
throw new ArgumentException("Must contain a COPY TO STDOUT command!", nameof(copyToCommand));

CheckReady();
var connector = StartBindingScope(ConnectorBindingScope.Copy);
var connector = await StartBindingScope(ConnectorBindingScope.Copy, async);

LogMessages.StartingTextExport(connector.LoggingConfiguration.CopyLogger, connector.Id);
// no point in passing a cancellationToken here, as we register the cancellation in the Init method
Expand Down Expand Up @@ -1399,7 +1399,7 @@ async Task<NpgsqlRawCopyStream> BeginRawBinaryCopy(string copyCommand, bool asyn
throw new ArgumentException("Must contain a COPY TO STDOUT OR COPY FROM STDIN command!", nameof(copyCommand));

CheckReady();
var connector = StartBindingScope(ConnectorBindingScope.Copy);
var connector = await StartBindingScope(ConnectorBindingScope.Copy, async);

LogMessages.StartingRawCopy(connector.LoggingConfiguration.CopyLogger, connector.Id);
// no point in passing a cancellationToken here, as we register the cancellation in the Init method
Expand Down Expand Up @@ -1807,10 +1807,36 @@ async ValueTask<NpgsqlConnector> StartBindingScopeAsync()
}
}

/// <summary>
/// Synchronously starts a binding scope
/// </summary>
/// <param name="scope">The <see cref="ConnectorBindingScope"/> to start</param>
/// <param name="async">The <see cref="ConnectorBindingScope"/> to start</param>
/// <returns>
/// A <see cref="ValueTask"/> representing the connector to use within the started binding scope
/// </returns>
internal ValueTask<NpgsqlConnector> StartBindingScope(ConnectorBindingScope scope, bool async)
=> StartBindingScope(scope, NpgsqlTimeout.Infinite, async, CancellationToken.None);

/// <summary>
/// Synchronously starts a binding scope
/// </summary>
/// <param name="scope">The <see cref="ConnectorBindingScope"/> to start</param>
/// <returns>A connector to use within the started binding scope</returns>
/// <remarks>
/// Warning: Never use this in async methods when multiplexing because it may block and cause a deadlock.
/// </remarks>
internal NpgsqlConnector StartBindingScope(ConnectorBindingScope scope)
=> StartBindingScope(scope, NpgsqlTimeout.Infinite, async: false, CancellationToken.None)
.GetAwaiter().GetResult();
=> StartBindingScope(scope, async: false).GetAwaiter().GetResult();

/// <summary>
/// Synchronously starts a temporary binding scope
/// </summary>
/// <param name="connector">A connector to execute the temporary commands</param>
/// <returns>An <see cref="IDisposable"/> which ends the temporary binding scope when it is disposed</returns>
/// <remarks>
/// Warning: Never use this in async methods when multiplexing because it may block and cause a deadlock.
/// </remarks>
internal EndScopeDisposable StartTemporaryBindingScope(out NpgsqlConnector connector)
{
connector = StartBindingScope(ConnectorBindingScope.Temporary);
Expand Down Expand Up @@ -2042,6 +2068,9 @@ public void UnprepareAll()
/// </summary>
public void ReloadTypes()
{
if (Settings.Multiplexing)
throw new NotSupportedException();

CheckReady();
using var scope = StartTemporaryBindingScope(out var connector);
connector.LoadDatabaseInfo(
Expand All @@ -2053,17 +2082,41 @@ public void ReloadTypes()
// Increment the change counter on the global type mapper. This will make conn.Open() pick up the
// new DatabaseInfo and set up a new connection type mapper
TypeMapping.GlobalTypeMapper.Instance.RecordChange();
}

if (Settings.Multiplexing)
/// <inheritdoc cref="ReloadTypes"/>
public async Task ReloadTypesAsync()
{
CheckReady();
var connector = await StartBindingScope(ConnectorBindingScope.Temporary, true);
try
{
await connector.LoadDatabaseInfo(
forceReload: true,
NpgsqlTimeout.Infinite,
async: true,
CancellationToken.None);

// Increment the change counter on the global type mapper. This will make conn.Open() pick up the
// new DatabaseInfo and set up a new connection type mapper
TypeMapping.GlobalTypeMapper.Instance.RecordChange();

if (Settings.Multiplexing)
{
var multiplexingTypeMapper = ((MultiplexingDataSource)NpgsqlDataSource).MultiplexingTypeMapper!;
Debug.Assert(multiplexingTypeMapper == connector.TypeMapper,
"A connector must reference the exact same TypeMapper the MultiplexingConnectorPool does");
// It's very probable that we've called ReloadTypes on the different connection than
// the MultiplexingConnectorPool references.
// Which means, we have to explicitly call Reset after we change the connector's DatabaseInfo to reload type mappings.
multiplexingTypeMapper.Connector.DatabaseInfo = connector.TypeMapper.DatabaseInfo;
multiplexingTypeMapper.Reset();
}

}
finally
{
var multiplexingTypeMapper = ((MultiplexingDataSource)NpgsqlDataSource).MultiplexingTypeMapper!;
Debug.Assert(multiplexingTypeMapper == connector.TypeMapper,
"A connector must reference the exact same TypeMapper the MultiplexingConnectorPool does");
// It's very probable that we've called ReloadTypes on the different connection than
// the MultiplexingConnectorPool references.
// Which means, we have to explicitly call Reset after we change the connector's DatabaseInfo to reload type mappings.
multiplexingTypeMapper.Connector.DatabaseInfo = connector.TypeMapper.DatabaseInfo;
multiplexingTypeMapper.Reset();
EndBindingScope(ConnectorBindingScope.Temporary);
}
}

Expand Down
1 change: 1 addition & 0 deletions src/Npgsql/PublicAPI.Unshipped.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#nullable enable
Npgsql.NpgsqlConnection.ReloadTypesAsync() -> System.Threading.Tasks.Task!
Npgsql.NpgsqlLoggingConfiguration
static Npgsql.NpgsqlLoggingConfiguration.InitializeLogging(Microsoft.Extensions.Logging.ILoggerFactory! loggerFactory, bool parameterLoggingEnabled = false) -> void
*REMOVED*Npgsql.NpgsqlConnection.Settings.get -> Npgsql.NpgsqlConnectionStringBuilder!
Expand Down
27 changes: 18 additions & 9 deletions test/Npgsql.Tests/ConnectionTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1039,8 +1039,7 @@ public void Clone()
}

[Test, IssueLink("https://github.com/npgsql/npgsql/issues/824")]
[NonParallelizable]
public async Task ReloadTypes()
public async Task ReloadTypes([Values]bool async)
{
if (IsMultiplexing)
return;
Expand All @@ -1049,19 +1048,29 @@ public async Task ReloadTypes()
using (var conn = await OpenConnectionAsync(connectionString))
using (var conn2 = await OpenConnectionAsync(connectionString))
{
Assert.That(await conn.ExecuteScalarAsync("SELECT EXISTS (SELECT * FROM pg_type WHERE typname='reload_types_enum')"),
Is.False);
await conn.ExecuteNonQueryAsync("CREATE TYPE pg_temp.reload_types_enum AS ENUM ('First', 'Second')");
Assert.That(() => conn.TypeMapper.MapEnum<ReloadTypesEnum>(), Throws.Exception.TypeOf<ArgumentException>());
conn.ReloadTypes();
conn.TypeMapper.MapEnum<ReloadTypesEnum>();
await using var tmpEnum = await CreateEnum<ReloadTypesEnum>(conn, out var enumName, out _);
Assert.That(() => conn.TypeMapper.MapEnum<ReloadTypesEnum>(enumName), Throws.Exception.TypeOf<ArgumentException>());
if (async)
await conn.ReloadTypesAsync();
else
{
if (IsMultiplexing)
{
Assert.That(() => conn.ReloadTypes(), Throws.InvalidOperationException);
return;
}
// ReSharper disable once MethodHasAsyncOverload
conn.ReloadTypes();
}

conn.TypeMapper.MapEnum<ReloadTypesEnum>(enumName);

// Make sure conn2 picks up the new type after a pooled close
var connId = conn2.ProcessID;
conn2.Close();
conn2.Open();
Assert.That(conn2.ProcessID, Is.EqualTo(connId), "Didn't get the same connector back");
conn2.TypeMapper.MapEnum<ReloadTypesEnum>();
conn2.TypeMapper.MapEnum<ReloadTypesEnum>(enumName);
}
}
enum ReloadTypesEnum { First, Second };
Expand Down
10 changes: 4 additions & 6 deletions test/Npgsql.Tests/ReaderTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -362,18 +362,17 @@ public async Task GetDataTypeName(string typeName, string? normalizedName = null
Assert.That(reader.GetDataTypeName(0), Is.EqualTo(normalizedName));
}

[Test]
[Test, IssueLink("https://github.com/npgsql/npgsql/issues/4369")]
public async Task GetDataTypeName_enum()
{
var csb = new NpgsqlConnectionStringBuilder(ConnectionString)
{
MaxPoolSize = 1
MaxPoolSize = 1,
};
await using var conn = await OpenConnectionAsync(csb);
await using var _ = await GetTempTypeName(conn, out var typeName);
await conn.ExecuteNonQueryAsync($"CREATE TYPE {typeName} AS ENUM ('one')");
await Task.Yield(); // TODO: fix multiplexing deadlock bug
conn.ReloadTypes();
await conn.ReloadTypesAsync();
await using var cmd = new NpgsqlCommand($"SELECT 'one'::{typeName}", conn);
await using var reader = await cmd.ExecuteReaderAsync(Behavior);
await reader.ReadAsync();
Expand All @@ -390,8 +389,7 @@ public async Task GetDataTypeName_domain()
await using var conn = await OpenConnectionAsync(csb);
await using var _ = await GetTempTypeName(conn, out var typeName);
await conn.ExecuteNonQueryAsync($"CREATE DOMAIN {typeName} AS VARCHAR(10)");
await Task.Yield(); // TODO: fix multiplexing deadlock bug
conn.ReloadTypes();
await conn.ReloadTypesAsync();
await using var cmd = new NpgsqlCommand($"SELECT 'one'::{typeName}", conn);
await using var reader = await cmd.ExecuteReaderAsync(Behavior);
await reader.ReadAsync();
Expand Down
22 changes: 22 additions & 0 deletions test/Npgsql.Tests/TestUtil.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Extensions.Logging;
using Npgsql.NameTranslation;
using NUnit.Framework;

namespace Npgsql.Tests
Expand Down Expand Up @@ -189,6 +190,27 @@ internal static Task<IAsyncDisposable> CreateTempSchema(NpgsqlConnection conn, o
TaskContinuationOptions.OnlyOnRanToCompletion);
}


/// <summary>
/// Creates a schema with a unique name, usable for a single test, and returns an <see cref="IDisposable"/> to
/// drop it at the end of the test.
/// </summary>
internal static Task<IAsyncDisposable> CreateEnum<TEnum>(NpgsqlConnection conn, out string enumName, out string[] enumValues, bool parallelizable = true, INpgsqlNameTranslator? nameTranslator = null)
where TEnum : struct, Enum
{
nameTranslator ??= new NpgsqlSnakeCaseNameTranslator();
var enumType = typeof(TEnum);
enumName = nameTranslator.TranslateTypeName(enumType.Name);
enumValues = Enum.GetNames(enumType).Select(n => nameTranslator.TranslateMemberName(n)).ToArray();
if (parallelizable)
enumName += Interlocked.Increment(ref _tempTypeCounter);
return conn.ExecuteNonQueryAsync($"DROP TYPE IF EXISTS {enumName} CASCADE; CREATE TYPE {enumName} AS ENUM ('{string.Join("', '", enumValues)}')")
.ContinueWith<IAsyncDisposable>(
(_, name) => new DatabaseObjectDropper(conn, (string)name!, "TYPE"),
enumName,
TaskContinuationOptions.OnlyOnRanToCompletion);
}

/// <summary>
/// Generates a unique table name, usable for a single test, and drops it if it already exists.
/// Actual creation of the table is the responsibility of the caller.
Expand Down

0 comments on commit ac2d3e4

Please sign in to comment.