Skip to content

Commit

Permalink
feat: implement message size validation to prevent excessive payloads (
Browse files Browse the repository at this point in the history
…#1197)

* feat: introduce message size validation

Added MAX_MESSAGE_LENGTH constant and implemented validation to ensure messages do not exceed the maximum allowed size in MobilePortStream and RemoteCommunicationPostMessageStream.

* feat: adding unit tests

* feat: linting

* feat: unit tests

* feat: cleanup

* feat: add size validation ot socket server
  • Loading branch information
abretonc7s authored Jan 15, 2025
1 parent 458c3fa commit a24e803
Show file tree
Hide file tree
Showing 7 changed files with 181 additions and 14 deletions.
1 change: 1 addition & 0 deletions packages/sdk-socket-server-next/src/config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ export const withAdminUI: boolean = process.env.ADMIN_UI === 'true';
const HOUR_IN_SECONDS = 60 * 60;
const THIRTY_DAYS_IN_SECONDS = 30 * 24 * 60 * 60; // expiration time of entries in Redis
export const MAX_CLIENTS_PER_ROOM = 2;
export const MAX_MESSAGE_LENGTH = 1_000_000; // 1MB limit

export const config = {
msgExpiry: HOUR_IN_SECONDS,
Expand Down
17 changes: 16 additions & 1 deletion packages/sdk-socket-server-next/src/protocol/handleMessage.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import { Server, Socket } from 'socket.io';
import { v4 as uuidv4 } from 'uuid';
import { pubClient } from '../analytics-api';
import { config, isDevelopment } from '../config';
import { config, isDevelopment, MAX_MESSAGE_LENGTH } from '../config';
import { getLogger } from '../logger';
import {
increaseRateLimits,
Expand Down Expand Up @@ -58,6 +58,21 @@ export const handleMessage = async ({
let ready = false; // Determines if the keys have been exchanged and both side support the full protocol

try {
// Add message size validation
const messageSize = typeof message === 'string'
? message.length
: JSON.stringify(message).length;

if (messageSize > MAX_MESSAGE_LENGTH) {
logger.warn(`[handleMessage] Message size ${messageSize} exceeds limit of ${MAX_MESSAGE_LENGTH} bytes`, {
channelId,
socketId,
clientIp,
});
callback?.(`Message size ${messageSize} exceeds maximum allowed size of ${MAX_MESSAGE_LENGTH} bytes`);
return;
}

if (clientType) {
// new protocol, get channelConfig
const channelConfigKey = `channel_config:${channelId}`;
Expand Down
2 changes: 2 additions & 0 deletions packages/sdk/src/config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -73,3 +73,5 @@ export const EXTENSION_EVENTS = {
CONNECT: 'connect',
CONNECTED: 'connected',
};

export const MAX_MESSAGE_LENGTH = 1_000_000; // 1MB limit
47 changes: 47 additions & 0 deletions packages/sdk/src/services/MobilePortStream/write.test.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import { Buffer } from 'buffer';
import { MAX_MESSAGE_LENGTH } from '../../config';
import { write } from './write';

describe('write function', () => {
Expand Down Expand Up @@ -77,4 +78,50 @@ describe('write function', () => {
new Error('MobilePortStream - disconnected'),
);
});

describe('Message Size Validation', () => {
beforeEach(() => {
jest.clearAllMocks();
global.window = {
location: { href: 'http://example.com' },
ReactNativeWebView: { postMessage: mockPostMessage },
} as any;
});

it('should reject messages exceeding MAX_MESSAGE_LENGTH', () => {
const largeData = {
data: {
jsonrpc: '2.0',
method: 'test_method',
params: ['x'.repeat(MAX_MESSAGE_LENGTH)],
},
};

write(largeData, 'utf-8', cb);

expect(cb).toHaveBeenCalledWith(
expect.objectContaining({
message: expect.stringMatching(
/Message size \d+ exceeds maximum allowed size of \d+ bytes/u,
),
}),
);
expect(mockPostMessage).not.toHaveBeenCalled();
});

it('should accept messages within MAX_MESSAGE_LENGTH', () => {
const validData = {
data: {
jsonrpc: '2.0',
method: 'test_method',
params: ['x'.repeat(100)],
},
};

write(validData, 'utf-8', cb);

expect(cb).toHaveBeenCalledWith();
expect(mockPostMessage).toHaveBeenCalled();
});
});
});
24 changes: 19 additions & 5 deletions packages/sdk/src/services/MobilePortStream/write.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import { Buffer } from 'buffer';
import { MAX_MESSAGE_LENGTH } from '../../config';

/**
* Handles communication between the in-app browser and MetaMask mobile application.
Expand All @@ -15,6 +16,7 @@ export function write(
cb: (error?: Error | null) => void,
) {
try {
let stringifiedData: string;
if (Buffer.isBuffer(chunk)) {
const data: {
type: 'Buffer';
Expand All @@ -23,18 +25,30 @@ export function write(
} = chunk.toJSON();

data._isBuffer = true;
window.ReactNativeWebView?.postMessage(
JSON.stringify({ ...data, origin: window.location.href }),
);
stringifiedData = JSON.stringify({
...data,
origin: window.location.href,
});
} else {
if (chunk.data) {
chunk.data.toNative = true;
}

window.ReactNativeWebView?.postMessage(
JSON.stringify({ ...chunk, origin: window.location.href }),
stringifiedData = JSON.stringify({
...chunk,
origin: window.location.href,
});
}

if (stringifiedData.length > MAX_MESSAGE_LENGTH) {
return cb(
new Error(
`Message size ${stringifiedData.length} exceeds maximum allowed size of ${MAX_MESSAGE_LENGTH} bytes`,
),
);
}

window.ReactNativeWebView?.postMessage(stringifiedData);
} catch (err) {
return cb(new Error('MobilePortStream - disconnected'));
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import { Ethereum } from '../Ethereum'; // Adjust the import based on your project structure
import { RemoteCommunicationPostMessageStream } from '../../PostMessageStream/RemoteCommunicationPostMessageStream'; // Adjust the import based on your project structure
import { METHODS_TO_REDIRECT } from '../../config';
import { MAX_MESSAGE_LENGTH, METHODS_TO_REDIRECT } from '../../config';
import * as loggerModule from '../../utils/logger'; // Adjust the import based on your project structure
import { write } from './write'; // Adjust the import based on your project structure
import { Ethereum } from '../Ethereum'; // Adjust the import based on your project structure
import { extractMethod } from './extractMethod';
import { write } from './write'; // Adjust the import based on your project structure

jest.mock('./extractMethod');
jest.mock('../Ethereum');
Expand Down Expand Up @@ -162,11 +162,22 @@ describe('write function', () => {
mockIsMobileWeb.mockReturnValue(false);
mockIsSecure.mockReturnValue(true);
mockGetChannelId.mockReturnValue('some_channel_id');
mockIsMetaMaskInstalled.mockReturnValue(true);
mockGetKeyInfo.mockReturnValue({ ecies: { public: 'test_public_key' } });
mockHasDeeplinkProtocol.mockReturnValue(false);
});

it('should redirect if method exists in METHODS_TO_REDIRECT', async () => {
mockExtractMethod.mockReturnValue({
method: Object.keys(METHODS_TO_REDIRECT)[0],
data: {
data: {
jsonrpc: '2.0',
method: Object.keys(METHODS_TO_REDIRECT)[0],
params: [],
},
},
triggeredInstaller: false,
});

await write(
Expand Down Expand Up @@ -239,4 +250,71 @@ describe('write function', () => {
expect(spyLogger).toHaveBeenCalled();
});
});

describe('Message Size Validation', () => {
it('should reject messages exceeding MAX_MESSAGE_LENGTH', async () => {
mockGetChannelId.mockReturnValue('some_channel_id');
mockIsReady.mockReturnValue(true);
mockIsConnected.mockReturnValue(true);

// Mock extractMethod to return large data
const largeData = {
jsonrpc: '2.0',
method: 'eth_call',
params: ['x'.repeat(MAX_MESSAGE_LENGTH + 1)],
};

mockExtractMethod.mockReturnValue({
method: 'eth_call',
data: {
data: largeData,
},
});

await write(
mockRemoteCommunicationPostMessageStream,
{ jsonrpc: '2.0', method: 'eth_call' },
'utf8',
callback,
);

// Don't test for exact error message, just verify it contains the key parts
expect(callback).toHaveBeenCalledWith(
expect.objectContaining({
message: expect.stringMatching(
/Message size \d+ exceeds maximum allowed size of \d+ bytes/u,
),
}),
);
expect(mockSendMessage).not.toHaveBeenCalled();
});

it('should accept messages within MAX_MESSAGE_LENGTH', async () => {
mockGetChannelId.mockReturnValue('some_channel_id');
mockIsReady.mockReturnValue(true);
mockIsConnected.mockReturnValue(true);

// Mock extractMethod to return valid-sized data
mockExtractMethod.mockReturnValue({
method: 'eth_call',
data: {
data: {
jsonrpc: '2.0',
method: 'eth_call',
params: ['x'.repeat(100)],
},
},
});

await write(
mockRemoteCommunicationPostMessageStream,
{ jsonrpc: '2.0', method: 'eth_call' },
'utf8',
callback,
);

expect(callback).toHaveBeenCalledWith();
expect(mockSendMessage).toHaveBeenCalled();
});
});
});
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
import { RemoteCommunicationPostMessageStream } from '../../PostMessageStream/RemoteCommunicationPostMessageStream';
import { METHODS_TO_REDIRECT, RPC_METHODS } from '../../config';
import {
METHODS_TO_REDIRECT,
RPC_METHODS,
MAX_MESSAGE_LENGTH,
} from '../../config';
import {
METAMASK_CONNECT_BASE_URL,
METAMASK_DEEPLINK_BASE,
Expand Down Expand Up @@ -57,11 +61,17 @@ export async function write(
deeplinkProtocolAvailable && mobileWeb && authorized;

try {
console.warn(
`[RCPMS: _write()] triggeredInstaller=${triggeredInstaller} activeDeeplinkProtocol=${activeDeeplinkProtocol}`,
);

if (!triggeredInstaller) {
// Check message size before sending
const stringifiedData = JSON.stringify(data?.data);
if (stringifiedData.length > MAX_MESSAGE_LENGTH) {
return callback(
new Error(
`Message size ${stringifiedData.length} exceeds maximum allowed size of ${MAX_MESSAGE_LENGTH} bytes`,
),
);
}

// The only reason not to send via network is because the rpc call will be sent in the deeplink
instance.state.remote
?.sendMessage(data?.data)
Expand Down

0 comments on commit a24e803

Please sign in to comment.