Skip to content
Open
129 changes: 114 additions & 15 deletions src/rpc/__tests__/dial.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ import {
} from '../../__mocks__/webrtc';
import { withICEServers } from '../__fixtures__/dial-webrtc-options';
import { createMockTransport } from '../../__mocks__/transports';
import { createMockSignalingExchange } from '../__mocks__/signaling-exchanges';
import { ClientChannel } from '../client-channel';
import type { Transport } from '@connectrpc/connect';

vi.mock('../peer');
vi.mock('../signaling-exchange');
Expand All @@ -52,28 +52,33 @@ const setupDialWebRTCMocks = () => {
const peerConnection = createMockPeerConnection();
const dataChannel = createMockDataChannel();
const transport = createMockTransport();
const signalingExchange = createMockSignalingExchange(transport);

vi.mocked(newPeerConnectionForClient).mockResolvedValue({
pc: peerConnection,
dc: dataChannel,
});

vi.mocked(SignalingExchange).mockImplementation(() => signalingExchange);

const optionalWebRTCConfigFn = vi.fn().mockResolvedValue({
config: {
additionalIceServers: [],
disableTrickle: false,
},
});

vi.mocked(createClient).mockReturnValue({
const mockClient = {
optionalWebRTCConfig: optionalWebRTCConfigFn,
} as unknown as ReturnType<typeof createClient>);
} as unknown as ReturnType<typeof createClient>;

vi.mocked(createClient).mockReturnValue(mockClient);
vi.mocked(createGrpcWebTransport).mockReturnValue(transport);

const signalingExchange = {
doExchange: vi.fn().mockResolvedValue(transport),
terminate: vi.fn(),
} as unknown as SignalingExchange;

vi.mocked(SignalingExchange).mockImplementation(() => signalingExchange);

return {
peerConnection,
dataChannel,
Expand Down Expand Up @@ -207,21 +212,18 @@ describe('dialWebRTC', () => {
expect(vi.mocked(peerConnection.close)).toHaveBeenCalled();
});

it('should close peer connection if dialDirect fails', async () => {
it('should propagate error if transport creation fails', async () => {
// Arrange
const { peerConnection, transport } = setupDialWebRTCMocks();
// First call succeeds (getOptionalWebRTCConfig), second call fails (signaling)
vi.mocked(createGrpcWebTransport)
.mockReturnValueOnce(transport)
.mockImplementationOnce(() => {
throw new Error('Transport creation failed');
});
setupDialWebRTCMocks();
vi.mocked(createGrpcWebTransport).mockImplementation(() => {
throw new Error('Transport creation failed');
});

// Act & Assert
await expect(dialWebRTC(TEST_URL, TEST_HOST)).rejects.toThrow(
'Transport creation failed'
);
expect(vi.mocked(peerConnection.close)).toHaveBeenCalled();
expect(newPeerConnectionForClient).not.toHaveBeenCalled();
});

it('should rethrow errors after cleanup', async () => {
Expand Down Expand Up @@ -327,6 +329,103 @@ describe('validateDialOptions', () => {
});
});

describe('resource management', () => {
it('should reuse a single transport for config fetching and signaling', async () => {
// Arrange
setupDialWebRTCMocks();

// Act
await dialWebRTC(TEST_URL, TEST_HOST);

// Assert
expect(createGrpcWebTransport).toHaveBeenCalledTimes(1);
expect(createGrpcWebTransport).toHaveBeenCalledWith({
baseUrl: TEST_URL,
credentials: 'same-origin',
});
});

it('should reuse a single signaling client for config fetching and signaling', async () => {
// Arrange
setupDialWebRTCMocks();

// Act
await dialWebRTC(TEST_URL, TEST_HOST);

// Assert
expect(createClient).toHaveBeenCalledTimes(1);
expect(createClient).toHaveBeenCalledWith(
expect.anything(),
expect.anything()
);
});

it('should not leak transports on successful connection', async () => {
// Arrange
const { transport } = setupDialWebRTCMocks();
const transportCount = { created: 0 };

vi.mocked(createGrpcWebTransport).mockImplementation(() => {
transportCount.created += 1;
return transport;
});

// Act
await dialWebRTC(TEST_URL, TEST_HOST);

// Assert
expect(transportCount.created).toBe(1);
});

it('should not leak transports on connection failure', async () => {
// Arrange
const { transport, signalingExchange } = setupDialWebRTCMocks();
const transportCount = { created: 0 };

vi.mocked(createGrpcWebTransport).mockImplementation(() => {
transportCount.created += 1;
return transport;
});

const error = new Error('Connection failed');
vi.mocked(signalingExchange.doExchange).mockRejectedValueOnce(error);

// Act
await dialWebRTC(TEST_URL, TEST_HOST).catch(() => {
// Ignore error for this test
});

// Assert
expect(transportCount.created).toBe(1);
});

it('should use the same transport reference for both config and signaling', async () => {
// Arrange
setupDialWebRTCMocks();
const capturedTransports: Transport[] = [];

vi.mocked(createClient).mockImplementation(
(_service, capturedTransport) => {
capturedTransports.push(capturedTransport);
return {
optionalWebRTCConfig: vi.fn().mockResolvedValue({
config: {
additionalIceServers: [],
disableTrickle: false,
},
}),
} as unknown as ReturnType<typeof createClient>;
}
);

// Act
await dialWebRTC(TEST_URL, TEST_HOST);

// Assert
expect(capturedTransports.length).toBe(1);
});
});

describe('dialDirect', () => {
afterEach(() => {
vi.restoreAllMocks();
Expand Down
71 changes: 30 additions & 41 deletions src/rpc/dial.ts
Original file line number Diff line number Diff line change
Expand Up @@ -309,20 +309,24 @@ export interface WebRTCConnection {
dataChannel: RTCDataChannel;
}

const getOptionalWebRTCConfig = async (
const getSignalingClient = async (
signalingAddress: string,
callOpts: CallOptions,
dialOpts?: DialOptions,
signalingExchangeOpts: DialOptions | undefined,
transportCredentialsInclude = false
): Promise<WebRTCConfig> => {
const optsCopy = { ...dialOpts } as DialOptions;
const directTransport = await dialDirect(
) => {
const transport = await dialDirect(
signalingAddress,
optsCopy,
signalingExchangeOpts,
transportCredentialsInclude
);

const signalingClient = createClient(SignalingService, directTransport);
return createClient(SignalingService, transport);
};

const getOptionalWebRTCConfig = async (
callOpts: CallOptions,
signalingClient: ReturnType<typeof createClient<typeof SignalingService>>
): Promise<WebRTCConfig> => {
try {
const resp = await signalingClient.optionalWebRTCConfig({}, callOpts);
return resp.config ?? new WebRTCConfig();
Expand Down Expand Up @@ -363,18 +367,25 @@ export const dialWebRTC = async (
};

/**
* First complete our WebRTC options, gathering any extra information like
* TURN servers from a cloud server.
* First, derive options specifically for signaling against our target. Then
* complete our WebRTC options, gathering any extra information like TURN
* servers from a cloud server. This also creates the transport and signaling
* client that we'll reuse to avoid resource leaks.
*/
const webrtcOpts = await processWebRTCOpts(
const exchangeOpts = processSignalingExchangeOpts(
usableSignalingAddress,
callOpts,
dialOpts,
transportCredentialsInclude
dialOpts
);
// then derive options specifically for signaling against our target.
const exchangeOpts = processSignalingExchangeOpts(

const signalingClient = await getSignalingClient(
usableSignalingAddress,
exchangeOpts,
transportCredentialsInclude
);

const webrtcOpts = await processWebRTCOpts(
signalingClient,
callOpts,
dialOpts
);

Expand All @@ -385,21 +396,6 @@ export const dialWebRTC = async (
);
let successful = false;

let directTransport: Transport;
try {
directTransport = await dialDirect(
usableSignalingAddress,
exchangeOpts,
transportCredentialsInclude
);
} catch (error) {
pc.close();
dc.close();
throw error;
}

const signalingClient = createClient(SignalingService, directTransport);

const exchange = new SignalingExchange(
signalingClient,
callOpts,
Expand Down Expand Up @@ -453,18 +449,11 @@ export const dialWebRTC = async (
};

const processWebRTCOpts = async (
signalingAddress: string,
signalingClient: ReturnType<typeof createClient<typeof SignalingService>>,
callOpts: CallOptions,
dialOpts?: DialOptions,
transportCredentialsInclude = false
dialOpts: DialOptions | undefined
): Promise<DialWebRTCOptions> => {
// Get TURN servers, if any.
const config = await getOptionalWebRTCConfig(
signalingAddress,
callOpts,
dialOpts,
transportCredentialsInclude
);
const config = await getOptionalWebRTCConfig(callOpts, signalingClient);
const additionalIceServers: RTCIceServer[] = config.additionalIceServers.map(
(ice) => {
const iceUrls = [];
Expand Down