mirror of
https://github.com/googleapis/genai-toolbox.git
synced 2026-02-05 12:45:11 -05:00
Currently, we are throwing 401 error immediately after auth token verification failure. This is not expected in the following situations: 1. Non-auth tool invocation with auth token that is invalid. 2. Auth tool invocation with all the required auth token, but the header contains extra non-required token that is invalid These requests should pass the authorization check but fail under the current implementation. Change made in this PR: 1. Do not throw error immediately after auth token verification failure. Instead only log it and continue to the next header iteration. 2. In the parseParams() method, if an auth parameter is missing, we should error with the message telling the user that either the auth header is missing or is invalid.
275 lines
8.4 KiB
Go
275 lines
8.4 KiB
Go
// Copyright 2024 Google LLC
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
package server
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
|
|
"github.com/go-chi/chi/v5"
|
|
"github.com/go-chi/chi/v5/middleware"
|
|
"github.com/go-chi/render"
|
|
"github.com/googleapis/genai-toolbox/internal/tools"
|
|
"go.opentelemetry.io/otel/attribute"
|
|
"go.opentelemetry.io/otel/codes"
|
|
"go.opentelemetry.io/otel/metric"
|
|
)
|
|
|
|
// apiRouter creates a router that represents the routes under /api
|
|
func apiRouter(s *Server) (chi.Router, error) {
|
|
r := chi.NewRouter()
|
|
|
|
r.Use(middleware.AllowContentType("application/json"))
|
|
r.Use(middleware.StripSlashes)
|
|
r.Use(render.SetContentType(render.ContentTypeJSON))
|
|
|
|
r.Get("/toolset", func(w http.ResponseWriter, r *http.Request) { toolsetHandler(s, w, r) })
|
|
r.Get("/toolset/{toolsetName}", func(w http.ResponseWriter, r *http.Request) { toolsetHandler(s, w, r) })
|
|
|
|
r.Route("/tool/{toolName}", func(r chi.Router) {
|
|
r.Get("/", func(w http.ResponseWriter, r *http.Request) { toolGetHandler(s, w, r) })
|
|
r.Post("/invoke", func(w http.ResponseWriter, r *http.Request) { toolInvokeHandler(s, w, r) })
|
|
})
|
|
|
|
return r, nil
|
|
}
|
|
|
|
// toolsetHandler handles the request for information about a Toolset.
|
|
func toolsetHandler(s *Server, w http.ResponseWriter, r *http.Request) {
|
|
ctx, span := s.instrumentation.Tracer.Start(r.Context(), "toolbox/server/toolset/get")
|
|
r = r.WithContext(ctx)
|
|
|
|
toolsetName := chi.URLParam(r, "toolsetName")
|
|
span.SetAttributes(attribute.String("toolset_name", toolsetName))
|
|
var err error
|
|
defer func() {
|
|
if err != nil {
|
|
span.SetStatus(codes.Error, err.Error())
|
|
}
|
|
span.End()
|
|
|
|
status := "success"
|
|
if err != nil {
|
|
status = "error"
|
|
}
|
|
s.instrumentation.ToolsetGet.Add(
|
|
r.Context(),
|
|
1,
|
|
metric.WithAttributes(attribute.String("toolbox.name", toolsetName)),
|
|
metric.WithAttributes(attribute.String("toolbox.operation.status", status)),
|
|
)
|
|
}()
|
|
|
|
toolset, ok := s.toolsets[toolsetName]
|
|
if !ok {
|
|
err = fmt.Errorf("Toolset %q does not exist", toolsetName)
|
|
s.logger.DebugContext(context.Background(), err.Error())
|
|
_ = render.Render(w, r, newErrResponse(err, http.StatusNotFound))
|
|
return
|
|
}
|
|
render.JSON(w, r, toolset.Manifest)
|
|
}
|
|
|
|
// toolGetHandler handles requests for a single Tool.
|
|
func toolGetHandler(s *Server, w http.ResponseWriter, r *http.Request) {
|
|
ctx, span := s.instrumentation.Tracer.Start(r.Context(), "toolbox/server/tool/get")
|
|
r = r.WithContext(ctx)
|
|
|
|
toolName := chi.URLParam(r, "toolName")
|
|
span.SetAttributes(attribute.String("tool_name", toolName))
|
|
var err error
|
|
defer func() {
|
|
if err != nil {
|
|
span.SetStatus(codes.Error, err.Error())
|
|
}
|
|
span.End()
|
|
|
|
status := "success"
|
|
if err != nil {
|
|
status = "error"
|
|
}
|
|
s.instrumentation.ToolGet.Add(
|
|
r.Context(),
|
|
1,
|
|
metric.WithAttributes(attribute.String("toolbox.name", toolName)),
|
|
metric.WithAttributes(attribute.String("toolbox.operation.status", status)),
|
|
)
|
|
}()
|
|
tool, ok := s.tools[toolName]
|
|
if !ok {
|
|
err = fmt.Errorf("invalid tool name: tool with name %q does not exist", toolName)
|
|
s.logger.DebugContext(context.Background(), err.Error())
|
|
_ = render.Render(w, r, newErrResponse(err, http.StatusNotFound))
|
|
return
|
|
}
|
|
// TODO: this can be optimized later with some caching
|
|
m := tools.ToolsetManifest{
|
|
ServerVersion: s.version,
|
|
ToolsManifest: map[string]tools.Manifest{
|
|
toolName: tool.Manifest(),
|
|
},
|
|
}
|
|
|
|
render.JSON(w, r, m)
|
|
}
|
|
|
|
// toolInvokeHandler handles the API request to invoke a specific Tool.
|
|
func toolInvokeHandler(s *Server, w http.ResponseWriter, r *http.Request) {
|
|
ctx, span := s.instrumentation.Tracer.Start(r.Context(), "toolbox/server/tool/invoke")
|
|
r = r.WithContext(ctx)
|
|
|
|
toolName := chi.URLParam(r, "toolName")
|
|
span.SetAttributes(attribute.String("tool_name", toolName))
|
|
var err error
|
|
defer func() {
|
|
if err != nil {
|
|
span.SetStatus(codes.Error, err.Error())
|
|
}
|
|
span.End()
|
|
|
|
status := "success"
|
|
if err != nil {
|
|
status = "error"
|
|
}
|
|
s.instrumentation.ToolInvoke.Add(
|
|
r.Context(),
|
|
1,
|
|
metric.WithAttributes(attribute.String("toolbox.name", toolName)),
|
|
metric.WithAttributes(attribute.String("toolbox.operation.status", status)),
|
|
)
|
|
}()
|
|
|
|
tool, ok := s.tools[toolName]
|
|
if !ok {
|
|
err = fmt.Errorf("invalid tool name: tool with name %q does not exist", toolName)
|
|
s.logger.DebugContext(context.Background(), err.Error())
|
|
_ = render.Render(w, r, newErrResponse(err, http.StatusNotFound))
|
|
return
|
|
}
|
|
|
|
// Tool authentication
|
|
// claimsFromAuth maps the name of the authsource to the claims retrieved from it.
|
|
claimsFromAuth := make(map[string]map[string]any)
|
|
for _, aS := range s.authSources {
|
|
claims, err := aS.GetClaimsFromHeader(r.Header)
|
|
if err != nil {
|
|
s.logger.DebugContext(context.Background(), err.Error())
|
|
continue
|
|
}
|
|
if claims == nil {
|
|
// authSource not present in header
|
|
continue
|
|
}
|
|
claimsFromAuth[aS.GetName()] = claims
|
|
}
|
|
|
|
// Tool authorization check
|
|
verifiedAuthSources := make([]string, len(claimsFromAuth))
|
|
i := 0
|
|
for k := range claimsFromAuth {
|
|
verifiedAuthSources[i] = k
|
|
i++
|
|
}
|
|
|
|
// Check if any of the specified auth sources is verified
|
|
isAuthorized := tool.Authorized(verifiedAuthSources)
|
|
if !isAuthorized {
|
|
err = fmt.Errorf("tool invocation not authorized. Please make sure your specify correct auth headers")
|
|
s.logger.DebugContext(context.Background(), err.Error())
|
|
_ = render.Render(w, r, newErrResponse(err, http.StatusUnauthorized))
|
|
return
|
|
}
|
|
|
|
var data map[string]any
|
|
if err = decodeJSON(r.Body, &data); err != nil {
|
|
render.Status(r, http.StatusBadRequest)
|
|
err = fmt.Errorf("request body was invalid JSON: %w", err)
|
|
s.logger.DebugContext(context.Background(), err.Error())
|
|
_ = render.Render(w, r, newErrResponse(err, http.StatusBadRequest))
|
|
return
|
|
}
|
|
|
|
params, err := tool.ParseParams(data, claimsFromAuth)
|
|
if err != nil {
|
|
err = fmt.Errorf("provided parameters were invalid: %w", err)
|
|
s.logger.DebugContext(context.Background(), err.Error())
|
|
_ = render.Render(w, r, newErrResponse(err, http.StatusBadRequest))
|
|
return
|
|
}
|
|
|
|
res, err := tool.Invoke(params)
|
|
if err != nil {
|
|
err = fmt.Errorf("error while invoking tool: %w", err)
|
|
s.logger.DebugContext(context.Background(), err.Error())
|
|
_ = render.Render(w, r, newErrResponse(err, http.StatusInternalServerError))
|
|
return
|
|
}
|
|
|
|
_ = render.Render(w, r, &resultResponse{Result: res})
|
|
}
|
|
|
|
var _ render.Renderer = &resultResponse{} // Renderer interface for managing response payloads.
|
|
|
|
// resultResponse is the response sent back when the tool was invocated successfully.
|
|
type resultResponse struct {
|
|
Result string `json:"result"` // result of tool invocation
|
|
}
|
|
|
|
// Render renders a single payload and respond to the client request.
|
|
func (rr resultResponse) Render(w http.ResponseWriter, r *http.Request) error {
|
|
render.Status(r, http.StatusOK)
|
|
return nil
|
|
}
|
|
|
|
var _ render.Renderer = &errResponse{} // Renderer interface for managing response payloads.
|
|
|
|
// newErrResponse is a helper function initalizing an ErrResponse
|
|
func newErrResponse(err error, code int) *errResponse {
|
|
return &errResponse{
|
|
Err: err,
|
|
HTTPStatusCode: code,
|
|
|
|
StatusText: http.StatusText(code),
|
|
ErrorText: err.Error(),
|
|
}
|
|
}
|
|
|
|
// errResponse is the response sent back when an error has been encountered.
|
|
type errResponse struct {
|
|
Err error `json:"-"` // low-level runtime error
|
|
HTTPStatusCode int `json:"-"` // http response status code
|
|
|
|
StatusText string `json:"status"` // user-level status message
|
|
ErrorText string `json:"error,omitempty"` // application-level error message, for debugging
|
|
}
|
|
|
|
func (e *errResponse) Render(w http.ResponseWriter, r *http.Request) error {
|
|
render.Status(r, e.HTTPStatusCode)
|
|
return nil
|
|
}
|
|
|
|
// decodeJSON decodes a given reader into an interface using the json decoder.
|
|
func decodeJSON(r io.Reader, v interface{}) error {
|
|
defer io.Copy(io.Discard, r) //nolint:errcheck
|
|
d := json.NewDecoder(r)
|
|
// specify JSON numbers should get parsed to json.Number instead of float64 by default.
|
|
// This prevents loss between floats/ints.
|
|
d.UseNumber()
|
|
return d.Decode(v)
|
|
}
|