mirror of
https://github.com/googleapis/genai-toolbox.git
synced 2026-01-09 07:28:05 -05:00
feat: Add embeddingModel support (#2121)
First part of the implementation to support semantic search in tools. Second part: https://github.com/googleapis/genai-toolbox/pull/2151
This commit is contained in:
@@ -276,7 +276,7 @@ func setUpServer(t *testing.T, router string, tools map[string]tools.Tool, tools
|
||||
|
||||
sseManager := newSseManager(ctx)
|
||||
|
||||
resourceManager := resources.NewResourceManager(nil, nil, tools, toolsets, prompts, promptsets)
|
||||
resourceManager := resources.NewResourceManager(nil, nil, nil, tools, toolsets, prompts, promptsets)
|
||||
|
||||
server := Server{
|
||||
version: fakeVersionString,
|
||||
|
||||
@@ -21,6 +21,8 @@ import (
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/auth"
|
||||
"github.com/googleapis/genai-toolbox/internal/auth/google"
|
||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels/gemini"
|
||||
"github.com/googleapis/genai-toolbox/internal/prompts"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
@@ -38,6 +40,8 @@ type ServerConfig struct {
|
||||
SourceConfigs SourceConfigs
|
||||
// AuthServiceConfigs defines what sources of authentication are available for tools.
|
||||
AuthServiceConfigs AuthServiceConfigs
|
||||
// EmbeddingModelConfigs defines a models used to embed parameters.
|
||||
EmbeddingModelConfigs EmbeddingModelConfigs
|
||||
// ToolConfigs defines what tools are available.
|
||||
ToolConfigs ToolConfigs
|
||||
// ToolsetConfigs defines what tools are available.
|
||||
@@ -205,6 +209,50 @@ func (c *AuthServiceConfigs) UnmarshalYAML(ctx context.Context, unmarshal func(i
|
||||
return nil
|
||||
}
|
||||
|
||||
// EmbeddingModelConfigs is a type used to allow unmarshal of the embedding model config map
|
||||
type EmbeddingModelConfigs map[string]embeddingmodels.EmbeddingModelConfig
|
||||
|
||||
// validate interface
|
||||
var _ yaml.InterfaceUnmarshalerContext = &EmbeddingModelConfigs{}
|
||||
|
||||
func (c *EmbeddingModelConfigs) UnmarshalYAML(ctx context.Context, unmarshal func(interface{}) error) error {
|
||||
*c = make(EmbeddingModelConfigs)
|
||||
// Parse the 'kind' fields for each embedding model
|
||||
var raw map[string]util.DelayedUnmarshaler
|
||||
if err := unmarshal(&raw); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for name, u := range raw {
|
||||
// Unmarshal to a general type that ensure it capture all fields
|
||||
var v map[string]any
|
||||
if err := u.Unmarshal(&v); err != nil {
|
||||
return fmt.Errorf("unable to unmarshal embedding model %q: %w", name, err)
|
||||
}
|
||||
|
||||
kind, ok := v["kind"]
|
||||
if !ok {
|
||||
return fmt.Errorf("missing 'kind' field for embedding model %q", name)
|
||||
}
|
||||
|
||||
dec, err := util.NewStrictDecoder(v)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error creating decoder: %w", err)
|
||||
}
|
||||
switch kind {
|
||||
case gemini.EmbeddingModelKind:
|
||||
actual := gemini.Config{Name: name}
|
||||
if err := dec.DecodeContext(ctx, &actual); err != nil {
|
||||
return fmt.Errorf("unable to parse as %q: %w", kind, err)
|
||||
}
|
||||
(*c)[name] = actual
|
||||
default:
|
||||
return fmt.Errorf("%q is not a valid kind of auth source", kind)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ToolConfigs is a type used to allow unmarshal of the tool configs
|
||||
type ToolConfigs map[string]tools.ToolConfig
|
||||
|
||||
|
||||
@@ -1107,7 +1107,7 @@ func TestStdioSession(t *testing.T) {
|
||||
|
||||
sseManager := newSseManager(ctx)
|
||||
|
||||
resourceManager := resources.NewResourceManager(nil, nil, toolsMap, toolsets, promptsMap, promptsets)
|
||||
resourceManager := resources.NewResourceManager(nil, nil, nil, toolsMap, toolsets, promptsMap, promptsets)
|
||||
|
||||
server := &Server{
|
||||
version: fakeVersionString,
|
||||
|
||||
@@ -18,6 +18,7 @@ import (
|
||||
"sync"
|
||||
|
||||
"github.com/googleapis/genai-toolbox/internal/auth"
|
||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
||||
"github.com/googleapis/genai-toolbox/internal/prompts"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
@@ -25,30 +26,33 @@ import (
|
||||
|
||||
// ResourceManager contains available resources for the server. Should be initialized with NewResourceManager().
|
||||
type ResourceManager struct {
|
||||
mu sync.RWMutex
|
||||
sources map[string]sources.Source
|
||||
authServices map[string]auth.AuthService
|
||||
tools map[string]tools.Tool
|
||||
toolsets map[string]tools.Toolset
|
||||
prompts map[string]prompts.Prompt
|
||||
promptsets map[string]prompts.Promptset
|
||||
mu sync.RWMutex
|
||||
sources map[string]sources.Source
|
||||
authServices map[string]auth.AuthService
|
||||
embeddingModels map[string]embeddingmodels.EmbeddingModel
|
||||
tools map[string]tools.Tool
|
||||
toolsets map[string]tools.Toolset
|
||||
prompts map[string]prompts.Prompt
|
||||
promptsets map[string]prompts.Promptset
|
||||
}
|
||||
|
||||
func NewResourceManager(
|
||||
sourcesMap map[string]sources.Source,
|
||||
authServicesMap map[string]auth.AuthService,
|
||||
embeddingModelsMap map[string]embeddingmodels.EmbeddingModel,
|
||||
toolsMap map[string]tools.Tool, toolsetsMap map[string]tools.Toolset,
|
||||
promptsMap map[string]prompts.Prompt, promptsetsMap map[string]prompts.Promptset,
|
||||
|
||||
) *ResourceManager {
|
||||
resourceMgr := &ResourceManager{
|
||||
mu: sync.RWMutex{},
|
||||
sources: sourcesMap,
|
||||
authServices: authServicesMap,
|
||||
tools: toolsMap,
|
||||
toolsets: toolsetsMap,
|
||||
prompts: promptsMap,
|
||||
promptsets: promptsetsMap,
|
||||
mu: sync.RWMutex{},
|
||||
sources: sourcesMap,
|
||||
authServices: authServicesMap,
|
||||
embeddingModels: embeddingModelsMap,
|
||||
tools: toolsMap,
|
||||
toolsets: toolsetsMap,
|
||||
prompts: promptsMap,
|
||||
promptsets: promptsetsMap,
|
||||
}
|
||||
|
||||
return resourceMgr
|
||||
@@ -68,6 +72,13 @@ func (r *ResourceManager) GetAuthService(authServiceName string) (auth.AuthServi
|
||||
return authService, ok
|
||||
}
|
||||
|
||||
func (r *ResourceManager) GetEmbeddingModel(embeddingModelName string) (embeddingmodels.EmbeddingModel, bool) {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
model, ok := r.embeddingModels[embeddingModelName]
|
||||
return model, ok
|
||||
}
|
||||
|
||||
func (r *ResourceManager) GetTool(toolName string) (tools.Tool, bool) {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
@@ -96,11 +107,12 @@ func (r *ResourceManager) GetPromptset(promptsetName string) (prompts.Promptset,
|
||||
return promptset, ok
|
||||
}
|
||||
|
||||
func (r *ResourceManager) SetResources(sourcesMap map[string]sources.Source, authServicesMap map[string]auth.AuthService, toolsMap map[string]tools.Tool, toolsetsMap map[string]tools.Toolset, promptsMap map[string]prompts.Prompt, promptsetsMap map[string]prompts.Promptset) {
|
||||
func (r *ResourceManager) SetResources(sourcesMap map[string]sources.Source, authServicesMap map[string]auth.AuthService, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel, toolsMap map[string]tools.Tool, toolsetsMap map[string]tools.Toolset, promptsMap map[string]prompts.Prompt, promptsetsMap map[string]prompts.Promptset) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
r.sources = sourcesMap
|
||||
r.authServices = authServicesMap
|
||||
r.embeddingModels = embeddingModelsMap
|
||||
r.tools = toolsMap
|
||||
r.toolsets = toolsetsMap
|
||||
r.prompts = promptsMap
|
||||
@@ -117,6 +129,16 @@ func (r *ResourceManager) GetAuthServiceMap() map[string]auth.AuthService {
|
||||
return copiedMap
|
||||
}
|
||||
|
||||
func (r *ResourceManager) GetEmbeddingModelMap() map[string]embeddingmodels.EmbeddingModel {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
copiedMap := make(map[string]embeddingmodels.EmbeddingModel, len(r.embeddingModels))
|
||||
for k, v := range r.embeddingModels {
|
||||
copiedMap[k] = v
|
||||
}
|
||||
return copiedMap
|
||||
}
|
||||
|
||||
func (r *ResourceManager) GetToolsMap() map[string]tools.Tool {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
@@ -19,6 +19,7 @@ import (
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/googleapis/genai-toolbox/internal/auth"
|
||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
||||
"github.com/googleapis/genai-toolbox/internal/prompts"
|
||||
"github.com/googleapis/genai-toolbox/internal/server/resources"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
@@ -36,6 +37,7 @@ func TestUpdateServer(t *testing.T) {
|
||||
},
|
||||
}
|
||||
newAuth := map[string]auth.AuthService{"example-auth": nil}
|
||||
newEmbeddingModels := map[string]embeddingmodels.EmbeddingModel{"example-model": nil}
|
||||
newTools := map[string]tools.Tool{"example-tool": nil}
|
||||
newToolsets := map[string]tools.Toolset{
|
||||
"example-toolset": {
|
||||
@@ -54,7 +56,7 @@ func TestUpdateServer(t *testing.T) {
|
||||
Prompts: []*prompts.Prompt{},
|
||||
},
|
||||
}
|
||||
resMgr := resources.NewResourceManager(newSources, newAuth, newTools, newToolsets, newPrompts, newPromptsets)
|
||||
resMgr := resources.NewResourceManager(newSources, newAuth, newEmbeddingModels, newTools, newToolsets, newPrompts, newPromptsets)
|
||||
|
||||
gotSource, _ := resMgr.GetSource("example-source")
|
||||
if diff := cmp.Diff(gotSource, newSources["example-source"]); diff != "" {
|
||||
@@ -95,7 +97,7 @@ func TestUpdateServer(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
resMgr.SetResources(updateSource, newAuth, newTools, newToolsets, newPrompts, newPromptsets)
|
||||
resMgr.SetResources(updateSource, newAuth, newEmbeddingModels, newTools, newToolsets, newPrompts, newPromptsets)
|
||||
gotSource, _ = resMgr.GetSource("example-source2")
|
||||
if diff := cmp.Diff(gotSource, updateSource["example-source2"]); diff != "" {
|
||||
t.Errorf("error updating server, sources (-want +got):\n%s", diff)
|
||||
|
||||
@@ -30,6 +30,7 @@ import (
|
||||
"github.com/go-chi/cors"
|
||||
"github.com/go-chi/httplog/v2"
|
||||
"github.com/googleapis/genai-toolbox/internal/auth"
|
||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
||||
"github.com/googleapis/genai-toolbox/internal/log"
|
||||
"github.com/googleapis/genai-toolbox/internal/prompts"
|
||||
"github.com/googleapis/genai-toolbox/internal/server/resources"
|
||||
@@ -56,6 +57,7 @@ type Server struct {
|
||||
func InitializeConfigs(ctx context.Context, cfg ServerConfig) (
|
||||
map[string]sources.Source,
|
||||
map[string]auth.AuthService,
|
||||
map[string]embeddingmodels.EmbeddingModel,
|
||||
map[string]tools.Tool,
|
||||
map[string]tools.Toolset,
|
||||
map[string]prompts.Prompt,
|
||||
@@ -91,7 +93,7 @@ func InitializeConfigs(ctx context.Context, cfg ServerConfig) (
|
||||
return s, nil
|
||||
}()
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, nil, nil, err
|
||||
return nil, nil, nil, nil, nil, nil, nil, err
|
||||
}
|
||||
sourcesMap[name] = s
|
||||
}
|
||||
@@ -119,7 +121,7 @@ func InitializeConfigs(ctx context.Context, cfg ServerConfig) (
|
||||
return a, nil
|
||||
}()
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, nil, nil, err
|
||||
return nil, nil, nil, nil, nil, nil, nil, err
|
||||
}
|
||||
authServicesMap[name] = a
|
||||
}
|
||||
@@ -129,6 +131,34 @@ func InitializeConfigs(ctx context.Context, cfg ServerConfig) (
|
||||
}
|
||||
l.InfoContext(ctx, fmt.Sprintf("Initialized %d authServices: %s", len(authServicesMap), strings.Join(authServiceNames, ", ")))
|
||||
|
||||
// Initialize and validate embedding models from configs.
|
||||
embeddingModelsMap := make(map[string]embeddingmodels.EmbeddingModel)
|
||||
for name, ec := range cfg.EmbeddingModelConfigs {
|
||||
em, err := func() (embeddingmodels.EmbeddingModel, error) {
|
||||
_, span := instrumentation.Tracer.Start(
|
||||
ctx,
|
||||
"toolbox/server/embeddingmodel/init",
|
||||
trace.WithAttributes(attribute.String("model_kind", ec.EmbeddingModelConfigKind())),
|
||||
trace.WithAttributes(attribute.String("model_name", name)),
|
||||
)
|
||||
defer span.End()
|
||||
em, err := ec.Initialize(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to initialize embedding model %q: %w", name, err)
|
||||
}
|
||||
return em, nil
|
||||
}()
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, nil, nil, nil, err
|
||||
}
|
||||
embeddingModelsMap[name] = em
|
||||
}
|
||||
embeddingModelNames := make([]string, 0, len(embeddingModelsMap))
|
||||
for name := range embeddingModelsMap {
|
||||
embeddingModelNames = append(embeddingModelNames, name)
|
||||
}
|
||||
l.InfoContext(ctx, fmt.Sprintf("Initialized %d embeddingModels: %s", len(embeddingModelsMap), strings.Join(embeddingModelNames, ", ")))
|
||||
|
||||
// initialize and validate the tools from configs
|
||||
toolsMap := make(map[string]tools.Tool)
|
||||
for name, tc := range cfg.ToolConfigs {
|
||||
@@ -147,7 +177,7 @@ func InitializeConfigs(ctx context.Context, cfg ServerConfig) (
|
||||
return t, nil
|
||||
}()
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, nil, nil, err
|
||||
return nil, nil, nil, nil, nil, nil, nil, err
|
||||
}
|
||||
toolsMap[name] = t
|
||||
}
|
||||
@@ -184,7 +214,7 @@ func InitializeConfigs(ctx context.Context, cfg ServerConfig) (
|
||||
return t, err
|
||||
}()
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, nil, nil, err
|
||||
return nil, nil, nil, nil, nil, nil, nil, err
|
||||
}
|
||||
toolsetsMap[name] = t
|
||||
}
|
||||
@@ -216,7 +246,7 @@ func InitializeConfigs(ctx context.Context, cfg ServerConfig) (
|
||||
return p, nil
|
||||
}()
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, nil, nil, err
|
||||
return nil, nil, nil, nil, nil, nil, nil, err
|
||||
}
|
||||
promptsMap[name] = p
|
||||
}
|
||||
@@ -253,7 +283,7 @@ func InitializeConfigs(ctx context.Context, cfg ServerConfig) (
|
||||
return p, err
|
||||
}()
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, nil, nil, err
|
||||
return nil, nil, nil, nil, nil, nil, nil, err
|
||||
}
|
||||
promptsetsMap[name] = p
|
||||
}
|
||||
@@ -267,7 +297,7 @@ func InitializeConfigs(ctx context.Context, cfg ServerConfig) (
|
||||
}
|
||||
l.InfoContext(ctx, fmt.Sprintf("Initialized %d promptsets: %s", len(promptsetsMap), strings.Join(promptsetNames, ", ")))
|
||||
|
||||
return sourcesMap, authServicesMap, toolsMap, toolsetsMap, promptsMap, promptsetsMap, nil
|
||||
return sourcesMap, authServicesMap, embeddingModelsMap, toolsMap, toolsetsMap, promptsMap, promptsetsMap, nil
|
||||
}
|
||||
|
||||
// NewServer returns a Server object based on provided Config.
|
||||
@@ -320,7 +350,7 @@ func NewServer(ctx context.Context, cfg ServerConfig) (*Server, error) {
|
||||
httpLogger := httplog.NewLogger("httplog", httpOpts)
|
||||
r.Use(httplog.RequestLogger(httpLogger))
|
||||
|
||||
sourcesMap, authServicesMap, toolsMap, toolsetsMap, promptsMap, promptsetsMap, err := InitializeConfigs(ctx, cfg)
|
||||
sourcesMap, authServicesMap, embeddingModelsMap, toolsMap, toolsetsMap, promptsMap, promptsetsMap, err := InitializeConfigs(ctx, cfg)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to initialize configs: %w", err)
|
||||
}
|
||||
@@ -330,7 +360,7 @@ func NewServer(ctx context.Context, cfg ServerConfig) (*Server, error) {
|
||||
|
||||
sseManager := newSseManager(ctx)
|
||||
|
||||
resourceManager := resources.NewResourceManager(sourcesMap, authServicesMap, toolsMap, toolsetsMap, promptsMap, promptsetsMap)
|
||||
resourceManager := resources.NewResourceManager(sourcesMap, authServicesMap, embeddingModelsMap, toolsMap, toolsetsMap, promptsMap, promptsetsMap)
|
||||
|
||||
s := &Server{
|
||||
version: cfg.Version,
|
||||
|
||||
@@ -25,6 +25,7 @@ import (
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/googleapis/genai-toolbox/internal/auth"
|
||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
||||
"github.com/googleapis/genai-toolbox/internal/log"
|
||||
"github.com/googleapis/genai-toolbox/internal/prompts"
|
||||
"github.com/googleapis/genai-toolbox/internal/server"
|
||||
@@ -144,6 +145,7 @@ func TestUpdateServer(t *testing.T) {
|
||||
},
|
||||
}
|
||||
newAuth := map[string]auth.AuthService{"example-auth": nil}
|
||||
newEmbeddingModels := map[string]embeddingmodels.EmbeddingModel{"example-model": nil}
|
||||
newTools := map[string]tools.Tool{"example-tool": nil}
|
||||
newToolsets := map[string]tools.Toolset{
|
||||
"example-toolset": {
|
||||
@@ -162,7 +164,7 @@ func TestUpdateServer(t *testing.T) {
|
||||
Prompts: []*prompts.Prompt{},
|
||||
},
|
||||
}
|
||||
s.ResourceMgr.SetResources(newSources, newAuth, newTools, newToolsets, newPrompts, newPromptsets)
|
||||
s.ResourceMgr.SetResources(newSources, newAuth, newEmbeddingModels, newTools, newToolsets, newPrompts, newPromptsets)
|
||||
if err != nil {
|
||||
t.Errorf("error updating server: %s", err)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user