From a0ac5334d162de013614ec791c2e616f1865ccff Mon Sep 17 00:00:00 2001 From: Yuan <45984206+Yuan325@users.noreply.github.com> Date: Mon, 3 Feb 2025 15:30:27 -0800 Subject: [PATCH] chore: return error for untested fields in tools.yaml (#239) This only checks within `SourceConfig`, `ToolConfig`, and `AuthSourceConfig`. Error when an unknown field is provided: `2025-01-27T22:43:46.988401-08:00 ERROR "unable to parse tool file at \"tools.yaml\": unable to parse as \"cloud-sql-postgres\": [2:1] unknown field \"extra\"\n 1 | database: test_database\n> 2 | extra: here\n ^\n 3 | instance: toolbox-cloudsql\n 4 | kind: cloud-sql-postgres\n 5 | password: postgres\n 6 | "` Error when a required field is not provided: `2025-01-27T17:49:47.584846-08:00 ERROR "unable to parse tool file at \"tools.yaml\": validation failed: Key: 'Config.Region' Error:Field validation for 'Region' failed on the 'required' tag"` --------- Co-authored-by: Kurtis Van Gent <31518063+kurtisvg@users.noreply.github.com> --- cmd/root.go | 2 +- cmd/root_test.go | 8 ++ go.mod | 11 +- go.sum | 24 +++- internal/auth/google/google.go | 6 +- internal/server/config.go | 128 ++++++++++-------- internal/sources/alloydbpg/alloydb_pg.go | 20 +-- internal/sources/alloydbpg/alloydb_pg_test.go | 56 +++++++- .../sources/cloudsqlmssql/cloud_sql_mssql.go | 22 +-- .../cloudsqlmssql/cloud_sql_mssql_test.go | 79 ++++++++++- .../sources/cloudsqlmysql/cloud_sql_mysql.go | 18 +-- .../cloudsqlmysql/cloud_sql_mysql_test.go | 54 +++++++- internal/sources/cloudsqlpg/cloud_sql_pg.go | 18 +-- .../sources/cloudsqlpg/cloud_sql_pg_test.go | 54 +++++++- internal/sources/dgraph/dgraph.go | 8 +- internal/sources/dgraph/dgraph_test.go | 61 +++++++++ internal/sources/mssql/mssql.go | 14 +- internal/sources/mssql/mssql_test.go | 57 ++++++++ internal/sources/mysql/mysql.go | 14 +- internal/sources/mysql/mysql_test.go | 57 ++++++++ internal/sources/neo4j/neo4j.go | 12 +- internal/sources/neo4j/neo4j_test.go | 55 ++++++++ internal/sources/postgres/postgres.go | 14 +- internal/sources/postgres/postgres_test.go | 57 ++++++++ internal/sources/spanner/spanner.go | 12 +- internal/sources/spanner/spanner_test.go | 36 ++++- internal/tools/dgraph/dgraph.go | 10 +- internal/tools/mssqlsql/mssqlsql.go | 10 +- internal/tools/mysqlsql/mysqlsql.go | 10 +- internal/tools/neo4j/neo4j.go | 10 +- internal/tools/parameters.go | 45 +++--- internal/tools/parameters_test.go | 12 +- internal/tools/postgressql/postgressql.go | 10 +- internal/tools/spanner/spanner.go | 10 +- internal/util/util.go | 18 +++ 35 files changed, 824 insertions(+), 208 deletions(-) diff --git a/cmd/root.go b/cmd/root.go index 3faa864fc5b..b7df86b3c15 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -128,7 +128,7 @@ type ToolsFile struct { func parseToolsFile(raw []byte) (ToolsFile, error) { var toolsFile ToolsFile // Parse contents - err := yaml.Unmarshal(raw, &toolsFile) + err := yaml.UnmarshalWithOptions(raw, &toolsFile, yaml.Strict()) if err != nil { return toolsFile, err } diff --git a/cmd/root_test.go b/cmd/root_test.go index 6a03bc9ce94..5cf163aa089 100644 --- a/cmd/root_test.go +++ b/cmd/root_test.go @@ -275,6 +275,8 @@ func TestParseToolFile(t *testing.T) { region: my-region instance: my-instance database: my_db + user: my_user + password: my_pass tools: example_tool: kind: postgres-sql @@ -300,6 +302,8 @@ func TestParseToolFile(t *testing.T) { Instance: "my-instance", IPType: "public", Database: "my_db", + User: "my_user", + Password: "my_pass", }, }, Tools: server.ToolConfigs{ @@ -362,6 +366,8 @@ func TestParseToolFileWithAuth(t *testing.T) { region: my-region instance: my-instance database: my_db + user: my_user + password: my_pass authSources: my-google-service: kind: google @@ -410,6 +416,8 @@ func TestParseToolFileWithAuth(t *testing.T) { Instance: "my-instance", IPType: "public", Database: "my_db", + User: "my_user", + Password: "my_pass", }, }, AuthSources: server.AuthSourceConfigs{ diff --git a/go.mod b/go.mod index 2c3a7a66ce0..68bdd00410c 100644 --- a/go.mod +++ b/go.mod @@ -13,6 +13,7 @@ require ( github.com/go-chi/chi/v5 v5.1.0 github.com/go-chi/httplog/v2 v2.1.1 github.com/go-chi/render v1.0.3 + github.com/go-playground/validator/v10 v10.24.0 github.com/go-sql-driver/mysql v1.8.1 github.com/goccy/go-yaml v1.15.13 github.com/google/go-cmp v0.6.0 @@ -56,8 +57,11 @@ require ( github.com/envoyproxy/go-control-plane v0.13.1 // indirect github.com/envoyproxy/protoc-gen-validate v1.1.0 // indirect github.com/felixge/httpsnoop v1.0.4 // indirect + github.com/gabriel-vasile/mimetype v1.4.8 // indirect github.com/go-logr/logr v1.4.2 // indirect github.com/go-logr/stdr v1.2.2 // indirect + github.com/go-playground/locales v0.14.1 // indirect + github.com/go-playground/universal-translator v0.18.1 // indirect github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 // indirect github.com/golang-sql/sqlexp v0.1.0 // indirect github.com/golang/groupcache v0.0.0-20241129210726-2c02b8208cf8 // indirect @@ -69,6 +73,7 @@ require ( github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect github.com/jackc/puddle/v2 v2.2.2 // indirect + github.com/leodido/go-urn v1.4.0 // indirect github.com/planetscale/vtprotobuf v0.6.1-0.20240319094008-0393e58bdf10 // indirect github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect github.com/spf13/pflag v1.0.5 // indirect @@ -84,11 +89,11 @@ require ( go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.33.0 // indirect go.opentelemetry.io/proto/otlp v1.4.0 // indirect go.uber.org/multierr v1.11.0 // indirect - golang.org/x/crypto v0.31.0 // indirect - golang.org/x/net v0.33.0 // indirect + golang.org/x/crypto v0.32.0 // indirect + golang.org/x/net v0.34.0 // indirect golang.org/x/oauth2 v0.24.0 // indirect golang.org/x/sync v0.10.0 // indirect - golang.org/x/sys v0.28.0 // indirect + golang.org/x/sys v0.29.0 // indirect golang.org/x/text v0.21.0 // indirect golang.org/x/time v0.8.0 // indirect google.golang.org/genproto v0.0.0-20241209162323-e6fa225c2576 // indirect diff --git a/go.sum b/go.sum index 25d4f2ca0b1..1db228445b8 100644 --- a/go.sum +++ b/go.sum @@ -727,6 +727,8 @@ github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2 github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U= github.com/fogleman/gg v1.2.1-0.20190220221249-0403632d5b90/go.mod h1:R/bRT+9gY/C5z7JzPU0zXsXHKM4/ayA+zqcVNZzPa1k= github.com/fogleman/gg v1.3.0/go.mod h1:R/bRT+9gY/C5z7JzPU0zXsXHKM4/ayA+zqcVNZzPa1k= +github.com/gabriel-vasile/mimetype v1.4.8 h1:FfZ3gj38NjllZIeJAmMhr+qKL8Wu+nOoI3GqacKw1NM= +github.com/gabriel-vasile/mimetype v1.4.8/go.mod h1:ByKUIKGjh1ODkGM1asKUbQZOLGrPjydw3hYPU2YU9t8= github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04= github.com/go-chi/chi/v5 v5.1.0 h1:acVI1TYaD+hhedDJ3r54HyA6sExp3HfXq7QWEEY/xMw= github.com/go-chi/chi/v5 v5.1.0/go.mod h1:DslCQbL2OYiznFReuXYUmQ2hGd1aDpCnlMNITLSKoi8= @@ -751,6 +753,14 @@ github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= github.com/go-pdf/fpdf v0.5.0/go.mod h1:HzcnA+A23uwogo0tp9yU+l3V+KXhiESpt1PMayhOh5M= github.com/go-pdf/fpdf v0.6.0/go.mod h1:HzcnA+A23uwogo0tp9yU+l3V+KXhiESpt1PMayhOh5M= +github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s= +github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= +github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA= +github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY= +github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY= +github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY= +github.com/go-playground/validator/v10 v10.24.0 h1:KHQckvo8G6hlWnrPX4NJJ+aBfWNAE/HH+qdL2cBpCmg= +github.com/go-playground/validator/v10 v10.24.0/go.mod h1:GGzBIJMuE98Ic/kJsBXbz1x/7cByt++cQ+YOuDM5wus= github.com/go-sql-driver/mysql v1.8.1 h1:LedoTUt/eveggdHS9qUFC1EFSa8bU2+1pZjSRpvNJ1Y= github.com/go-sql-driver/mysql v1.8.1/go.mod h1:wEBSXgmK//2ZFJyE+qWnIsVGmvmEKlqwuVSjsCm7DZg= github.com/goccy/go-json v0.9.11/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I= @@ -927,6 +937,8 @@ github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc= github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= +github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ= +github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI= github.com/lyft/protoc-gen-star v0.6.0/go.mod h1:TGAoBVkt8w7MPG72TrKIu85MIdXwDuzJYeZuUPFPNwA= github.com/lyft/protoc-gen-star v0.6.1/go.mod h1:TGAoBVkt8w7MPG72TrKIu85MIdXwDuzJYeZuUPFPNwA= github.com/lyft/protoc-gen-star/v2 v2.0.1/go.mod h1:RcCdONR2ScXaYnQC5tUzxzlpA3WVYF7/opLeUgcQs/o= @@ -1059,8 +1071,8 @@ golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPh golang.org/x/crypto v0.0.0-20210421170649-83a5a9bb288b/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.0.0-20211108221036-ceb1ce70b4fa/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= -golang.org/x/crypto v0.31.0 h1:ihbySMvVjLAeSH1IbfcRTkD/iNscyz8rGzjF/E5hV6U= -golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk= +golang.org/x/crypto v0.32.0 h1:euUpcYgM8WcP71gNpTqQCn6rC2t6ULUPiOzfWaXVVfc= +golang.org/x/crypto v0.32.0/go.mod h1:ZnnJkOaASj8g0AjIduWNlq2NRxL0PlBrbKVyZ6V/Ugc= golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20180807140117-3d87b88a115f/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= @@ -1175,8 +1187,8 @@ golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= golang.org/x/net v0.7.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= golang.org/x/net v0.8.0/go.mod h1:QVkue5JL9kW//ek3r6jTKnTFis1tRmNAW2P1shuFdJc= golang.org/x/net v0.9.0/go.mod h1:d48xBJpPfHeWQsugry2m+kC02ZBRGRgulfHnEXEuWns= -golang.org/x/net v0.33.0 h1:74SYHlV8BIgHIFC/LrYkOGIwL19eTYXQ5wc6TBuO36I= -golang.org/x/net v0.33.0/go.mod h1:HXLR5J+9DxmrqMwG9qjGCxZ+zKXxBru04zlTvWlWuN4= +golang.org/x/net v0.34.0 h1:Mb7Mrk043xzHgnRM88suvJFwzVrRfHEHJEl5/71CKw0= +golang.org/x/net v0.34.0/go.mod h1:di0qlW3YNM5oh6GqDGQr92MyTozJPmybPK4Ev/Gm31k= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= @@ -1303,8 +1315,8 @@ golang.org/x/sys v0.4.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.7.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA= -golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.29.0 h1:TPYlXGxvx1MGTn2GiZDhnjPA9wZzZeGKHHmKhHYvgaU= +golang.org/x/sys v0.29.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.2.0/go.mod h1:TVmDHMZPmdnySmBfhjOoOdhjzdE1h4u1VwSiw2l1Nuc= diff --git a/internal/auth/google/google.go b/internal/auth/google/google.go index 93cd354e9ff..009decd8352 100644 --- a/internal/auth/google/google.go +++ b/internal/auth/google/google.go @@ -30,9 +30,9 @@ var _ auth.AuthSourceConfig = Config{} // Auth source configuration type Config struct { - Name string `yaml:"name"` - Kind string `yaml:"kind"` - ClientID string `yaml:"clientId"` + Name string `yaml:"name" validate:"required"` + Kind string `yaml:"kind" validate:"required"` + ClientID string `yaml:"clientId" validate:"required"` } // Returns the auth source kind diff --git a/internal/server/config.go b/internal/server/config.go index 42c68d3ecaf..e6a15293956 100644 --- a/internal/server/config.go +++ b/internal/server/config.go @@ -135,76 +135,84 @@ func (c *SourceConfigs) UnmarshalYAML(unmarshal func(interface{}) error) error { } for name, u := range raw { - var k struct { - Kind string `yaml:"kind"` + // Unmarshal to a general type that ensure it capture all fields + var v map[string]any + if err := u.Unmarshal(&v); err != nil { + return fmt.Errorf("unable to unmarshal %q: %w", name, err) } - err := u.Unmarshal(&k) + + kind, ok := v["kind"] + if !ok { + return fmt.Errorf("missing 'kind' field for %q", name) + } + + dec, err := util.NewStrictDecoder(v) if err != nil { - return fmt.Errorf("missing 'kind' field for %q", k) + return fmt.Errorf("error creating decoder: %w", err) } - switch k.Kind { + switch kind { case alloydbpgsrc.SourceKind: actual := alloydbpgsrc.Config{Name: name, IPType: "public"} - if err := u.Unmarshal(&actual); err != nil { - return fmt.Errorf("unable to parse as %q: %w", k.Kind, err) + if err := dec.Decode(&actual); err != nil { + return fmt.Errorf("unable to parse as %q: %w", kind, err) } (*c)[name] = actual case cloudsqlpgsrc.SourceKind: actual := cloudsqlpgsrc.Config{Name: name, IPType: "public"} - if err := u.Unmarshal(&actual); err != nil { - return fmt.Errorf("unable to parse as %q: %w", k.Kind, err) + if err := dec.Decode(&actual); err != nil { + return fmt.Errorf("unable to parse as %q: %w", kind, err) } (*c)[name] = actual case postgressrc.SourceKind: actual := postgressrc.Config{Name: name} - if err := u.Unmarshal(&actual); err != nil { - return fmt.Errorf("unable to parse as %q: %w", k.Kind, err) + if err := dec.Decode(&actual); err != nil { + return fmt.Errorf("unable to parse as %q: %w", kind, err) } (*c)[name] = actual case cloudsqlmysqlsrc.SourceKind: actual := cloudsqlmysqlsrc.Config{Name: name, IPType: "public"} - if err := u.Unmarshal(&actual); err != nil { - return fmt.Errorf("unable to parse as %q: %w", k.Kind, err) + if err := dec.Decode(&actual); err != nil { + return fmt.Errorf("unable to parse as %q: %w", kind, err) } (*c)[name] = actual case mysqlsrc.SourceKind: actual := mysqlsrc.Config{Name: name} - if err := u.Unmarshal(&actual); err != nil { - return fmt.Errorf("unable to parse as %q: %w", k.Kind, err) + if err := dec.Decode(&actual); err != nil { + return fmt.Errorf("unable to parse as %q: %w", kind, err) } (*c)[name] = actual case spannersrc.SourceKind: actual := spannersrc.Config{Name: name, Dialect: "googlesql"} - if err := u.Unmarshal(&actual); err != nil { - return fmt.Errorf("unable to parse as %q: %w", k.Kind, err) + if err := dec.Decode(&actual); err != nil { + return fmt.Errorf("unable to parse as %q: %w", kind, err) } (*c)[name] = actual case neo4jrc.SourceKind: actual := neo4jrc.Config{Name: name} - if err := u.Unmarshal(&actual); err != nil { - return fmt.Errorf("unable to parse as %q: %w", k.Kind, err) + if err := dec.Decode(&actual); err != nil { + return fmt.Errorf("unable to parse as %q: %w", kind, err) } (*c)[name] = actual case cloudsqlmssqlsrc.SourceKind: - actual := cloudsqlmssqlsrc.Config{Name: name} - if err := u.Unmarshal(&actual); err != nil { - return fmt.Errorf("unable to parse as %q: %w", k.Kind, err) + actual := cloudsqlmssqlsrc.Config{Name: name, IPType: "public"} + if err := dec.Decode(&actual); err != nil { + return fmt.Errorf("unable to parse as %q: %w", kind, err) } (*c)[name] = actual case mssqlsrc.SourceKind: actual := mssqlsrc.Config{Name: name} - if err := u.Unmarshal(&actual); err != nil { - return fmt.Errorf("unable to parse as %q: %w", k.Kind, err) + if err := dec.Decode(&actual); err != nil { + return fmt.Errorf("unable to parse as %q: %w", kind, err) } (*c)[name] = actual case dgraphsrc.SourceKind: actual := dgraphsrc.Config{Name: name} - if err := u.Unmarshal(&actual); err != nil { - return fmt.Errorf("unable to parse as %q: %w", k.Kind, err) + if err := dec.Decode(&actual); err != nil { + return fmt.Errorf("unable to parse as %q: %w", kind, err) } (*c)[name] = actual default: - return fmt.Errorf("%q is not a valid kind of data source", k.Kind) + return fmt.Errorf("%q is not a valid kind of data source", kind) } } @@ -226,22 +234,29 @@ func (c *AuthSourceConfigs) UnmarshalYAML(unmarshal func(interface{}) error) err } for name, u := range raw { - var k struct { - Kind string `yaml:"kind"` + var v map[string]any + if err := u.Unmarshal(&v); err != nil { + return fmt.Errorf("unable to unmarshal %q: %w", name, err) } - err := u.Unmarshal(&k) + + kind, ok := v["kind"] + if !ok { + return fmt.Errorf("missing 'kind' field for %q", name) + } + + dec, err := util.NewStrictDecoder(v) if err != nil { - return fmt.Errorf("missing 'kind' field for %q", k) + return fmt.Errorf("error creating decoder: %w", err) } - switch k.Kind { + switch kind { case google.AuthSourceKind: actual := google.Config{Name: name} - if err := u.Unmarshal(&actual); err != nil { - return fmt.Errorf("unable to parse as %q: %w", k.Kind, err) + if err := dec.Decode(&actual); err != nil { + return fmt.Errorf("unable to parse as %q: %w", kind, err) } (*c)[name] = actual default: - return fmt.Errorf("%q is not a valid kind of auth source", k.Kind) + return fmt.Errorf("%q is not a valid kind of auth source", kind) } } return nil @@ -262,52 +277,59 @@ func (c *ToolConfigs) UnmarshalYAML(unmarshal func(interface{}) error) error { } for name, u := range raw { - var k struct { - Kind string `yaml:"kind"` + var v map[string]any + if err := u.Unmarshal(&v); err != nil { + return fmt.Errorf("unable to unmarshal %q: %w", name, err) } - err := u.Unmarshal(&k) - if err != nil { + + kind, ok := v["kind"] + if !ok { return fmt.Errorf("missing 'kind' field for %q", name) } - switch k.Kind { + + dec, err := util.NewStrictDecoder(v) + if err != nil { + return fmt.Errorf("error creating decoder: %w", err) + } + switch kind { case postgressql.ToolKind: actual := postgressql.Config{Name: name} - if err := u.Unmarshal(&actual); err != nil { - return fmt.Errorf("unable to parse as %q: %w", k.Kind, err) + if err := dec.Decode(&actual); err != nil { + return fmt.Errorf("unable to parse as %q: %w", kind, err) } (*c)[name] = actual case mysqlsql.ToolKind: actual := mysqlsql.Config{Name: name} - if err := u.Unmarshal(&actual); err != nil { - return fmt.Errorf("unable to parse as %q: %w", k.Kind, err) + if err := dec.Decode(&actual); err != nil { + return fmt.Errorf("unable to parse as %q: %w", kind, err) } (*c)[name] = actual case spanner.ToolKind: actual := spanner.Config{Name: name} - if err := u.Unmarshal(&actual); err != nil { - return fmt.Errorf("unable to parse as %q: %w", k.Kind, err) + if err := dec.Decode(&actual); err != nil { + return fmt.Errorf("unable to parse as %q: %w", kind, err) } (*c)[name] = actual case neo4jtool.ToolKind: actual := neo4jtool.Config{Name: name} - if err := u.Unmarshal(&actual); err != nil { - return fmt.Errorf("unable to parse as %q: %w", k.Kind, err) + if err := dec.Decode(&actual); err != nil { + return fmt.Errorf("unable to parse as %q: %w", kind, err) } (*c)[name] = actual case mssqlsql.ToolKind: actual := mssqlsql.Config{Name: name} - if err := u.Unmarshal(&actual); err != nil { - return fmt.Errorf("unable to parse as %q: %w", k.Kind, err) + if err := dec.Decode(&actual); err != nil { + return fmt.Errorf("unable to parse as %q: %w", kind, err) } (*c)[name] = actual case dgraph.ToolKind: actual := dgraph.Config{Name: name} - if err := u.Unmarshal(&actual); err != nil { - return fmt.Errorf("unable to parse as %q: %w", k.Kind, err) + if err := dec.Decode(&actual); err != nil { + return fmt.Errorf("unable to parse as %q: %w", kind, err) } (*c)[name] = actual default: - return fmt.Errorf("%q is not a valid kind of tool", k.Kind) + return fmt.Errorf("%q is not a valid kind of tool", kind) } } diff --git a/internal/sources/alloydbpg/alloydb_pg.go b/internal/sources/alloydbpg/alloydb_pg.go index 00f4c9e37c5..6a615c0f7d6 100644 --- a/internal/sources/alloydbpg/alloydb_pg.go +++ b/internal/sources/alloydbpg/alloydb_pg.go @@ -33,16 +33,16 @@ const SourceKind string = "alloydb-postgres" var _ sources.SourceConfig = Config{} type Config struct { - Name string `yaml:"name"` - Kind string `yaml:"kind"` - Project string `yaml:"project"` - Region string `yaml:"region"` - Cluster string `yaml:"cluster"` - Instance string `yaml:"instance"` - IPType sources.IPType `yaml:"ipType"` - User string `yaml:"user"` - Password string `yaml:"password"` - Database string `yaml:"database"` + Name string `yaml:"name" validate:"required"` + Kind string `yaml:"kind" validate:"required"` + Project string `yaml:"project" validate:"required"` + Region string `yaml:"region" validate:"required"` + Cluster string `yaml:"cluster" validate:"required"` + Instance string `yaml:"instance" validate:"required"` + IPType sources.IPType `yaml:"ipType" validate:"required"` + User string `yaml:"user" validate:"required"` + Password string `yaml:"password" validate:"required"` + Database string `yaml:"database" validate:"required"` } func (r Config) SourceConfigKind() string { diff --git a/internal/sources/alloydbpg/alloydb_pg_test.go b/internal/sources/alloydbpg/alloydb_pg_test.go index d061cfc62b1..8ca536a273d 100644 --- a/internal/sources/alloydbpg/alloydb_pg_test.go +++ b/internal/sources/alloydbpg/alloydb_pg_test.go @@ -42,6 +42,8 @@ func TestParseFromYamlAlloyDBPg(t *testing.T) { cluster: my-cluster instance: my-instance database: my_db + user: my_user + password: my_pass `, want: map[string]sources.SourceConfig{ "my-pg-instance": alloydbpg.Config{ @@ -53,6 +55,8 @@ func TestParseFromYamlAlloyDBPg(t *testing.T) { Instance: "my-instance", IPType: "public", Database: "my_db", + User: "my_user", + Password: "my_pass", }, }, }, @@ -68,6 +72,8 @@ func TestParseFromYamlAlloyDBPg(t *testing.T) { instance: my-instance ipType: Public database: my_db + user: my_user + password: my_pass `, want: map[string]sources.SourceConfig{ "my-pg-instance": alloydbpg.Config{ @@ -79,6 +85,8 @@ func TestParseFromYamlAlloyDBPg(t *testing.T) { Instance: "my-instance", IPType: "public", Database: "my_db", + User: "my_user", + Password: "my_pass", }, }, }, @@ -94,6 +102,8 @@ func TestParseFromYamlAlloyDBPg(t *testing.T) { instance: my-instance ipType: private database: my_db + user: my_user + password: my_pass `, want: map[string]sources.SourceConfig{ "my-pg-instance": alloydbpg.Config{ @@ -105,6 +115,8 @@ func TestParseFromYamlAlloyDBPg(t *testing.T) { Instance: "my-instance", IPType: "private", Database: "my_db", + User: "my_user", + Password: "my_pass", }, }, }, @@ -126,10 +138,11 @@ func TestParseFromYamlAlloyDBPg(t *testing.T) { } } -func FailParseFromYamlAlloyDBPg(t *testing.T) { +func TestFailParseFromYaml(t *testing.T) { tcs := []struct { desc string in string + err string }{ { desc: "invalid ipType", @@ -143,7 +156,42 @@ func FailParseFromYamlAlloyDBPg(t *testing.T) { instance: my-instance ipType: fail database: my_db + user: my_user + password: my_pass `, + err: "unable to parse as \"alloydb-postgres\": ipType invalid: must be one of \"public\", or \"private\"", + }, + { + desc: "extra field", + in: ` + sources: + my-pg-instance: + kind: alloydb-postgres + project: my-project + region: my-region + cluster: my-cluster + instance: my-instance + database: my_db + user: my_user + password: my_pass + foo: bar + `, + err: "unable to parse as \"alloydb-postgres\": [3:1] unknown field \"foo\"\n 1 | cluster: my-cluster\n 2 | database: my_db\n> 3 | foo: bar\n ^\n 4 | instance: my-instance\n 5 | kind: alloydb-postgres\n 6 | password: my_pass\n 7 | ", + }, + { + desc: "missing required field", + in: ` + sources: + my-pg-instance: + kind: alloydb-postgres + region: my-region + cluster: my-cluster + instance: my-instance + database: my_db + user: my_user + password: my_pass + `, + err: "unable to parse as \"alloydb-postgres\": Key: 'Config.Project' Error:Field validation for 'Project' failed on the 'required' tag", }, } for _, tc := range tcs { @@ -154,7 +202,11 @@ func FailParseFromYamlAlloyDBPg(t *testing.T) { // Parse contents err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got) if err == nil { - t.Fatalf("expect parsing to fail: %s", err) + t.Fatalf("expect parsing to fail") + } + errStr := err.Error() + if errStr != tc.err { + t.Fatalf("unexpected error: got %q, want %q", errStr, tc.err) } }) } diff --git a/internal/sources/cloudsqlmssql/cloud_sql_mssql.go b/internal/sources/cloudsqlmssql/cloud_sql_mssql.go index fdf60c29dfe..03895faf948 100644 --- a/internal/sources/cloudsqlmssql/cloud_sql_mssql.go +++ b/internal/sources/cloudsqlmssql/cloud_sql_mssql.go @@ -33,16 +33,16 @@ var _ sources.SourceConfig = Config{} type Config struct { // Cloud SQL MSSQL configs - Name string `yaml:"name"` - Kind string `yaml:"kind"` - Project string `yaml:"project"` - Region string `yaml:"region"` - Instance string `yaml:"instance"` - IPAddress string `yaml:"ipAddress"` - IPType string `yaml:"ipType"` - User string `yaml:"user"` - Password string `yaml:"password"` - Database string `yaml:"database"` + Name string `yaml:"name" validate:"required"` + Kind string `yaml:"kind" validate:"required"` + Project string `yaml:"project" validate:"required"` + Region string `yaml:"region" validate:"required"` + Instance string `yaml:"instance" validate:"required"` + IPAddress string `yaml:"ipAddress" validate:"required"` + IPType sources.IPType `yaml:"ipType" validate:"required"` + User string `yaml:"user" validate:"required"` + Password string `yaml:"password" validate:"required"` + Database string `yaml:"database" validate:"required"` } func (r Config) SourceConfigKind() string { @@ -52,7 +52,7 @@ func (r Config) SourceConfigKind() string { func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) { // Initializes a Cloud SQL MSSQL source - db, err := initCloudSQLMssqlConnection(ctx, tracer, r.Name, r.Project, r.Region, r.Instance, r.IPAddress, r.IPType, r.User, r.Password, r.Database) + db, err := initCloudSQLMssqlConnection(ctx, tracer, r.Name, r.Project, r.Region, r.Instance, r.IPAddress, r.IPType.String(), r.User, r.Password, r.Database) if err != nil { return nil, fmt.Errorf("unable to create db connection: %w", err) } diff --git a/internal/sources/cloudsqlmssql/cloud_sql_mssql_test.go b/internal/sources/cloudsqlmssql/cloud_sql_mssql_test.go index 939c1fa2820..036148f5025 100644 --- a/internal/sources/cloudsqlmssql/cloud_sql_mssql_test.go +++ b/internal/sources/cloudsqlmssql/cloud_sql_mssql_test.go @@ -41,7 +41,8 @@ func TestParseFromYamlCloudSQLMssql(t *testing.T) { instance: my-instance database: my_db ipAddress: localhost - ipType: public + user: my_user + password: my_pass `, want: server.SourceConfigs{ "my-instance": cloudsqlmssql.Config{ @@ -53,6 +54,8 @@ func TestParseFromYamlCloudSQLMssql(t *testing.T) { IPAddress: "localhost", IPType: "public", Database: "my_db", + User: "my_user", + Password: "my_pass", }, }, }, @@ -74,3 +77,77 @@ func TestParseFromYamlCloudSQLMssql(t *testing.T) { } } + +func TestFailParseFromYaml(t *testing.T) { + tcs := []struct { + desc string + in string + err string + }{ + { + desc: "invalid ipType", + in: ` + sources: + my-instance: + kind: cloud-sql-mssql + project: my-project + region: my-region + instance: my-instance + ipType: fail + database: my_db + ipAddress: localhost + user: my_user + password: my_pass + `, + err: "unable to parse as \"cloud-sql-mssql\": ipType invalid: must be one of \"public\", or \"private\"", + }, + { + desc: "extra field", + in: ` + sources: + my-instance: + kind: cloud-sql-mssql + project: my-project + region: my-region + instance: my-instance + database: my_db + ipAddress: localhost + user: my_user + password: my_pass + foo: bar + `, + err: "unable to parse as \"cloud-sql-mssql\": [2:1] unknown field \"foo\"\n 1 | database: my_db\n> 2 | foo: bar\n ^\n 3 | instance: my-instance\n 4 | ipAddress: localhost\n 5 | kind: cloud-sql-mssql\n 6 | ", + }, + { + desc: "missing required field", + in: ` + sources: + my-instance: + kind: cloud-sql-mssql + region: my-region + instance: my-instance + database: my_db + ipAddress: localhost + user: my_user + password: my_pass + `, + err: "unable to parse as \"cloud-sql-mssql\": Key: 'Config.Project' Error:Field validation for 'Project' failed on the 'required' tag", + }, + } + for _, tc := range tcs { + t.Run(tc.desc, func(t *testing.T) { + got := struct { + Sources server.SourceConfigs `yaml:"sources"` + }{} + // Parse contents + err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got) + if err == nil { + t.Fatalf("expect parsing to fail") + } + errStr := err.Error() + if errStr != tc.err { + t.Fatalf("unexpected error: got %q, want %q", errStr, tc.err) + } + }) + } +} diff --git a/internal/sources/cloudsqlmysql/cloud_sql_mysql.go b/internal/sources/cloudsqlmysql/cloud_sql_mysql.go index 6b610c00d5b..d469eda90b0 100644 --- a/internal/sources/cloudsqlmysql/cloud_sql_mysql.go +++ b/internal/sources/cloudsqlmysql/cloud_sql_mysql.go @@ -32,15 +32,15 @@ const SourceKind string = "cloud-sql-mysql" var _ sources.SourceConfig = Config{} type Config struct { - Name string `yaml:"name"` - Kind string `yaml:"kind"` - Project string `yaml:"project"` - Region string `yaml:"region"` - Instance string `yaml:"instance"` - IPType sources.IPType `yaml:"ipType"` - User string `yaml:"user"` - Password string `yaml:"password"` - Database string `yaml:"database"` + Name string `yaml:"name" validate:"required"` + Kind string `yaml:"kind" validate:"required"` + Project string `yaml:"project" validate:"required"` + Region string `yaml:"region" validate:"required"` + Instance string `yaml:"instance" validate:"required"` + IPType sources.IPType `yaml:"ipType" validate:"required"` + User string `yaml:"user" validate:"required"` + Password string `yaml:"password" validate:"required"` + Database string `yaml:"database" validate:"required"` } func (r Config) SourceConfigKind() string { diff --git a/internal/sources/cloudsqlmysql/cloud_sql_mysql_test.go b/internal/sources/cloudsqlmysql/cloud_sql_mysql_test.go index 86277d63a3d..14730d0dc9a 100644 --- a/internal/sources/cloudsqlmysql/cloud_sql_mysql_test.go +++ b/internal/sources/cloudsqlmysql/cloud_sql_mysql_test.go @@ -40,6 +40,8 @@ func TestParseFromYamlCloudSQLMySQL(t *testing.T) { region: my-region instance: my-instance database: my_db + user: my_user + password: my_pass `, want: server.SourceConfigs{ "my-mysql-instance": cloudsqlmysql.Config{ @@ -50,6 +52,8 @@ func TestParseFromYamlCloudSQLMySQL(t *testing.T) { Instance: "my-instance", IPType: "public", Database: "my_db", + User: "my_user", + Password: "my_pass", }, }, }, @@ -64,6 +68,8 @@ func TestParseFromYamlCloudSQLMySQL(t *testing.T) { instance: my-instance ipType: Public database: my_db + user: my_user + password: my_pass `, want: server.SourceConfigs{ "my-mysql-instance": cloudsqlmysql.Config{ @@ -74,6 +80,8 @@ func TestParseFromYamlCloudSQLMySQL(t *testing.T) { Instance: "my-instance", IPType: "public", Database: "my_db", + User: "my_user", + Password: "my_pass", }, }, }, @@ -88,6 +96,8 @@ func TestParseFromYamlCloudSQLMySQL(t *testing.T) { instance: my-instance ipType: private database: my_db + user: my_user + password: my_pass `, want: server.SourceConfigs{ "my-mysql-instance": cloudsqlmysql.Config{ @@ -98,6 +108,8 @@ func TestParseFromYamlCloudSQLMySQL(t *testing.T) { Instance: "my-instance", IPType: "private", Database: "my_db", + User: "my_user", + Password: "my_pass", }, }, }, @@ -120,10 +132,11 @@ func TestParseFromYamlCloudSQLMySQL(t *testing.T) { } -func FailParseFromYamlCloudSQLMySQL(t *testing.T) { +func TestFailParseFromYaml(t *testing.T) { tcs := []struct { desc string in string + err string }{ { desc: "invalid ipType", @@ -136,7 +149,40 @@ func FailParseFromYamlCloudSQLMySQL(t *testing.T) { instance: my-instance ipType: fail database: my_db + user: my_user + password: my_pass `, + err: "unable to parse as \"cloud-sql-mysql\": ipType invalid: must be one of \"public\", or \"private\"", + }, + { + desc: "extra field", + in: ` + sources: + my-mysql-instance: + kind: cloud-sql-mysql + project: my-project + region: my-region + instance: my-instance + database: my_db + user: my_user + password: my_pass + foo: bar + `, + err: "unable to parse as \"cloud-sql-mysql\": [2:1] unknown field \"foo\"\n 1 | database: my_db\n> 2 | foo: bar\n ^\n 3 | instance: my-instance\n 4 | kind: cloud-sql-mysql\n 5 | password: my_pass\n 6 | ", + }, + { + desc: "missing required field", + in: ` + sources: + my-mysql-instance: + kind: cloud-sql-mysql + region: my-region + instance: my-instance + database: my_db + user: my_user + password: my_pass + `, + err: "unable to parse as \"cloud-sql-mysql\": Key: 'Config.Project' Error:Field validation for 'Project' failed on the 'required' tag", }, } for _, tc := range tcs { @@ -147,7 +193,11 @@ func FailParseFromYamlCloudSQLMySQL(t *testing.T) { // Parse contents err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got) if err == nil { - t.Fatalf("expect parsing to fail: %s", err) + t.Fatalf("expect parsing to fail") + } + errStr := err.Error() + if errStr != tc.err { + t.Fatalf("unexpected error: got %q, want %q", errStr, tc.err) } }) } diff --git a/internal/sources/cloudsqlpg/cloud_sql_pg.go b/internal/sources/cloudsqlpg/cloud_sql_pg.go index c8659241ac5..877ded90854 100644 --- a/internal/sources/cloudsqlpg/cloud_sql_pg.go +++ b/internal/sources/cloudsqlpg/cloud_sql_pg.go @@ -32,15 +32,15 @@ const SourceKind string = "cloud-sql-postgres" var _ sources.SourceConfig = Config{} type Config struct { - Name string `yaml:"name"` - Kind string `yaml:"kind"` - Project string `yaml:"project"` - Region string `yaml:"region"` - Instance string `yaml:"instance"` - IPType sources.IPType `yaml:"ipType"` - User string `yaml:"user"` - Password string `yaml:"password"` - Database string `yaml:"database"` + Name string `yaml:"name" validate:"required"` + Kind string `yaml:"kind" validate:"required"` + Project string `yaml:"project" validate:"required"` + Region string `yaml:"region" validate:"required"` + Instance string `yaml:"instance" validate:"required"` + IPType sources.IPType `yaml:"ipType" validate:"required"` + User string `yaml:"user" validate:"required"` + Password string `yaml:"password" validate:"required"` + Database string `yaml:"database" validate:"required"` } func (r Config) SourceConfigKind() string { diff --git a/internal/sources/cloudsqlpg/cloud_sql_pg_test.go b/internal/sources/cloudsqlpg/cloud_sql_pg_test.go index 9ed54e44f90..a666e3c2097 100644 --- a/internal/sources/cloudsqlpg/cloud_sql_pg_test.go +++ b/internal/sources/cloudsqlpg/cloud_sql_pg_test.go @@ -40,6 +40,8 @@ func TestParseFromYamlCloudSQLPg(t *testing.T) { region: my-region instance: my-instance database: my_db + user: my_user + password: my_pass `, want: server.SourceConfigs{ "my-pg-instance": cloudsqlpg.Config{ @@ -50,6 +52,8 @@ func TestParseFromYamlCloudSQLPg(t *testing.T) { Instance: "my-instance", IPType: "public", Database: "my_db", + User: "my_user", + Password: "my_pass", }, }, }, @@ -64,6 +68,8 @@ func TestParseFromYamlCloudSQLPg(t *testing.T) { instance: my-instance ipType: Public database: my_db + user: my_user + password: my_pass `, want: server.SourceConfigs{ "my-pg-instance": cloudsqlpg.Config{ @@ -74,6 +80,8 @@ func TestParseFromYamlCloudSQLPg(t *testing.T) { Instance: "my-instance", IPType: "public", Database: "my_db", + User: "my_user", + Password: "my_pass", }, }, }, @@ -88,6 +96,8 @@ func TestParseFromYamlCloudSQLPg(t *testing.T) { instance: my-instance ipType: private database: my_db + user: my_user + password: my_pass `, want: server.SourceConfigs{ "my-pg-instance": cloudsqlpg.Config{ @@ -98,6 +108,8 @@ func TestParseFromYamlCloudSQLPg(t *testing.T) { Instance: "my-instance", IPType: "private", Database: "my_db", + User: "my_user", + Password: "my_pass", }, }, }, @@ -120,10 +132,11 @@ func TestParseFromYamlCloudSQLPg(t *testing.T) { } -func FailParseFromYamlCloudSQLPg(t *testing.T) { +func TestFailParseFromYaml(t *testing.T) { tcs := []struct { desc string in string + err string }{ { desc: "invalid ipType", @@ -136,7 +149,40 @@ func FailParseFromYamlCloudSQLPg(t *testing.T) { instance: my-instance ipType: fail database: my_db + user: my_user + password: my_pass `, + err: "unable to parse as \"cloud-sql-postgres\": ipType invalid: must be one of \"public\", or \"private\"", + }, + { + desc: "extra field", + in: ` + sources: + my-pg-instance: + kind: cloud-sql-postgres + project: my-project + region: my-region + instance: my-instance + database: my_db + user: my_user + password: my_pass + foo: bar + `, + err: "unable to parse as \"cloud-sql-postgres\": [2:1] unknown field \"foo\"\n 1 | database: my_db\n> 2 | foo: bar\n ^\n 3 | instance: my-instance\n 4 | kind: cloud-sql-postgres\n 5 | password: my_pass\n 6 | ", + }, + { + desc: "missing required field", + in: ` + sources: + my-pg-instance: + kind: cloud-sql-postgres + region: my-region + instance: my-instance + database: my_db + user: my_user + password: my_pass + `, + err: "unable to parse as \"cloud-sql-postgres\": Key: 'Config.Project' Error:Field validation for 'Project' failed on the 'required' tag", }, } for _, tc := range tcs { @@ -147,7 +193,11 @@ func FailParseFromYamlCloudSQLPg(t *testing.T) { // Parse contents err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got) if err == nil { - t.Fatalf("expect parsing to fail: %s", err) + t.Fatalf("expect parsing to fail") + } + errStr := err.Error() + if errStr != tc.err { + t.Fatalf("unexpected error: got %q, want %q", errStr, tc.err) } }) } diff --git a/internal/sources/dgraph/dgraph.go b/internal/sources/dgraph/dgraph.go index 197ba1087ff..32afcde14e2 100644 --- a/internal/sources/dgraph/dgraph.go +++ b/internal/sources/dgraph/dgraph.go @@ -50,9 +50,9 @@ type DgraphClient struct { } type Config struct { - Name string `yaml:"name"` - Kind string `yaml:"kind"` - DgraphUrl string `yaml:"dgraphUrl"` + Name string `yaml:"name" validate:"required"` + Kind string `yaml:"kind" validate:"required"` + DgraphUrl string `yaml:"dgraphUrl" validate:"required"` User string `yaml:"user"` Password string `yaml:"password"` Namespace uint64 `yaml:"namespace"` @@ -108,7 +108,7 @@ func initDgraphHttpClient(ctx context.Context, tracer trace.Tracer, r Config) (* hc := &DgraphClient{ httpClient: &http.Client{}, - baseUrl: r.DgraphUrl, + baseUrl: r.DgraphUrl, HttpToken: &HttpToken{ UserId: r.User, Namespace: r.Namespace, diff --git a/internal/sources/dgraph/dgraph_test.go b/internal/sources/dgraph/dgraph_test.go index 14b073ec0bb..825bf681065 100644 --- a/internal/sources/dgraph/dgraph_test.go +++ b/internal/sources/dgraph/dgraph_test.go @@ -54,6 +54,22 @@ func TestParseFromYamlDgraph(t *testing.T) { }, }, }, + { + desc: "basic example minimal field", + in: ` + sources: + my-dgraph-instance: + kind: dgraph + dgraphUrl: https://localhost:8080 + `, + want: server.SourceConfigs{ + "my-dgraph-instance": dgraph.Config{ + Name: "my-dgraph-instance", + Kind: dgraph.SourceKind, + DgraphUrl: "https://localhost:8080", + }, + }, + }, } for _, tc := range tcs { @@ -74,3 +90,48 @@ func TestParseFromYamlDgraph(t *testing.T) { } } + +func TestFailParseFromYaml(t *testing.T) { + tcs := []struct { + desc string + in string + err string + }{ + { + desc: "extra field", + in: ` + sources: + my-dgraph-instance: + kind: dgraph + dgraphUrl: https://localhost:8080 + foo: bar + `, + err: "unable to parse as \"dgraph\": [2:1] unknown field \"foo\"\n 1 | dgraphUrl: https://localhost:8080\n> 2 | foo: bar\n ^\n 3 | kind: dgraph", + }, + { + desc: "missing required field", + in: ` + sources: + my-dgraph-instance: + kind: dgraph + `, + err: "unable to parse as \"dgraph\": Key: 'Config.DgraphUrl' Error:Field validation for 'DgraphUrl' failed on the 'required' tag", + }, + } + for _, tc := range tcs { + t.Run(tc.desc, func(t *testing.T) { + got := struct { + Sources server.SourceConfigs `yaml:"sources"` + }{} + // Parse contents + err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got) + if err == nil { + t.Fatalf("expect parsing to fail") + } + errStr := err.Error() + if errStr != tc.err { + t.Fatalf("unexpected error: got %q, want %q", errStr, tc.err) + } + }) + } +} diff --git a/internal/sources/mssql/mssql.go b/internal/sources/mssql/mssql.go index d1d860aa9c7..b961d8b78a1 100644 --- a/internal/sources/mssql/mssql.go +++ b/internal/sources/mssql/mssql.go @@ -31,13 +31,13 @@ var _ sources.SourceConfig = Config{} type Config struct { // Cloud SQL MSSQL configs - Name string `yaml:"name"` - Kind string `yaml:"kind"` - Host string `yaml:"host"` - Port string `yaml:"port"` - User string `yaml:"user"` - Password string `yaml:"password"` - Database string `yaml:"database"` + Name string `yaml:"name" validate:"required"` + Kind string `yaml:"kind" validate:"required"` + Host string `yaml:"host" validate:"required"` + Port string `yaml:"port" validate:"required"` + User string `yaml:"user" validate:"required"` + Password string `yaml:"password" validate:"required"` + Database string `yaml:"database" validate:"required"` } func (r Config) SourceConfigKind() string { diff --git a/internal/sources/mssql/mssql_test.go b/internal/sources/mssql/mssql_test.go index 2963f384714..d71bfc41f5b 100644 --- a/internal/sources/mssql/mssql_test.go +++ b/internal/sources/mssql/mssql_test.go @@ -39,6 +39,8 @@ func TestParseFromYamlMssql(t *testing.T) { host: 0.0.0.0 port: my-port database: my_db + user: my_user + password: my_pass `, want: server.SourceConfigs{ "my-mssql-instance": mssql.Config{ @@ -47,6 +49,8 @@ func TestParseFromYamlMssql(t *testing.T) { Host: "0.0.0.0", Port: "my-port", Database: "my_db", + User: "my_user", + Password: "my_pass", }, }, }, @@ -68,3 +72,56 @@ func TestParseFromYamlMssql(t *testing.T) { } } + +func TestFailParseFromYaml(t *testing.T) { + tcs := []struct { + desc string + in string + err string + }{ + { + desc: "extra field", + in: ` + sources: + my-mssql-instance: + kind: mssql + host: 0.0.0.0 + port: my-port + database: my_db + user: my_user + password: my_pass + foo: bar + `, + err: "unable to parse as \"mssql\": [2:1] unknown field \"foo\"\n 1 | database: my_db\n> 2 | foo: bar\n ^\n 3 | host: 0.0.0.0\n 4 | kind: mssql\n 5 | password: my_pass\n 6 | ", + }, + { + desc: "missing required field", + in: ` + sources: + my-mssql-instance: + kind: mssql + host: 0.0.0.0 + port: my-port + database: my_db + user: my_user + `, + err: "unable to parse as \"mssql\": Key: 'Config.Password' Error:Field validation for 'Password' failed on the 'required' tag", + }, + } + for _, tc := range tcs { + t.Run(tc.desc, func(t *testing.T) { + got := struct { + Sources server.SourceConfigs `yaml:"sources"` + }{} + // Parse contents + err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got) + if err == nil { + t.Fatalf("expect parsing to fail") + } + errStr := err.Error() + if errStr != tc.err { + t.Fatalf("unexpected error: got %q, want %q", errStr, tc.err) + } + }) + } +} diff --git a/internal/sources/mysql/mysql.go b/internal/sources/mysql/mysql.go index c1321e8071a..e04f9550f73 100644 --- a/internal/sources/mysql/mysql.go +++ b/internal/sources/mysql/mysql.go @@ -30,13 +30,13 @@ const SourceKind string = "mysql" var _ sources.SourceConfig = Config{} type Config struct { - Name string `yaml:"name"` - Kind string `yaml:"kind"` - Host string `yaml:"host"` - Port string `yaml:"port"` - User string `yaml:"user"` - Password string `yaml:"password"` - Database string `yaml:"database"` + Name string `yaml:"name" validate:"required"` + Kind string `yaml:"kind" validate:"required"` + Host string `yaml:"host" validate:"required"` + Port string `yaml:"port" validate:"required"` + User string `yaml:"user" validate:"required"` + Password string `yaml:"password" validate:"required"` + Database string `yaml:"database" validate:"required"` } func (r Config) SourceConfigKind() string { diff --git a/internal/sources/mysql/mysql_test.go b/internal/sources/mysql/mysql_test.go index 3e312a20eed..633fc9aa214 100644 --- a/internal/sources/mysql/mysql_test.go +++ b/internal/sources/mysql/mysql_test.go @@ -39,6 +39,8 @@ func TestParseFromYamlCloudSQLMySQL(t *testing.T) { host: 0.0.0.0 port: my-host database: my_db + user: my_user + password: my_pass `, want: server.SourceConfigs{ "my-mysql-instance": mysql.Config{ @@ -47,6 +49,8 @@ func TestParseFromYamlCloudSQLMySQL(t *testing.T) { Host: "0.0.0.0", Port: "my-host", Database: "my_db", + User: "my_user", + Password: "my_pass", }, }, }, @@ -68,3 +72,56 @@ func TestParseFromYamlCloudSQLMySQL(t *testing.T) { } } + +func TestFailParseFromYaml(t *testing.T) { + tcs := []struct { + desc string + in string + err string + }{ + { + desc: "extra field", + in: ` + sources: + my-mysql-instance: + kind: mysql + host: 0.0.0.0 + port: my-host + database: my_db + user: my_user + password: my_pass + foo: bar + `, + err: "unable to parse as \"mysql\": [2:1] unknown field \"foo\"\n 1 | database: my_db\n> 2 | foo: bar\n ^\n 3 | host: 0.0.0.0\n 4 | kind: mysql\n 5 | password: my_pass\n 6 | ", + }, + { + desc: "missing required field", + in: ` + sources: + my-mysql-instance: + kind: mysql + port: my-host + database: my_db + user: my_user + password: my_pass + `, + err: "unable to parse as \"mysql\": Key: 'Config.Host' Error:Field validation for 'Host' failed on the 'required' tag", + }, + } + for _, tc := range tcs { + t.Run(tc.desc, func(t *testing.T) { + got := struct { + Sources server.SourceConfigs `yaml:"sources"` + }{} + // Parse contents + err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got) + if err == nil { + t.Fatalf("expect parsing to fail") + } + errStr := err.Error() + if errStr != tc.err { + t.Fatalf("unexpected error: got %q, want %q", errStr, tc.err) + } + }) + } +} diff --git a/internal/sources/neo4j/neo4j.go b/internal/sources/neo4j/neo4j.go index 5f3de6e9c67..38258a24e95 100644 --- a/internal/sources/neo4j/neo4j.go +++ b/internal/sources/neo4j/neo4j.go @@ -29,12 +29,12 @@ const SourceKind string = "neo4j" var _ sources.SourceConfig = Config{} type Config struct { - Name string `yaml:"name"` - Kind string `yaml:"kind"` - Uri string `yaml:"uri"` - User string `yaml:"user"` - Password string `yaml:"password"` - Database string `yaml:"database"` + Name string `yaml:"name" validate:"required"` + Kind string `yaml:"kind" validate:"required"` + Uri string `yaml:"uri" validate:"required"` + User string `yaml:"user" validate:"required"` + Password string `yaml:"password" validate:"required"` + Database string `yaml:"database" validate:"required"` } func (r Config) SourceConfigKind() string { diff --git a/internal/sources/neo4j/neo4j_test.go b/internal/sources/neo4j/neo4j_test.go index 000d088ce83..c602474be7d 100644 --- a/internal/sources/neo4j/neo4j_test.go +++ b/internal/sources/neo4j/neo4j_test.go @@ -38,6 +38,8 @@ func TestParseFromYamlNeo4j(t *testing.T) { kind: neo4j uri: neo4j+s://my-host:7687 database: my_db + user: my_user + password: my_pass `, want: server.SourceConfigs{ "my-neo4j-instance": neo4j.Config{ @@ -45,6 +47,8 @@ func TestParseFromYamlNeo4j(t *testing.T) { Kind: neo4j.SourceKind, Uri: "neo4j+s://my-host:7687", Database: "my_db", + User: "my_user", + Password: "my_pass", }, }, }, @@ -66,3 +70,54 @@ func TestParseFromYamlNeo4j(t *testing.T) { } } + +func TestFailParseFromYaml(t *testing.T) { + tcs := []struct { + desc string + in string + err string + }{ + { + desc: "extra field", + in: ` + sources: + my-neo4j-instance: + kind: neo4j + uri: neo4j+s://my-host:7687 + database: my_db + user: my_user + password: my_pass + foo: bar + `, + err: "unable to parse as \"neo4j\": [2:1] unknown field \"foo\"\n 1 | database: my_db\n> 2 | foo: bar\n ^\n 3 | kind: neo4j\n 4 | password: my_pass\n 5 | uri: neo4j+s://my-host:7687\n 6 | ", + }, + { + desc: "missing required field", + in: ` + sources: + my-neo4j-instance: + kind: neo4j + uri: neo4j+s://my-host:7687 + database: my_db + user: my_user + `, + err: "unable to parse as \"neo4j\": Key: 'Config.Password' Error:Field validation for 'Password' failed on the 'required' tag", + }, + } + for _, tc := range tcs { + t.Run(tc.desc, func(t *testing.T) { + got := struct { + Sources server.SourceConfigs `yaml:"sources"` + }{} + // Parse contents + err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got) + if err == nil { + t.Fatalf("expect parsing to fail") + } + errStr := err.Error() + if errStr != tc.err { + t.Fatalf("unexpected error: got %q, want %q", errStr, tc.err) + } + }) + } +} diff --git a/internal/sources/postgres/postgres.go b/internal/sources/postgres/postgres.go index 9694ca67bfc..9f2aba4d95b 100644 --- a/internal/sources/postgres/postgres.go +++ b/internal/sources/postgres/postgres.go @@ -29,13 +29,13 @@ const SourceKind string = "postgres" var _ sources.SourceConfig = Config{} type Config struct { - Name string `yaml:"name"` - Kind string `yaml:"kind"` - Host string `yaml:"host"` - Port string `yaml:"port"` - User string `yaml:"user"` - Password string `yaml:"password"` - Database string `yaml:"database"` + Name string `yaml:"name" validate:"required"` + Kind string `yaml:"kind" validate:"required"` + Host string `yaml:"host" validate:"required"` + Port string `yaml:"port" validate:"required"` + User string `yaml:"user" validate:"required"` + Password string `yaml:"password" validate:"required"` + Database string `yaml:"database" validate:"required"` } func (r Config) SourceConfigKind() string { diff --git a/internal/sources/postgres/postgres_test.go b/internal/sources/postgres/postgres_test.go index 8d80c8710c5..09603ebf883 100644 --- a/internal/sources/postgres/postgres_test.go +++ b/internal/sources/postgres/postgres_test.go @@ -39,6 +39,8 @@ func TestParseFromYamlPostgres(t *testing.T) { host: my-host port: 0.0.0.0 database: my_db + user: my_user + password: my_pass `, want: server.SourceConfigs{ "my-pg-instance": postgres.Config{ @@ -47,6 +49,8 @@ func TestParseFromYamlPostgres(t *testing.T) { Host: "my-host", Port: "0.0.0.0", Database: "my_db", + User: "my_user", + Password: "my_pass", }, }, }, @@ -68,3 +72,56 @@ func TestParseFromYamlPostgres(t *testing.T) { } } + +func TestFailParseFromYaml(t *testing.T) { + tcs := []struct { + desc string + in string + err string + }{ + { + desc: "extra field", + in: ` + sources: + my-pg-instance: + kind: postgres + host: my-host + port: 0.0.0.0 + database: my_db + user: my_user + password: my_pass + foo: bar + `, + err: "unable to parse as \"postgres\": [2:1] unknown field \"foo\"\n 1 | database: my_db\n> 2 | foo: bar\n ^\n 3 | host: my-host\n 4 | kind: postgres\n 5 | password: my_pass\n 6 | ", + }, + { + desc: "missing required field", + in: ` + sources: + my-pg-instance: + kind: postgres + host: my-host + port: 0.0.0.0 + database: my_db + user: my_user + `, + err: "unable to parse as \"postgres\": Key: 'Config.Password' Error:Field validation for 'Password' failed on the 'required' tag", + }, + } + for _, tc := range tcs { + t.Run(tc.desc, func(t *testing.T) { + got := struct { + Sources server.SourceConfigs `yaml:"sources"` + }{} + // Parse contents + err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got) + if err == nil { + t.Fatalf("expect parsing to fail") + } + errStr := err.Error() + if errStr != tc.err { + t.Fatalf("unexpected error: got %q, want %q", errStr, tc.err) + } + }) + } +} diff --git a/internal/sources/spanner/spanner.go b/internal/sources/spanner/spanner.go index d8276f3bf20..06d066fd367 100644 --- a/internal/sources/spanner/spanner.go +++ b/internal/sources/spanner/spanner.go @@ -30,12 +30,12 @@ const SourceKind string = "spanner" var _ sources.SourceConfig = Config{} type Config struct { - Name string `yaml:"name"` - Kind string `yaml:"kind"` - Project string `yaml:"project"` - Instance string `yaml:"instance"` - Dialect sources.Dialect `yaml:"dialect"` - Database string `yaml:"database"` + Name string `yaml:"name" validate:"required"` + Kind string `yaml:"kind" validate:"required"` + Project string `yaml:"project" validate:"required"` + Instance string `yaml:"instance" validate:"required"` + Dialect sources.Dialect `yaml:"dialect" validate:"required"` + Database string `yaml:"database" validate:"required"` } func (r Config) SourceConfigKind() string { diff --git a/internal/sources/spanner/spanner_test.go b/internal/sources/spanner/spanner_test.go index af4ac54c5f7..5d6a9320c76 100644 --- a/internal/sources/spanner/spanner_test.go +++ b/internal/sources/spanner/spanner_test.go @@ -115,10 +115,11 @@ func TestParseFromYamlSpannerDb(t *testing.T) { } -func FailParseFromYamlSpanner(t *testing.T) { +func TestFailParseFromYaml(t *testing.T) { tcs := []struct { desc string in string + err string }{ { desc: "invalid dialect", @@ -128,9 +129,34 @@ func FailParseFromYamlSpanner(t *testing.T) { kind: spanner project: my-project instance: my-instance - dialect: fail + dialect: fail database: my_db `, + err: "unable to parse as \"spanner\": dialect invalid: must be one of \"googlesql\", or \"postgresql\"", + }, + { + desc: "extra field", + in: ` + sources: + my-spanner-instance: + kind: spanner + project: my-project + instance: my-instance + database: my_db + foo: bar + `, + err: "unable to parse as \"spanner\": [2:1] unknown field \"foo\"\n 1 | database: my_db\n> 2 | foo: bar\n ^\n 3 | instance: my-instance\n 4 | kind: spanner\n 5 | project: my-project", + }, + { + desc: "missing required field", + in: ` + sources: + my-spanner-instance: + kind: spanner + project: my-project + instance: my-instance + `, + err: "unable to parse as \"spanner\": Key: 'Config.Database' Error:Field validation for 'Database' failed on the 'required' tag", }, } for _, tc := range tcs { @@ -141,7 +167,11 @@ func FailParseFromYamlSpanner(t *testing.T) { // Parse contents err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got) if err == nil { - t.Fatalf("expect parsing to fail: %s", err) + t.Fatalf("expect parsing to fail") + } + errStr := err.Error() + if errStr != tc.err { + t.Fatalf("unexpected error: got %q, want %q", errStr, tc.err) } }) } diff --git a/internal/tools/dgraph/dgraph.go b/internal/tools/dgraph/dgraph.go index c5a57fa1d85..19c49cec548 100644 --- a/internal/tools/dgraph/dgraph.go +++ b/internal/tools/dgraph/dgraph.go @@ -35,11 +35,11 @@ var _ compatibleSource = &dgraph.Source{} var compatibleSources = [...]string{dgraph.SourceKind} type Config struct { - Name string `yaml:"name"` - Kind string `yaml:"kind"` - Source string `yaml:"source"` - Description string `yaml:"description"` - Statement string `yaml:"statement"` + Name string `yaml:"name" validate:"required"` + Kind string `yaml:"kind" validate:"required"` + Source string `yaml:"source" validate:"required"` + Description string `yaml:"description" validate:"required"` + Statement string `yaml:"statement" validate:"required"` IsQuery bool `yaml:"isQuery"` Timeout string `yaml:"timeout"` Parameters tools.Parameters `yaml:"parameters"` diff --git a/internal/tools/mssqlsql/mssqlsql.go b/internal/tools/mssqlsql/mssqlsql.go index 4b97dd6116d..cad66cf7767 100644 --- a/internal/tools/mssqlsql/mssqlsql.go +++ b/internal/tools/mssqlsql/mssqlsql.go @@ -39,11 +39,11 @@ var _ compatibleSource = &mssql.Source{} var compatibleSources = [...]string{cloudsqlmssql.SourceKind, mssql.SourceKind} type Config struct { - Name string `yaml:"name"` - Kind string `yaml:"kind"` - Source string `yaml:"source"` - Description string `yaml:"description"` - Statement string `yaml:"statement"` + Name string `yaml:"name" validate:"required"` + Kind string `yaml:"kind" validate:"required"` + Source string `yaml:"source" validate:"required"` + Description string `yaml:"description" validate:"required"` + Statement string `yaml:"statement" validate:"required"` AuthRequired []string `yaml:"authRequired"` Parameters tools.Parameters `yaml:"parameters"` } diff --git a/internal/tools/mysqlsql/mysqlsql.go b/internal/tools/mysqlsql/mysqlsql.go index 4317969bb2c..40a1621aa2e 100644 --- a/internal/tools/mysqlsql/mysqlsql.go +++ b/internal/tools/mysqlsql/mysqlsql.go @@ -39,11 +39,11 @@ var _ compatibleSource = &mysql.Source{} var compatibleSources = [...]string{cloudsqlmysql.SourceKind, mysql.SourceKind} type Config struct { - Name string `yaml:"name"` - Kind string `yaml:"kind"` - Source string `yaml:"source"` - Description string `yaml:"description"` - Statement string `yaml:"statement"` + Name string `yaml:"name" validate:"required"` + Kind string `yaml:"kind" validate:"required"` + Source string `yaml:"source" validate:"required"` + Description string `yaml:"description" validate:"required"` + Statement string `yaml:"statement" validate:"required"` AuthRequired []string `yaml:"authRequired"` Parameters tools.Parameters `yaml:"parameters"` } diff --git a/internal/tools/neo4j/neo4j.go b/internal/tools/neo4j/neo4j.go index 216545658ae..3cf8c2b5c25 100644 --- a/internal/tools/neo4j/neo4j.go +++ b/internal/tools/neo4j/neo4j.go @@ -39,11 +39,11 @@ var _ compatibleSource = &neo4jsc.Source{} var compatibleSources = [...]string{neo4jsc.SourceKind} type Config struct { - Name string `yaml:"name"` - Kind string `yaml:"kind"` - Source string `yaml:"source"` - Description string `yaml:"description"` - Statement string `yaml:"statement"` + Name string `yaml:"name" validate:"required"` + Kind string `yaml:"kind" validate:"required"` + Source string `yaml:"source" validate:"required"` + Description string `yaml:"description" validate:"required"` + Statement string `yaml:"statement" validate:"required"` Parameters tools.Parameters `yaml:"parameters"` } diff --git a/internal/tools/parameters.go b/internal/tools/parameters.go index 98b4bf59fe3..504b4d9f615 100644 --- a/internal/tools/parameters.go +++ b/internal/tools/parameters.go @@ -84,16 +84,16 @@ func (p ParamValues) AsMapByOrderedKeys() map[string]interface{} { // Input: {"role": "admin", "$age": 30} // Output: {"$role": "admin", "$age": 30} func (p ParamValues) AsMapWithDollarPrefix() map[string]interface{} { - params := make(map[string]interface{}) + params := make(map[string]interface{}) - for _, param := range p { - key := param.Name - if !strings.HasPrefix(key, "$") { + for _, param := range p { + key := param.Name + if !strings.HasPrefix(key, "$") { key = "$" + key } - params[key] = param.Value - } - return params + params[key] = param.Value + } + return params } func parseFromAuthSource(paramAuthSources []ParamAuthSource, claimsMap map[string]map[string]any) (any, error) { @@ -178,44 +178,50 @@ func (c *Parameters) UnmarshalYAML(unmarshal func(interface{}) error) error { // parseParamFromDelayedUnmarshaler is a helper function that is required to parse // parameters because there are multiple different types func parseParamFromDelayedUnmarshaler(u *util.DelayedUnmarshaler) (Parameter, error) { - var p CommonParameter + var p map[string]any err := u.Unmarshal(&p) if err != nil { - return nil, fmt.Errorf("parameter missing required fields: %w", err) + return nil, fmt.Errorf("error parsing parameters: %w", err) } - switch p.Type { + + t, ok := p["type"] + if !ok { + return nil, fmt.Errorf("parameter is missing 'type' field: %w", err) + } + + switch t { case typeString: a := &StringParameter{} if err := u.Unmarshal(a); err != nil { - return nil, fmt.Errorf("unable to parse as %q: %w", p.Type, err) + return nil, fmt.Errorf("unable to parse as %q: %w", t, err) } return a, nil case typeInt: a := &IntParameter{} if err := u.Unmarshal(a); err != nil { - return nil, fmt.Errorf("unable to parse as %q: %w", p.Type, err) + return nil, fmt.Errorf("unable to parse as %q: %w", t, err) } return a, nil case typeFloat: a := &FloatParameter{} if err := u.Unmarshal(a); err != nil { - return nil, fmt.Errorf("unable to parse as %q: %w", p.Type, err) + return nil, fmt.Errorf("unable to parse as %q: %w", t, err) } return a, nil case typeBool: a := &BooleanParameter{} if err := u.Unmarshal(a); err != nil { - return nil, fmt.Errorf("unable to parse as %q: %w", p.Type, err) + return nil, fmt.Errorf("unable to parse as %q: %w", t, err) } return a, nil case typeArray: a := &ArrayParameter{} if err := u.Unmarshal(a); err != nil { - return nil, fmt.Errorf("unable to parse as %q: %w", p.Type, err) + return nil, fmt.Errorf("unable to parse as %q: %w", t, err) } return a, nil } - return nil, fmt.Errorf("%q is not valid type for a parameter!", p.Type) + return nil, fmt.Errorf("%q is not valid type for a parameter!", t) } func (ps Parameters) Manifest() []ParameterManifest { @@ -515,15 +521,14 @@ type ArrayParameter struct { } func (p *ArrayParameter) UnmarshalYAML(unmarshal func(interface{}) error) error { - if err := unmarshal(&p.CommonParameter); err != nil { - return err - } var rawItem struct { - Items util.DelayedUnmarshaler `yaml:"items"` + CommonParameter `yaml:",inline"` + Items util.DelayedUnmarshaler `yaml:"items"` } if err := unmarshal(&rawItem); err != nil { return err } + p.CommonParameter = rawItem.CommonParameter i, err := parseParamFromDelayedUnmarshaler(&rawItem.Items) if err != nil { return fmt.Errorf("unable to parse 'items' field: %w", err) diff --git a/internal/tools/parameters_test.go b/internal/tools/parameters_test.go index e564b71d0a3..6dbddaad551 100644 --- a/internal/tools/parameters_test.go +++ b/internal/tools/parameters_test.go @@ -608,12 +608,12 @@ func TestAuthParametersParse(t *testing.T) { func TestParamValues(t *testing.T) { tcs := []struct { - name string - in tools.ParamValues - wantSlice []any - wantMap map[string]interface{} - wantMapOrdered map[string]interface{} - wantMapWithDollar map[string]interface{} + name string + in tools.ParamValues + wantSlice []any + wantMap map[string]interface{} + wantMapOrdered map[string]interface{} + wantMapWithDollar map[string]interface{} }{ { name: "string", diff --git a/internal/tools/postgressql/postgressql.go b/internal/tools/postgressql/postgressql.go index cab16b2d883..cec03d23d98 100644 --- a/internal/tools/postgressql/postgressql.go +++ b/internal/tools/postgressql/postgressql.go @@ -41,11 +41,11 @@ var _ compatibleSource = &postgres.Source{} var compatibleSources = [...]string{alloydbpg.SourceKind, cloudsqlpg.SourceKind, postgres.SourceKind} type Config struct { - Name string `yaml:"name"` - Kind string `yaml:"kind"` - Source string `yaml:"source"` - Description string `yaml:"description"` - Statement string `yaml:"statement"` + Name string `yaml:"name" validate:"required"` + Kind string `yaml:"kind" validate:"required"` + Source string `yaml:"source" validate:"required"` + Description string `yaml:"description" validate:"required"` + Statement string `yaml:"statement" validate:"required"` AuthRequired []string `yaml:"authRequired"` Parameters tools.Parameters `yaml:"parameters"` } diff --git a/internal/tools/spanner/spanner.go b/internal/tools/spanner/spanner.go index 9cf2e6883ed..3b46b11d1de 100644 --- a/internal/tools/spanner/spanner.go +++ b/internal/tools/spanner/spanner.go @@ -39,11 +39,11 @@ var _ compatibleSource = &spannerdb.Source{} var compatibleSources = [...]string{spannerdb.SourceKind} type Config struct { - Name string `yaml:"name"` - Kind string `yaml:"kind"` - Source string `yaml:"source"` - Description string `yaml:"description"` - Statement string `yaml:"statement"` + Name string `yaml:"name" validate:"required"` + Kind string `yaml:"kind" validate:"required"` + Source string `yaml:"source" validate:"required"` + Description string `yaml:"description" validate:"required"` + Statement string `yaml:"statement" validate:"required"` AuthRequired []string `yaml:"authRequired"` Parameters tools.Parameters `yaml:"parameters"` } diff --git a/internal/util/util.go b/internal/util/util.go index 255869dc359..ab4642854e5 100644 --- a/internal/util/util.go +++ b/internal/util/util.go @@ -14,6 +14,10 @@ package util import ( + "bytes" + "fmt" + + "github.com/go-playground/validator/v10" yaml "github.com/goccy/go-yaml" ) @@ -39,3 +43,17 @@ type contextKey string // UserAgentKey is the key used to store userAgent within context const UserAgentKey contextKey = "userAgent" + +func NewStrictDecoder(v interface{}) (*yaml.Decoder, error) { + b, err := yaml.Marshal(v) + if err != nil { + return nil, fmt.Errorf("fail to marshal %q: %w", v, err) + } + + dec := yaml.NewDecoder( + bytes.NewReader(b), + yaml.Strict(), + yaml.Validator(validator.New()), + ) + return dec, nil +}