diff --git a/internal/plugins/ai/azure/azure.go b/internal/plugins/ai/azure/azure.go index 6f06ce2f..d45ac5c6 100644 --- a/internal/plugins/ai/azure/azure.go +++ b/internal/plugins/ai/azure/azure.go @@ -35,7 +35,7 @@ type Client struct { apiDeployments []string } -const defaultAPIVersion = "2024-05-01-preview" +const defaultAPIVersion = "2025-04-01-preview" func (oi *Client) configure() error { oi.apiDeployments = parseDeployments(oi.ApiDeployments.Value) @@ -86,6 +86,7 @@ func azureDeploymentMiddleware(req *http.Request, next option.MiddlewareNext) (* "/audio/transcriptions": true, "/audio/translations": true, "/images/generations": true, + "/responses": true, } path := req.URL.Path @@ -156,6 +157,3 @@ func (oi *Client) ListModels() (ret []string, err error) { return } -func (oi *Client) NeedsRawMode(modelName string) bool { - return false -} diff --git a/internal/plugins/ai/azure/azure_test.go b/internal/plugins/ai/azure/azure_test.go index 267b11a4..15e61567 100644 --- a/internal/plugins/ai/azure/azure_test.go +++ b/internal/plugins/ai/azure/azure_test.go @@ -1,6 +1,9 @@ package azure import ( + "bytes" + "io" + "net/http" "testing" ) @@ -27,7 +30,7 @@ func TestClientConfigure(t *testing.T) { client.ApiDeployments.Value = "deployment1,deployment2" client.ApiKey.Value = "test-api-key" client.ApiBaseURL.Value = "https://example.com" - client.ApiVersion.Value = "2024-05-01-preview" + client.ApiVersion.Value = "2025-04-01-preview" err := client.configure() if err != nil { @@ -48,8 +51,8 @@ func TestClientConfigure(t *testing.T) { t.Errorf("Expected ApiClient to be initialized, got nil") } - if client.ApiVersion.Value != "2024-05-01-preview" { - t.Errorf("Expected API version to be '2024-05-01-preview', got %s", client.ApiVersion.Value) + if client.ApiVersion.Value != "2025-04-01-preview" { + t.Errorf("Expected API version to be '2025-04-01-preview', got %s", client.ApiVersion.Value) } } @@ -88,3 +91,81 @@ func TestListModels(t *testing.T) { } } } + +func TestNeedsRawModeInheritsFromParent(t *testing.T) { + client := NewClient() + + tests := []struct { + name string + model string + expected bool + }{ + {"o1 model", "o1", true}, + {"o1-preview", "o1-preview", true}, + {"o3-mini", "o3-mini", true}, + {"o4-mini", "o4-mini", true}, + {"gpt-5", "gpt-5", true}, + {"gpt-5-turbo", "gpt-5-turbo", true}, + {"gpt-4o", "gpt-4o", false}, + {"gpt-4", "gpt-4", false}, + {"regular deployment", "my-deployment", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := client.NeedsRawMode(tt.model) + if result != tt.expected { + t.Errorf("NeedsRawMode(%q) = %v, want %v", tt.model, result, tt.expected) + } + }) + } +} + +func TestMiddlewareResponsesRoute(t *testing.T) { + // Verify /responses is in the deployment routes by testing the middleware + body := `{"model": "gpt-5"}` + req, err := http.NewRequest("POST", "https://example.com/openai/responses", io.NopCloser(bytes.NewReader([]byte(body)))) + if err != nil { + t.Fatalf("Failed to create request: %v", err) + } + + var capturedPath string + mockNext := func(req *http.Request) (*http.Response, error) { + capturedPath = req.URL.Path + return &http.Response{StatusCode: 200}, nil + } + + _, err = azureDeploymentMiddleware(req, mockNext) + if err != nil { + t.Fatalf("Middleware returned error: %v", err) + } + + expected := "/openai/deployments/gpt-5/responses" + if capturedPath != expected { + t.Errorf("Expected path %q, got %q", expected, capturedPath) + } +} + +func TestMiddlewareChatCompletionsRoute(t *testing.T) { + body := `{"model": "gpt-4o"}` + req, err := http.NewRequest("POST", "https://example.com/openai/chat/completions", io.NopCloser(bytes.NewReader([]byte(body)))) + if err != nil { + t.Fatalf("Failed to create request: %v", err) + } + + var capturedPath string + mockNext := func(req *http.Request) (*http.Response, error) { + capturedPath = req.URL.Path + return &http.Response{StatusCode: 200}, nil + } + + _, err = azureDeploymentMiddleware(req, mockNext) + if err != nil { + t.Fatalf("Middleware returned error: %v", err) + } + + expected := "/openai/deployments/gpt-4o/chat/completions" + if capturedPath != expected { + t.Errorf("Expected path %q, got %q", expected, capturedPath) + } +}