From c95bc28f7d8ca0f8b1b0eb70a62d605584ace5a9 Mon Sep 17 00:00:00 2001 From: Jay Paik Date: Fri, 17 Nov 2023 18:45:44 -0500 Subject: [PATCH] feat: initial commit forge install: account-abstraction ver0.6.0 chore: clean interface comments --- .env.example | 6 + .gasestimates.md | 385 +++++ .github/pull_request_template.md | 25 + .github/workflows/test.yml | 110 ++ .gitignore | 8 + .gitmodules | 14 + .solhint-src.json | 17 + .solhint-test.json | 16 + .storagelayout.md | 82 + .vscode/settings.json | 11 + README.md | 51 + ext/UUPSUpgradeable.sol | 142 ++ foundry.toml | 46 + lib/account-abstraction | 1 + lib/forge-std | 1 + lib/light-account | 1 + lib/openzeppelin-contracts | 1 + package.json | 11 + pnpm-lock.yaml | 474 ++++++ remappings.txt | 5 + script/Counter.s.sol | 12 + slither.config.json | 3 + src/account/AccountExecutor.sol | 229 +++ src/account/AccountLoupe.sol | 141 ++ src/account/AccountStorageInitializable.sol | 50 + src/account/PluginManagerInternals.sol | 962 ++++++++++++ src/account/UpgradeableModularAccount.sol | 777 ++++++++++ src/factory/MultiOwnerMSCAFactory.sol | 120 ++ .../MultiOwnerTokenReceiverMSCAFactory.sol | 129 ++ src/helpers/KnownSelectors.sol | 56 + src/helpers/ValidationDataHelpers.sol | 50 + src/interfaces/IAccountInitializable.sol | 11 + src/interfaces/IAccountLoupe.sol | 57 + src/interfaces/IAccountView.sol | 16 + src/interfaces/IPlugin.sol | 206 +++ src/interfaces/IPluginExecutor.sol | 24 + src/interfaces/IPluginManager.sol | 68 + src/interfaces/IStandardExecutor.sol | 29 + src/interfaces/erc4337/IAccount.sol | 25 + src/interfaces/erc4337/IAggregator.sol | 11 + src/interfaces/erc4337/IEntryPoint.sol | 17 + src/interfaces/erc4337/IPaymaster.sol | 19 + src/interfaces/erc4337/UserOperation.sol | 17 + src/libraries/AccountStorageV1.sol | 125 ++ src/libraries/AssociatedLinkedListSetLib.sol | 501 +++++++ src/libraries/CastLib.sol | 51 + src/libraries/CountableLinkedListSetLib.sol | 77 + src/libraries/FunctionReferenceLib.sol | 41 + src/libraries/LinkedListSetLib.sol | 321 ++++ src/libraries/LinkedListSetUtils.sol | 11 + src/libraries/PluginStorageLib.sol | 58 + src/plugins/BasePlugin.sol | 226 +++ src/plugins/TokenReceiverPlugin.sol | 121 ++ src/plugins/owner/IMultiOwnerPlugin.sol | 82 + src/plugins/owner/MultiOwnerPlugin.sol | 438 ++++++ src/plugins/session/ISessionKeyPlugin.sol | 70 + src/plugins/session/SessionKeyPlugin.sol | 313 ++++ .../ISessionKeyPermissionsPlugin.sol | 159 ++ .../ISessionKeyPermissionsUpdates.sol | 66 + .../permissions/SessionKeyPermissionsBase.sol | 200 +++ .../SessionKeyPermissionsLoupe.sol | 133 ++ .../SessionKeyPermissionsPlugin.sol | 838 +++++++++++ test/TestUtils.sol | 21 + test/Utils.sol | 12 + test/account/AccountExecHooks.t.sol | 1318 +++++++++++++++++ test/account/AccountLoupe.t.sol | 570 +++++++ test/account/AccountPermittedCallHooks.t.sol | 871 +++++++++++ test/account/AccountPreValidationHooks.t.sol | 739 +++++++++ test/account/AccountReturnData.t.sol | 132 ++ .../ExecuteFromPluginPermissions.t.sol | 419 ++++++ test/account/ManifestValidity.t.sol | 188 +++ test/account/UpgradeableModularAccount.t.sol | 532 +++++++ ...gradeableModularAccountPluginManager.t.sol | 963 ++++++++++++ test/account/ValidationIntersection.t.sol | 318 ++++ test/comparison/CompareSimpleAccount.t.sol | 160 ++ test/factory/MultiOwnerMSCAFactoryTest.t.sol | 123 ++ .../MultiOwnerTokenReceiverFactoryTest.t.sol | 191 +++ test/helpers/KnownSelectors.t.sol | 67 + ...AssociatedLinkedListSetLibInvariants.t.sol | 106 ++ test/invariant/LLSLRepro.t.sol | 37 + .../LinkedListSetLibInvariants.t.sol | 78 + .../AssociatedLinkedListSetHandler.sol | 414 ++++++ .../handlers/LinkedListSetHandler.sol | 335 +++++ test/libraries/AccountStorage.t.sol | 50 + .../AssociatedLinkedListSetLib.t.sol | 224 +++ .../libraries/CountableLinkedListSetLib.t.sol | 68 + test/libraries/FunctionReferenceLib.t.sol | 36 + test/libraries/LinkedListSetLib.t.sol | 302 ++++ test/libraries/PluginStorageLib.t.sol | 91 ++ test/mocks/ContractOwner.sol | 19 + test/mocks/Counter.sol | 19 + test/mocks/Counter.t.sol | 24 + test/mocks/MockDiamondStorageContract.sol | 13 + test/mocks/MockPlugin.sol | 104 ++ .../plugins/BadTransferOwnershipPlugin.sol | 72 + test/mocks/plugins/BaseTestPlugin.sol | 12 + test/mocks/plugins/ChangingManifestPlugin.sol | 82 + test/mocks/plugins/ComprehensivePlugin.sol | 279 ++++ .../ExecFromPluginPermissionsMocks.sol | 470 ++++++ test/mocks/plugins/ManifestValidityMocks.sol | 294 ++++ test/mocks/plugins/ReturnDataPluginMocks.sol | 156 ++ test/mocks/plugins/UninstallErrorsPlugin.sol | 71 + test/mocks/plugins/ValidationPluginMocks.sol | 201 +++ test/mocks/tokens/MockERC1155.sol | 12 + test/mocks/tokens/MockERC20.sol | 12 + test/mocks/tokens/MockERC777.sol | 51 + test/plugin/TokenReceiverPlugin.t.sol | 206 +++ test/plugin/owner/MultiOwnerPlugin.t.sol | 260 ++++ .../owner/MultiOwnerPluginIntegration.t.sol | 174 +++ .../SessionKeyPluginWithMultiOwner.t.sol | 339 +++++ .../SessionKeyERC20SpendLimits.t.sol | 1070 +++++++++++++ .../permissions/SessionKeyGasLimits.t.sol | 697 +++++++++ .../SessionKeyNativeTokenSpendLimits.t.sol | 674 +++++++++ .../SessionKeyPermissionsPlugin.t.sol | 573 +++++++ test/upgrade/LightAccountToMSCA.t.sol | 93 ++ test/upgrade/MSCAToMSCA.t.sol | 80 + utils/inspect.sh | 46 + 117 files changed, 22735 insertions(+) create mode 100644 .env.example create mode 100644 .gasestimates.md create mode 100644 .github/pull_request_template.md create mode 100644 .github/workflows/test.yml create mode 100644 .gitignore create mode 100644 .gitmodules create mode 100644 .solhint-src.json create mode 100644 .solhint-test.json create mode 100644 .storagelayout.md create mode 100644 .vscode/settings.json create mode 100644 README.md create mode 100644 ext/UUPSUpgradeable.sol create mode 100644 foundry.toml create mode 160000 lib/account-abstraction create mode 160000 lib/forge-std create mode 160000 lib/light-account create mode 160000 lib/openzeppelin-contracts create mode 100644 package.json create mode 100644 pnpm-lock.yaml create mode 100644 remappings.txt create mode 100644 script/Counter.s.sol create mode 100644 slither.config.json create mode 100644 src/account/AccountExecutor.sol create mode 100644 src/account/AccountLoupe.sol create mode 100644 src/account/AccountStorageInitializable.sol create mode 100644 src/account/PluginManagerInternals.sol create mode 100644 src/account/UpgradeableModularAccount.sol create mode 100644 src/factory/MultiOwnerMSCAFactory.sol create mode 100644 src/factory/MultiOwnerTokenReceiverMSCAFactory.sol create mode 100644 src/helpers/KnownSelectors.sol create mode 100644 src/helpers/ValidationDataHelpers.sol create mode 100644 src/interfaces/IAccountInitializable.sol create mode 100644 src/interfaces/IAccountLoupe.sol create mode 100644 src/interfaces/IAccountView.sol create mode 100644 src/interfaces/IPlugin.sol create mode 100644 src/interfaces/IPluginExecutor.sol create mode 100644 src/interfaces/IPluginManager.sol create mode 100644 src/interfaces/IStandardExecutor.sol create mode 100644 src/interfaces/erc4337/IAccount.sol create mode 100644 src/interfaces/erc4337/IAggregator.sol create mode 100644 src/interfaces/erc4337/IEntryPoint.sol create mode 100644 src/interfaces/erc4337/IPaymaster.sol create mode 100644 src/interfaces/erc4337/UserOperation.sol create mode 100644 src/libraries/AccountStorageV1.sol create mode 100644 src/libraries/AssociatedLinkedListSetLib.sol create mode 100644 src/libraries/CastLib.sol create mode 100644 src/libraries/CountableLinkedListSetLib.sol create mode 100644 src/libraries/FunctionReferenceLib.sol create mode 100644 src/libraries/LinkedListSetLib.sol create mode 100644 src/libraries/LinkedListSetUtils.sol create mode 100644 src/libraries/PluginStorageLib.sol create mode 100644 src/plugins/BasePlugin.sol create mode 100644 src/plugins/TokenReceiverPlugin.sol create mode 100644 src/plugins/owner/IMultiOwnerPlugin.sol create mode 100644 src/plugins/owner/MultiOwnerPlugin.sol create mode 100644 src/plugins/session/ISessionKeyPlugin.sol create mode 100644 src/plugins/session/SessionKeyPlugin.sol create mode 100644 src/plugins/session/permissions/ISessionKeyPermissionsPlugin.sol create mode 100644 src/plugins/session/permissions/ISessionKeyPermissionsUpdates.sol create mode 100644 src/plugins/session/permissions/SessionKeyPermissionsBase.sol create mode 100644 src/plugins/session/permissions/SessionKeyPermissionsLoupe.sol create mode 100644 src/plugins/session/permissions/SessionKeyPermissionsPlugin.sol create mode 100644 test/TestUtils.sol create mode 100644 test/Utils.sol create mode 100644 test/account/AccountExecHooks.t.sol create mode 100644 test/account/AccountLoupe.t.sol create mode 100644 test/account/AccountPermittedCallHooks.t.sol create mode 100644 test/account/AccountPreValidationHooks.t.sol create mode 100644 test/account/AccountReturnData.t.sol create mode 100644 test/account/ExecuteFromPluginPermissions.t.sol create mode 100644 test/account/ManifestValidity.t.sol create mode 100644 test/account/UpgradeableModularAccount.t.sol create mode 100644 test/account/UpgradeableModularAccountPluginManager.t.sol create mode 100644 test/account/ValidationIntersection.t.sol create mode 100644 test/comparison/CompareSimpleAccount.t.sol create mode 100644 test/factory/MultiOwnerMSCAFactoryTest.t.sol create mode 100644 test/factory/MultiOwnerTokenReceiverFactoryTest.t.sol create mode 100644 test/helpers/KnownSelectors.t.sol create mode 100644 test/invariant/AssociatedLinkedListSetLibInvariants.t.sol create mode 100644 test/invariant/LLSLRepro.t.sol create mode 100644 test/invariant/LinkedListSetLibInvariants.t.sol create mode 100644 test/invariant/handlers/AssociatedLinkedListSetHandler.sol create mode 100644 test/invariant/handlers/LinkedListSetHandler.sol create mode 100644 test/libraries/AccountStorage.t.sol create mode 100644 test/libraries/AssociatedLinkedListSetLib.t.sol create mode 100644 test/libraries/CountableLinkedListSetLib.t.sol create mode 100644 test/libraries/FunctionReferenceLib.t.sol create mode 100644 test/libraries/LinkedListSetLib.t.sol create mode 100644 test/libraries/PluginStorageLib.t.sol create mode 100644 test/mocks/ContractOwner.sol create mode 100644 test/mocks/Counter.sol create mode 100644 test/mocks/Counter.t.sol create mode 100644 test/mocks/MockDiamondStorageContract.sol create mode 100644 test/mocks/MockPlugin.sol create mode 100644 test/mocks/plugins/BadTransferOwnershipPlugin.sol create mode 100644 test/mocks/plugins/BaseTestPlugin.sol create mode 100644 test/mocks/plugins/ChangingManifestPlugin.sol create mode 100644 test/mocks/plugins/ComprehensivePlugin.sol create mode 100644 test/mocks/plugins/ExecFromPluginPermissionsMocks.sol create mode 100644 test/mocks/plugins/ManifestValidityMocks.sol create mode 100644 test/mocks/plugins/ReturnDataPluginMocks.sol create mode 100644 test/mocks/plugins/UninstallErrorsPlugin.sol create mode 100644 test/mocks/plugins/ValidationPluginMocks.sol create mode 100644 test/mocks/tokens/MockERC1155.sol create mode 100644 test/mocks/tokens/MockERC20.sol create mode 100644 test/mocks/tokens/MockERC777.sol create mode 100644 test/plugin/TokenReceiverPlugin.t.sol create mode 100644 test/plugin/owner/MultiOwnerPlugin.t.sol create mode 100644 test/plugin/owner/MultiOwnerPluginIntegration.t.sol create mode 100644 test/plugin/session/SessionKeyPluginWithMultiOwner.t.sol create mode 100644 test/plugin/session/permissions/SessionKeyERC20SpendLimits.t.sol create mode 100644 test/plugin/session/permissions/SessionKeyGasLimits.t.sol create mode 100644 test/plugin/session/permissions/SessionKeyNativeTokenSpendLimits.t.sol create mode 100644 test/plugin/session/permissions/SessionKeyPermissionsPlugin.t.sol create mode 100644 test/upgrade/LightAccountToMSCA.t.sol create mode 100644 test/upgrade/MSCAToMSCA.t.sol create mode 100644 utils/inspect.sh diff --git a/.env.example b/.env.example new file mode 100644 index 00000000..ebc9b5b0 --- /dev/null +++ b/.env.example @@ -0,0 +1,6 @@ +DEPLOYER_PRIVATE_KEY= + +RPC_URL_MAINNET= +RPC_URL_GOERLI= + +ETHERSCAN_API_KEY= diff --git a/.gasestimates.md b/.gasestimates.md new file mode 100644 index 00000000..2563356a --- /dev/null +++ b/.gasestimates.md @@ -0,0 +1,385 @@ +# Gas Estimates +Generated via `bash utils/inspect.sh`. + +--- + +`forge test --gas-report --no-match-path "test/invariant/**/*"` +| lib/account-abstraction/contracts/core/EntryPoint.sol:EntryPoint contract | | | | | | +|---------------------------------------------------------------------------|-----------------|--------|--------|---------|---------| +| Deployment Cost | Deployment Size | | | | | +| 4177081 | 20761 | | | | | +| Function Name | min | avg | median | max | # calls | +| addStake | 47842 | 47842 | 47842 | 47842 | 1 | +| depositTo | 24537 | 24537 | 24537 | 24537 | 17 | +| getNonce | 708 | 1850 | 2708 | 2708 | 14 | +| getUserOpHash | 2133 | 2173 | 2178 | 2317 | 36 | +| handleOps | 25412 | 184086 | 105763 | 1071904 | 25 | +| innerHandleOp | 12528 | 36578 | 32063 | 101716 | 19 | +| receive | 2227 | 17754 | 22127 | 24127 | 33 | + + +| lib/account-abstraction/contracts/core/SenderCreator.sol:SenderCreator contract | | | | | | +|---------------------------------------------------------------------------------|-----------------|--------|--------|--------|---------| +| Deployment Cost | Deployment Size | | | | | +| 100953 | 532 | | | | | +| Function Name | min | avg | median | max | # calls | +| createSender | 140903 | 538016 | 538016 | 935129 | 4 | + + +| lib/account-abstraction/contracts/samples/SimpleAccount.sol:SimpleAccount contract | | | | | | +|------------------------------------------------------------------------------------|-----------------|-------|--------|-------|---------| +| Deployment Cost | Deployment Size | | | | | +| 1479733 | 7684 | | | | | +| Function Name | min | avg | median | max | # calls | +| execute | 8933 | 9921 | 10415 | 10415 | 3 | +| initialize | 26217 | 26217 | 26217 | 26217 | 6 | +| validateUserOp | 18601 | 26151 | 26151 | 33701 | 4 | + + +| lib/account-abstraction/contracts/samples/SimpleAccountFactory.sol:SimpleAccountFactory contract | | | | | | +|--------------------------------------------------------------------------------------------------|-----------------|--------|--------|--------|---------| +| Deployment Cost | Deployment Size | | | | | +| 1972541 | 10186 | | | | | +| Function Name | min | avg | median | max | # calls | +| createAccount | 137635 | 137635 | 137635 | 137635 | 6 | +| getAddress | 4675 | 4675 | 4675 | 4675 | 4 | + + +| lib/openzeppelin-contracts/contracts/proxy/ERC1967/ERC1967Proxy.sol:ERC1967Proxy contract | | | | | | +|-------------------------------------------------------------------------------------------|-----------------|--------|--------|---------|---------| +| Deployment Cost | Deployment Size | | | | | +| 70433 | 1164 | | | | | +| Function Name | min | avg | median | max | # calls | +| evilTransferOwnership | 20868 | 20868 | 20868 | 20868 | 1 | +| execute(address,uint256,bytes) | 9270 | 10254 | 10746 | 10746 | 3 | +| execute(address,uint256,bytes,bytes21) | 5614 | 12398 | 10267 | 36633 | 10 | +| executeBatch | 21511 | 21511 | 21511 | 21511 | 1 | +| executeFromPlugin | 15308 | 20129 | 23308 | 23308 | 7 | +| executeImmediateRecovery | 75340 | 81135 | 81135 | 86930 | 2 | +| finalizeRecovery | 41860 | 42896 | 42896 | 43932 | 2 | +| getInstalledPlugins | 1515 | 2207 | 2207 | 2899 | 2 | +| getSessionKeys | 8909 | 8909 | 8909 | 8909 | 1 | +| initialize | 25785 | 25785 | 25785 | 25785 | 1 | +| initiateRecovery | 45141 | 50936 | 50936 | 56731 | 2 | +| installPlugin | 5576 | 440630 | 527859 | 1048359 | 112 | +| owner | 3081 | 7753 | 3081 | 16081 | 107 | +| transferOwnership | 4465 | 15550 | 15550 | 26635 | 2 | +| uninstallPlugin | 92774 | 92774 | 92774 | 92774 | 1 | +| updateSessionKeyPerms | 46409 | 67866 | 70937 | 95425 | 5 | +| updateSessionKeys | 49550 | 53425 | 49550 | 65050 | 8 | +| upgradeToAndCall | 774511 | 790780 | 774511 | 2216373 | 90 | +| validateUserOp | 11570 | 41873 | 48696 | 65631 | 35 | + + +| lib/openzeppelin-contracts/contracts/token/ERC721/presets/ERC721PresetMinterPauserAutoId.sol:ERC721PresetMinterPauserAutoId contract | | | | | | +|--------------------------------------------------------------------------------------------------------------------------------------|-----------------|--------|--------|--------|---------| +| Deployment Cost | Deployment Size | | | | | +| 2441807 | 12707 | | | | | +| Function Name | min | avg | median | max | # calls | +| mint | 101582 | 101582 | 101582 | 101582 | 8 | +| ownerOf | 765 | 1765 | 1765 | 2765 | 2 | +| safeTransferFrom | 44988 | 47288 | 47288 | 49588 | 2 | + + +| src/account/UpgradeableModularAccount.sol:UpgradeableModularAccount contract | | | | | | +|------------------------------------------------------------------------------|-----------------|--------|--------|---------|---------| +| Deployment Cost | Deployment Size | | | | | +| 7211330 | 36341 | | | | | +| Function Name | min | avg | median | max | # calls | +| evilTransferOwnership | 20518 | 20518 | 20518 | 20518 | 1 | +| execute | 5265 | 12056 | 9924 | 36296 | 10 | +| executeBatch | 21120 | 21120 | 21120 | 21120 | 1 | +| executeFromPlugin | 14977 | 19794 | 22977 | 22977 | 7 | +| executeImmediateRecovery | 74919 | 78464 | 78464 | 82009 | 2 | +| finalizeRecovery | 41610 | 42646 | 42646 | 43682 | 2 | +| getInstalledPlugins | 1193 | 1879 | 1879 | 2565 | 2 | +| getSessionKeys | 8587 | 8587 | 8587 | 8587 | 1 | +| initialize | 301652 | 766787 | 771349 | 2208855 | 94 | +| initiateRecovery | 44756 | 48301 | 48301 | 51846 | 2 | +| installPlugin | 5208 | 439866 | 527498 | 1043462 | 112 | +| onERC1155BatchReceived | 2873 | 3138 | 3138 | 3403 | 2 | +| onERC1155Received | 2873 | 2947 | 2947 | 3022 | 2 | +| onERC721Received | 2774 | 2823 | 2823 | 2873 | 2 | +| owner | 2765 | 5965 | 2765 | 11265 | 107 | +| supportsInterface | 541 | 1541 | 1541 | 2541 | 6 | +| tokensReceived | 2682 | 2777 | 2777 | 2873 | 2 | +| transferOwnership | 4149 | 12967 | 12967 | 21785 | 2 | +| uninstallPlugin | 92509 | 92509 | 92509 | 92509 | 1 | +| updateSessionKeyPerms | 41560 | 62996 | 66070 | 90540 | 5 | +| updateSessionKeys | 49213 | 51963 | 49213 | 60213 | 8 | +| validateUserOp | 6620 | 40473 | 43775 | 60652 | 31 | + + +| src/factory/ModularAccountFactory.sol:ModularAccountFactory contract | | | | | | +|----------------------------------------------------------------------|-----------------|---------|--------|---------|---------| +| Deployment Cost | Deployment Size | | | | | +| 1198650 | 7383 | | | | | +| Function Name | min | avg | median | max | # calls | +| createAccount | 2011 | 847948 | 887633 | 2334550 | 8 | +| getAddress | 1852 | 1852 | 1852 | 1852 | 6 | +| upgradeToAndCall | 776586 | 1016743 | 776586 | 2215029 | 6 | + + +| src/factory/verifiers/ECDSASigVerifier.sol:ECDSASigVerifier contract | | | | | | +|----------------------------------------------------------------------|-----------------|------|--------|------|---------| +| Deployment Cost | Deployment Size | | | | | +| 209057 | 1072 | | | | | +| Function Name | min | avg | median | max | # calls | +| verify | 5317 | 5773 | 5317 | 7142 | 4 | + + +| src/plugins/TokenReceiverPlugin.sol:TokenReceiverPlugin contract | | | | | | +|------------------------------------------------------------------|-----------------|------|--------|------|---------| +| Deployment Cost | Deployment Size | | | | | +| 821665 | 4132 | | | | | +| Function Name | min | avg | median | max | # calls | +| onERC1155BatchReceived | 1110 | 1110 | 1110 | 1110 | 1 | +| onERC1155Received | 789 | 789 | 789 | 789 | 1 | +| onERC721Received | 547 | 547 | 547 | 547 | 1 | +| onInstall | 493 | 493 | 493 | 493 | 4 | +| pluginManifest | 9023 | 9023 | 9023 | 9023 | 8 | +| supportsInterface | 252 | 269 | 278 | 278 | 12 | +| tokensReceived | 640 | 640 | 640 | 640 | 1 | + + +| src/plugins/owner/ExternalOwnerPlugin.sol:ExternalOwnerPlugin contract | | | | | | +|------------------------------------------------------------------------|-----------------|-------|--------|-------|---------| +| Deployment Cost | Deployment Size | | | | | +| 1232490 | 6184 | | | | | +| Function Name | min | avg | median | max | # calls | +| isValidSignature | 4382 | 5382 | 5382 | 6383 | 2 | +| onInstall | 24672 | 24672 | 24672 | 24672 | 90 | +| owner | 568 | 1282 | 568 | 2568 | 126 | +| pluginManifest | 18366 | 18366 | 18366 | 18366 | 319 | +| runtimeValidator | 856 | 1150 | 856 | 2899 | 129 | +| supportsInterface | 227 | 287 | 304 | 304 | 501 | +| transferOwnership | 2739 | 13622 | 14089 | 22639 | 18 | +| userOpValidator | 4625 | 5898 | 6625 | 6631 | 11 | + + +| src/plugins/recovery/SocialRecoveryPlugin.sol:SocialRecoveryPlugin contract | | | | | | +|-----------------------------------------------------------------------------|-----------------|-------|--------|--------|---------| +| Deployment Cost | Deployment Size | | | | | +| 3259804 | 17563 | | | | | +| Function Name | min | avg | median | max | # calls | +| MIN_GUARDIAN_UPDATE_DELAY | 449 | 449 | 449 | 449 | 44 | +| MIN_RECOVERY_DELAY | 273 | 273 | 273 | 273 | 44 | +| accountDataOf(address)((address,uint48,uint48,uint16,uint48,uint48,uint48)) | 1485 | 3110 | 3485 | 5485 | 16 | +| accountDataOf(address)((uint48,uint64)) | 1485 | 3485 | 3485 | 5485 | 13 | +| cancelGuardianUpdate | 804 | 3079 | 3056 | 5401 | 4 | +| cancelRecovery | 921 | 8417 | 2764 | 29716 | 10 | +| executeImmediateRecovery | 2664 | 19401 | 7503 | 75900 | 16 | +| finalizeGuardianUpdate | 1090 | 15263 | 2303 | 78748 | 6 | +| finalizeRecovery | 1368 | 19398 | 17154 | 38887 | 6 | +| guardiansOf | 3486 | 3939 | 3939 | 4392 | 2 | +| initiateGuardianUpdate | 1204 | 44075 | 54624 | 104240 | 15 | +| initiateRecovery | 796 | 21824 | 25914 | 45774 | 26 | +| onInstall | 819 | 81499 | 88576 | 137744 | 51 | +| pendingGuardiansOf | 1648 | 2605 | 1648 | 4486 | 7 | +| pluginManifest | 22433 | 22433 | 22433 | 22433 | 22 | +| preExecutionHook | 1159 | 1597 | 1159 | 2914 | 8 | +| recoveryCancellationTypedDataHash | 893 | 893 | 893 | 893 | 4 | +| recoveryTypedDataHash | 1356 | 1356 | 1356 | 1356 | 22 | +| supportsInterface | 287 | 303 | 312 | 312 | 33 | + + +| src/plugins/security/AccountTimelockPlugin.sol:AccountTimelockPlugin contract | | | | | | +|-------------------------------------------------------------------------------|-----------------|-------|--------|-------|---------| +| Deployment Cost | Deployment Size | | | | | +| 1227484 | 6159 | | | | | +| Function Name | min | avg | median | max | # calls | +| accountDataOf | 902 | 902 | 902 | 902 | 3 | +| cancelAllQueuedTransactions | 2258 | 2394 | 2394 | 2531 | 2 | +| cancelQueuedTransaction | 2556 | 2960 | 2776 | 3549 | 3 | +| onInstall | 730 | 45821 | 47266 | 47266 | 63 | +| pluginManifest | 12213 | 12213 | 12213 | 12213 | 105 | +| preExecutionHook | 1659 | 2461 | 2259 | 3668 | 4 | +| queueTransaction | 2911 | 23146 | 26519 | 26519 | 7 | +| supportsInterface | 230 | 295 | 307 | 307 | 189 | +| timelockExpiration | 412 | 983 | 412 | 2412 | 7 | +| transactionId | 1378 | 1378 | 1378 | 1378 | 10 | + + +| src/plugins/session/SessionKeyPlugin.sol:SessionKeyPlugin contract | | | | | | +|--------------------------------------------------------------------|-----------------|-------|--------|-------|---------| +| Deployment Cost | Deployment Size | | | | | +| 1441109 | 7226 | | | | | +| Function Name | min | avg | median | max | # calls | +| getSessionKeys | 2378 | 2378 | 2378 | 2378 | 1 | +| onInstall | 922 | 922 | 922 | 922 | 8 | +| pluginManifest | 10671 | 10671 | 10671 | 10671 | 16 | +| supportsInterface | 230 | 291 | 307 | 307 | 42 | +| updateSessionKeys | 45650 | 45650 | 45650 | 45650 | 8 | +| userOpValidator | 5007 | 6461 | 7007 | 7007 | 11 | + + +| src/plugins/session/permissions/SessionKeyPermissionsPlugin.sol:SessionKeyPermissionsPlugin contract | | | | | | +|------------------------------------------------------------------------------------------------------|-----------------|-------|--------|-------|---------| +| Deployment Cost | Deployment Size | | | | | +| 1489564 | 7468 | | | | | +| Function Name | min | avg | median | max | # calls | +| getAccessControlType | 3159 | 3159 | 3159 | 3159 | 1 | +| onInstall | 495 | 495 | 495 | 495 | 6 | +| pluginManifest | 10061 | 10061 | 10061 | 10061 | 12 | +| preUserOpValidationHook | 2951 | 4616 | 4511 | 7062 | 10 | +| supportsInterface | 281 | 298 | 307 | 307 | 18 | +| updateSessionKeyPerms | 26979 | 48382 | 51462 | 75904 | 5 | + + +| test/invariant/handlers/LinkedListSetHandler.sol:LinkedListSetHandler contract | | | | | | +|--------------------------------------------------------------------------------|-----------------|-------|--------|--------|---------| +| Deployment Cost | Deployment Size | | | | | +| 1228284 | 6163 | | | | | +| Function Name | min | avg | median | max | # calls | +| add | 68269 | 92829 | 102109 | 108109 | 3 | +| addFlagKnown | 463 | 463 | 463 | 463 | 5 | +| addFlagRandom | 4879 | 4879 | 4879 | 4879 | 3 | +| clear | 961 | 4149 | 1961 | 11716 | 4 | +| removeIterate | 500 | 500 | 500 | 500 | 5 | +| removeKnownPrevKey | 627 | 1627 | 1627 | 2627 | 2 | +| removeRandKeyIterate | 5135 | 5135 | 5135 | 5135 | 3 | +| removeRandKnownPrevKey | 7470 | 7470 | 7470 | 7470 | 4 | + + +| test/mocks/Counter.sol:Counter contract | | | | | | +|-----------------------------------------|-----------------|-------|--------|-------|---------| +| Deployment Cost | Deployment Size | | | | | +| 75929 | 407 | | | | | +| Function Name | min | avg | median | max | # calls | +| increment | 3212 | 19022 | 22312 | 22312 | 37 | +| number | 277 | 810 | 277 | 2277 | 15 | +| setNumber | 2347 | 7497 | 2697 | 22247 | 4 | +| supportsInterface | 138 | 138 | 138 | 138 | 7 | + + +| test/mocks/MockBadPlugin.sol:MockBadPlugin contract | | | | | | +|-----------------------------------------------------|-----------------|------|--------|------|---------| +| Deployment Cost | Deployment Size | | | | | +| 513357 | 2592 | | | | | +| Function Name | min | avg | median | max | # calls | +| pluginManifest | 3966 | 3966 | 3966 | 3966 | 1 | +| supportsInterface | 202 | 202 | 202 | 202 | 1 | + + +| test/mocks/MockDiamondStorageContract.sol:MockDiamondStorageContract contract | | | | | | +|-------------------------------------------------------------------------------|-----------------|-------|--------|-------|---------| +| Deployment Cost | Deployment Size | | | | | +| 194432 | 1049 | | | | | +| Function Name | min | avg | median | max | # calls | +| initialize | 611 | 10787 | 10787 | 20963 | 2 | + + +| test/mocks/MockERC1155.sol:MockERC1155 contract | | | | | | +|-------------------------------------------------|-----------------|-------|--------|--------|---------| +| Deployment Cost | Deployment Size | | | | | +| 1100548 | 5706 | | | | | +| Function Name | min | avg | median | max | # calls | +| balanceOf | 519 | 1519 | 1519 | 2519 | 20 | +| mint | 27032 | 27032 | 27032 | 27032 | 40 | +| safeBatchTransferFrom | 89629 | 98584 | 98584 | 107540 | 2 | +| safeTransferFrom | 26383 | 29527 | 29527 | 32671 | 2 | + + +| test/mocks/MockERC777.sol:MockERC777 contract | | | | | | +|-----------------------------------------------|-----------------|-------|--------|-------|---------| +| Deployment Cost | Deployment Size | | | | | +| 532976 | 2690 | | | | | +| Function Name | min | avg | median | max | # calls | +| balanceOf | 562 | 1562 | 1562 | 2562 | 4 | +| mint | 22539 | 22539 | 22539 | 22539 | 8 | +| transfer | 6822 | 14790 | 14790 | 22759 | 2 | + + +| test/mocks/MockPlugin.sol:MockPlugin contract | | | | | | +|-----------------------------------------------|-----------------|------|--------|------|---------| +| Deployment Cost | Deployment Size | | | | | +| 531569 | 2683 | | | | | +| Function Name | min | avg | median | max | # calls | +| onInstall | 469 | 469 | 469 | 469 | 21 | +| pluginManifest | 3966 | 3966 | 3966 | 3966 | 42 | +| supportsInterface | 250 | 267 | 276 | 276 | 63 | + + +| test/mocks/ModularSimpleAccountFactory.sol:ModularSimpleAccountFactory contract | | | | | | +|---------------------------------------------------------------------------------|-----------------|--------|--------|--------|---------| +| Deployment Cost | Deployment Size | | | | | +| 8768980 | 47847 | | | | | +| Function Name | min | avg | median | max | # calls | +| UNSTAKE_DELAY | 301 | 301 | 301 | 301 | 1 | +| addStake | 57641 | 57641 | 57641 | 57641 | 1 | +| createAccount | 929361 | 929480 | 929361 | 934361 | 84 | +| getAddress | 45743 | 45743 | 45743 | 45743 | 57 | +| upgradeToAndCall | 774102 | 774191 | 774102 | 776602 | 84 | + + +| test/mocks/plugins/BadTransferOwnershipPlugin.sol:BadTransferOwnershipPlugin contract | | | | | | +|---------------------------------------------------------------------------------------|-----------------|-------|--------|-------|---------| +| Deployment Cost | Deployment Size | | | | | +| 669310 | 3371 | | | | | +| Function Name | min | avg | median | max | # calls | +| evilTransferOwnership | 18312 | 18312 | 18312 | 18312 | 1 | +| onInstall | 461 | 461 | 461 | 461 | 1 | +| pluginManifest | 5649 | 5649 | 5649 | 5649 | 2 | +| supportsInterface | 245 | 261 | 270 | 270 | 3 | + + +| test/mocks/plugins/ComprehensivePlugin.sol:ComprehensivePlugin contract | | | | | | +|-------------------------------------------------------------------------|-----------------|-------|--------|-------|---------| +| Deployment Cost | Deployment Size | | | | | +| 1033482 | 5190 | | | | | +| Function Name | min | avg | median | max | # calls | +| onInstall | 446 | 446 | 446 | 446 | 1 | +| onUninstall | 490 | 490 | 490 | 490 | 1 | +| pluginManifest | 24717 | 24717 | 24717 | 24717 | 3 | +| supportsInterface | 227 | 244 | 253 | 253 | 3 | + + +| test/mocks/plugins/ValidationPluginMocks.sol:MockUserOpValidation1HookPlugin contract | | | | | | +|---------------------------------------------------------------------------------------|-----------------|-------|--------|-------|---------| +| Deployment Cost | Deployment Size | | | | | +| 584827 | 2949 | | | | | +| Function Name | min | avg | median | max | # calls | +| onInstall | 467 | 467 | 467 | 467 | 10 | +| pluginManifest | 6328 | 6328 | 6328 | 6328 | 20 | +| preUserOpValidationHook | 653 | 653 | 653 | 653 | 7 | +| setValidationData | 24671 | 33199 | 24671 | 44571 | 7 | +| supportsInterface | 250 | 267 | 276 | 276 | 30 | +| userOpValidator | 558 | 558 | 558 | 558 | 6 | + + +| test/mocks/plugins/ValidationPluginMocks.sol:MockUserOpValidation2HookPlugin contract | | | | | | +|---------------------------------------------------------------------------------------|-----------------|-------|--------|-------|---------| +| Deployment Cost | Deployment Size | | | | | +| 603646 | 3043 | | | | | +| Function Name | min | avg | median | max | # calls | +| onInstall | 470 | 470 | 470 | 470 | 10 | +| pluginManifest | 7728 | 7728 | 7728 | 7728 | 20 | +| preUserOpValidationHook | 656 | 665 | 665 | 675 | 4 | +| setValidationData | 26925 | 36875 | 36875 | 46825 | 2 | +| supportsInterface | 256 | 273 | 282 | 282 | 30 | +| userOpValidator | 561 | 561 | 561 | 561 | 2 | + + +| test/mocks/plugins/ValidationPluginMocks.sol:MockUserOpValidationPlugin contract | | | | | | +|----------------------------------------------------------------------------------|-----------------|-------|--------|-------|---------| +| Deployment Cost | Deployment Size | | | | | +| 544788 | 2749 | | | | | +| Function Name | min | avg | median | max | # calls | +| onInstall | 488 | 488 | 488 | 488 | 10 | +| pluginManifest | 4756 | 4756 | 4756 | 4756 | 20 | +| setValidationData | 22352 | 22352 | 22352 | 22352 | 1 | +| supportsInterface | 247 | 264 | 273 | 273 | 30 | +| userOpValidator | 555 | 555 | 555 | 555 | 1 | + + +| test/plugin/TokenReceiverPlugin.t.sol:TokenReceiverPluginTest contract | | | | | | +|------------------------------------------------------------------------|-----------------|-----|--------|-----|---------| +| Deployment Cost | Deployment Size | | | | | +| 15105146 | 75333 | | | | | +| Function Name | min | avg | median | max | # calls | +| onERC1155Received | 731 | 731 | 731 | 731 | 40 | + + + + +Ran 19 test suites: 174 tests passed, 0 failed, 0 skipped (174 total tests) diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md new file mode 100644 index 00000000..0e144d2c --- /dev/null +++ b/.github/pull_request_template.md @@ -0,0 +1,25 @@ + + +## Motivation + + + +## Solution + + \ No newline at end of file diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml new file mode 100644 index 00000000..2d39989c --- /dev/null +++ b/.github/workflows/test.yml @@ -0,0 +1,110 @@ +name: account-contracts Test CI + +on: [pull_request, workflow_dispatch] + +concurrency: + group: ${{github.workflow}}-${{github.ref}} + cancel-in-progress: true + +# Runs linter, tests, and inspection checker in parallel +jobs: + lint: + name: Run Linters + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - name: Install Foundry + uses: foundry-rs/foundry-toolchain@v1 + with: + version: nightly-5be158ba6dc7c798a6f032026fe60fc01686b33b + - run: forge install + + - run: forge fmt --check + + - name: "Check out the repo" + uses: "actions/checkout@v3" + with: + submodules: "recursive" + + - name: "Install Foundry" + uses: "foundry-rs/foundry-toolchain@v1" + + - name: "Install Pnpm" + uses: "pnpm/action-setup@v2" + with: + version: "8" + + - name: "Install Node.js" + uses: "actions/setup-node@v3" + with: + cache: "pnpm" + node-version: "lts/*" + + - name: "Install the Node.js dependencies" + run: "pnpm install" + + - name: "Lint the contracts" + run: "pnpm lint" + + # check-inspect: + # name: Verify Inspections + # runs-on: ubuntu-latest + # steps: + # - uses: actions/checkout@v3 + # - name: Install Foundry + # uses: foundry-rs/foundry-toolchain@v1 + # with: + # version: nightly + + # - run: forge install + # - run: bash ./utils/inspect.sh + + # - run: git status --untracked-files=no --porcelain + # - run: git --no-pager diff + + # - name: Check Inspections + # run: if [[ -n "$(git status --untracked-files=no --porcelain)" ]]; then echo "Inspection difference detected, verify tests are passing and run \`bash ./utils/inspect.sh\` to fix." && exit 1; fi + + test: + name: Run Forge Tests + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + with: + submodules: recursive + + - name: Install Foundry + uses: foundry-rs/foundry-toolchain@v1 + with: + version: nightly-5be158ba6dc7c798a6f032026fe60fc01686b33b + + - name: Install forge dependencies + run: forge install + + - name: Build project + run: forge build --sizes + + - name: Run tests + run: FOUNDRY_PROFILE=deep forge test -vvv + + test-lite: + name: Run Forge Tests [lite build] + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + with: + submodules: recursive + + - name: Install Foundry + uses: foundry-rs/foundry-toolchain@v1 + with: + version: nightly-5be158ba6dc7c798a6f032026fe60fc01686b33b + + - name: Install forge dependencies + run: forge install + + - name: Build project + run: FOUNDRY_PROFILE=lite forge build + + - name: Run tests + run: FOUNDRY_PROFILE=lite forge test -vvv diff --git a/.gitignore b/.gitignore new file mode 100644 index 00000000..53f6fb42 --- /dev/null +++ b/.gitignore @@ -0,0 +1,8 @@ +# Foundry build and cache directories +out/ +cache/ +node_modules/ + +# coverage +report/ +lcov.info \ No newline at end of file diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 00000000..a74893ca --- /dev/null +++ b/.gitmodules @@ -0,0 +1,14 @@ +[submodule "lib/forge-std"] + path = lib/forge-std + url = https://github.com/foundry-rs/forge-std +[submodule "lib/account-abstraction"] + path = lib/account-abstraction + url = https://github.com/eth-infinitism/account-abstraction + branch = ver0.6.0 +[submodule "lib/openzeppelin-contracts"] + path = lib/openzeppelin-contracts + url = https://github.com/OpenZeppelin/openzeppelin-contracts + branch = release-v4.9 +[submodule "lib/light-account"] + path = lib/light-account + url = https://github.com/alchemyplatform/light-account diff --git a/.solhint-src.json b/.solhint-src.json new file mode 100644 index 00000000..5245a9b7 --- /dev/null +++ b/.solhint-src.json @@ -0,0 +1,17 @@ +{ + "extends": "solhint:recommended", + "rules": { + "immutable-vars-naming": ["error"], + "no-unused-import": ["error"], + "compiler-version": ["error", ">=0.8.21"], + "func-visibility": ["error", { "ignoreConstructors": true }], + "max-line-length": ["error", 120], + "func-param-name-mixedcase": ["error"], + "modifier-name-mixedcase": ["error"], + "private-vars-leading-underscore": ["error"], + "ordering": ["warn"], + "no-inline-assembly": "off", + "avoid-low-level-calls": "off", + "no-complex-fallback": "off" + } +} diff --git a/.solhint-test.json b/.solhint-test.json new file mode 100644 index 00000000..69a6a552 --- /dev/null +++ b/.solhint-test.json @@ -0,0 +1,16 @@ +{ + "extends": "solhint:recommended", + "rules": { + "func-name-mixedcase": "off", + "immutable-vars-naming": ["error"], + "no-unused-import": ["error"], + "compiler-version": ["error", ">=0.8.21"], + "func-visibility": ["error", { "ignoreConstructors": true }], + "max-line-length": ["error", 120], + "max-states-count": ["warn", 30], + "modifier-name-mixedcase": ["error"], + "private-vars-leading-underscore": ["error"], + "no-inline-assembly": "off", + "avoid-low-level-calls": "off" + } +} diff --git a/.storagelayout.md b/.storagelayout.md new file mode 100644 index 00000000..809adfdc --- /dev/null +++ b/.storagelayout.md @@ -0,0 +1,82 @@ +# Storage Layouts +Generated via `bash utils/inspect.sh`. + +--- + +`forge inspect --pretty src/account/AccountExecutor.sol:AccountExecutor storage-layout` +| Name | Type | Slot | Offset | Bytes | Contract | +|------|------|------|--------|-------|----------| + +`forge inspect --pretty src/account/BaseModularAccount.sol:BaseModularAccount storage-layout` +| Name | Type | Slot | Offset | Bytes | Contract | +|------|------|------|--------|-------|----------| + +`forge inspect --pretty src/account/BaseModularAccountLoupe.sol:BaseModularAccountLoupe storage-layout` +| Name | Type | Slot | Offset | Bytes | Contract | +|------|------|------|--------|-------|----------| + +`forge inspect --pretty src/account/DiamondStorageInitializable.sol:DiamondStorageInitializable storage-layout` +| Name | Type | Slot | Offset | Bytes | Contract | +|------|------|------|--------|-------|----------| + +`forge inspect --pretty src/account/UpgradeableModularAccount.sol:UpgradeableModularAccount storage-layout` +| Name | Type | Slot | Offset | Bytes | Contract | +|------|------|------|--------|-------|----------| + +`forge inspect --pretty src/factory/ModularAccountFactory.sol:ModularAccountFactory storage-layout` +| Name | Type | Slot | Offset | Bytes | Contract | +|---------------|---------|------|--------|-------|-------------------------------------------------------------| +| _owner | address | 0 | 0 | 20 | src/factory/ModularAccountFactory.sol:ModularAccountFactory | +| _pendingOwner | address | 1 | 0 | 20 | src/factory/ModularAccountFactory.sol:ModularAccountFactory | + +`forge inspect --pretty src/factory/ProxyLoader.sol:ProxyLoader storage-layout` +| Name | Type | Slot | Offset | Bytes | Contract | +|------|------|------|--------|-------|----------| + +`forge inspect --pretty src/factory/verifiers/ECDSASigVerifier.sol:ECDSASigVerifier storage-layout` +| Name | Type | Slot | Offset | Bytes | Contract | +|------|------|------|--------|-------|----------| + +`forge inspect --pretty src/plugins/BasePlugin.sol:BasePlugin storage-layout` +| Name | Type | Slot | Offset | Bytes | Contract | +|------|------|------|--------|-------|----------| + +`forge inspect --pretty src/plugins/TokenReceiverPlugin.sol:TokenReceiverPlugin storage-layout` +| Name | Type | Slot | Offset | Bytes | Contract | +|------|------|------|--------|-------|----------| + +`forge inspect --pretty src/plugins/owner/ExternalOwnerPlugin.sol:ExternalOwnerPlugin storage-layout` +| Name | Type | Slot | Offset | Bytes | Contract | +|---------|-----------------------------|------|--------|-------|---------------------------------------------------------------| +| _owners | mapping(address => address) | 0 | 0 | 32 | src/plugins/owner/ExternalOwnerPlugin.sol:ExternalOwnerPlugin | + +`forge inspect --pretty src/plugins/recovery/SocialRecoveryPlugin.sol:SocialRecoveryPlugin storage-layout` +| Name | Type | Slot | Offset | Bytes | Contract | +|------------------|--------------------------------------------------------------|------|--------|-------|--------------------------------------------------------------------| +| _nameFallback | string | 0 | 0 | 32 | src/plugins/recovery/SocialRecoveryPlugin.sol:SocialRecoveryPlugin | +| _versionFallback | string | 1 | 0 | 32 | src/plugins/recovery/SocialRecoveryPlugin.sol:SocialRecoveryPlugin | +| _accountData | mapping(address => struct ISocialRecoveryPlugin.AccountData) | 2 | 0 | 32 | src/plugins/recovery/SocialRecoveryPlugin.sol:SocialRecoveryPlugin | + +`forge inspect --pretty src/plugins/security/AccountTimelockPlugin.sol:AccountTimelockPlugin storage-layout` +| Name | Type | Slot | Offset | Bytes | Contract | +|---------------------|---------------------------------------------------------------|------|--------|-------|----------------------------------------------------------------------| +| _seed | uint64 | 0 | 0 | 8 | src/plugins/security/AccountTimelockPlugin.sol:AccountTimelockPlugin | +| _accountData | mapping(address => struct IAccountTimelockPlugin.AccountData) | 1 | 0 | 32 | src/plugins/security/AccountTimelockPlugin.sol:AccountTimelockPlugin | +| _timelockExpiration | mapping(bytes32 => uint256) | 2 | 0 | 32 | src/plugins/security/AccountTimelockPlugin.sol:AccountTimelockPlugin | + +`forge inspect --pretty src/plugins/session/SessionKeyPlugin.sol:SessionKeyPlugin storage-layout` +| Name | Type | Slot | Offset | Bytes | Contract | +|------|------|------|--------|-------|----------| + +`forge inspect --pretty src/plugins/session/permissions/SessionKeyPermissionsBase.sol:SessionKeyPermissionsBase storage-layout` +| Name | Type | Slot | Offset | Bytes | Contract | +|------|------|------|--------|-------|----------| + +`forge inspect --pretty src/plugins/session/permissions/SessionKeyPermissionsLoupe.sol:SessionKeyPermissionsLoupe storage-layout` +| Name | Type | Slot | Offset | Bytes | Contract | +|------|------|------|--------|-------|----------| + +`forge inspect --pretty src/plugins/session/permissions/SessionKeyPermissionsPlugin.sol:SessionKeyPermissionsPlugin storage-layout` +| Name | Type | Slot | Offset | Bytes | Contract | +|------|------|------|--------|-------|----------| + diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 00000000..284161a3 --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,11 @@ +{ + "solidity.packageDefaultDependenciesContractsDirectory": "src", + "solidity.packageDefaultDependenciesDirectory": "lib", + "solidity.compileUsingRemoteVersion": "v0.8.21", + "editor.formatOnSave": true, + "[solidity]": { + "editor.defaultFormatter": "JuanBlanco.solidity" + }, + "solidity.formatter": "forge", + "search.exclude": { "lib": true } +} diff --git a/README.md b/README.md new file mode 100644 index 00000000..92154419 --- /dev/null +++ b/README.md @@ -0,0 +1,51 @@ +# Alchemy Modular Smart Contract Account (MSCA) + +Contracts for an upgradeable modular smart contract account that is compatible with ERC-4337 and ERC-6900, along with a set of plugins. + +## Development + +### Naming convention + +- `selector` is used for all function selectors. +- `validation` and `validationFunction` are used to replace `validator`. +- `associated` and `associatedFunction` are used to represents `validationFunction` and `hook` + +## Build + +```bash +forge build + +# or use the lite profile to reduce compilation time +FOUNDRY_PROFILE=lite forge build +``` + +## Syntax check + +```bash +pnpm lint:src && pnpm lint:test +``` + +## Test + +```bash +forge test -vvv + +# or use the lite profile to reduce compilation time +FOUNDRY_PROFILE=lite forge test -vvv +``` + +## Generate Inspections + +```bash +bash utils/inspect.sh +``` + +## Static Analysis + +```bash +slither . +``` + +## External Libraries + +We use Solady's highly optimized [UUPSUpgradeable](https://github.com/Vectorized/solady/blob/a061f38f27cd7ae330a86d42d3f15b4e7237f064/src/utils/UUPSUpgradeable.sol) in our contracts diff --git a/ext/UUPSUpgradeable.sol b/ext/UUPSUpgradeable.sol new file mode 100644 index 00000000..71036daf --- /dev/null +++ b/ext/UUPSUpgradeable.sol @@ -0,0 +1,142 @@ +// SPDX-License-Identifier: MIT +pragma solidity ^0.8.4; + +/// @notice UUPS proxy mixin. +/// @author Solady (https://github.com/vectorized/solady/blob/main/src/utils/UUPSUpgradeable.sol) +/// @author Modified from OpenZeppelin +/// (https://github.com/OpenZeppelin/openzeppelin-contracts/blob/master/contracts/proxy/utils/UUPSUpgradeable.sol) +/// +/// Note: +/// - This implementation is intended to be used with ERC1967 proxies. +/// See: `LibClone.deployERC1967` and related functions. +/// - This implementation is NOT compatible with legacy OpenZeppelin proxies +/// which do not store the implementation at `_ERC1967_IMPLEMENTATION_SLOT`. +abstract contract UUPSUpgradeable { + /*´:°•.°+.*•´.*:˚.°*.˚•´.°:°•.°•.*•´.*:˚.°*.˚•´.°:°•.°+.*•´.*:*/ + /* CUSTOM ERRORS */ + /*.•°:°.´+˚.*°.˚:*.´•*.+°.•°:´*.´•*.•°.•°:°.´:•˚°.*°.˚:*.´+°.•*/ + + /// @dev The upgrade failed. + error UpgradeFailed(); + + /// @dev The call is from an unauthorized call context. + error UnauthorizedCallContext(); + + /*´:°•.°+.*•´.*:˚.°*.˚•´.°:°•.°•.*•´.*:˚.°*.˚•´.°:°•.°+.*•´.*:*/ + /* IMMUTABLES */ + /*.•°:°.´+˚.*°.˚:*.´•*.+°.•°:´*.´•*.•°.•°:°.´:•˚°.*°.˚:*.´+°.•*/ + + /// @dev For checking if the context is a delegate call. + uint256 private immutable __self = uint256(uint160(address(this))); + + /*´:°•.°+.*•´.*:˚.°*.˚•´.°:°•.°•.*•´.*:˚.°*.˚•´.°:°•.°+.*•´.*:*/ + /* EVENTS */ + /*.•°:°.´+˚.*°.˚:*.´•*.+°.•°:´*.´•*.•°.•°:°.´:•˚°.*°.˚:*.´+°.•*/ + + /// @dev Emitted when the proxy's implementation is upgraded. + event Upgraded(address indexed implementation); + + /// @dev `keccak256(bytes("Upgraded(address)"))`. + uint256 private constant _UPGRADED_EVENT_SIGNATURE = + 0xbc7cd75a20ee27fd9adebab32041f755214dbc6bffa90cc0225b39da2e5c2d3b; + + /*´:°•.°+.*•´.*:˚.°*.˚•´.°:°•.°•.*•´.*:˚.°*.˚•´.°:°•.°+.*•´.*:*/ + /* STORAGE */ + /*.•°:°.´+˚.*°.˚:*.´•*.+°.•°:´*.´•*.•°.•°:°.´:•˚°.*°.˚:*.´+°.•*/ + + /// @dev The ERC-1967 storage slot for the implementation in the proxy. + /// `uint256(keccak256("eip1967.proxy.implementation")) - 1`. + bytes32 internal constant _ERC1967_IMPLEMENTATION_SLOT = + 0x360894a13ba1a3210667c828492db98dca3e2076cc3735a920a3ca505d382bbc; + + /*´:°•.°+.*•´.*:˚.°*.˚•´.°:°•.°•.*•´.*:˚.°*.˚•´.°:°•.°+.*•´.*:*/ + /* UUPS OPERATIONS */ + /*.•°:°.´+˚.*°.˚:*.´•*.+°.•°:´*.´•*.•°.•°:°.´:•˚°.*°.˚:*.´+°.•*/ + + /// @dev Please override this function to check if `msg.sender` is authorized + /// to upgrade the proxy to `newImplementation`, reverting if not. + /// ``` + /// function _authorizeUpgrade(address) internal override onlyOwner {} + /// ``` + function _authorizeUpgrade(address newImplementation) internal virtual; + + /// @dev Returns the storage slot used by the implementation, + /// as specified in [ERC1822](https://eips.ethereum.org/EIPS/eip-1822). + /// + /// Note: The `notDelegated` modifier prevents accidental upgrades to + /// an implementation that is a proxy contract. + function proxiableUUID() public view virtual notDelegated returns (bytes32) { + // This function must always return `_ERC1967_IMPLEMENTATION_SLOT` to comply with ERC1967. + return _ERC1967_IMPLEMENTATION_SLOT; + } + + /// @dev Upgrades the proxy's implementation to `newImplementation`. + /// Emits a {Upgraded} event. + /// + /// Note: Passing in empty `data` skips the delegatecall to `newImplementation`. + function upgradeToAndCall(address newImplementation, bytes calldata data) + public + payable + virtual + onlyProxy + { + _authorizeUpgrade(newImplementation); + /// @solidity memory-safe-assembly + assembly { + newImplementation := shr(96, shl(96, newImplementation)) // Clears upper 96 bits. + mstore(0x01, 0x52d1902d) // `proxiableUUID()`. + let s := _ERC1967_IMPLEMENTATION_SLOT + // Check if `newImplementation` implements `proxiableUUID` correctly. + if iszero(eq(mload(staticcall(gas(), newImplementation, 0x1d, 0x04, 0x01, 0x20)), s)) { + mstore(0x01, 0x55299b49) // `UpgradeFailed()`. + revert(0x1d, 0x04) + } + // Emit the {Upgraded} event. + log2(codesize(), 0x00, _UPGRADED_EVENT_SIGNATURE, newImplementation) + sstore(s, newImplementation) // Updates the implementation. + + // Perform a delegatecall to `newImplementation` if `data` is non-empty. + if data.length { + // Forwards the `data` to `newImplementation` via delegatecall. + let m := mload(0x40) + calldatacopy(m, data.offset, data.length) + if iszero(delegatecall(gas(), newImplementation, m, data.length, codesize(), 0x00)) + { + // Bubble up the revert if the call reverts. + returndatacopy(m, 0x00, returndatasize()) + revert(m, returndatasize()) + } + } + } + } + + /// @dev Requires that the execution is performed through a proxy. + modifier onlyProxy() { + uint256 s = __self; + /// @solidity memory-safe-assembly + assembly { + // To enable use cases with an immutable default implementation in the bytecode, + // (see: ERC6551Proxy), we don't require that the proxy address must match the + // value stored in the implementation slot, which may not be initialized. + if eq(s, address()) { + mstore(0x00, 0x9f03a026) // `UnauthorizedCallContext()`. + revert(0x1c, 0x04) + } + } + _; + } + + /// @dev Requires that the execution is NOT performed via delegatecall. + /// This is the opposite of `onlyProxy`. + modifier notDelegated() { + uint256 s = __self; + /// @solidity memory-safe-assembly + assembly { + if iszero(eq(s, address())) { + mstore(0x00, 0x9f03a026) // `UnauthorizedCallContext()`. + revert(0x1c, 0x04) + } + } + _; + } +} \ No newline at end of file diff --git a/foundry.toml b/foundry.toml new file mode 100644 index 00000000..71ad0e28 --- /dev/null +++ b/foundry.toml @@ -0,0 +1,46 @@ +[profile.default] +solc = '0.8.21' +via_ir = true +src = 'src' +out = 'out' +test = 'test' +libs = ['lib'] +optimizer = true +optimizer_runs = 100 +ignored_error_codes = [] + +[fuzz] +runs = 500 + +[invariant] +runs=500 +fail_on_revert = true +depth = 10 + +[profile.lite] +solc = '0.8.21' +via_ir = false +optimizer = true +optimizer_runs = 10_000 +ignored_error_codes = [] + +[profile.deep.fuzz] +runs = 10000 + +[profile.deep.invariant] +runs = 5000 +depth = 32 + +[fmt] +line_length = 115 +wrap_comments = true + +[rpc_endpoints] +mainnet = "${RPC_URL_MAINNET}" +goerli = "${RPC_URL_GOERLI}" + +[etherscan] +mainnet = { key = "${ETHERSCAN_API_KEY}" } +goerli = { key = "${ETHERSCAN_API_KEY}" } + +# See more config options https://github.com/foundry-rs/foundry/tree/master/config \ No newline at end of file diff --git a/lib/account-abstraction b/lib/account-abstraction new file mode 160000 index 00000000..187613b0 --- /dev/null +++ b/lib/account-abstraction @@ -0,0 +1 @@ +Subproject commit 187613b0172c3a21cf3496e12cdfa24af04fb510 diff --git a/lib/forge-std b/lib/forge-std new file mode 160000 index 00000000..bdea49f9 --- /dev/null +++ b/lib/forge-std @@ -0,0 +1 @@ +Subproject commit bdea49f9bb3c58c8c35850c3bdc17eaeea756e9a diff --git a/lib/light-account b/lib/light-account new file mode 160000 index 00000000..8f6e5978 --- /dev/null +++ b/lib/light-account @@ -0,0 +1 @@ +Subproject commit 8f6e5978ee61495452f761400526ca2791b269dd diff --git a/lib/openzeppelin-contracts b/lib/openzeppelin-contracts new file mode 160000 index 00000000..9329cfac --- /dev/null +++ b/lib/openzeppelin-contracts @@ -0,0 +1 @@ +Subproject commit 9329cfacd4c7d20bcb43d772d947ff9e39b65df9 diff --git a/package.json b/package.json new file mode 100644 index 00000000..e6510e73 --- /dev/null +++ b/package.json @@ -0,0 +1,11 @@ +{ + "devDependencies": { + "pnpm": "^8.7.5", + "solhint": "^3.6.2" + }, + "scripts": { + "lint": "pnpm lint:src && pnpm lint:test", + "lint:src": "solhint -c .solhint-src.json ./src/**/*.sol", + "lint:test": "solhint -c .solhint-test.json ./test/**/*.sol" + } +} diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml new file mode 100644 index 00000000..393efc6e --- /dev/null +++ b/pnpm-lock.yaml @@ -0,0 +1,474 @@ +lockfileVersion: '6.0' + +settings: + autoInstallPeers: true + excludeLinksFromLockfile: false + +devDependencies: + pnpm: + specifier: ^8.7.5 + version: 8.7.5 + solhint: + specifier: ^3.6.2 + version: 3.6.2 + +packages: + + /@babel/code-frame@7.22.13: + resolution: {integrity: sha512-XktuhWlJ5g+3TJXc5upd9Ks1HutSArik6jf2eAjYFyIOf4ej3RN+184cZbzDvbPnuTJIUhPKKJE3cIsYTiAT3w==} + engines: {node: '>=6.9.0'} + dependencies: + '@babel/highlight': 7.22.13 + chalk: 2.4.2 + dev: true + + /@babel/helper-validator-identifier@7.22.15: + resolution: {integrity: sha512-4E/F9IIEi8WR94324mbDUMo074YTheJmd7eZF5vITTeYchqAi6sYXRLHUVsmkdmY4QjfKTcB2jB7dVP3NaBElQ==} + engines: {node: '>=6.9.0'} + dev: true + + /@babel/highlight@7.22.13: + resolution: {integrity: sha512-C/BaXcnnvBCmHTpz/VGZ8jgtE2aYlW4hxDhseJAWZb7gqGM/qtCK6iZUb0TyKFf7BOUsBH7Q7fkRsDRhg1XklQ==} + engines: {node: '>=6.9.0'} + dependencies: + '@babel/helper-validator-identifier': 7.22.15 + chalk: 2.4.2 + js-tokens: 4.0.0 + dev: true + + /@solidity-parser/parser@0.16.1: + resolution: {integrity: sha512-PdhRFNhbTtu3x8Axm0uYpqOy/lODYQK+MlYSgqIsq2L8SFYEHJPHNUiOTAJbDGzNjjr1/n9AcIayxafR/fWmYw==} + dependencies: + antlr4ts: 0.5.0-alpha.4 + dev: true + + /ajv@6.12.6: + resolution: {integrity: sha512-j3fVLgvTo527anyYyJOGTYJbG+vnnQYvE0m5mmkc1TK+nxAppkCLMIL0aZ4dblVCNoGShhm+kzE4ZUykBoMg4g==} + dependencies: + fast-deep-equal: 3.1.3 + fast-json-stable-stringify: 2.1.0 + json-schema-traverse: 0.4.1 + uri-js: 4.4.1 + dev: true + + /ajv@8.12.0: + resolution: {integrity: sha512-sRu1kpcO9yLtYxBKvqfTeh9KzZEwO3STyX1HT+4CaDzC6HpTGYhIhPIzj9XuKU7KYDwnaeh5hcOwjy1QuJzBPA==} + dependencies: + fast-deep-equal: 3.1.3 + json-schema-traverse: 1.0.0 + require-from-string: 2.0.2 + uri-js: 4.4.1 + dev: true + + /ansi-regex@5.0.1: + resolution: {integrity: sha512-quJQXlTSUGL2LH9SUXo8VwsY4soanhgo6LNSm84E1LBcE8s3O0wpdiRzyR9z/ZZJMlMWv37qOOb9pdJlMUEKFQ==} + engines: {node: '>=8'} + dev: true + + /ansi-styles@3.2.1: + resolution: {integrity: sha512-VT0ZI6kZRdTh8YyJw3SMbYm/u+NqfsAxEpWO0Pf9sq8/e94WxxOpPKx9FR1FlyCtOVDNOQ+8ntlqFxiRc+r5qA==} + engines: {node: '>=4'} + dependencies: + color-convert: 1.9.3 + dev: true + + /ansi-styles@4.3.0: + resolution: {integrity: sha512-zbB9rCJAT1rbjiVDb2hqKFHNYLxgtk8NURxZ3IZwD3F6NtxbXZQCnnSi1Lkx+IDohdPlFp222wVALIheZJQSEg==} + engines: {node: '>=8'} + dependencies: + color-convert: 2.0.1 + dev: true + + /antlr4@4.13.1: + resolution: {integrity: sha512-kiXTspaRYvnIArgE97z5YVVf/cDVQABr3abFRR6mE7yesLMkgu4ujuyV/sgxafQ8wgve0DJQUJ38Z8tkgA2izA==} + engines: {node: '>=16'} + dev: true + + /antlr4ts@0.5.0-alpha.4: + resolution: {integrity: sha512-WPQDt1B74OfPv/IMS2ekXAKkTZIHl88uMetg6q3OTqgFxZ/dxDXI0EWLyZid/1Pe6hTftyg5N7gel5wNAGxXyQ==} + dev: true + + /argparse@2.0.1: + resolution: {integrity: sha512-8+9WqebbFzpX9OR+Wa6O29asIogeRMzcGtAINdpMHHyAg10f05aSFVBbcEqGf/PXw1EjAZ+q2/bEBg3DvurK3Q==} + dev: true + + /ast-parents@0.0.1: + resolution: {integrity: sha512-XHusKxKz3zoYk1ic8Un640joHbFMhbqneyoZfoKnEGtf2ey9Uh/IdpcQplODdO/kENaMIWsD0nJm4+wX3UNLHA==} + dev: true + + /astral-regex@2.0.0: + resolution: {integrity: sha512-Z7tMw1ytTXt5jqMcOP+OQteU1VuNK9Y02uuJtKQ1Sv69jXQKKg5cibLwGJow8yzZP+eAc18EmLGPal0bp36rvQ==} + engines: {node: '>=8'} + dev: true + + /balanced-match@1.0.2: + resolution: {integrity: sha512-3oSeUO0TMV67hN1AmbXsK4yaqU7tjiHlbxRDZOpH0KW9+CeX4bRAaX0Anxt0tx2MrpRpWwQaPwIlISEJhYU5Pw==} + dev: true + + /brace-expansion@2.0.1: + resolution: {integrity: sha512-XnAIvQ8eM+kC6aULx6wuQiwVsnzsi9d3WxzV3FpWTGA19F621kwdbsAcFKXgKUHZWsy+mY6iL1sHTxWEFCytDA==} + dependencies: + balanced-match: 1.0.2 + dev: true + + /callsites@3.1.0: + resolution: {integrity: sha512-P8BjAsXvZS+VIDUI11hHCQEv74YT67YUi5JJFNWIqL235sBmjX4+qx9Muvls5ivyNENctx46xQLQ3aTuE7ssaQ==} + engines: {node: '>=6'} + dev: true + + /chalk@2.4.2: + resolution: {integrity: sha512-Mti+f9lpJNcwF4tWV8/OrTTtF1gZi+f8FqlyAdouralcFWFQWF2+NgCHShjkCb+IFBLq9buZwE1xckQU4peSuQ==} + engines: {node: '>=4'} + dependencies: + ansi-styles: 3.2.1 + escape-string-regexp: 1.0.5 + supports-color: 5.5.0 + dev: true + + /chalk@4.1.2: + resolution: {integrity: sha512-oKnbhFyRIXpUuez8iBMmyEa4nbj4IOQyuhc/wy9kY7/WVPcwIO9VA668Pu8RkO7+0G76SLROeyw9CpQ061i4mA==} + engines: {node: '>=10'} + dependencies: + ansi-styles: 4.3.0 + supports-color: 7.2.0 + dev: true + + /color-convert@1.9.3: + resolution: {integrity: sha512-QfAUtd+vFdAtFQcC8CCyYt1fYWxSqAiK2cSD6zDB8N3cpsEBAvRxp9zOGg6G/SHHJYAT88/az/IuDGALsNVbGg==} + dependencies: + color-name: 1.1.3 + dev: true + + /color-convert@2.0.1: + resolution: {integrity: sha512-RRECPsj7iu/xb5oKYcsFHSppFNnsj/52OVTRKb4zP5onXwVF3zVmmToNcOfGC+CRDpfK/U584fMg38ZHCaElKQ==} + engines: {node: '>=7.0.0'} + dependencies: + color-name: 1.1.4 + dev: true + + /color-name@1.1.3: + resolution: {integrity: sha512-72fSenhMw2HZMTVHeCA9KCmpEIbzWiQsjN+BHcBbS9vr1mtt+vJjPdksIBNUmKAW8TFUDPJK5SUU3QhE9NEXDw==} + dev: true + + /color-name@1.1.4: + resolution: {integrity: sha512-dOy+3AuW3a2wNbZHIuMZpTcgjGuLU/uBL/ubcZF9OXbDo8ff4O8yVp5Bf0efS8uEoYo5q4Fx7dY9OgQGXgAsQA==} + dev: true + + /commander@10.0.1: + resolution: {integrity: sha512-y4Mg2tXshplEbSGzx7amzPwKKOCGuoSRP/CjEdwwk0FOGlUbq6lKuoyDZTNZkmxHdJtp54hdfY/JUrdL7Xfdug==} + engines: {node: '>=14'} + dev: true + + /cosmiconfig@8.3.6: + resolution: {integrity: sha512-kcZ6+W5QzcJ3P1Mt+83OUv/oHFqZHIx8DuxG6eZ5RGMERoLqp4BuGjhHLYGK+Kf5XVkQvqBSmAy/nGWN3qDgEA==} + engines: {node: '>=14'} + peerDependencies: + typescript: '>=4.9.5' + peerDependenciesMeta: + typescript: + optional: true + dependencies: + import-fresh: 3.3.0 + js-yaml: 4.1.0 + parse-json: 5.2.0 + path-type: 4.0.0 + dev: true + + /emoji-regex@8.0.0: + resolution: {integrity: sha512-MSjYzcWNOA0ewAHpz0MxpYFvwg6yjy1NG3xteoqz644VCo/RPgnr1/GGt+ic3iJTzQ8Eu3TdM14SawnVUmGE6A==} + dev: true + + /error-ex@1.3.2: + resolution: {integrity: sha512-7dFHNmqeFSEt2ZBsCriorKnn3Z2pj+fd9kmI6QoWw4//DL+icEBfc0U7qJCisqrTsKTjw4fNFy2pW9OqStD84g==} + dependencies: + is-arrayish: 0.2.1 + dev: true + + /escape-string-regexp@1.0.5: + resolution: {integrity: sha512-vbRorB5FUQWvla16U8R/qgaFIya2qGzwDrNmCZuYKrbdSUMG6I1ZCGQRefkRVhuOkIGVne7BQ35DSfo1qvJqFg==} + engines: {node: '>=0.8.0'} + dev: true + + /fast-deep-equal@3.1.3: + resolution: {integrity: sha512-f3qQ9oQy9j2AhBe/H9VC91wLmKBCCU/gDOnKNAYG5hswO7BLKj09Hc5HYNz9cGI++xlpDCIgDaitVs03ATR84Q==} + dev: true + + /fast-diff@1.3.0: + resolution: {integrity: sha512-VxPP4NqbUjj6MaAOafWeUn2cXWLcCtljklUtZf0Ind4XQ+QPtmA0b18zZy0jIQx+ExRVCR/ZQpBmik5lXshNsw==} + dev: true + + /fast-json-stable-stringify@2.1.0: + resolution: {integrity: sha512-lhd/wF+Lk98HZoTCtlVraHtfh5XYijIjalXck7saUtuanSDyLMxnHhSXEDJqHxD7msR8D0uCmqlkwjCV8xvwHw==} + dev: true + + /fs.realpath@1.0.0: + resolution: {integrity: sha512-OO0pH2lK6a0hZnAdau5ItzHPI6pUlvI7jMVnxUQRtw4owF2wk8lOSabtGDCTP4Ggrg2MbGnWO9X8K1t4+fGMDw==} + dev: true + + /glob@8.1.0: + resolution: {integrity: sha512-r8hpEjiQEYlF2QU0df3dS+nxxSIreXQS1qRhMJM0Q5NDdR386C7jb7Hwwod8Fgiuex+k0GFjgft18yvxm5XoCQ==} + engines: {node: '>=12'} + dependencies: + fs.realpath: 1.0.0 + inflight: 1.0.6 + inherits: 2.0.4 + minimatch: 5.1.6 + once: 1.4.0 + dev: true + + /has-flag@3.0.0: + resolution: {integrity: sha512-sKJf1+ceQBr4SMkvQnBDNDtf4TXpVhVGateu0t918bl30FnbE2m4vNLX+VWe/dpjlb+HugGYzW7uQXH98HPEYw==} + engines: {node: '>=4'} + dev: true + + /has-flag@4.0.0: + resolution: {integrity: sha512-EykJT/Q1KjTWctppgIAgfSO0tKVuZUjhgMr17kqTumMl6Afv3EISleU7qZUzoXDFTAHTDC4NOoG/ZxU3EvlMPQ==} + engines: {node: '>=8'} + dev: true + + /ignore@5.2.4: + resolution: {integrity: sha512-MAb38BcSbH0eHNBxn7ql2NH/kX33OkB3lZ1BNdh7ENeRChHTYsTvWrMubiIAMNS2llXEEgZ1MUOBtXChP3kaFQ==} + engines: {node: '>= 4'} + dev: true + + /import-fresh@3.3.0: + resolution: {integrity: sha512-veYYhQa+D1QBKznvhUHxb8faxlrwUnxseDAbAp457E0wLNio2bOSKnjYDhMj+YiAq61xrMGhQk9iXVk5FzgQMw==} + engines: {node: '>=6'} + dependencies: + parent-module: 1.0.1 + resolve-from: 4.0.0 + dev: true + + /inflight@1.0.6: + resolution: {integrity: sha512-k92I/b08q4wvFscXCLvqfsHCrjrF7yiXsQuIVvVE7N82W3+aqpzuUdBbfhWcy/FZR3/4IgflMgKLOsvPDrGCJA==} + dependencies: + once: 1.4.0 + wrappy: 1.0.2 + dev: true + + /inherits@2.0.4: + resolution: {integrity: sha512-k/vGaX4/Yla3WzyMCvTQOXYeIHvqOKtnqBduzTHpzpQZzAskKMhZ2K+EnBiSM9zGSoIFeMpXKxa4dYeZIQqewQ==} + dev: true + + /is-arrayish@0.2.1: + resolution: {integrity: sha512-zz06S8t0ozoDXMG+ube26zeCTNXcKIPJZJi8hBrF4idCLms4CG9QtK7qBl1boi5ODzFpjswb5JPmHCbMpjaYzg==} + dev: true + + /is-fullwidth-code-point@3.0.0: + resolution: {integrity: sha512-zymm5+u+sCsSWyD9qNaejV3DFvhCKclKdizYaJUuHA83RLjb7nSuGnddCHGv0hk+KY7BMAlsWeK4Ueg6EV6XQg==} + engines: {node: '>=8'} + dev: true + + /js-tokens@4.0.0: + resolution: {integrity: sha512-RdJUflcE3cUzKiMqQgsCu06FPu9UdIJO0beYbPhHN4k6apgJtifcoCtT9bcxOpYBtpD2kCM6Sbzg4CausW/PKQ==} + dev: true + + /js-yaml@4.1.0: + resolution: {integrity: sha512-wpxZs9NoxZaJESJGIZTyDEaYpl0FKSA+FB9aJiyemKhMwkxQg63h4T1KJgUGHpTqPDNRcmmYLugrRjJlBtWvRA==} + hasBin: true + dependencies: + argparse: 2.0.1 + dev: true + + /json-parse-even-better-errors@2.3.1: + resolution: {integrity: sha512-xyFwyhro/JEof6Ghe2iz2NcXoj2sloNsWr/XsERDK/oiPCfaNhl5ONfp+jQdAZRQQ0IJWNzH9zIZF7li91kh2w==} + dev: true + + /json-schema-traverse@0.4.1: + resolution: {integrity: sha512-xbbCH5dCYU5T8LcEhhuh7HJ88HXuW3qsI3Y0zOZFKfZEHcpWiHU/Jxzk629Brsab/mMiHQti9wMP+845RPe3Vg==} + dev: true + + /json-schema-traverse@1.0.0: + resolution: {integrity: sha512-NM8/P9n3XjXhIZn1lLhkFaACTOURQXjWhV4BA/RnOv8xvgqtqpAX9IO4mRQxSx1Rlo4tqzeqb0sOlruaOy3dug==} + dev: true + + /lines-and-columns@1.2.4: + resolution: {integrity: sha512-7ylylesZQ/PV29jhEDl3Ufjo6ZX7gCqJr5F7PKrqc93v7fzSymt1BpwEU8nAUXs8qzzvqhbjhK5QZg6Mt/HkBg==} + dev: true + + /lodash.truncate@4.4.2: + resolution: {integrity: sha512-jttmRe7bRse52OsWIMDLaXxWqRAmtIUccAQ3garviCqJjafXOfNMO0yMfNpdD6zbGaTU0P5Nz7e7gAT6cKmJRw==} + dev: true + + /lodash@4.17.21: + resolution: {integrity: sha512-v2kDEe57lecTulaDIuNTPy3Ry4gLGJ6Z1O3vE1krgXZNrsQ+LFTGHVxVjcXPs17LhbZVGedAJv8XZ1tvj5FvSg==} + dev: true + + /lru-cache@6.0.0: + resolution: {integrity: sha512-Jo6dJ04CmSjuznwJSS3pUeWmd/H0ffTlkXXgwZi+eq1UCmqQwCh+eLsYOYCwY991i2Fah4h1BEMCx4qThGbsiA==} + engines: {node: '>=10'} + dependencies: + yallist: 4.0.0 + dev: true + + /minimatch@5.1.6: + resolution: {integrity: sha512-lKwV/1brpG6mBUFHtb7NUmtABCb2WZZmm2wNiOA5hAb8VdCS4B3dtMWyvcoViccwAW/COERjXLt0zP1zXUN26g==} + engines: {node: '>=10'} + dependencies: + brace-expansion: 2.0.1 + dev: true + + /once@1.4.0: + resolution: {integrity: sha512-lNaJgI+2Q5URQBkccEKHTQOPaXdUxnZZElQTZY0MFUAuaEqe1E+Nyvgdz/aIyNi6Z9MzO5dv1H8n58/GELp3+w==} + dependencies: + wrappy: 1.0.2 + dev: true + + /parent-module@1.0.1: + resolution: {integrity: sha512-GQ2EWRpQV8/o+Aw8YqtfZZPfNRWZYkbidE9k5rpl/hC3vtHHBfGm2Ifi6qWV+coDGkrUKZAxE3Lot5kcsRlh+g==} + engines: {node: '>=6'} + dependencies: + callsites: 3.1.0 + dev: true + + /parse-json@5.2.0: + resolution: {integrity: sha512-ayCKvm/phCGxOkYRSCM82iDwct8/EonSEgCSxWxD7ve6jHggsFl4fZVQBPRNgQoKiuV/odhFrGzQXZwbifC8Rg==} + engines: {node: '>=8'} + dependencies: + '@babel/code-frame': 7.22.13 + error-ex: 1.3.2 + json-parse-even-better-errors: 2.3.1 + lines-and-columns: 1.2.4 + dev: true + + /path-type@4.0.0: + resolution: {integrity: sha512-gDKb8aZMDeD/tZWs9P6+q0J9Mwkdl6xMV8TjnGP3qJVJ06bdMgkbBlLU8IdfOsIsFz2BW1rNVT3XuNEl8zPAvw==} + engines: {node: '>=8'} + dev: true + + /pluralize@8.0.0: + resolution: {integrity: sha512-Nc3IT5yHzflTfbjgqWcCPpo7DaKy4FnpB0l/zCAW0Tc7jxAiuqSxHasntB3D7887LSrA93kDJ9IXovxJYxyLCA==} + engines: {node: '>=4'} + dev: true + + /pnpm@8.7.5: + resolution: {integrity: sha512-WI8WZb89Uiq5x2jdz4PcQMG9ovTnXcDCEpoEckPYIT2zD8/+dEhVozPlT7bu3WkBgE0uTARtgyIKAFt+IpW2cQ==} + engines: {node: '>=16.14'} + hasBin: true + dev: true + + /prettier@2.8.8: + resolution: {integrity: sha512-tdN8qQGvNjw4CHbY+XXk0JgCXn9QiF21a55rBe5LJAU+kDyC4WQn4+awm2Xfk2lQMk5fKup9XgzTZtGkjBdP9Q==} + engines: {node: '>=10.13.0'} + hasBin: true + requiresBuild: true + dev: true + optional: true + + /punycode@2.3.0: + resolution: {integrity: sha512-rRV+zQD8tVFys26lAGR9WUuS4iUAngJScM+ZRSKtvl5tKeZ2t5bvdNFdNHBW9FWR4guGHlgmsZ1G7BSm2wTbuA==} + engines: {node: '>=6'} + dev: true + + /require-from-string@2.0.2: + resolution: {integrity: sha512-Xf0nWe6RseziFMu+Ap9biiUbmplq6S9/p+7w7YXP/JBHhrUDDUhwa+vANyubuqfZWTveU//DYVGsDG7RKL/vEw==} + engines: {node: '>=0.10.0'} + dev: true + + /resolve-from@4.0.0: + resolution: {integrity: sha512-pb/MYmXstAkysRFx8piNI1tGFNQIFA3vkE3Gq4EuA1dF6gHp/+vgZqsCGJapvy8N3Q+4o7FwvquPJcnZ7RYy4g==} + engines: {node: '>=4'} + dev: true + + /semver@7.5.4: + resolution: {integrity: sha512-1bCSESV6Pv+i21Hvpxp3Dx+pSD8lIPt8uVjRrxAUt/nbswYc+tK6Y2btiULjd4+fnq15PX+nqQDC7Oft7WkwcA==} + engines: {node: '>=10'} + hasBin: true + dependencies: + lru-cache: 6.0.0 + dev: true + + /slice-ansi@4.0.0: + resolution: {integrity: sha512-qMCMfhY040cVHT43K9BFygqYbUPFZKHOg7K73mtTWJRb8pyP3fzf4Ixd5SzdEJQ6MRUg/WBnOLxghZtKKurENQ==} + engines: {node: '>=10'} + dependencies: + ansi-styles: 4.3.0 + astral-regex: 2.0.0 + is-fullwidth-code-point: 3.0.0 + dev: true + + /solhint@3.6.2: + resolution: {integrity: sha512-85EeLbmkcPwD+3JR7aEMKsVC9YrRSxd4qkXuMzrlf7+z2Eqdfm1wHWq1ffTuo5aDhoZxp2I9yF3QkxZOxOL7aQ==} + hasBin: true + dependencies: + '@solidity-parser/parser': 0.16.1 + ajv: 6.12.6 + antlr4: 4.13.1 + ast-parents: 0.0.1 + chalk: 4.1.2 + commander: 10.0.1 + cosmiconfig: 8.3.6 + fast-diff: 1.3.0 + glob: 8.1.0 + ignore: 5.2.4 + js-yaml: 4.1.0 + lodash: 4.17.21 + pluralize: 8.0.0 + semver: 7.5.4 + strip-ansi: 6.0.1 + table: 6.8.1 + text-table: 0.2.0 + optionalDependencies: + prettier: 2.8.8 + transitivePeerDependencies: + - typescript + dev: true + + /string-width@4.2.3: + resolution: {integrity: sha512-wKyQRQpjJ0sIp62ErSZdGsjMJWsap5oRNihHhu6G7JVO/9jIB6UyevL+tXuOqrng8j/cxKTWyWUwvSTriiZz/g==} + engines: {node: '>=8'} + dependencies: + emoji-regex: 8.0.0 + is-fullwidth-code-point: 3.0.0 + strip-ansi: 6.0.1 + dev: true + + /strip-ansi@6.0.1: + resolution: {integrity: sha512-Y38VPSHcqkFrCpFnQ9vuSXmquuv5oXOKpGeT6aGrr3o3Gc9AlVa6JBfUSOCnbxGGZF+/0ooI7KrPuUSztUdU5A==} + engines: {node: '>=8'} + dependencies: + ansi-regex: 5.0.1 + dev: true + + /supports-color@5.5.0: + resolution: {integrity: sha512-QjVjwdXIt408MIiAqCX4oUKsgU2EqAGzs2Ppkm4aQYbjm+ZEWEcW4SfFNTr4uMNZma0ey4f5lgLrkB0aX0QMow==} + engines: {node: '>=4'} + dependencies: + has-flag: 3.0.0 + dev: true + + /supports-color@7.2.0: + resolution: {integrity: sha512-qpCAvRl9stuOHveKsn7HncJRvv501qIacKzQlO/+Lwxc9+0q2wLyv4Dfvt80/DPn2pqOBsJdDiogXGR9+OvwRw==} + engines: {node: '>=8'} + dependencies: + has-flag: 4.0.0 + dev: true + + /table@6.8.1: + resolution: {integrity: sha512-Y4X9zqrCftUhMeH2EptSSERdVKt/nEdijTOacGD/97EKjhQ/Qs8RTlEGABSJNNN8lac9kheH+af7yAkEWlgneA==} + engines: {node: '>=10.0.0'} + dependencies: + ajv: 8.12.0 + lodash.truncate: 4.4.2 + slice-ansi: 4.0.0 + string-width: 4.2.3 + strip-ansi: 6.0.1 + dev: true + + /text-table@0.2.0: + resolution: {integrity: sha512-N+8UisAXDGk8PFXP4HAzVR9nbfmVJ3zYLAWiTIoqC5v5isinhr+r5uaO8+7r3BMfuNIufIsA7RdpVgacC2cSpw==} + dev: true + + /uri-js@4.4.1: + resolution: {integrity: sha512-7rKUyy33Q1yc98pQ1DAmLtwX109F7TIfWlW1Ydo8Wl1ii1SeHieeh0HHfPeL2fMXK6z0s8ecKs9frCuLJvndBg==} + dependencies: + punycode: 2.3.0 + dev: true + + /wrappy@1.0.2: + resolution: {integrity: sha512-l4Sp/DRseor9wL6EvV2+TuQn63dMkPjZ/sp9XkghTEbV9KlPS1xUsZ3u7/IQO4wxtcFB4bgpQPRcR3QCvezPcQ==} + dev: true + + /yallist@4.0.0: + resolution: {integrity: sha512-3wdGidZyq5PB084XLES5TpOSRA3wjXAlIWMhum2kRcv/41Sn2emQ0dycQW4uZXLejwKvg6EsvbdlVL+FYEct7A==} + dev: true diff --git a/remappings.txt b/remappings.txt new file mode 100644 index 00000000..dd9c13a7 --- /dev/null +++ b/remappings.txt @@ -0,0 +1,5 @@ +ds-test/=lib/forge-std/lib/ds-test/src/ +forge-std/=lib/forge-std/src/ +@eth-infinitism/account-abstraction/=lib/account-abstraction/contracts/ +@openzeppelin/=lib/openzeppelin-contracts/ +@alchemy/light-account/=lib/light-account/ diff --git a/script/Counter.s.sol b/script/Counter.s.sol new file mode 100644 index 00000000..fb4224a5 --- /dev/null +++ b/script/Counter.s.sol @@ -0,0 +1,12 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.14; + +import "forge-std/Script.sol"; + +contract CounterScript is Script { + function setUp() public {} + + function run() public { + vm.broadcast(); + } +} diff --git a/slither.config.json b/slither.config.json new file mode 100644 index 00000000..fd69b55d --- /dev/null +++ b/slither.config.json @@ -0,0 +1,3 @@ +{ + "filter_paths": "lib" +} \ No newline at end of file diff --git a/src/account/AccountExecutor.sol b/src/account/AccountExecutor.sol new file mode 100644 index 00000000..7f4e8ebd --- /dev/null +++ b/src/account/AccountExecutor.sol @@ -0,0 +1,229 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.21; + +import {ERC165Checker} from "@openzeppelin/contracts/utils/introspection/ERC165Checker.sol"; + +import {UserOperation} from "../interfaces/erc4337/UserOperation.sol"; +import {IPlugin} from "../interfaces/IPlugin.sol"; + +/// @title Account Executor +/// @author Alchemy +/// @notice Provides internal functions for executing calls on a modular account. +abstract contract AccountExecutor { + error PluginCallDenied(address plugin); + + /// @dev If the target is a plugin (as determined by its support for the IPlugin interface), revert. + /// This prevents the modular account from calling plugins (both installed and uninstalled) outside + /// of the normal flow (via execution functions installed on the account), which could lead to data + /// inconsistencies and unexpected behavior. + /// @param target The address of the contract to call. + /// @param value The value to send with the call. + /// @param data The call data. + /// @return result The return data of the call, or the error message from the call if call reverts. + function _exec(address target, uint256 value, bytes memory data) internal returns (bytes memory result) { + if (ERC165Checker.supportsInterface(target, type(IPlugin).interfaceId)) { + revert PluginCallDenied(target); + } + + bool success; + (success, result) = target.call{value: value}(data); + + if (!success) { + // Directly bubble up revert messages + assembly ("memory-safe") { + revert(add(result, 32), mload(result)) + } + } + } + + /// @dev Performs an `_executeRaw` for a call buffer holding a call to one of: + /// - Pre Runtime Validation Hook + /// - Runtime Validation + /// - Pre Execution Hook + /// - Pre Permitted Call Hook + /// And if it fails, reverts with the appropriate custom error. + function _executeRuntimePluginFunction(bytes memory buffer, address plugin, bytes4 errorSelector) internal { + if (!_executeRaw(plugin, buffer)) { + _revertOnRuntimePluginFunctionFail(buffer, plugin, errorSelector); + } + } + + function _executeRaw(address plugin, bytes memory buffer) internal returns (bool success) { + assembly ("memory-safe") { + success := + call( + gas(), + plugin, + /*value*/ + 0, + /*argOffset*/ + add(buffer, 0x20), // jump over 32 bytes for length + /*argSize*/ + mload(buffer), + /*retOffset*/ + 0, + /*retSize*/ + 0 + ) + } + } + + function _executeUserOpPluginFunction(bytes memory buffer, address plugin) + internal + returns (uint256 validationData) + { + assembly ("memory-safe") { + switch and( + gt(returndatasize(), 0x1f), + call( + /*forward all gas, but can't use gas opcode due to validation opcode restrictions*/ + not(0), + plugin, + /*value*/ + 0, + /*argOffset*/ + add(buffer, 0x20), // jump over 32 bytes for length + /*argSize*/ + mload(buffer), + /*retOffset*/ + 0, + /*retSize*/ + 0x20 + ) + ) + case 0 { + // If the call failed or did not return enough data, we return 1 (SIG_FAIL) as the validation data + validationData := 1 + } + default { + // Otherwise, we return the first word of the return data as the validation data + validationData := mload(0) + } + } + } + + function _allocateRuntimeCallBuffer(bytes calldata data) internal view returns (bytes memory buffer) { + buffer = abi.encodeWithSelector(bytes4(0), 0, msg.sender, msg.value, data); + } + + function _allocateUserOpCallBuffer(bytes4 selector, UserOperation calldata userOp, bytes32 userOpHash) + internal + pure + returns (bytes memory buffer) + { + buffer = abi.encodeWithSelector(selector, 0, userOp, userOpHash); + } + + /// @dev Updates which plugin function the buffer will call. + function _updatePluginCallBufferSelector(bytes memory buffer, bytes4 pluginSelector) internal pure { + assembly ("memory-safe") { + // We only want to write to the first 4 bytes, so we first load the first word, + // mask out the fist 4 bytes, then OR in the new selector. + let existingWord := mload(add(buffer, 0x20)) + // Clear the upper 4 bytes of the existing word + existingWord := and(existingWord, shr(32, not(0))) + // Clear the lower 28 bytes of the selector + pluginSelector := and(pluginSelector, shl(224, 0xFFFFFFFF)) + // OR in the new selector + existingWord := or(existingWord, pluginSelector) + mstore(add(buffer, 0x20), existingWord) + } + } + + function _updatePluginCallBufferFunctionId(bytes memory buffer, uint8 functionId) internal pure { + assembly ("memory-safe") { + // The function ID is a uint8 type, which is left-padded. + // We do want to mask it, however, because this is an internal function and the upper bits may not be + // cleared. + mstore(add(buffer, 0x24), and(functionId, 0xff)) + } + } + + /// @dev Re-interpret the existing call buffer as just a bytes memory hold msg.data. + /// Since it's already there, and we don't plan on using the buffer again, we can write over the other fields + /// to store calldata length before the data, then return a new memory pointer holding the length. + function _convertRuntimeCallBufferToExecBuffer(bytes memory runtimeCallBuffer) + internal + pure + returns (bytes memory execCallBuffer) + { + if (runtimeCallBuffer.length == 0) { + // There was no existing call buffer. This case is never reached in actual code, but in the event that + // it would be, we would need to re-collect all the calldata. + execCallBuffer = msg.data; + } else { + assembly ("memory-safe") { + // Skip forward to point to the new "length-holding" field. + // Since the existing buffer is already ABI-encoded, we can just skip to the inner callData field. + // This field is location bytes ahead. It skips over: + // - (32 bytes) The original buffer's length field + // - (4 bytes) Selector + // - (32 bytes) Function id + // - (32 bytes) Sender + // - (32 bytes) Value + // - (32 bytes) data offset + // Totoal: 164 bytes + execCallBuffer := add(runtimeCallBuffer, 164) + } + } + } + + /// @dev Used by pre exec hooks to store data for post exec hooks. + function _collectReturnData() internal pure returns (bytes memory returnData) { + assembly ("memory-safe") { + // Allocate a buffer of that size, advancing the memory pointer to the nearest word + returnData := mload(0x40) + mstore(returnData, returndatasize()) + mstore(0x40, and(add(add(returnData, returndatasize()), 0x3f), not(0x1f))) + + // Copy over the return data + returndatacopy(add(returnData, 0x20), 0, returndatasize()) + } + } + + /// @dev This function reverts with one of the following custom error types: + /// - PreRuntimeValidationHookFailed + /// - RuntimeValidationFunctionReverted + /// - PreExecHookReverted + /// Since they all take the same parameters, we can just switch the selector as needed. + /// The last parameter, revertReason, is copied from return data. + function _revertOnRuntimePluginFunctionFail(bytes memory buffer, address plugin, bytes4 errorSelector) + internal + pure + { + assembly ("memory-safe") { + // Call failed, revert with the established error format and the provided selector + // The error format is: + // - Custom error selector + // - plugin address + // - function id + // - byte offset and length of revert reason + // - byte memory revertReason + // Total size: 132 bytes (4 byte selector + 4 * 32 byte words) + length of revert reason + let errorStart := mload(0x40) + // We add the extra size for the abi encoded fields at the same time as the selector, + // which is after the word-alignment step. + // Pad errorSize to nearest word + let errorSize := and(add(returndatasize(), 0x1f), not(0x1f)) + // Add the abi-encoded fields length (128 bytes) and the selector's size (4 bytes) + // to the error size. + errorSize := add(errorSize, 132) + // errorSize := add(errorSize, 132) + // Store the selector in the start of the error buffer. + // Any set lower bits will be cleared with the subsequest mstore. + mstore(errorStart, errorSelector) + mstore(add(errorStart, 0x04), plugin) + // Store the function id in the next word, as retrieved from the buffer + mstore(add(errorStart, 0x24), mload(add(buffer, 0x24))) + // Store the offset and length of the revert reason in the next two words + mstore(add(errorStart, 0x44), 0x60) + mstore(add(errorStart, 0x64), returndatasize()) + + // Copy over the revert reason + returndatacopy(add(errorStart, 0x84), 0, returndatasize()) + + // Revert + revert(errorStart, errorSize) + } + } +} diff --git a/src/account/AccountLoupe.sol b/src/account/AccountLoupe.sol new file mode 100644 index 00000000..a1992ecc --- /dev/null +++ b/src/account/AccountLoupe.sol @@ -0,0 +1,141 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.21; + +import {KnownSelectors} from "../helpers/KnownSelectors.sol"; + +import {IAccountLoupe} from "../interfaces/IAccountLoupe.sol"; + +import {AccountStorageV1} from "../libraries/AccountStorageV1.sol"; +import {CastLib} from "../libraries/CastLib.sol"; +import {CountableLinkedListSetLib} from "../libraries/CountableLinkedListSetLib.sol"; +import {FunctionReference} from "../libraries/FunctionReferenceLib.sol"; +import {LinkedListSet, LinkedListSetLib} from "../libraries/LinkedListSetLib.sol"; + +/// @title Account Loupe +/// @author Alchemy +/// @notice Provides view functions for querying the configuration of a modular account. +abstract contract AccountLoupe is IAccountLoupe, AccountStorageV1 { + using LinkedListSetLib for LinkedListSet; + using CountableLinkedListSetLib for LinkedListSet; + + /// @inheritdoc IAccountLoupe + function getExecutionFunctionConfig(bytes4 selector) + external + view + returns (ExecutionFunctionConfig memory config) + { + AccountStorage storage storage_ = _getAccountStorage(); + + if (KnownSelectors.isNativeFunction(selector)) { + config.plugin = address(this); + } else { + config.plugin = storage_.selectorData[selector].plugin; + } + + config.userOpValidationFunction = storage_.selectorData[selector].userOpValidation; + config.runtimeValidationFunction = storage_.selectorData[selector].runtimeValidation; + } + + /// @inheritdoc IAccountLoupe + function getExecutionHooks(bytes4 selector) external view returns (ExecutionHooks[] memory execHooks) { + execHooks = _getHooks(_getAccountStorage().selectorData[selector].executionHooks); + } + + /// @inheritdoc IAccountLoupe + function getPermittedCallHooks(address callingPlugin, bytes4 selector) + external + view + returns (ExecutionHooks[] memory execHooks) + { + PermittedCallData storage permittedCallData = + _getAccountStorage().permittedCalls[_getPermittedCallKey(callingPlugin, selector)]; + + execHooks = _getHooks(permittedCallData.permittedCallHooks); + } + + /// @inheritdoc IAccountLoupe + function getPreValidationHooks(bytes4 selector) + external + view + returns ( + FunctionReference[] memory preUserOpValidationHooks, + FunctionReference[] memory preRuntimeValidationHooks + ) + { + SelectorData storage selectorData = _getAccountStorage().selectorData[selector]; + preUserOpValidationHooks = CastLib.toFunctionReferenceArray(selectorData.preUserOpValidationHooks.getAll()); + preRuntimeValidationHooks = + CastLib.toFunctionReferenceArray(selectorData.preRuntimeValidationHooks.getAll()); + } + + /// @inheritdoc IAccountLoupe + function getInstalledPlugins() external view returns (address[] memory pluginAddresses) { + pluginAddresses = CastLib.toAddressArray(_getAccountStorage().plugins.getAll()); + } + + /// @dev Collects hook data from stored hooks (either execution hooks or permitted call hooks) and prepares it + /// for returning as the `ExecutionHooks` type defined by `IAccountLoupe`. + function _getHooks(HookGroup storage storedHooks) internal view returns (ExecutionHooks[] memory execHooks) { + FunctionReference[] memory preExecHooks = CastLib.toFunctionReferenceArray(storedHooks.preHooks.getAll()); + FunctionReference[] memory postOnlyExecHooks = + CastLib.toFunctionReferenceArray(storedHooks.postOnlyHooks.getAll()); + + uint256 preExecHooksLength = preExecHooks.length; + uint256 postOnlyExecHooksLength = postOnlyExecHooks.length; + uint256 maxExecHooksLength = postOnlyExecHooksLength; + + // There can only be as many associated post hooks to run as there are pre hooks. + for (uint256 i = 0; i < preExecHooksLength;) { + unchecked { + maxExecHooksLength += storedHooks.preHooks.getCount(CastLib.toSetValue(preExecHooks[i])); + ++i; + } + } + + // Overallocate on length - not all of this may get filled up. We set the correct length later. + execHooks = new ExecutionHooks[](maxExecHooksLength); + uint256 actualExecHooksLength = 0; + + for (uint256 i = 0; i < preExecHooksLength;) { + FunctionReference[] memory associatedPostExecHooks = + CastLib.toFunctionReferenceArray(storedHooks.associatedPostHooks[preExecHooks[i]].getAll()); + uint256 associatedPostExecHooksLength = associatedPostExecHooks.length; + + if (associatedPostExecHooksLength > 0) { + for (uint256 j = 0; j < associatedPostExecHooksLength;) { + execHooks[actualExecHooksLength].preExecHook = preExecHooks[i]; + execHooks[actualExecHooksLength].postExecHook = associatedPostExecHooks[j]; + + unchecked { + ++actualExecHooksLength; + ++j; + } + } + } else { + execHooks[actualExecHooksLength].preExecHook = preExecHooks[i]; + + unchecked { + ++actualExecHooksLength; + } + } + + unchecked { + ++i; + } + } + + for (uint256 i = 0; i < postOnlyExecHooksLength;) { + execHooks[actualExecHooksLength].postExecHook = postOnlyExecHooks[i]; + + unchecked { + ++actualExecHooksLength; + ++i; + } + } + + // "Trim" the exec hooks array to the actual length, since we may have overallocated. + assembly ("memory-safe") { + mstore(execHooks, actualExecHooksLength) + } + } +} diff --git a/src/account/AccountStorageInitializable.sol b/src/account/AccountStorageInitializable.sol new file mode 100644 index 00000000..69de828c --- /dev/null +++ b/src/account/AccountStorageInitializable.sol @@ -0,0 +1,50 @@ +// SPDX-License-Identifier: MIT +pragma solidity ^0.8.21; + +import {Address} from "@openzeppelin/contracts/utils/Address.sol"; + +import {AccountStorageV1} from "../libraries/AccountStorageV1.sol"; + +/// @title Account Storage Initializable +/// @author Alchemy +/// @notice This enables functions that can be called only once per implementation with the same storage layout +/// @dev Adapted from OpenZeppelin's Initialiazble and modified to use a diamond storage pattern. Removed +/// Initialized() event since the account already emits an event on initialization. +abstract contract AccountStorageInitializable is AccountStorageV1 { + error AlreadyInitialized(); + error AlreadyInitializing(); + + /// @notice Modifier to put on function intended to be called only once per implementation + /// @dev Reverts if the contract has already been initialized + modifier initializer() { + AccountStorage storage storage_ = _getAccountStorage(); + bool isTopLevelCall = !storage_.initializing; + if ( + isTopLevelCall && storage_.initialized < 1 + || !Address.isContract(address(this)) && storage_.initialized == 1 + ) { + storage_.initialized = 1; + if (isTopLevelCall) { + storage_.initializing = true; + } + _; + if (isTopLevelCall) { + storage_.initializing = false; + } + } else { + revert AlreadyInitialized(); + } + } + + /// @notice Internal function to disable calls to initialization functions + /// @dev Reverts if the contract has already been initialized + function _disableInitializers() internal virtual { + AccountStorage storage storage_ = _getAccountStorage(); + if (storage_.initializing) { + revert AlreadyInitializing(); + } + if (storage_.initialized != type(uint8).max) { + storage_.initialized = type(uint8).max; + } + } +} diff --git a/src/account/PluginManagerInternals.sol b/src/account/PluginManagerInternals.sol new file mode 100644 index 00000000..aa125703 --- /dev/null +++ b/src/account/PluginManagerInternals.sol @@ -0,0 +1,962 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.21; + +import {ERC165Checker} from "@openzeppelin/contracts/utils/introspection/ERC165Checker.sol"; + +import {KnownSelectors} from "../helpers/KnownSelectors.sol"; + +import { + IPlugin, + ManifestAssociatedFunction, + ManifestAssociatedFunctionType, + ManifestExecutionHook, + ManifestExternalCallPermission, + ManifestFunction, + PluginManifest +} from "../interfaces/IPlugin.sol"; +import {IPluginManager} from "../interfaces/IPluginManager.sol"; + +import {AccountStorageV1} from "../libraries/AccountStorageV1.sol"; +import {CastLib} from "../libraries/CastLib.sol"; +import {CountableLinkedListSetLib} from "../libraries/CountableLinkedListSetLib.sol"; +import {FunctionReference, FunctionReferenceLib} from "../libraries/FunctionReferenceLib.sol"; +import {LinkedListSet, LinkedListSetLib} from "../libraries/LinkedListSetLib.sol"; + +/// @title Plugin Manager Internals +/// @author Alchemy +/// @notice Contains functions to manage the state and behavior of plugin installs and uninstalls. +abstract contract PluginManagerInternals is IPluginManager, AccountStorageV1 { + using LinkedListSetLib for LinkedListSet; + using CountableLinkedListSetLib for LinkedListSet; + + // Grouping of arguments to `uninstallPlugin` to avoid "stack too deep" + // errors when building without via-ir. + struct UninstallPluginArgs { + address plugin; + PluginManifest manifest; + bool forceUninstall; + uint256 callbackGasLimit; + } + + // These flags are used in LinkedListSet values to optimize lookups. + // It's important that they don't overlap with bit 1 and bit 2, which are reserved bits used to indicate + // the sentinel value and the existence of a next value, respectively. + uint16 internal constant _PRE_EXEC_HOOK_HAS_POST_FLAG = 0x0004; // bit 3 + + error ArrayLengthMismatch(); + error DuplicateHookLimitExceeded(bytes4 selector, FunctionReference hook); + error DuplicatePreRuntimeValidationHookLimitExceeded(bytes4 selector, FunctionReference hook); + error DuplicatePreUserOpValidationHookLimitExceeded(bytes4 selector, FunctionReference hook); + error Erc4337FunctionNotAllowed(bytes4 selector); + error ExecutionFunctionAlreadySet(bytes4 selector); + error ExecutionFunctionNotSet(bytes4 selector); + error InvalidDependenciesProvided(); + error InvalidPluginManifest(); + error MissingPluginDependency(address dependency); + error NativeFunctionNotAllowed(bytes4 selector); + error NullFunctionReference(); + error PluginAlreadyInstalled(address plugin); + error PluginApplyHookCallbackFailed(address providingPlugin, bytes revertReason); + error PluginDependencyViolation(address plugin); + error PluginHookUnapplyCallbackFailed(address providingPlugin, bytes revertReason); + error PluginInstallCallbackFailed(address plugin, bytes revertReason); + error PluginInterfaceNotSupported(address plugin); + error PluginNotInstalled(address plugin); + error PluginUninstallCallbackFailed(address plugin, bytes revertReason); + error RuntimeValidationFunctionAlreadySet(bytes4 selector, FunctionReference validationFunction); + error UserOpValidationFunctionAlreadySet(bytes4 selector, FunctionReference validationFunction); + + // Storage update operations + + function _setExecutionFunction(bytes4 selector, address plugin) internal { + SelectorData storage selectorData = _getAccountStorage().selectorData[selector]; + + if (selectorData.plugin != address(0)) { + revert ExecutionFunctionAlreadySet(selector); + } + + // make sure incoming execution function does not collide with any native functions (data are stored on the + // account implementation contract) + if (KnownSelectors.isNativeFunction(selector)) { + revert NativeFunctionNotAllowed(selector); + } + // Also make sure it doesn't collide with functions defined by ERC-4337 + // and called by the entry point. This prevents a malicious plugin from + // sneaking in a function with the same selector as e.g. + // `validatePaymasterUserOp` and turning the account into their own + // personal paymaster. + if (KnownSelectors.isErc4337Function(selector)) { + revert Erc4337FunctionNotAllowed(selector); + } + + selectorData.plugin = plugin; + } + + function _addUserOpValidationFunction(bytes4 selector, FunctionReference validationFunction) internal { + _assertNotNullFunction(validationFunction); + + SelectorData storage selectorData = _getAccountStorage().selectorData[selector]; + + if (selectorData.userOpValidation != FunctionReferenceLib._EMPTY_FUNCTION_REFERENCE) { + revert UserOpValidationFunctionAlreadySet(selector, validationFunction); + } + + selectorData.userOpValidation = validationFunction; + } + + function _addRuntimeValidationFunction(bytes4 selector, FunctionReference validationFunction) internal { + _assertNotNullFunction(validationFunction); + + SelectorData storage selectorData = _getAccountStorage().selectorData[selector]; + + if (selectorData.runtimeValidation != FunctionReferenceLib._EMPTY_FUNCTION_REFERENCE) { + revert RuntimeValidationFunctionAlreadySet(selector, validationFunction); + } + + selectorData.runtimeValidation = validationFunction; + } + + function _addExecHooks(bytes4 selector, FunctionReference preExecHook, FunctionReference postExecHook) + internal + { + SelectorData storage selectorData = _getAccountStorage().selectorData[selector]; + + _addHooks(selectorData.executionHooks, selector, preExecHook, postExecHook); + + if (preExecHook != FunctionReferenceLib._EMPTY_FUNCTION_REFERENCE) { + selectorData.hasPreExecHooks = true; + } + + if (postExecHook != FunctionReferenceLib._EMPTY_FUNCTION_REFERENCE) { + selectorData.hasPostOnlyExecHooks = true; + } + } + + function _removeExecHooks(bytes4 selector, FunctionReference preExecHook, FunctionReference postExecHook) + internal + { + SelectorData storage selectorData = _getAccountStorage().selectorData[selector]; + + (bool shouldClearHasPreHooks, bool shouldClearHasPostHooks) = + _removeHooks(selectorData.executionHooks, preExecHook, postExecHook); + + if (shouldClearHasPreHooks) { + selectorData.hasPreExecHooks = false; + } + + if (shouldClearHasPostHooks) { + selectorData.hasPostOnlyExecHooks = false; + } + } + + function _enableExecFromPlugin(bytes4 selector, address plugin, AccountStorage storage storage_) internal { + PermittedCallData storage permittedCallData = + storage_.permittedCalls[_getPermittedCallKey(plugin, selector)]; + + // If there are duplicates, this will just enable the flag again. This is not a problem, since the boolean + // will be set to false twice during uninstall, which is fine. + permittedCallData.callPermitted = true; + } + + function _addPermittedCallHooks( + bytes4 selector, + address plugin, + FunctionReference preExecHook, + FunctionReference postExecHook + ) internal { + PermittedCallData storage permittedCallData = + _getAccountStorage().permittedCalls[_getPermittedCallKey(plugin, selector)]; + + _addHooks(permittedCallData.permittedCallHooks, selector, preExecHook, postExecHook); + + if (preExecHook != FunctionReferenceLib._EMPTY_FUNCTION_REFERENCE) { + permittedCallData.hasPrePermittedCallHooks = true; + } + + if (postExecHook != FunctionReferenceLib._EMPTY_FUNCTION_REFERENCE) { + permittedCallData.hasPostOnlyPermittedCallHooks = true; + } + } + + function _removePermittedCallHooks( + bytes4 selector, + address plugin, + FunctionReference preExecHook, + FunctionReference postExecHook + ) internal { + PermittedCallData storage permittedCallData = + _getAccountStorage().permittedCalls[_getPermittedCallKey(plugin, selector)]; + + (bool shouldClearHasPreHooks, bool shouldClearHasPostHooks) = + _removeHooks(permittedCallData.permittedCallHooks, preExecHook, postExecHook); + + if (shouldClearHasPreHooks) { + permittedCallData.hasPrePermittedCallHooks = false; + } + + if (shouldClearHasPostHooks) { + permittedCallData.hasPostOnlyPermittedCallHooks = false; + } + } + + function _addHooks( + HookGroup storage hooks, + bytes4 selector, + FunctionReference preExecHook, + FunctionReference postExecHook + ) internal { + if (preExecHook != FunctionReferenceLib._EMPTY_FUNCTION_REFERENCE) { + // add pre or pre/post pair of exec hooks + if (!hooks.preHooks.tryIncrement(CastLib.toSetValue(preExecHook))) { + revert DuplicateHookLimitExceeded(selector, preExecHook); + } + + if (postExecHook != FunctionReferenceLib._EMPTY_FUNCTION_REFERENCE) { + // can ignore return val of tryEnableFlags here as tryIncrement above must have succeeded + hooks.preHooks.tryEnableFlags(CastLib.toSetValue(preExecHook), _PRE_EXEC_HOOK_HAS_POST_FLAG); + if (!hooks.associatedPostHooks[preExecHook].tryIncrement(CastLib.toSetValue(postExecHook))) { + revert DuplicateHookLimitExceeded(selector, postExecHook); + } + } + } else { + if (postExecHook == FunctionReferenceLib._EMPTY_FUNCTION_REFERENCE) { + // both pre and post hooks cannot be null + revert NullFunctionReference(); + } + + if (!hooks.postOnlyHooks.tryIncrement(CastLib.toSetValue(postExecHook))) { + revert DuplicateHookLimitExceeded(selector, postExecHook); + } + } + } + + function _removeHooks(HookGroup storage hooks, FunctionReference preExecHook, FunctionReference postExecHook) + internal + returns (bool shouldClearHasPreHooks, bool shouldClearHasPostHooks) + { + if (preExecHook != FunctionReferenceLib._EMPTY_FUNCTION_REFERENCE) { + // If decrementing results in removal, this also clears the flag _PRE_EXEC_HOOK_HAS_POST_FLAG. + // Can ignore the return value because the manifest was checked to match the hash. + hooks.preHooks.tryDecrement(CastLib.toSetValue(preExecHook)); + + // Update the cached flag value for the pre-exec hooks, as it may change with a removal. + if (hooks.preHooks.isEmpty()) { + // The "has pre exec hooks" flag should be disabled + shouldClearHasPreHooks = true; + } + + if (postExecHook != FunctionReferenceLib._EMPTY_FUNCTION_REFERENCE) { + // Remove the associated post-exec hook, if it is set to the expected value. + // Can ignore the return value because the manifest was checked to match the hash. + hooks.associatedPostHooks[preExecHook].tryDecrement(CastLib.toSetValue(postExecHook)); + + if (hooks.associatedPostHooks[preExecHook].isEmpty()) { + // We can ignore return val of tryDisableFlags here as tryDecrement above must have succeeded + // in either removing the element or decrementing its count. + hooks.preHooks.tryDisableFlags(CastLib.toSetValue(preExecHook), _PRE_EXEC_HOOK_HAS_POST_FLAG); + } + } + } else { + // If this else branch is reached, it must be a post-only exec hook, because installation would fail + // when both the pre and post exec hooks are empty. + + // Can ignore the return value because the manifest was checked to match the hash. + hooks.postOnlyHooks.tryDecrement(CastLib.toSetValue(postExecHook)); + + // Update the cached flag value for the post-only exec hooks, as it may change with a removal. + if (hooks.postOnlyHooks.isEmpty()) { + // The "has post only hooks" flag should be disabled + shouldClearHasPostHooks = true; + } + } + } + + function _addPreUserOpValidationHook(bytes4 selector, FunctionReference preUserOpValidationHook) internal { + _assertNotNullFunction(preUserOpValidationHook); + + SelectorData storage selectorData = _getAccountStorage().selectorData[selector]; + if (!selectorData.preUserOpValidationHooks.tryIncrement(CastLib.toSetValue(preUserOpValidationHook))) { + revert DuplicatePreUserOpValidationHookLimitExceeded(selector, preUserOpValidationHook); + } + // add the pre user op validation hook to the cache for the given selector + if (!selectorData.hasPreUserOpValidationHooks) { + selectorData.hasPreUserOpValidationHooks = true; + } + } + + function _removePreUserOpValidationHook(bytes4 selector, FunctionReference preUserOpValidationHook) internal { + SelectorData storage selectorData = _getAccountStorage().selectorData[selector]; + // Can ignore the return value because the manifest was checked to match the hash. + selectorData.preUserOpValidationHooks.tryDecrement(CastLib.toSetValue(preUserOpValidationHook)); + + if (selectorData.preUserOpValidationHooks.isEmpty()) { + selectorData.hasPreUserOpValidationHooks = false; + } + } + + function _addPreRuntimeValidationHook(bytes4 selector, FunctionReference preRuntimeValidationHook) internal { + _assertNotNullFunction(preRuntimeValidationHook); + + SelectorData storage selectorData = _getAccountStorage().selectorData[selector]; + if (!selectorData.preRuntimeValidationHooks.tryIncrement(CastLib.toSetValue(preRuntimeValidationHook))) { + revert DuplicatePreRuntimeValidationHookLimitExceeded(selector, preRuntimeValidationHook); + } + // add the pre runtime validation hook's existence to the validator cache for the given selector + if (!selectorData.hasPreRuntimeValidationHooks) { + selectorData.hasPreRuntimeValidationHooks = true; + } + } + + function _removePreRuntimeValidationHook(bytes4 selector, FunctionReference preRuntimeValidationHook) + internal + { + SelectorData storage selectorData = _getAccountStorage().selectorData[selector]; + // Can ignore the return value because the manifest was checked to match the hash. + selectorData.preRuntimeValidationHooks.tryDecrement(CastLib.toSetValue(preRuntimeValidationHook)); + + if (selectorData.preRuntimeValidationHooks.isEmpty()) { + selectorData.hasPreRuntimeValidationHooks = false; + } + } + + function _installPlugin( + address plugin, + bytes32 manifestHash, + bytes memory pluginInitData, + FunctionReference[] memory dependencies, + InjectedHook[] memory injectedHooks + ) internal { + AccountStorage storage storage_ = _getAccountStorage(); + + // Check if the plugin exists, also invalidate null address. + if (!storage_.plugins.tryAdd(CastLib.toSetValue(plugin))) { + revert PluginAlreadyInstalled(plugin); + } + + // Check that the plugin supports the IPlugin interface. + if (!ERC165Checker.supportsInterface(plugin, type(IPlugin).interfaceId)) { + revert PluginInterfaceNotSupported(plugin); + } + + // Check manifest hash. + PluginManifest memory manifest = IPlugin(plugin).pluginManifest(); + if (!_isValidPluginManifest(manifest, manifestHash)) { + revert InvalidPluginManifest(); + } + + // Check that the dependencies match the manifest. + if (dependencies.length != manifest.dependencyInterfaceIds.length) { + revert InvalidDependenciesProvided(); + } + + uint256 length = dependencies.length; + for (uint256 i = 0; i < length;) { + // Check the dependency interface id over the address of the dependency. + (address dependencyAddr,) = dependencies[i].unpack(); + + // Check that the dependency is installed. + if (storage_.pluginData[dependencyAddr].manifestHash == bytes32(0)) { + revert MissingPluginDependency(dependencyAddr); + } + + // Check that the dependency supports the expected interface. + if (!ERC165Checker.supportsInterface(dependencyAddr, manifest.dependencyInterfaceIds[i])) { + revert InvalidDependenciesProvided(); + } + + // Increment the dependency's dependents counter. + storage_.pluginData[dependencyAddr].dependentCount += 1; + + unchecked { + ++i; + } + } + + // Add the plugin metadata to the account + storage_.pluginData[plugin].manifestHash = manifestHash; + storage_.pluginData[plugin].dependencies = dependencies; + + // Update components according to the manifest. + // All conflicts should revert. + + // Mark whether or not this plugin may spend native token amounts + if (manifest.canSpendNativeToken) { + storage_.pluginData[plugin].canSpendNativeToken = true; + } + + // Install execution functions + length = manifest.executionFunctions.length; + for (uint256 i = 0; i < length;) { + _setExecutionFunction(manifest.executionFunctions[i], plugin); + + unchecked { + ++i; + } + } + + // Add installed plugin and selectors this plugin can call + length = manifest.permittedExecutionSelectors.length; + for (uint256 i = 0; i < length;) { + _enableExecFromPlugin(manifest.permittedExecutionSelectors[i], plugin, storage_); + + unchecked { + ++i; + } + } + + // Add the permitted external calls to the account. + if (manifest.permitAnyExternalAddress) { + storage_.pluginData[plugin].anyExternalAddressPermitted = true; + } else { + // Only store the specific permitted external calls if "permit any" flag was not set. + length = manifest.permittedExternalCalls.length; + for (uint256 i = 0; i < length;) { + ManifestExternalCallPermission memory externalCallPermission = manifest.permittedExternalCalls[i]; + + PermittedExternalCallData storage permittedExternalCallData = + storage_.permittedExternalCalls[IPlugin(plugin)][externalCallPermission.externalAddress]; + + permittedExternalCallData.addressPermitted = true; + + if (externalCallPermission.permitAnySelector) { + permittedExternalCallData.anySelectorPermitted = true; + } else { + uint256 externalContractSelectorsLength = externalCallPermission.selectors.length; + for (uint256 j = 0; j < externalContractSelectorsLength;) { + permittedExternalCallData.permittedSelectors[externalCallPermission.selectors[j]] = true; + + unchecked { + ++j; + } + } + } + + unchecked { + ++i; + } + } + } + + // Add injected hooks + length = injectedHooks.length; + // Manually set injected hooks array length + StoredInjectedHook[] storage injectedHooksArray = storage_.pluginData[plugin].injectedHooks; + assembly ("memory-safe") { + sstore(injectedHooksArray.slot, length) + } + for (uint256 i = 0; i < length;) { + InjectedHook memory hook = injectedHooks[i]; + + storage_.pluginData[plugin].injectedHooks[i] = StoredInjectedHook({ + providingPlugin: hook.providingPlugin, + selector: hook.selector, + preExecHookFunctionId: hook.injectedHooksInfo.preExecHookFunctionId, + isPostHookUsed: hook.injectedHooksInfo.isPostHookUsed, + postExecHookFunctionId: hook.injectedHooksInfo.postExecHookFunctionId + }); + + // Increment the dependent count for the plugin providing the hook. + storage_.pluginData[hook.providingPlugin].dependentCount += 1; + + if (!storage_.plugins.contains(CastLib.toSetValue(hook.providingPlugin))) { + revert MissingPluginDependency(hook.providingPlugin); + } + + _addPermittedCallHooks( + hook.selector, + plugin, + FunctionReferenceLib.pack(hook.providingPlugin, hook.injectedHooksInfo.preExecHookFunctionId), + hook.injectedHooksInfo.isPostHookUsed + ? FunctionReferenceLib.pack(hook.providingPlugin, hook.injectedHooksInfo.postExecHookFunctionId) + : FunctionReferenceLib._EMPTY_FUNCTION_REFERENCE + ); + + unchecked { + ++i; + } + } + + // Add user operation validation functions + length = manifest.userOpValidationFunctions.length; + for (uint256 i = 0; i < length;) { + ManifestAssociatedFunction memory mv = manifest.userOpValidationFunctions[i]; + _addUserOpValidationFunction( + mv.executionSelector, + _resolveManifestFunction( + mv.associatedFunction, plugin, dependencies, ManifestAssociatedFunctionType.NONE + ) + ); + + unchecked { + ++i; + } + } + + // Add runtime validation functions + length = manifest.runtimeValidationFunctions.length; + for (uint256 i = 0; i < length;) { + ManifestAssociatedFunction memory mv = manifest.runtimeValidationFunctions[i]; + _addRuntimeValidationFunction( + mv.executionSelector, + _resolveManifestFunction( + mv.associatedFunction, + plugin, + dependencies, + ManifestAssociatedFunctionType.RUNTIME_VALIDATION_ALWAYS_ALLOW + ) + ); + + unchecked { + ++i; + } + } + + // Add pre user operation validation hooks + length = manifest.preUserOpValidationHooks.length; + for (uint256 i = 0; i < length;) { + ManifestAssociatedFunction memory mh = manifest.preUserOpValidationHooks[i]; + _addPreUserOpValidationHook( + mh.executionSelector, + _resolveManifestFunction( + mh.associatedFunction, + plugin, + dependencies, + ManifestAssociatedFunctionType.PRE_HOOK_ALWAYS_DENY + ) + ); + + unchecked { + ++i; + } + } + + // Add pre runtime validation hooks + length = manifest.preRuntimeValidationHooks.length; + for (uint256 i = 0; i < length;) { + ManifestAssociatedFunction memory mh = manifest.preRuntimeValidationHooks[i]; + _addPreRuntimeValidationHook( + mh.executionSelector, + _resolveManifestFunction( + mh.associatedFunction, + plugin, + dependencies, + ManifestAssociatedFunctionType.PRE_HOOK_ALWAYS_DENY + ) + ); + unchecked { + ++i; + } + } + + // Add pre and post execution hooks + length = manifest.executionHooks.length; + for (uint256 i = 0; i < length;) { + ManifestExecutionHook memory mh = manifest.executionHooks[i]; + _addExecHooks( + mh.executionSelector, + _resolveManifestFunction( + mh.preExecHook, plugin, dependencies, ManifestAssociatedFunctionType.PRE_HOOK_ALWAYS_DENY + ), + _resolveManifestFunction( + mh.postExecHook, plugin, dependencies, ManifestAssociatedFunctionType.NONE + ) + ); + + unchecked { + ++i; + } + } + + // Add pre and post permitted call hooks + length = manifest.permittedCallHooks.length; + for (uint256 i = 0; i < length;) { + _addPermittedCallHooks( + manifest.permittedCallHooks[i].executionSelector, + plugin, + _resolveManifestFunction( + manifest.permittedCallHooks[i].preExecHook, + plugin, + dependencies, + ManifestAssociatedFunctionType.PRE_HOOK_ALWAYS_DENY + ), + _resolveManifestFunction( + manifest.permittedCallHooks[i].postExecHook, + plugin, + dependencies, + ManifestAssociatedFunctionType.NONE + ) + ); + + unchecked { + ++i; + } + } + + // Add new interface ids the plugin enabled for the account + length = manifest.interfaceIds.length; + for (uint256 i = 0; i < length;) { + storage_.supportedInterfaces[manifest.interfaceIds[i]] += 1; + unchecked { + ++i; + } + } + + // Call injected hooks' onHookApply after all setup, this is before calling plugin onInstall + length = injectedHooks.length; + for (uint256 i = 0; i < length;) { + InjectedHook memory hook = injectedHooks[i]; + + /* solhint-disable no-empty-blocks */ + try IPlugin(hook.providingPlugin).onHookApply( + plugin, hook.injectedHooksInfo, injectedHooks[i].hookApplyData + ) {} catch (bytes memory revertReason) { + revert PluginApplyHookCallbackFailed(hook.providingPlugin, revertReason); + } + /* solhint-enable no-empty-blocks */ + + // zero out hookApplyData to reduce log cost + injectedHooks[i].hookApplyData = new bytes(0); + + unchecked { + ++i; + } + } + + // Initialize the plugin storage for the account. + // solhint-disable-next-line no-empty-blocks + try IPlugin(plugin).onInstall(pluginInitData) {} + catch (bytes memory revertReason) { + revert PluginInstallCallbackFailed(plugin, revertReason); + } + + emit PluginInstalled(plugin, manifestHash, dependencies, injectedHooks); + } + + function _uninstallPlugin( + UninstallPluginArgs memory args, + bytes calldata uninstallData, + bytes[] calldata hookUnapplyData + ) internal { + AccountStorage storage storage_ = _getAccountStorage(); + + // Check if the plugin exists. + if (!storage_.plugins.tryRemove(CastLib.toSetValue(args.plugin))) { + revert PluginNotInstalled(args.plugin); + } + + // Check manifest hash. + if (!_isValidPluginManifest(args.manifest, storage_.pluginData[args.plugin].manifestHash)) { + revert InvalidPluginManifest(); + } + + // Ensure that there are no dependent plugins. + if (storage_.pluginData[args.plugin].dependentCount != 0) { + revert PluginDependencyViolation(args.plugin); + } + + // Remove this plugin as a dependent from its dependencies. + FunctionReference[] memory dependencies = storage_.pluginData[args.plugin].dependencies; + uint256 length = dependencies.length; + for (uint256 i = 0; i < length;) { + FunctionReference dependency = dependencies[i]; + (address dependencyAddr,) = dependency.unpack(); + + // Decrement the dependent count for the dependency function. + storage_.pluginData[dependencyAddr].dependentCount -= 1; + + unchecked { + ++i; + } + } + + // Remove components according to the manifest, in reverse order (by component type) of their installation. + + // Remove pre and post permitted call hooks + length = args.manifest.permittedCallHooks.length; + for (uint256 i = 0; i < length;) { + _removePermittedCallHooks( + args.manifest.permittedCallHooks[i].executionSelector, + args.plugin, + _resolveManifestFunction( + args.manifest.permittedCallHooks[i].preExecHook, + args.plugin, + dependencies, + ManifestAssociatedFunctionType.PRE_HOOK_ALWAYS_DENY + ), + _resolveManifestFunction( + args.manifest.permittedCallHooks[i].postExecHook, + args.plugin, + dependencies, + ManifestAssociatedFunctionType.NONE + ) + ); + + unchecked { + ++i; + } + } + + // Remove pre and post execution function hooks + length = args.manifest.executionHooks.length; + for (uint256 i = 0; i < length;) { + ManifestExecutionHook memory mh = args.manifest.executionHooks[i]; + _removeExecHooks( + mh.executionSelector, + _resolveManifestFunction( + mh.preExecHook, args.plugin, dependencies, ManifestAssociatedFunctionType.PRE_HOOK_ALWAYS_DENY + ), + _resolveManifestFunction( + mh.postExecHook, args.plugin, dependencies, ManifestAssociatedFunctionType.NONE + ) + ); + + unchecked { + ++i; + } + } + + // Remove pre runtime validation function hooks + length = args.manifest.preRuntimeValidationHooks.length; + for (uint256 i = 0; i < length;) { + ManifestAssociatedFunction memory mh = args.manifest.preRuntimeValidationHooks[i]; + + _removePreRuntimeValidationHook( + mh.executionSelector, + _resolveManifestFunction( + mh.associatedFunction, + args.plugin, + dependencies, + ManifestAssociatedFunctionType.PRE_HOOK_ALWAYS_DENY + ) + ); + + unchecked { + ++i; + } + } + + // Remove pre user op validation function hooks + length = args.manifest.preUserOpValidationHooks.length; + for (uint256 i = 0; i < length;) { + ManifestAssociatedFunction memory mh = args.manifest.preUserOpValidationHooks[i]; + + _removePreUserOpValidationHook( + mh.executionSelector, + _resolveManifestFunction( + mh.associatedFunction, + args.plugin, + dependencies, + ManifestAssociatedFunctionType.PRE_HOOK_ALWAYS_DENY + ) + ); + + unchecked { + ++i; + } + } + + // Remove runtime validation function hooks + length = args.manifest.runtimeValidationFunctions.length; + for (uint256 i = 0; i < length;) { + bytes4 executionSelector = args.manifest.runtimeValidationFunctions[i].executionSelector; + storage_.selectorData[executionSelector].runtimeValidation = + FunctionReferenceLib._EMPTY_FUNCTION_REFERENCE; + + unchecked { + ++i; + } + } + + // Remove user op validation function hooks + length = args.manifest.userOpValidationFunctions.length; + for (uint256 i = 0; i < length;) { + bytes4 executionSelector = args.manifest.userOpValidationFunctions[i].executionSelector; + storage_.selectorData[executionSelector].userOpValidation = + FunctionReferenceLib._EMPTY_FUNCTION_REFERENCE; + + unchecked { + ++i; + } + } + + // Remove permitted external call permissions + if (args.manifest.permitAnyExternalAddress) { + // Only clear if it was set during install time + storage_.pluginData[args.plugin].anyExternalAddressPermitted = false; + } else { + // Only clear the specific permitted external calls if "permit any" flag was not set. + length = args.manifest.permittedExternalCalls.length; + for (uint256 i = 0; i < length;) { + ManifestExternalCallPermission memory externalCallPermission = + args.manifest.permittedExternalCalls[i]; + + PermittedExternalCallData storage permittedExternalCallData = + storage_.permittedExternalCalls[IPlugin(args.plugin)][externalCallPermission.externalAddress]; + + permittedExternalCallData.addressPermitted = false; + + // Only clear this flag if it was set in the constructor. + if (externalCallPermission.permitAnySelector) { + permittedExternalCallData.anySelectorPermitted = false; + } else { + uint256 externalContractSelectorsLength = externalCallPermission.selectors.length; + for (uint256 j = 0; j < externalContractSelectorsLength;) { + permittedExternalCallData.permittedSelectors[externalCallPermission.selectors[j]] = false; + + unchecked { + ++j; + } + } + } + + unchecked { + ++i; + } + } + } + + // Remove injected hooks + length = storage_.pluginData[args.plugin].injectedHooks.length; + for (uint256 i = 0; i < length;) { + StoredInjectedHook memory hook = storage_.pluginData[args.plugin].injectedHooks[i]; + + // Decrement the dependent count for the plugin providing the hook. + storage_.pluginData[hook.providingPlugin].dependentCount -= 1; + + _removePermittedCallHooks( + hook.selector, + args.plugin, + FunctionReferenceLib.pack(hook.providingPlugin, hook.preExecHookFunctionId), + hook.isPostHookUsed + ? FunctionReferenceLib.pack(hook.providingPlugin, hook.postExecHookFunctionId) + : FunctionReferenceLib._EMPTY_FUNCTION_REFERENCE + ); + + unchecked { + ++i; + } + } + + // Remove permitted account execution function call permissions + length = args.manifest.permittedExecutionSelectors.length; + for (uint256 i = 0; i < length;) { + storage_.permittedCalls[_getPermittedCallKey(args.plugin, args.manifest.permittedExecutionSelectors[i])] + .callPermitted = false; + + unchecked { + ++i; + } + } + + // Remove installed execution function + length = args.manifest.executionFunctions.length; + for (uint256 i = 0; i < length;) { + storage_.selectorData[args.manifest.executionFunctions[i]].plugin = address(0); + + unchecked { + ++i; + } + } + + // Decrease supported interface ids' counters + length = args.manifest.interfaceIds.length; + for (uint256 i = 0; i < length;) { + storage_.supportedInterfaces[args.manifest.interfaceIds[i]] -= 1; + + unchecked { + ++i; + } + } + + // Call onHookUnapply on all injected hooks + bool callbacksSucceeded = true; + length = storage_.pluginData[args.plugin].injectedHooks.length; + bool hasUnapplyHookData = hookUnapplyData.length != 0; + if (hasUnapplyHookData && hookUnapplyData.length != length) { + revert ArrayLengthMismatch(); + } + for (uint256 i = 0; i < length;) { + StoredInjectedHook memory hook = storage_.pluginData[args.plugin].injectedHooks[i]; + + /* solhint-disable no-empty-blocks */ + try IPlugin(hook.providingPlugin).onHookUnapply{gas: args.callbackGasLimit}( + args.plugin, + InjectedHooksInfo({ + preExecHookFunctionId: hook.preExecHookFunctionId, + isPostHookUsed: hook.isPostHookUsed, + postExecHookFunctionId: hook.postExecHookFunctionId + }), + hasUnapplyHookData ? hookUnapplyData[i] : bytes("") + ) {} catch (bytes memory revertReason) { + if (!args.forceUninstall) { + revert PluginHookUnapplyCallbackFailed(hook.providingPlugin, revertReason); + } + callbacksSucceeded = false; + emit PluginIgnoredHookUnapplyCallbackFailure(args.plugin, hook.providingPlugin); + } + /* solhint-enable no-empty-blocks */ + + unchecked { + ++i; + } + } + + // Remove the plugin metadata from the account. + delete storage_.pluginData[args.plugin]; + + // Clear the plugin storage for the account. + // solhint-disable-next-line no-empty-blocks + try IPlugin(args.plugin).onUninstall{gas: args.callbackGasLimit}(uninstallData) {} + catch (bytes memory revertReason) { + if (!args.forceUninstall) { + revert PluginUninstallCallbackFailed(args.plugin, revertReason); + } + callbacksSucceeded = false; + emit PluginIgnoredUninstallCallbackFailure(args.plugin); + } + + emit PluginUninstalled(args.plugin, callbacksSucceeded); + } + + function _isValidPluginManifest(PluginManifest memory manifest, bytes32 manifestHash) + internal + pure + returns (bool) + { + return manifestHash == keccak256(abi.encode(manifest)); + } + + function _resolveManifestFunction( + ManifestFunction memory manifestFunction, + address plugin, + FunctionReference[] memory dependencies, + // Indicates which magic value, if any, is permissible for the function to resolve. + ManifestAssociatedFunctionType allowedMagicValue + ) internal pure returns (FunctionReference) { + if (manifestFunction.functionType == ManifestAssociatedFunctionType.SELF) { + return FunctionReferenceLib.pack(plugin, manifestFunction.functionId); + } else if (manifestFunction.functionType == ManifestAssociatedFunctionType.DEPENDENCY) { + return dependencies[manifestFunction.dependencyIndex]; + } else if (manifestFunction.functionType == ManifestAssociatedFunctionType.RUNTIME_VALIDATION_ALWAYS_ALLOW) + { + if (allowedMagicValue == ManifestAssociatedFunctionType.RUNTIME_VALIDATION_ALWAYS_ALLOW) { + return FunctionReferenceLib._RUNTIME_VALIDATION_ALWAYS_ALLOW; + } else { + revert InvalidPluginManifest(); + } + } else if (manifestFunction.functionType == ManifestAssociatedFunctionType.PRE_HOOK_ALWAYS_DENY) { + if (allowedMagicValue == ManifestAssociatedFunctionType.PRE_HOOK_ALWAYS_DENY) { + return FunctionReferenceLib._PRE_HOOK_ALWAYS_DENY; + } else { + revert InvalidPluginManifest(); + } + } + return FunctionReferenceLib._EMPTY_FUNCTION_REFERENCE; // Empty checks are done elsewhere + } + + function _assertNotNullFunction(FunctionReference functionReference) internal pure { + if (functionReference == FunctionReferenceLib._EMPTY_FUNCTION_REFERENCE) { + revert NullFunctionReference(); + } + } +} diff --git a/src/account/UpgradeableModularAccount.sol b/src/account/UpgradeableModularAccount.sol new file mode 100644 index 00000000..208a7a13 --- /dev/null +++ b/src/account/UpgradeableModularAccount.sol @@ -0,0 +1,777 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.21; + +import {IERC165} from "@openzeppelin/contracts/utils/introspection/IERC165.sol"; + +import {AccountExecutor} from "./AccountExecutor.sol"; +import {AccountLoupe} from "./AccountLoupe.sol"; +import {AccountStorageInitializable} from "./AccountStorageInitializable.sol"; +import {PluginManagerInternals} from "./PluginManagerInternals.sol"; + +import {_coalescePreValidation, _coalesceValidation} from "../helpers/ValidationDataHelpers.sol"; + +import {Call, IStandardExecutor} from "../interfaces/IStandardExecutor.sol"; +import {IAccount} from "../interfaces/erc4337/IAccount.sol"; +import {IAccountInitializable} from "../interfaces/IAccountInitializable.sol"; +import {IAccountView} from "../interfaces/IAccountView.sol"; +import {IEntryPoint} from "../interfaces/erc4337/IEntryPoint.sol"; +import {IPlugin, PluginManifest} from "../interfaces/IPlugin.sol"; +import {IPluginExecutor} from "../interfaces/IPluginExecutor.sol"; +import {IPluginManager} from "../interfaces/IPluginManager.sol"; +import {UserOperation} from "../interfaces/erc4337/UserOperation.sol"; + +import {CastLib} from "../libraries/CastLib.sol"; +import {CountableLinkedListSetLib} from "../libraries/CountableLinkedListSetLib.sol"; +import {FunctionReference, FunctionReferenceLib} from "../libraries/FunctionReferenceLib.sol"; +import {LinkedListSet, LinkedListSetLib} from "../libraries/LinkedListSetLib.sol"; +import {UUPSUpgradeable} from "../../ext/UUPSUpgradeable.sol"; + +/// @title Upgradeable Modular Account +/// @author Alchemy +/// @notice An ERC-6900 compatible modular smart contract account (MSCA) that supports upgradeability and plugins. +contract UpgradeableModularAccount is + AccountExecutor, + AccountLoupe, + AccountStorageInitializable, + PluginManagerInternals, + IAccount, + IAccountInitializable, + IAccountView, + IERC165, + IPluginExecutor, + IStandardExecutor, + UUPSUpgradeable +{ + using CountableLinkedListSetLib for LinkedListSet; + using LinkedListSetLib for LinkedListSet; + + /// @dev Struct to hold optional configuration data for uninstalling a plugin. This should be encoded and + /// passed to the `config` parameter of `uninstallPlugin`. + struct UninstallPluginConfig { + // ABI-encoding of a `PluginManifest` to specify the original manifest + // used to install the plugin now being uninstalled, in cases where the + // plugin manifest has changed. If empty, uses the default behavior of + // calling the plugin to get its current manifest. + bytes serializedManifest; + // If true, will complete the uninstall even if `onUninstall` or + // `onHookUnapply` callbacks revert. Available as an escape hatch if a + // plugin is blocking uninstall. + bool forceUninstall; + // Maximum amount of gas allowed for each uninstall callback function + // (`onUninstall` and `onHookUnapply`), or zero to set no limit. Should + // typically be used with `forceUninstall` to remove plugins that are + // preventing uninstallation by consuming all remaining gas. + uint256 callbackGasLimit; + } + + IEntryPoint private immutable _ENTRY_POINT; + uint256 internal constant _SIG_VALIDATION_FAILED = 1; + + // As per the EIP-165 spec, no interface should ever match 0xffffffff + bytes4 internal constant _INTERFACE_ID_INVALID = 0xffffffff; + bytes4 internal constant _IERC165_INTERFACE_ID = 0x01ffc9a7; + + event ModularAccountInitialized(IEntryPoint indexed entryPoint); + + error AlwaysDenyRule(); + error AuthorizeUpgradeReverted(bytes revertReason); + error ExecFromPluginNotPermitted(address plugin, bytes4 selector); + error ExecFromPluginExternalNotPermitted(address plugin, address target, uint256 value, bytes data); + error NativeTokenSpendingNotPermitted(address plugin); + error PostExecHookReverted(address plugin, uint8 functionId, bytes revertReason); + error PreExecHookReverted(address plugin, uint8 functionId, bytes revertReason); + error PreRuntimeValidationHookFailed(address plugin, uint8 functionId, bytes revertReason); + error RuntimeValidationFunctionMissing(bytes4 selector); + error RuntimeValidationFunctionReverted(address plugin, uint8 functionId, bytes revertReason); + error UnexpectedAggregator(address plugin, uint8 functionId, address aggregator); + error UnrecognizedFunction(bytes4 selector); + error UserOpNotFromEntryPoint(); + error UserOpValidationFunctionMissing(bytes4 selector); + + constructor(IEntryPoint anEntryPoint) { + _ENTRY_POINT = anEntryPoint; + _disableInitializers(); + } + + // EXTERNAL FUNCTIONS + + /// @inheritdoc IAccountInitializable + function initialize(address[] memory plugins, bytes calldata pluginInitData) external initializer { + (bytes32[] memory manifestHashes, bytes[] memory pluginInstallDatas) = + abi.decode(pluginInitData, (bytes32[], bytes[])); + + uint256 length = plugins.length; + + if (length != manifestHashes.length || length != pluginInstallDatas.length) { + revert ArrayLengthMismatch(); + } + + FunctionReference[] memory emptyDependencies = new FunctionReference[](0); + InjectedHook[] memory emptyInjectedHooks = new InjectedHook[](0); + + for (uint256 i = 0; i < length;) { + _installPlugin( + plugins[i], manifestHashes[i], pluginInstallDatas[i], emptyDependencies, emptyInjectedHooks + ); + + unchecked { + ++i; + } + } + + emit ModularAccountInitialized(_ENTRY_POINT); + } + + receive() external payable {} + + /// @notice Fallback function that routes calls to plugin execution functions. + /// @dev We route calls to execution functions based on incoming msg.sig. If there's no plugin associated with + /// this function selector, revert. + /// @return Data returned from the called execution function. + fallback(bytes calldata) external payable returns (bytes memory) { + SelectorData storage selectorData = _getAccountStorage().selectorData[msg.sig]; + + address execPlugin = selectorData.plugin; + if (execPlugin == address(0)) { + revert UnrecognizedFunction(msg.sig); + } + + // Either reuse the call buffer from runtime validation, or allocate a new one. It may or may not be used + // for pre exec hooks but it will be used for the plugin execution itself. + bytes memory callBuffer = + (msg.sender != address(_ENTRY_POINT)) ? _doRuntimeValidation() : _allocateRuntimeCallBuffer(msg.data); + + bool hasPreExecHooks = selectorData.hasPreExecHooks; + bool hasPostOnlyExecHooks = selectorData.hasPostOnlyExecHooks; + + FunctionReference[] memory postExecHooksToRun; + bytes[] memory postExecHookArgs; + if (hasPreExecHooks) { + // Cache post-exec hooks in memory + (postExecHooksToRun, postExecHookArgs) = _doPreExecHooks(msg.sig, callBuffer); + } + + // execute the function, bubbling up any reverts + bool execSuccess = _executeRaw(execPlugin, _convertRuntimeCallBufferToExecBuffer(callBuffer)); + bytes memory execReturnData = _collectReturnData(); + + if (!execSuccess) { + // Bubble up revert reasons from plugins + assembly ("memory-safe") { + revert(add(execReturnData, 32), mload(execReturnData)) + } + } + + _doCachedPostHooks(postExecHooksToRun, postExecHookArgs); + + if (hasPostOnlyExecHooks) { + _doCachedPostHooks( + CastLib.toFunctionReferenceArray(selectorData.executionHooks.postOnlyHooks.getAll()), + new bytes[](0) + ); + } + + return execReturnData; + } + + /// @inheritdoc IAccount + function validateUserOp(UserOperation calldata userOp, bytes32 userOpHash, uint256 missingAccountFunds) + external + virtual + override + returns (uint256 validationData) + { + if (msg.sender != address(_ENTRY_POINT)) { + revert UserOpNotFromEntryPoint(); + } + + bool hasPreValidationHooks; + + if (userOp.callData.length < 4) { + revert UnrecognizedFunction(bytes4(userOp.callData)); + } + bytes4 selector = _selectorFromCallData(userOp.callData); + SelectorData storage selectorData = _getAccountStorage().selectorData[selector]; + + FunctionReference userOpValidationFunction = selectorData.userOpValidation; + hasPreValidationHooks = selectorData.hasPreUserOpValidationHooks; + + validationData = + _doUserOpValidation(selector, userOpValidationFunction, userOp, userOpHash, hasPreValidationHooks); + + if (missingAccountFunds != 0) { + // entry point verifies if call succeeds so we don't need to do here + (bool success,) = payable(msg.sender).call{value: missingAccountFunds, gas: type(uint256).max}(""); + (success); + } + } + + /// @inheritdoc IStandardExecutor + function execute(address target, uint256 value, bytes calldata data) + external + payable + override + returns (bytes memory result) + { + (FunctionReference[] memory postExecHooks, bytes[] memory postExecHookArgs) = _preNativeFunction(); + result = _exec(target, value, data); + _postNativeFunction(postExecHooks, postExecHookArgs); + } + + /// @inheritdoc IStandardExecutor + function executeBatch(Call[] calldata calls) external payable override returns (bytes[] memory results) { + (FunctionReference[] memory postExecHooks, bytes[] memory postExecHookArgs) = _preNativeFunction(); + + uint256 callsLength = calls.length; + results = new bytes[](callsLength); + + for (uint256 i = 0; i < callsLength;) { + results[i] = _exec(calls[i].target, calls[i].value, calls[i].data); + + unchecked { + ++i; + } + } + + _postNativeFunction(postExecHooks, postExecHookArgs); + } + + /// @inheritdoc IPluginExecutor + function executeFromPlugin(bytes calldata data) external payable override returns (bytes memory returnData) { + bytes4 selector = _selectorFromCallData(data); + bytes24 permittedCallKey = _getPermittedCallKey(msg.sender, selector); + + AccountStorage storage storage_ = _getAccountStorage(); + PermittedCallData storage permittedCallData = storage_.permittedCalls[permittedCallKey]; + + if (!permittedCallData.callPermitted) { + revert ExecFromPluginNotPermitted(msg.sender, selector); + } + + bytes memory callBuffer = _allocateRuntimeCallBuffer(data); + + FunctionReference[] memory postPermittedCallHooks; + bytes[] memory postPermittedCallHookArgs; + if (permittedCallData.hasPrePermittedCallHooks) { + // Cache post-permitted call hooks in memory + (postPermittedCallHooks, postPermittedCallHookArgs) = + _doPrePermittedCallHooks(permittedCallKey, callBuffer); + } + + SelectorData storage selectorData = storage_.selectorData[selector]; + address execFunctionPlugin = selectorData.plugin; + + if (execFunctionPlugin == address(0)) { + revert UnrecognizedFunction(selector); + } + + FunctionReference[] memory postExecHooks; + bytes[] memory postExecHookArgs; + if (selectorData.hasPreExecHooks) { + // Cache post-exec hooks in memory + (postExecHooks, postExecHookArgs) = _doPreExecHooks(selector, callBuffer); + } + + bool success = _executeRaw(execFunctionPlugin, _convertRuntimeCallBufferToExecBuffer(callBuffer)); + returnData = _collectReturnData(); + + if (!success) { + assembly ("memory-safe") { + revert(add(returnData, 32), mload(returnData)) + } + } + + _doCachedPostHooks(postExecHooks, postExecHookArgs); + + if (selectorData.hasPostOnlyExecHooks) { + _doCachedPostHooks( + CastLib.toFunctionReferenceArray(selectorData.executionHooks.postOnlyHooks.getAll()), + new bytes[](0) + ); + } + + _doCachedPostHooks(postPermittedCallHooks, postPermittedCallHookArgs); + + if (permittedCallData.hasPostOnlyPermittedCallHooks) { + _doCachedPostHooks( + CastLib.toFunctionReferenceArray(permittedCallData.permittedCallHooks.postOnlyHooks.getAll()), + new bytes[](0) + ); + } + + return returnData; + } + + /// @inheritdoc IPluginExecutor + function executeFromPluginExternal(address target, uint256 value, bytes calldata data) + external + payable + returns (bytes memory) + { + AccountStorage storage storage_ = _getAccountStorage(); + address callingPlugin = msg.sender; + + // Make sure plugin is allowed to spend native token. + if (value > 0 && value > msg.value && !storage_.pluginData[callingPlugin].canSpendNativeToken) { + revert NativeTokenSpendingNotPermitted(callingPlugin); + } + + // Check the caller plugin's permission to make this call on the target address. + // + // 1. Check that the target is permitted at all, and if so check that any one of the following is true: + // a. Is any selector permitted? + // b. Is the calldata is empty? (allow empty data calls by default if the target address is permitted) + // c. Is the selector in the call permitted? + // 2. If the target is not permitted, instead check whether all external calls are permitted. + // + // `addressPermitted` can only be true if `anyExternalAddressPermitted` is false, so we can reduce our + // worst-case `sloads` by 1 by not checking `anyExternalAddressPermitted` if `addressPermitted` is true. + // + // We allow calls where the data may be less than 4 bytes - it's up to the calling contract to + // determine how to handle this. + bool isTargetCallPermitted; + if (storage_.permittedExternalCalls[IPlugin(callingPlugin)][target].addressPermitted) { + isTargetCallPermitted = ( + storage_.permittedExternalCalls[IPlugin(callingPlugin)][target].anySelectorPermitted + || data.length == 0 + || storage_.permittedExternalCalls[IPlugin(callingPlugin)][target].permittedSelectors[bytes4(data)] + ); + } else { + isTargetCallPermitted = storage_.pluginData[callingPlugin].anyExternalAddressPermitted; + } + + // If the target is not permitted, check if the caller plugin is permitted to make any external calls. + if (!isTargetCallPermitted) { + revert ExecFromPluginExternalNotPermitted(callingPlugin, target, value, data); + } + + // Run permitted call hooks and execution hooks. `execfuteFromPluginExternal` doesn't use PermittedCallData + // to check call permissions, nor do they have an address entry in SelectorData, so it doesn't make sense + // to use cached booleans (hasPreExecHooks, hasPostOnlyExecHooks, etc.) to conditionally bypass certain + // steps, as it would just be an added `sload` in the nonzero hooks case. + + // Run any pre permitted call hooks specific to this caller and the `executeFromPluginExternal` selector + bytes24 permittedCallKey = + _getPermittedCallKey(callingPlugin, IPluginExecutor.executeFromPluginExternal.selector); + (FunctionReference[] memory postPermittedCallHooks, bytes[] memory postPermittedCallHookArgs) = + _doPrePermittedCallHooks(permittedCallKey, ""); + + // Run any pre exec hooks for the `executeFromPluginExternal` selector + (FunctionReference[] memory postExecHooks, bytes[] memory postExecHookArgs) = + _doPreExecHooks(IPluginExecutor.executeFromPluginExternal.selector, ""); + + // Perform the external call + bytes memory returnData = _exec(target, value, data); + + // Run any post exec hooks for the `executeFromPluginExternal` selector + _doCachedPostHooks(postExecHooks, postExecHookArgs); + + // Run any post only exec hooks for the `executeFromPluginExternal` selector + _doCachedPostHooks( + CastLib.toFunctionReferenceArray( + storage_.selectorData[IPluginExecutor.executeFromPluginExternal.selector] + .executionHooks + .postOnlyHooks + .getAll() + ), + new bytes[](0) + ); + + // Run any post permitted call hooks specific to this caller and the `executeFromPluginExternal` selector + _doCachedPostHooks(postPermittedCallHooks, postPermittedCallHookArgs); + + // Run any post only permitted call hooks specific to this caller and the `executeFromPluginExternal` + // selector + _doCachedPostHooks( + CastLib.toFunctionReferenceArray( + storage_.permittedCalls[permittedCallKey].permittedCallHooks.postOnlyHooks.getAll() + ), + new bytes[](0) + ); + + return returnData; + } + + /// @inheritdoc IPluginManager + function installPlugin( + address plugin, + bytes32 manifestHash, + bytes calldata pluginInitData, + FunctionReference[] calldata dependencies, + InjectedHook[] calldata injectedHooks + ) external override { + (FunctionReference[] memory postExecHooks, bytes[] memory postHookArgs) = _preNativeFunction(); + _installPlugin(plugin, manifestHash, pluginInitData, dependencies, injectedHooks); + _postNativeFunction(postExecHooks, postHookArgs); + } + + /// @inheritdoc IPluginManager + function uninstallPlugin( + address plugin, + bytes calldata config, + bytes calldata pluginUninstallData, + bytes[] calldata hookUnapplyData + ) external override { + (FunctionReference[] memory postExecHooks, bytes[] memory postHookArgs) = _preNativeFunction(); + + UninstallPluginArgs memory args; + args.plugin = plugin; + bool hasSetManifest; + + if (config.length > 0) { + UninstallPluginConfig memory decodedConfig = abi.decode(config, (UninstallPluginConfig)); + if (decodedConfig.serializedManifest.length > 0) { + args.manifest = abi.decode(decodedConfig.serializedManifest, (PluginManifest)); + hasSetManifest = true; + } + args.forceUninstall = decodedConfig.forceUninstall; + args.callbackGasLimit = decodedConfig.callbackGasLimit; + } + if (!hasSetManifest) { + args.manifest = IPlugin(plugin).pluginManifest(); + } + if (args.callbackGasLimit == 0) { + args.callbackGasLimit = type(uint256).max; + } + + _uninstallPlugin(args, pluginUninstallData, hookUnapplyData); + + _postNativeFunction(postExecHooks, postHookArgs); + } + + /// @inheritdoc IERC165 + function supportsInterface(bytes4 interfaceId) external view override returns (bool) { + if (interfaceId == _INTERFACE_ID_INVALID) { + return false; + } + if (interfaceId == _IERC165_INTERFACE_ID) { + return true; + } + + return _getAccountStorage().supportedInterfaces[interfaceId] > 0; + } + + /// @inheritdoc UUPSUpgradeable + function upgradeToAndCall(address newImplementation, bytes calldata data) public payable override onlyProxy { + (FunctionReference[] memory postExecHooks, bytes[] memory postHookArgs) = _preNativeFunction(); + UUPSUpgradeable.upgradeToAndCall(newImplementation, data); + _postNativeFunction(postExecHooks, postHookArgs); + } + + /// @inheritdoc IAccountView + function entryPoint() public view override returns (IEntryPoint) { + return _ENTRY_POINT; + } + + /// @inheritdoc IAccountView + function getNonce() public view virtual override returns (uint256) { + return _ENTRY_POINT.getNonce(address(this), 0); + } + + // INTERNAL FUNCTIONS + + /// @dev Wraps execution of a native function with runtime validation and hooks. Used for upgradeToAndCall, + /// execute, executeBatch, installPlugin, uninstallPlugin. + function _preNativeFunction() + internal + returns (FunctionReference[] memory postExecHooks, bytes[] memory postExecHookArgs) + { + bytes memory callBuffer = ""; + + if (msg.sender != address(_ENTRY_POINT)) { + callBuffer = _doRuntimeValidation(); + } + + (postExecHooks, postExecHookArgs) = _doPreExecHooks(msg.sig, callBuffer); + } + + /// @dev Wraps execution of a native function with runtime validation and hooks. Used for upgradeToAndCall, + /// execute, executeBatch, installPlugin, uninstallPlugin. + function _postNativeFunction(FunctionReference[] memory postExecHooks, bytes[] memory postExecHookArgs) + internal + { + _doCachedPostHooks(postExecHooks, postExecHookArgs); + + _doCachedPostHooks( + CastLib.toFunctionReferenceArray( + _getAccountStorage().selectorData[msg.sig].executionHooks.postOnlyHooks.getAll() + ), + new bytes[](0) + ); + } + + /// @dev To support gas estimation, we don't fail early when the failure is caused by a signature failure. + function _doUserOpValidation( + bytes4 selector, + FunctionReference userOpValidationFunction, + UserOperation calldata userOp, + bytes32 userOpHash, + bool doPreValidationHooks + ) internal returns (uint256 validationData) { + if (userOpValidationFunction == FunctionReferenceLib._EMPTY_FUNCTION_REFERENCE) { + revert UserOpValidationFunctionMissing(selector); + } + + bytes memory callBuffer = + _allocateUserOpCallBuffer(IPlugin.preUserOpValidationHook.selector, userOp, userOpHash); + + uint256 currentValidationData; + uint256 preUserOpValidationHooksLength; + + if (doPreValidationHooks) { + // Do preUserOpValidation hooks + FunctionReference[] memory preUserOpValidationHooks = CastLib.toFunctionReferenceArray( + _getAccountStorage().selectorData[selector].preUserOpValidationHooks.getAll() + ); + + preUserOpValidationHooksLength = preUserOpValidationHooks.length; + for (uint256 i = 0; i < preUserOpValidationHooksLength;) { + // FunctionReference preUserOpValidationHook = preUserOpValidationHooks[i]; + + if (preUserOpValidationHooks[i].isEmptyOrMagicValue()) { + // Empty function reference is impossible here due to the element coming from a LinkedListSet. + // Runtime Validation Always Allow is not assignable here. + // Pre Hook Always Deny is the only assignable magic value here. + revert AlwaysDenyRule(); + } + + (address plugin, uint8 functionId) = preUserOpValidationHooks[i].unpack(); + + _updatePluginCallBufferFunctionId(callBuffer, functionId); + + currentValidationData = _executeUserOpPluginFunction(callBuffer, plugin); + + if (uint160(currentValidationData) > 1) { + // If the aggregator is not 0 or 1, it is an unexpected value + revert UnexpectedAggregator(plugin, functionId, address(uint160(currentValidationData))); + } + validationData = _coalescePreValidation(validationData, currentValidationData); + + unchecked { + ++i; + } + } + } + + // Run the user op validation function + { + _updatePluginCallBufferSelector(callBuffer, IPlugin.userOpValidationFunction.selector); + // No magic values are assignable here, and we already checked whether or not the function was empty, + // so we're OK to use the function immediately + (address plugin, uint8 functionId) = userOpValidationFunction.unpack(); + + _updatePluginCallBufferFunctionId(callBuffer, functionId); + + currentValidationData = _executeUserOpPluginFunction(callBuffer, plugin); + + if (preUserOpValidationHooksLength != 0) { + // If we have other validation data we need to coalesce with + validationData = _coalesceValidation(validationData, currentValidationData); + } else { + validationData = currentValidationData; + } + } + } + + function _doRuntimeValidation() internal returns (bytes memory callBuffer) { + AccountStorage storage storage_ = _getAccountStorage(); + FunctionReference runtimeValidationFunction = storage_.selectorData[msg.sig].runtimeValidation; + bool doPreRuntimeValidationHooks = storage_.selectorData[msg.sig].hasPreRuntimeValidationHooks; + + // Allocate the call buffer for preRuntimeValidationHook + callBuffer = _allocateRuntimeCallBuffer(msg.data); + + if (doPreRuntimeValidationHooks) { + _updatePluginCallBufferSelector(callBuffer, IPlugin.preRuntimeValidationHook.selector); + + // run all preRuntimeValidation hooks + FunctionReference[] memory preRuntimeValidationHooks = CastLib.toFunctionReferenceArray( + _getAccountStorage().selectorData[msg.sig].preRuntimeValidationHooks.getAll() + ); + + uint256 preRuntimeValidationHooksLength = preRuntimeValidationHooks.length; + for (uint256 i = 0; i < preRuntimeValidationHooksLength;) { + FunctionReference preRuntimeValidationHook = preRuntimeValidationHooks[i]; + + if (preRuntimeValidationHook.isEmptyOrMagicValue()) { + // The function reference must be the Always Deny magic value in this case, + // because zero and any other magic value is unassignable here. + revert AlwaysDenyRule(); + } + + (address plugin, uint8 functionId) = preRuntimeValidationHook.unpack(); + + _updatePluginCallBufferFunctionId(callBuffer, functionId); + + _executeRuntimePluginFunction(callBuffer, plugin, PreRuntimeValidationHookFailed.selector); + + unchecked { + ++i; + } + } + } + + // Identifier scope limiting + { + if (runtimeValidationFunction.isEmptyOrMagicValue()) { + if ( + runtimeValidationFunction == FunctionReferenceLib._EMPTY_FUNCTION_REFERENCE + && (msg.sig != IPluginManager.installPlugin.selector || msg.sender != address(this)) + ) { + // Runtime calls cannot be made against functions with no + // validator, except in the special case of self-calls to + // `installPlugin`, to enable removing the plugin protecting + // `installPlugin` and installing a different one as part of + // a single batch execution. + revert RuntimeValidationFunctionMissing(msg.sig); + } + // If _RUNTIME_VALIDATION_ALWAYS_ALLOW, or we're in the + // `installPlugin` special case,just let the function finish, + // without the else branch. + } else { + _updatePluginCallBufferSelector(callBuffer, IPlugin.runtimeValidationFunction.selector); + + (address plugin, uint8 functionId) = runtimeValidationFunction.unpack(); + + _updatePluginCallBufferFunctionId(callBuffer, functionId); + + _executeRuntimePluginFunction(callBuffer, plugin, RuntimeValidationFunctionReverted.selector); + } + } + } + + function _doPreExecHooks(bytes4 selector, bytes memory callBuffer) + internal + returns (FunctionReference[] memory, bytes[] memory) + { + SelectorData storage selectorData = _getAccountStorage().selectorData[selector]; + return _doPreHooks( + selectorData.executionHooks.preHooks, selectorData.executionHooks.associatedPostHooks, callBuffer + ); + } + + function _doPrePermittedCallHooks(bytes24 permittedCallKey, bytes memory callBuffer) + internal + returns (FunctionReference[] memory, bytes[] memory) + { + PermittedCallData storage permittedCallData = _getAccountStorage().permittedCalls[permittedCallKey]; + return _doPreHooks( + permittedCallData.permittedCallHooks.preHooks, + permittedCallData.permittedCallHooks.associatedPostHooks, + callBuffer + ); + } + + function _doPreHooks( + LinkedListSet storage preHookSet, + mapping(FunctionReference => LinkedListSet) storage associatedPostHooks, + bytes memory callBuffer + ) internal returns (FunctionReference[] memory postHooks, bytes[] memory postHookArgs) { + FunctionReference[] memory preExecHooks = CastLib.toFunctionReferenceArray(preHookSet.getAll()); + + uint256 preExecHooksLength = preExecHooks.length; + uint256 maxPostHooksToRunLength; + + // There can only be as many associated post hooks to run as there are pre hooks. + for (uint256 i = 0; i < preExecHooksLength;) { + unchecked { + maxPostHooksToRunLength += preHookSet.getCount(CastLib.toSetValue(preExecHooks[i])); + ++i; + } + } + + // Overallocate on length, but not all of this may get filled up. + postHooks = new FunctionReference[](maxPostHooksToRunLength); + postHookArgs = new bytes[](maxPostHooksToRunLength); + uint256 actualPostHooksToRunLength; + + // If not running anything, short-circuit before allocating more memory for the call buffers. + if (preExecHooksLength == 0) { + return (postHooks, postHookArgs); + } + + if (callBuffer.length == 0) { + // Allocate the call buffer for preExecHook. This case MUST NOT be reached by `executeFromPlugin`, + // otherwise the call will execute with the wrong calldata. This case should only be reachable by + // native functions with no runtime validation (i.e. being calling via a user operation). + callBuffer = _allocateRuntimeCallBuffer(msg.data); + } + _updatePluginCallBufferSelector(callBuffer, IPlugin.preExecutionHook.selector); + + for (uint256 i = 0; i < preExecHooksLength;) { + FunctionReference preExecHook = preExecHooks[i]; + + if (preExecHook.isEmptyOrMagicValue()) { + // The function reference must be the Always Deny magic value in this case, + // because zero and any other magic value is unassignable here. + revert AlwaysDenyRule(); + } + + (address plugin, uint8 functionId) = preExecHook.unpack(); + + _updatePluginCallBufferFunctionId(callBuffer, functionId); + + if (preHookSet.flagsEnabled(CastLib.toSetValue(preExecHook), _PRE_EXEC_HOOK_HAS_POST_FLAG)) { + FunctionReference[] memory associatedPostExecHooks = + CastLib.toFunctionReferenceArray(associatedPostHooks[preExecHook].getAll()); + uint256 associatedPostExecHooksLength = associatedPostExecHooks.length; + + for (uint256 j = 0; j < associatedPostExecHooksLength;) { + // Execute the pre-hook as many times as there are unique associated post-hooks. + _executeRuntimePluginFunction(callBuffer, plugin, PreExecHookReverted.selector); + + postHooks[actualPostHooksToRunLength] = associatedPostExecHooks[j]; + postHookArgs[actualPostHooksToRunLength] = abi.decode(_collectReturnData(), (bytes)); + + unchecked { + ++actualPostHooksToRunLength; + ++j; + } + } + } else { + _executeRuntimePluginFunction(callBuffer, plugin, PreExecHookReverted.selector); + } + + unchecked { + ++i; + } + } + + // "Trim" the associated post hook arrays to the actual length, since we may have overallocated. This + // allows for exeuction of post hooks in reverse order. + assembly ("memory-safe") { + mstore(postHooks, actualPostHooksToRunLength) + mstore(postHookArgs, actualPostHooksToRunLength) + } + } + + function _doCachedPostHooks(FunctionReference[] memory postHooks, bytes[] memory postHookArgs) internal { + uint256 postHooksToRunLength = postHooks.length; + bool hasPostHookArgs = postHookArgs.length > 0; + for (uint256 i = postHooksToRunLength; i > 0;) { + FunctionReference postExecHook = postHooks[i - 1]; + (address plugin, uint8 functionId) = postExecHook.unpack(); + // solhint-disable-next-line no-empty-blocks + try IPlugin(plugin).postExecutionHook(functionId, hasPostHookArgs ? postHookArgs[i - 1] : bytes("")) {} + catch (bytes memory revertReason) { + revert PostExecHookReverted(plugin, functionId, revertReason); + } + + unchecked { + --i; + } + } + } + + /// @inheritdoc UUPSUpgradeable + // solhint-disable-next-line no-empty-blocks + function _authorizeUpgrade(address newImplementation) internal override {} + + /// @dev Revert with an appropriate error if the calldata does not include a function selector. + function _selectorFromCallData(bytes calldata data) internal pure returns (bytes4) { + if (data.length < 4) { + revert UnrecognizedFunction(bytes4(data)); + } + return bytes4(data); + } +} diff --git a/src/factory/MultiOwnerMSCAFactory.sol b/src/factory/MultiOwnerMSCAFactory.sol new file mode 100644 index 00000000..3acd4b20 --- /dev/null +++ b/src/factory/MultiOwnerMSCAFactory.sol @@ -0,0 +1,120 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.21; + +import {Create2} from "@openzeppelin/contracts/utils/Create2.sol"; +import {ERC1967Proxy} from "@openzeppelin/contracts/proxy/ERC1967/ERC1967Proxy.sol"; +import {IERC20} from "@openzeppelin/contracts/token/ERC20/IERC20.sol"; +import {Ownable} from "@openzeppelin/contracts/access/Ownable.sol"; +import {SafeERC20} from "@openzeppelin/contracts/token/ERC20/utils/SafeERC20.sol"; + +import {IAccountInitializable} from "../interfaces/IAccountInitializable.sol"; +import {IEntryPoint} from "../interfaces/erc4337/IEntryPoint.sol"; + +/// @title Multi Owner Plugin MSCA (Modular Smart Contract Account) Factory +/// @author Alchemy +/// @notice Factory for ERC-6900 compatible upgradeable modular accounts with MultiOwnerPlugin installed. +/// @dev There is a reliance on the assumption that the plugin manifest will remain static, following ERC-6900. If +/// this assumption is broken then account deployments would be bricked. +contract MultiOwnerMSCAFactory is Ownable { + address public immutable MULTI_OWNER_PLUGIN; + address public immutable IMPL; + bytes32 internal immutable _MULTI_OWNER_PLUGIN_MANIFEST_HASH; + IEntryPoint public immutable ENTRYPOINT; + + /// @notice Constructor for the factory + constructor( + address owner, + address multiOwnerPlugin, + address implementation, + bytes32 multiOwnerPluginManifestHash, + IEntryPoint entryPoint + ) { + _transferOwnership(owner); + MULTI_OWNER_PLUGIN = multiOwnerPlugin; + IMPL = implementation; + _MULTI_OWNER_PLUGIN_MANIFEST_HASH = multiOwnerPluginManifestHash; + ENTRYPOINT = entryPoint; + } + + /// @notice Allow contract to receive native currency + receive() external payable {} + + /// @notice Create a modular smart contract account + /// @dev Account address depends on salt, impl addr, plugins and plugin init data + /// @param salt salt for additional entropy for create2 + /// @param owners address array of the owners + function createAccount(uint256 salt, address[] calldata owners) external returns (address addr) { + bytes[] memory pluginInitBytes = new bytes[](1); + pluginInitBytes[0] = abi.encode(owners); + + bytes32 combinedSalt = _getSalt(salt, pluginInitBytes[0]); + addr = Create2.computeAddress( + combinedSalt, keccak256(abi.encodePacked(type(ERC1967Proxy).creationCode, abi.encode(IMPL, ""))) + ); + + // short circuit if exists + if (addr.code.length == 0) { + // not necessary to check return addr of this arg since next call fails if so + new ERC1967Proxy{salt : combinedSalt}(IMPL, ""); + + address[] memory plugins = new address[](1); + plugins[0] = MULTI_OWNER_PLUGIN; + + bytes32[] memory manifestHashes = new bytes32[](1); + manifestHashes[0] = _MULTI_OWNER_PLUGIN_MANIFEST_HASH; + + IAccountInitializable(addr).initialize(plugins, abi.encode(manifestHashes, pluginInitBytes)); + } + } + + /// @notice Add stake to an entry point + /// @dev only callable by owner + /// @param unstakeDelay unstake delay for the stake + /// @param amount amount of native currency to stake + function addStake(uint32 unstakeDelay, uint256 amount) external payable onlyOwner { + ENTRYPOINT.addStake{value: amount}(unstakeDelay); + } + + /// @notice Start unlocking stake for an entry point + /// @dev only callable by owner + function unlockStake() external onlyOwner { + ENTRYPOINT.unlockStake(); + } + + /// @notice Withdraw stake from an entry point + /// @dev only callable by owner + /// @param to address to send native currency to + function withdrawStake(address payable to) external onlyOwner { + ENTRYPOINT.withdrawStake(to); + } + + /// @notice Withdraw funds from this contract + /// @dev can withdraw stuck erc20s + /// @param to address to send native currency to + /// @param token address of the token to withdraw, 0 address for native currency + /// @param amount amount of the token to withdraw in case of rebasing tokens + function withdraw(address payable to, address token, uint256 amount) external onlyOwner { + if (token == address(0)) { + to.transfer(address(this).balance); + } else { + SafeERC20.safeTransfer(IERC20(token), to, amount); + } + } + + /// @notice Getter for counterfactual address based on input params + /// @param salt salt for additional entropy for create2 + /// @param owners array of addresses of the owner + function getAddress(uint256 salt, address[] calldata owners) external view returns (address) { + return Create2.computeAddress( + _getSalt(salt, abi.encode(owners)), + keccak256(abi.encodePacked(type(ERC1967Proxy).creationCode, abi.encode(IMPL, ""))) + ); + } + + /// @notice Gets this factory's create2 salt based on the input params + /// @param salt additional entropy for create2 + /// @param owners encoded bytes array of owner addresses + function _getSalt(uint256 salt, bytes memory owners) internal pure returns (bytes32) { + return keccak256(abi.encode(salt, owners)); + } +} diff --git a/src/factory/MultiOwnerTokenReceiverMSCAFactory.sol b/src/factory/MultiOwnerTokenReceiverMSCAFactory.sol new file mode 100644 index 00000000..b5c47716 --- /dev/null +++ b/src/factory/MultiOwnerTokenReceiverMSCAFactory.sol @@ -0,0 +1,129 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.21; + +import {Create2} from "@openzeppelin/contracts/utils/Create2.sol"; +import {ERC1967Proxy} from "@openzeppelin/contracts/proxy/ERC1967/ERC1967Proxy.sol"; +import {IERC20} from "@openzeppelin/contracts/token/ERC20/IERC20.sol"; +import {Ownable} from "@openzeppelin/contracts/access/Ownable.sol"; +import {SafeERC20} from "@openzeppelin/contracts/token/ERC20/utils/SafeERC20.sol"; + +import {IAccountInitializable} from "../interfaces/IAccountInitializable.sol"; +import {IEntryPoint} from "../interfaces/erc4337/IEntryPoint.sol"; + +/// @title Multi Owner Plugin + Token Receiver MSCA (Modular Smart Contract Account) Factory +/// @author Alchemy +/// @notice Factory for ERC-6900 compatible upgradeable modular accounts with MultiOwnerPlugin and TokenReceiver +/// installed. +/// @dev There is a reliance on the assumption that the plugin manifest will remain static, following ERC-6900. If +/// this assumption is broken then account deployments would be bricked. +contract MultiOwnerTokenReceiverMSCAFactory is Ownable { + address public immutable MULTI_OWNER_PLUGIN; + address public immutable TOKEN_RECEIVER_PLUGIN; + address public immutable IMPL; + bytes32 internal immutable _MULTI_OWNER_PLUGIN_MANIFEST_HASH; + bytes32 internal immutable _TOKEN_RECEIVER_PLUGIN_MANIFEST_HASH; + IEntryPoint public immutable ENTRYPOINT; + + /// @notice Constructor for the factory + constructor( + address owner, + address multiOwnerPlugin, + address tokenReceiverPlugin, + address implementation, + bytes32 multiOwnerPluginManifestHash, + bytes32 tokenReceiverPluginManifestHash, + IEntryPoint entryPoint + ) { + _transferOwnership(owner); + MULTI_OWNER_PLUGIN = multiOwnerPlugin; + TOKEN_RECEIVER_PLUGIN = tokenReceiverPlugin; + IMPL = implementation; + _MULTI_OWNER_PLUGIN_MANIFEST_HASH = multiOwnerPluginManifestHash; + _TOKEN_RECEIVER_PLUGIN_MANIFEST_HASH = tokenReceiverPluginManifestHash; + ENTRYPOINT = entryPoint; + } + + /// @notice Allow contract to receive native currency + receive() external payable {} + + /// @notice Create a modular smart contract account + /// @dev Account address depends on salt, impl addr, plugins and plugin init data + /// @param salt salt for additional entropy for create2 + /// @param owners address array of the owners + function createAccount(uint256 salt, address[] calldata owners) external returns (address addr) { + bytes[] memory pluginInitBytes = new bytes[](2); // empty bytes for TokenReceiverPlugin init + pluginInitBytes[0] = abi.encode(owners); + + bytes32 combinedSalt = _getSalt(salt, pluginInitBytes[0]); + addr = Create2.computeAddress( + combinedSalt, keccak256(abi.encodePacked(type(ERC1967Proxy).creationCode, abi.encode(IMPL, ""))) + ); + + // short circuit if exists + if (addr.code.length == 0) { + // not necessary to check return addr of this arg since next call fails if so + new ERC1967Proxy{salt : combinedSalt}(IMPL, ""); + + address[] memory plugins = new address[](2); + plugins[0] = MULTI_OWNER_PLUGIN; + plugins[1] = TOKEN_RECEIVER_PLUGIN; + + bytes32[] memory manifestHashes = new bytes32[](2); + manifestHashes[0] = _MULTI_OWNER_PLUGIN_MANIFEST_HASH; + manifestHashes[1] = _TOKEN_RECEIVER_PLUGIN_MANIFEST_HASH; + + IAccountInitializable(addr).initialize(plugins, abi.encode(manifestHashes, pluginInitBytes)); + } + } + + /// @notice Add stake to an entry point + /// @dev only callable by owner + /// @param unstakeDelay unstake delay for the stake + /// @param amount amount of native currency to stake + function addStake(uint32 unstakeDelay, uint256 amount) external payable onlyOwner { + ENTRYPOINT.addStake{value: amount}(unstakeDelay); + } + + /// @notice Start unlocking stake for an entry point + /// @dev only callable by owner + function unlockStake() external onlyOwner { + ENTRYPOINT.unlockStake(); + } + + /// @notice Withdraw stake from an entry point + /// @dev only callable by owner + /// @param to address to send native currency to + function withdrawStake(address payable to) external onlyOwner { + ENTRYPOINT.withdrawStake(to); + } + + /// @notice Withdraw funds from this contract + /// @dev can withdraw stuck erc20s + /// @param to address to send native currency to + /// @param token address of the token to withdraw, 0 address for native currency + /// @param amount amount of the token to withdraw in case of rebasing tokens + function withdraw(address payable to, address token, uint256 amount) external onlyOwner { + if (token == address(0)) { + to.transfer(address(this).balance); + } else { + SafeERC20.safeTransfer(IERC20(token), to, amount); + } + } + + /// @notice Getter for counterfactual address based on input params + /// @param salt salt for additional entropy for create2 + /// @param owners array of addresses of the owner + function getAddress(uint256 salt, address[] calldata owners) external view returns (address) { + return Create2.computeAddress( + _getSalt(salt, abi.encode(owners)), + keccak256(abi.encodePacked(type(ERC1967Proxy).creationCode, abi.encode(IMPL, ""))) + ); + } + + /// @notice Gets this factory's create2 salt based on the input params + /// @param salt additional entropy for create2 + /// @param owners encoded bytes array of owner addresses + function _getSalt(uint256 salt, bytes memory owners) internal pure returns (bytes32) { + return keccak256(abi.encode(salt, owners)); + } +} diff --git a/src/helpers/KnownSelectors.sol b/src/helpers/KnownSelectors.sol new file mode 100644 index 00000000..3395785d --- /dev/null +++ b/src/helpers/KnownSelectors.sol @@ -0,0 +1,56 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.21; + +import {IERC165} from "@openzeppelin/contracts/utils/introspection/IERC165.sol"; +import {UUPSUpgradeable} from "../../ext/UUPSUpgradeable.sol"; + +import {IAccount} from "../../src/interfaces/erc4337/IAccount.sol"; +import {IAccountInitializable} from "../interfaces/IAccountInitializable.sol"; +import {IAccountLoupe} from "../../src/interfaces/IAccountLoupe.sol"; +import {IAccountView} from "../../src/interfaces/IAccountView.sol"; +import {IAggregator} from "../../src/interfaces/erc4337/IAggregator.sol"; +import {IPaymaster} from "../../src/interfaces/erc4337/IPaymaster.sol"; +import {IPluginExecutor} from "../interfaces/IPluginExecutor.sol"; +import {IPluginManager} from "../../src/interfaces/IPluginManager.sol"; +import {IStandardExecutor} from "../interfaces/IStandardExecutor.sol"; + +/// @title Known Selectors +/// @author Alchemy +/// @notice Library to help to check if a selector is a know function selector of the modular account or ERC-4337 +/// contract. +library KnownSelectors { + function isNativeFunction(bytes4 selector) internal pure returns (bool) { + return + // check against IAccount methods + selector == IAccount.validateUserOp.selector + // check against IAccountView methods + || selector == IAccountView.entryPoint.selector || selector == IAccountView.getNonce.selector + // check against IPluginManager methods + || selector == IPluginManager.installPlugin.selector || selector == IPluginManager.uninstallPlugin.selector + // check against IERC165 methods + || selector == IERC165.supportsInterface.selector + // check against UUPSUpgradeable methods + || selector == UUPSUpgradeable.proxiableUUID.selector + || selector == UUPSUpgradeable.upgradeToAndCall.selector + // check against IStandardExecutor methods + || selector == IStandardExecutor.execute.selector || selector == IStandardExecutor.executeBatch.selector + // check against IPluginExecutor methods + || selector == IPluginExecutor.executeFromPlugin.selector + || selector == IPluginExecutor.executeFromPluginExternal.selector + // check against IAccountInitializable methods + || selector == IAccountInitializable.initialize.selector + // check against IAccountLoupe methods + || selector == IAccountLoupe.getExecutionFunctionConfig.selector + || selector == IAccountLoupe.getExecutionHooks.selector + || selector == IAccountLoupe.getPermittedCallHooks.selector + || selector == IAccountLoupe.getPreValidationHooks.selector + || selector == IAccountLoupe.getInstalledPlugins.selector; + } + + function isErc4337Function(bytes4 selector) internal pure returns (bool) { + return selector == IAggregator.validateSignatures.selector + || selector == IAggregator.validateUserOpSignature.selector + || selector == IAggregator.aggregateSignatures.selector + || selector == IPaymaster.validatePaymasterUserOp.selector || selector == IPaymaster.postOp.selector; + } +} diff --git a/src/helpers/ValidationDataHelpers.sol b/src/helpers/ValidationDataHelpers.sol new file mode 100644 index 00000000..92c83baf --- /dev/null +++ b/src/helpers/ValidationDataHelpers.sol @@ -0,0 +1,50 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.21; + +// solhint-disable-next-line private-vars-leading-underscore +function _coalescePreValidation(uint256 validationData1, uint256 validationData2) + pure + returns (uint256 resValidationData) +{ + uint48 validUntil1 = uint48(validationData1 >> 160); + if (validUntil1 == 0) { + validUntil1 = type(uint48).max; + } + uint48 validUntil2 = uint48(validationData2 >> 160); + if (validUntil2 == 0) { + validUntil2 = type(uint48).max; + } + resValidationData = ((validUntil1 > validUntil2) ? uint256(validUntil2) << 160 : uint256(validUntil1) << 160); + + uint48 validAfter1 = uint48(validationData1 >> 208); + uint48 validAfter2 = uint48(validationData2 >> 208); + + resValidationData |= ((validAfter1 < validAfter2) ? uint256(validAfter2) << 208 : uint256(validAfter1) << 208); + + // Once we know that the authorizer field is 0 or 1, we can safely bubble up SIG_FAIL with bitwise OR + resValidationData |= uint160(validationData1) | uint160(validationData2); +} + +// solhint-disable-next-line private-vars-leading-underscore +function _coalesceValidation(uint256 preValidationData, uint256 validationData) + pure + returns (uint256 resValidationData) +{ + uint48 validUntil1 = uint48(preValidationData >> 160); + if (validUntil1 == 0) { + validUntil1 = type(uint48).max; + } + uint48 validUntil2 = uint48(validationData >> 160); + if (validUntil2 == 0) { + validUntil2 = type(uint48).max; + } + resValidationData = ((validUntil1 > validUntil2) ? uint256(validUntil2) << 160 : uint256(validUntil1) << 160); + + uint48 validAfter1 = uint48(preValidationData >> 208); + uint48 validAfter2 = uint48(validationData >> 208); + + resValidationData |= ((validAfter1 < validAfter2) ? uint256(validAfter2) << 208 : uint256(validAfter1) << 208); + + // If prevalidation failed, bubble up failure + resValidationData |= uint160(preValidationData) == 1 ? 1 : uint160(validationData); +} diff --git a/src/interfaces/IAccountInitializable.sol b/src/interfaces/IAccountInitializable.sol new file mode 100644 index 00000000..16d08ece --- /dev/null +++ b/src/interfaces/IAccountInitializable.sol @@ -0,0 +1,11 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.21; + +/// @title Account Initializable Interface +interface IAccountInitializable { + /// @notice Initializes the account with a set of plugins + /// @dev No dependencies or hooks can be injected with this installation + /// @param plugins The plugins to install + /// @param pluginInitData The plugin init data for each plugin + function initialize(address[] calldata plugins, bytes calldata pluginInitData) external; +} diff --git a/src/interfaces/IAccountLoupe.sol b/src/interfaces/IAccountLoupe.sol new file mode 100644 index 00000000..aeebb005 --- /dev/null +++ b/src/interfaces/IAccountLoupe.sol @@ -0,0 +1,57 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.21; + +import {FunctionReference} from "../libraries/FunctionReferenceLib.sol"; + +/// @title Account Loupe Interface +interface IAccountLoupe { + /// @notice Config for an execution function, given a selector + struct ExecutionFunctionConfig { + address plugin; + FunctionReference userOpValidationFunction; + FunctionReference runtimeValidationFunction; + } + + /// @notice Pre and post hooks for a given selector + /// @dev It's possible for one of either `preExecHook` or `postExecHook` to be empty + struct ExecutionHooks { + FunctionReference preExecHook; + FunctionReference postExecHook; + } + + /// @notice Gets the validation functions and plugin address for a selector + /// @dev If the selector is a native function, the plugin address will be the address of the account + /// @param selector The selector to get the configuration for + /// @return The configuration for this selector + function getExecutionFunctionConfig(bytes4 selector) external view returns (ExecutionFunctionConfig memory); + + /// @notice Gets the pre and post execution hooks for a selector + /// @param selector The selector to get the hooks for + /// @return The pre and post execution hooks for this selector + function getExecutionHooks(bytes4 selector) external view returns (ExecutionHooks[] memory); + + /// @notice Gets the pre and post permitted call hooks applied for a plugin calling this selector + /// @param callingPlugin The plugin that is calling the selector + /// @param selector The selector the plugin is calling + /// @return The pre and post permitted call hooks for this selector + function getPermittedCallHooks(address callingPlugin, bytes4 selector) + external + view + returns (ExecutionHooks[] memory); + + /// @notice Gets the pre user op and runtime validation hooks associated with a selector + /// @param selector The selector to get the hooks for + /// @return preUserOpValidationHooks The pre user op validation hooks for this selector + /// @return preRuntimeValidationHooks The pre runtime validation hooks for this selector + function getPreValidationHooks(bytes4 selector) + external + view + returns ( + FunctionReference[] memory preUserOpValidationHooks, + FunctionReference[] memory preRuntimeValidationHooks + ); + + /// @notice Gets an array of all installed plugins + /// @return The addresses of all installed plugins + function getInstalledPlugins() external view returns (address[] memory); +} diff --git a/src/interfaces/IAccountView.sol b/src/interfaces/IAccountView.sol new file mode 100644 index 00000000..f4ea3dcf --- /dev/null +++ b/src/interfaces/IAccountView.sol @@ -0,0 +1,16 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.21; + +import {IEntryPoint} from "./erc4337/IEntryPoint.sol"; + +/// @title Account View Interface +interface IAccountView { + /// @notice Gets the entry point for this account + /// @return entryPoint The entry point for this account + function entryPoint() external view returns (IEntryPoint); + + /// @notice Get the account nonce. + /// @dev uses key 0 + /// @return nonce The next account nonce. + function getNonce() external view returns (uint256); +} diff --git a/src/interfaces/IPlugin.sol b/src/interfaces/IPlugin.sol new file mode 100644 index 00000000..b5b2e2c1 --- /dev/null +++ b/src/interfaces/IPlugin.sol @@ -0,0 +1,206 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.21; + +import {IPluginManager} from "./IPluginManager.sol"; + +import {UserOperation} from "../interfaces/erc4337/UserOperation.sol"; + +// Forge formatter will displace the first comment for the enum field out of the enum itself, +// so annotating here to prevent that. +// forgefmt: disable-start +enum ManifestAssociatedFunctionType { + // Function is not defined. + NONE, + // Function belongs to this plugin. + SELF, + // Function belongs to an external plugin provided as a dependency during plugin installation. + DEPENDENCY, + // Resolves to a magic value to always bypass runtime validation for a given function. + // This is only assignable on runtime validation functions. If it were to be used on a user op validationFunction, + // it would risk burning gas from the account. When used as a hook in any hook location, it is equivalent to not + // setting a hook and is therefore disallowed. + RUNTIME_VALIDATION_ALWAYS_ALLOW, + // Resolves to a magic value to always fail in a hook for a given function. + // This is only assignable to pre hooks (pre validation and pre execution). It should not be used on + // validation functions themselves, because this is equivalent to leaving the validation functions unset. + // It should not be used in post-exec hooks, because if it is known to always revert, that should happen + // as early as possible to save gas. + PRE_HOOK_ALWAYS_DENY +} +// forgefmt: disable-end + +/// @dev For functions of type `ManifestAssociatedFunctionType.DEPENDENCY`, the MSCA MUST find the plugin address +/// of the function at `dependencies[dependencyIndex]` during the call to `installPlugin(config)`. +struct ManifestFunction { + ManifestAssociatedFunctionType functionType; + uint8 functionId; + uint256 dependencyIndex; +} + +struct ManifestAssociatedFunction { + bytes4 executionSelector; + ManifestFunction associatedFunction; +} + +struct ManifestExecutionHook { + bytes4 executionSelector; + ManifestFunction preExecHook; + ManifestFunction postExecHook; +} + +struct ManifestExternalCallPermission { + address externalAddress; + bool permitAnySelector; + bytes4[] selectors; +} + +/// @dev A struct describing how the plugin should be installed on a modular account. +struct PluginManifest { + // List of ERC-165 interfaceIds to add to account to support introspection checks. + bytes4[] interfaceIds; + // If this plugin depends on other plugins' validation functions and/or hooks, the interface IDs of + // those plugins MUST be provided here, with its position in the array matching the `dependencyIndex` + // members of `ManifestFunction` structs used in the manifest. + bytes4[] dependencyInterfaceIds; + // Execution functions defined in this plugin to be installed on the MSCA. + bytes4[] executionFunctions; + // Plugin execution functions already installed on the MSCA that this plugin will be able to call. + bytes4[] permittedExecutionSelectors; + // External addresses that this plugin will be able to call. + bool permitAnyExternalAddress; + // boolean to indicate whether the plugin needs access to spend native tokens of the account + bool canSpendNativeToken; + ManifestExternalCallPermission[] permittedExternalCalls; + ManifestAssociatedFunction[] userOpValidationFunctions; + ManifestAssociatedFunction[] runtimeValidationFunctions; + ManifestAssociatedFunction[] preUserOpValidationHooks; + ManifestAssociatedFunction[] preRuntimeValidationHooks; + ManifestExecutionHook[] executionHooks; + ManifestExecutionHook[] permittedCallHooks; +} + +/// @dev A struct holding fields to describe the plugin in a purely view context. Intended for front end clients. +struct PluginMetadata { + // A human-readable name of the plugin. + string name; + // The version of the plugin, following the semantic versioning scheme. + string version; + // The author field SHOULD be a username representing the identity of the user or organization + // that created this plugin. + string author; + // String desciptions of the relative sensitivity of specific functions. The selectors MUST be selectors for + // functions implemented by this plugin. + SelectorPermission[] permissionDescriptors; +} + +struct SelectorPermission { + bytes4 functionSelector; + string permissionDescription; +} + +/// @title Plugin Interface +interface IPlugin { + /// @notice Initialize plugin data for the modular account. + /// @dev Called by the modular account during `installPlugin`. + /// @param data Optional bytes array to be decoded and used by the plugin to setup initial plugin data for the + /// modular account. + function onInstall(bytes calldata data) external; + + /// @notice Clear plugin data for the modular account. + /// @dev Called by the modular account during `uninstallPlugin`. + /// @param data Optional bytes array to be decoded and used by the plugin to clear plugin data for the modular + /// account. + function onUninstall(bytes calldata data) external; + + /// @notice Run the pre user operation validation hook specified by the `functionId`. + /// @dev Pre user operation validation hooks MUST NOT return an authorizer value other than 0 or 1. + /// @param functionId An identifier that routes the call to different internal implementations, should there be + /// more than one. + /// @param userOp The user operation. + /// @param userOpHash The user operation hash. + /// @return Packed validation data for validAfter (6 bytes), validUntil (6 bytes), and authorizer (20 bytes). + function preUserOpValidationHook(uint8 functionId, UserOperation calldata userOp, bytes32 userOpHash) + external + returns (uint256); + + /// @notice Run the user operation validationFunction specified by the `functionId`. + /// @param functionId An identifier that routes the call to different internal implementations, should there be + /// more than one. + /// @param userOp The user operation. + /// @param userOpHash The user operation hash. + /// @return Packed validation data for validAfter (6 bytes), validUntil (6 bytes), and authorizer (20 bytes). + function userOpValidationFunction(uint8 functionId, UserOperation calldata userOp, bytes32 userOpHash) + external + returns (uint256); + + /// @notice Run the pre runtime validation hook specified by the `functionId`. + /// @dev To indicate the entire call should revert, the function MUST revert. + /// @param functionId An identifier that routes the call to different internal implementations, should there be + /// more than one. + /// @param sender The caller address. + /// @param value The call value. + /// @param data The calldata sent. + function preRuntimeValidationHook(uint8 functionId, address sender, uint256 value, bytes calldata data) + external; + + /// @notice Run the runtime validationFunction specified by the `functionId`. + /// @dev To indicate the entire call should revert, the function MUST revert. + /// @param functionId An identifier that routes the call to different internal implementations, should there be + /// more than one. + /// @param sender The caller address. + /// @param value The call value. + /// @param data The calldata sent. + function runtimeValidationFunction(uint8 functionId, address sender, uint256 value, bytes calldata data) + external; + + /// @notice Run the pre execution hook specified by the `functionId`. + /// @dev To indicate the entire call should revert, the function MUST revert. + /// @param functionId An identifier that routes the call to different internal implementations, should there be + /// more than one. + /// @param sender The caller address. + /// @param value The call value. + /// @param data The calldata sent. + /// @return Context to pass to a post execution hook, if present. An empty bytes array MAY be returned. + function preExecutionHook(uint8 functionId, address sender, uint256 value, bytes calldata data) + external + returns (bytes memory); + + /// @notice Run the post execution hook specified by the `functionId`. + /// @dev To indicate the entire call should revert, the function MUST revert. + /// @param functionId An identifier that routes the call to different internal implementations, should there be + /// more than one. + /// @param preExecHookData The context returned by its associated pre execution hook. + function postExecutionHook(uint8 functionId, bytes calldata preExecHookData) external; + + /// @notice A hook that runs when a hook this plugin owns is installed onto another plugin + /// @dev Optional, use to implement any required setup logic + /// @param pluginAppliedOn The plugin that the hook is being applied on + /// @param injectedHooksInfo Contains pre/post exec hook information + /// @param data Any optional data for setup + function onHookApply( + address pluginAppliedOn, + IPluginManager.InjectedHooksInfo calldata injectedHooksInfo, + bytes calldata data + ) external; + + /// @notice A hook that runs when a hook this plugin owns is unapplied from another plugin + /// @dev Optional, use to implement any required unapplied logic + /// @param pluginAppliedOn The plugin that the hook was applied on + /// @param injectedHooksInfo Contains pre/post exec hook information + /// @param data Any optional data for the unapplied call + function onHookUnapply( + address pluginAppliedOn, + IPluginManager.InjectedHooksInfo calldata injectedHooksInfo, + bytes calldata data + ) external; + + /// @notice Describe the contents and intended configuration of the plugin. + /// @dev This manifest MUST stay constant over time. + /// @return A manifest describing the contents and intended configuration of the plugin. + function pluginManifest() external pure returns (PluginManifest memory); + + /// @notice Describe the metadata of the plugin. + /// @dev This metadata MUST stay constant over time. + /// @return A metadata struct describing the plugin. + function pluginMetadata() external pure returns (PluginMetadata memory); +} diff --git a/src/interfaces/IPluginExecutor.sol b/src/interfaces/IPluginExecutor.sol new file mode 100644 index 00000000..07717c02 --- /dev/null +++ b/src/interfaces/IPluginExecutor.sol @@ -0,0 +1,24 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.21; + +/// @title Plugin Executor Interface +interface IPluginExecutor { + /// @notice Method from cals made from plugins to other plugin execution functions. Plugins are not allowed to + /// call accounts native functions. + /// @dev Permissions must be granted to the calling plugin for the call to go through + /// @param data The call data for the call. + /// @return The return data from the call. + function executeFromPlugin(bytes calldata data) external payable returns (bytes memory); + + /// @notice Method from calls made from plugins to external addresses. + /// @dev If the target is a plugin, the call SHOULD revert. Permissions must be granted to the calling plugin + /// for the call to go through + /// @param target The address to be called. + /// @param value The value to pass. + /// @param data The data to pass. + /// @return The result of the call + function executeFromPluginExternal(address target, uint256 value, bytes calldata data) + external + payable + returns (bytes memory); +} diff --git a/src/interfaces/IPluginManager.sol b/src/interfaces/IPluginManager.sol new file mode 100644 index 00000000..8ac143b5 --- /dev/null +++ b/src/interfaces/IPluginManager.sol @@ -0,0 +1,68 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.21; + +import {FunctionReference} from "../libraries/FunctionReferenceLib.sol"; + +/// @title Plugin Manager Interface +interface IPluginManager { + /// @dev Pre/post exec hooks added by the user to limit the scope of a plugin. These hooks are injected at + /// plugin install time + struct InjectedHook { + // The plugin that provides the hook + address providingPlugin; + // Either a plugin-defined execution function, or the native function executeFromPluginExternal + bytes4 selector; + InjectedHooksInfo injectedHooksInfo; + bytes hookApplyData; + } + + struct InjectedHooksInfo { + uint8 preExecHookFunctionId; + bool isPostHookUsed; + uint8 postExecHookFunctionId; + } + + /// @dev Note that we strip hookApplyData from InjectedHooks in this event for gas savings + event PluginInstalled( + address indexed plugin, + bytes32 manifestHash, + FunctionReference[] dependencies, + InjectedHook[] injectedHooks + ); + + event PluginUninstalled(address indexed plugin, bool indexed callbacksSucceeded); + event PluginIgnoredHookUnapplyCallbackFailure(address indexed plugin, address indexed providingPlugin); + event PluginIgnoredUninstallCallbackFailure(address indexed plugin); + + /// @notice Install a plugin to the modular account. + /// @param plugin The plugin to install. + /// @param manifestHash The hash of the plugin manifest. + /// @param pluginInitData Optional data to be decoded and used by the plugin to setup initial plugin data for + /// the modular account. + /// @param dependencies The dependencies of the plugin, as described in the manifest. + /// @param injectedHooks Optional hooks to be injected over permitted calls this plugin may make. Alchemy + /// Accounts only support injected permitted call hooks. + function installPlugin( + address plugin, + bytes32 manifestHash, + bytes calldata pluginInitData, + FunctionReference[] calldata dependencies, + InjectedHook[] calldata injectedHooks + ) external; + + /// @notice Uninstall a plugin from the modular account. + /// @dev Uninstalling owner plugins outside of a replace operation via executeBatch risks losing the account! + /// @param plugin The plugin to uninstall. + /// @param config An optional, implementation-specific field that accounts may use to ensure consistency + /// guarantees. + /// @param pluginUninstallData Optional data to be decoded and used by the plugin to clear plugin data for the + /// modular account. + /// @param hookUnapplyData Optional data to be decoded and used by the plugin to clear injected hooks for the + /// modular account. + function uninstallPlugin( + address plugin, + bytes calldata config, + bytes calldata pluginUninstallData, + bytes[] calldata hookUnapplyData + ) external; +} diff --git a/src/interfaces/IStandardExecutor.sol b/src/interfaces/IStandardExecutor.sol new file mode 100644 index 00000000..dd457435 --- /dev/null +++ b/src/interfaces/IStandardExecutor.sol @@ -0,0 +1,29 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.21; + +struct Call { + // The target address for account to call. + address target; + // The value sent with the call. + uint256 value; + // The call data for the call. + bytes data; +} + +/// @title Standard Executor Interface +interface IStandardExecutor { + /// @notice Standard execute method. + /// @dev If the target is a plugin, the call SHOULD revert. + /// @param target The target address for account to call. + /// @param value The value sent with the call. + /// @param data The call data for the call. + /// @return The return data from the call. + function execute(address target, uint256 value, bytes calldata data) external payable returns (bytes memory); + + /// @notice Standard executeBatch method. + /// @dev If the target is a plugin, the call SHOULD revert. If any of the transactions revert, the entire batch + /// reverts + /// @param calls The array of calls. + /// @return An array containing the return data from the calls. + function executeBatch(Call[] calldata calls) external payable returns (bytes[] memory); +} diff --git a/src/interfaces/erc4337/IAccount.sol b/src/interfaces/erc4337/IAccount.sol new file mode 100644 index 00000000..2b10f0ef --- /dev/null +++ b/src/interfaces/erc4337/IAccount.sol @@ -0,0 +1,25 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.21; + +import {IEntryPoint} from "./IEntryPoint.sol"; +import {UserOperation} from "./UserOperation.sol"; + +/// @notice Interface for the ERC-4337 account +interface IAccount { + /// @notice Validates a user operation, presumably by checking the signature and nonce. The entry point will + /// call this function to ensure that a user operation sent to it has been authorized, and thus that it should + /// call the account with the operation's call data and charge the account for gas in the absense of a + /// paymaster. If the signature is correctly formatted but invalid, this should return 1; other failures may + /// revert instead. In the case of a success, this can optionally return a signature aggregator and/or a time + /// range during which the operation is valid. + /// @param userOp the operation to be validated + /// @param userOpHash hash of the operation + /// @param missingAccountFunds amount that the account must send to the entry point as part of validation to + /// pay for gas + /// @return validationData Either 1 for an invalid signature, or a packed structure containing an optional + /// aggregator address in the first 20 bytes followed by two 6-byte timestamps representing the "validUntil" + /// and "validAfter" times at which the operation is valid (a "validUntil" of 0 means it is valid forever). + function validateUserOp(UserOperation calldata userOp, bytes32 userOpHash, uint256 missingAccountFunds) + external + returns (uint256 validationData); +} diff --git a/src/interfaces/erc4337/IAggregator.sol b/src/interfaces/erc4337/IAggregator.sol new file mode 100644 index 00000000..6db959a0 --- /dev/null +++ b/src/interfaces/erc4337/IAggregator.sol @@ -0,0 +1,11 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.21; + +import {UserOperation} from "./UserOperation.sol"; + +/// @notice Interface for the ERC-4337 aggregator +interface IAggregator { + function validateSignatures(UserOperation[] calldata, bytes calldata) external view; + function validateUserOpSignature(UserOperation calldata) external view returns (bytes memory); + function aggregateSignatures(UserOperation[] calldata) external view returns (bytes memory); +} diff --git a/src/interfaces/erc4337/IEntryPoint.sol b/src/interfaces/erc4337/IEntryPoint.sol new file mode 100644 index 00000000..1e011e0d --- /dev/null +++ b/src/interfaces/erc4337/IEntryPoint.sol @@ -0,0 +1,17 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.21; + +import {UserOperation} from "./UserOperation.sol"; + +/// @notice Interface for the ERC-4337 entry point +interface IEntryPoint { + error FailedOp(uint256 i, string s); + + function depositTo(address) external payable; + function addStake(uint32) external payable; + function unlockStake() external; + function withdrawStake(address payable) external; + function handleOps(UserOperation[] calldata, address payable) external; + function getNonce(address, uint192) external view returns (uint256); + function getUserOpHash(UserOperation calldata) external view returns (bytes32); +} diff --git a/src/interfaces/erc4337/IPaymaster.sol b/src/interfaces/erc4337/IPaymaster.sol new file mode 100644 index 00000000..43066c9a --- /dev/null +++ b/src/interfaces/erc4337/IPaymaster.sol @@ -0,0 +1,19 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.21; + +import {UserOperation} from "./UserOperation.sol"; + +/// @notice Interface for the ERC-4337 paymaster +interface IPaymaster { + enum PostOpMode { + opSucceeded, + opReverted, + postOpReverted + } + + function validatePaymasterUserOp(UserOperation calldata, bytes32, uint256) + external + returns (bytes memory, uint256); + + function postOp(PostOpMode, bytes calldata, uint256) external; +} diff --git a/src/interfaces/erc4337/UserOperation.sol b/src/interfaces/erc4337/UserOperation.sol new file mode 100644 index 00000000..1a852ea1 --- /dev/null +++ b/src/interfaces/erc4337/UserOperation.sol @@ -0,0 +1,17 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.21; + +/// @notice User Operation struct as defined in ERC-4337 +struct UserOperation { + address sender; + uint256 nonce; + bytes initCode; + bytes callData; + uint256 callGasLimit; + uint256 verificationGasLimit; + uint256 preVerificationGas; + uint256 maxFeePerGas; + uint256 maxPriorityFeePerGas; + bytes paymasterAndData; + bytes signature; +} diff --git a/src/libraries/AccountStorageV1.sol b/src/libraries/AccountStorageV1.sol new file mode 100644 index 00000000..4bf00b63 --- /dev/null +++ b/src/libraries/AccountStorageV1.sol @@ -0,0 +1,125 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.21; + +import {IPlugin} from "../interfaces/IPlugin.sol"; + +import {FunctionReference} from "../libraries/FunctionReferenceLib.sol"; +import {LinkedListSet} from "../libraries/LinkedListSetLib.sol"; + +/// @title Account Storage V1 +/// @author Alchemy +/// @notice Contains the storage layout for upgradeable modular accounts. +contract AccountStorageV1 { + /// @custom:storage-location erc7201:Alchemy.UpgradeableModularAccount.Storage_V1 + struct AccountStorage { + // AccountStorageInitializable variables + uint8 initialized; + bool initializing; + // Plugin metadata storage + LinkedListSet plugins; + mapping(address => PluginData) pluginData; + // Execution functions and their associated functions + mapping(bytes4 => SelectorData) selectorData; + // bytes24 key = address(calling plugin) || bytes4(selector of execution function) + mapping(bytes24 => PermittedCallData) permittedCalls; + // key = address(calling plugin) || target address + mapping(IPlugin => mapping(address => PermittedExternalCallData)) permittedExternalCalls; + // For ERC165 introspection, each count indicates support from account or an installed plugin + // 0 indicate the account does not support the interface and all plugins that support this interface have + // been uninstalled + mapping(bytes4 => uint256) supportedInterfaces; + } + + struct PluginData { + bool anyExternalAddressPermitted; + // boolean to indicate if the plugin can spend native tokens, if any of the execution function can spend + // native tokens, a plugin is considered to be able to spend native tokens of the accounts + bool canSpendNativeToken; + bytes32 manifestHash; + FunctionReference[] dependencies; + // Tracks the number of times this plugin has been used as a dependency function + uint256 dependentCount; + StoredInjectedHook[] injectedHooks; + } + + /// @dev A version of IPliginManager.InjectedHook used to track injected hooks in storage. Omits the + /// hookApplyData field, which is not needed for storage, and flattens the struct. + struct StoredInjectedHook { + // The plugin that provides the hook + address providingPlugin; + // Either a plugin-defined execution function, or the native function executeFromPluginExternal + bytes4 selector; + // Contents of the InjectedHooksInfo struct + uint8 preExecHookFunctionId; + bool isPostHookUsed; + uint8 postExecHookFunctionId; + } + + /// @dev Represents data associated with a plugin's permission to use `executeFromPlugin` to interact with + /// another plugin installed on the account. + struct PermittedCallData { + bool callPermitted; + // Cached flags indicating whether or not this function has pre permitted call hooks and + // post-only permitted call hooks. + bool hasPrePermittedCallHooks; + bool hasPostOnlyPermittedCallHooks; + HookGroup permittedCallHooks; + } + + /// @dev Represents data associated with a plugin's permission to use `executeFromPluginExternal` to interact + /// with contracts and addresses external to the account and its plugins. + struct PermittedExternalCallData { + // Is this address on the permitted addresses list? If it is, we either have a + // list of allowed selectors, or the flag that allows any selector. + bool addressPermitted; + bool anySelectorPermitted; + mapping(bytes4 => bool) permittedSelectors; + } + + struct HookGroup { + // NOTE: this uses the flag _PRE_EXEC_HOOK_HAS_POST_FLAG to indicate whether + // an element has an associated post-exec hook. + LinkedListSet preHooks; + // bytes21 key = pre exec hook function reference + mapping(FunctionReference => LinkedListSet) associatedPostHooks; + LinkedListSet postOnlyHooks; + } + + /// @dev Represents data associated with a specifc function selector. + struct SelectorData { + // The plugin that implements this execution function. + // If this is a native function, the address must remain address(0). + address plugin; + // Cached flags indicating whether or not this function has pre-execution hooks and + // post-only hooks. Flags for pre-validation hooks stored in the same storage word + // as the validation function itself, to use a warm storage slot when loading. + bool hasPreExecHooks; + bool hasPostOnlyExecHooks; + // The specified validation functions for this function selector. + FunctionReference userOpValidation; + bool hasPreUserOpValidationHooks; + FunctionReference runtimeValidation; + bool hasPreRuntimeValidationHooks; + // The pre validation hooks for this function selector. + LinkedListSet preUserOpValidationHooks; + LinkedListSet preRuntimeValidationHooks; + // The execution hooks for this function selector. + HookGroup executionHooks; + } + + /// @dev the same storage slot will be used versions V1.x.y of upgradeable modular accounts. Follows ERC-7201. + /// bytes = keccak256( + /// abi.encode(uint256(keccak256("Alchemy.UpgradeableModularAccount.Storage_V1")) - 1) + /// ) & ~bytes32(uint256(0xff)); + bytes32 internal constant _V1_STORAGE_SLOT = 0xade46bbfcf6f898a43d541e42556d456ca0bf9b326df8debc0f29d3f811a0300; + + function _getAccountStorage() internal pure returns (AccountStorage storage storage_) { + assembly { + storage_.slot := _V1_STORAGE_SLOT + } + } + + function _getPermittedCallKey(address addr, bytes4 selector) internal pure returns (bytes24) { + return bytes24(bytes20(addr)) | (bytes24(selector) >> 160); + } +} diff --git a/src/libraries/AssociatedLinkedListSetLib.sol b/src/libraries/AssociatedLinkedListSetLib.sol new file mode 100644 index 00000000..51a0c781 --- /dev/null +++ b/src/libraries/AssociatedLinkedListSetLib.sol @@ -0,0 +1,501 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.21; + +import {SetValue, SENTINEL_VALUE, HAS_NEXT_FLAG} from "./LinkedListSetUtils.sol"; + +/// @dev Type representing the set, which is just a storage slot placeholder like the solidity mapping type. +struct AssociatedLinkedListSet { + bytes32 placeholder; +} + +/// @title Associated Linked List Set Library +/// @author Alchemy +/// @notice Provides a set data structure that is enumerable and held in address-associated storage (per the +/// ERC-4337 spec) +library AssociatedLinkedListSetLib { + // Mapping Entry Byte Layout + // | value | 0xAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA____ | + // | meta | 0x____________________________________________________________BBBB | + + // Bit-layout of the meta bytes (2 bytes) + // | user flags | 11111111 11111100 | + // | has next | 00000000 00000010 | + // | sentinel | 00000000 00000001 | + + // Mapping keys exclude the upper 15 bits of the meta bytes, which allows keys to be either a value or the + // sentinel. + + bytes4 internal constant _ASSOCIATED_STORAGE_PREFIX = 0x9cc6c923; // bytes4(keccak256("AssociatedLinkedListSet")) + + // A custom type representing the index of a storage slot + type StoragePointer is bytes32; + + // A custom type representing a pointer to a location in memory beyond the current free memory pointer. + // Holds a fixed-size buffer similar to "bytes memory", but without a length field. + // Care must be taken when using these, as they may be overwritten if ANY memory is allocated after allocating + // a TempBytesMemory. + type TempBytesMemory is bytes32; + + // INTERNAL METHODS + + /// @notice Adds a value to a set. + /// @param set The set to add the value to. + /// @param associated The address the set is associated with. + /// @param value The value to add. + /// @return True if the value was added, false if the value cannot be added (already exists or is zero). + function tryAdd(AssociatedLinkedListSet storage set, address associated, SetValue value) + internal + returns (bool) + { + bytes32 unwrappedKey = bytes32(SetValue.unwrap(value)); + if (unwrappedKey == bytes32(0)) { + // Cannot add the zero value + return false; + } + + TempBytesMemory keyBuffer = _allocateTempKeyBuffer(set, associated); + + StoragePointer valueSlot = _mapLookup(keyBuffer, unwrappedKey); + if (_load(valueSlot) != bytes32(0)) { + // Entry already exists + return false; + } + + // Load the head of the set + StoragePointer sentinelSlot = _mapLookup(keyBuffer, SENTINEL_VALUE); + bytes32 prev = _load(sentinelSlot); + if (prev == bytes32(0) || isSentinel(prev)) { + // set is empty, need to do: + // map[SENTINEL_VALUE] = unwrappedKey; + // map[unwrappedKey] = SENTINEL_VALUE; + _store(sentinelSlot, unwrappedKey); + _store(valueSlot, SENTINEL_VALUE); + } else { + // set is not empty, need to do: + // map[SENTINEL_VALUE] = unwrappedKey | HAS_NEXT_FLAG; + // map[unwrappedKey] = prev; + _store(sentinelSlot, unwrappedKey | HAS_NEXT_FLAG); + _store(valueSlot, prev); + } + + return true; + } + + /// @notice Removes a value from a set. + /// @dev This is an O(n) operation, where n is the number of elements in the set. + /// @param set The set to remove the value from + /// @param associated The address the set is associated with + /// @param value The value to remove + /// @return True if the value was removed, false if the value does not exist + function tryRemove(AssociatedLinkedListSet storage set, address associated, SetValue value) + internal + returns (bool) + { + bytes32 unwrappedKey = bytes32(SetValue.unwrap(value)); + TempBytesMemory keyBuffer = _allocateTempKeyBuffer(set, associated); + + StoragePointer valueSlot = _mapLookup(keyBuffer, unwrappedKey); + bytes32 nextValue = _load(valueSlot); + if (unwrappedKey == bytes32(0) || nextValue == bytes32(0)) { + // Entry does not exist + return false; + } + + bytes32 prevKey = SENTINEL_VALUE; + bytes32 currentVal; + do { + // Load the current entry + StoragePointer prevSlot = _mapLookup(keyBuffer, prevKey); + currentVal = _load(prevSlot); + bytes32 currentKey = clearFlags(currentVal); + if (currentKey == unwrappedKey) { + // Found the entry + // Set the previous value's next value to the next value, + // and the flags to the current value's flags. + // and the next value's `hasNext` flag to determine whether or not the next value is (or points to) + // the sentinel value. + + // Need to do: + // map[prevKey] = clearFlags(nextValue) | getUserFlags(currentVal) | (nextValue & HAS_NEXT_FLAG); + // map[currentKey] = bytes32(0); + + _store(prevSlot, clearFlags(nextValue) | getUserFlags(currentVal) | (nextValue & HAS_NEXT_FLAG)); + _store(valueSlot, bytes32(0)); + + return true; + } + prevKey = currentKey; + } while (!isSentinel(currentVal) && currentVal != bytes32(0)); + return false; + } + + /// @notice Removes a value from a set, given the previous value in the set. + /// @dev This is an O(1) operation but requires additional knowledge. + /// @param set The set to remove the value from + /// @param associated The address the set is associated with + /// @param value The value to remove + /// @param prev The previous value in the set + /// @return True if the value was removed, false if the value does not exist + function tryRemoveKnown(AssociatedLinkedListSet storage set, address associated, SetValue value, bytes32 prev) + internal + returns (bool) + { + bytes32 unwrappedKey = bytes32(SetValue.unwrap(value)); + TempBytesMemory keyBuffer = _allocateTempKeyBuffer(set, associated); + + prev = clearFlags(prev); + + if (prev == bytes32(0) || unwrappedKey == bytes32(0)) { + return false; + } + + // assert that the previous key's next value is the value to be removed + StoragePointer prevSlot = _mapLookup(keyBuffer, prev); + bytes32 currentValue = _load(prevSlot); + if (clearFlags(currentValue) != unwrappedKey) { + return false; + } + + StoragePointer valueSlot = _mapLookup(keyBuffer, unwrappedKey); + bytes32 next = _load(valueSlot); + if (next == bytes32(0)) { + // The set didn't actually contain the value + return false; + } + + // Need to do: + // map[prev] = clearFlags(next) | getUserFlags(currentValue) | (next & HAS_NEXT_FLAG); + // map[unwrappedKey] = bytes32(0); + _store(prevSlot, clearFlags(next) | getUserFlags(currentValue) | (next & HAS_NEXT_FLAG)); + _store(valueSlot, bytes32(0)); + + return true; + } + + /// @notice Removes all values from a set. + /// @dev This is an O(n) operation, where n is the number of elements in the set. + /// @param set The set to remove the values from + /// @param associated The address the set is associated with + function clear(AssociatedLinkedListSet storage set, address associated) internal { + TempBytesMemory keyBuffer = _allocateTempKeyBuffer(set, associated); + + bytes32 cursor = SENTINEL_VALUE; + + do { + bytes32 cleared = clearFlags(cursor); + StoragePointer cursorSlot = _mapLookup(keyBuffer, cleared); + bytes32 next = _load(cursorSlot); + _store(cursorSlot, bytes32(0)); + cursor = next; + } while (!isSentinel(cursor) && cursor != bytes32(0)); + + StoragePointer sentinelSlot = _mapLookup(keyBuffer, SENTINEL_VALUE); + _store(sentinelSlot, bytes32(0)); + } + + /// @notice Set the flags on a value in the set. + /// @dev The user flags can only be set on the upper 14 bits, because the lower two are reserved for the + /// sentinel and has next bit. + /// @param set The set containing the value. + /// @param associated The address the set is associated with. + /// @param value The value to set the flags on. + /// @param flags The flags to set. + /// @return True if the set contains the value and the operation succeeds, false otherwise. + function trySetFlags(AssociatedLinkedListSet storage set, address associated, SetValue value, uint16 flags) + internal + returns (bool) + { + bytes32 unwrappedKey = SetValue.unwrap(value); + TempBytesMemory keyBuffer = _allocateTempKeyBuffer(set, associated); + + // Ignore the lower 2 bits. + flags &= 0xFFFC; + + // If the set doesn't actually contain the value, return false; + StoragePointer valueSlot = _mapLookup(keyBuffer, unwrappedKey); + bytes32 next = _load(valueSlot); + if (next == bytes32(0)) { + return false; + } + + // Set the flags + _store(valueSlot, clearUserFlags(next) | bytes32(uint256(flags))); + + return true; + } + + /// @notice Set the given flags on a value in the set, preserving the values of other flags. + /// @dev The user flags can only be set on the upper 14 bits, because the lower two are reserved for the + /// sentinel and has next bit. + /// Short-circuits if the flags are already enabled, returning true. + /// @param set The set containing the value. + /// @param associated The address the set is associated with. + /// @param value The value to enable the flags on. + /// @param flags The flags to enable. + /// @return True if the operation succeeds or short-circuits due to the flags already being enabled. False + /// otherwise. + function tryEnableFlags(AssociatedLinkedListSet storage set, address associated, SetValue value, uint16 flags) + internal + returns (bool) + { + flags &= 0xFFFC; // Allow short-circuit if lower bits are accidentally set + uint16 currFlags = getFlags(set, associated, value); + if (currFlags & flags == flags) return true; // flags are already enabled + return trySetFlags(set, associated, value, currFlags | flags); + } + + /// @notice Clear the given flags on a value in the set, preserving the values of other flags. + /// @notice If the value is not in the set, this function will still return true. + /// @dev The user flags can only be set on the upper 14 bits, because the lower two are reserved for the + /// sentinel and has next bit. + /// Short-circuits if the flags are already disabled, or if set does not contain the value. Short-circuits + /// return true. + /// @param set The set containing the value. + /// @param associated The address the set is associated with. + /// @param value The value to disable the flags on. + /// @param flags The flags to disable. + /// @return True if the operation succeeds, or short-circuits due to the flags already being disabled or if the + /// set does not contain the value. False otherwise. + function tryDisableFlags(AssociatedLinkedListSet storage set, address associated, SetValue value, uint16 flags) + internal + returns (bool) + { + flags &= 0xFFFC; // Allow short-circuit if lower bits are accidentally set + uint16 currFlags = getFlags(set, associated, value); + if (currFlags & flags == 0) return true; // flags are already disabled + return trySetFlags(set, associated, value, currFlags & ~flags); + } + + /// @notice Checks if a set contains a value + /// @dev This method does not clear the upper bits of `value`, that is expected to be done as part of casting + /// to the correct type. If this function is provided the sentinel value by using the upper bits, this function + /// may returns `true`. + /// @param set The set to check + /// @param associated The address the set is associated with + /// @param value The value to check for + /// @return True if the set contains the value, false otherwise + function contains(AssociatedLinkedListSet storage set, address associated, SetValue value) + internal + view + returns (bool) + { + bytes32 unwrappedKey = bytes32(SetValue.unwrap(value)); + TempBytesMemory keyBuffer = _allocateTempKeyBuffer(set, associated); + + StoragePointer slot = _mapLookup(keyBuffer, unwrappedKey); + return _load(slot) != bytes32(0); + } + + /// @notice Checks if a set is empty + /// @param set The set to check + /// @param associated The address the set is associated with + /// @return True if the set is empty, false otherwise + function isEmpty(AssociatedLinkedListSet storage set, address associated) internal view returns (bool) { + TempBytesMemory keyBuffer = _allocateTempKeyBuffer(set, associated); + + StoragePointer sentinelSlot = _mapLookup(keyBuffer, SENTINEL_VALUE); + bytes32 val = _load(sentinelSlot); + return val == bytes32(0) || isSentinel(val); // either the sentinel is unset, or points to itself + } + + /// @notice Get the flags on a value in the set. + /// @dev The reserved lower 2 bits will not be returned, as those are reserved for the sentinel and has next + /// bit. + /// @param set The set containing the value. + /// @param associated The address the set is associated with. + /// @param value The value to get the flags from. + /// @return The flags set on the value. + function getFlags(AssociatedLinkedListSet storage set, address associated, SetValue value) + internal + view + returns (uint16) + { + bytes32 unwrappedKey = SetValue.unwrap(value); + TempBytesMemory keyBuffer = _allocateTempKeyBuffer(set, associated); + return uint16(uint256(_load(_mapLookup(keyBuffer, unwrappedKey))) & 0xFFFC); + } + + /// @notice Check if the flags on a value are enabled. + /// @dev The reserved lower 2 bits will be ignored, as those are reserved for the sentinel and has next bit. + /// @param set The set containing the value. + /// @param associated The address the set is associated with. + /// @param value The value to check the flags on. + /// @param flags The flags to check. + /// @return True if all of the flags are enabled, false otherwise. + function flagsEnabled(AssociatedLinkedListSet storage set, address associated, SetValue value, uint16 flags) + internal + view + returns (bool) + { + flags &= 0xFFFC; + return getFlags(set, associated, value) & flags == flags; + } + + /// @notice Check if the flags on a value are disabled. + /// @dev The reserved lower 2 bits will be ignored, as those are reserved for the sentinel and has next bit. + /// @param set The set containing the value. + /// @param associated The address the set is associated with. + /// @param value The value to check the flags on. + /// @param flags The flags to check. + /// @return True if all of the flags are disabled, false otherwise. + function flagsDisabled(AssociatedLinkedListSet storage set, address associated, SetValue value, uint16 flags) + internal + view + returns (bool) + { + flags &= 0xFFFC; + return ~(getFlags(set, associated, value)) & flags == flags; + } + + /// @notice Gets all elements in a set. + /// @dev This is an O(n) operation, where n is the number of elements in the set. + /// @param set The set to get the elements of. + /// @return res An array of all elements in the set. + function getAll(AssociatedLinkedListSet storage set, address associated) + internal + view + returns (SetValue[] memory res) + { + TempBytesMemory keyBuffer = _allocateTempKeyBuffer(set, associated); + + StoragePointer sentinelSlot = _mapLookup(keyBuffer, SENTINEL_VALUE); + bytes32 cursor = _load(sentinelSlot); + + uint256 count; + while (!isSentinel(cursor) && cursor != bytes32(0)) { + unchecked { + ++count; + } + bytes32 cleared = clearFlags(cursor); + + if (hasNext(cursor)) { + StoragePointer cursorSlot = _mapLookup(keyBuffer, cleared); + cursor = _load(cursorSlot); + } else { + cursor = bytes32(0); + } + } + + res = new SetValue[](count); + + if (count == 0) { + return res; + } + + // Re-allocate the key buffer because we just overwrote it! + keyBuffer = _allocateTempKeyBuffer(set, associated); + + cursor = SENTINEL_VALUE; + for (uint256 i = 0; i < count;) { + StoragePointer cursorSlot = _mapLookup(keyBuffer, cursor); + bytes32 cursorValue = _load(cursorSlot); + bytes32 cleared = clearFlags(cursorValue); + res[i] = SetValue.wrap(bytes30(cleared)); + cursor = cleared; + + unchecked { + ++i; + } + } + } + + function isSentinel(bytes32 value) internal pure returns (bool ret) { + assembly ("memory-safe") { + ret := and(value, 1) + } + } + + function hasNext(bytes32 value) internal pure returns (bool) { + return value & HAS_NEXT_FLAG != 0; + } + + function clearFlags(bytes32 val) internal pure returns (bytes32) { + return val & 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF0001; + } + + /// @dev Preserves the lower two bits + function clearUserFlags(bytes32 val) internal pure returns (bytes32) { + return val & 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF0003; + } + + function getUserFlags(bytes32 val) internal pure returns (bytes32) { + return val & bytes32(uint256(0xFFFC)); + } + + // PRIVATE METHODS + + /// @notice Given an allocated key buffer, returns the storage slot for a given key + function _mapLookup(TempBytesMemory keyBuffer, bytes32 value) private pure returns (StoragePointer slot) { + assembly ("memory-safe") { + // Store the value in the last word. + let keyWord2 := value + mstore(add(keyBuffer, 0x60), keyWord2) + slot := keccak256(keyBuffer, 0x80) + } + } + + /// @notice Allocates a key buffer for a given ID and associated address into scratch space memory. + /// @dev The returned buffer must not be used if any additional memory is allocated after calling this + /// function. + /// @param set The set to allocate the key buffer for. + /// @param associated The address the set is associated with. + /// @return key A key buffer that can be used to lookup values in the set + function _allocateTempKeyBuffer(AssociatedLinkedListSet storage set, address associated) + private + pure + returns (TempBytesMemory key) + { + // Key derivation for an entry + // associated addr (left-padded) || prefix || uint224(0) batchIndex || set storage slot || entry + // Word 1: + // | zeros | 0x000000000000000000000000________________________________________ | + // | address | 0x________________________AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA | + // Word 2: + // | prefix | 0xPPPPPPPP________________________________________________________ | + // | batch index (zero) | 0x________00000000000000000000000000000000000000000000000000000000 | + // Word 3: + // | set storage slot | 0xSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSS | + // Word 4: + // | entry value | 0xVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVVV____ | + // | entry meta | 0x____________________________________________________________MMMM | + + // The batch index is for consistency with PluginStorageLib, and the prefix in front of it is + // to prevent any potential crafted collisions where the batch index may be equal to storage slot + // of the ALLS. The prefix is set to the upper bits of the batch index to make it infeasible to + // reach from just incrementing the value. + + // This segment is memory-safe because it only uses the scratch space memory after the value of the free + // memory pointer. + // See https://docs.soliditylang.org/en/v0.8.21/assembly.html#memory-safety + assembly ("memory-safe") { + // Clean upper bits of arguments + associated := and(associated, 0xffffffffffffffffffffffffffffffffffffffff) + + // Use memory past-the-free-memory-pointer without updating it, as this is just scratch space + key := mload(0x40) + // Store the associated address in the first word, left-padded with zeroes + mstore(key, associated) + // Store the prefix and a batch index of 0 + mstore(add(key, 0x20), _ASSOCIATED_STORAGE_PREFIX) + // Store the list's storage slot in the third word + mstore(add(key, 0x40), set.slot) + // Leaves the last word open for the value entry + } + + return key; + } + + /// @dev Loads a value from storage + function _load(StoragePointer ptr) private view returns (bytes32 val) { + assembly ("memory-safe") { + val := sload(ptr) + } + } + + /// @dev Writes a value into storage + function _store(StoragePointer ptr, bytes32 val) private { + assembly ("memory-safe") { + sstore(ptr, val) + } + } +} diff --git a/src/libraries/CastLib.sol b/src/libraries/CastLib.sol new file mode 100644 index 00000000..e5a0b8fc --- /dev/null +++ b/src/libraries/CastLib.sol @@ -0,0 +1,51 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.21; + +import {FunctionReference} from "./FunctionReferenceLib.sol"; +import {SetValue} from "./LinkedListSetUtils.sol"; + +/// @title Cast Library +/// @author Alchemy +/// @notice Library for various data type conversions. +library CastLib { + function toFunctionReferenceArray(SetValue[] memory vals) + internal + pure + returns (FunctionReference[] memory ret) + { + assembly ("memory-safe") { + ret := vals + } + } + + function toAddressArray(SetValue[] memory values) internal pure returns (address[] memory addresses) { + bytes32[] memory valuesBytes; + + assembly ("memory-safe") { + valuesBytes := values + } + + uint256 length = values.length; + for (uint256 i = 0; i < length;) { + valuesBytes[i] >>= 96; + + unchecked { + i++; + } + } + + assembly ("memory-safe") { + addresses := valuesBytes + } + + return addresses; + } + + function toSetValue(FunctionReference functionReference) internal pure returns (SetValue) { + return SetValue.wrap(bytes30(FunctionReference.unwrap(functionReference))); + } + + function toSetValue(address value) internal pure returns (SetValue) { + return SetValue.wrap(bytes30(bytes20(value))); + } +} diff --git a/src/libraries/CountableLinkedListSetLib.sol b/src/libraries/CountableLinkedListSetLib.sol new file mode 100644 index 00000000..9a8cd84d --- /dev/null +++ b/src/libraries/CountableLinkedListSetLib.sol @@ -0,0 +1,77 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.21; + +import {LinkedListSet, LinkedListSetLib} from "./LinkedListSetLib.sol"; +import {SetValue} from "./LinkedListSetUtils.sol"; + +/// @title Countable Linked List Set Library +/// @author Alchemy +/// @notice This library adds the ability to count the number of occurrences of a value in a linked list set. +library CountableLinkedListSetLib { + using LinkedListSetLib for LinkedListSet; + + /// @dev The counter is stored in the upper 8 bits of the the flag bytes, so the maximum value of the counter + /// is 255. This means each value can be included a maximum of 256 times in the set, as the counter is 0 when + /// the value is first added. + uint16 internal constant _MAX_COUNTER_VALUE = 255; + + /// @notice Increment an existing value in the set, or add it if it doesn't exist. + /// @dev The counter is stored in the upper 8 bits of the the flag bytes. Because this library repurposes a + /// portion of the flag bytes to store the counter, it's important to not use the upper 8 bits to store flags. + /// Any existing flags on the upper 8 bits will be interpreted as part of the counter. + /// @param set The set to increment (or add) the value in. + /// @param value The value to increment (or add). + /// @return True if the value was incremented or added, false otherwise. + function tryIncrement(LinkedListSet storage set, SetValue value) internal returns (bool) { + if (!set.contains(value)) { + return set.tryAdd(value); + } + uint16 flags = set.getFlags(value); + // Use the upper 8 bits of the (16-bit) flag for the counter. + uint16 counter = flags >> 8; + if (counter == _MAX_COUNTER_VALUE) { + return false; + } + unchecked { + ++counter; + } + return set.trySetFlags(value, (counter << 8) | (flags & 0xFF)); + } + + /// @notice Decrement an existing value in the set, or remove it if the count has reached 0. + /// @dev The counter is stored in the upper 8 bits of the the flag bytes. Because this library repurposes a + /// portion of the flag bytes to store the counter, it's important to not use the upper 8 bits to store flags. + /// Any existing flags on the upper 8 bits will be interpreted as part of the counter. + /// @param set The set to decrement (or remove) the value in. + /// @param value The value to decrement (or remove). + /// @return True if the value was decremented or removed, false otherwise. + function tryDecrement(LinkedListSet storage set, SetValue value) internal returns (bool) { + if (!set.contains(value)) { + return false; + } + uint16 flags = set.getFlags(value); + // Use the upper 8 bits of the (16-bit) flag for the counter. + uint16 counter = flags >> 8; + if (counter == 0) { + return set.tryRemove(value); + } + unchecked { + --counter; + } + return set.trySetFlags(value, (counter << 8) | (flags & 0xFF)); + } + + /// @notice Get the number of occurrences of a value in the set. + /// @dev The counter is stored in the upper 8 bits of the the flag bytes. Because this library repurposes a + /// portion of the flag bytes to store the counter, it's important to not use the upper 8 bits to store flags. + /// Any existing flags on the upper 8 bits will be interpreted as part of the counter. + /// @return The number of occurrences of the value in the set. + function getCount(LinkedListSet storage set, SetValue value) internal view returns (uint256) { + if (!set.contains(value)) { + return 0; + } + unchecked { + return (set.getFlags(value) >> 8) + 1; + } + } +} diff --git a/src/libraries/FunctionReferenceLib.sol b/src/libraries/FunctionReferenceLib.sol new file mode 100644 index 00000000..145b8a66 --- /dev/null +++ b/src/libraries/FunctionReferenceLib.sol @@ -0,0 +1,41 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.21; + +type FunctionReference is bytes21; + +using {eq as ==, notEq as !=} for FunctionReference global; +using FunctionReferenceLib for FunctionReference global; + +/// @title Function Reference Lib +/// @author Alchemy +library FunctionReferenceLib { + // Empty or unset function reference. + FunctionReference internal constant _EMPTY_FUNCTION_REFERENCE = FunctionReference.wrap(bytes21(0)); + // Magic value for runtime validation functions that always allow access. + FunctionReference internal constant _RUNTIME_VALIDATION_ALWAYS_ALLOW = + FunctionReference.wrap(bytes21(uint168(1))); + // Magic value for hooks that should always revert. + FunctionReference internal constant _PRE_HOOK_ALWAYS_DENY = FunctionReference.wrap(bytes21(uint168(2))); + + function pack(address addr, uint8 functionId) internal pure returns (FunctionReference) { + return FunctionReference.wrap(bytes21(bytes20(addr)) | bytes21(uint168(functionId))); + } + + function unpack(FunctionReference fr) internal pure returns (address addr, uint8 functionId) { + bytes21 underlying = FunctionReference.unwrap(fr); + addr = address(bytes20(underlying)); + functionId = uint8(bytes1(underlying << 160)); + } + + function isEmptyOrMagicValue(FunctionReference fr) internal pure returns (bool) { + return FunctionReference.unwrap(fr) <= bytes21(uint168(2)); + } +} + +function eq(FunctionReference a, FunctionReference b) pure returns (bool) { + return FunctionReference.unwrap(a) == FunctionReference.unwrap(b); +} + +function notEq(FunctionReference a, FunctionReference b) pure returns (bool) { + return FunctionReference.unwrap(a) != FunctionReference.unwrap(b); +} diff --git a/src/libraries/LinkedListSetLib.sol b/src/libraries/LinkedListSetLib.sol new file mode 100644 index 00000000..284021f4 --- /dev/null +++ b/src/libraries/LinkedListSetLib.sol @@ -0,0 +1,321 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.21; + +import {SetValue, SENTINEL_VALUE, HAS_NEXT_FLAG} from "./LinkedListSetUtils.sol"; + +struct LinkedListSet { + // Byte Layout + // | value | 0xAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA____ | + // | meta | 0x____________________________________________________________BBBB | + + // Bit-layout of the meta bytes (2 bytes) + // | user flags | 11111111 11111100 | + // | has next | 00000000 00000010 | + // | sentinel | 00000000 00000001 | + + // Key excludes the meta bytes, except for the sentinel value, which is 0x1 + mapping(bytes32 => bytes32) map; +} + +/// @title Linked List Set Library +/// @author Alchemy +/// @notice This library provides a set of functions for managing enumerable sets of bytes30 values. +library LinkedListSetLib { + // INTERNAL METHODS + + /// @notice Add a value to a set. + /// @param set The set to add the value to. + /// @param value The value to add. + /// @return True if the value was added, false if the value cannot be added (already exists or is zero). + function tryAdd(LinkedListSet storage set, SetValue value) internal returns (bool) { + mapping(bytes32 => bytes32) storage map = set.map; + bytes32 unwrappedKey = SetValue.unwrap(value); + if (unwrappedKey == bytes32(0) || map[unwrappedKey] != bytes32(0)) return false; + + bytes32 prev = map[SENTINEL_VALUE]; + if (prev == bytes32(0) || isSentinel(prev)) { + // Set is empty + map[SENTINEL_VALUE] = unwrappedKey; + map[unwrappedKey] = SENTINEL_VALUE; + } else { + // set is not empty + map[SENTINEL_VALUE] = unwrappedKey | HAS_NEXT_FLAG; + map[unwrappedKey] = prev; + } + + return true; + } + + /// @notice Remove a value from a set. + /// @dev This is an O(n) operation, where n is the number of elements in the set. + /// @param set The set to remove the value from. + /// @param value The value to remove. + /// @return True if the value was removed, false if the value does not exist. + function tryRemove(LinkedListSet storage set, SetValue value) internal returns (bool) { + mapping(bytes32 => bytes32) storage map = set.map; + bytes32 unwrappedKey = SetValue.unwrap(value); + + bytes32 nextValue = map[unwrappedKey]; + if (unwrappedKey == bytes32(0) || nextValue == bytes32(0)) return false; + + bytes32 prevKey = SENTINEL_VALUE; + bytes32 currentVal; + do { + currentVal = map[prevKey]; + bytes32 currentKey = clearFlags(currentVal); + if (currentKey == unwrappedKey) { + // Set the previous value's next value to the next value, + // and the flags to the current value's flags. + // and the next value's `hasNext` flag to determine whether or not the next value is (or points to) + // the sentinel value. + map[prevKey] = clearFlags(nextValue) | getUserFlags(currentVal) | (nextValue & HAS_NEXT_FLAG); + map[currentKey] = bytes32(0); + + return true; + } + prevKey = currentKey; + } while (!isSentinel(currentVal) && currentVal != bytes32(0)); + return false; + } + + /// @notice Remove a value from a set, given the previous value in the set. + /// @dev This is an O(1) operation but requires additional knowledge. + /// @param set The set to remove the value from. + /// @param value The value to remove. + /// @param prev The previous value in the set. + /// @return True if the value was removed, false if the value does not exist. + function tryRemoveKnown(LinkedListSet storage set, SetValue value, bytes32 prev) internal returns (bool) { + mapping(bytes32 => bytes32) storage map = set.map; + bytes32 unwrappedKey = SetValue.unwrap(value); + + // Clear the flag bits of prev + prev = clearFlags(prev); + + if (prev == bytes32(0) || unwrappedKey == bytes32(0)) { + return false; + } + + // assert that the previous value's next value is the value to be removed + bytes32 currentValue = map[prev]; + if (clearFlags(currentValue) != unwrappedKey) { + return false; + } + + bytes32 next = map[unwrappedKey]; + if (next == bytes32(0)) { + // The set didn't actually contain the value + return false; + } + + map[prev] = clearFlags(next) | getUserFlags(currentValue) | (next & HAS_NEXT_FLAG); + map[unwrappedKey] = bytes32(0); + return true; + } + + /// @notice Remove all values from a set. + /// @dev This is an O(n) operation, where n is the number of elements in the set. + /// @param set The set to remove the values from. + function clear(LinkedListSet storage set) internal { + mapping(bytes32 => bytes32) storage map = set.map; + bytes32 cursor = SENTINEL_VALUE; + + do { + bytes32 next = clearFlags(map[cursor]); + map[cursor] = bytes32(0); + cursor = next; + } while (!isSentinel(cursor) && cursor != bytes32(0)); + + map[SENTINEL_VALUE] = bytes32(0); + } + + /// @notice Set the flags on a value in the set. + /// @dev The user flags can only be set on the upper 14 bits, because the lower two are reserved for the + /// sentinel and has next bit. + /// @param set The set containing the value. + /// @param value The value to set the flags on. + /// @param flags The flags to set. + /// @return True if the set contains the value and the operation succeeds, false otherwise. + function trySetFlags(LinkedListSet storage set, SetValue value, uint16 flags) internal returns (bool) { + mapping(bytes32 => bytes32) storage map = set.map; + bytes32 unwrappedKey = SetValue.unwrap(value); + + // Ignore the lower 2 bits. + flags &= 0xFFFC; + + // If the set doesn't actually contain the value, return false; + bytes32 next = map[unwrappedKey]; + if (next == bytes32(0)) { + return false; + } + + // Set the flags + map[unwrappedKey] = clearUserFlags(next) | bytes32(uint256(flags)); + + return true; + } + + /// @notice Set the given flags on a value in the set, preserving the values of other flags. + /// @dev The user flags can only be set on the upper 14 bits, because the lower two are reserved for the + /// sentinel and has next bit. + /// Short-circuits if the flags are already enabled, returning true. + /// @param set The set containing the value. + /// @param value The value to enable the flags on. + /// @param flags The flags to enable. + /// @return True if the operation succeeds or short-circuits due to the flags already being enabled. False + /// otherwise. + function tryEnableFlags(LinkedListSet storage set, SetValue value, uint16 flags) internal returns (bool) { + flags &= 0xFFFC; // Allow short-circuit if lower bits are accidentally set + uint16 currFlags = getFlags(set, value); + if (currFlags & flags == flags) return true; // flags are already enabled + return trySetFlags(set, value, currFlags | flags); + } + + /// @notice Clear the given flags on a value in the set, preserving the values of other flags. + /// @notice If the value is not in the set, this function will still return true. + /// @dev The user flags can only be set on the upper 14 bits, because the lower two are reserved for the + /// sentinel and has next bit. + /// Short-circuits if the flags are already disabled, or if set does not contain the value. Short-circuits + /// return true. + /// @param set The set containing the value. + /// @param value The value to disable the flags on. + /// @param flags The flags to disable. + /// @return True if the operation succeeds, or short-circuits due to the flags already being disabled or if the + /// set does not contain the value. False otherwise. + function tryDisableFlags(LinkedListSet storage set, SetValue value, uint16 flags) internal returns (bool) { + flags &= 0xFFFC; // Allow short-circuit if lower bits are accidentally set + uint16 currFlags = getFlags(set, value); + if (currFlags & flags == 0) return true; // flags are already disabled + return trySetFlags(set, value, currFlags & ~flags); + } + + /// @notice Check if a set contains a value. + /// @dev This method does not clear the upper bits of `value`, that is expected to be done as part of casting + /// to the correct type. If this function is provided the sentinel value by using the upper bits, this function + /// may returns `true`. + /// @param set The set to check. + /// @param value The value to check for. + /// @return True if the set contains the value, false otherwise. + function contains(LinkedListSet storage set, SetValue value) internal view returns (bool) { + mapping(bytes32 => bytes32) storage map = set.map; + return map[SetValue.unwrap(value)] != bytes32(0); + } + + /// @notice Check if a set is empty. + /// @param set The set to check. + /// @return True if the set is empty, false otherwise. + function isEmpty(LinkedListSet storage set) internal view returns (bool) { + mapping(bytes32 => bytes32) storage map = set.map; + bytes32 val = map[SENTINEL_VALUE]; + return val == bytes32(0) || isSentinel(val); // either the sentinel is unset, or points to itself + } + + /// @notice Get the flags on a value in the set. + /// @dev The reserved lower 2 bits will not be returned, as those are reserved for the sentinel and has next + /// bit. + /// @param set The set containing the value. + /// @param value The value to get the flags from. + /// @return The flags set on the value. + function getFlags(LinkedListSet storage set, SetValue value) internal view returns (uint16) { + mapping(bytes32 => bytes32) storage map = set.map; + bytes32 unwrappedKey = SetValue.unwrap(value); + + return uint16(uint256(map[unwrappedKey]) & 0xFFFC); + } + + /// @notice Check if the flags on a value are enabled. + /// @dev The reserved lower 2 bits will be ignored, as those are reserved for the sentinel and has next bit. + /// @param set The set containing the value. + /// @param value The value to check the flags on. + /// @param flags The flags to check. + /// @return True if all of the flags are enabled, false otherwise. + function flagsEnabled(LinkedListSet storage set, SetValue value, uint16 flags) internal view returns (bool) { + flags &= 0xFFFC; + return getFlags(set, value) & flags == flags; + } + + /// @notice Check if the flags on a value are disabled. + /// @dev The reserved lower 2 bits will be ignored, as those are reserved for the sentinel and has next bit. + /// @param set The set containing the value. + /// @param value The value to check the flags on. + /// @param flags The flags to check. + /// @return True if all of the flags are disabled, false otherwise. + function flagsDisabled(LinkedListSet storage set, SetValue value, uint16 flags) internal view returns (bool) { + flags &= 0xFFFC; + return ~(getFlags(set, value)) & flags == flags; + } + + /// @notice Get all elements in a set. + /// @dev This is an O(n) operation, where n is the number of elements in the set. + /// @param set The set to get the elements of. + /// @return ret An array of all elements in the set. + function getAll(LinkedListSet storage set) internal view returns (SetValue[] memory ret) { + mapping(bytes32 => bytes32) storage map = set.map; + uint256 size; + bytes32 cursor = map[SENTINEL_VALUE]; + + // Dynamically allocate the returned array as we iterate through the set, since we don't know the size + // beforehand. + // This is accomplished by first writing to memory after the free memory pointer, + // then updating the free memory pointer to cover the newly-allocated data. + // To the compiler, writes to memory after the free memory pointer are considered "memory safe". + // See https://docs.soliditylang.org/en/v0.8.21/assembly.html#memory-safety + // Stack variable lifting done when compiling with via-ir will only ever place variables into memory + // locations + // below the current free memory pointer, so it is safe to compile this library with via-ir. + // See https://docs.soliditylang.org/en/v0.8.21/yul.html#memoryguard + assembly ("memory-safe") { + // It is critical that no other memory allocations occur between: + // - loading the value of the free memory pointer into `ret` + // - updating the free memory pointer to point to the newly-allocated data, which is done after all + // the values have been written. + ret := mload(0x40) + } + + while (!isSentinel(cursor) && cursor != bytes32(0)) { + unchecked { + ++size; + } + bytes32 cleared = clearFlags(cursor); + // Place the item into the return array manually. Since the size was just incremented, it will point to + // the next location to write to. + assembly ("memory-safe") { + mstore(add(ret, mul(size, 0x20)), cleared) + } + if (hasNext(cursor)) { + cursor = map[cleared]; + } else { + cursor = bytes32(0); + } + } + + assembly ("memory-safe") { + // Update the free memory pointer with the now-known length of the array. + mstore(0x40, add(ret, mul(add(size, 1), 0x20))) + // Set the length of the array. + mstore(ret, size) + } + } + + function isSentinel(bytes32 value) internal pure returns (bool ret) { + assembly ("memory-safe") { + ret := and(value, 1) + } + } + + function hasNext(bytes32 value) internal pure returns (bool) { + return value & HAS_NEXT_FLAG != 0; + } + + function clearFlags(bytes32 val) internal pure returns (bytes32) { + return val & 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF0001; + } + + /// @dev Preserves the lower two bits + function clearUserFlags(bytes32 val) internal pure returns (bytes32) { + return val & 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF0003; + } + + function getUserFlags(bytes32 val) internal pure returns (bytes32) { + return val & bytes32(uint256(0xFFFC)); + } +} diff --git a/src/libraries/LinkedListSetUtils.sol b/src/libraries/LinkedListSetUtils.sol new file mode 100644 index 00000000..d2b20c79 --- /dev/null +++ b/src/libraries/LinkedListSetUtils.sol @@ -0,0 +1,11 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.21; + +type SetValue is bytes30; + +/// @dev The sentinel value is used to indicate the head and tail of the list. +bytes32 constant SENTINEL_VALUE = bytes32(uint256(1)); + +/// @dev Removing the last element will result in this flag not being set correctly, but all operations will +/// function normally, albeit with one extra sload for getAll. +bytes32 constant HAS_NEXT_FLAG = bytes32(uint256(2)); diff --git a/src/libraries/PluginStorageLib.sol b/src/libraries/PluginStorageLib.sol new file mode 100644 index 00000000..949f78a2 --- /dev/null +++ b/src/libraries/PluginStorageLib.sol @@ -0,0 +1,58 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.21; + +type StoragePointer is bytes32; + +/// @title Plugin Storage Library +/// @author Alchemy +/// @notice Library for allocating and accessing ERC-4337 address-associated storage within plugins. +library PluginStorageLib { + /// @notice Allocates a memory buffer for an associated storage key, and sets the associated address and batch + /// index. + /// @param addr The address to associate with the storage key. + /// @param batchIndex The batch index to associate with the storage key. + /// @param keySize The size of the key in words, where each word is 32 bytes. Not inclusive of the address and + /// batch index. + /// @return key The allocated memory buffer. + function allocateAssociatedStorageKey(address addr, uint256 batchIndex, uint8 keySize) + internal + pure + returns (bytes memory key) + { + assembly ("memory-safe") { + // Clear any dirty upper bits of keySize to prevent overflow + keySize := and(keySize, 0xff) + + // compute the total size of the buffer, include the address and batch index + let totalSize := add(64, mul(32, keySize)) + + // Allocate memory for the key + key := mload(0x40) + mstore(0x40, add(add(key, totalSize), 32)) + mstore(key, totalSize) + + // Store the address and batch index in the key buffer + mstore(add(key, 32), addr) + mstore(add(key, 64), batchIndex) + } + } + + function associatedStorageLookup(bytes memory key, bytes32 input) internal pure returns (StoragePointer ptr) { + assembly ("memory-safe") { + mstore(add(key, 96), input) + ptr := keccak256(add(key, 32), mload(key)) + } + } + + function associatedStorageLookup(bytes memory key, bytes32 input1, bytes32 input2) + internal + pure + returns (StoragePointer ptr) + { + assembly ("memory-safe") { + mstore(add(key, 96), input1) + mstore(add(key, 128), input2) + ptr := keccak256(add(key, 32), mload(key)) + } + } +} diff --git a/src/plugins/BasePlugin.sol b/src/plugins/BasePlugin.sol new file mode 100644 index 00000000..c9c100b4 --- /dev/null +++ b/src/plugins/BasePlugin.sol @@ -0,0 +1,226 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.21; + +import {ERC165} from "@openzeppelin/contracts/utils/introspection/ERC165.sol"; + +import {IPlugin, PluginManifest, PluginMetadata} from "../interfaces/IPlugin.sol"; +import {IPluginManager} from "../interfaces/IPluginManager.sol"; +import {UserOperation} from "../interfaces/erc4337/UserOperation.sol"; + +/// @title Base contract for plugins +/// @dev Implements ERC-165 to support IPlugin's interface, which is a requirement +/// for plugin installation. This also ensures that plugin interactions cannot +/// happen via the standard execution funtions `execute` and `executeBatch`. +/// Note that the plugins implement BasePlugins cannot be installed when creating an account (aka installed in the +/// account constructor) unless onInstall is overriden without checking codesize of caller (account). Checking +/// codesize of account is to prevent EOA from accidentally calling plugin and initiate states which will make it +/// unusable in the future when EOA can be upgraded into an smart contract account. +abstract contract BasePlugin is ERC165, IPlugin { + error AlreadyInitialized(); + error InvalidAction(); + error NotImplemented(); + error NotContractCaller(); + error NotInitialized(); + + modifier isNotInitialized(address account) { + if (_isInitialized(account)) { + revert AlreadyInitialized(); + } + _; + } + + modifier isInitialized(address account) { + if (!_isInitialized(account)) { + revert NotInitialized(); + } + _; + } + + modifier staysInitialized(address account) { + if (!_isInitialized(account)) { + revert NotInitialized(); + } + _; + if (!_isInitialized(account)) { + revert InvalidAction(); + } + } + + /// @notice Initialize plugin data for the modular account. + /// @dev Called by the modular account during `installPlugin`. + /// @param data Optional bytes array to be decoded and used by the plugin to setup initial plugin data for the + /// modular account. + function onInstall(bytes calldata data) external virtual { + if (msg.sender.code.length == 0) { + revert NotContractCaller(); + } + _onInstall(data); + } + + /// @notice Clear plugin data for the modular account. + /// @dev Called by the modular account during `uninstallPlugin`. + /// @param data Optional bytes array to be decoded and used by the plugin to clear plugin data for the modular + /// account. + function onUninstall(bytes calldata data) external virtual { + (data); + revert NotImplemented(); + } + + /// @notice Run the pre user operation validation hook specified by the `functionId`. + /// @dev Pre user operation validation hooks MUST NOT return an authorizer value other than 0 or 1. + /// @param functionId An identifier that routes the call to different internal implementations, should there be + /// more than one. + /// @param userOp The user operation. + /// @param userOpHash The user operation hash. + /// @return Packed validation data for validAfter (6 bytes), validUntil (6 bytes), and authorizer (20 bytes). + function preUserOpValidationHook(uint8 functionId, UserOperation calldata userOp, bytes32 userOpHash) + external + virtual + returns (uint256) + { + (functionId, userOp, userOpHash); + revert NotImplemented(); + } + + /// @notice Run the user operation validationFunction specified by the `functionId`. + /// @param functionId An identifier that routes the call to different internal implementations, should there be + /// more than one. + /// @param userOp The user operation. + /// @param userOpHash The user operation hash. + /// @return Packed validation data for validAfter (6 bytes), validUntil (6 bytes), and authorizer (20 bytes). + function userOpValidationFunction(uint8 functionId, UserOperation calldata userOp, bytes32 userOpHash) + external + virtual + returns (uint256) + { + (functionId, userOp, userOpHash); + revert NotImplemented(); + } + + /// @notice Run the pre runtime validation hook specified by the `functionId`. + /// @dev To indicate the entire call should revert, the function MUST revert. + /// @param functionId An identifier that routes the call to different internal implementations, should there be + /// more than one. + /// @param sender The caller address. + /// @param value The call value. + /// @param data The calldata sent. + function preRuntimeValidationHook(uint8 functionId, address sender, uint256 value, bytes calldata data) + external + virtual + { + (functionId, sender, value, data); + revert NotImplemented(); + } + + /// @notice Run the runtime validationFunction specified by the `functionId`. + /// @dev To indicate the entire call should revert, the function MUST revert. + /// @param functionId An identifier that routes the call to different internal implementations, should there be + /// more than one. + /// @param sender The caller address. + /// @param value The call value. + /// @param data The calldata sent. + function runtimeValidationFunction(uint8 functionId, address sender, uint256 value, bytes calldata data) + external + virtual + { + (functionId, sender, value, data); + revert NotImplemented(); + } + + /// @notice Run the pre execution hook specified by the `functionId`. + /// @dev To indicate the entire call should revert, the function MUST revert. + /// @param functionId An identifier that routes the call to different internal implementations, should there be + /// more than one. + /// @param sender The caller address. + /// @param value The call value. + /// @param data The calldata sent. + /// @return Context to pass to a post execution hook, if present. An empty bytes array MAY be returned. + function preExecutionHook(uint8 functionId, address sender, uint256 value, bytes calldata data) + external + virtual + returns (bytes memory) + { + (functionId, sender, value, data); + revert NotImplemented(); + } + + /// @notice Run the post execution hook specified by the `functionId`. + /// @dev To indicate the entire call should revert, the function MUST revert. + /// @param functionId An identifier that routes the call to different internal implementations, should there be + /// more than one. + /// @param preExecHookData The context returned by its associated pre execution hook. + function postExecutionHook(uint8 functionId, bytes calldata preExecHookData) external virtual { + (functionId, preExecHookData); + revert NotImplemented(); + } + + /// @notice A hook that runs when a hook this plugin owns is installed onto another plugin + /// @dev Optional, use to implement any required setup logic + /// @param pluginAppliedOn The plugin that the hook is being applied on + /// @param injectedHooksInfo Contains pre/post exec hook information + /// @param data Any optional data for setup + function onHookApply( + address pluginAppliedOn, + IPluginManager.InjectedHooksInfo calldata injectedHooksInfo, + bytes calldata data + ) external virtual { + (pluginAppliedOn, injectedHooksInfo, data); + } + + /// @notice A hook that runs when a hook this plugin owns is unapplied from another plugin + /// @dev Optional, use to implement any required unapplied logic + /// @param pluginAppliedOn The plugin that the hook was applied on + /// @param injectedHooksInfo Contains pre/post exec hook information + /// @param data Any optional data for the unapplied call + function onHookUnapply( + address pluginAppliedOn, + IPluginManager.InjectedHooksInfo calldata injectedHooksInfo, + bytes calldata data + ) external virtual { + (pluginAppliedOn, injectedHooksInfo, data); + } + + /// @notice Describe the contents and intended configuration of the plugin. + /// @dev This manifest MUST stay constant over time. + /// @return A manifest describing the contents and intended configuration of the plugin. + function pluginManifest() external pure virtual returns (PluginManifest memory) { + revert NotImplemented(); + } + + /// @notice Describe the metadata of the plugin. + /// @dev This metadata MUST stay constant over time. + /// @return A metadata struct describing the plugin. + function pluginMetadata() external pure virtual returns (PluginMetadata memory); + + /// @dev Returns true if this contract implements the interface defined by + /// `interfaceId`. See the corresponding + /// https://eips.ethereum.org/EIPS/eip-165#how-interfaces-are-identified[EIP section] + /// to learn more about how these ids are created. + /// + /// This function call must use less than 30 000 gas. + /// + /// Supporting the IPlugin interface is a requirement for plugin installation. This is also used + /// by the modular account to prevent standard execution functions `execute` and `executeBatch` from + /// making calls to plugins. + /// @param interfaceId The interface ID to check for support. + /// @return True if the contract supports `interfaceId`. + function supportsInterface(bytes4 interfaceId) public view virtual override returns (bool) { + return interfaceId == type(IPlugin).interfaceId || super.supportsInterface(interfaceId); + } + + /// @notice Initialize plugin data for the modular account. + /// @dev Called by the modular account during `installPlugin`. + /// @param data Optional bytes array to be decoded and used by the plugin to setup initial plugin data for the + /// modular account. + function _onInstall(bytes calldata data) internal virtual { + (data); + revert NotImplemented(); + } + + /// @notice Check if the account has initialized this plugin yet + /// @dev This function should be overwritten for plugins that have state-changing onInstall's + /// @param account The account to check + /// @return True if the account has initialized this plugin + // solhint-disable-next-line no-empty-blocks + function _isInitialized(address account) internal view virtual returns (bool) {} +} diff --git a/src/plugins/TokenReceiverPlugin.sol b/src/plugins/TokenReceiverPlugin.sol new file mode 100644 index 00000000..082f2cd2 --- /dev/null +++ b/src/plugins/TokenReceiverPlugin.sol @@ -0,0 +1,121 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.21; + +import {IERC721Receiver} from "@openzeppelin/contracts/token/ERC721/IERC721Receiver.sol"; +import {IERC777Recipient} from "@openzeppelin/contracts/interfaces/IERC777Recipient.sol"; +import {IERC1155Receiver} from "@openzeppelin/contracts/interfaces/IERC1155Receiver.sol"; + +import { + ManifestFunction, + ManifestAssociatedFunctionType, + ManifestAssociatedFunction, + PluginManifest, + PluginMetadata +} from "../interfaces/IPlugin.sol"; +import {BasePlugin} from "./BasePlugin.sol"; + +/// @title Token Receiver Plugin +/// @author Alchemy +/// @notice This plugin allows modular accounts to receive various types of tokens by implementing +/// required token receiver interfaces. +contract TokenReceiverPlugin is BasePlugin, IERC721Receiver, IERC777Recipient, IERC1155Receiver { + string internal constant _NAME = "Token Receiver Plugin"; + string internal constant _VERSION = "1.0.0"; + string internal constant _AUTHOR = "Alchemy"; + + // ┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ + // ┃ Execution functions ┃ + // ┗━━━━━━━━━━━━━━━━━━━━━━━━━━━┛ + + function tokensReceived(address, address, address, uint256, bytes calldata, bytes calldata) + external + pure + override + // solhint-disable-next-line no-empty-blocks + {} + + function onERC721Received(address, address, uint256, bytes calldata) external pure override returns (bytes4) { + return IERC721Receiver.onERC721Received.selector; + } + + function onERC1155Received(address, address, uint256, uint256, bytes calldata) + external + pure + override + returns (bytes4) + { + return IERC1155Receiver.onERC1155Received.selector; + } + + function onERC1155BatchReceived(address, address, uint256[] calldata, uint256[] calldata, bytes calldata) + external + pure + override + returns (bytes4) + { + return IERC1155Receiver.onERC1155BatchReceived.selector; + } + + // ┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ + // ┃ Plugin interface functions ┃ + // ┗━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┛ + + /// @inheritdoc BasePlugin + // solhint-disable-next-line no-empty-blocks + function onInstall(bytes calldata) external pure override {} + + /// @inheritdoc BasePlugin + // solhint-disable-next-line no-empty-blocks + function onUninstall(bytes calldata) external pure override {} + + /// @inheritdoc BasePlugin + function pluginManifest() external pure override returns (PluginManifest memory) { + PluginManifest memory manifest; + + manifest.executionFunctions = new bytes4[](4); + manifest.executionFunctions[0] = this.tokensReceived.selector; + manifest.executionFunctions[1] = this.onERC721Received.selector; + manifest.executionFunctions[2] = this.onERC1155Received.selector; + manifest.executionFunctions[3] = this.onERC1155BatchReceived.selector; + + // Only runtime validationFunction is needed since callbacks come from token contracts only + ManifestFunction memory alwaysAllowFunction = ManifestFunction({ + functionType: ManifestAssociatedFunctionType.RUNTIME_VALIDATION_ALWAYS_ALLOW, + functionId: 0, // Unused. + dependencyIndex: 0 // Unused. + }); + manifest.runtimeValidationFunctions = new ManifestAssociatedFunction[](4); + manifest.runtimeValidationFunctions[0] = ManifestAssociatedFunction({ + executionSelector: this.tokensReceived.selector, + associatedFunction: alwaysAllowFunction + }); + manifest.runtimeValidationFunctions[1] = ManifestAssociatedFunction({ + executionSelector: this.onERC721Received.selector, + associatedFunction: alwaysAllowFunction + }); + manifest.runtimeValidationFunctions[2] = ManifestAssociatedFunction({ + executionSelector: this.onERC1155Received.selector, + associatedFunction: alwaysAllowFunction + }); + manifest.runtimeValidationFunctions[3] = ManifestAssociatedFunction({ + executionSelector: this.onERC1155BatchReceived.selector, + associatedFunction: alwaysAllowFunction + }); + + manifest.interfaceIds = new bytes4[](3); + manifest.interfaceIds[0] = type(IERC721Receiver).interfaceId; + manifest.interfaceIds[1] = type(IERC777Recipient).interfaceId; + manifest.interfaceIds[2] = type(IERC1155Receiver).interfaceId; + + return manifest; + } + + /// @inheritdoc BasePlugin + function pluginMetadata() external pure virtual override returns (PluginMetadata memory) { + PluginMetadata memory metadata; + metadata.name = _NAME; + metadata.version = _VERSION; + metadata.author = _AUTHOR; + return metadata; + } +} diff --git a/src/plugins/owner/IMultiOwnerPlugin.sol b/src/plugins/owner/IMultiOwnerPlugin.sol new file mode 100644 index 00000000..ac7b48c7 --- /dev/null +++ b/src/plugins/owner/IMultiOwnerPlugin.sol @@ -0,0 +1,82 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.21; + +import {UserOperation} from "../../interfaces/erc4337/UserOperation.sol"; + +interface IMultiOwnerPlugin { + enum FunctionId { + RUNTIME_VALIDATION_OWNER_OR_SELF, // require owner or self access + USER_OP_VALIDATION_OWNER // require owner access + } + + /// @notice This event is emitted when owners of the account are updated. + /// @param account The account whose ownership changed. + /// @param addedOwners The address array of added owners. + /// @param removedOwners The address array of removed owners. + event OwnerUpdated(address indexed account, address[] addedOwners, address[] removedOwners); + + error InvalidOwner(address owner); + error EmptyOwnersNotAllowed(); + error NotAuthorized(); + error OwnerDoesNotExist(address owner); + + /// @notice Update owners of the account. Owners can update owners. + /// @dev This function is installed on the account as part of plugin installation, and should + /// only be called from an account. + /// @param ownersToAdd The address array of owners to be added. + /// @param ownersToRemove The address array of owners to be removed. + function updateOwners(address[] memory ownersToAdd, address[] memory ownersToRemove) external; + + /// @notice Check if an address is an owner of the current account. + /// @dev This function is installed on the account as part of plugin installation, and should + /// only be called from an account. + /// @param ownerToCheck The owner to check if it is an owner of the current account. + /// @return True if the address is an owner of the account. + function isOwner(address ownerToCheck) external view returns (bool); + + /// @notice Get the owners of the current account. + /// @dev This function is installed on the account as part of plugin installation, and should + /// only be called from an account. + /// @return The addresses of the owners of the account. + function owners() external view returns (address[] memory); + + /// @notice Gets the EIP712 domain + /// @dev This implementation is different from typical 712 via its use of msg.sender instead. As such, it + /// should only be called from the SCAs that has installed this. See ERC-5267. + function eip712Domain() + external + view + returns ( + bytes1 fields, + string memory name, + string memory version, + uint256 chainId, + address verifyingContract, + bytes32 salt, + uint256[] memory extensions + ); + + /// @notice Check if an address is an owner of `account`. + /// @param account The account to check. + /// @param ownerToCheck The owner to check if it is an owner of the provided account. + /// @return True if the address is an owner of the account. + function isOwnerOf(address account, address ownerToCheck) external view returns (bool); + + /// @notice Get the owners of `account`. + /// @param account The account to get the owners of. + /// @return The addresses of the owners of the account. + function ownersOf(address account) external view returns (address[] memory); + + /// @notice Returns the pre-image of the message hash + /// @dev Assumes that the SCA's implementation of `domainSeparator` is this plugin's + /// @param account SCA to build the message encoding for + /// @param message Message that should be encoded. + /// @return Encoded message. + function encodeMessageData(address account, bytes memory message) external view returns (bytes memory); + + /// @notice Returns hash of a message that can be signed by owners. + /// @param account SCA to build the message hash for + /// @param message Message that should be hashed. + /// @return Message hash. + function getMessageHash(address account, bytes memory message) external view returns (bytes32); +} diff --git a/src/plugins/owner/MultiOwnerPlugin.sol b/src/plugins/owner/MultiOwnerPlugin.sol new file mode 100644 index 00000000..c9404a58 --- /dev/null +++ b/src/plugins/owner/MultiOwnerPlugin.sol @@ -0,0 +1,438 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.21; + +import {ECDSA} from "@openzeppelin/contracts/utils/cryptography/ECDSA.sol"; +import {EIP712} from "@openzeppelin/contracts/utils/cryptography/EIP712.sol"; +import {IERC1271} from "@openzeppelin/contracts/interfaces/IERC1271.sol"; +import {SignatureChecker} from "@openzeppelin/contracts/utils/cryptography/SignatureChecker.sol"; + +import {BasePlugin} from "../BasePlugin.sol"; +import {IMultiOwnerPlugin} from "./IMultiOwnerPlugin.sol"; +import {UpgradeableModularAccount, UUPSUpgradeable} from "../../account/UpgradeableModularAccount.sol"; + +import { + ManifestAssociatedFunction, + ManifestAssociatedFunctionType, + ManifestFunction, + PluginManifest, + PluginMetadata, + SelectorPermission +} from "../../interfaces/IPlugin.sol"; +import {IStandardExecutor} from "../../interfaces/IStandardExecutor.sol"; +import {UserOperation} from "../../interfaces/erc4337/UserOperation.sol"; + +import { + AssociatedLinkedListSet, AssociatedLinkedListSetLib +} from "../../libraries/AssociatedLinkedListSetLib.sol"; +import {CastLib} from "../../libraries/CastLib.sol"; +import {SetValue} from "../../libraries/LinkedListSetUtils.sol"; + +/// @title Multi Owner Plugin +/// @author Alchemy +/// @notice This plugin allows more than one EOA or smart contract to own a modular account. +/// All owners have equal root access to the account. +/// +/// It supports [ERC-1271](https://eips.ethereum.org/EIPS/eip-1271) signature +/// validation for both validating the signature on user operations and in +/// exposing its own `isValidSignature` method. This only works when the owner of +/// modular account also support ERC-1271. +/// +/// ERC-4337's bundler validation rules limit the types of contracts that can be +/// used as owners to validate user operation signatures. For example, the +/// contract's `isValidSignature` function may not use any forbidden opcodes +/// such as `TIMESTAMP` or `NUMBER`, and the contract may not be an ERC-1967 +/// proxy as it accesses a constant implementation slot not associated with +/// the account, violating storage access rules. This also means that the +/// owner of a modular account may not be another modular account if you want to +/// send user operations through a bundler. +contract MultiOwnerPlugin is BasePlugin, IMultiOwnerPlugin, IERC1271, EIP712 { + using AssociatedLinkedListSetLib for AssociatedLinkedListSet; + using ECDSA for bytes32; + + string internal constant _NAME = "Multi Owner Plugin"; + string internal constant _VERSION = "1.0.0"; + string internal constant _AUTHOR = "Alchemy"; + + bytes32 private constant _TYPE_HASH = + keccak256("EIP712Domain(string name,string version,uint256 chainId,address verifyingContract)"); + bytes32 private immutable _HASHED_NAME = keccak256(bytes(_NAME)); + bytes32 private immutable _HASHED_VERSION = keccak256(bytes(_VERSION)); + + constructor() EIP712(_NAME, _VERSION) {} + + // ERC-4337 specific value: signature validation passed + uint256 internal constant _SIG_VALIDATION_PASSED = 0; + // ERC-4337 specific value: signature validation failed + uint256 internal constant _SIG_VALIDATION_FAILED = 1; + + // bytes4(keccak256("isValidSignature(bytes32,bytes)")) + bytes4 internal constant _1271_MAGIC_VALUE = 0x1626ba7e; + bytes4 internal constant _1271_MAGIC_VALUE_FAILURE = 0xffffffff; + + // keccak256("ERC6900Message(bytes message)"); + bytes32 private constant ERC6900_TYPEHASH = 0xa856bbdae1f2c6e4aa17a75ad7cc5650f184ec4b549174dd7258c9701d663fc6; + + AssociatedLinkedListSet internal _owners; + + // ┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ + // ┃ Execution functions ┃ + // ┗━━━━━━━━━━━━━━━━━━━━━━━━━━━┛ + + /// @inheritdoc IMultiOwnerPlugin + function updateOwners(address[] memory ownersToAdd, address[] memory ownersToRemove) + public + isInitialized(msg.sender) + { + _addOwnersOrRevert(_owners, msg.sender, ownersToAdd); + _removeOwnersOrRevert(_owners, msg.sender, ownersToRemove); + + if (_owners.isEmpty(msg.sender)) { + revert EmptyOwnersNotAllowed(); + } + + emit OwnerUpdated(msg.sender, ownersToAdd, ownersToRemove); + } + + // ┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ + // ┃ Execution view functions ┃ + // ┗━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┛ + + /// @inheritdoc IMultiOwnerPlugin + function isOwner(address ownerToCheck) external view returns (bool) { + return isOwnerOf(msg.sender, ownerToCheck); + } + + /// @inheritdoc IMultiOwnerPlugin + function owners() external view returns (address[] memory) { + return ownersOf(msg.sender); + } + + /// @inheritdoc IMultiOwnerPlugin + function eip712Domain() + public + view + override(IMultiOwnerPlugin, EIP712) + returns ( + bytes1 fields, + string memory name, + string memory version, + uint256 chainId, + address verifyingContract, + bytes32 salt, + uint256[] memory extensions + ) + { + (fields, name, version, chainId,, salt, extensions) = super.eip712Domain(); + verifyingContract = msg.sender; + } + + /// @inheritdoc IERC1271 + /// @dev The signature is valid if it is signed by one of the owners' private key + /// (if the owner is an EOA) or if it is a valid ERC-1271 signature from one of the + /// owners (if the owner is a contract). Note that unlike the signature + /// validation used in `validateUserOp`, this does not wrap the digest in + /// an "Ethereum Signed Message" envelope before checking the signature in + /// the EOA-owner case. + function isValidSignature(bytes32 digest, bytes memory signature) public view override returns (bytes4) { + bytes memory messageData = encodeMessageData(msg.sender, abi.encode(digest)); + bytes32 messageHash = keccak256(messageData); + // try to recover through ECDSA + (address signer, ECDSA.RecoverError error) = ECDSA.tryRecover(messageHash, signature); + if (error == ECDSA.RecoverError.NoError) { + if (_owners.contains(msg.sender, CastLib.toSetValue(signer))) { + return _1271_MAGIC_VALUE; + } else { + return _1271_MAGIC_VALUE_FAILURE; + } + } else { + if (_isValidERC1271OwnerTypeSignature(msg.sender, messageHash, signature)) { + return _1271_MAGIC_VALUE; + } + } + + return _1271_MAGIC_VALUE_FAILURE; + } + + // ┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ + // ┃ Plugin view functions ┃ + // ┗━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┛ + + /// @inheritdoc IMultiOwnerPlugin + function encodeMessageData(address account, bytes memory message) + public + view + override + returns (bytes memory) + { + bytes32 messageHash = keccak256(abi.encode(ERC6900_TYPEHASH, keccak256(message))); + return abi.encodePacked("\x19\x01", _domainSeparator(account), messageHash); + } + + /// @inheritdoc IMultiOwnerPlugin + function getMessageHash(address account, bytes memory message) public view override returns (bytes32) { + return keccak256(encodeMessageData(account, message)); + } + + /// @inheritdoc IMultiOwnerPlugin + function isOwnerOf(address account, address ownerToCheck) public view returns (bool) { + return _owners.contains(account, CastLib.toSetValue(ownerToCheck)); + } + + /// @inheritdoc IMultiOwnerPlugin + function ownersOf(address account) public view returns (address[] memory) { + return CastLib.toAddressArray(_owners.getAll(account)); + } + + // ┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ + // ┃ Plugin interface functions ┃ + // ┗━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┛ + + /// @inheritdoc BasePlugin + function _onInstall(bytes calldata data) internal override isNotInitialized(msg.sender) { + (address[] memory initialOwners) = abi.decode(data, (address[])); + if (initialOwners.length == 0) { + revert EmptyOwnersNotAllowed(); + } + _addOwnersOrRevert(_owners, msg.sender, initialOwners); + } + + /// @inheritdoc BasePlugin + function onUninstall(bytes calldata) external override { + _owners.clear(msg.sender); + } + + /// @inheritdoc BasePlugin + function userOpValidationFunction(uint8 functionId, UserOperation calldata userOp, bytes32 userOpHash) + external + view + override + returns (uint256) + { + (address signer, ECDSA.RecoverError error) = + (userOpHash.toEthSignedMessageHash()).tryRecover(userOp.signature); + if (functionId == uint8(FunctionId.USER_OP_VALIDATION_OWNER)) { + if (error == ECDSA.RecoverError.NoError) { + if (isOwnerOf(msg.sender, signer)) { + return _SIG_VALIDATION_PASSED; + } + } else { + if (_isValidERC1271OwnerTypeSignature(msg.sender, userOpHash, userOp.signature)) { + return _SIG_VALIDATION_PASSED; + } + } + return _SIG_VALIDATION_FAILED; + } + revert NotImplemented(); + } + + /// @inheritdoc BasePlugin + function runtimeValidationFunction(uint8 functionId, address sender, uint256, bytes calldata) + external + view + override + { + if (functionId == uint8(FunctionId.RUNTIME_VALIDATION_OWNER_OR_SELF)) { + // Validate that the sender is an owner of the account, or self. + if (sender != msg.sender && !isOwnerOf(msg.sender, sender)) { + revert NotAuthorized(); + } + } else { + revert NotImplemented(); + } + } + + /// @inheritdoc BasePlugin + function pluginManifest() external pure override returns (PluginManifest memory) { + PluginManifest memory manifest; + + manifest.executionFunctions = new bytes4[](5); + manifest.executionFunctions[0] = this.updateOwners.selector; + manifest.executionFunctions[1] = this.owners.selector; + manifest.executionFunctions[2] = this.isOwner.selector; + manifest.executionFunctions[3] = this.eip712Domain.selector; + manifest.executionFunctions[4] = this.isValidSignature.selector; + + ManifestFunction memory ownerUserOpValidationFunction = ManifestFunction({ + functionType: ManifestAssociatedFunctionType.SELF, + functionId: uint8(FunctionId.USER_OP_VALIDATION_OWNER), + dependencyIndex: 0 // Unused. + }); + + // Update Modular Account's native functions to use userOpValidationFunction provided by this plugin + manifest.userOpValidationFunctions = new ManifestAssociatedFunction[](6); + manifest.userOpValidationFunctions[0] = ManifestAssociatedFunction({ + executionSelector: this.updateOwners.selector, + associatedFunction: ownerUserOpValidationFunction + }); + manifest.userOpValidationFunctions[1] = ManifestAssociatedFunction({ + executionSelector: IStandardExecutor.execute.selector, + associatedFunction: ownerUserOpValidationFunction + }); + manifest.userOpValidationFunctions[2] = ManifestAssociatedFunction({ + executionSelector: IStandardExecutor.executeBatch.selector, + associatedFunction: ownerUserOpValidationFunction + }); + manifest.userOpValidationFunctions[3] = ManifestAssociatedFunction({ + executionSelector: UpgradeableModularAccount.installPlugin.selector, + associatedFunction: ownerUserOpValidationFunction + }); + manifest.userOpValidationFunctions[4] = ManifestAssociatedFunction({ + executionSelector: UpgradeableModularAccount.uninstallPlugin.selector, + associatedFunction: ownerUserOpValidationFunction + }); + manifest.userOpValidationFunctions[5] = ManifestAssociatedFunction({ + executionSelector: UUPSUpgradeable.upgradeToAndCall.selector, + associatedFunction: ownerUserOpValidationFunction + }); + + ManifestFunction memory ownerOrSelfRuntimeValidationFunction = ManifestFunction({ + functionType: ManifestAssociatedFunctionType.SELF, + functionId: uint8(FunctionId.RUNTIME_VALIDATION_OWNER_OR_SELF), + dependencyIndex: 0 // Unused. + }); + ManifestFunction memory alwaysAllowFunction = ManifestFunction({ + functionType: ManifestAssociatedFunctionType.RUNTIME_VALIDATION_ALWAYS_ALLOW, + functionId: 0, // Unused. + dependencyIndex: 0 // Unused. + }); + + // Update Modular Account's native functions to use runtimeValidationFunction provided by this plugin + manifest.runtimeValidationFunctions = new ManifestAssociatedFunction[](10); + manifest.runtimeValidationFunctions[0] = ManifestAssociatedFunction({ + executionSelector: this.updateOwners.selector, + associatedFunction: ownerOrSelfRuntimeValidationFunction + }); + manifest.runtimeValidationFunctions[1] = ManifestAssociatedFunction({ + executionSelector: IStandardExecutor.execute.selector, + associatedFunction: ownerOrSelfRuntimeValidationFunction + }); + manifest.runtimeValidationFunctions[2] = ManifestAssociatedFunction({ + executionSelector: IStandardExecutor.executeBatch.selector, + associatedFunction: ownerOrSelfRuntimeValidationFunction + }); + manifest.runtimeValidationFunctions[3] = ManifestAssociatedFunction({ + executionSelector: UpgradeableModularAccount.installPlugin.selector, + associatedFunction: ownerOrSelfRuntimeValidationFunction + }); + manifest.runtimeValidationFunctions[4] = ManifestAssociatedFunction({ + executionSelector: UpgradeableModularAccount.uninstallPlugin.selector, + associatedFunction: ownerOrSelfRuntimeValidationFunction + }); + manifest.runtimeValidationFunctions[5] = ManifestAssociatedFunction({ + executionSelector: UUPSUpgradeable.upgradeToAndCall.selector, + associatedFunction: ownerOrSelfRuntimeValidationFunction + }); + manifest.runtimeValidationFunctions[6] = ManifestAssociatedFunction({ + executionSelector: this.isValidSignature.selector, + associatedFunction: alwaysAllowFunction + }); + manifest.runtimeValidationFunctions[7] = ManifestAssociatedFunction({ + executionSelector: this.isOwner.selector, + associatedFunction: alwaysAllowFunction + }); + manifest.runtimeValidationFunctions[8] = ManifestAssociatedFunction({ + executionSelector: this.owners.selector, + associatedFunction: alwaysAllowFunction + }); + manifest.runtimeValidationFunctions[9] = ManifestAssociatedFunction({ + executionSelector: this.eip712Domain.selector, + associatedFunction: alwaysAllowFunction + }); + + return manifest; + } + + /// @inheritdoc BasePlugin + function pluginMetadata() external pure virtual override returns (PluginMetadata memory) { + PluginMetadata memory metadata; + metadata.name = _NAME; + metadata.version = _VERSION; + metadata.author = _AUTHOR; + + // Permission strings + string memory modifyOwnershipPermission = "Modify Ownership"; + + // Permission descriptions + metadata.permissionDescriptors = new SelectorPermission[](1); + metadata.permissionDescriptors[0] = SelectorPermission({ + functionSelector: this.updateOwners.selector, + permissionDescription: modifyOwnershipPermission + }); + + return metadata; + } + + // ┏━━━━━━━━━━━━━━━┓ + // ┃ EIP-165 ┃ + // ┗━━━━━━━━━━━━━━━┛ + + /// @inheritdoc BasePlugin + function supportsInterface(bytes4 interfaceId) public view override returns (bool) { + return interfaceId == type(IMultiOwnerPlugin).interfaceId || super.supportsInterface(interfaceId); + } + + // ┏━━━━━━━━━━━━━━━┓ + // ┃ Internal ┃ + // ┗━━━━━━━━━━━━━━━┛ + + function _domainSeparator(address account) internal view returns (bytes32) { + return keccak256(abi.encode(_TYPE_HASH, _HASHED_NAME, _HASHED_VERSION, block.chainid, account)); + } + + function _addOwnersOrRevert( + AssociatedLinkedListSet storage ownerSet, + address associated, + address[] memory ownersToAdd + ) private { + uint256 length = ownersToAdd.length; + for (uint256 i = 0; i < length;) { + if (!ownerSet.tryAdd(associated, CastLib.toSetValue(ownersToAdd[i]))) { + revert InvalidOwner(ownersToAdd[i]); + } + + unchecked { + ++i; + } + } + } + + function _removeOwnersOrRevert( + AssociatedLinkedListSet storage ownerSet, + address associated, + address[] memory ownersToRemove + ) private { + uint256 ownersToRemoveLength = ownersToRemove.length; + for (uint256 j; j < ownersToRemoveLength;) { + if (!ownerSet.tryRemove(associated, CastLib.toSetValue(ownersToRemove[j]))) { + revert OwnerDoesNotExist(ownersToRemove[j]); + } + + unchecked { + ++j; + } + } + } + + function _isValidERC1271OwnerTypeSignature(address associated, bytes32 digest, bytes memory signature) + private + view + returns (bool) + { + address[] memory owners_ = ownersOf(associated); + for (uint256 i; i < owners_.length;) { + if (SignatureChecker.isValidERC1271SignatureNow(owners_[i], digest, signature)) { + return true; + } + + unchecked { + ++i; + } + } + return false; + } + + /// @inheritdoc BasePlugin + function _isInitialized(address account) internal view override returns (bool) { + return !_owners.isEmpty(account); + } +} diff --git a/src/plugins/session/ISessionKeyPlugin.sol b/src/plugins/session/ISessionKeyPlugin.sol new file mode 100644 index 00000000..713fb90d --- /dev/null +++ b/src/plugins/session/ISessionKeyPlugin.sol @@ -0,0 +1,70 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.21; + +import {Call} from "../../interfaces/IStandardExecutor.sol"; +import {UserOperation} from "../../interfaces/erc4337/UserOperation.sol"; + +interface ISessionKeyPlugin { + enum FunctionId {USER_OP_VALIDATION_SESSION_KEY} + + error InvalidSessionKey(address sessionKey); + error NotAuthorized(address caller); + error SessionKeyAlreadyExists(address sessionKey); + error SessionKeyNotFound(address sessionKey); + error UnableToRemove(address sessionKey); + + struct SessionKeyToRemove { + address sessionKey; + bytes32 predecessor; + } + + /// @notice Perform a batch execution with a session key. + /// @dev The session key address is included as a parameter so context may be preserved across validation and + /// execution. + /// @param calls The array of calls to be performed. + /// @param sessionKey The session key to be used for the execution. + /// @return The array of return data from the executions. + function executeWithSessionKey(Call[] calldata calls, address sessionKey) external returns (bytes[] memory); + + /// @notice Get the session keys of the account. + /// @return The array of session keys of the account. + function getSessionKeys() external view returns (address[] memory); + + /// @notice Check if a session key is a session key of the account. + /// @param sessionKey The session key to check. + /// @return The boolean whether the session key is a session key of the account. + function isSessionKey(address sessionKey) external view returns (bool); + + /// @notice Add and remove session keys from the account. + /// Note that the session keys to remove will be removed prior to any being added, and they will be removed in + /// order from first to last. If the predecessor changes due to a prior removal, the caller should pass in the + /// updated predecessor. + /// @param sessionKeysToAdd The array of session keys to add to the account. + /// @param sessionKeysToRemove The array of session keys to remove from the account, along with their + /// predecessor in the list. + function updateSessionKeys( + address[] calldata sessionKeysToAdd, + SessionKeyToRemove[] calldata sessionKeysToRemove + ) external; + + /// @notice Get the session keys of the account. + /// This function is not added to accounts during installation. + /// @param account The account to get the session keys of. + /// @return The array of session keys of the account. + function sessionKeysOf(address account) external view returns (address[] memory); + + /// @notice Check if a session key is a session key of the account. + /// This function is not added to accounts during installation. + /// @param account The account to check. + /// @param sessionKey The session key to check. + /// @return The boolean whether the session key is a session key of the account. + function isSessionKeyOf(address account, address sessionKey) external view returns (bool); + + /// @notice Get the list predecessor of a session key. This is used as an extra parameter to make removing + /// session keys more efficient. + /// This function is not added to accounts during installation. + /// @param account The account to check. + /// @param sessionKey The session key to check. + /// @return The list predecessor of the session key. + function findPredecessor(address account, address sessionKey) external view returns (bytes32); +} diff --git a/src/plugins/session/SessionKeyPlugin.sol b/src/plugins/session/SessionKeyPlugin.sol new file mode 100644 index 00000000..041ac61f --- /dev/null +++ b/src/plugins/session/SessionKeyPlugin.sol @@ -0,0 +1,313 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.21; + +import {ECDSA} from "@openzeppelin/contracts/utils/cryptography/ECDSA.sol"; + +import {BasePlugin} from "../BasePlugin.sol"; +import {ISessionKeyPlugin} from "./ISessionKeyPlugin.sol"; + +import {IPlugin} from "../../interfaces/IPlugin.sol"; +import {IPluginExecutor} from "../../interfaces/IPluginExecutor.sol"; +import { + ManifestFunction, + ManifestAssociatedFunctionType, + ManifestAssociatedFunction, + PluginManifest, + PluginMetadata, + SelectorPermission +} from "../../interfaces/IPlugin.sol"; +import {Call, IStandardExecutor} from "../../interfaces/IStandardExecutor.sol"; +import {UserOperation} from "../../interfaces/erc4337/UserOperation.sol"; + +import { + AssociatedLinkedListSet, AssociatedLinkedListSetLib +} from "../../libraries/AssociatedLinkedListSetLib.sol"; +import {CastLib} from "../../libraries/CastLib.sol"; +import {SetValue, SENTINEL_VALUE} from "../../libraries/LinkedListSetUtils.sol"; + +/// @title Session Key Plugin +/// @author Alchemy +/// @notice This plugin allows users to set session keys that can be used to validate user operations performing +/// external calls. It does not enforce any permissions on what the keys can do, that must be configured via other +/// plugins with hooks. +contract SessionKeyPlugin is BasePlugin, ISessionKeyPlugin { + using ECDSA for bytes32; + using AssociatedLinkedListSetLib for AssociatedLinkedListSet; + + string internal constant _NAME = "Session Key Plugin"; + string internal constant _VERSION = "1.0.0"; + string internal constant _AUTHOR = "Alchemy"; + + uint256 internal constant _SIG_VALIDATION_PASSED = 0; + uint256 internal constant _SIG_VALIDATION_FAILED = 1; + + // Constants used in the manifest + uint256 internal constant _MANIFEST_DEPENDENCY_INDEX_OWNER_USER_OP_VALIDATION = 0; + uint256 internal constant _MANIFEST_DEPENDENCY_INDEX_OWNER_RUNTIME_VALIDATION = 1; + + AssociatedLinkedListSet internal _sessionKeys; + + // ┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ + // ┃ Execution functions ┃ + // ┗━━━━━━━━━━━━━━━━━━━━━━━━━━━┛ + + /// @inheritdoc ISessionKeyPlugin + function executeWithSessionKey(Call[] calldata calls, address) external returns (bytes[] memory) { + uint256 callsLength = calls.length; + bytes[] memory results = new bytes[](callsLength); + + for (uint256 i = 0; i < callsLength;) { + Call calldata call = calls[i]; + + results[i] = IPluginExecutor(msg.sender).executeFromPluginExternal(call.target, call.value, call.data); + + unchecked { + ++i; + } + } + + return results; + } + + /// @inheritdoc ISessionKeyPlugin + function getSessionKeys() external view returns (address[] memory) { + SetValue[] memory values = _sessionKeys.getAll(msg.sender); + + return CastLib.toAddressArray(values); + } + + /// @inheritdoc ISessionKeyPlugin + function isSessionKey(address sessionKey) external view returns (bool) { + return _sessionKeys.contains(msg.sender, CastLib.toSetValue(sessionKey)); + } + + /// @inheritdoc ISessionKeyPlugin + function updateSessionKeys( + address[] calldata sessionKeysToAdd, + SessionKeyToRemove[] calldata sessionKeysToRemove + ) external { + uint256 length = sessionKeysToRemove.length; + for (uint256 i = 0; i < length;) { + if ( + !_sessionKeys.tryRemoveKnown( + msg.sender, + CastLib.toSetValue(sessionKeysToRemove[i].sessionKey), + sessionKeysToRemove[i].predecessor + ) + ) { + revert UnableToRemove(sessionKeysToRemove[i].sessionKey); + } + + unchecked { + ++i; + } + } + + length = sessionKeysToAdd.length; + for (uint256 i = 0; i < length;) { + if (!_sessionKeys.tryAdd(msg.sender, CastLib.toSetValue(sessionKeysToAdd[i]))) { + revert SessionKeyAlreadyExists(sessionKeysToAdd[i]); + } + + unchecked { + ++i; + } + } + } + + // ┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ + // ┃ Plugin view functions ┃ + // ┗━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┛ + + /// @inheritdoc ISessionKeyPlugin + function sessionKeysOf(address account) external view returns (address[] memory) { + SetValue[] memory values = _sessionKeys.getAll(account); + + return CastLib.toAddressArray(values); + } + + /// @inheritdoc ISessionKeyPlugin + function isSessionKeyOf(address account, address sessionKey) external view returns (bool) { + return _sessionKeys.contains(account, CastLib.toSetValue(sessionKey)); + } + + /// @inheritdoc ISessionKeyPlugin + function findPredecessor(address account, address sessionKey) external view returns (bytes32) { + address[] memory sessionKeys = CastLib.toAddressArray(_sessionKeys.getAll(account)); + + uint256 length = sessionKeys.length; + bytes32 predecessor = SENTINEL_VALUE; + for (uint256 i = 0; i < length;) { + if (sessionKeys[i] == sessionKey) { + return predecessor; + } + + predecessor = bytes32(bytes20(sessionKeys[i])); + + unchecked { + ++i; + } + } + + revert SessionKeyNotFound(sessionKey); + } + + // ┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ + // ┃ Plugin interface functions ┃ + // ┗━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┛ + + /// @inheritdoc BasePlugin + function userOpValidationFunction(uint8 functionId, UserOperation calldata userOp, bytes32 userOpHash) + external + view + override + returns (uint256) + { + if (functionId == uint8(FunctionId.USER_OP_VALIDATION_SESSION_KEY)) { + (, address sessionKey) = abi.decode(userOp.callData[4:], (Call[], address)); + bytes32 hash = userOpHash.toEthSignedMessageHash(); + + if (_sessionKeys.contains(msg.sender, CastLib.toSetValue(sessionKey))) { + (address recoveredSig, ECDSA.RecoverError err) = hash.tryRecover(userOp.signature); + if (err == ECDSA.RecoverError.NoError && sessionKey == recoveredSig) { + return _SIG_VALIDATION_PASSED; + } + } + return _SIG_VALIDATION_FAILED; + } + revert NotImplemented(); + } + + /// @inheritdoc BasePlugin + function _onInstall(bytes calldata data) internal override isNotInitialized(msg.sender) { + address[] memory sessionKeysToAdd = abi.decode(data, (address[])); + + uint256 length = sessionKeysToAdd.length; + for (uint256 i = 0; i < length;) { + address sessionKey = sessionKeysToAdd[i]; + + if (sessionKey == address(0)) { + revert InvalidSessionKey(sessionKey); + } + + if (!_sessionKeys.tryAdd(msg.sender, CastLib.toSetValue(sessionKey))) { + revert SessionKeyAlreadyExists(sessionKeysToAdd[i]); + } + + unchecked { + ++i; + } + } + } + + /// @inheritdoc BasePlugin + function onUninstall(bytes calldata) external override { + _sessionKeys.clear(msg.sender); + } + + /// @inheritdoc BasePlugin + function pluginManifest() external pure override returns (PluginManifest memory) { + PluginManifest memory manifest; + + manifest.dependencyInterfaceIds = new bytes4[](2); + manifest.dependencyInterfaceIds[_MANIFEST_DEPENDENCY_INDEX_OWNER_USER_OP_VALIDATION] = + type(IPlugin).interfaceId; + manifest.dependencyInterfaceIds[_MANIFEST_DEPENDENCY_INDEX_OWNER_RUNTIME_VALIDATION] = + type(IPlugin).interfaceId; + + manifest.executionFunctions = new bytes4[](4); + manifest.executionFunctions[0] = this.executeWithSessionKey.selector; + manifest.executionFunctions[1] = this.getSessionKeys.selector; + manifest.executionFunctions[2] = this.isSessionKey.selector; + manifest.executionFunctions[3] = this.updateSessionKeys.selector; + + ManifestFunction memory sessionKeyUserOpValidationFunction = ManifestFunction({ + functionType: ManifestAssociatedFunctionType.SELF, + functionId: uint8(FunctionId.USER_OP_VALIDATION_SESSION_KEY), + dependencyIndex: 0 // Unused. + }); + ManifestFunction memory ownerUserOpValidationFunction = ManifestFunction({ + functionType: ManifestAssociatedFunctionType.DEPENDENCY, + functionId: 0, // unused since it's a dependency + dependencyIndex: _MANIFEST_DEPENDENCY_INDEX_OWNER_USER_OP_VALIDATION + }); + + manifest.userOpValidationFunctions = new ManifestAssociatedFunction[](2); + manifest.userOpValidationFunctions[0] = ManifestAssociatedFunction({ + executionSelector: this.executeWithSessionKey.selector, + associatedFunction: sessionKeyUserOpValidationFunction + }); + manifest.userOpValidationFunctions[1] = ManifestAssociatedFunction({ + executionSelector: this.updateSessionKeys.selector, + associatedFunction: ownerUserOpValidationFunction + }); + + // Session keys are only expected to be used for user op validation, so no runtime validation functions are + // set over + // executeWithSessionKey. + ManifestFunction memory alwaysAllowValidationFunction = ManifestFunction({ + functionType: ManifestAssociatedFunctionType.RUNTIME_VALIDATION_ALWAYS_ALLOW, + functionId: 0, // Unused. + dependencyIndex: 0 // Unused. + }); + + ManifestFunction memory ownerRuntimeValidationFunction = ManifestFunction({ + functionType: ManifestAssociatedFunctionType.DEPENDENCY, + functionId: 0, // unused since it's a dependency + dependencyIndex: _MANIFEST_DEPENDENCY_INDEX_OWNER_RUNTIME_VALIDATION + }); + + manifest.runtimeValidationFunctions = new ManifestAssociatedFunction[](3); + manifest.runtimeValidationFunctions[0] = ManifestAssociatedFunction({ + executionSelector: this.getSessionKeys.selector, + associatedFunction: alwaysAllowValidationFunction + }); + manifest.runtimeValidationFunctions[1] = ManifestAssociatedFunction({ + executionSelector: this.isSessionKey.selector, + associatedFunction: alwaysAllowValidationFunction + }); + manifest.runtimeValidationFunctions[2] = ManifestAssociatedFunction({ + executionSelector: this.updateSessionKeys.selector, + associatedFunction: ownerRuntimeValidationFunction + }); + + manifest.permitAnyExternalAddress = true; + manifest.canSpendNativeToken = true; + + return manifest; + } + + /// @inheritdoc BasePlugin + function pluginMetadata() external pure virtual override returns (PluginMetadata memory) { + PluginMetadata memory metadata; + metadata.name = _NAME; + metadata.version = _VERSION; + metadata.author = _AUTHOR; + + // Permission strings + string memory modifyOwnershipPermission = "Modify Session Keys"; + + // Permission descriptions + metadata.permissionDescriptors = new SelectorPermission[](1); + metadata.permissionDescriptors[0] = SelectorPermission({ + functionSelector: this.updateSessionKeys.selector, + permissionDescription: modifyOwnershipPermission + }); + + return metadata; + } + + // ┏━━━━━━━━━━━━━━━┓ + // ┃ EIP-165 ┃ + // ┗━━━━━━━━━━━━━━━┛ + + /// @inheritdoc BasePlugin + function supportsInterface(bytes4 interfaceId) public view override returns (bool) { + return interfaceId == type(ISessionKeyPlugin).interfaceId || super.supportsInterface(interfaceId); + } + + /// @inheritdoc BasePlugin + function _isInitialized(address account) internal view override returns (bool) { + return !_sessionKeys.isEmpty(account); + } +} diff --git a/src/plugins/session/permissions/ISessionKeyPermissionsPlugin.sol b/src/plugins/session/permissions/ISessionKeyPermissionsPlugin.sol new file mode 100644 index 00000000..551f9eaa --- /dev/null +++ b/src/plugins/session/permissions/ISessionKeyPermissionsPlugin.sol @@ -0,0 +1,159 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.21; + +import {UserOperation} from "../../../interfaces/erc4337/UserOperation.sol"; + +interface ISessionKeyPermissionsPlugin { + enum FunctionId { + PRE_USER_OP_VALIDATION_HOOK_CHECK_PERMISSIONS, + PRE_EXECUTION_HOOK_UPDATE_LIMITS + } + + enum ContractAccessControlType { + ALLOWLIST, // Allowlist is default + DENYLIST, + NONE + } + + struct SpendLimitInfo { + bool hasLimit; + uint256 limit; + uint256 limitUsed; + uint48 refreshInterval; + uint48 lastUsedTime; + } + + /// @notice Emitted when a session key is registered. + /// @param account The account that owns the session key. + /// @param sessionKey The session key that was registered. + /// @param tag The tag that was associated with the key. + event KeyRegistered(address indexed account, address indexed sessionKey, bytes32 indexed tag); + + /// @notice Emitted when a session key's permissions are updated. + /// @param account The account that owns the session key. + /// @param sessionKey The session key that was updated. + /// @param updates The updates that were performed. Updates are ABI-encoded + event PermissionsUpdated(address indexed account, address indexed sessionKey, bytes[] updates); + + /// @notice Emitted when a session key is rotated, which transfers permissions from one key to another. + /// @param account The account that owns the session key. + /// @param oldSessionKey The session key that was rotated away. + /// @param newSessionKey The session key that was rotated to. + event KeyRotated(address indexed account, address indexed oldSessionKey, address indexed newSessionKey); + + error ERC20SpendLimitExceeded(address account, address sessionKey, address token); + error KeyAlreadyRegistered(address sessionKey); + error KeyNotRegistered(address sessionKey); + error InvalidPermissionsUpdate(); + error InvalidToken(); + error NativeTokenSpendLimitExceeded(address account, address sessionKey); + + /// @notice Register a key with the permissions plugin. Without this step, key cannot be used while the + /// permissions plugin is installed. + /// @param sessionKey The session key to register. + /// @param tag An optional tag that can be used to identify the key. + function registerKey(address sessionKey, bytes32 tag) external; + + /// @notice Move a session key's registration status and existing permissions to another session key. + /// @param oldSessionKey The session key to move. + /// @param newSessionKey The session key to move to. + function rotateKey(address oldSessionKey, address newSessionKey) external; + + /// @notice Performs a sequence of updates to a session key's permissions. These updates are abi-encoded calls + /// to the functions defined in `ISessionKeyPermissionsUpdates`, and are not external functions implemented by + /// this contract. + /// @param sessionKey The session key for which to update permissions. + /// @param updates The abi-encoded updates to perform. + function updateKeyPermissions(address sessionKey, bytes[] calldata updates) external; + + /// @notice An externally available function, callable by anyone, that resets the "last used" timestamp on a + /// session key. This helps a session key get "unstuck" if it was used in a setting where every call it made + /// while using a new interval's gas limit reverted. Since this plugin internally tracks when that reset should + /// happen, this function does not need other validation. + /// @param account The account that owns the session key. + /// @param sessionKey The session key to reset. + function resetSessionKeyGasLimitTimestamp(address account, address sessionKey) external; + + /// @notice Get the access control type for a session key on an account. + /// @param account The account to check. + /// @param sessionKey The session key to check. + /// @return The access control type for the session key on the account. + function getAccessControlType(address account, address sessionKey) + external + view + returns (ContractAccessControlType); + + /// @notice Get an access control entry for a session key on an account. + /// @param account The account to check. + /// @param sessionKey The session key to check. + /// @param targetAddress The target address to check. + /// @return isOnList Whether the target address is on the list (either allowlist or blocklist depending on the + /// access control type). + /// @return checkSelectors Whether the target address should be checked for selectors during permissions + /// enforcement. + function getAccessControlEntry(address account, address sessionKey, address targetAddress) + external + view + returns (bool isOnList, bool checkSelectors); + + /// @notice Get whether a selector is on the access control list for a session key on an account. + /// @param account The account to check. + /// @param sessionKey The session key to check. + /// @param targetAddress The target address to check. + /// @param selector The selector to check. + /// @return isOnList Whether the selector is on the list (either allowlist or blocklist depending on the + /// access control type). + function isSelectorOnAccessControlList( + address account, + address sessionKey, + address targetAddress, + bytes4 selector + ) external view returns (bool isOnList); + + /// @notice Get the active time range for a session key on an account. + /// @param account The account to check. + /// @param sessionKey The session key to check. + /// @return validAfter The time after which the session key is valid. + /// @return validUntil The time until which the session key is valid. + function getKeyTimeRange(address account, address sessionKey) + external + view + returns (uint48 validAfter, uint48 validUntil); + + /// @notice Get the native token spend limit for a session key on an account. + /// @param account The account to check. + /// @param sessionKey The session key to check. + /// @return A struct with fields describing the state of native token spending limits on this session key. + function getNativeTokenSpendLimitInfo(address account, address sessionKey) + external + view + returns (SpendLimitInfo memory); + + /// @notice Get the gas spend limit for a session key on an account. + /// @param account The account to check. + /// @param sessionKey The session key to check. + /// @return info A struct with fields describing the state of gas spending limits on this session key. + /// @return shouldReset Whether this session key must be reset by calling `resetSessionKeyGasLimitTimestamp` + /// before it can be used. + function getGasSpendLimit(address account, address sessionKey) + external + view + returns (SpendLimitInfo memory info, bool shouldReset); + + /// @notice Get the ERC20 spend limit for a session key on an account. + /// @param account The account to check. + /// @param sessionKey The session key to check. + /// @param token The token to check. + /// @return A struct with fields describing the state of ERC20 spending limits on this session key. + function getERC20SpendLimitInfo(address account, address sessionKey, address token) + external + view + returns (SpendLimitInfo memory); + + /// @notice Get the required paymaster address for a session key on an account, if any. + /// @param account The account to check. + /// @param sessionKey The session key to check. + /// @return The required paymaster address for this session key on this account, or the zero address if the + /// rule is disabled. + function getRequiredPaymaster(address account, address sessionKey) external view returns (address); +} diff --git a/src/plugins/session/permissions/ISessionKeyPermissionsUpdates.sol b/src/plugins/session/permissions/ISessionKeyPermissionsUpdates.sol new file mode 100644 index 00000000..3354d3f7 --- /dev/null +++ b/src/plugins/session/permissions/ISessionKeyPermissionsUpdates.sol @@ -0,0 +1,66 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.21; + +import {ISessionKeyPermissionsPlugin} from "./ISessionKeyPermissionsPlugin.sol"; + +/// @notice This interface defines the functions that may be used to update a session key's permissions. +/// The functions defined here are not actually implemented, but instead are abi-encoded as arguments to the +/// update function `updateKeyPermissions`. +interface ISessionKeyPermissionsUpdates { + /// @notice Update the session key's "access control type". The access control type determines whether the list + /// is treated as an allowlist, denylist, or if the listis ignored. Note that if the list type is changed, the + /// previous elements from the list are not cleared, and instead reinterpretted as entries in the new list + /// type. + /// @param contractAccessControlType The new access control type. + function setAccessListType(ISessionKeyPermissionsPlugin.ContractAccessControlType contractAccessControlType) + external; + + /// @notice Add or remove a contract address from the access list, optionally specifying whether to check + /// selectors. + /// @param contractAddress The contract address to add or remove. + /// @param isOnList Whether the contract address should be on the list. + /// @param checkSelectors Whether to check selectors for the contract address. + function updateAccessListAddressEntry(address contractAddress, bool isOnList, bool checkSelectors) external; + + /// @notice Add or remove a function selector from the access list. + /// @param contractAddress The contract address to add or remove. + /// @param selector The function selector to add or remove. + /// @param isOnList Whether the function selector should be on the list. + function updateAccessListFunctionEntry(address contractAddress, bytes4 selector, bool isOnList) external; + + /// @notice Sets the time range for a session key. + /// @param validAfter The time after which the session key may be used. + /// @param validUntil The time before which the session key may be used. + function updateTimeRange(uint48 validAfter, uint48 validUntil) external; + + /// @notice Sets the native token spend limit for a session key. This specifies how much of the native token + /// the session key may use, optionally with a refresh interval that specifies how often the limit is reset. + /// @param spendLimit The maximum amount of native token the session key may spend. + /// @param refreshInterval The time interval over which the spend limit is enforced. If zero, there is no time + /// interval by which the limit is refreshed. + function setNativeTokenSpendLimit(uint256 spendLimit, uint48 refreshInterval) external; + + /// @notice Sets the ERC-20 spend limit for a session key. + /// @param token The ERC-20 token address. + /// @param spendLimit The maximum amount of the ERC-20 token the session key may spend. + /// @param refreshInterval The time interval over which the spend limit is enforced. If zero, the spend limit + /// is never refreshed. + function setERC20SpendLimit(address token, uint256 spendLimit, uint48 refreshInterval) external; + + /// @notice Sets the gas spend limit for a session key. Note that the session key permissions enforcement will + /// usually overestimate the gas usage per user operation. + /// @dev If the account is staked, a malicious session key user may abuse gas limits to cause reputation damage + /// to the account. This is because when a gas limit is set, there are state updates during validation that can + /// potentially cause future user ops in the same bundle to fail. Intelligent bundlers may re-simulate and + /// remove the latter ops that exceed the gas limits, but this is not a guarantee. + /// @param spendLimit The maximum amount of native token the session key may spend on gas. This will always be + /// the result of an overestimate, however. + /// @param refreshInterval The time interval by which the spend limit is refreshed. If zero, the spend limit is + /// never refreshed. + function setGasSpendLimit(uint256 spendLimit, uint48 refreshInterval) external; + + /// @notice Sets the required paymaster for a session key. + /// @param requiredPaymaster The required paymaster for the session key. If the rule should be removed, this + /// should be address(0). + function setRequiredPaymaster(address requiredPaymaster) external; +} diff --git a/src/plugins/session/permissions/SessionKeyPermissionsBase.sol b/src/plugins/session/permissions/SessionKeyPermissionsBase.sol new file mode 100644 index 00000000..d41a15f8 --- /dev/null +++ b/src/plugins/session/permissions/SessionKeyPermissionsBase.sol @@ -0,0 +1,200 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.21; + +import {ISessionKeyPermissionsPlugin} from "./ISessionKeyPermissionsPlugin.sol"; + +import {PluginStorageLib, StoragePointer} from "../../../libraries/PluginStorageLib.sol"; + +abstract contract SessionKeyPermissionsBase is ISessionKeyPermissionsPlugin { + type SessionKeyId is bytes32; + + struct SessionKeyData { + // Contract access control type + ContractAccessControlType contractAccessControlType; + // Key time range: limits when a key may be used. + uint48 validAfter; + uint48 validUntil; + bool hasRequiredPaymaster; + bool hasGasLimit; + bool gasLimitResetThisBundle; + // Native token spend limits + bool nativeTokenSpendLimitBypassed; // By default, spend limits ARE enforced and the limit is zero. + SpendLimitTimeInfo gasLimitTimeInfo; + SpendLimitTimeInfo nativeTokenSpendLimitTimeInfo; + // Required paymaster rule + address requiredPaymaster; + SpendLimit gasLimit; + SpendLimit nativeTokenSpendLimit; + } + + /// @dev These structs are not held in an Associated Enumerable set, so the elements must be emitted from + /// events to use offchain. + struct ContractData { + bool isOnList; + bool checkSelectors; + bool isERC20WithSpendLimit; + SpendLimitTimeInfo erc20SpendLimitTimeInfo; + SpendLimit erc20SpendLimit; + } + + struct FunctionData { + bool isOnList; + } + + // Spending limit info structs. + // Split into two structs to allow custom storage arrangements. + + struct SpendLimitTimeInfo { + uint48 lastUsed; + uint48 refreshInterval; + } + + struct SpendLimit { + uint256 limitAmount; + uint256 limitUsed; + } + + // PluginStorageLib KEY DEFINITIONS + // When adding a new permission type, you must: + // 1. Add new prefixes here for all stored structs + // 2. Define the key derivation and checking functions in a new file + // 3. Use the checking function in PermissionsCheckerPlugin + + // Prefixes: + bytes4 internal constant SESSION_KEY_ID_PREFIX = bytes4(0x1a01dae4); // bytes4(keccak256("SessionKeyId")) + bytes4 internal constant SESSION_KEY_DATA_PREFIX = bytes4(0x16bff296); // bytes4(keccak256("SessionKeyData")) + bytes4 internal constant CONTRACT_DATA_PREFIX = bytes4(0x634c29f5); // bytes4(keccak256("ContractData")) + bytes4 internal constant FUNCTION_DATA_PREFIX = bytes4(0xd50536f0); // bytes4(keccak256("FunctionData")) + + // KEY DERIVATION + // All of these following keys begin with the associated address, + // the prefix, and a uint224 batch index of zero. + + // All PluginStorageLib keys are, at a minimum, 96 bytes long. + // The first word (32 bytes) is the associated address. + // The second word (32 bytes) is the prefix and batch index concatenated. + // Any subsequent words are the key data. + + // SessionKeyId storage key (96 bytes) + // 12 padding zeros || associated address || SESSION_KEY_ID_PREFIX || batch index || 12 padding zero bytes + // || sessionKey + + // SessionKeyData (96 bytes) + // 12 padding zeros || associated address || SESSION_KEY_DATA_PREFIX || batch index || sessionKeyId + + // ContractData (128 bytes) + // 12 padding zeros || associated address || CONTRACT_DATA_PREFIX || batch index || sessionKeyId + // || 12 padding zero bytes || contractAddress + + // FunctionData (128 bytes) + // 12 padding zeros || associated address || FUNCTION_DATA_PREFIX || batch index || sessionKeyId || selector + // || 8 padding zero bytes || contractAddress + + // Storage fields + mapping(address => uint256) internal _keyIdCounter; + + function _sessionKeyIdOf(address associated, address sessionKey) internal view returns (SessionKeyId keyId) { + uint256 prefixAndBatchIndex = uint256(bytes32(SESSION_KEY_ID_PREFIX)); + bytes memory associatedStorageKey = + PluginStorageLib.allocateAssociatedStorageKey(associated, prefixAndBatchIndex, 1); + StoragePointer ptr = + PluginStorageLib.associatedStorageLookup(associatedStorageKey, bytes32(uint256(uint160(sessionKey)))); + assembly ("memory-safe") { + keyId := sload(ptr) + } + } + + function _assertRegistered(SessionKeyId id, address sessionKey) internal pure { + if (SessionKeyId.unwrap(id) == bytes32(0)) { + revert KeyNotRegistered(sessionKey); + } + } + + function _updateSessionKeyId(address associated, address sessionKey, SessionKeyId newId) internal { + uint256 prefixAndBatchIndex = uint256(bytes32(SESSION_KEY_ID_PREFIX)); + bytes memory associatedStorageKey = + PluginStorageLib.allocateAssociatedStorageKey(associated, prefixAndBatchIndex, 1); + StoragePointer ptr = + PluginStorageLib.associatedStorageLookup(associatedStorageKey, bytes32(uint256(uint160(sessionKey)))); + assembly ("memory-safe") { + sstore(ptr, newId) + } + } + + function _sessionKeyDataOf(address associated, SessionKeyId id) + internal + pure + returns (SessionKeyData storage sessionKeyData) + { + uint256 prefixAndBatchIndex = uint256(bytes32(SESSION_KEY_DATA_PREFIX)); + bytes memory associatedStorageKey = + PluginStorageLib.allocateAssociatedStorageKey(associated, prefixAndBatchIndex, 1); + + bytes32 sessionKeyDataKey = bytes32(abi.encodePacked(SESSION_KEY_DATA_PREFIX, SessionKeyId.unwrap(id))); + return _toSessionKeyData(PluginStorageLib.associatedStorageLookup(associatedStorageKey, sessionKeyDataKey)); + } + + /// @dev Helper function that loads the session key id, asserts it is registered, and returns the session key + /// data and the key id. + function _loadSessionKey(address associated, address sessionKey) + internal + view + returns (SessionKeyData storage sessionKeyData, SessionKeyId keyId) + { + SessionKeyId id = _sessionKeyIdOf(associated, sessionKey); + _assertRegistered(id, sessionKey); + return (_sessionKeyDataOf(associated, id), id); + } + + function _contractDataOf(address associated, SessionKeyId id, address contractAddress) + internal + pure + returns (ContractData storage contractData) + { + uint256 prefixAndBatchIndex = uint256(bytes32(CONTRACT_DATA_PREFIX)); + bytes memory associatedStorageKey = + PluginStorageLib.allocateAssociatedStorageKey(associated, prefixAndBatchIndex, 2); + + bytes32 contractDataKey1 = SessionKeyId.unwrap(id); + bytes32 contractDataKey2 = bytes32(bytes20(contractAddress)); + return _toContractData( + PluginStorageLib.associatedStorageLookup(associatedStorageKey, contractDataKey1, contractDataKey2) + ); + } + + function _functionDataOf(address associated, SessionKeyId id, address contractAddress, bytes4 selector) + internal + pure + returns (FunctionData storage functionData) + { + uint256 prefixAndBatchIndex = uint256(bytes32(FUNCTION_DATA_PREFIX)); + bytes memory associatedStorageKey = + PluginStorageLib.allocateAssociatedStorageKey(associated, prefixAndBatchIndex, 2); + + bytes32 functionDataKey1 = SessionKeyId.unwrap(id); + bytes32 functionDataKey2 = bytes32(selector) | bytes32(uint256(uint160(contractAddress))); + return _toFunctionData( + PluginStorageLib.associatedStorageLookup(associatedStorageKey, functionDataKey1, functionDataKey2) + ); + } + + // Storage pointer interpretation + + function _toSessionKeyData(StoragePointer ptr) internal pure returns (SessionKeyData storage sessionKeyData) { + assembly ("memory-safe") { + sessionKeyData.slot := ptr + } + } + + function _toContractData(StoragePointer ptr) internal pure returns (ContractData storage contractData) { + assembly ("memory-safe") { + contractData.slot := ptr + } + } + + function _toFunctionData(StoragePointer ptr) internal pure returns (FunctionData storage functionData) { + assembly ("memory-safe") { + functionData.slot := ptr + } + } +} diff --git a/src/plugins/session/permissions/SessionKeyPermissionsLoupe.sol b/src/plugins/session/permissions/SessionKeyPermissionsLoupe.sol new file mode 100644 index 00000000..f3bc3a8d --- /dev/null +++ b/src/plugins/session/permissions/SessionKeyPermissionsLoupe.sol @@ -0,0 +1,133 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.21; + +import {ISessionKeyPermissionsPlugin} from "./ISessionKeyPermissionsPlugin.sol"; +import {SessionKeyPermissionsBase} from "./SessionKeyPermissionsBase.sol"; + +abstract contract SessionKeyPermissionsLoupe is SessionKeyPermissionsBase { + /// @inheritdoc ISessionKeyPermissionsPlugin + function getAccessControlType(address account, address sessionKey) + external + view + returns (ContractAccessControlType) + { + (SessionKeyData storage sessionKeyData,) = _loadSessionKey(account, sessionKey); + return sessionKeyData.contractAccessControlType; + } + + /// @inheritdoc ISessionKeyPermissionsPlugin + function getAccessControlEntry(address account, address sessionKey, address contractAddress) + external + view + returns (bool isOnList, bool checkSelectors) + { + SessionKeyId keyId = _sessionKeyIdOf(account, sessionKey); + _assertRegistered(keyId, sessionKey); + ContractData storage contractData = _contractDataOf(account, keyId, contractAddress); + return (contractData.isOnList, contractData.checkSelectors); + } + + /// @inheritdoc ISessionKeyPermissionsPlugin + function isSelectorOnAccessControlList( + address account, + address sessionKey, + address contractAddress, + bytes4 selector + ) external view returns (bool isOnList) { + SessionKeyId keyId = _sessionKeyIdOf(account, sessionKey); + _assertRegistered(keyId, sessionKey); + FunctionData storage functionData = _functionDataOf(account, keyId, contractAddress, selector); + return functionData.isOnList; + } + + /// @inheritdoc ISessionKeyPermissionsPlugin + function getKeyTimeRange(address account, address sessionKey) + external + view + returns (uint48 validAfter, uint48 validUntil) + { + (SessionKeyData storage sessionKeyData,) = _loadSessionKey(account, sessionKey); + return (sessionKeyData.validAfter, sessionKeyData.validUntil); + } + + /// @inheritdoc ISessionKeyPermissionsPlugin + function getNativeTokenSpendLimitInfo(address account, address sessionKey) + external + view + returns (SpendLimitInfo memory) + { + (SessionKeyData storage sessionKeyData,) = _loadSessionKey(account, sessionKey); + + bool hasLimit = !sessionKeyData.nativeTokenSpendLimitBypassed; + + if (hasLimit) { + return SpendLimitInfo({ + hasLimit: true, + limit: sessionKeyData.nativeTokenSpendLimit.limitAmount, + limitUsed: sessionKeyData.nativeTokenSpendLimit.limitUsed, + refreshInterval: sessionKeyData.nativeTokenSpendLimitTimeInfo.refreshInterval, + lastUsedTime: sessionKeyData.nativeTokenSpendLimitTimeInfo.lastUsed + }); + } else { + // The fields aren't cleared until the next time they are set, so report zeros. + return SpendLimitInfo({hasLimit: false, limit: 0, limitUsed: 0, refreshInterval: 0, lastUsedTime: 0}); + } + } + + /// @inheritdoc ISessionKeyPermissionsPlugin + function getERC20SpendLimitInfo(address account, address sessionKey, address token) + external + view + returns (SpendLimitInfo memory) + { + (, SessionKeyId keyId) = _loadSessionKey(account, sessionKey); + ContractData storage tokenContractData = _contractDataOf(account, keyId, token); + return SpendLimitInfo({ + hasLimit: tokenContractData.isERC20WithSpendLimit, + limit: tokenContractData.erc20SpendLimit.limitAmount, + limitUsed: tokenContractData.erc20SpendLimit.limitUsed, + refreshInterval: tokenContractData.erc20SpendLimitTimeInfo.refreshInterval, + lastUsedTime: tokenContractData.erc20SpendLimitTimeInfo.lastUsed + }); + } + + /// @inheritdoc ISessionKeyPermissionsPlugin + function getRequiredPaymaster(address account, address sessionKey) external view returns (address) { + SessionKeyId id = _sessionKeyIdOf(account, sessionKey); + _assertRegistered(id, sessionKey); + SessionKeyData storage sessionKeyData = _sessionKeyDataOf(account, id); + return sessionKeyData.hasRequiredPaymaster ? sessionKeyData.requiredPaymaster : address(0); + } + + /// @inheritdoc ISessionKeyPermissionsPlugin + function getGasSpendLimit(address account, address sessionKey) + external + view + override + returns (SpendLimitInfo memory, bool) + { + (SessionKeyData storage sessionKeyData,) = _loadSessionKey(account, sessionKey); + + bool hasLimit = sessionKeyData.hasGasLimit; + bool shouldReset = sessionKeyData.gasLimitResetThisBundle; + + if (hasLimit) { + return ( + SpendLimitInfo({ + hasLimit: true, + limit: sessionKeyData.gasLimit.limitAmount, + limitUsed: sessionKeyData.gasLimit.limitUsed, + refreshInterval: sessionKeyData.gasLimitTimeInfo.refreshInterval, + lastUsedTime: sessionKeyData.gasLimitTimeInfo.lastUsed + }), + shouldReset + ); + } else { + // The fields aren't cleared until the next time they are set, so report zeros. + return ( + SpendLimitInfo({hasLimit: false, limit: 0, limitUsed: 0, refreshInterval: 0, lastUsedTime: 0}), + shouldReset + ); + } + } +} diff --git a/src/plugins/session/permissions/SessionKeyPermissionsPlugin.sol b/src/plugins/session/permissions/SessionKeyPermissionsPlugin.sol new file mode 100644 index 00000000..6db2b1ab --- /dev/null +++ b/src/plugins/session/permissions/SessionKeyPermissionsPlugin.sol @@ -0,0 +1,838 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.21; + +import {IERC20} from "@openzeppelin/contracts/token/ERC20/IERC20.sol"; + +import {Call} from "../../../interfaces/IStandardExecutor.sol"; +import {BasePlugin} from "../../BasePlugin.sol"; +import {ISessionKeyPermissionsPlugin} from "./ISessionKeyPermissionsPlugin.sol"; +import {ISessionKeyPermissionsUpdates} from "./ISessionKeyPermissionsUpdates.sol"; +import {ISessionKeyPlugin} from "../ISessionKeyPlugin.sol"; +import {SessionKeyPermissionsLoupe} from "./SessionKeyPermissionsLoupe.sol"; + +import { + IPlugin, + ManifestAssociatedFunction, + ManifestAssociatedFunctionType, + ManifestExecutionHook, + ManifestFunction, + PluginManifest, + PluginMetadata, + SelectorPermission +} from "../../../interfaces/IPlugin.sol"; +import {IStandardExecutor} from "../../../interfaces/IStandardExecutor.sol"; +import {UserOperation} from "../../../interfaces/erc4337/UserOperation.sol"; + +/// @title Session Key Permissions Plugin +/// @author Alchemy +/// @notice This plugin allows users to configure and enforce permissions on session keys that have been +/// added by SessionKeyPlugin. +contract SessionKeyPermissionsPlugin is ISessionKeyPermissionsPlugin, SessionKeyPermissionsLoupe, BasePlugin { + string internal constant _NAME = "Session Key Permissions Plugin"; + string internal constant _VERSION = "1.0.0"; + string internal constant _AUTHOR = "Alchemy"; + + // Constants used in the manifest + uint256 internal constant _MANIFEST_DEPENDENCY_INDEX_OWNER_USER_OP_VALIDATION = 0; + uint256 internal constant _MANIFEST_DEPENDENCY_INDEX_OWNER_RUNTIME_VALIDATION = 1; + + // ┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ + // ┃ Execution functions ┃ + // ┗━━━━━━━━━━━━━━━━━━━━━━━━━━━┛ + + /// @inheritdoc ISessionKeyPermissionsPlugin + function registerKey(address sessionKey, bytes32 tag) external override { + SessionKeyId keyId = _sessionKeyIdOf(msg.sender, sessionKey); + if (SessionKeyId.unwrap(keyId) != 0) { + revert KeyAlreadyRegistered(sessionKey); + } + // Register the key with a new ID, and update the ID counter + // We use pre increment to prevent the first id from being zero + _updateSessionKeyId(msg.sender, sessionKey, SessionKeyId.wrap(bytes32(++_keyIdCounter[msg.sender]))); + + emit KeyRegistered(msg.sender, sessionKey, tag); + } + + /// @inheritdoc ISessionKeyPermissionsPlugin + function rotateKey(address oldSessionKey, address newSessionKey) external override { + SessionKeyId keyId = _sessionKeyIdOf(msg.sender, oldSessionKey); + _assertRegistered(keyId, oldSessionKey); + _updateSessionKeyId(msg.sender, oldSessionKey, SessionKeyId.wrap(bytes32(0))); + _updateSessionKeyId(msg.sender, newSessionKey, keyId); + + emit KeyRotated(msg.sender, oldSessionKey, newSessionKey); + } + + /// @inheritdoc ISessionKeyPermissionsPlugin + function updateKeyPermissions(address sessionKey, bytes[] calldata updates) external override { + (SessionKeyData storage sessionKeyData, SessionKeyId keyId) = _loadSessionKey(msg.sender, sessionKey); + + uint256 length = updates.length; + for (uint256 i = 0; i < length;) { + _performSessionKeyPermissionsUpdate(keyId, sessionKeyData, updates[i]); + + unchecked { + ++i; + } + } + + emit PermissionsUpdated(msg.sender, sessionKey, updates); + } + + /// @inheritdoc ISessionKeyPermissionsPlugin + function resetSessionKeyGasLimitTimestamp(address account, address sessionKey) external override { + (SessionKeyData storage sessionKeyData,) = _loadSessionKey(account, sessionKey); + if (sessionKeyData.gasLimitResetThisBundle) { + sessionKeyData.gasLimitResetThisBundle = false; + sessionKeyData.gasLimitTimeInfo.lastUsed = uint48(block.timestamp); + } + } + + // ┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ + // ┃ Plugin interface functions ┃ + // ┗━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┛ + + /// @inheritdoc BasePlugin + function preUserOpValidationHook(uint8 functionId, UserOperation calldata userOp, bytes32) + external + override + returns (uint256) + { + if (functionId == uint8(FunctionId.PRE_USER_OP_VALIDATION_HOOK_CHECK_PERMISSIONS)) { + return _checkUserOpPermissions(userOp); + } + revert NotImplemented(); + } + + function preExecutionHook(uint8 functionId, address, uint256, bytes calldata data) + external + override + returns (bytes memory) + { + if (functionId == uint8(FunctionId.PRE_EXECUTION_HOOK_UPDATE_LIMITS)) { + _updateLimitsPreExec(msg.sender, data); + } + return ""; + } + + /// @inheritdoc BasePlugin + function onInstall(bytes calldata) external override {} + + /// @inheritdoc BasePlugin + function onUninstall(bytes calldata) external override {} + + /// @inheritdoc BasePlugin + function pluginManifest() external pure override returns (PluginManifest memory) { + PluginManifest memory manifest; + + manifest.dependencyInterfaceIds = new bytes4[](2); + manifest.dependencyInterfaceIds[_MANIFEST_DEPENDENCY_INDEX_OWNER_USER_OP_VALIDATION] = + type(IPlugin).interfaceId; + manifest.dependencyInterfaceIds[_MANIFEST_DEPENDENCY_INDEX_OWNER_RUNTIME_VALIDATION] = + type(IPlugin).interfaceId; + + manifest.executionFunctions = new bytes4[](3); + manifest.executionFunctions[0] = ISessionKeyPermissionsPlugin.updateKeyPermissions.selector; + manifest.executionFunctions[1] = ISessionKeyPermissionsPlugin.registerKey.selector; + manifest.executionFunctions[2] = ISessionKeyPermissionsPlugin.rotateKey.selector; + + // Associate the owner's user op and runtime validator with permissions config + ManifestFunction memory ownerUserOpValidationFunction = ManifestFunction({ + functionType: ManifestAssociatedFunctionType.DEPENDENCY, + functionId: 0, // unused since it's a dependency + dependencyIndex: _MANIFEST_DEPENDENCY_INDEX_OWNER_USER_OP_VALIDATION + }); + manifest.userOpValidationFunctions = new ManifestAssociatedFunction[](3); + manifest.userOpValidationFunctions[0] = ManifestAssociatedFunction({ + executionSelector: this.updateKeyPermissions.selector, + associatedFunction: ownerUserOpValidationFunction + }); + manifest.userOpValidationFunctions[1] = ManifestAssociatedFunction({ + executionSelector: this.registerKey.selector, + associatedFunction: ownerUserOpValidationFunction + }); + manifest.userOpValidationFunctions[2] = ManifestAssociatedFunction({ + executionSelector: this.rotateKey.selector, + associatedFunction: ownerUserOpValidationFunction + }); + + ManifestFunction memory ownerRuntimeValidationFunction = ManifestFunction({ + functionType: ManifestAssociatedFunctionType.DEPENDENCY, + functionId: 0, // unused since it's a dependency + dependencyIndex: _MANIFEST_DEPENDENCY_INDEX_OWNER_RUNTIME_VALIDATION + }); + manifest.runtimeValidationFunctions = new ManifestAssociatedFunction[](3); + manifest.runtimeValidationFunctions[0] = ManifestAssociatedFunction({ + executionSelector: this.updateKeyPermissions.selector, + associatedFunction: ownerRuntimeValidationFunction + }); + manifest.runtimeValidationFunctions[1] = ManifestAssociatedFunction({ + executionSelector: this.registerKey.selector, + associatedFunction: ownerRuntimeValidationFunction + }); + manifest.runtimeValidationFunctions[2] = ManifestAssociatedFunction({ + executionSelector: this.rotateKey.selector, + associatedFunction: ownerRuntimeValidationFunction + }); + + // Apply the "enforcing" pre validation hook and pre exec hook + manifest.preUserOpValidationHooks = new ManifestAssociatedFunction[](1); + manifest.preUserOpValidationHooks[0] = ManifestAssociatedFunction({ + executionSelector: ISessionKeyPlugin.executeWithSessionKey.selector, + associatedFunction: ManifestFunction({ + functionType: ManifestAssociatedFunctionType.SELF, + functionId: uint8(FunctionId.PRE_USER_OP_VALIDATION_HOOK_CHECK_PERMISSIONS), + dependencyIndex: 0 // Unused. + }) + }); + + manifest.executionHooks = new ManifestExecutionHook[](1); + manifest.executionHooks[0] = ManifestExecutionHook({ + executionSelector: ISessionKeyPlugin.executeWithSessionKey.selector, + preExecHook: ManifestFunction({ + functionType: ManifestAssociatedFunctionType.SELF, + functionId: uint8(FunctionId.PRE_EXECUTION_HOOK_UPDATE_LIMITS), + dependencyIndex: 0 // Unused. + }), + postExecHook: ManifestFunction({ + functionType: ManifestAssociatedFunctionType.NONE, + functionId: 0, // Unused. + dependencyIndex: 0 // Unused. + }) + }); + + return manifest; + } + + /// @inheritdoc BasePlugin + function pluginMetadata() external pure virtual override returns (PluginMetadata memory) { + PluginMetadata memory metadata; + metadata.name = _NAME; + metadata.version = _VERSION; + metadata.author = _AUTHOR; + + // Permission strings + string memory modifyOwnershipPermission = "Modify Session Key Permissions"; + + // Permission descriptions + metadata.permissionDescriptors = new SelectorPermission[](1); + metadata.permissionDescriptors[0] = SelectorPermission({ + functionSelector: this.updateKeyPermissions.selector, + permissionDescription: modifyOwnershipPermission + }); + + return metadata; + } + + // ┏━━━━━━━━━━━━━━━┓ + // ┃ EIP-165 ┃ + // ┗━━━━━━━━━━━━━━━┛ + + /// @inheritdoc BasePlugin + function supportsInterface(bytes4 interfaceId) public view override returns (bool) { + return + interfaceId == type(ISessionKeyPermissionsPlugin).interfaceId || super.supportsInterface(interfaceId); + } + + // ┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ + // ┃ Internal / Private functions ┃ + // ┗━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┛ + + /// @dev A pre user op validation hook that checks the permissions of the key used to validate the user op. + function _checkUserOpPermissions(UserOperation calldata userOp) internal returns (uint256) { + // Decode the executions array and the session key from the user op's calldata + (Call[] memory calls, address sessionKey) = abi.decode(userOp.callData[4:], (Call[], address)); + + (SessionKeyData storage sessionKeyData, SessionKeyId keyId) = _loadSessionKey(msg.sender, sessionKey); + + // The session key's start time is the max of the key-specific validAfter and any time restrictions imposed + // by spending limits. + uint48 currentValidAfter = sessionKeyData.validAfter; + uint256 nativeTokenSpend; + + uint256 callsLength = calls.length; + // Only return validation success when there is at least one call + bool validationSuccess = calls.length > 0; + for (uint256 i = 0; i < callsLength;) { + Call memory call = calls[i]; + nativeTokenSpend += call.value; + validationSuccess = validationSuccess + && _checkCallPermissions( + sessionKeyData.contractAccessControlType, keyId, call.target, call.value, call.data + ); + + unchecked { + ++i; + } + } + + if (!sessionKeyData.nativeTokenSpendLimitBypassed) { + (bool spendLimitSuccess, uint48 spendLimitValidAfter) = _checkSpendLimitUsage( + nativeTokenSpend, + sessionKeyData.nativeTokenSpendLimitTimeInfo, + sessionKeyData.nativeTokenSpendLimit + ); + validationSuccess = validationSuccess && spendLimitSuccess; + currentValidAfter = _max(currentValidAfter, spendLimitValidAfter); + } + + if (sessionKeyData.hasGasLimit) { + // Multiplier for the verification gas limit is 3 if there is a paymaster, 1 otherwise. + // This is defined in EntryPoint v0.6.0, which uses the limit for the user op validation + paymaster + // validation, then again for up to two more calls of `postOp`. Later versions of the EntryPoint may + // change this scale factor or the usage of the verificaiton gas limit, so this value should be checked + // and updated if porting this plugin to a newer version of 4337. + uint256 multiplier = userOp.paymasterAndData.length > 0 ? 3 : 1; + uint256 maxGasFee = ( + userOp.callGasLimit + userOp.verificationGasLimit * multiplier + userOp.preVerificationGas + ) * userOp.maxFeePerGas; + (bool gasLimitSuccess, uint48 gasLimitValidAfter) = + _checkAndUpdateGasLimitUsage(maxGasFee, sessionKeyData); + validationSuccess = validationSuccess && gasLimitSuccess; + currentValidAfter = _max(currentValidAfter, gasLimitValidAfter); + + // Gas limit checking is the only type of permissions checking that has state changes performed during + // validation. This can potentially cause reputation damage to staked accounts if multiple user + // operations are accepted into the same bundle, then validation for one of the later operations fails + // due to the state change from the first. To protect from this, we require session keys to use their + // own address as the key portion of the user operation nonce field, in order to guarantee that they + // are used sequentially. + if (uint192(userOp.nonce >> 64) != uint192(uint160(sessionKey))) { + validationSuccess = false; + } + } + + if (sessionKeyData.hasRequiredPaymaster) { + // Technically this following line would right-pad the contents of the paymasterAndData bytes field if + // it is <20 bytes, which seems like it could cause false positive matches. However, the EntryPoint + // validates that if the paymasterAndData field is >0 length, then the first 20 bytes must be a valid + // paymaster address, so this is safe. (It would revert with "AA93 invalid paymasterAndData"). + // Additionally, we don't have to worry about a zero-length paymasterAndData being casted to + // address(0), because then the subsequent check will fail, as it is impossible to have + // sessionKeyData.hasRequiredPaymaster == true and a zero address sessionKeyData.requiredPaymaster, by + // how the rule's updating function works. + address userOpPaymaster = address(bytes20(userOp.paymasterAndData)); + validationSuccess = validationSuccess && (userOpPaymaster == sessionKeyData.requiredPaymaster); + } + // Validation return data is 1 in the case of an invalid signature, + // otherwise a packed struct of the aggregator address (0 here), and two + // 6-byte timestamps indicating the start and end times at which the op + // is valid. + return uint160(!validationSuccess ? 1 : 0) | (uint256(sessionKeyData.validUntil) << 160) + | (uint256(currentValidAfter) << (208)); + } + + /// @dev Checks permissions on a per-call basis. Should be run in the pre user op validation hook once per + /// `Call` struct in the user op's calldata. + function _checkCallPermissions( + ContractAccessControlType accessControlType, + SessionKeyId keyId, + address target, + uint256, /*value*/ + bytes memory callData + ) internal view returns (bool validationSuccess) { + validationSuccess = true; + + // This right-pads the selector variable if callData is <4 bytes. + bytes4 selector = bytes4(callData); + + ContractData storage contractData = _contractDataOf(msg.sender, keyId, target); + + // Validate access control + if (accessControlType == ContractAccessControlType.ALLOWLIST) { + if (!contractData.isOnList) return false; + if (!contractData.checkSelectors) return true; + // If selectors are specified, the function must be on the list. + FunctionData storage functionData = _functionDataOf(msg.sender, keyId, target, selector); + validationSuccess = functionData.isOnList; + } else if (accessControlType == ContractAccessControlType.DENYLIST) { + if (!contractData.isOnList) return true; + if (!contractData.checkSelectors) return false; + // If selectors are specified, the function must not be on the list. + // A denylist with function selectors allows function calls that are not on the list. + FunctionData storage functionData = _functionDataOf(msg.sender, keyId, target, selector); + validationSuccess = !functionData.isOnList; + } + } + + /// @dev Runs as a pre exec hook, and updates the spend limits of the session key in use + function _updateLimitsPreExec(address account, bytes calldata callData) internal { + (Call[] memory calls, address sessionKey) = abi.decode(callData[4:], (Call[], address)); + uint256 callsLength = calls.length; + + uint256 newNativeTokenUsage; + + (SessionKeyData storage sessionKeyData, SessionKeyId keyId) = _loadSessionKey(account, sessionKey); + + for (uint256 i = 0; i < callsLength;) { + Call memory call = calls[i]; + newNativeTokenUsage += call.value; + + ContractData storage contractData = _contractDataOf(account, keyId, call.target); + if (contractData.isERC20WithSpendLimit) { + // Tally up the amount being spent in each call to an ERC-20 contract. + // Since this is a runtime-only check, we can interact with the stored limits after each call in + // the batch and can still enforce the limits correctly. + uint256 spendAmount = _getTokenSpendAmount(account, call.target, call.data); + if ( + !_runtimeUpdateSpendLimitUsage( + spendAmount, contractData.erc20SpendLimitTimeInfo, contractData.erc20SpendLimit + ) + ) { + revert ERC20SpendLimitExceeded(msg.sender, sessionKey, call.target); + } + } + + unchecked { + ++i; + } + } + + if (!sessionKeyData.nativeTokenSpendLimitBypassed) { + // Only run this step if a native token limit is set. + if ( + !_runtimeUpdateSpendLimitUsage( + newNativeTokenUsage, + sessionKeyData.nativeTokenSpendLimitTimeInfo, + sessionKeyData.nativeTokenSpendLimit + ) + ) { + revert NativeTokenSpendLimitExceeded(msg.sender, sessionKey); + } + } + + if (sessionKeyData.gasLimitResetThisBundle) { + // If the gas limit was reset during validation, we must reset the flag here and update the last used + // field to `block.timestamp`. Note that if execution reverts, this step will be undone, and the flag + // will remain set on the key. If there is enough gas still within the next interval to support another + // call that succeeds, then the issue will be fixed. If, however, the gas in the upcoming interval is + // exhausted and the flag remains enabled, that session key will be stuck until the owner or another + // actor forces the last used timestamp to be reset and the flag cleared. + sessionKeyData.gasLimitResetThisBundle = false; + sessionKeyData.gasLimitTimeInfo.lastUsed = uint48(block.timestamp); + } + } + + /// @dev For use within user op validation + function _checkSpendLimitUsage(uint256 newUsage, SpendLimitTimeInfo storage timeInfo, SpendLimit storage limit) + internal + view + returns (bool, uint48) + { + bool validationSuccess; + // This value will be coalesced with the overall key's start time to return the max value, so it is ok to + // declare it as zero here and only use it if needed. + uint48 validAfter; + + uint48 lastUsed = timeInfo.lastUsed; + uint48 refreshInterval = timeInfo.refreshInterval; + + uint256 currentUsage = limit.limitUsed; + uint256 spendLimit = limit.limitAmount; + + // Gracefully report SIG_FAIL on overflow, rather than revert. + uint256 newTotalUsage; + unchecked { + newTotalUsage = newUsage + currentUsage; + if (newTotalUsage < newUsage) { + // If we overflow, fail early. + return (false, 0); + } + } + + if (refreshInterval == 0) { + // We don't have a refresh interval reset, so just check that the spend limits are not exceeded. + // The limits are not updated until the pre exec hook, in order to use `block.timestamp`. + validationSuccess = (newTotalUsage <= spendLimit); + } + // RefreshInterval != 0, meaning we have a time period over which the spend limit resets. + else if (newTotalUsage <= spendLimit) { + // The spend amount here fits within the existing interval, + // so we're OK to just accept the result. + validationSuccess = true; + } + // The spend amount does not fit within the current interval. + // It may or may not fit into the next one. + else if (newUsage <= spendLimit) { + // The spend amount fits into the next interval, so we're OK to accept the result, if we + // wait until the refresh and start of the next interval. + validationSuccess = true; + validAfter = lastUsed + refreshInterval; + } else { + // The spend amount does not fit, even into the next interval, + // so we must reject the operation. + validationSuccess = false; + } + + return (validationSuccess, validAfter); + } + + /// @dev For use within user op validation. Gas limits are both checked and updated within the user op + /// validation phase. + function _checkAndUpdateGasLimitUsage(uint256 newUsage, SessionKeyData storage keyData) + internal + returns (bool, uint48) + { + bool validationSuccess; + uint48 validAfter; + + uint48 lastUsed = keyData.gasLimitTimeInfo.lastUsed; + uint48 refreshInterval = keyData.gasLimitTimeInfo.refreshInterval; + + uint256 currentUsage = keyData.gasLimit.limitUsed; + uint256 gasLimit = keyData.gasLimit.limitAmount; + + bool gasLimitResetThisBundle = keyData.gasLimitResetThisBundle; + + // Gracefully report SIG_FAIL on overflow, rather than revert. + uint256 newTotalUsage; + unchecked { + newTotalUsage = newUsage + currentUsage; + if (newTotalUsage < newUsage) { + // If we overflow, fail early. + return (false, 0); + } + } + + if (refreshInterval == 0) { + // We don't have a refresh interval reset, so just check that the gas limits are not exceeded and + // update their amounts. + validationSuccess = newTotalUsage <= gasLimit; + keyData.gasLimit.limitUsed += newUsage; + } + // RefreshInterval != 0, meaning we have a time period over which the gas limit resets. + else if (newTotalUsage <= gasLimit) { + // The gas amount here fits within the existing refresh interval, + // so we're OK to just accept the result. + validationSuccess = true; + keyData.gasLimit.limitUsed = newTotalUsage; + // If this is an incremental usage after a failed "reset" attempt, then enforce this existing + // validAfter window. + validAfter = (gasLimitResetThisBundle ? lastUsed + refreshInterval : 0); + } + // The gas amount does not fit within the current refresh interval. + // It may or may not fit into the next one, provided the next interval usage has not already started. + else if (newUsage <= gasLimit && !gasLimitResetThisBundle) { + // The gas amount fits into the next refresh interval, so we're OK to accept the result, if we + // wait until the start of the next refresh interval. + validationSuccess = true; + validAfter = lastUsed + refreshInterval; + + // NOTE: This section is different than the other spend limit checks, due to how gas limits are + // updated during validation, and the fact that `block.timestamp` is inaccessible during + // validation. + // If we allow this check to complete at this point without updating some state to indicate that a new + // interval has started, there is a risk that this particular call + // path can cause a session key to burn more gas per time than the limit was set at. This can + // happen if the "new interval" case keeps getting triggered while the execution phase reverts, + // due to the fact that those reverts will undo the state change updating the "last used time" + // variable. To address this, we set a flag here called `gasLimitResetThisBundle` to indicate that + // during execution, the plugin should attempt to update the last used time to the current + // `block.timestamp`. + + keyData.gasLimitResetThisBundle = true; + keyData.gasLimit.limitUsed = newUsage; + } else { + // The gas amount does not fit, even into the next refresh interval, + // so we must reject the operation. + validationSuccess = false; + } + + return (validationSuccess, validAfter); + } + + /// @dev Re-check and update the spend limit during the execution phase. + /// We MUST re-check the limits, despite the fact that they are checked during validation. + // This is to protect from the case where multiple user operations are included in the same bundle, which + // can happen either if the account is staked or if the bundle is sent by someone other than a 4337-compliant + // bundler. + function _runtimeUpdateSpendLimitUsage( + uint256 newUsage, + SpendLimitTimeInfo storage timeInfo, + SpendLimit storage limit + ) internal returns (bool) { + uint48 refreshInterval = timeInfo.refreshInterval; + uint48 lastUsed = timeInfo.lastUsed; + uint256 spendLimit = limit.limitAmount; + uint256 currentUsage = limit.limitUsed; + + if (refreshInterval == 0 || lastUsed + refreshInterval > block.timestamp) { + // We either don't have a refresh interval, or the current one is still active. + + // Must re-check the limits to handle changes due to other user ops. + // We manually check for overflows here to give a more informative error message. + uint256 sum; + unchecked { + sum = newUsage + currentUsage; + } + if (sum < newUsage || sum > spendLimit) { + // If we overflow, or if the limit is exceeded, fail here and revert in the parent context. + return false; + } + + // We won't update the refresh interval last used variable now, so just update the spend limit. + limit.limitUsed += newUsage; + } else { + // We have a interval active that is currently resetting. + // Must re-check the amount to handle changes due to other user ops. + // It only needs to fit within the new refresh interval, since the old one has passed. + if (newUsage > spendLimit) { + return false; + } + + // The refresh interval has passed, so we can reset the spend limit to the new usage. + limit.limitUsed = newUsage; + timeInfo.lastUsed = uint48(block.timestamp); + } + + return true; + } + + // ERC-20 decoding logic + + /// @notice Decode the amount of a token a call is sending. + /// @dev This only supports the following standard ERC-20 functions: + /// - transfer(address,uint256) + /// - approve(address,uint256) + /// When decoding the approve function, this will first check the existing allowance of the spender. This + /// lookup is not necessarily in storage associated with the account, so this check should only be used during + /// runtime, not user op validation. + /// @param account The account that is sending the transaction. + /// @param callData The calldata of the transaction. + /// @return The amount of the token being sent. Zero if the call is not recognized as a spend. + function _getTokenSpendAmount(address account, address token, bytes memory callData) + internal + view + returns (uint256) + { + // Get the selector. + // Right-padding with zeroes here is OK, because none of the selectors we're comparing this to have + // trailing zero bytes. + bytes4 selector = bytes4(callData); + + if (selector == IERC20.transfer.selector) { + // Expected length: 68 bytes (4 selector + 32 address + 32 amount) + if (callData.length < 68) { + return 0; + } + + // Load the amount being sent. + // Solidity doesn't support access a whole word from a bytes memory at once, only a single byte, and + // trying to use abi.decode would require copying the data to remove the selector, which is expensive. + // Instead, we use inline assembly to load the amount directly. This is safe because we've checked the + // length of the call data. + uint256 amount; + assembly ("memory-safe") { + // Jump 68 words forward: 32 for the length field, 4 for the selector, and 32 for the to address. + amount := mload(add(callData, 68)) + } + return amount; + } else if (selector == IERC20.approve.selector) { + // Expected length: 68 bytes (4 selector + 32 address + 32 amount) + if (callData.length < 68) { + return 0; + } + // We must check the existing allowance of the spender. + address spender; + assembly ("memory-safe") { + // Jump 36 words forward: 32 for the length field and 4 for the selector. + spender := mload(add(callData, 36)) + // Mask out the upper 12 bytes of the address, since we only care about the lower 20 bytes. + // If the upper bits are nonzero, typically the token contract should revert as the input is + // malformed. We mask it here only as a precaution for tokens that may not fully conform to the ABI + // standard. + spender := and(spender, shr(96, not(0))) + } + uint256 existingAllowance = IERC20(token).allowance(account, spender); + uint256 approveAmount; + assembly ("memory-safe") { + // Jump 68 words forward: 32 for the length field, 4 for the selector, and 32 for the spender + // address. + approveAmount := mload(add(callData, 68)) + } + // We only consider this spending if the new allowance is greater than the existing allowance. + if (approveAmount <= existingAllowance) { + return 0; + } + + // Return the difference between the new allowance and the existing allowance. + // Unchecked is OK here since we've asserted the new allowance is greater than the existing allowance. + unchecked { + return approveAmount - existingAllowance; + } + + // There is an odd edge-case with the approval amount check. Since multiple approves may be batched, if + // the first approve lowers the allowance but the second one raises it by an amount that's allowed + // within the spend limits, the calls will be permitted. This won't let the session key actually spend + // more than expected, but the spender contract may experience their allowance going down from the + // previous amount by more than the spending limit, then back up to the previous amount plus the + // spending limit. + } + // Unrecognzied function selector + return 0; + } + + // Permissions updating functions + + function _performSessionKeyPermissionsUpdate( + SessionKeyId keyId, + SessionKeyData storage sessionKeyData, + bytes calldata update + ) internal { + if (update.length < 4) { + revert InvalidPermissionsUpdate(); + } + + bytes4 updateSelector = bytes4(update); + + // If/else chain to find the right interal update function to perform. + if (updateSelector == ISessionKeyPermissionsUpdates.setAccessListType.selector) { + ContractAccessControlType contractAccessControlType = + abi.decode(update[4:], (ContractAccessControlType)); + _setAccessListType(sessionKeyData, contractAccessControlType); + } else if (updateSelector == ISessionKeyPermissionsUpdates.updateAccessListAddressEntry.selector) { + (address contractAddress, bool isOnList, bool checkSelectors) = + abi.decode(update[4:], (address, bool, bool)); + _updateAccessListAddressEntry(keyId, contractAddress, isOnList, checkSelectors); + } else if (updateSelector == ISessionKeyPermissionsUpdates.updateAccessListFunctionEntry.selector) { + (address contractAddress, bytes4 selector, bool isOnList) = + abi.decode(update[4:], (address, bytes4, bool)); + _updateAccessListFunctionEntry(keyId, contractAddress, selector, isOnList); + } else if (updateSelector == ISessionKeyPermissionsUpdates.updateTimeRange.selector) { + (uint48 validAfter, uint48 validUntil) = abi.decode(update[4:], (uint48, uint48)); + _updateTimeRange(sessionKeyData, validAfter, validUntil); + } else if (updateSelector == ISessionKeyPermissionsUpdates.setNativeTokenSpendLimit.selector) { + (uint256 ethSpendLimit, uint48 refreshInterval) = abi.decode(update[4:], (uint256, uint48)); + _setNativeTokenSpendLimit(sessionKeyData, ethSpendLimit, refreshInterval); + } else if (updateSelector == ISessionKeyPermissionsUpdates.setERC20SpendLimit.selector) { + (address token, uint256 spendLimit, uint48 refreshInterval) = + abi.decode(update[4:], (address, uint256, uint48)); + _setERC20SpendLimit(keyId, token, spendLimit, refreshInterval); + } else if (updateSelector == ISessionKeyPermissionsUpdates.setGasSpendLimit.selector) { + (uint256 spendLimit, uint48 refreshInterval) = abi.decode(update[4:], (uint256, uint48)); + _setGasSpendLimit(sessionKeyData, spendLimit, refreshInterval); + } else if (updateSelector == ISessionKeyPermissionsUpdates.setRequiredPaymaster.selector) { + address requiredPaymaster = abi.decode(update[4:], (address)); + _setRequiredPaymaster(sessionKeyData, requiredPaymaster); + } else { + revert InvalidPermissionsUpdate(); + } + } + + function _setAccessListType( + SessionKeyData storage sessionKeyData, + ContractAccessControlType contractAccessControlType + ) internal { + sessionKeyData.contractAccessControlType = contractAccessControlType; + } + + function _updateAccessListAddressEntry( + SessionKeyId keyId, + address contractAddress, + bool isOnList, + bool checkSelectors + ) internal { + ContractData storage contractData = _contractDataOf(msg.sender, keyId, contractAddress); + contractData.isOnList = isOnList; + contractData.checkSelectors = checkSelectors; + } + + function _updateAccessListFunctionEntry( + SessionKeyId keyId, + address contractAddress, + bytes4 selector, + bool isOnList + ) internal { + FunctionData storage functionData = _functionDataOf(msg.sender, keyId, contractAddress, selector); + functionData.isOnList = isOnList; + } + + function _updateTimeRange(SessionKeyData storage sessionKeyData, uint48 validAfter, uint48 validUntil) + internal + { + sessionKeyData.validAfter = validAfter; + sessionKeyData.validUntil = validUntil; + } + + function _setNativeTokenSpendLimit( + SessionKeyData storage sessionKeyData, + uint256 ethSpendLimit, + uint48 refreshInterval + ) internal { + // The flag we store for native token spend is inverted from other similar flags + sessionKeyData.nativeTokenSpendLimitBypassed = !_updateSpendLimits( + ethSpendLimit, + refreshInterval, + sessionKeyData.nativeTokenSpendLimitTimeInfo, + sessionKeyData.nativeTokenSpendLimit + ); + } + + function _setERC20SpendLimit(SessionKeyId keyId, address token, uint256 spendLimit, uint48 refreshInterval) + internal + { + if (token == address(0)) { + revert InvalidToken(); + } + + ContractData storage tokenContractData = _contractDataOf(msg.sender, keyId, token); + + tokenContractData.isERC20WithSpendLimit = _updateSpendLimits( + spendLimit, + refreshInterval, + tokenContractData.erc20SpendLimitTimeInfo, + tokenContractData.erc20SpendLimit + ); + } + + function _setGasSpendLimit(SessionKeyData storage sessionKeyData, uint256 spendLimit, uint48 refreshInterval) + internal + { + // Start by clearing the reset flag, if set. + if (sessionKeyData.gasLimitResetThisBundle) { + sessionKeyData.gasLimitResetThisBundle = false; + // Don't need to update the last used timestamp, since that will be updated within _updateSpendLimits. + } + sessionKeyData.hasGasLimit = _updateSpendLimits( + spendLimit, refreshInterval, sessionKeyData.gasLimitTimeInfo, sessionKeyData.gasLimit + ); + } + + function _setRequiredPaymaster(SessionKeyData storage sessionKeyData, address requiredPaymaster) internal { + if (requiredPaymaster == address(0)) { + sessionKeyData.hasRequiredPaymaster = false; + } else { + sessionKeyData.hasRequiredPaymaster = true; + sessionKeyData.requiredPaymaster = requiredPaymaster; + } + } + + /// @dev A helper function re-used across the spend limit updating functions. + function _updateSpendLimits( + uint256 newLimit, + uint48 newRefreshInterval, + SpendLimitTimeInfo storage timeInfo, + SpendLimit storage spendLimit + ) internal returns (bool isEnabled) { + if (newLimit == type(uint256).max) { + isEnabled = false; + // This field must be manually cleared to have the expected behavior if the spend limit is re-enabled + // in the future. + // Other fields are implicity replaced once the spend limit is configured. + spendLimit.limitUsed = 0; + } else { + isEnabled = true; + spendLimit.limitAmount = newLimit; + timeInfo.refreshInterval = newRefreshInterval; + if (newRefreshInterval == 0) { + timeInfo.lastUsed = 0; + } else { + timeInfo.lastUsed = uint48(block.timestamp); + } + } + } + + function _max(uint48 a, uint48 b) internal pure returns (uint48) { + return a > b ? a : b; + } +} diff --git a/test/TestUtils.sol b/test/TestUtils.sol new file mode 100644 index 00000000..05b40c6c --- /dev/null +++ b/test/TestUtils.sol @@ -0,0 +1,21 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.21; + +import {Test, console} from "forge-std/Test.sol"; + +contract TestUtils is Test { + function printStorageReadsAndWrites(address addr) internal { + (bytes32[] memory accountReads, bytes32[] memory accountWrites) = vm.accesses(addr); + for (uint256 i = 0; i < accountWrites.length; i++) { + bytes32 valWritten = vm.load(addr, accountWrites[i]); + console.log( + string.concat("write loc: ", vm.toString(accountWrites[i]), " val: ", vm.toString(valWritten)) + ); + } + + for (uint256 i = 0; i < accountReads.length; i++) { + bytes32 valRead = vm.load(addr, accountReads[i]); + console.log(string.concat("read: ", vm.toString(accountReads[i]), " val: ", vm.toString(valRead))); + } + } +} diff --git a/test/Utils.sol b/test/Utils.sol new file mode 100644 index 00000000..6c4a7ef0 --- /dev/null +++ b/test/Utils.sol @@ -0,0 +1,12 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.21; + +library Utils { + function reverseAddressArray(address[] calldata array) public pure returns (address[] memory reversedArray) { + uint256 len = array.length; + reversedArray = new address[](len); + for (uint256 i; i < len; i++) { + reversedArray[i] = array[len - i - 1]; + } + } +} diff --git a/test/account/AccountExecHooks.t.sol b/test/account/AccountExecHooks.t.sol new file mode 100644 index 00000000..cdbaf18f --- /dev/null +++ b/test/account/AccountExecHooks.t.sol @@ -0,0 +1,1318 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.21; + +import {Test} from "forge-std/Test.sol"; + +import {ECDSA} from "@openzeppelin/contracts/utils/cryptography/ECDSA.sol"; +import {EntryPoint} from "@eth-infinitism/account-abstraction/core/EntryPoint.sol"; + +import {UpgradeableModularAccount} from "../../src/account/UpgradeableModularAccount.sol"; +import {MultiOwnerPlugin} from "../../src/plugins/owner/MultiOwnerPlugin.sol"; +import {IEntryPoint} from "../../src/interfaces/erc4337/IEntryPoint.sol"; +import {IPluginManager} from "../../src/interfaces/IPluginManager.sol"; +import {FunctionReference, FunctionReferenceLib} from "../../src/libraries/FunctionReferenceLib.sol"; +import { + IPlugin, + ManifestExecutionHook, + PluginManifest, + ManifestFunction, + ManifestAssociatedFunctionType, + ManifestAssociatedFunction +} from "../../src/interfaces/IPlugin.sol"; + +import {MultiOwnerMSCAFactory} from "../../src/factory/MultiOwnerMSCAFactory.sol"; +import {MockPlugin} from "../mocks/MockPlugin.sol"; + +contract UpgradeableModularAccountExecHooksTest is Test { + using ECDSA for bytes32; + + IEntryPoint public entryPoint; + MultiOwnerPlugin public multiOwnerPlugin; + MultiOwnerMSCAFactory public factory; + MockPlugin public mockPlugin1; + MockPlugin public mockPlugin2; + bytes32 public manifestHash1; + bytes32 public manifestHash2; + + address public owner1; + uint256 public owner1Key; + UpgradeableModularAccount public account1; + + bytes4 internal constant _EXEC_SELECTOR = bytes4(uint32(1)); + uint8 internal constant _PRE_HOOK_FUNCTION_ID_1 = 1; + uint8 internal constant _POST_HOOK_FUNCTION_ID_2 = 2; + uint8 internal constant _PRE_HOOK_FUNCTION_ID_3 = 3; + uint8 internal constant _POST_HOOK_FUNCTION_ID_4 = 4; + + PluginManifest public m1; + PluginManifest public m2; + + event PluginInstalled( + address indexed plugin, + bytes32 manifestHash, + FunctionReference[] dependencies, + IPluginManager.InjectedHook[] injectedHooks + ); + event PluginUninstalled(address indexed plugin, bool indexed callbacksSucceeded); + // emitted by MockPlugin + event ReceivedCall(bytes msgData, uint256 msgValue); + + function setUp() public { + entryPoint = IEntryPoint(address(new EntryPoint())); + multiOwnerPlugin = new MultiOwnerPlugin(); + + (owner1, owner1Key) = makeAddrAndKey("owner1"); + address impl = address(new UpgradeableModularAccount(IEntryPoint(address(entryPoint)))); + + factory = new MultiOwnerMSCAFactory( + address(this), + address(multiOwnerPlugin), + impl, + keccak256(abi.encode(multiOwnerPlugin.pluginManifest())), + entryPoint + ); + + address[] memory owners = new address[](1); + owners[0] = owner1; + account1 = UpgradeableModularAccount(payable(factory.createAccount(0, owners))); + vm.deal(address(account1), 100 ether); + + entryPoint.depositTo{value: 1 wei}(address(account1)); + + m1.executionFunctions.push(_EXEC_SELECTOR); + + m1.runtimeValidationFunctions.push( + ManifestAssociatedFunction({ + executionSelector: _EXEC_SELECTOR, + associatedFunction: ManifestFunction({ + functionType: ManifestAssociatedFunctionType.RUNTIME_VALIDATION_ALWAYS_ALLOW, + functionId: 0, + dependencyIndex: 0 + }) + }) + ); + } + + /// @dev Plugin 1 hook pair: [1, null] + /// Expected execution: [1, null] + function test_preExecHook_install() public { + vm.startPrank(owner1); + + _installPlugin1WithHooks( + _EXEC_SELECTOR, + ManifestFunction({ + functionType: ManifestAssociatedFunctionType.SELF, + functionId: _PRE_HOOK_FUNCTION_ID_1, + dependencyIndex: 0 + }), + ManifestFunction({functionType: ManifestAssociatedFunctionType.NONE, functionId: 0, dependencyIndex: 0}) + ); + + vm.stopPrank(); + } + + /// @dev Plugin 1 hook pair: [1, null] + /// Expected execution: [1, null] + function test_preExecHook_run() public { + test_preExecHook_install(); + + vm.startPrank(owner1); + + vm.expectEmit(true, true, true, true); + emit ReceivedCall( + abi.encodeWithSelector( + IPlugin.preExecutionHook.selector, + _PRE_HOOK_FUNCTION_ID_1, + owner1, // caller + 0, // msg.value in call to account + abi.encodeWithSelector(_EXEC_SELECTOR) + ), + 0 // msg value in call to plugin + ); + + (bool success,) = address(account1).call(abi.encodeWithSelector(_EXEC_SELECTOR)); + assertTrue(success); + + vm.stopPrank(); + } + + function testFuzz_preExecHook_revertData(bytes memory hookRevertReason) public { + vm.startPrank(owner1); + MockPlugin hookPlugin = _installPlugin1WithHooks( + _EXEC_SELECTOR, + ManifestFunction({ + functionType: ManifestAssociatedFunctionType.SELF, + functionId: _PRE_HOOK_FUNCTION_ID_1, + dependencyIndex: 0 + }), + ManifestFunction({functionType: ManifestAssociatedFunctionType.NONE, functionId: 0, dependencyIndex: 0}) + ); + + vm.mockCallRevert( + address(hookPlugin), + abi.encodeCall(IPlugin.preExecutionHook, (1, owner1, 0, abi.encodeWithSelector(_EXEC_SELECTOR))), + hookRevertReason + ); + (bool success, bytes memory returnData) = address(account1).call(abi.encodeWithSelector(_EXEC_SELECTOR)); + assertFalse(success); + assertEq( + returnData, + abi.encodeWithSelector( + UpgradeableModularAccount.PreExecHookReverted.selector, address(hookPlugin), 1, hookRevertReason + ) + ); + vm.stopPrank(); + } + + /// @dev Plugin 1 hook pair: [1, null] + /// Expected execution: [1, null] + function test_preExecHook_uninstall() public { + test_preExecHook_install(); + + vm.startPrank(owner1); + + _uninstallPlugin(mockPlugin1); + + vm.stopPrank(); + } + + function test_preExecHook_revertAlwaysDeny() public { + vm.startPrank(owner1); + + _installPlugin1WithHooks( + _EXEC_SELECTOR, + ManifestFunction({ + functionType: ManifestAssociatedFunctionType.PRE_HOOK_ALWAYS_DENY, + functionId: 0, + dependencyIndex: 0 + }), + ManifestFunction({functionType: ManifestAssociatedFunctionType.NONE, functionId: 0, dependencyIndex: 0}) + ); + + (bool success, bytes memory returnData) = address(account1).call(abi.encodeWithSelector(_EXEC_SELECTOR)); + assertFalse(success); + assertEq(returnData, abi.encodeWithSelector(UpgradeableModularAccount.AlwaysDenyRule.selector)); + + vm.stopPrank(); + } + + /// @dev Plugin 1 hook pair: [1, 2] + /// Expected execution: [1, 2] + function test_execHookPair_install() public { + vm.startPrank(owner1); + + _installPlugin1WithHooks( + _EXEC_SELECTOR, + ManifestFunction({ + functionType: ManifestAssociatedFunctionType.SELF, + functionId: _PRE_HOOK_FUNCTION_ID_1, + dependencyIndex: 0 + }), + ManifestFunction({ + functionType: ManifestAssociatedFunctionType.SELF, + functionId: _POST_HOOK_FUNCTION_ID_2, + dependencyIndex: 0 + }) + ); + + vm.stopPrank(); + } + + /// @dev Plugin 1 hook pair: [1, 2] + /// Expected execution: [1, 2] + function test_execHookPair_run() public { + test_execHookPair_install(); + + vm.startPrank(owner1); + + vm.expectEmit(true, true, true, true); + // pre hook call + emit ReceivedCall( + abi.encodeWithSelector( + IPlugin.preExecutionHook.selector, + _PRE_HOOK_FUNCTION_ID_1, + owner1, // caller + 0, // msg.value in call to account + abi.encodeWithSelector(_EXEC_SELECTOR) + ), + 0 // msg value in call to plugin + ); + vm.expectEmit(true, true, true, true); + // exec call + emit ReceivedCall(abi.encodePacked(_EXEC_SELECTOR), 0); + vm.expectEmit(true, true, true, true); + // post hook call + emit ReceivedCall( + abi.encodeCall( + IPlugin.postExecutionHook, (_POST_HOOK_FUNCTION_ID_2, abi.encode(_PRE_HOOK_FUNCTION_ID_1)) + ), + 0 // msg value in call to plugin + ); + + (bool success,) = address(account1).call(abi.encodeWithSelector(_EXEC_SELECTOR)); + assertTrue(success); + + vm.stopPrank(); + } + + /// @dev Plugin 1 hook pair: [1, 2] + /// Expected execution: [1, 2] + function test_execHookPair_uninstall() public { + test_execHookPair_install(); + + vm.startPrank(owner1); + + _uninstallPlugin(mockPlugin1); + + vm.stopPrank(); + } + + /// @dev Plugin 1 hook pair: [null, 2] + /// Expected execution: [null, 2] + function test_postOnlyExecHook_install() public { + vm.startPrank(owner1); + + _installPlugin1WithHooks( + _EXEC_SELECTOR, + ManifestFunction({functionType: ManifestAssociatedFunctionType.NONE, functionId: 0, dependencyIndex: 0}), + ManifestFunction({ + functionType: ManifestAssociatedFunctionType.SELF, + functionId: _POST_HOOK_FUNCTION_ID_2, + dependencyIndex: 0 + }) + ); + + vm.stopPrank(); + } + + /// @dev Plugin 1 hook pair: [null, 2] + /// Expected execution: [null, 2] + function test_postOnlyExecHook_run() public { + test_postOnlyExecHook_install(); + + vm.startPrank(owner1); + + vm.expectEmit(true, true, true, true); + emit ReceivedCall( + abi.encodeCall(IPlugin.postExecutionHook, (_POST_HOOK_FUNCTION_ID_2, "")), + 0 // msg value in call to plugin + ); + + (bool success,) = address(account1).call(abi.encodeWithSelector(_EXEC_SELECTOR)); + assertTrue(success); + + vm.stopPrank(); + } + + /// @dev Plugin 1 hook pair: [null, 2] + /// Expected execution: [null, 2] + function test_postOnlyExecHook_uninstall() public { + test_postOnlyExecHook_install(); + + vm.startPrank(owner1); + + _uninstallPlugin(mockPlugin1); + + vm.stopPrank(); + } + + /// @dev Plugin 1 hook pair: [1, null] + /// Plugin 2 hook pair: [1, null] + /// Expected execution: [1, null] + function test_overlappingPreExecHooks_install() public { + vm.startPrank(owner1); + + // Install the first plugin. + _installPlugin1WithHooks( + _EXEC_SELECTOR, + ManifestFunction({ + functionType: ManifestAssociatedFunctionType.SELF, + functionId: _PRE_HOOK_FUNCTION_ID_1, + dependencyIndex: 0 + }), + ManifestFunction({functionType: ManifestAssociatedFunctionType.NONE, functionId: 0, dependencyIndex: 0}) + ); + + // Install a second plugin that applies the first plugin's hook to the same selector. + FunctionReference[] memory dependencies = new FunctionReference[](1); + dependencies[0] = FunctionReferenceLib.pack(address(mockPlugin1), _PRE_HOOK_FUNCTION_ID_1); + _installPlugin2WithHooks( + _EXEC_SELECTOR, + ManifestFunction({ + functionType: ManifestAssociatedFunctionType.DEPENDENCY, + functionId: 0, + dependencyIndex: 0 + }), + ManifestFunction({functionType: ManifestAssociatedFunctionType.NONE, functionId: 0, dependencyIndex: 0}), + dependencies + ); + + vm.stopPrank(); + } + + /// @dev Plugin 1 hook pair: [1, null] + /// Plugin 2 hook pair: [1, null] + /// Expected execution: [1, null] + function test_overlappingPreExecHooks_run() public { + test_overlappingPreExecHooks_install(); + + vm.startPrank(owner1); + + // Expect the pre hook to be called just once. + vm.expectCall( + address(mockPlugin1), + abi.encodeWithSelector( + IPlugin.preExecutionHook.selector, + _PRE_HOOK_FUNCTION_ID_1, + owner1, // caller + 0, // msg.value in call to account + abi.encodeWithSelector(_EXEC_SELECTOR) + ), + 1 + ); + + (bool success,) = address(account1).call(abi.encodeWithSelector(_EXEC_SELECTOR)); + assertTrue(success); + + vm.stopPrank(); + } + + /// @dev Plugin 1 hook pair: [1, null] + /// Plugin 2 hook pair: [1, null] + /// Expected execution: [1, null] + function test_overlappingPreExecHooks_uninstall() public { + test_overlappingPreExecHooks_install(); + + vm.startPrank(owner1); + + // Uninstall the second plugin. + _uninstallPlugin(mockPlugin2); + + // Expect the pre hook to still exist after uninstalling a plugin with a duplicate hook. + vm.expectCall( + address(mockPlugin1), + abi.encodeWithSelector( + IPlugin.preExecutionHook.selector, + _PRE_HOOK_FUNCTION_ID_1, + owner1, // caller + 0, // msg.value in call to account + abi.encodeWithSelector(_EXEC_SELECTOR) + ), + 1 + ); + (bool success,) = address(account1).call(abi.encodeWithSelector(_EXEC_SELECTOR)); + assertTrue(success); + + // Uninstall the first plugin. + _uninstallPlugin(mockPlugin1); + + // Execution selector should no longer exist. + (success,) = address(account1).call(abi.encodeWithSelector(_EXEC_SELECTOR)); + assertFalse(success); + + vm.stopPrank(); + } + + /// @dev Plugin 1 hook pair: [1, 2] + /// Plugin 2 hook pair: [1, 2] + /// Expected execution: [1, 2] + function test_overlappingExecHookPairs_install() public { + vm.startPrank(owner1); + + // Install the first plugin. + _installPlugin1WithHooks( + _EXEC_SELECTOR, + ManifestFunction({ + functionType: ManifestAssociatedFunctionType.SELF, + functionId: _PRE_HOOK_FUNCTION_ID_1, + dependencyIndex: 0 + }), + ManifestFunction({ + functionType: ManifestAssociatedFunctionType.SELF, + functionId: _POST_HOOK_FUNCTION_ID_2, + dependencyIndex: 0 + }) + ); + + // Install a second plugin that applies the first plugin's hook pair to the same selector. + FunctionReference[] memory dependencies = new FunctionReference[](2); + dependencies[0] = FunctionReferenceLib.pack(address(mockPlugin1), _PRE_HOOK_FUNCTION_ID_1); + dependencies[1] = FunctionReferenceLib.pack(address(mockPlugin1), _POST_HOOK_FUNCTION_ID_2); + _installPlugin2WithHooks( + _EXEC_SELECTOR, + ManifestFunction({ + functionType: ManifestAssociatedFunctionType.DEPENDENCY, + functionId: 0, + dependencyIndex: 0 + }), + ManifestFunction({ + functionType: ManifestAssociatedFunctionType.DEPENDENCY, + functionId: 0, + dependencyIndex: 1 + }), + dependencies + ); + + vm.stopPrank(); + } + + /// @dev Plugin 1 hook pair: [1, 2] + /// Plugin 2 hook pair: [1, 2] + /// Expected execution: [1, 2] + function test_overlappingExecHookPairs_run() public { + test_overlappingExecHookPairs_install(); + + vm.startPrank(owner1); + + // Expect the pre hook to be called just once. + vm.expectCall( + address(mockPlugin1), + abi.encodeWithSelector( + IPlugin.preExecutionHook.selector, + _PRE_HOOK_FUNCTION_ID_1, + owner1, // caller + 0, // msg.value in call to account + abi.encodeWithSelector(_EXEC_SELECTOR) + ), + 1 + ); + + // Expect the post hook to be called just once, with the expected data. + vm.expectCall( + address(mockPlugin1), + abi.encodeWithSelector( + IPlugin.postExecutionHook.selector, + _POST_HOOK_FUNCTION_ID_2, + abi.encode(_PRE_HOOK_FUNCTION_ID_1) // preExecHookData + ), + 1 + ); + + (bool success,) = address(account1).call(abi.encodeWithSelector(_EXEC_SELECTOR)); + assertTrue(success); + + vm.stopPrank(); + } + + /// @dev Plugin 1 hook pair: [1, 2] + /// Plugin 2 hook pair: [1, 2] + /// Expected execution: [1, 2] + function test_overlappingExecHookPairs_uninstall() public { + test_overlappingExecHookPairs_install(); + + vm.startPrank(owner1); + + // Uninstall the second plugin. + _uninstallPlugin(mockPlugin2); + + // Expect the pre/post hooks to still exist after uninstalling a plugin with a duplicate hook. + vm.expectCall( + address(mockPlugin1), + abi.encodeWithSelector( + IPlugin.preExecutionHook.selector, + _PRE_HOOK_FUNCTION_ID_1, + owner1, // caller + 0, // msg.value in call to account + abi.encodeWithSelector(_EXEC_SELECTOR) + ), + 1 + ); + vm.expectCall( + address(mockPlugin1), + abi.encodeWithSelector( + IPlugin.postExecutionHook.selector, + _POST_HOOK_FUNCTION_ID_2, + abi.encode(_PRE_HOOK_FUNCTION_ID_1) // preExecHookData + ), + 1 + ); + (bool success,) = address(account1).call(abi.encodeWithSelector(_EXEC_SELECTOR)); + assertTrue(success); + + // Uninstall the first plugin. + _uninstallPlugin(mockPlugin1); + + // Execution selector should no longer exist. + (success,) = address(account1).call(abi.encodeWithSelector(_EXEC_SELECTOR)); + assertFalse(success); + + vm.stopPrank(); + } + + /// @dev Plugin 1 hook pair: [1, 2] + /// Plugin 2 hook pair: [3, 2] + /// Expected execution: [1, 2], [3, 2] + function test_overlappingExecHookPairsOnPost_install() public { + vm.startPrank(owner1); + + // Install the first plugin. + _installPlugin1WithHooks( + _EXEC_SELECTOR, + ManifestFunction({ + functionType: ManifestAssociatedFunctionType.SELF, + functionId: _PRE_HOOK_FUNCTION_ID_1, + dependencyIndex: 0 + }), + ManifestFunction({ + functionType: ManifestAssociatedFunctionType.SELF, + functionId: _POST_HOOK_FUNCTION_ID_2, + dependencyIndex: 0 + }) + ); + + // Install the second plugin. + FunctionReference[] memory dependencies = new FunctionReference[](1); + dependencies[0] = FunctionReferenceLib.pack(address(mockPlugin1), _POST_HOOK_FUNCTION_ID_2); + _installPlugin2WithHooks( + _EXEC_SELECTOR, + ManifestFunction({ + functionType: ManifestAssociatedFunctionType.SELF, + functionId: _PRE_HOOK_FUNCTION_ID_3, + dependencyIndex: 0 + }), + ManifestFunction({ + functionType: ManifestAssociatedFunctionType.DEPENDENCY, + functionId: 0, + dependencyIndex: 0 + }), + dependencies + ); + + vm.stopPrank(); + } + + /// @dev Plugin 1 hook pair: [1, 2] + /// Plugin 2 hook pair: [3, 2] + /// Expected execution: [1, 2], [3, 2] + function test_overlappingExecHookPairsOnPost_run() public { + test_overlappingExecHookPairsOnPost_install(); + + vm.startPrank(owner1); + + // Expect each pre hook to be called once. + vm.expectCall( + address(mockPlugin1), + abi.encodeWithSelector( + IPlugin.preExecutionHook.selector, + _PRE_HOOK_FUNCTION_ID_1, + owner1, // caller + 0, // msg.value in call to account + abi.encodeWithSelector(_EXEC_SELECTOR) + ), + 1 + ); + vm.expectCall( + address(mockPlugin2), + abi.encodeWithSelector( + IPlugin.preExecutionHook.selector, + _PRE_HOOK_FUNCTION_ID_3, + owner1, // caller + 0, // msg.value in call to account + abi.encodeWithSelector(_EXEC_SELECTOR) + ), + 1 + ); + + // Expect the post hook to be called twice, with the expected data. + vm.expectCall( + address(mockPlugin1), + abi.encodeWithSelector( + IPlugin.postExecutionHook.selector, + _POST_HOOK_FUNCTION_ID_2, + abi.encode(_PRE_HOOK_FUNCTION_ID_1) // preExecHookData + ), + 1 + ); + vm.expectCall( + address(mockPlugin1), + abi.encodeWithSelector( + IPlugin.postExecutionHook.selector, + _POST_HOOK_FUNCTION_ID_2, + abi.encode(_PRE_HOOK_FUNCTION_ID_3) // preExecHookData + ), + 1 + ); + + (bool success,) = address(account1).call(abi.encodeWithSelector(_EXEC_SELECTOR)); + assertTrue(success); + + vm.stopPrank(); + } + + /// @dev Plugin 1 hook pair: [1, 2] + /// Plugin 2 hook pair: [3, 2] + /// Expected execution: [1, 2], [3, 2] + function test_overlappingExecHookPairsOnPost_uninstall() public { + test_overlappingExecHookPairsOnPost_install(); + + vm.startPrank(owner1); + + // Uninstall the second plugin. + _uninstallPlugin(mockPlugin2); + + // Expect the pre/post hooks to still exist after uninstalling a plugin with a duplicate hook. + vm.expectCall( + address(mockPlugin1), + abi.encodeWithSelector( + IPlugin.preExecutionHook.selector, + _PRE_HOOK_FUNCTION_ID_1, + owner1, // caller + 0, // msg.value in call to account + abi.encodeWithSelector(_EXEC_SELECTOR) + ), + 1 + ); + vm.expectCall( + address(mockPlugin1), + abi.encodeWithSelector( + IPlugin.postExecutionHook.selector, + _POST_HOOK_FUNCTION_ID_2, + abi.encode(_PRE_HOOK_FUNCTION_ID_1) // preExecHookData + ), + 1 + ); + (bool success,) = address(account1).call(abi.encodeWithSelector(_EXEC_SELECTOR)); + assertTrue(success); + + // Uninstall the first plugin. + _uninstallPlugin(mockPlugin1); + + // Execution selector should no longer exist. + (success,) = address(account1).call(abi.encodeWithSelector(_EXEC_SELECTOR)); + assertFalse(success); + + vm.stopPrank(); + } + + /// @dev Plugin 1 hook pair: [1, 2] + /// Plugin 2 hook pair: [1, 4] + /// Expected execution: [1, 2], [1, 4] + function test_overlappingExecHookPairsOnPre_install() public { + vm.startPrank(owner1); + + // Install the first plugin. + _installPlugin1WithHooks( + _EXEC_SELECTOR, + ManifestFunction({ + functionType: ManifestAssociatedFunctionType.SELF, + functionId: _PRE_HOOK_FUNCTION_ID_1, + dependencyIndex: 0 + }), + ManifestFunction({ + functionType: ManifestAssociatedFunctionType.SELF, + functionId: _POST_HOOK_FUNCTION_ID_2, + dependencyIndex: 0 + }) + ); + + // Install the second plugin. + FunctionReference[] memory dependencies = new FunctionReference[](1); + dependencies[0] = FunctionReferenceLib.pack(address(mockPlugin1), _PRE_HOOK_FUNCTION_ID_1); + _installPlugin2WithHooks( + _EXEC_SELECTOR, + ManifestFunction({ + functionType: ManifestAssociatedFunctionType.DEPENDENCY, + functionId: 0, + dependencyIndex: 0 + }), + ManifestFunction({ + functionType: ManifestAssociatedFunctionType.SELF, + functionId: _POST_HOOK_FUNCTION_ID_4, + dependencyIndex: 0 + }), + dependencies + ); + + vm.stopPrank(); + } + + /// @dev Plugin 1 hook pair: [1, 2] + /// Plugin 2 hook pair: [1, 4] + /// Expected execution: [1, 2], [1, 4] + function test_overlappingExecHookPairsOnPre_run() public { + test_overlappingExecHookPairsOnPre_install(); + + vm.startPrank(owner1); + + // Expect the pre hook to be called twice, each passing data over to their respective post hooks. + vm.expectCall( + address(mockPlugin1), + abi.encodeWithSelector( + IPlugin.preExecutionHook.selector, + _PRE_HOOK_FUNCTION_ID_1, + owner1, // caller + 0, // msg.value in call to account + abi.encodeWithSelector(_EXEC_SELECTOR) + ), + 2 + ); + + // Expect each post hook to be called once, with the expected data. + vm.expectCall( + address(mockPlugin1), + abi.encodeWithSelector( + IPlugin.postExecutionHook.selector, + _POST_HOOK_FUNCTION_ID_2, + abi.encode(_PRE_HOOK_FUNCTION_ID_1) // preExecHookData + ), + 1 + ); + vm.expectCall( + address(mockPlugin2), + abi.encodeWithSelector( + IPlugin.postExecutionHook.selector, + _POST_HOOK_FUNCTION_ID_4, + abi.encode(_PRE_HOOK_FUNCTION_ID_1) // preExecHookData + ), + 1 + ); + + (bool success,) = address(account1).call(abi.encodeWithSelector(_EXEC_SELECTOR)); + assertTrue(success); + + vm.stopPrank(); + } + + /// @dev Plugin 1 hook pair: [1, 2] + /// Plugin 2 hook pair: [1, 4] + /// Expected execution: [1, 2], [1, 4] + function test_overlappingExecHookPairsOnPre_uninstall() public { + test_overlappingExecHookPairsOnPre_install(); + + vm.startPrank(owner1); + + // Uninstall the second plugin. + _uninstallPlugin(mockPlugin2); + + // Expect the pre/post hooks to still exist after uninstalling a plugin with a duplicate hook. + vm.expectCall( + address(mockPlugin1), + abi.encodeWithSelector( + IPlugin.preExecutionHook.selector, + _PRE_HOOK_FUNCTION_ID_1, + owner1, // caller + 0, // msg.value in call to account + abi.encodeWithSelector(_EXEC_SELECTOR) + ), + 1 + ); + vm.expectCall( + address(mockPlugin1), + abi.encodeWithSelector( + IPlugin.postExecutionHook.selector, + _POST_HOOK_FUNCTION_ID_2, + abi.encode(_PRE_HOOK_FUNCTION_ID_1) // preExecHookData + ), + 1 + ); + (bool success,) = address(account1).call(abi.encodeWithSelector(_EXEC_SELECTOR)); + assertTrue(success); + + // Uninstall the first plugin. + _uninstallPlugin(mockPlugin1); + + // Execution selector should no longer exist. + (success,) = address(account1).call(abi.encodeWithSelector(_EXEC_SELECTOR)); + assertFalse(success); + + vm.stopPrank(); + } + + /// @dev Plugin 1 hook pair: [1, 2] + /// Plugin 2 hook pair: [1, null] + /// Expected execution: [1, 2] + function test_overlappingExecHookPairsOnPreWithNullPost_install() public { + vm.startPrank(owner1); + + // Install the first plugin. + _installPlugin1WithHooks( + _EXEC_SELECTOR, + ManifestFunction({ + functionType: ManifestAssociatedFunctionType.SELF, + functionId: _PRE_HOOK_FUNCTION_ID_1, + dependencyIndex: 0 + }), + ManifestFunction({ + functionType: ManifestAssociatedFunctionType.SELF, + functionId: _POST_HOOK_FUNCTION_ID_2, + dependencyIndex: 0 + }) + ); + + // Install the second plugin. + FunctionReference[] memory dependencies = new FunctionReference[](1); + dependencies[0] = FunctionReferenceLib.pack(address(mockPlugin1), _PRE_HOOK_FUNCTION_ID_1); + _installPlugin2WithHooks( + _EXEC_SELECTOR, + ManifestFunction({ + functionType: ManifestAssociatedFunctionType.DEPENDENCY, + functionId: 0, + dependencyIndex: 0 + }), + ManifestFunction({functionType: ManifestAssociatedFunctionType.NONE, functionId: 0, dependencyIndex: 0}), + dependencies + ); + + vm.stopPrank(); + } + + /// @dev Plugin 1 hook pair: [1, 2] + /// Plugin 2 hook pair: [1, null] + /// Expected execution: [1, 2] + function test_overlappingExecHookPairsOnPreWithNullPost_run() public { + test_overlappingExecHookPairsOnPreWithNullPost_install(); + + vm.startPrank(owner1); + + // Expect the pre hook to be called just once. + vm.expectCall( + address(mockPlugin1), + abi.encodeWithSelector( + IPlugin.preExecutionHook.selector, + _PRE_HOOK_FUNCTION_ID_1, + owner1, // caller + 0, // msg.value in call to account + abi.encodeWithSelector(_EXEC_SELECTOR) + ), + 1 + ); + + // Expect the post hook to be called just once, with the expected data. + vm.expectCall( + address(mockPlugin1), + abi.encodeWithSelector( + IPlugin.postExecutionHook.selector, + _POST_HOOK_FUNCTION_ID_2, + abi.encode(_PRE_HOOK_FUNCTION_ID_1) // preExecHookData + ), + 1 + ); + + (bool success,) = address(account1).call(abi.encodeWithSelector(_EXEC_SELECTOR)); + assertTrue(success); + + vm.stopPrank(); + } + + /// @dev Plugin 1 hook pair: [1, 2] + /// Plugin 2 hook pair: [1, null] + /// Expected execution: [1, 2] + function test_overlappingExecHookPairsOnPreWithNullPost_uninstall() public { + test_overlappingExecHookPairsOnPreWithNullPost_install(); + + vm.startPrank(owner1); + + // Uninstall the second plugin. + _uninstallPlugin(mockPlugin2); + + // Expect the pre/post hooks to still exist after uninstalling a plugin with a duplicate hook. + vm.expectCall( + address(mockPlugin1), + abi.encodeWithSelector( + IPlugin.preExecutionHook.selector, + _PRE_HOOK_FUNCTION_ID_1, + owner1, // caller + 0, // msg.value in call to account + abi.encodeWithSelector(_EXEC_SELECTOR) + ), + 1 + ); + vm.expectCall( + address(mockPlugin1), + abi.encodeWithSelector( + IPlugin.postExecutionHook.selector, + _POST_HOOK_FUNCTION_ID_2, + abi.encode(_PRE_HOOK_FUNCTION_ID_1) // preExecHookData + ), + 1 + ); + (bool success,) = address(account1).call(abi.encodeWithSelector(_EXEC_SELECTOR)); + assertTrue(success); + + // Uninstall the first plugin. + _uninstallPlugin(mockPlugin1); + + // Execution selector should no longer exist. + (success,) = address(account1).call(abi.encodeWithSelector(_EXEC_SELECTOR)); + assertFalse(success); + + vm.stopPrank(); + } + + /// @dev Plugin 1 hook pair: [1, 2] + /// Plugin 2 hook pair: [null, 2] + /// Expected execution: [1, 2], [null, 2] + function test_overlappingExecHookPairsOnPostWithNullPre_install() public { + vm.startPrank(owner1); + + // Install the first plugin. + _installPlugin1WithHooks( + _EXEC_SELECTOR, + ManifestFunction({ + functionType: ManifestAssociatedFunctionType.SELF, + functionId: _PRE_HOOK_FUNCTION_ID_1, + dependencyIndex: 0 + }), + ManifestFunction({ + functionType: ManifestAssociatedFunctionType.SELF, + functionId: _POST_HOOK_FUNCTION_ID_2, + dependencyIndex: 0 + }) + ); + + // Install the second plugin. + FunctionReference[] memory dependencies = new FunctionReference[](1); + dependencies[0] = FunctionReferenceLib.pack(address(mockPlugin1), _POST_HOOK_FUNCTION_ID_2); + _installPlugin2WithHooks( + _EXEC_SELECTOR, + ManifestFunction({functionType: ManifestAssociatedFunctionType.NONE, functionId: 0, dependencyIndex: 0}), + ManifestFunction({ + functionType: ManifestAssociatedFunctionType.DEPENDENCY, + functionId: 0, + dependencyIndex: 0 + }), + dependencies + ); + + vm.stopPrank(); + } + + /// @dev Plugin 1 hook pair: [1, 2] + /// Plugin 2 hook pair: [null, 2] + /// Expected execution: [1, 2], [null, 2] + function test_overlappingExecHookPairsOnPostWithNullPre_run() public { + test_overlappingExecHookPairsOnPostWithNullPre_install(); + + vm.startPrank(owner1); + + // Expect the pre hook to be called just once. + vm.expectCall( + address(mockPlugin1), + abi.encodeWithSelector( + IPlugin.preExecutionHook.selector, + _PRE_HOOK_FUNCTION_ID_1, + owner1, // caller + 0, // msg.value in call to account + abi.encodeWithSelector(_EXEC_SELECTOR) + ), + 1 + ); + + // Expect the post hook to be called twice, with the expected data. + vm.expectCall( + address(mockPlugin1), + abi.encodeWithSelector( + IPlugin.postExecutionHook.selector, + _POST_HOOK_FUNCTION_ID_2, + abi.encode(_PRE_HOOK_FUNCTION_ID_1) // preExecHookData + ), + 1 + ); + vm.expectCall( + address(mockPlugin1), + abi.encodeWithSelector( + IPlugin.postExecutionHook.selector, + _POST_HOOK_FUNCTION_ID_2, + "" // preExecHookData + ), + 1 + ); + + (bool success,) = address(account1).call(abi.encodeWithSelector(_EXEC_SELECTOR)); + assertTrue(success); + + vm.stopPrank(); + } + + /// @dev Plugin 1 hook pair: [1, 2] + /// Plugin 2 hook pair: [null, 2] + /// Expected execution: [1, 2], [null, 2] + function test_overlappingExecHookPairsOnPostWithNullPre_uninstall() public { + test_overlappingExecHookPairsOnPostWithNullPre_install(); + + vm.startPrank(owner1); + + // Uninstall the second plugin. + _uninstallPlugin(mockPlugin2); + + // Expect the pre/post hooks to still exist after uninstalling a plugin with a duplicate hook. + vm.expectCall( + address(mockPlugin1), + abi.encodeWithSelector( + IPlugin.preExecutionHook.selector, + _PRE_HOOK_FUNCTION_ID_1, + owner1, // caller + 0, // msg.value in call to account + abi.encodeWithSelector(_EXEC_SELECTOR) + ), + 1 + ); + vm.expectCall( + address(mockPlugin1), + abi.encodeWithSelector( + IPlugin.postExecutionHook.selector, + _POST_HOOK_FUNCTION_ID_2, + abi.encode(_PRE_HOOK_FUNCTION_ID_1) // preExecHookData + ), + 1 + ); + (bool success,) = address(account1).call(abi.encodeWithSelector(_EXEC_SELECTOR)); + assertTrue(success); + + // Uninstall the first plugin. + _uninstallPlugin(mockPlugin1); + + // Execution selector should no longer exist. + (success,) = address(account1).call(abi.encodeWithSelector(_EXEC_SELECTOR)); + assertFalse(success); + + vm.stopPrank(); + } + + /// @dev Plugin 1 hook pair: [null, 2] + /// Plugin 2 hook pair: [null, 2] + /// Expected execution: [null, 2] + function test_overlappingPostExecHooks_install() public { + vm.startPrank(owner1); + + // Install the first plugin. + _installPlugin1WithHooks( + _EXEC_SELECTOR, + ManifestFunction({functionType: ManifestAssociatedFunctionType.NONE, functionId: 0, dependencyIndex: 0}), + ManifestFunction({ + functionType: ManifestAssociatedFunctionType.SELF, + functionId: _POST_HOOK_FUNCTION_ID_2, + dependencyIndex: 0 + }) + ); + + // Install the second plugin. + FunctionReference[] memory dependencies = new FunctionReference[](1); + dependencies[0] = FunctionReferenceLib.pack(address(mockPlugin1), _POST_HOOK_FUNCTION_ID_2); + _installPlugin2WithHooks( + _EXEC_SELECTOR, + ManifestFunction({functionType: ManifestAssociatedFunctionType.NONE, functionId: 0, dependencyIndex: 0}), + ManifestFunction({ + functionType: ManifestAssociatedFunctionType.DEPENDENCY, + functionId: 0, + dependencyIndex: 0 + }), + dependencies + ); + + vm.stopPrank(); + } + + /// @dev Plugin 1 hook pair: [null, 2] + /// Plugin 2 hook pair: [null, 2] + /// Expected execution: [null, 2] + function test_overlappingPostExecHooks_run() public { + test_overlappingPostExecHooks_install(); + + vm.startPrank(owner1); + + // Expect the post hook to be called just once. + vm.expectCall( + address(mockPlugin1), + abi.encodeWithSelector( + IPlugin.postExecutionHook.selector, + _POST_HOOK_FUNCTION_ID_2, + "" // preExecHookData + ), + 1 + ); + + (bool success,) = address(account1).call(abi.encodeWithSelector(_EXEC_SELECTOR)); + assertTrue(success); + + vm.stopPrank(); + } + + /// @dev Plugin 1 hook pair: [null, 2] + /// Plugin 2 hook pair: [null, 2] + /// Expected execution: [null, 2] + function test_overlappingPostExecHooks_uninstall() public { + test_overlappingPostExecHooks_install(); + + vm.startPrank(owner1); + + // Uninstall the second plugin. + _uninstallPlugin(mockPlugin2); + + // Expect the pre/post hooks to still exist after uninstalling a plugin with a duplicate hook. + vm.expectCall( + address(mockPlugin1), + abi.encodeWithSelector( + IPlugin.postExecutionHook.selector, + _POST_HOOK_FUNCTION_ID_2, + "" // preExecHookData + ), + 1 + ); + (bool success,) = address(account1).call(abi.encodeWithSelector(_EXEC_SELECTOR)); + assertTrue(success); + + // Uninstall the first plugin. + _uninstallPlugin(mockPlugin1); + + // Execution selector should no longer exist. + (success,) = address(account1).call(abi.encodeWithSelector(_EXEC_SELECTOR)); + assertFalse(success); + + vm.stopPrank(); + } + + /// @dev Plugin 1 hook pair: [1, 2] + /// Plugin 2 hook pair: [null, 2] + /// Expected execution: [1, 2], [null, 2] + function test_execHooksWithPostOnlyForNativeFunction_install() public { + vm.startPrank(owner1); + + // Install the first plugin. + _installPlugin1WithHooks( + UpgradeableModularAccount.execute.selector, + ManifestFunction({ + functionType: ManifestAssociatedFunctionType.SELF, + functionId: _PRE_HOOK_FUNCTION_ID_1, + dependencyIndex: 0 + }), + ManifestFunction({ + functionType: ManifestAssociatedFunctionType.SELF, + functionId: _POST_HOOK_FUNCTION_ID_2, + dependencyIndex: 0 + }) + ); + + // Install the second plugin. + FunctionReference[] memory dependencies = new FunctionReference[](1); + dependencies[0] = FunctionReferenceLib.pack(address(mockPlugin1), _POST_HOOK_FUNCTION_ID_2); + _installPlugin2WithHooks( + UpgradeableModularAccount.execute.selector, + ManifestFunction({functionType: ManifestAssociatedFunctionType.NONE, functionId: 0, dependencyIndex: 0}), + ManifestFunction({ + functionType: ManifestAssociatedFunctionType.DEPENDENCY, + functionId: 0, + dependencyIndex: 0 + }), + dependencies + ); + + vm.stopPrank(); + } + + /// @dev Plugin 1 hook pair: [1, 2] + /// Plugin 2 hook pair: [null, 2] + /// Expected execution: [1, 2], [null, 2] + function test_execHooksWithPostOnlyForNativeFunction_run() public { + test_execHooksWithPostOnlyForNativeFunction_install(); + + vm.startPrank(owner1); + + // Expect the pre hook to be called just once. + vm.expectCall( + address(mockPlugin1), + abi.encodeWithSelector( + IPlugin.preExecutionHook.selector, + _PRE_HOOK_FUNCTION_ID_1, + owner1, // caller + 0, // msg.value in call to account + abi.encodeWithSelector(UpgradeableModularAccount.execute.selector, address(0), 0, "") + ), + 1 + ); + + // Expect the post hook to be called twice, with the expected data. + vm.expectCall( + address(mockPlugin1), + abi.encodeWithSelector( + IPlugin.postExecutionHook.selector, + _POST_HOOK_FUNCTION_ID_2, + abi.encode(_PRE_HOOK_FUNCTION_ID_1) // preExecHookData + ), + 1 + ); + vm.expectCall( + address(mockPlugin1), + abi.encodeWithSelector( + IPlugin.postExecutionHook.selector, + _POST_HOOK_FUNCTION_ID_2, + "" // preExecHookData + ), + 1 + ); + + account1.execute(address(0), 0, ""); + + vm.stopPrank(); + } + + function _installPlugin1WithHooks( + bytes4 selector, + ManifestFunction memory preHook, + ManifestFunction memory postHook + ) internal returns (MockPlugin) { + m1.executionHooks.push(ManifestExecutionHook(selector, preHook, postHook)); + mockPlugin1 = new MockPlugin(m1); + manifestHash1 = keccak256(abi.encode(mockPlugin1.pluginManifest())); + + vm.expectEmit(true, true, true, true); + emit ReceivedCall(abi.encodeCall(IPlugin.onInstall, (bytes(""))), 0); + vm.expectEmit(true, true, true, true); + emit PluginInstalled( + address(mockPlugin1), manifestHash1, new FunctionReference[](0), new IPluginManager.InjectedHook[](0) + ); + + account1.installPlugin({ + plugin: address(mockPlugin1), + manifestHash: manifestHash1, + pluginInitData: bytes(""), + dependencies: new FunctionReference[](0), + injectedHooks: new IPluginManager.InjectedHook[](0) + }); + + return mockPlugin1; + } + + function _installPlugin2WithHooks( + bytes4 selector, + ManifestFunction memory preHook, + ManifestFunction memory postHook, + FunctionReference[] memory dependencies + ) internal { + if (preHook.functionType == ManifestAssociatedFunctionType.DEPENDENCY) { + m2.dependencyInterfaceIds.push(type(IPlugin).interfaceId); + } + if (postHook.functionType == ManifestAssociatedFunctionType.DEPENDENCY) { + m2.dependencyInterfaceIds.push(type(IPlugin).interfaceId); + } + + m2.executionHooks.push(ManifestExecutionHook(selector, preHook, postHook)); + + mockPlugin2 = new MockPlugin(m2); + manifestHash2 = keccak256(abi.encode(mockPlugin2.pluginManifest())); + + vm.expectEmit(true, true, true, true); + emit ReceivedCall(abi.encodeCall(IPlugin.onInstall, (bytes(""))), 0); + vm.expectEmit(true, true, true, true); + emit PluginInstalled( + address(mockPlugin2), manifestHash2, dependencies, new IPluginManager.InjectedHook[](0) + ); + + account1.installPlugin({ + plugin: address(mockPlugin2), + manifestHash: manifestHash2, + pluginInitData: bytes(""), + dependencies: dependencies, + injectedHooks: new IPluginManager.InjectedHook[](0) + }); + } + + function _uninstallPlugin(MockPlugin plugin) internal { + vm.expectEmit(true, true, true, true); + emit ReceivedCall(abi.encodeCall(IPlugin.onUninstall, (bytes(""))), 0); + vm.expectEmit(true, true, true, true); + emit PluginUninstalled(address(plugin), true); + + account1.uninstallPlugin(address(plugin), bytes(""), bytes(""), new bytes[](0)); + } +} diff --git a/test/account/AccountLoupe.t.sol b/test/account/AccountLoupe.t.sol new file mode 100644 index 00000000..23ec5ef4 --- /dev/null +++ b/test/account/AccountLoupe.t.sol @@ -0,0 +1,570 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.21; + +import {Test} from "forge-std/Test.sol"; + +import {UUPSUpgradeable} from "@openzeppelin/contracts/proxy/utils/UUPSUpgradeable.sol"; +import {EntryPoint} from "@eth-infinitism/account-abstraction/core/EntryPoint.sol"; + +import {UpgradeableModularAccount} from "../../src/account/UpgradeableModularAccount.sol"; +import {IMultiOwnerPlugin} from "../../src/plugins/owner/IMultiOwnerPlugin.sol"; +import {MultiOwnerPlugin} from "../../src/plugins/owner/MultiOwnerPlugin.sol"; +import {IEntryPoint} from "../../src/interfaces/erc4337/IEntryPoint.sol"; +import { + ManifestAssociatedFunctionType, + ManifestExecutionHook, + ManifestFunction, + PluginManifest +} from "../../src/interfaces/IPlugin.sol"; +import {IAccountLoupe} from "../../src/interfaces/IAccountLoupe.sol"; +import {IPluginManager} from "../../src/interfaces/IPluginManager.sol"; +import {IStandardExecutor} from "../../src/interfaces/IStandardExecutor.sol"; +import {FunctionReference, FunctionReferenceLib} from "../../src/libraries/FunctionReferenceLib.sol"; + +import {MultiOwnerMSCAFactory} from "../../src/factory/MultiOwnerMSCAFactory.sol"; +import {ComprehensivePlugin} from "../mocks/plugins/ComprehensivePlugin.sol"; +import {MockPlugin} from "../mocks/MockPlugin.sol"; + +contract AccountLoupeTest is Test { + IEntryPoint public entryPoint; + MultiOwnerPlugin public multiOwnerPlugin; + MultiOwnerMSCAFactory public factory; + ComprehensivePlugin public comprehensivePlugin; + + UpgradeableModularAccount public account1; + + FunctionReference public ownerUserOpValidation; + FunctionReference public ownerRuntimeValidation; + + event ReceivedCall(bytes msgData, uint256 msgValue); + + function setUp() public { + entryPoint = IEntryPoint(address(new EntryPoint())); + + multiOwnerPlugin = new MultiOwnerPlugin(); + address impl = address(new UpgradeableModularAccount(entryPoint)); + factory = new MultiOwnerMSCAFactory( + address(this), + address(multiOwnerPlugin), + impl, + keccak256(abi.encode(multiOwnerPlugin.pluginManifest())), + entryPoint + ); + comprehensivePlugin = new ComprehensivePlugin(); + + address[] memory owners = new address[](1); + owners[0] = address(this); + account1 = UpgradeableModularAccount(payable(factory.createAccount(0, owners))); + + bytes32 manifestHash = keccak256(abi.encode(comprehensivePlugin.pluginManifest())); + account1.installPlugin( + address(comprehensivePlugin), + manifestHash, + "", + new FunctionReference[](0), + new IPluginManager.InjectedHook[](0) + ); + + ownerUserOpValidation = FunctionReferenceLib.pack( + address(multiOwnerPlugin), uint8(IMultiOwnerPlugin.FunctionId.USER_OP_VALIDATION_OWNER) + ); + ownerRuntimeValidation = FunctionReferenceLib.pack( + address(multiOwnerPlugin), uint8(IMultiOwnerPlugin.FunctionId.RUNTIME_VALIDATION_OWNER_OR_SELF) + ); + } + + function test_pluginLoupe_getInstalledPlugins_initial() public { + address[] memory plugins = account1.getInstalledPlugins(); + + assertEq(plugins.length, 2); + + assertEq(plugins[1], address(multiOwnerPlugin)); + assertEq(plugins[0], address(comprehensivePlugin)); + } + + function test_pluginLoupe_getExecutionFunctionConfig_native() public { + bytes4[] memory selectorsToCheck = new bytes4[](5); + FunctionReference[] memory expectedUserOpValidations = new FunctionReference[](5); + FunctionReference[] memory expectedRuntimeValidations = new FunctionReference[](5); + + selectorsToCheck[0] = IStandardExecutor.execute.selector; + expectedUserOpValidations[0] = ownerUserOpValidation; + expectedRuntimeValidations[0] = ownerRuntimeValidation; + + selectorsToCheck[1] = IStandardExecutor.executeBatch.selector; + expectedUserOpValidations[1] = ownerUserOpValidation; + expectedRuntimeValidations[1] = ownerRuntimeValidation; + + selectorsToCheck[2] = UUPSUpgradeable.upgradeToAndCall.selector; + expectedUserOpValidations[2] = ownerUserOpValidation; + expectedRuntimeValidations[2] = ownerRuntimeValidation; + + selectorsToCheck[3] = IPluginManager.installPlugin.selector; + expectedUserOpValidations[3] = ownerUserOpValidation; + expectedRuntimeValidations[3] = ownerRuntimeValidation; + + selectorsToCheck[4] = IPluginManager.uninstallPlugin.selector; + expectedUserOpValidations[4] = ownerUserOpValidation; + expectedRuntimeValidations[4] = ownerRuntimeValidation; + + for (uint256 i = 0; i < selectorsToCheck.length; i++) { + IAccountLoupe.ExecutionFunctionConfig memory config = + account1.getExecutionFunctionConfig(selectorsToCheck[i]); + + assertEq(config.plugin, address(account1)); + assertEq( + FunctionReference.unwrap(config.userOpValidationFunction), + FunctionReference.unwrap(expectedUserOpValidations[i]) + ); + assertEq( + FunctionReference.unwrap(config.runtimeValidationFunction), + FunctionReference.unwrap(expectedRuntimeValidations[i]) + ); + } + } + + function test_pluginLoupe_getExecutionFunctionConfig_plugin() public { + bytes4[] memory selectorsToCheck = new bytes4[](2); + address[] memory expectedPluginAddress = new address[](2); + FunctionReference[] memory expectedUserOpValidations = new FunctionReference[](2); + FunctionReference[] memory expectedRuntimeValidations = new FunctionReference[](2); + + selectorsToCheck[0] = comprehensivePlugin.foo.selector; + expectedPluginAddress[0] = address(comprehensivePlugin); + expectedUserOpValidations[0] = FunctionReferenceLib.pack( + address(comprehensivePlugin), uint8(ComprehensivePlugin.FunctionId.USER_OP_VALIDATION) + ); + expectedRuntimeValidations[0] = FunctionReferenceLib.pack( + address(comprehensivePlugin), uint8(ComprehensivePlugin.FunctionId.RUNTIME_VALIDATION) + ); + + selectorsToCheck[1] = multiOwnerPlugin.updateOwners.selector; + expectedPluginAddress[1] = address(multiOwnerPlugin); + expectedUserOpValidations[1] = ownerUserOpValidation; + expectedRuntimeValidations[1] = ownerRuntimeValidation; + + for (uint256 i = 0; i < selectorsToCheck.length; i++) { + IAccountLoupe.ExecutionFunctionConfig memory config = + account1.getExecutionFunctionConfig(selectorsToCheck[i]); + + assertEq(config.plugin, expectedPluginAddress[i]); + assertEq( + FunctionReference.unwrap(config.userOpValidationFunction), + FunctionReference.unwrap(expectedUserOpValidations[i]) + ); + assertEq( + FunctionReference.unwrap(config.runtimeValidationFunction), + FunctionReference.unwrap(expectedRuntimeValidations[i]) + ); + } + } + + function test_pluginLoupe_getExecutionHooks() public { + IAccountLoupe.ExecutionHooks[] memory hooks = account1.getExecutionHooks(comprehensivePlugin.foo.selector); + + assertEq(hooks.length, 2); + + _assertHookEq( + hooks[0], + FunctionReferenceLib.pack( + address(comprehensivePlugin), uint8(ComprehensivePlugin.FunctionId.PRE_EXECUTION_HOOK) + ), + FunctionReferenceLib.pack( + address(comprehensivePlugin), uint8(ComprehensivePlugin.FunctionId.POST_EXECUTION_HOOK) + ) + ); + + _assertHookEq( + hooks[1], + FunctionReferenceLib._EMPTY_FUNCTION_REFERENCE, + FunctionReferenceLib.pack( + address(comprehensivePlugin), uint8(ComprehensivePlugin.FunctionId.POST_EXECUTION_HOOK) + ) + ); + } + + function test_pluginLoupe_getPermittedCallHooks() public { + IAccountLoupe.ExecutionHooks[] memory hooks = + account1.getPermittedCallHooks(address(comprehensivePlugin), comprehensivePlugin.foo.selector); + + assertEq(hooks.length, 2); + + _assertHookEq( + hooks[0], + FunctionReferenceLib.pack( + address(comprehensivePlugin), + uint8(ComprehensivePlugin.FunctionId.PRE_PERMITTED_CALL_EXECUTION_HOOK) + ), + FunctionReferenceLib.pack( + address(comprehensivePlugin), + uint8(ComprehensivePlugin.FunctionId.POST_PERMITTED_CALL_EXECUTION_HOOK) + ) + ); + + _assertHookEq( + hooks[1], + FunctionReferenceLib._EMPTY_FUNCTION_REFERENCE, + FunctionReferenceLib.pack( + address(comprehensivePlugin), + uint8(ComprehensivePlugin.FunctionId.POST_PERMITTED_CALL_EXECUTION_HOOK) + ) + ); + } + + function test_pluginLoupe_getHooks_multiple() public { + // Add a third set of execution hooks to the account, and validate that it can return all hooks applied + // over the function. + + PluginManifest memory mockPluginManifest; + + mockPluginManifest.executionHooks = new ManifestExecutionHook[](1); + mockPluginManifest.executionHooks[0] = ManifestExecutionHook({ + executionSelector: ComprehensivePlugin.foo.selector, + preExecHook: ManifestFunction({ + functionType: ManifestAssociatedFunctionType.SELF, + functionId: 0, + dependencyIndex: 0 + }), + postExecHook: ManifestFunction({ + functionType: ManifestAssociatedFunctionType.SELF, + functionId: 0, + dependencyIndex: 0 + }) + }); + + mockPluginManifest.permittedCallHooks = new ManifestExecutionHook[](2); + // Copy over the same hooks from executionHooks. + mockPluginManifest.permittedCallHooks[0] = mockPluginManifest.executionHooks[0]; + mockPluginManifest.permittedCallHooks[1] = ManifestExecutionHook({ + executionSelector: ComprehensivePlugin.foo.selector, + preExecHook: ManifestFunction({ + functionType: ManifestAssociatedFunctionType.SELF, + functionId: 1, + dependencyIndex: 0 + }), + postExecHook: ManifestFunction({ + functionType: ManifestAssociatedFunctionType.SELF, + functionId: 1, + dependencyIndex: 0 + }) + }); + + MockPlugin mockPlugin = new MockPlugin(mockPluginManifest); + bytes32 manifestHash = keccak256(abi.encode(mockPlugin.pluginManifest())); + + account1.installPlugin( + address(mockPlugin), manifestHash, "", new FunctionReference[](0), new IPluginManager.InjectedHook[](0) + ); + + // Assert that the returned execution hooks are what is expected + + IAccountLoupe.ExecutionHooks[] memory hooks = account1.getExecutionHooks(comprehensivePlugin.foo.selector); + + assertEq(hooks.length, 3); + + _assertHookEq( + hooks[0], + FunctionReferenceLib.pack(address(mockPlugin), uint8(0)), + FunctionReferenceLib.pack(address(mockPlugin), uint8(0)) + ); + + _assertHookEq( + hooks[1], + FunctionReferenceLib.pack( + address(comprehensivePlugin), uint8(ComprehensivePlugin.FunctionId.PRE_EXECUTION_HOOK) + ), + FunctionReferenceLib.pack( + address(comprehensivePlugin), uint8(ComprehensivePlugin.FunctionId.POST_EXECUTION_HOOK) + ) + ); + + _assertHookEq( + hooks[2], + FunctionReferenceLib._EMPTY_FUNCTION_REFERENCE, + FunctionReferenceLib.pack( + address(comprehensivePlugin), uint8(ComprehensivePlugin.FunctionId.POST_EXECUTION_HOOK) + ) + ); + + // Assert that the returned permitted call hooks are what is expected + + hooks = account1.getPermittedCallHooks(address(mockPlugin), comprehensivePlugin.foo.selector); + + assertEq(hooks.length, 2); + + _assertHookEq( + hooks[0], + FunctionReferenceLib.pack(address(mockPlugin), uint8(1)), + FunctionReferenceLib.pack(address(mockPlugin), uint8(1)) + ); + + _assertHookEq( + hooks[1], + FunctionReferenceLib.pack(address(mockPlugin), uint8(0)), + FunctionReferenceLib.pack(address(mockPlugin), uint8(0)) + ); + } + + function test_pluginLoupe_getPreValidationHooks() public { + (FunctionReference[] memory preUoHooks, FunctionReference[] memory preRuntimeHooks) = + account1.getPreValidationHooks(comprehensivePlugin.foo.selector); + + // veriry pre UO hooks + assertEq(preUoHooks.length, 2); + assertEq( + FunctionReference.unwrap(preUoHooks[0]), + FunctionReference.unwrap( + FunctionReferenceLib.pack( + address(comprehensivePlugin), + uint8(ComprehensivePlugin.FunctionId.PRE_USER_OP_VALIDATION_HOOK_2) + ) + ) + ); + assertEq( + FunctionReference.unwrap(preUoHooks[1]), + FunctionReference.unwrap( + FunctionReferenceLib.pack( + address(comprehensivePlugin), + uint8(ComprehensivePlugin.FunctionId.PRE_USER_OP_VALIDATION_HOOK_1) + ) + ) + ); + + // veriry pre runtime hooks + assertEq(preRuntimeHooks.length, 2); + assertEq( + FunctionReference.unwrap(preRuntimeHooks[0]), + FunctionReference.unwrap( + FunctionReferenceLib.pack( + address(comprehensivePlugin), + uint8(ComprehensivePlugin.FunctionId.PRE_RUNTIME_VALIDATION_HOOK_2) + ) + ) + ); + assertEq( + FunctionReference.unwrap(preRuntimeHooks[1]), + FunctionReference.unwrap( + FunctionReferenceLib.pack( + address(comprehensivePlugin), + uint8(ComprehensivePlugin.FunctionId.PRE_RUNTIME_VALIDATION_HOOK_1) + ) + ) + ); + } + + function test_pluginLoupe_getExecutionHooks_overlapping() public { + PluginManifest memory mockPluginManifest; + + mockPluginManifest.executionHooks = new ManifestExecutionHook[](9); + + // [0, null] + mockPluginManifest.executionHooks[0] = ManifestExecutionHook({ + executionSelector: ComprehensivePlugin.foo.selector, + preExecHook: ManifestFunction({ + functionType: ManifestAssociatedFunctionType.SELF, + functionId: 0, + dependencyIndex: 0 + }), + postExecHook: ManifestFunction({ + functionType: ManifestAssociatedFunctionType.NONE, + functionId: 0, + dependencyIndex: 0 + }) + }); + + // [0, null] + mockPluginManifest.executionHooks[1] = ManifestExecutionHook({ + executionSelector: ComprehensivePlugin.foo.selector, + preExecHook: ManifestFunction({ + functionType: ManifestAssociatedFunctionType.SELF, + functionId: 0, + dependencyIndex: 0 + }), + postExecHook: ManifestFunction({ + functionType: ManifestAssociatedFunctionType.NONE, + functionId: 0, + dependencyIndex: 0 + }) + }); + + // [1, 2] + mockPluginManifest.executionHooks[2] = ManifestExecutionHook({ + executionSelector: ComprehensivePlugin.foo.selector, + preExecHook: ManifestFunction({ + functionType: ManifestAssociatedFunctionType.SELF, + functionId: 1, + dependencyIndex: 0 + }), + postExecHook: ManifestFunction({ + functionType: ManifestAssociatedFunctionType.SELF, + functionId: 2, + dependencyIndex: 0 + }) + }); + + // [1, 2] + mockPluginManifest.executionHooks[3] = ManifestExecutionHook({ + executionSelector: ComprehensivePlugin.foo.selector, + preExecHook: ManifestFunction({ + functionType: ManifestAssociatedFunctionType.SELF, + functionId: 1, + dependencyIndex: 0 + }), + postExecHook: ManifestFunction({ + functionType: ManifestAssociatedFunctionType.SELF, + functionId: 2, + dependencyIndex: 0 + }) + }); + + // [3, 2] + mockPluginManifest.executionHooks[4] = ManifestExecutionHook({ + executionSelector: ComprehensivePlugin.foo.selector, + preExecHook: ManifestFunction({ + functionType: ManifestAssociatedFunctionType.SELF, + functionId: 3, + dependencyIndex: 0 + }), + postExecHook: ManifestFunction({ + functionType: ManifestAssociatedFunctionType.SELF, + functionId: 2, + dependencyIndex: 0 + }) + }); + + // [1, 4] + mockPluginManifest.executionHooks[5] = ManifestExecutionHook({ + executionSelector: ComprehensivePlugin.foo.selector, + preExecHook: ManifestFunction({ + functionType: ManifestAssociatedFunctionType.SELF, + functionId: 1, + dependencyIndex: 0 + }), + postExecHook: ManifestFunction({ + functionType: ManifestAssociatedFunctionType.SELF, + functionId: 4, + dependencyIndex: 0 + }) + }); + + // [1, null] + mockPluginManifest.executionHooks[6] = ManifestExecutionHook({ + executionSelector: ComprehensivePlugin.foo.selector, + preExecHook: ManifestFunction({ + functionType: ManifestAssociatedFunctionType.SELF, + functionId: 1, + dependencyIndex: 0 + }), + postExecHook: ManifestFunction({ + functionType: ManifestAssociatedFunctionType.NONE, + functionId: 0, + dependencyIndex: 0 + }) + }); + + // [null, 2] + mockPluginManifest.executionHooks[7] = ManifestExecutionHook({ + executionSelector: ComprehensivePlugin.foo.selector, + preExecHook: ManifestFunction({ + functionType: ManifestAssociatedFunctionType.NONE, + functionId: 0, + dependencyIndex: 0 + }), + postExecHook: ManifestFunction({ + functionType: ManifestAssociatedFunctionType.SELF, + functionId: 2, + dependencyIndex: 0 + }) + }); + + // [null, 2] + mockPluginManifest.executionHooks[8] = ManifestExecutionHook({ + executionSelector: ComprehensivePlugin.foo.selector, + preExecHook: ManifestFunction({ + functionType: ManifestAssociatedFunctionType.NONE, + functionId: 0, + dependencyIndex: 0 + }), + postExecHook: ManifestFunction({ + functionType: ManifestAssociatedFunctionType.SELF, + functionId: 2, + dependencyIndex: 0 + }) + }); + + MockPlugin mockPlugin = new MockPlugin(mockPluginManifest); + bytes32 manifestHash = keccak256(abi.encode(mockPlugin.pluginManifest())); + + account1.installPlugin( + address(mockPlugin), manifestHash, "", new FunctionReference[](0), new IPluginManager.InjectedHook[](0) + ); + + // Assert that the returned execution hooks are what is expected + + IAccountLoupe.ExecutionHooks[] memory hooks = account1.getExecutionHooks(comprehensivePlugin.foo.selector); + + assertEq(hooks.length, 7); + + // [3, 2] + _assertHookEq( + hooks[0], + FunctionReferenceLib.pack(address(mockPlugin), uint8(3)), + FunctionReferenceLib.pack(address(mockPlugin), uint8(2)) + ); + + // [1, 4] + _assertHookEq( + hooks[1], + FunctionReferenceLib.pack(address(mockPlugin), uint8(1)), + FunctionReferenceLib.pack(address(mockPlugin), uint8(4)) + ); + + // [1, 2] + _assertHookEq( + hooks[2], + FunctionReferenceLib.pack(address(mockPlugin), uint8(1)), + FunctionReferenceLib.pack(address(mockPlugin), uint8(2)) + ); + + // [0, null] + _assertHookEq( + hooks[3], + FunctionReferenceLib.pack(address(mockPlugin), uint8(0)), + FunctionReferenceLib._EMPTY_FUNCTION_REFERENCE + ); + + _assertHookEq( + hooks[4], + FunctionReferenceLib.pack( + address(comprehensivePlugin), uint8(ComprehensivePlugin.FunctionId.PRE_EXECUTION_HOOK) + ), + FunctionReferenceLib.pack( + address(comprehensivePlugin), uint8(ComprehensivePlugin.FunctionId.POST_EXECUTION_HOOK) + ) + ); + + // [null, 2] + _assertHookEq( + hooks[5], + FunctionReferenceLib._EMPTY_FUNCTION_REFERENCE, + FunctionReferenceLib.pack(address(mockPlugin), uint8(2)) + ); + + _assertHookEq( + hooks[6], + FunctionReferenceLib._EMPTY_FUNCTION_REFERENCE, + FunctionReferenceLib.pack( + address(comprehensivePlugin), uint8(ComprehensivePlugin.FunctionId.POST_EXECUTION_HOOK) + ) + ); + } + + function _assertHookEq( + IAccountLoupe.ExecutionHooks memory hook, + FunctionReference preHook, + FunctionReference postHook + ) internal { + assertEq(FunctionReference.unwrap(hook.preExecHook), FunctionReference.unwrap(preHook)); + assertEq(FunctionReference.unwrap(hook.postExecHook), FunctionReference.unwrap(postHook)); + } +} diff --git a/test/account/AccountPermittedCallHooks.t.sol b/test/account/AccountPermittedCallHooks.t.sol new file mode 100644 index 00000000..8fcacc86 --- /dev/null +++ b/test/account/AccountPermittedCallHooks.t.sol @@ -0,0 +1,871 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.21; + +import {Test} from "forge-std/Test.sol"; + +import {ECDSA} from "@openzeppelin/contracts/utils/cryptography/ECDSA.sol"; +import {EntryPoint} from "@eth-infinitism/account-abstraction/core/EntryPoint.sol"; + +import {UpgradeableModularAccount} from "../../src/account/UpgradeableModularAccount.sol"; +import {MultiOwnerPlugin} from "../../src/plugins/owner/MultiOwnerPlugin.sol"; +import {IEntryPoint} from "../../src/interfaces/erc4337/IEntryPoint.sol"; +import {IPluginManager} from "../../src/interfaces/IPluginManager.sol"; +import {FunctionReference} from "../../src/libraries/FunctionReferenceLib.sol"; +import { + IPlugin, + ManifestExecutionHook, + PluginManifest, + ManifestFunction, + ManifestAssociatedFunctionType, + ManifestAssociatedFunction +} from "../../src/interfaces/IPlugin.sol"; + +import {MultiOwnerMSCAFactory} from "../../src/factory/MultiOwnerMSCAFactory.sol"; +import {MockPlugin} from "../mocks/MockPlugin.sol"; + +/// @dev Unlike execution hooks, permitted call hooks are scoped to the plugin that is executing the call +/// through `executeFromPlugin`. Therefore, different plugins cannot apply overlapping hooks to the same +/// plugin + selector combination. Overlapping hooks in this case can only originate from the same plugin, +/// which is unrealistic but possible. That's what we test here. +contract UpgradeableModularAccountPermittedCallHooksTest is Test { + using ECDSA for bytes32; + + IEntryPoint public entryPoint; + MultiOwnerPlugin public multiOwnerPlugin; + MultiOwnerMSCAFactory public factory; + MockPlugin public mockPlugin1; + bytes32 public manifestHash1; + + address public owner1; + uint256 public owner1Key; + UpgradeableModularAccount public account1; + + bytes4 internal constant _EXEC_SELECTOR = bytes4(uint32(1)); + uint8 internal constant _PRE_HOOK_FUNCTION_ID_1 = 1; + uint8 internal constant _POST_HOOK_FUNCTION_ID_2 = 2; + uint8 internal constant _PRE_HOOK_FUNCTION_ID_3 = 3; + uint8 internal constant _POST_HOOK_FUNCTION_ID_4 = 4; + + PluginManifest public m1; + + event PluginInstalled( + address indexed plugin, + bytes32 manifestHash, + FunctionReference[] dependencies, + IPluginManager.InjectedHook[] injectedHooks + ); + event PluginUninstalled(address indexed plugin, bool indexed callbacksSucceeded); + // emitted by MockPlugin + event ReceivedCall(bytes msgData, uint256 msgValue); + + function setUp() public { + entryPoint = IEntryPoint(address(new EntryPoint())); + multiOwnerPlugin = new MultiOwnerPlugin(); + + (owner1, owner1Key) = makeAddrAndKey("owner1"); + address impl = address(new UpgradeableModularAccount(IEntryPoint(address(entryPoint)))); + + factory = new MultiOwnerMSCAFactory( + address(this), + address(multiOwnerPlugin), + impl, + keccak256(abi.encode(multiOwnerPlugin.pluginManifest())), + entryPoint + ); + + address[] memory owners = new address[](1); + owners[0] = owner1; + account1 = UpgradeableModularAccount(payable(factory.createAccount(0, owners))); + vm.deal(address(account1), 100 ether); + + entryPoint.depositTo{value: 1 wei}(address(account1)); + + m1.executionFunctions.push(_EXEC_SELECTOR); + + m1.runtimeValidationFunctions.push( + ManifestAssociatedFunction({ + executionSelector: _EXEC_SELECTOR, + associatedFunction: ManifestFunction({ + functionType: ManifestAssociatedFunctionType.RUNTIME_VALIDATION_ALWAYS_ALLOW, + functionId: 0, + dependencyIndex: 0 + }) + }) + ); + + m1.permittedExecutionSelectors.push(_EXEC_SELECTOR); + } + + /// @dev Plugin hook pair(s): [1, null] + /// Expected execution: [1, null] + function test_prePermittedCallHook_install() public { + vm.startPrank(owner1); + + _installPlugin1WithHooks( + ManifestFunction({ + functionType: ManifestAssociatedFunctionType.SELF, + functionId: _PRE_HOOK_FUNCTION_ID_1, + dependencyIndex: 0 + }), + ManifestFunction({functionType: ManifestAssociatedFunctionType.NONE, functionId: 0, dependencyIndex: 0}) + ); + + vm.stopPrank(); + } + + /// @dev Plugin hook pair(s): [1, null] + /// Expected execution: [1, null] + function test_prePermittedCallHook_run() public { + test_prePermittedCallHook_install(); + + vm.startPrank(address(mockPlugin1)); + + vm.expectEmit(true, true, true, true); + emit ReceivedCall( + abi.encodeWithSelector( + IPlugin.preExecutionHook.selector, + _PRE_HOOK_FUNCTION_ID_1, + address(mockPlugin1), // caller + 0, // msg.value in call to account + abi.encodePacked(_EXEC_SELECTOR) + ), + 0 // msg value in call to plugin + ); + + account1.executeFromPlugin(abi.encodePacked(_EXEC_SELECTOR)); + + vm.stopPrank(); + } + + /// @dev Plugin hook pair(s): [1, null] + /// Expected execution: [1, null] + function test_prePermittedCallHook_uninstall() public { + test_prePermittedCallHook_install(); + + vm.startPrank(owner1); + + _uninstallPlugin(mockPlugin1); + + vm.stopPrank(); + } + + /// @dev Plugin hook pair(s): [1, 2] + /// Expected execution: [1, 2] + function test_permittedCallHookPair_install() public { + vm.startPrank(owner1); + + _installPlugin1WithHooks( + ManifestFunction({ + functionType: ManifestAssociatedFunctionType.SELF, + functionId: _PRE_HOOK_FUNCTION_ID_1, + dependencyIndex: 0 + }), + ManifestFunction({ + functionType: ManifestAssociatedFunctionType.SELF, + functionId: _POST_HOOK_FUNCTION_ID_2, + dependencyIndex: 0 + }) + ); + + vm.stopPrank(); + } + + /// @dev Plugin hook pair(s): [1, 2] + /// Expected execution: [1, 2] + function test_permittedCallHookPair_run() public { + test_permittedCallHookPair_install(); + + vm.startPrank(address(mockPlugin1)); + + vm.expectEmit(true, true, true, true); + // pre hook call + emit ReceivedCall( + abi.encodeWithSelector( + IPlugin.preExecutionHook.selector, + _PRE_HOOK_FUNCTION_ID_1, + address(mockPlugin1), // caller + 0, // msg.value in call to account + abi.encodePacked(_EXEC_SELECTOR) + ), + 0 // msg value in call to plugin + ); + vm.expectEmit(true, true, true, true); + // exec call + emit ReceivedCall(abi.encodePacked(_EXEC_SELECTOR), 0); + vm.expectEmit(true, true, true, true); + // post hook call + emit ReceivedCall( + abi.encodeCall( + IPlugin.postExecutionHook, (_POST_HOOK_FUNCTION_ID_2, abi.encode(_PRE_HOOK_FUNCTION_ID_1)) + ), + 0 // msg value in call to plugin + ); + + account1.executeFromPlugin(abi.encodePacked(_EXEC_SELECTOR)); + + vm.stopPrank(); + } + + /// @dev Plugin hook pair(s): [1, 2] + /// Expected execution: [1, 2] + function test_permittedCallHookPair_uninstall() public { + test_permittedCallHookPair_install(); + + vm.startPrank(owner1); + + _uninstallPlugin(mockPlugin1); + + vm.stopPrank(); + } + + /// @dev Plugin hook pair(s): [null, 2] + /// Expected execution: [null, 2] + function test_postOnlyPermittedCallHook_install() public { + vm.startPrank(owner1); + + _installPlugin1WithHooks( + ManifestFunction({functionType: ManifestAssociatedFunctionType.NONE, functionId: 0, dependencyIndex: 0}), + ManifestFunction({ + functionType: ManifestAssociatedFunctionType.SELF, + functionId: _POST_HOOK_FUNCTION_ID_2, + dependencyIndex: 0 + }) + ); + + vm.stopPrank(); + } + + /// @dev Plugin hook pair(s): [null, 2] + /// Expected execution: [null, 2] + function test_postOnlyPermittedCallHook_run() public { + test_postOnlyPermittedCallHook_install(); + + vm.startPrank(address(mockPlugin1)); + + vm.expectEmit(true, true, true, true); + emit ReceivedCall( + abi.encodeCall(IPlugin.postExecutionHook, (_POST_HOOK_FUNCTION_ID_2, "")), + 0 // msg value in call to plugin + ); + + account1.executeFromPlugin(abi.encodePacked(_EXEC_SELECTOR)); + + vm.stopPrank(); + } + + /// @dev Plugin hook pair(s): [null, 2] + /// Expected execution: [null, 2] + function test_postOnlyPermittedCallHook_uninstall() public { + test_postOnlyPermittedCallHook_install(); + + vm.startPrank(owner1); + + _uninstallPlugin(mockPlugin1); + + vm.stopPrank(); + } + + /// @dev Plugin hook pair(s): [1, null], [1, null] + /// Expected execution: [1, null] + function test_overlappingPrePermittedCallHooks_install() public { + vm.startPrank(owner1); + + _installPlugin1WithHooks( + ManifestFunction({ + functionType: ManifestAssociatedFunctionType.SELF, + functionId: _PRE_HOOK_FUNCTION_ID_1, + dependencyIndex: 0 + }), + ManifestFunction({functionType: ManifestAssociatedFunctionType.NONE, functionId: 0, dependencyIndex: 0}), + ManifestFunction({ + functionType: ManifestAssociatedFunctionType.SELF, + functionId: _PRE_HOOK_FUNCTION_ID_1, + dependencyIndex: 0 + }), + ManifestFunction({functionType: ManifestAssociatedFunctionType.NONE, functionId: 0, dependencyIndex: 0}) + ); + + vm.stopPrank(); + } + + /// @dev Plugin hook pair(s): [1, null], [1, null] + /// Expected execution: [1, null] + function test_overlappingPrePermittedCallHooks_run() public { + test_overlappingPrePermittedCallHooks_install(); + + vm.startPrank(address(mockPlugin1)); + + // Expect the pre hook to be called just once. + vm.expectCall( + address(mockPlugin1), + abi.encodeWithSelector( + IPlugin.preExecutionHook.selector, + _PRE_HOOK_FUNCTION_ID_1, + address(mockPlugin1), // caller + 0, // msg.value in call to account + abi.encodePacked(_EXEC_SELECTOR) + ), + 1 + ); + + account1.executeFromPlugin(abi.encodePacked(_EXEC_SELECTOR)); + + vm.stopPrank(); + } + + /// @dev Plugin hook pair(s): [1, null], [1, null] + /// Expected execution: [1, null] + function test_overlappingPrePermittedCallHooks_uninstall() public { + test_overlappingPrePermittedCallHooks_install(); + + vm.startPrank(owner1); + + _uninstallPlugin(mockPlugin1); + + vm.stopPrank(); + } + + /// @dev Plugin hook pair(s): [1, 2], [1, 2] + /// Expected execution: [1, 2] + function test_overlappingPermittedCallHookPairs_install() public { + vm.startPrank(owner1); + + _installPlugin1WithHooks( + ManifestFunction({ + functionType: ManifestAssociatedFunctionType.SELF, + functionId: _PRE_HOOK_FUNCTION_ID_1, + dependencyIndex: 0 + }), + ManifestFunction({ + functionType: ManifestAssociatedFunctionType.SELF, + functionId: _POST_HOOK_FUNCTION_ID_2, + dependencyIndex: 0 + }), + ManifestFunction({ + functionType: ManifestAssociatedFunctionType.SELF, + functionId: _PRE_HOOK_FUNCTION_ID_1, + dependencyIndex: 0 + }), + ManifestFunction({ + functionType: ManifestAssociatedFunctionType.SELF, + functionId: _POST_HOOK_FUNCTION_ID_2, + dependencyIndex: 0 + }) + ); + + vm.stopPrank(); + } + + /// @dev Plugin hook pair(s): [1, 2], [1, 2] + /// Expected execution: [1, 2] + function test_overlappingPermittedCallHookPairs_run() public { + test_overlappingPermittedCallHookPairs_install(); + + vm.startPrank(address(mockPlugin1)); + + // Expect the pre hook to be called just once. + vm.expectCall( + address(mockPlugin1), + abi.encodeWithSelector( + IPlugin.preExecutionHook.selector, + _PRE_HOOK_FUNCTION_ID_1, + address(mockPlugin1), // caller + 0, // msg.value in call to account + abi.encodePacked(_EXEC_SELECTOR) + ), + 1 + ); + + // Expect the post hook to be called just once, with the expected data. + vm.expectCall( + address(mockPlugin1), + abi.encodeWithSelector( + IPlugin.postExecutionHook.selector, + _POST_HOOK_FUNCTION_ID_2, + abi.encode(_PRE_HOOK_FUNCTION_ID_1) // preExecHookData + ), + 1 + ); + + account1.executeFromPlugin(abi.encodePacked(_EXEC_SELECTOR)); + + vm.stopPrank(); + } + + /// @dev Plugin hook pair(s): [1, 2], [1, 2] + /// Expected execution: [1, 2] + function test_overlappingPermittedCallHookPairs_uninstall() public { + test_overlappingPermittedCallHookPairs_install(); + + vm.startPrank(owner1); + + _uninstallPlugin(mockPlugin1); + + vm.stopPrank(); + } + + /// @dev Plugin hook pair(s): [1, 2], [3, 2] + /// Expected execution: [1, 2], [3, 2] + function test_overlappingPermittedCallHookPairsOnPost_install() public { + vm.startPrank(owner1); + + _installPlugin1WithHooks( + ManifestFunction({ + functionType: ManifestAssociatedFunctionType.SELF, + functionId: _PRE_HOOK_FUNCTION_ID_1, + dependencyIndex: 0 + }), + ManifestFunction({ + functionType: ManifestAssociatedFunctionType.SELF, + functionId: _POST_HOOK_FUNCTION_ID_2, + dependencyIndex: 0 + }), + ManifestFunction({ + functionType: ManifestAssociatedFunctionType.SELF, + functionId: _PRE_HOOK_FUNCTION_ID_3, + dependencyIndex: 0 + }), + ManifestFunction({ + functionType: ManifestAssociatedFunctionType.SELF, + functionId: _POST_HOOK_FUNCTION_ID_2, + dependencyIndex: 0 + }) + ); + + vm.stopPrank(); + } + + /// @dev Plugin hook pair(s): [1, 2], [3, 2] + /// Expected execution: [1, 2], [3, 2] + function test_overlappingPermittedCallHookPairsOnPost_run() public { + test_overlappingPermittedCallHookPairsOnPost_install(); + + vm.startPrank(address(mockPlugin1)); + + // Expect each pre hook to be called once. + + vm.expectCall( + address(mockPlugin1), + abi.encodeWithSelector( + IPlugin.preExecutionHook.selector, + _PRE_HOOK_FUNCTION_ID_3, + address(mockPlugin1), // caller + 0, // msg.value in call to account + abi.encodePacked(_EXEC_SELECTOR) + ), + 1 + ); + + vm.expectCall( + address(mockPlugin1), + abi.encodeWithSelector( + IPlugin.preExecutionHook.selector, + _PRE_HOOK_FUNCTION_ID_1, + address(mockPlugin1), // caller + 0, // msg.value in call to account + abi.encodePacked(_EXEC_SELECTOR) + ), + 1 + ); + + // Expect the post hook to be called twice, with the expected data. + vm.expectCall( + address(mockPlugin1), + abi.encodeWithSelector( + IPlugin.postExecutionHook.selector, + _POST_HOOK_FUNCTION_ID_2, + abi.encode(_PRE_HOOK_FUNCTION_ID_1) // preExecHookData + ), + 1 + ); + vm.expectCall( + address(mockPlugin1), + abi.encodeWithSelector( + IPlugin.postExecutionHook.selector, + _POST_HOOK_FUNCTION_ID_2, + abi.encode(_PRE_HOOK_FUNCTION_ID_3) // preExecHookData + ), + 1 + ); + + account1.executeFromPlugin(abi.encodePacked(_EXEC_SELECTOR)); + + vm.stopPrank(); + } + + /// @dev Plugin hook pair(s): [1, 2], [3, 2] + /// Expected execution: [1, 2], [3, 2] + function test_overlappingPermittedCallHookPairsOnPost_uninstall() public { + test_overlappingPermittedCallHookPairsOnPost_install(); + + vm.startPrank(owner1); + + _uninstallPlugin(mockPlugin1); + + vm.stopPrank(); + } + + /// @dev Plugin hook pair(s): [1, 2], [1, 4] + /// Expected execution: [1, 2], [1, 4] + function test_overlappingPermittedCallHookPairsOnPre_install() public { + vm.startPrank(owner1); + + _installPlugin1WithHooks( + ManifestFunction({ + functionType: ManifestAssociatedFunctionType.SELF, + functionId: _PRE_HOOK_FUNCTION_ID_1, + dependencyIndex: 0 + }), + ManifestFunction({ + functionType: ManifestAssociatedFunctionType.SELF, + functionId: _POST_HOOK_FUNCTION_ID_2, + dependencyIndex: 0 + }), + ManifestFunction({ + functionType: ManifestAssociatedFunctionType.SELF, + functionId: _PRE_HOOK_FUNCTION_ID_1, + dependencyIndex: 0 + }), + ManifestFunction({ + functionType: ManifestAssociatedFunctionType.SELF, + functionId: _POST_HOOK_FUNCTION_ID_4, + dependencyIndex: 0 + }) + ); + + vm.stopPrank(); + } + + /// @dev Plugin hook pair(s): [1, 2], [1, 4] + /// Expected execution: [1, 2], [1, 4] + function test_overlappingPermittedCallHookPairsOnPre_run() public { + test_overlappingPermittedCallHookPairsOnPre_install(); + + vm.startPrank(address(mockPlugin1)); + + // Expect the pre hook to be called twice, each passing data over to their respective post hooks. + vm.expectCall( + address(mockPlugin1), + abi.encodeWithSelector( + IPlugin.preExecutionHook.selector, + _PRE_HOOK_FUNCTION_ID_1, + address(mockPlugin1), // caller + 0, // msg.value in call to account + abi.encodePacked(_EXEC_SELECTOR) + ), + 2 + ); + + // Expect each post hook to be called once, with the expected data. + vm.expectCall( + address(mockPlugin1), + abi.encodeWithSelector( + IPlugin.postExecutionHook.selector, + _POST_HOOK_FUNCTION_ID_2, + abi.encode(_PRE_HOOK_FUNCTION_ID_1) // preExecHookData + ), + 1 + ); + vm.expectCall( + address(mockPlugin1), + abi.encodeWithSelector( + IPlugin.postExecutionHook.selector, + _POST_HOOK_FUNCTION_ID_4, + abi.encode(_PRE_HOOK_FUNCTION_ID_1) // preExecHookData + ), + 1 + ); + + account1.executeFromPlugin(abi.encodePacked(_EXEC_SELECTOR)); + + vm.stopPrank(); + } + + /// @dev Plugin hook pair(s): [1, 2], [1, 4] + /// Expected execution: [1, 2], [1, 4] + function test_overlappingPermittedCallHookPairsOnPre_uninstall() public { + test_overlappingPermittedCallHookPairsOnPre_install(); + + vm.startPrank(owner1); + + _uninstallPlugin(mockPlugin1); + + vm.stopPrank(); + } + + /// @dev Plugin hook pair(s): [1, 2], [1, null] + /// Expected execution: [1, 2] + function test_overlappingPermittedCallHookPairsOnPreWithNullPost_install() public { + vm.startPrank(owner1); + + _installPlugin1WithHooks( + ManifestFunction({ + functionType: ManifestAssociatedFunctionType.SELF, + functionId: _PRE_HOOK_FUNCTION_ID_1, + dependencyIndex: 0 + }), + ManifestFunction({ + functionType: ManifestAssociatedFunctionType.SELF, + functionId: _POST_HOOK_FUNCTION_ID_2, + dependencyIndex: 0 + }), + ManifestFunction({ + functionType: ManifestAssociatedFunctionType.SELF, + functionId: _PRE_HOOK_FUNCTION_ID_1, + dependencyIndex: 0 + }), + ManifestFunction({functionType: ManifestAssociatedFunctionType.NONE, functionId: 0, dependencyIndex: 0}) + ); + + vm.stopPrank(); + } + + /// @dev Plugin hook pair(s): [1, 2], [1, null] + /// Expected execution: [1, 2] + function test_overlappingPermittedCallHookPairsOnPreWithNullPost_run() public { + test_overlappingPermittedCallHookPairsOnPreWithNullPost_install(); + + vm.startPrank(address(mockPlugin1)); + + // Expect the pre hook to be called just once. + vm.expectCall( + address(mockPlugin1), + abi.encodeWithSelector( + IPlugin.preExecutionHook.selector, + _PRE_HOOK_FUNCTION_ID_1, + address(mockPlugin1), // caller + 0, // msg.value in call to account + abi.encodePacked(_EXEC_SELECTOR) + ), + 1 + ); + + // Expect the post hook to be called just once, with the expected data. + vm.expectCall( + address(mockPlugin1), + abi.encodeWithSelector( + IPlugin.postExecutionHook.selector, + _POST_HOOK_FUNCTION_ID_2, + abi.encode(_PRE_HOOK_FUNCTION_ID_1) // preExecHookData + ), + 1 + ); + + account1.executeFromPlugin(abi.encodePacked(_EXEC_SELECTOR)); + + vm.stopPrank(); + } + + /// @dev Plugin hook pair(s): [1, 2], [1, null] + /// Expected execution: [1, 2] + function test_overlappingPermittedCallHookPairsOnPreWithNullPost_uninstall() public { + test_overlappingPermittedCallHookPairsOnPreWithNullPost_install(); + + vm.startPrank(owner1); + + _uninstallPlugin(mockPlugin1); + + vm.stopPrank(); + } + + /// @dev Plugin hook pair(s): [1, 2], [null, 2] + /// Expected execution: [1, 2], [null, 2] + function test_overlappingPermittedCallHookPairsOnPreWithNullPre_install() public { + vm.startPrank(owner1); + + _installPlugin1WithHooks( + ManifestFunction({ + functionType: ManifestAssociatedFunctionType.SELF, + functionId: _PRE_HOOK_FUNCTION_ID_1, + dependencyIndex: 0 + }), + ManifestFunction({ + functionType: ManifestAssociatedFunctionType.SELF, + functionId: _POST_HOOK_FUNCTION_ID_2, + dependencyIndex: 0 + }), + ManifestFunction({functionType: ManifestAssociatedFunctionType.NONE, functionId: 0, dependencyIndex: 0}), + ManifestFunction({ + functionType: ManifestAssociatedFunctionType.SELF, + functionId: _POST_HOOK_FUNCTION_ID_2, + dependencyIndex: 0 + }) + ); + + vm.stopPrank(); + } + + /// @dev Plugin hook pair(s): [1, 2], [null, 2] + /// Expected execution: [1, 2], [null, 2] + function test_overlappingPermittedCallHookPairsOnPreWithNullPre_run() public { + test_overlappingPermittedCallHookPairsOnPreWithNullPre_install(); + + vm.startPrank(address(mockPlugin1)); + + // Expect the pre hook to be called just once. + vm.expectCall( + address(mockPlugin1), + abi.encodeWithSelector( + IPlugin.preExecutionHook.selector, + _PRE_HOOK_FUNCTION_ID_1, + address(mockPlugin1), // caller + 0, // msg.value in call to account + abi.encodePacked(_EXEC_SELECTOR) + ), + 1 + ); + + // Expect the post hook to be called twice, with the expected data. + vm.expectCall( + address(mockPlugin1), + abi.encodeWithSelector( + IPlugin.postExecutionHook.selector, + _POST_HOOK_FUNCTION_ID_2, + abi.encode(_PRE_HOOK_FUNCTION_ID_1) // preExecHookData + ), + 1 + ); + vm.expectCall( + address(mockPlugin1), + abi.encodeWithSelector( + IPlugin.postExecutionHook.selector, + _POST_HOOK_FUNCTION_ID_2, + "" // preExecHookData + ), + 1 + ); + + account1.executeFromPlugin(abi.encodePacked(_EXEC_SELECTOR)); + + vm.stopPrank(); + } + + /// @dev Plugin hook pair(s): [1, 2], [null, 2] + /// Expected execution: [1, 2], [null, 2] + function test_overlappingPermittedCallHookPairsOnPreWithNullPre_uninstall() public { + test_overlappingPermittedCallHookPairsOnPreWithNullPre_install(); + + vm.startPrank(owner1); + + _uninstallPlugin(mockPlugin1); + + vm.stopPrank(); + } + + /// @dev Plugin hook pair(s): [null, 2], [null, 2] + /// Expected execution: [null, 2] + function test_overlappingPostPermittedCallHooks_install() public { + vm.startPrank(owner1); + + _installPlugin1WithHooks( + ManifestFunction({functionType: ManifestAssociatedFunctionType.NONE, functionId: 0, dependencyIndex: 0}), + ManifestFunction({ + functionType: ManifestAssociatedFunctionType.SELF, + functionId: _POST_HOOK_FUNCTION_ID_2, + dependencyIndex: 0 + }), + ManifestFunction({functionType: ManifestAssociatedFunctionType.NONE, functionId: 0, dependencyIndex: 0}), + ManifestFunction({ + functionType: ManifestAssociatedFunctionType.SELF, + functionId: _POST_HOOK_FUNCTION_ID_2, + dependencyIndex: 0 + }) + ); + + vm.stopPrank(); + } + + /// @dev Plugin hook pair(s): [null, 2], [null, 2] + /// Expected execution: [null, 2] + function test_overlappingPostPermittedCallHooks_run() public { + test_overlappingPostPermittedCallHooks_install(); + + vm.startPrank(address(mockPlugin1)); + + // Expect the post hook to be called just once. + vm.expectCall( + address(mockPlugin1), + abi.encodeWithSelector( + IPlugin.postExecutionHook.selector, + _POST_HOOK_FUNCTION_ID_2, + "" // preExecHookData + ), + 1 + ); + + account1.executeFromPlugin(abi.encodePacked(_EXEC_SELECTOR)); + + vm.stopPrank(); + } + + /// @dev Plugin hook pair(s): [null, 2], [null, 2] + /// Expected execution: [null, 2] + function test_overlappingPostPermittedCallHooks_uninstall() public { + test_overlappingPostPermittedCallHooks_install(); + + vm.startPrank(owner1); + + _uninstallPlugin(mockPlugin1); + + vm.stopPrank(); + } + + function _installPlugin1WithHooks(ManifestFunction memory preHook1, ManifestFunction memory postHook1) + internal + { + m1.permittedCallHooks.push(ManifestExecutionHook(_EXEC_SELECTOR, preHook1, postHook1)); + mockPlugin1 = new MockPlugin(m1); + manifestHash1 = keccak256(abi.encode(mockPlugin1.pluginManifest())); + + vm.expectEmit(true, true, true, true); + emit ReceivedCall(abi.encodeCall(IPlugin.onInstall, (bytes(""))), 0); + vm.expectEmit(true, true, true, true); + emit PluginInstalled( + address(mockPlugin1), manifestHash1, new FunctionReference[](0), new IPluginManager.InjectedHook[](0) + ); + + account1.installPlugin({ + plugin: address(mockPlugin1), + manifestHash: manifestHash1, + pluginInitData: bytes(""), + dependencies: new FunctionReference[](0), + injectedHooks: new IPluginManager.InjectedHook[](0) + }); + } + + function _installPlugin1WithHooks( + ManifestFunction memory preHook1, + ManifestFunction memory postHook1, + ManifestFunction memory preHook2, + ManifestFunction memory postHook2 + ) internal { + m1.permittedCallHooks.push(ManifestExecutionHook(_EXEC_SELECTOR, preHook1, postHook1)); + m1.permittedCallHooks.push(ManifestExecutionHook(_EXEC_SELECTOR, preHook2, postHook2)); + mockPlugin1 = new MockPlugin(m1); + manifestHash1 = keccak256(abi.encode(mockPlugin1.pluginManifest())); + + vm.expectEmit(true, true, true, true); + emit ReceivedCall(abi.encodeCall(IPlugin.onInstall, (bytes(""))), 0); + vm.expectEmit(true, true, true, true); + emit PluginInstalled( + address(mockPlugin1), manifestHash1, new FunctionReference[](0), new IPluginManager.InjectedHook[](0) + ); + + account1.installPlugin({ + plugin: address(mockPlugin1), + manifestHash: manifestHash1, + pluginInitData: bytes(""), + dependencies: new FunctionReference[](0), + injectedHooks: new IPluginManager.InjectedHook[](0) + }); + } + + function _uninstallPlugin(MockPlugin plugin) internal { + vm.expectEmit(true, true, true, true); + emit ReceivedCall(abi.encodeCall(IPlugin.onUninstall, (bytes(""))), 0); + vm.expectEmit(true, true, true, true); + emit PluginUninstalled(address(plugin), true); + + account1.uninstallPlugin(address(plugin), bytes(""), bytes(""), new bytes[](0)); + } +} diff --git a/test/account/AccountPreValidationHooks.t.sol b/test/account/AccountPreValidationHooks.t.sol new file mode 100644 index 00000000..1f4488ef --- /dev/null +++ b/test/account/AccountPreValidationHooks.t.sol @@ -0,0 +1,739 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.21; + +import {Test} from "forge-std/Test.sol"; + +import {ECDSA} from "@openzeppelin/contracts/utils/cryptography/ECDSA.sol"; +import {EntryPoint} from "@eth-infinitism/account-abstraction/core/EntryPoint.sol"; + +import {UpgradeableModularAccount} from "../../src/account/UpgradeableModularAccount.sol"; +import {IMultiOwnerPlugin} from "../../src/plugins/owner/IMultiOwnerPlugin.sol"; +import {MultiOwnerPlugin} from "../../src/plugins/owner/MultiOwnerPlugin.sol"; +import {IEntryPoint} from "../../src/interfaces/erc4337/IEntryPoint.sol"; +import {UserOperation} from "../../src/interfaces/erc4337/UserOperation.sol"; +import {IPluginManager} from "../../src/interfaces/IPluginManager.sol"; +import {FunctionReference, FunctionReferenceLib} from "../../src/libraries/FunctionReferenceLib.sol"; +import { + IPlugin, + PluginManifest, + ManifestFunction, + ManifestAssociatedFunctionType, + ManifestAssociatedFunction +} from "../../src/interfaces/IPlugin.sol"; + +import {MultiOwnerMSCAFactory} from "../../src/factory/MultiOwnerMSCAFactory.sol"; +import {MockPlugin} from "../mocks/MockPlugin.sol"; + +contract UpgradeableModularAccountPreValidationHooksTest is Test { + using ECDSA for bytes32; + + IEntryPoint public entryPoint; + address payable public beneficiary; + MultiOwnerPlugin public multiOwnerPlugin; + MultiOwnerMSCAFactory public factory; + MockPlugin public mockPlugin1; + MockPlugin public mockPlugin2; + bytes32 public manifestHash1; + bytes32 public manifestHash2; + + address public owner1; + uint256 public owner1Key; + UpgradeableModularAccount public account1; + + bytes4 internal constant _EXEC_SELECTOR = bytes4(uint32(1)); + + PluginManifest public m1; + PluginManifest public m2; + + uint256 public constant CALL_GAS_LIMIT = 70000; + uint256 public constant VERIFICATION_GAS_LIMIT = 1000000; + + event PluginInstalled( + address indexed plugin, + bytes32 manifestHash, + FunctionReference[] dependencies, + IPluginManager.InjectedHook[] injectedHooks + ); + event PluginUninstalled(address indexed plugin, bool indexed callbacksSucceeded); + + function setUp() public { + entryPoint = IEntryPoint(address(new EntryPoint())); + multiOwnerPlugin = new MultiOwnerPlugin(); + + (owner1, owner1Key) = makeAddrAndKey("owner1"); + beneficiary = payable(makeAddr("beneficiary")); + vm.deal(beneficiary, 1 wei); + + address impl = address(new UpgradeableModularAccount(entryPoint)); + + factory = new MultiOwnerMSCAFactory( + address(this), + address(multiOwnerPlugin), + impl, + keccak256(abi.encode(multiOwnerPlugin.pluginManifest())), + entryPoint + ); + + address[] memory owners = new address[](1); + owners[0] = owner1; + account1 = UpgradeableModularAccount(payable(factory.createAccount(0, owners))); + vm.deal(address(account1), 100 ether); + + entryPoint.depositTo{value: 1 wei}(address(account1)); + + m1.executionFunctions.push(_EXEC_SELECTOR); + + m1.runtimeValidationFunctions.push( + ManifestAssociatedFunction({ + executionSelector: _EXEC_SELECTOR, + associatedFunction: ManifestFunction({ + functionType: ManifestAssociatedFunctionType.RUNTIME_VALIDATION_ALWAYS_ALLOW, + functionId: 0, + dependencyIndex: 0 + }) + }) + ); + } + + /// @dev Plugin 1 hook: 1 + /// Plugin 2 hook: 2 + /// Expected execution: [1, 2] + function test_preRuntimeValidationHooks_install() public { + vm.startPrank(owner1); + + _installPlugin1WithPreRuntimeValidationHook( + _EXEC_SELECTOR, + ManifestFunction({functionType: ManifestAssociatedFunctionType.SELF, functionId: 1, dependencyIndex: 0}) + ); + + _installPlugin2WithPreRuntimeValidationHook( + _EXEC_SELECTOR, + ManifestFunction({functionType: ManifestAssociatedFunctionType.SELF, functionId: 2, dependencyIndex: 0}), + new FunctionReference[](0) + ); + + vm.stopPrank(); + } + + /// @dev Plugin 1 hook: 1 + /// Plugin 2 hook: 2 + /// Expected execution: [1, 2] + function test_preRuntimeValidationHooks_run() public { + test_preRuntimeValidationHooks_install(); + + vm.startPrank(owner1); + + vm.expectCall( + address(mockPlugin1), + abi.encodeWithSelector( + IPlugin.preRuntimeValidationHook.selector, + 1, + owner1, // caller + 0, // msg.value in call to account + abi.encodeWithSelector(_EXEC_SELECTOR) + ), + 1 + ); + + vm.expectCall( + address(mockPlugin2), + abi.encodeWithSelector( + IPlugin.preRuntimeValidationHook.selector, + 2, + owner1, // caller + 0, // msg.value in call to account + abi.encodeWithSelector(_EXEC_SELECTOR) + ), + 1 + ); + + (bool success,) = address(account1).call(abi.encodeWithSelector(_EXEC_SELECTOR)); + assertTrue(success); + + vm.stopPrank(); + } + + function testFuzz_preRuntimeValidationHooks_revert(bytes memory hookRevertReason) public { + vm.startPrank(owner1); + MockPlugin hookPlugin = _installPlugin1WithPreRuntimeValidationHook( + _EXEC_SELECTOR, + ManifestFunction({functionType: ManifestAssociatedFunctionType.SELF, functionId: 1, dependencyIndex: 0}) + ); + + vm.mockCallRevert( + address(hookPlugin), + abi.encodeCall( + IPlugin.preRuntimeValidationHook, (1, owner1, 0, abi.encodeWithSelector(_EXEC_SELECTOR)) + ), + hookRevertReason + ); + (bool success, bytes memory returnData) = address(account1).call(abi.encodeWithSelector(_EXEC_SELECTOR)); + assertFalse(success); + assertEq( + returnData, + abi.encodeWithSelector( + UpgradeableModularAccount.PreRuntimeValidationHookFailed.selector, + address(hookPlugin), + 1, + hookRevertReason + ) + ); + vm.stopPrank(); + } + + function test_preRuntimeValidationHooks_revertAlwaysDeny() public { + vm.startPrank(owner1); + + _installPlugin1WithPreRuntimeValidationHook( + _EXEC_SELECTOR, + ManifestFunction({ + functionType: ManifestAssociatedFunctionType.PRE_HOOK_ALWAYS_DENY, + functionId: 0, + dependencyIndex: 0 + }) + ); + + (bool success, bytes memory returnData) = address(account1).call(abi.encodeWithSelector(_EXEC_SELECTOR)); + assertFalse(success); + assertEq(returnData, abi.encodeWithSelector(UpgradeableModularAccount.AlwaysDenyRule.selector)); + + vm.stopPrank(); + } + + /// @dev Plugin 1 hook: 1 + /// Plugin 2 hook: 2 + /// Expected execution: [1, 2] + function test_preRuntimeValidationHooks_uninstall() public { + test_preRuntimeValidationHooks_install(); + + vm.startPrank(owner1); + + // Uninstall the second plugin. + _uninstallPlugin(mockPlugin2); + + // Expect hook 1 to exist, but not hook 2. + vm.expectCall( + address(mockPlugin1), + abi.encodeWithSelector( + IPlugin.preRuntimeValidationHook.selector, + 1, + owner1, // caller + 0, // msg.value in call to account + abi.encodeWithSelector(_EXEC_SELECTOR) + ), + 1 + ); + vm.expectCall( + address(mockPlugin2), + abi.encodeWithSelector( + IPlugin.preRuntimeValidationHook.selector, + 2, + owner1, // caller + 0, // msg.value in call to account + abi.encodeWithSelector(_EXEC_SELECTOR) + ), + 0 + ); + + (bool success,) = address(account1).call(abi.encodeWithSelector(_EXEC_SELECTOR)); + assertTrue(success); + + // Uninstall the first plugin. + _uninstallPlugin(mockPlugin1); + + // Execution selector should no longer exist. + (success,) = address(account1).call(abi.encodeWithSelector(_EXEC_SELECTOR)); + assertFalse(success); + + vm.stopPrank(); + } + + /// @dev Plugin 1 hook: 1 + /// Plugin 2 hook: 1 + /// Expected execution: [1] + function test_overlappingPreRuntimeValidationHook_install() public { + vm.startPrank(owner1); + + _installPlugin1WithPreRuntimeValidationHook( + _EXEC_SELECTOR, + ManifestFunction({functionType: ManifestAssociatedFunctionType.SELF, functionId: 1, dependencyIndex: 0}) + ); + + FunctionReference[] memory dependencies = new FunctionReference[](1); + dependencies[0] = FunctionReferenceLib.pack(address(mockPlugin1), 1); + _installPlugin2WithPreRuntimeValidationHook( + _EXEC_SELECTOR, + ManifestFunction({ + functionType: ManifestAssociatedFunctionType.DEPENDENCY, + functionId: 0, + dependencyIndex: 0 + }), + dependencies + ); + + vm.stopPrank(); + } + + /// @dev Plugin 1 hook: 1 + /// Plugin 2 hook: 1 + /// Expected execution: [1] + function test_overlappingPreRuntimeValidationHooks_run() public { + test_overlappingPreRuntimeValidationHook_install(); + + vm.startPrank(owner1); + + vm.expectCall( + address(mockPlugin1), + abi.encodeWithSelector( + IPlugin.preRuntimeValidationHook.selector, + 1, + owner1, // caller + 0, // msg.value in call to account + abi.encodeWithSelector(_EXEC_SELECTOR) + ), + 1 + ); + + (bool success,) = address(account1).call(abi.encodeWithSelector(_EXEC_SELECTOR)); + assertTrue(success); + + vm.stopPrank(); + } + + /// @dev Plugin 1 hook: 1 + /// Plugin 2 hook: 1 + /// Expected execution: [1] + function test_overlappingPreRuntimeValidationHook_uninstall() public { + test_overlappingPreRuntimeValidationHook_install(); + + vm.startPrank(owner1); + + // Uninstall the second plugin. + _uninstallPlugin(mockPlugin2); + + // Expect the hook to still exist after uninstalling a plugin with a duplicate hook. + vm.expectCall( + address(mockPlugin1), + abi.encodeWithSelector( + IPlugin.preRuntimeValidationHook.selector, + 1, + owner1, // caller + 0, // msg.value in call to account + abi.encodeWithSelector(_EXEC_SELECTOR) + ), + 1 + ); + + (bool success,) = address(account1).call(abi.encodeWithSelector(_EXEC_SELECTOR)); + assertTrue(success); + + // Uninstall the first plugin. + _uninstallPlugin(mockPlugin1); + + // Execution selector should no longer exist. + (success,) = address(account1).call(abi.encodeWithSelector(_EXEC_SELECTOR)); + assertFalse(success); + + vm.stopPrank(); + } + + /// @dev Plugin 1 hook: 1 + /// Plugin 2 hook: 2 + /// Expected execution: [1, 2] + function test_preUserOpValidationHooks_install() public { + vm.startPrank(owner1); + + _installPlugin1WithPreUserOpValidationHook( + _EXEC_SELECTOR, + ManifestFunction({functionType: ManifestAssociatedFunctionType.SELF, functionId: 1, dependencyIndex: 0}) + ); + + _installPlugin2WithPreUserOpValidationHook( + _EXEC_SELECTOR, + ManifestFunction({functionType: ManifestAssociatedFunctionType.SELF, functionId: 2, dependencyIndex: 0}), + new FunctionReference[](0) + ); + + vm.stopPrank(); + } + + /// @dev Plugin 1 hook: 1 + /// Plugin 2 hook: 2 + /// Expected execution: [1, 2] + function test_preUserOpValidationHooks_run() public { + test_preUserOpValidationHooks_install(); + + vm.startPrank(owner1); + + UserOperation memory userOp = UserOperation({ + sender: address(account1), + nonce: 0, + initCode: "", + callData: abi.encodeWithSelector(_EXEC_SELECTOR), + callGasLimit: CALL_GAS_LIMIT, + verificationGasLimit: VERIFICATION_GAS_LIMIT, + preVerificationGas: 0, + maxFeePerGas: 2, + maxPriorityFeePerGas: 1, + paymasterAndData: "", + signature: "" + }); + + // Generate signature + bytes32 userOpHash = entryPoint.getUserOpHash(userOp); + (uint8 v, bytes32 r, bytes32 s) = vm.sign(owner1Key, userOpHash.toEthSignedMessageHash()); + userOp.signature = abi.encodePacked(r, s, v); + + vm.expectCall( + address(mockPlugin1), + abi.encodeWithSelector(IPlugin.preUserOpValidationHook.selector, 1, userOp, userOpHash), + 1 + ); + + vm.expectCall( + address(mockPlugin2), + abi.encodeWithSelector(IPlugin.preUserOpValidationHook.selector, 2, userOp, userOpHash), + 1 + ); + + UserOperation[] memory userOps = new UserOperation[](1); + userOps[0] = userOp; + + entryPoint.handleOps(userOps, beneficiary); + + vm.stopPrank(); + } + + /// @dev Plugin 1 hook: 1 + /// Plugin 2 hook: 2 + /// Expected execution: [1, 2] + function test_preUserOpValidationHooks_uninstall() public { + test_preUserOpValidationHooks_install(); + + vm.startPrank(owner1); + + // Uninstall the second plugin. + _uninstallPlugin(mockPlugin2); + + UserOperation memory userOp = UserOperation({ + sender: address(account1), + nonce: 0, + initCode: "", + callData: abi.encodeWithSelector(_EXEC_SELECTOR), + callGasLimit: CALL_GAS_LIMIT, + verificationGasLimit: VERIFICATION_GAS_LIMIT, + preVerificationGas: 0, + maxFeePerGas: 2, + maxPriorityFeePerGas: 1, + paymasterAndData: "", + signature: "" + }); + + // Generate signature + bytes32 userOpHash = entryPoint.getUserOpHash(userOp); + (uint8 v, bytes32 r, bytes32 s) = vm.sign(owner1Key, userOpHash.toEthSignedMessageHash()); + userOp.signature = abi.encodePacked(r, s, v); + + // Expect hook 1 to exist, but not hook 2. + vm.expectCall( + address(mockPlugin1), + abi.encodeWithSelector(IPlugin.preUserOpValidationHook.selector, 1, userOp, userOpHash), + 1 + ); + vm.expectCall( + address(mockPlugin2), + abi.encodeWithSelector(IPlugin.preUserOpValidationHook.selector, 2, userOp, userOpHash), + 0 + ); + + UserOperation[] memory userOps = new UserOperation[](1); + userOps[0] = userOp; + + entryPoint.handleOps(userOps, beneficiary); + + // Uninstall the first plugin. + _uninstallPlugin(mockPlugin1); + + vm.stopPrank(); + } + + function test_preUserOpValidationHooks_revertAlwaysDeny() public { + vm.startPrank(owner1); + + _installPlugin1WithPreUserOpValidationHook( + _EXEC_SELECTOR, + ManifestFunction({ + functionType: ManifestAssociatedFunctionType.PRE_HOOK_ALWAYS_DENY, + functionId: 0, + dependencyIndex: 0 + }) + ); + + vm.stopPrank(); + + UserOperation memory userOp = UserOperation({ + sender: address(account1), + nonce: 0, + initCode: "", + callData: abi.encodeWithSelector(_EXEC_SELECTOR), + callGasLimit: CALL_GAS_LIMIT, + verificationGasLimit: VERIFICATION_GAS_LIMIT, + preVerificationGas: 0, + maxFeePerGas: 2, + maxPriorityFeePerGas: 1, + paymasterAndData: "", + signature: "" + }); + bytes32 userOpHash = entryPoint.getUserOpHash(userOp); + (uint8 v, bytes32 r, bytes32 s) = vm.sign(owner1Key, userOpHash.toEthSignedMessageHash()); + userOp.signature = abi.encodePacked(r, s, v); + + vm.startPrank(address(entryPoint)); + vm.expectRevert(UpgradeableModularAccount.AlwaysDenyRule.selector); + account1.validateUserOp(userOp, userOpHash, 0); + vm.stopPrank(); + } + + /// @dev Plugin 1 hook: 1 + /// Plugin 2 hook: 1 + /// Expected execution: [1, 1] + function test_overlappingPreUserOpValidationHooks_install() public { + vm.startPrank(owner1); + + _installPlugin1WithPreUserOpValidationHook( + _EXEC_SELECTOR, + ManifestFunction({functionType: ManifestAssociatedFunctionType.SELF, functionId: 1, dependencyIndex: 0}) + ); + + FunctionReference[] memory dependencies = new FunctionReference[](1); + dependencies[0] = FunctionReferenceLib.pack(address(mockPlugin1), 1); + _installPlugin2WithPreUserOpValidationHook( + _EXEC_SELECTOR, + ManifestFunction({ + functionType: ManifestAssociatedFunctionType.DEPENDENCY, + functionId: 0, + dependencyIndex: 0 + }), + dependencies + ); + + vm.stopPrank(); + } + + /// @dev Plugin 1 hook: 1 + /// Plugin 2 hook: 1 + /// Expected execution: [1] + function test_overlappingPreUserOpValidationHooks_run() public { + test_overlappingPreUserOpValidationHooks_install(); + + vm.startPrank(owner1); + + UserOperation memory userOp = UserOperation({ + sender: address(account1), + nonce: 0, + initCode: "", + callData: abi.encodeWithSelector(_EXEC_SELECTOR), + callGasLimit: CALL_GAS_LIMIT, + verificationGasLimit: VERIFICATION_GAS_LIMIT, + preVerificationGas: 0, + maxFeePerGas: 2, + maxPriorityFeePerGas: 1, + paymasterAndData: "", + signature: "" + }); + + // Generate signature + bytes32 userOpHash = entryPoint.getUserOpHash(userOp); + (uint8 v, bytes32 r, bytes32 s) = vm.sign(owner1Key, userOpHash.toEthSignedMessageHash()); + userOp.signature = abi.encodePacked(r, s, v); + + vm.expectCall( + address(mockPlugin1), + abi.encodeWithSelector(IPlugin.preUserOpValidationHook.selector, 1, userOp, userOpHash), + 1 + ); + + UserOperation[] memory userOps = new UserOperation[](1); + userOps[0] = userOp; + + entryPoint.handleOps(userOps, beneficiary); + + vm.stopPrank(); + } + + /// @dev Plugin 1 hook: 1 + /// Plugin 2 hook: 1 + /// Expected execution: [1] + function test_overlappingPreUserOpValidationHooks_uninstall() public { + test_overlappingPreUserOpValidationHooks_install(); + + vm.startPrank(owner1); + + // Uninstall the second plugin. + _uninstallPlugin(mockPlugin2); + + UserOperation memory userOp = UserOperation({ + sender: address(account1), + nonce: 0, + initCode: "", + callData: abi.encodeWithSelector(_EXEC_SELECTOR), + callGasLimit: CALL_GAS_LIMIT, + verificationGasLimit: VERIFICATION_GAS_LIMIT, + preVerificationGas: 0, + maxFeePerGas: 2, + maxPriorityFeePerGas: 1, + paymasterAndData: "", + signature: "" + }); + + // Generate signature + bytes32 userOpHash = entryPoint.getUserOpHash(userOp); + (uint8 v, bytes32 r, bytes32 s) = vm.sign(owner1Key, userOpHash.toEthSignedMessageHash()); + userOp.signature = abi.encodePacked(r, s, v); + + // Expect hook 1 to still exist. + vm.expectCall( + address(mockPlugin1), + abi.encodeWithSelector(IPlugin.preUserOpValidationHook.selector, 1, userOp, userOpHash), + 1 + ); + + UserOperation[] memory userOps = new UserOperation[](1); + userOps[0] = userOp; + + entryPoint.handleOps(userOps, beneficiary); + + // Uninstall the first plugin. + _uninstallPlugin(mockPlugin1); + + vm.stopPrank(); + } + + function _installPlugin1WithPreRuntimeValidationHook(bytes4 selector, ManifestFunction memory hook) + internal + returns (MockPlugin) + { + m1.preRuntimeValidationHooks.push(ManifestAssociatedFunction(selector, hook)); + + mockPlugin1 = new MockPlugin(m1); + manifestHash1 = keccak256(abi.encode(mockPlugin1.pluginManifest())); + + vm.expectCall(address(mockPlugin1), abi.encodeCall(IPlugin.onInstall, ("")), 1); + vm.expectEmit(true, true, true, true); + emit PluginInstalled( + address(mockPlugin1), manifestHash1, new FunctionReference[](0), new IPluginManager.InjectedHook[](0) + ); + + account1.installPlugin({ + plugin: address(mockPlugin1), + manifestHash: manifestHash1, + pluginInitData: bytes(""), + dependencies: new FunctionReference[](0), + injectedHooks: new IPluginManager.InjectedHook[](0) + }); + + return mockPlugin1; + } + + function _installPlugin2WithPreRuntimeValidationHook( + bytes4 selector, + ManifestFunction memory hook, + FunctionReference[] memory dependencies + ) internal { + if (hook.functionType == ManifestAssociatedFunctionType.DEPENDENCY) { + m2.dependencyInterfaceIds.push(type(IPlugin).interfaceId); + } + m2.preRuntimeValidationHooks.push(ManifestAssociatedFunction(selector, hook)); + + mockPlugin2 = new MockPlugin(m2); + manifestHash2 = keccak256(abi.encode(mockPlugin2.pluginManifest())); + + vm.expectCall(address(mockPlugin2), abi.encodeCall(IPlugin.onInstall, ("")), 1); + vm.expectEmit(true, true, true, true); + emit PluginInstalled( + address(mockPlugin2), manifestHash2, dependencies, new IPluginManager.InjectedHook[](0) + ); + + account1.installPlugin({ + plugin: address(mockPlugin2), + manifestHash: manifestHash2, + pluginInitData: bytes(""), + dependencies: dependencies, + injectedHooks: new IPluginManager.InjectedHook[](0) + }); + } + + function _installPlugin1WithPreUserOpValidationHook(bytes4 selector, ManifestFunction memory hook) internal { + // Set up the user op validation function first. + m1.dependencyInterfaceIds.push(type(IPlugin).interfaceId); + FunctionReference[] memory dependencies = new FunctionReference[](1); + dependencies[0] = FunctionReferenceLib.pack( + address(multiOwnerPlugin), uint8(IMultiOwnerPlugin.FunctionId.USER_OP_VALIDATION_OWNER) + ); + m1.userOpValidationFunctions.push( + ManifestAssociatedFunction({ + executionSelector: selector, + associatedFunction: ManifestFunction({ + functionType: ManifestAssociatedFunctionType.DEPENDENCY, + functionId: 0, + dependencyIndex: 0 + }) + }) + ); + + m1.preUserOpValidationHooks.push(ManifestAssociatedFunction(selector, hook)); + + mockPlugin1 = new MockPlugin(m1); + manifestHash1 = keccak256(abi.encode(mockPlugin1.pluginManifest())); + + vm.expectCall(address(mockPlugin1), abi.encodeCall(IPlugin.onInstall, ("")), 1); + vm.expectEmit(true, true, true, true); + emit PluginInstalled( + address(mockPlugin1), manifestHash1, dependencies, new IPluginManager.InjectedHook[](0) + ); + + account1.installPlugin({ + plugin: address(mockPlugin1), + manifestHash: manifestHash1, + pluginInitData: bytes(""), + dependencies: dependencies, + injectedHooks: new IPluginManager.InjectedHook[](0) + }); + } + + function _installPlugin2WithPreUserOpValidationHook( + bytes4 selector, + ManifestFunction memory hook, + FunctionReference[] memory dependencies + ) internal { + if (hook.functionType == ManifestAssociatedFunctionType.DEPENDENCY) { + m2.dependencyInterfaceIds.push(type(IPlugin).interfaceId); + } + m2.preUserOpValidationHooks.push(ManifestAssociatedFunction(selector, hook)); + + mockPlugin2 = new MockPlugin(m2); + manifestHash2 = keccak256(abi.encode(mockPlugin2.pluginManifest())); + + vm.expectCall(address(mockPlugin2), abi.encodeCall(IPlugin.onInstall, ("")), 1); + vm.expectEmit(true, true, true, true); + emit PluginInstalled( + address(mockPlugin2), manifestHash2, dependencies, new IPluginManager.InjectedHook[](0) + ); + + account1.installPlugin({ + plugin: address(mockPlugin2), + manifestHash: manifestHash2, + pluginInitData: bytes(""), + dependencies: dependencies, + injectedHooks: new IPluginManager.InjectedHook[](0) + }); + } + + function _uninstallPlugin(MockPlugin plugin) internal { + vm.expectCall(address(plugin), abi.encodeCall(IPlugin.onUninstall, ("")), 1); + vm.expectEmit(true, true, true, true); + emit PluginUninstalled(address(plugin), true); + + account1.uninstallPlugin(address(plugin), bytes(""), bytes(""), new bytes[](0)); + } +} diff --git a/test/account/AccountReturnData.t.sol b/test/account/AccountReturnData.t.sol new file mode 100644 index 00000000..4bcd41b8 --- /dev/null +++ b/test/account/AccountReturnData.t.sol @@ -0,0 +1,132 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.21; + +import {Test} from "forge-std/Test.sol"; + +import {EntryPoint} from "@eth-infinitism/account-abstraction/core/EntryPoint.sol"; + +import {IPluginManager} from "../../src/interfaces/IPluginManager.sol"; +import {UpgradeableModularAccount} from "../../src/account/UpgradeableModularAccount.sol"; +import {IEntryPoint} from "../../src/interfaces/erc4337/IEntryPoint.sol"; +import {MultiOwnerPlugin} from "../../src/plugins/owner/MultiOwnerPlugin.sol"; +import {FunctionReference} from "../../src/libraries/FunctionReferenceLib.sol"; +import {Call} from "../../src/interfaces/IStandardExecutor.sol"; + +import { + RegularResultContract, + ResultCreatorPlugin, + ResultConsumerPlugin +} from "../mocks/plugins/ReturnDataPluginMocks.sol"; +import {MultiOwnerMSCAFactory} from "../../src/factory/MultiOwnerMSCAFactory.sol"; + +// Tests all the different ways that return data can be read from plugins through an account +contract AccountReturnDataTest is Test { + IEntryPoint public entryPoint; // Just to be able to construct the factory + MultiOwnerPlugin public multiOwnerPlugin; + MultiOwnerMSCAFactory public factory; + + RegularResultContract public regularResultContract; + ResultCreatorPlugin public resultCreatorPlugin; + ResultConsumerPlugin public resultConsumerPlugin; + + UpgradeableModularAccount public account; + + function setUp() public { + entryPoint = IEntryPoint(address(new EntryPoint())); + multiOwnerPlugin = new MultiOwnerPlugin(); + address impl = address(new UpgradeableModularAccount(entryPoint)); + + factory = new MultiOwnerMSCAFactory( + address(this), + address(multiOwnerPlugin), + impl, + keccak256(abi.encode(multiOwnerPlugin.pluginManifest())), + entryPoint + ); + + regularResultContract = new RegularResultContract(); + resultCreatorPlugin = new ResultCreatorPlugin(); + resultConsumerPlugin = new ResultConsumerPlugin(resultCreatorPlugin, regularResultContract); + + // Create an account with "this" as the owner, so we can execute along the runtime path with regular + // solidity semantics + address[] memory owners = new address[](1); + owners[0] = address(this); + account = UpgradeableModularAccount(payable(factory.createAccount(0, owners))); + + // Add the result creator plugin to the account + bytes32 resultCreatorManifestHash = keccak256(abi.encode(resultCreatorPlugin.pluginManifest())); + account.installPlugin({ + plugin: address(resultCreatorPlugin), + manifestHash: resultCreatorManifestHash, + pluginInitData: "", + dependencies: new FunctionReference[](0), + injectedHooks: new IPluginManager.InjectedHook[](0) + }); + // Add the result consumer plugin to the account + bytes32 resultConsumerManifestHash = keccak256(abi.encode(resultConsumerPlugin.pluginManifest())); + account.installPlugin({ + plugin: address(resultConsumerPlugin), + manifestHash: resultConsumerManifestHash, + pluginInitData: "", + dependencies: new FunctionReference[](0), + injectedHooks: new IPluginManager.InjectedHook[](0) + }); + } + + // Tests the ability to read the result of plugin execution functions via the account's fallback + function test_returnData_fallback() public { + bytes32 result = ResultCreatorPlugin(address(account)).foo(); + + assertEq(result, keccak256("bar")); + } + + // Tests the ability to read the results of contracts called via IStandardExecutor.execute + function test_returnData_singular_execute() public { + bytes memory returnData = + account.execute(address(regularResultContract), 0, abi.encodeCall(RegularResultContract.foo, ())); + + bytes32 result = abi.decode(returnData, (bytes32)); + + assertEq(result, keccak256("bar")); + } + + // Tests the ability to read the results of multiple contract calls via IStandardExecutor.executeBatch + function test_returnData_executeBatch() public { + Call[] memory calls = new Call[](2); + calls[0] = Call({ + target: address(regularResultContract), + value: 0, + data: abi.encodeCall(RegularResultContract.foo, ()) + }); + calls[1] = Call({ + target: address(regularResultContract), + value: 0, + data: abi.encodeCall(RegularResultContract.bar, ()) + }); + + bytes[] memory returnDatas = account.executeBatch(calls); + + bytes32 result1 = abi.decode(returnDatas[0], (bytes32)); + bytes32 result2 = abi.decode(returnDatas[1], (bytes32)); + + assertEq(result1, keccak256("bar")); + assertEq(result2, keccak256("foo")); + } + + // Tests the ability to read data via executeFromPlugin routing to fallback functions + function test_returnData_execFromPlugin_fallback() public { + bool result = ResultConsumerPlugin(address(account)).checkResultEFPFallback(keccak256("bar")); + + assertTrue(result); + } + + // Tests the ability to read data via executeFromPluginExternal + function test_returnData_execFromPlugin_execute() public { + bool result = ResultConsumerPlugin(address(account)).checkResultEFPExternal( + address(regularResultContract), keccak256("bar") + ); + + assertTrue(result); + } +} diff --git a/test/account/ExecuteFromPluginPermissions.t.sol b/test/account/ExecuteFromPluginPermissions.t.sol new file mode 100644 index 00000000..0862a93b --- /dev/null +++ b/test/account/ExecuteFromPluginPermissions.t.sol @@ -0,0 +1,419 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.21; + +import {Test, console} from "forge-std/Test.sol"; + +import {EntryPoint} from "@eth-infinitism/account-abstraction/core/EntryPoint.sol"; + +import {IPlugin} from "../../src/interfaces/IPlugin.sol"; +import {IPluginManager} from "../../src/interfaces/IPluginManager.sol"; +import {UpgradeableModularAccount} from "../../src/account/UpgradeableModularAccount.sol"; +import {IEntryPoint} from "../../src/interfaces/erc4337/IEntryPoint.sol"; +import {MultiOwnerPlugin} from "../../src/plugins/owner/MultiOwnerPlugin.sol"; +import {FunctionReference} from "../../src/libraries/FunctionReferenceLib.sol"; + +import {MultiOwnerMSCAFactory} from "../../src/factory/MultiOwnerMSCAFactory.sol"; + +import {Counter} from "../mocks/Counter.sol"; +import {ResultCreatorPlugin} from "../mocks/plugins/ReturnDataPluginMocks.sol"; + +import { + EFPCallerPlugin, + EFPCallerPluginAnyExternal, + EFPCallerPluginAnyExternalCanSpendNativeToken, + EFPExecutionHookPlugin, + EFPExternalPermittedCallHookPlugin, + EFPPermittedCallHookPlugin +} from "../mocks/plugins/ExecFromPluginPermissionsMocks.sol"; + +contract ExecuteFromPluginPermissionsTest is Test { + Counter public counter1; + Counter public counter2; + Counter public counter3; + ResultCreatorPlugin public resultCreatorPlugin; + + IEntryPoint public entryPoint; // Just to be able to construct the factory + MultiOwnerPlugin public multiOwnerPlugin; + MultiOwnerMSCAFactory public factory; + UpgradeableModularAccount public account; + + EFPCallerPlugin public efpCallerPlugin; + EFPCallerPluginAnyExternal public efpCallerPluginAnyExternal; + EFPCallerPluginAnyExternalCanSpendNativeToken public efpCallerPluginAnyExternalCanSpendNativeToken; + EFPPermittedCallHookPlugin public efpPermittedCallHookPlugin; + EFPExternalPermittedCallHookPlugin public efpExternalPermittedCallHookPlugin; + EFPExecutionHookPlugin public efpExecutionHookPlugin; + + function setUp() public { + // Initialize the interaction targets + counter1 = new Counter(); + counter2 = new Counter(); + counter3 = new Counter(); + resultCreatorPlugin = new ResultCreatorPlugin(); + + // Initialize the contracts needed to use the account. + entryPoint = IEntryPoint(address(new EntryPoint())); + multiOwnerPlugin = new MultiOwnerPlugin(); + address impl = address(new UpgradeableModularAccount(entryPoint)); + + factory = new MultiOwnerMSCAFactory( + address(this), + address(multiOwnerPlugin), + impl, + keccak256(abi.encode(multiOwnerPlugin.pluginManifest())), + entryPoint + ); + + // Initialize the EFP caller plugins, which will attempt to use the permissions system to authorize calls. + efpCallerPlugin = new EFPCallerPlugin(); + efpCallerPluginAnyExternal = new EFPCallerPluginAnyExternal(); + efpCallerPluginAnyExternalCanSpendNativeToken = new EFPCallerPluginAnyExternalCanSpendNativeToken(); + efpPermittedCallHookPlugin = new EFPPermittedCallHookPlugin(); + efpExternalPermittedCallHookPlugin = new EFPExternalPermittedCallHookPlugin(); + efpExecutionHookPlugin = new EFPExecutionHookPlugin(); + + // Create an account with "this" as the owner, so we can execute along the runtime path with regular + // solidity semantics + address[] memory owners = new address[](1); + owners[0] = address(this); + account = UpgradeableModularAccount(payable(factory.createAccount(0, owners))); + + // Add the result creator plugin to the account + bytes32 resultCreatorManifestHash = keccak256(abi.encode(resultCreatorPlugin.pluginManifest())); + account.installPlugin({ + plugin: address(resultCreatorPlugin), + manifestHash: resultCreatorManifestHash, + pluginInitData: "", + dependencies: new FunctionReference[](0), + injectedHooks: new IPluginManager.InjectedHook[](0) + }); + // Add the EFP caller plugin to the account + bytes32 efpCallerManifestHash = keccak256(abi.encode(efpCallerPlugin.pluginManifest())); + account.installPlugin({ + plugin: address(efpCallerPlugin), + manifestHash: efpCallerManifestHash, + pluginInitData: "", + dependencies: new FunctionReference[](0), + injectedHooks: new IPluginManager.InjectedHook[](0) + }); + + // Add the EFP caller plugin with any external permissions to the account + bytes32 efpCallerAnyExternalManifestHash = + keccak256(abi.encode(efpCallerPluginAnyExternal.pluginManifest())); + account.installPlugin({ + plugin: address(efpCallerPluginAnyExternal), + manifestHash: efpCallerAnyExternalManifestHash, + pluginInitData: "", + dependencies: new FunctionReference[](0), + injectedHooks: new IPluginManager.InjectedHook[](0) + }); + + // Add the EFP caller plugin with any external permissions and native token spend permission to the account + bytes32 efpCallerAnyExternalCanSpendNativeTokenManifestHash = + keccak256(abi.encode(efpCallerPluginAnyExternalCanSpendNativeToken.pluginManifest())); + account.installPlugin({ + plugin: address(efpCallerPluginAnyExternalCanSpendNativeToken), + manifestHash: efpCallerAnyExternalCanSpendNativeTokenManifestHash, + pluginInitData: "", + dependencies: new FunctionReference[](0), + injectedHooks: new IPluginManager.InjectedHook[](0) + }); + + // Add the EFP caller plugin with permitted call hooks to the account + bytes32 efpPermittedCallHookManifestHash = + keccak256(abi.encode(efpPermittedCallHookPlugin.pluginManifest())); + account.installPlugin({ + plugin: address(efpPermittedCallHookPlugin), + manifestHash: efpPermittedCallHookManifestHash, + pluginInitData: "", + dependencies: new FunctionReference[](0), + injectedHooks: new IPluginManager.InjectedHook[](0) + }); + + // Add the EFP caller plugin with an external permitted call hook to the account + bytes32 efpExternalPermittedCallHookManifestHash = + keccak256(abi.encode(efpExternalPermittedCallHookPlugin.pluginManifest())); + account.installPlugin({ + plugin: address(efpExternalPermittedCallHookPlugin), + manifestHash: efpExternalPermittedCallHookManifestHash, + pluginInitData: "", + dependencies: new FunctionReference[](0), + injectedHooks: new IPluginManager.InjectedHook[](0) + }); + + // Add the EFP caller plugin with execution hooks to the account + bytes32 efpExecutionHookPluginManifestHash = keccak256(abi.encode(efpExecutionHookPlugin.pluginManifest())); + account.installPlugin({ + plugin: address(efpExecutionHookPlugin), + manifestHash: efpExecutionHookPluginManifestHash, + pluginInitData: "", + dependencies: new FunctionReference[](0), + injectedHooks: new IPluginManager.InjectedHook[](0) + }); + } + + // Report the addresses to be used in the address constants in ExecFromPluginPermissionsMocks.sol + function test_getPermissionsTestAddresses() public view { + // solhint-disable no-console + console.log("counter1 address: %s", address(counter1)); + console.log("counter2 address: %s", address(counter2)); + console.log("counter3 address: %s", address(counter3)); + console.log("resultCreatorPlugin address: %s", address(resultCreatorPlugin)); + // solhint-enable no-console + } + + function test_executeFromPluginAllowed() public { + bytes memory result = EFPCallerPlugin(address(account)).useEFPPermissionAllowed(); + bytes32 actual = abi.decode(result, (bytes32)); + + assertEq(actual, keccak256("bar")); + } + + function test_executeFromPluginNotAllowed() public { + vm.expectRevert( + abi.encodeWithSelector( + UpgradeableModularAccount.ExecFromPluginNotPermitted.selector, + address(efpCallerPlugin), + ResultCreatorPlugin.bar.selector + ) + ); + EFPCallerPlugin(address(account)).useEFPPermissionNotAllowed(); + } + + function test_executeFromPluginUnrecognizedFunction() public { + // Permitted but uninstalled selector + vm.expectRevert( + abi.encodeWithSelector( + UpgradeableModularAccount.UnrecognizedFunction.selector, bytes4(keccak256("baz()")) + ) + ); + EFPCallerPlugin(address(account)).passthroughExecuteFromPlugin( + abi.encodeWithSelector(bytes4(keccak256("baz()"))) + ); + + // Invalid selector < 4 bytes + vm.expectRevert( + abi.encodeWithSelector(UpgradeableModularAccount.UnrecognizedFunction.selector, bytes4(hex"11")) + ); + EFPCallerPlugin(address(account)).passthroughExecuteFromPlugin(hex"11"); + } + + function test_executeFromPluginExternal_Allowed_IndividualSelectors() public { + EFPCallerPlugin(address(account)).setNumberCounter1(17); + uint256 retrievedNumber = EFPCallerPlugin(address(account)).getNumberCounter1(); + + assertEq(retrievedNumber, 17); + } + + function test_executeFromPluginExternal_NotAlowed_IndividualSelectors() public { + EFPCallerPlugin(address(account)).setNumberCounter1(17); + + // Call to increment should fail + vm.expectRevert( + abi.encodeWithSelector( + UpgradeableModularAccount.ExecFromPluginExternalNotPermitted.selector, + address(efpCallerPlugin), + address(counter1), + 0, + abi.encodePacked(Counter.increment.selector) + ) + ); + EFPCallerPlugin(address(account)).incrementCounter1(); + + uint256 retrievedNumber = EFPCallerPlugin(address(account)).getNumberCounter1(); + + assertEq(retrievedNumber, 17); + } + + function test_executeFromPluginExternal_Allowed_AllSelectors() public { + EFPCallerPlugin(address(account)).setNumberCounter2(17); + EFPCallerPlugin(address(account)).incrementCounter2(); + uint256 retrievedNumber = EFPCallerPlugin(address(account)).getNumberCounter2(); + + assertEq(retrievedNumber, 18); + } + + function test_executeFromPluginExternal_NotAllowed_AllSelectors() public { + vm.expectRevert( + abi.encodeWithSelector( + UpgradeableModularAccount.ExecFromPluginExternalNotPermitted.selector, + address(efpCallerPlugin), + address(counter3), + 0, + abi.encodeWithSelector(Counter.setNumber.selector, uint256(17)) + ) + ); + EFPCallerPlugin(address(account)).setNumberCounter3(17); + + // Call to increment should fail + vm.expectRevert( + abi.encodeWithSelector( + UpgradeableModularAccount.ExecFromPluginExternalNotPermitted.selector, + address(efpCallerPlugin), + address(counter3), + 0, + abi.encodePacked(Counter.increment.selector) + ) + ); + EFPCallerPlugin(address(account)).incrementCounter3(); + + vm.expectRevert( + abi.encodeWithSelector( + UpgradeableModularAccount.ExecFromPluginExternalNotPermitted.selector, + address(efpCallerPlugin), + address(counter3), + 0, + abi.encodePacked(bytes4(keccak256("number()"))) + ) + ); + EFPCallerPlugin(address(account)).getNumberCounter3(); + + // Validate no state changes + assert(counter3.number() == 0); + } + + function test_executeFromPluginExternal_Allowed_AnyContract() public { + // Run full workflow for counter 1 + + EFPCallerPluginAnyExternal(address(account)).passthroughExecute( + address(counter1), 0, abi.encodeCall(Counter.setNumber, (17)) + ); + uint256 retrievedNumber = counter1.number(); + assertEq(retrievedNumber, 17); + + EFPCallerPluginAnyExternal(address(account)).passthroughExecute( + address(counter1), 0, abi.encodeCall(Counter.increment, ()) + ); + retrievedNumber = counter1.number(); + assertEq(retrievedNumber, 18); + + bytes memory result = EFPCallerPluginAnyExternal(address(account)).passthroughExecute( + address(counter1), 0, abi.encodePacked(bytes4(keccak256("number()"))) + ); + retrievedNumber = abi.decode(result, (uint256)); + assertEq(retrievedNumber, 18); + + // Run full workflow for counter 2 + + EFPCallerPluginAnyExternal(address(account)).passthroughExecute( + address(counter2), 0, abi.encodeCall(Counter.setNumber, (17)) + ); + retrievedNumber = counter2.number(); + assertEq(retrievedNumber, 17); + + EFPCallerPluginAnyExternal(address(account)).passthroughExecute( + address(counter2), 0, abi.encodeCall(Counter.increment, ()) + ); + retrievedNumber = counter2.number(); + assertEq(retrievedNumber, 18); + + result = EFPCallerPluginAnyExternal(address(account)).passthroughExecute( + address(counter2), 0, abi.encodePacked(bytes4(keccak256("number()"))) + ); + retrievedNumber = abi.decode(result, (uint256)); + assertEq(retrievedNumber, 18); + } + + function test_executeFromPluginExternal_NotAllowed_NativeTokenSpending() public { + vm.expectRevert( + abi.encodeWithSelector( + UpgradeableModularAccount.NativeTokenSpendingNotPermitted.selector, + address(efpCallerPluginAnyExternal) + ) + ); + EFPCallerPluginAnyExternal(address(account)).passthroughExecute(address(counter1), 1 ether, ""); + + address recipient = makeAddr("recipient"); + vm.deal(address(efpCallerPluginAnyExternal), 1 ether); + // This function forwards 1 eth from the plugin to the account and tries to send 2 eth to the recipient. + // This is not allowed because there would be a net decrease of the balance on the account. + vm.expectRevert( + abi.encodeWithSelector( + UpgradeableModularAccount.NativeTokenSpendingNotPermitted.selector, + address(efpCallerPluginAnyExternal) + ) + ); + EFPCallerPluginAnyExternal(address(account)).passthroughExecuteWith1Eth(address(recipient), 2 ether, ""); + } + + function test_executeFromPluginExternal_Allowed_NativeTokenSpending() public { + address recipient = makeAddr("recipient"); + + vm.deal(address(efpCallerPluginAnyExternal), 1 ether); + assertEq(address(recipient).balance, 0); + // This function forwards 1 eth from the plugin to the account and sends 1 eth to the recipient. This is + // allowed because there is no net change to the balance on the account. + EFPCallerPluginAnyExternal(address(account)).passthroughExecuteWith1Eth(address(recipient), 1 ether, ""); + assertEq(address(efpCallerPluginAnyExternal).balance, 0); + assertEq(address(recipient).balance, 1 ether); + + vm.deal(address(account), 1 ether); + EFPCallerPluginAnyExternalCanSpendNativeToken(address(account)) + .passthroughExecuteWithNativeTokenSpendPermission(address(recipient), 1 ether, ""); + assertEq(address(recipient).balance, 2 ether); + } + + function test_executeFromPlugin_PermittedCallHooks() public { + assertFalse(efpPermittedCallHookPlugin.preExecHookCalled()); + assertFalse(efpPermittedCallHookPlugin.postExecHookCalled()); + + bytes memory result = EFPPermittedCallHookPlugin(address(account)).performEFPCall(); + + bytes32 actual = abi.decode(result, (bytes32)); + + assertEq(actual, keccak256("bar")); + + assertTrue(efpPermittedCallHookPlugin.preExecHookCalled()); + assertTrue(efpPermittedCallHookPlugin.postExecHookCalled()); + } + + function test_executeFromPluginExternal_PermittedCallHooks() public { + counter1.setNumber(17); + + assertFalse(efpExternalPermittedCallHookPlugin.preExecHookCalled()); + assertFalse(efpExternalPermittedCallHookPlugin.postExecHookCalled()); + + EFPExternalPermittedCallHookPlugin(address(account)).performIncrement(); + + assertTrue(efpExternalPermittedCallHookPlugin.preExecHookCalled()); + assertTrue(efpExternalPermittedCallHookPlugin.postExecHookCalled()); + + uint256 retrievedNumber = counter1.number(); + assertEq(retrievedNumber, 18); + } + + function test_executeFromPlugin_ExecutionHooks() public { + // Expect the pre hook to be called just once. + vm.expectCall( + address(efpExecutionHookPlugin), + abi.encodeWithSelector( + IPlugin.preExecutionHook.selector, + 1, + address(efpExecutionHookPlugin), // caller + 0, // msg.value in call to account + abi.encodeWithSelector(ResultCreatorPlugin.foo.selector) + ), + 1 + ); + // Expect the post hook to be called twice, with the expected data. + vm.expectCall( + address(efpExecutionHookPlugin), + abi.encodeWithSelector( + IPlugin.postExecutionHook.selector, + 2, + abi.encode(1) // preExecHookData + ), + 1 + ); + vm.expectCall( + address(efpExecutionHookPlugin), + abi.encodeWithSelector( + IPlugin.postExecutionHook.selector, + 2, + "" // preExecHookData (none for this post only hook) + ), + 1 + ); + EFPExecutionHookPlugin(address(account)).performEFPCallWithExecHooks(); + } +} diff --git a/test/account/ManifestValidity.t.sol b/test/account/ManifestValidity.t.sol new file mode 100644 index 00000000..8ff2c516 --- /dev/null +++ b/test/account/ManifestValidity.t.sol @@ -0,0 +1,188 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.21; + +import {Test} from "forge-std/Test.sol"; + +import {EntryPoint} from "@eth-infinitism/account-abstraction/core/EntryPoint.sol"; + +import {IPluginManager} from "../../src/interfaces/IPluginManager.sol"; +import {UpgradeableModularAccount} from "../../src/account/UpgradeableModularAccount.sol"; +import {PluginManagerInternals} from "../../src/account/PluginManagerInternals.sol"; +import {IEntryPoint} from "../../src/interfaces/erc4337/IEntryPoint.sol"; +import {MultiOwnerPlugin} from "../../src/plugins/owner/MultiOwnerPlugin.sol"; +import {FunctionReference} from "../../src/libraries/FunctionReferenceLib.sol"; + +import {MultiOwnerMSCAFactory} from "../../src/factory/MultiOwnerMSCAFactory.sol"; +import { + BadValidationMagicValue_UserOp_Plugin, + BadValidationMagicValue_PreRuntimeValidationHook_Plugin, + BadValidationMagicValue_PreUserOpValidationHook_Plugin, + BadValidationMagicValue_PreExecHook_Plugin, + BadValidationMagicValue_PostExecHook_Plugin, + BadHookMagicValue_UserOpValidationFunction_Plugin, + BadHookMagicValue_RuntimeValidationFunction_Plugin, + BadHookMagicValue_PostExecHook_Plugin +} from "../mocks/plugins/ManifestValidityMocks.sol"; + +contract ManifestValidityTest is Test { + IEntryPoint public entryPoint; // Just to be able to construct the factory + MultiOwnerPlugin public multiOwnerPlugin; + MultiOwnerMSCAFactory public factory; + + UpgradeableModularAccount public account; + + function setUp() public { + entryPoint = IEntryPoint(address(new EntryPoint())); + multiOwnerPlugin = new MultiOwnerPlugin(); + address impl = address(new UpgradeableModularAccount(entryPoint)); + + factory = new MultiOwnerMSCAFactory( + address(this), + address(multiOwnerPlugin), + impl, + keccak256(abi.encode(multiOwnerPlugin.pluginManifest())), + entryPoint + ); + + // Create an account with "this" as the owner, so we can execute along the runtime path with regular + // solidity semantics + address[] memory owners = new address[](1); + owners[0] = address(this); + account = UpgradeableModularAccount(payable(factory.createAccount(0, owners))); + } + + // Tests that the plugin manager rejects a plugin with a user op validationFunction set to "validation always + // allow" + function test_ManifestValidity_invalid_ValidationAlwaysAllow_UserOpValidationFunction() public { + BadValidationMagicValue_UserOp_Plugin plugin = new BadValidationMagicValue_UserOp_Plugin(); + + bytes32 manifestHash = keccak256(abi.encode(plugin.pluginManifest())); + + vm.expectRevert(abi.encodeWithSelector(PluginManagerInternals.InvalidPluginManifest.selector)); + account.installPlugin({ + plugin: address(plugin), + manifestHash: manifestHash, + pluginInitData: "", + dependencies: new FunctionReference[](0), + injectedHooks: new IPluginManager.InjectedHook[](0) + }); + } + + // Tests that the plugin manager rejects a plugin with a pre-runtime validation hook set to "validation always + // allow" + function test_ManifestValidity_invalid_ValidationAlwaysAllow_PreRuntimeValidationHook() public { + BadValidationMagicValue_PreRuntimeValidationHook_Plugin plugin = + new BadValidationMagicValue_PreRuntimeValidationHook_Plugin(); + + bytes32 manifestHash = keccak256(abi.encode(plugin.pluginManifest())); + + vm.expectRevert(abi.encodeWithSelector(PluginManagerInternals.InvalidPluginManifest.selector)); + account.installPlugin({ + plugin: address(plugin), + manifestHash: manifestHash, + pluginInitData: "", + dependencies: new FunctionReference[](0), + injectedHooks: new IPluginManager.InjectedHook[](0) + }); + } + + // Tests that the plugin manager rejects a plugin with a pre-user op validation hook set to "validation always + // allow" + function test_ManifestValidity_invalid_ValidationAlwaysAllow_PreUserOpValidationHook() public { + BadValidationMagicValue_PreUserOpValidationHook_Plugin plugin = + new BadValidationMagicValue_PreUserOpValidationHook_Plugin(); + + bytes32 manifestHash = keccak256(abi.encode(plugin.pluginManifest())); + + vm.expectRevert(abi.encodeWithSelector(PluginManagerInternals.InvalidPluginManifest.selector)); + account.installPlugin({ + plugin: address(plugin), + manifestHash: manifestHash, + pluginInitData: "", + dependencies: new FunctionReference[](0), + injectedHooks: new IPluginManager.InjectedHook[](0) + }); + } + + // Tests that the plugin manager rejects a plugin with a pre-execution hook set to "validation always allow" + function test_ManifestValidity_invalid_ValidationAlwaysAllow_PreExecHook() public { + BadValidationMagicValue_PreExecHook_Plugin plugin = new BadValidationMagicValue_PreExecHook_Plugin(); + + bytes32 manifestHash = keccak256(abi.encode(plugin.pluginManifest())); + + vm.expectRevert(abi.encodeWithSelector(PluginManagerInternals.InvalidPluginManifest.selector)); + account.installPlugin({ + plugin: address(plugin), + manifestHash: manifestHash, + pluginInitData: "", + dependencies: new FunctionReference[](0), + injectedHooks: new IPluginManager.InjectedHook[](0) + }); + } + + // Tests that the plugin manager rejects a plugin with a post-execution hook set to "validation always allow" + function test_ManifestValidity_invalid_ValidationAlwaysAllow_PostExecHook() public { + BadValidationMagicValue_PostExecHook_Plugin plugin = new BadValidationMagicValue_PostExecHook_Plugin(); + + bytes32 manifestHash = keccak256(abi.encode(plugin.pluginManifest())); + + vm.expectRevert(abi.encodeWithSelector(PluginManagerInternals.InvalidPluginManifest.selector)); + account.installPlugin({ + plugin: address(plugin), + manifestHash: manifestHash, + pluginInitData: "", + dependencies: new FunctionReference[](0), + injectedHooks: new IPluginManager.InjectedHook[](0) + }); + } + + // Tests that the plugin manager rejects a plugin with a user op validationFunction set to "hook always deny" + function test_ManifestValidity_invalid_HookAlwaysDeny_UserOpValidation() public { + BadHookMagicValue_UserOpValidationFunction_Plugin plugin = + new BadHookMagicValue_UserOpValidationFunction_Plugin(); + + bytes32 manifestHash = keccak256(abi.encode(plugin.pluginManifest())); + + vm.expectRevert(abi.encodeWithSelector(PluginManagerInternals.InvalidPluginManifest.selector)); + account.installPlugin({ + plugin: address(plugin), + manifestHash: manifestHash, + pluginInitData: "", + dependencies: new FunctionReference[](0), + injectedHooks: new IPluginManager.InjectedHook[](0) + }); + } + + // Tests that the plugin manager rejects a plugin with a runtime validationFunction set to "hook always deny" + function test_ManifestValidity_invalid_HookAlwaysDeny_RuntimeValidationFunction() public { + BadHookMagicValue_RuntimeValidationFunction_Plugin plugin = + new BadHookMagicValue_RuntimeValidationFunction_Plugin(); + + bytes32 manifestHash = keccak256(abi.encode(plugin.pluginManifest())); + + vm.expectRevert(abi.encodeWithSelector(PluginManagerInternals.InvalidPluginManifest.selector)); + account.installPlugin({ + plugin: address(plugin), + manifestHash: manifestHash, + pluginInitData: "", + dependencies: new FunctionReference[](0), + injectedHooks: new IPluginManager.InjectedHook[](0) + }); + } + + // Tests that the plugin manager rejects a plugin with a post-execution hook set to "hook always deny" + function test_ManifestValidity_invalid_HookAlwaysDeny_PostExecHook() public { + BadHookMagicValue_PostExecHook_Plugin plugin = new BadHookMagicValue_PostExecHook_Plugin(); + + bytes32 manifestHash = keccak256(abi.encode(plugin.pluginManifest())); + + vm.expectRevert(abi.encodeWithSelector(PluginManagerInternals.InvalidPluginManifest.selector)); + account.installPlugin({ + plugin: address(plugin), + manifestHash: manifestHash, + pluginInitData: "", + dependencies: new FunctionReference[](0), + injectedHooks: new IPluginManager.InjectedHook[](0) + }); + } +} diff --git a/test/account/UpgradeableModularAccount.t.sol b/test/account/UpgradeableModularAccount.t.sol new file mode 100644 index 00000000..799aa4cd --- /dev/null +++ b/test/account/UpgradeableModularAccount.t.sol @@ -0,0 +1,532 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.21; + +import {Test, console} from "forge-std/Test.sol"; + +import {ERC1967Proxy} from "@openzeppelin/contracts/proxy/ERC1967/ERC1967Proxy.sol"; +import {ECDSA} from "@openzeppelin/contracts/utils/cryptography/ECDSA.sol"; +import {EntryPoint} from "@eth-infinitism/account-abstraction/core/EntryPoint.sol"; + +import {AccountExecutor} from "../../src/account/AccountExecutor.sol"; +import {PluginManagerInternals} from "../../src/account/PluginManagerInternals.sol"; +import {UpgradeableModularAccount} from "../../src/account/UpgradeableModularAccount.sol"; +import {IMultiOwnerPlugin} from "../../src/plugins/owner/IMultiOwnerPlugin.sol"; +import {MultiOwnerPlugin} from "../../src/plugins/owner/MultiOwnerPlugin.sol"; +import {SessionKeyPlugin} from "../../src/plugins/session/SessionKeyPlugin.sol"; +import {TokenReceiverPlugin} from "../../src/plugins/TokenReceiverPlugin.sol"; +import {IEntryPoint} from "../../src/interfaces/erc4337/IEntryPoint.sol"; +import {UserOperation} from "../../src/interfaces/erc4337/UserOperation.sol"; +import {IAccountInitializable} from "../../src/interfaces/IAccountInitializable.sol"; +import {IPlugin, PluginManifest} from "../../src/interfaces/IPlugin.sol"; +import {IPluginManager} from "../../src/interfaces/IPluginManager.sol"; +import {Call} from "../../src/interfaces/IStandardExecutor.sol"; +import {FunctionReference} from "../../src/libraries/FunctionReferenceLib.sol"; + +import {Counter} from "../mocks/Counter.sol"; +import {MultiOwnerMSCAFactory} from "../../src/factory/MultiOwnerMSCAFactory.sol"; +import {MockPlugin} from "../mocks/MockPlugin.sol"; + +contract UpgradeableModularAccountTest is Test { + using ECDSA for bytes32; + + IEntryPoint public entryPoint; + address payable public beneficiary; + MultiOwnerPlugin public multiOwnerPlugin; + TokenReceiverPlugin public tokenReceiverPlugin; + SessionKeyPlugin public sessionKeyPlugin; + MultiOwnerMSCAFactory public factory; + address public accountImplementation; + + address public owner1; + uint256 public owner1Key; + UpgradeableModularAccount public account1; + + address public owner2; + uint256 public owner2Key; + UpgradeableModularAccount public account2; + + address[] public owners1; + address[] public owners2; + + address public ethRecipient; + Counter public counter; + PluginManifest public manifest; + IPluginManager.InjectedHooksInfo public injectedHooksInfo = IPluginManager.InjectedHooksInfo({ + preExecHookFunctionId: 2, + isPostHookUsed: true, + postExecHookFunctionId: 3 + }); + + uint256 public constant CALL_GAS_LIMIT = 500000; + uint256 public constant VERIFICATION_GAS_LIMIT = 2000000; + + function setUp() public { + entryPoint = IEntryPoint(address(new EntryPoint())); + (owner1, owner1Key) = makeAddrAndKey("owner1"); + beneficiary = payable(makeAddr("beneficiary")); + vm.deal(beneficiary, 1 wei); + + multiOwnerPlugin = new MultiOwnerPlugin(); + tokenReceiverPlugin = new TokenReceiverPlugin(); + sessionKeyPlugin = new SessionKeyPlugin(); + accountImplementation = address(new UpgradeableModularAccount(entryPoint)); + bytes32 manifestHash = keccak256(abi.encode(multiOwnerPlugin.pluginManifest())); + factory = new MultiOwnerMSCAFactory( + address(this), address(multiOwnerPlugin), accountImplementation, manifestHash, entryPoint + ); + + // Compute counterfactual address + owners1 = new address[](1); + owners1[0] = owner1; + account1 = UpgradeableModularAccount(payable(factory.getAddress(0, owners1))); + vm.deal(address(account1), 100 ether); + + // Pre-deploy account two for different gas estimates + (owner2, owner2Key) = makeAddrAndKey("owner2"); + owners2 = new address[](1); + owners2[0] = owner2; + account2 = UpgradeableModularAccount(payable(factory.createAccount(0, owners2))); + vm.deal(address(account2), 100 ether); + + ethRecipient = makeAddr("ethRecipient"); + vm.deal(ethRecipient, 1 wei); + counter = new Counter(); + counter.increment(); // amoritze away gas cost of zero->nonzero transition + } + + function test_deployAccount() public { + factory.createAccount(0, owners1); + } + + function test_initialize_revertArrayLengthMismatch() public { + ERC1967Proxy account = new ERC1967Proxy{salt : ""}(accountImplementation, ""); + address[] memory plugins = new address[](2); + bytes memory pluginInitData = abi.encode(new bytes32[](1), new bytes[](1)); + vm.expectRevert(PluginManagerInternals.ArrayLengthMismatch.selector); + IAccountInitializable(address(account)).initialize(plugins, pluginInitData); + + pluginInitData = abi.encode(new bytes32[](2), new bytes[](1)); + vm.expectRevert(PluginManagerInternals.ArrayLengthMismatch.selector); + IAccountInitializable(address(account)).initialize(plugins, pluginInitData); + } + + function test_basicUserOp() public { + address[] memory owners = new address[](1); + owners[0] = owner2; + UserOperation memory userOp = UserOperation({ + sender: address(account1), + nonce: 0, + initCode: abi.encodePacked(address(factory), abi.encodeCall(factory.createAccount, (0, owners1))), + callData: abi.encodeCall(MultiOwnerPlugin.updateOwners, (owners, new address[](0))), + callGasLimit: CALL_GAS_LIMIT, + verificationGasLimit: VERIFICATION_GAS_LIMIT, + preVerificationGas: 0, + maxFeePerGas: 2, + maxPriorityFeePerGas: 1, + paymasterAndData: "", + signature: "" + }); + + // Generate signature + bytes32 userOpHash = entryPoint.getUserOpHash(userOp); + (uint8 v, bytes32 r, bytes32 s) = vm.sign(owner1Key, userOpHash.toEthSignedMessageHash()); + userOp.signature = abi.encodePacked(r, s, v); + + UserOperation[] memory userOps = new UserOperation[](1); + userOps[0] = userOp; + + entryPoint.handleOps(userOps, beneficiary); + } + + function test_standardExecuteEthSend() public { + address payable recipient = payable(makeAddr("recipient")); + + UserOperation memory userOp = UserOperation({ + sender: address(account1), + nonce: 0, + initCode: abi.encodePacked(address(factory), abi.encodeCall(factory.createAccount, (0, owners1))), + callData: abi.encodeCall(UpgradeableModularAccount(payable(account1)).execute, (recipient, 1 wei, "")), + callGasLimit: CALL_GAS_LIMIT, + verificationGasLimit: VERIFICATION_GAS_LIMIT, + preVerificationGas: 0, + maxFeePerGas: 2, + maxPriorityFeePerGas: 1, + paymasterAndData: "", + signature: "" + }); + + // Generate signature + bytes32 userOpHash = entryPoint.getUserOpHash(userOp); + (uint8 v, bytes32 r, bytes32 s) = vm.sign(owner1Key, userOpHash.toEthSignedMessageHash()); + userOp.signature = abi.encodePacked(r, s, v); + + UserOperation[] memory userOps = new UserOperation[](1); + userOps[0] = userOp; + + entryPoint.handleOps(userOps, beneficiary); + + assertEq(recipient.balance, 1 wei); + } + + function test_postDeploy_ethSend() public { + UserOperation memory userOp = UserOperation({ + sender: address(account2), + nonce: 0, + initCode: "", + callData: abi.encodeCall(UpgradeableModularAccount.execute, (ethRecipient, 1 wei, "")), + callGasLimit: CALL_GAS_LIMIT, + verificationGasLimit: VERIFICATION_GAS_LIMIT, + preVerificationGas: 0, + maxFeePerGas: 1, + maxPriorityFeePerGas: 1, + paymasterAndData: "", + signature: "" + }); + + // Generate signature + bytes32 userOpHash = entryPoint.getUserOpHash(userOp); + (uint8 v, bytes32 r, bytes32 s) = vm.sign(owner2Key, userOpHash.toEthSignedMessageHash()); + userOp.signature = abi.encodePacked(r, s, v); + + UserOperation[] memory userOps = new UserOperation[](1); + userOps[0] = userOp; + + entryPoint.handleOps(userOps, beneficiary); + + assertEq(ethRecipient.balance, 2 wei); + } + + function test_debug_upgradeableModularAccount_storageAccesses() public { + UserOperation memory userOp = UserOperation({ + sender: address(account2), + nonce: 0, + initCode: "", + callData: abi.encodeCall(UpgradeableModularAccount.execute, (ethRecipient, 1 wei, "")), + callGasLimit: CALL_GAS_LIMIT, + verificationGasLimit: VERIFICATION_GAS_LIMIT, + preVerificationGas: 0, + maxFeePerGas: 1, + maxPriorityFeePerGas: 1, + paymasterAndData: "", + signature: "" + }); + + // Generate signature + bytes32 userOpHash = entryPoint.getUserOpHash(userOp); + (uint8 v, bytes32 r, bytes32 s) = vm.sign(owner2Key, userOpHash.toEthSignedMessageHash()); + userOp.signature = abi.encodePacked(r, s, v); + + UserOperation[] memory userOps = new UserOperation[](1); + userOps[0] = userOp; + + vm.record(); + entryPoint.handleOps(userOps, beneficiary); + _printStorageReadsAndWrites(address(account2)); + } + + function test_contractInteraction() public { + UserOperation memory userOp = UserOperation({ + sender: address(account2), + nonce: 0, + initCode: "", + callData: abi.encodeCall( + UpgradeableModularAccount.execute, (address(counter), 0, abi.encodeCall(counter.increment, ())) + ), + callGasLimit: CALL_GAS_LIMIT, + verificationGasLimit: VERIFICATION_GAS_LIMIT, + preVerificationGas: 0, + maxFeePerGas: 1, + maxPriorityFeePerGas: 1, + paymasterAndData: "", + signature: "" + }); + + // Generate signature + bytes32 userOpHash = entryPoint.getUserOpHash(userOp); + (uint8 v, bytes32 r, bytes32 s) = vm.sign(owner2Key, userOpHash.toEthSignedMessageHash()); + userOp.signature = abi.encodePacked(r, s, v); + + UserOperation[] memory userOps = new UserOperation[](1); + userOps[0] = userOp; + + entryPoint.handleOps(userOps, beneficiary); + + assertEq(counter.number(), 2); + } + + function test_batchExecute() public { + // Performs both an eth send and a contract interaction with counter + Call[] memory calls = new Call[](2); + calls[0] = Call({target: ethRecipient, value: 1 wei, data: ""}); + calls[1] = Call({target: address(counter), value: 0, data: abi.encodeCall(counter.increment, ())}); + + UserOperation memory userOp = UserOperation({ + sender: address(account2), + nonce: 0, + initCode: "", + callData: abi.encodeCall(UpgradeableModularAccount(payable(account2)).executeBatch, (calls)), + callGasLimit: CALL_GAS_LIMIT, + verificationGasLimit: VERIFICATION_GAS_LIMIT, + preVerificationGas: 0, + maxFeePerGas: 1, + maxPriorityFeePerGas: 1, + paymasterAndData: "", + signature: "" + }); + + // Generate signature + bytes32 userOpHash = entryPoint.getUserOpHash(userOp); + (uint8 v, bytes32 r, bytes32 s) = vm.sign(owner2Key, userOpHash.toEthSignedMessageHash()); + userOp.signature = abi.encodePacked(r, s, v); + + UserOperation[] memory userOps = new UserOperation[](1); + userOps[0] = userOp; + + entryPoint.handleOps(userOps, beneficiary); + + assertEq(counter.number(), 2); + assertEq(ethRecipient.balance, 2 wei); + } + + // runtime validation tests + function test_runtime_standardExecuteEthSend() public { + factory.createAccount(0, owners1); + address payable recipient = payable(makeAddr("recipient")); + uint256 balBefore = recipient.balance; + + vm.startPrank(owner1); + UpgradeableModularAccount(payable(account1)).execute(recipient, 1 wei, ""); + assertEq(recipient.balance, balBefore + 1 wei); + } + + function test_runtime_debug_upgradeableModularAccount_storageAccesses() public { + vm.startPrank(owner2); + UpgradeableModularAccount(payable(account2)).execute(ethRecipient, 1 wei, ""); + _printStorageReadsAndWrites(address(account2)); + } + + function test_runtime_contractInteraction() public { + factory.createAccount(0, owners1); + uint256 valueBefore = counter.number(); + + vm.startPrank(owner1); + UpgradeableModularAccount(payable(account1)).execute( + address(counter), 0, abi.encodeCall(counter.increment, ()) + ); + assertEq(counter.number(), valueBefore + 1); + } + + function test_runtime_revertPluginCall() public { + factory.createAccount(0, owners1); + + vm.startPrank(owner1); + + vm.expectRevert( + abi.encodeWithSelector(AccountExecutor.PluginCallDenied.selector, address(multiOwnerPlugin)) + ); + UpgradeableModularAccount(payable(account1)).execute( + address(multiOwnerPlugin), 0, abi.encodeCall(MultiOwnerPlugin.ownersOf, (address(account1))) + ); + + Call[] memory calls = new Call[](1); + calls[0] = Call({target: address(multiOwnerPlugin), value: 1 wei, data: ""}); + vm.expectRevert( + abi.encodeWithSelector(AccountExecutor.PluginCallDenied.selector, address(multiOwnerPlugin)) + ); + UpgradeableModularAccount(payable(account1)).executeBatch(calls); + } + + function test_runtime_batchExecute() public { + // Performs both an eth send and a contract interaction with counter + Call[] memory calls = new Call[](2); + calls[0] = Call({target: ethRecipient, value: 1 wei, data: ""}); + calls[1] = Call({target: address(counter), value: 0, data: abi.encodeCall(counter.increment, ())}); + uint256 balBefore = ethRecipient.balance; + + vm.startPrank(owner2); + UpgradeableModularAccount(payable(account2)).executeBatch(calls); + assertEq(counter.number(), 2); + assertEq(ethRecipient.balance, balBefore + 1 wei); + } + + function testFuzz_runtime_revert(bytes memory revertReason) public { + vm.startPrank(owner2); + + bytes memory callData = abi.encodeCall(UpgradeableModularAccount.execute, (beneficiary, 0 wei, "")); + + vm.mockCallRevert( + address(multiOwnerPlugin), + abi.encodeCall( + IPlugin.runtimeValidationFunction, + (uint8(IMultiOwnerPlugin.FunctionId.RUNTIME_VALIDATION_OWNER_OR_SELF), owner2, 0, callData) + ), + revertReason + ); + vm.expectRevert( + abi.encodeWithSelector( + UpgradeableModularAccount.RuntimeValidationFunctionReverted.selector, + (address(multiOwnerPlugin)), + uint8(IMultiOwnerPlugin.FunctionId.RUNTIME_VALIDATION_OWNER_OR_SELF), + revertReason + ) + ); + account2.execute(beneficiary, 0 wei, ""); + } + + function test_view_entryPoint() public { + factory.createAccount(0, owners1); + + assertEq(address(UpgradeableModularAccount(payable(account1)).entryPoint()), address(entryPoint)); + } + + function test_view_getNonce() public { + factory.createAccount(0, owners1); + + assertEq(UpgradeableModularAccount(payable(account1)).getNonce(), 0); + + UserOperation memory userOp = UserOperation({ + sender: address(account1), + nonce: 0, + initCode: "", + callData: abi.encodeCall(UpgradeableModularAccount.execute, (ethRecipient, 1 wei, "")), + callGasLimit: CALL_GAS_LIMIT, + verificationGasLimit: VERIFICATION_GAS_LIMIT, + preVerificationGas: 0, + maxFeePerGas: 1, + maxPriorityFeePerGas: 1, + paymasterAndData: "", + signature: "" + }); + + // Generate signature + bytes32 userOpHash = entryPoint.getUserOpHash(userOp); + (uint8 v, bytes32 r, bytes32 s) = vm.sign(owner1Key, userOpHash.toEthSignedMessageHash()); + userOp.signature = abi.encodePacked(r, s, v); + + UserOperation[] memory userOps = new UserOperation[](1); + userOps[0] = userOp; + + entryPoint.handleOps(userOps, beneficiary); + + assertEq(UpgradeableModularAccount(payable(account1)).getNonce(), 1); + } + + function test_validateUserOp_revertNotFromEntryPoint() public { + UserOperation memory userOp = UserOperation({ + sender: address(account2), + nonce: 0, + initCode: "", + callData: abi.encodeCall(UpgradeableModularAccount.execute, (ethRecipient, 1 wei, "")), + callGasLimit: CALL_GAS_LIMIT, + verificationGasLimit: VERIFICATION_GAS_LIMIT, + preVerificationGas: 0, + maxFeePerGas: 1, + maxPriorityFeePerGas: 1, + paymasterAndData: "", + signature: "" + }); + + // Generate signature + bytes32 userOpHash = entryPoint.getUserOpHash(userOp); + (uint8 v, bytes32 r, bytes32 s) = vm.sign(owner2Key, userOpHash.toEthSignedMessageHash()); + userOp.signature = abi.encodePacked(r, s, v); + + vm.expectRevert(UpgradeableModularAccount.UserOpNotFromEntryPoint.selector); + account2.validateUserOp(userOp, userOpHash, 0); + } + + function test_validateUserOp_revertUnrecognizedFunction() public { + // Invalid calldata of length 2. + bytes memory callData = hex"12"; + + UserOperation memory userOp = UserOperation({ + sender: address(account2), + nonce: 0, + initCode: "", + callData: callData, + callGasLimit: CALL_GAS_LIMIT, + verificationGasLimit: VERIFICATION_GAS_LIMIT, + preVerificationGas: 0, + maxFeePerGas: 1, + maxPriorityFeePerGas: 1, + paymasterAndData: "", + signature: "" + }); + + // Generate signature + bytes32 userOpHash = entryPoint.getUserOpHash(userOp); + (uint8 v, bytes32 r, bytes32 s) = vm.sign(owner2Key, userOpHash.toEthSignedMessageHash()); + userOp.signature = abi.encodePacked(r, s, v); + + vm.startPrank(address(entryPoint)); + vm.expectRevert( + abi.encodeWithSelector(UpgradeableModularAccount.UnrecognizedFunction.selector, bytes4(callData)) + ); + account2.validateUserOp(userOp, userOpHash, 0); + vm.stopPrank(); + } + + function test_validateUserOp_revertFunctionMissing() public { + PluginManifest memory m; + m.executionFunctions = new bytes4[](1); + bytes4 fooSelector = bytes4(keccak256("foo()")); + m.executionFunctions[0] = fooSelector; + MockPlugin plugin = new MockPlugin(m); + bytes32 manifestHash = keccak256(abi.encode(plugin.pluginManifest())); + + vm.startPrank(owner2); + // Install a plugin with execution function foo() that does not have an associated user op validation + // function. + IPluginManager(account2).installPlugin({ + plugin: address(plugin), + manifestHash: manifestHash, + pluginInitData: "", + dependencies: new FunctionReference[](0), + injectedHooks: new IPluginManager.InjectedHook[](0) + }); + vm.stopPrank(); + + UserOperation memory userOp = UserOperation({ + sender: address(account2), + nonce: 0, + initCode: "", + callData: abi.encodeWithSelector(fooSelector), + callGasLimit: CALL_GAS_LIMIT, + verificationGasLimit: VERIFICATION_GAS_LIMIT, + preVerificationGas: 0, + maxFeePerGas: 1, + maxPriorityFeePerGas: 1, + paymasterAndData: "", + signature: "" + }); + bytes32 userOpHash = entryPoint.getUserOpHash(userOp); + (uint8 v, bytes32 r, bytes32 s) = vm.sign(owner2Key, userOpHash.toEthSignedMessageHash()); + userOp.signature = abi.encodePacked(r, s, v); + + vm.startPrank(address(entryPoint)); + vm.expectRevert( + abi.encodeWithSelector(UpgradeableModularAccount.UserOpValidationFunctionMissing.selector, fooSelector) + ); + account2.validateUserOp(userOp, userOpHash, 0); + vm.stopPrank(); + } + + // Internal Functions + + function _printStorageReadsAndWrites(address addr) internal { + (bytes32[] memory accountReads, bytes32[] memory accountWrites) = vm.accesses(addr); + for (uint256 i = 0; i < accountWrites.length; i++) { + bytes32 valWritten = vm.load(addr, accountWrites[i]); + // solhint-disable-next-line no-console + console.log( + string.concat("write loc: ", vm.toString(accountWrites[i]), " val: ", vm.toString(valWritten)) + ); + } + + for (uint256 i = 0; i < accountReads.length; i++) { + bytes32 valRead = vm.load(addr, accountReads[i]); + // solhint-disable-next-line no-console + console.log(string.concat("read: ", vm.toString(accountReads[i]), " val: ", vm.toString(valRead))); + } + } +} diff --git a/test/account/UpgradeableModularAccountPluginManager.t.sol b/test/account/UpgradeableModularAccountPluginManager.t.sol new file mode 100644 index 00000000..95eacb5f --- /dev/null +++ b/test/account/UpgradeableModularAccountPluginManager.t.sol @@ -0,0 +1,963 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.21; + +import {Test} from "forge-std/Test.sol"; + +import {ECDSA} from "@openzeppelin/contracts/utils/cryptography/ECDSA.sol"; +import {EntryPoint} from "@eth-infinitism/account-abstraction/core/EntryPoint.sol"; +import {IPaymaster} from "@eth-infinitism/account-abstraction/interfaces/IPaymaster.sol"; + +import {PluginManagerInternals} from "../../src/account/PluginManagerInternals.sol"; +import {UpgradeableModularAccount} from "../../src/account/UpgradeableModularAccount.sol"; +import {IMultiOwnerPlugin} from "../../src/plugins/owner/IMultiOwnerPlugin.sol"; +import {MultiOwnerPlugin} from "../../src/plugins/owner/MultiOwnerPlugin.sol"; +import {SessionKeyPlugin} from "../../src/plugins/session/SessionKeyPlugin.sol"; +import {TokenReceiverPlugin} from "../../src/plugins/TokenReceiverPlugin.sol"; +import {IEntryPoint} from "../../src/interfaces/erc4337/IEntryPoint.sol"; +import {PluginManifest} from "../../src/interfaces/IPlugin.sol"; +import {IAccountLoupe} from "../../src/interfaces/IAccountLoupe.sol"; +import {IPluginManager} from "../../src/interfaces/IPluginManager.sol"; +import {IPluginExecutor} from "../../src/interfaces/IPluginExecutor.sol"; +import {IStandardExecutor} from "../../src/interfaces/IStandardExecutor.sol"; +import {Call} from "../../src/interfaces/IStandardExecutor.sol"; +import {FunctionReference, FunctionReferenceLib} from "../../src/libraries/FunctionReferenceLib.sol"; +import {IPlugin, PluginManifest} from "../../src/interfaces/IPlugin.sol"; + +import {Counter} from "../mocks/Counter.sol"; +import {MultiOwnerMSCAFactory} from "../../src/factory/MultiOwnerMSCAFactory.sol"; +import { + CanChangeManifestPluginFactory, CanChangeManifestPlugin +} from "../mocks/plugins/ChangingManifestPlugin.sol"; +import {ComprehensivePlugin} from "../mocks/plugins/ComprehensivePlugin.sol"; +import {UninstallErrorsPlugin} from "../mocks/plugins/UninstallErrorsPlugin.sol"; +import {MockPlugin} from "../mocks/MockPlugin.sol"; + +contract UpgradeableModularAccountPluginManagerTest is Test { + using ECDSA for bytes32; + + IEntryPoint public entryPoint; + address payable public beneficiary; + MultiOwnerPlugin public multiOwnerPlugin; + TokenReceiverPlugin public tokenReceiverPlugin; + SessionKeyPlugin public sessionKeyPlugin; + MultiOwnerMSCAFactory public factory; + address public implementation; + + address public owner1; + uint256 public owner1Key; + UpgradeableModularAccount public account1; + + address public owner2; + uint256 public owner2Key; + UpgradeableModularAccount public account2; + + address[] public owners1; + address[] public owners2; + + address public ethRecipient; + Counter public counter; + PluginManifest public manifest; + IPluginManager.InjectedHooksInfo public injectedHooksInfo = IPluginManager.InjectedHooksInfo({ + preExecHookFunctionId: 2, + isPostHookUsed: true, + postExecHookFunctionId: 3 + }); + + uint256 public constant CALL_GAS_LIMIT = 500000; + uint256 public constant VERIFICATION_GAS_LIMIT = 2000000; + + event PluginInstalled( + address indexed plugin, + bytes32 manifestHash, + FunctionReference[] dependencies, + IPluginManager.InjectedHook[] injectedHooks + ); + event PluginUninstalled(address indexed plugin, bool indexed callbacksSucceeded); + event PluginIgnoredHookUnapplyCallbackFailure(address indexed plugin, address indexed providingPlugin); + event PluginIgnoredUninstallCallbackFailure(address indexed plugin); + event ReceivedCall(bytes msgData, uint256 msgValue); + + function setUp() public { + entryPoint = IEntryPoint(address(new EntryPoint())); + (owner1, owner1Key) = makeAddrAndKey("owner1"); + beneficiary = payable(makeAddr("beneficiary")); + vm.deal(beneficiary, 1 wei); + + multiOwnerPlugin = new MultiOwnerPlugin(); + tokenReceiverPlugin = new TokenReceiverPlugin(); + sessionKeyPlugin = new SessionKeyPlugin(); + implementation = address(new UpgradeableModularAccount(entryPoint)); + bytes32 manifestHash = keccak256(abi.encode(multiOwnerPlugin.pluginManifest())); + factory = new MultiOwnerMSCAFactory( + address(this), address(multiOwnerPlugin), implementation, manifestHash, entryPoint + ); + + // Compute counterfactual address + owners1 = new address[](1); + owners1[0] = owner1; + account1 = UpgradeableModularAccount(payable(factory.getAddress(0, owners1))); + vm.deal(address(account1), 100 ether); + + // Pre-deploy account two for different gas estimates + (owner2, owner2Key) = makeAddrAndKey("owner2"); + owners2 = new address[](1); + owners2[0] = owner2; + account2 = UpgradeableModularAccount(payable(factory.createAccount(0, owners2))); + vm.deal(address(account2), 100 ether); + + ethRecipient = makeAddr("ethRecipient"); + vm.deal(ethRecipient, 1 wei); + counter = new Counter(); + counter.increment(); // amoritze away gas cost of zero->nonzero transition + } + + function test_deployAccount() public { + factory.createAccount(0, owners1); + } + + function test_installPlugin() public { + vm.startPrank(owner2); + + bytes32 manifestHash = keccak256(abi.encode(sessionKeyPlugin.pluginManifest())); + FunctionReference[] memory dependencies = new FunctionReference[](2); + dependencies[0] = FunctionReferenceLib.pack( + address(multiOwnerPlugin), uint8(IMultiOwnerPlugin.FunctionId.USER_OP_VALIDATION_OWNER) + ); + dependencies[1] = FunctionReferenceLib.pack( + address(multiOwnerPlugin), uint8(IMultiOwnerPlugin.FunctionId.RUNTIME_VALIDATION_OWNER_OR_SELF) + ); + address[] memory sessionKeys = new address[](1); + sessionKeys[0] = owner1; + + vm.expectEmit(true, true, true, true); + emit PluginInstalled( + address(sessionKeyPlugin), manifestHash, dependencies, new IPluginManager.InjectedHook[](0) + ); + IPluginManager(account2).installPlugin({ + plugin: address(sessionKeyPlugin), + manifestHash: manifestHash, + pluginInitData: abi.encode(sessionKeys), + dependencies: dependencies, + injectedHooks: new IPluginManager.InjectedHook[](0) + }); + + manifestHash = keccak256(abi.encode(tokenReceiverPlugin.pluginManifest())); + vm.expectEmit(true, true, true, true); + emit PluginInstalled( + address(tokenReceiverPlugin), + manifestHash, + new FunctionReference[](0), + new IPluginManager.InjectedHook[](0) + ); + IPluginManager(account2).installPlugin({ + plugin: address(tokenReceiverPlugin), + manifestHash: manifestHash, + pluginInitData: abi.encode(uint48(1 days)), + dependencies: new FunctionReference[](0), + injectedHooks: new IPluginManager.InjectedHook[](0) + }); + + address[] memory plugins = IAccountLoupe(account2).getInstalledPlugins(); + assertEq(plugins.length, 3); + assertEq(plugins[0], address(tokenReceiverPlugin)); + assertEq(plugins[1], address(sessionKeyPlugin)); + assertEq(plugins[2], address(multiOwnerPlugin)); + } + + function test_installPlugin_ExecuteFromPlugin_PermittedExecSelectorNotInstalled() public { + vm.startPrank(owner2); + + PluginManifest memory m; + m.permittedExecutionSelectors = new bytes4[](1); + m.permittedExecutionSelectors[0] = IPlugin.onInstall.selector; + + MockPlugin mockPluginWithBadPermittedExec = new MockPlugin(m); + bytes32 manifestHash = keccak256(abi.encode(mockPluginWithBadPermittedExec.pluginManifest())); + + // This call should complete successfully, because we allow installation of plugins with non-existant + // permitted call selectors. + IPluginManager(account2).installPlugin({ + plugin: address(mockPluginWithBadPermittedExec), + manifestHash: manifestHash, + pluginInitData: "", + dependencies: new FunctionReference[](0), + injectedHooks: new IPluginManager.InjectedHook[](0) + }); + } + + function test_installPlugin_invalidManifest() public { + vm.startPrank(owner2); + + vm.expectRevert(abi.encodeWithSelector(PluginManagerInternals.InvalidPluginManifest.selector)); + IPluginManager(account2).installPlugin({ + plugin: address(tokenReceiverPlugin), + manifestHash: bytes32(0), + pluginInitData: abi.encode(uint48(1 days)), + dependencies: new FunctionReference[](0), + injectedHooks: new IPluginManager.InjectedHook[](0) + }); + } + + function test_installPlugin_interfaceNotSupported() public { + vm.startPrank(owner2); + + address badPlugin = address(1); + vm.expectRevert( + abi.encodeWithSelector(PluginManagerInternals.PluginInterfaceNotSupported.selector, address(badPlugin)) + ); + IPluginManager(account2).installPlugin({ + plugin: address(badPlugin), + manifestHash: bytes32(0), + pluginInitData: "", + dependencies: new FunctionReference[](0), + injectedHooks: new IPluginManager.InjectedHook[](0) + }); + } + + function test_installPlugin_alreadyInstalled() public { + vm.startPrank(owner2); + + bytes32 manifestHash = keccak256(abi.encode(tokenReceiverPlugin.pluginManifest())); + IPluginManager(account2).installPlugin({ + plugin: address(tokenReceiverPlugin), + manifestHash: manifestHash, + pluginInitData: abi.encode(uint48(1 days)), + dependencies: new FunctionReference[](0), + injectedHooks: new IPluginManager.InjectedHook[](0) + }); + + vm.expectRevert( + abi.encodeWithSelector( + PluginManagerInternals.PluginAlreadyInstalled.selector, address(tokenReceiverPlugin) + ) + ); + IPluginManager(account2).installPlugin({ + plugin: address(tokenReceiverPlugin), + manifestHash: manifestHash, + pluginInitData: abi.encode(uint48(1 days)), + dependencies: new FunctionReference[](0), + injectedHooks: new IPluginManager.InjectedHook[](0) + }); + } + + function test_installPlugin_failWithNativeFunctionSelector() public { + vm.startPrank(owner2); + + PluginManifest memory manifestBad; + manifestBad.executionFunctions = new bytes4[](1); + manifestBad.executionFunctions[0] = IPluginManager.installPlugin.selector; + MockPlugin mockPluginBad = new MockPlugin(manifestBad); + bytes32 manifestHashBad = keccak256(abi.encode(mockPluginBad.pluginManifest())); + + vm.expectRevert( + abi.encodeWithSelector( + PluginManagerInternals.NativeFunctionNotAllowed.selector, IPluginManager.installPlugin.selector + ) + ); + IPluginManager(account2).installPlugin({ + plugin: address(mockPluginBad), + manifestHash: manifestHashBad, + pluginInitData: bytes(""), + dependencies: new FunctionReference[](0), + injectedHooks: new IPluginManager.InjectedHook[](0) + }); + } + + function test_installPlugin_failWtihErc4337FunctionSelector() public { + vm.startPrank(owner2); + + PluginManifest memory manifestBad; + manifestBad.executionFunctions = new bytes4[](1); + manifestBad.executionFunctions[0] = IPaymaster.validatePaymasterUserOp.selector; + MockPlugin mockPluginBad = new MockPlugin(manifestBad); + bytes32 manifestHashBad = keccak256(abi.encode(mockPluginBad.pluginManifest())); + + vm.expectRevert( + abi.encodeWithSelector( + PluginManagerInternals.Erc4337FunctionNotAllowed.selector, + IPaymaster.validatePaymasterUserOp.selector + ) + ); + IPluginManager(account2).installPlugin({ + plugin: address(mockPluginBad), + manifestHash: manifestHashBad, + pluginInitData: bytes(""), + dependencies: new FunctionReference[](0), + injectedHooks: new IPluginManager.InjectedHook[](0) + }); + } + + function test_installPlugin_missingDependency() public { + vm.startPrank(owner2); + + address[] memory guardians = new address[](1); + guardians[0] = address(1); + + bytes32 manifestHash = keccak256(abi.encode(sessionKeyPlugin.pluginManifest())); + + // Create a duplicate MultiOwnerPlugin that isn't installed, and attempt to use that as a dependency + MultiOwnerPlugin multiOwnerPlugin2 = new MultiOwnerPlugin(); + + FunctionReference[] memory dependencies = new FunctionReference[](2); + dependencies[0] = FunctionReferenceLib.pack( + address(multiOwnerPlugin2), uint8(IMultiOwnerPlugin.FunctionId.USER_OP_VALIDATION_OWNER) + ); + dependencies[1] = FunctionReferenceLib.pack( + address(multiOwnerPlugin2), uint8(IMultiOwnerPlugin.FunctionId.RUNTIME_VALIDATION_OWNER_OR_SELF) + ); + + vm.expectRevert( + abi.encodeWithSelector( + PluginManagerInternals.MissingPluginDependency.selector, address(multiOwnerPlugin2) + ) + ); + IPluginManager(account2).installPlugin({ + plugin: address(sessionKeyPlugin), + manifestHash: manifestHash, + pluginInitData: abi.encode(uint48(1 days)), + dependencies: dependencies, + injectedHooks: new IPluginManager.InjectedHook[](0) + }); + } + + function test_uninstallPlugin_default() public { + vm.startPrank(owner2); + + ComprehensivePlugin plugin = new ComprehensivePlugin(); + bytes32 manifestHash = keccak256(abi.encode(plugin.pluginManifest())); + IPluginManager(account2).installPlugin({ + plugin: address(plugin), + manifestHash: manifestHash, + pluginInitData: "", + dependencies: new FunctionReference[](0), + injectedHooks: new IPluginManager.InjectedHook[](0) + }); + + vm.expectEmit(true, true, true, true); + emit PluginUninstalled(address(plugin), true); + IPluginManager(account2).uninstallPlugin({ + plugin: address(plugin), + config: "", + pluginUninstallData: "", + hookUnapplyData: new bytes[](0) + }); + address[] memory plugins = IAccountLoupe(account2).getInstalledPlugins(); + assertEq(plugins.length, 1); + assertEq(plugins[0], address(multiOwnerPlugin)); + } + + function test_uninstallPlugin_manifestParameter() public { + vm.startPrank(owner2); + + ComprehensivePlugin plugin = new ComprehensivePlugin(); + bytes memory serializedManifest = abi.encode(plugin.pluginManifest()); + bytes32 manifestHash = keccak256(serializedManifest); + IPluginManager(account2).installPlugin({ + plugin: address(plugin), + manifestHash: manifestHash, + pluginInitData: "", + dependencies: new FunctionReference[](0), + injectedHooks: new IPluginManager.InjectedHook[](0) + }); + + bytes memory config = abi.encode( + UpgradeableModularAccount.UninstallPluginConfig({ + serializedManifest: serializedManifest, + forceUninstall: false, + callbackGasLimit: 0 + }) + ); + vm.expectEmit(true, true, true, true); + emit PluginUninstalled(address(plugin), true); + IPluginManager(account2).uninstallPlugin({ + plugin: address(plugin), + config: config, + pluginUninstallData: "", + hookUnapplyData: new bytes[](0) + }); + address[] memory plugins = IAccountLoupe(account2).getInstalledPlugins(); + assertEq(plugins.length, 1); + assertEq(plugins[0], address(multiOwnerPlugin)); + } + + function test_uninstallPlugin_invalidManifestFails() public { + vm.startPrank(owner2); + + ComprehensivePlugin plugin = new ComprehensivePlugin(); + bytes memory serializedManifest = abi.encode(plugin.pluginManifest()); + bytes32 manifestHash = keccak256(serializedManifest); + IPluginManager(account2).installPlugin({ + plugin: address(plugin), + manifestHash: manifestHash, + pluginInitData: "", + dependencies: new FunctionReference[](0), + injectedHooks: new IPluginManager.InjectedHook[](0) + }); + + // Attempt to uninstall with a blank manifest + PluginManifest memory blankManifest; + bytes memory config = abi.encode( + UpgradeableModularAccount.UninstallPluginConfig({ + serializedManifest: abi.encode(blankManifest), + forceUninstall: false, + callbackGasLimit: 0 + }) + ); + + vm.expectRevert(abi.encodeWithSelector(PluginManagerInternals.InvalidPluginManifest.selector)); + IPluginManager(account2).uninstallPlugin({ + plugin: address(plugin), + config: config, + pluginUninstallData: "", + hookUnapplyData: new bytes[](0) + }); + + // The forceUninstall flag doesn't let you succeed if your manifest is + // wrong. + config = abi.encode( + UpgradeableModularAccount.UninstallPluginConfig({ + serializedManifest: abi.encode(blankManifest), + forceUninstall: true, + callbackGasLimit: 0 + }) + ); + + vm.expectRevert(abi.encodeWithSelector(PluginManagerInternals.InvalidPluginManifest.selector)); + IPluginManager(account2).uninstallPlugin({ + plugin: address(plugin), + config: config, + pluginUninstallData: "", + hookUnapplyData: new bytes[](0) + }); + + address[] memory plugins = IAccountLoupe(account2).getInstalledPlugins(); + assertEq(plugins.length, 2); + assertEq(plugins[0], address(plugin)); + assertEq(plugins[1], address(multiOwnerPlugin)); + } + + function test_uninstallPlugin_manifestHasChanged() public { + vm.startPrank(owner2); + + CanChangeManifestPlugin plugin = new CanChangeManifestPluginFactory().newPlugin(); + bytes memory serializedManifest = abi.encode(plugin.pluginManifest()); + bytes32 manifestHash = keccak256(serializedManifest); + IPluginManager(account2).installPlugin({ + plugin: address(plugin), + manifestHash: manifestHash, + pluginInitData: "", + dependencies: new FunctionReference[](0), + injectedHooks: new IPluginManager.InjectedHook[](0) + }); + + plugin.changeManifest(); + + // Call an execution method which only appears in the initial manifest + // to later check that it's been removed. + CanChangeManifestPlugin(address(account2)).someExecutionFunction(); + + // Default uninstall should fail because the manifest has changed. + vm.expectRevert(abi.encodeWithSelector(PluginManagerInternals.InvalidPluginManifest.selector)); + IPluginManager(account2).uninstallPlugin({ + plugin: address(plugin), + config: "", + pluginUninstallData: "", + hookUnapplyData: new bytes[](0) + }); + + // Uninstall should succeed with original manifest hash passed in + bytes memory config = abi.encode( + UpgradeableModularAccount.UninstallPluginConfig({ + serializedManifest: serializedManifest, + forceUninstall: false, + callbackGasLimit: 0 + }) + ); + vm.expectEmit(true, true, true, true); + emit PluginUninstalled(address(plugin), true); + IPluginManager(account2).uninstallPlugin({ + plugin: address(plugin), + config: config, + pluginUninstallData: "", + hookUnapplyData: new bytes[](0) + }); + address[] memory plugins = IAccountLoupe(account2).getInstalledPlugins(); + assertEq(plugins.length, 1); + assertEq(plugins[0], address(multiOwnerPlugin)); + + // Check that the execution function which only appeared in the initial + // manifest has been removed (i.e. the account didn't use the new + // manifest for uninstallation despite being given the old one). + vm.expectRevert( + abi.encodeWithSelector( + UpgradeableModularAccount.UnrecognizedFunction.selector, + CanChangeManifestPlugin.someExecutionFunction.selector + ) + ); + CanChangeManifestPlugin(address(account2)).someExecutionFunction(); + } + + function test_forceOnUninstall() external { + address plugin = _installPluginWithUninstallErrors(false); + + vm.expectRevert( + abi.encodeWithSelector( + PluginManagerInternals.PluginUninstallCallbackFailed.selector, + plugin, + abi.encodeWithSelector(UninstallErrorsPlugin.IntentionalUninstallError.selector) + ) + ); + IPluginManager(account2).uninstallPlugin({ + plugin: plugin, + config: "", + pluginUninstallData: "", + hookUnapplyData: new bytes[](0) + }); + + bytes memory config = abi.encode( + UpgradeableModularAccount.UninstallPluginConfig({ + serializedManifest: "", + forceUninstall: true, + callbackGasLimit: 0 + }) + ); + vm.expectEmit(true, true, true, true); + emit PluginUninstalled(plugin, false); + IPluginManager(account2).uninstallPlugin({ + plugin: plugin, + config: config, + pluginUninstallData: "", + hookUnapplyData: new bytes[](0) + }); + } + + function test_forceOnHookUnapply() external { + (address plugin, address hooksPlugin) = _installPluginWithHookUnapplyErrors(false); + + vm.expectRevert( + abi.encodeWithSelector( + PluginManagerInternals.PluginHookUnapplyCallbackFailed.selector, + hooksPlugin, + abi.encodeWithSelector(UninstallErrorsPlugin.IntentionalUninstallError.selector) + ) + ); + IPluginManager(account2).uninstallPlugin({ + plugin: plugin, + config: "", + pluginUninstallData: "", + hookUnapplyData: new bytes[](0) + }); + + bytes memory config = abi.encode( + UpgradeableModularAccount.UninstallPluginConfig({ + serializedManifest: "", + forceUninstall: true, + callbackGasLimit: 0 + }) + ); + vm.expectEmit(true, true, true, true); + emit PluginUninstalled(plugin, false); + IPluginManager(account2).uninstallPlugin({ + plugin: plugin, + config: config, + pluginUninstallData: "", + hookUnapplyData: new bytes[](0) + }); + } + + function test_onUninstallGasLimit() external { + address plugin = _installPluginWithUninstallErrors(true); + + vm.expectRevert( + abi.encodeWithSelector(PluginManagerInternals.PluginUninstallCallbackFailed.selector, plugin, "") + ); + IPluginManager(account2).uninstallPlugin{gas: 100_000}({ + plugin: plugin, + config: "", + pluginUninstallData: "", + hookUnapplyData: new bytes[](0) + }); + + // Just `forceUninstall` isn't enough. + bytes memory config = abi.encode( + UpgradeableModularAccount.UninstallPluginConfig({ + serializedManifest: "", + forceUninstall: true, + callbackGasLimit: 0 + }) + ); + vm.expectRevert(bytes("")); + IPluginManager(account2).uninstallPlugin{gas: 100_000}({ + plugin: plugin, + config: config, + pluginUninstallData: "", + hookUnapplyData: new bytes[](0) + }); + + config = abi.encode( + UpgradeableModularAccount.UninstallPluginConfig({ + serializedManifest: "", + forceUninstall: true, + callbackGasLimit: 3000 + }) + ); + vm.expectEmit(true, true, true, true); + emit PluginIgnoredUninstallCallbackFailure(plugin); + vm.expectEmit(true, true, true, true); + emit PluginUninstalled(plugin, false); + IPluginManager(account2).uninstallPlugin{gas: 100_000}({ + plugin: plugin, + config: config, + pluginUninstallData: "", + hookUnapplyData: new bytes[](0) + }); + } + + function test_onHookUnapplyGasLimit() external { + (address plugin, address hooksPlugin) = _installPluginWithHookUnapplyErrors(true); + + vm.expectRevert( + abi.encodeWithSelector( + PluginManagerInternals.PluginHookUnapplyCallbackFailed.selector, hooksPlugin, "" + ) + ); + IPluginManager(account2).uninstallPlugin{gas: 100_000}({ + plugin: plugin, + config: "", + pluginUninstallData: "", + hookUnapplyData: new bytes[](0) + }); + + bytes memory config = abi.encode( + UpgradeableModularAccount.UninstallPluginConfig({ + serializedManifest: "", + forceUninstall: true, + callbackGasLimit: 3000 + }) + ); + vm.expectEmit(true, true, true, true); + emit PluginIgnoredHookUnapplyCallbackFailure(plugin, hooksPlugin); + vm.expectEmit(true, true, true, true); + emit PluginUninstalled(plugin, false); + IPluginManager(account2).uninstallPlugin{gas: 100_000}({ + plugin: plugin, + config: config, + pluginUninstallData: "", + hookUnapplyData: new bytes[](0) + }); + } + + function test_injectHooks() external { + (, MockPlugin newPlugin,) = _installWithInjectHooks(); + + // order of emitting events: pre hook is run, exec function is run, post hook is run + vm.expectEmit(true, true, true, true); + emit ReceivedCall( + abi.encodeWithSelector( + IPlugin.preExecutionHook.selector, + injectedHooksInfo.preExecHookFunctionId, + address(newPlugin), // caller + 0, // msg.value in call to account + abi.encodeCall( + account2.executeFromPluginExternal, + (address(counter), 0, abi.encodePacked(counter.increment.selector)) + ) + ), + 0 // msg value in call to plugin + ); + vm.expectEmit(true, true, true, true); + emit ReceivedCall( + abi.encodeCall( + IPlugin.postExecutionHook, + (injectedHooksInfo.postExecHookFunctionId, abi.encode(injectedHooksInfo.preExecHookFunctionId)) + ), + 0 // msg value in call to plugin + ); + vm.prank(address(newPlugin)); + account2.executeFromPluginExternal(address(counter), 0, abi.encodePacked(counter.increment.selector)); + } + + function test_injectHooksApplyGoodCalldata() external { + MockPlugin hooksPlugin = _installPluginWithExecHooks(); + + MockPlugin newPlugin = new MockPlugin(manifest); + + bytes32 manifestHash = keccak256(abi.encode(newPlugin.pluginManifest())); + + IPluginManager.InjectedHook[] memory hooks = new IPluginManager.InjectedHook[](1); + bytes memory onApplyData = abi.encode(keccak256("randomdata")); + hooks[0] = IPluginManager.InjectedHook( + address(hooksPlugin), IPluginExecutor.executeFromPluginExternal.selector, injectedHooksInfo, "" + ); + + vm.expectEmit(true, true, true, true); + emit ReceivedCall( + abi.encodeCall(IPlugin.onHookApply, (address(newPlugin), injectedHooksInfo, onApplyData)), 0 + ); + vm.expectEmit(true, true, true, true); + emit PluginInstalled(address(newPlugin), manifestHash, new FunctionReference[](0), hooks); + + // set the apply data after as the event emits an InjectedHook object after stripping hookApplyData out + hooks[0].hookApplyData = onApplyData; + + vm.prank(owner2); + IPluginManager(account2).installPlugin({ + plugin: address(newPlugin), + manifestHash: manifestHash, + pluginInitData: "", + dependencies: new FunctionReference[](0), + injectedHooks: hooks + }); + } + + function test_injectHooksMissingPlugin() external { + // hooks plugin not installed + MockPlugin hooksPlugin = MockPlugin(payable(address(1))); + + MockPlugin newPlugin = new MockPlugin(manifest); + + bytes32 manifestHash = keccak256(abi.encode(newPlugin.pluginManifest())); + + IPluginManager.InjectedHook[] memory hooks = new IPluginManager.InjectedHook[](1); + hooks[0] = IPluginManager.InjectedHook( + address(hooksPlugin), IPluginExecutor.executeFromPluginExternal.selector, injectedHooksInfo, "" + ); + + vm.expectRevert( + abi.encodeWithSelector(PluginManagerInternals.MissingPluginDependency.selector, address(hooksPlugin)) + ); + vm.prank(owner2); + IPluginManager(account2).installPlugin({ + plugin: address(newPlugin), + manifestHash: manifestHash, + pluginInitData: "", + dependencies: new FunctionReference[](0), + injectedHooks: hooks + }); + } + + function test_injectHooksUninstall() external { + (, MockPlugin newPlugin,) = _installWithInjectHooks(); + + vm.expectEmit(true, true, true, true); + emit PluginUninstalled(address(newPlugin), true); + vm.prank(owner2); + IPluginManager(account2).uninstallPlugin({ + plugin: address(newPlugin), + config: "", + pluginUninstallData: "", + hookUnapplyData: new bytes[](0) + }); + } + + function test_injectHooksBadUninstallDependency() external { + (MockPlugin hooksPlugin,,) = _installWithInjectHooks(); + + vm.prank(owner2); + vm.expectRevert( + abi.encodeWithSelector(PluginManagerInternals.PluginDependencyViolation.selector, address(hooksPlugin)) + ); + IPluginManager(account2).uninstallPlugin({ + plugin: address(hooksPlugin), + config: "", + pluginUninstallData: "", + hookUnapplyData: new bytes[](0) + }); + } + + function test_injectHooksUnapplyGoodCalldata() external { + (, MockPlugin newPlugin,) = _installWithInjectHooks(); + + bytes[] memory injectedHooksDatas = new bytes[](1); + injectedHooksDatas[0] = abi.encode(keccak256("randomdata")); + + vm.expectEmit(true, true, true, true); + emit ReceivedCall( + abi.encodeCall(IPlugin.onHookUnapply, (address(newPlugin), injectedHooksInfo, injectedHooksDatas[0])), + 0 + ); + vm.prank(owner2); + IPluginManager(account2).uninstallPlugin({ + plugin: address(newPlugin), + config: "", + pluginUninstallData: "", + hookUnapplyData: injectedHooksDatas + }); + } + + function test_injectHooksUnapplyBadCalldata() external { + (, MockPlugin newPlugin,) = _installWithInjectHooks(); + + // length != installed hooks length + bytes[] memory injectedHooksDatas = new bytes[](2); + + vm.expectRevert(PluginManagerInternals.ArrayLengthMismatch.selector); + vm.prank(owner2); + IPluginManager(account2).uninstallPlugin({ + plugin: address(newPlugin), + config: "", + pluginUninstallData: "", + hookUnapplyData: injectedHooksDatas + }); + } + + function test_uninstallAndInstallInBatch() external { + // Check that we can uninstall the `MultiOwnerPlugin`, leaving no + // validator on `installPlugin`, and then install a different plugin + // immediately after as part of the same batch execution. This is a + // special case: normally an execution function with no runtime + // validator cannot be runtime-called. + vm.startPrank(owner2); + + ComprehensivePlugin plugin = new ComprehensivePlugin(); + bytes32 manifestHash = keccak256(abi.encode(plugin.pluginManifest())); + + Call[] memory calls = new Call[](2); + calls[0] = Call({ + target: address(account2), + value: 0, + data: abi.encodeCall(IPluginManager.uninstallPlugin, (address(multiOwnerPlugin), "", "", new bytes[](0))) + }); + calls[1] = Call({ + target: address(account2), + value: 0, + data: abi.encodeCall( + IPluginManager.installPlugin, + (address(plugin), manifestHash, "", new FunctionReference[](0), new IPluginManager.InjectedHook[](0)) + ) + }); + vm.expectEmit(true, true, true, true); + emit PluginUninstalled(address(multiOwnerPlugin), true); + vm.expectEmit(true, true, true, true); + emit PluginInstalled( + address(plugin), manifestHash, new FunctionReference[](0), new IPluginManager.InjectedHook[](0) + ); + IStandardExecutor(account2).executeBatch(calls); + } + + function test_noNonSelfInstallAfterUninstall() external { + // A companion to the previous test, ensuring that `installPlugin` can't + // be called directly (e.g. not via `execute` or `executeBatch`) if it + // has no validator. + vm.startPrank(owner2); + + vm.expectEmit(true, true, true, true); + emit PluginUninstalled(address(multiOwnerPlugin), true); + account2.uninstallPlugin({ + plugin: address(multiOwnerPlugin), + config: "", + pluginUninstallData: "", + hookUnapplyData: new bytes[](0) + }); + + ComprehensivePlugin plugin = new ComprehensivePlugin(); + bytes32 manifestHash = keccak256(abi.encode(plugin.pluginManifest())); + + vm.expectRevert( + abi.encodeWithSelector( + UpgradeableModularAccount.RuntimeValidationFunctionMissing.selector, + IPluginManager.installPlugin.selector + ) + ); + account2.installPlugin({ + plugin: address(plugin), + manifestHash: manifestHash, + pluginInitData: "", + dependencies: new FunctionReference[](0), + injectedHooks: new IPluginManager.InjectedHook[](0) + }); + } + + // Internal Functions + + function _installPluginWithExecHooks() internal returns (MockPlugin plugin) { + vm.startPrank(owner2); + + plugin = new MockPlugin(manifest); + bytes32 manifestHash = keccak256(abi.encode(plugin.pluginManifest())); + + IPluginManager(account2).installPlugin({ + plugin: address(plugin), + manifestHash: manifestHash, + pluginInitData: "", + dependencies: new FunctionReference[](0), + injectedHooks: new IPluginManager.InjectedHook[](0) + }); + + vm.stopPrank(); + } + + function _installWithInjectHooks() + internal + returns (MockPlugin hooksPlugin, MockPlugin newPlugin, bytes32 manifestHash) + { + hooksPlugin = _installPluginWithExecHooks(); + + manifest.permitAnyExternalAddress = true; + newPlugin = new MockPlugin(manifest); + + manifestHash = keccak256(abi.encode(newPlugin.pluginManifest())); + + IPluginManager.InjectedHook[] memory hooks = new IPluginManager.InjectedHook[](1); + hooks[0] = IPluginManager.InjectedHook( + address(hooksPlugin), IPluginExecutor.executeFromPluginExternal.selector, injectedHooksInfo, "" + ); + + vm.prank(owner2); + vm.expectEmit(true, true, true, true); + emit ReceivedCall(abi.encodeCall(IPlugin.onHookApply, (address(newPlugin), injectedHooksInfo, "")), 0); + vm.expectEmit(true, true, true, true); + emit PluginInstalled(address(newPlugin), manifestHash, new FunctionReference[](0), hooks); + IPluginManager(account2).installPlugin({ + plugin: address(newPlugin), + manifestHash: manifestHash, + pluginInitData: "", + dependencies: new FunctionReference[](0), + injectedHooks: hooks + }); + } + + function _installPluginWithUninstallErrors(bool shouldDrainGas) internal returns (address) { + vm.startPrank(owner2); + + UninstallErrorsPlugin plugin = new UninstallErrorsPlugin(shouldDrainGas); + bytes32 manifestHash = keccak256(abi.encode(plugin.pluginManifest())); + IPluginManager(account2).installPlugin({ + plugin: address(plugin), + manifestHash: manifestHash, + pluginInitData: "", + dependencies: new FunctionReference[](0), + injectedHooks: new IPluginManager.InjectedHook[](0) + }); + return address(plugin); + } + + function _installPluginWithHookUnapplyErrors(bool shouldDrainGas) + internal + returns (address pluginAddress, address hooksPluginAddress) + { + vm.startPrank(owner2); + + UninstallErrorsPlugin hooksPlugin = new UninstallErrorsPlugin(shouldDrainGas); + IPluginManager(account2).installPlugin({ + plugin: address(hooksPlugin), + manifestHash: keccak256(abi.encode(hooksPlugin.pluginManifest())), + pluginInitData: "", + dependencies: new FunctionReference[](0), + injectedHooks: new IPluginManager.InjectedHook[](0) + }); + MockPlugin plugin = new MockPlugin(manifest); + IPluginManager.InjectedHook[] memory hooks = new IPluginManager.InjectedHook[](1); + hooks[0] = IPluginManager.InjectedHook( + address(hooksPlugin), IPluginExecutor.executeFromPluginExternal.selector, injectedHooksInfo, "" + ); + IPluginManager(account2).installPlugin({ + plugin: address(plugin), + manifestHash: keccak256(abi.encode(plugin.pluginManifest())), + pluginInitData: "", + dependencies: new FunctionReference[](0), + injectedHooks: hooks + }); + return (address(plugin), address(hooksPlugin)); + } +} diff --git a/test/account/ValidationIntersection.t.sol b/test/account/ValidationIntersection.t.sol new file mode 100644 index 00000000..66e96e66 --- /dev/null +++ b/test/account/ValidationIntersection.t.sol @@ -0,0 +1,318 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.21; + +import {Test} from "forge-std/Test.sol"; + +import {EntryPoint} from "@eth-infinitism/account-abstraction/core/EntryPoint.sol"; + +import {IPluginManager} from "../../src/interfaces/IPluginManager.sol"; +import {UpgradeableModularAccount} from "../../src/account/UpgradeableModularAccount.sol"; +import {IEntryPoint} from "../../src/interfaces/erc4337/IEntryPoint.sol"; +import {UserOperation} from "../../src/interfaces/erc4337/UserOperation.sol"; +import {MultiOwnerPlugin} from "../../src/plugins/owner/MultiOwnerPlugin.sol"; +import {FunctionReference} from "../../src/libraries/FunctionReferenceLib.sol"; + +import {MultiOwnerMSCAFactory} from "../../src/factory/MultiOwnerMSCAFactory.sol"; +import { + MockBaseUserOpValidationPlugin, + MockUserOpValidation1HookPlugin, + MockUserOpValidation2HookPlugin, + MockUserOpValidationPlugin +} from "../mocks/plugins/ValidationPluginMocks.sol"; + +contract ValidationIntersectionTest is Test { + uint256 internal constant _SIG_VALIDATION_FAILED = 1; + + IEntryPoint public entryPoint; + + address public owner1; + uint256 public owner1Key; + UpgradeableModularAccount public account1; + MockUserOpValidationPlugin public noHookPlugin; + MockUserOpValidation1HookPlugin public oneHookPlugin; + MockUserOpValidation2HookPlugin public twoHookPlugin; + + function setUp() public { + entryPoint = IEntryPoint(address(new EntryPoint())); + owner1 = makeAddr("owner1"); + + MultiOwnerPlugin multiOwnerPlugin = new MultiOwnerPlugin(); + address impl = address(new UpgradeableModularAccount(entryPoint)); + + MultiOwnerMSCAFactory factory = new MultiOwnerMSCAFactory( + address(this), + address(multiOwnerPlugin), + impl, + keccak256(abi.encode(multiOwnerPlugin.pluginManifest())), + entryPoint + ); + + address[] memory owners1 = new address[](1); + owners1[0] = owner1; + account1 = UpgradeableModularAccount(payable(factory.createAccount(0, owners1))); + vm.deal(address(account1), 1 ether); + + noHookPlugin = new MockUserOpValidationPlugin(); + oneHookPlugin = new MockUserOpValidation1HookPlugin(); + twoHookPlugin = new MockUserOpValidation2HookPlugin(); + + vm.startPrank(address(owner1)); + account1.installPlugin({ + plugin: address(noHookPlugin), + manifestHash: keccak256(abi.encode(noHookPlugin.pluginManifest())), + pluginInitData: "", + dependencies: new FunctionReference[](0), + injectedHooks: new IPluginManager.InjectedHook[](0) + }); + account1.installPlugin({ + plugin: address(oneHookPlugin), + manifestHash: keccak256(abi.encode(oneHookPlugin.pluginManifest())), + pluginInitData: "", + dependencies: new FunctionReference[](0), + injectedHooks: new IPluginManager.InjectedHook[](0) + }); + account1.installPlugin({ + plugin: address(twoHookPlugin), + manifestHash: keccak256(abi.encode(twoHookPlugin.pluginManifest())), + pluginInitData: "", + dependencies: new FunctionReference[](0), + injectedHooks: new IPluginManager.InjectedHook[](0) + }); + vm.stopPrank(); + } + + function testFuzz_validationIntersect_single(uint256 validationData) public { + noHookPlugin.setValidationData(validationData); + + UserOperation memory userOp; + userOp.callData = bytes.concat(noHookPlugin.foo.selector); + bytes32 uoHash = entryPoint.getUserOpHash(userOp); + + vm.prank(address(entryPoint)); + uint256 returnedValidationData = account1.validateUserOp(userOp, uoHash, 1 wei); + + assertEq(returnedValidationData, validationData); + } + + function test_validationIntersect_authorizer_sigfail_validationFunction() public { + oneHookPlugin.setValidationData( + _SIG_VALIDATION_FAILED, + 0 // returns OK + ); + + UserOperation memory userOp; + userOp.callData = bytes.concat(oneHookPlugin.bar.selector); + bytes32 uoHash = entryPoint.getUserOpHash(userOp); + + vm.prank(address(entryPoint)); + uint256 returnedValidationData = account1.validateUserOp(userOp, uoHash, 1 wei); + + // Down-cast to only check the authorizer + assertEq(uint160(returnedValidationData), _SIG_VALIDATION_FAILED); + } + + function test_validationIntersect_authorizer_sigfail_hook() public { + oneHookPlugin.setValidationData( + 0, // returns OK + _SIG_VALIDATION_FAILED + ); + + UserOperation memory userOp; + userOp.callData = bytes.concat(oneHookPlugin.bar.selector); + bytes32 uoHash = entryPoint.getUserOpHash(userOp); + + vm.prank(address(entryPoint)); + uint256 returnedValidationData = account1.validateUserOp(userOp, uoHash, 1 wei); + + // Down-cast to only check the authorizer + assertEq(uint160(returnedValidationData), _SIG_VALIDATION_FAILED); + } + + function test_validationIntersect_timeBounds_intersect_1() public { + uint48 start1 = uint48(10); + uint48 end1 = uint48(20); + + uint48 start2 = uint48(15); + uint48 end2 = uint48(25); + + oneHookPlugin.setValidationData( + _packValidationData(address(0), start1, end1), _packValidationData(address(0), start2, end2) + ); + + UserOperation memory userOp; + userOp.callData = bytes.concat(oneHookPlugin.bar.selector); + bytes32 uoHash = entryPoint.getUserOpHash(userOp); + + vm.prank(address(entryPoint)); + uint256 returnedValidationData = account1.validateUserOp(userOp, uoHash, 1 wei); + + assertEq(returnedValidationData, _packValidationData(address(0), start2, end1)); + } + + function test_validationIntersect_timeBounds_intersect_2() public { + uint48 start1 = uint48(10); + uint48 end1 = uint48(20); + + uint48 start2 = uint48(15); + uint48 end2 = uint48(25); + + oneHookPlugin.setValidationData( + _packValidationData(address(0), start2, end2), _packValidationData(address(0), start1, end1) + ); + + UserOperation memory userOp; + userOp.callData = bytes.concat(oneHookPlugin.bar.selector); + bytes32 uoHash = entryPoint.getUserOpHash(userOp); + + vm.prank(address(entryPoint)); + uint256 returnedValidationData = account1.validateUserOp(userOp, uoHash, 1 wei); + + assertEq(returnedValidationData, _packValidationData(address(0), start2, end1)); + } + + function test_validationIntersect_revert_unexpectedAuthorizer() public { + address badAuthorizer = makeAddr("badAuthorizer"); + + oneHookPlugin.setValidationData( + 0, // returns OK + uint256(uint160(badAuthorizer)) // returns an aggregator, which preValidation hooks are not allowed to + // do. + ); + + UserOperation memory userOp; + userOp.callData = bytes.concat(oneHookPlugin.bar.selector); + bytes32 uoHash = entryPoint.getUserOpHash(userOp); + + vm.prank(address(entryPoint)); + vm.expectRevert( + abi.encodeWithSelector( + UpgradeableModularAccount.UnexpectedAggregator.selector, + address(oneHookPlugin), + MockBaseUserOpValidationPlugin.FunctionId.PRE_USER_OP_VALIDATION_HOOK_1, + badAuthorizer + ) + ); + account1.validateUserOp(userOp, uoHash, 1 wei); + } + + function test_validationIntersect_validAuthorizer() public { + address goodAuthorizer = makeAddr("goodAuthorizer"); + + oneHookPlugin.setValidationData( + uint256(uint160(goodAuthorizer)), // returns a valid aggregator + 0 // returns OK + ); + + UserOperation memory userOp; + userOp.callData = bytes.concat(oneHookPlugin.bar.selector); + bytes32 uoHash = entryPoint.getUserOpHash(userOp); + + vm.prank(address(entryPoint)); + uint256 returnedValidationData = account1.validateUserOp(userOp, uoHash, 1 wei); + + assertEq(address(uint160(returnedValidationData)), goodAuthorizer); + } + + function test_validationIntersect_authorizerAndTimeRange() public { + uint48 start1 = uint48(10); + uint48 end1 = uint48(20); + + uint48 start2 = uint48(15); + uint48 end2 = uint48(25); + + address goodAuthorizer = makeAddr("goodAuthorizer"); + + oneHookPlugin.setValidationData( + _packValidationData(goodAuthorizer, start1, end1), _packValidationData(address(0), start2, end2) + ); + + UserOperation memory userOp; + userOp.callData = bytes.concat(oneHookPlugin.bar.selector); + bytes32 uoHash = entryPoint.getUserOpHash(userOp); + + vm.prank(address(entryPoint)); + uint256 returnedValidationData = account1.validateUserOp(userOp, uoHash, 1 wei); + + assertEq(returnedValidationData, _packValidationData(goodAuthorizer, start2, end1)); + } + + function test_validationIntersect_multiplePreValidationHooksIntersect() public { + uint48 start1 = uint48(10); + uint48 end1 = uint48(20); + + uint48 start2 = uint48(15); + uint48 end2 = uint48(25); + + twoHookPlugin.setValidationData( + 0, // returns OK + _packValidationData(address(0), start1, end1), + _packValidationData(address(0), start2, end2) + ); + + UserOperation memory userOp; + userOp.callData = bytes.concat(twoHookPlugin.baz.selector); + bytes32 uoHash = entryPoint.getUserOpHash(userOp); + + vm.prank(address(entryPoint)); + uint256 returnedValidationData = account1.validateUserOp(userOp, uoHash, 1 wei); + + assertEq(returnedValidationData, _packValidationData(address(0), start2, end1)); + } + + function test_validationIntersect_multiplePreValidationHooksSigFail() public { + twoHookPlugin.setValidationData( + 0, // returns OK + 0, // returns OK + _SIG_VALIDATION_FAILED + ); + + UserOperation memory userOp; + userOp.callData = bytes.concat(twoHookPlugin.baz.selector); + + bytes32 uoHash = entryPoint.getUserOpHash(userOp); + + vm.prank(address(entryPoint)); + uint256 returnedValidationData = account1.validateUserOp(userOp, uoHash, 1 wei); + + // Down-cast to only check the authorizer + assertEq(uint160(returnedValidationData), _SIG_VALIDATION_FAILED); + } + + function _unpackValidationData(uint256 validationData) + internal + pure + returns (address authorizer, uint48 validAfter, uint48 validUntil) + { + authorizer = address(uint160(validationData)); + validUntil = uint48(validationData >> 160); + if (validUntil == 0) { + validUntil = type(uint48).max; + } + validAfter = uint48(validationData >> (48 + 160)); + } + + function _packValidationData(address authorizer, uint48 validAfter, uint48 validUntil) + internal + pure + returns (uint256) + { + return uint160(authorizer) | (uint256(validUntil) << 160) | (uint256(validAfter) << (160 + 48)); + } + + function _intersectTimeRange(uint48 validafter1, uint48 validuntil1, uint48 validafter2, uint48 validuntil2) + internal + pure + returns (uint48 validAfter, uint48 validUntil) + { + if (validafter1 < validafter2) { + validAfter = validafter2; + } else { + validAfter = validafter1; + } + if (validuntil1 > validuntil2) { + validUntil = validuntil2; + } else { + validUntil = validuntil1; + } + } +} diff --git a/test/comparison/CompareSimpleAccount.t.sol b/test/comparison/CompareSimpleAccount.t.sol new file mode 100644 index 00000000..61d2560f --- /dev/null +++ b/test/comparison/CompareSimpleAccount.t.sol @@ -0,0 +1,160 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.21; + +import {Test} from "forge-std/Test.sol"; + +import {ECDSA} from "@openzeppelin/contracts/utils/cryptography/ECDSA.sol"; +import {EntryPoint} from "@eth-infinitism/account-abstraction/core/EntryPoint.sol"; +import {SimpleAccount} from "@eth-infinitism/account-abstraction/samples/SimpleAccount.sol"; +import {SimpleAccountFactory} from "@eth-infinitism/account-abstraction/samples/SimpleAccountFactory.sol"; + +import {IEntryPoint} from "../../src/interfaces/erc4337/IEntryPoint.sol"; +import {UserOperation} from "../../src/interfaces/erc4337/UserOperation.sol"; +import {Counter} from "../mocks/Counter.sol"; + +contract CompareSimpleAccountTest is Test { + using ECDSA for bytes32; + + IEntryPoint public entryPoint; + address payable public beneficiary; + + SimpleAccountFactory public factory; + + // Owner 1 deploys account contract in the same transaction + address public owner1; + uint256 public owner1Key; + address public account1; + + // owner 2 pre-deploys account contract + address public owner2; + uint256 public owner2Key; + address public account2; + + Counter public counter; + + function setUp() public { + EntryPoint ep = new EntryPoint(); + entryPoint = IEntryPoint(address(ep)); + (owner1, owner1Key) = makeAddrAndKey("owner1"); + beneficiary = payable(makeAddr("beneficiary")); + vm.deal(beneficiary, 1 wei); + + factory = new SimpleAccountFactory(ep); + account1 = factory.getAddress(owner1, 0); + vm.deal(account1, 100 ether); + + counter = new Counter(); + counter.increment(); + + // Pre-generate account 2 + (owner2, owner2Key) = makeAddrAndKey("owner2"); + account2 = address(factory.createAccount(owner2, 0)); + vm.deal(account2, 100 ether); + vm.prank(account2); + entryPoint.depositTo{value: 1 wei}(account2); + } + + function test_SimpleAccount_deploy_basicSend() public { + UserOperation memory userOp = UserOperation({ + sender: account1, + nonce: 0, + initCode: abi.encodePacked(address(factory), abi.encodeCall(factory.createAccount, (owner1, 0))), + callData: abi.encodeCall(SimpleAccount.execute, (beneficiary, 1, "")), + callGasLimit: 5000000, + verificationGasLimit: 5000000, + preVerificationGas: 0, + maxFeePerGas: 2, + maxPriorityFeePerGas: 1, + paymasterAndData: "", + signature: "" + }); + + bytes32 userOpHash = entryPoint.getUserOpHash(userOp); + (uint8 v, bytes32 r, bytes32 s) = vm.sign(owner1Key, userOpHash.toEthSignedMessageHash()); + userOp.signature = abi.encodePacked(r, s, v); + + UserOperation[] memory userOps = new UserOperation[](1); + userOps[0] = userOp; + + entryPoint.handleOps(userOps, beneficiary); + } + + function test_SimpleAccount_deploy_empty() public { + UserOperation memory userOp = UserOperation({ + sender: account1, + nonce: 0, + initCode: abi.encodePacked(address(factory), abi.encodeCall(factory.createAccount, (owner1, 0))), + callData: "", + callGasLimit: 5000000, + verificationGasLimit: 5000000, + preVerificationGas: 0, + maxFeePerGas: 2, + maxPriorityFeePerGas: 1, + paymasterAndData: "", + signature: "" + }); + + bytes32 userOpHash = entryPoint.getUserOpHash(userOp); + (uint8 v, bytes32 r, bytes32 s) = vm.sign(owner1Key, userOpHash.toEthSignedMessageHash()); + userOp.signature = abi.encodePacked(r, s, v); + + UserOperation[] memory userOps = new UserOperation[](1); + userOps[0] = userOp; + + entryPoint.handleOps(userOps, beneficiary); + } + + function test_SimpleAccount_postDeploy_basicSend() public { + UserOperation memory userOp = UserOperation({ + sender: account2, + nonce: 0, + initCode: "", + callData: abi.encodeCall(SimpleAccount.execute, (beneficiary, 1, "")), + callGasLimit: 5000000, + verificationGasLimit: 5000000, + preVerificationGas: 0, + maxFeePerGas: 2, + maxPriorityFeePerGas: 1, + paymasterAndData: "", + signature: "" + }); + + bytes32 userOpHash = entryPoint.getUserOpHash(userOp); + (uint8 v, bytes32 r, bytes32 s) = vm.sign(owner2Key, userOpHash.toEthSignedMessageHash()); + userOp.signature = abi.encodePacked(r, s, v); + + UserOperation[] memory userOps = new UserOperation[](1); + userOps[0] = userOp; + + entryPoint.handleOps(userOps, beneficiary); + } + + function test_SimpleAccount_postDeploy_contractInteraction() public { + UserOperation memory userOp = UserOperation({ + sender: account2, + nonce: 0, + initCode: "", + callData: abi.encodeCall( + SimpleAccount.execute, (address(counter), 0, abi.encodeCall(Counter.increment, ())) + ), + callGasLimit: 5000000, + verificationGasLimit: 5000000, + preVerificationGas: 0, + maxFeePerGas: 2, + maxPriorityFeePerGas: 1, + paymasterAndData: "", + signature: "" + }); + + bytes32 userOpHash = entryPoint.getUserOpHash(userOp); + (uint8 v, bytes32 r, bytes32 s) = vm.sign(owner2Key, userOpHash.toEthSignedMessageHash()); + userOp.signature = abi.encodePacked(r, s, v); + + UserOperation[] memory userOps = new UserOperation[](1); + userOps[0] = userOp; + + entryPoint.handleOps(userOps, beneficiary); + + assertEq(counter.number(), 2); + } +} diff --git a/test/factory/MultiOwnerMSCAFactoryTest.t.sol b/test/factory/MultiOwnerMSCAFactoryTest.t.sol new file mode 100644 index 00000000..c5477ece --- /dev/null +++ b/test/factory/MultiOwnerMSCAFactoryTest.t.sol @@ -0,0 +1,123 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.21; + +import {Test} from "forge-std/Test.sol"; + +import {EntryPoint} from "@eth-infinitism/account-abstraction/core/EntryPoint.sol"; +import {ECDSA} from "@openzeppelin/contracts/utils/cryptography/ECDSA.sol"; + +import {MultiOwnerMSCAFactory} from "../../src/factory/MultiOwnerMSCAFactory.sol"; +import {IEntryPoint} from "../../src/interfaces/erc4337/IEntryPoint.sol"; +import {MultiOwnerPlugin} from "../../src/plugins/owner/MultiOwnerPlugin.sol"; +import {UpgradeableModularAccount} from "../../src/account/UpgradeableModularAccount.sol"; + +contract MultiOwnerMSCAFactoryTest is Test { + using ECDSA for bytes32; + + EntryPoint public entryPoint; + MultiOwnerMSCAFactory public factory; + MultiOwnerPlugin public multiOwnerPlugin; + address public impl; + + address public notOwner = address(1); + address public owner1 = address(2); + address public owner2 = address(3); + address public badImpl = address(4); + + address[] public owners; + + bytes32 internal constant _IMPLEMENTATION_SLOT = + 0x360894a13ba1a3210667c828492db98dca3e2076cc3735a920a3ca505d382bbc; + + function setUp() public { + owners.push(owner1); + owners.push(owner2); + entryPoint = new EntryPoint(); + impl = address(new UpgradeableModularAccount(IEntryPoint(address(entryPoint)))); + multiOwnerPlugin = new MultiOwnerPlugin(); + bytes32 manifestHash = keccak256(abi.encode(multiOwnerPlugin.pluginManifest())); + factory = new MultiOwnerMSCAFactory( + address(this), + address(multiOwnerPlugin), + impl, + manifestHash, + IEntryPoint(address(entryPoint)) + ); + vm.deal(address(this), 100 ether); + } + + function test_addressMatch() public { + address predicted = factory.getAddress(0, owners); + address deployed = factory.createAccount(0, owners); + assertEq(predicted, deployed); + } + + function test_deploy() public { + address deployed = factory.createAccount(0, owners); + + // test that the deployed account is initialized + assertEq(address(UpgradeableModularAccount(payable(deployed)).entryPoint()), address(entryPoint)); + + // test that the deployed account installed owner plugin correctly + address[] memory actualOwners = multiOwnerPlugin.ownersOf(deployed); + assertEq(actualOwners.length, 2); + assertEq(actualOwners[0], owner2); + assertEq(actualOwners[1], owner1); + } + + function test_deployCollision() public { + address deployed = factory.createAccount(0, owners); + + uint256 gasStart = gasleft(); + + // deploy 2nd time which should short circuit + // test for short circuit -> call should cost less than a CREATE2, or 32000 gas + address secondDeploy = factory.createAccount(0, owners); + + assertApproxEqAbs(gasleft(), gasStart, 31999); + assertEq(deployed, secondDeploy); + } + + function test_deployedAccountHasCorrectPlugins() public { + address deployed = factory.createAccount(0, owners); + + // check installed plugins on account + address[] memory plugins = UpgradeableModularAccount(payable(deployed)).getInstalledPlugins(); + assertEq(plugins.length, 1); + assertEq(plugins[0], address(multiOwnerPlugin)); + } + + function test_addStake() public { + assertEq(entryPoint.balanceOf(address(factory)), 0); + vm.deal(address(this), 100 ether); + factory.addStake{value: 10 ether}(10 hours, 10 ether); + assertEq(entryPoint.getDepositInfo(address(factory)).stake, 10 ether); + } + + function test_unlockStake() public { + test_addStake(); + factory.unlockStake(); + assertEq(entryPoint.getDepositInfo(address(factory)).withdrawTime, block.timestamp + 10 hours); + } + + function test_withdrawStake() public { + test_unlockStake(); + vm.warp(10 hours); + vm.expectRevert("Stake withdrawal is not due"); + factory.withdrawStake(payable(address(this))); + assertEq(address(this).balance, 90 ether); + vm.warp(10 hours + 1); + factory.withdrawStake(payable(address(this))); + assertEq(address(this).balance, 100 ether); + } + + function test_withdraw() public { + factory.addStake{value: 10 ether}(10 hours, 1 ether); + assertEq(address(factory).balance, 9 ether); + factory.withdraw(payable(address(this)), address(0), 0); // amount = balance if native currency + assertEq(address(factory).balance, 0); + } + + // to receive funds from withdraw + receive() external payable {} +} diff --git a/test/factory/MultiOwnerTokenReceiverFactoryTest.t.sol b/test/factory/MultiOwnerTokenReceiverFactoryTest.t.sol new file mode 100644 index 00000000..d3dfbd4d --- /dev/null +++ b/test/factory/MultiOwnerTokenReceiverFactoryTest.t.sol @@ -0,0 +1,191 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.21; + +import {Test} from "forge-std/Test.sol"; + +import {EntryPoint} from "@eth-infinitism/account-abstraction/core/EntryPoint.sol"; +import {ECDSA} from "@openzeppelin/contracts/utils/cryptography/ECDSA.sol"; +import {ERC721PresetMinterPauserAutoId} from + "@openzeppelin/contracts/token/ERC721/presets/ERC721PresetMinterPauserAutoId.sol"; + +import {MultiOwnerTokenReceiverMSCAFactory} from "../../src/factory/MultiOwnerTokenReceiverMSCAFactory.sol"; +import {IEntryPoint} from "../../src/interfaces/erc4337/IEntryPoint.sol"; +import {MultiOwnerPlugin} from "../../src/plugins/owner/MultiOwnerPlugin.sol"; +import {TokenReceiverPlugin} from "../../src/plugins/TokenReceiverPlugin.sol"; +import {UpgradeableModularAccount} from "../../src/account/UpgradeableModularAccount.sol"; +import {MockERC777} from "../mocks/tokens/MockERC777.sol"; +import {MockERC1155} from "../mocks/tokens/MockERC1155.sol"; + +contract MultiOwnerTokenReceiverMSCAFactoryTest is Test { + using ECDSA for bytes32; + + EntryPoint public entryPoint; + MultiOwnerTokenReceiverMSCAFactory public factory; + MultiOwnerPlugin public multiOwnerPlugin; + TokenReceiverPlugin public tokenReceiverPlugin; + address public impl; + ERC721PresetMinterPauserAutoId public t0; + MockERC777 public t1; + MockERC1155 public t2; + + address public owner1 = address(1); + address public owner2 = address(2); + address public nftHolder = address(3); + + address[] public owners; + uint256[] public tokenIds; + uint256[] public tokenAmts; + uint256[] public zeroTokenAmts; + + uint256 internal constant _TOKEN_AMOUNT = 1 ether; + uint256 internal constant _TOKEN_ID = 0; + uint256 internal constant _BATCH_TOKEN_IDS = 5; + + bytes32 internal constant _IMPLEMENTATION_SLOT = + 0x360894a13ba1a3210667c828492db98dca3e2076cc3735a920a3ca505d382bbc; + + function setUp() public { + owners.push(owner1); + owners.push(owner2); + entryPoint = new EntryPoint(); + impl = address(new UpgradeableModularAccount(IEntryPoint(address(entryPoint)))); + multiOwnerPlugin = new MultiOwnerPlugin(); + tokenReceiverPlugin = new TokenReceiverPlugin(); + bytes32 ownerManifestHash = keccak256(abi.encode(multiOwnerPlugin.pluginManifest())); + bytes32 tokenReceiverManifestHash = keccak256(abi.encode(tokenReceiverPlugin.pluginManifest())); + factory = new MultiOwnerTokenReceiverMSCAFactory( + address(this), + address(multiOwnerPlugin), + address(tokenReceiverPlugin), + impl, + ownerManifestHash, + tokenReceiverManifestHash, + IEntryPoint(address(entryPoint)) + ); + vm.deal(nftHolder, 100 ether); + + t0 = new ERC721PresetMinterPauserAutoId("t0", "t0", ""); + t0.mint(nftHolder); + + t1 = new MockERC777(); + t1.mint(nftHolder, _TOKEN_AMOUNT); + + t2 = new MockERC1155(); + t2.mint(nftHolder, _TOKEN_ID, _TOKEN_AMOUNT); + for (uint256 i = 1; i < _BATCH_TOKEN_IDS; i++) { + t2.mint(nftHolder, i, _TOKEN_AMOUNT); + tokenIds.push(i); + tokenAmts.push(_TOKEN_AMOUNT); + zeroTokenAmts.push(0); + } + } + + function test_addressMatch() public { + address predicted = factory.getAddress(0, owners); + address deployed = factory.createAccount(0, owners); + assertEq(predicted, deployed); + } + + function test_deploy() public { + address deployed = factory.createAccount(0, owners); + + // test that the deployed account is initialized + assertEq(address(UpgradeableModularAccount(payable(deployed)).entryPoint()), address(entryPoint)); + + // test that the deployed account installed owner plugin correctly + address[] memory actualOwners = multiOwnerPlugin.ownersOf(deployed); + assertEq(actualOwners.length, 2); + assertEq(actualOwners[0], owner2); + assertEq(actualOwners[1], owner1); + } + + function test_receiveTokens() public { + address acct = factory.createAccount(0, owners); + + vm.startPrank(nftHolder); + + // test that it can receive tokens + assertEq(t0.ownerOf(_TOKEN_ID), nftHolder); + t0.safeTransferFrom(nftHolder, acct, _TOKEN_ID); + assertEq(t0.ownerOf(_TOKEN_ID), acct); + + assertEq(t1.balanceOf(nftHolder), _TOKEN_AMOUNT); + assertEq(t1.balanceOf(acct), 0); + t1.transfer(acct, _TOKEN_AMOUNT); + assertEq(t1.balanceOf(nftHolder), 0); + assertEq(t1.balanceOf(acct), _TOKEN_AMOUNT); + + assertEq(t2.balanceOf(nftHolder, _TOKEN_ID), _TOKEN_AMOUNT); + assertEq(t2.balanceOf(acct, _TOKEN_ID), 0); + t2.safeTransferFrom(nftHolder, acct, _TOKEN_ID, _TOKEN_AMOUNT, ""); + assertEq(t2.balanceOf(nftHolder, _TOKEN_ID), 0); + assertEq(t2.balanceOf(acct, _TOKEN_ID), _TOKEN_AMOUNT); + + for (uint256 i = 1; i < _BATCH_TOKEN_IDS; i++) { + assertEq(t2.balanceOf(nftHolder, i), _TOKEN_AMOUNT); + assertEq(t2.balanceOf(acct, i), 0); + } + t2.safeBatchTransferFrom(nftHolder, acct, tokenIds, tokenAmts, ""); + for (uint256 i = 1; i < _BATCH_TOKEN_IDS; i++) { + assertEq(t2.balanceOf(nftHolder, i), 0); + assertEq(t2.balanceOf(acct, i), _TOKEN_AMOUNT); + } + } + + function test_deployCollision() public { + address deployed = factory.createAccount(0, owners); + + uint256 gasStart = gasleft(); + + // deploy 2nd time which should short circuit + // test for short circuit -> call should cost less than a CREATE2, or 32000 gas + address secondDeploy = factory.createAccount(0, owners); + + assertApproxEqAbs(gasleft(), gasStart, 31999); + assertEq(deployed, secondDeploy); + } + + function test_deployedAccountHasCorrectPlugins() public { + address deployed = factory.createAccount(0, owners); + + // check installed plugins on account + address[] memory plugins = UpgradeableModularAccount(payable(deployed)).getInstalledPlugins(); + assertEq(plugins.length, 2); + assertEq(plugins[0], address(tokenReceiverPlugin)); + assertEq(plugins[1], address(multiOwnerPlugin)); + } + + function test_addStake() public { + assertEq(entryPoint.balanceOf(address(factory)), 0); + vm.deal(address(this), 100 ether); + factory.addStake{value: 10 ether}(10 hours, 10 ether); + assertEq(entryPoint.getDepositInfo(address(factory)).stake, 10 ether); + } + + function test_unlockStake() public { + test_addStake(); + factory.unlockStake(); + assertEq(entryPoint.getDepositInfo(address(factory)).withdrawTime, block.timestamp + 10 hours); + } + + function test_withdrawStake() public { + test_unlockStake(); + vm.warp(10 hours); + vm.expectRevert("Stake withdrawal is not due"); + factory.withdrawStake(payable(address(this))); + assertEq(address(this).balance, 90 ether); + vm.warp(10 hours + 1); + factory.withdrawStake(payable(address(this))); + assertEq(address(this).balance, 100 ether); + } + + function test_withdraw() public { + factory.addStake{value: 10 ether}(10 hours, 1 ether); + assertEq(address(factory).balance, 9 ether); + factory.withdraw(payable(address(this)), address(0), 0); // amount = balance if native currency + assertEq(address(factory).balance, 0); + } + + // to receive funds from withdraw + receive() external payable {} +} diff --git a/test/helpers/KnownSelectors.t.sol b/test/helpers/KnownSelectors.t.sol new file mode 100644 index 00000000..2f122eb4 --- /dev/null +++ b/test/helpers/KnownSelectors.t.sol @@ -0,0 +1,67 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.21; + +import {Test} from "forge-std/Test.sol"; + +import {IERC165} from "@openzeppelin/contracts/utils/introspection/IERC165.sol"; +import {UUPSUpgradeable} from "@openzeppelin/contracts/proxy/utils/UUPSUpgradeable.sol"; +import {BaseAccount} from "@eth-infinitism/account-abstraction/core/BaseAccount.sol"; +import {IAggregator} from "@eth-infinitism/account-abstraction/interfaces/IAggregator.sol"; +import {IPaymaster} from "@eth-infinitism/account-abstraction/interfaces/IPaymaster.sol"; + +import {KnownSelectors} from "../../src/helpers/KnownSelectors.sol"; +import {IAccountLoupe} from "../../src/interfaces/IAccountLoupe.sol"; +import {IAccountInitializable} from "../../src/interfaces/IAccountInitializable.sol"; +import {IStandardExecutor} from "../../src/interfaces/IStandardExecutor.sol"; +import {IPluginExecutor} from "../../src/interfaces/IPluginExecutor.sol"; +import {IPluginManager} from "../../src/interfaces/IPluginManager.sol"; + +contract KnownSelectorsTest is Test { + function test_isNativeFunction() public { + // account-abstraction BaseAccount methods + assertTrue(KnownSelectors.isNativeFunction(BaseAccount.getNonce.selector)); + assertTrue(KnownSelectors.isNativeFunction(BaseAccount.entryPoint.selector)); + assertTrue(KnownSelectors.isNativeFunction(BaseAccount.validateUserOp.selector)); + + // IPluginManager methods + assertTrue(KnownSelectors.isNativeFunction(IPluginManager.installPlugin.selector)); + assertTrue(KnownSelectors.isNativeFunction(IPluginManager.uninstallPlugin.selector)); + + // IERC165 methods + assertTrue(KnownSelectors.isNativeFunction(IERC165.supportsInterface.selector)); + + // UUPSUpgradeable methods + assertTrue(KnownSelectors.isNativeFunction(UUPSUpgradeable.proxiableUUID.selector)); + assertTrue(KnownSelectors.isNativeFunction(UUPSUpgradeable.upgradeToAndCall.selector)); + + // IStandardExecutor methods + assertTrue(KnownSelectors.isNativeFunction(IStandardExecutor.execute.selector)); + assertTrue(KnownSelectors.isNativeFunction(IStandardExecutor.executeBatch.selector)); + + // IPluginExecutor methods + assertTrue(KnownSelectors.isNativeFunction(IPluginExecutor.executeFromPlugin.selector)); + assertTrue(KnownSelectors.isNativeFunction(IPluginExecutor.executeFromPluginExternal.selector)); + + // IAccountInitializable methods + assertTrue(KnownSelectors.isNativeFunction(IAccountInitializable.initialize.selector)); + + // IAccountLoupe methods + assertTrue(KnownSelectors.isNativeFunction(IAccountLoupe.getExecutionFunctionConfig.selector)); + assertTrue(KnownSelectors.isNativeFunction(IAccountLoupe.getExecutionHooks.selector)); + assertTrue(KnownSelectors.isNativeFunction(IAccountLoupe.getPermittedCallHooks.selector)); + assertTrue(KnownSelectors.isNativeFunction(IAccountLoupe.getPreValidationHooks.selector)); + assertTrue(KnownSelectors.isNativeFunction(IAccountLoupe.getInstalledPlugins.selector)); + + assertFalse(KnownSelectors.isNativeFunction(IPaymaster.validatePaymasterUserOp.selector)); + } + + function test_isErc4337Function() public { + assertTrue(KnownSelectors.isErc4337Function(IAggregator.validateSignatures.selector)); + assertTrue(KnownSelectors.isErc4337Function(IAggregator.validateUserOpSignature.selector)); + assertTrue(KnownSelectors.isErc4337Function(IAggregator.aggregateSignatures.selector)); + assertTrue(KnownSelectors.isErc4337Function(IPaymaster.validatePaymasterUserOp.selector)); + assertTrue(KnownSelectors.isErc4337Function(IPaymaster.postOp.selector)); + + assertFalse(KnownSelectors.isErc4337Function(BaseAccount.validateUserOp.selector)); + } +} diff --git a/test/invariant/AssociatedLinkedListSetLibInvariants.t.sol b/test/invariant/AssociatedLinkedListSetLibInvariants.t.sol new file mode 100644 index 00000000..0ce0e02c --- /dev/null +++ b/test/invariant/AssociatedLinkedListSetLibInvariants.t.sol @@ -0,0 +1,106 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.21; + +import {Test} from "forge-std/Test.sol"; + +import {AssociatedLinkedListSetHandler} from "./handlers/AssociatedLinkedListSetHandler.sol"; + +contract AssociatedLinkedListSetLibInvariantsTest is Test { + AssociatedLinkedListSetHandler public handler; + + // Only use these constants for testing. Enforces uniqueness across ID and associated address, + // Without reducing invariant call depth too much. + address public constant ASSOCIATED_1 = address(uint160(bytes20(keccak256("ASSOCIATED_1")))); + address public constant ASSOCIATED_2 = address(uint160(bytes20(keccak256("ASSOCIATED_2")))); + uint64 public constant ID_1 = 42; + uint64 public constant ID_2 = 115557777; + + function setUp() public { + handler = new AssociatedLinkedListSetHandler(); + + bytes4[] memory selectors = new bytes4[](8); + selectors[0] = handler.add.selector; + selectors[1] = handler.removeIterate.selector; + selectors[2] = handler.removeRandKeyIterate.selector; + selectors[3] = handler.clear.selector; + selectors[4] = handler.removeKnownPrevKey.selector; + selectors[5] = handler.removeRandKnownPrevKey.selector; + selectors[6] = handler.addFlagKnown.selector; + selectors[7] = handler.addFlagRandom.selector; + + targetSelector(FuzzSelector({addr: address(handler), selectors: selectors})); + } + + function invariant_shouldContain() public { + _shouldContain(ASSOCIATED_1, ID_1); + _shouldContain(ASSOCIATED_1, ID_2); + _shouldContain(ASSOCIATED_2, ID_1); + _shouldContain(ASSOCIATED_2, ID_2); + } + + // Doesn't test for no duplicates yet + function invariant_getAllEquivalence() public { + _getAllEquivalence(ASSOCIATED_1, ID_1); + _getAllEquivalence(ASSOCIATED_1, ID_2); + _getAllEquivalence(ASSOCIATED_2, ID_1); + _getAllEquivalence(ASSOCIATED_2, ID_2); + } + + function invariant_flagValidity() public { + _flagValidityCheck(ASSOCIATED_1, ID_1); + _flagValidityCheck(ASSOCIATED_1, ID_2); + _flagValidityCheck(ASSOCIATED_2, ID_1); + _flagValidityCheck(ASSOCIATED_2, ID_2); + } + + function _shouldContain(address associated, uint64 id) internal { + bytes32[] memory vals = handler.referenceEnumerate(associated, id); + + if (vals.length == 0) { + assertTrue(handler.referenceIsEmpty(associated, id)); + assertTrue(handler.associatedIsEmpty(associated, id)); + } else { + assertFalse(handler.referenceIsEmpty(associated, id)); + assertFalse(handler.associatedIsEmpty(associated, id)); + for (uint256 i = 0; i < vals.length; i++) { + bytes30 val = bytes30(vals[i]); + assertTrue(handler.associatedContains(associated, id, val)); + assertTrue(handler.referenceContains(associated, id, val)); + } + } + } + + function _flagValidityCheck(address associated, uint64 id) internal { + (bytes32[] memory keys, uint16[] memory metaFlags) = handler.referenceGetFlags(associated, id); + + for (uint256 i = 0; i < keys.length; i++) { + bytes30 key = bytes30(keys[i]); + uint16 metaFlag = metaFlags[i]; + assertEq(handler.associatedGetFlags(associated, id, key), metaFlag); + } + } + + function _getAllEquivalence(address associated, uint64 id) internal { + bytes32[] memory referenceEnumerate = handler.referenceEnumerate(associated, id); + bytes32[] memory associatedEnumerate = handler.associatedEnumerate(associated, id); + + assertTrue(referenceEnumerate.length == associatedEnumerate.length); + + for (uint256 i = 0; i < referenceEnumerate.length; i++) { + assertTrue(_contains(associatedEnumerate, referenceEnumerate[i])); + } + + for (uint256 i = 0; i < associatedEnumerate.length; i++) { + assertTrue(_contains(referenceEnumerate, associatedEnumerate[i])); + } + } + + function _contains(bytes32[] memory arr, bytes32 val) internal pure returns (bool) { + for (uint256 i = 0; i < arr.length; i++) { + if (arr[i] == val) { + return true; + } + } + return false; + } +} diff --git a/test/invariant/LLSLRepro.t.sol b/test/invariant/LLSLRepro.t.sol new file mode 100644 index 00000000..7102f87a --- /dev/null +++ b/test/invariant/LLSLRepro.t.sol @@ -0,0 +1,37 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.21; + +import {Test} from "forge-std/Test.sol"; + +import {LinkedListSetHandler} from "./handlers/LinkedListSetHandler.sol"; +import {AssociatedLinkedListSetHandler} from "./handlers/AssociatedLinkedListSetHandler.sol"; + +contract LLSLReproTest is Test { + LinkedListSetHandler public handler; + AssociatedLinkedListSetHandler public associatedHandler; + + function setUp() public { + handler = new LinkedListSetHandler(); + associatedHandler = new AssociatedLinkedListSetHandler(); + } + + function test_repro_1() public { + handler.removeRandKeyIterate(0); + handler.add(0xeeeb07e4676e566803e52fe9a102d0fe0c0ae5007215518bffb33d6c07e2); + handler.removeRandKnownPrevKey( + 0xeeeb07e4676e566803e52fe9a102d0fe0c0ae5007215518bffb33d6c07e2, + 0x0000000000000000000000000000000000000000000000000000000000001b01 + ); + } + + function test_repro_2() public { + associatedHandler.removeRandKeyIterate(0, 0, 0); + associatedHandler.add(0xeeeb07e4676e566803e52fe9a102d0fe0c0ae5007215518bffb33d6c07e2, 0, 0); + associatedHandler.removeRandKnownPrevKey( + 0xeeeb07e4676e566803e52fe9a102d0fe0c0ae5007215518bffb33d6c07e2, + 0x0000000000000000000000000000000000000000000000000000000000001b01, + 0, + 0 + ); + } +} diff --git a/test/invariant/LinkedListSetLibInvariants.t.sol b/test/invariant/LinkedListSetLibInvariants.t.sol new file mode 100644 index 00000000..8041e155 --- /dev/null +++ b/test/invariant/LinkedListSetLibInvariants.t.sol @@ -0,0 +1,78 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.21; + +import {Test} from "forge-std/Test.sol"; + +import {LinkedListSetHandler} from "./handlers/LinkedListSetHandler.sol"; + +contract LinkedListSetLibInvariantsTest is Test { + LinkedListSetHandler public handler; + + function setUp() public { + handler = new LinkedListSetHandler(); + + bytes4[] memory selectors = new bytes4[](8); + selectors[0] = handler.add.selector; + selectors[1] = handler.removeIterate.selector; + selectors[2] = handler.removeRandKeyIterate.selector; + selectors[3] = handler.clear.selector; + selectors[4] = handler.removeKnownPrevKey.selector; + selectors[5] = handler.removeRandKnownPrevKey.selector; + selectors[6] = handler.addFlagKnown.selector; + selectors[7] = handler.addFlagRandom.selector; + + targetSelector(FuzzSelector({addr: address(handler), selectors: selectors})); + } + + function invariant_shouldContain() public { + bytes32[] memory vals = handler.referenceEnumerate(); + + if (vals.length == 0) { + assertTrue(handler.referenceIsEmpty()); + assertTrue(handler.libIsEmpty()); + } else { + assertFalse(handler.referenceIsEmpty()); + assertFalse(handler.libIsEmpty()); + for (uint256 i = 0; i < vals.length; i++) { + bytes30 val = bytes30(vals[i]); + assertTrue(handler.libContains(val)); + assertTrue(handler.referenceContains(val)); + } + } + } + + // Doesn't test for no duplicates yet + function invariant_getAllEquivalence() public { + bytes32[] memory referenceEnumerate = handler.referenceEnumerate(); + bytes32[] memory libEnumerate = handler.libEnumerate(); + + assertTrue(referenceEnumerate.length == libEnumerate.length); + + for (uint256 i = 0; i < referenceEnumerate.length; i++) { + assertTrue(_contains(libEnumerate, referenceEnumerate[i])); + } + + for (uint256 i = 0; i < libEnumerate.length; i++) { + assertTrue(_contains(referenceEnumerate, libEnumerate[i])); + } + } + + function invariant_flagValidity() public { + (bytes32[] memory keys, uint16[] memory metaFlags) = handler.referenceGetFlags(); + + for (uint256 i = 0; i < keys.length; i++) { + bytes30 key = bytes30(keys[i]); + uint16 metaFlag = metaFlags[i]; + assertEq(handler.libGetFlags(key), metaFlag); + } + } + + function _contains(bytes32[] memory arr, bytes32 val) internal pure returns (bool) { + for (uint256 i = 0; i < arr.length; i++) { + if (arr[i] == val) { + return true; + } + } + return false; + } +} diff --git a/test/invariant/handlers/AssociatedLinkedListSetHandler.sol b/test/invariant/handlers/AssociatedLinkedListSetHandler.sol new file mode 100644 index 00000000..00a1e916 --- /dev/null +++ b/test/invariant/handlers/AssociatedLinkedListSetHandler.sol @@ -0,0 +1,414 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.21; + +import {CommonBase} from "forge-std/Base.sol"; +import {StdCheats} from "forge-std/StdCheats.sol"; +import {StdUtils} from "forge-std/StdUtils.sol"; + +import {EnumerableSet} from "@openzeppelin/contracts/utils/structs/EnumerableSet.sol"; +import {EnumerableMap} from "@openzeppelin/contracts/utils/structs/EnumerableMap.sol"; +import { + AssociatedLinkedListSet, + AssociatedLinkedListSetLib +} from "../../../src/libraries/AssociatedLinkedListSetLib.sol"; +import {SetValue} from "../../../src/libraries/LinkedListSetUtils.sol"; + +/// @notice A handler contract for differential invariant testing AssociatedLinkedListSetLib +/// This contract maps logic for adding, removeing, clearing, and inspecting a list +/// to a reference implementation using EnumerableSet.Bytes32Set, which the invariant +/// fuzzer can then use to test the library. +contract AssociatedLinkedListSetHandler is CommonBase, StdCheats, StdUtils { + using AssociatedLinkedListSetLib for AssociatedLinkedListSet; + using EnumerableSet for EnumerableSet.Bytes32Set; + using EnumerableMap for EnumerableMap.Bytes32ToUintMap; + + mapping(address => mapping(uint64 => EnumerableSet.Bytes32Set)) internal _referenceSets; + mapping(address => mapping(uint64 => EnumerableMap.Bytes32ToUintMap)) internal _referenceMeta; + + error FailedToAdd(bytes30 value); + error FailedToAddFlags(uint16 value); + error FailedToGetFlags(uint16 expected, uint16 actual); + error BadAddFlags(bytes30 value, uint16 flags); + error FailedToRemove(bytes30 value); + error ShouldNotRemove(bytes30 value); + error ContainsNotExpected(bytes30 value); + error DoesNotContain(bytes30 value); + error LengthMismatch(uint256 expected, uint256 actual); + + address public constant ASSOCIATED_1 = address(uint160(bytes20(keccak256("ASSOCIATED_1")))); + address public constant ASSOCIATED_2 = address(uint160(bytes20(keccak256("ASSOCIATED_2")))); + + AssociatedLinkedListSet public set1; + AssociatedLinkedListSet public set2; + + uint64 public constant ID_1 = 42; + uint64 public constant ID_2 = 115557777; + + bytes32 internal constant SENTINEL_VALUE = bytes32(uint256(1)); + + constructor() {} + + /// @notice Adds to both copies of the list - the associated one and the reference one + function add(bytes30 val, uint256 seedAddr, uint256 seedId) external { + AssociatedLinkedListSet storage associatedSet = seedId % 2 == 0 ? set1 : set2; + uint64 id = uint64(seedId % 2 == 0 ? ID_1 : ID_2); + address associated = address(seedAddr % 2 == 0 ? ASSOCIATED_1 : ASSOCIATED_2); + + EnumerableSet.Bytes32Set storage referenceSet = _referenceSets[associated][id]; + if (referenceSet.contains(bytes32(val)) || val == bytes30(0)) { + return; // Silently do nothing + } + + referenceSet.add(bytes32(val)); + + bool success = associatedSet.tryAdd(associated, SetValue.wrap(val)); + if (!success) { + revert FailedToAdd(val); + } + } + + /// @notice Removes a key from both sets by its index in the reference implementation. + /// Uses the O(n) iterating remove method. + function removeIterate(uint256 indexToRemove, uint256 seedAddr, uint256 seedId) external { + AssociatedLinkedListSet storage associatedSet = seedId % 2 == 0 ? set1 : set2; + uint64 id = uint64(seedId % 2 == 0 ? ID_1 : ID_2); + address associated = address(seedAddr % 2 == 0 ? ASSOCIATED_1 : ASSOCIATED_2); + EnumerableSet.Bytes32Set storage referenceSet = _referenceSets[associated][id]; + EnumerableMap.Bytes32ToUintMap storage referenceMeta = _referenceMeta[associated][id]; + + if (referenceSet.length() == 0) { + return; // Silently do nothing + } + + indexToRemove = bound(indexToRemove, 0, referenceSet.length() - 1); + + bytes30 value = bytes30(referenceSet.at(indexToRemove)); + + referenceSet.remove(bytes32(value)); + + // Remove the meta entry if it exists + referenceMeta.remove(bytes32(value)); + + if (!associatedSet.tryRemove(associated, SetValue.wrap(value))) { + revert FailedToRemove(value); + } + + if (associatedSet.contains(associated, SetValue.wrap(value))) { + revert ContainsNotExpected(value); + } + } + + /// @notice Removes a key from both sets by its index in the reference implementation. + /// Accepts an arbitrary value to attempt to remove that may or may not be in the list. + /// Uses the O(n) iterating remove method. + function removeRandKeyIterate(bytes30 val, uint256 seedAddr, uint256 seedId) external { + AssociatedLinkedListSet storage associatedSet = seedId % 2 == 0 ? set1 : set2; + uint64 id = uint64(seedId % 2 == 0 ? ID_1 : ID_2); + address associated = address(seedAddr % 2 == 0 ? ASSOCIATED_1 : ASSOCIATED_2); + + EnumerableSet.Bytes32Set storage referenceSet = _referenceSets[associated][id]; + EnumerableMap.Bytes32ToUintMap storage referenceMeta = _referenceMeta[associated][id]; + + if (!referenceSet.contains(bytes32(val))) { + if (associatedSet.contains(associated, SetValue.wrap(val))) { + revert ContainsNotExpected(val); + } + if (associatedSet.tryRemove(associated, SetValue.wrap(val))) { + revert ShouldNotRemove(val); + } + if (associatedSet.contains(associated, SetValue.wrap(val))) { + revert ContainsNotExpected(val); + } + return; // short-circuit after making assertions. + } + + referenceSet.remove(bytes32(val)); + + // Remove the meta entry if it exists + referenceMeta.remove(bytes32(val)); + + if (!associatedSet.tryRemove(associated, SetValue.wrap(val))) { + revert FailedToRemove(val); + } + + if (associatedSet.contains(associated, SetValue.wrap(val))) { + revert ContainsNotExpected(val); + } + } + + /// @notice Removes a key by looking up it's predecessor via getAll before submitting the call + /// Uses the O(1) remove method that has knowledge of the previous key. + function removeKnownPrevKey(uint256 index, uint256 seedAddr, uint256 seedId) external { + AssociatedLinkedListSet storage associatedSet = seedId % 2 == 0 ? set1 : set2; + uint64 id = uint64(seedId % 2 == 0 ? ID_1 : ID_2); + address associated = address(seedAddr % 2 == 0 ? ASSOCIATED_1 : ASSOCIATED_2); + + EnumerableSet.Bytes32Set storage referenceSet = _referenceSets[associated][id]; + EnumerableMap.Bytes32ToUintMap storage referenceMeta = _referenceMeta[associated][id]; + + if (referenceSet.length() == 0) { + return; // Silently do nothing + } + + index = bound(index, 0, referenceSet.length() - 1); + + bytes30 value = bytes30(referenceSet.at(index)); + + referenceSet.remove(bytes32(value)); + + // Remove the meta entry if it exists + referenceMeta.remove(bytes32(value)); + + // Get the previous entry via getAll + SetValue[] memory values = associatedSet.getAll(associated); + if (values.length == 0) { + revert LengthMismatch(0, values.length); + } + + bytes32 prev; + for (uint256 i = 0; i < values.length; i++) { + if (SetValue.unwrap(values[i]) == bytes30(value)) { + if (i == 0) { + prev = SENTINEL_VALUE; + } else { + prev = bytes32(SetValue.unwrap(values[i - 1])); + } + break; + } + } + + if (prev == bytes32(0)) { + revert DoesNotContain(value); + } + + if (!associatedSet.tryRemoveKnown(associated, SetValue.wrap(value), prev)) { + revert FailedToRemove(value); + } + + if (associatedSet.contains(associated, SetValue.wrap(value))) { + revert ContainsNotExpected(value); + } + } + + /// @notice Removes a key using the O(1) remove method that has knowledge of the previous key. + /// Accepts an arbitrary value for the remove and for prev that may or may not be in the list. + function removeRandKnownPrevKey(bytes30 val, bytes32 prev, uint256 seedAddr, uint256 seedId) external { + AssociatedLinkedListSet storage associatedSet = seedId % 2 == 0 ? set1 : set2; + uint64 id = uint64(seedId % 2 == 0 ? ID_1 : ID_2); + address associated = address(seedAddr % 2 == 0 ? ASSOCIATED_1 : ASSOCIATED_2); + + EnumerableSet.Bytes32Set storage referenceSet = _referenceSets[associated][id]; + EnumerableMap.Bytes32ToUintMap storage referenceMeta = _referenceMeta[associated][id]; + + if (!referenceSet.contains(bytes32(val))) { + if (associatedSet.contains(associated, SetValue.wrap(val))) { + revert ContainsNotExpected(val); + } + if (associatedSet.tryRemoveKnown(associated, SetValue.wrap(val), prev)) { + revert ShouldNotRemove(val); + } + if (associatedSet.contains(associated, SetValue.wrap(val))) { + revert ContainsNotExpected(val); + } + return; // short-circuit after making assertions. + } + + // Check to see in case it is actually the previous key + SetValue[] memory values = associatedSet.getAll(associated); + if (values.length == 0) { + revert LengthMismatch(0, values.length); + } + bytes32 realPrev; + for (uint256 i = 0; i < values.length; i++) { + if (SetValue.unwrap(values[i]) == bytes30(val)) { + if (i == 0) { + realPrev = SENTINEL_VALUE; + } else { + realPrev = bytes32(SetValue.unwrap(values[i - 1])); + } + break; + } + } + + // Clear the flags of prev to avoid any false test failures. This is ONLY safe to do if the library + // function also performs this clear, otherwise it will result in untested edge cases. + // This clearing is done after the prior check in the case where the value is not contained, to broaden the + // test cases. + prev = AssociatedLinkedListSetLib.clearFlags(prev); + + if (realPrev != prev) { + if (associatedSet.tryRemoveKnown(associated, SetValue.wrap(val), prev)) { + revert ShouldNotRemove(val); + } + return; // short-circuit after making assertions. + } else { + // Somehow, the invariant fuzzer actually generated a real prev value. Process the removal + referenceSet.remove(bytes32(val)); + + // Remove the meta entry if it exists + referenceMeta.remove(bytes32(val)); + + if (!associatedSet.tryRemoveKnown(associated, SetValue.wrap(val), prev)) { + revert FailedToRemove(val); + } + if (associatedSet.contains(associated, SetValue.wrap(val))) { + revert ContainsNotExpected(val); + } + } + } + + /// @notice Clears both copies of the list - the associated one and the reference one. + function clear(uint256 seedAddr, uint256 seedId) external { + AssociatedLinkedListSet storage associatedSet = seedId % 2 == 0 ? set1 : set2; + uint64 id = uint64(seedId % 2 == 0 ? ID_1 : ID_2); + address associated = address(seedAddr % 2 == 0 ? ASSOCIATED_1 : ASSOCIATED_2); + + EnumerableSet.Bytes32Set storage referenceSet = _referenceSets[associated][id]; + EnumerableMap.Bytes32ToUintMap storage referenceMeta = _referenceMeta[associated][id]; + + while (referenceSet.length() > 0) { + bytes30 value = bytes30(referenceSet.at(0)); + referenceSet.remove(bytes32(value)); + referenceMeta.remove(bytes32(value)); + } + + associatedSet.clear(associated); + } + + function addFlagKnown(uint256 seedAddr, uint256 seedId, uint256 indexToFlag, uint16 flags) external { + AssociatedLinkedListSet storage associatedSet = seedId % 2 == 0 ? set1 : set2; + uint64 id = uint64(seedId % 2 == 0 ? ID_1 : ID_2); + address associated = address(seedAddr % 2 == 0 ? ASSOCIATED_1 : ASSOCIATED_2); + + EnumerableSet.Bytes32Set storage referenceSet = _referenceSets[associated][id]; + EnumerableMap.Bytes32ToUintMap storage referenceMeta = _referenceMeta[associated][id]; + + if (referenceSet.length() == 0) { + return; // Silently do nothing + } + + flags &= 0xFFFC; // Clear the last two bits + + indexToFlag = bound(indexToFlag, 0, referenceSet.length() - 1); + + bytes30 value = bytes30(referenceSet.at(indexToFlag)); + + if (!associatedSet.trySetFlags(associated, SetValue.wrap(value), flags)) { + revert FailedToAddFlags(flags); + } + + uint16 returnedFlags = associatedSet.getFlags(associated, SetValue.wrap(value)); + if (returnedFlags != flags) { + revert FailedToGetFlags(flags, returnedFlags); + } + + // Add this entry to the reference set. + referenceMeta.set(bytes32(value), flags); + } + + function addFlagRandom(uint256 seedAddr, uint256 seedId, bytes30 key, uint16 flags) external { + AssociatedLinkedListSet storage associatedSet = seedId % 2 == 0 ? set1 : set2; + uint64 id = uint64(seedId % 2 == 0 ? ID_1 : ID_2); + address associated = address(seedAddr % 2 == 0 ? ASSOCIATED_1 : ASSOCIATED_2); + + EnumerableSet.Bytes32Set storage referenceSet = _referenceSets[associated][id]; + EnumerableMap.Bytes32ToUintMap storage referenceMeta = _referenceMeta[associated][id]; + + flags &= 0xFFFC; // Clear the last two bits + + if (!referenceSet.contains(bytes32(key))) { + if (associatedSet.trySetFlags(associated, SetValue.wrap(key), flags)) { + revert BadAddFlags(key, flags); + } + } else { + // The value actually exists, add the flags correctly + if (!associatedSet.trySetFlags(associated, SetValue.wrap(key), flags)) { + revert FailedToAddFlags(flags); + } + + uint16 returnedFlags = associatedSet.getFlags(associated, SetValue.wrap(key)); + if (returnedFlags != flags) { + revert FailedToGetFlags(flags, returnedFlags); + } + + // Add this entry to the reference set. + referenceMeta.set(bytes32(key), flags); + } + } + + /// @notice Checks if the associated set contains a value + function associatedContains(address associated, uint64 id, bytes30 val) external view returns (bool) { + AssociatedLinkedListSet storage associatedSet = _mapIdToSet(id); + + return associatedSet.contains(associated, SetValue.wrap(val)); + } + + /// @notice Checks if the reference set contains a value + function referenceContains(address associated, uint64 id, bytes30 val) external view returns (bool) { + EnumerableSet.Bytes32Set storage referenceSet = _referenceSets[associated][id]; + return referenceSet.contains(bytes32(val)); + } + + /// @notice Checks if the associated set is empty + function associatedIsEmpty(address associated, uint64 id) external view returns (bool) { + AssociatedLinkedListSet storage associatedSet = _mapIdToSet(id); + return associatedSet.isEmpty(associated); + } + + /// @notice Checks if the reference set is empty + function referenceIsEmpty(address associated, uint64 id) external view returns (bool) { + EnumerableSet.Bytes32Set storage referenceSet = _referenceSets[associated][id]; + return referenceSet.length() == 0; + } + + /// @notice Gets all contents of the reference set + function referenceEnumerate(address associated, uint64 id) external view returns (bytes32[] memory ret) { + EnumerableSet.Bytes32Set storage referenceSet = _referenceSets[associated][id]; + ret = new bytes32[](referenceSet.length()); + for (uint256 i = 0; i < referenceSet.length(); i++) { + ret[i] = referenceSet.at(i); + } + } + + /// @notice Gets all contents of the associated set + function associatedEnumerate(address associated, uint64 id) external view returns (bytes32[] memory ret) { + AssociatedLinkedListSet storage set = _mapIdToSet(id); + SetValue[] memory values = set.getAll(associated); + // Unsafe cast lol + assembly ("memory-safe") { + ret := values + } + } + + function referenceGetFlags(address associated, uint64 id) + external + view + returns (bytes32[] memory keys, uint16[] memory metas) + { + EnumerableMap.Bytes32ToUintMap storage referenceMeta = _referenceMeta[associated][id]; + + keys = new bytes32[](referenceMeta.length()); + metas = new uint16[](referenceMeta.length()); + + for (uint256 i = 0; i < referenceMeta.length(); i++) { + (bytes32 key, uint256 meta) = referenceMeta.at(i); + keys[i] = key; + metas[i] = uint16(meta); + } + } + + function associatedGetFlags(address associated, uint64 id, bytes30 key) external view returns (uint16 meta) { + AssociatedLinkedListSet storage associatedSet = _mapIdToSet(id); + meta = associatedSet.getFlags(associated, SetValue.wrap(key)); + } + + function _mapIdToSet(uint64 id) private view returns (AssociatedLinkedListSet storage associatedSet) { + if (id == ID_1) { + associatedSet = set1; + } else if (id == ID_2) { + associatedSet = set2; + } else { + revert("Invalid id"); + } + } +} diff --git a/test/invariant/handlers/LinkedListSetHandler.sol b/test/invariant/handlers/LinkedListSetHandler.sol new file mode 100644 index 00000000..526e8c41 --- /dev/null +++ b/test/invariant/handlers/LinkedListSetHandler.sol @@ -0,0 +1,335 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.21; + +import {CommonBase} from "forge-std/Base.sol"; +import {StdCheats} from "forge-std/StdCheats.sol"; +import {StdUtils} from "forge-std/StdUtils.sol"; + +import {EnumerableSet} from "@openzeppelin/contracts/utils/structs/EnumerableSet.sol"; +import {EnumerableMap} from "@openzeppelin/contracts/utils/structs/EnumerableMap.sol"; +import {LinkedListSetLib, LinkedListSet as EnumerableSetType} from "../../../src/libraries/LinkedListSetLib.sol"; +import {SetValue} from "../../../src/libraries/LinkedListSetUtils.sol"; + +/// @notice A handler contract for differential invariant testing LinkedListSetLib +/// This contract maps logic for adding, removeing, clearing, and inspecting a list +/// to a reference implementation using EnumerableSet.Bytes32Set, which the invariant +/// fuzzer can then use to test the library. +contract LinkedListSetHandler is CommonBase, StdCheats, StdUtils { + using LinkedListSetLib for EnumerableSetType; + using EnumerableSet for EnumerableSet.Bytes32Set; + using EnumerableMap for EnumerableMap.Bytes32ToUintMap; + + EnumerableSet.Bytes32Set internal referenceSet; + EnumerableMap.Bytes32ToUintMap internal referenceMeta; + + EnumerableSetType internal libSet; + + error FailedToAdd(bytes30 value); + error FailedToAddFlags(uint16 value); + error FailedToGetFlags(uint16 expected, uint16 actual); + error BadAddFlags(bytes30 value, uint16 flags); + error FailedToRemove(bytes30 value); + error ShouldNotRemove(bytes30 value); + error ContainsNotExpected(bytes30 value); + error DoesNotContain(bytes30 value); + error LengthMismatch(uint256 expected, uint256 actual); + error MetaDoesNotContain(bytes30 value); + + bytes32 internal constant SENTINEL_VALUE = bytes32(uint256(1)); + + constructor() {} + + /// @notice Adds to both copies of the list - the library one and the reference one + function add(bytes30 val) external { + if (referenceSet.contains(bytes32(val)) || val == bytes30(0)) { + return; // Silently do nothing + } + + referenceSet.add(bytes32(val)); + + bool success = libSet.tryAdd(SetValue.wrap(val)); + if (!success) { + revert FailedToAdd(val); + } + } + + /// @notice Removes a key from both sets by its index in the reference implementation. + /// Uses the O(n) iterating remove method. + function removeIterate(uint256 indexToRemove) external { + if (referenceSet.length() == 0) { + return; // Silently do nothing + } + + indexToRemove = bound(indexToRemove, 0, referenceSet.length() - 1); + + bytes30 value = bytes30(referenceSet.at(indexToRemove)); + + // Assert the value was in the reference set and is now removed. + if (!referenceSet.remove(bytes32(value))) { + revert DoesNotContain(value); + } + + // Remove the meta entry if it exists + referenceMeta.remove(bytes32(value)); + + if (!libSet.tryRemove(SetValue.wrap(value))) { + revert FailedToRemove(value); + } + + if (libSet.contains(SetValue.wrap(value))) { + revert ContainsNotExpected(value); + } + } + + /// @notice Removes a key from both sets. + /// Accepts an arbitrary value to attempt to remove that may or may not be in the list. + /// Uses the O(n) iterating remove method. + function removeRandKeyIterate(bytes30 val) external { + if (!referenceSet.contains(bytes32(val))) { + if (libSet.contains(SetValue.wrap(val))) { + revert ContainsNotExpected(val); + } + if (libSet.tryRemove(SetValue.wrap(val))) { + revert ShouldNotRemove(val); + } + if (libSet.contains(SetValue.wrap(val))) { + revert ContainsNotExpected(val); + } + return; // short-circuit after making assertions. + } + + if (!referenceSet.remove(bytes32(val))) { + revert DoesNotContain(val); + } + + // Remove the meta entry if it exists + referenceMeta.remove(bytes32(val)); + + if (!libSet.tryRemove(SetValue.wrap(val))) { + revert FailedToRemove(val); + } + + if (libSet.contains(SetValue.wrap(val))) { + revert ContainsNotExpected(val); + } + } + + /// @notice Removes a key by looking up it's predecessor via getAll before submitting the call + /// Uses the O(1) remove method that has knowledge of the previous key. + function removeKnownPrevKey(uint256 index) external { + if (referenceSet.length() == 0) { + return; // Silently do nothing + } + + index = bound(index, 0, referenceSet.length() - 1); + + bytes30 value = bytes30(referenceSet.at(index)); + + referenceSet.remove(bytes32(value)); + + // Remove the meta entry if it exists + referenceMeta.remove(bytes32(value)); + + // Get the previous entry via getAll + SetValue[] memory values = libSet.getAll(); + if (values.length == 0) { + revert LengthMismatch(0, values.length); + } + + bytes32 prev; + for (uint256 i = 0; i < values.length; i++) { + if (SetValue.unwrap(values[i]) == bytes30(value)) { + if (i == 0) { + prev = SENTINEL_VALUE; + } else { + prev = bytes32(SetValue.unwrap(values[i - 1])); + } + break; + } + } + + if (prev == bytes32(0)) { + revert DoesNotContain(value); + } + + if (!libSet.tryRemoveKnown(SetValue.wrap(value), prev)) { + revert FailedToRemove(value); + } + + if (libSet.contains(SetValue.wrap(value))) { + revert ContainsNotExpected(value); + } + } + + /// @notice Removes a key using the O(1) remove method that has knowledge of the previous key. + /// Accepts an arbitrary value for the remove and for prev that may or may not be in the list. + function removeRandKnownPrevKey(bytes30 val, bytes32 prev) external { + if (!referenceSet.contains(bytes32(val))) { + if (libSet.contains(SetValue.wrap(val))) { + revert ContainsNotExpected(val); + } + if (libSet.tryRemoveKnown(SetValue.wrap(val), prev)) { + revert ShouldNotRemove(val); + } + if (libSet.contains(SetValue.wrap(val))) { + revert ContainsNotExpected(val); + } + return; // short-circuit after making assertions. + } + + // Check to see in case it is actually the previous key + SetValue[] memory values = libSet.getAll(); + if (values.length == 0) { + revert LengthMismatch(0, values.length); + } + bytes32 realPrev; + for (uint256 i = 0; i < values.length; i++) { + if (SetValue.unwrap(values[i]) == bytes30(val)) { + if (i == 0) { + realPrev = SENTINEL_VALUE; + } else { + realPrev = bytes32(SetValue.unwrap(values[i - 1])); + } + break; + } + } + + // Clear the flags of prev to avoid any false test failures. This is ONLY safe to do if the library + // function also performs this clear, otherwise it will result in untested edge cases. + // This clearing is done after the prior check in the case where the value is not contained, to broaden the + // test cases. + prev = LinkedListSetLib.clearFlags(prev); + + if (realPrev != prev) { + if (libSet.tryRemoveKnown(SetValue.wrap(val), prev)) { + revert ShouldNotRemove(val); + } + return; // short-circuit after making assertions. + } else { + // Somehow, the invariant fuzzer actually generated a real prev value. Process the removal + if (!referenceSet.remove(bytes32(val))) { + revert DoesNotContain(val); + } + + // Remove the meta entry if it exists + referenceMeta.remove(bytes32(val)); + + if (!libSet.tryRemoveKnown(SetValue.wrap(val), prev)) { + revert FailedToRemove(val); + } + if (libSet.contains(SetValue.wrap(val))) { + revert ContainsNotExpected(val); + } + } + } + + /// @notice Clears both copies of the list - the library one and the reference one. + function clear() external { + while (referenceSet.length() > 0) { + bytes30 value = bytes30(referenceSet.at(0)); + referenceSet.remove(bytes32(value)); + referenceMeta.remove(bytes32(value)); + } + + libSet.clear(); + } + + function addFlagKnown(uint256 indexToFlag, uint16 flags) external { + if (referenceSet.length() == 0) { + return; // Silently do nothing + } + + flags &= 0xFFFC; // Clear the last two bits + + indexToFlag = bound(indexToFlag, 0, referenceSet.length() - 1); + + bytes30 value = bytes30(referenceSet.at(indexToFlag)); + + if (!libSet.trySetFlags(SetValue.wrap(value), flags)) { + revert FailedToAddFlags(flags); + } + + uint16 returnedFlags = libSet.getFlags(SetValue.wrap(value)); + if (returnedFlags != flags) { + revert FailedToGetFlags(flags, returnedFlags); + } + + // Add this entry to the reference set. + referenceMeta.set(bytes32(value), flags); + } + + function addFlagRandom(bytes30 key, uint16 flags) external { + flags &= 0xFFFC; // Clear the last two bits + + if (!referenceSet.contains(bytes32(key))) { + if (libSet.trySetFlags(SetValue.wrap(key), flags)) { + revert BadAddFlags(key, flags); + } + } else { + // The value actually exists, add the flags correctly + if (!libSet.trySetFlags(SetValue.wrap(key), flags)) { + revert FailedToAddFlags(flags); + } + + uint16 returnedFlags = libSet.getFlags(SetValue.wrap(key)); + if (returnedFlags != flags) { + revert FailedToGetFlags(flags, returnedFlags); + } + + // Add this entry to the reference set. + referenceMeta.set(bytes32(key), flags); + } + } + + /// @notice Checks if the library set contains a value + function libContains(bytes30 val) external view returns (bool) { + return libSet.contains(SetValue.wrap(val)); + } + + /// @notice Checks if the reference set contains a value + function referenceContains(bytes30 val) external view returns (bool) { + return referenceSet.contains(bytes32(val)); + } + + /// @notice Checks if the library set is empty + function libIsEmpty() external view returns (bool) { + return libSet.isEmpty(); + } + + /// @notice Checks if the reference set is empty + function referenceIsEmpty() external view returns (bool) { + return referenceSet.length() == 0; + } + + /// @notice Gets all contents of the reference set + function referenceEnumerate() external view returns (bytes32[] memory ret) { + ret = new bytes32[](referenceSet.length()); + for (uint256 i = 0; i < referenceSet.length(); i++) { + ret[i] = referenceSet.at(i); + } + } + + /// @notice Gets all contents of the library set + function libEnumerate() external view returns (bytes32[] memory ret) { + SetValue[] memory values = libSet.getAll(); + // Unsafe cast lol + assembly ("memory-safe") { + ret := values + } + } + + function referenceGetFlags() external view returns (bytes32[] memory keys, uint16[] memory metas) { + keys = new bytes32[](referenceMeta.length()); + metas = new uint16[](referenceMeta.length()); + + for (uint256 i = 0; i < referenceMeta.length(); i++) { + (bytes32 key, uint256 meta) = referenceMeta.at(i); + keys[i] = key; + metas[i] = uint16(meta); + } + } + + function libGetFlags(bytes30 key) external view returns (uint16 meta) { + meta = libSet.getFlags(SetValue.wrap(key)); + } +} diff --git a/test/libraries/AccountStorage.t.sol b/test/libraries/AccountStorage.t.sol new file mode 100644 index 00000000..ae91dc34 --- /dev/null +++ b/test/libraries/AccountStorage.t.sol @@ -0,0 +1,50 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.21; + +import {Test} from "forge-std/Test.sol"; +import {AccountStorageV1} from "../../src/libraries/AccountStorageV1.sol"; +import {AccountStorageInitializable} from "../../src/account/AccountStorageInitializable.sol"; +import {MockDiamondStorageContract} from "../mocks/MockDiamondStorageContract.sol"; +import {ERC1967Proxy} from "@openzeppelin/contracts/proxy/ERC1967/ERC1967Proxy.sol"; + +// Test implementation of AccountStorageInitializable which is contained in UpgradeableModularAccount +contract AccountStorageTest is Test, AccountStorageV1 { + MockDiamondStorageContract public impl; + address public proxy; + + function setUp() external { + impl = new MockDiamondStorageContract(); + proxy = address(new ERC1967Proxy(address(impl), "")); + } + + function test_storageSlotErc7201Formula() external { + bytes32 expected = keccak256( + abi.encode(uint256(keccak256("Alchemy.UpgradeableModularAccount.Storage_V1")) - 1) + ) & ~bytes32(uint256(0xff)); + assertEq(_V1_STORAGE_SLOT, expected); + } + + function test_storageSlotImpl() external { + // disable initializers sets value to uint8(max) + assertEq(uint256(vm.load(address(impl), _V1_STORAGE_SLOT)), type(uint8).max); + + // should revert if we try to initialize again + vm.expectRevert(AccountStorageInitializable.AlreadyInitialized.selector); + impl.initialize(); + } + + function test_storageSlotProxy() external { + // before init, proxy's slot should be empty + assertEq(uint256(vm.load(proxy, _V1_STORAGE_SLOT)), uint256(0)); + + MockDiamondStorageContract(proxy).initialize(); + // post init slot should contains: packed(uint8 initialized = 1, bool initializing = 0) + assertEq(uint256(vm.load(proxy, _V1_STORAGE_SLOT)), uint256(1)); + } + + function testFuzz_permittedCallKey(address addr, bytes4 selector) public { + bytes24 key = _getPermittedCallKey(addr, selector); + assertEq(bytes20(addr), bytes20(key)); + assertEq(bytes4(selector), bytes4(key << 160)); + } +} diff --git a/test/libraries/AssociatedLinkedListSetLib.t.sol b/test/libraries/AssociatedLinkedListSetLib.t.sol new file mode 100644 index 00000000..ef8ed6a4 --- /dev/null +++ b/test/libraries/AssociatedLinkedListSetLib.t.sol @@ -0,0 +1,224 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.21; + +import {Test} from "forge-std/Test.sol"; + +import { + AssociatedLinkedListSet, + AssociatedLinkedListSetLib +} from "../../src/libraries/AssociatedLinkedListSetLib.sol"; +import {SetValue, SENTINEL_VALUE} from "../../src/libraries/LinkedListSetUtils.sol"; + +contract AssociatedLinkedListSetLibTest is Test { + using AssociatedLinkedListSetLib for AssociatedLinkedListSet; + + AssociatedLinkedListSet internal _set1; + AssociatedLinkedListSet internal _set2; + + address internal _associated = address(this); + + // User-defined function for wrapping from bytes30 (uint240) to SetValue + // Can define a custom one for addresses, uints, etc. + function _getListValue(uint240 value) internal pure returns (SetValue) { + return SetValue.wrap(bytes30(value)); + } + + // A lot of these tests were auto-generated by copilot and manually inspected. + // In addition to these tests, there are also invariant tests in + // test/invariant/AssociatedLinkedListSetLibInvariants.t.sol + + function test_add_contains() public { + SetValue value = _getListValue(12); + + assertTrue(_set1.tryAdd(_associated, value)); + assertTrue(_set1.contains(_associated, value)); + } + + function test_empty() public { + SetValue value = _getListValue(12); + + assertFalse(_set1.contains(_associated, value)); + assertTrue(_set1.isEmpty(_associated)); + } + + function test_remove() public { + SetValue value = _getListValue(12); + + assertTrue(_set1.tryAdd(_associated, value)); + assertTrue(_set1.contains(_associated, value)); + + assertTrue(_set1.tryRemove(_associated, value)); + assertFalse(_set1.contains(_associated, value)); + } + + function test_remove_empty() public { + SetValue value = _getListValue(12); + + assertFalse(_set1.tryRemove(_associated, value)); + } + + function test_remove_nonexistent() public { + SetValue value = _getListValue(12); + + assertTrue(_set1.tryAdd(_associated, value)); + assertTrue(_set1.contains(_associated, value)); + + SetValue value2 = _getListValue(13); + assertFalse(_set1.tryRemove(_associated, value2)); + assertTrue(_set1.contains(_associated, value)); + } + + function test_remove_nonexistent_empty() public { + SetValue value = _getListValue(12); + + assertFalse(_set1.tryRemove(_associated, value)); + } + + function test_remove_nonexistent_empty2() public { + SetValue value = _getListValue(12); + + assertTrue(_set1.tryAdd(_associated, value)); + assertTrue(_set1.contains(_associated, value)); + + SetValue value2 = _getListValue(13); + assertFalse(_set1.tryRemove(_associated, value2)); + assertTrue(_set1.contains(_associated, value)); + } + + function test_add_remove_add() public { + SetValue value = _getListValue(12); + + assertTrue(_set1.tryAdd(_associated, value)); + assertTrue(_set1.contains(_associated, value)); + + assertTrue(_set1.tryRemove(_associated, value)); + assertFalse(_set1.contains(_associated, value)); + + assertTrue(_set1.tryAdd(_associated, value)); + assertTrue(_set1.contains(_associated, value)); + } + + function test_add_remove_add_empty() public { + SetValue value = _getListValue(12); + + assertTrue(_set1.tryAdd(_associated, value)); + assertTrue(_set1.contains(_associated, value)); + + assertTrue(_set1.tryRemove(_associated, value)); + assertFalse(_set1.contains(_associated, value)); + + assertTrue(_set1.tryAdd(_associated, value)); + assertTrue(_set1.contains(_associated, value)); + } + + function test_no_address_collision() public { + SetValue value = _getListValue(12); + + assertTrue(_set1.tryAdd(_associated, value)); + assertTrue(_set1.contains(_associated, value)); + assertFalse(_set2.contains(_associated, value)); + } + + function test_clear() public { + SetValue value = _getListValue(12); + + assertTrue(_set1.tryAdd(_associated, value)); + assertTrue(_set1.contains(_associated, value)); + + _set1.clear(_associated); + + assertFalse(_set1.contains(_associated, value)); + assertTrue(_set1.isEmpty(_associated)); + } + + function test_getAll() public { + SetValue value = _getListValue(12); + SetValue value2 = _getListValue(13); + + assertTrue(_set1.tryAdd(_associated, value)); + assertTrue(_set1.tryAdd(_associated, value2)); + + SetValue[] memory values = _set1.getAll(_associated); + assertEq(values.length, 2); + // Returned set will be in reverse order of added elements + assertEq(SetValue.unwrap(values[1]), SetValue.unwrap(value)); + assertEq(SetValue.unwrap(values[0]), SetValue.unwrap(value2)); + } + + function test_getAll2() public { + SetValue value = _getListValue(12); + SetValue value2 = _getListValue(13); + SetValue value3 = _getListValue(14); + + assertTrue(_set1.tryAdd(_associated, value)); + assertTrue(_set1.tryAdd(_associated, value2)); + assertTrue(_set1.tryAdd(_associated, value3)); + + SetValue[] memory values = _set1.getAll(_associated); + assertEq(values.length, 3); + // Returned set will be in reverse order of added elements + assertEq(SetValue.unwrap(values[2]), SetValue.unwrap(value)); + assertEq(SetValue.unwrap(values[1]), SetValue.unwrap(value2)); + assertEq(SetValue.unwrap(values[0]), SetValue.unwrap(value3)); + } + + function test_getAll_empty() public { + SetValue[] memory values = _set1.getAll(_associated); + assertEq(values.length, 0); + } + + function test_tryRemoveKnown1() public { + SetValue value = _getListValue(12); + + assertTrue(_set1.tryAdd(_associated, value)); + assertTrue(_set1.contains(_associated, value)); + + assertTrue(_set1.tryRemoveKnown(_associated, value, SENTINEL_VALUE)); + assertFalse(_set1.contains(_associated, value)); + assertTrue(_set1.isEmpty(_associated)); + } + + function test_tryRemoveKnown2() public { + SetValue value1 = _getListValue(12); + SetValue value2 = _getListValue(13); + + assertTrue(_set1.tryAdd(_associated, value1)); + assertTrue(_set1.tryAdd(_associated, value2)); + assertTrue(_set1.contains(_associated, value1)); + assertTrue(_set1.contains(_associated, value2)); + + // Assert that getAll returns the correct values + SetValue[] memory values = _set1.getAll(_associated); + assertEq(values.length, 2); + assertEq(SetValue.unwrap(values[1]), SetValue.unwrap(value1)); + assertEq(SetValue.unwrap(values[0]), SetValue.unwrap(value2)); + + assertTrue(_set1.tryRemoveKnown(_associated, value1, bytes32(SetValue.unwrap(value2)))); + assertFalse(_set1.contains(_associated, value1)); + assertTrue(_set1.contains(_associated, value2)); + + // Assert that getAll returns the correct values + values = _set1.getAll(_associated); + assertEq(values.length, 1); + assertEq(SetValue.unwrap(values[0]), SetValue.unwrap(value2)); + + assertTrue(_set1.tryRemoveKnown(_associated, value2, SENTINEL_VALUE)); + assertFalse(_set1.contains(_associated, value1)); + + assertTrue(_set1.isEmpty(_associated)); + } + + function test_tryRemoveKnown_invalid1() public { + SetValue value1 = _getListValue(12); + SetValue value2 = _getListValue(13); + + assertTrue(_set1.tryAdd(_associated, value1)); + assertTrue(_set1.tryAdd(_associated, value2)); + + assertFalse(_set1.tryRemoveKnown(_associated, value1, bytes32(SetValue.unwrap(value1)))); + assertTrue(_set1.contains(_associated, value1)); + + assertFalse(_set1.tryRemoveKnown(_associated, value2, bytes32(SetValue.unwrap(value2)))); + assertTrue(_set1.contains(_associated, value2)); + } +} diff --git a/test/libraries/CountableLinkedListSetLib.t.sol b/test/libraries/CountableLinkedListSetLib.t.sol new file mode 100644 index 00000000..d5f8bc9f --- /dev/null +++ b/test/libraries/CountableLinkedListSetLib.t.sol @@ -0,0 +1,68 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.21; + +import {Test} from "forge-std/Test.sol"; + +import {CountableLinkedListSetLib} from "../../src/libraries/CountableLinkedListSetLib.sol"; +import {LinkedListSet, LinkedListSetLib} from "../../src/libraries/LinkedListSetLib.sol"; +import {SetValue} from "../../src/libraries/LinkedListSetUtils.sol"; + +contract CountableLinkedListSetLibTest is Test { + using LinkedListSetLib for LinkedListSet; + using CountableLinkedListSetLib for LinkedListSet; + + LinkedListSet internal _set; + + uint16 internal constant _MAX_COUNTER_VALUE = 255; + + // User-defined function for wrapping from bytes30 (uint240) to SetValue + // Can define a custom one for addresses, uints, etc. + function _getListValue(uint240 value) internal pure returns (SetValue) { + return SetValue.wrap(bytes30(value)); + } + + function test_getCount() public { + SetValue value = _getListValue(12); + assertTrue(_set.tryAdd(value)); + assertEq(_set.getCount(value), 1); + _set.tryEnableFlags(value, 0xFF00); + assertEq(_set.getCount(value), 256); + } + + function test_tryIncrement() public { + SetValue value = _getListValue(12); + assertEq(_set.getCount(value), 0); + + for (uint256 i = 0; i < _MAX_COUNTER_VALUE + 1; ++i) { + assertTrue(_set.tryIncrement(value)); + assertEq(_set.getCount(value), i + 1); + } + + assertFalse(_set.tryIncrement(value)); + assertEq(_set.getCount(value), 256); + + assertTrue(_set.contains(value)); + assertFalse(_set.tryAdd(value)); + } + + function test_tryDecrement() public { + SetValue value = _getListValue(12); + assertEq(_set.getCount(value), 0); + assertFalse(_set.tryDecrement(value)); + + for (uint256 i = 0; i < _MAX_COUNTER_VALUE + 1; ++i) { + _set.tryIncrement(value); + } + + for (uint256 i = _MAX_COUNTER_VALUE + 1; i > 0; --i) { + assertTrue(_set.tryDecrement(value)); + assertEq(_set.getCount(value), i - 1); + } + + assertFalse(_set.tryDecrement(value)); + assertEq(_set.getCount(value), 0); + + assertFalse(_set.contains(value)); + assertFalse(_set.tryRemove(value)); + } +} diff --git a/test/libraries/FunctionReferenceLib.t.sol b/test/libraries/FunctionReferenceLib.t.sol new file mode 100644 index 00000000..c71a2bff --- /dev/null +++ b/test/libraries/FunctionReferenceLib.t.sol @@ -0,0 +1,36 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.21; + +import {Test} from "forge-std/Test.sol"; +import {FunctionReference, FunctionReferenceLib} from "../../src/libraries/FunctionReferenceLib.sol"; + +contract FunctionReferenceLibTest is Test { + function testFuzz_functionReference_packing(address addr, uint8 functionId) public { + // console.log("addr: ", addr); + // console.log("functionId: ", vm.toString(functionId)); + FunctionReference fr = FunctionReferenceLib.pack(addr, functionId); + // console.log("packed: ", vm.toString(FunctionReference.unwrap(fr))); + (address addr2, uint8 functionId2) = FunctionReferenceLib.unpack(fr); + // console.log("addr2: ", addr2); + // console.log("functionId2: ", vm.toString(functionId2)); + assertEq(addr, addr2); + assertEq(functionId, functionId2); + } + + function testFuzz_functionReference_operators(FunctionReference a, FunctionReference b) public { + assertTrue(a == a); + assertTrue(b == b); + + if (FunctionReference.unwrap(a) == FunctionReference.unwrap(b)) { + assertTrue(a == b); + assertTrue(b == a); + assertFalse(a != b); + assertFalse(b != a); + } else { + assertTrue(a != b); + assertTrue(b != a); + assertFalse(a == b); + assertFalse(b == a); + } + } +} diff --git a/test/libraries/LinkedListSetLib.t.sol b/test/libraries/LinkedListSetLib.t.sol new file mode 100644 index 00000000..e6b2df4e --- /dev/null +++ b/test/libraries/LinkedListSetLib.t.sol @@ -0,0 +1,302 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.21; + +import {Test} from "forge-std/Test.sol"; + +import {LinkedListSet, LinkedListSetLib} from "../../src/libraries/LinkedListSetLib.sol"; +import {SetValue, SENTINEL_VALUE} from "../../src/libraries/LinkedListSetUtils.sol"; + +// Ported over from test/AssociatedLinkedListSetLib.t.sol, dropping test_no_address_collision +contract LinkedListSetLibTest is Test { + using LinkedListSetLib for LinkedListSet; + + LinkedListSet internal _set; + + // User-defined function for wrapping from bytes30 (uint240) to SetValue + // Can define a custom one for addresses, uints, etc. + function _getListValue(uint240 value) internal pure returns (SetValue) { + return SetValue.wrap(bytes30(value)); + } + + // A lot of these tests were auto-generated by copilot and manually inspected. + // In addition to these tests, there are also invariant tests in + // test/invariant/LinkedListSetLibInvariants.t.sol + + function test_add_contains() public { + SetValue value = _getListValue(12); + assertTrue(_set.tryAdd(value)); + assertTrue(_set.contains(value)); + } + + function test_empty() public { + SetValue value = _getListValue(12); + assertFalse(_set.contains(value)); + assertTrue(_set.isEmpty()); + } + + function test_remove() public { + SetValue value = _getListValue(12); + assertTrue(_set.tryAdd(value)); + assertTrue(_set.contains(value)); + + assertTrue(_set.tryRemove(value)); + assertFalse(_set.contains(value)); + } + + function test_remove_empty() public { + SetValue value = _getListValue(12); + assertFalse(_set.tryRemove(value)); + } + + function test_remove_nonexistent() public { + SetValue value = _getListValue(12); + assertTrue(_set.tryAdd(value)); + assertTrue(_set.contains(value)); + + SetValue value2 = _getListValue(13); + assertFalse(_set.tryRemove(value2)); + assertTrue(_set.contains(value)); + } + + function test_remove_nonexistent_empty() public { + SetValue value = _getListValue(12); + assertFalse(_set.tryRemove(value)); + } + + function test_remove_nonexistent_empty2() public { + SetValue value = _getListValue(12); + assertTrue(_set.tryAdd(value)); + assertTrue(_set.contains(value)); + + SetValue value2 = _getListValue(13); + assertFalse(_set.tryRemove(value2)); + assertTrue(_set.contains(value)); + } + + function test_add_remove_add() public { + SetValue value = _getListValue(12); + assertTrue(_set.tryAdd(value)); + assertTrue(_set.contains(value)); + + assertTrue(_set.tryRemove(value)); + assertFalse(_set.contains(value)); + + assertTrue(_set.tryAdd(value)); + assertTrue(_set.contains(value)); + } + + function test_add_remove_add_empty() public { + SetValue value = _getListValue(12); + assertTrue(_set.tryAdd(value)); + assertTrue(_set.contains(value)); + + assertTrue(_set.tryRemove(value)); + assertFalse(_set.contains(value)); + + assertTrue(_set.tryAdd(value)); + assertTrue(_set.contains(value)); + } + + function test_clear() public { + SetValue value = _getListValue(12); + assertTrue(_set.tryAdd(value)); + assertTrue(_set.contains(value)); + + _set.clear(); + + assertFalse(_set.contains(value)); + assertTrue(_set.isEmpty()); + } + + function test_getAll() public { + SetValue value = _getListValue(12); + SetValue value2 = _getListValue(13); + assertTrue(_set.tryAdd(value)); + assertTrue(_set.tryAdd(value2)); + + SetValue[] memory values = _set.getAll(); + assertEq(values.length, 2); + // Returned set will be in reverse order of added elements + assertEq(SetValue.unwrap(values[1]), SetValue.unwrap(value)); + assertEq(SetValue.unwrap(values[0]), SetValue.unwrap(value2)); + } + + function test_getAll2() public { + SetValue value = _getListValue(12); + SetValue value2 = _getListValue(13); + SetValue value3 = _getListValue(14); + assertTrue(_set.tryAdd(value)); + assertTrue(_set.tryAdd(value2)); + assertTrue(_set.tryAdd(value3)); + + SetValue[] memory values = _set.getAll(); + assertEq(values.length, 3); + // Returned set will be in reverse order of added elements + assertEq(SetValue.unwrap(values[2]), SetValue.unwrap(value)); + assertEq(SetValue.unwrap(values[1]), SetValue.unwrap(value2)); + assertEq(SetValue.unwrap(values[0]), SetValue.unwrap(value3)); + } + + function test_getAll_empty() public { + SetValue[] memory values = _set.getAll(); + assertEq(values.length, 0); + } + + function test_tryRemoveKnown1() public { + SetValue value = _getListValue(12); + + assertTrue(_set.tryAdd(value)); + assertTrue(_set.contains(value)); + + assertTrue(_set.tryRemoveKnown(value, SENTINEL_VALUE)); + assertFalse(_set.contains(value)); + assertTrue(_set.isEmpty()); + } + + function test_tryRemoveKnown2() public { + SetValue value1 = _getListValue(12); + SetValue value2 = _getListValue(13); + + assertTrue(_set.tryAdd(value1)); + assertTrue(_set.tryAdd(value2)); + assertTrue(_set.contains(value1)); + assertTrue(_set.contains(value2)); + + // Assert that getAll returns the correct values + SetValue[] memory values = _set.getAll(); + assertEq(values.length, 2); + assertEq(SetValue.unwrap(values[1]), SetValue.unwrap(value1)); + assertEq(SetValue.unwrap(values[0]), SetValue.unwrap(value2)); + + assertTrue(_set.tryRemoveKnown(value1, bytes32(SetValue.unwrap(value2)))); + assertFalse(_set.contains(value1)); + assertTrue(_set.contains(value2)); + + // Assert that getAll returns the correct values + values = _set.getAll(); + assertEq(values.length, 1); + assertEq(SetValue.unwrap(values[0]), SetValue.unwrap(value2)); + + assertTrue(_set.tryRemoveKnown(value2, SENTINEL_VALUE)); + assertFalse(_set.contains(value1)); + + assertTrue(_set.isEmpty()); + } + + function test_tryRemoveKnown_invalid1() public { + SetValue value1 = _getListValue(12); + SetValue value2 = _getListValue(13); + + assertTrue(_set.tryAdd(value1)); + assertTrue(_set.tryAdd(value2)); + + assertFalse(_set.tryRemoveKnown(value1, bytes32(SetValue.unwrap(value1)))); + assertTrue(_set.contains(value1)); + + assertFalse(_set.tryRemoveKnown(value2, bytes32(SetValue.unwrap(value2)))); + assertTrue(_set.contains(value2)); + } + + function test_isSentinel() public { + bytes32 val1 = bytes32(uint256(0)); + assertFalse(LinkedListSetLib.isSentinel(val1)); + + bytes32 val2 = bytes32(uint256(1)); + assertTrue(LinkedListSetLib.isSentinel(val2)); + + bytes32 val3 = bytes32(uint256(3)); + assertTrue(LinkedListSetLib.isSentinel(val3)); + + bytes32 val4 = bytes32(uint256(2)); + assertFalse(LinkedListSetLib.isSentinel(val4)); + } + + function test_userFlags_fail_does_not_contain() public { + SetValue value1 = _getListValue(12); + SetValue value2 = _getListValue(13); + + assertTrue(_set.tryAdd(value1)); + assertTrue(_set.tryAdd(value2)); + + assertFalse(_set.trySetFlags(_getListValue(14), uint8(0xF0))); + } + + function test_userFlags_basic() public { + SetValue value1 = _getListValue(12); + SetValue value2 = _getListValue(13); + + assertTrue(_set.tryAdd(value1)); + assertTrue(_set.tryAdd(value2)); + + assertTrue(_set.trySetFlags(value1, uint8(0xF0))); + assertTrue(_set.trySetFlags(value2, uint8(0x0C))); + + assertEq(_set.getFlags(value1), uint8(0xF0)); + assertEq(_set.getFlags(value2), uint8(0x0C)); + } + + function test_userFlags_getAll() public { + SetValue value1 = _getListValue(12); + SetValue value2 = _getListValue(13); + + assertTrue(_set.tryAdd(value1)); + assertTrue(_set.tryAdd(value2)); + + assertTrue(_set.trySetFlags(value1, uint8(0xF0))); + assertTrue(_set.trySetFlags(value2, uint8(0x0C))); + + SetValue[] memory values = _set.getAll(); + assertEq(values.length, 2); + assertEq(SetValue.unwrap(values[0]), SetValue.unwrap(value2)); + assertEq(SetValue.unwrap(values[1]), SetValue.unwrap(value1)); + + assertEq(_set.getFlags(values[0]), uint8(0x0C)); + assertEq(_set.getFlags(values[1]), uint8(0xF0)); + } + + function test_userFlags_tryEnable() public { + SetValue value1 = _getListValue(12); + + assertTrue(_set.tryAdd(value1)); + + assertTrue(_set.trySetFlags(value1, uint8(0xF0))); + assertTrue(_set.tryEnableFlags(value1, uint8(0x0C))); + + assertEq(_set.getFlags(value1), uint8(0xFC)); + } + + function test_userFlags_tryDisable() public { + SetValue value1 = _getListValue(12); + + assertTrue(_set.tryAdd(value1)); + + assertTrue(_set.trySetFlags(value1, uint8(0xF0))); + assertTrue(_set.tryDisableFlags(value1, uint8(0xC0))); + + assertEq(_set.getFlags(value1), uint8(0x30)); + } + + function test_userFlags_flagsEnabled() public { + SetValue value1 = _getListValue(12); + + assertTrue(_set.tryAdd(value1)); + + assertTrue(_set.trySetFlags(value1, uint8(0xF0))); + + assertTrue(_set.flagsEnabled(value1, uint8(0x80))); + assertTrue(_set.flagsEnabled(value1, uint8(0xC0))); + assertFalse(_set.flagsEnabled(value1, uint8(0x0C))); + } + + function test_userFlags_flagsDisabled() public { + SetValue value1 = _getListValue(12); + + assertTrue(_set.tryAdd(value1)); + + assertTrue(_set.trySetFlags(value1, uint8(0xF0))); + + assertFalse(_set.flagsDisabled(value1, uint8(0x80))); + assertFalse(_set.flagsDisabled(value1, uint8(0xC0))); + assertTrue(_set.flagsDisabled(value1, uint8(0x0C))); + } +} diff --git a/test/libraries/PluginStorageLib.t.sol b/test/libraries/PluginStorageLib.t.sol new file mode 100644 index 00000000..e9226ad7 --- /dev/null +++ b/test/libraries/PluginStorageLib.t.sol @@ -0,0 +1,91 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.21; + +import {Test} from "forge-std/Test.sol"; + +import {PluginStorageLib, StoragePointer} from "../../src/libraries/PluginStorageLib.sol"; + +contract PluginStorageLibTest is Test { + using PluginStorageLib for bytes; + using PluginStorageLib for bytes32; + + uint256 public constant FUZZ_ARR_SIZE = 32; + + address public account1; + + struct TestStruct { + uint256 a; + uint256 b; + } + + function setUp() public { + account1 = makeAddr("account1"); + } + + function test_storagePointer() public { + bytes memory key = PluginStorageLib.allocateAssociatedStorageKey(account1, 0, 1); + + StoragePointer ptr = PluginStorageLib.associatedStorageLookup( + key, hex"00112233445566778899aabbccddeeff00112233445566778899aabbccddeeff" + ); + TestStruct storage val = _castPtrToStruct(ptr); + + vm.record(); + val.a = 0xdeadbeef; + val.b = 123; + (, bytes32[] memory accountWrites) = vm.accesses(address(this)); + + // printStorageReadsAndWrites(address(this)); + + assertEq(accountWrites.length, 2); + bytes32 expectedKey = keccak256( + abi.encodePacked( + uint256(uint160(account1)), + uint256(0), + hex"00112233445566778899aabbccddeeff00112233445566778899aabbccddeeff" + ) + ); + assertEq(accountWrites[0], expectedKey); + assertEq(vm.load(address(this), expectedKey), bytes32(uint256(0xdeadbeef))); + assertEq(accountWrites[1], bytes32(uint256(expectedKey) + 1)); + assertEq(vm.load(address(this), bytes32(uint256(expectedKey) + 1)), bytes32(uint256(123))); + } + + function testFuzz_storagePointer( + address account, + uint256 batchIndex, + bytes32 inputKey, + uint256[FUZZ_ARR_SIZE] calldata values + ) public { + bytes memory key = PluginStorageLib.allocateAssociatedStorageKey(account, batchIndex, 1); + uint256[FUZZ_ARR_SIZE] storage val = + _castPtrToArray(PluginStorageLib.associatedStorageLookup(key, inputKey)); + // Write values to storage + vm.record(); + for (uint256 i = 0; i < FUZZ_ARR_SIZE; i++) { + val[i] = values[i]; + } + // Assert the writes took place in the right location, and the correct value is stored there + (, bytes32[] memory accountWrites) = vm.accesses(address(this)); + assertEq(accountWrites.length, FUZZ_ARR_SIZE); + for (uint256 i = 0; i < FUZZ_ARR_SIZE; i++) { + bytes32 expectedKey = bytes32( + uint256(keccak256(abi.encodePacked(uint256(uint160(account)), uint256(batchIndex), inputKey))) + i + ); + assertEq(accountWrites[i], expectedKey); + assertEq(vm.load(address(this), expectedKey), bytes32(uint256(values[i]))); + } + } + + function _castPtrToArray(StoragePointer ptr) internal pure returns (uint256[FUZZ_ARR_SIZE] storage val) { + assembly ("memory-safe") { + val.slot := ptr + } + } + + function _castPtrToStruct(StoragePointer ptr) internal pure returns (TestStruct storage val) { + assembly ("memory-safe") { + val.slot := ptr + } + } +} diff --git a/test/mocks/ContractOwner.sol b/test/mocks/ContractOwner.sol new file mode 100644 index 00000000..f29530ac --- /dev/null +++ b/test/mocks/ContractOwner.sol @@ -0,0 +1,19 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.21; + +import {IERC1271} from "@openzeppelin/contracts/interfaces/IERC1271.sol"; + +contract ContractOwner is IERC1271 { + bytes4 internal constant _1271_MAGIC_VALUE = 0x1626ba7e; + + function sign(bytes32 digest) public pure returns (bytes memory) { + return abi.encodePacked("Signed: ", digest); + } + + function isValidSignature(bytes32 digest, bytes memory signature) public pure override returns (bytes4) { + if (keccak256(signature) == keccak256(sign(digest))) { + return _1271_MAGIC_VALUE; + } + return 0xffffffff; + } +} diff --git a/test/mocks/Counter.sol b/test/mocks/Counter.sol new file mode 100644 index 00000000..114403aa --- /dev/null +++ b/test/mocks/Counter.sol @@ -0,0 +1,19 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.21; + +/// @title A public counter for anyone to use. +contract Counter { + uint256 public number; + + /// @notice Set the counter's number to a new value. + /// @param newNumber The new number for the counter. + function setNumber(uint256 newNumber) public { + number = newNumber; + } + + /// @notice Increase the counter's value by one. + /// @dev The number is not in an unchecked block, so overflows will revert. + function increment() public { + number++; + } +} diff --git a/test/mocks/Counter.t.sol b/test/mocks/Counter.t.sol new file mode 100644 index 00000000..7c63181b --- /dev/null +++ b/test/mocks/Counter.t.sol @@ -0,0 +1,24 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.21; + +import {Test} from "forge-std/Test.sol"; +import {Counter} from "./Counter.sol"; + +contract CounterTest is Test { + Counter public counter; + + function setUp() public { + counter = new Counter(); + counter.setNumber(0); + } + + function testIncrement() public { + counter.increment(); + assertEq(counter.number(), 1); + } + + function testSetNumber(uint256 x) public { + counter.setNumber(x); + assertEq(counter.number(), x); + } +} diff --git a/test/mocks/MockDiamondStorageContract.sol b/test/mocks/MockDiamondStorageContract.sol new file mode 100644 index 00000000..76e47dac --- /dev/null +++ b/test/mocks/MockDiamondStorageContract.sol @@ -0,0 +1,13 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.21; + +import {AccountStorageInitializable} from "../../src/account/AccountStorageInitializable.sol"; + +contract MockDiamondStorageContract is AccountStorageInitializable { + constructor() { + _disableInitializers(); + } + + // solhint-disable-next-line no-empty-blocks + function initialize() external initializer {} +} diff --git a/test/mocks/MockPlugin.sol b/test/mocks/MockPlugin.sol new file mode 100644 index 00000000..ae4e263c --- /dev/null +++ b/test/mocks/MockPlugin.sol @@ -0,0 +1,104 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.21; + +import {ERC165} from "@openzeppelin/contracts/utils/introspection/ERC165.sol"; + +import {PluginManifest, PluginMetadata, IPlugin} from "../../src/interfaces/IPlugin.sol"; + +contract MockPlugin is ERC165 { + // It's super inefficient to hold the entire abi-encoded manifest in storage, but this is fine since it's + // just a mock. Note that the reason we do this is to allow copying the entire contents of the manifest + // into storage in a single line, since solidity fails to compile with memory -> storage copying of nested + // dynamic types when compiling without `via-ir` in the lite profile. + // See the error code below: + // Error: Unimplemented feature (/solidity/libsolidity/codegen/ArrayUtils.cpp:228):Copying of type + // struct ManifestAssociatedFunction memory[] memory to storage not yet supported. + bytes internal _manifest; + + string internal constant _NAME = "Mock Plugin Modifiable"; + string internal constant _VERSION = "1.0.0"; + string internal constant _AUTHOR = "Alchemy"; + + event ReceivedCall(bytes msgData, uint256 msgValue); + + // ┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ + // ┃ Plugin interface functions ┃ + // ┗━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┛ + + constructor(PluginManifest memory _pluginManifest) { + _manifest = abi.encode(_pluginManifest); + } + + // solhint-disable-next-line no-empty-blocks + function foo() public {} + + function _getManifest() internal view returns (PluginManifest memory) { + PluginManifest memory m = abi.decode(_manifest, (PluginManifest)); + return m; + } + + function _castToPure(function() internal view returns (PluginManifest memory) fnIn) + internal + pure + returns (function() internal pure returns (PluginManifest memory) fnOut) + { + assembly { + fnOut := fnIn + } + } + + function pluginManifest() external pure returns (PluginManifest memory) { + return _castToPure(_getManifest)(); + } + + function pluginMetadata() external pure returns (PluginMetadata memory) { + PluginMetadata memory metadata; + metadata.name = _NAME; + metadata.version = _VERSION; + metadata.author = _AUTHOR; + return metadata; + } + + /// @dev Returns true if this contract implements the interface defined by + /// `interfaceId`. See the corresponding + /// https://eips.ethereum.org/EIPS/eip-165#how-interfaces-are-identified[EIP section] + /// to learn more about how these ids are created. + /// + /// This function call must use less than 30 000 gas. + /// + /// Supporting the IPlugin interface is a requirement for plugin installation. This is also used + /// by the modular account to prevent standard execution functions `execute` and `executeBatch` from + /// making calls to plugins. + /// @param interfaceId The interface ID to check for support. + /// @return True if the contract supports `interfaceId`. + function supportsInterface(bytes4 interfaceId) public view virtual override returns (bool) { + return (interfaceId != 0xffffffff); + } + + /// @dev Hardcode the pre execution hook to return the functionId, which will be passed to the post execution + /// hook. + function preExecutionHook(uint8 functionId, address, uint256, bytes calldata) + external + returns (bytes memory) + { + emit ReceivedCall(msg.data, 0); + return abi.encode(functionId); + } + + receive() external payable {} + + // solhint-disable-next-line no-complex-fallback + fallback() external payable { + emit ReceivedCall(msg.data, msg.value); + if ( + msg.sig == IPlugin.userOpValidationFunction.selector + || msg.sig == IPlugin.runtimeValidationFunction.selector + || msg.sig == IPlugin.preUserOpValidationHook.selector + ) { + // return 0 for userOpVal/runtimeVal/preUserOpValidationHook + assembly { + return(0x00, 0x20) + } + } + } +} diff --git a/test/mocks/plugins/BadTransferOwnershipPlugin.sol b/test/mocks/plugins/BadTransferOwnershipPlugin.sol new file mode 100644 index 00000000..9c8dc3b8 --- /dev/null +++ b/test/mocks/plugins/BadTransferOwnershipPlugin.sol @@ -0,0 +1,72 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.21; + +import { + ManifestExecutionHook, + ManifestFunction, + ManifestAssociatedFunctionType, + ManifestAssociatedFunction, + PluginManifest, + PluginMetadata +} from "../../../src/interfaces/IPlugin.sol"; +import {IPluginManager} from "../../../src/interfaces/IPluginManager.sol"; +import {BaseTestPlugin} from "./BaseTestPlugin.sol"; +import {IMultiOwnerPlugin} from "../../../src/plugins/owner/IMultiOwnerPlugin.sol"; +import {IStandardExecutor} from "../../../src/interfaces/IStandardExecutor.sol"; +import {IPluginExecutor} from "../../../src/interfaces/IPluginExecutor.sol"; + +contract BadTransferOwnershipPlugin is BaseTestPlugin { + string internal constant _NAME = "Evil Transfer Ownership Plugin"; + string internal constant _VERSION = "1.0.0"; + string internal constant _AUTHOR = "Alchemy"; + + // ┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ + // ┃ Execution functions ┃ + // ┗━━━━━━━━━━━━━━━━━━━━━━━━━━━┛ + + function evilTransferOwnership(address target) external { + address[] memory owners = new address[](1); + owners[0] = target; + IPluginExecutor(msg.sender).executeFromPlugin( + abi.encodeCall(IMultiOwnerPlugin.updateOwners, (owners, new address[](0))) + ); + } + + // ┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ + // ┃ Plugin interface functions ┃ + // ┗━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┛ + + function onInstall(bytes calldata) external override {} + + function onUninstall(bytes calldata) external override {} + + function pluginManifest() external pure override returns (PluginManifest memory) { + PluginManifest memory manifest; + + manifest.executionFunctions = new bytes4[](1); + manifest.executionFunctions[0] = this.evilTransferOwnership.selector; + + manifest.permittedExecutionSelectors = new bytes4[](1); + manifest.permittedExecutionSelectors[0] = IMultiOwnerPlugin.updateOwners.selector; + + manifest.runtimeValidationFunctions = new ManifestAssociatedFunction[](1); + manifest.runtimeValidationFunctions[0] = ManifestAssociatedFunction({ + executionSelector: this.evilTransferOwnership.selector, + associatedFunction: ManifestFunction({ + functionType: ManifestAssociatedFunctionType.RUNTIME_VALIDATION_ALWAYS_ALLOW, + functionId: 0, // Unused. + dependencyIndex: 0 // Unused. + }) + }); + + return manifest; + } + + function pluginMetadata() external pure virtual override returns (PluginMetadata memory) { + PluginMetadata memory metadata; + metadata.name = _NAME; + metadata.version = _VERSION; + metadata.author = _AUTHOR; + return metadata; + } +} diff --git a/test/mocks/plugins/BaseTestPlugin.sol b/test/mocks/plugins/BaseTestPlugin.sol new file mode 100644 index 00000000..8b28a0f0 --- /dev/null +++ b/test/mocks/plugins/BaseTestPlugin.sol @@ -0,0 +1,12 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.21; + +import {BasePlugin} from "../../../src/plugins/BasePlugin.sol"; +import {PluginMetadata} from "../../../src/interfaces/IPlugin.sol"; + +contract BaseTestPlugin is BasePlugin { + // Don't need to implement this in each context + function pluginMetadata() external pure virtual override returns (PluginMetadata memory) { + revert NotImplemented(); + } +} diff --git a/test/mocks/plugins/ChangingManifestPlugin.sol b/test/mocks/plugins/ChangingManifestPlugin.sol new file mode 100644 index 00000000..ae7e1be1 --- /dev/null +++ b/test/mocks/plugins/ChangingManifestPlugin.sol @@ -0,0 +1,82 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.21; + +import {ERC1967Proxy} from "@openzeppelin/contracts/proxy/ERC1967/ERC1967Proxy.sol"; +import {UUPSUpgradeable} from "@openzeppelin/contracts/proxy/utils/UUPSUpgradeable.sol"; + +import { + ManifestFunction, + ManifestAssociatedFunctionType, + ManifestAssociatedFunction, + PluginManifest, + PluginMetadata +} from "../../../src/interfaces/IPlugin.sol"; +import {BasePlugin} from "../../../src/plugins/BasePlugin.sol"; + +contract CanChangeManifestPluginFactory { + function newPlugin() external returns (CanChangeManifestPlugin) { + return + CanChangeManifestPlugin(address(new ERC1967Proxy(address(new CanChangeManifestPlugin()), bytes("")))); + } +} + +contract CanChangeManifestPlugin is BasePlugin, UUPSUpgradeable { + string internal constant _NAME = "CanChangeManifestPlugin"; + string internal constant _VERSION = "1.0.0"; + string internal constant _AUTHOR = "Alchemy"; + + function someExecutionFunction() external {} + + function changeManifest() external { + _upgradeTo(address(new DidChangeManifestPlugin())); + } + + function onInstall(bytes calldata) external override {} + + function pluginManifest() external pure override returns (PluginManifest memory manifest) { + manifest.executionFunctions = new bytes4[](1); + manifest.executionFunctions[0] = this.someExecutionFunction.selector; + + manifest.runtimeValidationFunctions = new ManifestAssociatedFunction[](1); + manifest.runtimeValidationFunctions[0] = ManifestAssociatedFunction({ + executionSelector: this.someExecutionFunction.selector, + associatedFunction: ManifestFunction({ + functionType: ManifestAssociatedFunctionType.RUNTIME_VALIDATION_ALWAYS_ALLOW, + functionId: 0, // Unused. + dependencyIndex: 0 // Unused. + }) + }); + } + + /// @inheritdoc BasePlugin + function pluginMetadata() external pure virtual override returns (PluginMetadata memory) { + PluginMetadata memory metadata; + metadata.name = _NAME; + metadata.version = _VERSION; + metadata.author = _AUTHOR; + return metadata; + } + + function _authorizeUpgrade(address) internal virtual override {} +} + +contract DidChangeManifestPlugin is BasePlugin { + string internal constant _NAME = "DidChangeManifestPlugin"; + string internal constant _VERSION = "1.0.0"; + string internal constant _AUTHOR = "Alchemy"; + + function someExecutionFunction() external {} + + function onUninstall(bytes calldata) external override {} + + function pluginManifest() external pure override returns (PluginManifest memory manifest) {} + + /// @inheritdoc BasePlugin + function pluginMetadata() external pure virtual override returns (PluginMetadata memory) { + PluginMetadata memory metadata; + metadata.name = _NAME; + metadata.version = _VERSION; + metadata.author = _AUTHOR; + return metadata; + } +} diff --git a/test/mocks/plugins/ComprehensivePlugin.sol b/test/mocks/plugins/ComprehensivePlugin.sol new file mode 100644 index 00000000..16029de7 --- /dev/null +++ b/test/mocks/plugins/ComprehensivePlugin.sol @@ -0,0 +1,279 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.21; + +import {UserOperation} from "../../../src/interfaces/erc4337/UserOperation.sol"; +import { + ManifestExecutionHook, + ManifestFunction, + ManifestAssociatedFunctionType, + ManifestAssociatedFunction, + PluginManifest, + PluginMetadata +} from "../../../src/interfaces/IPlugin.sol"; +import {IStandardExecutor} from "../../../src/interfaces/IStandardExecutor.sol"; +import {BaseTestPlugin} from "./BaseTestPlugin.sol"; + +contract ComprehensivePlugin is BaseTestPlugin { + enum FunctionId { + PRE_USER_OP_VALIDATION_HOOK_1, + PRE_USER_OP_VALIDATION_HOOK_2, + USER_OP_VALIDATION, + PRE_RUNTIME_VALIDATION_HOOK_1, + PRE_RUNTIME_VALIDATION_HOOK_2, + RUNTIME_VALIDATION, + PRE_EXECUTION_HOOK, + PRE_PERMITTED_CALL_EXECUTION_HOOK, + POST_EXECUTION_HOOK, + POST_PERMITTED_CALL_EXECUTION_HOOK + } + + string internal constant _NAME = "Comprehensive Plugin"; + string internal constant _VERSION = "1.0.0"; + string internal constant _AUTHOR = "Alchemy"; + + // ┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ + // ┃ Execution functions ┃ + // ┗━━━━━━━━━━━━━━━━━━━━━━━━━━━┛ + + function foo() external {} + + // ┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ + // ┃ Plugin interface functions ┃ + // ┗━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┛ + + function onInstall(bytes calldata) external override {} + + function onUninstall(bytes calldata) external override {} + + function preUserOpValidationHook(uint8 functionId, UserOperation calldata, bytes32) + external + pure + override + returns (uint256) + { + if (functionId == uint8(FunctionId.PRE_USER_OP_VALIDATION_HOOK_1)) { + return 0; + } else if (functionId == uint8(FunctionId.PRE_USER_OP_VALIDATION_HOOK_2)) { + return 0; + } + revert NotImplemented(); + } + + function userOpValidationFunction(uint8 functionId, UserOperation calldata, bytes32) + external + pure + override + returns (uint256) + { + if (functionId == uint8(FunctionId.USER_OP_VALIDATION)) { + return 0; + } + revert NotImplemented(); + } + + function preRuntimeValidationHook(uint8 functionId, address, uint256, bytes calldata) external pure override { + if (functionId == uint8(FunctionId.PRE_RUNTIME_VALIDATION_HOOK_1)) { + return; + } else if (functionId == uint8(FunctionId.PRE_RUNTIME_VALIDATION_HOOK_2)) { + return; + } + revert NotImplemented(); + } + + function runtimeValidationFunction(uint8 functionId, address, uint256, bytes calldata) + external + pure + override + { + if (functionId == uint8(FunctionId.RUNTIME_VALIDATION)) { + return; + } + revert NotImplemented(); + } + + function preExecutionHook(uint8 functionId, address, uint256, bytes calldata) + external + pure + override + returns (bytes memory) + { + if (functionId == uint8(FunctionId.PRE_EXECUTION_HOOK)) { + return ""; + } else if (functionId == uint8(FunctionId.PRE_PERMITTED_CALL_EXECUTION_HOOK)) { + return ""; + } + revert NotImplemented(); + } + + function postExecutionHook(uint8 functionId, bytes calldata) external pure override { + if (functionId == uint8(FunctionId.POST_EXECUTION_HOOK)) { + return; + } else if (functionId == uint8(FunctionId.POST_PERMITTED_CALL_EXECUTION_HOOK)) { + return; + } + revert NotImplemented(); + } + + function pluginManifest() external pure override returns (PluginManifest memory) { + PluginManifest memory manifest; + + manifest.executionFunctions = new bytes4[](1); + manifest.executionFunctions[0] = this.foo.selector; + + ManifestFunction memory fooUserOpValidationFunction = ManifestFunction({ + functionType: ManifestAssociatedFunctionType.SELF, + functionId: uint8(FunctionId.USER_OP_VALIDATION), + dependencyIndex: 0 // Unused. + }); + manifest.userOpValidationFunctions = new ManifestAssociatedFunction[](1); + manifest.userOpValidationFunctions[0] = ManifestAssociatedFunction({ + executionSelector: this.foo.selector, + associatedFunction: fooUserOpValidationFunction + }); + + ManifestFunction memory fooRuntimeValidationFunction = ManifestFunction({ + functionType: ManifestAssociatedFunctionType.SELF, + functionId: uint8(FunctionId.RUNTIME_VALIDATION), + dependencyIndex: 0 // Unused. + }); + manifest.runtimeValidationFunctions = new ManifestAssociatedFunction[](1); + manifest.runtimeValidationFunctions[0] = ManifestAssociatedFunction({ + executionSelector: this.foo.selector, + associatedFunction: fooRuntimeValidationFunction + }); + + manifest.preUserOpValidationHooks = new ManifestAssociatedFunction[](4); + manifest.preUserOpValidationHooks[0] = ManifestAssociatedFunction({ + executionSelector: this.foo.selector, + associatedFunction: ManifestFunction({ + functionType: ManifestAssociatedFunctionType.SELF, + functionId: uint8(FunctionId.PRE_USER_OP_VALIDATION_HOOK_1), + dependencyIndex: 0 // Unused. + }) + }); + manifest.preUserOpValidationHooks[1] = ManifestAssociatedFunction({ + executionSelector: this.foo.selector, + associatedFunction: ManifestFunction({ + functionType: ManifestAssociatedFunctionType.SELF, + functionId: uint8(FunctionId.PRE_USER_OP_VALIDATION_HOOK_2), + dependencyIndex: 0 // Unused. + }) + }); + manifest.preUserOpValidationHooks[2] = ManifestAssociatedFunction({ + executionSelector: IStandardExecutor.execute.selector, + associatedFunction: ManifestFunction({ + functionType: ManifestAssociatedFunctionType.SELF, + functionId: uint8(FunctionId.PRE_USER_OP_VALIDATION_HOOK_1), + dependencyIndex: 0 // Unused. + }) + }); + manifest.preUserOpValidationHooks[3] = ManifestAssociatedFunction({ + executionSelector: IStandardExecutor.execute.selector, + associatedFunction: ManifestFunction({ + functionType: ManifestAssociatedFunctionType.SELF, + functionId: uint8(FunctionId.PRE_USER_OP_VALIDATION_HOOK_2), + dependencyIndex: 0 // Unused. + }) + }); + + manifest.preRuntimeValidationHooks = new ManifestAssociatedFunction[](4); + manifest.preRuntimeValidationHooks[0] = ManifestAssociatedFunction({ + executionSelector: this.foo.selector, + associatedFunction: ManifestFunction({ + functionType: ManifestAssociatedFunctionType.SELF, + functionId: uint8(FunctionId.PRE_RUNTIME_VALIDATION_HOOK_1), + dependencyIndex: 0 // Unused. + }) + }); + manifest.preRuntimeValidationHooks[1] = ManifestAssociatedFunction({ + executionSelector: this.foo.selector, + associatedFunction: ManifestFunction({ + functionType: ManifestAssociatedFunctionType.SELF, + functionId: uint8(FunctionId.PRE_RUNTIME_VALIDATION_HOOK_2), + dependencyIndex: 0 // Unused. + }) + }); + manifest.preRuntimeValidationHooks[2] = ManifestAssociatedFunction({ + executionSelector: IStandardExecutor.execute.selector, + associatedFunction: ManifestFunction({ + functionType: ManifestAssociatedFunctionType.SELF, + functionId: uint8(FunctionId.PRE_RUNTIME_VALIDATION_HOOK_1), + dependencyIndex: 0 // Unused. + }) + }); + manifest.preRuntimeValidationHooks[3] = ManifestAssociatedFunction({ + executionSelector: IStandardExecutor.execute.selector, + associatedFunction: ManifestFunction({ + functionType: ManifestAssociatedFunctionType.SELF, + functionId: uint8(FunctionId.PRE_RUNTIME_VALIDATION_HOOK_2), + dependencyIndex: 0 // Unused. + }) + }); + + manifest.executionHooks = new ManifestExecutionHook[](2); + manifest.executionHooks[0] = ManifestExecutionHook({ + executionSelector: this.foo.selector, + preExecHook: ManifestFunction({ + functionType: ManifestAssociatedFunctionType.SELF, + functionId: uint8(FunctionId.PRE_EXECUTION_HOOK), + dependencyIndex: 0 // Unused. + }), + postExecHook: ManifestFunction({ + functionType: ManifestAssociatedFunctionType.SELF, + functionId: uint8(FunctionId.POST_EXECUTION_HOOK), + dependencyIndex: 0 // Unused. + }) + }); + manifest.executionHooks[1] = ManifestExecutionHook({ + executionSelector: this.foo.selector, + preExecHook: ManifestFunction({ + functionType: ManifestAssociatedFunctionType.NONE, + functionId: 0, // Unused. + dependencyIndex: 0 // Unused. + }), + postExecHook: ManifestFunction({ + functionType: ManifestAssociatedFunctionType.SELF, + functionId: uint8(FunctionId.POST_EXECUTION_HOOK), + dependencyIndex: 0 // Unused. + }) + }); + + manifest.permittedCallHooks = new ManifestExecutionHook[](2); + manifest.permittedCallHooks[0] = ManifestExecutionHook({ + executionSelector: this.foo.selector, + preExecHook: ManifestFunction({ + functionType: ManifestAssociatedFunctionType.SELF, + functionId: uint8(FunctionId.PRE_PERMITTED_CALL_EXECUTION_HOOK), + dependencyIndex: 0 // Unused. + }), + postExecHook: ManifestFunction({ + functionType: ManifestAssociatedFunctionType.SELF, + functionId: uint8(FunctionId.POST_PERMITTED_CALL_EXECUTION_HOOK), + dependencyIndex: 0 // Unused. + }) + }); + manifest.permittedCallHooks[1] = ManifestExecutionHook({ + executionSelector: this.foo.selector, + preExecHook: ManifestFunction({ + functionType: ManifestAssociatedFunctionType.NONE, + functionId: 0, // Unused. + dependencyIndex: 0 // Unused. + }), + postExecHook: ManifestFunction({ + functionType: ManifestAssociatedFunctionType.SELF, + functionId: uint8(FunctionId.POST_PERMITTED_CALL_EXECUTION_HOOK), + dependencyIndex: 0 // Unused. + }) + }); + + return manifest; + } + + function pluginMetadata() external pure virtual override returns (PluginMetadata memory) { + PluginMetadata memory metadata; + metadata.name = _NAME; + metadata.version = _VERSION; + metadata.author = _AUTHOR; + return metadata; + } +} diff --git a/test/mocks/plugins/ExecFromPluginPermissionsMocks.sol b/test/mocks/plugins/ExecFromPluginPermissionsMocks.sol new file mode 100644 index 00000000..73b00888 --- /dev/null +++ b/test/mocks/plugins/ExecFromPluginPermissionsMocks.sol @@ -0,0 +1,470 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.21; + +import { + ManifestFunction, + ManifestAssociatedFunctionType, + ManifestAssociatedFunction, + ManifestExternalCallPermission, + ManifestExecutionHook, + PluginManifest +} from "../../../src/interfaces/IPlugin.sol"; +import {IStandardExecutor} from "../../../src/interfaces/IStandardExecutor.sol"; +import {IPluginExecutor} from "../../../src/interfaces/IPluginExecutor.sol"; +import {IPlugin} from "../../../src/interfaces/IPlugin.sol"; +import {BaseTestPlugin} from "./BaseTestPlugin.sol"; +import {FunctionReference} from "../../../src/libraries/FunctionReferenceLib.sol"; + +import {ResultCreatorPlugin} from "./ReturnDataPluginMocks.sol"; +import {Counter} from "../Counter.sol"; + +// Hardcode the counter addresses from ExecuteFromPluginPermissionsTest to be able to have a pure plugin manifest +// easily +address constant counter1 = 0x5615dEB798BB3E4dFa0139dFa1b3D433Cc23b72f; +address constant counter2 = 0x2e234DAe75C793f67A35089C9d99245E1C58470b; +address constant counter3 = 0xF62849F9A0B5Bf2913b396098F7c7019b51A820a; + +contract EFPCallerPlugin is BaseTestPlugin { + function onInstall(bytes calldata) external override {} + + function onUninstall(bytes calldata) external override {} + + function pluginManifest() external pure override returns (PluginManifest memory) { + PluginManifest memory manifest; + + manifest.executionFunctions = new bytes4[](12); + manifest.executionFunctions[0] = this.useEFPPermissionAllowed.selector; + manifest.executionFunctions[1] = this.useEFPPermissionNotAllowed.selector; + manifest.executionFunctions[2] = this.passthroughExecuteFromPlugin.selector; + manifest.executionFunctions[3] = this.setNumberCounter1.selector; + manifest.executionFunctions[4] = this.getNumberCounter1.selector; + manifest.executionFunctions[5] = this.incrementCounter1.selector; + manifest.executionFunctions[6] = this.setNumberCounter2.selector; + manifest.executionFunctions[7] = this.getNumberCounter2.selector; + manifest.executionFunctions[8] = this.incrementCounter2.selector; + manifest.executionFunctions[9] = this.setNumberCounter3.selector; + manifest.executionFunctions[10] = this.getNumberCounter3.selector; + manifest.executionFunctions[11] = this.incrementCounter3.selector; + + manifest.runtimeValidationFunctions = new ManifestAssociatedFunction[](12); + + ManifestFunction memory alwaysAllowValidationFunction = ManifestFunction({ + functionType: ManifestAssociatedFunctionType.RUNTIME_VALIDATION_ALWAYS_ALLOW, + functionId: 0, + dependencyIndex: 0 + }); + + for (uint256 i = 0; i < manifest.executionFunctions.length; i++) { + manifest.runtimeValidationFunctions[i] = ManifestAssociatedFunction({ + executionSelector: manifest.executionFunctions[i], + associatedFunction: alwaysAllowValidationFunction + }); + } + + // Request permission for "foo" and the non-existent selector "baz", but not "bar", from + // ResultCreatorPlugin + manifest.permittedExecutionSelectors = new bytes4[](2); + manifest.permittedExecutionSelectors[0] = ResultCreatorPlugin.foo.selector; + manifest.permittedExecutionSelectors[1] = bytes4(keccak256("baz()")); + + // Request permission for: + // - `setNumber` and `number` on counter 1 + // - All selectors on counter 2 + // - None on counter 3 + manifest.permittedExternalCalls = new ManifestExternalCallPermission[](2); + + bytes4[] memory selectorsCounter1 = new bytes4[](2); + selectorsCounter1[0] = Counter.setNumber.selector; + selectorsCounter1[1] = bytes4(keccak256("number()")); // Public vars don't automatically get exported + // selectors + + manifest.permittedExternalCalls[0] = ManifestExternalCallPermission({ + externalAddress: counter1, + permitAnySelector: false, + selectors: selectorsCounter1 + }); + + manifest.permittedExternalCalls[1] = ManifestExternalCallPermission({ + externalAddress: counter2, + permitAnySelector: true, + selectors: new bytes4[](0) + }); + + return manifest; + } + + // The manifest requested access to use the plugin-defined method "foo" + function useEFPPermissionAllowed() external returns (bytes memory) { + return IPluginExecutor(msg.sender).executeFromPlugin(abi.encodeCall(ResultCreatorPlugin.foo, ())); + } + + // The manifest has not requested access to use the plugin-defined method "bar", so this should revert. + function useEFPPermissionNotAllowed() external returns (bytes memory) { + return IPluginExecutor(msg.sender).executeFromPlugin(abi.encodeCall(ResultCreatorPlugin.bar, ())); + } + + function passthroughExecuteFromPlugin(bytes calldata data) external returns (bytes memory) { + return IPluginExecutor(msg.sender).executeFromPlugin(data); + } + + // Should be allowed + function setNumberCounter1(uint256 number) external { + IPluginExecutor(msg.sender).executeFromPluginExternal( + counter1, 0, abi.encodeWithSelector(Counter.setNumber.selector, number) + ); + } + + // Should be allowed + function getNumberCounter1() external returns (uint256) { + bytes memory returnData = IPluginExecutor(msg.sender).executeFromPluginExternal( + counter1, 0, abi.encodePacked(bytes4(keccak256("number()"))) + ); + + return abi.decode(returnData, (uint256)); + } + + // Should not be allowed + function incrementCounter1() external { + IPluginExecutor(msg.sender).executeFromPluginExternal( + counter1, 0, abi.encodeWithSelector(Counter.increment.selector) + ); + } + + // Should be allowed + function setNumberCounter2(uint256 number) external { + IPluginExecutor(msg.sender).executeFromPluginExternal( + counter2, 0, abi.encodeWithSelector(Counter.setNumber.selector, number) + ); + } + + // Should be allowed + function getNumberCounter2() external returns (uint256) { + bytes memory returnData = IPluginExecutor(msg.sender).executeFromPluginExternal( + counter2, 0, abi.encodePacked(bytes4(keccak256("number()"))) + ); + + return abi.decode(returnData, (uint256)); + } + + // Should be allowed + function incrementCounter2() external { + IPluginExecutor(msg.sender).executeFromPluginExternal( + counter2, 0, abi.encodeWithSelector(Counter.increment.selector) + ); + } + + // Should not be allowed + function setNumberCounter3(uint256 number) external { + IPluginExecutor(msg.sender).executeFromPluginExternal( + counter3, 0, abi.encodeWithSelector(Counter.setNumber.selector, number) + ); + } + + // Should not be allowed + function getNumberCounter3() external returns (uint256) { + bytes memory returnData = IPluginExecutor(msg.sender).executeFromPluginExternal( + counter3, 0, abi.encodePacked(bytes4(keccak256("number()"))) + ); + + return abi.decode(returnData, (uint256)); + } + + // Should not be allowed + function incrementCounter3() external { + IPluginExecutor(msg.sender).executeFromPluginExternal( + counter3, 0, abi.encodeWithSelector(Counter.increment.selector) + ); + } +} + +contract EFPCallerPluginAnyExternal is BaseTestPlugin { + function onInstall(bytes calldata) external override {} + + function onUninstall(bytes calldata) external override {} + + function pluginManifest() external pure override returns (PluginManifest memory) { + PluginManifest memory manifest; + + manifest.executionFunctions = new bytes4[](2); + manifest.executionFunctions[0] = this.passthroughExecute.selector; + manifest.executionFunctions[1] = this.passthroughExecuteWith1Eth.selector; + + manifest.runtimeValidationFunctions = new ManifestAssociatedFunction[](2); + manifest.runtimeValidationFunctions[0] = ManifestAssociatedFunction({ + executionSelector: this.passthroughExecute.selector, + associatedFunction: ManifestFunction({ + functionType: ManifestAssociatedFunctionType.RUNTIME_VALIDATION_ALWAYS_ALLOW, + functionId: 0, + dependencyIndex: 0 + }) + }); + manifest.runtimeValidationFunctions[1] = ManifestAssociatedFunction({ + executionSelector: this.passthroughExecuteWith1Eth.selector, + associatedFunction: ManifestFunction({ + functionType: ManifestAssociatedFunctionType.RUNTIME_VALIDATION_ALWAYS_ALLOW, + functionId: 0, + dependencyIndex: 0 + }) + }); + + manifest.permitAnyExternalAddress = true; + + return manifest; + } + + function passthroughExecute(address target, uint256 value, bytes calldata data) + external + payable + returns (bytes memory) + { + return IPluginExecutor(msg.sender).executeFromPluginExternal(target, value, data); + } + + function passthroughExecuteWith1Eth(address target, uint256 value, bytes calldata data) + external + payable + returns (bytes memory) + { + return IPluginExecutor(msg.sender).executeFromPluginExternal{value: 1 ether}(target, value, data); + } +} + +contract EFPCallerPluginAnyExternalCanSpendNativeToken is BaseTestPlugin { + function onInstall(bytes calldata) external override {} + + function onUninstall(bytes calldata) external override {} + + function pluginManifest() external pure override returns (PluginManifest memory) { + PluginManifest memory manifest; + + manifest.executionFunctions = new bytes4[](1); + manifest.executionFunctions[0] = this.passthroughExecuteWithNativeTokenSpendPermission.selector; + + manifest.runtimeValidationFunctions = new ManifestAssociatedFunction[](1); + manifest.runtimeValidationFunctions[0] = ManifestAssociatedFunction({ + executionSelector: this.passthroughExecuteWithNativeTokenSpendPermission.selector, + associatedFunction: ManifestFunction({ + functionType: ManifestAssociatedFunctionType.RUNTIME_VALIDATION_ALWAYS_ALLOW, + functionId: 0, + dependencyIndex: 0 + }) + }); + + manifest.canSpendNativeToken = true; + manifest.permitAnyExternalAddress = true; + + return manifest; + } + + function passthroughExecuteWithNativeTokenSpendPermission(address target, uint256 value, bytes calldata data) + external + payable + returns (bytes memory) + { + return IPluginExecutor(msg.sender).executeFromPluginExternal(target, value, data); + } +} + +// Create pre and post permitted call hooks for calling ResultCreatorPlugin.foo via `executeFromPlugin` +contract EFPPermittedCallHookPlugin is BaseTestPlugin { + bool public preExecHookCalled; + bool public postExecHookCalled; + + function preExecutionHook(uint8, address, uint256, bytes calldata) external override returns (bytes memory) { + preExecHookCalled = true; + return "context for post exec hook"; + } + + function postExecutionHook(uint8, bytes calldata preExecHookData) external override { + require( + keccak256(preExecHookData) == keccak256("context for post exec hook"), "Invalid pre exec hook data" + ); + postExecHookCalled = true; + } + + function onInstall(bytes calldata) external override {} + + function onUninstall(bytes calldata) external override {} + + function pluginManifest() external pure override returns (PluginManifest memory) { + PluginManifest memory manifest; + + manifest.executionFunctions = new bytes4[](1); + manifest.executionFunctions[0] = this.performEFPCall.selector; + + manifest.runtimeValidationFunctions = new ManifestAssociatedFunction[](1); + manifest.runtimeValidationFunctions[0] = ManifestAssociatedFunction({ + executionSelector: this.performEFPCall.selector, + associatedFunction: ManifestFunction({ + functionType: ManifestAssociatedFunctionType.RUNTIME_VALIDATION_ALWAYS_ALLOW, + functionId: 0, + dependencyIndex: 0 + }) + }); + + manifest.permittedCallHooks = new ManifestExecutionHook[](1); + manifest.permittedCallHooks[0] = ManifestExecutionHook({ + executionSelector: ResultCreatorPlugin.foo.selector, + preExecHook: ManifestFunction({ + functionType: ManifestAssociatedFunctionType.SELF, + functionId: 0, + dependencyIndex: 0 + }), + postExecHook: ManifestFunction({ + functionType: ManifestAssociatedFunctionType.SELF, + functionId: 0, + dependencyIndex: 0 + }) + }); + + manifest.permittedExecutionSelectors = new bytes4[](1); + manifest.permittedExecutionSelectors[0] = ResultCreatorPlugin.foo.selector; + + return manifest; + } + + function performEFPCall() external returns (bytes memory) { + return IPluginExecutor(msg.sender).executeFromPlugin(abi.encodeCall(ResultCreatorPlugin.foo, ())); + } +} + +// Creates pre and post permitted call hooks for `executeFromPluginExternal` +contract EFPExternalPermittedCallHookPlugin is BaseTestPlugin { + bool public preExecHookCalled; + bool public postExecHookCalled; + + function onInstall(bytes calldata) external override {} + + function onUninstall(bytes calldata) external override {} + + function preExecutionHook(uint8, address, uint256, bytes calldata) external override returns (bytes memory) { + preExecHookCalled = true; + return "context for post exec hook"; + } + + function postExecutionHook(uint8, bytes calldata preExecHookData) external override { + require( + keccak256(preExecHookData) == keccak256("context for post exec hook"), "Invalid pre exec hook data" + ); + postExecHookCalled = true; + } + + function pluginManifest() external pure override returns (PluginManifest memory) { + PluginManifest memory manifest; + + manifest.executionFunctions = new bytes4[](1); + manifest.executionFunctions[0] = this.performIncrement.selector; + + manifest.runtimeValidationFunctions = new ManifestAssociatedFunction[](1); + manifest.runtimeValidationFunctions[0] = ManifestAssociatedFunction({ + executionSelector: this.performIncrement.selector, + associatedFunction: ManifestFunction({ + functionType: ManifestAssociatedFunctionType.RUNTIME_VALIDATION_ALWAYS_ALLOW, + functionId: 0, + dependencyIndex: 0 + }) + }); + + manifest.permittedCallHooks = new ManifestExecutionHook[](1); + manifest.permittedCallHooks[0] = ManifestExecutionHook({ + executionSelector: IPluginExecutor.executeFromPluginExternal.selector, + preExecHook: ManifestFunction({ + functionType: ManifestAssociatedFunctionType.SELF, + functionId: 0, + dependencyIndex: 0 + }), + postExecHook: ManifestFunction({ + functionType: ManifestAssociatedFunctionType.SELF, + functionId: 0, + dependencyIndex: 0 + }) + }); + + manifest.permitAnyExternalAddress = true; + + return manifest; + } + + function performIncrement() external { + IPluginExecutor(msg.sender).executeFromPluginExternal( + counter1, 0, abi.encodeWithSelector(Counter.increment.selector) + ); + } +} + +// Create pre and post execution hooks for calling ResultCreatorPlugin.foo, and add a function that calls it via +// `executeFromPlugin` +contract EFPExecutionHookPlugin is BaseTestPlugin { + function onInstall(bytes calldata) external override {} + + function onUninstall(bytes calldata) external override {} + + function preExecutionHook(uint8 functionId, address, uint256, bytes calldata) + external + pure + override + returns (bytes memory) + { + return abi.encode(functionId); + } + + function postExecutionHook(uint8, bytes calldata) external pure override { + return; + } + + function pluginManifest() external pure override returns (PluginManifest memory) { + PluginManifest memory manifest; + + manifest.executionFunctions = new bytes4[](1); + manifest.executionFunctions[0] = this.performEFPCallWithExecHooks.selector; + + manifest.runtimeValidationFunctions = new ManifestAssociatedFunction[](1); + manifest.runtimeValidationFunctions[0] = ManifestAssociatedFunction({ + executionSelector: this.performEFPCallWithExecHooks.selector, + associatedFunction: ManifestFunction({ + functionType: ManifestAssociatedFunctionType.RUNTIME_VALIDATION_ALWAYS_ALLOW, + functionId: 0, + dependencyIndex: 0 + }) + }); + + manifest.executionHooks = new ManifestExecutionHook[](2); + // Pre and post hook + manifest.executionHooks[0] = ManifestExecutionHook({ + executionSelector: ResultCreatorPlugin.foo.selector, + preExecHook: ManifestFunction({ + functionType: ManifestAssociatedFunctionType.SELF, + functionId: 1, + dependencyIndex: 0 + }), + postExecHook: ManifestFunction({ + functionType: ManifestAssociatedFunctionType.SELF, + functionId: 2, + dependencyIndex: 0 + }) + }); + // Post only hook + manifest.executionHooks[1] = ManifestExecutionHook({ + executionSelector: ResultCreatorPlugin.foo.selector, + preExecHook: ManifestFunction({ + functionType: ManifestAssociatedFunctionType.NONE, + functionId: 0, + dependencyIndex: 0 + }), + postExecHook: ManifestFunction({ + functionType: ManifestAssociatedFunctionType.SELF, + functionId: 2, + dependencyIndex: 0 + }) + }); + + manifest.permittedExecutionSelectors = new bytes4[](1); + manifest.permittedExecutionSelectors[0] = ResultCreatorPlugin.foo.selector; + + return manifest; + } + + function performEFPCallWithExecHooks() external returns (bytes memory) { + return IPluginExecutor(msg.sender).executeFromPlugin(abi.encodeCall(ResultCreatorPlugin.foo, ())); + } +} diff --git a/test/mocks/plugins/ManifestValidityMocks.sol b/test/mocks/plugins/ManifestValidityMocks.sol new file mode 100644 index 00000000..9353a002 --- /dev/null +++ b/test/mocks/plugins/ManifestValidityMocks.sol @@ -0,0 +1,294 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.21; + +import { + ManifestFunction, + ManifestAssociatedFunctionType, + ManifestAssociatedFunction, + ManifestExecutionHook, + ManifestExternalCallPermission, + PluginManifest +} from "../../../src/interfaces/IPlugin.sol"; +import {IStandardExecutor} from "../../../src/interfaces/IStandardExecutor.sol"; +import {IPluginExecutor} from "../../../src/interfaces/IPluginExecutor.sol"; +import {IPlugin} from "../../../src/interfaces/IPlugin.sol"; +import {BaseTestPlugin} from "./BaseTestPlugin.sol"; +import {FunctionReference} from "../../../src/libraries/FunctionReferenceLib.sol"; + +contract BadValidationMagicValue_UserOp_Plugin is BaseTestPlugin { + function onInstall(bytes calldata) external override {} + + function onUninstall(bytes calldata) external override {} + + function foo() external pure returns (bytes32) { + return keccak256("bar"); + } + + function pluginManifest() external pure override returns (PluginManifest memory) { + PluginManifest memory manifest; + + manifest.executionFunctions = new bytes4[](1); + manifest.executionFunctions[0] = this.foo.selector; + + manifest.userOpValidationFunctions = new ManifestAssociatedFunction[](1); + manifest.userOpValidationFunctions[0] = ManifestAssociatedFunction({ + executionSelector: this.foo.selector, + associatedFunction: ManifestFunction({ + // Illegal assignment: validation always allow only usable on runtime validation functions + functionType: ManifestAssociatedFunctionType.RUNTIME_VALIDATION_ALWAYS_ALLOW, + functionId: 0, + dependencyIndex: 0 + }) + }); + + return manifest; + } +} + +contract BadValidationMagicValue_PreRuntimeValidationHook_Plugin is BaseTestPlugin { + function onInstall(bytes calldata) external override {} + + function onUninstall(bytes calldata) external override {} + + function foo() external pure returns (bytes32) { + return keccak256("bar"); + } + + function pluginManifest() external pure override returns (PluginManifest memory) { + PluginManifest memory manifest; + + manifest.executionFunctions = new bytes4[](1); + manifest.executionFunctions[0] = this.foo.selector; + + manifest.runtimeValidationFunctions = new ManifestAssociatedFunction[](1); + manifest.runtimeValidationFunctions[0] = ManifestAssociatedFunction({ + executionSelector: this.foo.selector, + associatedFunction: ManifestFunction({ + functionType: ManifestAssociatedFunctionType.SELF, + functionId: 0, + dependencyIndex: 0 + }) + }); + + manifest.preRuntimeValidationHooks = new ManifestAssociatedFunction[](1); + + // Illegal assignment: validation always allow only usable on runtime validation functions + manifest.preRuntimeValidationHooks[0] = ManifestAssociatedFunction({ + executionSelector: this.foo.selector, + associatedFunction: ManifestFunction({ + functionType: ManifestAssociatedFunctionType.RUNTIME_VALIDATION_ALWAYS_ALLOW, + functionId: 0, + dependencyIndex: 0 + }) + }); + + return manifest; + } +} + +contract BadValidationMagicValue_PreUserOpValidationHook_Plugin is BaseTestPlugin { + function onInstall(bytes calldata) external override {} + + function onUninstall(bytes calldata) external override {} + + function foo() external pure returns (bytes32) { + return keccak256("bar"); + } + + function pluginManifest() external pure override returns (PluginManifest memory) { + PluginManifest memory manifest; + + manifest.executionFunctions = new bytes4[](1); + manifest.executionFunctions[0] = this.foo.selector; + + manifest.userOpValidationFunctions = new ManifestAssociatedFunction[](1); + manifest.userOpValidationFunctions[0] = ManifestAssociatedFunction({ + executionSelector: this.foo.selector, + associatedFunction: ManifestFunction({ + functionType: ManifestAssociatedFunctionType.SELF, + functionId: 0, + dependencyIndex: 0 + }) + }); + + manifest.preUserOpValidationHooks = new ManifestAssociatedFunction[](1); + + // Illegal assignment: validation always allow only usable on runtime validation functions + manifest.preUserOpValidationHooks[0] = ManifestAssociatedFunction({ + executionSelector: this.foo.selector, + associatedFunction: ManifestFunction({ + functionType: ManifestAssociatedFunctionType.RUNTIME_VALIDATION_ALWAYS_ALLOW, + functionId: 0, + dependencyIndex: 0 + }) + }); + + return manifest; + } +} + +contract BadValidationMagicValue_PreExecHook_Plugin is BaseTestPlugin { + function onInstall(bytes calldata) external override {} + + function onUninstall(bytes calldata) external override {} + + function foo() external pure returns (bytes32) { + return keccak256("bar"); + } + + function pluginManifest() external pure override returns (PluginManifest memory) { + PluginManifest memory manifest; + + manifest.executionFunctions = new bytes4[](1); + manifest.executionFunctions[0] = this.foo.selector; + + manifest.executionHooks = new ManifestExecutionHook[](1); + + // Illegal assignment: validation always allow only usable on runtime validation functions + manifest.executionHooks[0] = ManifestExecutionHook({ + executionSelector: this.foo.selector, + preExecHook: ManifestFunction({ + functionType: ManifestAssociatedFunctionType.RUNTIME_VALIDATION_ALWAYS_ALLOW, + functionId: 0, + dependencyIndex: 0 + }), + postExecHook: ManifestFunction({ + functionType: ManifestAssociatedFunctionType.SELF, + functionId: 0, // Dummy unimplemented function id, but can be added correctly + dependencyIndex: 0 + }) + }); + + return manifest; + } +} + +contract BadValidationMagicValue_PostExecHook_Plugin is BaseTestPlugin { + function onInstall(bytes calldata) external override {} + + function onUninstall(bytes calldata) external override {} + + function foo() external pure returns (bytes32) { + return keccak256("bar"); + } + + function pluginManifest() external pure override returns (PluginManifest memory) { + PluginManifest memory manifest; + + manifest.executionFunctions = new bytes4[](1); + manifest.executionFunctions[0] = this.foo.selector; + + manifest.executionHooks = new ManifestExecutionHook[](1); + + // Illegal assignment: validation always allow only usable on runtime validation functions + manifest.executionHooks[0] = ManifestExecutionHook({ + executionSelector: this.foo.selector, + preExecHook: ManifestFunction({ + functionType: ManifestAssociatedFunctionType.SELF, + functionId: 0, // Dummy unimplemented function id, but can be added correctly + dependencyIndex: 0 + }), + postExecHook: ManifestFunction({ + functionType: ManifestAssociatedFunctionType.RUNTIME_VALIDATION_ALWAYS_ALLOW, + functionId: 0, + dependencyIndex: 0 + }) + }); + + return manifest; + } +} + +contract BadHookMagicValue_UserOpValidationFunction_Plugin is BaseTestPlugin { + function onInstall(bytes calldata) external override {} + + function onUninstall(bytes calldata) external override {} + + function foo() external pure returns (bytes32) { + return keccak256("bar"); + } + + function pluginManifest() external pure override returns (PluginManifest memory) { + PluginManifest memory manifest; + + manifest.executionFunctions = new bytes4[](1); + manifest.executionFunctions[0] = this.foo.selector; + + manifest.userOpValidationFunctions = new ManifestAssociatedFunction[](1); + manifest.userOpValidationFunctions[0] = ManifestAssociatedFunction({ + executionSelector: this.foo.selector, + associatedFunction: ManifestFunction({ + functionType: ManifestAssociatedFunctionType.PRE_HOOK_ALWAYS_DENY, + functionId: 0, + dependencyIndex: 0 + }) + }); + + return manifest; + } +} + +contract BadHookMagicValue_RuntimeValidationFunction_Plugin is BaseTestPlugin { + function onInstall(bytes calldata) external override {} + + function onUninstall(bytes calldata) external override {} + + function foo() external pure returns (bytes32) { + return keccak256("bar"); + } + + function pluginManifest() external pure override returns (PluginManifest memory) { + PluginManifest memory manifest; + + manifest.executionFunctions = new bytes4[](1); + manifest.executionFunctions[0] = this.foo.selector; + + manifest.runtimeValidationFunctions = new ManifestAssociatedFunction[](1); + manifest.runtimeValidationFunctions[0] = ManifestAssociatedFunction({ + executionSelector: this.foo.selector, + associatedFunction: ManifestFunction({ + functionType: ManifestAssociatedFunctionType.PRE_HOOK_ALWAYS_DENY, + functionId: 0, + dependencyIndex: 0 + }) + }); + + return manifest; + } +} + +contract BadHookMagicValue_PostExecHook_Plugin is BaseTestPlugin { + function onInstall(bytes calldata) external override {} + + function onUninstall(bytes calldata) external override {} + + function foo() external pure returns (bytes32) { + return keccak256("bar"); + } + + function pluginManifest() external pure override returns (PluginManifest memory) { + PluginManifest memory manifest; + + manifest.executionFunctions = new bytes4[](1); + manifest.executionFunctions[0] = this.foo.selector; + + manifest.executionHooks = new ManifestExecutionHook[](1); + + // Illegal assignment: hook always deny only usable on runtime validation functions + manifest.executionHooks[0] = ManifestExecutionHook({ + executionSelector: this.foo.selector, + preExecHook: ManifestFunction({ + functionType: ManifestAssociatedFunctionType.SELF, + functionId: 0, // Dummy unimplemented function id, but can be added correctly + dependencyIndex: 0 + }), + postExecHook: ManifestFunction({ + functionType: ManifestAssociatedFunctionType.PRE_HOOK_ALWAYS_DENY, + functionId: 0, + dependencyIndex: 0 + }) + }); + + return manifest; + } +} diff --git a/test/mocks/plugins/ReturnDataPluginMocks.sol b/test/mocks/plugins/ReturnDataPluginMocks.sol new file mode 100644 index 00000000..62a8ebfe --- /dev/null +++ b/test/mocks/plugins/ReturnDataPluginMocks.sol @@ -0,0 +1,156 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.21; + +import { + ManifestFunction, + ManifestAssociatedFunctionType, + ManifestAssociatedFunction, + ManifestExternalCallPermission, + PluginManifest +} from "../../../src/interfaces/IPlugin.sol"; +import {IStandardExecutor} from "../../../src/interfaces/IStandardExecutor.sol"; +import {IPluginExecutor} from "../../../src/interfaces/IPluginExecutor.sol"; +import {IPlugin} from "../../../src/interfaces/IPlugin.sol"; +import {BaseTestPlugin} from "./BaseTestPlugin.sol"; +import {FunctionReference} from "../../../src/libraries/FunctionReferenceLib.sol"; + +contract RegularResultContract { + function foo() external pure returns (bytes32) { + return keccak256("bar"); + } + + function bar() external pure returns (bytes32) { + return keccak256("foo"); + } +} + +contract ResultCreatorPlugin is BaseTestPlugin { + function onInstall(bytes calldata) external override {} + + function onUninstall(bytes calldata) external override {} + + function foo() external pure returns (bytes32) { + return keccak256("bar"); + } + + function bar() external pure returns (bytes32) { + return keccak256("foo"); + } + + function pluginManifest() external pure override returns (PluginManifest memory) { + PluginManifest memory manifest; + + manifest.executionFunctions = new bytes4[](2); + manifest.executionFunctions[0] = this.foo.selector; + manifest.executionFunctions[1] = this.bar.selector; + + manifest.runtimeValidationFunctions = new ManifestAssociatedFunction[](1); + manifest.runtimeValidationFunctions[0] = ManifestAssociatedFunction({ + executionSelector: this.foo.selector, + associatedFunction: ManifestFunction({ + functionType: ManifestAssociatedFunctionType.RUNTIME_VALIDATION_ALWAYS_ALLOW, + functionId: 0, + dependencyIndex: 0 + }) + }); + + return manifest; + } +} + +contract ResultConsumerPlugin is BaseTestPlugin { + ResultCreatorPlugin public immutable resultCreator; + RegularResultContract public immutable regularResultContract; + + constructor(ResultCreatorPlugin _resultCreator, RegularResultContract _regularResultContract) { + resultCreator = _resultCreator; + regularResultContract = _regularResultContract; + } + + // Check the return data through the executeFromPlugin fallback case + function checkResultEFPFallback(bytes32 expected) external returns (bool) { + // This result should be allowed based on the manifest permission request + IPluginExecutor(msg.sender).executeFromPlugin(abi.encodeCall(ResultCreatorPlugin.foo, ())); + + bytes32 actual = ResultCreatorPlugin(msg.sender).foo(); + + return actual == expected; + } + + // Check the rturn data through the executeFromPlugin std exec case + function checkResultEFPExternal(address target, bytes32 expected) external returns (bool) { + // This result should be allowed based on the manifest permission request + bytes memory returnData = IPluginExecutor(msg.sender).executeFromPluginExternal( + target, 0, abi.encodeCall(RegularResultContract.foo, ()) + ); + + bytes32 actual = abi.decode(returnData, (bytes32)); + + return actual == expected; + } + + function onInstall(bytes calldata) external override {} + + function onUninstall(bytes calldata) external override {} + + function pluginManifest() external pure override returns (PluginManifest memory) { + // We want to return the address of the immutable RegularResultContract in the permitted external calls + // area of the manifest. + // However, reading from immutable values is not permitted in pure functions. So we use this hack to get + // around that. + // In regular, non-mock plugins, external call targets in the plugin manifest should be constants, not just + // immutbales. + // But to make testing easier, we do this. + + function() internal pure returns (PluginManifest memory) pureManifestGetter; + + function() internal view returns (PluginManifest memory) viewManifestGetter = _innerPluginManifest; + + assembly ("memory-safe") { + pureManifestGetter := viewManifestGetter + } + + return pureManifestGetter(); + } + + function _innerPluginManifest() internal view returns (PluginManifest memory) { + PluginManifest memory manifest; + + manifest.executionFunctions = new bytes4[](2); + manifest.executionFunctions[0] = this.checkResultEFPFallback.selector; + manifest.executionFunctions[1] = this.checkResultEFPExternal.selector; + + manifest.runtimeValidationFunctions = new ManifestAssociatedFunction[](2); + manifest.runtimeValidationFunctions[0] = ManifestAssociatedFunction({ + executionSelector: this.checkResultEFPFallback.selector, + associatedFunction: ManifestFunction({ + functionType: ManifestAssociatedFunctionType.RUNTIME_VALIDATION_ALWAYS_ALLOW, + functionId: 0, + dependencyIndex: 0 + }) + }); + manifest.runtimeValidationFunctions[1] = ManifestAssociatedFunction({ + executionSelector: this.checkResultEFPExternal.selector, + associatedFunction: ManifestFunction({ + functionType: ManifestAssociatedFunctionType.RUNTIME_VALIDATION_ALWAYS_ALLOW, + functionId: 0, + dependencyIndex: 0 + }) + }); + + manifest.permittedExecutionSelectors = new bytes4[](1); + manifest.permittedExecutionSelectors[0] = ResultCreatorPlugin.foo.selector; + + manifest.permittedExternalCalls = new ManifestExternalCallPermission[](1); + + bytes4[] memory allowedSelectors = new bytes4[](1); + allowedSelectors[0] = RegularResultContract.foo.selector; + manifest.permittedExternalCalls[0] = ManifestExternalCallPermission({ + externalAddress: address(regularResultContract), + permitAnySelector: false, + selectors: allowedSelectors + }); + + return manifest; + } +} diff --git a/test/mocks/plugins/UninstallErrorsPlugin.sol b/test/mocks/plugins/UninstallErrorsPlugin.sol new file mode 100644 index 00000000..04f121d6 --- /dev/null +++ b/test/mocks/plugins/UninstallErrorsPlugin.sol @@ -0,0 +1,71 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.21; + +import {StorageSlot} from "@openzeppelin/contracts/utils/StorageSlot.sol"; + +import { + ManifestFunction, + ManifestAssociatedFunctionType, + ManifestAssociatedFunction, + PluginManifest, + PluginMetadata +} from "../../../src/interfaces/IPlugin.sol"; +import {IPluginManager} from "../../../src/interfaces/IPluginManager.sol"; +import {BaseTestPlugin} from "./BaseTestPlugin.sol"; + +/// Mock plugin that reverts in its uninstall callbacks. Can be configured to +/// either immediately revert or to drain all remaining gas. +contract UninstallErrorsPlugin is BaseTestPlugin { + string internal constant _NAME = "UninstallErrorsPlugin"; + string internal constant _VERSION = "1.0.0"; + string internal constant _AUTHOR = "Alchemy"; + + bool private _shouldDrainGas; + + error IntentionalUninstallError(); + + constructor(bool shouldDrainGas) { + _shouldDrainGas = shouldDrainGas; + } + + function onUninstall(bytes calldata) external override { + _revert(); + } + + function onHookUnapply(address, IPluginManager.InjectedHooksInfo calldata, bytes calldata) + external + virtual + override + { + _revert(); + } + + function pluginManifest() external pure override returns (PluginManifest memory manifest) {} + + function pluginMetadata() external pure override returns (PluginMetadata memory) { + PluginMetadata memory metadata; + + metadata.name = _NAME; + metadata.version = _VERSION; + metadata.author = _AUTHOR; + + return metadata; + } + + function _onInstall(bytes calldata) internal virtual override {} + + function _revert() private { + if (_shouldDrainGas) { + _wasteAllRemainingGas(); + } else { + revert IntentionalUninstallError(); + } + } + + function _wasteAllRemainingGas() private { + for (uint256 i = 0;; i++) { + // Say goodbye to your gas. + StorageSlot.getBooleanSlot(bytes32(i)).value = true; + } + } +} diff --git a/test/mocks/plugins/ValidationPluginMocks.sol b/test/mocks/plugins/ValidationPluginMocks.sol new file mode 100644 index 00000000..2f80add8 --- /dev/null +++ b/test/mocks/plugins/ValidationPluginMocks.sol @@ -0,0 +1,201 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.21; + +import {UserOperation} from "../../../src/interfaces/erc4337/UserOperation.sol"; +import { + ManifestFunction, + ManifestAssociatedFunctionType, + ManifestAssociatedFunction, + PluginManifest +} from "../../../src/interfaces/IPlugin.sol"; +import {BaseTestPlugin} from "./BaseTestPlugin.sol"; + +abstract contract MockBaseUserOpValidationPlugin is BaseTestPlugin { + enum FunctionId { + USER_OP_VALIDATION, + PRE_USER_OP_VALIDATION_HOOK_1, + PRE_USER_OP_VALIDATION_HOOK_2 + } + + uint256 internal _userOpValidationFunctionData; + uint256 internal _preUserOpValidationHook1Data; + uint256 internal _preUserOpValidationHook2Data; + + // ┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ + // ┃ Plugin interface functions ┃ + // ┗━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┛ + + function onInstall(bytes calldata) external override {} + + function onUninstall(bytes calldata) external override {} + + function preUserOpValidationHook(uint8 functionId, UserOperation calldata, bytes32) + external + view + override + returns (uint256) + { + if (functionId == uint8(FunctionId.PRE_USER_OP_VALIDATION_HOOK_1)) { + return _preUserOpValidationHook1Data; + } else if (functionId == uint8(FunctionId.PRE_USER_OP_VALIDATION_HOOK_2)) { + return _preUserOpValidationHook2Data; + } + revert NotImplemented(); + } + + function userOpValidationFunction(uint8 functionId, UserOperation calldata, bytes32) + external + view + override + returns (uint256) + { + if (functionId == uint8(FunctionId.USER_OP_VALIDATION)) { + return _userOpValidationFunctionData; + } + revert NotImplemented(); + } +} + +contract MockUserOpValidationPlugin is MockBaseUserOpValidationPlugin { + function setValidationData(uint256 userOpValidationFunctionData) external { + _userOpValidationFunctionData = userOpValidationFunctionData; + } + + // ┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ + // ┃ Execution functions ┃ + // ┗━━━━━━━━━━━━━━━━━━━━━━━━━━━┛ + + function foo() external {} + + // ┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ + // ┃ Plugin interface functions ┃ + // ┗━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┛ + + function pluginManifest() external pure override returns (PluginManifest memory) { + PluginManifest memory manifest; + + manifest.executionFunctions = new bytes4[](1); + manifest.executionFunctions[0] = this.foo.selector; + + manifest.userOpValidationFunctions = new ManifestAssociatedFunction[](1); + manifest.userOpValidationFunctions[0] = ManifestAssociatedFunction({ + executionSelector: this.foo.selector, + associatedFunction: ManifestFunction({ + functionType: ManifestAssociatedFunctionType.SELF, + functionId: uint8(FunctionId.USER_OP_VALIDATION), + dependencyIndex: 0 // Unused. + }) + }); + + return manifest; + } +} + +contract MockUserOpValidation1HookPlugin is MockBaseUserOpValidationPlugin { + function setValidationData(uint256 userOpValidationFunctionData, uint256 preUserOpValidationHook1Data) + external + { + _userOpValidationFunctionData = userOpValidationFunctionData; + _preUserOpValidationHook1Data = preUserOpValidationHook1Data; + } + + // ┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ + // ┃ Execution functions ┃ + // ┗━━━━━━━━━━━━━━━━━━━━━━━━━━━┛ + + function bar() external {} + + // ┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ + // ┃ Plugin interface functions ┃ + // ┗━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┛ + + function pluginManifest() external pure override returns (PluginManifest memory) { + PluginManifest memory manifest; + + manifest.executionFunctions = new bytes4[](1); + manifest.executionFunctions[0] = this.bar.selector; + + ManifestFunction memory userOpValidationFunctionRef = ManifestFunction({ + functionType: ManifestAssociatedFunctionType.SELF, + functionId: uint8(FunctionId.USER_OP_VALIDATION), + dependencyIndex: 0 // Unused. + }); + manifest.userOpValidationFunctions = new ManifestAssociatedFunction[](1); + manifest.userOpValidationFunctions[0] = ManifestAssociatedFunction({ + executionSelector: this.bar.selector, + associatedFunction: userOpValidationFunctionRef + }); + + manifest.preUserOpValidationHooks = new ManifestAssociatedFunction[](1); + manifest.preUserOpValidationHooks[0] = ManifestAssociatedFunction({ + executionSelector: this.bar.selector, + associatedFunction: ManifestFunction({ + functionType: ManifestAssociatedFunctionType.SELF, + functionId: uint8(FunctionId.PRE_USER_OP_VALIDATION_HOOK_1), + dependencyIndex: 0 // Unused. + }) + }); + + return manifest; + } +} + +contract MockUserOpValidation2HookPlugin is MockBaseUserOpValidationPlugin { + function setValidationData( + uint256 userOpValidationFunctionData, + uint256 preUserOpValidationHook1Data, + uint256 preUserOpValidationHook2Data + ) external { + _userOpValidationFunctionData = userOpValidationFunctionData; + _preUserOpValidationHook1Data = preUserOpValidationHook1Data; + _preUserOpValidationHook2Data = preUserOpValidationHook2Data; + } + + // ┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ + // ┃ Execution functions ┃ + // ┗━━━━━━━━━━━━━━━━━━━━━━━━━━━┛ + + function baz() external {} + + // ┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ + // ┃ Plugin interface functions ┃ + // ┗━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┛ + + function pluginManifest() external pure override returns (PluginManifest memory) { + PluginManifest memory manifest; + + manifest.executionFunctions = new bytes4[](1); + manifest.executionFunctions[0] = this.baz.selector; + + ManifestFunction memory userOpValidationFunctionRef = ManifestFunction({ + functionType: ManifestAssociatedFunctionType.SELF, + functionId: uint8(FunctionId.USER_OP_VALIDATION), + dependencyIndex: 0 // Unused. + }); + manifest.userOpValidationFunctions = new ManifestAssociatedFunction[](1); + manifest.userOpValidationFunctions[0] = ManifestAssociatedFunction({ + executionSelector: this.baz.selector, + associatedFunction: userOpValidationFunctionRef + }); + + manifest.preUserOpValidationHooks = new ManifestAssociatedFunction[](2); + manifest.preUserOpValidationHooks[0] = ManifestAssociatedFunction({ + executionSelector: this.baz.selector, + associatedFunction: ManifestFunction({ + functionType: ManifestAssociatedFunctionType.SELF, + functionId: uint8(FunctionId.PRE_USER_OP_VALIDATION_HOOK_1), + dependencyIndex: 0 // Unused. + }) + }); + manifest.preUserOpValidationHooks[1] = ManifestAssociatedFunction({ + executionSelector: this.baz.selector, + associatedFunction: ManifestFunction({ + functionType: ManifestAssociatedFunctionType.SELF, + functionId: uint8(FunctionId.PRE_USER_OP_VALIDATION_HOOK_2), + dependencyIndex: 0 // Unused. + }) + }); + + return manifest; + } +} diff --git a/test/mocks/tokens/MockERC1155.sol b/test/mocks/tokens/MockERC1155.sol new file mode 100644 index 00000000..82fd463d --- /dev/null +++ b/test/mocks/tokens/MockERC1155.sol @@ -0,0 +1,12 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.21; + +import {ERC1155} from "@openzeppelin/contracts/token/ERC1155/ERC1155.sol"; + +contract MockERC1155 is ERC1155 { + constructor() ERC1155("") {} + + function mint(address to, uint256 id, uint256 amount) public { + _mint(to, id, amount, ""); + } +} diff --git a/test/mocks/tokens/MockERC20.sol b/test/mocks/tokens/MockERC20.sol new file mode 100644 index 00000000..916b2b81 --- /dev/null +++ b/test/mocks/tokens/MockERC20.sol @@ -0,0 +1,12 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.21; + +import {ERC20} from "@openzeppelin/contracts/token/ERC20/ERC20.sol"; + +contract MockERC20 is ERC20 { + constructor(string memory name) ERC20(name, name) {} + + function mint(address to, uint256 amount) public { + _mint(to, amount); + } +} diff --git a/test/mocks/tokens/MockERC777.sol b/test/mocks/tokens/MockERC777.sol new file mode 100644 index 00000000..a6478ea4 --- /dev/null +++ b/test/mocks/tokens/MockERC777.sol @@ -0,0 +1,51 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.21; + +import {IERC777} from "@openzeppelin/contracts/token/ERC777/IERC777.sol"; +import {IERC777Recipient} from "@openzeppelin/contracts/token/ERC777/IERC777Recipient.sol"; + +contract MockERC777 is IERC777 { + string public override name; + string public override symbol; + uint256 public override granularity; + uint256 public override totalSupply; + mapping(address => uint256) public override balanceOf; + + function mint(address to, uint256 amount) public { + balanceOf[to] += amount; + } + + function transfer(address to, uint256 amount) public returns (bool) { + return transferFrom(msg.sender, to, amount); + } + + function transferFrom(address from, address to, uint256 amount) public returns (bool) { + IERC777Recipient(to).tokensReceived(msg.sender, from, to, amount, bytes(""), bytes("")); + balanceOf[from] -= amount; + balanceOf[to] += amount; + return true; + } + + function send(address to, uint256 amount, bytes calldata) public override { + transferFrom(msg.sender, to, amount); + } + + function burn(uint256 amount, bytes calldata) external { + transferFrom(msg.sender, address(0), amount); + } + + function isOperatorFor(address, address) external pure returns (bool) { + return false; + } + + // solhint-disable-next-line no-empty-blocks + function authorizeOperator(address) external {} + // solhint-disable-next-line no-empty-blocks + function revokeOperator(address) external {} + // solhint-disable-next-line no-empty-blocks + function defaultOperators() external view returns (address[] memory a) {} + // solhint-disable-next-line no-empty-blocks + function operatorSend(address, address, uint256, bytes calldata, bytes calldata) external {} + // solhint-disable-next-line no-empty-blocks + function operatorBurn(address, uint256, bytes calldata, bytes calldata) external {} +} diff --git a/test/plugin/TokenReceiverPlugin.t.sol b/test/plugin/TokenReceiverPlugin.t.sol new file mode 100644 index 00000000..35b1f8b8 --- /dev/null +++ b/test/plugin/TokenReceiverPlugin.t.sol @@ -0,0 +1,206 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.21; + +import {Test} from "forge-std/Test.sol"; + +import {EntryPoint} from "@eth-infinitism/account-abstraction/core/EntryPoint.sol"; +import {IERC721Receiver} from "@openzeppelin/contracts/token/ERC721/IERC721Receiver.sol"; +import {ERC721PresetMinterPauserAutoId} from + "@openzeppelin/contracts/token/ERC721/presets/ERC721PresetMinterPauserAutoId.sol"; +import {IERC777Recipient} from "@openzeppelin/contracts/token/ERC777/IERC777Recipient.sol"; +import {IERC1155Receiver} from "@openzeppelin/contracts/token/ERC1155/IERC1155Receiver.sol"; + +import {TokenReceiverPlugin} from "../../src/plugins/TokenReceiverPlugin.sol"; +import {UpgradeableModularAccount} from "../../src/account/UpgradeableModularAccount.sol"; +import {IEntryPoint} from "../../src/interfaces/erc4337/IEntryPoint.sol"; +import {AccountStorageV1} from "../../src/libraries/AccountStorageV1.sol"; +import {MultiOwnerPlugin} from "../../src/plugins/owner/MultiOwnerPlugin.sol"; +import {IPluginManager} from "../../src/interfaces/IPluginManager.sol"; +import {FunctionReference} from "../../src/libraries/FunctionReferenceLib.sol"; + +import {MockERC777} from "../mocks/tokens/MockERC777.sol"; +import {MockERC1155} from "../mocks/tokens/MockERC1155.sol"; +import {MultiOwnerMSCAFactory} from "../../src/factory/MultiOwnerMSCAFactory.sol"; + +contract TokenReceiverPluginTest is Test, IERC1155Receiver, AccountStorageV1 { + UpgradeableModularAccount public acct; + TokenReceiverPlugin public plugin; + + ERC721PresetMinterPauserAutoId public t0; + MockERC777 public t1; + MockERC1155 public t2; + MultiOwnerMSCAFactory public factory; + MultiOwnerPlugin public multiOwnerPlugin; + IEntryPoint public entryPoint; + + address public owner; + address[] public owners; + + // init dynamic length arrays for use in args + address[] public defaultOperators; + uint256[] public tokenIds; + uint256[] public tokenAmts; + uint256[] public zeroTokenAmts; + + uint256 internal constant _TOKEN_AMOUNT = 1 ether; + uint256 internal constant _TOKEN_ID = 0; + uint256 internal constant _BATCH_TOKEN_IDS = 5; + + function setUp() public { + entryPoint = IEntryPoint(address(new EntryPoint())); + multiOwnerPlugin = new MultiOwnerPlugin(); + factory = new MultiOwnerMSCAFactory( + address(this), + address(multiOwnerPlugin), + address(new UpgradeableModularAccount(entryPoint)), + keccak256(abi.encode(multiOwnerPlugin.pluginManifest())), + entryPoint + ); + (owner,) = makeAddrAndKey("owner"); + owners = new address[](1); + owners[0] = owner; + acct = UpgradeableModularAccount(payable(factory.createAccount(0, owners))); + plugin = new TokenReceiverPlugin(); + + t0 = new ERC721PresetMinterPauserAutoId("t0", "t0", ""); + t0.mint(address(this)); + + t1 = new MockERC777(); + t1.mint(address(this), _TOKEN_AMOUNT); + + t2 = new MockERC1155(); + t2.mint(address(this), _TOKEN_ID, _TOKEN_AMOUNT); + for (uint256 i = 1; i < _BATCH_TOKEN_IDS; i++) { + t2.mint(address(this), i, _TOKEN_AMOUNT); + tokenIds.push(i); + tokenAmts.push(_TOKEN_AMOUNT); + zeroTokenAmts.push(0); + } + } + + function _initPlugin() internal { + vm.startPrank(owner); + acct.installPlugin( + address(plugin), + keccak256(abi.encode(plugin.pluginManifest())), + bytes(""), + new FunctionReference[](0), + new IPluginManager.InjectedHook[](0) + ); + vm.stopPrank(); + } + + function test_failERC721Transfer() public { + vm.expectRevert( + abi.encodeWithSelector( + UpgradeableModularAccount.UnrecognizedFunction.selector, IERC721Receiver.onERC721Received.selector + ) + ); + t0.safeTransferFrom(address(this), address(acct), _TOKEN_ID); + } + + function test_passERC721Transfer() public { + _initPlugin(); + assertEq(t0.ownerOf(_TOKEN_ID), address(this)); + t0.safeTransferFrom(address(this), address(acct), _TOKEN_ID); + assertEq(t0.ownerOf(_TOKEN_ID), address(acct)); + } + + function test_failERC777Transfer() public { + vm.expectRevert( + abi.encodeWithSelector( + UpgradeableModularAccount.UnrecognizedFunction.selector, IERC777Recipient.tokensReceived.selector + ) + ); + t1.transfer(address(acct), _TOKEN_AMOUNT); + } + + function test_passERC777Transfer() public { + _initPlugin(); + + assertEq(t1.balanceOf(address(this)), _TOKEN_AMOUNT); + assertEq(t1.balanceOf(address(acct)), 0); + t1.transfer(address(acct), _TOKEN_AMOUNT); + assertEq(t1.balanceOf(address(this)), 0); + assertEq(t1.balanceOf(address(acct)), _TOKEN_AMOUNT); + } + + function test_failERC1155Transfer() public { + // for 1155, reverts are caught in a try catch and bubbled up with a diff reason + vm.expectRevert("ERC1155: transfer to non-ERC1155Receiver implementer"); + t2.safeTransferFrom(address(this), address(acct), _TOKEN_ID, _TOKEN_AMOUNT, ""); + + // for 1155, reverts are caught in a try catch and bubbled up with a diff reason + vm.expectRevert("ERC1155: transfer to non-ERC1155Receiver implementer"); + t2.safeBatchTransferFrom(address(this), address(acct), tokenIds, tokenAmts, ""); + } + + function test_passERC1155Transfer() public { + _initPlugin(); + + assertEq(t2.balanceOf(address(this), _TOKEN_ID), _TOKEN_AMOUNT); + assertEq(t2.balanceOf(address(acct), _TOKEN_ID), 0); + t2.safeTransferFrom(address(this), address(acct), _TOKEN_ID, _TOKEN_AMOUNT, ""); + assertEq(t2.balanceOf(address(this), _TOKEN_ID), 0); + assertEq(t2.balanceOf(address(acct), _TOKEN_ID), _TOKEN_AMOUNT); + + for (uint256 i = 1; i < _BATCH_TOKEN_IDS; i++) { + assertEq(t2.balanceOf(address(this), i), _TOKEN_AMOUNT); + assertEq(t2.balanceOf(address(acct), i), 0); + } + t2.safeBatchTransferFrom(address(this), address(acct), tokenIds, tokenAmts, ""); + for (uint256 i = 1; i < _BATCH_TOKEN_IDS; i++) { + assertEq(t2.balanceOf(address(this), i), 0); + assertEq(t2.balanceOf(address(acct), i), _TOKEN_AMOUNT); + } + } + + function test_failIntrospection() public { + bool isSupported; + + isSupported = acct.supportsInterface(type(IERC721Receiver).interfaceId); + assertEq(isSupported, false); + isSupported = acct.supportsInterface(type(IERC777Recipient).interfaceId); + assertEq(isSupported, false); + isSupported = acct.supportsInterface(type(IERC1155Receiver).interfaceId); + assertEq(isSupported, false); + } + + function test_passIntrospection() public { + _initPlugin(); + + bool isSupported; + + isSupported = acct.supportsInterface(type(IERC721Receiver).interfaceId); + assertEq(isSupported, true); + isSupported = acct.supportsInterface(type(IERC777Recipient).interfaceId); + assertEq(isSupported, true); + isSupported = acct.supportsInterface(type(IERC1155Receiver).interfaceId); + assertEq(isSupported, true); + } + + /** + * NON-TEST FUNCTIONS - USED SO MINT DOESNT FAIL + */ + function onERC1155Received(address, address, uint256, uint256, bytes calldata) + external + pure + override + returns (bytes4) + { + return IERC1155Receiver.onERC1155Received.selector; + } + + function onERC1155BatchReceived(address, address, uint256[] calldata, uint256[] calldata, bytes calldata) + external + pure + override + returns (bytes4) + { + return IERC1155Receiver.onERC1155BatchReceived.selector; + } + + function supportsInterface(bytes4) external pure override returns (bool) { + return false; + } +} diff --git a/test/plugin/owner/MultiOwnerPlugin.t.sol b/test/plugin/owner/MultiOwnerPlugin.t.sol new file mode 100644 index 00000000..e1d507de --- /dev/null +++ b/test/plugin/owner/MultiOwnerPlugin.t.sol @@ -0,0 +1,260 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.21; + +import {Test} from "forge-std/Test.sol"; + +import {EntryPoint} from "@eth-infinitism/account-abstraction/core/EntryPoint.sol"; +import {ECDSA} from "@openzeppelin/contracts/utils/cryptography/ECDSA.sol"; + +import {MultiOwnerPlugin} from "../../../src/plugins/owner/MultiOwnerPlugin.sol"; +import {IMultiOwnerPlugin} from "../../../src/plugins/owner/IMultiOwnerPlugin.sol"; +import {BasePlugin} from "../../../src/plugins/BasePlugin.sol"; +import {IEntryPoint} from "../../../src/interfaces/erc4337/IEntryPoint.sol"; +import {UserOperation} from "../../../src/interfaces/erc4337/UserOperation.sol"; +import {PluginManifest} from "../../../src/interfaces/IPlugin.sol"; + +import {ContractOwner} from "../../mocks/ContractOwner.sol"; +import {Utils} from "../../Utils.sol"; + +contract MultiOwnerPluginTest is Test { + using ECDSA for bytes32; + + MultiOwnerPlugin public plugin; + IEntryPoint public entryPoint; + + bytes4 internal constant _1271_MAGIC_VALUE = 0x1626ba7e; + address public accountA; + address public b; + + address public owner1; + address public owner2; + address public owner3; + ContractOwner public contractOwner; + address[] public ownerArray; + + function setUp() public { + plugin = new MultiOwnerPlugin(); + entryPoint = IEntryPoint(address(new EntryPoint())); + + accountA = address(new EntryPoint()); + b = makeAddr("b"); + owner1 = makeAddr("owner1"); + owner2 = makeAddr("owner2"); + owner3 = makeAddr("owner3"); + contractOwner = new ContractOwner(); + + // set up owners for accountA + ownerArray = new address[](3); + ownerArray[0] = owner1; + ownerArray[1] = owner2; + ownerArray[2] = owner3; + + vm.startPrank(accountA); + plugin.onInstall(abi.encode(ownerArray)); + } + + function test_pluginManifest() public { + PluginManifest memory manifest = plugin.pluginManifest(); + // 5 execution functions + assertEq(5, manifest.executionFunctions.length); + // 5 native + 1 plugin exec func + assertEq(6, manifest.userOpValidationFunctions.length); + // 5 native + 1 plugin exec func + 4 plugin view func + assertEq(10, manifest.runtimeValidationFunctions.length); + } + + function test_onUninstall_success() public { + plugin.onUninstall(abi.encode("")); + address[] memory returnedOwners = plugin.ownersOf(accountA); + assertEq(returnedOwners, plugin.owners()); + assertEq(0, returnedOwners.length); + } + + function test_onInstall_success() public { + address[] memory owners = new address[](1); + owners[0] = owner1; + + vm.startPrank(address(contractOwner)); + plugin.onInstall(abi.encode(owners)); + address[] memory returnedOwners = plugin.ownersOf(address(contractOwner)); + assertEq(returnedOwners, plugin.owners()); + assertEq(returnedOwners.length, 1); + assertEq(returnedOwners[0], owner1); + vm.stopPrank(); + } + + function test_eip712Domain() public { + assertEq(true, plugin.isOwnerOf(accountA, owner2)); + assertEq(false, plugin.isOwnerOf(accountA, address(contractOwner))); + assertEq(true, plugin.isOwner(owner2)); + assertEq(false, plugin.isOwner(address(contractOwner))); + + ( + bytes1 fields, + string memory name, + string memory version, + uint256 chainId, + address verifyingContract, + bytes32 salt, + uint256[] memory extensions + ) = plugin.eip712Domain(); + assertEq(fields, hex"0f"); + assertEq(name, "Multi Owner Plugin"); + assertEq(version, "1.0.0"); + assertEq(chainId, block.chainid); + assertEq(verifyingContract, accountA); + assertEq(salt, bytes32(0)); + assertEq(extensions.length, 0); + } + + function test_updateOwners_failWithEmptyOwners() public { + vm.expectRevert(IMultiOwnerPlugin.EmptyOwnersNotAllowed.selector); + plugin.updateOwners(new address[](0), ownerArray); + } + + function test_updateOwners_success() public { + (address[] memory owners) = plugin.ownersOf(accountA); + assertEq(owners, plugin.owners()); + assertEq(Utils.reverseAddressArray(ownerArray), owners); + + // remove should also work + address[] memory ownersToRemove = new address[](2); + ownersToRemove[0] = owner1; + ownersToRemove[1] = owner2; + + plugin.updateOwners(new address[](0), ownersToRemove); + + (address[] memory newOwnerList) = plugin.ownersOf(accountA); + assertEq(newOwnerList, plugin.owners()); + assertEq(newOwnerList.length, 1); + assertEq(newOwnerList[0], owner3); + } + + function test_updateOwners_failWithNotExist() public { + address[] memory ownersToRemove = new address[](1); + ownersToRemove[0] = address(contractOwner); + + vm.expectRevert( + abi.encodeWithSelector(IMultiOwnerPlugin.OwnerDoesNotExist.selector, address(contractOwner)) + ); + plugin.updateOwners(new address[](0), ownersToRemove); + } + + function testFuzz_isValidSignature_EOAOwner(string memory salt, bytes32 digest) public { + // range bound the possible set of priv keys + (address signer, uint256 privateKey) = makeAddrAndKey(salt); + bytes32 messageDigest = plugin.getMessageHash(address(accountA), abi.encode(digest)); + (uint8 v, bytes32 r, bytes32 s) = vm.sign(privateKey, messageDigest); + + address[] memory ownersToAdd = new address[](1); + ownersToAdd[0] = signer; + + assertEq(plugin.isOwnerOf(accountA, signer), plugin.isOwner(signer)); + + if (!plugin.isOwnerOf(accountA, signer)) { + // sig check should fail + assertEq(bytes4(0xFFFFFFFF), plugin.isValidSignature(digest, abi.encodePacked(r, s, v))); + + plugin.updateOwners(ownersToAdd, new address[](0)); + } + + // sig check should pass + assertEq(_1271_MAGIC_VALUE, plugin.isValidSignature(digest, abi.encodePacked(r, s, v))); + } + + function testFuzz_isValidSignature_ContractOwner(bytes32 digest) public { + address[] memory ownersToAdd = new address[](1); + ownersToAdd[0] = address(contractOwner); + plugin.updateOwners(ownersToAdd, new address[](0)); + + bytes32 messageDigest = plugin.getMessageHash(address(accountA), abi.encode(digest)); + bytes memory signature = contractOwner.sign(messageDigest); + assertEq(_1271_MAGIC_VALUE, plugin.isValidSignature(digest, signature)); + } + + function test_runtimeValidationFunction_OwnerOrSelf() public { + // should pass with owner as sender + plugin.runtimeValidationFunction( + uint8(IMultiOwnerPlugin.FunctionId.RUNTIME_VALIDATION_OWNER_OR_SELF), owner1, 0, "" + ); + + // should fail without owner as sender + vm.expectRevert(IMultiOwnerPlugin.NotAuthorized.selector); + plugin.runtimeValidationFunction( + uint8(IMultiOwnerPlugin.FunctionId.RUNTIME_VALIDATION_OWNER_OR_SELF), address(contractOwner), 0, "" + ); + } + + function test_multiOwnerPlugin_sentinelIsNotOwner() public { + assertEq(false, plugin.isOwner(address(1))); + } + + function testFuzz_userOpValidationFunction_ContractOwner(UserOperation memory userOp) public { + bytes32 userOpHash = entryPoint.getUserOpHash(userOp); + userOp.signature = contractOwner.sign(userOpHash); + + // should fail without owner access + uint256 resFail = plugin.userOpValidationFunction( + uint8(IMultiOwnerPlugin.FunctionId.USER_OP_VALIDATION_OWNER), userOp, userOpHash + ); + assertEq(resFail, 1); + + address[] memory ownersToAdd = new address[](1); + ownersToAdd[0] = address(contractOwner); + plugin.updateOwners(ownersToAdd, new address[](0)); + + // should pass with owner access + uint256 resSuccess = plugin.userOpValidationFunction( + uint8(IMultiOwnerPlugin.FunctionId.USER_OP_VALIDATION_OWNER), userOp, userOpHash + ); + assertEq(resSuccess, 0); + } + + function testFuzz_userOpValidationFunction_EOAOwner(string memory salt, UserOperation memory userOp) public { + // range bound the possible set of priv keys + (address signer, uint256 privateKey) = makeAddrAndKey(salt); + bytes32 userOpHash = entryPoint.getUserOpHash(userOp); + (uint8 v, bytes32 r, bytes32 s) = vm.sign(privateKey, userOpHash.toEthSignedMessageHash()); + + // sig cannot cover the whole userop struct since userop struct has sig field + userOp.signature = abi.encodePacked(r, s, v); + + address[] memory ownersToAdd = new address[](1); + ownersToAdd[0] = signer; + + assertEq(plugin.isOwnerOf(accountA, signer), plugin.isOwner(signer)); + + // Only check that the signature should fail if the signer is not already an owner + if (!plugin.isOwnerOf(accountA, signer)) { + // should fail without owner access + uint256 resFail = plugin.userOpValidationFunction( + uint8(IMultiOwnerPlugin.FunctionId.USER_OP_VALIDATION_OWNER), userOp, userOpHash + ); + assertEq(resFail, 1); + // add signer to owner + plugin.updateOwners(ownersToAdd, new address[](0)); + } + + // should pass with owner access + uint256 resSuccess = plugin.userOpValidationFunction( + uint8(IMultiOwnerPlugin.FunctionId.USER_OP_VALIDATION_OWNER), userOp, userOpHash + ); + assertEq(resSuccess, 0); + } + + function test_pluginInitializeGuards() public { + plugin.onUninstall(bytes("")); + + address[] memory addrArr = new address[](1); + addrArr[0] = address(this); + + // can't transfer owner if not initialized yet + vm.expectRevert(abi.encodeWithSelector(BasePlugin.NotInitialized.selector)); + plugin.updateOwners(addrArr, new address[](0)); + + // can't oninstall twice + plugin.onInstall(abi.encode(addrArr, new address[](0))); + vm.expectRevert(abi.encodeWithSelector(BasePlugin.AlreadyInitialized.selector)); + plugin.onInstall(abi.encode(addrArr, new address[](0))); + } +} diff --git a/test/plugin/owner/MultiOwnerPluginIntegration.t.sol b/test/plugin/owner/MultiOwnerPluginIntegration.t.sol new file mode 100644 index 00000000..0cce360b --- /dev/null +++ b/test/plugin/owner/MultiOwnerPluginIntegration.t.sol @@ -0,0 +1,174 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.21; + +import {Test} from "forge-std/Test.sol"; + +import {EntryPoint} from "@eth-infinitism/account-abstraction/core/EntryPoint.sol"; + +import {ECDSA} from "@openzeppelin/contracts/utils/cryptography/ECDSA.sol"; +import {IERC1271} from "@openzeppelin/contracts/interfaces/IERC1271.sol"; + +import {UpgradeableModularAccount} from "../../../src/account/UpgradeableModularAccount.sol"; +import {IEntryPoint} from "../../../src/interfaces/erc4337/IEntryPoint.sol"; +import {UserOperation} from "../../../src/interfaces/erc4337/UserOperation.sol"; +import {MultiOwnerPlugin} from "../../../src/plugins/owner/MultiOwnerPlugin.sol"; +import {IMultiOwnerPlugin} from "../../../src/plugins/owner/IMultiOwnerPlugin.sol"; +import {Utils} from "../../Utils.sol"; +import {Call} from "../../../src/interfaces/IStandardExecutor.sol"; +import {FunctionReference} from "../../../src/libraries/FunctionReferenceLib.sol"; +import {IPluginManager} from "../../../src/interfaces/IPluginManager.sol"; + +import {Counter} from "../../mocks/Counter.sol"; +import {MultiOwnerMSCAFactory} from "../../../src/factory/MultiOwnerMSCAFactory.sol"; +import {Utils} from "../../Utils.sol"; + +contract MultiOwnerPluginIntegration is Test { + using ECDSA for bytes32; + + // bytes4(keccak256("isValidSignature(bytes32,bytes)")) + bytes4 internal constant _1271_MAGIC_VALUE = 0x1626ba7e; + bytes4 internal constant _1271_MAGIC_VALUE_FAILURE = 0xffffffff; + + IEntryPoint public entryPoint; + MultiOwnerPlugin public multiOwnerPlugin; + MultiOwnerMSCAFactory public factory; + + Counter public counter; + address payable public beneficiary; + address public user1; + uint256 public user1Key; + + address public owner1; + uint256 public owner1Key; + UpgradeableModularAccount public account; + + address public owner2; + uint256 public owner2Key; + + address[] public owners; + + function setUp() public { + // setup dependencies and helper contract + counter = new Counter(); + entryPoint = IEntryPoint(address(new EntryPoint())); + beneficiary = payable(makeAddr("beneficiary")); + (user1, user1Key) = makeAddrAndKey("user1"); + (owner1, owner1Key) = makeAddrAndKey("owner1"); + (owner2, owner2Key) = makeAddrAndKey("owner2"); + + // setup plugins and factory + multiOwnerPlugin = new MultiOwnerPlugin(); + factory = new MultiOwnerMSCAFactory( + address(this), + address(multiOwnerPlugin), + address(new UpgradeableModularAccount(IEntryPoint(address(entryPoint)))), + keccak256(abi.encode(multiOwnerPlugin.pluginManifest())), + entryPoint + ); + + // setup account with MultiOwnerMSCAFactory + owners = new address[](2); + owners[0] = owner1; + owners[1] = owner2; + account = UpgradeableModularAccount(payable(factory.getAddress(0, owners))); + vm.deal(address(account), 100 ether); + factory.createAccount(0, owners); + } + + function test_ownerPlugin_successInstallation() public { + assertTrue(IMultiOwnerPlugin(address(account)).isOwner(owner1)); + assertTrue(IMultiOwnerPlugin(address(account)).isOwner(owner2)); + assertEq(Utils.reverseAddressArray(owners), IMultiOwnerPlugin(address(account)).owners()); + } + + function test_runtimeValidation_alwaysAllow_isValidSignature() public { + bytes32 digest = bytes32("digest"); + bytes32 messageDigest = multiOwnerPlugin.getMessageHash(address(account), abi.encode(digest)); + bytes memory signature; + + { + // should fail for sig from owner1 due to wrongly encode message + bytes32 messageDigestBad = keccak256( + abi.encodePacked("\x19\x01", keccak256(abi.encode(user1, block.chainid)), abi.encode(digest)) + ); + (uint8 v0, bytes32 r0, bytes32 s0) = vm.sign(owner1Key, messageDigestBad); + signature = abi.encodePacked(r0, s0, v0); + assertEq(_1271_MAGIC_VALUE_FAILURE, IERC1271(address(account)).isValidSignature(digest, signature)); + } + + // should pass for sig from owner1 + (uint8 v, bytes32 r, bytes32 s) = vm.sign(owner1Key, messageDigest); + signature = abi.encodePacked(r, s, v); + assertEq(_1271_MAGIC_VALUE, IERC1271(address(account)).isValidSignature(digest, signature)); + + // should pass for sig from owner2 + (uint8 v1, bytes32 r1, bytes32 s1) = vm.sign(owner1Key, messageDigest); + signature = abi.encodePacked(r1, s1, v1); + assertEq(_1271_MAGIC_VALUE, IERC1271(address(account)).isValidSignature(digest, signature)); + + // should fail for sig NOT from owner + (uint8 v2, bytes32 r2, bytes32 s2) = vm.sign(user1Key, messageDigest); + signature = abi.encodePacked(r2, s2, v2); + assertEq(_1271_MAGIC_VALUE_FAILURE, IERC1271(address(account)).isValidSignature(digest, signature)); + } + + function test_runtimeValidation_ownerOrSelf_standardExecute() public { + // should send 1 ETH to user1 by owner + uint256 startBal = user1.balance; + vm.startPrank(owner1); + account.execute(user1, 1 ether, ""); + assertEq(1 ether, user1.balance - startBal); + + // should NOT send 1 ETH to user1 by non-owner + startBal = user1.balance; + vm.startPrank(user1); + vm.expectRevert( + abi.encodeWithSelector( + UpgradeableModularAccount.RuntimeValidationFunctionReverted.selector, + multiOwnerPlugin, + IMultiOwnerPlugin.FunctionId.RUNTIME_VALIDATION_OWNER_OR_SELF, + abi.encodeWithSelector(IMultiOwnerPlugin.NotAuthorized.selector) + ) + ); + account.execute(user1, 1 ether, ""); + assertEq(0 ether, user1.balance - startBal); + } + + function test_userOpValidation_owner_standardExecute() public { + UserOperation memory userOp = UserOperation({ + sender: address(account), + nonce: 0, + initCode: "", + callData: abi.encodeCall(account.execute, (user1, 1 ether, "")), + callGasLimit: 50000, + verificationGasLimit: 1200000, + preVerificationGas: 0, + maxFeePerGas: 1, + maxPriorityFeePerGas: 1, + paymasterAndData: "", + signature: "" + }); + + // should send 1 ETH to user1 by owner + bytes32 userOpHash = entryPoint.getUserOpHash(userOp); + (uint8 v, bytes32 r, bytes32 s) = vm.sign(owner1Key, userOpHash.toEthSignedMessageHash()); + userOp.signature = abi.encodePacked(r, s, v); + UserOperation[] memory userOps = new UserOperation[](1); + userOps[0] = userOp; + uint256 startBal = user1.balance; + entryPoint.handleOps(userOps, beneficiary); + assertEq(1 ether, user1.balance - startBal); + + // should NOT send 1 ETH to user1 by non-owner + userOp.nonce++; + bytes32 userOpHash2 = entryPoint.getUserOpHash(userOp); + (uint8 v2, bytes32 r2, bytes32 s2) = vm.sign(user1Key, userOpHash2.toEthSignedMessageHash()); + userOp.signature = abi.encodePacked(r2, s2, v2); + UserOperation[] memory userOps2 = new UserOperation[](1); + userOps2[0] = userOp; + startBal = user1.balance; + vm.expectRevert(abi.encodeWithSelector(IEntryPoint.FailedOp.selector, 0, "AA24 signature error")); + entryPoint.handleOps(userOps2, beneficiary); + assertEq(0 ether, user1.balance - startBal); + } +} diff --git a/test/plugin/session/SessionKeyPluginWithMultiOwner.t.sol b/test/plugin/session/SessionKeyPluginWithMultiOwner.t.sol new file mode 100644 index 00000000..566826ec --- /dev/null +++ b/test/plugin/session/SessionKeyPluginWithMultiOwner.t.sol @@ -0,0 +1,339 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.21; + +import {Test} from "forge-std/Test.sol"; + +import {ECDSA} from "@openzeppelin/contracts/utils/cryptography/ECDSA.sol"; +import {EntryPoint} from "@eth-infinitism/account-abstraction/core/EntryPoint.sol"; + +import {UpgradeableModularAccount} from "../../../src/account/UpgradeableModularAccount.sol"; +import {BasePlugin} from "../../../src/plugins/BasePlugin.sol"; +import {IMultiOwnerPlugin} from "../../../src/plugins/owner/IMultiOwnerPlugin.sol"; +import {MultiOwnerPlugin} from "../../../src/plugins/owner/MultiOwnerPlugin.sol"; +import {ISessionKeyPlugin} from "../../../src/plugins/session/ISessionKeyPlugin.sol"; +import {SessionKeyPlugin} from "../../../src/plugins/session/SessionKeyPlugin.sol"; +import {IEntryPoint} from "../../../src/interfaces/erc4337/IEntryPoint.sol"; +import {UserOperation} from "../../../src/interfaces/erc4337/UserOperation.sol"; +import {IPluginManager} from "../../../src/interfaces/IPluginManager.sol"; +import {FunctionReference, FunctionReferenceLib} from "../../../src/libraries/FunctionReferenceLib.sol"; +import {Call} from "../../../src/interfaces/IStandardExecutor.sol"; + +import {MultiOwnerMSCAFactory} from "../../../src/factory/MultiOwnerMSCAFactory.sol"; + +contract SessionKeyPluginWithMultiOwnerTest is Test { + using ECDSA for bytes32; + + IEntryPoint entryPoint; + address payable beneficiary; + MultiOwnerPlugin multiOwnerPlugin; + MultiOwnerMSCAFactory factory; + SessionKeyPlugin sessionKeyPlugin; + + address owner1; + uint256 owner1Key; + address[] public owners; + UpgradeableModularAccount account1; + + uint256 constant CALL_GAS_LIMIT = 70000; + uint256 constant VERIFICATION_GAS_LIMIT = 1000000; + + address payable recipient; + + function setUp() public { + entryPoint = IEntryPoint(address(new EntryPoint())); + (owner1, owner1Key) = makeAddrAndKey("owner1"); + beneficiary = payable(makeAddr("beneficiary")); + recipient = payable(makeAddr("recipient")); + vm.deal(beneficiary, 1 wei); + vm.deal(recipient, 1 wei); + + multiOwnerPlugin = new MultiOwnerPlugin(); + address impl = address(new UpgradeableModularAccount(entryPoint)); + + factory = + new MultiOwnerMSCAFactory(address(this), address(multiOwnerPlugin), impl, keccak256(abi.encode(multiOwnerPlugin.pluginManifest())), entryPoint); + + sessionKeyPlugin = new SessionKeyPlugin(); + + owners = new address[](1); + owners[0] = owner1; + account1 = UpgradeableModularAccount(payable(factory.createAccount(0, owners))); + vm.deal(address(account1), 100 ether); + + bytes32 manifestHash = keccak256(abi.encode(sessionKeyPlugin.pluginManifest())); + FunctionReference[] memory dependencies = new FunctionReference[](2); + dependencies[0] = FunctionReferenceLib.pack( + address(multiOwnerPlugin), uint8(IMultiOwnerPlugin.FunctionId.USER_OP_VALIDATION_OWNER) + ); + dependencies[1] = FunctionReferenceLib.pack( + address(multiOwnerPlugin), uint8(IMultiOwnerPlugin.FunctionId.RUNTIME_VALIDATION_OWNER_OR_SELF) + ); + vm.prank(owner1); + account1.installPlugin({ + plugin: address(sessionKeyPlugin), + manifestHash: manifestHash, + pluginInitData: abi.encode(new address[](0)), + dependencies: dependencies, + injectedHooks: new IPluginManager.InjectedHook[](0) + }); + } + + function test_sessionKey_addKeySuccess() public { + address[] memory sessionKeysToAdd = new address[](1); + sessionKeysToAdd[0] = makeAddr("sessionKey1"); + + vm.prank(owner1); + SessionKeyPlugin(address(account1)).updateSessionKeys( + sessionKeysToAdd, new SessionKeyPlugin.SessionKeyToRemove[](0) + ); + + address[] memory sessionKeys = SessionKeyPlugin(address(account1)).getSessionKeys(); + assertEq(sessionKeys.length, 1); + assertEq(sessionKeys[0], sessionKeysToAdd[0]); + } + + function test_sessionKey_addAndRemoveKeys() public { + address[] memory sessionKeysToAdd = new address[](2); + sessionKeysToAdd[0] = makeAddr("sessionKey1"); + sessionKeysToAdd[1] = makeAddr("sessionKey2"); + + vm.prank(owner1); + SessionKeyPlugin(address(account1)).updateSessionKeys( + sessionKeysToAdd, new SessionKeyPlugin.SessionKeyToRemove[](0) + ); + + SessionKeyPlugin.SessionKeyToRemove[] memory sessionKeysToRemove = + new ISessionKeyPlugin.SessionKeyToRemove[](1); + sessionKeysToRemove[0] = ISessionKeyPlugin.SessionKeyToRemove({ + sessionKey: sessionKeysToAdd[0], + predecessor: bytes32(bytes20(sessionKeysToAdd[1])) + }); + vm.prank(owner1); + SessionKeyPlugin(address(account1)).updateSessionKeys(new address[](0), sessionKeysToRemove); + + address[] memory sessionKeys = SessionKeyPlugin(address(account1)).getSessionKeys(); + assertEq(sessionKeys.length, 1); + assertEq(sessionKeys[0], sessionKeysToAdd[1]); + } + + function test_sessionKey_useSessionKey() public { + address[] memory sessionKeysToAdd = new address[](1); + (address sessionKey, uint256 sessionKeyPrivate) = makeAddrAndKey("sessionKey1"); + sessionKeysToAdd[0] = sessionKey; + + vm.prank(owner1); + SessionKeyPlugin(address(account1)).updateSessionKeys( + sessionKeysToAdd, new SessionKeyPlugin.SessionKeyToRemove[](0) + ); + + Call[] memory calls = new Call[](1); + calls[0] = Call({target: recipient, value: 1 wei, data: ""}); + + UserOperation memory userOp = UserOperation({ + sender: address(account1), + nonce: 0, + initCode: "", + callData: abi.encodeCall( + ISessionKeyPlugin(address(sessionKeyPlugin)).executeWithSessionKey, (calls, sessionKey) + ), + callGasLimit: CALL_GAS_LIMIT, + verificationGasLimit: VERIFICATION_GAS_LIMIT, + preVerificationGas: 0, + maxFeePerGas: 2, + maxPriorityFeePerGas: 1, + paymasterAndData: "", + signature: "" + }); + + bytes32 userOpHash = entryPoint.getUserOpHash(userOp); + (uint8 v, bytes32 r, bytes32 s) = vm.sign(sessionKeyPrivate, userOpHash.toEthSignedMessageHash()); + userOp.signature = abi.encodePacked(r, s, v); + + UserOperation[] memory userOps = new UserOperation[](1); + userOps[0] = userOp; + + entryPoint.handleOps(userOps, beneficiary); + + assertEq(recipient.balance, 2 wei); + } + + function testFuzz_sessionKey_userOpValidation_valid(uint16 seed) public { + uint256[] memory privateKeys = _createSessionKeys(uint8(seed)); + + // Pick a random signer to use to validate with + uint256 signerPrivateKey = privateKeys[(seed >> 8) % privateKeys.length]; + address signerAddress = vm.addr(signerPrivateKey); + + // Construct a user op to validate against + Call[] memory calls = new Call[](1); + calls[0] = Call({target: recipient, value: 1 wei, data: ""}); + UserOperation memory userOp = UserOperation({ + sender: address(account1), + nonce: 0, + initCode: "", + callData: abi.encodeCall( + ISessionKeyPlugin(address(sessionKeyPlugin)).executeWithSessionKey, (calls, signerAddress) + ), + callGasLimit: CALL_GAS_LIMIT, + verificationGasLimit: VERIFICATION_GAS_LIMIT, + preVerificationGas: 0, + maxFeePerGas: 2, + maxPriorityFeePerGas: 1, + paymasterAndData: "", + signature: "" + }); + + bytes32 userOpHash = entryPoint.getUserOpHash(userOp); + (uint8 v, bytes32 r, bytes32 s) = vm.sign(signerPrivateKey, userOpHash.toEthSignedMessageHash()); + userOp.signature = abi.encodePacked(r, s, v); + + vm.prank(address(account1)); + uint256 result = sessionKeyPlugin.userOpValidationFunction( + uint8(ISessionKeyPlugin.FunctionId.USER_OP_VALIDATION_SESSION_KEY), userOp, userOpHash + ); + + assertEq(result, 0); + } + + function testFuzz_sessionKey_userOpValidation_invalid(uint8 sessionKeysSeed, uint64 signerSeed) public { + _createSessionKeys(sessionKeysSeed); + + (address signer, uint256 signerPrivate) = + makeAddrAndKey(string.concat("Signer", vm.toString(uint32(signerSeed)))); + + // The signer should not be a session key of the plugin - this is exceedingly unlikely but checking + // anyways. + vm.assume(!sessionKeyPlugin.isSessionKeyOf(address(account1), signer)); + + // Construct a user op to validate against + Call[] memory calls = new Call[](1); + calls[0] = Call({target: recipient, value: 1 wei, data: ""}); + UserOperation memory userOp = UserOperation({ + sender: address(account1), + nonce: 0, + initCode: "", + callData: abi.encodeCall( + ISessionKeyPlugin(address(sessionKeyPlugin)).executeWithSessionKey, (calls, signer) + ), + callGasLimit: CALL_GAS_LIMIT, + verificationGasLimit: VERIFICATION_GAS_LIMIT, + preVerificationGas: 0, + maxFeePerGas: 2, + maxPriorityFeePerGas: 1, + paymasterAndData: "", + signature: "" + }); + + bytes32 userOpHash = entryPoint.getUserOpHash(userOp); + (uint8 v, bytes32 r, bytes32 s) = vm.sign(signerPrivate, userOpHash.toEthSignedMessageHash()); + userOp.signature = abi.encodePacked(r, s, v); + + vm.prank(address(account1)); + uint256 result = sessionKeyPlugin.userOpValidationFunction( + uint8(ISessionKeyPlugin.FunctionId.USER_OP_VALIDATION_SESSION_KEY), userOp, userOpHash + ); + + assertEq(result, 1); + } + + function testFuzz_sessionKey_invalidFunctionId(uint8 functionId, UserOperation memory userOp) public { + vm.assume(functionId != uint8(ISessionKeyPlugin.FunctionId.USER_OP_VALIDATION_SESSION_KEY)); + + bytes32 userOpHash = entryPoint.getUserOpHash(userOp); + vm.expectRevert(abi.encodeWithSelector(BasePlugin.NotImplemented.selector)); + sessionKeyPlugin.userOpValidationFunction(functionId, userOp, userOpHash); + } + + // getPredecessor test case with sentinel value as predecessor + function test_sessionKey_getPredecessor_sentinel() public { + address[] memory sessionKeysToAdd = new address[](2); + sessionKeysToAdd[0] = makeAddr("sessionKey1"); + sessionKeysToAdd[1] = makeAddr("sessionKey2"); + + vm.prank(owner1); + SessionKeyPlugin(address(account1)).updateSessionKeys( + sessionKeysToAdd, new SessionKeyPlugin.SessionKeyToRemove[](0) + ); + + SessionKeyPlugin.SessionKeyToRemove[] memory sessionKeysToRemove = + new ISessionKeyPlugin.SessionKeyToRemove[](1); + sessionKeysToRemove[0] = ISessionKeyPlugin.SessionKeyToRemove({ + sessionKey: sessionKeysToAdd[0], + predecessor: sessionKeyPlugin.findPredecessor(address(account1), sessionKeysToAdd[0]) + }); + vm.prank(owner1); + SessionKeyPlugin(address(account1)).updateSessionKeys(new address[](0), sessionKeysToRemove); + + address[] memory sessionKeys = SessionKeyPlugin(address(account1)).getSessionKeys(); + assertEq(sessionKeys.length, 1); + assertEq(sessionKeys[0], sessionKeysToAdd[1]); + } + + // getPredecessor test case with address value as predecessor + function test_sessionKey_getPredecessor_address() public { + address[] memory sessionKeysToAdd = new address[](2); + sessionKeysToAdd[0] = makeAddr("sessionKey1"); + sessionKeysToAdd[1] = makeAddr("sessionKey2"); + + vm.prank(owner1); + SessionKeyPlugin(address(account1)).updateSessionKeys( + sessionKeysToAdd, new SessionKeyPlugin.SessionKeyToRemove[](0) + ); + + SessionKeyPlugin.SessionKeyToRemove[] memory sessionKeysToRemove = + new ISessionKeyPlugin.SessionKeyToRemove[](1); + sessionKeysToRemove[0] = ISessionKeyPlugin.SessionKeyToRemove({ + sessionKey: sessionKeysToAdd[1], + predecessor: sessionKeyPlugin.findPredecessor(address(account1), sessionKeysToAdd[1]) + }); + vm.prank(owner1); + SessionKeyPlugin(address(account1)).updateSessionKeys(new address[](0), sessionKeysToRemove); + + address[] memory sessionKeys = SessionKeyPlugin(address(account1)).getSessionKeys(); + assertEq(sessionKeys.length, 1); + assertEq(sessionKeys[0], sessionKeysToAdd[0]); + } + + function test_sessionKey_getPredecessor_missing() public { + address[] memory sessionKeysToAdd = new address[](1); + sessionKeysToAdd[0] = makeAddr("sessionKey1"); + + vm.prank(owner1); + SessionKeyPlugin(address(account1)).updateSessionKeys( + sessionKeysToAdd, new SessionKeyPlugin.SessionKeyToRemove[](0) + ); + + address key2 = makeAddr("sessionKey2"); + vm.expectRevert(abi.encodeWithSelector(ISessionKeyPlugin.SessionKeyNotFound.selector, key2)); + sessionKeyPlugin.findPredecessor(address(account1), key2); + } + + function test_sessionKey_doesNotContainSentinelValue() public { + assertFalse(sessionKeyPlugin.isSessionKeyOf(address(account1), address(1))); + + address[] memory sessionKeysToAdd = new address[](1); + sessionKeysToAdd[0] = makeAddr("sessionKey1"); + + vm.prank(owner1); + SessionKeyPlugin(address(account1)).updateSessionKeys( + sessionKeysToAdd, new SessionKeyPlugin.SessionKeyToRemove[](0) + ); + + assertFalse(sessionKeyPlugin.isSessionKeyOf(address(account1), address(1))); + } + + function _createSessionKeys(uint8 seed) internal returns (uint256[] memory privateKeys) { + uint256 addressCount = (seed % 16) + 1; + + address[] memory sessionKeysToAdd = new address[](addressCount); + privateKeys = new uint256[](addressCount); + for (uint256 i = 0; i < addressCount; i++) { + (sessionKeysToAdd[i], privateKeys[i]) = makeAddrAndKey(string.concat("sessionKey", vm.toString(i))); + } + + vm.prank(owner1); + SessionKeyPlugin(address(account1)).updateSessionKeys( + sessionKeysToAdd, new SessionKeyPlugin.SessionKeyToRemove[](0) + ); + } +} diff --git a/test/plugin/session/permissions/SessionKeyERC20SpendLimits.t.sol b/test/plugin/session/permissions/SessionKeyERC20SpendLimits.t.sol new file mode 100644 index 00000000..bf4b846b --- /dev/null +++ b/test/plugin/session/permissions/SessionKeyERC20SpendLimits.t.sol @@ -0,0 +1,1070 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.21; + +import {Test, console} from "forge-std/Test.sol"; + +import {ECDSA} from "@openzeppelin/contracts/utils/cryptography/ECDSA.sol"; +import {EntryPoint} from "@eth-infinitism/account-abstraction/core/EntryPoint.sol"; + +import {UpgradeableModularAccount} from "../../../../src/account/UpgradeableModularAccount.sol"; +import {IMultiOwnerPlugin} from "../../../../src/plugins/owner/IMultiOwnerPlugin.sol"; +import {MultiOwnerPlugin} from "../../../../src/plugins/owner/MultiOwnerPlugin.sol"; +import {ISessionKeyPlugin} from "../../../../src/plugins/session/ISessionKeyPlugin.sol"; +import {SessionKeyPlugin} from "../../../../src/plugins/session/SessionKeyPlugin.sol"; +import {ISessionKeyPermissionsPlugin} from + "../../../../src/plugins/session/permissions/ISessionKeyPermissionsPlugin.sol"; +import {ISessionKeyPermissionsUpdates} from + "../../../../src/plugins/session/permissions/ISessionKeyPermissionsUpdates.sol"; +import {SessionKeyPermissionsPlugin} from + "../../../../src/plugins/session/permissions/SessionKeyPermissionsPlugin.sol"; +import {IEntryPoint} from "../../../../src/interfaces/erc4337/IEntryPoint.sol"; +import {UserOperation} from "../../../../src/interfaces/erc4337/UserOperation.sol"; +import {IPluginManager} from "../../../../src/interfaces/IPluginManager.sol"; +import {FunctionReference, FunctionReferenceLib} from "../../../../src/libraries/FunctionReferenceLib.sol"; +import {Call} from "../../../../src/interfaces/IStandardExecutor.sol"; + +import {MultiOwnerMSCAFactory} from "../../../../src/factory/MultiOwnerMSCAFactory.sol"; +import {MockERC20} from "../../../mocks/tokens/MockERC20.sol"; + +contract SessionKeyERC20SpendLimitsTest is Test { + using ECDSA for bytes32; + + IEntryPoint entryPoint; + address payable beneficiary; + + MultiOwnerPlugin public multiOwnerPlugin; + MultiOwnerMSCAFactory public factory; + SessionKeyPlugin sessionKeyPlugin; + SessionKeyPermissionsPlugin sessionKeyPermissionsPlugin; + + address owner1; + uint256 owner1Key; + UpgradeableModularAccount account1; + + address sessionKey1; + uint256 sessionKey1Private; + + address recipient1; + address recipient2; + + MockERC20 token1; + MockERC20 token2; + MockERC20 token3; + + // Constants for running user ops + uint256 constant CALL_GAS_LIMIT = 300000; + uint256 constant VERIFICATION_GAS_LIMIT = 1000000; + + function setUp() public { + entryPoint = IEntryPoint(address(new EntryPoint())); + (owner1, owner1Key) = makeAddrAndKey("owner1"); + beneficiary = payable(makeAddr("beneficiary")); + + vm.deal(beneficiary, 1 wei); + + multiOwnerPlugin = new MultiOwnerPlugin(); + factory = new MultiOwnerMSCAFactory( + address(this), + address(multiOwnerPlugin), + address(new UpgradeableModularAccount(entryPoint)), + keccak256(abi.encode(multiOwnerPlugin.pluginManifest())), + entryPoint + ); + owner1 = makeAddr("owner"); + address[] memory owners = new address[](1); + owners[0] = owner1; + account1 = UpgradeableModularAccount(payable(factory.createAccount(0, owners))); + + vm.deal(address(account1), 100 ether); + + sessionKeyPlugin = new SessionKeyPlugin(); + sessionKeyPermissionsPlugin = new SessionKeyPermissionsPlugin(); + + bytes32 manifestHash = keccak256(abi.encode(sessionKeyPlugin.pluginManifest())); + FunctionReference[] memory dependencies = new FunctionReference[](2); + dependencies[0] = FunctionReferenceLib.pack( + address(multiOwnerPlugin), uint8(IMultiOwnerPlugin.FunctionId.USER_OP_VALIDATION_OWNER) + ); + dependencies[1] = FunctionReferenceLib.pack( + address(multiOwnerPlugin), uint8(IMultiOwnerPlugin.FunctionId.RUNTIME_VALIDATION_OWNER_OR_SELF) + ); + vm.prank(owner1); + account1.installPlugin({ + plugin: address(sessionKeyPlugin), + manifestHash: manifestHash, + pluginInitData: abi.encode(new address[](0)), + dependencies: dependencies, + injectedHooks: new IPluginManager.InjectedHook[](0) + }); + + manifestHash = keccak256(abi.encode(sessionKeyPermissionsPlugin.pluginManifest())); + // Can reuse the same dependencies for this installation, because the requirements are the same. + vm.prank(owner1); + account1.installPlugin({ + plugin: address(sessionKeyPermissionsPlugin), + manifestHash: manifestHash, + pluginInitData: "", + dependencies: dependencies, + injectedHooks: new IPluginManager.InjectedHook[](0) + }); + + // Create and add a session key + (sessionKey1, sessionKey1Private) = makeAddrAndKey("sessionKey1"); + + address[] memory sessionKeysToAdd = new address[](1); + sessionKeysToAdd[0] = sessionKey1; + + vm.prank(owner1); + SessionKeyPlugin(address(account1)).updateSessionKeys( + sessionKeysToAdd, new SessionKeyPlugin.SessionKeyToRemove[](0) + ); + + // Register the session key with the permissions plugin + vm.prank(owner1); + SessionKeyPermissionsPlugin(address(account1)).registerKey(sessionKey1, 0); + + // Disable the allowlist + bytes[] memory updates = new bytes[](1); + updates[0] = abi.encodeCall( + ISessionKeyPermissionsUpdates.setAccessListType, + (ISessionKeyPermissionsPlugin.ContractAccessControlType.NONE) + ); + vm.prank(owner1); + SessionKeyPermissionsPlugin(address(account1)).updateKeyPermissions(sessionKey1, updates); + + // Create recipients' addresses to receive the tokens + recipient1 = makeAddr("recipient1"); + recipient2 = makeAddr("recipient2"); + + // Create the mock token contracts + token1 = new MockERC20("T1"); + token2 = new MockERC20("T2"); + token3 = new MockERC20("T3"); + } + + function test_sessionKeyERC20SpendLimits_validateSetUp() public { + // Check that the session key is registered + assertTrue(SessionKeyPlugin(address(account1)).isSessionKey(sessionKey1)); + + // Check that the session key is registered with the permissions plugin and has its allowlist set up + // correctly + assertTrue( + sessionKeyPermissionsPlugin.getAccessControlType(address(account1), sessionKey1) + == ISessionKeyPermissionsPlugin.ContractAccessControlType.NONE + ); + } + + function testFuzz_sessionKeyERC20SpendLimits_setLimits( + address token, + uint256 limit, + uint48 refreshInterval, + uint48 timestamp + ) public { + // The zero address is not allowed as a token addr. The next test asserts this. + vm.assume(token != address(0)); + + // Pick a timestamp to warp to + vm.warp(timestamp); + + // Assert that the limit starts out unset + ISessionKeyPermissionsPlugin.SpendLimitInfo memory spendLimitInfo = + sessionKeyPermissionsPlugin.getERC20SpendLimitInfo(address(account1), sessionKey1, token); + + assertFalse(spendLimitInfo.hasLimit); + assertEq(spendLimitInfo.limit, 0); + assertEq(spendLimitInfo.refreshInterval, 0); + assertEq(spendLimitInfo.limitUsed, 0); + + // Set the limit + bytes[] memory updates = new bytes[](1); + updates[0] = + abi.encodeCall(ISessionKeyPermissionsUpdates.setERC20SpendLimit, (token, limit, refreshInterval)); + vm.prank(owner1); + SessionKeyPermissionsPlugin(address(account1)).updateKeyPermissions(sessionKey1, updates); + + // Verify the limit can be retrieved + spendLimitInfo = sessionKeyPermissionsPlugin.getERC20SpendLimitInfo(address(account1), sessionKey1, token); + + if (limit == type(uint256).max) { + // If the limit is "set" to this value, it is just removed. + // verify that the values are still as they were before. + assertEq(spendLimitInfo.limit, 0); + assertEq(spendLimitInfo.refreshInterval, 0); + assertEq(spendLimitInfo.limitUsed, 0); + } else { + // The limit is actually set, verify that the values are as expected. + assertTrue(spendLimitInfo.hasLimit); + assertEq(spendLimitInfo.limit, limit); + assertEq(spendLimitInfo.refreshInterval, refreshInterval); + assertEq(spendLimitInfo.limitUsed, 0); + if (refreshInterval == 0) { + assertEq(spendLimitInfo.lastUsedTime, 0); + } else { + assertEq(spendLimitInfo.lastUsedTime, timestamp); + } + } + } + + function test_sessionKeyERC20SpendLimits_tokenAddressZeroFails() public { + bytes[] memory updates = new bytes[](1); + updates[0] = abi.encodeCall(ISessionKeyPermissionsUpdates.setERC20SpendLimit, (address(0), 1000, 0)); + vm.expectRevert(abi.encodeWithSelector(ISessionKeyPermissionsPlugin.InvalidToken.selector)); + vm.prank(owner1); + SessionKeyPermissionsPlugin(address(account1)).updateKeyPermissions(sessionKey1, updates); + } + + function test_sessionKeyERC20SpendLimits_enforceLimit_none_basic() public { + // Give the account a starting balance + token1.mint(address(account1), 100 ether); + + // Set the limit to zero + bytes[] memory updates = new bytes[](1); + updates[0] = + abi.encodeCall(ISessionKeyPermissionsUpdates.setERC20SpendLimit, (address(token1), 0 ether, 0 days)); + vm.prank(owner1); + SessionKeyPermissionsPlugin(address(account1)).updateKeyPermissions(sessionKey1, updates); + + // Run a user op that spends 1 wei, should fail + Call[] memory calls = new Call[](1); + calls[0] = + Call({target: address(token1), data: abi.encodeCall(token1.transfer, (recipient1, 1 wei)), value: 0}); + + // Since the revert happens during execution, we can't check it using vm.expectRevert, since the underlyng + // call to handleOps does not revert. + // Instead, we assert that the transfer call did NOT happen via vm.expectCall with the count set to zero. + vm.expectCall(address(token1), 0 wei, calls[0].data, 0); + _runSessionKeyUserOp(calls, sessionKey1Private, ""); + + // Run a user op that spends 0 wei, should succeed + calls[0] = + Call({target: address(token1), data: abi.encodeCall(token1.transfer, (recipient1, 0 wei)), value: 0}); + vm.expectCall(address(token1), 0 wei, calls[0].data, 1); + _runSessionKeyUserOp(calls, sessionKey1Private, ""); + } + + // Expands on the previous test to cover the case where the spend is batched via multiple method types. + function test_sessionKeyERC20SpendLimits_enforceLimit_none_batch() public { + // Give the account a starting balance + token1.mint(address(account1), 100 ether); + + // Set the limit to zero + bytes[] memory updates = new bytes[](1); + updates[0] = + abi.encodeCall(ISessionKeyPermissionsUpdates.setERC20SpendLimit, (address(token1), 0 ether, 0 days)); + vm.prank(owner1); + SessionKeyPermissionsPlugin(address(account1)).updateKeyPermissions(sessionKey1, updates); + + // run a multi-execution user op that spends 0 wei, should succeed. + Call[] memory calls = new Call[](3); + calls = new Call[](3); + calls[0] = + Call({target: address(token1), data: abi.encodeCall(token1.transfer, (recipient1, 0 wei)), value: 0}); + calls[1] = + Call({target: address(token1), data: abi.encodeCall(token1.transfer, (recipient1, 0 wei)), value: 0}); + calls[2] = + Call({target: address(token1), data: abi.encodeCall(token1.approve, (recipient1, 0 wei)), value: 0}); + + vm.expectCall(address(token1), 0 wei, calls[0].data); + vm.expectCall(address(token1), 0 wei, calls[1].data); + vm.expectCall(address(token1), 0 wei, calls[2].data); + _runSessionKeyUserOp(calls, sessionKey1Private, ""); + } + + function test_sessionKeyERC20SpendLimits_basic_single() public { + // Give the account a starting balance + token1.mint(address(account1), 100 ether); + + // Set the limit to 1 ether + bytes[] memory updates = new bytes[](1); + updates[0] = + abi.encodeCall(ISessionKeyPermissionsUpdates.setERC20SpendLimit, (address(token1), 1 ether, 0 days)); + vm.prank(owner1); + SessionKeyPermissionsPlugin(address(account1)).updateKeyPermissions(sessionKey1, updates); + + // Run a user op that spends 1 wei, should succeed + Call[] memory calls = new Call[](1); + calls[0] = + Call({target: address(token1), data: abi.encodeCall(token1.transfer, (recipient1, 1 wei)), value: 0}); + + vm.expectCall(address(token1), 0 wei, calls[0].data, 1); + _runSessionKeyUserOp(calls, sessionKey1Private, ""); + + // Assert that the limit is updated + ISessionKeyPermissionsPlugin.SpendLimitInfo memory spendLimitInfo = + sessionKeyPermissionsPlugin.getERC20SpendLimitInfo(address(account1), sessionKey1, address(token1)); + + assertEq(spendLimitInfo.limit, 1 ether); + assertEq(spendLimitInfo.limitUsed, 1 wei); + assertEq(spendLimitInfo.refreshInterval, 0); + // Assert that the last used time is not updated when the interval is unset. + assertEq(spendLimitInfo.lastUsedTime, 0); + } + + // Almost a duplicate of the previous test, but asserts that subsequent calls that exceed the budget cause it + // to fail. + function test_sessionKeyERC20SpendLimits_exceedLimit_single() public { + // Give the account a starting balance + token1.mint(address(account1), 100 ether); + + // Set the limit to 1 ether + bytes[] memory updates = new bytes[](1); + updates[0] = + abi.encodeCall(ISessionKeyPermissionsUpdates.setERC20SpendLimit, (address(token1), 1 ether, 0 days)); + vm.prank(owner1); + SessionKeyPermissionsPlugin(address(account1)).updateKeyPermissions(sessionKey1, updates); + + // Run a user op that spends 1 wei, should succeed + Call[] memory calls = new Call[](1); + calls[0] = + Call({target: address(token1), data: abi.encodeCall(token1.transfer, (recipient1, 1 wei)), value: 0}); + + vm.expectCall(address(token1), 0 wei, calls[0].data, 1); + _runSessionKeyUserOp(calls, sessionKey1Private, ""); + + // Assert that the limit is updated + ISessionKeyPermissionsPlugin.SpendLimitInfo memory spendLimitInfo = + sessionKeyPermissionsPlugin.getERC20SpendLimitInfo(address(account1), sessionKey1, address(token1)); + + assertEq(spendLimitInfo.limit, 1 ether); + assertEq(spendLimitInfo.limitUsed, 1 wei); + assertEq(spendLimitInfo.refreshInterval, 0); + // Assert that the last used time is not updated when the interval is unset. + assertEq(spendLimitInfo.lastUsedTime, 0); + + // Run a user op that spends 1 ether, should fail + + calls[0] = + Call({target: address(token1), data: abi.encodeCall(token1.transfer, (recipient1, 1 ether)), value: 0}); + vm.expectCall(address(token1), 0 wei, calls[0].data, 0); + _runSessionKeyUserOp(calls, sessionKey1Private, ""); + + // Assert that the limit is not updated, and remains the same as before + spendLimitInfo = + sessionKeyPermissionsPlugin.getERC20SpendLimitInfo(address(account1), sessionKey1, address(token1)); + + assertEq(spendLimitInfo.limit, 1 ether); + assertEq(spendLimitInfo.limitUsed, 1 wei); + assertEq(spendLimitInfo.refreshInterval, 0); + // Assert that the last used time is not updated when the interval is unset. + assertEq(spendLimitInfo.lastUsedTime, 0); + } + + function test_executeWithSessionKey_success_multipleTransfer() public { + // Give the account a starting balance + token1.mint(address(account1), 100 ether); + + // Set the limit to 1 ether + bytes[] memory updates = new bytes[](1); + updates[0] = + abi.encodeCall(ISessionKeyPermissionsUpdates.setERC20SpendLimit, (address(token1), 1 ether, 0 days)); + vm.prank(owner1); + SessionKeyPermissionsPlugin(address(account1)).updateKeyPermissions(sessionKey1, updates); + + // Run a user op that spends 3 wei, should succeed + Call[] memory calls = new Call[](3); + calls[0] = + Call({target: address(token1), data: abi.encodeCall(token1.transfer, (recipient1, 1 wei)), value: 0}); + calls[1] = + Call({target: address(token1), data: abi.encodeCall(token1.transfer, (recipient1, 1 wei)), value: 0}); + calls[2] = + Call({target: address(token1), data: abi.encodeCall(token1.transfer, (recipient2, 1 wei)), value: 0}); + + vm.expectCall(address(token1), 0 wei, calls[0].data, 2); + vm.expectCall(address(token1), 0 wei, calls[2].data, 1); + _runSessionKeyUserOp(calls, sessionKey1Private, ""); + + // Assert that the limit is updated + ISessionKeyPermissionsPlugin.SpendLimitInfo memory spendLimitInfo = + sessionKeyPermissionsPlugin.getERC20SpendLimitInfo(address(account1), sessionKey1, address(token1)); + assertEq(spendLimitInfo.limit, 1 ether); + assertEq(spendLimitInfo.limitUsed, 3 wei); + assertEq(spendLimitInfo.refreshInterval, 0); + // Assert that the last used time is not updated when the interval is unset. + assertEq(spendLimitInfo.lastUsedTime, 0); + } + + function test_executeWithSessionKey_approveOnlyCountsIncrease() public { + // Give the account a starting balance + token1.mint(address(account1), 100 ether); + + // Preemptively approve the recipient for 0.5 ether + vm.prank(address(account1)); + token1.approve(recipient1, 0.5 ether); + + // Set the limit to 1 ether + bytes[] memory updates = new bytes[](1); + updates[0] = + abi.encodeCall(ISessionKeyPermissionsUpdates.setERC20SpendLimit, (address(token1), 1 ether, 0 days)); + vm.prank(owner1); + SessionKeyPermissionsPlugin(address(account1)).updateKeyPermissions(sessionKey1, updates); + + // Run a user op that spends 1 ether, should succeed + Call[] memory calls = new Call[](1); + calls[0] = + Call({target: address(token1), data: abi.encodeCall(token1.approve, (recipient1, 1 ether)), value: 0}); + + vm.expectCall(address(token1), 0 wei, calls[0].data, 1); + _runSessionKeyUserOp(calls, sessionKey1Private, ""); + + // Assert that the limit is updated + ISessionKeyPermissionsPlugin.SpendLimitInfo memory spendLimitInfo = + sessionKeyPermissionsPlugin.getERC20SpendLimitInfo(address(account1), sessionKey1, address(token1)); + assertEq(spendLimitInfo.limit, 1 ether); + assertEq(spendLimitInfo.limitUsed, 0.5 ether); + assertEq(spendLimitInfo.refreshInterval, 0); + // Assert that the last used time is not updated when the interval is unset. + assertEq(spendLimitInfo.lastUsedTime, 0); + } + + function test_executeWithSessionKey_success_multipleApprove() public { + // Give the account a starting balance + token1.mint(address(account1), 100 ether); + + // Set the limit to 1 ether + bytes[] memory updates = new bytes[](1); + updates[0] = + abi.encodeCall(ISessionKeyPermissionsUpdates.setERC20SpendLimit, (address(token1), 1 ether, 0 days)); + vm.prank(owner1); + SessionKeyPermissionsPlugin(address(account1)).updateKeyPermissions(sessionKey1, updates); + + // Run a user op that spends 3 wei, should succeed + Call[] memory calls = new Call[](3); + calls[0] = + Call({target: address(token1), data: abi.encodeCall(token1.approve, (recipient1, 3 wei)), value: 0}); + calls[1] = + Call({target: address(token1), data: abi.encodeCall(token1.approve, (recipient1, 2 wei)), value: 0}); + calls[2] = + Call({target: address(token1), data: abi.encodeCall(token1.approve, (recipient2, 1 wei)), value: 0}); + + vm.expectCall(address(token1), 0 wei, calls[0].data, 1); + vm.expectCall(address(token1), 0 wei, calls[1].data, 1); + vm.expectCall(address(token1), 0 wei, calls[2].data, 1); + _runSessionKeyUserOp(calls, sessionKey1Private, ""); + + // Assert that the limit is updated + ISessionKeyPermissionsPlugin.SpendLimitInfo memory spendLimitInfo = + sessionKeyPermissionsPlugin.getERC20SpendLimitInfo(address(account1), sessionKey1, address(token1)); + assertEq(spendLimitInfo.limit, 1 ether); + assertEq(spendLimitInfo.limitUsed, 6 wei); + assertEq(spendLimitInfo.refreshInterval, 0); + // Assert that the last used time is not updated when the interval is unset. + assertEq(spendLimitInfo.lastUsedTime, 0); + } + + function test_executeWithSessionKey_success_multipleSpendFunctions() public { + // Give the account a starting balance + token1.mint(address(account1), 100 ether); + + // Set the limit to 1 ether + bytes[] memory updates = new bytes[](1); + updates[0] = + abi.encodeCall(ISessionKeyPermissionsUpdates.setERC20SpendLimit, (address(token1), 1 ether, 0 days)); + vm.prank(owner1); + SessionKeyPermissionsPlugin(address(account1)).updateKeyPermissions(sessionKey1, updates); + + // Run a user op that spends 3 wei, should succeed + Call[] memory calls = new Call[](3); + calls[0] = + Call({target: address(token1), data: abi.encodeCall(token1.transfer, (recipient1, 1 wei)), value: 0}); + calls[1] = Call({ + target: address(token1), + data: abi.encodeCall(token1.approve, (address(account1), 1 wei)), + value: 0 + }); + calls[2] = + Call({target: address(token1), data: abi.encodeCall(token1.transfer, (recipient2, 1 wei)), value: 0}); + vm.expectCall(address(token1), 0 wei, calls[0].data, 1); + vm.expectCall(address(token1), 0 wei, calls[1].data, 1); + vm.expectCall(address(token1), 0 wei, calls[2].data, 1); + _runSessionKeyUserOp(calls, sessionKey1Private, ""); + + // Assert that the limit is updated + ISessionKeyPermissionsPlugin.SpendLimitInfo memory spendLimitInfo = + sessionKeyPermissionsPlugin.getERC20SpendLimitInfo(address(account1), sessionKey1, address(token1)); + assertEq(spendLimitInfo.limit, 1 ether); + assertEq(spendLimitInfo.limitUsed, 3 wei); + assertEq(spendLimitInfo.refreshInterval, 0); + // Assert that the last used time is not updated when the interval is unset. + assertEq(spendLimitInfo.lastUsedTime, 0); + } + + function test_executeWithSessionKey_success_multipleTokens() public { + // Give the account a starting balance + token1.mint(address(account1), 100 ether); + token2.mint(address(account1), 100 ether); + + // Set spending limit + bytes[] memory updates = new bytes[](2); + updates[0] = + abi.encodeCall(ISessionKeyPermissionsUpdates.setERC20SpendLimit, (address(token1), 1 ether, 0 days)); + updates[1] = + abi.encodeCall(ISessionKeyPermissionsUpdates.setERC20SpendLimit, (address(token2), 1 ether, 0 days)); + vm.prank(owner1); + SessionKeyPermissionsPlugin(address(account1)).updateKeyPermissions(sessionKey1, updates); + + // Run a user op that spends 3 wei, should succeed + Call[] memory calls = new Call[](3); + calls[0] = + Call({target: address(token1), data: abi.encodeCall(token1.transfer, (recipient1, 1 wei)), value: 0}); + calls[1] = Call({ + target: address(token2), + data: abi.encodeCall(token2.approve, (address(account1), 1 wei)), + value: 0 + }); + calls[2] = + Call({target: address(token2), data: abi.encodeCall(token2.transfer, (recipient2, 1 wei)), value: 0}); + vm.expectCall(address(token1), 0 wei, calls[0].data, 1); + vm.expectCall(address(token2), 0 wei, calls[1].data, 1); + vm.expectCall(address(token2), 0 wei, calls[2].data, 1); + _runSessionKeyUserOp(calls, sessionKey1Private, ""); + + // Assert that the limit is updated + ISessionKeyPermissionsPlugin.SpendLimitInfo memory spendLimitInfo1 = + sessionKeyPermissionsPlugin.getERC20SpendLimitInfo(address(account1), sessionKey1, address(token1)); + ISessionKeyPermissionsPlugin.SpendLimitInfo memory spendLimitInfo2 = + sessionKeyPermissionsPlugin.getERC20SpendLimitInfo(address(account1), sessionKey1, address(token2)); + assertEq(spendLimitInfo1.limit, 1 ether); + assertEq(spendLimitInfo1.limitUsed, 1 wei); + assertEq(spendLimitInfo2.limit, 1 ether); + assertEq(spendLimitInfo2.limitUsed, 2 wei); + } + + function test_executeWithSessionKey_failWithExceedLimit_multipleTransfer() public { + // Give the account a starting balance + token1.mint(address(account1), 100 ether); + + // Set the limit to 1 ether + bytes[] memory updates = new bytes[](1); + updates[0] = + abi.encodeCall(ISessionKeyPermissionsUpdates.setERC20SpendLimit, (address(token1), 1 wei, 0 days)); + vm.prank(owner1); + SessionKeyPermissionsPlugin(address(account1)).updateKeyPermissions(sessionKey1, updates); + + // Run a user op that should fail due to exceeding limit + Call[] memory calls = new Call[](2); + calls[0] = + Call({target: address(token1), data: abi.encodeCall(token1.transfer, (recipient1, 1 wei)), value: 0}); + calls[1] = + Call({target: address(token1), data: abi.encodeCall(token1.transfer, (recipient2, 1 wei)), value: 0}); + // should not call due to revert on ERC20SpendLimitExceeded + vm.expectCall(address(token1), 0 wei, calls[0].data, 0); + vm.expectCall(address(token1), 0 wei, calls[1].data, 0); + _runSessionKeyUserOp(calls, sessionKey1Private, ""); + + // Assert that the limit is NOT updated + ISessionKeyPermissionsPlugin.SpendLimitInfo memory spendLimitInfo = + sessionKeyPermissionsPlugin.getERC20SpendLimitInfo(address(account1), sessionKey1, address(token1)); + assertEq(spendLimitInfo.limit, 1 wei); + // limit used should be 0 as all action failed. + assertEq(spendLimitInfo.limitUsed, 0 wei); + } + + function test_executeWithSessionKey_failWithExceedLimit_multipleSpendFunctions() public { + // Give the account a starting balance + token1.mint(address(account1), 100 ether); + + // Set the limit to 1 ether + bytes[] memory updates = new bytes[](1); + updates[0] = + abi.encodeCall(ISessionKeyPermissionsUpdates.setERC20SpendLimit, (address(token1), 1 wei, 0 days)); + vm.prank(owner1); + SessionKeyPermissionsPlugin(address(account1)).updateKeyPermissions(sessionKey1, updates); + + // Run a user op that should fail due to exceeding limit + Call[] memory calls = new Call[](3); + calls[0] = + Call({target: address(token1), data: abi.encodeCall(token1.transfer, (recipient1, 1 wei)), value: 0}); + calls[1] = Call({ + target: address(token1), + data: abi.encodeCall(token1.approve, (address(account1), 1 wei)), + value: 0 + }); + calls[2] = + Call({target: address(token1), data: abi.encodeCall(token1.transfer, (recipient2, 1 wei)), value: 0}); + // should not call due to revert on ERC20SpendLimitExceeded + vm.expectCall(address(token1), 0 wei, calls[0].data, 0); + vm.expectCall(address(token1), 0 wei, calls[1].data, 0); + vm.expectCall(address(token1), 0 wei, calls[2].data, 0); + _runSessionKeyUserOp(calls, sessionKey1Private, ""); + // Assert that the limit is NOT updated + ISessionKeyPermissionsPlugin.SpendLimitInfo memory spendLimitInfo = + sessionKeyPermissionsPlugin.getERC20SpendLimitInfo(address(account1), sessionKey1, address(token1)); + assertEq(spendLimitInfo.limit, 1 wei); + // limit used should be 0 as all action failed. + assertEq(spendLimitInfo.limitUsed, 0 wei); + } + + function test_executeWithSessionKey_failWithExceedLimit_overflow() public { + // Give the account a starting balance + token1.mint(address(account1), 100 ether); + + // Set the limit to 1 ether + bytes[] memory updates = new bytes[](1); + updates[0] = + abi.encodeCall(ISessionKeyPermissionsUpdates.setERC20SpendLimit, (address(token1), 1 wei, 0 days)); + vm.prank(owner1); + SessionKeyPermissionsPlugin(address(account1)).updateKeyPermissions(sessionKey1, updates); + + // Run a user op that should fail due to exceeding limit + Call[] memory calls = new Call[](3); + calls[0] = Call({ + target: address(token1), + data: abi.encodeCall(token1.transfer, (recipient1, type(uint256).max)), + value: 0 + }); + calls[1] = Call({ + target: address(token1), + data: abi.encodeCall(token1.approve, (address(account1), type(uint256).max)), + value: 0 + }); + // should not call due to revert on ERC20SpendLimitExceeded + vm.expectCall(address(token1), 0 wei, calls[0].data, 0); + vm.expectCall(address(token1), 0 wei, calls[1].data, 0); + _runSessionKeyUserOp(calls, sessionKey1Private, ""); + // Assert that the limit is NOT updated + ISessionKeyPermissionsPlugin.SpendLimitInfo memory spendLimitInfo = + sessionKeyPermissionsPlugin.getERC20SpendLimitInfo(address(account1), sessionKey1, address(token1)); + assertEq(spendLimitInfo.limit, 1 wei); + // limit used should be 0 as all action failed. + assertEq(spendLimitInfo.limitUsed, 0 wei); + } + + function test_executeWithSessionKey_failWithExceedLimit_multipleTokens() public { + // Give the account a starting balance + token1.mint(address(account1), 100 ether); + token2.mint(address(account1), 100 ether); + + // Set spending limit + bytes[] memory updates = new bytes[](2); + updates[0] = + abi.encodeCall(ISessionKeyPermissionsUpdates.setERC20SpendLimit, (address(token1), 1 wei, 0 days)); + updates[1] = + abi.encodeCall(ISessionKeyPermissionsUpdates.setERC20SpendLimit, (address(token2), 1 wei, 0 days)); + vm.prank(owner1); + SessionKeyPermissionsPlugin(address(account1)).updateKeyPermissions(sessionKey1, updates); + + // Run a user op that should fail due to exceeding limit + Call[] memory calls = new Call[](3); + calls[0] = + Call({target: address(token1), data: abi.encodeCall(token1.transfer, (recipient1, 2 wei)), value: 0}); + calls[1] = Call({ + target: address(token2), + data: abi.encodeCall(token2.approve, (address(account1), 2 wei)), + value: 0 + }); + calls[2] = + Call({target: address(token2), data: abi.encodeCall(token2.transfer, (recipient2, 1 wei)), value: 0}); + vm.expectCall(address(token1), 0 wei, calls[0].data, 0); + vm.expectCall(address(token2), 0 wei, calls[1].data, 0); + vm.expectCall(address(token2), 0 wei, calls[2].data, 0); + _runSessionKeyUserOp(calls, sessionKey1Private, ""); + + // Assert that the limit is NOT updated + ISessionKeyPermissionsPlugin.SpendLimitInfo memory spendLimitInfo1 = + sessionKeyPermissionsPlugin.getERC20SpendLimitInfo(address(account1), sessionKey1, address(token1)); + ISessionKeyPermissionsPlugin.SpendLimitInfo memory spendLimitInfo2 = + sessionKeyPermissionsPlugin.getERC20SpendLimitInfo(address(account1), sessionKey1, address(token2)); + assertEq(spendLimitInfo1.limit, 1 wei); + // limit used should be 0 as all action failed. + assertEq(spendLimitInfo1.limitUsed, 0 wei); + assertEq(spendLimitInfo2.limit, 1 wei); + // limit used should be 0 as all action failed. + assertEq(spendLimitInfo2.limitUsed, 0 wei); + } + + function test_executeWithSessionKey_refreshInterval_singleTransfer() public { + // Set the time to the a unix timestamp + uint256 time0 = 1698708080; + vm.warp(time0); + + // Give the account a starting balance + token1.mint(address(account1), 100 ether); + + // Set the limit to 1 ether + bytes[] memory updates = new bytes[](1); + updates[0] = + abi.encodeCall(ISessionKeyPermissionsUpdates.setERC20SpendLimit, (address(token1), 1 ether, 1 days)); + vm.prank(owner1); + SessionKeyPermissionsPlugin(address(account1)).updateKeyPermissions(sessionKey1, updates); + + // Run a user op that spends 1 wei, should succeed + Call[] memory calls = new Call[](1); + calls[0] = + Call({target: address(token1), data: abi.encodeCall(token1.transfer, (recipient1, 1 wei)), value: 0}); + vm.expectCall(address(token1), 0 wei, calls[0].data, 1); + _runSessionKeyUserOp(calls, sessionKey1Private, ""); + + // Assert that the limit and last used time is updated + ISessionKeyPermissionsPlugin.SpendLimitInfo memory spendLimitInfo = + sessionKeyPermissionsPlugin.getERC20SpendLimitInfo(address(account1), sessionKey1, address(token1)); + assertEq(spendLimitInfo.limit, 1 ether); + assertEq(spendLimitInfo.limitUsed, 1 wei); + assertEq(spendLimitInfo.refreshInterval, 1 days); + assertEq(spendLimitInfo.lastUsedTime, time0); + + // Run a user op that spends 1 ETH, should fail due to over spending + calls[0] = + Call({target: address(token1), data: abi.encodeCall(token1.approve, (recipient1, 1 ether)), value: 0}); + vm.expectCall(address(token1), 0 wei, calls[0].data, 0); + _runSessionKeyUserOp(calls, sessionKey1Private, ""); + + // Assert that the limit and last used time is NOT updated + spendLimitInfo = + sessionKeyPermissionsPlugin.getERC20SpendLimitInfo(address(account1), sessionKey1, address(token1)); + assertEq(spendLimitInfo.limit, 1 ether); + assertEq(spendLimitInfo.limitUsed, 1 wei); + assertEq(spendLimitInfo.refreshInterval, 1 days); + assertEq(spendLimitInfo.lastUsedTime, time0); + + // warp to when the interval resets + vm.warp(time0 + 1 days); + + // Run a user op that spends 1 ether, should succeed + calls[0] = + Call({target: address(token1), data: abi.encodeCall(token1.transfer, (recipient1, 1 ether)), value: 0}); + vm.expectCall(address(token1), 0 wei, calls[0].data, 1); + _runSessionKeyUserOp(calls, sessionKey1Private, ""); + + // Assert that the limit is now updated and the last used timestamp is set. + spendLimitInfo = + sessionKeyPermissionsPlugin.getERC20SpendLimitInfo(address(account1), sessionKey1, address(token1)); + assertEq(spendLimitInfo.limit, 1 ether); + assertEq(spendLimitInfo.limitUsed, 1 ether); + assertEq(spendLimitInfo.refreshInterval, 1 days); + assertEq(spendLimitInfo.lastUsedTime, time0 + 1 days); + } + + function test_executeWithSessionKey_refreshInterval_multipleTransfer() public { + // Set the time to the a unix timestamp + uint256 time0 = 1698708080; + vm.warp(time0); + + // Give the account a starting balance + token1.mint(address(account1), 100 ether); + + // Set the limit to 1 ether + bytes[] memory updates = new bytes[](1); + updates[0] = + abi.encodeCall(ISessionKeyPermissionsUpdates.setERC20SpendLimit, (address(token1), 1 ether, 1 days)); + vm.prank(owner1); + SessionKeyPermissionsPlugin(address(account1)).updateKeyPermissions(sessionKey1, updates); + + // Run a user op that spends 1 wei, should succeed + Call[] memory calls = new Call[](2); + calls[0] = + Call({target: address(token1), data: abi.encodeCall(token1.transfer, (recipient1, 1 wei)), value: 0}); + calls[1] = + Call({target: address(token1), data: abi.encodeCall(token1.transfer, (recipient1, 1 wei)), value: 0}); + vm.expectCall(address(token1), 0 wei, calls[0].data, 2); + _runSessionKeyUserOp(calls, sessionKey1Private, ""); + + // Assert that the limit and last used time is updated + ISessionKeyPermissionsPlugin.SpendLimitInfo memory spendLimitInfo = + sessionKeyPermissionsPlugin.getERC20SpendLimitInfo(address(account1), sessionKey1, address(token1)); + assertEq(spendLimitInfo.limit, 1 ether); + assertEq(spendLimitInfo.limitUsed, 2 wei); + assertEq(spendLimitInfo.refreshInterval, 1 days); + assertEq(spendLimitInfo.lastUsedTime, time0); + + // Run a user op that spends 1 ETH, should fail due to over spending + calls[0] = + Call({target: address(token1), data: abi.encodeCall(token1.transfer, (recipient1, 1 ether)), value: 0}); + vm.expectCall(address(token1), 0 wei, calls[0].data, 0); + _runSessionKeyUserOp(calls, sessionKey1Private, ""); + + // Assert that the limit and last used time is NOT updated + spendLimitInfo = + sessionKeyPermissionsPlugin.getERC20SpendLimitInfo(address(account1), sessionKey1, address(token1)); + assertEq(spendLimitInfo.limit, 1 ether); + assertEq(spendLimitInfo.limitUsed, 2 wei); + assertEq(spendLimitInfo.refreshInterval, 1 days); + assertEq(spendLimitInfo.lastUsedTime, time0); + + // warp to when the interval resets + vm.warp(time0 + 1 days); + + // Run a user op that spends 1 ether, should succeed + calls[0] = Call({ + target: address(token1), + data: abi.encodeCall(token1.transfer, (recipient1, 0.5 ether)), + value: 0 + }); + calls[1] = Call({ + target: address(token1), + data: abi.encodeCall(token1.transfer, (recipient1, 0.5 ether)), + value: 0 + }); + vm.expectCall(address(token1), 0 wei, calls[0].data, 2); + _runSessionKeyUserOp(calls, sessionKey1Private, ""); + + // Assert that the limit is now updated and the last used timestamp is set. + spendLimitInfo = + sessionKeyPermissionsPlugin.getERC20SpendLimitInfo(address(account1), sessionKey1, address(token1)); + assertEq(spendLimitInfo.limit, 1 ether); + assertEq(spendLimitInfo.limitUsed, 1 ether); + assertEq(spendLimitInfo.refreshInterval, 1 days); + assertEq(spendLimitInfo.lastUsedTime, time0 + 1 days); + } + + function test_executeWithSessionKey_refreshInterval_multipleApprove() public { + // Set the time to the a unix timestamp + uint256 time0 = 1698708080; + vm.warp(time0); + + // Give the account a starting balance + token1.mint(address(account1), 100 ether); + + // Set the limit to 1 ether + bytes[] memory updates = new bytes[](1); + updates[0] = + abi.encodeCall(ISessionKeyPermissionsUpdates.setERC20SpendLimit, (address(token1), 1 ether, 1 days)); + vm.prank(owner1); + SessionKeyPermissionsPlugin(address(account1)).updateKeyPermissions(sessionKey1, updates); + + // Run a user op that spends 1 wei, should succeed + Call[] memory calls = new Call[](2); + calls[0] = + Call({target: address(token1), data: abi.encodeCall(token1.transfer, (recipient1, 1 wei)), value: 0}); + calls[1] = + Call({target: address(token1), data: abi.encodeCall(token1.approve, (recipient1, 1 wei)), value: 0}); + vm.expectCall(address(token1), 0 wei, calls[0].data, 1); + vm.expectCall(address(token1), 0 wei, calls[1].data, 1); + _runSessionKeyUserOp(calls, sessionKey1Private, ""); + // Assert that the limit and last used time is updated + ISessionKeyPermissionsPlugin.SpendLimitInfo memory spendLimitInfo = + sessionKeyPermissionsPlugin.getERC20SpendLimitInfo(address(account1), sessionKey1, address(token1)); + assertEq(spendLimitInfo.limit, 1 ether); + assertEq(spendLimitInfo.limitUsed, 2 wei); + assertEq(spendLimitInfo.refreshInterval, 1 days); + assertEq(spendLimitInfo.lastUsedTime, time0); + + // Run a user op that spends 1 ETH, should fail due to over spending + calls[0] = + Call({target: address(token1), data: abi.encodeCall(token1.transfer, (recipient1, 1 ether)), value: 0}); + vm.expectCall(address(token1), 0 wei, calls[0].data, 0); + _runSessionKeyUserOp(calls, sessionKey1Private, ""); + // Assert that the limit and last used time is NOT updated + spendLimitInfo = + sessionKeyPermissionsPlugin.getERC20SpendLimitInfo(address(account1), sessionKey1, address(token1)); + assertEq(spendLimitInfo.limit, 1 ether); + assertEq(spendLimitInfo.limitUsed, 2 wei); + assertEq(spendLimitInfo.refreshInterval, 1 days); + assertEq(spendLimitInfo.lastUsedTime, time0); + + // warp to when the interval resets + vm.warp(time0 + 1 days); + + // Run a user op that spends 1 ether, should succeed + calls[0] = Call({ + target: address(token1), + data: abi.encodeCall(token1.approve, (recipient1, 0.5 ether + 1 wei)), + value: 0 + }); + calls[1] = Call({ + target: address(token1), + // previous approved 1 wei is still effective + data: abi.encodeCall(token1.approve, (recipient1, 0.5 ether + 1 wei)), + value: 0 + }); + vm.expectCall(address(token1), 0 wei, calls[0].data, 2); + _runSessionKeyUserOp(calls, sessionKey1Private, ""); + // Assert that the limit is now updated and the last used timestamp is set. + spendLimitInfo = + sessionKeyPermissionsPlugin.getERC20SpendLimitInfo(address(account1), sessionKey1, address(token1)); + assertEq(spendLimitInfo.limit, 1 ether); + assertEq(spendLimitInfo.limitUsed, 1 ether); + assertEq(spendLimitInfo.refreshInterval, 1 days); + assertEq(spendLimitInfo.lastUsedTime, time0 + 1 days); + } + + function test_executeWithSessionKey_refreshInterval_multipleSpendFunctions() public { + // Set the time to the a unix timestamp + uint256 time0 = 1698708080; + vm.warp(time0); + + // Give the account a starting balance + token1.mint(address(account1), 100 ether); + + // Set the limit to 1 ether + bytes[] memory updates = new bytes[](1); + updates[0] = + abi.encodeCall(ISessionKeyPermissionsUpdates.setERC20SpendLimit, (address(token1), 1 ether, 1 days)); + vm.prank(owner1); + SessionKeyPermissionsPlugin(address(account1)).updateKeyPermissions(sessionKey1, updates); + + // Run a user op that spends within limit, should succeed + Call[] memory calls = new Call[](2); + calls[0] = + Call({target: address(token1), data: abi.encodeCall(token1.transfer, (recipient1, 1 wei)), value: 0}); + calls[1] = + Call({target: address(token1), data: abi.encodeCall(token1.transfer, (recipient1, 2 wei)), value: 0}); + vm.expectCall(address(token1), 0 wei, calls[0].data, 1); + vm.expectCall(address(token1), 0 wei, calls[1].data, 1); + _runSessionKeyUserOp(calls, sessionKey1Private, ""); + // Assert that the limit and last used time is updated + ISessionKeyPermissionsPlugin.SpendLimitInfo memory spendLimitInfo = + sessionKeyPermissionsPlugin.getERC20SpendLimitInfo(address(account1), sessionKey1, address(token1)); + assertEq(spendLimitInfo.limit, 1 ether); + assertEq(spendLimitInfo.limitUsed, 3 wei); + assertEq(spendLimitInfo.refreshInterval, 1 days); + assertEq(spendLimitInfo.lastUsedTime, time0); + + // Run a user op that spends 1 ETH, should fail due to over spending + calls[0] = + Call({target: address(token1), data: abi.encodeCall(token1.approve, (recipient1, 1 ether)), value: 0}); + vm.expectCall(address(token1), 0 wei, calls[0].data, 0); + _runSessionKeyUserOp(calls, sessionKey1Private, ""); + // Assert that the limit and last used time is NOT updated + spendLimitInfo = + sessionKeyPermissionsPlugin.getERC20SpendLimitInfo(address(account1), sessionKey1, address(token1)); + assertEq(spendLimitInfo.limit, 1 ether); + assertEq(spendLimitInfo.limitUsed, 3 wei); + assertEq(spendLimitInfo.refreshInterval, 1 days); + assertEq(spendLimitInfo.lastUsedTime, time0); + + // warp to when the interval resets + vm.warp(time0 + 1 days); + + // Run a user op that spends 1 ether, should succeed + calls[0] = Call({ + target: address(token1), + data: abi.encodeCall(token1.transfer, (recipient1, 0.5 ether)), + value: 0 + }); + calls[1] = Call({ + target: address(token1), + data: abi.encodeCall(token1.approve, (recipient1, 0.5 ether)), + value: 0 + }); + vm.expectCall(address(token1), 0 wei, calls[0].data, 1); + vm.expectCall(address(token1), 0 wei, calls[1].data, 1); + _runSessionKeyUserOp(calls, sessionKey1Private, ""); + // Assert that the limit is now updated and the last used timestamp is set. + spendLimitInfo = + sessionKeyPermissionsPlugin.getERC20SpendLimitInfo(address(account1), sessionKey1, address(token1)); + assertEq(spendLimitInfo.limit, 1 ether); + assertEq(spendLimitInfo.limitUsed, 1 ether); + assertEq(spendLimitInfo.refreshInterval, 1 days); + assertEq(spendLimitInfo.lastUsedTime, time0 + 1 days); + } + + function test_executeWithSessionKey_refreshInterval_failWithSomeTokenLimit() public { + // Set the time to the a unix timestamp + uint256 time0 = 1698708080; + vm.warp(time0); + + // Give the account a starting balance + token1.mint(address(account1), 100 ether); + token2.mint(address(account1), 100 ether); + + // Set the limit to 1 ether, over 1 day and 10 days, respectively + bytes[] memory updates = new bytes[](2); + updates[0] = + abi.encodeCall(ISessionKeyPermissionsUpdates.setERC20SpendLimit, (address(token1), 1 ether, 1 days)); + updates[1] = + abi.encodeCall(ISessionKeyPermissionsUpdates.setERC20SpendLimit, (address(token2), 1 ether, 10 days)); + vm.prank(owner1); + SessionKeyPermissionsPlugin(address(account1)).updateKeyPermissions(sessionKey1, updates); + + // Run a user op that spends max limit in interval, should succeed + Call[] memory calls = new Call[](2); + calls[0] = + Call({target: address(token1), data: abi.encodeCall(token1.transfer, (recipient1, 1 ether)), value: 0}); + calls[1] = + Call({target: address(token2), data: abi.encodeCall(token2.transfer, (recipient1, 1 ether)), value: 0}); + vm.expectCall(address(token1), 0 wei, calls[0].data, 1); + vm.expectCall(address(token2), 0 wei, calls[1].data, 1); + _runSessionKeyUserOp(calls, sessionKey1Private, ""); + + // Assert that the limit and last used time is updated + ISessionKeyPermissionsPlugin.SpendLimitInfo memory spendLimitInfo1 = + sessionKeyPermissionsPlugin.getERC20SpendLimitInfo(address(account1), sessionKey1, address(token1)); + ISessionKeyPermissionsPlugin.SpendLimitInfo memory spendLimitInfo2 = + sessionKeyPermissionsPlugin.getERC20SpendLimitInfo(address(account1), sessionKey1, address(token2)); + assertEq(spendLimitInfo1.limit, 1 ether); + assertEq(spendLimitInfo1.limitUsed, 1 ether); + assertEq(spendLimitInfo2.limit, 1 ether); + assertEq(spendLimitInfo2.limitUsed, 1 ether); + + // warp to when the interval resets + vm.warp(time0 + 1 days); + + // Run a user op that spends within limit for token 2, exceed limit for token 2 , should fail due to over + // spending on token2 + calls[0] = + Call({target: address(token1), data: abi.encodeCall(token1.transfer, (recipient1, 1 wei)), value: 0}); + calls[1] = + Call({target: address(token2), data: abi.encodeCall(token2.transfer, (recipient1, 1 wei)), value: 0}); + vm.expectCall(address(token1), 0 wei, calls[0].data, 0); + vm.expectCall(address(token2), 0 wei, calls[1].data, 0); + _runSessionKeyUserOp(calls, sessionKey1Private, ""); + + // Assert that the limit and last used time is NOT updated + spendLimitInfo1 = + sessionKeyPermissionsPlugin.getERC20SpendLimitInfo(address(account1), sessionKey1, address(token1)); + spendLimitInfo2 = + sessionKeyPermissionsPlugin.getERC20SpendLimitInfo(address(account1), sessionKey1, address(token2)); + assertEq(spendLimitInfo1.limit, 1 ether); + assertEq(spendLimitInfo1.limitUsed, 1 ether); + assertEq(spendLimitInfo1.lastUsedTime, time0); + assertEq(spendLimitInfo2.limit, 1 ether); + assertEq(spendLimitInfo2.limitUsed, 1 ether); + assertEq(spendLimitInfo2.lastUsedTime, time0); + + // warp to when the interval resets + vm.warp(time0 + 10 days); + + // Enough time passed, run a user op that spends max limit in interval, should succeed + calls[0] = + Call({target: address(token1), data: abi.encodeCall(token1.approve, (recipient1, 1 ether)), value: 0}); + calls[1] = + Call({target: address(token2), data: abi.encodeCall(token2.approve, (recipient1, 1 ether)), value: 0}); + vm.expectCall(address(token1), 0 wei, calls[0].data, 1); + vm.expectCall(address(token2), 0 wei, calls[1].data, 1); + _runSessionKeyUserOp(calls, sessionKey1Private, ""); + + // Assert that the limit is now updated and the last used timestamp is set. + spendLimitInfo1 = + sessionKeyPermissionsPlugin.getERC20SpendLimitInfo(address(account1), sessionKey1, address(token1)); + spendLimitInfo2 = + sessionKeyPermissionsPlugin.getERC20SpendLimitInfo(address(account1), sessionKey1, address(token2)); + assertEq(spendLimitInfo1.limit, 1 ether); + assertEq(spendLimitInfo1.limitUsed, 1 ether); + assertEq(spendLimitInfo1.lastUsedTime, time0 + 10 days); + assertEq(spendLimitInfo2.limit, 1 ether); + assertEq(spendLimitInfo2.limitUsed, 1 ether); + assertEq(spendLimitInfo2.lastUsedTime, time0 + 10 days); + } + + function _runSessionKeyUserOp(Call[] memory calls, uint256 sessionKeyPrivate, bytes memory expectedError) + internal + { + address sessionKey = vm.addr(sessionKeyPrivate); + + UserOperation memory userOp = UserOperation({ + sender: address(account1), + nonce: entryPoint.getNonce(address(account1), 0), + initCode: "", + callData: abi.encodeCall(ISessionKeyPlugin.executeWithSessionKey, (calls, sessionKey)), + callGasLimit: CALL_GAS_LIMIT, + verificationGasLimit: VERIFICATION_GAS_LIMIT, + preVerificationGas: 0, + maxFeePerGas: 2, + maxPriorityFeePerGas: 1, + paymasterAndData: "", + signature: "" + }); + + bytes32 userOpHash = entryPoint.getUserOpHash(userOp); + (uint8 v, bytes32 r, bytes32 s) = vm.sign(sessionKeyPrivate, userOpHash.toEthSignedMessageHash()); + userOp.signature = abi.encodePacked(r, s, v); + + UserOperation[] memory userOps = new UserOperation[](1); + userOps[0] = userOp; + + if (expectedError.length > 0) { + vm.expectRevert(expectedError); + } + entryPoint.handleOps(userOps, beneficiary); + } +} diff --git a/test/plugin/session/permissions/SessionKeyGasLimits.t.sol b/test/plugin/session/permissions/SessionKeyGasLimits.t.sol new file mode 100644 index 00000000..1ad64178 --- /dev/null +++ b/test/plugin/session/permissions/SessionKeyGasLimits.t.sol @@ -0,0 +1,697 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.21; + +import {Test, console} from "forge-std/Test.sol"; + +import {ECDSA} from "@openzeppelin/contracts/utils/cryptography/ECDSA.sol"; +import {EntryPoint} from "@eth-infinitism/account-abstraction/core/EntryPoint.sol"; + +import {UpgradeableModularAccount} from "../../../../src/account/UpgradeableModularAccount.sol"; +import {IMultiOwnerPlugin} from "../../../../src/plugins/owner/IMultiOwnerPlugin.sol"; +import {MultiOwnerPlugin} from "../../../../src/plugins/owner/MultiOwnerPlugin.sol"; +import {ISessionKeyPlugin} from "../../../../src/plugins/session/ISessionKeyPlugin.sol"; +import {SessionKeyPlugin} from "../../../../src/plugins/session/SessionKeyPlugin.sol"; +import {ISessionKeyPermissionsPlugin} from + "../../../../src/plugins/session/permissions/ISessionKeyPermissionsPlugin.sol"; +import {ISessionKeyPermissionsUpdates} from + "../../../../src/plugins/session/permissions/ISessionKeyPermissionsUpdates.sol"; +import {SessionKeyPermissionsPlugin} from + "../../../../src/plugins/session/permissions/SessionKeyPermissionsPlugin.sol"; +import {IEntryPoint} from "../../../../src/interfaces/erc4337/IEntryPoint.sol"; +import {UserOperation} from "../../../../src/interfaces/erc4337/UserOperation.sol"; +import {IPluginManager} from "../../../../src/interfaces/IPluginManager.sol"; +import {Call} from "../../../../src/interfaces/IStandardExecutor.sol"; +import {FunctionReference, FunctionReferenceLib} from "../../../../src/libraries/FunctionReferenceLib.sol"; + +import {MultiOwnerMSCAFactory} from "../../../../src/factory/MultiOwnerMSCAFactory.sol"; + +contract SessionKeyGasLimitsTest is Test { + using ECDSA for bytes32; + + IEntryPoint entryPoint; + address payable beneficiary; + + MultiOwnerPlugin public multiOwnerPlugin; + MultiOwnerMSCAFactory public factory; + SessionKeyPlugin sessionKeyPlugin; + SessionKeyPermissionsPlugin sessionKeyPermissionsPlugin; + + address owner1; + uint256 owner1Key; + UpgradeableModularAccount account1; + + address sessionKey1; + uint256 sessionKey1Private; + + address recipient; + + function setUp() public { + entryPoint = IEntryPoint(address(new EntryPoint())); + (owner1, owner1Key) = makeAddrAndKey("owner1"); + beneficiary = payable(makeAddr("beneficiary")); + recipient = makeAddr("recipient"); + + vm.deal(beneficiary, 1 wei); + + multiOwnerPlugin = new MultiOwnerPlugin(); + factory = new MultiOwnerMSCAFactory( + address(this), + address(multiOwnerPlugin), + address(new UpgradeableModularAccount(entryPoint)), + keccak256(abi.encode(multiOwnerPlugin.pluginManifest())), + entryPoint + ); + owner1 = makeAddr("owner"); + address[] memory owners = new address[](1); + owners[0] = owner1; + account1 = UpgradeableModularAccount(payable(factory.createAccount(0, owners))); + + vm.deal(address(account1), 100 ether); + + sessionKeyPlugin = new SessionKeyPlugin(); + sessionKeyPermissionsPlugin = new SessionKeyPermissionsPlugin(); + + bytes32 manifestHash = keccak256(abi.encode(sessionKeyPlugin.pluginManifest())); + FunctionReference[] memory dependencies = new FunctionReference[](2); + dependencies[0] = FunctionReferenceLib.pack( + address(multiOwnerPlugin), uint8(IMultiOwnerPlugin.FunctionId.USER_OP_VALIDATION_OWNER) + ); + dependencies[1] = FunctionReferenceLib.pack( + address(multiOwnerPlugin), uint8(IMultiOwnerPlugin.FunctionId.RUNTIME_VALIDATION_OWNER_OR_SELF) + ); + vm.prank(owner1); + account1.installPlugin({ + plugin: address(sessionKeyPlugin), + manifestHash: manifestHash, + pluginInitData: abi.encode(new address[](0)), + dependencies: dependencies, + injectedHooks: new IPluginManager.InjectedHook[](0) + }); + + manifestHash = keccak256(abi.encode(sessionKeyPermissionsPlugin.pluginManifest())); + // Can reuse the same dependencies for this installation, because the requirements are the same. + vm.prank(owner1); + account1.installPlugin({ + plugin: address(sessionKeyPermissionsPlugin), + manifestHash: manifestHash, + pluginInitData: "", + dependencies: dependencies, + injectedHooks: new IPluginManager.InjectedHook[](0) + }); + + // Create and add a session key + (sessionKey1, sessionKey1Private) = makeAddrAndKey("sessionKey1"); + + address[] memory sessionKeysToAdd = new address[](1); + sessionKeysToAdd[0] = sessionKey1; + + vm.prank(owner1); + SessionKeyPlugin(address(account1)).updateSessionKeys( + sessionKeysToAdd, new SessionKeyPlugin.SessionKeyToRemove[](0) + ); + + // Register the session key with the permissions plugin + vm.prank(owner1); + SessionKeyPermissionsPlugin(address(account1)).registerKey(sessionKey1, 0); + + bytes[] memory updates = new bytes[](1); + updates[0] = abi.encodeCall( + ISessionKeyPermissionsUpdates.setAccessListType, + (ISessionKeyPermissionsPlugin.ContractAccessControlType.NONE) + ); + vm.prank(owner1); + SessionKeyPermissionsPlugin(address(account1)).updateKeyPermissions(sessionKey1, updates); + } + + function testFuzz_sessionKeyGasLimits_setLimits(uint256 limit, uint48 interval, uint48 timestamp) public { + vm.warp(timestamp); + + // Assert that the limit starts out unset + (ISessionKeyPermissionsPlugin.SpendLimitInfo memory spendLimitInfo,) = + sessionKeyPermissionsPlugin.getGasSpendLimit(address(account1), sessionKey1); + + assertFalse(spendLimitInfo.hasLimit); + assertEq(spendLimitInfo.limit, 0); + assertEq(spendLimitInfo.refreshInterval, 0); + assertEq(spendLimitInfo.limitUsed, 0); + + // Set the limit + bytes[] memory updates = new bytes[](1); + updates[0] = abi.encodeCall(ISessionKeyPermissionsUpdates.setGasSpendLimit, (limit, interval)); + vm.prank(owner1); + SessionKeyPermissionsPlugin(address(account1)).updateKeyPermissions(sessionKey1, updates); + + // Verify that the limit is set + (spendLimitInfo,) = sessionKeyPermissionsPlugin.getGasSpendLimit(address(account1), sessionKey1); + + if (limit == type(uint256).max) { + // If the limit is "set" to this value, it is just removed. + // verify that the values are still as they were before. + assertEq(spendLimitInfo.limit, 0); + assertEq(spendLimitInfo.refreshInterval, 0); + assertEq(spendLimitInfo.limitUsed, 0); + } else { + // The limit is actually set, verify that the values are as expected. + assertTrue(spendLimitInfo.hasLimit); + assertEq(spendLimitInfo.limit, limit); + assertEq(spendLimitInfo.refreshInterval, interval); + assertEq(spendLimitInfo.limitUsed, 0); + if (interval == 0) { + assertEq(spendLimitInfo.lastUsedTime, 0); + } else { + assertEq(spendLimitInfo.lastUsedTime, timestamp); + } + } + } + + // gas limit zero + function test_sessionKeyGasLimits_enforceLimit_none() public { + bytes[] memory updates = new bytes[](1); + updates[0] = abi.encodeCall(ISessionKeyPermissionsUpdates.setGasSpendLimit, (0 ether, 0 days)); + vm.prank(owner1); + SessionKeyPermissionsPlugin(address(account1)).updateKeyPermissions(sessionKey1, updates); + + // A user op spending any gas should be rejected at this stage + _runSessionKeyUserOp( + 50_000, + 150_000, + 1, + 200_000 wei, + sessionKey1Private, + abi.encodeWithSelector(IEntryPoint.FailedOp.selector, 0, "AA24 signature error") + ); + } + + function testFuzz_sessionKeyGasLimits_nolimit(uint256 gasPrice) public { + gasPrice = bound(gasPrice, 1 wei, 1_000_000_000 gwei); + + uint256 ethToSpend = 1_000_000 * gasPrice; + + // Extra padding amount to cover the duplicate requirement gas for validation + vm.deal(address(account1), ethToSpend); + + _runSessionKeyUserOp(200_000, 800_000, gasPrice, ethToSpend, sessionKey1Private, ""); + } + + function test_sessionKeyGasLimits_enforceLimit_basic_single() public { + bytes[] memory updates = new bytes[](1); + updates[0] = abi.encodeCall(ISessionKeyPermissionsUpdates.setGasSpendLimit, (1 ether, 0 days)); + vm.prank(owner1); + SessionKeyPermissionsPlugin(address(account1)).updateKeyPermissions(sessionKey1, updates); + + // A basic user op using 0.6 ether in gas should succeed + _runSessionKeyUserOp(100_000, 500_000, 1_000 gwei, 0.6 ether, sessionKey1Private, ""); + + // This usage update should be reflected in the limits + (ISessionKeyPermissionsPlugin.SpendLimitInfo memory spendLimitInfo,) = + sessionKeyPermissionsPlugin.getGasSpendLimit(address(account1), sessionKey1); + assertTrue(spendLimitInfo.hasLimit); + assertEq(spendLimitInfo.limit, 1 ether); + assertEq(spendLimitInfo.limitUsed, 0.6 ether); + assertEq(spendLimitInfo.refreshInterval, 0); + assertEq(spendLimitInfo.lastUsedTime, 0); + + // A basic user op using 0.6 ether in gas should now fail + _runSessionKeyUserOp( + 100_000, + 500_000, + 1_000 gwei, + 0.6 ether, + sessionKey1Private, + abi.encodeWithSelector(IEntryPoint.FailedOp.selector, 0, "AA24 signature error") + ); + } + + function test_sessionKeyGasLimits_exceedLimit_single() public { + bytes[] memory updates = new bytes[](1); + updates[0] = abi.encodeCall(ISessionKeyPermissionsUpdates.setGasSpendLimit, (1 ether, 0 days)); + vm.prank(owner1); + SessionKeyPermissionsPlugin(address(account1)).updateKeyPermissions(sessionKey1, updates); + + // A basic user op using 1.2 ether in gas should fail + _runSessionKeyUserOp( + 100_000, + 500_000, + 2_000 gwei, + 1.2 ether, + sessionKey1Private, + abi.encodeWithSelector(IEntryPoint.FailedOp.selector, 0, "AA24 signature error") + ); + } + + function test_sessionKeyGasLimits_enforceLimit_basic_multipleInBundle() public { + bytes[] memory updates = new bytes[](1); + updates[0] = abi.encodeCall(ISessionKeyPermissionsUpdates.setGasSpendLimit, (1 ether, 0 days)); + vm.prank(owner1); + SessionKeyPermissionsPlugin(address(account1)).updateKeyPermissions(sessionKey1, updates); + + // Construct two user ops that, when bundled together, spend 0.8 ether + UserOperation[] memory userOps = new UserOperation[](2); + userOps[0] = _generateAndSignUserOp( + 100_000, 300_000, 1_000 gwei, 0.4 ether, sessionKey1Private, _wrapNonceWithAddr(0, sessionKey1) + ); + userOps[1] = _generateAndSignUserOp( + 100_000, 300_000, 1_000 gwei, 0.4 ether, sessionKey1Private, _wrapNonceWithAddr(1, sessionKey1) + ); + + // Run the user ops + entryPoint.handleOps(userOps, beneficiary); + + // This usage update should be reflected in the limits + (ISessionKeyPermissionsPlugin.SpendLimitInfo memory spendLimitInfo,) = + sessionKeyPermissionsPlugin.getGasSpendLimit(address(account1), sessionKey1); + assertTrue(spendLimitInfo.hasLimit); + assertEq(spendLimitInfo.limit, 1 ether); + assertEq(spendLimitInfo.limitUsed, 0.8 ether); + assertEq(spendLimitInfo.refreshInterval, 0); + assertEq(spendLimitInfo.lastUsedTime, 0); + } + + function test_sessionKeyGasLimits_exceedLimit_multipleInBundle() public { + bytes[] memory updates = new bytes[](1); + updates[0] = abi.encodeCall(ISessionKeyPermissionsUpdates.setGasSpendLimit, (1 ether, 0 days)); + vm.prank(owner1); + SessionKeyPermissionsPlugin(address(account1)).updateKeyPermissions(sessionKey1, updates); + + // Construct two user ops that, when bundled together, spend 1.6 ether + UserOperation[] memory userOps = new UserOperation[](2); + userOps[0] = _generateAndSignUserOp( + 100_000, 300_000, 2_000 gwei, 0.8 ether, sessionKey1Private, _wrapNonceWithAddr(0, sessionKey1) + ); + userOps[1] = _generateAndSignUserOp( + 100_000, 300_000, 2_000 gwei, 0.8 ether, sessionKey1Private, _wrapNonceWithAddr(1, sessionKey1) + ); + + // Run the user ops + // The second op (index 1) should be the one that fails signature validation. + vm.expectRevert(abi.encodeWithSelector(IEntryPoint.FailedOp.selector, 1, "AA24 signature error")); + entryPoint.handleOps(userOps, beneficiary); + + // The lack of usage update should be reflected in the limits + (ISessionKeyPermissionsPlugin.SpendLimitInfo memory spendLimitInfo,) = + sessionKeyPermissionsPlugin.getGasSpendLimit(address(account1), sessionKey1); + assertTrue(spendLimitInfo.hasLimit); + assertEq(spendLimitInfo.limit, 1 ether); + assertEq(spendLimitInfo.limitUsed, 0 ether); + assertEq(spendLimitInfo.refreshInterval, 0); + assertEq(spendLimitInfo.lastUsedTime, 0); + } + + function test_sessionKeyGasLimits_refreshInterval_inspectValidationData() public { + // Pick a start time + uint256 time0 = 1698708080; + vm.warp(time0); + + bytes[] memory updates = new bytes[](1); + updates[0] = abi.encodeCall(ISessionKeyPermissionsUpdates.setGasSpendLimit, (1 ether, 1 days)); + vm.prank(owner1); + SessionKeyPermissionsPlugin(address(account1)).updateKeyPermissions(sessionKey1, updates); + + // A basic user op using 0.6 ether in gas should succeed + _runSessionKeyUserOp(100_000, 500_000, 1_000 gwei, 0.6 ether, sessionKey1Private, ""); + + // This usage update should be reflected in the limits + (ISessionKeyPermissionsPlugin.SpendLimitInfo memory spendLimitInfo,) = + sessionKeyPermissionsPlugin.getGasSpendLimit(address(account1), sessionKey1); + assertTrue(spendLimitInfo.hasLimit); + assertEq(spendLimitInfo.limit, 1 ether); + assertEq(spendLimitInfo.limitUsed, 0.6 ether); + assertEq(spendLimitInfo.refreshInterval, 1 days); + assertEq(spendLimitInfo.lastUsedTime, time0); + + // Inspect the returned time range from validateUserOp to see the higher start time for the next interval. + UserOperation memory userOp = _generateAndSignUserOp( + 100_000, 500_000, 1_000 gwei, 0.6 ether, sessionKey1Private, _wrapNonceWithAddr(0, sessionKey1) + ); + bytes32 userOpHash = entryPoint.getUserOpHash(userOp); + + vm.prank(address(entryPoint)); + // NOTE: this causes the last used to advance, which is an intentional side effect of validation. Under + // normal circumstances, if it is not yet due, the validation will revert by the EntryPoint. The account + // protects from stray state updates by asserting that these calls only come from the entrypoint, but we + // mock it here with vm.prank. + uint256 validationData = account1.validateUserOp(userOp, userOpHash, 0); + + uint48 expectedStartTime = uint48(time0 + 1 days); + uint48 actualStartTime = uint48(validationData >> 208); + assertEq(actualStartTime, expectedStartTime); + } + + function test_sessionKeyGasLimits_refreshInterval_single() public { + // Pick a start time + uint256 time0 = 1698708080; + vm.warp(time0); + + bytes[] memory updates = new bytes[](1); + updates[0] = abi.encodeCall(ISessionKeyPermissionsUpdates.setGasSpendLimit, (1 ether, 1 days)); + vm.prank(owner1); + SessionKeyPermissionsPlugin(address(account1)).updateKeyPermissions(sessionKey1, updates); + + // A basic user op using 0.6 ether in gas should succeed + _runSessionKeyUserOp(100_000, 500_000, 1_000 gwei, 0.6 ether, sessionKey1Private, ""); + + // This usage update should be reflected in the limits + (ISessionKeyPermissionsPlugin.SpendLimitInfo memory spendLimitInfo,) = + sessionKeyPermissionsPlugin.getGasSpendLimit(address(account1), sessionKey1); + assertTrue(spendLimitInfo.hasLimit); + assertEq(spendLimitInfo.limit, 1 ether); + assertEq(spendLimitInfo.limitUsed, 0.6 ether); + assertEq(spendLimitInfo.refreshInterval, 1 days); + assertEq(spendLimitInfo.lastUsedTime, time0); + + // Attempting to use another 0.6 ether now should fail + + _runSessionKeyUserOp( + 100_000, + 500_000, + 1_000 gwei, + 0.6 ether, + sessionKey1Private, + abi.encodeWithSelector(IEntryPoint.FailedOp.selector, 0, "AA22 expired or not due") + ); + + // Skip forward and run the user op + + skip(1 days + 1 minutes); + _runSessionKeyUserOp(100_000, 500_000, 1_000 gwei, 0.6 ether, sessionKey1Private, ""); + + // This usage update should be reflected in the limits + (spendLimitInfo,) = sessionKeyPermissionsPlugin.getGasSpendLimit(address(account1), sessionKey1); + + assertTrue(spendLimitInfo.hasLimit); + assertEq(spendLimitInfo.limit, 1 ether); + assertEq(spendLimitInfo.limitUsed, 0.6 ether); + assertEq(spendLimitInfo.refreshInterval, 1 days); + // The last used time SHOULD increment by the actual time passed, not just the interval, if the call + // succeeded. + assertEq(spendLimitInfo.lastUsedTime, time0 + 1 days + 1 minutes); + } + + function test_sessionKeyGasLimits_refreshInterval_multipleInBundle() public { + // Pick a start time + uint256 time0 = 1698708080; + vm.warp(time0); + + bytes[] memory updates = new bytes[](1); + updates[0] = abi.encodeCall(ISessionKeyPermissionsUpdates.setGasSpendLimit, (1 ether, 1 days)); + vm.prank(owner1); + SessionKeyPermissionsPlugin(address(account1)).updateKeyPermissions(sessionKey1, updates); + + // Use up 0.6 ether + _runSessionKeyUserOp(100_000, 500_000, 1_000 gwei, 0.6 ether, sessionKey1Private, ""); + + // Construct two user ops that, when bundled together, spend 0.8 ether + + UserOperation[] memory userOps = new UserOperation[](2); + userOps[0] = _generateAndSignUserOp( + 100_000, 300_000, 1_000 gwei, 0.4 ether, sessionKey1Private, _wrapNonceWithAddr(1, sessionKey1) + ); + userOps[1] = _generateAndSignUserOp( + 100_000, 300_000, 1_000 gwei, 0.4 ether, sessionKey1Private, _wrapNonceWithAddr(2, sessionKey1) + ); + + // Run the user ops. This should fail now, with the second one's start time being later. + vm.expectRevert(abi.encodeWithSelector(IEntryPoint.FailedOp.selector, 1, "AA22 expired or not due")); + entryPoint.handleOps(userOps, beneficiary); + + // Usage should still be at 0.6 + (ISessionKeyPermissionsPlugin.SpendLimitInfo memory spendLimitInfo,) = + sessionKeyPermissionsPlugin.getGasSpendLimit(address(account1), sessionKey1); + assertTrue(spendLimitInfo.hasLimit); + assertEq(spendLimitInfo.limit, 1 ether); + assertEq(spendLimitInfo.limitUsed, 0.6 ether); + assertEq(spendLimitInfo.refreshInterval, 1 days); + assertEq(spendLimitInfo.lastUsedTime, time0); + + // Skip forward and run the user ops again. This should succeed now. + skip(1 days + 1 minutes); + + entryPoint.handleOps(userOps, beneficiary); + + // Usage should now be at 0.4 (odd case, since the first one fits in the old interval, but the second one + // doesn't. + + (spendLimitInfo,) = sessionKeyPermissionsPlugin.getGasSpendLimit(address(account1), sessionKey1); + assertTrue(spendLimitInfo.hasLimit); + assertEq(spendLimitInfo.limit, 1 ether); + assertEq(spendLimitInfo.limitUsed, 0.4 ether); + assertEq(spendLimitInfo.refreshInterval, 1 days); + // The last used time SHOULD increment by the actual time passed, not just the interval, if the call + // succeeded. + assertEq(spendLimitInfo.lastUsedTime, time0 + 1 days + 1 minutes); + } + + function test_sessionKeyGasLimits_refreshInterval_multipleInBundle_tryExceedFails() public { + // Pick a start time + uint256 time0 = 1698708080; + vm.warp(time0); + + bytes[] memory updates = new bytes[](1); + updates[0] = abi.encodeCall(ISessionKeyPermissionsUpdates.setGasSpendLimit, (1 ether, 1 days)); + vm.prank(owner1); + SessionKeyPermissionsPlugin(address(account1)).updateKeyPermissions(sessionKey1, updates); + + // Use up 0.8 ether + _runSessionKeyUserOp(100_000, 700_000, 1_000 gwei, 0.8 ether, sessionKey1Private, ""); + + // Construct three user ops that each cost 0.4 ether, and when bundled together, spend 1.2 ether + + UserOperation[] memory userOps = new UserOperation[](3); + userOps[0] = _generateAndSignUserOp( + 100_000, 300_000, 1_000 gwei, 0.4 ether, sessionKey1Private, _wrapNonceWithAddr(1, sessionKey1) + ); + userOps[1] = _generateAndSignUserOp( + 100_000, 300_000, 1_000 gwei, 0.4 ether, sessionKey1Private, _wrapNonceWithAddr(2, sessionKey1) + ); + userOps[2] = _generateAndSignUserOp( + 100_000, 300_000, 1_000 gwei, 0.4 ether, sessionKey1Private, _wrapNonceWithAddr(3, sessionKey1) + ); + + // Run the user ops. This should fail now, since even the first one exceeds the spend limit + vm.expectRevert(abi.encodeWithSelector(IEntryPoint.FailedOp.selector, 0, "AA22 expired or not due")); + entryPoint.handleOps(userOps, beneficiary); + + // Usage should still be at 0.8 + (ISessionKeyPermissionsPlugin.SpendLimitInfo memory spendLimitInfo,) = + sessionKeyPermissionsPlugin.getGasSpendLimit(address(account1), sessionKey1); + assertTrue(spendLimitInfo.hasLimit); + assertEq(spendLimitInfo.limit, 1 ether); + assertEq(spendLimitInfo.limitUsed, 0.8 ether); + assertEq(spendLimitInfo.refreshInterval, 1 days); + assertEq(spendLimitInfo.lastUsedTime, time0); + + // Skip forward and try to run the user ops again. This should still fail, since the third one + // would exceed the next spend limit window. This is somewhat counterintuitive, since it would seem like + // 0.8 + 1.2 ether should fit in two 1 ether intervals. However, the first user op in the 1.2 ether bundler + // starts a new interval and sets the usage to 0.4, meaning the remainder from the previous interval is not + // actually usable. + skip(1 days + 1 minutes); + + vm.expectRevert(abi.encodeWithSelector(IEntryPoint.FailedOp.selector, 2, "AA24 signature error")); + entryPoint.handleOps(userOps, beneficiary); + } + + function test_sessionKeyGasLimits_refreshInterval_resetFlagTracking() public { + // Pick a start time + uint256 time0 = 1698708080; + vm.warp(time0); + + bytes[] memory updates = new bytes[](1); + updates[0] = abi.encodeCall(ISessionKeyPermissionsUpdates.setGasSpendLimit, (1 ether, 1 days)); + vm.prank(owner1); + SessionKeyPermissionsPlugin(address(account1)).updateKeyPermissions(sessionKey1, updates); + + // Use up 0.6 ether + _runSessionKeyUserOp(100_000, 500_000, 1_000 gwei, 0.6 ether, sessionKey1Private, ""); + + // Try to use up 0.6 ether again, but with a call that reverts during execution. + + UserOperation[] memory userOps = new UserOperation[](1); + + Call[] memory calls = new Call[](1); + calls[0] = Call({target: address(this), value: 0 ether, data: abi.encodeWithSelector(bytes4(0x11223344))}); + userOps[0] = _generateAndSignUserOpWithCustomExecutions( + 100_000, 500_000, 1_000 gwei, 0.6 ether, calls, sessionKey1Private, _wrapNonceWithAddr(1, sessionKey1) + ); + + skip(1 days + 1 minutes); + entryPoint.handleOps(userOps, beneficiary); + + (, bool shouldReset) = sessionKeyPermissionsPlugin.getGasSpendLimit(address(account1), sessionKey1); + + assertTrue(shouldReset, "Session key should report that it needs to be reset"); + } + + function test_sessionKeyGasLimits_refreshInterval_resetFlag_fixWithExtraUO() public { + // Run the above test + test_sessionKeyGasLimits_refreshInterval_resetFlagTracking(); + + // Now, attempt to fix it by running a user op that does not exceed the limit, does not revert in + // execution, and resets the flag. + + // Use up 0.6 ether + _runSessionKeyUserOp(100_000, 300_000, 1_000 gwei, 0.4 ether, sessionKey1Private, ""); + + // The reset flag should now be false + (ISessionKeyPermissionsPlugin.SpendLimitInfo memory spendLimitInfo, bool shouldReset) = + sessionKeyPermissionsPlugin.getGasSpendLimit(address(account1), sessionKey1); + + assertFalse(shouldReset, "Session key should report that it does not need to be reset"); + assertEq(spendLimitInfo.limitUsed, 1 ether); + assertEq(spendLimitInfo.lastUsedTime, block.timestamp); + } + + function test_sessionKeyGasLimits_refreshInterval_resetFlag_fixWithOwnerReset() public { + // Run the above test + test_sessionKeyGasLimits_refreshInterval_resetFlagTracking(); + + // Now, attempt to fix it by calling the update function with the same parameters as before. + bytes[] memory updates = new bytes[](1); + updates[0] = abi.encodeCall(ISessionKeyPermissionsUpdates.setGasSpendLimit, (1 ether, 1 days)); + vm.prank(owner1); + SessionKeyPermissionsPlugin(address(account1)).updateKeyPermissions(sessionKey1, updates); + + // The reset flag should now be false + (ISessionKeyPermissionsPlugin.SpendLimitInfo memory spendLimitInfo, bool shouldReset) = + sessionKeyPermissionsPlugin.getGasSpendLimit(address(account1), sessionKey1); + + assertFalse(shouldReset, "Session key should report that it does not need to be reset"); + assertEq(spendLimitInfo.limitUsed, 0.6 ether); + assertEq(spendLimitInfo.lastUsedTime, block.timestamp); + } + + function test_sessionKeyGasLimits_refreshInterval_resetFlag_fixWithPublicReset() public { + // Run the above test + test_sessionKeyGasLimits_refreshInterval_resetFlagTracking(); + + // Now, attempt to fix it by calling the public reset function + sessionKeyPermissionsPlugin.resetSessionKeyGasLimitTimestamp(address(account1), sessionKey1); + + // The reset flag should now be false + (ISessionKeyPermissionsPlugin.SpendLimitInfo memory spendLimitInfo, bool shouldReset) = + sessionKeyPermissionsPlugin.getGasSpendLimit(address(account1), sessionKey1); + + assertFalse(shouldReset, "Session key should report that it does not need to be reset"); + assertEq(spendLimitInfo.limitUsed, 0.6 ether); + assertEq(spendLimitInfo.lastUsedTime, block.timestamp); + } + + function _getMaxGasCostPerUserOp(UserOperation memory userOp) internal pure returns (uint256) { + uint256 multiplier = userOp.paymasterAndData.length > 0 ? 3 : 1; + uint256 maxGasFee = ( + userOp.callGasLimit + userOp.verificationGasLimit * multiplier + userOp.preVerificationGas + ) * userOp.maxFeePerGas; + return maxGasFee; + } + + function _generateAndSignUserOp( + uint256 callGasLimit, + uint256 verificationGasLimit, + uint256 maxFeePerGas, + uint256 expectedEtherValue, + uint256 sessionKeyPrivate, + uint256 nonce + ) internal returns (UserOperation memory) { + address sessionKey = vm.addr(sessionKeyPrivate); + + // Just creates a dummy call, since the values being checked are only in the user op's gas fields. + Call[] memory calls = new Call[](1); + calls[0] = Call({target: recipient, value: 0 ether, data: ""}); + + UserOperation memory userOp = UserOperation({ + sender: address(account1), + nonce: nonce, + initCode: "", + callData: abi.encodeCall(ISessionKeyPlugin.executeWithSessionKey, (calls, sessionKey)), + callGasLimit: callGasLimit, + verificationGasLimit: verificationGasLimit, + preVerificationGas: 0, + maxFeePerGas: maxFeePerGas, + maxPriorityFeePerGas: 0, + paymasterAndData: "", + signature: "" + }); + + // Double-check that the parameters given actually result in a expected native token usage amount + + assertEq( + _getMaxGasCostPerUserOp(userOp), + expectedEtherValue, + "Mismatch between expect gas fee and actual gas fee" + ); + + bytes32 userOpHash = entryPoint.getUserOpHash(userOp); + (uint8 v, bytes32 r, bytes32 s) = vm.sign(sessionKeyPrivate, userOpHash.toEthSignedMessageHash()); + userOp.signature = abi.encodePacked(r, s, v); + + return userOp; + } + + function _generateAndSignUserOpWithCustomExecutions( + uint256 callGasLimit, + uint256 verificationGasLimit, + uint256 maxFeePerGas, + uint256 expectedEtherValue, + Call[] memory calls, + uint256 sessionKeyPrivate, + uint256 nonce + ) internal returns (UserOperation memory) { + address sessionKey = vm.addr(sessionKeyPrivate); + + UserOperation memory userOp = UserOperation({ + sender: address(account1), + nonce: nonce, + initCode: "", + callData: abi.encodeCall(ISessionKeyPlugin.executeWithSessionKey, (calls, sessionKey)), + callGasLimit: callGasLimit, + verificationGasLimit: verificationGasLimit, + preVerificationGas: 0, + maxFeePerGas: maxFeePerGas, + maxPriorityFeePerGas: 0, + paymasterAndData: "", + signature: "" + }); + + // Double-check that the parameters given actually result in a expected native token usage amount + + assertEq( + _getMaxGasCostPerUserOp(userOp), + expectedEtherValue, + "Mismatch between expect gas fee and actual gas fee" + ); + + bytes32 userOpHash = entryPoint.getUserOpHash(userOp); + (uint8 v, bytes32 r, bytes32 s) = vm.sign(sessionKeyPrivate, userOpHash.toEthSignedMessageHash()); + userOp.signature = abi.encodePacked(r, s, v); + + return userOp; + } + + function _runSessionKeyUserOp( + uint256 callGasLimit, + uint256 verificationGasLimit, + uint256 maxFeePerGas, + uint256 expectedEtherValue, + uint256 sessionKeyPrivate, + bytes memory expectedError + ) internal { + UserOperation[] memory userOps = new UserOperation[](1); + userOps[0] = _generateAndSignUserOp( + callGasLimit, + verificationGasLimit, + maxFeePerGas, + expectedEtherValue, + sessionKeyPrivate, + entryPoint.getNonce(address(account1), uint192(uint160(vm.addr(sessionKeyPrivate)))) + ); + + if (expectedError.length > 0) { + vm.expectRevert(expectedError); + } + entryPoint.handleOps(userOps, beneficiary); + } + + function _wrapNonceWithAddr(uint64 nonce, address addr) internal pure returns (uint256) { + return nonce | uint256(uint160(addr)) << 64; + } +} diff --git a/test/plugin/session/permissions/SessionKeyNativeTokenSpendLimits.t.sol b/test/plugin/session/permissions/SessionKeyNativeTokenSpendLimits.t.sol new file mode 100644 index 00000000..fc26e473 --- /dev/null +++ b/test/plugin/session/permissions/SessionKeyNativeTokenSpendLimits.t.sol @@ -0,0 +1,674 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.21; + +import {Test, console} from "forge-std/Test.sol"; + +import {ECDSA} from "@openzeppelin/contracts/utils/cryptography/ECDSA.sol"; +import {EntryPoint} from "@eth-infinitism/account-abstraction/core/EntryPoint.sol"; + +import {UpgradeableModularAccount} from "../../../../src/account/UpgradeableModularAccount.sol"; +import {IMultiOwnerPlugin} from "../../../../src/plugins/owner/IMultiOwnerPlugin.sol"; +import {MultiOwnerPlugin} from "../../../../src/plugins/owner/MultiOwnerPlugin.sol"; +import {ISessionKeyPlugin} from "../../../../src/plugins/session/ISessionKeyPlugin.sol"; +import {SessionKeyPlugin} from "../../../../src/plugins/session/SessionKeyPlugin.sol"; +import {ISessionKeyPermissionsPlugin} from + "../../../../src/plugins/session/permissions/ISessionKeyPermissionsPlugin.sol"; +import {ISessionKeyPermissionsUpdates} from + "../../../../src/plugins/session/permissions/ISessionKeyPermissionsUpdates.sol"; +import {SessionKeyPermissionsPlugin} from + "../../../../src/plugins/session/permissions/SessionKeyPermissionsPlugin.sol"; +import {IEntryPoint} from "../../../../src/interfaces/erc4337/IEntryPoint.sol"; +import {UserOperation} from "../../../../src/interfaces/erc4337/UserOperation.sol"; +import {IPluginManager} from "../../../../src/interfaces/IPluginManager.sol"; +import {FunctionReference, FunctionReferenceLib} from "../../../../src/libraries/FunctionReferenceLib.sol"; +import {Call} from "../../../../src/interfaces/IStandardExecutor.sol"; + +import {MultiOwnerMSCAFactory} from "../../../../src/factory/MultiOwnerMSCAFactory.sol"; + +contract SessionKeyNativeTokenSpendLimitsTest is Test { + using ECDSA for bytes32; + + IEntryPoint entryPoint; + address payable beneficiary; + + MultiOwnerPlugin multiOwnerPlugin; + MultiOwnerMSCAFactory factory; + SessionKeyPlugin sessionKeyPlugin; + SessionKeyPermissionsPlugin sessionKeyPermissionsPlugin; + + address owner1; + uint256 owner1Key; + UpgradeableModularAccount account1; + + address sessionKey1; + uint256 sessionKey1Private; + + address recipient1; + address recipient2; + address recipient3; + + // Constants for running user ops + uint256 constant CALL_GAS_LIMIT = 300000; + uint256 constant VERIFICATION_GAS_LIMIT = 1000000; + + function setUp() public { + entryPoint = IEntryPoint(address(new EntryPoint())); + (owner1, owner1Key) = makeAddrAndKey("owner1"); + beneficiary = payable(makeAddr("beneficiary")); + + vm.deal(beneficiary, 1 wei); + + multiOwnerPlugin = new MultiOwnerPlugin(); + bytes32 multiOwnerPluginManifestHash = keccak256(abi.encode(multiOwnerPlugin.pluginManifest())); + factory = + new MultiOwnerMSCAFactory(address(this), address(multiOwnerPlugin), address(new UpgradeableModularAccount(entryPoint)), multiOwnerPluginManifestHash, entryPoint); + + sessionKeyPlugin = new SessionKeyPlugin(); + sessionKeyPermissionsPlugin = new SessionKeyPermissionsPlugin(); + + address[] memory owners1 = new address[](1); + owners1[0] = owner1; + account1 = UpgradeableModularAccount(payable(factory.createAccount(0, owners1))); + vm.deal(address(account1), 100 ether); + + bytes32 manifestHash = keccak256(abi.encode(sessionKeyPlugin.pluginManifest())); + FunctionReference[] memory dependencies = new FunctionReference[](2); + dependencies[0] = FunctionReferenceLib.pack( + address(multiOwnerPlugin), uint8(IMultiOwnerPlugin.FunctionId.USER_OP_VALIDATION_OWNER) + ); + dependencies[1] = FunctionReferenceLib.pack( + address(multiOwnerPlugin), uint8(IMultiOwnerPlugin.FunctionId.RUNTIME_VALIDATION_OWNER_OR_SELF) + ); + vm.prank(owner1); + account1.installPlugin({ + plugin: address(sessionKeyPlugin), + manifestHash: manifestHash, + pluginInitData: abi.encode(new address[](0)), + dependencies: dependencies, + injectedHooks: new IPluginManager.InjectedHook[](0) + }); + + manifestHash = keccak256(abi.encode(sessionKeyPermissionsPlugin.pluginManifest())); + // Can reuse the same dependencies for this installation, because the requirements are the same. + vm.prank(owner1); + account1.installPlugin({ + plugin: address(sessionKeyPermissionsPlugin), + manifestHash: manifestHash, + pluginInitData: "", + dependencies: dependencies, + injectedHooks: new IPluginManager.InjectedHook[](0) + }); + + // Create and add a session key + (sessionKey1, sessionKey1Private) = makeAddrAndKey("sessionKey1"); + + address[] memory sessionKeysToAdd = new address[](1); + sessionKeysToAdd[0] = sessionKey1; + + vm.prank(owner1); + SessionKeyPlugin(address(account1)).updateSessionKeys( + sessionKeysToAdd, new SessionKeyPlugin.SessionKeyToRemove[](0) + ); + + // Register the session key with the permissions plugin + vm.prank(owner1); + SessionKeyPermissionsPlugin(address(account1)).registerKey(sessionKey1, 0); + + // Remove the allowlist + bytes[] memory updates = new bytes[](1); + updates[0] = abi.encodeCall( + ISessionKeyPermissionsUpdates.setAccessListType, + (ISessionKeyPermissionsPlugin.ContractAccessControlType.NONE) + ); + vm.prank(owner1); + SessionKeyPermissionsPlugin(address(account1)).updateKeyPermissions(sessionKey1, updates); + + // Create recipient addresses to receive ether + recipient1 = makeAddr("recipient1"); + recipient2 = makeAddr("recipient2"); + recipient3 = makeAddr("recipient3"); + } + + function test_sessionKeyNativeTokenSpendLimits_validateSetUp() public { + // Check that the session key is registered + assertTrue(SessionKeyPlugin(address(account1)).isSessionKey(sessionKey1)); + + // Check that the session key is registered with the permissions plugin and has its allowlist set up + // correctly + assertTrue( + sessionKeyPermissionsPlugin.getAccessControlType(address(account1), sessionKey1) + == ISessionKeyPermissionsPlugin.ContractAccessControlType.NONE + ); + } + + function testFuzz_sessionKeyNativeTokenSpendLimits_setLimits(uint256 limit, uint48 interval, uint48 timestamp) + public + { + vm.warp(timestamp); + + // Assert that the limit starts out set, and at zero + ISessionKeyPermissionsPlugin.SpendLimitInfo memory spendLimitInfo = + sessionKeyPermissionsPlugin.getNativeTokenSpendLimitInfo(address(account1), sessionKey1); + + assertTrue(spendLimitInfo.hasLimit); + assertEq(spendLimitInfo.limit, 0); + assertEq(spendLimitInfo.refreshInterval, 0); + assertEq(spendLimitInfo.limitUsed, 0); + + // Set the limit + bytes[] memory updates = new bytes[](1); + updates[0] = abi.encodeCall(ISessionKeyPermissionsUpdates.setNativeTokenSpendLimit, (limit, interval)); + vm.prank(owner1); + SessionKeyPermissionsPlugin(address(account1)).updateKeyPermissions(sessionKey1, updates); + + // Verify the limit can be retrieved + spendLimitInfo = sessionKeyPermissionsPlugin.getNativeTokenSpendLimitInfo(address(account1), sessionKey1); + + if (limit == type(uint256).max) { + // If the limit is "set" to this value, it is just removed. + // verify that the values are still as they were before. + assertEq(spendLimitInfo.limit, 0); + assertEq(spendLimitInfo.refreshInterval, 0); + assertEq(spendLimitInfo.limitUsed, 0); + } else { + // The limit is actually set, verify that the values are as expected. + assertTrue(spendLimitInfo.hasLimit); + assertEq(spendLimitInfo.limit, limit); + assertEq(spendLimitInfo.refreshInterval, interval); + assertEq(spendLimitInfo.limitUsed, 0); + if (interval == 0) { + assertEq(spendLimitInfo.lastUsedTime, 0); + } else { + assertEq(spendLimitInfo.lastUsedTime, timestamp); + } + } + } + + function test_sessionKeyNativeTokenSpendLimits_enforceLimit_none() public { + // The limit starts out at zero + + // Run a user op that spends 1 wei, should fail + Call[] memory calls = new Call[](1); + calls[0] = Call({target: recipient1, value: 1 wei, data: ""}); + + _runSessionKeyUserOp( + calls, + sessionKey1Private, + abi.encodeWithSelector(IEntryPoint.FailedOp.selector, 0, "AA24 signature error") + ); + + // Run a user op that spends 0 wei, should succeed + calls[0] = Call({target: recipient1, value: 0, data: "somedata"}); + + _runSessionKeyUserOp(calls, sessionKey1Private, ""); + + // Run a multi-execution user op that spends 0 wei, should succeed + calls = new Call[](2); + calls[0] = Call({target: recipient1, value: 0, data: "somedata1"}); + calls[1] = Call({target: recipient2, value: 0, data: "somedata2"}); + + _runSessionKeyUserOp(calls, sessionKey1Private, ""); + } + + function test_sessionKeyNativeTokenSpendLimits_basic_single() public { + // Set the limit to 1 ether + bytes[] memory updates = new bytes[](1); + updates[0] = abi.encodeCall(ISessionKeyPermissionsUpdates.setNativeTokenSpendLimit, (1 ether, 0)); + vm.prank(owner1); + SessionKeyPermissionsPlugin(address(account1)).updateKeyPermissions(sessionKey1, updates); + + // Run a user op that spends 1 wei, should succeed + Call[] memory calls = new Call[](1); + calls[0] = Call({target: recipient1, value: 1 wei, data: ""}); + + _runSessionKeyUserOp(calls, sessionKey1Private, ""); + + // Assert that the limit is now updated + ISessionKeyPermissionsPlugin.SpendLimitInfo memory spendLimitInfo = + sessionKeyPermissionsPlugin.getNativeTokenSpendLimitInfo(address(account1), sessionKey1); + + assertEq(spendLimitInfo.limit, 1 ether); + assertEq(spendLimitInfo.limitUsed, 1 wei); + assertEq(spendLimitInfo.refreshInterval, 0); + // Assert that the last used time is not updated when the interval is unset. + assertEq(spendLimitInfo.lastUsedTime, 0); + + // Run a user op that spends 1 ether, should fail + + calls[0] = Call({target: recipient1, value: 1 ether, data: ""}); + + _runSessionKeyUserOp( + calls, + sessionKey1Private, + abi.encodeWithSelector(IEntryPoint.FailedOp.selector, 0, "AA24 signature error") + ); + } + + function test_sessionKeyNativeTokenSpendLimits_exceedLimit_single() public { + // Set the limit to 1 ether + bytes[] memory updates = new bytes[](1); + updates[0] = abi.encodeCall(ISessionKeyPermissionsUpdates.setNativeTokenSpendLimit, (1 ether, 0)); + vm.prank(owner1); + SessionKeyPermissionsPlugin(address(account1)).updateKeyPermissions(sessionKey1, updates); + + // Run a user op that spends 1 wei, should succeed + Call[] memory calls = new Call[](1); + calls[0] = Call({target: recipient1, value: 1 wei, data: ""}); + + _runSessionKeyUserOp(calls, sessionKey1Private, ""); + + // Attempt to run an execution spending 1 ether, should fail + calls[0] = Call({target: recipient1, value: 1 ether, data: ""}); + + _runSessionKeyUserOp( + calls, + sessionKey1Private, + abi.encodeWithSelector(IEntryPoint.FailedOp.selector, 0, "AA24 signature error") + ); + + // Assert that the limit is NOT updated + ISessionKeyPermissionsPlugin.SpendLimitInfo memory spendLimitInfo = + sessionKeyPermissionsPlugin.getNativeTokenSpendLimitInfo(address(account1), sessionKey1); + + assertTrue(spendLimitInfo.hasLimit); + assertEq(spendLimitInfo.limit, 1 ether); + assertEq(spendLimitInfo.limitUsed, 1 wei); + assertEq(spendLimitInfo.refreshInterval, 0); + // Assert that the last used time is not updated when the interval is unset. + assertEq(spendLimitInfo.lastUsedTime, 0); + } + + // Tests basic enforcement of spend limits when using more than one execution in a user op. + function test_sessionKeyNativeTokenSpendLimits_basic_multi() public { + // Set the limit to 1 ether + bytes[] memory updates = new bytes[](1); + updates[0] = abi.encodeCall(ISessionKeyPermissionsUpdates.setNativeTokenSpendLimit, (1 ether, 0)); + vm.prank(owner1); + SessionKeyPermissionsPlugin(address(account1)).updateKeyPermissions(sessionKey1, updates); + + // Run a multi execution user op spending 3 wei, should succeed + Call[] memory calls = new Call[](3); + calls[0] = Call({target: recipient1, value: 1 wei, data: ""}); + calls[1] = Call({target: recipient2, value: 1 wei, data: ""}); + calls[2] = Call({target: recipient3, value: 1 wei, data: ""}); + + _runSessionKeyUserOp(calls, sessionKey1Private, ""); + + // Assert that the limit is now updated + ISessionKeyPermissionsPlugin.SpendLimitInfo memory spendLimitInfo = + sessionKeyPermissionsPlugin.getNativeTokenSpendLimitInfo(address(account1), sessionKey1); + + assertEq(spendLimitInfo.limit, 1 ether); + assertEq(spendLimitInfo.limitUsed, 3 wei); + assertEq(spendLimitInfo.refreshInterval, 0); + // Assert that the last used time is not updated when the interval is unset. + assertEq(spendLimitInfo.lastUsedTime, 0); + } + + function test_sessionKeyNativeTokenSpendLimits_exceedLimit_multi() public { + // Set the limit to 1 ether + bytes[] memory updates = new bytes[](1); + updates[0] = abi.encodeCall(ISessionKeyPermissionsUpdates.setNativeTokenSpendLimit, (1 ether, 0)); + vm.prank(owner1); + SessionKeyPermissionsPlugin(address(account1)).updateKeyPermissions(sessionKey1, updates); + + // Attempt to run a multi execution user op spending 1.5 ether, should fail + Call[] memory calls = new Call[](3); + calls[0] = Call({target: recipient1, value: 0.5 ether, data: ""}); + calls[1] = Call({target: recipient2, value: 0.5 ether, data: ""}); + calls[2] = Call({target: recipient3, value: 0.5 ether, data: ""}); + + _runSessionKeyUserOp( + calls, + sessionKey1Private, + abi.encodeWithSelector(IEntryPoint.FailedOp.selector, 0, "AA24 signature error") + ); + + // Assert that the limit is NOT updated + ISessionKeyPermissionsPlugin.SpendLimitInfo memory spendLimitInfo = + sessionKeyPermissionsPlugin.getNativeTokenSpendLimitInfo(address(account1), sessionKey1); + + assertTrue(spendLimitInfo.hasLimit); + assertEq(spendLimitInfo.limit, 1 ether); + assertEq(spendLimitInfo.limitUsed, 0); + assertEq(spendLimitInfo.refreshInterval, 0); + // Assert that the last used time is not updated when the interval is unset. + assertEq(spendLimitInfo.lastUsedTime, 0); + } + + function test_sessionKeyNativeTokenSpendLimits_refreshInterval_single() public { + // Set the time to the current unix timestamp as of writing + uint256 time0 = 1698708080; + vm.warp(time0); + + // Set the limit to 1 ether + bytes[] memory updates = new bytes[](1); + updates[0] = abi.encodeCall(ISessionKeyPermissionsUpdates.setNativeTokenSpendLimit, (1 ether, 1 days)); + vm.prank(owner1); + SessionKeyPermissionsPlugin(address(account1)).updateKeyPermissions(sessionKey1, updates); + + // Run a user op that spends 1 wei, should succeed + Call[] memory calls = new Call[](1); + calls[0] = Call({target: recipient1, value: 1 wei, data: ""}); + + _runSessionKeyUserOp(calls, sessionKey1Private, ""); + + // Assert that the limit is now updated and the last used timestamp is set. + ISessionKeyPermissionsPlugin.SpendLimitInfo memory spendLimitInfo = + sessionKeyPermissionsPlugin.getNativeTokenSpendLimitInfo(address(account1), sessionKey1); + + assertEq(spendLimitInfo.limit, 1 ether); + assertEq(spendLimitInfo.limitUsed, 1 wei); + assertEq(spendLimitInfo.refreshInterval, 1 days); + assertEq(spendLimitInfo.lastUsedTime, time0); + + // Run a user op that spends 1 ether, should fail + calls[0] = Call({target: recipient1, value: 1 ether, data: ""}); + + _runSessionKeyUserOp( + calls, + sessionKey1Private, + // The execution will be valid at a later time when the interval resets, but not right now. + abi.encodeWithSelector(IEntryPoint.FailedOp.selector, 0, "AA22 expired or not due") + ); + + // Assert that the limit is NOT updated + spendLimitInfo = sessionKeyPermissionsPlugin.getNativeTokenSpendLimitInfo(address(account1), sessionKey1); + + assertTrue(spendLimitInfo.hasLimit); + assertEq(spendLimitInfo.limit, 1 ether); + assertEq(spendLimitInfo.limitUsed, 1 wei); + assertEq(spendLimitInfo.refreshInterval, 1 days); + assertEq(spendLimitInfo.lastUsedTime, block.timestamp); + + // warp to when the interval resets + vm.warp(time0 + 1 days); + + // Run a user op that spends 1 ether, should succeed + calls[0] = Call({target: recipient1, value: 1 ether, data: ""}); + + _runSessionKeyUserOp(calls, sessionKey1Private, ""); + + // Assert that the limit is now updated and the last used timestamp is set. + spendLimitInfo = sessionKeyPermissionsPlugin.getNativeTokenSpendLimitInfo(address(account1), sessionKey1); + + assertEq(spendLimitInfo.limit, 1 ether); + assertEq(spendLimitInfo.limitUsed, 1 ether); + assertEq(spendLimitInfo.refreshInterval, 1 days); + assertEq(spendLimitInfo.lastUsedTime, time0 + 1 days); + } + + function test_sessionKeyNativeTokenSpendLimits_refreshInterval_multi() public { + // Set the time to the current unix timestamp as of writing + uint256 time0 = 1698708080; + vm.warp(time0); + + // Set the limit to 1 ether + bytes[] memory updates = new bytes[](1); + updates[0] = abi.encodeCall(ISessionKeyPermissionsUpdates.setNativeTokenSpendLimit, (1 ether, 1 days)); + vm.prank(owner1); + SessionKeyPermissionsPlugin(address(account1)).updateKeyPermissions(sessionKey1, updates); + + // Run a user op that spends 1 wei, should succeed + Call[] memory calls = new Call[](1); + calls[0] = Call({target: recipient1, value: 1 wei, data: ""}); + + _runSessionKeyUserOp(calls, sessionKey1Private, ""); + + // Assert that the limit is now updated and the last used timestamp is set. + ISessionKeyPermissionsPlugin.SpendLimitInfo memory spendLimitInfo = + sessionKeyPermissionsPlugin.getNativeTokenSpendLimitInfo(address(account1), sessionKey1); + + assertEq(spendLimitInfo.limit, 1 ether); + assertEq(spendLimitInfo.limitUsed, 1 wei); + assertEq(spendLimitInfo.refreshInterval, 1 days); + assertEq(spendLimitInfo.lastUsedTime, time0); + + // Run a user op that spends 1 ether, should fail + calls = new Call[](2); + calls[0] = Call({target: recipient1, value: 0.5 ether, data: ""}); + calls[1] = Call({target: recipient2, value: 0.5 ether, data: ""}); + + _runSessionKeyUserOp( + calls, + sessionKey1Private, + // The execution will be valid at a later time when the interval resets, but not right now. + abi.encodeWithSelector(IEntryPoint.FailedOp.selector, 0, "AA22 expired or not due") + ); + + // Assert that the limit is NOT updated + spendLimitInfo = sessionKeyPermissionsPlugin.getNativeTokenSpendLimitInfo(address(account1), sessionKey1); + + assertTrue(spendLimitInfo.hasLimit); + assertEq(spendLimitInfo.limit, 1 ether); + assertEq(spendLimitInfo.limitUsed, 1 wei); + assertEq(spendLimitInfo.refreshInterval, 1 days); + assertEq(spendLimitInfo.lastUsedTime, block.timestamp); + + // warp to when the interval resets + vm.warp(time0 + 1 days); + + // Run the previous user op that spends 1 ether, should succeed + _runSessionKeyUserOp(calls, sessionKey1Private, ""); + + // Assert that the limit is now updated and the last used timestamp is set. + spendLimitInfo = sessionKeyPermissionsPlugin.getNativeTokenSpendLimitInfo(address(account1), sessionKey1); + + assertEq(spendLimitInfo.limit, 1 ether); + assertEq(spendLimitInfo.limitUsed, 1 ether); + assertEq(spendLimitInfo.refreshInterval, 1 days); + assertEq(spendLimitInfo.lastUsedTime, time0 + 1 days); + } + + function test_sessionKeyNativeTokenSpendLimits_basic_refreshInterval_takeMaxStartTime() public { + // Tests the behavior of the session key spending limits to return the higher starting time between the + // key's time and the spending limit's time. + + // Set the time to the current unix timestamp as of writing + uint256 time0 = 1698708080; + vm.warp(time0); + + // Set the limit to 1 ether + bytes[] memory updates = new bytes[](1); + updates[0] = abi.encodeCall(ISessionKeyPermissionsUpdates.setNativeTokenSpendLimit, (1 ether, 1 days)); + vm.prank(owner1); + SessionKeyPermissionsPlugin(address(account1)).updateKeyPermissions(sessionKey1, updates); + + // Run a user op that spends 1 wei, should succeed + Call[] memory calls = new Call[](1); + calls[0] = Call({target: recipient1, value: 1 wei, data: ""}); + + _runSessionKeyUserOp(calls, sessionKey1Private, ""); + + // Assert that the limit is now updated and the last used timestamp is set. + ISessionKeyPermissionsPlugin.SpendLimitInfo memory spendLimitInfo = + sessionKeyPermissionsPlugin.getNativeTokenSpendLimitInfo(address(account1), sessionKey1); + + assertEq(spendLimitInfo.limit, 1 ether); + assertEq(spendLimitInfo.limitUsed, 1 wei); + assertEq(spendLimitInfo.refreshInterval, 1 days); + assertEq(spendLimitInfo.lastUsedTime, time0); + + // Assert that if we try to run a user op sending 1 ether, + // then it will return the current time + the interval. + + calls[0] = Call({target: recipient1, value: 1 ether, data: ""}); + + UserOperation memory uo = UserOperation({ + sender: address(account1), + nonce: entryPoint.getNonce(address(account1), 0), + initCode: "", + callData: abi.encodeCall(ISessionKeyPlugin.executeWithSessionKey, (calls, sessionKey1)), + callGasLimit: CALL_GAS_LIMIT, + verificationGasLimit: VERIFICATION_GAS_LIMIT, + preVerificationGas: 0, + maxFeePerGas: 2, + maxPriorityFeePerGas: 1, + paymasterAndData: "", + signature: "" + }); + + bytes32 userOpHash = entryPoint.getUserOpHash(uo); + (uint8 v, bytes32 r, bytes32 s) = vm.sign(sessionKey1Private, userOpHash.toEthSignedMessageHash()); + uo.signature = abi.encodePacked(r, s, v); + + vm.prank(address(entryPoint)); + uint256 result = account1.validateUserOp(uo, userOpHash, 0); + uint48 expectedStartTime = uint48(time0 + 1 days); + uint48 actualStartTime = uint48(result >> 208); + assertEq(actualStartTime, expectedStartTime); + + // Set the key's time limit to a value greater than the limit's start time. + uint256 keyStartTime = time0 + 2 days; + + bytes[] memory updates2 = new bytes[](1); + updates2[0] = abi.encodeCall(ISessionKeyPermissionsUpdates.updateTimeRange, (uint48(keyStartTime), 0)); + vm.prank(owner1); + SessionKeyPermissionsPlugin(address(account1)).updateKeyPermissions(sessionKey1, updates2); + + // Assert that the later start time is returned (key time range) + vm.prank(address(entryPoint)); + result = account1.validateUserOp(uo, userOpHash, 0); + + expectedStartTime = uint48(keyStartTime); + actualStartTime = uint48(result >> 208); + assertEq(actualStartTime, expectedStartTime); + + // Set the key's time limit to a value less than the limit's start time. + keyStartTime = time0 + 12 hours; + + updates2[0] = abi.encodeCall(ISessionKeyPermissionsUpdates.updateTimeRange, (uint48(keyStartTime), 0)); + vm.prank(owner1); + SessionKeyPermissionsPlugin(address(account1)).updateKeyPermissions(sessionKey1, updates2); + + // Assert that the later start time is returned (spend limit) + vm.prank(address(entryPoint)); + result = account1.validateUserOp(uo, userOpHash, 0); + + expectedStartTime = uint48(time0 + 1 days); + actualStartTime = uint48(result >> 208); + assertEq(actualStartTime, expectedStartTime); + } + + // This test protects against an attack vector where a staked account can submit multiple user operations to + // the same bundle, and because of the evaluation order of the user op validations and calls, can get two + // user ops to pass the spend limit validation when by the limit amount, only one should be able to execute. + // This is why the pre execution hook in the permission checker plugin re-checks the amounts being spent, and + // may revert. + function test_sessionKeyNativeTokenSpendLimits_multiUserOpBundle_check_noInterval() public { + // Set the limit to 1 ether + bytes[] memory updates = new bytes[](1); + updates[0] = abi.encodeCall(ISessionKeyPermissionsUpdates.setNativeTokenSpendLimit, (1 ether, 0 days)); + vm.prank(owner1); + SessionKeyPermissionsPlugin(address(account1)).updateKeyPermissions(sessionKey1, updates); + + // Prepare a user op bundle that attempts to spend 1 ether twice. + // The second call should revert because the first call will have updated the limit. + Call[] memory calls = new Call[](1); + calls[0] = Call({target: recipient1, value: 1 ether, data: ""}); + + UserOperation memory userOp1 = UserOperation({ + sender: address(account1), + nonce: entryPoint.getNonce(address(account1), 0), + initCode: "", + callData: abi.encodeCall(ISessionKeyPlugin.executeWithSessionKey, (calls, sessionKey1)), + callGasLimit: CALL_GAS_LIMIT, + verificationGasLimit: VERIFICATION_GAS_LIMIT, + preVerificationGas: 0, + maxFeePerGas: 2, + maxPriorityFeePerGas: 1, + paymasterAndData: "", + signature: "" + }); + + bytes32 userOpHash = entryPoint.getUserOpHash(userOp1); + (uint8 v, bytes32 r, bytes32 s) = vm.sign(sessionKey1Private, userOpHash.toEthSignedMessageHash()); + userOp1.signature = abi.encodePacked(r, s, v); + + UserOperation memory userOp2 = UserOperation({ + sender: address(account1), + nonce: entryPoint.getNonce(address(account1), 1), + initCode: "", + callData: abi.encodeCall(ISessionKeyPlugin.executeWithSessionKey, (calls, sessionKey1)), + callGasLimit: CALL_GAS_LIMIT, + verificationGasLimit: VERIFICATION_GAS_LIMIT, + preVerificationGas: 0, + maxFeePerGas: 2, + maxPriorityFeePerGas: 1, + paymasterAndData: "", + signature: "" + }); + + userOpHash = entryPoint.getUserOpHash(userOp2); + (v, r, s) = vm.sign(sessionKey1Private, userOpHash.toEthSignedMessageHash()); + userOp2.signature = abi.encodePacked(r, s, v); + + // Since handleOps will succeed because nothing will revert during validation, we have to make assertions + // about world state using the recipient's balance. + assertEq(recipient1.balance, 0); + + UserOperation[] memory userOps = new UserOperation[](2); + userOps[0] = userOp1; + userOps[1] = userOp2; + + // The second one should revert during execution, via the re-check phase in pre exec hooks. + // We don't have a good way to check this from a Foundry test without almost fully reimplementing the + // EntryPoint's logic, so instead we will just assert that the call to handleOps succeeds and the + // recipient's balance is only 1 eth after the fact. + entryPoint.handleOps(userOps, beneficiary); + + assertEq(recipient1.balance, 1 ether); + + // Assert that the spend limit is maxed out now. + ISessionKeyPermissionsPlugin.SpendLimitInfo memory spendLimitInfo = + sessionKeyPermissionsPlugin.getNativeTokenSpendLimitInfo(address(account1), sessionKey1); + + assertEq(spendLimitInfo.limit, 1 ether); + assertEq(spendLimitInfo.limitUsed, 1 ether); + assertEq(spendLimitInfo.refreshInterval, 0); + assertEq(spendLimitInfo.lastUsedTime, 0); + } + + // There's an additional pre exec revert that I haven't been able to trigger in a real example, when a + // usage of a session key with a native token spend limit reaches the "new time interval" section, but the + // amount being spent exceeds the new limit. This seems to be impossible to reach because any prior usage would + // reset the last used timestamp, and any call to `setNativeTokenSpendLimit` would also reset the last used + // timestamp. I'm leaving this note here for now in case someone can find a way to trigger it. + // + // There's also the possibility that this is impossible to trigger, and the check is redundant. But I don't + // want to remove it for now, since it's a fairly cheap check that may prevent a dangerous issue. + // + // function test_sessionKeyNativeTokenSpendLimits_multiUserOpBundle_check_interval() + + function _runSessionKeyUserOp(Call[] memory calls, uint256 sessionKeyPrivate, bytes memory expectedError) + internal + { + address sessionKey = vm.addr(sessionKeyPrivate); + + UserOperation memory userOp = UserOperation({ + sender: address(account1), + nonce: entryPoint.getNonce(address(account1), 0), + initCode: "", + callData: abi.encodeCall(ISessionKeyPlugin.executeWithSessionKey, (calls, sessionKey)), + callGasLimit: CALL_GAS_LIMIT, + verificationGasLimit: VERIFICATION_GAS_LIMIT, + preVerificationGas: 0, + maxFeePerGas: 2, + maxPriorityFeePerGas: 1, + paymasterAndData: "", + signature: "" + }); + + bytes32 userOpHash = entryPoint.getUserOpHash(userOp); + (uint8 v, bytes32 r, bytes32 s) = vm.sign(sessionKeyPrivate, userOpHash.toEthSignedMessageHash()); + userOp.signature = abi.encodePacked(r, s, v); + + UserOperation[] memory userOps = new UserOperation[](1); + userOps[0] = userOp; + + if (expectedError.length > 0) { + vm.expectRevert(expectedError); + } + entryPoint.handleOps(userOps, beneficiary); + } +} diff --git a/test/plugin/session/permissions/SessionKeyPermissionsPlugin.t.sol b/test/plugin/session/permissions/SessionKeyPermissionsPlugin.t.sol new file mode 100644 index 00000000..a7c699d7 --- /dev/null +++ b/test/plugin/session/permissions/SessionKeyPermissionsPlugin.t.sol @@ -0,0 +1,573 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.21; + +import {Test} from "forge-std/Test.sol"; + +import {ECDSA} from "@openzeppelin/contracts/utils/cryptography/ECDSA.sol"; +import {EntryPoint} from "@eth-infinitism/account-abstraction/core/EntryPoint.sol"; + +import {UpgradeableModularAccount} from "../../../../src/account/UpgradeableModularAccount.sol"; +import {IMultiOwnerPlugin} from "../../../../src/plugins/owner/IMultiOwnerPlugin.sol"; +import {MultiOwnerPlugin} from "../../../../src/plugins/owner/MultiOwnerPlugin.sol"; +import {ISessionKeyPlugin} from "../../../../src/plugins/session/ISessionKeyPlugin.sol"; +import {SessionKeyPlugin} from "../../../../src/plugins/session/SessionKeyPlugin.sol"; +import {ISessionKeyPermissionsPlugin} from + "../../../../src/plugins/session/permissions/ISessionKeyPermissionsPlugin.sol"; +import {ISessionKeyPermissionsUpdates} from + "../../../../src/plugins/session/permissions/ISessionKeyPermissionsUpdates.sol"; +import {SessionKeyPermissionsPlugin} from + "../../../../src/plugins/session/permissions/SessionKeyPermissionsPlugin.sol"; +import {IEntryPoint} from "../../../../src/interfaces/erc4337/IEntryPoint.sol"; +import {UserOperation} from "../../../../src/interfaces/erc4337/UserOperation.sol"; +import {IPluginManager} from "../../../../src/interfaces/IPluginManager.sol"; +import {FunctionReference, FunctionReferenceLib} from "../../../../src/libraries/FunctionReferenceLib.sol"; +import {Call} from "../../../../src/interfaces/IStandardExecutor.sol"; + +import {Counter} from "../../../mocks/Counter.sol"; +import {MultiOwnerMSCAFactory} from "../../../../src/factory/MultiOwnerMSCAFactory.sol"; + +contract SessionKeyPermissionsPluginTest is Test { + using ECDSA for bytes32; + + IEntryPoint entryPoint; + address payable beneficiary; + MultiOwnerPlugin multiOwnerPlugin; + MultiOwnerMSCAFactory factory; + SessionKeyPlugin sessionKeyPlugin; + SessionKeyPermissionsPlugin sessionKeyPermissionsPlugin; + + address owner1; + uint256 owner1Key; + UpgradeableModularAccount account1; + + address sessionKey1; + uint256 sessionKey1Private; + + uint256 constant CALL_GAS_LIMIT = 70000; + uint256 constant VERIFICATION_GAS_LIMIT = 1000000; + + address payable recipient; + + Counter counter1; + + Counter counter2; + + function setUp() public { + entryPoint = IEntryPoint(address(new EntryPoint())); + (owner1, owner1Key) = makeAddrAndKey("owner1"); + beneficiary = payable(makeAddr("beneficiary")); + recipient = payable(makeAddr("recipient")); + + vm.deal(beneficiary, 1 wei); + vm.deal(recipient, 1 wei); + + multiOwnerPlugin = new MultiOwnerPlugin(); + address impl = address(new UpgradeableModularAccount(entryPoint)); + + factory = + new MultiOwnerMSCAFactory(address(this), address(multiOwnerPlugin), impl, keccak256(abi.encode(multiOwnerPlugin.pluginManifest())), entryPoint); + + sessionKeyPlugin = new SessionKeyPlugin(); + sessionKeyPermissionsPlugin = new SessionKeyPermissionsPlugin(); + + address[] memory owners1 = new address[](1); + owners1[0] = owner1; + account1 = UpgradeableModularAccount(payable(factory.createAccount(0, owners1))); + vm.deal(address(account1), 100 ether); + + bytes32 manifestHash = keccak256(abi.encode(sessionKeyPlugin.pluginManifest())); + FunctionReference[] memory dependencies = new FunctionReference[](2); + dependencies[0] = FunctionReferenceLib.pack( + address(multiOwnerPlugin), uint8(IMultiOwnerPlugin.FunctionId.USER_OP_VALIDATION_OWNER) + ); + dependencies[1] = FunctionReferenceLib.pack( + address(multiOwnerPlugin), uint8(IMultiOwnerPlugin.FunctionId.RUNTIME_VALIDATION_OWNER_OR_SELF) + ); + vm.prank(owner1); + account1.installPlugin({ + plugin: address(sessionKeyPlugin), + manifestHash: manifestHash, + pluginInitData: abi.encode(new address[](0)), + dependencies: dependencies, + injectedHooks: new IPluginManager.InjectedHook[](0) + }); + + manifestHash = keccak256(abi.encode(sessionKeyPermissionsPlugin.pluginManifest())); + // Can reuse the same dependencies for this installation, because the requirements are the same. + vm.prank(owner1); + account1.installPlugin({ + plugin: address(sessionKeyPermissionsPlugin), + manifestHash: manifestHash, + pluginInitData: "", + dependencies: dependencies, + injectedHooks: new IPluginManager.InjectedHook[](0) + }); + + // Create and add a session key + (sessionKey1, sessionKey1Private) = makeAddrAndKey("sessionKey1"); + + address[] memory sessionKeysToAdd = new address[](1); + sessionKeysToAdd[0] = sessionKey1; + + vm.prank(owner1); + SessionKeyPlugin(address(account1)).updateSessionKeys( + sessionKeysToAdd, new SessionKeyPlugin.SessionKeyToRemove[](0) + ); + + // Register the session key with the permissions plugin + vm.prank(owner1); + SessionKeyPermissionsPlugin(address(account1)).registerKey(sessionKey1, 0); + + // Initialize the interaction targets + + counter1 = new Counter(); + counter1.increment(); + + counter2 = new Counter(); + counter2.increment(); + } + + function test_sessionPerms_validateSetUp() public { + assertEq( + uint8(sessionKeyPermissionsPlugin.getAccessControlType(address(account1), sessionKey1)), + uint8(ISessionKeyPermissionsPlugin.ContractAccessControlType.ALLOWLIST) + ); + } + + function test_sessionPerms_contractDefaultAllowList() public { + _runSessionKeyExecUserOp( + address(counter1), + sessionKey1, + sessionKey1Private, + abi.encodeCall(Counter.increment, ()), + 0 wei, + abi.encodeWithSelector(IEntryPoint.FailedOp.selector, 0, "AA24 signature error") + ); + + // Call should fail before removing the allowlist + assertEq(counter1.number(), 1); + + // Remove the allowlist + bytes[] memory updates = new bytes[](1); + updates[0] = abi.encodeCall( + ISessionKeyPermissionsUpdates.setAccessListType, + (ISessionKeyPermissionsPlugin.ContractAccessControlType.NONE) + ); + vm.prank(owner1); + SessionKeyPermissionsPlugin(address(account1)).updateKeyPermissions(sessionKey1, updates); + + // Call should succeed after removing the allowlist + _runSessionKeyExecUserOp( + address(counter1), sessionKey1, sessionKey1Private, abi.encodeCall(Counter.increment, ()), 0 wei, "" + ); + + assertEq(counter1.number(), 2); + } + + function test_sessionPerms_contractAllowList() public { + // Add the allowlist + bytes[] memory updates = new bytes[](1); + updates[0] = abi.encodeCall( + ISessionKeyPermissionsUpdates.updateAccessListAddressEntry, (address(counter1), true, false) + ); + vm.prank(owner1); + SessionKeyPermissionsPlugin(address(account1)).updateKeyPermissions(sessionKey1, updates); + + // Call should succeed after adding the allowlist + _runSessionKeyExecUserOp( + address(counter1), sessionKey1, sessionKey1Private, abi.encodeCall(Counter.increment, ()), 0 wei, "" + ); + + assertEq(counter1.number(), 2); + + // // Call should fail for contract not on allowlist + _runSessionKeyExecUserOp( + address(counter2), + sessionKey1, + sessionKey1Private, + abi.encodeCall(Counter.increment, ()), + 0 wei, + abi.encodeWithSelector(IEntryPoint.FailedOp.selector, 0, "AA24 signature error") + ); + + assertEq(counter2.number(), 1); + } + + function test_sessionPerms_contractDenyList() public { + // Add the denylist + bytes[] memory updates = new bytes[](2); + updates[0] = abi.encodeCall( + ISessionKeyPermissionsUpdates.setAccessListType, + (ISessionKeyPermissionsPlugin.ContractAccessControlType.DENYLIST) + ); + updates[1] = abi.encodeCall( + ISessionKeyPermissionsUpdates.updateAccessListAddressEntry, (address(counter1), true, false) + ); + vm.prank(owner1); + SessionKeyPermissionsPlugin(address(account1)).updateKeyPermissions(sessionKey1, updates); + + // Check that the call should fail after adding the denylist + _runSessionKeyExecUserOp( + address(counter1), + sessionKey1, + sessionKey1Private, + abi.encodeCall(Counter.increment, ()), + 0 wei, + abi.encodeWithSelector(IEntryPoint.FailedOp.selector, 0, "AA24 signature error") + ); + + assertEq(counter1.number(), 1); + + // Call should suceed for contract not on denylist + _runSessionKeyExecUserOp( + address(counter2), sessionKey1, sessionKey1Private, abi.encodeCall(Counter.increment, ()), 0 wei, "" + ); + + assertEq(counter2.number(), 2); + } + + function test_sessionPerms_selectorAllowList() public { + // Add the allowlist + bytes[] memory updates = new bytes[](2); + updates[0] = abi.encodeCall( + ISessionKeyPermissionsUpdates.updateAccessListAddressEntry, (address(counter1), true, true) + ); + updates[1] = abi.encodeCall( + ISessionKeyPermissionsUpdates.updateAccessListFunctionEntry, + (address(counter1), Counter.increment.selector, true) + ); + + vm.prank(owner1); + SessionKeyPermissionsPlugin(address(account1)).updateKeyPermissions(sessionKey1, updates); + + // Call should succeed after adding the allowlist + _runSessionKeyExecUserOp( + address(counter1), sessionKey1, sessionKey1Private, abi.encodeCall(Counter.increment, ()), 0 wei, "" + ); + + assertEq(counter1.number(), 2); + + // Call should fail for function not on allowlist + _runSessionKeyExecUserOp( + address(counter1), + sessionKey1, + sessionKey1Private, + abi.encodeCall(Counter.setNumber, (5)), + 0 wei, + abi.encodeWithSelector(IEntryPoint.FailedOp.selector, 0, "AA24 signature error") + ); + + assertEq(counter1.number(), 2); + } + + function test_sessionPerms_selectorDenyList() public { + // Add the denylist + bytes[] memory updates = new bytes[](3); + updates[0] = abi.encodeCall( + ISessionKeyPermissionsUpdates.setAccessListType, + (ISessionKeyPermissionsPlugin.ContractAccessControlType.DENYLIST) + ); + updates[1] = abi.encodeCall( + ISessionKeyPermissionsUpdates.updateAccessListAddressEntry, (address(counter1), true, true) + ); + updates[2] = abi.encodeCall( + ISessionKeyPermissionsUpdates.updateAccessListFunctionEntry, + (address(counter1), Counter.increment.selector, true) + ); + + vm.prank(owner1); + SessionKeyPermissionsPlugin(address(account1)).updateKeyPermissions(sessionKey1, updates); + + // Call should fail after adding the denylist + _runSessionKeyExecUserOp( + address(counter1), + sessionKey1, + sessionKey1Private, + abi.encodeCall(Counter.increment, ()), + 0 wei, + abi.encodeWithSelector(IEntryPoint.FailedOp.selector, 0, "AA24 signature error") + ); + + assertEq(counter1.number(), 1); + + // Call should succeed for function not on denylist + _runSessionKeyExecUserOp( + address(counter1), sessionKey1, sessionKey1Private, abi.encodeCall(Counter.setNumber, (5)), 0 wei, "" + ); + + assertEq(counter1.number(), 5); + } + + function testFuzz_sessionKeyTimeRange(uint48 startTime, uint48 endTime) public { + bytes[] memory updates = new bytes[](1); + updates[0] = abi.encodeCall(ISessionKeyPermissionsUpdates.updateTimeRange, (startTime, endTime)); + + vm.prank(owner1); + SessionKeyPermissionsPlugin(address(account1)).updateKeyPermissions(sessionKey1, updates); + + Call[] memory calls = new Call[](1); + calls[0] = Call({target: address(0), value: 0, data: ""}); + + UserOperation memory userOp = UserOperation({ + sender: address(account1), + nonce: entryPoint.getNonce(address(account1), 0), + initCode: "", + callData: abi.encodeCall(ISessionKeyPlugin.executeWithSessionKey, (calls, sessionKey1)), + callGasLimit: CALL_GAS_LIMIT, + verificationGasLimit: VERIFICATION_GAS_LIMIT, + preVerificationGas: 0, + maxFeePerGas: 2, + maxPriorityFeePerGas: 1, + paymasterAndData: "", + signature: "" + }); + + bytes32 userOpHash = entryPoint.getUserOpHash(userOp); + (uint8 v, bytes32 r, bytes32 s) = vm.sign(sessionKey1Private, userOpHash.toEthSignedMessageHash()); + userOp.signature = abi.encodePacked(r, s, v); + + vm.prank(address(entryPoint)); + uint256 validationData = account1.validateUserOp(userOp, userOpHash, 0); + + // Assert the correct time range fields are returned + // Only check the end time field if it wasn't zero, which is interpretted as a max value by 4337. + if (endTime != 0) { + assertEq(uint48(validationData >> 160), endTime); + } + assertEq(uint48(validationData >> 208), startTime); + } + + function test_rotateKey_basic() public { + // Remove the default allowlist + bytes[] memory updates = new bytes[](1); + updates[0] = abi.encodeCall( + ISessionKeyPermissionsUpdates.setAccessListType, + (ISessionKeyPermissionsPlugin.ContractAccessControlType.NONE) + ); + vm.prank(owner1); + SessionKeyPermissionsPlugin(address(account1)).updateKeyPermissions(sessionKey1, updates); + + (address sessionKey2, uint256 sessionKey2Private) = makeAddrAndKey("sessionKey2"); + + // Add the session key to the account + address[] memory keysToAdd = new address[](1); + keysToAdd[0] = sessionKey2; + + vm.prank(owner1); + SessionKeyPlugin(address(account1)).updateSessionKeys( + keysToAdd, new SessionKeyPlugin.SessionKeyToRemove[](0) + ); + + vm.prank(owner1); + SessionKeyPermissionsPlugin(address(account1)).rotateKey(sessionKey1, sessionKey2); + + // Attempting to use the previous key should fail + _runSessionKeyExecUserOp( + address(counter1), + sessionKey1, + sessionKey1Private, + abi.encodeCall(Counter.increment, ()), + 0 wei, + abi.encodeWithSelector(IEntryPoint.FailedOp.selector, 0, "AA24 signature error") + ); + + // Attempting to use the new key should succeed + _runSessionKeyExecUserOp( + address(counter1), sessionKey2, sessionKey2Private, abi.encodeCall(Counter.increment, ()), 0 wei, "" + ); + } + + function test_rotateKey_permissionsTransfer() public { + // Set a time range on the key + uint48 startTime = uint48(block.timestamp); + uint48 endTime = uint48(block.timestamp + 1000); + + bytes[] memory updates = new bytes[](1); + updates[0] = abi.encodeCall(ISessionKeyPermissionsUpdates.updateTimeRange, (startTime, endTime)); + + vm.prank(owner1); + SessionKeyPermissionsPlugin(address(account1)).updateKeyPermissions(sessionKey1, updates); + + // Rotate the key + address sessionKey2 = makeAddr("sessionKey2"); + + vm.prank(owner1); + SessionKeyPermissionsPlugin(address(account1)).rotateKey(sessionKey1, sessionKey2); + + // Check the rotated key's time range + (uint48 returnedStartTime, uint48 returnedEndTime) = + sessionKeyPermissionsPlugin.getKeyTimeRange(address(account1), sessionKey2); + + assertEq(returnedStartTime, startTime); + assertEq(returnedEndTime, endTime); + } + + function testFuzz_sessionKeyPermissions_setRequiredPaymaster(address requiredPaymaster) public { + bytes[] memory updates = new bytes[](1); + updates[0] = abi.encodeCall(ISessionKeyPermissionsUpdates.setRequiredPaymaster, (requiredPaymaster)); + + vm.prank(owner1); + SessionKeyPermissionsPlugin(address(account1)).updateKeyPermissions(sessionKey1, updates); + + // Check the required paymaster + address returnedRequiredPaymaster = + sessionKeyPermissionsPlugin.getRequiredPaymaster(address(account1), sessionKey1); + assertEq(returnedRequiredPaymaster, requiredPaymaster); + + // Set the required paymaster to zero + updates[0] = abi.encodeCall(ISessionKeyPermissionsUpdates.setRequiredPaymaster, (address(0))); + vm.prank(owner1); + SessionKeyPermissionsPlugin(address(account1)).updateKeyPermissions(sessionKey1, updates); + + // Check the required paymaster + returnedRequiredPaymaster = + sessionKeyPermissionsPlugin.getRequiredPaymaster(address(account1), sessionKey1); + assertEq(returnedRequiredPaymaster, address(0)); + } + + function testFuzz_sessionKeyPermissions_checkRequiredPaymaster( + address requiredPaymaster, + address providedPaymaster + ) public { + // Disable the allowlist and disable native token spend checking. + bytes[] memory updates = new bytes[](2); + updates[0] = abi.encodeCall( + ISessionKeyPermissionsUpdates.setAccessListType, + (ISessionKeyPermissionsPlugin.ContractAccessControlType.NONE) + ); + updates[1] = abi.encodeCall(ISessionKeyPermissionsUpdates.setNativeTokenSpendLimit, (type(uint256).max, 0)); + + vm.prank(owner1); + SessionKeyPermissionsPlugin(address(account1)).updateKeyPermissions(sessionKey1, updates); + + vm.assume(providedPaymaster != address(0)); + + // First validate a user op with the paymaster set, without the required paymaster rule. + + Call[] memory calls = new Call[](1); + calls[0] = Call({target: recipient, value: 1 wei, data: ""}); + UserOperation memory userOp = UserOperation({ + sender: address(account1), + nonce: entryPoint.getNonce(address(account1), 0), + initCode: "", + callData: abi.encodeCall(ISessionKeyPlugin.executeWithSessionKey, (calls, sessionKey1)), + callGasLimit: CALL_GAS_LIMIT, + verificationGasLimit: VERIFICATION_GAS_LIMIT, + preVerificationGas: 0, + maxFeePerGas: 2, + maxPriorityFeePerGas: 1, + paymasterAndData: abi.encodePacked(providedPaymaster), + signature: "" + }); + + bytes32 userOpHash = entryPoint.getUserOpHash(userOp); + (uint8 v, bytes32 r, bytes32 s) = vm.sign(sessionKey1Private, userOpHash.toEthSignedMessageHash()); + userOp.signature = abi.encodePacked(r, s, v); + + vm.prank(address(entryPoint)); + uint256 validationData = account1.validateUserOp(userOp, userOpHash, 0); + + // Assert that validation passes + assertEq(uint160(validationData), 0); + + // Now set the required paymaster rule and validate again. + bytes[] memory updates2 = new bytes[](1); + updates2[0] = abi.encodeCall(ISessionKeyPermissionsUpdates.setRequiredPaymaster, (requiredPaymaster)); + vm.prank(owner1); + SessionKeyPermissionsPlugin(address(account1)).updateKeyPermissions(sessionKey1, updates2); + + vm.prank(address(entryPoint)); + validationData = account1.validateUserOp(userOp, userOpHash, 0); + + if (requiredPaymaster == providedPaymaster || requiredPaymaster == address(0)) { + // Assert that validation passes + assertEq(uint160(validationData), 0); + } else { + // Assert that validation fails + assertEq(uint160(validationData), 1); + } + } + + function test_sessionKeyPerms_requiredPaymaster_partialAddressFails() public { + // Disable the allowlist + bytes[] memory updates = new bytes[](1); + updates[0] = abi.encodeCall( + ISessionKeyPermissionsUpdates.setAccessListType, + (ISessionKeyPermissionsPlugin.ContractAccessControlType.NONE) + ); + vm.prank(owner1); + SessionKeyPermissionsPlugin(address(account1)).updateKeyPermissions(sessionKey1, updates); + + // create a paymaster address that would match, if right-padded with zeroes + address paymasterAddr = 0x1234123412341234000000000000000000000000; + // Add it + updates[0] = abi.encodeCall(ISessionKeyPermissionsUpdates.setRequiredPaymaster, (paymasterAddr)); + vm.prank(owner1); + SessionKeyPermissionsPlugin(address(account1)).updateKeyPermissions(sessionKey1, updates); + + Call[] memory calls = new Call[](1); + calls[0] = Call({target: recipient, value: 1 wei, data: ""}); + + UserOperation memory userOp = UserOperation({ + sender: address(account1), + nonce: entryPoint.getNonce(address(account1), 0), + initCode: "", + callData: abi.encodeCall(ISessionKeyPlugin.executeWithSessionKey, (calls, sessionKey1)), + callGasLimit: CALL_GAS_LIMIT, + verificationGasLimit: VERIFICATION_GAS_LIMIT, + preVerificationGas: 0, + maxFeePerGas: 2, + maxPriorityFeePerGas: 1, + paymasterAndData: abi.encodePacked(uint64(0x1234123412341234)), + signature: "" + }); + + bytes32 userOpHash = entryPoint.getUserOpHash(userOp); + (uint8 v, bytes32 r, bytes32 s) = vm.sign(sessionKey1Private, userOpHash.toEthSignedMessageHash()); + userOp.signature = abi.encodePacked(r, s, v); + + UserOperation[] memory userOps = new UserOperation[](1); + userOps[0] = userOp; + + vm.expectRevert("AA93 invalid paymasterAndData"); + entryPoint.handleOps(userOps, beneficiary); + } + + function _runSessionKeyExecUserOp( + address target, + address sessionKey, + uint256 privateKey, + bytes memory callData, + uint256 value, + bytes memory revertReason + ) internal { + Call[] memory calls = new Call[](1); + calls[0] = Call({target: target, value: value, data: callData}); + + UserOperation memory userOp = UserOperation({ + sender: address(account1), + nonce: entryPoint.getNonce(address(account1), 0), + initCode: "", + callData: abi.encodeCall(ISessionKeyPlugin.executeWithSessionKey, (calls, sessionKey)), + callGasLimit: CALL_GAS_LIMIT, + verificationGasLimit: VERIFICATION_GAS_LIMIT, + preVerificationGas: 0, + maxFeePerGas: 2, + maxPriorityFeePerGas: 1, + paymasterAndData: "", + signature: "" + }); + + bytes32 userOpHash = entryPoint.getUserOpHash(userOp); + (uint8 v, bytes32 r, bytes32 s) = vm.sign(privateKey, userOpHash.toEthSignedMessageHash()); + userOp.signature = abi.encodePacked(r, s, v); + + UserOperation[] memory userOps = new UserOperation[](1); + userOps[0] = userOp; + + if (revertReason.length > 0) { + vm.expectRevert(revertReason); + } + entryPoint.handleOps(userOps, beneficiary); + } +} diff --git a/test/upgrade/LightAccountToMSCA.t.sol b/test/upgrade/LightAccountToMSCA.t.sol new file mode 100644 index 00000000..2e7b3ce7 --- /dev/null +++ b/test/upgrade/LightAccountToMSCA.t.sol @@ -0,0 +1,93 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.21; + +import {Test} from "forge-std/Test.sol"; + +import {LightAccount} from "@alchemy/light-account/src/LightAccount.sol"; +import {LightAccountFactory} from "@alchemy/light-account/src/LightAccountFactory.sol"; +import {EntryPoint} from "@eth-infinitism/account-abstraction/core/EntryPoint.sol"; +import {IEntryPoint} from "@eth-infinitism/account-abstraction/interfaces/IEntryPoint.sol"; + +import {UpgradeableModularAccount} from "../../src/account/UpgradeableModularAccount.sol"; +import {MultiOwnerPlugin} from "../../src/plugins/owner/MultiOwnerPlugin.sol"; +import {IEntryPoint as IMSCAEntryPoint} from "../../src/interfaces/erc4337/IEntryPoint.sol"; + +import {MockERC20} from "../mocks/tokens/MockERC20.sol"; + +contract LightAccountToMSCATest is Test { + IEntryPoint public entryPoint; + IMSCAEntryPoint public mscaEntryPoint; + + MockERC20 public token1; + + address public owner; + address[] public owners; + LightAccount public lightAccount; + + MultiOwnerPlugin public multiOwnerPlugin; + address public mscaImpl; + + event ModularAccountInitialized(IMSCAEntryPoint indexed entryPoint); + + function setUp() public { + entryPoint = IEntryPoint(address(new EntryPoint())); + mscaEntryPoint = IMSCAEntryPoint(address(entryPoint)); + (owner,) = makeAddrAndKey("owner"); + + // set up light account + LightAccountFactory lightAccountFactory = new LightAccountFactory(entryPoint); + lightAccount = lightAccountFactory.createAccount(owner, 1); + vm.deal(address(lightAccount), 2 ether); + + // setup mock tokens + token1 = new MockERC20("T1"); + token1.mint(address(lightAccount), 1 ether); + + // setup MSCA + multiOwnerPlugin = new MultiOwnerPlugin(); + mscaImpl = address(new UpgradeableModularAccount(mscaEntryPoint)); + } + + function test_verifySetup() public { + assertEq(token1.balanceOf(address(lightAccount)), 1 ether); + assertEq(token1.balanceOf(owner), 0 ether); + + address[] memory returnedOwners = multiOwnerPlugin.ownersOf(address(lightAccount)); + assertEq(returnedOwners, new address[](0)); + assertEq(payable(lightAccount).balance, 2 ether); + assertEq(payable(owner).balance, 0); + } + + function test_upgrade() public { + // setup data for msca upgrade + owners = new address[](1); + owners[0] = owner; + address[] memory plugins = new address[](1); + plugins[0] = address(multiOwnerPlugin); + bytes32[] memory manifestHashes = new bytes32[](1); + manifestHashes[0] = keccak256(abi.encode(multiOwnerPlugin.pluginManifest())); + bytes[] memory pluginInitBytes = new bytes[](1); + pluginInitBytes[0] = abi.encode(owners); + + // upgrade to msca + vm.startPrank(owner); + vm.expectEmit(true, true, true, true); + emit ModularAccountInitialized(mscaEntryPoint); + lightAccount.upgradeToAndCall( + mscaImpl, + abi.encodeCall( + UpgradeableModularAccount.initialize, (plugins, abi.encode(manifestHashes, pluginInitBytes)) + ) + ); + + // verify upgrade success + address[] memory returnedOwners = multiOwnerPlugin.ownersOf(address(lightAccount)); + assertEq(returnedOwners, owners); + assertEq(token1.balanceOf(address(lightAccount)), 1 ether); + + // verify can do basic transaction + lightAccount.execute(owner, 1 ether, ""); + assertEq(payable(lightAccount).balance, 1 ether); + assertEq(payable(owner).balance, 1 ether); + } +} diff --git a/test/upgrade/MSCAToMSCA.t.sol b/test/upgrade/MSCAToMSCA.t.sol new file mode 100644 index 00000000..aa3cc51f --- /dev/null +++ b/test/upgrade/MSCAToMSCA.t.sol @@ -0,0 +1,80 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.21; + +import {Test} from "forge-std/Test.sol"; + +import {EntryPoint} from "@eth-infinitism/account-abstraction/core/EntryPoint.sol"; + +import {UpgradeableModularAccount} from "../../src/account/UpgradeableModularAccount.sol"; +import {MultiOwnerTokenReceiverMSCAFactory} from "../../src/factory/MultiOwnerTokenReceiverMSCAFactory.sol"; +import {MultiOwnerPlugin} from "../../src/plugins/owner/MultiOwnerPlugin.sol"; +import {TokenReceiverPlugin} from "../../src/plugins/TokenReceiverPlugin.sol"; +import {IEntryPoint} from "../../src/interfaces/erc4337/IEntryPoint.sol"; + +import {Utils} from "../Utils.sol"; +import {MockERC20} from "../mocks/tokens/MockERC20.sol"; + +contract MSCAToMSCATest is Test { + IEntryPoint public entryPoint; + + MockERC20 public token1; + + address[] public owners; + UpgradeableModularAccount public msca; + + MultiOwnerPlugin public multiOwnerPlugin; + TokenReceiverPlugin public tokenReceiverPlugin; + address public mscaImpl1; + address public mscaImpl2; + + event Upgraded(address indexed implementation); + + function setUp() public { + owners.push(makeAddr("owner1")); + owners.push(makeAddr("owner2")); + entryPoint = IEntryPoint(address(new EntryPoint())); + mscaImpl1 = address(new UpgradeableModularAccount(entryPoint)); + mscaImpl2 = address(new UpgradeableModularAccount(entryPoint)); + multiOwnerPlugin = new MultiOwnerPlugin(); + tokenReceiverPlugin = new TokenReceiverPlugin(); + bytes32 ownerManifestHash = keccak256(abi.encode(multiOwnerPlugin.pluginManifest())); + bytes32 tokenReceiverManifestHash = keccak256(abi.encode(tokenReceiverPlugin.pluginManifest())); + MultiOwnerTokenReceiverMSCAFactory factory = new MultiOwnerTokenReceiverMSCAFactory( + address(this), + address(multiOwnerPlugin), + address(tokenReceiverPlugin), + mscaImpl1, + ownerManifestHash, + tokenReceiverManifestHash, + entryPoint + ); + msca = UpgradeableModularAccount(payable(factory.createAccount(0, owners))); + vm.deal(address(msca), 2 ether); + + // setup mock tokens + token1 = new MockERC20("T1"); + token1.mint(address(msca), 1 ether); + } + + function test_sameStorageSlot_upgradeToAndCall() public { + vm.startPrank(owners[0]); + + // upgrade to mscaImpl2 + vm.expectEmit(true, true, true, true); + emit Upgraded(mscaImpl2); + msca.upgradeToAndCall(mscaImpl2, ""); + + // verify account storage is the same + (, bytes memory returnData) = address(msca).call(abi.encodeWithSelector(MultiOwnerPlugin.owners.selector)); + address[] memory returnedOwners = abi.decode(returnData, (address[])); + assertEq(Utils.reverseAddressArray(returnedOwners), owners); + assertEq(token1.balanceOf(address(msca)), 1 ether); + + // verify can do basic transaction + msca.execute(owners[0], 1 ether, ""); + assertEq(payable(msca).balance, 1 ether); + assertEq(payable(owners[0]).balance, 1 ether); + + vm.stopPrank(); + } +} diff --git a/utils/inspect.sh b/utils/inspect.sh new file mode 100644 index 00000000..5f3e862c --- /dev/null +++ b/utils/inspect.sh @@ -0,0 +1,46 @@ +#!/bin/bash + +# Generate inspection files in MD format for all primary contracts in src. + +CONTRACT_FILES=($(find src -iname '*.sol' | sort)) + +rm .storagelayout.md +rm .gasestimates.md + +echo "# Storage Layouts" >> .storagelayout.md +echo "Generated via \`bash utils/inspect.sh\`." >> .storagelayout.md +echo "" >> .storagelayout.md +echo "---" >> .storagelayout.md +echo "" >> .storagelayout.md + +echo "# Gas Estimates" >> .gasestimates.md +echo "Generated via \`bash utils/inspect.sh\`." >> .gasestimates.md +echo "" >> .gasestimates.md +echo "---" >> .gasestimates.md +echo "" >> .gasestimates.md +echo "\`forge test --gas-report --no-match-path \"test/invariant/**/*\"\`" >> .gasestimates.md +# Sed strings to strip color data and only start printing after the first '|' character, to exclude previous contents (compilation, test results, etc) +forge test --gas-report --no-match-path "test/invariant/**/*" | sed -r "s/\x1B\[([0-9]{1,3}(;[0-9]{1,2};?)?)?[mGK]//g" | sed -nr '/\|/,$p' >> .gasestimates.md + +for index in ${!CONTRACT_FILES[*]}; do + # echo "${CONTRACT_NAMES[$index]} is in ${CONTRACT_FILES[$index]}" + CONTRACT_NAME=$(basename -s ".sol" ${CONTRACT_FILES[${index}]}) + # echo ${CONTRACT_NAME} + # If file does not contain a contract named the same as the filename, discard from inspection (e.g. libraries). + if ! grep -q "contract ${CONTRACT_NAME}" ${CONTRACT_FILES[$index]}; then + # echo "Skipping ${CONTRACT_NAME}" + continue + fi + + # Show command names in files + echo "\`forge inspect --pretty ${CONTRACT_FILES[$index]}:${CONTRACT_NAME} storage-layout\`" >> .storagelayout.md + forge inspect --pretty ${CONTRACT_FILES[$index]}:${CONTRACT_NAME} storage-layout >> .storagelayout.md + echo "" >> .storagelayout.md + + # echo "\`forge inspect ${CONTRACT_FILES[$index]}:${CONTRACT_NAME} gasestimates\`" >> .gasestimates.md + # echo "\`\`\`json" >> .gasestimates.md + # forge inspect ${CONTRACT_FILES[$index]}:${CONTRACT_NAME} gasestimates >> .gasestimates.md + # echo "\`\`\`" >> .gasestimates.md + # echo "" >> .gasestimates.md + +done