Skip to content

Commit

Permalink
feat[lang]: protect external calls with keyword (#2938)
Browse files Browse the repository at this point in the history
this commit adds `extcall` and `staticcall` keywords to the vyper
language. these are now a requirement for the user to add to distinguish
internal calls from:
1) external calls which can have side effects (`extcall`), and
2) external calls which are guaranteed by the EVM to not have side
   effects (`staticcall`).

`extcall` is used for `nonpayable` or `payable` functions (which emit
the `CALL` opcode), while `staticcall` is used for `view` and `pure`
functions (which emit the `STATICCALL` opcode).

the motivation for this is laid out more in the linked GH issue, but it
is primarily to make it easier to read, audit and analyze vyper
contracts, since you can find the locations of external calls in source
code using text-only techniques, and do not need to analyze (or have
access to the results of an analysis) in order to find where external
calls are. (note that this has become a larger concern with with the
introduction of modules in vyper, since you can no longer distinguish
between internal and external calls just by looking for the `self.`
prefix).

an analysis of some production contracts indicates that the frequency
of external calls has somewhat high variability, but is in the range of
one `extcall` (or `staticcall`) per 10-25 (logical) sloc, with
`staticcalls` being about twice as common. therefore, based on the
semantic vs write load of the keyword, the keyword should be somewhat
easy to type, but it also needs to be long enough and unusual enough to
stand out in a text editor.

the differentiation between `extcall` and `staticcall` was added
because, during testing of the feature, it was found that being able to
additionally infer at the call site whether the external call can have
side effects or not (without needing to reference the function
definition) substantially enhanced readability.

