feat(contracts): add replay message support (#437)

Co-authored-by: Max Wolff <maxcwolff@gmail.com>
Co-authored-by: Péter Garamvölgyi <peter@scroll.io>
Co-authored-by: Haichen Shen <shenhaichen@gmail.com>
This commit is contained in:
Xi Lin
2023-05-12 08:10:40 +08:00
committed by GitHub
parent 97e2695207
commit 57737ec2ca
5 changed files with 114 additions and 12 deletions

View File

@@ -215,7 +215,7 @@ function renounceOwnership() external nonpayable
### replayMessage
```solidity
function replayMessage(address _from, address _to, uint256 _value, uint256 _queueIndex, bytes _message, uint32 _oldGasLimit, uint32 _newGasLimit) external nonpayable
function replayMessage(address _from, address _to, uint256 _value, uint256 _queueIndex, bytes _message, uint32 _oldGasLimit, uint32 _newGasLimit, address _refundAddress) external payable
```
Replay an exsisting message.
@@ -233,6 +233,7 @@ Replay an exsisting message.
| _message | bytes | undefined |
| _oldGasLimit | uint32 | undefined |
| _newGasLimit | uint32 | undefined |
| _refundAddress | address | undefined |
### rollup

View File

@@ -44,6 +44,7 @@ interface IL1ScrollMessenger is IScrollMessenger {
/// @param message The content of the message.
/// @param oldGasLimit Original gas limit used to send the message.
/// @param newGasLimit New gas limit to be used for this message.
/// @param refundAddress The address of account who will receive the refunded fee.
function replayMessage(
address from,
address to,
@@ -51,6 +52,7 @@ interface IL1ScrollMessenger is IScrollMessenger {
uint256 queueIndex,
bytes memory message,
uint32 oldGasLimit,
uint32 newGasLimit
) external;
uint32 newGasLimit,
address refundAddress
) external payable;
}

View File

@@ -10,6 +10,7 @@ import {IL1ScrollMessenger} from "./IL1ScrollMessenger.sol";
import {ScrollConstants} from "../libraries/constants/ScrollConstants.sol";
import {IScrollMessenger} from "../libraries/IScrollMessenger.sol";
import {ScrollMessengerBase} from "../libraries/ScrollMessengerBase.sol";
import {AddressAliasHelper} from "../libraries/common/AddressAliasHelper.sol";
import {WithdrawTrieVerifier} from "../libraries/verifier/WithdrawTrieVerifier.sol";
// solhint-disable avoid-low-level-calls
@@ -174,9 +175,54 @@ contract L1ScrollMessenger is PausableUpgradeable, ScrollMessengerBase, IL1Scrol
uint256 _queueIndex,
bytes memory _message,
uint32 _oldGasLimit,
uint32 _newGasLimit
) external override whenNotPaused {
// @todo
uint32 _newGasLimit,
address _refundAddress
) external payable override whenNotPaused {
// We will use a different `queueIndex` for the replaced message. However, the original `queueIndex` or `nonce`
// is encoded in the `_message`. We will check the `xDomainCalldata` in layer 2 to avoid duplicated execution.
// So, only one message will succeed in layer 2. If one of the message is executed successfully, the other one
// will revert with "Message was already successfully executed".
address _messageQueue = messageQueue;
address _counterpart = counterpart;
bytes memory _xDomainCalldata = _encodeXDomainCalldata(_from, _to, _value, _queueIndex, _message);
// compute the expected transaction hash
bytes32 _computedTransactionHash = IL1MessageQueue(_messageQueue).computeTransactionHash(
AddressAliasHelper.applyL1ToL2Alias(address(this)),
_queueIndex,
0,
_counterpart,
_oldGasLimit,
_xDomainCalldata
);
// check the provided message matching with enqueued one.
require(
_computedTransactionHash == IL1MessageQueue(_messageQueue).getCrossDomainMessage(_queueIndex),
"Provided message has not been enqueued"
);
// compute and deduct the messaging fee to fee vault.
uint256 _fee = IL1MessageQueue(_messageQueue).estimateCrossDomainMessageFee(_newGasLimit);
// charge relayer fee
require(msg.value >= _fee, "Insufficient msg.value for fee");
if (_fee > 0) {
(bool _success, ) = feeVault.call{value: _fee}("");
require(_success, "Failed to deduct the fee");
}
// enqueue the new transaction
IL1MessageQueue(_messageQueue).appendCrossDomainMessage(_counterpart, _newGasLimit, _xDomainCalldata);
// refund fee to `_refundAddress`
unchecked {
uint256 _refund = msg.value - _fee;
if (_refund > 0) {
(bool _success, ) = _refundAddress.call{value: _refund}("");
require(_success, "Failed to refund the fee");
}
}
}
/************************
@@ -232,7 +278,7 @@ contract L1ScrollMessenger is PausableUpgradeable, ScrollMessengerBase, IL1Scrol
emit SentMessage(msg.sender, _to, _value, _messageNonce, _gasLimit, _message);
// refund fee to tx.origin
// refund fee to `_refundAddress`
unchecked {
uint256 _refund = msg.value - _fee - _value;
if (_refund > 0) {

View File

@@ -127,9 +127,12 @@ contract L2ScrollMessenger is ScrollMessengerBase, PausableUpgradeable, IL2Scrol
bytes32 _expectedStateRoot = IL1BlockContainer(blockContainer).getStateRoot(_blockHash);
require(_expectedStateRoot != bytes32(0), "Block is not imported");
// @todo fix the actual slot later.
bytes32 _storageKey;
// `mapping(bytes32 => bool) public isL1MessageSent` is the 105-nd slot of contract `L1ScrollMessenger`.
// + 50 from `OwnableUpgradeable`
// + 4 from `ScrollMessengerBase`
// + 50 from `PausableUpgradeable`
// + 2-nd in `L1ScrollMessenger`
assembly {
mstore(0x00, _msgHash)
mstore(0x20, 105)
@@ -159,9 +162,12 @@ contract L2ScrollMessenger is ScrollMessengerBase, PausableUpgradeable, IL2Scrol
bytes32 _expectedStateRoot = IL1BlockContainer(blockContainer).getStateRoot(_blockHash);
require(_expectedStateRoot != bytes32(0), "Block not imported");
// @todo fix the actual slot later.
bytes32 _storageKey;
// `mapping(bytes32 => bool) public isL2MessageExecuted` is the 106-th slot of contract `L1ScrollMessenger`.
// + 50 from `OwnableUpgradeable`
// + 4 from `ScrollMessengerBase`
// + 50 from `PausableUpgradeable`
// + 3-rd in `L1ScrollMessenger`
assembly {
mstore(0x00, _msgHash)
mstore(0x20, 106)

View File

@@ -5,6 +5,7 @@ pragma solidity ^0.8.0;
import {DSTestPlus} from "solmate/test/utils/DSTestPlus.sol";
import {L1MessageQueue} from "../L1/rollup/L1MessageQueue.sol";
import {L2GasPriceOracle} from "../L1/rollup/L2GasPriceOracle.sol";
import {IScrollChain, ScrollChain} from "../L1/rollup/ScrollChain.sol";
import {Whitelist} from "../L2/predeploys/Whitelist.sol";
import {IL1ScrollMessenger, L1ScrollMessenger} from "../L1/L1ScrollMessenger.sol";
@@ -18,6 +19,8 @@ contract L1ScrollMessengerTest is DSTestPlus {
L1ScrollMessenger internal l1Messenger;
ScrollChain internal scrollChain;
L1MessageQueue internal l1MessageQueue;
L2GasPriceOracle internal gasOracle;
Whitelist internal whitelist;
function setUp() public {
// Deploy L2 contracts
@@ -27,11 +30,19 @@ contract L1ScrollMessengerTest is DSTestPlus {
scrollChain = new ScrollChain(0, 0, bytes32(0));
l1MessageQueue = new L1MessageQueue();
l1Messenger = new L1ScrollMessenger();
gasOracle = new L2GasPriceOracle();
whitelist = new Whitelist(address(this));
// Initialize L1 contracts
l1Messenger.initialize(address(l2Messenger), feeVault, address(scrollChain), address(l1MessageQueue));
l1MessageQueue.initialize(address(l1Messenger), address(0));
l1MessageQueue.initialize(address(l1Messenger), address(gasOracle));
gasOracle.initialize();
scrollChain.initialize(address(l1MessageQueue));
gasOracle.updateWhitelist(address(whitelist));
address[] memory _accounts = new address[](1);
_accounts[0] = address(this);
whitelist.updateWhitelistStatus(_accounts, true);
}
function testForbidCallMessageQueueFromL2() external {
@@ -67,7 +78,6 @@ contract L1ScrollMessengerTest is DSTestPlus {
function testSendMessage(uint256 exceedValue, address refundAddress) external {
hevm.assume(refundAddress.code.length == 0);
hevm.assume(uint256(uint160(refundAddress)) > 100); // ignore some precompile contracts
hevm.assume(refundAddress != address(this));
exceedValue = bound(exceedValue, 1, address(this).balance / 2);
@@ -80,4 +90,41 @@ contract L1ScrollMessengerTest is DSTestPlus {
l1Messenger.sendMessage{value: 1 + exceedValue}(address(0), 1, new bytes(0), 0, refundAddress);
assertEq(balanceBefore + exceedValue, refundAddress.balance);
}
}
function testReplayMessage(uint256 exceedValue, address refundAddress) external {
hevm.assume(refundAddress.code.length == 0);
hevm.assume(uint256(uint160(refundAddress)) > 100); // ignore some precompile contracts
exceedValue = bound(exceedValue, 1, address(this).balance / 2);
// append a message
l1Messenger.sendMessage{value: 100}(address(0), 100, new bytes(0), 0, refundAddress);
// Provided message has not been enqueued
hevm.expectRevert("Provided message has not been enqueued");
l1Messenger.replayMessage(address(this), address(0), 101, 0, new bytes(0), 0, 1, refundAddress);
gasOracle.setL2BaseFee(1);
// Insufficient msg.value
hevm.expectRevert("Insufficient msg.value for fee");
l1Messenger.replayMessage(address(this), address(0), 100, 0, new bytes(0), 0, 1, refundAddress);
uint256 _fee = gasOracle.l2BaseFee() * 100;
// refund exceed fee
uint256 balanceBefore = refundAddress.balance;
uint256 feeVaultBefore = feeVault.balance;
l1Messenger.replayMessage{value: _fee + exceedValue}(
address(this),
address(0),
100,
0,
new bytes(0),
0,
100,
refundAddress
);
assertEq(balanceBefore + exceedValue, refundAddress.balance);
assertEq(feeVaultBefore + _fee, feeVault.balance);
}
}