diff --git a/test/account/AccountExecHooks.t.sol b/test/account/AccountExecHooks.t.sol index 132f2597f..cdbaf18f2 100644 --- a/test/account/AccountExecHooks.t.sol +++ b/test/account/AccountExecHooks.t.sol @@ -176,6 +176,26 @@ contract UpgradeableModularAccountExecHooksTest is Test { vm.stopPrank(); } + function test_preExecHook_revertAlwaysDeny() public { + vm.startPrank(owner1); + + _installPlugin1WithHooks( + _EXEC_SELECTOR, + ManifestFunction({ + functionType: ManifestAssociatedFunctionType.PRE_HOOK_ALWAYS_DENY, + functionId: 0, + dependencyIndex: 0 + }), + ManifestFunction({functionType: ManifestAssociatedFunctionType.NONE, functionId: 0, dependencyIndex: 0}) + ); + + (bool success, bytes memory returnData) = address(account1).call(abi.encodeWithSelector(_EXEC_SELECTOR)); + assertFalse(success); + assertEq(returnData, abi.encodeWithSelector(UpgradeableModularAccount.AlwaysDenyRule.selector)); + + vm.stopPrank(); + } + /// @dev Plugin 1 hook pair: [1, 2] /// Expected execution: [1, 2] function test_execHookPair_install() public { diff --git a/test/account/AccountPreValidationHooks.t.sol b/test/account/AccountPreValidationHooks.t.sol index 611ed3c39..1f4488ef4 100644 --- a/test/account/AccountPreValidationHooks.t.sol +++ b/test/account/AccountPreValidationHooks.t.sol @@ -181,6 +181,25 @@ contract UpgradeableModularAccountPreValidationHooksTest is Test { vm.stopPrank(); } + function test_preRuntimeValidationHooks_revertAlwaysDeny() public { + vm.startPrank(owner1); + + _installPlugin1WithPreRuntimeValidationHook( + _EXEC_SELECTOR, + ManifestFunction({ + functionType: ManifestAssociatedFunctionType.PRE_HOOK_ALWAYS_DENY, + functionId: 0, + dependencyIndex: 0 + }) + ); + + (bool success, bytes memory returnData) = address(account1).call(abi.encodeWithSelector(_EXEC_SELECTOR)); + assertFalse(success); + assertEq(returnData, abi.encodeWithSelector(UpgradeableModularAccount.AlwaysDenyRule.selector)); + + vm.stopPrank(); + } + /// @dev Plugin 1 hook: 1 /// Plugin 2 hook: 2 /// Expected execution: [1, 2] @@ -438,6 +457,43 @@ contract UpgradeableModularAccountPreValidationHooksTest is Test { vm.stopPrank(); } + function test_preUserOpValidationHooks_revertAlwaysDeny() public { + vm.startPrank(owner1); + + _installPlugin1WithPreUserOpValidationHook( + _EXEC_SELECTOR, + ManifestFunction({ + functionType: ManifestAssociatedFunctionType.PRE_HOOK_ALWAYS_DENY, + functionId: 0, + dependencyIndex: 0 + }) + ); + + vm.stopPrank(); + + UserOperation memory userOp = UserOperation({ + sender: address(account1), + nonce: 0, + initCode: "", + callData: abi.encodeWithSelector(_EXEC_SELECTOR), + callGasLimit: CALL_GAS_LIMIT, + verificationGasLimit: VERIFICATION_GAS_LIMIT, + preVerificationGas: 0, + maxFeePerGas: 2, + maxPriorityFeePerGas: 1, + paymasterAndData: "", + signature: "" + }); + bytes32 userOpHash = entryPoint.getUserOpHash(userOp); + (uint8 v, bytes32 r, bytes32 s) = vm.sign(owner1Key, userOpHash.toEthSignedMessageHash()); + userOp.signature = abi.encodePacked(r, s, v); + + vm.startPrank(address(entryPoint)); + vm.expectRevert(UpgradeableModularAccount.AlwaysDenyRule.selector); + account1.validateUserOp(userOp, userOpHash, 0); + vm.stopPrank(); + } + /// @dev Plugin 1 hook: 1 /// Plugin 2 hook: 1 /// Expected execution: [1, 1] diff --git a/test/account/UpgradeableModularAccount.t.sol b/test/account/UpgradeableModularAccount.t.sol index 86d0cd0bd..799aa4cd2 100644 --- a/test/account/UpgradeableModularAccount.t.sol +++ b/test/account/UpgradeableModularAccount.t.sol @@ -3,10 +3,12 @@ pragma solidity ^0.8.21; import {Test, console} from "forge-std/Test.sol"; +import {ERC1967Proxy} from "@openzeppelin/contracts/proxy/ERC1967/ERC1967Proxy.sol"; import {ECDSA} from "@openzeppelin/contracts/utils/cryptography/ECDSA.sol"; import {EntryPoint} from "@eth-infinitism/account-abstraction/core/EntryPoint.sol"; import {AccountExecutor} from "../../src/account/AccountExecutor.sol"; +import {PluginManagerInternals} from "../../src/account/PluginManagerInternals.sol"; import {UpgradeableModularAccount} from "../../src/account/UpgradeableModularAccount.sol"; import {IMultiOwnerPlugin} from "../../src/plugins/owner/IMultiOwnerPlugin.sol"; import {MultiOwnerPlugin} from "../../src/plugins/owner/MultiOwnerPlugin.sol"; @@ -14,14 +16,15 @@ import {SessionKeyPlugin} from "../../src/plugins/session/SessionKeyPlugin.sol"; import {TokenReceiverPlugin} from "../../src/plugins/TokenReceiverPlugin.sol"; import {IEntryPoint} from "../../src/interfaces/erc4337/IEntryPoint.sol"; import {UserOperation} from "../../src/interfaces/erc4337/UserOperation.sol"; -import {PluginManifest} from "../../src/interfaces/IPlugin.sol"; +import {IAccountInitializable} from "../../src/interfaces/IAccountInitializable.sol"; import {IPlugin, PluginManifest} from "../../src/interfaces/IPlugin.sol"; import {IPluginManager} from "../../src/interfaces/IPluginManager.sol"; import {Call} from "../../src/interfaces/IStandardExecutor.sol"; -import {PluginManifest} from "../../src/interfaces/IPlugin.sol"; +import {FunctionReference} from "../../src/libraries/FunctionReferenceLib.sol"; import {Counter} from "../mocks/Counter.sol"; import {MultiOwnerMSCAFactory} from "../../src/factory/MultiOwnerMSCAFactory.sol"; +import {MockPlugin} from "../mocks/MockPlugin.sol"; contract UpgradeableModularAccountTest is Test { using ECDSA for bytes32; @@ -32,6 +35,7 @@ contract UpgradeableModularAccountTest is Test { TokenReceiverPlugin public tokenReceiverPlugin; SessionKeyPlugin public sessionKeyPlugin; MultiOwnerMSCAFactory public factory; + address public accountImplementation; address public owner1; uint256 public owner1Key; @@ -65,10 +69,10 @@ contract UpgradeableModularAccountTest is Test { multiOwnerPlugin = new MultiOwnerPlugin(); tokenReceiverPlugin = new TokenReceiverPlugin(); sessionKeyPlugin = new SessionKeyPlugin(); - address implementation = address(new UpgradeableModularAccount(entryPoint)); + accountImplementation = address(new UpgradeableModularAccount(entryPoint)); bytes32 manifestHash = keccak256(abi.encode(multiOwnerPlugin.pluginManifest())); factory = new MultiOwnerMSCAFactory( - address(this), address(multiOwnerPlugin), implementation, manifestHash, entryPoint + address(this), address(multiOwnerPlugin), accountImplementation, manifestHash, entryPoint ); // Compute counterfactual address @@ -94,6 +98,18 @@ contract UpgradeableModularAccountTest is Test { factory.createAccount(0, owners1); } + function test_initialize_revertArrayLengthMismatch() public { + ERC1967Proxy account = new ERC1967Proxy{salt : ""}(accountImplementation, ""); + address[] memory plugins = new address[](2); + bytes memory pluginInitData = abi.encode(new bytes32[](1), new bytes[](1)); + vm.expectRevert(PluginManagerInternals.ArrayLengthMismatch.selector); + IAccountInitializable(address(account)).initialize(plugins, pluginInitData); + + pluginInitData = abi.encode(new bytes32[](2), new bytes[](1)); + vm.expectRevert(PluginManagerInternals.ArrayLengthMismatch.selector); + IAccountInitializable(address(account)).initialize(plugins, pluginInitData); + } + function test_basicUserOp() public { address[] memory owners = new address[](1); owners[0] = owner2; @@ -395,6 +411,106 @@ contract UpgradeableModularAccountTest is Test { assertEq(UpgradeableModularAccount(payable(account1)).getNonce(), 1); } + function test_validateUserOp_revertNotFromEntryPoint() public { + UserOperation memory userOp = UserOperation({ + sender: address(account2), + nonce: 0, + initCode: "", + callData: abi.encodeCall(UpgradeableModularAccount.execute, (ethRecipient, 1 wei, "")), + callGasLimit: CALL_GAS_LIMIT, + verificationGasLimit: VERIFICATION_GAS_LIMIT, + preVerificationGas: 0, + maxFeePerGas: 1, + maxPriorityFeePerGas: 1, + paymasterAndData: "", + signature: "" + }); + + // Generate signature + bytes32 userOpHash = entryPoint.getUserOpHash(userOp); + (uint8 v, bytes32 r, bytes32 s) = vm.sign(owner2Key, userOpHash.toEthSignedMessageHash()); + userOp.signature = abi.encodePacked(r, s, v); + + vm.expectRevert(UpgradeableModularAccount.UserOpNotFromEntryPoint.selector); + account2.validateUserOp(userOp, userOpHash, 0); + } + + function test_validateUserOp_revertUnrecognizedFunction() public { + // Invalid calldata of length 2. + bytes memory callData = hex"12"; + + UserOperation memory userOp = UserOperation({ + sender: address(account2), + nonce: 0, + initCode: "", + callData: callData, + callGasLimit: CALL_GAS_LIMIT, + verificationGasLimit: VERIFICATION_GAS_LIMIT, + preVerificationGas: 0, + maxFeePerGas: 1, + maxPriorityFeePerGas: 1, + paymasterAndData: "", + signature: "" + }); + + // Generate signature + bytes32 userOpHash = entryPoint.getUserOpHash(userOp); + (uint8 v, bytes32 r, bytes32 s) = vm.sign(owner2Key, userOpHash.toEthSignedMessageHash()); + userOp.signature = abi.encodePacked(r, s, v); + + vm.startPrank(address(entryPoint)); + vm.expectRevert( + abi.encodeWithSelector(UpgradeableModularAccount.UnrecognizedFunction.selector, bytes4(callData)) + ); + account2.validateUserOp(userOp, userOpHash, 0); + vm.stopPrank(); + } + + function test_validateUserOp_revertFunctionMissing() public { + PluginManifest memory m; + m.executionFunctions = new bytes4[](1); + bytes4 fooSelector = bytes4(keccak256("foo()")); + m.executionFunctions[0] = fooSelector; + MockPlugin plugin = new MockPlugin(m); + bytes32 manifestHash = keccak256(abi.encode(plugin.pluginManifest())); + + vm.startPrank(owner2); + // Install a plugin with execution function foo() that does not have an associated user op validation + // function. + IPluginManager(account2).installPlugin({ + plugin: address(plugin), + manifestHash: manifestHash, + pluginInitData: "", + dependencies: new FunctionReference[](0), + injectedHooks: new IPluginManager.InjectedHook[](0) + }); + vm.stopPrank(); + + UserOperation memory userOp = UserOperation({ + sender: address(account2), + nonce: 0, + initCode: "", + callData: abi.encodeWithSelector(fooSelector), + callGasLimit: CALL_GAS_LIMIT, + verificationGasLimit: VERIFICATION_GAS_LIMIT, + preVerificationGas: 0, + maxFeePerGas: 1, + maxPriorityFeePerGas: 1, + paymasterAndData: "", + signature: "" + }); + bytes32 userOpHash = entryPoint.getUserOpHash(userOp); + (uint8 v, bytes32 r, bytes32 s) = vm.sign(owner2Key, userOpHash.toEthSignedMessageHash()); + userOp.signature = abi.encodePacked(r, s, v); + + vm.startPrank(address(entryPoint)); + vm.expectRevert( + abi.encodeWithSelector(UpgradeableModularAccount.UserOpValidationFunctionMissing.selector, fooSelector) + ); + account2.validateUserOp(userOp, userOpHash, 0); + vm.stopPrank(); + } + // Internal Functions function _printStorageReadsAndWrites(address addr) internal {