diff --git a/src/runtime/core/core.ts b/src/runtime/core/core.ts index 13a10b8..406a52b 100644 --- a/src/runtime/core/core.ts +++ b/src/runtime/core/core.ts @@ -2,7 +2,7 @@ import { generateRandomString, alphabet } from "oslo/crypto"; import { checkDbAndTables, type tableNames } from "../database"; import { getOAuthAccountsTableSchema, getSessionsTableSchema, getUsersTableSchema } from "../database/lib/schema"; import { drizzle as drizzleIntegration } from "db0/integrations/drizzle/index"; -import type { checkDbAndTablesParameters, ICreateOrLoginParams, ISlipAuthCoreOptions, SchemasMockValue, SlipAuthSession } from "./types"; +import type { checkDbAndTablesParameters, DrizzleTransaction, ICreateOrLoginParams, ISlipAuthCoreOptions, SchemasMockValue, SlipAuthSession } from "./types"; import { createSlipHooks } from "./hooks"; import { UsersRepository } from "./repositories/UsersRepository"; import { SessionsRepository } from "./repositories/SessionsRepository"; @@ -71,52 +71,56 @@ export class SlipAuthCore { public async registerUserIfMissingInDb( params: ICreateOrLoginParams, ): Promise<[ string, SlipAuthSession]> { - const existingUser = await this.#repos.users.findByEmail(params.email); - - if (!existingUser) { - const userId = this.#createRandomUserId(); - - await this.#repos.users.insert(userId, params.email); - - const _insertedOAuthAccount = await this.#repos.oAuthAccounts.insert(params.email, { - provider_id: params.providerId, - provider_user_id: params.providerUserId, - user_id: userId, - }); - - const sessionFromRegistrationId = this.#createRandomSessionId(); - const sessionFromRegistration = await this.#repos.sessions.insert(sessionFromRegistrationId, { - userId, - expiresAt: Date.now() + this.#sessionMaxAge, - ip: params.ip, - ua: params.ua, - }); - - return [userId, sessionFromRegistration as SlipAuthSession]; - } - - const existingAccount = await this.#repos.oAuthAccounts.findByProviderData( - params.providerId, params.providerUserId, - ); - - if (existingUser && existingAccount?.provider_id !== params.providerId) { - throw new Error("user already have an account with another provider"); - } - - if (existingAccount) { - const sessionFromLoginId = this.#createRandomSessionId(); - const sessionFromLogin = await this.#repos.sessions.insert(sessionFromLoginId, { - userId: existingUser.id, - expiresAt: Date.now() + this.#sessionMaxAge, - ua: params.ua, - ip: params.ip, - }); - const { id, expires_at } = sessionFromLogin; - - return [existingUser.id, { id, expires_at }]; - } - - throw new Error("could not find oauth user"); + return await this.#orm.transaction(async (trx: DrizzleTransaction) => { + const existingUser = await this.#repos.users.findByEmail(params.email, trx); + + if (!existingUser) { + const userId = this.#createRandomUserId(); + + await this.#repos.users.insert({ userId, email: params.email }, trx); + + const _insertedOAuthAccount = await this.#repos.oAuthAccounts.insert(params.email, { + provider_id: params.providerId, + provider_user_id: params.providerUserId, + user_id: userId, + }, trx); + + const sessionFromRegistrationId = this.#createRandomSessionId(); + const sessionFromRegistration = await this.#repos.sessions.insert(sessionFromRegistrationId, { + userId, + expiresAt: Date.now() + this.#sessionMaxAge, + ip: params.ip, + ua: params.ua, + }, trx); + + throw new Error("could not find oauth user"); + return [userId, sessionFromRegistration as SlipAuthSession]; + } + + const existingAccount = await this.#repos.oAuthAccounts.findByProviderData( + params.providerId, params.providerUserId, + ); + + if (existingUser && existingAccount?.provider_id !== params.providerId) { + throw new Error("user already have an account with another provider"); + } + + if (existingAccount) { + const sessionFromLoginId = this.#createRandomSessionId(); + const sessionFromLogin = await this.#repos.sessions.insert(sessionFromLoginId, { + userId: existingUser.id, + expiresAt: Date.now() + this.#sessionMaxAge, + ua: params.ua, + ip: params.ip, + }, trx); + + const { id, expires_at } = sessionFromLogin; + + return [existingUser.id, { id, expires_at }]; + } + + throw new Error("could not find oauth user"); + }); } public setCreateRandomUserId(fn: () => string) { diff --git a/src/runtime/core/repositories/OAuthAccountsRepository.ts b/src/runtime/core/repositories/OAuthAccountsRepository.ts index 8c7d67b..a3e042a 100644 --- a/src/runtime/core/repositories/OAuthAccountsRepository.ts +++ b/src/runtime/core/repositories/OAuthAccountsRepository.ts @@ -1,9 +1,11 @@ import { eq, and } from "drizzle-orm"; import { TableRepository } from "./_repo"; +import type { DrizzleTransaction } from "../types"; export class OAuthAccountsRepository extends TableRepository<"oauthAccounts"> { - async insert(email: string, values: typeof this.table.$inferInsert): Promise { - await this._orm + async insert(email: string, values: typeof this.table.$inferInsert, trx?: DrizzleTransaction): Promise { + const orm = trx || this._orm; + await orm .insert(this.table) .values(values) .run(); diff --git a/src/runtime/core/repositories/SessionsRepository.ts b/src/runtime/core/repositories/SessionsRepository.ts index 1e12e34..1f651b7 100644 --- a/src/runtime/core/repositories/SessionsRepository.ts +++ b/src/runtime/core/repositories/SessionsRepository.ts @@ -1,10 +1,11 @@ import { eq, sql } from "drizzle-orm"; import { TableRepository } from "./_repo"; -import type { ICreateSessionsParams } from "../types"; +import type { ICreateSessionsParams, DrizzleTransaction } from "../types"; export class SessionsRepository extends TableRepository<"sessions"> { - async insert(sessionId: string, { userId, expiresAt, ip, ua }: ICreateSessionsParams): Promise { - await this._orm + async insert(sessionId: string, { userId, expiresAt, ip, ua }: ICreateSessionsParams, trx?: DrizzleTransaction): Promise { + const orm = trx || this._orm; + await orm .insert(this.table) .values({ id: sessionId, diff --git a/src/runtime/core/repositories/UsersRepository.ts b/src/runtime/core/repositories/UsersRepository.ts index d9ed4d4..105aa1c 100644 --- a/src/runtime/core/repositories/UsersRepository.ts +++ b/src/runtime/core/repositories/UsersRepository.ts @@ -1,9 +1,11 @@ import { eq } from "drizzle-orm"; import { TableRepository } from "./_repo"; +import type { DrizzleTransaction } from "../types"; export class UsersRepository extends TableRepository<"users"> { - async insert(userId: string, email: string): Promise { - await this._orm + async insert({ userId, email }: { userId: string, email: string }, trx?: DrizzleTransaction): Promise { + const orm = trx ?? this._orm; + await orm .insert(this.table) .values({ id: userId, diff --git a/src/runtime/core/types.ts b/src/runtime/core/types.ts index 181061b..eec180e 100644 --- a/src/runtime/core/types.ts +++ b/src/runtime/core/types.ts @@ -1,3 +1,4 @@ +import type { drizzle as drizzleIntegration } from "db0/integrations/drizzle/index"; import type { SQLiteTable } from "drizzle-orm/sqlite-core"; import type { checkDbAndTables, tableNames } from "../database"; import { getOAuthAccountsTableSchema, getSessionsTableSchema, getUsersTableSchema } from "../database/lib/schema"; @@ -5,6 +6,8 @@ import { getOAuthAccountsTableSchema, getSessionsTableSchema, getUsersTableSchem export type { tableNames }; export type { supportedConnectors } from "../database"; +export type DrizzleTransaction = Parameters["transaction"]>["0"]>[0]; + export type checkDbAndTablesParameters = Parameters; export interface ICreateOrLoginParams {