diff --git a/internal/plugins/ai/anthropic/anthropic.go b/internal/plugins/ai/anthropic/anthropic.go index 3c582935..bf8ddd56 100644 --- a/internal/plugins/ai/anthropic/anthropic.go +++ b/internal/plugins/ai/anthropic/anthropic.go @@ -29,11 +29,7 @@ func NewClient() (ret *Client) { vendorName := "Anthropic" ret = &Client{} - ret.PluginBase = &plugins.PluginBase{ - Name: vendorName, - EnvNamePrefix: plugins.BuildEnvVariablePrefix(vendorName), - ConfigureCustom: ret.configure, - } + ret.PluginBase = plugins.NewVendorPluginBase(vendorName, ret.configure) ret.ApiBaseURL = ret.AddSetupQuestion("API Base URL", false) ret.ApiBaseURL.Value = defaultBaseUrl diff --git a/internal/plugins/ai/bedrock/bedrock.go b/internal/plugins/ai/bedrock/bedrock.go index 5421b2f6..2d54c8b1 100644 --- a/internal/plugins/ai/bedrock/bedrock.go +++ b/internal/plugins/ai/bedrock/bedrock.go @@ -51,13 +51,9 @@ func NewClient() (ret *BedrockClient) { cfg, err := config.LoadDefaultConfig(ctx) if err != nil { // Create a minimal client that will fail gracefully during configuration - ret.PluginBase = &plugins.PluginBase{ - Name: vendorName, - EnvNamePrefix: plugins.BuildEnvVariablePrefix(vendorName), - ConfigureCustom: func() error { - return fmt.Errorf("unable to load AWS Config: %w", err) - }, - } + ret.PluginBase = plugins.NewVendorPluginBase(vendorName, func() error { + return fmt.Errorf("unable to load AWS Config: %w", err) + }) ret.bedrockRegion = ret.PluginBase.AddSetupQuestion("AWS Region", true) return } @@ -67,11 +63,7 @@ func NewClient() (ret *BedrockClient) { runtimeClient := bedrockruntime.NewFromConfig(cfg) controlPlaneClient := bedrock.NewFromConfig(cfg) - ret.PluginBase = &plugins.PluginBase{ - Name: vendorName, - EnvNamePrefix: plugins.BuildEnvVariablePrefix(vendorName), - ConfigureCustom: ret.configure, - } + ret.PluginBase = plugins.NewVendorPluginBase(vendorName, ret.configure) ret.runtimeClient = runtimeClient ret.controlPlaneClient = controlPlaneClient diff --git a/internal/plugins/ai/gemini/gemini.go b/internal/plugins/ai/gemini/gemini.go index 0fce3e72..172c7219 100644 --- a/internal/plugins/ai/gemini/gemini.go +++ b/internal/plugins/ai/gemini/gemini.go @@ -46,10 +46,7 @@ func NewClient() (ret *Client) { vendorName := "Gemini" ret = &Client{} - ret.PluginBase = &plugins.PluginBase{ - Name: vendorName, - EnvNamePrefix: plugins.BuildEnvVariablePrefix(vendorName), - } + ret.PluginBase = plugins.NewVendorPluginBase(vendorName, nil) ret.ApiKey = ret.PluginBase.AddSetupQuestion("API key", true) diff --git a/internal/plugins/ai/lmstudio/lmstudio.go b/internal/plugins/ai/lmstudio/lmstudio.go index f9cae99f..d4d2fa89 100644 --- a/internal/plugins/ai/lmstudio/lmstudio.go +++ b/internal/plugins/ai/lmstudio/lmstudio.go @@ -27,11 +27,7 @@ func NewClientCompatible(vendorName string, defaultBaseUrl string, configureCust if configureCustom == nil { configureCustom = ret.configure } - ret.PluginBase = &plugins.PluginBase{ - Name: vendorName, - EnvNamePrefix: plugins.BuildEnvVariablePrefix(vendorName), - ConfigureCustom: configureCustom, - } + ret.PluginBase = plugins.NewVendorPluginBase(vendorName, configureCustom) ret.ApiUrl = ret.AddSetupQuestionCustom("API URL", true, fmt.Sprintf("Enter your %v URL (as a reminder, it is usually %v')", vendorName, defaultBaseUrl)) return diff --git a/internal/plugins/ai/ollama/ollama.go b/internal/plugins/ai/ollama/ollama.go index 03317dfe..ea7d797c 100644 --- a/internal/plugins/ai/ollama/ollama.go +++ b/internal/plugins/ai/ollama/ollama.go @@ -24,11 +24,7 @@ func NewClient() (ret *Client) { vendorName := "Ollama" ret = &Client{} - ret.PluginBase = &plugins.PluginBase{ - Name: vendorName, - EnvNamePrefix: plugins.BuildEnvVariablePrefix(vendorName), - ConfigureCustom: ret.configure, - } + ret.PluginBase = plugins.NewVendorPluginBase(vendorName, ret.configure) ret.ApiUrl = ret.AddSetupQuestionCustom("API URL", true, "Enter your Ollama URL (as a reminder, it is usually http://localhost:11434')") diff --git a/internal/plugins/ai/openai/openai.go b/internal/plugins/ai/openai/openai.go index e364c66c..2e9a7be2 100644 --- a/internal/plugins/ai/openai/openai.go +++ b/internal/plugins/ai/openai/openai.go @@ -52,11 +52,7 @@ func NewClientCompatibleNoSetupQuestions(vendorName string, configureCustom func configureCustom = ret.configure } - ret.PluginBase = &plugins.PluginBase{ - Name: vendorName, - EnvNamePrefix: plugins.BuildEnvVariablePrefix(vendorName), - ConfigureCustom: configureCustom, - } + ret.PluginBase = plugins.NewVendorPluginBase(vendorName, configureCustom) return } diff --git a/internal/plugins/ai/perplexity/perplexity.go b/internal/plugins/ai/perplexity/perplexity.go index 4ec5f2b1..6a8ab7d4 100644 --- a/internal/plugins/ai/perplexity/perplexity.go +++ b/internal/plugins/ai/perplexity/perplexity.go @@ -31,11 +31,7 @@ type Client struct { func NewClient() *Client { c := &Client{} - c.PluginBase = &plugins.PluginBase{ - Name: providerName, - EnvNamePrefix: plugins.BuildEnvVariablePrefix(providerName), - ConfigureCustom: c.Configure, // Assign the Configure method - } + c.PluginBase = plugins.NewVendorPluginBase(providerName, c.Configure) c.APIKey = c.AddSetupQuestion("API_KEY", true) return c } diff --git a/internal/plugins/ai/vertexai/vertexai.go b/internal/plugins/ai/vertexai/vertexai.go index 789c0dcf..93e0ee04 100644 --- a/internal/plugins/ai/vertexai/vertexai.go +++ b/internal/plugins/ai/vertexai/vertexai.go @@ -28,11 +28,7 @@ func NewClient() (ret *Client) { vendorName := "VertexAI" ret = &Client{} - ret.PluginBase = &plugins.PluginBase{ - Name: vendorName, - EnvNamePrefix: plugins.BuildEnvVariablePrefix(vendorName), - ConfigureCustom: ret.configure, - } + ret.PluginBase = plugins.NewVendorPluginBase(vendorName, ret.configure) ret.ProjectID = ret.AddSetupQuestion("Project ID", true) ret.Region = ret.AddSetupQuestion("Region", false) diff --git a/internal/plugins/plugin.go b/internal/plugins/plugin.go index d4fdd404..e97518c6 100644 --- a/internal/plugins/plugin.go +++ b/internal/plugins/plugin.go @@ -36,6 +36,16 @@ func (o *PluginBase) GetName() string { return o.Name } +// NewVendorPluginBase creates a standardized PluginBase for AI vendor plugins. +// This centralizes the common initialization pattern used by all vendors. +func NewVendorPluginBase(name string, configure func() error) *PluginBase { + return &PluginBase{ + Name: name, + EnvNamePrefix: BuildEnvVariablePrefix(name), + ConfigureCustom: configure, + } +} + func (o *PluginBase) GetSetupDescription() (ret string) { if ret = o.SetupDescription; ret == "" { ret = o.GetName() diff --git a/internal/plugins/plugin_test.go b/internal/plugins/plugin_test.go index 1abae83c..16ab20de 100644 --- a/internal/plugins/plugin_test.go +++ b/internal/plugins/plugin_test.go @@ -8,6 +8,43 @@ import ( "github.com/stretchr/testify/assert" ) +func TestNewVendorPluginBase(t *testing.T) { + // Test with configure function + configureCalled := false + configureFunc := func() error { + configureCalled = true + return nil + } + + plugin := NewVendorPluginBase("TestVendor", configureFunc) + + assert.Equal(t, "TestVendor", plugin.Name) + assert.Equal(t, "TESTVENDOR_", plugin.EnvNamePrefix) + assert.NotNil(t, plugin.ConfigureCustom) + + // Test that configure function is properly stored + err := plugin.ConfigureCustom() + assert.NoError(t, err) + assert.True(t, configureCalled) +} + +func TestNewVendorPluginBase_NilConfigure(t *testing.T) { + // Test with nil configure function + plugin := NewVendorPluginBase("TestVendor", nil) + + assert.Equal(t, "TestVendor", plugin.Name) + assert.Equal(t, "TESTVENDOR_", plugin.EnvNamePrefix) + assert.Nil(t, plugin.ConfigureCustom) +} + +func TestNewVendorPluginBase_EnvPrefixWithSpaces(t *testing.T) { + // Test that spaces are converted to underscores + plugin := NewVendorPluginBase("LM Studio", nil) + + assert.Equal(t, "LM Studio", plugin.Name) + assert.Equal(t, "LM_STUDIO_", plugin.EnvNamePrefix) +} + func TestConfigurable_AddSetting(t *testing.T) { conf := &PluginBase{ Settings: Settings{},