Skip to content

Commit

Permalink
feat: refactor execution install functions into external lib
Browse files Browse the repository at this point in the history
  • Loading branch information
Zer0dot committed Oct 11, 2024
1 parent e8ce17a commit c8f1ff5
Show file tree
Hide file tree
Showing 8 changed files with 255 additions and 276 deletions.
6 changes: 4 additions & 2 deletions src/account/ModularAccountBase.sol
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ import {ModularAccountView} from "./ModularAccountView.sol";
import {ModuleManagerInternals} from "./ModuleManagerInternals.sol";
import {TokenReceiver} from "./TokenReceiver.sol";

import {ExecutionInstallLib} from "../libraries/ExecutionInstallLib.sol";

abstract contract ModularAccountBase is
IModularAccount,
ModularAccountView,
Expand Down Expand Up @@ -281,7 +283,7 @@ abstract contract ModularAccountBase is
ExecutionManifest calldata manifest,
bytes calldata moduleInstallData
) external override wrapNativeFunction {
_installExecution(module, manifest, moduleInstallData);
ExecutionInstallLib.installExecution(module, manifest, moduleInstallData);
}

/// @inheritdoc IModularAccount
Expand All @@ -291,7 +293,7 @@ abstract contract ModularAccountBase is
ExecutionManifest calldata manifest,
bytes calldata moduleUninstallData
) external override wrapNativeFunction {
_uninstallExecution(module, manifest, moduleUninstallData);
ExecutionInstallLib.uninstallExecution(module, manifest, moduleUninstallData);
}

/// @inheritdoc IModularAccount
Expand Down
194 changes: 11 additions & 183 deletions src/account/ModuleManagerInternals.sol
Original file line number Diff line number Diff line change
Expand Up @@ -2,30 +2,23 @@
pragma solidity ^0.8.26;

import {IExecutionHookModule} from "@erc6900/reference-implementation/interfaces/IExecutionHookModule.sol";
import {
ExecutionManifest,
ManifestExecutionHook
} from "@erc6900/reference-implementation/interfaces/IExecutionModule.sol";
import {
HookConfig,
IModularAccount,
ModuleEntity,
ValidationConfig
} from "@erc6900/reference-implementation/interfaces/IModularAccount.sol";
import {IModule} from "@erc6900/reference-implementation/interfaces/IModule.sol";
import {IValidationHookModule} from "@erc6900/reference-implementation/interfaces/IValidationHookModule.sol";
import {IValidationModule} from "@erc6900/reference-implementation/interfaces/IValidationModule.sol";
import {HookConfigLib} from "@erc6900/reference-implementation/libraries/HookConfigLib.sol";
import {ModuleEntityLib} from "@erc6900/reference-implementation/libraries/ModuleEntityLib.sol";
import {ValidationConfigLib} from "@erc6900/reference-implementation/libraries/ValidationConfigLib.sol";
import {ERC165Checker} from "@openzeppelin/contracts/utils/introspection/ERC165Checker.sol";

import {MAX_PRE_VALIDATION_HOOKS} from "../helpers/Constants.sol";
import {ExecutionLib} from "../libraries/ExecutionLib.sol";
import {KnownSelectorsLib} from "../libraries/KnownSelectorsLib.sol";
import {ExecutionInstallLib} from "../libraries/ExecutionInstallLib.sol";
import {LinkedListSet, LinkedListSetLib} from "../libraries/LinkedListSetLib.sol";
import {MemManagementLib} from "../libraries/MemManagementLib.sol";
import {AccountStorage, ExecutionData, ValidationData, getAccountStorage, toSetValue} from "./AccountStorage.sol";
import {ValidationData, getAccountStorage, toSetValue} from "./AccountStorage.sol";

