Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,7 @@
using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.Interfaces;
using Bit.Core.AdminConsole.OrganizationFeatures.Policies;
using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyRequirements;
using Bit.Core.Entities;
using Bit.Core.Enums;
using Bit.Core.Models.Data;
using Bit.Core.Platform.Push;
using Bit.Core.Repositories;
using Bit.Core.Services;
Expand Down Expand Up @@ -79,19 +77,10 @@ private async Task CreateDefaultCollectionsAsync(AutomaticallyConfirmOrganizatio
return;
}

await collectionRepository.CreateAsync(
new Collection
{
OrganizationId = request.Organization!.Id,
Name = request.DefaultUserCollectionName,
Type = CollectionType.DefaultUserCollection
},
groups: null,
[new CollectionAccessSelection
{
Id = request.OrganizationUser!.Id,
Manage = true
}]);
await collectionRepository.UpsertDefaultCollectionAsync(
request.Organization!.Id,
request.OrganizationUser!.Id,
request.DefaultUserCollectionName);
}
catch (Exception ex)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
using Bit.Core.Entities;
using Bit.Core.Enums;
using Bit.Core.Exceptions;
using Bit.Core.Models.Data;
using Bit.Core.Platform.Push;
using Bit.Core.Repositories;
using Bit.Core.Services;
Expand Down Expand Up @@ -274,21 +273,10 @@ private async Task CreateDefaultCollectionAsync(OrganizationUser organizationUse
return;
}

var defaultCollection = new Collection
{
OrganizationId = organizationUser.OrganizationId,
Name = defaultUserCollectionName,
Type = CollectionType.DefaultUserCollection
};
var collectionUser = new CollectionAccessSelection
{
Id = organizationUser.Id,
ReadOnly = false,
HidePasswords = false,
Manage = true
};

await _collectionRepository.CreateAsync(defaultCollection, groups: null, users: [collectionUser]);
await _collectionRepository.UpsertDefaultCollectionAsync(
organizationUser.OrganizationId,
organizationUser.Id,
defaultUserCollectionName);
}

/// <summary>
Expand Down
10 changes: 10 additions & 0 deletions src/Core/Repositories/ICollectionRepository.cs
Original file line number Diff line number Diff line change
Expand Up @@ -71,4 +71,14 @@ Task CreateOrUpdateAccessForManyAsync(Guid organizationId, IEnumerable<Guid> col
/// <param name="defaultCollectionName">The encrypted string to use as the default collection name.</param>
/// <returns></returns>
Task UpsertDefaultCollectionsAsync(Guid organizationId, IEnumerable<Guid> organizationUserIds, string defaultCollectionName);

/// <summary>
/// Creates a default user collection for the specified organization user if they do not already have one.
/// This operation is idempotent - calling it multiple times will not create duplicate collections.
/// </summary>
/// <param name="organizationId">The Organization ID.</param>
/// <param name="organizationUserId">The Organization User ID to create/find a default collection for.</param>
/// <param name="defaultCollectionName">The encrypted string to use as the default collection name.</param>
/// <returns>True if a new collection was created; false if the user already had a default collection.</returns>
Task<bool> UpsertDefaultCollectionAsync(Guid organizationId, Guid organizationUserId, string defaultCollectionName);
}
24 changes: 24 additions & 0 deletions src/Infrastructure.Dapper/Repositories/CollectionRepository.cs
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,30 @@ public async Task UpsertDefaultCollectionsAsync(Guid organizationId, IEnumerable
}
}

public async Task<bool> UpsertDefaultCollectionAsync(Guid organizationId, Guid organizationUserId, string defaultCollectionName)
{
using (var connection = new SqlConnection(ConnectionString))
{
var collectionId = CoreHelpers.GenerateComb();
var now = DateTime.UtcNow;
var parameters = new DynamicParameters();
parameters.Add("@CollectionId", collectionId);
parameters.Add("@OrganizationId", organizationId);
parameters.Add("@OrganizationUserId", organizationUserId);
parameters.Add("@Name", defaultCollectionName);
parameters.Add("@CreationDate", now);
parameters.Add("@RevisionDate", now);
parameters.Add("@WasCreated", dbType: DbType.Boolean, direction: ParameterDirection.Output);

await connection.ExecuteAsync(
$"[{Schema}].[Collection_UpsertDefaultCollection]",
parameters,
commandType: CommandType.StoredProcedure);

return parameters.Get<bool>("@WasCreated");
}
}

