Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
226 changes: 226 additions & 0 deletions backend/src/common/guards/rpc-throttle.guard.spec.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,226 @@
import { Test, TestingModule } from '@nestjs/testing';
import { ExecutionContext, HttpException } from '@nestjs/common';
import { ThrottlerException, ThrottlerModuleOptions, ThrottlerStorage } from '@nestjs/throttler';
import { Reflector } from '@nestjs/core';
import { RpcThrottleGuard } from './rpc-throttle.guard';

describe('RpcThrottleGuard', () => {
let guard: RpcThrottleGuard;
let mockExecutionContext: Partial<ExecutionContext>;
let mockRequest: any;
let mockResponse: any;
let mockReflector: any;
let mockStorageService: any;

beforeEach(() => {
// Mock Reflector
mockReflector = {
get: jest.fn().mockReturnValue(undefined),
};

// Mock ThrottlerStorage
mockStorageService = {
increment: jest.fn().mockResolvedValue([1, 60]),
reset: jest.fn(),
};

// Mock ThrottlerModuleOptions
const mockOptions: Partial<ThrottlerModuleOptions> = {
throttlers: [],
};

// Initialize the guard with mocked dependencies
guard = new RpcThrottleGuard(
mockOptions as any,
mockStorageService as any,
mockReflector as any,
);

// Mock response object
mockResponse = {
setHeader: jest.fn(),
};

// Mock request object
mockRequest = {
ip: '127.0.0.1',
method: 'GET',
path: '/savings/my-subscriptions',
user: null,
connection: {
remoteAddress: '127.0.0.1',
},
};

// Mock execution context
mockExecutionContext = {
switchToHttp: jest.fn().mockReturnValue({
getRequest: jest.fn().mockReturnValue(mockRequest),
getResponse: jest.fn().mockReturnValue(mockResponse),
}),
} as any;
});

describe('getTracker', () => {
it('should return user ID-based tracker when user is authenticated', async () => {
mockRequest.user = { id: 'user-123', email: 'test@example.com' };

const tracker = await (guard as any).getTracker(mockRequest);
expect(tracker).toBe('rpc-throttle:user-123');
});

it('should return IP-based tracker when user is not authenticated', async () => {
mockRequest.user = null;
mockRequest.ip = '192.168.0.1';

const tracker = await (guard as any).getTracker(mockRequest);
expect(tracker).toBe('rpc-throttle:192.168.0.1');
});

it('should fallback to connection.remoteAddress if req.ip is not available', async () => {
mockRequest.user = null;
mockRequest.ip = null;
mockRequest.connection.remoteAddress = '10.0.0.1';

const tracker = await (guard as any).getTracker(mockRequest);
expect(tracker).toBe('rpc-throttle:10.0.0.1');
});

it('should return "unknown" if both ip and remoteAddress are unavailable', async () => {
mockRequest.user = null;
mockRequest.ip = null;
mockRequest.connection = null;

const tracker = await (guard as any).getTracker(mockRequest);
expect(tracker).toBe('rpc-throttle:unknown');
});

it('should prefer user ID over IP even if IP is available', async () => {
mockRequest.user = { id: 'user-456', email: 'another@example.com' };
mockRequest.ip = '192.168.1.1';

const tracker = await (guard as any).getTracker(mockRequest);
expect(tracker).toBe('rpc-throttle:user-456');
});
});

describe('onLimitExceeded', () => {
it('should throw ThrottlerException with correct message', async () => {
const context = mockExecutionContext as ExecutionContext;
const limit = 10;
const ttl = 60000; // 1 minute

await expect(
guard.onLimitExceeded(context, limit, ttl),
).rejects.toThrow('Too many RPC requests');
});

it('should set Retry-After header', async () => {
const context = mockExecutionContext as ExecutionContext;
const limit = 10;
const ttl = 60000;

try {
await guard.onLimitExceeded(context, limit, ttl);
} catch (e) {
// Expected to throw
}

expect(mockResponse.setHeader).toHaveBeenCalledWith(
'Retry-After',
expect.any(Number),
);
});

it('should set X-RateLimit headers', async () => {
const context = mockExecutionContext as ExecutionContext;
const limit = 10;
const ttl = 60000;

try {
await guard.onLimitExceeded(context, limit, ttl);
} catch (e) {
// Expected to throw
}

expect(mockResponse.setHeader).toHaveBeenCalledWith(
'X-RateLimit-Limit',
limit,
);
expect(mockResponse.setHeader).toHaveBeenCalledWith(
'X-RateLimit-Remaining',
0,
);
expect(mockResponse.setHeader).toHaveBeenCalledWith(
'X-RateLimit-Reset',
expect.any(String),
);
});

it('should include user ID in log message when available', async () => {
mockRequest.user = { id: 'user-789', email: 'user@example.com' };
const context = mockExecutionContext as ExecutionContext;
const loggerWarnSpy = jest.spyOn((guard as any).logger, 'warn');

try {
await guard.onLimitExceeded(context, 10, 60000);
} catch (e) {
// Expected to throw
}

expect(loggerWarnSpy).toHaveBeenCalledWith(
expect.stringContaining('user-789'),
);
});

it('should calculate Retry-After in seconds', async () => {
const context = mockExecutionContext as ExecutionContext;
const limit = 10;
const ttl = 30000; // 30 seconds

try {
await guard.onLimitExceeded(context, limit, ttl);
} catch (e) {
// Expected to throw
}

expect(mockResponse.setHeader).toHaveBeenCalledWith(
'Retry-After',
30, // 30000ms / 1000 = 30 seconds (rounded up)
);
});
});

describe('User-ID Based Tracking', () => {
it('should allow different users to have independent rate limits', async () => {
const user1Request = {
...mockRequest,
user: { id: 'user-1' },
};
const user2Request = {
...mockRequest,
user: { id: 'user-2' },
};

const tracker1 = await (guard as any).getTracker(user1Request);
const tracker2 = await (guard as any).getTracker(user2Request);

expect(tracker1).not.toBe(tracker2);
expect(tracker1).toBe('rpc-throttle:user-1');
expect(tracker2).toBe('rpc-throttle:user-2');
});

it('should treat different IPs from same user as same tracker', async () => {
mockRequest.user = { id: 'user-same' };
mockRequest.ip = '192.168.0.1';

const tracker1 = await (guard as any).getTracker(mockRequest);

mockRequest.ip = '10.0.0.1'; // Different IP
const tracker2 = await (guard as any).getTracker(mockRequest);

expect(tracker1).toBe(tracker2);
expect(tracker1).toBe('rpc-throttle:user-same');
});
});
});
51 changes: 37 additions & 14 deletions backend/src/common/guards/rpc-throttle.guard.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import {
HttpStatus,
Logger,
} from '@nestjs/common';
import { ThrottlerGuard } from '@nestjs/throttler';
import { ThrottlerGuard, ThrottlerException } from '@nestjs/throttler';
import { Request } from 'express';

