feat: support WebSocket authentication handling

This commit is contained in:
Shelley Vohr
2025-11-10 21:30:44 +01:00
committed by Charles Kerr
parent d2ae9ed69f
commit f38cc5ab73
8 changed files with 207 additions and 25 deletions

View File

@@ -5,12 +5,14 @@
#include "shell/browser/api/electron_api_web_request.h"
#include <memory>
#include <optional>
#include <string>
#include <string_view>
#include <utility>
#include "base/containers/fixed_flat_map.h"
#include "base/memory/raw_ptr.h"
#include "base/strings/utf_string_conversions.h"
#include "base/task/sequenced_task_runner.h"
#include "base/values.h"
#include "content/public/browser/web_contents.h"
@@ -25,6 +27,7 @@
#include "shell/browser/api/electron_api_web_frame_main.h"
#include "shell/browser/electron_browser_context.h"
#include "shell/browser/javascript_environment.h"
#include "shell/browser/login_handler.h"
#include "shell/common/gin_converters/callback_converter.h"
#include "shell/common/gin_converters/frame_converter.h"
#include "shell/common/gin_converters/gurl_converter.h"
@@ -108,7 +111,7 @@ v8::Local<v8::Value> HttpResponseHeadersToV8(
// Overloaded by multiple types to fill the |details| object.
void ToDictionary(gin_helper::Dictionary* details,
extensions::WebRequestInfo* info) {
const extensions::WebRequestInfo* info) {
details->Set("id", info->id);
details->Set("url", info->url);
details->Set("method", info->method);
@@ -255,7 +258,7 @@ bool WebRequest::RequestFilter::MatchesType(
}
bool WebRequest::RequestFilter::MatchesRequest(
extensions::WebRequestInfo* info) const {
const extensions::WebRequestInfo* info) const {
// Matches URL and type, and does not match exclude URL.
return MatchesURL(info->url, include_url_patterns_) &&
!MatchesURL(info->url, exclude_url_patterns_) &&
@@ -287,6 +290,10 @@ struct WebRequest::BlockedRequest {
net::CompletionOnceCallback callback;
// Only used for onBeforeSendHeaders.
BeforeSendHeadersCallback before_send_headers_callback;
// The callback to invoke for auth. If |auth_callback.is_null()| is false,
// |callback| must be NULL.
// Only valid for OnAuthRequired.
AuthCallback auth_callback;
// Only used for onBeforeSendHeaders.
raw_ptr<net::HttpRequestHeaders> request_headers = nullptr;
// Only used for onHeadersReceived.
@@ -297,6 +304,8 @@ struct WebRequest::BlockedRequest {
std::string status_line;
// Only used for onBeforeRequest.
raw_ptr<GURL> new_url = nullptr;
// Owns the LoginHandler while waiting for auth credentials.
std::unique_ptr<LoginHandler> login_handler;
};
WebRequest::SimpleListenerInfo::SimpleListenerInfo(RequestFilter filter_,
@@ -603,6 +612,36 @@ void WebRequest::OnSendHeaders(extensions::WebRequestInfo* info,
HandleSimpleEvent(SimpleEvent::kOnSendHeaders, info, request, headers);
}
WebRequest::AuthRequiredResponse WebRequest::OnAuthRequired(
const extensions::WebRequestInfo* request_info,
const net::AuthChallengeInfo& auth_info,
WebRequest::AuthCallback callback,
net::AuthCredentials* credentials) {
content::RenderFrameHost* rfh = content::RenderFrameHost::FromID(
request_info->render_process_id, request_info->frame_routing_id);
content::WebContents* web_contents = nullptr;
if (rfh)
web_contents = content::WebContents::FromRenderFrameHost(rfh);
BlockedRequest blocked_request;
blocked_request.auth_callback = std::move(callback);
blocked_requests_[request_info->id] = std::move(blocked_request);
auto login_callback =
base::BindOnce(&WebRequest::OnLoginAuthResult, base::Unretained(this),
request_info->id, credentials);
scoped_refptr<net::HttpResponseHeaders> response_headers =
request_info->response_headers;
blocked_requests_[request_info->id].login_handler =
std::make_unique<LoginHandler>(
auth_info, web_contents,
static_cast<base::ProcessId>(request_info->render_process_id),
request_info->url, response_headers, std::move(login_callback));
return AuthRequiredResponse::AUTH_REQUIRED_RESPONSE_IO_PENDING;
}
void WebRequest::OnBeforeRedirect(extensions::WebRequestInfo* info,
const network::ResourceRequest& request,
const GURL& new_location) {
@@ -732,6 +771,26 @@ void WebRequest::HandleSimpleEvent(SimpleEvent event,
info.listener.Run(gin::ConvertToV8(isolate, details));
}
void WebRequest::OnLoginAuthResult(
uint64_t id,
net::AuthCredentials* credentials,
const std::optional<net::AuthCredentials>& maybe_creds) {
auto iter = blocked_requests_.find(id);
if (iter == blocked_requests_.end())
NOTREACHED();
AuthRequiredResponse action =
AuthRequiredResponse::AUTH_REQUIRED_RESPONSE_NO_ACTION;
if (maybe_creds.has_value()) {
*credentials = maybe_creds.value();
action = AuthRequiredResponse::AUTH_REQUIRED_RESPONSE_SET_AUTH;
}
base::SequencedTaskRunner::GetCurrentDefault()->PostTask(
FROM_HERE, base::BindOnce(std::move(iter->second.auth_callback), action));
blocked_requests_.erase(iter);
}
// static
gin_helper::Handle<WebRequest> WebRequest::FromOrCreate(
v8::Isolate* isolate,

View File

@@ -82,6 +82,10 @@ class WebRequest final : public gin_helper::DeprecatedWrappable<WebRequest>,
void OnSendHeaders(extensions::WebRequestInfo* info,
const network::ResourceRequest& request,
const net::HttpRequestHeaders& headers) override;
AuthRequiredResponse OnAuthRequired(const extensions::WebRequestInfo* info,
const net::AuthChallengeInfo& auth_info,
AuthCallback callback,
net::AuthCredentials* credentials) override;
void OnBeforeRedirect(extensions::WebRequestInfo* info,
const network::ResourceRequest& request,
const GURL& new_location) override;
@@ -157,6 +161,12 @@ class WebRequest final : public gin_helper::DeprecatedWrappable<WebRequest>,
v8::Local<v8::Value> response);
void OnHeadersReceivedListenerResult(uint64_t id,
v8::Local<v8::Value> response);
// Callback invoked by LoginHandler when auth credentials are supplied via
// the unified 'login' event. Bridges back into WebRequest's AuthCallback.
void OnLoginAuthResult(
uint64_t id,
net::AuthCredentials* credentials,
const std::optional<net::AuthCredentials>& maybe_creds);
class RequestFilter {
public:
@@ -174,7 +184,7 @@ class WebRequest final : public gin_helper::DeprecatedWrappable<WebRequest>,
bool is_match_pattern = true);
void AddType(extensions::WebRequestResourceType type);
bool MatchesRequest(extensions::WebRequestInfo* info) const;
bool MatchesRequest(const extensions::WebRequestInfo* info) const;
private:
bool MatchesURL(const GURL& url,

View File

@@ -42,6 +42,23 @@ LoginHandler::LoginHandler(
response_headers, first_auth_attempt));
}
LoginHandler::LoginHandler(
const net::AuthChallengeInfo& auth_info,
content::WebContents* web_contents,
base::ProcessId process_id,
const GURL& url,
scoped_refptr<net::HttpResponseHeaders> response_headers,
content::LoginDelegate::LoginAuthRequiredCallback auth_required_callback)
: LoginHandler(auth_info,
web_contents,
/*is_request_for_primary_main_frame=*/false,
/*is_request_for_navigation=*/false,
process_id,
url,
std::move(response_headers),
/*first_auth_attempt=*/false,
std::move(auth_required_callback)) {}
void LoginHandler::EmitEvent(
net::AuthChallengeInfo auth_info,
content::WebContents* web_contents,

View File

@@ -32,6 +32,13 @@ class LoginHandler : public content::LoginDelegate {
scoped_refptr<net::HttpResponseHeaders> response_headers,
bool first_auth_attempt,
content::LoginDelegate::LoginAuthRequiredCallback auth_required_callback);
LoginHandler(
const net::AuthChallengeInfo& auth_info,
content::WebContents* web_contents,
base::ProcessId process_id,
const GURL& url,
scoped_refptr<net::HttpResponseHeaders> response_headers,
content::LoginDelegate::LoginAuthRequiredCallback auth_required_callback);
~LoginHandler() override;
// disable copy

View File

@@ -374,19 +374,21 @@ void ProxyingWebSocket::OnHeadersReceivedComplete(int error_code) {
ContinueToCompleted();
}
void ProxyingWebSocket::OnAuthRequiredComplete(AuthRequiredResponse rv) {
void ProxyingWebSocket::OnAuthRequiredComplete(
WebRequestAPI::AuthRequiredResponse rv) {
CHECK(auth_required_callback_);
ResumeIncomingMethodCallProcessing();
switch (rv) {
case AuthRequiredResponse::kNoAction:
case AuthRequiredResponse::kCancelAuth:
case WebRequestAPI::AuthRequiredResponse::AUTH_REQUIRED_RESPONSE_NO_ACTION:
case WebRequestAPI::AuthRequiredResponse::
AUTH_REQUIRED_RESPONSE_CANCEL_AUTH:
std::move(auth_required_callback_).Run(std::nullopt);
break;
case AuthRequiredResponse::kSetAuth:
case WebRequestAPI::AuthRequiredResponse::AUTH_REQUIRED_RESPONSE_SET_AUTH:
std::move(auth_required_callback_).Run(auth_credentials_);
break;
case AuthRequiredResponse::kIoPending:
case WebRequestAPI::AuthRequiredResponse::AUTH_REQUIRED_RESPONSE_IO_PENDING:
NOTREACHED();
}
}
@@ -403,8 +405,13 @@ void ProxyingWebSocket::OnHeadersReceivedCompleteForAuth(
auto continuation = base::BindRepeating(
&ProxyingWebSocket::OnAuthRequiredComplete, weak_factory_.GetWeakPtr());
auto auth_rv = AuthRequiredResponse::kCancelAuth;
auto auth_rv = web_request_api_->OnAuthRequired(
&info_, auth_info, std::move(continuation), &auth_credentials_);
PauseIncomingMethodCallProcessing();
if (auth_rv ==
WebRequestAPI::AuthRequiredResponse::AUTH_REQUIRED_RESPONSE_IO_PENDING) {
return;
}
OnAuthRequiredComplete(auth_rv);
}

View File

@@ -37,21 +37,6 @@ class ProxyingWebSocket : public network::mojom::WebSocketHandshakeClient,
public:
using WebSocketFactory = content::ContentBrowserClient::WebSocketFactory;
// AuthRequiredResponse indicates how an OnAuthRequired call is handled.
enum class AuthRequiredResponse {
// No credentials were provided.
kNoAction,
// AuthCredentials is filled in with a username and password, which should
// be used in a response to the provided auth challenge.
kSetAuth,
// The request should be canceled.
kCancelAuth,
// The action will be decided asynchronously. |callback| will be invoked
// when the decision is made, and one of the other AuthRequiredResponse
// values will be passed in with the same semantics as described above.
kIoPending,
};
ProxyingWebSocket(
WebRequestAPI* web_request_api,
WebSocketFactory factory,
@@ -121,7 +106,7 @@ class ProxyingWebSocket : public network::mojom::WebSocketHandshakeClient,
void ContinueToStartRequest(int error_code);
void OnHeadersReceivedComplete(int error_code);
void ContinueToHeadersReceived();
void OnAuthRequiredComplete(AuthRequiredResponse rv);
void OnAuthRequiredComplete(WebRequestAPI::AuthRequiredResponse rv);
void OnHeadersReceivedCompleteForAuth(const net::AuthChallengeInfo& auth_info,
int rv);
void ContinueToCompleted();

View File

@@ -27,6 +27,23 @@ class WebRequestAPI {
const std::set<std::string>& set_headers,
int error_code)>;
// AuthRequiredResponse indicates how an OnAuthRequired call is handled.
enum class AuthRequiredResponse {
// No credentials were provided.
AUTH_REQUIRED_RESPONSE_NO_ACTION,
// AuthCredentials is filled in with a username and password, which should
// be used in a response to the provided auth challenge.
AUTH_REQUIRED_RESPONSE_SET_AUTH,
// The request should be canceled.
AUTH_REQUIRED_RESPONSE_CANCEL_AUTH,
// The action will be decided asynchronously. |callback| will be invoked
// when the decision is made, and one of the other AuthRequiredResponse
// values will be passed in with the same semantics as described above.
AUTH_REQUIRED_RESPONSE_IO_PENDING,
};
using AuthCallback = base::OnceCallback<void(AuthRequiredResponse)>;
virtual bool HasListener() const = 0;
virtual int OnBeforeRequest(extensions::WebRequestInfo* info,
const network::ResourceRequest& request,
@@ -36,6 +53,11 @@ class WebRequestAPI {
const network::ResourceRequest& request,
BeforeSendHeadersCallback callback,
net::HttpRequestHeaders* headers) = 0;
virtual AuthRequiredResponse OnAuthRequired(
const extensions::WebRequestInfo* info,
const net::AuthChallengeInfo& auth_info,
AuthCallback callback,
net::AuthCredentials* credentials) = 0;
virtual int OnHeadersReceived(
extensions::WebRequestInfo* info,
const network::ResourceRequest& request,

View File

@@ -733,5 +733,80 @@ describe('webRequest module', () => {
expect(reqHeaders['/websocket'].foo).to.equal('bar');
expect(reqHeaders['/'].foo).to.equal('bar');
});
it('authenticates a WebSocket via login event', async () => {
const authServer = http.createServer();
const wssAuth = new WebSocket.Server({ noServer: true });
const expected = 'Basic ' + Buffer.from('user:pass').toString('base64');
wssAuth.on('connection', ws => {
ws.send('Authenticated!');
});
authServer.on('upgrade', (req, socket, head) => {
const auth = req.headers.authorization || '';
if (auth !== expected) {
socket.write(
'HTTP/1.1 401 Unauthorized\r\n' +
'WWW-Authenticate: Basic realm="Test"\r\n' +
'Content-Length: 0\r\n' +
'\r\n'
);
socket.destroy();
return;
}
wssAuth.handleUpgrade(req, socket as Socket, head, ws => {
wssAuth.emit('connection', ws, req);
});
});
const { port } = await listen(authServer);
const ses = session.fromPartition(`WebRequestWSAuth-${Date.now()}`);
const contents = (webContents as typeof ElectronInternal.WebContents).create({
session: ses,
sandbox: true
});
defer(() => {
contents.destroy();
authServer.close();
wssAuth.close();
});
ses.webRequest.onBeforeRequest({ urls: ['ws://*/*'] }, (details, callback) => {
callback({});
});
contents.on('login', (event, details: any, _: any, callback: (u: string, p: string) => void) => {
if (details?.url?.startsWith(`ws://localhost:${port}`)) {
event.preventDefault();
callback('user', 'pass');
}
});
await contents.loadFile(path.join(fixturesPath, 'blank.html'));
const message = await contents.executeJavaScript(`new Promise((resolve, reject) => {
let attempts = 0;
function connect() {
attempts++;
const ws = new WebSocket('ws://localhost:${port}');
ws.onmessage = e => resolve(e.data);
ws.onerror = () => {
if (attempts < 3) {
setTimeout(connect, 50);
} else {
reject(new Error('WebSocket auth failed'));
}
};
}
connect();
setTimeout(() => reject(new Error('timeout')), 5000);
});`);
expect(message).to.equal('Authenticated!');
});
});
});