Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/config/configuration.ts
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,9 @@ export default () => ({
payments: {
url: process.env.PAYMENTS_API_URL,
},
mail: {
url: process.env.MAIL_API_URL,
},
},
apn: {
url: process.env.APN_URL,
Expand Down
12 changes: 12 additions & 0 deletions src/externals/mail/mail.module.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import { Module } from '@nestjs/common';
import { ConfigModule } from '@nestjs/config';
import { HttpClientModule } from '../http/http.module';
import { MailService } from './mail.service';

@Module({
imports: [ConfigModule, HttpClientModule],
controllers: [],
providers: [MailService],
exports: [MailService],
})
export class MailServiceModule {}
104 changes: 104 additions & 0 deletions src/externals/mail/mail.service.spec.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
import { Test, type TestingModule } from '@nestjs/testing';
import { type Logger } from '@nestjs/common';
import { createMock } from '@golevelup/ts-jest';
import { ConfigService } from '@nestjs/config';
import { type AxiosResponse, type InternalAxiosRequestConfig } from 'axios';
import { HttpClient } from '../http/http.service';
import { MailService } from './mail.service';

jest.mock('jsonwebtoken', () => ({
sign: jest.fn().mockReturnValue('mock-gateway-jwt'),
}));

describe('MailService', () => {
let service: MailService;
let configService: ConfigService;
let httpClient: HttpClient;

const emptyAxiosResponse = <T>(data: T): AxiosResponse<T> => ({
data,
status: 200,
statusText: 'OK',
headers: {},
config: {} as InternalAxiosRequestConfig,
});

beforeEach(async () => {
const module: TestingModule = await Test.createTestingModule({
providers: [
MailService,
{ provide: ConfigService, useValue: { get: jest.fn() } },
{ provide: HttpClient, useValue: { get: jest.fn() } },
],
})
.setLogger(createMock<Logger>())
.compile();

service = module.get(MailService);
configService = module.get(ConfigService);
httpClient = module.get(HttpClient);
});

describe('findUserIdByAddress', () => {
const baseUrl = 'https://mail.test';

beforeEach(() => {
jest.spyOn(configService, 'get').mockImplementation((key: string) => {
if (key === 'apis.mail.url') return baseUrl;
if (key === 'isDevelopment') return false;
if (key === 'secrets.gateway') return 'ZHVtbXk=';
return undefined;
});
});

it('When the gateway returns a userId, then it returns that userId', async () => {
const address = '[email protected]';
const userId = 'resolved-user-uuid';
jest
.spyOn(httpClient, 'get')
.mockResolvedValueOnce(emptyAxiosResponse({ userId }));

const result = await service.findUserIdByAddress(address);

expect(result).toBe(userId);
expect(httpClient.get).toHaveBeenCalledWith(
`${baseUrl}/gateway/addresses/${encodeURIComponent(address)}`,
expect.objectContaining({
headers: expect.objectContaining({
Authorization: 'Bearer mock-gateway-jwt',
'Content-Type': 'application/json',
}),
}),
);
});

it('When the gateway returns no userId, then it returns null', async () => {
jest
.spyOn(httpClient, 'get')
.mockResolvedValueOnce(emptyAxiosResponse({}));

const result = await service.findUserIdByAddress('[email protected]');

expect(result).toBeNull();
});

it('When the gateway responds with 404, then it returns null', async () => {
jest.spyOn(httpClient, 'get').mockRejectedValueOnce({
response: { status: 404 },
});

const result = await service.findUserIdByAddress('[email protected]');

expect(result).toBeNull();
});

it('When the gateway responds with a non-404 error, then it propagates the error', async () => {
const err = new Error('upstream');
jest.spyOn(httpClient, 'get').mockRejectedValueOnce(err);

await expect(service.findUserIdByAddress('[email protected]')).rejects.toThrow(
err,
);
});
});
});
55 changes: 55 additions & 0 deletions src/externals/mail/mail.service.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import { Inject, Injectable } from '@nestjs/common';
import { sign } from 'jsonwebtoken';
import { ConfigService } from '@nestjs/config';
import { HttpClient } from '../http/http.service';

function signToken(duration: string, secret: string, isDevelopment?: boolean) {
return sign({}, Buffer.from(secret, 'base64').toString('utf8'), {
algorithm: 'RS256',
expiresIn: duration,
...(isDevelopment ? { allowInsecureKeySizes: true } : null),
});
}

@Injectable()
export class MailService {
constructor(
@Inject(ConfigService)
private readonly configService: ConfigService,
@Inject(HttpClient)
private readonly httpClient: HttpClient,
) {}

private getAuthHeaders() {
const isDevelopment = this.configService.get('isDevelopment');
const jwt = signToken(
'5m',
this.configService.get('secrets.gateway'),
isDevelopment,
);

return {
'Content-Type': 'application/json',
Authorization: `Bearer ${jwt}`,
};
}

async findUserIdByAddress(address: string): Promise<string | null> {
const baseUrl = this.configService.get('apis.mail.url');
const headers = this.getAuthHeaders();

try {
const res = await this.httpClient.get(
`${baseUrl}/gateway/addresses/${encodeURIComponent(address)}`,
{ headers },
);

return res.data?.userId ?? null;
} catch (error) {
if (error?.response?.status === 404) {
return null;
}
throw error;
}
}
}
105 changes: 105 additions & 0 deletions src/modules/auth/auth.controller.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import { FeatureLimitService } from '../feature-limit/feature-limit.service';
import { PaymentRequiredException } from '../feature-limit/exceptions/payment-required.exception';
import { PlatformName } from '../../common/constants';
import { ClientEnum } from '../../common/enums/platform.enum';
import { MailService } from '../../externals/mail/mail.service';

describe('AuthController', () => {
let authController: AuthController;
Expand All @@ -32,6 +33,7 @@ describe('AuthController', () => {
let cryptoService: DeepMocked<CryptoService>;
let twoFactorAuthService: DeepMocked<TwoFactorAuthService>;
let featureLimitService: DeepMocked<FeatureLimitService>;
let mailService: DeepMocked<MailService>;

beforeEach(async () => {
const moduleRef = await Test.createTestingModule({
Expand All @@ -46,6 +48,7 @@ describe('AuthController', () => {
cryptoService = moduleRef.get(CryptoService);
twoFactorAuthService = moduleRef.get(TwoFactorAuthService);
featureLimitService = moduleRef.get(FeatureLimitService);
mailService = moduleRef.get(MailService);
});

it('should be defined', () => {
Expand Down Expand Up @@ -106,6 +109,73 @@ describe('AuthController', () => {

expect(userUseCases.findByEmail).toHaveBeenCalledWith(emailLowerCase);
});

it('When the email is not a managed mail domain, then it should not call the mail gateway', async () => {
const loginDto = new LoginDto();
loginDto.email = '[email protected]';
const user = newUser({ attributes: { email: loginDto.email } });

jest.spyOn(userUseCases, 'findByEmail').mockResolvedValueOnce(user);
jest.spyOn(keyServerUseCases, 'findUserKeys').mockResolvedValueOnce({
ecc: newKeyServer({ userId: user.id }),
kyber: null,
});
jest.spyOn(cryptoService, 'encryptText').mockReturnValue('x');

await authController.login(loginDto);

expect(mailService.findUserIdByAddress).not.toHaveBeenCalled();
});

it('When the email is managed and the gateway resolves a primary email, then findByEmail uses the primary email', async () => {
const loginDto = new LoginDto();
loginDto.email = '[email protected]';
const canonicalEmail = '[email protected]';
const resolvedUuid = v4();
const loginUser = newUser({
attributes: { email: canonicalEmail, id: 9 },
});

jest
.spyOn(mailService, 'findUserIdByAddress')
.mockResolvedValueOnce(resolvedUuid);
jest.spyOn(userUseCases, 'findByUuid').mockResolvedValueOnce(loginUser);
jest.spyOn(userUseCases, 'findByEmail').mockResolvedValueOnce(loginUser);
jest.spyOn(keyServerUseCases, 'findUserKeys').mockResolvedValueOnce({
ecc: newKeyServer({ userId: loginUser.id }),
kyber: null,
});
jest.spyOn(cryptoService, 'encryptText').mockReturnValue('encryptedText');

await authController.login(loginDto);

expect(mailService.findUserIdByAddress).toHaveBeenCalledWith(
loginDto.email,
);
expect(userUseCases.findByUuid).toHaveBeenCalledWith(resolvedUuid);
expect(userUseCases.findByEmail).toHaveBeenCalledWith(canonicalEmail);
});

it('When the email is managed but the gateway finds no user, then findByEmail uses the original address', async () => {
const loginDto = new LoginDto();
loginDto.email = '[email protected]';
const user = newUser({ attributes: { email: loginDto.email } });

jest
.spyOn(mailService, 'findUserIdByAddress')
.mockResolvedValueOnce(null);
jest.spyOn(userUseCases, 'findByEmail').mockResolvedValueOnce(user);
jest.spyOn(keyServerUseCases, 'findUserKeys').mockResolvedValueOnce({
ecc: newKeyServer({ userId: user.id }),
kyber: null,
});
jest.spyOn(cryptoService, 'encryptText').mockReturnValue('encryptedText');

await authController.login(loginDto);

expect(userUseCases.findByUuid).not.toHaveBeenCalled();
expect(userUseCases.findByEmail).toHaveBeenCalledWith(loginDto.email);
});
});

describe('POST /login/access', () => {
Expand Down Expand Up @@ -192,6 +262,41 @@ describe('AuthController', () => {
},
});
});

it('When the email is managed and resolves to a primary account, then loginAccess receives the resolved email', async () => {
const dto = { ...loginAccessDto };
dto.email = '[email protected]';
const canonicalEmail = '[email protected]';
const eccKey = newKeyServer({ ...dto });
const driveUser = newUser({ attributes: { email: canonicalEmail } });

jest.spyOn(keyServerUseCases, 'parseKeysInput').mockReturnValueOnce({
ecc: eccKey.toJSON(),
kyber: null,
});
jest
.spyOn(mailService, 'findUserIdByAddress')
.mockResolvedValueOnce(driveUser.uuid);
jest.spyOn(userUseCases, 'findByUuid').mockResolvedValueOnce(driveUser);
jest
.spyOn(userUseCases, 'loginAccess')
.mockResolvedValueOnce({ success: true } as any);

await authController.loginAccess(dto);

expect(userUseCases.loginAccess).toHaveBeenCalledWith({
...dto,
email: canonicalEmail,
keys: {
ecc: {
publicKey: eccKey.publicKey,
privateKey: eccKey.privateKey,
revocationKey: eccKey.revocationKey,
},
kyber: null,
},
});
});
});

describe('GET /logout', () => {
Expand Down
21 changes: 20 additions & 1 deletion src/modules/auth/auth.controller.ts
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ import { FeatureLimitService } from '../feature-limit/feature-limit.service';
import { PaymentRequiredException } from '../feature-limit/exceptions/payment-required.exception';
import { Client } from '../../common/decorators/client.decorator';
import { type ClientEnum } from '../../common/enums/platform.enum';
import { MailService } from '../../externals/mail/mail.service';
import { isManagedMailDomain } from './managed-mail-domains';

@ApiTags('Auth')
@Controller('auth')
Expand All @@ -60,8 +62,19 @@ export class AuthController {
private readonly twoFactorAuthService: TwoFactorAuthService,
private readonly authUseCases: AuthUsecases,
private readonly featureLimitService: FeatureLimitService,
private readonly mailService: MailService,
) {}

private async resolveLoginEmail(email: string): Promise<string> {
if (!isManagedMailDomain(email)) return email;

const userId = await this.mailService.findUserIdByAddress(email);
if (!userId) return email;

const user = await this.userUseCases.findByUuid(userId);
return user?.email ?? email;
}

@Post('/login')
@HttpCode(HttpStatus.OK)
@ApiOperation({
Expand All @@ -70,7 +83,7 @@ export class AuthController {
@ApiOkResponse({ description: 'Retrieve details', type: LoginResponseDto })
@Public()
async login(@Body() body: LoginDto): Promise<LoginResponseDto> {
const email = body.email.toLowerCase();
const email = await this.resolveLoginEmail(body.email.toLowerCase());

const user = await this.userUseCases.findByEmail(email);

Expand Down Expand Up @@ -128,8 +141,11 @@ export class AuthController {
revocationKey: body.revocateKey,
});

const email = await this.resolveLoginEmail(body.email.toLowerCase());

const result = await this.userUseCases.loginAccess({
...body,
email,
keys: { kyber, ecc },
});

Expand Down Expand Up @@ -299,8 +315,11 @@ export class AuthController {
revocationKey: body.revocateKey,
});

const email = await this.resolveLoginEmail(body.email.toLowerCase());

const result = await this.userUseCases.loginAccess({
...body,
email,
keys: { kyber, ecc },
platform,
});
Expand Down
Loading
Loading