mirror of
https://github.com/AthanorLabs/atomic-swap.git
synced 2026-01-09 14:18:03 -05:00
add forwarder contract code check (#253)
This commit is contained in:
@@ -7,6 +7,8 @@ import (
|
||||
|
||||
ethcommon "github.com/ethereum/go-ethereum/common"
|
||||
"github.com/ethereum/go-ethereum/ethclient"
|
||||
|
||||
"github.com/athanorlabs/go-relayer/impls/gsnforwarder"
|
||||
)
|
||||
|
||||
// expectedSwapFactoryBytecodeHex is generated by deploying an instance of SwapFactory.sol
|
||||
@@ -24,7 +26,10 @@ const (
|
||||
// blocks. See TestForwarderAddressIndexes to update the values.
|
||||
var forwarderAddressIndices = []int{1485, 1523}
|
||||
|
||||
var errInvalidSwapContract = errors.New("given contract address does not contain correct code")
|
||||
var (
|
||||
errInvalidSwapContract = errors.New("given contract address does not contain correct SwapFactory code")
|
||||
errInvalidForwarderContract = errors.New("given contract address does not contain correct Forwarder code")
|
||||
)
|
||||
|
||||
// CheckSwapFactoryContractCode checks that the bytecode at the given address matches the
|
||||
// SwapFactory.sol contract. The trusted forwarder address that the contract was deployed
|
||||
@@ -75,6 +80,36 @@ func CheckSwapFactoryContractCode(
|
||||
return ethcommon.Address{}, errInvalidSwapContract
|
||||
}
|
||||
|
||||
if (forwarderAddress == ethcommon.Address{}) {
|
||||
return forwarderAddress, nil
|
||||
}
|
||||
|
||||
err = checkForwarderContractCode(ctx, ec, forwarderAddress)
|
||||
if err != nil {
|
||||
return ethcommon.Address{}, err
|
||||
}
|
||||
|
||||
// return the trusted forwarder address that was parsed from the deployed contract byte code
|
||||
return forwarderAddress, nil
|
||||
}
|
||||
|
||||
// checkSwapFactoryForwarder checks that the trusted forwarder contract used by
|
||||
// the given swap contract has the expected bytecode.
|
||||
func checkForwarderContractCode(
|
||||
ctx context.Context,
|
||||
ec *ethclient.Client,
|
||||
contractAddr ethcommon.Address,
|
||||
) error {
|
||||
code, err := ec.CodeAt(ctx, contractAddr, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
expectedCode := ethcommon.FromHex(gsnforwarder.ForwarderMetaData.Bin)
|
||||
|
||||
if !bytes.Contains(expectedCode, code) {
|
||||
return errInvalidForwarderContract
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -48,6 +48,14 @@ func getContractCode(t *testing.T, trustedForwarder ethcommon.Address) []byte {
|
||||
return code
|
||||
}
|
||||
|
||||
func TestCheckForwarderContractCode(t *testing.T) {
|
||||
ec, _ := tests.NewEthClient(t)
|
||||
pk := tests.GetMakerTestKey(t)
|
||||
trustedForwarder := deployForwarder(t, ec, pk)
|
||||
err := checkForwarderContractCode(context.Background(), ec, trustedForwarder)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// This test will fail if the compiled SwapFactory contract is updated, but the
|
||||
// expectedSwapFactoryBytecodeHex constant is not updated. Use this test to update the
|
||||
// constant.
|
||||
|
||||
Reference in New Issue
Block a user