/**
Expand All @@ -18,6 +18,12 @@ import { Request } from 'express';
* Configuration:
* - GET /savings/my-subscriptions: 10 requests per minute per User ID
* - Other RPC endpoints: configurable via decorator
*
* Key Features:
* - User-ID based tracking (when authenticated) to prevent IP-shifting bypasses
* - IP fallback for unauthenticated RPC calls
* - Automatic Retry-After and X-RateLimit header injection
* - Custom error messages with limit/TTL information
*/
@Injectable()
export class RpcThrottleGuard extends ThrottlerGuard {
Expand All @@ -26,40 +32,57 @@ export class RpcThrottleGuard extends ThrottlerGuard {
/**
* Override getTracker to use User ID instead of IP address
* This ensures rate limiting is per-user, not per-IP
*
* Prioritization:
* 1. If user is authenticated (has req.user.id), use User ID
* 2. Otherwise fall back to IP address
*/
protected async getTracker(req: Record<string, any>): Promise<string> {
// Extract user ID from JWT token in request
// Extract user ID from JWT token in request (set by JwtAuthGuard)
const user = req.user;

if (!user || !user.id) {
this.logger.warn(
`RpcThrottleGuard: No user found in request to ${req.path}`,
);
// Fallback to IP if no user (shouldn't happen with JwtAuthGuard)
return req.ip || 'unknown';
if (user && user.id) {
// Use User ID for authenticated requests
return `rpc-throttle:${user.id}`;
}

return `rpc-throttle:${user.id}`;
// Fallback to IP for unauthenticated requests
const ip = req.ip || req.connection?.remoteAddress || 'unknown';
this.logger.debug(
`RpcThrottleGuard: Using IP-based tracking for ${req.method} ${req.path} (IP: ${ip})`,
);
return `rpc-throttle:${ip}`;
}

/**
* Override onLimitExceeded to provide custom error response
* Override onLimitExceeded to throw ThrottlerException (429)
* This integrates seamlessly with NestJS error handling
*/
async onLimitExceeded(
context: ExecutionContext,
limit: number,
ttl: number,
): Promise<void> {
const request = context.switchToHttp().getRequest<Request>();
const response = context.switchToHttp().getResponse();
const user = (request as any).user;

// Log the rate limit breach
this.logger.warn(
`RPC rate limit exceeded for user ${user?.id || 'unknown'} on ${request.method} ${request.path}. Limit: ${limit} requests per ${ttl}ms`,
`[RPC Rate Limit] User/IP: ${user?.id || request.ip || 'unknown'} | ` +
`Route: ${request.method} ${request.path} | ` +
`Limit: ${limit} req/${Math.round(ttl / 1000)}s`,
);

throw new HttpException(
`Too many RPC requests. Maximum ${limit} requests per ${Math.round(ttl / 1000)} seconds allowed.`,
HttpStatus.TOO_MANY_REQUESTS,
// Set Retry-After header (standard HTTP 429 behavior)
response.setHeader('Retry-After', Math.ceil(ttl / 1000));
response.setHeader('X-RateLimit-Limit', limit);
response.setHeader('X-RateLimit-Remaining', 0);
response.setHeader('X-RateLimit-Reset', new Date(Date.now() + ttl).toISOString());

// Throw ThrottlerException which results in HTTP 429
throw new ThrottlerException(
`Too many RPC requests. Maximum ${limit} requests per ${Math.round(ttl / 1000)} seconds allowed.`
);
}
}
5 changes: 3 additions & 2 deletions backend/src/modules/savings/savings.controller.ts
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import { UpdateGoalDto } from './dto/update-goal.dto';
import { ProductDetailsDto } from './dto/product-details.dto';
import { JwtAuthGuard } from '../../auth/guards/jwt-auth.guard';
import { CurrentUser } from '../../common/decorators/current-user.decorator';
import { RpcThrottleGuard } from '../../common/guards/rpc-throttle.guard';
import {
SavingsGoalProgress,
UserSubscriptionWithLiveBalance,
Expand Down Expand Up @@ -129,7 +130,7 @@ export class SavingsController {

@Get('my-subscriptions')
@Throttle({ rpc: { limit: 10, ttl: 60000 } })
@UseGuards(JwtAuthGuard)
@UseGuards(JwtAuthGuard, RpcThrottleGuard)
@ApiBearerAuth()
@ApiOperation({ summary: 'Get current user subscriptions' })
@ApiResponse({ status: 200, description: 'List of user subscriptions' })
Expand All @@ -143,7 +144,7 @@ export class SavingsController {

@Get('my-goals')
@Throttle({ rpc: { limit: 10, ttl: 60000 } })
@UseGuards(JwtAuthGuard)
@UseGuards(JwtAuthGuard, RpcThrottleGuard)
@ApiBearerAuth()
@ApiOperation({
summary:
Expand Down
Loading
Loading