abstract contract ModuleManagerInternals is IModularAccount {
using LinkedListSetLib for LinkedListSet;
Expand All @@ -38,7 +31,6 @@ abstract contract ModuleManagerInternals is IModularAccount {
error ExecutionFunctionAlreadySet(bytes4 selector);
error IModuleFunctionNotAllowed(bytes4 selector);
error InterfaceNotSupported(address module);
error NativeFunctionNotAllowed(bytes4 selector);
error NullModule();
error ExecutionHookAlreadySet(HookConfig hookConfig);
error ModuleInstallCallbackFailed(address module, bytes revertReason);
Expand All @@ -49,51 +41,6 @@ abstract contract ModuleManagerInternals is IModularAccount {

// Storage update operations

function _setExecutionFunction(
bytes4 selector,
bool skipRuntimeValidation,
bool allowGlobalValidation,
address module
) internal {
ExecutionData storage _executionData = getAccountStorage().executionData[selector];

if (_executionData.module != address(0)) {
revert ExecutionFunctionAlreadySet(selector);
}

// Make sure incoming execution function does not collide with any native functions (data are stored on the
// account implementation contract)
if (_isNativeFunction(selector)) {
revert NativeFunctionNotAllowed(selector);
}

// Make sure incoming execution function is not a function in IModule
if (KnownSelectorsLib.isIModuleFunction(selector)) {
revert IModuleFunctionNotAllowed(selector);
}

// Also make sure it doesn't collide with functions defined by ERC-4337
// and called by the entry point. This prevents a malicious module from
// sneaking in a function with the same selector as e.g.
// `validatePaymasterUserOp` and turning the account into their own
// personal paymaster.
if (KnownSelectorsLib.isErc4337Function(selector)) {
revert Erc4337FunctionNotAllowed(selector);
}

_executionData.module = module;
_executionData.skipRuntimeValidation = skipRuntimeValidation;
_executionData.allowGlobalValidation = allowGlobalValidation;
}

function _removeExecutionFunction(bytes4 selector) internal {
ExecutionData storage _executionData = getAccountStorage().executionData[selector];

_executionData.module = address(0);
_executionData.skipRuntimeValidation = false;
_executionData.allowGlobalValidation = false;
}

function _removeValidationFunction(ModuleEntity validationFunction) internal {
ValidationData storage _validationData = getAccountStorage().validationData[validationFunction];

Expand All @@ -102,123 +49,6 @@ abstract contract ModuleManagerInternals is IModularAccount {
_validationData.isUserOpValidation = false;
}

function _addExecHooks(LinkedListSet storage hooks, HookConfig hookConfig) internal {
if (!hooks.tryAdd(toSetValue(hookConfig))) {
revert ExecutionHookAlreadySet(hookConfig);
}
}

function _removeExecHooks(LinkedListSet storage hooks, HookConfig hookConfig) internal {
// Todo: use predecessor
hooks.tryRemove(toSetValue(hookConfig));
}

function _installExecution(
address module,
ExecutionManifest calldata manifest,
bytes calldata moduleInstallData
) internal {
AccountStorage storage _storage = getAccountStorage();

if (module == address(0)) {
revert NullModule();
}

// Update components according to the manifest.
uint256 length = manifest.executionFunctions.length;
for (uint256 i = 0; i < length; ++i) {
bytes4 selector = manifest.executionFunctions[i].executionSelector;
bool skipRuntimeValidation = manifest.executionFunctions[i].skipRuntimeValidation;
bool allowGlobalValidation = manifest.executionFunctions[i].allowGlobalValidation;
_setExecutionFunction(selector, skipRuntimeValidation, allowGlobalValidation, module);
}

length = manifest.executionHooks.length;
for (uint256 i = 0; i < length; ++i) {
ManifestExecutionHook memory mh = manifest.executionHooks[i];
ExecutionData storage executionData = _storage.executionData[mh.executionSelector];
HookConfig hookConfig = HookConfigLib.packExecHook({
_module: module,
_entityId: mh.entityId,
_hasPre: mh.isPreHook,
_hasPost: mh.isPostHook
});
_addExecHooks(executionData.executionHooks, hookConfig);
}

length = manifest.interfaceIds.length;
for (uint256 i = 0; i < length; ++i) {
_storage.supportedIfaces[manifest.interfaceIds[i]] += 1;
}

_onInstall(module, moduleInstallData, type(IModule).interfaceId);

emit ExecutionInstalled(module, manifest);
}

function _uninstallExecution(address module, ExecutionManifest calldata manifest, bytes calldata uninstallData)
internal
{
AccountStorage storage _storage = getAccountStorage();

// Remove components according to the manifest, in reverse order (by component type) of their installation.

uint256 length = manifest.executionHooks.length;
for (uint256 i = 0; i < length; ++i) {
ManifestExecutionHook memory mh = manifest.executionHooks[i];
ExecutionData storage execData = _storage.executionData[mh.executionSelector];
HookConfig hookConfig = HookConfigLib.packExecHook({
_module: module,
_entityId: mh.entityId,
_hasPre: mh.isPreHook,
_hasPost: mh.isPostHook
});
_removeExecHooks(execData.executionHooks, hookConfig);
}

length = manifest.executionFunctions.length;
for (uint256 i = 0; i < length; ++i) {
bytes4 selector = manifest.executionFunctions[i].executionSelector;
_removeExecutionFunction(selector);
}

length = manifest.interfaceIds.length;
for (uint256 i = 0; i < length; ++i) {
_storage.supportedIfaces[manifest.interfaceIds[i]] -= 1;
}

// Clear the module storage for the account.
bool onUninstallSuccess = _onUninstall(module, uninstallData);

emit ExecutionUninstalled(module, onUninstallSuccess, manifest);
}

function _onInstall(address module, bytes calldata data, bytes4 interfaceId) internal {
if (data.length > 0) {
if (!ERC165Checker.supportsERC165InterfaceUnchecked(module, interfaceId)) {
revert InterfaceNotSupported(module);
}
// solhint-disable-next-line no-empty-blocks
try IModule(module).onInstall(data) {}
catch {
bytes memory revertReason = ExecutionLib.collectReturnData();
revert ModuleInstallCallbackFailed(module, revertReason);
}
}
}

function _onUninstall(address module, bytes calldata data) internal returns (bool onUninstallSuccess) {
onUninstallSuccess = true;
if (data.length > 0) {
// Clear the module storage for the account.
// solhint-disable-next-line no-empty-blocks
try IModule(module).onUninstall(data) {}
catch {
onUninstallSuccess = false;
}
}
}

function _installValidation(
ValidationConfig validationConfig,
bytes4[] calldata selectors,
Expand All @@ -244,15 +74,17 @@ abstract contract ModuleManagerInternals is IModularAccount {
revert PreValidationHookDuplicate();
}

_onInstall(hookConfig.module(), hookData, type(IValidationHookModule).interfaceId);
ExecutionInstallLib.onInstall(
hookConfig.module(), hookData, type(IValidationHookModule).interfaceId
);

continue;
}
// Hook is an execution hook
_validationData.executionHookCount += 1;
_addExecHooks(_validationData.executionHooks, hookConfig);
ExecutionInstallLib.addExecHooks(_validationData.executionHooks, hookConfig);

_onInstall(hookConfig.module(), hookData, type(IExecutionHookModule).interfaceId);
ExecutionInstallLib.onInstall(hookConfig.module(), hookData, type(IExecutionHookModule).interfaceId);
}

for (uint256 i = 0; i < selectors.length; ++i) {
Expand All @@ -266,7 +98,7 @@ abstract contract ModuleManagerInternals is IModularAccount {
_validationData.isSignatureValidation = validationConfig.isSignatureValidation();
_validationData.isUserOpValidation = validationConfig.isUserOpValidation();

_onInstall(validationConfig.module(), installData, type(IValidationModule).interfaceId);
ExecutionInstallLib.onInstall(validationConfig.module(), installData, type(IValidationModule).interfaceId);
emit ValidationInstalled(validationConfig.module(), validationConfig.entityId());
}

Expand Down Expand Up @@ -295,14 +127,14 @@ abstract contract ModuleManagerInternals is IModularAccount {
for (uint256 i = 0; i < validationHooks.length; ++i) {
bytes calldata hookData = hookUninstallDatas[hookIndex];
(address hookModule,) = ModuleEntityLib.unpack(validationHooks[i].moduleEntity());
onUninstallSuccess = onUninstallSuccess && _onUninstall(hookModule, hookData);
onUninstallSuccess = onUninstallSuccess && ExecutionInstallLib.onUninstall(hookModule, hookData);
hookIndex++;
}

for (uint256 i = 0; i < execHooks.length; ++i) {
bytes calldata hookData = hookUninstallDatas[hookIndex];
address hookModule = execHooks[i].module();
onUninstallSuccess = onUninstallSuccess && _onUninstall(hookModule, hookData);
onUninstallSuccess = onUninstallSuccess && ExecutionInstallLib.onUninstall(hookModule, hookData);
hookIndex++;
}
}
Expand All @@ -318,12 +150,8 @@ abstract contract ModuleManagerInternals is IModularAccount {
_validationData.selectors.clear();

(address module, uint32 entityId) = ModuleEntityLib.unpack(validationFunction);
onUninstallSuccess = onUninstallSuccess && _onUninstall(module, uninstallData);
onUninstallSuccess = onUninstallSuccess && ExecutionInstallLib.onUninstall(module, uninstallData);

emit ValidationUninstalled(module, entityId, onUninstallSuccess);
}

function _isNativeFunction(bytes4 selector) internal pure virtual returns (bool) {
return KnownSelectorsLib.isNativeFunction(selector);
}
}
13 changes: 3 additions & 10 deletions src/account/SemiModularAccountBase.sol
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
// SPDX-License-Identifier: UNLICENSED
pragma solidity ^0.8.26;

import {DIRECT_CALL_VALIDATION_ENTITYID, FALLBACK_VALIDATION} from "../helpers/Constants.sol";
import {SignatureType} from "../helpers/SignatureType.sol";
import {ModularAccountBase} from "./ModularAccountBase.sol";
import {IModularAccount, ModuleEntity} from "@erc6900/reference-implementation/interfaces/IModularAccount.sol";
import {ModuleEntityLib} from "@erc6900/reference-implementation/libraries/ModuleEntityLib.sol";
import {IEntryPoint} from "@eth-infinitism/account-abstraction/interfaces/IEntryPoint.sol";
Expand All @@ -9,11 +12,6 @@ import {ECDSA} from "@openzeppelin/contracts/utils/cryptography/ECDSA.sol";
import {MessageHashUtils} from "@openzeppelin/contracts/utils/cryptography/MessageHashUtils.sol";
import {SignatureChecker} from "@openzeppelin/contracts/utils/cryptography/SignatureChecker.sol";

import {DIRECT_CALL_VALIDATION_ENTITYID, FALLBACK_VALIDATION} from "../helpers/Constants.sol";
import {SignatureType} from "../helpers/SignatureType.sol";
import {SemiModularKnownSelectorsLib} from "../libraries/SemiModularKnownSelectorsLib.sol";
import {ModularAccountBase} from "./ModularAccountBase.sol";

abstract contract SemiModularAccountBase is ModularAccountBase {
using MessageHashUtils for bytes32;
using ModuleEntityLib for ModuleEntity;
Expand Down Expand Up @@ -230,9 +228,4 @@ abstract contract SemiModularAccountBase is ModularAccountBase {
}
return res;
}

// Overrides ModuleManagerInternals
function _isNativeFunction(bytes4 selector) internal pure override returns (bool) {
return SemiModularKnownSelectorsLib.isNativeFunction(selector);
}
}
Loading

0 comments on commit c8f1ff5

Please sign in to comment.