refactoring/misc updates:
- update and clean up the grammar, especially the `variable_access` rule
  (cf. https://github.com/lark-parser/lark/blob/706190849ee/lark/grammars/python.lark#L192)
- add a proper .parent property to VyperNodes
- add tokenizer changes to make source locations are correct
- ban standalone staticcalls
- update tests -- in some cases, because standalone staticcalls are now
  banned, either an enclosing assignment was added or the mutability of
  the interface was changed.
- rewrite some assert_compile_failed to pytest.raises() along the way
- remove some dead functions

cf. GH issue / VIP 2856 

---------

Co-authored-by: tserg <[email protected]>
  • Loading branch information
charles-cooper and tserg authored Mar 6, 2024
1 parent 327e95a commit 2d232eb
Show file tree
Hide file tree
Showing 54 changed files with 946 additions and 478 deletions.
6 changes: 3 additions & 3 deletions examples/factory/Exchange.vy
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def __init__(_token: IERC20, _factory: Factory):
@external
def initialize():
# Anyone can safely call this function because of EXTCODEHASH
self.factory.register()
extcall self.factory.register()


# NOTE: This contract restricts trading to only be done by the factory.
Expand All @@ -31,12 +31,12 @@ def initialize():
@external
def receive(_from: address, _amt: uint256):
assert msg.sender == self.factory.address
success: bool = self.token.transferFrom(_from, self, _amt)
success: bool = extcall self.token.transferFrom(_from, self, _amt)
assert success


@external
def transfer(_to: address, _amt: uint256):
assert msg.sender == self.factory.address
success: bool = self.token.transfer(_to, _amt)
success: bool = extcall self.token.transfer(_to, _amt)
assert success
6 changes: 3 additions & 3 deletions examples/factory/Factory.vy
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,12 @@ def register():
# NOTE: Should do checks that it hasn't already been set,
# which has to be rectified with any upgrade strategy.
exchange: Exchange = Exchange(msg.sender)
self.exchanges[exchange.token()] = exchange
self.exchanges[staticcall exchange.token()] = exchange


@external
def trade(_token1: IERC20, _token2: IERC20, _amt: uint256):
# Perform a straight exchange of token1 to token 2 (1:1 price)
# NOTE: Any practical implementation would need to solve the price oracle problem
self.exchanges[_token1].receive(msg.sender, _amt)
self.exchanges[_token2].transfer(msg.sender, _amt)
extcall self.exchanges[_token1].receive(msg.sender, _amt)
extcall self.exchanges[_token2].transfer(msg.sender, _amt)
12 changes: 6 additions & 6 deletions examples/market_maker/on_chain_market_maker.vy
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ totalTokenQty: public(uint256)
# Constant set in `initiate` that's used to calculate
# the amount of ether/tokens that are exchanged
invariant: public(uint256)
token_address: IERC20
token: IERC20
owner: public(address)

# Sets the on chain market maker with its owner, initial token quantity,
Expand All @@ -17,8 +17,8 @@ owner: public(address)
@payable
def initiate(token_addr: address, token_quantity: uint256):
assert self.invariant == 0
self.token_address = IERC20(token_addr)
self.token_address.transferFrom(msg.sender, self, token_quantity)
self.token = IERC20(token_addr)
extcall self.token.transferFrom(msg.sender, self, token_quantity)
self.owner = msg.sender
self.totalEthQty = msg.value
self.totalTokenQty = token_quantity
Expand All @@ -33,14 +33,14 @@ def ethToTokens():
eth_in_purchase: uint256 = msg.value - fee
new_total_eth: uint256 = self.totalEthQty + eth_in_purchase
new_total_tokens: uint256 = self.invariant // new_total_eth
self.token_address.transfer(msg.sender, self.totalTokenQty - new_total_tokens)
extcall self.token.transfer(msg.sender, self.totalTokenQty - new_total_tokens)
self.totalEthQty = new_total_eth
self.totalTokenQty = new_total_tokens

# Sells tokens to the contract in exchange for ether
@external
def tokensToEth(sell_quantity: uint256):
self.token_address.transferFrom(msg.sender, self, sell_quantity)
extcall self.token.transferFrom(msg.sender, self, sell_quantity)
new_total_tokens: uint256 = self.totalTokenQty + sell_quantity
new_total_eth: uint256 = self.invariant // new_total_tokens
eth_to_send: uint256 = self.totalEthQty - new_total_eth
Expand All @@ -52,5 +52,5 @@ def tokensToEth(sell_quantity: uint256):
@external
def ownerWithdraw():
assert self.owner == msg.sender
self.token_address.transfer(self.owner, self.totalTokenQty)
extcall self.token.transfer(self.owner, self.totalTokenQty)
selfdestruct(self.owner)
20 changes: 10 additions & 10 deletions examples/tokens/ERC4626.vy
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def transferFrom(sender: address, receiver: address, amount: uint256) -> bool:
@view
@external
def totalAssets() -> uint256:
return self.asset.balanceOf(self)
return staticcall self.asset.balanceOf(self)


@view
Expand All @@ -91,7 +91,7 @@ def _convertToAssets(shareAmount: uint256) -> uint256:

# NOTE: `shareAmount = 0` is extremely rare case, not optimizing for it
# NOTE: `totalAssets = 0` is extremely rare case, not optimizing for it
return shareAmount * self.asset.balanceOf(self) // totalSupply
return shareAmount * staticcall self.asset.balanceOf(self) // totalSupply


@view
Expand All @@ -104,7 +104,7 @@ def convertToAssets(shareAmount: uint256) -> uint256:
@internal
def _convertToShares(assetAmount: uint256) -> uint256:
totalSupply: uint256 = self.totalSupply
totalAssets: uint256 = self.asset.balanceOf(self)
totalAssets: uint256 = staticcall self.asset.balanceOf(self)
if totalAssets == 0 or totalSupply == 0:
return assetAmount # 1:1 price

Expand Down Expand Up @@ -133,7 +133,7 @@ def previewDeposit(assets: uint256) -> uint256:
@external
def deposit(assets: uint256, receiver: address=msg.sender) -> uint256:
shares: uint256 = self._convertToShares(assets)
self.asset.transferFrom(msg.sender, self, assets)
extcall self.asset.transferFrom(msg.sender, self, assets)

self.totalSupply += shares
self.balanceOf[receiver] += shares
Expand All @@ -153,7 +153,7 @@ def previewMint(shares: uint256) -> uint256:
assets: uint256 = self._convertToAssets(shares)

# NOTE: Vyper does lazy eval on `and`, so this avoids SLOADs most of the time
if assets == 0 and self.asset.balanceOf(self) == 0:
if assets == 0 and staticcall self.asset.balanceOf(self) == 0:
return shares # NOTE: Assume 1:1 price if nothing deposited yet

return assets
Expand All @@ -163,10 +163,10 @@ def previewMint(shares: uint256) -> uint256:
def mint(shares: uint256, receiver: address=msg.sender) -> uint256:
assets: uint256 = self._convertToAssets(shares)

if assets == 0 and self.asset.balanceOf(self) == 0:
if assets == 0 and staticcall self.asset.balanceOf(self) == 0:
assets = shares # NOTE: Assume 1:1 price if nothing deposited yet

self.asset.transferFrom(msg.sender, self, assets)
extcall self.asset.transferFrom(msg.sender, self, assets)

self.totalSupply += shares
self.balanceOf[receiver] += shares
Expand Down Expand Up @@ -206,7 +206,7 @@ def withdraw(assets: uint256, receiver: address=msg.sender, owner: address=msg.s
self.totalSupply -= shares
self.balanceOf[owner] -= shares

self.asset.transfer(receiver, assets)
extcall self.asset.transfer(receiver, assets)
log IERC4626.Withdraw(msg.sender, receiver, owner, assets, shares)
return shares

Expand All @@ -232,7 +232,7 @@ def redeem(shares: uint256, receiver: address=msg.sender, owner: address=msg.sen
self.totalSupply -= shares
self.balanceOf[owner] -= shares

self.asset.transfer(receiver, assets)
extcall self.asset.transfer(receiver, assets)
log IERC4626.Withdraw(msg.sender, receiver, owner, assets, shares)
return assets

Expand All @@ -241,4 +241,4 @@ def redeem(shares: uint256, receiver: address=msg.sender, owner: address=msg.sen
def DEBUG_steal_tokens(amount: uint256):
# NOTE: This is the primary method of mocking share price changes
# do not put in production code!!!
self.asset.transfer(msg.sender, amount)
extcall self.asset.transfer(msg.sender, amount)
2 changes: 1 addition & 1 deletion examples/tokens/ERC721.vy
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ def safeTransferFrom(
"""
self._transferFrom(_from, _to, _tokenId, msg.sender)
if _to.is_contract: # check if `_to` is a contract address
returnValue: bytes4 = ERC721Receiver(_to).onERC721Received(msg.sender, _from, _tokenId, _data)
returnValue: bytes4 = extcall ERC721Receiver(_to).onERC721Received(msg.sender, _from, _tokenId, _data)
# Throws if transfer destination is a contract which does not implement 'onERC721Received'
assert returnValue == method_id("onERC721Received(address,address,uint256,bytes)", output_type=bytes4)

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
"py-evm>=0.7.0a1,<0.8",
"web3==6.0.0",
"tox>=3.15,<4.0",
"lark==1.1.2",
"lark==1.1.9",
"hypothesis[lark]>=5.37.1,<6.0",
"eth-stdlib==0.2.6",
],
Expand Down
2 changes: 1 addition & 1 deletion tests/functional/builtins/codegen/test_abi_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ def get_counter() -> Bytes[128]: nonpayable
def foo(addr: address) -> (uint256, String[5]):
a: uint256 = 0
b: String[5] = ""
a, b = _abi_decode(Foo(addr).get_counter(), (uint256, String[5]), unwrap_tuple=False)
a, b = _abi_decode(extcall Foo(addr).get_counter(), (uint256, String[5]), unwrap_tuple=False)
return a, b
"""

Expand Down
2 changes: 1 addition & 1 deletion tests/functional/builtins/codegen/test_abi_encode.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ def get_counter() -> (uint256, String[6]): nonpayable
@external
def foo(addr: address) -> Bytes[164]:
return _abi_encode(Foo(addr).get_counter(), method_id=0xdeadbeef)
return _abi_encode(extcall Foo(addr).get_counter(), method_id=0xdeadbeef)
"""

c2 = get_contract(contract_2)
Expand Down
8 changes: 4 additions & 4 deletions tests/functional/builtins/codegen/test_addmod.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,12 @@ def test_uint256_addmod_ext_call(
w3, side_effects_contract, assert_side_effects_invoked, get_contract
):
code = """
@external
def foo(f: Foo) -> uint256:
return uint256_addmod(32, 2, f.foo(32))
interface Foo:
def foo(x: uint256) -> uint256: payable
@external
def foo(f: Foo) -> uint256:
return uint256_addmod(32, 2, extcall f.foo(32))
"""

c1 = side_effects_contract("uint256")
Expand Down
8 changes: 4 additions & 4 deletions tests/functional/builtins/codegen/test_as_wei_value.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,12 +97,12 @@ def foo(a: {data_type}) -> uint256:

def test_ext_call(w3, side_effects_contract, assert_side_effects_invoked, get_contract):
code = """
@external
def foo(a: Foo) -> uint256:
return as_wei_value(a.foo(7), "ether")
interface Foo:
def foo(x: uint8) -> uint8: nonpayable
@external
def foo(a: Foo) -> uint256:
return as_wei_value(extcall a.foo(7), "ether")
"""

c1 = side_effects_contract("uint8")
Expand Down
8 changes: 4 additions & 4 deletions tests/functional/builtins/codegen/test_ceil.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,12 +108,12 @@ def ceil_param(p: decimal) -> int256:

def test_ceil_ext_call(w3, side_effects_contract, assert_side_effects_invoked, get_contract):
code = """
@external
def foo(a: Foo) -> int256:
return ceil(a.foo(2.5))
interface Foo:
def foo(x: decimal) -> decimal: payable
@external
def foo(a: Foo) -> int256:
return ceil(extcall a.foo(2.5))
"""

c1 = side_effects_contract("decimal")
Expand Down
17 changes: 2 additions & 15 deletions tests/functional/builtins/codegen/test_create_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,30 +44,23 @@ def test() -> address:

def test_create_minimal_proxy_to_call(get_contract, w3):
code = """
interface SubContract:
def hello() -> Bytes[100]: view
other: public(address)
@external
def test() -> address:
self.other = create_minimal_proxy_to(self)
return self.other
@external
def hello() -> Bytes[100]:
return b"hello world!"
@external
def test2() -> Bytes[100]:
return SubContract(self.other).hello()
return staticcall SubContract(self.other).hello()
"""

c = get_contract(code)
Expand All @@ -79,30 +72,24 @@ def test2() -> Bytes[100]:

def test_minimal_proxy_exception(w3, get_contract, tx_failed):
code = """
interface SubContract:
def hello(a: uint256) -> Bytes[100]: view
other: public(address)
@external
def test() -> address:
self.other = create_minimal_proxy_to(self)
return self.other
@external
def hello(a: uint256) -> Bytes[100]:
assert a > 0, "invaliddddd"
return b"hello world!"
@external
def test2(a: uint256) -> Bytes[100]:
return SubContract(self.other).hello(a)
return staticcall SubContract(self.other).hello(a)
"""

c = get_contract(code)
Expand Down
4 changes: 2 additions & 2 deletions tests/functional/builtins/codegen/test_ec.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def foo(x: uint256[2]) -> uint256[2]: payable
@external
def foo(a: Foo) -> uint256[2]:
return ecadd([1, 2], a.foo([1, 2]))
return ecadd([1, 2], extcall a.foo([1, 2]))
"""
c1 = side_effects_contract("uint256[2]")
c2 = get_contract(code)
Expand Down Expand Up @@ -148,7 +148,7 @@ def foo(x: uint256) -> uint256: payable
@external
def foo(a: Foo) -> uint256[2]:
return ecmul([1, 2], a.foo(3))
return ecmul([1, 2], extcall a.foo(3))
"""
c1 = side_effects_contract("uint256")
c2 = get_contract(code)
Expand Down
4 changes: 2 additions & 2 deletions tests/functional/builtins/codegen/test_empty.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,7 +445,7 @@ def pub2() -> bool:
@external
def pub3(x: address) -> bool:
self.write_junk_to_memory()
return Mirror(x).test_empty(empty(int128[111]), empty(Bytes[1024]), empty(Bytes[31]))
return staticcall Mirror(x).test_empty(empty(int128[111]), empty(Bytes[1024]), empty(Bytes[31]))
"""
c = get_contract_with_gas_estimation(code)
mirror = get_contract_with_gas_estimation(code)
Expand Down Expand Up @@ -658,7 +658,7 @@ def foo(
@view
@external
def bar(a: address) -> (uint256, Bytes[33], Bytes[65], uint256):
return Foo(a).foo(12, {a}, 42, {b})
return staticcall Foo(a).foo(12, {a}, 42, {b})
"""

c1 = get_contract(code_a)
Expand Down
8 changes: 4 additions & 4 deletions tests/functional/builtins/codegen/test_floor.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,12 +112,12 @@ def floor_param(p: decimal) -> int256:

def test_floor_ext_call(w3, side_effects_contract, assert_side_effects_invoked, get_contract):
code = """
@external
def foo(a: Foo) -> int256:
return floor(a.foo(2.5))
interface Foo:
def foo(x: decimal) -> decimal: nonpayable
@external
def foo(a: Foo) -> int256:
return floor(extcall a.foo(2.5))
"""

c1 = side_effects_contract("decimal")
Expand Down
8 changes: 4 additions & 4 deletions tests/functional/builtins/codegen/test_mulmod.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,12 @@ def test_uint256_mulmod_ext_call(
w3, side_effects_contract, assert_side_effects_invoked, get_contract
):
code = """
@external
def foo(f: Foo) -> uint256:
return uint256_mulmod(200, 3, f.foo(601))
interface Foo:
def foo(x: uint256) -> uint256: nonpayable
@external
def foo(f: Foo) -> uint256:
return uint256_mulmod(200, 3, extcall f.foo(601))
"""

c1 = side_effects_contract("uint256")
Expand Down
Loading

0 comments on commit 2d232eb

Please sign in to comment.