Skip to content

Commit

Permalink
PR comments
Browse files Browse the repository at this point in the history
  • Loading branch information
infomiho committed Dec 21, 2023
1 parent b0b1a8b commit ce1c89b
Show file tree
Hide file tree
Showing 11 changed files with 55 additions and 47 deletions.
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import { Request, Response } from 'express';
import { verifyPassword, throwInvalidCredentialsError } from "../../../core/auth.js";
import {
createProviderId,
findAuthIdentity,
findAuthWithUserBy,
createAuthToken,
Expand All @@ -20,7 +21,9 @@ export function getLoginRoute({
const fields = req.body ?? {}
ensureValidArgs(fields)

const authIdentity = await findAuthIdentity("email", fields.email)
const authIdentity = await findAuthIdentity(
createProviderId("email", fields.email)
)
if (!authIdentity) {
throwInvalidCredentialsError()
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import { Request, Response } from 'express';
import {
createProviderId,
findAuthIdentity,
doFakeWork,
deserializeAndSanitizeProviderData,
Expand Down Expand Up @@ -29,7 +30,9 @@ export function getRequestPasswordResetRoute({
const args = req.body ?? {};
ensureValidEmail(args);

const authIdentity = await findAuthIdentity("email", args.email);
const authIdentity = await findAuthIdentity(
createProviderId("email", args.email),
);

// User not found or not verified - don't leak information
if (!authIdentity) {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import { Request, Response } from 'express';
import {
createProviderId,
findAuthIdentity,
updateAuthIdentityProviderData,
verifyToken,
Expand All @@ -19,13 +20,14 @@ export async function resetPassword(
try {
const { id: email } = await verifyToken(token);

const authIdentity = await findAuthIdentity('email', email);
const providerId = createProviderId('email', email);
const authIdentity = await findAuthIdentity(providerId);
if (!authIdentity) {
return res.status(400).json({ success: false, message: 'Invalid token' });
}

const providerData = deserializeAndSanitizeProviderData<'email'>(authIdentity.providerData);
await updateAuthIdentityProviderData('email', email, providerData, {
await updateAuthIdentityProviderData(providerId, providerData, {
// The act of resetting the password verifies the email
isEmailVerified: true,
password,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ import { Request, Response } from 'express';
import { EmailFromField } from "../../../email/core/types.js";
import {
createUser,
createProviderId,
findAuthIdentity,
deleteUserByAuthId,
doFakeWork,
Expand Down Expand Up @@ -33,7 +34,8 @@ export function getSignupRoute({
const fields = req.body;
ensureValidArgs(fields);

const existingAuthIdentity = await findAuthIdentity("email", fields.email);
const providerId = createProviderId("email", fields.email);
const existingAuthIdentity = await findAuthIdentity(providerId);
if (existingAuthIdentity) {
const providerData = deserializeAndSanitizeProviderData<'email'>(existingAuthIdentity.providerData);
// User already exists and is verified - don't leak information
Expand All @@ -60,8 +62,7 @@ export function getSignupRoute({
});

const user = await createUser(
'email',
fields.email,
providerId,
newUserProviderData,
userFields,
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import { emailSender } from '../../../email/index.js';
import { Email } from '../../../email/core/types.js';
import {
rethrowPossibleAuthError,
createProviderId,
updateAuthIdentityProviderData,
findAuthIdentity,
deserializeAndSanitizeProviderData,
Expand Down Expand Up @@ -52,9 +53,10 @@ async function sendEmailAndLogTimestamp(
// so the user can't send multiple requests while
// the email is being sent.
try {
const authIdentity = await findAuthIdentity("email", email);
const providerId = createProviderId("email", email);
const authIdentity = await findAuthIdentity(providerId);
const providerData = deserializeAndSanitizeProviderData<'email'>(authIdentity.providerData);
await updateAuthIdentityProviderData<'email'>('email', email, providerData, {
await updateAuthIdentityProviderData<'email'>(providerId, providerData, {
[field]: new Date()
});
} catch (e) {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import { Request, Response } from 'express';
import {
verifyToken,
createProviderId,
findAuthIdentity,
updateAuthIdentityProviderData,
deserializeAndSanitizeProviderData,
Expand All @@ -15,9 +16,10 @@ export async function verifyEmail(
const { token } = req.body;
const { id: email } = await verifyToken(token);

const authIdentity = await findAuthIdentity('email', email);
const providerId = createProviderId('email', email);
const authIdentity = await findAuthIdentity(providerId);
const providerData = deserializeAndSanitizeProviderData<'email'>(authIdentity.providerData);
await updateAuthIdentityProviderData('email', email, providerData, {
await updateAuthIdentityProviderData(providerId, providerData, {
isEmailVerified: true,
});
} catch (e) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import { verifyPassword, throwInvalidCredentialsError } from '../../../core/auth
import { handleRejection } from '../../../utils.js'

import {
createProviderId,
findAuthIdentity,
findAuthWithUserBy,
createAuthToken,
Expand All @@ -14,7 +15,9 @@ export default handleRejection(async (req, res) => {
const fields = req.body ?? {}
ensureValidArgs(fields)

const authIdentity = await findAuthIdentity('username', fields.username)
const authIdentity = await findAuthIdentity(
createProviderId('username', fields.username),
)
if (!authIdentity) {
throwInvalidCredentialsError()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import prisma from '../../../dbClient.js'
import waspServerConfig from '../../../config.js'
import {
type ProviderName,
createProviderUserId,
createProviderId,
authConfig,
contextWithUserEntity,
createUser,
Expand Down Expand Up @@ -53,16 +53,12 @@ export function createRouter(provider: ProviderConfig, initData: { passportStrat
// TODO: In the future we could make this configurable, possibly associating an external account
// with the currently logged in account, or by some DB lookup.

const providerName = provider.id;
const providerUserId = createProviderUserId(providerProfile.id);
const providerId = createProviderId(provider.id, providerProfile.id);

try {
const existingAuthIdentity = await prisma.{= authIdentityEntityLower =}.findUnique({
where: {
providerName_providerUserId: {
providerName,
providerUserId: providerUserId.id,
},
providerName_providerUserId: providerId,
},
include: {
{= authFieldOnAuthIdentityEntityName =}: {
Expand All @@ -84,8 +80,7 @@ export function createRouter(provider: ProviderConfig, initData: { passportStrat
const userFields = await getUserFields()

const user = await createUser(
providerName,
providerUserId,
providerId,
undefined,
userFields,
)
Expand Down
43 changes: 20 additions & 23 deletions waspc/data/Generator/templates/server/src/auth/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import {
} from '../entities/index.js'
import { Prisma } from '@prisma/client';

import { PASSWORD_FIELD, throwValidationError } from './validation.js'
import { throwValidationError } from './validation.js'

{=# additionalSignupFields.isDefined =}
{=& additionalSignupFields.importStatement =}
Expand Down Expand Up @@ -58,30 +58,31 @@ export const authConfig = {
successRedirectPath: "{= successRedirectPath =}",
}

type ProviderUserId = {
id: string;
// ProviderId represents one user account in a specific provider.
// We are packing it into a single object to make it easier to
// make sure that the providerUserId is always lowercased.
type ProviderId = {
providerName: ProviderName;
providerUserId: string;
}

export function createProviderUserId(providerUserId: string): ProviderUserId {
export function createProviderId(providerName: ProviderName, providerUserId: string): ProviderId {
return {
id: providerUserId.toLowerCase(),
providerName,
providerUserId: providerUserId.toLowerCase(),
}
}

export async function findAuthIdentity(providerName: ProviderName, providerUserId: ProviderUserId): Promise<{= authIdentityEntityUpper =} | null> {
export async function findAuthIdentity(providerId: ProviderId): Promise<{= authIdentityEntityUpper =} | null> {
return prisma.{= authIdentityEntityLower =}.findUnique({
where: {
providerName_providerUserId: {
providerName,
providerUserId: providerUserId.id,
}
providerName_providerUserId: providerId,
}
});
}

export async function updateAuthIdentityProviderData<PN extends ProviderName>(
providerName: ProviderName,
providerUserId: ProviderUserId,
providerId: ProviderId,
existingProviderData: PossibleProviderData[PN],
providerDataUpdates: Partial<PossibleProviderData[PN]>,
): Promise<{= authIdentityEntityUpper =}> {
Expand All @@ -95,10 +96,7 @@ export async function updateAuthIdentityProviderData<PN extends ProviderName>(
const serializedProviderData = await serializeProviderData<PN>(newProviderData);
return prisma.{= authIdentityEntityLower =}.update({
where: {
providerName_providerUserId: {
providerName,
providerUserId: providerUserId.id,
}
providerName_providerUserId: providerId,
},
data: { providerData: serializedProviderData },
});
Expand All @@ -115,8 +113,7 @@ export async function findAuthWithUserBy(
}

export async function createUser(
providerName: ProviderName,
providerUserId: ProviderUserId,
providerId: ProviderId,
serializedProviderData?: string,
userFields?: PossibleAdditionalSignupFields,
): Promise<{= userEntityUpper =}> {
Expand All @@ -130,8 +127,8 @@ export async function createUser(
create: {
{= identitiesFieldOnAuthEntityName =}: {
create: {
providerName,
providerUserId: providerUserId.id,
providerName: providerId.providerName,
providerUserId: providerId.providerUserId,
providerData: serializedProviderData,
},
},
Expand Down Expand Up @@ -242,7 +239,7 @@ export function deserializeAndSanitizeProviderData<PN extends ProviderName>(
let data = JSON.parse(providerData) as PossibleProviderData[PN];

if (providerDataHasPasswordField(data) && shouldRemovePasswordField) {
delete data[PASSWORD_FIELD];
delete data.hashedPassword;
}

return data;
Expand All @@ -267,7 +264,7 @@ async function sanitizeProviderData<PN extends ProviderName>(
...providerData,
};
if (providerDataHasPasswordField(data)) {
data[PASSWORD_FIELD] = await hashPassword(data[PASSWORD_FIELD]);
data.hashedPassword = await hashPassword(data.hashedPassword);
}

return data;
Expand All @@ -277,5 +274,5 @@ async function sanitizeProviderData<PN extends ProviderName>(
function providerDataHasPasswordField(
providerData: PossibleProviderData[keyof PossibleProviderData],
): providerData is { hashedPassword: string } {
return PASSWORD_FIELD in providerData;
return 'hashedPassword' in providerData;
}
2 changes: 1 addition & 1 deletion waspc/src/Wasp/Generator/DbGenerator.hs
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ genPrismaSchema spec = do
then logAndThrowGeneratorError $ GenericGeneratorError "SQLite (a default database) is not supported in production. To build your Wasp app for production, switch to a different database. Switching to PostgreSQL: https://wasp-lang.dev/docs/data-model/backends#migrating-from-sqlite-to-postgresql ."
else return ("sqlite", "\"file:./dev.db\"")

entities <- DbAuth.injectAuth maybeUserEntity userDefinedEntities
entities <- maybe (return userDefinedEntities) (DbAuth.injectAuth userDefinedEntities) maybeUserEntity

let templateData =
object
Expand Down
4 changes: 2 additions & 2 deletions waspc/src/Wasp/Generator/DbGenerator/Auth.hs
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,8 @@ authIdentityEntityName = "AuthIdentity"
identitiesFieldOnAuthEntityName :: String
identitiesFieldOnAuthEntityName = "identities"

injectAuth :: (String, AS.Entity.Entity) -> [(String, AS.Entity.Entity)] -> Generator [(String, AS.Entity.Entity)]
injectAuth (userEntityName, userEntity) entities = do
injectAuth :: [(String, AS.Entity.Entity)] -> (String, AS.Entity.Entity) -> Generator [(String, AS.Entity.Entity)]
injectAuth entities (userEntityName, userEntity) = do
authEntity <- makeAuthEntity userEntityIdField (userEntityName, userEntity)
authIdentityEntity <- makeAuthIdentityEntity
let entitiesWithAuth = injectAuthIntoUserEntity userEntityName entities
Expand Down

0 comments on commit ce1c89b

Please sign in to comment.