add forwarder contract code check (#253)

This commit is contained in:
noot
2022-12-06 21:46:04 -05:00
committed by GitHub
parent 83b1c877ff
commit f9aaa5e9d1
3 changed files with 47 additions and 1 deletions

View File

@@ -7,6 +7,8 @@ import (
ethcommon "github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/ethclient"
"github.com/athanorlabs/go-relayer/impls/gsnforwarder"
)
// expectedSwapFactoryBytecodeHex is generated by deploying an instance of SwapFactory.sol
@@ -24,7 +26,10 @@ const (
// blocks. See TestForwarderAddressIndexes to update the values.
var forwarderAddressIndices = []int{1485, 1523}
var errInvalidSwapContract = errors.New("given contract address does not contain correct code")
var (
errInvalidSwapContract = errors.New("given contract address does not contain correct SwapFactory code")
errInvalidForwarderContract = errors.New("given contract address does not contain correct Forwarder code")
)
// CheckSwapFactoryContractCode checks that the bytecode at the given address matches the
// SwapFactory.sol contract. The trusted forwarder address that the contract was deployed
@@ -75,6 +80,36 @@ func CheckSwapFactoryContractCode(
return ethcommon.Address{}, errInvalidSwapContract
}
if (forwarderAddress == ethcommon.Address{}) {
return forwarderAddress, nil
}
err = checkForwarderContractCode(ctx, ec, forwarderAddress)
if err != nil {
return ethcommon.Address{}, err
}
// return the trusted forwarder address that was parsed from the deployed contract byte code
return forwarderAddress, nil
}
// checkSwapFactoryForwarder checks that the trusted forwarder contract used by
// the given swap contract has the expected bytecode.
func checkForwarderContractCode(
ctx context.Context,
ec *ethclient.Client,
contractAddr ethcommon.Address,
) error {
code, err := ec.CodeAt(ctx, contractAddr, nil)
if err != nil {
return err
}
expectedCode := ethcommon.FromHex(gsnforwarder.ForwarderMetaData.Bin)
if !bytes.Contains(expectedCode, code) {
return errInvalidForwarderContract
}
return nil
}

View File

@@ -48,6 +48,14 @@ func getContractCode(t *testing.T, trustedForwarder ethcommon.Address) []byte {
return code
}
func TestCheckForwarderContractCode(t *testing.T) {
ec, _ := tests.NewEthClient(t)
pk := tests.GetMakerTestKey(t)
trustedForwarder := deployForwarder(t, ec, pk)
err := checkForwarderContractCode(context.Background(), ec, trustedForwarder)
require.NoError(t, err)
}
// This test will fail if the compiled SwapFactory contract is updated, but the
// expectedSwapFactoryBytecodeHex constant is not updated. Use this test to update the
// constant.