Files
Fabric/internal/plugins/ai/bedrock/bedrock.go

292 lines
9.2 KiB
Go

// Package bedrock provides a plugin to use Amazon Bedrock models.
// Supported models are defined in the MODELS variable.
// To add additional models, append them to the MODELS array. Models must support the Converse and ConverseStream operations
// Authentication uses the AWS credential provider chain, similar.to the AWS CLI and SDKs
// https://docs.aws.amazon.com/sdkref/latest/guide/standardized-credentials.html
package bedrock
import (
"context"
"fmt"
"github.com/danielmiessler/fabric/internal/domain"
"github.com/danielmiessler/fabric/internal/plugins"
"github.com/danielmiessler/fabric/internal/plugins/ai"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/aws/middleware"
"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/service/bedrock"
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime/types"
"github.com/danielmiessler/fabric/internal/chat"
)
const (
userAgentKey = "aiosc"
userAgentValue = "fabric"
)
// Ensure BedrockClient implements the ai.Vendor interface
var _ ai.Vendor = (*BedrockClient)(nil)
// BedrockClient is a plugin to add support for Amazon Bedrock.
// It implements the plugins.Plugin interface and provides methods
// for interacting with AWS Bedrock's Converse and ConverseStream APIs.
type BedrockClient struct {
*plugins.PluginBase
runtimeClient *bedrockruntime.Client
controlPlaneClient *bedrock.Client
bedrockRegion *plugins.SetupQuestion
}
// NewClient returns a new Bedrock plugin client
func NewClient() (ret *BedrockClient) {
vendorName := "Bedrock"
ret = &BedrockClient{}
ctx := context.Background()
cfg, err := config.LoadDefaultConfig(ctx)
if err != nil {
// Create a minimal client that will fail gracefully during configuration
ret.PluginBase = &plugins.PluginBase{
Name: vendorName,
EnvNamePrefix: plugins.BuildEnvVariablePrefix(vendorName),
ConfigureCustom: func() error {
return fmt.Errorf("unable to load AWS Config: %w", err)
},
}
ret.bedrockRegion = ret.PluginBase.AddSetupQuestion("AWS Region", true)
return
}
cfg.APIOptions = append(cfg.APIOptions, middleware.AddUserAgentKeyValue(userAgentKey, userAgentValue))
runtimeClient := bedrockruntime.NewFromConfig(cfg)
controlPlaneClient := bedrock.NewFromConfig(cfg)
ret.PluginBase = &plugins.PluginBase{
Name: vendorName,
EnvNamePrefix: plugins.BuildEnvVariablePrefix(vendorName),
ConfigureCustom: ret.configure,
}
ret.runtimeClient = runtimeClient
ret.controlPlaneClient = controlPlaneClient
ret.bedrockRegion = ret.PluginBase.AddSetupQuestion("AWS Region", true)
if cfg.Region != "" {
ret.bedrockRegion.Value = cfg.Region
}
return
}
// isValidAWSRegion validates AWS region format
func isValidAWSRegion(region string) bool {
// Simple validation - AWS regions are typically 2-3 parts separated by hyphens
// Examples: us-east-1, eu-west-1, ap-southeast-2
if len(region) < 5 || len(region) > 30 {
return false
}
// Basic pattern check for AWS region format
return region != ""
}
// configure initializes the Bedrock clients with the specified AWS region.
// If no region is specified, the default region from AWS config is used.
func (c *BedrockClient) configure() error {
if c.bedrockRegion.Value == "" {
return nil // Use default region from AWS config
}
// Validate region format
if !isValidAWSRegion(c.bedrockRegion.Value) {
return fmt.Errorf("invalid AWS region: %s", c.bedrockRegion.Value)
}
ctx := context.Background()
cfg, err := config.LoadDefaultConfig(ctx, config.WithRegion(c.bedrockRegion.Value))
if err != nil {
return fmt.Errorf("unable to load AWS Config with region %s: %w", c.bedrockRegion.Value, err)
}
cfg.APIOptions = append(cfg.APIOptions, middleware.AddUserAgentKeyValue(userAgentKey, userAgentValue))
c.runtimeClient = bedrockruntime.NewFromConfig(cfg)
c.controlPlaneClient = bedrock.NewFromConfig(cfg)
return nil
}
// ListModels retrieves all available foundation models and inference profiles
// from AWS Bedrock that can be used with this plugin.
func (c *BedrockClient) ListModels() ([]string, error) {
models := []string{}
ctx := context.Background()
foundationModels, err := c.controlPlaneClient.ListFoundationModels(ctx, &bedrock.ListFoundationModelsInput{})
if err != nil {
return nil, fmt.Errorf("failed to list foundation models: %w", err)
}
for _, model := range foundationModels.ModelSummaries {
models = append(models, *model.ModelId)
}
inferenceProfilesPaginator := bedrock.NewListInferenceProfilesPaginator(c.controlPlaneClient, &bedrock.ListInferenceProfilesInput{})
for inferenceProfilesPaginator.HasMorePages() {
inferenceProfiles, err := inferenceProfilesPaginator.NextPage(ctx)
if err != nil {
return nil, fmt.Errorf("failed to list inference profiles: %w", err)
}
for _, profile := range inferenceProfiles.InferenceProfileSummaries {
models = append(models, *profile.InferenceProfileId)
}
}
return models, nil
}
// SendStream sends the messages to the Bedrock ConverseStream API
func (c *BedrockClient) SendStream(msgs []*chat.ChatCompletionMessage, opts *domain.ChatOptions, channel chan domain.StreamUpdate) (err error) {
// Ensure channel is closed on all exit paths to prevent goroutine leaks
defer func() {
if r := recover(); r != nil {
err = fmt.Errorf("panic in SendStream: %v", r)
}
close(channel)
}()
messages := c.toMessages(msgs)
var converseInput = bedrockruntime.ConverseStreamInput{
ModelId: aws.String(opts.Model),
Messages: messages,
InferenceConfig: &types.InferenceConfiguration{
Temperature: aws.Float32(float32(opts.Temperature)),
TopP: aws.Float32(float32(opts.TopP))},
}
response, err := c.runtimeClient.ConverseStream(context.Background(), &converseInput)
if err != nil {
return fmt.Errorf("bedrock conversestream failed for model %s: %w", opts.Model, err)
}
for event := range response.GetStream().Events() {
// Possible ConverseStream event types
// https://docs.aws.amazon.com/bedrock/latest/userguide/conversation-inference-call.html#conversation-inference-call-response-converse-stream
switch v := event.(type) {
case *types.ConverseStreamOutputMemberContentBlockDelta:
text, ok := v.Value.Delta.(*types.ContentBlockDeltaMemberText)
if ok {
channel <- domain.StreamUpdate{
Type: domain.StreamTypeContent,
Content: text.Value,
}
}
case *types.ConverseStreamOutputMemberMessageStop:
channel <- domain.StreamUpdate{
Type: domain.StreamTypeContent,
Content: "\n",
}
return nil // Let defer handle the close
case *types.ConverseStreamOutputMemberMetadata:
if v.Value.Usage != nil {
channel <- domain.StreamUpdate{
Type: domain.StreamTypeUsage,
Usage: &domain.UsageMetadata{
InputTokens: int(*v.Value.Usage.InputTokens),
OutputTokens: int(*v.Value.Usage.OutputTokens),
TotalTokens: int(*v.Value.Usage.TotalTokens),
},
}
}
// Unused Events
case *types.ConverseStreamOutputMemberMessageStart,
*types.ConverseStreamOutputMemberContentBlockStart,
*types.ConverseStreamOutputMemberContentBlockStop:
default:
return fmt.Errorf("unknown stream event type: %T", v)
}
}
return nil
}
// Send sends the messages the Bedrock Converse API
func (c *BedrockClient) Send(ctx context.Context, msgs []*chat.ChatCompletionMessage, opts *domain.ChatOptions) (ret string, err error) {
messages := c.toMessages(msgs)
var converseInput = bedrockruntime.ConverseInput{
ModelId: aws.String(opts.Model),
Messages: messages,
}
response, err := c.runtimeClient.Converse(ctx, &converseInput)
if err != nil {
return "", fmt.Errorf("bedrock converse failed for model %s: %w", opts.Model, err)
}
responseText, ok := response.Output.(*types.ConverseOutputMemberMessage)
if !ok {
return "", fmt.Errorf("unexpected response type: %T", response.Output)
}
if len(responseText.Value.Content) == 0 {
return "", fmt.Errorf("empty response content")
}
responseContentBlock := responseText.Value.Content[0]
text, ok := responseContentBlock.(*types.ContentBlockMemberText)
if !ok {
return "", fmt.Errorf("unexpected content block type: %T", responseContentBlock)
}
return text.Value, nil
}
// NeedsRawMode indicates whether the model requires raw mode processing.
// Bedrock models do not require raw mode.
func (c *BedrockClient) NeedsRawMode(modelName string) bool {
return false
}
// toMessages converts the array of input messages from the ChatCompletionMessageType to the
// Bedrock Converse Message type.
// The system role messages are mapped to the user role as they contain a mix of system messages,
// pattern content and user input.
func (c *BedrockClient) toMessages(inputMessages []*chat.ChatCompletionMessage) (messages []types.Message) {
for _, msg := range inputMessages {
roles := map[string]types.ConversationRole{
chat.ChatMessageRoleUser: types.ConversationRoleUser,
chat.ChatMessageRoleAssistant: types.ConversationRoleAssistant,
chat.ChatMessageRoleSystem: types.ConversationRoleUser,
}
role, ok := roles[msg.Role]
if !ok {
continue
}
message := types.Message{
Role: role,
Content: []types.ContentBlock{&types.ContentBlockMemberText{Value: msg.Content}},
}
messages = append(messages, message)
}
return
}