From 6fbb67c39b6c3841547ba930d67338487e683a79 Mon Sep 17 00:00:00 2001 From: Eric Zhong Date: Tue, 9 Jul 2024 10:36:24 -0400 Subject: [PATCH] feat: Add flag if selling to the pair reverts (#33) * Add flag if sells revert * fix test --- src/FeeOnTransferDetector.sol | 11 ++++-- test/FeeOnTransferDetector.t.sol | 24 ++++++++++++ test/mock/MockSellReentrancyFotToken.sol | 50 ++++++++++++++++++++++++ 3 files changed, 81 insertions(+), 4 deletions(-) create mode 100644 test/mock/MockSellReentrancyFotToken.sol diff --git a/src/FeeOnTransferDetector.sol b/src/FeeOnTransferDetector.sol index 953c75b..02bbbae 100644 --- a/src/FeeOnTransferDetector.sol +++ b/src/FeeOnTransferDetector.sol @@ -14,6 +14,7 @@ struct TokenFees { uint256 sellFeeBps; bool feeTakenOnTransfer; bool externalTransferFailed; + bool sellReverted; } /// @notice Detects the buy and sell fee for a fee-on-transfer token @@ -102,7 +103,7 @@ contract FeeOnTransferDetector { /// @param reason the revert reason /// @return the decoded TokenFees struct function parseRevertReason(bytes memory reason) private pure returns (TokenFees memory) { - if (reason.length != 128) { + if (reason.length != 160) { assembly { revert(add(32, reason), mload(reason)) } @@ -128,14 +129,15 @@ contract FeeOnTransferDetector { (bool feeTakenOnTransfer, bool externalTransferFailed) = tryExternalTransferAndRevert(tokenBorrowed, amountBorrowed); - uint256 sellFeeBps = _calculateSellFee(pair, tokenBorrowed, amountBorrowed, buyFeeBps); + (uint256 sellFeeBps, bool sellReverted) = _calculateSellFee(pair, tokenBorrowed, amountBorrowed, buyFeeBps); bytes memory fees = abi.encode( TokenFees({ buyFeeBps: buyFeeBps, sellFeeBps: sellFeeBps, feeTakenOnTransfer: feeTakenOnTransfer, - externalTransferFailed: externalTransferFailed + externalTransferFailed: externalTransferFailed, + sellReverted: sellReverted }) ); @@ -158,7 +160,7 @@ contract FeeOnTransferDetector { /// - in the case where the transfer fails, we set the sell fee to be the same as the buy fee function _calculateSellFee(IUniswapV2Pair pair, ERC20 tokenBorrowed, uint256 amountBorrowed, uint256 buyFeeBps) internal - returns (uint256 sellFeeBps) + returns (uint256 sellFeeBps, bool sellReverted) { uint256 pairBalanceBeforeSell = tokenBorrowed.balanceOf(address(pair)); try this.callTransfer(tokenBorrowed, address(pair), amountBorrowed) { @@ -167,6 +169,7 @@ contract FeeOnTransferDetector { sellFeeBps = sellFee.mulDivUp(BPS, amountBorrowed); } catch (bytes memory) { sellFeeBps = buyFeeBps; + sellReverted = true; } } diff --git a/test/FeeOnTransferDetector.t.sol b/test/FeeOnTransferDetector.t.sol index 02c661d..c34ae4f 100644 --- a/test/FeeOnTransferDetector.t.sol +++ b/test/FeeOnTransferDetector.t.sol @@ -8,6 +8,7 @@ import {MockV2Factory} from "./mock/MockV2Factory.sol"; import {MockFotToken} from "./mock/MockFotToken.sol"; import {MockFotTokenWithExternalFees} from "./mock/MockFotTokenWithExternalFees.sol"; import {MockReentrancyFotToken} from "./mock/MockReentrancyFotToken.sol"; +import {MockSellReentrancyFotToken} from "./mock/MockSellReentrancyFotToken.sol"; import {MockToken} from "./mock/MockToken.sol"; import {MockReenteringPair} from "./mock/MockReenteringPair.sol"; @@ -38,6 +39,7 @@ contract FeeOnTransferDetectorTest is Test { assertEq(fees.sellFeeBps, 500); assertEq(fees.feeTakenOnTransfer, false); assertEq(fees.externalTransferFailed, false); + assertEq(fees.sellReverted, false); } function testBasicFotTokenNoPrecisionLoss() public { @@ -55,6 +57,7 @@ contract FeeOnTransferDetectorTest is Test { assertEq(fees.sellFeeBps, 500); assertEq(fees.feeTakenOnTransfer, false); assertEq(fees.externalTransferFailed, false); + assertEq(fees.sellReverted, false); } function testBasicFotTokenWithExternalFees() public { @@ -71,6 +74,7 @@ contract FeeOnTransferDetectorTest is Test { assertEq(fees.sellFeeBps, 500); assertEq(fees.feeTakenOnTransfer, true); assertEq(fees.externalTransferFailed, false); + assertEq(fees.sellReverted, false); } function testReentrancyFotToken() public { @@ -87,6 +91,24 @@ contract FeeOnTransferDetectorTest is Test { assertEq(fees.sellFeeBps, 500); assertEq(fees.feeTakenOnTransfer, false); assertEq(fees.externalTransferFailed, true); + assertEq(fees.sellReverted, false); + } + + function testSellReentrancyFotToken() public { + MockSellReentrancyFotToken fotToken = new MockSellReentrancyFotToken(500); + MockToken otherToken = new MockToken(); + address pair = factory.deployPair(address(fotToken), address(otherToken)); + fotToken.setPair(pair); + fotToken.mint(pair, 100 ether); + otherToken.mint(pair, 100 ether); + IUniswapV2Pair(pair).sync(); + + TokenFees memory fees = detector.validate(address(fotToken), address(otherToken), 1 ether); + assertEq(fees.buyFeeBps, 500); + assertEq(fees.sellFeeBps, 500); + assertEq(fees.feeTakenOnTransfer, false); + assertEq(fees.externalTransferFailed, false); + assertEq(fees.sellReverted, true); } function testBasicFotTokenFuzz(uint16 buyFee, uint16 sellFee) public { @@ -105,6 +127,7 @@ contract FeeOnTransferDetectorTest is Test { assertEq(fees.sellFeeBps, sellFee); assertEq(fees.feeTakenOnTransfer, false); assertEq(fees.externalTransferFailed, false); + assertEq(fees.sellReverted, false); } function testBasicFotTokenWithExternalFeesFuzz(uint16 fee) public { @@ -123,6 +146,7 @@ contract FeeOnTransferDetectorTest is Test { bool feeTakenOnTransfer = (fee == 0 && fee == 0) ? false : true; assertEq(fees.feeTakenOnTransfer, feeTakenOnTransfer); assertEq(fees.externalTransferFailed, false); + assertEq(fees.sellReverted, false); } function testTransferFailsErrorPassthrough() public { diff --git a/test/mock/MockSellReentrancyFotToken.sol b/test/mock/MockSellReentrancyFotToken.sol new file mode 100644 index 0000000..0f95f84 --- /dev/null +++ b/test/mock/MockSellReentrancyFotToken.sol @@ -0,0 +1,50 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +pragma solidity =0.8.19; + +import {ERC20} from "solmate/tokens/ERC20.sol"; +import "v2-core/interfaces/IUniswapV2Pair.sol"; + +contract MockSellReentrancyFotToken is ERC20 { + uint256 public taxBps; + address public pair; + + constructor(uint256 _taxBps) ERC20("MockSellReentrancyFotToken", "MSFOT", 18) { + taxBps = _taxBps; + } + + function setPair(address _pair) external { + pair = _pair; + } + + function mint(address to, uint256 amount) external { + _mint(to, amount); + } + + // this token re-enters the pair on sells only (since buys lock the pair) + function transfer(address to, uint256 amount) public override returns (bool) { + balanceOf[msg.sender] -= amount; + + // Cannot overflow because the sum of all user + // balances can't exceed the max uint256 value. + unchecked { + if (to == pair || msg.sender == pair) { + uint256 feeAmount = amount * taxBps / 10000; + balanceOf[to] += amount - feeAmount; + balanceOf[address(this)] += feeAmount; + } else { + balanceOf[to] += amount; + } + } + + // Only add in extra swap for sells + if (to == pair) { + IUniswapV2Pair(pair).token0() == address(this) + ? IUniswapV2Pair(pair).swap(0, 0, address(this), new bytes(0)) + : IUniswapV2Pair(pair).swap(0, 0, address(this), new bytes(0)); + } + + emit Transfer(msg.sender, to, amount); + + return true; + } +}