From db906f8c8a9f8d848eff15606a6b9e22b2e5fd1f Mon Sep 17 00:00:00 2001 From: colin <102356659+colinlyguo@users.noreply.github.com> Date: Fri, 12 May 2023 20:28:21 +0800 Subject: [PATCH] test: add more tests (#466) --- bridge/config/config_test.go | 19 ++-- bridge/sender/sender_test.go | 94 +++++++++++++++++- build/run_tests.sh | 2 +- common/go.mod | 1 - common/go.sum | 2 - common/types/batch_test.go | 54 ++++++++++- common/types/db_test.go | 92 ++++++++++++++++++ common/types/message/message_test.go | 124 +++++++++++++++++++++++- coordinator/config/config_test.go | 137 +++++++++++++++++++++++++++ database/config_test.go | 70 ++++++++++++++ 10 files changed, 581 insertions(+), 14 deletions(-) create mode 100644 common/types/db_test.go create mode 100644 coordinator/config/config_test.go create mode 100644 database/config_test.go diff --git a/bridge/config/config_test.go b/bridge/config/config_test.go index 36c16da5c..b21560dfe 100644 --- a/bridge/config/config_test.go +++ b/bridge/config/config_test.go @@ -13,7 +13,7 @@ import ( func TestConfig(t *testing.T) { t.Run("Success Case", func(t *testing.T) { cfg, err := NewConfig("../config.json") - assert.NoError(t, err, "failed to load config") + assert.NoError(t, err) assert.Len(t, cfg.L1Config.RelayerConfig.MessageSenderPrivateKeys, 1) assert.Len(t, cfg.L2Config.RelayerConfig.MessageSenderPrivateKeys, 1) @@ -23,7 +23,11 @@ func TestConfig(t *testing.T) { assert.NoError(t, err) tmpJSON := fmt.Sprintf("/tmp/%d_bridge_config.json", time.Now().Nanosecond()) - defer func() { _ = os.Remove(tmpJSON) }() + defer func() { + if _, err = os.Stat(tmpJSON); err == nil { + assert.NoError(t, os.Remove(tmpJSON)) + } + }() assert.NoError(t, os.WriteFile(tmpJSON, data, 0644)) @@ -42,14 +46,17 @@ func TestConfig(t *testing.T) { t.Run("Invalid JSON Content", func(t *testing.T) { // Create a temporary file with invalid JSON content - tempFile, err := os.CreateTemp("", "invalid_json_config.json") + tmpFile, err := os.CreateTemp("", "invalid_json_config.json") assert.NoError(t, err) - defer os.Remove(tempFile.Name()) + defer func() { + assert.NoError(t, tmpFile.Close()) + assert.NoError(t, os.Remove(tmpFile.Name())) + }() - _, err = tempFile.WriteString("{ invalid_json: ") + _, err = tmpFile.WriteString("{ invalid_json: ") assert.NoError(t, err) - _, err = NewConfig(tempFile.Name()) + _, err = NewConfig(tmpFile.Name()) assert.Error(t, err) }) } diff --git a/bridge/sender/sender_test.go b/bridge/sender/sender_test.go index b7c0ba303..147590815 100644 --- a/bridge/sender/sender_test.go +++ b/bridge/sender/sender_test.go @@ -11,7 +11,9 @@ import ( "golang.org/x/sync/errgroup" + "github.com/agiledragon/gomonkey/v2" cmap "github.com/orcaman/concurrent-map" + "github.com/scroll-tech/go-ethereum/accounts/abi/bind" "github.com/scroll-tech/go-ethereum/common" "github.com/scroll-tech/go-ethereum/core/types" "github.com/scroll-tech/go-ethereum/crypto" @@ -63,7 +65,8 @@ func TestSender(t *testing.T) { t.Run("test pending limit", testPendLimit) t.Run("test min gas limit", testMinGasLimit) - t.Run("test resubmit transaction", func(t *testing.T) { testResubmitTransaction(t) }) + t.Run("test resubmit transaction", testResubmitTransaction) + t.Run("test check pending transaction", testCheckPendingTransaction) t.Run("test 1 account sender", func(t *testing.T) { testBatchSender(t, 1) }) t.Run("test 3 account sender", func(t *testing.T) { testBatchSender(t, 3) }) @@ -152,6 +155,95 @@ func testResubmitTransaction(t *testing.T) { } } +func testCheckPendingTransaction(t *testing.T) { + for _, txType := range txTypes { + cfgCopy := *cfg.L1Config.RelayerConfig.SenderConfig + cfgCopy.TxType = txType + s, err := NewSender(context.Background(), &cfgCopy, privateKeys) + assert.NoError(t, err) + + header := &types.Header{Number: big.NewInt(100), BaseFee: big.NewInt(100)} + confirmed := uint64(100) + receipt := &types.Receipt{Status: types.ReceiptStatusSuccessful, BlockNumber: big.NewInt(90)} + auth := s.auths.getAccount() + tx := types.NewTransaction(auth.Nonce.Uint64(), common.Address{}, big.NewInt(0), 0, big.NewInt(0), nil) + + testCases := []struct { + name string + receipt *types.Receipt + receiptErr error + resubmitErr error + expectedCount int + expectedFound bool + }{ + { + name: "Normal case, transaction receipt exists and successful", + receipt: receipt, + receiptErr: nil, + resubmitErr: nil, + expectedCount: 0, + expectedFound: false, + }, + { + name: "Resubmit case, resubmitTransaction error (not nonce) case", + receipt: receipt, + receiptErr: errors.New("receipt error"), + resubmitErr: errors.New("resubmit error"), + expectedCount: 1, + expectedFound: true, + }, + { + name: "Resubmit case, resubmitTransaction success case", + receipt: receipt, + receiptErr: errors.New("receipt error"), + resubmitErr: nil, + expectedCount: 1, + expectedFound: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + var c *ethclient.Client + patchGuard := gomonkey.ApplyMethodFunc(c, "TransactionReceipt", func(ctx context.Context, txHash common.Hash) (*types.Receipt, error) { + return tc.receipt, tc.receiptErr + }) + patchGuard.ApplyPrivateMethod(s, "resubmitTransaction", + func(feeData *FeeData, auth *bind.TransactOpts, tx *types.Transaction) (*types.Transaction, error) { + return tx, tc.resubmitErr + }, + ) + + pendingTx := &PendingTransaction{id: "abc", tx: tx, submitAt: header.Number.Uint64() - s.config.EscalateBlocks - 1} + s.pendingTxs.Set(pendingTx.id, pendingTx) + s.checkPendingTransaction(header, confirmed) + + if tc.receiptErr == nil { + expectedConfirmation := &Confirmation{ + ID: pendingTx.id, + IsSuccessful: tc.receipt.Status == types.ReceiptStatusSuccessful, + TxHash: pendingTx.tx.Hash(), + } + actualConfirmation := <-s.confirmCh + assert.Equal(t, expectedConfirmation, actualConfirmation) + } + + if tc.expectedFound && tc.resubmitErr == nil { + actualPendingTx, found := s.pendingTxs.Get(pendingTx.id) + assert.Equal(t, true, found) + assert.Equal(t, header.Number.Uint64(), actualPendingTx.submitAt) + } + + _, found := s.pendingTxs.Get(pendingTx.id) + assert.Equal(t, tc.expectedFound, found) + assert.Equal(t, tc.expectedCount, s.pendingTxs.Count()) + patchGuard.Reset() + }) + } + s.Stop() + } +} + func testBatchSender(t *testing.T, batchSize int) { for _, txType := range txTypes { for len(privateKeys) < batchSize { diff --git a/build/run_tests.sh b/build/run_tests.sh index a3795fcdc..caacdfe03 100755 --- a/build/run_tests.sh +++ b/build/run_tests.sh @@ -3,7 +3,7 @@ set -uex profile_name=$1 -exclude_dirs=("scroll-tech/bridge/cmd" "scroll-tech/bridge/tests" "scroll-tech/bridge/mock_bridge" "scroll-tech/coordinator/cmd") +exclude_dirs=("scroll-tech/bridge/cmd" "scroll-tech/bridge/tests" "scroll-tech/bridge/mock_bridge" "scroll-tech/coordinator/cmd" "scroll-tech/coordinator/verifier") all_packages=$(go list ./... | grep -v "^scroll-tech/${profile_name}$") coverpkg="scroll-tech/${profile_name}" diff --git a/common/go.mod b/common/go.mod index 29a90088a..f473eefdf 100644 --- a/common/go.mod +++ b/common/go.mod @@ -13,7 +13,6 @@ require ( github.com/scroll-tech/go-ethereum v1.10.14-0.20230321020420-127af384ed04 github.com/stretchr/testify v1.8.2 github.com/urfave/cli/v2 v2.17.2-0.20221006022127-8f469abc00aa - gotest.tools v2.2.0+incompatible ) require ( diff --git a/common/go.sum b/common/go.sum index 66de8f46f..937efc0e0 100644 --- a/common/go.sum +++ b/common/go.sum @@ -655,8 +655,6 @@ gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= -gotest.tools v2.2.0+incompatible h1:VsBPFP1AI068pPrMxtb/S8Zkgf9xEmTLJjfM+P5UIEo= -gotest.tools v2.2.0+incompatible/go.mod h1:DsYFclhRJ6vuDpmuTbkuFWG+y2sxOXAzmJt81HFBacw= gotest.tools/v3 v3.0.2/go.mod h1:3SzNCllyD9/Y+b5r9JIKQ474KzkZyqLqEfYqMsX94Bk= gotest.tools/v3 v3.4.0 h1:ZazjZUfuVeZGLAmlKKuyv3IKP5orXcwtOwDQH6YVr6o= gotest.tools/v3 v3.4.0/go.mod h1:CtbdzLSsqVhDgMtKsx03ird5YTGB3ar27v0u/yKBW5g= diff --git a/common/types/batch_test.go b/common/types/batch_test.go index cee60d2bd..d4e9dcb34 100644 --- a/common/types/batch_test.go +++ b/common/types/batch_test.go @@ -1,10 +1,12 @@ package types import ( + "encoding/json" "math/big" + "os" "testing" - "gotest.tools/assert" + "github.com/stretchr/testify/assert" "github.com/scroll-tech/go-ethereum/common" geth_types "github.com/scroll-tech/go-ethereum/core/types" @@ -89,3 +91,53 @@ func TestNewGenesisBatch(t *testing.T) { "wrong genesis batch hash", ) } + +func TestNewBatchData(t *testing.T) { + templateBlockTrace, err := os.ReadFile("../testdata/blockTrace_02.json") + assert.NoError(t, err) + + wrappedBlock := &WrappedBlock{} + assert.NoError(t, json.Unmarshal(templateBlockTrace, wrappedBlock)) + + parentBatch := &BlockBatch{ + Index: 1, + Hash: "0x0000000000000000000000000000000000000000", + StateRoot: "0x0000000000000000000000000000000000000000", + } + batchData1 := NewBatchData(parentBatch, []*WrappedBlock{wrappedBlock}, nil) + assert.NotNil(t, batchData1) + assert.NotNil(t, batchData1.Batch) + assert.Equal(t, "0xac4487c0d8f429dafda3c68cbb8983ac08af83c03c83c365d7df02864f80af37", batchData1.Hash().Hex()) + + templateBlockTrace, err = os.ReadFile("../testdata/blockTrace_03.json") + assert.NoError(t, err) + + wrappedBlock2 := &WrappedBlock{} + assert.NoError(t, json.Unmarshal(templateBlockTrace, wrappedBlock2)) + + parentBatch2 := &BlockBatch{ + Index: batchData1.Batch.BatchIndex, + Hash: batchData1.Hash().Hex(), + StateRoot: batchData1.Batch.NewStateRoot.Hex(), + } + batchData2 := NewBatchData(parentBatch2, []*WrappedBlock{wrappedBlock2}, nil) + assert.NotNil(t, batchData2) + assert.NotNil(t, batchData2.Batch) + assert.Equal(t, "0x8f1447573740b3e75b979879866b8ad02eecf88e1946275eb8cf14ab95876efc", batchData2.Hash().Hex()) +} + +func TestBatchDataTimestamp(t *testing.T) { + // Test case 1: when the batch data contains no blocks. + assert.Equal(t, uint64(0), (&BatchData{}).Timestamp()) + + // Test case 2: when the batch data contains blocks. + batchData := &BatchData{ + Batch: abi.IScrollChainBatch{ + Blocks: []abi.IScrollChainBlockContext{ + {Timestamp: 123456789}, + {Timestamp: 234567891}, + }, + }, + } + assert.Equal(t, uint64(123456789), batchData.Timestamp()) +} diff --git a/common/types/db_test.go b/common/types/db_test.go new file mode 100644 index 000000000..2a4ffedb3 --- /dev/null +++ b/common/types/db_test.go @@ -0,0 +1,92 @@ +package types + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestRollerProveStatus(t *testing.T) { + tests := []struct { + name string + s RollerProveStatus + want string + }{ + { + "RollerAssigned", + RollerAssigned, + "RollerAssigned", + }, + { + "RollerProofValid", + RollerProofValid, + "RollerProofValid", + }, + { + "RollerProofInvalid", + RollerProofInvalid, + "RollerProofInvalid", + }, + { + "Bad Value", + RollerProveStatus(999), // Invalid value. + "Bad Value: 999", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.want, tt.s.String()) + }) + } +} + +func TestProvingStatus(t *testing.T) { + tests := []struct { + name string + s ProvingStatus + want string + }{ + { + "ProvingTaskUnassigned", + ProvingTaskUnassigned, + "unassigned", + }, + { + "ProvingTaskSkipped", + ProvingTaskSkipped, + "skipped", + }, + { + "ProvingTaskAssigned", + ProvingTaskAssigned, + "assigned", + }, + { + "ProvingTaskProved", + ProvingTaskProved, + "proved", + }, + { + "ProvingTaskVerified", + ProvingTaskVerified, + "verified", + }, + { + "ProvingTaskFailed", + ProvingTaskFailed, + "failed", + }, + { + "Undefined", + ProvingStatus(999), // Invalid value. + "undefined", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.want, tt.s.String()) + }) + } +} diff --git a/common/types/message/message_test.go b/common/types/message/message_test.go index 15574f349..11ed0b05c 100644 --- a/common/types/message/message_test.go +++ b/common/types/message/message_test.go @@ -1,6 +1,7 @@ package message import ( + "encoding/hex" "testing" "time" @@ -15,13 +16,15 @@ func TestAuthMessageSignAndVerify(t *testing.T) { authMsg := &AuthMsg{ Identity: &Identity{ - Name: "testRoller", + Name: "testName", Timestamp: uint32(time.Now().Unix()), + Version: "testVersion", + Token: "testToken", }, } assert.NoError(t, authMsg.SignWithKey(privkey)) - // check public key. + // Check public key. pk, err := authMsg.PublicKey() assert.NoError(t, err) assert.Equal(t, common.Bytes2Hex(crypto.CompressPubkey(&privkey.PublicKey)), pk) @@ -36,3 +39,120 @@ func TestAuthMessageSignAndVerify(t *testing.T) { pubkey := crypto.CompressPubkey(&privkey.PublicKey) assert.Equal(t, pub, common.Bytes2Hex(pubkey)) } + +func TestGenerateToken(t *testing.T) { + token, err := GenerateToken() + assert.NoError(t, err) + assert.Equal(t, 32, len(token)) +} + +func TestIdentityHash(t *testing.T) { + identity := &Identity{ + Name: "testName", + RollerType: BasicProve, + Timestamp: uint32(1622428800), + Version: "testVersion", + Token: "testToken", + } + hash, err := identity.Hash() + assert.NoError(t, err) + + expectedHash := "b3f152958dc881446fc131a250526139d909710c6b91b4d3281ceded28ce2e32" + assert.Equal(t, expectedHash, hex.EncodeToString(hash)) +} + +func TestProofMessageSignVerifyPublicKey(t *testing.T) { + privkey, err := crypto.GenerateKey() + assert.NoError(t, err) + + proofMsg := &ProofMsg{ + ProofDetail: &ProofDetail{ + ID: "testID", + Type: BasicProve, + Status: StatusOk, + Proof: &AggProof{ + Proof: []byte("testProof"), + Instance: []byte("testInstance"), + FinalPair: []byte("testFinalPair"), + Vk: []byte("testVk"), + BlockCount: 1, + }, + Error: "testError", + }, + } + assert.NoError(t, proofMsg.Sign(privkey)) + + // Test when publicKey is not set. + ok, err := proofMsg.Verify() + assert.NoError(t, err) + assert.Equal(t, true, ok) + + // Test when publicKey is already set. + ok, err = proofMsg.Verify() + assert.NoError(t, err) + assert.Equal(t, true, ok) +} + +func TestProofDetailHash(t *testing.T) { + proofDetail := &ProofDetail{ + ID: "testID", + Type: BasicProve, + Status: StatusOk, + Proof: &AggProof{ + Proof: []byte("testProof"), + Instance: []byte("testInstance"), + FinalPair: []byte("testFinalPair"), + Vk: []byte("testVk"), + BlockCount: 1, + }, + Error: "testError", + } + hash, err := proofDetail.Hash() + assert.NoError(t, err) + expectedHash := "fdfaae752d6fd72a7fdd2ad034ef504d3acda9e691a799323cfa6e371684ba2b" + assert.Equal(t, expectedHash, hex.EncodeToString(hash)) +} + +func TestProveTypeString(t *testing.T) { + basicProve := ProveType(0) + assert.Equal(t, "Basic Prove", basicProve.String()) + + aggregatorProve := ProveType(1) + assert.Equal(t, "Aggregator Prove", aggregatorProve.String()) + + illegalProve := ProveType(3) + assert.Equal(t, "Illegal Prove type", illegalProve.String()) +} + +func TestProofMsgPublicKey(t *testing.T) { + privkey, err := crypto.GenerateKey() + assert.NoError(t, err) + + proofMsg := &ProofMsg{ + ProofDetail: &ProofDetail{ + ID: "testID", + Type: BasicProve, + Status: StatusOk, + Proof: &AggProof{ + Proof: []byte("testProof"), + Instance: []byte("testInstance"), + FinalPair: []byte("testFinalPair"), + Vk: []byte("testVk"), + BlockCount: 1, + }, + Error: "testError", + }, + } + assert.NoError(t, proofMsg.Sign(privkey)) + + // Test when publicKey is not set. + pk, err := proofMsg.PublicKey() + assert.NoError(t, err) + assert.Equal(t, common.Bytes2Hex(crypto.CompressPubkey(&privkey.PublicKey)), pk) + + // Test when publicKey is already set. + proofMsg.publicKey = common.Bytes2Hex(crypto.CompressPubkey(&privkey.PublicKey)) + pk, err = proofMsg.PublicKey() + assert.NoError(t, err) + assert.Equal(t, common.Bytes2Hex(crypto.CompressPubkey(&privkey.PublicKey)), pk) +} diff --git a/coordinator/config/config_test.go b/coordinator/config/config_test.go new file mode 100644 index 000000000..20ce28701 --- /dev/null +++ b/coordinator/config/config_test.go @@ -0,0 +1,137 @@ +package config + +import ( + "encoding/json" + "fmt" + "os" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestConfig(t *testing.T) { + configTemplate := `{ + "roller_manager_config": { + "compression_level": 9, + "rollers_per_session": 1, + "session_attempts": %d, + "collection_time": 180, + "token_time_to_live": 60, + "verifier": { + "mock_mode": true, + "params_path": "", + "agg_vk_path": "" + }, + "max_verifier_workers": %d, + "order_session": "%s" + }, + "db_config": { + "driver_name": "postgres", + "dsn": "postgres://admin:123456@localhost/test?sslmode=disable", + "maxOpenNum": 200, + "maxIdleNum": 20 + }, + "l2_config": { + "endpoint": "/var/lib/jenkins/workspace/SequencerPipeline/MyPrivateNetwork/geth.ipc" + } + }` + + t.Run("Success Case", func(t *testing.T) { + tmpFile, err := os.CreateTemp("", "example") + assert.NoError(t, err) + defer func() { + assert.NoError(t, tmpFile.Close()) + assert.NoError(t, os.Remove(tmpFile.Name())) + }() + config := fmt.Sprintf(configTemplate, defaultNumberOfSessionRetryAttempts, defaultNumberOfVerifierWorkers, "ASC") + _, err = tmpFile.WriteString(config) + assert.NoError(t, err) + + cfg, err := NewConfig(tmpFile.Name()) + assert.NoError(t, err) + + data, err := json.Marshal(cfg) + assert.NoError(t, err) + tmpJSON := fmt.Sprintf("/tmp/%d_config.json", time.Now().Nanosecond()) + defer func() { + if _, err = os.Stat(tmpJSON); err == nil { + assert.NoError(t, os.Remove(tmpJSON)) + } + }() + + assert.NoError(t, os.WriteFile(tmpJSON, data, 0644)) + + cfg2, err := NewConfig(tmpJSON) + assert.NoError(t, err) + assert.Equal(t, cfg, cfg2) + }) + + t.Run("File Not Found", func(t *testing.T) { + _, err := NewConfig("non_existent_file.json") + assert.ErrorIs(t, err, os.ErrNotExist) + }) + + t.Run("Invalid JSON Content", func(t *testing.T) { + tmpFile, err := os.CreateTemp("", "invalid_json_config.json") + assert.NoError(t, err) + defer func() { + assert.NoError(t, tmpFile.Close()) + assert.NoError(t, os.Remove(tmpFile.Name())) + }() + + _, err = tmpFile.WriteString("{ invalid_json: ") + assert.NoError(t, err) + + _, err = NewConfig(tmpFile.Name()) + assert.Error(t, err) + }) + + t.Run("Invalid Order Session", func(t *testing.T) { + tmpFile, err := os.CreateTemp("", "example") + assert.NoError(t, err) + defer func() { + assert.NoError(t, tmpFile.Close()) + assert.NoError(t, os.Remove(tmpFile.Name())) + }() + config := fmt.Sprintf(configTemplate, defaultNumberOfSessionRetryAttempts, defaultNumberOfVerifierWorkers, "INVALID") + _, err = tmpFile.WriteString(config) + assert.NoError(t, err) + + _, err = NewConfig(tmpFile.Name()) + assert.Error(t, err) + assert.Contains(t, err.Error(), "roller config's order session is invalid") + }) + + t.Run("Default MaxVerifierWorkers", func(t *testing.T) { + tmpFile, err := os.CreateTemp("", "example") + assert.NoError(t, err) + defer func() { + assert.NoError(t, tmpFile.Close()) + assert.NoError(t, os.Remove(tmpFile.Name())) + }() + config := fmt.Sprintf(configTemplate, defaultNumberOfSessionRetryAttempts, 0, "ASC") + _, err = tmpFile.WriteString(config) + assert.NoError(t, err) + + cfg, err := NewConfig(tmpFile.Name()) + assert.NoError(t, err) + assert.Equal(t, defaultNumberOfVerifierWorkers, cfg.RollerManagerConfig.MaxVerifierWorkers) + }) + + t.Run("Default SessionAttempts", func(t *testing.T) { + tmpFile, err := os.CreateTemp("", "example") + assert.NoError(t, err) + defer func() { + assert.NoError(t, tmpFile.Close()) + assert.NoError(t, os.Remove(tmpFile.Name())) + }() + config := fmt.Sprintf(configTemplate, 0, defaultNumberOfVerifierWorkers, "ASC") + _, err = tmpFile.WriteString(config) + assert.NoError(t, err) + + cfg, err := NewConfig(tmpFile.Name()) + assert.NoError(t, err) + assert.Equal(t, uint8(defaultNumberOfSessionRetryAttempts), cfg.RollerManagerConfig.SessionAttempts) + }) +} diff --git a/database/config_test.go b/database/config_test.go new file mode 100644 index 000000000..4c8b723c9 --- /dev/null +++ b/database/config_test.go @@ -0,0 +1,70 @@ +package database + +import ( + "encoding/json" + "fmt" + "os" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestConfig(t *testing.T) { + configTemplate := `{ + "dsn": "postgres://postgres:123456@localhost:5444/test?sslmode=disable", + "driver_name": "postgres", + "maxOpenNum": %d, + "maxIdleNum": %d + }` + + t.Run("Success Case", func(t *testing.T) { + tmpFile, err := os.CreateTemp("", "example") + assert.NoError(t, err) + defer func() { + assert.NoError(t, tmpFile.Close()) + assert.NoError(t, os.Remove(tmpFile.Name())) + }() + config := fmt.Sprintf(configTemplate, 200, 20) + _, err = tmpFile.WriteString(config) + assert.NoError(t, err) + + cfg, err := NewConfig(tmpFile.Name()) + assert.NoError(t, err) + + data, err := json.Marshal(cfg) + assert.NoError(t, err) + tmpJSON := fmt.Sprintf("/tmp/%d_config.json", time.Now().Nanosecond()) + defer func() { + if _, err = os.Stat(tmpJSON); err == nil { + assert.NoError(t, os.Remove(tmpJSON)) + } + }() + + assert.NoError(t, os.WriteFile(tmpJSON, data, 0644)) + + cfg2, err := NewConfig(tmpJSON) + assert.NoError(t, err) + assert.Equal(t, cfg, cfg2) + }) + + t.Run("File Not Found", func(t *testing.T) { + _, err := NewConfig("non_existent_file.json") + assert.ErrorIs(t, err, os.ErrNotExist) + }) + + t.Run("Invalid JSON Content", func(t *testing.T) { + tmpFile, err := os.CreateTemp("", "invalid_json_config.json") + assert.NoError(t, err) + defer func() { + assert.NoError(t, tmpFile.Close()) + assert.NoError(t, os.Remove(tmpFile.Name())) + }() + + _, err = tmpFile.WriteString("{ invalid_json: ") + assert.NoError(t, err) + + _, err = NewConfig(tmpFile.Name()) + assert.Error(t, err) + }) +}