From 00a671664e97eee5116d051a956bc3669448fbdd Mon Sep 17 00:00:00 2001 From: adam-alchemy <127769144+adam-alchemy@users.noreply.github.com> Date: Mon, 22 Jan 2024 09:29:03 -0800 Subject: [PATCH] fix: [spearbit-82] Session key init perms (#80) --- src/plugins/session/ISessionKeyPlugin.sol | 4 +- src/plugins/session/SessionKeyPlugin.sol | 41 ++++++- .../permissions/SessionKeyPermissions.sol | 2 +- ...gradeableModularAccountPluginManager.t.sol | 6 +- .../SessionKeyPluginWithMultiOwner.t.sol | 90 ++++++++++++---- .../SessionKeyERC20SpendLimits.t.sol | 4 +- .../permissions/SessionKeyGasLimits.t.sol | 4 +- .../SessionKeyNativeTokenSpendLimits.t.sol | 4 +- .../permissions/SessionKeyPermissions.t.sol | 100 +++++++++++++++++- 9 files changed, 215 insertions(+), 40 deletions(-) diff --git a/src/plugins/session/ISessionKeyPlugin.sol b/src/plugins/session/ISessionKeyPlugin.sol index 0d5f57685..85e352873 100644 --- a/src/plugins/session/ISessionKeyPlugin.sol +++ b/src/plugins/session/ISessionKeyPlugin.sol @@ -59,6 +59,7 @@ interface ISessionKeyPlugin { error InvalidPermissionsUpdate(); error InvalidToken(); error NativeTokenSpendLimitExceeded(address account, address sessionKey); + error LengthMismatch(); // ┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ // ┃ Execution functions ┃ @@ -75,7 +76,8 @@ interface ISessionKeyPlugin { /// @notice Add a session key. /// @param sessionKey The session key to register. /// @param tag An optional tag that can be used to identify the key. - function addSessionKey(address sessionKey, bytes32 tag) external; + /// @param permissionUpdates The initial permission updates to apply to the key. + function addSessionKey(address sessionKey, bytes32 tag, bytes[] calldata permissionUpdates) external; /// @notice Remove a session key. /// @param sessionKey The session key to remove. diff --git a/src/plugins/session/SessionKeyPlugin.sol b/src/plugins/session/SessionKeyPlugin.sol index 9f6bedfb7..cec2e47dc 100644 --- a/src/plugins/session/SessionKeyPlugin.sol +++ b/src/plugins/session/SessionKeyPlugin.sol @@ -122,7 +122,7 @@ contract SessionKeyPlugin is ISessionKeyPlugin, SessionKeyPermissions, BasePlugi } /// @inheritdoc ISessionKeyPlugin - function addSessionKey(address sessionKey, bytes32 tag) public override { + function addSessionKey(address sessionKey, bytes32 tag, bytes[] calldata permissionUpdates) public override { if (!_sessionKeys.tryAdd(msg.sender, CastLib.toSetValue(sessionKey))) { // This check ensures no duplicate keys and that the session key is not the zero address. revert InvalidSessionKey(sessionKey); @@ -135,6 +135,11 @@ contract SessionKeyPlugin is ISessionKeyPlugin, SessionKeyPermissions, BasePlugi _updateSessionKeyId(msg.sender, sessionKey, SessionKeyId.wrap(bytes32(++_keyIdCounter[msg.sender]))); emit SessionKeyAdded(msg.sender, sessionKey, tag); + + if (permissionUpdates.length > 0) { + // Call the public function internally to update the permissions. + updateKeyPermissions(sessionKey, permissionUpdates); + } } /// @inheritdoc ISessionKeyPlugin @@ -357,13 +362,39 @@ contract SessionKeyPlugin is ISessionKeyPlugin, SessionKeyPermissions, BasePlugi /// @inheritdoc BasePlugin function _onInstall(bytes calldata data) internal override isNotInitialized(msg.sender) { - address[] memory sessionKeysToAdd = abi.decode(data, (address[])); + (address[] memory sessionKeysToAdd, bytes32[] memory tags,) = + abi.decode(data, (address[], bytes32[], bytes[][])); + + // Permission updates depend on `calldata` types for the updates, in order to use array slices. + // Unfortunately, solidity does not support `abi.decode` outputting calldata types, so we do it manually. + // The prior `abi.decode` call validates that `data` has a valid encoding of a `bytes[][]` at this + // position, so we can load in the offset and length of the `bytes[][]` directly. + bytes[][] calldata permissionUpdates; + assembly ("memory-safe") { + // Get the offset of the bytes[][] used for permissions updates. The offset for this field is stored at + // the third word of `data`. Note that `data.offset` refers to the start of the actual data contents, + // one word after the length. + let permissionUpdatesRelativeOffset := calldataload(add(data.offset, 0x40)) + // We now have the relative offset of the bytes[][] in `data`. We need to add the starting offset + // (aka `data.offset`) to get the absolute offset. + let permissionUpdatesLengthOffset := add(data.offset, permissionUpdatesRelativeOffset) + // Note that solidity expects the field `var.offset` to point to the start of the actual data contents, + // one word after the length. Therefore, we add 0x20 here to get the correct offset. + permissionUpdates.offset := add(0x20, permissionUpdatesLengthOffset) + // Load the length of the bytes[][]. + permissionUpdates.length := calldataload(permissionUpdatesLengthOffset) + } uint256 length = sessionKeysToAdd.length; + + if (length != tags.length || length != permissionUpdates.length) { + revert LengthMismatch(); + } + for (uint256 i = 0; i < length; ++i) { - // Use the public function to add the session key, set the key id, and emit the event. - // Note that no tags are set when adding keys with this method. - addSessionKey(sessionKeysToAdd[i], bytes32(0)); + // Use the public function to add the session key, set the key id, emit the event, and update the + // permissions. + addSessionKey(sessionKeysToAdd[i], tags[i], permissionUpdates[i]); } } diff --git a/src/plugins/session/permissions/SessionKeyPermissions.sol b/src/plugins/session/permissions/SessionKeyPermissions.sol index 1ab90d887..87e354e4f 100644 --- a/src/plugins/session/permissions/SessionKeyPermissions.sol +++ b/src/plugins/session/permissions/SessionKeyPermissions.sol @@ -21,7 +21,7 @@ abstract contract SessionKeyPermissions is ISessionKeyPlugin, SessionKeyPermissi // ┗━━━━━━━━━━━━━━━━━━━━━━━━━━━┛ /// @inheritdoc ISessionKeyPlugin - function updateKeyPermissions(address sessionKey, bytes[] calldata updates) external override { + function updateKeyPermissions(address sessionKey, bytes[] calldata updates) public override { (SessionKeyData storage sessionKeyData, SessionKeyId keyId) = _loadSessionKey(msg.sender, sessionKey); uint256 length = updates.length; diff --git a/test/account/UpgradeableModularAccountPluginManager.t.sol b/test/account/UpgradeableModularAccountPluginManager.t.sol index 178d9f6a0..9bc337f4f 100644 --- a/test/account/UpgradeableModularAccountPluginManager.t.sol +++ b/test/account/UpgradeableModularAccountPluginManager.t.sol @@ -115,7 +115,7 @@ contract UpgradeableModularAccountPluginManagerTest is Test { factory.createAccount(0, owners1); } - function test_installPlugin() public { + function test_installPlugin_standard() public { vm.startPrank(owner2); bytes32 manifestHash = keccak256(abi.encode(sessionKeyPlugin.pluginManifest())); @@ -136,7 +136,9 @@ contract UpgradeableModularAccountPluginManagerTest is Test { IPluginManager(account2).installPlugin({ plugin: address(sessionKeyPlugin), manifestHash: manifestHash, - pluginInitData: abi.encode(sessionKeys), + pluginInitData: abi.encode( + sessionKeys, new bytes32[](sessionKeys.length), new bytes[][](sessionKeys.length) + ), dependencies: dependencies, injectedHooks: new IPluginManager.InjectedHook[](0) }); diff --git a/test/plugin/session/SessionKeyPluginWithMultiOwner.t.sol b/test/plugin/session/SessionKeyPluginWithMultiOwner.t.sol index c03fb52c3..2b0dfeee9 100644 --- a/test/plugin/session/SessionKeyPluginWithMultiOwner.t.sol +++ b/test/plugin/session/SessionKeyPluginWithMultiOwner.t.sol @@ -84,7 +84,7 @@ contract SessionKeyPluginWithMultiOwnerTest is Test { account1.installPlugin({ plugin: address(sessionKeyPlugin), manifestHash: manifestHash, - pluginInitData: abi.encode(new address[](0)), + pluginInitData: abi.encode(new address[](0), new bytes32[](0), new bytes[][](0)), dependencies: dependencies, injectedHooks: new IPluginManager.InjectedHook[](0) }); @@ -96,7 +96,7 @@ contract SessionKeyPluginWithMultiOwnerTest is Test { vm.expectEmit(true, true, true, true); emit SessionKeyAdded(address(account1), sessionKeyToAdd, bytes32(0)); vm.prank(owner1); - SessionKeyPlugin(address(account1)).addSessionKey(sessionKeyToAdd, bytes32(0)); + SessionKeyPlugin(address(account1)).addSessionKey(sessionKeyToAdd, bytes32(0), new bytes[](0)); // Check using all view methods address[] memory sessionKeys = sessionKeyPlugin.sessionKeysOf(address(account1)); @@ -111,13 +111,13 @@ contract SessionKeyPluginWithMultiOwnerTest is Test { // Zero address session key address sessionKeyToAdd = address(0); vm.expectRevert(abi.encodeWithSelector(ISessionKeyPlugin.InvalidSessionKey.selector, address(0))); - SessionKeyPlugin(address(account1)).addSessionKey(sessionKeyToAdd, bytes32(0)); + SessionKeyPlugin(address(account1)).addSessionKey(sessionKeyToAdd, bytes32(0), new bytes[](0)); // Duplicate session key sessionKeyToAdd = makeAddr("sessionKey1"); - SessionKeyPlugin(address(account1)).addSessionKey(sessionKeyToAdd, bytes32(0)); + SessionKeyPlugin(address(account1)).addSessionKey(sessionKeyToAdd, bytes32(0), new bytes[](0)); vm.expectRevert(abi.encodeWithSelector(ISessionKeyPlugin.InvalidSessionKey.selector, sessionKeyToAdd)); - SessionKeyPlugin(address(account1)).addSessionKey(sessionKeyToAdd, bytes32(0)); + SessionKeyPlugin(address(account1)).addSessionKey(sessionKeyToAdd, bytes32(0), new bytes[](0)); // Check using all view methods address[] memory sessionKeys = sessionKeyPlugin.sessionKeysOf(address(account1)); @@ -131,8 +131,8 @@ contract SessionKeyPluginWithMultiOwnerTest is Test { address sessionKey2 = makeAddr("sessionKey2"); vm.startPrank(owner1); - SessionKeyPlugin(address(account1)).addSessionKey(sessionKey1, bytes32(0)); - SessionKeyPlugin(address(account1)).addSessionKey(sessionKey2, bytes32(0)); + SessionKeyPlugin(address(account1)).addSessionKey(sessionKey1, bytes32(0), new bytes[](0)); + SessionKeyPlugin(address(account1)).addSessionKey(sessionKey2, bytes32(0), new bytes[](0)); vm.stopPrank(); // Check using all view methods @@ -159,11 +159,61 @@ contract SessionKeyPluginWithMultiOwnerTest is Test { assertTrue(sessionKeyPlugin.isSessionKeyOf(address(account1), sessionKey2)); } + function testFuzz_sessionKey_addKeysDuringInstall(uint8 seed) public { + // First uninstall the plugin + vm.prank(owner1); + account1.uninstallPlugin(address(sessionKeyPlugin), "", "", new bytes[](0)); + + // Generate a set of initial session keys + uint256 addressCount = (seed % 16) + 1; + + address[] memory sessionKeysToAdd = new address[](addressCount); + bytes32[] memory tags = new bytes32[](addressCount); + for (uint256 i = 0; i < addressCount; i++) { + sessionKeysToAdd[i] = makeAddr(string.concat(vm.toString(seed), "sessionKey", vm.toString(i))); + tags[i] = bytes32(uint256(i) + seed); + } + + bytes memory onInstallData = abi.encode(sessionKeysToAdd, tags, new bytes[][](addressCount)); + + // Re-install the plugin + bytes32 manifestHash = keccak256(abi.encode(sessionKeyPlugin.pluginManifest())); + FunctionReference[] memory dependencies = new FunctionReference[](2); + dependencies[0] = FunctionReferenceLib.pack( + address(multiOwnerPlugin), uint8(IMultiOwnerPlugin.FunctionId.USER_OP_VALIDATION_OWNER) + ); + dependencies[1] = FunctionReferenceLib.pack( + address(multiOwnerPlugin), uint8(IMultiOwnerPlugin.FunctionId.RUNTIME_VALIDATION_OWNER_OR_SELF) + ); + + for (uint256 i = 0; i < addressCount; i++) { + vm.expectEmit(true, true, true, true); + emit SessionKeyAdded(address(account1), sessionKeysToAdd[i], tags[i]); + } + vm.prank(owner1); + account1.installPlugin({ + plugin: address(sessionKeyPlugin), + manifestHash: manifestHash, + pluginInitData: onInstallData, + dependencies: dependencies, + injectedHooks: new IPluginManager.InjectedHook[](0) + }); + + // Check using all view methods + address[] memory sessionKeys = sessionKeyPlugin.sessionKeysOf(address(account1)); + assertEq(sessionKeys.length, addressCount); + for (uint256 i = 0; i < addressCount; i++) { + // Invert the indexing because the view function will return it in reverse order + assertEq(sessionKeys[sessionKeys.length - 1 - i], sessionKeysToAdd[i]); + assertTrue(sessionKeyPlugin.isSessionKeyOf(address(account1), sessionKeysToAdd[i])); + } + } + function test_sessionKey_rotate_valid() public { // Add the first key address sessionKey1 = makeAddr("sessionKey1"); vm.prank(owner1); - SessionKeyPlugin(address(account1)).addSessionKey(sessionKey1, bytes32(0)); + SessionKeyPlugin(address(account1)).addSessionKey(sessionKey1, bytes32(0), new bytes[](0)); // Rotate to the second key address sessionKey2 = makeAddr("sessionKey2"); @@ -186,8 +236,8 @@ contract SessionKeyPluginWithMultiOwnerTest is Test { address sessionKey1 = makeAddr("sessionKey1"); address sessionKey2 = makeAddr("sessionKey2"); vm.startPrank(owner1); - SessionKeyPlugin(address(account1)).addSessionKey(sessionKey1, bytes32(0)); - SessionKeyPlugin(address(account1)).addSessionKey(sessionKey2, bytes32(0)); + SessionKeyPlugin(address(account1)).addSessionKey(sessionKey1, bytes32(0), new bytes[](0)); + SessionKeyPlugin(address(account1)).addSessionKey(sessionKey2, bytes32(0), new bytes[](0)); vm.stopPrank(); // Attempt to rotate key 1 to key 2 @@ -201,7 +251,7 @@ contract SessionKeyPluginWithMultiOwnerTest is Test { // Add the first key address sessionKey1 = makeAddr("sessionKey1"); vm.prank(owner1); - SessionKeyPlugin(address(account1)).addSessionKey(sessionKey1, bytes32(0)); + SessionKeyPlugin(address(account1)).addSessionKey(sessionKey1, bytes32(0), new bytes[](0)); // Attempt to rotate to the zero address address zeroAddr = address(0); @@ -215,7 +265,7 @@ contract SessionKeyPluginWithMultiOwnerTest is Test { (address sessionKey1, uint256 sessionKeyPrivate) = makeAddrAndKey("sessionKey1"); vm.startPrank(owner1); - SessionKeyPlugin(address(account1)).addSessionKey(sessionKey1, bytes32(0)); + SessionKeyPlugin(address(account1)).addSessionKey(sessionKey1, bytes32(0), new bytes[](0)); // Disable the allowlist and native token spend checking bytes[] memory permissionUpdates = new bytes[](2); permissionUpdates[0] = abi.encodeCall( @@ -261,7 +311,7 @@ contract SessionKeyPluginWithMultiOwnerTest is Test { (address sessionKey1,) = makeAddrAndKey("sessionKey1"); vm.startPrank(owner1); - SessionKeyPlugin(address(account1)).addSessionKey(sessionKey1, bytes32(0)); + SessionKeyPlugin(address(account1)).addSessionKey(sessionKey1, bytes32(0), new bytes[](0)); Call[] memory calls = new Call[](1); calls[0] = Call({target: recipient, value: 1 wei, data: ""}); @@ -401,8 +451,8 @@ contract SessionKeyPluginWithMultiOwnerTest is Test { address sessionKey2 = makeAddr("sessionKey2"); vm.startPrank(owner1); - SessionKeyPlugin(address(account1)).addSessionKey(sessionKey1, bytes32(0)); - SessionKeyPlugin(address(account1)).addSessionKey(sessionKey2, bytes32(0)); + SessionKeyPlugin(address(account1)).addSessionKey(sessionKey1, bytes32(0), new bytes[](0)); + SessionKeyPlugin(address(account1)).addSessionKey(sessionKey2, bytes32(0), new bytes[](0)); bytes32 predecessor = sessionKeyPlugin.findPredecessor(address(account1), sessionKey2); assertEq(predecessor, bytes32(uint256(1))); @@ -424,8 +474,8 @@ contract SessionKeyPluginWithMultiOwnerTest is Test { address sessionKey2 = makeAddr("sessionKey2"); vm.startPrank(owner1); - SessionKeyPlugin(address(account1)).addSessionKey(sessionKey1, bytes32(0)); - SessionKeyPlugin(address(account1)).addSessionKey(sessionKey2, bytes32(0)); + SessionKeyPlugin(address(account1)).addSessionKey(sessionKey1, bytes32(0), new bytes[](0)); + SessionKeyPlugin(address(account1)).addSessionKey(sessionKey2, bytes32(0), new bytes[](0)); bytes32 predecessor = sessionKeyPlugin.findPredecessor(address(account1), sessionKey1); assertEq(predecessor, bytes32(bytes20(sessionKey2))); @@ -446,7 +496,7 @@ contract SessionKeyPluginWithMultiOwnerTest is Test { sessionKeysToAdd[0] = makeAddr("sessionKey1"); vm.startPrank(owner1); - SessionKeyPlugin(address(account1)).addSessionKey(sessionKeysToAdd[0], bytes32(0)); + SessionKeyPlugin(address(account1)).addSessionKey(sessionKeysToAdd[0], bytes32(0), new bytes[](0)); address key2 = makeAddr("sessionKey2"); vm.expectRevert(abi.encodeWithSelector(ISessionKeyPlugin.SessionKeyNotFound.selector, key2)); @@ -460,7 +510,7 @@ contract SessionKeyPluginWithMultiOwnerTest is Test { sessionKeysToAdd[0] = makeAddr("sessionKey1"); vm.prank(owner1); - SessionKeyPlugin(address(account1)).addSessionKey(sessionKeysToAdd[0], bytes32(0)); + SessionKeyPlugin(address(account1)).addSessionKey(sessionKeysToAdd[0], bytes32(0), new bytes[](0)); assertFalse(sessionKeyPlugin.isSessionKeyOf(address(account1), address(1))); } @@ -484,7 +534,7 @@ contract SessionKeyPluginWithMultiOwnerTest is Test { vm.startPrank(owner1); for (uint256 i = 0; i < addressCount; i++) { - SessionKeyPlugin(address(account1)).addSessionKey(sessionKeysToAdd[i], bytes32(0)); + SessionKeyPlugin(address(account1)).addSessionKey(sessionKeysToAdd[i], bytes32(0), new bytes[](0)); SessionKeyPlugin(address(account1)).updateKeyPermissions(sessionKeysToAdd[i], permissionUpdates); } vm.stopPrank(); diff --git a/test/plugin/session/permissions/SessionKeyERC20SpendLimits.t.sol b/test/plugin/session/permissions/SessionKeyERC20SpendLimits.t.sol index df38f9566..163234ef7 100644 --- a/test/plugin/session/permissions/SessionKeyERC20SpendLimits.t.sol +++ b/test/plugin/session/permissions/SessionKeyERC20SpendLimits.t.sol @@ -86,7 +86,7 @@ contract SessionKeyERC20SpendLimitsTest is Test { account1.installPlugin({ plugin: address(sessionKeyPlugin), manifestHash: manifestHash, - pluginInitData: abi.encode(new address[](0)), + pluginInitData: abi.encode(new address[](0), new bytes32[](0), new bytes[][](0)), dependencies: dependencies, injectedHooks: new IPluginManager.InjectedHook[](0) }); @@ -95,7 +95,7 @@ contract SessionKeyERC20SpendLimitsTest is Test { (sessionKey1, sessionKey1Private) = makeAddrAndKey("sessionKey1"); vm.prank(owner1); - SessionKeyPlugin(address(account1)).addSessionKey(sessionKey1, bytes32(0)); + SessionKeyPlugin(address(account1)).addSessionKey(sessionKey1, bytes32(0), new bytes[](0)); // Disable the allowlist bytes[] memory updates = new bytes[](1); diff --git a/test/plugin/session/permissions/SessionKeyGasLimits.t.sol b/test/plugin/session/permissions/SessionKeyGasLimits.t.sol index d0127cae0..d3ceb4431 100644 --- a/test/plugin/session/permissions/SessionKeyGasLimits.t.sol +++ b/test/plugin/session/permissions/SessionKeyGasLimits.t.sol @@ -77,7 +77,7 @@ contract SessionKeyGasLimitsTest is Test { account1.installPlugin({ plugin: address(sessionKeyPlugin), manifestHash: manifestHash, - pluginInitData: abi.encode(new address[](0)), + pluginInitData: abi.encode(new address[](0), new bytes32[](0), new bytes[][](0)), dependencies: dependencies, injectedHooks: new IPluginManager.InjectedHook[](0) }); @@ -86,7 +86,7 @@ contract SessionKeyGasLimitsTest is Test { (sessionKey1, sessionKey1Private) = makeAddrAndKey("sessionKey1"); vm.prank(owner1); - SessionKeyPlugin(address(account1)).addSessionKey(sessionKey1, bytes32(0)); + SessionKeyPlugin(address(account1)).addSessionKey(sessionKey1, bytes32(0), new bytes[](0)); bytes[] memory updates = new bytes[](1); updates[0] = abi.encodeCall( diff --git a/test/plugin/session/permissions/SessionKeyNativeTokenSpendLimits.t.sol b/test/plugin/session/permissions/SessionKeyNativeTokenSpendLimits.t.sol index f1a0245c7..19f23cfac 100644 --- a/test/plugin/session/permissions/SessionKeyNativeTokenSpendLimits.t.sol +++ b/test/plugin/session/permissions/SessionKeyNativeTokenSpendLimits.t.sol @@ -82,7 +82,7 @@ contract SessionKeyNativeTokenSpendLimitsTest is Test { account1.installPlugin({ plugin: address(sessionKeyPlugin), manifestHash: manifestHash, - pluginInitData: abi.encode(new address[](0)), + pluginInitData: abi.encode(new address[](0), new bytes32[](0), new bytes[][](0)), dependencies: dependencies, injectedHooks: new IPluginManager.InjectedHook[](0) }); @@ -91,7 +91,7 @@ contract SessionKeyNativeTokenSpendLimitsTest is Test { (sessionKey1, sessionKey1Private) = makeAddrAndKey("sessionKey1"); vm.prank(owner1); - SessionKeyPlugin(address(account1)).addSessionKey(sessionKey1, bytes32(0)); + SessionKeyPlugin(address(account1)).addSessionKey(sessionKey1, bytes32(0), new bytes[](0)); // Remove the allowlist bytes[] memory updates = new bytes[](1); diff --git a/test/plugin/session/permissions/SessionKeyPermissions.t.sol b/test/plugin/session/permissions/SessionKeyPermissions.t.sol index 6dcdb2a1b..8cd62be49 100644 --- a/test/plugin/session/permissions/SessionKeyPermissions.t.sol +++ b/test/plugin/session/permissions/SessionKeyPermissions.t.sol @@ -48,6 +48,8 @@ contract SessionKeyPermissionsTest is Test { Counter counter2; + event PermissionsUpdated(address indexed account, address indexed sessionKey, bytes[] updates); + function setUp() public { entryPoint = IEntryPoint(address(new EntryPoint())); (owner1, owner1Key) = makeAddrAndKey("owner1"); @@ -90,7 +92,7 @@ contract SessionKeyPermissionsTest is Test { account1.installPlugin({ plugin: address(sessionKeyPlugin), manifestHash: manifestHash, - pluginInitData: abi.encode(new address[](0)), + pluginInitData: abi.encode(new address[](0), new bytes32[](0), new bytes[][](0)), dependencies: dependencies, injectedHooks: new IPluginManager.InjectedHook[](0) }); @@ -99,7 +101,7 @@ contract SessionKeyPermissionsTest is Test { (sessionKey1, sessionKey1Private) = makeAddrAndKey("sessionKey1"); vm.prank(owner1); - SessionKeyPlugin(address(account1)).addSessionKey(sessionKey1, bytes32(0)); + SessionKeyPlugin(address(account1)).addSessionKey(sessionKey1, bytes32(0), new bytes[](0)); // Initialize the interaction targets counter1 = new Counter(); @@ -557,7 +559,7 @@ contract SessionKeyPermissionsTest is Test { address sessionKey2 = makeAddr("sessionKey2"); vm.prank(owner1); - SessionKeyPlugin(address(account1)).addSessionKey(sessionKey2, bytes32(0)); + SessionKeyPlugin(address(account1)).addSessionKey(sessionKey2, bytes32(0), new bytes[](0)); ISessionKeyPlugin.ContractAccessControlType accessControlType1; ISessionKeyPlugin.ContractAccessControlType accessControlType2; @@ -624,7 +626,7 @@ contract SessionKeyPermissionsTest is Test { account1.installPlugin({ plugin: address(sessionKeyPlugin), manifestHash: keccak256(abi.encode(sessionKeyPlugin.pluginManifest())), - pluginInitData: abi.encode(new address[](0)), + pluginInitData: abi.encode(new address[](0), new bytes32[](0), new bytes[][](0)), dependencies: dependencies, injectedHooks: new IPluginManager.InjectedHook[](0) }); @@ -632,7 +634,7 @@ contract SessionKeyPermissionsTest is Test { // Re-add the session key vm.prank(owner1); - SessionKeyPlugin(address(account1)).addSessionKey(sessionKey1, bytes32(0)); + SessionKeyPlugin(address(account1)).addSessionKey(sessionKey1, bytes32(0), new bytes[](0)); // Assert that the time range is reset (returnedStartTime, returnedEndTime) = sessionKeyPlugin.getKeyTimeRange(address(account1), sessionKey1); @@ -640,6 +642,38 @@ contract SessionKeyPermissionsTest is Test { assertEq(returnedEndTime, uint48(0)); } + function testFuzz_initialSessionKeysWithPermissions(uint256 seed) public { + // Uninstall the plugin + vm.prank(owner1); + account1.uninstallPlugin(address(sessionKeyPlugin), "", "", new bytes[](0)); + + address[] memory sessionKeys = _generateRandomAddresses(seed); + bytes32[] memory tags = new bytes32[](sessionKeys.length); + bytes[][] memory sessionKeyPermissions = new bytes[][](sessionKeys.length); + for (uint256 i = 0; i < sessionKeys.length; i++) { + uint256 modifiedSeed; + unchecked { + modifiedSeed = seed + i; + } + sessionKeyPermissions[i] = _generateRandomPermissionUpdates(modifiedSeed); + } + + // Reinstall the plugin with the session keys + for (uint256 i = 0; i < sessionKeys.length; i++) { + vm.expectEmit(true, true, true, true); + emit PermissionsUpdated(address(account1), sessionKeys[i], sessionKeyPermissions[i]); + } + bytes32 manifestHash = keccak256(abi.encode(sessionKeyPlugin.pluginManifest())); + vm.prank(owner1); + account1.installPlugin({ + plugin: address(sessionKeyPlugin), + manifestHash: manifestHash, + pluginInitData: abi.encode(sessionKeys, tags, sessionKeyPermissions), + dependencies: dependencies, + injectedHooks: new IPluginManager.InjectedHook[](0) + }); + } + function _runSessionKeyExecUserOp( address target, address sessionKey, @@ -677,4 +711,60 @@ contract SessionKeyPermissionsTest is Test { } entryPoint.handleOps(userOps, beneficiary); } + + function _generateRandomAddresses(uint256 seed) internal returns (address[] memory keys) { + uint256 addressCount = (seed % 5) + 1; + + keys = new address[](addressCount); + for (uint256 i = 0; i < addressCount; i++) { + keys[i] = makeAddr(string.concat(vm.toString(seed), "sessionKey", vm.toString(i))); + } + } + + function _generateRandomPermissionUpdates(uint256 seed) internal returns (bytes[] memory updates) { + uint256 updateCount = (seed % 5) + 1; + + updates = new bytes[](updateCount); + + for (uint256 i = 0; i < updateCount; i++) { + uint256 updateType = (seed % 6) + 1; + if (updateType == 1) { + // Set access list type + uint256 accessListType = (seed % 3); + updates[i] = abi.encodeCall( + ISessionKeyPermissionsUpdates.setAccessListType, + ISessionKeyPlugin.ContractAccessControlType(accessListType) + ); + } else if (updateType == 2) { + // Update access list address entry + address addr = makeAddr(string.concat(vm.toString(seed), "addr", vm.toString(i))); + bool isOnList = (seed % 2) == 0; + bool checkSelectors = (seed % 3) == 0; + updates[i] = abi.encodeCall( + ISessionKeyPermissionsUpdates.updateAccessListAddressEntry, (addr, isOnList, checkSelectors) + ); + } else if (updateType == 3) { + // Update access list function entry + address addr = makeAddr(string.concat(vm.toString(seed), "addr", vm.toString(i))); + bytes4 selector = bytes4(uint32(seed)); + bool isOnList = (seed % 2) == 0; + updates[i] = abi.encodeCall( + ISessionKeyPermissionsUpdates.updateAccessListFunctionEntry, (addr, selector, isOnList) + ); + } else if (updateType == 4) { + // Set time range + uint48 startTime = uint48(seed); + uint48 endTime = uint48(seed << 2); + updates[i] = abi.encodeCall(ISessionKeyPermissionsUpdates.updateTimeRange, (startTime, endTime)); + } else if (updateType == 5) { + // Set required paymaster + address paymaster = makeAddr(string.concat(vm.toString(seed), "paymaster", vm.toString(i))); + updates[i] = abi.encodeCall(ISessionKeyPermissionsUpdates.setRequiredPaymaster, (paymaster)); + } else if (updateType == 6) { + // Set native token spend limit + uint256 limit = seed; + updates[i] = abi.encodeCall(ISessionKeyPermissionsUpdates.setNativeTokenSpendLimit, (limit, 0)); + } + } + } }