mirror of
https://github.com/danielmiessler/Fabric.git
synced 2026-01-08 22:08:03 -05:00
292 lines
9.2 KiB
Go
292 lines
9.2 KiB
Go
// Package bedrock provides a plugin to use Amazon Bedrock models.
|
|
// Supported models are defined in the MODELS variable.
|
|
// To add additional models, append them to the MODELS array. Models must support the Converse and ConverseStream operations
|
|
// Authentication uses the AWS credential provider chain, similar.to the AWS CLI and SDKs
|
|
// https://docs.aws.amazon.com/sdkref/latest/guide/standardized-credentials.html
|
|
package bedrock
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
|
|
"github.com/danielmiessler/fabric/internal/domain"
|
|
"github.com/danielmiessler/fabric/internal/plugins"
|
|
"github.com/danielmiessler/fabric/internal/plugins/ai"
|
|
|
|
"github.com/aws/aws-sdk-go-v2/aws"
|
|
"github.com/aws/aws-sdk-go-v2/aws/middleware"
|
|
"github.com/aws/aws-sdk-go-v2/config"
|
|
"github.com/aws/aws-sdk-go-v2/service/bedrock"
|
|
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
|
|
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime/types"
|
|
|
|
"github.com/danielmiessler/fabric/internal/chat"
|
|
)
|
|
|
|
const (
|
|
userAgentKey = "aiosc"
|
|
userAgentValue = "fabric"
|
|
)
|
|
|
|
// Ensure BedrockClient implements the ai.Vendor interface
|
|
var _ ai.Vendor = (*BedrockClient)(nil)
|
|
|
|
// BedrockClient is a plugin to add support for Amazon Bedrock.
|
|
// It implements the plugins.Plugin interface and provides methods
|
|
// for interacting with AWS Bedrock's Converse and ConverseStream APIs.
|
|
type BedrockClient struct {
|
|
*plugins.PluginBase
|
|
runtimeClient *bedrockruntime.Client
|
|
controlPlaneClient *bedrock.Client
|
|
|
|
bedrockRegion *plugins.SetupQuestion
|
|
}
|
|
|
|
// NewClient returns a new Bedrock plugin client
|
|
func NewClient() (ret *BedrockClient) {
|
|
vendorName := "Bedrock"
|
|
ret = &BedrockClient{}
|
|
|
|
ctx := context.Background()
|
|
cfg, err := config.LoadDefaultConfig(ctx)
|
|
if err != nil {
|
|
// Create a minimal client that will fail gracefully during configuration
|
|
ret.PluginBase = &plugins.PluginBase{
|
|
Name: vendorName,
|
|
EnvNamePrefix: plugins.BuildEnvVariablePrefix(vendorName),
|
|
ConfigureCustom: func() error {
|
|
return fmt.Errorf("unable to load AWS Config: %w", err)
|
|
},
|
|
}
|
|
ret.bedrockRegion = ret.PluginBase.AddSetupQuestion("AWS Region", true)
|
|
return
|
|
}
|
|
|
|
cfg.APIOptions = append(cfg.APIOptions, middleware.AddUserAgentKeyValue(userAgentKey, userAgentValue))
|
|
|
|
runtimeClient := bedrockruntime.NewFromConfig(cfg)
|
|
controlPlaneClient := bedrock.NewFromConfig(cfg)
|
|
|
|
ret.PluginBase = &plugins.PluginBase{
|
|
Name: vendorName,
|
|
EnvNamePrefix: plugins.BuildEnvVariablePrefix(vendorName),
|
|
ConfigureCustom: ret.configure,
|
|
}
|
|
|
|
ret.runtimeClient = runtimeClient
|
|
ret.controlPlaneClient = controlPlaneClient
|
|
|
|
ret.bedrockRegion = ret.PluginBase.AddSetupQuestion("AWS Region", true)
|
|
|
|
if cfg.Region != "" {
|
|
ret.bedrockRegion.Value = cfg.Region
|
|
}
|
|
|
|
return
|
|
}
|
|
|
|
// isValidAWSRegion validates AWS region format
|
|
func isValidAWSRegion(region string) bool {
|
|
// Simple validation - AWS regions are typically 2-3 parts separated by hyphens
|
|
// Examples: us-east-1, eu-west-1, ap-southeast-2
|
|
if len(region) < 5 || len(region) > 30 {
|
|
return false
|
|
}
|
|
// Basic pattern check for AWS region format
|
|
return region != ""
|
|
}
|
|
|
|
// configure initializes the Bedrock clients with the specified AWS region.
|
|
// If no region is specified, the default region from AWS config is used.
|
|
func (c *BedrockClient) configure() error {
|
|
if c.bedrockRegion.Value == "" {
|
|
return nil // Use default region from AWS config
|
|
}
|
|
|
|
// Validate region format
|
|
if !isValidAWSRegion(c.bedrockRegion.Value) {
|
|
return fmt.Errorf("invalid AWS region: %s", c.bedrockRegion.Value)
|
|
}
|
|
|
|
ctx := context.Background()
|
|
cfg, err := config.LoadDefaultConfig(ctx, config.WithRegion(c.bedrockRegion.Value))
|
|
if err != nil {
|
|
return fmt.Errorf("unable to load AWS Config with region %s: %w", c.bedrockRegion.Value, err)
|
|
}
|
|
|
|
cfg.APIOptions = append(cfg.APIOptions, middleware.AddUserAgentKeyValue(userAgentKey, userAgentValue))
|
|
|
|
c.runtimeClient = bedrockruntime.NewFromConfig(cfg)
|
|
c.controlPlaneClient = bedrock.NewFromConfig(cfg)
|
|
|
|
return nil
|
|
}
|
|
|
|
// ListModels retrieves all available foundation models and inference profiles
|
|
// from AWS Bedrock that can be used with this plugin.
|
|
func (c *BedrockClient) ListModels() ([]string, error) {
|
|
models := []string{}
|
|
ctx := context.Background()
|
|
|
|
foundationModels, err := c.controlPlaneClient.ListFoundationModels(ctx, &bedrock.ListFoundationModelsInput{})
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to list foundation models: %w", err)
|
|
}
|
|
|
|
for _, model := range foundationModels.ModelSummaries {
|
|
models = append(models, *model.ModelId)
|
|
}
|
|
|
|
inferenceProfilesPaginator := bedrock.NewListInferenceProfilesPaginator(c.controlPlaneClient, &bedrock.ListInferenceProfilesInput{})
|
|
|
|
for inferenceProfilesPaginator.HasMorePages() {
|
|
inferenceProfiles, err := inferenceProfilesPaginator.NextPage(ctx)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to list inference profiles: %w", err)
|
|
}
|
|
|
|
for _, profile := range inferenceProfiles.InferenceProfileSummaries {
|
|
models = append(models, *profile.InferenceProfileId)
|
|
}
|
|
}
|
|
|
|
return models, nil
|
|
}
|
|
|
|
// SendStream sends the messages to the Bedrock ConverseStream API
|
|
func (c *BedrockClient) SendStream(msgs []*chat.ChatCompletionMessage, opts *domain.ChatOptions, channel chan domain.StreamUpdate) (err error) {
|
|
// Ensure channel is closed on all exit paths to prevent goroutine leaks
|
|
defer func() {
|
|
if r := recover(); r != nil {
|
|
err = fmt.Errorf("panic in SendStream: %v", r)
|
|
}
|
|
close(channel)
|
|
}()
|
|
|
|
messages := c.toMessages(msgs)
|
|
|
|
var converseInput = bedrockruntime.ConverseStreamInput{
|
|
ModelId: aws.String(opts.Model),
|
|
Messages: messages,
|
|
InferenceConfig: &types.InferenceConfiguration{
|
|
Temperature: aws.Float32(float32(opts.Temperature)),
|
|
TopP: aws.Float32(float32(opts.TopP))},
|
|
}
|
|
|
|
response, err := c.runtimeClient.ConverseStream(context.Background(), &converseInput)
|
|
if err != nil {
|
|
return fmt.Errorf("bedrock conversestream failed for model %s: %w", opts.Model, err)
|
|
}
|
|
|
|
for event := range response.GetStream().Events() {
|
|
// Possible ConverseStream event types
|
|
// https://docs.aws.amazon.com/bedrock/latest/userguide/conversation-inference-call.html#conversation-inference-call-response-converse-stream
|
|
switch v := event.(type) {
|
|
|
|
case *types.ConverseStreamOutputMemberContentBlockDelta:
|
|
text, ok := v.Value.Delta.(*types.ContentBlockDeltaMemberText)
|
|
if ok {
|
|
channel <- domain.StreamUpdate{
|
|
Type: domain.StreamTypeContent,
|
|
Content: text.Value,
|
|
}
|
|
}
|
|
|
|
case *types.ConverseStreamOutputMemberMessageStop:
|
|
channel <- domain.StreamUpdate{
|
|
Type: domain.StreamTypeContent,
|
|
Content: "\n",
|
|
}
|
|
return nil // Let defer handle the close
|
|
|
|
case *types.ConverseStreamOutputMemberMetadata:
|
|
if v.Value.Usage != nil {
|
|
channel <- domain.StreamUpdate{
|
|
Type: domain.StreamTypeUsage,
|
|
Usage: &domain.UsageMetadata{
|
|
InputTokens: int(*v.Value.Usage.InputTokens),
|
|
OutputTokens: int(*v.Value.Usage.OutputTokens),
|
|
TotalTokens: int(*v.Value.Usage.TotalTokens),
|
|
},
|
|
}
|
|
}
|
|
|
|
// Unused Events
|
|
case *types.ConverseStreamOutputMemberMessageStart,
|
|
*types.ConverseStreamOutputMemberContentBlockStart,
|
|
*types.ConverseStreamOutputMemberContentBlockStop:
|
|
|
|
default:
|
|
return fmt.Errorf("unknown stream event type: %T", v)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// Send sends the messages the Bedrock Converse API
|
|
func (c *BedrockClient) Send(ctx context.Context, msgs []*chat.ChatCompletionMessage, opts *domain.ChatOptions) (ret string, err error) {
|
|
|
|
messages := c.toMessages(msgs)
|
|
|
|
var converseInput = bedrockruntime.ConverseInput{
|
|
ModelId: aws.String(opts.Model),
|
|
Messages: messages,
|
|
}
|
|
response, err := c.runtimeClient.Converse(ctx, &converseInput)
|
|
if err != nil {
|
|
return "", fmt.Errorf("bedrock converse failed for model %s: %w", opts.Model, err)
|
|
}
|
|
|
|
responseText, ok := response.Output.(*types.ConverseOutputMemberMessage)
|
|
if !ok {
|
|
return "", fmt.Errorf("unexpected response type: %T", response.Output)
|
|
}
|
|
|
|
if len(responseText.Value.Content) == 0 {
|
|
return "", fmt.Errorf("empty response content")
|
|
}
|
|
|
|
responseContentBlock := responseText.Value.Content[0]
|
|
text, ok := responseContentBlock.(*types.ContentBlockMemberText)
|
|
if !ok {
|
|
return "", fmt.Errorf("unexpected content block type: %T", responseContentBlock)
|
|
}
|
|
|
|
return text.Value, nil
|
|
}
|
|
|
|
// NeedsRawMode indicates whether the model requires raw mode processing.
|
|
// Bedrock models do not require raw mode.
|
|
func (c *BedrockClient) NeedsRawMode(modelName string) bool {
|
|
return false
|
|
}
|
|
|
|
// toMessages converts the array of input messages from the ChatCompletionMessageType to the
|
|
// Bedrock Converse Message type.
|
|
// The system role messages are mapped to the user role as they contain a mix of system messages,
|
|
// pattern content and user input.
|
|
func (c *BedrockClient) toMessages(inputMessages []*chat.ChatCompletionMessage) (messages []types.Message) {
|
|
for _, msg := range inputMessages {
|
|
roles := map[string]types.ConversationRole{
|
|
chat.ChatMessageRoleUser: types.ConversationRoleUser,
|
|
chat.ChatMessageRoleAssistant: types.ConversationRoleAssistant,
|
|
chat.ChatMessageRoleSystem: types.ConversationRoleUser,
|
|
}
|
|
|
|
role, ok := roles[msg.Role]
|
|
if !ok {
|
|
continue
|
|
}
|
|
|
|
message := types.Message{
|
|
Role: role,
|
|
Content: []types.ContentBlock{&types.ContentBlockMemberText{Value: msg.Content}},
|
|
}
|
|
messages = append(messages, message)
|
|
|
|
}
|
|
|
|
return
|
|
}
|