feat: add context support to DirectlyGetModels method

## CHANGES

- Add context parameter to DirectlyGetModels method signature
- Add nil context check with Background fallback
- Extract magic number 500 into errorResponseLimit constant
- Update DirectlyGetModels call to pass context.Background
- Import context package in providers_config.go file
This commit is contained in:
Kayvan Sylvan
2025-07-11 12:43:31 -07:00
parent b187a80275
commit b34f249e24
2 changed files with 11 additions and 5 deletions

View File

@@ -15,10 +15,15 @@ type Model struct {
ID string `json:"id"`
}
const errorResponseLimit = 500
// DirectlyGetModels is used to fetch models directly from the API
// when the standard OpenAI SDK method fails due to a nonstandard format.
// This is useful for providers like Together that return a direct array of models.
func (c *Client) DirectlyGetModels() ([]string, error) {
func (c *Client) DirectlyGetModels(ctx context.Context) ([]string, error) {
if ctx == nil {
ctx = context.Background()
}
baseURL := c.ApiBaseURL.Value
if baseURL == "" {
return nil, fmt.Errorf("API base URL not configured")
@@ -30,7 +35,7 @@ func (c *Client) DirectlyGetModels() ([]string, error) {
return nil, fmt.Errorf("failed to create models URL: %w", err)
}
req, err := http.NewRequestWithContext(context.Background(), "GET", fullURL, nil)
req, err := http.NewRequestWithContext(ctx, "GET", fullURL, nil)
if err != nil {
return nil, err
} else {
@@ -51,8 +56,8 @@ func (c *Client) DirectlyGetModels() ([]string, error) {
// Read the response body for debugging
bodyBytes, _ := io.ReadAll(resp.Body)
bodyString := string(bodyBytes)
if len(bodyString) > 500 { // Truncate if too large
bodyString = bodyString[:500] + "..."
if len(bodyString) > errorResponseLimit { // Truncate if too large
bodyString = bodyString[:errorResponseLimit] + "..."
}
return nil, fmt.Errorf("unexpected status code: %d from provider %s, response body: %s",
resp.StatusCode, c.GetName(), bodyString)

View File

@@ -1,6 +1,7 @@
package openai_compatible
import (
"context"
"os"
"strings"
@@ -39,7 +40,7 @@ func (c *Client) ListModels() ([]string, error) {
return models, nil
}
return c.DirectlyGetModels()
return c.DirectlyGetModels(context.Background())
}
// ProviderMap is a map of provider name to ProviderConfig for O(1) lookup