From f38cc5ab7305f5f78fb566f7ec03d39b611d6722 Mon Sep 17 00:00:00 2001 From: Shelley Vohr Date: Mon, 10 Nov 2025 21:30:44 +0100 Subject: [PATCH] feat: support WebSocket authentication handling --- shell/browser/api/electron_api_web_request.cc | 63 +++++++++++++++- shell/browser/api/electron_api_web_request.h | 12 ++- shell/browser/login_handler.cc | 17 +++++ shell/browser/login_handler.h | 7 ++ shell/browser/net/proxying_websocket.cc | 19 +++-- shell/browser/net/proxying_websocket.h | 17 +---- shell/browser/net/web_request_api_interface.h | 22 ++++++ spec/api-web-request-spec.ts | 75 +++++++++++++++++++ 8 files changed, 207 insertions(+), 25 deletions(-) diff --git a/shell/browser/api/electron_api_web_request.cc b/shell/browser/api/electron_api_web_request.cc index b2b0a7f405..9be69c96ed 100644 --- a/shell/browser/api/electron_api_web_request.cc +++ b/shell/browser/api/electron_api_web_request.cc @@ -5,12 +5,14 @@ #include "shell/browser/api/electron_api_web_request.h" #include +#include #include #include #include #include "base/containers/fixed_flat_map.h" #include "base/memory/raw_ptr.h" +#include "base/strings/utf_string_conversions.h" #include "base/task/sequenced_task_runner.h" #include "base/values.h" #include "content/public/browser/web_contents.h" @@ -25,6 +27,7 @@ #include "shell/browser/api/electron_api_web_frame_main.h" #include "shell/browser/electron_browser_context.h" #include "shell/browser/javascript_environment.h" +#include "shell/browser/login_handler.h" #include "shell/common/gin_converters/callback_converter.h" #include "shell/common/gin_converters/frame_converter.h" #include "shell/common/gin_converters/gurl_converter.h" @@ -108,7 +111,7 @@ v8::Local HttpResponseHeadersToV8( // Overloaded by multiple types to fill the |details| object. void ToDictionary(gin_helper::Dictionary* details, - extensions::WebRequestInfo* info) { + const extensions::WebRequestInfo* info) { details->Set("id", info->id); details->Set("url", info->url); details->Set("method", info->method); @@ -255,7 +258,7 @@ bool WebRequest::RequestFilter::MatchesType( } bool WebRequest::RequestFilter::MatchesRequest( - extensions::WebRequestInfo* info) const { + const extensions::WebRequestInfo* info) const { // Matches URL and type, and does not match exclude URL. return MatchesURL(info->url, include_url_patterns_) && !MatchesURL(info->url, exclude_url_patterns_) && @@ -287,6 +290,10 @@ struct WebRequest::BlockedRequest { net::CompletionOnceCallback callback; // Only used for onBeforeSendHeaders. BeforeSendHeadersCallback before_send_headers_callback; + // The callback to invoke for auth. If |auth_callback.is_null()| is false, + // |callback| must be NULL. + // Only valid for OnAuthRequired. + AuthCallback auth_callback; // Only used for onBeforeSendHeaders. raw_ptr request_headers = nullptr; // Only used for onHeadersReceived. @@ -297,6 +304,8 @@ struct WebRequest::BlockedRequest { std::string status_line; // Only used for onBeforeRequest. raw_ptr new_url = nullptr; + // Owns the LoginHandler while waiting for auth credentials. + std::unique_ptr login_handler; }; WebRequest::SimpleListenerInfo::SimpleListenerInfo(RequestFilter filter_, @@ -603,6 +612,36 @@ void WebRequest::OnSendHeaders(extensions::WebRequestInfo* info, HandleSimpleEvent(SimpleEvent::kOnSendHeaders, info, request, headers); } +WebRequest::AuthRequiredResponse WebRequest::OnAuthRequired( + const extensions::WebRequestInfo* request_info, + const net::AuthChallengeInfo& auth_info, + WebRequest::AuthCallback callback, + net::AuthCredentials* credentials) { + content::RenderFrameHost* rfh = content::RenderFrameHost::FromID( + request_info->render_process_id, request_info->frame_routing_id); + content::WebContents* web_contents = nullptr; + if (rfh) + web_contents = content::WebContents::FromRenderFrameHost(rfh); + + BlockedRequest blocked_request; + blocked_request.auth_callback = std::move(callback); + blocked_requests_[request_info->id] = std::move(blocked_request); + + auto login_callback = + base::BindOnce(&WebRequest::OnLoginAuthResult, base::Unretained(this), + request_info->id, credentials); + + scoped_refptr response_headers = + request_info->response_headers; + blocked_requests_[request_info->id].login_handler = + std::make_unique( + auth_info, web_contents, + static_cast(request_info->render_process_id), + request_info->url, response_headers, std::move(login_callback)); + + return AuthRequiredResponse::AUTH_REQUIRED_RESPONSE_IO_PENDING; +} + void WebRequest::OnBeforeRedirect(extensions::WebRequestInfo* info, const network::ResourceRequest& request, const GURL& new_location) { @@ -732,6 +771,26 @@ void WebRequest::HandleSimpleEvent(SimpleEvent event, info.listener.Run(gin::ConvertToV8(isolate, details)); } +void WebRequest::OnLoginAuthResult( + uint64_t id, + net::AuthCredentials* credentials, + const std::optional& maybe_creds) { + auto iter = blocked_requests_.find(id); + if (iter == blocked_requests_.end()) + NOTREACHED(); + + AuthRequiredResponse action = + AuthRequiredResponse::AUTH_REQUIRED_RESPONSE_NO_ACTION; + if (maybe_creds.has_value()) { + *credentials = maybe_creds.value(); + action = AuthRequiredResponse::AUTH_REQUIRED_RESPONSE_SET_AUTH; + } + + base::SequencedTaskRunner::GetCurrentDefault()->PostTask( + FROM_HERE, base::BindOnce(std::move(iter->second.auth_callback), action)); + blocked_requests_.erase(iter); +} + // static gin_helper::Handle WebRequest::FromOrCreate( v8::Isolate* isolate, diff --git a/shell/browser/api/electron_api_web_request.h b/shell/browser/api/electron_api_web_request.h index c2e97e6dc6..9b35b598f4 100644 --- a/shell/browser/api/electron_api_web_request.h +++ b/shell/browser/api/electron_api_web_request.h @@ -82,6 +82,10 @@ class WebRequest final : public gin_helper::DeprecatedWrappable, void OnSendHeaders(extensions::WebRequestInfo* info, const network::ResourceRequest& request, const net::HttpRequestHeaders& headers) override; + AuthRequiredResponse OnAuthRequired(const extensions::WebRequestInfo* info, + const net::AuthChallengeInfo& auth_info, + AuthCallback callback, + net::AuthCredentials* credentials) override; void OnBeforeRedirect(extensions::WebRequestInfo* info, const network::ResourceRequest& request, const GURL& new_location) override; @@ -157,6 +161,12 @@ class WebRequest final : public gin_helper::DeprecatedWrappable, v8::Local response); void OnHeadersReceivedListenerResult(uint64_t id, v8::Local response); + // Callback invoked by LoginHandler when auth credentials are supplied via + // the unified 'login' event. Bridges back into WebRequest's AuthCallback. + void OnLoginAuthResult( + uint64_t id, + net::AuthCredentials* credentials, + const std::optional& maybe_creds); class RequestFilter { public: @@ -174,7 +184,7 @@ class WebRequest final : public gin_helper::DeprecatedWrappable, bool is_match_pattern = true); void AddType(extensions::WebRequestResourceType type); - bool MatchesRequest(extensions::WebRequestInfo* info) const; + bool MatchesRequest(const extensions::WebRequestInfo* info) const; private: bool MatchesURL(const GURL& url, diff --git a/shell/browser/login_handler.cc b/shell/browser/login_handler.cc index 8813b08363..00ddcafc27 100644 --- a/shell/browser/login_handler.cc +++ b/shell/browser/login_handler.cc @@ -42,6 +42,23 @@ LoginHandler::LoginHandler( response_headers, first_auth_attempt)); } +LoginHandler::LoginHandler( + const net::AuthChallengeInfo& auth_info, + content::WebContents* web_contents, + base::ProcessId process_id, + const GURL& url, + scoped_refptr response_headers, + content::LoginDelegate::LoginAuthRequiredCallback auth_required_callback) + : LoginHandler(auth_info, + web_contents, + /*is_request_for_primary_main_frame=*/false, + /*is_request_for_navigation=*/false, + process_id, + url, + std::move(response_headers), + /*first_auth_attempt=*/false, + std::move(auth_required_callback)) {} + void LoginHandler::EmitEvent( net::AuthChallengeInfo auth_info, content::WebContents* web_contents, diff --git a/shell/browser/login_handler.h b/shell/browser/login_handler.h index 144b97cc00..ef73b25933 100644 --- a/shell/browser/login_handler.h +++ b/shell/browser/login_handler.h @@ -32,6 +32,13 @@ class LoginHandler : public content::LoginDelegate { scoped_refptr response_headers, bool first_auth_attempt, content::LoginDelegate::LoginAuthRequiredCallback auth_required_callback); + LoginHandler( + const net::AuthChallengeInfo& auth_info, + content::WebContents* web_contents, + base::ProcessId process_id, + const GURL& url, + scoped_refptr response_headers, + content::LoginDelegate::LoginAuthRequiredCallback auth_required_callback); ~LoginHandler() override; // disable copy diff --git a/shell/browser/net/proxying_websocket.cc b/shell/browser/net/proxying_websocket.cc index 3ce673aee2..3b18d08f98 100644 --- a/shell/browser/net/proxying_websocket.cc +++ b/shell/browser/net/proxying_websocket.cc @@ -374,19 +374,21 @@ void ProxyingWebSocket::OnHeadersReceivedComplete(int error_code) { ContinueToCompleted(); } -void ProxyingWebSocket::OnAuthRequiredComplete(AuthRequiredResponse rv) { +void ProxyingWebSocket::OnAuthRequiredComplete( + WebRequestAPI::AuthRequiredResponse rv) { CHECK(auth_required_callback_); ResumeIncomingMethodCallProcessing(); switch (rv) { - case AuthRequiredResponse::kNoAction: - case AuthRequiredResponse::kCancelAuth: + case WebRequestAPI::AuthRequiredResponse::AUTH_REQUIRED_RESPONSE_NO_ACTION: + case WebRequestAPI::AuthRequiredResponse:: + AUTH_REQUIRED_RESPONSE_CANCEL_AUTH: std::move(auth_required_callback_).Run(std::nullopt); break; - case AuthRequiredResponse::kSetAuth: + case WebRequestAPI::AuthRequiredResponse::AUTH_REQUIRED_RESPONSE_SET_AUTH: std::move(auth_required_callback_).Run(auth_credentials_); break; - case AuthRequiredResponse::kIoPending: + case WebRequestAPI::AuthRequiredResponse::AUTH_REQUIRED_RESPONSE_IO_PENDING: NOTREACHED(); } } @@ -403,8 +405,13 @@ void ProxyingWebSocket::OnHeadersReceivedCompleteForAuth( auto continuation = base::BindRepeating( &ProxyingWebSocket::OnAuthRequiredComplete, weak_factory_.GetWeakPtr()); - auto auth_rv = AuthRequiredResponse::kCancelAuth; + auto auth_rv = web_request_api_->OnAuthRequired( + &info_, auth_info, std::move(continuation), &auth_credentials_); PauseIncomingMethodCallProcessing(); + if (auth_rv == + WebRequestAPI::AuthRequiredResponse::AUTH_REQUIRED_RESPONSE_IO_PENDING) { + return; + } OnAuthRequiredComplete(auth_rv); } diff --git a/shell/browser/net/proxying_websocket.h b/shell/browser/net/proxying_websocket.h index a369b9b3d4..b29cb49c8f 100644 --- a/shell/browser/net/proxying_websocket.h +++ b/shell/browser/net/proxying_websocket.h @@ -37,21 +37,6 @@ class ProxyingWebSocket : public network::mojom::WebSocketHandshakeClient, public: using WebSocketFactory = content::ContentBrowserClient::WebSocketFactory; - // AuthRequiredResponse indicates how an OnAuthRequired call is handled. - enum class AuthRequiredResponse { - // No credentials were provided. - kNoAction, - // AuthCredentials is filled in with a username and password, which should - // be used in a response to the provided auth challenge. - kSetAuth, - // The request should be canceled. - kCancelAuth, - // The action will be decided asynchronously. |callback| will be invoked - // when the decision is made, and one of the other AuthRequiredResponse - // values will be passed in with the same semantics as described above. - kIoPending, - }; - ProxyingWebSocket( WebRequestAPI* web_request_api, WebSocketFactory factory, @@ -121,7 +106,7 @@ class ProxyingWebSocket : public network::mojom::WebSocketHandshakeClient, void ContinueToStartRequest(int error_code); void OnHeadersReceivedComplete(int error_code); void ContinueToHeadersReceived(); - void OnAuthRequiredComplete(AuthRequiredResponse rv); + void OnAuthRequiredComplete(WebRequestAPI::AuthRequiredResponse rv); void OnHeadersReceivedCompleteForAuth(const net::AuthChallengeInfo& auth_info, int rv); void ContinueToCompleted(); diff --git a/shell/browser/net/web_request_api_interface.h b/shell/browser/net/web_request_api_interface.h index d93a6aa943..d2d3ffcb8b 100644 --- a/shell/browser/net/web_request_api_interface.h +++ b/shell/browser/net/web_request_api_interface.h @@ -27,6 +27,23 @@ class WebRequestAPI { const std::set& set_headers, int error_code)>; + // AuthRequiredResponse indicates how an OnAuthRequired call is handled. + enum class AuthRequiredResponse { + // No credentials were provided. + AUTH_REQUIRED_RESPONSE_NO_ACTION, + // AuthCredentials is filled in with a username and password, which should + // be used in a response to the provided auth challenge. + AUTH_REQUIRED_RESPONSE_SET_AUTH, + // The request should be canceled. + AUTH_REQUIRED_RESPONSE_CANCEL_AUTH, + // The action will be decided asynchronously. |callback| will be invoked + // when the decision is made, and one of the other AuthRequiredResponse + // values will be passed in with the same semantics as described above. + AUTH_REQUIRED_RESPONSE_IO_PENDING, + }; + + using AuthCallback = base::OnceCallback; + virtual bool HasListener() const = 0; virtual int OnBeforeRequest(extensions::WebRequestInfo* info, const network::ResourceRequest& request, @@ -36,6 +53,11 @@ class WebRequestAPI { const network::ResourceRequest& request, BeforeSendHeadersCallback callback, net::HttpRequestHeaders* headers) = 0; + virtual AuthRequiredResponse OnAuthRequired( + const extensions::WebRequestInfo* info, + const net::AuthChallengeInfo& auth_info, + AuthCallback callback, + net::AuthCredentials* credentials) = 0; virtual int OnHeadersReceived( extensions::WebRequestInfo* info, const network::ResourceRequest& request, diff --git a/spec/api-web-request-spec.ts b/spec/api-web-request-spec.ts index b2792cf9a6..9000b2d2b1 100644 --- a/spec/api-web-request-spec.ts +++ b/spec/api-web-request-spec.ts @@ -733,5 +733,80 @@ describe('webRequest module', () => { expect(reqHeaders['/websocket'].foo).to.equal('bar'); expect(reqHeaders['/'].foo).to.equal('bar'); }); + + it('authenticates a WebSocket via login event', async () => { + const authServer = http.createServer(); + const wssAuth = new WebSocket.Server({ noServer: true }); + const expected = 'Basic ' + Buffer.from('user:pass').toString('base64'); + + wssAuth.on('connection', ws => { + ws.send('Authenticated!'); + }); + + authServer.on('upgrade', (req, socket, head) => { + const auth = req.headers.authorization || ''; + if (auth !== expected) { + socket.write( + 'HTTP/1.1 401 Unauthorized\r\n' + + 'WWW-Authenticate: Basic realm="Test"\r\n' + + 'Content-Length: 0\r\n' + + '\r\n' + ); + socket.destroy(); + return; + } + + wssAuth.handleUpgrade(req, socket as Socket, head, ws => { + wssAuth.emit('connection', ws, req); + }); + }); + + const { port } = await listen(authServer); + const ses = session.fromPartition(`WebRequestWSAuth-${Date.now()}`); + + const contents = (webContents as typeof ElectronInternal.WebContents).create({ + session: ses, + sandbox: true + }); + + defer(() => { + contents.destroy(); + authServer.close(); + wssAuth.close(); + }); + + ses.webRequest.onBeforeRequest({ urls: ['ws://*/*'] }, (details, callback) => { + callback({}); + }); + + contents.on('login', (event, details: any, _: any, callback: (u: string, p: string) => void) => { + if (details?.url?.startsWith(`ws://localhost:${port}`)) { + event.preventDefault(); + callback('user', 'pass'); + } + }); + + await contents.loadFile(path.join(fixturesPath, 'blank.html')); + + const message = await contents.executeJavaScript(`new Promise((resolve, reject) => { + let attempts = 0; + function connect() { + attempts++; + const ws = new WebSocket('ws://localhost:${port}'); + ws.onmessage = e => resolve(e.data); + ws.onerror = () => { + if (attempts < 3) { + setTimeout(connect, 50); + } else { + reject(new Error('WebSocket auth failed')); + } + }; + } + connect(); + setTimeout(() => reject(new Error('timeout')), 5000); + });`); + + expect(message).to.equal('Authenticated!'); + }); }); });