mirror of
https://github.com/googleapis/genai-toolbox.git
synced 2026-01-11 08:28:11 -05:00
feat(sqlserver): add mssql source (#255)
Add `mssql` source for non-cloud sql server.
This commit is contained in:
2
go.mod
2
go.mod
@@ -18,6 +18,7 @@ require (
|
||||
github.com/google/go-cmp v0.6.0
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/jackc/pgx/v5 v5.7.1
|
||||
github.com/microsoft/go-mssqldb v1.8.0
|
||||
github.com/neo4j/neo4j-go-driver/v5 v5.26.0
|
||||
github.com/spf13/cobra v1.8.1
|
||||
go.opentelemetry.io/contrib/propagators/autoprop v0.58.0
|
||||
@@ -67,7 +68,6 @@ require (
|
||||
github.com/jackc/pgpassfile v1.0.0 // indirect
|
||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
|
||||
github.com/jackc/puddle/v2 v2.2.2 // indirect
|
||||
github.com/microsoft/go-mssqldb v1.8.0 // indirect
|
||||
github.com/planetscale/vtprotobuf v0.6.1-0.20240319094008-0393e58bdf10 // indirect
|
||||
github.com/spf13/pflag v1.0.5 // indirect
|
||||
go.opencensus.io v0.24.0 // indirect
|
||||
|
||||
@@ -25,6 +25,7 @@ import (
|
||||
cloudsqlmssqlsrc "github.com/googleapis/genai-toolbox/internal/sources/cloudsqlmssql"
|
||||
cloudsqlmysqlsrc "github.com/googleapis/genai-toolbox/internal/sources/cloudsqlmysql"
|
||||
cloudsqlpgsrc "github.com/googleapis/genai-toolbox/internal/sources/cloudsqlpg"
|
||||
mssqlsrc "github.com/googleapis/genai-toolbox/internal/sources/mssql"
|
||||
mysqlsrc "github.com/googleapis/genai-toolbox/internal/sources/mysql"
|
||||
neo4jrc "github.com/googleapis/genai-toolbox/internal/sources/neo4j"
|
||||
postgressrc "github.com/googleapis/genai-toolbox/internal/sources/postgres"
|
||||
@@ -188,6 +189,12 @@ func (c *SourceConfigs) UnmarshalYAML(unmarshal func(interface{}) error) error {
|
||||
return fmt.Errorf("unable to parse as %q: %w", k.Kind, err)
|
||||
}
|
||||
(*c)[name] = actual
|
||||
case mssqlsrc.SourceKind:
|
||||
actual := mssqlsrc.Config{Name: name}
|
||||
if err := u.Unmarshal(&actual); err != nil {
|
||||
return fmt.Errorf("unable to parse as %q: %w", k.Kind, err)
|
||||
}
|
||||
(*c)[name] = actual
|
||||
default:
|
||||
return fmt.Errorf("%q is not a valid kind of data source", k.Kind)
|
||||
}
|
||||
|
||||
102
internal/sources/mssql/mssql.go
Normal file
102
internal/sources/mssql/mssql.go
Normal file
@@ -0,0 +1,102 @@
|
||||
// Copyright 2025 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package mssql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
_ "github.com/microsoft/go-mssqldb"
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
)
|
||||
|
||||
const SourceKind string = "mssql"
|
||||
|
||||
// validate interface
|
||||
var _ sources.SourceConfig = Config{}
|
||||
|
||||
type Config struct {
|
||||
// Cloud SQL MSSQL configs
|
||||
Name string `yaml:"name"`
|
||||
Kind string `yaml:"kind"`
|
||||
Host string `yaml:"host"`
|
||||
Port string `yaml:"port"`
|
||||
User string `yaml:"user"`
|
||||
Password string `yaml:"password"`
|
||||
Database string `yaml:"database"`
|
||||
}
|
||||
|
||||
func (r Config) SourceConfigKind() string {
|
||||
// Returns Cloud SQL MSSQL source kind
|
||||
return SourceKind
|
||||
}
|
||||
|
||||
func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) {
|
||||
// Initializes a MSSQL source
|
||||
db, err := initMssqlConnection(ctx, tracer, r.Name, r.Host, r.Port, r.User, r.Password, r.Database)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to create db connection: %w", err)
|
||||
}
|
||||
|
||||
// Verify db connection
|
||||
err = db.PingContext(context.Background())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to connect successfully: %w", err)
|
||||
}
|
||||
|
||||
s := &Source{
|
||||
Name: r.Name,
|
||||
Kind: SourceKind,
|
||||
Db: db,
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
|
||||
var _ sources.Source = &Source{}
|
||||
|
||||
type Source struct {
|
||||
// Cloud SQL MSSQL struct with connection pool
|
||||
Name string `yaml:"name"`
|
||||
Kind string `yaml:"kind"`
|
||||
Db *sql.DB
|
||||
}
|
||||
|
||||
func (s *Source) SourceKind() string {
|
||||
// Returns Cloud SQL MSSQL source kind
|
||||
return SourceKind
|
||||
}
|
||||
|
||||
func (s *Source) MSSQLDB() *sql.DB {
|
||||
// Returns a Cloud SQL MSSQL database connection pool
|
||||
return s.Db
|
||||
}
|
||||
|
||||
func initMssqlConnection(ctx context.Context, tracer trace.Tracer, name, host, port, user, pass, dbname string) (*sql.DB, error) {
|
||||
//nolint:all // Reassigned ctx
|
||||
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)
|
||||
defer span.End()
|
||||
|
||||
// Create dsn
|
||||
dsn := fmt.Sprintf("sqlserver://%s:%s@%s:%s?database=%s", user, pass, host, port, dbname)
|
||||
|
||||
// Open database connection
|
||||
db, err := sql.Open("sqlserver", dsn)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("sql.Open: %w", err)
|
||||
}
|
||||
return db, nil
|
||||
}
|
||||
70
internal/sources/mssql/mssql_test.go
Normal file
70
internal/sources/mssql/mssql_test.go
Normal file
@@ -0,0 +1,70 @@
|
||||
// Copyright 2025 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package mssql_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/googleapis/genai-toolbox/internal/server"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources/mssql"
|
||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
func TestParseFromYamlMssql(t *testing.T) {
|
||||
tcs := []struct {
|
||||
desc string
|
||||
in string
|
||||
want server.SourceConfigs
|
||||
}{
|
||||
{
|
||||
desc: "basic example",
|
||||
in: `
|
||||
sources:
|
||||
my-mssql-instance:
|
||||
kind: mssql
|
||||
host: 0.0.0.0
|
||||
port: my-port
|
||||
database: my_db
|
||||
`,
|
||||
want: server.SourceConfigs{
|
||||
"my-mssql-instance": mssql.Config{
|
||||
Name: "my-mssql-instance",
|
||||
Kind: mssql.SourceKind,
|
||||
Host: "0.0.0.0",
|
||||
Port: "my-port",
|
||||
Database: "my_db",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
got := struct {
|
||||
Sources server.SourceConfigs `yaml:"sources"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to unmarshal: %s", err)
|
||||
}
|
||||
if !cmp.Equal(tc.want, got.Sources) {
|
||||
t.Fatalf("incorrect psarse: want %v, got %v", tc.want, got.Sources)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
}
|
||||
@@ -15,12 +15,14 @@
|
||||
package mssqlsql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources/cloudsqlmssql"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources/mssql"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
)
|
||||
|
||||
@@ -32,8 +34,9 @@ type compatibleSource interface {
|
||||
|
||||
// validate compatible sources are still compatible
|
||||
var _ compatibleSource = &cloudsqlmssql.Source{}
|
||||
var _ compatibleSource = &mssql.Source{}
|
||||
|
||||
var compatibleSources = [...]string{cloudsqlmssql.SourceKind}
|
||||
var compatibleSources = [...]string{cloudsqlmssql.SourceKind, mssql.SourceKind}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name"`
|
||||
@@ -118,7 +121,7 @@ func (t Tool) Invoke(params tools.ParamValues) (string, error) {
|
||||
namedArgs = append(namedArgs, v)
|
||||
}
|
||||
}
|
||||
rows, err := t.Db.Query(t.Statement, namedArgs...)
|
||||
rows, err := t.Db.QueryContext(context.Background(), t.Statement, namedArgs...)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("unable to execute query: %w", err)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user