private async Task<HashSet<Guid>> GetOrgUserIdsWithDefaultCollectionAsync(SqlConnection connection, SqlTransaction transaction, Guid organizationId)
{
const string sql = @"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,22 +47,34 @@ public static void SetupEntityFramework(this IServiceCollection services, string
{
if (provider == SupportedDatabaseProviders.Postgres)
{
options.UseNpgsql(connectionString, b => b.MigrationsAssembly("PostgresMigrations"));
options.UseNpgsql(connectionString, b =>
{
b.MigrationsAssembly("PostgresMigrations");
b.EnableRetryOnFailure();
});
// Handle NpgSql Legacy Support for `timestamp without timezone` issue
AppContext.SetSwitch("Npgsql.EnableLegacyTimestampBehavior", true);
}
else if (provider == SupportedDatabaseProviders.MySql)
{
options.UseMySql(connectionString, ServerVersion.AutoDetect(connectionString),
b => b.MigrationsAssembly("MySqlMigrations"));
b =>
{
b.MigrationsAssembly("MySqlMigrations");
b.EnableRetryOnFailure();
});
}
else if (provider == SupportedDatabaseProviders.Sqlite)
{
// SQLite doesn't support EnableRetryOnFailure
options.UseSqlite(connectionString, b => b.MigrationsAssembly("SqliteMigrations"));
}
else if (provider == SupportedDatabaseProviders.SqlServer)
{
options.UseSqlServer(connectionString);
options.UseSqlServer(connectionString, b =>
{
b.EnableRetryOnFailure();
});
}
});
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -821,6 +821,83 @@ public async Task UpsertDefaultCollectionsAsync(Guid organizationId, IEnumerable
await dbContext.SaveChangesAsync();
}

public async Task<bool> UpsertDefaultCollectionAsync(Guid organizationId, Guid organizationUserId, string defaultCollectionName)
{
using var scope = ServiceScopeFactory.CreateScope();
var dbContext = GetDatabaseContext(scope);

// Use EF's execution strategy to handle transient failures (including deadlocks)
var strategy = dbContext.Database.CreateExecutionStrategy();

return await strategy.ExecuteAsync(async () =>
{
// Use SERIALIZABLE isolation level to prevent race conditions during concurrent calls
using var transaction = await dbContext.Database.BeginTransactionAsync(System.Data.IsolationLevel.Serializable);

try
{
// Check if this organization user already has a default collection
// SERIALIZABLE ensures this SELECT acquires range locks
var existingDefaultCollection = await (
from c in dbContext.Collections
join cu in dbContext.CollectionUsers on c.Id equals cu.CollectionId
where cu.OrganizationUserId == organizationUserId
&& c.OrganizationId == organizationId
&& c.Type == CollectionType.DefaultUserCollection
select c
).FirstOrDefaultAsync();

// If collection already exists, return false (not created)
if (existingDefaultCollection != null)
{
await transaction.CommitAsync();
return false;
}

// Create new default collection
var collectionId = CoreHelpers.GenerateComb();
var now = DateTime.UtcNow;

var collection = new Collection
{
Id = collectionId,
OrganizationId = organizationId,
Name = defaultCollectionName,
ExternalId = null,
CreationDate = now,
RevisionDate = now,
Type = CollectionType.DefaultUserCollection,
DefaultUserCollectionEmail = null
};

var collectionUser = new CollectionUser
{
CollectionId = collectionId,
OrganizationUserId = organizationUserId,
ReadOnly = false,
HidePasswords = false,
Manage = true
};

await dbContext.Collections.AddAsync(collection);
await dbContext.CollectionUsers.AddAsync(collectionUser);
await dbContext.SaveChangesAsync();

// Bump user account revision dates
await dbContext.UserBumpAccountRevisionDateByCollectionIdAsync(collectionId, organizationId);
await dbContext.SaveChangesAsync();

await transaction.CommitAsync();
return true;
}
catch
{
await transaction.RollbackAsync();
throw;
}
});
}

