Skip to content

Commit

Permalink
fix: [spearbit-82] Session key init perms (#80)
Browse files Browse the repository at this point in the history
  • Loading branch information
adam-alchemy authored and jaypaik committed Jan 25, 2024
1 parent 9320362 commit 00a6716
Show file tree
Hide file tree
Showing 9 changed files with 215 additions and 40 deletions.
4 changes: 3 additions & 1 deletion src/plugins/session/ISessionKeyPlugin.sol
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ interface ISessionKeyPlugin {
error InvalidPermissionsUpdate();
error InvalidToken();
error NativeTokenSpendLimitExceeded(address account, address sessionKey);
error LengthMismatch();

// ┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
// ┃ Execution functions ┃
Expand All @@ -75,7 +76,8 @@ interface ISessionKeyPlugin {
/// @notice Add a session key.
/// @param sessionKey The session key to register.
/// @param tag An optional tag that can be used to identify the key.
function addSessionKey(address sessionKey, bytes32 tag) external;
/// @param permissionUpdates The initial permission updates to apply to the key.
function addSessionKey(address sessionKey, bytes32 tag, bytes[] calldata permissionUpdates) external;

/// @notice Remove a session key.
/// @param sessionKey The session key to remove.
Expand Down
41 changes: 36 additions & 5 deletions src/plugins/session/SessionKeyPlugin.sol
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ contract SessionKeyPlugin is ISessionKeyPlugin, SessionKeyPermissions, BasePlugi
}

/// @inheritdoc ISessionKeyPlugin
function addSessionKey(address sessionKey, bytes32 tag) public override {
function addSessionKey(address sessionKey, bytes32 tag, bytes[] calldata permissionUpdates) public override {
if (!_sessionKeys.tryAdd(msg.sender, CastLib.toSetValue(sessionKey))) {
// This check ensures no duplicate keys and that the session key is not the zero address.
revert InvalidSessionKey(sessionKey);
Expand All @@ -135,6 +135,11 @@ contract SessionKeyPlugin is ISessionKeyPlugin, SessionKeyPermissions, BasePlugi
_updateSessionKeyId(msg.sender, sessionKey, SessionKeyId.wrap(bytes32(++_keyIdCounter[msg.sender])));

emit SessionKeyAdded(msg.sender, sessionKey, tag);

if (permissionUpdates.length > 0) {
// Call the public function internally to update the permissions.
updateKeyPermissions(sessionKey, permissionUpdates);
}
}

/// @inheritdoc ISessionKeyPlugin
Expand Down Expand Up @@ -357,13 +362,39 @@ contract SessionKeyPlugin is ISessionKeyPlugin, SessionKeyPermissions, BasePlugi

/// @inheritdoc BasePlugin
function _onInstall(bytes calldata data) internal override isNotInitialized(msg.sender) {
address[] memory sessionKeysToAdd = abi.decode(data, (address[]));
(address[] memory sessionKeysToAdd, bytes32[] memory tags,) =
abi.decode(data, (address[], bytes32[], bytes[][]));

// Permission updates depend on `calldata` types for the updates, in order to use array slices.
// Unfortunately, solidity does not support `abi.decode` outputting calldata types, so we do it manually.
// The prior `abi.decode` call validates that `data` has a valid encoding of a `bytes[][]` at this
// position, so we can load in the offset and length of the `bytes[][]` directly.
bytes[][] calldata permissionUpdates;
assembly ("memory-safe") {
// Get the offset of the bytes[][] used for permissions updates. The offset for this field is stored at
// the third word of `data`. Note that `data.offset` refers to the start of the actual data contents,
// one word after the length.
let permissionUpdatesRelativeOffset := calldataload(add(data.offset, 0x40))
// We now have the relative offset of the bytes[][] in `data`. We need to add the starting offset
// (aka `data.offset`) to get the absolute offset.
let permissionUpdatesLengthOffset := add(data.offset, permissionUpdatesRelativeOffset)
// Note that solidity expects the field `var.offset` to point to the start of the actual data contents,
// one word after the length. Therefore, we add 0x20 here to get the correct offset.
permissionUpdates.offset := add(0x20, permissionUpdatesLengthOffset)
// Load the length of the bytes[][].
permissionUpdates.length := calldataload(permissionUpdatesLengthOffset)
}

uint256 length = sessionKeysToAdd.length;

if (length != tags.length || length != permissionUpdates.length) {
revert LengthMismatch();
}

for (uint256 i = 0; i < length; ++i) {
// Use the public function to add the session key, set the key id, and emit the event.
// Note that no tags are set when adding keys with this method.
addSessionKey(sessionKeysToAdd[i], bytes32(0));
// Use the public function to add the session key, set the key id, emit the event, and update the
// permissions.
addSessionKey(sessionKeysToAdd[i], tags[i], permissionUpdates[i]);
}
}

Expand Down
2 changes: 1 addition & 1 deletion src/plugins/session/permissions/SessionKeyPermissions.sol
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ abstract contract SessionKeyPermissions is ISessionKeyPlugin, SessionKeyPermissi
// ┗━━━━━━━━━━━━━━━━━━━━━━━━━━━┛

/// @inheritdoc ISessionKeyPlugin
function updateKeyPermissions(address sessionKey, bytes[] calldata updates) external override {
function updateKeyPermissions(address sessionKey, bytes[] calldata updates) public override {
(SessionKeyData storage sessionKeyData, SessionKeyId keyId) = _loadSessionKey(msg.sender, sessionKey);

uint256 length = updates.length;
Expand Down
6 changes: 4 additions & 2 deletions test/account/UpgradeableModularAccountPluginManager.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ contract UpgradeableModularAccountPluginManagerTest is Test {
factory.createAccount(0, owners1);
}

function test_installPlugin() public {
function test_installPlugin_standard() public {
vm.startPrank(owner2);

bytes32 manifestHash = keccak256(abi.encode(sessionKeyPlugin.pluginManifest()));
Expand All @@ -136,7 +136,9 @@ contract UpgradeableModularAccountPluginManagerTest is Test {
IPluginManager(account2).installPlugin({
plugin: address(sessionKeyPlugin),
manifestHash: manifestHash,
pluginInitData: abi.encode(sessionKeys),
pluginInitData: abi.encode(
sessionKeys, new bytes32[](sessionKeys.length), new bytes[][](sessionKeys.length)
),
dependencies: dependencies,
injectedHooks: new IPluginManager.InjectedHook[](0)
});
Expand Down
90 changes: 70 additions & 20 deletions test/plugin/session/SessionKeyPluginWithMultiOwner.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ contract SessionKeyPluginWithMultiOwnerTest is Test {
account1.installPlugin({
plugin: address(sessionKeyPlugin),
manifestHash: manifestHash,
pluginInitData: abi.encode(new address[](0)),
pluginInitData: abi.encode(new address[](0), new bytes32[](0), new bytes[][](0)),
dependencies: dependencies,
injectedHooks: new IPluginManager.InjectedHook[](0)
});
Expand All @@ -96,7 +96,7 @@ contract SessionKeyPluginWithMultiOwnerTest is Test {
vm.expectEmit(true, true, true, true);
emit SessionKeyAdded(address(account1), sessionKeyToAdd, bytes32(0));
vm.prank(owner1);
SessionKeyPlugin(address(account1)).addSessionKey(sessionKeyToAdd, bytes32(0));
SessionKeyPlugin(address(account1)).addSessionKey(sessionKeyToAdd, bytes32(0), new bytes[](0));

// Check using all view methods
address[] memory sessionKeys = sessionKeyPlugin.sessionKeysOf(address(account1));
Expand All @@ -111,13 +111,13 @@ contract SessionKeyPluginWithMultiOwnerTest is Test {
// Zero address session key
address sessionKeyToAdd = address(0);
vm.expectRevert(abi.encodeWithSelector(ISessionKeyPlugin.InvalidSessionKey.selector, address(0)));
SessionKeyPlugin(address(account1)).addSessionKey(sessionKeyToAdd, bytes32(0));
SessionKeyPlugin(address(account1)).addSessionKey(sessionKeyToAdd, bytes32(0), new bytes[](0));

// Duplicate session key
sessionKeyToAdd = makeAddr("sessionKey1");
SessionKeyPlugin(address(account1)).addSessionKey(sessionKeyToAdd, bytes32(0));
SessionKeyPlugin(address(account1)).addSessionKey(sessionKeyToAdd, bytes32(0), new bytes[](0));
vm.expectRevert(abi.encodeWithSelector(ISessionKeyPlugin.InvalidSessionKey.selector, sessionKeyToAdd));
SessionKeyPlugin(address(account1)).addSessionKey(sessionKeyToAdd, bytes32(0));
SessionKeyPlugin(address(account1)).addSessionKey(sessionKeyToAdd, bytes32(0), new bytes[](0));

// Check using all view methods
address[] memory sessionKeys = sessionKeyPlugin.sessionKeysOf(address(account1));
Expand All @@ -131,8 +131,8 @@ contract SessionKeyPluginWithMultiOwnerTest is Test {
address sessionKey2 = makeAddr("sessionKey2");

vm.startPrank(owner1);
SessionKeyPlugin(address(account1)).addSessionKey(sessionKey1, bytes32(0));
SessionKeyPlugin(address(account1)).addSessionKey(sessionKey2, bytes32(0));
SessionKeyPlugin(address(account1)).addSessionKey(sessionKey1, bytes32(0), new bytes[](0));
SessionKeyPlugin(address(account1)).addSessionKey(sessionKey2, bytes32(0), new bytes[](0));
vm.stopPrank();

// Check using all view methods
Expand All @@ -159,11 +159,61 @@ contract SessionKeyPluginWithMultiOwnerTest is Test {
assertTrue(sessionKeyPlugin.isSessionKeyOf(address(account1), sessionKey2));
}

function testFuzz_sessionKey_addKeysDuringInstall(uint8 seed) public {
// First uninstall the plugin
vm.prank(owner1);
account1.uninstallPlugin(address(sessionKeyPlugin), "", "", new bytes[](0));

// Generate a set of initial session keys
uint256 addressCount = (seed % 16) + 1;

address[] memory sessionKeysToAdd = new address[](addressCount);
bytes32[] memory tags = new bytes32[](addressCount);
for (uint256 i = 0; i < addressCount; i++) {
sessionKeysToAdd[i] = makeAddr(string.concat(vm.toString(seed), "sessionKey", vm.toString(i)));
tags[i] = bytes32(uint256(i) + seed);
}

bytes memory onInstallData = abi.encode(sessionKeysToAdd, tags, new bytes[][](addressCount));

// Re-install the plugin
bytes32 manifestHash = keccak256(abi.encode(sessionKeyPlugin.pluginManifest()));
FunctionReference[] memory dependencies = new FunctionReference[](2);
dependencies[0] = FunctionReferenceLib.pack(
address(multiOwnerPlugin), uint8(IMultiOwnerPlugin.FunctionId.USER_OP_VALIDATION_OWNER)
);
dependencies[1] = FunctionReferenceLib.pack(
address(multiOwnerPlugin), uint8(IMultiOwnerPlugin.FunctionId.RUNTIME_VALIDATION_OWNER_OR_SELF)
);

for (uint256 i = 0; i < addressCount; i++) {
vm.expectEmit(true, true, true, true);
emit SessionKeyAdded(address(account1), sessionKeysToAdd[i], tags[i]);
}
vm.prank(owner1);
account1.installPlugin({
plugin: address(sessionKeyPlugin),
manifestHash: manifestHash,
pluginInitData: onInstallData,
dependencies: dependencies,
injectedHooks: new IPluginManager.InjectedHook[](0)
});

// Check using all view methods
address[] memory sessionKeys = sessionKeyPlugin.sessionKeysOf(address(account1));
assertEq(sessionKeys.length, addressCount);
for (uint256 i = 0; i < addressCount; i++) {
// Invert the indexing because the view function will return it in reverse order
assertEq(sessionKeys[sessionKeys.length - 1 - i], sessionKeysToAdd[i]);
assertTrue(sessionKeyPlugin.isSessionKeyOf(address(account1), sessionKeysToAdd[i]));
}
}

function test_sessionKey_rotate_valid() public {
// Add the first key
address sessionKey1 = makeAddr("sessionKey1");
vm.prank(owner1);
SessionKeyPlugin(address(account1)).addSessionKey(sessionKey1, bytes32(0));
SessionKeyPlugin(address(account1)).addSessionKey(sessionKey1, bytes32(0), new bytes[](0));

// Rotate to the second key
address sessionKey2 = makeAddr("sessionKey2");
Expand All @@ -186,8 +236,8 @@ contract SessionKeyPluginWithMultiOwnerTest is Test {
address sessionKey1 = makeAddr("sessionKey1");
address sessionKey2 = makeAddr("sessionKey2");
vm.startPrank(owner1);
SessionKeyPlugin(address(account1)).addSessionKey(sessionKey1, bytes32(0));
SessionKeyPlugin(address(account1)).addSessionKey(sessionKey2, bytes32(0));
SessionKeyPlugin(address(account1)).addSessionKey(sessionKey1, bytes32(0), new bytes[](0));
SessionKeyPlugin(address(account1)).addSessionKey(sessionKey2, bytes32(0), new bytes[](0));
vm.stopPrank();

// Attempt to rotate key 1 to key 2
Expand All @@ -201,7 +251,7 @@ contract SessionKeyPluginWithMultiOwnerTest is Test {
// Add the first key
address sessionKey1 = makeAddr("sessionKey1");
vm.prank(owner1);
SessionKeyPlugin(address(account1)).addSessionKey(sessionKey1, bytes32(0));
SessionKeyPlugin(address(account1)).addSessionKey(sessionKey1, bytes32(0), new bytes[](0));

// Attempt to rotate to the zero address
address zeroAddr = address(0);
Expand All @@ -215,7 +265,7 @@ contract SessionKeyPluginWithMultiOwnerTest is Test {
(address sessionKey1, uint256 sessionKeyPrivate) = makeAddrAndKey("sessionKey1");

vm.startPrank(owner1);
SessionKeyPlugin(address(account1)).addSessionKey(sessionKey1, bytes32(0));
SessionKeyPlugin(address(account1)).addSessionKey(sessionKey1, bytes32(0), new bytes[](0));
// Disable the allowlist and native token spend checking
bytes[] memory permissionUpdates = new bytes[](2);
permissionUpdates[0] = abi.encodeCall(
Expand Down Expand Up @@ -261,7 +311,7 @@ contract SessionKeyPluginWithMultiOwnerTest is Test {
(address sessionKey1,) = makeAddrAndKey("sessionKey1");

vm.startPrank(owner1);
SessionKeyPlugin(address(account1)).addSessionKey(sessionKey1, bytes32(0));
SessionKeyPlugin(address(account1)).addSessionKey(sessionKey1, bytes32(0), new bytes[](0));

Call[] memory calls = new Call[](1);
calls[0] = Call({target: recipient, value: 1 wei, data: ""});
Expand Down Expand Up @@ -401,8 +451,8 @@ contract SessionKeyPluginWithMultiOwnerTest is Test {
address sessionKey2 = makeAddr("sessionKey2");

vm.startPrank(owner1);
SessionKeyPlugin(address(account1)).addSessionKey(sessionKey1, bytes32(0));
SessionKeyPlugin(address(account1)).addSessionKey(sessionKey2, bytes32(0));
SessionKeyPlugin(address(account1)).addSessionKey(sessionKey1, bytes32(0), new bytes[](0));
SessionKeyPlugin(address(account1)).addSessionKey(sessionKey2, bytes32(0), new bytes[](0));

bytes32 predecessor = sessionKeyPlugin.findPredecessor(address(account1), sessionKey2);
assertEq(predecessor, bytes32(uint256(1)));
Expand All @@ -424,8 +474,8 @@ contract SessionKeyPluginWithMultiOwnerTest is Test {
address sessionKey2 = makeAddr("sessionKey2");

vm.startPrank(owner1);
SessionKeyPlugin(address(account1)).addSessionKey(sessionKey1, bytes32(0));
SessionKeyPlugin(address(account1)).addSessionKey(sessionKey2, bytes32(0));
SessionKeyPlugin(address(account1)).addSessionKey(sessionKey1, bytes32(0), new bytes[](0));
SessionKeyPlugin(address(account1)).addSessionKey(sessionKey2, bytes32(0), new bytes[](0));

bytes32 predecessor = sessionKeyPlugin.findPredecessor(address(account1), sessionKey1);
assertEq(predecessor, bytes32(bytes20(sessionKey2)));
Expand All @@ -446,7 +496,7 @@ contract SessionKeyPluginWithMultiOwnerTest is Test {
sessionKeysToAdd[0] = makeAddr("sessionKey1");

vm.startPrank(owner1);
SessionKeyPlugin(address(account1)).addSessionKey(sessionKeysToAdd[0], bytes32(0));
SessionKeyPlugin(address(account1)).addSessionKey(sessionKeysToAdd[0], bytes32(0), new bytes[](0));

address key2 = makeAddr("sessionKey2");
vm.expectRevert(abi.encodeWithSelector(ISessionKeyPlugin.SessionKeyNotFound.selector, key2));
Expand All @@ -460,7 +510,7 @@ contract SessionKeyPluginWithMultiOwnerTest is Test {
sessionKeysToAdd[0] = makeAddr("sessionKey1");

vm.prank(owner1);
SessionKeyPlugin(address(account1)).addSessionKey(sessionKeysToAdd[0], bytes32(0));
SessionKeyPlugin(address(account1)).addSessionKey(sessionKeysToAdd[0], bytes32(0), new bytes[](0));

assertFalse(sessionKeyPlugin.isSessionKeyOf(address(account1), address(1)));
}
Expand All @@ -484,7 +534,7 @@ contract SessionKeyPluginWithMultiOwnerTest is Test {

vm.startPrank(owner1);
for (uint256 i = 0; i < addressCount; i++) {
SessionKeyPlugin(address(account1)).addSessionKey(sessionKeysToAdd[i], bytes32(0));
SessionKeyPlugin(address(account1)).addSessionKey(sessionKeysToAdd[i], bytes32(0), new bytes[](0));
SessionKeyPlugin(address(account1)).updateKeyPermissions(sessionKeysToAdd[i], permissionUpdates);
}
vm.stopPrank();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ contract SessionKeyERC20SpendLimitsTest is Test {
account1.installPlugin({
plugin: address(sessionKeyPlugin),
manifestHash: manifestHash,
pluginInitData: abi.encode(new address[](0)),
pluginInitData: abi.encode(new address[](0), new bytes32[](0), new bytes[][](0)),
dependencies: dependencies,
injectedHooks: new IPluginManager.InjectedHook[](0)
});
Expand All @@ -95,7 +95,7 @@ contract SessionKeyERC20SpendLimitsTest is Test {
(sessionKey1, sessionKey1Private) = makeAddrAndKey("sessionKey1");

vm.prank(owner1);
SessionKeyPlugin(address(account1)).addSessionKey(sessionKey1, bytes32(0));
SessionKeyPlugin(address(account1)).addSessionKey(sessionKey1, bytes32(0), new bytes[](0));

// Disable the allowlist
bytes[] memory updates = new bytes[](1);
Expand Down
4 changes: 2 additions & 2 deletions test/plugin/session/permissions/SessionKeyGasLimits.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ contract SessionKeyGasLimitsTest is Test {
account1.installPlugin({
plugin: address(sessionKeyPlugin),
manifestHash: manifestHash,
pluginInitData: abi.encode(new address[](0)),
pluginInitData: abi.encode(new address[](0), new bytes32[](0), new bytes[][](0)),
dependencies: dependencies,
injectedHooks: new IPluginManager.InjectedHook[](0)
});
Expand All @@ -86,7 +86,7 @@ contract SessionKeyGasLimitsTest is Test {
(sessionKey1, sessionKey1Private) = makeAddrAndKey("sessionKey1");

vm.prank(owner1);
SessionKeyPlugin(address(account1)).addSessionKey(sessionKey1, bytes32(0));
SessionKeyPlugin(address(account1)).addSessionKey(sessionKey1, bytes32(0), new bytes[](0));

bytes[] memory updates = new bytes[](1);
updates[0] = abi.encodeCall(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ contract SessionKeyNativeTokenSpendLimitsTest is Test {
account1.installPlugin({
plugin: address(sessionKeyPlugin),
manifestHash: manifestHash,
pluginInitData: abi.encode(new address[](0)),
pluginInitData: abi.encode(new address[](0), new bytes32[](0), new bytes[][](0)),
dependencies: dependencies,
injectedHooks: new IPluginManager.InjectedHook[](0)
});
Expand All @@ -91,7 +91,7 @@ contract SessionKeyNativeTokenSpendLimitsTest is Test {
(sessionKey1, sessionKey1Private) = makeAddrAndKey("sessionKey1");

vm.prank(owner1);
SessionKeyPlugin(address(account1)).addSessionKey(sessionKey1, bytes32(0));
SessionKeyPlugin(address(account1)).addSessionKey(sessionKey1, bytes32(0), new bytes[](0));

// Remove the allowlist
bytes[] memory updates = new bytes[](1);
Expand Down
Loading

0 comments on commit 00a6716

Please sign in to comment.