feat: support WebSocket authentication handling (#48512)

* feat: support WebSocket authentication handling

* test: add a test

* refactor: route through login instead
This commit is contained in:
Shelley Vohr
2025-11-10 21:30:44 +01:00
committed by GitHub
parent a5cebb6df2
commit 4951b96235
7 changed files with 204 additions and 25 deletions

View File

@@ -5,12 +5,14 @@
#include "shell/browser/api/electron_api_web_request.h"
#include <memory>
#include <optional>
#include <string>
#include <string_view>
#include <utility>
#include "base/containers/fixed_flat_map.h"
#include "base/memory/raw_ptr.h"
#include "base/strings/utf_string_conversions.h"
#include "base/task/sequenced_task_runner.h"
#include "base/values.h"
#include "content/public/browser/web_contents.h"
@@ -25,6 +27,7 @@
#include "shell/browser/api/electron_api_web_frame_main.h"
#include "shell/browser/electron_browser_context.h"
#include "shell/browser/javascript_environment.h"
#include "shell/browser/login_handler.h"
#include "shell/common/gin_converters/callback_converter.h"
#include "shell/common/gin_converters/frame_converter.h"
#include "shell/common/gin_converters/gurl_converter.h"
@@ -100,7 +103,7 @@ v8::Local<v8::Value> HttpResponseHeadersToV8(
// Overloaded by multiple types to fill the |details| object.
void ToDictionary(gin_helper::Dictionary* details,
extensions::WebRequestInfo* info) {
const extensions::WebRequestInfo* info) {
details->Set("id", info->id);
details->Set("url", info->url);
details->Set("method", info->method);
@@ -247,7 +250,7 @@ bool WebRequest::RequestFilter::MatchesType(
}
bool WebRequest::RequestFilter::MatchesRequest(
extensions::WebRequestInfo* info) const {
const extensions::WebRequestInfo* info) const {
// Matches URL and type, and does not match exclude URL.
return MatchesURL(info->url, include_url_patterns_) &&
!MatchesURL(info->url, exclude_url_patterns_) &&
@@ -279,6 +282,10 @@ struct WebRequest::BlockedRequest {
net::CompletionOnceCallback callback;
// Only used for onBeforeSendHeaders.
BeforeSendHeadersCallback before_send_headers_callback;
// The callback to invoke for auth. If |auth_callback.is_null()| is false,
// |callback| must be NULL.
// Only valid for OnAuthRequired.
AuthCallback auth_callback;
// Only used for onBeforeSendHeaders.
raw_ptr<net::HttpRequestHeaders> request_headers = nullptr;
// Only used for onHeadersReceived.
@@ -289,6 +296,8 @@ struct WebRequest::BlockedRequest {
std::string status_line;
// Only used for onBeforeRequest.
raw_ptr<GURL> new_url = nullptr;
// Owns the LoginHandler while waiting for auth credentials.
std::unique_ptr<LoginHandler> login_handler;
};
WebRequest::SimpleListenerInfo::SimpleListenerInfo(RequestFilter filter_,
@@ -588,6 +597,36 @@ void WebRequest::OnSendHeaders(extensions::WebRequestInfo* info,
HandleSimpleEvent(SimpleEvent::kOnSendHeaders, info, request, headers);
}
WebRequest::AuthRequiredResponse WebRequest::OnAuthRequired(
const extensions::WebRequestInfo* request_info,
const net::AuthChallengeInfo& auth_info,
WebRequest::AuthCallback callback,
net::AuthCredentials* credentials) {
content::RenderFrameHost* rfh = content::RenderFrameHost::FromID(
request_info->render_process_id, request_info->frame_routing_id);
content::WebContents* web_contents = nullptr;
if (rfh)
web_contents = content::WebContents::FromRenderFrameHost(rfh);
BlockedRequest blocked_request;
blocked_request.auth_callback = std::move(callback);
blocked_requests_[request_info->id] = std::move(blocked_request);
auto login_callback =
base::BindOnce(&WebRequest::OnLoginAuthResult, base::Unretained(this),
request_info->id, credentials);
scoped_refptr<net::HttpResponseHeaders> response_headers =
request_info->response_headers;
blocked_requests_[request_info->id].login_handler =
std::make_unique<LoginHandler>(
auth_info, web_contents,
static_cast<base::ProcessId>(request_info->render_process_id),
request_info->url, response_headers, std::move(login_callback));
return AuthRequiredResponse::AUTH_REQUIRED_RESPONSE_IO_PENDING;
}
void WebRequest::OnBeforeRedirect(extensions::WebRequestInfo* info,
const network::ResourceRequest& request,
const GURL& new_location) {
@@ -717,6 +756,26 @@ void WebRequest::HandleSimpleEvent(SimpleEvent event,
info.listener.Run(gin::ConvertToV8(isolate, details));
}
void WebRequest::OnLoginAuthResult(
uint64_t id,
net::AuthCredentials* credentials,
const std::optional<net::AuthCredentials>& maybe_creds) {
auto iter = blocked_requests_.find(id);
if (iter == blocked_requests_.end())
NOTREACHED();
AuthRequiredResponse action =
AuthRequiredResponse::AUTH_REQUIRED_RESPONSE_NO_ACTION;
if (maybe_creds.has_value()) {
*credentials = maybe_creds.value();
action = AuthRequiredResponse::AUTH_REQUIRED_RESPONSE_SET_AUTH;
}
base::SequencedTaskRunner::GetCurrentDefault()->PostTask(
FROM_HERE, base::BindOnce(std::move(iter->second.auth_callback), action));
blocked_requests_.erase(iter);
}
// static
gin_helper::Handle<WebRequest> WebRequest::FromOrCreate(
v8::Isolate* isolate,

View File

@@ -45,6 +45,23 @@ class WebRequest final : public gin_helper::DeprecatedWrappable<WebRequest> {
const std::set<std::string>& set_headers,
int error_code)>;
// AuthRequiredResponse indicates how an OnAuthRequired call is handled.
enum class AuthRequiredResponse {
// No credentials were provided.
AUTH_REQUIRED_RESPONSE_NO_ACTION,
// AuthCredentials is filled in with a username and password, which should
// be used in a response to the provided auth challenge.
AUTH_REQUIRED_RESPONSE_SET_AUTH,
// The request should be canceled.
AUTH_REQUIRED_RESPONSE_CANCEL_AUTH,
// The action will be decided asynchronously. |callback| will be invoked
// when the decision is made, and one of the other AuthRequiredResponse
// values will be passed in with the same semantics as described above.
AUTH_REQUIRED_RESPONSE_IO_PENDING,
};
using AuthCallback = base::OnceCallback<void(AuthRequiredResponse)>;
// Convenience wrapper around api::Session::FromOrCreate()->WebRequest().
// Creates the Session and WebRequest if they don't already exist.
// Note that the WebRequest is owned by the session, not by the caller.
@@ -83,6 +100,10 @@ class WebRequest final : public gin_helper::DeprecatedWrappable<WebRequest> {
void OnSendHeaders(extensions::WebRequestInfo* info,
const network::ResourceRequest& request,
const net::HttpRequestHeaders& headers);
AuthRequiredResponse OnAuthRequired(const extensions::WebRequestInfo* info,
const net::AuthChallengeInfo& auth_info,
AuthCallback callback,
net::AuthCredentials* credentials);
void OnBeforeRedirect(extensions::WebRequestInfo* info,
const network::ResourceRequest& request,
const GURL& new_location);
@@ -158,6 +179,12 @@ class WebRequest final : public gin_helper::DeprecatedWrappable<WebRequest> {
v8::Local<v8::Value> response);
void OnHeadersReceivedListenerResult(uint64_t id,
v8::Local<v8::Value> response);
// Callback invoked by LoginHandler when auth credentials are supplied via
// the unified 'login' event. Bridges back into WebRequest's AuthCallback.
void OnLoginAuthResult(
uint64_t id,
net::AuthCredentials* credentials,
const std::optional<net::AuthCredentials>& maybe_creds);
class RequestFilter {
public:
@@ -175,7 +202,7 @@ class WebRequest final : public gin_helper::DeprecatedWrappable<WebRequest> {
bool is_match_pattern = true);
void AddType(extensions::WebRequestResourceType type);
bool MatchesRequest(extensions::WebRequestInfo* info) const;
bool MatchesRequest(const extensions::WebRequestInfo* info) const;
private:
bool MatchesURL(const GURL& url,