feat(sqlserver): add mssql source (#255)

Add `mssql` source for non-cloud sql server.
This commit is contained in:
Yuan
2025-01-31 11:13:54 -08:00
committed by GitHub
parent 5f9fc762e5
commit 8fca0a95ee
5 changed files with 185 additions and 3 deletions

2
go.mod
View File

@@ -18,6 +18,7 @@ require (
github.com/google/go-cmp v0.6.0
github.com/google/uuid v1.6.0
github.com/jackc/pgx/v5 v5.7.1
github.com/microsoft/go-mssqldb v1.8.0
github.com/neo4j/neo4j-go-driver/v5 v5.26.0
github.com/spf13/cobra v1.8.1
go.opentelemetry.io/contrib/propagators/autoprop v0.58.0
@@ -67,7 +68,6 @@ require (
github.com/jackc/pgpassfile v1.0.0 // indirect
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
github.com/jackc/puddle/v2 v2.2.2 // indirect
github.com/microsoft/go-mssqldb v1.8.0 // indirect
github.com/planetscale/vtprotobuf v0.6.1-0.20240319094008-0393e58bdf10 // indirect
github.com/spf13/pflag v1.0.5 // indirect
go.opencensus.io v0.24.0 // indirect

View File

@@ -25,6 +25,7 @@ import (
cloudsqlmssqlsrc "github.com/googleapis/genai-toolbox/internal/sources/cloudsqlmssql"
cloudsqlmysqlsrc "github.com/googleapis/genai-toolbox/internal/sources/cloudsqlmysql"
cloudsqlpgsrc "github.com/googleapis/genai-toolbox/internal/sources/cloudsqlpg"
mssqlsrc "github.com/googleapis/genai-toolbox/internal/sources/mssql"
mysqlsrc "github.com/googleapis/genai-toolbox/internal/sources/mysql"
neo4jrc "github.com/googleapis/genai-toolbox/internal/sources/neo4j"
postgressrc "github.com/googleapis/genai-toolbox/internal/sources/postgres"
@@ -188,6 +189,12 @@ func (c *SourceConfigs) UnmarshalYAML(unmarshal func(interface{}) error) error {
return fmt.Errorf("unable to parse as %q: %w", k.Kind, err)
}
(*c)[name] = actual
case mssqlsrc.SourceKind:
actual := mssqlsrc.Config{Name: name}
if err := u.Unmarshal(&actual); err != nil {
return fmt.Errorf("unable to parse as %q: %w", k.Kind, err)
}
(*c)[name] = actual
default:
return fmt.Errorf("%q is not a valid kind of data source", k.Kind)
}

View File

@@ -0,0 +1,102 @@
// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package mssql
import (
"context"
"database/sql"
"fmt"
"github.com/googleapis/genai-toolbox/internal/sources"
_ "github.com/microsoft/go-mssqldb"
"go.opentelemetry.io/otel/trace"
)
const SourceKind string = "mssql"
// validate interface
var _ sources.SourceConfig = Config{}
type Config struct {
// Cloud SQL MSSQL configs
Name string `yaml:"name"`
Kind string `yaml:"kind"`
Host string `yaml:"host"`
Port string `yaml:"port"`
User string `yaml:"user"`
Password string `yaml:"password"`
Database string `yaml:"database"`
}
func (r Config) SourceConfigKind() string {
// Returns Cloud SQL MSSQL source kind
return SourceKind
}
func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) {
// Initializes a MSSQL source
db, err := initMssqlConnection(ctx, tracer, r.Name, r.Host, r.Port, r.User, r.Password, r.Database)
if err != nil {
return nil, fmt.Errorf("unable to create db connection: %w", err)
}
// Verify db connection
err = db.PingContext(context.Background())
if err != nil {
return nil, fmt.Errorf("unable to connect successfully: %w", err)
}
s := &Source{
Name: r.Name,
Kind: SourceKind,
Db: db,
}
return s, nil
}
var _ sources.Source = &Source{}
type Source struct {
// Cloud SQL MSSQL struct with connection pool
Name string `yaml:"name"`
Kind string `yaml:"kind"`
Db *sql.DB
}
func (s *Source) SourceKind() string {
// Returns Cloud SQL MSSQL source kind
return SourceKind
}
func (s *Source) MSSQLDB() *sql.DB {
// Returns a Cloud SQL MSSQL database connection pool
return s.Db
}
func initMssqlConnection(ctx context.Context, tracer trace.Tracer, name, host, port, user, pass, dbname string) (*sql.DB, error) {
//nolint:all // Reassigned ctx
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)
defer span.End()
// Create dsn
dsn := fmt.Sprintf("sqlserver://%s:%s@%s:%s?database=%s", user, pass, host, port, dbname)
// Open database connection
db, err := sql.Open("sqlserver", dsn)
if err != nil {
return nil, fmt.Errorf("sql.Open: %w", err)
}
return db, nil
}

View File

@@ -0,0 +1,70 @@
// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package mssql_test
import (
"testing"
"github.com/google/go-cmp/cmp"
"github.com/googleapis/genai-toolbox/internal/server"
"github.com/googleapis/genai-toolbox/internal/sources/mssql"
"github.com/googleapis/genai-toolbox/internal/testutils"
"gopkg.in/yaml.v3"
)
func TestParseFromYamlMssql(t *testing.T) {
tcs := []struct {
desc string
in string
want server.SourceConfigs
}{
{
desc: "basic example",
in: `
sources:
my-mssql-instance:
kind: mssql
host: 0.0.0.0
port: my-port
database: my_db
`,
want: server.SourceConfigs{
"my-mssql-instance": mssql.Config{
Name: "my-mssql-instance",
Kind: mssql.SourceKind,
Host: "0.0.0.0",
Port: "my-port",
Database: "my_db",
},
},
},
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
got := struct {
Sources server.SourceConfigs `yaml:"sources"`
}{}
// Parse contents
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
if err != nil {
t.Fatalf("unable to unmarshal: %s", err)
}
if !cmp.Equal(tc.want, got.Sources) {
t.Fatalf("incorrect psarse: want %v, got %v", tc.want, got.Sources)
}
})
}
}

View File

@@ -15,12 +15,14 @@
package mssqlsql
import (
"context"
"database/sql"
"fmt"
"strings"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/sources/cloudsqlmssql"
"github.com/googleapis/genai-toolbox/internal/sources/mssql"
"github.com/googleapis/genai-toolbox/internal/tools"
)
@@ -32,8 +34,9 @@ type compatibleSource interface {
// validate compatible sources are still compatible
var _ compatibleSource = &cloudsqlmssql.Source{}
var _ compatibleSource = &mssql.Source{}
var compatibleSources = [...]string{cloudsqlmssql.SourceKind}
var compatibleSources = [...]string{cloudsqlmssql.SourceKind, mssql.SourceKind}
type Config struct {
Name string `yaml:"name"`
@@ -118,7 +121,7 @@ func (t Tool) Invoke(params tools.ParamValues) (string, error) {
namedArgs = append(namedArgs, v)
}
}
rows, err := t.Db.Query(t.Statement, namedArgs...)
rows, err := t.Db.QueryContext(context.Background(), t.Statement, namedArgs...)
if err != nil {
return "", fmt.Errorf("unable to execute query: %w", err)
}