private async Task<HashSet<Guid>> GetOrgUserIdsWithDefaultCollectionAsync(DatabaseContext dbContext, Guid organizationId)
{
var results = await dbContext.OrganizationUsers
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
-- This procedure prevents duplicate "My Items" collections for users by checking
-- if a default collection already exists before attempting to create one.

CREATE PROCEDURE [dbo].[Collection_UpsertDefaultCollection]
@CollectionId UNIQUEIDENTIFIER,
@OrganizationId UNIQUEIDENTIFIER,
@OrganizationUserId UNIQUEIDENTIFIER,
@Name VARCHAR(MAX),
@CreationDate DATETIME2(7),
@RevisionDate DATETIME2(7),
@WasCreated BIT OUTPUT
AS
BEGIN
SET NOCOUNT ON

-- Use SERIALIZABLE isolation level to prevent race conditions during concurrent calls
SET TRANSACTION ISOLATION LEVEL SERIALIZABLE;
BEGIN TRANSACTION;

BEGIN TRY
DECLARE @ExistingCollectionId UNIQUEIDENTIFIER;

-- Check if this organization user already has a default collection
-- SERIALIZABLE ensures range locks prevent concurrent insertions
SELECT @ExistingCollectionId = c.Id
FROM [dbo].[Collection] c
INNER JOIN [dbo].[CollectionUser] cu ON cu.CollectionId = c.Id
WHERE cu.OrganizationUserId = @OrganizationUserId
AND c.OrganizationId = @OrganizationId
AND c.Type = 1; -- CollectionType.DefaultUserCollection

-- If collection already exists, return early
IF @ExistingCollectionId IS NOT NULL
BEGIN
SET @WasCreated = 0;
COMMIT TRANSACTION;
RETURN;
END

-- Create new default collection
SET @WasCreated = 1;

-- Insert Collection
INSERT INTO [dbo].[Collection]
(
[Id],
[OrganizationId],
[Name],
[ExternalId],
[CreationDate],
[RevisionDate],
[DefaultUserCollectionEmail],
[Type]
)
VALUES
(
@CollectionId,
@OrganizationId,
@Name,
NULL, -- ExternalId
@CreationDate,
@RevisionDate,
NULL, -- DefaultUserCollectionEmail
1 -- CollectionType.DefaultUserCollection
);

-- Insert CollectionUser
INSERT INTO [dbo].[CollectionUser]
(
[CollectionId],
[OrganizationUserId],
[ReadOnly],
[HidePasswords],
[Manage]
)
VALUES
(
@CollectionId,
@OrganizationUserId,
0, -- ReadOnly = false
0, -- HidePasswords = false
1 -- Manage = true
);

-- Bump user account revision dates
EXEC [dbo].[User_BumpAccountRevisionDateByCollectionId] @CollectionId, @OrganizationId;

COMMIT TRANSACTION;
END TRY
BEGIN CATCH
IF @@TRANCOUNT > 0
ROLLBACK TRANSACTION;
THROW;
END CATCH
END
GO
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
using Bit.Core.AdminConsole.Utilities.v2.Validation;
using Bit.Core.Entities;
using Bit.Core.Enums;
using Bit.Core.Models.Data;
using Bit.Core.Platform.Push;
using Bit.Core.Repositories;
using Bit.Core.Services;
Expand Down Expand Up @@ -203,14 +202,10 @@ public async Task AutomaticallyConfirmOrganizationUserAsync_WithDefaultCollectio

await sutProvider.GetDependency<ICollectionRepository>()
.Received(1)
.CreateAsync(
Arg.Is<Collection>(c =>
c.OrganizationId == organization.Id &&
c.Name == defaultCollectionName &&
c.Type == CollectionType.DefaultUserCollection),
Arg.Is<IEnumerable<CollectionAccessSelection>>(groups => groups == null),
Arg.Is<IEnumerable<CollectionAccessSelection>>(access =>
access.FirstOrDefault(x => x.Id == organizationUser.Id && x.Manage) != null));
.UpsertDefaultCollectionAsync(
organization.Id,
organizationUser.Id,
defaultCollectionName);
}

[Theory]
Expand Down Expand Up @@ -252,9 +247,10 @@ public async Task AutomaticallyConfirmOrganizationUserAsync_WithDefaultCollectio

await sutProvider.GetDependency<ICollectionRepository>()
.DidNotReceive()
.CreateAsync(Arg.Any<Collection>(),
Arg.Any<IEnumerable<CollectionAccessSelection>>(),
Arg.Any<IEnumerable<CollectionAccessSelection>>());
.UpsertDefaultCollectionAsync(
Arg.Any<Guid>(),
Arg.Any<Guid>(),
Arg.Any<string>());
}

[Theory]
Expand Down Expand Up @@ -290,9 +286,10 @@ public async Task AutomaticallyConfirmOrganizationUserAsync_WhenCreateDefaultCol

var collectionException = new Exception("Collection creation failed");
sutProvider.GetDependency<ICollectionRepository>()
.CreateAsync(Arg.Any<Collection>(),
Arg.Any<IEnumerable<CollectionAccessSelection>>(),
Arg.Any<IEnumerable<CollectionAccessSelection>>())
.UpsertDefaultCollectionAsync(
Arg.Any<Guid>(),
Arg.Any<Guid>(),
Arg.Any<string>())
.ThrowsAsync(collectionException);

// Act
Expand Down
Loading
Loading