diff --git a/internal/plugins/ai/openai_compatible/direct_models_call.go b/internal/plugins/ai/openai_compatible/direct_models_call.go index 8718d7e0..e4073792 100644 --- a/internal/plugins/ai/openai_compatible/direct_models_call.go +++ b/internal/plugins/ai/openai_compatible/direct_models_call.go @@ -15,10 +15,15 @@ type Model struct { ID string `json:"id"` } +const errorResponseLimit = 500 + // DirectlyGetModels is used to fetch models directly from the API // when the standard OpenAI SDK method fails due to a nonstandard format. // This is useful for providers like Together that return a direct array of models. -func (c *Client) DirectlyGetModels() ([]string, error) { +func (c *Client) DirectlyGetModels(ctx context.Context) ([]string, error) { + if ctx == nil { + ctx = context.Background() + } baseURL := c.ApiBaseURL.Value if baseURL == "" { return nil, fmt.Errorf("API base URL not configured") @@ -30,7 +35,7 @@ func (c *Client) DirectlyGetModels() ([]string, error) { return nil, fmt.Errorf("failed to create models URL: %w", err) } - req, err := http.NewRequestWithContext(context.Background(), "GET", fullURL, nil) + req, err := http.NewRequestWithContext(ctx, "GET", fullURL, nil) if err != nil { return nil, err } else { @@ -51,8 +56,8 @@ func (c *Client) DirectlyGetModels() ([]string, error) { // Read the response body for debugging bodyBytes, _ := io.ReadAll(resp.Body) bodyString := string(bodyBytes) - if len(bodyString) > 500 { // Truncate if too large - bodyString = bodyString[:500] + "..." + if len(bodyString) > errorResponseLimit { // Truncate if too large + bodyString = bodyString[:errorResponseLimit] + "..." } return nil, fmt.Errorf("unexpected status code: %d from provider %s, response body: %s", resp.StatusCode, c.GetName(), bodyString) diff --git a/internal/plugins/ai/openai_compatible/providers_config.go b/internal/plugins/ai/openai_compatible/providers_config.go index bb07a993..9511e01a 100644 --- a/internal/plugins/ai/openai_compatible/providers_config.go +++ b/internal/plugins/ai/openai_compatible/providers_config.go @@ -1,6 +1,7 @@ package openai_compatible import ( + "context" "os" "strings" @@ -39,7 +40,7 @@ func (c *Client) ListModels() ([]string, error) { return models, nil } - return c.DirectlyGetModels() + return c.DirectlyGetModels(context.Background()) } // ProviderMap is a map of provider name to ProviderConfig for O(1) lookup