From 9b765f18be8499db501520629bf62ed8afe64d1b Mon Sep 17 00:00:00 2001 From: adam-alchemy <127769144+adam-alchemy@users.noreply.github.com> Date: Mon, 22 Jan 2024 15:16:03 -0800 Subject: [PATCH] fix: [spearbit-77] Refactor session key loading util functions (#93) --- .../permissions/SessionKeyPermissions.sol | 4 ++-- .../permissions/SessionKeyPermissionsBase.sol | 24 ++++++++++++------- .../SessionKeyPermissionsLoupe.sol | 20 +++++++--------- 3 files changed, 25 insertions(+), 23 deletions(-) diff --git a/src/plugins/session/permissions/SessionKeyPermissions.sol b/src/plugins/session/permissions/SessionKeyPermissions.sol index 87e354e4f..21c1e172e 100644 --- a/src/plugins/session/permissions/SessionKeyPermissions.sol +++ b/src/plugins/session/permissions/SessionKeyPermissions.sol @@ -22,7 +22,7 @@ abstract contract SessionKeyPermissions is ISessionKeyPlugin, SessionKeyPermissi /// @inheritdoc ISessionKeyPlugin function updateKeyPermissions(address sessionKey, bytes[] calldata updates) public override { - (SessionKeyData storage sessionKeyData, SessionKeyId keyId) = _loadSessionKey(msg.sender, sessionKey); + (SessionKeyData storage sessionKeyData, SessionKeyId keyId) = _loadSessionKeyData(msg.sender, sessionKey); uint256 length = updates.length; for (uint256 i = 0; i < length; ++i) { @@ -34,7 +34,7 @@ abstract contract SessionKeyPermissions is ISessionKeyPlugin, SessionKeyPermissi /// @inheritdoc ISessionKeyPlugin function resetSessionKeyGasLimitTimestamp(address account, address sessionKey) external override { - (SessionKeyData storage sessionKeyData,) = _loadSessionKey(account, sessionKey); + (SessionKeyData storage sessionKeyData,) = _loadSessionKeyData(account, sessionKey); if (sessionKeyData.gasLimitResetThisBundle) { sessionKeyData.gasLimitResetThisBundle = false; sessionKeyData.gasLimitTimeInfo.lastUsed = uint48(block.timestamp); diff --git a/src/plugins/session/permissions/SessionKeyPermissionsBase.sol b/src/plugins/session/permissions/SessionKeyPermissionsBase.sol index 8decdb075..1632584dd 100644 --- a/src/plugins/session/permissions/SessionKeyPermissionsBase.sol +++ b/src/plugins/session/permissions/SessionKeyPermissionsBase.sol @@ -103,12 +103,6 @@ abstract contract SessionKeyPermissionsBase is ISessionKeyPlugin { // Internal Functions - function _assertKeyExists(SessionKeyId id, address sessionKey) internal pure { - if (SessionKeyId.unwrap(id) == bytes32(0)) { - revert InvalidSessionKey(sessionKey); - } - } - function _sessionKeyIdOf(address associated, address sessionKey) internal view returns (SessionKeyId keyId) { uint256 prefixAndBatchIndex = uint256(bytes32(SESSION_KEY_ID_PREFIX)); bytes memory associatedStorageKey = @@ -120,6 +114,19 @@ abstract contract SessionKeyPermissionsBase is ISessionKeyPlugin { } } + /// @dev Helper function that loads the session key id and asserts it is registered. + function _loadSessionKeyId(address associated, address sessionKey) + internal + view + returns (SessionKeyId keyId) + { + SessionKeyId id = _sessionKeyIdOf(associated, sessionKey); + if (SessionKeyId.unwrap(id) == bytes32(0)) { + revert InvalidSessionKey(sessionKey); + } + return id; + } + function _updateSessionKeyId(address associated, address sessionKey, SessionKeyId newId) internal { uint256 prefixAndBatchIndex = uint256(bytes32(SESSION_KEY_ID_PREFIX)); bytes memory associatedStorageKey = @@ -146,13 +153,12 @@ abstract contract SessionKeyPermissionsBase is ISessionKeyPlugin { /// @dev Helper function that loads the session key id, asserts it is registered, and returns the session key /// data and the key id. - function _loadSessionKey(address associated, address sessionKey) + function _loadSessionKeyData(address associated, address sessionKey) internal view returns (SessionKeyData storage sessionKeyData, SessionKeyId keyId) { - SessionKeyId id = _sessionKeyIdOf(associated, sessionKey); - _assertKeyExists(id, sessionKey); + SessionKeyId id = _loadSessionKeyId(associated, sessionKey); return (_sessionKeyDataOf(associated, id), id); } diff --git a/src/plugins/session/permissions/SessionKeyPermissionsLoupe.sol b/src/plugins/session/permissions/SessionKeyPermissionsLoupe.sol index 46e84e584..850e21062 100644 --- a/src/plugins/session/permissions/SessionKeyPermissionsLoupe.sol +++ b/src/plugins/session/permissions/SessionKeyPermissionsLoupe.sol @@ -15,7 +15,7 @@ abstract contract SessionKeyPermissionsLoupe is SessionKeyPermissionsBase { view returns (ContractAccessControlType) { - (SessionKeyData storage sessionKeyData,) = _loadSessionKey(account, sessionKey); + (SessionKeyData storage sessionKeyData,) = _loadSessionKeyData(account, sessionKey); return sessionKeyData.contractAccessControlType; } @@ -25,8 +25,7 @@ abstract contract SessionKeyPermissionsLoupe is SessionKeyPermissionsBase { view returns (bool isOnList, bool checkSelectors) { - SessionKeyId keyId = _sessionKeyIdOf(account, sessionKey); - _assertKeyExists(keyId, sessionKey); + SessionKeyId keyId = _loadSessionKeyId(account, sessionKey); ContractData storage contractData = _contractDataOf(account, keyId, contractAddress); return (contractData.isOnList, contractData.checkSelectors); } @@ -38,8 +37,7 @@ abstract contract SessionKeyPermissionsLoupe is SessionKeyPermissionsBase { address contractAddress, bytes4 selector ) external view returns (bool isOnList) { - SessionKeyId keyId = _sessionKeyIdOf(account, sessionKey); - _assertKeyExists(keyId, sessionKey); + SessionKeyId keyId = _loadSessionKeyId(account, sessionKey); FunctionData storage functionData = _functionDataOf(account, keyId, contractAddress, selector); return functionData.isOnList; } @@ -50,7 +48,7 @@ abstract contract SessionKeyPermissionsLoupe is SessionKeyPermissionsBase { view returns (uint48 validAfter, uint48 validUntil) { - (SessionKeyData storage sessionKeyData,) = _loadSessionKey(account, sessionKey); + (SessionKeyData storage sessionKeyData,) = _loadSessionKeyData(account, sessionKey); return (sessionKeyData.validAfter, sessionKeyData.validUntil); } @@ -60,7 +58,7 @@ abstract contract SessionKeyPermissionsLoupe is SessionKeyPermissionsBase { view returns (SpendLimitInfo memory info) { - (SessionKeyData storage sessionKeyData,) = _loadSessionKey(account, sessionKey); + (SessionKeyData storage sessionKeyData,) = _loadSessionKeyData(account, sessionKey); if (!sessionKeyData.nativeTokenSpendLimitBypassed) { info.hasLimit = true; @@ -79,7 +77,7 @@ abstract contract SessionKeyPermissionsLoupe is SessionKeyPermissionsBase { view returns (SpendLimitInfo memory) { - (, SessionKeyId keyId) = _loadSessionKey(account, sessionKey); + SessionKeyId keyId = _loadSessionKeyId(account, sessionKey); ContractData storage tokenContractData = _contractDataOf(account, keyId, token); return SpendLimitInfo({ hasLimit: tokenContractData.isERC20WithSpendLimit, @@ -92,9 +90,7 @@ abstract contract SessionKeyPermissionsLoupe is SessionKeyPermissionsBase { /// @inheritdoc ISessionKeyPlugin function getRequiredPaymaster(address account, address sessionKey) external view returns (address) { - SessionKeyId id = _sessionKeyIdOf(account, sessionKey); - _assertKeyExists(id, sessionKey); - SessionKeyData storage sessionKeyData = _sessionKeyDataOf(account, id); + (SessionKeyData storage sessionKeyData,) = _loadSessionKeyData(account, sessionKey); return sessionKeyData.hasRequiredPaymaster ? sessionKeyData.requiredPaymaster : address(0); } @@ -105,7 +101,7 @@ abstract contract SessionKeyPermissionsLoupe is SessionKeyPermissionsBase { override returns (SpendLimitInfo memory info, bool shouldReset) { - (SessionKeyData storage sessionKeyData,) = _loadSessionKey(account, sessionKey); + (SessionKeyData storage sessionKeyData,) = _loadSessionKeyData(account, sessionKey); shouldReset = sessionKeyData.gasLimitResetThisBundle;